...

Source file src/golang.org/x/net/http2/sync_test.go

Documentation: golang.org/x/net/http2

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package http2
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"runtime"
    11  	"strconv"
    12  	"strings"
    13  	"sync"
    14  	"testing"
    15  	"time"
    16  )
    17  
    18  // A synctestGroup synchronizes between a set of cooperating goroutines.
    19  type synctestGroup struct {
    20  	mu     sync.Mutex
    21  	gids   map[int]bool
    22  	now    time.Time
    23  	timers map[*fakeTimer]struct{}
    24  }
    25  
    26  type goroutine struct {
    27  	id     int
    28  	parent int
    29  	state  string
    30  }
    31  
    32  // newSynctest creates a new group with the synthetic clock set the provided time.
    33  func newSynctest(now time.Time) *synctestGroup {
    34  	return &synctestGroup{
    35  		gids: map[int]bool{
    36  			currentGoroutine(): true,
    37  		},
    38  		now: now,
    39  	}
    40  }
    41  
    42  // Join adds the current goroutine to the group.
    43  func (g *synctestGroup) Join() {
    44  	g.mu.Lock()
    45  	defer g.mu.Unlock()
    46  	g.gids[currentGoroutine()] = true
    47  }
    48  
    49  // Count returns the number of goroutines in the group.
    50  func (g *synctestGroup) Count() int {
    51  	gs := stacks(true)
    52  	count := 0
    53  	for _, gr := range gs {
    54  		if !g.gids[gr.id] && !g.gids[gr.parent] {
    55  			continue
    56  		}
    57  		count++
    58  	}
    59  	return count
    60  }
    61  
    62  // Close calls t.Fatal if the group contains any running goroutines.
    63  func (g *synctestGroup) Close(t testing.TB) {
    64  	if count := g.Count(); count != 1 {
    65  		buf := make([]byte, 16*1024)
    66  		n := runtime.Stack(buf, true)
    67  		t.Logf("stacks:\n%s", buf[:n])
    68  		t.Fatalf("%v goroutines still running after test completed, expect 1", count)
    69  	}
    70  }
    71  
    72  // Wait blocks until every goroutine in the group and their direct children are idle.
    73  func (g *synctestGroup) Wait() {
    74  	for i := 0; ; i++ {
    75  		if g.idle() {
    76  			return
    77  		}
    78  		runtime.Gosched()
    79  	}
    80  }
    81  
    82  func (g *synctestGroup) idle() bool {
    83  	gs := stacks(true)
    84  	g.mu.Lock()
    85  	defer g.mu.Unlock()
    86  	for _, gr := range gs[1:] {
    87  		if !g.gids[gr.id] && !g.gids[gr.parent] {
    88  			continue
    89  		}
    90  		// From runtime/runtime2.go.
    91  		switch gr.state {
    92  		case "IO wait":
    93  		case "chan receive (nil chan)":
    94  		case "chan send (nil chan)":
    95  		case "select":
    96  		case "select (no cases)":
    97  		case "chan receive":
    98  		case "chan send":
    99  		case "sync.Cond.Wait":
   100  		case "sync.Mutex.Lock":
   101  		case "sync.RWMutex.RLock":
   102  		case "sync.RWMutex.Lock":
   103  		default:
   104  			return false
   105  		}
   106  	}
   107  	return true
   108  }
   109  
   110  func currentGoroutine() int {
   111  	s := stacks(false)
   112  	return s[0].id
   113  }
   114  
   115  func stacks(all bool) []goroutine {
   116  	buf := make([]byte, 16*1024)
   117  	for {
   118  		n := runtime.Stack(buf, all)
   119  		if n < len(buf) {
   120  			buf = buf[:n]
   121  			break
   122  		}
   123  		buf = make([]byte, len(buf)*2)
   124  	}
   125  
   126  	var goroutines []goroutine
   127  	for _, gs := range strings.Split(string(buf), "\n\n") {
   128  		skip, rest, ok := strings.Cut(gs, "goroutine ")
   129  		if skip != "" || !ok {
   130  			panic(fmt.Errorf("1 unparsable goroutine stack:\n%s", gs))
   131  		}
   132  		ids, rest, ok := strings.Cut(rest, " [")
   133  		if !ok {
   134  			panic(fmt.Errorf("2 unparsable goroutine stack:\n%s", gs))
   135  		}
   136  		id, err := strconv.Atoi(ids)
   137  		if err != nil {
   138  			panic(fmt.Errorf("3 unparsable goroutine stack:\n%s", gs))
   139  		}
   140  		state, rest, ok := strings.Cut(rest, "]")
   141  		var parent int
   142  		_, rest, ok = strings.Cut(rest, "\ncreated by ")
   143  		if ok && strings.Contains(rest, " in goroutine ") {
   144  			_, rest, ok := strings.Cut(rest, " in goroutine ")
   145  			if !ok {
   146  				panic(fmt.Errorf("4 unparsable goroutine stack:\n%s", gs))
   147  			}
   148  			parents, rest, ok := strings.Cut(rest, "\n")
   149  			if !ok {
   150  				panic(fmt.Errorf("5 unparsable goroutine stack:\n%s", gs))
   151  			}
   152  			parent, err = strconv.Atoi(parents)
   153  			if err != nil {
   154  				panic(fmt.Errorf("6 unparsable goroutine stack:\n%s", gs))
   155  			}
   156  		}
   157  		goroutines = append(goroutines, goroutine{
   158  			id:     id,
   159  			parent: parent,
   160  			state:  state,
   161  		})
   162  	}
   163  	return goroutines
   164  }
   165  
   166  // AdvanceTime advances the synthetic clock by d.
   167  func (g *synctestGroup) AdvanceTime(d time.Duration) {
   168  	defer g.Wait()
   169  	g.mu.Lock()
   170  	defer g.mu.Unlock()
   171  	g.now = g.now.Add(d)
   172  	for tm := range g.timers {
   173  		if tm.when.After(g.now) {
   174  			continue
   175  		}
   176  		tm.run()
   177  		delete(g.timers, tm)
   178  	}
   179  }
   180  
   181  // Now returns the current synthetic time.
   182  func (g *synctestGroup) Now() time.Time {
   183  	g.mu.Lock()
   184  	defer g.mu.Unlock()
   185  	return g.now
   186  }
   187  
   188  // TimeUntilEvent returns the amount of time until the next scheduled timer.
   189  func (g *synctestGroup) TimeUntilEvent() (d time.Duration, scheduled bool) {
   190  	g.mu.Lock()
   191  	defer g.mu.Unlock()
   192  	for tm := range g.timers {
   193  		if dd := tm.when.Sub(g.now); !scheduled || dd < d {
   194  			d = dd
   195  			scheduled = true
   196  		}
   197  	}
   198  	return d, scheduled
   199  }
   200  
   201  // Sleep is time.Sleep, but using synthetic time.
   202  func (g *synctestGroup) Sleep(d time.Duration) {
   203  	tm := g.NewTimer(d)
   204  	<-tm.C()
   205  }
   206  
   207  // NewTimer is time.NewTimer, but using synthetic time.
   208  func (g *synctestGroup) NewTimer(d time.Duration) Timer {
   209  	return g.addTimer(d, &fakeTimer{
   210  		ch: make(chan time.Time),
   211  	})
   212  }
   213  
   214  // AfterFunc is time.AfterFunc, but using synthetic time.
   215  func (g *synctestGroup) AfterFunc(d time.Duration, f func()) Timer {
   216  	return g.addTimer(d, &fakeTimer{
   217  		f: f,
   218  	})
   219  }
   220  
   221  // ContextWithTimeout is context.WithTimeout, but using synthetic time.
   222  func (g *synctestGroup) ContextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
   223  	ctx, cancel := context.WithCancel(ctx)
   224  	tm := g.AfterFunc(d, cancel)
   225  	return ctx, func() {
   226  		tm.Stop()
   227  		cancel()
   228  	}
   229  }
   230  
   231  func (g *synctestGroup) addTimer(d time.Duration, tm *fakeTimer) *fakeTimer {
   232  	g.mu.Lock()
   233  	defer g.mu.Unlock()
   234  	tm.g = g
   235  	tm.when = g.now.Add(d)
   236  	if g.timers == nil {
   237  		g.timers = make(map[*fakeTimer]struct{})
   238  	}
   239  	if tm.when.After(g.now) {
   240  		g.timers[tm] = struct{}{}
   241  	} else {
   242  		tm.run()
   243  	}
   244  	return tm
   245  }
   246  
   247  type Timer = interface {
   248  	C() <-chan time.Time
   249  	Reset(d time.Duration) bool
   250  	Stop() bool
   251  }
   252  
   253  type fakeTimer struct {
   254  	g    *synctestGroup
   255  	when time.Time
   256  	ch   chan time.Time
   257  	f    func()
   258  }
   259  
   260  func (tm *fakeTimer) run() {
   261  	if tm.ch != nil {
   262  		tm.ch <- tm.g.now
   263  	} else {
   264  		go func() {
   265  			tm.g.Join()
   266  			tm.f()
   267  		}()
   268  	}
   269  }
   270  
   271  func (tm *fakeTimer) C() <-chan time.Time { return tm.ch }
   272  
   273  func (tm *fakeTimer) Reset(d time.Duration) bool {
   274  	tm.g.mu.Lock()
   275  	defer tm.g.mu.Unlock()
   276  	_, stopped := tm.g.timers[tm]
   277  	if d <= 0 {
   278  		delete(tm.g.timers, tm)
   279  		tm.run()
   280  	} else {
   281  		tm.when = tm.g.now.Add(d)
   282  		tm.g.timers[tm] = struct{}{}
   283  	}
   284  	return stopped
   285  }
   286  
   287  func (tm *fakeTimer) Stop() bool {
   288  	tm.g.mu.Lock()
   289  	defer tm.g.mu.Unlock()
   290  	_, stopped := tm.g.timers[tm]
   291  	delete(tm.g.timers, tm)
   292  	return stopped
   293  }
   294  

View as plain text