Quellcode durchsuchen

Enable per-room model overrides and clean up code

Introduced the ability to specify and retrieve different OpenAI models on a per-room basis, thereby allowing enhanced customization of the bot's response behavior according to the preferences for each room. Cleaned up code formatting across the bot implementation files for improved readability and maintainability. Additional logic now checks for model overrides when generating responses, ensuring the correct model is used as configured.

Refactors include streamlined database and API initializations and a refined method for processing message formatting to accommodate images, texts, and system messages consistently. This change differentiates default behavior from room-specific configurations, catering to diverse user needs without compromising on default settings.
Kumi vor 10 Monaten
Ursprung
Commit
87173ae284
3 geänderte Dateien mit 184 neuen und 47 gelöschten Zeilen
  1. 140 44
      src/gptbot/classes/bot.py
  2. 7 3
      src/gptbot/classes/openai.py
  3. 37 0
      src/gptbot/commands/roomsettings.py

+ 140 - 44
src/gptbot/classes/bot.py

@@ -87,7 +87,7 @@ class GPTBot:
     @property
     def allowed_users(self) -> List[str]:
         """List of users allowed to use the bot.
-        
+
         Returns:
             List[str]: List of user IDs. Defaults to [], which means all users are allowed.
         """
@@ -99,7 +99,7 @@ class GPTBot:
     @property
     def display_name(self) -> str:
         """Display name of the bot user.
-        
+
         Returns:
             str: The display name of the bot user. Defaults to "GPTBot".
         """
@@ -108,7 +108,7 @@ class GPTBot:
     @property
     def default_room_name(self) -> str:
         """Default name of rooms created by the bot.
-        
+
         Returns:
             str: The default name of rooms created by the bot. Defaults to the display name of the bot.
         """
