...

Source file src/github.com/sigstore/rekor/cmd/rekor-cli/app/pflags.go

Documentation: github.com/sigstore/rekor/cmd/rekor-cli/app

     1  //
     2  // Copyright 2021 The Sigstore Authors.
     3  //
     4  // Licensed under the Apache License, Version 2.0 (the "License");
     5  // you may not use this file except in compliance with the License.
     6  // You may obtain a copy of the License at
     7  //
     8  //     http://www.apache.org/licenses/LICENSE-2.0
     9  //
    10  // Unless required by applicable law or agreed to in writing, software
    11  // distributed under the License is distributed on an "AS IS" BASIS,
    12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  // See the License for the specific language governing permissions and
    14  // limitations under the License.
    15  
    16  package app
    17  
    18  import (
    19  	"encoding/base64"
    20  	"errors"
    21  	"fmt"
    22  	"log"
    23  	"os"
    24  	"path/filepath"
    25  	"strconv"
    26  	"strings"
    27  	"time"
    28  
    29  	"github.com/sigstore/rekor/pkg/pki"
    30  	"github.com/sigstore/rekor/pkg/sharding"
    31  
    32  	"github.com/spf13/pflag"
    33  
    34  	validator "github.com/asaskevich/govalidator"
    35  )
    36  
    37  type FlagType string
    38  
    39  const (
    40  	uuidFlag           FlagType = "uuid"
    41  	shaFlag            FlagType = "sha"
    42  	emailFlag          FlagType = "email"
    43  	operatorFlag       FlagType = "operator"
    44  	logIndexFlag       FlagType = "logIndex"
    45  	pkiFormatFlag      FlagType = "pkiFormat"
    46  	typeFlag           FlagType = "type"
    47  	fileFlag           FlagType = "file"
    48  	urlFlag            FlagType = "url"
    49  	fileOrURLFlag      FlagType = "fileOrURL"
    50  	multiFileOrURLFlag FlagType = "multiFileOrURL"
    51  	oidFlag            FlagType = "oid"
    52  	formatFlag         FlagType = "format"
    53  	timeoutFlag        FlagType = "timeout"
    54  	base64Flag         FlagType = "base64"
    55  	uintFlag           FlagType = "uint"
    56  )
    57  
    58  type newPFlagValueFunc func() pflag.Value
    59  
    60  var pflagValueFuncMap map[FlagType]newPFlagValueFunc
    61  
    62  // TODO: unit tests for all of this
    63  func initializePFlagMap() {
    64  	pflagValueFuncMap = map[FlagType]newPFlagValueFunc{
    65  		uuidFlag: func() pflag.Value {
    66  			// this validates a UUID with or without a prepended TreeID;
    67  			// the UUID corresponds to the merkle leaf hash of entries,
    68  			// which is represented by a 64 character hexadecimal string
    69  			return valueFactory(uuidFlag, validateID, "")
    70  		},
    71  		shaFlag: func() pflag.Value {
    72  			// this validates a valid sha256 checksum which is optionally prefixed with 'sha256:'
    73  			return valueFactory(shaFlag, validateSHAValue, "")
    74  		},
    75  		operatorFlag: func() pflag.Value {
    76  			// this validates a valid operator name
    77  			operatorFlagValidator := func(val string) error {
    78  				o := struct {
    79  					Value string `valid:"in(and|or)"`
    80  				}{val}
    81  				_, err := validator.ValidateStruct(o)
    82  				return err
    83  			}
    84  			return valueFactory(operatorFlag, operatorFlagValidator, "")
    85  		},
    86  		emailFlag: func() pflag.Value {
    87  			// this validates an email address
    88  			emailValidator := func(val string) error {
    89  				if !validator.IsEmail(val) {
    90  					return fmt.Errorf("'%v' is not a valid email address", val)
    91  				}
    92  				return nil
    93  			}
    94  			return valueFactory(emailFlag, emailValidator, "")
    95  		},
    96  		logIndexFlag: func() pflag.Value {
    97  			// this checks for a valid integer >= 0
    98  			return valueFactory(logIndexFlag, validateUint, "")
    99  		},
   100  		pkiFormatFlag: func() pflag.Value {
   101  			// this ensures a PKI implementation exists for the requested format
   102  			pkiFormatValidator := func(val string) error {
   103  				if !validator.IsIn(val, pki.SupportedFormats()...) {
   104  					return fmt.Errorf("'%v' is not a valid pki format", val)
   105  				}
   106  				return nil
   107  			}
   108  			return valueFactory(pkiFormatFlag, pkiFormatValidator, "pgp")
   109  		},
   110  		typeFlag: func() pflag.Value {
   111  			// this ensures the type of the log entry matches a type supported in the CLI
   112  			return valueFactory(typeFlag, validateTypeFlag, "rekord")
   113  		},
   114  		fileFlag: func() pflag.Value {
   115  			// this validates that the file exists and can be opened by the current uid
   116  			return valueFactory(fileFlag, validateFile, "")
   117  		},
   118  		urlFlag: func() pflag.Value {
   119  			// this validates that the string is a valid http/https URL
   120  			httpHTTPSValidator := func(val string) error {
   121  				if !validator.IsURL(val) {
   122  					return fmt.Errorf("'%v' is not a valid url", val)
   123  				}
   124  				if !(strings.HasPrefix(val, "http") || strings.HasPrefix(val, "https")) {
   125  					return errors.New("URL must be for http or https scheme")
   126  				}
   127  				return nil
   128  			}
   129  			return valueFactory(urlFlag, httpHTTPSValidator, "")
   130  		},
   131  		fileOrURLFlag: func() pflag.Value {
   132  			// applies logic of fileFlag OR urlFlag validators from above
   133  			return valueFactory(fileOrURLFlag, validateFileOrURL, "")
   134  		},
   135  		multiFileOrURLFlag: func() pflag.Value {
   136  			// applies logic of fileFlag OR urlFlag validators from above for multi file and URL
   137  			return multiValueFactory(multiFileOrURLFlag, validateFileOrURL, []string{})
   138  		},
   139  		oidFlag: func() pflag.Value {
   140  			// this validates for an OID, which is a sequence of positive integers separated by periods
   141  			return valueFactory(oidFlag, validateOID, "")
   142  		},
   143  		formatFlag: func() pflag.Value {
   144  			// this validates the output format requested
   145  			formatValidator := func(val string) error {
   146  				if !validator.IsIn(val, "json", "default", "tle") {
   147  					return fmt.Errorf("'%v' is not a valid output format", val)
   148  				}
   149  				return nil
   150  			}
   151  			return valueFactory(formatFlag, formatValidator, "")
   152  		},
   153  		timeoutFlag: func() pflag.Value {
   154  			// this validates the timeout is >= 0
   155  			return valueFactory(formatFlag, validateTimeout, "")
   156  		},
   157  		base64Flag: func() pflag.Value {
   158  			// This validates the string is in base64 format
   159  			return valueFactory(base64Flag, validateBase64, "")
   160  		},
   161  		uintFlag: func() pflag.Value {
   162  			// This validates the string is in base64 format
   163  			return valueFactory(uintFlag, validateUint, "")
   164  		},
   165  	}
   166  }
   167  
   168  // NewFlagValue creates a new pflag.Value for the specified type with the specified default value.
   169  // If a default value is not desired, pass "" for defaultVal.
   170  func NewFlagValue(flagType FlagType, defaultVal string) pflag.Value {
   171  	valFunc := pflagValueFuncMap[flagType]
   172  	val := valFunc()
   173  	if defaultVal != "" {
   174  		if err := val.Set(defaultVal); err != nil {
   175  			log.Fatal(fmt.Errorf("initializing flag: %w", err))
   176  		}
   177  	}
   178  	return val
   179  }
   180  
   181  type validationFunc func(string) error
   182  
   183  func valueFactory(flagType FlagType, v validationFunc, defaultVal string) pflag.Value {
   184  	return &baseValue{
   185  		flagType:       flagType,
   186  		validationFunc: v,
   187  		value:          defaultVal,
   188  	}
   189  }
   190  
   191  func multiValueFactory(flagType FlagType, v validationFunc, defaultVal []string) pflag.Value {
   192  	return &multiBaseValue{
   193  		flagType:       flagType,
   194  		validationFunc: v,
   195  		value:          defaultVal,
   196  	}
   197  }
   198  
   199  // multiBaseValue implements pflag.Value
   200  type multiBaseValue struct {
   201  	flagType       FlagType
   202  	value          []string
   203  	validationFunc validationFunc
   204  }
   205  
   206  func (b *multiBaseValue) String() string {
   207  	return strings.Join(b.value, ",")
   208  }
   209  
   210  // Type returns the type of this Value
   211  func (b multiBaseValue) Type() string {
   212  	return string(b.flagType)
   213  }
   214  
   215  func (b *multiBaseValue) Set(value string) error {
   216  	if err := b.validationFunc(value); err != nil {
   217  		return err
   218  	}
   219  	b.value = append(b.value, value)
   220  	return nil
   221  }
   222  
   223  // baseValue implements pflag.Value
   224  type baseValue struct {
   225  	flagType       FlagType
   226  	value          string
   227  	validationFunc validationFunc
   228  }
   229  
   230  // Type returns the type of this Value
   231  func (b baseValue) Type() string {
   232  	return string(b.flagType)
   233  }
   234  
   235  // String returns the string representation of this Value
   236  func (b baseValue) String() string {
   237  	return b.value
   238  }
   239  
   240  // Set validates the provided string against the appropriate validation rule
   241  // for b.flagType; if the string validates, it is stored in the Value and nil is returned.
   242  // Otherwise the validation error is returned but the state of the Value is not changed.
   243  func (b *baseValue) Set(s string) error {
   244  	if err := b.validationFunc(s); err != nil {
   245  		return err
   246  	}
   247  	b.value = s
   248  	return nil
   249  }
   250  
   251  // isURL returns true if the supplied value is a valid URL and false otherwise
   252  func isURL(v string) bool {
   253  	valGen := pflagValueFuncMap[urlFlag]
   254  	return valGen().Set(v) == nil
   255  }
   256  
   257  // validateSHAValue ensures that the supplied string matches the following formats:
   258  // [sha512:]<128 hexadecimal characters>
   259  // [sha256:]<64 hexadecimal characters>
   260  // [sha1:]<40 hexadecimal characters>
   261  // where [sha256:] and [sha1:] are optional
   262  func validateSHAValue(v string) error {
   263  	err := validateSHA1Value(v)
   264  	if err == nil {
   265  		return nil
   266  	}
   267  
   268  	err = validateSHA256Value(v)
   269  	if err == nil {
   270  		return nil
   271  	}
   272  
   273  	err = validateSHA512Value(v)
   274  	if err == nil {
   275  		return nil
   276  	}
   277  
   278  	return fmt.Errorf("error parsing %v flag: %w", shaFlag, err)
   279  }
   280  
   281  // validateFileOrURL ensures the provided string is either a valid file path that can be opened or a valid URL
   282  func validateFileOrURL(v string) error {
   283  	valGen := pflagValueFuncMap[fileFlag]
   284  	if valGen().Set(v) == nil {
   285  		return nil
   286  	}
   287  	valGen = pflagValueFuncMap[urlFlag]
   288  	return valGen().Set(v)
   289  }
   290  
   291  // validateID ensures the ID is either an EntryID (TreeID + UUID) or a UUID
   292  func validateID(v string) error {
   293  	if len(v) != sharding.EntryIDHexStringLen && len(v) != sharding.UUIDHexStringLen {
   294  		return fmt.Errorf("ID len error, expected %v (EntryID) or %v (UUID) but got len %v for ID %v", sharding.EntryIDHexStringLen, sharding.UUIDHexStringLen, len(v), v)
   295  	}
   296  
   297  	if !validator.IsHexadecimal(v) {
   298  		return fmt.Errorf("invalid uuid: %v", v)
   299  	}
   300  
   301  	return nil
   302  }
   303  
   304  // validateOID ensures that the supplied string is a valid ASN.1 object identifier
   305  func validateOID(v string) error {
   306  	values := strings.Split(v, ".")
   307  	for _, value := range values {
   308  		if !validator.IsNumeric(value) {
   309  			return fmt.Errorf("field '%v' is not a valid number", value)
   310  		}
   311  	}
   312  
   313  	return nil
   314  }
   315  
   316  // validateTimeout ensures that the supplied string is a valid time.Duration value >= 0
   317  func validateTimeout(v string) error {
   318  	duration, err := time.ParseDuration(v)
   319  	if err != nil {
   320  		return err
   321  	}
   322  	if duration < 0 {
   323  		return errors.New("timeout must be a positive value")
   324  	}
   325  	return nil
   326  }
   327  
   328  // validateBase64 ensures that the supplied string is valid base64 encoded data
   329  func validateBase64(v string) error {
   330  	_, err := base64.StdEncoding.DecodeString(v)
   331  
   332  	return err
   333  }
   334  
   335  // validateTypeFlag ensures that the string is in the format type(\.version)? and
   336  // that one of the types requested is implemented
   337  func validateTypeFlag(v string) error {
   338  	_, _, err := ParseTypeFlag(v)
   339  	return err
   340  }
   341  
   342  // validateUint ensures that the supplied string is a valid unsigned integer >= 0
   343  func validateUint(v string) error {
   344  	i, err := strconv.Atoi(v)
   345  	if err != nil {
   346  		return err
   347  	}
   348  	if i < 0 {
   349  		return fmt.Errorf("invalid unsigned int: %v", v)
   350  	}
   351  	return nil
   352  }
   353  
   354  // validateFile ensures that the supplied string is a valid path to a file that exists
   355  func validateFile(v string) error {
   356  	fileInfo, err := os.Stat(filepath.Clean(v))
   357  	if err != nil {
   358  		return err
   359  	}
   360  	if fileInfo.IsDir() {
   361  		return errors.New("path to a directory was provided")
   362  	}
   363  	return nil
   364  }
   365  

View as plain text