Skip to content

Commit

Permalink
drop .query
Browse files Browse the repository at this point in the history
  • Loading branch information
jacquesfize authored and bouttier committed Dec 19, 2023
1 parent 2f57360 commit ee68a41
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 37 deletions.
6 changes: 3 additions & 3 deletions src/ref_geo/commands.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import click
from flask.cli import with_appcontext
from sqlalchemy import func, column
from sqlalchemy import func, select

from ref_geo.env import db
from ref_geo.models import BibAreasTypes, LAreas
Expand All @@ -16,10 +16,10 @@ def ref_geo():
def info():
click.echo("RefGeo : nombre de zones par type")
q = (
db.session.query(BibAreasTypes, func.count(LAreas.id_area).label("count"))
select(BibAreasTypes, func.count(LAreas.id_area).label("count"))
.join(LAreas)
.group_by(BibAreasTypes.id_type)
.order_by(BibAreasTypes.id_type)
)
for area_type, count in q.all():
for area_type, count in db.session.scalars(q).unique().all():
click.echo("\t{}: {}".format(area_type.type_name, count))
69 changes: 37 additions & 32 deletions src/ref_geo/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from flask import Blueprint, request, current_app
from flask.json import jsonify
import sqlalchemy as sa
from sqlalchemy import func, distinct, asc, desc
from sqlalchemy import func, select, asc, desc
from sqlalchemy.sql import text
from sqlalchemy.orm import joinedload, undefer
from werkzeug.exceptions import BadRequest
Expand Down Expand Up @@ -58,8 +58,10 @@ def getGeoInfo():
raise BadRequest("Missing 'geometry' in request payload")
geojson = json.dumps(geojson)

