windows: add AddDllDirectory and RemoveDllDirectory
Per https://learn.microsoft.com/en-us/windows/win32/api/libloaderapi/nf-libloaderapi-adddlldirectory
and https://learn.microsoft.com/en-us/windows/win32/api/libloaderapi/nf-libloaderapi-removedlldirectory.
Change-Id: If44a3758720345d1bbd9af96ec2481fbe9398a08
Reviewed-on: https://go-review.googlesource.com/c/sys/+/537755
Reviewed-by: Tatiana Bradley <tatianabradley@google.com>
Auto-Submit: Roland Shoemaker <roland@golang.org>
Reviewed-by: Alex Brainman <alex.brainman@gmail.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/windows/syscall_windows.go b/windows/syscall_windows.go
index fb6cfd0..47dc579 100644
--- a/windows/syscall_windows.go
+++ b/windows/syscall_windows.go
@@ -155,6 +155,8 @@
//sys GetModuleFileName(module Handle, filename *uint16, size uint32) (n uint32, err error) = kernel32.GetModuleFileNameW
//sys GetModuleHandleEx(flags uint32, moduleName *uint16, module *Handle) (err error) = kernel32.GetModuleHandleExW
//sys SetDefaultDllDirectories(directoryFlags uint32) (err error)
+//sys AddDllDirectory(path *uint16) (cookie uintptr, err error) = kernel32.AddDllDirectory
+//sys RemoveDllDirectory(cookie uintptr) (err error) = kernel32.RemoveDllDirectory
//sys SetDllDirectory(path string) (err error) = kernel32.SetDllDirectoryW
//sys GetVersion() (ver uint32, err error)
//sys FormatMessage(flags uint32, msgsrc uintptr, msgid uint32, langid uint32, buf []uint16, args *byte) (n uint32, err error) = FormatMessageW
diff --git a/windows/syscall_windows_test.go b/windows/syscall_windows_test.go
index dcc706d..6658379 100644
--- a/windows/syscall_windows_test.go
+++ b/windows/syscall_windows_test.go
@@ -11,6 +11,7 @@
"errors"
"fmt"
"os"
+ "os/exec"
"path/filepath"
"runtime"
"strconv"
@@ -1222,3 +1223,55 @@
t.Fatalf("GetStartupInfo: got error %v, want nil", err)
}
}
+
+func TestAddRemoveDllDirectory(t *testing.T) {
+ if _, err := exec.LookPath("gcc"); err != nil {
+ t.Skip("skipping test: gcc is missing")
+ }
+ dllSrc := `#include <stdint.h>
+#include <windows.h>
+
+uintptr_t beep(void) {
+ return 5;
+}`
+ tmpdir := t.TempDir()
+ srcname := "beep.c"
+ err := os.WriteFile(filepath.Join(tmpdir, srcname), []byte(dllSrc), 0)
+ if err != nil {
+ t.Fatal(err)
+ }
+ name := "beep.dll"
+ cmd := exec.Command("gcc", "-shared", "-s", "-Werror", "-o", name, srcname)
+ cmd.Dir = tmpdir
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ t.Fatalf("failed to build dll: %v - %v", err, string(out))
+ }
+
+ if _, err := windows.LoadLibraryEx("beep.dll", 0, windows.LOAD_LIBRARY_SEARCH_USER_DIRS); err == nil {
+ t.Fatal("LoadLibraryEx unexpectedly found beep.dll")
+ }
+
+ dllCookie, err := windows.AddDllDirectory(windows.StringToUTF16Ptr(tmpdir))
+ if err != nil {
+ t.Fatalf("AddDllDirectory failed: %s", err)
+ }
+
+ handle, err := windows.LoadLibraryEx("beep.dll", 0, windows.LOAD_LIBRARY_SEARCH_USER_DIRS)
+ if err != nil {
+ t.Fatalf("LoadLibraryEx failed: %s", err)
+ }
+
+ if err := windows.FreeLibrary(handle); err != nil {
+ t.Fatalf("FreeLibrary failed: %s", err)
+ }
+
+ if err := windows.RemoveDllDirectory(dllCookie); err != nil {
+ t.Fatalf("RemoveDllDirectory failed: %s", err)
+ }
+
+ _, err = windows.LoadLibraryEx("beep.dll", 0, windows.LOAD_LIBRARY_SEARCH_USER_DIRS)
+ if err == nil {
+ t.Fatal("LoadLibraryEx unexpectedly found beep.dll")
+ }
+}
diff --git a/windows/zsyscall_windows.go b/windows/zsyscall_windows.go
index db6282e..146a1f0 100644
--- a/windows/zsyscall_windows.go
+++ b/windows/zsyscall_windows.go
@@ -184,6 +184,7 @@
procGetAdaptersInfo = modiphlpapi.NewProc("GetAdaptersInfo")
procGetBestInterfaceEx = modiphlpapi.NewProc("GetBestInterfaceEx")
procGetIfEntry = modiphlpapi.NewProc("GetIfEntry")
+ procAddDllDirectory = modkernel32.NewProc("AddDllDirectory")
procAssignProcessToJobObject = modkernel32.NewProc("AssignProcessToJobObject")
procCancelIo = modkernel32.NewProc("CancelIo")
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
@@ -330,6 +331,7 @@
procReadProcessMemory = modkernel32.NewProc("ReadProcessMemory")
procReleaseMutex = modkernel32.NewProc("ReleaseMutex")
procRemoveDirectoryW = modkernel32.NewProc("RemoveDirectoryW")
+ procRemoveDllDirectory = modkernel32.NewProc("RemoveDllDirectory")
procResetEvent = modkernel32.NewProc("ResetEvent")
procResizePseudoConsole = modkernel32.NewProc("ResizePseudoConsole")
procResumeThread = modkernel32.NewProc("ResumeThread")
@@ -1605,6 +1607,15 @@
return
}
+func AddDllDirectory(path *uint16) (cookie uintptr, err error) {
+ r0, _, e1 := syscall.Syscall(procAddDllDirectory.Addr(), 1, uintptr(unsafe.Pointer(path)), 0, 0)
+ cookie = uintptr(r0)
+ if cookie == 0 {
+ err = errnoErr(e1)
+ }
+ return
+}
+
func AssignProcessToJobObject(job Handle, process Handle) (err error) {
r1, _, e1 := syscall.Syscall(procAssignProcessToJobObject.Addr(), 2, uintptr(job), uintptr(process), 0)
if r1 == 0 {
@@ -2879,6 +2890,14 @@
return
}
+func RemoveDllDirectory(cookie uintptr) (err error) {
+ r1, _, e1 := syscall.Syscall(procRemoveDllDirectory.Addr(), 1, uintptr(cookie), 0, 0)
+ if r1 == 0 {
+ err = errnoErr(e1)
+ }
+ return
+}
+
func ResetEvent(event Handle) (err error) {
r1, _, e1 := syscall.Syscall(procResetEvent.Addr(), 1, uintptr(event), 0, 0)
if r1 == 0 {