Update app.py
Browse files
app.py
CHANGED
|
@@ -133,7 +133,8 @@ def encode_and_trace(text, selected_roles):
|
|
| 133 |
masked_input = ids.where(mask_flags, MASK_ID)
|
| 134 |
|
| 135 |
encoded_m = encode(masked_input, attn)
|
| 136 |
-
logits = mlm_head(encoded_m)
|
|
|
|
| 137 |
preds = logits.argmax(-1)
|
| 138 |
|
| 139 |
masked_positions = (~mask_flags[0]).nonzero(as_tuple=False).squeeze(-1)
|
|
|
|
| 133 |
masked_input = ids.where(mask_flags, MASK_ID)
|
| 134 |
|
| 135 |
encoded_m = encode(masked_input, attn)
|
| 136 |
+
logits = mlm_head(encoded_m)[0] # shape: (S, V)
|
| 137 |
+
|
| 138 |
preds = logits.argmax(-1)
|
| 139 |
|
| 140 |
masked_positions = (~mask_flags[0]).nonzero(as_tuple=False).squeeze(-1)
|