...

Source file src/github.com/palantir/go-baseapp/baseapp/auth/saml/serviceprovider.go

Documentation: github.com/palantir/go-baseapp/baseapp/auth/saml

     1  // Copyright 2019 Palantir Technologies, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package saml
    16  
    17  import (
    18  	"encoding/xml"
    19  	"net/http"
    20  	"net/url"
    21  
    22  	"github.com/crewjam/saml"
    23  	"github.com/pkg/errors"
    24  	"github.com/rs/zerolog/hlog"
    25  )
    26  
    27  type Error struct {
    28  	Err error
    29  
    30  	// The suggested HTTP response code for this error
    31  	ResponseCode int
    32  }
    33  
    34  func (s Error) Error() string {
    35  	return s.Err.Error()
    36  }
    37  
    38  func newError(err error, status int) Error {
    39  	return Error{
    40  		Err:          err,
    41  		ResponseCode: status,
    42  	}
    43  }
    44  
    45  // ErrorCallback is called whenever an error occurs in the saml package.
    46  // The callback is expected to send a response to the request. The http.ResponseWriter
    47  // will not have been written to, allowing the callback to send headers if desired.
    48  type ErrorCallback func(http.ResponseWriter, *http.Request, Error)
    49  
    50  // LoginCallback is called whenever an auth flow is successfully completed.
    51  // The callback is responsible preserving the login state.
    52  type LoginCallback func(http.ResponseWriter, *http.Request, *saml.Assertion)
    53  
    54  // ServiceProvider is capable of handling a SAML login. It provides
    55  // an http.Handler (via ACSHandler) which can process the http POST from the SAML IDP. It accepts callbacks for both error and
    56  // success conditions so that clients can take action after the auth flow is complete. It also provides a handler
    57  // for serving the service provider metadata XML.
    58  type ServiceProvider struct {
    59  	sp *saml.ServiceProvider
    60  
    61  	acsPath      string
    62  	metadataPath string
    63  	logoutPath   string
    64  
    65  	forceTLS          bool
    66  	disableEncryption bool
    67  
    68  	onError ErrorCallback
    69  	onLogin LoginCallback
    70  	idStore IDStore
    71  }
    72  
    73  type Param func(sp *ServiceProvider) error
    74  
    75  // NewServiceProvider returns a ServiceProvider. The configuration of the ServiceProvider
    76  // is a result of combinging settings provided to this method and values parsed from the IDP's metadata.
    77  func NewServiceProvider(params ...Param) (*ServiceProvider, error) {
    78  
    79  	sp := &ServiceProvider{
    80  		sp: &saml.ServiceProvider{},
    81  	}
    82  
    83  	for _, p := range params {
    84  		if err := p(sp); err != nil {
    85  			return nil, err
    86  		}
    87  	}
    88  
    89  	if sp.sp.Certificate == nil || sp.sp.Key == nil {
    90  		return nil, errors.New("a certificate and key must be provided")
    91  	}
    92  
    93  	if sp.sp.IDPMetadata == nil {
    94  		return nil, errors.New("the IDP Metadata must be provided")
    95  	}
    96  
    97  	if sp.acsPath == "" || sp.metadataPath == "" {
    98  		return nil, errors.New("ACS Path and Metadatda path must be provided")
    99  	}
   100  
   101  	if sp.onError == nil {
   102  		sp.onError = DefaultErrorCallback
   103  	}
   104  
   105  	if sp.onLogin == nil {
   106  		sp.onLogin = DefaultLoginCallback
   107  	}
   108  
   109  	if sp.idStore == nil {
   110  		sp.idStore = cookieIDStore{}
   111  	}
   112  
   113  	return sp, nil
   114  }
   115  
   116  func DefaultErrorCallback(w http.ResponseWriter, r *http.Request, err Error) {
   117  	hlog.FromRequest(r).Error().Err(err.Err).Msg("saml error")
   118  	http.Error(w, http.StatusText(err.ResponseCode), err.ResponseCode)
   119  }
   120  
   121  func DefaultLoginCallback(w http.ResponseWriter, r *http.Request, resp *saml.Assertion) {
   122  	w.WriteHeader(http.StatusOK)
   123  }
   124  
   125  func (s *ServiceProvider) getSAMLSettingsForRequest(r *http.Request) *saml.ServiceProvider {
   126  	// make a copy in case different requests have different host headers
   127  	newSP := *s.sp
   128  
   129  	u := url.URL{
   130  		Host:   r.Host,
   131  		Scheme: "http",
   132  	}
   133  
   134  	if s.forceTLS || r.TLS != nil {
   135  		u.Scheme = "https"
   136  	}
   137  
   138  	u.Path = s.metadataPath
   139  	newSP.MetadataURL = u
   140  
   141  	u.Path = s.acsPath
   142  	newSP.AcsURL = u
   143  
   144  	u.Path = s.logoutPath
   145  	newSP.SloURL = u
   146  
   147  	return &newSP
   148  }
   149  
   150  // DoAuth takes an http.ResponseWriter that has not been written to yet, and conducts and SP initiated login
   151  // If the flow proceeds correctly the user should be redirected to the handler provided by ACSHandler().
   152  func (s *ServiceProvider) DoAuth(w http.ResponseWriter, r *http.Request) {
   153  	sp := s.getSAMLSettingsForRequest(r)
   154  
   155  	request, err := sp.MakeAuthenticationRequest(sp.GetSSOBindingLocation(saml.HTTPRedirectBinding))
   156  	if err != nil {
   157  		s.onError(w, r, newError(errors.Wrap(err, "failed to create authentication request"), http.StatusInternalServerError))
   158  		return
   159  	}
   160  
   161  	if err := s.idStore.StoreID(w, r, request.ID); err != nil {
   162  		s.onError(w, r, newError(errors.Wrap(err, "failed to store SAML request id"), http.StatusInternalServerError))
   163  		return
   164  	}
   165  
   166  	target := request.Redirect("")
   167  
   168  	http.Redirect(w, r, target.String(), http.StatusFound)
   169  }
   170  
   171  // ACSHandler returns an http.Handler which is capable of validating and processing SAML Responses.
   172  func (s *ServiceProvider) ACSHandler() http.Handler {
   173  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   174  		sp := s.getSAMLSettingsForRequest(r)
   175  		if err := r.ParseForm(); err != nil {
   176  			s.onError(w, r, newError(errors.Wrap(err, "could not parse ACS form"), http.StatusForbidden))
   177  			return
   178  		}
   179  		id, err := s.idStore.GetID(r)
   180  		if err != nil {
   181  			s.onError(w, r, newError(errors.Wrap(err, "could not retrieve id"), http.StatusForbidden))
   182  			return
   183  		}
   184  		assertion, err := sp.ParseResponse(r, []string{id})
   185  
   186  		if err != nil {
   187  			if parseErr, ok := err.(*saml.InvalidResponseError); ok {
   188  				err = parseErr.PrivateErr
   189  			}
   190  			s.onError(w, r, newError(errors.Wrap(err, "failed to validate SAML assertion"), http.StatusForbidden))
   191  			return
   192  		}
   193  
   194  		s.onLogin(w, r, assertion)
   195  	})
   196  
   197  }
   198  
   199  // MetadataHandler returns an http.Handler which sends the generated metadata XML in response to a request
   200  func (s *ServiceProvider) MetadataHandler() http.Handler {
   201  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   202  		metadata := s.getSAMLSettingsForRequest(r).Metadata()
   203  
   204  		// post-process the metadata to account for issues in crewjam/saml
   205  		// struct navigation is hardcoded for the return value at implementation time
   206  
   207  		if s.logoutPath == "" {
   208  			// remove SingleLogoutService elements if the logout path is not set
   209  			metadata.SPSSODescriptors[0].SSODescriptor.SingleLogoutServices = nil
   210  		}
   211  		if s.disableEncryption {
   212  			// remove encryption keys from metadata
   213  			role := &(metadata.SPSSODescriptors[0].SSODescriptor.RoleDescriptor)
   214  			for i, k := range role.KeyDescriptors {
   215  				if k.Use == "encryption" {
   216  					role.KeyDescriptors = append(role.KeyDescriptors[:i], role.KeyDescriptors[i+1:]...)
   217  				}
   218  			}
   219  		}
   220  
   221  		md, err := xml.Marshal(metadata)
   222  		if err != nil {
   223  			s.onError(w, r, newError(errors.Wrap(err, "failed to generate service provider metadata"), http.StatusInternalServerError))
   224  			return
   225  		}
   226  
   227  		w.Header().Set("Content-Type", "application/xml")
   228  		// The error isn't handlable or recoverable so don't handle it
   229  		// assign to _ to placate errcheck
   230  		_, _ = w.Write(md)
   231  	})
   232  }
   233  

View as plain text