# tools/sql_tool.py import os import re from typing import Optional, Tuple, List import duckdb import pandas as pd # ------------------------------------------------------------ # Connection config # ------------------------------------------------------------ DUCKDB_PATH = os.getenv("DUCKDB_PATH", "alm.duckdb") # If you need to attach a catalog (e.g., MotherDuck), put the full ATTACH here. # Example: DUCKDB_ATTACH_SQL=ATTACH 'md:my_db' AS my_db; # Preferred identifiers (we will fall back automatically if they don't exist) PREF_CATALOG = os.getenv("SQL_DEFAULT_DB", "my_db") # catalog (optional) PREF_SCHEMA = os.getenv("SQL_DEFAULT_SCHEMA", "main") # schema PREF_TABLE = os.getenv("SQL_DEFAULT_TABLE", "masterdataset_v") # table class SQLTool: """ NL→SQL helper for DuckDB with: - optional pre-attach SQL (DUCKDB_ATTACH_SQL) - robust table path resolution (tries 3-part → 2-part → 1-part → information_schema scan) """ def __init__(self, db_path: Optional[str] = None): self.db_path = db_path or DUCKDB_PATH self.con = duckdb.connect(self.db_path) # Optional: run user-supplied ATTACH (safe no-op if empty) if DUCKDB_ATTACH_SQL: try: self.con.execute(DUCKDB_ATTACH_SQL) except Exception as e: # Don't crash the app on attach issues; we still try local tables print(f"[WARN] DUCKDB_ATTACH_SQL failed: {e}") self.full_table = self._resolve_full_table(PREF_CATALOG, PREF_SCHEMA, PREF_TABLE) # ------------------------------------------------------------ # Resolution helpers # ------------------------------------------------------------ def _try_probe(self, path: str) -> bool: """Return True if SELECT * FROM LIMIT 1 succeeds.""" try: self.con.execute(f"SELECT * FROM {path} LIMIT 1") return True except Exception: return False def _scan_information_schema(self, table_name: str) -> Optional[str]: """ Look for . (and ..
if available) in information_schema. Return a best guess path string or None. """ q = """ SELECT table_catalog, table_schema, table_name FROM information_schema.tables WHERE lower(table_name) = ? ORDER BY table_catalog, table_schema """ rows = self.con.execute(q, [table_name.lower()]).fetchall() if not rows: return None # Prefer matches in preferred schema/catalog when possible # 1) exact catalog+schema for cat, sch, t in rows: if (cat or "").lower() == (PREF_CATALOG or "").lower() and sch.lower() == PREF_SCHEMA.lower(): candidate = f"{cat}.{sch}.{t}" if cat else f"{sch}.{t}" if self._try_probe(candidate): return candidate # 2) exact schema (2-part) for cat, sch, t in rows: if sch.lower() == PREF_SCHEMA.lower(): candidate = f"{sch}.{t}" if self._try_probe(candidate): return candidate # 3) first working row (prefer 3-part if catalog present) for cat, sch, t in rows: candidate = f"{cat}.{sch}.{t}" if cat else f"{sch}.{t}" if self._try_probe(candidate): return candidate return None def _resolve_full_table(self, catalog: Optional[str], schema: Optional[str], table: str) -> str: """ Return a working fully qualified path for the table by trying: - ..
(3-part) - .
(2-part) -
(1-part) - information_schema scan (best effort) """ candidates: List[str] = [] if catalog: candidates.append(f"{catalog}.{schema}.{table}") if schema: candidates.append(f"{schema}.{table}") candidates.append(table) for path in candidates: if self._try_probe(path): print(f"[INFO] Using table path: {path}") return path # Fallback: scan information_schema scanned = self._scan_information_schema(table) if scanned: print(f"[INFO] Using table path (scanned): {scanned}") return scanned # Last resort: keep preferred 3-part (will raise on first query) fallback = f"{catalog}.{schema}.{table}" if catalog else f"{schema}.{table}" print(f"[WARN] Could not resolve table path; falling back to: {fallback}") return fallback # ------------------------------------------------------------ # Run SQL directly # ------------------------------------------------------------ def run_sql(self, sql: str) -> pd.DataFrame: return self.con.execute(sql).df() # ------------------------------------------------------------ # NL → SQL # ------------------------------------------------------------ def _nl_to_sql(self, message: str) -> Tuple[str, str]: full_table = self.full_table m = (message or "").strip().lower() def has_any(txt, words): return any(w in txt for w in words) # Extract "top N" limit = None m_top = re.search(r"\btop\s+(\d+)", m) if m_top: limit = int(m_top.group(1)) # 1. Top N FDs if has_any(m, ["fd", "fixed deposit", "deposits"]) and has_any( m, ["top", "largest", "biggest"] ) and has_any(m, ["portfolio value", "portfolio_value"]): n = limit or 10 sql = f""" SELECT contract_number, Portfolio_value, Interest_rate, currency, segments FROM {full_table} WHERE lower(product) = 'fd' ORDER BY Portfolio_value DESC LIMIT {n}; """ why = f"Top {n} fixed deposits by Portfolio_value from {full_table}" return sql, why # 2. Top N Assets if has_any(m, ["asset", "loan", "advances"]) and has_any( m, ["top", "largest", "biggest"] ) and has_any(m, ["portfolio value", "portfolio_value"]): n = limit or 10 sql = f""" SELECT contract_number, Portfolio_value, Interest_rate, currency, segments FROM {full_table} WHERE lower(product) = 'assets' ORDER BY Portfolio_value DESC LIMIT {n}; """ why = f"Top {n} assets by Portfolio_value from {full_table}" return sql, why # 3. Aggregate by segment/currency if has_any(m, ["sum", "total", "avg", "average"]) and has_any( m, ["segment", "currency"] ): agg = "SUM" if has_any(m, ["sum", "total"]) else "AVG" dim = "segments" if "segment" in m else "currency" sql = f""" SELECT {dim}, {agg}(Portfolio_value) AS {agg.lower()}_Portfolio_value FROM {full_table} GROUP BY 1 ORDER BY 2 DESC; """ why = f"{agg} Portfolio_value grouped by {dim} from {full_table}" return sql, why # 4. Generic filters product = None if "fd" in m or "deposit" in m: product = "fd" elif "asset" in m or "loan" in m or "advance" in m: product = "assets" parts = [f"SELECT * FROM {full_table} WHERE 1=1"] why_parts = [f"Filtered rows from {full_table}"] if product: parts.append(f"AND lower(product) = '{product}'") why_parts.append(f"product = {product}") cur_match = re.search(r"\b(currency|in)\s+([a-z]{3})\b", m) if cur_match: cur = cur_match.group(2).upper() parts.append(f"AND upper(currency) = '{cur}'") why_parts.append(f"currency = {cur}") seg_match = re.search(r"(segment|for)\s+([a-z0-9_\- ]+)", m) if seg_match: seg = seg_match.group(2).strip() if seg: parts.append(f"AND lower(segments) LIKE '%{seg.lower()}%'") why_parts.append(f"segments like '{seg}'") if limit: parts.append(f"LIMIT {limit}") fallback_sql = " ".join(parts) + ";" fallback_why = "; ".join(why_parts) return fallback_sql, fallback_why # ------------------------------------------------------------ # Public wrappers # ------------------------------------------------------------ def query_from_nl(self, message: str): sql, why = self._nl_to_sql(message) df = self.run_sql(sql) return df, sql, why def get_full_table_path(self) -> str: return self.full_table