KarmanovaLidiia commited on
Commit
c69034b
·
1 Parent(s): 4392241

feat: auto-download CatBoost models and on_topic from HF Space

Browse files
Files changed (1) hide show
  1. src/predict.py +60 -7
src/predict.py CHANGED
@@ -12,6 +12,7 @@ import numpy as np
12
  import pandas as pd
13
  import joblib
14
  from catboost import CatBoostRegressor
 
15
 
16
  # --- импорты проекта ---
17
  HERE = Path(__file__).parent
@@ -26,10 +27,13 @@ except ModuleNotFoundError:
26
  # если файл лежит в src/
27
  from src.feature_engineering import FeatureExtractor # type: ignore
28
 
29
- # --- пути ---
30
  MODELS_DIR = ROOT / "models" # catboost_Q1.cbm ... catboost_Q4.cbm
31
  ON_TOPIC_PATH = MODELS_DIR / "on_topic.pkl" # опционально
32
 
 
 
 
33
  # --- служебные колонки (не подавать в модель) ---
34
  NON_NUMERIC_KEEP = {"question_number", "question_text", "answer_text"}
35
  TARGET_COLS = {"score", "Оценка экзаменатора"}
@@ -75,11 +79,55 @@ def _clip_by_q(qnum: int, preds: np.ndarray) -> np.ndarray:
75
  return np.clip(preds, lo, hi)
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def _load_model(qnum: int) -> CatBoostRegressor:
79
- """Загрузка CatBoost-модели для указанного вопроса."""
80
- model_path = MODELS_DIR / f"catboost_Q{qnum}.cbm"
81
- if not model_path.exists():
82
- raise FileNotFoundError(f"Не найден файл модели: {model_path}")
83
  model = CatBoostRegressor()
84
  model.load_model(str(model_path))
85
  return model
@@ -98,10 +146,15 @@ def _align_to_model_features(model: CatBoostRegressor, X: pd.DataFrame) -> pd.Da
98
 
99
  def _maybe_add_on_topic(df_feats: pd.DataFrame) -> pd.DataFrame:
100
  """
101
- Если есть on_topic.pkl (pack = {'model': clf, 'features': [...]})
102
  — добавляем вероятность 'on_topic_prob'. Иначе 0.0.
103
  """
104
  out = df_feats.copy()
 
 
 
 
 
105
  if not ON_TOPIC_PATH.exists():
106
  out["on_topic_prob"] = 0.0
107
  return out
@@ -180,7 +233,7 @@ def pipeline_infer(input_csv: Path, output_csv: Path) -> None:
180
  pq = np.asarray(pq, dtype=float).reshape(-1)
181
  preds[mask.values] = _clip_by_q(q, pq)
182
 
183
- # --- новое надёжное округление (без .loc по индексам) ---
184
  qnums = feats["question_number"].astype(int).to_numpy()
185
  rounded = np.rint(preds).astype(np.float32)
186
  mask13 = (qnums == 1) | (qnums == 3)
 
12
  import pandas as pd
13
  import joblib
14
  from catboost import CatBoostRegressor
15
+ from huggingface_hub import hf_hub_download # <— автодозагрузка файлов из HF
16
 
17
  # --- импорты проекта ---
18
  HERE = Path(__file__).parent
 
27
  # если файл лежит в src/
28
  from src.feature_engineering import FeatureExtractor # type: ignore
29
 
30
+ # --- пути/константы ---
31
  MODELS_DIR = ROOT / "models" # catboost_Q1.cbm ... catboost_Q4.cbm
32
  ON_TOPIC_PATH = MODELS_DIR / "on_topic.pkl" # опционально
33
 
34
+ # репозиторий Space, откуда подтягиваем артефакты, если их нет локально
35
+ SPACE_REPO = os.environ.get("SPACE_REPO", "lidiiakarmanova/exam-evaluator")
36
+
37
  # --- служебные колонки (не подавать в модель) ---
