From 71b51b2096090b650724489e0b521a1b3f813706 Mon Sep 17 00:00:00 2001 From: Benjamin Himes <37844818+thewhaleking@users.noreply.github.com> Date: Fri, 30 Aug 2024 16:16:26 +0200 Subject: [PATCH] Don't use `astype` for torch.Tensor (#2242) * Broke apart the logic of use_torch and not _process_weights_and_bonds to more clearly read it. Only apply .astype to the np version. * Add test for weights and bonds using torch. --- bittensor/metagraph.py | 32 +++++++++++++++++++----------- tests/unit_tests/test_metagraph.py | 31 +++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 12 deletions(-) diff --git a/bittensor/metagraph.py b/bittensor/metagraph.py index 8d7e97bcc0..420c847a09 100644 --- a/bittensor/metagraph.py +++ b/bittensor/metagraph.py @@ -25,10 +25,11 @@ import bittensor from os import listdir from os.path import join -from typing import List, Optional, Union, Tuple +from typing import List, Optional, Union, Tuple, cast from bittensor.chain_data import AxonInfo from bittensor.utils.registration import torch, use_torch +from bittensor.utils import weight_utils METAGRAPH_STATE_DICT_NDARRAY_KEYS = [ "version", @@ -648,33 +649,40 @@ def _process_weights_or_bonds( self.weights = self._process_weights_or_bonds(raw_weights_data, "weights") """ - data_array = [] + data_array: list[Union[NDArray[np.float32], "torch.Tensor"]] = [] for item in data: if len(item) == 0: if use_torch(): - data_array.append(torch.zeros(len(self.neurons))) # type: ignore + data_array.append(torch.zeros(len(self.neurons))) else: - data_array.append(np.zeros(len(self.neurons), dtype=np.float32)) # type: ignore + data_array.append(np.zeros(len(self.neurons), dtype=np.float32)) else: uids, values = zip(*item) # TODO: Validate and test the conversion of uids and values to tensor if attribute == "weights": data_array.append( - bittensor.utils.weight_utils.convert_weight_uids_and_vals_to_tensor( + weight_utils.convert_weight_uids_and_vals_to_tensor( len(self.neurons), list(uids), - list(values), # type: ignore + list(values), ) ) else: - data_array.append( - bittensor.utils.weight_utils.convert_bond_uids_and_vals_to_tensor( # type: ignore - len(self.neurons), list(uids), list(values) - ).astype(np.float32) + da_item = weight_utils.convert_bond_uids_and_vals_to_tensor( + len(self.neurons), list(uids), list(values) ) + if use_torch(): + data_array.append(cast("torch.LongTensor", da_item)) + else: + data_array.append( + cast(NDArray[np.float32], da_item).astype(np.float32) + ) tensor_param: Union["torch.nn.Parameter", NDArray] = ( ( - torch.nn.Parameter(torch.stack(data_array), requires_grad=False) + torch.nn.Parameter( + torch.stack(cast(list["torch.Tensor"], data_array)), + requires_grad=False, + ) if len(data_array) else torch.nn.Parameter() ) @@ -730,7 +738,7 @@ def _process_root_weights( uids, values = zip(*item) # TODO: Validate and test the conversion of uids and values to tensor data_array.append( - bittensor.utils.weight_utils.convert_root_weight_uids_and_vals_to_tensor( # type: ignore + weight_utils.convert_root_weight_uids_and_vals_to_tensor( # type: ignore n_subnets, list(uids), list(values), subnets ) ) diff --git a/tests/unit_tests/test_metagraph.py b/tests/unit_tests/test_metagraph.py index af0dbdba76..40303297a5 100644 --- a/tests/unit_tests/test_metagraph.py +++ b/tests/unit_tests/test_metagraph.py @@ -124,6 +124,37 @@ def test_process_weights_or_bonds(mock_environment): # TODO: Add more checks to ensure the bonds have been processed correctly +def test_process_weights_or_bonds_torch( + mock_environment, force_legacy_torch_compat_api +): + _, neurons = mock_environment + metagraph = bittensor.metagraph(1, sync=False) + metagraph.neurons = neurons + + # Test weights processing + weights = metagraph._process_weights_or_bonds( + data=[neuron.weights for neuron in neurons], attribute="weights" + ) + assert weights.shape[0] == len( + neurons + ) # Number of rows should be equal to number of neurons + assert weights.shape[1] == len( + neurons + ) # Number of columns should be equal to number of neurons + # TODO: Add more checks to ensure the weights have been processed correctly + + # Test bonds processing + bonds = metagraph._process_weights_or_bonds( + data=[neuron.bonds for neuron in neurons], attribute="bonds" + ) + assert bonds.shape[0] == len( + neurons + ) # Number of rows should be equal to number of neurons + assert bonds.shape[1] == len( + neurons + ) # Number of columns should be equal to number of neurons + + # Mocking the bittensor.subtensor class for testing purposes @pytest.fixture def mock_subtensor():