Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

extract_mnist.py 1.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
  1. #!/usr/bin/env python
  2. #coding: utf-8
  3. """
  4. MNIST数据集的原始格式是gz格式的,该脚本的功能是将该数据集中的每一张图片转为.png格式,并保存到指定目录。
  5. """
  6. import gzip
  7. import os
  8. import numpy
  9. import PIL.Image as Image
  10. def _read32(bytestream):
  11. dt = numpy.dtype(numpy.uint32).newbyteorder('>')
  12. return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
  13. def extract_and_save_mnist(file_path,save_dir):
  14. """
  15. file_path:mnist数据集的路径,比如"./data/t10k-images-idx3-ubyte.gz"
  16. save_dir:要保存至的目标目录,比如"/home/u2/data"
  17. """
  18. f=open(file_path, 'rb')
  19. print('Extracting', f.name)
  20. with gzip.GzipFile(fileobj=f) as bytestream:
  21. magic = _read32(bytestream)
  22. if magic != 2051:
  23. raise ValueError('Invalid magic number %d in MNIST image file: %s' %
  24. (magic, f.name))
  25. num_images = _read32(bytestream)
  26. rows = _read32(bytestream)
  27. cols = _read32(bytestream)
  28. buf = bytestream.read(rows * cols * num_images)
  29. data = numpy.frombuffer(buf, dtype=numpy.uint8)
  30. data = data.reshape(num_images, rows, cols)
  31. #依次保存图片
  32. for index,d in enumerate(data):
  33. imges = Image.fromarray(d).convert('L')
  34. abs_path = os.path.join(os.path.abspath(save_dir), str(index)+".png")
  35. imges.save(abs_path,'png')
  36. def main():
  37. TEST_IMAGES_PATH = 'data/t10k-images-idx3-ubyte.gz'
  38. extract_and_save_mnist(TEST_IMAGES_PATH,"./img")
  39. # main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...