1
2
3
4
5
6
7 package bsoncodec
8
9 import (
10 "fmt"
11 "math"
12 "reflect"
13
14 "go.mongodb.org/mongo-driver/bson/bsonoptions"
15 "go.mongodb.org/mongo-driver/bson/bsonrw"
16 "go.mongodb.org/mongo-driver/bson/bsontype"
17 )
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37 type UIntCodec struct {
38
39
40
41
42 EncodeToMinSize bool
43 }
44
45 var (
46 defaultUIntCodec = NewUIntCodec()
47
48
49
50 _ typeDecoder = defaultUIntCodec
51 )
52
53
54
55
56
57 func NewUIntCodec(opts ...*bsonoptions.UIntCodecOptions) *UIntCodec {
58 uintOpt := bsonoptions.MergeUIntCodecOptions(opts...)
59
60 codec := UIntCodec{}
61 if uintOpt.EncodeToMinSize != nil {
62 codec.EncodeToMinSize = *uintOpt.EncodeToMinSize
63 }
64 return &codec
65 }
66
67
68 func (uic *UIntCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
69 switch val.Kind() {
70 case reflect.Uint8, reflect.Uint16:
71 return vw.WriteInt32(int32(val.Uint()))
72 case reflect.Uint, reflect.Uint32, reflect.Uint64:
73 u64 := val.Uint()
74
75
76 useMinSize := ec.MinSize || (uic.EncodeToMinSize && val.Kind() != reflect.Uint64)
77
78 if u64 <= math.MaxInt32 && useMinSize {
79 return vw.WriteInt32(int32(u64))
80 }
81 if u64 > math.MaxInt64 {
82 return fmt.Errorf("%d overflows int64", u64)
83 }
84 return vw.WriteInt64(int64(u64))
85 }
86
87 return ValueEncoderError{
88 Name: "UintEncodeValue",
89 Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint},
90 Received: val,
91 }
92 }
93
94 func (uic *UIntCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) {
95 var i64 int64
96 var err error
97 switch vrType := vr.Type(); vrType {
98 case bsontype.Int32:
99 i32, err := vr.ReadInt32()
100 if err != nil {
101 return emptyValue, err
102 }
103 i64 = int64(i32)
104 case bsontype.Int64:
105 i64, err = vr.ReadInt64()
106 if err != nil {
107 return emptyValue, err
108 }
109 case bsontype.Double:
110 f64, err := vr.ReadDouble()
111 if err != nil {
112 return emptyValue, err
113 }
114 if !dc.Truncate && math.Floor(f64) != f64 {
115 return emptyValue, errCannotTruncate
116 }
117 if f64 > float64(math.MaxInt64) {
118 return emptyValue, fmt.Errorf("%g overflows int64", f64)
119 }
120 i64 = int64(f64)
121 case bsontype.Boolean:
122 b, err := vr.ReadBoolean()
123 if err != nil {
124 return emptyValue, err
125 }
126 if b {
127 i64 = 1
128 }
129 case bsontype.Null:
130 if err = vr.ReadNull(); err != nil {
131 return emptyValue, err
132 }
133 case bsontype.Undefined:
134 if err = vr.ReadUndefined(); err != nil {
135 return emptyValue, err
136 }
137 default:
138 return emptyValue, fmt.Errorf("cannot decode %v into an integer type", vrType)
139 }
140
141 switch t.Kind() {
142 case reflect.Uint8:
143 if i64 < 0 || i64 > math.MaxUint8 {
144 return emptyValue, fmt.Errorf("%d overflows uint8", i64)
145 }
146
147 return reflect.ValueOf(uint8(i64)), nil
148 case reflect.Uint16:
149 if i64 < 0 || i64 > math.MaxUint16 {
150 return emptyValue, fmt.Errorf("%d overflows uint16", i64)
151 }
152
153 return reflect.ValueOf(uint16(i64)), nil
154 case reflect.Uint32:
155 if i64 < 0 || i64 > math.MaxUint32 {
156 return emptyValue, fmt.Errorf("%d overflows uint32", i64)
157 }
158
159 return reflect.ValueOf(uint32(i64)), nil
160 case reflect.Uint64:
161 if i64 < 0 {
162 return emptyValue, fmt.Errorf("%d overflows uint64", i64)
163 }
164
165 return reflect.ValueOf(uint64(i64)), nil
166 case reflect.Uint:
167 if i64 < 0 || int64(uint(i64)) != i64 {
168 return emptyValue, fmt.Errorf("%d overflows uint", i64)
169 }
170
171 return reflect.ValueOf(uint(i64)), nil
172 default:
173 return emptyValue, ValueDecoderError{
174 Name: "UintDecodeValue",
175 Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint},
176 Received: reflect.Zero(t),
177 }
178 }
179 }
180
181
182 func (uic *UIntCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
183 if !val.CanSet() {
184 return ValueDecoderError{
185 Name: "UintDecodeValue",
186 Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint},
187 Received: val,
188 }
189 }
190
191 elem, err := uic.decodeType(dc, vr, val.Type())
192 if err != nil {
193 return err
194 }
195
196 val.SetUint(elem.Uint())
197 return nil
198 }
199
View as plain text