...

Source file src/github.com/godbus/dbus/v5/transport_unix.go

Documentation: github.com/godbus/dbus/v5

     1  //+build !windows,!solaris
     2  
     3  package dbus
     4  
     5  import (
     6  	"bytes"
     7  	"encoding/binary"
     8  	"errors"
     9  	"io"
    10  	"net"
    11  	"syscall"
    12  )
    13  
    14  type oobReader struct {
    15  	conn *net.UnixConn
    16  	oob  []byte
    17  	buf  [4096]byte
    18  }
    19  
    20  func (o *oobReader) Read(b []byte) (n int, err error) {
    21  	n, oobn, flags, _, err := o.conn.ReadMsgUnix(b, o.buf[:])
    22  	if err != nil {
    23  		return n, err
    24  	}
    25  	if flags&syscall.MSG_CTRUNC != 0 {
    26  		return n, errors.New("dbus: control data truncated (too many fds received)")
    27  	}
    28  	o.oob = append(o.oob, o.buf[:oobn]...)
    29  	return n, nil
    30  }
    31  
    32  type unixTransport struct {
    33  	*net.UnixConn
    34  	rdr        *oobReader
    35  	hasUnixFDs bool
    36  }
    37  
    38  func newUnixTransport(keys string) (transport, error) {
    39  	var err error
    40  
    41  	t := new(unixTransport)
    42  	abstract := getKey(keys, "abstract")
    43  	path := getKey(keys, "path")
    44  	switch {
    45  	case abstract == "" && path == "":
    46  		return nil, errors.New("dbus: invalid address (neither path nor abstract set)")
    47  	case abstract != "" && path == "":
    48  		t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: "@" + abstract, Net: "unix"})
    49  		if err != nil {
    50  			return nil, err
    51  		}
    52  		return t, nil
    53  	case abstract == "" && path != "":
    54  		t.UnixConn, err = net.DialUnix("unix", nil, &net.UnixAddr{Name: path, Net: "unix"})
    55  		if err != nil {
    56  			return nil, err
    57  		}
    58  		return t, nil
    59  	default:
    60  		return nil, errors.New("dbus: invalid address (both path and abstract set)")
    61  	}
    62  }
    63  
    64  func init() {
    65  	transports["unix"] = newUnixTransport
    66  }
    67  
    68  func (t *unixTransport) EnableUnixFDs() {
    69  	t.hasUnixFDs = true
    70  }
    71  
    72  func (t *unixTransport) ReadMessage() (*Message, error) {
    73  	var (
    74  		blen, hlen uint32
    75  		csheader   [16]byte
    76  		headers    []header
    77  		order      binary.ByteOrder
    78  		unixfds    uint32
    79  	)
    80  	// To be sure that all bytes of out-of-band data are read, we use a special
    81  	// reader that uses ReadUnix on the underlying connection instead of Read
    82  	// and gathers the out-of-band data in a buffer.
    83  	if t.rdr == nil {
    84  		t.rdr = &oobReader{conn: t.UnixConn}
    85  	} else {
    86  		t.rdr.oob = nil
    87  	}
    88  
    89  	// read the first 16 bytes (the part of the header that has a constant size),
    90  	// from which we can figure out the length of the rest of the message
    91  	if _, err := io.ReadFull(t.rdr, csheader[:]); err != nil {
    92  		return nil, err
    93  	}
    94  	switch csheader[0] {
    95  	case 'l':
    96  		order = binary.LittleEndian
    97  	case 'B':
    98  		order = binary.BigEndian
    99  	default:
   100  		return nil, InvalidMessageError("invalid byte order")
   101  	}
   102  	// csheader[4:8] -> length of message body, csheader[12:16] -> length of
   103  	// header fields (without alignment)
   104  	binary.Read(bytes.NewBuffer(csheader[4:8]), order, &blen)
   105  	binary.Read(bytes.NewBuffer(csheader[12:]), order, &hlen)
   106  	if hlen%8 != 0 {
   107  		hlen += 8 - (hlen % 8)
   108  	}
   109  
   110  	// decode headers and look for unix fds
   111  	headerdata := make([]byte, hlen+4)
   112  	copy(headerdata, csheader[12:])
   113  	if _, err := io.ReadFull(t.rdr, headerdata[4:]); err != nil {
   114  		return nil, err
   115  	}
   116  	dec := newDecoder(bytes.NewBuffer(headerdata), order, make([]int, 0))
   117  	dec.pos = 12
   118  	vs, err := dec.Decode(Signature{"a(yv)"})
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  	Store(vs, &headers)
   123  	for _, v := range headers {
   124  		if v.Field == byte(FieldUnixFDs) {
   125  			unixfds, _ = v.Variant.value.(uint32)
   126  		}
   127  	}
   128  	all := make([]byte, 16+hlen+blen)
   129  	copy(all, csheader[:])
   130  	copy(all[16:], headerdata[4:])
   131  	if _, err := io.ReadFull(t.rdr, all[16+hlen:]); err != nil {
   132  		return nil, err
   133  	}
   134  	if unixfds != 0 {
   135  		if !t.hasUnixFDs {
   136  			return nil, errors.New("dbus: got unix fds on unsupported transport")
   137  		}
   138  		// read the fds from the OOB data
   139  		scms, err := syscall.ParseSocketControlMessage(t.rdr.oob)
   140  		if err != nil {
   141  			return nil, err
   142  		}
   143  		if len(scms) != 1 {
   144  			return nil, errors.New("dbus: received more than one socket control message")
   145  		}
   146  		fds, err := syscall.ParseUnixRights(&scms[0])
   147  		if err != nil {
   148  			return nil, err
   149  		}
   150  		msg, err := DecodeMessageWithFDs(bytes.NewBuffer(all), fds)
   151  		if err != nil {
   152  			return nil, err
   153  		}
   154  		// substitute the values in the message body (which are indices for the
   155  		// array receiver via OOB) with the actual values
   156  		for i, v := range msg.Body {
   157  			switch index := v.(type) {
   158  			case UnixFDIndex:
   159  				if uint32(index) >= unixfds {
   160  					return nil, InvalidMessageError("invalid index for unix fd")
   161  				}
   162  				msg.Body[i] = UnixFD(fds[index])
   163  			case []UnixFDIndex:
   164  				fdArray := make([]UnixFD, len(index))
   165  				for k, j := range index {
   166  					if uint32(j) >= unixfds {
   167  						return nil, InvalidMessageError("invalid index for unix fd")
   168  					}
   169  					fdArray[k] = UnixFD(fds[j])
   170  				}
   171  				msg.Body[i] = fdArray
   172  			}
   173  		}
   174  		return msg, nil
   175  	}
   176  	return DecodeMessage(bytes.NewBuffer(all))
   177  }
   178  
   179  func (t *unixTransport) SendMessage(msg *Message) error {
   180  	fdcnt, err := msg.CountFds()
   181  	if err != nil {
   182  		return err
   183  	}
   184  	if fdcnt != 0 {
   185  		if !t.hasUnixFDs {
   186  			return errors.New("dbus: unix fd passing not enabled")
   187  		}
   188  		msg.Headers[FieldUnixFDs] = MakeVariant(uint32(fdcnt))
   189  		buf := new(bytes.Buffer)
   190  		fds, err := msg.EncodeToWithFDs(buf, nativeEndian)
   191  		if err != nil {
   192  			return err
   193  		}
   194  		oob := syscall.UnixRights(fds...)
   195  		n, oobn, err := t.UnixConn.WriteMsgUnix(buf.Bytes(), oob, nil)
   196  		if err != nil {
   197  			return err
   198  		}
   199  		if n != buf.Len() || oobn != len(oob) {
   200  			return io.ErrShortWrite
   201  		}
   202  	} else {
   203  		if err := msg.EncodeTo(t, nativeEndian); err != nil {
   204  			return err
   205  		}
   206  	}
   207  	return nil
   208  }
   209  
   210  func (t *unixTransport) SupportsUnixFDs() bool {
   211  	return true
   212  }
   213  

View as plain text