1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package app
17
18 import (
19 "encoding/base64"
20 "errors"
21 "fmt"
22 "log"
23 "os"
24 "path/filepath"
25 "strconv"
26 "strings"
27 "time"
28
29 "github.com/sigstore/rekor/pkg/pki"
30 "github.com/sigstore/rekor/pkg/sharding"
31
32 "github.com/spf13/pflag"
33
34 validator "github.com/asaskevich/govalidator"
35 )
36
37 type FlagType string
38
39 const (
40 uuidFlag FlagType = "uuid"
41 shaFlag FlagType = "sha"
42 emailFlag FlagType = "email"
43 operatorFlag FlagType = "operator"
44 logIndexFlag FlagType = "logIndex"
45 pkiFormatFlag FlagType = "pkiFormat"
46 typeFlag FlagType = "type"
47 fileFlag FlagType = "file"
48 urlFlag FlagType = "url"
49 fileOrURLFlag FlagType = "fileOrURL"
50 multiFileOrURLFlag FlagType = "multiFileOrURL"
51 oidFlag FlagType = "oid"
52 formatFlag FlagType = "format"
53 timeoutFlag FlagType = "timeout"
54 base64Flag FlagType = "base64"
55 uintFlag FlagType = "uint"
56 )
57
58 type newPFlagValueFunc func() pflag.Value
59
60 var pflagValueFuncMap map[FlagType]newPFlagValueFunc
61
62
63 func initializePFlagMap() {
64 pflagValueFuncMap = map[FlagType]newPFlagValueFunc{
65 uuidFlag: func() pflag.Value {
66
67
68
69 return valueFactory(uuidFlag, validateID, "")
70 },
71 shaFlag: func() pflag.Value {
72
73 return valueFactory(shaFlag, validateSHAValue, "")
74 },
75 operatorFlag: func() pflag.Value {
76
77 operatorFlagValidator := func(val string) error {
78 o := struct {
79 Value string `valid:"in(and|or)"`
80 }{val}
81 _, err := validator.ValidateStruct(o)
82 return err
83 }
84 return valueFactory(operatorFlag, operatorFlagValidator, "")
85 },
86 emailFlag: func() pflag.Value {
87
88 emailValidator := func(val string) error {
89 if !validator.IsEmail(val) {
90 return fmt.Errorf("'%v' is not a valid email address", val)
91 }
92 return nil
93 }
94 return valueFactory(emailFlag, emailValidator, "")
95 },
96 logIndexFlag: func() pflag.Value {
97
98 return valueFactory(logIndexFlag, validateUint, "")
99 },
100 pkiFormatFlag: func() pflag.Value {
101
102 pkiFormatValidator := func(val string) error {
103 if !validator.IsIn(val, pki.SupportedFormats()...) {
104 return fmt.Errorf("'%v' is not a valid pki format", val)
105 }
106 return nil
107 }
108 return valueFactory(pkiFormatFlag, pkiFormatValidator, "pgp")
109 },
110 typeFlag: func() pflag.Value {
111
112 return valueFactory(typeFlag, validateTypeFlag, "rekord")
113 },
114 fileFlag: func() pflag.Value {
115
116 return valueFactory(fileFlag, validateFile, "")
117 },
118 urlFlag: func() pflag.Value {
119
120 httpHTTPSValidator := func(val string) error {
121 if !validator.IsURL(val) {
122 return fmt.Errorf("'%v' is not a valid url", val)
123 }
124 if !(strings.HasPrefix(val, "http") || strings.HasPrefix(val, "https")) {
125 return errors.New("URL must be for http or https scheme")
126 }
127 return nil
128 }
129 return valueFactory(urlFlag, httpHTTPSValidator, "")
130 },
131 fileOrURLFlag: func() pflag.Value {
132
133 return valueFactory(fileOrURLFlag, validateFileOrURL, "")
134 },
135 multiFileOrURLFlag: func() pflag.Value {
136
137 return multiValueFactory(multiFileOrURLFlag, validateFileOrURL, []string{})
138 },
139 oidFlag: func() pflag.Value {
140
141 return valueFactory(oidFlag, validateOID, "")
142 },
143 formatFlag: func() pflag.Value {
144
145 formatValidator := func(val string) error {
146 if !validator.IsIn(val, "json", "default", "tle") {
147 return fmt.Errorf("'%v' is not a valid output format", val)
148 }
149 return nil
150 }
151 return valueFactory(formatFlag, formatValidator, "")
152 },
153 timeoutFlag: func() pflag.Value {
154
155 return valueFactory(formatFlag, validateTimeout, "")
156 },
157 base64Flag: func() pflag.Value {
158
159 return valueFactory(base64Flag, validateBase64, "")
160 },
161 uintFlag: func() pflag.Value {
162
163 return valueFactory(uintFlag, validateUint, "")
164 },
165 }
166 }
167
168
169
170 func NewFlagValue(flagType FlagType, defaultVal string) pflag.Value {
171 valFunc := pflagValueFuncMap[flagType]
172 val := valFunc()
173 if defaultVal != "" {
174 if err := val.Set(defaultVal); err != nil {
175 log.Fatal(fmt.Errorf("initializing flag: %w", err))
176 }
177 }
178 return val
179 }
180
181 type validationFunc func(string) error
182
183 func valueFactory(flagType FlagType, v validationFunc, defaultVal string) pflag.Value {
184 return &baseValue{
185 flagType: flagType,
186 validationFunc: v,
187 value: defaultVal,
188 }
189 }
190
191 func multiValueFactory(flagType FlagType, v validationFunc, defaultVal []string) pflag.Value {
192 return &multiBaseValue{
193 flagType: flagType,
194 validationFunc: v,
195 value: defaultVal,
196 }
197 }
198
199
200 type multiBaseValue struct {
201 flagType FlagType
202 value []string
203 validationFunc validationFunc
204 }
205
206 func (b *multiBaseValue) String() string {
207 return strings.Join(b.value, ",")
208 }
209
210
211 func (b multiBaseValue) Type() string {
212 return string(b.flagType)
213 }
214
215 func (b *multiBaseValue) Set(value string) error {
216 if err := b.validationFunc(value); err != nil {
217 return err
218 }
219 b.value = append(b.value, value)
220 return nil
221 }
222
223
224 type baseValue struct {
225 flagType FlagType
226 value string
227 validationFunc validationFunc
228 }
229
230
231 func (b baseValue) Type() string {
232 return string(b.flagType)
233 }
234
235
236 func (b baseValue) String() string {
237 return b.value
238 }
239
240
241
242
243 func (b *baseValue) Set(s string) error {
244 if err := b.validationFunc(s); err != nil {
245 return err
246 }
247 b.value = s
248 return nil
249 }
250
251
252 func isURL(v string) bool {
253 valGen := pflagValueFuncMap[urlFlag]
254 return valGen().Set(v) == nil
255 }
256
257
258
259
260
261
262 func validateSHAValue(v string) error {
263 err := validateSHA1Value(v)
264 if err == nil {
265 return nil
266 }
267
268 err = validateSHA256Value(v)
269 if err == nil {
270 return nil
271 }
272
273 err = validateSHA512Value(v)
274 if err == nil {
275 return nil
276 }
277
278 return fmt.Errorf("error parsing %v flag: %w", shaFlag, err)
279 }
280
281
282 func validateFileOrURL(v string) error {
283 valGen := pflagValueFuncMap[fileFlag]
284 if valGen().Set(v) == nil {
285 return nil
286 }
287 valGen = pflagValueFuncMap[urlFlag]
288 return valGen().Set(v)
289 }
290
291
292 func validateID(v string) error {
293 if len(v) != sharding.EntryIDHexStringLen && len(v) != sharding.UUIDHexStringLen {
294 return fmt.Errorf("ID len error, expected %v (EntryID) or %v (UUID) but got len %v for ID %v", sharding.EntryIDHexStringLen, sharding.UUIDHexStringLen, len(v), v)
295 }
296
297 if !validator.IsHexadecimal(v) {
298 return fmt.Errorf("invalid uuid: %v", v)
299 }
300
301 return nil
302 }
303
304
305 func validateOID(v string) error {
306 values := strings.Split(v, ".")
307 for _, value := range values {
308 if !validator.IsNumeric(value) {
309 return fmt.Errorf("field '%v' is not a valid number", value)
310 }
311 }
312
313 return nil
314 }
315
316
317 func validateTimeout(v string) error {
318 duration, err := time.ParseDuration(v)
319 if err != nil {
320 return err
321 }
322 if duration < 0 {
323 return errors.New("timeout must be a positive value")
324 }
325 return nil
326 }
327
328
329 func validateBase64(v string) error {
330 _, err := base64.StdEncoding.DecodeString(v)
331
332 return err
333 }
334
335
336
337 func validateTypeFlag(v string) error {
338 _, _, err := ParseTypeFlag(v)
339 return err
340 }
341
342
343 func validateUint(v string) error {
344 i, err := strconv.Atoi(v)
345 if err != nil {
346 return err
347 }
348 if i < 0 {
349 return fmt.Errorf("invalid unsigned int: %v", v)
350 }
351 return nil
352 }
353
354
355 func validateFile(v string) error {
356 fileInfo, err := os.Stat(filepath.Clean(v))
357 if err != nil {
358 return err
359 }
360 if fileInfo.IsDir() {
361 return errors.New("path to a directory was provided")
362 }
363 return nil
364 }
365
View as plain text