Skip to content

Commit

Permalink
Merge pull request #40 from delta-mpc/zk
Browse files Browse the repository at this point in the history
HLR Zk
  • Loading branch information
mh739025250 authored Sep 29, 2022
2 parents 8984cc3 + 9e88c63 commit cd515d2
Show file tree
Hide file tree
Showing 108 changed files with 7,578 additions and 2,187 deletions.
16 changes: 15 additions & 1 deletion build_proto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,18 @@ set -e

. ./.venv/bin/activate

exec python3 -m grpc_tools.protoc -Idelta_node/chain --python_out=delta_node/chain --grpclib_python_out=delta_node/chain --mypy_out=delta_node/chain delta_node/chain/chain.proto
mkdir -p delta_node/chain/identity
mkdir -p delta_node/chain/horizontal
mkdir -p delta_node/chain/datahub
mkdir -p delta_node/chain/hlr
mkdir -p delta_node/chain/subscribe
mkdir -p delta_node/chain/transaction
mkdir -p delta_node/zh

python3 -m grpc_tools.protoc -Idelta_node/proto --python_out=delta_node/chain/identity --grpclib_python_out=delta_node/chain/identity --mypy_out=delta_node/chain/identity delta_node/proto/identity.proto
python3 -m grpc_tools.protoc -Idelta_node/proto --python_out=delta_node/chain/horizontal --grpclib_python_out=delta_node/chain/horizontal --mypy_out=delta_node/chain/horizontal delta_node/proto/horizontal.proto
python3 -m grpc_tools.protoc -Idelta_node/proto --python_out=delta_node/chain/datahub --grpclib_python_out=delta_node/chain/datahub --mypy_out=delta_node/chain/datahub delta_node/proto/datahub.proto
python3 -m grpc_tools.protoc -Idelta_node/proto --python_out=delta_node/chain/hlr --grpclib_python_out=delta_node/chain/hlr --mypy_out=delta_node/chain/hlr delta_node/proto/hlr.proto
python3 -m grpc_tools.protoc -Idelta_node/proto --python_out=delta_node/chain/subscribe --grpclib_python_out=delta_node/chain/subscribe --mypy_out=delta_node/chain/subscribe delta_node/proto/subscribe.proto
python3 -m grpc_tools.protoc -Idelta_node/proto --python_out=delta_node/chain/transaction --grpclib_python_out=delta_node/chain/transaction --mypy_out=delta_node/chain/transaction delta_node/proto/transaction.proto
python3 -m grpc_tools.protoc -Idelta_node/proto --python_out=delta_node/zk --grpclib_python_out=delta_node/zk --mypy_out=delta_node/zk delta_node/proto/delta-zk.proto
76 changes: 54 additions & 22 deletions delta_node/app/v1/coord.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json
import logging
import os
Expand Down Expand Up @@ -101,23 +103,38 @@ async def upload_secret_share(
shares: SecretShares, session: AsyncSession = Depends(db.get_session)
):
q = (
sa.select(entity.TaskRound)
.where(entity.TaskRound.task_id == shares.task_id)
.where(entity.TaskRound.round == shares.round)
sa.select(entity.Task)
.where(entity.Task.task_id == shares.task_id)
)
task: entity.Task | None = (await session.execute(q)).scalars().one_or_none()
if task is None:
raise HTTPException(400, "task does not exist")

if task.type == "horizontal":
te = entity.horizontal
elif task.type == "hlr":
te = entity.hlr
else:
raise HTTPException(400, f"unknown task type {task.type}")

