...

Source file src/nhooyr.io/websocket/netconn.go

Documentation: nhooyr.io/websocket

     1  package websocket
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"math"
     8  	"net"
     9  	"sync/atomic"
    10  	"time"
    11  )
    12  
    13  // NetConn converts a *websocket.Conn into a net.Conn.
    14  //
    15  // It's for tunneling arbitrary protocols over WebSockets.
    16  // Few users of the library will need this but it's tricky to implement
    17  // correctly and so provided in the library.
    18  // See https://github.com/nhooyr/websocket/issues/100.
    19  //
    20  // Every Write to the net.Conn will correspond to a message write of
    21  // the given type on *websocket.Conn.
    22  //
    23  // The passed ctx bounds the lifetime of the net.Conn. If cancelled,
    24  // all reads and writes on the net.Conn will be cancelled.
    25  //
    26  // If a message is read that is not of the correct type, the connection
    27  // will be closed with StatusUnsupportedData and an error will be returned.
    28  //
    29  // Close will close the *websocket.Conn with StatusNormalClosure.
    30  //
    31  // When a deadline is hit and there is an active read or write goroutine, the
    32  // connection will be closed. This is different from most net.Conn implementations
    33  // where only the reading/writing goroutines are interrupted but the connection
    34  // is kept alive.
    35  //
    36  // The Addr methods will return the real addresses for connections obtained
    37  // from websocket.Accept. But for connections obtained from websocket.Dial, a mock net.Addr
    38  // will be returned that gives "websocket" for Network() and "websocket/unknown-addr" for
    39  // String(). This is because websocket.Dial only exposes a io.ReadWriteCloser instead of the
    40  // full net.Conn to us.
    41  //
    42  // When running as WASM, the Addr methods will always return the mock address described above.
    43  //
    44  // A received StatusNormalClosure or StatusGoingAway close frame will be translated to
    45  // io.EOF when reading.
    46  //
    47  // Furthermore, the ReadLimit is set to -1 to disable it.
    48  func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
    49  	c.SetReadLimit(-1)
    50  
    51  	nc := &netConn{
    52  		c:       c,
    53  		msgType: msgType,
    54  		readMu:  newMu(c),
    55  		writeMu: newMu(c),
    56  	}
    57  
    58  	nc.writeCtx, nc.writeCancel = context.WithCancel(ctx)
    59  	nc.readCtx, nc.readCancel = context.WithCancel(ctx)
    60  
    61  	nc.writeTimer = time.AfterFunc(math.MaxInt64, func() {
    62  		if !nc.writeMu.tryLock() {
    63  			// If the lock cannot be acquired, then there is an
    64  			// active write goroutine and so we should cancel the context.
    65  			nc.writeCancel()
    66  			return
    67  		}
    68  		defer nc.writeMu.unlock()
    69  
    70  		// Prevents future writes from writing until the deadline is reset.
    71  		atomic.StoreInt64(&nc.writeExpired, 1)
    72  	})
    73  	if !nc.writeTimer.Stop() {
    74  		<-nc.writeTimer.C
    75  	}
    76  
    77  	nc.readTimer = time.AfterFunc(math.MaxInt64, func() {
    78  		if !nc.readMu.tryLock() {
    79  			// If the lock cannot be acquired, then there is an
    80  			// active read goroutine and so we should cancel the context.
    81  			nc.readCancel()
    82  			return
    83  		}
    84  		defer nc.readMu.unlock()
    85  
    86  		// Prevents future reads from reading until the deadline is reset.
    87  		atomic.StoreInt64(&nc.readExpired, 1)
    88  	})
    89  	if !nc.readTimer.Stop() {
    90  		<-nc.readTimer.C
    91  	}
    92  
    93  	return nc
    94  }
    95  
    96  type netConn struct {
    97  	c       *Conn
    98  	msgType MessageType
    99  
   100  	writeTimer   *time.Timer
   101  	writeMu      *mu
   102  	writeExpired int64
   103  	writeCtx     context.Context
   104  	writeCancel  context.CancelFunc
   105  
   106  	readTimer   *time.Timer
   107  	readMu      *mu
   108  	readExpired int64
   109  	readCtx     context.Context
   110  	readCancel  context.CancelFunc
   111  	readEOFed   bool
   112  	reader      io.Reader
   113  }
   114  
   115  var _ net.Conn = &netConn{}
   116  
   117  func (nc *netConn) Close() error {
   118  	nc.writeTimer.Stop()
   119  	nc.writeCancel()
   120  	nc.readTimer.Stop()
   121  	nc.readCancel()
   122  	return nc.c.Close(StatusNormalClosure, "")
   123  }
   124  
   125  func (nc *netConn) Write(p []byte) (int, error) {
   126  	nc.writeMu.forceLock()
   127  	defer nc.writeMu.unlock()
   128  
   129  	if atomic.LoadInt64(&nc.writeExpired) == 1 {
   130  		return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded)
   131  	}
   132  
   133  	err := nc.c.Write(nc.writeCtx, nc.msgType, p)
   134  	if err != nil {
   135  		return 0, err
   136  	}
   137  	return len(p), nil
   138  }
   139  
   140  func (nc *netConn) Read(p []byte) (int, error) {
   141  	nc.readMu.forceLock()
   142  	defer nc.readMu.unlock()
   143  
   144  	for {
   145  		n, err := nc.read(p)
   146  		if err != nil {
   147  			return n, err
   148  		}
   149  		if n == 0 {
   150  			continue
   151  		}
   152  		return n, nil
   153  	}
   154  }
   155  
   156  func (nc *netConn) read(p []byte) (int, error) {
   157  	if atomic.LoadInt64(&nc.readExpired) == 1 {
   158  		return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded)
   159  	}
   160  
   161  	if nc.readEOFed {
   162  		return 0, io.EOF
   163  	}
   164  
   165  	if nc.reader == nil {
   166  		typ, r, err := nc.c.Reader(nc.readCtx)
   167  		if err != nil {
   168  			switch CloseStatus(err) {
   169  			case StatusNormalClosure, StatusGoingAway:
   170  				nc.readEOFed = true
   171  				return 0, io.EOF
   172  			}
   173  			return 0, err
   174  		}
   175  		if typ != nc.msgType {
   176  			err := fmt.Errorf("unexpected frame type read (expected %v): %v", nc.msgType, typ)
   177  			nc.c.Close(StatusUnsupportedData, err.Error())
   178  			return 0, err
   179  		}
   180  		nc.reader = r
   181  	}
   182  
   183  	n, err := nc.reader.Read(p)
   184  	if err == io.EOF {
   185  		nc.reader = nil
   186  		err = nil
   187  	}
   188  	return n, err
   189  }
   190  
   191  type websocketAddr struct {
   192  }
   193  
   194  func (a websocketAddr) Network() string {
   195  	return "websocket"
   196  }
   197  
   198  func (a websocketAddr) String() string {
   199  	return "websocket/unknown-addr"
   200  }
   201  
   202  func (nc *netConn) SetDeadline(t time.Time) error {
   203  	nc.SetWriteDeadline(t)
   204  	nc.SetReadDeadline(t)
   205  	return nil
   206  }
   207  
   208  func (nc *netConn) SetWriteDeadline(t time.Time) error {
   209  	atomic.StoreInt64(&nc.writeExpired, 0)
   210  	if t.IsZero() {
   211  		nc.writeTimer.Stop()
   212  	} else {
   213  		dur := time.Until(t)
   214  		if dur <= 0 {
   215  			dur = 1
   216  		}
   217  		nc.writeTimer.Reset(dur)
   218  	}
   219  	return nil
   220  }
   221  
   222  func (nc *netConn) SetReadDeadline(t time.Time) error {
   223  	atomic.StoreInt64(&nc.readExpired, 0)
   224  	if t.IsZero() {
   225  		nc.readTimer.Stop()
   226  	} else {
   227  		dur := time.Until(t)
   228  		if dur <= 0 {
   229  			dur = 1
   230  		}
   231  		nc.readTimer.Reset(dur)
   232  	}
   233  	return nil
   234  }
   235  

View as plain text