burtenshaw
commited on
Commit
·
a224632
1
Parent(s):
e3af5c6
improve structure and layout
Browse files- app/src/content/article.mdx +0 -2
- app/src/content/chapters/inference.mdx +7 -1
- app/src/content/chapters/sft.mdx +90 -36
- inference.ipynb +187 -0
- sft.ipynb +4 -402
app/src/content/article.mdx
CHANGED
|
@@ -570,8 +570,6 @@ In `modular_nanochat.py`, we don't need to write this logic at all. As seen in t
|
|
| 570 |
|
| 571 |
It's very clear that Andrej Karpathy's implementation offers 10 times more to learn from than the transformer version which inherits almost entirely from existing models or features. That said, we can still take more away from the inherited modular modeling implementation. Models like Llama, Llama4, Gemma2, Qwen3, and CLIP are all reused to create a genuinely canonical implementation of a transformer.
|
| 572 |
|
| 573 |
-
# Hands-on Tutorial
|
| 574 |
-
|
| 575 |
Ok. Let's cut the philosphy and see what we can do with `nanochat` in transformers.
|
| 576 |
|
| 577 |
<Inference />
|
|
|
|
| 570 |
|
| 571 |
It's very clear that Andrej Karpathy's implementation offers 10 times more to learn from than the transformer version which inherits almost entirely from existing models or features. That said, we can still take more away from the inherited modular modeling implementation. Models like Llama, Llama4, Gemma2, Qwen3, and CLIP are all reused to create a genuinely canonical implementation of a transformer.
|
| 572 |
|
|
|
|
|
|
|
| 573 |
Ok. Let's cut the philosphy and see what we can do with `nanochat` in transformers.
|
| 574 |
|
| 575 |
<Inference />
|
app/src/content/chapters/inference.mdx
CHANGED
|
@@ -1,4 +1,10 @@
|
|
| 1 |
-
## Inference on
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
First bonus tutorial will help you to do basic inference in `transformers`:
|
| 4 |
|
|
|
|
| 1 |
+
## Example 1: Inference on nanochat in Transformers
|
| 2 |
+
|
| 3 |
+
<Sidenote>
|
| 4 |
+
|
| 5 |
+
[](https://colab.research.google.com/#fileId=https://huggingface.co/datasets/nanochat-students/notebooks/blob/main/inference.ipynb)
|
| 6 |
+
|
| 7 |
+
</Sidenote>
|
| 8 |
|
| 9 |
First bonus tutorial will help you to do basic inference in `transformers`:
|
| 10 |
|
app/src/content/chapters/sft.mdx
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import Sidenote from '../../components/Sidenote.astro'
|
| 2 |
import Note from '../../components/Note.astro'
|
| 3 |
|
| 4 |
-
|
| 5 |
|
| 6 |
<Sidenote>
|
| 7 |
|
|
@@ -19,7 +19,7 @@ In this tutorial, we'll fine-tune the NanoChat model using pure PyTorch, giving
|
|
| 19 |
|
| 20 |
</Note>
|
| 21 |
|
| 22 |
-
|
| 23 |
|
| 24 |
<Sidenote>
|
| 25 |
|
|
@@ -51,7 +51,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 51 |
|
| 52 |
We use `bfloat16` precision on GPU to reduce memory usage while maintaining training stability. On CPU, we fall back to `float32` for compatibility.
|
| 53 |
|
| 54 |
-
|
| 55 |
|
| 56 |
<Sidenote>
|
| 57 |
|
|
@@ -87,7 +87,7 @@ trainable params: 1,179,648 || all params: 1,880,227,840 || trainable%: 0.0627
|
|
| 87 |
|
| 88 |
With LoRA, we're only training **0.06%** of the model's parameters—just over 1 million weights instead of 1.8 billion. This makes fine-tuning feasible on consumer hardware.
|
| 89 |
|
| 90 |
-
|
| 91 |
|
| 92 |
<Sidenote>
|
| 93 |
|
|
@@ -179,7 +179,7 @@ Generated: The capital of France is Paris.<|assistant_end|>
|
|
| 179 |
|
| 180 |
Notice the special tokens: `<|bos|>`, `<|user_start|>`, `<|assistant_start|>`, etc. These delimiters help the model understand conversation structure.
|
| 181 |
|
| 182 |
-
|
| 183 |
|
| 184 |
<Sidenote>
|
| 185 |
|
|
@@ -196,8 +196,6 @@ train_dataset = splits["train"]
|
|
| 196 |
eval_dataset = splits["test"]
|
| 197 |
```
|
| 198 |
|
| 199 |
-
### Process the Dataset
|
| 200 |
-
|
| 201 |
<Sidenote>
|
| 202 |
|
| 203 |
The [datasets map](https://huggingface.co/docs/datasets/process#map) function applies transformations efficiently with caching and multiprocessing support.
|
|
@@ -236,31 +234,6 @@ eval_dataset = eval_dataset.select(range(min(len(eval_dataset), max_eval_example
|
|
| 236 |
eval_dataset = eval_dataset.map(format_example, remove_columns=eval_dataset.column_names)
|
| 237 |
```
|
| 238 |
|
| 239 |
-
## Training Configuration
|
| 240 |
-
|
| 241 |
-
These hyperparameters control the training dynamics. We use conservative values that work well across different hardware:
|
| 242 |
-
|
| 243 |
-
```python
|
| 244 |
-
train_batch_size = 2
|
| 245 |
-
eval_batch_size = 2
|
| 246 |
-
num_epochs = 1
|
| 247 |
-
gradient_accumulation_steps = 4
|
| 248 |
-
learning_rate = 1e-5
|
| 249 |
-
weight_decay = 0.0
|
| 250 |
-
warmup_ratio = 0.03
|
| 251 |
-
logging_frequency = 10
|
| 252 |
-
```
|
| 253 |
-
|
| 254 |
-
<Sidenote>
|
| 255 |
-
|
| 256 |
-
**Gradient accumulation** simulates larger batch sizes by accumulating gradients over multiple forward passes before updating weights. Effective batch size = `train_batch_size × gradient_accumulation_steps` = 8.
|
| 257 |
-
|
| 258 |
-
</Sidenote>
|
| 259 |
-
|
| 260 |
-
Key configuration choices include using a low learning rate (`1e-5`), as LoRA generally requires smaller learning rates given that the base model weights are kept frozen. Additionally, gradient accumulation is employed to enable larger effective batch sizes, which helps when training on GPUs with limited memory.
|
| 261 |
-
|
| 262 |
-
## Create a DataLoader
|
| 263 |
-
|
| 264 |
<Sidenote>
|
| 265 |
|
| 266 |
PyTorch's [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) handles batching, shuffling, and parallel data loading automatically.
|
|
@@ -288,7 +261,30 @@ eval_loader = DataLoader(eval_dataset, batch_size=eval_batch_size, shuffle=False
|
|
| 288 |
|
| 289 |
Setting padding tokens to `-100` in labels tells PyTorch's cross-entropy loss to ignore them—we don't want to penalize the model for not predicting padding.
|
| 290 |
|
| 291 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
<Sidenote>
|
| 294 |
|
|
@@ -306,7 +302,7 @@ optimizer = torch.optim.AdamW(
|
|
| 306 |
)
|
| 307 |
```
|
| 308 |
|
| 309 |
-
|
| 310 |
|
| 311 |
<Sidenote>
|
| 312 |
|
|
@@ -323,8 +319,6 @@ warmup_steps = max(1, int(max_train_steps * warmup_ratio))
|
|
| 323 |
scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, max_train_steps)
|
| 324 |
```
|
| 325 |
|
| 326 |
-
## The Training Loop
|
| 327 |
-
|
| 328 |
<Sidenote>
|
| 329 |
|
| 330 |
For distributed training across multiple GPUs, consider [Accelerate](https://huggingface.co/docs/accelerate/index) which wraps this loop with minimal code changes.
|
|
@@ -398,3 +392,63 @@ step=00040 | loss=1.7935 | lr=5.33e-06
|
|
| 398 |
step=00050 | loss=1.8029 | lr=6.67e-06
|
| 399 |
...
|
| 400 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import Sidenote from '../../components/Sidenote.astro'
|
| 2 |
import Note from '../../components/Note.astro'
|
| 3 |
|
| 4 |
+
## Example 2: Supervised Fine-tuning in torch
|
| 5 |
|
| 6 |
<Sidenote>
|
| 7 |
|
|
|
|
| 19 |
|
| 20 |
</Note>
|
| 21 |
|
| 22 |
+
### Import model and tokenizer
|
| 23 |
|
| 24 |
<Sidenote>
|
| 25 |
|
|
|
|
| 51 |
|
| 52 |
We use `bfloat16` precision on GPU to reduce memory usage while maintaining training stability. On CPU, we fall back to `float32` for compatibility.
|
| 53 |
|
| 54 |
+
### Setup LoRA
|
| 55 |
|
| 56 |
<Sidenote>
|
| 57 |
|
|
|
|
| 87 |
|
| 88 |
With LoRA, we're only training **0.06%** of the model's parameters—just over 1 million weights instead of 1.8 billion. This makes fine-tuning feasible on consumer hardware.
|
| 89 |
|
| 90 |
+
### Demo the model
|
| 91 |
|
| 92 |
<Sidenote>
|
| 93 |
|
|
|
|
| 179 |
|
| 180 |
Notice the special tokens: `<|bos|>`, `<|user_start|>`, `<|assistant_start|>`, etc. These delimiters help the model understand conversation structure.
|
| 181 |
|
| 182 |
+
### Dataset
|
| 183 |
|
| 184 |
<Sidenote>
|
| 185 |
|
|
|
|
| 196 |
eval_dataset = splits["test"]
|
| 197 |
```
|
| 198 |
|
|
|
|
|
|
|
| 199 |
<Sidenote>
|
| 200 |
|
| 201 |
The [datasets map](https://huggingface.co/docs/datasets/process#map) function applies transformations efficiently with caching and multiprocessing support.
|
|
|
|
| 234 |
eval_dataset = eval_dataset.map(format_example, remove_columns=eval_dataset.column_names)
|
| 235 |
```
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
<Sidenote>
|
| 238 |
|
| 239 |
PyTorch's [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) handles batching, shuffling, and parallel data loading automatically.
|
|
|
|
| 261 |
|
| 262 |
Setting padding tokens to `-100` in labels tells PyTorch's cross-entropy loss to ignore them—we don't want to penalize the model for not predicting padding.
|
| 263 |
|
| 264 |
+
## Training
|
| 265 |
+
|
| 266 |
+
These hyperparameters control the training dynamics. We use conservative values that work well across different hardware:
|
| 267 |
+
|
| 268 |
+
```python
|
| 269 |
+
train_batch_size = 2
|
| 270 |
+
eval_batch_size = 2
|
| 271 |
+
num_epochs = 1
|
| 272 |
+
gradient_accumulation_steps = 4
|
| 273 |
+
learning_rate = 1e-5
|
| 274 |
+
weight_decay = 0.0
|
| 275 |
+
warmup_ratio = 0.03
|
| 276 |
+
logging_frequency = 10
|
| 277 |
+
```
|
| 278 |
+
|
| 279 |
+
<Sidenote>
|
| 280 |
+
|
| 281 |
+
**Gradient accumulation** simulates larger batch sizes by accumulating gradients over multiple forward passes before updating weights. Effective batch size = `train_batch_size × gradient_accumulation_steps` = 8.
|
| 282 |
+
|
| 283 |
+
</Sidenote>
|
| 284 |
+
|
| 285 |
+
Key configuration choices include using a low learning rate (`1e-5`), as LoRA generally requires smaller learning rates given that the base model weights are kept frozen. Additionally, gradient accumulation is employed to enable larger effective batch sizes, which helps when training on GPUs with limited memory.
|
| 286 |
+
|
| 287 |
+
### Optimizer
|
| 288 |
|
| 289 |
<Sidenote>
|
| 290 |
|
|
|
|
| 302 |
)
|
| 303 |
```
|
| 304 |
|
| 305 |
+
### Learning Rate Scheduler
|
| 306 |
|
| 307 |
<Sidenote>
|
| 308 |
|
|
|
|
| 319 |
scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, max_train_steps)
|
| 320 |
```
|
| 321 |
|
|
|
|
|
|
|
| 322 |
<Sidenote>
|
| 323 |
|
| 324 |
For distributed training across multiple GPUs, consider [Accelerate](https://huggingface.co/docs/accelerate/index) which wraps this loop with minimal code changes.
|
|
|
|
| 392 |
step=00050 | loss=1.8029 | lr=6.67e-06
|
| 393 |
...
|
| 394 |
```
|
| 395 |
+
|
| 396 |
+
## Example 3: Fine-tuning with TRL
|
| 397 |
+
|
| 398 |
+
<Sidenote>
|
| 399 |
+
|
| 400 |
+
[](https://colab.research.google.com/#fileId=https://github.com/huggingface/trl/blob/main/examples/notebooks/sft_trl_lora_qlora.ipynb)
|
| 401 |
+
|
| 402 |
+
</Sidenote>
|
| 403 |
+
|
| 404 |
+
Finally, we can implement the training loop above with TRL. Which definitely simplifies the code and abstracts away a lot of the complexity (education). But it's got all the bells and whistles of a production-ready solution.
|
| 405 |
+
|
| 406 |
+
We can define the training arguments and create the trainer object like this:
|
| 407 |
+
|
| 408 |
+
```python
|
| 409 |
+
from trl import SFTConfig
|
| 410 |
+
|
| 411 |
+
training_args = SFTConfig(
|
| 412 |
+
per_device_train_batch_size = 1, # Batch size per GPU
|
| 413 |
+
gradient_accumulation_steps = 4, # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16
|
| 414 |
+
warmup_steps = 5,
|
| 415 |
+
# num_train_epochs = 1, # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)
|
| 416 |
+
max_steps = 30,
|
| 417 |
+
learning_rate = 2e-4, # Learning rate for the optimizer
|
| 418 |
+
optim = "paged_adamw_8bit", # Optimizer
|
| 419 |
+
|
| 420 |
+
# Logging / reporting
|
| 421 |
+
logging_steps=1, # Log training metrics every N steps
|
| 422 |
+
report_to="trackio", # Experiment tracking tool
|
| 423 |
+
trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved
|
| 424 |
+
output_dir=output_dir, # Where to save model checkpoints and logs
|
| 425 |
+
|
| 426 |
+
max_length=1024, # Maximum input sequence length
|
| 427 |
+
use_liger_kernel=True, # Enable Liger kernel optimizations for faster training
|
| 428 |
+
activation_offloading=True, # Offload activations to CPU to reduce GPU memory usage
|
| 429 |
+
gradient_checkpointing=True, # Save memory by re-computing activations during backpropagation
|
| 430 |
+
|
| 431 |
+
# Hub integration
|
| 432 |
+
push_to_hub=True, # Automatically push the trained model to the Hugging Face Hub
|
| 433 |
+
# The model will be saved under your Hub account in the repository named `output_dir`
|
| 434 |
+
)
|
| 435 |
+
```
|
| 436 |
+
|
| 437 |
+
Then we can train the model like this and TRL will deal with data loading, batching, and training.
|
| 438 |
+
|
| 439 |
+
```
|
| 440 |
+
from trl import SFTTrainer
|
| 441 |
+
|
| 442 |
+
trainer = SFTTrainer(
|
| 443 |
+
model=model,
|
| 444 |
+
args=training_args,
|
| 445 |
+
train_dataset=train_dataset,
|
| 446 |
+
peft_config=peft_config
|
| 447 |
+
)
|
| 448 |
+
```
|
| 449 |
+
|
| 450 |
+
And then we can train the model like this:
|
| 451 |
+
|
| 452 |
+
```
|
| 453 |
+
trainer_stats = trainer.train()
|
| 454 |
+
```
|
inference.ipynb
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "b7eb261b",
|
| 6 |
+
"metadata": {
|
| 7 |
+
"id": "b7eb261b"
|
| 8 |
+
},
|
| 9 |
+
"source": [
|
| 10 |
+
"# NanoChat Easy - SFT Training\n"
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "markdown",
|
| 15 |
+
"id": "8b8a04a8",
|
| 16 |
+
"metadata": {
|
| 17 |
+
"id": "8b8a04a8"
|
| 18 |
+
},
|
| 19 |
+
"source": [
|
| 20 |
+
"## Import model and tokenizer\n"
|
| 21 |
+
]
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "code",
|
| 25 |
+
"execution_count": null,
|
| 26 |
+
"id": "3e48247c",
|
| 27 |
+
"metadata": {
|
| 28 |
+
"id": "3e48247c",
|
| 29 |
+
"outputId": "882fcf01-34fb-4123-e84c-deefdf477814"
|
| 30 |
+
},
|
| 31 |
+
"outputs": [
|
| 32 |
+
{
|
| 33 |
+
"name": "stderr",
|
| 34 |
+
"output_type": "stream",
|
| 35 |
+
"text": [
|
| 36 |
+
"/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 37 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
| 38 |
+
"`torch_dtype` is deprecated! Use `dtype` instead!\n"
|
| 39 |
+
]
|
| 40 |
+
}
|
| 41 |
+
],
|
| 42 |
+
"source": [
|
| 43 |
+
"import torch\n",
|
| 44 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"\n",
|
| 47 |
+
"model_id = \"karpathy/nanochat-d32\"\n",
|
| 48 |
+
"revision = \"refs/pr/1\"\n",
|
| 49 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)\n",
|
| 53 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 54 |
+
" model_id,\n",
|
| 55 |
+
" revision=revision,\n",
|
| 56 |
+
" torch_dtype=torch.bfloat16 if device.type == \"cuda\" else torch.float32,\n",
|
| 57 |
+
").to(device)\n"
|
| 58 |
+
]
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"cell_type": "markdown",
|
| 62 |
+
"id": "4810af1a",
|
| 63 |
+
"metadata": {
|
| 64 |
+
"id": "4810af1a"
|
| 65 |
+
},
|
| 66 |
+
"source": [
|
| 67 |
+
"## Demo the model\n"
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"cell_type": "code",
|
| 72 |
+
"execution_count": null,
|
| 73 |
+
"id": "b3e81aa9",
|
| 74 |
+
"metadata": {
|
| 75 |
+
"id": "b3e81aa9",
|
| 76 |
+
"outputId": "1cde7e69-7ff1-4bfe-aa9f-9ded20249d82"
|
| 77 |
+
},
|
| 78 |
+
"outputs": [
|
| 79 |
+
{
|
| 80 |
+
"name": "stdout",
|
| 81 |
+
"output_type": "stream",
|
| 82 |
+
"text": [
|
| 83 |
+
"================================================================================\n",
|
| 84 |
+
"TEST 1: Plain Autoregressive Prompt\n",
|
| 85 |
+
"================================================================================\n",
|
| 86 |
+
"Prompt: The Eiffel Tower stands in Paris and\n",
|
| 87 |
+
"\n",
|
| 88 |
+
"Generated: is one of the most famous landmarks in the world. It is located on the Champ de Mars in the heart of the city. The tower was built for the 1889 World's Fair. It was designed by the French engineer Gustave Eiffel and took 2 years to build. The Eiffel Tower stands 324 meters\n",
|
| 89 |
+
"================================================================================\n"
|
| 90 |
+
]
|
| 91 |
+
}
|
| 92 |
+
],
|
| 93 |
+
"source": [
|
| 94 |
+
"print(\"=\" * 80)\n",
|
| 95 |
+
"print(\"TEST 1: Plain Autoregressive Prompt\")\n",
|
| 96 |
+
"print(\"=\" * 80)\n",
|
| 97 |
+
"prompt = \"The Eiffel Tower stands in Paris and\"\n",
|
| 98 |
+
"test_inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
|
| 99 |
+
"\n",
|
| 100 |
+
"\n",
|
| 101 |
+
"with torch.no_grad():\n",
|
| 102 |
+
" test_outputs = model.generate(\n",
|
| 103 |
+
" **test_inputs,\n",
|
| 104 |
+
" max_new_tokens=64,\n",
|
| 105 |
+
" do_sample=False,\n",
|
| 106 |
+
" pad_token_id=tokenizer.pad_token_id,\n",
|
| 107 |
+
" )\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"generated_tokens = test_outputs[0, test_inputs[\"input_ids\"].shape[1] :]\n",
|
| 110 |
+
"print(f\"Prompt: {prompt}\")\n",
|
| 111 |
+
"print(f\"\\nGenerated: {tokenizer.decode(generated_tokens, skip_special_tokens=True)}\")\n",
|
| 112 |
+
"print(\"=\" * 80)\n"
|
| 113 |
+
]
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"cell_type": "code",
|
| 117 |
+
"execution_count": null,
|
| 118 |
+
"id": "8e7b275c",
|
| 119 |
+
"metadata": {
|
| 120 |
+
"id": "8e7b275c",
|
| 121 |
+
"outputId": "719e986e-61b4-4fd5-db15-4a9ef8f97396"
|
| 122 |
+
},
|
| 123 |
+
"outputs": [
|
| 124 |
+
{
|
| 125 |
+
"name": "stdout",
|
| 126 |
+
"output_type": "stream",
|
| 127 |
+
"text": [
|
| 128 |
+
"================================================================================\n",
|
| 129 |
+
"TEST 2: Chat Template\n",
|
| 130 |
+
"================================================================================\n",
|
| 131 |
+
"Formatted prompt: <|bos|><|user_start|>What is the capital of France?<|user_end|><|assistant_start|>\n",
|
| 132 |
+
"Input IDs: [65527, 65528, 1442, 309, 261, 3429, 281, 4215, 63, 65529, 65530]\n",
|
| 133 |
+
"\n",
|
| 134 |
+
"Generated: The capital of France is Paris.<|assistant_end|>\n",
|
| 135 |
+
"================================================================================\n"
|
| 136 |
+
]
|
| 137 |
+
}
|
| 138 |
+
],
|
| 139 |
+
"source": [
|
| 140 |
+
"print(\"=\" * 80)\n",
|
| 141 |
+
"print(\"TEST 2: Chat Template\")\n",
|
| 142 |
+
"print(\"=\" * 80)\n",
|
| 143 |
+
"conversation = [\n",
|
| 144 |
+
" {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n",
|
| 145 |
+
"]\n",
|
| 146 |
+
"\n",
|
| 147 |
+
"inputs = tokenizer.apply_chat_template(\n",
|
| 148 |
+
" conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors=\"pt\"\n",
|
| 149 |
+
").to(device)\n",
|
| 150 |
+
"\n",
|
| 151 |
+
"print(f\"Formatted prompt: {tokenizer.decode(inputs['input_ids'][0])}\")\n",
|
| 152 |
+
"print(f\"Input IDs: {inputs['input_ids'][0].tolist()}\")\n",
|
| 153 |
+
"\n",
|
| 154 |
+
"with torch.no_grad():\n",
|
| 155 |
+
" outputs = model.generate(**inputs, max_new_tokens=64, do_sample=False)\n",
|
| 156 |
+
"\n",
|
| 157 |
+
"generated_tokens = outputs[0, inputs[\"input_ids\"].shape[1] :]\n",
|
| 158 |
+
"print(f\"\\nGenerated: {tokenizer.decode(generated_tokens)}\")\n",
|
| 159 |
+
"print(\"=\" * 80)\n"
|
| 160 |
+
]
|
| 161 |
+
}
|
| 162 |
+
],
|
| 163 |
+
"metadata": {
|
| 164 |
+
"colab": {
|
| 165 |
+
"provenance": []
|
| 166 |
+
},
|
| 167 |
+
"kernelspec": {
|
| 168 |
+
"display_name": ".venv",
|
| 169 |
+
"language": "python",
|
| 170 |
+
"name": "python3"
|
| 171 |
+
},
|
| 172 |
+
"language_info": {
|
| 173 |
+
"codemirror_mode": {
|
| 174 |
+
"name": "ipython",
|
| 175 |
+
"version": 3
|
| 176 |
+
},
|
| 177 |
+
"file_extension": ".py",
|
| 178 |
+
"mimetype": "text/x-python",
|
| 179 |
+
"name": "python",
|
| 180 |
+
"nbconvert_exporter": "python",
|
| 181 |
+
"pygments_lexer": "ipython3",
|
| 182 |
+
"version": "3.10.18"
|
| 183 |
+
}
|
| 184 |
+
},
|
| 185 |
+
"nbformat": 4,
|
| 186 |
+
"nbformat_minor": 5
|
| 187 |
+
}
|
sft.ipynb
CHANGED
|
@@ -59,48 +59,6 @@
|
|
| 59 |
").to(device)\n"
|
| 60 |
]
|
| 61 |
},
|
| 62 |
-
{
|
| 63 |
-
"cell_type": "markdown",
|
| 64 |
-
"id": "c9a9c0a4",
|
| 65 |
-
"metadata": {
|
| 66 |
-
"id": "c9a9c0a4"
|
| 67 |
-
},
|
| 68 |
-
"source": [
|
| 69 |
-
"## Setup LoRA\n"
|
| 70 |
-
]
|
| 71 |
-
},
|
| 72 |
-
{
|
| 73 |
-
"cell_type": "code",
|
| 74 |
-
"execution_count": null,
|
| 75 |
-
"id": "dd9a698a",
|
| 76 |
-
"metadata": {
|
| 77 |
-
"id": "dd9a698a",
|
| 78 |
-
"outputId": "0aae9ecc-7af9-436e-a95b-a4cd023997fd"
|
| 79 |
-
},
|
| 80 |
-
"outputs": [
|
| 81 |
-
{
|
| 82 |
-
"name": "stdout",
|
| 83 |
-
"output_type": "stream",
|
| 84 |
-
"text": [
|
| 85 |
-
"trainable params: 1,179,648 || all params: 1,880,227,840 || trainable%: 0.0627\n"
|
| 86 |
-
]
|
| 87 |
-
}
|
| 88 |
-
],
|
| 89 |
-
"source": [
|
| 90 |
-
"from peft import LoraConfig, get_peft_model\n",
|
| 91 |
-
"\n",
|
| 92 |
-
"lora_config = LoraConfig(\n",
|
| 93 |
-
" r=1,\n",
|
| 94 |
-
" lora_alpha=2,\n",
|
| 95 |
-
" lora_dropout=0.00,\n",
|
| 96 |
-
" task_type=\"CAUSAL_LM\",\n",
|
| 97 |
-
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"fc1\", \"fc2\"]\n",
|
| 98 |
-
")\n",
|
| 99 |
-
"\n",
|
| 100 |
-
"model = get_peft_model(model, lora_config)\n",
|
| 101 |
-
"model.print_trainable_parameters()\n"
|
| 102 |
-
]
|
| 103 |
-
},
|
| 104 |
{
|
| 105 |
"cell_type": "markdown",
|
| 106 |
"id": "4810af1a",
|
|
@@ -206,365 +164,12 @@
|
|
| 206 |
"print(f\"\\nGenerated: {tokenizer.decode(generated_tokens)}\")\n",
|
| 207 |
"print(\"=\" * 80)\n"
|
| 208 |
]
|
| 209 |
-
},
|
| 210 |
-
{
|
| 211 |
-
"cell_type": "markdown",
|
| 212 |
-
"id": "44cb321a",
|
| 213 |
-
"metadata": {
|
| 214 |
-
"id": "44cb321a"
|
| 215 |
-
},
|
| 216 |
-
"source": [
|
| 217 |
-
"## Dataset\n"
|
| 218 |
-
]
|
| 219 |
-
},
|
| 220 |
-
{
|
| 221 |
-
"cell_type": "code",
|
| 222 |
-
"execution_count": null,
|
| 223 |
-
"id": "e1a75c14",
|
| 224 |
-
"metadata": {
|
| 225 |
-
"id": "e1a75c14"
|
| 226 |
-
},
|
| 227 |
-
"outputs": [],
|
| 228 |
-
"source": [
|
| 229 |
-
"raw_dataset = load_dataset(\"HuggingFaceTB/smoltalk2\", \"SFT\", split=\"OpenThoughts3_1.2M_think\")\n",
|
| 230 |
-
"splits = raw_dataset.train_test_split(test_size=0.1, seed=42)\n",
|
| 231 |
-
"train_dataset = splits[\"train\"]\n",
|
| 232 |
-
"eval_dataset = splits[\"test\"]\n"
|
| 233 |
-
]
|
| 234 |
-
},
|
| 235 |
-
{
|
| 236 |
-
"cell_type": "markdown",
|
| 237 |
-
"id": "8b29399d",
|
| 238 |
-
"metadata": {
|
| 239 |
-
"id": "8b29399d"
|
| 240 |
-
},
|
| 241 |
-
"source": [
|
| 242 |
-
"### Process the Dataset\n"
|
| 243 |
-
]
|
| 244 |
-
},
|
| 245 |
-
{
|
| 246 |
-
"cell_type": "code",
|
| 247 |
-
"execution_count": null,
|
| 248 |
-
"id": "451542b4",
|
| 249 |
-
"metadata": {
|
| 250 |
-
"id": "451542b4",
|
| 251 |
-
"outputId": "caa727dd-f9d8-4c67-d193-79bcc0836b49"
|
| 252 |
-
},
|
| 253 |
-
"outputs": [
|
| 254 |
-
{
|
| 255 |
-
"name": "stderr",
|
| 256 |
-
"output_type": "stream",
|
| 257 |
-
"text": [
|
| 258 |
-
"Map: 0%| | 0/20000 [00:00<?, ? examples/s]"
|
| 259 |
-
]
|
| 260 |
-
},
|
| 261 |
-
{
|
| 262 |
-
"name": "stderr",
|
| 263 |
-
"output_type": "stream",
|
| 264 |
-
"text": [
|
| 265 |
-
"Map: 100%|██████████| 20000/20000 [06:27<00:00, 51.68 examples/s]\n",
|
| 266 |
-
"Map: 100%|██████████| 1000/1000 [00:19<00:00, 52.12 examples/s]\n"
|
| 267 |
-
]
|
| 268 |
-
}
|
| 269 |
-
],
|
| 270 |
-
"source": [
|
| 271 |
-
"max_length = 2048\n",
|
| 272 |
-
"max_train_examples = 20000\n",
|
| 273 |
-
"max_eval_examples = 1000\n",
|
| 274 |
-
"\n",
|
| 275 |
-
"def format_example(example):\n",
|
| 276 |
-
" formatted = tokenizer.apply_chat_template(\n",
|
| 277 |
-
" example[\"messages\"],\n",
|
| 278 |
-
" add_generation_prompt=False,\n",
|
| 279 |
-
" truncation=True,\n",
|
| 280 |
-
" max_length=max_length,\n",
|
| 281 |
-
" padding=False,\n",
|
| 282 |
-
" return_dict=True,\n",
|
| 283 |
-
" return_tensors=\"pt\",\n",
|
| 284 |
-
" )\n",
|
| 285 |
-
" return {\n",
|
| 286 |
-
" \"input_ids\": formatted[\"input_ids\"][0].tolist(),\n",
|
| 287 |
-
" \"attention_mask\": formatted[\"attention_mask\"][0].tolist(),\n",
|
| 288 |
-
" }\n",
|
| 289 |
-
"\n",
|
| 290 |
-
"\n",
|
| 291 |
-
"if max_train_examples is not None:\n",
|
| 292 |
-
" train_dataset = train_dataset.select(range(min(len(train_dataset), max_train_examples)))\n",
|
| 293 |
-
" train_dataset = train_dataset.map(format_example, remove_columns=train_dataset.column_names)\n",
|
| 294 |
-
"else:\n",
|
| 295 |
-
" train_dataset = train_dataset.map(format_example, remove_columns=train_dataset.column_names)\n",
|
| 296 |
-
"\n",
|
| 297 |
-
"if max_eval_examples is not None:\n",
|
| 298 |
-
" eval_dataset = eval_dataset.select(range(min(len(eval_dataset), max_eval_examples)))\n",
|
| 299 |
-
" eval_dataset = eval_dataset.map(format_example, remove_columns=eval_dataset.column_names)\n",
|
| 300 |
-
"else:\n",
|
| 301 |
-
" eval_dataset = eval_dataset.map(format_example, remove_columns=eval_dataset.column_names)\n"
|
| 302 |
-
]
|
| 303 |
-
},
|
| 304 |
-
{
|
| 305 |
-
"cell_type": "markdown",
|
| 306 |
-
"id": "ecd33dd7",
|
| 307 |
-
"metadata": {
|
| 308 |
-
"id": "ecd33dd7"
|
| 309 |
-
},
|
| 310 |
-
"source": [
|
| 311 |
-
"## Training Configuration"
|
| 312 |
-
]
|
| 313 |
-
},
|
| 314 |
-
{
|
| 315 |
-
"cell_type": "code",
|
| 316 |
-
"execution_count": null,
|
| 317 |
-
"id": "f9d837ee",
|
| 318 |
-
"metadata": {
|
| 319 |
-
"id": "f9d837ee"
|
| 320 |
-
},
|
| 321 |
-
"outputs": [],
|
| 322 |
-
"source": [
|
| 323 |
-
"train_batch_size = 2\n",
|
| 324 |
-
"eval_batch_size = 2\n",
|
| 325 |
-
"num_epochs = 1\n",
|
| 326 |
-
"gradient_accumulation_steps = 4\n",
|
| 327 |
-
"learning_rate = 1e-5\n",
|
| 328 |
-
"weight_decay = 0.0\n",
|
| 329 |
-
"warmup_ratio = 0.03\n",
|
| 330 |
-
"logging_frequency = 10"
|
| 331 |
-
]
|
| 332 |
-
},
|
| 333 |
-
{
|
| 334 |
-
"cell_type": "markdown",
|
| 335 |
-
"id": "1cf11e96",
|
| 336 |
-
"metadata": {
|
| 337 |
-
"id": "1cf11e96"
|
| 338 |
-
},
|
| 339 |
-
"source": [
|
| 340 |
-
"## Create a `DataLoader` 👴"
|
| 341 |
-
]
|
| 342 |
-
},
|
| 343 |
-
{
|
| 344 |
-
"cell_type": "code",
|
| 345 |
-
"execution_count": null,
|
| 346 |
-
"id": "1bc4fa24",
|
| 347 |
-
"metadata": {
|
| 348 |
-
"id": "1bc4fa24"
|
| 349 |
-
},
|
| 350 |
-
"outputs": [],
|
| 351 |
-
"source": [
|
| 352 |
-
"def collate_fn(batch):\n",
|
| 353 |
-
" batch_dict = {\n",
|
| 354 |
-
" \"input_ids\": [record[\"input_ids\"] for record in batch],\n",
|
| 355 |
-
" \"attention_mask\": [record[\"attention_mask\"] for record in batch],\n",
|
| 356 |
-
" }\n",
|
| 357 |
-
" padded = tokenizer.pad(batch_dict, padding=True, return_tensors=\"pt\")\n",
|
| 358 |
-
" labels = padded[\"input_ids\"].clone()\n",
|
| 359 |
-
" labels[padded[\"attention_mask\"] == 0] = -100\n",
|
| 360 |
-
" padded[\"labels\"] = labels\n",
|
| 361 |
-
" return padded\n",
|
| 362 |
-
"\n",
|
| 363 |
-
"\n",
|
| 364 |
-
"TrainLoader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn)\n",
|
| 365 |
-
"EvalLoader = DataLoader(eval_dataset, batch_size=eval_batch_size, shuffle=False, collate_fn=collate_fn)\n"
|
| 366 |
-
]
|
| 367 |
-
},
|
| 368 |
-
{
|
| 369 |
-
"cell_type": "markdown",
|
| 370 |
-
"id": "f5965d1b",
|
| 371 |
-
"metadata": {
|
| 372 |
-
"id": "f5965d1b"
|
| 373 |
-
},
|
| 374 |
-
"source": [
|
| 375 |
-
"## Optimizer"
|
| 376 |
-
]
|
| 377 |
-
},
|
| 378 |
-
{
|
| 379 |
-
"cell_type": "code",
|
| 380 |
-
"execution_count": null,
|
| 381 |
-
"id": "f57c7be2",
|
| 382 |
-
"metadata": {
|
| 383 |
-
"id": "f57c7be2"
|
| 384 |
-
},
|
| 385 |
-
"outputs": [],
|
| 386 |
-
"source": [
|
| 387 |
-
"optimizer = torch.optim.AdamW(\n",
|
| 388 |
-
" model.parameters(),\n",
|
| 389 |
-
" lr=learning_rate,\n",
|
| 390 |
-
" weight_decay=weight_decay,\n",
|
| 391 |
-
")\n"
|
| 392 |
-
]
|
| 393 |
-
},
|
| 394 |
-
{
|
| 395 |
-
"cell_type": "markdown",
|
| 396 |
-
"id": "215f8782",
|
| 397 |
-
"metadata": {
|
| 398 |
-
"id": "215f8782"
|
| 399 |
-
},
|
| 400 |
-
"source": [
|
| 401 |
-
"# Learning Rate Scheduler"
|
| 402 |
-
]
|
| 403 |
-
},
|
| 404 |
-
{
|
| 405 |
-
"cell_type": "code",
|
| 406 |
-
"execution_count": null,
|
| 407 |
-
"id": "034e2903",
|
| 408 |
-
"metadata": {
|
| 409 |
-
"id": "034e2903"
|
| 410 |
-
},
|
| 411 |
-
"outputs": [],
|
| 412 |
-
"source": [
|
| 413 |
-
"num_update_steps_per_epoch = max(len(TrainLoader) // gradient_accumulation_steps, 1)\n",
|
| 414 |
-
"max_train_steps = num_epochs * num_update_steps_per_epoch\n",
|
| 415 |
-
"warmup_steps = max(1, int(max_train_steps * warmup_ratio))\n",
|
| 416 |
-
"scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, max_train_steps)\n"
|
| 417 |
-
]
|
| 418 |
-
},
|
| 419 |
-
{
|
| 420 |
-
"cell_type": "markdown",
|
| 421 |
-
"id": "0f0090b6",
|
| 422 |
-
"metadata": {
|
| 423 |
-
"id": "0f0090b6"
|
| 424 |
-
},
|
| 425 |
-
"source": [
|
| 426 |
-
"# The Training Loop"
|
| 427 |
-
]
|
| 428 |
-
},
|
| 429 |
-
{
|
| 430 |
-
"cell_type": "code",
|
| 431 |
-
"execution_count": null,
|
| 432 |
-
"id": "1540e30a",
|
| 433 |
-
"metadata": {
|
| 434 |
-
"id": "1540e30a",
|
| 435 |
-
"outputId": "747badd7-18df-441f-8026-7aa4f30c2fd7"
|
| 436 |
-
},
|
| 437 |
-
"outputs": [
|
| 438 |
-
{
|
| 439 |
-
"name": "stderr",
|
| 440 |
-
"output_type": "stream",
|
| 441 |
-
"text": [
|
| 442 |
-
"You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
|
| 443 |
-
]
|
| 444 |
-
},
|
| 445 |
-
{
|
| 446 |
-
"name": "stdout",
|
| 447 |
-
"output_type": "stream",
|
| 448 |
-
"text": [
|
| 449 |
-
"Epoch 1/1\n",
|
| 450 |
-
"step=00010 | loss=1.7586 | lr=1.33e-06\n",
|
| 451 |
-
"step=00020 | loss=1.8188 | lr=2.67e-06\n",
|
| 452 |
-
"step=00030 | loss=1.8235 | lr=4.00e-06\n",
|
| 453 |
-
"step=00040 | loss=1.7935 | lr=5.33e-06\n",
|
| 454 |
-
"step=00050 | loss=1.8029 | lr=6.67e-06\n",
|
| 455 |
-
"step=00060 | loss=1.8433 | lr=8.00e-06\n",
|
| 456 |
-
"step=00070 | loss=1.8616 | lr=9.33e-06\n",
|
| 457 |
-
"step=00080 | loss=1.8238 | lr=9.98e-06\n",
|
| 458 |
-
"step=00090 | loss=1.7774 | lr=9.94e-06\n",
|
| 459 |
-
"step=00100 | loss=1.8081 | lr=9.90e-06\n",
|
| 460 |
-
"step=00110 | loss=1.7437 | lr=9.86e-06\n",
|
| 461 |
-
"step=00120 | loss=1.7830 | lr=9.81e-06\n",
|
| 462 |
-
"step=00130 | loss=1.8064 | lr=9.77e-06\n",
|
| 463 |
-
"step=00140 | loss=1.8541 | lr=9.73e-06\n",
|
| 464 |
-
"step=00150 | loss=1.8301 | lr=9.69e-06\n",
|
| 465 |
-
"step=00160 | loss=1.7725 | lr=9.65e-06\n",
|
| 466 |
-
"step=00170 | loss=1.7635 | lr=9.61e-06\n",
|
| 467 |
-
"step=00180 | loss=1.7963 | lr=9.57e-06\n",
|
| 468 |
-
"step=00190 | loss=1.7563 | lr=9.53e-06\n",
|
| 469 |
-
"step=00200 | loss=1.6950 | lr=9.48e-06\n",
|
| 470 |
-
"step=00210 | loss=1.7680 | lr=9.44e-06\n",
|
| 471 |
-
"step=00220 | loss=1.8906 | lr=9.40e-06\n",
|
| 472 |
-
"step=00230 | loss=1.7120 | lr=9.36e-06\n",
|
| 473 |
-
"step=00240 | loss=1.8390 | lr=9.32e-06\n",
|
| 474 |
-
"step=00250 | loss=1.7180 | lr=9.28e-06\n",
|
| 475 |
-
"step=00260 | loss=1.7709 | lr=9.24e-06\n",
|
| 476 |
-
"step=00270 | loss=1.7598 | lr=9.20e-06\n",
|
| 477 |
-
"step=00280 | loss=1.7981 | lr=9.15e-06\n",
|
| 478 |
-
"step=00290 | loss=1.7540 | lr=9.11e-06\n",
|
| 479 |
-
"step=00300 | loss=1.7695 | lr=9.07e-06\n",
|
| 480 |
-
"step=00310 | loss=1.7468 | lr=9.03e-06\n"
|
| 481 |
-
]
|
| 482 |
-
},
|
| 483 |
-
{
|
| 484 |
-
"ename": "KeyboardInterrupt",
|
| 485 |
-
"evalue": "",
|
| 486 |
-
"output_type": "error",
|
| 487 |
-
"traceback": [
|
| 488 |
-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 489 |
-
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
| 490 |
-
"Cell \u001b[0;32mIn[14], line 11\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m step, batch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(TrainLoader, start\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m):\n\u001b[1;32m 10\u001b[0m batch \u001b[38;5;241m=\u001b[39m {key: value\u001b[38;5;241m.\u001b[39mto(device) \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m batch\u001b[38;5;241m.\u001b[39mitems()}\n\u001b[0;32m---> 11\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 12\u001b[0m loss \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;241m/\u001b[39m gradient_accumulation_steps\n\u001b[1;32m 13\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n",
|
| 491 |
-
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1773\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1771\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1772\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1773\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 492 |
-
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1784\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1779\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1780\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1781\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1782\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1783\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1784\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1786\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1787\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
|
| 493 |
-
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/peft/peft_model.py:1850\u001b[0m, in \u001b[0;36mPeftModelForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)\u001b[0m\n\u001b[1;32m 1848\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_enable_peft_forward_hooks(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 1849\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m {k: v \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m kwargs\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mspecial_peft_forward_args}\n\u001b[0;32m-> 1850\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbase_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1851\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1852\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1853\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1854\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1855\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1856\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1857\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1858\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1859\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1861\u001b[0m batch_size \u001b[38;5;241m=\u001b[39m _get_batch_size(input_ids, inputs_embeds)\n\u001b[1;32m 1862\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attention_mask \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1863\u001b[0m \u001b[38;5;66;03m# concat prompt attention mask\u001b[39;00m\n",
|
| 494 |
-
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1773\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1771\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1772\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1773\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 495 |
-
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1784\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1779\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1780\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1781\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1782\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1783\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1784\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1786\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1787\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
|
| 496 |
-
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:222\u001b[0m, in \u001b[0;36mBaseTuner.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any):\n\u001b[0;32m--> 222\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 497 |
-
"File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/utils/generic.py:757\u001b[0m, in \u001b[0;36mcan_return_tuple.<locals>.wrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 755\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m return_dict_passed \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 756\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict_passed\n\u001b[0;32m--> 757\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 758\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m return_dict \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(output, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 759\u001b[0m output \u001b[38;5;241m=\u001b[39m output\u001b[38;5;241m.\u001b[39mto_tuple()\n",
|
| 498 |
-
"File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/models/nanochat/modeling_nanochat.py:474\u001b[0m, in \u001b[0;36mNanoChatForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, cache_position, logits_to_keep, **kwargs)\u001b[0m\n\u001b[1;32m 435\u001b[0m \u001b[38;5;129m@can_return_tuple\u001b[39m\n\u001b[1;32m 436\u001b[0m \u001b[38;5;129m@auto_docstring\u001b[39m\n\u001b[1;32m 437\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mforward\u001b[39m(\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 448\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Unpack[TransformersKwargs],\n\u001b[1;32m 449\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m CausalLMOutputWithPast:\n\u001b[1;32m 450\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 451\u001b[0m \u001b[38;5;124;03m Example:\u001b[39;00m\n\u001b[1;32m 452\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 472\u001b[0m \u001b[38;5;124;03m >>> output = tokenizer.decode(generated_tokens, skip_special_tokens=True)\u001b[39;00m\n\u001b[1;32m 473\u001b[0m \u001b[38;5;124;03m ```\"\"\"\u001b[39;00m\n\u001b[0;32m--> 474\u001b[0m outputs: BaseModelOutputWithPast \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 475\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 476\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 477\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 478\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 479\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 480\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 481\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 482\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 483\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 485\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mlast_hidden_state\n\u001b[1;32m 486\u001b[0m slice_indices \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mslice\u001b[39m(\u001b[38;5;241m-\u001b[39mlogits_to_keep, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(logits_to_keep, \u001b[38;5;28mint\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m logits_to_keep\n",
|
| 499 |
-
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1773\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1771\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1772\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1773\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 500 |
-
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1784\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1779\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1780\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1781\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1782\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1783\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1784\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1786\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1787\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
|
| 501 |
-
"File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/utils/generic.py:927\u001b[0m, in \u001b[0;36mcheck_model_inputs.<locals>.wrapped_fn.<locals>.wrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 924\u001b[0m monkey_patched_layers\u001b[38;5;241m.\u001b[39mappend((module, original_forward))\n\u001b[1;32m 926\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 927\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 928\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m original_exception:\n\u001b[1;32m 929\u001b[0m \u001b[38;5;66;03m# If we get a TypeError, it's possible that the model is not receiving the recordable kwargs correctly.\u001b[39;00m\n\u001b[1;32m 930\u001b[0m \u001b[38;5;66;03m# Get a TypeError even after removing the recordable kwargs -> re-raise the original exception\u001b[39;00m\n\u001b[1;32m 931\u001b[0m \u001b[38;5;66;03m# Otherwise -> we're probably missing `**kwargs` in the decorated function\u001b[39;00m\n\u001b[1;32m 932\u001b[0m kwargs_without_recordable \u001b[38;5;241m=\u001b[39m {k: v \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m kwargs\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m recordable_keys}\n",
|
| 502 |
-
"File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/models/nanochat/modeling_nanochat.py:401\u001b[0m, in \u001b[0;36mNanoChatModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, cache_position, **kwargs)\u001b[0m\n\u001b[1;32m 398\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minitial_norm(hidden_states)\n\u001b[1;32m 400\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m decoder_layer \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[0;32m--> 401\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 402\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 403\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 404\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 405\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 406\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 407\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 408\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 409\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 410\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 412\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm(hidden_states)\n\u001b[1;32m 414\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m BaseModelOutputWithPast(\n\u001b[1;32m 415\u001b[0m last_hidden_state\u001b[38;5;241m=\u001b[39mhidden_states,\n\u001b[1;32m 416\u001b[0m past_key_values\u001b[38;5;241m=\u001b[39mpast_key_values \u001b[38;5;28;01mif\u001b[39;00m use_cache \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 417\u001b[0m )\n",
|
| 503 |
-
"File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/modeling_layers.py:94\u001b[0m, in \u001b[0;36mGradientCheckpointingLayer.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 91\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarning_once(message)\n\u001b[1;32m 93\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(partial(\u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs), \u001b[38;5;241m*\u001b[39margs)\n\u001b[0;32m---> 94\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__call__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 504 |
-
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1773\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1771\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1772\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1773\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 505 |
-
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1784\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1779\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1780\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1781\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1782\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1783\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1784\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1786\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1787\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
|
| 506 |
-
"File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/models/nanochat/modeling_nanochat.py:279\u001b[0m, in \u001b[0;36mNanoChatDecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_values, use_cache, cache_position, position_embeddings, **kwargs)\u001b[0m\n\u001b[1;32m 267\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m 268\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 269\u001b[0m hidden_states: torch\u001b[38;5;241m.\u001b[39mTensor,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 276\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Unpack[TransformersKwargs],\n\u001b[1;32m 277\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[1;32m 278\u001b[0m residual \u001b[38;5;241m=\u001b[39m hidden_states\n\u001b[0;32m--> 279\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minput_layernorm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 280\u001b[0m \u001b[38;5;66;03m# Self Attention\u001b[39;00m\n\u001b[1;32m 281\u001b[0m hidden_states, _ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mself_attn(\n\u001b[1;32m 282\u001b[0m hidden_states\u001b[38;5;241m=\u001b[39mhidden_states,\n\u001b[1;32m 283\u001b[0m attention_mask\u001b[38;5;241m=\u001b[39mattention_mask,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 289\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 290\u001b[0m )\n",
|
| 507 |
-
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1773\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1771\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1772\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1773\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 508 |
-
"File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1784\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1779\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1780\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1781\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1782\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1783\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1784\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1786\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1787\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
|
| 509 |
-
"File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/models/nanochat/modeling_nanochat.py:53\u001b[0m, in \u001b[0;36mNanoChatRMSNorm.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m---> 53\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_norm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mtype_as(x)\n",
|
| 510 |
-
"File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/models/nanochat/modeling_nanochat.py:50\u001b[0m, in \u001b[0;36mNanoChatRMSNorm._norm\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_norm\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m---> 50\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mrsqrt(\u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpow\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmean\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeepdim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39meps)\n",
|
| 511 |
-
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
| 512 |
-
]
|
| 513 |
-
}
|
| 514 |
-
],
|
| 515 |
-
"source": [
|
| 516 |
-
"\n",
|
| 517 |
-
"model.train()\n",
|
| 518 |
-
"global_step = 0\n",
|
| 519 |
-
"running_loss = 0.0\n",
|
| 520 |
-
"running_steps = 0\n",
|
| 521 |
-
"\n",
|
| 522 |
-
"for epoch in range(num_epochs):\n",
|
| 523 |
-
" print(f\"Epoch {epoch + 1}/{num_epochs}\")\n",
|
| 524 |
-
" optimizer.zero_grad(set_to_none=True)\n",
|
| 525 |
-
" for step, batch in enumerate(TrainLoader, start=1):\n",
|
| 526 |
-
" batch = {key: value.to(device) for key, value in batch.items()}\n",
|
| 527 |
-
" outputs = model(**batch)\n",
|
| 528 |
-
" loss = outputs.loss / gradient_accumulation_steps\n",
|
| 529 |
-
" loss.backward()\n",
|
| 530 |
-
"\n",
|
| 531 |
-
" running_loss += outputs.loss.float().item()\n",
|
| 532 |
-
" running_steps += 1\n",
|
| 533 |
-
"\n",
|
| 534 |
-
" if step % gradient_accumulation_steps == 0 or step == len(TrainLoader):\n",
|
| 535 |
-
" torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
|
| 536 |
-
" optimizer.step()\n",
|
| 537 |
-
" scheduler.step()\n",
|
| 538 |
-
" optimizer.zero_grad(set_to_none=True)\n",
|
| 539 |
-
" global_step += 1\n",
|
| 540 |
-
"\n",
|
| 541 |
-
" if global_step % logging_frequency == 0:\n",
|
| 542 |
-
" current_lr = scheduler.get_last_lr()[0]\n",
|
| 543 |
-
" mean_loss = running_loss / running_steps\n",
|
| 544 |
-
" print(f\"step={global_step:05d} | loss={mean_loss:.4f} | lr={current_lr:.2e}\")\n",
|
| 545 |
-
" running_loss = 0.0\n",
|
| 546 |
-
" running_steps = 0\n",
|
| 547 |
-
"\n",
|
| 548 |
-
" train_loss = running_loss / running_steps if running_steps > 0 else float(\"nan\")\n",
|
| 549 |
-
" print(f\"Training loss after epoch {epoch + 1}: {train_loss:.4f}\")\n",
|
| 550 |
-
"\n",
|
| 551 |
-
" model.eval()\n",
|
| 552 |
-
" losses = []\n",
|
| 553 |
-
" with torch.no_grad():\n",
|
| 554 |
-
" for _, batch in enumerate(EvalLoader, start=1):\n",
|
| 555 |
-
" batch = {key: value.to(device) for key, value in batch.items()}\n",
|
| 556 |
-
" loss = model(**batch).loss\n",
|
| 557 |
-
" losses.append(loss.float().item())\n",
|
| 558 |
-
" model.train()\n",
|
| 559 |
-
" val_loss = sum(losses) / len(losses) if losses else float(\"nan\")\n",
|
| 560 |
-
"\n",
|
| 561 |
-
" print(f\"Validation loss after epoch {epoch + 1}: {val_loss:.4f}\")\n",
|
| 562 |
-
"\n",
|
| 563 |
-
"print(\"Training complete.\")\n"
|
| 564 |
-
]
|
| 565 |
}
|
| 566 |
],
|
| 567 |
"metadata": {
|
|
|
|
|
|
|
|
|
|
| 568 |
"kernelspec": {
|
| 569 |
"display_name": ".venv",
|
| 570 |
"language": "python",
|
|
@@ -581,11 +186,8 @@
|
|
| 581 |
"nbconvert_exporter": "python",
|
| 582 |
"pygments_lexer": "ipython3",
|
| 583 |
"version": "3.10.18"
|
| 584 |
-
},
|
| 585 |
-
"colab": {
|
| 586 |
-
"provenance": []
|
| 587 |
}
|
| 588 |
},
|
| 589 |
"nbformat": 4,
|
| 590 |
"nbformat_minor": 5
|
| 591 |
-
}
|
|
|
|
| 59 |
").to(device)\n"
|
| 60 |
]
|
| 61 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
{
|
| 63 |
"cell_type": "markdown",
|
| 64 |
"id": "4810af1a",
|
|
|
|
| 164 |
"print(f\"\\nGenerated: {tokenizer.decode(generated_tokens)}\")\n",
|
| 165 |
"print(\"=\" * 80)\n"
|
| 166 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
}
|
| 168 |
],
|
| 169 |
"metadata": {
|
| 170 |
+
"colab": {
|
| 171 |
+
"provenance": []
|
| 172 |
+
},
|
| 173 |
"kernelspec": {
|
| 174 |
"display_name": ".venv",
|
| 175 |
"language": "python",
|
|
|
|
| 186 |
"nbconvert_exporter": "python",
|
| 187 |
"pygments_lexer": "ipython3",
|
| 188 |
"version": "3.10.18"
|
|
|
|
|
|
|
|
|
|
| 189 |
}
|
| 190 |
},
|
| 191 |
"nbformat": 4,
|
| 192 |
"nbformat_minor": 5
|
| 193 |
+
}
|