...

Source file src/k8s.io/client-go/util/connrotation/connrotation.go

Documentation: k8s.io/client-go/util/connrotation

     1  /*
     2  Copyright 2018 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // Package connrotation implements a connection dialer that tracks and can close
    18  // all created connections.
    19  //
    20  // This is used for credential rotation of long-lived connections, when there's
    21  // no way to re-authenticate on a live connection.
    22  package connrotation
    23  
    24  import (
    25  	"context"
    26  	"net"
    27  	"sync"
    28  )
    29  
    30  // DialFunc is a shorthand for signature of net.DialContext.
    31  type DialFunc func(ctx context.Context, network, address string) (net.Conn, error)
    32  
    33  // Dialer opens connections through Dial and tracks them.
    34  type Dialer struct {
    35  	dial DialFunc
    36  	*ConnectionTracker
    37  }
    38  
    39  // NewDialer creates a new Dialer instance.
    40  // Equivalent to NewDialerWithTracker(dial, nil).
    41  func NewDialer(dial DialFunc) *Dialer {
    42  	return NewDialerWithTracker(dial, nil)
    43  }
    44  
    45  // NewDialerWithTracker creates a new Dialer instance.
    46  //
    47  // If dial is not nil, it will be used to create new underlying connections.
    48  // Otherwise net.DialContext is used.
    49  // If tracker is not nil, it will be used to track new underlying connections.
    50  // Otherwise NewConnectionTracker() is used.
    51  func NewDialerWithTracker(dial DialFunc, tracker *ConnectionTracker) *Dialer {
    52  	if tracker == nil {
    53  		tracker = NewConnectionTracker()
    54  	}
    55  	return &Dialer{
    56  		dial:              dial,
    57  		ConnectionTracker: tracker,
    58  	}
    59  }
    60  
    61  // ConnectionTracker keeps track of opened connections
    62  type ConnectionTracker struct {
    63  	mu    sync.Mutex
    64  	conns map[*closableConn]struct{}
    65  }
    66  
    67  // NewConnectionTracker returns a connection tracker for use with NewDialerWithTracker
    68  func NewConnectionTracker() *ConnectionTracker {
    69  	return &ConnectionTracker{
    70  		conns: make(map[*closableConn]struct{}),
    71  	}
    72  }
    73  
    74  // CloseAll forcibly closes all tracked connections.
    75  //
    76  // Note: new connections may get created before CloseAll returns.
    77  func (c *ConnectionTracker) CloseAll() {
    78  	c.mu.Lock()
    79  	conns := c.conns
    80  	c.conns = make(map[*closableConn]struct{})
    81  	c.mu.Unlock()
    82  
    83  	for conn := range conns {
    84  		conn.Close()
    85  	}
    86  }
    87  
    88  // Track adds the connection to the list of tracked connections,
    89  // and returns a wrapped copy of the connection that stops tracking the connection
    90  // when it is closed.
    91  func (c *ConnectionTracker) Track(conn net.Conn) net.Conn {
    92  	closable := &closableConn{Conn: conn}
    93  
    94  	// When the connection is closed, remove it from the map. This will
    95  	// be no-op if the connection isn't in the map, e.g. if CloseAll()
    96  	// is called.
    97  	closable.onClose = func() {
    98  		c.mu.Lock()
    99  		delete(c.conns, closable)
   100  		c.mu.Unlock()
   101  	}
   102  
   103  	// Start tracking the connection
   104  	c.mu.Lock()
   105  	c.conns[closable] = struct{}{}
   106  	c.mu.Unlock()
   107  
   108  	return closable
   109  }
   110  
   111  // Dial creates a new tracked connection.
   112  func (d *Dialer) Dial(network, address string) (net.Conn, error) {
   113  	return d.DialContext(context.Background(), network, address)
   114  }
   115  
   116  // DialContext creates a new tracked connection.
   117  func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
   118  	conn, err := d.dial(ctx, network, address)
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  	return d.ConnectionTracker.Track(conn), nil
   123  }
   124  
   125  type closableConn struct {
   126  	onClose func()
   127  	net.Conn
   128  }
   129  
   130  func (c *closableConn) Close() error {
   131  	go c.onClose()
   132  	return c.Conn.Close()
   133  }
   134  

View as plain text