package database import ( "bytes" "context" "crypto/rand" "encoding/base32" "encoding/gob" "errors" "io" "net/http" "strings" "time" "edge-infra.dev/pkg/edge/iam/config" "edge-infra.dev/pkg/edge/iam/crypto" "github.com/go-redis/redis" "github.com/gorilla/sessions" ) // SessionStore stores gorilla sessions in Redis type SessionStore struct { // client to connect to redis client redis.UniversalClient // default options to use when a new session is created options sessions.Options // key prefix with which the session will be stored keyPrefix string // key generator keyGen KeyGenFunc // session serializer serializer SessionSerializer } // KeyGenFunc defines a function used by store to generate a key type KeyGenFunc func() (string, error) // NewRedisSessionStore returns a new SessionStore with default configuration func NewRedisSessionStore(_ context.Context, client redis.UniversalClient) (*SessionStore, error) { rs := &SessionStore{ options: sessions.Options{ Path: "/", MaxAge: config.SessionCookieMaxAge(), HttpOnly: true, //The request must be a top-level navigation. You can think of this as equivalent to when the URL shown in the URL bar changes, e.g. a user clicking on a link to go to another site. // The request method must be safe (e.g. GET or HEAD, but not POST). SameSite: http.SameSiteLaxMode, }, client: client, keyPrefix: "session:", keyGen: generateRandomKey, serializer: GobSerializer{}, } return rs, rs.client.Ping().Err() } // Get returns a session for the given name after adding it to the registry. func (s *SessionStore) Get(r *http.Request, name string) (*sessions.Session, error) { return s.New(r, name) } // New returns a session for the given name without adding it to the registry. func (s *SessionStore) New(r *http.Request, name string) (*sessions.Session, error) { session := sessions.NewSession(s, name) opts := s.options session.Options = &opts session.IsNew = true c, err := r.Cookie(name) if err != nil { return session, nil } session.ID = c.Value err = s.load(session) if err == nil { session.IsNew = false } else if err == redis.Nil { err = nil // no data stored } return session, err } // Save adds a single session to the response. // // If the Options.MaxAge of the session is <= 0 then the session file will be // deleted from the store. With this process it enforces the properly // session cookie handling so no need to trust in the cookie management in the // web browser. func (s *SessionStore) Save(_ *http.Request, w http.ResponseWriter, session *sessions.Session) error { // Delete if max-age is <= 0 if session.Options.MaxAge <= 0 { if err := s.delete(session); err != nil { return err } http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options)) return nil } if session.ID == "" { id, err := s.keyGen() if err != nil { return errors.New("SessionStore: failed to generate session id") } session.ID = id } if err := s.save(session); err != nil { return err } http.SetCookie(w, sessions.NewCookie(session.Name(), session.ID, session.Options)) return nil } // Options set options to use when a new session is created func (s *SessionStore) Options(opts sessions.Options) { s.options = opts } // KeyPrefix sets the key prefix to store session in Redis func (s *SessionStore) KeyPrefix(keyPrefix string) { s.keyPrefix = keyPrefix } // KeyGen sets the key generator function func (s *SessionStore) KeyGen(f KeyGenFunc) { s.keyGen = f } // Serializer sets the session serializer to store session func (s *SessionStore) Serializer(ss SessionSerializer) { s.serializer = ss } // Close closes the Redis store func (s *SessionStore) Close() error { return s.client.Close() } // save writes session in Redis func (s *SessionStore) save(session *sessions.Session) error { b, err := s.serializer.Serialize(session) if err != nil { return err } if config.EncryptionEnabled() { encryptedValue, err := crypto.EncryptRedis(b, config.EncryptionKey()) if err != nil { return err } return s.client.Set(s.keyPrefix+session.ID, encryptedValue, time.Duration(session.Options.MaxAge)*time.Second).Err() } return s.client.Set(s.keyPrefix+session.ID, b, time.Duration(session.Options.MaxAge)*time.Second).Err() } // load reads session from Redis func (s *SessionStore) load(session *sessions.Session) error { cmd := s.client.Get(s.keyPrefix + session.ID) if cmd.Err() != nil { return cmd.Err() } b, err := cmd.Bytes() if err != nil { return err } if config.EncryptionEnabled() { // nolint:nestif // if data is not encrypted, encrypt + save, then return unencrypted for use if !isRedisDataEncrypted(b) { ttl := s.client.TTL(s.keyPrefix + session.ID) encryptedVal, err := crypto.EncryptRedis(b, config.EncryptionKey()) if err != nil { return err } // update with encrypted value err = s.client.Set(s.keyPrefix+session.ID, encryptedVal, ttl.Val()).Err() if err != nil { return err } // return to use return s.serializer.Deserialize(b, session) } decryptedValue, err := crypto.DecryptRedis(string(b), config.EncryptionKey()) if err != nil { return err } return s.serializer.Deserialize(decryptedValue, session) } return s.serializer.Deserialize(b, session) } // delete deletes session in Redis func (s *SessionStore) delete(session *sessions.Session) error { return s.client.Del(s.keyPrefix + session.ID).Err() } // SessionSerializer provides an interface for serialize/deserialize a session type SessionSerializer interface { Serialize(s *sessions.Session) ([]byte, error) Deserialize(b []byte, s *sessions.Session) error } // Gob serializer type GobSerializer struct{} func (gs GobSerializer) Serialize(s *sessions.Session) ([]byte, error) { buf := new(bytes.Buffer) enc := gob.NewEncoder(buf) err := enc.Encode(s.Values) if err == nil { return buf.Bytes(), nil } return nil, err } func (gs GobSerializer) Deserialize(d []byte, s *sessions.Session) error { dec := gob.NewDecoder(bytes.NewBuffer(d)) return dec.Decode(&s.Values) } // generateRandomKey returns a new random key func generateRandomKey() (string, error) { k := make([]byte, 64) if _, err := io.ReadFull(rand.Reader, k); err != nil { return "", err } return strings.TrimRight(base32.StdEncoding.EncodeToString(k), "="), nil } // does the data contain the EncryptedData prefix? func isRedisDataEncrypted(data []byte) bool { searchStr := []byte("EncryptedData:") return bytes.Contains(data, searchStr) }