diff --git a/actual/__init__.py b/actual/__init__.py index 82baac5..841653f 100644 --- a/actual/__init__.py +++ b/actual/__init__.py @@ -1,63 +1,43 @@ -import enum +from __future__ import annotations + +import contextlib import io +import json import pathlib +import re +import sqlite3 import tempfile +import uuid import zipfile -from typing import List, Union +from typing import TYPE_CHECKING, List, Union -import pydantic -import requests import sqlalchemy import sqlalchemy.orm from sqlalchemy.orm import joinedload -from actual.database import Accounts, Categories, Transactions, get_class_by_table_name -from actual.models import RemoteFile -from actual.protobuf_models import Message, SyncRequest, SyncResponse - - -class Endpoints(enum.Enum): - LOGIN = "account/login" - INFO = "info" - ACCOUNT_VALIDATE = "account/validate" - NEEDS_BOOTSTRAP = "account/needs-bootstrap" - SYNC = "sync/sync" - LIST_USER_FILES = "sync/list-user-files" - GET_USER_FILE_INFO = "sync/get-user-file-info" - DOWNLOAD_USER_FILE = "sync/download-user-file" - UPLOAD_USER_FILE = "sync/upload-user-file" - RESET_USER_FILE = "sync/reset-user-file" - # data related - DATA_FILE_INDEX = "data-file-index.txt" - DEFAULT_DB = "data/default-db.sqlite" - MIGRATIONS = "data/migrations" - - def __str__(self): - return self.value - - -class ActualError(Exception): - pass - - -class AuthorizationError(ActualError): - pass - +from actual.api import ActualServer, RemoteFileListDTO +from actual.database import ( + Accounts, + Categories, + Payees, + Transactions, + get_attribute_by_table_name, + get_class_by_table_name, +) +from actual.exceptions import InvalidZipFile, UnknownFileId +from actual.protobuf_models import Message, SyncRequest -class UnknownFileId(ActualError): - pass +if TYPE_CHECKING: + from actual.database import BaseModel -class InvalidZipFile(ActualError): - pass - - -class Actual: +class Actual(ActualServer): def __init__( self, base_url: str = "http://localhost:5006", token: str = None, password: str = None, + file: str = None, data_dir: Union[str, pathlib.Path] = None, ): """ @@ -70,92 +50,141 @@ def __init__( :param base_url: url of the running Actual server :param token: the token for authentication, if this is available (optional) :param password: the password for authentication. It will be used on the .login() method to retrieve the token. + :param file: the name or id of the file to be set :param data_dir: where to store the downloaded files from the server. If not specified, a temporary folder will be created instead. """ - self.api_url = base_url - self._token = token - self._password = password - self._file: RemoteFile | None = None - self._data_dir = data_dir + super().__init__(base_url, token, password) + self._file: RemoteFileListDTO | None = None + self._data_dir = pathlib.Path(data_dir) if data_dir else None self._session_maker = None - if token is None and password is None: - raise ValueError("Either provide a valid token or a password.") - - def login(self) -> str: - """Logs in on the Actual server using the password provided. Raises `AuthorizationError` if it fails to - authenticate the user.""" - if not self._password: - raise AuthorizationError("Trying to login but not password was provided.") - response = requests.post(f"{self.api_url}/{Endpoints.LOGIN}", json={"password": self._password}) - response.raise_for_status() - token = response.json()["data"]["token"] - if token is None: - raise AuthorizationError("Could not validate password on login.") - self._password = None # erase password - self._token = token - return self._token - - def headers(self, file_id: str = None, extra_headers: dict = None) -> dict: - """Generates headers by retrieving a token, if one is not provided, and auto-filling the file id.""" - if not self._token: - self.login() - headers = {"X-ACTUAL-TOKEN": self._token} - if self._file and self._file.file_id: - headers["X-ACTUAL-FILE-ID"] = file_id or self._file.file_id - if extra_headers: - headers = headers | extra_headers - return headers - - def user_files(self) -> List[RemoteFile]: - """Lists user files from remote. Requires authentication to return all results.""" - response = requests.get(f"{self.api_url}/{Endpoints.LIST_USER_FILES}", headers=self.headers()) - response.raise_for_status() - files = response.json() - return pydantic.parse_obj_as(List[RemoteFile], files["data"]) - - def set_file(self, file_id: Union[str, RemoteFile]) -> RemoteFile: + self._session: sqlalchemy.orm.Session | None = None + # set the correct file + if file: + self.set_file(file) + + def __enter__(self) -> Actual: + if self._file: + self.download_budget() + self._session = self._session_maker() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._session: + self._session.close() + + @contextlib.contextmanager + def with_session(self) -> sqlalchemy.orm.Session: + s = self._session if self._session else self._session_maker() + try: + yield s + finally: + if not self._session: + s.close() + + def set_file(self, file_id: Union[str, RemoteFileListDTO]) -> RemoteFileListDTO: """Sets the file id for the class for further requests. The file_id argument can be either a name or remote id from the file. If the name is provided, only the first match is taken.""" - if isinstance(file_id, RemoteFile): + if isinstance(file_id, RemoteFileListDTO): self._file = file_id return file_id else: - user_files = self.user_files() - for file in user_files: + selected_files = [] + user_files = self.list_user_files() + for file in user_files.data: if (file.file_id == file_id or file.name == file_id) and file.deleted == 0: - return self.set_file(file) - raise UnknownFileId(f"Could not find a file id or identifier '{file_id}'") - - def reset_file(self, file_id: Union[str, RemoteFile] = None) -> bool: - if not self._file: - if file_id is None: - raise UnknownFileId("Could not reset the file without a valid 'file_id'") - self.set_file(file_id) - request = requests.post( - f"{self.api_url}/{Endpoints.RESET_USER_FILE}", json={"fileId": self._file.file_id, "token": self._token} + selected_files.append(file) + if len(selected_files) == 0: + raise UnknownFileId(f"Could not find a file id or identifier '{file_id}'") + elif len(selected_files) > 1: + raise UnknownFileId(f"Multiple files found with identifier '{file_id}'") + return self.set_file(selected_files[0]) + + def run_migrations(self, migration_files: list[str]): + """Runs the migration files, skipping the ones that have already been run. The files can be retrieved from + .data_file_index() method. This first file is the base database, and the following files are migrations. + Migrations can also be .js files. In this case, we have to extract and execute queries from the standard JS.""" + conn = sqlite3.connect(self._data_dir / "db.sqlite") + for file in migration_files[1:]: + file_id = file.split("_")[0].split("/")[1] + if conn.execute(f"SELECT id FROM __migrations__ WHERE id = '{file_id}';").fetchall(): + continue # skip migration as it was already ran + migration = self.data_file(file) # retrieves file from actual server + sql_statements = migration.decode() + if file.endswith(".js"): + # there is one migration which is Javascript. All entries inside db.execQuery(`...`) must be executed + exec_entries = re.findall(r"db\.execQuery\(`([^`]*)`\)", sql_statements, re.DOTALL) + sql_statements = "\n".join(exec_entries) + conn.executescript(sql_statements) + conn.execute(f"INSERT INTO __migrations__ (id) VALUES ({file_id});") + conn.commit() + conn.close() + + def create_budget(self, budget_name: str): + """Creates a budget using the remote server default database and migrations.""" + migration_files = self.data_file_index() + # create folder for the files + if not self._data_dir: + self._data_dir = pathlib.Path(tempfile.mkdtemp()) + # first migration file is the default database + migration = self.data_file(migration_files[0]) + (self._data_dir / "db.sqlite").write_bytes(migration) + # also write the metadata file with default fields + random_id = str(uuid.uuid4()).replace("-", "")[:7] + file_id = str(uuid.uuid4()) + (self._data_dir / "metadata.json").write_text( + json.dumps( + { + "id": f"My-Finances-{random_id}", + "budgetName": budget_name, + "userId": self._token, + "cloudFileId": file_id, + "resetClock": True, + } + ) ) - request.raise_for_status() - return request.json()["status"] == "ok" + self._file = RemoteFileListDTO(name=budget_name, fileId=file_id, groupId=None, deleted=0, encryptKeyId=None) + # create engine for downloaded database and run migrations + self.run_migrations(migration_files[1:]) + + def upload_budget(self): + """Uploads the current file to the Actual server.""" + if not self._data_dir: + raise UnknownFileId("No current file loaded.") + binary_data = io.BytesIO() + with zipfile.ZipFile(binary_data, "a", zipfile.ZIP_DEFLATED, False) as z: + z.write(self._data_dir / "db.sqlite", "db.sqlite") + z.write(self._data_dir / "metadata.json", "metadata.json") + binary_data.seek(0) + return self.upload_user_file(binary_data.getvalue(), self._file.file_id, self._file.name) def apply_changes(self, messages: list[Message]): + """Applies a list of sync changes, based on what the sync method returned on the remote.""" if not self._session_maker: - raise UnknownFileId("No valid file available, download one with download_budget") - with self._session_maker() as s: - s: sqlalchemy.orm.Session + raise UnknownFileId("No valid file available, download one with download_budget()") + with self.with_session() as s: for message in messages: + if message.dataset == "prefs": + # write it to metadata.json instead + config = json.loads((self._data_dir / "metadata.json").read_text() or "{}") + config[message.row] = message.get_value() + (self._data_dir / "metadata.json").write_text(json.dumps(config)) + continue table = get_class_by_table_name(message.dataset) + column = get_attribute_by_table_name(message.dataset, message.column) entry = s.query(table).get(message.row) if not entry: entry = table(id=message.row) - setattr(entry, message.column, message.get_value()) + setattr(entry, column, message.get_value()) s.add(entry) s.commit() def download_budget(self): - db = requests.get(f"{self.api_url}/{Endpoints.DOWNLOAD_USER_FILE}", headers=self.headers()) - db.raise_for_status() - f = io.BytesIO(db.content) + """Downloads the budget file from the remote. After the file is downloaded, the sync endpoint is queries + for the list of pending changes. The changes are individual row updates, that are then applied on by one to + the downloaded database state.""" + file_bytes = self.download_user_file(self._file.file_id) + f = io.BytesIO(file_bytes) try: zip_file = zipfile.ZipFile(f) except zipfile.BadZipfile as e: @@ -166,31 +195,33 @@ def download_budget(self): zip_file.extractall(self._data_dir) engine = sqlalchemy.create_engine(f"sqlite:///{self._data_dir}/db.sqlite") self._session_maker = sqlalchemy.orm.sessionmaker(engine) + # actual js always calls validation + self.validate() # after downloading the budget, some pending transactions still need to be retrieved using sync request = SyncRequest({"messages": [], "fileId": self._file.file_id, "groupId": self._file.group_id}) request.set_null_timestamp() # using 0 timestamp to retrieve all changes changes = self.sync(request) self.apply_changes(changes.get_messages()) - def sync(self, request: SyncRequest) -> SyncResponse: - response = requests.post( - f"{self.api_url}/{Endpoints.SYNC}", - headers=self.headers(extra_headers={"Content-Type": "application/actual-sync"}), - data=SyncRequest.serialize(request), - ) - response.raise_for_status() - parsed_response = SyncResponse.deserialize(response.content) - return parsed_response # noqa + def load_clock(self): + """See implementation at: + https://github.com/actualbudget/actual/blob/5bcfc71be67c6e7b7c8b444e4c4f60da9ea9fdaa/packages/loot-core/src/server/db/index.ts#L81-L98 + """ + pass def get_transactions(self) -> List[Transactions]: with self._session_maker() as s: query = ( s.query(Transactions) - .options(joinedload(Transactions.account), joinedload(Transactions.category_)) + .options( + joinedload(Transactions.account), + joinedload(Transactions.category), + joinedload(Transactions.payee), + ) .filter( Transactions.date.isnot(None), Transactions.acct.isnot(None), - sqlalchemy.or_(Transactions.isChild == 0, Transactions.parent_id.isnot(None)), + sqlalchemy.or_(Transactions.is_child == 0, Transactions.parent_id.isnot(None)), sqlalchemy.func.coalesce(Transactions.tombstone, 0) == 0, ) .order_by( @@ -202,12 +233,34 @@ def get_transactions(self) -> List[Transactions]: ) return query.all() + def add(self, model: BaseModel): + """Adds a new entry to the local database, sends a sync request to the remote server to synchronize + the local changes, then commits the change on the local database. It's important to note that this process + is not atomic, so if the process is interrupted before it completes successfully, the files would end up in + a unknown state.""" + with self.with_session() as s: + # add to database and see if all works well + s.add(model) + # generate a sync request and sync it to the server + req = SyncRequest({"fileId": self._file.file_id, "groupId": self._file.group_id}) + req.set_timestamp() + req.set_messages(model.convert()) + self.sync(req) + s.commit() + if not self._session: + s.close() + def get_categories(self) -> List[Categories]: - with self._session_maker() as s: + with self.with_session() as s: query = s.query(Categories) return query.all() def get_accounts(self) -> List[Accounts]: - with self._session_maker() as s: + with self.with_session() as s: query = s.query(Accounts) return query.all() + + def get_payees(self) -> List[Payees]: + with self.with_session() as s: + query = s.query(Payees) + return query.all() diff --git a/actual/api.py b/actual/api.py new file mode 100644 index 0000000..cebcb4a --- /dev/null +++ b/actual/api.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +import enum +from typing import List, Optional + +import requests +from pydantic import BaseModel, Field + +from actual.exceptions import AuthorizationError, UnknownFileId +from actual.protobuf_models import SyncRequest, SyncResponse + + +class Endpoints(enum.Enum): + LOGIN = "account/login" + INFO = "info" + ACCOUNT_VALIDATE = "account/validate" + NEEDS_BOOTSTRAP = "account/needs-bootstrap" + SYNC = "sync/sync" + LIST_USER_FILES = "sync/list-user-files" + GET_USER_FILE_INFO = "sync/get-user-file-info" + UPDATE_USER_FILE_NAME = "sync/update-user-file-name" + DOWNLOAD_USER_FILE = "sync/download-user-file" + UPLOAD_USER_FILE = "sync/upload-user-file" + RESET_USER_FILE = "sync/reset-user-file" + # encryption related + USER_GET_KEY = "sync/user-get-key" + USER_CREATE_KEY = "sync/user-create-key" + # data related + DATA_FILE_INDEX = "data-file-index.txt" + DEFAULT_DB = "data/default-db.sqlite" + MIGRATIONS = "data/migrations" + + def __str__(self): + return self.value + + +class StatusCode(enum.Enum): + OK = "ok" + + +class StatusDTO(BaseModel): + status: StatusCode + + +class FileDTO(BaseModel): + deleted: Optional[int] + file_id: Optional[str] = Field(..., alias="fileId") + group_id: Optional[str] = Field(..., alias="groupId") + name: Optional[str] + + +class RemoteFileListDTO(FileDTO): + encrypt_key_id: Optional[str] = Field(..., alias="encryptKeyId") + + +class RemoteFileDTO(FileDTO): + encrypt_meta: Optional[EncryptMetaDTO] = Field(..., alias="encryptMeta") + + +class TokenDTO(BaseModel): + token: Optional[str] + + +class LoginDTO(StatusDTO): + data: TokenDTO + + +class UploadUserFileDTO(StatusDTO): + group_id: str = Field(..., alias="groupId") + + +class ListUserFilesDTO(StatusDTO): + data: List[RemoteFileListDTO] + + +class IsValidatedDTO(BaseModel): + validated: Optional[bool] + + +class ValidateDTO(StatusDTO): + data: IsValidatedDTO + + +class EncryptionDTO(BaseModel): + id: Optional[str] + salt: Optional[str] + test: Optional[str] + + +class EncryptMetaDTO(BaseModel): + key_id: Optional[str] = Field(..., alias="keyId") + algorithm: Optional[str] + iv: Optional[str] + auth_tag: Optional[str] = Field(..., alias="authTag") + + +class UserGetKeyDTO(StatusDTO): + data: EncryptionDTO + + +class BuildDTO(BaseModel): + name: str + description: Optional[str] + version: Optional[str] + + +class InfoDTO(BaseModel): + build: BuildDTO + + +class IsBootstrapedDTO(BaseModel): + bootstrapped: bool + + +class BootstrapInfoDTO(StatusDTO): + data: IsBootstrapedDTO + + +class ActualServer: + def __init__( + self, + base_url: str = "http://localhost:5006", + token: str = None, + password: str = None, + ): + self.api_url = base_url + self._token = token + if token is None and password is None: + raise ValueError("Either provide a valid token or a password.") + # already try to login if password was provided + if password: + self.login(password) + + def login(self, password: str) -> LoginDTO: + """Logs in on the Actual server using the password provided. Raises `AuthorizationError` if it fails to + authenticate the user.""" + if not password: + raise AuthorizationError("Trying to login but not password was provided.") + response = requests.post(f"{self.api_url}/{Endpoints.LOGIN}", json={"password": password}) + response.raise_for_status() + login_response = LoginDTO.parse_obj(response.json()) + if login_response.data.token is None: + raise AuthorizationError("Could not validate password on login.") + self._token = login_response.data.token + return login_response + + def headers(self, file_id: str = None, extra_headers: dict = None) -> dict: + """Generates headers by retrieving a token, if one is not provided, and auto-filling the file id.""" + if not self._token: + raise AuthorizationError("Token not available for requests. Use the login() method or provide a token.") + headers = {"X-ACTUAL-TOKEN": self._token} + if file_id: + headers["X-ACTUAL-FILE-ID"] = file_id + if extra_headers: + headers = headers | extra_headers + return headers + + def info(self) -> InfoDTO: + """Gets the information from the Actual server, like the name and version.""" + response = requests.get(f"{self.api_url}/{Endpoints.INFO}") + response.raise_for_status() + return InfoDTO.parse_obj(response.json()) + + def validate(self) -> ValidateDTO: + """Validates""" + response = requests.get(f"{self.api_url}/{Endpoints.ACCOUNT_VALIDATE}", headers=self.headers()) + response.raise_for_status() + return ValidateDTO.parse_obj(response.json()) + + def needs_bootstrap(self) -> BootstrapInfoDTO: + """Checks if the Actual needs bootstrap, in other words, if it needs a master password for the server.""" + response = requests.get(f"{self.api_url}/{Endpoints.NEEDS_BOOTSTRAP}") + response.raise_for_status() + return BootstrapInfoDTO.parse_obj(response.json()) + + def data_file_index(self) -> List[str]: + """Gets all the migration file references for the actual server.""" + response = requests.get(f"{self.api_url}/{Endpoints.DATA_FILE_INDEX}") + response.raise_for_status() + return response.content.decode().splitlines() + + def data_file(self, file_path: str) -> bytes: + """Gets the content of the individual migration file from server.""" + response = requests.get(f"{self.api_url}/data/{file_path}") + response.raise_for_status() + return response.content + + def reset_user_file(self, file_id: str) -> StatusDTO: + """Resets the file. If the file_id is not provided, the current file set is reset. Usually used together with + the upload_user_file() method.""" + if file_id is None: + raise UnknownFileId("Could not reset the file without a valid 'file_id'") + request = requests.post( + f"{self.api_url}/{Endpoints.RESET_USER_FILE}", json={"fileId": file_id, "token": self._token} + ) + request.raise_for_status() + return StatusDTO.parse_obj(request.json()) + + def download_user_file(self, file_id: str) -> bytes: + """Downloads the user file based on the file_id provided. Returns the `bytes` from the response, which is a + zipped folder of the database `db.sqlite` and the `metadata.json`. If the database is encrypted, the key id + has to be retrieved additionally using user_get_key().""" + db = requests.get(f"{self.api_url}/{Endpoints.DOWNLOAD_USER_FILE}", headers=self.headers(file_id)) + db.raise_for_status() + return db.content + + def upload_user_file(self, binary_data: bytes, file_id: str, file_name: str = "My Finances") -> UploadUserFileDTO: + """Uploads the binary data, which is a zip folder containing the `db.sqlite` and the `metadata.json`.""" + request = requests.post( + f"{self.api_url}/{Endpoints.UPLOAD_USER_FILE}", + data=binary_data, + headers=self.headers( + extra_headers={ + "X-ACTUAL-FORMAT": "2", + "X-ACTUAL-FILE-ID": file_id, + "X-ACTUAL-NAME": file_name, + "Content-Type": "application/encrypted-file", + } + ), + ) + request.raise_for_status() + return UploadUserFileDTO.parse_obj(request.json()) + + def list_user_files(self) -> ListUserFilesDTO: + """Lists the user files. If the response item contains `encrypt_key_id` different from `None`, then the + file must be decrypted on retrieval.""" + response = requests.get(f"{self.api_url}/{Endpoints.LIST_USER_FILES}", headers=self.headers()) + response.raise_for_status() + return ListUserFilesDTO.parse_obj(response.json()) + + def get_user_file_info(self, file_id: str) -> RemoteFileDTO: + """Gets the user file information, including the encryption metadata.""" + response = requests.get(f"{self.api_url}/{Endpoints.GET_USER_FILE_INFO}", headers=self.headers(file_id)) + response.raise_for_status() + return RemoteFileDTO.parse_obj(response.json()) + + def update_user_file_name(self, file_id: str, file_name: str) -> StatusDTO: + """Updates the file name for the budget on the remote server.""" + response = requests.post( + f"{self.api_url}/{Endpoints.UPDATE_USER_FILE_NAME}", + json={"fileId": file_id, "name": file_name, "token": self._token}, + headers=self.headers(), + ) + response.raise_for_status() + return StatusDTO.parse_obj(response.json()) + + def user_get_key(self, file_id: str) -> UserGetKeyDTO: + """Gets the key information associated with a user file, including the algorithm, key, salt and iv.""" + response = requests.get( + f"{self.api_url}/{Endpoints.USER_GET_KEY}", + json={ + "file_id": file_id, + "token": self._token, + }, + headers=self.headers(file_id), + ) + response.raise_for_status() + return UserGetKeyDTO.parse_obj(response.json()) + + def user_create_key(self, file_id: str, key_id: str, key_salt: str) -> StatusDTO: + """Creates a new key for the user file. The key has to be used then to encrypt the local file, and this file + still needs to be uploaded.""" + test_content = "" # todo: see how this is generated + response = requests.post( + f"{self.api_url}/{Endpoints.USER_CREATE_KEY}", + headers=self.headers(), + json={ + "fileId": file_id, + "keyId": key_id, + "keySalt": key_salt, + "testContent": test_content, + "token": self._token, + }, + ) + return StatusDTO.parse_obj(response.json()) + + def sync(self, request: SyncRequest) -> SyncResponse: + """Calls the sync endpoint with a request and returns the response. Both the request and response are + protobuf models. The request and response are not standard REST, but rather protobuf binary serialized data. + The server stores this serialized data to allow the user to replay all changes to the database and construct + a local copy.""" + response = requests.post( + f"{self.api_url}/{Endpoints.SYNC}", + headers=self.headers(request.fileId, extra_headers={"Content-Type": "application/actual-sync"}), + data=SyncRequest.serialize(request), + ) + response.raise_for_status() + parsed_response = SyncResponse.deserialize(response.content) + return parsed_response # noqa diff --git a/actual/database.py b/actual/database.py index 3ef5101..491f2bb 100644 --- a/actual/database.py +++ b/actual/database.py @@ -1,4 +1,7 @@ -from typing import List, Optional +import datetime +import decimal +import uuid +from typing import List, Optional, Union from sqlalchemy import ( Boolean, @@ -13,15 +16,63 @@ ) from sqlmodel import Field, Relationship, SQLModel +from actual.protobuf_models import Message -def get_class_by_table_name(table_name: str) -> SQLModel | None: +""" +This variable contains the internal model mappings for all databases. It solves a couple of issues, namely having the +mapping from __tablename__ to the actual SQLAlchemy class, and later mapping the SQL column into the Pydantic field, +which could be different and follows the Python naming convention. An example is the field `Transactions.is_parent`, +that converts into the SQL equivalent `transactions.isParent`. In this case, we would have the following entries: + + __TABLE_COLUMNS_MAP__ = { + "transactions": { + "entity": , + "columns": { + "isParent": "is_parent" + } + } + } +""" +__TABLE_COLUMNS_MAP__ = dict() + + +def get_class_by_table_name(table_name: str) -> Union[SQLModel, None]: """ Returns, based on the defined tables __tablename__ the corresponding SQLModel object. If not found, returns None. """ - for entry in SQLModel._sa_registry.mappers: - if entry.entity.__tablename__ == table_name: - return entry.entity - return None + return __TABLE_COLUMNS_MAP__.get(table_name, {}).get("entity", None) + + +def get_attribute_by_table_name(table_name: str, column_name: str) -> Union[str, None]: + """ + Returns, based, on the defined tables __tablename__ and the sacolumn name, the correct pydantic attribute. + If not found, returns None. + + :param table_name: SQL table name. + :param column_name: SQL column name. + :return: Pydantic attribute name. + """ + return __TABLE_COLUMNS_MAP__.get(table_name, {}).get("columns", {}).get(column_name, None) + + +class BaseModel(SQLModel): + def convert(self) -> List[Message]: + """Convert the object into distinct entries for sync method. Based on the original implementation: + + https://github.com/actualbudget/actual/blob/98c17bd5e0f13e27a09a7f6ac176510530572be7/packages/loot-core/src/server/aql/schema-helpers.ts#L146 + """ + dataset = self.__tablename__ + changes = [] + row = getattr(self, "id", None) # also helps lazy loading the instance + if row is None: + raise AttributeError(f"Cannot convert model {self.__name__} because it misses the 'id' attribute.") + for column, value in self.model_dump().items(): + if value is None or column == "id": + continue + m = Message(dict(dataset=dataset, row=row, column=column)) + m.set_value(value) + changes.append(m) + return changes class Meta(SQLModel, table=True): @@ -75,7 +126,7 @@ class Categories(SQLModel, table=True): tombstone: Optional[int] = Field(default=None, sa_column=Column("tombstone", Integer, server_default=text("0"))) goal_def: Optional[str] = Field(default=None, sa_column=Column("goal_def", Text, server_default=text("null"))) - transactions: List["Transactions"] = Relationship(back_populates="category_") + transactions: List["Transactions"] = Relationship(back_populates="category") class CategoryGroups(SQLModel, table=True): @@ -93,7 +144,7 @@ class CategoryMapping(SQLModel, table=True): __tablename__ = "category_mapping" id: Optional[str] = Field(default=None, sa_column=Column("id", Text, primary_key=True)) - transferId: Optional[str] = Field(default=None, sa_column=Column("transferId", Text)) + transfer_id: Optional[str] = Field(default=None, sa_column=Column("transferId", Text)) class CreatedBudgets(SQLModel, table=True): @@ -142,7 +193,7 @@ class PayeeMapping(SQLModel, table=True): __tablename__ = "payee_mapping" id: Optional[str] = Field(default=None, sa_column=Column("id", Text, primary_key=True)) - targetId: Optional[str] = Field(default=None, sa_column=Column("targetId", Text)) + target_id: Optional[str] = Field(default=None, sa_column=Column("targetId", Text)) class Payees(SQLModel, table=True): @@ -152,6 +203,8 @@ class Payees(SQLModel, table=True): tombstone: Optional[int] = Field(default=None, sa_column=Column("tombstone", Integer, server_default=text("0"))) transfer_acct: Optional[str] = Field(default=None, sa_column=Column("transfer_acct", Text)) + transactions: List["Transactions"] = Relationship(back_populates="payee") + class ReflectBudgets(SQLModel, table=True): __tablename__ = "reflect_budgets" @@ -224,7 +277,7 @@ class TransactionFilters(SQLModel, table=True): tombstone: Optional[int] = Field(default=None, sa_column=Column("tombstone", Integer, server_default=text("0"))) -class Transactions(SQLModel, table=True): +class Transactions(BaseModel, table=True): __table_args__ = ( Index("trans_category", "category"), Index("trans_category_date", "category", "date"), @@ -234,12 +287,12 @@ class Transactions(SQLModel, table=True): ) id: Optional[str] = Field(default=None, sa_column=Column("id", Text, primary_key=True)) - isParent: Optional[int] = Field(default=None, sa_column=Column("isParent", Integer, server_default=text("0"))) - isChild: Optional[int] = Field(default=None, sa_column=Column("isChild", Integer, server_default=text("0"))) + is_parent: Optional[int] = Field(default=None, sa_column=Column("isParent", Integer, server_default=text("0"))) + is_child: Optional[int] = Field(default=None, sa_column=Column("isChild", Integer, server_default=text("0"))) acct: Optional[str] = Field(default=None, sa_column=Column("acct", Text, ForeignKey("accounts.id"))) - category: Optional[str] = Field(default=None, sa_column=Column("category", Text, ForeignKey("categories.id"))) + category_id: Optional[str] = Field(default=None, sa_column=Column("category", Text, ForeignKey("categories.id"))) amount: Optional[int] = Field(default=None, sa_column=Column("amount", Integer)) - description: Optional[str] = Field(default=None, sa_column=Column("description", Text)) + description: Optional[str] = Field(default=None, sa_column=Column("description", Text, ForeignKey("payees.id"))) notes: Optional[str] = Field(default=None, sa_column=Column("notes", Text)) date: Optional[int] = Field(default=None, sa_column=Column("date", Integer)) financial_id: Optional[str] = Field(default=None, sa_column=Column("financial_id", Text)) @@ -261,7 +314,32 @@ class Transactions(SQLModel, table=True): reconciled: Optional[int] = Field(default=None, sa_column=Column("reconciled", Integer, server_default=text("0"))) account: Optional["Accounts"] = Relationship(back_populates="transactions") - category_: Optional["Categories"] = Relationship(back_populates="transactions") + category: Optional["Categories"] = Relationship(back_populates="transactions") + payee: Optional["Payees"] = Relationship(back_populates="transactions") + + @classmethod + def new( + cls, + account_id: str, + amount: decimal.Decimal, + date: datetime.date, + notes: Optional[str] = None, + category: Optional[Categories] = None, + payee: Optional[Payees] = None, + ): + date_int = int(datetime.date.strftime(date, "%Y%m%d")) + return cls( + id=str(uuid.uuid4()), + acct=account_id, + date=date_int, + amount=int(amount * 100), + category=category, + payee=payee, + notes=notes, + reconciled=0, + cleared=0, + sort_order=datetime.datetime.utcnow().timestamp(), + ) class ZeroBudgetMonths(SQLModel, table=True): @@ -292,3 +370,14 @@ class PendingTransactions(SQLModel, table=True): date: Optional[str] = Field(default=None, sa_column=Column("date", Text)) account: Optional["Accounts"] = Relationship(back_populates="pending_transactions") + + +for entry in SQLModel._sa_registry.mappers: + t_name = entry.entity.__tablename__ + if t_name not in __TABLE_COLUMNS_MAP__: + __TABLE_COLUMNS_MAP__[t_name] = {"entity": entry.entity, "columns": {}} + table_columns = list(c.name for c in entry.columns) + # the name and property name of the pydantic property and database column can be different + for key, column in dict(entry.entity.__dict__).items(): + if hasattr(column, "name") and getattr(column, "name") in table_columns: + __TABLE_COLUMNS_MAP__[t_name]["columns"][column.name] = key diff --git a/actual/exceptions.py b/actual/exceptions.py new file mode 100644 index 0000000..d268bb1 --- /dev/null +++ b/actual/exceptions.py @@ -0,0 +1,14 @@ +class ActualError(Exception): + pass + + +class AuthorizationError(ActualError): + pass + + +class UnknownFileId(ActualError): + pass + + +class InvalidZipFile(ActualError): + pass diff --git a/actual/models.py b/actual/models.py deleted file mode 100644 index e034e4a..0000000 --- a/actual/models.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Any, Optional - -from pydantic import BaseModel, Field - - -class RemoteFile(BaseModel): - deleted: Optional[int] = None - file_id: Optional[str] = Field(None, alias="fileId") - group_id: Optional[str] = Field(None, alias="groupId") - name: Optional[str] = None - encrypt_meta: Optional[Any] = Field(None, alias="encryptMeta") diff --git a/actual/protobuf_models.py b/actual/protobuf_models.py index faef8d5..69f0c7d 100644 --- a/actual/protobuf_models.py +++ b/actual/protobuf_models.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import uuid @@ -14,6 +16,31 @@ """ +def timestamp(client_id: str = None, now: datetime.datetime = None) -> str: + """Actual uses Hybrid Unique Logical Clock (HULC) timestamp generator. + + Timestamps serialize into a 46-character collatable string + * example: 2015-04-24T22:23:42.123Z-1000-0123456789ABCDEF + * example: 2015-04-24T22:23:42.123Z-1000-A219E7A71CC18912 + + See https://github.com/actualbudget/actual/blob/a9362cc6f9b974140a760ad05816cac51c849769/packages/crdt/src/crdt/timestamp.ts + for reference. + """ + if not now: + now = datetime.datetime.utcnow() + if not client_id: + client_id = get_client_id() + return f"{now.isoformat(timespec='milliseconds')}Z-0000-{client_id}" + + +def get_client_id(): + """Creates a client id for the HULC request. Copied implementation from: + + https://github.com/actualbudget/actual/blob/a9362cc6f9b974140a760ad05816cac51c849769/packages/crdt/src/crdt/timestamp.ts#L80 + """ + return str(uuid.uuid4()).replace("-", "")[-16:] + + class EncryptedData(proto.Message): iv = proto.Field(proto.BYTES, number=1) authTag = proto.Field(proto.BYTES, number=2) @@ -26,21 +53,43 @@ class Message(proto.Message): column = proto.Field(proto.STRING, number=3) value = proto.Field(proto.STRING, number=4) - def get_value(self) -> str | int: + def get_value(self) -> str | int | float | None: + """Serialization types from Actual. Source: + + https://github.com/actualbudget/actual/blob/998efb9447da6f8ce97956cbe83d6e8a3c18cf53/packages/loot-core/src/server/sync/index.ts#L154-L160 + """ datatype, _, value = self.value.partition(":") if datatype == "S": return value elif datatype == "N": return int(value) + elif datatype == "0": + return None else: raise ValueError(f"Conversion not supported for datatype '{datatype}'") + def set_value(self, value: str | int | float | None) -> str: + if isinstance(value, str): + datatype = "S" + elif isinstance(value, int): + datatype = "N" + elif value is None: + datatype = "0" + else: + raise ValueError(f"Conversion not supported for datatype '{type(value)}'") + self.value = f"{datatype}:{value}" + return self.value + class MessageEnvelope(proto.Message): timestamp = proto.Field(proto.STRING, number=1) isEncrypted = proto.Field(proto.BOOL, number=2) content = proto.Field(proto.BYTES, number=3) + def set_timestamp(self, client_id: str = None, now: datetime.datetime = None) -> str: + self.timestamp = timestamp(client_id, now) + return self.timestamp + class SyncRequest(proto.Message): messages = proto.RepeatedField(MessageEnvelope, number=1) @@ -50,31 +99,19 @@ class SyncRequest(proto.Message): since = proto.Field(proto.STRING, number=6) def set_timestamp(self, client_id: str = None, now: datetime.datetime = None) -> str: - """Actual uses Hybrid Unique Logical Clock (HULC) timestamp generator. - - Timestamps serialize into a 46-character collatable string - * example: 2015-04-24T22:23:42.123Z-1000-0123456789ABCDEF - * example: 2015-04-24T22:23:42.123Z-1000-A219E7A71CC18912 - - See https://github.com/actualbudget/actual/blob/a9362cc6f9b974140a760ad05816cac51c849769/packages/crdt/src/crdt/timestamp.ts - for reference. - """ - if not now: - now = datetime.datetime.utcnow() - if not client_id: - client_id = self.client_id() - self.since = f"{now.isoformat(timespec='milliseconds')}Z-0000-{client_id}" + self.since = timestamp(client_id, now) return self.since def set_null_timestamp(self) -> str: return self.set_timestamp(None, datetime.datetime(1970, 1, 1, 0, 0, 0, 0)) - def client_id(self): - """Creates a client id for the HULC request. Copied implementation from: - - https://github.com/actualbudget/actual/blob/a9362cc6f9b974140a760ad05816cac51c849769/packages/crdt/src/crdt/timestamp.ts#L80 - """ - return str(uuid.uuid4()).replace("-", "")[-16:] + def set_messages(self, messages: list[Message]): + _messages = [] + for message in messages: + m = MessageEnvelope({"content": Message.serialize(message), "isEncrypted": False}) + m.set_timestamp() + _messages.append(m) + self.messages = _messages class SyncResponse(proto.Message): diff --git a/requirements.txt b/requirements.txt index 139bf5f..942334d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ requests>=2 -sqlmodel==0.0.14 +sqlmodel==0.0.16 pydantic>=2 sqlalchemy>=1.4 proto-plus>=1 diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..6902c79 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,17 @@ +from actual.database import ( + Transactions, + get_attribute_by_table_name, + get_class_by_table_name, +) + + +def test_get_class_by_table_name(): + assert get_class_by_table_name("transactions") == Transactions + assert get_class_by_table_name("foo") is None + + +def test_get_attribute_by_table_name(): + assert get_attribute_by_table_name("transactions", "isParent") == "is_parent" + assert get_attribute_by_table_name("transactions", "category") == "category_id" + assert get_attribute_by_table_name("transactions", "foo") is None + assert get_attribute_by_table_name("foo", "bar") is None