Skip to content

Commit

Permalink
Bugfix plot() for change_in_mean and sbs, wbs. (#119)
Browse files Browse the repository at this point in the history
* Fix plotting of change_in_mean.

* Fix plotting for sbs, wbs.

* 0.7.2.

* Add matplotlib to dependencies.

* Correct red line.

* Install matplotlib in pyproject.toml

* Also install matplotlib for python tests.
  • Loading branch information
mlondschien authored May 9, 2022
1 parent f391f46 commit b5ea904
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
# https://github.com/docker/setup-qemu-action/issues/22 and
# https://github.com/pypa/cibuildwheel/issues/598
- name: Set up QEMU
if: runner.os == 'Linux'
if: runner.os == 'Linux' && matrix.vers == 'aarch64'
uses: docker/setup-qemu-action@v2
with:
platforms: arm64
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ jobs:
- name: Install wheel
run: |
pip install numpy pytest
pip install numpy matplotlib pytest
pip install --force-reinstall --no-index --find-links changeforest-py/target/wheels/ changeforest
- name: Run tests
Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@

# Changelog

## 0.7.2 - (2022-05-09)

**Bug fixes:**

- Fixed bugs when plotting results created with `method="change_in_mean"` or `segmentation="sbs"` or `"wbs"` (Python).

## 0.7.1 - (2022-05-02)

**Bug fixes:**
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "changeforest"
description = "Classifier based non-parametric change point detection."
authors = ["Malte Londschien <[email protected]>"]
repository = "https://github.com/mlondschien/changeforest/"
version = "0.7.1"
version = "0.7.2"
edition = "2021"
readme = "README.md"
license = "BSD-3-Clause"
Expand Down
2 changes: 1 addition & 1 deletion changeforest-py/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "changeforest_py"
version = "0.7.1"
version = "0.7.2"
edition = "2021"

[lib]
Expand Down
11 changes: 8 additions & 3 deletions changeforest-py/changeforest/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,14 @@ def _plot_gain_result(gain_result):
fig, axes = plt.subplots()

axes.plot(range(gain_result.start, gain_result.stop), gain_result.gain, color="k")
ymin, ymax = axes.get_ylim()
axes.vlines(
np.nanmax(gain_result.gain) + gain_result.start,
np.nanargmax(gain_result.gain) + gain_result.start,
linestyles="dotted",
color="#EE6677", # red
linewidth=2,
ymax=ymax,
ymin=ymin,
)

axes.set_xlabel("split")
Expand Down Expand Up @@ -146,8 +149,10 @@ def _plot_binary_segmentation_result(binary_segmentation_result, max_depth=5):
if node.optimizer_result is not None:
result = node.optimizer_result.gain_results[-1]
gains[-1].append(np.full(n, np.nan))
gains[-1][-1][node.start : node.stop] = result.gain # noqa: E203
guesses[-1].append(result.guess)
gains[-1][-1][result.start : result.stop] = result.gain # noqa: E203

if result.guess is not None: # For change_in_mean
guesses[-1].append(result.guess)

if node.model_selection_result.is_significant:
found_changepoints[-1].append(node.best_split)
Expand Down
3 changes: 2 additions & 1 deletion changeforest-py/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ dependencies:
- python >=3.7
- maturin
- pytest
- numpy >=1.19.0
- numpy >=1.19.0
- matplotlib
4 changes: 2 additions & 2 deletions changeforest-py/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "changeforest"
description = "Classifier based non-parametric change point detection"
readme = "README.md"
version = "0.7.1"
version = "0.7.2"
requires-python = ">=3.7"
author = "Malte Londschien <[email protected]>"
urls = {homepage = "https://github.com/mlondschien/changeforest/"}
Expand All @@ -25,7 +25,7 @@ skip_glob = '\.eggs/*,\.git/*,\.venv/*,build/*,dist/*'
default_section = 'THIRDPARTY'

[tool.cibuildwheel]
test-requires = "numpy pytest"
test-requires = "numpy matplotlib pytest"
test-command = "pytest {project}/changeforest-py/tests"

# macos arm64 wheels can be built on X86_64, but cannot be tested.
Expand Down
27 changes: 27 additions & 0 deletions changeforest-py/tests/test_plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest

from changeforest import Control, changeforest


@pytest.mark.parametrize("method", ["knn", "change_in_mean", "random_forest"])
@pytest.mark.parametrize("segmentation_type", ["bs", "sbs", "wbs"])
def test_plot_binary_segmentation_result(iris_dataset, method, segmentation_type):
result = changeforest(
iris_dataset,
method,
segmentation_type,
control=Control(minimal_relative_segment_length=0.1),
)
result.plot().show()


@pytest.mark.parametrize("method", ["knn", "change_in_mean", "random_forest"])
@pytest.mark.parametrize("segmentation_type", ["bs", "sbs", "wbs"])
def test_plot_optimizer_result(iris_dataset, method, segmentation_type):
result = changeforest(
iris_dataset,
method,
segmentation_type,
control=Control(minimal_relative_segment_length=0.1),
)
result.optimizer_result.plot().show()
2 changes: 1 addition & 1 deletion changeforest-r/DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: changeforest
Type: Package
Title: Classifier Based Non-Parametric Change Point Detection
Version: 0.7.1
Version: 0.7.2
Author: Malte Londschien
Maintainer: Malte Londschien <[email protected]>
Description: Perform classifier based multivariate, non-parametric change point detection.
Expand Down
2 changes: 1 addition & 1 deletion changeforest-r/src/rust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = 'changeforestr'
version = '0.7.1'
version = '0.7.2'
edition = '2021'

[lib]
Expand Down

0 comments on commit b5ea904

Please sign in to comment.