...
1 package interrupt_handler
2
3 import (
4 "os"
5 "os/signal"
6 "sync"
7 "syscall"
8 "time"
9
10 "github.com/onsi/ginkgo/v2/internal/parallel_support"
11 )
12
13 var ABORT_POLLING_INTERVAL = 500 * time.Millisecond
14
15 type InterruptCause uint
16
17 const (
18 InterruptCauseInvalid InterruptCause = iota
19 InterruptCauseSignal
20 InterruptCauseAbortByOtherProcess
21 )
22
23 type InterruptLevel uint
24
25 const (
26 InterruptLevelUninterrupted InterruptLevel = iota
27 InterruptLevelCleanupAndReport
28 InterruptLevelReportOnly
29 InterruptLevelBailOut
30 )
31
32 func (ic InterruptCause) String() string {
33 switch ic {
34 case InterruptCauseSignal:
35 return "Interrupted by User"
36 case InterruptCauseAbortByOtherProcess:
37 return "Interrupted by Other Ginkgo Process"
38 }
39 return "INVALID_INTERRUPT_CAUSE"
40 }
41
42 type InterruptStatus struct {
43 Channel chan interface{}
44 Level InterruptLevel
45 Cause InterruptCause
46 }
47
48 func (s InterruptStatus) Interrupted() bool {
49 return s.Level != InterruptLevelUninterrupted
50 }
51
52 func (s InterruptStatus) Message() string {
53 return s.Cause.String()
54 }
55
56 func (s InterruptStatus) ShouldIncludeProgressReport() bool {
57 return s.Cause != InterruptCauseAbortByOtherProcess
58 }
59
60 type InterruptHandlerInterface interface {
61 Status() InterruptStatus
62 }
63
64 type InterruptHandler struct {
65 c chan interface{}
66 lock *sync.Mutex
67 level InterruptLevel
68 cause InterruptCause
69 client parallel_support.Client
70 stop chan interface{}
71 signals []os.Signal
72 requestAbortCheck chan interface{}
73 }
74
75 func NewInterruptHandler(client parallel_support.Client, signals ...os.Signal) *InterruptHandler {
76 if len(signals) == 0 {
77 signals = []os.Signal{os.Interrupt, syscall.SIGTERM}
78 }
79 handler := &InterruptHandler{
80 c: make(chan interface{}),
81 lock: &sync.Mutex{},
82 stop: make(chan interface{}),
83 requestAbortCheck: make(chan interface{}),
84 client: client,
85 signals: signals,
86 }
87 handler.registerForInterrupts()
88 return handler
89 }
90
91 func (handler *InterruptHandler) Stop() {
92 close(handler.stop)
93 }
94
95 func (handler *InterruptHandler) registerForInterrupts() {
96
97 signalChannel := make(chan os.Signal, 1)
98 signal.Notify(signalChannel, handler.signals...)
99
100
101 var abortChannel chan interface{}
102 if handler.client != nil {
103 abortChannel = make(chan interface{})
104 go func() {
105 pollTicker := time.NewTicker(ABORT_POLLING_INTERVAL)
106 for {
107 select {
108 case <-pollTicker.C:
109 if handler.client.ShouldAbort() {
110 close(abortChannel)
111 pollTicker.Stop()
112 return
113 }
114 case <-handler.requestAbortCheck:
115 if handler.client.ShouldAbort() {
116 close(abortChannel)
117 pollTicker.Stop()
118 return
119 }
120 case <-handler.stop:
121 pollTicker.Stop()
122 return
123 }
124 }
125 }()
126 }
127
128 go func(abortChannel chan interface{}) {
129 var interruptCause InterruptCause
130 for {
131 select {
132 case <-signalChannel:
133 interruptCause = InterruptCauseSignal
134 case <-abortChannel:
135 interruptCause = InterruptCauseAbortByOtherProcess
136 case <-handler.stop:
137 signal.Stop(signalChannel)
138 return
139 }
140 abortChannel = nil
141
142 handler.lock.Lock()
143 oldLevel := handler.level
144 handler.cause = interruptCause
145 if handler.level == InterruptLevelUninterrupted {
146 handler.level = InterruptLevelCleanupAndReport
147 } else if handler.level == InterruptLevelCleanupAndReport {
148 handler.level = InterruptLevelReportOnly
149 } else if handler.level == InterruptLevelReportOnly {
150 handler.level = InterruptLevelBailOut
151 }
152 if handler.level != oldLevel {
153 close(handler.c)
154 handler.c = make(chan interface{})
155 }
156 handler.lock.Unlock()
157 }
158 }(abortChannel)
159 }
160
161 func (handler *InterruptHandler) Status() InterruptStatus {
162 handler.lock.Lock()
163 status := InterruptStatus{
164 Level: handler.level,
165 Channel: handler.c,
166 Cause: handler.cause,
167 }
168 handler.lock.Unlock()
169
170 if handler.client != nil && handler.client.ShouldAbort() && !status.Interrupted() {
171 close(handler.requestAbortCheck)
172 <-status.Channel
173 return handler.Status()
174 }
175
176 return status
177 }
178
View as plain text