HoneyTian commited on
Commit
a064312
·
1 Parent(s): 58d9724
examples/sound_classification_by_lstm/step_6_export_onnx_model.py CHANGED
@@ -30,6 +30,18 @@ def get_args():
30
 
31
  parser.add_argument("--serialization_dir", default="file_dir/best", type=str)
32
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  args = parser.parse_args()
34
  return args
35
 
@@ -102,13 +114,17 @@ def main():
102
  "logits", "new_h", "new_c",
103
  ],
104
  dynamic_axes={
105
- "inputs": {0: "batch_size", 1: "time_steps"},
106
- "h": {1: "batch_size"},
107
- "c": {1: "batch_size"},
108
- "logits": {0: "batch_size"},
109
- "new_h": {1: "batch_size"},
110
- "new_c": {1: "batch_size"},
111
- })
 
 
 
 
112
 
113
  ort_session = ort.InferenceSession(output_file.as_posix())
114
  input_feed = {
 
30
 
31
  parser.add_argument("--serialization_dir", default="file_dir/best", type=str)
32
 
33
+ # parser.add_argument(
34
+ # "--vocabulary_dir",
35
+ # default=r"D:\Users\tianx\HuggingSpaces\cc_audio_8\trained_models\voicemail-zh-tw-2-ch64-lstm\voicemail-zh-tw-2-ch64-lstm\vocabulary",
36
+ # type=str
37
+ # )
38
+ # parser.add_argument(
39
+ # "--model_dir",
40
+ # default=r"D:\Users\tianx\HuggingSpaces\cc_audio_8\trained_models\voicemail-zh-tw-2-ch64-lstm\voicemail-zh-tw-2-ch64-lstm",
41
+ # type=str
42
+ # )
43
+ # parser.add_argument("--serialization_dir", default="./", type=str)
44
+
45
  args = parser.parse_args()
46
  return args
47
 
 
114
  "logits", "new_h", "new_c",
115
  ],
116
  dynamic_axes={
117
+ "inputs": {1: "time_steps"},
118
+ }
119
+ # dynamic_axes={
120
+ # "inputs": {0: "batch_size", 1: "time_steps"},
121
+ # "h": {1: "batch_size"},
122
+ # "c": {1: "batch_size"},
123
+ # "logits": {0: "batch_size"},
124
+ # "new_h": {1: "batch_size"},
125
+ # "new_c": {1: "batch_size"},
126
+ # }
127
+ )
128
 
129
  ort_session = ort.InferenceSession(output_file.as_posix())
130
  input_feed = {
toolbox/torchaudio/models/cnn_audio_classifier/modeling_cnn_audio_classifier.py CHANGED
@@ -308,12 +308,13 @@ class ClsHead(nn.Module):
308
  def forward(self, inputs: torch.Tensor):
309
  # inputs: [batch_size, seq_length, spec_dim]
310
  x = self.feedforward(inputs)
 
311
 
312
- # x: [batch_size, spec_dim]
313
  x = torch.mean(x, dim=1)
 
314
 
315
- # logits: [batch_size, num_labels]
316
  logits = self.output_project_layer.forward(x)
 
317
  return logits
318
 
319
 
 
308
  def forward(self, inputs: torch.Tensor):
309
  # inputs: [batch_size, seq_length, spec_dim]
310
  x = self.feedforward(inputs)
311
+ # x: [batch_size, seq_length, hidden_size]
312
 
 
313
  x = torch.mean(x, dim=1)
314
+ # x: [batch_size, hidden_size]
315
 
 
316
  logits = self.output_project_layer.forward(x)
317
+ # logits: [batch_size, num_labels]
318
  return logits
319
 
320