gollvm: materialize correct control flow for no-return calls

Update the bridge to generate the correct control flow for no-return
calls, e.g. insure that a no-return call ends the current basic block
by inserting an "unreachable" op. Prior to this the bridge tagged
calls as no-return correctly but did not take any other actions.

Change-Id: I4aeffa87e0f4a7e1ab30b817600dafc4708d85df
Reviewed-on: https://go-review.googlesource.com/c/140119
Reviewed-by: Cherry Zhang <cherryyz@google.com>
diff --git a/bridge/go-llvm-bexpression.cpp b/bridge/go-llvm-bexpression.cpp
index e54be63..1b97168 100644
--- a/bridge/go-llvm-bexpression.cpp
+++ b/bridge/go-llvm-bexpression.cpp
@@ -31,6 +31,17 @@
   return true;
 }
 
+std::vector<llvm::Instruction *>
+Binstructions::extractInstsAfter(llvm::Instruction *inst)
+{
+  auto it = std::find(instructions_.begin(), instructions_.end(), inst);
+  assert(it != instructions_.end());
+  std::vector<llvm::Instruction *> rval(it, instructions_.end());
+  if (it != instructions_.end())
+    instructions_.erase(it, instructions_.end());
+  return rval;
+}
+
 Bexpression::Bexpression(NodeFlavor fl, const std::vector<Bnode *> &kids,
                          llvm::Value *val, Btype *typ, Location loc)
     : Bnode(fl, kids, loc)
diff --git a/bridge/go-llvm-bexpression.h b/bridge/go-llvm-bexpression.h
index 47fad06..767ac47 100644
--- a/bridge/go-llvm-bexpression.h
+++ b/bridge/go-llvm-bexpression.h
@@ -54,6 +54,11 @@
 
   void clear() { instructions_.clear(); }
 
+  // Locate 'inst' within the instructions vector, then remove 'inst' and all
+  // subsequent instructions from the list and return them as a vector. Will
+  // assert if 'inst' is not found in the list.
+  std::vector<llvm::Instruction *> extractInstsAfter(llvm::Instruction *inst);
+
 private:
   std::vector<llvm::Instruction *> instructions_;
 
diff --git a/bridge/go-llvm-bnode.cpp b/bridge/go-llvm-bnode.cpp
index 1311c99..e0e6290 100644
--- a/bridge/go-llvm-bnode.cpp
+++ b/bridge/go-llvm-bnode.cpp
@@ -1008,16 +1008,33 @@
 void BnodeBuilder::updateInstructions(Bexpression *expr,
                                       std::vector<llvm::Instruction*> newinsts)
 {
-  assert(expr->instructions().size() == newinsts.size());
+  assert(expr->instructions().size() >= newinsts.size());
+  llvm::Instruction *deleteAfter = nullptr;
+  unsigned newSize = newinsts.size();
   unsigned idx = 0;
+
   for (auto originst : expr->instructions()) {
-    llvm::Instruction *newinst = newinsts[idx];
-    if (originst != newinst) {
+    if (idx >= newSize) {
+      if (idx == newSize)
+        deleteAfter = originst;
       integrityVisitor_->unsetParent(originst, expr, idx);
-      integrityVisitor_->setParent(newinst, expr, idx);
-      if (expr->value() == originst)
-        expr->setValue(newinst);
+    } else {
+      llvm::Instruction *newinst = newinsts[idx];
+      if (originst != newinst) {
+        integrityVisitor_->unsetParent(originst, expr, idx);
+        integrityVisitor_->setParent(newinst, expr, idx);
+        if (expr->value() == originst)
+          expr->setValue(newinst);
+      }
     }
     idx++;
   }
+  if (deleteAfter != nullptr) {
+    std::vector<llvm::Instruction *> todel =
+        expr->extractInstsAfter(deleteAfter);
+    for (auto it = todel.rbegin(); it != todel.rend(); ++it) {
+      auto victim = *it;
+      victim->deleteValue();
+    }
+  }
 }
diff --git a/bridge/go-llvm-bnode.h b/bridge/go-llvm-bnode.h
index 8e7b159..87a72be 100644
--- a/bridge/go-llvm-bnode.h
+++ b/bridge/go-llvm-bnode.h
@@ -351,7 +351,8 @@
   // Bnodes but not instructions), or DelBoth (gets rid of nodes and
   // instructions).
   // If recursive is true, also delete its children recursively.
-  void destroy(Bnode *node, WhichDel which = DelWrappers, bool recursive = true);
+  void destroy(Bnode *node, WhichDel which = DelWrappers,
+               bool recursive = true);
 
   // Clone an expression subtree.
   Bexpression *cloneSubtree(Bexpression *expr);
@@ -370,9 +371,13 @@
   // children).
   std::vector<Bnode *> extractChildNodesAndDestroy(Bnode *node);
 
-  // Update the instructions of an expression node, mostly for updating
-  // references in the integrity checker. Used when the instructions
-  // are post-processed.
+  // Update the instructions of an expression node after we've finished
+  // walking its instruction list. During the walk it's possible that
+  // A) one or more existing instructions were converted into new
+  // instructions, or B) some instructions were thrown away because
+  // they appeared after a no-return call in the list. This helper
+  // fixes up the instruction list and updates ownership info
+  // in the integrity checker.
   void updateInstructions(Bexpression *expr,
                           std::vector<llvm::Instruction*> newinsts);
 
diff --git a/bridge/go-llvm.cpp b/bridge/go-llvm.cpp
index 7d8e62f..ba49ebe 100644
--- a/bridge/go-llvm.cpp
+++ b/bridge/go-llvm.cpp
@@ -2968,6 +2968,21 @@
   return std::make_pair(inst, curblock);
 }
 
