bhardwaj08sarthak commited on
Commit
00dcf6d
·
verified ·
1 Parent(s): 4b2ffda

Update level_classifier_tool_2.py

Browse files
Files changed (1) hide show
  1. level_classifier_tool_2.py +30 -15
level_classifier_tool_2.py CHANGED
@@ -8,9 +8,6 @@ from transformers import AutoTokenizer, AutoModel
8
  #import tensorflow
9
  Agg = Literal["mean", "max", "topk_mean"]
10
 
11
-
12
- # --------------------------- Embedding backend ---------------------------
13
-
14
  @dataclass
15
  class HFEmbeddingBackend:
16
  """
@@ -18,18 +15,36 @@ class HFEmbeddingBackend:
18
  Uses mean pooling over last_hidden_state and L2 normalizes the result.
19
  """
20
  model_name: str = "google/embeddinggemma-300m"
21
- device = "cuda" if torch.cuda.is_available() else "cpu"
22
- TOK = AutoTokenizer.from_pretrained(model_name)
23
- MODEL = AutoModel.from_pretrained(model_name)
24
- MODEL.to(device).eval()
25
-
26
- def encode(self, texts: Iterable[str], batch_size: int = 32) -> "tuple[torch.Tensor, list[str]]":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  """
28
  Returns (embeddings, texts_list). Embeddings have shape [N, D] and are unit-normalized.
29
  """
30
  texts_list = list(texts)
31
  if not texts_list:
32
- return torch.empty((0, self.MODEL.config.hidden_size)), [] # type: ignore
33
 
34
  all_out = []
35
  with torch.inference_mode():
@@ -39,18 +54,18 @@ class HFEmbeddingBackend:
39
  out = self.MODEL(**enc)
40
  last = out.last_hidden_state # [B, T, H]
41
  mask = enc["attention_mask"].unsqueeze(-1) # [B, T, 1]
42
- # mean pool
 
43
  summed = (last * mask).sum(dim=1)
44
  counts = mask.sum(dim=1).clamp(min=1)
45
  pooled = summed / counts
 
46
  # L2 normalize
47
  pooled = pooled / pooled.norm(dim=1, keepdim=True).clamp(min=1e-12)
48
  all_out.append(pooled.cpu())
49
- embs = torch.cat(all_out, dim=0) if all_out else torch.empty((0, self.MODEL.config.hidden_size)) # type: ignore
50
- return embs, texts_list
51
 
52
-
53
- # --------------------------- Utilities ---------------------------
54
 
55
  def _normalize_whitespace(s: str) -> str:
56
  return " ".join(s.strip().split())
 
8
  #import tensorflow
9
  Agg = Literal["mean", "max", "topk_mean"]
10
 
 
 
 
11
  @dataclass
12
  class HFEmbeddingBackend:
13
  """
 
15
  Uses mean pooling over last_hidden_state and L2 normalizes the result.
16
  """
17
  model_name: str = "google/embeddinggemma-300m"
18
+ device: str = field(default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu")
19
+ # Lazy-initialized in __post_init__
20
+ TOK: Any = field(init=False, repr=False)
21
+ MODEL: Any = field(init=False, repr=False)
22
+
23
+ def __post_init__(self):
24
+
25
+ os.environ.setdefault("SPACES_ZERO_DISABLED", "1")
26
+
27
+ try:
28
+ torch.backends.cuda.sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False)
29
+ except Exception:
30
+ pass
31
+
32
+ self.TOK = AutoTokenizer.from_pretrained(self.model_name)
33
+ self.MODEL = AutoModel.from_pretrained(self.model_name, attn_implementation="eager")
34
+ try:
35
+ self.MODEL.config.attn_implementation = "eager"
36
+ except Exception:
37
+ pass
38
+
39
+ self.MODEL.to(self.device).eval()
40
+
41
+ def encode(self, texts: Iterable[str], batch_size: int = 32) -> "Tuple[torch.Tensor, List[str]]":
42
  """
43
  Returns (embeddings, texts_list). Embeddings have shape [N, D] and are unit-normalized.
44
  """
45
  texts_list = list(texts)
46
  if not texts_list:
47
+ return torch.empty((0, self.MODEL.config.hidden_size)), []
48
 
49
  all_out = []
50
  with torch.inference_mode():
 
54
  out = self.MODEL(**enc)
55
  last = out.last_hidden_state # [B, T, H]
56
  mask = enc["attention_mask"].unsqueeze(-1) # [B, T, 1]
57
+
58
+ # Mean pool
59
  summed = (last * mask).sum(dim=1)
60
  counts = mask.sum(dim=1).clamp(min=1)
61
  pooled = summed / counts
62
+
63
  # L2 normalize
64
  pooled = pooled / pooled.norm(dim=1, keepdim=True).clamp(min=1e-12)
65
  all_out.append(pooled.cpu())
 
 
66
 
67
+ embs = torch.cat(all_out, dim=0) if all_out else torch.empty((0, self.MODEL.config.hidden_size)) # type: ignore
68
+ return embs, texts_list----------------- Utilities ---------------------------
69
 
70
  def _normalize_whitespace(s: str) -> str:
71
  return " ".join(s.strip().split())