...

Source file src/go.mongodb.org/mongo-driver/bson/bsoncodec/registry_test.go

Documentation: go.mongodb.org/mongo-driver/bson/bsoncodec

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package bsoncodec
     8  
     9  import (
    10  	"errors"
    11  	"reflect"
    12  	"testing"
    13  
    14  	"github.com/google/go-cmp/cmp"
    15  	"go.mongodb.org/mongo-driver/bson/bsonrw"
    16  	"go.mongodb.org/mongo-driver/bson/bsontype"
    17  	"go.mongodb.org/mongo-driver/internal/assert"
    18  )
    19  
    20  func TestRegistryBuilder(t *testing.T) {
    21  	t.Run("Register", func(t *testing.T) {
    22  		fc1, fc2, fc3, fc4 := new(fakeCodec), new(fakeCodec), new(fakeCodec), new(fakeCodec)
    23  		t.Run("interface", func(t *testing.T) {
    24  			var t1f *testInterface1
    25  			var t2f *testInterface2
    26  			var t4f *testInterface4
    27  			ips := []interfaceValueEncoder{
    28  				{i: reflect.TypeOf(t1f).Elem(), ve: fc1},
    29  				{i: reflect.TypeOf(t2f).Elem(), ve: fc2},
    30  				{i: reflect.TypeOf(t1f).Elem(), ve: fc3},
    31  				{i: reflect.TypeOf(t4f).Elem(), ve: fc4},
    32  			}
    33  			want := []interfaceValueEncoder{
    34  				{i: reflect.TypeOf(t1f).Elem(), ve: fc3},
    35  				{i: reflect.TypeOf(t2f).Elem(), ve: fc2},
    36  				{i: reflect.TypeOf(t4f).Elem(), ve: fc4},
    37  			}
    38  			rb := NewRegistryBuilder()
    39  			for _, ip := range ips {
    40  				rb.RegisterHookEncoder(ip.i, ip.ve)
    41  			}
    42  
    43  			reg := rb.Build()
    44  			got := reg.interfaceEncoders
    45  			if !cmp.Equal(got, want, cmp.AllowUnexported(interfaceValueEncoder{}, fakeCodec{}), cmp.Comparer(typeComparer)) {
    46  				t.Errorf("the registered interfaces are not correct: got %#v, want %#v", got, want)
    47  			}
    48  		})
    49  		t.Run("type", func(t *testing.T) {
    50  			ft1, ft2, ft4 := fakeType1{}, fakeType2{}, fakeType4{}
    51  			rb := NewRegistryBuilder().
    52  				RegisterTypeEncoder(reflect.TypeOf(ft1), fc1).
    53  				RegisterTypeEncoder(reflect.TypeOf(ft2), fc2).
    54  				RegisterTypeEncoder(reflect.TypeOf(ft1), fc3).
    55  				RegisterTypeEncoder(reflect.TypeOf(ft4), fc4)
    56  			want := []struct {
    57  				t reflect.Type
    58  				c ValueEncoder
    59  			}{
    60  				{reflect.TypeOf(ft1), fc3},
    61  				{reflect.TypeOf(ft2), fc2},
    62  				{reflect.TypeOf(ft4), fc4},
    63  			}
    64  
    65  			reg := rb.Build()
    66  			got := reg.typeEncoders
    67  			for _, s := range want {
    68  				wantT, wantC := s.t, s.c
    69  				gotC, exists := got.Load(wantT)
    70  				if !exists {
    71  					t.Errorf("Did not find type in the type registry: %v", wantT)
    72  				}
    73  				if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) {
    74  					t.Errorf("codecs did not match: got %#v; want %#v", gotC, wantC)
    75  				}
    76  			}
    77  		})
    78  		t.Run("kind", func(t *testing.T) {
    79  			k1, k2, k4 := reflect.Struct, reflect.Slice, reflect.Map
    80  			rb := NewRegistryBuilder().
    81  				RegisterDefaultEncoder(k1, fc1).
    82  				RegisterDefaultEncoder(k2, fc2).
    83  				RegisterDefaultEncoder(k1, fc3).
    84  				RegisterDefaultEncoder(k4, fc4)
    85  			want := []struct {
    86  				k reflect.Kind
    87  				c ValueEncoder
    88  			}{
    89  				{k1, fc3},
    90  				{k2, fc2},
    91  				{k4, fc4},
    92  			}
    93  
    94  			reg := rb.Build()
    95  			got := reg.kindEncoders
    96  			for _, s := range want {
    97  				wantK, wantC := s.k, s.c
    98  				gotC, exists := got.Load(wantK)
    99  				if !exists {
   100  					t.Errorf("Did not find kind in the kind registry: %v", wantK)
   101  				}
   102  				if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) {
   103  					t.Errorf("codecs did not match: got %#v; want %#v", gotC, wantC)
   104  				}
   105  			}
   106  		})
   107  		t.Run("RegisterDefault", func(t *testing.T) {
   108  			t.Run("MapCodec", func(t *testing.T) {
   109  				codec := &fakeCodec{num: 1}
   110  				codec2 := &fakeCodec{num: 2}
   111  				rb := NewRegistryBuilder()
   112  
   113  				rb.RegisterDefaultEncoder(reflect.Map, codec)
   114  				reg := rb.Build()
   115  				if reg.kindEncoders.get(reflect.Map) != codec {
   116  					t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec)
   117  				}
   118  
   119  				rb.RegisterDefaultEncoder(reflect.Map, codec2)
   120  				reg = rb.Build()
   121  				if reg.kindEncoders.get(reflect.Map) != codec2 {
   122  					t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec2)
   123  				}
   124  			})
   125  			t.Run("StructCodec", func(t *testing.T) {
   126  				codec := &fakeCodec{num: 1}
   127  				codec2 := &fakeCodec{num: 2}
   128  				rb := NewRegistryBuilder()
   129  
   130  				rb.RegisterDefaultEncoder(reflect.Struct, codec)
   131  				reg := rb.Build()
   132  				if reg.kindEncoders.get(reflect.Struct) != codec {
   133  					t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec)
   134  				}
   135  
   136  				rb.RegisterDefaultEncoder(reflect.Struct, codec2)
   137  				reg = rb.Build()
   138  				if reg.kindEncoders.get(reflect.Struct) != codec2 {
   139  					t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec2)
   140  				}
   141  			})
   142  			t.Run("SliceCodec", func(t *testing.T) {
   143  				codec := &fakeCodec{num: 1}
   144  				codec2 := &fakeCodec{num: 2}
   145  				rb := NewRegistryBuilder()
   146  
   147  				rb.RegisterDefaultEncoder(reflect.Slice, codec)
   148  				reg := rb.Build()
   149  				if reg.kindEncoders.get(reflect.Slice) != codec {
   150  					t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec)
   151  				}
   152  
   153  				rb.RegisterDefaultEncoder(reflect.Slice, codec2)
   154  				reg = rb.Build()
   155  				if reg.kindEncoders.get(reflect.Slice) != codec2 {
   156  					t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec2)
   157  				}
   158  			})
   159  			t.Run("ArrayCodec", func(t *testing.T) {
   160  				codec := &fakeCodec{num: 1}
   161  				codec2 := &fakeCodec{num: 2}
   162  				rb := NewRegistryBuilder()
   163  
   164  				rb.RegisterDefaultEncoder(reflect.Array, codec)
   165  				reg := rb.Build()
   166  				if reg.kindEncoders.get(reflect.Array) != codec {
   167  					t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec)
   168  				}
   169  
   170  				rb.RegisterDefaultEncoder(reflect.Array, codec2)
   171  				reg = rb.Build()
   172  				if reg.kindEncoders.get(reflect.Array) != codec2 {
   173  					t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec2)
   174  				}
   175  			})
   176  		})
   177  		t.Run("Lookup", func(t *testing.T) {
   178  			type Codec interface {
   179  				ValueEncoder
   180  				ValueDecoder
   181  			}
   182  
   183  			var (
   184  				arrinstance     [12]int
   185  				arr             = reflect.TypeOf(arrinstance)
   186  				slc             = reflect.TypeOf(make([]int, 12))
   187  				m               = reflect.TypeOf(make(map[string]int))
   188  				strct           = reflect.TypeOf(struct{ Foo string }{})
   189  				ft1             = reflect.PtrTo(reflect.TypeOf(fakeType1{}))
   190  				ft2             = reflect.TypeOf(fakeType2{})
   191  				ft3             = reflect.TypeOf(fakeType5(func(string, string) string { return "fakeType5" }))
   192  				ti1             = reflect.TypeOf((*testInterface1)(nil)).Elem()
   193  				ti2             = reflect.TypeOf((*testInterface2)(nil)).Elem()
   194  				ti1Impl         = reflect.TypeOf(testInterface1Impl{})
   195  				ti2Impl         = reflect.TypeOf(testInterface2Impl{})
   196  				ti3             = reflect.TypeOf((*testInterface3)(nil)).Elem()
   197  				ti3Impl         = reflect.TypeOf(testInterface3Impl{})
   198  				ti3ImplPtr      = reflect.TypeOf((*testInterface3Impl)(nil))
   199  				fc1, fc2        = &fakeCodec{num: 1}, &fakeCodec{num: 2}
   200  				fsc, fslcc, fmc = new(fakeStructCodec), new(fakeSliceCodec), new(fakeMapCodec)
   201  				pc              = NewPointerCodec()
   202  			)
   203  
   204  			reg := NewRegistryBuilder().
   205  				RegisterTypeEncoder(ft1, fc1).
   206  				RegisterTypeEncoder(ft2, fc2).
   207  				RegisterTypeEncoder(ti1, fc1).
   208  				RegisterDefaultEncoder(reflect.Struct, fsc).
   209  				RegisterDefaultEncoder(reflect.Slice, fslcc).
   210  				RegisterDefaultEncoder(reflect.Array, fslcc).
   211  				RegisterDefaultEncoder(reflect.Map, fmc).
   212  				RegisterDefaultEncoder(reflect.Ptr, pc).
   213  				RegisterTypeDecoder(ft1, fc1).
   214  				RegisterTypeDecoder(ft2, fc2).
   215  				RegisterTypeDecoder(ti1, fc1). // values whose exact type is testInterface1 will use fc1 encoder
   216  				RegisterDefaultDecoder(reflect.Struct, fsc).
   217  				RegisterDefaultDecoder(reflect.Slice, fslcc).
   218  				RegisterDefaultDecoder(reflect.Array, fslcc).
   219  				RegisterDefaultDecoder(reflect.Map, fmc).
   220  				RegisterDefaultDecoder(reflect.Ptr, pc).
   221  				RegisterHookEncoder(ti2, fc2).
   222  				RegisterHookDecoder(ti2, fc2).
   223  				RegisterHookEncoder(ti3, fc3).
   224  				RegisterHookDecoder(ti3, fc3).
   225  				Build()
   226  
   227  			testCases := []struct {
   228  				name      string
   229  				t         reflect.Type
   230  				wantcodec Codec
   231  				wanterr   error
   232  				testcache bool
   233  			}{
   234  				{
   235  					"type registry (pointer)",
   236  					ft1,
   237  					fc1,
   238  					nil,
   239  					false,
   240  				},
   241  				{
   242  					"type registry (non-pointer)",
   243  					ft2,
   244  					fc2,
   245  					nil,
   246  					false,
   247  				},
   248  				{
   249  					// lookup an interface type and expect that the registered encoder is returned
   250  					"interface with type encoder",
   251  					ti1,
   252  					fc1,
   253  					nil,
   254  					true,
   255  				},
   256  				{
   257  					// lookup a type that implements an interface and expect that the default struct codec is returned
   258  					"interface implementation with type encoder",
   259  					ti1Impl,
   260  					fsc,
   261  					nil,
   262  					false,
   263  				},
   264  				{
   265  					// lookup an interface type and expect that the registered hook is returned
   266  					"interface with hook",
   267  					ti2,
   268  					fc2,
   269  					nil,
   270  					false,
   271  				},
   272  				{
   273  					// lookup a type that implements an interface and expect that the registered hook is returned
   274  					"interface implementation with hook",
   275  					ti2Impl,
   276  					fc2,
   277  					nil,
   278  					false,
   279  				},
   280  				{
   281  					// lookup a pointer to a type where the pointer implements an interface and expect that the
   282  					// registered hook is returned
   283  					"interface pointer to implementation with hook (pointer)",
   284  					ti3ImplPtr,
   285  					fc3,
   286  					nil,
   287  					false,
   288  				},
   289  				{
   290  					"default struct codec (pointer)",
   291  					reflect.PtrTo(strct),
   292  					pc,
   293  					nil,
   294  					false,
   295  				},
   296  				{
   297  					"default struct codec (non-pointer)",
   298  					strct,
   299  					fsc,
   300  					nil,
   301  					false,
   302  				},
   303  				{
   304  					"default array codec",
   305  					arr,
   306  					fslcc,
   307  					nil,
   308  					false,
   309  				},
   310  				{
   311  					"default slice codec",
   312  					slc,
   313  					fslcc,
   314  					nil,
   315  					false,
   316  				},
   317  				{
   318  					"default map",
   319  					m,
   320  					fmc,
   321  					nil,
   322  					false,
   323  				},
   324  				{
   325  					"map non-string key",
   326  					reflect.TypeOf(map[int]int{}),
   327  					fmc,
   328  					nil,
   329  					false,
   330  				},
   331  				{
   332  					"No Codec Registered",
   333  					ft3,
   334  					nil,
   335  					ErrNoEncoder{Type: ft3},
   336  					false,
   337  				},
   338  			}
   339  
   340  			allowunexported := cmp.AllowUnexported(fakeCodec{}, fakeStructCodec{}, fakeSliceCodec{}, fakeMapCodec{})
   341  			comparepc := func(pc1, pc2 *PointerCodec) bool { return true }
   342  			for _, tc := range testCases {
   343  				t.Run(tc.name, func(t *testing.T) {
   344  					t.Run("Encoder", func(t *testing.T) {
   345  						gotcodec, goterr := reg.LookupEncoder(tc.t)
   346  						if !cmp.Equal(goterr, tc.wanterr, cmp.Comparer(compareErrors)) {
   347  							t.Errorf("errors did not match: got %#v, want %#v", goterr, tc.wanterr)
   348  						}
   349  						if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) {
   350  							t.Errorf("codecs did not match: got %#v, want %#v", gotcodec, tc.wantcodec)
   351  						}
   352  					})
   353  					t.Run("Decoder", func(t *testing.T) {
   354  						wanterr := tc.wanterr
   355  						if ene, ok := tc.wanterr.(ErrNoEncoder); ok {
   356  							wanterr = ErrNoDecoder(ene)
   357  						}
   358  
   359  						gotcodec, goterr := reg.LookupDecoder(tc.t)
   360  						if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) {
   361  							t.Errorf("errors did not match: got %#v, want %#v", goterr, wanterr)
   362  						}
   363  						if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) {
   364  							t.Errorf("codecs did not match: got %#v, want %#v", gotcodec, tc.wantcodec)
   365  						}
   366  					})
   367  				})
   368  			}
   369  			// lookup a type whose pointer implements an interface and expect that the registered hook is
   370  			// returned
   371  			t.Run("interface implementation with hook (pointer)", func(t *testing.T) {
   372  				t.Run("Encoder", func(t *testing.T) {
   373  					gotEnc, err := reg.LookupEncoder(ti3Impl)
   374  					assert.Nil(t, err, "LookupEncoder error: %v", err)
   375  
   376  					cae, ok := gotEnc.(*condAddrEncoder)
   377  					assert.True(t, ok, "Expected CondAddrEncoder, got %T", gotEnc)
   378  					if !cmp.Equal(cae.canAddrEnc, fc3, allowunexported, cmp.Comparer(comparepc)) {
   379  						t.Errorf("expected canAddrEnc %#v, got %#v", cae.canAddrEnc, fc3)
   380  					}
   381  					if !cmp.Equal(cae.elseEnc, fsc, allowunexported, cmp.Comparer(comparepc)) {
   382  						t.Errorf("expected elseEnc %#v, got %#v", cae.elseEnc, fsc)
   383  					}
   384  				})
   385  				t.Run("Decoder", func(t *testing.T) {
   386  					gotDec, err := reg.LookupDecoder(ti3Impl)
   387  					assert.Nil(t, err, "LookupDecoder error: %v", err)
   388  
   389  					cad, ok := gotDec.(*condAddrDecoder)
   390  					assert.True(t, ok, "Expected CondAddrDecoder, got %T", gotDec)
   391  					if !cmp.Equal(cad.canAddrDec, fc3, allowunexported, cmp.Comparer(comparepc)) {
   392  						t.Errorf("expected canAddrDec %#v, got %#v", cad.canAddrDec, fc3)
   393  					}
   394  					if !cmp.Equal(cad.elseDec, fsc, allowunexported, cmp.Comparer(comparepc)) {
   395  						t.Errorf("expected elseDec %#v, got %#v", cad.elseDec, fsc)
   396  					}
   397  				})
   398  			})
   399  		})
   400  	})
   401  	t.Run("Type Map", func(t *testing.T) {
   402  		reg := NewRegistryBuilder().
   403  			RegisterTypeMapEntry(bsontype.String, reflect.TypeOf("")).
   404  			RegisterTypeMapEntry(bsontype.Int32, reflect.TypeOf(int(0))).
   405  			Build()
   406  
   407  		var got, want reflect.Type
   408  
   409  		want = reflect.TypeOf("")
   410  		got, err := reg.LookupTypeMapEntry(bsontype.String)
   411  		noerr(t, err)
   412  		if got != want {
   413  			t.Errorf("unexpected type: got %#v, want %#v", got, want)
   414  		}
   415  
   416  		want = reflect.TypeOf(int(0))
   417  		got, err = reg.LookupTypeMapEntry(bsontype.Int32)
   418  		noerr(t, err)
   419  		if got != want {
   420  			t.Errorf("unexpected type: got %#v, want %#v", got, want)
   421  		}
   422  
   423  		want = nil
   424  		wanterr := ErrNoTypeMapEntry{Type: bsontype.ObjectID}
   425  		got, err = reg.LookupTypeMapEntry(bsontype.ObjectID)
   426  		if !errors.Is(err, wanterr) {
   427  			t.Errorf("did not get expected error: got %#v, want %#v", err, wanterr)
   428  		}
   429  		if got != want {
   430  			t.Errorf("unexpected type: got %#v, want %#v", got, want)
   431  		}
   432  	})
   433  }
   434  
   435  func TestRegistry(t *testing.T) {
   436  	t.Parallel()
   437  
   438  	t.Run("Register", func(t *testing.T) {
   439  		t.Parallel()
   440  
   441  		fc1, fc2, fc3, fc4 := new(fakeCodec), new(fakeCodec), new(fakeCodec), new(fakeCodec)
   442  		t.Run("interface", func(t *testing.T) {
   443  			t.Parallel()
   444  
   445  			var t1f *testInterface1
   446  			var t2f *testInterface2
   447  			var t4f *testInterface4
   448  			ips := []interfaceValueEncoder{
   449  				{i: reflect.TypeOf(t1f).Elem(), ve: fc1},
   450  				{i: reflect.TypeOf(t2f).Elem(), ve: fc2},
   451  				{i: reflect.TypeOf(t1f).Elem(), ve: fc3},
   452  				{i: reflect.TypeOf(t4f).Elem(), ve: fc4},
   453  			}
   454  			want := []interfaceValueEncoder{
   455  				{i: reflect.TypeOf(t1f).Elem(), ve: fc3},
   456  				{i: reflect.TypeOf(t2f).Elem(), ve: fc2},
   457  				{i: reflect.TypeOf(t4f).Elem(), ve: fc4},
   458  			}
   459  			reg := NewRegistry()
   460  			for _, ip := range ips {
   461  				reg.RegisterInterfaceEncoder(ip.i, ip.ve)
   462  			}
   463  			got := reg.interfaceEncoders
   464  			if !cmp.Equal(got, want, cmp.AllowUnexported(interfaceValueEncoder{}, fakeCodec{}), cmp.Comparer(typeComparer)) {
   465  				t.Errorf("registered interfaces are not correct: got %#v, want %#v", got, want)
   466  			}
   467  		})
   468  		t.Run("type", func(t *testing.T) {
   469  			t.Parallel()
   470  
   471  			ft1, ft2, ft4 := fakeType1{}, fakeType2{}, fakeType4{}
   472  			reg := NewRegistry()
   473  			reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc1)
   474  			reg.RegisterTypeEncoder(reflect.TypeOf(ft2), fc2)
   475  			reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc3)
   476  			reg.RegisterTypeEncoder(reflect.TypeOf(ft4), fc4)
   477  
   478  			want := []struct {
   479  				t reflect.Type
   480  				c ValueEncoder
   481  			}{
   482  				{reflect.TypeOf(ft1), fc3},
   483  				{reflect.TypeOf(ft2), fc2},
   484  				{reflect.TypeOf(ft4), fc4},
   485  			}
   486  			got := reg.typeEncoders
   487  			for _, s := range want {
   488  				wantT, wantC := s.t, s.c
   489  				gotC, exists := got.Load(wantT)
   490  				if !exists {
   491  					t.Errorf("type missing in registry: %v", wantT)
   492  				}
   493  				if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) {
   494  					t.Errorf("codecs did not match: got %#v; want %#v", gotC, wantC)
   495  				}
   496  			}
   497  		})
   498  		t.Run("kind", func(t *testing.T) {
   499  			t.Parallel()
   500  
   501  			k1, k2, k4 := reflect.Struct, reflect.Slice, reflect.Map
   502  			reg := NewRegistry()
   503  			reg.RegisterKindEncoder(k1, fc1)
   504  			reg.RegisterKindEncoder(k2, fc2)
   505  			reg.RegisterKindEncoder(k1, fc3)
   506  			reg.RegisterKindEncoder(k4, fc4)
   507  
   508  			want := []struct {
   509  				k reflect.Kind
   510  				c ValueEncoder
   511  			}{
   512  				{k1, fc3},
   513  				{k2, fc2},
   514  				{k4, fc4},
   515  			}
   516  			got := reg.kindEncoders
   517  			for _, s := range want {
   518  				wantK, wantC := s.k, s.c
   519  				gotC, exists := got.Load(wantK)
   520  				if !exists {
   521  					t.Errorf("type missing in registry: %v", wantK)
   522  				}
   523  				if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) {
   524  					t.Errorf("codecs did not match: got %#v, want %#v", gotC, wantC)
   525  				}
   526  			}
   527  		})
   528  		t.Run("RegisterDefault", func(t *testing.T) {
   529  			t.Parallel()
   530  
   531  			t.Run("MapCodec", func(t *testing.T) {
   532  				t.Parallel()
   533  
   534  				codec := &fakeCodec{num: 1}
   535  				codec2 := &fakeCodec{num: 2}
   536  				reg := NewRegistry()
   537  				reg.RegisterKindEncoder(reflect.Map, codec)
   538  				if reg.kindEncoders.get(reflect.Map) != codec {
   539  					t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec)
   540  				}
   541  				reg.RegisterKindEncoder(reflect.Map, codec2)
   542  				if reg.kindEncoders.get(reflect.Map) != codec2 {
   543  					t.Errorf("map codec properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec2)
   544  				}
   545  			})
   546  			t.Run("StructCodec", func(t *testing.T) {
   547  				t.Parallel()
   548  
   549  				codec := &fakeCodec{num: 1}
   550  				codec2 := &fakeCodec{num: 2}
   551  				reg := NewRegistry()
   552  				reg.RegisterKindEncoder(reflect.Struct, codec)
   553  				if reg.kindEncoders.get(reflect.Struct) != codec {
   554  					t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec)
   555  				}
   556  				reg.RegisterKindEncoder(reflect.Struct, codec2)
   557  				if reg.kindEncoders.get(reflect.Struct) != codec2 {
   558  					t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec2)
   559  				}
   560  			})
   561  			t.Run("SliceCodec", func(t *testing.T) {
   562  				t.Parallel()
   563  
   564  				codec := &fakeCodec{num: 1}
   565  				codec2 := &fakeCodec{num: 2}
   566  				reg := NewRegistry()
   567  				reg.RegisterKindEncoder(reflect.Slice, codec)
   568  				if reg.kindEncoders.get(reflect.Slice) != codec {
   569  					t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec)
   570  				}
   571  				reg.RegisterKindEncoder(reflect.Slice, codec2)
   572  				if reg.kindEncoders.get(reflect.Slice) != codec2 {
   573  					t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec2)
   574  				}
   575  			})
   576  			t.Run("ArrayCodec", func(t *testing.T) {
   577  				t.Parallel()
   578  
   579  				codec := &fakeCodec{num: 1}
   580  				codec2 := &fakeCodec{num: 2}
   581  				reg := NewRegistry()
   582  				reg.RegisterKindEncoder(reflect.Array, codec)
   583  				if reg.kindEncoders.get(reflect.Array) != codec {
   584  					t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec)
   585  				}
   586  				reg.RegisterKindEncoder(reflect.Array, codec2)
   587  				if reg.kindEncoders.get(reflect.Array) != codec2 {
   588  					t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec2)
   589  				}
   590  			})
   591  		})
   592  		t.Run("Lookup", func(t *testing.T) {
   593  			t.Parallel()
   594  
   595  			type Codec interface {
   596  				ValueEncoder
   597  				ValueDecoder
   598  			}
   599  
   600  			var (
   601  				arrinstance     [12]int
   602  				arr             = reflect.TypeOf(arrinstance)
   603  				slc             = reflect.TypeOf(make([]int, 12))
   604  				m               = reflect.TypeOf(make(map[string]int))
   605  				strct           = reflect.TypeOf(struct{ Foo string }{})
   606  				ft1             = reflect.PtrTo(reflect.TypeOf(fakeType1{}))
   607  				ft2             = reflect.TypeOf(fakeType2{})
   608  				ft3             = reflect.TypeOf(fakeType5(func(string, string) string { return "fakeType5" }))
   609  				ti1             = reflect.TypeOf((*testInterface1)(nil)).Elem()
   610  				ti2             = reflect.TypeOf((*testInterface2)(nil)).Elem()
   611  				ti1Impl         = reflect.TypeOf(testInterface1Impl{})
   612  				ti2Impl         = reflect.TypeOf(testInterface2Impl{})
   613  				ti3             = reflect.TypeOf((*testInterface3)(nil)).Elem()
   614  				ti3Impl         = reflect.TypeOf(testInterface3Impl{})
   615  				ti3ImplPtr      = reflect.TypeOf((*testInterface3Impl)(nil))
   616  				fc1, fc2        = &fakeCodec{num: 1}, &fakeCodec{num: 2}
   617  				fsc, fslcc, fmc = new(fakeStructCodec), new(fakeSliceCodec), new(fakeMapCodec)
   618  				pc              = NewPointerCodec()
   619  			)
   620  
   621  			reg := NewRegistry()
   622  			reg.RegisterTypeEncoder(ft1, fc1)
   623  			reg.RegisterTypeEncoder(ft2, fc2)
   624  			reg.RegisterTypeEncoder(ti1, fc1)
   625  			reg.RegisterKindEncoder(reflect.Struct, fsc)
   626  			reg.RegisterKindEncoder(reflect.Slice, fslcc)
   627  			reg.RegisterKindEncoder(reflect.Array, fslcc)
   628  			reg.RegisterKindEncoder(reflect.Map, fmc)
   629  			reg.RegisterKindEncoder(reflect.Ptr, pc)
   630  			reg.RegisterTypeDecoder(ft1, fc1)
   631  			reg.RegisterTypeDecoder(ft2, fc2)
   632  			reg.RegisterTypeDecoder(ti1, fc1) // values whose exact type is testInterface1 will use fc1 encoder
   633  			reg.RegisterKindDecoder(reflect.Struct, fsc)
   634  			reg.RegisterKindDecoder(reflect.Slice, fslcc)
   635  			reg.RegisterKindDecoder(reflect.Array, fslcc)
   636  			reg.RegisterKindDecoder(reflect.Map, fmc)
   637  			reg.RegisterKindDecoder(reflect.Ptr, pc)
   638  			reg.RegisterInterfaceEncoder(ti2, fc2)
   639  			reg.RegisterInterfaceDecoder(ti2, fc2)
   640  			reg.RegisterInterfaceEncoder(ti3, fc3)
   641  			reg.RegisterInterfaceDecoder(ti3, fc3)
   642  
   643  			testCases := []struct {
   644  				name      string
   645  				t         reflect.Type
   646  				wantcodec Codec
   647  				wanterr   error
   648  				testcache bool
   649  			}{
   650  				{
   651  					"type registry (pointer)",
   652  					ft1,
   653  					fc1,
   654  					nil,
   655  					false,
   656  				},
   657  				{
   658  					"type registry (non-pointer)",
   659  					ft2,
   660  					fc2,
   661  					nil,
   662  					false,
   663  				},
   664  				{
   665  					// lookup an interface type and expect that the registered encoder is returned
   666  					"interface with type encoder",
   667  					ti1,
   668  					fc1,
   669  					nil,
   670  					true,
   671  				},
   672  				{
   673  					// lookup a type that implements an interface and expect that the default struct codec is returned
   674  					"interface implementation with type encoder",
   675  					ti1Impl,
   676  					fsc,
   677  					nil,
   678  					false,
   679  				},
   680  				{
   681  					// lookup an interface type and expect that the registered hook is returned
   682  					"interface with hook",
   683  					ti2,
   684  					fc2,
   685  					nil,
   686  					false,
   687  				},
   688  				{
   689  					// lookup a type that implements an interface and expect that the registered hook is returned
   690  					"interface implementation with hook",
   691  					ti2Impl,
   692  					fc2,
   693  					nil,
   694  					false,
   695  				},
   696  				{
   697  					// lookup a pointer to a type where the pointer implements an interface and expect that the
   698  					// registered hook is returned
   699  					"interface pointer to implementation with hook (pointer)",
   700  					ti3ImplPtr,
   701  					fc3,
   702  					nil,
   703  					false,
   704  				},
   705  				{
   706  					"default struct codec (pointer)",
   707  					reflect.PtrTo(strct),
   708  					pc,
   709  					nil,
   710  					false,
   711  				},
   712  				{
   713  					"default struct codec (non-pointer)",
   714  					strct,
   715  					fsc,
   716  					nil,
   717  					false,
   718  				},
   719  				{
   720  					"default array codec",
   721  					arr,
   722  					fslcc,
   723  					nil,
   724  					false,
   725  				},
   726  				{
   727  					"default slice codec",
   728  					slc,
   729  					fslcc,
   730  					nil,
   731  					false,
   732  				},
   733  				{
   734  					"default map",
   735  					m,
   736  					fmc,
   737  					nil,
   738  					false,
   739  				},
   740  				{
   741  					"map non-string key",
   742  					reflect.TypeOf(map[int]int{}),
   743  					fmc,
   744  					nil,
   745  					false,
   746  				},
   747  				{
   748  					"No Codec Registered",
   749  					ft3,
   750  					nil,
   751  					ErrNoEncoder{Type: ft3},
   752  					false,
   753  				},
   754  			}
   755  
   756  			allowunexported := cmp.AllowUnexported(fakeCodec{}, fakeStructCodec{}, fakeSliceCodec{}, fakeMapCodec{})
   757  			comparepc := func(pc1, pc2 *PointerCodec) bool { return true }
   758  			for _, tc := range testCases {
   759  				tc := tc
   760  
   761  				t.Run(tc.name, func(t *testing.T) {
   762  					t.Parallel()
   763  
   764  					t.Run("Encoder", func(t *testing.T) {
   765  						t.Parallel()
   766  
   767  						gotcodec, goterr := reg.LookupEncoder(tc.t)
   768  						if !cmp.Equal(goterr, tc.wanterr, cmp.Comparer(compareErrors)) {
   769  							t.Errorf("errors did not match: got %#v, want %#v", goterr, tc.wanterr)
   770  						}
   771  						if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) {
   772  							t.Errorf("codecs did not match: got %#v, want %#v", gotcodec, tc.wantcodec)
   773  						}
   774  					})
   775  					t.Run("Decoder", func(t *testing.T) {
   776  						t.Parallel()
   777  
   778  						wanterr := tc.wanterr
   779  						if ene, ok := tc.wanterr.(ErrNoEncoder); ok {
   780  							wanterr = ErrNoDecoder(ene)
   781  						}
   782  
   783  						gotcodec, goterr := reg.LookupDecoder(tc.t)
   784  						if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) {
   785  							t.Errorf("errors did not match: got %#v, want %#v", goterr, wanterr)
   786  						}
   787  						if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) {
   788  							t.Errorf("codecs did not match: got %v: want %v", gotcodec, tc.wantcodec)
   789  						}
   790  					})
   791  				})
   792  			}
   793  			t.Run("nil type", func(t *testing.T) {
   794  				t.Parallel()
   795  
   796  				t.Run("Encoder", func(t *testing.T) {
   797  					t.Parallel()
   798  
   799  					wanterr := ErrNoEncoder{Type: reflect.TypeOf(nil)}
   800  
   801  					gotcodec, goterr := reg.LookupEncoder(nil)
   802  					if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) {
   803  						t.Errorf("errors did not match: got %#v, want %#v", goterr, wanterr)
   804  					}
   805  					if !cmp.Equal(gotcodec, nil, allowunexported, cmp.Comparer(comparepc)) {
   806  						t.Errorf("codecs did not match: got %#v, want nil", gotcodec)
   807  					}
   808  				})
   809  				t.Run("Decoder", func(t *testing.T) {
   810  					t.Parallel()
   811  
   812  					wanterr := ErrNilType
   813  
   814  					gotcodec, goterr := reg.LookupDecoder(nil)
   815  					if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) {
   816  						t.Errorf("errors did not match: got %#v, want %#v", goterr, wanterr)
   817  					}
   818  					if !cmp.Equal(gotcodec, nil, allowunexported, cmp.Comparer(comparepc)) {
   819  						t.Errorf("codecs did not match: got %v: want nil", gotcodec)
   820  					}
   821  				})
   822  			})
   823  			// lookup a type whose pointer implements an interface and expect that the registered hook is
   824  			// returned
   825  			t.Run("interface implementation with hook (pointer)", func(t *testing.T) {
   826  				t.Parallel()
   827  
   828  				t.Run("Encoder", func(t *testing.T) {
   829  					t.Parallel()
   830  					gotEnc, err := reg.LookupEncoder(ti3Impl)
   831  					assert.Nil(t, err, "LookupEncoder error: %v", err)
   832  
   833  					cae, ok := gotEnc.(*condAddrEncoder)
   834  					assert.True(t, ok, "Expected CondAddrEncoder, got %T", gotEnc)
   835  					if !cmp.Equal(cae.canAddrEnc, fc3, allowunexported, cmp.Comparer(comparepc)) {
   836  						t.Errorf("expected canAddrEnc %#v, got %#v", cae.canAddrEnc, fc3)
   837  					}
   838  					if !cmp.Equal(cae.elseEnc, fsc, allowunexported, cmp.Comparer(comparepc)) {
   839  						t.Errorf("expected elseEnc %#v, got %#v", cae.elseEnc, fsc)
   840  					}
   841  				})
   842  				t.Run("Decoder", func(t *testing.T) {
   843  					t.Parallel()
   844  
   845  					gotDec, err := reg.LookupDecoder(ti3Impl)
   846  					assert.Nil(t, err, "LookupDecoder error: %v", err)
   847  
   848  					cad, ok := gotDec.(*condAddrDecoder)
   849  					assert.True(t, ok, "Expected CondAddrDecoder, got %T", gotDec)
   850  					if !cmp.Equal(cad.canAddrDec, fc3, allowunexported, cmp.Comparer(comparepc)) {
   851  						t.Errorf("expected canAddrDec %#v, got %#v", cad.canAddrDec, fc3)
   852  					}
   853  					if !cmp.Equal(cad.elseDec, fsc, allowunexported, cmp.Comparer(comparepc)) {
   854  						t.Errorf("expected elseDec %#v, got %#v", cad.elseDec, fsc)
   855  					}
   856  				})
   857  			})
   858  		})
   859  	})
   860  	t.Run("Type Map", func(t *testing.T) {
   861  		t.Parallel()
   862  		reg := NewRegistry()
   863  		reg.RegisterTypeMapEntry(bsontype.String, reflect.TypeOf(""))
   864  		reg.RegisterTypeMapEntry(bsontype.Int32, reflect.TypeOf(int(0)))
   865  
   866  		var got, want reflect.Type
   867  
   868  		want = reflect.TypeOf("")
   869  		got, err := reg.LookupTypeMapEntry(bsontype.String)
   870  		noerr(t, err)
   871  		if got != want {
   872  			t.Errorf("unexpected type: got %#v, want %#v", got, want)
   873  		}
   874  
   875  		want = reflect.TypeOf(int(0))
   876  		got, err = reg.LookupTypeMapEntry(bsontype.Int32)
   877  		noerr(t, err)
   878  		if got != want {
   879  			t.Errorf("unexpected type: got %#v, want %#v", got, want)
   880  		}
   881  
   882  		want = nil
   883  		wanterr := ErrNoTypeMapEntry{Type: bsontype.ObjectID}
   884  		got, err = reg.LookupTypeMapEntry(bsontype.ObjectID)
   885  		if !errors.Is(err, wanterr) {
   886  			t.Errorf("unexpected error: got %#v, want %#v", err, wanterr)
   887  		}
   888  		if got != want {
   889  			t.Errorf("unexpected error: got %#v, want %#v", got, want)
   890  		}
   891  	})
   892  }
   893  
   894  // get is only for testing as it does return if the value was found
   895  func (c *kindEncoderCache) get(rt reflect.Kind) ValueEncoder {
   896  	e, _ := c.Load(rt)
   897  	return e
   898  }
   899  
   900  func BenchmarkLookupEncoder(b *testing.B) {
   901  	type childStruct struct {
   902  		V1, V2, V3, V4 int
   903  	}
   904  	type nestedStruct struct {
   905  		childStruct
   906  		A struct{ C1, C2, C3, C4 childStruct }
   907  		B struct{ C1, C2, C3, C4 childStruct }
   908  		C struct{ M1, M2, M3, M4 map[int]int }
   909  	}
   910  	types := [...]reflect.Type{
   911  		reflect.TypeOf(int64(1)),
   912  		reflect.TypeOf(&fakeCodec{}),
   913  		reflect.TypeOf(&testInterface1Impl{}),
   914  		reflect.TypeOf(&nestedStruct{}),
   915  	}
   916  	r := NewRegistry()
   917  	for _, typ := range types {
   918  		r.RegisterTypeEncoder(typ, &fakeCodec{})
   919  	}
   920  	b.Run("Serial", func(b *testing.B) {
   921  		for i := 0; i < b.N; i++ {
   922  			_, err := r.LookupEncoder(types[i%len(types)])
   923  			if err != nil {
   924  				b.Fatal(err)
   925  			}
   926  		}
   927  	})
   928  	b.Run("Parallel", func(b *testing.B) {
   929  		b.RunParallel(func(pb *testing.PB) {
   930  			for i := 0; pb.Next(); i++ {
   931  				_, err := r.LookupEncoder(types[i%len(types)])
   932  				if err != nil {
   933  					b.Fatal(err)
   934  				}
   935  			}
   936  		})
   937  	})
   938  }
   939  
   940  type fakeType1 struct{}
   941  type fakeType2 struct{}
   942  type fakeType4 struct{}
   943  type fakeType5 func(string, string) string
   944  type fakeStructCodec struct{ *fakeCodec }
   945  type fakeSliceCodec struct{ *fakeCodec }
   946  type fakeMapCodec struct{ *fakeCodec }
   947  
   948  type fakeCodec struct {
   949  	// num is used to differentiate fakeCodec instances and to force Go to allocate a new value in
   950  	// memory for every fakeCodec. If fakeCodec were an empty struct, Go may use the same pointer
   951  	// for every instance of fakeCodec, making comparisons between pointers to instances of
   952  	// fakeCodec sometimes meaningless.
   953  	num int
   954  }
   955  
   956  func (*fakeCodec) EncodeValue(EncodeContext, bsonrw.ValueWriter, reflect.Value) error {
   957  	return nil
   958  }
   959  func (*fakeCodec) DecodeValue(DecodeContext, bsonrw.ValueReader, reflect.Value) error {
   960  	return nil
   961  }
   962  
   963  type testInterface1 interface{ test1() }
   964  type testInterface2 interface{ test2() }
   965  type testInterface3 interface{ test3() }
   966  type testInterface4 interface{ test4() }
   967  
   968  type testInterface1Impl struct{}
   969  
   970  var _ testInterface1 = testInterface1Impl{}
   971  
   972  func (testInterface1Impl) test1() {}
   973  
   974  type testInterface2Impl struct{}
   975  
   976  var _ testInterface2 = testInterface2Impl{}
   977  
   978  func (testInterface2Impl) test2() {}
   979  
   980  type testInterface3Impl struct{}
   981  
   982  var _ testInterface3 = (*testInterface3Impl)(nil)
   983  
   984  func (*testInterface3Impl) test3() {}
   985  
   986  func typeComparer(i1, i2 reflect.Type) bool { return i1 == i2 }
   987  

View as plain text