...

Source file src/github.com/antonlindstrom/pgstore/pgstore.go

Documentation: github.com/antonlindstrom/pgstore

     1  package pgstore
     2  
     3  import (
     4  	"database/sql"
     5  	"encoding/base32"
     6  	"net/http"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/pkg/errors"
    11  
    12  	"github.com/gorilla/securecookie"
    13  	"github.com/gorilla/sessions"
    14  
    15  	// Include the pq postgres driver.
    16  	_ "github.com/lib/pq"
    17  )
    18  
    19  // PGStore represents the currently configured session store.
    20  type PGStore struct {
    21  	Codecs  []securecookie.Codec
    22  	Options *sessions.Options
    23  	Path    string
    24  	DbPool  *sql.DB
    25  }
    26  
    27  // PGSession type
    28  type PGSession struct {
    29  	ID         int64
    30  	Key        string
    31  	Data       string
    32  	CreatedOn  time.Time
    33  	ModifiedOn time.Time
    34  	ExpiresOn  time.Time
    35  }
    36  
    37  // NewPGStore creates a new PGStore instance and a new database/sql pool.
    38  // This will also create in the database the schema needed by pgstore.
    39  func NewPGStore(dbURL string, keyPairs ...[]byte) (*PGStore, error) {
    40  	db, err := sql.Open("postgres", dbURL)
    41  	if err != nil {
    42  		// Ignore and return nil.
    43  		return nil, err
    44  	}
    45  	return NewPGStoreFromPool(db, keyPairs...)
    46  }
    47  
    48  // NewPGStoreFromPool creates a new PGStore instance from an existing
    49  // database/sql pool.
    50  // This will also create the database schema needed by pgstore.
    51  func NewPGStoreFromPool(db *sql.DB, keyPairs ...[]byte) (*PGStore, error) {
    52  	dbStore := &PGStore{
    53  		Codecs: securecookie.CodecsFromPairs(keyPairs...),
    54  		Options: &sessions.Options{
    55  			Path:   "/",
    56  			MaxAge: 86400 * 30,
    57  		},
    58  		DbPool: db,
    59  	}
    60  
    61  	// Create table if it doesn't exist
    62  	err := dbStore.createSessionsTable()
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  
    67  	return dbStore, nil
    68  }
    69  
    70  // Close closes the database connection.
    71  func (db *PGStore) Close() {
    72  	db.DbPool.Close()
    73  }
    74  
    75  // Get Fetches a session for a given name after it has been added to the
    76  // registry.
    77  func (db *PGStore) Get(r *http.Request, name string) (*sessions.Session, error) {
    78  	return sessions.GetRegistry(r).Get(db, name)
    79  }
    80  
    81  // New returns a new session for the given name without adding it to the registry.
    82  func (db *PGStore) New(r *http.Request, name string) (*sessions.Session, error) {
    83  	session := sessions.NewSession(db, name)
    84  	if session == nil {
    85  		return nil, nil
    86  	}
    87  
    88  	opts := *db.Options
    89  	session.Options = &(opts)
    90  	session.IsNew = true
    91  
    92  	var err error
    93  	if c, errCookie := r.Cookie(name); errCookie == nil {
    94  		err = securecookie.DecodeMulti(name, c.Value, &session.ID, db.Codecs...)
    95  		if err == nil {
    96  			err = db.load(session)
    97  			if err == nil {
    98  				session.IsNew = false
    99  			} else if errors.Cause(err) == sql.ErrNoRows {
   100  				err = nil
   101  			}
   102  		}
   103  	}
   104  
   105  	db.MaxAge(db.Options.MaxAge)
   106  
   107  	return session, err
   108  }
   109  
   110  // Save saves the given session into the database and deletes cookies if needed
   111  func (db *PGStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
   112  	// Set delete if max-age is < 0
   113  	if session.Options.MaxAge < 0 {
   114  		if err := db.destroy(session); err != nil {
   115  			return err
   116  		}
   117  		http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
   118  		return nil
   119  	}
   120  
   121  	if session.ID == "" {
   122  		// Generate a random session ID key suitable for storage in the DB
   123  		session.ID = strings.TrimRight(
   124  			base32.StdEncoding.EncodeToString(
   125  				securecookie.GenerateRandomKey(32),
   126  			), "=")
   127  	}
   128  
   129  	if err := db.save(session); err != nil {
   130  		return err
   131  	}
   132  
   133  	// Keep the session ID key in a cookie so it can be looked up in DB later.
   134  	encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, db.Codecs...)
   135  	if err != nil {
   136  		return err
   137  	}
   138  
   139  	http.SetCookie(w, sessions.NewCookie(session.Name(), encoded, session.Options))
   140  	return nil
   141  }
   142  
   143  // MaxLength restricts the maximum length of new sessions to l.
   144  // If l is 0 there is no limit to the size of a session, use with caution.
   145  // The default for a new PGStore is 4096. PostgreSQL allows for max
   146  // value sizes of up to 1GB (http://www.postgresql.org/docs/current/interactive/datatype-character.html)
   147  func (db *PGStore) MaxLength(l int) {
   148  	for _, c := range db.Codecs {
   149  		if codec, ok := c.(*securecookie.SecureCookie); ok {
   150  			codec.MaxLength(l)
   151  		}
   152  	}
   153  }
   154  
   155  // MaxAge sets the maximum age for the store and the underlying cookie
   156  // implementation. Individual sessions can be deleted by setting Options.MaxAge
   157  // = -1 for that session.
   158  func (db *PGStore) MaxAge(age int) {
   159  	db.Options.MaxAge = age
   160  
   161  	// Set the maxAge for each securecookie instance.
   162  	for _, codec := range db.Codecs {
   163  		if sc, ok := codec.(*securecookie.SecureCookie); ok {
   164  			sc.MaxAge(age)
   165  		}
   166  	}
   167  }
   168  
   169  // load fetches a session by ID from the database and decodes its content
   170  // into session.Values.
   171  func (db *PGStore) load(session *sessions.Session) error {
   172  	var s PGSession
   173  
   174  	err := db.selectOne(&s, session.ID)
   175  	if err != nil {
   176  		return err
   177  	}
   178  
   179  	return securecookie.DecodeMulti(session.Name(), string(s.Data), &session.Values, db.Codecs...)
   180  }
   181  
   182  // save writes encoded session.Values to a database record.
   183  // writes to http_sessions table by default.
   184  func (db *PGStore) save(session *sessions.Session) error {
   185  	encoded, err := securecookie.EncodeMulti(session.Name(), session.Values, db.Codecs...)
   186  	if err != nil {
   187  		return err
   188  	}
   189  
   190  	crOn := session.Values["created_on"]
   191  	exOn := session.Values["expires_on"]
   192  
   193  	var expiresOn time.Time
   194  
   195  	createdOn, ok := crOn.(time.Time)
   196  	if !ok {
   197  		createdOn = time.Now()
   198  	}
   199  
   200  	if exOn == nil {
   201  		expiresOn = time.Now().Add(time.Second * time.Duration(session.Options.MaxAge))
   202  	} else {
   203  		expiresOn = exOn.(time.Time)
   204  		if expiresOn.Sub(time.Now().Add(time.Second*time.Duration(session.Options.MaxAge))) < 0 {
   205  			expiresOn = time.Now().Add(time.Second * time.Duration(session.Options.MaxAge))
   206  		}
   207  	}
   208  
   209  	s := PGSession{
   210  		Key:        session.ID,
   211  		Data:       encoded,
   212  		CreatedOn:  createdOn,
   213  		ExpiresOn:  expiresOn,
   214  		ModifiedOn: time.Now(),
   215  	}
   216  
   217  	if session.IsNew {
   218  		return db.insert(&s)
   219  	}
   220  
   221  	return db.update(&s)
   222  }
   223  
   224  // Delete session
   225  func (db *PGStore) destroy(session *sessions.Session) error {
   226  	_, err := db.DbPool.Exec("DELETE FROM http_sessions WHERE key = $1", session.ID)
   227  	return err
   228  }
   229  
   230  func (db *PGStore) createSessionsTable() error {
   231  	stmt := `DO $$
   232                BEGIN
   233                CREATE TABLE IF NOT EXISTS http_sessions (
   234                id BIGSERIAL PRIMARY KEY,
   235                key BYTEA,
   236                data BYTEA,
   237                created_on TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
   238                modified_on TIMESTAMPTZ,
   239                expires_on TIMESTAMPTZ);
   240                EXCEPTION WHEN insufficient_privilege THEN
   241                  IF NOT EXISTS (SELECT FROM pg_catalog.pg_tables WHERE schemaname = current_schema() AND tablename = 'http_sessions') THEN
   242                    RAISE;
   243                  END IF;
   244                WHEN others THEN RAISE;
   245                END;
   246                $$;`
   247  
   248  	_, err := db.DbPool.Exec(stmt)
   249  	if err != nil {
   250  		return errors.Wrapf(err, "Unable to create http_sessions table in the database")
   251  	}
   252  
   253  	return nil
   254  }
   255  
   256  func (db *PGStore) selectOne(s *PGSession, key string) error {
   257  	stmt := "SELECT id, key, data, created_on, modified_on, expires_on FROM http_sessions WHERE key = $1"
   258  	err := db.DbPool.QueryRow(stmt, key).Scan(&s.ID, &s.Key, &s.Data, &s.CreatedOn, &s.ModifiedOn, &s.ExpiresOn)
   259  	if err != nil {
   260  		return errors.Wrapf(err, "Unable to find session in the database")
   261  	}
   262  
   263  	return nil
   264  }
   265  
   266  func (db *PGStore) insert(s *PGSession) error {
   267  	stmt := `INSERT INTO http_sessions (key, data, created_on, modified_on, expires_on)
   268             VALUES ($1, $2, $3, $4, $5)`
   269  	_, err := db.DbPool.Exec(stmt, s.Key, s.Data, s.CreatedOn, s.ModifiedOn, s.ExpiresOn)
   270  
   271  	return err
   272  }
   273  
   274  func (db *PGStore) update(s *PGSession) error {
   275  	stmt := `UPDATE http_sessions SET data=$1, modified_on=$2, expires_on=$3 WHERE key=$4`
   276  	_, err := db.DbPool.Exec(stmt, s.Data, s.ModifiedOn, s.ExpiresOn, s.Key)
   277  
   278  	return err
   279  }
   280  

View as plain text