windows: use proper system directory path in fallback loader
The %WINDIR% variable is an odd choice and not even entirely reliable.
Since Windows 2000, there has been a specific function for determining
this information, so let's use it. It's also a useful function in its
own right for folks who want to launch system tools in a somewhat safe
way, like netsh.exe.
Updates golang/go#14959
Updates golang/go#30642
Change-Id: Ic24baf37d14f2daced0c1db2771b5a673d2c8852
Reviewed-on: https://go-review.googlesource.com/c/sys/+/165759
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Alex Brainman <alex.brainman@gmail.com>
diff --git a/windows/dll_windows.go b/windows/dll_windows.go
index e92c05b..ba67658 100644
--- a/windows/dll_windows.go
+++ b/windows/dll_windows.go
@@ -359,11 +359,11 @@
// trying to load "foo.dll" out of the system
// folder, but LoadLibraryEx doesn't support
// that yet on their system, so emulate it.
- windir, _ := Getenv("WINDIR") // old var; apparently works on XP
- if windir == "" {
- return nil, errString("%WINDIR% not defined")
+ systemdir, err := GetSystemDirectory()
+ if err != nil {
+ return nil, err
}
- loadDLL = windir + "\\System32\\" + name
+ loadDLL = systemdir + "\\" + name
}
}
h, err := LoadLibraryEx(loadDLL, 0, flags)
diff --git a/windows/security_windows.go b/windows/security_windows.go
index 9f946da..f5f2d8b 100644
--- a/windows/security_windows.go
+++ b/windows/security_windows.go
@@ -372,6 +372,7 @@
//sys OpenProcessToken(h Handle, access uint32, token *Token) (err error) = advapi32.OpenProcessToken
//sys GetTokenInformation(t Token, infoClass uint32, info *byte, infoLen uint32, returnedLen *uint32) (err error) = advapi32.GetTokenInformation
//sys GetUserProfileDirectory(t Token, dir *uint16, dirLen *uint32) (err error) = userenv.GetUserProfileDirectoryW
+//sys getSystemDirectory(dir *uint16, dirLen uint32) (len uint32, err error) = kernel32.GetSystemDirectoryW
// An access token contains the security information for a logon session.
// The system creates an access token when a user logs on, and every
@@ -468,6 +469,23 @@
}
}
+// GetSystemDirectory retrieves path to current location of the system
+// directory, which is typically, though not always, C:\Windows\System32.
+func GetSystemDirectory() (string, error) {
+ n := uint32(MAX_PATH)
+ for {
+ b := make([]uint16, n)
+ l, e := getSystemDirectory(&b[0], n)
+ if e != nil {
+ return "", e
+ }
+ if l <= n {
+ return UTF16ToString(b[:l]), nil
+ }
+ n = l
+ }
+}
+
// IsMember reports whether the access token t is a member of the provided SID.
func (t Token) IsMember(sid *SID) (bool, error) {
var b int32
diff --git a/windows/syscall_test.go b/windows/syscall_test.go
index d7009e4..f09c6dd 100644
--- a/windows/syscall_test.go
+++ b/windows/syscall_test.go
@@ -7,6 +7,7 @@
package windows_test
import (
+ "strings"
"syscall"
"testing"
@@ -51,3 +52,13 @@
t.Error("shlwapi.dll:IsOS(OS_NT) returned 0, expected non-zero value")
}
}
+
+func TestGetSystemDirectory(t *testing.T) {
+ d, err := windows.GetSystemDirectory()
+ if err != nil {
+ t.Fatalf("Failed to get system directory: %s", err)
+ }
+ if !strings.HasSuffix(strings.ToLower(d), "\\system32") {
+ t.Fatalf("System directory does not end in system32: %s", d)
+ }
+}
diff --git a/windows/zsyscall_windows.go b/windows/zsyscall_windows.go
index e4b54e2..308c97e 100644
--- a/windows/zsyscall_windows.go
+++ b/windows/zsyscall_windows.go
@@ -252,6 +252,7 @@
procOpenProcessToken = modadvapi32.NewProc("OpenProcessToken")
procGetTokenInformation = modadvapi32.NewProc("GetTokenInformation")
procGetUserProfileDirectoryW = moduserenv.NewProc("GetUserProfileDirectoryW")
+ procGetSystemDirectoryW = modkernel32.NewProc("GetSystemDirectoryW")
)
func RegisterEventSource(uncServerName *uint16, sourceName *uint16) (handle Handle, err error) {
@@ -2718,3 +2719,16 @@
}
return
}
+
+func getSystemDirectory(dir *uint16, dirLen uint32) (len uint32, err error) {
+ r0, _, e1 := syscall.Syscall(procGetSystemDirectoryW.Addr(), 2, uintptr(unsafe.Pointer(dir)), uintptr(dirLen), 0)
+ len = uint32(r0)
+ if len == 0 {
+ if e1 != 0 {
+ err = errnoErr(e1)
+ } else {
+ err = syscall.EINVAL
+ }
+ }
+ return
+}