gptbot.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574
  1. import os
  2. import inspect
  3. import logging
  4. import signal
  5. import random
  6. import uuid
  7. import openai
  8. import asyncio
  9. import markdown2
  10. import tiktoken
  11. import duckdb
  12. from nio import AsyncClient, RoomMessageText, MatrixRoom, Event, InviteEvent, AsyncClientConfig, MegolmEvent, GroupEncryptionError, EncryptionError, HttpClient, Api
  13. from nio.api import MessageDirection
  14. from nio.responses import RoomMessagesError, SyncResponse, RoomRedactError, WhoamiResponse, JoinResponse, RoomSendResponse
  15. from nio.crypto import Olm
  16. from configparser import ConfigParser
  17. from datetime import datetime
  18. from argparse import ArgumentParser
  19. from typing import List, Dict, Union, Optional
  20. from commands import COMMANDS
  21. from classes import DuckDBStore
  22. from migrations import MIGRATIONS
  23. def logging(message: str, log_level: str = "info"):
  24. caller = inspect.currentframe().f_back.f_code.co_name
  25. timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S:%f")
  26. print(f"[{timestamp}] - {caller} - [{log_level.upper()}] {message}")
  27. CONTEXT = {
  28. "database": False,
  29. "default_room_name": "GPTBot",
  30. "system_message": "You are a helpful assistant.",
  31. "force_system_message": False,
  32. "max_tokens": 3000,
  33. "max_messages": 20,
  34. "model": "gpt-3.5-turbo",
  35. "client": None,
  36. "sync_token": None,
  37. "logger": logging
  38. }
  39. async def gpt_query(messages: list, model: Optional[str] = None):
  40. model = model or CONTEXT["model"]
  41. logging(f"Querying GPT with {len(messages)} messages")
  42. logging(messages, "debug")
  43. try:
  44. response = openai.ChatCompletion.create(
  45. model=model,
  46. messages=messages
  47. )
  48. result_text = response.choices[0].message['content']
  49. tokens_used = response.usage["total_tokens"]
  50. logging(f"Used {tokens_used} tokens")
  51. return result_text, tokens_used
  52. except Exception as e:
  53. logging(f"Error during GPT API call: {e}", "error")
  54. return None, 0
  55. async def fetch_last_n_messages(room_id: str, n: Optional[int] = None,
  56. client: Optional[AsyncClient] = None, sync_token: Optional[str] = None):
  57. messages = []
  58. n = n or CONTEXT["max_messages"]
  59. client = client or CONTEXT["client"]
  60. sync_token = sync_token or CONTEXT["sync_token"]
  61. logging(
  62. f"Fetching last {2*n} messages from room {room_id} (starting at {sync_token})...")
  63. response = await client.room_messages(
  64. room_id=room_id,
  65. start=sync_token,
  66. limit=2*n,
  67. )
  68. if isinstance(response, RoomMessagesError):
  69. logging(
  70. f"Error fetching messages: {response.message} (status code {response.status_code})", "error")
  71. return []
  72. for event in response.chunk:
  73. if len(messages) >= n:
  74. break
  75. if isinstance(event, MegolmEvent):
  76. try:
  77. event = await client.decrypt_event(event)
  78. except (GroupEncryptionError, EncryptionError):
  79. logging(
  80. f"Could not decrypt message {event.event_id} in room {room_id}", "error")
  81. continue
  82. if isinstance(event, RoomMessageText):
  83. if event.body.startswith("!gptbot ignoreolder"):
  84. break
  85. if not event.body.startswith("!"):
  86. messages.append(event)
  87. logging(f"Found {len(messages)} messages (limit: {n})")
  88. # Reverse the list so that messages are in chronological order
  89. return messages[::-1]
  90. def truncate_messages_to_fit_tokens(messages: list, max_tokens: Optional[int] = None,
  91. model: Optional[str] = None, system_message: Optional[str] = None):
  92. max_tokens = max_tokens or CONTEXT["max_tokens"]
  93. model = model or CONTEXT["model"]
  94. system_message = system_message or CONTEXT["system_message"]
  95. encoding = tiktoken.encoding_for_model(model)
  96. total_tokens = 0
  97. system_message_tokens = len(encoding.encode(system_message)) + 1
  98. if system_message_tokens > max_tokens:
  99. logging(
  100. f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed", "error")
  101. return []
  102. total_tokens += system_message_tokens
  103. total_tokens = len(system_message) + 1
  104. truncated_messages = []
  105. for message in [messages[0]] + list(reversed(messages[1:])):
  106. content = message["content"]
  107. tokens = len(encoding.encode(content)) + 1
  108. if total_tokens + tokens > max_tokens:
  109. break
  110. total_tokens += tokens
  111. truncated_messages.append(message)
  112. return [truncated_messages[0]] + list(reversed(truncated_messages[1:]))
  113. async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
  114. client = kwargs.get("client") or CONTEXT["client"]
  115. database = kwargs.get("database") or CONTEXT["database"]
  116. max_tokens = kwargs.get("max_tokens") or CONTEXT["max_tokens"]
  117. system_message = kwargs.get("system_message") or CONTEXT["system_message"]
  118. force_system_message = kwargs.get("force_system_message") or CONTEXT["force_system_message"]
  119. await client.room_typing(room.room_id, True)
  120. await client.room_read_markers(room.room_id, event.event_id)
  121. last_messages = await fetch_last_n_messages(room.room_id, 20)
  122. system_message = get_system_message(room, {
  123. "database": database,
  124. "system_message": system_message,
  125. "force_system_message": force_system_message,
  126. })
  127. chat_messages = [{"role": "system", "content": system_message}]
  128. for message in last_messages:
  129. role = "assistant" if message.sender == client.user_id else "user"
  130. if not message.event_id == event.event_id:
  131. chat_messages.append({"role": role, "content": message.body})
  132. chat_messages.append({"role": "user", "content": event.body})
  133. # Truncate messages to fit within the token limit
  134. truncated_messages = truncate_messages_to_fit_tokens(
  135. chat_messages, max_tokens - 1, system_message=system_message)
  136. response, tokens_used = await gpt_query(truncated_messages)
  137. if response:
  138. logging(f"Sending response to room {room.room_id}...")
  139. # Convert markdown to HTML
  140. message = await send_message(room, response)
  141. if database:
  142. logging("Logging tokens used...")
  143. with database.cursor() as cursor:
  144. cursor.execute(
  145. "INSERT INTO token_usage (message_id, room_id, tokens, timestamp) VALUES (?, ?, ?, ?)",
  146. (message.event_id, room.room_id, tokens_used, datetime.now()))
  147. database.commit()
  148. else:
  149. # Send a notice to the room if there was an error
  150. logging("Error during GPT API call - sending notice to room")
  151. send_message(
  152. room, "Sorry, I'm having trouble connecting to the GPT API right now. Please try again later.", True)
  153. print("No response from GPT API")
  154. await client.room_typing(room.room_id, False)
  155. async def process_command(room: MatrixRoom, event: RoomMessageText, context: Optional[dict] = None):
  156. context = context or CONTEXT
  157. logging(
  158. f"Received command {event.body} from {event.sender} in room {room.room_id}")
  159. command = event.body.split()[1] if event.body.split()[1:] else None
  160. message = await COMMANDS.get(command, COMMANDS[None])(room, event, context)
  161. if message:
  162. room_id, event, content = message
  163. rooms = await context["client"].joined_rooms()
  164. await send_message(context["client"].rooms[room_id], content["body"],
  165. True if content["msgtype"] == "m.notice" else False, context["client"])
  166. def get_system_message(room: MatrixRoom, context: Optional[dict]) -> str:
  167. context = context or CONTEXT
  168. default = context.get("system_message")
  169. with context["database"].cursor() as cur:
  170. cur.execute(
  171. "SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
  172. (room.room_id,)
  173. )
  174. system_message = cur.fetchone()
  175. complete = ((default if ((not system_message) or context["force_system_message"]) else "") + (
  176. "\n\n" + system_message[0] if system_message else "")).strip()
  177. return complete
  178. async def message_callback(room: MatrixRoom, event: RoomMessageText | MegolmEvent, **kwargs):
  179. context = kwargs.get("context") or CONTEXT
  180. logging(f"Received message from {event.sender} in room {room.room_id}")
  181. if isinstance(event, MegolmEvent):
  182. try:
  183. event = await context["client"].decrypt_event(event)
  184. except Exception as e:
  185. try:
  186. logging("Requesting new encryption keys...")
  187. await context["client"].request_room_key(event)
  188. except:
  189. pass
  190. logging(f"Error decrypting message: {e}", "error")
  191. await send_message(room, "Sorry, I couldn't decrypt that message. Please try again later or switch to a room without encryption.", True, context["client"])
  192. return
  193. if event.sender == context["client"].user_id:
  194. logging("Message is from bot itself - ignoring")
  195. elif event.body.startswith("!gptbot"):
  196. await process_command(room, event)
  197. elif event.body.startswith("!"):
  198. logging("Might be a command, but not for this bot - ignoring")
  199. else:
  200. await process_query(room, event, context=context)
  201. async def room_invite_callback(room: MatrixRoom, event: InviteEvent, **kwargs):
  202. client: AsyncClient = kwargs.get("client") or CONTEXT["client"]
  203. if room.room_id in client.rooms:
  204. logging(f"Already in room {room.room_id} - ignoring invite")
  205. return
  206. logging(f"Received invite to room {room.room_id} - joining...")
  207. response = await client.join(room.room_id)
  208. if isinstance(response, JoinResponse):
  209. await send_message(room, "Hello! I'm a helpful assistant. How can I help you today?", client)
  210. else:
  211. logging(f"Error joining room {room.room_id}: {response}", "error")
  212. async def send_message(room: MatrixRoom, message: str, notice: bool = False, client: Optional[AsyncClient] = None):
  213. client = client or CONTEXT["client"]
  214. markdowner = markdown2.Markdown(extras=["fenced-code-blocks"])
  215. formatted_body = markdowner.convert(message)
  216. msgtype = "m.notice" if notice else "m.text"
  217. msgcontent = {"msgtype": msgtype, "body": message,
  218. "format": "org.matrix.custom.html", "formatted_body": formatted_body}
  219. content = None
  220. if client.olm and room.encrypted:
  221. try:
  222. if not room.members_synced:
  223. responses = []
  224. responses.append(await client.joined_members(room.room_id))
  225. if client.olm.should_share_group_session(room.room_id):
  226. try:
  227. event = client.sharing_session[room.room_id]
  228. await event.wait()
  229. except KeyError:
  230. await client.share_group_session(
  231. room.room_id,
  232. ignore_unverified_devices=True,
  233. )
  234. if msgtype != "m.reaction":
  235. response = client.encrypt(
  236. room.room_id, "m.room.message", msgcontent)
  237. msgtype, content = response
  238. except Exception as e:
  239. logging(
  240. f"Error encrypting message: {e} - sending unencrypted", "error")
  241. raise
  242. if not content:
  243. msgtype = "m.room.message"
  244. content = msgcontent
  245. method, path, data = Api.room_send(
  246. client.access_token, room.room_id, msgtype, content, uuid.uuid4()
  247. )
  248. return await client._send(RoomSendResponse, method, path, data, (room.room_id,))
  249. async def accept_pending_invites(client: Optional[AsyncClient] = None):
  250. client = client or CONTEXT["client"]
  251. logging("Accepting pending invites...")
  252. for room_id in list(client.invited_rooms.keys()):
  253. logging(f"Joining room {room_id}...")
  254. response = await client.join(room_id)
  255. if isinstance(response, JoinResponse):
  256. logging(response, "debug")
  257. rooms = await client.joined_rooms()
  258. await send_message(client.rooms[room_id], "Hello! I'm a helpful assistant. How can I help you today?", client)
  259. else:
  260. logging(f"Error joining room {room_id}: {response}", "error")
  261. async def sync_cb(response, write_global: bool = True):
  262. logging(
  263. f"Sync response received (next batch: {response.next_batch})", "debug")
  264. SYNC_TOKEN = response.next_batch
  265. if write_global:
  266. global CONTEXT
  267. CONTEXT["sync_token"] = SYNC_TOKEN
  268. async def test_callback(room: MatrixRoom, event: Event, **kwargs):
  269. logging(
  270. f"Received event {event.__class__.__name__} in room {room.room_id}", "debug")
  271. async def init(config: ConfigParser):
  272. # Set up Matrix client
  273. try:
  274. assert "Matrix" in config
  275. assert "Homeserver" in config["Matrix"]
  276. assert "AccessToken" in config["Matrix"]
  277. except:
  278. logging("Matrix config not found or incomplete", "critical")
  279. exit(1)
  280. homeserver = config["Matrix"]["Homeserver"]
  281. access_token = config["Matrix"]["AccessToken"]
  282. device_id, user_id = await get_device_id(access_token, homeserver)
  283. device_id = config["Matrix"].get("DeviceID", device_id)
  284. user_id = config["Matrix"].get("UserID", user_id)
  285. # Set up database
  286. if "Database" in config and config["Database"].get("Path"):
  287. database = CONTEXT["database"] = initialize_database(
  288. config["Database"]["Path"])
  289. matrix_store = DuckDBStore
  290. client_config = AsyncClientConfig(
  291. store_sync_tokens=True, encryption_enabled=True, store=matrix_store)
  292. else:
  293. client_config = AsyncClientConfig(
  294. store_sync_tokens=True, encryption_enabled=False)
  295. client = AsyncClient(
  296. config["Matrix"]["Homeserver"], config=client_config)
  297. if client.config.encryption_enabled:
  298. client.store = client.config.store(
  299. user_id,
  300. device_id,
  301. database
  302. )
  303. assert client.store
  304. client.olm = Olm(client.user_id, client.device_id, client.store)
  305. client.encrypted_rooms = client.store.load_encrypted_rooms()
  306. CONTEXT["client"] = client
  307. CONTEXT["client"].access_token = config["Matrix"]["AccessToken"]
  308. CONTEXT["client"].user_id = user_id
  309. CONTEXT["client"].device_id = device_id
  310. # Set up GPT API
  311. try:
  312. assert "OpenAI" in config
  313. assert "APIKey" in config["OpenAI"]
  314. except:
  315. logging("OpenAI config not found or incomplete", "critical")
  316. exit(1)
  317. openai.api_key = config["OpenAI"]["APIKey"]
  318. if "Model" in config["OpenAI"]:
  319. CONTEXT["model"] = config["OpenAI"]["Model"]
  320. if "MaxTokens" in config["OpenAI"]:
  321. CONTEXT["max_tokens"] = int(config["OpenAI"]["MaxTokens"])
  322. if "MaxMessages" in config["OpenAI"]:
  323. CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"])
  324. # Override defaults with config
  325. if "GPTBot" in config:
  326. if "SystemMessage" in config["GPTBot"]:
  327. CONTEXT["system_message"] = config["GPTBot"]["SystemMessage"]
  328. if "DefaultRoomName" in config["GPTBot"]:
  329. CONTEXT["default_room_name"] = config["GPTBot"]["DefaultRoomName"]
  330. if "ForceSystemMessage" in config["GPTBot"]:
  331. CONTEXT["force_system_message"] = config["GPTBot"].getboolean(
  332. "ForceSystemMessage")
  333. async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClient] = None):
  334. if not client and not CONTEXT.get("client"):
  335. await init(config)
  336. client = client or CONTEXT["client"]
  337. try:
  338. assert client.user_id
  339. except AssertionError:
  340. logging(
  341. "Failed to get user ID - check your access token or try setting it manually", "critical")
  342. await client.close()
  343. return
  344. # Listen for SIGTERM
  345. def sigterm_handler(_signo, _stack_frame):
  346. logging("Received SIGTERM - exiting...")
  347. exit()
  348. signal.signal(signal.SIGTERM, sigterm_handler)
  349. logging("Starting bot...")
  350. client.add_response_callback(sync_cb, SyncResponse)
  351. logging("Syncing...")
  352. await client.sync(timeout=30000)
  353. client.add_event_callback(message_callback, RoomMessageText)
  354. client.add_event_callback(message_callback, MegolmEvent)
  355. client.add_event_callback(room_invite_callback, InviteEvent)
  356. client.add_event_callback(test_callback, Event)
  357. await accept_pending_invites() # Accept pending invites
  358. logging("Bot started")
  359. try:
  360. # Continue syncing events
  361. await client.sync_forever(timeout=30000)
  362. finally:
  363. logging("Syncing one last time...")
  364. await client.sync(timeout=30000)
  365. await client.close() # Properly close the aiohttp client session
  366. logging("Bot stopped")
  367. def initialize_database(path: os.PathLike):
  368. logging("Initializing database...")
  369. conn = duckdb.connect(path)
  370. with conn.cursor() as cursor:
  371. # Get the latest migration ID if the migrations table exists
  372. try:
  373. cursor.execute(
  374. """
  375. SELECT MAX(id) FROM migrations
  376. """
  377. )
  378. latest_migration = int(cursor.fetchone()[0])
  379. except:
  380. latest_migration = 0
  381. for migration, function in MIGRATIONS.items():
  382. if latest_migration < migration:
  383. logging(f"Running migration {migration}...")
  384. function(conn)
  385. latest_migration = migration
  386. return conn
  387. async def get_device_id(access_token, homeserver):
  388. client = AsyncClient(homeserver)
  389. client.access_token = access_token
  390. logging(f"Obtaining device ID for access token {access_token}...", "debug")
  391. response = await client.whoami()
  392. if isinstance(response, WhoamiResponse):
  393. logging(
  394. f"Authenticated as {response.user_id}.")
  395. user_id = response.user_id
  396. devices = await client.devices()
  397. device_id = devices.devices[0].id
  398. await client.close()
  399. return device_id, user_id
  400. else:
  401. logging(f"Failed to obtain device ID: {response}", "error")
  402. await client.close()
  403. return None, None
  404. if __name__ == "__main__":
  405. # Parse command line arguments
  406. parser = ArgumentParser()
  407. parser.add_argument(
  408. "--config", help="Path to config file (default: config.ini in working directory)", default="config.ini")
  409. args = parser.parse_args()
  410. # Read config file
  411. config = ConfigParser()
  412. config.read(args.config)
  413. # Start bot loop
  414. try:
  415. asyncio.run(main(config))
  416. except KeyboardInterrupt:
  417. logging("Received KeyboardInterrupt - exiting...")
  418. except SystemExit:
  419. logging("Received SIGTERM - exiting...")
  420. finally:
  421. if CONTEXT["database"]:
  422. CONTEXT["database"].close()