Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some fix for invalid broadcasting #1637

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ jobs:
- TEST=rf_array
- TEST=rf_attention
- TEST=rf_base
- TEST=rf_base RETURNN_FRONTEND_NATIVE=0
- TEST=rf_cond
- TEST=rf_const
- TEST=rf_container
Expand All @@ -360,6 +361,7 @@ jobs:
- TEST=rf_label_smoothing
- TEST=rf_loop
- TEST=rf_math
- TEST=rf_math RETURNN_FRONTEND_NATIVE=0
- TEST=rf_normalization
- TEST=rf_piecewise_linear
- TEST=rf_rec
Expand Down
9 changes: 6 additions & 3 deletions returnn/frontend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""

from __future__ import annotations

import os
from typing import Optional, Any, Union, TypeVar, Generic, Type, Callable, Sequence, Dict, Tuple, List
import contextlib
import numpy
Expand Down Expand Up @@ -1446,10 +1448,11 @@ def select_backend_torch():
global_backend.__class__ = backend
BehaviorVersion.set_min_behavior_version(16)

from returnn.frontend import _native
if os.environ.get("RETURNN_FRONTEND_NATIVE", "").strip() in ("", "1"):
from returnn.frontend import _native

_native.setup()
_native.setup_torch()
_native.setup()
_native.setup_torch()


