kero2111 commited on
Commit
0b88a08
·
verified ·
1 Parent(s): 36eb585

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +72 -69
inference.py CHANGED
@@ -3,63 +3,75 @@ import numpy as np
3
  from PIL import Image
4
  import os
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def load_model(model_path):
7
- """
8
- Load the pre-trained plant disease classification model
9
- """
10
  try:
11
- model = tf.keras.models.load_model(model_path)
12
  print("Model loaded successfully!")
13
  return model
14
  except Exception as e:
15
  print(f"Error loading model: {e}")
16
  return None
17
 
18
- def preprocess_image(image_path, target_size=(224, 224)):
19
- """
20
- Preprocess image for model inference
21
- """
22
  try:
23
- # Load image
24
  img = Image.open(image_path)
25
-
26
- # Convert to RGB if necessary
27
  if img.mode != 'RGB':
28
  img = img.convert('RGB')
29
-
30
- # Resize image
31
  img = img.resize(target_size)
32
-
33
- # Convert to numpy array and normalize
34
- img_array = np.array(img) / 255.0
35
-
36
- # Add batch dimension
37
  img_array = np.expand_dims(img_array, axis=0)
38
-
39
  return img_array
40
  except Exception as e:
41
  print(f"Error preprocessing image: {e}")
42
  return None
43
 
 
 
 
44
  def predict_disease(model, image_array):
45
- """
46
- Make prediction on preprocessed image
47
- """
48
  try:
49
- # Make prediction
50
  prediction = model.predict(image_array)
51
  predicted_class = np.argmax(prediction[0])
52
  confidence = prediction[0][predicted_class]
53
-
54
  return predicted_class, confidence, prediction[0]
55
  except Exception as e:
56
  print(f"Error making prediction: {e}")
57
  return None, None, None
58
 
 
 
 
59
  def get_class_name(class_index):
60
- """
61
- Get class name from class index
62
- """
63
  classes = [
64
  "Apple___Apple_scab", "Apple___Black_rot", "Apple___Cedar_apple_rust", "Apple___healthy",
65
  "Blueberry___healthy", "Cherry_(including_sour)___Powdery_mildew", "Cherry_(including_sour)___healthy",
@@ -75,60 +87,51 @@ def get_class_name(class_index):
75
  "Tomato___Target_Spot", "Tomato___Tomato_Yellow_Leaf_Curl_Virus", "Tomato___Tomato_mosaic_virus",
76
  "Tomato___healthy"
77
  ]
78
-
79
  if 0 <= class_index < len(classes):
80
  return classes[class_index]
81
  else:
82
  return "Unknown"
83
 
 
 
 
84
  def main():
85
- """
86
- Main function to demonstrate model usage
87
- """
88
- # Model path (update this path to your model location)
89
  model_path = "Pretrained_model.h5"
90
-
91
- # Check if model exists
92
  if not os.path.exists(model_path):
93
  print(f"Model file not found at: {model_path}")
94
- print("Please ensure the model file is in the current directory or update the path.")
95
  return
96
-
97
- # Load model
 
 
 
98
  model = load_model(model_path)
99
  if model is None:
100
  return
101
-
102
- # Example usage with a sample image
103
- # Replace 'sample_image.jpg' with your actual image path
104
- sample_image_path = "sample_image.jpg"
105
-
106
- if os.path.exists(sample_image_path):
107
- # Preprocess image
108
- image_array = preprocess_image(sample_image_path)
109
- if image_array is None:
110
- return
111
-
112
- # Make prediction
113
- predicted_class, confidence, all_predictions = predict_disease(model, image_array)
114
-
115
- if predicted_class is not None:
116
- class_name = get_class_name(predicted_class)
117
- print(f"\nPrediction Results:")
118
- print(f"Predicted Class: {class_name}")
119
- print(f"Confidence: {confidence:.2%}")
120
- print(f"Class Index: {predicted_class}")
121
-
122
- # Show top 3 predictions
123
- top_3_indices = np.argsort(all_predictions)[-3:][::-1]
124
- print(f"\nTop 3 Predictions:")
125
- for i, idx in enumerate(top_3_indices):
126
- class_name = get_class_name(idx)
127
- confidence = all_predictions[idx]
128
- print(f"{i+1}. {class_name}: {confidence:.2%}")
129
- else:
130
- print(f"Sample image not found at: {sample_image_path}")
131
- print("Please provide a valid image path to test the model.")
132
 
133
  if __name__ == "__main__":
134
- main()
 
3
  from PIL import Image
4
  import os
5
 
