Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

functions.py 3.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
  1. import nibabel as nib
  2. import glob
  3. from scipy.ndimage import zoom
  4. from tqdm import tqdm
  5. from pathlib import PurePath
  6. import numpy as np
  7. import os
  8. # TODO: Add docstring to functions
  9. def load_Nift2np(path):
  10. """
  11. Reads a .nii.gz file and loads it into numpy array
  12. """
  13. return np.array(nib.load(path).dataobj)
  14. def resize(img, shape, mode='constant', orig_shape=(240, 240, 155)):
  15. """
  16. Wrapper for scipy.ndimage.zoom suited for MRI images.
  17. """
  18. assert len(shape) == 3, "Can not have more than 3 dimensions"
  19. factors = (
  20. shape[0] / orig_shape[0],
  21. shape[1] / orig_shape[1],
  22. shape[2] / orig_shape[2]
  23. )
  24. # Resize to the given shape
  25. return zoom(img, factors, mode=mode)
  26. def process_img(img, out_shape=None):
  27. """
  28. Preprocess the image.
  29. Just an example, you can add more preprocessing steps if you wish to.
  30. """
  31. if out_shape is not None:
  32. img = resize(img, out_shape, mode='constant')
  33. # Normalize the image
  34. mean = img.mean()
  35. std = img.std()
  36. return (img - mean) / std
  37. def save_img(img, raw_img_path):
  38. # Create output path
  39. process_img_path = PurePath(raw_img_path.replace('raw', 'processed'))
  40. # Makedir if necessary
  41. if not os.path.isdir(process_img_path.parent):
  42. os.makedirs(process_img_path.parent)
  43. # Save img to file
  44. ni_img = nib.Nifti1Image(img, affine=np.eye(4))
  45. nib.save(ni_img, process_img_path)
  46. def preprocess_label(img, out_shape=None, mode='nearest'):
  47. """
  48. Separates out the 3 labels from the segmentation provided, namely:
  49. GD-enhancing tumor (ET — label 4), the peritumoral edema (ED — label 2))
  50. and the necrotic and non-enhancing tumor core (NCR/NET — label 1)
  51. """
  52. ncr = img == 1 # Necrotic and Non-Enhancing Tumor (NCR/NET)
  53. ed = img == 2 # Peritumoral Edema (ED)
  54. et = img == 4 # GD-enhancing Tumor (ET)
  55. if out_shape:
  56. ncr = resize(ncr, out_shape, mode=mode)
  57. ed = resize(ed, out_shape, mode=mode)
  58. et = resize(et, out_shape, mode=mode)
  59. return np.array([ncr, ed, et], dtype=np.uint8)
  60. def modalities_path(base_path):
  61. t1 = glob.glob(os.path.join(base_path + '/*t1.nii.gz'))
  62. t2 = glob.glob(os.path.join(base_path + '/*t2.nii.gz'))
  63. flair = glob.glob(os.path.join(base_path + '/*flair.nii.gz'))
  64. t1ce = glob.glob(os.path.join(base_path + '/*t1ce.nii.gz'))
  65. seg = glob.glob(os.path.join(base_path + '/*seg.nii.gz'))
  66. return t1, t2, flair, t1ce, seg
  67. def load_raw_data(t1, t2, flair, t1ce, num_samples, input_shape):
  68. data: np.array = np.empty((num_samples,) + input_shape, dtype=np.float32)
  69. for index, item in enumerate(tqdm(zip(t1, t2, flair, t1ce), desc="Loading Data", total=len(t1))):
  70. data[index] = np.array(
  71. [process_img(load_Nift2np(modal_path), out_shape=input_shape[1:]) for modal_path in item])
  72. return data
  73. def load_raw_labels(seg, num_samples, input_shape, output_channels):
  74. labels: np.array = np.empty((num_samples, output_channels) + input_shape[1:], dtype=np.float32)
  75. for index, modal_path in enumerate(tqdm(seg, total=len(seg), desc="Loading Labels")):
  76. labels[index] = preprocess_label(load_Nift2np(modal_path), out_shape=input_shape[1:])[None, ...]
  77. return labels
  78. def save_data(t1, t2, flair, t1ce, data):
  79. for index, item in enumerate(tqdm(zip(t1, t2, flair, t1ce), desc="Saving Processed Data", total=len(t1))):
  80. for modal, modal_path in enumerate(item):
  81. save_img(data[index][modal], modal_path)
  82. def save_label(labels, seg, file_type):
  83. for label_index, modal_path in enumerate(tqdm(seg, total=len(seg), desc="Loading Labels")):
  84. for c_type_index, c_type in enumerate(("_NCR_NET", "_ED", "_ET")):
  85. save_img(labels[label_index][c_type_index], modal_path.split(file_type)[0] + c_type + file_type)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...