windows: add DLL directory search path manipulation functions
A whole class of DLL hijacking attacks can be avoided with a dance like:
SetDllDirectory("")
SetDefaultDllDirectories(LOAD_LIBRARY_SEARCH_SYSTEM32)
For applications who want to opt into this better secure posture, this
commit adds the function definitions to do so.
Reference:
https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-setdlldirectorya
https://docs.microsoft.com/en-us/windows/win32/api/libloaderapi/nf-libloaderapi-setdefaultdlldirectories
Change-Id: I9b6d4e414a80a689b31b9b43a2d5c72de4813c39
Reviewed-on: https://go-review.googlesource.com/c/sys/+/273606
Run-TryBot: Jason A. Donenfeld <Jason@zx2c4.com>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Jason A. Donenfeld <Jason@zx2c4.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/windows/dll_windows.go b/windows/dll_windows.go
index 9cd147b..115341f 100644
--- a/windows/dll_windows.go
+++ b/windows/dll_windows.go
@@ -391,7 +391,6 @@
var flags uintptr
if system {
if canDoSearchSystem32() {
- const LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
flags = LOAD_LIBRARY_SEARCH_SYSTEM32
} else if isBaseName(name) {
// WindowsXP or unpatched Windows machine
diff --git a/windows/syscall_windows.go b/windows/syscall_windows.go
index 86a46f7..6c2ff27 100644
--- a/windows/syscall_windows.go
+++ b/windows/syscall_windows.go
@@ -170,6 +170,8 @@
//sys GetProcAddress(module Handle, procname string) (proc uintptr, err error)
//sys GetModuleFileName(module Handle, filename *uint16, size uint32) (n uint32, err error) = kernel32.GetModuleFileNameW
//sys GetModuleHandleEx(flags uint32, moduleName *uint16, module *Handle) (err error) = kernel32.GetModuleHandleExW
+//sys SetDefaultDllDirectories(directoryFlags uint32) (err error)
+//sys SetDllDirectory(path string) (err error) = kernel32.SetDllDirectoryW
//sys GetVersion() (ver uint32, err error)
//sys FormatMessage(flags uint32, msgsrc uintptr, msgid uint32, langid uint32, buf []uint16, args *byte) (n uint32, err error) = FormatMessageW
//sys ExitProcess(exitcode uint32)
diff --git a/windows/types_windows.go b/windows/types_windows.go
index e7ae37f..265d797 100644
--- a/windows/types_windows.go
+++ b/windows/types_windows.go
@@ -1801,3 +1801,22 @@
FileCaseSensitiveInfo = 23
FileNormalizedNameInfo = 24
)
+
+// LoadLibrary flags for determining from where to search for a DLL
+const (
+ DONT_RESOLVE_DLL_REFERENCES = 0x1
+ LOAD_LIBRARY_AS_DATAFILE = 0x2
+ LOAD_WITH_ALTERED_SEARCH_PATH = 0x8
+ LOAD_IGNORE_CODE_AUTHZ_LEVEL = 0x10
+ LOAD_LIBRARY_AS_IMAGE_RESOURCE = 0x20
+ LOAD_LIBRARY_AS_DATAFILE_EXCLUSIVE = 0x40
+ LOAD_LIBRARY_REQUIRE_SIGNED_TARGET = 0x80
+ LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x100
+ LOAD_LIBRARY_SEARCH_APPLICATION_DIR = 0x200
+ LOAD_LIBRARY_SEARCH_USER_DIRS = 0x400
+ LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x800
+ LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x1000
+ LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
+ LOAD_LIBRARY_SEARCH_SYSTEM32_NO_FORWARDER = 0x00004000
+ LOAD_LIBRARY_OS_INTEGRITY_CONTINUITY = 0x00008000
+)
diff --git a/windows/zsyscall_windows.go b/windows/zsyscall_windows.go
index dadd5a8..204b1db 100644
--- a/windows/zsyscall_windows.go
+++ b/windows/zsyscall_windows.go
@@ -280,6 +280,8 @@
procSetConsoleCursorPosition = modkernel32.NewProc("SetConsoleCursorPosition")
procSetConsoleMode = modkernel32.NewProc("SetConsoleMode")
procSetCurrentDirectoryW = modkernel32.NewProc("SetCurrentDirectoryW")
+ procSetDefaultDllDirectories = modkernel32.NewProc("SetDefaultDllDirectories")
+ procSetDllDirectoryW = modkernel32.NewProc("SetDllDirectoryW")
procSetEndOfFile = modkernel32.NewProc("SetEndOfFile")
procSetEnvironmentVariableW = modkernel32.NewProc("SetEnvironmentVariableW")
procSetErrorMode = modkernel32.NewProc("SetErrorMode")
@@ -2376,6 +2378,31 @@
return
}
+func SetDefaultDllDirectories(directoryFlags uint32) (err error) {
+ r1, _, e1 := syscall.Syscall(procSetDefaultDllDirectories.Addr(), 1, uintptr(directoryFlags), 0, 0)
+ if r1 == 0 {
+ err = errnoErr(e1)
+ }
+ return
+}
+
+func SetDllDirectory(path string) (err error) {
+ var _p0 *uint16
+ _p0, err = syscall.UTF16PtrFromString(path)
+ if err != nil {
+ return
+ }
+ return _SetDllDirectory(_p0)
+}
+
+func _SetDllDirectory(path *uint16) (err error) {
+ r1, _, e1 := syscall.Syscall(procSetDllDirectoryW.Addr(), 1, uintptr(unsafe.Pointer(path)), 0, 0)
+ if r1 == 0 {
+ err = errnoErr(e1)
+ }
+ return
+}
+
func SetEndOfFile(handle Handle) (err error) {
r1, _, e1 := syscall.Syscall(procSetEndOfFile.Addr(), 1, uintptr(handle), 0, 0)
if r1 == 0 {