...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package transport
16
17 import (
18 "net/http"
19 "net/http/httptest"
20 "reflect"
21 "testing"
22 )
23
24 func TestValidateSecureEndpoints(t *testing.T) {
25 tlsInfo, certCleanup, err := createSelfCert()
26 if err != nil {
27 t.Fatalf("unable to create cert: %v", err)
28 }
29 defer certCleanup()
30
31 remoteAddr := func(w http.ResponseWriter, r *http.Request) {
32 w.Write([]byte(r.RemoteAddr))
33 }
34 srv := httptest.NewServer(http.HandlerFunc(remoteAddr))
35 defer srv.Close()
36
37 tests := map[string]struct {
38 endPoints []string
39 expectedEndpoints []string
40 expectedErr bool
41 }{
42 "invalidEndPoints": {
43 endPoints: []string{
44 "invalid endpoint",
45 },
46 expectedEndpoints: nil,
47 expectedErr: true,
48 },
49 "insecureEndpoints": {
50 endPoints: []string{
51 "http://127.0.0.1:8000",
52 "http://" + srv.Listener.Addr().String(),
53 },
54 expectedEndpoints: nil,
55 expectedErr: true,
56 },
57 "secureEndPoints": {
58 endPoints: []string{
59 "https://" + srv.Listener.Addr().String(),
60 },
61 expectedEndpoints: []string{
62 "https://" + srv.Listener.Addr().String(),
63 },
64 expectedErr: false,
65 },
66 "mixEndPoints": {
67 endPoints: []string{
68 "https://" + srv.Listener.Addr().String(),
69 "http://" + srv.Listener.Addr().String(),
70 "invalid end points",
71 },
72 expectedEndpoints: []string{
73 "https://" + srv.Listener.Addr().String(),
74 },
75 expectedErr: true,
76 },
77 }
78 for name, test := range tests {
79 t.Run(name, func(t *testing.T) {
80 secureEps, err := ValidateSecureEndpoints(*tlsInfo, test.endPoints)
81 if test.expectedErr != (err != nil) {
82 t.Errorf("Unexpected error, got: %v, want: %v", err, test.expectedErr)
83 }
84
85 if !reflect.DeepEqual(test.expectedEndpoints, secureEps) {
86 t.Errorf("expected endpoints %v, got %v", test.expectedEndpoints, secureEps)
87 }
88 })
89 }
90 }
91
View as plain text