1 package tls
2
3 import (
4 "bytes"
5 "crypto/ecdsa"
6 "crypto/rand"
7 "crypto/rsa"
8 "crypto/x509"
9 "encoding/pem"
10 "errors"
11 "fmt"
12 "os"
13 "path/filepath"
14 "time"
15 )
16
17 type (
18
19 privateKeyEC struct {
20 *ecdsa.PrivateKey
21 }
22
23
24 privateKeyRSA struct {
25 *rsa.PrivateKey
26 }
27
28
29 GenericPrivateKey interface {
30 matchesCertificate(*x509.Certificate) bool
31 marshal() ([]byte, error)
32 }
33
34
35 Cred struct {
36 PrivateKey GenericPrivateKey
37 Crt
38 }
39
40
41
42
43
44 Crt struct {
45 Certificate *x509.Certificate
46 TrustChain []*x509.Certificate
47 }
48 )
49
50 func (k privateKeyEC) matchesCertificate(c *x509.Certificate) bool {
51 pub, ok := c.PublicKey.(*ecdsa.PublicKey)
52 return ok && pub.X.Cmp(k.X) == 0 && pub.Y.Cmp(k.Y) == 0
53 }
54
55 func (k privateKeyEC) marshal() ([]byte, error) {
56 return x509.MarshalECPrivateKey(k.PrivateKey)
57 }
58
59 func (k privateKeyRSA) matchesCertificate(c *x509.Certificate) bool {
60 pub, ok := c.PublicKey.(*rsa.PublicKey)
61 return ok && pub.N.Cmp(k.N) == 0 && pub.E == k.E
62 }
63
64 func (k privateKeyRSA) marshal() ([]byte, error) {
65 return x509.MarshalPKCS1PrivateKey(k.PrivateKey), nil
66 }
67
68
69 func validCredOrPanic(ecKey *ecdsa.PrivateKey, crt Crt) Cred {
70 k := privateKeyEC{ecKey}
71 if !k.matchesCertificate(crt.Certificate) {
72 panic("Cert's public key does not match private key")
73 }
74 return Cred{Crt: crt, PrivateKey: k}
75 }
76
77
78 func (crt *Crt) CertPool() *x509.CertPool {
79 p := x509.NewCertPool()
80 p.AddCert(crt.Certificate)
81 for _, c := range crt.TrustChain {
82 p.AddCert(c)
83 }
84 return p
85 }
86
87
88 func (crt *Crt) Verify(roots *x509.CertPool, name string, currentTime time.Time) error {
89 i := x509.NewCertPool()
90 for _, c := range crt.TrustChain {
91 i.AddCert(c)
92 }
93 vo := x509.VerifyOptions{Roots: roots, Intermediates: i, DNSName: name, CurrentTime: currentTime}
94 _, err := crt.Certificate.Verify(vo)
95
96 if currentTime.IsZero() {
97 currentTime = time.Now()
98 }
99
100 if crtExpiryError(err) {
101 return fmt.Errorf("%w - Current Time : %s - Invalid before %s - Invalid After %s", err, currentTime, crt.Certificate.NotBefore, crt.Certificate.NotAfter)
102 }
103 return err
104 }
105
106
107 func (crt *Crt) ExtractRaw() [][]byte {
108 chain := make([][]byte, len(crt.TrustChain)+1)
109 chain[0] = crt.Certificate.Raw
110 for i, c := range crt.TrustChain {
111 chain[len(crt.TrustChain)-i] = c.Raw
112 }
113 return chain
114 }
115
116
117
118 func (crt *Crt) EncodePEM() string {
119 buf := bytes.Buffer{}
120 encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: crt.Certificate.Raw})
121
122
123 n := len(crt.TrustChain)
124 for i := n - 1; i >= 0; i-- {
125 encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: crt.TrustChain[i].Raw})
126 }
127
128 return buf.String()
129 }
130
131
132 func (crt *Crt) EncodeCertificatePEM() string {
133 return EncodeCertificatesPEM(crt.Certificate)
134 }
135
136
137 func (cred *Cred) EncodePrivateKeyPEM() string {
138 b, err := cred.PrivateKey.marshal()
139 if err != nil {
140 panic(fmt.Sprintf("Invalid private key: %s", err))
141 }
142
143 return string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: b}))
144 }
145
146
147 func (cred *Cred) EncodePrivateKeyP8() ([]byte, error) {
148 return x509.MarshalPKCS8PrivateKey(cred.PrivateKey)
149 }
150
151
152
153
154 func (cred *Cred) SignCrt(template *x509.Certificate) (Crt, error) {
155 crtb, err := x509.CreateCertificate(
156 rand.Reader,
157 template,
158 cred.Crt.Certificate,
159 template.PublicKey,
160 cred.PrivateKey,
161 )
162 if err != nil {
163 return Crt{}, err
164 }
165
166 c, err := x509.ParseCertificate(crtb)
167 if err != nil {
168 return Crt{}, err
169 }
170
171 crt := Crt{
172 Certificate: c,
173 TrustChain: append(cred.Crt.TrustChain, cred.Crt.Certificate),
174 }
175 return crt, nil
176 }
177
178
179 func ValidateAndCreateCreds(crt, key string) (*Cred, error) {
180 k, err := DecodePEMKey(key)
181 if err != nil {
182 return nil, err
183 }
184
185 c, err := DecodePEMCrt(crt)
186 if err != nil {
187 return nil, err
188 }
189
190 if !k.matchesCertificate(c.Certificate) {
191 return nil, errors.New("tls: Public and private key do not match")
192 }
193 return &Cred{PrivateKey: k, Crt: *c}, nil
194 }
195
196
197 func ReadPEMCreds(keyPath, crtPath string) (*Cred, error) {
198 keyb, err := os.ReadFile(filepath.Clean(keyPath))
199 if err != nil {
200 return nil, err
201 }
202
203 crtb, err := os.ReadFile(filepath.Clean(crtPath))
204 if err != nil {
205 return nil, err
206 }
207
208 return ValidateAndCreateCreds(string(crtb), string(keyb))
209 }
210
211
212 func DecodePEMCrt(txt string) (*Crt, error) {
213 certs, err := DecodePEMCertificates(txt)
214 if err != nil {
215 return nil, err
216 }
217 if len(certs) == 0 {
218 return nil, errors.New("No certificates found")
219 }
220
221 crt := Crt{
222 Certificate: certs[0],
223 TrustChain: make([]*x509.Certificate, len(certs)-1),
224 }
225
226
227 certs = certs[1:]
228 for i, c := range certs {
229 crt.TrustChain[len(certs)-i-1] = c
230 }
231
232 return &crt, nil
233 }
234
235 func crtExpiryError(err error) bool {
236 var cie x509.CertificateInvalidError
237 if errors.As(err, &cie) {
238 return cie.Reason == x509.Expired
239 }
240 return false
241 }
242
View as plain text