diff --git a/.github/workflows/macos_xmake.yml b/.github/workflows/macos_xmake.yml index 74b916a8..35e0eb2c 100644 --- a/.github/workflows/macos_xmake.yml +++ b/.github/workflows/macos_xmake.yml @@ -2,9 +2,9 @@ name: MacOS (xmake) on: push: - branches: [ master ] + branches: [ main ] pull_request: - branches: [ master ] + branches: [ main ] jobs: build_with_xmake_on_macos: diff --git a/.github/workflows/windows_xmake.yml b/.github/workflows/windows_xmake.yml index 261254e4..4aacee26 100644 --- a/.github/workflows/windows_xmake.yml +++ b/.github/workflows/windows_xmake.yml @@ -2,9 +2,9 @@ name: Windows (xmake) on: push: - branches: [ master ] + branches: [ main ] pull_request: - branches: [ master ] + branches: [ main ] jobs: build_with_xmake_on_windows: diff --git a/BUILD.bazel b/BUILD.bazel index 30f3cfa1..16ee52d4 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -31,9 +31,9 @@ cc_library( [ "src/**/*.cpp", ], - exclude = ["src/codegen/**"], + exclude = ["src/codegen/**", "src/core/**"], ) + select({ - ":mlir_codegen": glob(["src/codegen/**/*.cpp"]), + ":mlir_codegen": glob(["src/codegen/**/*.cpp", "src/core/**/*.cpp", "src/core/**/*.h"]), "//conditions:default": [], }) + [ ":src/version.cpp", diff --git a/CMakeLists.txt b/CMakeLists.txt index ce1d022b..8fa6ab54 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -103,6 +103,26 @@ file(GLOB PSCM_SRCS "src/*.cpp") target_sources(pscm PRIVATE ${PSCM_SRCS}) if (PSCM_ENABLE_MLIR_CODEGEN) message(STATUS "Enable codegen with MLIR") + target_sources(pscm PUBLIC + src/core/Value.cpp + src/core/Value.h + src/core/Parser.cpp + src/core/Parser.h + src/core/Evaluator.cpp + src/core/Evaluator.h + src/core/SymbolTable.cpp + src/core/SymbolTable.h + src/core/Procedure.cpp + src/core/Procedure.h + src/core/JIT.cpp + src/core/JIT.h + src/core/Scheme.cpp + src/core/Scheme.h + src/core/Mangler.cpp + src/core/Mangler.h + src/core/Runtime.cpp + src/core/Runtime.h + ) target_compile_definitions(pscm PRIVATE PSCM_ENABLE_MLIR_CODEGEN) file(GLOB CODEGEN_SRCS src/codegen/*.cpp src/codegen/mlir/*.cpp src/codegen/llvm_ir/*.cpp) target_sources(pscm PRIVATE ${CODEGEN_SRCS}) diff --git a/include/pscm/common_def.h b/include/pscm/common_def.h index a64b3baf..758c4629 100644 --- a/include/pscm/common_def.h +++ b/include/pscm/common_def.h @@ -9,6 +9,9 @@ #define PSCM_THROW_EXCEPTION(msg) \ PSCM_ERROR("Exception occurred here: {0}", (msg)); \ throw ::pscm::Exception(msg) + +#define PSCM_UNIMPLEMENTED() PSCM_THROW_EXCEPTION("Unimplemented") + #define PSCM_ASSERT(e) \ if (!(e)) { \ PSCM_ERROR("ASSERT FAILED here: {0}", #e); \ diff --git a/src/core/Evaluator.cpp b/src/core/Evaluator.cpp new file mode 100644 index 00000000..f7f18fd3 --- /dev/null +++ b/src/core/Evaluator.cpp @@ -0,0 +1,310 @@ +#include "Evaluator.h" +#include "Mangler.h" +#include "Procedure.h" +#include "SymbolTable.h" +#include "Value.h" + +#include + +namespace pscm::core { +PSCM_INLINE_LOG_DECLARE("pscm.core.Evaluator"); + +template +struct BinaryOperation { + template + static typename Op::ReturnType *eval(LeftType *lhs, RightType *rhs) { + return Op()(lhs, rhs); + } +}; + +template +struct GetReturnType; + +template <> +struct GetReturnType { + using value = IntegerValue; +}; + +template +class AddOp { +public: + using ReturnType = typename GetReturnType::value; + + ReturnType *operator()(LeftType *lhs, RightType *rhs) { + auto ret = lhs->value() + rhs->value(); + return new ReturnType(ret); + } +}; + +template +class MinusOp { +public: + using ReturnType = typename GetReturnType::value; + + ReturnType *operator()(LeftType *lhs, RightType *rhs) { + auto ret = lhs->value() - rhs->value(); + return new ReturnType(ret); + } +}; + +class EvaluatorImpl { +public: + EvaluatorImpl() + : sym_table_(new SymbolTable()) { + auto sym_car = new SymbolValue("car"); + sym_table_->put(sym_car, new Procedure(sym_car, {}, {}, nullptr)); + } + + ~EvaluatorImpl() { + delete sym_table_; + } + + AST *eval(Value *expr) { + if (auto p = dynamic_cast(expr); p) { + return eval(p); + } + // if (auto p = dynamic_cast(expr); p) { + // return eval(p); + // } + if (auto p = dynamic_cast(expr); p) { + return eval(p); + } + // if (auto p = dynamic_cast(expr); p) { + // return p; + // } + if (auto p = dynamic_cast(expr); p) { + return p; + } + // if (auto p = dynamic_cast(expr); p) { + // return p; + // } + if (expr) { + PSCM_THROW_EXCEPTION("Unsupported type: " + expr->to_string()); + } + PSCM_UNIMPLEMENTED(); + }; + + AST *eval(ListValue *expr) { + PSCM_ASSERT(expr); + auto value_list = expr->value(); + PSCM_ASSERT(!value_list.empty()); + if (auto p = dynamic_cast(value_list[0]); p) { + if (auto value = sym_table_->lookup(p); value) { + if (auto f = dynamic_cast(value); f) { + std::vector args; + std::vector arg_types; + for (int i = 1; i < value_list.size(); ++i) { + auto arg = eval(value_list[i]); + if (auto func_arg = dynamic_cast(arg); func_arg) { + args.push_back(func_arg); + auto type = func_arg->type(); + arg_types.push_back(type); + } + else { + PSCM_UNIMPLEMENTED(); + } + } + auto call = new CallExprAST(f->name()->to_string(), args, arg_types); + return call; + } + } + /* + (cond ((> 3 2) 100) + ((< 3 2) 200)) + */ + if (p->to_string() == "cond") { + return create_cond(value_list); + } + if (p->to_string() == "map") { + return create_map(value_list); + } + if (p->to_string() == "quote") { + auto value = value_list[1]; + if (auto list = dynamic_cast(value); list) { + std::vector array_value_list; + for (int i = 0; i < list->value().size(); ++i) { + auto array_value = eval(list->value()[i]); + if (auto array_value_expr_ast = dynamic_cast(array_value); array_value_expr_ast) { + array_value_list.push_back(array_value_expr_ast); + } + else { + PSCM_UNIMPLEMENTED(); + } + } + return new ArrayExprAST(std::move(array_value_list)); + } + } + + std::vector operands; + operands.reserve(value_list.size() - 1); + for (int i = 1; i < value_list.size(); ++i) { + auto arg = eval(value_list[i]); + if (auto ast = dynamic_cast(arg); ast) { + operands.push_back(ast); + } + else { + PSCM_UNIMPLEMENTED(); + } + } + auto sym = p->to_string(); + if (sym == ">" || sym == "<" || sym == "=") { + if (operands.size() < 2) { + PSCM_THROW_EXCEPTION("Invalid arguments: " + expr->to_string() + ", require at least 2"); + } + auto lhs = operands[0]; + auto rhs = operands[1]; + return new BinaryExprAST(p, lhs, rhs); + } + if (sym == "+" || sym == "-") { + int start_index = 0; + ExprAST *ret = IntegerValue::zero(); + if (operands.empty()) { + return ret; + } + if (operands.size() == 1) { + if (sym == "+") { + if (auto ret_int = dynamic_cast(operands[0]); ret_int) { + return ret_int; + } + else { + PSCM_UNIMPLEMENTED(); + } + } + } + else { + ret = operands[0]; + start_index = 1; + } + if (sym == "-") { + if (auto ret_int = dynamic_cast(operands[0]); ret_int) { + if (operands.size() == 1) { + ret = new BinaryExprAST(p, ret, ret_int); + return ret; + } + else { + ret = ret_int; + } + } + else { + PSCM_UNIMPLEMENTED(); + } + start_index = 1; + } + for (int i = start_index; i < operands.size(); ++i) { + auto arg = operands[i]; + ret = new BinaryExprAST(p, ret, arg); + } + return ret; + } + } + return nullptr; + } + + Value *eval(DottedListValue *expr) { + return nullptr; + } + + ExprAST *eval(SymbolValue *expr) { + return sym_table_->lookup(expr); + } + + IfExprAST *create_cond(const std::vector& value_list) { + std::vector list; + std::vector> cond_stmt_list; + ExprAST *else_stmt = nullptr; + for (int i = 1; i < value_list.size(); ++i) { + if (auto cond_stmt = dynamic_cast(value_list[i]); cond_stmt) { + PSCM_ASSERT(!cond_stmt->value().empty()); + auto cond = eval(cond_stmt->value()[0]); + auto then_value = eval(cond_stmt->value()[1]); + if (!cond) { + if (auto maybe_else = dynamic_cast(cond_stmt->value()[0]); maybe_else) { + if (maybe_else->to_string() == "else") { + auto else_expr_ast = dynamic_cast(then_value); + PSCM_ASSERT(else_expr_ast); + else_stmt = else_expr_ast; + } + else { + PSCM_THROW_EXCEPTION("Invalid symbol: " + maybe_else->to_string()); + } + } + } + else if (auto cond_expr_ast = dynamic_cast(cond); cond_expr_ast) { + if (auto else_sym = dynamic_cast(cond); else_sym) { + PSCM_UNIMPLEMENTED(); + } + else { + auto then_value_expr_ast = dynamic_cast(then_value); + PSCM_ASSERT(then_value_expr_ast); + cond_stmt_list.emplace_back(cond_expr_ast, then_value_expr_ast); + } + } + else { + PSCM_THROW_EXCEPTION("Invalid cond statement: " + cond_stmt->to_string()); + } + } + else { + PSCM_UNIMPLEMENTED(); + } + } + IfExprAST *if_expr_ast = nullptr; + for (auto& [cond, then_value] : cond_stmt_list) { + if (if_expr_ast) { + if_expr_ast->add_else_if(cond, then_value); + } + else { + if_expr_ast = new IfExprAST(cond, then_value, else_stmt); + } + } + return if_expr_ast; + } + + // (map abs '(4 -5 6)) + MapExprAST *create_map(const std::vector& value_list) { + PSCM_ASSERT(value_list.size() == 3); + auto f = eval(value_list[1]); + auto args = eval(value_list[2]); + auto proc = dynamic_cast(f); + auto proc_name = proc->name()->to_string(); + if (auto expr_ast = dynamic_cast(args); expr_ast) { + return new MapExprAST(proc_name, expr_ast); + } + PSCM_UNIMPLEMENTED(); + } + + SymbolTable *sym_table_; +}; + +Evaluator::Evaluator() + : impl_(new EvaluatorImpl()) { +} + +Evaluator::~Evaluator() { + delete impl_; +} + +AST *Evaluator::eval(Value *expr) { + return impl_->eval(expr); +} + +void Evaluator::add_proc(SymbolValue *sym, ExprAST *value) { + impl_->sym_table_->put(sym, value); +} + +void Evaluator::push_symbol_table() { + auto table = new SymbolTable(impl_->sym_table_); + impl_->sym_table_ = table; +} + +void Evaluator::pop_symbol_table() { + auto parent = impl_->sym_table_->parent(); + PSCM_ASSERT(parent); + delete impl_->sym_table_; + impl_->sym_table_ = parent; +} + +void Evaluator::add_sym(SymbolValue *sym, ExprAST *value) { + impl_->sym_table_->put(sym, value); +} + +} // namespace pscm::core \ No newline at end of file diff --git a/src/core/Evaluator.h b/src/core/Evaluator.h new file mode 100644 index 00000000..a79b5905 --- /dev/null +++ b/src/core/Evaluator.h @@ -0,0 +1,25 @@ +#pragma once + +namespace pscm::core { +class Value; +class SymbolValue; +class AST; +class ExprAST; +class Type; +class EvaluatorImpl; + +class Evaluator { +public: + Evaluator(); + ~Evaluator(); + AST *eval(Value *expr); + void add_proc(SymbolValue *sym, ExprAST *value); + void push_symbol_table(); + void pop_symbol_table(); + void add_sym(SymbolValue *sym, ExprAST *value); + +private: + EvaluatorImpl *impl_; +}; + +} // namespace pscm::core diff --git a/src/core/JIT.cpp b/src/core/JIT.cpp new file mode 100644 index 00000000..5def9ee9 --- /dev/null +++ b/src/core/JIT.cpp @@ -0,0 +1,35 @@ +#include "JIT.h" +#include "Runtime.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include + +namespace pscm::core { +PSCM_INLINE_LOG_DECLARE("pscm.core.JIT"); + +JIT::JIT(std::unique_ptr es, llvm::orc::JITTargetMachineBuilder jtmb, llvm::DataLayout dl) + : es_(std::move(es)) + , dl_(std::move(dl)) + , mangle_(*this->es_, this->dl_) + , object_layer_(*this->es_, + []() { + return std::make_unique(); + }) + , compile_layer_(*this->es_, object_layer_, std::make_unique(std::move(jtmb))) + , main_jd_(this->es_->createBareJITDylib("
")) { + main_jd_.addGenerator( + llvm::cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(dl_.getGlobalPrefix()))); + + void *malloc_addr = llvm::sys::DynamicLibrary::SearchForAddressOfSymbol("malloc"); + if (!malloc_addr) { + PSCM_THROW_EXCEPTION("malloc not found"); + } + auto malloc_sym = llvm::orc::ExecutorSymbolDef(llvm::orc::ExecutorAddr(llvm::pointerToJITTargetAddress(malloc_addr)), + llvm::JITSymbolFlags::Exported); + llvm::cantFail(main_jd_.define(llvm::orc::absoluteSymbols({ + { es_->intern("malloc"), malloc_sym }, + { es_->intern("car_array[integer]"), + llvm::orc::ExecutorSymbolDef(llvm::orc::ExecutorAddr(llvm::pointerToJITTargetAddress(car_array)), + llvm::JITSymbolFlags::Exported) } + }))); +} +} // namespace pscm::core \ No newline at end of file diff --git a/src/core/JIT.h b/src/core/JIT.h new file mode 100644 index 00000000..47422e4b --- /dev/null +++ b/src/core/JIT.h @@ -0,0 +1,69 @@ +#pragma once +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/Orc/CompileUtils.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" +#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h" +#include "llvm/ExecutionEngine/SectionMemoryManager.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/LLVMContext.h" +#include + +namespace pscm::core { + +class JIT { +public: + JIT(std::unique_ptr es, llvm::orc::JITTargetMachineBuilder jtmb, llvm::DataLayout dl); + + ~JIT() { + if (auto err = es_->endSession(); err) { + es_->reportError(std::move(err)); + } + } + + [[nodiscard]] static llvm::Expected> create() { + auto epc = llvm::orc::SelfExecutorProcessControl::Create(); + if (!epc) { + return epc.takeError(); + } + auto es = std::make_unique(std::move(*epc)); + llvm::orc::JITTargetMachineBuilder jtmb(es->getExecutorProcessControl().getTargetTriple()); + auto dl = jtmb.getDefaultDataLayoutForTarget(); + if (!dl) { + return dl.takeError(); + } + return std::make_unique(std::move(es), std::move(jtmb), std::move(*dl)); + } + + [[nodiscard]] const llvm::DataLayout& data_layout() const { + return dl_; + } + + [[nodiscard]] llvm::orc::JITDylib& main_jit_dylib() { + return main_jd_; + } + + [[nodiscard]] llvm::Error add_module(llvm::orc::ThreadSafeModule tsm, llvm::orc::ResourceTrackerSP rt = nullptr) { + if (!rt) { + rt = main_jd_.getDefaultResourceTracker(); + } + return compile_layer_.add(rt, std::move(tsm)); + } + + [[nodiscard]] llvm::Expected lookup(llvm::StringRef name) { + return es_->lookup({ &main_jd_ }, mangle_(name.str())); + } + +private: + std::unique_ptr es_; + llvm::DataLayout dl_; + llvm::orc::MangleAndInterner mangle_; + llvm::orc::RTDyldObjectLinkingLayer object_layer_; + llvm::orc::IRCompileLayer compile_layer_; + llvm::orc::JITDylib& main_jd_; +}; +} // namespace pscm::core diff --git a/src/core/Mangler.cpp b/src/core/Mangler.cpp new file mode 100644 index 00000000..c3d00358 --- /dev/null +++ b/src/core/Mangler.cpp @@ -0,0 +1,17 @@ +#include "Mangler.h" +#include "Value.h" +#include + +namespace pscm::core { + +std::string Mangler::mangle(const std::string& callee, const std::vector& arg_type_list) const { + PSCM_INLINE_LOG_DECLARE("pscm.core.mangle_name"); + std::stringstream ss; + ss << callee; + for (auto arg_type : arg_type_list) { + ss << "_"; + ss << arg_type->to_string(); + } + return ss.str(); +} +} // namespace pscm::core diff --git a/src/core/Mangler.h b/src/core/Mangler.h new file mode 100644 index 00000000..ac679157 --- /dev/null +++ b/src/core/Mangler.h @@ -0,0 +1,13 @@ +#pragma once +#include + +namespace pscm::core { +class ExprAST; +class Type; + +class Mangler { +public: + [[nodiscard]] std::string mangle(const std::string& callee, const std::vector& args) const; +}; + +} // namespace pscm::core diff --git a/src/core/Parser.cpp b/src/core/Parser.cpp new file mode 100644 index 00000000..01672c39 --- /dev/null +++ b/src/core/Parser.cpp @@ -0,0 +1,149 @@ +#include "Parser.h" +#include "pscm/Parser.h" + +#include "Procedure.h" +#include "Value.h" + +#include +#include +#include +#include + +#include +#include +using namespace std::string_literals; + +namespace pscm::core { +PSCM_INLINE_LOG_DECLARE("pscm.core.Parser"); + +Procedure *create_proc(const std::vector& list) { + PSCM_ASSERT(list.size() == 3); + auto head = list[1]; + std::vector body(list.begin() + 2, list.end()); + if (auto p = dynamic_cast(head); p) { + auto name = dynamic_cast(p->value1()[0]); + PSCM_ASSERT(name); + std::vector args; + args.reserve(p->value1().size() - 1); + for (int i = 1; i < p->value1().size(); ++i) { + auto arg = dynamic_cast(p->value1()[i]); + PSCM_ASSERT(arg); + args.push_back(arg); + } + auto vararg = dynamic_cast(p->value2()); + auto proc = new Procedure(name, args, body, nullptr, (SymbolValue *)vararg); + return proc; + } + else if (auto p = dynamic_cast(head); p) { + auto name = dynamic_cast(p->value()[0]); + PSCM_ASSERT(name); + std::vector args; + args.reserve(p->value().size() - 1); + for (int i = 1; i < p->value().size(); ++i) { + auto arg = dynamic_cast(p->value()[i]); + PSCM_ASSERT(arg); + args.push_back(arg); + } + auto proc = new Procedure(name, args, body, nullptr); + return proc; + } + else { + PSCM_THROW_EXCEPTION("Unsupported type: " + head->to_string()); + } +} + +class ParserImpl { +public: + ParserImpl(std::string code) + : parser_(code.c_str()) { + } + + Value *parse() { + auto cell = parser_.parse(); + auto ret = convert_cell_to_value(cell); + return ret; + } + + Value *convert_cell_to_value(Cell cell) { + if (cell.is_none()) { + return nullptr; + } + if (cell.is_bool()) { + auto value = cell.to_bool(); + return value ? (Value *)TrueValue::instance() : (Value *)FalseValue::instance(); + } + else if (cell.is_sym()) { + auto value = cell.to_sym(); + return new SymbolValue(cell.to_std_string()); + } + else if (cell.is_str()) { + auto value = cell.to_str(); + std::string converted; + value->str().toUTF8String(converted); + return new StringValue(std::move(converted)); + } + else if (cell.is_num()) { + if (auto value = cell.to_num(); value->is_int()) { + auto int_value = value->to_int(); + return new IntegerValue(int_value); + } + else { + PSCM_THROW_EXCEPTION("Unsupported number type: " + cell.to_string()); + } + } + else if (cell.is_pair()) { + std::vector value_list; + while (cell.is_pair()) { + auto item = car(cell); + auto value = convert_cell_to_value(item); + PSCM_ASSERT(value); + value_list.push_back(value); + cell = cdr(cell); + } + if (cell.is_nil()) { + if (value_list.size() >= 3 && value_list[0]->to_string() == "define") { + if (auto list = dynamic_cast(value_list[1]); list) { + if (auto sym = dynamic_cast(list->value()[0]); sym) { + auto proc = create_proc(value_list); + return proc; + } + } + } + return new ListValue(std::move(value_list)); + } + else { + auto last_value = convert_cell_to_value(cell); + return new DottedListValue(std::move(value_list), last_value); + } + } + else { + PSCM_THROW_EXCEPTION("Unsupported type: " + cell.to_string()); + } + } + + pscm::Parser parser_; +}; + +Parser::Parser(std::string code) + : impl_(new ParserImpl(std::move(code))) { +} + +Parser::~Parser() { + delete impl_; +} + +std::vector Parser::parse_all() { + std::vector ret; + auto value = parse_one(); + while (value != nullptr) { + ret.push_back(value); + value = parse_one(); + } + return ret; +} + +Value *Parser::parse_one() { + PSCM_ASSERT(impl_); + return impl_->parse(); +} +} // namespace pscm::core diff --git a/src/core/Parser.h b/src/core/Parser.h new file mode 100644 index 00000000..9fa69b63 --- /dev/null +++ b/src/core/Parser.h @@ -0,0 +1,20 @@ +#pragma once +#include +#include + +namespace pscm::core { +class Value; +class ParserImpl; + +class Parser { +public: + explicit Parser(std::string code); + ~Parser(); + std::vector parse_all(); + Value *parse_one(); + +private: + ParserImpl *impl_; +}; + +} // namespace pscm::core diff --git a/src/core/Procedure.cpp b/src/core/Procedure.cpp new file mode 100644 index 00000000..ef6da791 --- /dev/null +++ b/src/core/Procedure.cpp @@ -0,0 +1,42 @@ +#include "Procedure.h" +#include + +namespace pscm::core { +std::string Procedure::to_string() const { + std::stringstream ss; + ss << "#"; + ss << "to_string(); + } + else { + ss << "#f"; + } + ss << " "; + ss << "("; + for (auto arg : args_) { + ss << arg->to_string(); + ss << " "; + } + if (vararg_.has_value()) { + ss << "."; + ss << " "; + ss << vararg_.value()->to_string(); + } + else { + ss.seekp(-1, std::ios::cur); + } + ss << ")"; + ss << ">"; + return ss.str(); +} + +llvm::Value *Procedure::codegen(CodegenContext& ctx) { + return nullptr; +} + +const Type *Procedure::type() const { + return nullptr; +} + +} // namespace pscm::core diff --git a/src/core/Procedure.h b/src/core/Procedure.h new file mode 100644 index 00000000..e75750c2 --- /dev/null +++ b/src/core/Procedure.h @@ -0,0 +1,53 @@ +#pragma once + +#include "Value.h" +#include +#include + +namespace pscm::core { +class Value; +class SymbolValue; +class SymbolTable; + +class Procedure + : public Value + , public ExprAST { +public: + Procedure(SymbolValue *name, std::vector args, std::vector body, SymbolTable *env, + std::optional vararg = std::nullopt) + : name_(name) + , args_(std::move(args)) + , body_(std::move(body)) + , env_(env) + , vararg_(std::move(vararg)) { + } + + [[nodiscard]] SymbolValue *name() const { + return name_; + } + + void set_name(SymbolValue *name) { + this->name_ = name; + } + + [[nodiscard]] const std::vector& args() const { + return args_; + } + + [[nodiscard]] const std::vector& body() const { + return body_; + } + + [[nodiscard]] std::string to_string() const override; + llvm::Value *codegen(CodegenContext& ctx) override; + const Type *type() const override; + +private: + SymbolValue *name_; + std::vector args_; + std::vector body_; + SymbolTable *env_; + std::optional vararg_; +}; + +} // namespace pscm::core diff --git a/src/core/Runtime.cpp b/src/core/Runtime.cpp new file mode 100644 index 00000000..06e38087 --- /dev/null +++ b/src/core/Runtime.cpp @@ -0,0 +1,14 @@ + +#include "Runtime.h" +#include + +namespace pscm::core { +PSCM_INLINE_LOG_DECLARE("pscm.core.Runtime"); + +int64_t car_array(Array *input) { + PSCM_ASSERT(input); + PSCM_ASSERT(input->size > 0); + return input->data[0]; +} + +} // namespace pscm::core diff --git a/src/core/Runtime.h b/src/core/Runtime.h new file mode 100644 index 00000000..44b9c998 --- /dev/null +++ b/src/core/Runtime.h @@ -0,0 +1,9 @@ +#pragma once +#include "Value.h" + +namespace pscm::core { +int64_t car_array(Array *input); + +class Runtime {}; + +} // namespace pscm::core diff --git a/src/core/Scheme.cpp b/src/core/Scheme.cpp new file mode 100644 index 00000000..d9fbb779 --- /dev/null +++ b/src/core/Scheme.cpp @@ -0,0 +1,138 @@ +#include "Scheme.h" +#include "Evaluator.h" +#include "JIT.h" +#include "Parser.h" +#include "Procedure.h" +#include "Value.h" +#include "pscm/logger/Appender.h" +#include + +namespace pscm::core { +static llvm::ExitOnError exit_on_err; +PSCM_INLINE_LOG_DECLARE("pscm.core.Scheme"); + +class SchemeImpl { +public: + SchemeImpl() { + init(); + } + + void init() { + if (has_init_) { + return; + } + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeAllAsmParsers(); + pscm::logger::Logger::root_logger()->add_appender(new pscm::logger::ConsoleAppender()); + has_init_ = true; + } + + Value *eval(const std::string& code) { + auto llvm_ctx = std::make_unique(); + auto llvm_module = std::make_unique("pscm jit", *llvm_ctx); + auto builder = std::make_unique>(*llvm_ctx); + + Parser parser(code); + Evaluator evaluator; + CodegenContext ctx{ + .llvm_ctx = *llvm_ctx, .llvm_module = *llvm_module, .builder = *builder, .evaluator = evaluator + }; + llvm_module->getOrInsertFunction("car_array[integer]", + llvm::FunctionType::get(llvm::Type::getInt64Ty(*llvm_ctx), + { llvm::PointerType::get(ctx.get_array(), 0) }, false)); + auto value = parser.parse_one(); + AST *ast = nullptr; + while (value) { + if (auto proc = dynamic_cast(value); proc) { + evaluator.add_proc(proc->name(), proc); + ctx.proc_map[proc->name()->to_string()] = proc; + value = parser.parse_one(); + continue; + } + ast = evaluator.eval(value); + if (auto expr = dynamic_cast(ast); expr) { + break; + } + else { + PSCM_ASSERT(ast); + [[maybe_unused]] auto v = ast->codegen(ctx); + value = parser.parse_one(); + } + } + + auto proto = new PrototypeAST("__anon_expr", {}, {}); + auto func = new FunctionAST(proto, ast); + + func->codegen(ctx); + llvm::errs() << ctx.llvm_module; + llvm::errs() << "\n"; + llvm::verifyModule(ctx.llvm_module); + auto jit = exit_on_err(JIT::create()); + auto rt = jit->main_jit_dylib().createResourceTracker(); + auto tsm = llvm::orc::ThreadSafeModule(std::move(llvm_module), std::move(llvm_ctx)); + exit_on_err(jit->add_module(std::move(tsm), rt)); + auto expr_sym = exit_on_err(jit->lookup("_anon_expr")); + if (auto map_expr = dynamic_cast(ast); map_expr) { + // Array + Array *(*fp)() = expr_sym.getAddress().toPtr(); + auto eval_ret = fp(); + std::vector list; + list.reserve(eval_ret->size); + for (int i = 0; i < eval_ret->size; ++i) { + list.push_back(new IntegerValue(eval_ret->data[i])); + } + return new ListValue(list); + // return new IntegerValue(eval_ret->size); + } + else if (auto call_expr = dynamic_cast(ast); call_expr) { + if (call_expr->type()) { + if (auto array_type = dynamic_cast(call_expr->type()); array_type) { + Array *(*fp)() = expr_sym.getAddress().toPtr(); + auto eval_ret = fp(); + std::vector list; + list.reserve(eval_ret->size); + for (int i = 0; i < eval_ret->size; ++i) { + list.push_back(new IntegerValue(eval_ret->data[i])); + } + return new ListValue(list); + } + else if (auto integer_type = dynamic_cast(call_expr->type()); integer_type) { + int (*fp)() = expr_sym.getAddress().toPtr(); + auto eval_ret = fp(); + return new IntegerValue(eval_ret); + } + else { + PSCM_UNIMPLEMENTED(); + } + } + else { + PSCM_UNIMPLEMENTED(); + } + } + else { + int (*fp)() = expr_sym.getAddress().toPtr(); + auto eval_ret = fp(); + return new IntegerValue(eval_ret); + } + } + + bool has_init_ = false; +}; + +Scheme::Scheme() + : impl_(new SchemeImpl()) { +} + +Value *Scheme::eval(const char *code) { + return impl_->eval(code); +} + +Scheme::~Scheme() { + delete impl_; +} + +void Scheme::set_logger_level(int level) { + pscm::logger::Logger::get_logger("pscm")->set_level(static_cast(level)); +} +} // namespace pscm::core diff --git a/src/core/Scheme.h b/src/core/Scheme.h new file mode 100644 index 00000000..02fcbe67 --- /dev/null +++ b/src/core/Scheme.h @@ -0,0 +1,19 @@ +#pragma once + +namespace pscm::core { +class Value; +class SchemeImpl; + +class Scheme { +public: + Scheme(); + ~Scheme(); + + Value *eval(const char *code); + void set_logger_level(int level); + +private: + SchemeImpl *impl_; +}; + +} // namespace pscm::core diff --git a/src/core/SymbolTable.cpp b/src/core/SymbolTable.cpp new file mode 100644 index 00000000..53a4d163 --- /dev/null +++ b/src/core/SymbolTable.cpp @@ -0,0 +1,51 @@ + +#include "SymbolTable.h" + +#include + +namespace pscm::core { +class SymbolTableImpl { +public: + explicit SymbolTableImpl(SymbolTable *parent) + : parent_(parent) { + } + + void put(SymbolValue *sym, ExprAST *value) { + sym_table_.insert_or_assign(sym->to_string(), value); + } + + [[nodiscard]] ExprAST *lookup(SymbolValue *sym) { + auto it = sym_table_.find(sym->to_string()); + if (it != sym_table_.end()) { + return it->second; + } + if (parent_) { + return parent_->lookup(sym); + } + return nullptr; + } + + SymbolTable *parent_; + std::unordered_map sym_table_; +}; + +SymbolTable::SymbolTable(SymbolTable *parent) + : impl_(new SymbolTableImpl(parent)) { +} + +SymbolTable::~SymbolTable() { + delete impl_; +} + +ExprAST *SymbolTable::lookup(SymbolValue *sym) const { + return impl_->lookup(sym); +} + +void SymbolTable::put(SymbolValue *sym, ExprAST *value) { + impl_->put(sym, value); +} + +SymbolTable *SymbolTable::parent() const { + return impl_->parent_; +} +} // namespace pscm::core diff --git a/src/core/SymbolTable.h b/src/core/SymbolTable.h new file mode 100644 index 00000000..4c8397d0 --- /dev/null +++ b/src/core/SymbolTable.h @@ -0,0 +1,23 @@ +#pragma once +#include "Value.h" +#include + +namespace pscm::core { +class Value; +class ExprAST; +class SymbolValue; +class SymbolTableImpl; + +class SymbolTable { +public: + explicit SymbolTable(SymbolTable *parent = nullptr); + ~SymbolTable(); + [[nodiscard]] ExprAST *lookup(SymbolValue *sym) const; + void put(SymbolValue *sym, ExprAST *value); + [[nodiscard]] SymbolTable *parent() const; + +private: + SymbolTableImpl *impl_; +}; + +} // namespace pscm::core diff --git a/src/core/Value.cpp b/src/core/Value.cpp new file mode 100644 index 00000000..d2e7089a --- /dev/null +++ b/src/core/Value.cpp @@ -0,0 +1,539 @@ +#include "Value.h" +#include "Evaluator.h" +#include "Mangler.h" +#include "Procedure.h" +#include +#include + +namespace pscm::core { + +llvm::Function *CodegenContext::get_function(const std::string& name) { + if (auto f = llvm_module.getFunction(name); f) { + return f; + } + auto it = func_proto_map.find(name); + if (it != func_proto_map.end()) { + return it->second->codegen(*this); + } + return nullptr; +} + +llvm::FunctionCallee CodegenContext::get_malloc() { + auto malloc_func = + llvm_module.getOrInsertFunction("malloc", llvm::FunctionType::get(llvm::Type::getInt8PtrTy(llvm_ctx), + { llvm::Type::getInt64Ty(llvm_ctx) }, false)); + return malloc_func; +} + +llvm::StructType *CodegenContext::get_array() { + auto array_struct = llvm::StructType::getTypeByName(llvm_ctx, "Array"); + if (array_struct) { + return array_struct; + } + array_struct = llvm::StructType::create(llvm_ctx, "Array"); + array_struct->setBody({ llvm::Type::getInt64Ty(llvm_ctx), llvm::Type::getInt64PtrTy(llvm_ctx) }); + return array_struct; +} + +std::string ListValue::to_string() const { + if (value_.empty()) { + return "()"; + } + std::stringstream ss; + ss << "("; + for (int i = 0; i < value_.size() - 1; ++i) { + ss << value_[i]->to_string(); + ss << " "; + } + ss << value_.back()->to_string(); + ss << ")"; + return ss.str(); +} + +std::string DottedListValue::to_string() const { + PSCM_INLINE_LOG_DECLARE("pscm.core.DottedListValue"); + PSCM_ASSERT(!value1_.empty()); + std::stringstream ss; + ss << "("; + for (int i = 0; i < value1_.size() - 1; ++i) { + ss << value1_[i]->to_string(); + ss << " "; + } + ss << value1_.back()->to_string(); + + ss << " "; + ss << "."; + ss << " "; + ss << value2_->to_string(); + ss << ")"; + return ss.str(); +} + +llvm::Value *IntegerValue::codegen(CodegenContext& ctx) { + return llvm::ConstantInt::get(ctx.llvm_ctx, llvm::APInt(64, value_)); +} + +const Type *IntegerValue::type() const { + return Type::get_integer_type(); +} + +llvm::Value *TrueValue::codegen(CodegenContext& ctx) { + return llvm::ConstantInt::getTrue(ctx.llvm_ctx); +} + +const Type *TrueValue::type() const { + return Type::get_boolean_type(); +} + +llvm::Value *FalseValue::codegen(CodegenContext& ctx) { + return llvm::ConstantInt::getFalse(ctx.llvm_ctx); +} + +const Type *FalseValue::type() const { + return Type::get_boolean_type(); +} + +llvm::Value *BinaryExprAST::codegen(CodegenContext& ctx) { + PSCM_INLINE_LOG_DECLARE("pscm.core.BinaryExprAST"); + PSCM_ASSERT(lhs_); + PSCM_ASSERT(rhs_); + auto lhs = lhs_->codegen(ctx); + auto rhs = rhs_->codegen(ctx); + if (!lhs || !rhs) { + return nullptr; + } + if (op_->to_string() == "+") { + return ctx.builder.CreateAdd(lhs, rhs, "add_tmp"); + } + else if (op_->to_string() == "-") { + return ctx.builder.CreateSub(lhs, rhs, "sub_tmp"); + } + else if (op_->to_string() == "<") { + auto cmptmp = ctx.builder.CreateICmpSLT(lhs, rhs, "lt_cmp_tmp"); + return cmptmp; + } + else if (op_->to_string() == ">") { + auto cmptmp = ctx.builder.CreateICmpSGT(lhs, rhs, "gt_cmp_tmp"); + return cmptmp; + } + else if (op_->to_string() == "=") { + return ctx.builder.CreateICmpEQ(lhs, rhs, "eq_cmp_tmp"); + } + else { + PSCM_UNIMPLEMENTED(); + } +} + +const Type *BinaryExprAST::type() const { + PSCM_INLINE_LOG_DECLARE("pscm.core.BinaryExprAST"); + auto op = op_->to_string(); + if (op == "<" || op == ">" || op == "=") { + return Type::get_boolean_type(); + } + if (op == "+" || op == "-") { + PSCM_ASSERT(lhs_->type() == rhs_->type()); + return lhs_->type(); + } + PSCM_UNIMPLEMENTED(); +} + +llvm::Type *convert_pscm_type_to_llvm_type(CodegenContext& ctx, const Type *type) { + PSCM_INLINE_LOG_DECLARE("pscm.core.convert_pscm_type_to_llvm_type"); + if (auto array_type = dynamic_cast(type); array_type) { + if (auto element_type = dynamic_cast(array_type->element_type()); element_type) { + return llvm::PointerType::get(ctx.get_array(), 0); + } + else { + PSCM_UNIMPLEMENTED(); + } + } + else if (auto integer_type = dynamic_cast(type); integer_type) { + return llvm::Type::getInt64Ty(ctx.llvm_ctx); + } + else { + PSCM_UNIMPLEMENTED(); + } +} + +llvm::Function *PrototypeAST::codegen(CodegenContext& ctx) { + std::vector func_args; + func_args.reserve(args_.size()); + for (auto arg_type : arg_type_list_) { + func_args.push_back(convert_pscm_type_to_llvm_type(ctx, arg_type)); + } + llvm::Type *func_return_type; + if (return_type_) { + func_return_type = convert_pscm_type_to_llvm_type(ctx, return_type_); + } + else if (!arg_type_list_.empty()) { + func_return_type = convert_pscm_type_to_llvm_type(ctx, arg_type_list_.front()); + } + else { + func_return_type = llvm::Type::getInt64Ty(ctx.llvm_ctx); + } + + llvm::FunctionType *func_type = llvm::FunctionType::get(func_return_type, func_args, false); + llvm::Function *func = llvm::Function::Create(func_type, llvm::Function::ExternalLinkage, name_, ctx.llvm_module); + unsigned idx = 0; + for (auto& arg : func->args()) { + arg.setName(args_[idx]); + idx++; + } + return func; +} + +llvm::Function *FunctionAST::codegen(CodegenContext& ctx) { + ctx.func_proto_map[proto_->name()] = proto_; + auto func = ctx.get_function(proto_->name()); + if (!func) { + return nullptr; + } + llvm::BasicBlock *bb = llvm::BasicBlock::Create(ctx.llvm_ctx, "entry", func); + ctx.builder.SetInsertPoint(bb); + ctx.named_values_map.clear(); + for (auto& arg : func->args()) { + ctx.named_values_map[std::string(arg.getName())] = &arg; + } + if (auto ret = body_->codegen(ctx)) { + ctx.builder.CreateRet(ret); + llvm::verifyFunction(*func); + return func; + } + func->eraseFromParent(); + return nullptr; +} + +llvm::Value *VariableExprAST::codegen(CodegenContext& ctx) { + PSCM_INLINE_LOG_DECLARE("pscm.core.VariableExprAST"); + auto value = ctx.named_values_map[this->value_->to_string()]; + if (value) { + return value; + } + PSCM_THROW_EXCEPTION("Unknown symbol: " + this->value_->to_string()); +} + +const Type *VariableExprAST::type() const { + return type_; +} + +llvm::Value *CallExprAST::codegen(CodegenContext& ctx) { + PSCM_INLINE_LOG_DECLARE("pscm.core.CallExprAST"); + std::vector arg_type_list; + for (auto arg : args_) { + arg_type_list.push_back(arg->type()); + } + auto mangled_name = Mangler().mangle(callee_, arg_type_list); + llvm::Function *callee = ctx.llvm_module.getFunction(mangled_name); + if (!callee) { + std::tie(callee, return_type_) = this->instance_function(ctx, callee_, arg_type_list, nullptr); + PSCM_ASSERT(callee); + } + if (!return_type_) { + if (callee->getFunction().getReturnType() == llvm::Type::getInt64Ty(ctx.llvm_ctx)) { + return_type_ = Type::get_integer_type(); + } + else { + PSCM_UNIMPLEMENTED(); + } + } + if (callee->arg_size() != args_.size()) { + PSCM_THROW_EXCEPTION("Incorrect # arguments passed: " + callee_); + } + std::vector args; + args.reserve(args_.size()); + for (auto& arg : args_) { + args.push_back(arg->codegen(ctx)); + } + auto calltmp = ctx.builder.CreateCall(callee, args, "calltmp"); + return calltmp; +} + +const Type *CallExprAST::type() const { + PSCM_INLINE_LOG_DECLARE("pscm.core.CallExprAST"); + PSCM_ASSERT(return_type_); + return return_type_; +} + +llvm::Value *IfExprAST::codegen(CodegenContext& ctx) { + auto cond = cond_->codegen(ctx); + if (!cond) { + return nullptr; + } + // llvm::errs() << *cond->getType() << "\n"; + // cond = ctx.builder.CreateICmpEQ(cond, llvm::ConstantInt::get(ctx.llvm_ctx, llvm::APInt(1, 1)), "if_cond"); + auto func = ctx.builder.GetInsertBlock()->getParent(); + auto then_bb = llvm::BasicBlock::Create(ctx.llvm_ctx, "then", func); + auto else_bb = llvm::BasicBlock::Create(ctx.llvm_ctx, "else"); + auto merge_bb = llvm::BasicBlock::Create(ctx.llvm_ctx, "if_cont"); + + std::vector else_if_cond_bb_list; + std::vector else_if_then_bb_list; + std::vector else_if_then_stmt_list; + for (auto& [else_if_cond, then] : else_if_) { + auto else_if_cond_bb = llvm::BasicBlock::Create(ctx.llvm_ctx, "else_if", func); + auto else_if_then_bb = llvm::BasicBlock::Create(ctx.llvm_ctx, "else_if_then", func); + else_if_cond_bb_list.push_back(else_if_cond_bb); + else_if_then_bb_list.push_back(else_if_then_bb); + } + if (else_if_.empty()) { + ctx.builder.CreateCondBr(cond, then_bb, else_bb); + } + else { + ctx.builder.CreateCondBr(cond, then_bb, else_if_cond_bb_list[0]); + } + for (int i = 0; i < else_if_.size(); ++i) { + auto else_if_cond = else_if_[i].first; + auto else_if_then = else_if_[i].second; + auto else_if_cond_bb = else_if_cond_bb_list[i]; + auto else_if_then_bb = else_if_then_bb_list[i]; + ctx.builder.SetInsertPoint(else_if_cond_bb); + auto else_if_cond_stmt = else_if_cond->codegen(ctx); + if (!else_if_cond_stmt) { + return nullptr; + } + ctx.builder.SetInsertPoint(else_if_then_bb); + auto else_if_then_stmt = else_if_then->codegen(ctx); + if (!else_if_then_stmt) { + return nullptr; + } + else_if_then_stmt_list.push_back(else_if_then_stmt); + ctx.builder.CreateBr(merge_bb); + ctx.builder.SetInsertPoint(else_if_cond_bb); + if (i == else_if_.size() - 1) { + ctx.builder.CreateCondBr(else_if_cond_stmt, else_if_then_bb, else_bb); + } + else { + ctx.builder.CreateCondBr(else_if_cond_stmt, else_if_then_bb, else_if_cond_bb_list[i + 1]); + } + } + + ctx.builder.SetInsertPoint(then_bb); + auto then_stmt = then_stmt_->codegen(ctx); + if (!then_stmt) { + return nullptr; + } + ctx.builder.CreateBr(merge_bb); + then_bb = ctx.builder.GetInsertBlock(); + + func->insert(func->end(), else_bb); + ctx.builder.SetInsertPoint(else_bb); + + auto else_stmt = else_stmt_ ? else_stmt_->codegen(ctx) : llvm::ConstantInt::get(ctx.llvm_ctx, llvm::APInt(64, 0)); + if (!else_stmt) { + return nullptr; + } + ctx.builder.CreateBr(merge_bb); + + func->insert(func->end(), merge_bb); + ctx.builder.SetInsertPoint(merge_bb); + llvm::PHINode *phi = ctx.builder.CreatePHI(llvm::Type::getInt64Ty(ctx.llvm_ctx), 2 + else_if_.size(), "if_tmp"); + phi->addIncoming(then_stmt, then_bb); + for (int i = 0; i < else_if_.size(); ++i) { + phi->addIncoming(else_if_then_stmt_list[i], else_if_then_bb_list[i]); + } + phi->addIncoming(else_stmt, else_bb); + llvm::verifyFunction(*func); + return phi; +} + +const Type *IfExprAST::type() const { + return then_stmt_->type(); +} + +llvm::Value *ArrayExprAST::codegen(CodegenContext& ctx) { + auto malloc_func = ctx.get_malloc(); + auto array_struct = ctx.get_array(); + auto array_ptr = ctx.builder.CreateCall(malloc_func, { ctx.builder.getInt64(sizeof(Array)) }, "array_ptr"); + auto array_size_ptr = ctx.builder.CreateStructGEP(array_struct, array_ptr, 0, "array_size"); + auto array_data_placeholder = ctx.builder.CreateStructGEP(array_struct, array_ptr, 1, "array_data_placeholder"); + ctx.builder.CreateAlignedStore(llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx.llvm_ctx), value_.size()), + array_size_ptr, llvm::MaybeAlign(8)); + + auto array_data_ptr = + ctx.builder.CreateCall(malloc_func, { ctx.builder.getInt64(value_.size() * 8) }, "array_data_ptr"); + for (size_t i = 0; i < value_.size(); ++i) { + llvm::Value *ptr = ctx.builder.CreateGEP(llvm::Type::getInt64Ty(ctx.llvm_ctx), array_data_ptr, + ctx.builder.getInt64(i), "array_data_" + std::to_string(i)); + ctx.builder.CreateAlignedStore(value_[i]->codegen(ctx), ptr, llvm::MaybeAlign(8)); + } + ctx.builder.CreateAlignedStore(array_data_ptr, array_data_placeholder, llvm::MaybeAlign(8)); + return array_ptr; +} + +std::string ArrayExprAST::to_string() const { + return "array ast"; +} + +const Type *ArrayExprAST::type() const { + return Type::get_integer_array_type(); +} + +llvm::Value *MapExprAST::codegen(CodegenContext& ctx) { + PSCM_INLINE_LOG_DECLARE("pscm.core.MapExprAST"); + if (auto array = dynamic_cast(args_); array) { + std::vector value_list; + for (int i = 0; i < array->size(); ++i) { + auto call = new CallExprAST(callee_, { array->value()[i] }, { Type::get_integer_type() }); + value_list.push_back(call); + } + return ArrayExprAST(value_list).codegen(ctx); + } + else if (auto sym = dynamic_cast(args_); sym) { + auto array_ptr = ctx.named_values_map[sym->name()]; + auto array_struct = ctx.get_array(); + + auto malloc_func = ctx.get_malloc(); + auto ret_array_ptr = ctx.builder.CreateCall(malloc_func, { ctx.builder.getInt64(sizeof(Array)) }, "ret_array_ptr"); + auto ret_array_size_ptr = ctx.builder.CreateStructGEP(array_struct, ret_array_ptr, 0, "ret_array_size"); + auto ret_array_data_placeholder = + ctx.builder.CreateStructGEP(array_struct, ret_array_ptr, 1, "ret_array_data_placeholder"); + + auto array_size_ptr = ctx.builder.CreateStructGEP(array_struct, array_ptr, 0, "array_size_ptr"); + auto array_data_placeholder = ctx.builder.CreateStructGEP(array_struct, array_ptr, 1, "array_data_placeholder"); + auto array_data_ptr = ctx.builder.CreateAlignedLoad(llvm::Type::getInt64PtrTy(ctx.llvm_ctx), array_data_placeholder, + llvm::MaybeAlign(8), "array_data_ptr"); + auto array_size = ctx.builder.CreateAlignedLoad(llvm::Type::getInt64Ty(ctx.llvm_ctx), array_size_ptr, + llvm::MaybeAlign(8), "array_size"); + ctx.builder.CreateAlignedStore(array_size, ret_array_size_ptr, llvm::MaybeAlign(8)); + auto ret_array_data_memory_size = ctx.builder.CreateMul(array_size, ctx.builder.getInt64(8), "array_mem_size"); + auto ret_array_data_ptr = ctx.builder.CreateCall(malloc_func, { ret_array_data_memory_size }, "ret_array_data_ptr"); + ctx.builder.CreateAlignedStore(ret_array_data_ptr, ret_array_data_placeholder, llvm::MaybeAlign(8)); + + auto func = ctx.builder.GetInsertBlock()->getParent(); + auto loop_cond_bb = llvm::BasicBlock::Create(ctx.llvm_ctx, "loop.cond", func); + auto loop_body_bb = llvm::BasicBlock::Create(ctx.llvm_ctx, "loop.body", func); + auto loop_inc_bb = llvm::BasicBlock::Create(ctx.llvm_ctx, "loop.inc", func); + auto loop_end_bb = llvm::BasicBlock::Create(ctx.llvm_ctx, "loop.end", func); + + auto idx_ptr = ctx.builder.CreateAlloca(llvm::Type::getInt64Ty(ctx.llvm_ctx), nullptr, "idx_ptr"); + ctx.builder.CreateAlignedStore(llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx.llvm_ctx), 0), idx_ptr, + llvm::MaybeAlign(8)); + ctx.builder.CreateBr(loop_cond_bb); + + ctx.builder.SetInsertPoint(loop_cond_bb); + auto idx = ctx.builder.CreateAlignedLoad(llvm::Type::getInt64Ty(ctx.llvm_ctx), idx_ptr, llvm::MaybeAlign(8), "idx"); + + auto cmptmp = ctx.builder.CreateICmpULT(idx, array_size, "cmptmp"); + ctx.builder.CreateCondBr(cmptmp, loop_body_bb, loop_end_bb); + + ctx.builder.SetInsertPoint(loop_body_bb); + auto input_item_ptr = + ctx.builder.CreateGEP(llvm::Type::getInt64PtrTy(ctx.llvm_ctx), array_data_ptr, idx, "input_item_ptr"); + auto input_item = ctx.builder.CreateAlignedLoad(llvm::Type::getInt64Ty(ctx.llvm_ctx), input_item_ptr, + llvm::MaybeAlign(8), "input_item"); + auto ret_output_item_ptr = + ctx.builder.CreateGEP(llvm::Type::getInt64PtrTy(ctx.llvm_ctx), ret_array_data_ptr, idx, "output_item_ptr"); + auto callee = ctx.get_function(callee_); + const Type *return_type = nullptr; + if (!callee) { + auto map_func_input_type = sym->type(); + if (auto array_type = dynamic_cast(map_func_input_type); array_type) { + std::tie(callee, return_type) = this->instance_function(ctx, callee_, { array_type->element_type() }); + } + PSCM_ASSERT(callee); + } + std::vector func_args; + func_args.push_back(input_item); + auto calltmp = ctx.builder.CreateCall(callee, func_args, "output_item"); + ctx.builder.CreateAlignedStore(calltmp, ret_output_item_ptr, llvm::MaybeAlign(8)); + + ctx.builder.CreateBr(loop_inc_bb); + + ctx.builder.SetInsertPoint(loop_inc_bb); + auto new_idx = + ctx.builder.CreateAdd(idx, llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx.llvm_ctx), 1), "new_idx"); + ctx.builder.CreateAlignedStore(new_idx, idx_ptr, llvm::MaybeAlign(8)); + + ctx.builder.CreateBr(loop_cond_bb); + ctx.builder.SetInsertPoint(loop_end_bb); + + return ret_array_ptr; + } + else { + PSCM_UNIMPLEMENTED(); + } +} + +const Type *MapExprAST::type() const { + return Type::get_integer_array_type(); +} + +std::pair AST::instance_function(CodegenContext& ctx, const std::string& callee_, + const std::vector& args_, + const Type *return_type) { + PSCM_INLINE_LOG_DECLARE("pscm.core.instance_function"); + auto mangled_name = Mangler().mangle(callee_, args_); + auto t = ctx.builder.GetInsertBlock(); + + // instance func callee + auto it = ctx.proc_map.find(callee_); + if (it == ctx.proc_map.end()) { + PSCM_THROW_EXCEPTION("Unknown symbol: " + callee_); + } + auto proc = it->second; + + std::vector args; + args.reserve(proc->args().size()); + for (auto& arg : proc->args()) { + args.push_back(arg->to_string()); + } + std::vector value_to_codegen; + value_to_codegen.reserve(proc->body().size()); + ctx.evaluator.push_symbol_table(); + for (int i = 0; i < args_.size(); ++i) { + auto sym = proc->args()[i]; + auto arg = args_[i]; + ctx.evaluator.add_sym(sym, new VariableExprAST(sym, arg)); + } + for (auto stmt : proc->body()) { + auto value = ctx.evaluator.eval(stmt); + value_to_codegen.push_back(value); + if (auto expr_ast = dynamic_cast(value); expr_ast) { + return_type = expr_ast->type(); + } + else { + PSCM_UNIMPLEMENTED(); + } + } + ctx.evaluator.pop_symbol_table(); + auto proto = new PrototypeAST(mangled_name, args, args_, return_type); + auto func = new FunctionAST(proto, value_to_codegen[0]); + func->codegen(ctx); + auto callee = ctx.llvm_module.getFunction(mangled_name); + PSCM_ASSERT(callee); + ctx.builder.SetInsertPoint(t); + return { callee, return_type }; +} + +BooleanType *Type::get_boolean_type() { + static BooleanType type; + return &type; +} + +IntegerType *Type::get_integer_type() { + static IntegerType type; + return &type; +} + +ArrayType *Type::get_integer_array_type() { + static ArrayType type(get_integer_type()); + return &type; +} + +std::string BooleanType::to_string() const { + return "boolean"; +} + +std::string IntegerType::to_string() const { + return "integer"; +} + +std::string ArrayType::to_string() const { + std::stringstream ss; + ss << "array"; + ss << "["; + ss << element_type_->to_string(); + ss << "]"; + return ss.str(); +} +} // namespace pscm::core \ No newline at end of file diff --git a/src/core/Value.h b/src/core/Value.h new file mode 100644 index 00000000..c7868ce2 --- /dev/null +++ b/src/core/Value.h @@ -0,0 +1,412 @@ +#pragma once +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" + +namespace pscm::core { + +class Array { +public: + int64_t size; + int64_t *data; +}; +class BooleanType; +class IntegerType; +class ArrayType; + +class Type { +public: + [[nodiscard]] static BooleanType *get_boolean_type(); + [[nodiscard]] static IntegerType *get_integer_type(); + [[nodiscard]] static ArrayType *get_integer_array_type(); + [[nodiscard]] virtual std::string to_string() const = 0; +}; + +class BooleanType : public Type { +public: + [[nodiscard]] std::string to_string() const override; +}; + +class IntegerType : public Type { + [[nodiscard]] std::string to_string() const override; +}; + +class ArrayType : public Type { +public: + explicit ArrayType(Type *element_type) + : element_type_(element_type) { + } + + [[nodiscard]] const Type *element_type() const { + return element_type_; + } + + [[nodiscard]] std::string to_string() const override; + +private: + Type *element_type_; +}; + +class PrototypeAST; +class FunctionAST; +class Procedure; +class Evaluator; + +struct CodegenContext { + llvm::LLVMContext& llvm_ctx; + llvm::Module& llvm_module; + llvm::IRBuilder<>& builder; + Evaluator& evaluator; + std::unordered_map func_proto_map; + std::unordered_map named_values_map; + std::unordered_map proc_map; + + [[nodiscard]] llvm::Function *get_function(const std::string& name); + [[nodiscard]] llvm::FunctionCallee get_malloc(); + [[nodiscard]] llvm::StructType *get_array(); +}; +class ExprAST; + +class AST { +public: + [[nodiscard]] virtual llvm::Value *codegen(CodegenContext& ctx) = 0; + [[nodiscard]] std::pair + instance_function(CodegenContext& ctx, const std::string& callee, const std::vector& arg_type_list, + const Type *return_type = nullptr); +}; + +class ExprAST : public AST { +public: + [[nodiscard]] virtual const Type *type() const = 0; +}; + +class Value { +public: + [[nodiscard]] virtual std::string to_string() const = 0; + + virtual ~Value() = default; +}; + +class BooleanValue : public Value { +public: +}; + +class TrueValue final + : public BooleanValue + , public ExprAST { +private: + TrueValue() = default; + +public: + static const TrueValue *instance() { + static TrueValue value; + return &value; + } + + [[nodiscard]] std::string to_string() const override { + return "#t"; + } + + [[nodiscard]] llvm::Value *codegen(CodegenContext& ctx) override; + + const Type *type() const override; +}; + +class FalseValue final + : public BooleanValue + , public ExprAST { +private: + FalseValue() = default; + +public: + static const FalseValue *instance() { + static FalseValue value; + return &value; + } + + [[nodiscard]] std::string to_string() const override { + return "#f"; + } + + [[nodiscard]] llvm::Value *codegen(CodegenContext& ctx) override; + + const Type *type() const override; +}; + +class SymbolValue final : public Value { +public: + explicit SymbolValue(std::string value) + : value_(std::move(value)) { + } + + [[nodiscard]] std::string to_string() const override { + return value_; + } + +private: + std::string value_; +}; + +class VariableExprAST : public ExprAST { +public: + VariableExprAST(SymbolValue *sym, const Type *type) + : value_(sym) + , type_(type) { + } + + [[nodiscard]] llvm::Value *codegen(CodegenContext& ctx) override; + + std::string name() const { + return value_->to_string(); + } + + const Type *type() const override; + +private: + SymbolValue *value_; + const Type *type_; +}; + +class StringValue final : public Value { +public: + explicit StringValue(std::string value) + : value_(std::move(value)) { + } + + [[nodiscard]] std::string to_string() const override { + std::vector s; + s.reserve(value_.size() + 2); + s.push_back('"'); + for (auto ch : value_) { + if (ch == '"') { + s.push_back('\\'); + } + s.push_back(ch); + } + s.push_back('"'); + return { s.begin(), s.end() }; + } + +private: + std::string value_; +}; + +class NumberValue : public Value {}; + +class IntegerValue final + : public NumberValue + , public ExprAST { +public: + explicit IntegerValue(int64_t value) + : value_(value) { + } + + static IntegerValue *zero() { + static IntegerValue zero(0); + return &zero; + } + + [[nodiscard]] std::string to_string() const override { + return std::to_string(value_); + } + + [[nodiscard]] int64_t value() const { + return value_; + } + + [[nodiscard]] llvm::Value *codegen(CodegenContext& ctx) override; + + const Type *type() const override; + +private: + int64_t value_; +}; + +class ListValue : public Value { +public: + explicit ListValue(std::vector value_list) + : value_(std::move(value_list)) { + } + + [[nodiscard]] std::string to_string() const override; + + [[nodiscard]] const std::vector value() const { + return value_; + } + +private: + std::vector value_; +}; + +class DottedListValue : public Value { +public: + DottedListValue(std::vector value_list, Value *value) + : value1_(std::move(value_list)) + , value2_(value) { + } + + [[nodiscard]] std::string to_string() const override; + + [[nodiscard]] const std::vector& value1() const { + return value1_; + } + + [[nodiscard]] const Value *value2() const { + return value2_; + } + +private: + std::vector value1_; + Value *value2_; +}; + +class ArrayExprAST + : public Value + , public ExprAST { +public: + explicit ArrayExprAST(std::vector value) + : value_(std::move(value)) { + } + + [[nodiscard]] llvm::Value *codegen(CodegenContext& ctx) override; + + [[nodiscard]] const std::vector& value() const { + return value_; + } + + [[nodiscard]] std::size_t size() const { + return value_.size(); + } + + [[nodiscard]] std::string to_string() const override; + const Type *type() const override; + +private: + std::vector value_; +}; + +class BinaryExprAST : public ExprAST { +public: + BinaryExprAST(SymbolValue *op, ExprAST *lhs, ExprAST *rhs) + : op_(op) + , lhs_(lhs) + , rhs_(rhs) { + } + + llvm::Value *codegen(CodegenContext& ctx) override; + [[nodiscard]] const Type *type() const override; + +private: + SymbolValue *op_; + ExprAST *lhs_; + ExprAST *rhs_; +}; + +class CallExprAST : public ExprAST { +public: + CallExprAST(std::string callee, std::vector args, std::vector types) + : callee_(std::move(callee)) + , args_(std::move(args)) + , types_(std::move(types)) + , return_type_(nullptr) { + } + + [[nodiscard]] llvm::Value *codegen(CodegenContext& ctx) override; + [[nodiscard]] const Type *type() const override; + +private: + std::string callee_; + std::vector args_; + std::vector types_; + const Type *return_type_; +}; + +class IfExprAST : public ExprAST { +public: + IfExprAST(ExprAST *cond, ExprAST *then_stmt, ExprAST *else_stmt) + : cond_(cond) + , then_stmt_(then_stmt) + , else_stmt_(else_stmt) { + } + + void add_else_if(ExprAST *cond, ExprAST *then_stmt) { + else_if_.emplace_back(cond, then_stmt); + } + + llvm::Value *codegen(CodegenContext& ctx) override; + const Type *type() const override; + +private: + ExprAST *cond_; + ExprAST *then_stmt_; + std::vector> else_if_; + ExprAST *else_stmt_; +}; + +class MapExprAST : public ExprAST { +public: + MapExprAST(std::string callee, ExprAST *args) + : callee_(std::move(callee)) + , args_(args) { + } + + [[nodiscard]] llvm::Value *codegen(CodegenContext& ctx) override; + [[nodiscard]] const Type *type() const override; + +private: + std::string callee_; + ExprAST *args_; +}; + +class PrototypeAST { +public: + PrototypeAST(std::string name, std::vector args, std::vector arg_type_list, + const Type *return_type = nullptr) + : name_(std::move(name)) + , args_(std::move(args)) + , arg_type_list_(std::move(arg_type_list)) + , return_type_(return_type) { + } + + [[nodiscard]] llvm::Function *codegen(CodegenContext& ctx); + + [[nodiscard]] const std::string& name() const { + return name_; + } + +private: + std::string name_; + std::vector args_; + std::vector arg_type_list_; + const Type *return_type_; +}; + +class FunctionAST : public AST { +public: + FunctionAST(PrototypeAST *proto, AST *body) + : proto_(proto) + , body_(body) { + } + + [[nodiscard]] llvm::Function *codegen(CodegenContext& ctx); + +private: + PrototypeAST *proto_; + AST *body_; +}; +} // namespace pscm::core diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index fb183cec..8bb7afcf 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -19,6 +19,10 @@ endforeach () add_pscm_test(sicp_ch1_tests sicp/ch1_tests.cpp) add_pscm_test(r4rs_test r4rs/r4rs_tests.cpp) add_pscm_test(module_test module/load_path_tests.cpp) +if (PSCM_ENABLE_MLIR_CODEGEN) + add_pscm_test(core_test core/simple_core_tests.cpp) + target_include_directories(core_test PRIVATE ${CMAKE_SOURCE_DIR}/src) +endif () if (EMSCRIPTEN) elseif (WIN32) # TODO: handle EOF diff --git a/test/core/simple_core_tests.cpp b/test/core/simple_core_tests.cpp new file mode 100644 index 00000000..110b89d5 --- /dev/null +++ b/test/core/simple_core_tests.cpp @@ -0,0 +1,139 @@ +#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +#include "core/Scheme.h" +#include "core/Value.h" +#include "doctest/doctest.h" +using namespace doctest; +using namespace pscm::core; +using namespace std::string_literals; +using namespace doctest; + +TEST_CASE("testing add") { + std::string code = R"( +(+ 2 3) +)"; + Scheme scm; + auto ret = scm.eval(code.c_str()); + CHECK(ret); + CHECK(ret->to_string() == "5"s); +} + +TEST_CASE("testing add 3") { + std::string code = R"( +(+ 2 3 4) +)"; + Scheme scm; + auto ret = scm.eval(code.c_str()); + CHECK(ret); + CHECK(ret->to_string() == "9"s); +} + +TEST_CASE("testing minus") { + std::string code = R"( +(- 2) +)"; + Scheme scm; + auto ret = scm.eval(code.c_str()); + CHECK(ret); + CHECK(ret->to_string() == "-2"s); +} + +TEST_CASE("testing minus, 2") { + std::string code = R"( +(- 2 3) +)"; + Scheme scm; + auto ret = scm.eval(code.c_str()); + CHECK(ret); + CHECK(ret->to_string() == "-1"s); +} + +TEST_CASE("testing function") { + std::string code = R"( +(define (sum a b c) + (+ a b c)) +(sum 1 2 3) +)"; + Scheme scm; + auto ret = scm.eval(code.c_str()); + CHECK(ret); + CHECK(ret->to_string() == "6"s); +} + +TEST_CASE("testing cond") { + std::string code = R"( +(cond ((> 3 2) 100) + ((< 3 2) 200)) +)"; + Scheme scm; + auto ret = scm.eval(code.c_str()); + CHECK(ret); + CHECK(ret->to_string() == "100"s); +} + +TEST_CASE("testing function") { + // 0, 1, 1, 2, 3, 5, 8, 13, 21, 34 + std::string code = R"( +(define (fib n) + (cond ((< n 2) 0) + ((< n 4) 1) + (else (+ (fib (- n 1)) (fib (- n 2)))))) +(fib 10) +)"; + Scheme scm; + auto ret = scm.eval(code.c_str()); + CHECK(ret); + CHECK(ret->to_string() == "34"s); +} + +TEST_CASE("testing loop, map, abs function") { + // (map abs '(4 -5 6)) + std::string code = R"( +(define (abs n) + (cond ((< n 0) (- n)) + (else n))) +(abs -5) +)"; + Scheme scm; + auto ret = scm.eval(code.c_str()); + CHECK(ret); + CHECK(ret->to_string() == "5"s); +} + +TEST_CASE("testing loop, map") { + // (map abs '(4 -5 6)) + std::string code = R"( +(define (abs n) + (cond ((< n 0) (- n)) + (else n))) +(map abs '(4 -5 6)) +)"; + Scheme scm; + scm.set_logger_level(0); + auto ret = scm.eval(code.c_str()); + CHECK(ret); + CHECK(ret->to_string() == "(4 5 6)"s); +} + +TEST_CASE("testing loop, map, in func") { + std::string code = R"( +(define (abs n) + (cond ((< n 0) (- n)) + (else n))) +(define (map-fn list) (map abs list)) +(map-fn '(4 -5 6)) +)"; + Scheme scm; + auto ret = scm.eval(code.c_str()); + CHECK(ret); + CHECK(ret->to_string() == "(4 5 6)"s); +} + +TEST_CASE("testing list, car") { + std::string code = R"( +(car '(4 -5 6)) +)"; + Scheme scm; + auto ret = scm.eval(code.c_str()); + CHECK(ret); + CHECK(ret->to_string() == "4"s); +} \ No newline at end of file