Skip to content

Commit

Permalink
one less teeny line
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot committed Sep 23, 2023
1 parent 2da66f2 commit cdd850c
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions teenygrad/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
from teenygrad.ops import UnaryOps, BinaryOps, ReduceOps, TernaryOps, LoadOps
import numpy as np

def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]:
assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b)

class LazyBuffer:
device = "CPU"
dtype = dtypes.float32
Expand Down Expand Up @@ -49,8 +45,10 @@ def e(self, op, *srcs):
else: raise NotImplementedError(op)

def r(self, op, new_shape):
if op == ReduceOps.SUM: return LazyBuffer(self._np.sum(shape_to_axis(self.shape, new_shape), keepdims=True))
elif op == ReduceOps.MAX: return LazyBuffer(self._np.max(shape_to_axis(self.shape, new_shape), keepdims=True))
assert len(self.shape) == len(new_shape), "reduce shapes must have same dimensions"
axis = tuple(i for i,(a,b) in enumerate(zip(self.shape, new_shape)) if a != b)
if op == ReduceOps.SUM: return LazyBuffer(self._np.sum(axis, keepdims=True))
elif op == ReduceOps.MAX: return LazyBuffer(self._np.max(axis, keepdims=True))
else: raise NotImplementedError(op)

# MovementOps
Expand Down

0 comments on commit cdd850c

Please sign in to comment.