Skip to content

Commit

Permalink
Initial implementation of reduced splitting area (#151)
Browse files Browse the repository at this point in the history
* half done adding left and right limits for splits

* implemented change in rust base package as suggested

* python intf and change in segmentation for sbs

* R intf

* typo in control.R

* reverted all changes apart from fix to sbs search which I think was a bug fix

* control.rs with forbidden segments and integrity check

* split candidates updated with mlondschien suggestion

* py intf

* two step search heuristic

* some testing

* more testing

* Update src/optimizer/optimizer.rs

Co-authored-by: Malte Londschien <[email protected]>

* Update src/optimizer/two_step_search.rs

Co-authored-by: Malte Londschien <[email protected]>

* pre-commit run -a.

* Add test.

* Changelog.

* Adjust definition of segments to (a, b].

* Prepare for 1.1.0 release.

* Thanks @enzbus.

---------

Co-authored-by: enzo.busseti <[email protected]>
Co-authored-by: Malte Londschien <[email protected]>
Co-authored-by: Malte Londschien <[email protected]>
  • Loading branch information
4 people authored Aug 1, 2023
1 parent 8810436 commit a53e571
Show file tree
Hide file tree
Showing 15 changed files with 200 additions and 12 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@

# Changelog

## 1.1.0 - (2023-08-01)

**New features**:

- New argument `forbidden_segments` (list or vector of 2-tuple) or `None` to `Control`. If not `None`, `changeforest` will not split on split points contained in segments `(a, b]` in `forbidden_segments` (rust and Python only). Thanks @enzbus!

## 1.0.1 - (2022-06-01)

**Bug fixes:**
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "changeforest"
description = "Random Forests for Change Point Detection"
authors = ["Malte Londschien <[email protected]>"]
repository = "https://github.com/mlondschien/changeforest/"
version = "1.0.1"
version = "1.1.0"
edition = "2021"
readme = "README.md"
license = "BSD-3-Clause"
Expand Down
2 changes: 1 addition & 1 deletion changeforest-py/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "changeforest_py"
version = "1.0.1"
version = "1.1.0"
edition = "2021"

[lib]
Expand Down
15 changes: 15 additions & 0 deletions changeforest-py/changeforest/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
random_forest_max_depth="default",
random_forest_max_features="default",
random_forest_n_jobs="default",
forbidden_segments="default",
):
self.minimal_relative_segment_length = _to_float(
minimal_relative_segment_length
Expand All @@ -32,6 +33,7 @@ def __init__(
self.random_forest_max_depth = _to_int(random_forest_max_depth)
self.random_forest_max_features = _to_int(random_forest_max_features)
self.random_forest_n_jobs = _to_int(random_forest_n_jobs)
self.forbidden_segments = _to_segments(forbidden_segments)


def _to_float(value):
Expand All @@ -50,3 +52,16 @@ def _to_int(value):
return value
else:
return int(value)


def _to_segments(value):
if (value is None) or isinstance(value, str):
return value
else:
try:
return [(int(el1), int(el2)) for (el1, el2) in value]
except Exception:
raise SyntaxError(
"forbidden_segments must be provided as [(a,b), ...] where a and b are "
"integers."
)
2 changes: 1 addition & 1 deletion changeforest-py/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "changeforest"
description = "Random Forests for Change Point Detection"
readme = "README.md"
version = "1.0.1"
version = "1.1.0"
requires-python = ">=3.7"
author = "Malte Londschien <[email protected]>"
urls = {homepage = "https://github.com/mlondschien/changeforest/"}
Expand Down
6 changes: 6 additions & 0 deletions changeforest-py/src/control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ pub fn control_from_pyobj(py: Python, obj: Option<PyObject>) -> PyResult<Control
control.random_forest_parameters.with_n_jobs(value);
}
};

if let Ok(pyvalue) = obj.getattr(py, "forbidden_segments") {
if let Ok(value) = pyvalue.extract::<Option<Vec<(usize, usize)>>>(py) {
control = control.with_forbidden_segments(value);
}
};
}

