Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Case insensitivity #113

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
language: python
python:
- "2.7"
install:
- "pip install -r requirements.txt --use-mirrors"
- "pip install coverage"
- "pip install coveralls"
script:
- "coverage run --source=sandman setup.py test"
install:
- python setup.py develop
- pip install coverage
- pip install coveralls
script:
- coverage run --source=sandman setup.py test
after_success:
coveralls
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
recursive-include sandman/static *
recursive-include sandman/templates *
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ itsdangerous==0.24
wsgiref==0.1.2
click==0.7
sphinx-rtd-theme==0.1.6
six==1.9.0
15 changes: 10 additions & 5 deletions sandman/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from flask import current_app, g
from flask.ext.admin import Admin
from flask.ext.admin.contrib.sqla import ModelView
from flask.ext.sqlalchemy import SQLAlchemy
from sqlalchemy.engine import reflection
from sqlalchemy.ext.declarative import declarative_base, DeferredReflection
from sqlalchemy.orm import relationship
Expand All @@ -22,7 +23,7 @@ def _get_session():
return session


def generate_endpoint_classes(db, generate_pks=False):
def generate_endpoint_classes(db, generate_pks=False, base=None):
"""Return a list of model classes generated for each reflected database
table."""
seen_classes = set()
Expand All @@ -38,7 +39,7 @@ def generate_endpoint_classes(db, generate_pks=False):
else:
cls = type(
str(name),
(sandman_model, db.Model),
(base or sandman_model, db.Model),
{'__tablename__': name})
register(cls)

Expand Down Expand Up @@ -160,7 +161,7 @@ def register_classes_for_admin(db_session, show_pks=True, name='admin'):
admin_view.add_view(admin_view_class(cls, db_session))


def activate(admin=True, browser=True, name='admin', reflect_all=False):
def activate(admin=True, browser=True, name='admin', reflect_all=False, base=None):
"""Activate each pre-registered model or generate the model classes and
(possibly) register them for the admin.

Expand All @@ -170,13 +171,14 @@ def activate(admin=True, browser=True, name='admin', reflect_all=False):
this to avoid naming conflicts with other blueprints (if
trying to use sandman to connect to multiple databases
simultaneously)
:param base: Optional base model class; defaults to `model.Model`

"""
with app.app_context():
generate_pks = app.config.get('SANDMAN_GENERATE_PKS', None) or False
if getattr(app, 'class_references', None) is None or reflect_all:
app.class_references = collections.OrderedDict()
generate_endpoint_classes(db, generate_pks)
generate_endpoint_classes(db, generate_pks, base=base)
else:
Model.prepare(db.engine)
prepare_relationships(db, current_app.class_references)
Expand All @@ -199,4 +201,7 @@ def activate(admin=True, browser=True, name='admin', reflect_all=False):
# actually the same thing.

sandman_model = Model
Model = declarative_base(cls=(Model, DeferredReflection))
# Model = declarative_base(cls=(Model, DeferredReflection))

class Model(Model, DeferredReflection, db.Model):
__abstract__ = True
92 changes: 51 additions & 41 deletions sandman/sandman.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Sandman REST API creator for Flask and SQLAlchemy"""

import six
from flask import (
jsonify,
request,
current_app,
Response,
render_template,
make_response)
from sqlalchemy.exc import IntegrityError
import sqlalchemy as sa
from . import app
from .decorators import etag, no_cache
from .exception import InvalidAPIUsage
Expand Down Expand Up @@ -54,6 +55,20 @@ def _get_acceptable_response_type():
raise InvalidAPIUsage(406)


def _get_column(model, key):
try:
return getattr(model, key)
except AttributeError:
raise InvalidAPIUsage(422)


def _column_type(attribute):
columns = attribute.property.columns
if len(columns) == 1:
return columns[0].type.python_type
return None


@app.errorhandler(InvalidAPIUsage)
def handle_exception(error):
"""Return a response with the appropriate status code, message, and content
Expand Down Expand Up @@ -135,34 +150,28 @@ def _single_resource_html_response(resource):
tablename=tablename))


def _collection_json_response(cls, resources, start, stop, depth=0):
def _collection_json_response(cls, resources, depth=0):
"""Return the JSON representation of the collection *resources*.

:param list resources: list of :class:`sandman.model.Model`s to render
:rtype: :class:`flask.Response`

"""

top_level_json_name = None
if cls.__top_level_json_name__ is not None:
top_level_json_name = cls.__top_level_json_name__
else:
top_level_json_name = 'resources'

result_list = []
for resource in resources:
result_list.append(resource.as_dict(depth))

payload = {}
if start is not None:
payload[top_level_json_name] = result_list[start:stop]
else:
payload[top_level_json_name] = result_list
resource_key = cls.__top_level_json_name__ or 'resources'

payload = {
resource_key: [each.as_dict(depth) for each in resources.items],
'pagination': {
'page': resources.page,
'per_page': resources.per_page,
'count': resources.total,
}
}
return jsonify(payload)


