import timm from resnet_model.configuration_resnet import ResnetConfig from resnet_model.modeling_resnet import ResnetModel, ResnetModelForImageClassification from transformers import AutoConfig, AutoModel, AutoModelForImageClassification ResnetConfig.register_for_auto_class() ResnetModel.register_for_auto_class("AutoModel") ResnetModelForImageClassification.register_for_auto_class("AutoModel") # AutoConfig.register("rgbdsod-resnet", ResnetConfig) # AutoModel.register(ResnetConfig, ResnetModel) # AutoModelForImageClassification.register( # ResnetConfig, ResnetModelForImageClassification # ) resnet50d_config = ResnetConfig( block_type="bottleneck", stem_width=32, stem_type="deep", avg_down=True ) resnet50d = ResnetModelForImageClassification(resnet50d_config) pretrained_model = timm.create_model("resnet50d", pretrained=True) resnet50d.model.model.load_state_dict(pretrained_model.state_dict()) resnet50d.push_to_hub("RGBD-SOD/custom-resnet50d")