...

Source file src/go.mongodb.org/mongo-driver/bson/marshal_test.go

Documentation: go.mongodb.org/mongo-driver/bson

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package bson
     8  
     9  import (
    10  	"bytes"
    11  	"errors"
    12  	"fmt"
    13  	"reflect"
    14  	"sync"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/google/go-cmp/cmp"
    19  	"go.mongodb.org/mongo-driver/bson/bsoncodec"
    20  	"go.mongodb.org/mongo-driver/bson/bsonrw"
    21  	"go.mongodb.org/mongo-driver/bson/primitive"
    22  	"go.mongodb.org/mongo-driver/internal/assert"
    23  	"go.mongodb.org/mongo-driver/internal/require"
    24  	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
    25  )
    26  
    27  var tInt32 = reflect.TypeOf(int32(0))
    28  
    29  func TestMarshalAppendWithRegistry(t *testing.T) {
    30  	for _, tc := range marshalingTestCases {
    31  		t.Run(tc.name, func(t *testing.T) {
    32  			dst := make([]byte, 0, 1024)
    33  			var reg *bsoncodec.Registry
    34  			if tc.reg != nil {
    35  				reg = tc.reg
    36  			} else {
    37  				reg = DefaultRegistry
    38  			}
    39  			got, err := MarshalAppendWithRegistry(reg, dst, tc.val)
    40  			noerr(t, err)
    41  
    42  			if !bytes.Equal(got, tc.want) {
    43  				t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
    44  				t.Errorf("Bytes:\n%v\n%v", got, tc.want)
    45  			}
    46  		})
    47  	}
    48  }
    49  
    50  func TestMarshalAppendWithContext(t *testing.T) {
    51  	for _, tc := range marshalingTestCases {
    52  		t.Run(tc.name, func(t *testing.T) {
    53  			dst := make([]byte, 0, 1024)
    54  			var reg *bsoncodec.Registry
    55  			if tc.reg != nil {
    56  				reg = tc.reg
    57  			} else {
    58  				reg = DefaultRegistry
    59  			}
    60  			ec := bsoncodec.EncodeContext{Registry: reg}
    61  			got, err := MarshalAppendWithContext(ec, dst, tc.val)
    62  			noerr(t, err)
    63  
    64  			if !bytes.Equal(got, tc.want) {
    65  				t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
    66  				t.Errorf("Bytes:\n%v\n%v", got, tc.want)
    67  			}
    68  		})
    69  	}
    70  }
    71  
    72  func TestMarshalWithRegistry(t *testing.T) {
    73  	for _, tc := range marshalingTestCases {
    74  		t.Run(tc.name, func(t *testing.T) {
    75  			var reg *bsoncodec.Registry
    76  			if tc.reg != nil {
    77  				reg = tc.reg
    78  			} else {
    79  				reg = DefaultRegistry
    80  			}
    81  			got, err := MarshalWithRegistry(reg, tc.val)
    82  			noerr(t, err)
    83  
    84  			if !bytes.Equal(got, tc.want) {
    85  				t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
    86  				t.Errorf("Bytes:\n%v\n%v", got, tc.want)
    87  			}
    88  		})
    89  	}
    90  }
    91  
    92  func TestMarshalWithContext(t *testing.T) {
    93  	for _, tc := range marshalingTestCases {
    94  		t.Run(tc.name, func(t *testing.T) {
    95  			var reg *bsoncodec.Registry
    96  			if tc.reg != nil {
    97  				reg = tc.reg
    98  			} else {
    99  				reg = DefaultRegistry
   100  			}
   101  			ec := bsoncodec.EncodeContext{Registry: reg}
   102  			got, err := MarshalWithContext(ec, tc.val)
   103  			noerr(t, err)
   104  
   105  			if !bytes.Equal(got, tc.want) {
   106  				t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
   107  				t.Errorf("Bytes:\n%v\n%v", got, tc.want)
   108  			}
   109  		})
   110  	}
   111  }
   112  
   113  func TestMarshalAppend(t *testing.T) {
   114  	for _, tc := range marshalingTestCases {
   115  		t.Run(tc.name, func(t *testing.T) {
   116  			if tc.reg != nil {
   117  				t.Skip() // test requires custom registry
   118  			}
   119  			dst := make([]byte, 0, 1024)
   120  			got, err := MarshalAppend(dst, tc.val)
   121  			noerr(t, err)
   122  
   123  			if !bytes.Equal(got, tc.want) {
   124  				t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want)
   125  				t.Errorf("Bytes:\n%v\n%v", got, tc.want)
   126  			}
   127  		})
   128  	}
   129  }
   130  
   131  func TestMarshalExtJSONAppendWithContext(t *testing.T) {
   132  	t.Run("MarshalExtJSONAppendWithContext", func(t *testing.T) {
   133  		dst := make([]byte, 0, 1024)
   134  		type teststruct struct{ Foo int }
   135  		val := teststruct{1}
   136  		ec := bsoncodec.EncodeContext{Registry: DefaultRegistry}
   137  		got, err := MarshalExtJSONAppendWithContext(ec, dst, val, true, false)
   138  		noerr(t, err)
   139  		want := []byte(`{"foo":{"$numberInt":"1"}}`)
   140  		if !bytes.Equal(got, want) {
   141  			t.Errorf("Bytes are not equal. got %v; want %v", got, want)
   142  			t.Errorf("Bytes:\n%s\n%s", got, want)
   143  		}
   144  	})
   145  }
   146  
   147  func TestMarshalExtJSONWithContext(t *testing.T) {
   148  	t.Run("MarshalExtJSONWithContext", func(t *testing.T) {
   149  		type teststruct struct{ Foo int }
   150  		val := teststruct{1}
   151  		ec := bsoncodec.EncodeContext{Registry: DefaultRegistry}
   152  		got, err := MarshalExtJSONWithContext(ec, val, true, false)
   153  		noerr(t, err)
   154  		want := []byte(`{"foo":{"$numberInt":"1"}}`)
   155  		if !bytes.Equal(got, want) {
   156  			t.Errorf("Bytes are not equal. got %v; want %v", got, want)
   157  			t.Errorf("Bytes:\n%s\n%s", got, want)
   158  		}
   159  	})
   160  }
   161  
   162  func TestMarshal_roundtripFromBytes(t *testing.T) {
   163  	before := []byte{
   164  		// length
   165  		0x1c, 0x0, 0x0, 0x0,
   166  
   167  		// --- begin array ---
   168  
   169  		// type - document
   170  		0x3,
   171  		// key - "foo"
   172  		0x66, 0x6f, 0x6f, 0x0,
   173  
   174  		// length
   175  		0x12, 0x0, 0x0, 0x0,
   176  		// type - string
   177  		0x2,
   178  		// key - "bar"
   179  		0x62, 0x61, 0x72, 0x0,
   180  		// value - string length
   181  		0x4, 0x0, 0x0, 0x0,
   182  		// value - "baz"
   183  		0x62, 0x61, 0x7a, 0x0,
   184  
   185  		// null terminator
   186  		0x0,
   187  
   188  		// --- end array ---
   189  
   190  		// null terminator
   191  		0x0,
   192  	}
   193  
   194  	var doc D
   195  	require.NoError(t, Unmarshal(before, &doc))
   196  
   197  	after, err := Marshal(doc)
   198  	require.NoError(t, err)
   199  
   200  	require.True(t, bytes.Equal(before, after))
   201  }
   202  
   203  func TestMarshal_roundtripFromDoc(t *testing.T) {
   204  	before := D{
   205  		{"foo", "bar"},
   206  		{"baz", int64(-27)},
   207  		{"bing", A{nil, primitive.Regex{Pattern: "word", Options: "i"}}},
   208  	}
   209  
   210  	b, err := Marshal(before)
   211  	require.NoError(t, err)
   212  
   213  	var after D
   214  	require.NoError(t, Unmarshal(b, &after))
   215  
   216  	if !cmp.Equal(after, before) {
   217  		t.Errorf("Documents to not match. got %v; want %v", after, before)
   218  	}
   219  }
   220  
   221  func TestCachingEncodersNotSharedAcrossRegistries(t *testing.T) {
   222  	// Encoders that have caches for recursive encoder lookup should not be shared across Registry instances. Otherwise,
   223  	// the first EncodeValue call would cache an encoder and a subsequent call would see that encoder even if a
   224  	// different Registry is used.
   225  
   226  	// Create a custom Registry that negates int32 values when encoding.
   227  	var encodeInt32 bsoncodec.ValueEncoderFunc = func(_ bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
   228  		if val.Kind() != reflect.Int32 {
   229  			return fmt.Errorf("expected kind to be int32, got %v", val.Kind())
   230  		}
   231  
   232  		return vw.WriteInt32(int32(val.Int()) * -1)
   233  	}
   234  	customReg := NewRegistryBuilder().
   235  		RegisterTypeEncoder(tInt32, encodeInt32).
   236  		Build()
   237  
   238  	// Helper function to run the test and make assertions. The provided original value should result in the document
   239  	// {"x": {$numberInt: 1}} when marshalled with the default registry.
   240  	verifyResults := func(t *testing.T, original interface{}) {
   241  		// Marshal using the default and custom registries. Assert that the result is {x: 1} and {x: -1}, respectively.
   242  
   243  		first, err := Marshal(original)
   244  		assert.Nil(t, err, "Marshal error: %v", err)
   245  		expectedFirst := Raw(bsoncore.BuildDocumentFromElements(
   246  			nil,
   247  			bsoncore.AppendInt32Element(nil, "x", 1),
   248  		))
   249  		assert.Equal(t, expectedFirst, Raw(first), "expected document %v, got %v", expectedFirst, Raw(first))
   250  
   251  		second, err := MarshalWithRegistry(customReg, original)
   252  		assert.Nil(t, err, "Marshal error: %v", err)
   253  		expectedSecond := Raw(bsoncore.BuildDocumentFromElements(
   254  			nil,
   255  			bsoncore.AppendInt32Element(nil, "x", -1),
   256  		))
   257  		assert.Equal(t, expectedSecond, Raw(second), "expected document %v, got %v", expectedSecond, Raw(second))
   258  	}
   259  
   260  	t.Run("struct", func(t *testing.T) {
   261  		type Struct struct {
   262  			X int32
   263  		}
   264  		verifyResults(t, Struct{
   265  			X: 1,
   266  		})
   267  	})
   268  	t.Run("pointer", func(t *testing.T) {
   269  		i32 := int32(1)
   270  		verifyResults(t, M{
   271  			"x": &i32,
   272  		})
   273  	})
   274  }
   275  
   276  func TestNullBytes(t *testing.T) {
   277  	t.Run("element keys", func(t *testing.T) {
   278  		doc := D{{"a\x00", "foobar"}}
   279  		res, err := Marshal(doc)
   280  		want := errors.New("BSON element key cannot contain null bytes")
   281  		assert.Equal(t, want, err, "expected Marshal error %v, got error %v with result %q", want, err, Raw(res))
   282  	})
   283  
   284  	t.Run("regex values", func(t *testing.T) {
   285  		wantErr := errors.New("BSON regex values cannot contain null bytes")
   286  
   287  		testCases := []struct {
   288  			name    string
   289  			pattern string
   290  			options string
   291  		}{
   292  			{"null bytes in pattern", "a\x00", "i"},
   293  			{"null bytes in options", "pattern", "i\x00"},
   294  		}
   295  		for _, tc := range testCases {
   296  			t.Run(tc.name, func(t *testing.T) {
   297  				regex := primitive.Regex{
   298  					Pattern: tc.pattern,
   299  					Options: tc.options,
   300  				}
   301  				res, err := Marshal(D{{"foo", regex}})
   302  				assert.Equal(t, wantErr, err, "expected Marshal error %v, got error %v with result %q", wantErr, err, Raw(res))
   303  			})
   304  		}
   305  	})
   306  
   307  	t.Run("sub document field name", func(t *testing.T) {
   308  		doc := D{{"foo", D{{"foobar", D{{"a\x00", "foobar"}}}}}}
   309  		res, err := Marshal(doc)
   310  		wantErr := errors.New("BSON element key cannot contain null bytes")
   311  		assert.Equal(t, wantErr, err, "expected Marshal error %v, got error %v with result %q", wantErr, err, Raw(res))
   312  	})
   313  }
   314  
   315  func TestMarshalExtJSONIndent(t *testing.T) {
   316  	type indentTestCase struct {
   317  		name            string
   318  		val             interface{}
   319  		expectedExtJSON string
   320  	}
   321  
   322  	// expectedExtJSON must be written as below because single-quoted
   323  	// literal strings capture undesired code formatting tabs
   324  	testCases := []indentTestCase{
   325  		{
   326  			"empty val",
   327  			struct{}{},
   328  			`{}`,
   329  		},
   330  		{
   331  			"embedded struct",
   332  			struct {
   333  				Embedded interface{} `json:"embedded"`
   334  				Foo      string      `json:"foo"`
   335  			}{
   336  				Embedded: struct {
   337  					Name string `json:"name"`
   338  					Word string `json:"word"`
   339  				}{
   340  					Name: "test",
   341  					Word: "word",
   342  				},
   343  				Foo: "bar",
   344  			},
   345  			"{\n\t\"embedded\": {\n\t\t\"name\": \"test\",\n\t\t\"word\": \"word\"\n\t},\n\t\"foo\": \"bar\"\n}",
   346  		},
   347  		{
   348  			"date struct",
   349  			struct {
   350  				Foo  string    `json:"foo"`
   351  				Date time.Time `json:"date"`
   352  			}{
   353  				Foo:  "bar",
   354  				Date: time.Date(2000, time.January, 1, 12, 0, 0, 0, time.UTC),
   355  			},
   356  			"{\n\t\"foo\": \"bar\",\n\t\"date\": {\n\t\t\"$date\": {\n\t\t\t\"$numberLong\": \"946728000000\"\n\t\t}\n\t}\n}",
   357  		},
   358  		{
   359  			"float struct",
   360  			struct {
   361  				Foo   string  `json:"foo"`
   362  				Float float32 `json:"float"`
   363  			}{
   364  				Foo:   "bar",
   365  				Float: 3.14,
   366  			},
   367  			"{\n\t\"foo\": \"bar\",\n\t\"float\": {\n\t\t\"$numberDouble\": \"3.140000104904175\"\n\t}\n}",
   368  		},
   369  	}
   370  
   371  	for _, tc := range testCases {
   372  		tc := tc
   373  		t.Run(tc.name, func(t *testing.T) {
   374  			t.Parallel()
   375  			extJSONBytes, err := MarshalExtJSONIndent(tc.val, true, false, "", "\t")
   376  			assert.Nil(t, err, "Marshal indent error: %v", err)
   377  
   378  			expectedExtJSONBytes := []byte(tc.expectedExtJSON)
   379  
   380  			assert.Equal(t, expectedExtJSONBytes, extJSONBytes, "expected:\n%s\ngot:\n%s", expectedExtJSONBytes, extJSONBytes)
   381  		})
   382  	}
   383  }
   384  
   385  func TestMarshalConcurrently(t *testing.T) {
   386  	t.Parallel()
   387  
   388  	const size = 10_000
   389  
   390  	wg := sync.WaitGroup{}
   391  	wg.Add(size)
   392  	for i := 0; i < size; i++ {
   393  		go func() {
   394  			defer wg.Done()
   395  			_, _ = Marshal(struct{ LastError error }{})
   396  		}()
   397  	}
   398  	wg.Wait()
   399  }
   400  

View as plain text