Browse Source

Moving migrations to subdirectory
Add option for custom system messages per room
Fixing some methods in store

Kumi 1 year ago
parent
commit
2bbc6a33ca
10 changed files with 358 additions and 207 deletions
  1. 3 0
      README.md
  2. 46 162
      classes/store.py
  3. 2 0
      commands/__init__.py
  4. 34 0
      commands/systemmessage.py
  5. 7 2
      config.dist.ini
  6. 61 43
      gptbot.py
  7. 11 0
      migrations/__init__.py
  8. 32 0
      migrations/migration_1.py
  9. 138 0
      migrations/migration_2.py
  10. 24 0
      migrations/migration_3.py

+ 3 - 0
README.md

@@ -6,6 +6,9 @@ to generate responses to messages in a Matrix room.
 It will also save a log of the spent tokens to a DuckDB database
 (database.db in the working directory, by default).
 
+Note that this bot does not yet support encryption - this is still work in
+progress.
+
 ## Installation
 
 Simply clone this repository and install the requirements.

+ 46 - 162
classes/store.py

@@ -1,6 +1,6 @@
 import duckdb
 
-from nio.store.database import MatrixStore, DeviceTrustState, OlmDevice, TrustState, InboundGroupSession, SessionStore, OlmSessions, GroupSessionStore, OutgoingKeyRequest, DeviceStore
+from nio.store.database import MatrixStore, DeviceTrustState, OlmDevice, TrustState, InboundGroupSession, SessionStore, OlmSessions, GroupSessionStore, OutgoingKeyRequest, DeviceStore, Session
 from nio.crypto import OlmAccount, OlmDevice
 
 from random import SystemRandom
@@ -24,150 +24,6 @@ class DuckDBStore(MatrixStore):
         self.conn = duckdb_conn
         self.user_id = user_id
         self.device_id = device_id
