1 package controllers
2
3 import (
4 "context"
5 "crypto/x509"
6 "encoding/base64"
7 "errors"
8 "fmt"
9 "strings"
10 "time"
11
12 "github.com/cert-manager/cert-manager/pkg/util/pki"
13 issuerapi "github.com/cert-manager/issuer-lib/api/v1alpha1"
14 "github.com/cert-manager/issuer-lib/controllers"
15 "github.com/cert-manager/issuer-lib/controllers/signer"
16 ctrl "sigs.k8s.io/controller-runtime"
17 "sigs.k8s.io/controller-runtime/pkg/client"
18
19 edgeissuerapi "edge-infra.dev/pkg/edge/edge-issuer/api/v1alpha1"
20 )
21
22 var (
23 errSignerBuilder = errors.New("failed to build the signer")
24 errSignerSign = errors.New("failed to sign")
25 )
26
27 type Signer interface {
28 Sign(*x509.Certificate) ([]byte, error)
29 }
30
31 type SignerBuilder func(keyPEM []byte, certPEM []byte, duration time.Duration) (Signer, error)
32
33 type Issuer struct {
34 SignerBuilder SignerBuilder
35 CAPrivateKey []byte
36 CACert []byte
37 Expiration time.Time
38 Config *Config
39 client client.Client
40 CACertRef string
41 SecretManager secretManager
42 }
43
44 func (o Issuer) SetupWithManager(ctx context.Context, mgr ctrl.Manager) error {
45 o.client = mgr.GetClient()
46
47 return (&controllers.CombinedController{
48 IssuerTypes: []issuerapi.Issuer{&edgeissuerapi.EdgeIssuer{}},
49 ClusterIssuerTypes: []issuerapi.Issuer{&edgeissuerapi.EdgeClusterIssuer{}},
50
51 FieldOwner: "edgeissuer.edge-issuer.edge.ncr.com",
52 MaxRetryDuration: 1 * time.Minute,
53
54 Sign: o.Sign,
55 Check: o.Check,
56 EventRecorder: mgr.GetEventRecorderFor("edgeissuer.edge-issuer.edge.ncr.com"),
57 }).SetupWithManager(ctx, mgr)
58 }
59
60
61
62 func (o *Issuer) Check(ctx context.Context, _ issuerapi.Issuer) error {
63 log := ctrl.Log.WithName("edge-issuer")
64 log.Info("Checking if CA is available")
65
66 if o.SecretManager == nil {
67 o.SecretManager = &gcpSecretManager{}
68 }
69
70
71 ca, key, expiration, version, err := o.getCAInfo(ctx, o.Config.BannerID)
72 if err != nil {
73 return err
74 }
75
76 o.CACert = []byte(ca)
77 o.CAPrivateKey = []byte(key)
78 o.Expiration = expiration
79 log.Info("Successfully checked CA and stored CA info", "version", version)
80 return nil
81 }
82
83
84
85
86
87
88
89 func (o *Issuer) Sign(ctx context.Context, cr signer.CertificateRequestObject, _ issuerapi.Issuer) (signer.PEMBundle, error) {
90 log := ctrl.Log.WithName("edge-issuer")
91
92 caCertRef, _, _, err := o.getCAInfoFromDB(ctx, o.Config.BannerID)
93 if err != nil {
94 return signer.PEMBundle{}, fmt.Errorf("failed to get ca cert ref from database: %w", err)
95 }
96
97
98 if o.Expiration.Before(time.Now()) || o.CACertRef != caCertRef {
99 ca, key, expiration, version, err := o.getCAInfo(ctx, o.Config.BannerID)
100 if err != nil {
101 return signer.PEMBundle{}, err
102 }
103
104 o.CACert = []byte(ca)
105 o.CAPrivateKey = []byte(key)
106 o.Expiration = expiration
107
108 log.Info("Successfully updated CA info to latest CA certificate", "version", version)
109 }
110
111 certTemplate, duration, _, err := cr.GetRequest()
112 if err != nil {
113 return signer.PEMBundle{}, err
114 }
115 signerObj, err := o.SignerBuilder(o.CAPrivateKey, o.CACert, duration)
116 if err != nil {
117 return signer.PEMBundle{}, fmt.Errorf("%w: %v", errSignerBuilder, err)
118 }
119
120 signed, err := signerObj.Sign(certTemplate)
121 if err != nil {
122 return signer.PEMBundle{}, fmt.Errorf("%w: %v", errSignerSign, err)
123 }
124
125 bundle, err := pki.ParseSingleCertificateChainPEM(signed)
126 if err != nil {
127 return signer.PEMBundle{}, err
128 }
129
130 return signer.PEMBundle(bundle), nil
131 }
132
133
134 func (o *Issuer) getCAInfo(ctx context.Context, bannerID string) (string, string, time.Time, string, error) {
135 var err error
136
137
138 certRef, privateKeyRef, expiration, err := o.getCAInfoFromDB(ctx, bannerID)
139 if err != nil {
140 return "", "", time.Time{}, "", fmt.Errorf("failed to get certRef, privateKeyRef, and expiration from db: %w", err)
141 }
142
143 smClient, err := o.SecretManager.NewWithOptions(ctx, o.Config.TopLevelProjectID)
144 if err != nil {
145 return "", "", time.Time{}, "", fmt.Errorf("failed to create secret manager client: %w", err)
146 }
147
148
149 ca, caVersion, err := getValueFromDBRef(ctx, smClient, certRef)
150 if err != nil {
151 return "", "", time.Time{}, "", fmt.Errorf("failed to get ca cert from db ref: %w", err)
152 }
153
154 key, _, err := getValueFromDBRef(ctx, smClient, privateKeyRef)
155 if err != nil {
156 return "", "", time.Time{}, "", fmt.Errorf("failed to get private key from db ref: %w", err)
157 }
158
159 return ca, key, expiration, caVersion, nil
160 }
161
162 func (o *Issuer) getCAInfoFromDB(ctx context.Context, bannerID string) (string, string, time.Time, error) {
163 var caPoolID, certRef, privateKeyRef string
164
165
166 err := o.Config.DB.QueryRowContext(ctx, "SELECT ca_pool_edge_id FROM ca_pools WHERE banner_edge_id = $1", bannerID).Scan(&caPoolID)
167 if err != nil {
168 return "", "", time.Time{}, err
169 }
170
171
172 var expiration time.Time
173 err = o.Config.DB.QueryRowContext(ctx, "SELECT cert_ref, private_key_ref, expiration FROM ca_certificates WHERE ca_pool_edge_id = $1 AND status = 'active'", caPoolID).Scan(&certRef, &privateKeyRef, &expiration)
174 if err != nil {
175 return "", "", time.Time{}, err
176 }
177 return certRef, privateKeyRef, expiration, nil
178 }
179
180
181 func getValueFromDBRef(ctx context.Context, smClient secretManagerClient, secretRef string) (string, string, error) {
182 secretName, secretVersion, err := splitRef(secretRef)
183 if err != nil {
184 return "", "", fmt.Errorf("failed to split secretRef ref into secret name and version: %w", err)
185 }
186
187 secretValue, err := smClient.GetSecretVersionValue(ctx, secretName, secretVersion)
188 if err != nil {
189 return "", "", fmt.Errorf("failed to get latest secret value for key: %w", err)
190 }
191
192 decodedSecretValue, err := base64.StdEncoding.DecodeString(string(secretValue))
193 if err != nil {
194 return "", "", fmt.Errorf("failed to decode secret value: %w", err)
195 }
196
197 return string(decodedSecretValue), secretVersion, nil
198 }
199
200 func splitRef(ref string) (string, string, error) {
201
202 lastHyphenIndex := strings.LastIndex(ref, "-")
203 if lastHyphenIndex == -1 {
204 return "", "", fmt.Errorf("invalid certRef format: %s", ref)
205 }
206
207
208 secretName := ref[:lastHyphenIndex]
209 version := ref[lastHyphenIndex+1:]
210
211 return secretName, version, nil
212 }
213
View as plain text