Skip to content

Commit

Permalink
docstring + log erros
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Jun 26, 2024
1 parent 7f6f4bd commit d5f67fd
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 27 deletions.
4 changes: 4 additions & 0 deletions src/anemoi/graphs/edges/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .connections import CutOffEdgeBuilder
from .connections import KNNEdgeBuilder

__all__ = ["KNNEdgeBuilder", "CutOffEdgeBuilder"]
4 changes: 4 additions & 0 deletions src/anemoi/graphs/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .nodes import NPZNodes
from .nodes import ZarrNodes

__all__ = ["ZarrNodes", "NPZNodes"]
23 changes: 1 addition & 22 deletions src/anemoi/graphs/nodes/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,10 @@
from torch_geometric.data import HeteroData

logger = logging.getLogger(__name__)
earth_radius = 6371.0 # km


def latlon_to_radians(coords: np.ndarray) -> np.ndarray:
return np.deg2rad(coords)


def rad_to_latlon(coords: np.ndarray) -> np.ndarray:
"""Converts coordinates from radians to degrees.
Parameters
----------
coords : np.ndarray
Coordinates in radians.
Returns
-------
np.ndarray
_description_
"""
return np.rad2deg(coords)


class BaseNodeBuilder(ABC):
"""Base class for node builders."""

def register_nodes(self, graph: HeteroData, name: str) -> None:
graph[name].x = self.get_coordinates()
Expand All @@ -52,7 +32,6 @@ def get_coordinates(self) -> np.ndarray: ...
def reshape_coords(self, latitudes: np.ndarray, longitudes: np.ndarray) -> np.ndarray:
coords = np.stack([latitudes, longitudes], axis=-1).reshape((-1, 2))
coords = np.deg2rad(coords)
# TODO: type needs to be variable?
return torch.tensor(coords, dtype=torch.float32)

def transform(self, graph: HeteroData, name: str, attr_config: DotDict) -> HeteroData:
Expand Down
7 changes: 4 additions & 3 deletions src/anemoi/graphs/nodes/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,19 @@ def __init__(self, norm: Optional[str] = None):
@abstractmethod
def compute(self, nodes: NodeStorage, *args, **kwargs): ...

def get_weights(self, *args, **kwargs):
def get_weights(self, *args, **kwargs) -> torch.Tensor:
weights = self.compute(*args, **kwargs)
if weights.ndim == 1:
weights = weights[:, np.newaxis]
return self.normalize(weights)
norm_weights = self.normalize(weights)
return torch.tensor(norm_weights, dtype=torch.float32)


class UniformWeights(BaseWeights):
"""Implements a uniform weight for the nodes."""

def compute(self, nodes: NodeStorage) -> np.ndarray:
return torch.ones(nodes.num_nodes)
return np.ones(nodes.num_nodes)


class AreaWeights(BaseWeights):
Expand Down
10 changes: 8 additions & 2 deletions src/anemoi/graphs/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,11 @@ def normalize(self, values: np.ndarray) -> np.ndarray:
if self.norm == "unit-sum":
return values / np.sum(values)
if self.norm == "unit-std":
return values / np.std(values)
raise ValueError("Weight normalization must be 'l1', 'l2', 'unit-max' 'unit-sum' or 'unit-std'.")
std = np.std(values)
if std == 0:
logger.warning(f"Std. dev. of the {self.__class__.__name__} is 0. Cannot normalize.")
return values
return values / std
raise ValueError(
f"Weight normalization \"{values}\" is not valid. Options are: 'l1', 'l2', 'unit-max' 'unit-sum' or 'unit-std'."
)

0 comments on commit d5f67fd

Please sign in to comment.