@@ -181,8 +181,17 @@ class GPTBot:
             str: The path to the logo of the bot. Defaults to "assets/logo.png" in the bot's directory.
         """
         return self.config["GPTBot"].get(
-                "Logo", str(Path(__file__).parent.parent / "assets/logo.png")
-            )
+            "Logo", str(Path(__file__).parent.parent / "assets/logo.png")
+        )
+
+    @property
+    def allow_model_override(self) -> bool:
+        """Whether to allow per-room model overrides.
+
+        Returns:
+            bool: Whether to allow per-room model overrides. Defaults to False.
+        """
+        return self.config["GPTBot"].getboolean("AllowModelOverride", False)
 
     # User agent to use for HTTP requests
     USER_AGENT = "matrix-gptbot/dev (+https://kumig.it/kumitterer/matrix-gptbot)"
@@ -208,11 +217,7 @@ class GPTBot:
             if "Database" in config and "Path" in config["Database"]
             else None
         )
-        bot.database = (
-            sqlite3.connect(bot.database_path)
-            if bot.database_path
-            else None
-        )
+        bot.database = sqlite3.connect(bot.database_path) if bot.database_path else None
 
         # Override default values
         if "GPTBot" in config:
@@ -224,14 +229,16 @@ class GPTBot:
             if Path(bot.logo_path).exists() and Path(bot.logo_path).is_file():
                 bot.logo = Image.open(bot.logo_path)
 
-        bot.chat_api = bot.image_api = bot.classification_api = bot.tts_api = bot.stt_api = OpenAI(
+        bot.chat_api = (
+            bot.image_api
+        ) = bot.classification_api = bot.tts_api = bot.stt_api = OpenAI(
             bot=bot,
-            api_key=config["OpenAI"]["APIKey"], 
+            api_key=config["OpenAI"]["APIKey"],
             chat_model=config["OpenAI"].get("Model"),
             image_model=config["OpenAI"].get("ImageModel"),
             tts_model=config["OpenAI"].get("TTSModel"),
             stt_model=config["OpenAI"].get("STTModel"),
-            base_url=config["OpenAI"].get("BaseURL")
+            base_url=config["OpenAI"].get("BaseURL"),
         )
 
         if "BaseURL" in config["OpenAI"]:
@@ -285,7 +292,12 @@ class GPTBot:
 
         return user_id
 
-    async def _last_n_messages(self, room: str | MatrixRoom, n: Optional[int], ignore_bot_commands: bool = False):
+    async def _last_n_messages(
+        self,
+        room: str | MatrixRoom,
+        n: Optional[int],
+        ignore_bot_commands: bool = False,
+    ):
         messages = []
         n = n or self.max_messages
         room_id = room.room_id if isinstance(room, MatrixRoom) else room
@@ -362,7 +374,13 @@ class GPTBot:
         truncated_messages = []
 
         for message in [messages[0]] + list(reversed(messages[1:])):
-            content = message["content"] if isinstance(message["content"], str) else message["content"][0]["text"] if isinstance(message["content"][0].get("text"), str) else ""
+            content = (
+                message["content"]
+                if isinstance(message["content"], str)
+                else message["content"][0]["text"]
+                if isinstance(message["content"][0].get("text"), str)
+                else ""
+            )
             tokens = len(encoding.encode(content)) + 1
             if total_tokens + tokens > max_tokens:
                 break
@@ -658,9 +676,7 @@ class GPTBot:
             "url": content_uri,
         }
 
-        status = await self.matrix_client.room_send(
-            room, "m.room.message", content
-        )
+        status = await self.matrix_client.room_send(room, "m.room.message", content)
 
         self.logger.log("Sent image", "debug")
 
@@ -694,9 +710,7 @@ class GPTBot:
             "url": content_uri,
         }
 
-        status = await self.matrix_client.room_send(
-            room, "m.room.message", content
-        )
+        status = await self.matrix_client.room_send(room, "m.room.message", content)
 
         self.logger.log("Sent file", "debug")
 
@@ -789,7 +803,9 @@ class GPTBot:
             self.matrix_client.device_id = await self._get_device_id()
 
         if not self.database:
-            self.database = sqlite3.connect(Path(__file__).parent.parent / "database.db")
+            self.database = sqlite3.connect(
+                Path(__file__).parent.parent / "database.db"
+            )
 
         self.logger.log("Running migrations...")
 
@@ -987,6 +1003,28 @@ class GPTBot:
 
         return True if not result else bool(int(result[0]))
 
+    async def get_room_model(self, room: MatrixRoom | str) -> str:
+        """Get the model used for a room.
+
+        Args:
+            room (MatrixRoom | str): The room to check.
+
+        Returns:
+            str: The model used for the room.
+        """
+
+        if isinstance(room, MatrixRoom):
+            room = room.room_id
+
+        with closing(self.database.cursor()) as cursor:
+            cursor.execute(
+                "SELECT value FROM room_settings WHERE room_id = ? AND setting = ?",
+                (room, "model"),
+            )
+            result = cursor.fetchone()
+
+        return result[0] if result else self.chat_api.chat_model
+
     async def process_query(
         self, room: MatrixRoom, event: RoomMessageText, from_chat_command: bool = False
     ):
@@ -1053,28 +1091,46 @@ class GPTBot:
         for message in last_messages:
             if isinstance(message, (RoomMessageNotice, RoomMessageText)):
                 role = (
-                    "assistant" if message.sender == self.matrix_client.user_id else "user"
+                    "assistant"
+                    if message.sender == self.matrix_client.user_id
+                    else "user"
                 )
                 if message == event or (not message.event_id == event.event_id):
-                    message_body = message.body if not self.chat_api.supports_chat_images() else [{"type": "text", "text": message.body}]
+                    message_body = (
+                        message.body
+                        if not self.chat_api.supports_chat_images()
+                        else [{"type": "text", "text": message.body}]
+                    )
                     chat_messages.append({"role": role, "content": message_body})
 
-            elif isinstance(message, RoomMessageAudio) or (isinstance(message, RoomMessageFile) and message.body.endswith(".mp3")):
+            elif isinstance(message, RoomMessageAudio) or (
+                isinstance(message, RoomMessageFile) and message.body.endswith(".mp3")
+            ):
                 role = (
-                    "assistant" if message.sender == self.matrix_client.user_id else "user"
+                    "assistant"
+                    if message.sender == self.matrix_client.user_id
+                    else "user"
                 )
                 if message == event or (not message.event_id == event.event_id):
                     if self.room_uses_stt(room):
                         try:
                             download = await self.download_file(message.url)
-                            message_text = await self.stt_api.speech_to_text(download.body)
+                            message_text = await self.stt_api.speech_to_text(
+                                download.body
+                            )
                         except Exception as e:
-                            self.logger.log(f"Error generating text from audio: {e}", "error")
+                            self.logger.log(
+                                f"Error generating text from audio: {e}", "error"
+                            )
                             message_text = message.body
                     else:
                         message_text = message.body
 
-                    message_body = message_text if not self.chat_api.supports_chat_images() else [{"type": "text", "text": message_text}]
+                    message_body = (
+                        message_text
+                        if not self.chat_api.supports_chat_images()
+                        else [{"type": "text", "text": message_text}]
+                    )
                     chat_messages.append({"role": role, "content": message_body})
 
             elif isinstance(message, RoomMessageFile):
@@ -1092,38 +1148,72 @@ class GPTBot:
                                 if message.sender == self.matrix_client.user_id
                                 else "user"
                             )
-                            if message == event or (not message.event_id == event.event_id):
-                                message_body = text if not self.chat_api.supports_chat_images() else [{"type": "text", "text": text}]
-                                chat_messages.append({"role": role, "content": message_body})
+                            if message == event or (
+                                not message.event_id == event.event_id
+                            ):
+                                message_body = (
+                                    text
+                                    if not self.chat_api.supports_chat_images()
+                                    else [{"type": "text", "text": text}]
+                                )
+                                chat_messages.append(
+                                    {"role": role, "content": message_body}
+                                )
 
                 except Exception as e:
                     self.logger.log(f"Error generating text from file: {e}", "error")
-                    message_body = message.body if not self.chat_api.supports_chat_images() else [{"type": "text", "text": message.body}]
+                    message_body = (
+                        message.body
+                        if not self.chat_api.supports_chat_images()
+                        else [{"type": "text", "text": message.body}]
+                    )
                     chat_messages.append({"role": "system", "content": message_body})
 
-            elif self.chat_api.supports_chat_images() and isinstance(message, RoomMessageImage):
+            elif self.chat_api.supports_chat_images() and isinstance(
+                message, RoomMessageImage
+            ):
                 try:
                     image_url = message.url
                     download = await self.download_file(image_url)
 
                     if download:
                         encoded_url = f"data:{download.content_type};base64,{base64.b64encode(download.body).decode('utf-8')}"
-                        parent = chat_messages[-1] if chat_messages and chat_messages[-1]["role"] == ("assistant" if message.sender == self.matrix_client.user_id else "user") else None
+                        parent = (
+                            chat_messages[-1]
+                            if chat_messages
+                            and chat_messages[-1]["role"]
+                            == (
+                                "assistant"
+                                if message.sender == self.matrix_client.user_id
+                                else "user"
+                            )
+                            else None
+                        )
 
                         if not parent:
-                            chat_messages.append({"role": ("assistant" if message.sender == self.matrix_client.user_id else "user"), "content": []})
+                            chat_messages.append(
+                                {
+                                    "role": (
+                                        "assistant"
+                                        if message.sender == self.matrix_client.user_id
+                                        else "user"
+                                    ),
+                                    "content": [],
+                                }
+                            )
                             parent = chat_messages[-1]
 
-                        parent["content"].append({
-                            "type": "image_url",
-                            "image_url": {
-                                "url": encoded_url
-                            }
-                        })
+                        parent["content"].append(
+                            {"type": "image_url", "image_url": {"url": encoded_url}}
+                        )
 
                 except Exception as e:
                     self.logger.log(f"Error generating image from file: {e}", "error")
-                    message_body = message.body if not self.chat_api.supports_chat_images() else [{"type": "text", "text": message.body}]
+                    message_body = (
+                        message.body
+                        if not self.chat_api.supports_chat_images()
+                        else [{"type": "text", "text": message.body}]
+                    )
                     chat_messages.append({"role": "system", "content": message_body})
 
         # Truncate messages to fit within the token limit
@@ -1131,9 +1221,15 @@ class GPTBot:
             chat_messages[1:], self.max_tokens - 1, system_message=system_message
         )
 
+        # Check for a model override
+        if self.allow_model_override:
+            model = await self.get_room_model(room)
+        else:
+            model = self.chat_api.chat_model
+
         try:
             response, tokens_used = await self.chat_api.generate_chat_response(
-                chat_messages, user=event.sender, room=room.room_id
+                chat_messages, user=event.sender, room=room.room_id, model=model
             )
         except Exception as e:
             print(traceback.format_exc())

+ 7 - 3
src/gptbot/classes/openai.py

@@ -124,6 +124,7 @@ class OpenAI:
         room: Optional[str] = None,
         allow_override: bool = True,
         use_tools: bool = True,
+        model: Optional[str] = None,
     ) -> Tuple[str, int]:
         """Generate a response to a chat message.
 
