Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update to latest tinygrad #9

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ jobs:
with:
python-version: 3.11
- name: Install deps with testing deps
run: pip install numpy tqdm mypy torch pytest tabulate --extra-index-url https://download.pytorch.org/whl/cpu
- name: Get code size
run: PYTHONPATH="." python sz.py
run: pip install numpy tqdm mypy torch pytest tabulate hypothesis --extra-index-url https://download.pytorch.org/whl/cpu
- name: Test ops dtype optim
run: |
PYTHONPATH="." python test/test_ops.py
PYTHONPATH="." python test/test_dtype.py
PYTHONPATH="." python test/test_optim.py
- name: Check types with mypy
run: mypy
run: mypy
- name: Get code size
run: PYTHONPATH="." python sz.py
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
*.pyc
*.pyc
.hypothesis
3 changes: 2 additions & 1 deletion import_from_tinygrad.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#!/usr/bin/env python3
import pathlib

FILES = ["tensor.py", "mlops.py", "nn/optim.py", "../test/test_ops.py", "../test/test_dtype.py", "../test/test_optim.py"]
FILES = ["tensor.py", "mlops.py", "dtype.py", "nn/optim.py", "../test/test_ops.py", "../test/test_dtype.py", "../test/test_optim.py"]
src = pathlib.Path("../tinygrad/tinygrad")
dest = pathlib.Path("teenygrad")

for f in FILES:
print("importing", f)
rd = open(src/f).read()
rd = rd.replace("from tinygrad ", "from teenygrad ")
rd = rd.replace("from tinygrad.", "from teenygrad.")
rd = rd.replace("import tinygrad.", "import teenygrad.")
(dest/f).parent.mkdir(parents=True, exist_ok=True)
Expand Down
36 changes: 36 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
indent-width = 2

lint.select = [
"F",
"W6",
"E71",
"E72",
"E112", # no-indented-block
"E113", # unexpected-indentation
# "E124",
"E203", # whitespace-before-punctuation
"E272", # multiple-spaces-before-keyword
# "E303",
# "E304",
"E501", # line-too-long
# "E502",
"E702", # multiple-statements-on-one-line-semicolon
"E703", # useless-semicolon
"E731", # lambda-assignment
"W191", # tab-indentation
"W291", # trailing-whitespace
"W293", # blank-line-with-whitespace
"UP039", # unnecessary-class-parentheses
]

line-length = 150

exclude = [
"disassemblers/",
"docs/",
"examples/",
"extra/",
"openpilot/",
"tinygrad/runtime/autogen",
"test/external/mlperf_resnet",
]
4 changes: 3 additions & 1 deletion teenygrad/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from teenygrad.tensor import Tensor # noqa: F401
from teenygrad.tensor import Tensor # noqa: F401
from teenygrad.dtype import dtypes # noqa: F401
from teenygrad.device import Device # noqa: F401
16 changes: 16 additions & 0 deletions teenygrad/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Optional, Any
from teenygrad.dtype import DType
import numpy as np

class Device:
DEFAULT = "CPU"
_devices = ["CPU"]
@staticmethod
def canonicalize(device:Optional[str]) -> str: return "CPU"

class Buffer:
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options=None):
self.device, self.size, self.dtype = device, size, dtype
self._buf = opaque[1] if isinstance(opaque, tuple) else opaque
def copyin(self, buf): self._buf = np.frombuffer(buf, dtype=self.dtype.np)
def as_buffer(self): return np.require(self._buf, requirements=["C"]).data
103 changes: 103 additions & 0 deletions teenygrad/dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union
from dataclasses import dataclass
import numpy as np # TODO: remove numpy
import functools

Scalar = Union[float, int, bool]

@dataclass(frozen=True, order=True)
class DType:
priority: int # this determines when things get upcasted
itemsize: int
name: str
fmt: Optional[str]
count: int
def __repr__(self): return f"dtypes.{'_'*(c:=self.count!=1)}{INVERSE_DTYPES_DICT[self.name if not c else self.scalar().name]}{str(self.count)*c}"
def vec(self, sz:int):
assert sz > 1 and self.count == 1, f"can't vectorize {self} with size {sz}"
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
# TODO: someday this will be removed with the "remove numpy" project
@property
def np(self) -> Optional[type]: return np.dtype(self.fmt).type if self.fmt is not None else None

# dependent typing?
@dataclass(frozen=True, repr=False)
class ImageDType(DType):
shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
base: DType
def scalar(self): return self.base
def vec(self, sz:int): return self.base.vec(sz)
def __repr__(self): return f"dtypes.{self.name}({self.shape})"

# @dataclass(frozen=True, init=False, repr=False, eq=False)
class PtrDType(DType):
def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
def __repr__(self): return f"ptr.{super().__repr__()}"
def __hash__(self): return super().__hash__()
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
def __ne__(self, dt): return not (self == dt)

def cast_scalar(scalar: Scalar, dtype:DType):
return int(scalar) if dtypes.is_int(dtype) else float(scalar) if dtypes.is_float(dtype) else bool(scalar)

