Skip to content

Commit e061a74

Browse files
committed
Drop Dataset.relations and make Dataset data-path-aware
1 parent 6d80006 commit e061a74

File tree

4 files changed

+79
-70
lines changed

4 files changed

+79
-70
lines changed

khiops/sklearn/dataset.py

Lines changed: 39 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def _check_multitable_spec(ds_spec):
176176
)
177177

178178

179-
def _table_name_of_path(table_path):
179+
def table_name_of_path(table_path):
180180
return table_path.split("/")[-1]
181181

182182

@@ -387,7 +387,6 @@ def __init__(self, X, y=None, categorical_target=True):
387387
# Initialize members
388388
self.main_table = None
389389
self.additional_data_tables = None
390-
self.relations = None
391390
self.categorical_target = categorical_target
392391
self.target_column = None
393392
self.target_column_id = None
@@ -437,7 +436,8 @@ def __init__(self, X, y=None, categorical_target=True):
437436
# Index the tables by name
438437
self._tables_by_name = {
439438
table.name: table
440-
for table in [self.main_table] + self.additional_data_tables
439+
for table in [self.main_table]
440+
+ [table for _, table, _ in self.additional_data_tables]
441441
}
442442

443443
# Post-conditions
@@ -513,32 +513,21 @@ def _init_tables_from_mapping(self, X):
513513
key=main_table_key,
514514
)
515515
self.additional_data_tables = []
516-
self.relations = []
517516
if "additional_data_tables" in X:
518517
for table_path, table_spec in X["additional_data_tables"].items():
519518
table_source, table_key = table_spec[:2]
520-
table_name = _table_name_of_path(table_path)
519+
table_name = table_name_of_path(table_path)
521520
table = PandasTable(
522521
table_name,
523522
table_source,
524-
data_path=table_path,
525523
key=table_key,
526524
)
527-
self.additional_data_tables.append(table)
528525
is_one_to_one_relation = False
529526
if len(table_spec) == 3 and table_spec[2] is True:
530527
is_one_to_one_relation = True
531528

