...

Source file src/edge-infra.dev/pkg/edge/iam/util/util.go

Documentation: edge-infra.dev/pkg/edge/iam/util

     1  package util
     2  
     3  import (
     4  	"bytes"
     5  	"compress/zlib"
     6  	"context"
     7  	"crypto/rand"
     8  	"encoding/base64"
     9  	"encoding/json"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"net/http"
    15  	"sort"
    16  	"strings"
    17  	"time"
    18  
    19  	"github.com/gin-gonic/gin"
    20  
    21  	"edge-infra.dev/pkg/edge/iam/apperror"
    22  	"edge-infra.dev/pkg/edge/iam/config"
    23  	"edge-infra.dev/pkg/edge/iam/log"
    24  	"edge-infra.dev/pkg/edge/iam/types"
    25  	"edge-infra.dev/pkg/lib/fog"
    26  )
    27  
    28  type RoleClaimPart struct {
    29  	Indent int
    30  	Part   string
    31  }
    32  
    33  // Serialize: Accept Roles Array and
    34  func Serialize(roles []string) (string, error) {
    35  	// Sorting the roles in assending order
    36  	sort.Strings(roles)
    37  
    38  	serializedParts := []RoleClaimPart{{Indent: 0, Part: "0"}}
    39  
    40  	for _, role := range roles {
    41  		parts := strings.Split(role, "_")
    42  
    43  		for p, part := range parts {
    44  			lastPartMatchingIndent := RoleClaimPart{}
    45  
    46  			for c := len(serializedParts) - 1; c >= 0; c-- {
    47  				if serializedParts[c].Indent == p {
    48  					lastPartMatchingIndent = serializedParts[c]
    49  					break
    50  				} else if serializedParts[c].Indent < p {
    51  					break
    52  				}
    53  			}
    54  
    55  			if lastPartMatchingIndent.Part == part {
    56  				continue
    57  			}
    58  
    59  			serializedParts = append(serializedParts, RoleClaimPart{Indent: p, Part: part})
    60  		}
    61  	}
    62  
    63  	var serialized []string
    64  	for _, part := range serializedParts {
    65  		serialized = append(serialized, strings.Repeat("\t", part.Indent)+part.Part)
    66  	}
    67  
    68  	return encodeRolesClaim(strings.Join(serialized, "\n"))
    69  }
    70  
    71  func encodeRolesClaim(bslRls string) (string, error) {
    72  	var b bytes.Buffer
    73  	writer := zlib.NewWriter(&b)
    74  	_, err := writer.Write([]byte(bslRls))
    75  	if err != nil {
    76  		return "", err
    77  	}
    78  	writer.Close()
    79  	return base64.StdEncoding.EncodeToString(b.Bytes()), nil
    80  }
    81  
    82  // decodeRolesClaim: decode base64 roleClaim string and decompress rls claim to UTF-8 String
    83  func decodeRolesClaim(claim string) ([]byte, error) {
    84  	compressedRolesClaim, err := base64.StdEncoding.DecodeString(claim)
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  
    89  	var buf bytes.Buffer
    90  	buf.Write(compressedRolesClaim)
    91  	// Decompress decoded rls.
    92  	r, err := zlib.NewReader(&buf)
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  	defer r.Close()
    97  	return io.ReadAll(r)
    98  }
    99  
   100  func isAvailable(url string) bool {
   101  	client := &http.Client{
   102  		Timeout: config.CloudIDPTimeout(),
   103  	}
   104  	if !config.IsProduction() {
   105  		return !config.IsWANDisrupted()
   106  	}
   107  	_, err := client.Get(url)
   108  
   109  	// we timed-out, we're OK with skipping verification
   110  	if _, ok := err.(net.Error); ok {
   111  		return false
   112  	}
   113  	return true
   114  }
   115  
   116  // IsCloudLoginAvailable calls a service endpoint and returns if it succeeded or not.
   117  // If device login is enabled - it will be called,
   118  // Otherwise - based on which service is enabled between BSL or Okta - it will be called.
   119  func IsCloudLoginAvailable() bool {
   120  	if config.DeviceLoginEnabled() {
   121  		return IsDeviceLoginAvailable()
   122  	}
   123  
   124  	return IsCloudIDPAvailable()
   125  }
   126  
   127  // IsDeviceLoginAvailable checks if device login services are available.
   128  func IsDeviceLoginAvailable() bool {
   129  	return isAvailable(config.DeviceBaseURL())
   130  }
   131  
   132  // IsCloudIDPAvailable checks if Okta or BSL services are available
   133  func IsCloudIDPAvailable() bool {
   134  	if config.OktaEnabled() {
   135  		return isAvailable(config.OktaIssuer())
   136  	}
   137  
   138  	return isAvailable(config.BslSecurityURL())
   139  }
   140  
   141  // Deserialize:
   142  func Deserialize(rolesClaim string) ([]string, error) {
   143  	extractedRolesClaim, err := decodeRolesClaim(rolesClaim)
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  
   148  	lines := strings.Split(string(extractedRolesClaim), "\n")
   149  	var extractedRoles, path []string
   150  
   151  	for i, line := range lines {
   152  		lineParts := strings.Split(line, "\t")
   153  		level := len(lineParts) - 1
   154  		path = path[:level]
   155  		path = append(path, lineParts[level])
   156  
   157  		nextLine := ""
   158  		if i < len(lines)-1 {
   159  			nextLine = lines[i+1]
   160  		}
   161  
   162  		if nextLine != "" {
   163  			nextLineParts := strings.Split(nextLine, "\t")
   164  			nextLevel := len(nextLineParts) - 1
   165  			if nextLevel > level {
   166  				continue
   167  			}
   168  		}
   169  
   170  		// _ is the role delimiter defined by BSL
   171  		name := strings.Join(path, "_")
   172  		extractedRoles = append(extractedRoles, name)
   173  	}
   174  
   175  	return extractedRoles, nil
   176  }
   177  func RandomStringGenerator(length int) (string, error) {
   178  	charset := config.GetBarcodeCharset()
   179  	bytes := make([]byte, length)
   180  	_, err := rand.Read(bytes)
   181  	if err != nil {
   182  		return "", nil
   183  	}
   184  
   185  	// map characters in our charset due to the limitations of barcode.
   186  	for i, v := range bytes {
   187  		bytes[i] = charset[v%byte(len(charset))]
   188  	}
   189  	return string(bytes), nil
   190  }
   191  func IsElementExist(s []string, str string) bool {
   192  	for _, v := range s {
   193  		if strings.EqualFold(v, str) {
   194  			return true
   195  		}
   196  	}
   197  	return false
   198  }
   199  
   200  func ShortOperationID(c context.Context) string {
   201  	opID := fog.OperationID(c)
   202  	return opID[:8]
   203  }
   204  
   205  const (
   206  	SourceLocationKey = "caller"
   207  )
   208  
   209  func MakeHandlerFunc(f types.APIFunc) gin.HandlerFunc {
   210  	return func(c *gin.Context) {
   211  		if err := f(c); err != nil { //   nolint:nestif
   212  			log := log.Get(c.Request.Context())
   213  
   214  			if redirectErr, ok := err.(apperror.Redirecter); ok {
   215  				log.Error(err, "redirect error", "short_operation_id", ShortOperationID(c.Request.Context()))
   216  
   217  				code, location, message := redirectErr.Redirect()
   218  				keysToMask := []string{"client_id"}
   219  				maskedLocation, _ := MaskURLQuery(location, keysToMask...)
   220  				msg := fmt.Sprintf("[%v] - (%d) redirecting to `%v`", ShortOperationID(c.Request.Context()), code, maskedLocation)
   221  				log.Info(fmt.Sprintf("%v. %v", msg, message))
   222  
   223  				c.Redirect(code, location)
   224  				return
   225  			}
   226  
   227  			if abortErr, ok := err.(apperror.ErrorAborter); ok {
   228  				code, e := abortErr.AbortError()
   229  				msg := fmt.Sprintf("[%v] - (%d) aborting with error", ShortOperationID(c.Request.Context()), code)
   230  
   231  				location := abortErr.SourceLocation()
   232  				if location.File != "" {
   233  					log.Error(abortErr, msg, SourceLocationKey, location.ToMap())
   234  				} else {
   235  					log.Error(abortErr, msg)
   236  				}
   237  
   238  				c.AbortWithError(code, e) //nolint:errcheck
   239  				return
   240  			}
   241  
   242  			if jsonErr, ok := err.(apperror.JSONResponder); ok {
   243  				code, jsonObj := jsonErr.JSONResponse()
   244  				msg := fmt.Sprintf("[%v] - (%d) aborting with json", ShortOperationID(c.Request.Context()), code)
   245  
   246  				location := jsonErr.SourceLocation()
   247  				if location.File != "" {
   248  					log.Error(jsonErr, msg, "details", jsonErr.JSONDetails(), SourceLocationKey, location.ToMap())
   249  				} else {
   250  					log.Error(jsonErr, msg, "details", jsonErr.JSONDetails())
   251  				}
   252  
   253  				c.AbortWithStatusJSON(code, jsonObj)
   254  				return
   255  			}
   256  
   257  			if statusErr, ok := err.(apperror.StatusCoder); ok {
   258  				code := statusErr.StatusCode()
   259  				msg := fmt.Sprintf("[%v] - (%d) - aborting with status", ShortOperationID(c.Request.Context()), code)
   260  
   261  				location := statusErr.SourceLocation()
   262  				if location.File != "" {
   263  					log.Error(statusErr, msg, SourceLocationKey, location.ToMap())
   264  				} else {
   265  					log.Error(statusErr, msg)
   266  				}
   267  
   268  				c.AbortWithStatus(code)
   269  				return
   270  			}
   271  
   272  			// handle non AppError
   273  			log.Error(err, "unexpected error occurred")
   274  
   275  			c.AbortWithError(500, errors.New("unexpected error occurred")) //nolint:errcheck
   276  			return
   277  		}
   278  	}
   279  }
   280  
   281  // WriteJSON sets the status and writes json into the stream
   282  func WriteJSON(w http.ResponseWriter, status int, v any) error {
   283  	w.Header().Set("Content-Type", "application/json")
   284  	w.WriteHeader(status)
   285  
   286  	return json.NewEncoder(w).Encode(v)
   287  }
   288  
   289  func CalculateAge(dob string, nowFunc func() time.Time, timezone string) (int, error) {
   290  	// Parse the date of birth, the date below is the constant layout in Go: https://pkg.go.dev/time#pkg-constants
   291  	birthDate, err := time.Parse("2006-01-02", dob)
   292  	if err != nil {
   293  		return 0, err
   294  	}
   295  	// use the time.Now() passed in, to get the current date. This is for testability :)
   296  	location, locationErr := time.LoadLocation(timezone)
   297  	var currentDate time.Time
   298  	if locationErr != nil {
   299  		currentDate = nowFunc()
   300  	} else {
   301  		currentDate = nowFunc().In(location)
   302  	}
   303  	// Calculate the age based on year
   304  	age := currentDate.Year() - birthDate.Year()
   305  	// Adjust the age if the birthday hasn't occurred yet this year
   306  	if currentDate.YearDay() < birthDate.YearDay() {
   307  		age--
   308  	}
   309  	if age < 0 {
   310  		return 0, fmt.Errorf("date is in future")
   311  	}
   312  	return age, nil
   313  }
   314  

View as plain text