Skip to content

Commit 6d80006

Browse files
committed
Add deprecation path for the hierachical schema specification in sklearn
1 parent 83b3bfb commit 6d80006

File tree

2 files changed

+170
-8
lines changed

2 files changed

+170
-8
lines changed

khiops/sklearn/dataset.py

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121
import khiops.core as kh
2222
import khiops.core.internals.filesystems as fs
2323
from khiops.core.dictionary import VariableBlock
24-
from khiops.core.internals.common import is_dict_like, is_list_like, type_error_message
24+
from khiops.core.internals.common import (
25+
deprecation_message,
26+
is_dict_like,
27+
is_list_like,
28+
type_error_message,
29+
)
2530

2631
# Disable PEP8 variable names because of scikit-learn X,y conventions
2732
# To capture invalid-names other than X,y run:
@@ -171,6 +176,54 @@ def _check_multitable_spec(ds_spec):
171176
)
172177

173178

179+
def _table_name_of_path(table_path):
180+
return table_path.split("/")[-1]
181+
182+
183+
def _upgrade_mapping_spec(ds_spec):
184+
assert is_dict_like(ds_spec)
185+
new_ds_spec = {}
186+
new_ds_spec["additional_data_tables"] = {}
187+
for table_name, table_data in ds_spec["tables"].items():
188+
table_df, table_key = table_data
189+
if not is_list_like(table_key):
190+
table_key = [table_key]
191+
if table_name == ds_spec["main_table"]:
192+
new_ds_spec["main_table"] = (table_df, table_key)
193+
else:
194+
table_path = [table_name]
195+
is_entity = False
196+
197+
# Cycle 4 times on the relations to get all transitive relation, like:
198+
# - current table name N
199+
# - main table name N1
200+
# - and relations: (N1, N2), (N2, N3), (N3, N)
201+
# the data-path must be N2/N3/N
202+
# Note: this is a heuristic that should be replaced with a graph
203+
# traversal procedure
204+
# If no "relations" key exists, then one has a star schema and
205+
# the data-paths are the names of the secondary tables themselves
206+
# (with respect to the main table)
207+
if "relations" in ds_spec:
208+
for relation in list(ds_spec["relations"]) * 4:
209+
left, right = relation[:2]
210+
if len(relation) == 3 and right == table_name:
211+
is_entity = relation[2]
212+
if (
213+
left != ds_spec["main_table"]
214+
and left not in table_path
215+
and right in table_path
216+
):
217+
table_path.insert(0, left)
218+
table_path = "/".join(table_path)
219+
if is_entity:
220+
table_data = (table_df, table_key, is_entity)
221+
else:
222+
table_data = (table_df, table_key)
223+
new_ds_spec["additional_data_tables"][table_path] = table_data
224+
return new_ds_spec
225+
226+
174227
def get_khiops_type(numpy_type):
175228
"""Translates a numpy dtype to a Khiops dictionary type
176229
@@ -426,14 +479,26 @@ def _check_input_sequence(self, X, key=None):
426479
# Check the key for the main_table (it is the same for the others)
427480
_check_table_key("main_table", key)
428481

429-
def _table_name_of_path(self, table_path):
430-
# TODO: Add >= 128-character truncation and indexing scheme
431-
return table_path.split("/")[-1]
432-
433482
def _init_tables_from_mapping(self, X):
434483
"""Initializes the table spec from a dict-like 'X'"""
435484
assert is_dict_like(X), "'X' must be dict-like"
436485

486+
# Detect if deprecated mapping specification syntax is used;
487+
# if so, issue deprecation warning and transform it to the new syntax
488+
if "tables" in X.keys() and isinstance(X.get("main_table"), str):
489+
warnings.warn(
490+
deprecation_message(
491+
"This multi-table dataset specification format",
492+
"11.0.1",
493+
replacement=(
494+
"the new data-path-based format, as documented in "
495+
":doc:`multi_table_primer`."
496+
),
497+
quote=False,
498+
)
499+
)
500+
X = _upgrade_mapping_spec(X)
501+
437502
# Check the input mapping
438503
check_dataset_spec(X)
439504

@@ -452,7 +517,7 @@ def _init_tables_from_mapping(self, X):
452517
if "additional_data_tables" in X:
453518
for table_path, table_spec in X["additional_data_tables"].items():
454519
table_source, table_key = table_spec[:2]
455-
table_name = self._table_name_of_path(table_path)
520+
table_name = _table_name_of_path(table_path)
456521
table = PandasTable(
457522
table_name,
458523
table_source,
@@ -469,7 +534,7 @@ def _init_tables_from_mapping(self, X):
469534
parent_table_name = self.main_table.name
470535
else:
471536
table_path_fragments = table_path.split("/")
472-
parent_table_name = self._table_name_of_path(
537+
parent_table_name = _table_name_of_path(
473538
"/".join(table_path_fragments[:-1])
474539
)
475540
self.relations.append(

tests/test_dataset_class.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import shutil
99
import unittest
10+
import warnings
1011

1112
import numpy as np
1213
import pandas as pd
@@ -15,7 +16,7 @@
1516
from pandas.testing import assert_frame_equal
1617
from sklearn import datasets
1718

18-
from khiops.sklearn.dataset import Dataset
19+
from khiops.sklearn.dataset import Dataset, _upgrade_mapping_spec
1920

2021

2122
class DatasetInputOutputConsistencyTests(unittest.TestCase):
@@ -352,6 +353,102 @@ def get_ref_var_types(self, multitable, schema=None):
352353

353354
return ref_var_types
354355

356+
def test_dataset_of_deprecated_mt_mapping(self):
357+
"""Test deprecated multi-table specification handling"""
358+
(
359+
ref_main_table,
360+
ref_secondary_table_1,
361+
ref_secondary_table_2,
362+
ref_tertiary_table,
363+
ref_quaternary_table,
364+
) = self.create_multitable_snowflake_dataframes()
365+
366+
features_ref_main_table = ref_main_table.drop("class", axis=1)
367+
expected_ds_spec = {
368+
"main_table": (features_ref_main_table, ["User_ID"]),
369+
"additional_data_tables": {
370+
"B": (ref_secondary_table_1, ["User_ID", "VAR_1"], False),
371+
"B/D": (ref_tertiary_table, ["User_ID", "VAR_1", "VAR_2"], False),
372+
"B/D/E": (
373+
ref_quaternary_table,
374+
["User_ID", "VAR_1", "VAR_2", "VAR_3"],
375+
),
376+
"C": (ref_secondary_table_2, ["User_ID"], True),
377+
},
378+
}
379+
deprecated_ds_spec = {
380+
"main_table": "A",
381+
"tables": {
382+
"A": (features_ref_main_table, "User_ID"),
383+
"B": (ref_secondary_table_1, ["User_ID", "VAR_1"]),
384+
"C": (ref_secondary_table_2, "User_ID"),
385+
"D": (ref_tertiary_table, ["User_ID", "VAR_1", "VAR_2"]),
386+
"E": (
387+
ref_quaternary_table,
388+
["User_ID", "VAR_1", "VAR_2", "VAR_3"],
389+
),
390+
},
391+
"relations": {
392+
("A", "B", False),
393+
("B", "D", False),
394+
("D", "E"),
395+
("A", "C", True),
396+
},
397+
}
398+
399+
label = ref_main_table["class"]
400+
401+
# Test that deprecation warning is issued when creating a dataset
402+
# according to the deprecated spec
403+
with warnings.catch_warnings(record=True) as warning_list:
404+
_ = Dataset(deprecated_ds_spec, label)
405+
self.assertTrue(len(warning_list) > 0)
406+
deprecation_warning_found = False
407+
for warning in warning_list:
408+
warning_message = warning.message
409+
if (
410+
issubclass(warning.category, UserWarning)
411+
and len(warning_message.args) == 1
412+
and "multi-table dataset specification format"
413+
in warning_message.args[0]
414+
and "deprecated" in warning_message.args[0]
415+
):
416+
deprecation_warning_found = True
417+
break
418+
self.assertTrue(deprecation_warning_found)
419+
420+
# Test that a deprecated dataset spec is upgraded to the new format
421+
ds_spec = _upgrade_mapping_spec(deprecated_ds_spec)
422+
self.assertEqual(ds_spec.keys(), expected_ds_spec.keys())
423+
main_table = ds_spec["main_table"]
424+
expected_main_table = expected_ds_spec["main_table"]
425+
426+
# Test that main table keys are identical
427+
self.assertEqual(main_table[1], expected_main_table[1])
428+
429+
# Test that main table data frame are equal
430+
assert_frame_equal(main_table[0], expected_main_table[0])
431+
432+
# Test that additional data tables keys are identical
433+
additional_data_tables = ds_spec["additional_data_tables"]
434+
expected_additional_data_tables = expected_ds_spec["additional_data_tables"]
435+
self.assertEqual(
436+
additional_data_tables.keys(), expected_additional_data_tables.keys()
437+
)
438+
439+
for table_path, expected_table_data in expected_additional_data_tables.items():
440+
table_data = additional_data_tables[table_path]
441+
442+
# Test that secondary table keys are identical
443+
self.assertEqual(table_data[1], expected_table_data[1])
444+
445+
# Test that the secondary table data frames are identical
446+
assert_frame_equal(table_data[0], expected_table_data[0])
447+
448+
# Test that the secondary table entity statuses are identical if True
449+
if len(expected_table_data) > 2 and expected_table_data[2] is True:
450+
self.assertEqual(table_data[2], expected_table_data[2])
451+
355452
def test_dataset_is_correctly_built(self):
356453
"""Test that the dataset structure is consistent with the input spec"""
357454
ds_spec, label = self.create_fixture_ds_spec(

0 commit comments

Comments
 (0)