#!/usr/bin/env python3 """ Maaza Nano 9.6M - The 99ms Brain Simple inference script for tool routing. Usage: python inference.py "search for cats" python inference.py "read the file config.json" python inference.py "send an email to bob@example.com" """ import torch import json import sys import time from pathlib import Path # Add current directory to path for imports sys.path.insert(0, str(Path(__file__).parent)) from model import MaazaNanoModel, MaazaNanoConfig from tokenizer import BPETokenizer def load_model(model_dir: str = "."): """Load the Maaza Nano model.""" model_dir = Path(model_dir) # Load tokenizer tokenizer = BPETokenizer.load(str(model_dir / "tokenizer.json")) # Load model config with open(model_dir / "config.json") as f: cfg = json.load(f) config = MaazaNanoConfig( vocab_size=cfg["vocab_size"], hidden_size=cfg["hidden_size"], num_layers=cfg["num_layers"], num_heads=cfg["num_heads"], intermediate_size=cfg["intermediate_size"], max_position_embeddings=cfg["max_position_embeddings"], ) # Load model weights model = MaazaNanoModel(config) model.load_state_dict(torch.load(model_dir / "model.pt", weights_only=True)) # Use GPU if available device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) model.eval() return model, tokenizer, device def route_tool(prompt: str, model, tokenizer, device, max_tokens: int = 100): """Route a natural language prompt to a tool call.""" # Format input full_prompt = f"<|user|>{prompt}<|assistant|>" tokens = tokenizer.encode(full_prompt) input_ids = torch.tensor([tokens]).to(device) # Generate start_time = time.time() with torch.no_grad(): for _ in range(max_tokens): outputs = model(input_ids) logits = outputs["logits"] next_token = logits[0, -1].argmax().item() input_ids = torch.cat([input_ids, torch.tensor([[next_token]]).to(device)], dim=1) # Stop at EOS if next_token == tokenizer.vocab.get("<|eos|>"): break latency_ms = (time.time() - start_time) * 1000 # Decode output generated = tokenizer.decode(input_ids[0].tolist()) # Extract JSON from output try: json_start = generated.find('[{') if json_start >= 0: json_end = generated.find('}]', json_start) + 2 json_str = generated[json_start:json_end] tool_calls = json.loads(json_str) return tool_calls, latency_ms except json.JSONDecodeError: pass return None, latency_ms def main(): if len(sys.argv) < 2: print("Usage: python inference.py \"your prompt here\"") print("\nExamples:") print(" python inference.py \"search for cats\"") print(" python inference.py \"read config.json\"") print(" python inference.py \"send email to bob@example.com\"") sys.exit(1) prompt = sys.argv[1] print("Loading Maaza Nano 9.6M...") model, tokenizer, device = load_model() print(f"Model loaded on {device}") print() print(f"Prompt: {prompt}") tool_calls, latency = route_tool(prompt, model, tokenizer, device) print(f"Latency: {latency:.1f}ms") print() if tool_calls: print("Tool call:") print(json.dumps(tool_calls, indent=2)) else: print("Failed to parse tool call") if __name__ == "__main__": main()