+static bool isNoReturnCall(llvm::Instruction *inst)
+{
+  llvm::Function *func = nullptr;
+  if (llvm::isa<llvm::CallInst>(inst)) {
+    llvm::CallInst *call = llvm::cast<llvm::CallInst>(inst);
+    func = call->getCalledFunction();
+  } else if (llvm::isa<llvm::InvokeInst>(inst)) {
+    llvm::InvokeInst *invoke = llvm::cast<llvm::InvokeInst>(inst);
+    func = invoke->getCalledFunction();
+  }
+  if (func != nullptr && func->hasFnAttribute(llvm::Attribute::NoReturn))
+    return true;
+  return false;
+}
+
 llvm::BasicBlock *GenBlocks::walkExpr(llvm::BasicBlock *curblock,
                                       Bstatement *containingStmt,
                                       Bexpression *expr)
@@ -2987,14 +3002,13 @@
   if (!curblock)
     be_->nodeBuilder().destroy(expr, DelInstructions, false);
 
-  // Now visit instructions for this expr
-  // TODO: currently the control flow won't change from
-  // live to dead in this loop. Handle it, especially
-  // deallocate part of the instruction list, if it
-  // becomes necessary.
+  // Now visit instructions for this expr. Note: if as part of this loop a
+  // no-return call is encountered, we'll wind up changing from live code to
+  // dead code; handle this case appropriately.
   bool changed = false;
   std::vector<llvm::Instruction*> newinsts;
   for (auto originst : expr->instructions()) {
+    llvm::BasicBlock *origblock = curblock;
     auto pair = postProcessInst(originst, curblock);
     auto inst = pair.first;
     if (inst != originst)
@@ -3004,6 +3018,18 @@
     curblock->getInstList().push_back(inst);
     curblock = pair.second;
     newinsts.push_back(inst);
+
+    // Check for no-return call
+    if (isNoReturnCall(inst)) {
+      // Insert 'unreachable' inst into current block, then end
+      // current block.
+      LIRBuilder builder(context_, llvm::ConstantFolder());
+      llvm::Instruction *unreachable = builder.CreateUnreachable();
+      curblock->getInstList().push_back(unreachable);
+      curblock = nullptr;
+      changed = true;
+      break;
+    }
   }
   if (changed)
     be_->nodeBuilder().updateInstructions(expr, newinsts);
diff --git a/unittests/BackendCore/BackendCallTests.cpp b/unittests/BackendCore/BackendCallTests.cpp
index 687dde1..9cd6557 100644
--- a/unittests/BackendCore/BackendCallTests.cpp
+++ b/unittests/BackendCore/BackendCallTests.cpp
@@ -166,4 +166,62 @@
   EXPECT_FALSE(broken && "Module failed to verify.");
 }
 
