Skip to content

Commit

Permalink
major code generation refactoring;
Browse files Browse the repository at this point in the history
  • Loading branch information
NateSeymour committed Nov 29, 2024
1 parent c7087cc commit e258fbd
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 304 deletions.
16 changes: 9 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ add_library(unlogic STATIC
src/util/io.cpp
src/util/io.h
src/Error.h
src/parser/Node.cpp
src/parser/Node.h
src/parser/Parser.cpp
src/parser/Parser.h
src/compiler/transformer/IRGenerator.cpp
Expand Down Expand Up @@ -105,23 +107,23 @@ qt_add_shaders(unlogic-calculator "shaders"
PREFIX "/shaders"
BASE src/calculator/resource/shaders
FILES
src/calculator/resource/shaders/plot.frag
src/calculator/resource/shaders/plot.vert
src/calculator/resource/shaders/grid.frag
src/calculator/resource/shaders/grid.vert
src/calculator/resource/shaders/plot.frag
src/calculator/resource/shaders/plot.vert
src/calculator/resource/shaders/grid.frag
src/calculator/resource/shaders/grid.vert
)
qt_add_resources(unlogic-calculator "fonts"
PREFIX "/fonts"
BASE src/calculator/resource
FILES
src/calculator/resource/Roboto-Medium.ttf
src/calculator/resource/SourceCodePro.ttf
src/calculator/resource/Roboto-Medium.ttf
src/calculator/resource/SourceCodePro.ttf
)
qt_add_resources(unlogic-calculator "styles"
PREFIX "/styles"
BASE src/calculator/resource
FILES
src/calculator/resource/stylesheet.qss
src/calculator/resource/stylesheet.qss
)
target_link_libraries(unlogic-calculator PRIVATE Qt6::Core Qt6::Widgets Qt6::Gui Qt6::ShaderTools Qt6::GuiPrivate unlogic)
set_target_properties(unlogic-calculator PROPERTIES
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/Compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ std::expected<Program, CompilationError> Compiler::Compile(std::string_view prog
// Build program
try
{
ast_body->Accept(generator);
std::visit(generator, *ast_body);
}
catch (std::runtime_error &e)
{
Expand Down
176 changes: 75 additions & 101 deletions src/compiler/transformer/IRGenerator.cpp
Original file line number Diff line number Diff line change
@@ -1,134 +1,114 @@
#include <format>
#include "IRGenerator.h"
#include <format>
#include <llvm/IR/Verifier.h>

void unlogic::IRGenerator::Visit(unlogic::NumericLiteralNode const *node)
llvm::Value *unlogic::IRGenerator::operator()(unlogic::NumericLiteralNode &node)
{
llvm::Value *value = llvm::ConstantFP::get(this->ctx.llvm_ctx, llvm::APFloat(node->value));
this->values.push(value);
return llvm::ConstantFP::get(this->ctx.llvm_ctx, llvm::APFloat(node.value));
}

void unlogic::IRGenerator::Visit(unlogic::StringLiteralNode const *node) {}

void unlogic::IRGenerator::Visit(unlogic::VariableNode const *node)
llvm::Value *unlogic::IRGenerator::operator()(unlogic::StringLiteralNode &node)
{
llvm::Value *value = *this->ctx.scope.Lookup(node->identifier_);
this->values.push(value);
return this->builder.CreateGlobalStringPtr(node.value);
}

void unlogic::IRGenerator::Visit(unlogic::CallNode const *node)
llvm::Value *unlogic::IRGenerator::operator()(unlogic::DivisionNode &node)
{
llvm::Function *function = ctx.module->getFunction(node->function_name_);
llvm::Value *lhs = std::visit(*this, *node.lhs);
llvm::Value *rhs = std::visit(*this, *node.rhs);

if (function->arg_size() < node->arguments_.size())
{
throw std::runtime_error("Aaaaaahhhhhhh");
}
return this->builder.CreateFDiv(lhs, rhs, "divtmp");
}

std::vector<llvm::Value *> argument_values;
argument_values.reserve(node->arguments_.size());
for (auto const &argument: node->arguments_)
llvm::Value *unlogic::IRGenerator::operator()(unlogic::ScopedBlockNode &node)
{
for (auto &statement: node.statements)
{
argument->Accept(*this);
argument_values.push_back(this->values.top());
this->values.pop();
std::visit(*this, *statement);
}

llvm::Value *value = this->builder.CreateCall(function, argument_values, "calltmp");
this->values.push(value);
return nullptr;
}

void unlogic::IRGenerator::Visit(unlogic::AdditionNode const *node)
llvm::Value *unlogic::IRGenerator::operator()(unlogic::VariableNode &node)
{
node->lhs_->Accept(*this);
llvm::Value *lhs = this->values.top();
this->values.pop();

node->rhs_->Accept(*this);
llvm::Value *rhs = this->values.top();
this->values.pop();

llvm::Value *value = this->builder.CreateFAdd(lhs, rhs, "addtmp");
this->values.push(value);
return *this->ctx.scope.Lookup(node.identifier);
}

void unlogic::IRGenerator::Visit(unlogic::SubtractionNode const *node)
llvm::Value *unlogic::IRGenerator::operator()(unlogic::CallNode &node)
{
node->lhs_->Accept(*this);
llvm::Value *lhs = this->values.top();
this->values.pop();
llvm::Function *function = ctx.module->getFunction(node.function_name);

node->rhs_->Accept(*this);
llvm::Value *rhs = this->values.top();
this->values.pop();
if (function->arg_size() < node.arguments.size())
{
throw std::runtime_error("Aaaaaahhhhhhh");
}

std::vector<llvm::Value *> argument_values;
argument_values.reserve(node.arguments.size());
for (auto &argument: node.arguments)
{
llvm::Value *arg_value = std::visit(*this, *argument);
argument_values.push_back(arg_value);
}

llvm::Value *value = this->builder.CreateFSub(lhs, rhs, "subtmp");
this->values.push(value);
return this->builder.CreateCall(function, argument_values, "calltmp");
}

void unlogic::IRGenerator::Visit(unlogic::MultiplicationNode const *node)
llvm::Value *unlogic::IRGenerator::operator()(unlogic::AdditionNode &node)
{
node->lhs_->Accept(*this);
llvm::Value *lhs = this->values.top();
this->values.pop();
llvm::Value *lhs = std::visit(*this, *node.lhs);
llvm::Value *rhs = std::visit(*this, *node.rhs);

node->rhs_->Accept(*this);
llvm::Value *rhs = this->values.top();
this->values.pop();

llvm::Value *value = this->builder.CreateFMul(lhs, rhs, "multmp");
this->values.push(value);
return this->builder.CreateFAdd(lhs, rhs, "addtmp");
}

void unlogic::IRGenerator::Visit(unlogic::DivisionNode const *node)
llvm::Value *unlogic::IRGenerator::operator()(unlogic::SubtractionNode &node)
{
node->lhs_->Accept(*this);
llvm::Value *lhs = this->values.top();
this->values.pop();

node->rhs_->Accept(*this);
llvm::Value *rhs = this->values.top();
this->values.pop();
llvm::Value *lhs = std::visit(*this, *node.lhs);
llvm::Value *rhs = std::visit(*this, *node.rhs);

llvm::Value *value = this->builder.CreateFDiv(lhs, rhs, "divtmp");
this->values.push(value);
return this->builder.CreateFSub(lhs, rhs, "subtmp");
}

void unlogic::IRGenerator::Visit(unlogic::PotentiationNode const *node)
llvm::Value *unlogic::IRGenerator::operator()(unlogic::MultiplicationNode &node)
{
node->lhs_->Accept(*this);
llvm::Value *lhs = this->values.top();
this->values.pop();
llvm::Value *lhs = std::visit(*this, *node.lhs);
llvm::Value *rhs = std::visit(*this, *node.rhs);

node->rhs_->Accept(*this);
llvm::Value *rhs = this->values.top();
this->values.pop();
return this->builder.CreateFMul(lhs, rhs, "multmp");
}

llvm::Value *unlogic::IRGenerator::operator()(unlogic::PotentiationNode &node)
{
llvm::Value *lhs = std::visit(*this, *node.lhs);
llvm::Value *rhs = std::visit(*this, *node.rhs);

llvm::Function *std_pow = this->ctx.module->getFunction("pow");

llvm::Value *value = this->builder.CreateCall(std_pow, {lhs, rhs}, "powtmp");
this->values.push(value);
return this->builder.CreateCall(std_pow, {lhs, rhs}, "powtmp");
}

void unlogic::IRGenerator::Visit(unlogic::FunctionDefinitionNode const *node)
llvm::Value *unlogic::IRGenerator::operator()(unlogic::FunctionDefinitionNode &node)
{
// Save entry
llvm::BasicBlock *parent = this->builder.GetInsertBlock();

// Generate function information
std::vector<llvm::Type *> argument_types(node->args_.size(), llvm::Type::getDoubleTy(ctx.llvm_ctx));
std::vector<llvm::Type *> argument_types(node.args.size(), llvm::Type::getDoubleTy(ctx.llvm_ctx));
llvm::FunctionType *function_type = llvm::FunctionType::get(llvm::Type::getDoubleTy(ctx.llvm_ctx), argument_types, false);
llvm::Function *function = llvm::Function::Create(function_type, llvm::Function::ExternalLinkage, node->name_, *ctx.module);
llvm::Function *function = llvm::Function::Create(function_type, llvm::Function::ExternalLinkage, node.name, *ctx.module);

this->ctx.scope.Insert(node->name_, function);
this->ctx.scope.Insert(node.name, function);

unsigned idx = 0;
for (auto &arg: function->args())
{
arg.setName(node->args_[idx++]);
arg.setName(node.args[idx++]);
}

// Generate function body
llvm::BasicBlock *block = llvm::BasicBlock::Create(ctx.llvm_ctx, node->name_, function);
llvm::BasicBlock *block = llvm::BasicBlock::Create(ctx.llvm_ctx, node.name, function);
this->builder.SetInsertPoint(block);

ctx.scope.PushLayer();
Expand All @@ -137,9 +117,7 @@ void unlogic::IRGenerator::Visit(unlogic::FunctionDefinitionNode const *node)
ctx.scope.Insert(std::string(arg.getName()), &arg);
}

node->body_->Accept(*this);
llvm::Value *return_value = this->values.top();
this->values.pop();
llvm::Value *return_value = std::visit(*this, *node.body);

this->builder.CreateRet(return_value);

Expand All @@ -150,42 +128,31 @@ void unlogic::IRGenerator::Visit(unlogic::FunctionDefinitionNode const *node)
throw std::runtime_error("function has errors");
}

this->values.push(function);

// Return to parent block
this->builder.SetInsertPoint(parent);

return function;
}

void unlogic::IRGenerator::Visit(unlogic::PlotCommandNode const *node)
llvm::Value *unlogic::IRGenerator::operator()(unlogic::PlotCommandNode &node)
{
llvm::Value *scene = *this->ctx.scope.Lookup("__scene");
llvm::Value *name = this->builder.CreateGlobalStringPtr(node->function_name);
llvm::Value *name = this->builder.CreateGlobalStringPtr(node.function_name);

auto function = this->ctx.module->getFunction(node->function_name);
auto function = this->ctx.module->getFunction(node.function_name);
if (!function)
{
throw std::runtime_error(std::format("Function \"{}\" could not be found!", node->function_name));
throw std::runtime_error(std::format("Function \"{}\" could not be found!", node.function_name));
}

auto scene_add_plot = this->ctx.module->getFunction("unlogic_scene_add_plot");

std::array<llvm::Value *, 3> args = {scene, name, function};

llvm::Value *ret = this->builder.CreateCall(scene_add_plot, args);

this->values.push(ret);
return this->builder.CreateCall(scene_add_plot, args);
}

void unlogic::IRGenerator::Visit(unlogic::ScopedBlockNode const *node)
{
for (auto &statement: node->statements_)
{
statement->Accept(*this);
this->values.pop();
}
}

void unlogic::IRGenerator::Visit(unlogic::ProgramEntryNode const *node)
llvm::Value *unlogic::IRGenerator::operator()(unlogic::ProgramEntryNode &node)
{
std::array<llvm::Type *, 1> args = {
llvm::PointerType::getUnqual(this->ctx.llvm_ctx),
Expand All @@ -201,7 +168,7 @@ void unlogic::IRGenerator::Visit(unlogic::ProgramEntryNode const *node)

this->builder.SetInsertPoint(block);

node->body->Accept(*this);
std::visit(*this, *node.body);

this->builder.CreateRetVoid();
this->ctx.scope.PopLayer();
Expand All @@ -210,4 +177,11 @@ void unlogic::IRGenerator::Visit(unlogic::ProgramEntryNode const *node)
{
throw std::runtime_error("function has errors");
}

return nullptr;
}

llvm::Value *unlogic::IRGenerator::operator()(std::monostate &node)
{
throw std::runtime_error("Invalid Node!");
}
36 changes: 19 additions & 17 deletions src/compiler/transformer/IRGenerator.h
Original file line number Diff line number Diff line change
@@ -1,34 +1,36 @@
#ifndef UNLOGIC_IRGENERATOR_H
#define UNLOGIC_IRGENERATOR_H

#include <llvm/IR/IRBuilder.h>
#include <stack>
#include "parser/Node.h"
#include "IRGenerationContext.h"
#include "parser/Node.h"

namespace unlogic
{
struct IRGenerator : public INodeVisitor
struct IRGenerator
{
IRGenerationContext &ctx;
llvm::IRBuilder<> builder;
std::stack<llvm::Value *> values;

void Visit(const NumericLiteralNode *node) override;
void Visit(const StringLiteralNode *node) override;
void Visit(const DivisionNode *node) override;
void Visit(const ScopedBlockNode *node) override;
void Visit(const VariableNode *node) override;
void Visit(const CallNode *node) override;
void Visit(const AdditionNode *node) override;
void Visit(const SubtractionNode *node) override;
void Visit(const MultiplicationNode *node) override;
void Visit(const PotentiationNode *node) override;
void Visit(const FunctionDefinitionNode *node) override;
void Visit(const PlotCommandNode *node) override;
void Visit(const ProgramEntryNode *node) override;
llvm::Value *operator()(std::monostate &node);
llvm::Value *operator()(unlogic::NumericLiteralNode &node);
llvm::Value *operator()(StringLiteralNode &node);
llvm::Value *operator()(DivisionNode &node);
llvm::Value *operator()(ScopedBlockNode &node);
llvm::Value *operator()(VariableNode &node);
llvm::Value *operator()(CallNode &node);
llvm::Value *operator()(AdditionNode &node);
llvm::Value *operator()(SubtractionNode &node);
llvm::Value *operator()(MultiplicationNode &node);
llvm::Value *operator()(PotentiationNode &node);
llvm::Value *operator()(FunctionDefinitionNode &node);
llvm::Value *operator()(PlotCommandNode &node);
llvm::Value *operator()(ProgramEntryNode &node);

IRGenerator(IRGenerationContext &ctx) : ctx(ctx), builder(ctx.llvm_ctx) {}
};
}
} // namespace unlogic

#endif //UNLOGIC_IRGENERATOR_H
#endif // UNLOGIC_IRGENERATOR_H
Loading

0 comments on commit e258fbd

Please sign in to comment.