Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

input_data.py 3.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
  1. #!/usr/bin/env python
  2. #coding: utf-8
  3. import numpy as np
  4. import mnist_loader
  5. import collections
  6. Datasets = collections.namedtuple('Datasets', ['train', 'test'])
  7. class DataSet(object):
  8. def __init__(self,
  9. images,
  10. labels):
  11. self._num_examples = images.shape[0]
  12. self._images = images
  13. self._labels = labels
  14. self._epochs_completed = 0
  15. self._index_in_epoch = 0
  16. @property
  17. def images(self):
  18. return self._images
  19. @property
  20. def labels(self):
  21. return self._labels
  22. @property
  23. def num_examples(self):
  24. return self._num_examples
  25. @property
  26. def epochs_completed(self):
  27. return self._epochs_completed
  28. def mini_batches(self,mini_batch_size):
  29. """
  30. return: list of tuple(x,y)
  31. """
  32. # Shuffle the data
  33. perm = np.arange(self._num_examples)
  34. np.random.shuffle(perm)
  35. self._images = self._images[perm]
  36. self._labels = self._labels[perm]
  37. n = self.images.shape[0]
  38. mini_batches = [(self._images[k:k+mini_batch_size],self._labels[k:k+mini_batch_size])
  39. for k in range(0, n, mini_batch_size)]
  40. if len(mini_batches[-1]) != mini_batch_size:
  41. return mini_batches[:-1]
  42. else:
  43. return mini_batches
  44. def _next_batch(self, batch_size, fake_data=False):
  45. """Return the next `batch_size` examples from this data set."""
  46. start = self._index_in_epoch
  47. self._index_in_epoch += batch_size
  48. if self._index_in_epoch > self._num_examples:
  49. # Finished epoch
  50. self._epochs_completed += 1
  51. # Shuffle the data
  52. perm = np.arange(self._num_examples)
  53. np.random.shuffle(perm)
  54. self._images = self._images[perm]
  55. self._labels = self._labels[perm]
  56. # Start next epoch
  57. start = 0
  58. self._index_in_epoch = batch_size
  59. assert batch_size <= self._num_examples
  60. end = self._index_in_epoch
  61. return self._images[start:end], self._labels[start:end]
  62. def read_data_sets():
  63. """
  64. Function:读取训练集(TrainSet)和测试集(TestSet)。
  65. Notes
  66. ----------
  67. TrainSet: include imgs_train and labels_train.
  68. TestSet: include imgs_test and labels_test.
  69. the shape of imgs_train and imgs_test are:(batch_size,height,width). namely (n, 28L, 28L)
  70. the shape of labels_train and labels_test are:(batch_size,num_classes). namely (n, 10L)
  71. """
  72. imgs_train, imgs_test, labels_train, labels_test = mnist_loader.read_data_sets()
  73. train = DataSet(imgs_train, labels_train)
  74. test = DataSet(imgs_test, labels_test)
  75. return Datasets(train=train, test=test)
  76. def _test():
  77. dataset = read_data_sets()
  78. print("dataset.train.images.shape:",dataset.train.images.shape)
  79. print("dataset.train.labels.shape:",dataset.train.labels.shape)
  80. print("dataset.test.images.shape:",dataset.test.images.shape)
  81. print("dataset.test.labels.shape:",dataset.test.labels.shape)
  82. print(dataset.test.images[0])
  83. print(dataset.test.labels[0])
  84. # _test()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...