|
@@ -21,7 +21,6 @@ from ultralytics.utils.torch_utils import TORCHVISION_0_10, TORCHVISION_0_11, TO
|
|
|
|
|
|
DEFAULT_MEAN = (0.0, 0.0, 0.0)
|
|
DEFAULT_MEAN = (0.0, 0.0, 0.0)
|
|
DEFAULT_STD = (1.0, 1.0, 1.0)
|
|
DEFAULT_STD = (1.0, 1.0, 1.0)
|
|
-DEFAULT_CROP_FRACTION = 1.0
|
|
|
|
|
|
|
|
|
|
|
|
class BaseTransform:
|
|
class BaseTransform:
|
|
@@ -2446,7 +2445,7 @@ def classify_transforms(
|
|
mean=DEFAULT_MEAN,
|
|
mean=DEFAULT_MEAN,
|
|
std=DEFAULT_STD,
|
|
std=DEFAULT_STD,
|
|
interpolation="BILINEAR",
|
|
interpolation="BILINEAR",
|
|
- crop_fraction: float = DEFAULT_CROP_FRACTION,
|
|
|
|
|
|
+ crop_fraction=None,
|
|
):
|
|
):
|
|
"""
|
|
"""
|
|
Creates a composition of image transforms for classification tasks.
|
|
Creates a composition of image transforms for classification tasks.
|
|
@@ -2461,7 +2460,7 @@ def classify_transforms(
|
|
mean (tuple): Mean values for each RGB channel used in normalization.
|
|
mean (tuple): Mean values for each RGB channel used in normalization.
|
|
std (tuple): Standard deviation values for each RGB channel used in normalization.
|
|
std (tuple): Standard deviation values for each RGB channel used in normalization.
|
|
interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'.
|
|
interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'.
|
|
- crop_fraction (float): Fraction of the image to be cropped.
|
|
|
|
|
|
+ crop_fraction (float): Deprecated, will be removed in a future version.
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
(torchvision.transforms.Compose): A composition of torchvision transforms.
|
|
(torchvision.transforms.Compose): A composition of torchvision transforms.
|
|
@@ -2473,12 +2472,12 @@ def classify_transforms(
|
|
"""
|
|
"""
|
|
import torchvision.transforms as T # scope for faster 'import ultralytics'
|
|
import torchvision.transforms as T # scope for faster 'import ultralytics'
|
|
|
|
|
|
- if isinstance(size, (tuple, list)):
|
|
|
|
- assert len(size) == 2, f"'size' tuples must be length 2, not length {len(size)}"
|
|
|
|
- scale_size = tuple(math.floor(x / crop_fraction) for x in size)
|
|
|
|
- else:
|
|
|
|
- scale_size = math.floor(size / crop_fraction)
|
|
|
|
- scale_size = (scale_size, scale_size)
|
|
|
|
|
|
+ scale_size = size if isinstance(size, (tuple, list)) and len(size) == 2 else (size, size)
|
|
|
|
+
|
|
|
|
+ if crop_fraction:
|
|
|
|
+ raise DeprecationWarning(
|
|
|
|
+ "'crop_fraction' arg of classify_transforms is deprecated, will be removed in a future version."
|
|
|
|
+ )
|
|
|
|
|
|
# Aspect ratio is preserved, crops center within image, no borders are added, image is lost
|
|
# Aspect ratio is preserved, crops center within image, no borders are added, image is lost
|
|
if scale_size[0] == scale_size[1]:
|
|
if scale_size[0] == scale_size[1]:
|
|
@@ -2487,13 +2486,7 @@ def classify_transforms(
|
|
else:
|
|
else:
|
|
# Resize the shortest edge to matching target dim for non-square target
|
|
# Resize the shortest edge to matching target dim for non-square target
|
|
tfl = [T.Resize(scale_size)]
|
|
tfl = [T.Resize(scale_size)]
|
|
- tfl.extend(
|
|
|
|
- [
|
|
|
|
- T.CenterCrop(size),
|
|
|
|
- T.ToTensor(),
|
|
|
|
- T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
|
|
|
|
- ]
|
|
|
|
- )
|
|
|
|
|
|
+ tfl += [T.CenterCrop(size), T.ToTensor(), T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))]
|
|
return T.Compose(tfl)
|
|
return T.Compose(tfl)
|
|
|
|
|
|
|
|
|