store.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614
  1. import duckdb
  2. from nio.store.database import MatrixStore, DeviceTrustState, OlmDevice, TrustState, InboundGroupSession, SessionStore, OlmSessions, GroupSessionStore, OutgoingKeyRequest, DeviceStore, Session
  3. from nio.crypto import OlmAccount, OlmDevice
  4. from random import SystemRandom
  5. from collections import defaultdict
  6. from typing import Dict, List, Optional, Tuple
  7. import json
  8. class DuckDBStore(MatrixStore):
  9. @property
  10. def account_id(self):
  11. id = self._get_account()[0] if self._get_account() else None
  12. if id is None:
  13. id = SystemRandom().randint(0, 2**16)
  14. return id
  15. def __init__(self, user_id, device_id, duckdb_conn):
  16. self.conn = duckdb_conn
  17. self.user_id = user_id
  18. self.device_id = device_id
  19. def _get_account(self):
  20. cursor = self.conn.cursor()
  21. cursor.execute(
  22. "SELECT * FROM accounts WHERE user_id = ? AND device_id = ?",
  23. (self.user_id, self.device_id),
  24. )
  25. account = cursor.fetchone()
  26. cursor.close()
  27. return account
  28. def _get_device(self, device):
  29. acc = self._get_account()
  30. if not acc:
  31. return None
  32. cursor = self.conn.cursor()
  33. cursor.execute(
  34. "SELECT * FROM device_keys WHERE user_id = ? AND device_id = ? AND account_id = ?",
  35. (device.user_id, device.id, acc[0]),
  36. )
  37. device_entry = cursor.fetchone()
  38. cursor.close()
  39. return device_entry
  40. # Implementing methods with DuckDB equivalents
  41. def verify_device(self, device):
  42. if self.is_device_verified(device):
  43. return False
  44. d = self._get_device(device)
  45. assert d
  46. cursor = self.conn.cursor()
  47. cursor.execute(
  48. "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
  49. (d[0], TrustState.verified),
  50. )
  51. self.conn.commit()
  52. cursor.close()
  53. device.trust_state = TrustState.verified
  54. return True
  55. def unverify_device(self, device):
  56. if not self.is_device_verified(device):
  57. return False
  58. d = self._get_device(device)
  59. assert d
  60. cursor = self.conn.cursor()
  61. cursor.execute(
  62. "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
  63. (d[0], TrustState.unset),
  64. )
  65. self.conn.commit()
  66. cursor.close()
  67. device.trust_state = TrustState.unset
  68. return True
  69. def is_device_verified(self, device):
  70. d = self._get_device(device)
  71. if not d:
  72. return False
  73. cursor = self.conn.cursor()
  74. cursor.execute(
  75. "SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
  76. )
  77. trust_state = cursor.fetchone()
  78. cursor.close()
  79. if not trust_state:
  80. return False
  81. return trust_state[0] == TrustState.verified
  82. def blacklist_device(self, device):
  83. if self.is_device_blacklisted(device):
  84. return False
  85. d = self._get_device(device)
  86. assert d
  87. cursor = self.conn.cursor()
  88. cursor.execute(
  89. "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
  90. (d[0], TrustState.blacklisted),
  91. )
  92. self.conn.commit()
  93. cursor.close()
  94. device.trust_state = TrustState.blacklisted
  95. return True
  96. def unblacklist_device(self, device):
  97. if not self.is_device_blacklisted(device):
  98. return False
  99. d = self._get_device(device)
  100. assert d
  101. cursor = self.conn.cursor()
  102. cursor.execute(
  103. "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
  104. (d[0], TrustState.unset),
  105. )
  106. self.conn.commit()
  107. cursor.close()
  108. device.trust_state = TrustState.unset
  109. return True
  110. def is_device_blacklisted(self, device):
  111. d = self._get_device(device)
  112. if not d:
  113. return False
  114. cursor = self.conn.cursor()
  115. cursor.execute(
  116. "SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
  117. )
  118. trust_state = cursor.fetchone()
  119. cursor.close()
  120. if not trust_state:
  121. return False
  122. return trust_state[0] == TrustState.blacklisted
  123. def ignore_device(self, device):
  124. if self.is_device_ignored(device):
  125. return False
  126. d = self._get_device(device)
  127. assert d
  128. cursor = self.conn.cursor()
  129. cursor.execute(
  130. "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
  131. (d[0], int(TrustState.ignored.value)),
  132. )
  133. self.conn.commit()
  134. cursor.close()
  135. return True
  136. def ignore_devices(self, devices):
  137. for device in devices:
  138. self.ignore_device(device)
  139. def unignore_device(self, device):
  140. if not self.is_device_ignored(device):
  141. return False
  142. d = self._get_device(device)
  143. assert d
  144. cursor = self.conn.cursor()
  145. cursor.execute(
  146. "INSERT OR REPLACE INTO device_trust_state (device_id, state) VALUES (?, ?)",
  147. (d[0], TrustState.unset),
  148. )
  149. self.conn.commit()
  150. cursor.close()
  151. device.trust_state = TrustState.unset
  152. return True
  153. def is_device_ignored(self, device):
  154. d = self._get_device(device)
  155. if not d:
  156. return False
  157. cursor = self.conn.cursor()
  158. cursor.execute(
  159. "SELECT state FROM device_trust_state WHERE device_id = ?", (d[0],)
  160. )
  161. trust_state = cursor.fetchone()
  162. cursor.close()
  163. if not trust_state:
  164. return False
  165. return trust_state[0] == TrustState.ignored
  166. def load_device_keys(self):
  167. """Load all the device keys from the database.
  168. Returns DeviceStore containing the OlmDevices with the device keys.
  169. """
  170. store = DeviceStore()
  171. account = self.account_id
  172. if not account:
  173. return store
  174. with self.conn.cursor() as cur:
  175. cur.execute(
  176. "SELECT * FROM device_keys WHERE account_id = ?",
  177. (account,)
  178. )
  179. device_keys = cur.fetchall()
  180. for d in device_keys:
  181. cur.execute(
  182. "SELECT * FROM keys WHERE device_id = ?",
  183. (d[0],)
  184. )
  185. keys = cur.fetchall()
  186. key_dict = {k[0]: k[1] for k in keys}
  187. store.add(
  188. OlmDevice(
  189. d[2],
  190. d[0],
  191. key_dict,
  192. display_name=d[3],
  193. deleted=d[4],
  194. )
  195. )
  196. return store
  197. def save_device_keys(self, device_keys):
  198. """Save the provided device keys to the database."""
  199. account = self.account_id
  200. assert account
  201. rows = []
  202. for user_id, devices_dict in device_keys.items():
  203. for device_id, device in devices_dict.items():
  204. rows.append(
  205. {
  206. "account_id": account,
  207. "user_id": user_id,
  208. "device_id": device_id,
  209. "display_name": device.display_name,
  210. "deleted": device.deleted,
  211. }
  212. )
  213. if not rows:
  214. return
  215. with self.conn.cursor() as cur:
  216. for idx in range(0, len(rows), 100):
  217. data = rows[idx: idx + 100]
  218. cur.executemany(
  219. "INSERT OR IGNORE INTO device_keys (account_id, user_id, device_id, display_name, deleted) VALUES (?, ?, ?, ?, ?)",
  220. [(r["account_id"], r["user_id"], r["device_id"],
  221. r["display_name"], r["deleted"]) for r in data]
  222. )
  223. for user_id, devices_dict in device_keys.items():
  224. for device_id, device in devices_dict.items():
  225. cur.execute(
  226. "UPDATE device_keys SET deleted = ? WHERE device_id = ?",
  227. (device.deleted, device_id)
  228. )
  229. for key_type, key in device.keys.items():
  230. cur.execute("""
  231. INSERT INTO keys (key_type, key, device_id) VALUES (?, ?, ?)
  232. ON CONFLICT (key_type, device_id) DO UPDATE SET key = ?
  233. """,
  234. (key_type, key, device_id, key)
  235. )
  236. self.conn.commit()
  237. def save_group_sessions(self, sessions):
  238. with self.conn.cursor() as cur:
  239. for session in sessions:
  240. cur.execute("""
  241. INSERT OR REPLACE INTO inbound_group_sessions (
  242. session_id, sender_key, signing_key, room_id, pickle, account_id
  243. ) VALUES (?, ?, ?, ?, ?, ?)
  244. """, (
  245. session.id,
  246. session.sender_key,
  247. session.signing_key,
  248. session.room_id,
  249. session.pickle,
  250. self.account_id
  251. ))
  252. self.conn.commit()
  253. def save_olm_sessions(self, sessions):
  254. with self.conn.cursor() as cur:
  255. for session in sessions:
  256. cur.execute("""
  257. INSERT OR REPLACE INTO olm_sessions (
  258. session_id, sender_key, pickle, account_id
  259. ) VALUES (?, ?, ?, ?)
  260. """, (
  261. session.id,
  262. session.sender_key,
  263. session.pickle,
  264. self.account_id
  265. ))
  266. self.conn.commit()
  267. def save_outbound_group_sessions(self, sessions):
  268. with self.conn.cursor() as cur:
  269. for session in sessions:
  270. cur.execute("""
  271. INSERT OR REPLACE INTO outbound_group_sessions (
  272. room_id, session_id, pickle, account_id
  273. ) VALUES (?, ?, ?, ?)
  274. """, (
  275. session.room_id,
  276. session.id,
  277. session.pickle,
  278. self.account_id
  279. ))
  280. self.conn.commit()
  281. def save_account(self, account: OlmAccount):
  282. with self.conn.cursor() as cur:
  283. cur.execute("""
  284. INSERT OR REPLACE INTO accounts (
  285. id, user_id, device_id, shared_account, pickle
  286. ) VALUES (?, ?, ?, ?, ?)
  287. """, (
  288. self.account_id,
  289. self.user_id,
  290. self.device_id,
  291. account.shared,
  292. account.pickle(self.pickle_key),
  293. ))
  294. self.conn.commit()
  295. def load_sessions(self):
  296. session_store = SessionStore()
  297. with self.conn.cursor() as cur:
  298. cur.execute("""
  299. SELECT
  300. os.sender_key, os.session, os.creation_time
  301. FROM
  302. olm_sessions os
  303. INNER JOIN
  304. accounts a ON os.account_id = a.id
  305. WHERE
  306. a.id = ?
  307. """, (self.account_id,))
  308. for row in cur.fetchall():
  309. sender_key, session_pickle, creation_time = row
  310. session = Session.from_pickle(
  311. session_pickle, creation_time, self.pickle_key)
  312. session_store.add(sender_key, session)
  313. return session_store
  314. def load_inbound_group_sessions(self):
  315. # type: () -> GroupSessionStore
  316. """Load all Olm sessions from the database.
  317. Returns:
  318. ``GroupSessionStore`` object, containing all the loaded sessions.
  319. """
  320. store = GroupSessionStore()
  321. account = self.account_id
  322. if not account:
  323. return store
  324. with self.conn.cursor() as cursor:
  325. cursor.execute(
  326. "SELECT * FROM inbound_group_sessions WHERE account_id = ?", (
  327. account,)
  328. )
  329. for row in cursor.fetchall():
  330. cursor.execute(
  331. "SELECT sender_key FROM forwarded_chains WHERE session_id = ?",
  332. (row[1],),
  333. )
  334. chains = cursor.fetchall()
  335. session = InboundGroupSession.from_pickle(
  336. row[2].encode(),
  337. row[3],
  338. row[4],
  339. row[5],
  340. self.pickle_key,
  341. [
  342. chain[0]
  343. for chain in chains
  344. ],
  345. )
  346. store.add(session)
  347. return store
  348. def load_outgoing_key_requests(self):
  349. # type: () -> dict
  350. """Load all outgoing key requests from the database.
  351. Returns:
  352. ``OutgoingKeyRequestStore`` object, containing all the loaded key requests.
  353. """
  354. account = self.account_id
  355. if not account:
  356. return store
  357. with self.conn.cursor() as cur:
  358. cur.execute(
  359. "SELECT * FROM outgoing_key_requests WHERE account_id = ?",
  360. (account,)
  361. )
  362. rows = cur.fetchall()
  363. return {
  364. request.request_id: OutgoingKeyRequest.from_database(request)
  365. for request in rows
  366. }
  367. def load_encrypted_rooms(self):
  368. """Load the set of encrypted rooms for this account.
  369. Returns:
  370. ``Set`` containing room ids of encrypted rooms.
  371. """
  372. account = self.account_id
  373. if not account:
  374. return set()
  375. with self.conn.cursor() as cur:
  376. cur.execute(
  377. "SELECT room_id FROM encrypted_rooms WHERE account_id = ?",
  378. (account,)
  379. )
  380. rows = cur.fetchall()
  381. return {row[0] for row in rows}
  382. def save_sync_token(self, token):
  383. """Save the given token"""
  384. account = self.account_id
  385. assert account
  386. with self.conn.cursor() as cur:
  387. cur.execute(
  388. "INSERT OR REPLACE INTO sync_tokens (account_id, token) VALUES (?, ?)",
  389. (account, token)
  390. )
  391. self.conn.commit()
  392. def save_encrypted_rooms(self, rooms):
  393. """Save the set of room ids for this account."""
  394. account = self.account_id
  395. assert account
  396. data = [(room_id, account) for room_id in rooms]
  397. with self.conn.cursor() as cur:
  398. for idx in range(0, len(data), 400):
  399. rows = data[idx: idx + 400]
  400. cur.executemany(
  401. "INSERT OR IGNORE INTO encrypted_rooms (room_id, account_id) VALUES (?, ?)",
  402. rows
  403. )
  404. self.conn.commit()
  405. def save_session(self, sender_key, session):
  406. """Save the provided Olm session to the database.
  407. Args:
  408. sender_key (str): The curve key that owns the Olm session.
  409. session (Session): The Olm session that will be pickled and
  410. saved in the database.
  411. """
  412. account = self.account_id
  413. assert account
  414. pickled_session = session.pickle(self.pickle_key)
  415. with self.conn.cursor() as cur:
  416. cur.execute(
  417. "INSERT OR REPLACE INTO olm_sessions (account_id, sender_key, session, session_id, creation_time, last_usage_date) VALUES (?, ?, ?, ?, ?, ?)",
  418. (account, sender_key, pickled_session, session.id,
  419. session.creation_time, session.use_time)
  420. )
  421. self.conn.commit()
  422. def save_inbound_group_session(self, session):
  423. """Save the provided Megolm inbound group session to the database.
  424. Args:
  425. session (InboundGroupSession): The session to save.
  426. """
  427. account = self.account_id
  428. assert account
  429. with self.conn.cursor() as cur:
  430. # Insert a new session or update the existing one
  431. query = """
  432. INSERT INTO inbound_group_sessions (account_id, sender_key, fp_key, room_id, session)
  433. VALUES (?, ?, ?, ?, ?)
  434. ON CONFLICT (account_id, sender_key, fp_key, room_id)
  435. DO UPDATE SET session = excluded.session
  436. """
  437. cur.execute(query, (account, session.sender_key,
  438. session.ed25519, session.room_id, session.pickle(self.pickle_key)))
  439. # Delete existing forwarded chains for the session
  440. delete_query = """
  441. DELETE FROM forwarded_chains WHERE session_id = (SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?)
  442. """
  443. cur.execute(
  444. delete_query, (account, session.sender_key, session.ed25519, session.room_id))
  445. # Insert new forwarded chains for the session
  446. insert_query = """
  447. INSERT INTO forwarded_chains (session_id, sender_key)
  448. VALUES ((SELECT id FROM inbound_group_sessions WHERE account_id = ? AND sender_key = ? AND fp_key = ? AND room_id = ?), ?)
  449. """
  450. for chain in session.forwarding_chain:
  451. cur.execute(
  452. insert_query, (account, session.sender_key, session.ed25519, session.room_id, chain))
  453. def add_outgoing_key_request(self, key_request):
  454. account_id = self.account_id
  455. with self.conn.cursor() as cursor:
  456. cursor.execute(
  457. """
  458. INSERT INTO outgoing_key_requests (account_id, request_id, session_id, room_id, algorithm)
  459. VALUES (?, ?, ?, ?, ?)
  460. ON CONFLICT (account_id, request_id) DO NOTHING
  461. """,
  462. (
  463. account_id,
  464. key_request.request_id,
  465. key_request.session_id,
  466. key_request.room_id,
  467. key_request.algorithm,
  468. )
  469. )
  470. def load_account(self):
  471. # type: () -> Optional[OlmAccount]
  472. """Load the Olm account from the database.
  473. Returns:
  474. ``OlmAccount`` object, or ``None`` if it wasn't found for the
  475. current device_id.
  476. """
  477. cursor = self.conn.cursor()
  478. query = """
  479. SELECT pickle, shared_account
  480. FROM accounts
  481. WHERE device_id = ?;
  482. """
  483. cursor.execute(query, (self.device_id,))
  484. result = cursor.fetchone()
  485. if not result:
  486. return None
  487. account_pickle, shared = result
  488. return OlmAccount.from_pickle(account_pickle.encode(), self.pickle_key, shared)