1
2
3 package dbus
4
5 import (
6 "bytes"
7 "encoding/binary"
8 "errors"
9 "io"
10 "net"
11 "syscall"
12 )
13
14 type oobReader struct {
15 conn *net.UnixConn
16 oob []byte
17 buf [4096]byte
18 }
19
20 func (o *oobReader) Read(b []byte) (n int, err error) {
21 n, oobn, flags, _, err := o.conn.ReadMsgUnix(b, o.buf[:])
22 if err != nil {
23 return n, err
24 }
25 if flags&syscall.MSG_CTRUNC != 0 {
26 return n, errors.New("dbus: control data truncated (too many fds received)")
27 }
28 o.oob = append(o.oob, o.buf[:oobn]...)
29 return n, nil
30 }
31
32 type unixTransport struct {
33 *net.UnixConn
34 rdr *oobReader
35 hasUnixFDs bool
36 }
37
38 func newUnixTransport(keys string) (transport, error) {
39 var err error
40
41 t := new(unixTransport)
42 abstract := getKey(keys, "abstract")
43 path := getKey(keys, "path")
44 switch {
45 case abstract == "" && path == "":
46 return nil, errors.New("dbus: invalid address (neither path nor abstract set)")
47 case abstract != "" && path == "":
48 t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: "@" + abstract, Net: "unix"})
49 if err != nil {
50 return nil, err
51 }
52 return t, nil
53 case abstract == "" && path != "":
54 t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: path, Net: "unix"})
55 if err != nil {
56 return nil, err
57 }
58 return t, nil
59 default:
60 return nil, errors.New("dbus: invalid address (both path and abstract set)")
61 }
62 }
63
64 func init() {
65 transports["unix"] = newUnixTransport
66 }
67
68 func (t *unixTransport) EnableUnixFDs() {
69 t.hasUnixFDs = true
70 }
71
72 func (t *unixTransport) ReadMessage() (*Message, error) {
73 var (
74 blen, hlen uint32
75 csheader [16]byte
76 headers []header
77 order binary.ByteOrder
78 unixfds uint32
79 )
80
81
82
83 if t.rdr == nil {
84 t.rdr = &oobReader{conn: t.UnixConn}
85 } else {
86 t.rdr.oob = nil
87 }
88
89
90
91 if _, err := io.ReadFull(t.rdr, csheader[:]); err != nil {
92 return nil, err
93 }
94 switch csheader[0] {
95 case 'l':
96 order = binary.LittleEndian
97 case 'B':
98 order = binary.BigEndian
99 default:
100 return nil, InvalidMessageError("invalid byte order")
101 }
102
103
104 binary.Read(bytes.NewBuffer(csheader[4:8]), order, &blen)
105 binary.Read(bytes.NewBuffer(csheader[12:]), order, &hlen)
106 if hlen%8 != 0 {
107 hlen += 8 - (hlen % 8)
108 }
109
110
111 headerdata := make([]byte, hlen+4)
112 copy(headerdata, csheader[12:])
113 if _, err := io.ReadFull(t.rdr, headerdata[4:]); err != nil {
114 return nil, err
115 }
116 dec := newDecoder(bytes.NewBuffer(headerdata), order, make([]int, 0))
117 dec.pos = 12
118 vs, err := dec.Decode(Signature{"a(yv)"})
119 if err != nil {
120 return nil, err
121 }
122 Store(vs, &headers)
123 for _, v := range headers {
124 if v.Field == byte(FieldUnixFDs) {
125 unixfds, _ = v.Variant.value.(uint32)
126 }
127 }
128 all := make([]byte, 16+hlen+blen)
129 copy(all, csheader[:])
130 copy(all[16:], headerdata[4:])
131 if _, err := io.ReadFull(t.rdr, all[16+hlen:]); err != nil {
132 return nil, err
133 }
134 if unixfds != 0 {
135 if !t.hasUnixFDs {
136 return nil, errors.New("dbus: got unix fds on unsupported transport")
137 }
138
139 scms, err := syscall.ParseSocketControlMessage(t.rdr.oob)
140 if err != nil {
141 return nil, err
142 }
143 if len(scms) != 1 {
144 return nil, errors.New("dbus: received more than one socket control message")
145 }
146 fds, err := syscall.ParseUnixRights(&scms[0])
147 if err != nil {
148 return nil, err
149 }
150 msg, err := DecodeMessageWithFDs(bytes.NewBuffer(all), fds)
151 if err != nil {
152 return nil, err
153 }
154
155
156 for i, v := range msg.Body {
157 switch index := v.(type) {
158 case UnixFDIndex:
159 if uint32(index) >= unixfds {
160 return nil, InvalidMessageError("invalid index for unix fd")
161 }
162 msg.Body[i] = UnixFD(fds[index])
163 case []UnixFDIndex:
164 fdArray := make([]UnixFD, len(index))
165 for k, j := range index {
166 if uint32(j) >= unixfds {
167 return nil, InvalidMessageError("invalid index for unix fd")
168 }
169 fdArray[k] = UnixFD(fds[j])
170 }
171 msg.Body[i] = fdArray
172 }
173 }
174 return msg, nil
175 }
176 return DecodeMessage(bytes.NewBuffer(all))
177 }
178
179 func (t *unixTransport) SendMessage(msg *Message) error {
180 fdcnt, err := msg.CountFds()
181 if err != nil {
182 return err
183 }
184 if fdcnt != 0 {
185 if !t.hasUnixFDs {
186 return errors.New("dbus: unix fd passing not enabled")
187 }
188 msg.Headers[FieldUnixFDs] = MakeVariant(uint32(fdcnt))
189 buf := new(bytes.Buffer)
190 fds, err := msg.EncodeToWithFDs(buf, nativeEndian)
191 if err != nil {
192 return err
193 }
194 oob := syscall.UnixRights(fds...)
195 n, oobn, err := t.UnixConn.WriteMsgUnix(buf.Bytes(), oob, nil)
196 if err != nil {
197 return err
198 }
199 if n != buf.Len() || oobn != len(oob) {
200 return io.ErrShortWrite
201 }
202 } else {
203 if err := msg.EncodeTo(t, nativeEndian); err != nil {
204 return err
205 }
206 }
207 return nil
208 }
209
210 func (t *unixTransport) SupportsUnixFDs() bool {
211 return true
212 }
213
View as plain text