1 package configuration
2
3 import (
4 "fmt"
5 "os"
6 "reflect"
7 "sort"
8 "strconv"
9 "strings"
10
11 "github.com/sirupsen/logrus"
12 "gopkg.in/yaml.v2"
13 )
14
15
16
17
18 type Version string
19
20
21 func MajorMinorVersion(major, minor uint) Version {
22 return Version(fmt.Sprintf("%d.%d", major, minor))
23 }
24
25 func (version Version) major() (uint, error) {
26 majorPart := strings.Split(string(version), ".")[0]
27 major, err := strconv.ParseUint(majorPart, 10, 0)
28 return uint(major), err
29 }
30
31
32 func (version Version) Major() uint {
33 major, _ := version.major()
34 return major
35 }
36
37 func (version Version) minor() (uint, error) {
38 minorPart := strings.Split(string(version), ".")[1]
39 minor, err := strconv.ParseUint(minorPart, 10, 0)
40 return uint(minor), err
41 }
42
43
44 func (version Version) Minor() uint {
45 minor, _ := version.minor()
46 return minor
47 }
48
49
50
51 type VersionedParseInfo struct {
52
53 Version Version
54
55
56 ParseAs reflect.Type
57
58
59
60 ConversionFunc func(interface{}) (interface{}, error)
61 }
62
63 type envVar struct {
64 name string
65 value string
66 }
67
68 type envVars []envVar
69
70 func (a envVars) Len() int { return len(a) }
71 func (a envVars) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
72 func (a envVars) Less(i, j int) bool { return a[i].name < a[j].name }
73
74
75
76 type Parser struct {
77 prefix string
78 mapping map[Version]VersionedParseInfo
79 env envVars
80 }
81
82
83
84 func NewParser(prefix string, parseInfos []VersionedParseInfo) *Parser {
85 p := Parser{prefix: prefix, mapping: make(map[Version]VersionedParseInfo)}
86
87 for _, parseInfo := range parseInfos {
88 p.mapping[parseInfo.Version] = parseInfo
89 }
90
91 for _, env := range os.Environ() {
92 envParts := strings.SplitN(env, "=", 2)
93 p.env = append(p.env, envVar{envParts[0], envParts[1]})
94 }
95
96
97
98
99
100
101
102 sort.Sort(p.env)
103
104 return &p
105 }
106
107
108
109
110
111
112
113
114 func (p *Parser) Parse(in []byte, v interface{}) error {
115 var versionedStruct struct {
116 Version Version
117 }
118
119 if err := yaml.Unmarshal(in, &versionedStruct); err != nil {
120 return err
121 }
122
123 parseInfo, ok := p.mapping[versionedStruct.Version]
124 if !ok {
125 return fmt.Errorf("unsupported version: %q", versionedStruct.Version)
126 }
127
128 parseAs := reflect.New(parseInfo.ParseAs)
129 err := yaml.Unmarshal(in, parseAs.Interface())
130 if err != nil {
131 return err
132 }
133
134 for _, envVar := range p.env {
135 pathStr := envVar.name
136 if strings.HasPrefix(pathStr, strings.ToUpper(p.prefix)+"_") {
137 path := strings.Split(pathStr, "_")
138
139 err = p.overwriteFields(parseAs, pathStr, path[1:], envVar.value)
140 if err != nil {
141 return err
142 }
143 }
144 }
145
146 c, err := parseInfo.ConversionFunc(parseAs.Interface())
147 if err != nil {
148 return err
149 }
150 reflect.ValueOf(v).Elem().Set(reflect.Indirect(reflect.ValueOf(c)))
151 return nil
152 }
153
154
155
156
157 func (p *Parser) overwriteFields(v reflect.Value, fullpath string, path []string, payload string) error {
158 for v.Kind() == reflect.Ptr {
159 if v.IsNil() {
160 panic("encountered nil pointer while handling environment variable " + fullpath)
161 }
162 v = reflect.Indirect(v)
163 }
164 switch v.Kind() {
165 case reflect.Struct:
166 return p.overwriteStruct(v, fullpath, path, payload)
167 case reflect.Map:
168 return p.overwriteMap(v, fullpath, path, payload)
169 case reflect.Interface:
170 if v.NumMethod() == 0 {
171 if !v.IsNil() {
172 return p.overwriteFields(v.Elem(), fullpath, path, payload)
173 }
174
175 var template map[string]interface{}
176 wrappedV := reflect.MakeMap(reflect.TypeOf(template))
177 v.Set(wrappedV)
178 return p.overwriteMap(wrappedV, fullpath, path, payload)
179 }
180 }
181 return nil
182 }
183
184 func (p *Parser) overwriteStruct(v reflect.Value, fullpath string, path []string, payload string) error {
185
186 byUpperCase := make(map[string]int)
187 for i := 0; i < v.NumField(); i++ {
188 sf := v.Type().Field(i)
189 upper := strings.ToUpper(sf.Name)
190 if _, present := byUpperCase[upper]; present {
191 panic(fmt.Sprintf("field name collision in configuration object: %s", sf.Name))
192 }
193 byUpperCase[upper] = i
194 }
195
196 fieldIndex, present := byUpperCase[path[0]]
197 if !present {
198 logrus.Warnf("Ignoring unrecognized environment variable %s", fullpath)
199 return nil
200 }
201 field := v.Field(fieldIndex)
202 sf := v.Type().Field(fieldIndex)
203
204 if len(path) == 1 {
205
206 fieldVal := reflect.New(sf.Type)
207 err := yaml.Unmarshal([]byte(payload), fieldVal.Interface())
208 if err != nil {
209 return err
210 }
211 field.Set(reflect.Indirect(fieldVal))
212 return nil
213 }
214
215
216 switch sf.Type.Kind() {
217 case reflect.Map:
218 if field.IsNil() {
219 field.Set(reflect.MakeMap(sf.Type))
220 }
221 case reflect.Ptr:
222 if field.IsNil() {
223 field.Set(reflect.New(sf.Type))
224 }
225 }
226
227 err := p.overwriteFields(field, fullpath, path[1:], payload)
228 if err != nil {
229 return err
230 }
231
232 return nil
233 }
234
235 func (p *Parser) overwriteMap(m reflect.Value, fullpath string, path []string, payload string) error {
236 if m.Type().Key().Kind() != reflect.String {
237
238 logrus.Warnf("Ignoring environment variable %s involving map with non-string keys", fullpath)
239 return nil
240 }
241
242 if len(path) > 1 {
243
244
245 for _, k := range m.MapKeys() {
246 if strings.ToUpper(k.String()) == path[0] {
247 mapValue := m.MapIndex(k)
248
249
250 if (mapValue.Kind() == reflect.Ptr ||
251 mapValue.Kind() == reflect.Interface ||
252 mapValue.Kind() == reflect.Map) &&
253 mapValue.IsNil() {
254 break
255 }
256 return p.overwriteFields(mapValue, fullpath, path[1:], payload)
257 }
258 }
259 }
260
261
262 var mapValue reflect.Value
263 if m.Type().Elem().Kind() == reflect.Map {
264 mapValue = reflect.MakeMap(m.Type().Elem())
265 } else {
266 mapValue = reflect.New(m.Type().Elem())
267 }
268 if len(path) > 1 {
269 err := p.overwriteFields(mapValue, fullpath, path[1:], payload)
270 if err != nil {
271 return err
272 }
273 } else {
274 err := yaml.Unmarshal([]byte(payload), mapValue.Interface())
275 if err != nil {
276 return err
277 }
278 }
279
280 m.SetMapIndex(reflect.ValueOf(strings.ToLower(path[0])), reflect.Indirect(mapValue))
281
282 return nil
283 }
284
View as plain text