1 package viper
2
3 import (
4 "strings"
5 "testing"
6
7 "github.com/spf13/cast"
8 "github.com/stretchr/testify/assert"
9 )
10
11 type layer int
12
13 const (
14 defaultLayer layer = iota + 1
15 overrideLayer
16 )
17
18 func TestNestedOverrides(t *testing.T) {
19 assert := assert.New(t)
20 var v *Viper
21
22
23 overrideDefault(assert, "tom", 10, "tom", 20)
24 override(assert, "tom", 10, "tom", 20)
25 overrideDefault(assert, "tom.age", 10, "tom.age", 20)
26 override(assert, "tom.age", 10, "tom.age", 20)
27 overrideDefault(assert, "sawyer.tom.age", 10, "sawyer.tom.age", 20)
28 override(assert, "sawyer.tom.age", 10, "sawyer.tom.age", 20)
29
30
31 v = overrideDefault(assert, "tom.age", 10, "tom", "boy")
32 assert.Nil(v.Get("tom.age"))
33 v = override(assert, "tom.age", 10, "tom", "boy")
34 assert.Nil(v.Get("tom.age"))
35
36
37 overrideDefault(assert, "tom", "boy", "tom.age", 10)
38 override(assert, "tom.age", 10, "tom", "boy")
39
40
41 v = overrideDefault(assert, "tom.size", 4, "tom.age", 10)
42 assert.Equal(4, v.Get("tom.size"))
43 v = override(assert, "tom.size", 4, "tom.age", 10)
44 assert.Equal(4, v.Get("tom.size"))
45 deepCheckValue(assert, v, overrideLayer, []string{"tom", "size"}, 4)
46
47
48 v = overrideDefault(assert, "tom.size", 4, "tom", map[string]any{"age": 10})
49 assert.Equal(4, v.Get("tom.size"))
50 assert.Equal(10, v.Get("tom.age"))
51 deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10)
52 v = override(assert, "tom.size", 4, "tom", map[string]any{"age": 10})
53 assert.Nil(v.Get("tom.size"))
54 assert.Equal(10, v.Get("tom.age"))
55 deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, 10)
56
57
58 overrideDefault(assert, "tom", []int{10, 20}, "tom", 30)
59 override(assert, "tom", []int{10, 20}, "tom", 30)
60 overrideDefault(assert, "tom.age", []int{10, 20}, "tom.age", 30)
61 override(assert, "tom.age", []int{10, 20}, "tom.age", 30)
62
63
64 overrideDefault(assert, "tom", []int{10, 20}, "tom", []int{30, 40})
65 override(assert, "tom", []int{10, 20}, "tom", []int{30, 40})
66 overrideDefault(assert, "tom.age", []int{10, 20}, "tom.age", []int{30, 40})
67 v = override(assert, "tom.age", []int{10, 20}, "tom.age", []int{30, 40})
68
69 s, ok := v.Get("tom.age").([]int)
70 if assert.True(ok, "tom[\"age\"] is not a slice") {
71 v.Set("tom.age", append(s, []int{50, 60}...))
72 assert.Equal([]int{30, 40, 50, 60}, v.Get("tom.age"))
73 deepCheckValue(assert, v, overrideLayer, []string{"tom", "age"}, []int{30, 40, 50, 60})
74 }
75 }
76
77 func overrideDefault(assert *assert.Assertions, firstPath string, firstValue any, secondPath string, secondValue any) *Viper {
78 return overrideFromLayer(defaultLayer, assert, firstPath, firstValue, secondPath, secondValue)
79 }
80
81 func override(assert *assert.Assertions, firstPath string, firstValue any, secondPath string, secondValue any) *Viper {
82 return overrideFromLayer(overrideLayer, assert, firstPath, firstValue, secondPath, secondValue)
83 }
84
85
86
87
88
89
90
91
92
93
94
95
96 func overrideFromLayer(l layer, assert *assert.Assertions, firstPath string, firstValue any, secondPath string, secondValue any) *Viper {
97 v := New()
98 firstKeys := strings.Split(firstPath, v.keyDelim)
99 if assert == nil ||
100 len(firstKeys) == 0 || firstKeys[0] == "" {
101 return v
102 }
103
104
105 switch l {
106 case defaultLayer:
107 v.SetDefault(firstPath, firstValue)
108 case overrideLayer:
109 v.Set(firstPath, firstValue)
110 default:
111 return v
112 }
113 assert.Equal(firstValue, v.Get(firstPath))
114 deepCheckValue(assert, v, l, firstKeys, firstValue)
115
116
117 secondKeys := strings.Split(secondPath, v.keyDelim)
118 if len(secondKeys) == 0 || secondKeys[0] == "" {
119 return v
120 }
121 v.Set(secondPath, secondValue)
122 assert.Equal(secondValue, v.Get(secondPath))
123 deepCheckValue(assert, v, overrideLayer, secondKeys, secondValue)
124
125 return v
126 }
127
128
129
130 func deepCheckValue(assert *assert.Assertions, v *Viper, l layer, keys []string, value any) {
131 if assert == nil || v == nil ||
132 len(keys) == 0 || keys[0] == "" {
133 return
134 }
135
136
137 var val any
138 var ms string
139 switch l {
140 case defaultLayer:
141 val = v.defaults
142 ms = "v.defaults"
143 case overrideLayer:
144 val = v.override
145 ms = "v.override"
146 }
147
148
149 var m map[string]any
150 for _, k := range keys {
151 if val == nil {
152 assert.Failf("%s is not a map[string]any", ms)
153 return
154 }
155
156
157 switch val := val.(type) {
158 case map[any]any:
159 m = cast.ToStringMap(val)
160 case map[string]any:
161 m = val
162 default:
163 assert.Failf("%s is not a map[string]any", ms)
164 return
165 }
166 ms = ms + "[\"" + k + "\"]"
167 val = m[k]
168 }
169 assert.Equal(value, val)
170 }
171
View as plain text