Skip to content

Commit

Permalink
Don't use astype for torch.Tensor (#2242)
Browse files Browse the repository at this point in the history
* Broke apart the logic of use_torch and not _process_weights_and_bonds to more clearly read it. Only apply .astype to the np version.

* Add test for weights and bonds using torch.
  • Loading branch information
thewhaleking authored Aug 30, 2024
1 parent 3632f69 commit 71b51b2
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 12 deletions.
32 changes: 20 additions & 12 deletions bittensor/metagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
import bittensor
from os import listdir
from os.path import join
from typing import List, Optional, Union, Tuple
from typing import List, Optional, Union, Tuple, cast

from bittensor.chain_data import AxonInfo
from bittensor.utils.registration import torch, use_torch
from bittensor.utils import weight_utils

METAGRAPH_STATE_DICT_NDARRAY_KEYS = [
"version",
Expand Down Expand Up @@ -648,33 +649,40 @@ def _process_weights_or_bonds(
self.weights = self._process_weights_or_bonds(raw_weights_data, "weights")
"""
data_array = []
data_array: list[Union[NDArray[np.float32], "torch.Tensor"]] = []
for item in data:
if len(item) == 0:
if use_torch():
data_array.append(torch.zeros(len(self.neurons))) # type: ignore
data_array.append(torch.zeros(len(self.neurons)))
else:
data_array.append(np.zeros(len(self.neurons), dtype=np.float32)) # type: ignore
data_array.append(np.zeros(len(self.neurons), dtype=np.float32))
else:
uids, values = zip(*item)
# TODO: Validate and test the conversion of uids and values to tensor
if attribute == "weights":
data_array.append(
bittensor.utils.weight_utils.convert_weight_uids_and_vals_to_tensor(
weight_utils.convert_weight_uids_and_vals_to_tensor(
len(self.neurons),
list(uids),
list(values), # type: ignore
list(values),
)
)
else:
data_array.append(
bittensor.utils.weight_utils.convert_bond_uids_and_vals_to_tensor( # type: ignore
len(self.neurons), list(uids), list(values)
).astype(np.float32)
da_item = weight_utils.convert_bond_uids_and_vals_to_tensor(
len(self.neurons), list(uids), list(values)
)
if use_torch():
data_array.append(cast("torch.LongTensor", da_item))
else:
data_array.append(
cast(NDArray[np.float32], da_item).astype(np.float32)
)
tensor_param: Union["torch.nn.Parameter", NDArray] = (
(
torch.nn.Parameter(torch.stack(data_array), requires_grad=False)
torch.nn.Parameter(
torch.stack(cast(list["torch.Tensor"], data_array)),
requires_grad=False,
)
if len(data_array)
else torch.nn.Parameter()
)
Expand Down Expand Up @@ -730,7 +738,7 @@ def _process_root_weights(
uids, values = zip(*item)
# TODO: Validate and test the conversion of uids and values to tensor
data_array.append(
bittensor.utils.weight_utils.convert_root_weight_uids_and_vals_to_tensor( # type: ignore
weight_utils.convert_root_weight_uids_and_vals_to_tensor( # type: ignore
n_subnets, list(uids), list(values), subnets
)
)
Expand Down
31 changes: 31 additions & 0 deletions tests/unit_tests/test_metagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,37 @@ def test_process_weights_or_bonds(mock_environment):
# TODO: Add more checks to ensure the bonds have been processed correctly


def test_process_weights_or_bonds_torch(
mock_environment, force_legacy_torch_compat_api
):
_, neurons = mock_environment
metagraph = bittensor.metagraph(1, sync=False)
metagraph.neurons = neurons

# Test weights processing
weights = metagraph._process_weights_or_bonds(
data=[neuron.weights for neuron in neurons], attribute="weights"
)
assert weights.shape[0] == len(
neurons
) # Number of rows should be equal to number of neurons
assert weights.shape[1] == len(
neurons
) # Number of columns should be equal to number of neurons
# TODO: Add more checks to ensure the weights have been processed correctly

# Test bonds processing
bonds = metagraph._process_weights_or_bonds(
data=[neuron.bonds for neuron in neurons], attribute="bonds"
)
assert bonds.shape[0] == len(
neurons
) # Number of rows should be equal to number of neurons
assert bonds.shape[1] == len(
neurons
) # Number of columns should be equal to number of neurons


# Mocking the bittensor.subtensor class for testing purposes
@pytest.fixture
def mock_subtensor():
Expand Down

0 comments on commit 71b51b2

Please sign in to comment.