...
1
16
17 package ttrpc
18
19 import (
20 "context"
21 "errors"
22 "fmt"
23 "net"
24 "os"
25 "syscall"
26
27 "golang.org/x/sys/unix"
28 )
29
30 type UnixCredentialsFunc func(*unix.Ucred) error
31
32 func (fn UnixCredentialsFunc) Handshake(_ context.Context, conn net.Conn) (net.Conn, interface{}, error) {
33 uc, err := requireUnixSocket(conn)
34 if err != nil {
35 return nil, nil, fmt.Errorf("ttrpc.UnixCredentialsFunc: require unix socket: %w", err)
36 }
37
38 rs, err := uc.SyscallConn()
39 if err != nil {
40 return nil, nil, fmt.Errorf("ttrpc.UnixCredentialsFunc: (net.UnixConn).SyscallConn failed: %w", err)
41 }
42 var (
43 ucred *unix.Ucred
44 ucredErr error
45 )
46 if err := rs.Control(func(fd uintptr) {
47 ucred, ucredErr = unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED)
48 }); err != nil {
49 return nil, nil, fmt.Errorf("ttrpc.UnixCredentialsFunc: (*syscall.RawConn).Control failed: %w", err)
50 }
51
52 if ucredErr != nil {
53 return nil, nil, fmt.Errorf("ttrpc.UnixCredentialsFunc: failed to retrieve socket peer credentials: %w", ucredErr)
54 }
55
56 if err := fn(ucred); err != nil {
57 return nil, nil, fmt.Errorf("ttrpc.UnixCredentialsFunc: credential check failed: %w", err)
58 }
59
60 return uc, ucred, nil
61 }
62
63
64
65
66
67
68
69
70
71 func UnixSocketRequireUidGid(uid, gid int) UnixCredentialsFunc {
72 return func(ucred *unix.Ucred) error {
73 return requireUidGid(ucred, uid, gid)
74 }
75 }
76
77 func UnixSocketRequireRoot() UnixCredentialsFunc {
78 return UnixSocketRequireUidGid(0, 0)
79 }
80
81
82
83
84
85
86 func UnixSocketRequireSameUser() UnixCredentialsFunc {
87 euid, egid := os.Geteuid(), os.Getegid()
88 return UnixSocketRequireUidGid(euid, egid)
89 }
90
91 func requireUidGid(ucred *unix.Ucred, uid, gid int) error {
92 if (uid != -1 && uint32(uid) != ucred.Uid) || (gid != -1 && uint32(gid) != ucred.Gid) {
93 return fmt.Errorf("ttrpc: invalid credentials: %v", syscall.EPERM)
94 }
95 return nil
96 }
97
98 func requireUnixSocket(conn net.Conn) (*net.UnixConn, error) {
99 uc, ok := conn.(*net.UnixConn)
100 if !ok {
101 return nil, errors.New("a unix socket connection is required")
102 }
103
104 return uc, nil
105 }
106
View as plain text