1 package config
2
3 import (
4 "fmt"
5 "os"
6 "path/filepath"
7 "runtime"
8 "testing"
9
10 "github.com/stretchr/testify/assert"
11 "github.com/stretchr/testify/require"
12 )
13
14 func TestConfigDir(t *testing.T) {
15 tempDir := t.TempDir()
16
17 tests := []struct {
18 name string
19 onlyWindows bool
20 env map[string]string
21 output string
22 }{
23 {
24 name: "HOME/USERPROFILE specified",
25 env: map[string]string{
26 "GH_CONFIG_DIR": "",
27 "XDG_CONFIG_HOME": "",
28 "AppData": "",
29 "USERPROFILE": tempDir,
30 "HOME": tempDir,
31 },
32 output: filepath.Join(tempDir, ".config", "gh"),
33 },
34 {
35 name: "GH_CONFIG_DIR specified",
36 env: map[string]string{
37 "GH_CONFIG_DIR": filepath.Join(tempDir, "gh_config_dir"),
38 },
39 output: filepath.Join(tempDir, "gh_config_dir"),
40 },
41 {
42 name: "XDG_CONFIG_HOME specified",
43 env: map[string]string{
44 "XDG_CONFIG_HOME": tempDir,
45 },
46 output: filepath.Join(tempDir, "gh"),
47 },
48 {
49 name: "GH_CONFIG_DIR and XDG_CONFIG_HOME specified",
50 env: map[string]string{
51 "GH_CONFIG_DIR": filepath.Join(tempDir, "gh_config_dir"),
52 "XDG_CONFIG_HOME": tempDir,
53 },
54 output: filepath.Join(tempDir, "gh_config_dir"),
55 },
56 {
57 name: "AppData specified",
58 onlyWindows: true,
59 env: map[string]string{
60 "AppData": tempDir,
61 },
62 output: filepath.Join(tempDir, "GitHub CLI"),
63 },
64 {
65 name: "GH_CONFIG_DIR and AppData specified",
66 onlyWindows: true,
67 env: map[string]string{
68 "GH_CONFIG_DIR": filepath.Join(tempDir, "gh_config_dir"),
69 "AppData": tempDir,
70 },
71 output: filepath.Join(tempDir, "gh_config_dir"),
72 },
73 {
74 name: "XDG_CONFIG_HOME and AppData specified",
75 onlyWindows: true,
76 env: map[string]string{
77 "XDG_CONFIG_HOME": tempDir,
78 "AppData": tempDir,
79 },
80 output: filepath.Join(tempDir, "gh"),
81 },
82 }
83
84 for _, tt := range tests {
85 if tt.onlyWindows && runtime.GOOS != "windows" {
86 continue
87 }
88 t.Run(tt.name, func(t *testing.T) {
89 if tt.env != nil {
90 for k, v := range tt.env {
91 t.Setenv(k, v)
92 }
93 }
94 assert.Equal(t, tt.output, ConfigDir())
95 })
96 }
97 }
98
99 func TestStateDir(t *testing.T) {
100 tempDir := t.TempDir()
101
102 tests := []struct {
103 name string
104 onlyWindows bool
105 env map[string]string
106 output string
107 }{
108 {
109 name: "HOME/USERPROFILE specified",
110 env: map[string]string{
111 "XDG_STATE_HOME": "",
112 "GH_CONFIG_DIR": "",
113 "XDG_CONFIG_HOME": "",
114 "LocalAppData": "",
115 "USERPROFILE": tempDir,
116 "HOME": tempDir,
117 },
118 output: filepath.Join(tempDir, ".local", "state", "gh"),
119 },
120 {
121 name: "XDG_STATE_HOME specified",
122 env: map[string]string{
123 "XDG_STATE_HOME": tempDir,
124 },
125 output: filepath.Join(tempDir, "gh"),
126 },
127 {
128 name: "LocalAppData specified",
129 onlyWindows: true,
130 env: map[string]string{
131 "LocalAppData": tempDir,
132 },
133 output: filepath.Join(tempDir, "GitHub CLI"),
134 },
135 {
136 name: "XDG_STATE_HOME and LocalAppData specified",
137 onlyWindows: true,
138 env: map[string]string{
139 "XDG_STATE_HOME": tempDir,
140 "LocalAppData": tempDir,
141 },
142 output: filepath.Join(tempDir, "gh"),
143 },
144 }
145
146 for _, tt := range tests {
147 if tt.onlyWindows && runtime.GOOS != "windows" {
148 continue
149 }
150 t.Run(tt.name, func(t *testing.T) {
151 if tt.env != nil {
152 for k, v := range tt.env {
153 t.Setenv(k, v)
154 }
155 }
156 assert.Equal(t, tt.output, StateDir())
157 })
158 }
159 }
160
161 func TestDataDir(t *testing.T) {
162 tempDir := t.TempDir()
163
164 tests := []struct {
165 name string
166 onlyWindows bool
167 env map[string]string
168 output string
169 }{
170 {
171 name: "HOME/USERPROFILE specified",
172 env: map[string]string{
173 "XDG_DATA_HOME": "",
174 "GH_CONFIG_DIR": "",
175 "XDG_CONFIG_HOME": "",
176 "LocalAppData": "",
177 "USERPROFILE": tempDir,
178 "HOME": tempDir,
179 },
180 output: filepath.Join(tempDir, ".local", "share", "gh"),
181 },
182 {
183 name: "XDG_DATA_HOME specified",
184 env: map[string]string{
185 "XDG_DATA_HOME": tempDir,
186 },
187 output: filepath.Join(tempDir, "gh"),
188 },
189 {
190 name: "LocalAppData specified",
191 onlyWindows: true,
192 env: map[string]string{
193 "LocalAppData": tempDir,
194 },
195 output: filepath.Join(tempDir, "GitHub CLI"),
196 },
197 {
198 name: "XDG_DATA_HOME and LocalAppData specified",
199 onlyWindows: true,
200 env: map[string]string{
201 "XDG_DATA_HOME": tempDir,
202 "LocalAppData": tempDir,
203 },
204 output: filepath.Join(tempDir, "gh"),
205 },
206 }
207
208 for _, tt := range tests {
209 if tt.onlyWindows && runtime.GOOS != "windows" {
210 continue
211 }
212 t.Run(tt.name, func(t *testing.T) {
213 if tt.env != nil {
214 for k, v := range tt.env {
215 t.Setenv(k, v)
216 }
217 }
218 assert.Equal(t, tt.output, DataDir())
219 })
220 }
221 }
222
223 func TestCacheDir(t *testing.T) {
224 expectedCacheDir := "/expected-cache-dir"
225 unexpectedCacheDir := "/unexpected-cache-dir"
226
227 tests := []struct {
228 name string
229 onlyWindows bool
230 env map[string]string
231 output string
232 }{
233 {
234 name: "XDG_CACHE_HOME is highest precedence",
235 env: map[string]string{
236 "XDG_CACHE_HOME": expectedCacheDir,
237 "LocalAppData": unexpectedCacheDir,
238 "USERPROFILE": unexpectedCacheDir,
239 "HOME": unexpectedCacheDir,
240 },
241 output: filepath.Join(expectedCacheDir, "gh"),
242 },
243 {
244 name: "on windows, LocalAppData is preferred to home dir",
245 onlyWindows: true,
246 env: map[string]string{
247 "XDG_CACHE_HOME": "",
248 "LocalAppData": expectedCacheDir,
249 "USERPROFILE": unexpectedCacheDir,
250 "HOME": unexpectedCacheDir,
251 },
252 output: filepath.Join(expectedCacheDir, "GitHub CLI"),
253 },
254 {
255 name: "tries to use the home dir cache directory",
256 env: map[string]string{
257 "XDG_CACHE_HOME": "",
258 "LocalAppData": "",
259 "USERPROFILE": expectedCacheDir,
260 "HOME": expectedCacheDir,
261 },
262 output: filepath.Join(expectedCacheDir, ".cache", "gh"),
263 },
264 {
265 name: "finally falls back to tmpdir",
266
267 env: map[string]string{
268 "XDG_CACHE_HOME": "",
269 "LocalAppData": "",
270 "USERPROFILE": "",
271 "HOME": "",
272 },
273 output: filepath.Join(os.TempDir(), "gh-cli-cache"),
274 },
275 }
276
277 for _, tt := range tests {
278 if tt.onlyWindows && runtime.GOOS != "windows" {
279 continue
280 }
281 t.Run(tt.name, func(t *testing.T) {
282 if tt.env != nil {
283 for k, v := range tt.env {
284 t.Setenv(k, v)
285 }
286 }
287 assert.Equal(t, tt.output, CacheDir())
288 })
289 }
290
291 }
292
293 func TestLoad(t *testing.T) {
294 tempDir := t.TempDir()
295 globalFilePath := filepath.Join(tempDir, "config.yml")
296 invalidGlobalFilePath := filepath.Join(tempDir, "invalid_config.yml")
297 hostsFilePath := filepath.Join(tempDir, "hosts.yml")
298 invalidHostsFilePath := filepath.Join(tempDir, "invalid_hosts.yml")
299 err := os.WriteFile(globalFilePath, []byte(testGlobalData()), 0755)
300 assert.NoError(t, err)
301 err = os.WriteFile(invalidGlobalFilePath, []byte("invalid"), 0755)
302 assert.NoError(t, err)
303 err = os.WriteFile(hostsFilePath, []byte(testHostsData()), 0755)
304 assert.NoError(t, err)
305 err = os.WriteFile(invalidHostsFilePath, []byte("invalid"), 0755)
306 assert.NoError(t, err)
307
308 tests := []struct {
309 name string
310 globalConfigPath string
311 hostsConfigPath string
312 fallback *Config
313 wantGitProtocol string
314 wantToken string
315 wantErr bool
316 wantErrMsg string
317 }{
318 {
319 name: "global and hosts files exist",
320 globalConfigPath: globalFilePath,
321 hostsConfigPath: hostsFilePath,
322 wantGitProtocol: "ssh",
323 wantToken: "yyyyyyyyyyyyyyyyyyyy",
324 },
325 {
326 name: "invalid global file",
327 globalConfigPath: invalidGlobalFilePath,
328 wantErr: true,
329 wantErrMsg: fmt.Sprintf("invalid config file %s: invalid format", filepath.Join(tempDir, "invalid_config.yml")),
330 },
331 {
332 name: "invalid hosts file",
333 globalConfigPath: globalFilePath,
334 hostsConfigPath: invalidHostsFilePath,
335 wantErr: true,
336 wantErrMsg: fmt.Sprintf("invalid config file %s: invalid format", filepath.Join(tempDir, "invalid_hosts.yml")),
337 },
338 {
339 name: "global file does not exist and hosts file exist",
340 globalConfigPath: "",
341 hostsConfigPath: hostsFilePath,
342 wantGitProtocol: "",
343 wantToken: "yyyyyyyyyyyyyyyyyyyy",
344 },
345 {
346 name: "global file exist and hosts file does not exist",
347 globalConfigPath: globalFilePath,
348 hostsConfigPath: "",
349 wantGitProtocol: "ssh",
350 wantToken: "",
351 },
352 {
353 name: "global file does not exist and hosts file does not exist with no fallback",
354 globalConfigPath: "",
355 hostsConfigPath: "",
356 wantGitProtocol: "",
357 wantToken: "",
358 },
359 {
360 name: "global file does not exist and hosts file does not exist with fallback",
361 globalConfigPath: "",
362 hostsConfigPath: "",
363 fallback: ReadFromString(testFullConfig()),
364 wantGitProtocol: "ssh",
365 wantToken: "yyyyyyyyyyyyyyyyyyyy",
366 },
367 }
368
369 for _, tt := range tests {
370 t.Run(tt.name, func(t *testing.T) {
371 cfg, err := load(tt.globalConfigPath, tt.hostsConfigPath, tt.fallback)
372 if tt.wantErr {
373 assert.EqualError(t, err, tt.wantErrMsg)
374 return
375 }
376 assert.NoError(t, err)
377
378 if tt.wantGitProtocol == "" {
379 assertNoKey(t, cfg, []string{"git_protocol"})
380 } else {
381 assertKeyWithValue(t, cfg, []string{"git_protocol"}, tt.wantGitProtocol)
382 }
383
384 if tt.wantToken == "" {
385 assertNoKey(t, cfg, []string{"hosts", "enterprise.com", "oauth_token"})
386 } else {
387 assertKeyWithValue(t, cfg, []string{"hosts", "enterprise.com", "oauth_token"}, tt.wantToken)
388 }
389
390 if tt.fallback != nil {
391
392 assert.Equal(t, tt.fallback.entries.String(), cfg.entries.String())
393 assert.False(t, tt.fallback == cfg)
394 }
395 })
396 }
397 }
398
399 func TestWrite(t *testing.T) {
400 tests := []struct {
401 name string
402 createConfig func() *Config
403 wantConfig func() *Config
404 wantErr bool
405 wantErrMsg string
406 }{
407 {
408 name: "writes config and hosts files",
409 createConfig: func() *Config {
410 cfg := ReadFromString(testFullConfig())
411 cfg.Set([]string{"editor"}, "vim")
412 cfg.Set([]string{"hosts", "github.com", "git_protocol"}, "https")
413 return cfg
414 },
415 wantConfig: func() *Config {
416
417
418 cfg := ReadFromString(testFullConfig())
419 cfg.Set([]string{"editor"}, "vim")
420 cfg.Set([]string{"hosts", "github.com", "git_protocol"}, "https")
421 return cfg
422 },
423 },
424 {
425 name: "only writes hosts file",
426 createConfig: func() *Config {
427 cfg := ReadFromString(testFullConfig())
428 cfg.Set([]string{"hosts", "enterprise.com", "git_protocol"}, "ssh")
429 return cfg
430 },
431 wantConfig: func() *Config {
432
433 cfg := ReadFromString("")
434 cfg.Set([]string{"hosts", "github.com", "user"}, "user1")
435 cfg.Set([]string{"hosts", "github.com", "oauth_token"}, "xxxxxxxxxxxxxxxxxxxx")
436 cfg.Set([]string{"hosts", "github.com", "git_protocol"}, "ssh")
437 cfg.Set([]string{"hosts", "enterprise.com", "user"}, "user2")
438 cfg.Set([]string{"hosts", "enterprise.com", "oauth_token"}, "yyyyyyyyyyyyyyyyyyyy")
439 cfg.Set([]string{"hosts", "enterprise.com", "git_protocol"}, "ssh")
440 return cfg
441 },
442 },
443 {
444 name: "only writes global config file",
445 createConfig: func() *Config {
446 cfg := ReadFromString(testFullConfig())
447 cfg.Set([]string{"editor"}, "vim")
448 return cfg
449 },
450 wantConfig: func() *Config {
451
452 cfg := ReadFromString(testGlobalData())
453 cfg.Set([]string{"editor"}, "vim")
454 return cfg
455 },
456 },
457 }
458
459 for _, tt := range tests {
460 t.Run(tt.name, func(t *testing.T) {
461 tempDir := t.TempDir()
462 t.Setenv("GH_CONFIG_DIR", tempDir)
463 cfg := tt.createConfig()
464 err := Write(cfg)
465 assert.NoError(t, err)
466 loadedCfg, err := load(generalConfigFile(), hostsConfigFile(), nil)
467 assert.NoError(t, err)
468 wantCfg := tt.wantConfig()
469 assert.Equal(t, wantCfg.entries.String(), loadedCfg.entries.String())
470 })
471 }
472 }
473
474 func TestWriteEmptyValues(t *testing.T) {
475 tempDir := t.TempDir()
476 t.Setenv("GH_CONFIG_DIR", tempDir)
477 cfg := ReadFromString(testFullConfig())
478 cfg.Set([]string{"editor"}, "")
479 err := Write(cfg)
480 assert.NoError(t, err)
481 data, err := os.ReadFile(generalConfigFile())
482 assert.NoError(t, err)
483 assert.Equal(t, "git_protocol: ssh\neditor:\nprompt: enabled\npager: less\n", string(data))
484 }
485
486 func TestGet(t *testing.T) {
487 tests := []struct {
488 name string
489 keys []string
490 wantValue string
491 wantErr bool
492 }{
493 {
494 name: "get git_protocol value",
495 keys: []string{"git_protocol"},
496 wantValue: "ssh",
497 },
498 {
499 name: "get editor value",
500 keys: []string{"editor"},
501 wantValue: "",
502 },
503 {
504 name: "get prompt value",
505 keys: []string{"prompt"},
506 wantValue: "enabled",
507 },
508 {
509 name: "get pager value",
510 keys: []string{"pager"},
511 wantValue: "less",
512 },
513 {
514 name: "non-existant key",
515 keys: []string{"unknown"},
516 wantErr: true,
517 },
518 {
519 name: "nested key",
520 keys: []string{"nested", "key"},
521 wantValue: "value",
522 },
523 {
524 name: "nested key with same name",
525 keys: []string{"nested", "pager"},
526 wantValue: "more",
527 },
528 {
529 name: "nested non-existant key",
530 keys: []string{"nested", "invalid"},
531 wantErr: true,
532 },
533 }
534
535 for _, tt := range tests {
536 t.Run(tt.name, func(t *testing.T) {
537 cfg := testConfig()
538 if tt.wantErr {
539 assertNoKey(t, cfg, tt.keys)
540 } else {
541 assertKeyWithValue(t, cfg, tt.keys, tt.wantValue)
542 }
543 assert.False(t, cfg.entries.IsModified())
544 })
545 }
546 }
547
548 func TestKeys(t *testing.T) {
549 tests := []struct {
550 name string
551 findKeys []string
552 wantKeys []string
553 wantErr bool
554 wantErrMsg string
555 }{
556 {
557 name: "top level keys",
558 findKeys: nil,
559 wantKeys: []string{"git_protocol", "editor", "prompt", "pager", "nested"},
560 },
561 {
562 name: "nested keys",
563 findKeys: []string{"nested"},
564 wantKeys: []string{"key", "pager"},
565 },
566 {
567 name: "keys for non-existant nested key",
568 findKeys: []string{"unknown"},
569 wantKeys: nil,
570 wantErr: true,
571 wantErrMsg: `could not find key "unknown"`,
572 },
573 }
574
575 for _, tt := range tests {
576 t.Run(tt.name, func(t *testing.T) {
577 cfg := testConfig()
578 ks, err := cfg.Keys(tt.findKeys)
579 if tt.wantErr {
580 assert.EqualError(t, err, tt.wantErrMsg)
581 } else {
582 assert.NoError(t, err)
583 }
584 assert.Equal(t, tt.wantKeys, ks)
585 assert.False(t, cfg.entries.IsModified())
586 })
587 }
588 }
589
590 func TestRemove(t *testing.T) {
591 tests := []struct {
592 name string
593 keys []string
594 wantErr bool
595 wantErrMsg string
596 }{
597 {
598 name: "remove top level key",
599 keys: []string{"pager"},
600 },
601 {
602 name: "remove nested key",
603 keys: []string{"nested", "pager"},
604 },
605 {
606 name: "remove top level map",
607 keys: []string{"nested"},
608 },
609 {
610 name: "remove non-existant top level key",
611 keys: []string{"unknown"},
612 wantErr: true,
613 wantErrMsg: `could not find key "unknown"`,
614 },
615 {
616 name: "remove non-existant nested key",
617 keys: []string{"nested", "invalid"},
618 wantErr: true,
619 wantErrMsg: `could not find key "invalid"`,
620 },
621 }
622
623 for _, tt := range tests {
624 t.Run(tt.name, func(t *testing.T) {
625 cfg := testConfig()
626 err := cfg.Remove(tt.keys)
627 if tt.wantErr {
628 assert.EqualError(t, err, tt.wantErrMsg)
629 assert.False(t, cfg.entries.IsModified())
630 } else {
631 assert.NoError(t, err)
632 assert.True(t, cfg.entries.IsModified())
633 }
634 assertNoKey(t, cfg, tt.keys)
635 })
636 }
637 }
638
639 func TestSet(t *testing.T) {
640 tests := []struct {
641 name string
642 keys []string
643 value string
644 }{
645 {
646 name: "set top level existing key",
647 keys: []string{"pager"},
648 value: "test pager",
649 },
650 {
651 name: "set nested existing key",
652 keys: []string{"nested", "pager"},
653 value: "new pager",
654 },
655 {
656 name: "set top level map",
657 keys: []string{"nested"},
658 value: "override",
659 },
660 {
661 name: "set non-existant top level key",
662 keys: []string{"unknown"},
663 value: "why not",
664 },
665 {
666 name: "set non-existant nested key",
667 keys: []string{"nested", "invalid"},
668 value: "sure",
669 },
670 {
671 name: "set non-existant nest",
672 keys: []string{"johnny", "test"},
673 value: "dukey",
674 },
675 {
676 name: "set empty value",
677 keys: []string{"empty"},
678 value: "",
679 },
680 }
681
682 for _, tt := range tests {
683 t.Run(tt.name, func(t *testing.T) {
684 cfg := testConfig()
685 cfg.Set(tt.keys, tt.value)
686 assert.True(t, cfg.entries.IsModified())
687 assertKeyWithValue(t, cfg, tt.keys, tt.value)
688 })
689 }
690 }
691
692 func TestEntriesShouldBeModifiedOnLoad(t *testing.T) {
693
694 tempDir := t.TempDir()
695 t.Setenv("GH_CONFIG_DIR", tempDir)
696
697 require.NoError(t, writeFile(hostsConfigFile(), []byte(testHostsData())))
698 require.NoError(t, writeFile(generalConfigFile(), []byte(testGlobalData())))
699
700
701 cfg, err := load(generalConfigFile(), hostsConfigFile(), nil)
702 require.NoError(t, err)
703
704
705
706 require.False(t, cfg.entries.IsModified())
707
708 hosts, err := cfg.entries.FindEntry("hosts")
709 require.NoError(t, err)
710 require.False(t, hosts.IsModified())
711 }
712
713 func testConfig() *Config {
714 var data = `
715 git_protocol: ssh
716 editor:
717 prompt: enabled
718 pager: less
719 nested:
720 key: value
721 pager: more
722 `
723 return ReadFromString(data)
724 }
725
726 func testGlobalData() string {
727 var data = `
728 git_protocol: ssh
729 editor:
730 prompt: enabled
731 pager: less
732 `
733 return data
734 }
735
736 func testHostsData() string {
737 var data = `
738 github.com:
739 user: user1
740 oauth_token: xxxxxxxxxxxxxxxxxxxx
741 git_protocol: ssh
742 enterprise.com:
743 user: user2
744 oauth_token: yyyyyyyyyyyyyyyyyyyy
745 git_protocol: https
746 `
747 return data
748 }
749
750 func testFullConfig() string {
751 var data = `
752 git_protocol: ssh
753 editor:
754 prompt: enabled
755 pager: less
756 hosts:
757 github.com:
758 user: user1
759 oauth_token: xxxxxxxxxxxxxxxxxxxx
760 git_protocol: ssh
761 enterprise.com:
762 user: user2
763 oauth_token: yyyyyyyyyyyyyyyyyyyy
764 git_protocol: https
765 `
766 return data
767 }
768
769 func assertNoKey(t *testing.T, cfg *Config, keys []string) {
770 t.Helper()
771 _, err := cfg.Get(keys)
772 var keyNotFoundError *KeyNotFoundError
773 assert.ErrorAs(t, err, &keyNotFoundError)
774 }
775
776 func assertKeyWithValue(t *testing.T, cfg *Config, keys []string, value string) {
777 t.Helper()
778 actual, err := cfg.Get(keys)
779 assert.NoError(t, err)
780 assert.Equal(t, value, actual)
781 }
782
View as plain text