38
  NON_NUMERIC_KEEP = {"question_number", "question_text", "answer_text"}
39
  TARGET_COLS = {"score", "Оценка экзаменатора"}
 
79
  return np.clip(preds, lo, hi)
80
 
81
 
82
+ def _ensure_model_file(qnum: int) -> Path:
83
+ """
84
+ Гарантирует наличие файла модели Q{qnum} локально.
85
+ Если файла нет — скачивает из Space (путь в репо: models/catboost_Q{q}.cbm).
86
+ """
87
+ MODELS_DIR.mkdir(parents=True, exist_ok=True)
88
+ local_path = MODELS_DIR / f"catboost_Q{qnum}.cbm"
89
+ if local_path.exists():
90
+ return local_path
91
+
92
+ remote_filename = f"models/catboost_Q{qnum}.cbm"
93
+ print(f"[i] Модель Q{qnum} не найдена локально, скачиваем из {SPACE_REPO}:{remote_filename}")
94
+ cache_path = hf_hub_download(
95
+ repo_id=SPACE_REPO,
96
+ repo_type="space",
97
+ filename=remote_filename,
98
+ )
99
+ # скопируем из кэша в models/ — Space может чистить кэш между рестартами
100
+ Path(local_path).write_bytes(Path(cache_path).read_bytes())
101
+ return local_path
102
+
103
+
104
+ def _ensure_on_topic_file() -> Path | None:
105
+ """
106
+ Если используем on_topic.pkl — аналогично подтянем из Space (models/on_topic.pkl),
107
+ иначе вернём None.
108
+ """
109
+ if ON_TOPIC_PATH.exists():
110
+ return ON_TOPIC_PATH
111
+
112
+ remote_filename = "models/on_topic.pkl"
113
+ try:
114
+ print(f"[i] on_topic.pkl не найден локально, пробуем скачать из {SPACE_REPO}:{remote_filename}")
115
+ cache_path = hf_hub_download(
116
+ repo_id=SPACE_REPO,
117
+ repo_type="space",
118
+ filename=remote_filename,
119
+ )
120
+ ON_TOPIC_PATH.parent.mkdir(parents=True, exist_ok=True)
121
+ Path(ON_TOPIC_PATH).write_bytes(Path(cache_path).read_bytes())
122
+ return ON_TOPIC_PATH
123
+ except Exception as e:
124
+ print(f"[!] Не удалось скачать on_topic.pkl: {e}")
125
+ return None
126
+
127
+
128
  def _load_model(qnum: int) -> CatBoostRegressor:
129
+ """Загрузка CatBoost-модели для указанного вопроса (с автодозагрузкой из HF)."""
130
+ model_path = _ensure_model_file(qnum)
 
 
131
  model = CatBoostRegressor()
132
  model.load_model(str(model_path))
133
  return model
 
146
 
147
  def _maybe_add_on_topic(df_feats: pd.DataFrame) -> pd.DataFrame:
148
  """
149
+ Если есть on_topic.pkl (pack = {'model': clf, 'features': [...]}),
150
  — добавляем вероятность 'on_topic_prob'. Иначе 0.0.
151
  """
152
  out = df_feats.copy()
153
+
154
+ # попытаемся подтянуть on_topic.pkl из Space при необходимости
155
+ if not ON_TOPIC_PATH.exists():
156
+ _ensure_on_topic_file()
157
+
158
  if not ON_TOPIC_PATH.exists():
159
  out["on_topic_prob"] = 0.0
160
  return out
 
233
  pq = np.asarray(pq, dtype=float).reshape(-1)
234
  preds[mask.values] = _clip_by_q(q, pq)
235
 
236
+ # --- надёжное округление ---
237
  qnums = feats["question_number"].astype(int).to_numpy()
238
  rounded = np.rint(preds).astype(np.float32)
239
  mask13 = (qnums == 1) | (qnums == 3)