...
1 package ctxwatch
2
3 import (
4 "context"
5 "sync"
6 )
7
8
9
10 type ContextWatcher struct {
11 onCancel func()
12 onUnwatchAfterCancel func()
13 unwatchChan chan struct{}
14
15 lock sync.Mutex
16 watchInProgress bool
17 onCancelWasCalled bool
18 }
19
20
21
22
23 func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher {
24 cw := &ContextWatcher{
25 onCancel: onCancel,
26 onUnwatchAfterCancel: onUnwatchAfterCancel,
27 unwatchChan: make(chan struct{}),
28 }
29
30 return cw
31 }
32
33
34 func (cw *ContextWatcher) Watch(ctx context.Context) {
35 cw.lock.Lock()
36 defer cw.lock.Unlock()
37
38 if cw.watchInProgress {
39 panic("Watch already in progress")
40 }
41
42 cw.onCancelWasCalled = false
43
44 if ctx.Done() != nil {
45 cw.watchInProgress = true
46 go func() {
47 select {
48 case <-ctx.Done():
49 cw.onCancel()
50 cw.onCancelWasCalled = true
51 <-cw.unwatchChan
52 case <-cw.unwatchChan:
53 }
54 }()
55 } else {
56 cw.watchInProgress = false
57 }
58 }
59
60
61
62 func (cw *ContextWatcher) Unwatch() {
63 cw.lock.Lock()
64 defer cw.lock.Unlock()
65
66 if cw.watchInProgress {
67 cw.unwatchChan <- struct{}{}
68 if cw.onCancelWasCalled {
69 cw.onUnwatchAfterCancel()
70 }
71 cw.watchInProgress = false
72 }
73 }
74
View as plain text