Update level_classifier_tool_2.py
Browse files- level_classifier_tool_2.py +30 -15
level_classifier_tool_2.py
CHANGED
|
@@ -8,9 +8,6 @@ from transformers import AutoTokenizer, AutoModel
|
|
| 8 |
#import tensorflow
|
| 9 |
Agg = Literal["mean", "max", "topk_mean"]
|
| 10 |
|
| 11 |
-
|
| 12 |
-
# --------------------------- Embedding backend ---------------------------
|
| 13 |
-
|
| 14 |
@dataclass
|
| 15 |
class HFEmbeddingBackend:
|
| 16 |
"""
|
|
@@ -18,18 +15,36 @@ class HFEmbeddingBackend:
|
|
| 18 |
Uses mean pooling over last_hidden_state and L2 normalizes the result.
|
| 19 |
"""
|
| 20 |
model_name: str = "google/embeddinggemma-300m"
|
| 21 |
-
device =
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
MODEL
|
| 25 |
-
|
| 26 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
"""
|
| 28 |
Returns (embeddings, texts_list). Embeddings have shape [N, D] and are unit-normalized.
|
| 29 |
"""
|
| 30 |
texts_list = list(texts)
|
| 31 |
if not texts_list:
|
| 32 |
-
return torch.empty((0, self.MODEL.config.hidden_size)), []
|
| 33 |
|
| 34 |
all_out = []
|
| 35 |
with torch.inference_mode():
|
|
@@ -39,18 +54,18 @@ class HFEmbeddingBackend:
|
|
| 39 |
out = self.MODEL(**enc)
|
| 40 |
last = out.last_hidden_state # [B, T, H]
|
| 41 |
mask = enc["attention_mask"].unsqueeze(-1) # [B, T, 1]
|
| 42 |
-
|
|
|
|
| 43 |
summed = (last * mask).sum(dim=1)
|
| 44 |
counts = mask.sum(dim=1).clamp(min=1)
|
| 45 |
pooled = summed / counts
|
|
|
|
| 46 |
# L2 normalize
|
| 47 |
pooled = pooled / pooled.norm(dim=1, keepdim=True).clamp(min=1e-12)
|
| 48 |
all_out.append(pooled.cpu())
|
| 49 |
-
embs = torch.cat(all_out, dim=0) if all_out else torch.empty((0, self.MODEL.config.hidden_size)) # type: ignore
|
| 50 |
-
return embs, texts_list
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
|
| 55 |
def _normalize_whitespace(s: str) -> str:
|
| 56 |
return " ".join(s.strip().split())
|
|
|
|
| 8 |
#import tensorflow
|
| 9 |
Agg = Literal["mean", "max", "topk_mean"]
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
@dataclass
|
| 12 |
class HFEmbeddingBackend:
|
| 13 |
"""
|
|
|
|
| 15 |
Uses mean pooling over last_hidden_state and L2 normalizes the result.
|
| 16 |
"""
|
| 17 |
model_name: str = "google/embeddinggemma-300m"
|
| 18 |
+
device: str = field(default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu")
|
| 19 |
+
# Lazy-initialized in __post_init__
|
| 20 |
+
TOK: Any = field(init=False, repr=False)
|
| 21 |
+
MODEL: Any = field(init=False, repr=False)
|
| 22 |
+
|
| 23 |
+
def __post_init__(self):
|
| 24 |
+
|
| 25 |
+
os.environ.setdefault("SPACES_ZERO_DISABLED", "1")
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
torch.backends.cuda.sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False)
|
| 29 |
+
except Exception:
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
self.TOK = AutoTokenizer.from_pretrained(self.model_name)
|
| 33 |
+
self.MODEL = AutoModel.from_pretrained(self.model_name, attn_implementation="eager")
|
| 34 |
+
try:
|
| 35 |
+
self.MODEL.config.attn_implementation = "eager"
|
| 36 |
+
except Exception:
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
self.MODEL.to(self.device).eval()
|
| 40 |
+
|
| 41 |
+
def encode(self, texts: Iterable[str], batch_size: int = 32) -> "Tuple[torch.Tensor, List[str]]":
|
| 42 |
"""
|
| 43 |
Returns (embeddings, texts_list). Embeddings have shape [N, D] and are unit-normalized.
|
| 44 |
"""
|
| 45 |
texts_list = list(texts)
|
| 46 |
if not texts_list:
|
| 47 |
+
return torch.empty((0, self.MODEL.config.hidden_size)), []
|
| 48 |
|
| 49 |
all_out = []
|
| 50 |
with torch.inference_mode():
|
|
|
|
| 54 |
out = self.MODEL(**enc)
|
| 55 |
last = out.last_hidden_state # [B, T, H]
|
| 56 |
mask = enc["attention_mask"].unsqueeze(-1) # [B, T, 1]
|
| 57 |
+
|
| 58 |
+
# Mean pool
|
| 59 |
summed = (last * mask).sum(dim=1)
|
| 60 |
counts = mask.sum(dim=1).clamp(min=1)
|
| 61 |
pooled = summed / counts
|
| 62 |
+
|
| 63 |
# L2 normalize
|
| 64 |
pooled = pooled / pooled.norm(dim=1, keepdim=True).clamp(min=1e-12)
|
| 65 |
all_out.append(pooled.cpu())
|
|
|
|
|
|
|
| 66 |
|
| 67 |
+
embs = torch.cat(all_out, dim=0) if all_out else torch.empty((0, self.MODEL.config.hidden_size)) # type: ignore
|
| 68 |
+
return embs, texts_list----------------- Utilities ---------------------------
|
| 69 |
|
| 70 |
def _normalize_whitespace(s: str) -> str:
|
| 71 |
return " ".join(s.strip().split())
|