Skip to content

Commit

Permalink
add strong ref to pending tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
synodriver committed Mar 21, 2023
1 parent 72b5840 commit 7f0a073
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 55 deletions.
10 changes: 7 additions & 3 deletions README.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,15 @@ def callback2(trigger, data:dict):

### v1.3.5rc2

* parser for aria2files
* add parser for aria2 files

```python
from pprint import pprint
from aioaria2 import DHTFile

pprint(DHTFile.from_file("dht.dat"))
```
pprint(DHTFile.from_file2("dht.dat"))
```

### v1.3.5rc3

* add strong ref to pending tasks
14 changes: 14 additions & 0 deletions README_zh.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -198,5 +198,19 @@ def callback2(trigger, data:dict):

* 更好的关闭

### v1.3.5rc2

* aria2二进制文件解析器

```python
from pprint import pprint
from aioaria2 import DHTFile

pprint(DHTFile.from_file2("dht.dat"))
```

### v1.3.5rc3

* 对没完成的task使用强引用以防止被gc

![title](https://konachan.com/sample/c7f565c0cd96e58908bc852dd754f61a/Konachan.com%20-%20302356%20sample.jpg)
2 changes: 1 addition & 1 deletion aioaria2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from aioaria2.server import Aria2Server, AsyncAria2Server
from aioaria2.utils import add_async_callback, run_sync

__version__ = "1.3.5rc2"
__version__ = "1.3.5rc3"

__author__ = "synodriver"
__all__ = [
Expand Down
9 changes: 8 additions & 1 deletion aioaria2/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,7 @@ def __init__(
list
) # 存放各个notice的回调
self._listen_task = None # type: asyncio.Task
self._pending_tasks = set()

@classmethod
async def new(
Expand Down Expand Up @@ -916,6 +917,10 @@ def closed(self) -> bool:
async def close(self) -> None:
if self._listen_task and not self._listen_task.cancelled():
self._listen_task.cancel()
try:
await self._listen_task
except asyncio.CancelledError:
pass
await super().close()
await self._client_session.close()

Expand All @@ -931,7 +936,9 @@ async def listen(self) -> None:
continue
if not data or not isinstance(data, dict):
continue
asyncio.create_task(self.handle_event(data))
task = asyncio.create_task(self.handle_event(data))
self._pending_tasks.add(task) # add a strong ref
task.add_done_callback(self._pending_tasks.discard)
except asyncio.CancelledError:
pass

Expand Down
83 changes: 39 additions & 44 deletions aioaria2/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ class InFlightPiece:
piece_bitfield: bytes

@classmethod
def from_reader(cls, reader: IO[bytes], version: int) -> "InFlightPiece":
index = int.from_bytes(reader.read(4), "big" if version == 1 else "little")
length = int.from_bytes(reader.read(4), "big" if version == 1 else "little")
def from_file(cls, file: IO[bytes], version: int) -> "InFlightPiece":
index = int.from_bytes(file.read(4), "big" if version == 1 else "little")
length = int.from_bytes(file.read(4), "big" if version == 1 else "little")
piece_bitfield_length = int.from_bytes(
reader.read(4), "big" if version == 1 else "little"
file.read(4), "big" if version == 1 else "little"
)
piece_bitfield = reader.read(piece_bitfield_length)
piece_bitfield = file.read(piece_bitfield_length)
return cls(
index=index,
length=length,
Expand Down Expand Up @@ -58,41 +58,36 @@ class ControlFile:
inflight_pieces: List[InFlightPiece]

@classmethod
def from_file(cls, file: Union[str, Path]) -> "ControlFile":
def from_file2(cls, file: Union[str, Path]) -> "ControlFile":
with open(file, "rb") as f:
return cls.from_reader(f)

@classmethod
def from_reader(cls, reader: IO[bytes]) -> "ControlFile":
version = int.from_bytes(reader.read(2), "big")
ext = reader.read(4)
def from_file(cls, file: IO[bytes]) -> "ControlFile":
version = int.from_bytes(file.read(2), "big")
ext = file.read(4)
info_hash_length = int.from_bytes(
reader.read(4), "big" if version == 1 else "little"
file.read(4), "big" if version == 1 else "little"
)
if info_hash_length == 0 and ext[3] & 1 == 1:
raise ValueError(
'"infoHashCheck" extension is enabled but info hash length is 0'
)
info_hash = reader.read(info_hash_length)
piece_length = int.from_bytes(
reader.read(4), "big" if version == 1 else "little"
)
total_length = int.from_bytes(
reader.read(8), "big" if version == 1 else "little"
)
info_hash = file.read(info_hash_length)
piece_length = int.from_bytes(file.read(4), "big" if version == 1 else "little")
total_length = int.from_bytes(file.read(8), "big" if version == 1 else "little")
upload_length = int.from_bytes(
reader.read(8), "big" if version == 1 else "little"
file.read(8), "big" if version == 1 else "little"
)
bitfield_length = int.from_bytes(
reader.read(4), "big" if version == 1 else "little"
file.read(4), "big" if version == 1 else "little"
)
bitfield = reader.read(bitfield_length)
bitfield = file.read(bitfield_length)
num_inflight_piece = int.from_bytes(
reader.read(4), "big" if version == 1 else "little"
file.read(4), "big" if version == 1 else "little"
)
inflight_pieces = [
InFlightPiece.from_reader(reader, version)
for _ in range(num_inflight_piece)
InFlightPiece.from_file(file, version) for _ in range(num_inflight_piece)
]

return cls(
Expand Down Expand Up @@ -145,15 +140,15 @@ class NodeInfo:
node_id: bytes

@classmethod
def from_reader(cls, reader: IO[bytes]) -> "NodeInfo":
plen = int.from_bytes(reader.read(1), "big")
reader.read(7)
def from_file(cls, file: IO[bytes]) -> "NodeInfo":
plen = int.from_bytes(file.read(1), "big")
file.read(7)
class_ = IPv4Address if plen == 6 else IPv6Address
temp = reader.read(plen)
temp = file.read(plen)
compact_peer_info = (class_(temp[:-2]), int.from_bytes(temp[-2:], "big"))
reader.read(24 - plen)
node_id = reader.read(20)
reader.read(4)
file.read(24 - plen)
node_id = file.read(20)
file.read(4)
return cls(plen=plen, compact_peer_info=compact_peer_info, node_id=node_id)

def save(self, file: IO[bytes]) -> None:
Expand Down Expand Up @@ -183,26 +178,26 @@ class DHTFile:
nodes: List[NodeInfo]

@classmethod
def from_file(cls, file: Union[str, Path]) -> "DHTFile":
def from_file2(cls, file: Union[str, Path]) -> "DHTFile":
with open(file, "rb") as f:
return cls.from_reader(f)
return cls.from_file(f)

@classmethod
def from_reader(cls, reader: IO[bytes]) -> "DHTFile":
mgc = reader.read(2)
def from_file(cls, file: IO[bytes]) -> "DHTFile":
mgc = file.read(2)
assert mgc == b"\xa1\xa2", "wrong magic number"
fmt = reader.read(1)
fmt = file.read(1)
assert fmt == b"\x02", "wrong format idr"
ver = reader.read(2)
ver = file.read(2)
# assert ver == b'\x00\x03', "wrong version number"
reader.read(3)
mtime = int.from_bytes(reader.read(8), "big")
reader.read(8)
localnode_id = reader.read(20)
reader.read(4)
num_node = int.from_bytes(reader.read(4), "big")
reader.read(4)
nodes = [NodeInfo.from_reader(reader) for _ in range(num_node)]
file.read(3)
mtime = int.from_bytes(file.read(8), "big")
file.read(8)
localnode_id = file.read(20)
file.read(4)
num_node = int.from_bytes(file.read(4), "big")
file.read(4)
nodes = [NodeInfo.from_file(file) for _ in range(num_node)]
return cls(
mgc=mgc,
fmt=fmt,
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[metadata]
# replace with your username:
name = aioaria2
version = 1.3.4
version = 1.3.5rc3
keywords =
asyncio
Aria2
Expand Down
17 changes: 12 additions & 5 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,34 @@
# -*- coding: utf-8 -*-
from io import BytesIO
from unittest import TestCase
from os.path import join, dirname

from aioaria2 import ControlFile, DHTFile


class Testarser(TestCase):
def test_ControlFile(self):
s = BytesIO()
data = ControlFile.from_file("180P_225K_242958531.webm.aria2")
data = ControlFile.from_file2("180P_225K_242958531.webm.aria2")
data.save(s)
self.assertEqual(
s.getvalue(), open("180P_225K_242958531.webm.aria2", "rb").read()
)

def test_DHTFile(self):
s = BytesIO()
data = DHTFile.from_file("dht.dat")
data = DHTFile.from_file2(join(dirname(__file__), "dht.dat"))
data.save(s)
self.assertEqual(len(s.getvalue()), len(open("dht.dat", "rb").read()))
self.assertEqual(
len(s.getvalue()),
len(open(join(dirname(__file__), "dht.dat"), "rb").read()),
)

def test_DHTFilev6(self):
s = BytesIO()
data = DHTFile.from_file("dht6.dat")
data = DHTFile.from_file2(join(dirname(__file__), "dht6.dat"))
data.save(s)
self.assertEqual(len(s.getvalue()), len(open("dht6.dat", "rb").read()))
self.assertEqual(
len(s.getvalue()),
len(open(join(dirname(__file__), "dht6.dat"), "rb").read()),
)

0 comments on commit 7f0a073

Please sign in to comment.