gollvm: intrinsify runtime.getg

Currently, getg is implemented in C, which loads the thread-local
g variable. Calling getg from Go, which is somewhat frequent in
the runtime, is not inlineable. This CL lets the compiler
generate inlined load for getg. We do this in the middle-end
instead of the frontend because currently we can only do this
safely in gollvm, not gccgo.

The concern is that The backend may choose to cache the TLS
address in a register or on stack. If a thread switch happens,
the cache will become invalid. Currently, there seems no way to
tell the backend to disable or invalidate the cache, for both GCC
and LLVM backend.

However, in LLVM, the caching happens in SelectionDAG, where it
CSEs the TLS address for multiple loads in the same basic block.
It does not CSE the TLS address across multiple basic blocks. So
we are safe if we can ensure there is at most one load of g in a
block. If there are more than one, we replace it with a call of
runtime.getg (i.e. undoing the // inlining). We introduce a
backend pass to do this.

For TLS access in PIC mode, currently LLVM generates assembly
code that contains detached data16 prefixes, which causes the
assembler to emit warnings. This doesn't seem to cause any
actual problems, besides cluttering outputs.

Change-Id: Ib3969470cc838c84f993dece3716399839b69959
Reviewed-on: https://go-review.googlesource.com/c/gollvm/+/186257
Reviewed-by: Than McIntosh <thanm@google.com>
diff --git a/bridge/go-llvm-materialize.cpp b/bridge/go-llvm-materialize.cpp
index 284b390..4afc5fe 100644
--- a/bridge/go-llvm-materialize.cpp
+++ b/bridge/go-llvm-materialize.cpp
@@ -29,6 +29,12 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
+#include "llvm/Support/CommandLine.h"
+
+static llvm::cl::opt<bool> DisableInlineGetg("disable-inline-getg",
+                                             llvm::cl::desc("Disable inlining getg"),
+                                             llvm::cl::init(false),
+                                             llvm::cl::Hidden);
 
 Bexpression *Llvm_backend::materializeIndirect(Bexpression *indExpr, bool isLHS)
 {
@@ -1360,6 +1366,25 @@
   }
 }
 
+// Inline runtime.getg, generate a load of g.
+// This is not done as a builtin because, unlike other builtins,
+// we need the FE to tell us the result type.
+static llvm::Value *makeGetg(Btype *resType,
+                             BinstructionsLIRBuilder *builder,
+                             Llvm_backend *be)
+{
+  llvm::GlobalValue* g = be->module().getGlobalVariable("runtime.g");
+  if (!g) {
+    bool is_external = true, is_hidden = false, in_unique_section = false;
+    Location location; // dummy
+    Bvariable* bv = be->global_variable("runtime.g", "runtime.g", resType, is_external,
+                                        is_hidden, in_unique_section, location);
+    g = llvm::cast<llvm::GlobalValue>(bv->value());
+    g->setThreadLocal(true);
+  }
+  return builder->CreateLoad(g);
+}
+
 Bexpression *Llvm_backend::materializeCall(Bexpression *callExpr)
 {
   Location location = callExpr->location();
@@ -1461,7 +1486,8 @@
       BuiltinExprMaker makerfn = be->exprMaker();
       if (makerfn)
         callValue = makerfn(state.llargs, &state.builder, this);
-    }
+    } else if (fcn->getName() == "runtime.getg" && !DisableInlineGetg)
+      callValue = makeGetg(rbtype, &state.builder, this);
   }
   if (!callValue) {
     llvm::FunctionType *llft =
diff --git a/driver/CompileGo.cpp b/driver/CompileGo.cpp
index 14ff429..94b4916 100644
--- a/driver/CompileGo.cpp
+++ b/driver/CompileGo.cpp
@@ -914,6 +914,8 @@
       createTargetTransformInfoWrapperPass(target_->getTargetIRAnalysis()));
   createPasses(modulePasses, functionPasses);
 
