Skip to content

Commit

Permalink
Pytest refactor (#3)
Browse files Browse the repository at this point in the history
* tox support

- add tox.ini
- minor changes to support tox
- pyproject clean up

* refactor tests

- remove boiller plate tests
- implement generators to parametrize fixtures
- clean up repetitive tests

* more refactoring

- add xfail marks
- run test for CalendarReport as xfail
- remove duplicate code from conftest
  • Loading branch information
gnzsnz authored May 29, 2024
1 parent 6341636 commit 624d248
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 410 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repos:
- id: pylint
entry: pyenv which pylint
types: [python]
args: [-rn, --rcfile=pyproject.toml]
args: [-rn, --rcfile=pyproject.toml, tests/, ib_fundamental/]
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
hooks:
Expand Down
2 changes: 2 additions & 0 deletions ib_fundamental/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ class Revenue:
as_of_date: datetime
report_type: str
period: str
currency: str
revenue: float


Expand All @@ -267,6 +268,7 @@ class EarningsPerShare:
as_of_date: datetime
report_type: str
period: str
currency: str
eps: float


Expand Down
6 changes: 6 additions & 0 deletions ib_fundamental/xml_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,14 @@ def get_revenue(
"""Revenue"""
fa = "./TotalRevenues"
fs = self.xml_report.fin_summary.find(fa)
curr = fs.attrib["currency"]

_revenue = [
Revenue(
as_of_date=fromisoformat(tr.attrib["asofDate"]),
report_type=tr.attrib["reportType"],
period=tr.attrib["period"],
currency=curr,
revenue=float(tr.text),
)
for tr in fs
Expand All @@ -290,11 +293,14 @@ def get_eps(
"""Earnings per share"""
fa = "./EPSs"
fs = self.xml_report.fin_summary.find(fa)
curr = fs.attrib["currency"]

_eps = [
EarningsPerShare(
as_of_date=fromisoformat(tr.attrib["asofDate"]),
report_type=tr.attrib["reportType"],
period=tr.attrib["period"],
currency=curr,
eps=float(tr.text),
)
for tr in fs
Expand Down
172 changes: 172 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,23 @@
"""IB Fundamental test fixtures"""

import os
import typing
from itertools import chain, product, repeat

import pytest
from ib_async import IB

from ib_fundamental.fundamental import FundamentalData
from ib_fundamental.ib_client import IBClient
from ib_fundamental.objects import (
BalanceSheetStatement,
CashFlowStatement,
DividendPerShare,
EarningsPerShare,
IncomeStatement,
ReportType,
Revenue,
)
from ib_fundamental.xml_parser import XMLParser
from ib_fundamental.xml_report import XMLReport

Expand Down Expand Up @@ -60,6 +71,66 @@
"WMT",
]

# all but calendar
xml_report_attributes = [
"fin_statements",
"fin_summary",
"ownership",
"resc",
"snapshot",
]

statement = [CashFlowStatement, BalanceSheetStatement, IncomeStatement]
stament_period = ["Annual", "Interim"]
statement_period_type = ["annual", "quarter"]

financials = [DividendPerShare, Revenue, EarningsPerShare]
fin_attrs_parser = ["get_div_per_share", "get_revenue", "get_eps"]
statement_attrs_fun_data = [
"cashflow_annual",
"cashflow_quarter",
"balance_annual",
"balance_quarter",
"income_annual",
"income_quarter",
]
fin_attrs_fun = [
"div_ps_ttm",
"div_ps_q",
"revenue_ttm",
"revenue_q",
"eps_ttm",
"eps_q",
]

report_type = ["TTM", "R"]
period_type = ["12M", "3M"]
forward_year_attrs = ["fy_estimates", "fy_actuals"]
forward_year_parser_attrs = ["get_fy_estimates", "get_fy_actuals"]

req_report_type = list(typing.get_args(ReportType))

map_items_statement_type = ["CAS", "BAL", "INC"]


def repeat_each(iterable, n=2):
"""Repeat each element in *iterable* *n* times.
>>> list(repeat_each('ABC', 3))
['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C']
"""
return chain.from_iterable(map(repeat, iterable, repeat(n)))


def ncycles(iterable, n):
"""Returns the sequence elements *n* times
>>> list(ncycles(["a", "b"], 3))
['a', 'b', 'a', 'b', 'a', 'b']
"""
return chain.from_iterable(repeat(tuple(iterable), n))


@pytest.fixture(scope="session")
def tws_client():
Expand All @@ -83,6 +154,13 @@ def ib_client(tws_client, request):
del _ib_client


@pytest.fixture(scope="function", params=req_report_type)
def ib_client_req_fund(ib_client, request):
"""IBClient ib_req_fund combinations"""
_report = request.param
yield ib_client.ib_req_fund(_report)


@pytest.fixture(scope="class")
def xml_report(ib_client):
"""XMLReport fixture"""
Expand All @@ -92,6 +170,13 @@ def xml_report(ib_client):
del _xml_report


@pytest.fixture(scope="function", params=xml_report_attributes)
def xml_report_attrs(xml_report, request):
"""XMLReport get attributes"""
_attr = request.param
yield getattr(xml_report, _attr)


@pytest.fixture(scope="function")
def xml_parser(ib_client):
"""XMLReport fixture"""
Expand All @@ -101,6 +186,54 @@ def xml_parser(ib_client):
del _xml_parser


gen_parser_statement = (
(*_m, *_p)
for _m in zip(map_items_statement_type, statement)
for _p in zip(stament_period, statement_period_type)
)


@pytest.fixture(scope="function", params=gen_parser_statement)
def xml_parser_statement(xml_parser, request):
"""XMLParser statement combinations"""
_statement_type, _statement, _stament_period, _statement_period_type = request.param
_parser_statement = xml_parser.get_fin_statement(
_statement_type, _statement_period_type
)
yield _parser_statement, _statement, _stament_period


@pytest.fixture(scope="function", params=map_items_statement_type)
def xml_parser_map_items(xml_parser, request):
"""XMLParser map_items statement type combinations"""
_statement_type = request.param
_map_items = xml_parser.get_map_items(_statement_type)
yield _map_items, _statement_type


gen_parser_fin = (
(*_f, *_t)
for _f in zip(financials, fin_attrs_parser)
for _t in zip(report_type, period_type)
)


@pytest.fixture(scope="function", params=gen_parser_fin)
def xml_parser_financials(xml_parser, request):
"""XMLParser div ps, rev, eps combinations"""
_financials, _fin_attrs_parser, _report_type, _period_type = request.param
_fin = getattr(xml_parser, _fin_attrs_parser)(_report_type, _period_type)
yield _financials, _fin, _report_type, _period_type


@pytest.fixture(scope="function", params=forward_year_parser_attrs)
def xml_parser_fy(xml_parser, request):
"""XMLParser fy estimates and actuals"""
_attr = request.param
_fy = getattr(xml_parser, _attr)()
yield _fy


@pytest.fixture(scope="class", params=DJIA)
def fundamental_data(tws_client, request):
"""FundamentalData fixture"""
Expand All @@ -109,3 +242,42 @@ def fundamental_data(tws_client, request):
_fund = FundamentalData(symbol=_symbol, ib=_tws_client)
yield _fund
del _fund


gen_fund_statement = (
(a, *b)
for a, b in zip(statement_attrs_fun_data, product(statement, stament_period))
)


@pytest.fixture(scope="function", params=gen_fund_statement)
def fundamental_statement(fundamental_data, request):
"""FundamentalData statements"""
_attr, _class, _period = request.param
_statement = getattr(fundamental_data, _attr)
yield _statement, _class, _period


gen_financials = (
(_class, _attr, *_t)
for _class, _attr, _t in zip(
repeat_each(financials, 2),
fin_attrs_fun,
ncycles(zip(report_type, period_type), 3),
)
)


@pytest.fixture(scope="function", params=gen_financials)
def fundamental_financials(fundamental_data, request):
"""FundamentalData financials"""
_class, _attr, _report_type, _period_type = request.param
yield getattr(fundamental_data, _attr), _class, _report_type, _period_type


@pytest.fixture(scope="function", params=forward_year_attrs)
def fundamental_fy(fundamental_data, request):
"""FundamentalData fy estimates and actuals"""
_attr = request.param
_fy = getattr(fundamental_data, _attr)
yield _fy
Loading

0 comments on commit 624d248

Please sign in to comment.