class dtypes:
@staticmethod
def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64)
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
def is_int(x: DType) -> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) or dtypes.is_unsigned(x)
@staticmethod
def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def from_np(x: type) -> DType: return DTYPES_DICT[np.dtype(x).name]
@staticmethod # NOTE: isinstance(True, int) is True in python
def from_py(x) -> DType: return dtypes.default_float if isinstance(x, float) else dtypes.bool if isinstance(x, bool) else dtypes.default_int
@staticmethod
def fields() -> Dict[str, DType]: return DTYPES_DICT
bool: Final[DType] = DType(0, 1, "bool", '?', 1)
int8: Final[DType] = DType(1, 1, "char", 'b', 1)
uint8: Final[DType] = DType(2, 1, "unsigned char", 'B', 1)
int16: Final[DType] = DType(3, 2, "short", 'h', 1)
uint16: Final[DType] = DType(4, 2, "unsigned short", 'H', 1)
int32: Final[DType] = DType(5, 4, "int", 'i', 1)
uint32: Final[DType] = DType(6, 4, "unsigned int", 'I', 1)
int64: Final[DType] = DType(7, 8, "long", 'l', 1)
uint64: Final[DType] = DType(8, 8, "unsigned long", 'L', 1)
float16: Final[DType] = DType(9, 2, "half", 'e', 1)
# bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
bfloat16: Final[DType] = DType(10, 2, "__bf16", None, 1)
float32: Final[DType] = DType(11, 4, "float", 'f', 1)
float64: Final[DType] = DType(12, 8, "double", 'd', 1)

# dtype aliases
half = float16; float = float32; double = float64 # noqa: E702
uchar = uint8; ushort = uint16; uint = uint32; ulong = uint64 # noqa: E702
char = int8; short = int16; int = int32; long = int64 # noqa: E702

# NOTE: these are image dtypes
@staticmethod
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, shape=shp, base=dtypes.float32)
@staticmethod
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, shape=shp, base=dtypes.float32)

default_float: ClassVar[DType] = float32
default_int: ClassVar[DType] = int32

# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
# we don't support weak type and complex type
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],
dtypes.int64: [dtypes.float16, dtypes.bfloat16], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16],
dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }

@functools.lru_cache(None)
def _get_recursive_parents(dtype:DType) -> Set[DType]:
return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
@functools.lru_cache(None)
def least_upper_dtype(*ds:DType) -> DType:
return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)

# HACK: staticmethods are not callable in 3.8 so we have to compare the class
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default')) or v.__class__ is staticmethod)}
INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()}
13 changes: 13 additions & 0 deletions teenygrad/features/multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# NOTE: this abstraction is wrong in tensor.py so we have to stub this
from typing import Tuple
from teenygrad.dtype import DType
class MultiLazyBuffer:
device: Tuple[str, ...]
dtype: DType
shape: Tuple[int, ...]
def __init__(self, lbs, axis, real=None):
self.lbs, self.axis = lbs, axis
raise NotImplementedError("no multibuffer support")
@staticmethod
def from_sharded(lb, devices, axis=None): raise NotImplementedError("no multibuffer support")
def copy_to_device(self, device): pass
62 changes: 13 additions & 49 deletions teenygrad/helpers.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,25 @@
from typing import Union, Tuple, Iterator, Optional, Final, Any
import os, functools, platform
import numpy as np
from math import prod # noqa: F401 # pylint:disable=unused-import
from dataclasses import dataclass
from typing import Union, Tuple, Sequence, Any, Iterable, Dict, TypeVar
import os, functools, platform, operator

T = TypeVar("T")
U = TypeVar("U")
OSX = platform.system() == "Darwin"
def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.mul, x, 1)
def dedup(x): return list(dict.fromkeys(x)) # retains list orderi
def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else x
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
def flatten(l:Iterator): return [item for sublist in l for item in sublist]
def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
def fully_flatten(l): return [item for sublist in l for item in (fully_flatten(sublist) if isinstance(sublist, (tuple, list)) else [sublist])]
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def all_int(t: Tuple[Any, ...]) -> bool: return all(isinstance(s, int) for s in t)
def all_int(t: Sequence[Any]) -> bool: return all(isinstance(s, int) for s in t)
def round_up(num, amt:int): return (num+amt-1)//amt * amt
def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]:
assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" # noqa: E501
return {k:v for d in ds for k,v in d.items()}
def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,))

@functools.lru_cache(maxsize=None)
def getenv(key, default=0): return type(default)(os.getenv(key, default))

DEBUG = getenv("DEBUG")
DEBUG, WINO, IMAGE = getenv("DEBUG"), getenv("WINO"), 0
CI = os.getenv("CI", "") != ""

@dataclass(frozen=True, order=True)
class DType:
priority: int # this determines when things get upcasted
itemsize: int
name: str
np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project
sz: int = 1
def __repr__(self): return f"dtypes.{self.name}"