Ok(control)
Expand Down
40 changes: 40 additions & 0 deletions changeforest-py/tests/test_changeforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,43 @@ def test_changeforest_repr(iris_dataset):
°--(100, 150] 136 -2.398 0.875\
"""
)


def test_changeforest_repr_segments(iris_dataset):
result = changeforest(
iris_dataset,
"random_forest",
"bs",
control=Control(forbidden_segments=[(0, 49), (101, 120)]),
)
assert (
result.__repr__()
== """\
best_split max_gain p_value
(0, 150] 50 95.1 0.005
¦--(0, 50]
°--(50, 150] 100 52.799 0.005
¦--(50, 100] 53 6.892 0.315
°--(100, 150] 136 -3.516 0.68\
""" # noqa: W291
)


def test_changeforest_repr_segments2(iris_dataset):
result = changeforest(
iris_dataset,
"random_forest",
"bs",
control=Control(forbidden_segments=[(49, 101)]),
)
assert (
result.__repr__()
== """\
best_split max_gain p_value
(0, 150] 49 87.462 0.005
¦--(0, 49] 2 -8.889 0.995
°--(49, 150] 102 41.237 0.005
¦--(49, 102]
°--(102, 150] 136 1.114 0.36\
""" # noqa: W291
)
19 changes: 19 additions & 0 deletions changeforest-py/tests/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,22 @@ def test_control_defaults(iris_dataset, key, default_value, another_value):

assert str(result) == str(default_result)
assert str(result) != str(another_result)


def test_control_segments():
with pytest.raises(SyntaxError):
Control(
forbidden_segments=[
(2),
]
)

with pytest.raises(SyntaxError):
Control(
forbidden_segments=[
(2, 3, 4),
]
)

with pytest.raises(SyntaxError):
Control(forbidden_segments=[2, 3])
2 changes: 1 addition & 1 deletion changeforest-r/DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: changeforest
Type: Package
Title: Random Forests for Change Point Detection
Version: 1.0.1
Version: 1.1.0
Author: Malte Londschien
Maintainer: Malte Londschien <[email protected]>
Description:
Expand Down
2 changes: 1 addition & 1 deletion changeforest-r/src/rust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = 'changeforestr'
version = '1.0.1'
version = '1.1.0'
edition = '2021'

[lib]
Expand Down
19 changes: 19 additions & 0 deletions src/control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ pub struct Control {
pub seed: u64,
/// Hyperparameters for random forests.
pub random_forest_parameters: RandomForestParameters,
/// Segments of indexes were no segmentation is allowed.
pub forbidden_segments: Option<Vec<(usize, usize)>>,
}

impl Control {
Expand All @@ -45,6 +47,7 @@ impl Control {
.with_max_depth(Some(8))
.with_max_features(MaxFeatures::Sqrt)
.with_n_jobs(Some(-1)),
forbidden_segments: None,
}
}

Expand Down Expand Up @@ -111,4 +114,20 @@ impl Control {
self.random_forest_parameters = random_forest_parameters;
self
}

pub fn with_forbidden_segments(
mut self,
forbidden_segments: Option<Vec<(usize, usize)>>,
) -> Self {
// check that segments are well specified
if let Some(ref _forbidden_segments) = forbidden_segments {
for el in _forbidden_segments.iter() {
if el.0 > el.1 {
panic!("Forbidden segments must be specified as [(a,b), ...] where a <= b!");
}
}
}
self.forbidden_segments = forbidden_segments;
self
}
}
33 changes: 33 additions & 0 deletions src/optimizer/grid_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,37 @@ mod tests {
expected
);
}

#[rstest]
#[case(0, 10, Some(vec![(0, 3)]), 0.09, vec![4, 5, 6, 7, 8])]
#[case(1, 10, Some(vec![(6, 10)]), 0.15, vec![3, 4, 5, 6])]
#[case(0, 10, Some(vec![(2, 4), (5, 7)]), 0.09, vec![1, 2, 5, 8])]
#[case(1, 7, Some(vec![(2, 4), (5, 7)]), 0.09, vec![2, 5])]
fn test_split_candidates(
#[case] start: usize,
#[case] stop: usize,
#[case] forbidden_segments: Option<Vec<(usize, usize)>>,
#[case] delta: f64,
#[case] expected: Vec<usize>,
) {
let X = ndarray::array![
[0.0],
[0.0],
[0.0],
[0.0],
[-0.0],
[-0.0],
[-0.0],
[-0.0],
[-0.0],
[-0.0]
];
let X_view = X.view();
let control = Control::default()
.with_minimal_relative_segment_length(delta)
.with_forbidden_segments(forbidden_segments);
let gain = testing::ChangeInMean::new(&X_view, &control);
let grid_search = GridSearch { gain };
assert_eq!(grid_search.split_candidates(start, stop).unwrap(), expected);
}
}
17 changes: 16 additions & 1 deletion src/optimizer/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,22 @@ pub trait Optimizer {
if 2 * minimal_segment_length >= (stop - start) {
Err("Segment too small.")
} else {
Ok(((start + minimal_segment_length)..(stop - minimal_segment_length)).collect())
let mut split_candidates: Vec<usize> =
((start + minimal_segment_length)..(stop - minimal_segment_length)).collect();

if let Some(forbidden_segments) = &self.control().forbidden_segments {
split_candidates.retain(|x| {
forbidden_segments
.iter()
.all(|segment| x <= &segment.0 || x > &segment.1)
});
}

if split_candidates.is_empty() {
Err("No split_candidates left after filtering out forbidden_segments.")
} else {
Ok(split_candidates)
}
}
}
}
43 changes: 38 additions & 5 deletions src/optimizer/two_step_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,46 @@ where
fn find_best_split(&self, start: usize, stop: usize) -> Result<OptimizerResult, &str> {
let split_candidates = self.split_candidates(start, stop)?;

let guesses = vec![
(3 * start + stop) / 4,
(start + stop) / 2,
(start + 3 * stop) / 4,
];
let mut guesses = vec![];
let mut results: Vec<GainResult> = vec![];

// if there are forbidden segments change the heuristics
// pick middle element of split_candidates, 1/4th and 3/4th
if let Some(_forbidden_segments) = &self.control().forbidden_segments {
// there is at least one element in split_candidates
guesses.push(
split_candidates
.clone()
.into_iter()
.nth(split_candidates.len() / 4)
.unwrap(),
);

// we add this if it is not equal to last
let cand = split_candidates
.clone()
.into_iter()
.nth(split_candidates.len() / 2)
.unwrap();
if cand > guesses[guesses.len() - 1] {
guesses.push(cand)
};

// same
let cand = split_candidates
.clone()
.into_iter()
.nth(3 * split_candidates.len() / 4)
.unwrap();
if cand > guesses[guesses.len() - 1] {
guesses.push(cand)
};
} else {
guesses.push((3 * start + stop) / 4);
guesses.push((start + stop) / 2);
guesses.push((start + 3 * stop) / 4);
}

// Don't use first and last guess if stop - start / 4 < delta.
for guess in guesses.iter().filter(|x| split_candidates.contains(x)) {
results.push(self._single_find_best_split(start, stop, *guess, &split_candidates));
Expand Down
4 changes: 3 additions & 1 deletion src/segmentation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ impl<'a> Segmentation<'a> {
// start + segment_length > n through floating point errors in
// n_segments, e.g. for n = 20'000, alpha_k = 1/sqrt(2), k=6
stop = (start + (segment_length as f32).ceil() as usize).min(optimizer.n());
segments.push(optimizer.find_best_split(start, stop).unwrap());
if let Ok(optimizer_result) = optimizer.find_best_split(start, stop) {
segments.push(optimizer_result)
}
}
}
}
Expand Down

0 comments on commit a53e571

Please sign in to comment.