...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package saml
16
17 import (
18 "crypto/rsa"
19 "crypto/x509"
20 "encoding/pem"
21 "encoding/xml"
22 "io/ioutil"
23 "net/http"
24
25 "github.com/crewjam/saml"
26 "github.com/pkg/errors"
27 )
28
29 func WithCertificateFromFile(path string) Param {
30
31 return func(sp *ServiceProvider) error {
32 certBytes, err := ioutil.ReadFile(path)
33 if err != nil {
34 return errors.Wrap(err, "could not read provided certificate file")
35 }
36
37 return WithCertificateFromBytes(certBytes)(sp)
38 }
39
40 }
41
42 func WithCertificateFromBytes(certBytes []byte) Param {
43 return func(sp *ServiceProvider) error {
44 certPem, _ := pem.Decode(certBytes)
45 if certPem == nil {
46 return errors.New("could not PEM decode the provided certificate")
47 }
48
49 cert, err := x509.ParseCertificate(certPem.Bytes)
50 sp.sp.Certificate = cert
51 return errors.Wrap(err, "failed to parse provided certificate")
52 }
53
54 }
55
56 func WithKeyFromFile(path string) Param {
57 return func(sp *ServiceProvider) error {
58 keyBytes, err := ioutil.ReadFile(path)
59 if err != nil {
60 return errors.Wrap(err, "could not read provided key file")
61 }
62
63 return WithKeyFromBytes(keyBytes)(sp)
64 }
65
66 }
67
68 func WithKeyFromBytes(keyBytes []byte) Param {
69
70 return func(sp *ServiceProvider) error {
71 keyPem, _ := pem.Decode(keyBytes)
72 if keyPem == nil {
73 return errors.New("could not PEM decode the provided private key")
74 }
75
76 key, err := x509.ParsePKCS8PrivateKey(keyPem.Bytes)
77 if err != nil {
78 return errors.Wrap(err, "could not parse provided private key")
79 }
80
81 rsaKey, ok := key.(*rsa.PrivateKey)
82 sp.sp.Key = rsaKey
83 if !ok {
84 return errors.New("provided private key was not an RSA key")
85 }
86 return nil
87 }
88
89 }
90
91 func WithEntityFromURL(url string) Param {
92
93 return func(sp *ServiceProvider) error {
94 resp, err := http.Get(url)
95 if err != nil {
96 return errors.Wrap(err, "failed to download IDP metadata")
97 }
98
99 defer func() { _ = resp.Body.Close() }()
100 descriptor, err := ioutil.ReadAll(resp.Body)
101 if err != nil {
102 return errors.Wrap(err, "failed to download IDP metadata")
103 }
104
105 return WithEntityFromBytes(descriptor)(sp)
106 }
107
108 }
109
110 func WithEntityFromBytes(metadata []byte) Param {
111
112 return func(sp *ServiceProvider) error {
113 var entity saml.EntityDescriptor
114
115 if err := xml.Unmarshal(metadata, &entity); err != nil {
116 var entities saml.EntitiesDescriptor
117
118 if err := xml.Unmarshal(metadata, &entities); err != nil {
119 return errors.Wrap(err, "could not parse returned metadata")
120 }
121
122 if len(entities.EntityDescriptors) == 0 {
123 return errors.New("metadata did not contain an entity")
124 }
125
126 entity = entities.EntityDescriptors[0]
127
128 }
129 sp.sp.IDPMetadata = &entity
130 return nil
131 }
132
133 }
134
135
136
137
138 func WithACSPath(path string) Param {
139 return func(sp *ServiceProvider) error {
140 sp.acsPath = path
141 return nil
142 }
143 }
144
145
146
147
148 func WithMetadataPath(path string) Param {
149 return func(sp *ServiceProvider) error {
150 sp.metadataPath = path
151 return nil
152 }
153 }
154
155
156
157 func WithLogoutPath(path string) Param {
158 return func(sp *ServiceProvider) error {
159 sp.logoutPath = path
160 return nil
161 }
162 }
163
164 func WithForceTLS(force bool) Param {
165 return func(sp *ServiceProvider) error {
166 sp.forceTLS = force
167 return nil
168 }
169 }
170
171 func WithLoginCallback(lcb LoginCallback) Param {
172 return func(sp *ServiceProvider) error {
173 sp.onLogin = lcb
174 return nil
175 }
176 }
177
178 func WithErrorCallback(ecb ErrorCallback) Param {
179 return func(sp *ServiceProvider) error {
180 sp.onError = ecb
181 return nil
182 }
183 }
184
185 func WithIDStore(store IDStore) Param {
186 return func(sp *ServiceProvider) error {
187 sp.idStore = store
188 return nil
189 }
190 }
191
192 func WithServiceProvider(s *saml.ServiceProvider) Param {
193 return func(sp *ServiceProvider) error {
194 sp.sp = s
195 return nil
196 }
197 }
198
199 func WithNameIDFormat(n saml.NameIDFormat) Param {
200 return func(sp *ServiceProvider) error {
201 sp.sp.AuthnNameIDFormat = n
202 return nil
203 }
204 }
205
206
207
208
209 func WithEncryptedAssertions(encrypt bool) Param {
210 return func(sp *ServiceProvider) error {
211 sp.disableEncryption = !encrypt
212 return nil
213 }
214 }
215
216 func WithForceAuthn(force bool) Param {
217 return func(sp *ServiceProvider) error {
218 sp.sp.ForceAuthn = &force
219 return nil
220 }
221 }
222
View as plain text