Skip to content

Commit

Permalink
Merge pull request #12 from inetrg/pr/location_improvements
Browse files Browse the repository at this point in the history
Some fixes for location and add location cache
  • Loading branch information
MrKevinWeiss authored Feb 20, 2024
2 parents 2c5bde6 + 24b3fc9 commit fa46abd
Show file tree
Hide file tree
Showing 10 changed files with 209 additions and 33 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ console_scripts =
inet-nm-update-from-os = inet_nm.cli_update_from_os:main
inet-nm-set-location = inet_nm.cli_set_location:main
inet-nm-show-location = inet_nm.cli_show_location:main
inet-nm-update-cache = inet_nm.cli_update_cache:main


[tool:pytest]
Expand Down
5 changes: 4 additions & 1 deletion src/inet_nm/cli_fake_usb.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def add_board(id=None, **kwargs):
hash_object = hashlib.md5(id.encode())
# Convert the hash to a hexadecimal string
hex_hash = hash_object.hexdigest()

DEVPATH = kwargs.get(
"DEVPATH", f"/devices/pci0000:00/0000:00:00.0/usb1/1-{board_counter}"
)
ID_PATH = kwargs.get("ID_PATH", f"pci-0000:00:00.0-usb-0:{board_counter}")
device_node = kwargs.get("device_node", f"/dev/ttyUSB{board_counter + 100}")
ID_VENDOR_ID = kwargs.get("ID_VENDOR_ID", hex_hash[0:4])
Expand Down Expand Up @@ -73,6 +75,7 @@ def add_board(id=None, **kwargs):
"ID_VENDOR_FROM_DATABASE": "QinHeng Electronics",
"DRIVER": "ch341",
"ID_PATH": ID_PATH,
"DEVPATH": DEVPATH,
},
},
]
Expand Down
32 changes: 27 additions & 5 deletions src/inet_nm/cli_set_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
try_to_inc_map_name,
)
from inet_nm.data_types import NmNode
from inet_nm.usb_ctrl import get_connected_id_paths, get_id_path_from_node
from inet_nm.usb_ctrl import (
get_connected_id_paths,
get_id_path_from_node,
get_usb_info_from_node,
)


def select_available_node(nodes: List[NmNode], mapped_locations: List[str]) -> NmNode:
Expand Down Expand Up @@ -44,6 +48,12 @@ def _main():
parser.add_argument(
"-l", "--locate", action="store_true", help="Use usb hub location"
)
parser.add_argument(
"-p",
"--power-control",
action="store_true",
help="Flag that the port has power control option",
)

args = parser.parse_args()
loc_cfg = cfg.LocationConfig(config_dir=args.config)
Expand All @@ -57,17 +67,29 @@ def _main():
else:
available = chk.get_filtered_nodes(config=args.config)
sel_node = select_available_node(available, list(loc_mapping.keys()))
location = get_id_path_from_node(sel_node)
def_name = try_to_inc_map_name(list(loc_mapping.values()))
location, hub, port = get_usb_info_from_node(sel_node)
# use list comprehension to get the names from the loc_mapping dict values
names = [usb_info["name"] for usb_info in loc_mapping.values()]
def_name = try_to_inc_map_name(names)
name = nm_prompt_default_input("Enter a name for the location", default=def_name)
if location in loc_mapping:
res = nm_prompt_confirm(
f"Overwrite {location} currently " f"{loc_mapping[location]}?", default=True
)
if res:
loc_mapping[location] = name
loc_mapping[location] = {
"name": name,
"power_control": args.power_control,
"hub": hub,
"port": port,
}
else:
loc_mapping[location] = name
loc_mapping[location] = {
"name": name,
"power_control": args.power_control,
"hub": hub,
"port": port,
}
loc_cfg.save(loc_mapping)
nm_print(f"{name} mapped to {location}")
nm_print(f"Updated {loc_cfg.file_path}")
Expand Down
3 changes: 2 additions & 1 deletion src/inet_nm/cli_show_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def _main():
if len(matching_locs) and isinstance(matching_locs[0], list):
matching_locs = sorted(matching_locs)
if graph:
nm_print(parse_locations(matching_locs))
names = [usb_info["name"] for usb_info in matching_locs]
nm_print(parse_locations(names))
else:
nm_print(json.dumps(matching_locs, indent=2, sort_keys=True))