q = (
sa.select(te.TaskRound)
.where(te.TaskRound.task_id == shares.task_id)
.where(te.TaskRound.round == shares.round)
)
round: Optional[entity.TaskRound] = (
round = (
(await session.execute(q)).scalars().one_or_none()
)
if not round:
raise HTTPException(400, "task round does not exist")
if round.status != entity.RoundStatus.RUNNING:
if round.status != te.RoundStatus.RUNNING:
raise HTTPException(400, "round is not in running phase")

# get members in running phase
q = (
sa.select(entity.RoundMember)
.where(entity.RoundMember.status == entity.RoundStatus.RUNNING)
.where(entity.RoundMember.round_id == round.id)
sa.select(te.RoundMember)
.where(te.RoundMember.status == te.RoundStatus.RUNNING)
.where(te.RoundMember.round_id == round.id)
)
members = (await session.execute(q)).scalars().all()
member_addrs = [member.address for member in members]
Expand All @@ -135,7 +152,7 @@ async def upload_secret_share(

for share in shares.shares:
receiver = member_dict[share.receiver]
ss = entity.SecretShare(
ss = te.SecretShare(
sender.id,
receiver.id,
hex_to_bytes(share.seed_share),
Expand All @@ -160,35 +177,50 @@ async def get_secret_shares(
session: AsyncSession = Depends(db.get_session),
):
q = (
sa.select(entity.TaskRound)
.where(entity.TaskRound.task_id == task_id)
.where(entity.TaskRound.round == round)
sa.select(entity.Task)
.where(entity.Task.task_id == task_id)
)
task: entity.Task | None = (await session.execute(q)).scalars().one_or_none()
if task is None:
raise HTTPException(400, "task does not exist")

if task.type == "horizontal":
te = entity.horizontal
elif task.type == "hlr":
te = entity.hlr
else:
raise HTTPException(400, f"unknown task type {task.type}")

q = (
sa.select(te.TaskRound)
.where(te.TaskRound.task_id == task_id)
.where(te.TaskRound.round == round)
)
round_entity: Optional[entity.TaskRound] = (
round_entity = (
(await session.execute(q)).scalars().one_or_none()
)
if not round_entity:
raise HTTPException(400, "task round does not exist")
if round_entity.status != entity.RoundStatus.CALCULATING:
if round_entity.status != te.RoundStatus.CALCULATING:
raise HTTPException(400, "round is not in calculating phase")

q = (
sa.select(entity.RoundMember)
.where(entity.RoundMember.round_id == round_entity.id)
.where(entity.RoundMember.address == address)
.options(selectinload(entity.RoundMember.received_shares))
sa.select(te.RoundMember)
.where(te.RoundMember.round_id == round_entity.id)
.where(te.RoundMember.address == address)
.options(selectinload(te.RoundMember.received_shares))
)
member: Optional[entity.RoundMember] = (
member = (
(await session.execute(q)).scalars().one_or_none()
)
if not member:
raise HTTPException(400, f"member {address} does not exists")
if member.status != entity.RoundStatus.CALCULATING:
if member.status != te.RoundStatus.CALCULATING:
raise HTTPException(400, f"member {address} is not allowed")

sender_ids = [share.sender_id for share in member.received_shares]
q = sa.select(entity.RoundMember).where(entity.RoundMember.id.in_(sender_ids)) # type: ignore
senders: List[entity.RoundMember] = (await session.execute(q)).scalars().all()
q = sa.select(te.RoundMember).where(te.RoundMember.id.in_(sender_ids)) # type: ignore
senders = (await session.execute(q)).scalars().all()

sender_dict = {sender.id: sender for sender in senders}
shares = []
Expand Down
4 changes: 2 additions & 2 deletions delta_node/app/v1/node.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
from typing import List

from delta_node import chain
from delta_node.chain import identity
from fastapi import APIRouter, Query
from pydantic import BaseModel

Expand All @@ -21,7 +21,7 @@ class NodesPage(BaseModel):

@router.get("/nodes", response_model=NodesPage)
async def get_nodes(page: int = Query(..., ge=1), page_size: int = Query(20, gt=0)):
nodes, total_count = await chain.get_client().get_nodes(
nodes, total_count = await identity.get_client().get_nodes(
page=page, page_size=page_size
)
total_pages = math.ceil(total_count / page_size)
Expand Down
38 changes: 33 additions & 5 deletions delta_node/app/v1/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from typing import IO, List, Optional

import sqlalchemy as sa
from delta_node import chain, coord, db, entity, pool, registry
from delta_node import coord, db, entity, pool, registry
from delta_node.chain import hlr, horizontal
from fastapi import (
APIRouter,
BackgroundTasks,
Expand Down Expand Up @@ -42,9 +43,30 @@ def move_task_file(task_file: IO[bytes], task_id: str):
async def run_task(task_item: entity.Task, task_file: IO[bytes]):
node_address = await registry.get_node_address()

tx_hash, task_id = await chain.get_client().create_task(
node_address, task_item.dataset, task_item.commitment, task_item.type
)
try:
if task_item.type == "horizontal":
tx_hash, task_id = await horizontal.get_client().create_task(
node_address, task_item.dataset, task_item.commitment, task_item.type
)
elif task_item.type == "hlr":
tx_hash, task_id = await hlr.get_client().create_task(
node_address,
task_item.dataset,
task_item.commitment,
task_item.enable_verify,
task_item.tolerance,
)
else:
raise TypeError(f"unknown task type {task_item.type}")
except Exception as e:
async with db.session_scope() as sess:
task_item.status = entity.TaskStatus.ERROR
task_item = await sess.merge(task_item)
sess.add(task_item)
await sess.commit()
_logger.error(f"create task of id {task_item.id} error: {str(e)}")
raise

task_item.task_id = task_id
task_item.creator = node_address
task_item.status = entity.TaskStatus.RUNNING
Expand All @@ -70,7 +92,10 @@ async def create_task(
background: BackgroundTasks,
):
f = await pool.run_in_io(create_task_file, file.file)
task_item = await pool.run_in_io(coord.create_task, f)
try:
task_item = await pool.run_in_io(coord.create_task, f)
except TypeError as e:
raise HTTPException(400, str(e))
session.add(task_item)
await session.commit()
await session.refresh(task_item)
Expand All @@ -86,6 +111,7 @@ class Task(BaseModel):
type: str
creator: str
status: str
enable_verify: bool


@task_router.get("/list", response_model=List[Task])
Expand All @@ -108,6 +134,7 @@ async def get_task_list(
type=task.type,
creator=task.creator,
status=task.status.name,
enable_verify=task.enable_verify,
)
)
return task_items
Expand All @@ -130,6 +157,7 @@ async def get_task_metadata(
type=task.type,
creator=task.creator,
status=task.status.name,
enable_verify=task.enable_verify,
)


Expand Down
60 changes: 13 additions & 47 deletions delta_node/chain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,13 @@
import threading

from delta_node import config
from grpclib.client import Channel
from grpclib.config import Configuration

from .client import ChainClient

__all__ = ["init", "get_client", "close"]

_local = threading.local()


def init(
host: str = config.chain_host, port: int = config.chain_port, *, ssl: bool = False
):
if hasattr(_local, "ch") or hasattr(_local, "client"):
raise ValueError("chain has been initialized")

config = Configuration(
_keepalive_time=10,
_keepalive_timeout=5,
_keepalive_permit_without_calls=True,
_http2_max_pings_without_data=0,
)
ch = Channel(host, port, ssl=ssl, config=config)
client = ChainClient(ch)
_local.ch = ch
_local.client = client


def get_client() -> ChainClient:
if not hasattr(_local, "client"):
raise ValueError("chain has not been initialized")

client: ChainClient = _local.client
return client


def close():
if (not hasattr(_local, "client")) or (not hasattr(_local, "ch")):
raise ValueError("chain has not been initialized")

ch: Channel = _local.ch
ch.close()
delattr(_local, "ch")
delattr(_local, "client")
from . import datahub, hlr, horizontal, identity, subscribe
from .channel import close, get_channel, init

__all__ = [
"init",
"get_channel",
"close",
"datahub",
"hlr",
"horizontal",
"identity",
"subscribe",
]
Loading

0 comments on commit cd515d2

Please sign in to comment.