1 package dns
2
3 import (
4 "fmt"
5 "time"
6 )
7
8
9 type Envelope struct {
10 RR []RR
11 Error error
12 }
13
14
15 type Transfer struct {
16 *Conn
17 DialTimeout time.Duration
18 ReadTimeout time.Duration
19 WriteTimeout time.Duration
20 TsigProvider TsigProvider
21 TsigSecret map[string]string
22 tsigTimersOnly bool
23 }
24
25 func (t *Transfer) tsigProvider() TsigProvider {
26 if t.TsigProvider != nil {
27 return t.TsigProvider
28 }
29 if t.TsigSecret != nil {
30 return tsigSecretProvider(t.TsigSecret)
31 }
32 return nil
33 }
34
35
36
37
38
39
40
41
42
43
44
45
46
47 func (t *Transfer) In(q *Msg, a string) (env chan *Envelope, err error) {
48 switch q.Question[0].Qtype {
49 case TypeAXFR, TypeIXFR:
50 default:
51 return nil, &Error{"unsupported question type"}
52 }
53
54 timeout := dnsTimeout
55 if t.DialTimeout != 0 {
56 timeout = t.DialTimeout
57 }
58
59 if t.Conn == nil {
60 t.Conn, err = DialTimeout("tcp", a, timeout)
61 if err != nil {
62 return nil, err
63 }
64 }
65
66 if err := t.WriteMsg(q); err != nil {
67 return nil, err
68 }
69
70 env = make(chan *Envelope)
71 switch q.Question[0].Qtype {
72 case TypeAXFR:
73 go t.inAxfr(q, env)
74 case TypeIXFR:
75 go t.inIxfr(q, env)
76 }
77
78 return env, nil
79 }
80
81 func (t *Transfer) inAxfr(q *Msg, c chan *Envelope) {
82 first := true
83 defer func() {
84
85
86
87 t.Close()
88 close(c)
89 }()
90 timeout := dnsTimeout
91 if t.ReadTimeout != 0 {
92 timeout = t.ReadTimeout
93 }
94 for {
95 t.Conn.SetReadDeadline(time.Now().Add(timeout))
96 in, err := t.ReadMsg()
97 if err != nil {
98 c <- &Envelope{nil, err}
99 return
100 }
101 if q.Id != in.Id {
102 c <- &Envelope{in.Answer, ErrId}
103 return
104 }
105 if first {
106 if in.Rcode != RcodeSuccess {
107 c <- &Envelope{in.Answer, &Error{err: fmt.Sprintf(errXFR, in.Rcode)}}
108 return
109 }
110 if !isSOAFirst(in) {
111 c <- &Envelope{in.Answer, ErrSoa}
112 return
113 }
114 first = !first
115
116 if len(in.Answer) == 1 {
117 t.tsigTimersOnly = true
118 c <- &Envelope{in.Answer, nil}
119 continue
120 }
121 }
122
123 if !first {
124 t.tsigTimersOnly = true
125 if isSOALast(in) {
126 c <- &Envelope{in.Answer, nil}
127 return
128 }
129 c <- &Envelope{in.Answer, nil}
130 }
131 }
132 }
133
134 func (t *Transfer) inIxfr(q *Msg, c chan *Envelope) {
135 var serial uint32
136 axfr := true
137 n := 0
138 qser := q.Ns[0].(*SOA).Serial
139 defer func() {
140
141
142
143 t.Close()
144 close(c)
145 }()
146 timeout := dnsTimeout
147 if t.ReadTimeout != 0 {
148 timeout = t.ReadTimeout
149 }
150 for {
151 t.SetReadDeadline(time.Now().Add(timeout))
152 in, err := t.ReadMsg()
153 if err != nil {
154 c <- &Envelope{nil, err}
155 return
156 }
157 if q.Id != in.Id {
158 c <- &Envelope{in.Answer, ErrId}
159 return
160 }
161 if in.Rcode != RcodeSuccess {
162 c <- &Envelope{in.Answer, &Error{err: fmt.Sprintf(errXFR, in.Rcode)}}
163 return
164 }
165 if n == 0 {
166
167 if !isSOAFirst(in) {
168 c <- &Envelope{in.Answer, ErrSoa}
169 return
170 }
171
172 serial = in.Answer[0].(*SOA).Serial
173
174 if qser >= serial {
175 c <- &Envelope{in.Answer, nil}
176 return
177 }
178 }
179
180 t.tsigTimersOnly = true
181 for _, rr := range in.Answer {
182 if v, ok := rr.(*SOA); ok {
183 if v.Serial == serial {
184 n++
185
186 if axfr && n == 2 || n == 3 {
187 c <- &Envelope{in.Answer, nil}
188 return
189 }
190 } else if axfr {
191
192 axfr = false
193 }
194 }
195 }
196 c <- &Envelope{in.Answer, nil}
197 }
198 }
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216 func (t *Transfer) Out(w ResponseWriter, q *Msg, ch chan *Envelope) error {
217 for x := range ch {
218 r := new(Msg)
219
220 r.SetReply(q)
221 r.Authoritative = true
222
223 r.Answer = append(r.Answer, x.RR...)
224 if tsig := q.IsTsig(); tsig != nil && w.TsigStatus() == nil {
225 r.SetTsig(tsig.Hdr.Name, tsig.Algorithm, tsig.Fudge, time.Now().Unix())
226 }
227 if err := w.WriteMsg(r); err != nil {
228 return err
229 }
230 w.TsigTimersOnly(true)
231 }
232 return nil
233 }
234
235
236 func (t *Transfer) ReadMsg() (*Msg, error) {
237 m := new(Msg)
238 p := make([]byte, MaxMsgSize)
239 n, err := t.Read(p)
240 if err != nil && n == 0 {
241 return nil, err
242 }
243 p = p[:n]
244 if err := m.Unpack(p); err != nil {
245 return nil, err
246 }
247 if ts, tp := m.IsTsig(), t.tsigProvider(); ts != nil && tp != nil {
248
249 err = TsigVerifyWithProvider(p, tp, t.tsigRequestMAC, t.tsigTimersOnly)
250 t.tsigRequestMAC = ts.MAC
251 }
252 return m, err
253 }
254
255
256 func (t *Transfer) WriteMsg(m *Msg) (err error) {
257 var out []byte
258 if ts, tp := m.IsTsig(), t.tsigProvider(); ts != nil && tp != nil {
259 out, t.tsigRequestMAC, err = TsigGenerateWithProvider(m, tp, t.tsigRequestMAC, t.tsigTimersOnly)
260 } else {
261 out, err = m.Pack()
262 }
263 if err != nil {
264 return err
265 }
266 _, err = t.Write(out)
267 return err
268 }
269
270 func isSOAFirst(in *Msg) bool {
271 return len(in.Answer) > 0 &&
272 in.Answer[0].Header().Rrtype == TypeSOA
273 }
274
275 func isSOALast(in *Msg) bool {
276 return len(in.Answer) > 0 &&
277 in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA
278 }
279
280 const errXFR = "bad xfr rcode: %d"
281
View as plain text