Spaces:
Running
Running
update
Browse files
examples/sound_classification_by_lstm/step_6_export_onnx_model.py
CHANGED
|
@@ -30,6 +30,18 @@ def get_args():
|
|
| 30 |
|
| 31 |
parser.add_argument("--serialization_dir", default="file_dir/best", type=str)
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
args = parser.parse_args()
|
| 34 |
return args
|
| 35 |
|
|
@@ -102,13 +114,17 @@ def main():
|
|
| 102 |
"logits", "new_h", "new_c",
|
| 103 |
],
|
| 104 |
dynamic_axes={
|
| 105 |
-
"inputs": {
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
ort_session = ort.InferenceSession(output_file.as_posix())
|
| 114 |
input_feed = {
|
|
|
|
| 30 |
|
| 31 |
parser.add_argument("--serialization_dir", default="file_dir/best", type=str)
|
| 32 |
|
| 33 |
+
# parser.add_argument(
|
| 34 |
+
# "--vocabulary_dir",
|
| 35 |
+
# default=r"D:\Users\tianx\HuggingSpaces\cc_audio_8\trained_models\voicemail-zh-tw-2-ch64-lstm\voicemail-zh-tw-2-ch64-lstm\vocabulary",
|
| 36 |
+
# type=str
|
| 37 |
+
# )
|
| 38 |
+
# parser.add_argument(
|
| 39 |
+
# "--model_dir",
|
| 40 |
+
# default=r"D:\Users\tianx\HuggingSpaces\cc_audio_8\trained_models\voicemail-zh-tw-2-ch64-lstm\voicemail-zh-tw-2-ch64-lstm",
|
| 41 |
+
# type=str
|
| 42 |
+
# )
|
| 43 |
+
# parser.add_argument("--serialization_dir", default="./", type=str)
|
| 44 |
+
|
| 45 |
args = parser.parse_args()
|
| 46 |
return args
|
| 47 |
|
|
|
|
| 114 |
"logits", "new_h", "new_c",
|
| 115 |
],
|
| 116 |
dynamic_axes={
|
| 117 |
+
"inputs": {1: "time_steps"},
|
| 118 |
+
}
|
| 119 |
+
# dynamic_axes={
|
| 120 |
+
# "inputs": {0: "batch_size", 1: "time_steps"},
|
| 121 |
+
# "h": {1: "batch_size"},
|
| 122 |
+
# "c": {1: "batch_size"},
|
| 123 |
+
# "logits": {0: "batch_size"},
|
| 124 |
+
# "new_h": {1: "batch_size"},
|
| 125 |
+
# "new_c": {1: "batch_size"},
|
| 126 |
+
# }
|
| 127 |
+
)
|
| 128 |
|
| 129 |
ort_session = ort.InferenceSession(output_file.as_posix())
|
| 130 |
input_feed = {
|
toolbox/torchaudio/models/cnn_audio_classifier/modeling_cnn_audio_classifier.py
CHANGED
|
@@ -308,12 +308,13 @@ class ClsHead(nn.Module):
|
|
| 308 |
def forward(self, inputs: torch.Tensor):
|
| 309 |
# inputs: [batch_size, seq_length, spec_dim]
|
| 310 |
x = self.feedforward(inputs)
|
|
|
|
| 311 |
|
| 312 |
-
# x: [batch_size, spec_dim]
|
| 313 |
x = torch.mean(x, dim=1)
|
|
|
|
| 314 |
|
| 315 |
-
# logits: [batch_size, num_labels]
|
| 316 |
logits = self.output_project_layer.forward(x)
|
|
|
|
| 317 |
return logits
|
| 318 |
|
| 319 |
|
|
|
|
| 308 |
def forward(self, inputs: torch.Tensor):
|
| 309 |
# inputs: [batch_size, seq_length, spec_dim]
|
| 310 |
x = self.feedforward(inputs)
|
| 311 |
+
# x: [batch_size, seq_length, hidden_size]
|
| 312 |
|
|
|
|
| 313 |
x = torch.mean(x, dim=1)
|
| 314 |
+
# x: [batch_size, hidden_size]
|
| 315 |
|
|
|
|
| 316 |
logits = self.output_project_layer.forward(x)
|
| 317 |
+
# logits: [batch_size, num_labels]
|
| 318 |
return logits
|
| 319 |
|
| 320 |
|