harshinde commited on
Commit
4e79cd8
·
verified ·
1 Parent(s): f087569

Upload src/streamlit_app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +176 -70
src/streamlit_app.py CHANGED
@@ -1,44 +1,75 @@
1
  import streamlit as st
 
2
  import torch
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
  import yaml
6
  import os
7
- from pathlib import Path
8
- from model_downloader import ModelDownloader
9
 
10
  # Import models
11
- from .deeplabv3plus_model import LandslideModel as DeepLabV3PlusModel
12
- from .vgg16_model import LandslideModel as VGG16Model
13
- from .resnet34_model import LandslideModel as ResNet34Model
14
- from .efficientnetb0_model import LandslideModel as EfficientNetB0Model
15
- from .mitb1_model import LandslideModel as MiTB1Model
16
- from .inceptionv4_model import LandslideModel as InceptionV4Model
17
- from .densenet121_model import LandslideModel as DenseNet121Model
18
- from .resnext50_32x4d_model import LandslideModel as ResNeXt50_32X4DModel
19
- from .se_resnet50_model import LandslideModel as SEResNet50Model
20
- from .se_resnext50_32x4d_model import LandslideModel as SEResNeXt50_32X4DModel
21
- from .segformer_model import LandslideModel as SegFormerB2Model
22
- from .inceptionresnetv2_model import LandslideModel as InceptionResNetV2Model
23
-
24
- # Initialize model downloader
25
- model_downloader = ModelDownloader()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # Model descriptions
28
  model_descriptions = {
29
- "MobileNetV2": {"type": "mobilenet_v2", "description": "MobileNetV2 is a lightweight deep learning model for image classification and segmentation."},
30
- "VGG16": {"type": "vgg16", "description": "VGG16 is a popular deep learning model known for its simplicity and depth."},
31
- "ResNet34": {"type": "resnet34", "description": "ResNet34 is a deep residual network that helps in training very deep networks."},
32
- "EfficientNetB0": {"type": "efficientnet_b0", "description": "EfficientNetB0 is part of the EfficientNet family, known for its efficiency and performance."},
33
- "MiT-B1": {"type": "mit_b1", "description": "MiT-B1 is a transformer-based model designed for segmentation tasks."},
34
- "InceptionV4": {"type": "inceptionv4", "description": "InceptionV4 is a convolutional neural network known for its inception modules."},
35
- "DeepLabV3+": {"type": "deeplabv3plus", "description": "DeepLabV3+ is an advanced model for semantic image segmentation."},
36
- "DenseNet121": {"type": "densenet121", "description": "DenseNet121 is a densely connected convolutional network for image classification and segmentation."},
37
- "ResNeXt50_32X4D": {"type": "resnext50_32x4d", "description": "ResNeXt50_32X4D is a highly modularized network aimed at improving accuracy."},
38
- "SEResNet50": {"type": "se_resnet50", "description": "SEResNet50 is a ResNet model with squeeze-and-excitation blocks for better feature recalibration."},
39
- "SEResNeXt50_32X4D": {"type": "se_resnext50_32x4d", "description": "SEResNeXt50_32X4D combines ResNeXt and SE blocks for improved performance."},
40
- "SegFormerB2": {"type": "segformer", "description": "SegFormerB2 is a transformer-based model for semantic segmentation."},
41
- "InceptionResNetV2": {"type": "inceptionresnetv2", "description": "InceptionResNetV2 is a hybrid model combining Inception and ResNet architectures."},
42
  }
43
 
44
  # Streamlit app
@@ -47,7 +78,7 @@ st.set_page_config(page_title="Landslide Detection", layout="wide")
47
  st.title("Landslide Detection")
