|
|
|
|
|
import os |
|
|
import io |
|
|
import json |
|
|
import base64 |
|
|
from typing import Dict, Optional |
|
|
|
|
|
import shap |
|
|
import pandas as pd |
|
|
import matplotlib.pyplot as plt |
|
|
import joblib |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
from utils.config import AppConfig |
|
|
from utils.tracing import Tracer |
|
|
|
|
|
|
|
|
class ExplainTool: |
|
|
""" |
|
|
Generates lightweight global SHAP visualizations (bar + beeswarm) for a sample |
|
|
of the current DataFrame. Designed to run on CPU in HF Spaces. |
|
|
""" |
|
|
def __init__(self, cfg: AppConfig, tracer: Tracer): |
|
|
self.cfg = cfg |
|
|
self.tracer = tracer |
|
|
self._model = None |
|
|
self._feature_order = None |
|
|
|
|
|
def _ensure_model(self): |
|
|
if self._model is not None: |
|
|
return |
|
|
token = os.getenv("HF_TOKEN") |
|
|
repo = self.cfg.hf_model_repo |
|
|
|
|
|
model_path = hf_hub_download( |
|
|
repo_id=repo, |
|
|
filename="model.pkl", |
|
|
token=token |
|
|
) |
|
|
self._model = joblib.load(model_path) |
|
|
|
|
|
|
|
|
try: |
|
|
meta_path = hf_hub_download( |
|
|
repo_id=repo, |
|
|
filename="feature_metadata.json", |
|
|
token=token |
|
|
) |
|
|
with open(meta_path, "r", encoding="utf-8") as f: |
|
|
meta = json.load(f) or {} |
|
|
self._feature_order = meta.get("feature_order") |
|
|
except Exception: |
|
|
self._feature_order = None |
|
|
|
|
|
@staticmethod |
|
|
def _to_data_uri(fig) -> str: |
|
|
buf = io.BytesIO() |
|
|
fig.savefig(buf, format="png", bbox_inches="tight", dpi=150) |
|
|
plt.close(fig) |
|
|
buf.seek(0) |
|
|
return "data:image/png;base64," + base64.b64encode(buf.read()).decode() |
|
|
|
|
|
def run(self, df: Optional[pd.DataFrame]) -> Dict[str, str]: |
|
|
""" |
|
|
Returns dict of {plot_name: data_uri_png}. If df is None/empty, returns {}. |
|
|
""" |
|
|
self._ensure_model() |
|
|
if df is None or len(df) == 0: |
|
|
return {} |
|
|
|
|
|
|
|
|
if self._feature_order: |
|
|
missing = [c for c in self._feature_order if c not in df.columns] |
|
|
if missing: |
|
|
|
|
|
X = df[[c for c in self._feature_order if c in df.columns]].copy() |
|
|
else: |
|
|
X = df[self._feature_order].copy() |
|
|
else: |
|
|
X = df.copy() |
|
|
|
|
|
|
|
|
n = min(len(X), 500) |
|
|
sample = X.sample(n, random_state=42) if len(X) > n else X |
|
|
|
|
|
|
|
|
explainer = shap.Explainer(self._model, sample) |
|
|
sv = explainer(sample) |
|
|
|
|
|
|
|
|
fig_bar = plt.figure() |
|
|
shap.plots.bar(sv, show=False) |
|
|
bar_uri = self._to_data_uri(fig_bar) |
|
|
|
|
|
|
|
|
fig_bee = plt.figure() |
|
|
shap.plots.beeswarm(sv, show=False) |
|
|
bee_uri = self._to_data_uri(fig_bee) |
|
|
|
|
|
try: |
|
|
self.tracer.trace_event("explain", {"rows": int(n)}) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
return {"global_bar": bar_uri, "beeswarm": bee_uri} |
|
|
|