Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use polars and do an async write to gcs #2672

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions warehouse/metrics_tools/compute/worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# The worker initialization
import io
import logging
import os
import time
Expand All @@ -9,7 +8,8 @@
from threading import Lock

import duckdb
import pandas as pd
import gcsfs
import polars as pl
from dask.distributed import Worker, WorkerPlugin, get_worker
from google.cloud import storage
from metrics_tools.compute.types import ExportReference, ExportType
Expand Down Expand Up @@ -72,6 +72,7 @@ def __init__(
self._gcs_secret = gcs_secret
self._duckdb_path = duckdb_path
self._conn = None
self._fs = None
self._cache_status: t.Dict[str, bool] = {}
self._catalog = None
self._mode = "duckdb"
Expand All @@ -93,6 +94,7 @@ def setup(self, worker: Worker):
);
"""
self._conn.sql(sql)
self._fs = gcsfs.GCSFileSystem()

def teardown(self, worker: Worker):
if self._conn:
Expand Down Expand Up @@ -170,6 +172,11 @@ def upload_to_gcs_bucket(self, blob_path: str, file: t.IO):
blob = bucket.blob(blob_path)
blob.upload_from_file(file)

@property
def fs(self):
assert self._fs is not None, "GCSFS not initialized"
return self._fs

def handle_query(
self,
job_id: str,
Expand All @@ -182,30 +189,28 @@ def handle_query(

This executes the query with duckdb and writes the results to a gcs path.
"""

for ref, actual in dependencies.items():
self.logger.info(
f"job[{job_id}][{task_id}] Loading cache for {ref}:{actual}"
)
self.get_for_cache(ref, actual)
conn = self.connection
results: t.List[pd.DataFrame] = []
results: t.List[pl.DataFrame] = []
for query in queries:
self.logger.info(f"job[{job_id}][{task_id}]: Executing query {query}")
result = conn.execute(query).df()
result = conn.execute(query).pl()
results.append(result)
# Concatenate the results
self.logger.info(f"job[{job_id}][{task_id}]: Concatenating results")
results_df = pd.concat(results)
results_df = pl.concat(results)

# Export the results to a parquet file in memory
self.logger.info(f"job[{job_id}][{task_id}]: Writing to in memory parquet")
inmem_file = io.BytesIO()
results_df.to_parquet(inmem_file)
inmem_file.seek(0)

# Upload the parquet to gcs
self.logger.info(f"job[{job_id}][{task_id}]: Uploading to gcs {result_path}")
self.upload_to_gcs_bucket(result_path, inmem_file)
self.logger.info(
f"job[{job_id}][{task_id}]: Uploading to gcs {result_path} with polars"
)
with self.fs.open(f"{self._gcs_bucket}/{result_path}", "wb") as f:
results_df.write_parquet(f) # type: ignore
return task_id


Expand Down
Loading