...
1
2
3
4
5
6
7 package examples
8
9 import (
10 "context"
11 "log"
12 "net"
13 "sync"
14 "testing"
15 "time"
16
17 "github.com/miekg/dns"
18 "go.mongodb.org/mongo-driver/bson"
19 "go.mongodb.org/mongo-driver/mongo"
20 "go.mongodb.org/mongo-driver/mongo/options"
21 )
22
23 func resolve(ctx context.Context, cache *dnsCache, in *dns.Conn, out *dns.Conn) {
24 for ctx.Err() == nil {
25 q, err := in.ReadMsg()
26 if err != nil {
27
28 log.Fatalf("Unhandled error in ReadMsg: %v", err)
29 }
30 if len(q.Question) != 1 {
31
32 continue
33 }
34
35 a, err := func() (*dns.Msg, error) {
36 cache.lock.Lock()
37 defer cache.lock.Unlock()
38
39 now := time.Now()
40 if rr, ok := cache.records[q.Question[0]]; ok && rr.exp.After(now) {
41 a := new(dns.Msg)
42 a.SetReply(q)
43 a.Compress = false
44 a.Answer = append(a.Answer, rr.record)
45 return a, nil
46 }
47
48 err := out.WriteMsg(q)
49 if err != nil {
50 return nil, err
51 }
52
53 m, err := out.ReadMsg()
54 if err != nil {
55 return nil, err
56 }
57
58 l := len(m.Answer)
59 for i, q := range m.Question {
60 if i >= l {
61 break
62 }
63 a := m.Answer[i]
64 cache.records[q] = &RR{
65 a,
66 now.Add(time.Second * time.Duration(a.Header().Ttl)),
67 }
68 }
69 return m, nil
70 }()
71 if err != nil {
72
73 log.Fatalf("Unhandled error in record retrieval: %v", err)
74 }
75
76 if err := in.WriteMsg(a); err != nil {
77
78 log.Fatalf("Unhandled error in WriteMsg: %v", err)
79 }
80 }
81 }
82
83 type RR struct {
84 record dns.RR
85 exp time.Time
86 }
87
88 type dnsCache struct {
89 records map[dns.Question]*RR
90 lock sync.Mutex
91 }
92
93 type dialer struct {
94 *net.Dialer
95 cache *dnsCache
96 }
97
98 func NewDialer() dialer {
99 cache := &dnsCache{
100 records: make(map[dns.Question]*RR),
101 lock: sync.Mutex{},
102 }
103 return dialer{
104 Dialer: &net.Dialer{
105 Resolver: &net.Resolver{
106 PreferGo: true,
107 Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
108 var d net.Dialer
109 outConn, err := d.DialContext(ctx, network, address)
110 conn, inConn := net.Pipe()
111 if err == nil {
112 go resolve(ctx, cache, &dns.Conn{Conn: inConn}, &dns.Conn{Conn: outConn})
113 }
114 return conn, err
115 },
116 },
117 },
118 cache: cache,
119 }
120 }
121
122 func TestCustomDialer(t *testing.T) {
123 client, err := mongo.NewClient(options.Client().ApplyURI("mongodb://testurl:27017").SetDialer(NewDialer()))
124 if err != nil {
125 t.Fatalf("error creating client: %v", err)
126 }
127 ctx := context.Background()
128 err = client.Connect(ctx)
129 if err != nil {
130 t.Fatalf("error connecting: %v", err)
131 }
132 defer client.Disconnect(context.Background())
133 coll := client.Database("test").Collection("test")
134 _, err = coll.InsertOne(context.Background(), bson.D{{"text", "text"}})
135 if err != nil {
136 t.Fatalf("error inserting: %v", err)
137 }
138 }
139
View as plain text