...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/drivertest/channel_conn.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver/drivertest

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package drivertest
     8  
     9  import (
    10  	"context"
    11  	"errors"
    12  
    13  	"go.mongodb.org/mongo-driver/mongo/address"
    14  	"go.mongodb.org/mongo-driver/mongo/description"
    15  	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
    16  	"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
    17  )
    18  
    19  // ChannelConn implements the driver.Connection interface by reading and writing wire messages
    20  // to a channel
    21  type ChannelConn struct {
    22  	WriteErr error
    23  	Written  chan []byte
    24  	ReadResp chan []byte
    25  	ReadErr  chan error
    26  	Desc     description.Server
    27  }
    28  
    29  // WriteWireMessage implements the driver.Connection interface.
    30  func (c *ChannelConn) WriteWireMessage(ctx context.Context, wm []byte) error {
    31  	// Copy wm in case it came from a buffer pool.
    32  	b := make([]byte, len(wm))
    33  	copy(b, wm)
    34  	select {
    35  	case c.Written <- b:
    36  	case <-ctx.Done():
    37  		return ctx.Err()
    38  	default:
    39  		c.WriteErr = errors.New("could not write wiremessage to written channel")
    40  	}
    41  	return c.WriteErr
    42  }
    43  
    44  // ReadWireMessage implements the driver.Connection interface.
    45  func (c *ChannelConn) ReadWireMessage(ctx context.Context) ([]byte, error) {
    46  	var wm []byte
    47  	var err error
    48  	select {
    49  	case wm = <-c.ReadResp:
    50  	case err = <-c.ReadErr:
    51  	case <-ctx.Done():
    52  		err = ctx.Err()
    53  	}
    54  	return wm, err
    55  }
    56  
    57  // Description implements the driver.Connection interface.
    58  func (c *ChannelConn) Description() description.Server { return c.Desc }
    59  
    60  // Close implements the driver.Connection interface.
    61  func (c *ChannelConn) Close() error {
    62  	return nil
    63  }
    64  
    65  // ID implements the driver.Connection interface.
    66  func (c *ChannelConn) ID() string {
    67  	return "faked"
    68  }
    69  
    70  // DriverConnectionID implements the driver.Connection interface.
    71  // TODO(GODRIVER-2824): replace return type with int64.
    72  func (c *ChannelConn) DriverConnectionID() uint64 {
    73  	return 0
    74  }
    75  
    76  // ServerConnectionID implements the driver.Connection interface.
    77  func (c *ChannelConn) ServerConnectionID() *int64 {
    78  	serverConnectionID := int64(42)
    79  	return &serverConnectionID
    80  }
    81  
    82  // Address implements the driver.Connection interface.
    83  func (c *ChannelConn) Address() address.Address { return address.Address("0.0.0.0") }
    84  
    85  // Stale implements the driver.Connection interface.
    86  func (c *ChannelConn) Stale() bool {
    87  	return false
    88  }
    89  
    90  // MakeReply creates an OP_REPLY wiremessage from a BSON document
    91  func MakeReply(doc bsoncore.Document) []byte {
    92  	var dst []byte
    93  	idx, dst := wiremessage.AppendHeaderStart(dst, 10, 9, wiremessage.OpReply)
    94  	dst = wiremessage.AppendReplyFlags(dst, 0)
    95  	dst = wiremessage.AppendReplyCursorID(dst, 0)
    96  	dst = wiremessage.AppendReplyStartingFrom(dst, 0)
    97  	dst = wiremessage.AppendReplyNumberReturned(dst, 1)
    98  	dst = append(dst, doc...)
    99  	return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:])))
   100  }
   101  
   102  // GetCommandFromQueryWireMessage returns the command sent in an OP_QUERY wire message.
   103  func GetCommandFromQueryWireMessage(wm []byte) (bsoncore.Document, error) {
   104  	var ok bool
   105  	_, _, _, _, wm, ok = wiremessage.ReadHeader(wm)
   106  	if !ok {
   107  		return nil, errors.New("could not read header")
   108  	}
   109  	_, wm, ok = wiremessage.ReadQueryFlags(wm)
   110  	if !ok {
   111  		return nil, errors.New("could not read flags")
   112  	}
   113  	_, wm, ok = wiremessage.ReadQueryFullCollectionName(wm)
   114  	if !ok {
   115  		return nil, errors.New("could not read fullCollectionName")
   116  	}
   117  	_, wm, ok = wiremessage.ReadQueryNumberToSkip(wm)
   118  	if !ok {
   119  		return nil, errors.New("could not read numberToSkip")
   120  	}
   121  	_, wm, ok = wiremessage.ReadQueryNumberToReturn(wm)
   122  	if !ok {
   123  		return nil, errors.New("could not read numberToReturn")
   124  	}
   125  
   126  	var query bsoncore.Document
   127  	query, wm, ok = wiremessage.ReadQueryQuery(wm)
   128  	if !ok {
   129  		return nil, errors.New("could not read query")
   130  	}
   131  	return query, nil
   132  }
   133  
   134  // GetCommandFromMsgWireMessage returns the command document sent in an OP_MSG wire message.
   135  func GetCommandFromMsgWireMessage(wm []byte) (bsoncore.Document, error) {
   136  	var ok bool
   137  	_, _, _, _, wm, ok = wiremessage.ReadHeader(wm)
   138  	if !ok {
   139  		return nil, errors.New("could not read header")
   140  	}
   141  
   142  	_, wm, ok = wiremessage.ReadMsgFlags(wm)
   143  	if !ok {
   144  		return nil, errors.New("could not read flags")
   145  	}
   146  	_, wm, ok = wiremessage.ReadMsgSectionType(wm)
   147  	if !ok {
   148  		return nil, errors.New("could not read section type")
   149  	}
   150  
   151  	cmdDoc, wm, ok := wiremessage.ReadMsgSectionSingleDocument(wm)
   152  	if !ok {
   153  		return nil, errors.New("could not read command document")
   154  	}
   155  	return cmdDoc, nil
   156  }
   157  

View as plain text