...

Source file src/k8s.io/client-go/tools/remotecommand/spdy_test.go

Documentation: k8s.io/client-go/tools/remotecommand

     1  /*
     2  Copyright 2020 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package remotecommand
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"crypto/rand"
    23  	"encoding/json"
    24  	"errors"
    25  	"io"
    26  	"net/http"
    27  	"net/http/httptest"
    28  	"net/url"
    29  	"strings"
    30  	"testing"
    31  	"time"
    32  
    33  	v1 "k8s.io/api/core/v1"
    34  	apierrors "k8s.io/apimachinery/pkg/api/errors"
    35  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    36  	"k8s.io/apimachinery/pkg/util/httpstream"
    37  	"k8s.io/apimachinery/pkg/util/httpstream/spdy"
    38  	remotecommandconsts "k8s.io/apimachinery/pkg/util/remotecommand"
    39  	"k8s.io/apimachinery/pkg/util/wait"
    40  	"k8s.io/client-go/rest"
    41  )
    42  
    43  type AttachFunc func(in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan TerminalSize) error
    44  type streamContext struct {
    45  	conn         io.Closer
    46  	stdinStream  io.ReadCloser
    47  	stdoutStream io.WriteCloser
    48  	stderrStream io.WriteCloser
    49  	writeStatus  func(status *apierrors.StatusError) error
    50  }
    51  
    52  type streamAndReply struct {
    53  	httpstream.Stream
    54  	replySent <-chan struct{}
    55  }
    56  
    57  type fakeEmptyDataPty struct {
    58  }
    59  
    60  func (s *fakeEmptyDataPty) Read(p []byte) (int, error) {
    61  	return len(p), nil
    62  }
    63  
    64  func (s *fakeEmptyDataPty) Write(p []byte) (int, error) {
    65  	return len(p), nil
    66  }
    67  
    68  type fakeMassiveDataPty struct{}
    69  
    70  func (s *fakeMassiveDataPty) Read(p []byte) (int, error) {
    71  	time.Sleep(time.Duration(1) * time.Second)
    72  	return copy(p, []byte{}), errors.New("client crashed after 1 second")
    73  }
    74  
    75  func (s *fakeMassiveDataPty) Write(p []byte) (int, error) {
    76  	time.Sleep(time.Duration(1) * time.Second)
    77  	return len(p), errors.New("return err")
    78  }
    79  
    80  func fakeMassiveDataAttacher(stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan TerminalSize) error {
    81  
    82  	copyDone := make(chan struct{}, 3)
    83  
    84  	if stdin == nil {
    85  		return errors.New("stdin is requested") // we need stdin to notice the conn break
    86  	}
    87  
    88  	go func() {
    89  		io.Copy(io.Discard, stdin)
    90  		copyDone <- struct{}{}
    91  	}()
    92  
    93  	go func() {
    94  		if stdout == nil {
    95  			return
    96  		}
    97  		copyDone <- writeMassiveData(stdout)
    98  	}()
    99  
   100  	go func() {
   101  		if stderr == nil {
   102  			return
   103  		}
   104  		copyDone <- writeMassiveData(stderr)
   105  	}()
   106  
   107  	select {
   108  	case <-copyDone:
   109  		return nil
   110  	}
   111  }
   112  
   113  func writeMassiveData(stdStream io.Writer) struct{} { // write to stdin or stdout
   114  	for {
   115  		_, err := io.Copy(stdStream, strings.NewReader("something"))
   116  		if err != nil && err.Error() != "EOF" {
   117  			break
   118  		}
   119  	}
   120  	return struct{}{}
   121  }
   122  
   123  func TestSPDYExecutorStream(t *testing.T) {
   124  	tests := []struct {
   125  		timeout     time.Duration
   126  		name        string
   127  		options     StreamOptions
   128  		expectError string
   129  		attacher    AttachFunc
   130  	}{
   131  		{
   132  			name: "stdoutBlockTest",
   133  			options: StreamOptions{
   134  				Stdin:  &fakeMassiveDataPty{},
   135  				Stdout: &fakeMassiveDataPty{},
   136  			},
   137  			expectError: "",
   138  			attacher:    fakeMassiveDataAttacher,
   139  		},
   140  		{
   141  			name: "stderrBlockTest",
   142  			options: StreamOptions{
   143  				Stdin:  &fakeMassiveDataPty{},
   144  				Stderr: &fakeMassiveDataPty{},
   145  			},
   146  			expectError: "",
   147  			attacher:    fakeMassiveDataAttacher,
   148  		},
   149  		{
   150  			timeout: 500 * time.Millisecond,
   151  			name:    "timeoutTest",
   152  			options: StreamOptions{
   153  				Stdin:  &fakeMassiveDataPty{},
   154  				Stderr: &fakeMassiveDataPty{},
   155  			},
   156  			expectError: context.DeadlineExceeded.Error(),
   157  			attacher:    fakeMassiveDataAttacher,
   158  		},
   159  	}
   160  
   161  	for _, test := range tests {
   162  		t.Run(test.name, func(t *testing.T) {
   163  			server := newTestHTTPServer(test.attacher, &test.options)
   164  			defer server.Close()
   165  
   166  			ctx, cancel := context.Background(), func() {}
   167  			if test.timeout > 0 {
   168  				ctx, cancel = context.WithTimeout(ctx, test.timeout)
   169  			}
   170  			defer cancel()
   171  
   172  			err := attach2Server(ctx, server.URL, test.options)
   173  
   174  			gotError := ""
   175  			if err != nil {
   176  				gotError = err.Error()
   177  			}
   178  			if test.expectError != gotError {
   179  				t.Errorf("%s: expected [%v], got [%v]", test.name, test.expectError, gotError)
   180  			}
   181  		})
   182  	}
   183  }
   184  
   185  func newTestHTTPServer(f AttachFunc, options *StreamOptions) *httptest.Server {
   186  	//nolint:errcheck
   187  	server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
   188  		ctx, err := createHTTPStreams(writer, request, options)
   189  		if err != nil {
   190  			return
   191  		}
   192  		defer ctx.conn.Close()
   193  
   194  		// handle input output
   195  		err = f(ctx.stdinStream, ctx.stdoutStream, ctx.stderrStream, false, nil)
   196  		if err != nil {
   197  			ctx.writeStatus(apierrors.NewInternalError(err))
   198  		} else {
   199  			ctx.writeStatus(&apierrors.StatusError{ErrStatus: metav1.Status{
   200  				Status: metav1.StatusSuccess,
   201  			}})
   202  		}
   203  	}))
   204  	return server
   205  }
   206  
   207  func attach2Server(ctx context.Context, rawURL string, options StreamOptions) error {
   208  	uri, _ := url.Parse(rawURL)
   209  	exec, err := NewSPDYExecutor(&rest.Config{Host: uri.Host}, "POST", uri)
   210  	if err != nil {
   211  		return err
   212  	}
   213  
   214  	e := make(chan error, 1)
   215  	go func(e chan error) {
   216  		e <- exec.StreamWithContext(ctx, options)
   217  	}(e)
   218  	select {
   219  	case err := <-e:
   220  		return err
   221  	case <-time.After(wait.ForeverTestTimeout):
   222  		return errors.New("execute timeout")
   223  	}
   224  }
   225  
   226  // simplify createHttpStreams , only support StreamProtocolV4Name
   227  func createHTTPStreams(w http.ResponseWriter, req *http.Request, opts *StreamOptions) (*streamContext, error) {
   228  	_, err := httpstream.Handshake(req, w, []string{remotecommandconsts.StreamProtocolV4Name})
   229  	if err != nil {
   230  		return nil, err
   231  	}
   232  
   233  	upgrader := spdy.NewResponseUpgrader()
   234  	streamCh := make(chan streamAndReply)
   235  	conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream, replySent <-chan struct{}) error {
   236  		streamCh <- streamAndReply{Stream: stream, replySent: replySent}
   237  		return nil
   238  	})
   239  	ctx := &streamContext{
   240  		conn: conn,
   241  	}
   242  
   243  	// wait for stream
   244  	replyChan := make(chan struct{}, 4)
   245  	defer close(replyChan)
   246  	receivedStreams := 0
   247  	expectedStreams := 1
   248  	if opts.Stdout != nil {
   249  		expectedStreams++
   250  	}
   251  	if opts.Stdin != nil {
   252  		expectedStreams++
   253  	}
   254  	if opts.Stderr != nil {
   255  		expectedStreams++
   256  	}
   257  WaitForStreams:
   258  	for {
   259  		select {
   260  		case stream := <-streamCh:
   261  			streamType := stream.Headers().Get(v1.StreamType)
   262  			switch streamType {
   263  			case v1.StreamTypeError:
   264  				replyChan <- struct{}{}
   265  				ctx.writeStatus = v4WriteStatusFunc(stream)
   266  			case v1.StreamTypeStdout:
   267  				replyChan <- struct{}{}
   268  				ctx.stdoutStream = stream
   269  			case v1.StreamTypeStdin:
   270  				replyChan <- struct{}{}
   271  				ctx.stdinStream = stream
   272  			case v1.StreamTypeStderr:
   273  				replyChan <- struct{}{}
   274  				ctx.stderrStream = stream
   275  			default:
   276  				// add other stream ...
   277  				return nil, errors.New("unimplemented stream type")
   278  			}
   279  		case <-replyChan:
   280  			receivedStreams++
   281  			if receivedStreams == expectedStreams {
   282  				break WaitForStreams
   283  			}
   284  		}
   285  	}
   286  
   287  	return ctx, nil
   288  }
   289  
   290  func v4WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) error {
   291  	return func(status *apierrors.StatusError) error {
   292  		bs, err := json.Marshal(status.Status())
   293  		if err != nil {
   294  			return err
   295  		}
   296  		_, err = stream.Write(bs)
   297  		return err
   298  	}
   299  }
   300  
   301  // writeDetector provides a helper method to block until the underlying writer written.
   302  type writeDetector struct {
   303  	written chan bool
   304  	closed  bool
   305  	io.Writer
   306  }
   307  
   308  func newWriterDetector(w io.Writer) *writeDetector {
   309  	return &writeDetector{
   310  		written: make(chan bool),
   311  		Writer:  w,
   312  	}
   313  }
   314  
   315  func (w *writeDetector) BlockUntilWritten() {
   316  	<-w.written
   317  }
   318  
   319  func (w *writeDetector) Write(p []byte) (n int, err error) {
   320  	if !w.closed {
   321  		close(w.written)
   322  		w.closed = true
   323  	}
   324  	return w.Writer.Write(p)
   325  }
   326  
   327  // `Executor.StreamWithContext` starts a goroutine in the background to do the streaming
   328  // and expects the deferred close of the connection leads to the exit of the goroutine on cancellation.
   329  // This test verifies that works.
   330  func TestStreamExitsAfterConnectionIsClosed(t *testing.T) {
   331  	writeDetector := newWriterDetector(&fakeEmptyDataPty{})
   332  	options := StreamOptions{
   333  		Stdin:  &fakeEmptyDataPty{},
   334  		Stdout: writeDetector,
   335  	}
   336  	server := newTestHTTPServer(fakeMassiveDataAttacher, &options)
   337  
   338  	ctx, cancelFn := context.WithTimeout(context.Background(), 500*time.Millisecond)
   339  	defer cancelFn()
   340  
   341  	uri, _ := url.Parse(server.URL)
   342  	exec, err := NewSPDYExecutor(&rest.Config{Host: uri.Host}, "POST", uri)
   343  	if err != nil {
   344  		t.Fatal(err)
   345  	}
   346  	streamExec := exec.(*spdyStreamExecutor)
   347  
   348  	conn, streamer, err := streamExec.newConnectionAndStream(ctx, options)
   349  	if err != nil {
   350  		t.Fatal(err)
   351  	}
   352  
   353  	errorChan := make(chan error)
   354  	go func() {
   355  		errorChan <- streamer.stream(conn)
   356  	}()
   357  
   358  	// Wait until stream goroutine starts.
   359  	writeDetector.BlockUntilWritten()
   360  
   361  	// Close the connection
   362  	conn.Close()
   363  
   364  	select {
   365  	case <-time.After(1 * time.Second):
   366  		t.Fatalf("expect stream to be closed after connection is closed.")
   367  	case <-errorChan:
   368  		return
   369  	}
   370  }
   371  
   372  func TestStreamRandomData(t *testing.T) {
   373  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   374  		var stdin, stdout bytes.Buffer
   375  		ctx, err := createHTTPStreams(w, req, &StreamOptions{
   376  			Stdin:  &stdin,
   377  			Stdout: &stdout,
   378  		})
   379  		if err != nil {
   380  			t.Errorf("error on createHTTPStreams: %v", err)
   381  			return
   382  		}
   383  		defer ctx.conn.Close()
   384  
   385  		io.Copy(ctx.stdoutStream, ctx.stdinStream) //nolint:errcheck
   386  	}))
   387  
   388  	defer server.Close()
   389  
   390  	uri, _ := url.Parse(server.URL)
   391  	exec, err := NewSPDYExecutor(&rest.Config{Host: uri.Host}, "POST", uri)
   392  	if err != nil {
   393  		t.Fatal(err)
   394  	}
   395  
   396  	randomData := make([]byte, 1024*1024)
   397  	if _, err := rand.Read(randomData); err != nil {
   398  		t.Errorf("unexpected error reading random data: %v", err)
   399  	}
   400  	var stdout bytes.Buffer
   401  	options := &StreamOptions{
   402  		Stdin:  bytes.NewReader(randomData),
   403  		Stdout: &stdout,
   404  	}
   405  	errorChan := make(chan error)
   406  	go func() {
   407  		errorChan <- exec.StreamWithContext(context.Background(), *options)
   408  	}()
   409  
   410  	select {
   411  	case <-time.After(wait.ForeverTestTimeout):
   412  		t.Fatalf("expect stream to be closed after connection is closed.")
   413  	case err := <-errorChan:
   414  		if err != nil {
   415  			t.Errorf("unexpected error")
   416  		}
   417  	}
   418  
   419  	data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
   420  	if err != nil {
   421  		t.Errorf("error reading the stream: %v", err)
   422  		return
   423  	}
   424  	if !bytes.Equal(randomData, data) {
   425  		t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
   426  	}
   427  
   428  }
   429  

View as plain text