...

Source file src/github.com/in-toto/in-toto-golang/internal/test/workload_api.go

Documentation: github.com/in-toto/in-toto-golang/internal/test

     1  package test
     2  
     3  import (
     4  	"context"
     5  	"crypto/x509"
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  	"sync"
    10  	"testing"
    11  
    12  	"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
    13  	"github.com/spiffe/go-spiffe/v2/proto/spiffe/workload"
    14  	"github.com/spiffe/go-spiffe/v2/svid/x509svid"
    15  	"github.com/stretchr/testify/require"
    16  	"google.golang.org/grpc"
    17  	"google.golang.org/grpc/codes"
    18  	"google.golang.org/grpc/metadata"
    19  	"google.golang.org/grpc/status"
    20  )
    21  
    22  var noIdentityError = status.Error(codes.PermissionDenied, "no identity issued")
    23  
    24  type WorkloadAPI struct {
    25  	tb        testing.TB
    26  	wg        sync.WaitGroup
    27  	addr      string
    28  	server    *grpc.Server
    29  	mu        sync.Mutex
    30  	x509Resp  *workload.X509SVIDResponse
    31  	x509Chans map[chan *workload.X509SVIDResponse]struct{}
    32  }
    33  
    34  func NewWorkloadAPI(tb testing.TB) *WorkloadAPI {
    35  	w := &WorkloadAPI{
    36  		x509Chans: make(map[chan *workload.X509SVIDResponse]struct{}),
    37  	}
    38  
    39  	listener, err := net.Listen("tcp", "localhost:0")
    40  	require.NoError(tb, err)
    41  
    42  	server := grpc.NewServer()
    43  	workload.RegisterSpiffeWorkloadAPIServer(server, &workloadAPIWrapper{w: w})
    44  
    45  	w.wg.Add(1)
    46  	go func() {
    47  		defer w.wg.Done()
    48  		_ = server.Serve(listener)
    49  	}()
    50  
    51  	w.addr = fmt.Sprintf("%s://%s", listener.Addr().Network(), listener.Addr().String())
    52  	tb.Logf("WorkloadAPI address: %s", w.addr)
    53  	w.server = server
    54  	return w
    55  }
    56  
    57  func (w *WorkloadAPI) Stop() {
    58  	w.server.Stop()
    59  	w.wg.Wait()
    60  }
    61  
    62  func (w *WorkloadAPI) Addr() string {
    63  	return w.addr
    64  }
    65  
    66  func (w *WorkloadAPI) SetX509SVIDResponse(r *X509SVIDResponse) {
    67  	var resp *workload.X509SVIDResponse
    68  	if r != nil {
    69  		resp = r.ToProto(w.tb)
    70  	}
    71  
    72  	w.mu.Lock()
    73  	defer w.mu.Unlock()
    74  	w.x509Resp = resp
    75  
    76  	for ch := range w.x509Chans {
    77  		select {
    78  		case ch <- resp:
    79  		default:
    80  			<-ch
    81  			ch <- resp
    82  		}
    83  	}
    84  }
    85  
    86  func concatRawCertsFromCerts(certs []*x509.Certificate) []byte {
    87  	var rawCerts []byte
    88  	for _, cert := range certs {
    89  		rawCerts = append(rawCerts, cert.Raw...)
    90  	}
    91  	return rawCerts
    92  }
    93  
    94  func (r *X509SVIDResponse) ToProto(tb testing.TB) *workload.X509SVIDResponse {
    95  	var bundle []byte
    96  	if r.Bundle != nil {
    97  		bundle = concatRawCertsFromCerts(r.Bundle.X509Authorities())
    98  	}
    99  
   100  	pb := &workload.X509SVIDResponse{
   101  		FederatedBundles: make(map[string][]byte),
   102  	}
   103  	for _, svid := range r.SVIDs {
   104  		var keyDER []byte
   105  		if svid.PrivateKey != nil {
   106  			var err error
   107  			keyDER, err = x509.MarshalPKCS8PrivateKey(svid.PrivateKey)
   108  			require.NoError(tb, err)
   109  		}
   110  		pb.Svids = append(pb.Svids, &workload.X509SVID{
   111  			SpiffeId:    svid.ID.String(),
   112  			X509Svid:    concatRawCertsFromCerts(svid.Certificates),
   113  			X509SvidKey: keyDER,
   114  			Bundle:      bundle,
   115  		})
   116  	}
   117  	for _, v := range r.FederatedBundles {
   118  		pb.FederatedBundles[v.TrustDomain().IDString()] = concatRawCertsFromCerts(v.X509Authorities())
   119  	}
   120  
   121  	return pb
   122  }
   123  
   124  type workloadAPIWrapper struct {
   125  	workload.UnimplementedSpiffeWorkloadAPIServer
   126  	w *WorkloadAPI
   127  }
   128  
   129  func (w *workloadAPIWrapper) FetchX509SVID(req *workload.X509SVIDRequest, stream workload.SpiffeWorkloadAPI_FetchX509SVIDServer) error {
   130  	return w.w.fetchX509SVID(req, stream)
   131  }
   132  
   133  type X509SVIDResponse struct {
   134  	SVIDs            []*x509svid.SVID
   135  	Bundle           *x509bundle.Bundle
   136  	FederatedBundles []*x509bundle.Bundle
   137  }
   138  
   139  func (w *WorkloadAPI) fetchX509SVID(_ *workload.X509SVIDRequest, stream workload.SpiffeWorkloadAPI_FetchX509SVIDServer) error {
   140  	if err := checkHeader(stream.Context()); err != nil {
   141  		return err
   142  	}
   143  	ch := make(chan *workload.X509SVIDResponse, 1)
   144  	w.mu.Lock()
   145  	w.x509Chans[ch] = struct{}{}
   146  	resp := w.x509Resp
   147  	w.mu.Unlock()
   148  
   149  	defer func() {
   150  		w.mu.Lock()
   151  		delete(w.x509Chans, ch)
   152  		w.mu.Unlock()
   153  	}()
   154  
   155  	sendResp := func(resp *workload.X509SVIDResponse) error {
   156  		if resp == nil {
   157  			return noIdentityError
   158  		}
   159  		return stream.Send(resp)
   160  	}
   161  
   162  	if err := sendResp(resp); err != nil {
   163  		return err
   164  	}
   165  	for {
   166  		select {
   167  		case resp := <-ch:
   168  			if err := sendResp(resp); err != nil {
   169  				return err
   170  			}
   171  		case <-stream.Context().Done():
   172  			return stream.Context().Err()
   173  		}
   174  	}
   175  }
   176  
   177  func checkHeader(ctx context.Context) error {
   178  	return checkMetadata(ctx, "workload.spiffe.io", "true")
   179  }
   180  
   181  func checkMetadata(ctx context.Context, key, value string) error {
   182  	md, ok := metadata.FromIncomingContext(ctx)
   183  	if !ok {
   184  		return errors.New("request does not contain metadata")
   185  	}
   186  	values := md.Get(key)
   187  	if len(value) == 0 {
   188  		return fmt.Errorf("request metadata does not contain %q value", key)
   189  	}
   190  	if values[0] != value {
   191  		return fmt.Errorf("request metadata %q value is %q; expected %q", key, values[0], value)
   192  	}
   193  	return nil
   194  }
   195  

View as plain text