...

Source file src/github.com/letsencrypt/boulder/test/asserts.go

Documentation: github.com/letsencrypt/boulder/test

     1  package test
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/base64"
     6  	"encoding/json"
     7  	"errors"
     8  	"reflect"
     9  	"strings"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/prometheus/client_golang/prometheus"
    14  	io_prometheus_client "github.com/prometheus/client_model/go"
    15  )
    16  
    17  // Assert a boolean
    18  func Assert(t *testing.T, result bool, message string) {
    19  	t.Helper()
    20  	if !result {
    21  		t.Fatal(message)
    22  	}
    23  }
    24  
    25  // AssertNil checks that an object is nil. Being a "boxed nil" (a nil value
    26  // wrapped in a non-nil interface type) is not good enough.
    27  func AssertNil(t *testing.T, obj interface{}, message string) {
    28  	t.Helper()
    29  	if obj != nil {
    30  		t.Fatal(message)
    31  	}
    32  }
    33  
    34  // AssertNotNil checks an object to be non-nil. Being a "boxed nil" (a nil value
    35  // wrapped in a non-nil interface type) is not good enough.
    36  // Note that there is a gap between AssertNil and AssertNotNil. Both fail when
    37  // called with a boxed nil. This is intentional: we want to avoid boxed nils.
    38  func AssertNotNil(t *testing.T, obj interface{}, message string) {
    39  	t.Helper()
    40  	if obj == nil {
    41  		t.Fatal(message)
    42  	}
    43  	switch reflect.TypeOf(obj).Kind() {
    44  	// .IsNil() only works on chan, func, interface, map, pointer, and slice.
    45  	case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
    46  		if reflect.ValueOf(obj).IsNil() {
    47  			t.Fatal(message)
    48  		}
    49  	}
    50  }
    51  
    52  // AssertBoxedNil checks that an inner object is nil. This is intentional for
    53  // testing purposes only.
    54  func AssertBoxedNil(t *testing.T, obj interface{}, message string) {
    55  	t.Helper()
    56  	typ := reflect.TypeOf(obj).Kind()
    57  	switch typ {
    58  	// .IsNil() only works on chan, func, interface, map, pointer, and slice.
    59  	case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
    60  		if !reflect.ValueOf(obj).IsNil() {
    61  			t.Fatal(message)
    62  		}
    63  	default:
    64  		t.Fatalf("Cannot check type \"%s\". Needs to be of type chan, func, interface, map, pointer, or slice.", typ)
    65  	}
    66  }
    67  
    68  // AssertNotError checks that err is nil
    69  func AssertNotError(t *testing.T, err error, message string) {
    70  	t.Helper()
    71  	if err != nil {
    72  		t.Fatalf("%s: %s", message, err)
    73  	}
    74  }
    75  
    76  // AssertError checks that err is non-nil
    77  func AssertError(t *testing.T, err error, message string) {
    78  	t.Helper()
    79  	if err == nil {
    80  		t.Fatalf("%s: expected error but received none", message)
    81  	}
    82  }
    83  
    84  // AssertErrorWraps checks that err can be unwrapped into the given target.
    85  // NOTE: Has the side effect of actually performing that unwrapping.
    86  func AssertErrorWraps(t *testing.T, err error, target interface{}) {
    87  	t.Helper()
    88  	if !errors.As(err, target) {
    89  		t.Fatalf("error does not wrap an error of the expected type: %q !> %+T", err.Error(), target)
    90  	}
    91  }
    92  
    93  // AssertErrorIs checks that err wraps the given error
    94  func AssertErrorIs(t *testing.T, err error, target error) {
    95  	t.Helper()
    96  
    97  	if err == nil {
    98  		t.Fatal("err was unexpectedly nil and should not have been")
    99  	}
   100  
   101  	if !errors.Is(err, target) {
   102  		t.Fatalf("error does not wrap expected error: %q !> %q", err.Error(), target.Error())
   103  	}
   104  }
   105  
   106  // AssertEquals uses the equality operator (==) to measure one and two
   107  func AssertEquals(t *testing.T, one interface{}, two interface{}) {
   108  	t.Helper()
   109  	if reflect.TypeOf(one) != reflect.TypeOf(two) {
   110  		t.Fatalf("cannot test equality of different types: %T != %T", one, two)
   111  	}
   112  	if one != two {
   113  		t.Fatalf("%#v != %#v", one, two)
   114  	}
   115  }
   116  
   117  // AssertDeepEquals uses the reflect.DeepEqual method to measure one and two
   118  func AssertDeepEquals(t *testing.T, one interface{}, two interface{}) {
   119  	t.Helper()
   120  	if !reflect.DeepEqual(one, two) {
   121  		t.Fatalf("[%#v] !(deep)= [%#v]", one, two)
   122  	}
   123  }
   124  
   125  // AssertMarshaledEquals marshals one and two to JSON, and then uses
   126  // the equality operator to measure them
   127  func AssertMarshaledEquals(t *testing.T, one interface{}, two interface{}) {
   128  	t.Helper()
   129  	oneJSON, err := json.Marshal(one)
   130  	AssertNotError(t, err, "Could not marshal 1st argument")
   131  	twoJSON, err := json.Marshal(two)
   132  	AssertNotError(t, err, "Could not marshal 2nd argument")
   133  
   134  	if !bytes.Equal(oneJSON, twoJSON) {
   135  		t.Fatalf("[%s] !(json)= [%s]", oneJSON, twoJSON)
   136  	}
   137  }
   138  
   139  // AssertUnmarshaledEquals unmarshals two JSON strings (got and expected) to
   140  // a map[string]interface{} and then uses reflect.DeepEqual to check they are
   141  // the same
   142  func AssertUnmarshaledEquals(t *testing.T, got, expected string) {
   143  	t.Helper()
   144  	var gotMap, expectedMap map[string]interface{}
   145  	err := json.Unmarshal([]byte(got), &gotMap)
   146  	AssertNotError(t, err, "Could not unmarshal 'got'")
   147  	err = json.Unmarshal([]byte(expected), &expectedMap)
   148  	AssertNotError(t, err, "Could not unmarshal 'expected'")
   149  	if len(gotMap) != len(expectedMap) {
   150  		t.Errorf("Expected had %d keys, got had %d", len(gotMap), len(expectedMap))
   151  	}
   152  	for k, v := range expectedMap {
   153  		if !reflect.DeepEqual(v, gotMap[k]) {
   154  			t.Errorf("Field %q: Expected \"%v\", got \"%v\"", k, v, gotMap[k])
   155  		}
   156  	}
   157  }
   158  
   159  // AssertNotEquals uses the equality operator to measure that one and two
   160  // are different
   161  func AssertNotEquals(t *testing.T, one interface{}, two interface{}) {
   162  	t.Helper()
   163  	if one == two {
   164  		t.Fatalf("%#v == %#v", one, two)
   165  	}
   166  }
   167  
   168  // AssertByteEquals uses bytes.Equal to measure one and two for equality.
   169  func AssertByteEquals(t *testing.T, one []byte, two []byte) {
   170  	t.Helper()
   171  	if !bytes.Equal(one, two) {
   172  		t.Fatalf("Byte [%s] != [%s]",
   173  			base64.StdEncoding.EncodeToString(one),
   174  			base64.StdEncoding.EncodeToString(two))
   175  	}
   176  }
   177  
   178  // AssertContains determines whether needle can be found in haystack
   179  func AssertContains(t *testing.T, haystack string, needle string) {
   180  	t.Helper()
   181  	if !strings.Contains(haystack, needle) {
   182  		t.Fatalf("String [%s] does not contain [%s]", haystack, needle)
   183  	}
   184  }
   185  
   186  // AssertNotContains determines if needle is not found in haystack
   187  func AssertNotContains(t *testing.T, haystack string, needle string) {
   188  	t.Helper()
   189  	if strings.Contains(haystack, needle) {
   190  		t.Fatalf("String [%s] contains [%s]", haystack, needle)
   191  	}
   192  }
   193  
   194  // AssertSliceContains determines if needle can be found in haystack
   195  func AssertSliceContains[T comparable](t *testing.T, haystack []T, needle T) {
   196  	t.Helper()
   197  	for _, item := range haystack {
   198  		if item == needle {
   199  			return
   200  		}
   201  	}
   202  	t.Fatalf("Slice %v does not contain %v", haystack, needle)
   203  }
   204  
   205  // AssertMetricWithLabelsEquals determines whether the value held by a prometheus Collector
   206  // (e.g. Gauge, Counter, CounterVec, etc) is equal to the expected float64.
   207  // In order to make useful assertions about just a subset of labels (e.g. for a
   208  // CounterVec with fields "host" and "valid", being able to assert that two
   209  // "valid": "true" increments occurred, without caring which host was tagged in
   210  // each), takes a set of labels and ignores any metrics which have different
   211  // label values.
   212  // Only works for simple metrics (Counters and Gauges), or for the *count*
   213  // (not value) of data points in a Histogram.
   214  func AssertMetricWithLabelsEquals(t *testing.T, c prometheus.Collector, l prometheus.Labels, expected float64) {
   215  	t.Helper()
   216  	ch := make(chan prometheus.Metric)
   217  	done := make(chan struct{})
   218  	go func() {
   219  		c.Collect(ch)
   220  		close(done)
   221  	}()
   222  	var total float64
   223  	timeout := time.After(time.Second)
   224  loop:
   225  	for {
   226  	metric:
   227  		select {
   228  		case <-timeout:
   229  			t.Fatal("timed out collecting metrics")
   230  		case <-done:
   231  			break loop
   232  		case m := <-ch:
   233  			var iom io_prometheus_client.Metric
   234  			_ = m.Write(&iom)
   235  			for _, lp := range iom.Label {
   236  				// If any of the labels on this metric have the same name as but
   237  				// different value than a label in `l`, skip this metric.
   238  				val, ok := l[lp.GetName()]
   239  				if ok && lp.GetValue() != val {
   240  					break metric
   241  				}
   242  			}
   243  			// Exactly one of the Counter, Gauge, or Histogram values will be set by
   244  			// the .Write() operation, so add them all because the others will be 0.
   245  			total += iom.Counter.GetValue()
   246  			total += iom.Gauge.GetValue()
   247  			total += float64(iom.Histogram.GetSampleCount())
   248  		}
   249  	}
   250  	AssertEquals(t, total, expected)
   251  }
   252  
   253  // AssertImplementsGRPCServer guarantees that impl, which must be a pointer to one of our
   254  // gRPC service implementation types, implements all of the methods expected by
   255  // unimpl, which must be the auto-generated gRPC type which is embedded by impl.
   256  // This function incidentally also guarantees that impl does not have any
   257  // methods with non-pointer receivers.
   258  func AssertImplementsGRPCServer(t *testing.T, impl any, unimpl any) {
   259  	// This type is auto-generated by grpc-go, and has methods which implement
   260  	// the auto-generated gRPC interface. These methods all have value receivers,
   261  	// not pointer receivers, so we're not using a pointer type here.
   262  	unimplType := reflect.TypeOf(unimpl)
   263  
   264  	// This is the type we use to manually implement the same auto-generated gRPC
   265  	// interface. We implement all of our methods with pointer receivers. This
   266  	// type is required to embed the auto-generated "unimplemented" type above,
   267  	// so it inherits all of the methods implemented by that type as well. But
   268  	// we can use the fact that we use pointer receivers while the auto-generated
   269  	// type uses value receivers to our advantage.
   270  	implType := reflect.TypeOf(impl).Elem()
   271  
   272  	// Iterate over all of the methods which are provided by a *non-pointer*
   273  	// receiver of our type. These will be only the methods which we *don't*
   274  	// manually implement ourselves, because when we implement the method ourself
   275  	// we use a pointer receiver. So this loop will only iterate over those
   276  	// methods which "fall through" to the embedded "unimplemented" type. Ideally,
   277  	// this loop executes zero times. If there are any methods at all on the
   278  	// non-pointer receiver, something has gone wrong.
   279  	for i := 0; i < implType.NumMethod(); i++ {
   280  		method := implType.Method(i)
   281  		_, ok := unimplType.MethodByName(method.Name)
   282  		if ok {
   283  			// If the lookup worked, then we know this is a method which we were
   284  			// supposed to implement, but didn't. Oops.
   285  			t.Errorf("%s does not implement method %s", implType.Name(), method.Name)
   286  		} else {
   287  			// If the lookup failed, then we have accidentally implemented some other
   288  			// method with a non-pointer receiver. We probably didn't mean to do that.
   289  			t.Errorf("%s.%s has non-pointer receiver", implType.Name(), method.Name)
   290  		}
   291  	}
   292  }
   293  

View as plain text