Spaces:
Sleeping
Sleeping
Update tasks/audio.py
Browse files- tasks/audio.py +1 -7
tasks/audio.py
CHANGED
|
@@ -228,13 +228,8 @@ async def evaluate_audio(request: AudioEvaluationRequest):
|
|
| 228 |
int8_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') # Optional if not defined in saved model
|
| 229 |
|
| 230 |
# Load the state dictionary
|
| 231 |
-
int8_model.load_state_dict(torch.load(quantized_model_path, weights_only=True))
|
| 232 |
int8_model.eval() # Set to evaluation mode
|
| 233 |
-
|
| 234 |
-
# Move model to device
|
| 235 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 236 |
-
int8_model.to(device)
|
| 237 |
-
|
| 238 |
#model.load_state_dict(torch.load("./best_blazeface_model_second.pth", map_location=torch.device('cpu'), weights_only=True))
|
| 239 |
|
| 240 |
#model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
|
|
@@ -250,7 +245,6 @@ async def evaluate_audio(request: AudioEvaluationRequest):
|
|
| 250 |
predictions = []
|
| 251 |
with torch.inference_mode():
|
| 252 |
for data, target in test_loader:
|
| 253 |
-
waveforms, labels = waveforms.to(device), labels.to(device)
|
| 254 |
outputs = int8_model(waveforms)
|
| 255 |
_, predicted = torch.max(outputs, 1)
|
| 256 |
predictions.extend(predicted.tolist())
|
|
|
|
| 228 |
int8_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') # Optional if not defined in saved model
|
| 229 |
|
| 230 |
# Load the state dictionary
|
| 231 |
+
int8_model.load_state_dict(torch.load(quantized_model_path, map_location=torch.device('cpu'), weights_only=True))
|
| 232 |
int8_model.eval() # Set to evaluation mode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
#model.load_state_dict(torch.load("./best_blazeface_model_second.pth", map_location=torch.device('cpu'), weights_only=True))
|
| 234 |
|
| 235 |
#model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
|
|
|
|
| 245 |
predictions = []
|
| 246 |
with torch.inference_mode():
|
| 247 |
for data, target in test_loader:
|
|
|
|
| 248 |
outputs = int8_model(waveforms)
|
| 249 |
_, predicted = torch.max(outputs, 1)
|
| 250 |
predictions.extend(predicted.tolist())
|