fido_changes / src /classifier.py
szymskul's picture
update files
00cccb0
# newQlasifier.py
from __future__ import annotations
from pathlib import Path
from functools import lru_cache
import numpy as np
import pickle
import os
os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/app/.cache/sentence_transformers")
os.environ.setdefault("HF_HOME", "/app/.cache/huggingface")
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/app/.cache/huggingface")
os.environ.setdefault("TRANSFORMERS_CACHE", "/app/.cache/huggingface/transformers")
os.environ.setdefault("TRANSFORMERS_NO_TF", "1")
os.environ.setdefault("TRANSFORMERS_NO_FLAX", "1")
# ---- Sta艂e i 艣cie偶ki (ABSOLUTNE wzgl臋dem tego pliku) ----
THIS_DIR = Path(__file__).resolve().parent
MODEL_A_PATH = THIS_DIR / "best_model_70%.keras"
MODEL_B_PATH = THIS_DIR / "best_model_70%1.keras"
MLB_A_PATH = THIS_DIR / "mlb.pkl"
MLB_B_PATH = THIS_DIR / "mlb1.pkl"
EMBED_NAME = "paraphrase-multilingual-MiniLM-L12-v2" # 384-D
# ---- 艁adowanie zale偶no艣ci ci臋偶kich (lazy + cache) ----
from sentence_transformers import SentenceTransformer
EMBED_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" # Tw贸j model
@lru_cache(maxsize=1)
def _embedder():
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME", "/app/.cache/sentence_transformers")
return SentenceTransformer(EMBED_NAME, cache_folder=cache_dir)
def _load_with_fallback(model_path: Path):
"""
Najpierw spr贸buj tf.keras, a je艣li trafi si臋 konflikt deserializacji (np. 'batch_shape'),
spr贸buj standalone 'keras'. Dzi臋ki temu dzia艂a w r贸偶nych 艣rodowiskach.
"""
# 1) tf.keras
try:
import tensorflow as tf
return tf.keras.models.load_model(str(model_path), compile=False)
except TypeError as e:
# typowy b艂膮d z 'batch_shape' przy niezgodnych wersjach
err = str(e).lower()
if "unrecognized keyword arguments" in err or "batch_shape" in err:
pass # spr贸bujemy standalone keras
else:
raise
except Exception:
# inne problemy te偶 spr贸bujmy obej艣膰 via keras
pass
# 2) standalone keras
import keras
return keras.models.load_model(str(model_path), compile=False)
@lru_cache(maxsize=1)
def _model_a():
return _load_with_fallback(MODEL_A_PATH)
@lru_cache(maxsize=1)
def _model_b():
return _load_with_fallback(MODEL_B_PATH)
@lru_cache(maxsize=1)
def _mlb_a():
with open(MLB_A_PATH, "rb") as f:
return pickle.load(f)
@lru_cache(maxsize=1)
def _mlb_b():
with open(MLB_B_PATH, "rb") as f:
return pickle.load(f)
# ---- API: funkcje do wywo艂ywania z innych plik贸w ----
def encode_text(text: str) -> np.ndarray:
"""
Zwraca wektor (1, d) jako float32.
"""
emb = _embedder()
X = emb.encode([text], convert_to_numpy=True, show_progress_bar=False)
return np.asarray(X, dtype="float32")
def predict_raw(text: str) -> str:
"""
Predykcja modelem A (best_model_70%.keras) -> zwraca etykiet臋 (string).
"""
X = encode_text(text) # (1, d)
y = _model_a().predict(X, verbose=0)[0] # (n_classes,)
cls = int(np.argmax(y))
return _mlb_a().classes_[cls]
def predict_raw1(text: str) -> str:
"""
Predykcja modelem B (best_model_70%1.keras) -> zwraca etykiet臋 (string).
"""
X = encode_text(text)
y = _model_b().predict(X, verbose=0)[0]
cls = int(np.argmax(y))
return _mlb_b().classes_[cls]