NeuroBLAST-V3-SYNTH-EC-150000

⚠️ EXPERIMENTAL EARLY CHECKPOINT ⚠️

This is an Early Checkpoint (EC) of the NeuroBLAST V3 architecture, a novel hybrid model designed with a biologically inspired "cortical" structure.

This specific checkpoint (150k steps) represents the "pre-decay" phase of training. It has been trained on short contexts with a high learning rate and is intended for architectural evaluation and research purposes.

Model Details

  • Architecture: NeuroBLAST V3 (Custom Hybrid Architecture)
  • Checkpoint Step: 150,000
  • Parameters: 596,728,320
  • Num layers: 72
    • Sensory layers: 24
    • Associative layers: 32
    • Motor layers: 16
  • Hidden size: 512
  • Vocab size: 65538
  • Intermediate size: 3072
  • Num attention heads: 16
  • Num kv heads: 8
  • Head dim: 128
  • Tie word embeddings: False

Architecture Highlights

NeuroBLAST differs from standard Transformers by utilizing a three-stage cortical design:

  1. Sensory Cortex: Hybrid layers alternating between Attention and Dilated Causal 2D Convolutions.
  2. Associative Cortex: Hybrid layers with alternating RoPE usage.
  3. Motor Cortex: Pure Attention layers.
  4. Deep Residual Bridges: Long-range residual connections injecting the original embeddings (and their negations) between cortical stages to improve signal propagation.

architecture

Training Details

This model is currently being trained using the Google TPU Research Cloud (TRC).

  • Dataset: PleIAs/SYNTH
  • Tokens Processed: ~118 Billion
  • Hardware: TPUv4-16
  • Training Time: ~8 Days
  • Effective Batch Size: 1024
  • Context Length: 768 tokens (Current phase)
  • Learning rate: 4e-3
  • Weight decay: 0.0
  • Optimizer: AdamW
  • Precision: BFloat16
  • Current State: Pre-decay phase (No weight decay applied yet).

eval_loss

Roadmap

This checkpoint marks the end of the initial warmup/learning phase. The next steps in training are:

  1. Significantly extending the context length.
  2. Lowering the learning rate.
  3. Introducing weight decay for convergence.

Usage

Note: You must use trust_remote_code=True as this model utilizes custom modeling code (modeling_neuroblast.py).

import torch
from transformers import AutoTokenizer, TextStreamer, AutoModelForCausalLM

model_id = "mkurman/NeuroBLAST-V3-SYNTH-EC-150000"

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Load the model with custom code trust
model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    torch_dtype=torch.bfloat16, 
    device_map='cuda', 
    trust_remote_code=True
).eval()

streamer = TextStreamer(
    tokenizer, skip_prompt=False, decode_kwargs={"skip_special_tokens": False}
)

# Prepare input
input_ids = tokenizer.apply_chat_template(
    [{"role": "user", "content": "what is hypertension?"}], 
    tokenize=True, 
    return_tensors="pt", 
    add_generation_prompt=True
)

print(f"Input IDs: {input_ids}")

# Generate
with torch.no_grad():
    outputs = model.generate(
        input_ids=input_ids.to(model.device),
        max_new_tokens=128,
        streamer=streamer,
        use_cache=True,
        # Important: Keep repetition_penalty at 1.0 for this early checkpoint
        repetition_penalty=1.0, 
    )

You can find the underlying JAX implementation in the neuroblastv3_jax folder. (weights in a separate project)


import argparse
import jax
import jax.numpy as jnp
from transformers import AutoTokenizer
from neuroblast3_jax.modeling_neuroblast_jax import NeuroBLASTForCausalLM as NeuroBLASTForCausalLMJax

def generate_text(model, tokenizer, text, max_new_tokens=50, temperature=0.7, top_k=50):
    inputs = tokenizer(f"user\n{text}<|im_end|><|im_start|>assistant\n", return_tensors="np")
    original_input_ids = inputs["input_ids"]
    batch_size, prompt_len = original_input_ids.shape
    total_len = prompt_len + max_new_tokens
    
    # Pad input_ids to total_len
    pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
    input_ids = jnp.full((batch_size, total_len), pad_id, dtype=jnp.int32)
    input_ids = input_ids.at[:, :prompt_len].set(original_input_ids)
    
    attention_mask = jnp.ones((batch_size, total_len), dtype=jnp.int32)
    params = model.params

    @jax.jit
    def model_step(params, input_ids, attention_mask, rng):
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, params=params, train=False)
        return outputs.logits

    rng = jax.random.PRNGKey(0)
    
    print("Generating...")
    current_len = prompt_len
    printed_len = 0
    
    for i in range(max_new_tokens):
        rng, step_rng = jax.random.split(rng)
        
        # Run model
        logits = model_step(params, input_ids, attention_mask, step_rng)
        
        # Get logits for the last valid token (current_len - 1)
        next_token_logits = logits[:, current_len - 1, :]
        
        # Sampling
        scaled_logits = next_token_logits / temperature
        next_token = jax.random.categorical(step_rng, scaled_logits, axis=-1)
        
        # Update input_ids
        # We need to update the next position
        input_ids = input_ids.at[:, current_len].set(next_token)
        
        current_len += 1
        
        # Streaming output
        valid_ids = input_ids[0, :current_len]
        current_text = tokenizer.decode(valid_ids, skip_special_tokens=False)
        
        if i == 0:
             pass

        new_text = current_text[printed_len:]
        if new_text:
            print(new_text, end="", flush=True)
            printed_len += len(new_text)
        
        # Check EOS
        if next_token[0] == tokenizer.eos_token_id:
            break
            
    valid_ids = input_ids[0, :current_len]
    return tokenizer.decode(valid_ids, skip_special_tokens=False)


  checkpoint = "mkurman/NeuroBLAST-V3-SYNTH-EC-150000-JAX"

  print(f"Loading model from {checkpoint}...")
  tokenizer = AutoTokenizer.from_pretrained(
      checkpoint,
      use_fast=True,
      trust_remote_code=True,
  )

  print(f"Available devices: {jax.devices()}")

  model = NeuroBLASTForCausalLMJax.from_pretrained(
      checkpoint,
      dtype=jnp.bfloat16, 
      trust_remote_code=True,
      is_decoder=True,
  )
  
  generated_text = generate_text(model, tokenizer, 'what is hypertension?', 128)
  
  print("\nGenerated Text:")
  print("-" * 20)
  print(generated_text)
  print("-" * 20)

Acknowledgments

This model was trained using Cloud TPUs provided by Google's TPU Research Cloud (TRC) program.

Special thanks to Pierre-Carl Langlais and the PleIAs team for the high-quality SYNTH dataset.

Downloads last month
53
Safetensors
Model size
0.6B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train mkurman/NeuroBLAST-V3-SYNTH-EC-150000

Collection including mkurman/NeuroBLAST-V3-SYNTH-EC-150000