package util import ( "bytes" "compress/zlib" "context" "crypto/rand" "encoding/base64" "encoding/json" "errors" "fmt" "io" "net" "net/http" "sort" "strings" "time" "github.com/gin-gonic/gin" "edge-infra.dev/pkg/edge/iam/apperror" "edge-infra.dev/pkg/edge/iam/config" "edge-infra.dev/pkg/edge/iam/log" "edge-infra.dev/pkg/edge/iam/types" "edge-infra.dev/pkg/lib/fog" ) type RoleClaimPart struct { Indent int Part string } // Serialize: Accept Roles Array and func Serialize(roles []string) (string, error) { // Sorting the roles in assending order sort.Strings(roles) serializedParts := []RoleClaimPart{{Indent: 0, Part: "0"}} for _, role := range roles { parts := strings.Split(role, "_") for p, part := range parts { lastPartMatchingIndent := RoleClaimPart{} for c := len(serializedParts) - 1; c >= 0; c-- { if serializedParts[c].Indent == p { lastPartMatchingIndent = serializedParts[c] break } else if serializedParts[c].Indent < p { break } } if lastPartMatchingIndent.Part == part { continue } serializedParts = append(serializedParts, RoleClaimPart{Indent: p, Part: part}) } } var serialized []string for _, part := range serializedParts { serialized = append(serialized, strings.Repeat("\t", part.Indent)+part.Part) } return encodeRolesClaim(strings.Join(serialized, "\n")) } func encodeRolesClaim(bslRls string) (string, error) { var b bytes.Buffer writer := zlib.NewWriter(&b) _, err := writer.Write([]byte(bslRls)) if err != nil { return "", err } writer.Close() return base64.StdEncoding.EncodeToString(b.Bytes()), nil } // decodeRolesClaim: decode base64 roleClaim string and decompress rls claim to UTF-8 String func decodeRolesClaim(claim string) ([]byte, error) { compressedRolesClaim, err := base64.StdEncoding.DecodeString(claim) if err != nil { return nil, err } var buf bytes.Buffer buf.Write(compressedRolesClaim) // Decompress decoded rls. r, err := zlib.NewReader(&buf) if err != nil { return nil, err } defer r.Close() return io.ReadAll(r) } func isAvailable(url string) bool { client := &http.Client{ Timeout: config.CloudIDPTimeout(), } if !config.IsProduction() { return !config.IsWANDisrupted() } _, err := client.Get(url) // we timed-out, we're OK with skipping verification if _, ok := err.(net.Error); ok { return false } return true } // IsCloudLoginAvailable calls a service endpoint and returns if it succeeded or not. // If device login is enabled - it will be called, // Otherwise - based on which service is enabled between BSL or Okta - it will be called. func IsCloudLoginAvailable() bool { if config.DeviceLoginEnabled() { return IsDeviceLoginAvailable() } return IsCloudIDPAvailable() } // IsDeviceLoginAvailable checks if device login services are available. func IsDeviceLoginAvailable() bool { return isAvailable(config.DeviceBaseURL()) } // IsCloudIDPAvailable checks if Okta or BSL services are available func IsCloudIDPAvailable() bool { if config.OktaEnabled() { return isAvailable(config.OktaIssuer()) } return isAvailable(config.BslSecurityURL()) } // Deserialize: func Deserialize(rolesClaim string) ([]string, error) { extractedRolesClaim, err := decodeRolesClaim(rolesClaim) if err != nil { return nil, err } lines := strings.Split(string(extractedRolesClaim), "\n") var extractedRoles, path []string for i, line := range lines { lineParts := strings.Split(line, "\t") level := len(lineParts) - 1 path = path[:level] path = append(path, lineParts[level]) nextLine := "" if i < len(lines)-1 { nextLine = lines[i+1] } if nextLine != "" { nextLineParts := strings.Split(nextLine, "\t") nextLevel := len(nextLineParts) - 1 if nextLevel > level { continue } } // _ is the role delimiter defined by BSL name := strings.Join(path, "_") extractedRoles = append(extractedRoles, name) } return extractedRoles, nil } func RandomStringGenerator(length int) (string, error) { charset := config.GetBarcodeCharset() bytes := make([]byte, length) _, err := rand.Read(bytes) if err != nil { return "", nil } // map characters in our charset due to the limitations of barcode. for i, v := range bytes { bytes[i] = charset[v%byte(len(charset))] } return string(bytes), nil } func IsElementExist(s []string, str string) bool { for _, v := range s { if strings.EqualFold(v, str) { return true } } return false } func ShortOperationID(c context.Context) string { opID := fog.OperationID(c) return opID[:8] } const ( SourceLocationKey = "caller" ) func MakeHandlerFunc(f types.APIFunc) gin.HandlerFunc { return func(c *gin.Context) { if err := f(c); err != nil { // nolint:nestif log := log.Get(c.Request.Context()) if redirectErr, ok := err.(apperror.Redirecter); ok { log.Error(err, "redirect error", "short_operation_id", ShortOperationID(c.Request.Context())) code, location, message := redirectErr.Redirect() keysToMask := []string{"client_id"} maskedLocation, _ := MaskURLQuery(location, keysToMask...) msg := fmt.Sprintf("[%v] - (%d) redirecting to `%v`", ShortOperationID(c.Request.Context()), code, maskedLocation) log.Info(fmt.Sprintf("%v. %v", msg, message)) c.Redirect(code, location) return } if abortErr, ok := err.(apperror.ErrorAborter); ok { code, e := abortErr.AbortError() msg := fmt.Sprintf("[%v] - (%d) aborting with error", ShortOperationID(c.Request.Context()), code) location := abortErr.SourceLocation() if location.File != "" { log.Error(abortErr, msg, SourceLocationKey, location.ToMap()) } else { log.Error(abortErr, msg) } c.AbortWithError(code, e) //nolint:errcheck return } if jsonErr, ok := err.(apperror.JSONResponder); ok { code, jsonObj := jsonErr.JSONResponse() msg := fmt.Sprintf("[%v] - (%d) aborting with json", ShortOperationID(c.Request.Context()), code) location := jsonErr.SourceLocation() if location.File != "" { log.Error(jsonErr, msg, "details", jsonErr.JSONDetails(), SourceLocationKey, location.ToMap()) } else { log.Error(jsonErr, msg, "details", jsonErr.JSONDetails()) } c.AbortWithStatusJSON(code, jsonObj) return } if statusErr, ok := err.(apperror.StatusCoder); ok { code := statusErr.StatusCode() msg := fmt.Sprintf("[%v] - (%d) - aborting with status", ShortOperationID(c.Request.Context()), code) location := statusErr.SourceLocation() if location.File != "" { log.Error(statusErr, msg, SourceLocationKey, location.ToMap()) } else { log.Error(statusErr, msg) } c.AbortWithStatus(code) return } // handle non AppError log.Error(err, "unexpected error occurred") c.AbortWithError(500, errors.New("unexpected error occurred")) //nolint:errcheck return } } } // WriteJSON sets the status and writes json into the stream func WriteJSON(w http.ResponseWriter, status int, v any) error { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) return json.NewEncoder(w).Encode(v) } func CalculateAge(dob string, nowFunc func() time.Time, timezone string) (int, error) { // Parse the date of birth, the date below is the constant layout in Go: https://pkg.go.dev/time#pkg-constants birthDate, err := time.Parse("2006-01-02", dob) if err != nil { return 0, err } // use the time.Now() passed in, to get the current date. This is for testability :) location, locationErr := time.LoadLocation(timezone) var currentDate time.Time if locationErr != nil { currentDate = nowFunc() } else { currentDate = nowFunc().In(location) } // Calculate the age based on year age := currentDate.Year() - birthDate.Year() // Adjust the age if the birthday hasn't occurred yet this year if currentDate.YearDay() < birthDate.YearDay() { age-- } if age < 0 { return 0, fmt.Errorf("date is in future") } return age, nil }