1 package agent
2
3 import (
4 "context"
5 "crypto/tls"
6 "encoding/json"
7 "fmt"
8 "io"
9 "net/url"
10 "sync"
11
12 "google.golang.org/grpc"
13 "google.golang.org/grpc/credentials"
14 "google.golang.org/grpc/metadata"
15
16 "github.com/datawire/ambassador/v2/pkg/api/agent"
17 "github.com/datawire/dlib/dlog"
18 )
19
20 const APIKeyMetadataKey = "x-ambassador-api-key"
21
22 type RPCComm struct {
23 conn io.Closer
24 client agent.DirectorClient
25 rptWake chan struct{}
26 retCancel context.CancelFunc
27 agentID *agent.Identity
28 directives chan *agent.Directive
29 metricsStreamWriterMutex sync.Mutex
30 extraHeaders []string
31 }
32
33 const (
34 defaultHostname = "app.getambassador.io"
35 defaultPort = "443"
36 )
37
38 type ConnInfo struct {
39 hostname string
40 port string
41 secure bool
42 }
43
44 func connInfoFromAddress(address string) (*ConnInfo, error) {
45 endpoint, err := url.Parse(address)
46 if err != nil {
47 return nil, err
48 }
49
50 hostname := endpoint.Hostname()
51 if hostname == "" {
52 hostname = defaultHostname
53 }
54
55 port := endpoint.Port()
56 if port == "" {
57 port = defaultPort
58 }
59
60 secure := true
61 if endpoint.Scheme == "http" {
62 secure = false
63 }
64
65 return &ConnInfo{hostname, port, secure}, nil
66 }
67
68 func NewComm(
69 ctx context.Context,
70 connInfo *ConnInfo,
71 agentID *agent.Identity,
72 apiKey string,
73 extraHeaders []string,
74 ) (*RPCComm, error) {
75 ctx = dlog.WithField(ctx, "agent", "comm")
76 opts := make([]grpc.DialOption, 0, 1)
77 address := connInfo.hostname + ":" + connInfo.port
78
79 if connInfo.secure {
80 config := &tls.Config{ServerName: connInfo.hostname}
81 creds := credentials.NewTLS(config)
82 opts = append(opts, grpc.WithTransportCredentials(creds))
83 } else {
84 opts = append(opts, grpc.WithInsecure())
85 }
86
87 dlog.Debugf(ctx, "Dialing server at %s (secure=%t)", address, connInfo.secure)
88
89 conn, err := grpc.Dial(address, opts...)
90 if err != nil {
91 return nil, err
92 }
93
94 client := agent.NewDirectorClient(conn)
95 retCtx, retCancel := context.WithCancel(ctx)
96
97 c := &RPCComm{
98 conn: conn,
99 client: client,
100 retCancel: retCancel,
101 agentID: agentID,
102 directives: make(chan *agent.Directive, 1),
103 rptWake: make(chan struct{}, 1),
104 extraHeaders: extraHeaders,
105 }
106 retCtx = metadata.AppendToOutgoingContext(ctx, c.getHeaders(apiKey)...)
107
108 go c.retrieveLoop(retCtx)
109
110 return c, nil
111 }
112
113 func (c *RPCComm) getHeaders(apiKey string) []string {
114 return append([]string{
115 APIKeyMetadataKey, apiKey}, c.extraHeaders...)
116 }
117
118 func (c *RPCComm) retrieveLoop(ctx context.Context) {
119 ctx = dlog.WithField(ctx, "agent", "retriever")
120
121 for {
122 if err := c.retrieve(ctx); err != nil {
123 dlog.Debugf(ctx, "exited: %+v", err)
124 }
125
126 select {
127 case <-c.rptWake:
128 dlog.Debug(ctx, "restarting")
129 case <-ctx.Done():
130 return
131 }
132 }
133 }
134
135 func (c *RPCComm) retrieve(ctx context.Context) error {
136 stream, err := c.client.Retrieve(ctx, c.agentID)
137
138 if err != nil {
139 return err
140 }
141
142 for {
143 directive, err := stream.Recv()
144 if err != nil {
145 return err
146 }
147
148 select {
149 case c.directives <- directive:
150 case <-ctx.Done():
151 return nil
152 }
153 }
154 }
155
156 func (c *RPCComm) Close() error {
157 c.retCancel()
158 return c.conn.Close()
159 }
160
161 func (c *RPCComm) ReportCommandResult(ctx context.Context, result *agent.CommandResult, apiKey string) error {
162 ctx = metadata.AppendToOutgoingContext(ctx, c.getHeaders(apiKey)...)
163 _, err := c.client.ReportCommandResult(ctx, result, grpc.EmptyCallOption{})
164 if err != nil {
165 return fmt.Errorf("ReportCommandResult error: %w", err)
166 }
167 return nil
168 }
169
170 func (c *RPCComm) Report(ctx context.Context, report *agent.Snapshot, apiKey string) error {
171 select {
172 case c.rptWake <- struct{}{}:
173 default:
174 }
175 ctx = metadata.AppendToOutgoingContext(ctx, c.getHeaders(apiKey)...)
176
177
178 data, err := json.Marshal(report)
179 if err != nil {
180 return fmt.Errorf("json.Marshal: %w", err)
181 }
182
183 const CHUNKSIZE = (64 * 1024) - 4
184 dlog.Debugf(ctx, "Report is %dB; will take %d chunks",
185 len(data),
186 (len(data)+CHUNKSIZE-1)/CHUNKSIZE)
187
188
189 stream, err := c.client.ReportStream(ctx)
190 if err != nil {
191 return fmt.Errorf("ReportStream.Open: %w", err)
192 }
193
194
195 msg := &agent.RawSnapshotChunk{}
196 for i := 0; i < len(data); i += CHUNKSIZE {
197 j := i + CHUNKSIZE
198
199 if j < len(data) {
200 msg.Chunk = data[i:j]
201 } else {
202 msg.Chunk = data[i:]
203 }
204
205 if err := stream.Send(msg); err != nil {
206 return fmt.Errorf("ReportStream.Send: %w", err)
207 }
208 }
209
210 if _, err = stream.CloseAndRecv(); err != nil {
211 return fmt.Errorf("ReportStream.Close: %w", err)
212 }
213
214 return nil
215 }
216
217 func (c *RPCComm) StreamMetrics(ctx context.Context, metrics *agent.StreamMetricsMessage, apiKey string) error {
218 ctx = dlog.WithField(ctx, "agent", "streammetrics")
219
220 c.metricsStreamWriterMutex.Lock()
221 defer c.metricsStreamWriterMutex.Unlock()
222 ctx = metadata.AppendToOutgoingContext(ctx, c.getHeaders(apiKey)...)
223 streamClient, err := c.client.StreamMetrics(ctx)
224
225 if err != nil {
226 return err
227 }
228
229 if err := streamClient.Send(metrics); err != nil {
230 return err
231 }
232
233 return streamClient.CloseSend()
234 }
235
236 func (c *RPCComm) Directives() <-chan *agent.Directive {
237 return c.directives
238 }
239
View as plain text