...

Source file src/k8s.io/client-go/tools/remotecommand/v2_test.go

Documentation: k8s.io/client-go/tools/remotecommand

     1  /*
     2  Copyright 2016 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package remotecommand
    18  
    19  import (
    20  	"errors"
    21  	"io"
    22  	"net/http"
    23  	"strings"
    24  	"testing"
    25  	"time"
    26  
    27  	"k8s.io/api/core/v1"
    28  	"k8s.io/apimachinery/pkg/util/httpstream"
    29  	"k8s.io/apimachinery/pkg/util/wait"
    30  )
    31  
    32  type fakeReader struct {
    33  	err error
    34  }
    35  
    36  func (r *fakeReader) Read([]byte) (int, error) { return 0, r.err }
    37  
    38  type fakeWriter struct{}
    39  
    40  func (*fakeWriter) Write([]byte) (int, error) { return 0, nil }
    41  
    42  type fakeStreamCreator struct {
    43  	created map[string]bool
    44  	errors  map[string]error
    45  }
    46  
    47  var _ streamCreator = &fakeStreamCreator{}
    48  
    49  func (f *fakeStreamCreator) CreateStream(headers http.Header) (httpstream.Stream, error) {
    50  	streamType := headers.Get(v1.StreamType)
    51  	f.created[streamType] = true
    52  	return nil, f.errors[streamType]
    53  }
    54  
    55  func TestV2CreateStreams(t *testing.T) {
    56  	tests := []struct {
    57  		name        string
    58  		stdin       bool
    59  		stdinError  error
    60  		stdout      bool
    61  		stdoutError error
    62  		stderr      bool
    63  		stderrError error
    64  		errorError  error
    65  		tty         bool
    66  		expectError bool
    67  	}{
    68  		{
    69  			name:        "stdin error",
    70  			stdin:       true,
    71  			stdinError:  errors.New("stdin error"),
    72  			expectError: true,
    73  		},
    74  		{
    75  			name:        "stdout error",
    76  			stdout:      true,
    77  			stdoutError: errors.New("stdout error"),
    78  			expectError: true,
    79  		},
    80  		{
    81  			name:        "stderr error",
    82  			stderr:      true,
    83  			stderrError: errors.New("stderr error"),
    84  			expectError: true,
    85  		},
    86  		{
    87  			name:        "error stream error",
    88  			stdin:       true,
    89  			stdout:      true,
    90  			stderr:      true,
    91  			errorError:  errors.New("error stream error"),
    92  			expectError: true,
    93  		},
    94  		{
    95  			name:        "no errors",
    96  			stdin:       true,
    97  			stdout:      true,
    98  			stderr:      true,
    99  			expectError: false,
   100  		},
   101  		{
   102  			name:        "no errors, stderr & tty set, don't expect stderr",
   103  			stdin:       true,
   104  			stdout:      true,
   105  			stderr:      true,
   106  			tty:         true,
   107  			expectError: false,
   108  		},
   109  	}
   110  	for _, test := range tests {
   111  		conn := &fakeStreamCreator{
   112  			created: make(map[string]bool),
   113  			errors: map[string]error{
   114  				v1.StreamTypeStdin:  test.stdinError,
   115  				v1.StreamTypeStdout: test.stdoutError,
   116  				v1.StreamTypeStderr: test.stderrError,
   117  				v1.StreamTypeError:  test.errorError,
   118  			},
   119  		}
   120  
   121  		opts := StreamOptions{Tty: test.tty}
   122  		if test.stdin {
   123  			opts.Stdin = &fakeReader{}
   124  		}
   125  		if test.stdout {
   126  			opts.Stdout = &fakeWriter{}
   127  		}
   128  		if test.stderr {
   129  			opts.Stderr = &fakeWriter{}
   130  		}
   131  
   132  		h := newStreamProtocolV2(opts).(*streamProtocolV2)
   133  		err := h.createStreams(conn)
   134  
   135  		if test.expectError {
   136  			if err == nil {
   137  				t.Errorf("%s: expected error", test.name)
   138  				continue
   139  			}
   140  			if e, a := test.stdinError, err; test.stdinError != nil && e != a {
   141  				t.Errorf("%s: expected %v, got %v", test.name, e, a)
   142  			}
   143  			if e, a := test.stdoutError, err; test.stdoutError != nil && e != a {
   144  				t.Errorf("%s: expected %v, got %v", test.name, e, a)
   145  			}
   146  			if e, a := test.stderrError, err; test.stderrError != nil && e != a {
   147  				t.Errorf("%s: expected %v, got %v", test.name, e, a)
   148  			}
   149  			if e, a := test.errorError, err; test.errorError != nil && e != a {
   150  				t.Errorf("%s: expected %v, got %v", test.name, e, a)
   151  			}
   152  			continue
   153  		}
   154  
   155  		if !test.expectError && err != nil {
   156  			t.Errorf("%s: unexpected error: %v", test.name, err)
   157  			continue
   158  		}
   159  
   160  		if test.stdin && !conn.created[v1.StreamTypeStdin] {
   161  			t.Errorf("%s: expected stdin stream", test.name)
   162  		}
   163  		if test.stdout && !conn.created[v1.StreamTypeStdout] {
   164  			t.Errorf("%s: expected stdout stream", test.name)
   165  		}
   166  		if test.stderr {
   167  			if test.tty && conn.created[v1.StreamTypeStderr] {
   168  				t.Errorf("%s: unexpected stderr stream because tty is set", test.name)
   169  			} else if !test.tty && !conn.created[v1.StreamTypeStderr] {
   170  				t.Errorf("%s: expected stderr stream", test.name)
   171  			}
   172  		}
   173  		if !conn.created[v1.StreamTypeError] {
   174  			t.Errorf("%s: expected error stream", test.name)
   175  		}
   176  
   177  	}
   178  }
   179  
   180  func TestV2ErrorStreamReading(t *testing.T) {
   181  	tests := []struct {
   182  		name          string
   183  		stream        io.Reader
   184  		expectedError error
   185  	}{
   186  		{
   187  			name:          "error reading from stream",
   188  			stream:        &fakeReader{errors.New("foo")},
   189  			expectedError: errors.New("error reading from error stream: foo"),
   190  		},
   191  		{
   192  			name:          "stream returns an error",
   193  			stream:        strings.NewReader("some error"),
   194  			expectedError: errors.New("error executing remote command: some error"),
   195  		},
   196  	}
   197  
   198  	for _, test := range tests {
   199  		h := newStreamProtocolV2(StreamOptions{}).(*streamProtocolV2)
   200  		h.errorStream = test.stream
   201  
   202  		ch := watchErrorStream(h.errorStream, &errorDecoderV2{})
   203  		if ch == nil {
   204  			t.Fatalf("%s: unexpected nil channel", test.name)
   205  		}
   206  
   207  		var err error
   208  		select {
   209  		case err = <-ch:
   210  		case <-time.After(wait.ForeverTestTimeout):
   211  			t.Fatalf("%s: timed out", test.name)
   212  		}
   213  
   214  		if test.expectedError != nil {
   215  			if err == nil {
   216  				t.Errorf("%s: expected an error", test.name)
   217  			} else if e, a := test.expectedError, err; e.Error() != a.Error() {
   218  				t.Errorf("%s: expected %q, got %q", test.name, e, a)
   219  			}
   220  			continue
   221  		}
   222  
   223  		if test.expectedError == nil && err != nil {
   224  			t.Errorf("%s: unexpected error: %v", test.name, err)
   225  			continue
   226  		}
   227  	}
   228  }
   229  

View as plain text