...

Source file src/github.com/godbus/dbus/v5/conn_test.go

Documentation: github.com/godbus/dbus/v5

     1  package dbus
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"log"
    10  	"sync"
    11  	"testing"
    12  	"time"
    13  )
    14  
    15  func TestSessionBus(t *testing.T) {
    16  	oldConn, err := SessionBus()
    17  	if err != nil {
    18  		t.Error(err)
    19  	}
    20  	if err = oldConn.Close(); err != nil {
    21  		t.Fatal(err)
    22  	}
    23  	if oldConn.Connected() {
    24  		t.Fatal("Should be closed")
    25  	}
    26  	newConn, err := SessionBus()
    27  	if err != nil {
    28  		t.Error(err)
    29  	}
    30  	if newConn == oldConn {
    31  		t.Fatal("Should get a new connection")
    32  	}
    33  }
    34  
    35  func TestSystemBus(t *testing.T) {
    36  	oldConn, err := SystemBus()
    37  	if err != nil {
    38  		t.Error(err)
    39  	}
    40  	if err = oldConn.Close(); err != nil {
    41  		t.Fatal(err)
    42  	}
    43  	if oldConn.Connected() {
    44  		t.Fatal("Should be closed")
    45  	}
    46  	newConn, err := SystemBus()
    47  	if err != nil {
    48  		t.Error(err)
    49  	}
    50  	if newConn == oldConn {
    51  		t.Fatal("Should get a new connection")
    52  	}
    53  }
    54  
    55  func TestConnectSessionBus(t *testing.T) {
    56  	conn, err := ConnectSessionBus()
    57  	if err != nil {
    58  		t.Fatal(err)
    59  	}
    60  	if err = conn.Close(); err != nil {
    61  		t.Fatal(err)
    62  	}
    63  	if conn.Connected() {
    64  		t.Fatal("Should be closed")
    65  	}
    66  }
    67  
    68  func TestConnectSystemBus(t *testing.T) {
    69  	conn, err := ConnectSystemBus()
    70  	if err != nil {
    71  		t.Fatal(err)
    72  	}
    73  	if err = conn.Close(); err != nil {
    74  		t.Fatal(err)
    75  	}
    76  	if conn.Connected() {
    77  		t.Fatal("Should be closed")
    78  	}
    79  }
    80  
    81  func TestSend(t *testing.T) {
    82  	bus, err := ConnectSessionBus()
    83  	if err != nil {
    84  		t.Fatal(err)
    85  	}
    86  	defer bus.Close()
    87  
    88  	ch := make(chan *Call, 1)
    89  	msg := &Message{
    90  		Type:  TypeMethodCall,
    91  		Flags: 0,
    92  		Headers: map[HeaderField]Variant{
    93  			FieldDestination: MakeVariant(bus.Names()[0]),
    94  			FieldPath:        MakeVariant(ObjectPath("/org/freedesktop/DBus")),
    95  			FieldInterface:   MakeVariant("org.freedesktop.DBus.Peer"),
    96  			FieldMember:      MakeVariant("Ping"),
    97  		},
    98  	}
    99  	call := bus.Send(msg, ch)
   100  	<-ch
   101  	if call.Err != nil {
   102  		t.Error(call.Err)
   103  	}
   104  }
   105  
   106  func TestFlagNoReplyExpectedSend(t *testing.T) {
   107  	bus, err := ConnectSessionBus()
   108  	if err != nil {
   109  		t.Fatal(err)
   110  	}
   111  	defer bus.Close()
   112  
   113  	done := make(chan struct{})
   114  	go func() {
   115  		bus.BusObject().Call("org.freedesktop.DBus.ListNames", FlagNoReplyExpected)
   116  		close(done)
   117  	}()
   118  	select {
   119  	case <-done:
   120  	case <-time.After(1 * time.Second):
   121  		t.Error("Failed to announce that the call was done")
   122  	}
   123  }
   124  
   125  func TestRemoveSignal(t *testing.T) {
   126  	bus, err := NewConn(nil)
   127  	if err != nil {
   128  		t.Error(err)
   129  	}
   130  	signals := bus.signalHandler.(*defaultSignalHandler).signals
   131  	ch := make(chan *Signal)
   132  	ch2 := make(chan *Signal)
   133  	for _, ch := range []chan *Signal{ch, ch2, ch, ch2, ch2, ch} {
   134  		bus.Signal(ch)
   135  	}
   136  	signals = bus.signalHandler.(*defaultSignalHandler).signals
   137  	if len(signals) != 6 {
   138  		t.Errorf("remove signal: signals length not equal: got '%d', want '6'", len(signals))
   139  	}
   140  	bus.RemoveSignal(ch)
   141  	signals = bus.signalHandler.(*defaultSignalHandler).signals
   142  	if len(signals) != 3 {
   143  		t.Errorf("remove signal: signals length not equal: got '%d', want '3'", len(signals))
   144  	}
   145  	signals = bus.signalHandler.(*defaultSignalHandler).signals
   146  	for _, scd := range signals {
   147  		if scd.ch != ch2 {
   148  			t.Errorf("remove signal: removed signal present: got '%v', want '%v'", scd.ch, ch2)
   149  		}
   150  	}
   151  }
   152  
   153  type rwc struct {
   154  	io.Reader
   155  	io.Writer
   156  }
   157  
   158  func (rwc) Close() error { return nil }
   159  
   160  type fakeAuth struct {
   161  }
   162  
   163  func (fakeAuth) FirstData() (name, resp []byte, status AuthStatus) {
   164  	return []byte("name"), []byte("resp"), AuthOk
   165  }
   166  
   167  func (fakeAuth) HandleData(data []byte) (resp []byte, status AuthStatus) {
   168  	return nil, AuthOk
   169  }
   170  
   171  func TestCloseBeforeSignal(t *testing.T) {
   172  	reader, pipewriter := io.Pipe()
   173  	defer pipewriter.Close()
   174  	defer reader.Close()
   175  
   176  	bus, err := NewConn(rwc{Reader: reader, Writer: ioutil.Discard})
   177  	if err != nil {
   178  		t.Fatal(err)
   179  	}
   180  	// give ch a buffer so sends won't block
   181  	ch := make(chan *Signal, 1)
   182  	bus.Signal(ch)
   183  
   184  	go func() {
   185  		_, err := pipewriter.Write([]byte("REJECTED name\r\nOK myuuid\r\n"))
   186  		if err != nil {
   187  			t.Errorf("error writing to pipe: %v", err)
   188  		}
   189  	}()
   190  
   191  	err = bus.Auth([]Auth{fakeAuth{}})
   192  	if err != nil {
   193  		t.Fatal(err)
   194  	}
   195  
   196  	err = bus.Close()
   197  	if err != nil {
   198  		t.Fatal(err)
   199  	}
   200  
   201  	msg := &Message{
   202  		Type: TypeSignal,
   203  		Headers: map[HeaderField]Variant{
   204  			FieldInterface: MakeVariant("foo.bar"),
   205  			FieldMember:    MakeVariant("bar"),
   206  			FieldPath:      MakeVariant(ObjectPath("/baz")),
   207  		},
   208  	}
   209  	err = msg.EncodeTo(pipewriter, binary.LittleEndian)
   210  	if err != nil {
   211  		t.Fatal(err)
   212  	}
   213  }
   214  
   215  func TestCloseChannelAfterRemoveSignal(t *testing.T) {
   216  	bus, err := NewConn(nil)
   217  	if err != nil {
   218  		t.Fatal(err)
   219  	}
   220  
   221  	// Add an unbuffered signal channel
   222  	ch := make(chan *Signal)
   223  	bus.Signal(ch)
   224  
   225  	// Send a signal
   226  	msg := &Message{
   227  		Type: TypeSignal,
   228  		Headers: map[HeaderField]Variant{
   229  			FieldInterface: MakeVariant("foo.bar"),
   230  			FieldMember:    MakeVariant("bar"),
   231  			FieldPath:      MakeVariant(ObjectPath("/baz")),
   232  		},
   233  	}
   234  	bus.handleSignal(Sequence(1), msg)
   235  
   236  	// Remove and close the signal channel
   237  	bus.RemoveSignal(ch)
   238  	close(ch)
   239  }
   240  
   241  func TestAddAndRemoveMatchSignalContext(t *testing.T) {
   242  	conn, err := ConnectSessionBus()
   243  	if err != nil {
   244  		t.Fatal(err)
   245  	}
   246  	defer conn.Close()
   247  
   248  	sigc := make(chan *Signal, 1)
   249  	conn.Signal(sigc)
   250  
   251  	ctx, cancel := context.WithCancel(context.Background())
   252  	cancel()
   253  	// try to subscribe to a made up signal with an already canceled context
   254  	if err = conn.AddMatchSignalContext(
   255  		ctx,
   256  		WithMatchInterface("org.test"),
   257  		WithMatchMember("CtxTest"),
   258  	); err == nil {
   259  		t.Fatal("call on canceled context did not fail")
   260  	}
   261  
   262  	// subscribe to the signal with background context
   263  	if err = conn.AddMatchSignalContext(
   264  		context.Background(),
   265  		WithMatchInterface("org.test"),
   266  		WithMatchMember("CtxTest"),
   267  	); err != nil {
   268  		t.Fatal(err)
   269  	}
   270  
   271  	// try to unsubscribe with an already canceled context
   272  	if err = conn.RemoveMatchSignalContext(
   273  		ctx,
   274  		WithMatchInterface("org.test"),
   275  		WithMatchMember("CtxTest"),
   276  	); err == nil {
   277  		t.Fatal("call on canceled context did not fail")
   278  	}
   279  
   280  	// check that signal is still delivered
   281  	if err = conn.Emit("/", "org.test.CtxTest"); err != nil {
   282  		t.Fatal(err)
   283  	}
   284  	if sig := waitSignal(sigc, "org.test.CtxTest", time.Second); sig == nil {
   285  		t.Fatal("signal receive timed out")
   286  	}
   287  
   288  	// unsubscribe from the signal
   289  	if err = conn.RemoveMatchSignalContext(
   290  		context.Background(),
   291  		WithMatchInterface("org.test"),
   292  		WithMatchMember("CtxTest"),
   293  	); err != nil {
   294  		t.Fatal(err)
   295  	}
   296  	if err = conn.Emit("/", "org.test.CtxTest"); err != nil {
   297  		t.Fatal(err)
   298  	}
   299  	if sig := waitSignal(sigc, "org.test.CtxTest", time.Second); sig != nil {
   300  		t.Fatalf("unsubscribed from %q signal, but received %#v", "org.test.CtxTest", sig)
   301  	}
   302  }
   303  
   304  func TestAddAndRemoveMatchSignal(t *testing.T) {
   305  	conn, err := ConnectSessionBus()
   306  	if err != nil {
   307  		t.Fatal(err)
   308  	}
   309  	defer conn.Close()
   310  
   311  	sigc := make(chan *Signal, 1)
   312  	conn.Signal(sigc)
   313  
   314  	// subscribe to a made up signal name and emit one of the type
   315  	if err = conn.AddMatchSignal(
   316  		WithMatchInterface("org.test"),
   317  		WithMatchMember("Test"),
   318  	); err != nil {
   319  		t.Fatal(err)
   320  	}
   321  	if err = conn.Emit("/", "org.test.Test"); err != nil {
   322  		t.Fatal(err)
   323  	}
   324  	if sig := waitSignal(sigc, "org.test.Test", time.Second); sig == nil {
   325  		t.Fatal("signal receive timed out")
   326  	}
   327  
   328  	// unsubscribe from the signal and check that is not delivered anymore
   329  	if err = conn.RemoveMatchSignal(
   330  		WithMatchInterface("org.test"),
   331  		WithMatchMember("Test"),
   332  	); err != nil {
   333  		t.Fatal(err)
   334  	}
   335  	if err = conn.Emit("/", "org.test.Test"); err != nil {
   336  		t.Fatal(err)
   337  	}
   338  	if sig := waitSignal(sigc, "org.test.Test", time.Second); sig != nil {
   339  		t.Fatalf("unsubscribed from %q signal, but received %#v", "org.test.Test", sig)
   340  	}
   341  }
   342  
   343  func waitSignal(sigc <-chan *Signal, name string, timeout time.Duration) *Signal {
   344  	for {
   345  		select {
   346  		case sig := <-sigc:
   347  			if sig.Name == name {
   348  				return sig
   349  			}
   350  		case <-time.After(timeout):
   351  			return nil
   352  		}
   353  	}
   354  }
   355  
   356  const (
   357  	SCPPInterface         = "org.godbus.DBus.StatefulTest"
   358  	SCPPPath              = "/org/godbus/DBus/StatefulTest"
   359  	SCPPChangedSignalName = "Changed"
   360  	SCPPStateMethodName   = "State"
   361  )
   362  
   363  func TestStateCachingProxyPattern(t *testing.T) {
   364  	srv, err := ConnectSessionBus()
   365  	defer srv.Close()
   366  	if err != nil {
   367  		t.Fatal(err)
   368  	}
   369  
   370  	conn, err := ConnectSessionBus(WithSignalHandler(NewSequentialSignalHandler()))
   371  	if err != nil {
   372  		t.Fatal(err)
   373  	}
   374  	defer conn.Close()
   375  
   376  	serviceName := srv.Names()[0]
   377  	// message channel should have at least some buffering, to make sure Eavesdrop does not
   378  	// drop the message if nobody is currently trying to read from the channel.
   379  	messages := make(chan *Message, 1)
   380  	srv.Eavesdrop(messages)
   381  
   382  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
   383  	defer cancel()
   384  
   385  	var wg sync.WaitGroup
   386  	wg.Add(2)
   387  	go func() {
   388  		defer wg.Done()
   389  		if err := serverProcess(ctx, srv, messages, t); err != nil {
   390  			t.Errorf("error in server process: %v", err)
   391  			cancel()
   392  		}
   393  	}()
   394  	go func() {
   395  		defer wg.Done()
   396  		if err := clientProcess(ctx, conn, serviceName, t); err != nil {
   397  			t.Errorf("error in client process: %v", err)
   398  		}
   399  		// Cancel the server process.
   400  		cancel()
   401  	}()
   402  	wg.Wait()
   403  }
   404  
   405  func clientProcess(ctx context.Context, conn *Conn, serviceName string, t *testing.T) error {
   406  	// Subscribe to state changes on the remote object
   407  	if err := conn.AddMatchSignal(
   408  		WithMatchInterface(SCPPInterface),
   409  		WithMatchMember(SCPPChangedSignalName),
   410  	); err != nil {
   411  		return err
   412  	}
   413  	channel := make(chan *Signal)
   414  	conn.Signal(channel)
   415  	t.Log("Subscribed to signals")
   416  
   417  	// Simulate unfavourable OS scheduling leading to a delay between subscription
   418  	// and querying for the current state.
   419  	time.Sleep(30 * time.Millisecond)
   420  
   421  	// Call .State() on the remote object to get its current state and store in observedStates[0].
   422  	obj := conn.Object(serviceName, SCPPPath)
   423  	observedStates := make([]uint64, 1)
   424  	call := obj.CallWithContext(ctx, SCPPInterface+"."+SCPPStateMethodName, 0)
   425  	if err := call.Store(&observedStates[0]); err != nil {
   426  		return err
   427  	}
   428  	t.Logf("Queried current state, got %v", observedStates[0])
   429  
   430  	// Populate observedStates[1...49] based on the state change signals,
   431  	// ignoring signals with a sequence number less than call.ResponseSequence so that we ignore past signals.
   432  	signalsProcessed := 0
   433  readSignals:
   434  	for {
   435  		select {
   436  		case signal := <-channel:
   437  			signalsProcessed++
   438  			if signal.Name == SCPPInterface+"."+SCPPChangedSignalName && signal.Sequence > call.ResponseSequence {
   439  				observedState := signal.Body[0].(uint64)
   440  				observedStates = append(observedStates, observedState)
   441  				// Observing at least 50 states gives us low probability that we received a contiguous subsequence of states 'by accident'
   442  				if len(observedStates) >= 50 {
   443  					break readSignals
   444  				}
   445  			}
   446  		case <-ctx.Done():
   447  			t.Logf("Context cancelled, client processed %v signals", signalsProcessed)
   448  			return ctx.Err()
   449  		}
   450  	}
   451  	t.Logf("client processed %v signals", signalsProcessed)
   452  
   453  	// Expect that we begun observing at least a few states in. This ensures the server was already emitting signals
   454  	// and makes it likely we simulated our race condition.
   455  	if observedStates[0] < 10 {
   456  		return fmt.Errorf("expected first state to be at least 10, got %v", observedStates[0])
   457  	}
   458  
   459  	t.Logf("Observed states: %v", observedStates)
   460  
   461  	// The observable states of the remote object were [1 ... (infinity)] during this test.
   462  	// This loop is intended to assert that our observed states are a contiguous subgrange [n ... n+49] for some n, i.e.
   463  	// that we received a contiguous subsequence of the states of the remote object. For each run of the test, n
   464  	// may be slightly different due to scheduling effects.
   465  	for i := 0; i < len(observedStates); i++ {
   466  		expectedState := observedStates[0] + uint64(i)
   467  		if observedStates[i] != expectedState {
   468  			return fmt.Errorf("expected observed state %v to be %v, got %v", i, expectedState, observedStates[i])
   469  		}
   470  	}
   471  	return nil
   472  }
   473  
   474  func serverProcess(ctx context.Context, srv *Conn, messages <-chan *Message, t *testing.T) error {
   475  	state := uint64(0)
   476  
   477  process:
   478  	for {
   479  		select {
   480  		case msg, ok := <-messages:
   481  			if !ok {
   482  				t.Log("Message channel closed")
   483  				// Message channel closed.
   484  				break process
   485  			}
   486  			if msg.IsValid() != nil {
   487  				t.Log("Got invalid message, discarding")
   488  				continue process
   489  			}
   490  			name := msg.Headers[FieldMember].value.(string)
   491  			ifname := msg.Headers[FieldInterface].value.(string)
   492  			if ifname == SCPPInterface && name == SCPPStateMethodName {
   493  				t.Logf("Processing reply to .State(), returning state = %v", state)
   494  				reply := new(Message)
   495  				reply.Type = TypeMethodReply
   496  				reply.Headers = make(map[HeaderField]Variant)
   497  				reply.Headers[FieldDestination] = msg.Headers[FieldSender]
   498  				reply.Headers[FieldReplySerial] = MakeVariant(msg.serial)
   499  				reply.Body = make([]interface{}, 1)
   500  				reply.Body[0] = state
   501  				reply.Headers[FieldSignature] = MakeVariant(SignatureOf(reply.Body...))
   502  				srv.sendMessageAndIfClosed(reply, nil)
   503  			}
   504  		case <-ctx.Done():
   505  			t.Logf("Context cancelled, server emitted %v signals", state)
   506  			return nil
   507  		default:
   508  			state++
   509  			if err := srv.Emit(SCPPPath, SCPPInterface+"."+SCPPChangedSignalName, state); err != nil {
   510  				return err
   511  			}
   512  		}
   513  	}
   514  	return nil
   515  }
   516  
   517  type server struct{}
   518  
   519  func (server) Double(i int64) (int64, *Error) {
   520  	return 2 * i, nil
   521  }
   522  
   523  func BenchmarkCall(b *testing.B) {
   524  	b.StopTimer()
   525  	b.ReportAllocs()
   526  	var s string
   527  	bus, err := ConnectSessionBus()
   528  	if err != nil {
   529  		b.Fatal(err)
   530  	}
   531  	defer bus.Close()
   532  
   533  	name := bus.Names()[0]
   534  	obj := bus.BusObject()
   535  	b.StartTimer()
   536  	for i := 0; i < b.N; i++ {
   537  		err := obj.Call("org.freedesktop.DBus.GetNameOwner", 0, name).Store(&s)
   538  		if err != nil {
   539  			b.Fatal(err)
   540  		}
   541  		if s != name {
   542  			b.Errorf("got %s, wanted %s", s, name)
   543  		}
   544  	}
   545  }
   546  
   547  func BenchmarkCallAsync(b *testing.B) {
   548  	b.StopTimer()
   549  	b.ReportAllocs()
   550  	bus, err := ConnectSessionBus()
   551  	if err != nil {
   552  		b.Fatal(err)
   553  	}
   554  	defer bus.Close()
   555  
   556  	name := bus.Names()[0]
   557  	obj := bus.BusObject()
   558  	c := make(chan *Call, 50)
   559  	done := make(chan struct{})
   560  	go func() {
   561  		for i := 0; i < b.N; i++ {
   562  			v := <-c
   563  			if v.Err != nil {
   564  				b.Error(v.Err)
   565  			}
   566  			s := v.Body[0].(string)
   567  			if s != name {
   568  				b.Errorf("got %s, wanted %s", s, name)
   569  			}
   570  		}
   571  		close(done)
   572  	}()
   573  	b.StartTimer()
   574  	for i := 0; i < b.N; i++ {
   575  		obj.Go("org.freedesktop.DBus.GetNameOwner", 0, c, name)
   576  	}
   577  	<-done
   578  }
   579  
   580  func BenchmarkServe(b *testing.B) {
   581  	b.StopTimer()
   582  	srv, err := ConnectSessionBus()
   583  	if err != nil {
   584  		b.Fatal(err)
   585  	}
   586  	defer srv.Close()
   587  
   588  	cli, err := ConnectSessionBus()
   589  	if err != nil {
   590  		b.Fatal(err)
   591  	}
   592  	defer cli.Close()
   593  
   594  	benchmarkServe(b, srv, cli)
   595  }
   596  
   597  func BenchmarkServeAsync(b *testing.B) {
   598  	b.StopTimer()
   599  	srv, err := ConnectSessionBus()
   600  	if err != nil {
   601  		b.Fatal(err)
   602  	}
   603  	defer srv.Close()
   604  
   605  	cli, err := ConnectSessionBus()
   606  	if err != nil {
   607  		b.Fatal(err)
   608  	}
   609  	defer cli.Close()
   610  
   611  	benchmarkServeAsync(b, srv, cli)
   612  }
   613  
   614  func BenchmarkServeSameConn(b *testing.B) {
   615  	b.StopTimer()
   616  	bus, err := ConnectSessionBus()
   617  	if err != nil {
   618  		b.Fatal(err)
   619  	}
   620  	defer bus.Close()
   621  
   622  	benchmarkServe(b, bus, bus)
   623  }
   624  
   625  func BenchmarkServeSameConnAsync(b *testing.B) {
   626  	b.StopTimer()
   627  	bus, err := ConnectSessionBus()
   628  	if err != nil {
   629  		b.Fatal(err)
   630  	}
   631  	defer bus.Close()
   632  
   633  	benchmarkServeAsync(b, bus, bus)
   634  }
   635  
   636  func benchmarkServe(b *testing.B, srv, cli *Conn) {
   637  	var r int64
   638  	var err error
   639  	dest := srv.Names()[0]
   640  	srv.Export(server{}, "/org/guelfey/DBus/Test", "org.guelfey.DBus.Test")
   641  	obj := cli.Object(dest, "/org/guelfey/DBus/Test")
   642  	b.StartTimer()
   643  	for i := 0; i < b.N; i++ {
   644  		err = obj.Call("org.guelfey.DBus.Test.Double", 0, int64(i)).Store(&r)
   645  		if err != nil {
   646  			b.Fatal(err)
   647  		}
   648  		if r != 2*int64(i) {
   649  			b.Errorf("got %d, wanted %d", r, 2*int64(i))
   650  		}
   651  	}
   652  }
   653  
   654  func benchmarkServeAsync(b *testing.B, srv, cli *Conn) {
   655  	dest := srv.Names()[0]
   656  	srv.Export(server{}, "/org/guelfey/DBus/Test", "org.guelfey.DBus.Test")
   657  	obj := cli.Object(dest, "/org/guelfey/DBus/Test")
   658  	c := make(chan *Call, 50)
   659  	done := make(chan struct{})
   660  	go func() {
   661  		for i := 0; i < b.N; i++ {
   662  			v := <-c
   663  			if v.Err != nil {
   664  				b.Fatal(v.Err)
   665  			}
   666  			i, r := v.Args[0].(int64), v.Body[0].(int64)
   667  			if 2*i != r {
   668  				b.Errorf("got %d, wanted %d", r, 2*i)
   669  			}
   670  		}
   671  		close(done)
   672  	}()
   673  	b.StartTimer()
   674  	for i := 0; i < b.N; i++ {
   675  		obj.Go("org.guelfey.DBus.Test.Double", 0, c, int64(i))
   676  	}
   677  	<-done
   678  }
   679  
   680  func TestGetKey(t *testing.T) {
   681  	keys := "host=1.2.3.4,port=5678,family=ipv4"
   682  	if host := getKey(keys, "host"); host != "1.2.3.4" {
   683  		t.Error(`Expected "1.2.3.4", got`, host)
   684  	}
   685  	if port := getKey(keys, "port"); port != "5678" {
   686  		t.Error(`Expected "5678", got`, port)
   687  	}
   688  	if family := getKey(keys, "family"); family != "ipv4" {
   689  		t.Error(`Expected "ipv4", got`, family)
   690  	}
   691  }
   692  
   693  func TestInterceptors(t *testing.T) {
   694  	conn, err := ConnectSessionBus(
   695  		WithIncomingInterceptor(func(msg *Message) {
   696  			log.Println("<", msg)
   697  		}),
   698  		WithOutgoingInterceptor(func(msg *Message) {
   699  			log.Println(">", msg)
   700  		}),
   701  	)
   702  	if err != nil {
   703  		t.Fatal(err)
   704  	}
   705  	defer conn.Close()
   706  }
   707  
   708  func TestCloseCancelsConnectionContext(t *testing.T) {
   709  	bus, err := ConnectSessionBus()
   710  	if err != nil {
   711  		t.Fatal(err)
   712  	}
   713  	defer bus.Close()
   714  
   715  	// The context is not done at this point
   716  	ctx := bus.Context()
   717  	select {
   718  	case <-ctx.Done():
   719  		t.Fatal("context should not be done")
   720  	default:
   721  	}
   722  
   723  	err = bus.Close()
   724  	if err != nil {
   725  		t.Fatal(err)
   726  	}
   727  	select {
   728  	case <-ctx.Done():
   729  		// expected
   730  	case <-time.After(5 * time.Second):
   731  		t.Fatal("context is not done after connection closed")
   732  	}
   733  }
   734  
   735  func TestDisconnectCancelsConnectionContext(t *testing.T) {
   736  	reader, pipewriter := io.Pipe()
   737  	defer pipewriter.Close()
   738  	defer reader.Close()
   739  
   740  	bus, err := NewConn(rwc{Reader: reader, Writer: ioutil.Discard})
   741  	if err != nil {
   742  		t.Fatal(err)
   743  	}
   744  
   745  	go func() {
   746  		_, err := pipewriter.Write([]byte("REJECTED name\r\nOK myuuid\r\n"))
   747  		if err != nil {
   748  			t.Errorf("error writing to pipe: %v", err)
   749  		}
   750  	}()
   751  	err = bus.Auth([]Auth{fakeAuth{}})
   752  	if err != nil {
   753  		t.Fatal(err)
   754  	}
   755  
   756  	ctx := bus.Context()
   757  
   758  	pipewriter.Close()
   759  	select {
   760  	case <-ctx.Done():
   761  		// expected
   762  	case <-time.After(5 * time.Second):
   763  		t.Fatal("context is not done after connection closed")
   764  	}
   765  }
   766  
   767  func TestCancellingContextClosesConnection(t *testing.T) {
   768  	ctx, cancel := context.WithCancel(context.Background())
   769  	defer cancel()
   770  
   771  	reader, pipewriter := io.Pipe()
   772  	defer pipewriter.Close()
   773  	defer reader.Close()
   774  
   775  	bus, err := NewConn(rwc{Reader: reader, Writer: ioutil.Discard}, WithContext(ctx))
   776  	if err != nil {
   777  		t.Fatal(err)
   778  	}
   779  
   780  	go func() {
   781  		_, err := pipewriter.Write([]byte("REJECTED name\r\nOK myuuid\r\n"))
   782  		if err != nil {
   783  			t.Errorf("error writing to pipe: %v", err)
   784  		}
   785  	}()
   786  	err = bus.Auth([]Auth{fakeAuth{}})
   787  	if err != nil {
   788  		t.Fatal(err)
   789  	}
   790  
   791  	// Cancel the connection's parent context and give time for
   792  	// other goroutines to schedule.
   793  	cancel()
   794  	time.Sleep(50 * time.Millisecond)
   795  
   796  	err = bus.BusObject().Call("org.freedesktop.DBus.Peer.Ping", 0).Store()
   797  	if err != ErrClosed {
   798  		t.Errorf("expected connection to be closed, but got: %v", err)
   799  	}
   800  }
   801  
   802  // TestTimeoutContextClosesConnection checks that a Conn instance is closed after
   803  // the passed context's deadline is missed.
   804  // The test also checks that there's no data race between Conn creation and its
   805  // automatic closing.
   806  func TestTimeoutContextClosesConnection(t *testing.T) {
   807  	ctx, cancel := context.WithTimeout(context.Background(), 0)
   808  	defer cancel()
   809  
   810  	bus, err := NewConn(rwc{}, WithContext(ctx))
   811  	if err != nil {
   812  		t.Fatal(err)
   813  	}
   814  
   815  	// wait until the connection is actually closed
   816  	time.Sleep(50 * time.Millisecond)
   817  
   818  	err = bus.BusObject().Call("org.freedesktop.DBus.Peer.Ping", 0).Store()
   819  	if err != ErrClosed {
   820  		t.Errorf("expected connection to be closed, but got: %v", err)
   821  	}
   822  }
   823  

View as plain text