1
2
3 package regstate
4
5 import (
6 "encoding/json"
7 "fmt"
8 "net/url"
9 "os"
10 "path/filepath"
11 "reflect"
12 "syscall"
13
14 "golang.org/x/sys/windows"
15 "golang.org/x/sys/windows/registry"
16 )
17
18
19
20
21
22 const (
23 _REG_OPTION_VOLATILE = 1
24
25 _REG_OPENED_EXISTING_KEY = 2
26 )
27
28 type Key struct {
29 registry.Key
30 Name string
31 }
32
33 var localMachine = &Key{registry.LOCAL_MACHINE, "HKEY_LOCAL_MACHINE"}
34 var localUser = &Key{registry.CURRENT_USER, "HKEY_CURRENT_USER"}
35
36 var rootPath = `SOFTWARE\Microsoft\runhcs`
37
38 type NotFoundError struct {
39 ID string
40 }
41
42 func (err *NotFoundError) Error() string {
43 return fmt.Sprintf("ID '%s' was not found", err.ID)
44 }
45
46 func IsNotFoundError(err error) bool {
47 _, ok := err.(*NotFoundError)
48 return ok
49 }
50
51 type NoStateError struct {
52 ID string
53 Key string
54 }
55
56 func (err *NoStateError) Error() string {
57 return fmt.Sprintf("state '%s' is not present for ID '%s'", err.Key, err.ID)
58 }
59
60 func createVolatileKey(k *Key, path string, access uint32) (newk *Key, openedExisting bool, err error) {
61 var (
62 h syscall.Handle
63 d uint32
64 )
65 fullpath := filepath.Join(k.Name, path)
66 pathPtr, _ := windows.UTF16PtrFromString(path)
67 err = regCreateKeyEx(syscall.Handle(k.Key), pathPtr, 0, nil, _REG_OPTION_VOLATILE, access, nil, &h, &d)
68 if err != nil {
69 return nil, false, &os.PathError{Op: "RegCreateKeyEx", Path: fullpath, Err: err}
70 }
71 return &Key{registry.Key(h), fullpath}, d == _REG_OPENED_EXISTING_KEY, nil
72 }
73
74 func hive(perUser bool) *Key {
75 r := localMachine
76 if perUser {
77 r = localUser
78 }
79 return r
80 }
81
82 func Open(root string, perUser bool) (*Key, error) {
83 k, _, err := createVolatileKey(hive(perUser), rootPath, registry.ALL_ACCESS)
84 if err != nil {
85 return nil, err
86 }
87 defer k.Close()
88
89 k2, _, err := createVolatileKey(k, url.PathEscape(root), registry.ALL_ACCESS)
90 if err != nil {
91 return nil, err
92 }
93 return k2, nil
94 }
95
96 func RemoveAll(root string, perUser bool) error {
97 k, err := hive(perUser).open(rootPath)
98 if err != nil {
99 return err
100 }
101 defer k.Close()
102 r, err := k.open(url.PathEscape(root))
103 if err != nil {
104 return err
105 }
106 defer r.Close()
107 ids, err := r.Enumerate()
108 if err != nil {
109 return err
110 }
111 for _, id := range ids {
112 err = r.Remove(id)
113 if err != nil {
114 return err
115 }
116 }
117 r.Close()
118 return k.Remove(root)
119 }
120
121 func (k *Key) Close() error {
122 err := k.Key.Close()
123 k.Key = 0
124 return err
125 }
126
127 func (k *Key) Enumerate() ([]string, error) {
128 escapedIDs, err := k.ReadSubKeyNames(0)
129 if err != nil {
130 return nil, err
131 }
132 var ids []string
133 for _, e := range escapedIDs {
134 id, err := url.PathUnescape(e)
135 if err == nil {
136 ids = append(ids, id)
137 }
138 }
139 return ids, nil
140 }
141
142 func (k *Key) open(name string) (*Key, error) {
143 fullpath := filepath.Join(k.Name, name)
144 nk, err := registry.OpenKey(k.Key, name, registry.ALL_ACCESS)
145 if err != nil {
146 return nil, &os.PathError{Op: "RegOpenKey", Path: fullpath, Err: err}
147 }
148 return &Key{nk, fullpath}, nil
149 }
150
151 func (k *Key) openid(id string) (*Key, error) {
152 escaped := url.PathEscape(id)
153 fullpath := filepath.Join(k.Name, escaped)
154 nk, err := k.open(escaped)
155 if perr, ok := err.(*os.PathError); ok && perr.Err == syscall.ERROR_FILE_NOT_FOUND {
156 return nil, &NotFoundError{id}
157 }
158 if err != nil {
159 return nil, &os.PathError{Op: "RegOpenKey", Path: fullpath, Err: err}
160 }
161 return nk, nil
162 }
163
164 func (k *Key) Remove(id string) error {
165 escaped := url.PathEscape(id)
166 err := registry.DeleteKey(k.Key, escaped)
167 if err != nil {
168 if err == syscall.ERROR_FILE_NOT_FOUND {
169 return &NotFoundError{id}
170 }
171 return &os.PathError{Op: "RegDeleteKey", Path: filepath.Join(k.Name, escaped), Err: err}
172 }
173 return nil
174 }
175
176 func (k *Key) set(id string, create bool, key string, state interface{}) error {
177 var sk *Key
178 var err error
179 if create {
180 var existing bool
181 eid := url.PathEscape(id)
182 sk, existing, err = createVolatileKey(k, eid, registry.ALL_ACCESS)
183 if err != nil {
184 return err
185 }
186 defer sk.Close()
187 if existing {
188 sk.Close()
189 return fmt.Errorf("container %s already exists", id)
190 }
191 } else {
192 sk, err = k.openid(id)
193 if err != nil {
194 return err
195 }
196 defer sk.Close()
197 }
198 switch reflect.TypeOf(state).Kind() {
199 case reflect.Bool:
200 v := uint32(0)
201 if state.(bool) {
202 v = 1
203 }
204 err = sk.SetDWordValue(key, v)
205 case reflect.Int:
206 err = sk.SetQWordValue(key, uint64(state.(int)))
207 case reflect.String:
208 err = sk.SetStringValue(key, state.(string))
209 default:
210 var js []byte
211 js, err = json.Marshal(state)
212 if err != nil {
213 return err
214 }
215 err = sk.SetBinaryValue(key, js)
216 }
217 if err != nil {
218 if err == syscall.ERROR_FILE_NOT_FOUND {
219 return &NoStateError{id, key}
220 }
221 return &os.PathError{Op: "RegSetValueEx", Path: sk.Name + ":" + key, Err: err}
222 }
223 return nil
224 }
225
226 func (k *Key) Create(id, key string, state interface{}) error {
227 return k.set(id, true, key, state)
228 }
229
230 func (k *Key) Set(id, key string, state interface{}) error {
231 return k.set(id, false, key, state)
232 }
233
234 func (k *Key) Clear(id, key string) error {
235 sk, err := k.openid(id)
236 if err != nil {
237 return err
238 }
239 defer sk.Close()
240 err = sk.DeleteValue(key)
241 if err != nil {
242 if err == syscall.ERROR_FILE_NOT_FOUND {
243 return &NoStateError{id, key}
244 }
245 return &os.PathError{Op: "RegDeleteValue", Path: sk.Name + ":" + key, Err: err}
246 }
247 return nil
248 }
249
250 func (k *Key) Get(id, key string, state interface{}) error {
251 sk, err := k.openid(id)
252 if err != nil {
253 return err
254 }
255 defer sk.Close()
256
257 var js []byte
258 switch reflect.TypeOf(state).Elem().Kind() {
259 case reflect.Bool:
260 var v uint64
261 v, _, err = sk.GetIntegerValue(key)
262 if err == nil {
263 *state.(*bool) = v != 0
264 }
265 case reflect.Int:
266 var v uint64
267 v, _, err = sk.GetIntegerValue(key)
268 if err == nil {
269 *state.(*int) = int(v)
270 }
271 case reflect.String:
272 var v string
273 v, _, err = sk.GetStringValue(key)
274 if err == nil {
275 *state.(*string) = string(v)
276 }
277 default:
278 js, _, err = sk.GetBinaryValue(key)
279 }
280 if err != nil {
281 if err == syscall.ERROR_FILE_NOT_FOUND {
282 return &NoStateError{id, key}
283 }
284 return &os.PathError{Op: "RegQueryValueEx", Path: sk.Name + ":" + key, Err: err}
285 }
286 if js != nil {
287 err = json.Unmarshal(js, state)
288 }
289 return err
290 }
291
View as plain text