diff --git a/kneed/knee_locator.py b/kneed/knee_locator.py index a83ed22..a952b3b 100644 --- a/kneed/knee_locator.py +++ b/kneed/knee_locator.py @@ -188,6 +188,11 @@ def __init__( self.y_normalized = self.transform_y( self.y_normalized, self.direction, self.curve ) + + self.x_normalized = self.transform_x( + self.x_normalized, self.direction, self.curve + ) + # normalized difference curve # normalized difference curve self.y_difference = self.y_normalized - self.x_normalized self.x_difference = self.x_normalized.copy() @@ -228,16 +233,21 @@ def __normalize(a: Iterable[float]) -> Iterable[float]: def transform_y(y: Iterable[float], direction: str, curve: str) -> float: """transform y to concave, increasing based on given direction and curve""" # convert elbows to knees - if direction == "decreasing": - if curve == "concave": - y = np.flip(y) - elif curve == "convex": - y = y.max() - y - elif direction == "increasing" and curve == "convex": - y = np.flip(y.max() - y) + if curve == "convex": + y = y.max() - y return y + @staticmethod + def transform_x(x: Iterable[float], direction: str, curve: str) -> float: + """transform x to concave, increasing based on given direction and curve""" + # convert elbows to knees + if (direction == "decreasing" and curve == "concave") or ( + direction == "increasing" and curve == "convex"): + x = x.max() - x + + return x + def find_knee( self, ):