HoneyTian commited on
Commit
f9d6521
·
1 Parent(s): 555ea1a
examples/sound_classification_by_lstm/step_9_evaluation_onnx_model.py CHANGED
@@ -42,7 +42,8 @@ def get_args():
42
  parser.add_argument("--model_dir", default="best", type=str)
43
  parser.add_argument("--onnx_model_file", default="model.onnx", type=str)
44
  parser.add_argument("--output_file", default="evaluation_onnx.xlsx", type=str)
45
- parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", type=str)
 
46
 
47
  parser.add_argument("--max_count", default=10000, type=int)
48
 
@@ -78,7 +79,10 @@ def main():
78
  evaluation_file = Path(args.dataset)
79
 
80
  config = WaveClassifierConfig.from_pretrained(config_file.as_posix())
81
- ort_session = ort.InferenceSession(onnx_model_file.as_posix())
 
 
 
82
  vocabulary = Vocabulary.from_files(vocab_path.as_posix())
83
 
84
  # transform
 
42
  parser.add_argument("--model_dir", default="best", type=str)
43
  parser.add_argument("--onnx_model_file", default="model.onnx", type=str)
44
  parser.add_argument("--output_file", default="evaluation_onnx.xlsx", type=str)
45
+ # parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", type=str)
46
+ parser.add_argument("--device", default="cpu", type=str)
47
 
48
  parser.add_argument("--max_count", default=10000, type=int)
49
 
 
79
  evaluation_file = Path(args.dataset)
80
 
81
  config = WaveClassifierConfig.from_pretrained(config_file.as_posix())
82
+ ort_session = ort.InferenceSession(
83
+ onnx_model_file.as_posix(),
84
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
85
+ )
86
  vocabulary = Vocabulary.from_files(vocab_path.as_posix())
87
 
88
  # transform
requirements.txt CHANGED
@@ -12,8 +12,8 @@ evaluate
12
  gradio
13
  python-dotenv
14
  numpy
15
- onnxruntime
16
- scipy
17
  onnx
18
  onnxruntime
 
 
19
  tenacity
 
12
  gradio
13
  python-dotenv
14
  numpy
 
 
15
  onnx
16
  onnxruntime
17
+ onnxruntime-gpu
18
+ scipy
19
  tenacity