ssheroz commited on
Commit
95e00bd
·
verified ·
1 Parent(s): 9214390

Upload 3 files

Browse files
industrial-document-classifier-clip-lora.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00051355383afc43c51aea8897012b518646e7150e4e2228608793db0fae3e20
3
+ size 407433429
main.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pipeline import DocumentClassifier
2
+
3
+ def main():
4
+ classifier = DocumentClassifier()
5
+
6
+ image_paths = [
7
+ "path/to/image1.jpg",
8
+ "path/to/image2.png",
9
+ "path/to/image3.jpeg",
10
+ ]
11
+
12
+ results = classifier.predict(image_paths)
13
+
14
+ for result in results:
15
+ print(f"\nImage: {result['image_path']}")
16
+ if result['error_response']:
17
+ print(f"Error: {result['error_response']}")
18
+ else:
19
+ print("Predictions:")
20
+ sorted_predictions = sorted(result['predictions'].items(), key=lambda x: x[1], reverse=True)
21
+ for class_name, probability in sorted_predictions:
22
+ print(f" {class_name}: {probability:.4f}")
23
+
24
+ classifier.unload()
25
+
26
+ if __name__ == "__main__":
27
+ main()
pipeline.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Dict, List, Union
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+ import torch
7
+ import torch.nn as nn
8
+ from PIL import Image
9
+ from transformers import CLIPVisionModel, CLIPProcessor
10
+ from peft import LoraConfig, get_peft_model
11
+ import warnings
12
+ warnings.filterwarnings(action="ignore")
13
+ class CONSTANTS:
14
+ BASE_MODEL_NAME = "openai/clip-vit-base-patch16"
15
+ TUNED_MODEL_NAME = "industrial-document-classifier-clip-lora.pt"
16
+ EMBEDDING_DIM = 768
17
+ NUM_PARENT_CLASSES = 6
18
+ NUM_CHILD_CLASSES = 13
19
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
+ LORA_R = 32
21
+ LORA_ALPHA = 64
22
+ LORA_DROPOUT = 0.1
23
+ LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"]
24
+ BATCH_SIZE = 16
25
+ MAX_WORKERS = os.cpu_count()
26
+ VALID_IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.tif', '.webp', '.ico', '.heic', '.heif'}
27
+ PARENT_CLASS_NAMES = {0: "product_information", 1: "engineering_drawings", 2: "instructional_guides", 3: "compliance_certificates", 4: "energy_ratings", 5: "warranty_documents"}
28
+ class HierarchicalDocumentClassifier(nn.Module):
29
+ def __init__(self, model_name: str, num_parent_classes: int, num_child_classes: int, embedding_dim: int):
30
+ super().__init__()
31
+ self.vision_model = CLIPVisionModel.from_pretrained(model_name, use_safetensors=False)
32
+ self.parent_classifier = nn.Linear(embedding_dim, num_parent_classes)
33
+ self.child_classifier = nn.Linear(embedding_dim, num_child_classes)
34
+ def forward(self, pixel_values):
35
+ outputs = self.vision_model(pixel_values=pixel_values)
36
+ embeddings = outputs.pooler_output
37
+ parent_logits = self.parent_classifier(embeddings)
38
+ child_logits = self.child_classifier(embeddings)
39
+ return parent_logits, child_logits, embeddings
40
+ class DocumentClassifier:
41
+ def __init__(self):
42
+ self.device = torch.device(CONSTANTS.DEVICE)
43
+ self.processor = CLIPProcessor.from_pretrained(CONSTANTS.BASE_MODEL_NAME, use_fast=True)
44
+ self.model = self._load_model()
45
+ self.model.eval()
46
+ self._clear_cache()
47
+ def _load_model(self) -> nn.Module:
48
+ model = HierarchicalDocumentClassifier(
49
+ model_name=CONSTANTS.BASE_MODEL_NAME,
50
+ num_parent_classes=CONSTANTS.NUM_PARENT_CLASSES,
51
+ num_child_classes=CONSTANTS.NUM_CHILD_CLASSES,
52
+ embedding_dim=CONSTANTS.EMBEDDING_DIM)
53
+ lora_config = LoraConfig(
54
+ r=CONSTANTS.LORA_R,
55
+ lora_alpha=CONSTANTS.LORA_ALPHA,
56
+ target_modules=CONSTANTS.LORA_TARGET_MODULES,
57
+ lora_dropout=CONSTANTS.LORA_DROPOUT,
58
+ bias="none")
59
+ model.vision_model = get_peft_model(model.vision_model, lora_config)
60
+ checkpoint = torch.load(CONSTANTS.TUNED_MODEL_NAME, map_location=self.device)
61
+ model.load_state_dict(checkpoint['model_state_dict'])
62
+ del checkpoint
63
+ self._clear_cache()
64
+ return model.to(self.device)
65
+ def _clear_cache(self):
66
+ if torch.cuda.is_available():
67
+ torch.cuda.empty_cache()
68
+ gc.collect()
69
+ def _is_valid_image(self, image_path: str) -> bool:
70
+ path = Path(image_path)
71
+ if not path.exists():
72
+ return False
73
+ if not path.is_file():
74
+ return False
75
+ if path.suffix.lower() not in CONSTANTS.VALID_IMAGE_EXTENSIONS:
76
+ return False
77
+ return True
78
+ def _load_image(self, image_path: str) -> Union[Image.Image, None]:
79
+ try:
80
+ img = Image.open(image_path).convert('RGB')
81
+ return img
82
+ except Exception:
83
+ return None
84
+ def _process_batch(self, images: List[Image.Image]):
85
+ try:
86
+ inputs = self.processor(images=images, return_tensors="pt")
87
+ pixel_values = inputs['pixel_values'].to(self.device)
88
+
89
+ with torch.no_grad():
90
+ parent_logits, _, _ = self.model(pixel_values)
91
+ probabilities = torch.softmax(parent_logits, dim=1)
92
+ return probabilities.cpu()
93
+ except Exception:
94
+ return None
95
+ def _process_single_image(self, image_path: str) -> Dict:
96
+ result = {
97
+ "image_path": image_path,
98
+ "predictions": {},
99
+ "error_response": ""
100
+ }
101
+ if not self._is_valid_image(image_path):
102
+ result["error_response"] = "Invalid image path or unsupported format"
103
+ return result
104
+ img = self._load_image(image_path)
105
+ if img is None:
106
+ result["error_response"] = "Failed to load image"
107
+ return result
108
+ probabilities = self._process_batch([img])
109
+
110
+ if probabilities is None:
111
+ result["error_response"] = "Model inference failed"
112
+ return result
113
+ probs = probabilities[0]
114
+ predictions_dict = {}
115
+ for class_id, class_name in CONSTANTS.PARENT_CLASS_NAMES.items():
116
+ predictions_dict[class_name] = float(probs[class_id])
117
+ result["predictions"] = predictions_dict
118
+ return result
119
+ def _process_batch_images(self, image_paths: List[str]) -> List[Dict]:
120
+ batch_results = []
121
+ batch_images = []
122
+ batch_valid_paths = []
123
+ for img_path in image_paths:
124
+ result = {
125
+ "image_path": img_path,
126
+ "predictions": {},
127
+ "error_response": ""
128
+ }
129
+ if not self._is_valid_image(img_path):
130
+ result["error_response"] = "Invalid image path or unsupported format"
131
+ batch_results.append(result)
132
+ continue
133
+ img = self._load_image(img_path)
134
+ if img is None:
135
+ result["error_response"] = "Failed to load image"
136
+ batch_results.append(result)
137
+ continue
138
+ batch_images.append(img)
139
+ batch_valid_paths.append(img_path)
140
+ batch_results.append(result)
141
+ if len(batch_images) > 0:
142
+ probabilities = self._process_batch(batch_images)
143
+ if probabilities is None:
144
+ for result in batch_results:
145
+ if result["error_response"] == "":
146
+ result["error_response"] = "Model inference failed"
147
+ else:
148
+ valid_idx = 0
149
+ for result in batch_results:
150
+ if result["error_response"] == "":
151
+ probs = probabilities[valid_idx]
152
+ predictions_dict = {}
153
+ for class_id, class_name in CONSTANTS.PARENT_CLASS_NAMES.items():
154
+ predictions_dict[class_name] = float(probs[class_id])
155
+ result["predictions"] = predictions_dict
156
+ valid_idx += 1
157
+ return batch_results
158
+ def predict(self, image_paths: List[str]) -> List[Dict]:
159
+ if len(image_paths) == 0:
160
+ return []
161
+ batches = []
162
+ for i in range(0, len(image_paths), CONSTANTS.BATCH_SIZE):
163
+ batches.append(image_paths[i:i + CONSTANTS.BATCH_SIZE])
164
+ results_map = {}
165
+ with ThreadPoolExecutor(max_workers=CONSTANTS.MAX_WORKERS) as executor:
166
+ future_to_batch = {executor.submit(self._process_batch_images, batch): batch for batch in batches}
167
+ for future in as_completed(future_to_batch):
168
+ batch = future_to_batch[future]
169
+ batch_results = future.result()
170
+ for result in batch_results:
171
+ results_map[result["image_path"]] = result
172
+ ordered_results = []
173
+ for img_path in image_paths:
174
+ ordered_results.append(results_map[img_path])
175
+ return ordered_results
176
+ def unload(self):
177
+ if hasattr(self, 'model') and self.model is not None:
178
+ del self.model
179
+ if hasattr(self, 'processor') and self.processor is not None:
180
+ del self.processor
181
+ self._clear_cache()