Skip to content

Commit 2d519d1

Browse files
committed
fix: add support for simple casts
1 parent e033e8e commit 2d519d1

File tree

4 files changed

+49
-22
lines changed

4 files changed

+49
-22
lines changed

README.md

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,19 @@ rye add query-farm-sql-scan-planning
3131

3232
```python
3333
from query_farm_sql_scan_planning import Planner, RangeFieldInfo, SetFieldInfo
34+
import pyarrow as pa
3435

3536
# Define file metadata
3637
files = [
3738
(
3839
"data_2023_q1.parquet",
3940
{
40-
"sales_amount": RangeFieldInfo[int](
41-
min_value=100, max_value=50000,
41+
"sales_amount": RangeFieldInfo(
42+
min_value=pa.scalar(100), max_value=pa.scalar(50000),
4243
has_nulls=False, has_non_nulls=True
4344
),
4445
"region": SetFieldInfo[str](
45-
values={"US", "CA", "MX"},
46+
values={pa.scalar("US"), pa.scalar("CA"), pa.scalar("MX")},
4647
has_nulls=False, has_non_nulls=True
4748
),
4849
}
@@ -51,11 +52,11 @@ files = [
5152
"data_2023_q2.parquet",
5253
{
5354
"sales_amount": RangeFieldInfo[int](
54-
min_value=200, max_value=75000,
55+
min_value=pa.scalar(200), max_value=pa.scalar(75000),
5556
has_nulls=False, has_non_nulls=True
5657
),
5758
"region": SetFieldInfo[str](
58-
values={"US", "EU", "UK"},
59+
values={pa.scalar("US"), pa.scalar("EU"), pa.scalar("UK")},
5960
has_nulls=False, has_non_nulls=True
6061
),
6162
}
@@ -81,9 +82,9 @@ print(matching_files) # {'data_2023_q2.parquet'}
8182
For fields with known minimum and maximum values:
8283

8384
```python
84-
RangeFieldInfo[int](
85-
min_value=0,
86-
max_value=100,
85+
RangeFieldInfo(
86+
min_value=pa.scalar(0),
87+
max_value=pa.scalar(100),
8788
has_nulls=False, # Whether the field contains NULL values
8889
has_non_nulls=True # Whether the field contains non-NULL values
8990
)
@@ -94,8 +95,8 @@ RangeFieldInfo[int](
9495
For fields with a known set of possible values (useful for categorical data):
9596

9697
```python
97-
SetFieldInfo[str](
98-
values={"apple", "banana", "cherry"},
98+
SetFieldInfo(
99+
values={pa.scalar("apple"), pa.scalar("banana"), pa.scalar("cherry")},
99100
has_nulls=False,
100101
has_non_nulls=True
101102
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "query-farm-sql-scan-planning"
3-
version = "0.1.4"
3+
version = "0.1.5"
44
description = "A Python library for intelligent file filtering using SQL expressions and metadata-based scan planning. This library enables efficient data lake query optimization by determining which files need to be scanned based on their statistical metadata."
55
authors = [
66
{ name = "Rusty Conover", email = "rusty@conover.me" }

src/query_farm_sql_scan_planning/planner.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from typing import Any, Generator
44
import duckdb
55
import pyarrow as pa
6+
import pyarrow.compute as pc
67
import sqlglot
78
import sqlglot.expressions
8-
import sqlglot.optimizer.simplify
9+
import sqlglot.optimizer
910

1011

1112
@dataclass
@@ -139,9 +140,20 @@ def _eval_predicate(
139140
if isinstance(node, sqlglot.expressions.Is):
140141
return self._evaluate_node_is(node, file_info)
141142

142-
# Handle comparison operations
143-
if not isinstance(node.left, sqlglot.expressions.Column):
143+
# So if the left is just a cast of an column ref we can handle that
144+
# because the right hand side will also be that way.
145+
146+
left_is_column = isinstance(node.left, sqlglot.expressions.Column)
147+
left_is_cast = isinstance(node.left, sqlglot.expressions.Cast)
148+
need_left_value_cast = False
149+
if left_is_cast and isinstance(node.left.this, sqlglot.expressions.Column):
150+
# If the left side is a cast of a column, we can treat it as a column reference.
151+
need_left_value_cast = True
152+
left_column_name = node.left.this.this.this
153+
elif not left_is_column:
144154
return None
155+
else:
156+
left_column_name = node.left.this.this
145157

146158
if node.right.find(sqlglot.expressions.Column) is not None:
147159
# Can't evaluate this since it has a right hand column ref, ideally
@@ -165,14 +177,7 @@ def _eval_predicate(
165177
if type(right_val) is pa.Int32Scalar and right_val.as_py() is None:
166178
right_val = pa.scalar(None, type=pa.null())
167179

168-
left_val = node.left
169-
assert isinstance(left_val, sqlglot.expressions.Column), (
170-
f"Expected a column on left side of {node}, got {left_val}"
171-
)
172-
assert isinstance(left_val.this, sqlglot.expressions.Identifier), (
173-
f"Expected an identifier on left side of {node}, got {left_val.this}"
174-
)
175-
referenced_field_name = left_val.this.this
180+
referenced_field_name = left_column_name
176181

177182
field_info = file_info.get(referenced_field_name)
178183

@@ -181,6 +186,22 @@ def _eval_predicate(
181186
if field_info is None:
182187
return None
183188

189+
if need_left_value_cast:
190+
if not isinstance(field_info, RangeFieldInfo):
191+
# If we need a value cast but the field info is not a range,
192+
# we can't evaluate this expression.
193+
return None
194+
field_info = RangeFieldInfo(
195+
has_nulls=field_info.has_nulls,
196+
has_non_nulls=field_info.has_non_nulls,
197+
min_value=pc.cast(field_info.min_value, right_val.type)
198+
if field_info.min_value is not None
199+
else None,
200+
max_value=pc.cast(field_info.max_value, right_val.type)
201+
if field_info.max_value is not None
202+
else None,
203+
)
204+
184205
if isinstance(field_info, SetFieldInfo):
185206
match type(node):
186207
case sqlglot.expressions.EQ:

src/query_farm_sql_scan_planning/test_planner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,12 @@ def sample_files() -> list[tuple[str, FileFieldInfo]]:
165165
@pytest.mark.parametrize(
166166
"clause, expected_files",
167167
[
168+
("t1::date >= '2030-01-01'", set()),
169+
("t1::date = '2023-08-01'", {"file1"}),
168170
("t1 = DATE '2023-08-01'", {"file1"}),
169171
("t1 > DATE '2023-08-01'", {"file1"}),
172+
("t1::date > '2023-08-01'", {"file1"}),
173+
("cast(t1 as timestamp) > TIMESTAMP '2023-08-01'", {"file1"}),
170174
("t1 <> DATE '2023-08-01'", {"file1"}),
171175
("t1 <> DATE '2023-08-01' - interval '6 days'", {"file1"}),
172176
# This isn't possible, to evaluate, we need to check for additional
@@ -186,6 +190,7 @@ def sample_files() -> list[tuple[str, FileFieldInfo]]:
186190
),
187191
("'apple' in (d1)", ALL_FILES), # could be improved.
188192
("v1 < 100 and d1 = 'apple'", {"file1"}),
193+
("v1::uhugeint * 5 > 400", ALL_FILES),
189194
("v1 > 500 and v1 < 600", {"file4", "file5"}),
190195
("v1 != 500 and v1 < 400", {"file1", "file2", "file3"}),
191196
("v1 >= 300 and v1 <= 500", {"file2", "file3", "file4", "file7"}),

0 commit comments

Comments
 (0)