1 package util
2
3 import (
4 "bytes"
5 "compress/zlib"
6 "context"
7 "crypto/rand"
8 "encoding/base64"
9 "encoding/json"
10 "errors"
11 "fmt"
12 "io"
13 "net"
14 "net/http"
15 "sort"
16 "strings"
17 "time"
18
19 "github.com/gin-gonic/gin"
20
21 "edge-infra.dev/pkg/edge/iam/apperror"
22 "edge-infra.dev/pkg/edge/iam/config"
23 "edge-infra.dev/pkg/edge/iam/log"
24 "edge-infra.dev/pkg/edge/iam/types"
25 "edge-infra.dev/pkg/lib/fog"
26 )
27
28 type RoleClaimPart struct {
29 Indent int
30 Part string
31 }
32
33
34 func Serialize(roles []string) (string, error) {
35
36 sort.Strings(roles)
37
38 serializedParts := []RoleClaimPart{{Indent: 0, Part: "0"}}
39
40 for _, role := range roles {
41 parts := strings.Split(role, "_")
42
43 for p, part := range parts {
44 lastPartMatchingIndent := RoleClaimPart{}
45
46 for c := len(serializedParts) - 1; c >= 0; c-- {
47 if serializedParts[c].Indent == p {
48 lastPartMatchingIndent = serializedParts[c]
49 break
50 } else if serializedParts[c].Indent < p {
51 break
52 }
53 }
54
55 if lastPartMatchingIndent.Part == part {
56 continue
57 }
58
59 serializedParts = append(serializedParts, RoleClaimPart{Indent: p, Part: part})
60 }
61 }
62
63 var serialized []string
64 for _, part := range serializedParts {
65 serialized = append(serialized, strings.Repeat("\t", part.Indent)+part.Part)
66 }
67
68 return encodeRolesClaim(strings.Join(serialized, "\n"))
69 }
70
71 func encodeRolesClaim(bslRls string) (string, error) {
72 var b bytes.Buffer
73 writer := zlib.NewWriter(&b)
74 _, err := writer.Write([]byte(bslRls))
75 if err != nil {
76 return "", err
77 }
78 writer.Close()
79 return base64.StdEncoding.EncodeToString(b.Bytes()), nil
80 }
81
82
83 func decodeRolesClaim(claim string) ([]byte, error) {
84 compressedRolesClaim, err := base64.StdEncoding.DecodeString(claim)
85 if err != nil {
86 return nil, err
87 }
88
89 var buf bytes.Buffer
90 buf.Write(compressedRolesClaim)
91
92 r, err := zlib.NewReader(&buf)
93 if err != nil {
94 return nil, err
95 }
96 defer r.Close()
97 return io.ReadAll(r)
98 }
99
100 func isAvailable(url string) bool {
101 client := &http.Client{
102 Timeout: config.CloudIDPTimeout(),
103 }
104 if !config.IsProduction() {
105 return !config.IsWANDisrupted()
106 }
107 _, err := client.Get(url)
108
109
110 if _, ok := err.(net.Error); ok {
111 return false
112 }
113 return true
114 }
115
116
117
118
119 func IsCloudLoginAvailable() bool {
120 if config.DeviceLoginEnabled() {
121 return IsDeviceLoginAvailable()
122 }
123
124 return IsCloudIDPAvailable()
125 }
126
127
128 func IsDeviceLoginAvailable() bool {
129 return isAvailable(config.DeviceBaseURL())
130 }
131
132
133 func IsCloudIDPAvailable() bool {
134 if config.OktaEnabled() {
135 return isAvailable(config.OktaIssuer())
136 }
137
138 return isAvailable(config.BslSecurityURL())
139 }
140
141
142 func Deserialize(rolesClaim string) ([]string, error) {
143 extractedRolesClaim, err := decodeRolesClaim(rolesClaim)
144 if err != nil {
145 return nil, err
146 }
147
148 lines := strings.Split(string(extractedRolesClaim), "\n")
149 var extractedRoles, path []string
150
151 for i, line := range lines {
152 lineParts := strings.Split(line, "\t")
153 level := len(lineParts) - 1
154 path = path[:level]
155 path = append(path, lineParts[level])
156
157 nextLine := ""
158 if i < len(lines)-1 {
159 nextLine = lines[i+1]
160 }
161
162 if nextLine != "" {
163 nextLineParts := strings.Split(nextLine, "\t")
164 nextLevel := len(nextLineParts) - 1
165 if nextLevel > level {
166 continue
167 }
168 }
169
170
171 name := strings.Join(path, "_")
172 extractedRoles = append(extractedRoles, name)
173 }
174
175 return extractedRoles, nil
176 }
177 func RandomStringGenerator(length int) (string, error) {
178 charset := config.GetBarcodeCharset()
179 bytes := make([]byte, length)
180 _, err := rand.Read(bytes)
181 if err != nil {
182 return "", nil
183 }
184
185
186 for i, v := range bytes {
187 bytes[i] = charset[v%byte(len(charset))]
188 }
189 return string(bytes), nil
190 }
191 func IsElementExist(s []string, str string) bool {
192 for _, v := range s {
193 if strings.EqualFold(v, str) {
194 return true
195 }
196 }
197 return false
198 }
199
200 func ShortOperationID(c context.Context) string {
201 opID := fog.OperationID(c)
202 return opID[:8]
203 }
204
205 const (
206 SourceLocationKey = "caller"
207 )
208
209 func MakeHandlerFunc(f types.APIFunc) gin.HandlerFunc {
210 return func(c *gin.Context) {
211 if err := f(c); err != nil {
212 log := log.Get(c.Request.Context())
213
214 if redirectErr, ok := err.(apperror.Redirecter); ok {
215 log.Error(err, "redirect error", "short_operation_id", ShortOperationID(c.Request.Context()))
216
217 code, location, message := redirectErr.Redirect()
218 keysToMask := []string{"client_id"}
219 maskedLocation, _ := MaskURLQuery(location, keysToMask...)
220 msg := fmt.Sprintf("[%v] - (%d) redirecting to `%v`", ShortOperationID(c.Request.Context()), code, maskedLocation)
221 log.Info(fmt.Sprintf("%v. %v", msg, message))
222
223 c.Redirect(code, location)
224 return
225 }
226
227 if abortErr, ok := err.(apperror.ErrorAborter); ok {
228 code, e := abortErr.AbortError()
229 msg := fmt.Sprintf("[%v] - (%d) aborting with error", ShortOperationID(c.Request.Context()), code)
230
231 location := abortErr.SourceLocation()
232 if location.File != "" {
233 log.Error(abortErr, msg, SourceLocationKey, location.ToMap())
234 } else {
235 log.Error(abortErr, msg)
236 }
237
238 c.AbortWithError(code, e)
239 return
240 }
241
242 if jsonErr, ok := err.(apperror.JSONResponder); ok {
243 code, jsonObj := jsonErr.JSONResponse()
244 msg := fmt.Sprintf("[%v] - (%d) aborting with json", ShortOperationID(c.Request.Context()), code)
245
246 location := jsonErr.SourceLocation()
247 if location.File != "" {
248 log.Error(jsonErr, msg, "details", jsonErr.JSONDetails(), SourceLocationKey, location.ToMap())
249 } else {
250 log.Error(jsonErr, msg, "details", jsonErr.JSONDetails())
251 }
252
253 c.AbortWithStatusJSON(code, jsonObj)
254 return
255 }
256
257 if statusErr, ok := err.(apperror.StatusCoder); ok {
258 code := statusErr.StatusCode()
259 msg := fmt.Sprintf("[%v] - (%d) - aborting with status", ShortOperationID(c.Request.Context()), code)
260
261 location := statusErr.SourceLocation()
262 if location.File != "" {
263 log.Error(statusErr, msg, SourceLocationKey, location.ToMap())
264 } else {
265 log.Error(statusErr, msg)
266 }
267
268 c.AbortWithStatus(code)
269 return
270 }
271
272
273 log.Error(err, "unexpected error occurred")
274
275 c.AbortWithError(500, errors.New("unexpected error occurred"))
276 return
277 }
278 }
279 }
280
281
282 func WriteJSON(w http.ResponseWriter, status int, v any) error {
283 w.Header().Set("Content-Type", "application/json")
284 w.WriteHeader(status)
285
286 return json.NewEncoder(w).Encode(v)
287 }
288
289 func CalculateAge(dob string, nowFunc func() time.Time, timezone string) (int, error) {
290
291 birthDate, err := time.Parse("2006-01-02", dob)
292 if err != nil {
293 return 0, err
294 }
295
296 location, locationErr := time.LoadLocation(timezone)
297 var currentDate time.Time
298 if locationErr != nil {
299 currentDate = nowFunc()
300 } else {
301 currentDate = nowFunc().In(location)
302 }
303
304 age := currentDate.Year() - birthDate.Year()
305
306 if currentDate.YearDay() < birthDate.YearDay() {
307 age--
308 }
309 if age < 0 {
310 return 0, fmt.Errorf("date is in future")
311 }
312 return age, nil
313 }
314
View as plain text