1
18
19 package record
20
21 import (
22 "errors"
23 "fmt"
24 "testing"
25
26 "github.com/google/go-cmp/cmp"
27 "github.com/google/go-cmp/cmp/cmpopts"
28 commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
29 s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
30 "github.com/google/s2a-go/internal/tokenmanager"
31 "google.golang.org/grpc/codes"
32 "google.golang.org/protobuf/testing/protocmp"
33 )
34
35 const (
36 testAccessToken = "test_access_token"
37 )
38
39 type fakeStream struct {
40
41
42 returnInvalid bool
43
44
45 returnRecvErr bool
46 }
47
48 func (fs *fakeStream) Send(req *s2apb.SessionReq) error {
49 if len(req.GetResumptionTicket().GetInBytes()) == 0 {
50 return errors.New("fakeStream Send received an empty InBytes")
51 }
52 if req.GetResumptionTicket().GetConnectionId() == 0 {
53 return errors.New("fakeStream Send received a 0 ConnectionId")
54 }
55 if req.GetResumptionTicket().GetLocalIdentity() == nil {
56 return errors.New("fakeStream Send received an empty LocalIdentity")
57 }
58 return nil
59 }
60
61 func (fs *fakeStream) Recv() (*s2apb.SessionResp, error) {
62 if fs.returnRecvErr {
63 return nil, errors.New("fakeStream Recv error")
64 }
65 if fs.returnInvalid {
66 return &s2apb.SessionResp{
67 Status: &s2apb.SessionStatus{Code: uint32(codes.InvalidArgument)},
68 }, nil
69 }
70 return &s2apb.SessionResp{
71 Status: &s2apb.SessionStatus{Code: uint32(codes.OK)},
72 }, nil
73 }
74
75 type fakeAccessTokenManager struct {
76 acceptedIdentity *commonpb.Identity
77 accessToken string
78 allowEmptyIdentity bool
79 }
80
81 func (m *fakeAccessTokenManager) DefaultToken() (string, error) {
82 if !m.allowEmptyIdentity {
83 return "", fmt.Errorf("not allowed to get token for empty identity")
84 }
85 return m.accessToken, nil
86 }
87
88 func (m *fakeAccessTokenManager) Token(identity *commonpb.Identity) (string, error) {
89 if identity == nil || cmp.Equal(identity, &commonpb.Identity{}, protocmp.Transform()) {
90 if !m.allowEmptyIdentity {
91 return "", fmt.Errorf("not allowed to get token for empty identity")
92 }
93 return m.accessToken, nil
94 }
95 if cmp.Equal(identity, m.acceptedIdentity, protocmp.Transform()) {
96 return m.accessToken, nil
97 }
98 return "", fmt.Errorf("unable to get token")
99 }
100
101 func TestWriteTicketsToStream(t *testing.T) {
102 for _, tc := range []struct {
103 returnInvalid bool
104 returnRecvError bool
105 }{
106 {
107
108 },
109 {
110 returnInvalid: true,
111 },
112 {
113 returnRecvError: true,
114 },
115 } {
116 sender := ticketSender{
117 connectionID: 1,
118 localIdentity: &commonpb.Identity{
119 IdentityOneof: &commonpb.Identity_SpiffeId{
120 SpiffeId: "test_spiffe_id",
121 },
122 },
123 }
124 fs := &fakeStream{returnInvalid: tc.returnInvalid, returnRecvErr: tc.returnRecvError}
125 if got, want := sender.writeTicketsToStream(fs, make([][]byte, 1)) == nil, !tc.returnRecvError && !tc.returnInvalid; got != want {
126 t.Errorf("sender.writeTicketsToStream(%v, _) = (err=nil) = %v, want %v", fs, got, want)
127 }
128 }
129 }
130
131 func TestGetAuthMechanism(t *testing.T) {
132 sortProtos := cmpopts.SortSlices(func(m1, m2 *s2apb.AuthenticationMechanism) bool { return m1.String() < m2.String() })
133 for _, tc := range []struct {
134 description string
135 localIdentity *commonpb.Identity
136 tokenManager tokenmanager.AccessTokenManager
137 expectedAuthMechanisms []*s2apb.AuthenticationMechanism
138 }{
139 {
140 description: "token manager is nil",
141 tokenManager: nil,
142 expectedAuthMechanisms: nil,
143 },
144 {
145 description: "token manager expects empty identity",
146 tokenManager: &fakeAccessTokenManager{
147 accessToken: testAccessToken,
148 allowEmptyIdentity: true,
149 },
150 expectedAuthMechanisms: []*s2apb.AuthenticationMechanism{
151 {
152 MechanismOneof: &s2apb.AuthenticationMechanism_Token{
153 Token: testAccessToken,
154 },
155 },
156 },
157 },
158 {
159 description: "token manager does not expect empty identity",
160 tokenManager: &fakeAccessTokenManager{
161 allowEmptyIdentity: false,
162 },
163 expectedAuthMechanisms: nil,
164 },
165 {
166 description: "token manager expects SPIFFE ID",
167 localIdentity: &commonpb.Identity{
168 IdentityOneof: &commonpb.Identity_SpiffeId{
169 SpiffeId: "allowed_spiffe_id",
170 },
171 },
172 tokenManager: &fakeAccessTokenManager{
173 accessToken: testAccessToken,
174 acceptedIdentity: &commonpb.Identity{
175 IdentityOneof: &commonpb.Identity_SpiffeId{
176 SpiffeId: "allowed_spiffe_id",
177 },
178 },
179 },
180 expectedAuthMechanisms: []*s2apb.AuthenticationMechanism{
181 {
182 Identity: &commonpb.Identity{
183 IdentityOneof: &commonpb.Identity_SpiffeId{
184 SpiffeId: "allowed_spiffe_id",
185 },
186 },
187 MechanismOneof: &s2apb.AuthenticationMechanism_Token{
188 Token: testAccessToken,
189 },
190 },
191 },
192 },
193 {
194 description: "token manager does not expect hostname",
195
196 localIdentity: &commonpb.Identity{
197 IdentityOneof: &commonpb.Identity_Hostname{
198 Hostname: "disallowed_hostname",
199 },
200 },
201 tokenManager: &fakeAccessTokenManager{},
202 expectedAuthMechanisms: nil,
203 },
204 } {
205 t.Run(tc.description, func(t *testing.T) {
206 ticketSender := &ticketSender{
207 localIdentity: tc.localIdentity,
208 tokenManager: tc.tokenManager,
209 }
210 authMechanisms := ticketSender.getAuthMechanisms()
211 if got, want := (authMechanisms == nil), (tc.expectedAuthMechanisms == nil); got != want {
212 t.Errorf("authMechanisms == nil: %t, tc.expectedAuthMechanisms == nil: %t", got, want)
213 }
214 if authMechanisms != nil && tc.expectedAuthMechanisms != nil {
215 if diff := cmp.Diff(authMechanisms, tc.expectedAuthMechanisms, protocmp.Transform(), sortProtos); diff != "" {
216 t.Errorf("ticketSender.getAuthMechanisms() returned incorrect slice, (-want +got):\n%s", diff)
217 }
218 }
219 })
220 }
221 }
222
View as plain text