...

Source file src/edge-infra.dev/pkg/edge/iam/storage/database/redis_session.go

Documentation: edge-infra.dev/pkg/edge/iam/storage/database

     1  package database
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/rand"
     7  	"encoding/base32"
     8  	"encoding/gob"
     9  	"errors"
    10  	"io"
    11  	"net/http"
    12  	"strings"
    13  	"time"
    14  
    15  	"edge-infra.dev/pkg/edge/iam/config"
    16  	"edge-infra.dev/pkg/edge/iam/crypto"
    17  
    18  	"github.com/go-redis/redis"
    19  	"github.com/gorilla/sessions"
    20  )
    21  
    22  // SessionStore stores gorilla sessions in Redis
    23  type SessionStore struct {
    24  	// client to connect to redis
    25  	client redis.UniversalClient
    26  	// default options to use when a new session is created
    27  	options sessions.Options
    28  	// key prefix with which the session will be stored
    29  	keyPrefix string
    30  	// key generator
    31  	keyGen KeyGenFunc
    32  	// session serializer
    33  	serializer SessionSerializer
    34  }
    35  
    36  // KeyGenFunc defines a function used by store to generate a key
    37  type KeyGenFunc func() (string, error)
    38  
    39  // NewRedisSessionStore returns a new SessionStore with default configuration
    40  func NewRedisSessionStore(_ context.Context, client redis.UniversalClient) (*SessionStore, error) {
    41  	rs := &SessionStore{
    42  		options: sessions.Options{
    43  			Path:     "/",
    44  			MaxAge:   config.SessionCookieMaxAge(),
    45  			HttpOnly: true,
    46  			//The request must be a top-level navigation. You can think of this as equivalent to when the URL shown in the URL bar changes, e.g. a user clicking on a link to go to another site.
    47  			// The request method must be safe (e.g. GET or HEAD, but not POST).
    48  			SameSite: http.SameSiteLaxMode,
    49  		},
    50  		client:     client,
    51  		keyPrefix:  "session:",
    52  		keyGen:     generateRandomKey,
    53  		serializer: GobSerializer{},
    54  	}
    55  
    56  	return rs, rs.client.Ping().Err()
    57  }
    58  
    59  // Get returns a session for the given name after adding it to the registry.
    60  func (s *SessionStore) Get(r *http.Request, name string) (*sessions.Session, error) {
    61  	return s.New(r, name)
    62  }
    63  
    64  // New returns a session for the given name without adding it to the registry.
    65  func (s *SessionStore) New(r *http.Request, name string) (*sessions.Session, error) {
    66  	session := sessions.NewSession(s, name)
    67  	opts := s.options
    68  	session.Options = &opts
    69  	session.IsNew = true
    70  
    71  	c, err := r.Cookie(name)
    72  	if err != nil {
    73  		return session, nil
    74  	}
    75  	session.ID = c.Value
    76  
    77  	err = s.load(session)
    78  	if err == nil {
    79  		session.IsNew = false
    80  	} else if err == redis.Nil {
    81  		err = nil // no data stored
    82  	}
    83  	return session, err
    84  }
    85  
    86  // Save adds a single session to the response.
    87  //
    88  // If the Options.MaxAge of the session is <= 0 then the session file will be
    89  // deleted from the store. With this process it enforces the properly
    90  // session cookie handling so no need to trust in the cookie management in the
    91  // web browser.
    92  func (s *SessionStore) Save(_ *http.Request, w http.ResponseWriter, session *sessions.Session) error {
    93  	// Delete if max-age is <= 0
    94  	if session.Options.MaxAge <= 0 {
    95  		if err := s.delete(session); err != nil {
    96  			return err
    97  		}
    98  		http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
    99  		return nil
   100  	}
   101  
   102  	if session.ID == "" {
   103  		id, err := s.keyGen()
   104  		if err != nil {
   105  			return errors.New("SessionStore: failed to generate session id")
   106  		}
   107  		session.ID = id
   108  	}
   109  	if err := s.save(session); err != nil {
   110  		return err
   111  	}
   112  
   113  	http.SetCookie(w, sessions.NewCookie(session.Name(), session.ID, session.Options))
   114  	return nil
   115  }
   116  
   117  // Options set options to use when a new session is created
   118  func (s *SessionStore) Options(opts sessions.Options) {
   119  	s.options = opts
   120  }
   121  
   122  // KeyPrefix sets the key prefix to store session in Redis
   123  func (s *SessionStore) KeyPrefix(keyPrefix string) {
   124  	s.keyPrefix = keyPrefix
   125  }
   126  
   127  // KeyGen sets the key generator function
   128  func (s *SessionStore) KeyGen(f KeyGenFunc) {
   129  	s.keyGen = f
   130  }
   131  
   132  // Serializer sets the session serializer to store session
   133  func (s *SessionStore) Serializer(ss SessionSerializer) {
   134  	s.serializer = ss
   135  }
   136  
   137  // Close closes the Redis store
   138  func (s *SessionStore) Close() error {
   139  	return s.client.Close()
   140  }
   141  
   142  // save writes session in Redis
   143  func (s *SessionStore) save(session *sessions.Session) error {
   144  	b, err := s.serializer.Serialize(session)
   145  	if err != nil {
   146  		return err
   147  	}
   148  
   149  	if config.EncryptionEnabled() {
   150  		encryptedValue, err := crypto.EncryptRedis(b, config.EncryptionKey())
   151  		if err != nil {
   152  			return err
   153  		}
   154  		return s.client.Set(s.keyPrefix+session.ID, encryptedValue, time.Duration(session.Options.MaxAge)*time.Second).Err()
   155  	}
   156  
   157  	return s.client.Set(s.keyPrefix+session.ID, b, time.Duration(session.Options.MaxAge)*time.Second).Err()
   158  }
   159  
   160  // load reads session from Redis
   161  func (s *SessionStore) load(session *sessions.Session) error {
   162  	cmd := s.client.Get(s.keyPrefix + session.ID)
   163  	if cmd.Err() != nil {
   164  		return cmd.Err()
   165  	}
   166  
   167  	b, err := cmd.Bytes()
   168  	if err != nil {
   169  		return err
   170  	}
   171  
   172  	if config.EncryptionEnabled() { // nolint:nestif
   173  		// if data is not encrypted, encrypt + save, then return unencrypted for use
   174  		if !isRedisDataEncrypted(b) {
   175  			ttl := s.client.TTL(s.keyPrefix + session.ID)
   176  			encryptedVal, err := crypto.EncryptRedis(b, config.EncryptionKey())
   177  			if err != nil {
   178  				return err
   179  			}
   180  			// update with encrypted value
   181  			err = s.client.Set(s.keyPrefix+session.ID, encryptedVal, ttl.Val()).Err()
   182  			if err != nil {
   183  				return err
   184  			}
   185  			// return to use
   186  			return s.serializer.Deserialize(b, session)
   187  		}
   188  
   189  		decryptedValue, err := crypto.DecryptRedis(string(b), config.EncryptionKey())
   190  		if err != nil {
   191  			return err
   192  		}
   193  		return s.serializer.Deserialize(decryptedValue, session)
   194  	}
   195  	return s.serializer.Deserialize(b, session)
   196  }
   197  
   198  // delete deletes session in Redis
   199  func (s *SessionStore) delete(session *sessions.Session) error {
   200  	return s.client.Del(s.keyPrefix + session.ID).Err()
   201  }
   202  
   203  // SessionSerializer provides an interface for serialize/deserialize a session
   204  type SessionSerializer interface {
   205  	Serialize(s *sessions.Session) ([]byte, error)
   206  	Deserialize(b []byte, s *sessions.Session) error
   207  }
   208  
   209  // Gob serializer
   210  type GobSerializer struct{}
   211  
   212  func (gs GobSerializer) Serialize(s *sessions.Session) ([]byte, error) {
   213  	buf := new(bytes.Buffer)
   214  	enc := gob.NewEncoder(buf)
   215  	err := enc.Encode(s.Values)
   216  	if err == nil {
   217  		return buf.Bytes(), nil
   218  	}
   219  	return nil, err
   220  }
   221  
   222  func (gs GobSerializer) Deserialize(d []byte, s *sessions.Session) error {
   223  	dec := gob.NewDecoder(bytes.NewBuffer(d))
   224  	return dec.Decode(&s.Values)
   225  }
   226  
   227  // generateRandomKey returns a new random key
   228  func generateRandomKey() (string, error) {
   229  	k := make([]byte, 64)
   230  	if _, err := io.ReadFull(rand.Reader, k); err != nil {
   231  		return "", err
   232  	}
   233  	return strings.TrimRight(base32.StdEncoding.EncodeToString(k), "="), nil
   234  }
   235  
   236  // does the data contain the EncryptedData prefix?
   237  func isRedisDataEncrypted(data []byte) bool {
   238  	searchStr := []byte("EncryptedData:")
   239  	return bytes.Contains(data, searchStr)
   240  }
   241  

View as plain text