From eddd483314d667ab521e7f686cef441215fcaa69 Mon Sep 17 00:00:00 2001 From: MrKevinWeiss Date: Thu, 15 Feb 2024 10:23:17 +0100 Subject: [PATCH 1/3] feat(locate): Improve location structure with more info --- src/inet_nm/cli_fake_usb.py | 5 +- src/inet_nm/cli_set_location.py | 32 ++++++++-- src/inet_nm/cli_show_location.py | 3 +- src/inet_nm/usb_ctrl.py | 101 +++++++++++++++++++++++++++++-- tests/test_example_location.py | 6 +- tests/test_example_usage.py | 2 + 6 files changed, 133 insertions(+), 16 deletions(-) diff --git a/src/inet_nm/cli_fake_usb.py b/src/inet_nm/cli_fake_usb.py index 0733bd7..a1ff40b 100644 --- a/src/inet_nm/cli_fake_usb.py +++ b/src/inet_nm/cli_fake_usb.py @@ -44,7 +44,9 @@ def add_board(id=None, **kwargs): hash_object = hashlib.md5(id.encode()) # Convert the hash to a hexadecimal string hex_hash = hash_object.hexdigest() - + DEVPATH = kwargs.get( + "DEVPATH", f"/devices/pci0000:00/0000:00:00.0/usb1/1-{board_counter}" + ) ID_PATH = kwargs.get("ID_PATH", f"pci-0000:00:00.0-usb-0:{board_counter}") device_node = kwargs.get("device_node", f"/dev/ttyUSB{board_counter + 100}") ID_VENDOR_ID = kwargs.get("ID_VENDOR_ID", hex_hash[0:4]) @@ -73,6 +75,7 @@ def add_board(id=None, **kwargs): "ID_VENDOR_FROM_DATABASE": "QinHeng Electronics", "DRIVER": "ch341", "ID_PATH": ID_PATH, + "DEVPATH": DEVPATH, }, }, ] diff --git a/src/inet_nm/cli_set_location.py b/src/inet_nm/cli_set_location.py index 52b476b..b9ecb6c 100644 --- a/src/inet_nm/cli_set_location.py +++ b/src/inet_nm/cli_set_location.py @@ -11,7 +11,11 @@ try_to_inc_map_name, ) from inet_nm.data_types import NmNode -from inet_nm.usb_ctrl import get_connected_id_paths, get_id_path_from_node +from inet_nm.usb_ctrl import ( + get_connected_id_paths, + get_id_path_from_node, + get_usb_info_from_node, +) def select_available_node(nodes: List[NmNode], mapped_locations: List[str]) -> NmNode: @@ -44,6 +48,12 @@ def _main(): parser.add_argument( "-l", "--locate", action="store_true", help="Use usb hub location" ) + parser.add_argument( + "-p", + "--power-control", + action="store_true", + help="Flag that the port has power control option", + ) args = parser.parse_args() loc_cfg = cfg.LocationConfig(config_dir=args.config) @@ -57,17 +67,29 @@ def _main(): else: available = chk.get_filtered_nodes(config=args.config) sel_node = select_available_node(available, list(loc_mapping.keys())) - location = get_id_path_from_node(sel_node) - def_name = try_to_inc_map_name(list(loc_mapping.values())) + location, hub, port = get_usb_info_from_node(sel_node) + # use list comprehension to get the names from the loc_mapping dict values + names = [usb_info["name"] for usb_info in loc_mapping.values()] + def_name = try_to_inc_map_name(names) name = nm_prompt_default_input("Enter a name for the location", default=def_name) if location in loc_mapping: res = nm_prompt_confirm( f"Overwrite {location} currently " f"{loc_mapping[location]}?", default=True ) if res: - loc_mapping[location] = name + loc_mapping[location] = { + "name": name, + "power_control": args.power_control, + "hub": hub, + "port": port, + } else: - loc_mapping[location] = name + loc_mapping[location] = { + "name": name, + "power_control": args.power_control, + "hub": hub, + "port": port, + } loc_cfg.save(loc_mapping) nm_print(f"{name} mapped to {location}") nm_print(f"Updated {loc_cfg.file_path}") diff --git a/src/inet_nm/cli_show_location.py b/src/inet_nm/cli_show_location.py index 9d04b47..e3d75c3 100644 --- a/src/inet_nm/cli_show_location.py +++ b/src/inet_nm/cli_show_location.py @@ -62,7 +62,8 @@ def _main(): if len(matching_locs) and isinstance(matching_locs[0], list): matching_locs = sorted(matching_locs) if graph: - nm_print(parse_locations(matching_locs)) + names = [usb_info["name"] for usb_info in matching_locs] + nm_print(parse_locations(names)) else: nm_print(json.dumps(matching_locs, indent=2, sort_keys=True)) diff --git a/src/inet_nm/usb_ctrl.py b/src/inet_nm/usb_ctrl.py index b5c1df9..6358103 100644 --- a/src/inet_nm/usb_ctrl.py +++ b/src/inet_nm/usb_ctrl.py @@ -1,5 +1,5 @@ import os -from typing import List, Optional +from typing import List, Optional, Set if os.getenv("INET_NM_FAKE_USB_PATH"): from inet_nm.fake_usb import Context @@ -33,7 +33,7 @@ def get_connected_uids() -> List[str]: return uids -def get_connected_id_paths() -> List[str]: +def get_connected_id_paths() -> Set[str]: """ Get the ID_PATHs of all connected USB devices. @@ -41,15 +41,104 @@ def get_connected_id_paths() -> List[str]: A list of ID_PATHs of all connected USB devices. """ context = Context() - locations = [] - for device in context.list_devices(subsystem="usb", DEVTYPE="usb_device"): + locations = set() + for device in context.list_devices(subsystem="tty"): parent = device.find_parent("usb", "usb_device") if parent is None: continue - locations.append(parent.get("ID_PATH")) + locations.add(parent.get("ID_PATH")) return locations +def _split_devpath(devpath): + """Split the devpath into hub and port. + + Args: + devpath: The devpath to split. + + Returns: + Tuple containing the hub and port. + + Example: + >>> split_devpath("usb1/1-1/1-1.3") + ("1-1", "3") + >>> split_devpath("usb1/1-1/1-1.3.2-4") + ("1-1.3.2", "4") + >>> split_devpath("usb5/5-2") + ("5", "2") + """ + end_devpath = devpath.split("/")[-1] + # check if a "-" or a "." is the later in the string + + # get the last index of "-" in the string + if "-" in end_devpath: + port = end_devpath.split("-")[-1] + if "." in port: + port = port.split(".")[-1] + else: + port = end_devpath.split(".")[-1] + + # remove the port string from the end_devpath string + hub = end_devpath[: -len(port) - 1] + return (hub, port) + + +def get_uid_from_id_path(id_path: str) -> str: + """Get the UID of a connected USB device. + + Args: + id_path: The ID_PATH of the connected USB device. + + Returns: + The UID of the connected USB device. + + Raises: + Exception: If the node is not found, maybe not connected. + """ + context = Context() + + for device in context.list_devices(subsystem="tty"): + parent = device.find_parent("usb", "usb_device") + if parent is None: + continue + if parent.get("ID_PATH") == id_path: + vendor_id = parent.get("ID_VENDOR_ID") + model_id = parent.get("ID_MODEL_ID") + serial_short = parent.get("ID_SERIAL_SHORT") + return NmNode.calculate_uid(model_id, vendor_id, serial_short) + + +def get_usb_info_from_node(node: NmNode): + """Read the USB information from a node. + + Args: + node: The node to read the USB information from. + + Returns: + A tuple containing the ID_PATH, hub, and port of the connected USB device. + + Raises: + Exception: If the node is not found, maybe not connected. + """ + context = Context() + vendor_id = node.vendor_id + model_id = node.product_id + serial_short = node.serial + + for device in context.list_devices(subsystem="tty"): + parent = device.find_parent("usb", "usb_device") + if parent is None: + continue + if ( + parent.get("ID_VENDOR_ID") == vendor_id + and parent.get("ID_MODEL_ID") == model_id + and parent.get("ID_SERIAL_SHORT") == serial_short + ): + hub, port = _split_devpath(parent.get("DEVPATH")) + return (parent.get("ID_PATH"), hub, port) + raise Exception("Node not found, maybe not connected") + + def get_id_path_from_node(node: NmNode) -> str: """ Get the ID_PATH of a connected USB device. @@ -75,7 +164,7 @@ def get_id_path_from_node(node: NmNode) -> str: and parent.get("ID_SERIAL_SHORT") == serial_short ): return parent.get("ID_PATH") - raise Exception("Node not found, maybe not connected") + raise Exception(f"Node {node} not found, maybe not connected") def get_devices_from_tty(saved_nodes: Optional[List[NmNode]] = None) -> List[NmNode]: diff --git a/tests/test_example_location.py b/tests/test_example_location.py index 09bd40b..6fd888c 100644 --- a/tests/test_example_location.py +++ b/tests/test_example_location.py @@ -147,9 +147,9 @@ def test_cli_example_locate(tmpdir, cli_readme_mock): ) try: ret = json.loads(ret) - assert ret[0]["location"] == "1.1.1" - assert ret[1]["location"] == "2.3.4" - assert ret[2]["location"] == "in the garbage" + assert ret[0]["location"]["name"] == "1.1.1" + assert ret[1]["location"]["name"] == "2.3.4" + assert ret[2]["location"]["name"] == "in the garbage" except json.JSONDecodeError: assert False, "Could not decode json\n" + ret diff --git a/tests/test_example_usage.py b/tests/test_example_usage.py index edc61cf..4865d6c 100644 --- a/tests/test_example_usage.py +++ b/tests/test_example_usage.py @@ -31,6 +31,8 @@ def test_cli_example(tmpdir, cli_readme_mock): ct.footer = "That was the example to show off and test most of the features." os.environ["NM_CONFIG_DIR"] = str(tmpdir) + if "INET_NM_FAKE_USB_PATH" in os.environ: + del os.environ["INET_NM_FAKE_USB_PATH"] ct.run_step( description="Let's just create a `board_info` list with some features...\n" From 62db7502276c5bc574d7a185889e584db641b10c Mon Sep 17 00:00:00 2001 From: MrKevinWeiss Date: Thu, 15 Feb 2024 10:24:33 +0100 Subject: [PATCH 2/3] cleanup(graph): Remove commented test --- src/inet_nm/graph.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/inet_nm/graph.py b/src/inet_nm/graph.py index 3d44957..37d35d3 100644 --- a/src/inet_nm/graph.py +++ b/src/inet_nm/graph.py @@ -125,20 +125,3 @@ def parse_locations(locations): _overlay_locations(grid, valid_locs) grid = _parse_grid(grid) return grid + "\n".join(invalid_locs) - - -# locs = [ -# ["1.1.1", "3.1.2", "1.3.4", "garbage"], -# [], -# ["1.1.1", "3.2.2", "1.3.4", "1.2.3.4"], -# ["1.1.1", "3.1.2", "1.3.4", "1.2.6"], -# ["1.1.1", "3.1.2", "1.3.4", "a.3.4"], -# ["1.1.1"], -# ["3.3.1"], -# ["3.3.3"], -# ["2.3.4"], -# ] -# for loc in locs: -# print(loc) -# print(parse_locations(loc)) -# print("==================") From 24b3fc95f2e64faefaedd2baa211068f7855c4f9 Mon Sep 17 00:00:00 2001 From: MrKevinWeiss Date: Tue, 20 Feb 2024 15:05:50 +0100 Subject: [PATCH 3/3] feat(update_cache): Initial location cache --- setup.cfg | 1 + src/inet_nm/cli_update_cache.py | 33 ++++++++++++++++++++++++++ src/inet_nm/location.py | 42 +++++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+) create mode 100644 src/inet_nm/cli_update_cache.py create mode 100644 src/inet_nm/location.py diff --git a/setup.cfg b/setup.cfg index de7f56a..0c402a3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -76,6 +76,7 @@ console_scripts = inet-nm-update-from-os = inet_nm.cli_update_from_os:main inet-nm-set-location = inet_nm.cli_set_location:main inet-nm-show-location = inet_nm.cli_show_location:main + inet-nm-update-cache = inet_nm.cli_update_cache:main [tool:pytest] diff --git a/src/inet_nm/cli_update_cache.py b/src/inet_nm/cli_update_cache.py new file mode 100644 index 0000000..0114d2d --- /dev/null +++ b/src/inet_nm/cli_update_cache.py @@ -0,0 +1,33 @@ +import argparse + +import inet_nm.config as cfg +import inet_nm.location as loc +from inet_nm._helpers import nm_print + + +def _main(): + parser = argparse.ArgumentParser(description="Update the location cache") + cfg.config_arg(parser) + + args = parser.parse_args() + loc_mapping = cfg.LocationConfig(config_dir=args.config).load() + nodes = cfg.NodesConfig(config_dir=args.config).load() + loc_cache = cfg.LocationCache(config_dir=args.config) + loc_cache.check_file(writable=True) + + cache = loc.get_location_cache(nodes, loc_mapping) + + loc_cache.save(cache) + nm_print(f"Updated {loc_cache.file_path}") + + +def main(): + """Updates the current state of board locations.""" + try: + _main() + except KeyboardInterrupt: + nm_print("\nUser aborted...") + + +if __name__ == "__main__": + main() diff --git a/src/inet_nm/location.py b/src/inet_nm/location.py new file mode 100644 index 0000000..da61eb8 --- /dev/null +++ b/src/inet_nm/location.py @@ -0,0 +1,42 @@ +from typing import Dict, List + +import inet_nm.usb_ctrl as ucl +from inet_nm.data_types import NmNode + + +def get_location_cache(nodes: List[NmNode], id_paths: Dict): + """ + Get the location cache for a list of NmNode objects. + + Args: + nodes: List of NmNode objects. + id_paths: List of id_paths to check. + + Returns: + The location cache. + """ + processed_id_paths = set() + cache = [] + node_uids = {node.uid for node in nodes if not node.ignore} + for id_path in id_paths: + node_uid = ucl.get_uid_from_id_path(id_path) + if node_uid is not None and node_uid in node_uids: + cache.append( + {"id_path": id_path, "node_uid": node_uid, "state": "attached"} + ) + else: + cache.append({"id_path": id_path, "node_uid": node_uid, "state": "missing"}) + processed_id_paths.add(id_path) + + for node in nodes: + try: + id_path = ucl.get_id_path_from_node(node) + if id_path not in processed_id_paths: + cache.append( + {"id_path": id_path, "node_uid": node.uid, "state": "unassigned"} + ) + processed_id_paths.add(id_path) + except Exception: + pass + cache.sort(key=lambda x: x["id_path"]) + return cache