|
@@ -23,6 +23,7 @@ from typing import List, Dict, Union, Optional
|
|
|
|
|
|
from commands import COMMANDS
|
|
from commands import COMMANDS
|
|
from classes import DuckDBStore
|
|
from classes import DuckDBStore
|
|
|
|
+from migrations import MIGRATIONS
|
|
|
|
|
|
|
|
|
|
def logging(message: str, log_level: str = "info"):
|
|
def logging(message: str, log_level: str = "info"):
|
|
@@ -35,6 +36,7 @@ CONTEXT = {
|
|
"database": False,
|
|
"database": False,
|
|
"default_room_name": "GPTBot",
|
|
"default_room_name": "GPTBot",
|
|
"system_message": "You are a helpful assistant.",
|
|
"system_message": "You are a helpful assistant.",
|
|
|
|
+ "force_system_message": False,
|
|
"max_tokens": 3000,
|
|
"max_tokens": 3000,
|
|
"max_messages": 20,
|
|
"max_messages": 20,
|
|
"model": "gpt-3.5-turbo",
|
|
"model": "gpt-3.5-turbo",
|
|
@@ -48,6 +50,8 @@ async def gpt_query(messages: list, model: Optional[str] = None):
|
|
model = model or CONTEXT["model"]
|
|
model = model or CONTEXT["model"]
|
|
|
|
|
|
logging(f"Querying GPT with {len(messages)} messages")
|
|
logging(f"Querying GPT with {len(messages)} messages")
|
|
|
|
+ logging(messages, "debug")
|
|
|
|
+
|
|
try:
|
|
try:
|
|
response = openai.ChatCompletion.create(
|
|
response = openai.ChatCompletion.create(
|
|
model=model,
|
|
model=model,
|
|
@@ -143,8 +147,9 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
|
|
|
|
|
|
client = kwargs.get("client") or CONTEXT["client"]
|
|
client = kwargs.get("client") or CONTEXT["client"]
|
|
database = kwargs.get("database") or CONTEXT["database"]
|
|
database = kwargs.get("database") or CONTEXT["database"]
|
|
- system_message = kwargs.get("system_message") or CONTEXT["system_message"]
|
|
|
|
max_tokens = kwargs.get("max_tokens") or CONTEXT["max_tokens"]
|
|
max_tokens = kwargs.get("max_tokens") or CONTEXT["max_tokens"]
|
|
|
|
+ system_message = kwargs.get("system_message") or CONTEXT["system_message"]
|
|
|
|
+ force_system_message = kwargs.get("force_system_message") or CONTEXT["force_system_message"]
|
|
|
|
|
|
await client.room_typing(room.room_id, True)
|
|
await client.room_typing(room.room_id, True)
|
|
|
|
|
|
@@ -152,6 +157,12 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
|
|
|
|
|
|
last_messages = await fetch_last_n_messages(room.room_id, 20)
|
|
last_messages = await fetch_last_n_messages(room.room_id, 20)
|
|
|
|
|
|
|
|
+ system_message = get_system_message(room, {
|
|
|
|
+ "database": database,
|
|
|
|
+ "system_message": system_message,
|
|
|
|
+ "force_system_message": force_system_message,
|
|
|
|
+ })
|
|
|
|
+
|
|
chat_messages = [{"role": "system", "content": system_message}]
|
|
chat_messages = [{"role": "system", "content": system_message}]
|
|
|
|
|
|
for message in last_messages:
|
|
for message in last_messages:
|
|
@@ -163,8 +174,7 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
|
|
|
|
|
|
# Truncate messages to fit within the token limit
|
|
# Truncate messages to fit within the token limit
|
|
truncated_messages = truncate_messages_to_fit_tokens(
|
|
truncated_messages = truncate_messages_to_fit_tokens(
|
|
- chat_messages, max_tokens - 1)
|
|
|
|
-
|
|
|
|
|
|
+ chat_messages, max_tokens - 1, system_message=system_message)
|
|
response, tokens_used = await gpt_query(truncated_messages)
|
|
response, tokens_used = await gpt_query(truncated_messages)
|
|
|
|
|
|
if response:
|
|
if response:
|
|
@@ -204,10 +214,29 @@ async def process_command(room: MatrixRoom, event: RoomMessageText, context: Opt
|
|
|
|
|
|
if message:
|
|
if message:
|
|
room_id, event, content = message
|
|
room_id, event, content = message
|
|
|
|
+ rooms = await context["client"].joined_rooms()
|
|
await send_message(context["client"].rooms[room_id], content["body"],
|
|
await send_message(context["client"].rooms[room_id], content["body"],
|
|
True if content["msgtype"] == "m.notice" else False, context["client"])
|
|
True if content["msgtype"] == "m.notice" else False, context["client"])
|
|
|
|
|
|
|
|
|
|
|
|
+def get_system_message(room: MatrixRoom, context: Optional[dict]) -> str:
|
|
|
|
+ context = context or CONTEXT
|
|
|
|
+
|
|
|
|
+ default = context.get("system_message")
|
|
|
|
+
|
|
|
|
+ with context["database"].cursor() as cur:
|
|
|
|
+ cur.execute(
|
|
|
|
+ "SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
|
|
|
|
+ (room.room_id,)
|
|
|
|
+ )
|
|
|
|
+ system_message = cur.fetchone()
|
|
|
|
+
|
|
|
|
+ complete = ((default if ((not system_message) or context["force_system_message"]) else "") + (
|
|
|
|
+ "\n\n" + system_message[0] if system_message else "")).strip()
|
|
|
|
+
|
|
|
|
+ return complete
|
|
|
|
+
|
|
|
|
+
|
|
async def message_callback(room: MatrixRoom, event: RoomMessageText | MegolmEvent, **kwargs):
|
|
async def message_callback(room: MatrixRoom, event: RoomMessageText | MegolmEvent, **kwargs):
|
|
context = kwargs.get("context") or CONTEXT
|
|
context = kwargs.get("context") or CONTEXT
|
|
|
|
|
|
@@ -286,7 +315,8 @@ async def send_message(room: MatrixRoom, message: str, notice: bool = False, cli
|
|
)
|
|
)
|
|
|
|
|
|
if msgtype != "m.reaction":
|
|
if msgtype != "m.reaction":
|
|
- response = client.encrypt(room.room_id, "m.room.message", msgcontent)
|
|
|
|
|
|
+ response = client.encrypt(
|
|
|
|
+ room.room_id, "m.room.message", msgcontent)
|
|
msgtype, content = response
|
|
msgtype, content = response
|
|
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
@@ -318,7 +348,7 @@ async def accept_pending_invites(client: Optional[AsyncClient] = None):
|
|
if isinstance(response, JoinResponse):
|
|
if isinstance(response, JoinResponse):
|
|
logging(response, "debug")
|
|
logging(response, "debug")
|
|
rooms = await client.joined_rooms()
|
|
rooms = await client.joined_rooms()
|
|
- await send_message(rooms[room_id], "Hello! I'm a helpful assistant. How can I help you today?", client)
|
|
|
|
|
|
+ await send_message(client.rooms[room_id], "Hello! I'm a helpful assistant. How can I help you today?", client)
|
|
else:
|
|
else:
|
|
logging(f"Error joining room {room_id}: {response}", "error")
|
|
logging(f"Error joining room {room_id}: {response}", "error")
|
|
|
|
|
|
@@ -408,13 +438,16 @@ async def init(config: ConfigParser):
|
|
if "MaxMessages" in config["OpenAI"]:
|
|
if "MaxMessages" in config["OpenAI"]:
|
|
CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"])
|
|
CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"])
|
|
|
|
|
|
- # Listen for SIGTERM
|
|
|
|
-
|
|
|
|
- def sigterm_handler(_signo, _stack_frame):
|
|
|
|
- logging("Received SIGTERM - exiting...")
|
|
|
|
- exit()
|
|
|
|
|
|
+ # Override defaults with config
|
|
|
|
|
|
- signal.signal(signal.SIGTERM, sigterm_handler)
|
|
|
|
|
|
+ if "GPTBot" in config:
|
|
|
|
+ if "SystemMessage" in config["GPTBot"]:
|
|
|
|
+ CONTEXT["system_message"] = config["GPTBot"]["SystemMessage"]
|
|
|
|
+ if "DefaultRoomName" in config["GPTBot"]:
|
|
|
|
+ CONTEXT["default_room_name"] = config["GPTBot"]["DefaultRoomName"]
|
|
|
|
+ if "ForceSystemMessage" in config["GPTBot"]:
|
|
|
|
+ CONTEXT["force_system_message"] = config["GPTBot"].getboolean(
|
|
|
|
+ "ForceSystemMessage")
|
|
|
|
|
|
|
|
|
|
async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClient] = None):
|
|
async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClient] = None):
|
|
@@ -431,6 +464,14 @@ async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClie
|
|
await client.close()
|
|
await client.close()
|
|
return
|
|
return
|
|
|
|
|
|
|
|
+ # Listen for SIGTERM
|
|
|
|
+
|
|
|
|
+ def sigterm_handler(_signo, _stack_frame):
|
|
|
|
+ logging("Received SIGTERM - exiting...")
|
|
|
|
+ exit()
|
|
|
|
+
|
|
|
|
+ signal.signal(signal.SIGTERM, sigterm_handler)
|
|
|
|
+
|
|
logging("Starting bot...")
|
|
logging("Starting bot...")
|
|
|
|
|
|
client.add_response_callback(sync_cb, SyncResponse)
|
|
client.add_response_callback(sync_cb, SyncResponse)
|
|
@@ -460,9 +501,9 @@ async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClie
|
|
|
|
|
|
def initialize_database(path: os.PathLike):
|
|
def initialize_database(path: os.PathLike):
|
|
logging("Initializing database...")
|
|
logging("Initializing database...")
|
|
- database = duckdb.connect(path)
|
|
|
|
|
|
+ conn = duckdb.connect(path)
|
|
|
|
|
|
- with database.cursor() as cursor:
|
|
|
|
|
|
+ with conn.cursor() as cursor:
|
|
# Get the latest migration ID if the migrations table exists
|
|
# Get the latest migration ID if the migrations table exists
|
|
try:
|
|
try:
|
|
cursor.execute(
|
|
cursor.execute(
|
|
@@ -472,40 +513,17 @@ def initialize_database(path: os.PathLike):
|
|
)
|
|
)
|
|
|
|
|
|
latest_migration = int(cursor.fetchone()[0])
|
|
latest_migration = int(cursor.fetchone()[0])
|
|
|
|
+
|
|
except:
|
|
except:
|
|
latest_migration = 0
|
|
latest_migration = 0
|
|
|
|
|
|
- # Version 1
|
|
|
|
-
|
|
|
|
- if latest_migration < 1:
|
|
|
|
- cursor.execute(
|
|
|
|
- """
|
|
|
|
- CREATE TABLE IF NOT EXISTS token_usage (
|
|
|
|
- message_id TEXT PRIMARY KEY,
|
|
|
|
- room_id TEXT NOT NULL,
|
|
|
|
- tokens INTEGER NOT NULL,
|
|
|
|
- timestamp TIMESTAMP NOT NULL
|
|
|
|
- )
|
|
|
|
- """
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- cursor.execute(
|
|
|
|
- """
|
|
|
|
- CREATE TABLE IF NOT EXISTS migrations (
|
|
|
|
- id INTEGER NOT NULL,
|
|
|
|
- timestamp TIMESTAMP NOT NULL
|
|
|
|
- )
|
|
|
|
- """
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- cursor.execute(
|
|
|
|
- "INSERT INTO migrations (id, timestamp) VALUES (1, ?)",
|
|
|
|
- (datetime.now(),)
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- database.commit()
|
|
|
|
|
|
+ for migration, function in MIGRATIONS.items():
|
|
|
|
+ if latest_migration < migration:
|
|
|
|
+ logging(f"Running migration {migration}...")
|
|
|
|
+ function(conn)
|
|
|
|
+ latest_migration = migration
|
|
|
|
|
|
- return database
|
|
|
|
|
|
+ return conn
|
|
|
|
|
|
|
|
|
|
async def get_device_id(access_token, homeserver):
|
|
async def get_device_id(access_token, homeserver):
|