Browse Source

Moving migrations to subdirectory
Add option for custom system messages per room
Fixing some methods in store

Kumi 1 year ago
parent
commit
2bbc6a33ca
10 changed files with 358 additions and 207 deletions
  1. 3 0
      README.md
  2. 46 162
      classes/store.py
  3. 2 0
      commands/__init__.py
  4. 34 0
      commands/systemmessage.py
  5. 7 2
      config.dist.ini
  6. 61 43
      gptbot.py
  7. 11 0
      migrations/__init__.py
  8. 32 0
      migrations/migration_1.py
  9. 138 0
      migrations/migration_2.py
  10. 24 0
      migrations/migration_3.py

+ 3 - 0
README.md

@@ -6,6 +6,9 @@ to generate responses to messages in a Matrix room.
 It will also save a log of the spent tokens to a DuckDB database
 It will also save a log of the spent tokens to a DuckDB database
 (database.db in the working directory, by default).
 (database.db in the working directory, by default).
 
 
+Note that this bot does not yet support encryption - this is still work in
+progress.
+
 ## Installation
 ## Installation
 
 
 Simply clone this repository and install the requirements.
 Simply clone this repository and install the requirements.

+ 46 - 162
classes/store.py

@@ -1,6 +1,6 @@
 import duckdb
 import duckdb
 
 
-from nio.store.database import MatrixStore, DeviceTrustState, OlmDevice, TrustState, InboundGroupSession, SessionStore, OlmSessions, GroupSessionStore, OutgoingKeyRequest, DeviceStore
+from nio.store.database import MatrixStore, DeviceTrustState, OlmDevice, TrustState, InboundGroupSession, SessionStore, OlmSessions, GroupSessionStore, OutgoingKeyRequest, DeviceStore, Session
 from nio.crypto import OlmAccount, OlmDevice
 from nio.crypto import OlmAccount, OlmDevice
 
 
 from random import SystemRandom
 from random import SystemRandom
@@ -24,150 +24,6 @@ class DuckDBStore(MatrixStore):
         self.conn = duckdb_conn
         self.conn = duckdb_conn
         self.user_id = user_id
         self.user_id = user_id
         self.device_id = device_id
         self.device_id = device_id
-        self._create_tables()
-
-    def _create_tables(self):
-        with self.conn.cursor() as cursor:
-            cursor.execute("""
-            DROP TABLE IF EXISTS sync_tokens CASCADE;
-            DROP TABLE IF EXISTS encrypted_rooms CASCADE;
-            DROP TABLE IF EXISTS outgoing_key_requests CASCADE;
-            DROP TABLE IF EXISTS forwarded_chains CASCADE;
-            DROP TABLE IF EXISTS outbound_group_sessions CASCADE;
-            DROP TABLE IF EXISTS inbound_group_sessions CASCADE;
-            DROP TABLE IF EXISTS olm_sessions CASCADE;
-            DROP TABLE IF EXISTS device_trust_state CASCADE;
-            DROP TABLE IF EXISTS keys CASCADE;
-            DROP TABLE IF EXISTS device_keys_key CASCADE;
-            DROP TABLE IF EXISTS device_keys CASCADE;
-            DROP TABLE IF EXISTS accounts CASCADE;
-            """)
-
-            # Create accounts table
-            cursor.execute("""
-            CREATE TABLE IF NOT EXISTS accounts (
-                id INTEGER PRIMARY KEY,
-                user_id VARCHAR NOT NULL,
-                device_id VARCHAR NOT NULL,
-                shared_account INTEGER NOT NULL,
-                pickle VARCHAR NOT NULL
-            );
-            """)
-
-            # Create device_keys table
-            cursor.execute("""
-            CREATE TABLE IF NOT EXISTS device_keys (
-                device_id TEXT PRIMARY KEY,
-                account_id INTEGER NOT NULL,
-                user_id TEXT NOT NULL,
-                display_name TEXT,
-                deleted BOOLEAN NOT NULL DEFAULT 0,
-                UNIQUE (account_id, user_id, device_id),
-                FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
-            );
-
-            CREATE TABLE IF NOT EXISTS keys (
-                key_type TEXT NOT NULL,
-                key TEXT NOT NULL,
-                device_id VARCHAR NOT NULL,
-                UNIQUE (key_type, device_id),
-                FOREIGN KEY (device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
-            );
-            """)
-
-            # Create device_trust_state table
-            cursor.execute("""
-            CREATE TABLE IF NOT EXISTS device_trust_state (
-                device_id VARCHAR PRIMARY KEY,
-                state INTEGER NOT NULL,
-                FOREIGN KEY(device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
-            );
-            """)
-
-            # Create olm_sessions table
-            cursor.execute("""
-            CREATE SEQUENCE IF NOT EXISTS olm_sessions_id_seq START 1;
-
-            CREATE TABLE IF NOT EXISTS olm_sessions (
-                id INTEGER PRIMARY KEY DEFAULT nextval('olm_sessions_id_seq'),
-                account_id INTEGER NOT NULL,
-                sender_key TEXT NOT NULL,
-                session BLOB NOT NULL,
-                session_id VARCHAR NOT NULL,
-                creation_time TIMESTAMP NOT NULL,
-                last_usage_date TIMESTAMP NOT NULL,
-                FOREIGN KEY (account_id) REFERENCES accounts (id) ON DELETE CASCADE
-            );
-            """)
-
-            # Create inbound_group_sessions table
-            cursor.execute("""
-            CREATE SEQUENCE IF NOT EXISTS inbound_group_sessions_id_seq START 1;
-
-            CREATE TABLE IF NOT EXISTS inbound_group_sessions (
-                id INTEGER PRIMARY KEY DEFAULT nextval('inbound_group_sessions_id_seq'),
-                account_id INTEGER NOT NULL,
-                session TEXT NOT NULL,
-                fp_key TEXT NOT NULL,
-                sender_key TEXT NOT NULL,
-                room_id TEXT NOT NULL,
-                UNIQUE (account_id, sender_key, fp_key, room_id),
-                FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
-            );
-
-            CREATE TABLE IF NOT EXISTS forwarded_chains (
-                id INTEGER PRIMARY KEY,
-                session_id INTEGER NOT NULL,
-                sender_key TEXT NOT NULL,
-                FOREIGN KEY (session_id) REFERENCES inbound_group_sessions(id) ON DELETE CASCADE
-            );
-            """)
-
-            # Create outbound_group_sessions table
-            cursor.execute("""
-            CREATE TABLE IF NOT EXISTS outbound_group_sessions (
-                id INTEGER PRIMARY KEY,
-                account_id INTEGER NOT NULL,
-                room_id VARCHAR NOT NULL,
-                session_id VARCHAR NOT NULL UNIQUE,
-                session BLOB NOT NULL,
-                FOREIGN KEY(account_id) REFERENCES accounts(id) ON DELETE CASCADE
-            );
-            """)
-
-            # Create outgoing_key_requests table
-            cursor.execute("""
-            CREATE TABLE IF NOT EXISTS outgoing_key_requests (
-                id INTEGER PRIMARY KEY,
-                account_id INTEGER NOT NULL,
-                request_id TEXT NOT NULL,
-                session_id TEXT NOT NULL,
-                room_id TEXT NOT NULL,
-                algorithm TEXT NOT NULL,
-                FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
-                UNIQUE (account_id, request_id)
-            );
-
-            """)
-
-            # Create encrypted_rooms table
-            cursor.execute("""
-            CREATE TABLE IF NOT EXISTS encrypted_rooms (
-                room_id TEXT NOT NULL,
-                account_id INTEGER NOT NULL,
-                PRIMARY KEY (room_id, account_id),
-                FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
-            );
-            """)
-
-            # Create sync_tokens table
-            cursor.execute("""
-            CREATE TABLE IF NOT EXISTS sync_tokens (
-                account_id INTEGER PRIMARY KEY,
-                token TEXT NOT NULL,
-                FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
-            );
-            """)
 
 
     def _get_account(self):
     def _get_account(self):
         cursor = self.conn.cursor()
         cursor = self.conn.cursor()