+  modulePasses.add(createGoSafeGetgPass());
+
   // Add statepoint insertion pass to the end of optimization pipeline,
   // right before lowering to machine IR.
   if (enable_gc_) {
diff --git a/passes/CMakeLists.txt b/passes/CMakeLists.txt
index 22a0a3b..b70279c 100644
--- a/passes/CMakeLists.txt
+++ b/passes/CMakeLists.txt
@@ -10,6 +10,7 @@
   GC.cpp
   GoAnnotation.cpp
   GoNilChecks.cpp
+  GoSafeGetg.cpp
   GoStatepoints.cpp
   GoWrappers.cpp
   RemoveAddrSpace.cpp
diff --git a/passes/GoSafeGetg.cpp b/passes/GoSafeGetg.cpp
new file mode 100644
index 0000000..52b54c2
--- /dev/null
+++ b/passes/GoSafeGetg.cpp
@@ -0,0 +1,134 @@
+//===--- GoSafeGetg.cpp ---------------------------------------------------===//
+//
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+//
+//===----------------------------------------------------------------------===//
+//
+// LLVM backend pass to make sure inlined getg's are
+// safe. Specifically, make sure the TLS address is not
+// cached across a thread switch.
+//
+//===----------------------------------------------------------------------===//
+
+#include "GollvmPasses.h"
+
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Pass.h"
+#include "llvm/PassRegistry.h"
+#include "llvm/Support/Debug.h"
+
+using namespace llvm;
+
+namespace {
+
+class GoSafeGetg : public ModulePass {
+ public:
+  static char ID;
+
+  GoSafeGetg() : ModulePass(ID) {
+    initializeGoSafeGetgPass(*PassRegistry::getPassRegistry());
+  }
+
+  bool runOnModule(Module &M) override;
+};
+
+}  // namespace
+
+char GoSafeGetg::ID = 0;
+INITIALIZE_PASS(GoSafeGetg, "go-safegetg",
+                "Ensure Go getg's are safe", false,
+                false)
+ModulePass *llvm::createGoSafeGetgPass() { return new GoSafeGetg(); }
+
+// In the runtime g is a thread-local variable. The backend may
+// choose to cache the TLS address in a register or on stack.
+// If a thread switch happens, the cache will become invalid.
+// Specifically, within a function,
+//
+//   load g
+//   call mcall(...)
+//   load g
+//
+// may be compiled to
+//
+//   leaq    g@TLS, %rdi
+//   call    __tls_get_addr
+//   movq    %rax, %rbx     // cache in a callee-save register %rbx
+//   ... use g in %rax ...
+//   call    foo
+//   ... use g in %rbx ...
+//
+// This is incorrect if a thread switch happens at the call of foo.
+// Currently, there seems no way to tell the backend to disable
+// or invalidate the cache.
+//
+// In LLVM, this happens in SelectionDAG, where it CSEs the TLS
+// address for multiple loads in the same basic block. It does not
+// CSE the TLS address across multiple basic blocks. So we are
+// safe if we can ensure there is at most one load of g in a block.
+// This function looks for second load of g in each block, and, if
+// found, replace it with a call of runtime.getg (i.e. undoing the
+// inlining).
+//
+bool
+GoSafeGetg::runOnModule(Module &M) {
+  GlobalVariable *GV = M.getGlobalVariable("runtime.g");
+  if (!GV)
+    return false; // no access of g, nothing to do
+
+  bool Changed = false;
+  for (Function &F : M) {
+    SmallVector<Instruction*, 2> ToDel;
+
+    for (BasicBlock &BB : F) {
+      bool HasGetg = false;
+      bool HasCall = false; // whether we have seen a call after a getg
+      for (Instruction &I : BB) {
+        if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
+          if (LI->getPointerOperand()->stripPointerCasts() == GV) {
+            HasGetg = true;
+            if (!HasCall)
+              continue;
+
+            // There is a getg and a call before this getg.
+            // We replace the second one with a call.
+            IRBuilder<> Builder(&I);
+            FunctionCallee GetgFn =
+                M.getOrInsertFunction("runtime.getg", I.getType());
+            Instruction *Call = Builder.CreateCall(GetgFn);
+            I.replaceAllUsesWith(Call);
+            ToDel.push_back(&I);
+            Changed = true;
+            continue;
+          }
+        } else {
+          // We should not see a use of g that is not a load.
+          for (Value *O : I.operands())
+            if (O->stripPointerCasts() == GV)
+              report_fatal_error("non-load use of runtime.g: " +
+                                 I.getName() + " ( in function " +
+                                 F.getName() + ")");
+        }
+
+        if (HasGetg && !HasCall)
+          if (CallInst *CI = dyn_cast<CallInst>(&I)) {
+            if (Function *Fn = CI->getCalledFunction())
+              if (Fn->isIntrinsic())
+                continue; // intrinsics are ok
+            HasCall = true;
+          }
+      }
+    }
+
+    for (Instruction* I : ToDel)
+      I->eraseFromParent();
+  }
+
+  return Changed;
+}
diff --git a/passes/GollvmPasses.h b/passes/GollvmPasses.h
index 0eef28c..398df82 100644
--- a/passes/GollvmPasses.h
+++ b/passes/GollvmPasses.h
@@ -22,12 +22,14 @@
 
 void initializeGoAnnotationPass(PassRegistry&);
 void initializeGoNilChecksPass(PassRegistry&);
+void initializeGoSafeGetgPass(PassRegistry&);
 void initializeGoStatepointsLegacyPassPass(PassRegistry&);
 void initializeGoWrappersPass(PassRegistry&);
 void initializeRemoveAddrSpacePassPass(PassRegistry&);
 
 FunctionPass *createGoAnnotationPass();
 FunctionPass *createGoNilChecksPass();
+ModulePass *createGoSafeGetgPass();
 ModulePass *createGoStatepointsLegacyPass();
 FunctionPass *createGoWrappersPass();
 ModulePass *createRemoveAddrSpacePass(const DataLayout&);