Skip to content

Commit 67cb0c4

Browse files
authored
Add bool operators (#194)
## Summary - Add boolean equality operators (`==` and `\!=`) - Add full set of boolean operations (`and`, `or`, `not`) ## Changes This PR implements boolean comparison and operation capabilities, enabling more expressive conditional logic in SPy code. ## Test plan - Run test suite to verify proper boolean operator functionality
2 parents f3848c1 + 233a324 commit 67cb0c4

File tree

7 files changed

+185
-7
lines changed

7 files changed

+185
-7
lines changed

spy/libspy/include/spy/operator.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,40 @@ static inline double spy_operator$f64_floordiv(double x, double y) {
6868
return floor(x / y);
6969
}
7070

71+
static inline bool spy_operator$bool_eq(bool x, bool y) {
72+
return x == y;
73+
}
74+
75+
static inline bool spy_operator$bool_ne(bool x, bool y) {
76+
return x != y;
77+
}
78+
79+
static inline bool spy_operator$bool_and(bool x, bool y) {
80+
return x && y;
81+
}
82+
83+
static inline bool spy_operator$bool_or(bool x, bool y) {
84+
return x || y;
85+
}
86+
87+
static inline bool spy_operator$bool_xor(bool x, bool y) {
88+
return x != y;
89+
}
90+
91+
static inline bool spy_operator$bool_lt(bool x, bool y) {
92+
return !x && y;
93+
}
94+
95+
static inline bool spy_operator$bool_le(bool x, bool y) {
96+
return !x || y;
97+
}
98+
99+
static inline bool spy_operator$bool_gt(bool x, bool y) {
100+
return x && !y;
101+
}
102+
103+
static inline bool spy_operator$bool_ge(bool x, bool y) {
104+
return x || !y;
105+
}
106+
71107
#endif /* SPY_OPERATOR_H */

spy/tests/compiler/test_basic.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,90 @@ def get_False() -> bool:
420420
""")
421421
assert mod.get_True() is True
422422
assert mod.get_False() is False
423+
424+
def test_bool_equality(self):
425+
mod = self.compile("""
426+
def eq_bool(a: bool, b: bool) -> bool:
427+
return a == b
428+
429+
def ne_bool(a: bool, b: bool) -> bool:
430+
return a != b
431+
""")
432+
assert mod.eq_bool(True, True) is True
433+
assert mod.eq_bool(True, False) is False
434+
assert mod.eq_bool(False, True) is False
435+
assert mod.eq_bool(False, False) is True
436+
437+
assert mod.ne_bool(True, True) is False
438+
assert mod.ne_bool(True, False) is True
439+
assert mod.ne_bool(False, True) is True
440+
assert mod.ne_bool(False, False) is False
441+
442+
def test_bool_operations(self):
443+
mod = self.compile("""
444+
def and_bool(a: bool, b: bool) -> bool:
445+
return a & b
446+
447+
def or_bool(a: bool, b: bool) -> bool:
448+
return a | b
449+
450+
def xor_bool(a: bool, b: bool) -> bool:
451+
return a ^ b
452+
453+
def lt_bool(a: bool, b: bool) -> bool:
454+
return a < b
455+
456+
def le_bool(a: bool, b: bool) -> bool:
457+
return a <= b
458+
459+
def gt_bool(a: bool, b: bool) -> bool:
460+
return a > b
461+
462+
def ge_bool(a: bool, b: bool) -> bool:
463+
return a >= b
464+
""")
465+
466+
# Test AND
467+
assert mod.and_bool(True, True) is True
468+
assert mod.and_bool(True, False) is False
469+
assert mod.and_bool(False, True) is False
470+
assert mod.and_bool(False, False) is False
471+
472+
# Test OR
473+
assert mod.or_bool(True, True) is True
474+
assert mod.or_bool(True, False) is True
475+
assert mod.or_bool(False, True) is True
476+
assert mod.or_bool(False, False) is False
477+
478+
# Test XOR
479+
assert mod.xor_bool(True, True) is False
480+
assert mod.xor_bool(True, False) is True
481+
assert mod.xor_bool(False, True) is True
482+
assert mod.xor_bool(False, False) is False
483+
484+
# Test <
485+
assert mod.lt_bool(True, True) is False
486+
assert mod.lt_bool(True, False) is False
487+
assert mod.lt_bool(False, True) is True
488+
assert mod.lt_bool(False, False) is False
489+
490+
# Test <=
491+
assert mod.le_bool(True, True) is True
492+
assert mod.le_bool(True, False) is False
493+
assert mod.le_bool(False, True) is True
494+
assert mod.le_bool(False, False) is True
495+
496+
# Test >
497+
assert mod.gt_bool(True, True) is False
498+
assert mod.gt_bool(True, False) is True
499+
assert mod.gt_bool(False, True) is False
500+
assert mod.gt_bool(False, False) is False
501+
502+
# Test >=
503+
assert mod.ge_bool(True, True) is True
504+
assert mod.ge_bool(True, False) is True
505+
assert mod.ge_bool(False, True) is False
506+
assert mod.ge_bool(False, False) is True
423507

424508
def test_CompareOp_error(self):
425509
src = """

spy/tests/compiler/test_typelift.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,15 @@ def wrong_meth(x: i32) -> i32:
9898
def test_if_inside_classdef(self):
9999
src = """
100100
@blue
101-
def make_Foo(K):
101+
def make_Foo(DOUBLE):
102102
@typelift
103103
class Foo:
104104
__ll__: i32
105105
106106
def __new__(i: i32) -> Foo:
107107
return Foo.__lift__(i)
108108
109-
if K == 2:
109+
if DOUBLE:
110110
def get(self: Foo) -> i32:
111111
return self.__ll__ * 2
112112
else:
@@ -116,13 +116,13 @@ def get(self: Foo) -> i32:
116116
return Foo
117117
118118
def test1(x: i32) -> i32:
119-
a = make_Foo(1)(x)
119+
a = make_Foo(True)(x)
120120
return a.get()
121121
122122
def test2(x: i32) -> i32:
123-
b = make_Foo(2)(x)
123+
b = make_Foo(False)(x)
124124
return b.get()
125125
"""
126126
mod = self.compile(src)
127-
assert mod.test1(10) == 10
128-
assert mod.test2(10) == 20
127+
assert mod.test1(10) == 20
128+
assert mod.test2(10) == 10

spy/tests/wasm_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(self, vm: SPyVM, ll: LLSPyInstance, c_name: str,
7676
self.w_functype = w_functype
7777

7878
def py2wasm(self, pyval: Any, w_type: W_Type) -> Any:
79-
if w_type in (B.w_i32, B.w_i8, B.w_u8, B.w_f64):
79+
if w_type in (B.w_i32, B.w_i8, B.w_u8, B.w_f64, B.w_bool):
8080
return pyval
8181
elif w_type is B.w_str:
8282
# XXX: with the GC, we need to think how to keep this alive

spy/vm/modules/operator/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def op_fast_call(vm: 'SPyVM', w_func: W_Func,
6565
from . import opimpl_str # noqa: F401 -- side effects
6666
from . import opimpl_object # noqa: F401 -- side effects
6767
from . import opimpl_dynamic # noqa: F401 -- side effects
68+
from . import opimpl_bool # noqa: F401 -- side effects
6869
from . import unaryop # noqa: F401 -- side effects
6970
from . import binop # noqa: F401 -- side effects
7071
from . import attrop # noqa: F401 -- side effects

spy/vm/modules/operator/binop.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,17 @@
111111
MM.register('==', 'str', 'str', OP.w_str_eq)
112112
MM.register('!=', 'str', 'str', OP.w_str_ne)
113113

114+
# bool ops
115+
MM.register('==', 'bool', 'bool', OP.w_bool_eq)
116+
MM.register('!=', 'bool', 'bool', OP.w_bool_ne)
117+
MM.register('&', 'bool', 'bool', OP.w_bool_and)
118+
MM.register('|', 'bool', 'bool', OP.w_bool_or)
119+
MM.register('^', 'bool', 'bool', OP.w_bool_xor)
120+
MM.register('<', 'bool', 'bool', OP.w_bool_lt)
121+
MM.register('<=', 'bool', 'bool', OP.w_bool_le)
122+
MM.register('>', 'bool', 'bool', OP.w_bool_gt)
123+
MM.register('>=', 'bool', 'bool', OP.w_bool_ge)
124+
114125
# dynamic ops
115126
MM.register_partial('+', 'dynamic', OP.w_dynamic_add)
116127
MM.register_partial('*', 'dynamic', OP.w_dynamic_mul)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from typing import TYPE_CHECKING
2+
from spy.vm.primitive import W_Bool
3+
from . import OP
4+
5+
if TYPE_CHECKING:
6+
from spy.vm.vm import SPyVM
7+
8+
@OP.builtin_func('bool_eq')
9+
def w_bool_eq(vm: 'SPyVM', w_a: W_Bool, w_b: W_Bool) -> W_Bool:
10+
return vm.wrap(w_a.value == w_b.value) # type: ignore
11+
12+
@OP.builtin_func('bool_ne')
13+
def w_bool_ne(vm: 'SPyVM', w_a: W_Bool, w_b: W_Bool) -> W_Bool:
14+
return vm.wrap(w_a.value != w_b.value) # type: ignore
15+
16+
@OP.builtin_func('bool_and')
17+
def w_bool_and(vm: 'SPyVM', w_a: W_Bool, w_b: W_Bool) -> W_Bool:
18+
return vm.wrap(w_a.value and w_b.value) # type: ignore
19+
20+
@OP.builtin_func('bool_or')
21+
def w_bool_or(vm: 'SPyVM', w_a: W_Bool, w_b: W_Bool) -> W_Bool:
22+
return vm.wrap(w_a.value or w_b.value) # type: ignore
23+
24+
@OP.builtin_func('bool_xor')
25+
def w_bool_xor(vm: 'SPyVM', w_a: W_Bool, w_b: W_Bool) -> W_Bool:
26+
return vm.wrap(w_a.value != w_b.value) # type: ignore
27+
28+
@OP.builtin_func('bool_lt')
29+
def w_bool_lt(vm: 'SPyVM', w_a: W_Bool, w_b: W_Bool) -> W_Bool:
30+
# False < True but not True < False
31+
return vm.wrap(not w_a.value and w_b.value) # type: ignore
32+
33+
@OP.builtin_func('bool_le')
34+
def w_bool_le(vm: 'SPyVM', w_a: W_Bool, w_b: W_Bool) -> W_Bool:
35+
# False <= True and True <= True and False <= False
36+
return vm.wrap(not w_a.value or w_b.value) # type: ignore
37+
38+
@OP.builtin_func('bool_gt')
39+
def w_bool_gt(vm: 'SPyVM', w_a: W_Bool, w_b: W_Bool) -> W_Bool:
40+
# True > False but not False > True
41+
return vm.wrap(w_a.value and not w_b.value) # type: ignore
42+
43+
@OP.builtin_func('bool_ge')
44+
def w_bool_ge(vm: 'SPyVM', w_a: W_Bool, w_b: W_Bool) -> W_Bool:
45+
# True >= False and True >= True and False >= False
46+
return vm.wrap(w_a.value or not w_b.value) # type: ignore

0 commit comments

Comments
 (0)