1
16
17
18
19 package driver
20
21 import (
22 "context"
23 "encoding/json"
24 "errors"
25 "net"
26 "sync"
27
28 "google.golang.org/grpc/codes"
29 "google.golang.org/grpc/status"
30 "k8s.io/klog/v2"
31
32 "github.com/container-storage-interface/spec/lib/go/csi"
33 "google.golang.org/grpc"
34 )
35
36 var (
37
38 ErrNoCredentials = errors.New("secret must be provided")
39
40 ErrAuthFailed = errors.New("authentication failed")
41 )
42
43
44
45 type CSIDriverServers struct {
46 Controller csi.ControllerServer
47 Identity csi.IdentityServer
48 Node csi.NodeServer
49 }
50
51
52 const secretField = "secretKey"
53
54
55
56
57 type CSICreds struct {
58 CreateVolumeSecret string
59 DeleteVolumeSecret string
60 ControllerPublishVolumeSecret string
61 ControllerUnpublishVolumeSecret string
62 NodeStageVolumeSecret string
63 NodePublishVolumeSecret string
64 CreateSnapshotSecret string
65 DeleteSnapshotSecret string
66 ControllerValidateVolumeCapabilitiesSecret string
67 }
68
69 type CSIDriver struct {
70 listener net.Listener
71 server *grpc.Server
72 servers *CSIDriverServers
73 wg sync.WaitGroup
74 running bool
75 lock sync.Mutex
76 creds *CSICreds
77 logGRPC LogGRPC
78 }
79
80 type LogGRPC func(method string, request, reply interface{}, err error)
81
82 func NewCSIDriver(servers *CSIDriverServers) *CSIDriver {
83 return &CSIDriver{
84 servers: servers,
85 }
86 }
87
88 func (c *CSIDriver) goServe(started chan<- bool) {
89 goServe(c.server, &c.wg, c.listener, started)
90 }
91
92 func (c *CSIDriver) Address() string {
93 return c.listener.Addr().String()
94 }
95
96
97
98
99
100 func (c *CSIDriver) Start(l net.Listener, interceptor grpc.UnaryServerInterceptor) error {
101 c.lock.Lock()
102 defer c.lock.Unlock()
103
104
105 c.listener = l
106
107
108 if interceptor == nil {
109 interceptor = c.callInterceptor
110 }
111 c.server = grpc.NewServer(grpc.UnaryInterceptor(interceptor))
112
113
114 if c.servers.Controller != nil {
115 csi.RegisterControllerServer(c.server, c.servers.Controller)
116 }
117 if c.servers.Identity != nil {
118 csi.RegisterIdentityServer(c.server, c.servers.Identity)
119 }
120 if c.servers.Node != nil {
121 csi.RegisterNodeServer(c.server, c.servers.Node)
122 }
123
124
125 waitForServer := make(chan bool)
126 c.goServe(waitForServer)
127 <-waitForServer
128 c.running = true
129 return nil
130 }
131
132 func (c *CSIDriver) Stop() {
133 stop(&c.lock, &c.wg, c.server, c.running)
134 }
135
136 func (c *CSIDriver) Close() {
137 c.server.Stop()
138 }
139
140 func (c *CSIDriver) IsRunning() bool {
141 c.lock.Lock()
142 defer c.lock.Unlock()
143
144 return c.running
145 }
146
147
148 func (c *CSIDriver) SetDefaultCreds() {
149 setDefaultCreds(c.creds)
150 }
151
152
153 func goServe(server *grpc.Server, wg *sync.WaitGroup, listener net.Listener, started chan<- bool) {
154 wg.Add(1)
155 go func() {
156 defer wg.Done()
157 started <- true
158 err := server.Serve(listener)
159 if err != nil {
160 klog.Infof("gRPC server for CSI driver stopped: %v", err)
161 }
162 }()
163 }
164
165
166 func stop(lock *sync.Mutex, wg *sync.WaitGroup, server *grpc.Server, running bool) {
167 lock.Lock()
168 defer lock.Unlock()
169
170 if !running {
171 return
172 }
173
174 server.Stop()
175 wg.Wait()
176 }
177
178
179 func setDefaultCreds(creds *CSICreds) {
180 *creds = CSICreds{
181 CreateVolumeSecret: "secretval1",
182 DeleteVolumeSecret: "secretval2",
183 ControllerPublishVolumeSecret: "secretval3",
184 ControllerUnpublishVolumeSecret: "secretval4",
185 NodeStageVolumeSecret: "secretval5",
186 NodePublishVolumeSecret: "secretval6",
187 CreateSnapshotSecret: "secretval7",
188 DeleteSnapshotSecret: "secretval8",
189 ControllerValidateVolumeCapabilitiesSecret: "secretval9",
190 }
191 }
192
193 func (c *CSIDriver) callInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
194 err := authInterceptor(c.creds, req)
195 if err != nil {
196 logGRPC(info.FullMethod, req, nil, err)
197 return nil, err
198 }
199 rsp, err := handler(ctx, req)
200 logGRPC(info.FullMethod, req, rsp, err)
201 if c.logGRPC != nil {
202 c.logGRPC(info.FullMethod, req, rsp, err)
203 }
204 return rsp, err
205 }
206
207 func authInterceptor(creds *CSICreds, req interface{}) error {
208 if creds != nil {
209 authenticated, authErr := isAuthenticated(req, creds)
210 if !authenticated {
211 if authErr == ErrNoCredentials {
212 return status.Error(codes.InvalidArgument, authErr.Error())
213 }
214 if authErr == ErrAuthFailed {
215 return status.Error(codes.Unauthenticated, authErr.Error())
216 }
217 }
218 }
219 return nil
220 }
221
222 func logGRPC(method string, request, reply interface{}, err error) {
223
224 logMessage := struct {
225 Method string
226 Request interface{}
227 Response interface{}
228
229
230 Error string
231
232 FullError error
233 }{
234 Method: method,
235 Request: request,
236 Response: reply,
237 FullError: err,
238 }
239
240 if err != nil {
241 logMessage.Error = err.Error()
242 }
243
244 msg, _ := json.Marshal(logMessage)
245 klog.V(3).Infof("gRPCCall: %s\n", msg)
246 }
247
248 func isAuthenticated(req interface{}, creds *CSICreds) (bool, error) {
249 switch r := req.(type) {
250 case *csi.CreateVolumeRequest:
251 return authenticateCreateVolume(r, creds)
252 case *csi.DeleteVolumeRequest:
253 return authenticateDeleteVolume(r, creds)
254 case *csi.ControllerPublishVolumeRequest:
255 return authenticateControllerPublishVolume(r, creds)
256 case *csi.ControllerUnpublishVolumeRequest:
257 return authenticateControllerUnpublishVolume(r, creds)
258 case *csi.NodeStageVolumeRequest:
259 return authenticateNodeStageVolume(r, creds)
260 case *csi.NodePublishVolumeRequest:
261 return authenticateNodePublishVolume(r, creds)
262 case *csi.CreateSnapshotRequest:
263 return authenticateCreateSnapshot(r, creds)
264 case *csi.DeleteSnapshotRequest:
265 return authenticateDeleteSnapshot(r, creds)
266 case *csi.ValidateVolumeCapabilitiesRequest:
267 return authenticateControllerValidateVolumeCapabilities(r, creds)
268 default:
269 return true, nil
270 }
271 }
272
273 func authenticateCreateVolume(req *csi.CreateVolumeRequest, creds *CSICreds) (bool, error) {
274 return credsCheck(req.GetSecrets(), creds.CreateVolumeSecret)
275 }
276
277 func authenticateDeleteVolume(req *csi.DeleteVolumeRequest, creds *CSICreds) (bool, error) {
278 return credsCheck(req.GetSecrets(), creds.DeleteVolumeSecret)
279 }
280
281 func authenticateControllerPublishVolume(req *csi.ControllerPublishVolumeRequest, creds *CSICreds) (bool, error) {
282 return credsCheck(req.GetSecrets(), creds.ControllerPublishVolumeSecret)
283 }
284
285 func authenticateControllerUnpublishVolume(req *csi.ControllerUnpublishVolumeRequest, creds *CSICreds) (bool, error) {
286 return credsCheck(req.GetSecrets(), creds.ControllerUnpublishVolumeSecret)
287 }
288
289 func authenticateNodeStageVolume(req *csi.NodeStageVolumeRequest, creds *CSICreds) (bool, error) {
290 return credsCheck(req.GetSecrets(), creds.NodeStageVolumeSecret)
291 }
292
293 func authenticateNodePublishVolume(req *csi.NodePublishVolumeRequest, creds *CSICreds) (bool, error) {
294 return credsCheck(req.GetSecrets(), creds.NodePublishVolumeSecret)
295 }
296
297 func authenticateCreateSnapshot(req *csi.CreateSnapshotRequest, creds *CSICreds) (bool, error) {
298 return credsCheck(req.GetSecrets(), creds.CreateSnapshotSecret)
299 }
300
301 func authenticateDeleteSnapshot(req *csi.DeleteSnapshotRequest, creds *CSICreds) (bool, error) {
302 return credsCheck(req.GetSecrets(), creds.DeleteSnapshotSecret)
303 }
304
305 func authenticateControllerValidateVolumeCapabilities(req *csi.ValidateVolumeCapabilitiesRequest, creds *CSICreds) (bool, error) {
306 return credsCheck(req.GetSecrets(), creds.ControllerValidateVolumeCapabilitiesSecret)
307 }
308
309 func credsCheck(secrets map[string]string, secretVal string) (bool, error) {
310 if len(secrets) == 0 {
311 return false, ErrNoCredentials
312 }
313
314 if secrets[secretField] != secretVal {
315 return false, ErrAuthFailed
316 }
317 return true, nil
318 }
319
View as plain text