@@ -387,18 +243,18 @@ class DuckDBStore(MatrixStore):
             for d in device_keys:
             for d in device_keys:
                 cur.execute(
                 cur.execute(
                     "SELECT * FROM keys WHERE device_id = ?",
                     "SELECT * FROM keys WHERE device_id = ?",
-                    (d["id"],)
+                    (d[0],)
                 )
                 )
                 keys = cur.fetchall()
                 keys = cur.fetchall()
-                key_dict = {k["key_type"]: k["key"] for k in keys}
+                key_dict = {k[0]: k[1] for k in keys}
 
 
                 store.add(
                 store.add(
                     OlmDevice(
                     OlmDevice(
-                        d["user_id"],
-                        d["device_id"],
+                        d[2],
+                        d[0],
                         key_dict,
                         key_dict,
-                        display_name=d["display_name"],
-                        deleted=d["deleted"],
+                        display_name=d[3],
+                        deleted=d[4],
                     )
                     )
                 )
                 )
 
 
@@ -561,18 +417,21 @@ class DuckDBStore(MatrixStore):
             )
             )
 
 
             for row in cursor.fetchall():
             for row in cursor.fetchall():
+                cursor.execute(
+                    "SELECT sender_key FROM forwarded_chains WHERE session_id = ?",
+                    (row[1],),
+                )
+                chains = cursor.fetchall()
+
                 session = InboundGroupSession.from_pickle(
                 session = InboundGroupSession.from_pickle(
-                    row["session"],
-                    row["fp_key"],
-                    row["sender_key"],
-                    row["room_id"],
+                    row[2].encode(),
+                    row[3],
+                    row[4],
+                    row[5],
                     self.pickle_key,
                     self.pickle_key,
                     [
                     [
-                        chain["sender_key"]
-                        for chain in cursor.execute(
-                            "SELECT sender_key FROM forwarded_chains WHERE session_id = ?",
-                            (row["id"],),
-                        )
+                        chain[0]
+                        for chain in chains
                     ],
                     ],
                 )
                 )
                 store.add(session)
                 store.add(session)
@@ -621,7 +480,7 @@ class DuckDBStore(MatrixStore):
             )
             )
             rows = cur.fetchall()
             rows = cur.fetchall()
 
 
-        return {row["room_id"] for row in rows}
+        return {row[0] for row in rows}
 
 
     def save_sync_token(self, token):
     def save_sync_token(self, token):
         """Save the given token"""
         """Save the given token"""
@@ -727,4 +586,29 @@ class DuckDBStore(MatrixStore):
                     key_request.room_id,
                     key_request.room_id,
                     key_request.algorithm,
                     key_request.algorithm,
                 )
                 )
-            )
+            )
+
+    def load_account(self):
+        # type: () -> Optional[OlmAccount]
+        """Load the Olm account from the database.
+
+        Returns:
+            ``OlmAccount`` object, or ``None`` if it wasn't found for the
+                current device_id.
+
+        """
+        cursor = self.conn.cursor()
+        query = """
+            SELECT pickle, shared_account
+            FROM accounts
+            WHERE device_id = ?;
+        """
+        cursor.execute(query, (self.device_id,))
+
+        result = cursor.fetchone()
+
+        if not result:
+            return None
+
+        account_pickle, shared = result
+        return OlmAccount.from_pickle(account_pickle.encode(), self.pickle_key, shared)

+ 2 - 0
commands/__init__.py

@@ -5,6 +5,7 @@ from .botinfo import command_botinfo
 from .unknown import command_unknown
 from .unknown import command_unknown
 from .coin import command_coin
 from .coin import command_coin
 from .ignoreolder import command_ignoreolder
 from .ignoreolder import command_ignoreolder
