diff --git a/docs/custom_json_encoder.md b/docs/custom_json_encoder.md new file mode 100644 index 00000000..e382fb07 --- /dev/null +++ b/docs/custom_json_encoder.md @@ -0,0 +1,26 @@ +# Custom Json Encoder + +flask-mongoengine have option to add custom encoder for flask +By this way you can handle encoding special object + +Examples: + +```python +from flask_mongoengine.json import MongoEngineJSONProvider +class CustomJSONEncoder(MongoEngineJSONProvider): + @staticmethod + def default(obj): + if isinstance(obj, set): + return list(obj) + if isinstance(obj, Decimal128): + return str(obj) + return MongoEngineJSONProvider.default(obj) + + +# Tell your flask app to use your customised JSON encoder + + +app.json_provider_class = CustomJSONEncoder +app.json = app.json_provider_class(app) + +``` diff --git a/docs/custom_queryset.md b/docs/custom_queryset.md index f011b79c..0abcbb1f 100644 --- a/docs/custom_queryset.md +++ b/docs/custom_queryset.md @@ -6,7 +6,8 @@ flask-mongoengine attaches the following methods to Mongoengine's default QueryS Optional arguments: *message* - custom message to display. * **first_or_404**: same as above, except for .first(). Optional arguments: *message* - custom message to display. -* **paginate**: paginates the QuerySet. Takes two arguments, *page* and *per_page*. +* **paginate**: paginates the QuerySet. Takes two required arguments, *page* and *per_page*. + And one optional arguments *max_depth*. * **paginate_field**: paginates a field from one document in the QuerySet. Arguments: *field_name*, *doc_id*, *page*, *per_page*. diff --git a/docs/index.rst b/docs/index.rst index 625bdb50..c305b72f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -15,6 +15,7 @@ You can also use `WTForms `_ as model forms for forms migration_to_v2 custom_queryset + custom_json_encoder wtf_forms session_interface debug_toolbar diff --git a/flask_mongoengine/json.py b/flask_mongoengine/json.py index 83637632..514bac6b 100644 --- a/flask_mongoengine/json.py +++ b/flask_mongoengine/json.py @@ -78,7 +78,7 @@ def default(obj): (BaseDocument, QuerySet, CommandCursor, DBRef, ObjectId), ): return _convert_mongo_objects(obj) - return super().default(obj) + return superclass.default(obj) return MongoEngineJSONProvider diff --git a/flask_mongoengine/pagination.py b/flask_mongoengine/pagination.py index 01b8b2ea..22be2a1a 100644 --- a/flask_mongoengine/pagination.py +++ b/flask_mongoengine/pagination.py @@ -8,7 +8,15 @@ class Pagination(object): - def __init__(self, iterable, page, per_page): + def __init__(self, iterable, page: int, per_page: int, max_depth: int = None): + """ + :param iterable: iterable object . + :param page: Required page number start from 1. + :param per_page: Required number of documents per page. + :param max_depth: Option for limit number of dereference documents. + + + """ if page < 1: abort(404) @@ -19,11 +27,11 @@ def __init__(self, iterable, page, per_page): if isinstance(self.iterable, QuerySet): self.total = iterable.count() - self.items = ( - self.iterable.skip(self.per_page * (self.page - 1)) - .limit(self.per_page) - .select_related() + self.items = self.iterable.skip(self.per_page * (self.page - 1)).limit( + self.per_page ) + if max_depth is not None: + self.items = self.items.select_related(max_depth) else: start_index = (page - 1) * per_page end_index = page * per_page diff --git a/flask_mongoengine/sessions.py b/flask_mongoengine/sessions.py index 756359ee..2a8df49d 100644 --- a/flask_mongoengine/sessions.py +++ b/flask_mongoengine/sessions.py @@ -56,7 +56,7 @@ def get_expiration_time(self, app, session) -> timedelta: return timedelta(**app.config.get("SESSION_TTL", {"days": 1})) def open_session(self, app, request): - sid = request.cookies.get(app.session_cookie_name) + sid = request.cookies.get(app.config["SESSION_COOKIE_NAME"]) if sid: stored_session = self.cls.objects(sid=sid).first() @@ -81,7 +81,7 @@ def save_session(self, app, session, response): # If the session is empty, return without setting the cookie. if not session: if session.modified: - response.delete_cookie(app.session_cookie_name, domain=domain) + response.delete_cookie(app.config["SESSION_COOKIE_NAME"], domain=domain) return expiration = datetime.utcnow().replace(tzinfo=utc) + self.get_expiration_time( @@ -92,7 +92,7 @@ def save_session(self, app, session, response): self.cls(sid=session.sid, data=session, expiration=expiration).save() response.set_cookie( - app.session_cookie_name, + app.config["SESSION_COOKIE_NAME"], session.sid, expires=expiration, httponly=httponly, diff --git a/tests/test_json.py b/tests/test_json.py index 73b7de28..75623cfd 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -41,7 +41,7 @@ def extended_db(app): test_db.connection["default"].drop_database(db_name) -class DummyEncoder(flask.json.JSONEncoder): +class DummyEncoder(flask.json._json.JSONEncoder): """ An example encoder which a user may create and override the apps json_encoder with. diff --git a/tests/test_pagination.py b/tests/test_pagination.py index e9d3f4fe..668093b5 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,14 +1,29 @@ +import flask import pytest from werkzeug.exceptions import NotFound - from flask_mongoengine import ListFieldPagination, Pagination -def test_queryset_paginator(app, todo): +@pytest.fixture(autouse=True) +def setup_endpoints(app, todo): Todo = todo for i in range(42): Todo(title=f"post: {i}").save() + @app.route("/") + def index(): + page = int(flask.request.form.get("page")) + per_page = int(flask.request.form.get("per_page")) + query_set = Todo.objects().paginate(page=page, per_page=per_page) + return {'data': [_ for _ in query_set.items], + 'total': query_set.total, + 'has_next': query_set.has_next, + } + + +def test_queryset_paginator(app, todo): + Todo = todo + with pytest.raises(NotFound): Pagination(iterable=Todo.objects, page=0, per_page=10) @@ -90,3 +105,25 @@ def _test_paginator(paginator): # Paginate to the next page if i < 5: paginator = paginator.next() + + +def test_flask_pagination(app, todo): + client = app.test_client() + response = client.get(f"/", data={"page": 0, "per_page": 10}) + print(response.status_code) + assert response.status_code == 404 + + response = client.get(f"/", data={"page": 6, "per_page": 10}) + print(response.status_code) + assert response.status_code == 404 + + +def test_flask_pagination_next(app, todo): + client = app.test_client() + has_next = True + page = 1 + while has_next: + response = client.get(f"/", data={"page": page, "per_page": 10}) + assert response.status_code == 200 + has_next = response.json['has_next'] + page += 1