05 Apr

Optimizing LLVM IR from the C API

After my previous post on how to read & write LLVM bitcode, I thought I’d follow it up with a post on actually modifying LLVM bitcode files after you’ve read them. LLVM comes with extensive built-in optimization passes, but also plenty of scope to do your own optimizations too.

First off, remember that that when we parse an LLVM bitcode file, we get an LLVM module. So what exactly is an LLVM module?

An LLVM module is basically a collection of global variables and functions. Functions contain basic-blocks. The simplest way to think about a basic block is a scope in a C/C++ file:

int foo(int a, int b, bool c) {
  // basic-block 0 is here
  if (c) {
    // basic-block 1 is here
    return a;
  } else {
    // basic-block 2 is here
    return b;
  }
}

In the above example, we have three basic-blocks. There is always an entry basic-block, one that is tied to the function itself. Then that basic-block can branch to one or more other basic-blocks. Each basic-block contains one or more instructions.

So now we know the basics of how LLVM is put together, lets look at actually doing an optimization. Just to keep with something trivial, lets do constant folding. Constant folding is when you have an instruction that takes only constant arguments, and so you can replace the instruction with a single constant value instead:

int foo() {
  return 13 + 42;
}

In the above example you really don’t want the compiler to actually do 13 + 42 at runtime, what you want is that it instead uses the constant 55 instead. LLVM already has this sort of optimization, but lets do it ourselves to work through the process.

We’ll integrate our new code into the parser I used in my last blog post. Looking back at how we got our LLVM module:

LLVMModuleRef module;
if (0 != LLVMParseBitcode2(memoryBuffer, &module)) {
  fprintf(stderr, "Invalid bitcode detected!\n");
  LLVMDisposeMemoryBuffer(memoryBuffer);
  return 1;
}

So we have our LLVM module, now lets start to look for instructions that use constant arguments. So given an LLVM module, we first have to walk the functions within that module:

for (LLVMValueRef function =
  LLVMGetFirstFunction(module);
  function;
  function =
  LLVMGetNextFunction(function)) {
  // ...
}

Then, we walk the basic-blocks of each function:

for (LLVMBasicBlockRef basicBlock =
  LLVMGetFirstBasicBlock(function);
  basicBlock;
  basicBlock =
  LLVMGetNextBasicBlock(basicBlock)) {
  // ...
}

And finally we walk the instructions in each basic-block:

for (LLVMValueRef instruction =
  LLVMGetFirstInstruction(basicBlock);
  instruction;
  instruction =
  LLVMGetNextInstruction(instruction)) {
  // ...
}

So now we are walking every instruction, of every basic-block, of every function, in our module. Now we need to identify instructions that we want to investigate. Just to keep things simple, we’ll only try and fold binary operators – things like + – * /.

