package handlers import ( "errors" "net/http" "time" "edge-infra.dev/pkg/edge/auth-proxy/types" "github.com/gin-contrib/sessions" ) const ( headerKeyUsername = "X-Auth-Username" headerKeyEmail = "X-Auth-Email" headerKeyRoles = "X-Auth-Roles" headerKeyBannerEdgeIDs = "X-Auth-Banners" ) func validSession(session sessions.Session) bool { expiresAt := session.Get(types.SessionExpirationField) if expiresAt != nil { expirationTime := expiresAt.(time.Time) return !expirationTime.Before(time.Now().UTC()) && session.Get(types.SessionIDField) != nil } return false } // SessionUserDetails handler that adds user details to the incoming request header. // // If any error occurs this handler MUST NOT return the error. If the error is // returned the default incoming request is used in the proxy to all upstream // servers. This incoming request might contain auth headers that an attacker // can use to gain unauthorized access to the emergencyaccess solution. // If an error occurs the handler MUST log the error, and return a nil error. It // also MUST make sure the returned http.Request does not contain any auth headers. func (h ProxyHandler) SessionUserDetails(req *http.Request, body []byte) (*http.Request, []byte, error) { // Make sure any connecting user cannot inject any auth headers into the // request. req.Header.Del(headerKeyUsername) req.Header.Del(headerKeyEmail) req.Header.Del(headerKeyRoles) req.Header.Del(headerKeyBannerEdgeIDs) if !validSession(h.session) { // MUST NOT return non-nil error h.log.Error(errors.New("invalid session"), "invalid session") return req, body, nil } var err error username, ok := h.session.Get(types.SessionUsernameField).(string) if !ok { err = errors.Join(err, errors.New("unable to get username from session")) } email, ok := h.session.Get(types.SessionEmailField).(string) if !ok { err = errors.Join(err, errors.New("unable to get email from session")) } roles, ok := h.session.Get(types.SessionRolesField).([]string) if !ok { err = errors.Join(err, errors.New("unable to get roles from session")) } bannerEdgeIDs, ok := h.session.Get(types.SessionBannerEdgeIDs).([]string) if !ok { err = errors.Join(err, errors.New("unable to get assigned banners from session")) } if err != nil { // MUST NOT return non-nil error h.log.Error(err, "failed to extract user details") return req, body, nil } req.Header.Set(headerKeyUsername, username) req.Header.Set(headerKeyEmail, email) for _, role := range roles { req.Header.Add(headerKeyRoles, role) } for _, bannerEdgeID := range bannerEdgeIDs { req.Header.Add(headerKeyBannerEdgeIDs, bannerEdgeID) } return req, body, nil }