-        self._create_tables()
-
-    def _create_tables(self):
-        with self.conn.cursor() as cursor:
-            cursor.execute("""
-            DROP TABLE IF EXISTS sync_tokens CASCADE;
-            DROP TABLE IF EXISTS encrypted_rooms CASCADE;
-            DROP TABLE IF EXISTS outgoing_key_requests CASCADE;
-            DROP TABLE IF EXISTS forwarded_chains CASCADE;
-            DROP TABLE IF EXISTS outbound_group_sessions CASCADE;
-            DROP TABLE IF EXISTS inbound_group_sessions CASCADE;
-            DROP TABLE IF EXISTS olm_sessions CASCADE;
-            DROP TABLE IF EXISTS device_trust_state CASCADE;
-            DROP TABLE IF EXISTS keys CASCADE;
-            DROP TABLE IF EXISTS device_keys_key CASCADE;
-            DROP TABLE IF EXISTS device_keys CASCADE;
-            DROP TABLE IF EXISTS accounts CASCADE;
-            """)
-
-            # Create accounts table
-            cursor.execute("""
-            CREATE TABLE IF NOT EXISTS accounts (
-                id INTEGER PRIMARY KEY,
-                user_id VARCHAR NOT NULL,
-                device_id VARCHAR NOT NULL,
-                shared_account INTEGER NOT NULL,
-                pickle VARCHAR NOT NULL
-            );
-            """)
-
-            # Create device_keys table
-            cursor.execute("""
-            CREATE TABLE IF NOT EXISTS device_keys (
-                device_id TEXT PRIMARY KEY,
-                account_id INTEGER NOT NULL,
-                user_id TEXT NOT NULL,
-                display_name TEXT,
-                deleted BOOLEAN NOT NULL DEFAULT 0,
-                UNIQUE (account_id, user_id, device_id),
-                FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
-            );
-
-            CREATE TABLE IF NOT EXISTS keys (
-                key_type TEXT NOT NULL,
-                key TEXT NOT NULL,
-                device_id VARCHAR NOT NULL,
-                UNIQUE (key_type, device_id),
-                FOREIGN KEY (device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
-            );
-            """)
-
-            # Create device_trust_state table
-            cursor.execute("""
-            CREATE TABLE IF NOT EXISTS device_trust_state (
-                device_id VARCHAR PRIMARY KEY,
-                state INTEGER NOT NULL,
-                FOREIGN KEY(device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
-            );
-            """)
-
-            # Create olm_sessions table
-            cursor.execute("""
-            CREATE SEQUENCE IF NOT EXISTS olm_sessions_id_seq START 1;
-
-            CREATE TABLE IF NOT EXISTS olm_sessions (
-                id INTEGER PRIMARY KEY DEFAULT nextval('olm_sessions_id_seq'),
-                account_id INTEGER NOT NULL,
-                sender_key TEXT NOT NULL,
-                session BLOB NOT NULL,
-                session_id VARCHAR NOT NULL,
-                creation_time TIMESTAMP NOT NULL,
-                last_usage_date TIMESTAMP NOT NULL,
-                FOREIGN KEY (account_id) REFERENCES accounts (id) ON DELETE CASCADE
-            );
-            """)
-
-            # Create inbound_group_sessions table
-            cursor.execute("""
-            CREATE SEQUENCE IF NOT EXISTS inbound_group_sessions_id_seq START 1;
-
-            CREATE TABLE IF NOT EXISTS inbound_group_sessions (
-                id INTEGER PRIMARY KEY DEFAULT nextval('inbound_group_sessions_id_seq'),
-                account_id INTEGER NOT NULL,
-                session TEXT NOT NULL,
-                fp_key TEXT NOT NULL,
-                sender_key TEXT NOT NULL,
-                room_id TEXT NOT NULL,
-                UNIQUE (account_id, sender_key, fp_key, room_id),
-                FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
-            );
-
-            CREATE TABLE IF NOT EXISTS forwarded_chains (
-                id INTEGER PRIMARY KEY,
-                session_id INTEGER NOT NULL,
-                sender_key TEXT NOT NULL,
-                FOREIGN KEY (session_id) REFERENCES inbound_group_sessions(id) ON DELETE CASCADE
-            );
-            """)
-
-            # Create outbound_group_sessions table
-            cursor.execute("""
-            CREATE TABLE IF NOT EXISTS outbound_group_sessions (
-                id INTEGER PRIMARY KEY,
-                account_id INTEGER NOT NULL,
-                room_id VARCHAR NOT NULL,
-                session_id VARCHAR NOT NULL UNIQUE,
-                session BLOB NOT NULL,
-                FOREIGN KEY(account_id) REFERENCES accounts(id) ON DELETE CASCADE
-            );
-            """)
-
-            # Create outgoing_key_requests table
-            cursor.execute("""
-            CREATE TABLE IF NOT EXISTS outgoing_key_requests (
-                id INTEGER PRIMARY KEY,
-                account_id INTEGER NOT NULL,
-                request_id TEXT NOT NULL,
-                session_id TEXT NOT NULL,
-                room_id TEXT NOT NULL,
-                algorithm TEXT NOT NULL,
-                FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
-                UNIQUE (account_id, request_id)
-            );
-
-            """)
-
-            # Create encrypted_rooms table
-            cursor.execute("""
-            CREATE TABLE IF NOT EXISTS encrypted_rooms (
-                room_id TEXT NOT NULL,
-                account_id INTEGER NOT NULL,
-                PRIMARY KEY (room_id, account_id),
-                FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
-            );
-            """)
-
-            # Create sync_tokens table
-            cursor.execute("""
-            CREATE TABLE IF NOT EXISTS sync_tokens (
-                account_id INTEGER PRIMARY KEY,
-                token TEXT NOT NULL,
-                FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
-            );
-            """)
 
     def _get_account(self):
         cursor = self.conn.cursor()
@@ -387,18 +243,18 @@ class DuckDBStore(MatrixStore):
             for d in device_keys:
                 cur.execute(
                     "SELECT * FROM keys WHERE device_id = ?",
-                    (d["id"],)
+                    (d[0],)
                 )
                 keys = cur.fetchall()
