123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481 |
- import openai
- import markdown2
- import duckdb
- import tiktoken
- import asyncio
- from nio import (
- AsyncClient,
- AsyncClientConfig,
- WhoamiResponse,
- DevicesResponse,
- Event,
- Response,
- MatrixRoom,
- Api,
- RoomMessagesError,
- MegolmEvent,
- GroupEncryptionError,
- EncryptionError,
- RoomMessageText,
- RoomSendResponse,
- SyncResponse
- )
- from nio.crypto import Olm
- from typing import Optional, List, Dict
- from configparser import ConfigParser
- from datetime import datetime
- import uuid
- from .logging import Logger
- from migrations import migrate
- from callbacks import RESPONSE_CALLBACKS, EVENT_CALLBACKS
- from commands import COMMANDS
- from .store import DuckDBStore
- class GPTBot:
- # Default values
- database: Optional[duckdb.DuckDBPyConnection] = None
- default_room_name: str = "GPTBot" # Default name of rooms created by the bot
- default_system_message: str = "You are a helpful assistant."
- # Force default system message to be included even if a custom room message is set
- force_system_message: bool = False
- max_tokens: int = 3000 # Maximum number of input tokens
- max_messages: int = 30 # Maximum number of messages to consider as input
- model: str = "gpt-3.5-turbo" # OpenAI chat model to use
- matrix_client: Optional[AsyncClient] = None
- sync_token: Optional[str] = None
- logger: Optional[Logger] = Logger()
- openai_api_key: Optional[str] = None
- @classmethod
- def from_config(cls, config: ConfigParser):
- """Create a new GPTBot instance from a config file.
- Args:
- config (ConfigParser): ConfigParser instance with the bot's config.
- Returns:
- GPTBot: The new GPTBot instance.
- """
- # Create a new GPTBot instance
- bot = cls()
- # Set the database connection
- bot.database = duckdb.connect(
- config["Database"]["Path"]) if "Database" in config and "Path" in config["Database"] else None
- # Override default values
- if "GPTBot" in config:
- bot.default_room_name = config["GPTBot"].get(
- "DefaultRoomName", bot.default_room_name)
- bot.default_system_message = config["GPTBot"].get(
- "SystemMessage", bot.default_system_message)
- bot.force_system_message = config["GPTBot"].getboolean(
- "ForceSystemMessage", bot.force_system_message)
- bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
- bot.max_messages = config["OpenAI"].getint(
- "MaxMessages", bot.max_messages)
- bot.model = config["OpenAI"].get("Model", bot.model)
- bot.openai_api_key = config["OpenAI"]["APIKey"]
- # Set up the Matrix client
- assert "Matrix" in config, "Matrix config not found"
- homeserver = config["Matrix"]["Homeserver"]
- bot.matrix_client = AsyncClient(homeserver)
- bot.matrix_client.access_token = config["Matrix"]["AccessToken"]
- bot.matrix_client.user_id = config["Matrix"].get("UserID")
- bot.matrix_client.device_id = config["Matrix"].get("DeviceID")
- # Return the new GPTBot instance
- return bot
- async def _get_user_id(self) -> str:
- """Get the user ID of the bot from the whoami endpoint.
- Requires an access token to be set up.
- Returns:
- str: The user ID of the bot.
- """
- assert self.matrix_client, "Matrix client not set up"
- user_id = self.matrix_client.user_id
- if not user_id:
- assert self.matrix_client.access_token, "Access token not set up"
- response = await self.matrix_client.whoami()
- if isinstance(response, WhoamiResponse):
- user_id = response.user_id
- else:
- raise Exception(f"Could not get user ID: {response}")
- return user_id
- async def _last_n_messages(self, room: str | MatrixRoom, n: Optional[int]):
- messages = []
- n = n or bot.max_messages
- room_id = room.room_id if isinstance(room, MatrixRoom) else room
- self.logger.log(
- f"Fetching last {2*n} messages from room {room_id} (starting at {self.sync_token})...")
- response = await self.matrix_client.room_messages(
- room_id=room_id,
- start=self.sync_token,
- limit=2*n,
- )
- if isinstance(response, RoomMessagesError):
- raise Exception(
- f"Error fetching messages: {response.message} (status code {response.status_code})", "error")
- for event in response.chunk:
- if len(messages) >= n:
- break
- if isinstance(event, MegolmEvent):
- try:
- event = await self.matrix_client.decrypt_event(event)
- except (GroupEncryptionError, EncryptionError):
- self.logger.log(
- f"Could not decrypt message {event.event_id} in room {room_id}", "error")
- continue
- if isinstance(event, RoomMessageText):
- if event.body.startswith("!gptbot ignoreolder"):
- break
- if not event.body.startswith("!"):
- messages.append(event)
- self.logger.log(f"Found {len(messages)} messages (limit: {n})")
- # Reverse the list so that messages are in chronological order
- return messages[::-1]
- def _truncate(self, messages: list, max_tokens: Optional[int] = None,
- model: Optional[str] = None, system_message: Optional[str] = None):
- max_tokens = max_tokens or self.max_tokens
- model = model or self.model
- system_message = self.default_system_message if system_message is None else system_message
- encoding = tiktoken.encoding_for_model(model)
- total_tokens = 0
- system_message_tokens = 0 if not system_message else (
- len(encoding.encode(system_message)) + 1)
- if system_message_tokens > max_tokens:
- self.logger.log(
- f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed", "error")
- return []
- total_tokens += system_message_tokens
- total_tokens = len(system_message) + 1
- truncated_messages = []
- for message in [messages[0]] + list(reversed(messages[1:])):
- content = message["content"]
- tokens = len(encoding.encode(content)) + 1
- if total_tokens + tokens > max_tokens:
- break
- total_tokens += tokens
- truncated_messages.append(message)
- return [truncated_messages[0]] + list(reversed(truncated_messages[1:]))
- async def _get_device_id(self) -> str:
- """Guess the device ID of the bot.
- Requires an access token to be set up.
- Returns:
- str: The guessed device ID.
- """
- assert self.matrix_client, "Matrix client not set up"
- device_id = self.matrix_client.device_id
- if not device_id:
- assert self.matrix_client.access_token, "Access token not set up"
- devices = await self.matrix_client.devices()
- if isinstance(devices, DevicesResponse):
- device_id = devices.devices[0].id
- return device_id
- async def process_command(self, room: MatrixRoom, event: RoomMessageText):
- self.logger.log(
- f"Received command {event.body} from {event.sender} in room {room.room_id}")
- command = event.body.split()[1] if event.body.split()[1:] else None
- await COMMANDS.get(command, COMMANDS[None])(room, event, self)
- async def event_callback(self,room: MatrixRoom, event: Event):
- self.logger.log("Received event: " + str(event), "debug")
- for eventtype, callback in EVENT_CALLBACKS.items():
- if isinstance(event, eventtype):
- await callback(room, event, self)
- async def response_callback(self, response: Response):
- for response_type, callback in RESPONSE_CALLBACKS.items():
- if isinstance(response, response_type):
- await callback(response, self)
- async def accept_pending_invites(self):
- """Accept all pending invites."""
- assert self.matrix_client, "Matrix client not set up"
- invites = self.matrix_client.invited_rooms
- for invite in invites.keys():
- await self.matrix_client.join(invite)
- async def send_message(self, room: MatrixRoom, message: str, notice: bool = False):
- markdowner = markdown2.Markdown(extras=["fenced-code-blocks"])
- formatted_body = markdowner.convert(message)
- msgtype = "m.notice" if notice else "m.text"
- msgcontent = {"msgtype": msgtype, "body": message,
- "format": "org.matrix.custom.html", "formatted_body": formatted_body}
- content = None
- if self.matrix_client.olm and room.encrypted:
- try:
- if not room.members_synced:
- responses = []
- responses.append(await self.matrix_client.joined_members(room.room_id))
- if self.matrix_client.olm.should_share_group_session(room.room_id):
- try:
- event = self.matrix_client.sharing_session[room.room_id]
- await event.wait()
- except KeyError:
- await self.matrix_client.share_group_session(
- room.room_id,
- ignore_unverified_devices=True,
- )
- if msgtype != "m.reaction":
- response = self.matrix_client.encrypt(
- room.room_id, "m.room.message", msgcontent)
- msgtype, content = response
- except Exception as e:
- self.logger.log(
- f"Error encrypting message: {e} - sending unencrypted", "error")
- raise
- if not content:
- msgtype = "m.room.message"
- content = msgcontent
- method, path, data = Api.room_send(
- self.matrix_client.access_token, room.room_id, msgtype, content, uuid.uuid4()
- )
- return await self.matrix_client._send(RoomSendResponse, method, path, data, (room.room_id,))
- async def run(self):
- """Start the bot."""
- # Set up the Matrix client
- assert self.matrix_client, "Matrix client not set up"
- assert self.matrix_client.access_token, "Access token not set up"
- if not self.matrix_client.user_id:
- self.matrix_client.user_id = await self._get_user_id()
- if not self.matrix_client.device_id:
- self.matrix_client.device_id = await self._get_device_id()
- # Set up database
- IN_MEMORY = False
- if not self.database:
- self.logger.log(
- "No database connection set up, using in-memory database. Data will be lost on bot shutdown.")
- IN_MEMORY = True
- self.database = DuckDBPyConnection(":memory:")
- self.logger.log("Running migrations...")
- before, after = migrate(self.database)
- if before != after:
- self.logger.log(f"Migrated from version {before} to {after}.")
- else:
- self.logger.log(f"Already at latest version {after}.")
- if IN_MEMORY:
- client_config = AsyncClientConfig(
- store_sync_tokens=True, encryption_enabled=False)
- else:
- matrix_store = DuckDBStore
- client_config = AsyncClientConfig(
- store_sync_tokens=True, encryption_enabled=True, store=matrix_store)
- self.matrix_client.config = client_config
- self.matrix_client.store = matrix_store(
- self.matrix_client.user_id,
- self.matrix_client.device_id,
- self.database
- )
- self.matrix_client.olm = Olm(
- self.matrix_client.user_id,
- self.matrix_client.device_id,
- self.matrix_client.store
- )
- self.matrix_client.encrypted_rooms = self.matrix_client.store.load_encrypted_rooms()
- # Run initial sync
- sync = await self.matrix_client.sync(timeout=30000)
- if isinstance(sync, SyncResponse):
- await self.response_callback(sync)
- else:
- self.logger.log(f"Initial sync failed, aborting: {sync}", "error")
- return
- # Set up callbacks
- self.matrix_client.add_event_callback(self.event_callback, Event)
- self.matrix_client.add_response_callback(self.response_callback, Response)
- # Accept pending invites
- self.logger.log("Accepting pending invites...")
- await self.accept_pending_invites()
- # Start syncing events
- self.logger.log("Starting sync loop...")
- try:
- await self.matrix_client.sync_forever(timeout=30000)
- finally:
- self.logger.log("Syncing one last time...")
- await self.matrix_client.sync(timeout=30000)
- async def process_query(self, room: MatrixRoom, event: RoomMessageText):
- await self.matrix_client.room_typing(room.room_id, True)
- await self.matrix_client.room_read_markers(room.room_id, event.event_id)
- try:
- last_messages = await self._last_n_messages(room.room_id, 20)
- except Exception as e:
- self.logger.log(f"Error getting last messages: {e}", "error")
- await self.send_message(
- room, "Something went wrong. Please try again.", True)
- return
- system_message = self.get_system_message(room)
- chat_messages = [{"role": "system", "content": system_message}]
- for message in last_messages:
- role = "assistant" if message.sender == self.matrix_client.user_id else "user"
- if not message.event_id == event.event_id:
- chat_messages.append({"role": role, "content": message.body})
- chat_messages.append({"role": "user", "content": event.body})
- # Truncate messages to fit within the token limit
- truncated_messages = self._truncate(
- chat_messages, self.max_tokens - 1, system_message=system_message)
- try:
- response, tokens_used = await self.generate_chat_response(truncated_messages)
- except Exception as e:
- self.logger.log(f"Error generating response: {e}", "error")
- await self.send_message(
- room, "Something went wrong. Please try again.", True)
- return
- if response:
- self.logger.log(f"Sending response to room {room.room_id}...")
- # Convert markdown to HTML
- message = await self.send_message(room, response)
- if self.database:
- self.logger.log("Storing record of tokens used...")
- with self.database.cursor() as cursor:
- cursor.execute(
- "INSERT INTO token_usage (message_id, room_id, tokens, timestamp) VALUES (?, ?, ?, ?)",
- (message.event_id, room.room_id, tokens_used, datetime.now()))
- self.database.commit()
- else:
- # Send a notice to the room if there was an error
- self.logger.log("Didn't get a response from GPT API", "error")
- send_message(
- room, "Something went wrong. Please try again.", True)
- await self.matrix_client.room_typing(room.room_id, False)
- async def generate_chat_response(self, messages: List[Dict[str, str]]) -> str:
- """Generate a response to a chat message.
- Args:
- messages (List[Dict[str, str]]): A list of messages to use as context.
- Returns:
- str: The response to the chat.
- """
- self.logger.log(f"Generating response to {len(messages)} messages...")
- response = openai.ChatCompletion.create(
- model=self.model,
- messages=messages,
- api_key=self.openai_api_key
- )
- result_text = response.choices[0].message['content']
- tokens_used = response.usage["total_tokens"]
- self.logger.log(f"Generated response with {tokens_used} tokens.")
- return result_text, tokens_used
- def get_system_message(self, room: MatrixRoom | int) -> str:
- default = self.default_system_message
- if isinstance(room, int):
- room_id = room
- else:
- room_id = room.room_id
- with self.database.cursor() as cur:
- cur.execute(
- "SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
- (room_id,)
- )
- system_message = cur.fetchone()
- complete = ((default if ((not system_message) or self.force_system_message) else "") + (
- "\n\n" + system_message[0] if system_message else "")).strip()
- return complete
- def __del__(self):
- """Close the bot."""
- if self.matrix_client:
- asyncio.run(self.matrix_client.close())
- if self.database:
- self.database.close()
|