1 package agent
2
3 import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "io"
8 "sync"
9 "sync/atomic"
10 "testing"
11
12 "github.com/stretchr/testify/assert"
13 "google.golang.org/grpc"
14 "google.golang.org/grpc/metadata"
15
16 "github.com/datawire/ambassador/v2/pkg/api/agent"
17 "github.com/datawire/dlib/dlog"
18 )
19
20 type MockClient struct {
21 Counter int64
22 grpc.ClientStream
23 SentMetrics []*agent.StreamMetricsMessage
24 SentSnapshots []*agent.Snapshot
25 snapMux sync.Mutex
26 reportFunc func(context.Context, *agent.Snapshot) (*agent.SnapshotResponse, error)
27 LastMetadata metadata.MD
28 }
29
30 func (m *MockClient) ReportCommandResult(ctx context.Context, in *agent.CommandResult, opts ...grpc.CallOption) (*agent.CommandResultResponse, error) {
31 panic("implement me")
32 }
33
34 func (m *MockClient) Close() error {
35 return nil
36 }
37
38 func (m *MockClient) GetLastMetadata() metadata.MD {
39 m.snapMux.Lock()
40 defer m.snapMux.Unlock()
41 meta := m.LastMetadata
42 return meta
43 }
44
45 func (m *MockClient) GetSnapshots() []*agent.Snapshot {
46 m.snapMux.Lock()
47 defer m.snapMux.Unlock()
48 snaps := m.SentSnapshots
49 return snaps
50 }
51
52 func (m *MockClient) Report(ctx context.Context, in *agent.Snapshot, opts ...grpc.CallOption) (*agent.SnapshotResponse, error) {
53 m.snapMux.Lock()
54 defer m.snapMux.Unlock()
55 if m.SentSnapshots == nil {
56 m.SentSnapshots = []*agent.Snapshot{}
57 }
58 m.SentSnapshots = append(m.SentSnapshots, in)
59 md, _ := metadata.FromOutgoingContext(ctx)
60 m.LastMetadata = md
61 if m.reportFunc != nil {
62 return m.reportFunc(ctx, in)
63 }
64 return nil, nil
65 }
66
67 func (m *MockClient) StreamMetrics(ctx context.Context, opts ...grpc.CallOption) (agent.Director_StreamMetricsClient, error) {
68 return &mockStreamMetricsClient{
69 ctx: ctx,
70 opts: opts,
71 parent: m,
72 }, nil
73 }
74
75 type mockStreamMetricsClient struct {
76 ctx context.Context
77 opts []grpc.CallOption
78 parent *MockClient
79 }
80
81 func (s *mockStreamMetricsClient) Send(msg *agent.StreamMetricsMessage) error {
82 s.parent.SentMetrics = append(s.parent.SentMetrics, msg)
83 return nil
84 }
85 func (s *mockStreamMetricsClient) CloseAndRecv() (*agent.StreamMetricsResponse, error) {
86 return nil, nil
87 }
88
89 func (s *mockStreamMetricsClient) Header() (metadata.MD, error) { return nil, nil }
90 func (s *mockStreamMetricsClient) Trailer() metadata.MD { return nil }
91 func (s *mockStreamMetricsClient) CloseSend() error { return nil }
92 func (s *mockStreamMetricsClient) Context() context.Context { return s.ctx }
93 func (s *mockStreamMetricsClient) SendMsg(m interface{}) error { return nil }
94 func (s *mockStreamMetricsClient) RecvMsg(m interface{}) error { return nil }
95
96 type mockReportStreamClient struct {
97 ctx context.Context
98 opts []grpc.CallOption
99 parent *MockClient
100 content []byte
101 }
102
103 func (s *mockReportStreamClient) Send(chunk *agent.RawSnapshotChunk) error {
104 s.content = append(s.content, chunk.Chunk...)
105 return nil
106 }
107 func (s *mockReportStreamClient) CloseAndRecv() (*agent.SnapshotResponse, error) {
108 var snapshot agent.Snapshot
109 if err := json.Unmarshal(s.content, &snapshot); err != nil {
110 return nil, err
111 }
112 return s.parent.Report(s.ctx, &snapshot, s.opts...)
113 }
114
115 func (s *mockReportStreamClient) Header() (metadata.MD, error) { return nil, nil }
116 func (s *mockReportStreamClient) Trailer() metadata.MD { return nil }
117 func (s *mockReportStreamClient) CloseSend() error { return nil }
118 func (s *mockReportStreamClient) Context() context.Context { return s.ctx }
119 func (s *mockReportStreamClient) SendMsg(m interface{}) error { return nil }
120 func (s *mockReportStreamClient) RecvMsg(m interface{}) error { return nil }
121
122 func (m *MockClient) ReportStream(ctx context.Context, opts ...grpc.CallOption) (agent.Director_ReportStreamClient, error) {
123 return &mockReportStreamClient{
124 ctx: ctx,
125 opts: opts,
126 parent: m,
127 }, nil
128 }
129
130 func (m *MockClient) Recv() (*agent.Directive, error) {
131 counter := atomic.AddInt64(&m.Counter, 1)
132
133 if counter < 3 {
134 return &agent.Directive{
135 Commands: []*agent.Command{
136 {Message: fmt.Sprintf("test command %d", counter)},
137 },
138 }, nil
139 }
140
141 return nil, io.EOF
142 }
143
144 func (m *MockClient) Retrieve(ctx context.Context, in *agent.Identity, opts ...grpc.CallOption) (agent.Director_RetrieveClient, error) {
145 fmt.Println("Retrieve called")
146 return m, nil
147 }
148
149 type retrvsnapshotclient struct {
150 grpc.ClientStream
151 }
152
153 func (r *retrvsnapshotclient) Recv() (*agent.RawSnapshotChunk, error) {
154 return nil, nil
155 }
156
157 func (m *MockClient) RetrieveSnapshot(context.Context, *agent.Identity, ...grpc.CallOption) (agent.Director_RetrieveSnapshotClient, error) {
158 return &retrvsnapshotclient{}, nil
159 }
160
161 func TestComm(t *testing.T) {
162 ctx := dlog.NewTestContext(t, false)
163 ctx, cancel := context.WithCancel(ctx)
164 client := &MockClient{}
165 agentID := &agent.Identity{}
166 c := &RPCComm{
167 conn: client,
168 client: client,
169 rptWake: make(chan struct{}, 1),
170 retCancel: cancel,
171 agentID: agentID,
172 directives: make(chan *agent.Directive, 1),
173 }
174
175 go c.retrieveLoop(ctx)
176
177 t.Logf("got: %v", <-c.directives)
178 t.Logf("got: %v", <-c.directives)
179
180 atomic.StoreInt64(&client.Counter, 0)
181
182 if err := c.Report(ctx, &agent.Snapshot{
183 Identity: agentID,
184 Message: "hello same ID",
185 }, "apikey"); err != nil {
186 t.Errorf("Comm.Report() error = %v", err)
187 }
188
189 t.Logf("got: %v", <-c.directives)
190 t.Logf("got: %v", <-c.directives)
191
192 eqID := &agent.Identity{}
193
194 if err := c.Report(ctx, &agent.Snapshot{
195 Identity: eqID,
196 Message: "hello equivalent ID",
197 }, "apikey"); err != nil {
198 t.Errorf("Comm.Report() error = %v", err)
199 }
200
201 if err := c.Close(); err != nil {
202 t.Errorf("Comm.Close() error = %v", err)
203 }
204 }
205
206 func TestConnInfo(t *testing.T) {
207 assert := assert.New(t)
208
209 var (
210 ci *ConnInfo
211 err error
212 )
213
214 defaults := []string{
215 "",
216 fmt.Sprintf("https://%s:%s/", defaultHostname, defaultPort),
217 "a bogus value that looks like a path",
218 }
219
220 for _, addr := range defaults {
221 ci, err = connInfoFromAddress(addr)
222
223 assert.NoError(err)
224 assert.Equal(defaultHostname, ci.hostname)
225 assert.Equal(defaultPort, ci.port)
226 assert.True(ci.secure)
227 }
228
229 ci, err = connInfoFromAddress(":a bad value")
230 assert.Error(err, ci)
231 }
232
View as plain text