From d23c8c7a759bd5ab8350156ee69d33d198f818c7 Mon Sep 17 00:00:00 2001 From: "A. R. Shajii" Date: Fri, 12 Jan 2024 19:27:29 -0500 Subject: [PATCH] Support "key" argument on min() and max() builtins (#505) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Support "key" argument on min() and max() builtins * Delay overload selection when arguments are not known (delayed dispatch) * Delay 'is None' for 'Optional[T]' until type is known * Fix union overload selection * Add static string slicing * Fix itertools.accumulate * Fix list comprehension optimization ( minitech:imports-in-list-comprehensions ) * Fix match or patterns * Fix tests and faulty static tuple issue * Fix OpenMP reductions with new min/max functions * Fix domination of dominated bindings; Fix hasattr overloads; Fix arg=None handling * Fix empty return handling; Mark generators with an attribute * Fix #487 * Fix test * Fix IR pass --------- Co-authored-by: Ibrahim Numanagić --- codon/cir/transform/parallel/openmp.cpp | 23 ++-- codon/cir/transform/pythonic/generator.cpp | 12 +- codon/parser/ast/stmt.cpp | 1 + codon/parser/ast/stmt.h | 1 + codon/parser/visitors/simplify/access.cpp | 2 +- codon/parser/visitors/simplify/function.cpp | 2 + codon/parser/visitors/typecheck/assign.cpp | 32 +++-- codon/parser/visitors/typecheck/call.cpp | 80 +++++++----- codon/parser/visitors/typecheck/function.cpp | 20 +-- codon/parser/visitors/typecheck/op.cpp | 7 +- stdlib/internal/builtin.codon | 63 ++++++---- stdlib/internal/core.codon | 4 - stdlib/internal/internal.codon | 6 +- test/core/bltin.codon | 126 +++++++++++++++++++ test/parser/typecheck_stmt.codon | 37 +++++- test/parser/types.codon | 16 ++- 16 files changed, 334 insertions(+), 98 deletions(-) diff --git a/codon/cir/transform/parallel/openmp.cpp b/codon/cir/transform/parallel/openmp.cpp index d9824631..4b599c00 100644 --- a/codon/cir/transform/parallel/openmp.cpp +++ b/codon/cir/transform/parallel/openmp.cpp @@ -203,18 +203,15 @@ struct Reduction { case Kind::XOR: result = *lhs ^ *arg; break; - case Kind::MIN: { - auto *tup = util::makeTuple({lhs, arg}); - auto *fn = M->getOrRealizeFunc("min", {tup->getType()}, {}, builtinModule); - seqassertn(fn, "min function not found"); - result = util::call(fn, {tup}); - break; - } + case Kind::MIN: case Kind::MAX: { + auto name = (kind == Kind::MIN ? "min" : "max"); auto *tup = util::makeTuple({lhs, arg}); - auto *fn = M->getOrRealizeFunc("max", {tup->getType()}, {}, builtinModule); - seqassertn(fn, "max function not found"); - result = util::call(fn, {tup}); + auto *none = (*M->getNoneType())(); + auto *fn = M->getOrRealizeFunc(name, {tup->getType(), none->getType()}, {}, + builtinModule); + seqassertn(fn, "{} function not found", name); + result = util::call(fn, {tup, none}); break; } default: @@ -432,6 +429,7 @@ struct ReductionIdentifier : public util::Operator { auto *ptrType = cast(shared->getType()); seqassertn(ptrType, "expected shared var to be of pointer type"); auto *type = ptrType->getBase(); + auto *noneType = M->getOptionalType(M->getNoneType()); // double-check the call if (!util::isCallOf(v, Module::SETITEM_MAGIC_NAME, @@ -454,7 +452,8 @@ struct ReductionIdentifier : public util::Operator { if (!util::isCallOf(item, rf.name, {type, type}, type, /*method=*/true)) continue; } else { - if (!util::isCallOf(item, rf.name, {M->getTupleType({type, type})}, type, + if (!util::isCallOf(item, rf.name, {M->getTupleType({type, type}), noneType}, + type, /*method=*/false)) continue; } @@ -1183,9 +1182,7 @@ struct GPULoopBodyStubReplacer : public util::Operator { std::vector newArgs; for (auto *arg : *replacement) { - // std::cout << "A: " << *arg << std::endl; if (getVarFromOutlinedArg(arg)->getId() == loopVar->getId()) { - // std::cout << "(loop var)" << std::endl; newArgs.push_back(idx); } else { newArgs.push_back(util::tupleGet(args, next++)); diff --git a/codon/cir/transform/pythonic/generator.cpp b/codon/cir/transform/pythonic/generator.cpp index b1be1a8d..023bad75 100644 --- a/codon/cir/transform/pythonic/generator.cpp +++ b/codon/cir/transform/pythonic/generator.cpp @@ -60,7 +60,11 @@ struct GeneratorSumTransformer : public util::Operator { auto *M = v->getModule(); auto *newReturn = M->Nr(M->Nr(accumulator)); see(newReturn); - v->replaceAll(util::series(v->getValue(), newReturn)); + if (v->getValue()) { + v->replaceAll(util::series(v->getValue(), newReturn)); + } else { + v->replaceAll(newReturn); + } } void handle(YieldInInstr *v) override { valid = false; } @@ -97,7 +101,11 @@ struct GeneratorAnyAllTransformer : public util::Operator { auto *M = v->getModule(); auto *newReturn = M->Nr(M->getBool(!any)); see(newReturn); - v->replaceAll(util::series(v->getValue(), newReturn)); + if (v->getValue()) { + v->replaceAll(util::series(v->getValue(), newReturn)); + } else { + v->replaceAll(newReturn); + } } void handle(YieldInInstr *v) override { valid = false; } diff --git a/codon/parser/ast/stmt.cpp b/codon/parser/ast/stmt.cpp index 02b86f10..b896c49d 100644 --- a/codon/parser/ast/stmt.cpp +++ b/codon/parser/ast/stmt.cpp @@ -324,6 +324,7 @@ const std::string Attr::CVarArg = ".__vararg__"; const std::string Attr::Method = ".__method__"; const std::string Attr::Capture = ".__capture__"; const std::string Attr::HasSelf = ".__hasself__"; +const std::string Attr::IsGenerator = ".__generator__"; const std::string Attr::Extend = "extend"; const std::string Attr::Tuple = "tuple"; const std::string Attr::Test = "std.internal.attributes.test"; diff --git a/codon/parser/ast/stmt.h b/codon/parser/ast/stmt.h index 45596ce7..a1fa17b0 100644 --- a/codon/parser/ast/stmt.h +++ b/codon/parser/ast/stmt.h @@ -427,6 +427,7 @@ struct Attr { const static std::string Method; const static std::string Capture; const static std::string HasSelf; + const static std::string IsGenerator; // Class attributes const static std::string Extend; const static std::string Tuple; diff --git a/codon/parser/visitors/simplify/access.cpp b/codon/parser/visitors/simplify/access.cpp index 15f08356..28cf6351 100644 --- a/codon/parser/visitors/simplify/access.cpp +++ b/codon/parser/visitors/simplify/access.cpp @@ -79,7 +79,7 @@ void SimplifyVisitor::visit(IdExpr *expr) { if (!checked) { // Prepend access with __internal__.undef([var]__used__, "[var name]") auto checkStmt = N(N( - N("__internal__", "undef"), + N("__internal__.undef"), N(fmt::format("{}.__used__", val->canonicalName)), N(ctx->cache->reverseIdentifierLookup[val->canonicalName]))); if (!ctx->isConditionalExpr) { diff --git a/codon/parser/visitors/simplify/function.cpp b/codon/parser/visitors/simplify/function.cpp index 4fc41be2..2a36dd30 100644 --- a/codon/parser/visitors/simplify/function.cpp +++ b/codon/parser/visitors/simplify/function.cpp @@ -20,6 +20,7 @@ namespace codon::ast { void SimplifyVisitor::visit(YieldExpr *expr) { if (!ctx->inFunction()) E(Error::FN_OUTSIDE_ERROR, expr, "yield"); + ctx->getBase()->attributes->set(Attr::IsGenerator); } /// Transform lambdas. Capture outer expressions. @@ -45,6 +46,7 @@ void SimplifyVisitor::visit(YieldStmt *stmt) { if (!ctx->inFunction()) E(Error::FN_OUTSIDE_ERROR, stmt, "yield"); transform(stmt->expr); + ctx->getBase()->attributes->set(Attr::IsGenerator); } /// Transform `yield from` statements. diff --git a/codon/parser/visitors/typecheck/assign.cpp b/codon/parser/visitors/typecheck/assign.cpp index 2fbb95f1..7b6bac56 100644 --- a/codon/parser/visitors/typecheck/assign.cpp +++ b/codon/parser/visitors/typecheck/assign.cpp @@ -34,19 +34,27 @@ void TypecheckVisitor::visit(AssignStmt *stmt) { if (auto changed = in(ctx->cache->replacements, lhs)) { while (auto s = in(ctx->cache->replacements, lhs)) lhs = changed->first, changed = s; - if (stmt->rhs && changed->second) { - // Mark the dominating binding as used: `var.__used__ = True` - auto u = - N(N(fmt::format("{}.__used__", lhs)), N(true)); - u->setUpdate(); - prependStmts->push_back(transform(u)); - } else if (changed->second && !stmt->rhs) { - // This assignment was a declaration only. Just mark the dominating binding as - // used: `var.__used__ = True` - stmt->lhs = N(fmt::format("{}.__used__", lhs)); - stmt->rhs = N(true); + if (changed->second) { // has __used__ binding + if (stmt->rhs) { + // Mark the dominating binding as used: `var.__used__ = True` + auto u = N(N(fmt::format("{}.__used__", lhs)), + N(true)); + u->setUpdate(); + prependStmts->push_back(transform(u)); + } else { + // This assignment was a declaration only. Just mark the dominating binding as + // used: `var.__used__ = True` + stmt->lhs = N(fmt::format("{}.__used__", lhs)); + stmt->rhs = N(true); + } } - seqassert(stmt->rhs, "bad domination statement: '{}'", stmt->toString()); + + if (endswith(lhs, ".__used__") || !stmt->rhs) { + // unneeded declaration (unnecessary used or binding) + resultStmt = transform(N()); + return; + } + // Change this to the update and follow the update logic stmt->setUpdate(); transformUpdate(stmt); diff --git a/codon/parser/visitors/typecheck/call.cpp b/codon/parser/visitors/typecheck/call.cpp index 2db58451..13562523 100644 --- a/codon/parser/visitors/typecheck/call.cpp +++ b/codon/parser/visitors/typecheck/call.cpp @@ -46,6 +46,21 @@ void TypecheckVisitor::visit(EllipsisExpr *expr) { /// See @c transformCallArgs , @c getCalleeFn , @c callReorderArguments , /// @c typecheckCallArgs , @c transformSpecialCall and @c wrapExpr for more details. void TypecheckVisitor::visit(CallExpr *expr) { + if (expr->expr->isId("__internal__.undef") && expr->args.size() == 2 && + expr->args[0].value->getId()) { + auto val = expr->args[0].value->getId()->value; + val = val.substr(0, val.size() - 9); + if (auto changed = in(ctx->cache->replacements, val)) { + while (auto s = in(ctx->cache->replacements, val)) + val = changed->first, changed = s; + if (!changed->second) { + // TODO: add no-op expr + resultExpr = transform(N(false)); + return; + } + } + } + // Transform and expand arguments. Return early if it cannot be done yet if (!transformCallArgs(expr->args)) return; @@ -319,7 +334,7 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e } ExprPtr e = N(extra); e->setAttr(ExprAttr::StarArgument); - if (!expr->expr->isId("hasattr:0")) + if (!expr->expr->isId("hasattr")) e = transform(e); if (partial) { part.args = e; @@ -373,8 +388,16 @@ ExprPtr TypecheckVisitor::callReorderArguments(FuncTypePtr calleeFn, CallExpr *e E(Error::CALL_RECURSIVE_DEFAULT, expr, ctx->cache->rev(calleeFn->ast->args[si].name)); ctx->defaultCallDepth.insert(es); - args.push_back( - {realName, transform(clone(calleeFn->ast->args[si].defaultValue))}); + + if (calleeFn->ast->args[si].defaultValue->getNone() && + !calleeFn->ast->args[si].type) { + args.push_back( + {realName, transform(N(N( + N("Optional"), N("NoneType"))))}); + } else { + args.push_back( + {realName, transform(clone(calleeFn->ast->args[si].defaultValue))}); + } ctx->defaultCallDepth.erase(es); } } else { @@ -562,7 +585,7 @@ std::pair TypecheckVisitor::transformSpecialCall(CallExpr *expr) return {true, transformIsInstance(expr)}; } else if (val == "staticlen") { return {true, transformStaticLen(expr)}; - } else if (startswith(val, "hasattr:")) { + } else if (val == "hasattr") { return {true, transformHasAttr(expr)}; } else if (val == "getattr") { return {true, transformGetAttr(expr)}; @@ -812,38 +835,35 @@ ExprPtr TypecheckVisitor::transformHasAttr(CallExpr *expr) { auto typ = expr->args[0].value->getType()->getClass(); if (!typ) return nullptr; - auto member = expr->expr->type->getFunc() ->funcGenerics[0] .type->getStatic() ->evaluate() .getString(); std::vector> args{{"", typ}}; - if (expr->expr->isId("hasattr:0")) { - // Case: the first hasattr overload allows passing argument types via *args - auto tup = expr->args[1].value->getTuple(); - seqassert(tup, "not a tuple"); - for (auto &a : tup->items) { - transform(a); - if (!a->getType()->getClass()) - return nullptr; - args.push_back({"", a->getType()}); - } - auto kwtup = expr->args[2].value->origExpr->getCall(); - seqassert(expr->args[2].value->origExpr && expr->args[2].value->origExpr->getCall(), - "expected call: {}", expr->args[2].value->origExpr); - auto kw = expr->args[2].value->origExpr->getCall(); - auto kwCls = - in(ctx->cache->classes, expr->args[2].value->getType()->getClass()->name); - seqassert(kwCls, "cannot find {}", - expr->args[2].value->getType()->getClass()->name); - for (size_t i = 0; i < kw->args.size(); i++) { - auto &a = kw->args[i].value; - transform(a); - if (!a->getType()->getClass()) - return nullptr; - args.push_back({kwCls->fields[i].name, a->getType()}); - } + + // Case: passing argument types via *args + auto tup = expr->args[1].value->getTuple(); + seqassert(tup, "not a tuple"); + for (auto &a : tup->items) { + transform(a); + if (!a->getType()->getClass()) + return nullptr; + args.emplace_back("", a->getType()); + } + auto kwtup = expr->args[2].value->origExpr->getCall(); + seqassert(expr->args[2].value->origExpr && expr->args[2].value->origExpr->getCall(), + "expected call: {}", expr->args[2].value->origExpr); + auto kw = expr->args[2].value->origExpr->getCall(); + auto kwCls = + in(ctx->cache->classes, expr->args[2].value->getType()->getClass()->name); + seqassert(kwCls, "cannot find {}", expr->args[2].value->getType()->getClass()->name); + for (size_t i = 0; i < kw->args.size(); i++) { + auto &a = kw->args[i].value; + transform(a); + if (!a->getType()->getClass()) + return nullptr; + args.emplace_back(kwCls->fields[i].name, a->getType()); } if (typ->getUnion()) { diff --git a/codon/parser/visitors/typecheck/function.cpp b/codon/parser/visitors/typecheck/function.cpp index 9139a73f..e61a22bd 100644 --- a/codon/parser/visitors/typecheck/function.cpp +++ b/codon/parser/visitors/typecheck/function.cpp @@ -29,7 +29,14 @@ void TypecheckVisitor::visit(YieldExpr *expr) { /// Also partialize functions if they are being returned. /// See @c wrapExpr for more details. void TypecheckVisitor::visit(ReturnStmt *stmt) { - if (transform(stmt->expr)) { + if (!stmt->expr && ctx->getRealizationBase()->type && + ctx->getRealizationBase()->type->getFunc()->ast->hasAttr(Attr::IsGenerator)) { + stmt->setDone(); + } else { + if (!stmt->expr) { + stmt->expr = N(N("NoneType")); + } + transform(stmt->expr); // Wrap expression to match the return type if (!ctx->getRealizationBase()->returnType->getUnbound()) if (!wrapExpr(stmt->expr, ctx->getRealizationBase()->returnType)) { @@ -44,10 +51,6 @@ void TypecheckVisitor::visit(ReturnStmt *stmt) { } unify(ctx->getRealizationBase()->returnType, stmt->expr->type); - } else { - // Just set the expr for the translation stage. However, do not unify the return - // type! This might be a `return` in a generator. - stmt->expr = transform(N(N("NoneType"))); } // If we are not within conditional block, ignore later statements in this function. @@ -55,15 +58,16 @@ void TypecheckVisitor::visit(ReturnStmt *stmt) { if (!ctx->blockLevel) ctx->returnEarly = true; - if (stmt->expr->isDone()) + if (!stmt->expr || stmt->expr->isDone()) stmt->setDone(); } /// Typecheck yield statements. Empty yields assume `NoneType`. void TypecheckVisitor::visit(YieldStmt *stmt) { stmt->expr = transform(stmt->expr ? stmt->expr : N(N("NoneType"))); - unify(ctx->getRealizationBase()->returnType, - ctx->instantiateGeneric(ctx->getType("Generator"), {stmt->expr->type})); + + auto t = ctx->instantiateGeneric(ctx->getType("Generator"), {stmt->expr->type}); + unify(ctx->getRealizationBase()->returnType, t); if (stmt->expr->isDone()) stmt->setDone(); diff --git a/codon/parser/visitors/typecheck/op.cpp b/codon/parser/visitors/typecheck/op.cpp index cd3d07fc..331b87da 100644 --- a/codon/parser/visitors/typecheck/op.cpp +++ b/codon/parser/visitors/typecheck/op.cpp @@ -823,9 +823,10 @@ TypecheckVisitor::transformStaticTupleIndex(const ClassTypePtr &tuple, E(Error::TUPLE_RANGE_BOUNDS, index, sz - 1, i); te.push_back(N(clone(var), classFields[i].name)); } - ExprPtr e = transform( - N(std::vector{ass}, - N(N(N(TYPE_TUPLE), "__new__"), te))); + auto s = ctx->generateTuple(te.size()); + ExprPtr e = + transform(N(std::vector{ass}, + N(N(N(s), "__new__"), te))); return {true, e}; } } diff --git a/stdlib/internal/builtin.codon b/stdlib/internal/builtin.codon index f37e93d4..76f7244c 100644 --- a/stdlib/internal/builtin.codon +++ b/stdlib/internal/builtin.codon @@ -39,57 +39,78 @@ class __internal__: def print(*args): print(*args, flush=True, file=_C.seq_stdout()) - -def min(*args): +def min(*args, key=None): if staticlen(args) == 0: - raise ValueError("empty sequence") - elif staticlen(args) == 1 and hasattr(args[0], "__iter__"): + compile_error("min expected at least 1 argument, got 0") + elif staticlen(args) == 1: x = args[0].__iter__() if not x.done(): s = x.next() while not x.done(): i = x.next() - if i < s: - s = i + if key is None: + if i < s: + s = i + else: + if key(i) < key(s): + s = i x.destroy() return s else: x.destroy() - raise ValueError("empty sequence") + raise ValueError("min() arg is an empty sequence") elif staticlen(args) == 2: a, b = args - return a if a <= b else b + if key is None: + return a if a <= b else b + else: + return a if key(a) <= key(b) else b else: m = args[0] - for i in args: - if i < m: - m = i + for i in args[1:]: + if key is None: + if i < m: + m = i + else: + if key(i) < key(m): + m = i return m -def max(*args): +def max(*args, key=None): if staticlen(args) == 0: - raise ValueError("empty sequence") - elif staticlen(args) == 1 and hasattr(args[0], "__iter__"): + compile_error("max expected at least 1 argument, got 0") + elif staticlen(args) == 1: x = args[0].__iter__() if not x.done(): s = x.next() while not x.done(): i = x.next() - if i > s: - s = i + if key is None: + if i > s: + s = i + else: + if key(i) > key(s): + s = i x.destroy() return s else: x.destroy() - raise ValueError("empty sequence") + raise ValueError("max() arg is an empty sequence") elif staticlen(args) == 2: a, b = args - return a if a >= b else b + if key is None: + return a if a >= b else b + else: + return a if key(a) >= key(b) else b else: m = args[0] - for i in args: - if i > m: - m = i + for i in args[1:]: + if key is None: + if i > m: + m = i + else: + if key(i) > key(m): + m = i return m def len(x) -> int: diff --git a/stdlib/internal/core.codon b/stdlib/internal/core.codon index d761f1c7..7717ed1c 100644 --- a/stdlib/internal/core.codon +++ b/stdlib/internal/core.codon @@ -250,10 +250,6 @@ def hasattr(obj, attr: Static[str], *args, **kwargs): """Special handling""" pass -@overload -def hasattr(obj, attr: Static[str]): - pass - def getattr(obj, attr: Static[str]): pass diff --git a/stdlib/internal/internal.codon b/stdlib/internal/internal.codon index 17b8eef4..b7e733dd 100644 --- a/stdlib/internal/internal.codon +++ b/stdlib/internal/internal.codon @@ -69,7 +69,8 @@ class __internal__: """ global __vtables__ sz = __vtable_size__ + 1 - __vtables__ = Ptr[Ptr[cobj]](alloc_atomic_uncollectable(sz * sizeof(Ptr[cobj]))) + p = alloc_atomic_uncollectable(sz * sizeof(Ptr[cobj])) + __vtables__ = Ptr[Ptr[cobj]](p) __internal__.class_populate_vtables() def class_populate_vtables() -> None: @@ -95,7 +96,8 @@ class __internal__: def class_set_rtti_vtable(id: int, sz: int, T: type): if not __has_rtti__(T): compile_error("class is not polymorphic") - __vtables__[id] = Ptr[cobj](sz + 1) + p = alloc_atomic_uncollectable((sz + 1) * sizeof(cobj)) + __vtables__[id] = Ptr[cobj](p) __internal__.class_set_typeinfo(__vtables__[id], id) def class_set_rtti_vtable_fn(id: int, fid: int, f: cobj, T: type): diff --git a/test/core/bltin.codon b/test/core/bltin.codon index 0a2df0a4..5342d6f0 100644 --- a/test/core/bltin.codon +++ b/test/core/bltin.codon @@ -2,6 +2,8 @@ @test def test_min_max(): + neg = lambda x: -x + assert max(2, 1, 1, 1, 1) == 2 assert max(1, 2, 1, 1, 1) == 2 assert max(1, 1, 2, 1, 1) == 2 @@ -17,6 +19,21 @@ def test_min_max(): assert max(2, 1) == 2 assert max(1, 2) == 2 + assert max(2, 1, 1, 1, 1, key=neg) == 1 + assert max(1, 2, 1, 1, 1, key=neg) == 1 + assert max(1, 1, 2, 1, 1, key=neg) == 1 + assert max(1, 1, 1, 2, 1, key=neg) == 1 + assert max(1, 1, 1, 1, 2, key=neg) == 1 + assert max(2, 1, 1, 1, key=neg) == 1 + assert max(1, 2, 1, 1, key=neg) == 1 + assert max(1, 1, 2, 1, key=neg) == 1 + assert max(1, 1, 1, 2, key=neg) == 1 + assert max(2, 1, 1, key=neg) == 1 + assert max(1, 2, 1, key=neg) == 1 + assert max(1, 1, 2, key=neg) == 1 + assert max(2, 1, key=neg) == 1 + assert max(1, 2, key=neg) == 1 + assert min(2, 1, 1, 1, 1) == 1 assert min(1, 2, 1, 1, 1) == 1 assert min(1, 1, 2, 1, 1) == 1 @@ -32,6 +49,21 @@ def test_min_max(): assert min(2, 1) == 1 assert min(1, 2) == 1 + assert min(2, 1, 1, 1, 1, key=neg) == 2 + assert min(1, 2, 1, 1, 1, key=neg) == 2 + assert min(1, 1, 2, 1, 1, key=neg) == 2 + assert min(1, 1, 1, 2, 1, key=neg) == 2 + assert min(1, 1, 1, 1, 2, key=neg) == 2 + assert min(2, 1, 1, 1, key=neg) == 2 + assert min(1, 2, 1, 1, key=neg) == 2 + assert min(1, 1, 2, 1, key=neg) == 2 + assert min(1, 1, 1, 2, key=neg) == 2 + assert min(2, 1, 1, key=neg) == 2 + assert min(1, 2, 1, key=neg) == 2 + assert min(1, 1, 2, key=neg) == 2 + assert min(2, 1, key=neg) == 2 + assert min(1, 2, key=neg) == 2 + assert max(0, 1, 1, 1, 1) == 1 assert max(1, 0, 1, 1, 1) == 1 assert max(1, 1, 0, 1, 1) == 1 @@ -47,6 +79,21 @@ def test_min_max(): assert max(0, 1) == 1 assert max(1, 0) == 1 + assert max(0, 1, 1, 1, 1, key=neg) == 0 + assert max(1, 0, 1, 1, 1, key=neg) == 0 + assert max(1, 1, 0, 1, 1, key=neg) == 0 + assert max(1, 1, 1, 0, 1, key=neg) == 0 + assert max(1, 1, 1, 1, 0, key=neg) == 0 + assert max(0, 1, 1, 1, key=neg) == 0 + assert max(1, 0, 1, 1, key=neg) == 0 + assert max(1, 1, 0, 1, key=neg) == 0 + assert max(1, 1, 1, 0, key=neg) == 0 + assert max(0, 1, 1, key=neg) == 0 + assert max(1, 0, 1, key=neg) == 0 + assert max(1, 1, 0, key=neg) == 0 + assert max(0, 1, key=neg) == 0 + assert max(1, 0, key=neg) == 0 + assert min(0, 1, 1, 1, 1) == 0 assert min(1, 0, 1, 1, 1) == 0 assert min(1, 1, 0, 1, 1) == 0 @@ -62,11 +109,90 @@ def test_min_max(): assert min(0, 1) == 0 assert min(1, 0) == 0 + assert min(0, 1, 1, 1, 1, key=neg) == 1 + assert min(1, 0, 1, 1, 1, key=neg) == 1 + assert min(1, 1, 0, 1, 1, key=neg) == 1 + assert min(1, 1, 1, 0, 1, key=neg) == 1 + assert min(1, 1, 1, 1, 0, key=neg) == 1 + assert min(0, 1, 1, 1, key=neg) == 1 + assert min(1, 0, 1, 1, key=neg) == 1 + assert min(1, 1, 0, 1, key=neg) == 1 + assert min(1, 1, 1, 0, key=neg) == 1 + assert min(0, 1, 1, key=neg) == 1 + assert min(1, 0, 1, key=neg) == 1 + assert min(1, 1, 0, key=neg) == 1 + assert min(0, 1, key=neg) == 1 + assert min(1, 0, key=neg) == 1 + assert min(a*a for a in range(3)) == 0 assert max(a*a for a in range(3)) == 4 assert min([0, 2, -1]) == -1 assert max([0, 2, -1]) == 2 + assert min((a*a for a in range(3)), key=neg) == 4 + assert max((a*a for a in range(3)), key=neg) == 0 + assert min([0, 2, -1], key=neg) == 2 + assert max([0, 2, -1], key=neg) == -1 + + assert min('abcx') == 'a' + assert max('abcx') == 'x' + assert min(['a', 'b', 'c', 'x']) == 'a' + assert max(['a', 'b', 'c', 'x']) == 'x' + + d = {'a': 4, 'b': 1, 'c': -1, 'x': 9} + assert min('abcx', key=d.__getitem__) == 'c' + assert max('abcx', key=d.__getitem__) == 'x' + assert min(['a', 'b', 'c', 'x'], key=d.__getitem__) == 'c' + assert max(['a', 'b', 'c', 'x'], key=d.__getitem__) == 'x' + + try: + max('') + assert False + except ValueError as e: + assert str(e) == 'max() arg is an empty sequence' + + try: + min('') + assert False + except ValueError as e: + assert str(e) == 'min() arg is an empty sequence' + + try: + max(List[float]()) + assert False + except ValueError as e: + assert str(e) == 'max() arg is an empty sequence' + + try: + min(List[float]()) + assert False + except ValueError as e: + assert str(e) == 'min() arg is an empty sequence' + + try: + max('', key=lambda x: x * 2) + assert False + except ValueError as e: + assert str(e) == 'max() arg is an empty sequence' + + try: + min('', key=lambda x: x * 2) + assert False + except ValueError as e: + assert str(e) == 'min() arg is an empty sequence' + + try: + max(List[float](), key=lambda x: x * 2) + assert False + except ValueError as e: + assert str(e) == 'max() arg is an empty sequence' + + try: + min(List[float](), key=lambda x: x * 2) + assert False + except ValueError as e: + assert str(e) == 'min() arg is an empty sequence' + @test def test_map_filter(): assert list(map(lambda i: i+1, (i*2 for i in range(5)))) == [1, 3, 5, 7, 9] diff --git a/test/parser/typecheck_stmt.codon b/test/parser/typecheck_stmt.codon index c5878757..dfa56741 100644 --- a/test/parser/typecheck_stmt.codon +++ b/test/parser/typecheck_stmt.codon @@ -173,6 +173,41 @@ def foo(): yield 2 print list(foo()) #: [1] +def foo(x=0): + yield 1 + if x: + return + yield 2 +print list(foo()) #: [1, 2] +print list(foo(1)) #: [1] + +def foo(x=0): + if x: + return + yield 1 + yield 2 +print list(foo()) #: [1, 2] +print list(foo(1)) #: [] + +#%% return_none_err_1,barebones +def foo(n: int): + if n > 0: + return + else: + return 1 +foo(1) +#! 'NoneType' does not match expected type 'int' +#! during the realization of foo(n: int) + +#%% return_none_err_2,barebones +def foo(n: int): + if n > 0: + return 1 + return +foo(1) +#! 'int' does not match expected type 'NoneType' +#! during the realization of foo(n: int) + #%% while,barebones a = 3 while a: @@ -253,7 +288,7 @@ except MyError: print "my" except OSError as o: print "os", o.typename, len(o.message), o.file[-20:], o.line - #: os OSError 9 typecheck_stmt.codon 249 + #: os OSError 9 typecheck_stmt.codon 284 finally: print "whoa" #: whoa diff --git a/test/parser/types.codon b/test/parser/types.codon index 5ce6da7c..cf645622 100644 --- a/test/parser/types.codon +++ b/test/parser/types.codon @@ -2047,7 +2047,21 @@ def correlate(a, b, mode = 'valid'): else: raise ValueError(f"mode must be one of 'valid', 'same', or 'full' (got {repr(mode)})") return xret -print(correlate([1], [2], 'full')) # 5z +print(correlate([1], [2], 'full')) #: 5z + +def foo(x, y): + a = 5 + if isinstance(a, int): + if staticlen(y) == 0: + a = 0 + elif staticlen(y) == 1: + a = 1 + else: + for i in range(10): + a = 40 + return a + return a +print foo(5, (1, 2, 3)) #: 40 #%% union_hasattr,barebones class A: