Skip to content

Commit

Permalink
Merge pull request #10 from delta-mpc/0.3.0
Browse files Browse the repository at this point in the history
0.3.0-rc2
  • Loading branch information
mh739025250 authored Jan 11, 2022
2 parents c676162 + f524a47 commit 55012af
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 51 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,5 @@ data*/*.npz
db/*.db
*.db
nodes/
.python-version
main.py
17 changes: 11 additions & 6 deletions delta_node/app/v1/coord.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import asyncio
import json
import logging
import os
import shutil
from typing import IO, Dict, List, Optional
import time
from typing import Dict, List, Optional

import sqlalchemy as sa
from delta_node import chain, coord, db, entity, pool, registry
from delta_node import coord, db, entity
from delta_node.serialize import bytes_to_hex, hex_to_bytes
from fastapi import (APIRouter, BackgroundTasks, Depends, File, Form,
HTTPException, Query, UploadFile)
from fastapi import (APIRouter, Depends, File, Form, HTTPException, Query,
UploadFile)
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
Expand All @@ -27,8 +27,13 @@ class CommonResp(BaseModel):
@router.get("/config")
def get_task_config(task_id: str = Query(..., regex=r"0x[0-9a-fA-F]+")):
config_filename = coord.task_config_file(task_id)
retry = 3
if not os.path.exists(config_filename):
raise HTTPException(400, f"task {task_id} does not exist")
retry -= 1
if retry == 0:
raise HTTPException(400, f"task {task_id} does not exist")
else:
time.sleep(2)

def file_iter():
chunk_size = 1024 * 1024
Expand Down
66 changes: 40 additions & 26 deletions delta_node/app/v1/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,13 @@
import math
import os
import shutil
from tempfile import TemporaryFile
from typing import IO, List, Optional

import sqlalchemy as sa
from delta_node import chain, coord, db, entity, pool, registry
from fastapi import (
APIRouter,
BackgroundTasks,
Depends,
File,
HTTPException,
Query,
UploadFile,
)
from fastapi import (APIRouter, BackgroundTasks, Depends, File, HTTPException,
Query, UploadFile)
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
Expand All @@ -29,11 +23,41 @@ class CreateTaskResp(BaseModel):
task_id: str


def dump_task(task_id: str, task_file: IO[bytes]):
def create_task_file(task_file: IO[bytes]):
f = TemporaryFile(mode="w+b")
shutil.copyfileobj(task_file, f)
return f


def move_task_file(task_file: IO[bytes], task_id: str):
task_file.seek(0)
with open(coord.task_config_file(task_id), mode="wb") as f:
shutil.copyfileobj(task_file, f)
_logger.debug(f"save task config file of {task_id}")
task_file.close()


async def run_task(id: int, task_file: IO[bytes]):
node_address = await registry.get_node_address()

async with db.session_scope() as sess:
q = sa.select(entity.Task).where(entity.Task.id == id)
task_item: entity.Task = (await sess.execute(q)).scalar_one()

tx_hash, task_id = await chain.get_client().create_task(
node_address, task_item.dataset, task_item.commitment, task_item.type
)
task_item.task_id = task_id
task_item.creator = node_address
sess.add(task_item)
await sess.commit()

loop = asyncio.get_running_loop()
await loop.run_in_executor(pool.IO_POOL, move_task_file, task_file, task_id)
_logger.info(
f"create task {task_id}", extra={"task_id": task_id, "tx_hash": tx_hash}
)

await coord.run_task(task_id)


@task_router.post("", response_model=CreateTaskResp)
Expand All @@ -44,24 +68,14 @@ async def create_task(
background: BackgroundTasks,
):
loop = asyncio.get_running_loop()
task_item = await loop.run_in_executor(pool.IO_POOL, coord.create_task, file.file)
node_address = await registry.get_node_address()
tx_hash, task_id = await chain.get_client().create_task(
node_address, task_item.dataset, task_item.commitment, task_item.type
)
task_item.creator = node_address
task_item.task_id = task_id
task_item.status = entity.TaskStatus.RUNNING
f = await loop.run_in_executor(pool.IO_POOL, create_task_file, file.file)
task_item = await loop.run_in_executor(pool.IO_POOL, coord.create_task, f)
session.add(task_item)
await session.commit()
await session.refresh(task_item)

await loop.run_in_executor(pool.IO_POOL, dump_task, task_id, file.file)
_logger.info(
f"create task {task_id}", extra={"task_id": task_id, "tx_hash": tx_hash}
)

background.add_task(coord.run_task, task_id)
return CreateTaskResp(task_id=task_id)
background.add_task(run_task, task_item.id, f)
return CreateTaskResp(task_id=str(task_item.id))


class Task(BaseModel):
Expand Down
34 changes: 16 additions & 18 deletions delta_node/coord/create.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,28 @@
import shutil
import tempfile
from typing import IO
from typing import IO, Tuple

import delta
import delta.serialize
from delta_node import entity, utils


def create_task(task_file: IO[bytes]) -> entity.Task:
with tempfile.TemporaryFile(mode="w+b") as f:
shutil.copyfileobj(task_file, f)
f.seek(0)
task_file.seek(0)

task = delta.serialize.load_task(f)
dataset = task.dataset
task = delta.serialize.load_task(task_file)
dataset = task.dataset

f.seek(0)
commitment = utils.calc_commitment(f)
task_file.seek(0)
commitment = utils.calc_commitment(task_file)

task_item = entity.Task(
creator="",
task_id="",
dataset=dataset,
commitment=commitment,
status=entity.TaskStatus.PENDING,
name=task.name,
type=task.type,
)
return task_item
task_item = entity.Task(
creator="",
task_id="",
dataset=dataset,
commitment=commitment,
status=entity.TaskStatus.PENDING,
name=task.name,
type=task.type,
)
return task_item
2 changes: 1 addition & 1 deletion delta_node/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async def _run():
await commu.init()

fut = asyncio.gather(
app.run("0.0.0.0", config.api_port), runner.run(), coord.run_unfinished_tasks()
app.run("0.0.0.0", config.api_port), runner.run()
)
try:
await fut
Expand Down

0 comments on commit 55012af

Please sign in to comment.