+TEST(BackendCallTests, CallToNoReturnFunction) {
+
+  FcnTestHarness h;
+  Llvm_backend *be = h.be();
+  BFunctionType *befty = mkFuncTyp(be, L_END);
+  Bfunction *func = h.mkFunction("foo", befty);
+  Location loc;
+
+  // Declare a function 'noret' with no args and no return.
+  bool is_decl = true; bool is_inl = false;
+  bool is_vis = true; bool is_split = true;
+  bool is_noret = true; bool is_uniqsec = false;
+  Bfunction *nrfcn = be->function(befty, "noret", "noret",
+                                  is_vis, is_decl, is_inl, is_split,
+                                  is_noret, is_uniqsec, loc);
+
+  // Create a block containing two no-return calls. The intent here is to make
+  // sure that the bridge detects and deletes instructions appearing downstream
+  // of a no-return call.
+  std::vector<Bexpression *> args;
+  Bexpression *fn1 = be->function_code_expression(nrfcn, loc);
+  Bexpression *call1 = be->call_expression(nrfcn, fn1, args, nullptr, loc);
+  Bstatement *cs1 = h.mkExprStmt(call1, FcnTestHarness::NoAppend);
+  Bblock *nrcblock = mkBlockFromStmt(be, func, cs1);
+  Bexpression *fn2 = be->function_code_expression(nrfcn, loc);
+  Bexpression *call2 = be->call_expression(nrfcn, fn2, args, nullptr, loc);
+  Bstatement *cs2 = h.mkExprStmt(call2, FcnTestHarness::NoAppend);
+  addStmtToBlock(be, nrcblock, cs2);
+
+  // Create an "if" statement branching to the block above
+  Bexpression *cond = be->boolean_constant_expression(true);
+  h.mkIf(cond, be->block_statement(nrcblock), nullptr,
+         FcnTestHarness::YesAppend);
+
+  const char *exp = R"RAW_RESULT(
+    define void @foo(i8* nest %nest.0) #0 {
+    entry:
+      br i1 true, label %then.0, label %else.0
+
+    then.0:                                           ; preds = %entry
+      call void @noret(i8* nest undef)
+      unreachable
+
+    fallthrough.0:                                    ; preds = %else.0
+      ret void
+
+    else.0:                                           ; preds = %entry
+      br label %fallthrough.0
+    }
+    )RAW_RESULT";
+
+  bool broken = h.finish(StripDebugInfo);
+  EXPECT_FALSE(broken && "Module failed to verify.");
+
+  bool isOK = h.expectValue(func->function(), exp);
+  EXPECT_TRUE(isOK && "Function does not have expected contents");
+}
+
 }
diff --git a/unittests/BackendCore/BackendStmtTests.cpp b/unittests/BackendCore/BackendStmtTests.cpp
index 89550d3..e08f7ba 100644
--- a/unittests/BackendCore/BackendStmtTests.cpp
+++ b/unittests/BackendCore/BackendStmtTests.cpp
@@ -600,6 +600,10 @@
   Bvariable *loc1 = h.mkLocal("x", bi64t);
   Btype *bi8t = be->integer_type(false, 8);
   Bvariable *loc2 = h.mkLocal("y", bi8t);
