diff --git a/src/ref_geo/commands.py b/src/ref_geo/commands.py index c59426b..005d9c6 100644 --- a/src/ref_geo/commands.py +++ b/src/ref_geo/commands.py @@ -1,6 +1,6 @@ import click from flask.cli import with_appcontext -from sqlalchemy import func, column +from sqlalchemy import func, select from ref_geo.env import db from ref_geo.models import BibAreasTypes, LAreas @@ -16,10 +16,10 @@ def ref_geo(): def info(): click.echo("RefGeo : nombre de zones par type") q = ( - db.session.query(BibAreasTypes, func.count(LAreas.id_area).label("count")) + select(BibAreasTypes, func.count(LAreas.id_area).label("count")) .join(LAreas) .group_by(BibAreasTypes.id_type) .order_by(BibAreasTypes.id_type) ) - for area_type, count in q.all(): + for area_type, count in db.session.scalars(q).unique().all(): click.echo("\t{}: {}".format(area_type.type_name, count)) diff --git a/src/ref_geo/routes.py b/src/ref_geo/routes.py index b32e349..73dec15 100644 --- a/src/ref_geo/routes.py +++ b/src/ref_geo/routes.py @@ -4,7 +4,7 @@ from flask import Blueprint, request, current_app from flask.json import jsonify import sqlalchemy as sa -from sqlalchemy import func, distinct, asc, desc +from sqlalchemy import func, select, asc, desc from sqlalchemy.sql import text from sqlalchemy.orm import joinedload, undefer from werkzeug.exceptions import BadRequest @@ -58,8 +58,10 @@ def getGeoInfo(): raise BadRequest("Missing 'geometry' in request payload") geojson = json.dumps(geojson) - areas = LAreas.query.filter_by(enable=True).filter( - geojson_intersect_filter.params(geojson=geojson) + areas = ( + select(LAreas) + .filter_by(enable=True) + .filter(geojson_intersect_filter.params(geojson=geojson)) ) if "area_type" in request.json: areas = areas.join(BibAreasTypes).filter_by(type_code=request.json["area_type"]) @@ -75,7 +77,7 @@ def getGeoInfo(): return jsonify( { "areas": AreaSchema(only=["id_area", "id_type", "area_code", "area_name"]).dump( - areas.all(), many=True + db.session.scalars(areas).unique().all(), many=True ), "altitude": altitude, } @@ -116,8 +118,10 @@ def getAreasIntersection(): raise BadRequest("Missing 'geometry' in request payload") geojson = json.dumps(geojson) - areas = LAreas.query.filter_by(enable=True).filter( - geojson_intersect_filter.params(geojson=geojson) + areas = ( + select(LAreas) + .filter_by(enable=True) + .filter(geojson_intersect_filter.params(geojson=geojson)) ) if "area_type" in request.json: areas = areas.join(BibAreasTypes).filter_by(type_code=request.json["area_type"]) @@ -130,7 +134,9 @@ def getAreasIntersection(): areas = areas.order_by(LAreas.id_type) response = {} - for id_type, _areas in groupby(areas.all(), key=lambda area: area.id_type): + for id_type, _areas in groupby( + db.session.scalars(areas).unique().all(), key=lambda area: area.id_type + ): _areas = list(_areas) response[id_type] = _areas[0].area_type.as_dict(fields=["type_code", "type_name"]) response[id_type].update( @@ -157,13 +163,13 @@ def get_municipalities(): """ parameters = request.args - q = db.session.query(LiMunicipalities).order_by(LiMunicipalities.nom_com.asc()) + q = select(LiMunicipalities).order_by(LiMunicipalities.nom_com.asc()) if "nom_com" in parameters: - q = q.filter(LiMunicipalities.nom_com.ilike("{}%".format(parameters.get("nom_com")))) + q = q.where(LiMunicipalities.nom_com.ilike("{}%".format(parameters.get("nom_com")))) limit = int(parameters.get("limit")) if parameters.get("limit") else 100 - municipalities = q.limit(limit) + municipalities = db.session.scalars(q.limit(limit)).all() return jsonify(MunicipalitySchema().dump(municipalities, many=True)) @@ -176,8 +182,8 @@ def get_areas(): # change all args in a list of value params = {key: request.args.getlist(key) for key, value in request.args.items()} - q = ( - db.session.query(LAreas) + query = ( + select(LAreas) .options(joinedload("area_type").load_only("type_code")) .order_by(LAreas.area_name.asc()) ) @@ -192,33 +198,34 @@ def get_areas(): } return response, 400 if enable_param == "true": - q = q.filter(LAreas.enable == True) + query = query.where(LAreas.enable == True) elif enable_param == "false": - q = q.filter(LAreas.enable == False) + query = query.where(LAreas.enable == False) else: - q = q.filter(LAreas.enable == True) + query = query.where(LAreas.enable == True) if "id_type" in params: - q = q.filter(LAreas.id_type.in_(params["id_type"])) + query = query.where(LAreas.id_type.in_(params["id_type"])) if "type_code" in params: - q = q.filter(LAreas.area_type.has(BibAreasTypes.type_code.in_(params["type_code"]))) + query = query.where(LAreas.area_type.has(BibAreasTypes.type_code.in_(params["type_code"]))) if "area_name" in params: - q = q.filter(LAreas.area_name.ilike("%{}%".format(params.get("area_name")[0]))) + query = query.where(LAreas.area_name.ilike("%{}%".format(params.get("area_name")[0]))) limit = int(params.get("limit")[0]) if params.get("limit") else 100 - areas = q.limit(limit) - # allow to format response format = request.args.get("format", default="", type=str) fields = {"area_type.type_code"} if format == "geojson": fields |= {"+geom_4326"} - areas = areas.options(undefer("geom_4326")) - response = AreaSchema(only=fields, as_geojson=format == "geojson").dump(areas.all(), many=True) + query = query.options(undefer("geom_4326")) + + areas = db.session.scalars(query.limit(limit)).unique().all() + + response = AreaSchema(only=fields, as_geojson=format == "geojson").dump(areas, many=True) if format == "geojson": # retro-compat: return a list of Features instead of the FeatureCollection response = response["features"] @@ -242,7 +249,7 @@ def get_area_size(): raise BadRequest("Missing 'geometry' in request payload") geojson = json.dumps(geojson) - query = db.session.query(area_size_func.params(geojson=geojson)) + query = select(area_size_func.params(geojson=geojson)) return jsonify(db.session.execute(query).scalar()) @@ -261,20 +268,18 @@ def get_area_types(): type_code = request.args.get("code") type_name = request.args.get("name") sort = request.args.get("sort") - query = db.session.query(BibAreasTypes) + query = select(BibAreasTypes) # GET ONLY INFO FOR A SPECIFIC CODE if type_code: - code_exists = ( - db.session.query(BibAreasTypes) - .filter(BibAreasTypes.type_code == type_code) - .one_or_none() - ) + code_exists = db.session.scalars( + select(BibAreasTypes).where(BibAreasTypes.type_code == type_code) + ).scalar_one_or_none() if not code_exists: raise BadRequest("This area type code does not exist") - query = query.filter(BibAreasTypes.type_code == type_code) + query = query.where(BibAreasTypes.type_code == type_code) # FILTER BY NAME if type_name: - query = query.filter(BibAreasTypes.type_name.ilike("%{}%".format(type_name))) + query = query.where(BibAreasTypes.type_name.ilike("%{}%".format(type_name))) # SORT if sort == "asc": query = query.order_by(asc("type_name")) @@ -282,4 +287,4 @@ def get_area_types(): query = query.order_by(desc("type_name")) # FIELDS fields = ["type_name", "type_code", "id_type"] - return jsonify(AreaTypeSchema(only=fields).dump(query.all(), many=True)) + return jsonify(AreaTypeSchema(only=fields).dump(db.session.scalars(query).all(), many=True)) diff --git a/src/ref_geo/tests/test_ref_geo.py b/src/ref_geo/tests/test_ref_geo.py index a0c7b76..e0ac3da 100644 --- a/src/ref_geo/tests/test_ref_geo.py +++ b/src/ref_geo/tests/test_ref_geo.py @@ -10,6 +10,7 @@ from ref_geo.env import db from ref_geo.models import BibAreasTypes, LAreas +from sqlalchemy import select polygon = { @@ -42,7 +43,7 @@ def has_french_dem(): @pytest.fixture(scope="function") def area_commune(): - return BibAreasTypes.query.filter_by(type_code="COM").one() + return db.session.execute(select(BibAreasTypes).filter_by(type_code="COM")).scalar_one() @pytest.mark.usefixtures("client_class", "temporary_transaction") @@ -302,7 +303,7 @@ def test_get_areas_as_geojson(self, area_commune): """ type_code = area_commune.type_code id_type = area_commune.id_type - first_comm = LAreas.query.filter(LAreas.id_type == id_type).first() + first_comm = db.session.scalars(db.select(LAreas).where(LAreas.id_type == id_type)).first() # will test many responses are return response = self.client.get( url_for("ref_geo.get_areas"),