def _collection_html_response(resources, start=0, stop=20):
def _collection_html_response(resources):
"""Return the HTML representation of the collection *resources*.

:param list resources: list of :class:`sandman.model.Model`s to render
Expand All @@ -171,7 +180,7 @@ def _collection_html_response(resources, start=0, stop=20):
"""
return make_response(render_template(
'collection.html',
resources=resources[start:stop]))
resources=resources.items))


def _validate(cls, method, resource=None):
Expand All @@ -188,7 +197,7 @@ def _validate(cls, method, resource=None):

"""
if method not in cls.__methods__:
raise InvalidAPIUsage(403, FORBIDDEN_EXCEPTION_MESSAGE.format(
raise InvalidAPIUsage(405, FORBIDDEN_EXCEPTION_MESSAGE.format(
method,
cls.endpoint(), cls.__methods__))

Expand Down Expand Up @@ -244,27 +253,28 @@ def retrieve_collection(collection, query_arguments=None):
:rtype: class:`sandman.model.Model`

"""
session = _get_session()
cls = endpoint_class(collection)
if query_arguments:
filters = []
order = []
limit = None
for key, value in query_arguments.items():
if key == 'page':
if key in ['page', 'limit']:
continue
if value.startswith('%'):
filters.append(getattr(cls, key).like(str(value), escape='/'))
filters.append(_get_column(cls, key).like(str(value), escape='/'))
elif key == 'sort':
order.append(getattr(cls, value))
order.append(_get_column(cls, value))
elif key == 'limit':
limit = value
elif key:
filters.append(getattr(cls, key) == value)
resources = session.query(cls).filter(*filters).order_by(
*order).limit(limit)
column = _get_column(cls, key)
if app.config.get('CASE_INSENSITIVE') and issubclass(_column_type(column), six.string_types):
filters.append(sa.func.upper(column) == value.upper())
else:
filters.append(column == value)
resources = cls.query.filter(*filters).order_by(*order)
else:
resources = session.query(cls).all()
resources = cls.query
return resources


Expand Down Expand Up @@ -303,7 +313,7 @@ def resource_created_response(resource):
return response


def collection_response(cls, resources, start=None, stop=None):
def collection_response(cls, resources):
"""Return a response for the *resources* of the appropriate content type.

:param resources: resources to be returned in request
Expand All @@ -312,9 +322,9 @@ def collection_response(cls, resources, start=None, stop=None):

"""
if _get_acceptable_response_type() == JSON:
return _collection_json_response(cls, resources, start, stop)
return _collection_json_response(cls, resources)
else:
return _collection_html_response(resources, start, stop)
return _collection_html_response(resources)


def resource_response(resource, depth=0):
Expand Down Expand Up @@ -428,7 +438,7 @@ def put_resource(collection, key):
resource.replace(get_resource_data(request))
try:
_perform_database_action('add', resource)
except IntegrityError as exception:
except sa.exc.IntegrityError as exception:
raise InvalidAPIUsage(422, FORWARDED_EXCEPTION_MESSAGE.format(
exception))
return no_content_response()
Expand Down Expand Up @@ -471,7 +481,7 @@ def delete_resource(collection, key):

try:
_perform_database_action('delete', resource)
except IntegrityError as exception:
except sa.exc.IntegrityError as exception:
raise InvalidAPIUsage(422, FORWARDED_EXCEPTION_MESSAGE.format(
exception))
return no_content_response()
Expand Down Expand Up @@ -530,13 +540,13 @@ def get_collection(collection):

_validate(cls, request.method, resources)

start = stop = None

if request.args and 'page' in request.args:
page = int(request.args['page'])
results_per_page = app.config.get('RESULTS_PER_PAGE', 20)
start, stop = page * results_per_page, (page + 1) * results_per_page
return collection_response(cls, resources, start, stop)
try:
page = int(request.args.get('page', 1))
except (TypeError, ValueError):
raise InvalidAPIUsage(422)
per_page = app.config.get('RESULTS_PER_PAGE', 20)
resources = resources.paginate(page, per_page, error_out=False)
return collection_response(cls, resources)


@app.route('/', methods=['GET'])
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def run_tests(self):
'Flask-HTTPAuth>=2.2.1',
'docopt>=0.6.1',
'click',
'six',
#'sphinx-rtd-theme',
],
cmdclass={'test': PyTest},
Expand Down
5 changes: 3 additions & 2 deletions tests/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Models for unit testing sandman"""

from flask.ext.sqlalchemy import BaseQuery
from flask.ext.admin.contrib.sqla import ModelView
from sandman.model import register, Model, activate
from sandman.model.models import db
Expand Down Expand Up @@ -90,11 +91,11 @@ def validate_GET(resource=None):

"""

if isinstance(resource, list):
if isinstance(resource, BaseQuery):
return True
elif resource and resource.GenreId == 1:
return False
return True

register((Artist, Album, Playlist, Track, MediaType, Style, SomeModel))
activate(browser=True)
activate(browser=False)
Loading