bot.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852
  1. import markdown2
  2. import duckdb
  3. import tiktoken
  4. import asyncio
  5. import functools
  6. from PIL import Image
  7. from nio import (
  8. AsyncClient,
  9. AsyncClientConfig,
  10. WhoamiResponse,
  11. DevicesResponse,
  12. Event,
  13. Response,
  14. MatrixRoom,
  15. Api,
  16. RoomMessagesError,
  17. MegolmEvent,
  18. GroupEncryptionError,
  19. EncryptionError,
  20. RoomMessageText,
  21. RoomSendResponse,
  22. SyncResponse,
  23. RoomMessageNotice,
  24. JoinError,
  25. RoomLeaveError,
  26. RoomSendError,
  27. RoomVisibility,
  28. RoomCreateError,
  29. )
  30. from nio.crypto import Olm
  31. from typing import Optional, List
  32. from configparser import ConfigParser
  33. from datetime import datetime
  34. from io import BytesIO
  35. from pathlib import Path
  36. import uuid
  37. import traceback
  38. import json
  39. from .logging import Logger
  40. from migrations import migrate
  41. from callbacks import RESPONSE_CALLBACKS, EVENT_CALLBACKS
  42. from commands import COMMANDS
  43. from .store import DuckDBStore
  44. from .openai import OpenAI
  45. from .wolframalpha import WolframAlpha
  46. from .trackingmore import TrackingMore
  47. class GPTBot:
  48. # Default values
  49. database: Optional[duckdb.DuckDBPyConnection] = None
  50. # Default name of rooms created by the bot
  51. display_name = default_room_name = "GPTBot"
  52. default_system_message: str = "You are a helpful assistant."
  53. # Force default system message to be included even if a custom room message is set
  54. force_system_message: bool = False
  55. max_tokens: int = 3000 # Maximum number of input tokens
  56. max_messages: int = 30 # Maximum number of messages to consider as input
  57. matrix_client: Optional[AsyncClient] = None
  58. sync_token: Optional[str] = None
  59. logger: Optional[Logger] = Logger()
  60. chat_api: Optional[OpenAI] = None
  61. image_api: Optional[OpenAI] = None
  62. classification_api: Optional[OpenAI] = None
  63. parcel_api: Optional[TrackingMore] = None
  64. operator: Optional[str] = None
  65. room_ignore_list: List[str] = [] # List of rooms to ignore invites from
  66. debug: bool = False
  67. logo: Optional[Image.Image] = None
  68. logo_uri: Optional[str] = None
  69. allowed_users: List[str] = []
  70. @classmethod
  71. def from_config(cls, config: ConfigParser):
  72. """Create a new GPTBot instance from a config file.
  73. Args:
  74. config (ConfigParser): ConfigParser instance with the bot's config.
  75. Returns:
  76. GPTBot: The new GPTBot instance.
  77. """
  78. # Create a new GPTBot instance
  79. bot = cls()
  80. # Set the database connection
  81. bot.database = duckdb.connect(
  82. config["Database"]["Path"]) if "Database" in config and "Path" in config["Database"] else None
  83. # Override default values
  84. if "GPTBot" in config:
  85. bot.operator = config["GPTBot"].get("Operator", bot.operator)
  86. bot.default_room_name = config["GPTBot"].get(
  87. "DefaultRoomName", bot.default_room_name)
  88. bot.default_system_message = config["GPTBot"].get(
  89. "SystemMessage", bot.default_system_message)
  90. bot.force_system_message = config["GPTBot"].getboolean(
  91. "ForceSystemMessage", bot.force_system_message)
  92. bot.debug = config["GPTBot"].getboolean("Debug", bot.debug)
  93. logo_path = config["GPTBot"].get("Logo", str(
  94. Path(__file__).parent.parent / "assets/logo.png"))
  95. bot.logger.log(f"Loading logo from {logo_path}")
  96. if Path(logo_path).exists() and Path(logo_path).is_file():
  97. bot.logo = Image.open(logo_path)
  98. bot.display_name = config["GPTBot"].get(
  99. "DisplayName", bot.display_name)
  100. if "AllowedUsers" in config["GPTBot"]:
  101. bot.allowed_users = json.loads(config["GPTBot"]["AllowedUsers"])
  102. bot.chat_api = bot.image_api = bot.classification_api = OpenAI(
  103. config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"), bot.logger)
  104. bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
  105. bot.max_messages = config["OpenAI"].getint(
  106. "MaxMessages", bot.max_messages)
  107. # Set up WolframAlpha
  108. if "WolframAlpha" in config:
  109. bot.calculation_api = WolframAlpha(
  110. config["WolframAlpha"]["APIKey"], bot.logger)
  111. # Set up TrackingMore
  112. if "TrackingMore" in config:
  113. bot.parcel_api = TrackingMore(
  114. config["TrackingMore"]["APIKey"], bot.logger)
  115. # Set up the Matrix client
  116. assert "Matrix" in config, "Matrix config not found"
  117. homeserver = config["Matrix"]["Homeserver"]
  118. bot.matrix_client = AsyncClient(homeserver)
  119. bot.matrix_client.access_token = config["Matrix"]["AccessToken"]
  120. bot.matrix_client.user_id = config["Matrix"].get("UserID")
  121. bot.matrix_client.device_id = config["Matrix"].get("DeviceID")
  122. # Return the new GPTBot instance
  123. return bot
  124. async def _get_user_id(self) -> str:
  125. """Get the user ID of the bot from the whoami endpoint.
  126. Requires an access token to be set up.
  127. Returns:
  128. str: The user ID of the bot.
  129. """
  130. assert self.matrix_client, "Matrix client not set up"
  131. user_id = self.matrix_client.user_id
  132. if not user_id:
  133. assert self.matrix_client.access_token, "Access token not set up"
  134. response = await self.matrix_client.whoami()
  135. if isinstance(response, WhoamiResponse):
  136. user_id = response.user_id
  137. else:
  138. raise Exception(f"Could not get user ID: {response}")
  139. return user_id
  140. async def _last_n_messages(self, room: str | MatrixRoom, n: Optional[int]):
  141. messages = []
  142. n = n or self.max_messages
  143. room_id = room.room_id if isinstance(room, MatrixRoom) else room
  144. self.logger.log(
  145. f"Fetching last {2*n} messages from room {room_id} (starting at {self.sync_token})...")
  146. response = await self.matrix_client.room_messages(
  147. room_id=room_id,
  148. start=self.sync_token,
  149. limit=2*n,
  150. )
  151. if isinstance(response, RoomMessagesError):
  152. raise Exception(
  153. f"Error fetching messages: {response.message} (status code {response.status_code})", "error")
  154. for event in response.chunk:
  155. if len(messages) >= n:
  156. break
  157. if isinstance(event, MegolmEvent):
  158. try:
  159. event = await self.matrix_client.decrypt_event(event)
  160. except (GroupEncryptionError, EncryptionError):
  161. self.logger.log(
  162. f"Could not decrypt message {event.event_id} in room {room_id}", "error")
  163. continue
  164. if isinstance(event, (RoomMessageText, RoomMessageNotice)):
  165. if event.body.startswith("!gptbot ignoreolder"):
  166. break
  167. if (not event.body.startswith("!")) or (event.body.startswith("!gptbot")):
  168. messages.append(event)
  169. self.logger.log(f"Found {len(messages)} messages (limit: {n})")
  170. # Reverse the list so that messages are in chronological order
  171. return messages[::-1]
  172. def _truncate(self, messages: list, max_tokens: Optional[int] = None,
  173. model: Optional[str] = None, system_message: Optional[str] = None):
  174. max_tokens = max_tokens or self.max_tokens
  175. model = model or self.chat_api.chat_model
  176. system_message = self.default_system_message if system_message is None else system_message
  177. encoding = tiktoken.encoding_for_model(model)
  178. total_tokens = 0
  179. system_message_tokens = 0 if not system_message else (
  180. len(encoding.encode(system_message)) + 1)
  181. if system_message_tokens > max_tokens:
  182. self.logger.log(
  183. f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed", "error")
  184. return []
  185. total_tokens += system_message_tokens
  186. total_tokens = len(system_message) + 1
  187. truncated_messages = []
  188. for message in [messages[0]] + list(reversed(messages[1:])):
  189. content = message["content"]
  190. tokens = len(encoding.encode(content)) + 1
  191. if total_tokens + tokens > max_tokens:
  192. break
  193. total_tokens += tokens
  194. truncated_messages.append(message)
  195. return [truncated_messages[0]] + list(reversed(truncated_messages[1:]))
  196. async def _get_device_id(self) -> str:
  197. """Guess the device ID of the bot.
  198. Requires an access token to be set up.
  199. Returns:
  200. str: The guessed device ID.
  201. """
  202. assert self.matrix_client, "Matrix client not set up"
  203. device_id = self.matrix_client.device_id
  204. if not device_id:
  205. assert self.matrix_client.access_token, "Access token not set up"
  206. devices = await self.matrix_client.devices()
  207. if isinstance(devices, DevicesResponse):
  208. device_id = devices.devices[0].id
  209. return device_id
  210. async def process_command(self, room: MatrixRoom, event: RoomMessageText):
  211. """Process a command. Called from the event_callback() method.
  212. Delegates to the appropriate command handler.
  213. Args:
  214. room (MatrixRoom): The room the command was sent in.
  215. event (RoomMessageText): The event containing the command.
  216. """
  217. self.logger.log(
  218. f"Received command {event.body} from {event.sender} in room {room.room_id}")
  219. command = event.body.split()[1] if event.body.split()[1:] else None
  220. await COMMANDS.get(command, COMMANDS[None])(room, event, self)
  221. def room_uses_classification(self, room: MatrixRoom | str) -> bool:
  222. """Check if a room uses classification.
  223. Args:
  224. room (MatrixRoom | str): The room to check.
  225. Returns:
  226. bool: Whether the room uses classification.
  227. """
  228. room_id = room.room_id if isinstance(room, MatrixRoom) else room
  229. with self.database.cursor() as cursor:
  230. cursor.execute(
  231. "SELECT value FROM room_settings WHERE room_id = ? AND setting = ?", (room_id, "use_classification"))
  232. result = cursor.fetchone()
  233. return False if not result else bool(int(result[0]))
  234. async def _event_callback(self, room: MatrixRoom, event: Event):
  235. self.logger.log("Received event: " + str(event.event_id), "debug")
  236. try:
  237. for eventtype, callback in EVENT_CALLBACKS.items():
  238. if isinstance(event, eventtype):
  239. await callback(room, event, self)
  240. except Exception as e:
  241. self.logger.log(
  242. f"Error in event callback for {event.__class__}: {e}", "error")
  243. if self.debug:
  244. await self.send_message(room, f"Error: {e}\n\n```\n{traceback.format_exc()}\n```", True)
  245. def user_is_allowed(self, user_id: str) -> bool:
  246. """Check if a user is allowed to use the bot.
  247. Args:
  248. user_id (str): The user ID to check.
  249. Returns:
  250. bool: Whether the user is allowed to use the bot.
  251. """
  252. return (
  253. user_id in self.allowed_users or
  254. f"*:{user_id.split(':')[1]}" in self.allowed_users or
  255. f"@*:{user_id.split(':')[1]}" in self.allowed_users
  256. ) if self.allowed_users else True
  257. async def event_callback(self, room: MatrixRoom, event: Event):
  258. """Callback for events.
  259. Args:
  260. room (MatrixRoom): The room the event was sent in.
  261. event (Event): The event.
  262. """
  263. if event.sender == self.matrix_client.user_id:
  264. return
  265. if not self.user_is_allowed(event.sender):
  266. if len(room.users) == 2:
  267. await self.matrix_client.room_send(
  268. room.room_id,
  269. "m.room.message",
  270. {
  271. "msgtype": "m.notice",
  272. "body": f"You are not allowed to use this bot. Please contact {self.operator} for more information."
  273. }
  274. )
  275. return
  276. task = asyncio.create_task(self._event_callback(room, event))
  277. def room_uses_timing(self, room: MatrixRoom):
  278. """Check if a room uses timing.
  279. Args:
  280. room (MatrixRoom): The room to check.
  281. Returns:
  282. bool: Whether the room uses timing.
  283. """
  284. room_id = room.room_id
  285. with self.database.cursor() as cursor:
  286. cursor.execute(
  287. "SELECT value FROM room_settings WHERE room_id = ? AND setting = ?", (room_id, "use_timing"))
  288. result = cursor.fetchone()
  289. return False if not result else bool(int(result[0]))
  290. async def _response_callback(self, response: Response):
  291. for response_type, callback in RESPONSE_CALLBACKS.items():
  292. if isinstance(response, response_type):
  293. await callback(response, self)
  294. async def response_callback(self, response: Response):
  295. task = asyncio.create_task(self._response_callback(response))
  296. async def accept_pending_invites(self):
  297. """Accept all pending invites."""
  298. assert self.matrix_client, "Matrix client not set up"
  299. invites = self.matrix_client.invited_rooms
  300. for invite in invites.keys():
  301. if invite in self.room_ignore_list:
  302. self.logger.log(
  303. f"Ignoring invite to room {invite} (room is in ignore list)")
  304. continue
  305. self.logger.log(f"Accepting invite to room {invite}")
  306. response = await self.matrix_client.join(invite)
  307. if isinstance(response, JoinError):
  308. self.logger.log(
  309. f"Error joining room {invite}: {response.message}. Not trying again.", "error")
  310. leave_response = await self.matrix_client.room_leave(invite)
  311. if isinstance(leave_response, RoomLeaveError):
  312. self.logger.log(
  313. f"Error leaving room {invite}: {leave_response.message}", "error")
  314. self.room_ignore_list.append(invite)
  315. async def upload_file(self, file: bytes, filename: str = "file", mime: str = "application/octet-stream") -> str:
  316. """Upload a file to the homeserver.
  317. Args:
  318. file (bytes): The file to upload.
  319. filename (str, optional): The name of the file. Defaults to "file".
  320. mime (str, optional): The MIME type of the file. Defaults to "application/octet-stream".
  321. Returns:
  322. str: The MXC URI of the uploaded file.
  323. """
  324. bio = BytesIO(file)
  325. bio.seek(0)
  326. response, _ = await self.matrix_client.upload(
  327. bio,
  328. content_type=mime,
  329. filename=filename,
  330. filesize=len(file)
  331. )
  332. return response.content_uri
  333. async def send_image(self, room: MatrixRoom, image: bytes, message: Optional[str] = None):
  334. """Send an image to a room.
  335. Args:
  336. room (MatrixRoom): The room to send the image to.
  337. image (bytes): The image to send.
  338. message (str, optional): The message to send with the image. Defaults to None.
  339. """
  340. self.logger.log(
  341. f"Sending image of size {len(image)} bytes to room {room.room_id}")
  342. bio = BytesIO(image)
  343. img = Image.open(bio)
  344. mime = Image.MIME[img.format]
  345. (width, height) = img.size
  346. self.logger.log(
  347. f"Uploading - Image size: {width}x{height} pixels, MIME type: {mime}")
  348. content_uri = await self.upload_file(image, "image", mime)
  349. self.logger.log("Uploaded image - sending message...")
  350. content = {
  351. "body": message or "",
  352. "info": {
  353. "mimetype": mime,
  354. "size": len(image),
  355. "w": width,
  356. "h": height,
  357. },
  358. "msgtype": "m.image",
  359. "url": content_uri
  360. }
  361. status = await self.matrix_client.room_send(
  362. room.room_id,
  363. "m.room.message",
  364. content
  365. )
  366. self.logger.log(str(status), "debug")
  367. self.logger.log("Sent image")
  368. async def send_message(self, room: MatrixRoom | str, message: str, notice: bool = False):
  369. """Send a message to a room.
  370. Args:
  371. room (MatrixRoom): The room to send the message to.
  372. message (str): The message to send.
  373. notice (bool): Whether to send the message as a notice. Defaults to False.
  374. """
  375. if isinstance(room, str):
  376. room = self.matrix_client.rooms[room]
  377. markdowner = markdown2.Markdown(extras=["fenced-code-blocks"])
  378. formatted_body = markdowner.convert(message)
  379. msgtype = "m.notice" if notice else "m.text"
  380. msgcontent = {"msgtype": msgtype, "body": message,
  381. "format": "org.matrix.custom.html", "formatted_body": formatted_body}
  382. content = None
  383. if self.matrix_client.olm and room.encrypted:
  384. try:
  385. if not room.members_synced:
  386. responses = []
  387. responses.append(await self.matrix_client.joined_members(room.room_id))
  388. if self.matrix_client.olm.should_share_group_session(room.room_id):
  389. try:
  390. event = self.matrix_client.sharing_session[room.room_id]
  391. await event.wait()
  392. except KeyError:
  393. await self.matrix_client.share_group_session(
  394. room.room_id,
  395. ignore_unverified_devices=True,
  396. )
  397. if msgtype != "m.reaction":
  398. response = self.matrix_client.encrypt(
  399. room.room_id, "m.room.message", msgcontent)
  400. msgtype, content = response
  401. except Exception as e:
  402. self.logger.log(
  403. f"Error encrypting message: {e} - sending unencrypted", "error")
  404. raise
  405. if not content:
  406. msgtype = "m.room.message"
  407. content = msgcontent
  408. method, path, data = Api.room_send(
  409. self.matrix_client.access_token, room.room_id, msgtype, content, uuid.uuid4()
  410. )
  411. response = await self.matrix_client._send(RoomSendResponse, method, path, data, (room.room_id,))
  412. if isinstance(response, RoomSendError):
  413. self.logger.log(
  414. f"Error sending message: {response.message}", "error")
  415. return
  416. def log_api_usage(self, message: Event | str, room: MatrixRoom | str, api: str, tokens: int):
  417. """Log API usage to the database.
  418. Args:
  419. message (Event): The event that triggered the API usage.
  420. room (MatrixRoom | str): The room the event was sent in.
  421. api (str): The API that was used.
  422. tokens (int): The number of tokens used.
  423. """
  424. if not self.database:
  425. return
  426. if isinstance(message, Event):
  427. message = message.event_id
  428. if isinstance(room, MatrixRoom):
  429. room = room.room_id
  430. self.database.execute(
  431. "INSERT INTO token_usage (message_id, room_id, tokens, api, timestamp) VALUES (?, ?, ?, ?, ?)",
  432. (message, room, tokens, api, datetime.now())
  433. )
  434. async def run(self):
  435. """Start the bot."""
  436. # Set up the Matrix client
  437. assert self.matrix_client, "Matrix client not set up"
  438. assert self.matrix_client.access_token, "Access token not set up"
  439. if not self.matrix_client.user_id:
  440. self.matrix_client.user_id = await self._get_user_id()
  441. if not self.matrix_client.device_id:
  442. self.matrix_client.device_id = await self._get_device_id()
  443. # Set up database
  444. IN_MEMORY = False
  445. if not self.database:
  446. self.logger.log(
  447. "No database connection set up, using in-memory database. Data will be lost on bot shutdown.")
  448. IN_MEMORY = True
  449. self.database = duckdb.DuckDBPyConnection(":memory:")
  450. self.logger.log("Running migrations...")
  451. before, after = migrate(self.database)
  452. if before != after:
  453. self.logger.log(f"Migrated from version {before} to {after}.")
  454. else:
  455. self.logger.log(f"Already at latest version {after}.")
  456. if IN_MEMORY:
  457. client_config = AsyncClientConfig(
  458. store_sync_tokens=True, encryption_enabled=False)
  459. else:
  460. matrix_store = DuckDBStore
  461. client_config = AsyncClientConfig(
  462. store_sync_tokens=True, encryption_enabled=True, store=matrix_store)
  463. self.matrix_client.config = client_config
  464. self.matrix_client.store = matrix_store(
  465. self.matrix_client.user_id,
  466. self.matrix_client.device_id,
  467. self.database
  468. )
  469. self.matrix_client.olm = Olm(
  470. self.matrix_client.user_id,
  471. self.matrix_client.device_id,
  472. self.matrix_client.store
  473. )
  474. self.matrix_client.encrypted_rooms = self.matrix_client.store.load_encrypted_rooms()
  475. # Run initial sync (now includes joining rooms)
  476. sync = await self.matrix_client.sync(timeout=30000)
  477. if isinstance(sync, SyncResponse):
  478. await self.response_callback(sync)
  479. else:
  480. self.logger.log(f"Initial sync failed, aborting: {sync}", "error")
  481. return
  482. # Set up callbacks
  483. self.matrix_client.add_event_callback(self.event_callback, Event)
  484. self.matrix_client.add_response_callback(
  485. self.response_callback, Response)
  486. # Set custom name / logo
  487. if self.display_name:
  488. self.logger.log(f"Setting display name to {self.display_name}")
  489. await self.matrix_client.set_displayname(self.display_name)
  490. if self.logo:
  491. self.logger.log("Setting avatar...")
  492. logo_bio = BytesIO()
  493. self.logo.save(logo_bio, format=self.logo.format)
  494. uri = await self.upload_file(logo_bio.getvalue(), "logo", Image.MIME[self.logo.format])
  495. self.logo_uri = uri
  496. asyncio.create_task(self.matrix_client.set_avatar(uri))
  497. for room in self.matrix_client.rooms.keys():
  498. self.logger.log(f"Setting avatar for {room}...", "debug")
  499. asyncio.create_task(self.matrix_client.room_put_state(room, "m.room.avatar", {
  500. "url": uri
  501. }, ""))
  502. # Start syncing events
  503. self.logger.log("Starting sync loop...")
  504. try:
  505. await self.matrix_client.sync_forever(timeout=30000)
  506. finally:
  507. self.logger.log("Syncing one last time...")
  508. await self.matrix_client.sync(timeout=30000)
  509. async def create_space(self, name, visibility=RoomVisibility.private) -> str:
  510. """Create a space.
  511. Args:
  512. name (str): The name of the space.
  513. visibility (RoomVisibility, optional): The visibility of the space. Defaults to RoomVisibility.private.
  514. Returns:
  515. MatrixRoom: The created space.
  516. """
  517. response = await self.matrix_client.room_create(
  518. name=name, visibility=visibility, space=True)
  519. if isinstance(response, RoomCreateError):
  520. self.logger.log(
  521. f"Error creating space: {response.message}", "error")
  522. return
  523. return response.room_id
  524. async def add_rooms_to_space(self, space: MatrixRoom | str, rooms: List[MatrixRoom | str]):
  525. """Add rooms to a space.
  526. Args:
  527. space (MatrixRoom | str): The space to add the rooms to.
  528. rooms (List[MatrixRoom | str]): The rooms to add to the space.
  529. """
  530. if isinstance(space, MatrixRoom):
  531. space = space.room_id
  532. for room in rooms:
  533. if isinstance(room, MatrixRoom):
  534. room = room.room_id
  535. if space == room:
  536. self.logger.log(
  537. f"Refusing to add {room} to itself", "warning")
  538. continue
  539. self.logger.log(f"Adding {room} to {space}...")
  540. await self.matrix_client.room_put_state(space, "m.space.child", {
  541. "via": [room.split(":")[1], space.split(":")[1]],
  542. }, room)
  543. await self.matrix_client.room_put_state(room, "m.room.parent", {
  544. "via": [space.split(":")[1], room.split(":")[1]],
  545. "canonical": True
  546. }, space)
  547. def respond_to_room_messages(self, room: MatrixRoom | str) -> bool:
  548. """Check whether the bot should respond to all messages sent in a room.
  549. Args:
  550. room (MatrixRoom | str): The room to check.
  551. Returns:
  552. bool: Whether the bot should respond to all messages sent in the room.
  553. """
  554. if isinstance(room, MatrixRoom):
  555. room = room.room_id
  556. with self.database.cursor() as cursor:
  557. cursor.execute(
  558. "SELECT value FROM room_settings WHERE room_id = ? AND setting = ?", (room, "always_reply"))
  559. result = cursor.fetchone()
  560. return True if not result else bool(int(result[0]))
  561. async def process_query(self, room: MatrixRoom, event: RoomMessageText, from_chat_command: bool = False):
  562. """Process a query message. Generates a response and sends it to the room.
  563. Args:
  564. room (MatrixRoom): The room the message was sent in.
  565. event (RoomMessageText): The event that triggered the query.
  566. from_chat_command (bool, optional): Whether the query was sent via the `!gptbot chat` command. Defaults to False.
  567. """
  568. if not (from_chat_command or self.respond_to_room_messages(room) or self.matrix_client.user_id in event.body):
  569. return
  570. await self.matrix_client.room_typing(room.room_id, True)
  571. await self.matrix_client.room_read_markers(room.room_id, event.event_id)
  572. if (not from_chat_command) and self.room_uses_classification(room):
  573. try:
  574. classification, tokens = await self.classification_api.classify_message(
  575. event.body, room.room_id)
  576. except Exception as e:
  577. self.logger.log(f"Error classifying message: {e}", "error")
  578. await self.send_message(
  579. room, "Something went wrong. Please try again.", True)
  580. return
  581. self.log_api_usage(
  582. event, room, f"{self.classification_api.api_code}-{self.classification_api.classification_api}", tokens)
  583. if not classification["type"] == "chat":
  584. event.body = f"!gptbot {classification['type']} {classification['prompt']}"
  585. await self.process_command(room, event)
  586. return
  587. try:
  588. last_messages = await self._last_n_messages(room.room_id, 20)
  589. except Exception as e:
  590. self.logger.log(f"Error getting last messages: {e}", "error")
  591. await self.send_message(
  592. room, "Something went wrong. Please try again.", True)
  593. return
  594. system_message = self.get_system_message(room)
  595. chat_messages = [{"role": "system", "content": system_message}]
  596. for message in last_messages:
  597. role = "assistant" if message.sender == self.matrix_client.user_id else "user"
  598. if not message.event_id == event.event_id:
  599. chat_messages.append({"role": role, "content": message.body})
  600. chat_messages.append({"role": "user", "content": event.body})
  601. # Truncate messages to fit within the token limit
  602. truncated_messages = self._truncate(
  603. chat_messages, self.max_tokens - 1, system_message=system_message)
  604. try:
  605. response, tokens_used = await self.chat_api.generate_chat_response(
  606. chat_messages, user=room.room_id)
  607. except Exception as e:
  608. self.logger.log(f"Error generating response: {e}", "error")
  609. await self.send_message(
  610. room, "Something went wrong. Please try again.", True)
  611. return
  612. if response:
  613. self.log_api_usage(
  614. event, room, f"{self.chat_api.api_code}-{self.chat_api.chat_api}", tokens_used)
  615. self.logger.log(f"Sending response to room {room.room_id}...")
  616. # Convert markdown to HTML
  617. message = await self.send_message(room, response)
  618. else:
  619. # Send a notice to the room if there was an error
  620. self.logger.log("Didn't get a response from GPT API", "error")
  621. await self.send_message(
  622. room, "Something went wrong. Please try again.", True)
  623. await self.matrix_client.room_typing(room.room_id, False)
  624. def get_system_message(self, room: MatrixRoom | str) -> str:
  625. """Get the system message for a room.
  626. Args:
  627. room (MatrixRoom | str): The room to get the system message for.
  628. Returns:
  629. str: The system message.
  630. """
  631. default = self.default_system_message
  632. if isinstance(room, str):
  633. room_id = room
  634. else:
  635. room_id = room.room_id
  636. with self.database.cursor() as cur:
  637. cur.execute(
  638. "SELECT value FROM room_settings WHERE room_id = ? AND setting = ?",
  639. (room_id, "system_message")
  640. )
  641. system_message = cur.fetchone()
  642. complete = ((default if ((not system_message) or self.force_system_message) else "") + (
  643. "\n\n" + system_message[0] if system_message else "")).strip()
  644. return complete
  645. def __del__(self):
  646. """Close the bot."""
  647. if self.matrix_client:
  648. asyncio.run(self.matrix_client.close())
  649. if self.database:
  650. self.database.close()