package handlers

import (
	"errors"
	"net/http"
	"time"

	"edge-infra.dev/pkg/edge/auth-proxy/types"

	"github.com/gin-contrib/sessions"
)

const (
	headerKeyUsername      = "X-Auth-Username"
	headerKeyEmail         = "X-Auth-Email"
	headerKeyRoles         = "X-Auth-Roles"
	headerKeyBannerEdgeIDs = "X-Auth-Banners"
)

func validSession(session sessions.Session) bool {
	expiresAt := session.Get(types.SessionExpirationField)
	if expiresAt != nil {
		expirationTime := expiresAt.(time.Time)
		return !expirationTime.Before(time.Now().UTC()) && session.Get(types.SessionIDField) != nil
	}
	return false
}

// SessionUserDetails handler that adds user details to the incoming request header.
//
// If any error occurs this handler MUST NOT return the error. If the error is
// returned the default incoming request is used in the proxy to all upstream
// servers. This incoming request might contain auth headers that an attacker
// can use to gain unauthorized access to the emergencyaccess solution.
// If an error occurs the handler MUST log the error, and return a nil error. It
// also MUST make sure the returned http.Request does not contain any auth headers.
func (h ProxyHandler) SessionUserDetails(req *http.Request, body []byte) (*http.Request, []byte, error) {
	// Make sure any connecting user cannot inject any auth headers into the
	// request.
	req.Header.Del(headerKeyUsername)
	req.Header.Del(headerKeyEmail)
	req.Header.Del(headerKeyRoles)
	req.Header.Del(headerKeyBannerEdgeIDs)

	if !validSession(h.session) {
		// MUST NOT return non-nil error
		h.log.Error(errors.New("invalid session"), "invalid session")
		return req, body, nil
	}

	var err error
	username, ok := h.session.Get(types.SessionUsernameField).(string)
	if !ok {
		err = errors.Join(err, errors.New("unable to get username from session"))
	}
	email, ok := h.session.Get(types.SessionEmailField).(string)
	if !ok {
		err = errors.Join(err, errors.New("unable to get email from session"))
	}
	roles, ok := h.session.Get(types.SessionRolesField).([]string)
	if !ok {
		err = errors.Join(err, errors.New("unable to get roles from session"))
	}
	bannerEdgeIDs, ok := h.session.Get(types.SessionBannerEdgeIDs).([]string)
	if !ok {
		err = errors.Join(err, errors.New("unable to get assigned banners from session"))
	}
	if err != nil {
		// MUST NOT return non-nil error
		h.log.Error(err, "failed to extract user details")
		return req, body, nil
	}

	req.Header.Set(headerKeyUsername, username)
	req.Header.Set(headerKeyEmail, email)
	for _, role := range roles {
		req.Header.Add(headerKeyRoles, role)
	}
	for _, bannerEdgeID := range bannerEdgeIDs {
		req.Header.Add(headerKeyBannerEdgeIDs, bannerEdgeID)
	}
	return req, body, nil
}