AB739 commited on
Commit
deeac22
·
verified ·
1 Parent(s): bd4f556

Update tasks/audio.py

Browse files
Files changed (1) hide show
  1. tasks/audio.py +4 -136
tasks/audio.py CHANGED
@@ -10,6 +10,7 @@ from torch.utils.data import DataLoader, TensorDataset
10
  from torchaudio import transforms
11
  from torchvision import models
12
 
 
13
  from .utils.evaluation import AudioEvaluationRequest
14
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
15
 
@@ -87,145 +88,12 @@ async def evaluate_audio(request: AudioEvaluationRequest):
87
  'spectrogram_length': 64,
88
  'dct_coefficient_count': 481,
89
  'label_count': 2
90
- }
91
-
92
- # Create model
93
- #model = BlazeFaceModel(input_channels=1, label_count=model_settings['label_count'], use_double_block=False, activation='relu', use_optional_block=False)
94
- from torch.quantization import QuantStub, DeQuantStub
95
- class BlazeFace(nn.Module):
96
- def __init__(self, input_channels=1, use_double_block=False, activation="relu", use_optional_block=True):
97
- super(BlazeFace, self).__init__()
98
- self.activation = activation
99
- self.use_double_block = use_double_block
100
- self.use_optional_block = use_optional_block
101
-
102
- def conv_block(in_channels, out_channels, kernel_size, stride, padding):
103
- return nn.Sequential(
104
- nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
105
- nn.BatchNorm2d(out_channels),
106
- nn.ReLU() if activation == "relu" else nn.Sigmoid() # Apply ReLU activation (default) or Sigmoid
107
- )
108
-
109
- def depthwise_separable_block(in_channels, out_channels, stride):
110
- return nn.Sequential(
111
- nn.Conv2d(in_channels, in_channels, kernel_size=5, stride=stride, padding=2, groups=in_channels, bias=False),
112
- nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
113
- nn.BatchNorm2d(out_channels),
114
- nn.ReLU() if activation == "relu" else nn.Sigmoid()
115
- )
116
-
117
- def double_block(in_channels, filters_1, filters_2, stride):
118
- return nn.Sequential(
119
- depthwise_separable_block(in_channels, filters_1, stride),
120
- depthwise_separable_block(filters_1, filters_2, 1)
121
- )
122
-
123
- # Define layers (first part: conv layers)
124
- self.conv1 = conv_block(input_channels, 24, kernel_size=5, stride=2, padding=2)
125
-
126
- # Define single blocks (subsequent conv blocks)
127
- self.single_blocks = nn.ModuleList([
128
- depthwise_separable_block(24, 24, stride=1),
129
- depthwise_separable_block(24, 24, stride=1),
130
- depthwise_separable_block(24, 48, stride=2),
131
- depthwise_separable_block(48, 48, stride=1),
132
- depthwise_separable_block(48, 48, stride=1)
133
- ])
134
-
135
- # Define double blocks if `use_double_block` is True
136
- if self.use_double_block:
137
- self.double_blocks = nn.ModuleList([
138
- double_block(48, 24, 96, stride=2),
139
- double_block(96, 24, 96, stride=1),
140
- double_block(96, 24, 96, stride=2),
141
- double_block(96, 24, 96, stride=1),
142
- double_block(96, 24, 96, stride=2)
143
- ])
144
- else:
145
- self.double_blocks = nn.ModuleList([
146
- depthwise_separable_block(48, 96, stride=2),
147
- depthwise_separable_block(96, 96, stride=1),
148
- depthwise_separable_block(96, 96, stride=2),
149
- depthwise_separable_block(96, 96, stride=1),
150
- depthwise_separable_block(96, 96, stride=2)
151
- ])
152
-
153
- # Final convolutional head
154
- self.conv_head = nn.Conv2d(96, 64, kernel_size=1, stride=1)
155
- self.bn_head = nn.BatchNorm2d(64)
156
-
157
- # Global Average Pooling
158
- self.global_avg_pooling = nn.AdaptiveAvgPool2d(1)
159
-
160
- def forward(self, x):
161
- # First conv layer
162
- x = self.conv1(x)
163
-
164
- # Apply single blocks
165
- for block in self.single_blocks:
166
- x = block(x)
167
-
168
- # Apply double blocks
169
- for block in self.double_blocks:
170
- x = block(x)
171
-
172
- # Final head
173
- x = self.conv_head(x)
174
- x = self.bn_head(x)
175
- x = F.relu(x)
176
-
177
- # Global Average Pooling and Flatten
178
- x = self.global_avg_pooling(x)
179
- x = torch.flatten(x, 1)
180
-
181
- return x
182
-
183
- class BlazeFaceModel(nn.Module):
184
- def __init__(self, input_channels, label_count, use_double_block=False, activation="relu", use_optional_block=True):
185
- super(BlazeFaceModel, self).__init__()
186
- self.blazeface_backbone = BlazeFace(input_channels=input_channels, use_double_block=use_double_block, activation=activation, use_optional_block=use_optional_block)
187
- self.fc = nn.Linear(64, label_count)
188
-
189
- def forward(self, x):
190
- features = self.blazeface_backbone(x)
191
- output = self.fc(features)
192
- return output
193
-
194
- # Example Usage
195
- model_settings = {
196
- 'spectrogram_length': 64,
197
- 'dct_coefficient_count': 481,
198
- 'label_count': 2
199
- }
200
-
201
-
202
-
203
- # Define a quantized BlazeFace model
204
- class QuantizedBlazeFaceModel(nn.Module):
205
- def __init__(self, model_fp32):
206
- super(QuantizedBlazeFaceModel, self).__init__()
207
- self.quant = QuantStub()
208
- self.dequant = DeQuantStub()
209
- self.backbone = model_fp32.blazeface_backbone
210
- self.fc = model_fp32.fc
211
-
212
- def forward(self, x):
213
- x = self.quant(x)
214
- x = self.backbone(x)
215
- x = self.fc(x)
216
- x = self.dequant(x)
217
- return x
218
-
219
- # Load the trained model
220
- model_settings = {
221
- 'label_count': 2
222
- }
223
-
224
- model_fp32 = BlazeFaceModel(input_channels=1, label_count=2) # Assume label_count is 2
225
  quantized_model_path = "./qat_int8_blazeface_model.pth"
226
 
227
  int8_model = QuantizedBlazeFaceModel(model_fp32)
228
- int8_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') # Optional if not defined in saved model
229
 
230
  # Load the state dictionary
231
  int8_model.load_state_dict(torch.load(quantized_model_path, map_location=torch.device('cpu'), weights_only=True))
 
10
  from torchaudio import transforms
11
  from torchvision import models
12
 
13
+ from .model import BlazeFaceModel, QuantizedBlazeFaceModel
14
  from .utils.evaluation import AudioEvaluationRequest
15
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
16
 
 
88
  'spectrogram_length': 64,
89
  'dct_coefficient_count': 481,
90
  'label_count': 2
91
+ }
92
+ model = BlazeFaceModel(input_channels=1, label_count=model_settings['label_count'], use_double_block=False, activation='relu', use_optional_block=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  quantized_model_path = "./qat_int8_blazeface_model.pth"
94
 
95
  int8_model = QuantizedBlazeFaceModel(model_fp32)
96
+ int8_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
97
 
98
  # Load the state dictionary
99
  int8_model.load_state_dict(torch.load(quantized_model_path, map_location=torch.device('cpu'), weights_only=True))