Skip to content

Commit

Permalink
Merge branch '1.1.x' into dev/1.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
j-ittner committed Mar 28, 2021
2 parents 18aaf94 + f2aeeff commit 19ec93b
Showing 1 changed file with 39 additions and 22 deletions.
61 changes: 39 additions & 22 deletions src/facet/crossfit/_crossfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ class _FitScoreParameters(NamedTuple):
train_weight: Optional[pd.Series]

# score parameters
scorer: Optional[Scorer]
score_train_split: bool
test_features: Optional[pd.DataFrame]
test_target: Union[pd.Series, pd.DataFrame, None]
test_weight: Optional[pd.Series]
scorer: Optional[Scorer] = None
score_train_split: bool = False
test_features: Optional[pd.DataFrame] = None
test_target: Union[pd.Series, pd.DataFrame, None] = None
test_weight: Optional[pd.Series] = None


@inheritdoc(match="[see superclass]")
Expand Down Expand Up @@ -389,22 +389,15 @@ def _fit_score_queue(

global_fit: Optional[Job[FitResult]]
if do_fit:

class _FitModelOnFullData(Job[FitResult]):
# noinspection PyMissingOrEmptyDocstring
def run(self) -> FitResult:
if sample_weight is None:
pipeline.fit(X=features, y=target, **fit_params)
else:
pipeline.fit(
X=features,
y=target,
sample_weight=sample_weight,
**fit_params,
)
return (pipeline, None)

global_fit = _FitModelOnFullData()
global_fit = _FitModelOnFullData(
parameters=_FitScoreParameters(
pipeline=pipeline,
train_features=features,
train_target=target,
train_weight=sample_weight,
),
fit_params=fit_params,
)
else:
global_fit = None

Expand Down Expand Up @@ -519,13 +512,15 @@ def __len__(self) -> int:
return self.n_splits_


class _FitAndScoreModelForSplit(Job[FitResult]):
class _BaseFitAndScore(Job[FitResult], metaclass=ABCMeta):
def __init__(
self, parameters: _FitScoreParameters, fit_params: Dict[str, Any]
) -> None:
self.parameters = parameters
self.fit_params = fit_params


class _FitAndScoreModelForSplit(_BaseFitAndScore):
def run(self) -> FitResult:
"""
Fit and/or score a learner pipeline.
Expand Down Expand Up @@ -578,4 +573,26 @@ def run(self) -> FitResult:
return pipeline if do_fit else None, score


class _FitModelOnFullData(_BaseFitAndScore):
# noinspection PyMissingOrEmptyDocstring
def run(self) -> FitResult:
parameters = self.parameters
pipeline = parameters.pipeline

if parameters.train_target is None:
pipeline.fit(
X=parameters.train_features,
y=parameters.train_target,
**self.fit_params,
)
else:
pipeline.fit(
X=parameters.train_features,
y=parameters.train_target,
sample_weight=parameters.train_weight,
**self.fit_params,
)
return (pipeline, None)


__tracker.validate()

0 comments on commit 19ec93b

Please sign in to comment.