AB739 commited on
Commit
bd4f556
·
verified ·
1 Parent(s): 1572cba

Update tasks/audio.py

Browse files
Files changed (1) hide show
  1. tasks/audio.py +1 -7
tasks/audio.py CHANGED
@@ -228,13 +228,8 @@ async def evaluate_audio(request: AudioEvaluationRequest):
228
  int8_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') # Optional if not defined in saved model
229
 
230
  # Load the state dictionary
231
- int8_model.load_state_dict(torch.load(quantized_model_path, weights_only=True))
232
  int8_model.eval() # Set to evaluation mode
233
-
234
- # Move model to device
235
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
236
- int8_model.to(device)
237
-
238
  #model.load_state_dict(torch.load("./best_blazeface_model_second.pth", map_location=torch.device('cpu'), weights_only=True))
239
 
240
  #model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
@@ -250,7 +245,6 @@ async def evaluate_audio(request: AudioEvaluationRequest):
250
  predictions = []
251
  with torch.inference_mode():
252
  for data, target in test_loader:
253
- waveforms, labels = waveforms.to(device), labels.to(device)
254
  outputs = int8_model(waveforms)
255
  _, predicted = torch.max(outputs, 1)
256
  predictions.extend(predicted.tolist())
 
228
  int8_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') # Optional if not defined in saved model
229
 
230
  # Load the state dictionary
231
+ int8_model.load_state_dict(torch.load(quantized_model_path, map_location=torch.device('cpu'), weights_only=True))
232
  int8_model.eval() # Set to evaluation mode
 
 
 
 
 
233
  #model.load_state_dict(torch.load("./best_blazeface_model_second.pth", map_location=torch.device('cpu'), weights_only=True))
234
 
235
  #model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
 
245
  predictions = []
246
  with torch.inference_mode():
247
  for data, target in test_loader:
 
248
  outputs = int8_model(waveforms)
249
  _, predicted = torch.max(outputs, 1)
250
  predictions.extend(predicted.tolist())