Spaces:
Sleeping
Sleeping
KarmanovaLidiia
commited on
Commit
·
c69034b
1
Parent(s):
4392241
feat: auto-download CatBoost models and on_topic from HF Space
Browse files- src/predict.py +60 -7
src/predict.py
CHANGED
|
@@ -12,6 +12,7 @@ import numpy as np
|
|
| 12 |
import pandas as pd
|
| 13 |
import joblib
|
| 14 |
from catboost import CatBoostRegressor
|
|
|
|
| 15 |
|
| 16 |
# --- импорты проекта ---
|
| 17 |
HERE = Path(__file__).parent
|
|
@@ -26,10 +27,13 @@ except ModuleNotFoundError:
|
|
| 26 |
# если файл лежит в src/
|
| 27 |
from src.feature_engineering import FeatureExtractor # type: ignore
|
| 28 |
|
| 29 |
-
# ---
|
| 30 |
MODELS_DIR = ROOT / "models" # catboost_Q1.cbm ... catboost_Q4.cbm
|
| 31 |
ON_TOPIC_PATH = MODELS_DIR / "on_topic.pkl" # опционально
|
| 32 |
|
|
|
|
|
|
|
|
|
|
| 33 |
# --- служебные колонки (не подавать в модель) ---
|
| 34 |
NON_NUMERIC_KEEP = {"question_number", "question_text", "answer_text"}
|
| 35 |
TARGET_COLS = {"score", "Оценка экзаменатора"}
|
|
@@ -75,11 +79,55 @@ def _clip_by_q(qnum: int, preds: np.ndarray) -> np.ndarray:
|
|
| 75 |
return np.clip(preds, lo, hi)
|
| 76 |
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
def _load_model(qnum: int) -> CatBoostRegressor:
|
| 79 |
-
"""Загрузка CatBoost-модели для указанного
|
| 80 |
-
model_path =
|
| 81 |
-
if not model_path.exists():
|
| 82 |
-
raise FileNotFoundError(f"Не найден файл модели: {model_path}")
|
| 83 |
model = CatBoostRegressor()
|
| 84 |
model.load_model(str(model_path))
|
| 85 |
return model
|
|
@@ -98,10 +146,15 @@ def _align_to_model_features(model: CatBoostRegressor, X: pd.DataFrame) -> pd.Da
|
|
| 98 |
|
| 99 |
def _maybe_add_on_topic(df_feats: pd.DataFrame) -> pd.DataFrame:
|
| 100 |
"""
|
| 101 |
-
Если есть on_topic.pkl (pack = {'model': clf, 'features': [...]})
|
| 102 |
— добавляем вероятность 'on_topic_prob'. Иначе 0.0.
|
| 103 |
"""
|
| 104 |
out = df_feats.copy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
if not ON_TOPIC_PATH.exists():
|
| 106 |
out["on_topic_prob"] = 0.0
|
| 107 |
return out
|
|
@@ -180,7 +233,7 @@ def pipeline_infer(input_csv: Path, output_csv: Path) -> None:
|
|
| 180 |
pq = np.asarray(pq, dtype=float).reshape(-1)
|
| 181 |
preds[mask.values] = _clip_by_q(q, pq)
|
| 182 |
|
| 183 |
-
# ---
|
| 184 |
qnums = feats["question_number"].astype(int).to_numpy()
|
| 185 |
rounded = np.rint(preds).astype(np.float32)
|
| 186 |
mask13 = (qnums == 1) | (qnums == 3)
|
|
|
|
| 12 |
import pandas as pd
|
| 13 |
import joblib
|
| 14 |
from catboost import CatBoostRegressor
|
| 15 |
+
from huggingface_hub import hf_hub_download # <— автодозагрузка файлов из HF
|
| 16 |
|
| 17 |
# --- импорты проекта ---
|
| 18 |
HERE = Path(__file__).parent
|
|
|
|
| 27 |
# если файл лежит в src/
|
| 28 |
from src.feature_engineering import FeatureExtractor # type: ignore
|
| 29 |
|
| 30 |
+
# --- пути/константы ---
|
| 31 |
MODELS_DIR = ROOT / "models" # catboost_Q1.cbm ... catboost_Q4.cbm
|
| 32 |
ON_TOPIC_PATH = MODELS_DIR / "on_topic.pkl" # опционально
|
| 33 |
|
| 34 |
+
# репозиторий Space, откуда подтягиваем артефакты, если их нет локально
|
| 35 |
+
SPACE_REPO = os.environ.get("SPACE_REPO", "lidiiakarmanova/exam-evaluator")
|
| 36 |
+
|
| 37 |
# --- служебные колонки (не подавать в модель) ---
|
| 38 |
NON_NUMERIC_KEEP = {"question_number", "question_text", "answer_text"}
|
| 39 |
TARGET_COLS = {"score", "Оценка экзаменатора"}
|
|
|
|
| 79 |
return np.clip(preds, lo, hi)
|
| 80 |
|
| 81 |
|
| 82 |
+
def _ensure_model_file(qnum: int) -> Path:
|
| 83 |
+
"""
|
| 84 |
+
Гарантирует наличие файла модели Q{qnum} локально.
|
| 85 |
+
Если файла нет — скачивает из Space (путь в репо: models/catboost_Q{q}.cbm).
|
| 86 |
+
"""
|
| 87 |
+
MODELS_DIR.mkdir(parents=True, exist_ok=True)
|
| 88 |
+
local_path = MODELS_DIR / f"catboost_Q{qnum}.cbm"
|
| 89 |
+
if local_path.exists():
|
| 90 |
+
return local_path
|
| 91 |
+
|
| 92 |
+
remote_filename = f"models/catboost_Q{qnum}.cbm"
|
| 93 |
+
print(f"[i] Модель Q{qnum} не найдена локально, скачиваем из {SPACE_REPO}:{remote_filename}")
|
| 94 |
+
cache_path = hf_hub_download(
|
| 95 |
+
repo_id=SPACE_REPO,
|
| 96 |
+
repo_type="space",
|
| 97 |
+
filename=remote_filename,
|
| 98 |
+
)
|
| 99 |
+
# скопируем из кэша в models/ — Space может чистить кэш между рестартами
|
| 100 |
+
Path(local_path).write_bytes(Path(cache_path).read_bytes())
|
| 101 |
+
return local_path
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _ensure_on_topic_file() -> Path | None:
|
| 105 |
+
"""
|
| 106 |
+
Если используем on_topic.pkl — аналогично подтянем из Space (models/on_topic.pkl),
|
| 107 |
+
иначе вернём None.
|
| 108 |
+
"""
|
| 109 |
+
if ON_TOPIC_PATH.exists():
|
| 110 |
+
return ON_TOPIC_PATH
|
| 111 |
+
|
| 112 |
+
remote_filename = "models/on_topic.pkl"
|
| 113 |
+
try:
|
| 114 |
+
print(f"[i] on_topic.pkl не найден локально, пробуем скачать из {SPACE_REPO}:{remote_filename}")
|
| 115 |
+
cache_path = hf_hub_download(
|
| 116 |
+
repo_id=SPACE_REPO,
|
| 117 |
+
repo_type="space",
|
| 118 |
+
filename=remote_filename,
|
| 119 |
+
)
|
| 120 |
+
ON_TOPIC_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 121 |
+
Path(ON_TOPIC_PATH).write_bytes(Path(cache_path).read_bytes())
|
| 122 |
+
return ON_TOPIC_PATH
|
| 123 |
+
except Exception as e:
|
| 124 |
+
print(f"[!] Не удалось скачать on_topic.pkl: {e}")
|
| 125 |
+
return None
|
| 126 |
+
|
| 127 |
+
|
| 128 |
def _load_model(qnum: int) -> CatBoostRegressor:
|
| 129 |
+
"""Загрузка CatBoost-модели для указанного вопроса (с автодозагрузкой из HF)."""
|
| 130 |
+
model_path = _ensure_model_file(qnum)
|
|
|
|
|
|
|
| 131 |
model = CatBoostRegressor()
|
| 132 |
model.load_model(str(model_path))
|
| 133 |
return model
|
|
|
|
| 146 |
|
| 147 |
def _maybe_add_on_topic(df_feats: pd.DataFrame) -> pd.DataFrame:
|
| 148 |
"""
|
| 149 |
+
Если есть on_topic.pkl (pack = {'model': clf, 'features': [...]}),
|
| 150 |
— добавляем вероятность 'on_topic_prob'. Иначе 0.0.
|
| 151 |
"""
|
| 152 |
out = df_feats.copy()
|
| 153 |
+
|
| 154 |
+
# попытаемся подтянуть on_topic.pkl из Space при необходимости
|
| 155 |
+
if not ON_TOPIC_PATH.exists():
|
| 156 |
+
_ensure_on_topic_file()
|
| 157 |
+
|
| 158 |
if not ON_TOPIC_PATH.exists():
|
| 159 |
out["on_topic_prob"] = 0.0
|
| 160 |
return out
|
|
|
|
| 233 |
pq = np.asarray(pq, dtype=float).reshape(-1)
|
| 234 |
preds[mask.values] = _clip_by_q(q, pq)
|
| 235 |
|
| 236 |
+
# --- надёжное округление ---
|
| 237 |
qnums = feats["question_number"].astype(int).to_numpy()
|
| 238 |
rounded = np.rint(preds).astype(np.float32)
|
| 239 |
mask13 = (qnums == 1) | (qnums == 3)
|