1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package saml
16
17 import (
18 "encoding/xml"
19 "net/http"
20 "net/url"
21
22 "github.com/crewjam/saml"
23 "github.com/pkg/errors"
24 "github.com/rs/zerolog/hlog"
25 )
26
27 type Error struct {
28 Err error
29
30
31 ResponseCode int
32 }
33
34 func (s Error) Error() string {
35 return s.Err.Error()
36 }
37
38 func newError(err error, status int) Error {
39 return Error{
40 Err: err,
41 ResponseCode: status,
42 }
43 }
44
45
46
47
48 type ErrorCallback func(http.ResponseWriter, *http.Request, Error)
49
50
51
52 type LoginCallback func(http.ResponseWriter, *http.Request, *saml.Assertion)
53
54
55
56
57
58 type ServiceProvider struct {
59 sp *saml.ServiceProvider
60
61 acsPath string
62 metadataPath string
63 logoutPath string
64
65 forceTLS bool
66 disableEncryption bool
67
68 onError ErrorCallback
69 onLogin LoginCallback
70 idStore IDStore
71 }
72
73 type Param func(sp *ServiceProvider) error
74
75
76
77 func NewServiceProvider(params ...Param) (*ServiceProvider, error) {
78
79 sp := &ServiceProvider{
80 sp: &saml.ServiceProvider{},
81 }
82
83 for _, p := range params {
84 if err := p(sp); err != nil {
85 return nil, err
86 }
87 }
88
89 if sp.sp.Certificate == nil || sp.sp.Key == nil {
90 return nil, errors.New("a certificate and key must be provided")
91 }
92
93 if sp.sp.IDPMetadata == nil {
94 return nil, errors.New("the IDP Metadata must be provided")
95 }
96
97 if sp.acsPath == "" || sp.metadataPath == "" {
98 return nil, errors.New("ACS Path and Metadatda path must be provided")
99 }
100
101 if sp.onError == nil {
102 sp.onError = DefaultErrorCallback
103 }
104
105 if sp.onLogin == nil {
106 sp.onLogin = DefaultLoginCallback
107 }
108
109 if sp.idStore == nil {
110 sp.idStore = cookieIDStore{}
111 }
112
113 return sp, nil
114 }
115
116 func DefaultErrorCallback(w http.ResponseWriter, r *http.Request, err Error) {
117 hlog.FromRequest(r).Error().Err(err.Err).Msg("saml error")
118 http.Error(w, http.StatusText(err.ResponseCode), err.ResponseCode)
119 }
120
121 func DefaultLoginCallback(w http.ResponseWriter, r *http.Request, resp *saml.Assertion) {
122 w.WriteHeader(http.StatusOK)
123 }
124
125 func (s *ServiceProvider) getSAMLSettingsForRequest(r *http.Request) *saml.ServiceProvider {
126
127 newSP := *s.sp
128
129 u := url.URL{
130 Host: r.Host,
131 Scheme: "http",
132 }
133
134 if s.forceTLS || r.TLS != nil {
135 u.Scheme = "https"
136 }
137
138 u.Path = s.metadataPath
139 newSP.MetadataURL = u
140
141 u.Path = s.acsPath
142 newSP.AcsURL = u
143
144 u.Path = s.logoutPath
145 newSP.SloURL = u
146
147 return &newSP
148 }
149
150
151
152 func (s *ServiceProvider) DoAuth(w http.ResponseWriter, r *http.Request) {
153 sp := s.getSAMLSettingsForRequest(r)
154
155 request, err := sp.MakeAuthenticationRequest(sp.GetSSOBindingLocation(saml.HTTPRedirectBinding))
156 if err != nil {
157 s.onError(w, r, newError(errors.Wrap(err, "failed to create authentication request"), http.StatusInternalServerError))
158 return
159 }
160
161 if err := s.idStore.StoreID(w, r, request.ID); err != nil {
162 s.onError(w, r, newError(errors.Wrap(err, "failed to store SAML request id"), http.StatusInternalServerError))
163 return
164 }
165
166 target := request.Redirect("")
167
168 http.Redirect(w, r, target.String(), http.StatusFound)
169 }
170
171
172 func (s *ServiceProvider) ACSHandler() http.Handler {
173 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
174 sp := s.getSAMLSettingsForRequest(r)
175 if err := r.ParseForm(); err != nil {
176 s.onError(w, r, newError(errors.Wrap(err, "could not parse ACS form"), http.StatusForbidden))
177 return
178 }
179 id, err := s.idStore.GetID(r)
180 if err != nil {
181 s.onError(w, r, newError(errors.Wrap(err, "could not retrieve id"), http.StatusForbidden))
182 return
183 }
184 assertion, err := sp.ParseResponse(r, []string{id})
185
186 if err != nil {
187 if parseErr, ok := err.(*saml.InvalidResponseError); ok {
188 err = parseErr.PrivateErr
189 }
190 s.onError(w, r, newError(errors.Wrap(err, "failed to validate SAML assertion"), http.StatusForbidden))
191 return
192 }
193
194 s.onLogin(w, r, assertion)
195 })
196
197 }
198
199
200 func (s *ServiceProvider) MetadataHandler() http.Handler {
201 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
202 metadata := s.getSAMLSettingsForRequest(r).Metadata()
203
204
205
206
207 if s.logoutPath == "" {
208
209 metadata.SPSSODescriptors[0].SSODescriptor.SingleLogoutServices = nil
210 }
211 if s.disableEncryption {
212
213 role := &(metadata.SPSSODescriptors[0].SSODescriptor.RoleDescriptor)
214 for i, k := range role.KeyDescriptors {
215 if k.Use == "encryption" {
216 role.KeyDescriptors = append(role.KeyDescriptors[:i], role.KeyDescriptors[i+1:]...)
217 }
218 }
219 }
220
221 md, err := xml.Marshal(metadata)
222 if err != nil {
223 s.onError(w, r, newError(errors.Wrap(err, "failed to generate service provider metadata"), http.StatusInternalServerError))
224 return
225 }
226
227 w.Header().Set("Content-Type", "application/xml")
228
229
230 _, _ = w.Write(md)
231 })
232 }
233
View as plain text