Skip to content

Commit

Permalink
Merge pull request #6586 from hotosm/fastapi-refactor
Browse files Browse the repository at this point in the history
Project tasks gpx and xml export
  • Loading branch information
prabinoid authored Oct 2, 2024
2 parents c54c64a + 980194e commit eda14ea
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 55 deletions.
61 changes: 29 additions & 32 deletions backend/api/tasks/resources.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
import io
from distutils.util import strtobool

from backend.services.mapping_service import MappingService
from backend.models.dtos.grid_dto import GridDTO

from backend.services.users.authentication_service import tm
from backend.services.users.user_service import UserService
from backend.services.validator_service import ValidatorService

from backend.services.project_service import ProjectService, ProjectServiceError
from backend.services.grid.grid_service import GridService
from backend.models.postgis.statuses import UserRole
from backend.models.postgis.utils import InvalidGeoJson
from databases import Database
from fastapi import APIRouter, Depends, Request
from starlette.authentication import requires
from fastapi.responses import Response, StreamingResponse
from loguru import logger
from starlette.authentication import requires

from backend.db import get_db
from databases import Database

from backend.models.dtos.grid_dto import GridDTO
from backend.models.postgis.statuses import UserRole
from backend.models.postgis.utils import InvalidGeoJson
from backend.services.grid.grid_service import GridService
from backend.services.mapping_service import MappingService
from backend.services.project_service import ProjectService, ProjectServiceError
from backend.services.users.authentication_service import tm
from backend.services.users.user_service import UserService
from backend.services.validator_service import ValidatorService

