Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

mnist_loader.py 3.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
  1. #!/usr/bin/env python
  2. #coding: utf-8
  3. import gzip
  4. import numpy
  5. def _read32(bytestream):
  6. dt = numpy.dtype(numpy.uint32).newbyteorder('>')
  7. return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
  8. def _extract_images(f):
  9. """Extract the images into a 4D uint8 numpy array [index, y, x, depth].
  10. Args:
  11. f: A file object that can be passed into a gzip reader.
  12. Returns:
  13. data: A 4D unit8 numpy array [index, y, x].
  14. Raises:
  15. ValueError: If the bytestream does not start with 2051.
  16. """
  17. print('Extracting', f.name)
  18. with gzip.GzipFile(fileobj=f) as bytestream:
  19. magic = _read32(bytestream)
  20. if magic != 2051:
  21. raise ValueError('Invalid magic number %d in MNIST image file: %s' %
  22. (magic, f.name))
  23. num_images = _read32(bytestream)
  24. rows = _read32(bytestream)
  25. cols = _read32(bytestream)
  26. buf = bytestream.read(rows * cols * num_images)
  27. data = numpy.frombuffer(buf, dtype=numpy.uint8)
  28. data = data.reshape(num_images, rows, cols)
  29. data = numpy.multiply(data, 1.0 / 255.0)
  30. return data
  31. def _dense_to_one_hot(labels_dense, num_classes):
  32. """Convert class labels from scalars to one-hot vectors."""
  33. num_labels = labels_dense.shape[0]
  34. index_offset = numpy.arange(num_labels) * num_classes
  35. labels_one_hot = numpy.zeros((num_labels, num_classes))
  36. labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
  37. return labels_one_hot
  38. def _extract_labels(f, one_hot=False, num_classes=10):
  39. """Extract the labels into a 1D uint8 numpy array [index].
  40. Args:
  41. f: A file object that can be passed into a gzip reader.
  42. one_hot: Does one hot encoding for the result.
  43. num_classes: Number of classes for the one hot encoding.
  44. Returns:
  45. labels: a 1D unit8 numpy array.
  46. Raises:
  47. ValueError: If the bystream doesn't start with 2049.
  48. """
  49. print('Extracting', f.name)
  50. with gzip.GzipFile(fileobj=f) as bytestream:
  51. magic = _read32(bytestream)
  52. if magic != 2049:
  53. raise ValueError('Invalid magic number %d in MNIST label file: %s' %
  54. (magic, f.name))
  55. num_items = _read32(bytestream)
  56. buf = bytestream.read(num_items)
  57. labels = numpy.frombuffer(buf, dtype=numpy.uint8)
  58. if one_hot:
  59. labels = _dense_to_one_hot(labels, num_classes)
  60. return labels
  61. def read_data_sets():
  62. TRAIN_IMAGES = 'data/train-images-idx3-ubyte.gz'
  63. TRAIN_LABELS = 'data/train-labels-idx1-ubyte.gz'
  64. TEST_IMAGES = 'data/t10k-images-idx3-ubyte.gz'
  65. TEST_LABELS = 'data/t10k-labels-idx1-ubyte.gz'
  66. local_file = TRAIN_IMAGES
  67. with open(local_file, 'rb') as f:
  68. train_images = _extract_images(f)
  69. local_file = TRAIN_LABELS
  70. with open(local_file, 'rb') as f:
  71. train_labels = _extract_labels(f, one_hot=True)
  72. local_file = TEST_IMAGES
  73. with open(local_file, 'rb') as f:
  74. test_images = _extract_images(f)
  75. local_file = TEST_LABELS
  76. with open(local_file, 'rb') as f:
  77. test_labels = _extract_labels(f, one_hot=True)
  78. return train_images, test_images, train_labels, test_labels
  79. # read_data_sets()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...