Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

prepare_img.py 3.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
  1. # -*- coding: utf-8 -*-
  2. # Copyright (c) 2021. Jeffrey Nirschl. All rights reserved.
  3. #
  4. # Licensed under the MIT license. See the LICENSE file in the project
  5. # root directory for license information.
  6. #
  7. # Time-stamp: <>
  8. # ======================================================================
  9. import argparse
  10. import os
  11. from pathlib import Path
  12. import cv2
  13. import numpy as np
  14. import pandas as pd
  15. from src.img import transforms
  16. def main(data_path, ext="png",
  17. img_shape=(28, 28, 1),
  18. output="./data/interim/",
  19. prefix="",
  20. na_rep="nan"):
  21. """Accept a numpy array of flattened images and
  22. save as images."""
  23. # create output director, if needed
  24. output = Path(output).resolve().joinpath(prefix)
  25. if not os.path.isdir(output):
  26. os.mkdir(output)
  27. # check for errors
  28. assert os.path.isfile(data_path), FileNotFoundError
  29. assert os.path.isdir(output), NotADirectoryError
  30. # remove period from ext
  31. ext = ext.replace(".", "")
  32. # read file
  33. img_array = pd.read_csv(data_path, sep=",",
  34. header=0)
  35. # pop target column and save
  36. if "label" in img_array.columns:
  37. target = pd.DataFrame(img_array.pop("label"))
  38. else:
  39. target = pd.DataFrame({"label": np.full_like(img_array[img_array.columns[0]],
  40. np.nan, dtype=np.float32)})
  41. # create mean image
  42. mean_image = transforms.mean_image(img_array)
  43. cv2.imwrite(str(output.parent.joinpath(f"{prefix}_mean_image.png")), mean_image)
  44. # save individual images
  45. filenames = save_images(img_array, target, img_shape, ext, output)
  46. # save a mapfile with the filename and label
  47. mapfile = pd.DataFrame({"filenames": filenames,
  48. target.columns[0]: target[target.columns[0]]},
  49. index=target.index)
  50. mapfile.to_csv(output.parent.joinpath(f"{prefix}_mapfile.csv"),
  51. na_rep=na_rep)
  52. def save_images(img_array, target, img_shape, ext, output):
  53. """Subfunction to process flattened images in dataframe"""
  54. # process dataframe line by line
  55. print(f"Reshaping flattened images in numpy array into 2D images...")
  56. filenames = []
  57. for idx in range(img_array.shape[0]):
  58. if (idx + 1) % 10000 == 0:
  59. print(f"\tProcessed {idx + 1} images")
  60. # select image and reshape
  61. temp_img = np.reshape(img_array.iloc[idx].to_numpy(),
  62. img_shape).astype(np.float32)
  63. # set filename
  64. temp_name = f"{idx:06d}_{target.iloc[idx].to_numpy()[0]}.{ext}"
  65. filenames.append(str(output.joinpath(temp_name)))
  66. if not cv2.imwrite(filenames[idx], temp_img):
  67. raise SystemError
  68. print(f"\tProcessed {idx + 1} images")
  69. return filenames
  70. if __name__ == '__main__':
  71. parser = argparse.ArgumentParser()
  72. parser.add_argument("-tr", "--train_data", dest="train_data",
  73. required=True, help="Train CSV file")
  74. parser.add_argument("-te", "--test_data", dest="test_data",
  75. required=True, help="Test CSV file")
  76. parser.add_argument("-ex", "--ext", dest="ext",
  77. default=".png",
  78. required=False, help="Train CSV file")
  79. parser.add_argument("-o", "--out-dir", dest="output_dir",
  80. default=Path("./data/interim").resolve(),
  81. required=False, help="output directory")
  82. args = parser.parse_args()
  83. # categorical variables into integer codes
  84. main(args.train_data, prefix="train", ext=args.ext, output=args.output_dir)
  85. main(args.test_data, prefix="test", ext=args.ext, output=args.output_dir)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...