Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convolution added to numpy backend #517

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mrmustard/math/backend_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,8 +1527,7 @@ def all_diagonals(self, rho: Tensor, real: bool) -> Tensor:

def poisson(self, max_k: int, rate: Tensor) -> Tensor:
"""Poisson distribution up to ``max_k``."""
k = self.arange(max_k)
rate = self.cast(rate, k.dtype)
k = self.arange(max_k, dtype=rate.dtype)
return self.exp(k * self.log(rate + 1e-9) - rate - self.lgamma(k + 1.0))

def binomial_conditional_prob(self, success_prob: Tensor, dim_out: int, dim_in: int):
Expand Down
54 changes: 54 additions & 0 deletions mrmustard/math/backend_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import numpy as np
import scipy as sp
from scipy.signal import convolve2d as scipy_convolve2d
from scipy.linalg import expm as scipy_expm
from scipy.linalg import sqrtm as scipy_sqrtm
from scipy.special import xlogy as scipy_xlogy
Expand Down Expand Up @@ -136,6 +137,59 @@ def concat(self, values: list[np.ndarray], axis: int) -> np.ndarray:
def conj(self, array: np.ndarray) -> np.ndarray:
return np.conj(array)

def convolution(
self,
array: np.ndarray, # shape: [width, in_channels]
filters: np.ndarray, # shape: [kernel_width, in_channels, out_channels]
padding: str = "VALID",
data_format: str | None = None,
) -> np.ndarray: # returns: [width, out_channels]
"""Performs 2D convolution operation similar to tf.nn.convolution using numpy.

Args:
array: Input array of shape (batch, height, width, channels)
filters: Filter kernel of shape (kernel_height, kernel_width, in_channels, out_channels)
padding: String indicating the padding type ('VALID' or 'SAME')
data_format: Unused, kept for API compatibility

Returns:
np.ndarray: Result of the convolution operation with shape (batch, new_height, new_width, out_channels)
"""
# Extract shapes
batch, in_height, in_width, in_channels = array.shape
kernel_h, kernel_w, _, out_channels = filters.shape

# Reshape filter to 2D for convolution
filter_2d = filters[:, :, 0, 0]

# For SAME padding, calculate padding sizes
if padding == "SAME":
pad_h = (kernel_h - 1) // 2
pad_w = (kernel_w - 1) // 2
array = np.pad(
array[:, :, :, 0], ((0, 0), (pad_h, pad_h), (pad_w, pad_w)), mode="constant"
)
else:
array = array[:, :, :, 0]
ziofil marked this conversation as resolved.
Show resolved Hide resolved

# Calculate output dimensions
out_height = array.shape[1] - kernel_h + 1
out_width = array.shape[2] - kernel_w + 1

# Initialize output array
output = np.zeros((batch, out_height, out_width, out_channels))

# Perform convolution for each batch
for b in range(batch):
ziofil marked this conversation as resolved.
Show resolved Hide resolved
# Convolve using scipy's convolve2d which is more efficient than np.convolve for 2D
output[b, :, :, 0] = scipy_convolve2d(
array[b],
np.flip(np.flip(filter_2d, 0), 1), # Flip kernel for proper convolution
mode="valid",
)

return output

def cos(self, array: np.ndarray) -> np.ndarray:
return np.cos(array)

Expand Down
Loading