|
@@ -27,6 +27,12 @@ from nio import (
|
|
|
RoomSendError,
|
|
|
RoomVisibility,
|
|
|
RoomCreateError,
|
|
|
+ RoomMessageMedia,
|
|
|
+ RoomMessageImage,
|
|
|
+ RoomMessageFile,
|
|
|
+ RoomMessageAudio,
|
|
|
+ DownloadError,
|
|
|
+ DownloadResponse,
|
|
|
)
|
|
|
from nio.crypto import Olm
|
|
|
from nio.store import SqliteStore
|
|
@@ -38,6 +44,7 @@ from io import BytesIO
|
|
|
from pathlib import Path
|
|
|
from contextlib import closing
|
|
|
|
|
|
+import base64
|
|
|
import uuid
|
|
|
import traceback
|
|
|
import json
|
|
@@ -139,7 +146,7 @@ class GPTBot:
|
|
|
|
|
|
bot.chat_api = bot.image_api = bot.classification_api = OpenAI(
|
|
|
config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"),
|
|
|
- config["OpenAI"].get("ImageModel"), bot.logger
|
|
|
+ config["OpenAI"].get("ImageModel"), config["OpenAI"].get("BaseURL"), bot.logger
|
|
|
)
|
|
|
bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
|
|
|
bot.max_messages = config["OpenAI"].getint("MaxMessages", bot.max_messages)
|
|
@@ -220,6 +227,7 @@ class GPTBot:
|
|
|
for event in response.chunk:
|
|
|
if len(messages) >= n:
|
|
|
break
|
|
|
+
|
|
|
if isinstance(event, MegolmEvent):
|
|
|
try:
|
|
|
event = await self.matrix_client.decrypt_event(event)
|
|
@@ -229,14 +237,22 @@ class GPTBot:
|
|
|
"error",
|
|
|
)
|
|
|
continue
|
|
|
- if isinstance(event, (RoomMessageText, RoomMessageNotice)):
|
|
|
+
|
|
|
+ if isinstance(event, RoomMessageText):
|
|
|
if event.body.startswith("!gptbot ignoreolder"):
|
|
|
break
|
|
|
- if (not event.body.startswith("!")) or (
|
|
|
- event.body.startswith("!gptbot") and not ignore_bot_commands
|
|
|
- ):
|
|
|
+ if (not event.body.startswith("!")) or (not ignore_bot_commands):
|
|
|
messages.append(event)
|
|
|
|
|
|
+ if isinstance(event, RoomMessageNotice):
|
|
|
+ if not ignore_bot_commands:
|
|
|
+ messages.append(event)
|
|
|
+
|
|
|
+ if isinstance(event, RoomMessageMedia):
|
|
|
+ if event.sender != self.matrix_client.user_id:
|
|
|
+ if len(messages) < 2 or isinstance(messages[-1], RoomMessageMedia):
|
|
|
+ messages.append(event)
|
|
|
+
|
|
|
self.logger.log(f"Found {len(messages)} messages (limit: {n})", "debug")
|
|
|
|
|
|
# Reverse the list so that messages are in chronological order
|
|
@@ -275,7 +291,7 @@ class GPTBot:
|
|
|
truncated_messages = []
|
|
|
|
|
|
for message in [messages[0]] + list(reversed(messages[1:])):
|
|
|
- content = message["content"]
|
|
|
+ content = message["content"] if isinstance(message["content"], str) else message["content"][0]["text"] if isinstance(message["content"][0].get("text"), str) else ""
|
|
|
tokens = len(encoding.encode(content)) + 1
|
|
|
if total_tokens + tokens > max_tokens:
|
|
|
break
|
|
@@ -906,14 +922,39 @@ class GPTBot:
|
|
|
|
|
|
chat_messages = [{"role": "system", "content": system_message}]
|
|
|
|
|
|
- for message in last_messages:
|
|
|
+ text_messages = list(filter(lambda x: not isinstance(x, RoomMessageMedia), last_messages))
|
|
|
+
|
|
|
+ for message in text_messages:
|
|
|
role = (
|
|
|
"assistant" if message.sender == self.matrix_client.user_id else "user"
|
|
|
)
|
|
|
if not message.event_id == event.event_id:
|
|
|
chat_messages.append({"role": role, "content": message.body})
|
|
|
|
|
|
- chat_messages.append({"role": "user", "content": event.body})
|
|
|
+ if not self.chat_api.supports_chat_images():
|
|
|
+ event_body = event.body
|
|
|
+ else:
|
|
|
+ event_body = [
|
|
|
+ {
|
|
|
+ "type": "text",
|
|
|
+ "text": event.body
|
|
|
+ }
|
|
|
+ ]
|
|
|
+
|
|
|
+ for m in list(filter(lambda x: isinstance(x, RoomMessageMedia), last_messages)):
|
|
|
+ image_url = m.url
|
|
|
+ download = await self.download_file(image_url)
|
|
|
+
|
|
|
+ if download:
|
|
|
+ encoded_url = f"data:{download.content_type};base64,{base64.b64encode(download.body).decode('utf-8')}"
|
|
|
+ event_body.append({
|
|
|
+ "type": "image_url",
|
|
|
+ "image_url": {
|
|
|
+ "url": encoded_url
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ chat_messages.append({"role": "user", "content": event_body})
|
|
|
|
|
|
# Truncate messages to fit within the token limit
|
|
|
truncated_messages = self._truncate(
|
|
@@ -926,6 +967,7 @@ class GPTBot:
|
|
|
)
|
|
|
except Exception as e:
|
|
|
self.logger.log(f"Error generating response: {e}", "error")
|
|
|
+
|
|
|
await self.send_message(
|
|
|
room, "Something went wrong. Please try again.", True
|
|
|
)
|
|
@@ -954,6 +996,24 @@ class GPTBot:
|
|
|
|
|
|
await self.matrix_client.room_typing(room.room_id, False)
|
|
|
|
|
|
+ def download_file(self, mxc) -> Optional[bytes]:
|
|
|
+ """Download a file from the homeserver.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ mxc (str): The MXC URI of the file to download.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Optional[bytes]: The downloaded file, or None if there was an error.
|
|
|
+ """
|
|
|
+
|
|
|
+ download = self.matrix_client.download(mxc)
|
|
|
+
|
|
|
+ if isinstance(download, DownloadError):
|
|
|
+ self.logger.log(f"Error downloading file: {download.message}", "error")
|
|
|
+ return
|
|
|
+
|
|
|
+ return download
|
|
|
+
|
|
|
def get_system_message(self, room: MatrixRoom | str) -> str:
|
|
|
"""Get the system message for a room.
|
|
|
|