windows: add AddDllDirectory and RemoveDllDirectory Per https://learn.microsoft.com/en-us/windows/win32/api/libloaderapi/nf-libloaderapi-adddlldirectory and https://learn.microsoft.com/en-us/windows/win32/api/libloaderapi/nf-libloaderapi-removedlldirectory. Change-Id: If44a3758720345d1bbd9af96ec2481fbe9398a08 Reviewed-on: https://go-review.googlesource.com/c/sys/+/537755 Reviewed-by: Tatiana Bradley <tatianabradley@google.com> Auto-Submit: Roland Shoemaker <roland@golang.org> Reviewed-by: Alex Brainman <alex.brainman@gmail.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
diff --git a/windows/syscall_windows.go b/windows/syscall_windows.go index fb6cfd0..47dc579 100644 --- a/windows/syscall_windows.go +++ b/windows/syscall_windows.go
@@ -155,6 +155,8 @@ //sys GetModuleFileName(module Handle, filename *uint16, size uint32) (n uint32, err error) = kernel32.GetModuleFileNameW //sys GetModuleHandleEx(flags uint32, moduleName *uint16, module *Handle) (err error) = kernel32.GetModuleHandleExW //sys SetDefaultDllDirectories(directoryFlags uint32) (err error) +//sys AddDllDirectory(path *uint16) (cookie uintptr, err error) = kernel32.AddDllDirectory +//sys RemoveDllDirectory(cookie uintptr) (err error) = kernel32.RemoveDllDirectory //sys SetDllDirectory(path string) (err error) = kernel32.SetDllDirectoryW //sys GetVersion() (ver uint32, err error) //sys FormatMessage(flags uint32, msgsrc uintptr, msgid uint32, langid uint32, buf []uint16, args *byte) (n uint32, err error) = FormatMessageW
diff --git a/windows/syscall_windows_test.go b/windows/syscall_windows_test.go index dcc706d..6658379 100644 --- a/windows/syscall_windows_test.go +++ b/windows/syscall_windows_test.go
@@ -11,6 +11,7 @@ "errors" "fmt" "os" + "os/exec" "path/filepath" "runtime" "strconv" @@ -1222,3 +1223,55 @@ t.Fatalf("GetStartupInfo: got error %v, want nil", err) } } + +func TestAddRemoveDllDirectory(t *testing.T) { + if _, err := exec.LookPath("gcc"); err != nil { + t.Skip("skipping test: gcc is missing") + } + dllSrc := `#include <stdint.h> +#include <windows.h> + +uintptr_t beep(void) { + return 5; +}` + tmpdir := t.TempDir() + srcname := "beep.c" + err := os.WriteFile(filepath.Join(tmpdir, srcname), []byte(dllSrc), 0) + if err != nil { + t.Fatal(err) + } + name := "beep.dll" + cmd := exec.Command("gcc", "-shared", "-s", "-Werror", "-o", name, srcname) + cmd.Dir = tmpdir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("failed to build dll: %v - %v", err, string(out)) + } + + if _, err := windows.LoadLibraryEx("beep.dll", 0, windows.LOAD_LIBRARY_SEARCH_USER_DIRS); err == nil { + t.Fatal("LoadLibraryEx unexpectedly found beep.dll") + } + + dllCookie, err := windows.AddDllDirectory(windows.StringToUTF16Ptr(tmpdir)) + if err != nil { + t.Fatalf("AddDllDirectory failed: %s", err) + } + + handle, err := windows.LoadLibraryEx("beep.dll", 0, windows.LOAD_LIBRARY_SEARCH_USER_DIRS) + if err != nil { + t.Fatalf("LoadLibraryEx failed: %s", err) + } + + if err := windows.FreeLibrary(handle); err != nil { + t.Fatalf("FreeLibrary failed: %s", err) + } + + if err := windows.RemoveDllDirectory(dllCookie); err != nil { + t.Fatalf("RemoveDllDirectory failed: %s", err) + } + + _, err = windows.LoadLibraryEx("beep.dll", 0, windows.LOAD_LIBRARY_SEARCH_USER_DIRS) + if err == nil { + t.Fatal("LoadLibraryEx unexpectedly found beep.dll") + } +}
diff --git a/windows/zsyscall_windows.go b/windows/zsyscall_windows.go index db6282e..146a1f0 100644 --- a/windows/zsyscall_windows.go +++ b/windows/zsyscall_windows.go
@@ -184,6 +184,7 @@ procGetAdaptersInfo = modiphlpapi.NewProc("GetAdaptersInfo") procGetBestInterfaceEx = modiphlpapi.NewProc("GetBestInterfaceEx") procGetIfEntry = modiphlpapi.NewProc("GetIfEntry") + procAddDllDirectory = modkernel32.NewProc("AddDllDirectory") procAssignProcessToJobObject = modkernel32.NewProc("AssignProcessToJobObject") procCancelIo = modkernel32.NewProc("CancelIo") procCancelIoEx = modkernel32.NewProc("CancelIoEx") @@ -330,6 +331,7 @@ procReadProcessMemory = modkernel32.NewProc("ReadProcessMemory") procReleaseMutex = modkernel32.NewProc("ReleaseMutex") procRemoveDirectoryW = modkernel32.NewProc("RemoveDirectoryW") + procRemoveDllDirectory = modkernel32.NewProc("RemoveDllDirectory") procResetEvent = modkernel32.NewProc("ResetEvent") procResizePseudoConsole = modkernel32.NewProc("ResizePseudoConsole") procResumeThread = modkernel32.NewProc("ResumeThread") @@ -1605,6 +1607,15 @@ return } +func AddDllDirectory(path *uint16) (cookie uintptr, err error) { + r0, _, e1 := syscall.Syscall(procAddDllDirectory.Addr(), 1, uintptr(unsafe.Pointer(path)), 0, 0) + cookie = uintptr(r0) + if cookie == 0 { + err = errnoErr(e1) + } + return +} + func AssignProcessToJobObject(job Handle, process Handle) (err error) { r1, _, e1 := syscall.Syscall(procAssignProcessToJobObject.Addr(), 2, uintptr(job), uintptr(process), 0) if r1 == 0 { @@ -2879,6 +2890,14 @@ return } +func RemoveDllDirectory(cookie uintptr) (err error) { + r1, _, e1 := syscall.Syscall(procRemoveDllDirectory.Addr(), 1, uintptr(cookie), 0, 0) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + func ResetEvent(event Handle) (err error) { r1, _, e1 := syscall.Syscall(procResetEvent.Addr(), 1, uintptr(event), 0, 0) if r1 == 0 {