...

Source file src/cloud.google.com/go/rpcreplay/rpcreplay.go

Documentation: cloud.google.com/go/rpcreplay

     1  // Copyright 2017 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rpcreplay
    16  
    17  import (
    18  	"bufio"
    19  	"bytes"
    20  	"context"
    21  	"encoding/binary"
    22  	"errors"
    23  	"fmt"
    24  	"io"
    25  	"log"
    26  	"net"
    27  	"os"
    28  	"sync"
    29  
    30  	pb "cloud.google.com/go/rpcreplay/proto/rpcreplay"
    31  	spb "google.golang.org/genproto/googleapis/rpc/status"
    32  	"google.golang.org/grpc"
    33  	"google.golang.org/grpc/metadata"
    34  	"google.golang.org/grpc/status"
    35  	"google.golang.org/protobuf/encoding/prototext"
    36  	"google.golang.org/protobuf/proto"
    37  	"google.golang.org/protobuf/types/known/anypb"
    38  )
    39  
    40  // A Recorder records RPCs for later playback.
    41  type Recorder struct {
    42  	mu   sync.Mutex
    43  	w    *bufio.Writer
    44  	f    *os.File
    45  	next int
    46  	err  error
    47  	// BeforeFunc defines a function that can inspect and modify requests and responses
    48  	// written to the replay file. It does not modify messages sent to the service.
    49  	// It is run once before a request is written to the replay file, and once before a response
    50  	// is written to the replay file.
    51  	// The function is called with the method name and the message that triggered the callback.
    52  	// If the function returns an error, the error will be returned to the client.
    53  	// This is only executed for unary RPCs; streaming RPCs are not supported.
    54  	BeforeFunc func(string, proto.Message) error
    55  }
    56  
    57  // NewRecorder creates a recorder that writes to filename. The file will
    58  // also store the initial bytes for retrieval during replay.
    59  //
    60  // You must call Close on the Recorder to ensure that all data is written.
    61  func NewRecorder(filename string, initial []byte) (*Recorder, error) {
    62  	f, err := os.Create(filename)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  	rec, err := NewRecorderWriter(f, initial)
    67  	if err != nil {
    68  		_ = f.Close()
    69  		return nil, err
    70  	}
    71  	rec.f = f
    72  	return rec, nil
    73  }
    74  
    75  // NewRecorderWriter creates a recorder that writes to w. The initial
    76  // bytes will also be written to w for retrieval during replay.
    77  //
    78  // You must call Close on the Recorder to ensure that all data is written.
    79  func NewRecorderWriter(w io.Writer, initial []byte) (*Recorder, error) {
    80  	bw := bufio.NewWriter(w)
    81  	if err := writeHeader(bw, initial); err != nil {
    82  		return nil, err
    83  	}
    84  	return &Recorder{w: bw, next: 1}, nil
    85  }
    86  
    87  // DialOptions returns the options that must be passed to grpc.Dial
    88  // to enable recording.
    89  func (r *Recorder) DialOptions() []grpc.DialOption {
    90  	return []grpc.DialOption{
    91  		grpc.WithUnaryInterceptor(r.interceptUnary),
    92  		grpc.WithStreamInterceptor(r.interceptStream),
    93  	}
    94  }
    95  
    96  // Close saves any unwritten information.
    97  func (r *Recorder) Close() error {
    98  	r.mu.Lock()
    99  	defer r.mu.Unlock()
   100  	if r.err != nil {
   101  		return r.err
   102  	}
   103  	err := r.w.Flush()
   104  	if r.f != nil {
   105  		if err2 := r.f.Close(); err == nil {
   106  			err = err2
   107  		}
   108  	}
   109  	return err
   110  }
   111  
   112  // Intercepts all unary (non-stream) RPCs.
   113  func (r *Recorder) interceptUnary(ctx context.Context, method string, req, res interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
   114  	ereq := &entry{
   115  		kind:   pb.Entry_REQUEST,
   116  		method: method,
   117  		msg:    message{msg: proto.Clone(req.(proto.Message))},
   118  	}
   119  
   120  	if r.BeforeFunc != nil {
   121  		if err := r.BeforeFunc(method, ereq.msg.msg); err != nil {
   122  			return err
   123  		}
   124  	}
   125  	refIndex, err := r.writeEntry(ereq)
   126  	if err != nil {
   127  		return err
   128  	}
   129  	ierr := invoker(ctx, method, req, res, cc, opts...)
   130  	eres := &entry{
   131  		kind:     pb.Entry_RESPONSE,
   132  		refIndex: refIndex,
   133  	}
   134  	// If the error is not a gRPC status, then something more
   135  	// serious is wrong. More significantly, we have no way
   136  	// of serializing an arbitrary error. So just return it
   137  	// without recording the response.
   138  	if _, ok := status.FromError(ierr); !ok {
   139  		r.mu.Lock()
   140  		r.err = fmt.Errorf("saw non-status error in %s response: %w (%T)", method, ierr, ierr)
   141  		r.mu.Unlock()
   142  		return ierr
   143  	}
   144  	eres.msg.set(proto.Clone(res.(proto.Message)), ierr)
   145  	if r.BeforeFunc != nil {
   146  		if err := r.BeforeFunc(method, eres.msg.msg); err != nil {
   147  			return err
   148  		}
   149  	}
   150  	if _, err := r.writeEntry(eres); err != nil {
   151  		return err
   152  	}
   153  	return ierr
   154  }
   155  
   156  func (r *Recorder) writeEntry(e *entry) (int, error) {
   157  	r.mu.Lock()
   158  	defer r.mu.Unlock()
   159  	if r.err != nil {
   160  		return 0, r.err
   161  	}
   162  	err := writeEntry(r.w, e)
   163  	if err != nil {
   164  		r.err = err
   165  		return 0, err
   166  	}
   167  	n := r.next
   168  	r.next++
   169  	return n, nil
   170  }
   171  
   172  func (r *Recorder) interceptStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
   173  	cstream, serr := streamer(ctx, desc, cc, method, opts...)
   174  	e := &entry{
   175  		kind:   pb.Entry_CREATE_STREAM,
   176  		method: method,
   177  	}
   178  	e.msg.set(nil, serr)
   179  	refIndex, err := r.writeEntry(e)
   180  	if err != nil {
   181  		return nil, err
   182  	}
   183  	return &recClientStream{
   184  		ctx:      ctx,
   185  		rec:      r,
   186  		cstream:  cstream,
   187  		refIndex: refIndex,
   188  	}, serr
   189  }
   190  
   191  // A recClientStream implements the gprc.ClientStream interface.
   192  // It behaves exactly like the default ClientStream, but also
   193  // records all messages sent and received.
   194  type recClientStream struct {
   195  	ctx      context.Context
   196  	rec      *Recorder
   197  	cstream  grpc.ClientStream
   198  	refIndex int
   199  }
   200  
   201  func (rcs *recClientStream) Context() context.Context { return rcs.ctx }
   202  
   203  func (rcs *recClientStream) SendMsg(m interface{}) error {
   204  	serr := rcs.cstream.SendMsg(m)
   205  	e := &entry{
   206  		kind:     pb.Entry_SEND,
   207  		refIndex: rcs.refIndex,
   208  	}
   209  	e.msg.set(m, serr)
   210  	if _, err := rcs.rec.writeEntry(e); err != nil {
   211  		return err
   212  	}
   213  	return serr
   214  }
   215  
   216  func (rcs *recClientStream) RecvMsg(m interface{}) error {
   217  	serr := rcs.cstream.RecvMsg(m)
   218  	e := &entry{
   219  		kind:     pb.Entry_RECV,
   220  		refIndex: rcs.refIndex,
   221  	}
   222  	e.msg.set(m, serr)
   223  	if _, err := rcs.rec.writeEntry(e); err != nil {
   224  		return err
   225  	}
   226  	return serr
   227  }
   228  
   229  func (rcs *recClientStream) Header() (metadata.MD, error) {
   230  	// TODO(jba): record.
   231  	return rcs.cstream.Header()
   232  }
   233  
   234  func (rcs *recClientStream) Trailer() metadata.MD {
   235  	// TODO(jba): record.
   236  	return rcs.cstream.Trailer()
   237  }
   238  
   239  func (rcs *recClientStream) CloseSend() error {
   240  	// TODO(jba): record.
   241  	return rcs.cstream.CloseSend()
   242  }
   243  
   244  // A Replayer replays a set of RPCs saved by a Recorder.
   245  type Replayer struct {
   246  	initial []byte                                // initial state
   247  	log     func(format string, v ...interface{}) // for debugging
   248  
   249  	mu      sync.Mutex
   250  	calls   []*call
   251  	streams []*stream
   252  	// BeforeFunc defines a function that can inspect and modify requests before they
   253  	// are matched for responses from the replay file.
   254  	// The function is called with the method name and the message that triggered the callback.
   255  	// If the function returns an error, the error will be returned to the client.
   256  	// This is only executed for unary RPCs; streaming RPCs are not supported.
   257  	BeforeFunc func(string, proto.Message) error
   258  }
   259  
   260  // A call represents a unary RPC, with a request and response (or error).
   261  type call struct {
   262  	method   string
   263  	request  proto.Message
   264  	response message
   265  }
   266  
   267  // A stream represents a gRPC stream, with an initial create-stream call, followed by
   268  // zero or more sends and/or receives.
   269  type stream struct {
   270  	method      string
   271  	createIndex int
   272  	createErr   error // error from create call
   273  	sends       []message
   274  	recvs       []message
   275  }
   276  
   277  // NewReplayer creates a Replayer that reads from filename.
   278  func NewReplayer(filename string) (*Replayer, error) {
   279  	f, err := os.Open(filename)
   280  	if err != nil {
   281  		return nil, err
   282  	}
   283  	defer f.Close()
   284  	return NewReplayerReader(f)
   285  }
   286  
   287  // NewReplayerReader creates a Replayer that reads from r.
   288  func NewReplayerReader(r io.Reader) (*Replayer, error) {
   289  	rep := &Replayer{
   290  		log: func(string, ...interface{}) {},
   291  	}
   292  	if err := rep.read(r); err != nil {
   293  		return nil, err
   294  	}
   295  	return rep, nil
   296  }
   297  
   298  // read reads the stream of recorded entries.
   299  // It matches requests with responses, with each pair grouped
   300  // into a call struct.
   301  func (rep *Replayer) read(r io.Reader) error {
   302  	r = bufio.NewReader(r)
   303  	bytes, err := readHeader(r)
   304  	if err != nil {
   305  		return err
   306  	}
   307  	rep.initial = bytes
   308  
   309  	callsByIndex := map[int]*call{}
   310  	streamsByIndex := map[int]*stream{}
   311  	for i := 1; ; i++ {
   312  		e, err := readEntry(r)
   313  		if err != nil {
   314  			return err
   315  		}
   316  		if e == nil {
   317  			break
   318  		}
   319  		switch e.kind {
   320  		case pb.Entry_REQUEST:
   321  			callsByIndex[i] = &call{
   322  				method:  e.method,
   323  				request: e.msg.msg,
   324  			}
   325  
   326  		case pb.Entry_RESPONSE:
   327  			call := callsByIndex[e.refIndex]
   328  			if call == nil {
   329  				return fmt.Errorf("replayer: no request for response #%d", i)
   330  			}
   331  			delete(callsByIndex, e.refIndex)
   332  			call.response = e.msg
   333  			rep.calls = append(rep.calls, call)
   334  
   335  		case pb.Entry_CREATE_STREAM:
   336  			s := &stream{method: e.method, createIndex: i}
   337  			s.createErr = e.msg.err
   338  			streamsByIndex[i] = s
   339  			rep.streams = append(rep.streams, s)
   340  
   341  		case pb.Entry_SEND:
   342  			s := streamsByIndex[e.refIndex]
   343  			if s == nil {
   344  				return fmt.Errorf("replayer: no stream for send #%d", i)
   345  			}
   346  			s.sends = append(s.sends, e.msg)
   347  
   348  		case pb.Entry_RECV:
   349  			s := streamsByIndex[e.refIndex]
   350  			if s == nil {
   351  				return fmt.Errorf("replayer: no stream for recv #%d", i)
   352  			}
   353  			s.recvs = append(s.recvs, e.msg)
   354  
   355  		default:
   356  			return fmt.Errorf("replayer: unknown kind %s", e.kind)
   357  		}
   358  	}
   359  	if len(callsByIndex) > 0 {
   360  		return fmt.Errorf("replayer: %d unmatched requests", len(callsByIndex))
   361  	}
   362  	return nil
   363  }
   364  
   365  // DialOptions returns the options that must be passed to grpc.Dial
   366  // to enable replaying.
   367  func (rep *Replayer) DialOptions() []grpc.DialOption {
   368  	return []grpc.DialOption{
   369  		// On replay, we make no RPCs, which means the connection may be closed
   370  		// before the normally async Dial completes. Making the Dial synchronous
   371  		// fixes that.
   372  		grpc.WithBlock(),
   373  		grpc.WithUnaryInterceptor(rep.interceptUnary),
   374  		grpc.WithStreamInterceptor(rep.interceptStream),
   375  	}
   376  }
   377  
   378  // Connection returns a fake gRPC connection suitable for replaying.
   379  func (rep *Replayer) Connection() (*grpc.ClientConn, error) {
   380  	// We don't need an actual connection, not even a loopback one.
   381  	// But we do need something to attach gRPC interceptors to.
   382  	// So we start a local server and connect to it, then close it down.
   383  	srv := grpc.NewServer()
   384  	l, err := net.Listen("tcp", "localhost:0")
   385  	if err != nil {
   386  		return nil, err
   387  	}
   388  	go func() {
   389  		if err := srv.Serve(l); err != nil {
   390  			panic(err) // we should never get an error because we just connect and stop
   391  		}
   392  	}()
   393  	conn, err := grpc.Dial(l.Addr().String(),
   394  		append([]grpc.DialOption{grpc.WithInsecure()}, rep.DialOptions()...)...)
   395  	if err != nil {
   396  		return nil, err
   397  	}
   398  	conn.Close()
   399  	srv.Stop()
   400  	return conn, nil
   401  }
   402  
   403  // Initial returns the initial state saved by the Recorder.
   404  func (rep *Replayer) Initial() []byte { return rep.initial }
   405  
   406  // SetLogFunc sets a function to be used for debug logging. The function
   407  // should be safe to be called from multiple goroutines.
   408  func (rep *Replayer) SetLogFunc(f func(format string, v ...interface{})) {
   409  	rep.log = f
   410  }
   411  
   412  // Close closes the Replayer.
   413  func (rep *Replayer) Close() error {
   414  	return nil
   415  }
   416  
   417  func (rep *Replayer) interceptUnary(_ context.Context, method string, req, res interface{}, _ *grpc.ClientConn, _ grpc.UnaryInvoker, _ ...grpc.CallOption) error {
   418  	mreq := req.(proto.Message)
   419  	if rep.BeforeFunc != nil {
   420  		if err := rep.BeforeFunc(method, mreq); err != nil {
   421  			return err
   422  		}
   423  	}
   424  	rep.log("request %s (%s)", method, req)
   425  	call := rep.extractCall(method, mreq)
   426  	if call == nil {
   427  		return fmt.Errorf("replayer: request not found: %s", mreq)
   428  	}
   429  	rep.log("returning %v", call.response)
   430  	if call.response.err != nil {
   431  		return call.response.err
   432  	}
   433  	proto.Merge(res.(proto.Message), call.response.msg) // copy msg into res
   434  	return nil
   435  }
   436  
   437  func (rep *Replayer) interceptStream(ctx context.Context, _ *grpc.StreamDesc, _ *grpc.ClientConn, method string, _ grpc.Streamer, _ ...grpc.CallOption) (grpc.ClientStream, error) {
   438  	rep.log("create-stream %s", method)
   439  	return &repClientStream{ctx: ctx, rep: rep, method: method}, nil
   440  }
   441  
   442  type repClientStream struct {
   443  	ctx    context.Context
   444  	rep    *Replayer
   445  	method string
   446  	str    *stream
   447  }
   448  
   449  func (rcs *repClientStream) Context() context.Context { return rcs.ctx }
   450  
   451  func (rcs *repClientStream) SendMsg(req interface{}) error {
   452  	if rcs.str == nil {
   453  		if err := rcs.setStream(rcs.method, req.(proto.Message)); err != nil {
   454  			return err
   455  		}
   456  	}
   457  	if len(rcs.str.sends) == 0 {
   458  		return fmt.Errorf("replayer: no more sends for stream %s, created at index %d",
   459  			rcs.str.method, rcs.str.createIndex)
   460  	}
   461  	// TODO(jba): Do not assume that the sends happen in the same order on replay.
   462  	msg := rcs.str.sends[0]
   463  	rcs.str.sends = rcs.str.sends[1:]
   464  	return msg.err
   465  }
   466  
   467  func (rcs *repClientStream) setStream(method string, req proto.Message) error {
   468  	str := rcs.rep.extractStream(method, req)
   469  	if str == nil {
   470  		return fmt.Errorf("replayer: stream not found for method %s and request %v", method, req)
   471  	}
   472  	if str.createErr != nil {
   473  		return str.createErr
   474  	}
   475  	rcs.str = str
   476  	return nil
   477  }
   478  
   479  func (rcs *repClientStream) RecvMsg(m interface{}) error {
   480  	if rcs.str == nil {
   481  		// Receive before send; fall back to matching stream by method only.
   482  		if err := rcs.setStream(rcs.method, nil); err != nil {
   483  			return err
   484  		}
   485  	}
   486  	if len(rcs.str.recvs) == 0 {
   487  		return fmt.Errorf("replayer: no more receives for stream %s, created at index %d",
   488  			rcs.str.method, rcs.str.createIndex)
   489  	}
   490  	msg := rcs.str.recvs[0]
   491  	rcs.str.recvs = rcs.str.recvs[1:]
   492  	if msg.err != nil {
   493  		return msg.err
   494  	}
   495  	proto.Merge(m.(proto.Message), msg.msg) // copy msg into m
   496  	return nil
   497  }
   498  
   499  func (rcs *repClientStream) Header() (metadata.MD, error) {
   500  	log.Printf("replay: stream metadata not supported")
   501  	return nil, nil
   502  }
   503  
   504  func (rcs *repClientStream) Trailer() metadata.MD {
   505  	log.Printf("replay: stream metadata not supported")
   506  	return nil
   507  }
   508  
   509  func (rcs *repClientStream) CloseSend() error {
   510  	return nil
   511  }
   512  
   513  // extractCall finds the first call in the list with the same method
   514  // and request. It returns nil if it can't find such a call.
   515  func (rep *Replayer) extractCall(method string, req proto.Message) *call {
   516  	rep.mu.Lock()
   517  	defer rep.mu.Unlock()
   518  	for i, call := range rep.calls {
   519  		if call == nil {
   520  			continue
   521  		}
   522  		if method == call.method && proto.Equal(req, call.request) {
   523  			rep.calls[i] = nil // nil out this call so we don't reuse it
   524  			return call
   525  		}
   526  	}
   527  	return nil
   528  }
   529  
   530  // extractStream find the first stream in the list with the same method and the same
   531  // first request sent. If req is nil, that means a receive occurred before a send, so
   532  // it matches only on method.
   533  func (rep *Replayer) extractStream(method string, req proto.Message) *stream {
   534  	rep.mu.Lock()
   535  	defer rep.mu.Unlock()
   536  	for i, stream := range rep.streams {
   537  		// Skip stream if it is nil (already extracted) or its method doesn't match.
   538  		if stream == nil || stream.method != method {
   539  			continue
   540  		}
   541  		// If there is a first request, skip stream if it has no requests or its first
   542  		// request doesn't match.
   543  		if req != nil && len(stream.sends) > 0 && !proto.Equal(req, stream.sends[0].msg) {
   544  			continue
   545  		}
   546  		rep.streams[i] = nil // nil out this stream so we don't reuse it
   547  		return stream
   548  	}
   549  	return nil
   550  }
   551  
   552  // Fprint reads the entries from filename and writes them to w in human-readable form.
   553  // It is intended for debugging.
   554  func Fprint(w io.Writer, filename string) error {
   555  	f, err := os.Open(filename)
   556  	if err != nil {
   557  		return err
   558  	}
   559  	defer f.Close()
   560  	return FprintReader(w, f)
   561  }
   562  
   563  // FprintReader reads the entries from r and writes them to w in human-readable form.
   564  // It is intended for debugging.
   565  func FprintReader(w io.Writer, r io.Reader) error {
   566  	initial, err := readHeader(r)
   567  	if err != nil {
   568  		return err
   569  	}
   570  	fmt.Fprintf(w, "initial state: %q\n", string(initial))
   571  	for i := 1; ; i++ {
   572  		e, err := readEntry(r)
   573  		if err != nil {
   574  			return err
   575  		}
   576  		if e == nil {
   577  			return nil
   578  		}
   579  
   580  		fmt.Fprintf(w, "#%d: kind: %s, method: %s, ref index: %d", i, e.kind, e.method, e.refIndex)
   581  		switch {
   582  		case e.msg.msg != nil:
   583  			fmt.Fprintf(w, ", message:\n")
   584  			b, err := prototext.Marshal(e.msg.msg)
   585  			if err != nil {
   586  				return err
   587  			}
   588  			if _, err := io.Copy(w, bytes.NewReader(b)); err != nil {
   589  				return err
   590  			}
   591  		case e.msg.err != nil:
   592  			fmt.Fprintf(w, ", error: %v\n", e.msg.err)
   593  		default:
   594  			fmt.Fprintln(w)
   595  		}
   596  	}
   597  }
   598  
   599  // An entry holds one gRPC action (request, response, etc.).
   600  type entry struct {
   601  	kind     pb.Entry_Kind
   602  	method   string
   603  	msg      message
   604  	refIndex int // index of corresponding request or create-stream
   605  }
   606  
   607  func (e1 *entry) equal(e2 *entry) bool {
   608  	if e1 == nil && e2 == nil {
   609  		return true
   610  	}
   611  	if e1 == nil || e2 == nil {
   612  		return false
   613  	}
   614  	return e1.kind == e2.kind &&
   615  		e1.method == e2.method &&
   616  		proto.Equal(e1.msg.msg, e2.msg.msg) &&
   617  		errEqual(e1.msg.err, e2.msg.err) &&
   618  		e1.refIndex == e2.refIndex
   619  }
   620  
   621  func errEqual(e1, e2 error) bool {
   622  	if e1 == e2 {
   623  		return true
   624  	}
   625  	s1, ok1 := status.FromError(e1)
   626  	s2, ok2 := status.FromError(e2)
   627  	if !ok1 || !ok2 {
   628  		return false
   629  	}
   630  	return proto.Equal(s1.Proto(), s2.Proto())
   631  }
   632  
   633  // message holds either a single proto.Message or an error.
   634  type message struct {
   635  	msg proto.Message
   636  	err error
   637  }
   638  
   639  func (m *message) set(msg interface{}, err error) {
   640  	m.err = err
   641  	if err != io.EOF && msg != nil {
   642  		m.msg = msg.(proto.Message)
   643  	}
   644  }
   645  
   646  // File format:
   647  //   header
   648  //   sequence of Entry protos
   649  //
   650  // Header format:
   651  //   magic string
   652  //   a record containing the bytes of the initial state
   653  
   654  const magic = "RPCReplay"
   655  
   656  func writeHeader(w io.Writer, initial []byte) error {
   657  	if _, err := io.WriteString(w, magic); err != nil {
   658  		return err
   659  	}
   660  	return writeRecord(w, initial)
   661  }
   662  
   663  func readHeader(r io.Reader) ([]byte, error) {
   664  	var buf [len(magic)]byte
   665  	if _, err := io.ReadFull(r, buf[:]); err != nil {
   666  		if err == io.EOF {
   667  			err = errors.New("rpcreplay: empty replay file")
   668  		}
   669  		return nil, err
   670  	}
   671  	if string(buf[:]) != magic {
   672  		return nil, errors.New("rpcreplay: not a replay file (does not begin with magic string)")
   673  	}
   674  	bytes, err := readRecord(r)
   675  	if err == io.EOF {
   676  		err = errors.New("rpcreplay: missing initial state")
   677  	}
   678  	return bytes, err
   679  }
   680  
   681  func writeEntry(w io.Writer, e *entry) error {
   682  	var m proto.Message
   683  	if e.msg.err != nil && e.msg.err != io.EOF {
   684  		s, ok := status.FromError(e.msg.err)
   685  		if !ok {
   686  			return fmt.Errorf("rpcreplay: error %w is not a Status", e.msg.err)
   687  		}
   688  		m = s.Proto()
   689  	} else {
   690  		m = e.msg.msg
   691  	}
   692  	var a *anypb.Any
   693  	var err error
   694  	if m != nil {
   695  		a, err = anypb.New(m)
   696  		if err != nil {
   697  			return err
   698  		}
   699  	}
   700  	pe := &pb.Entry{
   701  		Kind:     e.kind,
   702  		Method:   e.method,
   703  		Message:  a,
   704  		IsError:  e.msg.err != nil,
   705  		RefIndex: int32(e.refIndex),
   706  	}
   707  	bytes, err := proto.Marshal(pe)
   708  	if err != nil {
   709  		return err
   710  	}
   711  	return writeRecord(w, bytes)
   712  }
   713  
   714  func readEntry(r io.Reader) (*entry, error) {
   715  	buf, err := readRecord(r)
   716  	if err == io.EOF {
   717  		return nil, nil
   718  	}
   719  	if err != nil {
   720  		return nil, err
   721  	}
   722  	var pe pb.Entry
   723  	if err := proto.Unmarshal(buf, &pe); err != nil {
   724  		return nil, err
   725  	}
   726  	var msg message
   727  	if pe.Message != nil {
   728  		if pe.IsError {
   729  			s := &spb.Status{}
   730  			err := anypb.UnmarshalTo(pe.Message, s, proto.UnmarshalOptions{AllowPartial: true, DiscardUnknown: true})
   731  			if err != nil {
   732  				return nil, err
   733  			}
   734  			msg.err = status.ErrorProto(s)
   735  		} else {
   736  			m, err := anypb.UnmarshalNew(pe.Message, proto.UnmarshalOptions{AllowPartial: true, DiscardUnknown: true})
   737  			if err != nil {
   738  				return nil, err
   739  			}
   740  			msg.msg = m
   741  		}
   742  	} else if pe.IsError {
   743  		msg.err = io.EOF
   744  	} else if pe.Kind != pb.Entry_CREATE_STREAM {
   745  		return nil, errors.New("rpcreplay: entry with nil message and false is_error")
   746  	}
   747  	return &entry{
   748  		kind:     pe.Kind,
   749  		method:   pe.Method,
   750  		msg:      msg,
   751  		refIndex: int(pe.RefIndex),
   752  	}, nil
   753  }
   754  
   755  // A record consists of an unsigned 32-bit little-endian length L followed by L
   756  // bytes.
   757  
   758  func writeRecord(w io.Writer, data []byte) error {
   759  	if err := binary.Write(w, binary.LittleEndian, uint32(len(data))); err != nil {
   760  		return err
   761  	}
   762  	_, err := w.Write(data)
   763  	return err
   764  }
   765  
   766  func readRecord(r io.Reader) ([]byte, error) {
   767  	var size uint32
   768  	if err := binary.Read(r, binary.LittleEndian, &size); err != nil {
   769  		return nil, err
   770  	}
   771  	buf := make([]byte, size)
   772  	if _, err := io.ReadFull(r, buf); err != nil {
   773  		return nil, err
   774  	}
   775  	return buf, nil
   776  }
   777  

View as plain text