1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package client
16
17 import (
18 "context"
19 "crypto/sha256"
20 "errors"
21 "fmt"
22 "net/http"
23 "os"
24 "time"
25
26 ct "github.com/google/certificate-transparency-go"
27 "github.com/google/certificate-transparency-go/client/configpb"
28 "github.com/google/certificate-transparency-go/jsonclient"
29 "github.com/google/certificate-transparency-go/x509"
30 "google.golang.org/protobuf/encoding/prototext"
31 "google.golang.org/protobuf/proto"
32 )
33
34 type interval struct {
35 lower *time.Time
36 upper *time.Time
37 }
38
39
40
41 func TemporalLogConfigFromFile(filename string) (*configpb.TemporalLogConfig, error) {
42 if len(filename) == 0 {
43 return nil, errors.New("log config filename empty")
44 }
45
46 cfgBytes, err := os.ReadFile(filename)
47 if err != nil {
48 return nil, fmt.Errorf("failed to read log config: %v", err)
49 }
50
51 var cfg configpb.TemporalLogConfig
52 if txtErr := prototext.Unmarshal(cfgBytes, &cfg); txtErr != nil {
53 if binErr := proto.Unmarshal(cfgBytes, &cfg); binErr != nil {
54 return nil, fmt.Errorf("failed to parse TemporalLogConfig from %q as text protobuf (%v) or binary protobuf (%v)", filename, txtErr, binErr)
55 }
56 }
57
58 if len(cfg.Shard) == 0 {
59 return nil, errors.New("empty log config found")
60 }
61 return &cfg, nil
62 }
63
64
65
66
67 type AddLogClient interface {
68 AddChain(ctx context.Context, chain []ct.ASN1Cert) (*ct.SignedCertificateTimestamp, error)
69 AddPreChain(ctx context.Context, chain []ct.ASN1Cert) (*ct.SignedCertificateTimestamp, error)
70 GetAcceptedRoots(ctx context.Context) ([]ct.ASN1Cert, error)
71 }
72
73
74 type TemporalLogClient struct {
75 Clients []*LogClient
76 intervals []interval
77 }
78
79
80
81 func NewTemporalLogClient(cfg *configpb.TemporalLogConfig, hc *http.Client) (*TemporalLogClient, error) {
82 if len(cfg.GetShard()) == 0 {
83 return nil, errors.New("empty config")
84 }
85
86 overall, err := shardInterval(cfg.Shard[0])
87 if err != nil {
88 return nil, fmt.Errorf("cfg.Shard[0] invalid: %v", err)
89 }
90 intervals := make([]interval, 0, len(cfg.Shard))
91 intervals = append(intervals, overall)
92 for i := 1; i < len(cfg.Shard); i++ {
93 interval, err := shardInterval(cfg.Shard[i])
94 if err != nil {
95 return nil, fmt.Errorf("cfg.Shard[%d] invalid: %v", i, err)
96 }
97 if overall.upper == nil {
98 return nil, fmt.Errorf("cfg.Shard[%d] extends an interval with no upper bound", i)
99 }
100 if interval.lower == nil {
101 return nil, fmt.Errorf("cfg.Shard[%d] has no lower bound but extends an interval", i)
102 }
103 if !interval.lower.Equal(*overall.upper) {
104 return nil, fmt.Errorf("cfg.Shard[%d] starts at %v but previous interval ended at %v", i, interval.lower, overall.upper)
105 }
106 overall.upper = interval.upper
107 intervals = append(intervals, interval)
108 }
109 clients := make([]*LogClient, 0, len(cfg.Shard))
110 for i, shard := range cfg.Shard {
111 opts := jsonclient.Options{UserAgent: "ct-go-multilog/1.0"}
112 opts.PublicKeyDER = shard.GetPublicKeyDer()
113 c, err := New(shard.Uri, hc, opts)
114 if err != nil {
115 return nil, fmt.Errorf("failed to create client for cfg.Shard[%d]: %v", i, err)
116 }
117 clients = append(clients, c)
118 }
119 tlc := TemporalLogClient{
120 Clients: clients,
121 intervals: intervals,
122 }
123 return &tlc, nil
124 }
125
126
127
128 func (tlc *TemporalLogClient) GetAcceptedRoots(ctx context.Context) ([]ct.ASN1Cert, error) {
129 type result struct {
130 roots []ct.ASN1Cert
131 err error
132 }
133 results := make(chan result, len(tlc.Clients))
134 for _, c := range tlc.Clients {
135 go func(c *LogClient) {
136 var r result
137 r.roots, r.err = c.GetAcceptedRoots(ctx)
138 results <- r
139 }(c)
140 }
141
142 var allRoots []ct.ASN1Cert
143 seen := make(map[[sha256.Size]byte]bool)
144 for range tlc.Clients {
145 r := <-results
146 if r.err != nil {
147 return nil, r.err
148 }
149 for _, root := range r.roots {
150 h := sha256.Sum256(root.Data)
151 if seen[h] {
152 continue
153 }
154 seen[h] = true
155 allRoots = append(allRoots, root)
156 }
157 }
158 return allRoots, nil
159 }
160
161
162 func (tlc *TemporalLogClient) AddChain(ctx context.Context, chain []ct.ASN1Cert) (*ct.SignedCertificateTimestamp, error) {
163 return tlc.addChain(ctx, ct.X509LogEntryType, ct.AddChainPath, chain)
164 }
165
166
167 func (tlc *TemporalLogClient) AddPreChain(ctx context.Context, chain []ct.ASN1Cert) (*ct.SignedCertificateTimestamp, error) {
168 return tlc.addChain(ctx, ct.PrecertLogEntryType, ct.AddPreChainPath, chain)
169 }
170
171 func (tlc *TemporalLogClient) addChain(ctx context.Context, ctype ct.LogEntryType, path string, chain []ct.ASN1Cert) (*ct.SignedCertificateTimestamp, error) {
172
173 if len(chain) == 0 {
174 return nil, errors.New("missing chain")
175 }
176 cert, err := x509.ParseCertificate(chain[0].Data)
177 if err != nil {
178 return nil, fmt.Errorf("failed to parse initial chain entry: %v", err)
179 }
180 cidx, err := tlc.IndexByDate(cert.NotAfter)
181 if err != nil {
182 return nil, fmt.Errorf("failed to find log to process cert: %v", err)
183 }
184 return tlc.Clients[cidx].addChainWithRetry(ctx, ctype, path, chain)
185 }
186
187
188
189 func (tlc *TemporalLogClient) IndexByDate(when time.Time) (int, error) {
190 for i, interval := range tlc.intervals {
191 if (interval.lower != nil) && when.Before(*interval.lower) {
192 continue
193 }
194 if (interval.upper != nil) && !when.Before(*interval.upper) {
195 continue
196 }
197 return i, nil
198 }
199 return -1, fmt.Errorf("no log found encompassing date %v", when)
200 }
201
202 func shardInterval(cfg *configpb.LogShardConfig) (interval, error) {
203 var interval interval
204 if cfg.NotAfterStart != nil {
205 if err := cfg.NotAfterStart.CheckValid(); err != nil {
206 return interval, fmt.Errorf("failed to parse NotAfterStart: %v", err)
207 }
208 t := cfg.NotAfterStart.AsTime()
209 interval.lower = &t
210 }
211 if cfg.NotAfterLimit != nil {
212 if err := cfg.NotAfterLimit.CheckValid(); err != nil {
213 return interval, fmt.Errorf("failed to parse NotAfterLimit: %v", err)
214 }
215 t := cfg.NotAfterLimit.AsTime()
216 interval.upper = &t
217 }
218
219 if interval.lower != nil && interval.upper != nil && !(*interval.lower).Before(*interval.upper) {
220 return interval, errors.New("inverted interval")
221 }
222 return interval, nil
223 }
224
View as plain text