Skip to content

Commit 99cd99e

Browse files
committed
feat: Support parameters
This commit allows users to pass some types of Python values as parameters. Supported types are as follows. * Basic types * string * integer * `float` * `bool`(`True`/`False`) * `None` * Collection * `dict` -- only basic types can be key * `list`/`tuple`
1 parent 64c8b45 commit 99cd99e

File tree

4 files changed

+235
-2
lines changed

4 files changed

+235
-2
lines changed

agensgraph/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from agensgraph._vertex import Vertex, cast_vertex as _cast_vertex
2222
from agensgraph._edge import Edge, cast_edge as _cast_edge
2323
from agensgraph._graphpath import Path, cast_graphpath as _cast_graphpath
24+
from agensgraph._property import Property
2425

2526
_GRAPHID_OID = 7002
2627
_VERTEX_OID = 7012
@@ -40,5 +41,5 @@
4041
PATH = _ext.new_type((_GRAPHPATH_OID,), 'PATH', _cast_graphpath)
4142
_ext.register_type(PATH)
4243

43-
__all__ = ['GraphId', 'Vertex', 'Edge', 'Path',
44+
__all__ = ['GraphId', 'Vertex', 'Edge', 'Path', 'Property',
4445
'GRAPHID', 'VERTEX', 'EDGE', 'PATH']

agensgraph/_property.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
'''
2+
Copyright (c) 2014-2018, Bitnine Inc.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
'''
16+
17+
import sys
18+
19+
from psycopg2.extensions import ISQLQuote
20+
from psycopg2.extras import json
21+
22+
# borrowed from simplejson's compat.py
23+
if sys.version_info[0] < 3:
24+
string_types = (basestring,)
25+
integer_types = (int, long)
26+
def dict_items(o):
27+
return o.iteritems()
28+
else:
29+
string_types = (str,)
30+
integer_types = (int,)
31+
def dict_items(o):
32+
return o.items()
33+
34+
def quote_string(s):
35+
s = s[1:-1]
36+
s = "'" + s.replace("'", "''") + "'"
37+
return s
38+
39+
class PropertyEncoder(object):
40+
def encode(self, o):
41+
chunks = self.iterencode(o)
42+
if not isinstance(chunks, (list, tuple)):
43+
chunks = list(chunks)
44+
return ''.join(chunks)
45+
46+
def iterencode(self, o):
47+
markers = {}
48+
_iterencode = _make_iterencode(markers, json.dumps, quote_string)
49+
return _iterencode(o)
50+
51+
def _make_iterencode(markers, _encoder, _quote_string,
52+
dict=dict,
53+
float=float,
54+
id=id,
55+
isinstance=isinstance,
56+
list=list,
57+
tuple=tuple,
58+
string_types=string_types,
59+
integer_types=integer_types,
60+
dict_items=dict_items):
61+
def _iterencode_list(o):
62+
if not o:
63+
yield '[]'
64+
return
65+
66+
markerid = id(o)
67+
if markerid in markers:
68+
raise ValueError('Circular reference detected')
69+
markers[markerid] = o
70+
71+
yield '['
72+
first = True
73+
for e in o:
74+
if first:
75+
first = False
76+
else:
77+
yield ','
78+
79+
for chunk in _iterencode(e):
80+
yield chunk
81+
yield ']'
82+
83+
del markers[markerid]
84+
85+
def _iterencode_dict(o):
86+
if not o:
87+
yield '{}'
88+
return
89+
90+
markerid = id(o)
91+
if markerid in markers:
92+
raise ValueError('Circular reference detected')
93+
markers[markerid] = o
94+
95+
yield '{'
96+
first = True
97+
for k, v in dict_items(o):
98+
if isinstance(k, string_types):
99+
pass
100+
elif (k is True or k is False or k is None or
101+
isinstance(k, integer_types) or isinstance(k, float)):
102+
k = _encoder(k)
103+
else:
104+
raise TypeError('keys must be str, int, float, bool or None, '
105+
'not %s' % k.__class__.__name__)
106+
107+
if first:
108+
first = False
109+
else:
110+
yield ','
111+
112+
yield _quote_string(_encoder(k))
113+
yield ':'
114+
for chunk in _iterencode(v):
115+
yield chunk
116+
yield '}'
117+
118+
del markers[markerid]
119+
120+
def _iterencode(o):
121+
if isinstance(o, string_types):
122+
yield _quote_string(_encoder(o))
123+
elif isinstance(o, (list, tuple)):
124+
for chunk in _iterencode_list(o):
125+
yield chunk
126+
elif isinstance(o, dict):
127+
for chunk in _iterencode_dict(o):
128+
yield chunk
129+
else:
130+
yield _encoder(o)
131+
132+
return _iterencode
133+
134+
_default_encoder = PropertyEncoder()
135+
136+
class Property(object):
137+
def __init__(self, value):
138+
self.value = value
139+
140+
def __conform__(self, proto):
141+
if proto is ISQLQuote:
142+
return self
143+
144+
def prepare(self, conn):
145+
self._conn = conn
146+
147+
def getquoted(self):
148+
return _default_encoder.encode(self.value)

tests/test_agensgraph.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
'''
2-
Copyright (c) 2014-2017, Bitnine Inc.
2+
Copyright (c) 2014-2018, Bitnine Inc.
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -124,5 +124,36 @@ def test_path(self):
124124
(e.eid,))
125125
self.assertEqual(1, self.cur.fetchone()[0])
126126

127+
class TestParam(TestConnection):
128+
def setUp(self):
129+
super(TestParam, self).setUp()
130+
self.name = "'Agens\"Graph'"
131+
132+
def test_param_dict(self):
133+
d = {'name': self.name, 'since': 2016}
134+
p = agensgraph.Property(d)
135+
self.cur.execute('CREATE (n %s) RETURN n', (p,))
136+
self.conn.commit()
137+
138+
v = self.cur.fetchone()[0]
139+
self.assertEqual(self.name, v.props['name'])
140+
self.assertEqual(2016, v.props['since'])
141+
142+
def test_param_list_and_tuple(self):
143+
a = [self.name, 2016]
144+
t = (self.name, 2016)
145+
pa = agensgraph.Property(a)
146+
pt = agensgraph.Property(t)
147+
self.cur.execute('CREATE (n {a: %s, t: %s}) RETURN n', (pa, pt))
148+
self.conn.commit()
149+
150+
v = self.cur.fetchone()[0]
151+
va = v.props['a']
152+
self.assertEqual(self.name, va[0])
153+
self.assertEqual(2016, va[1])
154+
vt = v.props['t']
155+
self.assertEqual(self.name, vt[0])
156+
self.assertEqual(2016, vt[1])
157+
127158
if __name__ == '__main__':
128159
unittest.main()

tests/test_property.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
'''
2+
Copyright (c) 2014-2018, Bitnine Inc.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
'''
16+
17+
import unittest
18+
19+
from agensgraph._property import Property
20+
21+
from psycopg2.extensions import QuotedString, adapt
22+
23+
class TestProperty(unittest.TestCase):
24+
def test_string(self):
25+
self.assertEqual(r"'\"'", Property('"').getquoted())
26+
self.assertEqual(r"''''", Property("'").getquoted())
27+
28+
def test_number(self):
29+
self.assertEqual('0', Property(0).getquoted())
30+
self.assertEqual('-1', Property(-1).getquoted())
31+
self.assertEqual('3.14159', Property(3.14159).getquoted())
32+
33+
def test_boolean(self):
34+
self.assertEqual('true', Property(True).getquoted())
35+
self.assertEqual('false', Property(False).getquoted())
36+
37+
def test_null(self):
38+
self.assertEqual('null', Property(None).getquoted())
39+
40+
def test_array(self):
41+
a = ["'\\\"'", 3.14159, True, None, (), {}]
42+
e = "['''\\\\\\\"''',3.14159,true,null,[],{}]"
43+
self.assertEqual(e, Property(a).getquoted())
44+
45+
def test_object(self):
46+
self.assertEqual("{'\\\"':'\\\"'}", Property({'"': '"'}).getquoted())
47+
self.assertEqual("{'3.14159':3.14159}",
48+
Property({3.14159: 3.14159}).getquoted())
49+
self.assertEqual("{'true':false}", Property({True: False}).getquoted())
50+
self.assertEqual("{'null':null}", Property({None: None}).getquoted())
51+
self.assertEqual("{'a':[]}", Property({'a': []}).getquoted())
52+
self.assertEqual("{'o':{}}", Property({'o': {}}).getquoted())
53+
self.assertRaises(TypeError, Property({(): None}).getquoted)

0 commit comments

Comments
 (0)