ovinduG commited on
Commit
3560106
·
verified ·
1 Parent(s): b5a26dd

Add inference example script

Browse files
Files changed (1) hide show
  1. inference_example.py +85 -0
inference_example.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-Domain Classifier - Inference Example
3
+ Repository: https://huggingface.co/ovinduG/multi-domain-classifier-phi3
4
+ """
5
+
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ from peft import PeftModel
8
+ import torch
9
+ import json
10
+
11
+ class MultiDomainClassifier:
12
+ def __init__(self, model_id="ovinduG/multi-domain-classifier-phi3"):
13
+ print("Loading model...")
14
+
15
+ # Load base model
16
+ self.base_model = AutoModelForCausalLM.from_pretrained(
17
+ "microsoft/Phi-3-mini-4k-instruct",
18
+ torch_dtype=torch.bfloat16,
19
+ device_map="auto"
20
+ )
21
+
22
+ # Load LoRA adapter
23
+ self.model = PeftModel.from_pretrained(self.base_model, model_id)
24
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
25
+ self.model.eval()
26
+
27
+ print("✅ Model loaded!")
28
+
29
+ def predict(self, query: str) -> dict:
30
+ """Classify a query into domains"""
31
+
32
+ prompt = f"""Classify this query: {query}
33
+
34
+ Output JSON format:
35
+ {
36
+ "primary_domain": "domain_name",
37
+ "primary_confidence": 0.95,
38
+ "is_multi_domain": true/false,
39
+ "secondary_domains": [{"domain": "name", "confidence": 0.85}]
40
+ }"""
41
+
42
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
43
+
44
+ with torch.no_grad():
45
+ outputs = self.model.generate(
46
+ **inputs,
47
+ max_new_tokens=200,
48
+ temperature=0.1,
49
+ do_sample=False,
50
+ use_cache=False
51
+ )
52
+
53
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
54
+
55
+ # Parse JSON from response
56
+ try:
57
+ json_str = response.split("Output JSON format:")[-1].strip()
58
+ result = json.loads(json_str)
59
+ return result
60
+ except:
61
+ return {"error": "Failed to parse response", "raw": response}
62
+
63
+
64
+ # Example usage
65
+ if __name__ == "__main__":
66
+ # Initialize classifier
67
+ classifier = MultiDomainClassifier()
68
+
69
+ # Example queries
70
+ queries = [
71
+ "Write a Python function to calculate factorial",
72
+ "Build ML model to analyze sales data and create API endpoints",
73
+ "What is quantum entanglement?",
74
+ "Create a REST API for healthcare diabetes prediction"
75
+ ]
76
+
77
+ print("\n" + "="*80)
78
+ print("CLASSIFICATION EXAMPLES")
79
+ print("="*80)
80
+
81
+ for query in queries:
82
+ print(f"\nQuery: {query}")
83
+ result = classifier.predict(query)
84
+ print(f"Result: {json.dumps(result, indent=2)}")
85
+ print("-"*80)