+  Btype *s2t = mkBackendStruct(be, bi8t, "f1", bi8t, "f2", nullptr);
+  Bvariable *loc3 = h.mkLocal("z", s2t);
+  BFunctionType *beftynr = mkFuncTyp(be, L_RES, s2t, L_END);
+
   BFunctionType *befty2 = mkFuncTyp(be,
                                     L_PARM, bi64t,
                                     L_RES, bi64t,
@@ -609,8 +613,8 @@
   bool is_vis = true; bool is_split = true;
   bool is_noret = false; bool is_uniqsec = false;
   const char *fnames[] = { "plark", "plix" };
-  Bfunction *fcns[3];
-  Bexpression *calls[3];
+  Bfunction *fcns[4];
+  Bexpression *calls[5];
   for (unsigned ii = 0; ii < 2; ++ii)  {
     fcns[ii] = be->function(befty, fnames[ii], fnames[ii],
                             is_vis, is_decl, is_inl, is_split,
@@ -628,16 +632,33 @@
   iargs.push_back(mkInt64Const(be, 99));
   calls[2] = be->call_expression(func, idfn, iargs,
                                  nullptr, h.newloc());
+  fcns[3] = be->function(beftynr, "noret", "noret",
+                         is_vis, is_decl, is_inl, is_split,
+                         true, is_uniqsec, h.newloc());
+  for (unsigned ii = 0; ii < 2; ++ii)  {
+    Bexpression *nrfn = be->function_code_expression(fcns[3], h.newloc());
+    std::vector<Bexpression *> noargs;
+    calls[3+ii] = be->call_expression(func, nrfn, noargs,
+                                      nullptr, h.newloc());
+  }
 
   // body:
   // x = id(99)
   // plark()
+  // if false { y = noret(); y = noret() }
   // x = 123
   Bexpression *ve1 = be->var_expression(loc1, h.newloc());
   Bstatement *as1 =
       be->assignment_statement(func, ve1, calls[2], h.newloc());
   Bblock *bb1 = mkBlockFromStmt(be, func, as1);
   addStmtToBlock(be, bb1, h.mkExprStmt(calls[0], FcnTestHarness::NoAppend));
+  Bstatement *nrcs = h.mkExprStmt(calls[3], FcnTestHarness::NoAppend);
+  Bstatement *nrcs2 = h.mkExprStmt(calls[4], FcnTestHarness::NoAppend);
+  Bblock *bbnr = mkBlockFromStmt(be, func, nrcs);
+  addStmtToBlock(be, bbnr, nrcs2);
+  Bexpression *cond = be->boolean_constant_expression(false);
+  Bstatement *ifst = h.mkIf(cond, bbnr, nullptr, FcnTestHarness::NoAppend);
+  addStmtToBlock(be, bb1, ifst);
   Bexpression *ve2 = be->var_expression(loc1, h.newloc());
   Bstatement *as2 =
       be->assignment_statement(func, ve2, mkInt64Const(be, 123), h.newloc());
@@ -667,20 +688,26 @@
   %ehtmp.0 = alloca { i8*, i32 }
   %x = alloca i64
   %y = alloca i8
+  %z = alloca { i8, i8 }
+  %sret.actual.0 = alloca { i8, i8 }
+  %sret.actual.1 = alloca { i8, i8 }
   %finvar.0 = alloca i8
   store i64 0, i64* %x
   store i8 0, i8* %y
+  %cast.0 = bitcast { i8, i8 }* %z to i8*
+  %cast.1 = bitcast { i8, i8 }* @const.0 to i8*
+  call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 1 %cast.0, i8* align 1 %cast.1, i64 2, i1 false)
   %call.0 = invoke i64 @id(i8* nest undef, i64 99)
           to label %cont.1 unwind label %pad.1
 
-finok.0:                                          ; preds = %cont.3
+finok.0:                                          ; preds = %cont.4
   store i8 1, i8* %finvar.0
   br label %finally.0
 
 finally.0:                                        ; preds = %catchpad.0, %finok.0
   br label %finish.0
 
-pad.0:                                            ; preds = %cont.2, %finish.0
+pad.0:                                            ; preds = %fallthrough.0, %finish.0
   %ex.0 = landingpad { i8*, i32 }
           catch i8* null
   br label %catch.0
@@ -693,19 +720,19 @@
   invoke void @deferreturn(i8* nest undef, i8* %y)
           to label %cont.0 unwind label %pad.0
 
-cont.0:                                           ; preds = %cont.2, %finish.0
+cont.0:                                           ; preds = %fallthrough.0, %finish.0
   %fload.0 = load i8, i8* %finvar.0
   %icmp.0 = icmp eq i8 %fload.0, 1
   br i1 %icmp.0, label %finret.0, label %finres.0
 
-pad.1:                                            ; preds = %cont.1, %entry
+pad.1:                                            ; preds = %then.0, %cont.1, %entry
   %ex.1 = landingpad { i8*, i32 }
           catch i8* null
   br label %catch.1
 
 catch.1:                                          ; preds = %pad.1
   invoke void @plix(i8* nest undef)
-          to label %cont.3 unwind label %catchpad.0
+          to label %cont.4 unwind label %catchpad.0
 
 catchpad.0:                                       ; preds = %catch.1
   %ex2.0 = landingpad { i8*, i32 }
@@ -720,12 +747,25 @@
           to label %cont.2 unwind label %pad.1
 
 cont.2:                                           ; preds = %cont.1
+  br i1 false, label %then.0, label %else.0
+
+then.0:                                           ; preds = %cont.2
+  %call.1 = invoke i16 @noret(i8* nest undef)
+          to label %cont.3 unwind label %pad.1
+
+fallthrough.0:                                    ; preds = %else.0
   store i64 123, i64* %x
   store i8 1, i8* %finvar.0
   invoke void @deferreturn(i8* nest undef, i8* %y)
           to label %cont.0 unwind label %pad.0
 
-cont.3:                                           ; preds = %catch.1
+else.0:                                           ; preds = %cont.2
+  br label %fallthrough.0
+
+cont.3:                                           ; preds = %then.0
+  unreachable
+
+cont.4:                                           ; preds = %catch.1
   br label %finok.0
 
 finres.0:                                         ; preds = %cont.0