| import torch | |
| class Params: | |
| def __init__(self): | |
| self.batch_size = 148 | |
| self.name = "resnet_50" | |
| self.num_workers = 48 | |
| self.lr = 0.165 | |
| self.momentum = 0.9 | |
| self.weight_decay = 1e-4 | |
| self.lr_step_size = 30 | |
| self.lr_gamma = 0.1 | |
| self.num_epochs = 50 | |
| def __repr__(self): | |
| return str(self.__dict__) | |
| def __eq__(self, other): | |
| return self.__dict__ == other.__dict__ | |
| def get_device(): | |
| return ( | |
| "cuda" | |
| if torch.cuda.is_available() | |
| else "mps" | |
| if torch.backends.mps.is_available() | |
| else "cpu" | |
| ) |