diff --git a/delta_node/chain/identity/client.py b/delta_node/chain/identity/client.py index bc1ad65..01009bb 100644 --- a/delta_node/chain/identity/client.py +++ b/delta_node/chain/identity/client.py @@ -32,7 +32,7 @@ async def update_name(self, address: str, name: str) -> str: _logger.error(e) raise - async def updaet_url(self, address: str, url: str) -> str: + async def update_url(self, address: str, url: str) -> str: req = pb.UpdateUrlReq(address=address, url=url) try: resp = await self.stub.UpdateUrl(req) diff --git a/delta_node/main.py b/delta_node/main.py index 02e9430..21805ea 100644 --- a/delta_node/main.py +++ b/delta_node/main.py @@ -24,18 +24,21 @@ async def _run(): await db.init(config.db) chain.init(config.chain_host, config.chain_port, ssl=False) zk.init(config.zk_host, config.zk_port, ssl=False) - await registry.register(config.node_url, config.node_name) - + r = registry.Registry(url=config.node_url, name=config.node_name) + await r.register() + + registry_fut = asyncio.create_task(r.start()) runner_fut = asyncio.create_task(runner.run()) app_fut = asyncio.create_task(app.run("0.0.0.0", config.api_port)) - fut = asyncio.gather(runner_fut, app_fut) + fut = asyncio.gather(registry_fut, runner_fut, app_fut) loop.add_signal_handler(signal.SIGINT, lambda: fut.cancel()) loop.add_signal_handler(signal.SIGTERM, lambda: fut.cancel()) try: await fut finally: - await registry.unregister() + await r.stop() + await r.unregister() chain.close() zk.close() await db.close() @@ -52,7 +55,7 @@ def run(): async def _leave(): - from delta_node import chain, config, db, log, pool, registry + from delta_node import chain, config, db, log, registry if len(config.chain_host) == 0: raise RuntimeError("chain connector host is required") @@ -69,7 +72,10 @@ async def _leave(): await db.init(config.db) chain.init(config.chain_host, config.chain_port, ssl=False) - await registry.unregister() + + r = registry.Registry(url=config.node_url, name=config.node_name) + await r.unregister() + chain.close() await db.close() listener.stop() diff --git a/delta_node/registry/registry.py b/delta_node/registry/registry.py index c5843f2..2e1b171 100644 --- a/delta_node/registry/registry.py +++ b/delta_node/registry/registry.py @@ -3,13 +3,13 @@ from typing import Optional import sqlalchemy as sa -from sqlalchemy.exc import NoResultFound from async_lru import alru_cache from delta_node import config, db -from delta_node.entity.identity import Node from delta_node.chain import identity +from delta_node.entity.identity import Node +from sqlalchemy.exc import NoResultFound -__all__ = ["register", "get_node_address", "unregister"] +__all__ = ["get_node_address", "Registry"] _logger = logging.getLogger(__name__) @@ -27,52 +27,77 @@ async def get_node_address() -> str: raise -async def register( - url: str = config.node_url, - name: str = config.node_name, -): - async with db.session_scope() as sess: - q = sa.select(Node).where(Node.id == 1) - node: Optional[Node] = (await sess.execute(q)).scalars().one_or_none() - - if node: - # join first to avoid address changed when connect to monkey chain connector - _, address = await identity.get_client().join(url, name) - updated = False - if node.address != address: - node.address = address - updated = True - if node.url != url: - await identity.get_client().updaet_url(node.address, url) - node.url = url - updated = True - if node.name != name: - await identity.get_client().update_name(node.address, name) - node.name = name - updated = True - if updated: +class Registry(object): + def __init__( + self, url: str = config.node_url, name: str = config.node_name + ) -> None: + self.url = url + self.name = name + + self.running_task: Optional[asyncio.Task] = None + + async def register(self): + _, address = await identity.get_client().join(self.url, self.name) + + async with db.session_scope() as sess: + q = sa.select(Node).where(Node.id == 1) + node: Optional[Node] = (await sess.execute(q)).scalars().one_or_none() + + if node is not None: + update = False + if node.address != address: + node.address = address + update = True + if node.url != self.url: + node.url = self.url + update = True + if node.name != self.name: + node.name = self.name + update = True + if update: + sess.add(node) + await sess.commit() + _logger.info(f"register new node, node address: {address}") + else: + _logger.info(f"registered node, node address: {address}") + else: + node = Node(url=self.url, name=self.name, address=address) sess.add(node) await sess.commit() - _logger.info(f"registered node, node address: {node.address}") + _logger.info(f"register new node, node address: {address}") - else: - _, address = await identity.get_client().join(url, name) - node = Node(url=url, name=name, address=address) - sess.add(node) - await sess.commit() - await sess.refresh(node) - _logger.info(f"register new node, node address: {node.address}") + async def unregister(self): + address = await get_node_address() + await identity.get_client().leave(address) + + async with db.session_scope() as sess: + q = sa.select(Node).where(Node.id == 1) + node = (await sess.execute(q)).scalar_one() + await sess.delete(node) + await sess.commit() -async def unregister(): - address = await get_node_address() - await identity.get_client().leave(address) + _logger.info(f"node {address} leave") - async with db.session_scope() as sess: - q = sa.select(Node).where(Node.id == 1) - node = (await sess.execute(q)).scalar_one() + async def start(self, interval: int = 60): + async def run(): + while True: + await asyncio.sleep(interval) + await identity.get_client().join(self.url, self.name) - await sess.delete(node) - await sess.commit() + if self.running_task is None: + self.running_task = asyncio.create_task(run()) + try: + await self.running_task + except asyncio.CancelledError: + pass + except Exception as e: + _logger.exception(e) + raise + else: + raise ValueError("registry is already started") - _logger.info(f"node {address} leave") + async def stop(self): + if self.running_task is not None: + self.running_task.cancel() + _logger.info("stop registry") diff --git a/delta_node/utils/precision.py b/delta_node/utils/precision.py index fc007eb..775299f 100644 --- a/delta_node/utils/precision.py +++ b/delta_node/utils/precision.py @@ -3,13 +3,13 @@ def fix_precision(arr: AggValueType, precision: int) -> AggValueType: - arr = arr.astype(np.float64) - arr = arr * (10**precision) - arr = arr.astype(np.int64) - return arr + _arr = arr.astype(np.float64) + _arr = _arr * (10**precision) + _arr = _arr.astype(np.int64) + return _arr def unfix_precision(arr: AggValueType, precision: int) -> AggValueType: - arr = arr.astype(np.float64) - arr = arr / (10**precision) - return arr + _arr = arr.astype(np.float64) + _arr = _arr / (10**precision) + return _arr diff --git a/requirements.txt b/requirements.txt index 1dad2e4..0dbf7e8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ aiosqlite==0.17.0 async_lru==1.0.2 cryptography==3.4.7 -delta-task==0.8.0 +delta-task==0.8.1 fastapi==0.70.1 grpclib==0.4.2 httpx==0.23.0 diff --git a/setup.py b/setup.py index 65dcb0a..adb25a7 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ def run_tests(self): setup( name="delta_node", - version="0.8.0", + version="0.8.1", packages=find_packages(), package_data={"delta_node": ["dataset/examples/*.csv"]}, include_package_data=True, @@ -39,7 +39,7 @@ def run_tests(self): "aiosqlite==0.17.0", "async_lru==1.0.2", "cryptography==3.4.7", - "delta-task==0.8.0", + "delta-task==0.8.1", "fastapi==0.70.1", "grpclib==0.4.2", "httpx==0.23.0", diff --git a/tests/chain/identity_test.py b/tests/chain/identity_test.py index fd8da65..70f3b0e 100644 --- a/tests/chain/identity_test.py +++ b/tests/chain/identity_test.py @@ -12,7 +12,7 @@ async def test_identity(identity_client: identity.Client): assert info.url == url # update url new_url = "http://127.0.0.1:6800" - await identity_client.updaet_url(address, new_url) + await identity_client.update_url(address, new_url) info = await identity_client.get_node_info(address) assert info.url == new_url url = new_url diff --git a/tests/registry_test.py b/tests/registry_test.py index 055e0f1..b853fc4 100644 --- a/tests/registry_test.py +++ b/tests/registry_test.py @@ -1,3 +1,4 @@ +import asyncio import pytest from delta_node import db, registry, chain @@ -10,14 +11,24 @@ async def test_register(): chain.init("127.0.0.1", 4500) url = "http://127.0.0.1:6800" name = "node1" - await registry.register(url, name) + + r = registry.Registry(url, name) + await r.register() + + fut = asyncio.create_task(r.start(interval=1)) + address = await registry.get_node_address() + info = await identity.get_client().get_node_info(address) + assert info.address == address + assert info.name == name + assert info.url == url - node_info = await identity.get_client().get_node_info(address=address) - assert node_info.address == address - assert node_info.name == name - assert node_info.url == url + try: + await asyncio.wait_for(fut, timeout=2) + except asyncio.TimeoutError: + pass - await registry.unregister() + await r.stop() + await r.unregister() chain.close() await db.close() \ No newline at end of file