|
@@ -2,7 +2,9 @@ import unittest
|
|
|
|
|
|
import torch
|
|
import torch
|
|
|
|
|
|
|
|
+from super_gradients.training.losses import YoloXDetectionLoss, YoloXFastDetectionLoss
|
|
from super_gradients.training.models.detection_models.yolox import YoloX_N, YoloX_T, YoloX_S, YoloX_M, YoloX_L, YoloX_X
|
|
from super_gradients.training.models.detection_models.yolox import YoloX_N, YoloX_T, YoloX_S, YoloX_M, YoloX_L, YoloX_X
|
|
|
|
+from super_gradients.training.utils.detection_utils import DetectionCollateFN
|
|
from super_gradients.training.utils.utils import HpmStruct
|
|
from super_gradients.training.utils.utils import HpmStruct
|
|
|
|
|
|
|
|
|
|
@@ -39,6 +41,33 @@ class TestYOLOX(unittest.TestCase):
|
|
output_augment = yolo_model(dummy_input)
|
|
output_augment = yolo_model(dummy_input)
|
|
self.assertIsNotNone(output_augment)
|
|
self.assertIsNotNone(output_augment)
|
|
|
|
|
|
|
|
+ def test_yolox_loss(self):
|
|
|
|
+ samples = [
|
|
|
|
+ (torch.zeros((3, 256, 256)), torch.zeros((100, 5))),
|
|
|
|
+ (torch.zeros((3, 256, 256)), torch.zeros((100, 5))),
|
|
|
|
+ (torch.zeros((3, 256, 256)), torch.zeros((100, 5))),
|
|
|
|
+ (torch.zeros((3, 256, 256)), torch.zeros((100, 5))),
|
|
|
|
+ (torch.zeros((3, 256, 256)), torch.zeros((100, 5))),
|
|
|
|
+ ]
|
|
|
|
+ collate = DetectionCollateFN()
|
|
|
|
+ _, targets = collate(samples)
|
|
|
|
+
|
|
|
|
+ predictions = [
|
|
|
|
+ torch.randn((5, 1, 256 // 8, 256 // 8, 4 + 1 + 10)),
|
|
|
|
+ torch.randn((5, 1, 256 // 16, 256 // 16, 4 + 1 + 10)),
|
|
|
|
+ torch.randn((5, 1, 256 // 32, 256 // 32, 4 + 1 + 10)),
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ for loss in [
|
|
|
|
+ YoloXDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=True, iou_type="giou"),
|
|
|
|
+ YoloXDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=True, iou_type="iou"),
|
|
|
|
+ YoloXDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=False),
|
|
|
|
+ YoloXFastDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=True),
|
|
|
|
+ YoloXFastDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=False),
|
|
|
|
+ ]:
|
|
|
|
+ result = loss(predictions, targets)
|
|
|
|
+ print(result)
|
|
|
|
+
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
unittest.main()
|
|
unittest.main()
|