|
@@ -87,7 +87,7 @@ class GPTBot:
|
|
|
@property
|
|
|
def allowed_users(self) -> List[str]:
|
|
|
"""List of users allowed to use the bot.
|
|
|
-
|
|
|
+
|
|
|
Returns:
|
|
|
List[str]: List of user IDs. Defaults to [], which means all users are allowed.
|
|
|
"""
|
|
@@ -99,7 +99,7 @@ class GPTBot:
|
|
|
@property
|
|
|
def display_name(self) -> str:
|
|
|
"""Display name of the bot user.
|
|
|
-
|
|
|
+
|
|
|
Returns:
|
|
|
str: The display name of the bot user. Defaults to "GPTBot".
|
|
|
"""
|
|
@@ -108,7 +108,7 @@ class GPTBot:
|
|
|
@property
|
|
|
def default_room_name(self) -> str:
|
|
|
"""Default name of rooms created by the bot.
|
|
|
-
|
|
|
+
|
|
|
Returns:
|
|
|
str: The default name of rooms created by the bot. Defaults to the display name of the bot.
|
|
|
"""
|
|
@@ -181,8 +181,17 @@ class GPTBot:
|
|
|
str: The path to the logo of the bot. Defaults to "assets/logo.png" in the bot's directory.
|
|
|
"""
|
|
|
return self.config["GPTBot"].get(
|
|
|
- "Logo", str(Path(__file__).parent.parent / "assets/logo.png")
|
|
|
- )
|
|
|
+ "Logo", str(Path(__file__).parent.parent / "assets/logo.png")
|
|
|
+ )
|
|
|
+
|
|
|
+ @property
|
|
|
+ def allow_model_override(self) -> bool:
|
|
|
+ """Whether to allow per-room model overrides.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ bool: Whether to allow per-room model overrides. Defaults to False.
|
|
|
+ """
|
|
|
+ return self.config["GPTBot"].getboolean("AllowModelOverride", False)
|
|
|
|
|
|
# User agent to use for HTTP requests
|
|
|
USER_AGENT = "matrix-gptbot/dev (+https://kumig.it/kumitterer/matrix-gptbot)"
|
|
@@ -208,11 +217,7 @@ class GPTBot:
|
|
|
if "Database" in config and "Path" in config["Database"]
|
|
|
else None
|
|
|
)
|
|
|
- bot.database = (
|
|
|
- sqlite3.connect(bot.database_path)
|
|
|
- if bot.database_path
|
|
|
- else None
|
|
|
- )
|
|
|
+ bot.database = sqlite3.connect(bot.database_path) if bot.database_path else None
|
|
|
|
|
|
# Override default values
|
|
|
if "GPTBot" in config:
|
|
@@ -224,14 +229,16 @@ class GPTBot:
|
|
|
if Path(bot.logo_path).exists() and Path(bot.logo_path).is_file():
|
|
|
bot.logo = Image.open(bot.logo_path)
|
|
|
|
|
|
- bot.chat_api = bot.image_api = bot.classification_api = bot.tts_api = bot.stt_api = OpenAI(
|
|
|
+ bot.chat_api = (
|
|
|
+ bot.image_api
|
|
|
+ ) = bot.classification_api = bot.tts_api = bot.stt_api = OpenAI(
|
|
|
bot=bot,
|
|
|
- api_key=config["OpenAI"]["APIKey"],
|
|
|
+ api_key=config["OpenAI"]["APIKey"],
|
|
|
chat_model=config["OpenAI"].get("Model"),
|
|
|
image_model=config["OpenAI"].get("ImageModel"),
|
|
|
tts_model=config["OpenAI"].get("TTSModel"),
|
|
|
stt_model=config["OpenAI"].get("STTModel"),
|
|
|
- base_url=config["OpenAI"].get("BaseURL")
|
|
|
+ base_url=config["OpenAI"].get("BaseURL"),
|
|
|
)
|
|
|
|
|
|
if "BaseURL" in config["OpenAI"]:
|
|
@@ -285,7 +292,12 @@ class GPTBot:
|
|
|
|
|
|
return user_id
|
|
|
|
|
|
- async def _last_n_messages(self, room: str | MatrixRoom, n: Optional[int], ignore_bot_commands: bool = False):
|
|
|
+ async def _last_n_messages(
|
|
|
+ self,
|
|
|
+ room: str | MatrixRoom,
|
|
|
+ n: Optional[int],
|
|
|
+ ignore_bot_commands: bool = False,
|
|
|
+ ):
|
|
|
messages = []
|
|
|
n = n or self.max_messages
|
|
|
room_id = room.room_id if isinstance(room, MatrixRoom) else room
|
|
@@ -362,7 +374,13 @@ class GPTBot:
|
|
|
truncated_messages = []
|
|
|
|
|
|
for message in [messages[0]] + list(reversed(messages[1:])):
|
|
|
- content = message["content"] if isinstance(message["content"], str) else message["content"][0]["text"] if isinstance(message["content"][0].get("text"), str) else ""
|
|
|
+ content = (
|
|
|
+ message["content"]
|
|
|
+ if isinstance(message["content"], str)
|
|
|
+ else message["content"][0]["text"]
|
|
|
+ if isinstance(message["content"][0].get("text"), str)
|
|
|
+ else ""
|
|
|
+ )
|
|
|
tokens = len(encoding.encode(content)) + 1
|
|
|
if total_tokens + tokens > max_tokens:
|
|
|
break
|
|
@@ -658,9 +676,7 @@ class GPTBot:
|
|
|
"url": content_uri,
|
|
|
}
|
|
|
|
|
|
- status = await self.matrix_client.room_send(
|
|
|
- room, "m.room.message", content
|
|
|
- )
|
|
|
+ status = await self.matrix_client.room_send(room, "m.room.message", content)
|
|
|
|
|
|
self.logger.log("Sent image", "debug")
|
|
|
|
|
@@ -694,9 +710,7 @@ class GPTBot:
|
|
|
"url": content_uri,
|
|
|
}
|
|
|
|
|
|
- status = await self.matrix_client.room_send(
|
|
|
- room, "m.room.message", content
|
|
|
- )
|
|
|
+ status = await self.matrix_client.room_send(room, "m.room.message", content)
|
|
|
|
|
|
self.logger.log("Sent file", "debug")
|
|
|
|
|
@@ -789,7 +803,9 @@ class GPTBot:
|
|
|
self.matrix_client.device_id = await self._get_device_id()
|
|
|
|
|
|
if not self.database:
|
|
|
- self.database = sqlite3.connect(Path(__file__).parent.parent / "database.db")
|
|
|
+ self.database = sqlite3.connect(
|
|
|
+ Path(__file__).parent.parent / "database.db"
|
|
|
+ )
|
|
|
|
|
|
self.logger.log("Running migrations...")
|
|
|
|
|
@@ -987,6 +1003,28 @@ class GPTBot:
|
|
|
|
|
|
return True if not result else bool(int(result[0]))
|
|
|
|
|
|
+ async def get_room_model(self, room: MatrixRoom | str) -> str:
|
|
|
+ """Get the model used for a room.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ room (MatrixRoom | str): The room to check.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ str: The model used for the room.
|
|
|
+ """
|
|
|
+
|
|
|
+ if isinstance(room, MatrixRoom):
|
|
|
+ room = room.room_id
|
|
|
+
|
|
|
+ with closing(self.database.cursor()) as cursor:
|
|
|
+ cursor.execute(
|
|
|
+ "SELECT value FROM room_settings WHERE room_id = ? AND setting = ?",
|
|
|
+ (room, "model"),
|
|
|
+ )
|
|
|
+ result = cursor.fetchone()
|
|
|
+
|
|
|
+ return result[0] if result else self.chat_api.chat_model
|
|
|
+
|
|
|
async def process_query(
|
|
|
self, room: MatrixRoom, event: RoomMessageText, from_chat_command: bool = False
|
|
|
):
|
|
@@ -1053,28 +1091,46 @@ class GPTBot:
|
|
|
for message in last_messages:
|
|
|
if isinstance(message, (RoomMessageNotice, RoomMessageText)):
|
|
|
role = (
|
|
|
- "assistant" if message.sender == self.matrix_client.user_id else "user"
|
|
|
+ "assistant"
|
|
|
+ if message.sender == self.matrix_client.user_id
|
|
|
+ else "user"
|
|
|
)
|
|
|
if message == event or (not message.event_id == event.event_id):
|
|
|
- message_body = message.body if not self.chat_api.supports_chat_images() else [{"type": "text", "text": message.body}]
|
|
|
+ message_body = (
|
|
|
+ message.body
|
|
|
+ if not self.chat_api.supports_chat_images()
|
|
|
+ else [{"type": "text", "text": message.body}]
|
|
|
+ )
|
|
|
chat_messages.append({"role": role, "content": message_body})
|
|
|
|
|
|
- elif isinstance(message, RoomMessageAudio) or (isinstance(message, RoomMessageFile) and message.body.endswith(".mp3")):
|
|
|
+ elif isinstance(message, RoomMessageAudio) or (
|
|
|
+ isinstance(message, RoomMessageFile) and message.body.endswith(".mp3")
|
|
|
+ ):
|
|
|
role = (
|
|
|
- "assistant" if message.sender == self.matrix_client.user_id else "user"
|
|
|
+ "assistant"
|
|
|
+ if message.sender == self.matrix_client.user_id
|
|
|
+ else "user"
|
|
|
)
|
|
|
if message == event or (not message.event_id == event.event_id):
|
|
|
if self.room_uses_stt(room):
|
|
|
try:
|
|
|
download = await self.download_file(message.url)
|
|
|
- message_text = await self.stt_api.speech_to_text(download.body)
|
|
|
+ message_text = await self.stt_api.speech_to_text(
|
|
|
+ download.body
|
|
|
+ )
|
|
|
except Exception as e:
|
|
|
- self.logger.log(f"Error generating text from audio: {e}", "error")
|
|
|
+ self.logger.log(
|
|
|
+ f"Error generating text from audio: {e}", "error"
|
|
|
+ )
|
|
|
message_text = message.body
|
|
|
else:
|
|
|
message_text = message.body
|
|
|
|
|
|
- message_body = message_text if not self.chat_api.supports_chat_images() else [{"type": "text", "text": message_text}]
|
|
|
+ message_body = (
|
|
|
+ message_text
|
|
|
+ if not self.chat_api.supports_chat_images()
|
|
|
+ else [{"type": "text", "text": message_text}]
|
|
|
+ )
|
|
|
chat_messages.append({"role": role, "content": message_body})
|
|
|
|
|
|
elif isinstance(message, RoomMessageFile):
|
|
@@ -1092,38 +1148,72 @@ class GPTBot:
|
|
|
if message.sender == self.matrix_client.user_id
|
|
|
else "user"
|
|
|
)
|
|
|
- if message == event or (not message.event_id == event.event_id):
|
|
|
- message_body = text if not self.chat_api.supports_chat_images() else [{"type": "text", "text": text}]
|
|
|
- chat_messages.append({"role": role, "content": message_body})
|
|
|
+ if message == event or (
|
|
|
+ not message.event_id == event.event_id
|
|
|
+ ):
|
|
|
+ message_body = (
|
|
|
+ text
|
|
|
+ if not self.chat_api.supports_chat_images()
|
|
|
+ else [{"type": "text", "text": text}]
|
|
|
+ )
|
|
|
+ chat_messages.append(
|
|
|
+ {"role": role, "content": message_body}
|
|
|
+ )
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.log(f"Error generating text from file: {e}", "error")
|
|
|
- message_body = message.body if not self.chat_api.supports_chat_images() else [{"type": "text", "text": message.body}]
|
|
|
+ message_body = (
|
|
|
+ message.body
|
|
|
+ if not self.chat_api.supports_chat_images()
|
|
|
+ else [{"type": "text", "text": message.body}]
|
|
|
+ )
|
|
|
chat_messages.append({"role": "system", "content": message_body})
|
|
|
|
|
|
- elif self.chat_api.supports_chat_images() and isinstance(message, RoomMessageImage):
|
|
|
+ elif self.chat_api.supports_chat_images() and isinstance(
|
|
|
+ message, RoomMessageImage
|
|
|
+ ):
|
|
|
try:
|
|
|
image_url = message.url
|
|
|
download = await self.download_file(image_url)
|
|
|
|
|
|
if download:
|
|
|
encoded_url = f"data:{download.content_type};base64,{base64.b64encode(download.body).decode('utf-8')}"
|
|
|
- parent = chat_messages[-1] if chat_messages and chat_messages[-1]["role"] == ("assistant" if message.sender == self.matrix_client.user_id else "user") else None
|
|
|
+ parent = (
|
|
|
+ chat_messages[-1]
|
|
|
+ if chat_messages
|
|
|
+ and chat_messages[-1]["role"]
|
|
|
+ == (
|
|
|
+ "assistant"
|
|
|
+ if message.sender == self.matrix_client.user_id
|
|
|
+ else "user"
|
|
|
+ )
|
|
|
+ else None
|
|
|
+ )
|
|
|
|
|
|
if not parent:
|
|
|
- chat_messages.append({"role": ("assistant" if message.sender == self.matrix_client.user_id else "user"), "content": []})
|
|
|
+ chat_messages.append(
|
|
|
+ {
|
|
|
+ "role": (
|
|
|
+ "assistant"
|
|
|
+ if message.sender == self.matrix_client.user_id
|
|
|
+ else "user"
|
|
|
+ ),
|
|
|
+ "content": [],
|
|
|
+ }
|
|
|
+ )
|
|
|
parent = chat_messages[-1]
|
|
|
|
|
|
- parent["content"].append({
|
|
|
- "type": "image_url",
|
|
|
- "image_url": {
|
|
|
- "url": encoded_url
|
|
|
- }
|
|
|
- })
|
|
|
+ parent["content"].append(
|
|
|
+ {"type": "image_url", "image_url": {"url": encoded_url}}
|
|
|
+ )
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.log(f"Error generating image from file: {e}", "error")
|
|
|
- message_body = message.body if not self.chat_api.supports_chat_images() else [{"type": "text", "text": message.body}]
|
|
|
+ message_body = (
|
|
|
+ message.body
|
|
|
+ if not self.chat_api.supports_chat_images()
|
|
|
+ else [{"type": "text", "text": message.body}]
|
|
|
+ )
|
|
|
chat_messages.append({"role": "system", "content": message_body})
|
|
|
|
|
|
# Truncate messages to fit within the token limit
|
|
@@ -1131,9 +1221,15 @@ class GPTBot:
|
|
|
chat_messages[1:], self.max_tokens - 1, system_message=system_message
|
|
|
)
|
|
|
|
|
|
+ # Check for a model override
|
|
|
+ if self.allow_model_override:
|
|
|
+ model = await self.get_room_model(room)
|
|
|
+ else:
|
|
|
+ model = self.chat_api.chat_model
|
|
|
+
|
|
|
try:
|
|
|
response, tokens_used = await self.chat_api.generate_chat_response(
|
|
|
- chat_messages, user=event.sender, room=room.room_id
|
|
|
+ chat_messages, user=event.sender, room=room.room_id, model=model
|
|
|
)
|
|
|
except Exception as e:
|
|
|
print(traceback.format_exc())
|