1
2
3
4
5
6
7 package bsoncodec
8
9 import (
10 "errors"
11 "reflect"
12 "testing"
13
14 "github.com/google/go-cmp/cmp"
15 "go.mongodb.org/mongo-driver/bson/bsonrw"
16 "go.mongodb.org/mongo-driver/bson/bsontype"
17 "go.mongodb.org/mongo-driver/internal/assert"
18 )
19
20 func TestRegistryBuilder(t *testing.T) {
21 t.Run("Register", func(t *testing.T) {
22 fc1, fc2, fc3, fc4 := new(fakeCodec), new(fakeCodec), new(fakeCodec), new(fakeCodec)
23 t.Run("interface", func(t *testing.T) {
24 var t1f *testInterface1
25 var t2f *testInterface2
26 var t4f *testInterface4
27 ips := []interfaceValueEncoder{
28 {i: reflect.TypeOf(t1f).Elem(), ve: fc1},
29 {i: reflect.TypeOf(t2f).Elem(), ve: fc2},
30 {i: reflect.TypeOf(t1f).Elem(), ve: fc3},
31 {i: reflect.TypeOf(t4f).Elem(), ve: fc4},
32 }
33 want := []interfaceValueEncoder{
34 {i: reflect.TypeOf(t1f).Elem(), ve: fc3},
35 {i: reflect.TypeOf(t2f).Elem(), ve: fc2},
36 {i: reflect.TypeOf(t4f).Elem(), ve: fc4},
37 }
38 rb := NewRegistryBuilder()
39 for _, ip := range ips {
40 rb.RegisterHookEncoder(ip.i, ip.ve)
41 }
42
43 reg := rb.Build()
44 got := reg.interfaceEncoders
45 if !cmp.Equal(got, want, cmp.AllowUnexported(interfaceValueEncoder{}, fakeCodec{}), cmp.Comparer(typeComparer)) {
46 t.Errorf("the registered interfaces are not correct: got %#v, want %#v", got, want)
47 }
48 })
49 t.Run("type", func(t *testing.T) {
50 ft1, ft2, ft4 := fakeType1{}, fakeType2{}, fakeType4{}
51 rb := NewRegistryBuilder().
52 RegisterTypeEncoder(reflect.TypeOf(ft1), fc1).
53 RegisterTypeEncoder(reflect.TypeOf(ft2), fc2).
54 RegisterTypeEncoder(reflect.TypeOf(ft1), fc3).
55 RegisterTypeEncoder(reflect.TypeOf(ft4), fc4)
56 want := []struct {
57 t reflect.Type
58 c ValueEncoder
59 }{
60 {reflect.TypeOf(ft1), fc3},
61 {reflect.TypeOf(ft2), fc2},
62 {reflect.TypeOf(ft4), fc4},
63 }
64
65 reg := rb.Build()
66 got := reg.typeEncoders
67 for _, s := range want {
68 wantT, wantC := s.t, s.c
69 gotC, exists := got.Load(wantT)
70 if !exists {
71 t.Errorf("Did not find type in the type registry: %v", wantT)
72 }
73 if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) {
74 t.Errorf("codecs did not match: got %#v; want %#v", gotC, wantC)
75 }
76 }
77 })
78 t.Run("kind", func(t *testing.T) {
79 k1, k2, k4 := reflect.Struct, reflect.Slice, reflect.Map
80 rb := NewRegistryBuilder().
81 RegisterDefaultEncoder(k1, fc1).
82 RegisterDefaultEncoder(k2, fc2).
83 RegisterDefaultEncoder(k1, fc3).
84 RegisterDefaultEncoder(k4, fc4)
85 want := []struct {
86 k reflect.Kind
87 c ValueEncoder
88 }{
89 {k1, fc3},
90 {k2, fc2},
91 {k4, fc4},
92 }
93
94 reg := rb.Build()
95 got := reg.kindEncoders
96 for _, s := range want {
97 wantK, wantC := s.k, s.c
98 gotC, exists := got.Load(wantK)
99 if !exists {
100 t.Errorf("Did not find kind in the kind registry: %v", wantK)
101 }
102 if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) {
103 t.Errorf("codecs did not match: got %#v; want %#v", gotC, wantC)
104 }
105 }
106 })
107 t.Run("RegisterDefault", func(t *testing.T) {
108 t.Run("MapCodec", func(t *testing.T) {
109 codec := &fakeCodec{num: 1}
110 codec2 := &fakeCodec{num: 2}
111 rb := NewRegistryBuilder()
112
113 rb.RegisterDefaultEncoder(reflect.Map, codec)
114 reg := rb.Build()
115 if reg.kindEncoders.get(reflect.Map) != codec {
116 t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec)
117 }
118
119 rb.RegisterDefaultEncoder(reflect.Map, codec2)
120 reg = rb.Build()
121 if reg.kindEncoders.get(reflect.Map) != codec2 {
122 t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec2)
123 }
124 })
125 t.Run("StructCodec", func(t *testing.T) {
126 codec := &fakeCodec{num: 1}
127 codec2 := &fakeCodec{num: 2}
128 rb := NewRegistryBuilder()
129
130 rb.RegisterDefaultEncoder(reflect.Struct, codec)
131 reg := rb.Build()
132 if reg.kindEncoders.get(reflect.Struct) != codec {
133 t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec)
134 }
135
136 rb.RegisterDefaultEncoder(reflect.Struct, codec2)
137 reg = rb.Build()
138 if reg.kindEncoders.get(reflect.Struct) != codec2 {
139 t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec2)
140 }
141 })
142 t.Run("SliceCodec", func(t *testing.T) {
143 codec := &fakeCodec{num: 1}
144 codec2 := &fakeCodec{num: 2}
145 rb := NewRegistryBuilder()
146
147 rb.RegisterDefaultEncoder(reflect.Slice, codec)
148 reg := rb.Build()
149 if reg.kindEncoders.get(reflect.Slice) != codec {
150 t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec)
151 }
152
153 rb.RegisterDefaultEncoder(reflect.Slice, codec2)
154 reg = rb.Build()
155 if reg.kindEncoders.get(reflect.Slice) != codec2 {
156 t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec2)
157 }
158 })
159 t.Run("ArrayCodec", func(t *testing.T) {
160 codec := &fakeCodec{num: 1}
161 codec2 := &fakeCodec{num: 2}
162 rb := NewRegistryBuilder()
163
164 rb.RegisterDefaultEncoder(reflect.Array, codec)
165 reg := rb.Build()
166 if reg.kindEncoders.get(reflect.Array) != codec {
167 t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec)
168 }
169
170 rb.RegisterDefaultEncoder(reflect.Array, codec2)
171 reg = rb.Build()
172 if reg.kindEncoders.get(reflect.Array) != codec2 {
173 t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec2)
174 }
175 })
176 })
177 t.Run("Lookup", func(t *testing.T) {
178 type Codec interface {
179 ValueEncoder
180 ValueDecoder
181 }
182
183 var (
184 arrinstance [12]int
185 arr = reflect.TypeOf(arrinstance)
186 slc = reflect.TypeOf(make([]int, 12))
187 m = reflect.TypeOf(make(map[string]int))
188 strct = reflect.TypeOf(struct{ Foo string }{})
189 ft1 = reflect.PtrTo(reflect.TypeOf(fakeType1{}))
190 ft2 = reflect.TypeOf(fakeType2{})
191 ft3 = reflect.TypeOf(fakeType5(func(string, string) string { return "fakeType5" }))
192 ti1 = reflect.TypeOf((*testInterface1)(nil)).Elem()
193 ti2 = reflect.TypeOf((*testInterface2)(nil)).Elem()
194 ti1Impl = reflect.TypeOf(testInterface1Impl{})
195 ti2Impl = reflect.TypeOf(testInterface2Impl{})
196 ti3 = reflect.TypeOf((*testInterface3)(nil)).Elem()
197 ti3Impl = reflect.TypeOf(testInterface3Impl{})
198 ti3ImplPtr = reflect.TypeOf((*testInterface3Impl)(nil))
199 fc1, fc2 = &fakeCodec{num: 1}, &fakeCodec{num: 2}
200 fsc, fslcc, fmc = new(fakeStructCodec), new(fakeSliceCodec), new(fakeMapCodec)
201 pc = NewPointerCodec()
202 )
203
204 reg := NewRegistryBuilder().
205 RegisterTypeEncoder(ft1, fc1).
206 RegisterTypeEncoder(ft2, fc2).
207 RegisterTypeEncoder(ti1, fc1).
208 RegisterDefaultEncoder(reflect.Struct, fsc).
209 RegisterDefaultEncoder(reflect.Slice, fslcc).
210 RegisterDefaultEncoder(reflect.Array, fslcc).
211 RegisterDefaultEncoder(reflect.Map, fmc).
212 RegisterDefaultEncoder(reflect.Ptr, pc).
213 RegisterTypeDecoder(ft1, fc1).
214 RegisterTypeDecoder(ft2, fc2).
215 RegisterTypeDecoder(ti1, fc1).
216 RegisterDefaultDecoder(reflect.Struct, fsc).
217 RegisterDefaultDecoder(reflect.Slice, fslcc).
218 RegisterDefaultDecoder(reflect.Array, fslcc).
219 RegisterDefaultDecoder(reflect.Map, fmc).
220 RegisterDefaultDecoder(reflect.Ptr, pc).
221 RegisterHookEncoder(ti2, fc2).
222 RegisterHookDecoder(ti2, fc2).
223 RegisterHookEncoder(ti3, fc3).
224 RegisterHookDecoder(ti3, fc3).
225 Build()
226
227 testCases := []struct {
228 name string
229 t reflect.Type
230 wantcodec Codec
231 wanterr error
232 testcache bool
233 }{
234 {
235 "type registry (pointer)",
236 ft1,
237 fc1,
238 nil,
239 false,
240 },
241 {
242 "type registry (non-pointer)",
243 ft2,
244 fc2,
245 nil,
246 false,
247 },
248 {
249
250 "interface with type encoder",
251 ti1,
252 fc1,
253 nil,
254 true,
255 },
256 {
257
258 "interface implementation with type encoder",
259 ti1Impl,
260 fsc,
261 nil,
262 false,
263 },
264 {
265
266 "interface with hook",
267 ti2,
268 fc2,
269 nil,
270 false,
271 },
272 {
273
274 "interface implementation with hook",
275 ti2Impl,
276 fc2,
277 nil,
278 false,
279 },
280 {
281
282
283 "interface pointer to implementation with hook (pointer)",
284 ti3ImplPtr,
285 fc3,
286 nil,
287 false,
288 },
289 {
290 "default struct codec (pointer)",
291 reflect.PtrTo(strct),
292 pc,
293 nil,
294 false,
295 },
296 {
297 "default struct codec (non-pointer)",
298 strct,
299 fsc,
300 nil,
301 false,
302 },
303 {
304 "default array codec",
305 arr,
306 fslcc,
307 nil,
308 false,
309 },
310 {
311 "default slice codec",
312 slc,
313 fslcc,
314 nil,
315 false,
316 },
317 {
318 "default map",
319 m,
320 fmc,
321 nil,
322 false,
323 },
324 {
325 "map non-string key",
326 reflect.TypeOf(map[int]int{}),
327 fmc,
328 nil,
329 false,
330 },
331 {
332 "No Codec Registered",
333 ft3,
334 nil,
335 ErrNoEncoder{Type: ft3},
336 false,
337 },
338 }
339
340 allowunexported := cmp.AllowUnexported(fakeCodec{}, fakeStructCodec{}, fakeSliceCodec{}, fakeMapCodec{})
341 comparepc := func(pc1, pc2 *PointerCodec) bool { return true }
342 for _, tc := range testCases {
343 t.Run(tc.name, func(t *testing.T) {
344 t.Run("Encoder", func(t *testing.T) {
345 gotcodec, goterr := reg.LookupEncoder(tc.t)
346 if !cmp.Equal(goterr, tc.wanterr, cmp.Comparer(compareErrors)) {
347 t.Errorf("errors did not match: got %#v, want %#v", goterr, tc.wanterr)
348 }
349 if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) {
350 t.Errorf("codecs did not match: got %#v, want %#v", gotcodec, tc.wantcodec)
351 }
352 })
353 t.Run("Decoder", func(t *testing.T) {
354 wanterr := tc.wanterr
355 if ene, ok := tc.wanterr.(ErrNoEncoder); ok {
356 wanterr = ErrNoDecoder(ene)
357 }
358
359 gotcodec, goterr := reg.LookupDecoder(tc.t)
360 if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) {
361 t.Errorf("errors did not match: got %#v, want %#v", goterr, wanterr)
362 }
363 if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) {
364 t.Errorf("codecs did not match: got %#v, want %#v", gotcodec, tc.wantcodec)
365 }
366 })
367 })
368 }
369
370
371 t.Run("interface implementation with hook (pointer)", func(t *testing.T) {
372 t.Run("Encoder", func(t *testing.T) {
373 gotEnc, err := reg.LookupEncoder(ti3Impl)
374 assert.Nil(t, err, "LookupEncoder error: %v", err)
375
376 cae, ok := gotEnc.(*condAddrEncoder)
377 assert.True(t, ok, "Expected CondAddrEncoder, got %T", gotEnc)
378 if !cmp.Equal(cae.canAddrEnc, fc3, allowunexported, cmp.Comparer(comparepc)) {
379 t.Errorf("expected canAddrEnc %#v, got %#v", cae.canAddrEnc, fc3)
380 }
381 if !cmp.Equal(cae.elseEnc, fsc, allowunexported, cmp.Comparer(comparepc)) {
382 t.Errorf("expected elseEnc %#v, got %#v", cae.elseEnc, fsc)
383 }
384 })
385 t.Run("Decoder", func(t *testing.T) {
386 gotDec, err := reg.LookupDecoder(ti3Impl)
387 assert.Nil(t, err, "LookupDecoder error: %v", err)
388
389 cad, ok := gotDec.(*condAddrDecoder)
390 assert.True(t, ok, "Expected CondAddrDecoder, got %T", gotDec)
391 if !cmp.Equal(cad.canAddrDec, fc3, allowunexported, cmp.Comparer(comparepc)) {
392 t.Errorf("expected canAddrDec %#v, got %#v", cad.canAddrDec, fc3)
393 }
394 if !cmp.Equal(cad.elseDec, fsc, allowunexported, cmp.Comparer(comparepc)) {
395 t.Errorf("expected elseDec %#v, got %#v", cad.elseDec, fsc)
396 }
397 })
398 })
399 })
400 })
401 t.Run("Type Map", func(t *testing.T) {
402 reg := NewRegistryBuilder().
403 RegisterTypeMapEntry(bsontype.String, reflect.TypeOf("")).
404 RegisterTypeMapEntry(bsontype.Int32, reflect.TypeOf(int(0))).
405 Build()
406
407 var got, want reflect.Type
408
409 want = reflect.TypeOf("")
410 got, err := reg.LookupTypeMapEntry(bsontype.String)
411 noerr(t, err)
412 if got != want {
413 t.Errorf("unexpected type: got %#v, want %#v", got, want)
414 }
415
416 want = reflect.TypeOf(int(0))
417 got, err = reg.LookupTypeMapEntry(bsontype.Int32)
418 noerr(t, err)
419 if got != want {
420 t.Errorf("unexpected type: got %#v, want %#v", got, want)
421 }
422
423 want = nil
424 wanterr := ErrNoTypeMapEntry{Type: bsontype.ObjectID}
425 got, err = reg.LookupTypeMapEntry(bsontype.ObjectID)
426 if !errors.Is(err, wanterr) {
427 t.Errorf("did not get expected error: got %#v, want %#v", err, wanterr)
428 }
429 if got != want {
430 t.Errorf("unexpected type: got %#v, want %#v", got, want)
431 }
432 })
433 }
434
435 func TestRegistry(t *testing.T) {
436 t.Parallel()
437
438 t.Run("Register", func(t *testing.T) {
439 t.Parallel()
440
441 fc1, fc2, fc3, fc4 := new(fakeCodec), new(fakeCodec), new(fakeCodec), new(fakeCodec)
442 t.Run("interface", func(t *testing.T) {
443 t.Parallel()
444
445 var t1f *testInterface1
446 var t2f *testInterface2
447 var t4f *testInterface4
448 ips := []interfaceValueEncoder{
449 {i: reflect.TypeOf(t1f).Elem(), ve: fc1},
450 {i: reflect.TypeOf(t2f).Elem(), ve: fc2},
451 {i: reflect.TypeOf(t1f).Elem(), ve: fc3},
452 {i: reflect.TypeOf(t4f).Elem(), ve: fc4},
453 }
454 want := []interfaceValueEncoder{
455 {i: reflect.TypeOf(t1f).Elem(), ve: fc3},
456 {i: reflect.TypeOf(t2f).Elem(), ve: fc2},
457 {i: reflect.TypeOf(t4f).Elem(), ve: fc4},
458 }
459 reg := NewRegistry()
460 for _, ip := range ips {
461 reg.RegisterInterfaceEncoder(ip.i, ip.ve)
462 }
463 got := reg.interfaceEncoders
464 if !cmp.Equal(got, want, cmp.AllowUnexported(interfaceValueEncoder{}, fakeCodec{}), cmp.Comparer(typeComparer)) {
465 t.Errorf("registered interfaces are not correct: got %#v, want %#v", got, want)
466 }
467 })
468 t.Run("type", func(t *testing.T) {
469 t.Parallel()
470
471 ft1, ft2, ft4 := fakeType1{}, fakeType2{}, fakeType4{}
472 reg := NewRegistry()
473 reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc1)
474 reg.RegisterTypeEncoder(reflect.TypeOf(ft2), fc2)
475 reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc3)
476 reg.RegisterTypeEncoder(reflect.TypeOf(ft4), fc4)
477
478 want := []struct {
479 t reflect.Type
480 c ValueEncoder
481 }{
482 {reflect.TypeOf(ft1), fc3},
483 {reflect.TypeOf(ft2), fc2},
484 {reflect.TypeOf(ft4), fc4},
485 }
486 got := reg.typeEncoders
487 for _, s := range want {
488 wantT, wantC := s.t, s.c
489 gotC, exists := got.Load(wantT)
490 if !exists {
491 t.Errorf("type missing in registry: %v", wantT)
492 }
493 if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) {
494 t.Errorf("codecs did not match: got %#v; want %#v", gotC, wantC)
495 }
496 }
497 })
498 t.Run("kind", func(t *testing.T) {
499 t.Parallel()
500
501 k1, k2, k4 := reflect.Struct, reflect.Slice, reflect.Map
502 reg := NewRegistry()
503 reg.RegisterKindEncoder(k1, fc1)
504 reg.RegisterKindEncoder(k2, fc2)
505 reg.RegisterKindEncoder(k1, fc3)
506 reg.RegisterKindEncoder(k4, fc4)
507
508 want := []struct {
509 k reflect.Kind
510 c ValueEncoder
511 }{
512 {k1, fc3},
513 {k2, fc2},
514 {k4, fc4},
515 }
516 got := reg.kindEncoders
517 for _, s := range want {
518 wantK, wantC := s.k, s.c
519 gotC, exists := got.Load(wantK)
520 if !exists {
521 t.Errorf("type missing in registry: %v", wantK)
522 }
523 if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) {
524 t.Errorf("codecs did not match: got %#v, want %#v", gotC, wantC)
525 }
526 }
527 })
528 t.Run("RegisterDefault", func(t *testing.T) {
529 t.Parallel()
530
531 t.Run("MapCodec", func(t *testing.T) {
532 t.Parallel()
533
534 codec := &fakeCodec{num: 1}
535 codec2 := &fakeCodec{num: 2}
536 reg := NewRegistry()
537 reg.RegisterKindEncoder(reflect.Map, codec)
538 if reg.kindEncoders.get(reflect.Map) != codec {
539 t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec)
540 }
541 reg.RegisterKindEncoder(reflect.Map, codec2)
542 if reg.kindEncoders.get(reflect.Map) != codec2 {
543 t.Errorf("map codec properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec2)
544 }
545 })
546 t.Run("StructCodec", func(t *testing.T) {
547 t.Parallel()
548
549 codec := &fakeCodec{num: 1}
550 codec2 := &fakeCodec{num: 2}
551 reg := NewRegistry()
552 reg.RegisterKindEncoder(reflect.Struct, codec)
553 if reg.kindEncoders.get(reflect.Struct) != codec {
554 t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec)
555 }
556 reg.RegisterKindEncoder(reflect.Struct, codec2)
557 if reg.kindEncoders.get(reflect.Struct) != codec2 {
558 t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec2)
559 }
560 })
561 t.Run("SliceCodec", func(t *testing.T) {
562 t.Parallel()
563
564 codec := &fakeCodec{num: 1}
565 codec2 := &fakeCodec{num: 2}
566 reg := NewRegistry()
567 reg.RegisterKindEncoder(reflect.Slice, codec)
568 if reg.kindEncoders.get(reflect.Slice) != codec {
569 t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec)
570 }
571 reg.RegisterKindEncoder(reflect.Slice, codec2)
572 if reg.kindEncoders.get(reflect.Slice) != codec2 {
573 t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec2)
574 }
575 })
576 t.Run("ArrayCodec", func(t *testing.T) {
577 t.Parallel()
578
579 codec := &fakeCodec{num: 1}
580 codec2 := &fakeCodec{num: 2}
581 reg := NewRegistry()
582 reg.RegisterKindEncoder(reflect.Array, codec)
583 if reg.kindEncoders.get(reflect.Array) != codec {
584 t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec)
585 }
586 reg.RegisterKindEncoder(reflect.Array, codec2)
587 if reg.kindEncoders.get(reflect.Array) != codec2 {
588 t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec2)
589 }
590 })
591 })
592 t.Run("Lookup", func(t *testing.T) {
593 t.Parallel()
594
595 type Codec interface {
596 ValueEncoder
597 ValueDecoder
598 }
599
600 var (
601 arrinstance [12]int
602 arr = reflect.TypeOf(arrinstance)
603 slc = reflect.TypeOf(make([]int, 12))
604 m = reflect.TypeOf(make(map[string]int))
605 strct = reflect.TypeOf(struct{ Foo string }{})
606 ft1 = reflect.PtrTo(reflect.TypeOf(fakeType1{}))
607 ft2 = reflect.TypeOf(fakeType2{})
608 ft3 = reflect.TypeOf(fakeType5(func(string, string) string { return "fakeType5" }))
609 ti1 = reflect.TypeOf((*testInterface1)(nil)).Elem()
610 ti2 = reflect.TypeOf((*testInterface2)(nil)).Elem()
611 ti1Impl = reflect.TypeOf(testInterface1Impl{})
612 ti2Impl = reflect.TypeOf(testInterface2Impl{})
613 ti3 = reflect.TypeOf((*testInterface3)(nil)).Elem()
614 ti3Impl = reflect.TypeOf(testInterface3Impl{})
615 ti3ImplPtr = reflect.TypeOf((*testInterface3Impl)(nil))
616 fc1, fc2 = &fakeCodec{num: 1}, &fakeCodec{num: 2}
617 fsc, fslcc, fmc = new(fakeStructCodec), new(fakeSliceCodec), new(fakeMapCodec)
618 pc = NewPointerCodec()
619 )
620
621 reg := NewRegistry()
622 reg.RegisterTypeEncoder(ft1, fc1)
623 reg.RegisterTypeEncoder(ft2, fc2)
624 reg.RegisterTypeEncoder(ti1, fc1)
625 reg.RegisterKindEncoder(reflect.Struct, fsc)
626 reg.RegisterKindEncoder(reflect.Slice, fslcc)
627 reg.RegisterKindEncoder(reflect.Array, fslcc)
628 reg.RegisterKindEncoder(reflect.Map, fmc)
629 reg.RegisterKindEncoder(reflect.Ptr, pc)
630 reg.RegisterTypeDecoder(ft1, fc1)
631 reg.RegisterTypeDecoder(ft2, fc2)
632 reg.RegisterTypeDecoder(ti1, fc1)
633 reg.RegisterKindDecoder(reflect.Struct, fsc)
634 reg.RegisterKindDecoder(reflect.Slice, fslcc)
635 reg.RegisterKindDecoder(reflect.Array, fslcc)
636 reg.RegisterKindDecoder(reflect.Map, fmc)
637 reg.RegisterKindDecoder(reflect.Ptr, pc)
638 reg.RegisterInterfaceEncoder(ti2, fc2)
639 reg.RegisterInterfaceDecoder(ti2, fc2)
640 reg.RegisterInterfaceEncoder(ti3, fc3)
641 reg.RegisterInterfaceDecoder(ti3, fc3)
642
643 testCases := []struct {
644 name string
645 t reflect.Type
646 wantcodec Codec
647 wanterr error
648 testcache bool
649 }{
650 {
651 "type registry (pointer)",
652 ft1,
653 fc1,
654 nil,
655 false,
656 },
657 {
658 "type registry (non-pointer)",
659 ft2,
660 fc2,
661 nil,
662 false,
663 },
664 {
665
666 "interface with type encoder",
667 ti1,
668 fc1,
669 nil,
670 true,
671 },
672 {
673
674 "interface implementation with type encoder",
675 ti1Impl,
676 fsc,
677 nil,
678 false,
679 },
680 {
681
682 "interface with hook",
683 ti2,
684 fc2,
685 nil,
686 false,
687 },
688 {
689
690 "interface implementation with hook",
691 ti2Impl,
692 fc2,
693 nil,
694 false,
695 },
696 {
697
698
699 "interface pointer to implementation with hook (pointer)",
700 ti3ImplPtr,
701 fc3,
702 nil,
703 false,
704 },
705 {
706 "default struct codec (pointer)",
707 reflect.PtrTo(strct),
708 pc,
709 nil,
710 false,
711 },
712 {
713 "default struct codec (non-pointer)",
714 strct,
715 fsc,
716 nil,
717 false,
718 },
719 {
720 "default array codec",
721 arr,
722 fslcc,
723 nil,
724 false,
725 },
726 {
727 "default slice codec",
728 slc,
729 fslcc,
730 nil,
731 false,
732 },
733 {
734 "default map",
735 m,
736 fmc,
737 nil,
738 false,
739 },
740 {
741 "map non-string key",
742 reflect.TypeOf(map[int]int{}),
743 fmc,
744 nil,
745 false,
746 },
747 {
748 "No Codec Registered",
749 ft3,
750 nil,
751 ErrNoEncoder{Type: ft3},
752 false,
753 },
754 }
755
756 allowunexported := cmp.AllowUnexported(fakeCodec{}, fakeStructCodec{}, fakeSliceCodec{}, fakeMapCodec{})
757 comparepc := func(pc1, pc2 *PointerCodec) bool { return true }
758 for _, tc := range testCases {
759 tc := tc
760
761 t.Run(tc.name, func(t *testing.T) {
762 t.Parallel()
763
764 t.Run("Encoder", func(t *testing.T) {
765 t.Parallel()
766
767 gotcodec, goterr := reg.LookupEncoder(tc.t)
768 if !cmp.Equal(goterr, tc.wanterr, cmp.Comparer(compareErrors)) {
769 t.Errorf("errors did not match: got %#v, want %#v", goterr, tc.wanterr)
770 }
771 if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) {
772 t.Errorf("codecs did not match: got %#v, want %#v", gotcodec, tc.wantcodec)
773 }
774 })
775 t.Run("Decoder", func(t *testing.T) {
776 t.Parallel()
777
778 wanterr := tc.wanterr
779 if ene, ok := tc.wanterr.(ErrNoEncoder); ok {
780 wanterr = ErrNoDecoder(ene)
781 }
782
783 gotcodec, goterr := reg.LookupDecoder(tc.t)
784 if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) {
785 t.Errorf("errors did not match: got %#v, want %#v", goterr, wanterr)
786 }
787 if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) {
788 t.Errorf("codecs did not match: got %v: want %v", gotcodec, tc.wantcodec)
789 }
790 })
791 })
792 }
793 t.Run("nil type", func(t *testing.T) {
794 t.Parallel()
795
796 t.Run("Encoder", func(t *testing.T) {
797 t.Parallel()
798
799 wanterr := ErrNoEncoder{Type: reflect.TypeOf(nil)}
800
801 gotcodec, goterr := reg.LookupEncoder(nil)
802 if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) {
803 t.Errorf("errors did not match: got %#v, want %#v", goterr, wanterr)
804 }
805 if !cmp.Equal(gotcodec, nil, allowunexported, cmp.Comparer(comparepc)) {
806 t.Errorf("codecs did not match: got %#v, want nil", gotcodec)
807 }
808 })
809 t.Run("Decoder", func(t *testing.T) {
810 t.Parallel()
811
812 wanterr := ErrNilType
813
814 gotcodec, goterr := reg.LookupDecoder(nil)
815 if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) {
816 t.Errorf("errors did not match: got %#v, want %#v", goterr, wanterr)
817 }
818 if !cmp.Equal(gotcodec, nil, allowunexported, cmp.Comparer(comparepc)) {
819 t.Errorf("codecs did not match: got %v: want nil", gotcodec)
820 }
821 })
822 })
823
824
825 t.Run("interface implementation with hook (pointer)", func(t *testing.T) {
826 t.Parallel()
827
828 t.Run("Encoder", func(t *testing.T) {
829 t.Parallel()
830 gotEnc, err := reg.LookupEncoder(ti3Impl)
831 assert.Nil(t, err, "LookupEncoder error: %v", err)
832
833 cae, ok := gotEnc.(*condAddrEncoder)
834 assert.True(t, ok, "Expected CondAddrEncoder, got %T", gotEnc)
835 if !cmp.Equal(cae.canAddrEnc, fc3, allowunexported, cmp.Comparer(comparepc)) {
836 t.Errorf("expected canAddrEnc %#v, got %#v", cae.canAddrEnc, fc3)
837 }
838 if !cmp.Equal(cae.elseEnc, fsc, allowunexported, cmp.Comparer(comparepc)) {
839 t.Errorf("expected elseEnc %#v, got %#v", cae.elseEnc, fsc)
840 }
841 })
842 t.Run("Decoder", func(t *testing.T) {
843 t.Parallel()
844
845 gotDec, err := reg.LookupDecoder(ti3Impl)
846 assert.Nil(t, err, "LookupDecoder error: %v", err)
847
848 cad, ok := gotDec.(*condAddrDecoder)
849 assert.True(t, ok, "Expected CondAddrDecoder, got %T", gotDec)
850 if !cmp.Equal(cad.canAddrDec, fc3, allowunexported, cmp.Comparer(comparepc)) {
851 t.Errorf("expected canAddrDec %#v, got %#v", cad.canAddrDec, fc3)
852 }
853 if !cmp.Equal(cad.elseDec, fsc, allowunexported, cmp.Comparer(comparepc)) {
854 t.Errorf("expected elseDec %#v, got %#v", cad.elseDec, fsc)
855 }
856 })
857 })
858 })
859 })
860 t.Run("Type Map", func(t *testing.T) {
861 t.Parallel()
862 reg := NewRegistry()
863 reg.RegisterTypeMapEntry(bsontype.String, reflect.TypeOf(""))
864 reg.RegisterTypeMapEntry(bsontype.Int32, reflect.TypeOf(int(0)))
865
866 var got, want reflect.Type
867
868 want = reflect.TypeOf("")
869 got, err := reg.LookupTypeMapEntry(bsontype.String)
870 noerr(t, err)
871 if got != want {
872 t.Errorf("unexpected type: got %#v, want %#v", got, want)
873 }
874
875 want = reflect.TypeOf(int(0))
876 got, err = reg.LookupTypeMapEntry(bsontype.Int32)
877 noerr(t, err)
878 if got != want {
879 t.Errorf("unexpected type: got %#v, want %#v", got, want)
880 }
881
882 want = nil
883 wanterr := ErrNoTypeMapEntry{Type: bsontype.ObjectID}
884 got, err = reg.LookupTypeMapEntry(bsontype.ObjectID)
885 if !errors.Is(err, wanterr) {
886 t.Errorf("unexpected error: got %#v, want %#v", err, wanterr)
887 }
888 if got != want {
889 t.Errorf("unexpected error: got %#v, want %#v", got, want)
890 }
891 })
892 }
893
894
895 func (c *kindEncoderCache) get(rt reflect.Kind) ValueEncoder {
896 e, _ := c.Load(rt)
897 return e
898 }
899
900 func BenchmarkLookupEncoder(b *testing.B) {
901 type childStruct struct {
902 V1, V2, V3, V4 int
903 }
904 type nestedStruct struct {
905 childStruct
906 A struct{ C1, C2, C3, C4 childStruct }
907 B struct{ C1, C2, C3, C4 childStruct }
908 C struct{ M1, M2, M3, M4 map[int]int }
909 }
910 types := [...]reflect.Type{
911 reflect.TypeOf(int64(1)),
912 reflect.TypeOf(&fakeCodec{}),
913 reflect.TypeOf(&testInterface1Impl{}),
914 reflect.TypeOf(&nestedStruct{}),
915 }
916 r := NewRegistry()
917 for _, typ := range types {
918 r.RegisterTypeEncoder(typ, &fakeCodec{})
919 }
920 b.Run("Serial", func(b *testing.B) {
921 for i := 0; i < b.N; i++ {
922 _, err := r.LookupEncoder(types[i%len(types)])
923 if err != nil {
924 b.Fatal(err)
925 }
926 }
927 })
928 b.Run("Parallel", func(b *testing.B) {
929 b.RunParallel(func(pb *testing.PB) {
930 for i := 0; pb.Next(); i++ {
931 _, err := r.LookupEncoder(types[i%len(types)])
932 if err != nil {
933 b.Fatal(err)
934 }
935 }
936 })
937 })
938 }
939
940 type fakeType1 struct{}
941 type fakeType2 struct{}
942 type fakeType4 struct{}
943 type fakeType5 func(string, string) string
944 type fakeStructCodec struct{ *fakeCodec }
945 type fakeSliceCodec struct{ *fakeCodec }
946 type fakeMapCodec struct{ *fakeCodec }
947
948 type fakeCodec struct {
949
950
951
952
953 num int
954 }
955
956 func (*fakeCodec) EncodeValue(EncodeContext, bsonrw.ValueWriter, reflect.Value) error {
957 return nil
958 }
959 func (*fakeCodec) DecodeValue(DecodeContext, bsonrw.ValueReader, reflect.Value) error {
960 return nil
961 }
962
963 type testInterface1 interface{ test1() }
964 type testInterface2 interface{ test2() }
965 type testInterface3 interface{ test3() }
966 type testInterface4 interface{ test4() }
967
968 type testInterface1Impl struct{}
969
970 var _ testInterface1 = testInterface1Impl{}
971
972 func (testInterface1Impl) test1() {}
973
974 type testInterface2Impl struct{}
975
976 var _ testInterface2 = testInterface2Impl{}
977
978 func (testInterface2Impl) test2() {}
979
980 type testInterface3Impl struct{}
981
982 var _ testInterface3 = (*testInterface3Impl)(nil)
983
984 func (*testInterface3Impl) test3() {}
985
986 func typeComparer(i1, i2 reflect.Type) bool { return i1 == i2 }
987
View as plain text