Skip to content

Commit

Permalink
Fix gradfn.ReduceMean.Backward not respecting operator shape
Browse files Browse the repository at this point in the history
  • Loading branch information
marco-nicola committed Nov 4, 2023
1 parent e547f4b commit 93ae88e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Fixed

- `gradfn.ReduceMean.Backward` was not respecting the operand's shape, causing
incompatible-shape error when reduce-mean was applied to non-vector values.

## [1.1.0] - 2023-10-30

### Changed
Expand Down
2 changes: 1 addition & 1 deletion mat/gradfn/reducemean.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (r *ReduceMean[O]) Backward(gy mat.Tensor) error {
x := r.x.Value().(mat.Matrix)
size := x.Size()
v := gy.Item().F64() / float64(size)
gx := x.NewMatrix(mat.WithShape(size), mat.WithBacking(mat.CreateInitializedSlice(size, v)))
gx := x.NewMatrix(mat.WithShape(x.Shape()...), mat.WithBacking(mat.CreateInitializedSlice(size, v)))
r.x.AccGrad(gx)
}
return nil
Expand Down

0 comments on commit 93ae88e

Please sign in to comment.