Skip to content

Commit

Permalink
Bugfixes 2023-08 (#440)
Browse files Browse the repository at this point in the history
* Fix type argument overload issue; Fix Cython version for CI

* Add __contains__ for kwargs

* Add get() for kwargs

* Add static <<, >> and unary ~

* Fix CI

* Fix OpenMP "ordered" clause

* Fix static ~

* Fix Cython 3 issues

* Fix Python MANIFEST.in

---------

Co-authored-by: A. R. Shajii <[email protected]>
  • Loading branch information
inumanag and arshajii authored Aug 12, 2023
1 parent 7198a09 commit 750bb28
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 17 deletions.
24 changes: 20 additions & 4 deletions codon/parser/visitors/typecheck/class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,19 +151,35 @@ std::string TypecheckVisitor::generateTuple(size_t len, const std::string &name,
StmtPtr stmt = N<ClassStmt>(ctx->cache->generateSrcInfo(), typeName, args, nullptr,
std::vector<ExprPtr>{N<IdExpr>("tuple")});

// Add getItem for KwArgs:
// Add helpers for KwArgs:
// `def __getitem__(self, key: Static[str]): return getattr(self, key)`
// `def __contains__(self, key: Static[str]): return hasattr(self, key)`
auto getItem = N<FunctionStmt>(
"__getitem__", nullptr,
std::vector<Param>{Param{"self"}, Param{"key", N<IndexExpr>(N<IdExpr>("Static"),
N<IdExpr>("str"))}},
N<SuiteStmt>(N<ReturnStmt>(
N<CallExpr>(N<IdExpr>("getattr"), N<IdExpr>("self"), N<IdExpr>("key")))));
auto contains = N<FunctionStmt>(
"__contains__", nullptr,
std::vector<Param>{Param{"self"}, Param{"key", N<IndexExpr>(N<IdExpr>("Static"),
N<IdExpr>("str"))}},
N<SuiteStmt>(N<ReturnStmt>(
N<CallExpr>(N<IdExpr>("hasattr"), N<IdExpr>("self"), N<IdExpr>("key")))));
auto getDef = N<FunctionStmt>(
"get", nullptr,
std::vector<Param>{
Param{"self"},
Param{"key", N<IndexExpr>(N<IdExpr>("Static"), N<IdExpr>("str"))},
Param{"default", nullptr, N<CallExpr>(N<IdExpr>("NoneType"))}},
N<SuiteStmt>(N<ReturnStmt>(
N<CallExpr>(N<DotExpr>(N<IdExpr>("__internal__"), "kwargs_get"),
N<IdExpr>("self"), N<IdExpr>("key"), N<IdExpr>("default")))));
if (startswith(typeName, TYPE_KWTUPLE))
stmt->getClass()->suite = getItem;
stmt->getClass()->suite = N<SuiteStmt>(getItem, contains, getDef);

// Add getItem for KwArgs:
// `def __repr__(self,): return __magic__.repr_partial(self)`
// Add repr for KwArgs:
// `def __repr__(self): return __magic__.repr_partial(self)`
auto repr = N<FunctionStmt>(
"__repr__", nullptr, std::vector<Param>{Param{"self"}},
N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(
Expand Down
13 changes: 10 additions & 3 deletions codon/parser/visitors/typecheck/op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ void TypecheckVisitor::visit(UnaryExpr *expr) {
transform(expr->expr);

static std::unordered_map<StaticValue::Type, std::unordered_set<std::string>>
staticOps = {{StaticValue::INT, {"-", "+", "!"}}, {StaticValue::STRING, {"@"}}};
staticOps = {{StaticValue::INT, {"-", "+", "!", "~"}},
{StaticValue::STRING, {"@"}}};
// Handle static expressions
if (expr->expr->isStatic() && in(staticOps[expr->expr->staticValue.type], expr->op)) {
resultExpr = evaluateStaticUnary(expr);
Expand Down Expand Up @@ -62,7 +63,7 @@ void TypecheckVisitor::visit(BinaryExpr *expr) {
static std::unordered_map<StaticValue::Type, std::unordered_set<std::string>>
staticOps = {{StaticValue::INT,
{"<", "<=", ">", ">=", "==", "!=", "&&", "||", "+", "-", "*", "//",
"%", "&", "|", "^"}},
"%", "&", "|", "^", ">>", "<<"}},
{StaticValue::STRING, {"==", "!=", "+"}}};
if (expr->lexpr->isStatic() && expr->rexpr->isStatic() &&
expr->lexpr->staticValue.type == expr->rexpr->staticValue.type &&
Expand Down Expand Up @@ -370,13 +371,15 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) {
}

// Case: static integers
if (expr->op == "-" || expr->op == "+" || expr->op == "!") {
if (expr->op == "-" || expr->op == "+" || expr->op == "!" || expr->op == "~") {
if (expr->expr->staticValue.evaluated) {
int64_t value = expr->expr->staticValue.getInt();
if (expr->op == "+")
;
else if (expr->op == "-")
value = -value;
else if (expr->op == "~")
value = ~value;
else
value = !bool(value);
LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value);
Expand Down Expand Up @@ -484,6 +487,10 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) {
lvalue = lvalue & rvalue;
else if (expr->op == "|")
lvalue = lvalue | rvalue;
else if (expr->op == ">>")
lvalue = lvalue >> rvalue;
else if (expr->op == "<<")
lvalue = lvalue << rvalue;
else if (expr->op == "//")
lvalue = divMod(ctx, lvalue, rvalue).first;
else if (expr->op == "%")
Expand Down
7 changes: 6 additions & 1 deletion codon/parser/visitors/typecheck/typecheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ int TypecheckVisitor::canCall(const types::FuncTypePtr &fn,
auto score = ctx->reorderNamedArgs(
fn.get(), args,
[&](int s, int k, const std::vector<std::vector<int>> &slots, bool _) {
for (int si = 0; si < slots.size(); si++) {
for (int si = 0, gi = 0; si < slots.size(); si++) {
if (fn->ast->args[si].status == Param::Generic) {
if (slots[si].empty()) {
// is this "real" type?
Expand All @@ -263,8 +263,13 @@ int TypecheckVisitor::canCall(const types::FuncTypePtr &fn,
}
reordered.push_back({nullptr, 0});
} else {
seqassert(gi < fn->funcGenerics.size(), "bad fn");
if (!fn->funcGenerics[gi].type->isStaticType() &&
!args[slots[si][0]].value->isType())
return -1;
reordered.push_back({args[slots[si][0]].value->type, slots[si][0]});
}
gi++;
} else if (si == s || si == k || slots[si].size() != 1) {
// Ignore *args, *kwargs and default arguments
reordered.push_back({nullptr, 0});
Expand Down
1 change: 1 addition & 0 deletions jit/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include codon/*.pxd
2 changes: 1 addition & 1 deletion jit/codon/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def wrapped(*args, **kwargs):
file=sys.stderr,
)
return _jit.run_wrapper(
obj_name, types, f.__module__, pyvars, args, 1 if debug else 0
obj_name, list(types), f.__module__, list(pyvars), args, 1 if debug else 0
)
except JITError:
_reset_jit()
Expand Down
2 changes: 1 addition & 1 deletion jit/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@

jit_extension = Extension(
"codon.codon_jit",
sources=["codon/jit.pyx", "codon/jit.pxd"],
sources=["codon/jit.pyx"],
libraries=libraries,
language="c++",
extra_compile_args=["-w"],
Expand Down
6 changes: 6 additions & 0 deletions stdlib/internal/internal.codon
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,12 @@ class __internal__:
e.col = col
return e

def kwargs_get(kw, key: Static[str], default):
if hasattr(kw, key):
return getattr(kw, key)
else:
return default


@extend
class __magic__:
Expand Down
14 changes: 7 additions & 7 deletions stdlib/openmp.codon
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def _master_end(loc_ref: Ptr[Ident], gtid: int):
__kmpc_end_master(loc_ref, i32(gtid))

def _ordered_begin(loc_ref: Ptr[Ident], gtid: int):
from C import __kmpc_ordered(Ptr[Ident], i32) -> i32
return int(__kmpc_ordered(loc_ref, i32(gtid)))
from C import __kmpc_ordered(Ptr[Ident], i32)
__kmpc_ordered(loc_ref, i32(gtid))

def _ordered_end(loc_ref: Ptr[Ident], gtid: int):
from C import __kmpc_end_ordered(Ptr[Ident], i32)
Expand Down Expand Up @@ -781,11 +781,11 @@ def ordered(func):
def _wrapper(*args, **kwargs):
gtid = get_thread_num()
loc = _default_loc()
if _ordered_begin(loc, gtid) != 0:
try:
func(*args, **kwargs)
finally:
_ordered_end(loc, gtid)
_ordered_begin(loc, gtid)
try:
func(*args, **kwargs)
finally:
_ordered_end(loc, gtid)

return _wrapper

Expand Down
15 changes: 15 additions & 0 deletions test/parser/types.codon
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,21 @@ def foo(x):
print foo('hi') #: (3, 2)
print foo('hi', 1) #: (2, 'hi_1')


def fox(a: int, b: int, c: int, dtype: type = int):
print('fox 1:', a, b, c)

@overload
def fox(a: int, b: int, dtype: type = int):
print('fox 2:', a, b, dtype.__class__.__name__)

fox(1, 2, float)
#: fox 2: 1 2 float
fox(1, 2)
#: fox 2: 1 2 int
fox(1, 2, 3)
#: fox 1: 1 2 3

#%% fn_shadow,barebones
def foo(x):
return 1, x
Expand Down
15 changes: 15 additions & 0 deletions test/transform/omp.codon
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,20 @@ def test_omp_collapse():

assert A6 == B6

@test
def test_omp_ordered(N: int = 1000):
@omp.ordered
def f(A, i):
A.append(i)

A = []

@par(schedule='dynamic', chunk_size=1, num_threads=2, ordered=True)
for i in range(N):
f(A, i)

assert A == list(range(N))

test_omp_api()
test_omp_schedules()
test_omp_ranges()
Expand All @@ -901,3 +915,4 @@ test_omp_transform(111.1, 222.2, 333.3)
test_omp_nested()
test_omp_corner_cases()
test_omp_collapse()
test_omp_ordered()

0 comments on commit 750bb28

Please sign in to comment.