Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use slots for large models #2673

Merged
merged 7 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions warehouse/metrics_tools/compute/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ async def initialize_app(app: FastAPI):
cluster_spec = make_new_cluster_with_defaults(config)
cluster_factory = KubeClusterFactory(
config.cluster_namespace,
config.worker_resources,
cluster_spec=cluster_spec,
shutdown_on_close=not config.debug_cluster_no_shutdown,
)
Expand Down
9 changes: 8 additions & 1 deletion warehouse/metrics_tools/compute/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def calculate_metrics(
cluster_min_size: int = 6,
cluster_max_size: int = 6,
job_retries: int = 3,
slots: int = 1,
execution_time: t.Optional[datetime] = None,
):
"""Calculate metrics for a given period and write the results to a gcs
folder. This method is a high level method that triggers all of the
Expand All @@ -151,6 +153,7 @@ def calculate_metrics(
locals (t.Dict[str, t.Any]): The local variables to use
dependent_tables_map (t.Dict[str, str]): The dependent tables map
job_retries (int): The number of retries for a given job in the worker queue. Defaults to 3.
execution_time (t.Optional[datetime]): The execution time for the job

Returns:
ExportReference: The export reference for the resulting calculation
Expand All @@ -172,6 +175,8 @@ def calculate_metrics(
locals,
dependent_tables_map,
job_retries,
slots=slots,
execution_time=execution_time,
)
job_id = job_response.job_id
export_reference = job_response.export_reference
Expand Down Expand Up @@ -240,6 +245,8 @@ def submit_job(
locals: t.Dict[str, t.Any],
dependent_tables_map: t.Dict[str, str],
job_retries: t.Optional[int] = None,
slots: int = 1,
execution_time: t.Optional[datetime] = None,
):
"""Submit a job to the metrics calculation service

