Skip to content

Commit

Permalink
Add support for instances ids (#180)
Browse files Browse the repository at this point in the history
* fixes

* update pre-commit

* changelog
  • Loading branch information
gsheni authored Jan 2, 2024
1 parent 41539cb commit ed6a3ca
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
exclude: ^LICENSE/|\.(html|csv|svg|md)$
default_stages: [commit]
default_stages: [pre-commit, commit, pre-push]
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
Expand Down
4 changes: 3 additions & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ v0.8.0 (December X, X)
* Update LLM helper to support updated GPT-4 models [#174][#174]
* Update ruff to latest and remove black as a development dependency [#174][#174]
* Add Python 3.11 markers and CI testing [#174][#174]
* Add support for instance IDs when generating target values [#180][#180]
* Fixes
*
* Fix verbose print out during target value generation [#180][#180]

[#174]: <https://github.com/trane-dev/Trane/pull/174>
[#180]: <https://github.com/trane-dev/Trane/pull/180>

v0.7.0 (October 21, 2023)
=========================
Expand Down
8 changes: 4 additions & 4 deletions tests/test_problem_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ def test_problem_generator_single_table():
# 3. Generate target values for each problem
for p in problems:
if p.has_parameters_set() is True:
labels = p.create_target_values(dataframe)
labels = p.create_target_values(dataframe, verbose=False)
if labels.empty:
raise ValueError("labels should not be empty")
check_problem_type(labels, p.get_problem_type())
else:
thresholds = p.get_recommended_thresholds(dataframe)
for threshold in thresholds:
p.set_parameters(threshold)
labels = p.create_target_values(dataframe)
labels = p.create_target_values(dataframe, verbose=False)
check_problem_type(labels, p.get_problem_type())


Expand Down Expand Up @@ -84,13 +84,13 @@ def test_problem_generator_multi(tables, target_table):
string_repr = p.__repr__()
assert "2 days" in string_repr
if p.has_parameters_set() is True:
labels = p.create_target_values(dataframes)
labels = p.create_target_values(dataframes, verbose=False)
check_problem_type(labels, p.get_problem_type())
else:
thresholds = p.get_recommended_thresholds(dataframes)
for threshold in thresholds:
p.set_parameters(threshold)
labels = p.create_target_values(dataframes)
labels = p.create_target_values(dataframes, verbose=False)
check_problem_type(labels, p.get_problem_type())


Expand Down
15 changes: 14 additions & 1 deletion trane/core/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,13 @@ def get_recommended_thresholds(self, dataframes, n_quantiles=10):
)
return thresholds

def create_target_values(self, dataframes, verbose=False):
def create_target_values(
self,
dataframes,
verbose=False,
nrows=None,
instance_ids=None,
):
# Won't this always be normalized?
normalized_dataframe = self.get_normalized_dataframe(dataframes)
if self.has_parameters_set() is False:
Expand All @@ -141,6 +147,12 @@ def create_target_values(self, dataframes, verbose=False):
# create a fake index with all rows to generate predictions problems "Predict X"
normalized_dataframe["__identity__"] = 0
target_dataframe_index = "__identity__"
if instance_ids and len(instance_ids) > 0:
if verbose:
print("Only selecting given instance IDs")
normalized_dataframe = normalized_dataframe[
normalized_dataframe[self.entity_column].isin(instance_ids)
]

lt = calculate_target_values(
df=normalized_dataframe,
Expand All @@ -149,6 +161,7 @@ def create_target_values(self, dataframes, verbose=False):
time_index=self.metadata.time_index,
window_size=self.window_size,
verbose=verbose,
nrows=nrows,
)
if "__identity__" in normalized_dataframe.columns:
normalized_dataframe.drop(columns=["__identity__"], inplace=True)
Expand Down
17 changes: 13 additions & 4 deletions trane/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import pandas as pd


def set_dataframe_index(df, index):
def set_dataframe_index(df, index, verbose=False):
if df.index.name != index:
if verbose:
print(f"setting dataframe to index: {index}")
df = df.set_index(index, inplace=False)
return df


def generate_data_slices(df, window_size, gap, drop_empty=True):
def generate_data_slices(df, window_size, gap, drop_empty=True, verbose=False):
# valid for a specify group of id
# so we need to groupby id (before this function)
window_size = pd.to_timedelta(window_size)
Expand Down Expand Up @@ -36,18 +38,25 @@ def calculate_target_values(
window_size,
drop_empty=True,
verbose=False,
nrows=None,
):
df = set_dataframe_index(df, time_index)
df = set_dataframe_index(df, time_index, verbose=verbose)
if str(df.index.dtype) == "timestamp[ns][pyarrow]":
df.index = df.index.astype("datetime64[ns]")
if nrows and nrows > 0 and nrows < len(df):
if verbose:
print("sampling {nrows} rows")
df = df.sample(n=nrows)
records = []
label_name = labeling_function.__name__

for group_key, df_by_index in df.groupby(target_dataframe_index, observed=True):
# TODO: support gap
for dataslice, _ in generate_data_slices(
df=df_by_index,
window_size=window_size,
gap=window_size,
drop_empty=drop_empty,
verbose=verbose,
):
record = labeling_function(dataslice)
records.append(
Expand Down
7 changes: 0 additions & 7 deletions trane/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,6 @@ def reset_primary_key(self, table):
self.check_if_table_exists(table)
self.primary_keys.pop(table, None)

def obi(self, table):
self.check_if_table_exists(table)
if self.primary_keys:
primary_key = self.primary_keys[table]
self.ml_types[table][primary_key].remove_tag("primary_key")
self.primary_keys.pop(table)

def add_table(self, table, ml_types):
if table in self.ml_types:
raise ValueError("Table already exists")
Expand Down
10 changes: 9 additions & 1 deletion trane/parsing/denormalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def denormalize_dataframes(
ml_types: Dict[str, Dict[str, str]],
target_table: str,
) -> pd.DataFrame:
keys_to_ml_type = {}
merged_dataframes = {}
for relationship in relationships:
parent_table_name, parent_key, child_table_name, child_key = relationship
Expand All @@ -58,6 +59,8 @@ def denormalize_dataframes(
raise ValueError(
f"{child_key} not in table: {child_table_name}",
)
keys_to_ml_type[parent_key] = ml_types.get(parent_table_name).get(parent_key)
keys_to_ml_type[child_key] = ml_types.get(child_table_name).get(child_key)
check_target_table(target_table, relationships, list(dataframes.keys()))
relationship_order = child_relationships(target_table, relationships)
if len(relationship_order) == 0:
Expand Down Expand Up @@ -110,7 +113,12 @@ def denormalize_dataframes(
# TODO: set primary key to be the index
# TODO: pass information to table meta (primary key, foreign keys)? maybe? technically relationships has this info
valid_columns = list(merged_dataframes[target_table].columns)
col_to_ml_type = {col: column_to_ml_type[col] for col in valid_columns}
col_to_ml_type = {}
for col in valid_columns:
if col in column_to_ml_type:
col_to_ml_type[col] = column_to_ml_type[col]
else:
col_to_ml_type[col] = keys_to_ml_type[col]
return merged_dataframes, col_to_ml_type


Expand Down

0 comments on commit ed6a3ca

Please sign in to comment.