...
1 package ctxwatch_test
2
3 import (
4 "context"
5 "sync/atomic"
6 "testing"
7 "time"
8
9 "github.com/jackc/pgx/v5/pgconn/internal/ctxwatch"
10 "github.com/stretchr/testify/require"
11 )
12
13 func TestContextWatcherContextCancelled(t *testing.T) {
14 canceledChan := make(chan struct{})
15 cleanupCalled := false
16 cw := ctxwatch.NewContextWatcher(func() {
17 canceledChan <- struct{}{}
18 }, func() {
19 cleanupCalled = true
20 })
21
22 ctx, cancel := context.WithCancel(context.Background())
23 cw.Watch(ctx)
24 cancel()
25
26 select {
27 case <-canceledChan:
28 case <-time.NewTimer(time.Second).C:
29 t.Fatal("Timed out waiting for cancel func to be called")
30 }
31
32 cw.Unwatch()
33
34 require.True(t, cleanupCalled, "Cleanup func was not called")
35 }
36
37 func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
38 cw := ctxwatch.NewContextWatcher(func() {
39 t.Error("cancel func should not have been called")
40 }, func() {
41 t.Error("cleanup func should not have been called")
42 })
43
44 ctx, cancel := context.WithCancel(context.Background())
45 cw.Watch(ctx)
46 cw.Unwatch()
47 cancel()
48 }
49
50 func TestContextWatcherMultipleWatchPanics(t *testing.T) {
51 cw := ctxwatch.NewContextWatcher(func() {}, func() {})
52
53 ctx, cancel := context.WithCancel(context.Background())
54 defer cancel()
55 cw.Watch(ctx)
56 defer cw.Unwatch()
57
58 ctx2, cancel2 := context.WithCancel(context.Background())
59 defer cancel2()
60 require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times")
61 }
62
63 func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
64 cw := ctxwatch.NewContextWatcher(func() {}, func() {})
65 cw.Unwatch()
66
67 ctx, cancel := context.WithCancel(context.Background())
68 defer cancel()
69 cw.Watch(ctx)
70 cw.Unwatch()
71 cw.Unwatch()
72 }
73
74 func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) {
75 cw := ctxwatch.NewContextWatcher(func() {}, func() {})
76
77 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
78 defer cancel()
79 cw.Watch(ctx)
80
81 go cw.Unwatch()
82 go cw.Unwatch()
83
84 <-ctx.Done()
85 }
86
87 func TestContextWatcherStress(t *testing.T) {
88 var cancelFuncCalls int64
89 var cleanupFuncCalls int64
90
91 cw := ctxwatch.NewContextWatcher(func() {
92 atomic.AddInt64(&cancelFuncCalls, 1)
93 }, func() {
94 atomic.AddInt64(&cleanupFuncCalls, 1)
95 })
96
97 cycleCount := 100000
98
99 for i := 0; i < cycleCount; i++ {
100 ctx, cancel := context.WithCancel(context.Background())
101 cw.Watch(ctx)
102 if i%2 == 0 {
103 cancel()
104 }
105
106
107 if i%333 == 0 {
108
109
110 time.Sleep(time.Nanosecond)
111 }
112
113 cw.Unwatch()
114 if i%2 == 1 {
115 cancel()
116 }
117 }
118
119 actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls)
120 actualCleanupFuncCalls := atomic.LoadInt64(&cleanupFuncCalls)
121
122 if actualCancelFuncCalls == 0 {
123 t.Fatal("actualCancelFuncCalls == 0")
124 }
125
126 maxCancelFuncCalls := int64(cycleCount) / 2
127 if actualCancelFuncCalls > maxCancelFuncCalls {
128 t.Errorf("cancel func calls should be no more than %d but was %d", actualCancelFuncCalls, maxCancelFuncCalls)
129 }
130
131 if actualCancelFuncCalls != actualCleanupFuncCalls {
132 t.Errorf("cancel func calls (%d) should be equal to cleanup func calls (%d) but was not", actualCancelFuncCalls, actualCleanupFuncCalls)
133 }
134 }
135
136 func BenchmarkContextWatcherUncancellable(b *testing.B) {
137 cw := ctxwatch.NewContextWatcher(func() {}, func() {})
138
139 for i := 0; i < b.N; i++ {
140 cw.Watch(context.Background())
141 cw.Unwatch()
142 }
143 }
144
145 func BenchmarkContextWatcherCancelled(b *testing.B) {
146 cw := ctxwatch.NewContextWatcher(func() {}, func() {})
147
148 for i := 0; i < b.N; i++ {
149 ctx, cancel := context.WithCancel(context.Background())
150 cw.Watch(ctx)
151 cancel()
152 cw.Unwatch()
153 }
154 }
155
156 func BenchmarkContextWatcherCancellable(b *testing.B) {
157 cw := ctxwatch.NewContextWatcher(func() {}, func() {})
158
159 ctx, cancel := context.WithCancel(context.Background())
160 defer cancel()
161
162 for i := 0; i < b.N; i++ {
163 cw.Watch(ctx)
164 cw.Unwatch()
165 }
166 }
167
View as plain text