store.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637
  1. import duckdb
  2. from nio.store.database import MatrixStore, DeviceTrustState, OlmDevice, TrustState, InboundGroupSession, SessionStore, OlmSessions, GroupSessionStore, OutgoingKeyRequest, DeviceStore, Session
  3. from nio.crypto import OlmAccount, OlmDevice
  4. from random import SystemRandom
  5. from collections import defaultdict
  6. from typing import Dict, List, Optional, Tuple
  7. from .dict import AttrDict
  8. import json
  9. class DuckDBStore(MatrixStore):
  10. @property
  11. def account_id(self):
  12. id = self._get_account()[0] if self._get_account() else None
  13. if id is None:
  14. id = SystemRandom().randint(0, 2**16)
  15. return id
  16. def __init__(self, user_id, device_id, duckdb_conn):
  17. self.conn = duckdb_conn
  18. self.user_id = user_id
  19. self.device_id = device_id
  20. def _get_account(self):
  21. cursor = self.conn.cursor()
  22. cursor.execute(
  23. "SELECT * FROM accounts WHERE user_id = ? AND device_id = ?",
  24. (self.user_id, self.device_id),
  25. )
  26. account = cursor.fetchone()
  27. cursor.close()
  28. return account
  29. def _get_device(self, device):
  30. acc = self._get_account()
  31. if not acc:
  32. return None
  33. cursor = self.conn.cursor()
  34. cursor.execute(
  35. "SELECT * FROM device_keys WHERE user_id = ? AND device_id = ? AND account_id = ?",
  36. (device.user_id, device.id, acc[0]),
  37. )
  38. device_entry = cursor.fetchone()
  39. cursor.close()
  40. return device_entry
  41. # Implementing methods with DuckDB equivalents
  42. def verify_device(self, device):
  43. if self.is_device_verified(device):
  44. return False
  45. d = self._get_device(device)
  46. assert d
  47. cursor = self.conn.cursor()
  48. cursor.execute(
  49. "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
  50. (d[0], TrustState.verified),
  51. )
  52. self.conn.commit()
  53. cursor.close()
  54. device.trust_state = TrustState.verified
  55. return True
  56. def unverify_device(self, device):
  57. if not self.is_device_verified(device):
  58. return False
  59. d = self._get_device(device)
  60. assert d
  61. cursor = self.conn.cursor()
  62. cursor.execute(
  63. "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
  64. (d[0], TrustState.unset),
  65. )
  66. self.conn.commit()
  67. cursor.close()
  68. device.trust_state = TrustState.unset
  69. return True
  70. def is_device_verified(self, device):
  71. d = self._get_device(device)
  72. if not d:
  73. return False
  74. cursor = self.conn.cursor()
  75. cursor.execute(
  76. "SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
  77. )
  78. trust_state = cursor.fetchone()
  79. cursor.close()
  80. if not trust_state:
  81. return False
  82. return trust_state[0] == TrustState.verified
  83. def blacklist_device(self, device):
  84. if self.is_device_blacklisted(device):
  85. return False
  86. d = self._get_device(device)
  87. assert d
  88. cursor = self.conn.cursor()
  89. cursor.execute(
  90. "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
  91. (d[0], TrustState.blacklisted),
  92. )
  93. self.conn.commit()
  94. cursor.close()
  95. device.trust_state = TrustState.blacklisted
  96. return True
  97. def unblacklist_device(self, device):
  98. if not self.is_device_blacklisted(device):
  99. return False
  100. d = self._get_device(device)
  101. assert d
  102. cursor = self.conn.cursor()
  103. cursor.execute(
  104. "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
  105. (d[0], TrustState.unset),
  106. )
  107. self.conn.commit()
  108. cursor.close()
  109. device.trust_state = TrustState.unset
  110. return True
  111. def is_device_blacklisted(self, device):
  112. d = self._get_device(device)
  113. if not d:
  114. return False
  115. cursor = self.conn.cursor()
  116. cursor.execute(
  117. "SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
  118. )
  119. trust_state = cursor.fetchone()
  120. cursor.close()
  121. if not trust_state:
  122. return False
  123. return trust_state[0] == TrustState.blacklisted
  124. def ignore_device(self, device):
  125. if self.is_device_ignored(device):
  126. return False
  127. d = self._get_device(device)
  128. assert d
  129. cursor = self.conn.cursor()
  130. cursor.execute(
  131. "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
  132. (d[0], int(TrustState.ignored.value)),
  133. )
  134. self.conn.commit()
  135. cursor.close()
  136. return True
  137. def ignore_devices(self, devices):
  138. for device in devices:
  139. self.ignore_device(device)
  140. def unignore_device(self, device):
  141. if not self.is_device_ignored(device):
  142. return False
  143. d = self._get_device(device)
  144. assert d
  145. cursor = self.conn.cursor()
  146. cursor.execute(
  147. "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
  148. (d[0], TrustState.unset),
  149. )
  150. self.conn.commit()
  151. cursor.close()
  152. device.trust_state = TrustState.unset
  153. return True
  154. def is_device_ignored(self, device):
  155. d = self._get_device(device)
  156. if not d:
  157. return False
  158. cursor = self.conn.cursor()
  159. cursor.execute(
  160. "SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
  161. )
  162. trust_state = cursor.fetchone()
  163. cursor.close()
  164. if not trust_state:
  165. return False
  166. return trust_state[0] == TrustState.ignored
  167. def load_device_keys(self):
  168. """Load all the device keys from the database.
  169. Returns DeviceStore containing the OlmDevices with the device keys.
  170. """
  171. store = DeviceStore()
  172. account = self.account_id
  173. if not account:
  174. return store
  175. with self.conn.cursor() as cur:
  176. cur.execute(
  177. "SELECT * FROM device_keys WHERE account_id = ?",
  178. (account,)
  179. )
  180. device_keys = cur.fetchall()
  181. for d in device_keys:
  182. cur.execute(
  183. "SELECT * FROM keys WHERE device_id = ?",
  184. (d[0],)
  185. )
  186. keys = cur.fetchall()
  187. key_dict = {k[0]: k[1] for k in keys}
  188. store.add(
  189. OlmDevice(
  190. d[2],
  191. d[0],
  192. key_dict,
  193. display_name=d[3],
  194. deleted=d[4],
  195. )
  196. )
  197. return store
  198. def save_device_keys(self, device_keys):
  199. """Save the provided device keys to the database."""
  200. account = self.account_id
  201. assert account
  202. rows = []
  203. for user_id, devices_dict in device_keys.items():
  204. for device_id, device in devices_dict.items():
  205. rows.append(
  206. {
  207. "account_id": account,
  208. "user_id": user_id,
  209. "device_id": device_id,
  210. "display_name": device.display_name,
  211. "deleted": device.deleted,
  212. }
  213. )
  214. if not rows:
  215. return
  216. with self.conn.cursor() as cur:
  217. for idx in range(0, len(rows), 100):
  218. data = rows[idx: idx + 100]
  219. cur.executemany(
  220. "INSERT OR IGNORE INTO device_keys (account_id, user_id, device_id, display_name, deleted) VALUES (?, ?, ?, ?, ?)",
  221. [(r["account_id"], r["user_id"], r["device_id"],
  222. r["display_name"], r["deleted"]) for r in data]
  223. )
  224. for user_id, devices_dict in device_keys.items():
  225. for device_id, device in devices_dict.items():
  226. cur.execute(
  227. "UPDATE device_keys SET deleted = ? WHERE device_id = ?",
  228. (device.deleted, device_id)
  229. )
  230. for key_type, key in device.keys.items():
  231. cur.execute("""
  232. INSERT INTO keys (key_type, key, device_id) VALUES (?, ?, ?)
  233. ON CONFLICT (key_type, device_id) DO UPDATE SET key = ?
  234. """,
  235. (key_type, key, device_id, key)
  236. )
  237. self.conn.commit()
  238. def save_group_sessions(self, sessions):
  239. with self.conn.cursor() as cur:
  240. for session in sessions:
  241. cur.execute("""
  242. INSERT OR REPLACE INTO inbound_group_sessions (
  243. session_id, sender_key, signing_key, room_id, pickle, account_id
  244. ) VALUES (?, ?, ?, ?, ?, ?)
  245. """, (
  246. session.id,
  247. session.sender_key,
  248. session.signing_key,
  249. session.room_id,
  250. session.pickle,
  251. self.account_id
  252. ))
  253. self.conn.commit()
  254. def save_olm_sessions(self, sessions):
  255. with self.conn.cursor() as cur:
  256. for session in sessions:
  257. cur.execute("""
  258. INSERT OR REPLACE INTO olm_sessions (
  259. session_id, sender_key, pickle, account_id
  260. ) VALUES (?, ?, ?, ?)
  261. """, (
  262. session.id,
  263. session.sender_key,
  264. session.pickle,
  265. self.account_id
  266. ))
  267. self.conn.commit()
  268. def save_outbound_group_sessions(self, sessions):
  269. with self.conn.cursor() as cur:
  270. for session in sessions:
  271. cur.execute("""
  272. INSERT OR REPLACE INTO outbound_group_sessions (
  273. room_id, session_id, pickle, account_id
  274. ) VALUES (?, ?, ?, ?)
  275. """, (
  276. session.room_id,
  277. session.id,
  278. session.pickle,
  279. self.account_id
  280. ))
  281. self.conn.commit()
  282. def save_account(self, account: OlmAccount):
  283. with self.conn.cursor() as cur:
  284. cur.execute("""
  285. INSERT OR REPLACE INTO accounts (
  286. id, user_id, device_id, shared_account, pickle
  287. ) VALUES (?, ?, ?, ?, ?)
  288. """, (
  289. self.account_id,
  290. self.user_id,
  291. self.device_id,
  292. account.shared,
  293. account.pickle(self.pickle_key),
  294. ))
  295. self.conn.commit()
  296. def load_sessions(self):
  297. session_store = SessionStore()
  298. with self.conn.cursor() as cur:
  299. cur.execute("""
  300. SELECT
  301. os.sender_key, os.session, os.creation_time
  302. FROM
  303. olm_sessions os
  304. INNER JOIN
  305. accounts a ON os.account_id = a.id
  306. WHERE
  307. a.id = ?
  308. """, (self.account_id,))
  309. for row in cur.fetchall():
  310. sender_key, session_pickle, creation_time = row
  311. session = Session.from_pickle(
  312. session_pickle, creation_time, self.pickle_key)
  313. session_store.add(sender_key, session)
  314. return session_store
  315. def load_inbound_group_sessions(self):
  316. # type: () -> GroupSessionStore
  317. """Load all Olm sessions from the database.
  318. Returns:
  319. ``GroupSessionStore`` object, containing all the loaded sessions.
  320. """
  321. store = GroupSessionStore()
  322. account = self.account_id
  323. if not account:
  324. return store
  325. with self.conn.cursor() as cursor:
  326. cursor.execute(
  327. "SELECT * FROM inbound_group_sessions WHERE account_id = ?", (
  328. account,)
  329. )
  330. for row in cursor.fetchall():
  331. cursor.execute(
  332. "SELECT sender_key FROM forwarded_chains WHERE session_id = ?",
  333. (row[1],),
  334. )
  335. chains = cursor.fetchall()
  336. session = InboundGroupSession.from_pickle(
  337. row[2].encode(),
  338. row[3],
  339. row[4],
  340. row[5],
  341. self.pickle_key,
  342. [
  343. chain[0]
  344. for chain in chains
  345. ],
  346. )
  347. store.add(session)
  348. return store
  349. def load_outgoing_key_requests(self):
  350. # type: () -> dict
  351. """Load all outgoing key requests from the database.
  352. Returns:
  353. ``OutgoingKeyRequestStore`` object, containing all the loaded key requests.
  354. """
  355. account = self.account_id
  356. if not account:
  357. return store
  358. with self.conn.cursor() as cur:
  359. cur.execute(
  360. "SELECT * FROM outgoing_key_requests WHERE account_id = ?",
  361. (account,)
  362. )
  363. rows = cur.fetchall()
  364. return {
  365. row[1]: OutgoingKeyRequest.from_response(AttrDict({
  366. "id": row[0],
  367. "account_id": row[1],
  368. "request_id": row[2],
  369. "session_id": row[3],
  370. "room_id": row[4],
  371. "algorithm": row[5],
  372. })) for row in rows
  373. }
  374. def load_encrypted_rooms(self):
  375. """Load the set of encrypted rooms for this account.
  376. Returns:
  377. ``Set`` containing room ids of encrypted rooms.
  378. """
  379. account = self.account_id
  380. if not account:
  381. return set()
  382. with self.conn.cursor() as cur:
  383. cur.execute(
  384. "SELECT room_id FROM encrypted_rooms WHERE account_id = ?",
  385. (account,)
  386. )
  387. rows = cur.fetchall()
  388. return {row[0] for row in rows}
  389. def save_sync_token(self, token):
  390. """Save the given token"""
  391. account = self.account_id
  392. assert account
  393. with self.conn.cursor() as cur:
  394. cur.execute(
  395. "INSERT OR REPLACE INTO sync_tokens (account_id, token) VALUES (?, ?)",
  396. (account, token)
  397. )
  398. self.conn.commit()
  399. def save_encrypted_rooms(self, rooms):
  400. """Save the set of room ids for this account."""
  401. account = self.account_id
  402. assert account
  403. data = [(room_id, account) for room_id in rooms]
  404. with self.conn.cursor() as cur:
  405. for idx in range(0, len(data), 400):
  406. rows = data[idx: idx + 400]
  407. cur.executemany(
  408. "INSERT OR IGNORE INTO encrypted_rooms (room_id, account_id) VALUES (?, ?)",
  409. rows
  410. )
  411. self.conn.commit()
  412. def save_session(self, sender_key, session):
  413. """Save the provided Olm session to the database.
  414. Args:
  415. sender_key (str): The curve key that owns the Olm session.
  416. session (Session): The Olm session that will be pickled and
  417. saved in the database.
  418. """
  419. account = self.account_id
  420. assert account
  421. pickled_session = session.pickle(self.pickle_key)
  422. with self.conn.cursor() as cur:
  423. cur.execute(
  424. "INSERT OR REPLACE INTO olm_sessions (account_id, sender_key, session, session_id, creation_time, last_usage_date) VALUES (?, ?, ?, ?, ?, ?)",
  425. (account, sender_key, pickled_session, session.id,
  426. session.creation_time, session.use_time)
  427. )
  428. self.conn.commit()
  429. def save_inbound_group_session(self, session):
  430. """Save the provided Megolm inbound group session to the database.
  431. Args:
  432. session (InboundGroupSession): The session to save.
  433. """
  434. account = self.account_id
  435. assert account
  436. with self.conn.cursor() as cur:
  437. # Insert a new session or update the existing one
  438. query = """
  439. INSERT INTO inbound_group_sessions (account_id, sender_key, fp_key, room_id, session)
  440. VALUES (?, ?, ?, ?, ?)
  441. ON CONFLICT (account_id, sender_key, fp_key, room_id)
  442. DO UPDATE SET session = excluded.session
  443. """
  444. cur.execute(query, (account, session.sender_key,
  445. session.ed25519, session.room_id, session.pickle(self.pickle_key)))
  446. # Delete existing forwarded chains for the session
  447. delete_query = """
  448. DELETE FROM forwarded_chains WHERE session_id = (SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?)
  449. """
  450. cur.execute(
  451. delete_query, (account, session.sender_key, session.ed25519, session.room_id))
  452. # Insert new forwarded chains for the session
  453. insert_query = """
  454. INSERT INTO forwarded_chains (session_id, sender_key)
  455. VALUES ((SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?), ?)
  456. """
  457. for chain in session.forwarding_chain:
  458. cur.execute(
  459. insert_query, (account, session.sender_key, session.ed25519, session.room_id, chain))
  460. def add_outgoing_key_request(self, key_request):
  461. """Add a new outgoing key request to the database.
  462. Args:
  463. key_request (OutgoingKeyRequest): The key request to add.
  464. """
  465. account_id = self.account_id
  466. with self.conn.cursor() as cursor:
  467. cursor.execute(
  468. """
  469. SELECT MAX(id) FROM outgoing_key_requests
  470. """
  471. )
  472. row = cursor.fetchone()
  473. request_id = row[0] + 1 if row[0] else 1
  474. cursor.execute(
  475. """
  476. INSERT INTO outgoing_key_requests (id, account_id, request_id, session_id, room_id, algorithm)
  477. VALUES (?, ?, ?, ?, ?, ?)
  478. ON CONFLICT (account_id, request_id) DO NOTHING
  479. """,
  480. (
  481. request_id,
  482. account_id,
  483. key_request.request_id,
  484. key_request.session_id,
  485. key_request.room_id,
  486. key_request.algorithm,
  487. )
  488. )
  489. def load_account(self):
  490. # type: () -> Optional[OlmAccount]
  491. """Load the Olm account from the database.
  492. Returns:
  493. ``OlmAccount`` object, or ``None`` if it wasn't found for the
  494. current device_id.
  495. """
  496. cursor = self.conn.cursor()
  497. query = """
  498. SELECT pickle, shared_account
  499. FROM accounts
  500. WHERE device_id = ?;
  501. """
  502. cursor.execute(query, (self.device_id,))
  503. result = cursor.fetchone()
  504. if not result:
  505. return None
  506. account_pickle, shared = result
  507. return OlmAccount.from_pickle(account_pickle.encode(), self.pickle_key, shared)