+from .systemmessage import command_systemmessage
 
 
 COMMANDS = {
 COMMANDS = {
     "help": command_help,
     "help": command_help,
@@ -13,5 +14,6 @@ COMMANDS = {
     "botinfo": command_botinfo,
     "botinfo": command_botinfo,
     "coin": command_coin,
     "coin": command_coin,
     "ignoreolder": command_ignoreolder,
     "ignoreolder": command_ignoreolder,
+    "systemmessage": command_systemmessage,
     None: command_unknown,
     None: command_unknown,
 }
 }

+ 34 - 0
commands/systemmessage.py

@@ -0,0 +1,34 @@
+from nio.events.room_events import RoomMessageText
+from nio.rooms import MatrixRoom
+
+
+async def command_systemmessage(room: MatrixRoom, event: RoomMessageText, context: dict):
+    system_message = " ".join(event.body.split()[2:])
+
+    if system_message:
+        context["logger"]("Adding system message...")
+
+        with context["database"].cursor() as cur:
+            cur.execute(
+                "INSERT INTO system_messages (room_id, message_id, user_id, body, timestamp) VALUES (?, ?, ?, ?, ?)",
+                (room.room_id, event.event_id, event.sender,
+                 system_message, event.server_timestamp)
+            )
+
+        return room.room_id, "m.room.message", {"msgtype": "m.notice", "body": f"System message stored: {system_message}"}
+
+    context["logger"]("Retrieving system message...")
+
+    with context["database"].cursor() as cur:
+        cur.execute(
+            "SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
+            (room.room_id,)
+        )
+        system_message = cur.fetchone()
+
+    if system_message is None:
+        system_message = context.get("system_message", "No system message set")
+    elif context.get("force_system_message") and context.get("system_message"):
+        system_message = system_message + "\n\n" + context["system_message"]
+
+    return room.room_id, "m.room.message", {"msgtype": "m.notice", "body": f"System message: {system_message}"}

+ 7 - 2
config.dist.ini

@@ -69,10 +69,15 @@ AccessToken = syt_yoursynapsetoken
 #
 #
 # SystemMessage = You are a helpful bot.
 # SystemMessage = You are a helpful bot.
 
 
+# Force inclusion of the SystemMessage defined above if one is defined on per-room level
+# If no custom message is defined for the room, SystemMessage is always included
+#
+# ForceSystemMessage = 0
+
 [Database]
 [Database]
 
 
 # Settings for the DuckDB database.
 # Settings for the DuckDB database.
-# Currently only used to store details on spent tokens per room.
-# If not defined, the bot will not store this data.
+# If not defined, the bot will not be able to remember anything, and will not support encryption
+# N.B.: Encryption doesn't work as it is supposed to anyway.
 
 
 Path = database.db
 Path = database.db

+ 61 - 43
gptbot.py

@@ -23,6 +23,7 @@ from typing import List, Dict, Union, Optional
 
 
 from commands import COMMANDS
 from commands import COMMANDS
 from classes import DuckDBStore
 from classes import DuckDBStore
+from migrations import MIGRATIONS
 
 
 
 
 def logging(message: str, log_level: str = "info"):
 def logging(message: str, log_level: str = "info"):
@@ -35,6 +36,7 @@ CONTEXT = {
     "database": False,
     "database": False,
     "default_room_name": "GPTBot",
     "default_room_name": "GPTBot",
     "system_message": "You are a helpful assistant.",
     "system_message": "You are a helpful assistant.",
+    "force_system_message": False,
     "max_tokens": 3000,
     "max_tokens": 3000,
     "max_messages": 20,
     "max_messages": 20,
     "model": "gpt-3.5-turbo",
     "model": "gpt-3.5-turbo",
@@ -48,6 +50,8 @@ async def gpt_query(messages: list, model: Optional[str] = None):
     model = model or CONTEXT["model"]
     model = model or CONTEXT["model"]
 
 
     logging(f"Querying GPT with {len(messages)} messages")
     logging(f"Querying GPT with {len(messages)} messages")
+    logging(messages, "debug")
+    
     try:
     try:
         response = openai.ChatCompletion.create(
         response = openai.ChatCompletion.create(
             model=model,
             model=model,
@@ -143,8 +147,9 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
 
 
     client = kwargs.get("client") or CONTEXT["client"]
     client = kwargs.get("client") or CONTEXT["client"]
     database = kwargs.get("database") or CONTEXT["database"]
     database = kwargs.get("database") or CONTEXT["database"]
-    system_message = kwargs.get("system_message") or CONTEXT["system_message"]
     max_tokens = kwargs.get("max_tokens") or CONTEXT["max_tokens"]
     max_tokens = kwargs.get("max_tokens") or CONTEXT["max_tokens"]
+    system_message = kwargs.get("system_message") or CONTEXT["system_message"]
+    force_system_message = kwargs.get("force_system_message") or CONTEXT["force_system_message"]
 
 
     await client.room_typing(room.room_id, True)
     await client.room_typing(room.room_id, True)
 
 
@@ -152,6 +157,12 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
 
 
     last_messages = await fetch_last_n_messages(room.room_id, 20)
     last_messages = await fetch_last_n_messages(room.room_id, 20)
 
 
+    system_message = get_system_message(room, {
+        "database": database,
+        "system_message": system_message,
+        "force_system_message": force_system_message,
+    })
+
     chat_messages = [{"role": "system", "content": system_message}]
     chat_messages = [{"role": "system", "content": system_message}]
 
 
     for message in last_messages:
     for message in last_messages:
@@ -163,8 +174,7 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
 
 
     # Truncate messages to fit within the token limit
     # Truncate messages to fit within the token limit
     truncated_messages = truncate_messages_to_fit_tokens(
     truncated_messages = truncate_messages_to_fit_tokens(
-        chat_messages, max_tokens - 1)
-
+        chat_messages, max_tokens - 1, system_message=system_message)
     response, tokens_used = await gpt_query(truncated_messages)
     response, tokens_used = await gpt_query(truncated_messages)
 
 
     if response:
     if response:
@@ -204,10 +214,29 @@ async def process_command(room: MatrixRoom, event: RoomMessageText, context: Opt
 
 
     if message:
     if message:
         room_id, event, content = message
         room_id, event, content = message
+        rooms = await context["client"].joined_rooms()
         await send_message(context["client"].rooms[room_id], content["body"],
         await send_message(context["client"].rooms[room_id], content["body"],
                            True if content["msgtype"] == "m.notice" else False, context["client"])
                            True if content["msgtype"] == "m.notice" else False, context["client"])
 
 
 
 
+def get_system_message(room: MatrixRoom, context: Optional[dict]) -> str:
+    context = context or CONTEXT
+
+    default = context.get("system_message")
+
+    with context["database"].cursor() as cur:
+        cur.execute(
+            "SELECT body FROM system_messages WHERE room_id = ? ORDER BY timestamp DESC LIMIT 1",
+            (room.room_id,)
+        )
+        system_message = cur.fetchone()
+
+    complete = ((default if ((not system_message) or context["force_system_message"]) else "") + (
+        "\n\n" + system_message[0] if system_message else "")).strip()
+
+    return complete
+
+
 async def message_callback(room: MatrixRoom, event: RoomMessageText | MegolmEvent, **kwargs):
 async def message_callback(room: MatrixRoom, event: RoomMessageText | MegolmEvent, **kwargs):
     context = kwargs.get("context") or CONTEXT
     context = kwargs.get("context") or CONTEXT
 
 
@@ -286,7 +315,8 @@ async def send_message(room: MatrixRoom, message: str, notice: bool = False, cli
                     )
                     )
 
 
             if msgtype != "m.reaction":
             if msgtype != "m.reaction":
-                response = client.encrypt(room.room_id, "m.room.message", msgcontent)
+                response = client.encrypt(
+                    room.room_id, "m.room.message", msgcontent)
                 msgtype, content = response
                 msgtype, content = response
 
 
         except Exception as e:
         except Exception as e:
@@ -318,7 +348,7 @@ async def accept_pending_invites(client: Optional[AsyncClient] = None):
         if isinstance(response, JoinResponse):
         if isinstance(response, JoinResponse):
             logging(response, "debug")
             logging(response, "debug")
             rooms = await client.joined_rooms()
             rooms = await client.joined_rooms()
-            await send_message(rooms[room_id], "Hello! I'm a helpful assistant. How can I help you today?", client)
+            await send_message(client.rooms[room_id], "Hello! I'm a helpful assistant. How can I help you today?", client)
         else:
         else:
             logging(f"Error joining room {room_id}: {response}", "error")
             logging(f"Error joining room {room_id}: {response}", "error")
 
 
@@ -408,13 +438,16 @@ async def init(config: ConfigParser):
     if "MaxMessages" in config["OpenAI"]:
     if "MaxMessages" in config["OpenAI"]:
         CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"])
         CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"])
 
 
-    # Listen for SIGTERM
-
-    def sigterm_handler(_signo, _stack_frame):
-        logging("Received SIGTERM - exiting...")
-        exit()
+    # Override defaults with config
 
 
-    signal.signal(signal.SIGTERM, sigterm_handler)
+    if "GPTBot" in config:
+        if "SystemMessage" in config["GPTBot"]:
+            CONTEXT["system_message"] = config["GPTBot"]["SystemMessage"]
+        if "DefaultRoomName" in config["GPTBot"]:
+            CONTEXT["default_room_name"] = config["GPTBot"]["DefaultRoomName"]
+        if "ForceSystemMessage" in config["GPTBot"]:
+            CONTEXT["force_system_message"] = config["GPTBot"].getboolean(
+                "ForceSystemMessage")
 
 
 
 
 async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClient] = None):
 async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClient] = None):
