# space/tools/explain_tool.py import os import io import json import base64 from typing import Dict, Optional import shap import pandas as pd import matplotlib.pyplot as plt import joblib from huggingface_hub import hf_hub_download from utils.config import AppConfig from utils.tracing import Tracer class ExplainTool: """ Generates lightweight global SHAP visualizations (bar + beeswarm) for a sample of the current DataFrame. Designed to run on CPU in HF Spaces. """ def __init__(self, cfg: AppConfig, tracer: Tracer): self.cfg = cfg self.tracer = tracer self._model = None self._feature_order = None def _ensure_model(self): if self._model is not None: return token = os.getenv("HF_TOKEN") repo = self.cfg.hf_model_repo model_path = hf_hub_download( repo_id=repo, filename="model.pkl", token=token ) self._model = joblib.load(model_path) # read optional feature metadata to keep column order consistent try: meta_path = hf_hub_download( repo_id=repo, filename="feature_metadata.json", token=token ) with open(meta_path, "r", encoding="utf-8") as f: meta = json.load(f) or {} self._feature_order = meta.get("feature_order") except Exception: self._feature_order = None @staticmethod def _to_data_uri(fig) -> str: buf = io.BytesIO() fig.savefig(buf, format="png", bbox_inches="tight", dpi=150) plt.close(fig) buf.seek(0) return "data:image/png;base64," + base64.b64encode(buf.read()).decode() def run(self, df: Optional[pd.DataFrame]) -> Dict[str, str]: """ Returns dict of {plot_name: data_uri_png}. If df is None/empty, returns {}. """ self._ensure_model() if df is None or len(df) == 0: return {} # Select & sample features if self._feature_order: missing = [c for c in self._feature_order if c not in df.columns] if missing: # best effort: intersect X = df[[c for c in self._feature_order if c in df.columns]].copy() else: X = df[self._feature_order].copy() else: X = df.copy() # Small sample for speed n = min(len(X), 500) sample = X.sample(n, random_state=42) if len(X) > n else X # Build explainer and compute SHAP values explainer = shap.Explainer(self._model, sample) sv = explainer(sample) # --- Global bar plot --- fig_bar = plt.figure() shap.plots.bar(sv, show=False) bar_uri = self._to_data_uri(fig_bar) # --- Beeswarm plot --- fig_bee = plt.figure() shap.plots.beeswarm(sv, show=False) bee_uri = self._to_data_uri(fig_bee) try: self.tracer.trace_event("explain", {"rows": int(n)}) except Exception: pass return {"global_bar": bar_uri, "beeswarm": bee_uri}