diff --git a/data_specifications/specification.proto b/data_specifications/specification.proto index 9f7a6d963..34692f52d 100644 --- a/data_specifications/specification.proto +++ b/data_specifications/specification.proto @@ -305,6 +305,11 @@ message StackFrame { int64 parameter_offset = 5; } +message TypeHint { + uint64 target_addr = 1; + Variable target_var = 2; +} + message Function { uint64 entry_address = 1; FunctionLinkage func_linkage = 3; @@ -319,6 +324,11 @@ message Function { StackFrame frame = 9; repeated Parameter in_scope_vars = 10; + + // an instruction can have a set of typehints that says this loc is known + // to have this type after this instruction, these will be translated into + // a low lifting of that location with spec type metadata + repeated TypeHint type_hints = 11; } message GlobalVariable { diff --git a/include/anvill/Declarations.h b/include/anvill/Declarations.h index a6bdfe713..7977fff7a 100644 --- a/include/anvill/Declarations.h +++ b/include/anvill/Declarations.h @@ -365,6 +365,11 @@ class SpecBlockContext : public BasicBlockContext { virtual const std::vector &LiveParamsAtExit() const override; }; + +struct TypeHint { + uint64_t target_addr; + ValueDecl hint; +}; // A function decl, as represented at a "near ABI" level. To be specific, // not all C, and most C++ decls, as written would be directly translatable // to this. This ought nearly represent how LLVM represents a C/C++ function @@ -415,6 +420,9 @@ struct FunctionDecl : public CallableDecl { std::unordered_map> constant_values_at_exit; + // sorted vector of hints + std::vector type_hints; + std::uint64_t stack_depth; std::uint64_t maximum_depth; diff --git a/lib/Lifters/BasicBlockLifter.cpp b/lib/Lifters/BasicBlockLifter.cpp index d690a175a..e55748679 100644 --- a/lib/Lifters/BasicBlockLifter.cpp +++ b/lib/Lifters/BasicBlockLifter.cpp @@ -298,6 +298,33 @@ bool BasicBlockLifter::DecodeInstructionInto(const uint64_t addr, } +void BasicBlockLifter::ApplyTypeHint(llvm::IRBuilder<> &bldr, + const ValueDecl &type_hint) { + + auto ty_hint = this->GetTypeHintFunction(); + auto state_ptr_internal = + this->lifted_func->getArg(remill::kStatePointerArgNum); + auto mem_ptr = + remill::LoadMemoryPointer(bldr.GetInsertBlock(), this->intrinsics); + auto curr_value = + anvill::LoadLiftedValue(type_hint, options.TypeDictionary(), intrinsics, + options.arch, bldr, state_ptr_internal, mem_ptr); + + if (curr_value->getType()->isPointerTy()) { + auto call = bldr.CreateCall(ty_hint, {curr_value}); + call->setMetadata("anvill.type", this->type_specifier.EncodeToMetadata( + type_hint.spec_type)); + curr_value = call; + } + + auto new_mem_ptr = + StoreNativeValue(curr_value, type_hint, options.TypeDictionary(), + intrinsics, bldr, state_ptr_internal, mem_ptr); + bldr.CreateStore(new_mem_ptr, + remill::LoadMemoryPointerRef(bldr.GetInsertBlock())); +} + + void BasicBlockLifter::LiftInstructionsIntoLiftedFunction() { auto entry_block = &this->lifted_func->getEntryBlock(); @@ -340,6 +367,22 @@ void BasicBlockLifter::LiftInstructionsIntoLiftedFunction() { inst, bb, this->lifted_func->getArg(remill::kStatePointerArgNum), false /* is_delayed */); + llvm::IRBuilder<> builder(bb); + + auto start = + std::lower_bound(decl.type_hints.begin(), decl.type_hints.end(), + inst.pc, [](const TypeHint &hint_rhs, uint64_t addr) { + return hint_rhs.target_addr < addr; + }); + auto end = + std::upper_bound(decl.type_hints.begin(), decl.type_hints.end(), + inst.pc, [](uint64_t addr, const TypeHint &hint_rhs) { + return addr < hint_rhs.target_addr; + }); + for (; start != end; start++) { + this->ApplyTypeHint(builder, start->hint); + } + ended_on_terminal = !this->ApplyInterProceduralControlFlowOverride(inst, bb); DLOG_IF(INFO, ended_on_terminal) diff --git a/lib/Lifters/BasicBlockLifter.h b/lib/Lifters/BasicBlockLifter.h index 565b365b5..e6c065a97 100644 --- a/lib/Lifters/BasicBlockLifter.h +++ b/lib/Lifters/BasicBlockLifter.h @@ -73,6 +73,9 @@ class BasicBlockLifter : public CodeLifter { remill::DecodingContext CreateDecodingContext(const CodeBlock &blk); + + void ApplyTypeHint(llvm::IRBuilder<> &bldr, const ValueDecl &type_hint); + void LiftInstructionsIntoLiftedFunction(); BasicBlockFunction CreateBasicBlockFunction(); diff --git a/lib/Lifters/CodeLifter.cpp b/lib/Lifters/CodeLifter.cpp index 372913cfe..6454142bb 100644 --- a/lib/Lifters/CodeLifter.cpp +++ b/lib/Lifters/CodeLifter.cpp @@ -1,6 +1,9 @@ #include "CodeLifter.h" +#include +#include #include +#include #include #include #include @@ -24,8 +27,6 @@ #include -#include "anvill/Type.h" - namespace anvill { namespace { // Clear out LLVM variable names. They're usually not helpful. @@ -170,6 +171,25 @@ void CodeLifter::InitializeStateStructureFromGlobalRegisterVariables( }); } +llvm::Function *CodeLifter::GetTypeHintFunction() { + const auto &func_name = kTypeHintFunctionPrefix; + + auto func = semantics_module->getFunction(func_name); + if (func != nullptr) { + return func; + } + + auto ptr = llvm::PointerType::get(this->semantics_module->getContext(), 0); + llvm::Type *func_parameters[] = {ptr}; + + auto func_type = llvm::FunctionType::get(ptr, func_parameters, false); + + func = llvm::Function::Create(func_type, llvm::GlobalValue::ExternalLinkage, + func_name, this->semantics_module); + + return func; +} + llvm::MDNode *CodeLifter::GetAddrAnnotation(uint64_t addr, llvm::LLVMContext &context) const { auto pc_val = llvm::ConstantInt::get( diff --git a/lib/Lifters/CodeLifter.h b/lib/Lifters/CodeLifter.h index 9f28d1276..cc7f10438 100644 --- a/lib/Lifters/CodeLifter.h +++ b/lib/Lifters/CodeLifter.h @@ -72,6 +72,9 @@ class CodeLifter { unsigned pc_annotation_id; + + llvm::Function *GetTypeHintFunction(); + llvm::MDNode *GetAddrAnnotation(uint64_t addr, llvm::LLVMContext &context) const; diff --git a/lib/Optimize.cpp b/lib/Optimize.cpp index e204405fa..e3b1db9ca 100644 --- a/lib/Optimize.cpp +++ b/lib/Optimize.cpp @@ -262,7 +262,6 @@ void OptimizeModule(const EntityLifter &lifter, llvm::Module &module, //AddRecoverBasicStackFrame(fpm, options.stack_frame_recovery_options); //AddSplitStackFrameAtReturnAddress(fpm, options.stack_frame_recovery_options); fpm.addPass(llvm::SROAPass(llvm::SROAOptions::ModifyCFG)); - //fpm.addPass(anvill::ReplaceStackReferences(contexts, lifter)); fpm.addPass(llvm::VerifierPass()); fpm.addPass(llvm::SROAPass(llvm::SROAOptions::ModifyCFG)); fpm.addPass(llvm::VerifierPass()); diff --git a/lib/Passes/ConvertPointerArithmeticToGEP.cpp b/lib/Passes/ConvertPointerArithmeticToGEP.cpp index 19bec6289..8a0f949fb 100644 --- a/lib/Passes/ConvertPointerArithmeticToGEP.cpp +++ b/lib/Passes/ConvertPointerArithmeticToGEP.cpp @@ -6,6 +6,7 @@ * the LICENSE file found in the root directory of this source tree. */ +#include #include #include #include @@ -18,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -71,6 +73,8 @@ struct ConvertPointerArithmeticToGEP::Impl { llvm::MDNode *TypeSpecToMD(llvm::LLVMContext &context, UnknownType t); llvm::MDNode *TypeSpecToMD(llvm::LLVMContext &context, TypeSpec type); + + bool ConvertTypeHints(llvm::Function &f); bool ConvertLoadInt(llvm::Function &f); bool FoldPtrAdd(llvm::Function &f); bool FoldScaledIndex(llvm::Function &f); @@ -330,6 +334,26 @@ llvm::StringRef ConvertPointerArithmeticToGEP::name() { return "ConvertPointerArithmeticToGEP"; } +bool ConvertPointerArithmeticToGEP::Impl::ConvertTypeHints(llvm::Function &f) { + std::vector calls; + for (auto &insn : llvm::instructions(f)) { + if (auto *call = llvm::dyn_cast(&insn)) { + if (call->getCalledFunction() && + call->getCalledFunction()->getName() == kTypeHintFunctionPrefix) { + calls.push_back(call); + } + } + } + + for (auto call : calls) { + auto arg = call->getArgOperand(0); + call->replaceAllUsesWith(arg); + call->eraseFromParent(); + } + + return !calls.empty(); +} + // Finds `(load i64, P)` and converts it to `(ptrtoint (load ptr, P))` bool ConvertPointerArithmeticToGEP::Impl::ConvertLoadInt(llvm::Function &f) { using namespace llvm::PatternMatch; @@ -573,6 +597,7 @@ llvm::PreservedAnalyses ConvertPointerArithmeticToGEP::runOnBasicBlockFunction( bool changed = impl->ConvertLoadInt(function); changed |= impl->FoldPtrAdd(function); changed |= impl->FoldScaledIndex(function); + changed |= impl->ConvertTypeHints(function); return changed ? llvm::PreservedAnalyses::none() : llvm::PreservedAnalyses::all(); } diff --git a/lib/Protobuf.cpp b/lib/Protobuf.cpp index 2d41d1d32..6c72c0444 100644 --- a/lib/Protobuf.cpp +++ b/lib/Protobuf.cpp @@ -558,6 +558,27 @@ Result ProtobufTranslator::DecodeFunction( this->ParseCFGIntoFunction(function, decl); + + for (auto &ty_hint : function.type_hints()) { + auto maybe_type = DecodeType(ty_hint.target_var().type()); + if (maybe_type.Succeeded()) { + auto maybe_var = + DecodeValueDecl(ty_hint.target_var().values(), maybe_type.TakeValue(), + "attempting to decode type hint value"); + if (maybe_var.Succeeded()) { + decl.type_hints.push_back( + {ty_hint.target_addr(), maybe_var.TakeValue()}); + } + } else { + LOG(ERROR) << "Failed to decode type for type hint"; + } + } + + std::sort(decl.type_hints.begin(), decl.type_hints.end(), + [](const TypeHint &hint_lhs, const TypeHint &hint_rhs) { + return hint_lhs.target_addr < hint_rhs.target_addr; + }); + auto link = function.func_linkage(); if (link == specification::FUNCTION_LINKAGE_DECL) {