bot.py 20 KB


  1. import markdown2
  2. import duckdb
  3. import tiktoken
  4. import magic
  5. import asyncio
  6. from PIL import Image
  7. from nio import (
  8. AsyncClient,
  9. AsyncClientConfig,
  10. WhoamiResponse,
  11. DevicesResponse,
  12. Event,
  13. Response,
  14. MatrixRoom,
  15. Api,
  16. RoomMessagesError,
  17. MegolmEvent,
  18. GroupEncryptionError,
  19. EncryptionError,
  20. RoomMessageText,
  21. RoomSendResponse,
  22. SyncResponse,
  23. RoomMessageNotice
  24. )
  25. from nio.crypto import Olm
  26. from typing import Optional, List, Dict, Tuple
  27. from configparser import ConfigParser
  28. from datetime import datetime
  29. from io import BytesIO
  30. import uuid
  31. from .logging import Logger
  32. from migrations import migrate
  33. from callbacks import RESPONSE_CALLBACKS, EVENT_CALLBACKS
  34. from commands import COMMANDS
  35. from .store import DuckDBStore
  36. from .openai import OpenAI
  37. from .wolframalpha import WolframAlpha
  38. class GPTBot:
  39. # Default values
  40. database: Optional[duckdb.DuckDBPyConnection] = None
  41. default_room_name: str = "GPTBot" # Default name of rooms created by the bot
  42. default_system_message: str = "You are a helpful assistant."
  43. # Force default system message to be included even if a custom room message is set
  44. force_system_message: bool = False
  45. max_tokens: int = 3000 # Maximum number of input tokens
  46. max_messages: int = 30 # Maximum number of messages to consider as input
  47. matrix_client: Optional[AsyncClient] = None
  48. sync_token: Optional[str] = None
  49. logger: Optional[Logger] = Logger()
  50. chat_api: Optional[OpenAI] = None
  51. image_api: Optional[OpenAI] = None
  52. classification_api: Optional[OpenAI] = None
  53. operator: Optional[str] = None
  54. @classmethod
  55. def from_config(cls, config: ConfigParser):
  56. """Create a new GPTBot instance from a config file.
  57. Args:
  58. config (ConfigParser): ConfigParser instance with the bot's config.
  59. Returns:
  60. GPTBot: The new GPTBot instance.
  61. """
  62. # Create a new GPTBot instance
  63. bot = cls()
  64. # Set the database connection
  65. bot.database = duckdb.connect(
  66. config["Database"]["Path"]) if "Database" in config and "Path" in config["Database"] else None
  67. # Override default values
  68. if "GPTBot" in config:
  69. bot.operator = config["GPTBot"].get("Operator", bot.operator)
  70. bot.default_room_name = config["GPTBot"].get(
  71. "DefaultRoomName", bot.default_room_name)
  72. bot.default_system_message = config["GPTBot"].get(
  73. "SystemMessage", bot.default_system_message)
  74. bot.force_system_message = config["GPTBot"].getboolean(
  75. "ForceSystemMessage", bot.force_system_message)
  76. bot.chat_api = bot.image_api = bot.classification_api = OpenAI(
  77. config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"), bot.logger)
  78. bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
  79. bot.max_messages = config["OpenAI"].getint(
  80. "MaxMessages", bot.max_messages)
  81. # Set up WolframAlpha
  82. if "WolframAlpha" in config:
  83. bot.calculation_api = WolframAlpha(
  84. config["WolframAlpha"]["APIKey"], bot.logger)
  85. # Set up the Matrix client
  86. assert "Matrix" in config, "Matrix config not found"
  87. homeserver = config["Matrix"]["Homeserver"]
  88. bot.matrix_client = AsyncClient(homeserver)
  89. bot.matrix_client.access_token = config["Matrix"]["AccessToken"]
  90. bot.matrix_client.user_id = config["Matrix"].get("UserID")
  91. bot.matrix_client.device_id = config["Matrix"].get("DeviceID")
  92. # Return the new GPTBot instance
  93. return bot
  94. async def _get_user_id(self) -> str:
  95. """Get the user ID of the bot from the whoami endpoint.
  96. Requires an access token to be set up.
  97. Returns:
  98. str: The user ID of the bot.
  99. """
  100. assert self.matrix_client, "Matrix client not set up"
  101. user_id = self.matrix_client.user_id
  102. if not user_id:
  103. assert self.matrix_client.access_token, "Access token not set up"
  104. response = await self.matrix_client.whoami()
  105. if isinstance(response, WhoamiResponse):
  106. user_id = response.user_id
  107. else:
  108. raise Exception(f"Could not get user ID: {response}")
  109. return user_id
  110. async def _last_n_messages(self, room: str | MatrixRoom, n: Optional[int]):
  111. messages = []
  112. n = n or bot.max_messages
  113. room_id = room.room_id if isinstance(room, MatrixRoom) else room
  114. self.logger.log(
  115. f"Fetching last {2*n} messages from room {room_id} (starting at {self.sync_token})...")
  116. response = await self.matrix_client.room_messages(
  117. room_id=room_id,
  118. start=self.sync_token,
  119. limit=2*n,
  120. )
  121. if isinstance(response, RoomMessagesError):
  122. raise Exception(
  123. f"Error fetching messages: {response.message} (status code {response.status_code})", "error")
  124. for event in response.chunk:
  125. if len(messages) >= n:
  126. break
  127. if isinstance(event, MegolmEvent):
  128. try:
  129. event = await self.matrix_client.decrypt_event(event)
  130. except (GroupEncryptionError, EncryptionError):
  131. self.logger.log(
  132. f"Could not decrypt message {event.event_id} in room {room_id}", "error")
  133. continue
  134. if isinstance(event, (RoomMessageText, RoomMessageNotice)):
  135. if event.body.startswith("!gptbot ignoreolder"):
  136. break
  137. if (not event.body.startswith("!")) or (event.body.startswith("!gptbot")):
  138. messages.append(event)
  139. self.logger.log(f"Found {len(messages)} messages (limit: {n})")
  140. # Reverse the list so that messages are in chronological order
  141. return messages[::-1]
  142. def _truncate(self, messages: list, max_tokens: Optional[int] = None,
  143. model: Optional[str] = None, system_message: Optional[str] = None):
  144. max_tokens = max_tokens or self.max_tokens
  145. model = model or self.chat_api.chat_model
  146. system_message = self.default_system_message if system_message is None else system_message
  147. encoding = tiktoken.encoding_for_model(model)
  148. total_tokens = 0
  149. system_message_tokens = 0 if not system_message else (
  150. len(encoding.encode(system_message)) + 1)
  151. if system_message_tokens > max_tokens:
  152. self.logger.log(
  153. f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed", "error")
  154. return []
  155. total_tokens += system_message_tokens
  156. total_tokens = len(system_message) + 1
  157. truncated_messages = []
  158. for message in [messages[0]] + list(reversed(messages[1:])):
  159. content = message["content"]
  160. tokens = len(encoding.encode(content)) + 1
  161. if total_tokens + tokens > max_tokens:
  162. break
  163. total_tokens += tokens
  164. truncated_messages.append(message)
  165. return [truncated_messages[0]] + list(reversed(truncated_messages[1:]))
  166. async def _get_device_id(self) -> str:
  167. """Guess the device ID of the bot.
  168. Requires an access token to be set up.
  169. Returns:
  170. str: The guessed device ID.
  171. """
  172. assert self.matrix_client, "Matrix client not set up"
  173. device_id = self.matrix_client.device_id
  174. if not device_id:
  175. assert self.matrix_client.access_token, "Access token not set up"
  176. devices = await self.matrix_client.devices()
  177. if isinstance(devices, DevicesResponse):
  178. device_id = devices.devices[0].id
  179. return device_id
  180. async def process_command(self, room: MatrixRoom, event: RoomMessageText):
  181. self.logger.log(
  182. f"Received command {event.body} from {event.sender} in room {room.room_id}")
  183. command = event.body.split()[1] if event.body.split()[1:] else None
  184. await COMMANDS.get(command, COMMANDS[None])(room, event, self)
  185. def room_uses_classification(self, room: MatrixRoom | int) -> bool:
  186. """Check if a room uses classification.
  187. Args:
  188. room (MatrixRoom): The room to check.
  189. Returns:
  190. bool: Whether the room uses classification.
  191. """
  192. room_id = room.room_id if isinstance(room, MatrixRoom) else room
  193. with self.database.cursor() as cursor:
  194. cursor.execute(
  195. "SELECT value FROM room_settings WHERE room_id = ? AND setting = ?", (room_id, "use_classification"))
  196. result = cursor.fetchone()
  197. return False if not result else bool(int(result[0]))
  198. async def event_callback(self, room: MatrixRoom, event: Event):
  199. self.logger.log("Received event: " + str(event.event_id), "debug")
  200. try:
  201. for eventtype, callback in EVENT_CALLBACKS.items():
  202. if isinstance(event, eventtype):
  203. await callback(room, event, self)
  204. except Exception as e:
  205. self.logger.log(
  206. f"Error in event callback for {event.__class__}: {e}", "error")
  207. async def response_callback(self, response: Response):
  208. for response_type, callback in RESPONSE_CALLBACKS.items():
  209. if isinstance(response, response_type):
  210. await callback(response, self)
  211. async def accept_pending_invites(self):
  212. """Accept all pending invites."""
  213. assert self.matrix_client, "Matrix client not set up"
  214. invites = self.matrix_client.invited_rooms
  215. for invite in invites.keys():
  216. await self.matrix_client.join(invite)
  217. async def send_image(self, room: MatrixRoom, image: bytes, message: Optional[str] = None):
  218. self.logger.log(
  219. f"Sending image of size {len(image)} bytes to room {room.room_id}")
  220. bio = BytesIO(image)
  221. img = Image.open(bio)
  222. mime = Image.MIME[img.format]
  223. (width, height) = img.size
  224. self.logger.log(
  225. f"Uploading - Image size: {width}x{height} pixels, MIME type: {mime}")
  226. bio.seek(0)
  227. response, _ = await self.matrix_client.upload(
  228. bio,
  229. content_type=mime,
  230. filename="image",
  231. filesize=len(image)
  232. )
  233. self.logger.log("Uploaded image - sending message...")
  234. content = {
  235. "body": message or "",
  236. "info": {
  237. "mimetype": mime,
  238. "size": len(image),
  239. "w": width,
  240. "h": height,
  241. },
  242. "msgtype": "m.image",
  243. "url": response.content_uri
  244. }
  245. status = await self.matrix_client.room_send(
  246. room.room_id,
  247. "m.room.message",
  248. content
  249. )
  250. self.logger.log(str(status), "debug")
  251. self.logger.log("Sent image")
  252. async def send_message(self, room: MatrixRoom, message: str, notice: bool = False):
  253. markdowner = markdown2.Markdown(extras=["fenced-code-blocks"])
  254. formatted_body = markdowner.convert(message)
  255. msgtype = "m.notice" if notice else "m.text"
  256. msgcontent = {"msgtype": msgtype, "body": message,
  257. "format": "org.matrix.custom.html", "formatted_body": formatted_body}
  258. content = None
  259. if self.matrix_client.olm and room.encrypted:
  260. try:
  261. if not room.members_synced:
  262. responses = []
  263. responses.append(await self.matrix_client.joined_members(room.room_id))
  264. if self.matrix_client.olm.should_share_group_session(room.room_id):
  265. try:
  266. event = self.matrix_client.sharing_session[room.room_id]
  267. await event.wait()
  268. except KeyError:
  269. await self.matrix_client.share_group_session(
  270. room.room_id,
  271. ignore_unverified_devices=True,
  272. )
  273. if msgtype != "m.reaction":
  274. response = self.matrix_client.encrypt(
  275. room.room_id, "m.room.message", msgcontent)
  276. msgtype, content = response
  277. except Exception as e:
  278. self.logger.log(
  279. f"Error encrypting message: {e} - sending unencrypted", "error")
  280. raise
  281. if not content:
  282. msgtype = "m.room.message"
  283. content = msgcontent
  284. method, path, data = Api.room_send(
  285. self.matrix_client.access_token, room.room_id, msgtype, content, uuid.uuid4()
  286. )
  287. return await self.matrix_client._send(RoomSendResponse, method, path, data, (room.room_id,))
  288. def log_api_usage(self, message: Event | str, room: MatrixRoom | int, api: str, tokens: int):
  289. """Log API usage to the database.
  290. Args:
  291. message (Event): The event that triggered the API usage.
  292. room (MatrixRoom | int): The room the event was sent in.
  293. api (str): The API that was used.
  294. tokens (int): The number of tokens used.
  295. """
  296. if not self.database:
  297. return
  298. if isinstance(message, Event):
  299. message = message.event_id
  300. if isinstance(room, MatrixRoom):
  301. room = room.room_id
  302. self.database.execute(
  303. "INSERT INTO token_usage (message_id, room_id, tokens, api, timestamp) VALUES (?, ?, ?, ?, ?)",
  304. (message, room, tokens, api, datetime.now())
  305. )
  306. async def run(self):
  307. """Start the bot."""
  308. # Set up the Matrix client
  309. assert self.matrix_client, "Matrix client not set up"
  310. assert self.matrix_client.access_token, "Access token not set up"
  311. if not self.matrix_client.user_id:
  312. self.matrix_client.user_id = await self._get_user_id()
  313. if not self.matrix_client.device_id:
  314. self.matrix_client.device_id = await self._get_device_id()
  315. # Set up database
  316. IN_MEMORY = False
  317. if not self.database:
  318. self.logger.log(
  319. "No database connection set up, using in-memory database. Data will be lost on bot shutdown.")
  320. IN_MEMORY = True
  321. self.database = DuckDBPyConnection(":memory:")
  322. self.logger.log("Running migrations...")
  323. before, after = migrate(self.database)
  324. if before != after:
  325. self.logger.log(f"Migrated from version {before} to {after}.")
  326. else:
  327. self.logger.log(f"Already at latest version {after}.")
  328. if IN_MEMORY:
  329. client_config = AsyncClientConfig(
  330. store_sync_tokens=True, encryption_enabled=False)
  331. else:
  332. matrix_store = DuckDBStore
  333. client_config = AsyncClientConfig(
  334. store_sync_tokens=True, encryption_enabled=True, store=matrix_store)
  335. self.matrix_client.config = client_config
  336. self.matrix_client.store = matrix_store(
  337. self.matrix_client.user_id,
  338. self.matrix_client.device_id,
  339. self.database
  340. )
  341. self.matrix_client.olm = Olm(
  342. self.matrix_client.user_id,
  343. self.matrix_client.device_id,
  344. self.matrix_client.store
  345. )
  346. self.matrix_client.encrypted_rooms = self.matrix_client.store.load_encrypted_rooms()
  347. # Run initial sync
  348. sync = await self.matrix_client.sync(timeout=30000)
  349. if isinstance(sync, SyncResponse):
  350. await self.response_callback(sync)
  351. else:
  352. self.logger.log(f"Initial sync failed, aborting: {sync}", "error")
  353. return
  354. # Set up callbacks
  355. self.matrix_client.add_event_callback(self.event_callback, Event)
  356. self.matrix_client.add_response_callback(
  357. self.response_callback, Response)
  358. # Accept pending invites
  359. self.logger.log("Accepting pending invites...")
  360. await self.accept_pending_invites()
  361. # Start syncing events
  362. self.logger.log("Starting sync loop...")
  363. try:
  364. await self.matrix_client.sync_forever(timeout=30000)
  365. finally:
  366. self.logger.log("Syncing one last time...")
  367. await self.matrix_client.sync(timeout=30000)
  368. async def process_query(self, room: MatrixRoom, event: RoomMessageText, allow_classify: bool = True):
  369. await self.matrix_client.room_typing(room.room_id, True)
  370. await self.matrix_client.room_read_markers(room.room_id, event.event_id)
  371. if allow_classify and self.room_uses_classification(room):
  372. classification, tokens = self.classification_api.classify_message(event.body, room.room_id)
  373. self.log_api_usage(event, room, f"{self.classification_api.api_code}-{self.classification_api.classification_api}", tokens)
  374. if not classification["type"] == "chat":
  375. event.body = f"!gptbot {classification['type']} {classification['prompt']}"
  376. await self.process_command(room, event)
  377. return
  378. try:
  379. last_messages = await self._last_n_messages(room.room_id, 20)
  380. except Exception as e:
  381. self.logger.log(f"Error getting last messages: {e}", "error")
  382. await self.send_message(
  383. room, "Something went wrong. Please try again.", True)
  384. return
  385. system_message = self.get_system_message(room)
  386. chat_messages = [{"role": "system", "content": system_message}]
  387. for message in last_messages:
  388. role = "assistant" if message.sender == self.matrix_client.user_id else "user"
  389. if not message.event_id == event.event_id:
  390. chat_messages.append({"role": role, "content": message.body})
  391. chat_messages.append({"role": "user", "content": event.body})
  392. # Truncate messages to fit within the token limit
  393. truncated_messages = self._truncate(
  394. chat_messages, self.max_tokens - 1, system_message=system_message)
  395. try:
  396. response, tokens_used = self.chat_api.generate_chat_response(
  397. chat_messages, user=room.room_id)
  398. except Exception as e:
  399. self.logger.log(f"Error generating response: {e}", "error")
  400. await self.send_message(
  401. room, "Something went wrong. Please try again.", True)
  402. return
  403. if response:
  404. self.log_api_usage(event, room, f"{self.chat_api.api_code}-{self.chat_api.chat_api}", tokens_used)
  405. self.logger.log(f"Sending response to room {room.room_id}...")
  406. # Convert markdown to HTML
  407. message = await self.send_message(room, response)
  408. else:
  409. # Send a notice to the room if there was an error
  410. self.logger.log("Didn't get a response from GPT API", "error")
  411. await send_message(
  412. room, "Something went wrong. Please try again.", True)
  413. await self.matrix_client.room_typing(room.room_id, False)
  414. def get_system_message(self, room: MatrixRoom | int) -> str:
  415. default = self.default_system_message
  416. if isinstance(room, int):
  417. room_id = room
  418. else:
  419. room_id = room.room_id
  420. with self.database.cursor() as cur:
  421. cur.execute(
  422. "SELECT value FROM room_settings WHERE room_id = ? AND setting = ?",
  423. (room_id, "system_message")
  424. )
  425. system_message = cur.fetchone()
  426. complete = ((default if ((not system_message) or self.force_system_message) else "") + (
  427. "\n\n" + system_message[0] if system_message else "")).strip()
  428. return complete
  429. def __del__(self):
  430. """Close the bot."""
  431. if self.matrix_client:
  432. asyncio.run(self.matrix_client.close())
  433. if self.database:
  434. self.database.close()