|
@@ -235,6 +235,24 @@ class GPTBot:
|
|
|
|
|
|
await COMMANDS.get(command, COMMANDS[None])(room, event, self)
|
|
await COMMANDS.get(command, COMMANDS[None])(room, event, self)
|
|
|
|
|
|
|
|
+ def room_uses_classification(self, room: MatrixRoom | int) -> bool:
|
|
|
|
+ """Check if a room uses classification.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ room (MatrixRoom): The room to check.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ bool: Whether the room uses classification.
|
|
|
|
+ """
|
|
|
|
+ room_id = room.room_id if isinstance(room, MatrixRoom) else room
|
|
|
|
+
|
|
|
|
+ with self.database.cursor() as cursor:
|
|
|
|
+ cursor.execute(
|
|
|
|
+ "SELECT value FROM room_settings WHERE room_id = ? AND setting = ?", (room_id, "use_classification"))
|
|
|
|
+ result = cursor.fetchone()
|
|
|
|
+
|
|
|
|
+ return False if not result else bool(int(result[0]))
|
|
|
|
+
|
|
async def event_callback(self, room: MatrixRoom, event: Event):
|
|
async def event_callback(self, room: MatrixRoom, event: Event):
|
|
self.logger.log("Received event: " + str(event.event_id), "debug")
|
|
self.logger.log("Received event: " + str(event.event_id), "debug")
|
|
try:
|
|
try:
|
|
@@ -456,11 +474,21 @@ class GPTBot:
|
|
self.logger.log("Syncing one last time...")
|
|
self.logger.log("Syncing one last time...")
|
|
await self.matrix_client.sync(timeout=30000)
|
|
await self.matrix_client.sync(timeout=30000)
|
|
|
|
|
|
- async def process_query(self, room: MatrixRoom, event: RoomMessageText):
|
|
|
|
|
|
+ async def process_query(self, room: MatrixRoom, event: RoomMessageText, allow_classify: bool = True):
|
|
await self.matrix_client.room_typing(room.room_id, True)
|
|
await self.matrix_client.room_typing(room.room_id, True)
|
|
|
|
|
|
await self.matrix_client.room_read_markers(room.room_id, event.event_id)
|
|
await self.matrix_client.room_read_markers(room.room_id, event.event_id)
|
|
|
|
|
|
|
|
+ if allow_classify and self.room_uses_classification(room):
|
|
|
|
+ classification, tokens = self.classification_api.classify_message(event.body, room.room_id)
|
|
|
|
+
|
|
|
|
+ self.log_api_usage(event, room, f"{self.classification_api.api_code}-{self.classification_api.classification_api}", tokens)
|
|
|
|
+
|
|
|
|
+ if not classification["type"] == "chat":
|
|
|
|
+ event.body = f"!gptbot {classification['type']} {classification['prompt']}"
|
|
|
|
+ await self.process_command(room, event)
|
|
|
|
+ return
|
|
|
|
+
|
|
try:
|
|
try:
|
|
last_messages = await self._last_n_messages(room.room_id, 20)
|
|
last_messages = await self._last_n_messages(room.room_id, 20)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
@@ -520,8 +548,8 @@ class GPTBot:
|
|
|
|
|
|
with self.database.cursor() as cur:
|
|
with self.database.cursor() as cur:
|
|
cur.execute(
|
|
cur.execute(
|
|
- "SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
|
|
|
|
- (room_id,)
|
|
|
|
|
|
+ "SELECT value FROM room_settings WHERE room_id = ? AND setting = ?",
|
|
|
|
+ (room_id, "system_message")
|
|
)
|
|
)
|
|
system_message = cur.fetchone()
|
|
system_message = cur.fetchone()
|
|
|
|
|