@@ -431,6 +464,14 @@ async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClie
         await client.close()
         await client.close()
         return
         return
 
 
+    # Listen for SIGTERM
+
+    def sigterm_handler(_signo, _stack_frame):
+        logging("Received SIGTERM - exiting...")
+        exit()
+
+    signal.signal(signal.SIGTERM, sigterm_handler)
+
     logging("Starting bot...")
     logging("Starting bot...")
 
 
     client.add_response_callback(sync_cb, SyncResponse)
     client.add_response_callback(sync_cb, SyncResponse)
@@ -460,9 +501,9 @@ async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClie
 
 
 def initialize_database(path: os.PathLike):
 def initialize_database(path: os.PathLike):
     logging("Initializing database...")
     logging("Initializing database...")
-    database = duckdb.connect(path)
+    conn = duckdb.connect(path)
 
 
-    with database.cursor() as cursor:
+    with conn.cursor() as cursor:
         # Get the latest migration ID if the migrations table exists
         # Get the latest migration ID if the migrations table exists
         try:
         try:
             cursor.execute(
             cursor.execute(
@@ -472,40 +513,17 @@ def initialize_database(path: os.PathLike):
             )
             )
 
 
             latest_migration = int(cursor.fetchone()[0])
             latest_migration = int(cursor.fetchone()[0])
