1
2
3
4
5
6
7 package unified
8
9 import (
10 "context"
11 "errors"
12 "fmt"
13 "strings"
14
15 "go.mongodb.org/mongo-driver/bson"
16 "go.mongodb.org/mongo-driver/mongo"
17 )
18
19
20
21 type expectedError struct {
22 IsClientError *bool `bson:"isClientError"`
23 IsTimeoutError *bool `bson:"isTimeoutError"`
24 ErrorSubstring *string `bson:"errorContains"`
25 Code *int32 `bson:"errorCode"`
26 CodeName *string `bson:"errorCodeName"`
27 IncludedLabels []string `bson:"errorLabelsContain"`
28 OmittedLabels []string `bson:"errorLabelsOmit"`
29 ExpectedResult *bson.RawValue `bson:"expectResult"`
30 ErrorResponse *bson.Raw `bson:"errorResponse"`
31 }
32
33
34
35
36 func verifyOperationError(ctx context.Context, expected *expectedError, result *operationResult) error {
37
38
39 if errors.Is(result.Err, mongo.ErrUnacknowledgedWrite) {
40 result.Err = nil
41 }
42
43 if expected == nil {
44 if result.Err != nil {
45 return fmt.Errorf("expected no error, but got %w", result.Err)
46 }
47 return nil
48 }
49
50 if result.Err == nil {
51 return fmt.Errorf("expected error, got nil")
52 }
53
54
55 if expected.ErrorSubstring != nil {
56
57
58 expectedErrMsg := strings.ToLower(*expected.ErrorSubstring)
59 actualErrMsg := strings.ToLower(result.Err.Error())
60 if !strings.Contains(actualErrMsg, expectedErrMsg) {
61 return fmt.Errorf("expected error %w to contain substring %s", result.Err, *expected.ErrorSubstring)
62 }
63 }
64
65
66
67 details, serverError := extractErrorDetails(result.Err)
68 if expected.IsClientError != nil {
69
70 isClientError := !serverError || mongo.IsNetworkError(result.Err)
71 if *expected.IsClientError != isClientError {
72 return fmt.Errorf("expected error %w to be a client error: %v, is client error: %v", result.Err,
73 *expected.IsClientError, isClientError)
74 }
75 }
76 if expected.IsTimeoutError != nil {
77 isTimeoutError := mongo.IsTimeout(result.Err)
78 if *expected.IsTimeoutError != isTimeoutError {
79 return fmt.Errorf("expected error %w to be a timeout error: %v, is timeout error: %v", result.Err,
80 *expected.IsTimeoutError, isTimeoutError)
81 }
82 }
83 if !serverError {
84
85 if expected.Code != nil || expected.CodeName != nil || expected.IncludedLabels != nil || expected.OmittedLabels != nil {
86 return fmt.Errorf("failed to extract details from error %v of type %T", result.Err, result.Err)
87 }
88 }
89
90 if expected.Code != nil {
91 var found bool
92 for _, code := range details.codes {
93 if code == *expected.Code {
94 found = true
95 break
96 }
97 }
98 if !found {
99 return fmt.Errorf("expected error %w to have code %d", result.Err, *expected.Code)
100 }
101 }
102 if expected.CodeName != nil {
103 var found bool
104 for _, codeName := range details.codeNames {
105 if codeName == *expected.CodeName {
106 found = true
107 break
108 }
109 }
110 if !found {
111 return fmt.Errorf("expected error %w to have code name %q", result.Err, *expected.CodeName)
112 }
113 }
114 for _, label := range expected.IncludedLabels {
115 if !stringSliceContains(details.labels, label) {
116 return fmt.Errorf("expected error %w to contain label %q", result.Err, label)
117 }
118 }
119 for _, label := range expected.OmittedLabels {
120 if stringSliceContains(details.labels, label) {
121 return fmt.Errorf("expected error %w to not contain label %q", result.Err, label)
122 }
123 }
124
125 if expected.ExpectedResult != nil {
126 if err := verifyOperationResult(ctx, *expected.ExpectedResult, result); err != nil {
127 return fmt.Errorf("result comparison error: %w", err)
128 }
129 }
130
131 if expected.ErrorResponse != nil {
132 if details.raw == nil {
133 return fmt.Errorf("expected error response from the server, got none")
134 }
135
136
137 gotValue := documentToRawValue(details.raw)
138 expectedValue := documentToRawValue(*expected.ErrorResponse)
139 if err := verifyValuesMatch(ctx, expectedValue, gotValue, true); err != nil {
140 return fmt.Errorf("error response comparison error: %w", err)
141 }
142 }
143 return nil
144 }
145
146
147 type errorDetails struct {
148 codes []int32
149 codeNames []string
150 labels []string
151 raw bson.Raw
152 }
153
154
155
156 func extractErrorDetails(err error) (errorDetails, bool) {
157 var details errorDetails
158
159 switch converted := err.(type) {
160 case mongo.CommandError:
161 details.codes = []int32{converted.Code}
162 details.codeNames = []string{converted.Name}
163 details.labels = converted.Labels
164 details.raw = converted.Raw
165 case mongo.WriteException:
166 if converted.WriteConcernError != nil {
167 details.codes = append(details.codes, int32(converted.WriteConcernError.Code))
168 details.codeNames = append(details.codeNames, converted.WriteConcernError.Name)
169 }
170 for _, we := range converted.WriteErrors {
171 details.codes = append(details.codes, int32(we.Code))
172 }
173 details.labels = converted.Labels
174 details.raw = converted.Raw
175 case mongo.BulkWriteException:
176 if converted.WriteConcernError != nil {
177 details.codes = append(details.codes, int32(converted.WriteConcernError.Code))
178 details.codeNames = append(details.codeNames, converted.WriteConcernError.Name)
179 }
180 for _, we := range converted.WriteErrors {
181 details.codes = append(details.codes, int32(we.Code))
182 details.raw = we.Raw
183 }
184 details.labels = converted.Labels
185 default:
186 return errorDetails{}, false
187 }
188
189 return details, true
190 }
191
192 func stringSliceContains(arr []string, target string) bool {
193 for _, val := range arr {
194 if val == target {
195 return true
196 }
197 }
198 return false
199 }
200
View as plain text