102 lines
3.1 KiB
Python
102 lines
3.1 KiB
Python
import sqlite3
|
|
from semver import Version
|
|
from typing import Dict, Optional, List
|
|
|
|
|
|
VERSION_TABLE = '''CREATE TABLE version (
|
|
id TEXT,
|
|
version TEXT,
|
|
PRIMARY KEY (id)
|
|
);'''
|
|
KEY_TABLE = '''CREATE TABLE key (
|
|
chapter_id INT,
|
|
user_id INT,
|
|
key TEXT,
|
|
PRIMARY KEY (chapter_id, user_id)
|
|
);'''
|
|
DIVISION_TABLE = '''CREATE TABLE division(
|
|
division_id INT,
|
|
is_linear BOOL,
|
|
PRIMARY KEY(division_id)
|
|
);'''
|
|
|
|
|
|
class CwmDb:
|
|
def __init__(self, db_path):
|
|
self._db = sqlite3.connect(db_path, check_same_thread=False)
|
|
self.version = Version(0, 0, 0, 1)
|
|
if not self.__check_database():
|
|
self.__create_table()
|
|
|
|
def __check_database(self):
|
|
self.__update_exists_tables()
|
|
v = self.__read_version()
|
|
if v is None:
|
|
return False
|
|
if v < self.version:
|
|
if v < Version(0, 0, 0, 1):
|
|
self._db.execute(DIVISION_TABLE)
|
|
self.__update_exists_tables()
|
|
self.__write_version()
|
|
return True
|
|
|
|
def __create_table(self):
|
|
if 'version' not in self._exist_tables:
|
|
self._db.execute(VERSION_TABLE)
|
|
self.__write_version()
|
|
if 'key' not in self._exist_tables:
|
|
self._db.execute(KEY_TABLE)
|
|
self._db.commit()
|
|
|
|
def __write_version(self):
|
|
self._db.execute('INSERT OR REPLACE INTO version VALUES (?, ?);', [
|
|
'main', str(self.version)])
|
|
self._db.commit()
|
|
|
|
def __read_version(self) -> Optional[Version]:
|
|
if 'version' not in self._exist_tables:
|
|
return None
|
|
cur = self._db.execute(
|
|
'SELECT version FROM version WHERE id = ?;', ['main'])
|
|
for i in cur:
|
|
try:
|
|
return Version.parse(i[0])
|
|
except Exception:
|
|
return None
|
|
|
|
def __update_exists_tables(self):
|
|
cur = self._db.execute('SELECT * FROM main.sqlite_master;')
|
|
self._exist_tables = {}
|
|
for i in cur:
|
|
if i[0] == 'table':
|
|
self._exist_tables[i[1]] = i
|
|
|
|
def add_key(self, chapter_id: int, user_id: int, key: str):
|
|
self._db.execute('INSERT OR REPLACE INTO key VALUES (?, ?, ?);', [
|
|
chapter_id, user_id, key])
|
|
|
|
def get_all_keys_as_origin(self) -> Dict[str, str]:
|
|
cur = self._db.execute('SELECT chapter_id, user_id, key FROM key;')
|
|
return {f'{i[0]}{i[1]}': i[2] for i in cur}
|
|
|
|
def commit(self):
|
|
self._db.commit()
|
|
|
|
def get_key(self, chapter_id: int) -> List[str]:
|
|
cur = self._db.execute('SELECT key FROM key WHERE chapter_id = ?;', [
|
|
chapter_id])
|
|
return [i[0] for i in cur]
|
|
|
|
def get_mark(self, division_id: int):
|
|
cur = self._db.execute(
|
|
'SELECT is_linear FROM division WHERE division_id = ?;',
|
|
[division_id])
|
|
for i in cur:
|
|
return bool(i[0])
|
|
return True
|
|
|
|
def set_mark(self, division_id: int, is_linear: bool):
|
|
self._db.execute('INSERT OR REPLACE INTO division VALUES (?, ?);', [
|
|
division_id, is_linear])
|
|
self._db.commit()
|