Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

appv1.py 4.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
  1. import gradio as gr
  2. import os
  3. import torch
  4. import pandas as pd
  5. import numpy as np
  6. from mlProject.utils.bertModel import preprocess_text
  7. from mlProject.utils.common import load_product_image
  8. from PIL import Image
  9. import os
  10. from dotenv import load_dotenv
  11. from pinecone import Pinecone, ServerlessSpec
  12. import io
  13. import base64
  14. load_dotenv()
  15. pkey = os.environ.get("pkey")
  16. pc=Pinecone(api_key=pkey)
  17. index = pc.Index('stv1-embeddings')
  18. sampled_data = pd.read_csv('artifacts/data_ingestion/data_tar_extracted/processed_dataset_target_data_with_captions_only.csv')
  19. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  20. # TODO: Add path to config and use untar image
  21. tarfilepath = 'artifacts/data_ingestion/abo-images-small.tar'
  22. # Dummy preprocess_text and generate_image_url functions
  23. '''
  24. def preprocess_text(text):
  25. inputs = bert_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
  26. outputs = bert_model(**inputs)
  27. return outputs.pooler_output.detach().numpy()
  28. '''
  29. def generate_image_url(path):
  30. # Specify the directory path to where your images are stored
  31. base_path = "/Users/user/Documents/MLProjects/project6/artifacts/data_ingestion/abo-images-small/images/resize/"
  32. full_path = os.path.join(base_path, path)
  33. try:
  34. with Image.open(full_path) as img:
  35. # Resize the image to 224x224
  36. img = img.resize((224, 224), Image.Resampling.LANCZOS)
  37. # Save the resized image to a byte buffer
  38. img_byte_arr = io.BytesIO()
  39. img.save(img_byte_arr, format='JPEG')
  40. img_byte_arr = img_byte_arr.getvalue()
  41. # Encode the image to base64 and format it as a data URL
  42. base64_img = base64.b64encode(img_byte_arr).decode('utf-8')
  43. return f"data:image/jpeg;base64,{base64_img}"
  44. except Exception as e:
  45. print(f"Error processing image {path}: {e}")
  46. return None
  47. def search_similar_products(query_text):
  48. query_embedding = preprocess_text(query_text).flatten()
  49. results = index.query(vector=query_embedding.tolist(), top_k=5, include_metadata=True)
  50. images = []
  51. captions = []
  52. for result in results['matches']:
  53. item_id = result['id']
  54. metadata = result.get('metadata', {})
  55. item_details = sampled_data[sampled_data['item_id'] == item_id].iloc[0]
  56. item_name = item_details['item_name_in_en']
  57. brand_name = item_details['brand']
  58. product_type = item_details['product_type']
  59. item_image_path = item_details['path']
  60. image_url = generate_image_url(item_image_path)
  61. images.append(image_url)
  62. caption = f"Product: {item_name}, Brand: {brand_name}, Type: {product_type}"
  63. captions.append(caption)
  64. return images, captions
  65. def clear_all(chatbot, image_display):
  66. chatbot.clear()
  67. image_display.clear()
  68. return [], []
  69. def setup_ui(query):
  70. try:
  71. images, captions = search_similar_products(query)
  72. gallery_data = [(image, caption) for image, caption in zip(images, captions)]
  73. return gallery_data
  74. except Exception as e:
  75. print(f"Error in setup_ui: {e}")
  76. return [] # Return an empty list in case of an error
  77. css = """
  78. #gallery {font-size: 24px !important}
  79. """
  80. with gr.Blocks() as demo:
  81. input_text = gr.Textbox(label="Enter product query:")
  82. submit_button = gr.Button("Submit")
  83. gallery = gr.Gallery(label="Product Images and Descriptions",show_label= False,
  84. elem_id="gallery", columns=[8], rows=[1], object_fit='scale-down', height=50)
  85. submit_button.click(fn=setup_ui, inputs=[input_text], outputs=[gallery])
  86. demo.launch(height= 60, width ="50%")
  87. '''
  88. def setup_ui(query):
  89. results = search_similar_products(query)
  90. output_image = []
  91. output_data =[]
  92. # Check results and prepare data for output
  93. for result in results:
  94. # Append a tuple of (image_url, description) for each result
  95. output_image.append(result["image"])
  96. output_data.append(result["description"])
  97. print("Output Data ", output_data, output_image)
  98. return output_image, output_data '''
  99. '''
  100. with gr.Blocks() as demo:
  101. with gr.Row():
  102. gr.Markdown("# ShopTalk Product Search")
  103. with gr.Row():
  104. input_text = gr.Textbox(label="Enter Product Name:")
  105. with gr.Row():
  106. output_text = gr.Textbox(label="Results")
  107. with gr.Row():
  108. output_image = gr.Image(label="Generated Image")
  109. submit_button = gr.Button("Submit")
  110. submit_button.click(
  111. fn=search_similar_products,
  112. inputs=input_text,
  113. outputs=[output_text, output_image]
  114. )
  115. clear_button = gr.Button("Clear")
  116. clear_button.click(
  117. fn=clear_all,
  118. inputs=[],
  119. outputs=[output_text, output_image]
  120. )
  121. demo.launch()
  122. '''
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...