Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

model.py 1.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
  1. from .model_architecture import build_model
  2. from .const_model import *
  3. from .functions import *
  4. import os
  5. import mlflow
  6. t1, t2, flair, t1ce, seg = modalities_path(base_path.replace('raw', 'processed'))
  7. assert len(t1) == len(t2) == len(flair) == len(t1ce), 'Missing modalities in some patients'
  8. assert len(t1) == len(t2) == len(flair) == len(t1ce) > 0, 'The lists of paths are empty'
  9. num_samples: int = len(t1)
  10. # TODO change it to a function the will not be redundant
  11. data: np.array = np.empty((num_samples,) + input_shape, dtype=np.float32)
  12. for index, item in enumerate(tqdm(zip(t1, t2, flair, t1ce), desc="Loading Data", total=len(t1))):
  13. data[index] = np.array([load_Nift2np(modal_path) for modal_path in item])
  14. labels = load_raw_labels(seg, num_samples, input_shape, output_channels) # TODO: Load the processed images.
  15. mlflow.keras.autolog()
  16. model = build_model(input_shape=input_shape, output_channels=output_channels)
  17. model.fit(data, [labels, data], batch_size=batch_size, epochs=epochs)
  18. model.save_weights(os.getcwd()+model_weights_path, overwrite=True)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...