Skip to content

Commit

Permalink
Add dataframe_type property to EntitySet (#1473)
Browse files Browse the repository at this point in the history
* add dataframe_type property

* remove _get_entityset_type

* update if not pandas entityset checks in tests

* add docstring to dataframe_type

* update release notes

* rework dataframe_type logic

* add test cases

* use dataframe_type in more tests

* remove some unused ks imports

* more test updates

* fix faulty comparison in tests
  • Loading branch information
rwedge authored Jun 11, 2021
1 parent d75a9ef commit d882b8f
Show file tree
Hide file tree
Showing 21 changed files with 153 additions and 146 deletions.
1 change: 1 addition & 0 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Future Release
==============
* Enhancements
* Add ``get_valid_primitives`` function (:pr:`1462`)
* Add ``EntitySet.dataframe_type`` attribute (:pr:`1473`)
* Fixes
* Changes
* Upgrade minimum alteryx open source update checker to 2.0.0 (:pr:`1460`)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from featuretools.feature_base import AggregationFeature, FeatureBase
from featuretools.utils import Trie
from featuretools.utils.gen_utils import (
Library,
import_or_none,
is_instance,
make_tqdm_iterator
Expand Down Expand Up @@ -147,7 +148,7 @@ def calculate_feature_matrix(features, entityset=None, cutoff_time=None, instanc
else:
raise TypeError("No entities or valid EntitySet provided")

if any(isinstance(es.df, dd.DataFrame) for es in entityset.entities):
if entityset.dataframe_type == Library.DASK.value:
if approximate:
msg = "Using approximate is not supported with Dask Entities"
raise ValueError(msg)
Expand Down
19 changes: 17 additions & 2 deletions featuretools/entityset/entityset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from featuretools.entityset import deserialize, serialize
from featuretools.entityset.entity import Entity
from featuretools.entityset.relationship import Relationship, RelationshipPath
from featuretools.utils.gen_utils import import_or_none, is_instance
from featuretools.utils.gen_utils import Library, import_or_none, is_instance
from featuretools.utils.plot_utils import (
check_graphviz,
get_graphviz_format,
Expand Down Expand Up @@ -144,6 +144,21 @@ def __getitem__(self, entity_id):
def entities(self):
return list(self.entity_dict.values())

@property
def dataframe_type(self):
'''String specifying the library used for the entity dataframes. Null if no entities'''
df_type = None

if self.entities:
if isinstance(self.entities[0].df, pd.DataFrame):
df_type = Library.PANDAS.value
elif isinstance(self.entities[0].df, dd.DataFrame):
df_type = Library.DASK.value
elif is_instance(self.entities[0].df, ks, 'DataFrame'):
df_type = Library.KOALAS.value

return df_type

@property
def metadata(self):
'''Returns the metadata for this EntitySet. The metadata will be recomputed if it does not exist.'''
Expand Down Expand Up @@ -196,7 +211,7 @@ def to_csv(self, path, sep=',', encoding='utf-8', engine='python', compression=N
compression (str) : Name of the compression to use. Possible values are: {'gzip', 'bz2', 'zip', 'xz', None}.
profile_name (str) : Name of AWS profile to use, False to use an anonymous profile, or None.
'''
if is_instance(self.entities[0].df, ks, 'DataFrame'):
if self.dataframe_type == Library.KOALAS.value:
compression = str(compression)
serialize.write_data_description(self, path, format='csv', index=False, sep=sep, encoding=encoding, engine=engine, compression=compression, profile_name=profile_name)
return self
Expand Down
15 changes: 9 additions & 6 deletions featuretools/synthesis/deep_feature_synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
generate_all_primitive_options,
ignore_entity_for_primitive
)
from featuretools.synthesis.utils import _get_entityset_type
from featuretools.utils.gen_utils import Library
from featuretools.variable_types import Boolean, Discrete, Id, Numeric

logger = logging.getLogger('featuretools')
Expand Down Expand Up @@ -185,10 +185,13 @@ def __init__(self,
self.target_entity_id = target_entity_id
self.es = entityset

entityset_type = _get_entityset_type(self.es)
for library in Library:
if library.value == self.es.dataframe_type:
df_library = library
break

if agg_primitives is None:
agg_primitives = [p for p in primitives.get_default_aggregation_primitives() if entityset_type in p.compatibility]
agg_primitives = [p for p in primitives.get_default_aggregation_primitives() if df_library in p.compatibility]
self.agg_primitives = []
agg_prim_dict = primitives.get_aggregation_primitives()
for a in agg_primitives:
Expand All @@ -206,7 +209,7 @@ def __init__(self,
self.agg_primitives.sort()

if trans_primitives is None:
trans_primitives = [p for p in primitives.get_default_transform_primitives() if entityset_type in p.compatibility]
trans_primitives = [p for p in primitives.get_default_transform_primitives() if df_library in p.compatibility]
self.trans_primitives = []
for t in trans_primitives:
t = check_trans_primitive(t)
Expand Down Expand Up @@ -240,10 +243,10 @@ def __init__(self,
primitive_options = {}
all_primitives = self.trans_primitives + self.agg_primitives + \
self.where_primitives + self.groupby_trans_primitives
bad_primitives = [prim.name for prim in all_primitives if entityset_type not in prim.compatibility]
bad_primitives = [prim.name for prim in all_primitives if df_library not in prim.compatibility]
if bad_primitives:
msg = 'Selected primitives are incompatible with {} EntitySets: {}'
raise ValueError(msg.format(entityset_type.value, ', '.join(bad_primitives)))
raise ValueError(msg.format(df_library.value, ', '.join(bad_primitives)))

self.primitive_options, self.ignore_entities, self.ignore_variables =\
generate_all_primitive_options(all_primitives,
Expand Down
14 changes: 9 additions & 5 deletions featuretools/synthesis/get_valid_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from featuretools.synthesis.deep_feature_synthesis import DeepFeatureSynthesis
from featuretools.synthesis.utils import (
_categorize_features,
_get_entityset_type,
get_unused_primitives
)
from featuretools.utils.gen_utils import Library


def get_valid_primitives(entityset, target_entity, max_depth=2, selected_primitives=None):
Expand Down Expand Up @@ -40,7 +40,11 @@ def get_valid_primitives(entityset, target_entity, max_depth=2, selected_primiti
trans_primitives = []
available_aggs = get_aggregation_primitives()
available_trans = get_transform_primitives()
entityset_type = _get_entityset_type(entityset)

for library in Library:
if library.value == entityset.dataframe_type:
df_library = library
break

if selected_primitives:
for prim in selected_primitives:
Expand All @@ -60,13 +64,13 @@ def get_valid_primitives(entityset, target_entity, max_depth=2, selected_primiti
prim_list = trans_primitives
else:
raise ValueError(f"'{prim}' is not a recognized primitive name")
if entityset_type in prim.compatibility:
if df_library in prim.compatibility:
prim_list.append(prim)
else:
agg_primitives = [agg for agg in available_aggs.values()
if entityset_type in agg.compatibility]
if df_library in agg.compatibility]
trans_primitives = [trans for trans in available_trans.values()
if entityset_type in trans.compatibility]
if df_library in trans.compatibility]

dfs_object = DeepFeatureSynthesis(target_entity, entityset,
agg_primitives=agg_primitives,
Expand Down
16 changes: 0 additions & 16 deletions featuretools/synthesis/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from dask import dataframe as dd

from featuretools.feature_base import (
AggregationFeature,
FeatureOutputSlice,
GroupByTransformFeature,
TransformFeature
)
from featuretools.utils.gen_utils import Library, import_or_none, is_instance

ks = import_or_none('databricks.koalas')


def _categorize_features(features):
Expand Down Expand Up @@ -59,14 +54,3 @@ def get_unused_primitives(specified, used):
return []
specified = {primitive if isinstance(primitive, str) else primitive.name for primitive in specified}
return sorted(list(specified.difference(used)))


def _get_entityset_type(entityset):
if any(isinstance(entity.df, dd.DataFrame) for entity in entityset.entities):
entityset_type = Library.DASK
elif any(is_instance(entity.df, ks, 'DataFrame') for entity in entityset.entities):
entityset_type = Library.KOALAS
else:
entityset_type = Library.PANDAS

return entityset_type
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@
get_mock_client_cluster,
to_pandas
)
from featuretools.utils.gen_utils import Library, import_or_none

ks = import_or_none('databricks.koalas')
from featuretools.utils.gen_utils import Library


def test_scatter_warning(caplog):
Expand All @@ -63,7 +61,7 @@ def test_scatter_warning(caplog):

# TODO: final assert fails w/ Dask
def test_calc_feature_matrix(es):
if not all(isinstance(entity.df, pd.DataFrame) for entity in es.entities):
if es.dataframe_type != Library.PANDAS.value:
pytest.xfail('Distributed dataframe result not ordered')
times = list([datetime(2011, 4, 9, 10, 30, i * 6) for i in range(5)] +
[datetime(2011, 4, 9, 10, 31, i * 9) for i in range(4)] +
Expand Down Expand Up @@ -166,7 +164,7 @@ def test_cfm_compose(es, lt):


def test_cfm_compose_approximate(es, lt):
if not all(isinstance(entity.df, pd.DataFrame) for entity in es.entities):
if es.dataframe_type != Library.PANDAS.value:
pytest.xfail('dask does not support approximate')

property_feature = ft.Feature(es['log']['value']) > 10
Expand Down Expand Up @@ -278,7 +276,7 @@ def test_cfm_no_cutoff_time_index(pd_es):
# TODO: fails with dask entitysets
# TODO: fails with koalas entitysets
def test_cfm_duplicated_index_in_cutoff_time(es):
if not all(isinstance(entity.df, pd.DataFrame) for entity in es.entities):
if es.dataframe_type != Library.PANDAS.value:
pytest.xfail('Distributed results not ordered, missing duplicates')
times = [datetime(2011, 4, 1), datetime(2011, 5, 1),
datetime(2011, 4, 1), datetime(2011, 5, 1)]
Expand All @@ -297,7 +295,7 @@ def test_cfm_duplicated_index_in_cutoff_time(es):

# TODO: fails with Dask, Koalas
def test_saveprogress(es, tmpdir):
if not all(isinstance(entity.df, pd.DataFrame) for entity in es.entities):
if es.dataframe_type != Library.PANDAS.value:
pytest.xfail('saveprogress fails with distributed entitysets')
times = list([datetime(2011, 4, 9, 10, 30, i * 6) for i in range(5)] +
[datetime(2011, 4, 9, 10, 31, i * 9) for i in range(4)] +
Expand Down Expand Up @@ -1029,7 +1027,7 @@ def test_cutoff_time_naming(es):

# TODO: order doesn't match, but output matches
def test_cutoff_time_extra_columns(es):
if not all(isinstance(entity.df, pd.DataFrame) for entity in es.entities):
if es.dataframe_type != Library.PANDAS.value:
pytest.xfail('Distributed result not ordered')
agg_feat = ft.Feature(es['customers']['id'], parent_entity=es[u'régions'], primitive=Count)
dfeat = DirectFeature(agg_feat, es['customers'])
Expand Down Expand Up @@ -1068,7 +1066,7 @@ def test_cutoff_time_extra_columns_approximate(pd_es):


def test_cutoff_time_extra_columns_same_name(es):
if not all(isinstance(entity.df, pd.DataFrame) for entity in es.entities):
if es.dataframe_type != Library.PANDAS.value:
pytest.xfail('Distributed result not ordered')
agg_feat = ft.Feature(es['customers']['id'], parent_entity=es[u'régions'], primitive=Count)
dfeat = DirectFeature(agg_feat, es['customers'])
Expand Down Expand Up @@ -1118,7 +1116,7 @@ def test_instances_after_cutoff_time_removed(es):

# TODO: Dask and Koalas do not keep instance_id after cutoff
def test_instances_with_id_kept_after_cutoff(es):
if not all(isinstance(entity.df, pd.DataFrame) for entity in es.entities):
if es.dataframe_type != Library.PANDAS.value:
pytest.xfail('Distributed result not ordered, missing extra instances')
property_feature = ft.Feature(es['log']['id'], parent_entity=es['customers'], primitive=Count)
cutoff_time = datetime(2011, 4, 8)
Expand All @@ -1137,7 +1135,7 @@ def test_instances_with_id_kept_after_cutoff(es):
# TODO: Fails with Dask
# TODO: Fails with Koalas
def test_cfm_returns_original_time_indexes(es):
if not all(isinstance(entity.df, pd.DataFrame) for entity in es.entities):
if es.dataframe_type != Library.PANDAS.value:
pytest.xfail('Distributed result not ordered, indexes are lost due to not multiindexing')
agg_feat = ft.Feature(es['customers']['id'], parent_entity=es[u'régions'], primitive=Count)
dfeat = DirectFeature(agg_feat, es['customers'])
Expand Down Expand Up @@ -1534,7 +1532,7 @@ def test_string_time_values_in_cutoff_time(es):
# TODO: Dask version fails (feature matrix is empty)
# TODO: Koalas version fails (koalas groupby agg doesn't support custom functions)
def test_no_data_for_cutoff_time(mock_customer):
if not all(isinstance(entity.df, pd.DataFrame) for entity in mock_customer.entities):
if mock_customer.dataframe_type != Library.PANDAS.value:
pytest.xfail("Dask fails because returned feature matrix is empty; Koalas doesn't support custom agg functions")
es = mock_customer
cutoff_times = pd.DataFrame({"customer_id": [4],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from featuretools.primitives.base import AggregationPrimitive
from featuretools.tests.testing_utils import backward_path, to_pandas
from featuretools.utils import Trie
from featuretools.utils.gen_utils import Library
from featuretools.variable_types import Numeric


Expand Down Expand Up @@ -165,7 +166,7 @@ def test_make_agg_feat_using_prev_time(es):


def test_make_agg_feat_using_prev_n_events(es):
if not all(isinstance(entity.df, pd.DataFrame) for entity in es.entities):
if es.dataframe_type != Library.PANDAS.value:
pytest.xfail('Distrubuted entitysets do not support use_previous')
agg_feat_1 = ft.Feature(es['log']['value'],
parent_entity=es['sessions'],
Expand Down Expand Up @@ -204,7 +205,7 @@ def test_make_agg_feat_using_prev_n_events(es):


def test_make_agg_feat_multiple_dtypes(es):
if not all(isinstance(entity.df, pd.DataFrame) for entity in es.entities):
if es.dataframe_type != Library.PANDAS.value:
pytest.xfail('Currently no Dask or Koalas compatible agg prims that use multiple dtypes')
compare_prod = IdentityFeature(es['log']['product_id']) == 'coke zero'

Expand Down Expand Up @@ -855,7 +856,7 @@ def test_with_features_built_from_es_metadata(es):

# TODO: Fails with Dask and Koalas (conflicting aggregation primitives)
def test_handles_primitive_function_name_uniqueness(es):
if not all(isinstance(entity.df, pd.DataFrame) for entity in es.entities):
if es.dataframe_type != Library.PANDAS.value:
pytest.xfail("Fails with Dask and Koalas due conflicting aggregation primitive names")

class SumTimesN(AggregationPrimitive):
Expand Down Expand Up @@ -989,7 +990,7 @@ def test_calls_progress_callback(es):
trans_full = ft.Feature(agg, primitive=CumSum)
groupby_trans = ft.Feature(agg, primitive=CumSum, groupby=es["customers"]["cohort"])

if not all(isinstance(entity.df, pd.DataFrame) for entity in es.entities):
if es.dataframe_type != Library.PANDAS.value:
all_features = [identity, direct, agg, trans]
else:
all_features = [identity, direct, agg, agg_apply, trans, trans_full, groupby_trans]
Expand Down
5 changes: 5 additions & 0 deletions featuretools/tests/entityset_tests/test_dask_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import featuretools as ft
from featuretools.entityset import EntitySet, Relationship
from featuretools.utils.gen_utils import Library


def test_create_entity_from_dask_df(pd_es):
Expand Down Expand Up @@ -138,3 +139,7 @@ def test_create_entity_with_make_index():

expected_df = pd.DataFrame({"new_index": range(len(values)), "values": values})
pd.testing.assert_frame_equal(expected_df, dask_es['new_entity'].df.compute())


def test_dataframe_type_dask(dask_es):
assert dask_es.dataframe_type == Library.DASK.value
10 changes: 4 additions & 6 deletions featuretools/tests/entityset_tests/test_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@
make_ecommerce_entityset,
to_pandas
)
from featuretools.utils.gen_utils import import_or_none
from featuretools.utils.gen_utils import Library
from featuretools.variable_types import find_variable_types

ks = import_or_none('databricks.koalas')


def test_is_index_column(es):
assert es['cohorts'].index == 'cohort'
Expand Down Expand Up @@ -97,7 +95,7 @@ def test_eq(es):

def test_update_data(es):
df = es['customers'].df.copy()
if ks and isinstance(df, ks.DataFrame):
if es.dataframe_type == Library.KOALAS.value:
df['new'] = [1, 2, 3]
else:
df['new'] = pd.Series([1, 2, 3])
Expand All @@ -116,7 +114,7 @@ def test_update_data(es):
updated_id.iloc[1] = 2
updated_id.iloc[2] = 1

if ks and isinstance(df, ks.DataFrame):
if es.dataframe_type == Library.KOALAS.value:
df["id"] = updated_id.to_list()
df = df.sort_index()
else:
Expand All @@ -133,7 +131,7 @@ def test_update_data(es):
updated_signup = to_pandas(df['signup_date'])
updated_signup.iloc[0] = datetime(2011, 4, 11)

if ks and isinstance(df, ks.DataFrame):
if es.dataframe_type == Library.KOALAS.value:
df['signup_date'] = updated_signup.to_list()
df = df.sort_index()
else:
Expand Down
Loading

0 comments on commit d882b8f

Please sign in to comment.