windows: add DLL directory search path manipulation functions

A whole class of DLL hijacking attacks can be avoided with a dance like:

    SetDllDirectory("")
    SetDefaultDllDirectories(LOAD_LIBRARY_SEARCH_SYSTEM32)

For applications who want to opt into this better secure posture, this
commit adds the function definitions to do so.

Reference:
https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-setdlldirectorya
https://docs.microsoft.com/en-us/windows/win32/api/libloaderapi/nf-libloaderapi-setdefaultdlldirectories

Change-Id: I9b6d4e414a80a689b31b9b43a2d5c72de4813c39
Reviewed-on: https://go-review.googlesource.com/c/sys/+/273606
Run-TryBot: Jason A. Donenfeld <Jason@zx2c4.com>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Jason A. Donenfeld <Jason@zx2c4.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/windows/dll_windows.go b/windows/dll_windows.go
index 9cd147b..115341f 100644
--- a/windows/dll_windows.go
+++ b/windows/dll_windows.go
@@ -391,7 +391,6 @@
 	var flags uintptr
 	if system {
 		if canDoSearchSystem32() {
-			const LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
 			flags = LOAD_LIBRARY_SEARCH_SYSTEM32
 		} else if isBaseName(name) {
 			// WindowsXP or unpatched Windows machine
diff --git a/windows/syscall_windows.go b/windows/syscall_windows.go
index 86a46f7..6c2ff27 100644
--- a/windows/syscall_windows.go
+++ b/windows/syscall_windows.go
@@ -170,6 +170,8 @@
 //sys	GetProcAddress(module Handle, procname string) (proc uintptr, err error)
 //sys	GetModuleFileName(module Handle, filename *uint16, size uint32) (n uint32, err error) = kernel32.GetModuleFileNameW
 //sys	GetModuleHandleEx(flags uint32, moduleName *uint16, module *Handle) (err error) = kernel32.GetModuleHandleExW
+//sys	SetDefaultDllDirectories(directoryFlags uint32) (err error)
+//sys	SetDllDirectory(path string) (err error) = kernel32.SetDllDirectoryW
 //sys	GetVersion() (ver uint32, err error)
 //sys	FormatMessage(flags uint32, msgsrc uintptr, msgid uint32, langid uint32, buf []uint16, args *byte) (n uint32, err error) = FormatMessageW
 //sys	ExitProcess(exitcode uint32)
diff --git a/windows/types_windows.go b/windows/types_windows.go
index e7ae37f..265d797 100644
--- a/windows/types_windows.go
+++ b/windows/types_windows.go
@@ -1801,3 +1801,22 @@
 	FileCaseSensitiveInfo          = 23
 	FileNormalizedNameInfo         = 24
 )
+
+// LoadLibrary flags for determining from where to search for a DLL
+const (
+	DONT_RESOLVE_DLL_REFERENCES               = 0x1
+	LOAD_LIBRARY_AS_DATAFILE                  = 0x2
+	LOAD_WITH_ALTERED_SEARCH_PATH             = 0x8
+	LOAD_IGNORE_CODE_AUTHZ_LEVEL              = 0x10
+	LOAD_LIBRARY_AS_IMAGE_RESOURCE            = 0x20
+	LOAD_LIBRARY_AS_DATAFILE_EXCLUSIVE        = 0x40
+	LOAD_LIBRARY_REQUIRE_SIGNED_TARGET        = 0x80
+	LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR          = 0x100
+	LOAD_LIBRARY_SEARCH_APPLICATION_DIR       = 0x200
+	LOAD_LIBRARY_SEARCH_USER_DIRS             = 0x400
+	LOAD_LIBRARY_SEARCH_SYSTEM32              = 0x800
+	LOAD_LIBRARY_SEARCH_DEFAULT_DIRS          = 0x1000
+	LOAD_LIBRARY_SAFE_CURRENT_DIRS            = 0x00002000
+	LOAD_LIBRARY_SEARCH_SYSTEM32_NO_FORWARDER = 0x00004000
+	LOAD_LIBRARY_OS_INTEGRITY_CONTINUITY      = 0x00008000
+)
diff --git a/windows/zsyscall_windows.go b/windows/zsyscall_windows.go
index dadd5a8..204b1db 100644
--- a/windows/zsyscall_windows.go
+++ b/windows/zsyscall_windows.go
@@ -280,6 +280,8 @@
 	procSetConsoleCursorPosition                             = modkernel32.NewProc("SetConsoleCursorPosition")
 	procSetConsoleMode                                       = modkernel32.NewProc("SetConsoleMode")
 	procSetCurrentDirectoryW                                 = modkernel32.NewProc("SetCurrentDirectoryW")
+	procSetDefaultDllDirectories                             = modkernel32.NewProc("SetDefaultDllDirectories")
+	procSetDllDirectoryW                                     = modkernel32.NewProc("SetDllDirectoryW")
 	procSetEndOfFile                                         = modkernel32.NewProc("SetEndOfFile")
 	procSetEnvironmentVariableW                              = modkernel32.NewProc("SetEnvironmentVariableW")
 	procSetErrorMode                                         = modkernel32.NewProc("SetErrorMode")
@@ -2376,6 +2378,31 @@
 	return
 }
 
+func SetDefaultDllDirectories(directoryFlags uint32) (err error) {
+	r1, _, e1 := syscall.Syscall(procSetDefaultDllDirectories.Addr(), 1, uintptr(directoryFlags), 0, 0)
+	if r1 == 0 {
+		err = errnoErr(e1)
+	}
+	return
+}
+
+func SetDllDirectory(path string) (err error) {
+	var _p0 *uint16
+	_p0, err = syscall.UTF16PtrFromString(path)
+	if err != nil {
+		return
+	}
+	return _SetDllDirectory(_p0)
+}
+
+func _SetDllDirectory(path *uint16) (err error) {
+	r1, _, e1 := syscall.Syscall(procSetDllDirectoryW.Addr(), 1, uintptr(unsafe.Pointer(path)), 0, 0)
+	if r1 == 0 {
+		err = errnoErr(e1)
+	}
+	return
+}
+
 func SetEndOfFile(handle Handle) (err error) {
 	r1, _, e1 := syscall.Syscall(procSetEndOfFile.Addr(), 1, uintptr(handle), 0, 0)
 	if r1 == 0 {