Expand Down Expand Up @@ -269,7 +276,7 @@ def submit_job(
locals=locals,
dependent_tables_map=dependent_tables_map,
retries=job_retries,
execution_time=datetime.now(),
execution_time=execution_time or datetime.now(),
)
job_response = self.service_post_with_input(
JobSubmitResponse, "/job/submit", request
Expand Down
19 changes: 13 additions & 6 deletions warehouse/metrics_tools/compute/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def start_duckdb_cluster(

async def start_duckdb_cluster_async(
namespace: str,
resources: t.Dict[str, str],
cluster_spec: t.Optional[dict] = None,
min_size: int = 6,
max_size: int = 6,
Expand All @@ -55,7 +56,10 @@ async def start_duckdb_cluster_async(
a thread. The "async" version of dask's KubeCluster doesn't work as
expected. So for now we do this."""

options: t.Dict[str, t.Any] = {"namespace": namespace}
options: t.Dict[str, t.Any] = {
"namespace": namespace,
"resources": resources,
}
options.update(kwargs)
if cluster_spec:
options["custom_cluster_spec"] = cluster_spec
Expand All @@ -69,10 +73,6 @@ async def start_duckdb_cluster_async(
await adapt_response
return cluster

# return await asyncio.to_thread(
# start_duckdb_cluster, namespace, cluster_spec, min_size, max_size
# )


class ClusterProxy(abc.ABC):
async def client(self) -> Client:
Expand Down Expand Up @@ -163,18 +163,25 @@ class KubeClusterFactory(ClusterFactory):
def __init__(
self,
namespace: str,
resources: t.Dict[str, str],
cluster_spec: t.Optional[dict] = None,
log_override: t.Optional[logging.Logger] = None,
**kwargs: t.Any,
):
self._namespace = namespace
self.logger = log_override or logger
self._cluster_spec = cluster_spec
self._resources = resources
self.kwargs = kwargs

async def create_cluster(self, min_size: int, max_size: int):
cluster = await start_duckdb_cluster_async(
self._namespace, self._cluster_spec, min_size, max_size, **self.kwargs
self._namespace,
self._resources,
self._cluster_spec,
min_size,
max_size,
**self.kwargs,
)
return KubeClusterProxy(cluster)

Expand Down
1 change: 1 addition & 0 deletions warehouse/metrics_tools/compute/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def async_test_setup_cluster(config: AppConfig):

cluster_factory = KubeClusterFactory(
config.cluster_namespace,
config.worker_resources,
cluster_spec=cluster_spec,
log_override=logger,
)
Expand Down
3 changes: 3 additions & 0 deletions warehouse/metrics_tools/compute/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ async def _batch_query_to_scheduler(
task_id,
result_path,
batch,
input.slots,
exported_dependent_tables_map,
retries=3,
)
Expand All @@ -251,6 +252,7 @@ async def _submit_query_task_to_scheduler(
task_id: str,
result_path: str,
batch: t.List[str],
slots: int,
exported_dependent_tables_map: t.Dict[str, ExportReference],
retries: int,
):
Expand All @@ -266,6 +268,7 @@ async def _submit_query_task_to_scheduler(
exported_dependent_tables_map,
retries=retries,
key=task_id,
resources={"slots": slots},
)

try:
Expand Down
2 changes: 2 additions & 0 deletions warehouse/metrics_tools/compute/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ class JobSubmitRequest(BaseModel):
locals: t.Dict[str, t.Any]
dependent_tables_map: t.Dict[str, str]
retries: t.Optional[int] = None
slots: int = 1
execution_time: datetime

def query_as(self, dialect: str) -> str:
Expand Down Expand Up @@ -420,6 +421,7 @@ class ClusterConfig(BaseSettings):
scheduler_memory_request: str = "85000Mi"
scheduler_pool_type: str = "sqlmesh-scheduler"

worker_resources: t.Dict[str, str] = Field(default_factory=lambda: {"slots": "32"})
worker_threads: int = 16
worker_memory_limit: str = "90000Mi"
worker_memory_request: str = "85000Mi"
Expand Down
10 changes: 10 additions & 0 deletions warehouse/metrics_tools/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ class RollingConfig(t.TypedDict):
unit: str
cron: RollingCronOptions

# How many days do we process at once. This is useful to set for very large
# datasets but will default to a year if not set.
model_batch_size: t.NotRequired[int]

# The number of required slots for a given model. This is also very useful
# for large datasets
slots: t.NotRequired[int]


class TimeseriesBucket(Enum):
HOUR = "hour"
Expand Down Expand Up @@ -78,6 +86,8 @@ class PeerMetricDependencyRef(t.TypedDict):
unit: t.NotRequired[t.Optional[str]]
time_aggregation: t.NotRequired[t.Optional[str]]
cron: t.NotRequired[RollingCronOptions]
batch_size: t.NotRequired[int]
slots: t.NotRequired[int]


class MetricModelRef(t.TypedDict):
Expand Down
17 changes: 13 additions & 4 deletions warehouse/metrics_tools/factory/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,12 @@ def generate_rolling_python_model_for_rendered_query(

columns = METRICS_COLUMNS_BY_ENTITY[ref["entity_type"]]

kind_common = {"batch_size": 365, "batch_concurrency": 1}
kind_common = {
"batch_size": ref.get("batch_size", 365),
"batch_concurrency": 1,
"lookback": 10,
"forward_only": True,
}
partitioned_by = ("day(metrics_sample_date)",)
window = ref.get("window")
assert window is not None
Expand Down Expand Up @@ -468,12 +473,15 @@ def generate_time_aggregation_model_for_rendered_query(
time_aggregation = ref.get("time_aggregation")
assert time_aggregation is not None

kind_common = {"batch_concurrency": 1}
kind_options = {"lookback": 7, **kind_common}
kind_common = {
"batch_concurrency": 1,
"forward_only": True,
}
kind_options = {"lookback": 10, **kind_common}
partitioned_by = ("day(metrics_sample_date)",)

if time_aggregation == "weekly":
kind_options = {"lookback": 7, **kind_common}
kind_options = {"lookback": 10, **kind_common}
if time_aggregation == "monthly":
kind_options = {"lookback": 1, **kind_common}
partitioned_by = ("month(metrics_sample_date)",)
Expand Down Expand Up @@ -680,6 +688,7 @@ def generated_rolling_query(
batch_size=env.ensure_int("SQLMESH_MCS_BATCH_SIZE", 10),
columns=columns,
ref=ref,
slots=ref.get("slots", 1),
locals=sqlmesh_vars,
dependent_tables_map=create_dependent_tables_map(
context, rendered_query_str
Expand Down
Loading