AB739 commited on
Commit
0fe12cf
·
verified ·
1 Parent(s): deeac22

Update tasks/model.py

Browse files
Files changed (1) hide show
  1. tasks/model.py +17 -1
tasks/model.py CHANGED
@@ -4,6 +4,7 @@ import torch.nn.functional as F
4
  from torch.utils.data import DataLoader, TensorDataset
5
  from torchaudio import transforms
6
  from torchvision import models
 
7
 
8
 
9
 
@@ -104,4 +105,19 @@ class BlazeFaceModel(nn.Module):
104
  def forward(self, x):
105
  features = self.blazeface_backbone(x)
106
  output = self.fc(features)
107
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from torch.utils.data import DataLoader, TensorDataset
5
  from torchaudio import transforms
6
  from torchvision import models
7
+ from torch.quantization import QuantStub, DeQuantStub
8
 
9
 
10
 
 
105
  def forward(self, x):
106
  features = self.blazeface_backbone(x)
107
  output = self.fc(features)
108
+ return output
109
+
110
+ class QuantizedBlazeFaceModel(nn.Module):
111
+ def __init__(self, model_fp32):
112
+ super(QuantizedBlazeFaceModel, self).__init__()
113
+ self.quant = QuantStub()
114
+ self.dequant = DeQuantStub()
115
+ self.backbone = model_fp32.blazeface_backbone
116
+ self.fc = model_fp32.fc
117
+
118
+ def forward(self, x):
119
+ x = self.quant(x)
120
+ x = self.backbone(x)
121
+ x = self.fc(x)
122
+ x = self.dequant(x)
123
+ return x