1 package services
2
3 import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "io"
8 "io/ioutil"
9 "net/http"
10 "os"
11 "os/signal"
12 "strconv"
13 "sync"
14 "syscall"
15 "time"
16
17 "google.golang.org/grpc"
18 "google.golang.org/grpc/codes"
19 "google.golang.org/grpc/metadata"
20 "google.golang.org/grpc/status"
21
22 "github.com/datawire/ambassador/v2/pkg/api/agent"
23 "github.com/datawire/dlib/dhttp"
24 "github.com/datawire/dlib/dlog"
25 )
26
27 type GRPCAgent struct {
28 Port int16
29 }
30
31 func (a *GRPCAgent) Start(ctx context.Context) <-chan bool {
32 wg := &sync.WaitGroup{}
33 var opts []grpc.ServerOption
34 if sizeStr := os.Getenv("KAT_GRPC_MAX_RECV_MSG_SIZE"); sizeStr != "" {
35 size, err := strconv.Atoi(sizeStr)
36 if err == nil {
37 dlog.Printf(ctx, "setting gRPC MaxRecvMsgSize to %d bytes", size)
38 opts = append(opts, grpc.MaxRecvMsgSize(size))
39 }
40 }
41 grpcHandler := grpc.NewServer(opts...)
42 dir := &director{}
43 agent.RegisterDirectorServer(grpcHandler, dir)
44 sc := &dhttp.ServerConfig{
45 Handler: grpcHandler,
46 }
47 grpcErrChan := make(chan error)
48 httpErrChan := make(chan error)
49 ctx, cancel := context.WithCancel(ctx)
50
51 wg.Add(2)
52 go func() {
53 defer wg.Done()
54 dlog.Print(ctx, "starting GRPC agentcom...")
55 if err := sc.ListenAndServe(ctx, fmt.Sprintf(":%d", a.Port)); err != nil {
56 select {
57 case grpcErrChan <- err:
58 default:
59 }
60 }
61 }()
62 srv := &http.Server{Addr: ":3001"}
63
64 http.HandleFunc("/lastSnapshot", func(w http.ResponseWriter, r *http.Request) {
65 lastSnap := dir.GetLastSnapshot()
66 if lastSnap == nil {
67 w.WriteHeader(http.StatusNotFound)
68 return
69 }
70 ret, err := json.Marshal(lastSnap)
71 if err != nil {
72 w.WriteHeader(http.StatusInternalServerError)
73 return
74 }
75
76 w.WriteHeader(http.StatusOK)
77 _, _ = w.Write(ret)
78 })
79
80 go func() {
81 defer wg.Done()
82
83 dlog.Print(ctx, "Starting http server")
84 if err := srv.ListenAndServe(); err != http.ErrServerClosed {
85 select {
86 case httpErrChan <- err:
87 default:
88 }
89 }
90 }()
91
92 exited := make(chan bool)
93 go func() {
94
95 c := make(chan os.Signal, 1)
96 signal.Notify(c, os.Interrupt, syscall.SIGTERM)
97
98 select {
99 case err := <-grpcErrChan:
100 dlog.Errorf(ctx, "GRPC service died: %+v", err)
101 panic(err)
102 case err := <-httpErrChan:
103 dlog.Errorf(ctx, "http service died: %+v", err)
104 panic(err)
105 case <-c:
106 dlog.Print(ctx, "Received shutdown")
107 }
108
109 ctx, timeout := context.WithTimeout(ctx, time.Second*30)
110 defer timeout()
111 cancel()
112
113 grpcHandler.GracefulStop()
114 _ = srv.Shutdown(ctx)
115 wg.Wait()
116 close(exited)
117 }()
118 return exited
119 }
120
121 type director struct {
122 agent.UnimplementedDirectorServer
123 lastSnapshot *agent.Snapshot
124 }
125
126 func (d *director) GetLastSnapshot() *agent.Snapshot {
127 return d.lastSnapshot
128 }
129
130
131 func (d *director) Report(ctx context.Context, snapshot *agent.Snapshot) (*agent.SnapshotResponse, error) {
132 err := checkContext(ctx)
133 if err != nil {
134 return nil, err
135 }
136
137 dlog.Print(ctx, "Received snapshot")
138
139 err = writeSnapshot(snapshot)
140 if err != nil {
141 return nil, err
142 }
143
144 d.lastSnapshot = snapshot
145 return &agent.SnapshotResponse{}, nil
146 }
147
148 func (d *director) Retrieve(agentID *agent.Identity, stream agent.Director_RetrieveServer) error {
149 return nil
150 }
151
152 func checkContext(ctx context.Context) error {
153 md, ok := metadata.FromIncomingContext(ctx)
154 if !ok {
155 dlog.Print(ctx, "No metadata found, not allowing request")
156 err := status.Error(codes.PermissionDenied, "Missing grpc metadata")
157
158 return err
159 }
160
161 apiKeyValues := md.Get("x-ambassador-api-key")
162 if len(apiKeyValues) == 0 || apiKeyValues[0] == "" {
163 dlog.Print(ctx, "api key found, not allowing request")
164 err := status.Error(codes.PermissionDenied, "Missing api key")
165 return err
166 }
167 return nil
168 }
169
170 func writeSnapshot(snapshot *agent.Snapshot) error {
171 snapBytes, err := json.Marshal(snapshot)
172 if err != nil {
173 return err
174 }
175 err = ioutil.WriteFile("/tmp/snapshot.json", snapBytes, 0644)
176 if err != nil {
177 return err
178 }
179 return nil
180 }
181
182 func (d *director) ReportStream(server agent.Director_ReportStreamServer) error {
183 err := checkContext(server.Context())
184 if err != nil {
185 return err
186 }
187
188 var data []byte
189 for {
190 msg, err := server.Recv()
191 data = append(data, msg.GetChunk()...)
192 if err != nil {
193 if err == io.EOF {
194 break
195 } else {
196 return err
197 }
198 }
199 }
200
201 var snapshot agent.Snapshot
202 err = json.Unmarshal(data, &snapshot)
203 if err != nil {
204 return err
205 }
206
207 dlog.Print(server.Context(), "Received snapshot")
208
209 err = writeSnapshot(&snapshot)
210 if err != nil {
211 return err
212 }
213
214 response := &agent.SnapshotResponse{}
215 err = server.SendMsg(response)
216 return err
217 }
218
View as plain text