532-
# Set relation parent: if no "/" in path, main_table is the parent
533-
if not "/" in table_path:
534-
parent_table_name = self.main_table.name
535-
else:
536-
table_path_fragments = table_path.split("/")
537-
parent_table_name = _table_name_of_path(
538-
"/".join(table_path_fragments[:-1])
539-
)
540-
self.relations.append(
541-
(parent_table_name, table_name, is_one_to_one_relation)
529+
self.additional_data_tables.append(
530+
(table_path, table, is_one_to_one_relation)
542531
)
543532
# Initialize a sparse dataset (monotable)
544533
elif isinstance(main_table_source, sp.spmatrix):
@@ -548,7 +537,6 @@ def _init_tables_from_mapping(self, X):
548537
key=main_table_key,
549538
)
550539
self.additional_data_tables = []
551-
self.relations = []
552540
# Initialize a numpyarray dataset (monotable)
553541
elif hasattr(main_table_source, "__array__"):
554542
self.main_table = NumpyTable(
@@ -561,7 +549,6 @@ def _init_tables_from_mapping(self, X):
561549
"with pandas dataframe source tables"
562550
)
563551
self.additional_data_tables = []
564-
self.relations = []
565552
else:
566553
raise TypeError(
567554
type_error_message(
@@ -680,11 +667,12 @@ def to_spec(self):
680667
ds_spec = {}
681668
ds_spec["main_table"] = (self.main_table.data_source, self.main_table.key)
682669
ds_spec["additional_data_tables"] = {}
683-
for table in self.additional_data_tables:
684-
assert table.data_path is not None
685-
ds_spec["additional_data_tables"][table.data_path] = (
670+
for table_path, table, is_one_to_one_relation in self.additional_data_tables:
671+
assert table_path is not None
672+
ds_spec["additional_data_tables"][table_path] = (
686673
table.data_source,
687674
table.key,
675+
is_one_to_one_relation,
688676
)
689677

690678
return ds_spec
@@ -748,31 +736,32 @@ def create_khiops_dictionary_domain(self):
748736
# Note: In general 'name' and 'object_type' fields of Variable can be different
749737
if self.additional_data_tables:
750738
main_dictionary.root = True
751-
table_names = [table.name for table in self.additional_data_tables]
752-
tables_to_visit = [self.main_table.name]
753-
while tables_to_visit:
754-
current_table = tables_to_visit.pop(0)
755-
for relation in self.relations:
756-
parent_table, child_table, is_one_to_one_relation = relation
757-
if parent_table == current_table:
758-
tables_to_visit.append(child_table)
759-
parent_table_name = parent_table
760-
index_table = table_names.index(child_table)
761-
table = self.additional_data_tables[index_table]
762-
parent_table_dictionary = dictionary_domain.get_dictionary(
763-
parent_table_name
764-
)
765-
dictionary = table.create_khiops_dictionary()
766-
dictionary_domain.add_dictionary(dictionary)
767-
table_variable = kh.Variable()
768-
if is_one_to_one_relation:
769-
table_variable.type = "Entity"
770-
else:
771-
table_variable.type = "Table"
772-
table_variable.name = table.name
773-
table_variable.object_type = table.name
774-
parent_table_dictionary.add_variable(table_variable)
739+
for (
740+
table_path,
741+
table,
742+
is_one_to_one_relation,
743+
) in self.additional_data_tables:
744+
if not "/" in table_path:
745+
parent_table_name = self.main_table.name
746+
else:
747+
table_path_fragments = table_path.split("/")
748+
parent_table_name = table_name_of_path(
749+
"/".join(table_path_fragments[:-1])
750+
)
751+
parent_table_dictionary = dictionary_domain.get_dictionary(
752+
parent_table_name
753+
)
775754

755+
dictionary = table.create_khiops_dictionary()
756+
dictionary_domain.add_dictionary(dictionary)
757+
table_variable = kh.Variable()
758+
if is_one_to_one_relation:
759+
table_variable.type = "Entity"
760+
else:
761+
table_variable.type = "Table"
762+
table_variable.name = table.name
763+
table_variable.object_type = table.name
764+
parent_table_dictionary.add_variable(table_variable)
776765
return dictionary_domain
777766

778767
def create_table_files_for_khiops(self, output_dir, sort=True):
@@ -811,9 +800,9 @@ def create_table_files_for_khiops(self, output_dir, sort=True):
811800

812801
# Create a copy of each secondary table
813802
secondary_table_paths = {}
814-
for table in self.additional_data_tables:
815-
assert table.data_path is not None
816-
secondary_table_paths[table.data_path] = table.create_table_file_for_khiops(
803+
for table_path, table, _ in self.additional_data_tables:
804+
assert table_path is not None
805+
secondary_table_paths[table_path] = table.create_table_file_for_khiops(
817806
output_dir, sort=sort
818807
)
819808

@@ -918,13 +907,11 @@ class PandasTable(DatasetTable):
918907
Name for the table.
919908
dataframe : `pandas.DataFrame`
920909
The data frame to be encapsulated. It must be non-empty.
921-
data_path : str, optional
922-
Data path of the table. Unset for main tables.
923910
key : list of str, optional
924911
The names of the columns composing the key.
925912
"""
926913

927-
def __init__(self, name, dataframe, data_path=None, key=None):
914+
def __init__(self, name, dataframe, key=None):
928915
# Call the parent method
929916
super().__init__(name=name, key=key)
930917

@@ -937,7 +924,6 @@ def __init__(self, name, dataframe, data_path=None, key=None):
937924
# Initialize the attributes
938925
self.data_source = dataframe
939926
self.n_samples = len(self.data_source)
940-
self.data_path = data_path
941927

942928
# Initialize feature columns and verify their types
943929
self.column_ids = self.data_source.columns.values

khiops/sklearn/estimators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1516,7 +1516,7 @@ def _transform_check_dataset(self, ds):
15161516

15171517
# Multi-table model: Check name and dictionary coherence of secondary tables
15181518
dataset_secondary_tables_by_name = {
1519-
table.name: table for table in ds.additional_data_tables
1519+
table.name: table for _, table, _ in ds.additional_data_tables
15201520
}
15211521
for dictionary in self.model_.dictionaries:
15221522
assert dictionary.name.startswith(self._khiops_model_prefix), (

khiops/sklearn/helpers.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sklearn.model_selection import train_test_split
1212

1313
from khiops.core.internals.common import is_dict_like, type_error_message
14-
from khiops.sklearn.dataset import Dataset
14+
from khiops.sklearn.dataset import Dataset, table_name_of_path
1515

1616
# Note: We build the splits with lists and itertools.chain avoid pylint warning about
1717
# unbalanced-tuple-unpacking. See issue https://github.com/pylint-dev/pylint/issues/5671
@@ -122,15 +122,31 @@ def _train_test_split_in_memory_dataset(ds, y, test_size, sklearn_split_params=N
122122

123123
# Split the secondary tables tables
124124
# Note: The tables are traversed in BFS
125-
todo_relations = [
126-
relation for relation in ds.relations if relation[0] == ds.main_table.name
125+
todo_tables = [
126+
(table_path, table)
127+
for table_path, table, _ in ds.additional_data_tables
128+
if "/" not in table_path
127129
]
128-
while todo_relations:
129-
current_parent_table_name, current_child_table_name, _ = todo_relations.pop(0)
130-
for relation in ds.relations:
131-
parent_table_name, _, _ = relation
130+
while todo_tables:
131+
current_table_path, current_table = todo_tables.pop(0)
132+
if "/" not in current_table_path:
133+
current_parent_table_name = ds.main_table.name
134+
else:
135+
table_path_fragments = current_table_path.split("/")
136+
current_parent_table_name = table_name_of_path(
137+
"/".join(table_path_fragments[:-1])
138+
)
139+
current_child_table_name = current_table.name
140+
for secondary_table_path, secondary_table, _ in ds.additional_data_tables:
141+
if "/" not in secondary_table_path:
142+
parent_table_name = ds.main_table.name
143+
else:
144+
table_path_fragments = secondary_table_path.split("/")
145+
parent_table_name = table_name_of_path(
146+
"/".join(table_path_fragments[:-1])
147+
)
132148
if parent_table_name == current_child_table_name:
133-
todo_relations.append(relation)
149+
todo_tables.append((secondary_table_path, secondary_table))
134150

135151
for new_ds in (train_ds, test_ds):
136152
origin_child_table = ds.get_table(current_child_table_name)

tests/test_dataset_class.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -459,19 +459,22 @@ def test_dataset_is_correctly_built(self):
459459
self.assertEqual(dataset.main_table.name, "main_table")
460460
self.assertEqual(len(dataset.additional_data_tables), 4)
461461
dataset_secondary_table_names = {
462-
secondary_table.name for secondary_table in dataset.additional_data_tables
462+
secondary_table.name
463+
for _, secondary_table, _ in dataset.additional_data_tables
463464
}
464465
self.assertEqual(dataset_secondary_table_names, {"B", "C", "D", "E"})
465-
self.assertEqual(len(dataset.relations), 4)
466466

467467
table_specs = ds_spec["additional_data_tables"].items()
468-
for relation, (table_path, table_spec) in zip(dataset.relations, table_specs):
468+
for (ds_table_path, _, ds_is_one_to_one), (
469+
table_path,
470+
table_spec,
471+
) in zip(dataset.additional_data_tables, table_specs):
469472
# The relation holds the table name, not the table path
470-
self.assertEqual(relation[1], table_path.split("/")[-1])
473+
self.assertEqual(ds_table_path, table_path)
471474
if len(table_spec) == 3:
472-
self.assertEqual(relation[2], table_spec[2])
475+
self.assertEqual(ds_is_one_to_one, table_spec[2])
473476
else:
474-
self.assertFalse(relation[2])
477+
self.assertFalse(ds_is_one_to_one)
475478

476479
def test_out_file_from_dataframe_monotable(self):
477480
"""Test consistency of the created data file with the input dataframe
@@ -745,7 +748,9 @@ def _test_domain_coherence(self, ds, ref_var_types):
745748

746749
# Check that the domain has the same table names as the reference
747750
ref_table_names = {
748-
table.name for table in [ds.main_table] + ds.additional_data_tables
751+
table.name
752+
for table in [ds.main_table]
753+
+ [table for _, table, _ in ds.additional_data_tables]
749754
}
750755
out_table_names = {dictionary.name for dictionary in out_domain.dictionaries}
751756
self.assertEqual(ref_table_names, out_table_names)
@@ -758,7 +763,9 @@ def _test_domain_coherence(self, ds, ref_var_types):
758763
# Check that:
759764
# - the table keys are the same as the dataset
760765
# - the domain has the same variable names as the reference
761-
for table in [ds.main_table] + ds.additional_data_tables:
766+
for table in [ds.main_table] + [
767+
table for _, table, _ in ds.additional_data_tables
768+
]:
762769
with self.subTest(table=table.name):
763770
self.assertEqual(table.key, out_domain.get_dictionary(table.name).key)
764771
out_dictionary_var_types = {

0 commit comments

Comments
 (0)