...
1
2
3
4
5
6
7 package unified
8
9 import (
10 "context"
11 "fmt"
12
13 "go.mongodb.org/mongo-driver/mongo"
14 )
15
16
17 type ctxKey string
18
19 const (
20
21
22
23
24 operationalFailPointKey ctxKey = "operational-fail-point"
25
26 entitiesKey ctxKey = "test-entities"
27
28 failPointsKey ctxKey = "test-failpoints"
29
30 targetedFailPointsKey ctxKey = "test-targeted-failpoints"
31 clientLogMessagesKey ctxKey = "test-expected-log-message-count"
32 ignoreLogMessagesKey ctxKey = "test-ignore-log-message-count"
33 )
34
35
36
37 func newTestContext(
38 ctx context.Context,
39 entityMap *EntityMap,
40 clientLogMessages []*clientLogMessages,
41 hasOperationalFailPoint bool,
42 ) context.Context {
43 ctx = context.WithValue(ctx, operationalFailPointKey, hasOperationalFailPoint)
44 ctx = context.WithValue(ctx, entitiesKey, entityMap)
45 ctx = context.WithValue(ctx, failPointsKey, make(map[string]*mongo.Client))
46 ctx = context.WithValue(ctx, targetedFailPointsKey, make(map[string]string))
47 ctx = context.WithValue(ctx, clientLogMessagesKey, clientLogMessages)
48 return ctx
49 }
50
51 func addFailPoint(ctx context.Context, failPoint string, client *mongo.Client) error {
52 failPoints := ctx.Value(failPointsKey).(map[string]*mongo.Client)
53 if _, ok := failPoints[failPoint]; ok {
54 return fmt.Errorf("fail point %q already exists in tracked fail points map", failPoint)
55 }
56
57 failPoints[failPoint] = client
58 return nil
59 }
60
61 func addTargetedFailPoint(ctx context.Context, failPoint string, host string) error {
62 failPoints := ctx.Value(targetedFailPointsKey).(map[string]string)
63 if _, ok := failPoints[failPoint]; ok {
64 return fmt.Errorf("fail point %q already exists in tracked targeted fail points map", failPoint)
65 }
66
67 failPoints[failPoint] = host
68 return nil
69 }
70
71 func failPoints(ctx context.Context) map[string]*mongo.Client {
72 return ctx.Value(failPointsKey).(map[string]*mongo.Client)
73 }
74
75 func hasOperationalFailpoint(ctx context.Context) bool {
76 return ctx.Value(operationalFailPointKey).(bool)
77 }
78
79 func targetedFailPoints(ctx context.Context) map[string]string {
80 return ctx.Value(targetedFailPointsKey).(map[string]string)
81 }
82
83 func entities(ctx context.Context) *EntityMap {
84 return ctx.Value(entitiesKey).(*EntityMap)
85 }
86
87 func expectedLogMessagesCount(ctx context.Context, clientID string) int {
88 messages := ctx.Value(clientLogMessagesKey).([]*clientLogMessages)
89
90 count := 0
91 for _, message := range messages {
92 if message.Client == clientID {
93 count += len(message.LogMessages)
94 }
95 }
96
97 return count
98 }
99
100 func ignoreLogMessages(ctx context.Context, clientID string) []*logMessage {
101 messages := ctx.Value(clientLogMessagesKey).([]*clientLogMessages)
102
103 ignoreMessages := []*logMessage{}
104 for _, message := range messages {
105 if message.Client == clientID {
106 ignoreMessages = append(ignoreMessages, message.IgnoreMessages...)
107 }
108 }
109
110 return ignoreMessages
111 }
112
View as plain text