Spaces:
Running
Running
update
Browse files
examples/sound_classification_by_lstm/step_9_evaluation_onnx_model.py
CHANGED
|
@@ -42,7 +42,8 @@ def get_args():
|
|
| 42 |
parser.add_argument("--model_dir", default="best", type=str)
|
| 43 |
parser.add_argument("--onnx_model_file", default="model.onnx", type=str)
|
| 44 |
parser.add_argument("--output_file", default="evaluation_onnx.xlsx", type=str)
|
| 45 |
-
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", type=str)
|
|
|
|
| 46 |
|
| 47 |
parser.add_argument("--max_count", default=10000, type=int)
|
| 48 |
|
|
@@ -78,7 +79,10 @@ def main():
|
|
| 78 |
evaluation_file = Path(args.dataset)
|
| 79 |
|
| 80 |
config = WaveClassifierConfig.from_pretrained(config_file.as_posix())
|
| 81 |
-
ort_session = ort.InferenceSession(
|
|
|
|
|
|
|
|
|
|
| 82 |
vocabulary = Vocabulary.from_files(vocab_path.as_posix())
|
| 83 |
|
| 84 |
# transform
|
|
|
|
| 42 |
parser.add_argument("--model_dir", default="best", type=str)
|
| 43 |
parser.add_argument("--onnx_model_file", default="model.onnx", type=str)
|
| 44 |
parser.add_argument("--output_file", default="evaluation_onnx.xlsx", type=str)
|
| 45 |
+
# parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", type=str)
|
| 46 |
+
parser.add_argument("--device", default="cpu", type=str)
|
| 47 |
|
| 48 |
parser.add_argument("--max_count", default=10000, type=int)
|
| 49 |
|
|
|
|
| 79 |
evaluation_file = Path(args.dataset)
|
| 80 |
|
| 81 |
config = WaveClassifierConfig.from_pretrained(config_file.as_posix())
|
| 82 |
+
ort_session = ort.InferenceSession(
|
| 83 |
+
onnx_model_file.as_posix(),
|
| 84 |
+
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
|
| 85 |
+
)
|
| 86 |
vocabulary = Vocabulary.from_files(vocab_path.as_posix())
|
| 87 |
|
| 88 |
# transform
|
requirements.txt
CHANGED
|
@@ -12,8 +12,8 @@ evaluate
|
|
| 12 |
gradio
|
| 13 |
python-dotenv
|
| 14 |
numpy
|
| 15 |
-
onnxruntime
|
| 16 |
-
scipy
|
| 17 |
onnx
|
| 18 |
onnxruntime
|
|
|
|
|
|
|
| 19 |
tenacity
|
|
|
|
| 12 |
gradio
|
| 13 |
python-dotenv
|
| 14 |
numpy
|
|
|
|
|
|
|
| 15 |
onnx
|
| 16 |
onnxruntime
|
| 17 |
+
onnxruntime-gpu
|
| 18 |
+
scipy
|
| 19 |
tenacity
|