...

Source file src/github.com/spf13/afero/sftpfs/sftp_test.go

Documentation: github.com/spf13/afero/sftpfs

     1  // Copyright © 2015 Jerry Jacobs <jerry.jacobs@xor-gate.org>.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software
     9  // distributed under the License is distributed on an "AS IS" BASIS,
    10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package sftpfs
    15  
    16  import (
    17  	_rand "crypto/rand"
    18  	"crypto/rsa"
    19  	"crypto/x509"
    20  	"encoding/pem"
    21  	"flag"
    22  	"fmt"
    23  	"io"
    24  	"log"
    25  	"net"
    26  	"os"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/pkg/sftp"
    31  	"golang.org/x/crypto/ssh"
    32  )
    33  
    34  type SftpFsContext struct {
    35  	sshc   *ssh.Client
    36  	sshcfg *ssh.ClientConfig
    37  	sftpc  *sftp.Client
    38  }
    39  
    40  // TODO we only connect with hardcoded user+pass for now
    41  // it should be possible to use $HOME/.ssh/id_rsa to login into the stub sftp server
    42  func SftpConnect(user, password, host string) (*SftpFsContext, error) {
    43  	/*
    44  		pemBytes, err := ioutil.ReadFile(os.Getenv("HOME") + "/.ssh/id_rsa")
    45  		if err != nil {
    46  			return nil,err
    47  		}
    48  
    49  		signer, err := ssh.ParsePrivateKey(pemBytes)
    50  		if err != nil {
    51  			return nil,err
    52  		}
    53  
    54  		sshcfg := &ssh.ClientConfig{
    55  			User: user,
    56  			Auth: []ssh.AuthMethod{
    57  				ssh.Password(password),
    58  				ssh.PublicKeys(signer),
    59  			},
    60  		}
    61  	*/
    62  
    63  	sshcfg := &ssh.ClientConfig{
    64  		User: user,
    65  		Auth: []ssh.AuthMethod{
    66  			ssh.Password(password),
    67  		},
    68  		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
    69  	}
    70  
    71  	sshc, err := ssh.Dial("tcp", host, sshcfg)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  
    76  	sftpc, err := sftp.NewClient(sshc)
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  
    81  	ctx := &SftpFsContext{
    82  		sshc:   sshc,
    83  		sshcfg: sshcfg,
    84  		sftpc:  sftpc,
    85  	}
    86  
    87  	return ctx, nil
    88  }
    89  
    90  func (ctx *SftpFsContext) Disconnect() error {
    91  	ctx.sftpc.Close()
    92  	ctx.sshc.Close()
    93  	return nil
    94  }
    95  
    96  // TODO for such a weird reason rootpath is "." when writing "file1" with afero sftp backend
    97  func RunSftpServer(rootpath string) {
    98  	var (
    99  		readOnly      bool
   100  		debugLevelStr string
   101  		debugStderr   bool
   102  		rootDir       string
   103  	)
   104  
   105  	flag.BoolVar(&readOnly, "R", false, "read-only server")
   106  	flag.BoolVar(&debugStderr, "e", true, "debug to stderr")
   107  	flag.StringVar(&debugLevelStr, "l", "none", "debug level")
   108  	flag.StringVar(&rootDir, "root", rootpath, "root directory")
   109  	flag.Parse()
   110  
   111  	debugStream := io.Discard
   112  
   113  	// An SSH server is represented by a ServerConfig, which holds
   114  	// certificate details and handles authentication of ServerConns.
   115  	config := &ssh.ServerConfig{
   116  		PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
   117  			// Should use constant-time compare (or better, salt+hash) in
   118  			// a production setting.
   119  			fmt.Fprintf(debugStream, "Login: %s\n", c.User())
   120  			if c.User() == "test" && string(pass) == "test" {
   121  				return nil, nil
   122  			}
   123  			return nil, fmt.Errorf("password rejected for %q", c.User())
   124  		},
   125  	}
   126  
   127  	privateBytes, err := os.ReadFile("./test/id_rsa")
   128  	if err != nil {
   129  		log.Fatal("Failed to load private key", err)
   130  	}
   131  
   132  	private, err := ssh.ParsePrivateKey(privateBytes)
   133  	if err != nil {
   134  		log.Fatal("Failed to parse private key", err)
   135  	}
   136  
   137  	config.AddHostKey(private)
   138  
   139  	// Once a ServerConfig has been configured, connections can be
   140  	// accepted.
   141  	listener, err := net.Listen("tcp", "0.0.0.0:2022")
   142  	if err != nil {
   143  		log.Fatal("failed to listen for connection", err)
   144  	}
   145  
   146  	nConn, err := listener.Accept()
   147  	if err != nil {
   148  		log.Fatal("failed to accept incoming connection", err)
   149  	}
   150  
   151  	// Before use, a handshake must be performed on the incoming
   152  	// net.Conn.
   153  	conn, chans, reqs, err := ssh.NewServerConn(nConn, config)
   154  	if err != nil {
   155  		log.Fatal("failed to handshake", err)
   156  	}
   157  	defer conn.Close()
   158  
   159  	// The incoming Request channel must be serviced.
   160  	go ssh.DiscardRequests(reqs)
   161  
   162  	// Service the incoming Channel channel.
   163  	for newChannel := range chans {
   164  		// Channels have a type, depending on the application level
   165  		// protocol intended. In the case of an SFTP session, this is "subsystem"
   166  		// with a payload string of "<length=4>sftp"
   167  		fmt.Fprintf(debugStream, "Incoming channel: %s\n", newChannel.ChannelType())
   168  		if newChannel.ChannelType() != "session" {
   169  			newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
   170  			fmt.Fprintf(debugStream, "Unknown channel type: %s\n", newChannel.ChannelType())
   171  			continue
   172  		}
   173  		channel, requests, err := newChannel.Accept()
   174  		if err != nil {
   175  			log.Fatal("could not accept channel.", err)
   176  		}
   177  		fmt.Fprintf(debugStream, "Channel accepted\n")
   178  
   179  		// Sessions have out-of-band requests such as "shell",
   180  		// "pty-req" and "env".  Here we handle only the
   181  		// "subsystem" request.
   182  		go func(in <-chan *ssh.Request) {
   183  			for req := range in {
   184  				fmt.Fprintf(debugStream, "Request: %v\n", req.Type)
   185  				ok := false
   186  				switch req.Type {
   187  				case "subsystem":
   188  					fmt.Fprintf(debugStream, "Subsystem: %s\n", req.Payload[4:])
   189  					if string(req.Payload[4:]) == "sftp" {
   190  						ok = true
   191  					}
   192  				}
   193  				fmt.Fprintf(debugStream, " - accepted: %v\n", ok)
   194  				req.Reply(ok, nil)
   195  			}
   196  		}(requests)
   197  
   198  		server, err := sftp.NewServer(channel, sftp.WithDebug(debugStream))
   199  		if err != nil {
   200  			log.Fatal(err)
   201  		}
   202  		_ = server.Serve()
   203  		return
   204  	}
   205  }
   206  
   207  // MakeSSHKeyPair make a pair of public and private keys for SSH access.
   208  // Public key is encoded in the format for inclusion in an OpenSSH authorized_keys file.
   209  // Private Key generated is PEM encoded
   210  func MakeSSHKeyPair(bits int, pubKeyPath, privateKeyPath string) error {
   211  	privateKey, err := rsa.GenerateKey(_rand.Reader, bits)
   212  	if err != nil {
   213  		return err
   214  	}
   215  
   216  	// generate and write private key as PEM
   217  	privateKeyFile, err := os.Create(privateKeyPath)
   218  	if err != nil {
   219  		return err
   220  	}
   221  	defer privateKeyFile.Close()
   222  	if err != nil {
   223  		return err
   224  	}
   225  
   226  	privateKeyPEM := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}
   227  	if err := pem.Encode(privateKeyFile, privateKeyPEM); err != nil {
   228  		return err
   229  	}
   230  
   231  	// generate and write public key
   232  	pub, err := ssh.NewPublicKey(&privateKey.PublicKey)
   233  	if err != nil {
   234  		return err
   235  	}
   236  
   237  	return os.WriteFile(pubKeyPath, ssh.MarshalAuthorizedKey(pub), 0o655)
   238  }
   239  
   240  func TestSftpCreate(t *testing.T) {
   241  	os.Mkdir("./test", 0o777)
   242  	MakeSSHKeyPair(1024, "./test/id_rsa.pub", "./test/id_rsa")
   243  
   244  	go RunSftpServer("./test/")
   245  	time.Sleep(5 * time.Second)
   246  
   247  	ctx, err := SftpConnect("test", "test", "localhost:2022")
   248  	if err != nil {
   249  		t.Fatal(err)
   250  	}
   251  	defer ctx.Disconnect()
   252  
   253  	fs := New(ctx.sftpc)
   254  
   255  	fs.MkdirAll("test/dir1/dir2/dir3", os.FileMode(0o777))
   256  	fs.Mkdir("test/foo", os.FileMode(0o000))
   257  	fs.Chmod("test/foo", os.FileMode(0o700))
   258  	fs.Mkdir("test/bar", os.FileMode(0o777))
   259  
   260  	file, err := fs.Create("file1")
   261  	if err != nil {
   262  		t.Error(err)
   263  	}
   264  	defer file.Close()
   265  
   266  	file.Write([]byte("hello "))
   267  	file.WriteString("world!\n")
   268  
   269  	f1, err := fs.Open("file1")
   270  	if err != nil {
   271  		log.Fatalf("open: %v", err)
   272  	}
   273  	defer f1.Close()
   274  
   275  	b := make([]byte, 100)
   276  
   277  	_, _ = f1.Read(b)
   278  	fmt.Println(string(b))
   279  
   280  	fmt.Println("done")
   281  	// TODO check here if "hello\tworld\n" is in buffer b
   282  }
   283  

View as plain text