1 package pgstore
2
3 import (
4 "database/sql"
5 "encoding/base32"
6 "net/http"
7 "strings"
8 "time"
9
10 "github.com/pkg/errors"
11
12 "github.com/gorilla/securecookie"
13 "github.com/gorilla/sessions"
14
15
16 _ "github.com/lib/pq"
17 )
18
19
20 type PGStore struct {
21 Codecs []securecookie.Codec
22 Options *sessions.Options
23 Path string
24 DbPool *sql.DB
25 }
26
27
28 type PGSession struct {
29 ID int64
30 Key string
31 Data string
32 CreatedOn time.Time
33 ModifiedOn time.Time
34 ExpiresOn time.Time
35 }
36
37
38
39 func NewPGStore(dbURL string, keyPairs ...[]byte) (*PGStore, error) {
40 db, err := sql.Open("postgres", dbURL)
41 if err != nil {
42
43 return nil, err
44 }
45 return NewPGStoreFromPool(db, keyPairs...)
46 }
47
48
49
50
51 func NewPGStoreFromPool(db *sql.DB, keyPairs ...[]byte) (*PGStore, error) {
52 dbStore := &PGStore{
53 Codecs: securecookie.CodecsFromPairs(keyPairs...),
54 Options: &sessions.Options{
55 Path: "/",
56 MaxAge: 86400 * 30,
57 },
58 DbPool: db,
59 }
60
61
62 err := dbStore.createSessionsTable()
63 if err != nil {
64 return nil, err
65 }
66
67 return dbStore, nil
68 }
69
70
71 func (db *PGStore) Close() {
72 db.DbPool.Close()
73 }
74
75
76
77 func (db *PGStore) Get(r *http.Request, name string) (*sessions.Session, error) {
78 return sessions.GetRegistry(r).Get(db, name)
79 }
80
81
82 func (db *PGStore) New(r *http.Request, name string) (*sessions.Session, error) {
83 session := sessions.NewSession(db, name)
84 if session == nil {
85 return nil, nil
86 }
87
88 opts := *db.Options
89 session.Options = &(opts)
90 session.IsNew = true
91
92 var err error
93 if c, errCookie := r.Cookie(name); errCookie == nil {
94 err = securecookie.DecodeMulti(name, c.Value, &session.ID, db.Codecs...)
95 if err == nil {
96 err = db.load(session)
97 if err == nil {
98 session.IsNew = false
99 } else if errors.Cause(err) == sql.ErrNoRows {
100 err = nil
101 }
102 }
103 }
104
105 db.MaxAge(db.Options.MaxAge)
106
107 return session, err
108 }
109
110
111 func (db *PGStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
112
113 if session.Options.MaxAge < 0 {
114 if err := db.destroy(session); err != nil {
115 return err
116 }
117 http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
118 return nil
119 }
120
121 if session.ID == "" {
122
123 session.ID = strings.TrimRight(
124 base32.StdEncoding.EncodeToString(
125 securecookie.GenerateRandomKey(32),
126 ), "=")
127 }
128
129 if err := db.save(session); err != nil {
130 return err
131 }
132
133
134 encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, db.Codecs...)
135 if err != nil {
136 return err
137 }
138
139 http.SetCookie(w, sessions.NewCookie(session.Name(), encoded, session.Options))
140 return nil
141 }
142
143
144
145
146
147 func (db *PGStore) MaxLength(l int) {
148 for _, c := range db.Codecs {
149 if codec, ok := c.(*securecookie.SecureCookie); ok {
150 codec.MaxLength(l)
151 }
152 }
153 }
154
155
156
157
158 func (db *PGStore) MaxAge(age int) {
159 db.Options.MaxAge = age
160
161
162 for _, codec := range db.Codecs {
163 if sc, ok := codec.(*securecookie.SecureCookie); ok {
164 sc.MaxAge(age)
165 }
166 }
167 }
168
169
170
171 func (db *PGStore) load(session *sessions.Session) error {
172 var s PGSession
173
174 err := db.selectOne(&s, session.ID)
175 if err != nil {
176 return err
177 }
178
179 return securecookie.DecodeMulti(session.Name(), string(s.Data), &session.Values, db.Codecs...)
180 }
181
182
183
184 func (db *PGStore) save(session *sessions.Session) error {
185 encoded, err := securecookie.EncodeMulti(session.Name(), session.Values, db.Codecs...)
186 if err != nil {
187 return err
188 }
189
190 crOn := session.Values["created_on"]
191 exOn := session.Values["expires_on"]
192
193 var expiresOn time.Time
194
195 createdOn, ok := crOn.(time.Time)
196 if !ok {
197 createdOn = time.Now()
198 }
199
200 if exOn == nil {
201 expiresOn = time.Now().Add(time.Second * time.Duration(session.Options.MaxAge))
202 } else {
203 expiresOn = exOn.(time.Time)
204 if expiresOn.Sub(time.Now().Add(time.Second*time.Duration(session.Options.MaxAge))) < 0 {
205 expiresOn = time.Now().Add(time.Second * time.Duration(session.Options.MaxAge))
206 }
207 }
208
209 s := PGSession{
210 Key: session.ID,
211 Data: encoded,
212 CreatedOn: createdOn,
213 ExpiresOn: expiresOn,
214 ModifiedOn: time.Now(),
215 }
216
217 if session.IsNew {
218 return db.insert(&s)
219 }
220
221 return db.update(&s)
222 }
223
224
225 func (db *PGStore) destroy(session *sessions.Session) error {
226 _, err := db.DbPool.Exec("DELETE FROM http_sessions WHERE key = $1", session.ID)
227 return err
228 }
229
230 func (db *PGStore) createSessionsTable() error {
231 stmt := `DO $$
232 BEGIN
233 CREATE TABLE IF NOT EXISTS http_sessions (
234 id BIGSERIAL PRIMARY KEY,
235 key BYTEA,
236 data BYTEA,
237 created_on TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
238 modified_on TIMESTAMPTZ,
239 expires_on TIMESTAMPTZ);
240 EXCEPTION WHEN insufficient_privilege THEN
241 IF NOT EXISTS (SELECT FROM pg_catalog.pg_tables WHERE schemaname = current_schema() AND tablename = 'http_sessions') THEN
242 RAISE;
243 END IF;
244 WHEN others THEN RAISE;
245 END;
246 $$;`
247
248 _, err := db.DbPool.Exec(stmt)
249 if err != nil {
250 return errors.Wrapf(err, "Unable to create http_sessions table in the database")
251 }
252
253 return nil
254 }
255
256 func (db *PGStore) selectOne(s *PGSession, key string) error {
257 stmt := "SELECT id, key, data, created_on, modified_on, expires_on FROM http_sessions WHERE key = $1"
258 err := db.DbPool.QueryRow(stmt, key).Scan(&s.ID, &s.Key, &s.Data, &s.CreatedOn, &s.ModifiedOn, &s.ExpiresOn)
259 if err != nil {
260 return errors.Wrapf(err, "Unable to find session in the database")
261 }
262
263 return nil
264 }
265
266 func (db *PGStore) insert(s *PGSession) error {
267 stmt := `INSERT INTO http_sessions (key, data, created_on, modified_on, expires_on)
268 VALUES ($1, $2, $3, $4, $5)`
269 _, err := db.DbPool.Exec(stmt, s.Key, s.Data, s.CreatedOn, s.ModifiedOn, s.ExpiresOn)
270
271 return err
272 }
273
274 func (db *PGStore) update(s *PGSession) error {
275 stmt := `UPDATE http_sessions SET data=$1, modified_on=$2, expires_on=$3 WHERE key=$4`
276 _, err := db.DbPool.Exec(stmt, s.Data, s.ModifiedOn, s.ExpiresOn, s.Key)
277
278 return err
279 }
280
View as plain text