+
         except:
         except:
             latest_migration = 0
             latest_migration = 0
 
 
-        # Version 1
-
-        if latest_migration < 1:
-            cursor.execute(
-                """
-                CREATE TABLE IF NOT EXISTS token_usage (
-                    message_id TEXT PRIMARY KEY,
-                    room_id TEXT NOT NULL,
-                    tokens INTEGER NOT NULL,
-                    timestamp TIMESTAMP NOT NULL
-                )
-                """
-            )
-
-            cursor.execute(
-                """
-                CREATE TABLE IF NOT EXISTS migrations (
-                    id INTEGER NOT NULL,
-                    timestamp TIMESTAMP NOT NULL
-                )
-                """
-            )
-
-            cursor.execute(
-                "INSERT INTO migrations (id, timestamp) VALUES (1, ?)",
-                (datetime.now(),)
-            )
-
-        database.commit()
+    for migration, function in MIGRATIONS.items():
+        if latest_migration < migration:
+            logging(f"Running migration {migration}...")
+            function(conn)
+            latest_migration = migration
 
 
-        return database
+    return conn
 
 
 
 
 async def get_device_id(access_token, homeserver):
 async def get_device_id(access_token, homeserver):

+ 11 - 0
migrations/__init__.py

@@ -0,0 +1,11 @@
+from collections import OrderedDict
+
+from .migration_1 import migration as migration_1
+from .migration_2 import migration as migration_2
+from .migration_3 import migration as migration_3
+
+MIGRATIONS = OrderedDict()
+
+MIGRATIONS[1] = migration_1
+MIGRATIONS[2] = migration_2
+MIGRATIONS[3] = migration_3