@@ -133,6 +134,7 @@ class OpenAI:
             room (Optional[str], optional): The room to use the assistant for. Defaults to None.
             allow_override (bool, optional): Whether to allow the chat model to be overridden. Defaults to True.
             use_tools (bool, optional): Whether to use tools. Defaults to True.
+            model (Optional[str], optional): The model to use. Defaults to None, which uses the default chat model.
 
         Returns:
             Tuple[str, int]: The response text and the number of tokens used.
@@ -141,6 +143,8 @@ class OpenAI:
             f"Generating response to {len(messages)} messages for user {user} in room {room}..."
         )
 
+        chat_model = model or self.chat_model
+
         # Check current recursion depth to prevent infinite loops
 
         if use_tools:
@@ -157,6 +161,7 @@ class OpenAI:
                     room=room,
                     allow_override=False, # TODO: Could this be a problem?
                     use_tools=False,
+                    model=model,
                 )
 
         tools = [
@@ -171,10 +176,9 @@ class OpenAI:
             for tool_name, tool_class in TOOLS.items()
         ]
 
-        chat_model = self.chat_model
         original_messages = messages
 
-        if allow_override and not "gpt-3.5-turbo" in self.chat_model:
+        if allow_override and not "gpt-3.5-turbo" in model:
             if self.bot.config.getboolean("OpenAI", "ForceTools", fallback=False):
                 self.logger.log(f"Overriding chat model to use tools")
                 chat_model = "gpt-3.5-turbo-1106"
