1
2
3
4 package winio
5
6 import (
7 "bytes"
8 "encoding/binary"
9 "fmt"
10 "runtime"
11 "sync"
12 "syscall"
13 "unicode/utf16"
14
15 "golang.org/x/sys/windows"
16 )
17
18
19
20
21
22
23
24
25
26
27 const (
28
29 SE_PRIVILEGE_ENABLED = windows.SE_PRIVILEGE_ENABLED
30
31
32 ERROR_NOT_ALL_ASSIGNED syscall.Errno = windows.ERROR_NOT_ALL_ASSIGNED
33
34 SeBackupPrivilege = "SeBackupPrivilege"
35 SeRestorePrivilege = "SeRestorePrivilege"
36 SeSecurityPrivilege = "SeSecurityPrivilege"
37 )
38
39 var (
40 privNames = make(map[string]uint64)
41 privNameMutex sync.Mutex
42 )
43
44
45 type PrivilegeError struct {
46 privileges []uint64
47 }
48
49 func (e *PrivilegeError) Error() string {
50 s := "Could not enable privilege "
51 if len(e.privileges) > 1 {
52 s = "Could not enable privileges "
53 }
54 for i, p := range e.privileges {
55 if i != 0 {
56 s += ", "
57 }
58 s += `"`
59 s += getPrivilegeName(p)
60 s += `"`
61 }
62 return s
63 }
64
65
66 func RunWithPrivilege(name string, fn func() error) error {
67 return RunWithPrivileges([]string{name}, fn)
68 }
69
70
71 func RunWithPrivileges(names []string, fn func() error) error {
72 privileges, err := mapPrivileges(names)
73 if err != nil {
74 return err
75 }
76 runtime.LockOSThread()
77 defer runtime.UnlockOSThread()
78 token, err := newThreadToken()
79 if err != nil {
80 return err
81 }
82 defer releaseThreadToken(token)
83 err = adjustPrivileges(token, privileges, SE_PRIVILEGE_ENABLED)
84 if err != nil {
85 return err
86 }
87 return fn()
88 }
89
90 func mapPrivileges(names []string) ([]uint64, error) {
91 privileges := make([]uint64, 0, len(names))
92 privNameMutex.Lock()
93 defer privNameMutex.Unlock()
94 for _, name := range names {
95 p, ok := privNames[name]
96 if !ok {
97 err := lookupPrivilegeValue("", name, &p)
98 if err != nil {
99 return nil, err
100 }
101 privNames[name] = p
102 }
103 privileges = append(privileges, p)
104 }
105 return privileges, nil
106 }
107
108
109 func EnableProcessPrivileges(names []string) error {
110 return enableDisableProcessPrivilege(names, SE_PRIVILEGE_ENABLED)
111 }
112
113
114 func DisableProcessPrivileges(names []string) error {
115 return enableDisableProcessPrivilege(names, 0)
116 }
117
118 func enableDisableProcessPrivilege(names []string, action uint32) error {
119 privileges, err := mapPrivileges(names)
120 if err != nil {
121 return err
122 }
123
124 p := windows.CurrentProcess()
125 var token windows.Token
126 err = windows.OpenProcessToken(p, windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, &token)
127 if err != nil {
128 return err
129 }
130
131 defer token.Close()
132 return adjustPrivileges(token, privileges, action)
133 }
134
135 func adjustPrivileges(token windows.Token, privileges []uint64, action uint32) error {
136 var b bytes.Buffer
137 _ = binary.Write(&b, binary.LittleEndian, uint32(len(privileges)))
138 for _, p := range privileges {
139 _ = binary.Write(&b, binary.LittleEndian, p)
140 _ = binary.Write(&b, binary.LittleEndian, action)
141 }
142 prevState := make([]byte, b.Len())
143 reqSize := uint32(0)
144 success, err := adjustTokenPrivileges(token, false, &b.Bytes()[0], uint32(len(prevState)), &prevState[0], &reqSize)
145 if !success {
146 return err
147 }
148 if err == ERROR_NOT_ALL_ASSIGNED {
149 return &PrivilegeError{privileges}
150 }
151 return nil
152 }
153
154 func getPrivilegeName(luid uint64) string {
155 var nameBuffer [256]uint16
156 bufSize := uint32(len(nameBuffer))
157 err := lookupPrivilegeName("", &luid, &nameBuffer[0], &bufSize)
158 if err != nil {
159 return fmt.Sprintf("<unknown privilege %d>", luid)
160 }
161
162 var displayNameBuffer [256]uint16
163 displayBufSize := uint32(len(displayNameBuffer))
164 var langID uint32
165 err = lookupPrivilegeDisplayName("", &nameBuffer[0], &displayNameBuffer[0], &displayBufSize, &langID)
166 if err != nil {
167 return fmt.Sprintf("<unknown privilege %s>", string(utf16.Decode(nameBuffer[:bufSize])))
168 }
169
170 return string(utf16.Decode(displayNameBuffer[:displayBufSize]))
171 }
172
173 func newThreadToken() (windows.Token, error) {
174 err := impersonateSelf(windows.SecurityImpersonation)
175 if err != nil {
176 return 0, err
177 }
178
179 var token windows.Token
180 err = openThreadToken(getCurrentThread(), syscall.TOKEN_ADJUST_PRIVILEGES|syscall.TOKEN_QUERY, false, &token)
181 if err != nil {
182 rerr := revertToSelf()
183 if rerr != nil {
184 panic(rerr)
185 }
186 return 0, err
187 }
188 return token, nil
189 }
190
191 func releaseThreadToken(h windows.Token) {
192 err := revertToSelf()
193 if err != nil {
194 panic(err)
195 }
196 h.Close()
197 }
198
View as plain text