router = APIRouter(
prefix="/projects",
Expand Down Expand Up @@ -207,7 +205,7 @@ async def delete(request: Request, project_id):


@router.get("/{project_id}/tasks/queries/xml/")
async def get(request: Request, project_id: int):
async def get(request: Request, project_id: int, db: Database = Depends(get_db)):
"""
Get all tasks for a project as OSM XML
---
Expand Down Expand Up @@ -242,30 +240,29 @@ async def get(request: Request, project_id: int):
500:
description: Internal Server Error
"""
tasks = (
request.query_params.get("tasks") if request.query_params.get("tasks") else None
)
tasks = request.query_params.get("tasks")
as_file = (
strtobool(request.query_params.get("as_file"))
if request.query_params.get("as_file")
else False
)

xml = MappingService.generate_osm_xml(project_id, tasks)
xml = await MappingService.generate_osm_xml(project_id, tasks, db)

if as_file:
return send_file(
return StreamingResponse(
io.BytesIO(xml),
mimetype="text.xml",
as_attachment=True,
download_name=f"HOT-project-{project_id}.osm",
media_type="text/xml",
headers={
"Content-Disposition": f"attachment; filename=HOT-project-{project_id}.osm"
},
)

return Response(xml, mimetype="text/xml", status=200)
return Response(content=xml, media_type="text/xml", status_code=200)


@router.get("/{project_id}/tasks/queries/gpx/")
async def get(request: Request, project_id):
async def get(request: Request, project_id: int, db: Database = Depends(get_db)):
"""
Get all tasks for a project as GPX
---
Expand Down Expand Up @@ -300,25 +297,25 @@ async def get(request: Request, project_id):
500:
description: Internal Server Error
"""
logger.debug("GPX Called")
tasks = request.query_params.get("tasks")
as_file = (
strtobool(request.query_params.get("as_file"))
if request.query_params.get("as_file")
else False
)

xml = MappingService.generate_gpx(project_id, tasks)
xml = await MappingService.generate_gpx(project_id, tasks, db)

if as_file:
return send_file(
return StreamingResponse(
io.BytesIO(xml),
mimetype="text.xml",
as_attachment=True,
download_name=f"HOT-project-{project_id}.gpx",
media_type="text/xml",
headers={
"Content-Disposition": f"attachment; filename=HOT-project-{project_id}.gpx"
},
)

return Response(xml, mimetype="text/xml", status=200)
return Response(content=xml, media_type="text/xml", status_code=200)


@router.put("/{project_id}/tasks/queries/aoi/")
Expand Down
49 changes: 39 additions & 10 deletions backend/models/postgis/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,19 +818,48 @@ async def exists(task_id: int, project_id: int, db: Database) -> bool:
)
return task is not None

# @staticmethod
# def get_tasks(project_id: int, task_ids: List[int]):
# """Get all tasks that match supplied list"""
# return (
# session.query(Task)
# .filter(Task.project_id == project_id, Task.id.in_(task_ids))
# .all()
# )

@staticmethod
def get_tasks(project_id: int, task_ids: List[int]):
"""Get all tasks that match supplied list"""
return (
session.query(Task)
.filter(Task.project_id == project_id, Task.id.in_(task_ids))
.all()
)
async def get_tasks(project_id: int, task_ids: List[int], db: Database):
"""
Get all tasks that match the supplied list of task_ids for a project.
"""
query = """
SELECT id, geometry
FROM tasks
WHERE project_id = :project_id
AND id = ANY(:task_ids)
"""
values = {"project_id": project_id, "task_ids": task_ids}
rows = await db.fetch_all(query=query, values=values)
return rows

# @staticmethod
# def get_all_tasks(project_id: int):
# """Get all tasks for a given project"""
# return session.query(Task).filter(Task.project_id == project_id).all()

@staticmethod
def get_all_tasks(project_id: int):
"""Get all tasks for a given project"""
return session.query(Task).filter(Task.project_id == project_id).all()
async def get_all_tasks(project_id: int, db: Database):
"""
Get all tasks for a given project.
"""
query = """
SELECT id, geometry
FROM tasks
WHERE project_id = :project_id
"""
values = {"project_id": project_id}
rows = await db.fetch_all(query=query, values=values)
return rows

@staticmethod
def get_tasks_by_status(project_id: int, status: str):
Expand Down
32 changes: 19 additions & 13 deletions backend/services/mapping_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import xml.etree.ElementTree as ET

# from flask import current_app
from geoalchemy2 import shape

from geoalchemy2 import WKBElement
from geoalchemy2.shape import to_shape
from backend.exceptions import NotFound
from backend.models.dtos.mapping_dto import (
ExtendLockTimeDTO,
Expand Down Expand Up @@ -276,7 +276,9 @@ async def add_task_comment(task_comment: TaskCommentDTO) -> TaskDTO:
return task.as_dto_with_instructions(task_comment.preferred_locale)

@staticmethod
def generate_gpx(project_id: int, task_ids_str: str, timestamp=None):
async def generate_gpx(
project_id: int, task_ids_str: str, db: Database, timestamp=None
):
"""
Creates a GPX file for supplied tasks. Timestamp is for unit testing only.
You can use the following URL to test locally:
Expand Down Expand Up @@ -316,18 +318,22 @@ def generate_gpx(project_id: int, task_ids_str: str, timestamp=None):
# Construct trkseg elements
if task_ids_str is not None:
task_ids = list(map(int, task_ids_str.split(",")))
tasks = Task.get_tasks(project_id, task_ids)
tasks = await Task.get_tasks(project_id, task_ids, db)
if not tasks or len(tasks) == 0:
raise NotFound(
sub_code="TASKS_NOT_FOUND", project_id=project_id, task_ids=task_ids
)
else:
tasks = Task.get_all_tasks(project_id)
tasks = await Task.get_all_tasks(project_id, db)
if not tasks or len(tasks) == 0:
raise NotFound(sub_code="TASKS_NOT_FOUND", project_id=project_id)

for task in tasks:
task_geom = shape.to_shape(task.geometry)
# task_geom = shape.to_shape(task.geometry)
if isinstance(task["geometry"], (bytes, str)):
task_geom = to_shape(WKBElement(task["geometry"], srid=4326))
else:
raise ValueError("Invalid geometry format")
for poly in task_geom.geoms:
trkseg = ET.SubElement(trk, "trkseg")
for point in poly.exterior.coords:
Expand All @@ -336,8 +342,6 @@ def generate_gpx(project_id: int, task_ids_str: str, timestamp=None):
"trkpt",
attrib=dict(lon=str(point[0]), lat=str(point[1])),
)

# Append wpt elements to end of doc
wpt = ET.Element(
"wpt", attrib=dict(lon=str(point[0]), lat=str(point[1]))
)
Expand All @@ -347,30 +351,32 @@ def generate_gpx(project_id: int, task_ids_str: str, timestamp=None):
return xml_gpx

@staticmethod
def generate_osm_xml(project_id: int, task_ids_str: str) -> str:
async def generate_osm_xml(project_id: int, task_ids_str: str, db: Database) -> str:
"""Generate xml response suitable for loading into JOSM. A sample output file is in
/backend/helpers/testfiles/osm-sample.xml"""
# Note XML created with upload No to ensure it will be rejected by OSM if uploaded by mistake
root = ET.Element(
"osm",
attrib=dict(version="0.6", upload="never", creator="HOT Tasking Manager"),
)

if task_ids_str:
task_ids = list(map(int, task_ids_str.split(",")))
tasks = Task.get_tasks(project_id, task_ids)
tasks = await Task.get_tasks(project_id, task_ids, db)
if not tasks or len(tasks) == 0:
raise NotFound(
sub_code="TASKS_NOT_FOUND", project_id=project_id, task_ids=task_ids
)
else:
tasks = Task.get_all_tasks(project_id)
tasks = await Task.get_all_tasks(project_id, db)
if not tasks or len(tasks) == 0:
raise NotFound(sub_code="TASKS_NOT_FOUND", project_id=project_id)

fake_id = -1 # We use fake-ids to ensure XML will not be validated by OSM
for task in tasks:
task_geom = shape.to_shape(task.geometry)
if isinstance(task["geometry"], (bytes, str)):
task_geom = to_shape(WKBElement(task["geometry"], srid=4326))
else:
raise ValueError("Invalid geometry format")
way = ET.SubElement(
root,
"way",
Expand Down

0 comments on commit eda14ea

Please sign in to comment.