Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#581 Bug/sg 512 shuffle bugfix in recipe datalaoders

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bug/SG-512_shuffle_bugfix_in_recipe_datalaoders
1 changed files with 7 additions and 6 deletions
  1. 7
    6
      src/super_gradients/training/dataloaders/dataloaders.py
@@ -105,27 +105,28 @@ def _process_dataloader_params(cfg, dataloader_params, dataset, train):
 
 
 def _process_sampler_params(dataloader_params, dataset, default_dataloader_params):
 def _process_sampler_params(dataloader_params, dataset, default_dataloader_params):
     is_dist = super_gradients.is_distributed()
     is_dist = super_gradients.is_distributed()
+    dataloader_params = override_default_params_without_nones(dataloader_params, default_dataloader_params)
     if get_param(dataloader_params, "sampler") is not None:
     if get_param(dataloader_params, "sampler") is not None:
         dataloader_params = _instantiate_sampler(dataset, dataloader_params)
         dataloader_params = _instantiate_sampler(dataset, dataloader_params)
-    elif get_param(default_dataloader_params, "sampler") is not None:
-        default_dataloader_params = _instantiate_sampler(dataset, default_dataloader_params)
     elif is_dist:
     elif is_dist:
-        default_dataloader_params["sampler"] = {"DistributedSampler": {}}
-        default_dataloader_params = _instantiate_sampler(dataset, default_dataloader_params)
-    dataloader_params = override_default_params_without_nones(dataloader_params, default_dataloader_params)
+        dataloader_params["sampler"] = {"DistributedSampler": {}}
+        dataloader_params = _instantiate_sampler(dataset, dataloader_params)
     if get_param(dataloader_params, "batch_sampler"):
     if get_param(dataloader_params, "batch_sampler"):
         sampler = dataloader_params.pop("sampler")
         sampler = dataloader_params.pop("sampler")
         batch_size = dataloader_params.pop("batch_size")
         batch_size = dataloader_params.pop("batch_size")
         if "drop_last" in dataloader_params:
         if "drop_last" in dataloader_params:
             drop_last = dataloader_params.pop("drop_last")
             drop_last = dataloader_params.pop("drop_last")
         else:
         else:
-            drop_last = default_dataloader_params["drop_last"]
+            drop_last = dataloader_params["drop_last"]
         dataloader_params["batch_sampler"] = BatchSampler(sampler=sampler, batch_size=batch_size, drop_last=drop_last)
         dataloader_params["batch_sampler"] = BatchSampler(sampler=sampler, batch_size=batch_size, drop_last=drop_last)
     return dataloader_params
     return dataloader_params
 
 
 
 
 def _instantiate_sampler(dataset, dataloader_params):
 def _instantiate_sampler(dataset, dataloader_params):
     sampler_name = list(dataloader_params["sampler"].keys())[0]
     sampler_name = list(dataloader_params["sampler"].keys())[0]
+    if "shuffle" in dataloader_params.keys():
+        # SHUFFLE IS MUTUALLY EXCLUSIVE WITH SAMPLER ARG IN DATALOADER INIT
+        dataloader_params["sampler"][sampler_name]["shuffle"] = dataloader_params.pop("shuffle")
     dataloader_params["sampler"][sampler_name]["dataset"] = dataset
     dataloader_params["sampler"][sampler_name]["dataset"] = dataset
     dataloader_params["sampler"] = SamplersFactory().get(dataloader_params["sampler"])
     dataloader_params["sampler"] = SamplersFactory().get(dataloader_params["sampler"])
     return dataloader_params
     return dataloader_params
Discard