...

Source file src/edge-infra.dev/pkg/edge/auth-proxy/store/store.go

Documentation: edge-infra.dev/pkg/edge/auth-proxy/store

     1  package store
     2  
     3  import (
     4  	"database/sql"
     5  	"encoding/base32"
     6  	"fmt"
     7  	"net/http"
     8  	"strings"
     9  	"time"
    10  
    11  	"errors"
    12  
    13  	"github.com/go-logr/logr"
    14  	"github.com/gorilla/securecookie"
    15  	"github.com/gorilla/sessions"
    16  
    17  	// Include the pq postgres driver.
    18  	_ "github.com/lib/pq"
    19  
    20  	"edge-infra.dev/pkg/edge/audit"
    21  )
    22  
    23  // PGStore represents the currently configured session store.
    24  type PGStore struct {
    25  	Codecs   []securecookie.Codec
    26  	Options  *sessions.Options
    27  	Path     string
    28  	DbPool   *sql.DB
    29  	log      logr.Logger
    30  	auditLog *audit.Sink
    31  }
    32  
    33  // PGSession type
    34  type PGSession struct {
    35  	ID         int64
    36  	Key        string
    37  	Data       string
    38  	CreatedOn  time.Time
    39  	ModifiedOn time.Time
    40  	ExpiresOn  time.Time
    41  }
    42  
    43  // NewPGStore creates a new PGStore instance and a new database/sql pool.
    44  // This will also create in the database the schema needed by pgstore.
    45  func NewPGStore(dbURL string, logger logr.Logger, auditLog *audit.Sink, keyPairs ...[]byte) (*PGStore, error) {
    46  	db, err := sql.Open("postgres", dbURL)
    47  	if err != nil {
    48  		// Ignore and return nil.
    49  		return nil, err
    50  	}
    51  	return NewPGStoreFromPool(db, logger, auditLog, keyPairs...)
    52  }
    53  
    54  // NewPGStoreFromPool creates a new PGStore instance from an existing
    55  // database/sql pool.
    56  // This will also create the database schema needed by pgstore.
    57  func NewPGStoreFromPool(db *sql.DB, logger logr.Logger, auditLog *audit.Sink, keyPairs ...[]byte) (*PGStore, error) {
    58  	dbStore := &PGStore{
    59  		Codecs: securecookie.CodecsFromPairs(keyPairs...),
    60  		Options: &sessions.Options{
    61  			Path:   "/",
    62  			MaxAge: 86400 * 30,
    63  		},
    64  		DbPool:   db,
    65  		log:      logger,
    66  		auditLog: auditLog,
    67  	}
    68  
    69  	// Create table if it doesn't exist
    70  	err := dbStore.createSessionsTable()
    71  	if err != nil {
    72  		return nil, err
    73  	}
    74  
    75  	return dbStore, nil
    76  }
    77  
    78  // Close closes the database connection.
    79  func (db *PGStore) Close() {
    80  	db.DbPool.Close()
    81  }
    82  
    83  // Get Fetches a session for a given name after it has been added to the
    84  // registry.
    85  func (db *PGStore) Get(r *http.Request, name string) (*sessions.Session, error) {
    86  	return sessions.GetRegistry(r).Get(db, name)
    87  }
    88  
    89  // New returns a new session for the given name without adding it to the registry.
    90  func (db *PGStore) New(r *http.Request, name string) (*sessions.Session, error) {
    91  	session := sessions.NewSession(db, name)
    92  	if session == nil {
    93  		return nil, nil
    94  	}
    95  
    96  	opts := *db.Options
    97  	session.Options = &(opts)
    98  	session.IsNew = true
    99  
   100  	var err error
   101  	if c, errCookie := r.Cookie(name); errCookie == nil {
   102  		err = securecookie.DecodeMulti(name, c.Value, &session.ID, db.Codecs...)
   103  		if err == nil {
   104  			err = db.load(session)
   105  			if err == nil {
   106  				session.IsNew = false
   107  			} else if errors.Is(err, sql.ErrNoRows) {
   108  				err = nil
   109  			}
   110  		}
   111  	}
   112  
   113  	db.MaxAge(db.Options.MaxAge)
   114  
   115  	return session, err
   116  }
   117  
   118  // Save saves the given session into the database and deletes cookies if needed
   119  func (db *PGStore) Save(_ *http.Request, w http.ResponseWriter, session *sessions.Session) error {
   120  	// Set delete if max-age is < 0
   121  	if session.Options.MaxAge < 0 {
   122  		if err := db.destroy(session); err != nil {
   123  			return err
   124  		}
   125  		http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
   126  		return nil
   127  	}
   128  
   129  	if session.ID == "" {
   130  		// Generate a random session ID key suitable for storage in the DB
   131  		session.ID = strings.TrimRight(
   132  			base32.StdEncoding.EncodeToString(
   133  				securecookie.GenerateRandomKey(32),
   134  			), "=")
   135  	}
   136  
   137  	if err := db.save(session); err != nil {
   138  		return err
   139  	}
   140  
   141  	// Keep the session ID key in a cookie so it can be looked up in DB later.
   142  	encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, db.Codecs...)
   143  	if err != nil {
   144  		return err
   145  	}
   146  
   147  	http.SetCookie(w, sessions.NewCookie(session.Name(), encoded, session.Options))
   148  	return nil
   149  }
   150  
   151  // MaxLength restricts the maximum length of new sessions to l.
   152  // If l is 0 there is no limit to the size of a session, use with caution.
   153  // The default for a new PGStore is 4096. PostgreSQL allows for max
   154  // value sizes of up to 1GB (http://www.postgresql.org/docs/current/interactive/datatype-character.html)
   155  func (db *PGStore) MaxLength(l int) {
   156  	for _, c := range db.Codecs {
   157  		if codec, ok := c.(*securecookie.SecureCookie); ok {
   158  			codec.MaxLength(l)
   159  		}
   160  	}
   161  }
   162  
   163  // MaxAge sets the maximum age for the store and the underlying cookie
   164  // implementation. Individual sessions can be deleted by setting Options.MaxAge
   165  // = -1 for that session.
   166  func (db *PGStore) MaxAge(age int) {
   167  	db.Options.MaxAge = age
   168  
   169  	// Set the maxAge for each securecookie instance.
   170  	for _, codec := range db.Codecs {
   171  		if sc, ok := codec.(*securecookie.SecureCookie); ok {
   172  			sc.MaxAge(age)
   173  		}
   174  	}
   175  }
   176  
   177  // load fetches a session by ID from the database and decodes its content
   178  // into session.Values.
   179  func (db *PGStore) load(session *sessions.Session) error {
   180  	var s PGSession
   181  	err := db.selectOne(&s, session.ID)
   182  	if err != nil {
   183  		return err
   184  	}
   185  	return securecookie.DecodeMulti(session.Name(), s.Data, &session.Values, db.Codecs...)
   186  }
   187  
   188  // save writes encoded session.Values to a database record.
   189  // writes to http_sessions table by default.
   190  func (db *PGStore) save(session *sessions.Session) error {
   191  	status := "Success"
   192  	defer func() {
   193  		db.audit(SaveSessionOp, status, session)
   194  	}()
   195  	encoded, err := securecookie.EncodeMulti(session.Name(), session.Values, db.Codecs...)
   196  	if err != nil {
   197  		status = Failure
   198  		return err
   199  	}
   200  
   201  	crOn := session.Values["created_on"]
   202  	exOn := session.Values["expires_on"]
   203  
   204  	var expiresOn time.Time
   205  
   206  	createdOn, ok := crOn.(time.Time)
   207  	if !ok {
   208  		createdOn = time.Now()
   209  	}
   210  
   211  	if exOn == nil {
   212  		expiresOn = time.Now().Add(time.Second * time.Duration(session.Options.MaxAge))
   213  	} else {
   214  		expiresOn = exOn.(time.Time)
   215  		if expiresOn.Sub(time.Now().Add(time.Second*time.Duration(session.Options.MaxAge))) < 0 {
   216  			expiresOn = time.Now().Add(time.Second * time.Duration(session.Options.MaxAge))
   217  		}
   218  	}
   219  
   220  	s := PGSession{
   221  		Key:        session.ID,
   222  		Data:       encoded,
   223  		CreatedOn:  createdOn,
   224  		ExpiresOn:  expiresOn,
   225  		ModifiedOn: time.Now(),
   226  	}
   227  
   228  	if session.IsNew {
   229  		return db.insert(&s)
   230  	}
   231  
   232  	return db.update(&s)
   233  }
   234  
   235  // Delete session
   236  func (db *PGStore) destroy(session *sessions.Session) error {
   237  	// probably pull session details here
   238  	// TODO(pa250194): Intentionally left here
   239  	// Will fix the bug where the session is empty
   240  	//fmt.Println("session vals: ", session.Values)
   241  	//fmt.Println("sessionID destroy: ", session.ID)
   242  	//_ = db.updateExpiration(session.ID, time.Now().Add(5*time.Second))
   243  	//sess := sessions.NewSession(db, SessionIdentifier)
   244  	//sess.ID = session.ID
   245  	//err := db.load(sess)
   246  	//fmt.Println("Destroy save err: ", err)
   247  	status := Success
   248  	defer func() {
   249  		db.audit(DestroySessionOp, status, session)
   250  	}()
   251  	_, err := db.DbPool.Exec("DELETE FROM http_sessions WHERE key = $1", session.ID)
   252  	if err != nil {
   253  		if errors.Is(err, sql.ErrNoRows) {
   254  			return nil
   255  		}
   256  		//status = Failure
   257  		db.log.Error(err, "an error occurred deleting session", enrichLogWithSession(DestroySessionOp, session)...)
   258  		return err
   259  	}
   260  	return nil
   261  }
   262  
   263  func (db *PGStore) createSessionsTable() error {
   264  	stmt := `DO $$
   265                BEGIN
   266                CREATE TABLE IF NOT EXISTS http_sessions (
   267                id BIGSERIAL PRIMARY KEY,
   268                key BYTEA,
   269                data BYTEA,
   270                created_on TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
   271                modified_on TIMESTAMPTZ,
   272                expires_on TIMESTAMPTZ);
   273                CREATE INDEX IF NOT EXISTS http_sessions_expiry_idx ON http_sessions (expires_on);
   274                CREATE INDEX IF NOT EXISTS http_sessions_key_idx ON http_sessions (key);
   275                EXCEPTION WHEN insufficient_privilege THEN
   276                  IF NOT EXISTS (SELECT FROM pg_catalog.pg_tables WHERE schemaname = current_schema() AND tablename = 'http_sessions') THEN
   277                    RAISE;
   278                  END IF;
   279                WHEN others THEN RAISE;
   280                END;
   281                $$;`
   282  
   283  	_, err := db.DbPool.Exec(stmt)
   284  	if err != nil {
   285  		return fmt.Errorf("unable to create http_sessions table in the database, err: %w", err)
   286  	}
   287  
   288  	return nil
   289  }
   290  
   291  func (db *PGStore) selectOne(s *PGSession, key string) error {
   292  	stmt := "SELECT id, key, data, created_on, modified_on, expires_on FROM http_sessions WHERE key = $1"
   293  	err := db.DbPool.QueryRow(stmt, key).Scan(&s.ID, &s.Key, &s.Data, &s.CreatedOn, &s.ModifiedOn, &s.ExpiresOn)
   294  	if err != nil {
   295  		if !errors.Is(err, sql.ErrNoRows) {
   296  			return fmt.Errorf("unable to find session in the database, err: %w", err)
   297  		}
   298  	}
   299  	return nil
   300  }
   301  
   302  func (db *PGStore) insert(s *PGSession) error {
   303  	stmt := `INSERT INTO http_sessions (key, data, created_on, modified_on, expires_on)
   304             VALUES ($1, $2, $3, $4, $5)`
   305  	_, err := db.DbPool.Exec(stmt, s.Key, s.Data, s.CreatedOn, s.ModifiedOn, s.ExpiresOn)
   306  
   307  	return err
   308  }
   309  
   310  func (db *PGStore) update(s *PGSession) error {
   311  	stmt := `UPDATE http_sessions SET data=$1, modified_on=$2, expires_on=$3 WHERE key=$4`
   312  	_, err := db.DbPool.Exec(stmt, s.Data, s.ModifiedOn, s.ExpiresOn, s.Key)
   313  
   314  	return err
   315  }
   316  
   317  //func (db *PGStore) updateExpiration(key string, expiresAt time.Time) error {
   318  //	stmt := `UPDATE http_sessions SET expires_on=$1 WHERE key=$2`
   319  //	_, err := db.DbPool.Exec(stmt, expiresAt, key)
   320  //	return err
   321  //}
   322  

View as plain text