Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

presets.py 1.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
  1. import torch
  2. import transforms as T
  3. class OpticalFlowPresetEval(torch.nn.Module):
  4. def __init__(self):
  5. super().__init__()
  6. self.transforms = T.Compose(
  7. [
  8. T.PILToTensor(),
  9. T.ConvertImageDtype(torch.float32),
  10. T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
  11. T.ValidateModelInput(),
  12. ]
  13. )
  14. def forward(self, img1, img2, flow, valid):
  15. return self.transforms(img1, img2, flow, valid)
  16. class OpticalFlowPresetTrain(torch.nn.Module):
  17. def __init__(
  18. self,
  19. *,
  20. # RandomResizeAndCrop params
  21. crop_size,
  22. min_scale=-0.2,
  23. max_scale=0.5,
  24. stretch_prob=0.8,
  25. # AsymmetricColorJitter params
  26. brightness=0.4,
  27. contrast=0.4,
  28. saturation=0.4,
  29. hue=0.5 / 3.14,
  30. # Random[H,V]Flip params
  31. asymmetric_jitter_prob=0.2,
  32. do_flip=True,
  33. ):
  34. super().__init__()
  35. transforms = [
  36. T.PILToTensor(),
  37. T.AsymmetricColorJitter(
  38. brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob
  39. ),
  40. T.RandomResizeAndCrop(
  41. crop_size=crop_size, min_scale=min_scale, max_scale=max_scale, stretch_prob=stretch_prob
  42. ),
  43. ]
  44. if do_flip:
  45. transforms += [T.RandomHorizontalFlip(p=0.5), T.RandomVerticalFlip(p=0.1)]
  46. transforms += [
  47. T.ConvertImageDtype(torch.float32),
  48. T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
  49. T.RandomErasing(max_erase=2),
  50. T.MakeValidFlowMask(),
  51. T.ValidateModelInput(),
  52. ]
  53. self.transforms = T.Compose(transforms)
  54. def forward(self, img1, img2, flow, valid):
  55. return self.transforms(img1, img2, flow, valid)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...