Skip to content

Commit

Permalink
Updates from enterprise version (#571)
Browse files Browse the repository at this point in the history
  • Loading branch information
arshajii authored Jul 30, 2024
1 parent c750ae6 commit 11d281d
Show file tree
Hide file tree
Showing 13 changed files with 317 additions and 99 deletions.
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
*.dylib
*.pyc
build/
build_*/
build-*/
install/
install_*/
install-*/
Expand Down Expand Up @@ -70,4 +68,3 @@ jit/codon/version.py
temp/
playground/
scratch*.*
_*
16 changes: 15 additions & 1 deletion codon/cir/llvm/llvisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ const std::string GPU_KERNEL_ATTR = "std.gpu.kernel";

const std::string MAIN_UNCLASH = ".main.unclash";
const std::string MAIN_CTOR = ".main.ctor";

llvm::cl::opt<bool> DisableExceptions("disable-exceptions",
llvm::cl::desc("Disable exception handling"),
llvm::cl::init(false));
} // namespace

llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) {
Expand Down Expand Up @@ -1696,6 +1700,7 @@ void LLVMVisitor::visit(const ExternalFunc *x) {
coro = {};
seqassertn(func, "{} not inserted", *x);
func->setDoesNotThrow();
func->setWillReturn();
}

namespace {
Expand Down Expand Up @@ -1939,7 +1944,9 @@ void LLVMVisitor::visit(const BodiedFunc *x) {
func->addFnAttr(llvm::Attribute::get(*context, "kernel"));
func->setLinkage(llvm::GlobalValue::ExternalLinkage);
}
func->setPersonalityFn(llvm::cast<llvm::Constant>(makePersonalityFunc().getCallee()));
if (!DisableExceptions)
func->setPersonalityFn(
llvm::cast<llvm::Constant>(makePersonalityFunc().getCallee()));

auto *funcType = cast<types::FuncType>(x->getType());
seqassertn(funcType, "{} is not a function type", *x->getType());
Expand Down Expand Up @@ -3362,6 +3369,13 @@ void LLVMVisitor::visit(const YieldInstr *x) {
}

void LLVMVisitor::visit(const ThrowInstr *x) {
if (DisableExceptions) {
B->SetInsertPoint(block);
B->CreateUnreachable();
block = llvm::BasicBlock::Create(*context, "throw_unreachable.new", func);
return;
}

// note: exception header should be set in the frontend
auto excAllocFunc = makeExcAllocFunc();
auto throwFunc = makeThrowFunc();
Expand Down
2 changes: 1 addition & 1 deletion codon/parser/cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ std::string Cache::rev(const std::string &s) {
void Cache::addGlobal(const std::string &name, ir::Var *var) {
if (!in(globals, name)) {
// LOG("[global] {}", name);
globals[name] = var;
globals[name] = {false, var};
}
}

Expand Down
2 changes: 1 addition & 1 deletion codon/parser/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ struct Cache : public std::enable_shared_from_this<Cache> {

/// Set of unique (canonical) global identifiers for marking such variables as global
/// in code-generation step and in JIT.
std::map<std::string, ir::Var *> globals;
std::map<std::string, std::pair<bool, ir::Var *>> globals;

/// Stores class data for each class (type) in the source code.
struct Class {
Expand Down
12 changes: 6 additions & 6 deletions codon/parser/visitors/translate/translate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ ir::Func *TranslateVisitor::apply(Cache *cache, const StmtPtr &stmts) {
cache->codegenCtx->bases = {main};
cache->codegenCtx->series = {block};

for (auto &g : cache->globals)
if (!g.second) {
g.second = g.first == VAR_ARGV ? cache->codegenCtx->getModule()->getArgVar()
: cache->codegenCtx->getModule()->N<ir::Var>(
SrcInfo(), nullptr, true, false, g.first);
cache->codegenCtx->add(TranslateItem::Var, g.first, g.second);
for (auto &[name, p] : cache->globals)
if (p.first && !p.second) {
p.second = name == VAR_ARGV ? cache->codegenCtx->getModule()->getArgVar()
: cache->codegenCtx->getModule()->N<ir::Var>(
SrcInfo(), nullptr, true, false, name);
cache->codegenCtx->add(TranslateItem::Var, name, p.second);
}

auto tv = TranslateVisitor(cache->codegenCtx);
Expand Down
4 changes: 4 additions & 0 deletions codon/parser/visitors/typecheck/assign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ void TypecheckVisitor::visit(AssignStmt *stmt) {
ctx->instantiate(stmt->type->getSrcInfo(), stmt->type->getType()));
}
ctx->add(TypecheckItem::Var, lhs, stmt->lhs->type);
if (in(ctx->cache->globals, lhs))
ctx->cache->globals[lhs].first = true;
if (realize(stmt->lhs->type) || !stmt->type)
stmt->setDone();
} else if (stmt->type && stmt->type->getType()->isStaticType()) {
Expand All @@ -84,6 +86,7 @@ void TypecheckVisitor::visit(AssignStmt *stmt) {
auto val = ctx->add(TypecheckItem::Var, lhs, stmt->lhs->type);
if (in(ctx->cache->globals, lhs)) {
// Make globals always visible!
ctx->cache->globals[lhs].first = true;
ctx->addToplevel(lhs, val);
}
if (realize(stmt->lhs->type))
Expand Down Expand Up @@ -112,6 +115,7 @@ void TypecheckVisitor::visit(AssignStmt *stmt) {

if (in(ctx->cache->globals, lhs)) {
// Make globals always visible!
ctx->cache->globals[lhs].first = true;
ctx->addToplevel(lhs, val);
if (kind != TypecheckItem::Var)
ctx->cache->globals.erase(lhs);
Expand Down
1 change: 1 addition & 0 deletions codon/runtime/lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,5 +130,6 @@ std::string makeBacktraceFrameString(uintptr_t pc, const std::string &func = "",
std::string getCapturedOutput();

void setJITErrorCallback(std::function<void(const JITError &)> callback);

} // namespace runtime
} // namespace codon
2 changes: 1 addition & 1 deletion jit/codon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

__all__ = ["jit", "convert", "JITError"]

from .decorator import jit, convert, JITError
from .decorator import jit, convert, execute, JITError
105 changes: 62 additions & 43 deletions jit/codon/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import functools
import itertools
import ast
import textwrap
import astunparse
from pathlib import Path

Expand Down Expand Up @@ -125,10 +126,13 @@ def _obj_to_str(obj, **kwargs) -> str:
lines = inspect.getsourcelines(obj)[0]
extra_spaces = lines[0].find("class")
obj_str = "".join(l[extra_spaces:] for l in lines)
elif callable(obj):
lines = inspect.getsourcelines(obj)[0]
extra_spaces = lines[0].find("@")
obj_str = "".join(l[extra_spaces:] for l in lines[1:])
obj_name = obj.__name__
elif callable(obj) or isinstance(obj, str):
is_str = isinstance(obj, str)
lines = [i + '\n' for i in obj.split('\n')] if is_str else inspect.getsourcelines(obj)[0]
if not is_str: lines = lines[1:]
obj_str = textwrap.dedent(''.join(lines))

pyvars = kwargs.get("pyvars", None)
if pyvars:
for i in pyvars:
Expand All @@ -138,20 +142,20 @@ def _obj_to_str(obj, **kwargs) -> str:
RewriteFunctionArgs(pyvars).visit(ast.parse(obj_str))
)
obj_str = astunparse.unparse(node)
if is_str:
try:
obj_name = ast.parse(obj_str).body[0].name
except:
raise ValueError("cannot infer function name!")
else:
obj_name = obj.__name__
else:
raise TypeError("Function or class expected, got " + type(obj).__name__)
return obj_str.replace("_@par", "@par")


def _obj_name(obj) -> str:
if inspect.isclass(obj) or callable(obj):
return obj.__name__
else:
raise TypeError("Function or class expected, got " + type(obj).__name__)
return obj_name, obj_str.replace("_@par", "@par")


def _parse_decorated(obj, **kwargs):
return _obj_name(obj), _obj_to_str(obj, **kwargs)
return _obj_to_str(obj, **kwargs)


def convert(t):
Expand Down Expand Up @@ -186,44 +190,59 @@ def convert(t):
return t


def _jit_register_fn(f, pyvars, debug):
try:
obj_name, obj_str = _parse_decorated(f, pyvars=pyvars)
fn, fl = "<internal>", 1
if hasattr(f, "__code__"):
fn, fl = f.__code__.co_filename, f.__code__.co_firstlineno
_jit.execute(obj_str, fn, fl, 1 if debug else 0)
return obj_name
except JITError:
_reset_jit()
raise

def _jit_callback_fn(obj_name, module, debug=None, sample_size=5, pyvars=None, *args, **kwargs):
try:
args = (*args, *kwargs.values())
types = _codon_types(args, debug=debug, sample_size=sample_size)
if debug:
print("[python] {}({})".format(obj_name, list(types)), file=sys.stderr)
return _jit.run_wrapper(
obj_name, list(types), module, list(pyvars), args, 1 if debug else 0
)
except JITError:
_reset_jit()
raise

def _jit_str_fn(fstr, debug=None, sample_size=5, pyvars=None):
obj_name = _jit_register_fn(fstr, pyvars, debug)
def wrapped(*args, **kwargs):
return _jit_callback_fn(obj_name, "__main__", debug, sample_size, pyvars, *args, **kwargs)
return wrapped


def jit(fn=None, debug=None, sample_size=5, pyvars=None):
if not pyvars:
pyvars = []
if not isinstance(pyvars, list):
raise ArgumentError("pyvars must be a list")

def _decorate(f):
try:
obj_name, obj_str = _parse_decorated(f, pyvars=pyvars)
_jit.execute(
obj_str,
f.__code__.co_filename,
f.__code__.co_firstlineno,
1 if debug else 0,
)
except JITError:
_reset_jit()
raise
if fn and isinstance(fn, str):
return _jit_str_fn(fn, debug, sample_size, pyvars)

def _decorate(f):
obj_name = _jit_register_fn(f, pyvars, debug)
@functools.wraps(f)
def wrapped(*args, **kwargs):
try:
args = (*args, *kwargs.values())
types = _codon_types(args, debug=debug, sample_size=sample_size)
if debug:
print(
"[python] {}({})".format(f.__name__, list(types)),
file=sys.stderr,
)
return _jit.run_wrapper(
obj_name, list(types), f.__module__, list(pyvars), args, 1 if debug else 0
)
except JITError:
_reset_jit()
raise

return _jit_callback_fn(obj_name, f.__module__, debug, sample_size, pyvars, *args, **kwargs)
return wrapped
return _decorate(fn) if fn else _decorate


if fn:
return _decorate(fn)
return _decorate
def execute(code, debug=False):
try:
_jit.execute(code, "<internal>", 0, int(debug))
except JITError:
_reset_jit()
raise
1 change: 1 addition & 0 deletions scripts/deps.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ if [ ! -f "${INSTALLDIR}/bin/llvm-config" ]; then
-DLLVM_INCLUDE_TESTS=OFF \
-DLLVM_ENABLE_RTTI=ON \
-DLLVM_ENABLE_ZLIB=OFF \
-DLLVM_ENABLE_ZSTD=OFF \
-DLLVM_ENABLE_TERMINFO=OFF \
-DLLVM_TARGETS_TO_BUILD=all \
-DCMAKE_INSTALL_PREFIX="${INSTALLDIR}"
Expand Down
Loading

0 comments on commit 11d281d

Please sign in to comment.