Skip to content

Commit

Permalink
Add .all() methods for pull everything out of the database
Browse files Browse the repository at this point in the history
  • Loading branch information
palewire committed Sep 1, 2024
1 parent 989afae commit 83fb3ca
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/continuous-deployment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ jobs:
- id: build
name: Build release
run: |
pipenv run python cpi/download.py
pipenv run python -c 'import cpi'
pipenv run python setup.py sdist
pipenv run python setup.py bdist_wheel
ls -l dist
Expand Down
6 changes: 6 additions & 0 deletions cpi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@
"CPI data is out of date. To accurately inflate to today's dollars, you must run `cpi.update()`."
)

# Create aliases for accessing the other data tables
areas = models.Area
periods = models.Period
periodicities = models.Periodicity
items = models.Item


def get(
year_or_month,
Expand Down
31 changes: 25 additions & 6 deletions cpi/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,27 @@ def rm(self) -> None:
"""Remove any existing files."""
db_path = self.THIS_DIR / "cpi.db"
if db_path.exists():
logger.debug(f"Deleting {db_path}")
db_path.unlink()
logger.debug("Clearing database")
# Drop all tables in the database
conn = self.get_db_conn()
table_list = [
"areas",
"items",
"periods",
"periodicities",
"series",
"indexes",
]
for t in table_list:
conn.execute(f"DROP TABLE IF EXISTS '{t}';")
conn.close()
self.vaccum()

def vaccum(self) -> None:
"""Vaccum the database."""
conn = self.get_db_conn()
conn.execute("VACUUM;")
conn.close()

def update(self) -> None:
"""Update the Consumer Price Index dataset that powers this library."""
Expand Down Expand Up @@ -94,7 +113,7 @@ def process_files(self) -> None:
series.to_sql("series", conn, if_exists="replace", index=False)

index = parsers.ParseIndex().get_df()
index.to_sql("index", conn, if_exists="replace", index=False)
index.to_sql("indexes", conn, if_exists="replace", index=False)

conn.close()

Expand Down Expand Up @@ -127,12 +146,12 @@ def drop_file_list(self, file_list: typing.List[str]) -> None:
logger.debug(f"- {name}")
conn.execute(f"DROP TABLE '{name}';")

# Clear space
conn.execute("VACUUM;")

# Close the connection
conn.close()

# Vaccum the database
self.vaccum()

def get_df(self, file: str) -> pd.DataFrame:
"""Download TSV file from the BLS."""
# Download it
Expand Down
21 changes: 17 additions & 4 deletions cpi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
logger.addHandler(logging.NullHandler())


def query(sql: str, params: list | tuple | None) -> list[dict]:
def query(sql: str, params: list | tuple | None = None) -> list[dict]:
"""Query the cpi.db database and return the result.
Args:
Expand Down Expand Up @@ -106,6 +106,12 @@ def get_by_name(cls, value: str):
d = queryone(f"SELECT * from '{cls.table_name}' WHERE name=?", (value,))
return cls(**d)

@classmethod
def all(cls):
"""Returns a list of all objects in the table."""
dict_list = query(f"SELECT * FROM '{cls.table_name}'")
return [cls(**d) for d in dict_list]


class Area(BaseObject):
"""A geographical area where prices are gathered monthly."""
Expand Down Expand Up @@ -312,7 +318,7 @@ def get_by_id(cls, value: str):
d["items"] = Item.get_by_id(d["items"])

# Get the indexes
dict_list = query("SELECT * FROM 'index' WHERE series=?", (value,))
dict_list = query("SELECT * FROM 'indexes' WHERE series=?", (value,))
d["indexes"] = []
for i in dict_list:
obj = Index(
Expand Down Expand Up @@ -377,6 +383,14 @@ def get_by_id(self, value) -> Series:
# Return it
return obj

def all(self) -> list[Series]:
"""Get all of the series from our database."""
# Query all of the series ids from the database
series_list = query("SELECT id FROM 'series';")

# Get all of them, to ensure they're all loaded in the cache
return [self.get_by_id(d["id"]) for d in series_list]

def get(
self,
survey=DEFAULTS_SERIES_ATTRS["survey"],
Expand All @@ -385,8 +399,7 @@ def get(
area=DEFAULTS_SERIES_ATTRS["area"],
items=DEFAULTS_SERIES_ATTRS["items"],
) -> Series:
"""
Returns a single CPI Series object based on the input.
"""Returns a single CPI Series object based on the input.
The default series is returned if not configuration is made to the keyword arguments.
"""
Expand Down
11 changes: 10 additions & 1 deletion tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,17 @@ def test_get_by_series_id(self):
def test_series_list(self):
cpi.series.get_by_id("CUSR0000SA0")

def test_metadata_lists(self):
self.assertTrue(len(cpi.areas.all()) > 0)
self.assertTrue(len(cpi.periods.all()) > 0)
self.assertTrue(len(cpi.periodicities.all()) > 0)
self.assertTrue(len(cpi.items.all()) > 0)

def test_series_indexes(self):
for series in cpi.series:
# Make sure we can lazy load the full database
series_list = cpi.series.all()
self.assertTrue(len(series_list) > 1)
for series in series_list:
self.assertTrue(len(series.indexes) > 0)
series.latest_month
series.latest_year
Expand Down

0 comments on commit 83fb3ca

Please sign in to comment.