Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

assistant.py 19 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
  1. import tiktoken
  2. import openai
  3. from datetime import datetime
  4. from typing import Any
  5. from time import sleep
  6. from memory_manager import MemoryManager
  7. class OpenAIAssistant():
  8. """
  9. ChatGPT wrapper for OpenAI API
  10. """
  11. def __init__(
  12. self,
  13. api_key: str,
  14. chat_model: str = 'gpt-3.5-turbo',
  15. embedding_model: Any = 'text-embedding-ada-002',
  16. enc: str = 'gpt2',
  17. short_term_memory_summary_prompt: str = None,
  18. long_term_memory_summary_prompt: str = None,
  19. system_prompt: str = "You are a helpful assistant. Your name is SERPy.",
  20. short_term_memory_max_tokens: int = 750,
  21. long_term_memory_max_tokens: int = 500,
  22. knowledge_retrieval_max_tokens: int = 1000,
  23. short_term_memory_summary_max_tokens: int = 300,
  24. long_term_memory_summary_max_tokens: int = 300,
  25. knowledge_retrieval_summary_max_tokens: int = 600,
  26. summarize_short_term_memory: bool = False,
  27. summarize_long_term_memory: bool = False,
  28. summarize_knowledge_retrieval: bool = False,
  29. use_long_term_memory: bool = False,
  30. long_term_memory_collection_name: str = 'long_term_memory',
  31. use_short_term_memory: bool = False,
  32. use_knowledge_retrieval: bool = False,
  33. knowledge_retrieval_collection_name: str = 'knowledge_retrieval',
  34. price_per_token: float = 0.000002,
  35. max_seq_len: int = 4096,
  36. memory_manager: MemoryManager = None,
  37. debug: bool = False
  38. ) -> None:
  39. """
  40. Initialize the OpenAIAssistant
  41. Parameters:
  42. api_key (str): The OpenAI API key
  43. chat_model (str): The model to use for chat
  44. embedding_model (Any): The model to use for embeddings
  45. enc (str): The encoding to use for the model
  46. short_term_memory_summary_prompt (str): The prompt to use for short term memory summarization
  47. long_term_memory_summary_prompt (str): The prompt to use for long term memory summarization
  48. system_prompt (str): The system prompt to use for the model
  49. short_term_memory_max_tokens (int): The maximum number of tokens to store in short term memory
  50. long_term_memory_max_tokens (int): The maximum number of tokens to store in long term memory
  51. knowledge_retrieval_max_tokens (int): The maximum number of tokens to store in knowledge retrieval
  52. short_term_memory_summary_max_tokens (int): The maximum number of tokens to store in short term memory summary
  53. long_term_memory_summary_max_tokens (int): The maximum number of tokens to store in long term memory summary
  54. knowledge_retrieval_summary_max_tokens (int): The maximum number of tokens to store in knowledge retrieval summary
  55. summarize_short_term_memory (bool): Whether to use short term memory summarization
  56. summarize_long_term_memory (bool): Whether to use long term memory summarization
  57. summarize_knowledge_retrieval (bool): Whether to use knowledge retrieval summarization
  58. use_long_term_memory (bool): Whether to use long term memory
  59. long_term_memory_collection_name (str): The name of the long term memory collection
  60. use_short_term_memory (bool): Whether to use short term memory
  61. use_knowledge_retrieval (bool): Whether to use knowledge retrieval
  62. knowledge_retrieval_collection_name (str): The name of the knowledge retrieval collection
  63. price_per_token (float): The price per token in USD
  64. max_seq_len (int): The maximum sequence length
  65. memory_manager (MemoryManager): The memory manager to use for long term memory and knowledge retrieval
  66. debug (bool): Whether to enable debug mode
  67. """
  68. openai.api_key = api_key
  69. self.api_key = api_key
  70. self.chat_model = chat_model
  71. self.embedding_model = embedding_model
  72. self.enc = tiktoken.get_encoding(enc)
  73. self.memory_manager = memory_manager
  74. self.price_per_token = price_per_token
  75. self.short_term_memory = []
  76. self.short_term_memory_summary = ''
  77. self.long_term_memory_summary = ''
  78. self.knowledge_retrieval_summary = ''
  79. self.debug = debug
  80. self.summarize_short_term_memory = summarize_short_term_memory
  81. self.summarize_long_term_memory = summarize_long_term_memory
  82. self.summarize_knowledge_retrieval = summarize_knowledge_retrieval
  83. self.use_long_term_memory = use_long_term_memory
  84. self.long_term_memory_collection_name = 'long_term_memory' if long_term_memory_collection_name is None else long_term_memory_collection_name
  85. self.use_knowledge_retrieval = use_knowledge_retrieval
  86. self.knowledge_retrieval_collection_name = 'knowledge_retrieval' if knowledge_retrieval_collection_name is None else knowledge_retrieval_collection_name
  87. if self.memory_manager is None:
  88. self.use_long_term_memory = False
  89. self.use_knowledge_retrieval = False
  90. if self.use_long_term_memory and self.memory_manager is not None:
  91. self.memory_manager.create_collection(self.long_term_memory_collection_name)
  92. if self.use_knowledge_retrieval and self.memory_manager is not None:
  93. self.memory_manager.create_collection(self.knowledge_retrieval_collection_name)
  94. self.use_short_term_memory = use_short_term_memory
  95. self.short_term_memory_summary_max_tokens = short_term_memory_summary_max_tokens
  96. self.long_term_memory_summary_max_tokens = long_term_memory_summary_max_tokens
  97. self.knowledge_retrieval_summary_max_tokens = knowledge_retrieval_summary_max_tokens
  98. self.short_term_memory_max_tokens = short_term_memory_max_tokens
  99. self.long_term_memory_max_tokens = long_term_memory_max_tokens
  100. self.knowledge_retrieval_max_tokens = knowledge_retrieval_max_tokens
  101. self.system_prompt = system_prompt
  102. if short_term_memory_summary_prompt is None:
  103. self.short_term_memory_summary_prompt = "Summarize the following conversation:\n\nPrevious Summary: {previous_summary}\n\nConversation: {conversation}"
  104. else:
  105. self.short_term_memory_summary_prompt = short_term_memory_summary_prompt
  106. if long_term_memory_summary_prompt is None:
  107. self.long_term_memory_summary_prompt = "Summarize the following (out of order) conversation messages:\n\nPrevious Summary: {previous_summary}\n\nMessages: {conversation}"
  108. self.max_seq_len = max_seq_len
  109. def _construct_messages(self, prompt: str, inject_messages: list = []) -> list:
  110. """
  111. Construct the messages for the chat completion
  112. Parameters:
  113. prompt (str): The prompt to construct the messages for
  114. inject_messages (list): The messages to inject into the chat completion
  115. Returns:
  116. list: The messages to use for the chat completion
  117. """
  118. messages = []
  119. if self.system_prompt is not None and self.system_prompt != "":
  120. messages.append({
  121. "role": "system",
  122. "content": self.system_prompt
  123. })
  124. if self.use_long_term_memory:
  125. long_term_memory = self.query_long_term_memory(prompt, summarize=self.summarize_long_term_memory)
  126. if long_term_memory is not None and long_term_memory != '':
  127. messages.append({
  128. "role": "system",
  129. "content": long_term_memory
  130. })
  131. if self.summarize_short_term_memory:
  132. if self.short_term_memory_summary != '' and self.short_term_memory_summary is not None:
  133. messages.append({
  134. "role": "system",
  135. "content": self.short_term_memory_summary
  136. })
  137. if self.use_short_term_memory:
  138. for i, message in enumerate(self.short_term_memory):
  139. messages.append(message)
  140. if inject_messages is not None and inject_messages != []:
  141. for i in range(len(messages)):
  142. for y, message in enumerate(inject_messages):
  143. if i == list(message.keys())[0]:
  144. messages.insert(i, list(message.values())[0])
  145. inject_messages.pop(y)
  146. for message in inject_messages:
  147. messages.append(list(message.values())[0])
  148. if prompt is None or prompt == "":
  149. return messages
  150. messages.append({
  151. "role": "user",
  152. "content": prompt
  153. })
  154. return messages
  155. def change_system_prompt(self, system_prompt: str) -> None:
  156. """
  157. Change the system prompt
  158. Parameters:
  159. system_prompt (str): The new system prompt to use
  160. """
  161. self.system_prompt = system_prompt
  162. def calculate_num_tokens(self, text: str) -> int:
  163. """
  164. Calculate the number of tokens in a given text
  165. Parameters:
  166. text (str): The text to calculate the number of tokens for
  167. Returns:
  168. int: The number of tokens in the text
  169. """
  170. return len(self.enc.encode(text))
  171. def calculate_short_term_memory_tokens(self) -> int:
  172. """
  173. Calculate the number of tokens in short term memory
  174. Returns:
  175. int: The number of tokens in short term memory
  176. """
  177. return sum([self.calculate_num_tokens(message['content']) for message in self.short_term_memory])
  178. def query_long_term_memory(self, query: str, summarize=False) -> str:
  179. """
  180. Query long term memory
  181. Parameters:
  182. query (str): The query to use for long term memory
  183. summarize (bool): Whether to summarize the long term memory
  184. Returns:
  185. str: The long term memory
  186. """
  187. embedding = self.get_embedding(query).data[0].embedding
  188. points = self.memory_manager.search_points(vector=embedding, collection_name=self.long_term_memory_collection_name, k=20)
  189. if len(points) == 0:
  190. return ''
  191. long_term_memory = ''
  192. if summarize:
  193. long_term_memory += 'Summary of previous related conversations from long term memory:' + self.generate_long_term_memory_summary(points) + '\n\n'
  194. if self.long_term_memory_max_tokens > 0:
  195. long_term_memory += 'Previous related conversations from long term memory:\n\n'
  196. for point in points:
  197. point = point.payload
  198. if self.calculate_num_tokens(long_term_memory + f"{point['user_message']['role'].title()}: {point['user_message']['content']}\n\n{point['assistant_message']['role'].title()}: {point['assistant_message']['content']}\n----------\n") > self.long_term_memory_max_tokens:
  199. continue
  200. long_term_memory += f"{point['user_message']['role'].title()}: {point['user_message']['content']}\n\n{point['assistant_message']['role'].title()}: {point['assistant_message']['content']}\n----------\n"
  201. if long_term_memory == 'Previous related conversations from long term memory:\n\n':
  202. return ''
  203. elif long_term_memory.endswith('\n\nPrevious related conversations from long term memory:\n\n'):
  204. long_term_memory = long_term_memory.replace('\n\nPrevious related conversations from long term memory:\n\n', '')
  205. return long_term_memory.strip()
  206. def add_message_to_short_term_memory(self, user_message: dict, assistant_message: dict) -> None:
  207. """
  208. Add a message to short term memory
  209. Parameters:
  210. user_message (dict): The user message to add to short term memory
  211. assistant_message (dict): The assistant message to add to short term memory
  212. """
  213. self.short_term_memory.append(user_message)
  214. self.short_term_memory.append(assistant_message)
  215. while self.calculate_short_term_memory_tokens() > self.short_term_memory_max_tokens:
  216. if self.summarize_short_term_memory:
  217. self.generate_short_term_memory_summary()
  218. self.short_term_memory.pop(0) # Remove the oldest message (User message)
  219. self.short_term_memory.pop(0) # Remove the oldest message (OpenAIAssistant message)
  220. def add_message_to_long_term_memory(self, user_message: dict, assistant_message: dict) -> None:
  221. """
  222. Add a message to long term memory
  223. Parameters:
  224. user_message (dict): The user message to add to long term memory
  225. assistant_message (dict): The assistant message to add to long term memory
  226. """
  227. points = [
  228. {
  229. "vector": self.get_embedding(f'User: {user_message["content"]}\n\nAssistant: {assistant_message["content"]}').data[0].embedding,
  230. "payload": {
  231. "user_message": user_message,
  232. "assistant_message": assistant_message,
  233. "timestamp": datetime.now().timestamp()
  234. }
  235. }
  236. ]
  237. self.memory_manager.insert_points(collection_name=self.long_term_memory_collection_name, points=points)
  238. def generate_short_term_memory_summary(self) -> None:
  239. """
  240. Generate a summary of short term memory
  241. """
  242. prompt = self.short_term_memory_summary_prompt.format(
  243. previous_summary=self.short_term_memory_summary,
  244. conversation=f'User: {self.short_term_memory[0]["content"]}\n\nAssistant: {self.short_term_memory[1]["content"]}'
  245. )
  246. if self.calculate_num_tokens(prompt) > self.max_seq_len - self.short_term_memory_summary_max_tokens:
  247. prompt = self.enc.decode(self.enc.encode(prompt)[:self.max_seq_len - self.short_term_memory_summary_max_tokens])
  248. summary_agent = OpenAIAssistant(self.api_key, system_prompt=None)
  249. self.short_term_memory_summary = summary_agent.get_chat_response(prompt, max_tokens=self.short_term_memory_summary_max_tokens).choices[0].message.content
  250. def generate_long_term_memory_summary(self, points: list) -> str:
  251. """
  252. Summarize long term memory
  253. Parameters:
  254. points (list): The points to summarize
  255. Returns:
  256. str: The summary of long term memory
  257. """
  258. prompt = self.long_term_memory_summary_prompt.format(
  259. previous_summary=self.long_term_memory_summary,
  260. conversation='\n\n'.join([f'User: {point.payload["user_message"]["content"]}\n\nAssistant: {point.payload["assistant_message"]["content"]}' for point in points])
  261. )
  262. if self.calculate_num_tokens(prompt) > self.max_seq_len - self.long_term_memory_summary_max_tokens:
  263. prompt = self.enc.decode(self.enc.encode(prompt)[:self.max_seq_len - self.long_term_memory_summary_max_tokens])
  264. summary_agent = OpenAIAssistant(self.api_key, system_prompt=None)
  265. self.long_term_memory_summary = summary_agent.get_chat_response(prompt, max_tokens=self.long_term_memory_summary_max_tokens).choices[0].message.content
  266. return self.long_term_memory_summary
  267. def calculate_price(self, prompt: str = None, num_tokens: int = None) -> float:
  268. """
  269. Calculate the price of a prompt (or number of tokens) in USD
  270. Parameters:
  271. prompt (str): The prompt to calculate the price of
  272. num_tokens (int): The number of tokens to calculate the price of
  273. Returns:
  274. float: The price of the generation in USD
  275. """
  276. assert prompt or num_tokens, "You must provide either a prompt or number of tokens"
  277. if prompt:
  278. num_tokens = self.calculate_num_tokens(prompt)
  279. return num_tokens * self.price_per_token
  280. def get_embedding(self, input: str, user: str = '', instructor_instruction: str = None) -> str:
  281. """
  282. Get the embedding for given text
  283. Parameters:
  284. input (str): The text to get the embedding for
  285. user (str): The user to get the embedding for
  286. instructor_instruction (str): The instructor instruction to get the embedding with
  287. Returns:
  288. str: The embedding for the prompt
  289. """
  290. if self.embedding_model is None:
  291. return None
  292. elif self.embedding_model == 'text-embedding-ada-002':
  293. return openai.Embedding.create(
  294. model=self.embedding_model,
  295. input=input,
  296. user=user
  297. )
  298. else:
  299. if instructor_instruction is not None:
  300. return self.embedding_model.encode([[instructor_instruction, input]])
  301. return self.embedding_model.encode([input])
  302. def get_chat_response(self, prompt: str, max_tokens: int = None, temperature: float = 1.0, top_p: float = 1.0, n: int = 1, stream: bool = False, frequency_penalty: float = 0, presence_penalty: float = 0, stop: list = None, logit_bias: dict = {}, user: str = '', max_retries: int = 3, inject_messages: list = []) -> str:
  303. """
  304. Get a chat response from the model
  305. Parameters:
  306. prompt (str): The prompt to generate a response for
  307. max_tokens (int): The maximum number of tokens to generate
  308. temperature (float): The temperature of the model
  309. top_p (float): The top_p of the model
  310. n (int): The number of responses to generate
  311. stream (bool): Whether to stream the response
  312. frequency_penalty (float): The frequency penalty of the model
  313. presence_penalty (float): The presence penalty of the model
  314. stop (list): The stop sequence of the model
  315. logit_bias (dict): The logit bias of the model
  316. user (str): The user to generate the response for
  317. max_retries (int): The maximum number of retries to generate a response
  318. inject_messages (list): The messages to inject into the prompt (key: index to insert at in short term memory (0 to prepend before all messages), value: message to inject)
  319. Returns:
  320. str: The chat response
  321. """
  322. messages = self._construct_messages(prompt, inject_messages=inject_messages)
  323. if self.debug:
  324. print(f'Messages: {messages}')
  325. iteration = 0
  326. while True:
  327. try:
  328. response = openai.ChatCompletion.create(
  329. model=self.chat_model,
  330. messages=messages,
  331. temperature=temperature,
  332. top_p=top_p,
  333. n=n,
  334. stream=stream,
  335. stop=stop,
  336. max_tokens=max_tokens,
  337. presence_penalty=presence_penalty,
  338. frequency_penalty=frequency_penalty,
  339. logit_bias=logit_bias,
  340. user=user
  341. )
  342. if self.use_short_term_memory:
  343. self.add_message_to_short_term_memory(user_message={
  344. "role": "user",
  345. "content": prompt
  346. }, assistant_message=response.choices[0].message.to_dict())
  347. if self.use_long_term_memory:
  348. self.add_message_to_long_term_memory(user_message={
  349. "role": "user",
  350. "content": prompt
  351. }, assistant_message=response.choices[0].message.to_dict())
  352. return response
  353. except Exception as e:
  354. iteration += 1
  355. if iteration >= max_retries:
  356. raise e
  357. print('Error communicating with chatGPT:', e)
  358. sleep(1)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...