Expand Down
33 changes: 33 additions & 0 deletions src/inet_nm/cli_update_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import argparse

import inet_nm.config as cfg
import inet_nm.location as loc
from inet_nm._helpers import nm_print


def _main():
parser = argparse.ArgumentParser(description="Update the location cache")
cfg.config_arg(parser)

args = parser.parse_args()
loc_mapping = cfg.LocationConfig(config_dir=args.config).load()
nodes = cfg.NodesConfig(config_dir=args.config).load()
loc_cache = cfg.LocationCache(config_dir=args.config)
loc_cache.check_file(writable=True)

cache = loc.get_location_cache(nodes, loc_mapping)

loc_cache.save(cache)
nm_print(f"Updated {loc_cache.file_path}")


def main():
"""Updates the current state of board locations."""
try:
_main()
except KeyboardInterrupt:
nm_print("\nUser aborted...")


if __name__ == "__main__":
main()
17 changes: 0 additions & 17 deletions src/inet_nm/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,20 +125,3 @@ def parse_locations(locations):
_overlay_locations(grid, valid_locs)
grid = _parse_grid(grid)
return grid + "\n".join(invalid_locs)


# locs = [
# ["1.1.1", "3.1.2", "1.3.4", "garbage"],
# [],
# ["1.1.1", "3.2.2", "1.3.4", "1.2.3.4"],
# ["1.1.1", "3.1.2", "1.3.4", "1.2.6"],
# ["1.1.1", "3.1.2", "1.3.4", "a.3.4"],
# ["1.1.1"],
# ["3.3.1"],
# ["3.3.3"],
# ["2.3.4"],
# ]
# for loc in locs:
# print(loc)
# print(parse_locations(loc))
# print("==================")
42 changes: 42 additions & 0 deletions src/inet_nm/location.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Dict, List

import inet_nm.usb_ctrl as ucl
from inet_nm.data_types import NmNode


def get_location_cache(nodes: List[NmNode], id_paths: Dict):
"""
Get the location cache for a list of NmNode objects.
Args:
nodes: List of NmNode objects.
id_paths: List of id_paths to check.
Returns:
The location cache.
"""
processed_id_paths = set()
cache = []
node_uids = {node.uid for node in nodes if not node.ignore}
for id_path in id_paths:
node_uid = ucl.get_uid_from_id_path(id_path)
if node_uid is not None and node_uid in node_uids:
cache.append(
{"id_path": id_path, "node_uid": node_uid, "state": "attached"}
)
else:
cache.append({"id_path": id_path, "node_uid": node_uid, "state": "missing"})
processed_id_paths.add(id_path)

for node in nodes:
try:
id_path = ucl.get_id_path_from_node(node)
if id_path not in processed_id_paths:
cache.append(
{"id_path": id_path, "node_uid": node.uid, "state": "unassigned"}
)
processed_id_paths.add(id_path)
except Exception:
pass
cache.sort(key=lambda x: x["id_path"])
return cache
101 changes: 95 additions & 6 deletions src/inet_nm/usb_ctrl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import List, Optional
from typing import List, Optional, Set

if os.getenv("INET_NM_FAKE_USB_PATH"):
from inet_nm.fake_usb import Context
Expand Down Expand Up @@ -33,23 +33,112 @@ def get_connected_uids() -> List[str]:
return uids


def get_connected_id_paths() -> List[str]:
def get_connected_id_paths() -> Set[str]:
"""
Get the ID_PATHs of all connected USB devices.
Returns:
A list of ID_PATHs of all connected USB devices.
"""
context = Context()
locations = []
for device in context.list_devices(subsystem="usb", DEVTYPE="usb_device"):
locations = set()
for device in context.list_devices(subsystem="tty"):
parent = device.find_parent("usb", "usb_device")
if parent is None:
continue
locations.append(parent.get("ID_PATH"))
locations.add(parent.get("ID_PATH"))
return locations