+ 32 - 0
migrations/migration_1.py

@@ -0,0 +1,32 @@
+# Initial migration, token usage logging
+
+from datetime import datetime
+
+def migration(conn):
+    with conn.cursor() as cursor:
+        cursor.execute(
+            """
+            CREATE TABLE IF NOT EXISTS token_usage (
+                message_id TEXT PRIMARY KEY,
+                room_id TEXT NOT NULL,
+                tokens INTEGER NOT NULL,
+                timestamp TIMESTAMP NOT NULL
+            )
+            """
+        )
+
+        cursor.execute(
+            """
+            CREATE TABLE IF NOT EXISTS migrations (
+                id INTEGER NOT NULL,
+                timestamp TIMESTAMP NOT NULL
+            )
+            """
+        )
+
+        cursor.execute(
+            "INSERT INTO migrations (id, timestamp) VALUES (1, ?)",
+            (datetime.now(),)
+        )
+
+        conn.commit()

+ 138 - 0
migrations/migration_2.py

@@ -0,0 +1,138 @@
+# Migration for Matrix Store
+
+from datetime import datetime
+
+def migration(conn):
+    with conn.cursor() as cursor:
+        # Create accounts table
+        cursor.execute("""
+        CREATE TABLE IF NOT EXISTS accounts (
+            id INTEGER PRIMARY KEY,
+            user_id VARCHAR NOT NULL,
+            device_id VARCHAR NOT NULL,
+            shared_account INTEGER NOT NULL,
+            pickle VARCHAR NOT NULL
+        );
+        """)
+
+        # Create device_keys table
+        cursor.execute("""
+        CREATE TABLE IF NOT EXISTS device_keys (
+            device_id TEXT PRIMARY KEY,
+            account_id INTEGER NOT NULL,
+            user_id TEXT NOT NULL,
+            display_name TEXT,
+            deleted BOOLEAN NOT NULL DEFAULT 0,
+            UNIQUE (account_id, user_id, device_id),
+            FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
+        );
+
+        CREATE TABLE IF NOT EXISTS keys (
+            key_type TEXT NOT NULL,
+            key TEXT NOT NULL,
+            device_id VARCHAR NOT NULL,
+            UNIQUE (key_type, device_id),
+            FOREIGN KEY (device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
+        );
+        """)
+
+        # Create device_trust_state table
+        cursor.execute("""
+        CREATE TABLE IF NOT EXISTS device_trust_state (
+            device_id VARCHAR PRIMARY KEY,
+            state INTEGER NOT NULL,
+            FOREIGN KEY(device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
+        );
+        """)
+
+        # Create olm_sessions table
+        cursor.execute("""
+        CREATE SEQUENCE IF NOT EXISTS olm_sessions_id_seq START 1;
+
+        CREATE TABLE IF NOT EXISTS olm_sessions (
+            id INTEGER PRIMARY KEY DEFAULT nextval('olm_sessions_id_seq'),
+            account_id INTEGER NOT NULL,
+            sender_key TEXT NOT NULL,
+            session BLOB NOT NULL,
+            session_id VARCHAR NOT NULL,
+            creation_time TIMESTAMP NOT NULL,
+            last_usage_date TIMESTAMP NOT NULL,
+            FOREIGN KEY (account_id) REFERENCES accounts (id) ON DELETE CASCADE
+        );
+        """)
+
+        # Create inbound_group_sessions table
+        cursor.execute("""
+        CREATE SEQUENCE IF NOT EXISTS inbound_group_sessions_id_seq START 1;
+
+        CREATE TABLE IF NOT EXISTS inbound_group_sessions (
+            id INTEGER PRIMARY KEY DEFAULT nextval('inbound_group_sessions_id_seq'),
+            account_id INTEGER NOT NULL,
+            session TEXT NOT NULL,
+            fp_key TEXT NOT NULL,
+            sender_key TEXT NOT NULL,
+            room_id TEXT NOT NULL,
+            UNIQUE (account_id, sender_key, fp_key, room_id),
+            FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
+        );
+
+        CREATE TABLE IF NOT EXISTS forwarded_chains (
+            id INTEGER PRIMARY KEY,
+            session_id INTEGER NOT NULL,
+            sender_key TEXT NOT NULL,
+            FOREIGN KEY (session_id) REFERENCES inbound_group_sessions(id) ON DELETE CASCADE
+        );
+        """)
+
+        # Create outbound_group_sessions table
+        cursor.execute("""
+        CREATE TABLE IF NOT EXISTS outbound_group_sessions (
+            id INTEGER PRIMARY KEY,
+            account_id INTEGER NOT NULL,
+            room_id VARCHAR NOT NULL,
+            session_id VARCHAR NOT NULL UNIQUE,
+            session BLOB NOT NULL,
+            FOREIGN KEY(account_id) REFERENCES accounts(id) ON DELETE CASCADE
+        );
+        """)
+
+        # Create outgoing_key_requests table
+        cursor.execute("""
+        CREATE TABLE IF NOT EXISTS outgoing_key_requests (
+            id INTEGER PRIMARY KEY,
+            account_id INTEGER NOT NULL,
+            request_id TEXT NOT NULL,
+            session_id TEXT NOT NULL,
+            room_id TEXT NOT NULL,
+            algorithm TEXT NOT NULL,
+            FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
+            UNIQUE (account_id, request_id)
+        );
+
+        """)
+
+        # Create encrypted_rooms table
+        cursor.execute("""
+        CREATE TABLE IF NOT EXISTS encrypted_rooms (
+            room_id TEXT NOT NULL,
+            account_id INTEGER NOT NULL,
+            PRIMARY KEY (room_id, account_id),
+            FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
+        );
+        """)
+
+        # Create sync_tokens table
+        cursor.execute("""
+        CREATE TABLE IF NOT EXISTS sync_tokens (
+            account_id INTEGER PRIMARY KEY,
+            token TEXT NOT NULL,
+            FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
+        );
+        """)
+        
+        cursor.execute(
+            "INSERT INTO migrations (id, timestamp) VALUES (2, ?)",
+            (datetime.now(),)
+        )
+
+        conn.commit()

+ 24 - 0
migrations/migration_3.py

@@ -0,0 +1,24 @@
+# Migration for custom system messages
+
+from datetime import datetime
+
+def migration(conn):
+    with conn.cursor() as cursor:
+        cursor.execute(
+            """
+            CREATE TABLE IF NOT EXISTS system_messages (
+                room_id TEXT NOT NULL,
+                message_id TEXT NOT NULL,
+                user_id TEXT NOT NULL,
+                body TEXT NOT NULL,
+                timestamp BIGINT NOT NULL,
+            )
+            """
+        )
+
+        cursor.execute(
+            "INSERT INTO migrations (id, timestamp) VALUES (3, ?)",
+            (datetime.now(),)
+        )
+
+        conn.commit()