windows/svc: use NtQuerySystemInformation in IsWindowsService
This brings the algorithm more exactly in line with what .NET does for
the identically named function. Specifically, instead of using
OpenProcess, which requires rights that restricted services might not
have, we use NtQuerySystemInformation(SYSTEM_PROCESS_INFORMATION) to
find the parent process image name and session ID.
Fixes golang/go#44921.
Change-Id: Ie2ad7521cf4c530037d086e61dbc2413e4e7777c
Reviewed-on: https://go-review.googlesource.com/c/sys/+/372554
Trust: Jason Donenfeld <Jason@zx2c4.com>
Run-TryBot: Jason Donenfeld <Jason@zx2c4.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Patrik Nyblom <pnyb@google.com>
Trust: Patrik Nyblom <pnyb@google.com>
Run-TryBot: Patrik Nyblom <pnyb@google.com>
diff --git a/windows/svc/security.go b/windows/svc/security.go
index 351d286..1c51006 100644
--- a/windows/svc/security.go
+++ b/windows/svc/security.go
@@ -8,7 +8,6 @@
package svc
import (
- "path/filepath"
"strings"
"unsafe"
@@ -74,36 +73,29 @@
// Specifically, it looks up whether the parent process has session ID zero
// and is called "services".
- var pbi windows.PROCESS_BASIC_INFORMATION
- pbiLen := uint32(unsafe.Sizeof(pbi))
- err := windows.NtQueryInformationProcess(windows.CurrentProcess(), windows.ProcessBasicInformation, unsafe.Pointer(&pbi), pbiLen, &pbiLen)
+ var currentProcess windows.PROCESS_BASIC_INFORMATION
+ infoSize := uint32(unsafe.Sizeof(currentProcess))
+ err := windows.NtQueryInformationProcess(windows.CurrentProcess(), windows.ProcessBasicInformation, unsafe.Pointer(¤tProcess), infoSize, &infoSize)
if err != nil {
return false, err
}
- var psid uint32
- err = windows.ProcessIdToSessionId(uint32(pbi.InheritedFromUniqueProcessId), &psid)
- if err != nil || psid != 0 {
- return false, nil
+ var parentProcess *windows.SYSTEM_PROCESS_INFORMATION
+ for infoSize = uint32((unsafe.Sizeof(*parentProcess) + unsafe.Sizeof(uintptr(0))) * 1024); ; {
+ parentProcess = (*windows.SYSTEM_PROCESS_INFORMATION)(unsafe.Pointer(&make([]byte, infoSize)[0]))
+ err = windows.NtQuerySystemInformation(windows.SystemProcessInformation, unsafe.Pointer(parentProcess), infoSize, &infoSize)
+ if err == nil {
+ break
+ } else if err != windows.STATUS_INFO_LENGTH_MISMATCH {
+ return false, err
+ }
}
- pproc, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, uint32(pbi.InheritedFromUniqueProcessId))
- if err != nil {
- return false, err
+ for ; ; parentProcess = (*windows.SYSTEM_PROCESS_INFORMATION)(unsafe.Pointer(uintptr(unsafe.Pointer(parentProcess)) + uintptr(parentProcess.NextEntryOffset))) {
+ if parentProcess.UniqueProcessID == currentProcess.InheritedFromUniqueProcessId {
+ return parentProcess.SessionID == 0 && strings.EqualFold("services.exe", parentProcess.ImageName.String()), nil
+ }
+ if parentProcess.NextEntryOffset == 0 {
+ break
+ }
}
- defer windows.CloseHandle(pproc)
- var exeNameBuf [261]uint16
- exeNameLen := uint32(len(exeNameBuf) - 1)
- err = windows.QueryFullProcessImageName(pproc, 0, &exeNameBuf[0], &exeNameLen)
- if err != nil {
- return false, err
- }
- exeName := windows.UTF16ToString(exeNameBuf[:exeNameLen])
- if !strings.EqualFold(filepath.Base(exeName), "services.exe") {
- return false, nil
- }
- system32, err := windows.GetSystemDirectory()
- if err != nil {
- return false, err
- }
- targetExeName := filepath.Join(system32, "services.exe")
- return strings.EqualFold(exeName, targetExeName), nil
+ return false, nil
}
diff --git a/windows/types_windows.go b/windows/types_windows.go
index 655447e..bb31abd 100644
--- a/windows/types_windows.go
+++ b/windows/types_windows.go
@@ -2749,6 +2749,43 @@
InheritedFromUniqueProcessId uintptr
}
+type SYSTEM_PROCESS_INFORMATION struct {
+ NextEntryOffset uint32
+ NumberOfThreads uint32
+ WorkingSetPrivateSize int64
+ HardFaultCount uint32
+ NumberOfThreadsHighWatermark uint32
+ CycleTime uint64
+ CreateTime int64
+ UserTime int64
+ KernelTime int64
+ ImageName NTUnicodeString
+ BasePriority int32
+ UniqueProcessID uintptr
+ InheritedFromUniqueProcessID uintptr
+ HandleCount uint32
+ SessionID uint32
+ UniqueProcessKey *uint32
+ PeakVirtualSize uintptr
+ VirtualSize uintptr
+ PageFaultCount uint32
+ PeakWorkingSetSize uintptr
+ WorkingSetSize uintptr
+ QuotaPeakPagedPoolUsage uintptr
+ QuotaPagedPoolUsage uintptr
+ QuotaPeakNonPagedPoolUsage uintptr
+ QuotaNonPagedPoolUsage uintptr
+ PagefileUsage uintptr
+ PeakPagefileUsage uintptr
+ PrivatePageCount uintptr
+ ReadOperationCount int64
+ WriteOperationCount int64
+ OtherOperationCount int64
+ ReadTransferCount int64
+ WriteTransferCount int64
+ OtherTransferCount int64
+}
+
// SystemInformationClasses for NtQuerySystemInformation and NtSetSystemInformation
const (
SystemBasicInformation = iota