|
@@ -53,10 +53,11 @@ from .openai import OpenAI
|
|
|
from .wolframalpha import WolframAlpha
|
|
|
from .trackingmore import TrackingMore
|
|
|
|
|
|
+
|
|
|
class GPTBot:
|
|
|
# Default values
|
|
|
database: Optional[sqlite3.Connection] = None
|
|
|
- crypto_store_path: Optional[str|Path] = None
|
|
|
+ crypto_store_path: Optional[str | Path] = None
|
|
|
# Default name of rooms created by the bot
|
|
|
display_name = default_room_name = "GPTBot"
|
|
|
default_system_message: str = "You are a helpful assistant."
|
|
@@ -93,51 +94,64 @@ class GPTBot:
|
|
|
bot = cls()
|
|
|
|
|
|
# Set the database connection
|
|
|
- bot.database = sqlite3.connect(
|
|
|
- config["Database"]["Path"]) if "Database" in config and "Path" in config["Database"] else None
|
|
|
+ bot.database = (
|
|
|
+ sqlite3.connect(config["Database"]["Path"])
|
|
|
+ if "Database" in config and "Path" in config["Database"]
|
|
|
+ else None
|
|
|
+ )
|
|
|
|
|
|
- bot.crypto_store_path = config["Database"]["CryptoStore"] if "Database" in config and "CryptoStore" in config["Database"] else None
|
|
|
+ bot.crypto_store_path = (
|
|
|
+ config["Database"]["CryptoStore"]
|
|
|
+ if "Database" in config and "CryptoStore" in config["Database"]
|
|
|
+ else None
|
|
|
+ )
|
|
|
|
|
|
# Override default values
|
|
|
if "GPTBot" in config:
|
|
|
bot.operator = config["GPTBot"].get("Operator", bot.operator)
|
|
|
bot.default_room_name = config["GPTBot"].get(
|
|
|
- "DefaultRoomName", bot.default_room_name)
|
|
|
+ "DefaultRoomName", bot.default_room_name
|
|
|
+ )
|
|
|
bot.default_system_message = config["GPTBot"].get(
|
|
|
- "SystemMessage", bot.default_system_message)
|
|
|
+ "SystemMessage", bot.default_system_message
|
|
|
+ )
|
|
|
bot.force_system_message = config["GPTBot"].getboolean(
|
|
|
- "ForceSystemMessage", bot.force_system_message)
|
|
|
+ "ForceSystemMessage", bot.force_system_message
|
|
|
+ )
|
|
|
bot.debug = config["GPTBot"].getboolean("Debug", bot.debug)
|
|
|
|
|
|
- logo_path = config["GPTBot"].get("Logo", str(
|
|
|
- Path(__file__).parent.parent / "assets/logo.png"))
|
|
|
+ if "LogLevel" in config["GPTBot"]:
|
|
|
+ bot.logger = Logger(config["GPTBot"]["LogLevel"])
|
|
|
+
|
|
|
+ logo_path = config["GPTBot"].get(
|
|
|
+ "Logo", str(Path(__file__).parent.parent / "assets/logo.png")
|
|
|
+ )
|
|
|
|
|
|
- bot.logger.log(f"Loading logo from {logo_path}")
|
|
|
+ bot.logger.log(f"Loading logo from {logo_path}", "debug")
|
|
|
|
|
|
if Path(logo_path).exists() and Path(logo_path).is_file():
|
|
|
bot.logo = Image.open(logo_path)
|
|
|
|
|
|
- bot.display_name = config["GPTBot"].get(
|
|
|
- "DisplayName", bot.display_name)
|
|
|
+ bot.display_name = config["GPTBot"].get("DisplayName", bot.display_name)
|
|
|
|
|
|
if "AllowedUsers" in config["GPTBot"]:
|
|
|
bot.allowed_users = json.loads(config["GPTBot"]["AllowedUsers"])
|
|
|
|
|
|
bot.chat_api = bot.image_api = bot.classification_api = OpenAI(
|
|
|
- config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"), bot.logger)
|
|
|
+ config["OpenAI"]["APIKey"], config["OpenAI"].get("Model"), bot.logger
|
|
|
+ )
|
|
|
bot.max_tokens = config["OpenAI"].getint("MaxTokens", bot.max_tokens)
|
|
|
- bot.max_messages = config["OpenAI"].getint(
|
|
|
- "MaxMessages", bot.max_messages)
|
|
|
+ bot.max_messages = config["OpenAI"].getint("MaxMessages", bot.max_messages)
|
|
|
|
|
|
# Set up WolframAlpha
|
|
|
if "WolframAlpha" in config:
|
|
|
bot.calculation_api = WolframAlpha(
|
|
|
- config["WolframAlpha"]["APIKey"], bot.logger)
|
|
|
+ config["WolframAlpha"]["APIKey"], bot.logger
|
|
|
+ )
|
|
|
|
|
|
# Set up TrackingMore
|
|
|
if "TrackingMore" in config:
|
|
|
- bot.parcel_api = TrackingMore(
|
|
|
- config["TrackingMore"]["APIKey"], bot.logger)
|
|
|
+ bot.parcel_api = TrackingMore(config["TrackingMore"]["APIKey"], bot.logger)
|
|
|
|
|
|
# Set up the Matrix client
|
|
|
|
|
@@ -182,17 +196,21 @@ class GPTBot:
|
|
|
room_id = room.room_id if isinstance(room, MatrixRoom) else room
|
|
|
|
|
|
self.logger.log(
|
|
|
- f"Fetching last {2*n} messages from room {room_id} (starting at {self.sync_token})...")
|
|
|
+ f"Fetching last {2*n} messages from room {room_id} (starting at {self.sync_token})...",
|
|
|
+ "debug",
|
|
|
+ )
|
|
|
|
|
|
response = await self.matrix_client.room_messages(
|
|
|
room_id=room_id,
|
|
|
start=self.sync_token,
|
|
|
- limit=2*n,
|
|
|
+ limit=2 * n,
|
|
|
)
|
|
|
|
|
|
if isinstance(response, RoomMessagesError):
|
|
|
raise Exception(
|
|
|
- f"Error fetching messages: {response.message} (status code {response.status_code})", "error")
|
|
|
+ f"Error fetching messages: {response.message} (status code {response.status_code})",
|
|
|
+ "error",
|
|
|
+ )
|
|
|
|
|
|
for event in response.chunk:
|
|
|
if len(messages) >= n:
|
|
@@ -202,34 +220,48 @@ class GPTBot:
|
|
|
event = await self.matrix_client.decrypt_event(event)
|
|
|
except (GroupEncryptionError, EncryptionError):
|
|
|
self.logger.log(
|
|
|
- f"Could not decrypt message {event.event_id} in room {room_id}", "error")
|
|
|
+ f"Could not decrypt message {event.event_id} in room {room_id}",
|
|
|
+ "error",
|
|
|
+ )
|
|
|
continue
|
|
|
if isinstance(event, (RoomMessageText, RoomMessageNotice)):
|
|
|
if event.body.startswith("!gptbot ignoreolder"):
|
|
|
break
|
|
|
- if (not event.body.startswith("!")) or (event.body.startswith("!gptbot")):
|
|
|
+ if (not event.body.startswith("!")) or (
|
|
|
+ event.body.startswith("!gptbot")
|
|
|
+ ):
|
|
|
messages.append(event)
|
|
|
|
|
|
- self.logger.log(f"Found {len(messages)} messages (limit: {n})")
|
|
|
+ self.logger.log(f"Found {len(messages)} messages (limit: {n})", "debug")
|
|
|
|
|
|
# Reverse the list so that messages are in chronological order
|
|
|
return messages[::-1]
|
|
|
|
|
|
- def _truncate(self, messages: list, max_tokens: Optional[int] = None,
|
|
|
- model: Optional[str] = None, system_message: Optional[str] = None):
|
|
|
+ def _truncate(
|
|
|
+ self,
|
|
|
+ messages: list,
|
|
|
+ max_tokens: Optional[int] = None,
|
|
|
+ model: Optional[str] = None,
|
|
|
+ system_message: Optional[str] = None,
|
|
|
+ ):
|
|
|
max_tokens = max_tokens or self.max_tokens
|
|
|
model = model or self.chat_api.chat_model
|
|
|
- system_message = self.default_system_message if system_message is None else system_message
|
|
|
+ system_message = (
|
|
|
+ self.default_system_message if system_message is None else system_message
|
|
|
+ )
|
|
|
|
|
|
encoding = tiktoken.encoding_for_model(model)
|
|
|
total_tokens = 0
|
|
|
|
|
|
- system_message_tokens = 0 if not system_message else (
|
|
|
- len(encoding.encode(system_message)) + 1)
|
|
|
+ system_message_tokens = (
|
|
|
+ 0 if not system_message else (len(encoding.encode(system_message)) + 1)
|
|
|
+ )
|
|
|
|
|
|
if system_message_tokens > max_tokens:
|
|
|
self.logger.log(
|
|
|
- f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed", "error")
|
|
|
+ f"System message is too long to fit within token limit ({system_message_tokens} tokens) - cannot proceed",
|
|
|
+ "error",
|
|
|
+ )
|
|
|
return []
|
|
|
|
|
|
total_tokens += system_message_tokens
|
|
@@ -279,7 +311,9 @@ class GPTBot:
|
|
|
"""
|
|
|
|
|
|
self.logger.log(
|
|
|
- f"Received command {event.body} from {event.sender} in room {room.room_id}")
|
|
|
+ f"Received command {event.body} from {event.sender} in room {room.room_id}",
|
|
|
+ "debug",
|
|
|
+ )
|
|
|
command = event.body.split()[1] if event.body.split()[1:] else None
|
|
|
|
|
|
await COMMANDS.get(command, COMMANDS[None])(room, event, self)
|
|
@@ -297,7 +331,9 @@ class GPTBot:
|
|
|
|
|
|
with closing(self.database.cursor()) as cursor:
|
|
|
cursor.execute(
|
|
|
- "SELECT value FROM room_settings WHERE room_id = ? AND setting = ?", (room_id, "use_classification"))
|
|
|
+ "SELECT value FROM room_settings WHERE room_id = ? AND setting = ?",
|
|
|
+ (room_id, "use_classification"),
|
|
|
+ )
|
|
|
result = cursor.fetchone()
|
|
|
|
|
|
return False if not result else bool(int(result[0]))
|
|
@@ -310,10 +346,13 @@ class GPTBot:
|
|
|
await callback(room, event, self)
|
|
|
except Exception as e:
|
|
|
self.logger.log(
|
|
|
- f"Error in event callback for {event.__class__}: {e}", "error")
|
|
|
+ f"Error in event callback for {event.__class__}: {e}", "error"
|
|
|
+ )
|
|
|
|
|
|
if self.debug:
|
|
|
- await self.send_message(room, f"Error: {e}\n\n```\n{traceback.format_exc()}\n```", True)
|
|
|
+ await self.send_message(
|
|
|
+ room, f"Error: {e}\n\n```\n{traceback.format_exc()}\n```", True
|
|
|
+ )
|
|
|
|
|
|
def user_is_allowed(self, user_id: str) -> bool:
|
|
|
"""Check if a user is allowed to use the bot.
|
|
@@ -326,10 +365,14 @@ class GPTBot:
|
|
|
"""
|
|
|
|
|
|
return (
|
|
|
- user_id in self.allowed_users or
|
|
|
- f"*:{user_id.split(':')[1]}" in self.allowed_users or
|
|
|
- f"@*:{user_id.split(':')[1]}" in self.allowed_users
|
|
|
- ) if self.allowed_users else True
|
|
|
+ (
|
|
|
+ user_id in self.allowed_users
|
|
|
+ or f"*:{user_id.split(':')[1]}" in self.allowed_users
|
|
|
+ or f"@*:{user_id.split(':')[1]}" in self.allowed_users
|
|
|
+ )
|
|
|
+ if self.allowed_users
|
|
|
+ else True
|
|
|
+ )
|
|
|
|
|
|
async def event_callback(self, room: MatrixRoom, event: Event):
|
|
|
"""Callback for events.
|
|
@@ -349,8 +392,8 @@ class GPTBot:
|
|
|
"m.room.message",
|
|
|
{
|
|
|
"msgtype": "m.notice",
|
|
|
- "body": f"You are not allowed to use this bot. Please contact {self.operator} for more information."
|
|
|
- }
|
|
|
+ "body": f"You are not allowed to use this bot. Please contact {self.operator} for more information.",
|
|
|
+ },
|
|
|
)
|
|
|
return
|
|
|
|
|
@@ -369,7 +412,9 @@ class GPTBot:
|
|
|
|
|
|
with closing(self.database.cursor()) as cursor:
|
|
|
cursor.execute(
|
|
|
- "SELECT value FROM room_settings WHERE room_id = ? AND setting = ?", (room_id, "use_timing"))
|
|
|
+ "SELECT value FROM room_settings WHERE room_id = ? AND setting = ?",
|
|
|
+ (room_id, "use_timing"),
|
|
|
+ )
|
|
|
result = cursor.fetchone()
|
|
|
|
|
|
return False if not result else bool(int(result[0]))
|
|
@@ -392,7 +437,9 @@ class GPTBot:
|
|
|
for invite in invites.keys():
|
|
|
if invite in self.room_ignore_list:
|
|
|
self.logger.log(
|
|
|
- f"Ignoring invite to room {invite} (room is in ignore list)")
|
|
|
+ f"Ignoring invite to room {invite} (room is in ignore list)",
|
|
|
+ "debug",
|
|
|
+ )
|
|
|
continue
|
|
|
|
|
|
self.logger.log(f"Accepting invite to room {invite}")
|
|
@@ -401,16 +448,25 @@ class GPTBot:
|
|
|
|
|
|
if isinstance(response, JoinError):
|
|
|
self.logger.log(
|
|
|
- f"Error joining room {invite}: {response.message}. Not trying again.", "error")
|
|
|
+ f"Error joining room {invite}: {response.message}. Not trying again.",
|
|
|
+ "error",
|
|
|
+ )
|
|
|
|
|
|
leave_response = await self.matrix_client.room_leave(invite)
|
|
|
|
|
|
if isinstance(leave_response, RoomLeaveError):
|
|
|
self.logger.log(
|
|
|
- f"Error leaving room {invite}: {leave_response.message}", "error")
|
|
|
+ f"Error leaving room {invite}: {leave_response.message}",
|
|
|
+ "error",
|
|
|
+ )
|
|
|
self.room_ignore_list.append(invite)
|
|
|
|
|
|
- async def upload_file(self, file: bytes, filename: str = "file", mime: str = "application/octet-stream") -> str:
|
|
|
+ async def upload_file(
|
|
|
+ self,
|
|
|
+ file: bytes,
|
|
|
+ filename: str = "file",
|
|
|
+ mime: str = "application/octet-stream",
|
|
|
+ ) -> str:
|
|
|
"""Upload a file to the homeserver.
|
|
|
|
|
|
Args:
|
|
@@ -426,15 +482,14 @@ class GPTBot:
|
|
|
bio.seek(0)
|
|
|
|
|
|
response, _ = await self.matrix_client.upload(
|
|
|
- bio,
|
|
|
- content_type=mime,
|
|
|
- filename=filename,
|
|
|
- filesize=len(file)
|
|
|
+ bio, content_type=mime, filename=filename, filesize=len(file)
|
|
|
)
|
|
|
|
|
|
return response.content_uri
|
|
|
|
|
|
- async def send_image(self, room: MatrixRoom, image: bytes, message: Optional[str] = None):
|
|
|
+ async def send_image(
|
|
|
+ self, room: MatrixRoom, image: bytes, message: Optional[str] = None
|
|
|
+ ):
|
|
|
"""Send an image to a room.
|
|
|
|
|
|
Args:
|
|
@@ -444,7 +499,8 @@ class GPTBot:
|
|
|
"""
|
|
|
|
|
|
self.logger.log(
|
|
|
- f"Sending image of size {len(image)} bytes to room {room.room_id}")
|
|
|
+ f"Sending image of size {len(image)} bytes to room {room.room_id}", "debug"
|
|
|
+ )
|
|
|
|
|
|
bio = BytesIO(image)
|
|
|
img = Image.open(bio)
|
|
@@ -453,11 +509,13 @@ class GPTBot:
|
|
|
(width, height) = img.size
|
|
|
|
|
|
self.logger.log(
|
|
|
- f"Uploading - Image size: {width}x{height} pixels, MIME type: {mime}")
|
|
|
+ f"Uploading - Image size: {width}x{height} pixels, MIME type: {mime}",
|
|
|
+ "debug",
|
|
|
+ )
|
|
|
|
|
|
content_uri = await self.upload_file(image, "image", mime)
|
|
|
|
|
|
- self.logger.log("Uploaded image - sending message...")
|
|
|
+ self.logger.log("Uploaded image - sending message...", "debug")
|
|
|
|
|
|
content = {
|
|
|
"body": message or "",
|
|
@@ -468,20 +526,18 @@ class GPTBot:
|
|
|
"h": height,
|
|
|
},
|
|
|
"msgtype": "m.image",
|
|
|
- "url": content_uri
|
|
|
+ "url": content_uri,
|
|
|
}
|
|
|
|
|
|
status = await self.matrix_client.room_send(
|
|
|
- room.room_id,
|
|
|
- "m.room.message",
|
|
|
- content
|
|
|
+ room.room_id, "m.room.message", content
|
|
|
)
|
|
|
|
|
|
- self.logger.log(str(status), "debug")
|
|
|
-
|
|
|
- self.logger.log("Sent image")
|
|
|
+ self.logger.log("Sent image", "debug")
|
|
|
|
|
|
- async def send_message(self, room: MatrixRoom | str, message: str, notice: bool = False):
|
|
|
+ async def send_message(
|
|
|
+ self, room: MatrixRoom | str, message: str, notice: bool = False
|
|
|
+ ):
|
|
|
"""Send a message to a room.
|
|
|
|
|
|
Args:
|
|
@@ -498,8 +554,12 @@ class GPTBot:
|
|
|
|
|
|
msgtype = "m.notice" if notice else "m.text"
|
|
|
|
|
|
- msgcontent = {"msgtype": msgtype, "body": message,
|
|
|
- "format": "org.matrix.custom.html", "formatted_body": formatted_body}
|
|
|
+ msgcontent = {
|
|
|
+ "msgtype": msgtype,
|
|
|
+ "body": message,
|
|
|
+ "format": "org.matrix.custom.html",
|
|
|
+ "formatted_body": formatted_body,
|
|
|
+ }
|
|
|
|
|
|
content = None
|
|
|
|
|
@@ -507,7 +567,9 @@ class GPTBot:
|
|
|
try:
|
|
|
if not room.members_synced:
|
|
|
responses = []
|
|
|
- responses.append(await self.matrix_client.joined_members(room.room_id))
|
|
|
+ responses.append(
|
|
|
+ await self.matrix_client.joined_members(room.room_id)
|
|
|
+ )
|
|
|
|
|
|
if self.matrix_client.olm.should_share_group_session(room.room_id):
|
|
|
try:
|
|
@@ -521,12 +583,14 @@ class GPTBot:
|
|
|
|
|
|
if msgtype != "m.reaction":
|
|
|
response = self.matrix_client.encrypt(
|
|
|
- room.room_id, "m.room.message", msgcontent)
|
|
|
+ room.room_id, "m.room.message", msgcontent
|
|
|
+ )
|
|
|
msgtype, content = response
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.log(
|
|
|
- f"Error encrypting message: {e} - sending unencrypted", "error")
|
|
|
+ f"Error encrypting message: {e} - sending unencrypted", "warning"
|
|
|
+ )
|
|
|
raise
|
|
|
|
|
|
if not content:
|
|
@@ -534,17 +598,24 @@ class GPTBot:
|
|
|
content = msgcontent
|
|
|
|
|
|
method, path, data = Api.room_send(
|
|
|
- self.matrix_client.access_token, room.room_id, msgtype, content, uuid.uuid4()
|
|
|
+ self.matrix_client.access_token,
|
|
|
+ room.room_id,
|
|
|
+ msgtype,
|
|
|
+ content,
|
|
|
+ uuid.uuid4(),
|
|
|
)
|
|
|
|
|
|
- response = await self.matrix_client._send(RoomSendResponse, method, path, data, (room.room_id,))
|
|
|
+ response = await self.matrix_client._send(
|
|
|
+ RoomSendResponse, method, path, data, (room.room_id,)
|
|
|
+ )
|
|
|
|
|
|
if isinstance(response, RoomSendError):
|
|
|
- self.logger.log(
|
|
|
- f"Error sending message: {response.message}", "error")
|
|
|
+ self.logger.log(f"Error sending message: {response.message}", "error")
|
|
|
return
|
|
|
|
|
|
- def log_api_usage(self, message: Event | str, room: MatrixRoom | str, api: str, tokens: int):
|
|
|
+ def log_api_usage(
|
|
|
+ self, message: Event | str, room: MatrixRoom | str, api: str, tokens: int
|
|
|
+ ):
|
|
|
"""Log API usage to the database.
|
|
|
|
|
|
Args:
|
|
@@ -565,7 +636,7 @@ class GPTBot:
|
|
|
|
|
|
self.database.execute(
|
|
|
"INSERT INTO token_usage (message_id, room_id, tokens, api, timestamp) VALUES (?, ?, ?, ?, ?)",
|
|
|
- (message, room, tokens, api, datetime.now())
|
|
|
+ (message, room, tokens, api, datetime.now()),
|
|
|
)
|
|
|
|
|
|
async def run(self):
|
|
@@ -587,7 +658,9 @@ class GPTBot:
|
|
|
IN_MEMORY = False
|
|
|
if not self.database:
|
|
|
self.logger.log(
|
|
|
- "No database connection set up, using in-memory database. Data will be lost on bot shutdown.")
|
|
|
+ "No database connection set up, using in-memory database. Data will be lost on bot shutdown.",
|
|
|
+ "warning",
|
|
|
+ )
|
|
|
IN_MEMORY = True
|
|
|
self.database = sqlite3.connect(":memory:")
|
|
|
|
|
@@ -596,8 +669,12 @@ class GPTBot:
|
|
|
try:
|
|
|
before, after = migrate(self.database)
|
|
|
except sqlite3.DatabaseError as e:
|
|
|
- self.logger.log(f"Error migrating database: {e}", "fatal")
|
|
|
- self.logger.log("If you have just updated the bot, the previous version of the database may be incompatible with this version. Please delete the database file and try again.", "fatal")
|
|
|
+ self.logger.log(f"Error migrating database: {e}", "critical")
|
|
|
+
|
|
|
+ self.logger.log(
|
|
|
+ "If you have just updated the bot, the previous version of the database may be incompatible with this version. Please delete the database file and try again.",
|
|
|
+ "critical",
|
|
|
+ )
|
|
|
exit(1)
|
|
|
|
|
|
if before != after:
|
|
@@ -607,66 +684,73 @@ class GPTBot:
|
|
|
|
|
|
if IN_MEMORY:
|
|
|
client_config = AsyncClientConfig(
|
|
|
- store_sync_tokens=True, encryption_enabled=False)
|
|
|
+ store_sync_tokens=True, encryption_enabled=False
|
|
|
+ )
|
|
|
else:
|
|
|
matrix_store = SqliteStore
|
|
|
client_config = AsyncClientConfig(
|
|
|
- store_sync_tokens=True, encryption_enabled=True, store=matrix_store)
|
|
|
+ store_sync_tokens=True, encryption_enabled=True, store=matrix_store
|
|
|
+ )
|
|
|
self.matrix_client.config = client_config
|
|
|
self.matrix_client.store = matrix_store(
|
|
|
self.matrix_client.user_id,
|
|
|
self.matrix_client.device_id,
|
|
|
- self.crypto_store_path or ""
|
|
|
+ self.crypto_store_path or "",
|
|
|
)
|
|
|
|
|
|
self.matrix_client.olm = Olm(
|
|
|
self.matrix_client.user_id,
|
|
|
self.matrix_client.device_id,
|
|
|
- self.matrix_client.store
|
|
|
+ self.matrix_client.store,
|
|
|
)
|
|
|
|
|
|
- self.matrix_client.encrypted_rooms = self.matrix_client.store.load_encrypted_rooms()
|
|
|
+ self.matrix_client.encrypted_rooms = (
|
|
|
+ self.matrix_client.store.load_encrypted_rooms()
|
|
|
+ )
|
|
|
|
|
|
# Run initial sync (now includes joining rooms)
|
|
|
sync = await self.matrix_client.sync(timeout=30000)
|
|
|
if isinstance(sync, SyncResponse):
|
|
|
await self.response_callback(sync)
|
|
|
else:
|
|
|
- self.logger.log(f"Initial sync failed, aborting: {sync}", "error")
|
|
|
- return
|
|
|
+ self.logger.log(f"Initial sync failed, aborting: {sync}", "critical")
|
|
|
+ exit(1)
|
|
|
|
|
|
# Set up callbacks
|
|
|
|
|
|
self.matrix_client.add_event_callback(self.event_callback, Event)
|
|
|
- self.matrix_client.add_response_callback(
|
|
|
- self.response_callback, Response)
|
|
|
+ self.matrix_client.add_response_callback(self.response_callback, Response)
|
|
|
|
|
|
# Set custom name / logo
|
|
|
|
|
|
if self.display_name:
|
|
|
- self.logger.log(f"Setting display name to {self.display_name}")
|
|
|
- await self.matrix_client.set_displayname(self.display_name)
|
|
|
+ self.logger.log(f"Setting display name to {self.display_name}", "debug")
|
|
|
+ asyncio.create_task(self.matrix_client.set_displayname(self.display_name))
|
|
|
if self.logo:
|
|
|
self.logger.log("Setting avatar...")
|
|
|
logo_bio = BytesIO()
|
|
|
self.logo.save(logo_bio, format=self.logo.format)
|
|
|
- uri = await self.upload_file(logo_bio.getvalue(), "logo", Image.MIME[self.logo.format])
|
|
|
+ uri = await self.upload_file(
|
|
|
+ logo_bio.getvalue(), "logo", Image.MIME[self.logo.format]
|
|
|
+ )
|
|
|
self.logo_uri = uri
|
|
|
|
|
|
asyncio.create_task(self.matrix_client.set_avatar(uri))
|
|
|
|
|
|
for room in self.matrix_client.rooms.keys():
|
|
|
self.logger.log(f"Setting avatar for {room}...", "debug")
|
|
|
- asyncio.create_task(self.matrix_client.room_put_state(room, "m.room.avatar", {
|
|
|
- "url": uri
|
|
|
- }, ""))
|
|
|
+ asyncio.create_task(
|
|
|
+ self.matrix_client.room_put_state(
|
|
|
+ room, "m.room.avatar", {"url": uri}, ""
|
|
|
+ )
|
|
|
+ )
|
|
|
|
|
|
# Start syncing events
|
|
|
- self.logger.log("Starting sync loop...")
|
|
|
+ self.logger.log("Starting sync loop...", "warning")
|
|
|
try:
|
|
|
await self.matrix_client.sync_forever(timeout=30000)
|
|
|
finally:
|
|
|
- self.logger.log("Syncing one last time...")
|
|
|
+ self.logger.log("Syncing one last time...", "warning")
|
|
|
await self.matrix_client.sync(timeout=30000)
|
|
|
|
|
|
async def create_space(self, name, visibility=RoomVisibility.private) -> str:
|
|
@@ -681,16 +765,18 @@ class GPTBot:
|
|
|
"""
|
|
|
|
|
|
response = await self.matrix_client.room_create(
|
|
|
- name=name, visibility=visibility, space=True)
|
|
|
+ name=name, visibility=visibility, space=True
|
|
|
+ )
|
|
|
|
|
|
if isinstance(response, RoomCreateError):
|
|
|
- self.logger.log(
|
|
|
- f"Error creating space: {response.message}", "error")
|
|
|
+ self.logger.log(f"Error creating space: {response.message}", "error")
|
|
|
return
|
|
|
|
|
|
return response.room_id
|
|
|
|
|
|
- async def add_rooms_to_space(self, space: MatrixRoom | str, rooms: List[MatrixRoom | str]):
|
|
|
+ async def add_rooms_to_space(
|
|
|
+ self, space: MatrixRoom | str, rooms: List[MatrixRoom | str]
|
|
|
+ ):
|
|
|
"""Add rooms to a space.
|
|
|
|
|
|
Args:
|
|
@@ -706,20 +792,26 @@ class GPTBot:
|
|
|
room = room.room_id
|
|
|
|
|
|
if space == room:
|
|
|
- self.logger.log(
|
|
|
- f"Refusing to add {room} to itself", "warning")
|
|
|
+ self.logger.log(f"Refusing to add {room} to itself", "warning")
|
|
|
continue
|
|
|
|
|
|
- self.logger.log(f"Adding {room} to {space}...")
|
|
|
+ self.logger.log(f"Adding {room} to {space}...", "debug")
|
|
|
|
|
|
- await self.matrix_client.room_put_state(space, "m.space.child", {
|
|
|
- "via": [room.split(":")[1], space.split(":")[1]],
|
|
|
- }, room)
|
|
|
+ await self.matrix_client.room_put_state(
|
|
|
+ space,
|
|
|
+ "m.space.child",
|
|
|
+ {
|
|
|
+ "via": [room.split(":")[1], space.split(":")[1]],
|
|
|
+ },
|
|
|
+ room,
|
|
|
+ )
|
|
|
|
|
|
- await self.matrix_client.room_put_state(room, "m.room.parent", {
|
|
|
- "via": [space.split(":")[1], room.split(":")[1]],
|
|
|
- "canonical": True
|
|
|
- }, space)
|
|
|
+ await self.matrix_client.room_put_state(
|
|
|
+ room,
|
|
|
+ "m.room.parent",
|
|
|
+ {"via": [space.split(":")[1], room.split(":")[1]], "canonical": True},
|
|
|
+ space,
|
|
|
+ )
|
|
|
|
|
|
def respond_to_room_messages(self, room: MatrixRoom | str) -> bool:
|
|
|
"""Check whether the bot should respond to all messages sent in a room.
|
|
@@ -736,12 +828,16 @@ class GPTBot:
|
|
|
|
|
|
with closing(self.database.cursor()) as cursor:
|
|
|
cursor.execute(
|
|
|
- "SELECT value FROM room_settings WHERE room_id = ? AND setting = ?", (room, "always_reply"))
|
|
|
+ "SELECT value FROM room_settings WHERE room_id = ? AND setting = ?",
|
|
|
+ (room, "always_reply"),
|
|
|
+ )
|
|
|
result = cursor.fetchone()
|
|
|
|
|
|
return True if not result else bool(int(result[0]))
|
|
|
|
|
|
- async def process_query(self, room: MatrixRoom, event: RoomMessageText, from_chat_command: bool = False):
|
|
|
+ async def process_query(
|
|
|
+ self, room: MatrixRoom, event: RoomMessageText, from_chat_command: bool = False
|
|
|
+ ):
|
|
|
"""Process a query message. Generates a response and sends it to the room.
|
|
|
|
|
|
Args:
|
|
@@ -750,7 +846,11 @@ class GPTBot:
|
|
|
from_chat_command (bool, optional): Whether the query was sent via the `!gptbot chat` command. Defaults to False.
|
|
|
"""
|
|
|
|
|
|
- if not (from_chat_command or self.respond_to_room_messages(room) or self.matrix_client.user_id in event.body):
|
|
|
+ if not (
|
|
|
+ from_chat_command
|
|
|
+ or self.respond_to_room_messages(room)
|
|
|
+ or self.matrix_client.user_id in event.body
|
|
|
+ ):
|
|
|
return
|
|
|
|
|
|
await self.matrix_client.room_typing(room.room_id, True)
|
|
@@ -760,18 +860,26 @@ class GPTBot:
|
|
|
if (not from_chat_command) and self.room_uses_classification(room):
|
|
|
try:
|
|
|
classification, tokens = await self.classification_api.classify_message(
|
|
|
- event.body, room.room_id)
|
|
|
+ event.body, room.room_id
|
|
|
+ )
|
|
|
except Exception as e:
|
|
|
self.logger.log(f"Error classifying message: {e}", "error")
|
|
|
await self.send_message(
|
|
|
- room, "Something went wrong. Please try again.", True)
|
|
|
+ room, "Something went wrong. Please try again.", True
|
|
|
+ )
|
|
|
return
|
|
|
|
|
|
self.log_api_usage(
|
|
|
- event, room, f"{self.classification_api.api_code}-{self.classification_api.classification_api}", tokens)
|
|
|
+ event,
|
|
|
+ room,
|
|
|
+ f"{self.classification_api.api_code}-{self.classification_api.classification_api}",
|
|
|
+ tokens,
|
|
|
+ )
|
|
|
|
|
|
if not classification["type"] == "chat":
|
|
|
- event.body = f"!gptbot {classification['type']} {classification['prompt']}"
|
|
|
+ event.body = (
|
|
|
+ f"!gptbot {classification['type']} {classification['prompt']}"
|
|
|
+ )
|
|
|
await self.process_command(room, event)
|
|
|
return
|
|
|
|
|
@@ -780,7 +888,8 @@ class GPTBot:
|
|
|
except Exception as e:
|
|
|
self.logger.log(f"Error getting last messages: {e}", "error")
|
|
|
await self.send_message(
|
|
|
- room, "Something went wrong. Please try again.", True)
|
|
|
+ room, "Something went wrong. Please try again.", True
|
|
|
+ )
|
|
|
return
|
|
|
|
|
|
system_message = self.get_system_message(room)
|
|
@@ -788,7 +897,9 @@ class GPTBot:
|
|
|
chat_messages = [{"role": "system", "content": system_message}]
|
|
|
|
|
|
for message in last_messages:
|
|
|
- role = "assistant" if message.sender == self.matrix_client.user_id else "user"
|
|
|
+ role = (
|
|
|
+ "assistant" if message.sender == self.matrix_client.user_id else "user"
|
|
|
+ )
|
|
|
if not message.event_id == event.event_id:
|
|
|
chat_messages.append({"role": role, "content": message.body})
|
|
|
|
|
@@ -796,20 +907,27 @@ class GPTBot:
|
|
|
|
|
|
# Truncate messages to fit within the token limit
|
|
|
truncated_messages = self._truncate(
|
|
|
- chat_messages, self.max_tokens - 1, system_message=system_message)
|
|
|
+ chat_messages, self.max_tokens - 1, system_message=system_message
|
|
|
+ )
|
|
|
|
|
|
try:
|
|
|
response, tokens_used = await self.chat_api.generate_chat_response(
|
|
|
- chat_messages, user=room.room_id)
|
|
|
+ chat_messages, user=room.room_id
|
|
|
+ )
|
|
|
except Exception as e:
|
|
|
self.logger.log(f"Error generating response: {e}", "error")
|
|
|
await self.send_message(
|
|
|
- room, "Something went wrong. Please try again.", True)
|
|
|
+ room, "Something went wrong. Please try again.", True
|
|
|
+ )
|
|
|
return
|
|
|
|
|
|
if response:
|
|
|
self.log_api_usage(
|
|
|
- event, room, f"{self.chat_api.api_code}-{self.chat_api.chat_api}", tokens_used)
|
|
|
+ event,
|
|
|
+ room,
|
|
|
+ f"{self.chat_api.api_code}-{self.chat_api.chat_api}",
|
|
|
+ tokens_used,
|
|
|
+ )
|
|
|
|
|
|
self.logger.log(f"Sending response to room {room.room_id}...")
|
|
|
|
|
@@ -821,7 +939,8 @@ class GPTBot:
|
|
|
# Send a notice to the room if there was an error
|
|
|
self.logger.log("Didn't get a response from GPT API", "error")
|
|
|
await self.send_message(
|
|
|
- room, "Something went wrong. Please try again.", True)
|
|
|
+ room, "Something went wrong. Please try again.", True
|
|
|
+ )
|
|
|
|
|
|
await self.matrix_client.room_typing(room.room_id, False)
|
|
|
|
|
@@ -845,12 +964,14 @@ class GPTBot:
|
|
|
with closing(self.database.cursor()) as cur:
|
|
|
cur.execute(
|
|
|
"SELECT value FROM room_settings WHERE room_id = ? AND setting = ?",
|
|
|
- (room_id, "system_message")
|
|
|
+ (room_id, "system_message"),
|
|
|
)
|
|
|
system_message = cur.fetchone()
|
|
|
|
|
|
- complete = ((default if ((not system_message) or self.force_system_message) else "") + (
|
|
|
- "\n\n" + system_message[0] if system_message else "")).strip()
|
|
|
+ complete = (
|
|
|
+ (default if ((not system_message) or self.force_system_message) else "")
|
|
|
+ + ("\n\n" + system_message[0] if system_message else "")
|
|
|
+ ).strip()
|
|
|
|
|
|
return complete
|
|
|
|