Skip to content

Commit

Permalink
feat: Add support for running and inserting rules (#12)
Browse files Browse the repository at this point in the history
- Introduce Rule and RuleSet for creating new rules and running existing rules on the local database.
- Refactor some parts of the code to make the API naming more uniform
- Update database model to include the custom reports
  • Loading branch information
bvanelli authored May 6, 2024
1 parent 86bf55e commit dd7aad2
Show file tree
Hide file tree
Showing 14 changed files with 804 additions and 83 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
python -m pip install --upgrade pip
pip install ruff pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
- name: Lint with ruff
run: |
# stop the build if there are Python syntax errors or undefined names
Expand Down
45 changes: 28 additions & 17 deletions actual/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import tempfile
import uuid
import zipfile
from typing import Union
from os import PathLike
from typing import IO, Union

import sqlalchemy
import sqlalchemy.orm
Expand Down Expand Up @@ -47,7 +48,7 @@ def __init__(
books using Python.
Parts of the implementation are available at the following file:
https://github.com/actualbudget/actual/blob/master/packages/loot-core/src/server/cloud-storage.ts
https://github.com/actualbudget/actual/blob/2178da0414958064337b2c53efc95ff1d3abf98a/packages/loot-core/src/server/cloud-storage.ts
:param base_url: url of the running Actual server
:param token: the token for authentication, if this is available (optional)
Expand Down Expand Up @@ -129,7 +130,7 @@ def run_migrations(self, migration_files: list[str]):
conn.commit()
conn.close()

def create_budget(self, budget_name: str, encryption_password: str = None):
def create_budget(self, budget_name: str):
"""Creates a budget using the remote server default database and migrations. If password is provided, the
budget will be encrypted."""
migration_files = self.data_file_index()
Expand Down Expand Up @@ -162,12 +163,16 @@ def create_budget(self, budget_name: str, encryption_password: str = None):
# create a clock
self.load_clock()

def _gen_zip(self) -> bytes:
binary_data = io.BytesIO()
with zipfile.ZipFile(binary_data, "a", zipfile.ZIP_DEFLATED, False) as z:
def export_data(self, output_file: str | PathLike[str] | IO[bytes] = None) -> bytes:
"""Export your data as a zip file containing db.sqlite and metadata.json files. It can be imported into another
Actual instance by closing an open file (if any), then clicking the “Import file” button, then choosing
“Actual.” Even though encryption is enabled, the exported zip file will not have any encryption."""
if not output_file:
output_file = io.BytesIO()
with zipfile.ZipFile(output_file, "a", zipfile.ZIP_DEFLATED, False) as z:
z.write(self._data_dir / "db.sqlite", "db.sqlite")
z.write(self._data_dir / "metadata.json", "metadata.json")
return binary_data.getvalue()
return output_file.getvalue()

def encrypt(self, encryption_password: str):
"""Encrypts the local database using a new key, and re-uploads to the server.
Expand All @@ -189,7 +194,7 @@ def encrypt(self, encryption_password: str):
raise ActualError("Budget is encrypted but password was not provided")
self._master_key = create_key_buffer(encryption_password, salt)
# encrypt binary data with
encrypted = encrypt(self._file.encrypt_key_id, self._master_key, self._gen_zip())
encrypted = encrypt(self._file.encrypt_key_id, self._master_key, self.export_data())
binary_data = io.BytesIO(base64.b64decode(encrypted["value"]))
encryption_meta = encrypted["meta"]
self.reset_user_file(self._file.file_id)
Expand Down Expand Up @@ -227,6 +232,10 @@ def apply_changes(self, messages: list[Message]):
self.update_metadata({message.row: message.get_value()})
continue
table = get_class_by_table_name(message.dataset)
if table is None:
raise ActualError(
f"Actual is at a version not supported by the library: '{message.dataset}' not found"
)
column = get_attribute_by_table_name(message.dataset, message.column)
entry = s.query(table).get(message.row)
if not entry:
Expand Down Expand Up @@ -255,18 +264,20 @@ def download_budget(self, encryption_password: str = None):

if self._file.encrypt_key_id and encryption_password is None:
raise ActualError("File is encrypted but no encryption password provided.")
if encryption_password is not None:
if encryption_password is not None and self._file.encrypt_key_id:
file_info = self.get_user_file_info(self._file.file_id)
key_info = self.user_get_key(self._file.file_id)
self._master_key = create_key_buffer(encryption_password, key_info.data.salt)
# decrypt file bytes
file_bytes = decrypt_from_meta(self._master_key, file_bytes, file_info.data.encrypt_meta)
self._load_zip(file_bytes)
self.import_zip(io.BytesIO(file_bytes))
# actual js always calls validation
self.validate()
self.sync()

def _load_zip(self, file_bytes: bytes):
f = io.BytesIO(file_bytes)
def import_zip(self, file_bytes: str | PathLike[str] | IO[bytes]):
try:
zip_file = zipfile.ZipFile(f)
zip_file = zipfile.ZipFile(file_bytes)
except zipfile.BadZipfile as e:
raise InvalidZipFile(f"Invalid zip file: {e}") from None
if not self._data_dir:
Expand All @@ -275,10 +286,10 @@ def _load_zip(self, file_bytes: bytes):
zip_file.extractall(self._data_dir)
engine = sqlalchemy.create_engine(f"sqlite:///{self._data_dir}/db.sqlite")
self.session_maker = sqlalchemy.orm.sessionmaker(engine, autoflush=False)
# actual js always calls validation
self.validate()
# load the client id
self.load_clock()

def sync(self):
# after downloading the budget, some pending transactions still need to be retrieved using sync
request = SyncRequest(
{
Expand All @@ -289,7 +300,7 @@ def _load_zip(self, file_bytes: bytes):
}
)
request.set_null_timestamp(client_id=self._client.client_id) # using 0 timestamp to retrieve all changes
changes = self.sync(request)
changes = self.sync_sync(request)
self.apply_changes(changes.get_messages(self._master_key))
if changes.messages:
self._client = HULC_Client.from_timestamp(changes.messages[-1].timestamp)
Expand Down Expand Up @@ -338,4 +349,4 @@ def commit(self):
# make sure changes are valid before syncing
self._session.commit()
# sync all changes to the server
self.sync(req)
self.sync_sync(req)
2 changes: 1 addition & 1 deletion actual/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def user_create_key(self, file_id: str, key_id: str, password: str, key_salt: st
)
return StatusDTO.parse_obj(response.json())

def sync(self, request: SyncRequest) -> SyncResponse:
def sync_sync(self, request: SyncRequest) -> SyncResponse:
"""Calls the sync endpoint with a request and returns the response. Both the request and response are
protobuf models. The request and response are not standard REST, but rather protobuf binary serialized data.
The server stores this serialized data to allow the user to replay all changes to the database and construct
Expand Down
34 changes: 32 additions & 2 deletions actual/crypto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import base64
import os
import uuid

import cryptography.exceptions
from cryptography.hazmat.primitives import hashes
Expand All @@ -20,8 +23,6 @@ def make_salt(length: int = 32) -> str:


def create_key_buffer(password: str, key_salt: str) -> bytes:
if key_salt is None:
key_salt = make_salt()
kdf = PBKDF2HMAC(algorithm=hashes.SHA512(), length=32, salt=key_salt.encode(), iterations=10_000)
return kdf.derive(password.encode())

Expand Down Expand Up @@ -66,3 +67,32 @@ def make_test_message(key_id: str, key: bytes) -> dict:
binary_message = Message.serialize(m)
# return encrypted binary message
return encrypt(key_id, key, binary_message)


def is_uuid(text: str, version: int = 4):
"""
Check if uuid_to_test is a valid UUID.
Taken from https://stackoverflow.com/a/54254115/12681470
Parameters
----------
uuid_to_test : str
version : {1, 2, 3, 4}
Returns
-------
`True` if uuid_to_test is a valid UUID, otherwise `False`.
Examples
--------
>>> is_uuid('c9bf9e57-1685-4c89-bafb-ff5af830be8a')
True
>>> is_uuid('c9bf9e58')
False
"""
try:
uuid.UUID(str(text), version=version)
return True
except ValueError:
return False
83 changes: 56 additions & 27 deletions actual/database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
"""
This file was partially generated using sqlacodegen using the downloaded version of the db.sqlite file export
in order to update this file, you can generate the code with:
> sqlacodegen --generator sqlmodels sqlite:///db.sqlite
and patch the necessary models by merging the results.
"""

import datetime
import decimal
import uuid
from typing import List, Optional, Union

from sqlalchemy import (
Expand Down Expand Up @@ -95,8 +103,8 @@ def convert(self, is_new: bool = True) -> List[Message]:
return changes

def delete(self):
if not getattr(self, "tombstone", None):
raise AttributeError(f"Model {self.__name__} has no tombstone field and cannot be deleted.")
if not hasattr(self, "tombstone"):
raise AttributeError(f"Model {self.__class__.__name__} has no tombstone field and cannot be deleted.")
setattr(self, "tombstone", 1)


Expand Down Expand Up @@ -184,6 +192,42 @@ class CreatedBudgets(SQLModel, table=True):
month: Optional[str] = Field(default=None, sa_column=Column("month", Text, primary_key=True))


class CustomReports(BaseModel, table=True):
__tablename__ = "custom_reports"

id: Optional[str] = Field(default=None, sa_column=Column("id", Text, primary_key=True))
name: Optional[str] = Field(default=None, sa_column=Column("name", Text))
start_date: Optional[str] = Field(default=None, sa_column=Column("start_date", Text))
end_date: Optional[str] = Field(default=None, sa_column=Column("end_date", Text))
date_static: Optional[int] = Field(default=None, sa_column=Column("date_static", Integer, server_default=text("0")))
date_range: Optional[str] = Field(default=None, sa_column=Column("date_range", Text))
mode: Optional[str] = Field(default=None, sa_column=Column("mode", Text, server_default=text("'total'")))
group_by: Optional[str] = Field(default=None, sa_column=Column("group_by", Text, server_default=text("'Category'")))
balance_type: Optional[str] = Field(
default=None, sa_column=Column("balance_type", Text, server_default=text("'Expense'"))
)
show_empty: Optional[int] = Field(default=None, sa_column=Column("show_empty", Integer, server_default=text("0")))
show_offbudget: Optional[int] = Field(
default=None, sa_column=Column("show_offbudget", Integer, server_default=text("0"))
)
show_hidden: Optional[int] = Field(default=None, sa_column=Column("show_hidden", Integer, server_default=text("0")))
show_uncategorized: Optional[int] = Field(
default=None, sa_column=Column("show_uncategorized", Integer, server_default=text("0"))
)
selected_categories: Optional[str] = Field(default=None, sa_column=Column("selected_categories", Text))
graph_type: Optional[str] = Field(
default=None, sa_column=Column("graph_type", Text, server_default=text("'BarGraph'"))
)
conditions: Optional[str] = Field(default=None, sa_column=Column("conditions", Text))
conditions_op: Optional[str] = Field(
default=None, sa_column=Column("conditions_op", Text, server_default=text("'and'"))
)
metadata_: Optional[str] = Field(default=None, sa_column=Column("metadata", Text))
interval: Optional[str] = Field(default=None, sa_column=Column("interval", Text, server_default=text("'Monthly'")))
color_scheme: Optional[str] = Field(default=None, sa_column=Column("color_scheme", Text))
tombstone: Optional[int] = Field(default=None, sa_column=Column("tombstone", Integer, server_default=text("0")))


class Kvcache(SQLModel, table=True):
key: Optional[str] = Field(default=None, sa_column=Column("key", Text, primary_key=True))
value: Optional[str] = Field(default=None, sa_column=Column("value", Text))
Expand Down Expand Up @@ -351,33 +395,18 @@ class Transactions(BaseModel, table=True):
category: Optional["Categories"] = Relationship(back_populates="transactions")
payee: Optional["Payees"] = Relationship(back_populates="transactions")

@classmethod
def new(
cls,
account_id: str,
amount: decimal.Decimal,
date: datetime.date,
notes: Optional[str] = None,
category: Optional[Categories] = None,
payee: Optional[Payees] = None,
):
date_int = int(datetime.date.strftime(date, "%Y%m%d"))
return cls(
id=str(uuid.uuid4()),
acct=account_id,
date=date_int,
amount=int(amount * 100),
category=category,
payee=payee,
notes=notes,
reconciled=0,
cleared=0,
sort_order=datetime.datetime.utcnow().timestamp(),
)

def get_date(self) -> datetime.date:
return datetime.datetime.strptime(str(self.date), "%Y%m%d").date()

def set_date(self, date: datetime.date):
self.date = int(datetime.date.strftime(date, "%Y%m%d"))

def set_amount(self, amount: Union[decimal.Decimal, int, float]):
self.amount = int(amount * 100)

def get_amount(self) -> decimal.Decimal:
return decimal.Decimal(self.amount) / decimal.Decimal(100)


class ZeroBudgetMonths(SQLModel, table=True):
__tablename__ = "zero_budget_months"
Expand Down
3 changes: 2 additions & 1 deletion actual/protobuf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import proto

from actual.crypto import decrypt, encrypt
from actual.exceptions import ActualDecryptionError

"""
Protobuf message definitions taken from:
Expand Down Expand Up @@ -150,7 +151,7 @@ def get_messages(self, master_key: bytes = None) -> list[Message]:
for message in self.messages: # noqa
if message.isEncrypted:
if not master_key:
raise ValueError("Master key not provided and data is encrypted.")
raise ActualDecryptionError("Master key not provided and data is encrypted.")
encrypted = EncryptedData.deserialize(message.content)
content = decrypt(master_key, encrypted.iv, encrypted.data, encrypted.authTag)
else:
Expand Down
Loading

0 comments on commit dd7aad2

Please sign in to comment.