-                key_dict = {k["key_type"]: k["key"] for k in keys}
+                key_dict = {k[0]: k[1] for k in keys}
 
                 store.add(
                     OlmDevice(
-                        d["user_id"],
-                        d["device_id"],
+                        d[2],
+                        d[0],
                         key_dict,
-                        display_name=d["display_name"],
-                        deleted=d["deleted"],
+                        display_name=d[3],
+                        deleted=d[4],
                     )
                 )
 
@@ -561,18 +417,21 @@ class DuckDBStore(MatrixStore):
             )
 
             for row in cursor.fetchall():
+                cursor.execute(
+                    "SELECT sender_key FROM forwarded_chains WHERE session_id = ?",
+                    (row[1],),
+                )
+                chains = cursor.fetchall()
+
                 session = InboundGroupSession.from_pickle(
-                    row["session"],
-                    row["fp_key"],
-                    row["sender_key"],
-                    row["room_id"],
+                    row[2].encode(),
+                    row[3],
+                    row[4],
+                    row[5],
                     self.pickle_key,
                     [
-                        chain["sender_key"]
-                        for chain in cursor.execute(
-                            "SELECT sender_key FROM forwarded_chains WHERE session_id = ?",
-                            (row["id"],),
-                        )
+                        chain[0]
+                        for chain in chains
                     ],
                 )
                 store.add(session)
@@ -621,7 +480,7 @@ class DuckDBStore(MatrixStore):
             )
             rows = cur.fetchall()
 
-        return {row["room_id"] for row in rows}
+        return {row[0] for row in rows}
 
     def save_sync_token(self, token):
         """Save the given token"""
@@ -727,4 +586,29 @@ class DuckDBStore(MatrixStore):
                     key_request.room_id,
                     key_request.algorithm,
                 )
-            )
+            )
+
+    def load_account(self):
+        # type: () -> Optional[OlmAccount]
+        """Load the Olm account from the database.
+
+        Returns:
+            ``OlmAccount`` object, or ``None`` if it wasn't found for the
+                current device_id.
+
+        """
+        cursor = self.conn.cursor()
+        query = """
+            SELECT pickle, shared_account
+            FROM accounts
+            WHERE device_id = ?;
+        """
+        cursor.execute(query, (self.device_id,))
+
+        result = cursor.fetchone()
+
+        if not result:
+            return None
+
+        account_pickle, shared = result
+        return OlmAccount.from_pickle(account_pickle.encode(), self.pickle_key, shared)

+ 2 - 0
commands/__init__.py

@@ -5,6 +5,7 @@ from .botinfo import command_botinfo
 from .unknown import command_unknown
 from .coin import command_coin
 from .ignoreolder import command_ignoreolder
