...
1 package rifs
2
3 import (
4 "io"
5 "time"
6
7 "github.com/dsoprea/go-logging"
8 )
9
10
11 type ProgressFunc func(n int, duration time.Duration, isEof bool) error
12
13
14
15 type WriteProgressWrapper struct {
16 w io.Writer
17 progressCb ProgressFunc
18 }
19
20
21 func NewWriteProgressWrapper(w io.Writer, progressCb ProgressFunc) io.Writer {
22 return &WriteProgressWrapper{
23 w: w,
24 progressCb: progressCb,
25 }
26 }
27
28
29 func (wpw *WriteProgressWrapper) Write(buffer []byte) (n int, err error) {
30 defer func() {
31 if state := recover(); state != nil {
32 err = log.Wrap(state.(error))
33 }
34 }()
35
36 startAt := time.Now()
37
38 n, err = wpw.w.Write(buffer)
39 log.PanicIf(err)
40
41 duration := time.Since(startAt)
42
43 err = wpw.progressCb(n, duration, false)
44 log.PanicIf(err)
45
46 return n, nil
47 }
48
49
50
51 type ReadProgressWrapper struct {
52 r io.Reader
53 progressCb ProgressFunc
54 }
55
56
57 func NewReadProgressWrapper(r io.Reader, progressCb ProgressFunc) io.Reader {
58 return &ReadProgressWrapper{
59 r: r,
60 progressCb: progressCb,
61 }
62 }
63
64
65 func (rpw *ReadProgressWrapper) Read(buffer []byte) (n int, err error) {
66 defer func() {
67 if state := recover(); state != nil {
68 err = log.Wrap(state.(error))
69 }
70 }()
71
72 startAt := time.Now()
73
74 n, err = rpw.r.Read(buffer)
75
76 duration := time.Since(startAt)
77
78 if err != nil {
79 if err == io.EOF {
80 errInner := rpw.progressCb(n, duration, true)
81 log.PanicIf(errInner)
82
83 return n, err
84 }
85
86 log.Panic(err)
87 }
88
89 err = rpw.progressCb(n, duration, false)
90 log.PanicIf(err)
91
92 return n, nil
93 }
94
View as plain text