|
@@ -231,6 +231,7 @@ class MultiscalePrePredictionCallback(AbstractPrePredictionCallback):
|
|
|
self.rank = None
|
|
|
self.is_distributed = None
|
|
|
self.sampled_imres_once = False
|
|
|
+ self.new_input_size = None
|
|
|
|
|
|
def __call__(self, inputs, targets, batch_idx):
|
|
|
if self.rank is None:
|
|
@@ -239,9 +240,9 @@ class MultiscalePrePredictionCallback(AbstractPrePredictionCallback):
|
|
|
self.is_distributed = get_world_size() > 1
|
|
|
|
|
|
# GENERATE A NEW SIZE AND BROADCAST IT TO THE THE OTHER RANKS SO THEY HAVE THE SAME SCALE
|
|
|
+ input_size = inputs.shape[2:]
|
|
|
if batch_idx % self.frequency == 0:
|
|
|
- tensor = torch.LongTensor(2).cuda()
|
|
|
- input_size = inputs.shape[2:]
|
|
|
+ tensor = torch.LongTensor(2).to(inputs.device)
|
|
|
|
|
|
if self.rank == 0:
|
|
|
size_factor = input_size[1] * 1.0 / input_size[0]
|
|
@@ -262,12 +263,12 @@ class MultiscalePrePredictionCallback(AbstractPrePredictionCallback):
|
|
|
dist.barrier()
|
|
|
dist.broadcast(tensor, 0)
|
|
|
|
|
|
- new_input_size = (tensor[0].item(), tensor[1].item())
|
|
|
+ self.new_input_size = (tensor[0].item(), tensor[1].item())
|
|
|
|
|
|
- scale_y = new_input_size[0] / input_size[0]
|
|
|
- scale_x = new_input_size[1] / input_size[1]
|
|
|
- if scale_x != 1 or scale_y != 1:
|
|
|
- inputs = torch.nn.functional.interpolate(inputs, size=new_input_size, mode="bilinear", align_corners=False)
|
|
|
+ scale_y = self.new_input_size[0] / input_size[0]
|
|
|
+ scale_x = self.new_input_size[1] / input_size[1]
|
|
|
+ if scale_x != 1 or scale_y != 1:
|
|
|
+ inputs = torch.nn.functional.interpolate(inputs, size=self.new_input_size, mode="bilinear", align_corners=False)
|
|
|
return inputs, targets
|
|
|
|
|
|
|