...

Source file src/golang.org/x/net/quic/endpoint_test.go

Documentation: golang.org/x/net/quic

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build go1.21
     6  
     7  package quic
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"crypto/tls"
    13  	"io"
    14  	"log/slog"
    15  	"net/netip"
    16  	"runtime"
    17  	"testing"
    18  	"time"
    19  
    20  	"golang.org/x/net/quic/qlog"
    21  )
    22  
    23  func TestConnect(t *testing.T) {
    24  	newLocalConnPair(t, &Config{}, &Config{})
    25  }
    26  
    27  func TestConnectDefaultTLSConfig(t *testing.T) {
    28  	serverConfig := newTestTLSConfigWithMoreDefaults(serverSide)
    29  	clientConfig := newTestTLSConfigWithMoreDefaults(clientSide)
    30  	newLocalConnPair(t, &Config{TLSConfig: serverConfig}, &Config{TLSConfig: clientConfig})
    31  }
    32  
    33  func TestStreamTransfer(t *testing.T) {
    34  	ctx := context.Background()
    35  	cli, srv := newLocalConnPair(t, &Config{}, &Config{})
    36  	data := makeTestData(1 << 20)
    37  
    38  	srvdone := make(chan struct{})
    39  	go func() {
    40  		defer close(srvdone)
    41  		s, err := srv.AcceptStream(ctx)
    42  		if err != nil {
    43  			t.Errorf("AcceptStream: %v", err)
    44  			return
    45  		}
    46  		b, err := io.ReadAll(s)
    47  		if err != nil {
    48  			t.Errorf("io.ReadAll(s): %v", err)
    49  			return
    50  		}
    51  		if !bytes.Equal(b, data) {
    52  			t.Errorf("read data mismatch (got %v bytes, want %v", len(b), len(data))
    53  		}
    54  		if err := s.Close(); err != nil {
    55  			t.Errorf("s.Close() = %v", err)
    56  		}
    57  	}()
    58  
    59  	s, err := cli.NewSendOnlyStream(ctx)
    60  	if err != nil {
    61  		t.Fatalf("NewStream: %v", err)
    62  	}
    63  	n, err := io.Copy(s, bytes.NewBuffer(data))
    64  	if n != int64(len(data)) || err != nil {
    65  		t.Fatalf("io.Copy(s, data) = %v, %v; want %v, nil", n, err, len(data))
    66  	}
    67  	if err := s.Close(); err != nil {
    68  		t.Fatalf("s.Close() = %v", err)
    69  	}
    70  }
    71  
    72  func newLocalConnPair(t testing.TB, conf1, conf2 *Config) (clientConn, serverConn *Conn) {
    73  	switch runtime.GOOS {
    74  	case "plan9":
    75  		t.Skipf("ReadMsgUDP not supported on %s", runtime.GOOS)
    76  	}
    77  	t.Helper()
    78  	ctx := context.Background()
    79  	e1 := newLocalEndpoint(t, serverSide, conf1)
    80  	e2 := newLocalEndpoint(t, clientSide, conf2)
    81  	conf2 = makeTestConfig(conf2, clientSide)
    82  	c2, err := e2.Dial(ctx, "udp", e1.LocalAddr().String(), conf2)
    83  	if err != nil {
    84  		t.Fatal(err)
    85  	}
    86  	c1, err := e1.Accept(ctx)
    87  	if err != nil {
    88  		t.Fatal(err)
    89  	}
    90  	return c2, c1
    91  }
    92  
    93  func newLocalEndpoint(t testing.TB, side connSide, conf *Config) *Endpoint {
    94  	t.Helper()
    95  	conf = makeTestConfig(conf, side)
    96  	e, err := Listen("udp", "127.0.0.1:0", conf)
    97  	if err != nil {
    98  		t.Fatal(err)
    99  	}
   100  	t.Cleanup(func() {
   101  		e.Close(canceledContext())
   102  	})
   103  	return e
   104  }
   105  
   106  func makeTestConfig(conf *Config, side connSide) *Config {
   107  	if conf == nil {
   108  		return nil
   109  	}
   110  	newConf := *conf
   111  	conf = &newConf
   112  	if conf.TLSConfig == nil {
   113  		conf.TLSConfig = newTestTLSConfig(side)
   114  	}
   115  	if conf.QLogLogger == nil {
   116  		conf.QLogLogger = slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{
   117  			Level: QLogLevelFrame,
   118  			Dir:   *qlogdir,
   119  		}))
   120  	}
   121  	return conf
   122  }
   123  
   124  type testEndpoint struct {
   125  	t                     *testing.T
   126  	e                     *Endpoint
   127  	now                   time.Time
   128  	recvc                 chan *datagram
   129  	idlec                 chan struct{}
   130  	conns                 map[*Conn]*testConn
   131  	acceptQueue           []*testConn
   132  	configTransportParams []func(*transportParameters)
   133  	configTestConn        []func(*testConn)
   134  	sentDatagrams         [][]byte
   135  	peerTLSConn           *tls.QUICConn
   136  	lastInitialDstConnID  []byte // for parsing Retry packets
   137  }
   138  
   139  func newTestEndpoint(t *testing.T, config *Config) *testEndpoint {
   140  	te := &testEndpoint{
   141  		t:     t,
   142  		now:   time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC),
   143  		recvc: make(chan *datagram),
   144  		idlec: make(chan struct{}),
   145  		conns: make(map[*Conn]*testConn),
   146  	}
   147  	var err error
   148  	te.e, err = newEndpoint((*testEndpointUDPConn)(te), config, (*testEndpointHooks)(te))
   149  	if err != nil {
   150  		t.Fatal(err)
   151  	}
   152  	t.Cleanup(te.cleanup)
   153  	return te
   154  }
   155  
   156  func (te *testEndpoint) cleanup() {
   157  	te.e.Close(canceledContext())
   158  }
   159  
   160  func (te *testEndpoint) wait() {
   161  	select {
   162  	case te.idlec <- struct{}{}:
   163  	case <-te.e.closec:
   164  	}
   165  	for _, tc := range te.conns {
   166  		tc.wait()
   167  	}
   168  }
   169  
   170  // accept returns a server connection from the endpoint.
   171  // Unlike Endpoint.Accept, connections are available as soon as they are created.
   172  func (te *testEndpoint) accept() *testConn {
   173  	if len(te.acceptQueue) == 0 {
   174  		te.t.Fatalf("accept: expected available conn, but found none")
   175  	}
   176  	tc := te.acceptQueue[0]
   177  	te.acceptQueue = te.acceptQueue[1:]
   178  	return tc
   179  }
   180  
   181  func (te *testEndpoint) write(d *datagram) {
   182  	te.recvc <- d
   183  	te.wait()
   184  }
   185  
   186  var testClientAddr = netip.MustParseAddrPort("10.0.0.1:8000")
   187  
   188  func (te *testEndpoint) writeDatagram(d *testDatagram) {
   189  	te.t.Helper()
   190  	logDatagram(te.t, "<- endpoint under test receives", d)
   191  	var buf []byte
   192  	for _, p := range d.packets {
   193  		tc := te.connForDestination(p.dstConnID)
   194  		if p.ptype != packetTypeRetry && tc != nil {
   195  			space := spaceForPacketType(p.ptype)
   196  			if p.num >= tc.peerNextPacketNum[space] {
   197  				tc.peerNextPacketNum[space] = p.num + 1
   198  			}
   199  		}
   200  		if p.ptype == packetTypeInitial {
   201  			te.lastInitialDstConnID = p.dstConnID
   202  		}
   203  		pad := 0
   204  		if p.ptype == packetType1RTT {
   205  			pad = d.paddedSize - len(buf)
   206  		}
   207  		buf = append(buf, encodeTestPacket(te.t, tc, p, pad)...)
   208  	}
   209  	for len(buf) < d.paddedSize {
   210  		buf = append(buf, 0)
   211  	}
   212  	te.write(&datagram{
   213  		b:        buf,
   214  		peerAddr: d.addr,
   215  	})
   216  }
   217  
   218  func (te *testEndpoint) connForDestination(dstConnID []byte) *testConn {
   219  	for _, tc := range te.conns {
   220  		for _, loc := range tc.conn.connIDState.local {
   221  			if bytes.Equal(loc.cid, dstConnID) {
   222  				return tc
   223  			}
   224  		}
   225  	}
   226  	return nil
   227  }
   228  
   229  func (te *testEndpoint) connForSource(srcConnID []byte) *testConn {
   230  	for _, tc := range te.conns {
   231  		for _, loc := range tc.conn.connIDState.remote {
   232  			if bytes.Equal(loc.cid, srcConnID) {
   233  				return tc
   234  			}
   235  		}
   236  	}
   237  	return nil
   238  }
   239  
   240  func (te *testEndpoint) read() []byte {
   241  	te.t.Helper()
   242  	te.wait()
   243  	if len(te.sentDatagrams) == 0 {
   244  		return nil
   245  	}
   246  	d := te.sentDatagrams[0]
   247  	te.sentDatagrams = te.sentDatagrams[1:]
   248  	return d
   249  }
   250  
   251  func (te *testEndpoint) readDatagram() *testDatagram {
   252  	te.t.Helper()
   253  	buf := te.read()
   254  	if buf == nil {
   255  		return nil
   256  	}
   257  	p, _ := parseGenericLongHeaderPacket(buf)
   258  	tc := te.connForSource(p.dstConnID)
   259  	d := parseTestDatagram(te.t, te, tc, buf)
   260  	logDatagram(te.t, "-> endpoint under test sends", d)
   261  	return d
   262  }
   263  
   264  // wantDatagram indicates that we expect the Endpoint to send a datagram.
   265  func (te *testEndpoint) wantDatagram(expectation string, want *testDatagram) {
   266  	te.t.Helper()
   267  	got := te.readDatagram()
   268  	if !datagramEqual(got, want) {
   269  		te.t.Fatalf("%v:\ngot datagram:  %v\nwant datagram: %v", expectation, got, want)
   270  	}
   271  }
   272  
   273  // wantIdle indicates that we expect the Endpoint to not send any more datagrams.
   274  func (te *testEndpoint) wantIdle(expectation string) {
   275  	if got := te.readDatagram(); got != nil {
   276  		te.t.Fatalf("expect: %v\nunexpectedly got: %v", expectation, got)
   277  	}
   278  }
   279  
   280  // advance causes time to pass.
   281  func (te *testEndpoint) advance(d time.Duration) {
   282  	te.t.Helper()
   283  	te.advanceTo(te.now.Add(d))
   284  }
   285  
   286  // advanceTo sets the current time.
   287  func (te *testEndpoint) advanceTo(now time.Time) {
   288  	te.t.Helper()
   289  	if te.now.After(now) {
   290  		te.t.Fatalf("time moved backwards: %v -> %v", te.now, now)
   291  	}
   292  	te.now = now
   293  	for _, tc := range te.conns {
   294  		if !tc.timer.After(te.now) {
   295  			tc.conn.sendMsg(timerEvent{})
   296  			tc.wait()
   297  		}
   298  	}
   299  }
   300  
   301  // testEndpointHooks implements endpointTestHooks.
   302  type testEndpointHooks testEndpoint
   303  
   304  func (te *testEndpointHooks) timeNow() time.Time {
   305  	return te.now
   306  }
   307  
   308  func (te *testEndpointHooks) newConn(c *Conn) {
   309  	tc := newTestConnForConn(te.t, (*testEndpoint)(te), c)
   310  	te.conns[c] = tc
   311  }
   312  
   313  // testEndpointUDPConn implements UDPConn.
   314  type testEndpointUDPConn testEndpoint
   315  
   316  func (te *testEndpointUDPConn) Close() error {
   317  	close(te.recvc)
   318  	return nil
   319  }
   320  
   321  func (te *testEndpointUDPConn) LocalAddr() netip.AddrPort {
   322  	return netip.MustParseAddrPort("127.0.0.1:443")
   323  }
   324  
   325  func (te *testEndpointUDPConn) Read(f func(*datagram)) {
   326  	for {
   327  		select {
   328  		case d, ok := <-te.recvc:
   329  			if !ok {
   330  				return
   331  			}
   332  			f(d)
   333  		case <-te.idlec:
   334  		}
   335  	}
   336  }
   337  
   338  func (te *testEndpointUDPConn) Write(dgram datagram) error {
   339  	te.sentDatagrams = append(te.sentDatagrams, append([]byte(nil), dgram.b...))
   340  	return nil
   341  }
   342  

View as plain text