Refactor LLVM JIT (#1613)

Refactor LLVM JIT for some purposes:
- To simplify the source code of JIT compilation
- To simplify the JIT modes
- To align with LLVM latest changes
- To prepare for the Multi-tier JIT compilation, refer to #1302

The changes mainly include:
- Remove the MCJIT mode, replace it with ORC JIT eager mode
- Remove the LLVM legacy pass manager (only keep the LLVM new pass manager)
- Change the lazy mode's LLVM module/function binding:
  change each function in an individual LLVM module into all functions in a single LLVM module
- Upgraded ORC JIT to ORCv2 JIT to enable lazy compilation

Refer to #1468
This commit is contained in:
Wenyong Huang
2022-10-18 20:17:34 +08:00
committed by GitHub
parent 84b1a6c10e
commit e87a554616
22 changed files with 1058 additions and 1104 deletions

View File

@ -3,6 +3,8 @@
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
*/
#include <llvm/Passes/StandardInstrumentations.h>
#include <llvm/Support/Error.h>
#include <llvm/ADT/SmallVector.h>
#include <llvm/ADT/Twine.h>
#include <llvm/ADT/Triple.h>
@ -42,23 +44,15 @@
#if LLVM_VERSION_MAJOR >= 12
#include <llvm/Analysis/AliasAnalysis.h>
#endif
#include <cstring>
#if WASM_ENABLE_LAZY_JIT != 0
#include "../aot/aot_runtime.h"
#endif
#include <cstring>
#include "../aot/aot_runtime.h"
#include "aot_llvm.h"
using namespace llvm;
using namespace llvm::orc;
extern "C" {
LLVMBool
WAMRCreateMCJITCompilerForModule(LLVMExecutionEngineRef *OutJIT,
LLVMModuleRef M,
LLVMMCJITCompilerOptions *PassedOptions,
size_t SizeOfPassedOptions, char **OutError);
LLVM_C_EXTERN_C_BEGIN
bool
aot_check_simd_compatibility(const char *arch_c_str, const char *cpu_c_str);
@ -70,93 +64,11 @@ void
aot_add_simple_loop_unswitch_pass(LLVMPassManagerRef pass);
void
aot_func_disable_tce(LLVMValueRef func);
aot_apply_llvm_new_pass_manager(AOTCompContext *comp_ctx, LLVMModuleRef module);
void
aot_apply_llvm_new_pass_manager(AOTCompContext *comp_ctx);
}
LLVM_C_EXTERN_C_END
static TargetMachine *
unwrap(LLVMTargetMachineRef P)
{
return reinterpret_cast<TargetMachine *>(P);
}
LLVMBool
WAMRCreateMCJITCompilerForModule(LLVMExecutionEngineRef *OutJIT,
LLVMModuleRef M,
LLVMMCJITCompilerOptions *PassedOptions,
size_t SizeOfPassedOptions, char **OutError)
{
LLVMMCJITCompilerOptions options;
// If the user passed a larger sized options struct, then they were compiled
// against a newer LLVM. Tell them that something is wrong.
if (SizeOfPassedOptions > sizeof(options)) {
*OutError = strdup("Refusing to use options struct that is larger than "
"my own; assuming LLVM library mismatch.");
return 1;
}
// Defend against the user having an old version of the API by ensuring that
// any fields they didn't see are cleared. We must defend against fields
// being set to the bitwise equivalent of zero, and assume that this means
// "do the default" as if that option hadn't been available.
LLVMInitializeMCJITCompilerOptions(&options, sizeof(options));
memcpy(&options, PassedOptions, SizeOfPassedOptions);
TargetOptions targetOptions;
targetOptions.EnableFastISel = options.EnableFastISel;
std::unique_ptr<Module> Mod(unwrap(M));
if (Mod) {
// Set function attribute "frame-pointer" based on
// NoFramePointerElim.
for (auto &F : *Mod) {
auto Attrs = F.getAttributes();
StringRef Value = options.NoFramePointerElim ? "all" : "none";
#if LLVM_VERSION_MAJOR <= 13
Attrs =
Attrs.addAttribute(F.getContext(), AttributeList::FunctionIndex,
"frame-pointer", Value);
#else
Attrs = Attrs.addAttributeAtIndex(F.getContext(),
AttributeList::FunctionIndex,
"frame-pointer", Value);
#endif
F.setAttributes(Attrs);
}
}
std::string Error;
bool JIT;
char *host_cpu = LLVMGetHostCPUName();
if (!host_cpu) {
*OutError = NULL;
return false;
}
std::string mcpu(host_cpu);
LLVMDisposeMessage(host_cpu);
EngineBuilder builder(std::move(Mod));
builder.setEngineKind(EngineKind::JIT)
.setErrorStr(&Error)
.setMCPU(mcpu)
.setOptLevel((CodeGenOpt::Level)options.OptLevel)
.setTargetOptions(targetOptions);
if (Optional<CodeModel::Model> CM = unwrap(options.CodeModel, JIT))
builder.setCodeModel(*CM);
if (options.MCJMM)
builder.setMCJITMemoryManager(
std::unique_ptr<RTDyldMemoryManager>(unwrap(options.MCJMM)));
if (ExecutionEngine *JIT = builder.create()) {
*OutJIT = wrap(JIT);
return 0;
}
*OutError = strdup(Error.c_str());
return 1;
}
ExitOnError ExitOnErr;
class ExpandMemoryOpPass : public llvm::ModulePass
{
@ -258,13 +170,15 @@ ExpandMemoryOpPass::runOnModule(Module &M)
void
aot_add_expand_memory_op_pass(LLVMPassManagerRef pass)
{
unwrap(pass)->add(new ExpandMemoryOpPass());
reinterpret_cast<legacy::PassManager *>(pass)->add(
new ExpandMemoryOpPass());
}
void
aot_add_simple_loop_unswitch_pass(LLVMPassManagerRef pass)
{
unwrap(pass)->add(createSimpleLoopUnswitchLegacyPass());
reinterpret_cast<legacy::PassManager *>(pass)->add(
createSimpleLoopUnswitchLegacyPass());
}
bool
@ -308,123 +222,54 @@ aot_check_simd_compatibility(const char *arch_c_str, const char *cpu_c_str)
#endif /* WASM_ENABLE_SIMD */
}
#if WASM_ENABLE_LAZY_JIT != 0
#if LLVM_VERSION_MAJOR < 12
LLVMOrcJITTargetMachineBuilderRef
LLVMOrcJITTargetMachineBuilderFromTargetMachine(LLVMTargetMachineRef TM);
LLVMOrcJITTargetMachineBuilderRef
LLVMOrcJITTargetMachineBuilderCreateFromTargetMachine(LLVMTargetMachineRef TM)
{
return LLVMOrcJITTargetMachineBuilderFromTargetMachine(TM);
}
#endif
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(LLJITBuilder, LLVMOrcLLJITBuilderRef)
void
LLVMOrcLLJITBuilderSetNumCompileThreads(LLVMOrcLLJITBuilderRef orcjit_builder,
unsigned num_compile_threads)
aot_apply_llvm_new_pass_manager(AOTCompContext *comp_ctx, LLVMModuleRef module)
{
unwrap(orcjit_builder)->setNumCompileThreads(num_compile_threads);
}
void *
aot_lookup_orcjit_func(LLVMOrcLLJITRef orc_lazyjit, void *module_inst,
uint32 func_idx)
{
char func_name[32], buf[128], *err_msg = NULL;
LLVMErrorRef error;
LLVMOrcJITTargetAddress func_addr = 0;
AOTModuleInstance *aot_inst = (AOTModuleInstance *)module_inst;
AOTModule *aot_module = (AOTModule *)aot_inst->module;
void **func_ptrs = aot_inst->func_ptrs;
/**
* No need to lock the func_ptr[func_idx] here as it is basic
* data type, the load/store for it can be finished by one cpu
* instruction, and there can be only one cpu instruction
* loading/storing at the same time.
*/
if (func_ptrs[func_idx])
return func_ptrs[func_idx];
snprintf(func_name, sizeof(func_name), "%s%d", AOT_FUNC_PREFIX,
func_idx - aot_module->import_func_count);
if ((error = LLVMOrcLLJITLookup(orc_lazyjit, &func_addr, func_name))) {
err_msg = LLVMGetErrorMessage(error);
snprintf(buf, sizeof(buf), "failed to lookup orcjit function: %s",
err_msg);
aot_set_exception(aot_inst, buf);
LLVMDisposeErrorMessage(err_msg);
return NULL;
}
func_ptrs[func_idx] = (void *)func_addr;
return (void *)func_addr;
}
#endif /* end of WASM_ENABLE_LAZY_JIT != 0 */
void
aot_func_disable_tce(LLVMValueRef func)
{
Function *F = unwrap<Function>(func);
auto Attrs = F->getAttributes();
#if LLVM_VERSION_MAJOR <= 13
Attrs = Attrs.addAttribute(F->getContext(), AttributeList::FunctionIndex,
"disable-tail-calls", "true");
#else
Attrs =
Attrs.addAttributeAtIndex(F->getContext(), AttributeList::FunctionIndex,
"disable-tail-calls", "true");
#endif
F->setAttributes(Attrs);
}
#if LLVM_VERSION_MAJOR >= 12
void
aot_apply_llvm_new_pass_manager(AOTCompContext *comp_ctx)
{
Module *M;
TargetMachine *TM = unwrap(comp_ctx->target_machine);
bool disable_llvm_lto = false;
LoopAnalysisManager LAM;
FunctionAnalysisManager FAM;
CGSCCAnalysisManager CGAM;
ModuleAnalysisManager MAM;
TargetMachine *TM =
reinterpret_cast<TargetMachine *>(comp_ctx->target_machine);
PipelineTuningOptions PTO;
PTO.LoopVectorization = true;
PTO.SLPVectorization = true;
PTO.LoopUnrolling = true;
#ifdef DEBUG_PASS
PassInstrumentationCallbacks PIC;
PassBuilder PB(TM, PTO, None, &PIC);
#else
#if LLVM_VERSION_MAJOR == 12
PassBuilder PB(false, TM, PTO);
#else
PassBuilder PB(TM, PTO);
#endif
#endif
// Register the target library analysis directly and give it a
// customized preset TLI.
/* Register all the basic analyses with the managers */
LoopAnalysisManager LAM;
FunctionAnalysisManager FAM;
CGSCCAnalysisManager CGAM;
ModuleAnalysisManager MAM;
/* Register the target library analysis directly and give it a
customized preset TLI */
std::unique_ptr<TargetLibraryInfoImpl> TLII(
new TargetLibraryInfoImpl(Triple(TM->getTargetTriple())));
FAM.registerPass([&] { return TargetLibraryAnalysis(*TLII); });
// Register the AA manager first so that our version is the one used.
/* Register the AA manager first so that our version is the one used */
AAManager AA = PB.buildDefaultAAPipeline();
FAM.registerPass([&] { return std::move(AA); });
// Register all the basic analyses with the managers.
PB.registerModuleAnalyses(MAM);
PB.registerCGSCCAnalyses(CGAM);
#ifdef DEBUG_PASS
StandardInstrumentations SI(true, false);
SI.registerCallbacks(PIC, &FAM);
#endif
PB.registerFunctionAnalyses(FAM);
PB.registerLoopAnalyses(LAM);
PB.registerModuleAnalyses(MAM);
PB.registerCGSCCAnalyses(CGAM);
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
ModulePassManager MPM;
#if LLVM_VERSION_MAJOR <= 13
PassBuilder::OptimizationLevel OL;
@ -463,25 +308,23 @@ aot_apply_llvm_new_pass_manager(AOTCompContext *comp_ctx)
}
#endif /* end of LLVM_VERSION_MAJOR */
if (comp_ctx->disable_llvm_lto) {
disable_llvm_lto = true;
}
bool disable_llvm_lto = comp_ctx->disable_llvm_lto;
#if WASM_ENABLE_SPEC_TEST != 0
disable_llvm_lto = true;
#endif
Module *M = reinterpret_cast<Module *>(module);
if (disable_llvm_lto) {
uint32 i;
for (i = 0; i < comp_ctx->func_ctx_count; i++) {
aot_func_disable_tce(comp_ctx->func_ctxes[i]->func);
for (Function &F : *M) {
F.addFnAttr("disable-tail-calls", "true");
}
}
ModulePassManager MPM;
if (comp_ctx->is_jit_mode) {
/* Apply normal pipeline for JIT mode, without
Vectorize related passes, without LTO */
MPM.addPass(PB.buildPerModuleDefaultPipeline(OL));
const char *Passes =
"mem2reg,instcombine,simplifycfg,jump-threading,indvars";
ExitOnErr(PB.parsePassPipeline(MPM, Passes));
}
else {
FunctionPassManager FPM;
@ -512,16 +355,5 @@ aot_apply_llvm_new_pass_manager(AOTCompContext *comp_ctx)
}
}
#if WASM_ENABLE_LAZY_JIT == 0
M = unwrap(comp_ctx->module);
MPM.run(*M, MAM);
#else
uint32 i;
for (i = 0; i < comp_ctx->func_ctx_count; i++) {
M = unwrap(comp_ctx->modules[i]);
MPM.run(*M, MAM);
}
#endif
}
#endif /* end of LLVM_VERSION_MAJOR >= 12 */