bot.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. import markdown2
  2. import duckdb
  3. import tiktoken
  4. import magic
  5. import asyncio
  6. from PIL import Image
  7. from nio import (
  8. AsyncClient,
  9. AsyncClientConfig,
  10. WhoamiResponse,
  11. DevicesResponse,
  12. Event,
  13. Response,
  14. MatrixRoom,
  15. Api,
  16. RoomMessagesError,
  17. MegolmEvent,
  18. GroupEncryptionError,
  19. EncryptionError,
  20. RoomMessageText,
  21. RoomSendResponse,
  22. SyncResponse
  23. )
  24. from nio.crypto import Olm
  25. from typing import Optional, List, Dict, Tuple
  26. from configparser import ConfigParser
  27. from datetime import datetime
  28. from io import BytesIO
  29. import uuid
  30. from .logging import Logger
  31. from migrations import migrate
  32. from callbacks import RESPONSE_CALLBACKS, EVENT_CALLBACKS
  33. from commands import COMMANDS
  34. from .store import DuckDBStore
  35. from .openai import OpenAI
  36. from .wolframalpha import WolframAlpha
  37. class GPTBot:
  38. # Default values
  39. database: Optional[duckdb.DuckDBPyConnection] = None
  40. default_room_name: str = "GPTBot" # Default name of rooms created by the bot
  41. default_system_message: str = "You are a helpful assistant."
  42. # Force default system message to be included even if a custom room message is set
  43. force_system_message: bool = False
  44. max_tokens: int = 3000 # Maximum number of input tokens
  45. max_messages: int = 30 # Maximum number of messages to consider as input
  46. matrix_client: Optional[AsyncClient] = None
  47. sync_token: Optional[str] = None
  48. logger: Optional[Logger] = Logger()
  49. chat_api: Optional[OpenAI] = None
  50. image_api: Optional[OpenAI] = None
  51. @classmethod
  52. def from_config(cls, config: ConfigParser):
  53. """Create a new GPTBot instance from a config file.
  54. Args:
  55. config (ConfigParser): ConfigParser instance with the bot's config.
  56. Returns:
  57. GPTBot: The new GPTBot instance.
  58. """
  59. # Create a new GPTBot instance
  60. bot = cls()
  61. # Set the database connection
  62. bot.database = duckdb.connect(
  63. config["Database"]["Path"]) if "Database" in config and "Path" in config["Database"] else None
  64. # Override default values
  65. if "GPTBot" in config:
  66. bot.default_room_name = config["GPTBot"].get(
  67. "DefaultRoomName", bot.default_room_name)
  68. bot.default_system_message = config["GPTBot"].get(
  69. "SystemMessage", bot.default_system_message)
  70. bot.force_system_message = config["GPTBot"].getboolean(
  71. "ForceSystemMessage", bot.force_system_message)
  72. bot.chat_api = bot.image_api = OpenAI(config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"), bot.logger)
  73. bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
  74. bot.max_messages = config["OpenAI"].getint(
  75. "MaxMessages", bot.max_messages)
  76. # Set up WolframAlpha
  77. if "WolframAlpha" in config:
  78. bot.calculation_api = WolframAlpha(
  79. config["WolframAlpha"]["APIKey"], bot.logger)
  80. # Set up the Matrix client
  81. assert "Matrix" in config, "Matrix config not found"
  82. homeserver = config["Matrix"]["Homeserver"]
  83. bot.matrix_client = AsyncClient(homeserver)
  84. bot.matrix_client.access_token = config["Matrix"]["AccessToken"]
  85. bot.matrix_client.user_id = config["Matrix"].get("UserID")
  86. bot.matrix_client.device_id = config["Matrix"].get("DeviceID")
  87. # Return the new GPTBot instance
  88. return bot
  89. async def _get_user_id(self) -> str:
  90. """Get the user ID of the bot from the whoami endpoint.
  91. Requires an access token to be set up.
  92. Returns:
  93. str: The user ID of the bot.
  94. """
  95. assert self.matrix_client, "Matrix client not set up"
  96. user_id = self.matrix_client.user_id
  97. if not user_id:
  98. assert self.matrix_client.access_token, "Access token not set up"
  99. response = await self.matrix_client.whoami()
  100. if isinstance(response, WhoamiResponse):
  101. user_id = response.user_id
  102. else:
  103. raise Exception(f"Could not get user ID: {response}")
  104. return user_id
  105. async def _last_n_messages(self, room: str | MatrixRoom, n: Optional[int]):
  106. messages = []
  107. n = n or bot.max_messages
  108. room_id = room.room_id if isinstance(room, MatrixRoom) else room
  109. self.logger.log(
  110. f"Fetching last {2*n} messages from room {room_id} (starting at {self.sync_token})...")
  111. response = await self.matrix_client.room_messages(
  112. room_id=room_id,
  113. start=self.sync_token,
  114. limit=2*n,
  115. )
  116. if isinstance(response, RoomMessagesError):
  117. raise Exception(
  118. f"Error fetching messages: {response.message} (status code {response.status_code})", "error")
  119. for event in response.chunk:
  120. if len(messages) >= n:
  121. break
  122. if isinstance(event, MegolmEvent):
  123. try:
  124. event = await self.matrix_client.decrypt_event(event)
  125. except (GroupEncryptionError, EncryptionError):
  126. self.logger.log(
  127. f"Could not decrypt message {event.event_id} in room {room_id}", "error")
  128. continue
  129. if isinstance(event, RoomMessageText):
  130. if event.body.startswith("!gptbot ignoreolder"):
  131. break
  132. if not event.body.startswith("!"):
  133. messages.append(event)
  134. self.logger.log(f"Found {len(messages)} messages (limit: {n})")
  135. # Reverse the list so that messages are in chronological order
  136. return messages[::-1]
  137. def _truncate(self, messages: list, max_tokens: Optional[int] = None,
  138. model: Optional[str] = None, system_message: Optional[str] = None):
  139. max_tokens = max_tokens or self.max_tokens
  140. model = model or self.chat_api.chat_model
  141. system_message = self.default_system_message if system_message is None else system_message
  142. encoding = tiktoken.encoding_for_model(model)
  143. total_tokens = 0
  144. system_message_tokens = 0 if not system_message else (
  145. len(encoding.encode(system_message)) + 1)
  146. if system_message_tokens > max_tokens:
  147. self.logger.log(
  148. f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed", "error")
  149. return []
  150. total_tokens += system_message_tokens
  151. total_tokens = len(system_message) + 1
  152. truncated_messages = []
  153. for message in [messages[0]] + list(reversed(messages[1:])):
  154. content = message["content"]
  155. tokens = len(encoding.encode(content)) + 1
  156. if total_tokens + tokens > max_tokens:
  157. break
  158. total_tokens += tokens
  159. truncated_messages.append(message)
  160. return [truncated_messages[0]] + list(reversed(truncated_messages[1:]))
  161. async def _get_device_id(self) -> str:
  162. """Guess the device ID of the bot.
  163. Requires an access token to be set up.
  164. Returns:
  165. str: The guessed device ID.
  166. """
  167. assert self.matrix_client, "Matrix client not set up"
  168. device_id = self.matrix_client.device_id
  169. if not device_id:
  170. assert self.matrix_client.access_token, "Access token not set up"
  171. devices = await self.matrix_client.devices()
  172. if isinstance(devices, DevicesResponse):
  173. device_id = devices.devices[0].id
  174. return device_id
  175. async def process_command(self, room: MatrixRoom, event: RoomMessageText):
  176. self.logger.log(
  177. f"Received command {event.body} from {event.sender} in room {room.room_id}")
  178. command = event.body.split()[1] if event.body.split()[1:] else None
  179. await COMMANDS.get(command, COMMANDS[None])(room, event, self)
  180. async def event_callback(self,room: MatrixRoom, event: Event):
  181. self.logger.log("Received event: " + str(event.event_id), "debug")
  182. try:
  183. for eventtype, callback in EVENT_CALLBACKS.items():
  184. if isinstance(event, eventtype):
  185. await callback(room, event, self)
  186. except Exception as e:
  187. self.logger.log(f"Error in event callback for {event.__class__}: {e}", "error")
  188. async def response_callback(self, response: Response):
  189. for response_type, callback in RESPONSE_CALLBACKS.items():
  190. if isinstance(response, response_type):
  191. await callback(response, self)
  192. async def accept_pending_invites(self):
  193. """Accept all pending invites."""
  194. assert self.matrix_client, "Matrix client not set up"
  195. invites = self.matrix_client.invited_rooms
  196. for invite in invites.keys():
  197. await self.matrix_client.join(invite)
  198. async def send_image(self, room: MatrixRoom, image: bytes, message: Optional[str] = None):
  199. self.logger.log(
  200. f"Sending image of size {len(image)} bytes to room {room.room_id}")
  201. bio = BytesIO(image)
  202. img = Image.open(bio)
  203. mime = Image.MIME[img.format]
  204. (width, height) = img.size
  205. self.logger.log(
  206. f"Uploading - Image size: {width}x{height} pixels, MIME type: {mime}")
  207. bio.seek(0)
  208. response, _ = await self.matrix_client.upload(
  209. bio,
  210. content_type=mime,
  211. filename="image",
  212. filesize=len(image)
  213. )
  214. self.logger.log("Uploaded image - sending message...")
  215. content = {
  216. "body": message or "",
  217. "info": {
  218. "mimetype": mime,
  219. "size": len(image),
  220. "w": width,
  221. "h": height,
  222. },
  223. "msgtype": "m.image",
  224. "url": response.content_uri
  225. }
  226. status = await self.matrix_client.room_send(
  227. room.room_id,
  228. "m.room.message",
  229. content
  230. )
  231. self.logger.log(str(status), "debug")
  232. self.logger.log("Sent image")
  233. async def send_message(self, room: MatrixRoom, message: str, notice: bool = False):
  234. markdowner = markdown2.Markdown(extras=["fenced-code-blocks"])
  235. formatted_body = markdowner.convert(message)
  236. msgtype = "m.notice" if notice else "m.text"
  237. msgcontent = {"msgtype": msgtype, "body": message,
  238. "format": "org.matrix.custom.html", "formatted_body": formatted_body}
  239. content = None
  240. if self.matrix_client.olm and room.encrypted:
  241. try:
  242. if not room.members_synced:
  243. responses = []
  244. responses.append(await self.matrix_client.joined_members(room.room_id))
  245. if self.matrix_client.olm.should_share_group_session(room.room_id):
  246. try:
  247. event = self.matrix_client.sharing_session[room.room_id]
  248. await event.wait()
  249. except KeyError:
  250. await self.matrix_client.share_group_session(
  251. room.room_id,
  252. ignore_unverified_devices=True,
  253. )
  254. if msgtype != "m.reaction":
  255. response = self.matrix_client.encrypt(
  256. room.room_id, "m.room.message", msgcontent)
  257. msgtype, content = response
  258. except Exception as e:
  259. self.logger.log(
  260. f"Error encrypting message: {e} - sending unencrypted", "error")
  261. raise
  262. if not content:
  263. msgtype = "m.room.message"
  264. content = msgcontent
  265. method, path, data = Api.room_send(
  266. self.matrix_client.access_token, room.room_id, msgtype, content, uuid.uuid4()
  267. )
  268. return await self.matrix_client._send(RoomSendResponse, method, path, data, (room.room_id,))
  269. async def run(self):
  270. """Start the bot."""
  271. # Set up the Matrix client
  272. assert self.matrix_client, "Matrix client not set up"
  273. assert self.matrix_client.access_token, "Access token not set up"
  274. if not self.matrix_client.user_id:
  275. self.matrix_client.user_id = await self._get_user_id()
  276. if not self.matrix_client.device_id:
  277. self.matrix_client.device_id = await self._get_device_id()
  278. # Set up database
  279. IN_MEMORY = False
  280. if not self.database:
  281. self.logger.log(
  282. "No database connection set up, using in-memory database. Data will be lost on bot shutdown.")
  283. IN_MEMORY = True
  284. self.database = DuckDBPyConnection(":memory:")
  285. self.logger.log("Running migrations...")
  286. before, after = migrate(self.database)
  287. if before != after:
  288. self.logger.log(f"Migrated from version {before} to {after}.")
  289. else:
  290. self.logger.log(f"Already at latest version {after}.")
  291. if IN_MEMORY:
  292. client_config = AsyncClientConfig(
  293. store_sync_tokens=True, encryption_enabled=False)
  294. else:
  295. matrix_store = DuckDBStore
  296. client_config = AsyncClientConfig(
  297. store_sync_tokens=True, encryption_enabled=True, store=matrix_store)
  298. self.matrix_client.config = client_config
  299. self.matrix_client.store = matrix_store(
  300. self.matrix_client.user_id,
  301. self.matrix_client.device_id,
  302. self.database
  303. )
  304. self.matrix_client.olm = Olm(
  305. self.matrix_client.user_id,
  306. self.matrix_client.device_id,
  307. self.matrix_client.store
  308. )
  309. self.matrix_client.encrypted_rooms = self.matrix_client.store.load_encrypted_rooms()
  310. # Run initial sync
  311. sync = await self.matrix_client.sync(timeout=30000)
  312. if isinstance(sync, SyncResponse):
  313. await self.response_callback(sync)
  314. else:
  315. self.logger.log(f"Initial sync failed, aborting: {sync}", "error")
  316. return
  317. # Set up callbacks
  318. self.matrix_client.add_event_callback(self.event_callback, Event)
  319. self.matrix_client.add_response_callback(self.response_callback, Response)
  320. # Accept pending invites
  321. self.logger.log("Accepting pending invites...")
  322. await self.accept_pending_invites()
  323. # Start syncing events
  324. self.logger.log("Starting sync loop...")
  325. try:
  326. await self.matrix_client.sync_forever(timeout=30000)
  327. finally:
  328. self.logger.log("Syncing one last time...")
  329. await self.matrix_client.sync(timeout=30000)
  330. async def process_query(self, room: MatrixRoom, event: RoomMessageText):
  331. await self.matrix_client.room_typing(room.room_id, True)
  332. await self.matrix_client.room_read_markers(room.room_id, event.event_id)
  333. try:
  334. last_messages = await self._last_n_messages(room.room_id, 20)
  335. except Exception as e:
  336. self.logger.log(f"Error getting last messages: {e}", "error")
  337. await self.send_message(
  338. room, "Something went wrong. Please try again.", True)
  339. return
  340. system_message = self.get_system_message(room)
  341. chat_messages = [{"role": "system", "content": system_message}]
  342. for message in last_messages:
  343. role = "assistant" if message.sender == self.matrix_client.user_id else "user"
  344. if not message.event_id == event.event_id:
  345. chat_messages.append({"role": role, "content": message.body})
  346. chat_messages.append({"role": "user", "content": event.body})
  347. # Truncate messages to fit within the token limit
  348. truncated_messages = self._truncate(
  349. chat_messages, self.max_tokens - 1, system_message=system_message)
  350. try:
  351. response, tokens_used = await self.generate_chat_response(truncated_messages)
  352. except Exception as e:
  353. self.logger.log(f"Error generating response: {e}", "error")
  354. await self.send_message(
  355. room, "Something went wrong. Please try again.", True)
  356. return
  357. if response:
  358. self.logger.log(f"Sending response to room {room.room_id}...")
  359. # Convert markdown to HTML
  360. message = await self.send_message(room, response)
  361. if self.database:
  362. self.logger.log("Storing record of tokens used...")
  363. with self.database.cursor() as cursor:
  364. cursor.execute(
  365. "INSERT INTO token_usage (message_id, room_id, tokens, timestamp) VALUES (?, ?, ?, ?)",
  366. (message.event_id, room.room_id, tokens_used, datetime.now()))
  367. self.database.commit()
  368. else:
  369. # Send a notice to the room if there was an error
  370. self.logger.log("Didn't get a response from GPT API", "error")
  371. send_message(
  372. room, "Something went wrong. Please try again.", True)
  373. await self.matrix_client.room_typing(room.room_id, False)
  374. async def generate_chat_response(self, messages: List[Dict[str, str]]) -> Tuple[str, int]:
  375. """Generate a response to a chat message.
  376. Args:
  377. messages (List[Dict[str, str]]): A list of messages to use as context.
  378. Returns:
  379. Tuple[str, int]: The response text and the number of tokens used.
  380. """
  381. return self.chat_api.generate_chat_response(messages)
  382. def get_system_message(self, room: MatrixRoom | int) -> str:
  383. default = self.default_system_message
  384. if isinstance(room, int):
  385. room_id = room
  386. else:
  387. room_id = room.room_id
  388. with self.database.cursor() as cur:
  389. cur.execute(
  390. "SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
  391. (room_id,)
  392. )
  393. system_message = cur.fetchone()
  394. complete = ((default if ((not system_message) or self.force_system_message) else "") + (
  395. "\n\n" + system_message[0] if system_message else "")).strip()
  396. return complete
  397. def __del__(self):
  398. """Close the bot."""
  399. if self.matrix_client:
  400. asyncio.run(self.matrix_client.close())
  401. if self.database:
  402. self.database.close()