diff --git a/.gitignore b/.gitignore index 8a548b5..961e75c 100644 --- a/.gitignore +++ b/.gitignore @@ -44,3 +44,5 @@ data*/*.npz db/*.db *.db nodes/ +.python-version +main.py diff --git a/delta_node/app/v1/coord.py b/delta_node/app/v1/coord.py index 92a8163..534fb80 100644 --- a/delta_node/app/v1/coord.py +++ b/delta_node/app/v1/coord.py @@ -1,15 +1,15 @@ -import asyncio import json import logging import os import shutil -from typing import IO, Dict, List, Optional +import time +from typing import Dict, List, Optional import sqlalchemy as sa -from delta_node import chain, coord, db, entity, pool, registry +from delta_node import coord, db, entity from delta_node.serialize import bytes_to_hex, hex_to_bytes -from fastapi import (APIRouter, BackgroundTasks, Depends, File, Form, - HTTPException, Query, UploadFile) +from fastapi import (APIRouter, Depends, File, Form, HTTPException, Query, + UploadFile) from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession @@ -27,8 +27,13 @@ class CommonResp(BaseModel): @router.get("/config") def get_task_config(task_id: str = Query(..., regex=r"0x[0-9a-fA-F]+")): config_filename = coord.task_config_file(task_id) + retry = 3 if not os.path.exists(config_filename): - raise HTTPException(400, f"task {task_id} does not exist") + retry -= 1 + if retry == 0: + raise HTTPException(400, f"task {task_id} does not exist") + else: + time.sleep(2) def file_iter(): chunk_size = 1024 * 1024 diff --git a/delta_node/app/v1/task.py b/delta_node/app/v1/task.py index 43fefc8..37d35ba 100644 --- a/delta_node/app/v1/task.py +++ b/delta_node/app/v1/task.py @@ -3,19 +3,13 @@ import math import os import shutil +from tempfile import TemporaryFile from typing import IO, List, Optional import sqlalchemy as sa from delta_node import chain, coord, db, entity, pool, registry -from fastapi import ( - APIRouter, - BackgroundTasks, - Depends, - File, - HTTPException, - Query, - UploadFile, -) +from fastapi import (APIRouter, BackgroundTasks, Depends, File, HTTPException, + Query, UploadFile) from fastapi.responses import StreamingResponse from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession @@ -29,11 +23,41 @@ class CreateTaskResp(BaseModel): task_id: str -def dump_task(task_id: str, task_file: IO[bytes]): +def create_task_file(task_file: IO[bytes]): + f = TemporaryFile(mode="w+b") + shutil.copyfileobj(task_file, f) + return f + + +def move_task_file(task_file: IO[bytes], task_id: str): task_file.seek(0) with open(coord.task_config_file(task_id), mode="wb") as f: shutil.copyfileobj(task_file, f) - _logger.debug(f"save task config file of {task_id}") + task_file.close() + + +async def run_task(id: int, task_file: IO[bytes]): + node_address = await registry.get_node_address() + + async with db.session_scope() as sess: + q = sa.select(entity.Task).where(entity.Task.id == id) + task_item: entity.Task = (await sess.execute(q)).scalar_one() + + tx_hash, task_id = await chain.get_client().create_task( + node_address, task_item.dataset, task_item.commitment, task_item.type + ) + task_item.task_id = task_id + task_item.creator = node_address + sess.add(task_item) + await sess.commit() + + loop = asyncio.get_running_loop() + await loop.run_in_executor(pool.IO_POOL, move_task_file, task_file, task_id) + _logger.info( + f"create task {task_id}", extra={"task_id": task_id, "tx_hash": tx_hash} + ) + + await coord.run_task(task_id) @task_router.post("", response_model=CreateTaskResp) @@ -44,24 +68,14 @@ async def create_task( background: BackgroundTasks, ): loop = asyncio.get_running_loop() - task_item = await loop.run_in_executor(pool.IO_POOL, coord.create_task, file.file) - node_address = await registry.get_node_address() - tx_hash, task_id = await chain.get_client().create_task( - node_address, task_item.dataset, task_item.commitment, task_item.type - ) - task_item.creator = node_address - task_item.task_id = task_id - task_item.status = entity.TaskStatus.RUNNING + f = await loop.run_in_executor(pool.IO_POOL, create_task_file, file.file) + task_item = await loop.run_in_executor(pool.IO_POOL, coord.create_task, f) session.add(task_item) await session.commit() + await session.refresh(task_item) - await loop.run_in_executor(pool.IO_POOL, dump_task, task_id, file.file) - _logger.info( - f"create task {task_id}", extra={"task_id": task_id, "tx_hash": tx_hash} - ) - - background.add_task(coord.run_task, task_id) - return CreateTaskResp(task_id=task_id) + background.add_task(run_task, task_item.id, f) + return CreateTaskResp(task_id=str(task_item.id)) class Task(BaseModel): diff --git a/delta_node/coord/create.py b/delta_node/coord/create.py index 5bc199b..082eaf3 100644 --- a/delta_node/coord/create.py +++ b/delta_node/coord/create.py @@ -1,6 +1,6 @@ import shutil import tempfile -from typing import IO +from typing import IO, Tuple import delta import delta.serialize @@ -8,23 +8,21 @@ def create_task(task_file: IO[bytes]) -> entity.Task: - with tempfile.TemporaryFile(mode="w+b") as f: - shutil.copyfileobj(task_file, f) - f.seek(0) + task_file.seek(0) - task = delta.serialize.load_task(f) - dataset = task.dataset + task = delta.serialize.load_task(task_file) + dataset = task.dataset - f.seek(0) - commitment = utils.calc_commitment(f) + task_file.seek(0) + commitment = utils.calc_commitment(task_file) - task_item = entity.Task( - creator="", - task_id="", - dataset=dataset, - commitment=commitment, - status=entity.TaskStatus.PENDING, - name=task.name, - type=task.type, - ) - return task_item + task_item = entity.Task( + creator="", + task_id="", + dataset=dataset, + commitment=commitment, + status=entity.TaskStatus.PENDING, + name=task.name, + type=task.type, + ) + return task_item diff --git a/delta_node/main.py b/delta_node/main.py index 5a85cd9..fac614c 100644 --- a/delta_node/main.py +++ b/delta_node/main.py @@ -28,7 +28,7 @@ async def _run(): await commu.init() fut = asyncio.gather( - app.run("0.0.0.0", config.api_port), runner.run(), coord.run_unfinished_tasks() + app.run("0.0.0.0", config.api_port), runner.run() ) try: await fut