Browse Source

Some refactoring, starting implementation of encryption

Kumi 1 year ago
parent
commit
f20b762558
11 changed files with 976 additions and 116 deletions
  1. 1 0
      classes/__init__.py
  2. 730 0
      classes/store.py
  3. 1 1
      commands/__init__.py
  4. 4 4
      commands/botinfo.py
  5. 3 4
      commands/coin.py
  6. 3 4
      commands/help.py
  7. 2 4
      commands/ignoreolder.py
  8. 4 3
      commands/newroom.py
  9. 5 9
      commands/stats.py
  10. 3 4
      commands/unknown.py
  11. 220 83
      gptbot.py

+ 1 - 0
classes/__init__.py

@@ -0,0 +1 @@
+from .store import DuckDBStore

+ 730 - 0
classes/store.py

@@ -0,0 +1,730 @@
+import duckdb
+
+from nio.store.database import MatrixStore, DeviceTrustState, OlmDevice, TrustState, InboundGroupSession, SessionStore, OlmSessions, GroupSessionStore, OutgoingKeyRequest, DeviceStore
+from nio.crypto import OlmAccount, OlmDevice
+
+from random import SystemRandom
+from collections import defaultdict
+from typing import Dict, List, Optional, Tuple
+
+import json
+
+
+class DuckDBStore(MatrixStore):
+    @property
+    def account_id(self):
+        id = self._get_account()[0] if self._get_account() else None
+
+        if id is None:
+            id = SystemRandom().randint(0, 2**16)
+
+        return id
+
+    def __init__(self, user_id, device_id, duckdb_conn):
+        self.conn = duckdb_conn
+        self.user_id = user_id
+        self.device_id = device_id
+        self._create_tables()
+
+    def _create_tables(self):
+        with self.conn.cursor() as cursor:
+            cursor.execute("""
+            DROP TABLE IF EXISTS sync_tokens CASCADE;
+            DROP TABLE IF EXISTS encrypted_rooms CASCADE;
+            DROP TABLE IF EXISTS outgoing_key_requests CASCADE;
+            DROP TABLE IF EXISTS forwarded_chains CASCADE;
+            DROP TABLE IF EXISTS outbound_group_sessions CASCADE;
+            DROP TABLE IF EXISTS inbound_group_sessions CASCADE;
+            DROP TABLE IF EXISTS olm_sessions CASCADE;
+            DROP TABLE IF EXISTS device_trust_state CASCADE;
+            DROP TABLE IF EXISTS keys CASCADE;
+            DROP TABLE IF EXISTS device_keys_key CASCADE;
+            DROP TABLE IF EXISTS device_keys CASCADE;
+            DROP TABLE IF EXISTS accounts CASCADE;
+            """)
+
+            # Create accounts table
+            cursor.execute("""
+            CREATE TABLE IF NOT EXISTS accounts (
+                id INTEGER PRIMARY KEY,
+                user_id VARCHAR NOT NULL,
+                device_id VARCHAR NOT NULL,
+                shared_account INTEGER NOT NULL,
+                pickle VARCHAR NOT NULL
+            );
+            """)
+
+            # Create device_keys table
+            cursor.execute("""
+            CREATE TABLE IF NOT EXISTS device_keys (
+                device_id TEXT PRIMARY KEY,
+                account_id INTEGER NOT NULL,
+                user_id TEXT NOT NULL,
+                display_name TEXT,
+                deleted BOOLEAN NOT NULL DEFAULT 0,
+                UNIQUE (account_id, user_id, device_id),
+                FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
+            );
+
+            CREATE TABLE IF NOT EXISTS keys (
+                key_type TEXT NOT NULL,
+                key TEXT NOT NULL,
+                device_id VARCHAR NOT NULL,
+                UNIQUE (key_type, device_id),
+                FOREIGN KEY (device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
+            );
+            """)
+
+            # Create device_trust_state table
+            cursor.execute("""
+            CREATE TABLE IF NOT EXISTS device_trust_state (
+                device_id VARCHAR PRIMARY KEY,
+                state INTEGER NOT NULL,
+                FOREIGN KEY(device_id) REFERENCES device_keys(device_id) ON DELETE CASCADE
+            );
+            """)
+
+            # Create olm_sessions table
+            cursor.execute("""
+            CREATE SEQUENCE IF NOT EXISTS olm_sessions_id_seq START 1;
+
+            CREATE TABLE IF NOT EXISTS olm_sessions (
+                id INTEGER PRIMARY KEY DEFAULT nextval('olm_sessions_id_seq'),
+                account_id INTEGER NOT NULL,
+                sender_key TEXT NOT NULL,
+                session BLOB NOT NULL,
+                session_id VARCHAR NOT NULL,
+                creation_time TIMESTAMP NOT NULL,
+                last_usage_date TIMESTAMP NOT NULL,
+                FOREIGN KEY (account_id) REFERENCES accounts (id) ON DELETE CASCADE
+            );
+            """)
+
+            # Create inbound_group_sessions table
+            cursor.execute("""
+            CREATE SEQUENCE IF NOT EXISTS inbound_group_sessions_id_seq START 1;
+
+            CREATE TABLE IF NOT EXISTS inbound_group_sessions (
+                id INTEGER PRIMARY KEY DEFAULT nextval('inbound_group_sessions_id_seq'),
+                account_id INTEGER NOT NULL,
+                session TEXT NOT NULL,
+                fp_key TEXT NOT NULL,
+                sender_key TEXT NOT NULL,
+                room_id TEXT NOT NULL,
+                UNIQUE (account_id, sender_key, fp_key, room_id),
+                FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
+            );
+
+            CREATE TABLE IF NOT EXISTS forwarded_chains (
+                id INTEGER PRIMARY KEY,
+                session_id INTEGER NOT NULL,
+                sender_key TEXT NOT NULL,
+                FOREIGN KEY (session_id) REFERENCES inbound_group_sessions(id) ON DELETE CASCADE
+            );
+            """)
+
+            # Create outbound_group_sessions table
+            cursor.execute("""
+            CREATE TABLE IF NOT EXISTS outbound_group_sessions (
+                id INTEGER PRIMARY KEY,
+                account_id INTEGER NOT NULL,
+                room_id VARCHAR NOT NULL,
+                session_id VARCHAR NOT NULL UNIQUE,
+                session BLOB NOT NULL,
+                FOREIGN KEY(account_id) REFERENCES accounts(id) ON DELETE CASCADE
+            );
+            """)
+
+            # Create outgoing_key_requests table
+            cursor.execute("""
+            CREATE TABLE IF NOT EXISTS outgoing_key_requests (
+                id INTEGER PRIMARY KEY,
+                account_id INTEGER NOT NULL,
+                request_id TEXT NOT NULL,
+                session_id TEXT NOT NULL,
+                room_id TEXT NOT NULL,
+                algorithm TEXT NOT NULL,
+                FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
+                UNIQUE (account_id, request_id)
+            );
+
+            """)
+
+            # Create encrypted_rooms table
+            cursor.execute("""
+            CREATE TABLE IF NOT EXISTS encrypted_rooms (
+                room_id TEXT NOT NULL,
+                account_id INTEGER NOT NULL,
+                PRIMARY KEY (room_id, account_id),
+                FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
+            );
+            """)
+
+            # Create sync_tokens table
+            cursor.execute("""
+            CREATE TABLE IF NOT EXISTS sync_tokens (
+                account_id INTEGER PRIMARY KEY,
+                token TEXT NOT NULL,
+                FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
+            );
+            """)
+
+    def _get_account(self):
+        cursor = self.conn.cursor()
+        cursor.execute(
+            "SELECT * FROM accounts WHERE user_id = ? AND device_id = ?",
+            (self.user_id, self.device_id),
+        )
+        account = cursor.fetchone()
+        cursor.close()
+        return account
+
+    def _get_device(self, device):
+        acc = self._get_account()
+
+        if not acc:
+            return None
+
+        cursor = self.conn.cursor()
+        cursor.execute(
+            "SELECT * FROM device_keys WHERE user_id = ? AND device_id = ? AND account_id = ?",
+            (device.user_id, device.id, acc[0]),
+        )
+        device_entry = cursor.fetchone()
+        cursor.close()
+
+        return device_entry
+
+    # Implementing methods with DuckDB equivalents
+    def verify_device(self, device):
+        if self.is_device_verified(device):
+            return False
+
+        d = self._get_device(device)
+        assert d
+
+        cursor = self.conn.cursor()
+        cursor.execute(
+            "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
+            (d[0], TrustState.verified),
+        )
+        self.conn.commit()
+        cursor.close()
+
+        device.trust_state = TrustState.verified
+
+        return True
+
+    def unverify_device(self, device):
+        if not self.is_device_verified(device):
+            return False
+
+        d = self._get_device(device)
+        assert d
+
+        cursor = self.conn.cursor()
+        cursor.execute(
+            "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
+            (d[0], TrustState.unset),
+        )
+        self.conn.commit()
+        cursor.close()
+
+        device.trust_state = TrustState.unset
+
+        return True
+
+    def is_device_verified(self, device):
+        d = self._get_device(device)
+
+        if not d:
+            return False
+
+        cursor = self.conn.cursor()
+        cursor.execute(
+            "SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
+        )
+        trust_state = cursor.fetchone()
+        cursor.close()
+
+        if not trust_state:
+            return False
+
+        return trust_state[0] == TrustState.verified
+
+    def blacklist_device(self, device):
+        if self.is_device_blacklisted(device):
+            return False
+
+        d = self._get_device(device)
+        assert d
+
+        cursor = self.conn.cursor()
+        cursor.execute(
+            "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
+            (d[0], TrustState.blacklisted),
+        )
+        self.conn.commit()
+        cursor.close()
+
+        device.trust_state = TrustState.blacklisted
+
+        return True
+
+    def unblacklist_device(self, device):
+        if not self.is_device_blacklisted(device):
+            return False
+
+        d = self._get_device(device)
+        assert d
+
+        cursor = self.conn.cursor()
+        cursor.execute(
+            "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
+            (d[0], TrustState.unset),
+        )
+        self.conn.commit()
+        cursor.close()
+
+        device.trust_state = TrustState.unset
+
+        return True
+
+    def is_device_blacklisted(self, device):
+        d = self._get_device(device)
+
+        if not d:
+            return False
+
+        cursor = self.conn.cursor()
+        cursor.execute(
+            "SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
+        )
+        trust_state = cursor.fetchone()
+        cursor.close()
+
+        if not trust_state:
+            return False
+
+        return trust_state[0] == TrustState.blacklisted
+
+    def ignore_device(self, device):
+        if self.is_device_ignored(device):
+            return False
+
+        d = self._get_device(device)
+        assert d
+
+        cursor = self.conn.cursor()
+        cursor.execute(
+            "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
+            (d[0], int(TrustState.ignored.value)),
+        )
+        self.conn.commit()
+        cursor.close()
+
+        return True
+
+    def ignore_devices(self, devices):
+        for device in devices:
+            self.ignore_device(device)
+
+    def unignore_device(self, device):
+        if not self.is_device_ignored(device):
+            return False
+
+        d = self._get_device(device)
+        assert d
+
+        cursor = self.conn.cursor()
+        cursor.execute(
+            "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
+            (d[0], TrustState.unset),
+        )
+        self.conn.commit()
+        cursor.close()
+
+        device.trust_state = TrustState.unset
+
+        return True
+
+    def is_device_ignored(self, device):
+        d = self._get_device(device)
+
+        if not d:
+            return False
+
+        cursor = self.conn.cursor()
+        cursor.execute(
+            "SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
+        )
+        trust_state = cursor.fetchone()
+        cursor.close()
+
+        if not trust_state:
+            return False
+
+        return trust_state[0] == TrustState.ignored
+
+    def load_device_keys(self):
+        """Load all the device keys from the database.
+
+        Returns DeviceStore containing the OlmDevices with the device keys.
+        """
+        store = DeviceStore()
+        account = self.account_id
+
+        if not account:
+            return store
+
+        with self.conn.cursor() as cur:
+            cur.execute(
+                "SELECT * FROM device_keys WHERE account_id = ?",
+                (account,)
+            )
+            device_keys = cur.fetchall()
+
+            for d in device_keys:
+                cur.execute(
+                    "SELECT * FROM keys WHERE device_id = ?",
+                    (d["id"],)
+                )
+                keys = cur.fetchall()
+                key_dict = {k["key_type"]: k["key"] for k in keys}
+
+                store.add(
+                    OlmDevice(
+                        d["user_id"],
+                        d["device_id"],
+                        key_dict,
+                        display_name=d["display_name"],
+                        deleted=d["deleted"],
+                    )
+                )
+
+        return store
+
+    def save_device_keys(self, device_keys):
+        """Save the provided device keys to the database."""
+        account = self.account_id
+        assert account
+        rows = []
+
+        for user_id, devices_dict in device_keys.items():
+            for device_id, device in devices_dict.items():
+                rows.append(
+                    {
+                        "account_id": account,
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "display_name": device.display_name,
+                        "deleted": device.deleted,
+                    }
+                )
+
+        if not rows:
+            return
+
+        with self.conn.cursor() as cur:
+            for idx in range(0, len(rows), 100):
+                data = rows[idx: idx + 100]
+                cur.executemany(
+                    "INSERT OR IGNORE INTO device_keys (account_id, user_id, device_id, display_name, deleted) VALUES (?, ?, ?, ?, ?)",
+                    [(r["account_id"], r["user_id"], r["device_id"],
+                      r["display_name"], r["deleted"]) for r in data]
+                )
+
+            for user_id, devices_dict in device_keys.items():
+                for device_id, device in devices_dict.items():
+                    cur.execute(
+                        "UPDATE device_keys SET deleted = ? WHERE device_id = ?",
+                        (device.deleted, device_id)
+                    )
+
+                    for key_type, key in device.keys.items():
+                        cur.execute("""
+                            INSERT INTO keys (key_type, key, device_id) VALUES (?, ?, ?)
+                            ON CONFLICT (key_type, device_id) DO UPDATE SET key = ?
+                            """,
+                                    (key_type, key, device_id, key)
+                                    )
+            self.conn.commit()
+
+    def save_group_sessions(self, sessions):
+        with self.conn.cursor() as cur:
+            for session in sessions:
+                cur.execute("""
+                    INSERT OR REPLACE INTO inbound_group_sessions (
+                        session_id, sender_key, signing_key, room_id, pickle, account_id
+                    ) VALUES (?, ?, ?, ?, ?, ?)
+                """, (
+                    session.id,
+                    session.sender_key,
+                    session.signing_key,
+                    session.room_id,
+                    session.pickle,
+                    self.account_id
+                ))
+
+            self.conn.commit()
+
+    def save_olm_sessions(self, sessions):
+        with self.conn.cursor() as cur:
+            for session in sessions:
+                cur.execute("""
+                    INSERT OR REPLACE INTO olm_sessions (
+                        session_id, sender_key, pickle, account_id
+                    ) VALUES (?, ?, ?, ?)
+                """, (
+                    session.id,
+                    session.sender_key,
+                    session.pickle,
+                    self.account_id
+                ))
+
+            self.conn.commit()
+
+    def save_outbound_group_sessions(self, sessions):
+        with self.conn.cursor() as cur:
+            for session in sessions:
+                cur.execute("""
+                    INSERT OR REPLACE INTO outbound_group_sessions (
+                        room_id, session_id, pickle, account_id
+                    ) VALUES (?, ?, ?, ?)
+                """, (
+                    session.room_id,
+                    session.id,
+                    session.pickle,
+                    self.account_id
+                ))
+
+            self.conn.commit()
+
+    def save_account(self, account: OlmAccount):
+        with self.conn.cursor() as cur:
+            cur.execute("""
+                INSERT OR REPLACE INTO accounts (
+                    id, user_id, device_id, shared_account, pickle
+                ) VALUES (?, ?, ?, ?, ?)
+            """, (
+                self.account_id,
+                self.user_id,
+                self.device_id,
+                account.shared,
+                account.pickle(self.pickle_key),
+            ))
+
+            self.conn.commit()
+
+    def load_sessions(self):
+        session_store = SessionStore()
+
+        with self.conn.cursor() as cur:
+            cur.execute("""
+                SELECT
+                    os.sender_key, os.session, os.creation_time
+                FROM
+                    olm_sessions os
+                INNER JOIN
+                    accounts a ON os.account_id = a.id
+                WHERE
+                    a.id = ?
+            """, (self.account_id,))
+
+            for row in cur.fetchall():
+                sender_key, session_pickle, creation_time = row
+                session = Session.from_pickle(
+                    session_pickle, creation_time, self.pickle_key)
+                session_store.add(sender_key, session)
+
+        return session_store
+
+    def load_inbound_group_sessions(self):
+        # type: () -> GroupSessionStore
+        """Load all Olm sessions from the database.
+
+        Returns:
+            ``GroupSessionStore`` object, containing all the loaded sessions.
+
+        """
+        store = GroupSessionStore()
+
+        account = self.account_id
+
+        if not account:
+            return store
+
+        with self.conn.cursor() as cursor:
+            cursor.execute(
+                "SELECT * FROM inbound_group_sessions WHERE account_id = ?", (
+                    account,)
+            )
+
+            for row in cursor.fetchall():
+                session = InboundGroupSession.from_pickle(
+                    row["session"],
+                    row["fp_key"],
+                    row["sender_key"],
+                    row["room_id"],
+                    self.pickle_key,
+                    [
+                        chain["sender_key"]
+                        for chain in cursor.execute(
+                            "SELECT sender_key FROM forwarded_chains WHERE session_id = ?",
+                            (row["id"],),
+                        )
+                    ],
+                )
+                store.add(session)
+
+        return store
+
+    def load_outgoing_key_requests(self):
+        # type: () -> dict
+        """Load all outgoing key requests from the database.
+
+        Returns:
+            ``OutgoingKeyRequestStore`` object, containing all the loaded key requests.
+        """
+        account = self.account_id
+
+        if not account:
+            return store
+
+        with self.conn.cursor() as cur:
+            cur.execute(
+                "SELECT * FROM outgoing_key_requests WHERE account_id = ?",
+                (account,)
+            )
+            rows = cur.fetchall()
+
+        return {
+            request.request_id: OutgoingKeyRequest.from_database(request)
+            for request in rows
+        }
+
+    def load_encrypted_rooms(self):
+        """Load the set of encrypted rooms for this account.
+
+        Returns:
+            ``Set`` containing room ids of encrypted rooms.
+        """
+        account = self.account_id
+
+        if not account:
+            return set()
+
+        with self.conn.cursor() as cur:
+            cur.execute(
+                "SELECT room_id FROM encrypted_rooms WHERE account_id = ?",
+                (account,)
+            )
+            rows = cur.fetchall()
+
+        return {row["room_id"] for row in rows}
+
+    def save_sync_token(self, token):
+        """Save the given token"""
+        account = self.account_id
+        assert account
+
+        with self.conn.cursor() as cur:
+            cur.execute(
+                "INSERT OR REPLACE INTO sync_tokens (account_id, token) VALUES (?, ?)",
+                (account, token)
+            )
+            self.conn.commit()
+
+    def save_encrypted_rooms(self, rooms):
+        """Save the set of room ids for this account."""
+        account = self.account_id
+        assert account
+
+        data = [(room_id, account) for room_id in rooms]
+
+        with self.conn.cursor() as cur:
+            for idx in range(0, len(data), 400):
+                rows = data[idx: idx + 400]
+                cur.executemany(
+                    "INSERT OR IGNORE INTO encrypted_rooms (room_id, account_id) VALUES (?, ?)",
+                    rows
+                )
+            self.conn.commit()
+
+    def save_session(self, sender_key, session):
+        """Save the provided Olm session to the database.
+
+        Args:
+            sender_key (str): The curve key that owns the Olm session.
+            session (Session): The Olm session that will be pickled and
+                saved in the database.
+        """
+        account = self.account_id
+        assert account
+
+        pickled_session = session.pickle(self.pickle_key)
+
+        with self.conn.cursor() as cur:
+
+            cur.execute(
+                "INSERT OR REPLACE INTO olm_sessions (account_id, sender_key, session, session_id, creation_time, last_usage_date) VALUES (?, ?, ?, ?, ?, ?)",
+                (account, sender_key, pickled_session, session.id,
+                 session.creation_time, session.use_time)
+            )
+            self.conn.commit()
+
+    def save_inbound_group_session(self, session):
+        """Save the provided Megolm inbound group session to the database.
+
+        Args:
+            session (InboundGroupSession): The session to save.
+        """
+        account = self.account_id
+        assert account
+
+        with self.conn.cursor() as cur:
+
+            # Insert a new session or update the existing one
+            query = """
+            INSERT INTO inbound_group_sessions (account_id, sender_key, fp_key, room_id, session)
+            VALUES (?, ?, ?, ?, ?)
+            ON CONFLICT (account_id, sender_key, fp_key, room_id)
+            DO UPDATE SET session = excluded.session
+            """
+            cur.execute(query, (account, session.sender_key,
+                                session.ed25519, session.room_id, session.pickle(self.pickle_key)))
+
+            # Delete existing forwarded chains for the session
+            delete_query = """
+            DELETE FROM forwarded_chains WHERE session_id = (SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?)
+            """
+            cur.execute(
+                delete_query, (account, session.sender_key, session.ed25519, session.room_id))
+
+            # Insert new forwarded chains for the session
+            insert_query = """
+            INSERT INTO forwarded_chains (session_id, sender_key)
+            VALUES ((SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?), ?)
+            """
+
+            for chain in session.forwarding_chain:
+                cur.execute(
+                    insert_query, (account, session.sender_key, session.ed25519, session.room_id, chain))
+
+    def add_outgoing_key_request(self, key_request):
+        account_id = self.account_id
+        with self.conn.cursor() as cursor:
+            cursor.execute(
+                """
+                INSERT INTO outgoing_key_requests (account_id, request_id, session_id, room_id, algorithm)
+                VALUES (?, ?, ?, ?, ?)
+                ON CONFLICT (account_id, request_id) DO NOTHING
+                """,
+                (
+                    account_id,
+                    key_request.request_id,
+                    key_request.session_id,
+                    key_request.room_id,
+                    key_request.algorithm,
+                )
+            )

