burtenshaw commited on
Commit
a224632
·
1 Parent(s): e3af5c6

improve structure and layout

Browse files
app/src/content/article.mdx CHANGED
@@ -570,8 +570,6 @@ In `modular_nanochat.py`, we don't need to write this logic at all. As seen in t
570
 
571
  It's very clear that Andrej Karpathy's implementation offers 10 times more to learn from than the transformer version which inherits almost entirely from existing models or features. That said, we can still take more away from the inherited modular modeling implementation. Models like Llama, Llama4, Gemma2, Qwen3, and CLIP are all reused to create a genuinely canonical implementation of a transformer.
572
 
573
- # Hands-on Tutorial
574
-
575
  Ok. Let's cut the philosphy and see what we can do with `nanochat` in transformers.
576
 
577
  <Inference />
 
570
 
571
  It's very clear that Andrej Karpathy's implementation offers 10 times more to learn from than the transformer version which inherits almost entirely from existing models or features. That said, we can still take more away from the inherited modular modeling implementation. Models like Llama, Llama4, Gemma2, Qwen3, and CLIP are all reused to create a genuinely canonical implementation of a transformer.
572
 
 
 
573
  Ok. Let's cut the philosphy and see what we can do with `nanochat` in transformers.
574
 
575
  <Inference />
app/src/content/chapters/inference.mdx CHANGED
@@ -1,4 +1,10 @@
1
- ## Inference on `nano` in Transformers
 
 
 
 
 
 
2
 
3
  First bonus tutorial will help you to do basic inference in `transformers`:
4
 
 
1
+ ## Example 1: Inference on nanochat in Transformers
2
+
3
+ <Sidenote>
4
+
5
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/#fileId=https://huggingface.co/datasets/nanochat-students/notebooks/blob/main/inference.ipynb)
6
+
7
+ </Sidenote>
8
 
9
  First bonus tutorial will help you to do basic inference in `transformers`:
10
 
app/src/content/chapters/sft.mdx CHANGED
@@ -1,7 +1,7 @@
1
  import Sidenote from '../../components/Sidenote.astro'
2
  import Note from '../../components/Note.astro'
3
 
4
- # [BONUS 2] Supervised Fine-tuning in `torch`
5
 
6
  <Sidenote>
7
 
@@ -19,7 +19,7 @@ In this tutorial, we'll fine-tune the NanoChat model using pure PyTorch, giving
19
 
20
  </Note>
21
 
22
- ## Import model and tokenizer
23
 
24
  <Sidenote>
25
 
@@ -51,7 +51,7 @@ model = AutoModelForCausalLM.from_pretrained(
51
 
52
  We use `bfloat16` precision on GPU to reduce memory usage while maintaining training stability. On CPU, we fall back to `float32` for compatibility.
53
 
54
- ## Setup LoRA
55
 
56
  <Sidenote>
57
 
@@ -87,7 +87,7 @@ trainable params: 1,179,648 || all params: 1,880,227,840 || trainable%: 0.0627
87
 
88
  With LoRA, we're only training **0.06%** of the model's parameters—just over 1 million weights instead of 1.8 billion. This makes fine-tuning feasible on consumer hardware.
89
 
90
- ## Demo the model
91
 
92
  <Sidenote>
93
 
@@ -179,7 +179,7 @@ Generated: The capital of France is Paris.<|assistant_end|>
179
 
180
  Notice the special tokens: `<|bos|>`, `<|user_start|>`, `<|assistant_start|>`, etc. These delimiters help the model understand conversation structure.
181
 
182
- ## Dataset
183
 
184
  <Sidenote>
185
 
@@ -196,8 +196,6 @@ train_dataset = splits["train"]
196
  eval_dataset = splits["test"]
197
  ```
198
 
199
- ### Process the Dataset
200
-
201
  <Sidenote>
202
 
203
  The [datasets map](https://huggingface.co/docs/datasets/process#map) function applies transformations efficiently with caching and multiprocessing support.
@@ -236,31 +234,6 @@ eval_dataset = eval_dataset.select(range(min(len(eval_dataset), max_eval_example
236
  eval_dataset = eval_dataset.map(format_example, remove_columns=eval_dataset.column_names)
237
  ```
238
 
239
- ## Training Configuration
240
-
241
- These hyperparameters control the training dynamics. We use conservative values that work well across different hardware:
242
-
243
- ```python
244
- train_batch_size = 2
245
- eval_batch_size = 2
246
- num_epochs = 1
247
- gradient_accumulation_steps = 4
248
- learning_rate = 1e-5
249
- weight_decay = 0.0
250
- warmup_ratio = 0.03
251
- logging_frequency = 10
252
- ```
253
-
254
- <Sidenote>
255
-
256
- **Gradient accumulation** simulates larger batch sizes by accumulating gradients over multiple forward passes before updating weights. Effective batch size = `train_batch_size × gradient_accumulation_steps` = 8.
257
-
258
- </Sidenote>
259
-
260
- Key configuration choices include using a low learning rate (`1e-5`), as LoRA generally requires smaller learning rates given that the base model weights are kept frozen. Additionally, gradient accumulation is employed to enable larger effective batch sizes, which helps when training on GPUs with limited memory.
261
-
262
- ## Create a DataLoader
263
-
264
  <Sidenote>
265
 
266
  PyTorch's [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) handles batching, shuffling, and parallel data loading automatically.
@@ -288,7 +261,30 @@ eval_loader = DataLoader(eval_dataset, batch_size=eval_batch_size, shuffle=False
288
 
289
  Setting padding tokens to `-100` in labels tells PyTorch's cross-entropy loss to ignore them—we don't want to penalize the model for not predicting padding.
290
 
291
- ## Optimizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
  <Sidenote>
294
 
@@ -306,7 +302,7 @@ optimizer = torch.optim.AdamW(
306
  )
307
  ```
308
 
309
- ## Learning Rate Scheduler
310
 
311
  <Sidenote>
312
 
@@ -323,8 +319,6 @@ warmup_steps = max(1, int(max_train_steps * warmup_ratio))
323
  scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, max_train_steps)
324
  ```
325
 
326
- ## The Training Loop
327
-
328
  <Sidenote>
329
 
330
  For distributed training across multiple GPUs, consider [Accelerate](https://huggingface.co/docs/accelerate/index) which wraps this loop with minimal code changes.
@@ -398,3 +392,63 @@ step=00040 | loss=1.7935 | lr=5.33e-06
398
  step=00050 | loss=1.8029 | lr=6.67e-06
399
  ...
400
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import Sidenote from '../../components/Sidenote.astro'
2
  import Note from '../../components/Note.astro'
3
 
4
+ ## Example 2: Supervised Fine-tuning in torch
5
 
6
  <Sidenote>
7
 
 
19
 
20
  </Note>
21
 
22
+ ### Import model and tokenizer
23
 
24
  <Sidenote>
25
 
 
51
 
52
  We use `bfloat16` precision on GPU to reduce memory usage while maintaining training stability. On CPU, we fall back to `float32` for compatibility.
53
 
54
+ ### Setup LoRA
55
 
56
  <Sidenote>
57
 
 
87
 
88
  With LoRA, we're only training **0.06%** of the model's parameters—just over 1 million weights instead of 1.8 billion. This makes fine-tuning feasible on consumer hardware.
89
 
90
+ ### Demo the model
91
 
92
  <Sidenote>
93
 
 
179
 
180
  Notice the special tokens: `<|bos|>`, `<|user_start|>`, `<|assistant_start|>`, etc. These delimiters help the model understand conversation structure.
181
 
182
+ ### Dataset
183
 
184
  <Sidenote>
185
 
 
196
  eval_dataset = splits["test"]
197
  ```
198
 
 
 
199
  <Sidenote>
200
 
201
  The [datasets map](https://huggingface.co/docs/datasets/process#map) function applies transformations efficiently with caching and multiprocessing support.
 
234
  eval_dataset = eval_dataset.map(format_example, remove_columns=eval_dataset.column_names)
235
  ```
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  <Sidenote>
238
 
239
  PyTorch's [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) handles batching, shuffling, and parallel data loading automatically.
 
261
 
262
  Setting padding tokens to `-100` in labels tells PyTorch's cross-entropy loss to ignore them—we don't want to penalize the model for not predicting padding.
263
 
264
+ ## Training
265
+
266
+ These hyperparameters control the training dynamics. We use conservative values that work well across different hardware:
267
+
268
+ ```python
269
+ train_batch_size = 2
270
+ eval_batch_size = 2
271
+ num_epochs = 1
272
+ gradient_accumulation_steps = 4
273
+ learning_rate = 1e-5
274
+ weight_decay = 0.0
275
+ warmup_ratio = 0.03
276
+ logging_frequency = 10
277
+ ```
278
+
279
+ <Sidenote>
280
+
281
+ **Gradient accumulation** simulates larger batch sizes by accumulating gradients over multiple forward passes before updating weights. Effective batch size = `train_batch_size × gradient_accumulation_steps` = 8.
282
+
283
+ </Sidenote>
284
+
285
+ Key configuration choices include using a low learning rate (`1e-5`), as LoRA generally requires smaller learning rates given that the base model weights are kept frozen. Additionally, gradient accumulation is employed to enable larger effective batch sizes, which helps when training on GPUs with limited memory.
286
+
287
+ ### Optimizer
288
 
289
  <Sidenote>
290
 
 
302
  )
303
  ```
304
 
305
+ ### Learning Rate Scheduler
306
 
307
  <Sidenote>
308
 
 
319
  scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, max_train_steps)
320
  ```
321
 
 
 
322
  <Sidenote>
323
 
324
  For distributed training across multiple GPUs, consider [Accelerate](https://huggingface.co/docs/accelerate/index) which wraps this loop with minimal code changes.
 
392
  step=00050 | loss=1.8029 | lr=6.67e-06
393
  ...
394
  ```
395
+
396
+ ## Example 3: Fine-tuning with TRL
397
+
398
+ <Sidenote>
399
+
400
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/#fileId=https://github.com/huggingface/trl/blob/main/examples/notebooks/sft_trl_lora_qlora.ipynb)
401
+
402
+ </Sidenote>
403
+
404
+ Finally, we can implement the training loop above with TRL. Which definitely simplifies the code and abstracts away a lot of the complexity (education). But it's got all the bells and whistles of a production-ready solution.
405
+
406
+ We can define the training arguments and create the trainer object like this:
407
+
408
+ ```python
409
+ from trl import SFTConfig
410
+
411
+ training_args = SFTConfig(
412
+ per_device_train_batch_size = 1, # Batch size per GPU
413
+ gradient_accumulation_steps = 4, # Gradients are accumulated over multiple steps → effective batch size = 2 * 8 = 16
414
+ warmup_steps = 5,
415
+ # num_train_epochs = 1, # Number of full dataset passes. For shorter training, use `max_steps` instead (this case)
416
+ max_steps = 30,
417
+ learning_rate = 2e-4, # Learning rate for the optimizer
418
+ optim = "paged_adamw_8bit", # Optimizer
419
+
420
+ # Logging / reporting
421
+ logging_steps=1, # Log training metrics every N steps
422
+ report_to="trackio", # Experiment tracking tool
423
+ trackio_space_id=output_dir, # HF Space where the experiment tracking will be saved
424
+ output_dir=output_dir, # Where to save model checkpoints and logs
425
+
426
+ max_length=1024, # Maximum input sequence length
427
+ use_liger_kernel=True, # Enable Liger kernel optimizations for faster training
428
+ activation_offloading=True, # Offload activations to CPU to reduce GPU memory usage
429
+ gradient_checkpointing=True, # Save memory by re-computing activations during backpropagation
430
+
431
+ # Hub integration
432
+ push_to_hub=True, # Automatically push the trained model to the Hugging Face Hub
433
+ # The model will be saved under your Hub account in the repository named `output_dir`
434
+ )
435
+ ```
436
+
437
+ Then we can train the model like this and TRL will deal with data loading, batching, and training.
438
+
439
+ ```
440
+ from trl import SFTTrainer
441
+
442
+ trainer = SFTTrainer(
443
+ model=model,
444
+ args=training_args,
445
+ train_dataset=train_dataset,
446
+ peft_config=peft_config
447
+ )
448
+ ```
449
+
450
+ And then we can train the model like this:
451
+
452
+ ```
453
+ trainer_stats = trainer.train()
454
+ ```
inference.ipynb ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "b7eb261b",
6
+ "metadata": {
7
+ "id": "b7eb261b"
8
+ },
9
+ "source": [
10
+ "# NanoChat Easy - SFT Training\n"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "id": "8b8a04a8",
16
+ "metadata": {
17
+ "id": "8b8a04a8"
18
+ },
19
+ "source": [
20
+ "## Import model and tokenizer\n"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": null,
26
+ "id": "3e48247c",
27
+ "metadata": {
28
+ "id": "3e48247c",
29
+ "outputId": "882fcf01-34fb-4123-e84c-deefdf477814"
30
+ },
31
+ "outputs": [
32
+ {
33
+ "name": "stderr",
34
+ "output_type": "stream",
35
+ "text": [
36
+ "/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
37
+ " from .autonotebook import tqdm as notebook_tqdm\n",
38
+ "`torch_dtype` is deprecated! Use `dtype` instead!\n"
39
+ ]
40
+ }
41
+ ],
42
+ "source": [
43
+ "import torch\n",
44
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
45
+ "\n",
46
+ "\n",
47
+ "model_id = \"karpathy/nanochat-d32\"\n",
48
+ "revision = \"refs/pr/1\"\n",
49
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
50
+ "\n",
51
+ "\n",
52
+ "tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)\n",
53
+ "model = AutoModelForCausalLM.from_pretrained(\n",
54
+ " model_id,\n",
55
+ " revision=revision,\n",
56
+ " torch_dtype=torch.bfloat16 if device.type == \"cuda\" else torch.float32,\n",
57
+ ").to(device)\n"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "markdown",
62
+ "id": "4810af1a",
63
+ "metadata": {
64
+ "id": "4810af1a"
65
+ },
66
+ "source": [
67
+ "## Demo the model\n"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": null,
73
+ "id": "b3e81aa9",
74
+ "metadata": {
75
+ "id": "b3e81aa9",
76
+ "outputId": "1cde7e69-7ff1-4bfe-aa9f-9ded20249d82"
77
+ },
78
+ "outputs": [
79
+ {
80
+ "name": "stdout",
81
+ "output_type": "stream",
82
+ "text": [
83
+ "================================================================================\n",
84
+ "TEST 1: Plain Autoregressive Prompt\n",
85
+ "================================================================================\n",
86
+ "Prompt: The Eiffel Tower stands in Paris and\n",
87
+ "\n",
88
+ "Generated: is one of the most famous landmarks in the world. It is located on the Champ de Mars in the heart of the city. The tower was built for the 1889 World's Fair. It was designed by the French engineer Gustave Eiffel and took 2 years to build. The Eiffel Tower stands 324 meters\n",
89
+ "================================================================================\n"
90
+ ]
91
+ }
92
+ ],
93
+ "source": [
94
+ "print(\"=\" * 80)\n",
95
+ "print(\"TEST 1: Plain Autoregressive Prompt\")\n",
96
+ "print(\"=\" * 80)\n",
97
+ "prompt = \"The Eiffel Tower stands in Paris and\"\n",
98
+ "test_inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
99
+ "\n",
100
+ "\n",
101
+ "with torch.no_grad():\n",
102
+ " test_outputs = model.generate(\n",
103
+ " **test_inputs,\n",
104
+ " max_new_tokens=64,\n",
105
+ " do_sample=False,\n",
106
+ " pad_token_id=tokenizer.pad_token_id,\n",
107
+ " )\n",
108
+ "\n",
109
+ "generated_tokens = test_outputs[0, test_inputs[\"input_ids\"].shape[1] :]\n",
110
+ "print(f\"Prompt: {prompt}\")\n",
111
+ "print(f\"\\nGenerated: {tokenizer.decode(generated_tokens, skip_special_tokens=True)}\")\n",
112
+ "print(\"=\" * 80)\n"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "id": "8e7b275c",
119
+ "metadata": {
120
+ "id": "8e7b275c",
121
+ "outputId": "719e986e-61b4-4fd5-db15-4a9ef8f97396"
122
+ },
123
+ "outputs": [
124
+ {
125
+ "name": "stdout",
126
+ "output_type": "stream",
127
+ "text": [
128
+ "================================================================================\n",
129
+ "TEST 2: Chat Template\n",
130
+ "================================================================================\n",
131
+ "Formatted prompt: <|bos|><|user_start|>What is the capital of France?<|user_end|><|assistant_start|>\n",
132
+ "Input IDs: [65527, 65528, 1442, 309, 261, 3429, 281, 4215, 63, 65529, 65530]\n",
133
+ "\n",
134
+ "Generated: The capital of France is Paris.<|assistant_end|>\n",
135
+ "================================================================================\n"
136
+ ]
137
+ }
138
+ ],
139
+ "source": [
140
+ "print(\"=\" * 80)\n",
141
+ "print(\"TEST 2: Chat Template\")\n",
142
+ "print(\"=\" * 80)\n",
143
+ "conversation = [\n",
144
+ " {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n",
145
+ "]\n",
146
+ "\n",
147
+ "inputs = tokenizer.apply_chat_template(\n",
148
+ " conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors=\"pt\"\n",
149
+ ").to(device)\n",
150
+ "\n",
151
+ "print(f\"Formatted prompt: {tokenizer.decode(inputs['input_ids'][0])}\")\n",
152
+ "print(f\"Input IDs: {inputs['input_ids'][0].tolist()}\")\n",
153
+ "\n",
154
+ "with torch.no_grad():\n",
155
+ " outputs = model.generate(**inputs, max_new_tokens=64, do_sample=False)\n",
156
+ "\n",
157
+ "generated_tokens = outputs[0, inputs[\"input_ids\"].shape[1] :]\n",
158
+ "print(f\"\\nGenerated: {tokenizer.decode(generated_tokens)}\")\n",
159
+ "print(\"=\" * 80)\n"
160
+ ]
161
+ }
162
+ ],
163
+ "metadata": {
164
+ "colab": {
165
+ "provenance": []
166
+ },
167
+ "kernelspec": {
168
+ "display_name": ".venv",
169
+ "language": "python",
170
+ "name": "python3"
171
+ },
172
+ "language_info": {
173
+ "codemirror_mode": {
174
+ "name": "ipython",
175
+ "version": 3
176
+ },
177
+ "file_extension": ".py",
178
+ "mimetype": "text/x-python",
179
+ "name": "python",
180
+ "nbconvert_exporter": "python",
181
+ "pygments_lexer": "ipython3",
182
+ "version": "3.10.18"
183
+ }
184
+ },
185
+ "nbformat": 4,
186
+ "nbformat_minor": 5
187
+ }
sft.ipynb CHANGED
@@ -59,48 +59,6 @@
59
  ").to(device)\n"
60
  ]
61
  },
62
- {
63
- "cell_type": "markdown",
64
- "id": "c9a9c0a4",
65
- "metadata": {
66
- "id": "c9a9c0a4"
67
- },
68
- "source": [
69
- "## Setup LoRA\n"
70
- ]
71
- },
72
- {
73
- "cell_type": "code",
74
- "execution_count": null,
75
- "id": "dd9a698a",
76
- "metadata": {
77
- "id": "dd9a698a",
78
- "outputId": "0aae9ecc-7af9-436e-a95b-a4cd023997fd"
79
- },
80
- "outputs": [
81
- {
82
- "name": "stdout",
83
- "output_type": "stream",
84
- "text": [
85
- "trainable params: 1,179,648 || all params: 1,880,227,840 || trainable%: 0.0627\n"
86
- ]
87
- }
88
- ],
89
- "source": [
90
- "from peft import LoraConfig, get_peft_model\n",
91
- "\n",
92
- "lora_config = LoraConfig(\n",
93
- " r=1,\n",
94
- " lora_alpha=2,\n",
95
- " lora_dropout=0.00,\n",
96
- " task_type=\"CAUSAL_LM\",\n",
97
- " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"fc1\", \"fc2\"]\n",
98
- ")\n",
99
- "\n",
100
- "model = get_peft_model(model, lora_config)\n",
101
- "model.print_trainable_parameters()\n"
102
- ]
103
- },
104
  {
105
  "cell_type": "markdown",
106
  "id": "4810af1a",
@@ -206,365 +164,12 @@
206
  "print(f\"\\nGenerated: {tokenizer.decode(generated_tokens)}\")\n",
207
  "print(\"=\" * 80)\n"
208
  ]
209
- },
210
- {
211
- "cell_type": "markdown",
212
- "id": "44cb321a",
213
- "metadata": {
214
- "id": "44cb321a"
215
- },
216
- "source": [
217
- "## Dataset\n"
218
- ]
219
- },
220
- {
221
- "cell_type": "code",
222
- "execution_count": null,
223
- "id": "e1a75c14",
224
- "metadata": {
225
- "id": "e1a75c14"
226
- },
227
- "outputs": [],
228
- "source": [
229
- "raw_dataset = load_dataset(\"HuggingFaceTB/smoltalk2\", \"SFT\", split=\"OpenThoughts3_1.2M_think\")\n",
230
- "splits = raw_dataset.train_test_split(test_size=0.1, seed=42)\n",
231
- "train_dataset = splits[\"train\"]\n",
232
- "eval_dataset = splits[\"test\"]\n"
233
- ]
234
- },
235
- {
236
- "cell_type": "markdown",
237
- "id": "8b29399d",
238
- "metadata": {
239
- "id": "8b29399d"
240
- },
241
- "source": [
242
- "### Process the Dataset\n"
243
- ]
244
- },
245
- {
246
- "cell_type": "code",
247
- "execution_count": null,
248
- "id": "451542b4",
249
- "metadata": {
250
- "id": "451542b4",
251
- "outputId": "caa727dd-f9d8-4c67-d193-79bcc0836b49"
252
- },
253
- "outputs": [
254
- {
255
- "name": "stderr",
256
- "output_type": "stream",
257
- "text": [
258
- "Map: 0%| | 0/20000 [00:00<?, ? examples/s]"
259
- ]
260
- },
261
- {
262
- "name": "stderr",
263
- "output_type": "stream",
264
- "text": [
265
- "Map: 100%|██████████| 20000/20000 [06:27<00:00, 51.68 examples/s]\n",
266
- "Map: 100%|██████████| 1000/1000 [00:19<00:00, 52.12 examples/s]\n"
267
- ]
268
- }
269
- ],
270
- "source": [
271
- "max_length = 2048\n",
272
- "max_train_examples = 20000\n",
273
- "max_eval_examples = 1000\n",
274
- "\n",
275
- "def format_example(example):\n",
276
- " formatted = tokenizer.apply_chat_template(\n",
277
- " example[\"messages\"],\n",
278
- " add_generation_prompt=False,\n",
279
- " truncation=True,\n",
280
- " max_length=max_length,\n",
281
- " padding=False,\n",
282
- " return_dict=True,\n",
283
- " return_tensors=\"pt\",\n",
284
- " )\n",
285
- " return {\n",
286
- " \"input_ids\": formatted[\"input_ids\"][0].tolist(),\n",
287
- " \"attention_mask\": formatted[\"attention_mask\"][0].tolist(),\n",
288
- " }\n",
289
- "\n",
290
- "\n",
291
- "if max_train_examples is not None:\n",
292
- " train_dataset = train_dataset.select(range(min(len(train_dataset), max_train_examples)))\n",
293
- " train_dataset = train_dataset.map(format_example, remove_columns=train_dataset.column_names)\n",
294
- "else:\n",
295
- " train_dataset = train_dataset.map(format_example, remove_columns=train_dataset.column_names)\n",
296
- "\n",
297
- "if max_eval_examples is not None:\n",
298
- " eval_dataset = eval_dataset.select(range(min(len(eval_dataset), max_eval_examples)))\n",
299
- " eval_dataset = eval_dataset.map(format_example, remove_columns=eval_dataset.column_names)\n",
300
- "else:\n",
301
- " eval_dataset = eval_dataset.map(format_example, remove_columns=eval_dataset.column_names)\n"
302
- ]
303
- },
304
- {
305
- "cell_type": "markdown",
306
- "id": "ecd33dd7",
307
- "metadata": {
308
- "id": "ecd33dd7"
309
- },
310
- "source": [
311
- "## Training Configuration"
312
- ]
313
- },
314
- {
315
- "cell_type": "code",
316
- "execution_count": null,
317
- "id": "f9d837ee",
318
- "metadata": {
319
- "id": "f9d837ee"
320
- },
321
- "outputs": [],
322
- "source": [
323
- "train_batch_size = 2\n",
324
- "eval_batch_size = 2\n",
325
- "num_epochs = 1\n",
326
- "gradient_accumulation_steps = 4\n",
327
- "learning_rate = 1e-5\n",
328
- "weight_decay = 0.0\n",
329
- "warmup_ratio = 0.03\n",
330
- "logging_frequency = 10"
331
- ]
332
- },
333
- {
334
- "cell_type": "markdown",
335
- "id": "1cf11e96",
336
- "metadata": {
337
- "id": "1cf11e96"
338
- },
339
- "source": [
340
- "## Create a `DataLoader` 👴"
341
- ]
342
- },
343
- {
344
- "cell_type": "code",
345
- "execution_count": null,
346
- "id": "1bc4fa24",
347
- "metadata": {
348
- "id": "1bc4fa24"
349
- },
350
- "outputs": [],
351
- "source": [
352
- "def collate_fn(batch):\n",
353
- " batch_dict = {\n",
354
- " \"input_ids\": [record[\"input_ids\"] for record in batch],\n",
355
- " \"attention_mask\": [record[\"attention_mask\"] for record in batch],\n",
356
- " }\n",
357
- " padded = tokenizer.pad(batch_dict, padding=True, return_tensors=\"pt\")\n",
358
- " labels = padded[\"input_ids\"].clone()\n",
359
- " labels[padded[\"attention_mask\"] == 0] = -100\n",
360
- " padded[\"labels\"] = labels\n",
361
- " return padded\n",
362
- "\n",
363
- "\n",
364
- "TrainLoader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn)\n",
365
- "EvalLoader = DataLoader(eval_dataset, batch_size=eval_batch_size, shuffle=False, collate_fn=collate_fn)\n"
366
- ]
367
- },
368
- {
369
- "cell_type": "markdown",
370
- "id": "f5965d1b",
371
- "metadata": {
372
- "id": "f5965d1b"
373
- },
374
- "source": [
375
- "## Optimizer"
376
- ]
377
- },
378
- {
379
- "cell_type": "code",
380
- "execution_count": null,
381
- "id": "f57c7be2",
382
- "metadata": {
383
- "id": "f57c7be2"
384
- },
385
- "outputs": [],
386
- "source": [
387
- "optimizer = torch.optim.AdamW(\n",
388
- " model.parameters(),\n",
389
- " lr=learning_rate,\n",
390
- " weight_decay=weight_decay,\n",
391
- ")\n"
392
- ]
393
- },
394
- {
395
- "cell_type": "markdown",
396
- "id": "215f8782",
397
- "metadata": {
398
- "id": "215f8782"
399
- },
400
- "source": [
401
- "# Learning Rate Scheduler"
402
- ]
403
- },
404
- {
405
- "cell_type": "code",
406
- "execution_count": null,
407
- "id": "034e2903",
408
- "metadata": {
409
- "id": "034e2903"
410
- },
411
- "outputs": [],
412
- "source": [
413
- "num_update_steps_per_epoch = max(len(TrainLoader) // gradient_accumulation_steps, 1)\n",
414
- "max_train_steps = num_epochs * num_update_steps_per_epoch\n",
415
- "warmup_steps = max(1, int(max_train_steps * warmup_ratio))\n",
416
- "scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, max_train_steps)\n"
417
- ]
418
- },
419
- {
420
- "cell_type": "markdown",
421
- "id": "0f0090b6",
422
- "metadata": {
423
- "id": "0f0090b6"
424
- },
425
- "source": [
426
- "# The Training Loop"
427
- ]
428
- },
429
- {
430
- "cell_type": "code",
431
- "execution_count": null,
432
- "id": "1540e30a",
433
- "metadata": {
434
- "id": "1540e30a",
435
- "outputId": "747badd7-18df-441f-8026-7aa4f30c2fd7"
436
- },
437
- "outputs": [
438
- {
439
- "name": "stderr",
440
- "output_type": "stream",
441
- "text": [
442
- "You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
443
- ]
444
- },
445
- {
446
- "name": "stdout",
447
- "output_type": "stream",
448
- "text": [
449
- "Epoch 1/1\n",
450
- "step=00010 | loss=1.7586 | lr=1.33e-06\n",
451
- "step=00020 | loss=1.8188 | lr=2.67e-06\n",
452
- "step=00030 | loss=1.8235 | lr=4.00e-06\n",
453
- "step=00040 | loss=1.7935 | lr=5.33e-06\n",
454
- "step=00050 | loss=1.8029 | lr=6.67e-06\n",
455
- "step=00060 | loss=1.8433 | lr=8.00e-06\n",
456
- "step=00070 | loss=1.8616 | lr=9.33e-06\n",
457
- "step=00080 | loss=1.8238 | lr=9.98e-06\n",
458
- "step=00090 | loss=1.7774 | lr=9.94e-06\n",
459
- "step=00100 | loss=1.8081 | lr=9.90e-06\n",
460
- "step=00110 | loss=1.7437 | lr=9.86e-06\n",
461
- "step=00120 | loss=1.7830 | lr=9.81e-06\n",
462
- "step=00130 | loss=1.8064 | lr=9.77e-06\n",
463
- "step=00140 | loss=1.8541 | lr=9.73e-06\n",
464
- "step=00150 | loss=1.8301 | lr=9.69e-06\n",
465
- "step=00160 | loss=1.7725 | lr=9.65e-06\n",
466
- "step=00170 | loss=1.7635 | lr=9.61e-06\n",
467
- "step=00180 | loss=1.7963 | lr=9.57e-06\n",
468
- "step=00190 | loss=1.7563 | lr=9.53e-06\n",
469
- "step=00200 | loss=1.6950 | lr=9.48e-06\n",
470
- "step=00210 | loss=1.7680 | lr=9.44e-06\n",
471
- "step=00220 | loss=1.8906 | lr=9.40e-06\n",
472
- "step=00230 | loss=1.7120 | lr=9.36e-06\n",
473
- "step=00240 | loss=1.8390 | lr=9.32e-06\n",
474
- "step=00250 | loss=1.7180 | lr=9.28e-06\n",
475
- "step=00260 | loss=1.7709 | lr=9.24e-06\n",
476
- "step=00270 | loss=1.7598 | lr=9.20e-06\n",
477
- "step=00280 | loss=1.7981 | lr=9.15e-06\n",
478
- "step=00290 | loss=1.7540 | lr=9.11e-06\n",
479
- "step=00300 | loss=1.7695 | lr=9.07e-06\n",
480
- "step=00310 | loss=1.7468 | lr=9.03e-06\n"
481
- ]
482
- },
483
- {
484
- "ename": "KeyboardInterrupt",
485
- "evalue": "",
486
- "output_type": "error",
487
- "traceback": [
488
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
489
- "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
490
- "Cell \u001b[0;32mIn[14], line 11\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m step, batch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(TrainLoader, start\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m):\n\u001b[1;32m 10\u001b[0m batch \u001b[38;5;241m=\u001b[39m {key: value\u001b[38;5;241m.\u001b[39mto(device) \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m batch\u001b[38;5;241m.\u001b[39mitems()}\n\u001b[0;32m---> 11\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 12\u001b[0m loss \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;241m/\u001b[39m gradient_accumulation_steps\n\u001b[1;32m 13\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n",
491
- "File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1773\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1771\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1772\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1773\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
492
- "File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1784\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1779\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1780\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1781\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1782\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1783\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1784\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1786\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1787\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
493
- "File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/peft/peft_model.py:1850\u001b[0m, in \u001b[0;36mPeftModelForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)\u001b[0m\n\u001b[1;32m 1848\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_enable_peft_forward_hooks(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 1849\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m {k: v \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m kwargs\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mspecial_peft_forward_args}\n\u001b[0;32m-> 1850\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbase_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1851\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1852\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1853\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1854\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1855\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1856\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1857\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1858\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1859\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1861\u001b[0m batch_size \u001b[38;5;241m=\u001b[39m _get_batch_size(input_ids, inputs_embeds)\n\u001b[1;32m 1862\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attention_mask \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1863\u001b[0m \u001b[38;5;66;03m# concat prompt attention mask\u001b[39;00m\n",
494
- "File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1773\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1771\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1772\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1773\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
495
- "File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1784\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1779\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1780\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1781\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1782\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1783\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1784\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1786\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1787\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
496
- "File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:222\u001b[0m, in \u001b[0;36mBaseTuner.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any):\n\u001b[0;32m--> 222\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
497
- "File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/utils/generic.py:757\u001b[0m, in \u001b[0;36mcan_return_tuple.<locals>.wrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 755\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m return_dict_passed \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 756\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict_passed\n\u001b[0;32m--> 757\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 758\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m return_dict \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(output, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[1;32m 759\u001b[0m output \u001b[38;5;241m=\u001b[39m output\u001b[38;5;241m.\u001b[39mto_tuple()\n",
498
- "File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/models/nanochat/modeling_nanochat.py:474\u001b[0m, in \u001b[0;36mNanoChatForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, cache_position, logits_to_keep, **kwargs)\u001b[0m\n\u001b[1;32m 435\u001b[0m \u001b[38;5;129m@can_return_tuple\u001b[39m\n\u001b[1;32m 436\u001b[0m \u001b[38;5;129m@auto_docstring\u001b[39m\n\u001b[1;32m 437\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mforward\u001b[39m(\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 448\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Unpack[TransformersKwargs],\n\u001b[1;32m 449\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m CausalLMOutputWithPast:\n\u001b[1;32m 450\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 451\u001b[0m \u001b[38;5;124;03m Example:\u001b[39;00m\n\u001b[1;32m 452\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 472\u001b[0m \u001b[38;5;124;03m >>> output = tokenizer.decode(generated_tokens, skip_special_tokens=True)\u001b[39;00m\n\u001b[1;32m 473\u001b[0m \u001b[38;5;124;03m ```\"\"\"\u001b[39;00m\n\u001b[0;32m--> 474\u001b[0m outputs: BaseModelOutputWithPast \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 475\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 476\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 477\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 478\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 479\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 480\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 481\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 482\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 483\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 485\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mlast_hidden_state\n\u001b[1;32m 486\u001b[0m slice_indices \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mslice\u001b[39m(\u001b[38;5;241m-\u001b[39mlogits_to_keep, \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(logits_to_keep, \u001b[38;5;28mint\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m logits_to_keep\n",
499
- "File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1773\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1771\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1772\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1773\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
500
- "File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1784\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1779\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1780\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1781\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1782\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1783\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1784\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1786\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1787\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
501
- "File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/utils/generic.py:927\u001b[0m, in \u001b[0;36mcheck_model_inputs.<locals>.wrapped_fn.<locals>.wrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 924\u001b[0m monkey_patched_layers\u001b[38;5;241m.\u001b[39mappend((module, original_forward))\n\u001b[1;32m 926\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 927\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 928\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m original_exception:\n\u001b[1;32m 929\u001b[0m \u001b[38;5;66;03m# If we get a TypeError, it's possible that the model is not receiving the recordable kwargs correctly.\u001b[39;00m\n\u001b[1;32m 930\u001b[0m \u001b[38;5;66;03m# Get a TypeError even after removing the recordable kwargs -> re-raise the original exception\u001b[39;00m\n\u001b[1;32m 931\u001b[0m \u001b[38;5;66;03m# Otherwise -> we're probably missing `**kwargs` in the decorated function\u001b[39;00m\n\u001b[1;32m 932\u001b[0m kwargs_without_recordable \u001b[38;5;241m=\u001b[39m {k: v \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m kwargs\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m recordable_keys}\n",
502
- "File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/models/nanochat/modeling_nanochat.py:401\u001b[0m, in \u001b[0;36mNanoChatModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, cache_position, **kwargs)\u001b[0m\n\u001b[1;32m 398\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minitial_norm(hidden_states)\n\u001b[1;32m 400\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m decoder_layer \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[0;32m--> 401\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 402\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 403\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcausal_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 404\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 405\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 406\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 407\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_position\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_position\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 408\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_embeddings\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_embeddings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 409\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 410\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 412\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm(hidden_states)\n\u001b[1;32m 414\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m BaseModelOutputWithPast(\n\u001b[1;32m 415\u001b[0m last_hidden_state\u001b[38;5;241m=\u001b[39mhidden_states,\n\u001b[1;32m 416\u001b[0m past_key_values\u001b[38;5;241m=\u001b[39mpast_key_values \u001b[38;5;28;01mif\u001b[39;00m use_cache \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 417\u001b[0m )\n",
503
- "File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/modeling_layers.py:94\u001b[0m, in \u001b[0;36mGradientCheckpointingLayer.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 91\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarning_once(message)\n\u001b[1;32m 93\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(partial(\u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs), \u001b[38;5;241m*\u001b[39margs)\n\u001b[0;32m---> 94\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__call__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
504
- "File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1773\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1771\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1772\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1773\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
505
- "File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1784\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1779\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1780\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1781\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1782\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1783\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1784\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1786\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1787\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
506
- "File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/models/nanochat/modeling_nanochat.py:279\u001b[0m, in \u001b[0;36mNanoChatDecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_values, use_cache, cache_position, position_embeddings, **kwargs)\u001b[0m\n\u001b[1;32m 267\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m 268\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 269\u001b[0m hidden_states: torch\u001b[38;5;241m.\u001b[39mTensor,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 276\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Unpack[TransformersKwargs],\n\u001b[1;32m 277\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[1;32m 278\u001b[0m residual \u001b[38;5;241m=\u001b[39m hidden_states\n\u001b[0;32m--> 279\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minput_layernorm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 280\u001b[0m \u001b[38;5;66;03m# Self Attention\u001b[39;00m\n\u001b[1;32m 281\u001b[0m hidden_states, _ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mself_attn(\n\u001b[1;32m 282\u001b[0m hidden_states\u001b[38;5;241m=\u001b[39mhidden_states,\n\u001b[1;32m 283\u001b[0m attention_mask\u001b[38;5;241m=\u001b[39mattention_mask,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 289\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 290\u001b[0m )\n",
507
- "File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1773\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1771\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1772\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1773\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
508
- "File \u001b[0;32m/fsx/benjamin_burtenshaw/nanochat_/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1784\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1779\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1780\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1781\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1782\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1783\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1784\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1786\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1787\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
509
- "File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/models/nanochat/modeling_nanochat.py:53\u001b[0m, in \u001b[0;36mNanoChatRMSNorm.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m---> 53\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_norm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mtype_as(x)\n",
510
- "File \u001b[0;32m/fsx/benjamin_burtenshaw/transformers/src/transformers/models/nanochat/modeling_nanochat.py:50\u001b[0m, in \u001b[0;36mNanoChatRMSNorm._norm\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_norm\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m---> 50\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mrsqrt(\u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpow\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmean\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeepdim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39meps)\n",
511
- "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
512
- ]
513
- }
514
- ],
515
- "source": [
516
- "\n",
517
- "model.train()\n",
518
- "global_step = 0\n",
519
- "running_loss = 0.0\n",
520
- "running_steps = 0\n",
521
- "\n",
522
- "for epoch in range(num_epochs):\n",
523
- " print(f\"Epoch {epoch + 1}/{num_epochs}\")\n",
524
- " optimizer.zero_grad(set_to_none=True)\n",
525
- " for step, batch in enumerate(TrainLoader, start=1):\n",
526
- " batch = {key: value.to(device) for key, value in batch.items()}\n",
527
- " outputs = model(**batch)\n",
528
- " loss = outputs.loss / gradient_accumulation_steps\n",
529
- " loss.backward()\n",
530
- "\n",
531
- " running_loss += outputs.loss.float().item()\n",
532
- " running_steps += 1\n",
533
- "\n",
534
- " if step % gradient_accumulation_steps == 0 or step == len(TrainLoader):\n",
535
- " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
536
- " optimizer.step()\n",
537
- " scheduler.step()\n",
538
- " optimizer.zero_grad(set_to_none=True)\n",
539
- " global_step += 1\n",
540
- "\n",
541
- " if global_step % logging_frequency == 0:\n",
542
- " current_lr = scheduler.get_last_lr()[0]\n",
543
- " mean_loss = running_loss / running_steps\n",
544
- " print(f\"step={global_step:05d} | loss={mean_loss:.4f} | lr={current_lr:.2e}\")\n",
545
- " running_loss = 0.0\n",
546
- " running_steps = 0\n",
547
- "\n",
548
- " train_loss = running_loss / running_steps if running_steps > 0 else float(\"nan\")\n",
549
- " print(f\"Training loss after epoch {epoch + 1}: {train_loss:.4f}\")\n",
550
- "\n",
551
- " model.eval()\n",
552
- " losses = []\n",
553
- " with torch.no_grad():\n",
554
- " for _, batch in enumerate(EvalLoader, start=1):\n",
555
- " batch = {key: value.to(device) for key, value in batch.items()}\n",
556
- " loss = model(**batch).loss\n",
557
- " losses.append(loss.float().item())\n",
558
- " model.train()\n",
559
- " val_loss = sum(losses) / len(losses) if losses else float(\"nan\")\n",
560
- "\n",
561
- " print(f\"Validation loss after epoch {epoch + 1}: {val_loss:.4f}\")\n",
562
- "\n",
563
- "print(\"Training complete.\")\n"
564
- ]
565
  }
566
  ],
567
  "metadata": {
 
 
 
568
  "kernelspec": {
569
  "display_name": ".venv",
570
  "language": "python",
@@ -581,11 +186,8 @@
581
  "nbconvert_exporter": "python",
582
  "pygments_lexer": "ipython3",
583
  "version": "3.10.18"
584
- },
585
- "colab": {
586
- "provenance": []
587
  }
588
  },
589
  "nbformat": 4,
590
  "nbformat_minor": 5
591
- }
 
59
  ").to(device)\n"
60
  ]
61
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  {
63
  "cell_type": "markdown",
64
  "id": "4810af1a",
 
164
  "print(f\"\\nGenerated: {tokenizer.decode(generated_tokens)}\")\n",
165
  "print(\"=\" * 80)\n"
166
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  }
168
  ],
169
  "metadata": {
170
+ "colab": {
171
+ "provenance": []
172
+ },
173
  "kernelspec": {
174
  "display_name": ".venv",
175
  "language": "python",
 
186
  "nbconvert_exporter": "python",
187
  "pygments_lexer": "ipython3",
188
  "version": "3.10.18"
 
 
 
189
  }
190
  },
191
  "nbformat": 4,
192
  "nbformat_minor": 5
193
+ }