6
+ # ========================
7
+ # Custom layer (مطلوبة)
8
+ # ========================
9
+ from tensorflow.keras.layers import Layer
10
+
11
+ class CustomScaleLayer(Layer):
12
+ def __init__(self, scale=1.0, **kwargs):
13
+ super(CustomScaleLayer, self).__init__(**kwargs)
14
+ self.scale = scale
15
+
16
+ def call(self, inputs):
17
+ if isinstance(inputs, (list, tuple)):
18
+ x = tf.add_n(inputs)
19
+ else:
20
+ x = inputs
21
+ return x * self.scale
22
+
23
+ def get_config(self):
24
+ config = super().get_config()
25
+ config.update({"scale": self.scale})
26
+ return config
27
+
28
+ # ========================
29
+ # Load the model
30
+ # ========================
31
  def load_model(model_path):
 
 
 
32
  try:
33
+ model = tf.keras.models.load_model(model_path, custom_objects={'CustomScaleLayer': CustomScaleLayer})
34
  print("Model loaded successfully!")
35
  return model
36
  except Exception as e:
37
  print(f"Error loading model: {e}")
38
  return None
39
 
40
+ # ========================
41
+ # Image preprocessing
42
+ # ========================
43
+ def preprocess_image(image_path, target_size=(299, 299), normalize=True):
44
  try:
 
45
  img = Image.open(image_path)
 
 
46
  if img.mode != 'RGB':
47
  img = img.convert('RGB')
 
 
48
  img = img.resize(target_size)
49
+ img_array = np.array(img)
50
+ if normalize:
51
+ img_array = img_array / 255.0
 
 
52
  img_array = np.expand_dims(img_array, axis=0)
 
53
  return img_array
54
  except Exception as e:
55
  print(f"Error preprocessing image: {e}")
56
  return None
57
 
58
+ # ========================
59
+ # Prediction
60
+ # ========================
61
  def predict_disease(model, image_array):
 
 
 
62
  try:
 
63
  prediction = model.predict(image_array)
64
  predicted_class = np.argmax(prediction[0])
65
  confidence = prediction[0][predicted_class]
 
66
  return predicted_class, confidence, prediction[0]
67
  except Exception as e:
68
  print(f"Error making prediction: {e}")
69
  return None, None, None
70
 
71
+ # ========================
72
+ # Class names
73
+ # ========================
74
  def get_class_name(class_index):
 
 
 
75
  classes = [
76
  "Apple___Apple_scab", "Apple___Black_rot", "Apple___Cedar_apple_rust", "Apple___healthy",
77
  "Blueberry___healthy", "Cherry_(including_sour)___Powdery_mildew", "Cherry_(including_sour)___healthy",
 
87
  "Tomato___Target_Spot", "Tomato___Tomato_Yellow_Leaf_Curl_Virus", "Tomato___Tomato_mosaic_virus",
88
  "Tomato___healthy"
89
  ]
 
90
  if 0 <= class_index < len(classes):
91
  return classes[class_index]
92
  else:
93
  return "Unknown"
94
 
95
+ # ========================
96
+ # Main function
97
+ # ========================
98
  def main():
 
 
 
 
99
  model_path = "Pretrained_model.h5"
100
+ sample_image_path = "sample_image.jpg" # 👈 ضع هنا اسم الصورة
101
+
102
  if not os.path.exists(model_path):
103
  print(f"Model file not found at: {model_path}")
 
104
  return
105
+
106
+ if not os.path.exists(sample_image_path):
107
+ print(f"Image file not found at: {sample_image_path}")
108
+ return
109
+
110
  model = load_model(model_path)
111
  if model is None:
112
  return
113
+
114
+ # تحقق هل الموديل فيه طبقة Rescaling
115
+ has_rescaling = any(isinstance(layer, tf.keras.layers.Rescaling) for layer in model.layers)
116
+ image_array = preprocess_image(sample_image_path, target_size=(299, 299), normalize=not has_rescaling)
117
+ if image_array is None:
118
+ return
119
+
120
+ predicted_class, confidence, all_predictions = predict_disease(model, image_array)
121
+ if predicted_class is not None:
122
+ class_name = get_class_name(predicted_class)
123
+ print(f"\nPrediction Results:")
124
+ print(f"Predicted Class: {class_name}")
125
+ print(f"Confidence: {confidence:.2%}")
126
+ print(f"Class Index: {predicted_class}")
127
+
128
+ # Show top 3 predictions
129
+ top_3_indices = np.argsort(all_predictions)[-3:][::-1]
130
+ print(f"\nTop 3 Predictions:")
131
+ for i, idx in enumerate(top_3_indices):
132
+ class_name = get_class_name(idx)
133
+ confidence = all_predictions[idx]
134
+ print(f"{i+1}. {class_name}: {confidence:.2%}")
 
 
 
 
 
 
 
 
 
135
 
136
  if __name__ == "__main__":
137
+ main()