...

Source file src/github.com/palantir/go-baseapp/baseapp/auth/saml/params.go

Documentation: github.com/palantir/go-baseapp/baseapp/auth/saml

     1  // Copyright 2019 Palantir Technologies, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package saml
    16  
    17  import (
    18  	"crypto/rsa"
    19  	"crypto/x509"
    20  	"encoding/pem"
    21  	"encoding/xml"
    22  	"io/ioutil"
    23  	"net/http"
    24  
    25  	"github.com/crewjam/saml"
    26  	"github.com/pkg/errors"
    27  )
    28  
    29  func WithCertificateFromFile(path string) Param {
    30  
    31  	return func(sp *ServiceProvider) error {
    32  		certBytes, err := ioutil.ReadFile(path)
    33  		if err != nil {
    34  			return errors.Wrap(err, "could not read provided certificate file")
    35  		}
    36  
    37  		return WithCertificateFromBytes(certBytes)(sp)
    38  	}
    39  
    40  }
    41  
    42  func WithCertificateFromBytes(certBytes []byte) Param {
    43  	return func(sp *ServiceProvider) error {
    44  		certPem, _ := pem.Decode(certBytes)
    45  		if certPem == nil {
    46  			return errors.New("could not PEM decode the provided certificate")
    47  		}
    48  
    49  		cert, err := x509.ParseCertificate(certPem.Bytes)
    50  		sp.sp.Certificate = cert
    51  		return errors.Wrap(err, "failed to parse provided certificate")
    52  	}
    53  
    54  }
    55  
    56  func WithKeyFromFile(path string) Param {
    57  	return func(sp *ServiceProvider) error {
    58  		keyBytes, err := ioutil.ReadFile(path)
    59  		if err != nil {
    60  			return errors.Wrap(err, "could not read provided key file")
    61  		}
    62  
    63  		return WithKeyFromBytes(keyBytes)(sp)
    64  	}
    65  
    66  }
    67  
    68  func WithKeyFromBytes(keyBytes []byte) Param {
    69  
    70  	return func(sp *ServiceProvider) error {
    71  		keyPem, _ := pem.Decode(keyBytes)
    72  		if keyPem == nil {
    73  			return errors.New("could not PEM decode the provided private key")
    74  		}
    75  
    76  		key, err := x509.ParsePKCS8PrivateKey(keyPem.Bytes)
    77  		if err != nil {
    78  			return errors.Wrap(err, "could not parse provided private key")
    79  		}
    80  
    81  		rsaKey, ok := key.(*rsa.PrivateKey)
    82  		sp.sp.Key = rsaKey
    83  		if !ok {
    84  			return errors.New("provided private key was not an RSA key")
    85  		}
    86  		return nil
    87  	}
    88  
    89  }
    90  
    91  func WithEntityFromURL(url string) Param {
    92  
    93  	return func(sp *ServiceProvider) error {
    94  		resp, err := http.Get(url)
    95  		if err != nil {
    96  			return errors.Wrap(err, "failed to download IDP metadata")
    97  		}
    98  
    99  		defer func() { _ = resp.Body.Close() }()
   100  		descriptor, err := ioutil.ReadAll(resp.Body)
   101  		if err != nil {
   102  			return errors.Wrap(err, "failed to download IDP metadata")
   103  		}
   104  
   105  		return WithEntityFromBytes(descriptor)(sp)
   106  	}
   107  
   108  }
   109  
   110  func WithEntityFromBytes(metadata []byte) Param {
   111  
   112  	return func(sp *ServiceProvider) error {
   113  		var entity saml.EntityDescriptor
   114  
   115  		if err := xml.Unmarshal(metadata, &entity); err != nil {
   116  			var entities saml.EntitiesDescriptor
   117  
   118  			if err := xml.Unmarshal(metadata, &entities); err != nil {
   119  				return errors.Wrap(err, "could not parse returned metadata")
   120  			}
   121  
   122  			if len(entities.EntityDescriptors) == 0 {
   123  				return errors.New("metadata did not contain an entity")
   124  			}
   125  
   126  			entity = entities.EntityDescriptors[0]
   127  
   128  		}
   129  		sp.sp.IDPMetadata = &entity
   130  		return nil
   131  	}
   132  
   133  }
   134  
   135  // WithACSPath sets the path where the assertion consumer handler for the
   136  // service provider is registered. The path is included in generated metadata.
   137  // This is a required parameter.
   138  func WithACSPath(path string) Param {
   139  	return func(sp *ServiceProvider) error {
   140  		sp.acsPath = path
   141  		return nil
   142  	}
   143  }
   144  
   145  // WithMetadataPath sets the path where the metadata handler for the service
   146  // provider is registered. The path is included in generated metadata. This is
   147  // a required parameter.
   148  func WithMetadataPath(path string) Param {
   149  	return func(sp *ServiceProvider) error {
   150  		sp.metadataPath = path
   151  		return nil
   152  	}
   153  }
   154  
   155  // WithLogoutPath sets the path where the single logout handler for the service
   156  // provider is registered. The path is included in generated metadata.
   157  func WithLogoutPath(path string) Param {
   158  	return func(sp *ServiceProvider) error {
   159  		sp.logoutPath = path
   160  		return nil
   161  	}
   162  }
   163  
   164  func WithForceTLS(force bool) Param {
   165  	return func(sp *ServiceProvider) error {
   166  		sp.forceTLS = force
   167  		return nil
   168  	}
   169  }
   170  
   171  func WithLoginCallback(lcb LoginCallback) Param {
   172  	return func(sp *ServiceProvider) error {
   173  		sp.onLogin = lcb
   174  		return nil
   175  	}
   176  }
   177  
   178  func WithErrorCallback(ecb ErrorCallback) Param {
   179  	return func(sp *ServiceProvider) error {
   180  		sp.onError = ecb
   181  		return nil
   182  	}
   183  }
   184  
   185  func WithIDStore(store IDStore) Param {
   186  	return func(sp *ServiceProvider) error {
   187  		sp.idStore = store
   188  		return nil
   189  	}
   190  }
   191  
   192  func WithServiceProvider(s *saml.ServiceProvider) Param {
   193  	return func(sp *ServiceProvider) error {
   194  		sp.sp = s
   195  		return nil
   196  	}
   197  }
   198  
   199  func WithNameIDFormat(n saml.NameIDFormat) Param {
   200  	return func(sp *ServiceProvider) error {
   201  		sp.sp.AuthnNameIDFormat = n
   202  		return nil
   203  	}
   204  }
   205  
   206  // WithEncryptedAssertions enables or disables assertion encryption. By
   207  // default, encryption is enabled. When set to false, the encryption key is not
   208  // included in generated metadata.
   209  func WithEncryptedAssertions(encrypt bool) Param {
   210  	return func(sp *ServiceProvider) error {
   211  		sp.disableEncryption = !encrypt
   212  		return nil
   213  	}
   214  }
   215  
   216  func WithForceAuthn(force bool) Param {
   217  	return func(sp *ServiceProvider) error {
   218  		sp.sp.ForceAuthn = &force
   219  		return nil
   220  	}
   221  }
   222  

View as plain text