areas = LAreas.query.filter_by(enable=True).filter(
geojson_intersect_filter.params(geojson=geojson)
areas = (
select(LAreas)
.filter_by(enable=True)
.filter(geojson_intersect_filter.params(geojson=geojson))
)
if "area_type" in request.json:
areas = areas.join(BibAreasTypes).filter_by(type_code=request.json["area_type"])
Expand All @@ -75,7 +77,7 @@ def getGeoInfo():
return jsonify(
{
"areas": AreaSchema(only=["id_area", "id_type", "area_code", "area_name"]).dump(
areas.all(), many=True
db.session.scalars(areas).unique().all(), many=True
),
"altitude": altitude,
}
Expand Down Expand Up @@ -116,8 +118,10 @@ def getAreasIntersection():
raise BadRequest("Missing 'geometry' in request payload")
geojson = json.dumps(geojson)

areas = LAreas.query.filter_by(enable=True).filter(
geojson_intersect_filter.params(geojson=geojson)
areas = (
select(LAreas)
.filter_by(enable=True)
.filter(geojson_intersect_filter.params(geojson=geojson))
)
if "area_type" in request.json:
areas = areas.join(BibAreasTypes).filter_by(type_code=request.json["area_type"])
Expand All @@ -130,7 +134,9 @@ def getAreasIntersection():
areas = areas.order_by(LAreas.id_type)

response = {}
for id_type, _areas in groupby(areas.all(), key=lambda area: area.id_type):
for id_type, _areas in groupby(
db.session.scalars(areas).unique().all(), key=lambda area: area.id_type
):
_areas = list(_areas)
response[id_type] = _areas[0].area_type.as_dict(fields=["type_code", "type_name"])
response[id_type].update(
Expand All @@ -157,13 +163,13 @@ def get_municipalities():
"""
parameters = request.args

q = db.session.query(LiMunicipalities).order_by(LiMunicipalities.nom_com.asc())
q = select(LiMunicipalities).order_by(LiMunicipalities.nom_com.asc())

if "nom_com" in parameters:
q = q.filter(LiMunicipalities.nom_com.ilike("{}%".format(parameters.get("nom_com"))))
q = q.where(LiMunicipalities.nom_com.ilike("{}%".format(parameters.get("nom_com"))))
limit = int(parameters.get("limit")) if parameters.get("limit") else 100

municipalities = q.limit(limit)
municipalities = db.session.scalars(q.limit(limit)).all()
return jsonify(MunicipalitySchema().dump(municipalities, many=True))


Expand All @@ -176,8 +182,8 @@ def get_areas():
# change all args in a list of value
params = {key: request.args.getlist(key) for key, value in request.args.items()}

q = (
db.session.query(LAreas)
query = (
select(LAreas)
.options(joinedload("area_type").load_only("type_code"))
.order_by(LAreas.area_name.asc())
)
Expand All @@ -192,33 +198,34 @@ def get_areas():
}
return response, 400
if enable_param == "true":
q = q.filter(LAreas.enable == True)
query = query.where(LAreas.enable == True)
elif enable_param == "false":
q = q.filter(LAreas.enable == False)
query = query.where(LAreas.enable == False)
else:
q = q.filter(LAreas.enable == True)
query = query.where(LAreas.enable == True)

if "id_type" in params:
q = q.filter(LAreas.id_type.in_(params["id_type"]))
query = query.where(LAreas.id_type.in_(params["id_type"]))

if "type_code" in params:
q = q.filter(LAreas.area_type.has(BibAreasTypes.type_code.in_(params["type_code"])))
query = query.where(LAreas.area_type.has(BibAreasTypes.type_code.in_(params["type_code"])))

if "area_name" in params:
q = q.filter(LAreas.area_name.ilike("%{}%".format(params.get("area_name")[0])))
query = query.where(LAreas.area_name.ilike("%{}%".format(params.get("area_name")[0])))

limit = int(params.get("limit")[0]) if params.get("limit") else 100

areas = q.limit(limit)

# allow to format response
format = request.args.get("format", default="", type=str)

fields = {"area_type.type_code"}
if format == "geojson":
fields |= {"+geom_4326"}
areas = areas.options(undefer("geom_4326"))
response = AreaSchema(only=fields, as_geojson=format == "geojson").dump(areas.all(), many=True)
query = query.options(undefer("geom_4326"))

areas = db.session.scalars(query.limit(limit)).unique().all()

response = AreaSchema(only=fields, as_geojson=format == "geojson").dump(areas, many=True)
if format == "geojson":
# retro-compat: return a list of Features instead of the FeatureCollection
response = response["features"]
Expand All @@ -242,7 +249,7 @@ def get_area_size():
raise BadRequest("Missing 'geometry' in request payload")
geojson = json.dumps(geojson)

query = db.session.query(area_size_func.params(geojson=geojson))
query = select(area_size_func.params(geojson=geojson))

return jsonify(db.session.execute(query).scalar())

Expand All @@ -261,25 +268,23 @@ def get_area_types():
type_code = request.args.get("code")
type_name = request.args.get("name")
sort = request.args.get("sort")
query = db.session.query(BibAreasTypes)
query = select(BibAreasTypes)
# GET ONLY INFO FOR A SPECIFIC CODE
if type_code:
code_exists = (
db.session.query(BibAreasTypes)
.filter(BibAreasTypes.type_code == type_code)
.one_or_none()
)
code_exists = db.session.scalars(
select(BibAreasTypes).where(BibAreasTypes.type_code == type_code)
).scalar_one_or_none()
if not code_exists:
raise BadRequest("This area type code does not exist")
query = query.filter(BibAreasTypes.type_code == type_code)
query = query.where(BibAreasTypes.type_code == type_code)
# FILTER BY NAME
if type_name:
query = query.filter(BibAreasTypes.type_name.ilike("%{}%".format(type_name)))
query = query.where(BibAreasTypes.type_name.ilike("%{}%".format(type_name)))
# SORT
if sort == "asc":
query = query.order_by(asc("type_name"))
if sort == "desc":
query = query.order_by(desc("type_name"))
# FIELDS
fields = ["type_name", "type_code", "id_type"]
return jsonify(AreaTypeSchema(only=fields).dump(query.all(), many=True))
return jsonify(AreaTypeSchema(only=fields).dump(db.session.scalars(query).all(), many=True))
5 changes: 3 additions & 2 deletions src/ref_geo/tests/test_ref_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ref_geo.env import db
from ref_geo.models import BibAreasTypes, LAreas
from sqlalchemy import select


polygon = {
Expand Down Expand Up @@ -42,7 +43,7 @@ def has_french_dem():

@pytest.fixture(scope="function")
def area_commune():
return BibAreasTypes.query.filter_by(type_code="COM").one()
return db.session.execute(select(BibAreasTypes).filter_by(type_code="COM")).scalar_one()


@pytest.mark.usefixtures("client_class", "temporary_transaction")
Expand Down Expand Up @@ -302,7 +303,7 @@ def test_get_areas_as_geojson(self, area_commune):
"""
type_code = area_commune.type_code
id_type = area_commune.id_type
first_comm = LAreas.query.filter(LAreas.id_type == id_type).first()
first_comm = db.session.scalars(db.select(LAreas).where(LAreas.id_type == id_type)).first()
# will test many responses are return
response = self.client.get(
url_for("ref_geo.get_areas"),
Expand Down

0 comments on commit ee68a41

Please sign in to comment.