def _split_devpath(devpath):
"""Split the devpath into hub and port.
Args:
devpath: The devpath to split.
Returns:
Tuple containing the hub and port.
Example:
>>> split_devpath("usb1/1-1/1-1.3")
("1-1", "3")
>>> split_devpath("usb1/1-1/1-1.3.2-4")
("1-1.3.2", "4")
>>> split_devpath("usb5/5-2")
("5", "2")
"""
end_devpath = devpath.split("/")[-1]
# check if a "-" or a "." is the later in the string

# get the last index of "-" in the string
if "-" in end_devpath:
port = end_devpath.split("-")[-1]
if "." in port:
port = port.split(".")[-1]
else:
port = end_devpath.split(".")[-1]

# remove the port string from the end_devpath string
hub = end_devpath[: -len(port) - 1]
return (hub, port)


def get_uid_from_id_path(id_path: str) -> str:
"""Get the UID of a connected USB device.
Args:
id_path: The ID_PATH of the connected USB device.
Returns:
The UID of the connected USB device.
Raises:
Exception: If the node is not found, maybe not connected.
"""
context = Context()

for device in context.list_devices(subsystem="tty"):
parent = device.find_parent("usb", "usb_device")
if parent is None:
continue
if parent.get("ID_PATH") == id_path:
vendor_id = parent.get("ID_VENDOR_ID")
model_id = parent.get("ID_MODEL_ID")
serial_short = parent.get("ID_SERIAL_SHORT")
return NmNode.calculate_uid(model_id, vendor_id, serial_short)


def get_usb_info_from_node(node: NmNode):
"""Read the USB information from a node.
Args:
node: The node to read the USB information from.
Returns:
A tuple containing the ID_PATH, hub, and port of the connected USB device.
Raises:
Exception: If the node is not found, maybe not connected.
"""
context = Context()
vendor_id = node.vendor_id
model_id = node.product_id
serial_short = node.serial

for device in context.list_devices(subsystem="tty"):
parent = device.find_parent("usb", "usb_device")
if parent is None:
continue
if (
parent.get("ID_VENDOR_ID") == vendor_id
and parent.get("ID_MODEL_ID") == model_id
and parent.get("ID_SERIAL_SHORT") == serial_short
):
hub, port = _split_devpath(parent.get("DEVPATH"))
return (parent.get("ID_PATH"), hub, port)
raise Exception("Node not found, maybe not connected")


def get_id_path_from_node(node: NmNode) -> str:
"""
Get the ID_PATH of a connected USB device.
Expand All @@ -75,7 +164,7 @@ def get_id_path_from_node(node: NmNode) -> str:
and parent.get("ID_SERIAL_SHORT") == serial_short
):
return parent.get("ID_PATH")
raise Exception("Node not found, maybe not connected")
raise Exception(f"Node {node} not found, maybe not connected")


def get_devices_from_tty(saved_nodes: Optional[List[NmNode]] = None) -> List[NmNode]:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_example_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ def test_cli_example_locate(tmpdir, cli_readme_mock):
)
try:
ret = json.loads(ret)
assert ret[0]["location"] == "1.1.1"
assert ret[1]["location"] == "2.3.4"
assert ret[2]["location"] == "in the garbage"
assert ret[0]["location"]["name"] == "1.1.1"
assert ret[1]["location"]["name"] == "2.3.4"
assert ret[2]["location"]["name"] == "in the garbage"
except json.JSONDecodeError:
assert False, "Could not decode json\n" + ret

Expand Down
2 changes: 2 additions & 0 deletions tests/test_example_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def test_cli_example(tmpdir, cli_readme_mock):
ct.footer = "That was the example to show off and test most of the features."

os.environ["NM_CONFIG_DIR"] = str(tmpdir)
if "INET_NM_FAKE_USB_PATH" in os.environ:
del os.environ["INET_NM_FAKE_USB_PATH"]

ct.run_step(
description="Let's just create a `board_info` list with some features...\n"
Expand Down

0 comments on commit fa46abd

Please sign in to comment.