Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for instances ids #180

Merged
merged 3 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
exclude: ^LICENSE/|\.(html|csv|svg|md)$
default_stages: [commit]
default_stages: [pre-commit, commit, pre-push]
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
Expand Down
4 changes: 3 additions & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ v0.8.0 (December X, X)
* Update LLM helper to support updated GPT-4 models [#174][#174]
* Update ruff to latest and remove black as a development dependency [#174][#174]
* Add Python 3.11 markers and CI testing [#174][#174]
* Add support for instance IDs when generating target values [#180][#180]
* Fixes
*
* Fix verbose print out during target value generation [#180][#180]

[#174]: <https://github.com/trane-dev/Trane/pull/174>
[#180]: <https://github.com/trane-dev/Trane/pull/180>

v0.7.0 (October 21, 2023)
=========================
Expand Down
8 changes: 4 additions & 4 deletions tests/test_problem_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ def test_problem_generator_single_table():
# 3. Generate target values for each problem
for p in problems:
if p.has_parameters_set() is True:
labels = p.create_target_values(dataframe)
labels = p.create_target_values(dataframe, verbose=False)
if labels.empty:
raise ValueError("labels should not be empty")
check_problem_type(labels, p.get_problem_type())
else:
thresholds = p.get_recommended_thresholds(dataframe)
for threshold in thresholds:
p.set_parameters(threshold)
labels = p.create_target_values(dataframe)
labels = p.create_target_values(dataframe, verbose=False)
check_problem_type(labels, p.get_problem_type())


Expand Down Expand Up @@ -84,13 +84,13 @@ def test_problem_generator_multi(tables, target_table):
string_repr = p.__repr__()
assert "2 days" in string_repr
if p.has_parameters_set() is True:
labels = p.create_target_values(dataframes)
labels = p.create_target_values(dataframes, verbose=False)
check_problem_type(labels, p.get_problem_type())
else:
thresholds = p.get_recommended_thresholds(dataframes)
for threshold in thresholds:
p.set_parameters(threshold)
labels = p.create_target_values(dataframes)
labels = p.create_target_values(dataframes, verbose=False)
check_problem_type(labels, p.get_problem_type())


Expand Down
15 changes: 14 additions & 1 deletion trane/core/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,13 @@
)
return thresholds

def create_target_values(self, dataframes, verbose=False):
def create_target_values(
self,
dataframes,
verbose=False,
nrows=None,
instance_ids=None,
):
# Won't this always be normalized?
normalized_dataframe = self.get_normalized_dataframe(dataframes)
if self.has_parameters_set() is False:
Expand All @@ -141,6 +147,12 @@
# create a fake index with all rows to generate predictions problems "Predict X"
normalized_dataframe["__identity__"] = 0
target_dataframe_index = "__identity__"
if instance_ids and len(instance_ids) > 0:
if verbose:
print("Only selecting given instance IDs")
normalized_dataframe = normalized_dataframe[

Check warning on line 153 in trane/core/problem.py

View check run for this annotation

Codecov / codecov/patch

trane/core/problem.py#L153

Added line #L153 was not covered by tests
normalized_dataframe[self.entity_column].isin(instance_ids)
]

lt = calculate_target_values(
df=normalized_dataframe,
Expand All @@ -149,6 +161,7 @@
time_index=self.metadata.time_index,
window_size=self.window_size,
verbose=verbose,
nrows=nrows,
)
if "__identity__" in normalized_dataframe.columns:
normalized_dataframe.drop(columns=["__identity__"], inplace=True)
Expand Down
17 changes: 13 additions & 4 deletions trane/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import pandas as pd


def set_dataframe_index(df, index):
def set_dataframe_index(df, index, verbose=False):
if df.index.name != index:
if verbose:
print(f"setting dataframe to index: {index}")
df = df.set_index(index, inplace=False)
return df


def generate_data_slices(df, window_size, gap, drop_empty=True):
def generate_data_slices(df, window_size, gap, drop_empty=True, verbose=False):
# valid for a specify group of id
# so we need to groupby id (before this function)
window_size = pd.to_timedelta(window_size)
Expand Down Expand Up @@ -36,18 +38,25 @@
window_size,
drop_empty=True,
verbose=False,
nrows=None,
):
df = set_dataframe_index(df, time_index)
df = set_dataframe_index(df, time_index, verbose=verbose)
if str(df.index.dtype) == "timestamp[ns][pyarrow]":
df.index = df.index.astype("datetime64[ns]")

Check warning on line 45 in trane/core/utils.py

View check run for this annotation

Codecov / codecov/patch

trane/core/utils.py#L45

Added line #L45 was not covered by tests
if nrows and nrows > 0 and nrows < len(df):
if verbose:
print("sampling {nrows} rows")
df = df.sample(n=nrows)

Check warning on line 49 in trane/core/utils.py

View check run for this annotation

Codecov / codecov/patch

trane/core/utils.py#L49

Added line #L49 was not covered by tests
records = []
label_name = labeling_function.__name__

for group_key, df_by_index in df.groupby(target_dataframe_index, observed=True):
# TODO: support gap
for dataslice, _ in generate_data_slices(
df=df_by_index,
window_size=window_size,
gap=window_size,
drop_empty=drop_empty,
verbose=verbose,
):
record = labeling_function(dataslice)
records.append(
Expand Down
7 changes: 0 additions & 7 deletions trane/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,6 @@ def reset_primary_key(self, table):
self.check_if_table_exists(table)
self.primary_keys.pop(table, None)

def obi(self, table):
self.check_if_table_exists(table)
if self.primary_keys:
primary_key = self.primary_keys[table]
self.ml_types[table][primary_key].remove_tag("primary_key")
self.primary_keys.pop(table)

def add_table(self, table, ml_types):
if table in self.ml_types:
raise ValueError("Table already exists")
Expand Down
10 changes: 9 additions & 1 deletion trane/parsing/denormalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
ml_types: Dict[str, Dict[str, str]],
target_table: str,
) -> pd.DataFrame:
keys_to_ml_type = {}
merged_dataframes = {}
for relationship in relationships:
parent_table_name, parent_key, child_table_name, child_key = relationship
Expand All @@ -58,6 +59,8 @@
raise ValueError(
f"{child_key} not in table: {child_table_name}",
)
keys_to_ml_type[parent_key] = ml_types.get(parent_table_name).get(parent_key)
keys_to_ml_type[child_key] = ml_types.get(child_table_name).get(child_key)
check_target_table(target_table, relationships, list(dataframes.keys()))
relationship_order = child_relationships(target_table, relationships)
if len(relationship_order) == 0:
Expand Down Expand Up @@ -110,7 +113,12 @@
# TODO: set primary key to be the index
# TODO: pass information to table meta (primary key, foreign keys)? maybe? technically relationships has this info
valid_columns = list(merged_dataframes[target_table].columns)
col_to_ml_type = {col: column_to_ml_type[col] for col in valid_columns}
col_to_ml_type = {}
for col in valid_columns:
if col in column_to_ml_type:
col_to_ml_type[col] = column_to_ml_type[col]
else:
col_to_ml_type[col] = keys_to_ml_type[col]

Check warning on line 121 in trane/parsing/denormalize.py

View check run for this annotation

Codecov / codecov/patch

trane/parsing/denormalize.py#L121

Added line #L121 was not covered by tests
return merged_dataframes, col_to_ml_type


Expand Down