...

Source file src/golang.org/x/net/http2/h2c/h2c_test.go

Documentation: golang.org/x/net/http2/h2c

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package h2c
     6  
     7  import (
     8  	"context"
     9  	"crypto/tls"
    10  	"fmt"
    11  	"io"
    12  	"log"
    13  	"net"
    14  	"net/http"
    15  	"net/http/httptest"
    16  	"strings"
    17  	"testing"
    18  
    19  	"golang.org/x/net/http2"
    20  )
    21  
    22  func ExampleNewHandler() {
    23  	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    24  		fmt.Fprint(w, "Hello world")
    25  	})
    26  	h2s := &http2.Server{
    27  		// ...
    28  	}
    29  	h1s := &http.Server{
    30  		Addr:    ":8080",
    31  		Handler: NewHandler(handler, h2s),
    32  	}
    33  	log.Fatal(h1s.ListenAndServe())
    34  }
    35  
    36  func TestContext(t *testing.T) {
    37  	baseCtx := context.WithValue(context.Background(), "testkey", "testvalue")
    38  
    39  	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    40  		if r.ProtoMajor != 2 {
    41  			t.Errorf("Request wasn't handled by h2c.  Got ProtoMajor=%v", r.ProtoMajor)
    42  		}
    43  		if r.Context().Value("testkey") != "testvalue" {
    44  			t.Errorf("Request doesn't have expected base context: %v", r.Context())
    45  		}
    46  		fmt.Fprint(w, "Hello world")
    47  	})
    48  
    49  	h2s := &http2.Server{}
    50  	h1s := httptest.NewUnstartedServer(NewHandler(handler, h2s))
    51  	h1s.Config.BaseContext = func(_ net.Listener) context.Context {
    52  		return baseCtx
    53  	}
    54  	h1s.Start()
    55  	defer h1s.Close()
    56  
    57  	client := &http.Client{
    58  		Transport: &http2.Transport{
    59  			AllowHTTP: true,
    60  			DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
    61  				return net.Dial(network, addr)
    62  			},
    63  		},
    64  	}
    65  
    66  	resp, err := client.Get(h1s.URL)
    67  	if err != nil {
    68  		t.Fatal(err)
    69  	}
    70  	_, err = io.ReadAll(resp.Body)
    71  	if err != nil {
    72  		t.Fatal(err)
    73  	}
    74  	if err := resp.Body.Close(); err != nil {
    75  		t.Fatal(err)
    76  	}
    77  }
    78  
    79  func TestPropagation(t *testing.T) {
    80  	var (
    81  		server *http.Server
    82  		// double the limit because http2 will compress header
    83  		headerSize  = 1 << 11
    84  		headerLimit = 1 << 10
    85  	)
    86  
    87  	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    88  		if r.ProtoMajor != 2 {
    89  			t.Errorf("Request wasn't handled by h2c.  Got ProtoMajor=%v", r.ProtoMajor)
    90  		}
    91  		if r.Context().Value(http.ServerContextKey).(*http.Server) != server {
    92  			t.Errorf("Request doesn't have expected http server: %v", r.Context())
    93  		}
    94  		if len(r.Header.Get("Long-Header")) != headerSize {
    95  			t.Errorf("Request doesn't have expected http header length: %v", len(r.Header.Get("Long-Header")))
    96  		}
    97  		fmt.Fprint(w, "Hello world")
    98  	})
    99  
   100  	h2s := &http2.Server{}
   101  	h1s := httptest.NewUnstartedServer(NewHandler(handler, h2s))
   102  
   103  	server = h1s.Config
   104  	server.MaxHeaderBytes = headerLimit
   105  	server.ConnState = func(conn net.Conn, state http.ConnState) {
   106  		t.Logf("server conn state: conn %s -> %s, status changed to %s", conn.RemoteAddr(), conn.LocalAddr(), state)
   107  	}
   108  
   109  	h1s.Start()
   110  	defer h1s.Close()
   111  
   112  	client := &http.Client{
   113  		Transport: &http2.Transport{
   114  			AllowHTTP: true,
   115  			DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
   116  				conn, err := net.Dial(network, addr)
   117  				if conn != nil {
   118  					t.Logf("client dial tls: %s -> %s", conn.RemoteAddr(), conn.LocalAddr())
   119  				}
   120  				return conn, err
   121  			},
   122  		},
   123  	}
   124  
   125  	req, err := http.NewRequest("GET", h1s.URL, nil)
   126  	if err != nil {
   127  		t.Fatal(err)
   128  	}
   129  
   130  	req.Header.Set("Long-Header", strings.Repeat("A", headerSize))
   131  
   132  	_, err = client.Do(req)
   133  	if err == nil {
   134  		t.Fatal("expected server err, got nil")
   135  	}
   136  }
   137  
   138  func TestMaxBytesHandler(t *testing.T) {
   139  	const bodyLimit = 10
   140  	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   141  		t.Errorf("got request, expected to be blocked by body limit")
   142  	})
   143  
   144  	h2s := &http2.Server{}
   145  	h1s := httptest.NewUnstartedServer(http.MaxBytesHandler(NewHandler(handler, h2s), bodyLimit))
   146  	h1s.Start()
   147  	defer h1s.Close()
   148  
   149  	// Wrap the body in a struct{io.Reader} to prevent it being rewound and resent.
   150  	body := "0123456789abcdef"
   151  	req, err := http.NewRequest("POST", h1s.URL, struct{ io.Reader }{strings.NewReader(body)})
   152  	if err != nil {
   153  		t.Fatal(err)
   154  	}
   155  	req.Header.Set("Http2-Settings", "")
   156  	req.Header.Set("Upgrade", "h2c")
   157  	req.Header.Set("Connection", "Upgrade, HTTP2-Settings")
   158  
   159  	resp, err := h1s.Client().Do(req)
   160  	if err != nil {
   161  		t.Fatal(err)
   162  	}
   163  	defer resp.Body.Close()
   164  	_, err = io.ReadAll(resp.Body)
   165  	if err != nil {
   166  		t.Fatal(err)
   167  	}
   168  	if got, want := resp.StatusCode, http.StatusInternalServerError; got != want {
   169  		t.Errorf("resp.StatusCode = %v, want %v", got, want)
   170  	}
   171  }
   172  

View as plain text