if (LLVMIsABinaryOperator(instruction)) {

Once we know we’ve got a binary operator, we know we’ve got exactly two operands:

LLVMValueRef x = LLVMGetOperand(instruction, 0);
LLVMValueRef y = LLVMGetOperand(instruction, 1);

And to check if we have constant operands or not we simply do:

const int allConstant = LLVMIsAConstant(x) && LLVMIsAConstant(y);

So if allConstant is true, we know we can fold the binary operator. To fold the operation, we look at the opcode of the binary operator – this identifies which binary operator it actually is. Then, we turn the binary operator into a constant expression which does the same thing. LLVM will do all the hard work for us when we create a constant expression with known constant values – it’ll do the constant folding we want for us.

LLVMValueRef replacementValue = 0;

if (allConstant) {
  switch (LLVMGetInstructionOpcode(instruction)) {
  default:
    break;
  case LLVMAdd:
    replacementValue = LLVMConstAdd(x, y);
    break;
  case LLVMFAdd:
    replacementValue = LLVMConstFAdd(x, y);
    break;
  case LLVMSub:
    replacementValue = LLVMConstSub(x, y);
    break;
  case LLVMFSub:
    replacementValue = LLVMConstFSub(x, y);
    break;
  case LLVMMul:
    replacementValue = LLVMConstMul(x, y);
    break;
  case LLVMFMul:
    replacementValue = LLVMConstFMul(x, y);
    break;
  case LLVMUDiv:
    replacementValue = LLVMConstUDiv(x, y);
    break;
  case LLVMSDiv:
    replacementValue = LLVMConstSDiv(x, y);
    break;
  case LLVMFDiv:
    replacementValue = LLVMConstFDiv(x, y);
    break;
  case LLVMURem:
    replacementValue = LLVMConstURem(x, y);
    break;
  case LLVMSRem:
    replacementValue = LLVMConstSRem(x, y);
    break;
  case LLVMFRem:
    replacementValue = LLVMConstFRem(x, y);
    break;
  case LLVMShl:
    replacementValue = LLVMConstShl(x, y);
    break;
  case LLVMLShr:
    replacementValue = LLVMConstLShr(x, y);
    break;
  case LLVMAShr:
    replacementValue = LLVMConstAShr(x, y);
    break;
  case LLVMAnd:
    replacementValue = LLVMConstAnd(x, y);
    break;
  case LLVMOr:
    replacementValue = LLVMConstOr(x, y);
    break;
  case LLVMXor:
    replacementValue = LLVMConstXor(x, y);
    break;
  }
}

Now we have our replacement value for the binary operator, we can replace the original instruction with the new constant:

// if we managed to find a more optimal replacement
if (replacementValue) {
  // replace all uses of the old instruction with the new one
  LLVMReplaceAllUsesWith(instruction, replacementValue);

  // erase the instruction that we've replaced
  LLVMInstructionEraseFromParent(instruction);
}

And that’s it! Lets run this on a simple example I’ve knocked together:

define i32 @foo() {
  %a = mul i32 13, 46
  %1 = add i32 4, %a
  %2 = sub i32 %1, 400
  %3 = shl i32 %2, 2
  ret i32 %3
}

We can run our code and everything… explodes? You’ll get some horrific segfault deep in LLVM and run away screaming, if you didn’t know what the cause is. Basically when hacking with an LLVM module, you’ve got to be super careful about deleting instructions, basic-blocks or functions while you are still iterating through the lists. You either need to store all the values and do all the replacements after you’ve finished with the iterators, or handle the replacements very carefully.

So the way I do it when iterating using the C API is to remember the last instruction in the instruction stream, and then if I happen to replace the current instruction and delete it, I know I can look at the last instruction and pick up the remainder of the instruction stream from there.

LLVMValueRef lastInstruction = 0;

// loop through all the instructions in the basic block
for (LLVMValueRef instruction =
  LLVMGetFirstInstruction(basicBlock);
  instruction;) {
  LLVMValueRef replacementValue = 0;

  // ...

  // if we managed to find a more optimal replacement
  if (replacementValue) {
    // replace all uses of the old instruction with the new one
    LLVMReplaceAllUsesWith(instruction, replacementValue);

    // erase the instruction that we've replaced
    LLVMInstructionEraseFromParent(instruction);

    // if we don't have a previous instruction, get the first
    // one from the basic block again
    if (!lastInstruction) {
      instruction = LLVMGetFirstInstruction(basicBlock);
    } else {
      instruction = LLVMGetNextInstruction(lastInstruction);
    }
  } else {
    lastInstruction = instruction;
    instruction = LLVMGetNextInstruction(instruction);
  }
}

So now we are correctly handling the case where we are removing instructions within the instruction stream. So if we compile and run everything again, what do we get? So for the example I showed above, after I’ve run my bitcode read -> constant fold -> bitcode write, I get the following LLVM IR:

define i32 @foo() {
  ret i32 808
}

Nice! That looks like much more optimal code.

I will note that this is not how real LLVM passes work (they are full C++ with classes and templates), but it does allow you to easily work with LLVM IR yourself.

I’ve updated my llvm_bc_parse_example to include this optimization code.

I hope this proves useful to any budding compiler engineers who want to start tinkering with LLVM!

One thought on “Optimizing LLVM IR from the C API

Comments are closed.