Commit
·
db7d3b8
1
Parent(s):
695e3cb
Update
Browse files- resnet_model/configuration_resnet.py +1 -1
- run.py +8 -0
- test.py +6 -0
resnet_model/configuration_resnet.py
CHANGED
|
@@ -20,7 +20,7 @@ class ResnetConfig(PretrainedConfig):
|
|
| 20 |
Defining a model_type for your configuration (here model_type="resnet") is not mandatory,
|
| 21 |
unless you want to register your model with the auto classes (see last section)."""
|
| 22 |
|
| 23 |
-
model_type = "resnet"
|
| 24 |
|
| 25 |
def __init__(
|
| 26 |
self,
|
|
|
|
| 20 |
Defining a model_type for your configuration (here model_type="resnet") is not mandatory,
|
| 21 |
unless you want to register your model with the auto classes (see last section)."""
|
| 22 |
|
| 23 |
+
model_type = "rgbdsod-resnet"
|
| 24 |
|
| 25 |
def __init__(
|
| 26 |
self,
|
run.py
CHANGED
|
@@ -2,11 +2,19 @@ import timm
|
|
| 2 |
|
| 3 |
from resnet_model.configuration_resnet import ResnetConfig
|
| 4 |
from resnet_model.modeling_resnet import ResnetModel, ResnetModelForImageClassification
|
|
|
|
| 5 |
|
| 6 |
ResnetConfig.register_for_auto_class()
|
| 7 |
ResnetModel.register_for_auto_class("AutoModel")
|
| 8 |
ResnetModelForImageClassification.register_for_auto_class("AutoModel")
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
resnet50d_config = ResnetConfig(
|
| 11 |
block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True
|
| 12 |
)
|
|
|
|
| 2 |
|
| 3 |
from resnet_model.configuration_resnet import ResnetConfig
|
| 4 |
from resnet_model.modeling_resnet import ResnetModel, ResnetModelForImageClassification
|
| 5 |
+
from transformers import AutoConfig, AutoModel, AutoModelForImageClassification
|
| 6 |
|
| 7 |
ResnetConfig.register_for_auto_class()
|
| 8 |
ResnetModel.register_for_auto_class("AutoModel")
|
| 9 |
ResnetModelForImageClassification.register_for_auto_class("AutoModel")
|
| 10 |
|
| 11 |
+
|
| 12 |
+
# AutoConfig.register("rgbdsod-resnet", ResnetConfig)
|
| 13 |
+
# AutoModel.register(ResnetConfig, ResnetModel)
|
| 14 |
+
# AutoModelForImageClassification.register(
|
| 15 |
+
# ResnetConfig, ResnetModelForImageClassification
|
| 16 |
+
# )
|
| 17 |
+
|
| 18 |
resnet50d_config = ResnetConfig(
|
| 19 |
block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True
|
| 20 |
)
|
test.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoModel
|
| 2 |
+
|
| 3 |
+
model = AutoModel.from_pretrained(
|
| 4 |
+
"RGBD-SOD/custom-resnet50d", trust_remote_code=True
|
| 5 |
+
)
|
| 6 |
+
print(model)
|