...
1 package ctxwatch_test
2
3 import (
4 "context"
5 "sync/atomic"
6 "testing"
7 "time"
8
9 "github.com/jackc/pgconn/internal/ctxwatch"
10 "github.com/stretchr/testify/require"
11 )
12
13 func TestContextWatcherContextCancelled(t *testing.T) {
14 canceledChan := make(chan struct{})
15 cleanupCalled := false
16 cw := ctxwatch.NewContextWatcher(func() {
17 canceledChan <- struct{}{}
18 }, func() {
19 cleanupCalled = true
20 })
21
22 ctx, cancel := context.WithCancel(context.Background())
23 cw.Watch(ctx)
24 cancel()
25
26 select {
27 case <-canceledChan:
28 case <-time.NewTimer(time.Second).C:
29 t.Fatal("Timed out waiting for cancel func to be called")
30 }
31
32 cw.Unwatch()
33
34 require.True(t, cleanupCalled, "Cleanup func was not called")
35 }
36
37 func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
38 cw := ctxwatch.NewContextWatcher(func() {
39 t.Error("cancel func should not have been called")
40 }, func() {
41 t.Error("cleanup func should not have been called")
42 })
43
44 ctx, cancel := context.WithCancel(context.Background())
45 cw.Watch(ctx)
46 cw.Unwatch()
47 cancel()
48 }
49
50 func TestContextWatcherMultipleWatchPanics(t *testing.T) {
51 cw := ctxwatch.NewContextWatcher(func() {}, func() {})
52
53 ctx, cancel := context.WithCancel(context.Background())
54 defer cancel()
55 cw.Watch(ctx)
56
57 ctx2, cancel2 := context.WithCancel(context.Background())
58 defer cancel2()
59 require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times")
60 }
61
62 func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
63 cw := ctxwatch.NewContextWatcher(func() {}, func() {})
64 cw.Unwatch()
65
66 ctx, cancel := context.WithCancel(context.Background())
67 defer cancel()
68 cw.Watch(ctx)
69 cw.Unwatch()
70 cw.Unwatch()
71 }
72
73 func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) {
74 cw := ctxwatch.NewContextWatcher(func() {}, func() {})
75
76 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
77 defer cancel()
78 cw.Watch(ctx)
79
80 go cw.Unwatch()
81 go cw.Unwatch()
82
83 <-ctx.Done()
84 }
85
86 func TestContextWatcherStress(t *testing.T) {
87 var cancelFuncCalls int64
88 var cleanupFuncCalls int64
89
90 cw := ctxwatch.NewContextWatcher(func() {
91 atomic.AddInt64(&cancelFuncCalls, 1)
92 }, func() {
93 atomic.AddInt64(&cleanupFuncCalls, 1)
94 })
95
96 cycleCount := 100000
97
98 for i := 0; i < cycleCount; i++ {
99 ctx, cancel := context.WithCancel(context.Background())
100 cw.Watch(ctx)
101 if i%2 == 0 {
102 cancel()
103 }
104
105
106 if i%3 == 0 {
107 time.Sleep(time.Nanosecond)
108 }
109
110 cw.Unwatch()
111 if i%2 == 1 {
112 cancel()
113 }
114 }
115
116 actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls)
117 actualCleanupFuncCalls := atomic.LoadInt64(&cleanupFuncCalls)
118
119 if actualCancelFuncCalls == 0 {
120 t.Fatal("actualCancelFuncCalls == 0")
121 }
122
123 maxCancelFuncCalls := int64(cycleCount) / 2
124 if actualCancelFuncCalls > maxCancelFuncCalls {
125 t.Errorf("cancel func calls should be no more than %d but was %d", actualCancelFuncCalls, maxCancelFuncCalls)
126 }
127
128 if actualCancelFuncCalls != actualCleanupFuncCalls {
129 t.Errorf("cancel func calls (%d) should be equal to cleanup func calls (%d) but was not", actualCancelFuncCalls, actualCleanupFuncCalls)
130 }
131 }
132
133 func BenchmarkContextWatcherUncancellable(b *testing.B) {
134 cw := ctxwatch.NewContextWatcher(func() {}, func() {})
135
136 for i := 0; i < b.N; i++ {
137 cw.Watch(context.Background())
138 cw.Unwatch()
139 }
140 }
141
142 func BenchmarkContextWatcherCancelled(b *testing.B) {
143 cw := ctxwatch.NewContextWatcher(func() {}, func() {})
144
145 for i := 0; i < b.N; i++ {
146 ctx, cancel := context.WithCancel(context.Background())
147 cw.Watch(ctx)
148 cancel()
149 cw.Unwatch()
150 }
151 }
152
153 func BenchmarkContextWatcherCancellable(b *testing.B) {
154 cw := ctxwatch.NewContextWatcher(func() {}, func() {})
155
156 ctx, cancel := context.WithCancel(context.Background())
157 defer cancel()
158
159 for i := 0; i < b.N; i++ {
160 cw.Watch(ctx)
161 cw.Unwatch()
162 }
163 }
164
View as plain text