1 package authserver
2
3 import (
4 "context"
5 "fmt"
6 "net/http"
7 "regexp"
8 "slices"
9 "strconv"
10
11 "github.com/gin-contrib/sessions"
12 "github.com/gin-gonic/gin"
13
14 authproxytypes "edge-infra.dev/pkg/edge/auth-proxy/types"
15 "edge-infra.dev/pkg/lib/uuid"
16 vncconst "edge-infra.dev/pkg/sds/vnc/constants"
17 )
18
19
20
21 const (
22
23 vncReadWriteAuthRequired = "vnc_read_write_auth_required"
24 vncReadWriteAuthRequiredOverride = "vnc_read_write_auth_required_override"
25
26
27 vncReadAuthRequired = "vnc_read_auth_required"
28 vncReadAuthRequiredOverride = "vnc_read_auth_required_override"
29 )
30
31
32 type vncConnectMode string
33
34 const (
35 vncWriteConnectMode vncConnectMode = "write"
36 vncReadConnectMode vncConnectMode = "read"
37 )
38
39 var (
40
41
42
43 novncReadPath = regexp.MustCompile("^/remoteaccess/[^/]*/novnc/read/")
44 )
45
46
47
48
49
50
51 func (as *AuthServer) validateVNCRoles(ctx *gin.Context, session sessions.Session) error {
52 roles := session.Get(authproxytypes.SessionRolesField)
53 if roles == nil {
54 return &httpError{
55 statusCode: http.StatusUnauthorized,
56 err: fmt.Errorf("no user roles in session"),
57 }
58 }
59 userRoles := roles.([]string)
60
61 requestedConnectMode := getVNCConnectMode(ctx.Request)
62
63 allowedRoles := []string{
64 "EDGE_ORG_ADMIN",
65 "EDGE_BANNER_ADMIN",
66 "EDGE_BANNER_OPERATOR",
67 }
68 if requestedConnectMode == vncReadConnectMode {
69 allowedRoles = append(allowedRoles, "EDGE_BANNER_VIEWER")
70 }
71
72 allowed := slices.ContainsFunc(userRoles, func(userRole string) bool {
73 return slices.Contains(allowedRoles, userRole)
74 })
75
76 if !allowed {
77 return &httpError{
78 statusCode: http.StatusUnauthorized,
79 err: fmt.Errorf("user roles do not satisfy requested vnc connect mode (%q)", requestedConnectMode),
80 }
81 }
82
83 return nil
84 }
85
86
87
88
89 func (as *AuthServer) injectVNCAuthHeaders(ctx *gin.Context, _ sessions.Session) error {
90
91 val := ctx.GetHeader(vncconst.HeaderKeyRequestID)
92 if val == "" {
93 val = uuid.New().UUID
94 }
95 ctx.Header(vncconst.HeaderKeyRequestID, val)
96
97
98
99
100 if ctx.GetHeader(vncconst.HeaderKeyAuthMode) != "" {
101 ctx.Header(vncconst.HeaderKeyAuthMode, ctx.GetHeader(vncconst.HeaderKeyAuthMode))
102 return nil
103 }
104 return as.setVNCAuthModeHeader(ctx)
105 }
106
107
108
109
110
111 type vncClusterConfig struct {
112 BannerVNCAuthRequired string
113 ClusterVNCAuthRequired string
114 VNCAuthRequiredOverride string
115 }
116
117
118
119
120 func (as *AuthServer) setVNCAuthModeHeader(ctx *gin.Context) error {
121 clusterEdgeID := getClusterEdgeIDFromPath(ctx.Request.URL.Path)
122 if clusterEdgeID == "" {
123 return fmt.Errorf("could not find cluster edge ID from path %s", ctx.Request.URL.Path)
124 }
125
126 bannerEdgeID := ctx.GetHeader(bannerHeaderName)
127 if bannerEdgeID == "" {
128 return fmt.Errorf("could not find banner edge ID from header")
129 }
130
131 connectMode := getVNCConnectMode(ctx.Request)
132
133 clusterConfiguration, err := as.getVNCClusterConfig(ctx, bannerEdgeID, clusterEdgeID, connectMode)
134 if err != nil {
135 return err
136 }
137
138 vncAuthMode, err := getVNCAuthConfig(clusterConfiguration)
139 if err != nil {
140 return err
141 }
142
143 setBoolHeader(ctx, vncconst.HeaderKeyAuthMode, vncAuthMode)
144 return nil
145 }
146
147
148
149
150 func getVNCConnectMode(req *http.Request) vncConnectMode {
151 if novncReadPath.MatchString(req.URL.Path) {
152 return vncReadConnectMode
153 }
154 return vncWriteConnectMode
155 }
156
157
158
159 func (as *AuthServer) getVNCClusterConfig(ctx context.Context, bannerEdgeID string, clusterEdgeID string, connectMode vncConnectMode) (vncClusterConfig, error) {
160 res := vncClusterConfig{}
161
162 authRequiredKey := vncReadWriteAuthRequired
163 authOverrideKey := vncReadWriteAuthRequiredOverride
164 if connectMode == vncReadConnectMode {
165 authRequiredKey = vncReadAuthRequired
166 authOverrideKey = vncReadAuthRequiredOverride
167 }
168
169 rows, err := as.db.QueryContext(ctx, vncQuery, authRequiredKey, authOverrideKey, bannerEdgeID, clusterEdgeID)
170 if err != nil {
171 return res, fmt.Errorf("error querying db for vnc auth mode: %w", err)
172 }
173 defer rows.Close()
174
175 for rows.Next() {
176 var kind, value string
177 if err := rows.Scan(&kind, &value); err != nil {
178 return res, fmt.Errorf("error scanning vnc auth mode row: %w", err)
179 }
180
181
182
183
184 switch kind {
185 case "banner_vnc_auth_required":
186 res.BannerVNCAuthRequired = value
187 case "banner_vnc_override":
188 res.VNCAuthRequiredOverride = value
189 case "cluster_vnc_auth_required":
190 res.ClusterVNCAuthRequired = value
191 default:
192 return res, fmt.Errorf("unrecognized result from query: kind %q, value %q", kind, value)
193 }
194 }
195
196 if err := rows.Err(); err != nil {
197 return res, fmt.Errorf("error during vnc auth mode dq query iteration: %w", err)
198 }
199
200 return res, nil
201 }
202
203
204
205
206
207 func getVNCAuthConfig(clusterConfiguration vncClusterConfig) (bool, error) {
208
209
210
211
212 if clusterConfiguration.VNCAuthRequiredOverride == "" {
213
214 clusterConfiguration.VNCAuthRequiredOverride = "false"
215 }
216 if clusterConfiguration.BannerVNCAuthRequired == "" {
217
218 clusterConfiguration.BannerVNCAuthRequired = "false"
219 }
220
221 allowClusterOverride, err := strconv.ParseBool(clusterConfiguration.VNCAuthRequiredOverride)
222 if err != nil {
223 return false, fmt.Errorf("invalid vnc auth required override value: %q", clusterConfiguration.VNCAuthRequiredOverride)
224 }
225
226 var authRequired = clusterConfiguration.BannerVNCAuthRequired
227 if allowClusterOverride && clusterConfiguration.ClusterVNCAuthRequired != "" {
228
229
230
231 authRequired = clusterConfiguration.ClusterVNCAuthRequired
232 }
233
234
235 return strconv.ParseBool(authRequired)
236 }
237
View as plain text