|
@@ -238,65 +238,76 @@ class ResNet(SgModule):
|
|
self.linear = nn.Linear(width_multiplier(512, self.width_mult) * self.expansion, new_num_classes)
|
|
self.linear = nn.Linear(width_multiplier(512, self.width_mult) * self.expansion, new_num_classes)
|
|
|
|
|
|
|
|
|
|
-def ResNet18(arch_params, num_classes=None):
|
|
|
|
- return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes or arch_params.num_classes,
|
|
|
|
- droppath_prob=get_param(arch_params, "droppath_prob", 0),
|
|
|
|
- backbone_mode=get_param(arch_params, "backbone_mode", False))
|
|
|
|
|
|
+class ResNet18(ResNet):
|
|
|
|
+ def __init__(self, arch_params, num_classes=None):
|
|
|
|
+ super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=num_classes or arch_params.num_classes,
|
|
|
|
+ droppath_prob=get_param(arch_params, "droppath_prob", 0),
|
|
|
|
+ backbone_mode=get_param(arch_params, "backbone_mode", False))
|
|
|
|
|
|
|
|
|
|
-def ResNet18Cifar(arch_params, num_classes=None):
|
|
|
|
- return CifarResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes or arch_params.num_classes)
|
|
|
|
|
|
+class ResNet18Cifar(CifarResNet):
|
|
|
|
+ def __init__(self, arch_params, num_classes=None):
|
|
|
|
+ super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=num_classes or arch_params.num_classes)
|
|
|
|
|
|
|
|
|
|
-def ResNet34(arch_params, num_classes=None):
|
|
|
|
- return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes or arch_params.num_classes,
|
|
|
|
- droppath_prob=get_param(arch_params, "droppath_prob", 0),
|
|
|
|
- backbone_mode=get_param(arch_params, "backbone_mode", False))
|
|
|
|
|
|
+class ResNet34(ResNet):
|
|
|
|
+ def __init__(self, arch_params, num_classes=None):
|
|
|
|
+ super().__init__(BasicBlock, [3, 4, 6, 3], num_classes=num_classes or arch_params.num_classes,
|
|
|
|
+ droppath_prob=get_param(arch_params, "droppath_prob", 0),
|
|
|
|
+ backbone_mode=get_param(arch_params, "backbone_mode", False))
|
|
|
|
|
|
|
|
|
|
-def ResNet50(arch_params, num_classes=None):
|
|
|
|
- return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes or arch_params.num_classes,
|
|
|
|
- droppath_prob=get_param(arch_params, "droppath_prob", 0),
|
|
|
|
- backbone_mode=get_param(arch_params, "backbone_mode", False), expansion=4)
|
|
|
|
|
|
+class ResNet50(ResNet):
|
|
|
|
+ def __init__(self, arch_params, num_classes=None):
|
|
|
|
+ super().__init__(Bottleneck, [3, 4, 6, 3], num_classes=num_classes or arch_params.num_classes,
|
|
|
|
+ droppath_prob=get_param(arch_params, "droppath_prob", 0),
|
|
|
|
+ backbone_mode=get_param(arch_params, "backbone_mode", False), expansion=4)
|
|
|
|
|
|
|
|
|
|
-def ResNet50_3343(arch_params, num_classes=None):
|
|
|
|
- return ResNet(Bottleneck, [3, 3, 4, 3], num_classes=num_classes or arch_params.num_classes,
|
|
|
|
- droppath_prob=get_param(arch_params, "droppath_prob", 0),
|
|
|
|
- backbone_mode=get_param(arch_params, "backbone_mode", False), expansion=4)
|
|
|
|
|
|
+class ResNet50_3343(ResNet):
|
|
|
|
+ def __init__(self, arch_params, num_classes=None):
|
|
|
|
+ super().__init__(Bottleneck, [3, 3, 4, 3], num_classes=num_classes or arch_params.num_classes,
|
|
|
|
+ droppath_prob=get_param(arch_params, "droppath_prob", 0),
|
|
|
|
+ backbone_mode=get_param(arch_params, "backbone_mode", False), expansion=4)
|
|
|
|
|
|
|
|
|
|
-def ResNet101(arch_params, num_classes=None):
|
|
|
|
- return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes or arch_params.num_classes,
|
|
|
|
- droppath_prob=get_param(arch_params, "droppath_prob", 0),
|
|
|
|
- backbone_mode=get_param(arch_params, "backbone_mode", False), expansion=4)
|
|
|
|
|
|
+class ResNet101(ResNet):
|
|
|
|
+ def __init__(self, arch_params, num_classes=None):
|
|
|
|
+ super().__init__(Bottleneck, [3, 4, 23, 3], num_classes=num_classes or arch_params.num_classes,
|
|
|
|
+ droppath_prob=get_param(arch_params, "droppath_prob", 0),
|
|
|
|
+ backbone_mode=get_param(arch_params, "backbone_mode", False), expansion=4)
|
|
|
|
|
|
|
|
|
|
-def ResNet152(arch_params, num_classes=None):
|
|
|
|
- return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes or arch_params.num_classes,
|
|
|
|
- droppath_prob=get_param(arch_params, "droppath_prob", 0),
|
|
|
|
- backbone_mode=get_param(arch_params, "backbone_mode", False), expansion=4)
|
|
|
|
|
|
+class ResNet152(ResNet):
|
|
|
|
+ def __init__(self, arch_params, num_classes=None):
|
|
|
|
+ super().__init__(Bottleneck, [3, 8, 36, 3], num_classes=num_classes or arch_params.num_classes,
|
|
|
|
+ droppath_prob=get_param(arch_params, "droppath_prob", 0),
|
|
|
|
+ backbone_mode=get_param(arch_params, "backbone_mode", False), expansion=4)
|
|
|
|
|
|
|
|
|
|
-def CustomizedResnetCifar(arch_params, num_classes=None):
|
|
|
|
- return CifarResNet(BasicBlock, arch_params.structure, width_mult=arch_params.width_mult,
|
|
|
|
- num_classes=num_classes or arch_params.num_classes)
|
|
|
|
|
|
+class CustomizedResnetCifar(CifarResNet):
|
|
|
|
+ def __init__(self, arch_params, num_classes=None):
|
|
|
|
+ super().__init__(BasicBlock, arch_params.structure, width_mult=arch_params.width_mult,
|
|
|
|
+ num_classes=num_classes or arch_params.num_classes)
|
|
|
|
|
|
|
|
|
|
-def CustomizedResnet50Cifar(arch_params, num_classes=None):
|
|
|
|
- return CifarResNet(Bottleneck, arch_params.structure, width_mult=arch_params.width_mult,
|
|
|
|
- num_classes=num_classes or arch_params.num_classes, expansion=4)
|
|
|
|
|
|
+class CustomizedResnet50Cifar(CifarResNet):
|
|
|
|
+ def __init__(self, arch_params, num_classes=None):
|
|
|
|
+ super().__init__(Bottleneck, arch_params.structure, width_mult=arch_params.width_mult,
|
|
|
|
+ num_classes=num_classes or arch_params.num_classes, expansion=4)
|
|
|
|
|
|
|
|
|
|
-def CustomizedResnet(arch_params, num_classes=None):
|
|
|
|
- return ResNet(BasicBlock, arch_params.structure, width_mult=arch_params.width_mult,
|
|
|
|
- num_classes=num_classes or arch_params.num_classes,
|
|
|
|
- droppath_prob=get_param(arch_params, "droppath_prob", 0),
|
|
|
|
- backbone_mode=get_param(arch_params, "backbone_mode", False))
|
|
|
|
|
|
+class CustomizedResnet(ResNet):
|
|
|
|
+ def __init__(self, arch_params, num_classes=None):
|
|
|
|
+ super().__init__(BasicBlock, arch_params.structure, width_mult=arch_params.width_mult,
|
|
|
|
+ num_classes=num_classes or arch_params.num_classes,
|
|
|
|
+ droppath_prob=get_param(arch_params, "droppath_prob", 0),
|
|
|
|
+ backbone_mode=get_param(arch_params, "backbone_mode", False))
|
|
|
|
|
|
|
|
|
|
-def CustomizedResnet50(arch_params, num_classes=None):
|
|
|
|
- return ResNet(Bottleneck, arch_params.structure, width_mult=arch_params.width_mult,
|
|
|
|
- num_classes=num_classes or arch_params.num_classes,
|
|
|
|
- droppath_prob=get_param(arch_params, "droppath_prob", 0),
|
|
|
|
- backbone_mode=get_param(arch_params, "backbone_mode", False), expansion=4)
|
|
|
|
|
|
+class CustomizedResnet50(ResNet):
|
|
|
|
+ def __init__(self, arch_params, num_classes=None):
|
|
|
|
+ super().__init__(Bottleneck, arch_params.structure, width_mult=arch_params.width_mult,
|
|
|
|
+ num_classes=num_classes or arch_params.num_classes,
|
|
|
|
+ droppath_prob=get_param(arch_params, "droppath_prob", 0),
|
|
|
|
+ backbone_mode=get_param(arch_params, "backbone_mode", False), expansion=4)
|