...
1 package viper
2
3 import (
4 "testing"
5
6 "github.com/spf13/pflag"
7 "github.com/stretchr/testify/assert"
8 "github.com/stretchr/testify/require"
9 )
10
11 func TestBindFlagValueSet(t *testing.T) {
12 Reset()
13 flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError)
14
15 testValues := map[string]*string{
16 "host": nil,
17 "port": nil,
18 "endpoint": nil,
19 }
20
21 mutatedTestValues := map[string]string{
22 "host": "localhost",
23 "port": "6060",
24 "endpoint": "/public",
25 }
26
27 for name := range testValues {
28 testValues[name] = flagSet.String(name, "", "test")
29 }
30
31 flagValueSet := pflagValueSet{flagSet}
32
33 err := BindFlagValues(flagValueSet)
34 require.NoError(t, err, "error binding flag set")
35
36 flagSet.VisitAll(func(flag *pflag.Flag) {
37 flag.Value.Set(mutatedTestValues[flag.Name])
38 flag.Changed = true
39 })
40
41 for name, expected := range mutatedTestValues {
42 assert.Equal(t, expected, Get(name))
43 }
44 }
45
46 func TestBindFlagValue(t *testing.T) {
47 testString := "testing"
48 testValue := newStringValue(testString, &testString)
49
50 flag := &pflag.Flag{
51 Name: "testflag",
52 Value: testValue,
53 Changed: false,
54 }
55
56 flagValue := pflagValue{flag}
57 BindFlagValue("testvalue", flagValue)
58
59 assert.Equal(t, testString, Get("testvalue"))
60
61 flag.Value.Set("testing_mutate")
62 flag.Changed = true
63
64 assert.Equal(t, "testing_mutate", Get("testvalue"))
65 }
66
View as plain text