Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

tl_VGG19.py 2.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
  1. from tensorflow import keras
  2. from datetime import datetime
  3. from ..sup.evaluation import *
  4. from ..sup.support import *
  5. from ..sup.test_set_eval import *
  6. model_name = ""
  7. classes = ['heart','brain','eye','kidney','skull','other']
  8. root_dir = '../../datasets/'
  9. train_dir = os.path.join(root_dir,'train')
  10. validation_dir = os.path.join(root_dir,'validation')
  11. tr_heart_dir,tr_brain_dir,tr_eye_dir,tr_kidney_dir,tr_skull_dir = path_update(train_dir,classes)
  12. vl_heart_dir,vl_brain_dir,vl_eye_dir,vl_kidney_dir,vl_skull_dir = path_update(validation_dir,classes)
  13. plot_sample_of_img(4,4,os.listdir(tr_heart_dir)+os.listdir(tr_eye_dir))
  14. train_gen_tmp = ImageDataGenerator(rescale=1./255,
  15. rotation_range=40,
  16. width_shift_range=0.2,
  17. height_shift_range=0.2,
  18. shear_range=0.2,
  19. zoom_range=0.2,
  20. horizontal_flip=True,
  21. fill_mode='nearest')
  22. validation_gen_tmp = ImageDataGenerator(rescale=1/225.)
  23. train_gen = train_gen_tmp.flow_from_directory(train_dir,
  24. target_size=(300,300),
  25. color_mode='rgb',
  26. class_mode='categorical',
  27. batch_size= 100,
  28. shuffle=True,
  29. seed=42)
  30. validation_gen = validation_gen_tmp.flow_from_directory(validation_dir,
  31. target_size=(300,300),
  32. color_mode='rgb',
  33. class_mode='categorical',
  34. batch_size= 100,
  35. shuffle=True,
  36. seed=42)
  37. STEP_SIZE_TRAIN=train_gen.n//train_gen.batch_size
  38. STEP_SIZE_VALID=validation_gen.n//validation_gen.batch_size
  39. clToInt_dict = train_gen.class_indices
  40. clToInt_dict = dict((v,k) for v,k in clToInt_dict.items())
  41. model = keras.models.Sequential()
  42. model.compile()
  43. # Define the Keras TensorBoard callback.
  44. logdir="logs/fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")
  45. tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)
  46. history = model.fit()
  47. #visualize_model(model,img_path)
  48. acc_n_loss(history)
  49. model.evaluate_generator(validation_gen,
  50. steps=STEP_SIZE_VALID)
  51. y_pred,y_test = test_eval(model,classes)
  52. plot_confusion_metrix(y_test,y_pred,classes)
  53. ROC_classes(6,y_test,y_pred,classes)
  54. model_path,model_weight_path = save(model,datetime.now()+model_name)
  55. #rnd_predict(model_path,model_weight_path,img_path,clToInt_dict)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...