|
@@ -105,27 +105,28 @@ def _process_dataloader_params(cfg, dataloader_params, dataset, train):
|
|
|
|
|
|
def _process_sampler_params(dataloader_params, dataset, default_dataloader_params):
|
|
def _process_sampler_params(dataloader_params, dataset, default_dataloader_params):
|
|
is_dist = super_gradients.is_distributed()
|
|
is_dist = super_gradients.is_distributed()
|
|
|
|
+ dataloader_params = override_default_params_without_nones(dataloader_params, default_dataloader_params)
|
|
if get_param(dataloader_params, "sampler") is not None:
|
|
if get_param(dataloader_params, "sampler") is not None:
|
|
dataloader_params = _instantiate_sampler(dataset, dataloader_params)
|
|
dataloader_params = _instantiate_sampler(dataset, dataloader_params)
|
|
- elif get_param(default_dataloader_params, "sampler") is not None:
|
|
|
|
- default_dataloader_params = _instantiate_sampler(dataset, default_dataloader_params)
|
|
|
|
elif is_dist:
|
|
elif is_dist:
|
|
- default_dataloader_params["sampler"] = {"DistributedSampler": {}}
|
|
|
|
- default_dataloader_params = _instantiate_sampler(dataset, default_dataloader_params)
|
|
|
|
- dataloader_params = override_default_params_without_nones(dataloader_params, default_dataloader_params)
|
|
|
|
|
|
+ dataloader_params["sampler"] = {"DistributedSampler": {}}
|
|
|
|
+ dataloader_params = _instantiate_sampler(dataset, dataloader_params)
|
|
if get_param(dataloader_params, "batch_sampler"):
|
|
if get_param(dataloader_params, "batch_sampler"):
|
|
sampler = dataloader_params.pop("sampler")
|
|
sampler = dataloader_params.pop("sampler")
|
|
batch_size = dataloader_params.pop("batch_size")
|
|
batch_size = dataloader_params.pop("batch_size")
|
|
if "drop_last" in dataloader_params:
|
|
if "drop_last" in dataloader_params:
|
|
drop_last = dataloader_params.pop("drop_last")
|
|
drop_last = dataloader_params.pop("drop_last")
|
|
else:
|
|
else:
|
|
- drop_last = default_dataloader_params["drop_last"]
|
|
|
|
|
|
+ drop_last = dataloader_params["drop_last"]
|
|
dataloader_params["batch_sampler"] = BatchSampler(sampler=sampler, batch_size=batch_size, drop_last=drop_last)
|
|
dataloader_params["batch_sampler"] = BatchSampler(sampler=sampler, batch_size=batch_size, drop_last=drop_last)
|
|
return dataloader_params
|
|
return dataloader_params
|
|
|
|
|
|
|
|
|
|
def _instantiate_sampler(dataset, dataloader_params):
|
|
def _instantiate_sampler(dataset, dataloader_params):
|
|
sampler_name = list(dataloader_params["sampler"].keys())[0]
|
|
sampler_name = list(dataloader_params["sampler"].keys())[0]
|
|
|
|
+ if "shuffle" in dataloader_params.keys():
|
|
|
|
+ # SHUFFLE IS MUTUALLY EXCLUSIVE WITH SAMPLER ARG IN DATALOADER INIT
|
|
|
|
+ dataloader_params["sampler"][sampler_name]["shuffle"] = dataloader_params.pop("shuffle")
|
|
dataloader_params["sampler"][sampler_name]["dataset"] = dataset
|
|
dataloader_params["sampler"][sampler_name]["dataset"] = dataset
|
|
dataloader_params["sampler"] = SamplersFactory().get(dataloader_params["sampler"])
|
|
dataloader_params["sampler"] = SamplersFactory().get(dataloader_params["sampler"])
|
|
return dataloader_params
|
|
return dataloader_params
|