1
2
3
4
5 package envconfig
6
7 import (
8 "encoding"
9 "errors"
10 "fmt"
11 "os"
12 "reflect"
13 "regexp"
14 "strconv"
15 "strings"
16 "time"
17 )
18
19
20 var ErrInvalidSpecification = errors.New("specification must be a struct pointer")
21
22 var gatherRegexp = regexp.MustCompile("([^A-Z]+|[A-Z]+[^A-Z]+|[A-Z]+)")
23 var acronymRegexp = regexp.MustCompile("([A-Z]+)([A-Z][^A-Z]+)")
24
25
26
27 type ParseError struct {
28 KeyName string
29 FieldName string
30 TypeName string
31 Value string
32 Err error
33 }
34
35
36
37 type Decoder interface {
38 Decode(value string) error
39 }
40
41
42
43 type Setter interface {
44 Set(value string) error
45 }
46
47 func (e *ParseError) Error() string {
48 return fmt.Sprintf("envconfig.Process: assigning %[1]s to %[2]s: converting '%[3]s' to type %[4]s. details: %[5]s", e.KeyName, e.FieldName, e.Value, e.TypeName, e.Err)
49 }
50
51
52 type varInfo struct {
53 Name string
54 Alt string
55 Key string
56 Field reflect.Value
57 Tags reflect.StructTag
58 }
59
60
61 func gatherInfo(prefix string, spec interface{}) ([]varInfo, error) {
62 s := reflect.ValueOf(spec)
63
64 if s.Kind() != reflect.Ptr {
65 return nil, ErrInvalidSpecification
66 }
67 s = s.Elem()
68 if s.Kind() != reflect.Struct {
69 return nil, ErrInvalidSpecification
70 }
71 typeOfSpec := s.Type()
72
73
74 infos := make([]varInfo, 0, s.NumField())
75 for i := 0; i < s.NumField(); i++ {
76 f := s.Field(i)
77 ftype := typeOfSpec.Field(i)
78 if !f.CanSet() || isTrue(ftype.Tag.Get("ignored")) {
79 continue
80 }
81
82 for f.Kind() == reflect.Ptr {
83 if f.IsNil() {
84 if f.Type().Elem().Kind() != reflect.Struct {
85
86 break
87 }
88
89 f.Set(reflect.New(f.Type().Elem()))
90 }
91 f = f.Elem()
92 }
93
94
95 info := varInfo{
96 Name: ftype.Name,
97 Field: f,
98 Tags: ftype.Tag,
99 Alt: strings.ToUpper(ftype.Tag.Get("envconfig")),
100 }
101
102
103 info.Key = info.Name
104
105
106 if isTrue(ftype.Tag.Get("split_words")) {
107 words := gatherRegexp.FindAllStringSubmatch(ftype.Name, -1)
108 if len(words) > 0 {
109 var name []string
110 for _, words := range words {
111 if m := acronymRegexp.FindStringSubmatch(words[0]); len(m) == 3 {
112 name = append(name, m[1], m[2])
113 } else {
114 name = append(name, words[0])
115 }
116 }
117
118 info.Key = strings.Join(name, "_")
119 }
120 }
121 if info.Alt != "" {
122 info.Key = info.Alt
123 }
124 if prefix != "" {
125 info.Key = fmt.Sprintf("%s_%s", prefix, info.Key)
126 }
127 info.Key = strings.ToUpper(info.Key)
128 infos = append(infos, info)
129
130 if f.Kind() == reflect.Struct {
131
132 if decoderFrom(f) == nil && setterFrom(f) == nil && textUnmarshaler(f) == nil && binaryUnmarshaler(f) == nil {
133 innerPrefix := prefix
134 if !ftype.Anonymous {
135 innerPrefix = info.Key
136 }
137
138 embeddedPtr := f.Addr().Interface()
139 embeddedInfos, err := gatherInfo(innerPrefix, embeddedPtr)
140 if err != nil {
141 return nil, err
142 }
143 infos = append(infos[:len(infos)-1], embeddedInfos...)
144
145 continue
146 }
147 }
148 }
149 return infos, nil
150 }
151
152
153
154
155 func CheckDisallowed(prefix string, spec interface{}) error {
156 infos, err := gatherInfo(prefix, spec)
157 if err != nil {
158 return err
159 }
160
161 vars := make(map[string]struct{})
162 for _, info := range infos {
163 vars[info.Key] = struct{}{}
164 }
165
166 if prefix != "" {
167 prefix = strings.ToUpper(prefix) + "_"
168 }
169
170 for _, env := range os.Environ() {
171 if !strings.HasPrefix(env, prefix) {
172 continue
173 }
174 v := strings.SplitN(env, "=", 2)[0]
175 if _, found := vars[v]; !found {
176 return fmt.Errorf("unknown environment variable %s", v)
177 }
178 }
179
180 return nil
181 }
182
183
184 func Process(prefix string, spec interface{}) error {
185 infos, err := gatherInfo(prefix, spec)
186
187 for _, info := range infos {
188
189
190
191
192
193 value, ok := lookupEnv(info.Key)
194 if !ok && info.Alt != "" {
195 value, ok = lookupEnv(info.Alt)
196 }
197
198 def := info.Tags.Get("default")
199 if def != "" && !ok {
200 value = def
201 }
202
203 req := info.Tags.Get("required")
204 if !ok && def == "" {
205 if isTrue(req) {
206 key := info.Key
207 if info.Alt != "" {
208 key = info.Alt
209 }
210 return fmt.Errorf("required key %s missing value", key)
211 }
212 continue
213 }
214
215 err = processField(value, info.Field)
216 if err != nil {
217 return &ParseError{
218 KeyName: info.Key,
219 FieldName: info.Name,
220 TypeName: info.Field.Type().String(),
221 Value: value,
222 Err: err,
223 }
224 }
225 }
226
227 return err
228 }
229
230
231 func MustProcess(prefix string, spec interface{}) {
232 if err := Process(prefix, spec); err != nil {
233 panic(err)
234 }
235 }
236
237 func processField(value string, field reflect.Value) error {
238 typ := field.Type()
239
240 decoder := decoderFrom(field)
241 if decoder != nil {
242 return decoder.Decode(value)
243 }
244
245 setter := setterFrom(field)
246 if setter != nil {
247 return setter.Set(value)
248 }
249
250 if t := textUnmarshaler(field); t != nil {
251 return t.UnmarshalText([]byte(value))
252 }
253
254 if b := binaryUnmarshaler(field); b != nil {
255 return b.UnmarshalBinary([]byte(value))
256 }
257
258 if typ.Kind() == reflect.Ptr {
259 typ = typ.Elem()
260 if field.IsNil() {
261 field.Set(reflect.New(typ))
262 }
263 field = field.Elem()
264 }
265
266 switch typ.Kind() {
267 case reflect.String:
268 field.SetString(value)
269 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
270 var (
271 val int64
272 err error
273 )
274 if field.Kind() == reflect.Int64 && typ.PkgPath() == "time" && typ.Name() == "Duration" {
275 var d time.Duration
276 d, err = time.ParseDuration(value)
277 val = int64(d)
278 } else {
279 val, err = strconv.ParseInt(value, 0, typ.Bits())
280 }
281 if err != nil {
282 return err
283 }
284
285 field.SetInt(val)
286 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
287 val, err := strconv.ParseUint(value, 0, typ.Bits())
288 if err != nil {
289 return err
290 }
291 field.SetUint(val)
292 case reflect.Bool:
293 val, err := strconv.ParseBool(value)
294 if err != nil {
295 return err
296 }
297 field.SetBool(val)
298 case reflect.Float32, reflect.Float64:
299 val, err := strconv.ParseFloat(value, typ.Bits())
300 if err != nil {
301 return err
302 }
303 field.SetFloat(val)
304 case reflect.Slice:
305 sl := reflect.MakeSlice(typ, 0, 0)
306 if typ.Elem().Kind() == reflect.Uint8 {
307 sl = reflect.ValueOf([]byte(value))
308 } else if len(strings.TrimSpace(value)) != 0 {
309 vals := strings.Split(value, ",")
310 sl = reflect.MakeSlice(typ, len(vals), len(vals))
311 for i, val := range vals {
312 err := processField(val, sl.Index(i))
313 if err != nil {
314 return err
315 }
316 }
317 }
318 field.Set(sl)
319 case reflect.Map:
320 mp := reflect.MakeMap(typ)
321 if len(strings.TrimSpace(value)) != 0 {
322 pairs := strings.Split(value, ",")
323 for _, pair := range pairs {
324 kvpair := strings.Split(pair, ":")
325 if len(kvpair) != 2 {
326 return fmt.Errorf("invalid map item: %q", pair)
327 }
328 k := reflect.New(typ.Key()).Elem()
329 err := processField(kvpair[0], k)
330 if err != nil {
331 return err
332 }
333 v := reflect.New(typ.Elem()).Elem()
334 err = processField(kvpair[1], v)
335 if err != nil {
336 return err
337 }
338 mp.SetMapIndex(k, v)
339 }
340 }
341 field.Set(mp)
342 }
343
344 return nil
345 }
346
347 func interfaceFrom(field reflect.Value, fn func(interface{}, *bool)) {
348
349 if !field.CanInterface() {
350 return
351 }
352 var ok bool
353 fn(field.Interface(), &ok)
354 if !ok && field.CanAddr() {
355 fn(field.Addr().Interface(), &ok)
356 }
357 }
358
359 func decoderFrom(field reflect.Value) (d Decoder) {
360 interfaceFrom(field, func(v interface{}, ok *bool) { d, *ok = v.(Decoder) })
361 return d
362 }
363
364 func setterFrom(field reflect.Value) (s Setter) {
365 interfaceFrom(field, func(v interface{}, ok *bool) { s, *ok = v.(Setter) })
366 return s
367 }
368
369 func textUnmarshaler(field reflect.Value) (t encoding.TextUnmarshaler) {
370 interfaceFrom(field, func(v interface{}, ok *bool) { t, *ok = v.(encoding.TextUnmarshaler) })
371 return t
372 }
373
374 func binaryUnmarshaler(field reflect.Value) (b encoding.BinaryUnmarshaler) {
375 interfaceFrom(field, func(v interface{}, ok *bool) { b, *ok = v.(encoding.BinaryUnmarshaler) })
376 return b
377 }
378
379 func isTrue(s string) bool {
380 b, _ := strconv.ParseBool(s)
381 return b
382 }
383
View as plain text