Skip to content

Commit

Permalink
Merge pull request #165 from siapy/fix
Browse files Browse the repository at this point in the history
fix: Update n_jobs default value to use all processors
  • Loading branch information
janezlapajne authored Oct 10, 2024
2 parents 2ea348b + e768b93 commit 226eb9e
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion siapy/optimizers/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class OptimizeStudyConfig(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
n_trials: int | None = None
timeout: float | None = None
n_jobs: int = 1
n_jobs: int = -1
catch: Iterable[type[Exception]] | type[Exception] = ()
callbacks: (
list[Callable[[optuna.study.Study, optuna.trial.FrozenTrial], None]] | None
Expand Down
8 changes: 6 additions & 2 deletions siapy/optimizers/scorers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Iterable, Literal
from typing import Annotated, Iterable, Literal

import numpy as np
from sklearn import model_selection
Expand Down Expand Up @@ -38,6 +38,10 @@ def init_cross_validator_scorer(
| Iterable
| Literal["RepeatedKFold", "RepeatedStratifiedKFold"]
| None = None,
n_jobs: Annotated[
int | None,
"Number of jobs to run in parallel. `-1` means using all processors.",
] = None,
) -> "Scorer":
if isinstance(cv, str) and cv in ["RepeatedKFold", "RepeatedStratifiedKFold"]:
cv = initialize_object(
Expand All @@ -52,7 +56,7 @@ def init_cross_validator_scorer(
scoring=scoring,
cv=cv, # type: ignore
groups=None,
n_jobs=1,
n_jobs=n_jobs,
verbose=0,
fit_params=None,
pre_dispatch=1,
Expand Down
2 changes: 1 addition & 1 deletion tests/optimizers/test_optimizers_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_optimize_study_config_defaults():
config = OptimizeStudyConfig()
assert config.n_trials is None
assert config.timeout is None
assert config.n_jobs == 1
assert config.n_jobs == -1
assert config.catch == ()
assert config.callbacks is None
assert config.gc_after_trial is False
Expand Down

0 comments on commit 226eb9e

Please sign in to comment.