Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

streamlit_client.py 2.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
  1. import cv2
  2. import numpy as np
  3. import os
  4. import streamlit as st
  5. from omegaconf import OmegaConf
  6. from PIL import Image
  7. __import__('sys').path.append('.')
  8. from deployment.src.client import predict_api, predict_sagemaker
  9. @st.cache
  10. def load_image(image_file):
  11. pil_image = Image.open(image_file)
  12. image = np.array(pil_image)
  13. image = cv2.resize(image, (224, 224))
  14. image = np.float32(image / np.float32(255))
  15. if len(image.shape) == 2:
  16. image = image[..., np.newaxis].repeat(3, axis=-1)
  17. return image
  18. @st.cache
  19. def predict(image_tensor: np.ndarray, mode: str, endpoint: str):
  20. if mode == 'API Endpoint':
  21. return predict_api(image_tensor[np.newaxis, ...], api_endpoint=endpoint)[0]
  22. else:
  23. return predict_sagemaker(image_tensor[np.newaxis, ...], sagemaker_endpoint=endpoint)[0]
  24. DEPLOYMENT_CONST_PATH = os.path.join('deployment', 'src', 'const.yaml')
  25. GENERAL_CONST_PATH = os.path.join('src', 'const.yaml')
  26. deployment_const = OmegaConf.load(os.path.join(os.getcwd(), DEPLOYMENT_CONST_PATH))
  27. general_const = OmegaConf.load(os.path.join(os.getcwd(), GENERAL_CONST_PATH))
  28. st.title('CheXNet Inference')
  29. mode = st.selectbox('Method', ['API Endpoint', 'AWS SageMaker CLI'])
  30. if mode == 'API Endpoint':
  31. api_endpoint = st.text_input('API Endpoint URL', deployment_const.API_ENDPOINT)
  32. elif mode == 'AWS SageMaker CLI':
  33. sagemamer_endpoint = st.text_input('SageMaker Endpoint Name', deployment_const.SAGEMAKER_ENDPOINT)
  34. st.markdown('---')
  35. st.header('Select image')
  36. image_file = st.file_uploader('Upload Image', type=['png', 'jpg', 'jpeg', 'gif', 'bmp'])
  37. if image_file is not None:
  38. image = load_image(image_file)
  39. st.image(image)
  40. raw_output = predict(image, mode=mode, endpoint=api_endpoint if mode == 'API Endpoint' else sagemamer_endpoint)
  41. st.write("Raw model output:")
  42. st.write(raw_output[np.newaxis])
  43. mask = (raw_output >= 0.5)[:eval(general_const.N_CLASSES)]
  44. n_ailments = np.sum(mask)
  45. st.header(f"This image shows {n_ailments} ailment{'s' if n_ailments != 1 else ''}:")
  46. output = np.array(general_const.CLASSES)[mask]
  47. st.write(output)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...