1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package submission
16
17 import (
18 "context"
19 "fmt"
20 "strings"
21 "sync"
22 "time"
23
24 ct "github.com/google/certificate-transparency-go"
25 "github.com/google/certificate-transparency-go/ctpolicy"
26 )
27
28 const (
29
30
31
32 PostBatchInterval = time.Second
33 )
34
35
36 type Submitter interface {
37 SubmitToLog(ctx context.Context, logURL string, chain []ct.ASN1Cert, asPreChain bool) (*ct.SignedCertificateTimestamp, error)
38 }
39
40
41 type submissionResult struct {
42 sct *ct.SignedCertificateTimestamp
43 err error
44 }
45
46 type groupState struct {
47 Name string
48 Success bool
49 }
50
51
52
53
54 type safeSubmissionState struct {
55 mu sync.Mutex
56 logToGroups map[string]ctpolicy.GroupSet
57 groupNeeds map[string]int
58
59 results map[string]*submissionResult
60 cancels map[string]context.CancelFunc
61 }
62
63 func newSafeSubmissionState(groups ctpolicy.LogPolicyData) *safeSubmissionState {
64 var s safeSubmissionState
65 s.logToGroups = ctpolicy.GroupByLogs(groups)
66 s.groupNeeds = make(map[string]int)
67 for _, g := range groups {
68 s.groupNeeds[g.Name] = g.MinInclusions
69 }
70 s.results = make(map[string]*submissionResult)
71 s.cancels = make(map[string]context.CancelFunc)
72 return &s
73 }
74
75
76
77 func (sub *safeSubmissionState) request(logURL string, cancel context.CancelFunc) bool {
78 sub.mu.Lock()
79 defer sub.mu.Unlock()
80 if sub.results[logURL] != nil {
81
82 return false
83 }
84 sub.results[logURL] = &submissionResult{}
85 isAwaited := false
86 for g := range sub.logToGroups[logURL] {
87 if sub.groupNeeds[g] > 0 {
88 isAwaited = true
89 break
90 }
91 }
92 if !isAwaited {
93
94 return false
95 }
96 sub.cancels[logURL] = cancel
97 return true
98 }
99
100
101
102
103 func (sub *safeSubmissionState) setResult(logURL string, sct *ct.SignedCertificateTimestamp, err error) {
104 sub.mu.Lock()
105 defer sub.mu.Unlock()
106 if sct == nil {
107 sub.results[logURL] = &submissionResult{sct: sct, err: err}
108 return
109 }
110
111 for groupName := range sub.logToGroups[logURL] {
112
113 if groupName == ctpolicy.BaseName {
114 continue
115 }
116 if sub.groupNeeds[groupName] > 0 {
117 sub.results[logURL] = &submissionResult{sct: sct, err: err}
118 }
119 sub.groupNeeds[groupName]--
120 }
121
122
123 if sub.logToGroups[logURL][ctpolicy.BaseName] {
124 if sub.results[logURL].sct != nil {
125
126 sub.groupNeeds[ctpolicy.BaseName]--
127 } else if sub.groupNeeds[ctpolicy.BaseName] > 0 {
128 minInclusionsForOtherGroup := 0
129 for g, cnt := range sub.groupNeeds {
130 if g != ctpolicy.BaseName && cnt > 0 {
131 minInclusionsForOtherGroup += cnt
132 }
133 }
134
135
136 if sub.groupNeeds[ctpolicy.BaseName] > minInclusionsForOtherGroup {
137 sub.results[logURL] = &submissionResult{sct: sct, err: err}
138 sub.groupNeeds[ctpolicy.BaseName]--
139 }
140 }
141 }
142
143
144
145 for logURL, groupSet := range sub.logToGroups {
146 isAwaited := false
147 for g := range groupSet {
148 if sub.groupNeeds[g] > 0 {
149 isAwaited = true
150 break
151 }
152 }
153 if !isAwaited && sub.cancels[logURL] != nil {
154 sub.cancels[logURL]()
155 sub.cancels[logURL] = nil
156 }
157 }
158 }
159
160
161 func (sub *safeSubmissionState) groupComplete(groupName string) bool {
162 sub.mu.Lock()
163 defer sub.mu.Unlock()
164 needs, ok := sub.groupNeeds[groupName]
165 if !ok {
166 return true
167 }
168 return needs <= 0
169 }
170
171 func (sub *safeSubmissionState) collectSCTs() []*AssignedSCT {
172 sub.mu.Lock()
173 defer sub.mu.Unlock()
174 scts := []*AssignedSCT{}
175 for logURL, r := range sub.results {
176 if r != nil && r.sct != nil {
177 scts = append(scts, &AssignedSCT{LogURL: logURL, SCT: r.sct})
178 }
179 }
180 return scts
181 }
182
183
184
185
186 func postInterval(idx int, parallelStart int, dur time.Duration) time.Duration {
187 if idx < parallelStart {
188 return time.Duration(0)
189 }
190 return time.Duration(idx+1-parallelStart) * dur
191 }
192
193
194
195 func groupRace(ctx context.Context, chain []ct.ASN1Cert, asPreChain bool,
196 group *ctpolicy.LogGroupInfo, parallelStart int,
197 state *safeSubmissionState, submitter Submitter) groupState {
198
199
200 session := group.GetSubmissionSession()
201 type count struct{}
202 counter := make(chan count, len(session))
203
204 countCall := func() {
205 counter <- count{}
206 }
207
208 for i, logURL := range session {
209 subCtx, cancel := context.WithCancel(ctx)
210 go func(i int, logURL string) {
211 defer countCall()
212 timeoutchan := time.After(postInterval(i, parallelStart, PostBatchInterval))
213
214 select {
215 case <-subCtx.Done():
216 return
217 case <-timeoutchan:
218 }
219 if state.groupComplete(group.Name) {
220 cancel()
221 return
222 }
223 if firstRequested := state.request(logURL, cancel); !firstRequested {
224 return
225 }
226 sct, err := submitter.SubmitToLog(subCtx, logURL, chain, asPreChain)
227
228 state.setResult(logURL, sct, err)
229 }(i, logURL)
230 }
231
232
233 for range session {
234 select {
235 case <-ctx.Done():
236 return groupState{Name: group.Name, Success: state.groupComplete(group.Name)}
237 case <-counter:
238 if state.groupComplete(group.Name) {
239 return groupState{Name: group.Name, Success: true}
240 }
241 }
242 }
243 return groupState{Name: group.Name, Success: state.groupComplete(group.Name)}
244 }
245
246 func parallelNums(groups ctpolicy.LogPolicyData) map[string]int {
247 nums := make(map[string]int)
248 var subsetSum int
249 for _, g := range groups {
250 nums[g.Name] = g.MinInclusions
251 if !g.IsBase {
252 subsetSum += g.MinInclusions
253 }
254 }
255 if _, hasBase := nums[ctpolicy.BaseName]; hasBase {
256 if nums[ctpolicy.BaseName] >= subsetSum {
257 nums[ctpolicy.BaseName] -= subsetSum
258 } else {
259 nums[ctpolicy.BaseName] = 0
260 }
261 }
262 return nums
263 }
264
265
266 type AssignedSCT struct {
267 LogURL string
268 SCT *ct.SignedCertificateTimestamp
269 }
270
271 func completenessError(groupComplete map[string]bool) error {
272 failedGroups := []string{}
273 for name, success := range groupComplete {
274 if !success {
275 failedGroups = append(failedGroups, name)
276 }
277 }
278 if len(failedGroups) > 0 {
279 return fmt.Errorf("log-group(s) %s didn't receive enough SCTs", strings.Join(failedGroups, ", "))
280 }
281 return nil
282 }
283
284
285
286
287 func GetSCTs(ctx context.Context, submitter Submitter, chain []ct.ASN1Cert, asPreChain bool, groups ctpolicy.LogPolicyData) ([]*AssignedSCT, error) {
288 groupComplete := make(map[string]bool)
289 for _, g := range groups {
290 groupComplete[g.Name] = false
291 }
292
293 parallelNums := parallelNums(groups)
294
295 groupEvents := make(chan groupState, len(groups))
296 submissions := newSafeSubmissionState(groups)
297 for _, g := range groups {
298 go func(g *ctpolicy.LogGroupInfo) {
299 groupEvents <- groupRace(ctx, chain, asPreChain, g, parallelNums[g.Name], submissions, submitter)
300 }(g)
301 }
302
303
304
305 for i := 0; i < len(groups); i++ {
306 select {
307 case <-ctx.Done():
308 return submissions.collectSCTs(), completenessError(groupComplete)
309 case g := <-groupEvents:
310 groupComplete[g.Name] = g.Success
311 }
312 }
313 return submissions.collectSCTs(), completenessError(groupComplete)
314 }
315
View as plain text