@@ -204,7 +208,7 @@ class OpenAI:
             use_tools
             and self.bot.config.getboolean("OpenAI", "EmulateTools", fallback=False)
             and not self.bot.config.getboolean("OpenAI", "ForceTools", fallback=False)
-            and not "gpt-3.5-turbo" in self.chat_model
+            and not "gpt-3.5-turbo" in chat_model
         ):
             self.bot.logger.log("Using tool emulation mode.", "debug")
 

+ 37 - 0
src/gptbot/commands/roomsettings.py

@@ -80,6 +80,40 @@ async def command_roomsettings(room: MatrixRoom, event: RoomMessageText, bot):
         await bot.send_message(room, f"The current {setting} status is: '{value}'.", True)
         return
 
+    if bot.allow_model_override and setting == "model":
+        if value:
+            bot.logger.log(f"Setting chat model for {room.room_id} to {value}...")
+
+            with closing(bot.database.cursor()) as cur:
+                cur.execute(
+                    """INSERT INTO room_settings (room_id, setting, value) VALUES (?, ?, ?)
+                    ON CONFLICT (room_id, setting) DO UPDATE SET value = ?;""",
+                    (room.room_id, "model", value, value)
+                )
+
+            bot.database.commit()
+
+            await bot.send_message(room, f"Alright, I've set the chat model to: '{value}'.", True)
+            return
+
+        bot.logger.log(f"Retrieving chat model for {room.room_id}...")
+
+        with closing(bot.database.cursor()) as cur:
+            cur.execute(
+                """SELECT value FROM room_settings WHERE room_id = ? AND setting = ?;""",
+                (room.room_id, "model")
+            )
+
+            value = cur.fetchone()[0]
+
+            if not value:
+                value = bot.chat_api.chat_model
+            else:
+                value = str(value)
+
+        await bot.send_message(room, f"The current chat model is: '{value}'.", True)
+        return
+
     message = f"""The following settings are available:
 
 - system_message [message]: Get or set the system message to be sent to the chat model
@@ -90,4 +124,7 @@ async def command_roomsettings(room: MatrixRoom, event: RoomMessageText, bot):
 - timing [true/false]: Get or set whether the bot should return information about the time it took to generate a response
 """
 
+    if bot.allow_model_override:
+        message += "- model [model]: Get or set the chat model to be used for this room"
+
     await bot.send_message(room, message, True)