Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#812 Fix YoloX loss to handle negative batch correctly

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-000-fix-yolox-loss-on-negative-batch
1 changed files with 29 additions and 0 deletions
  1. 29
    0
      tests/unit_tests/yolox_unit_test.py
@@ -2,7 +2,9 @@ import unittest
 
 
 import torch
 import torch
 
 
+from super_gradients.training.losses import YoloXDetectionLoss, YoloXFastDetectionLoss
 from super_gradients.training.models.detection_models.yolox import YoloX_N, YoloX_T, YoloX_S, YoloX_M, YoloX_L, YoloX_X
 from super_gradients.training.models.detection_models.yolox import YoloX_N, YoloX_T, YoloX_S, YoloX_M, YoloX_L, YoloX_X
+from super_gradients.training.utils.detection_utils import DetectionCollateFN
 from super_gradients.training.utils.utils import HpmStruct
 from super_gradients.training.utils.utils import HpmStruct
 
 
 
 
@@ -39,6 +41,33 @@ class TestYOLOX(unittest.TestCase):
                 output_augment = yolo_model(dummy_input)
                 output_augment = yolo_model(dummy_input)
                 self.assertIsNotNone(output_augment)
                 self.assertIsNotNone(output_augment)
 
 
+    def test_yolox_loss(self):
+        samples = [
+            (torch.zeros((3, 256, 256)), torch.zeros((100, 5))),
+            (torch.zeros((3, 256, 256)), torch.zeros((100, 5))),
+            (torch.zeros((3, 256, 256)), torch.zeros((100, 5))),
+            (torch.zeros((3, 256, 256)), torch.zeros((100, 5))),
+            (torch.zeros((3, 256, 256)), torch.zeros((100, 5))),
+        ]
+        collate = DetectionCollateFN()
+        _, targets = collate(samples)
+
+        predictions = [
+            torch.randn((5, 1, 256 // 8, 256 // 8, 4 + 1 + 10)),
+            torch.randn((5, 1, 256 // 16, 256 // 16, 4 + 1 + 10)),
+            torch.randn((5, 1, 256 // 32, 256 // 32, 4 + 1 + 10)),
+        ]
+
+        for loss in [
+            YoloXDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=True, iou_type="giou"),
+            YoloXDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=True, iou_type="iou"),
+            YoloXDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=False),
+            YoloXFastDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=True),
+            YoloXFastDetectionLoss(strides=[8, 16, 32], num_classes=10, use_l1=False),
+        ]:
+            result = loss(predictions, targets)
+            print(result)
+
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     unittest.main()
     unittest.main()
Discard