Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

moving_mnist.py 3.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
  1. import os.path
  2. from pathlib import Path
  3. from typing import Callable, Optional, Union
  4. import numpy as np
  5. import torch
  6. from torchvision.datasets.utils import download_url, verify_str_arg
  7. from torchvision.datasets.vision import VisionDataset
  8. class MovingMNIST(VisionDataset):
  9. """`MovingMNIST <http://www.cs.toronto.edu/~nitish/unsupervised_video/>`_ Dataset.
  10. Args:
  11. root (str or ``pathlib.Path``): Root directory of dataset where ``MovingMNIST/mnist_test_seq.npy`` exists.
  12. split (string, optional): The dataset split, supports ``None`` (default), ``"train"`` and ``"test"``.
  13. If ``split=None``, the full data is returned.
  14. split_ratio (int, optional): The split ratio of number of frames. If ``split="train"``, the first split
  15. frames ``data[:, :split_ratio]`` is returned. If ``split="test"``, the last split frames ``data[:, split_ratio:]``
  16. is returned. If ``split=None``, this parameter is ignored and the all frames data is returned.
  17. download (bool, optional): If true, downloads the dataset from the internet and
  18. puts it in root directory. If dataset is already downloaded, it is not
  19. downloaded again.
  20. transform (callable, optional): A function/transform that takes in a torch Tensor
  21. and returns a transformed version. E.g, ``transforms.RandomCrop``
  22. """
  23. _URL = "http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy"
  24. def __init__(
  25. self,
  26. root: Union[str, Path],
  27. split: Optional[str] = None,
  28. split_ratio: int = 10,
  29. download: bool = False,
  30. transform: Optional[Callable] = None,
  31. ) -> None:
  32. super().__init__(root, transform=transform)
  33. self._base_folder = os.path.join(self.root, self.__class__.__name__)
  34. self._filename = self._URL.split("/")[-1]
  35. if split is not None:
  36. verify_str_arg(split, "split", ("train", "test"))
  37. self.split = split
  38. if not isinstance(split_ratio, int):
  39. raise TypeError(f"`split_ratio` should be an integer, but got {type(split_ratio)}")
  40. elif not (1 <= split_ratio <= 19):
  41. raise ValueError(f"`split_ratio` should be `1 <= split_ratio <= 19`, but got {split_ratio} instead.")
  42. self.split_ratio = split_ratio
  43. if download:
  44. self.download()
  45. if not self._check_exists():
  46. raise RuntimeError("Dataset not found. You can use download=True to download it.")
  47. data = torch.from_numpy(np.load(os.path.join(self._base_folder, self._filename)))
  48. if self.split == "train":
  49. data = data[: self.split_ratio]
  50. elif self.split == "test":
  51. data = data[self.split_ratio :]
  52. self.data = data.transpose(0, 1).unsqueeze(2).contiguous()
  53. def __getitem__(self, idx: int) -> torch.Tensor:
  54. """
  55. Args:
  56. idx (int): Index
  57. Returns:
  58. torch.Tensor: Video frames (torch Tensor[T, C, H, W]). The `T` is the number of frames.
  59. """
  60. data = self.data[idx]
  61. if self.transform is not None:
  62. data = self.transform(data)
  63. return data
  64. def __len__(self) -> int:
  65. return len(self.data)
  66. def _check_exists(self) -> bool:
  67. return os.path.exists(os.path.join(self._base_folder, self._filename))
  68. def download(self) -> None:
  69. if self._check_exists():
  70. return
  71. download_url(
  72. url=self._URL,
  73. root=self._base_folder,
  74. filename=self._filename,
  75. md5="be083ec986bfe91a449d63653c411eb2",
  76. )
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...