Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#561 Feature/sg 193 extend output formator

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-193-extend_detection_target_transform
@@ -7,8 +7,13 @@ import onnx
 import onnxruntime as ort
 import onnxruntime as ort
 import torch.jit
 import torch.jit
 
 
-from super_gradients.training.utils.bbox_formats import NormalizedXYWHCoordinateFormat, CXCYWHCoordinateFormat, YXYXCoordinateFormat
-from super_gradients.training.utils.output_adapters import DetectionOutputAdapter, ConcatenatedTensorFormat, BoundingBoxesTensorSliceItem, TensorSliceItem
+from super_gradients.training.datasets.data_formats.bbox_formats import NormalizedXYWHCoordinateFormat, CXCYWHCoordinateFormat, YXYXCoordinateFormat
+from super_gradients.training.datasets.data_formats.output_adapters.detection_adapter import DetectionOutputAdapter
+from super_gradients.training.datasets.data_formats import (
+    ConcatenatedTensorFormat,
+    BoundingBoxesTensorSliceItem,
+    TensorSliceItem,
+)
 
 
 NORMALIZED_XYWH_SCORES_LABELS = ConcatenatedTensorFormat(
 NORMALIZED_XYWH_SCORES_LABELS = ConcatenatedTensorFormat(
     layout=(
     layout=(
@@ -119,7 +124,7 @@ class TestDetectionOutputAdapter(unittest.TestCase):
 
 
             with tempfile.TemporaryDirectory() as tmpdirname:
             with tempfile.TemporaryDirectory() as tmpdirname:
                 adapter_fname = os.path.join(tmpdirname, "adapter.onnx")
                 adapter_fname = os.path.join(tmpdirname, "adapter.onnx")
-                torch.onnx.export(adapter, inp, f=adapter_fname, input_names=["predictions"], output_names=["output_predictions"])
+                torch.onnx.export(adapter, inp, f=adapter_fname, input_names=["predictions"], output_names=["output_predictions"], opset_version=11)
 
 
                 onnx_model = onnx.load(adapter_fname)
                 onnx_model = onnx.load(adapter_fname)
                 onnx.checker.check_model(onnx_model)
                 onnx.checker.check_model(onnx_model)
Discard
Tip!

Press p or to see the previous file or, n or to see the next file