1 package viper
2
3 import (
4 "fmt"
5 "strings"
6 "testing"
7
8 "github.com/spf13/cast"
9 "github.com/stretchr/testify/assert"
10 )
11
12 type layer int
13
14 const (
15 defaultLayer layer = iota + 1
16 overrideLayer
17 )
18
19 func TestNestedOverrides(t *testing.T) {
20 assert := assert.New(t)
21 var v *Viper
22
23
24 overrideDefault(assert, "tom", 10, "tom", 20)
25 override(assert, "tom", 10, "tom", 20)
26 overrideDefault(assert, "tom.age", 10, "tom.age", 20)
27 override(assert, "tom.age", 10, "tom.age", 20)
28 overrideDefault(assert, "sawyer.tom.age", 10, "sawyer.tom.age", 20)
29 override(assert, "sawyer.tom.age", 10, "sawyer.tom.age", 20)
30
31
32 v = overrideDefault(assert, "tom.age", 10, "tom", "boy")
33 assert.Nil(v.Get("tom.age"))
34 v = override(assert, "tom.age", 10, "tom", "boy")
35 assert.Nil(v.Get("tom.age"))
36
37
38 overrideDefault(assert, "tom", "boy", "tom.age", 10)
39 override(assert, "tom.age", 10, "tom", "boy")
40
41
42 v = overrideDefault(assert, "tom.size", 4, "tom.age", 10)
43 assert.Equal(4, v.Get("tom.size"))
44 v = override(assert, "tom.size", 4, "tom.age", 10)
45 assert.Equal(4, v.Get("tom.size"))
46 deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4)
47
48
49 v = overrideDefault(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10})
50 assert.Equal(4, v.Get("tom.size"))
51 assert.Equal(10, v.Get("tom.age"))
52 deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10)
53 v = override(assert, "tom.size", 4, "tom", map[string]interface{}{"age": 10})
54 assert.Nil(v.Get("tom.size"))
55 assert.Equal(10, v.Get("tom.age"))
56 deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10)
57
58
59 overrideDefault(assert, "tom", []int{10, 20}, "tom", 30)
60 override(assert, "tom", []int{10, 20}, "tom", 30)
61 overrideDefault(assert, "tom.age", []int{10, 20}, "tom.age", 30)
62 override(assert, "tom.age", []int{10, 20}, "tom.age", 30)
63
64
65 overrideDefault(assert, "tom", []int{10, 20}, "tom", []int{30, 40})
66 override(assert, "tom", []int{10, 20}, "tom", []int{30, 40})
67 overrideDefault(assert, "tom.age", []int{10, 20}, "tom.age", []int{30, 40})
68 v = override(assert, "tom.age", []int{10, 20}, "tom.age", []int{30, 40})
69
70 s, ok := v.Get("tom.age").([]int)
71 if assert.True(ok, "tom[\"age\"] is not a slice") {
72 v.Set("tom.age", append(s, []int{50, 60}...))
73 assert.Equal([]int{30, 40, 50, 60}, v.Get("tom.age"))
74 deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, []int{30, 40, 50, 60})
75 }
76 }
77
78 func overrideDefault(assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper {
79 return overrideFromLayer(defaultLayer, assert, firstPath, firstValue, secondPath, secondValue)
80 }
81 func override(assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper {
82 return overrideFromLayer(overrideLayer, assert, firstPath, firstValue, secondPath, secondValue)
83 }
84
85
86
87
88
89
90
91
92
93
94
95
96 func overrideFromLayer(l layer, assert *assert.Assertions, firstPath string, firstValue interface{}, secondPath string, secondValue interface{}) *Viper {
97 v := New()
98 firstKeys := strings.Split(firstPath, v.keyDelim)
99 if assert == nil ||
100 len(firstKeys) == 0 || len(firstKeys[0]) == 0 {
101 return v
102 }
103
104
105 switch l {
106 case defaultLayer:
107 v.SetDefault(firstPath, firstValue)
108 case overrideLayer:
109 v.Set(firstPath, firstValue)
110 default:
111 return v
112 }
113 assert.Equal(firstValue, v.Get(firstPath))
114 deepCheckValue(assert, v, l, firstKeys, firstValue)
115
116
117 secondKeys := strings.Split(secondPath, v.keyDelim)
118 if len(secondKeys) == 0 || len(secondKeys[0]) == 0 {
119 return v
120 }
121 v.Set(secondPath, secondValue)
122 assert.Equal(secondValue, v.Get(secondPath))
123 deepCheckValue(assert, v, overrideLayer, secondKeys, secondValue)
124
125 return v
126 }
127
128
129
130 func deepCheckValue(assert *assert.Assertions, v *Viper, l layer, keys []string, value interface{}) {
131 if assert == nil || v == nil ||
132 len(keys) == 0 || len(keys[0]) == 0 {
133 return
134 }
135
136
137 var val interface{}
138 var ms string
139 switch l {
140 case defaultLayer:
141 val = v.defaults
142 ms = "v.defaults"
143 case overrideLayer:
144 val = v.override
145 ms = "v.override"
146 }
147
148
149 var m map[string]interface{}
150 err := false
151 for _, k := range keys {
152 if val == nil {
153 assert.Fail(fmt.Sprintf("%s is not a map[string]interface{}", ms))
154 return
155 }
156
157
158 switch val.(type) {
159 case map[interface{}]interface{}:
160 m = cast.ToStringMap(val)
161 case map[string]interface{}:
162 m = val.(map[string]interface{})
163 default:
164 assert.Fail(fmt.Sprintf("%s is not a map[string]interface{}", ms))
165 return
166 }
167 ms = ms + "[\"" + k + "\"]"
168 val = m[k]
169 }
170 if !err {
171 assert.Equal(value, val)
172 }
173 }
174
View as plain text