windows: add AddDllDirectory and RemoveDllDirectory

Per https://learn.microsoft.com/en-us/windows/win32/api/libloaderapi/nf-libloaderapi-adddlldirectory
and https://learn.microsoft.com/en-us/windows/win32/api/libloaderapi/nf-libloaderapi-removedlldirectory.

Change-Id: If44a3758720345d1bbd9af96ec2481fbe9398a08
Reviewed-on: https://go-review.googlesource.com/c/sys/+/537755
Reviewed-by: Tatiana Bradley <tatianabradley@google.com>
Auto-Submit: Roland Shoemaker <roland@golang.org>
Reviewed-by: Alex Brainman <alex.brainman@gmail.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/windows/syscall_windows.go b/windows/syscall_windows.go
index fb6cfd0..47dc579 100644
--- a/windows/syscall_windows.go
+++ b/windows/syscall_windows.go
@@ -155,6 +155,8 @@
 //sys	GetModuleFileName(module Handle, filename *uint16, size uint32) (n uint32, err error) = kernel32.GetModuleFileNameW
 //sys	GetModuleHandleEx(flags uint32, moduleName *uint16, module *Handle) (err error) = kernel32.GetModuleHandleExW
 //sys	SetDefaultDllDirectories(directoryFlags uint32) (err error)
+//sys	AddDllDirectory(path *uint16) (cookie uintptr, err error) = kernel32.AddDllDirectory
+//sys	RemoveDllDirectory(cookie uintptr) (err error) = kernel32.RemoveDllDirectory
 //sys	SetDllDirectory(path string) (err error) = kernel32.SetDllDirectoryW
 //sys	GetVersion() (ver uint32, err error)
 //sys	FormatMessage(flags uint32, msgsrc uintptr, msgid uint32, langid uint32, buf []uint16, args *byte) (n uint32, err error) = FormatMessageW
diff --git a/windows/syscall_windows_test.go b/windows/syscall_windows_test.go
index dcc706d..6658379 100644
--- a/windows/syscall_windows_test.go
+++ b/windows/syscall_windows_test.go
@@ -11,6 +11,7 @@
 	"errors"
 	"fmt"
 	"os"
+	"os/exec"
 	"path/filepath"
 	"runtime"
 	"strconv"
@@ -1222,3 +1223,55 @@
 		t.Fatalf("GetStartupInfo: got error %v, want nil", err)
 	}
 }
+
+func TestAddRemoveDllDirectory(t *testing.T) {
+	if _, err := exec.LookPath("gcc"); err != nil {
+		t.Skip("skipping test: gcc is missing")
+	}
+	dllSrc := `#include <stdint.h>
+#include <windows.h>
+
+uintptr_t beep(void) {
+   return 5;
+}`
+	tmpdir := t.TempDir()
+	srcname := "beep.c"
+	err := os.WriteFile(filepath.Join(tmpdir, srcname), []byte(dllSrc), 0)
+	if err != nil {
+		t.Fatal(err)
+	}
+	name := "beep.dll"
+	cmd := exec.Command("gcc", "-shared", "-s", "-Werror", "-o", name, srcname)
+	cmd.Dir = tmpdir
+	out, err := cmd.CombinedOutput()
+	if err != nil {
+		t.Fatalf("failed to build dll: %v - %v", err, string(out))
+	}
+
+	if _, err := windows.LoadLibraryEx("beep.dll", 0, windows.LOAD_LIBRARY_SEARCH_USER_DIRS); err == nil {
+		t.Fatal("LoadLibraryEx unexpectedly found beep.dll")
+	}
+
+	dllCookie, err := windows.AddDllDirectory(windows.StringToUTF16Ptr(tmpdir))
+	if err != nil {
+		t.Fatalf("AddDllDirectory failed: %s", err)
+	}
+
+	handle, err := windows.LoadLibraryEx("beep.dll", 0, windows.LOAD_LIBRARY_SEARCH_USER_DIRS)
+	if err != nil {
+		t.Fatalf("LoadLibraryEx failed: %s", err)
+	}
+
+	if err := windows.FreeLibrary(handle); err != nil {
+		t.Fatalf("FreeLibrary failed: %s", err)
+	}
+
+	if err := windows.RemoveDllDirectory(dllCookie); err != nil {
+		t.Fatalf("RemoveDllDirectory failed: %s", err)
+	}
+
+	_, err = windows.LoadLibraryEx("beep.dll", 0, windows.LOAD_LIBRARY_SEARCH_USER_DIRS)
+	if err == nil {
+		t.Fatal("LoadLibraryEx unexpectedly found beep.dll")
+	}
+}
diff --git a/windows/zsyscall_windows.go b/windows/zsyscall_windows.go
index db6282e..146a1f0 100644
--- a/windows/zsyscall_windows.go
+++ b/windows/zsyscall_windows.go
@@ -184,6 +184,7 @@
 	procGetAdaptersInfo                                      = modiphlpapi.NewProc("GetAdaptersInfo")
 	procGetBestInterfaceEx                                   = modiphlpapi.NewProc("GetBestInterfaceEx")
 	procGetIfEntry                                           = modiphlpapi.NewProc("GetIfEntry")
+	procAddDllDirectory                                      = modkernel32.NewProc("AddDllDirectory")
 	procAssignProcessToJobObject                             = modkernel32.NewProc("AssignProcessToJobObject")
 	procCancelIo                                             = modkernel32.NewProc("CancelIo")
 	procCancelIoEx                                           = modkernel32.NewProc("CancelIoEx")
@@ -330,6 +331,7 @@
 	procReadProcessMemory                                    = modkernel32.NewProc("ReadProcessMemory")
 	procReleaseMutex                                         = modkernel32.NewProc("ReleaseMutex")
 	procRemoveDirectoryW                                     = modkernel32.NewProc("RemoveDirectoryW")
+	procRemoveDllDirectory                                   = modkernel32.NewProc("RemoveDllDirectory")
 	procResetEvent                                           = modkernel32.NewProc("ResetEvent")
 	procResizePseudoConsole                                  = modkernel32.NewProc("ResizePseudoConsole")
 	procResumeThread                                         = modkernel32.NewProc("ResumeThread")
@@ -1605,6 +1607,15 @@
 	return
 }
 
+func AddDllDirectory(path *uint16) (cookie uintptr, err error) {
+	r0, _, e1 := syscall.Syscall(procAddDllDirectory.Addr(), 1, uintptr(unsafe.Pointer(path)), 0, 0)
+	cookie = uintptr(r0)
+	if cookie == 0 {
+		err = errnoErr(e1)
+	}
+	return
+}
+
 func AssignProcessToJobObject(job Handle, process Handle) (err error) {
 	r1, _, e1 := syscall.Syscall(procAssignProcessToJobObject.Addr(), 2, uintptr(job), uintptr(process), 0)
 	if r1 == 0 {
@@ -2879,6 +2890,14 @@
 	return
 }
 
+func RemoveDllDirectory(cookie uintptr) (err error) {
+	r1, _, e1 := syscall.Syscall(procRemoveDllDirectory.Addr(), 1, uintptr(cookie), 0, 0)
+	if r1 == 0 {
+		err = errnoErr(e1)
+	}
+	return
+}
+
 func ResetEvent(event Handle) (err error) {
 	r1, _, e1 := syscall.Syscall(procResetEvent.Addr(), 1, uintptr(event), 0, 0)
 	if r1 == 0 {