def get_backend_by_tensor(tensor: Tensor, *, fallback: Optional[T2] = None) -> Union[Type[Backend[T]], T2]:
Expand Down
2 changes: 1 addition & 1 deletion returnn/frontend/_native/py_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class PyObjectScopedRef {
Sequence interface contract:

- Holds borrowed reference to PyObject*.
- Copy object itself is supposed to be fast, small object.
- Copy PyTupleOrListStaticRef/PyTupleOrListRef itself is supposed to be fast, small object.

Methods:

Expand Down
51 changes: 7 additions & 44 deletions returnn/frontend/_native/tensor_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,41 +253,6 @@ static bool _isSeqSubsetReorderFast(ASeqT subset, BSeqT superset, std::vector<in
return true;
}

static bool _isTupleSubsetReorderList(PyObject* subsetTuple, PyObject* supersetList, bool& error) {
int superSize = PyList_GET_SIZE(supersetList);
if(superSize < 0) { error = true; return false; }
int subSize = PyTuple_GET_SIZE(subsetTuple);
if(subSize < 0) { error = true; return false; }
if(subSize > superSize)
return false;
std::vector<bool> subsetTaken(subSize, false);
for(int j = 0; j < superSize; ++j) {
PyObject* b_ = PyList_GET_ITEM(supersetList, j);
int i = 0;
for(; i < subSize; ++i) {
if(subsetTaken[i]) continue;
PyObject* a_ = PyTuple_GET_ITEM(subsetTuple, i);
if(a_ == b_) break;
}
if(i == subSize) { // not found, try again using rich compare
for(; i < subSize; ++i) {
if(subsetTaken[i]) continue;
PyObject* a_ = PyTuple_GET_ITEM(subsetTuple, i);
int eq = PyObject_RichCompareBool(a_, b_, Py_EQ);
if(eq < 0) { error = true; return false; }
if(eq) break;
}
}
if(i < subSize)
subsetTaken[i] = true;
}
for(int i = 0; i < subSize; ++i) {
if(!subsetTaken[i])
return false;
}
return true;
}

PyObject* pyTensorCopy(PyObject *self, PyObject *args, PyObject *kwargs) {
static const char *kwlist[] = { "tensor", "name", NULL };
PyObject* tensor;
Expand Down Expand Up @@ -1203,6 +1168,8 @@ static PyObject* compareOrCombine(
// collect all dims
PyObjectScopedRef allDims = PyList_New(0);
if(!allDims) return NULL;
bool aDimsHaveAll = true;
bool bDimsHaveAll = true;
for(int i = 0; i < aDimsSeq.size() + bDimsSeq.size(); ++i) {
PyObject* dim =
i < aDimsSeq.size() ?
Expand All @@ -1221,8 +1188,10 @@ static PyObject* compareOrCombine(
// and this allows for a faster path.
int aDimsCount = PySequence_Count(aDims, dim);
if(aDimsCount < 0) return NULL;
if(aDimsCount == 0) aDimsHaveAll = false;
int bDimsCount = PySequence_Count(bDims, dim);
if(bDimsCount < 0) return NULL;
if(bDimsCount == 0) bDimsHaveAll = false;
if(aDimsCount <= 1 && bDimsCount <= 1) {
if(PyList_Append(allDims, dim) < 0) return NULL;
continue;
Expand All @@ -1249,17 +1218,11 @@ static PyObject* compareOrCombine(
}
PyTupleOrListStaticRef<false> allDimsSeq(allDims);

// check if all dims are in a and b, or whether we need allowBroadcastAllSources
bool error = false;
bool aDimsIsSubset = _isTupleSubsetReorderList(aDims, allDims, error);
if(error) return NULL;
bool bDimsIsSubset = _isTupleSubsetReorderList(bDims, allDims, error);
if(error) return NULL;
if(!aDimsIsSubset && !bDimsIsSubset) {
if(!allowBroadcastAllSources) {
if(!allowBroadcastAllSources) {
if(!aDimsHaveAll && !bDimsHaveAll) {
PyErr_Format(
PyExc_ValueError,
"compareOrCombine: sources %R %R not allowed with allow_broadcast_all_sources=False",
"compareOrCombine: sources %R %R not allowed, require explicit allow_broadcast_all_sources=True",
a, b);
return NULL;
}
Expand Down
5 changes: 4 additions & 1 deletion returnn/frontend/array_.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,10 @@ def sequence_mask(dims: Union[Dim, Sequence[Dim]], *, device: Optional[str] = No
return rf.constant(True, dims=())
mask = True
for dim in dyn_dims:
mask = rf.opt_logical_and(mask, dim.get_mask(dim_order=dims, device=device))
mask = rf.opt_logical_and(
mask, dim.get_mask(dim_order=dims, device=device), allow_broadcast_all_sources=True, dim_order=dims
)
assert isinstance(mask, Tensor)
return mask


Expand Down
18 changes: 15 additions & 3 deletions returnn/frontend/math_.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,23 @@ def opt_logical_or(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tens


@overload
def opt_logical_and(a: bool, b: bool) -> bool:
def opt_logical_and(
a: bool,
b: bool,
*,
allow_broadcast_all_sources: Optional[bool] = None,
dim_order: Optional[Sequence[Dim]] = None,
) -> bool:
"""logical and"""


def opt_logical_and(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Tensor, bool]:
def opt_logical_and(
a: Union[Tensor, bool],
b: Union[Tensor, bool],
*,
allow_broadcast_all_sources: Optional[bool] = None,
dim_order: Optional[Sequence[Dim]] = None,
) -> Union[Tensor, bool]:
"""logical and"""
if isinstance(a, bool):
if not a:
Expand All @@ -361,7 +373,7 @@ def opt_logical_and(a: Union[Tensor, bool], b: Union[Tensor, bool]) -> Union[Ten
if not b:
return False
return a
return combine(a, "logical_and", b)
return combine(a, "logical_and", b, allow_broadcast_all_sources=allow_broadcast_all_sources, dim_order=dim_order)


def is_finite(a: Tensor) -> Tensor:
Expand Down
4 changes: 2 additions & 2 deletions returnn/tensor/_tensor_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2969,10 +2969,10 @@ def copy_masked(

import returnn.frontend as rf

mask = None
mask = True
for axis in axes:
mask_ = self._dims[axis].get_mask(dim_order=self.dims, device=self.device)
mask = rf.logical_and(mask, mask_) if mask is not None else mask_
mask = rf.opt_logical_and(mask, mask_, allow_broadcast_all_sources=True, dim_order=self.dims)
assert isinstance(mask, _t.Tensor)
res = rf.where(mask, self, mask_value)
if use_padding_info:
Expand Down
4 changes: 4 additions & 0 deletions tests/test_rf_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations
from typing import Tuple
import os
from unittest import SkipTest
import _setup_test_env # noqa
import returnn.frontend as rf
Expand Down Expand Up @@ -493,6 +494,9 @@ def test_build_from_dict_func():


def test_build_from_dict_func_native():
if os.environ.get("RETURNN_FRONTEND_NATIVE", "").strip() not in ("", "1"):
raise SkipTest("RETURNN_FRONTEND_NATIVE not enabled")

from types import BuiltinFunctionType

rf.select_backend_torch() # enables some of the native optimizations
Expand Down
63 changes: 63 additions & 0 deletions tests/test_rf_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,69 @@ def _forward_step(*, model: _Net, extern_data: TensorDict):
run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step)


def test_eq():
time_dim = Dim(Tensor("time", [batch_dim], dtype="int32"))
in_dim = Dim(7, name="in")
extern_data = TensorDict(
{
"a": Tensor("a", [batch_dim, time_dim, in_dim], dtype="int32"),
"b": Tensor("b", [batch_dim, in_dim], dtype="int32"),
}
)

# noinspection PyShadowingNames
def _forward_step(*, extern_data: TensorDict, **_kwargs):
out = extern_data["a"] == extern_data["b"]
out.mark_as_default_output(shape=(batch_dim, time_dim, in_dim))

run_model(extern_data, lambda **_kwargs: rf.Module(), _forward_step)


def test_neq():
time_dim = Dim(Tensor("time", [batch_dim], dtype="int32"))
in_dim = Dim(7, name="in")
extern_data = TensorDict(
{
"a": Tensor("a", [batch_dim, time_dim, in_dim], dtype="int32"),
"b": Tensor("b", [batch_dim, in_dim], dtype="int32"),
}
)

# noinspection PyShadowingNames
def _forward_step(*, extern_data: TensorDict, **_kwargs):
out = extern_data["a"] != extern_data["b"]
out.mark_as_default_output(shape=(batch_dim, time_dim, in_dim))

run_model(extern_data, lambda **_kwargs: rf.Module(), _forward_step)


def test_neq_broadcast_exception():
time_dim = Dim(Tensor("time", [batch_dim], dtype="int32"))
other_time_dim = Dim(Tensor("other_time", [batch_dim], dtype="int32"))
extern_data = TensorDict(
{
"a": Tensor("a", [batch_dim, time_dim], dtype="int32"),
"b": Tensor("b", [batch_dim, other_time_dim], dtype="int32"),
}
)

# noinspection PyShadowingNames
def _forward_step(*, extern_data: TensorDict, **_kwargs):
a, b = extern_data["a"], extern_data["b"]
try:
_ = a != b
except Exception as e:
print("Got exception:", e)
assert "require explicit allow_broadcast_all_sources=True" in str(e) or "require broadcasting to" in str(
e
), f"exception unexpected: {e}"
else:
raise Exception("Expected exception for invalid broadcasting")
a.mark_as_default_output(shape=(batch_dim, time_dim))

run_model(extern_data, lambda **_kwargs: rf.Module(), _forward_step)


def test_squared_difference():
time_dim = Dim(Tensor("time", [batch_dim], dtype="int32"))
in_dim = Dim(7, name="in")
Expand Down
Loading