|
@@ -36,6 +36,7 @@ from typing import Optional, List, Dict, Tuple
|
|
|
from configparser import ConfigParser
|
|
|
from datetime import datetime
|
|
|
from io import BytesIO
|
|
|
+from pathlib import Path
|
|
|
|
|
|
import uuid
|
|
|
import traceback
|
|
@@ -53,7 +54,7 @@ from .trackingmore import TrackingMore
|
|
|
class GPTBot:
|
|
|
# Default values
|
|
|
database: Optional[duckdb.DuckDBPyConnection] = None
|
|
|
- default_room_name: str = "GPTBot" # Default name of rooms created by the bot
|
|
|
+ display_name = default_room_name = "GPTBot" # Default name of rooms created by the bot
|
|
|
default_system_message: str = "You are a helpful assistant."
|
|
|
# Force default system message to be included even if a custom room message is set
|
|
|
force_system_message: bool = False
|
|
@@ -69,6 +70,8 @@ class GPTBot:
|
|
|
operator: Optional[str] = None
|
|
|
room_ignore_list: List[str] = [] # List of rooms to ignore invites from
|
|
|
debug: bool = False
|
|
|
+ logo: Optional[Image.Image] = None
|
|
|
+ logo_uri: Optional[str] = None
|
|
|
|
|
|
@classmethod
|
|
|
def from_config(cls, config: ConfigParser):
|
|
@@ -99,6 +102,15 @@ class GPTBot:
|
|
|
"ForceSystemMessage", bot.force_system_message)
|
|
|
bot.debug = config["GPTBot"].getboolean("Debug", bot.debug)
|
|
|
|
|
|
+ logo_path = config["GPTBot"].get("Logo", str(Path(__file__).parent.parent / "assets/logo.png"))
|
|
|
+
|
|
|
+ bot.logger.log(f"Loading logo from {logo_path}")
|
|
|
+
|
|
|
+ if Path(logo_path).exists() and Path(logo_path).is_file():
|
|
|
+ bot.logo = Image.open(logo_path)
|
|
|
+
|
|
|
+ bot.display_name = config["GPTBot"].get("DisplayName", bot.display_name)
|
|
|
+
|
|
|
bot.chat_api = bot.image_api = bot.classification_api = OpenAI(
|
|
|
config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"), bot.logger)
|
|
|
bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
|
|
@@ -340,6 +352,30 @@ class GPTBot:
|
|
|
f"Error leaving room {invite}: {leave_response.message}", "error")
|
|
|
self.room_ignore_list.append(invite)
|
|
|
|
|
|
+ async def upload_file(self, file: bytes, filename: str = "file", mime: str = "application/octet-stream") -> str:
|
|
|
+ """Upload a file to the homeserver.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ file (bytes): The file to upload.
|
|
|
+ filename (str, optional): The name of the file. Defaults to "file".
|
|
|
+ mime (str, optional): The MIME type of the file. Defaults to "application/octet-stream".
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ str: The MXC URI of the uploaded file.
|
|
|
+ """
|
|
|
+
|
|
|
+ bio = BytesIO(file)
|
|
|
+ bio.seek(0)
|
|
|
+
|
|
|
+ response, _ = await self.matrix_client.upload(
|
|
|
+ bio,
|
|
|
+ content_type=mime,
|
|
|
+ filename=filename,
|
|
|
+ filesize=len(file)
|
|
|
+ )
|
|
|
+
|
|
|
+ return response.content_uri
|
|
|
+
|
|
|
async def send_image(self, room: MatrixRoom, image: bytes, message: Optional[str] = None):
|
|
|
"""Send an image to a room.
|
|
|
|
|
@@ -361,14 +397,7 @@ class GPTBot:
|
|
|
self.logger.log(
|
|
|
f"Uploading - Image size: {width}x{height} pixels, MIME type: {mime}")
|
|
|
|
|
|
- bio.seek(0)
|
|
|
-
|
|
|
- response, _ = await self.matrix_client.upload(
|
|
|
- bio,
|
|
|
- content_type=mime,
|
|
|
- filename="image",
|
|
|
- filesize=len(image)
|
|
|
- )
|
|
|
+ content_uri = await self.upload_file(image, "image", mime)
|
|
|
|
|
|
self.logger.log("Uploaded image - sending message...")
|
|
|
|
|
@@ -381,7 +410,7 @@ class GPTBot:
|
|
|
"h": height,
|
|
|
},
|
|
|
"msgtype": "m.image",
|
|
|
- "url": response.content_uri
|
|
|
+ "url": content_uri
|
|
|
}
|
|
|
|
|
|
status = await self.matrix_client.room_send(
|
|
@@ -547,6 +576,26 @@ class GPTBot:
|
|
|
self.matrix_client.add_response_callback(
|
|
|
self.response_callback, Response)
|
|
|
|
|
|
+ # Set custom name / logo
|
|
|
+
|
|
|
+ if self.display_name:
|
|
|
+ self.logger.log(f"Setting display name to {self.display_name}")
|
|
|
+ await self.matrix_client.set_displayname(self.display_name)
|
|
|
+ if self.logo:
|
|
|
+ self.logger.log("Setting avatar...")
|
|
|
+ logo_bio = BytesIO()
|
|
|
+ self.logo.save(logo_bio, format=self.logo.format)
|
|
|
+ uri = await self.upload_file(logo_bio.getvalue(), "logo", Image.MIME[self.logo.format])
|
|
|
+ self.logo_uri = uri
|
|
|
+
|
|
|
+ asyncio.create_task(self.matrix_client.set_avatar(uri))
|
|
|
+
|
|
|
+ for room in self.matrix_client.rooms.keys():
|
|
|
+ self.logger.log(f"Setting avatar for {room}...", "debug")
|
|
|
+ asyncio.create_task(self.matrix_client.room_put_state(room, "m.room.avatar", {
|
|
|
+ "url": uri
|
|
|
+ }, ""))
|
|
|
+
|
|
|
# Start syncing events
|
|
|
self.logger.log("Starting sync loop...")
|
|
|
try:
|