class dtypes:
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64)
@staticmethod
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
float16: Final[DType] = DType(9, 2, "half", np.float16)
half = float16
float32: Final[DType] = DType(10, 4, "float", np.float32)
float = float32
float64: Final[DType] = DType(11, 8, "double", np.float64)
double = float64
int8: Final[DType] = DType(1, 1, "char", np.int8)
int16: Final[DType] = DType(3, 2, "short", np.int16)
int32: Final[DType] = DType(5, 4, "int", np.int32)
int64: Final[DType] = DType(7, 8, "long", np.int64)
uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8)
uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16)
uint32: Final[DType] = DType(6, 4, "unsigned int", np.uint32)
uint64: Final[DType] = DType(8, 8, "unsigned long", np.uint64)

# NOTE: bfloat16 isn't supported in numpy
bfloat16: Final[DType] = DType(9, 2, "__bf16", None)

DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}

PtrDType, ImageDType, IMAGE = None, None, 0 # junk to remove
33 changes: 19 additions & 14 deletions teenygrad/lazy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations
from teenygrad.helpers import DType, dtypes, DEBUG
from teenygrad.helpers import DEBUG, prod
from teenygrad.dtype import DType, dtypes, least_upper_dtype
from teenygrad.device import Buffer
from teenygrad.ops import UnaryOps, BinaryOps, ReduceOps, TernaryOps, LoadOps
import numpy as np

Expand All @@ -9,31 +11,33 @@ def toCPU(self): return self.x

class LazyBuffer:
device = "CPU"

def __init__(self, buf: np.ndarray): self._np = buf
def __init__(self, buf: np.ndarray): self.realized = Buffer("CPU", buf.size, dtypes.from_np(buf.dtype), buf)

@property
def base(self): return self
@property
def dtype(self): return dtypes.from_np(self._np.dtype)
def dtype(self): return self.realized.dtype
@property
def realized(self): return RawCPUBuffer(self._np)
def _np(self):
if self.realized._buf is None: return np.array([], dtype=self.realized.dtype.np).reshape((0,))
return self.realized._buf
@property
def shape(self): return self._np.shape
def __repr__(self): return f"<LB {self.shape} {self.dtype}>"

def schedule(self, seen=None): return []
def is_unrealized_contiguous_const(self): return False
def copy_to_device(self, device:str) -> LazyBuffer: return self

@staticmethod
def fromCPU(x): return LazyBuffer(x)

@staticmethod
def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer:
if op == LoadOps.RAND: return LazyBuffer(np.random.default_rng(arg).random(size=shape, dtype=dtype.np))
def loadop(op, shape, dtype, device, arg=None, src=None, _buf=None) -> LazyBuffer:
if op == LoadOps.CUSTOM:
arg(ret := Buffer(device, prod(shape), dtype))
return ret._buf.reshape(shape)
elif op == LoadOps.CONST: return LazyBuffer(np.full(shape, arg, dtype=dtype.np))
elif op == LoadOps.EMPTY: return LazyBuffer(np.empty(shape, dtype=dtype.np))
elif op == LoadOps.EMPTY: return LazyBuffer(_buf._buf if device == "EXT" and prod(shape) != 0 else np.empty(shape, dtype=dtype.np))
else: raise NotImplementedError(op)

def contiguous(x): return x
Expand All @@ -52,16 +56,17 @@ def e(self, op, *srcs:LazyBuffer):
elif op == BinaryOps.SUB: ret = self._np - srcs[0]._np
elif op == BinaryOps.MUL: ret = self._np * srcs[0]._np
elif op == BinaryOps.DIV: ret = self._np / srcs[0]._np
elif op == BinaryOps.XOR: ret = self._np ^ srcs[0]._np
elif op == BinaryOps.MAX: ret = np.maximum(self._np, srcs[0]._np)
elif op == BinaryOps.CMPLT: ret = self._np < srcs[0]._np
elif op == BinaryOps.CMPEQ: ret = self._np == srcs[0]._np
elif op == TernaryOps.WHERE: ret = np.where(self._np, srcs[0]._np, srcs[1]._np)
else: raise NotImplementedError(op)
return LazyBuffer(ret.astype(self.dtype.np if len(srcs) == 0 else max(self.dtype, *[x.dtype for x in srcs]).np, copy=False))
new_type = least_upper_dtype(self.dtype, *[x.dtype for x in srcs]) if op not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else dtypes.bool
return LazyBuffer(ret.astype(new_type.np, copy=False))

def r(self, op, new_shape):
if DEBUG >= 1: print(op, self, new_shape)
assert len(self.shape) == len(new_shape), "reduce shapes must have same dimensions"
axis = tuple(i for i,(a,b) in enumerate(zip(self.shape, new_shape)) if a != b)
def r(self, op, axis):
if DEBUG >= 1: print(op, self, axis)
if op == ReduceOps.SUM: return LazyBuffer(self._np.sum(axis, dtype=self._np.dtype, keepdims=True))
elif op == ReduceOps.MAX: return LazyBuffer(self._np.max(axis, keepdims=True))
else: raise NotImplementedError(op)
Expand Down
Loading
Loading