|
@@ -76,16 +76,21 @@ Train YOLO11n-cls on the MNIST160 dataset for 100 [epochs](https://www.ultralyti
|
|
|
|
|
|
Ultralytics YOLO classification uses [torchvision.transforms.RandomResizedCrop](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.RandomResizedCrop.html) for training augmentation and [torchvision.transforms.CenterCrop](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.CenterCrop.html) for validation/inference.
|
|
Ultralytics YOLO classification uses [torchvision.transforms.RandomResizedCrop](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.RandomResizedCrop.html) for training augmentation and [torchvision.transforms.CenterCrop](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.CenterCrop.html) for validation/inference.
|
|
For images with extreme aspect ratios, consider using [torchvision.transforms.Resize](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.Resize.html) instead. The example below shows how to customize augmentations for classification training.
|
|
For images with extreme aspect ratios, consider using [torchvision.transforms.Resize](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.Resize.html) instead. The example below shows how to customize augmentations for classification training.
|
|
|
|
+
|
|
```python
|
|
```python
|
|
import torch
|
|
import torch
|
|
import torchvision.transforms as T
|
|
import torchvision.transforms as T
|
|
|
|
|
|
|
|
+ from ultralytics import YOLO
|
|
from ultralytics.data.dataset import ClassificationDataset
|
|
from ultralytics.data.dataset import ClassificationDataset
|
|
from ultralytics.models.yolo.classify import ClassificationTrainer
|
|
from ultralytics.models.yolo.classify import ClassificationTrainer
|
|
|
|
|
|
|
|
|
|
class CustomizedDataset(ClassificationDataset):
|
|
class CustomizedDataset(ClassificationDataset):
|
|
|
|
+ """A customized dataset class for image classification with enhanced data augmentation transforms."""
|
|
|
|
+
|
|
def __init__(self, root: str, args, augment: bool = False, prefix: str = ""):
|
|
def __init__(self, root: str, args, augment: bool = False, prefix: str = ""):
|
|
|
|
+ """Initialize a customized classification dataset with enhanced data augmentation transforms."""
|
|
super().__init__(root, args, augment, prefix)
|
|
super().__init__(root, args, augment, prefix)
|
|
train_transforms = T.Compose(
|
|
train_transforms = T.Compose(
|
|
[
|
|
[
|
|
@@ -110,12 +115,13 @@ Train YOLO11n-cls on the MNIST160 dataset for 100 [epochs](https://www.ultralyti
|
|
|
|
|
|
|
|
|
|
class CustomizedTrainer(ClassificationTrainer):
|
|
class CustomizedTrainer(ClassificationTrainer):
|
|
|
|
+ """A customized trainer class for YOLO classification models with enhanced dataset handling."""
|
|
|
|
+
|
|
def build_dataset(self, img_path: str, mode: str = "train", batch=None):
|
|
def build_dataset(self, img_path: str, mode: str = "train", batch=None):
|
|
|
|
+ """Build a customized dataset for classification training or validation."""
|
|
return CustomizedDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
|
|
return CustomizedDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
|
|
|
|
|
|
|
|
|
|
- from ultralytics import YOLO
|
|
|
|
-
|
|
|
|
model = YOLO("yolo11n-cls.pt")
|
|
model = YOLO("yolo11n-cls.pt")
|
|
model.train(data="imagenet1000", trainer=CustomizedTrainer, epochs=10, imgsz=224, batch=64)
|
|
model.train(data="imagenet1000", trainer=CustomizedTrainer, epochs=10, imgsz=224, batch=64)
|
|
```
|
|
```
|