+from .systemmessage import command_systemmessage
 
 COMMANDS = {
     "help": command_help,
@@ -13,5 +14,6 @@ COMMANDS = {
     "botinfo": command_botinfo,
     "coin": command_coin,
     "ignoreolder": command_ignoreolder,
+    "systemmessage": command_systemmessage,
     None: command_unknown,
 }

+ 34 - 0
commands/systemmessage.py

@@ -0,0 +1,34 @@
+from nio.events.room_events import RoomMessageText
+from nio.rooms import MatrixRoom
+
+
+async def command_systemmessage(room: MatrixRoom, event: RoomMessageText, context: dict):
+    system_message = " ".join(event.body.split()[2:])
+
+    if system_message:
+        context["logger"]("Adding system message...")
+
+        with context["database"].cursor() as cur:
+            cur.execute(
+                "INSERT INTO system_messages (room_id, message_id, user_id, body, timestamp) VALUES (?, ?, ?, ?, ?)",
+                (room.room_id, event.event_id, event.sender,
+                 system_message, event.server_timestamp)
+            )
+
+        return room.room_id, "m.room.message", {"msgtype": "m.notice", "body": f"System message stored: {system_message}"}
+
+    context["logger"]("Retrieving system message...")
+
+    with context["database"].cursor() as cur:
+        cur.execute(
+            "SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
+            (room.room_id,)
+        )
+        system_message = cur.fetchone()
+
+    if system_message is None:
+        system_message = context.get("system_message", "No system message set")
+    elif context.get("force_system_message") and context.get("system_message"):
+        system_message = system_message + "\n\n" + context["system_message"]
+
+    return room.room_id, "m.room.message", {"msgtype": "m.notice", "body": f"System message: {system_message}"}

+ 7 - 2
config.dist.ini

@@ -69,10 +69,15 @@ AccessToken = syt_yoursynapsetoken
 #
 # SystemMessage = You are a helpful bot.
 
+# Force inclusion of the SystemMessage defined above if one is defined on per-room level
+# If no custom message is defined for the room, SystemMessage is always included
+#
+# ForceSystemMessage = 0
+
 [Database]
 
 # Settings for the DuckDB database.
-# Currently only used to store details on spent tokens per room.
-# If not defined, the bot will not store this data.
+# If not defined, the bot will not be able to remember anything, and will not support encryption
+# N.B.: Encryption doesn't work as it is supposed to anyway.
 
 Path = database.db

+ 61 - 43
gptbot.py

@@ -23,6 +23,7 @@ from typing import List, Dict, Union, Optional
 
 from commands import COMMANDS
 from classes import DuckDBStore
+from migrations import MIGRATIONS
 
 
 def logging(message: str, log_level: str = "info"):
@@ -35,6 +36,7 @@ CONTEXT = {
     "database": False,
     "default_room_name": "GPTBot",
     "system_message": "You are a helpful assistant.",
+    "force_system_message": False,
     "max_tokens": 3000,
     "max_messages": 20,
     "model": "gpt-3.5-turbo",
@@ -48,6 +50,8 @@ async def gpt_query(messages: list, model: Optional[str] = None):
     model = model or CONTEXT["model"]
 
     logging(f"Querying GPT with {len(messages)} messages")
+    logging(messages, "debug")
+    
     try:
         response = openai.ChatCompletion.create(
             model=model,
@@ -143,8 +147,9 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
 
     client = kwargs.get("client") or CONTEXT["client"]
     database = kwargs.get("database") or CONTEXT["database"]
-    system_message = kwargs.get("system_message") or CONTEXT["system_message"]
     max_tokens = kwargs.get("max_tokens") or CONTEXT["max_tokens"]
+    system_message = kwargs.get("system_message") or CONTEXT["system_message"]
+    force_system_message = kwargs.get("force_system_message") or CONTEXT["force_system_message"]
 
     await client.room_typing(room.room_id, True)
 
@@ -152,6 +157,12 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
 
     last_messages = await fetch_last_n_messages(room.room_id, 20)
 
+    system_message = get_system_message(room, {
+        "database": database,
+        "system_message": system_message,
+        "force_system_message": force_system_message,
+    })
+
     chat_messages = [{"role": "system", "content": system_message}]
 
     for message in last_messages:
@@ -163,8 +174,7 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
 
     # Truncate messages to fit within the token limit
     truncated_messages = truncate_messages_to_fit_tokens(
-        chat_messages, max_tokens - 1)
-
+        chat_messages, max_tokens - 1, system_message=system_message)
     response, tokens_used = await gpt_query(truncated_messages)
 
     if response:
@@ -204,10 +214,29 @@ async def process_command(room: MatrixRoom, event: RoomMessageText, context: Opt
 
     if message:
         room_id, event, content = message
+        rooms = await context["client"].joined_rooms()
         await send_message(context["client"].rooms[room_id], content["body"],
                            True if content["msgtype"] == "m.notice" else False, context["client"])
 
 
+def get_system_message(room: MatrixRoom, context: Optional[dict]) -> str:
+    context = context or CONTEXT
+
+    default = context.get("system_message")
+
+    with context["database"].cursor() as cur:
+        cur.execute(
+            "SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
+            (room.room_id,)
+        )
+        system_message = cur.fetchone()
+
+    complete = ((default if ((not system_message) or context["force_system_message"]) else "") + (
+        "\n\n" + system_message[0] if system_message else "")).strip()
+
+    return complete
+
+
 async def message_callback(room: MatrixRoom, event: RoomMessageText | MegolmEvent, **kwargs):
     context = kwargs.get("context") or CONTEXT
 
@@ -286,7 +315,8 @@ async def send_message(room: MatrixRoom, message: str, notice: bool = False, cli
                     )
 
             if msgtype != "m.reaction":
-                response = client.encrypt(room.room_id, "m.room.message", msgcontent)
+                response = client.encrypt(
+                    room.room_id, "m.room.message", msgcontent)
                 msgtype, content = response
 
         except Exception as e:
@@ -318,7 +348,7 @@ async def accept_pending_invites(client: Optional[AsyncClient] = None):
         if isinstance(response, JoinResponse):
             logging(response, "debug")
             rooms = await client.joined_rooms()
-            await send_message(rooms[room_id], "Hello! I'm a helpful assistant. How can I help you today?", client)
+            await send_message(client.rooms[room_id], "Hello! I'm a helpful assistant. How can I help you today?", client)
         else:
             logging(f"Error joining room {room_id}: {response}", "error")
 
@@ -408,13 +438,16 @@ async def init(config: ConfigParser):
     if "MaxMessages" in config["OpenAI"]:
         CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"])
 
-    # Listen for SIGTERM
-
-    def sigterm_handler(_signo, _stack_frame):
-        logging("Received SIGTERM - exiting...")
-        exit()
+    # Override defaults with config
 
-    signal.signal(signal.SIGTERM, sigterm_handler)
+    if "GPTBot" in config:
+        if "SystemMessage" in config["GPTBot"]:
+            CONTEXT["system_message"] = config["GPTBot"]["SystemMessage"]
+        if "DefaultRoomName" in config["GPTBot"]:
+            CONTEXT["default_room_name"] = config["GPTBot"]["DefaultRoomName"]
+        if "ForceSystemMessage" in config["GPTBot"]:
+            CONTEXT["force_system_message"] = config["GPTBot"].getboolean(
+                "ForceSystemMessage")
 
 
 async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClient] = None):
@@ -431,6 +464,14 @@ async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClie
         await client.close()
         return
 
+    # Listen for SIGTERM
+
+    def sigterm_handler(_signo, _stack_frame):
+        logging("Received SIGTERM - exiting...")
+        exit()
+
+    signal.signal(signal.SIGTERM, sigterm_handler)
+
     logging("Starting bot...")
 
     client.add_response_callback(sync_cb, SyncResponse)
@@ -460,9 +501,9 @@ async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClie
 
 def initialize_database(path: os.PathLike):
     logging("Initializing database...")
-    database = duckdb.connect(path)
+    conn = duckdb.connect(path)
 
-    with database.cursor() as cursor:
+    with conn.cursor() as cursor:
         # Get the latest migration ID if the migrations table exists
         try:
             cursor.execute(
@@ -472,40 +513,17 @@ def initialize_database(path: os.PathLike):
             )
 
             latest_migration = int(cursor.fetchone()[0])
+
         except:
             latest_migration = 0
 
-        # Version 1
-
-        if latest_migration < 1:
-            cursor.execute(
-                """
-                CREATE TABLE IF NOT EXISTS token_usage (
-                    message_id TEXT PRIMARY KEY,
-                    room_id TEXT NOT NULL,
-                    tokens INTEGER NOT NULL,
-                    timestamp TIMESTAMP NOT NULL
-                )
-                """
-            )
-
-            cursor.execute(
-                """
-                CREATE TABLE IF NOT EXISTS migrations (
-                    id INTEGER NOT NULL,
-                    timestamp TIMESTAMP NOT NULL
-                )
-                """
-            )
-
-            cursor.execute(
-                "INSERT INTO migrations (id, timestamp) VALUES (1, ?)",
-                (datetime.now(),)
-            )
-
-        database.commit()
+    for migration, function in MIGRATIONS.items():
+        if latest_migration < migration:
+            logging(f"Running migration {migration}...")
+            function(conn)
+            latest_migration = migration
 
-        return database
+    return conn
 
 
 async def get_device_id(access_token, homeserver):

+ 11 - 0
migrations/__init__.py

@@ -0,0 +1,11 @@
+from collections import OrderedDict
+
+from .migration_1 import migration as migration_1
+from .migration_2 import migration as migration_2
+from .migration_3 import migration as migration_3
+
+MIGRATIONS = OrderedDict()
+
+MIGRATIONS[1] = migration_1
+MIGRATIONS[2] = migration_2
+MIGRATIONS[3] = migration_3

+ 32 - 0
migrations/migration_1.py

@@ -0,0 +1,32 @@
+# Initial migration, token usage logging
+
+from datetime import datetime
+
+def migration(conn):
+    with conn.cursor() as cursor:
+        cursor.execute(
+            """
+            CREATE TABLE IF NOT EXISTS token_usage (
+                message_id TEXT PRIMARY KEY,
+                room_id TEXT NOT NULL,
+                tokens INTEGER NOT NULL,
+                timestamp TIMESTAMP NOT NULL
+            )
+            """
+        )
+
+        cursor.execute(
+            """
+            CREATE TABLE IF NOT EXISTS migrations (
+                id INTEGER NOT NULL,
+                timestamp TIMESTAMP NOT NULL
+            )
+            """
+        )
+
+        cursor.execute(
+            "INSERT INTO migrations (id, timestamp) VALUES (1, ?)",
+            (datetime.now(),)
+        )
+
+        conn.commit()

+ 138 - 0
migrations/migration_2.py

@@ -0,0 +1,138 @@
+# Migration for Matrix Store
+
+from datetime import datetime
+
+def migration(conn):
+    with conn.cursor() as cursor:
+        # Create accounts table
+        cursor.execute("""
+        CREATE TABLE IF NOT EXISTS accounts (
+            id INTEGER PRIMARY KEY,
+            user_id VARCHAR NOT NULL,
+            device_id VARCHAR NOT NULL,
+            shared_account INTEGER NOT NULL,
+            pickle VARCHAR NOT NULL
+        );
+        """)
+
+        # Create device_keys table
+        cursor.execute("""
+        CREATE TABLE IF NOT EXISTS device_keys (
+            device_id TEXT PRIMARY KEY,
+            account_id INTEGER NOT NULL,
+            user_id TEXT NOT NULL,
+            display_name TEXT,
+            deleted BOOLEAN NOT NULL DEFAULT 0,
+            UNIQUE (account_id, user_id, device_id),
+            FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
+        );
+
+        CREATE TABLE IF NOT EXISTS keys (
+            key_type TEXT NOT NULL,
+            key TEXT NOT NULL,
+            device_id VARCHAR NOT NULL,
+            UNIQUE (key_type, device_id),
+            FOREIGN KEY (device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
+        );
+        """)
+
+        # Create device_trust_state table
+        cursor.execute("""
+        CREATE TABLE IF NOT EXISTS device_trust_state (
+            device_id VARCHAR PRIMARY KEY,
+            state INTEGER NOT NULL,
+            FOREIGN KEY(device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
+        );
+        """)
+
+        # Create olm_sessions table
+        cursor.execute("""
+        CREATE SEQUENCE IF NOT EXISTS olm_sessions_id_seq START 1;
+
+        CREATE TABLE IF NOT EXISTS olm_sessions (
+            id INTEGER PRIMARY KEY DEFAULT nextval('olm_sessions_id_seq'),
+            account_id INTEGER NOT NULL,
+            sender_key TEXT NOT NULL,
+            session BLOB NOT NULL,
+            session_id VARCHAR NOT NULL,
+            creation_time TIMESTAMP NOT NULL,
+            last_usage_date TIMESTAMP NOT NULL,
+            FOREIGN KEY (account_id) REFERENCES accounts (id) ON DELETE CASCADE
+        );
+        """)
+
+        # Create inbound_group_sessions table
+        cursor.execute("""
+        CREATE SEQUENCE IF NOT EXISTS inbound_group_sessions_id_seq START 1;
+
+        CREATE TABLE IF NOT EXISTS inbound_group_sessions (
+            id INTEGER PRIMARY KEY DEFAULT nextval('inbound_group_sessions_id_seq'),
+            account_id INTEGER NOT NULL,
+            session TEXT NOT NULL,
+            fp_key TEXT NOT NULL,
+            sender_key TEXT NOT NULL,
+            room_id TEXT NOT NULL,
+            UNIQUE (account_id, sender_key, fp_key, room_id),
+            FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
+        );
+
+        CREATE TABLE IF NOT EXISTS forwarded_chains (
+            id INTEGER PRIMARY KEY,
+            session_id INTEGER NOT NULL,
+            sender_key TEXT NOT NULL,
+            FOREIGN KEY (session_id) REFERENCES inbound_group_sessions(id) ON DELETE CASCADE
+        );
+        """)
+
+        # Create outbound_group_sessions table
+        cursor.execute("""
+        CREATE TABLE IF NOT EXISTS outbound_group_sessions (
+            id INTEGER PRIMARY KEY,
+            account_id INTEGER NOT NULL,
+            room_id VARCHAR NOT NULL,
+            session_id VARCHAR NOT NULL UNIQUE,
+            session BLOB NOT NULL,
+            FOREIGN KEY(account_id) REFERENCES accounts(id) ON DELETE CASCADE
+        );
+        """)
+
+        # Create outgoing_key_requests table
+        cursor.execute("""
+        CREATE TABLE IF NOT EXISTS outgoing_key_requests (
+            id INTEGER PRIMARY KEY,
+            account_id INTEGER NOT NULL,
+            request_id TEXT NOT NULL,
+            session_id TEXT NOT NULL,
+            room_id TEXT NOT NULL,
+            algorithm TEXT NOT NULL,
+            FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
+            UNIQUE (account_id, request_id)
+        );
+
+        """)
+
+        # Create encrypted_rooms table
+        cursor.execute("""
+        CREATE TABLE IF NOT EXISTS encrypted_rooms (
+            room_id TEXT NOT NULL,
+            account_id INTEGER NOT NULL,
+            PRIMARY KEY (room_id, account_id),
+            FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
+        );
+        """)
+
+        # Create sync_tokens table
+        cursor.execute("""
+        CREATE TABLE IF NOT EXISTS sync_tokens (
+            account_id INTEGER PRIMARY KEY,
+            token TEXT NOT NULL,
+            FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
+        );
+        """)
+        
+        cursor.execute(
+            "INSERT INTO migrations (id, timestamp) VALUES (2, ?)",
+            (datetime.now(),)
+        )
+
+        conn.commit()

+ 24 - 0
migrations/migration_3.py

@@ -0,0 +1,24 @@
+# Migration for custom system messages
+
+from datetime import datetime
+
+def migration(conn):
+    with conn.cursor() as cursor:
+        cursor.execute(
+            """
+            CREATE TABLE IF NOT EXISTS system_messages (
+                room_id TEXT NOT NULL,
+                message_id TEXT NOT NULL,
+                user_id TEXT NOT NULL,
+                body TEXT NOT NULL,
+                timestamp BIGINT NOT NULL,
+            )
+            """
+        )
+
+        cursor.execute(
+            "INSERT INTO migrations (id, timestamp) VALUES (3, ?)",
+            (datetime.now(),)
+        )
+
+        conn.commit()