diff --git a/.gitignore b/.gitignore index f104e9ca..07dda88e 100644 --- a/.gitignore +++ b/.gitignore @@ -15,8 +15,6 @@ *.dylib *.pyc build/ -build_*/ -build-*/ install/ install_*/ install-*/ @@ -70,4 +68,3 @@ jit/codon/version.py temp/ playground/ scratch*.* -_* diff --git a/codon/cir/llvm/llvisitor.cpp b/codon/cir/llvm/llvisitor.cpp index 99b543d6..40703aeb 100644 --- a/codon/cir/llvm/llvisitor.cpp +++ b/codon/cir/llvm/llvisitor.cpp @@ -29,6 +29,10 @@ const std::string GPU_KERNEL_ATTR = "std.gpu.kernel"; const std::string MAIN_UNCLASH = ".main.unclash"; const std::string MAIN_CTOR = ".main.ctor"; + +llvm::cl::opt DisableExceptions("disable-exceptions", + llvm::cl::desc("Disable exception handling"), + llvm::cl::init(false)); } // namespace llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) { @@ -1696,6 +1700,7 @@ void LLVMVisitor::visit(const ExternalFunc *x) { coro = {}; seqassertn(func, "{} not inserted", *x); func->setDoesNotThrow(); + func->setWillReturn(); } namespace { @@ -1939,7 +1944,9 @@ void LLVMVisitor::visit(const BodiedFunc *x) { func->addFnAttr(llvm::Attribute::get(*context, "kernel")); func->setLinkage(llvm::GlobalValue::ExternalLinkage); } - func->setPersonalityFn(llvm::cast(makePersonalityFunc().getCallee())); + if (!DisableExceptions) + func->setPersonalityFn( + llvm::cast(makePersonalityFunc().getCallee())); auto *funcType = cast(x->getType()); seqassertn(funcType, "{} is not a function type", *x->getType()); @@ -3362,6 +3369,13 @@ void LLVMVisitor::visit(const YieldInstr *x) { } void LLVMVisitor::visit(const ThrowInstr *x) { + if (DisableExceptions) { + B->SetInsertPoint(block); + B->CreateUnreachable(); + block = llvm::BasicBlock::Create(*context, "throw_unreachable.new", func); + return; + } + // note: exception header should be set in the frontend auto excAllocFunc = makeExcAllocFunc(); auto throwFunc = makeThrowFunc(); diff --git a/codon/parser/cache.cpp b/codon/parser/cache.cpp index 65006779..10c49f46 100644 --- a/codon/parser/cache.cpp +++ b/codon/parser/cache.cpp @@ -38,7 +38,7 @@ std::string Cache::rev(const std::string &s) { void Cache::addGlobal(const std::string &name, ir::Var *var) { if (!in(globals, name)) { // LOG("[global] {}", name); - globals[name] = var; + globals[name] = {false, var}; } } diff --git a/codon/parser/cache.h b/codon/parser/cache.h index f9bdbbf5..f6f47e59 100644 --- a/codon/parser/cache.h +++ b/codon/parser/cache.h @@ -95,7 +95,7 @@ struct Cache : public std::enable_shared_from_this { /// Set of unique (canonical) global identifiers for marking such variables as global /// in code-generation step and in JIT. - std::map globals; + std::map> globals; /// Stores class data for each class (type) in the source code. struct Class { diff --git a/codon/parser/visitors/translate/translate.cpp b/codon/parser/visitors/translate/translate.cpp index be8d8073..3102f3e3 100644 --- a/codon/parser/visitors/translate/translate.cpp +++ b/codon/parser/visitors/translate/translate.cpp @@ -47,12 +47,12 @@ ir::Func *TranslateVisitor::apply(Cache *cache, const StmtPtr &stmts) { cache->codegenCtx->bases = {main}; cache->codegenCtx->series = {block}; - for (auto &g : cache->globals) - if (!g.second) { - g.second = g.first == VAR_ARGV ? cache->codegenCtx->getModule()->getArgVar() - : cache->codegenCtx->getModule()->N( - SrcInfo(), nullptr, true, false, g.first); - cache->codegenCtx->add(TranslateItem::Var, g.first, g.second); + for (auto &[name, p] : cache->globals) + if (p.first && !p.second) { + p.second = name == VAR_ARGV ? cache->codegenCtx->getModule()->getArgVar() + : cache->codegenCtx->getModule()->N( + SrcInfo(), nullptr, true, false, name); + cache->codegenCtx->add(TranslateItem::Var, name, p.second); } auto tv = TranslateVisitor(cache->codegenCtx); diff --git a/codon/parser/visitors/typecheck/assign.cpp b/codon/parser/visitors/typecheck/assign.cpp index 010a4346..ed76be73 100644 --- a/codon/parser/visitors/typecheck/assign.cpp +++ b/codon/parser/visitors/typecheck/assign.cpp @@ -72,6 +72,8 @@ void TypecheckVisitor::visit(AssignStmt *stmt) { ctx->instantiate(stmt->type->getSrcInfo(), stmt->type->getType())); } ctx->add(TypecheckItem::Var, lhs, stmt->lhs->type); + if (in(ctx->cache->globals, lhs)) + ctx->cache->globals[lhs].first = true; if (realize(stmt->lhs->type) || !stmt->type) stmt->setDone(); } else if (stmt->type && stmt->type->getType()->isStaticType()) { @@ -84,6 +86,7 @@ void TypecheckVisitor::visit(AssignStmt *stmt) { auto val = ctx->add(TypecheckItem::Var, lhs, stmt->lhs->type); if (in(ctx->cache->globals, lhs)) { // Make globals always visible! + ctx->cache->globals[lhs].first = true; ctx->addToplevel(lhs, val); } if (realize(stmt->lhs->type)) @@ -112,6 +115,7 @@ void TypecheckVisitor::visit(AssignStmt *stmt) { if (in(ctx->cache->globals, lhs)) { // Make globals always visible! + ctx->cache->globals[lhs].first = true; ctx->addToplevel(lhs, val); if (kind != TypecheckItem::Var) ctx->cache->globals.erase(lhs); diff --git a/codon/runtime/lib.h b/codon/runtime/lib.h index 3f653c32..7192bf57 100644 --- a/codon/runtime/lib.h +++ b/codon/runtime/lib.h @@ -130,5 +130,6 @@ std::string makeBacktraceFrameString(uintptr_t pc, const std::string &func = "", std::string getCapturedOutput(); void setJITErrorCallback(std::function callback); + } // namespace runtime } // namespace codon diff --git a/jit/codon/__init__.py b/jit/codon/__init__.py index 458b4678..bf510968 100644 --- a/jit/codon/__init__.py +++ b/jit/codon/__init__.py @@ -2,4 +2,4 @@ __all__ = ["jit", "convert", "JITError"] -from .decorator import jit, convert, JITError +from .decorator import jit, convert, execute, JITError diff --git a/jit/codon/decorator.py b/jit/codon/decorator.py index 54209dcf..dcff9100 100644 --- a/jit/codon/decorator.py +++ b/jit/codon/decorator.py @@ -8,6 +8,7 @@ import functools import itertools import ast +import textwrap import astunparse from pathlib import Path @@ -125,10 +126,13 @@ def _obj_to_str(obj, **kwargs) -> str: lines = inspect.getsourcelines(obj)[0] extra_spaces = lines[0].find("class") obj_str = "".join(l[extra_spaces:] for l in lines) - elif callable(obj): - lines = inspect.getsourcelines(obj)[0] - extra_spaces = lines[0].find("@") - obj_str = "".join(l[extra_spaces:] for l in lines[1:]) + obj_name = obj.__name__ + elif callable(obj) or isinstance(obj, str): + is_str = isinstance(obj, str) + lines = [i + '\n' for i in obj.split('\n')] if is_str else inspect.getsourcelines(obj)[0] + if not is_str: lines = lines[1:] + obj_str = textwrap.dedent(''.join(lines)) + pyvars = kwargs.get("pyvars", None) if pyvars: for i in pyvars: @@ -138,20 +142,20 @@ def _obj_to_str(obj, **kwargs) -> str: RewriteFunctionArgs(pyvars).visit(ast.parse(obj_str)) ) obj_str = astunparse.unparse(node) + if is_str: + try: + obj_name = ast.parse(obj_str).body[0].name + except: + raise ValueError("cannot infer function name!") + else: + obj_name = obj.__name__ else: raise TypeError("Function or class expected, got " + type(obj).__name__) - return obj_str.replace("_@par", "@par") - - -def _obj_name(obj) -> str: - if inspect.isclass(obj) or callable(obj): - return obj.__name__ - else: - raise TypeError("Function or class expected, got " + type(obj).__name__) + return obj_name, obj_str.replace("_@par", "@par") def _parse_decorated(obj, **kwargs): - return _obj_name(obj), _obj_to_str(obj, **kwargs) + return _obj_to_str(obj, **kwargs) def convert(t): @@ -186,44 +190,59 @@ def convert(t): return t +def _jit_register_fn(f, pyvars, debug): + try: + obj_name, obj_str = _parse_decorated(f, pyvars=pyvars) + fn, fl = "", 1 + if hasattr(f, "__code__"): + fn, fl = f.__code__.co_filename, f.__code__.co_firstlineno + _jit.execute(obj_str, fn, fl, 1 if debug else 0) + return obj_name + except JITError: + _reset_jit() + raise + +def _jit_callback_fn(obj_name, module, debug=None, sample_size=5, pyvars=None, *args, **kwargs): + try: + args = (*args, *kwargs.values()) + types = _codon_types(args, debug=debug, sample_size=sample_size) + if debug: + print("[python] {}({})".format(obj_name, list(types)), file=sys.stderr) + return _jit.run_wrapper( + obj_name, list(types), module, list(pyvars), args, 1 if debug else 0 + ) + except JITError: + _reset_jit() + raise + +def _jit_str_fn(fstr, debug=None, sample_size=5, pyvars=None): + obj_name = _jit_register_fn(fstr, pyvars, debug) + def wrapped(*args, **kwargs): + return _jit_callback_fn(obj_name, "__main__", debug, sample_size, pyvars, *args, **kwargs) + return wrapped + + def jit(fn=None, debug=None, sample_size=5, pyvars=None): if not pyvars: pyvars = [] if not isinstance(pyvars, list): raise ArgumentError("pyvars must be a list") - def _decorate(f): - try: - obj_name, obj_str = _parse_decorated(f, pyvars=pyvars) - _jit.execute( - obj_str, - f.__code__.co_filename, - f.__code__.co_firstlineno, - 1 if debug else 0, - ) - except JITError: - _reset_jit() - raise + if fn and isinstance(fn, str): + return _jit_str_fn(fn, debug, sample_size, pyvars) + def _decorate(f): + obj_name = _jit_register_fn(f, pyvars, debug) @functools.wraps(f) def wrapped(*args, **kwargs): - try: - args = (*args, *kwargs.values()) - types = _codon_types(args, debug=debug, sample_size=sample_size) - if debug: - print( - "[python] {}({})".format(f.__name__, list(types)), - file=sys.stderr, - ) - return _jit.run_wrapper( - obj_name, list(types), f.__module__, list(pyvars), args, 1 if debug else 0 - ) - except JITError: - _reset_jit() - raise - + return _jit_callback_fn(obj_name, f.__module__, debug, sample_size, pyvars, *args, **kwargs) return wrapped + return _decorate(fn) if fn else _decorate + - if fn: - return _decorate(fn) - return _decorate +def execute(code, debug=False): + try: + _jit.execute(code, "", 0, int(debug)) + except JITError: + _reset_jit() + raise diff --git a/scripts/deps.sh b/scripts/deps.sh index a32ca9fc..25c212b3 100755 --- a/scripts/deps.sh +++ b/scripts/deps.sh @@ -21,6 +21,7 @@ if [ ! -f "${INSTALLDIR}/bin/llvm-config" ]; then -DLLVM_INCLUDE_TESTS=OFF \ -DLLVM_ENABLE_RTTI=ON \ -DLLVM_ENABLE_ZLIB=OFF \ + -DLLVM_ENABLE_ZSTD=OFF \ -DLLVM_ENABLE_TERMINFO=OFF \ -DLLVM_TARGETS_TO_BUILD=all \ -DCMAKE_INSTALL_PREFIX="${INSTALLDIR}" diff --git a/stdlib/internal/types/complex.codon b/stdlib/internal/types/complex.codon index e73621d1..30617f9a 100644 --- a/stdlib/internal/types/complex.codon +++ b/stdlib/internal/types/complex.codon @@ -49,40 +49,73 @@ class complex: def __hash__(self) -> int: return self.real.__hash__() + self.imag.__hash__() * 1000003 - def __add__(self, other) -> complex: + def __add__(self, other: float) -> complex: return self + complex(other) - def __sub__(self, other) -> complex: + def __sub__(self, other: float) -> complex: return self - complex(other) - def __mul__(self, other) -> complex: + def __mul__(self, other: float) -> complex: return self * complex(other) - def __truediv__(self, other) -> complex: + def __truediv__(self, other: float) -> complex: return self / complex(other) - def __eq__(self, other) -> bool: + def __eq__(self, other: float) -> bool: return self == complex(other) - def __ne__(self, other) -> bool: + def __ne__(self, other: float) -> bool: return self != complex(other) - def __pow__(self, other) -> complex: + def __pow__(self, other: float) -> complex: return self ** complex(other) - def __radd__(self, other) -> complex: + def __add__(self, other: int) -> complex: + return self + complex(other) + + def __sub__(self, other: int) -> complex: + return self - complex(other) + + def __mul__(self, other: int) -> complex: + return self * complex(other) + + def __truediv__(self, other: int) -> complex: + return self / complex(other) + + def __eq__(self, other: int) -> bool: + return self == complex(other) + + def __ne__(self, other: int) -> bool: + return self != complex(other) + + def __radd__(self, other: float) -> complex: return complex(other) + self - def __rsub__(self, other) -> complex: + def __rsub__(self, other: float) -> complex: return complex(other) - self - def __rmul__(self, other) -> complex: + def __rmul__(self, other: float) -> complex: return complex(other) * self - def __rtruediv__(self, other) -> complex: + def __rtruediv__(self, other: float) -> complex: return complex(other) / self - def __rpow__(self, other) -> complex: + def __rpow__(self, other: float) -> complex: + return complex(other) ** self + + def __radd__(self, other: int) -> complex: + return complex(other) + self + + def __rsub__(self, other: int) -> complex: + return complex(other) - self + + def __rmul__(self, other: int) -> complex: + return complex(other) * self + + def __rtruediv__(self, other: int) -> complex: + return complex(other) / self + + def __rpow__(self, other: int) -> complex: return complex(other) ** self def __add__(self, other: complex) -> complex: @@ -194,7 +227,23 @@ class complex: %y = call double @llvm.cos.f64(double %x) ret double %y - if other.real == 0.0 and other.imag == 0.0: + @pure + @llvm + def floor(x: float) -> float: + declare double @llvm.floor.f64(double) + %y = call double @llvm.floor.f64(double %x) + ret double %y + + @pure + @llvm + def fabs(x: float) -> float: + declare double @llvm.fabs.f64(double) + %y = call double @llvm.fabs.f64(double %x) + ret double %y + + if other.imag == 0.0 and other.real == floor(other.real) and fabs(other.real) <= 100.0: + return self ** int(other.real) + elif other.real == 0.0 and other.imag == 0.0: return complex(1.0, 0.0) elif self.real == 0.0 and self.imag == 0.0: # if other.imag != 0. or other.real < 0.: errno = EDOM @@ -342,42 +391,150 @@ class complex64: def __hash__(self) -> int: return self.real.__hash__() + self.imag.__hash__() * 1000003 - def __add__(self, other) -> complex64: + def __add__(self, other: complex) -> complex: + return complex(self) + other + + def __sub__(self, other: complex) -> complex: + return complex(self) - other + + def __mul__(self, other: complex) -> complex: + return complex(self) * other + + def __truediv__(self, other: complex) -> complex: + return complex(self) / other + + def __pow__(self, other: complex) -> complex: + return complex(self) ** other + + def __eq__(self, other: complex) -> bool: + return complex(self) == other + + def __ne__(self, other: complex) -> bool: + return complex(self) != other + + def __radd__(self, other: complex) -> complex: + return other + complex(self) + + def __rsub__(self, other: complex) -> complex: + return other - complex(self) + + def __rmul__(self, other: complex) -> complex: + return other * complex(self) + + def __rtruediv__(self, other: complex) -> complex: + return other / complex(self) + + def __rpow__(self, other: complex) -> complex: + return other ** complex(self) + + def __add__(self, other: float32) -> complex64: return self + complex64(other) - def __sub__(self, other) -> complex64: + def __sub__(self, other: float32) -> complex64: return self - complex64(other) - def __mul__(self, other) -> complex64: + def __mul__(self, other: float32) -> complex64: return self * complex64(other) - def __truediv__(self, other) -> complex64: + def __truediv__(self, other: float32) -> complex64: return self / complex64(other) - def __eq__(self, other) -> bool: + def __pow__(self, other: float32) -> complex64: + return self ** complex64(other) + + def __eq__(self, other: float32) -> bool: return self == complex64(other) - def __ne__(self, other) -> bool: + def __ne__(self, other: float32) -> bool: return self != complex64(other) - def __pow__(self, other) -> complex64: - return self ** complex64(other) - - def __radd__(self, other) -> complex64: + def __radd__(self, other: float32) -> complex64: return complex64(other) + self - def __rsub__(self, other) -> complex64: + def __rsub__(self, other: float32) -> complex64: return complex64(other) - self - def __rmul__(self, other) -> complex64: + def __rmul__(self, other: float32) -> complex64: return complex64(other) * self - def __rtruediv__(self, other) -> complex64: + def __rtruediv__(self, other: float32) -> complex64: return complex64(other) / self - def __rpow__(self, other) -> complex64: + def __rpow__(self, other: float32) -> complex64: return complex64(other) ** self + def __add__(self, other: float) -> complex: + return complex(self) + other + + def __sub__(self, other: float) -> complex: + return complex(self) - other + + def __mul__(self, other: float) -> complex: + return complex(self) * other + + def __truediv__(self, other: float) -> complex: + return complex(self) / other + + def __pow__(self, other: float) -> complex: + return complex(self) ** other + + def __eq__(self, other: float) -> bool: + return complex(self) == other + + def __ne__(self, other: float) -> bool: + return complex(self) != other + + def __radd__(self, other: float) -> complex: + return other + complex(self) + + def __rsub__(self, other: float) -> complex: + return other - complex(self) + + def __rmul__(self, other: float) -> complex: + return other * complex(self) + + def __rtruediv__(self, other: float) -> complex: + return other / complex(self) + + def __rpow__(self, other: float) -> complex: + return other ** complex(self) + + def __add__(self, other: int) -> complex: + return complex(self) + other + + def __sub__(self, other: int) -> complex: + return complex(self) - other + + def __mul__(self, other: int) -> complex: + return complex(self) * other + + def __truediv__(self, other: int) -> complex: + return complex(self) / other + + # def __pow__(self, other: int) -> complex: + # return complex(self) ** other + + def __eq__(self, other: int) -> bool: + return complex(self) == other + + def __ne__(self, other: int) -> bool: + return complex(self) != other + + def __radd__(self, other: int) -> complex: + return other + complex(self) + + def __rsub__(self, other: int) -> complex: + return other - complex(self) + + def __rmul__(self, other: int) -> complex: + return other * complex(self) + + def __rtruediv__(self, other: int) -> complex: + return other / complex(self) + + def __rpow__(self, other: int) -> complex: + return other ** complex(self) + def __add__(self, other: complex64) -> complex64: return complex64(self.real + other.real, self.imag + other.imag) @@ -487,7 +644,23 @@ class complex64: %y = call float @llvm.cos.f32(float %x) ret float %y - if other.real == f32(0.0) and other.imag == f32(0.0): + @pure + @llvm + def floor(x: f32) -> f32: + declare float @llvm.floor.f32(float) + %y = call float @llvm.floor.f32(float %x) + ret float %y + + @pure + @llvm + def fabs(x: f32) -> f32: + declare float @llvm.fabs.f64(float) + %y = call float @llvm.fabs.f64(float %x) + ret float %y + + if other.imag == f32(0.0) and other.real == floor(other.real) and fabs(other.real) <= f32(100.0): + return self ** int(other.real) + elif other.real == f32(0.0) and other.imag == f32(0.0): return complex64(1.0, 0.0) elif self.real == f32(0.0) and self.imag == f32(0.0): # if other.imag != 0. or other.real < 0.: errno = EDOM diff --git a/stdlib/time.codon b/stdlib/time.codon index 8ed0a713..b5dcc5f2 100644 --- a/stdlib/time.codon +++ b/stdlib/time.codon @@ -1,7 +1,5 @@ # Copyright (C) 2022-2024 Exaloop Inc. -from sys import stderr - def time() -> float: return _C.seq_time() / 1e9 @@ -41,6 +39,7 @@ class TimeInterval: self.start = _C.seq_time() def __exit__(self): + from sys import stderr print(self.report(self.msg), file=stderr) def report(self, msg="", memory=False) -> str: diff --git a/test/stdlib/cmath_test.codon b/test/stdlib/cmath_test.codon index d2e8926e..f397919b 100644 --- a/test/stdlib/cmath_test.codon +++ b/test/stdlib/cmath_test.codon @@ -815,27 +815,37 @@ test_complex_bool() def test_complex64(): c64 = complex64 z = c64(.5 + .5j) - assert c64() == z * 0 - assert z + 1 == c64(1.5, .5) + assert c64() == z * float32(0.) + assert z + float32(1) == c64(1.5, .5) assert bool(z) == True assert bool(0 * z) == False assert +z == z assert -z == c64(-.5 - .5j) assert abs(z) == float32(0.7071067811865476) - assert z + 1 == c64(1.5 + .5j) - assert 1j + z == c64(.5 + 1.5j) - assert z * 2 == c64(1 + 1j) - assert 2j * z == c64(-1 + 1j) - assert z / .5 == c64(1 + 1j) - assert 1j / z == c64(1 + 1j) - assert z ** 2 == c64(.5j) - y = 1j ** z - assert math.isclose(float(y.real), 0.32239694194483454) - assert math.isclose(float(y.imag), 0.32239694194483454) + + assert z + c64(1) == c64(1.5 + .5j) + assert c64(1j) + z == c64(.5 + 1.5j) + assert z * c64(2) == c64(1 + 1j) + assert c64(2j) * z == c64(-1 + 1j) + assert z / c64(.5) == c64(1 + 1j) + assert c64(1j) / z == c64(1 + 1j) + assert z ** c64(2) == c64(.5j) + + assert z + 1 == 1.5 + .5j + assert 1j + z == .5 + 1.5j + assert z * 2 == 1 + 1j + assert 2j * z == -1 + 1j + assert z / .5 == 1 + 1j + assert 1j / z == 1 + 1j + assert z ** 2 == .5j + + y = c64(1j) ** z + assert math.isclose(float(y.real), 0.32239694194483454, rel_tol=1e-7) + assert math.isclose(float(y.imag), 0.32239694194483454, rel_tol=1e-7) assert z != -z assert z != 0 assert z.real == float32(.5) - assert (z + 1j).imag == float32(1.5) + assert (z + c64(1j)).imag == float32(1.5) assert z.conjugate() == c64(.5 - .5j) assert z.__copy__() == z assert hash(z)