Skip to content

Commit

Permalink
feat: Continue working on getting all helper methods ready.
Browse files Browse the repository at this point in the history
  • Loading branch information
Brunno Vanelli committed Apr 4, 2024
1 parent f90b6a7 commit 298a41e
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 22 deletions.
11 changes: 9 additions & 2 deletions actual/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,19 @@ def __init__(
# set the correct file
if file:
self.set_file(file)
self._in_context = False

def __enter__(self) -> Actual:
if self._file:
self.download_budget()
self._session = self.session_maker()
self._in_context = True
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self._session:
self._session.close()
self._in_context = False

@property
def session(self) -> sqlalchemy.orm.Session:
Expand Down Expand Up @@ -151,7 +154,9 @@ def create_budget(self, budget_name: str):
self.run_migrations(migration_files[1:])
# generate a session
engine = sqlalchemy.create_engine(f"sqlite:///{self._data_dir}/db.sqlite")
self.session_maker = sqlalchemy.orm.sessionmaker(engine)
self.session_maker = sqlalchemy.orm.sessionmaker(engine, autoflush=False)
if self._in_context:
self._session = self.session_maker()
# create a clock
self.load_clock()

Expand Down Expand Up @@ -189,6 +194,8 @@ def apply_changes(self, messages: list[Message]):
entry = table(id=message.row)
setattr(entry, column, message.get_value())
s.add(entry)
# this seems to be required for sqlmodel, remove if not needed anymore when querying from cache
s.flush()
s.commit()

def download_budget(self):
Expand All @@ -206,7 +213,7 @@ def download_budget(self):
# this should extract 'db.sqlite' and 'metadata.json' to the folder
zip_file.extractall(self._data_dir)
engine = sqlalchemy.create_engine(f"sqlite:///{self._data_dir}/db.sqlite")
self.session_maker = sqlalchemy.orm.sessionmaker(engine)
self.session_maker = sqlalchemy.orm.sessionmaker(engine, autoflush=False)
# actual js always calls validation
self.validate()
# load the client id
Expand Down
18 changes: 15 additions & 3 deletions actual/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class Accounts(BaseModel, table=True):
sort_order: Optional[float] = Field(default=None, sa_column=Column("sort_order", Float))
type: Optional[str] = Field(default=None, sa_column=Column("type", Text))

payee: "Payees" = Relationship(back_populates="account", sa_relationship_kwargs={"uselist": False})
pending_transactions: List["PendingTransactions"] = Relationship(back_populates="account")
transactions: List["Transactions"] = Relationship(back_populates="account")

Expand All @@ -141,12 +142,15 @@ class Categories(SQLModel, table=True):
id: Optional[str] = Field(default=None, sa_column=Column("id", Text, primary_key=True))
name: Optional[str] = Field(default=None, sa_column=Column("name", Text))
is_income: Optional[int] = Field(default=None, sa_column=Column("is_income", Integer, server_default=text("0")))
cat_group: Optional[str] = Field(default=None, sa_column=Column("cat_group", Text))
cat_group: Optional[str] = Field(
default=None, sa_column=Column("cat_group", Text, ForeignKey("category_groups.id"))
)
sort_order: Optional[float] = Field(default=None, sa_column=Column("sort_order", Float))
tombstone: Optional[int] = Field(default=None, sa_column=Column("tombstone", Integer, server_default=text("0")))
goal_def: Optional[str] = Field(default=None, sa_column=Column("goal_def", Text, server_default=text("null")))

transactions: List["Transactions"] = Relationship(back_populates="category")
group: "CategoryGroups" = Relationship(back_populates="categories", sa_relationship_kwargs={"uselist": False})


class CategoryGroups(SQLModel, table=True):
Expand All @@ -159,6 +163,8 @@ class CategoryGroups(SQLModel, table=True):
sort_order: Optional[float] = Field(default=None, sa_column=Column("sort_order", Float))
tombstone: Optional[int] = Field(default=None, sa_column=Column("tombstone", Integer, server_default=text("0")))

categories: List["Categories"] = Relationship(back_populates="group")


class CategoryMapping(SQLModel, table=True):
__tablename__ = "category_mapping"
Expand Down Expand Up @@ -221,8 +227,11 @@ class Payees(SQLModel, table=True):
name: Optional[str] = Field(default=None, sa_column=Column("name", Text))
category: Optional[str] = Field(default=None, sa_column=Column("category", Text))
tombstone: Optional[int] = Field(default=None, sa_column=Column("tombstone", Integer, server_default=text("0")))
transfer_acct: Optional[str] = Field(default=None, sa_column=Column("transfer_acct", Text))
transfer_acct: Optional[str] = Field(
default=None, sa_column=Column("transfer_acct", Text, ForeignKey("accounts.id"))
)

account: Optional["Accounts"] = Relationship(back_populates="payee", sa_relationship_kwargs={"uselist": False})
transactions: List["Transactions"] = Relationship(back_populates="payee")


Expand Down Expand Up @@ -312,7 +321,7 @@ class Transactions(BaseModel, table=True):
acct: Optional[str] = Field(default=None, sa_column=Column("acct", Text, ForeignKey("accounts.id")))
category_id: Optional[str] = Field(default=None, sa_column=Column("category", Text, ForeignKey("categories.id")))
amount: Optional[int] = Field(default=None, sa_column=Column("amount", Integer))
description: Optional[str] = Field(default=None, sa_column=Column("description", Text, ForeignKey("payees.id")))
payee_id: Optional[str] = Field(default=None, sa_column=Column("description", Text, ForeignKey("payees.id")))
notes: Optional[str] = Field(default=None, sa_column=Column("notes", Text))
date: Optional[int] = Field(default=None, sa_column=Column("date", Integer))
financial_id: Optional[str] = Field(default=None, sa_column=Column("financial_id", Text))
Expand Down Expand Up @@ -361,6 +370,9 @@ def new(
sort_order=datetime.datetime.utcnow().timestamp(),
)

def get_date(self) -> datetime.date:
return datetime.datetime.strptime(str(self.date), "%Y%m%d").date()


class ZeroBudgetMonths(SQLModel, table=True):
__tablename__ = "zero_budget_months"
Expand Down
161 changes: 144 additions & 17 deletions actual/queries.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import datetime
import decimal
import typing
Expand All @@ -6,11 +8,48 @@
import sqlalchemy
from sqlalchemy.orm import Session, joinedload

from actual.database import Accounts, Categories, PayeeMapping, Payees, Transactions
from actual.database import (
Accounts,
Categories,
CategoryGroups,
CategoryMapping,
PayeeMapping,
Payees,
Transactions,
)

T = typing.TypeVar("T")


def is_uuid(text: str, version: int = 4):
"""
Check if uuid_to_test is a valid UUID.
Taken from https://stackoverflow.com/a/54254115/12681470
Parameters
----------
uuid_to_test : str
version : {1, 2, 3, 4}
Returns
-------
`True` if uuid_to_test is a valid UUID, otherwise `False`.
Examples
--------
>>> is_uuid('c9bf9e57-1685-4c89-bafb-ff5af830be8a')
True
>>> is_uuid('c9bf9e58')
False
"""
try:
uuid.UUID(str(text), version=version)
return True
except ValueError:
return False


def get_transactions(s: Session, notes: str = None, include_deleted: bool = False) -> typing.List[Transactions]:
query = (
s.query(Transactions)
Expand Down Expand Up @@ -53,8 +92,8 @@ def create_transaction_from_ids(
acct=account_id,
date=date_int,
amount=int(amount * 100),
category=category_id,
payee=payee_id,
category_id=category_id,
payee_id=payee_id,
notes=notes,
reconciled=0,
cleared=0,
Expand All @@ -68,15 +107,15 @@ def create_transaction(
s: Session,
account_name: str,
date: datetime.date,
payee_name: str,
payee_name: str | Payees,
notes: str,
category_name: str = None,
amount: decimal.Decimal = 0,
):
acct = get_account(s, account_name)
payee = get_or_create_payee(s, payee_name)
if category_name:
category_id = get_or_create_category(s, category_name).id
category_id = get_or_create_category(s, category_name, "").id
else:
category_id = None
return create_transaction_from_ids(s, acct.id, date, payee.id, notes, category_id, amount)
Expand All @@ -91,13 +130,63 @@ def base_query(s: Session, instance: typing.Type[T], name: str, include_deleted:
return query


def create_category_group(s: Session, name: str) -> CategoryGroups:
category_group = CategoryGroups(id=str(uuid.uuid4()), name=name, is_income=0, is_hidden=0, sort_order=0)
s.add(category_group)
return category_group


def get_or_create_category_group(s: Session, name: str) -> CategoryGroups:
category_group = s.query(CategoryGroups).filter(CategoryGroups.name == name).one_or_none()
if not category_group:
category_group = create_category_group(s, name)
return category_group


def get_categories(s: Session, name: str = None, include_deleted: bool = False) -> typing.List[Categories]:
query = base_query(s, Categories, name, include_deleted).options(joinedload(Payees.transactions))
return query.all()


def get_or_create_category(s: Session, name: str) -> Categories:
pass
def create_category(
s: Session,
name: str,
group_name: str,
) -> Categories:
category_group = get_or_create_category_group(s, group_name)
category = Categories(
id=str(uuid.uuid4()), name=name, hidden=0, is_income=0, sort_order=0, cat_group=category_group.id
)
category_mapping = CategoryMapping(id=category.id, transfer_id=category.id)
s.add(category)
s.add(category_mapping)
return category


def get_category(
s: Session, name: str | Categories, group_name: str = None, strict_group: bool = False
) -> typing.Optional[Categories]:
if isinstance(name, Categories):
return name
category = (
s.query(Categories)
.join(CategoryGroups)
.filter(Categories.name == name, CategoryGroups.name == group_name)
.one_or_none()
)
if not category and not strict_group:
# try to find it without the group name
category = s.query(Categories).filter(Categories.name == name).one_or_none()
return category


def get_or_create_category(
s: Session, name: str | Categories, group_name: str, strict_group: bool = False
) -> Categories:
category = get_category(s, name, group_name, strict_group)
if not category:
category = create_category(s, name, group_name)
return category


def get_accounts(s: Session, name: str = None, include_deleted: bool = False) -> typing.List[Accounts]:
Expand All @@ -110,16 +199,22 @@ def get_payees(s: Session, name: str = None, include_deleted: bool = False) -> t
return query.all()


def create_payee(s: Session, name: str) -> Payees:
payee = Payees(id=uuid.uuid4(), name=name)
def get_payee(s: Session, name: str | Payees) -> Payees:
if isinstance(name, Payees):
return name
return s.query(Payees).filter(Payees.name == name).one_or_none()


def create_payee(s: Session, name: str | None) -> Payees:
payee = Payees(id=str(uuid.uuid4()), name=name)
s.add(payee)
# add also the payee mapping
s.add(PayeeMapping(id=payee.id, target_id=payee.id))
return payee


def get_or_create_payee(s: Session, name: str) -> Payees:
payee = s.query(Payees).filter(Payees.name == name).one_or_none()
def get_or_create_payee(s: Session, name: str | Payees | None) -> Payees:
payee = get_payee(s, name)
if not payee:
payee = create_payee(s, name)
return payee
Expand All @@ -128,23 +223,55 @@ def get_or_create_payee(s: Session, name: str) -> Payees:
def create_account(
s: Session, name: str, initial_balance: decimal.Decimal = decimal.Decimal(0), off_budget: bool = False
) -> Accounts:
acct = Accounts(id=uuid.uuid4(), name=name, offbudget=int(off_budget), closed=0)
acct = Accounts(id=str(uuid.uuid4()), name=name, offbudget=int(off_budget), closed=0)
s.add(acct)
# add a blank payee
payee = create_payee(s, "")
payee.target_id = None
payee = create_payee(s, None)
payee.transfer_acct = acct.id
s.add(payee)
# if there is an no initial balance, create it
if initial_balance:
payee_starting = get_or_create_payee(s, "Starting Balance")
category = get_or_create_category(s, "Starting Balances")
category = get_or_create_category(s, "Starting Balances", "Income")
create_transaction_from_ids(
s, acct.id, datetime.date.today(), payee_starting.id, "", category.id, initial_balance
)
return acct


def get_account(s: Session, name: str) -> typing.Optional[Accounts]:
account = s.query(Accounts).filter(Accounts.name == name).one_or_none()
def get_account(s: Session, name: str | Accounts) -> typing.Optional[Accounts]:
if isinstance(name, Accounts):
return name
if is_uuid(name):
account = s.query(Accounts).filter(Accounts.id == name).one_or_none()
else:
account = s.query(Accounts).filter(Accounts.name == name).one_or_none()
return account


def get_or_create_account(s: Session, name: str | Accounts) -> Accounts:
account = get_account(s, name)
if not account:
account = create_account(s, name)
return account


def create_transfer(
s: Session,
source_account: str | Accounts,
dest_account: str | Accounts,
amount: decimal.Decimal,
date: datetime.date,
notes: str = None,
) -> typing.Tuple[Transactions, Transactions]:
source: Accounts = get_account(s, source_account)
dest: Accounts = get_account(s, dest_account)
source_transaction = create_transaction_from_ids(s, source.id, date, dest.payee.id, notes, None, -amount)
dest_transaction = create_transaction_from_ids(s, dest.id, date, source.payee.id, notes, None, amount)
# swap the transferred ids
source_transaction.transferred_id = dest_transaction.id
dest_transaction.transferred_id = source_transaction.id
# add and return objects
s.add(source_transaction)
s.add(dest_transaction)
return source_transaction, dest_transaction

0 comments on commit 298a41e

Please sign in to comment.