48
  st.markdown("""
49
  ## Instructions
50
- 1. Select a model from the sidebar.
51
  2. Upload one or more `.h5` files.
52
  3. The app will process the files and display the input image, prediction, and overlay.
53
  4. You can download the prediction results.
@@ -55,46 +86,121 @@ st.markdown("""
55
 
56
  # Sidebar for model selection
57
  st.sidebar.title("Model Selection")
58
- model_type = st.sidebar.selectbox("Select Model", list(model_descriptions.keys()))
59
-
60
- # Get model details
61
- config = {
62
- 'model_config': {
63
- 'model_type': model_descriptions[model_type]['type'],
64
- 'in_channels': 14,
65
- 'num_classes': 1
66
- }
67
- }
68
-
69
- # Show model description
70
- st.sidebar.markdown(f"**Model Type:** {model_descriptions[model_type]['type']}")
71
- st.sidebar.markdown(f"**Description:** {model_descriptions[model_type]['description']}")
72
-
73
- try:
74
- # Get the appropriate model class
75
  if model_type == "DeepLabV3+":
76
  model_class = DeepLabV3PlusModel
77
  else:
78
  model_class = locals()[model_type.replace("-", "") + "Model"]
79
-
80
- # Get model path from downloader
81
- model_name = model_descriptions[model_type]['type'].replace("+", "plus").lower()
82
- model_path = model_downloader.get_model_path(model_name)
83
- st.success(f"Model {model_type} loaded successfully!")
84
-
85
- # File uploader
86
- uploaded_files = st.file_uploader("Upload H5 files", type=['h5'], accept_multiple_files=True)
87
-
88
- if uploaded_files:
89
- # Process each uploaded file
90
- for uploaded_file in uploaded_files:
91
- st.write(f"Processing {uploaded_file.name}...")
92
- # Add your file processing logic here
93
-
94
- except FileNotFoundError as e:
95
- st.error(f"Model file not found: {str(e)}")
96
- st.error("Please ensure all model files are present in the models directory")
97
- st.stop()
98
- except Exception as e:
99
- st.error(f"Error: {str(e)}")
100
- st.stop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import h5py
3
  import torch
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
  import yaml
7
  import os
 
 
8
 
9
  # Import models
10
+ from src.mobilenetv2_model import LandslideModel as MobileNetV2Model
11
+ from src.vgg16_model import LandslideModel as VGG16Model
12
+ from src.resnet34_model import LandslideModel as ResNet34Model
13
+ from src.efficientnetb0_model import LandslideModel as EfficientNetB0Model
14
+ from src.mitb1_model import LandslideModel as MiTB1Model
15
+ from src.inceptionv4_model import LandslideModel as InceptionV4Model
16
+ from src.densenet121_model import LandslideModel as DenseNet121Model
17
+ from src.deeplabv3plus_model import LandslideModel as DeepLabV3PlusModel
18
+ from src.resnext50_32x4d_model import LandslideModel as ResNeXt50_32X4DModel
19
+ from src.se_resnet50_model import LandslideModel as SEResNet50Model
20
+ from src.se_resnext50_32x4d_model import LandslideModel as SEResNeXt50_32X4DModel
21
+ from segformer_model import LandslideModel as SegFormerB2Model
22
+ from inceptionresnetv2_model import LandslideModel as InceptionResNetV2Model
23
+
24
+ # Load the configuration file
25
+ config = """
26
+ model_config:
27
+ model_type: "mobilenet_v2"
28
+ in_channels: 14
29
+ num_classes: 1
30
+ encoder_weights: "imagenet"
31
+ wce_weight: 0.5
32
+
33
+ dataset_config:
34
+ num_classes: 1
35
+ num_channels: 14
36
+ channels: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
37
+ normalize: False
38
+
39
+ train_config:
40
+ dataset_path: ""
41
+ checkpoint_path: "checkpoints"
42
+ seed: 42
43
+ train_val_split: 0.8
44
+ batch_size: 16
45
+ num_epochs: 100
46
+ lr: 0.001
47
+ device: "cuda:0"
48
+ save_config: True
49
+ experiment_name: "mobilenet_v2"
50
+
51
+ logging_config:
52
+ wandb_project: "l4s"
53
+ wandb_entity: "Silvamillion"
54
+ """
55
+
56
+ config = yaml.safe_load(config)
57
 
58
  # Model descriptions
59
  model_descriptions = {
60
+ "MobileNetV2": {"path": "mobilenetv2.pth", "type": "mobilenet_v2", "description": "MobileNetV2 is a lightweight deep learning model for image classification and segmentation."},
61
+ "VGG16": {"path": "vgg16.pth", "type": "vgg16", "description": "VGG16 is a popular deep learning model known for its simplicity and depth."},
62
+ "ResNet34": {"path": "resnet34.pth", "type": "resnet34", "description": "ResNet34 is a deep residual network that helps in training very deep networks."},
63
+ "EfficientNetB0": {"path": "effucientnetb0.pth", "type": "efficientnet_b0", "description": "EfficientNetB0 is part of the EfficientNet family, known for its efficiency and performance."},
64
+ "MiT-B1": {"path": "mitb1.pth", "type": "mit_b1", "description": "MiT-B1 is a transformer-based model designed for segmentation tasks."},
65
+ "InceptionV4": {"path": "inceptionv4.pth", "type": "inceptionv4", "description": "InceptionV4 is a convolutional neural network known for its inception modules."},
66
+ "DeepLabV3+": {"path": "deeplabv3.pth", "type": "deeplabv3+", "description": "DeepLabV3+ is an advanced model for semantic image segmentation."},
67
+ "DenseNet121": {"path": "densenet121.pth", "type": "densenet121", "description": "DenseNet121 is a densely connected convolutional network for image classification and segmentation."},
68
+ "ResNeXt50_32X4D": {"path": "resnext50-32x4d.pth", "type": "resnext50_32x4d", "description": "ResNeXt50_32X4D is a highly modularized network aimed at improving accuracy."},
69
+ "SEResNet50": {"path": "se_resnet50.pth", "type": "se_resnet50", "description": "SEResNet50 is a ResNet model with squeeze-and-excitation blocks for better feature recalibration."},
70
+ "SEResNeXt50_32X4D": {"path": "se_resnext50_32x4d.pth", "type": "se_resnext50_32x4d", "description": "SEResNeXt50_32X4D combines ResNeXt and SE blocks for improved performance."},
71
+ "SegFormerB2": {"path": "segformer.pth", "type": "segformer_b2", "description": "SegFormerB2 is a transformer-based model for semantic segmentation."},
72
+ "InceptionResNetV2": {"path": "inceptionresnetv2.pth", "type": "inceptionresnetv2", "description": "InceptionResNetV2 is a hybrid model combining Inception and ResNet architectures."},
73
  }
74
 
75
  # Streamlit app
 
78
  st.title("Landslide Detection")
79
  st.markdown("""
80
  ## Instructions
81
+ 1. Select a model from the sidebar or choose to run all models.
82
  2. Upload one or more `.h5` files.
83
  3. The app will process the files and display the input image, prediction, and overlay.
84
  4. You can download the prediction results.
 
86
 
87
  # Sidebar for model selection
88
  st.sidebar.title("Model Selection")
89
+ model_option = st.sidebar.radio("Choose an option", ["Select a single model", "Run all models"])
90
+ if model_option == "Select a single model":
91
+ model_type = st.sidebar.selectbox("Select Model", list(model_descriptions.keys()))
92
+ config['model_config']['model_type'] = model_descriptions[model_type]['type']
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  if model_type == "DeepLabV3+":
94
  model_class = DeepLabV3PlusModel
95
  else:
96
  model_class = locals()[model_type.replace("-", "") + "Model"]
97
+ model_path = model_descriptions[model_type]['path']
98
+
99
+ # Display model details in the sidebar
100
+ st.sidebar.markdown(f"**Model Type:** {model_descriptions[model_type]['type']}")
101
+ st.sidebar.markdown(f"**Model Path:** {model_descriptions[model_type]['path']}")
102
+ st.sidebar.markdown(f"**Description:** {model_descriptions[model_type]['description']}")
103
+
104
+ # Main content
105
+ st.header("Upload Data")
106
+ uploaded_files = st.file_uploader("Choose .h5 files...", type="h5", accept_multiple_files=True)
107
+ if uploaded_files:
108
+ for uploaded_file in uploaded_files:
109
+ st.write(f"Processing file: {uploaded_file.name}")
110
+ with st.spinner('Classifying...'):
111
+ with h5py.File(uploaded_file, 'r') as hdf:
112
+ data = np.array(hdf.get('img'))
113
+ data[np.isnan(data)] = 0.000001
114
+ channels = config["dataset_config"]["channels"]
115
+ image = np.zeros((128, 128, len(channels)))
116
+ for i, channel in enumerate(channels):
117
+ image[:, :, i] = data[:, :, channel-1]
118
+
119
+ # Transform the image to the required format
120
+ image = image.transpose((2, 0, 1)) # (H, W, C) to (C, H, W)
121
+ image = torch.from_numpy(image).float().unsqueeze(0) # Add batch dimension
122
+
123
+ if model_option == "Select a single model":
124
+ # Process the image with the selected model
125
+ st.write(f"Using model: {model_type}")
126
+
127
+ # Load the model
128
+ model = model_class(config)
129
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
130
+ model.eval()
131
+
132
+ # Make prediction
133
+ with torch.no_grad():
134
+ prediction = model(image)
135
+ prediction = torch.sigmoid(prediction).cpu().numpy()
136
+
137
+ # Display prediction
138
+ st.header(f"Prediction Results - {model_type}")
139
+ fig, ax = plt.subplots(1, 3, figsize=(15, 5))
140
+ img = image.squeeze().permute(1, 2, 0).numpy()
141
+ img = (img - img.min()) / (img.max() - img.min()) # Normalize the image to [0, 1] range for display
142
+ ax[0].imshow(img[:, :, 1:4]) # Display first three channels as RGB
143
+ ax[0].set_title("Input Image")
144
+ ax[1].imshow(prediction.squeeze() > 0.5, cmap='plasma') # Apply threshold
145
+ ax[1].set_title("Prediction")
146
+ ax[2].imshow(img[:, :, :3]) # Display first three channels as RGB
147
+ ax[2].imshow(prediction.squeeze() > 0.5, cmap='plasma', alpha=0.3) # Overlay prediction
148
+ ax[2].set_title("Overlay")
149
+ st.pyplot(fig)
150
+
151
+ # Option to download the prediction
152
+ st.write(f"Download the prediction as a .npy file for {model_type}:")
153
+ npy_data = prediction.squeeze()
154
+ st.download_button(
155
+ label=f"Download Prediction - {model_type}",
156
+ data=npy_data.tobytes(),
157
+ file_name=f"{uploaded_file.name.split('.')[0]}_{model_type}_prediction.npy",
158
+ mime="application/octet-stream"
159
+ )
160
+
161
+ else:
162
+ # Process the image with each model
163
+ for model_name, model_info in model_descriptions.items():
164
+ st.write(f"Using model: {model_name}")
165
+ if model_name == "DeepLabV3+":
166
+ model_class = DeepLabV3PlusModel
167
+ else:
168
+ model_class = locals()[model_name.replace("-", "") + "Model"]
169
+ model_path = model_info['path']
170
+ config['model_config']['model_type'] = model_info['type']
171
+
172
+ # Load the model
173
+ model = model_class(config)
174
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
175
+ model.eval()
176
+
177
+ # Make prediction
178
+ with torch.no_grad():
179
+ prediction = model(image)
180
+ prediction = torch.sigmoid(prediction).cpu().numpy()
181
+
182
+ # Display prediction
183
+ st.header(f"Prediction Results - {model_name}")
184
+ fig, ax = plt.subplots(1, 3, figsize=(15, 5))
185
+ img = image.squeeze().permute(1, 2, 0).numpy()
186
+ img = (img - img.min()) / (img.max() - img.min()) # Normalize the image to [0, 1] range for display
187
+ ax[0].imshow(img[:, :, :3]) # Display first three channels as RGB
188
+ ax[0].set_title("Input Image")
189
+ ax[1].imshow(prediction.squeeze() > 0.5, cmap='plasma') # Apply threshold
190
+ ax[1].set_title("Prediction")
191
+ ax[2].imshow(img[:, :, :3]) # Display first three channels as RGB
192
+ ax[2].imshow(prediction.squeeze() > 0.5, cmap='plasma', alpha=0.3) # Overlay prediction
193
+ ax[2].set_title("Overlay")
194
+ st.pyplot(fig)
195
+
196
+ # Option to download the prediction
197
+ st.write(f"Download the prediction as a .npy file for {model_name}:")
198
+ npy_data = prediction.squeeze()
199
+ st.download_button(
200
+ label=f"Download Prediction - {model_name}",
201
+ data=npy_data.tobytes(),
202
+ file_name=f"{uploaded_file.name.split('.')[0]}_{model_name}_prediction.npy",
203
+ mime="application/octet-stream"
204
+ )
205
+
206
+ st.success('Done!')