...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package app
17
18 import (
19 "fmt"
20 "log"
21 "strings"
22 "time"
23
24 validator "github.com/go-playground/validator/v10"
25 "github.com/pkg/errors"
26 "github.com/spf13/pflag"
27 )
28
29 type FlagType string
30
31 const (
32 fileFlag FlagType = "file"
33 urlFlag FlagType = "url"
34 oidFlag FlagType = "oid"
35 formatFlag FlagType = "format"
36 timeoutFlag FlagType = "timeout"
37 )
38
39 type newPFlagValueFunc func() pflag.Value
40
41 var pflagValueFuncMap map[FlagType]newPFlagValueFunc
42
43
44 func initializePFlagMap() {
45 pflagValueFuncMap = map[FlagType]newPFlagValueFunc{
46 fileFlag: func() pflag.Value {
47
48 return valueFactory(fileFlag, validateString("required,file"), "")
49 },
50 urlFlag: func() pflag.Value {
51
52 return valueFactory(urlFlag, validateString("required,url,startswith=http|startswith=https"), "")
53 },
54 oidFlag: func() pflag.Value {
55
56 return valueFactory(oidFlag, validateOID, "")
57 },
58 formatFlag: func() pflag.Value {
59
60 return valueFactory(formatFlag, validateString("required,oneof=json default"), "")
61 },
62 timeoutFlag: func() pflag.Value {
63
64 return valueFactory(formatFlag, validateTimeout, "")
65 },
66 }
67 }
68
69
70
71 func NewFlagValue(flagType FlagType, defaultVal string) pflag.Value {
72 valFunc := pflagValueFuncMap[flagType]
73 val := valFunc()
74 if defaultVal != "" {
75 if err := val.Set(defaultVal); err != nil {
76 log.Fatal(errors.Wrap(err, "initializing flag"))
77 }
78 }
79 return val
80 }
81
82 type validationFunc func(string) error
83
84 func valueFactory(flagType FlagType, v validationFunc, defaultVal string) pflag.Value {
85 return &baseValue{
86 flagType: flagType,
87 validationFunc: v,
88 value: defaultVal,
89 }
90 }
91
92
93 type baseValue struct {
94 flagType FlagType
95 value string
96 validationFunc validationFunc
97 }
98
99
100 func (b baseValue) Type() string {
101 return string(b.flagType)
102 }
103
104
105 func (b baseValue) String() string {
106 return b.value
107 }
108
109
110
111
112 func (b *baseValue) Set(s string) error {
113 if err := b.validationFunc(s); err != nil {
114 return err
115 }
116 b.value = s
117 return nil
118 }
119
120
121 func validateOID(v string) error {
122 o := struct {
123 Oid []string `validate:"dive,numeric"`
124 }{strings.Split(v, ".")}
125
126 return useValidator(oidFlag, o)
127 }
128
129
130 func validateTimeout(v string) error {
131 duration, err := time.ParseDuration(v)
132 if err != nil {
133 return err
134 }
135 d := struct {
136 Duration time.Duration `validate:"min=0"`
137 }{duration}
138 return useValidator(timeoutFlag, d)
139 }
140
141
142
143 func validateString(tag string) validationFunc {
144 return func(v string) error {
145 validator := validator.New()
146 return validator.Var(v, tag)
147 }
148 }
149
150
151
152 func useValidator(flagType FlagType, s interface{}) error {
153 validate := validator.New()
154 if err := validate.Struct(s); err != nil {
155 return fmt.Errorf("error parsing %v flag: %w", flagType, err)
156 }
157
158 return nil
159 }
160
View as plain text