+ 1 - 1
commands/__init__.py

@@ -14,4 +14,4 @@ COMMANDS = {
     "coin": command_coin,
     "ignoreolder": command_ignoreolder,
     None: command_unknown,
-}
+}

+ 4 - 4
commands/botinfo.py

@@ -1,12 +1,12 @@
 from nio.events.room_events import RoomMessageText
 from nio.rooms import MatrixRoom
 
+
 async def command_botinfo(room: MatrixRoom, event: RoomMessageText, context: dict):
     logging("Showing bot info...")
 
-    await context["client"].room_send(
-        room.room_id, "m.room.message", {"msgtype": "m.notice",
-                                         "body": f"""GPT Info:
+    return room.room_id, "m.room.message", {"msgtype": "m.notice",
+                                            "body": f"""GPT Info:
 
 Model: {context["model"]}
 Maximum context tokens: {context["max_tokens"]}
@@ -19,4 +19,4 @@ Bot user ID: {context["client"].user_id}
 Current room ID: {room.room_id}
 
 For usage statistics, run !gptbot stats
-"""})
+"""}

+ 3 - 4
commands/coin.py

@@ -3,12 +3,11 @@ from nio.rooms import MatrixRoom
 
 from random import SystemRandom
 
+
 async def command_coin(room: MatrixRoom, event: RoomMessageText, context: dict):
     context["logger"]("Flipping a coin...")
 
     heads = SystemRandom().choice([True, False])
 
-    await context["client"].room_send(
-        room.room_id, "m.room.message", {"msgtype": "m.notice",
-                                         "body": "Heads!" if heads else "Tails!"}
-    )
+    return room.room_id, "m.room.message", {"msgtype": "m.notice",
+                                            "body": "Heads!" if heads else "Tails!"}

+ 3 - 4
commands/help.py

@@ -1,10 +1,10 @@
 from nio.events.room_events import RoomMessageText
 from nio.rooms import MatrixRoom
 
+
 async def command_help(room: MatrixRoom, event: RoomMessageText, context: dict):
-    await context["client"].room_send(
-        room.room_id, "m.room.message", {"msgtype": "m.notice",
-                                         "body": """Available commands:
+    return room.guest_accessroom_id, "m.room.message", {"msgtype": "m.notice",
+                                                        "body": """Available commands:
 
 !gptbot help - Show this message
 !gptbot newroom <room name> - Create a new room and invite yourself to it
@@ -13,4 +13,3 @@ async def command_help(room: MatrixRoom, event: RoomMessageText, context: dict):
 !gptbot coin - Flip a coin (heads or tails)
 !gptbot ignoreolder - Ignore messages before this point as context
 """}
-    )

+ 2 - 4
commands/ignoreolder.py

@@ -2,9 +2,7 @@ from nio.events.room_events import RoomMessageText
 from nio.rooms import MatrixRoom
 
 async def command_ignoreolder(room: MatrixRoom, event: RoomMessageText, context: dict):
-    await context["client"].room_send(
-        room.room_id, "m.room.message", {"msgtype": "m.notice",
+    return room.room_id, "m.room.message", {"msgtype": "m.notice",
                                          "body": """Alright, messages before this point will not be processed as context anymore.
                                          
-If you ever reconsider, you can simply delete your message and I will start processing messages before it again."""}
-    )
+If you ever reconsider, you can simply delete your message and I will start processing messages before it again."""}

+ 4 - 3
commands/newroom.py

@@ -1,8 +1,10 @@
 from nio.events.room_events import RoomMessageText
 from nio.rooms import MatrixRoom
 
+
 async def command_newroom(room: MatrixRoom, event: RoomMessageText, context: dict):
-    room_name = " ".join(event.body.split()[2:]) or context["default_room_name"]
+    room_name = " ".join(event.body.split()[
+                         2:]) or context["default_room_name"]
 
     context["logger"]("Creating new room...")
     new_room = await context["client"].room_create(name=room_name)
@@ -12,5 +14,4 @@ async def command_newroom(room: MatrixRoom, event: RoomMessageText, context: dic
     await context["client"].room_put_state(
         new_room.room_id, "m.room.power_levels", {"users": {event.sender: 100}})
 
-    await context["client"].room_send(
-        new_room.room_id, "m.room.message", {"msgtype": "m.text", "body": "Welcome to the new room!"})
+    return new_room.room_id, "m.room.message", {"msgtype": "m.text", "body": "Welcome to the new room!"}

+ 5 - 9
commands/stats.py

@@ -1,23 +1,19 @@
 from nio.events.room_events import RoomMessageText
 from nio.rooms import MatrixRoom
 
+
 async def command_stats(room: MatrixRoom, event: RoomMessageText, context: dict):
     context["logger"]("Showing stats...")
 
     if not (database := context.get("database")):
         context["logger"]("No database connection - cannot show stats")
-        context["client"].room_send(
-            room.room_id, "m.room.message", {"msgtype": "m.notice",
-                                             "body": "Sorry, I'm not connected to a database, so I don't have any statistics on your usage."}
-        )
-        return
+        return room.room_id, "m.room.message", {"msgtype": "m.notice",
+                                                "body": "Sorry, I'm not connected to a database, so I don't have any statistics on your usage."}
 
     with database.cursor() as cursor:
         cursor.execute(
             "SELECT SUM(tokens) FROM token_usage WHERE room_id = ?", (room.room_id,))
         total_tokens = cursor.fetchone()[0] or 0
 
-    await context["client"].room_send(
-        room.room_id, "m.room.message", {"msgtype": "m.notice",
-                                         "body": f"Total tokens used: {total_tokens}"}
-    )
+    return room.room_id, "m.room.message", {"msgtype": "m.notice",
+                                            "body": f"Total tokens used: {total_tokens}"}

+ 3 - 4
commands/unknown.py

@@ -1,10 +1,9 @@
 from nio.events.room_events import RoomMessageText
 from nio.rooms import MatrixRoom
 
+
 async def command_unknown(room: MatrixRoom, event: RoomMessageText, context: dict):
     context["logger"]("Unknown command")
 
-    await context["client"].room_send(
-        room.room_id, "m.room.message", {"msgtype": "m.notice",
-                                         "body": "Unknown command - try !gptbot help"}
-    )
+    return room.room_id, "m.room.message", {"msgtype": "m.notice",
+                                            "body": "Unknown command - try !gptbot help"}

+ 220 - 83
gptbot.py

@@ -3,6 +3,7 @@ import inspect
 import logging
 import signal
 import random
+import uuid
 
 import openai
 import asyncio
@@ -10,9 +11,10 @@ import markdown2
 import tiktoken
 import duckdb
 
-from nio import AsyncClient, RoomMessageText, MatrixRoom, Event, InviteEvent
+from nio import AsyncClient, RoomMessageText, MatrixRoom, Event, InviteEvent, AsyncClientConfig, MegolmEvent, GroupEncryptionError, EncryptionError, HttpClient, Api
 from nio.api import MessageDirection
-from nio.responses import RoomMessagesError, SyncResponse, RoomRedactError
+from nio.responses import RoomMessagesError, SyncResponse, RoomRedactError, WhoamiResponse, JoinResponse, RoomSendResponse
+from nio.crypto import Olm
 
 from configparser import ConfigParser
 from datetime import datetime
@@ -20,6 +22,7 @@ from argparse import ArgumentParser
 from typing import List, Dict, Union, Optional
 
 from commands import COMMANDS
+from classes import DuckDBStore
 
 
 def logging(message: str, log_level: str = "info"):
@@ -85,6 +88,13 @@ async def fetch_last_n_messages(room_id: str, n: Optional[int] = None,
     for event in response.chunk:
         if len(messages) >= n:
             break
+        if isinstance(event, MegolmEvent):
+            try:
+                event = await client.decrypt_event(event)
+            except (GroupEncryptionError, EncryptionError):
+                logging(
+                    f"Could not decrypt message {event.event_id} in room {room_id}", "error")
+                continue
         if isinstance(event, RoomMessageText):
             if event.body.startswith("!gptbot ignoreolder"):
                 break
@@ -162,14 +172,7 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
 
         # Convert markdown to HTML
 
-        markdowner = markdown2.Markdown(extras=["fenced-code-blocks"])
-        formatted_body = markdowner.convert(response)
-
-        message = await client.room_send(
-            room.room_id, "m.room.message",
-            {"msgtype": "m.text", "body": response,
-             "format": "org.matrix.custom.html", "formatted_body": formatted_body}
-        )
+        message = await send_message(room, response)
 
         if database:
             logging("Logging tokens used...")
@@ -183,11 +186,8 @@ async def process_query(room: MatrixRoom, event: RoomMessageText, **kwargs):
         # Send a notice to the room if there was an error
 
         logging("Error during GPT API call - sending notice to room")
-
-        await client.room_send(
-            room.room_id, "m.room.message", {
-                "msgtype": "m.notice", "body": "Sorry, I'm having trouble connecting to the GPT API right now. Please try again later."}
-        )
+        send_message(
+            room, "Sorry, I'm having trouble connecting to the GPT API right now. Please try again later.", True)
         print("No response from GPT API")
 
     await client.room_typing(room.room_id, False)
@@ -199,14 +199,34 @@ async def process_command(room: MatrixRoom, event: RoomMessageText, context: Opt
     logging(
         f"Received command {event.body} from {event.sender} in room {room.room_id}")
     command = event.body.split()[1] if event.body.split()[1:] else None
-    await COMMANDS.get(command, COMMANDS[None])(room, event, context)
 
+    message = await COMMANDS.get(command, COMMANDS[None])(room, event, context)
+
+    if message:
+        room_id, event, content = message
+        await send_message(context["client"].rooms[room_id], content["body"],
+                           True if content["msgtype"] == "m.notice" else False, context["client"])
 
-async def message_callback(room: MatrixRoom, event: RoomMessageText, **kwargs):
+
+async def message_callback(room: MatrixRoom, event: RoomMessageText | MegolmEvent, **kwargs):
     context = kwargs.get("context") or CONTEXT
-    
+
     logging(f"Received message from {event.sender} in room {room.room_id}")
 
+    if isinstance(event, MegolmEvent):
+        try:
+            event = await context["client"].decrypt_event(event)
+        except Exception as e:
+            try:
+                logging("Requesting new encryption keys...")
+                await context["client"].request_room_key(event)
+            except:
+                pass
+
+            logging(f"Error decrypting message: {e}", "error")
+            await send_message(room, "Sorry, I couldn't decrypt that message. Please try again later or switch to a room without encryption.", True, context["client"])
+            return
+
     if event.sender == context["client"].user_id:
         logging("Message is from bot itself - ignoring")
 
@@ -221,18 +241,69 @@ async def message_callback(room: MatrixRoom, event: RoomMessageText, **kwargs):
 
 
 async def room_invite_callback(room: MatrixRoom, event: InviteEvent, **kwargs):
-    client = kwargs.get("client") or CONTEXT["client"]
+    client: AsyncClient = kwargs.get("client") or CONTEXT["client"]
+
+    if room.room_id in client.rooms:
+        logging(f"Already in room {room.room_id} - ignoring invite")
+        return
 
     logging(f"Received invite to room {room.room_id} - joining...")
 
-    await client.join(room.room_id)
-    await client.room_send(
-        room.room_id,
-        "m.room.message",
-        {"msgtype": "m.text",
-            "body": "Hello! I'm a helpful assistant. How can I help you today?"}
+    response = await client.join(room.room_id)
+    if isinstance(response, JoinResponse):
+        await send_message(room, "Hello! I'm a helpful assistant. How can I help you today?", client)
+    else:
+        logging(f"Error joining room {room.room_id}: {response}", "error")
+
+
+async def send_message(room: MatrixRoom, message: str, notice: bool = False, client: Optional[AsyncClient] = None):
+    client = client or CONTEXT["client"]
+
+    markdowner = markdown2.Markdown(extras=["fenced-code-blocks"])
+    formatted_body = markdowner.convert(message)
+
+    msgtype = "m.notice" if notice else "m.text"
+
+    msgcontent = {"msgtype": msgtype, "body": message,
+                  "format": "org.matrix.custom.html", "formatted_body": formatted_body}
+
+    content = None
+
+    if client.olm and room.encrypted:
+        try:
+            if not room.members_synced:
+                responses = []
+                responses.append(await client.joined_members(room.room_id))
+
+            if client.olm.should_share_group_session(room.room_id):
+                try:
+                    event = client.sharing_session[room.room_id]
+                    await event.wait()
+                except KeyError:
+                    await client.share_group_session(
+                        room.room_id,
+                        ignore_unverified_devices=True,
+                    )
+
+            if msgtype != "m.reaction":
+                response = client.encrypt(room.room_id, "m.room.message", msgcontent)
+                msgtype, content = response
+
+        except Exception as e:
+            logging(
+                f"Error encrypting message: {e} - sending unencrypted", "error")
+            raise
+
+    if not content:
+        msgtype = "m.room.message"
+        content = msgcontent
+
+    method, path, data = Api.room_send(
+        client.access_token, room.room_id, msgtype, content, uuid.uuid4()
     )
 
+    return await client._send(RoomSendResponse, method, path, data, (room.room_id,))
+
 
 async def accept_pending_invites(client: Optional[AsyncClient] = None):
     client = client or CONTEXT["client"]
@@ -242,13 +313,14 @@ async def accept_pending_invites(client: Optional[AsyncClient] = None):
     for room_id in list(client.invited_rooms.keys()):
         logging(f"Joining room {room_id}...")
 
-        await client.join(room_id)
-        await client.room_send(
-            room_id,
-            "m.room.message",
-            {"msgtype": "m.text",
-                "body": "Hello! I'm a helpful assistant. How can I help you today?"}
-        )
+        response = await client.join(room_id)
+
+        if isinstance(response, JoinResponse):
+            logging(response, "debug")
+            rooms = await client.joined_rooms()
+            await send_message(rooms[room_id], "Hello! I'm a helpful assistant. How can I help you today?", client)
+        else:
+            logging(f"Error joining room {room_id}: {response}", "error")
 
 
 async def sync_cb(response, write_global: bool = True):
@@ -261,12 +333,95 @@ async def sync_cb(response, write_global: bool = True):
         CONTEXT["sync_token"] = SYNC_TOKEN
 
 
-async def main(client: Optional[AsyncClient] = None):
-    client = client or CONTEXT["client"]
+async def test_callback(room: MatrixRoom, event: Event, **kwargs):
+    logging(
+        f"Received event {event.__class__.__name__} in room {room.room_id}", "debug")
 
-    if not client.user_id:
-        whoami = await client.whoami()
-        client.user_id = whoami.user_id
+
+async def init(config: ConfigParser):
+    # Set up Matrix client
+    try:
+        assert "Matrix" in config
+        assert "Homeserver" in config["Matrix"]
+        assert "AccessToken" in config["Matrix"]
+    except:
+        logging("Matrix config not found or incomplete", "critical")
+        exit(1)
+
+    homeserver = config["Matrix"]["Homeserver"]
+    access_token = config["Matrix"]["AccessToken"]
+
+    device_id, user_id = await get_device_id(access_token, homeserver)
+
+    device_id = config["Matrix"].get("DeviceID", device_id)
+    user_id = config["Matrix"].get("UserID", user_id)
+
+    # Set up database
+    if "Database" in config and config["Database"].get("Path"):
+        database = CONTEXT["database"] = initialize_database(
+            config["Database"]["Path"])
+        matrix_store = DuckDBStore
+
+        client_config = AsyncClientConfig(
+            store_sync_tokens=True, encryption_enabled=True, store=matrix_store)
+
+    else:
+        client_config = AsyncClientConfig(
+            store_sync_tokens=True, encryption_enabled=False)
+
+    client = AsyncClient(
+        config["Matrix"]["Homeserver"], config=client_config)
+
+    if client.config.encryption_enabled:
+        client.store = client.config.store(
+            user_id,
+            device_id,
+            database
+        )
+        assert client.store
+
+        client.olm = Olm(client.user_id, client.device_id, client.store)
+        client.encrypted_rooms = client.store.load_encrypted_rooms()
+
+    CONTEXT["client"] = client
+
+    CONTEXT["client"].access_token = config["Matrix"]["AccessToken"]
+    CONTEXT["client"].user_id = user_id
+    CONTEXT["client"].device_id = device_id
+
+    # Set up GPT API
+    try:
+        assert "OpenAI" in config
+        assert "APIKey" in config["OpenAI"]
+    except:
+        logging("OpenAI config not found or incomplete", "critical")
+        exit(1)
+
+    openai.api_key = config["OpenAI"]["APIKey"]
+
+    if "Model" in config["OpenAI"]:
+        CONTEXT["model"] = config["OpenAI"]["Model"]
+
+    if "MaxTokens" in config["OpenAI"]:
+        CONTEXT["max_tokens"] = int(config["OpenAI"]["MaxTokens"])
+
+    if "MaxMessages" in config["OpenAI"]:
+        CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"])
+
+    # Listen for SIGTERM
+
+    def sigterm_handler(_signo, _stack_frame):
+        logging("Received SIGTERM - exiting...")
+        exit()
+
+    signal.signal(signal.SIGTERM, sigterm_handler)
+
+
+async def main(config: Optional[ConfigParser] = None, client: Optional[AsyncClient] = None):
+    if not client and not CONTEXT.get("client"):
+        await init(config)
+
+    client = client or CONTEXT["client"]
 
     try:
         assert client.user_id
@@ -285,7 +440,9 @@ async def main(client: Optional[AsyncClient] = None):
     await client.sync(timeout=30000)
 
     client.add_event_callback(message_callback, RoomMessageText)
+    client.add_event_callback(message_callback, MegolmEvent)
     client.add_event_callback(room_invite_callback, InviteEvent)
+    client.add_event_callback(test_callback, Event)
 
     await accept_pending_invites()  # Accept pending invites
 
@@ -351,6 +508,31 @@ def initialize_database(path: os.PathLike):
         return database
 
 
+async def get_device_id(access_token, homeserver):
+    client = AsyncClient(homeserver)
+    client.access_token = access_token
+
+    logging(f"Obtaining device ID for access token {access_token}...", "debug")
+    response = await client.whoami()
+    if isinstance(response, WhoamiResponse):
+        logging(
+            f"Authenticated as {response.user_id}.")
+        user_id = response.user_id
+        devices = await client.devices()
+        device_id = devices.devices[0].id
+
+        await client.close()
+
+        return device_id, user_id
+
+    else:
+        logging(f"Failed to obtain device ID: {response}", "error")
+
+        await client.close()
+
+        return None, None
+
+
 if __name__ == "__main__":
     # Parse command line arguments
     parser = ArgumentParser()
@@ -362,54 +544,9 @@ if __name__ == "__main__":
     config = ConfigParser()
     config.read(args.config)
 
-    # Set up Matrix client
-    try:
-        assert "Matrix" in config
-        assert "Homeserver" in config["Matrix"]
-        assert "AccessToken" in config["Matrix"]
-    except:
-        logging("Matrix config not found or incomplete", "critical")
-        exit(1)
-
-    CONTEXT["client"] = AsyncClient(config["Matrix"]["Homeserver"])
-
-    CONTEXT["client"].access_token = config["Matrix"]["AccessToken"]
-    CONTEXT["client"].user_id = config["Matrix"].get("UserID")
-
-    # Set up GPT API
-    try:
-        assert "OpenAI" in config
-        assert "APIKey" in config["OpenAI"]
-    except:
-        logging("OpenAI config not found or incomplete", "critical")
-        exit(1)
-
-    openai.api_key = config["OpenAI"]["APIKey"]
-
-    if "Model" in config["OpenAI"]:
-        CONTEXT["model"] = config["OpenAI"]["Model"]
-
-    if "MaxTokens" in config["OpenAI"]:
-        CONTEXT["max_tokens"] = int(config["OpenAI"]["MaxTokens"])
-
-    if "MaxMessages" in config["OpenAI"]:
-        CONTEXT["max_messages"] = int(config["OpenAI"]["MaxMessages"])
-
-    # Set up database
-    if "Database" in config and config["Database"].get("Path"):
-        CONTEXT["database"] = initialize_database(config["Database"]["Path"])
-
-    # Listen for SIGTERM
-
-    def sigterm_handler(_signo, _stack_frame):
-        logging("Received SIGTERM - exiting...")
-        exit()
-
-    signal.signal(signal.SIGTERM, sigterm_handler)
-
     # Start bot loop
     try:
-        asyncio.run(main())
+        asyncio.run(main(config))
     except KeyboardInterrupt:
         logging("Received KeyboardInterrupt - exiting...")
     except SystemExit: