...

Source file src/k8s.io/kubernetes/pkg/client/tests/remotecommand_test.go

Documentation: k8s.io/kubernetes/pkg/client/tests

     1  /*
     2  Copyright 2015 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package tests
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"errors"
    23  	"fmt"
    24  	"io"
    25  	"io/ioutil"
    26  	"net/http"
    27  	"net/http/httptest"
    28  	"net/url"
    29  	"strings"
    30  	"testing"
    31  	"time"
    32  
    33  	"github.com/stretchr/testify/require"
    34  
    35  	"k8s.io/apimachinery/pkg/runtime"
    36  	"k8s.io/apimachinery/pkg/runtime/schema"
    37  	"k8s.io/apimachinery/pkg/types"
    38  	"k8s.io/apimachinery/pkg/util/httpstream"
    39  	remotecommandconsts "k8s.io/apimachinery/pkg/util/remotecommand"
    40  	restclient "k8s.io/client-go/rest"
    41  	remoteclient "k8s.io/client-go/tools/remotecommand"
    42  	"k8s.io/client-go/transport/spdy"
    43  	"k8s.io/kubelet/pkg/cri/streaming/remotecommand"
    44  	"k8s.io/kubernetes/pkg/api/legacyscheme"
    45  	api "k8s.io/kubernetes/pkg/apis/core"
    46  )
    47  
    48  type fakeExecutor struct {
    49  	t             *testing.T
    50  	testName      string
    51  	errorData     string
    52  	stdoutData    string
    53  	stderrData    string
    54  	expectStdin   bool
    55  	stdinReceived bytes.Buffer
    56  	tty           bool
    57  	messageCount  int
    58  	command       []string
    59  	exec          bool
    60  }
    61  
    62  func (ex *fakeExecutor) ExecInContainer(_ context.Context, name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remoteclient.TerminalSize, timeout time.Duration) error {
    63  	return ex.run(name, uid, container, cmd, in, out, err, tty)
    64  }
    65  
    66  func (ex *fakeExecutor) AttachContainer(_ context.Context, name string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remoteclient.TerminalSize) error {
    67  	return ex.run(name, uid, container, nil, in, out, err, tty)
    68  }
    69  
    70  func (ex *fakeExecutor) run(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error {
    71  	ex.command = cmd
    72  	ex.tty = tty
    73  
    74  	if e, a := "pod", name; e != a {
    75  		ex.t.Errorf("%s: pod: expected %q, got %q", ex.testName, e, a)
    76  	}
    77  	if e, a := "uid", uid; e != string(a) {
    78  		ex.t.Errorf("%s: uid: expected %q, got %q", ex.testName, e, a)
    79  	}
    80  	if ex.exec {
    81  		if e, a := "ls /", strings.Join(ex.command, " "); e != a {
    82  			ex.t.Errorf("%s: command: expected %q, got %q", ex.testName, e, a)
    83  		}
    84  	} else {
    85  		if len(ex.command) > 0 {
    86  			ex.t.Errorf("%s: command: expected nothing, got %v", ex.testName, ex.command)
    87  		}
    88  	}
    89  
    90  	if len(ex.errorData) > 0 {
    91  		return errors.New(ex.errorData)
    92  	}
    93  
    94  	if len(ex.stdoutData) > 0 {
    95  		for i := 0; i < ex.messageCount; i++ {
    96  			fmt.Fprint(out, ex.stdoutData)
    97  		}
    98  	}
    99  
   100  	if len(ex.stderrData) > 0 {
   101  		for i := 0; i < ex.messageCount; i++ {
   102  			fmt.Fprint(err, ex.stderrData)
   103  		}
   104  	}
   105  
   106  	if ex.expectStdin {
   107  		io.Copy(&ex.stdinReceived, in)
   108  	}
   109  
   110  	return nil
   111  }
   112  
   113  func fakeServer(t *testing.T, requestReceived chan struct{}, testName string, exec bool, stdinData, stdoutData, stderrData, errorData string, tty bool, messageCount int, serverProtocols []string) http.HandlerFunc {
   114  	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   115  		executor := &fakeExecutor{
   116  			t:            t,
   117  			testName:     testName,
   118  			errorData:    errorData,
   119  			stdoutData:   stdoutData,
   120  			stderrData:   stderrData,
   121  			expectStdin:  len(stdinData) > 0,
   122  			tty:          tty,
   123  			messageCount: messageCount,
   124  			exec:         exec,
   125  		}
   126  
   127  		opts, err := remotecommand.NewOptions(req)
   128  		require.NoError(t, err)
   129  		if exec {
   130  			cmd := req.URL.Query()[api.ExecCommandParam]
   131  			remotecommand.ServeExec(w, req, executor, "pod", "uid", "container", cmd, opts, 0, 10*time.Second, serverProtocols)
   132  		} else {
   133  			remotecommand.ServeAttach(w, req, executor, "pod", "uid", "container", opts, 0, 10*time.Second, serverProtocols)
   134  		}
   135  
   136  		if e, a := strings.Repeat(stdinData, messageCount), executor.stdinReceived.String(); e != a {
   137  			t.Errorf("%s: stdin: expected %q, got %q", testName, e, a)
   138  		}
   139  		close(requestReceived)
   140  	})
   141  }
   142  
   143  func TestStream(t *testing.T) {
   144  	testCases := []struct {
   145  		TestName        string
   146  		Stdin           string
   147  		Stdout          string
   148  		Stderr          string
   149  		Error           string
   150  		Tty             bool
   151  		MessageCount    int
   152  		ClientProtocols []string
   153  		ServerProtocols []string
   154  	}{
   155  		{
   156  			TestName:        "error",
   157  			Error:           "bail",
   158  			Stdout:          "a",
   159  			ClientProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
   160  			ServerProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
   161  		},
   162  		{
   163  			TestName:        "in/out/err",
   164  			Stdin:           "a",
   165  			Stdout:          "b",
   166  			Stderr:          "c",
   167  			MessageCount:    100,
   168  			ClientProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
   169  			ServerProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
   170  		},
   171  		{
   172  			TestName:        "oversized stdin",
   173  			Stdin:           strings.Repeat("a", 20*1024*1024),
   174  			Stdout:          "b",
   175  			Stderr:          "",
   176  			MessageCount:    1,
   177  			ClientProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
   178  			ServerProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
   179  		},
   180  		{
   181  			TestName:        "in/out/tty",
   182  			Stdin:           "a",
   183  			Stdout:          "b",
   184  			Tty:             true,
   185  			MessageCount:    100,
   186  			ClientProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
   187  			ServerProtocols: []string{remotecommandconsts.StreamProtocolV2Name},
   188  		},
   189  	}
   190  
   191  	for _, testCase := range testCases {
   192  		for _, exec := range []bool{true, false} {
   193  			var name string
   194  			if exec {
   195  				name = testCase.TestName + " (exec)"
   196  			} else {
   197  				name = testCase.TestName + " (attach)"
   198  			}
   199  
   200  			t.Run(name, func(t *testing.T) {
   201  				var (
   202  					streamIn             io.Reader
   203  					streamOut, streamErr io.Writer
   204  				)
   205  				localOut := &bytes.Buffer{}
   206  				localErr := &bytes.Buffer{}
   207  
   208  				requestReceived := make(chan struct{})
   209  				server := httptest.NewServer(fakeServer(t, requestReceived, name, exec, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, testCase.MessageCount, testCase.ServerProtocols))
   210  				defer server.Close()
   211  
   212  				url, _ := url.ParseRequestURI(server.URL)
   213  				config := restclient.ClientContentConfig{
   214  					GroupVersion: schema.GroupVersion{Group: "x"},
   215  					Negotiator:   runtime.NewClientNegotiator(legacyscheme.Codecs.WithoutConversion(), schema.GroupVersion{Group: "x"}),
   216  				}
   217  				c, err := restclient.NewRESTClient(url, "", config, nil, nil)
   218  				if err != nil {
   219  					t.Fatalf("failed to create a client: %v", err)
   220  				}
   221  				req := c.Post().Resource("testing")
   222  
   223  				if exec {
   224  					req.Param("command", "ls")
   225  					req.Param("command", "/")
   226  				}
   227  
   228  				if len(testCase.Stdin) > 0 {
   229  					req.Param(api.ExecStdinParam, "1")
   230  					streamIn = strings.NewReader(strings.Repeat(testCase.Stdin, testCase.MessageCount))
   231  				}
   232  
   233  				if len(testCase.Stdout) > 0 {
   234  					req.Param(api.ExecStdoutParam, "1")
   235  					streamOut = localOut
   236  				}
   237  
   238  				if testCase.Tty {
   239  					req.Param(api.ExecTTYParam, "1")
   240  				} else if len(testCase.Stderr) > 0 {
   241  					req.Param(api.ExecStderrParam, "1")
   242  					streamErr = localErr
   243  				}
   244  
   245  				conf := &restclient.Config{
   246  					Host: server.URL,
   247  				}
   248  				transport, upgradeTransport, err := spdy.RoundTripperFor(conf)
   249  				if err != nil {
   250  					t.Fatalf("%s: unexpected error: %v", name, err)
   251  				}
   252  				e, err := remoteclient.NewSPDYExecutorForProtocols(transport, upgradeTransport, "POST", req.URL(), testCase.ClientProtocols...)
   253  				if err != nil {
   254  					t.Fatalf("%s: unexpected error: %v", name, err)
   255  				}
   256  				err = e.StreamWithContext(context.Background(), remoteclient.StreamOptions{
   257  					Stdin:  streamIn,
   258  					Stdout: streamOut,
   259  					Stderr: streamErr,
   260  					Tty:    testCase.Tty,
   261  				})
   262  				hasErr := err != nil
   263  
   264  				if len(testCase.Error) > 0 {
   265  					if !hasErr {
   266  						t.Errorf("%s: expected an error", name)
   267  					} else {
   268  						if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) {
   269  							t.Errorf("%s: expected error stream read %q, got %q", name, e, a)
   270  						}
   271  					}
   272  					return
   273  				}
   274  
   275  				if hasErr {
   276  					t.Fatalf("%s: unexpected error: %v", name, err)
   277  				}
   278  
   279  				if len(testCase.Stdout) > 0 {
   280  					if e, a := strings.Repeat(testCase.Stdout, testCase.MessageCount), localOut; e != a.String() {
   281  						t.Fatalf("%s: expected stdout data %q, got %q", name, e, a)
   282  					}
   283  				}
   284  
   285  				if testCase.Stderr != "" {
   286  					if e, a := strings.Repeat(testCase.Stderr, testCase.MessageCount), localErr; e != a.String() {
   287  						t.Fatalf("%s: expected stderr data %q, got %q", name, e, a)
   288  					}
   289  				}
   290  
   291  				select {
   292  				case <-requestReceived:
   293  				case <-time.After(time.Minute):
   294  					t.Errorf("%s: expected fakeServerInstance to receive request", name)
   295  				}
   296  			})
   297  		}
   298  	}
   299  }
   300  
   301  type fakeUpgrader struct {
   302  	req           *http.Request
   303  	resp          *http.Response
   304  	conn          httpstream.Connection
   305  	err, connErr  error
   306  	checkResponse bool
   307  	called        bool
   308  
   309  	t *testing.T
   310  }
   311  
   312  func (u *fakeUpgrader) RoundTrip(req *http.Request) (*http.Response, error) {
   313  	u.called = true
   314  	u.req = req
   315  	return u.resp, u.err
   316  }
   317  
   318  func (u *fakeUpgrader) NewConnection(resp *http.Response) (httpstream.Connection, error) {
   319  	if u.checkResponse && u.resp != resp {
   320  		u.t.Errorf("response objects passed did not match: %#v", resp)
   321  	}
   322  	return u.conn, u.connErr
   323  }
   324  
   325  type fakeConnection struct {
   326  	httpstream.Connection
   327  }
   328  
   329  // Dial is the common functionality between any stream based upgrader, regardless of protocol.
   330  // This method ensures that someone can use a generic stream executor without being dependent
   331  // on the core Kube client config behavior.
   332  func TestDial(t *testing.T) {
   333  	upgrader := &fakeUpgrader{
   334  		t:             t,
   335  		checkResponse: true,
   336  		conn:          &fakeConnection{},
   337  		resp: &http.Response{
   338  			StatusCode: http.StatusSwitchingProtocols,
   339  			Body:       ioutil.NopCloser(&bytes.Buffer{}),
   340  		},
   341  	}
   342  	dialer := spdy.NewDialer(upgrader, &http.Client{Transport: upgrader}, "POST", &url.URL{Host: "something.com", Scheme: "https"})
   343  	conn, protocol, err := dialer.Dial("protocol1")
   344  	if err != nil {
   345  		t.Fatal(err)
   346  	}
   347  	if conn != upgrader.conn {
   348  		t.Errorf("unexpected connection: %#v", conn)
   349  	}
   350  	if !upgrader.called {
   351  		t.Errorf("request not called")
   352  	}
   353  	_ = protocol
   354  }
   355  

View as plain text