streamlit_app.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # todo: Edit SHORT_DESCRIPTION
  2. # todo: Change val set to be traced by git & change img path here
  3. # todo: add docstrings to functions
  4. # todo general:
  5. # - Add wights to the model,
  6. # - change base path in const to general const
  7. import streamlit as st
  8. import cv2 as cv
  9. from PIL import Image
  10. import numpy as np
  11. import tensorflow as tf
  12. from src.const.general_const import PROD_MODEL_PATH, IMG_SIZE, CLASS_NAME_PATH
  13. from task_5_streamlit.src.const.streamlit_const import \
  14. DAGSHUB_IMAGE_PATH, HEALTHY_IMAGE_ONE_PATH,HEALTHY_IMAGE_TWO_PATH, SICK_IMAGE_ONE_PATH, SICK_IMAGE_TWO_PATH,HEADER,\
  15. SUB_HEADER, SHORT_DESCRIPTION, IMAGE_POOL_DESCRIPTION, SELECT_BOX_TEXT, SUPPORTED_IMG_TYPE, WARNING_MSG, SUCCESS_MSG, \
  16. BIG_FONT, MID_FONT, SMALL_FONT
  17. def markdown_format(font_size,content):
  18. st.markdown(f"<{font_size} style='text-align: center;'>{content}</{font_size}>",
  19. unsafe_allow_html=True)
  20. def load_n_resize_image(image_path):
  21. pil_img = Image.open(image_path)
  22. return cv.resize(cv.cvtColor(np.array(pil_img), cv.COLOR_RGB2BGR), IMG_SIZE)
  23. def load_image_pool():
  24. healthy = [load_n_resize_image(HEALTHY_IMAGE_ONE_PATH), load_n_resize_image(HEALTHY_IMAGE_TWO_PATH)]
  25. sick = [load_n_resize_image(SICK_IMAGE_ONE_PATH), load_n_resize_image(SICK_IMAGE_TWO_PATH)]
  26. return {'healthy': healthy, 'sick': sick}
  27. def present_pool(col, col_name, img_list):
  28. name_list = []
  29. for row in range(len(img_list)):
  30. col.image(img_list[row], use_column_width=True, caption=col_name + f" {row + 1}")
  31. name_list.append(col_name + f" {row + 1}")
  32. return name_list
  33. def display_prediction(pred):
  34. if pred == 'sick':
  35. st.warning(WARNING_MSG)
  36. else:
  37. st.success(SUCCESS_MSG)
  38. @st.cache(suppress_st_warning=True)
  39. def get_prediction(img):
  40. with open(CLASS_NAME_PATH, "r") as textfile:
  41. class_names = textfile.read().split(',')
  42. img_expand = np.expand_dims(img, 0)
  43. model = tf.keras.models.load_model(PROD_MODEL_PATH)
  44. predictions = model.predict(img_expand)
  45. display_prediction(class_names[np.rint(predictions[0][0]).astype(int)])
  46. def predict_for_selectbox(selectbox, my_bar, latest_iteration):
  47. img_class = selectbox.split()[0]
  48. img_position = int(selectbox.split()[-1]) - 1
  49. img = dict_of_img_lists[img_class][img_position]
  50. my_bar.progress(50)
  51. latest_iteration.text('Processing image')
  52. get_prediction(img)
  53. my_bar.progress(100)
  54. def predict_for_file_buffer(file_buffer, my_bar, latest_iteration):
  55. latest_iteration.text('Loading image')
  56. img = load_n_resize_image(file_buffer)
  57. markdown_format(MID_FONT, "Your chest X-ray")
  58. st.image(img, use_column_width=True)
  59. my_bar.progress(50)
  60. latest_iteration.text('Processing image')
  61. get_prediction(img)
  62. my_bar.progress(100)
  63. if __name__ == '__main__':
  64. # Page configuration
  65. st.set_page_config(page_title=HEADER, page_icon="🤒",
  66. initial_sidebar_state='expanded')
  67. # Base Design
  68. st.image(image=DAGSHUB_IMAGE_PATH)
  69. markdown_format(BIG_FONT, HEADER)
  70. markdown_format(MID_FONT, SUB_HEADER)
  71. markdown_format(SMALL_FONT, SHORT_DESCRIPTION)
  72. latest_iteration = st.empty()
  73. my_bar = st.progress(0)
  74. # Show pool of images
  75. dict_of_img_lists = load_image_pool()
  76. with st.beta_expander("Image Pool"):
  77. markdown_format(MID_FONT, IMAGE_POOL_DESCRIPTION)
  78. col1, col2 = st.beta_columns(2)
  79. healthy_sidebar_list = present_pool(col1, "healthy", dict_of_img_lists['healthy'])
  80. sick_sidebar_list = present_pool(col2, "sick", dict_of_img_lists['sick'])
  81. # Sidebar
  82. selectbox = st.sidebar.selectbox(SELECT_BOX_TEXT,
  83. [None] + healthy_sidebar_list + sick_sidebar_list)
  84. file_buffer = st.sidebar.file_uploader("", type=SUPPORTED_IMG_TYPE)
  85. # Predict for user selection
  86. if selectbox:
  87. predict_for_selectbox(selectbox, my_bar, latest_iteration)
  88. dict_of_img_lists = load_image_pool()
  89. if file_buffer:
  90. predict_for_file_buffer(file_buffer, my_bar, latest_iteration)
Tip!

Press p or to see the previous file or, n or to see the next file