gptbot.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. import os
  2. import inspect
  3. import logging
  4. import signal
  5. import random
  6. import openai
  7. import asyncio
  8. import markdown2
  9. import tiktoken
  10. import duckdb
  11. from nio import AsyncClient, RoomMessageText, MatrixRoom, Event, InviteEvent
  12. from nio.api import MessageDirection
  13. from nio.responses import RoomMessagesError, SyncResponse, RoomRedactError
  14. from configparser import ConfigParser
  15. from datetime import datetime
  16. from argparse import ArgumentParser
  17. from typing import List, Dict, Union, Optional
  18. from commands import COMMANDS
  19. def logging(message: str, log_level: str = "info"):
  20. caller = inspect.currentframe().f_back.f_code.co_name
  21. timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S:%f")
  22. print(f"[{timestamp}] - {caller} - [{log_level.upper()}] {message}")
  23. CONTEXT = {
  24. "database": False,
  25. "default_room_name": "GPTBot",
  26. "system_message": "You are a helpful assistant.",
  27. "max_tokens": 3000,
  28. "max_messages": 20,
  29. "model": "gpt-3.5-turbo",
  30. "client": None,
  31. "sync_token": None,
  32. "logger": logging
  33. }
  34. async def gpt_query(messages: list, model: Optional[str] = None):
  35. model = model or CONTEXT["model"]
  36. logging(f"Querying GPT with {len(messages)} messages")
  37. try:
  38. response = openai.ChatCompletion.create(
  39. model=model,
  40. messages=messages
  41. )
  42. result_text = response.choices[0].message['content']
  43. tokens_used = response.usage["total_tokens"]
  44. logging(f"Used {tokens_used} tokens")
  45. return result_text, tokens_used
  46. except Exception as e:
  47. logging(f"Error during GPT API call: {e}", "error")
  48. return None, 0
  49. async def fetch_last_n_messages(room_id: str, n: Optional[int] = None,
  50. client: Optional[AsyncClient] = None, sync_token: Optional[str] = None):
  51. messages = []
  52. n = n or CONTEXT["max_messages"]
  53. client = client or CONTEXT["client"]
  54. sync_token = sync_token or CONTEXT["sync_token"]
  55. logging(
  56. f"Fetching last {2*n} messages from room {room_id} (starting at {sync_token})...")
  57. response = await client.room_messages(
  58. room_id=room_id,
  59. start=sync_token,
  60. limit=2*n,
  61. )
  62. if isinstance(response, RoomMessagesError):
  63. logging(
  64. f"Error fetching messages: {response.message} (status code {response.status_code})", "error")
  65. return []
  66. for event in response.chunk:
  67. if len(messages) >= n:
  68. break
  69. if isinstance(event, RoomMessageText):
  70. if event.body.startswith("!gptbot ignoreolder"):
  71. break
  72. if not event.body.startswith("!"):
  73. messages.append(event)
  74. logging(f"Found {len(messages)} messages (limit: {n})")
  75. # Reverse the list so that messages are in chronological order
  76. return messages[::-1]
  77. def truncate_messages_to_fit_tokens(messages: list, max_tokens: Optional[int] = None,
  78. model: Optional[str] = None, system_message: Optional[str] = None):
  79. max_tokens = max_tokens or CONTEXT["max_tokens"]
  80. model = model or CONTEXT["model"]
  81. system_message = system_message or CONTEXT["system_message"]
  82. encoding = tiktoken.encoding_for_model(model)
  83. total_tokens = 0
  84. system_message_tokens = len(encoding.encode(system_message)) + 1
  85. if system_message_tokens > max_tokens:
  86. logging(
  87. f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed", "error")
  88. return []
  89. total_tokens += system_message_tokens
  90. total_tokens = len(system_message) + 1
  91. truncated_messages = []
  92. for message in [messages[0]] + list(reversed(messages[1:])):
  93. content = message["content"]
  94. tokens = len(encoding.encode(content)) + 1
  95. if total_tokens + tokens > max_tokens:
  96. break
  97. total_tokens += tokens
  98. truncated_messages.append(message)
  99. return [truncated_messages[0]] + list(reversed(truncated_messages[1:]))
  100. async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
  101. client = kwargs.get("client") or CONTEXT["client"]
  102. database = kwargs.get("database") or CONTEXT["database"]
  103. system_message = kwargs.get("system_message") or CONTEXT["system_message"]
  104. max_tokens = kwargs.get("max_tokens") or CONTEXT["max_tokens"]
  105. await client.room_typing(room.room_id, True)
  106. await client.room_read_markers(room.room_id, event.event_id)
  107. last_messages = await fetch_last_n_messages(room.room_id, 20)
  108. chat_messages = [{"role": "system", "content": system_message}]
  109. for message in last_messages:
  110. role = "assistant" if message.sender == client.user_id else "user"
  111. if not message.event_id == event.event_id:
  112. chat_messages.append({"role": role, "content": message.body})
  113. chat_messages.append({"role": "user", "content": event.body})
  114. # Truncate messages to fit within the token limit
  115. truncated_messages = truncate_messages_to_fit_tokens(
  116. chat_messages, max_tokens - 1)
  117. response, tokens_used = await gpt_query(truncated_messages)
  118. if response:
  119. logging(f"Sending response to room {room.room_id}...")
  120. # Convert markdown to HTML
  121. markdowner = markdown2.Markdown(extras=["fenced-code-blocks"])
  122. formatted_body = markdowner.convert(response)
  123. message = await client.room_send(
  124. room.room_id, "m.room.message",
  125. {"msgtype": "m.text", "body": response,
  126. "format": "org.matrix.custom.html", "formatted_body": formatted_body}
  127. )
  128. if database:
  129. logging("Logging tokens used...")
  130. with database.cursor() as cursor:
  131. cursor.execute(
  132. "INSERT INTO token_usage (message_id, room_id, tokens, timestamp) VALUES (?, ?, ?, ?)",
  133. (message.event_id, room.room_id, tokens_used, datetime.now()))
  134. database.commit()
  135. else:
  136. # Send a notice to the room if there was an error
  137. logging("Error during GPT API call - sending notice to room")
  138. await client.room_send(
  139. room.room_id, "m.room.message", {
  140. "msgtype": "m.notice", "body": "Sorry, I'm having trouble connecting to the GPT API right now. Please try again later."}
  141. )
  142. print("No response from GPT API")
  143. await client.room_typing(room.room_id, False)
  144. async def process_command(room: MatrixRoom, event: RoomMessageText, context: Optional[dict] = None):
  145. context = context or CONTEXT
  146. logging(
  147. f"Received command {event.body} from {event.sender} in room {room.room_id}")
  148. command = event.body.split()[1] if event.body.split()[1:] else None
  149. await COMMANDS.get(command, COMMANDS[None])(room, event, context)
  150. async def message_callback(room: MatrixRoom, event: RoomMessageText, **kwargs):
  151. context = kwargs.get("context") or CONTEXT
  152. logging(f"Received message from {event.sender} in room {room.room_id}")
  153. if event.sender == context["client"].user_id:
  154. logging("Message is from bot itself - ignoring")
  155. elif event.body.startswith("!gptbot"):
  156. await process_command(room, event)
  157. elif event.body.startswith("!"):
  158. logging("Might be a command, but not for this bot - ignoring")
  159. else:
  160. await process_query(room, event, context=context)
  161. async def room_invite_callback(room: MatrixRoom, event: InviteEvent, **kwargs):
  162. client = kwargs.get("client") or CONTEXT["client"]
  163. logging(f"Received invite to room {room.room_id} - joining...")
  164. await client.join(room.room_id)
  165. await client.room_send(
  166. room.room_id,
  167. "m.room.message",
  168. {"msgtype": "m.text",
  169. "body": "Hello! I'm a helpful assistant. How can I help you today?"}
  170. )
  171. async def accept_pending_invites(client: Optional[AsyncClient] = None):
  172. client = client or CONTEXT["client"]
  173. logging("Accepting pending invites...")
  174. for room_id in list(client.invited_rooms.keys()):
  175. logging(f"Joining room {room_id}...")
  176. await client.join(room_id)
  177. await client.room_send(
  178. room_id,
  179. "m.room.message",
  180. {"msgtype": "m.text",
  181. "body": "Hello! I'm a helpful assistant. How can I help you today?"}
  182. )
  183. async def sync_cb(response, write_global: bool = True):
  184. logging(
  185. f"Sync response received (next batch: {response.next_batch})", "debug")
  186. SYNC_TOKEN = response.next_batch
  187. if write_global:
  188. global CONTEXT
  189. CONTEXT["sync_token"] = SYNC_TOKEN
  190. async def main(client: Optional[AsyncClient] = None):
  191. client = client or CONTEXT["client"]
  192. if not client.user_id:
  193. whoami = await client.whoami()
  194. client.user_id = whoami.user_id
  195. try:
  196. assert client.user_id
  197. except AssertionError:
  198. logging(
  199. "Failed to get user ID - check your access token or try setting it manually", "critical")
  200. await client.close()
  201. return
  202. logging("Starting bot...")
  203. client.add_response_callback(sync_cb, SyncResponse)
  204. logging("Syncing...")
  205. await client.sync(timeout=30000)
  206. client.add_event_callback(message_callback, RoomMessageText)
  207. client.add_event_callback(room_invite_callback, InviteEvent)
  208. await accept_pending_invites() # Accept pending invites
  209. logging("Bot started")
  210. try:
  211. # Continue syncing events
  212. await client.sync_forever(timeout=30000)
  213. finally:
  214. logging("Syncing one last time...")
  215. await client.sync(timeout=30000)
  216. await client.close() # Properly close the aiohttp client session
  217. logging("Bot stopped")
  218. def initialize_database(path: os.PathLike):
  219. logging("Initializing database...")
  220. database = duckdb.connect(path)
  221. with database.cursor() as cursor:
  222. # Get the latest migration ID if the migrations table exists
  223. try:
  224. cursor.execute(
  225. """
  226. SELECT MAX(id) FROM migrations
  227. """
  228. )
  229. latest_migration = int(cursor.fetchone()[0])
  230. except:
  231. latest_migration = 0
  232. # Version 1
  233. if latest_migration < 1:
  234. cursor.execute(
  235. """
  236. CREATE TABLE IF NOT EXISTS token_usage (
  237. message_id TEXT PRIMARY KEY,
  238. room_id TEXT NOT NULL,
  239. tokens INTEGER NOT NULL,
  240. timestamp TIMESTAMP NOT NULL
  241. )
  242. """
  243. )
  244. cursor.execute(
  245. """
  246. CREATE TABLE IF NOT EXISTS migrations (
  247. id INTEGER NOT NULL,
  248. timestamp TIMESTAMP NOT NULL
  249. )
  250. """
  251. )
  252. cursor.execute(
  253. "INSERT INTO migrations (id, timestamp) VALUES (1, ?)",
  254. (datetime.now(),)
  255. )
  256. database.commit()
  257. return database
  258. if __name__ == "__main__":
  259. # Parse command line arguments
  260. parser = ArgumentParser()
  261. parser.add_argument(
  262. "--config", help="Path to config file (default: config.ini in working directory)", default="config.ini")
  263. args = parser.parse_args()
  264. # Read config file
  265. config = ConfigParser()
  266. config.read(args.config)
  267. # Set up Matrix client
  268. try:
  269. assert "Matrix" in config
  270. assert "Homeserver" in config["Matrix"]
  271. assert "AccessToken" in config["Matrix"]
  272. except:
  273. logging("Matrix config not found or incomplete", "critical")
  274. exit(1)
  275. CONTEXT["client"] = AsyncClient(config["Matrix"]["Homeserver"])
  276. CONTEXT["client"].access_token = config["Matrix"]["AccessToken"]
  277. CONTEXT["client"].user_id = config["Matrix"].get("UserID")
  278. # Set up GPT API
  279. try:
  280. assert "OpenAI" in config
  281. assert "APIKey" in config["OpenAI"]
  282. except:
  283. logging("OpenAI config not found or incomplete", "critical")
  284. exit(1)
  285. openai.api_key = config["OpenAI"]["APIKey"]
  286. if "Model" in config["OpenAI"]:
  287. CONTEXT["model"] = config["OpenAI"]["Model"]
  288. if "MaxTokens" in config["OpenAI"]:
  289. CONTEXT["max_tokens"] = int(config["OpenAI"]["MaxTokens"])
  290. if "MaxMessages" in config["OpenAI"]:
  291. CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"])
  292. # Set up database
  293. if "Database" in config and config["Database"].get("Path"):
  294. CONTEXT["database"] = initialize_database(config["Database"]["Path"])
  295. # Listen for SIGTERM
  296. def sigterm_handler(_signo, _stack_frame):
  297. logging("Received SIGTERM - exiting...")
  298. exit()
  299. signal.signal(signal.SIGTERM, sigterm_handler)
  300. # Start bot loop
  301. try:
  302. asyncio.run(main())
  303. except KeyboardInterrupt:
  304. logging("Received KeyboardInterrupt - exiting...")
  305. except SystemExit:
  306. logging("Received SIGTERM - exiting...")
  307. finally:
  308. if CONTEXT["database"]:
  309. CONTEXT["database"].close()