...

Source file src/go.mongodb.org/mongo-driver/examples/_example_customdns_test.go

Documentation: go.mongodb.org/mongo-driver/examples

     1  // Copyright (C) MongoDB, Inc. 2022-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package examples
     8  
     9  import (
    10  	"context"
    11  	"log"
    12  	"net"
    13  	"sync"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/miekg/dns"
    18  	"go.mongodb.org/mongo-driver/bson"
    19  	"go.mongodb.org/mongo-driver/mongo"
    20  	"go.mongodb.org/mongo-driver/mongo/options"
    21  )
    22  
    23  func resolve(ctx context.Context, cache *dnsCache, in *dns.Conn, out *dns.Conn) {
    24  	for ctx.Err() == nil {
    25  		q, err := in.ReadMsg()
    26  		if err != nil {
    27  			// TODO: Handle error.
    28  			log.Fatalf("Unhandled error in ReadMsg: %v", err)
    29  		}
    30  		if len(q.Question) != 1 {
    31  			// Multiple questions in a single query is not actually used in real life.
    32  			continue
    33  		}
    34  
    35  		a, err := func() (*dns.Msg, error) {
    36  			cache.lock.Lock()
    37  			defer cache.lock.Unlock()
    38  
    39  			now := time.Now()
    40  			if rr, ok := cache.records[q.Question[0]]; ok && rr.exp.After(now) {
    41  				a := new(dns.Msg)
    42  				a.SetReply(q)
    43  				a.Compress = false
    44  				a.Answer = append(a.Answer, rr.record)
    45  				return a, nil
    46  			}
    47  
    48  			err := out.WriteMsg(q)
    49  			if err != nil {
    50  				return nil, err
    51  			}
    52  
    53  			m, err := out.ReadMsg()
    54  			if err != nil {
    55  				return nil, err
    56  			}
    57  
    58  			l := len(m.Answer)
    59  			for i, q := range m.Question {
    60  				if i >= l {
    61  					break
    62  				}
    63  				a := m.Answer[i]
    64  				cache.records[q] = &RR{
    65  					a,
    66  					now.Add(time.Second * time.Duration(a.Header().Ttl)),
    67  				}
    68  			}
    69  			return m, nil
    70  		}()
    71  		if err != nil {
    72  			// TODO: Handle error.
    73  			log.Fatalf("Unhandled error in record retrieval: %v", err)
    74  		}
    75  
    76  		if err := in.WriteMsg(a); err != nil {
    77  			// TODO: Handle error.
    78  			log.Fatalf("Unhandled error in WriteMsg: %v", err)
    79  		}
    80  	}
    81  }
    82  
    83  type RR struct {
    84  	record dns.RR
    85  	exp    time.Time
    86  }
    87  
    88  type dnsCache struct {
    89  	records map[dns.Question]*RR
    90  	lock    sync.Mutex
    91  }
    92  
    93  type dialer struct {
    94  	*net.Dialer
    95  	cache *dnsCache
    96  }
    97  
    98  func NewDialer() dialer {
    99  	cache := &dnsCache{
   100  		records: make(map[dns.Question]*RR),
   101  		lock:    sync.Mutex{},
   102  	}
   103  	return dialer{
   104  		Dialer: &net.Dialer{
   105  			Resolver: &net.Resolver{
   106  				PreferGo: true,
   107  				Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
   108  					var d net.Dialer
   109  					outConn, err := d.DialContext(ctx, network, address)
   110  					conn, inConn := net.Pipe()
   111  					if err == nil {
   112  						go resolve(ctx, cache, &dns.Conn{Conn: inConn}, &dns.Conn{Conn: outConn})
   113  					}
   114  					return conn, err
   115  				},
   116  			},
   117  		},
   118  		cache: cache,
   119  	}
   120  }
   121  
   122  func TestCustomDialer(t *testing.T) {
   123  	client, err := mongo.NewClient(options.Client().ApplyURI("mongodb://testurl:27017").SetDialer(NewDialer()))
   124  	if err != nil {
   125  		t.Fatalf("error creating client: %v", err)
   126  	}
   127  	ctx := context.Background()
   128  	err = client.Connect(ctx)
   129  	if err != nil {
   130  		t.Fatalf("error connecting: %v", err)
   131  	}
   132  	defer client.Disconnect(context.Background())
   133  	coll := client.Database("test").Collection("test")
   134  	_, err = coll.InsertOne(context.Background(), bson.D{{"text", "text"}})
   135  	if err != nil {
   136  		t.Fatalf("error inserting: %v", err)
   137  	}
   138  }
   139  

View as plain text