...
1
2
3
4
5
6
7 package gridfs
8
9 import (
10 "context"
11 "errors"
12 "io"
13 "math"
14 "time"
15
16 "go.mongodb.org/mongo-driver/bson"
17 "go.mongodb.org/mongo-driver/mongo"
18 )
19
20
21 var ErrWrongIndex = errors.New("chunk index does not match expected index")
22
23
24 var ErrWrongSize = errors.New("chunk size does not match expected size")
25
26 var errNoMoreChunks = errors.New("no more chunks remaining")
27
28
29 type DownloadStream struct {
30 numChunks int32
31 chunkSize int32
32 cursor *mongo.Cursor
33 done bool
34 closed bool
35 buffer []byte
36 bufferStart int
37 bufferEnd int
38 expectedChunk int32
39 readDeadline time.Time
40 fileLen int64
41
42
43
44
45 file *File
46 }
47
48
49
50 type File struct {
51
52
53 ID interface{}
54
55
56 Length int64
57
58
59 ChunkSize int32
60
61
62
63 UploadDate time.Time
64
65
66 Name string
67
68
69
70 Metadata bson.Raw
71 }
72
73 var _ bson.Unmarshaler = (*File)(nil)
74
75
76
77 type unmarshalFile struct {
78 ID interface{} `bson:"_id"`
79 Length int64 `bson:"length"`
80 ChunkSize int32 `bson:"chunkSize"`
81 UploadDate time.Time `bson:"uploadDate"`
82 Name string `bson:"filename"`
83 Metadata bson.Raw `bson:"metadata"`
84 }
85
86
87
88
89 func (f *File) UnmarshalBSON(data []byte) error {
90 var temp unmarshalFile
91 if err := bson.Unmarshal(data, &temp); err != nil {
92 return err
93 }
94
95 f.ID = temp.ID
96 f.Length = temp.Length
97 f.ChunkSize = temp.ChunkSize
98 f.UploadDate = temp.UploadDate
99 f.Name = temp.Name
100 f.Metadata = temp.Metadata
101 return nil
102 }
103
104 func newDownloadStream(cursor *mongo.Cursor, chunkSize int32, file *File) *DownloadStream {
105 numChunks := int32(math.Ceil(float64(file.Length) / float64(chunkSize)))
106
107 return &DownloadStream{
108 numChunks: numChunks,
109 chunkSize: chunkSize,
110 cursor: cursor,
111 buffer: make([]byte, chunkSize),
112 done: cursor == nil,
113 fileLen: file.Length,
114 file: file,
115 }
116 }
117
118
119 func (ds *DownloadStream) Close() error {
120 if ds.closed {
121 return ErrStreamClosed
122 }
123
124 ds.closed = true
125 if ds.cursor != nil {
126 return ds.cursor.Close(context.Background())
127 }
128 return nil
129 }
130
131
132 func (ds *DownloadStream) SetReadDeadline(t time.Time) error {
133 if ds.closed {
134 return ErrStreamClosed
135 }
136
137 ds.readDeadline = t
138 return nil
139 }
140
141
142 func (ds *DownloadStream) Read(p []byte) (int, error) {
143 if ds.closed {
144 return 0, ErrStreamClosed
145 }
146
147 if ds.done {
148 return 0, io.EOF
149 }
150
151 ctx, cancel := deadlineContext(ds.readDeadline)
152 if cancel != nil {
153 defer cancel()
154 }
155
156 bytesCopied := 0
157 var err error
158 for bytesCopied < len(p) {
159 if ds.bufferStart >= ds.bufferEnd {
160
161 err = ds.fillBuffer(ctx)
162 if err != nil {
163 if errors.Is(err, errNoMoreChunks) {
164 if bytesCopied == 0 {
165 ds.done = true
166 return 0, io.EOF
167 }
168 return bytesCopied, nil
169 }
170 return bytesCopied, err
171 }
172 }
173
174 copied := copy(p[bytesCopied:], ds.buffer[ds.bufferStart:ds.bufferEnd])
175
176 bytesCopied += copied
177 ds.bufferStart += copied
178 }
179
180 return len(p), nil
181 }
182
183
184 func (ds *DownloadStream) Skip(skip int64) (int64, error) {
185 if ds.closed {
186 return 0, ErrStreamClosed
187 }
188
189 if ds.done {
190 return 0, nil
191 }
192
193 ctx, cancel := deadlineContext(ds.readDeadline)
194 if cancel != nil {
195 defer cancel()
196 }
197
198 var skipped int64
199 var err error
200
201 for skipped < skip {
202 if ds.bufferStart >= ds.bufferEnd {
203
204 err = ds.fillBuffer(ctx)
205 if err != nil {
206 if errors.Is(err, errNoMoreChunks) {
207 return skipped, nil
208 }
209 return skipped, err
210 }
211 }
212
213 toSkip := skip - skipped
214
215 bufferRemaining := ds.bufferEnd - ds.bufferStart
216 if toSkip > int64(bufferRemaining) {
217 toSkip = int64(bufferRemaining)
218 }
219
220 skipped += toSkip
221 ds.bufferStart += int(toSkip)
222 }
223
224 return skip, nil
225 }
226
227
228 func (ds *DownloadStream) GetFile() *File {
229 return ds.file
230 }
231
232 func (ds *DownloadStream) fillBuffer(ctx context.Context) error {
233 if !ds.cursor.Next(ctx) {
234 ds.done = true
235
236 if ds.cursor.Err() != nil {
237 _ = ds.cursor.Close(ctx)
238 return ds.cursor.Err()
239 }
240
241
242 if ds.expectedChunk != ds.numChunks {
243 return ErrWrongIndex
244 }
245 return errNoMoreChunks
246 }
247
248 chunkIndex, err := ds.cursor.Current.LookupErr("n")
249 if err != nil {
250 return err
251 }
252
253 var chunkIndexInt32 int32
254 if chunkIndexInt64, ok := chunkIndex.Int64OK(); ok {
255 chunkIndexInt32 = int32(chunkIndexInt64)
256 } else {
257 chunkIndexInt32 = chunkIndex.Int32()
258 }
259
260 if chunkIndexInt32 != ds.expectedChunk {
261 return ErrWrongIndex
262 }
263
264 ds.expectedChunk++
265 data, err := ds.cursor.Current.LookupErr("data")
266 if err != nil {
267 return err
268 }
269
270 _, dataBytes := data.Binary()
271 copied := copy(ds.buffer, dataBytes)
272
273 bytesLen := int32(len(dataBytes))
274 if ds.expectedChunk == ds.numChunks {
275
276 bytesDownloaded := int64(ds.chunkSize) * (int64(ds.expectedChunk) - int64(1))
277 bytesRemaining := ds.fileLen - bytesDownloaded
278
279 if int64(bytesLen) != bytesRemaining {
280 return ErrWrongSize
281 }
282 } else if bytesLen != ds.chunkSize {
283
284 return ErrWrongSize
285 }
286
287 ds.bufferStart = 0
288 ds.bufferEnd = copied
289
290 return nil
291 }
292
View as plain text