1
16
17
18
19 package openidmetadata
20
21 import (
22 "context"
23 "fmt"
24 "log"
25 "net"
26 "net/http"
27 "net/url"
28 "os"
29 "runtime"
30 "time"
31
32 "github.com/coreos/go-oidc"
33 "github.com/spf13/cobra"
34 "golang.org/x/oauth2"
35 "gopkg.in/square/go-jose.v2/jwt"
36 "k8s.io/apimachinery/pkg/util/wait"
37 "k8s.io/client-go/rest"
38 )
39
40
41 var CmdTestServiceAccountIssuerDiscovery = &cobra.Command{
42 Use: "test-service-account-issuer-discovery",
43 Short: "Tests the ServiceAccountIssuerDiscovery feature",
44 Long: "Reads in a mounted token and attempts to verify it against the API server's " +
45 "OIDC endpoints, using a third-party OIDC implementation.",
46 Args: cobra.MaximumNArgs(0),
47 Run: main,
48 }
49
50 var (
51 tokenPath string
52 audience string
53 )
54
55 func init() {
56 fs := CmdTestServiceAccountIssuerDiscovery.Flags()
57 fs.StringVar(&tokenPath, "token-path", "", "Path to read service account token from.")
58 fs.StringVar(&audience, "audience", "", "Audience to check on received token.")
59 }
60
61 func main(cmd *cobra.Command, args []string) {
62 raw, err := gettoken()
63 if err != nil {
64 log.Fatal(err)
65 }
66 log.Print("OK: Got token")
67
68
80
81 log.Print("validating with in-cluster discovery")
82 inClusterCtx, err := withInClusterOauth2Client(context.Background())
83 if err != nil {
84 log.Fatal(err)
85 }
86 if err := validate(inClusterCtx, raw); err == nil {
87 os.Exit(0)
88 } else {
89 log.Print("failed to validate with in-cluster discovery: ", err)
90 }
91
92 log.Print("falling back to validating with external discovery")
93 externalCtx, err := withExternalOAuth2Client(context.Background())
94 if err != nil {
95 log.Fatal(err)
96 }
97 if err := validate(externalCtx, raw); err != nil {
98 log.Fatal(err)
99 }
100 }
101
102 func validate(ctx context.Context, raw string) error {
103 tok, err := jwt.ParseSigned(raw)
104 if err != nil {
105 log.Fatal(err)
106 }
107 var unsafeClaims claims
108 if err := tok.UnsafeClaimsWithoutVerification(&unsafeClaims); err != nil {
109 log.Fatal(err)
110 }
111 log.Printf("OK: got issuer %s", unsafeClaims.Issuer)
112 log.Printf("Full, not-validated claims: \n%#v", unsafeClaims)
113
114 if runtime.GOOS == "windows" {
115 if err := ensureWindowsDNSAvailability(unsafeClaims.Issuer); err != nil {
116 log.Fatal(err)
117 }
118 }
119
120 iss, err := oidc.NewProvider(ctx, unsafeClaims.Issuer)
121 if err != nil {
122 return err
123 }
124 log.Printf("OK: Constructed OIDC provider for issuer %v", unsafeClaims.Issuer)
125
126 validTok, err := iss.Verifier(&oidc.Config{
127 ClientID: audience,
128 SupportedSigningAlgs: []string{oidc.RS256, oidc.ES256},
129 }).Verify(ctx, raw)
130 if err != nil {
131 return err
132 }
133 log.Print("OK: Validated signature on JWT")
134
135 var safeClaims claims
136 if err := validTok.Claims(&safeClaims); err != nil {
137 return err
138 }
139 log.Print("OK: Got valid claims from token!")
140 log.Printf("Full, validated claims: \n%#v", &safeClaims)
141 return nil
142 }
143
144 type kubeName struct {
145 Name string `json:"name"`
146 UID string `json:"uid"`
147 }
148
149 type kubeClaims struct {
150 Namespace string `json:"namespace"`
151 ServiceAccount kubeName `json:"serviceaccount"`
152 }
153
154 type claims struct {
155 jwt.Claims
156
157 Kubernetes kubeClaims `json:"kubernetes.io"`
158 }
159
160 func (k *claims) String() string {
161 return fmt.Sprintf("%s/%s for %s", k.Kubernetes.Namespace, k.Kubernetes.ServiceAccount.Name, k.Audience)
162 }
163
164 func gettoken() (string, error) {
165 b, err := os.ReadFile(tokenPath)
166 return string(b), err
167 }
168
169 func withExternalOAuth2Client(ctx context.Context) (context.Context, error) {
170
171
172 return context.WithValue(ctx,
173
174
175 oauth2.HTTPClient, &http.Client{
176 Transport: http.DefaultTransport,
177 }), nil
178 }
179
180 func withInClusterOauth2Client(ctx context.Context) (context.Context, error) {
181
182 cfg, err := rest.InClusterConfig()
183 if err != nil {
184 return nil, err
185 }
186
187 rt, err := rest.TransportFor(cfg)
188 if err != nil {
189 return nil, fmt.Errorf("could not get roundtripper: %v", err)
190 }
191
192 return context.WithValue(ctx,
193
194
195 oauth2.HTTPClient, &http.Client{
196 Transport: rt,
197 }), nil
198 }
199
200
201
202
203
204
205 func ensureWindowsDNSAvailability(issuer string) error {
206 log.Println("Ensuring Windows DNS availability")
207
208 u, err := url.Parse(issuer)
209 if err != nil {
210 return err
211 }
212
213 return wait.PollImmediate(5*time.Second, 20*time.Second, func() (bool, error) {
214 ips, err := net.LookupHost(u.Host)
215 if err != nil {
216 log.Println(err)
217 return false, nil
218 }
219 log.Printf("OK: Resolved host %s: %v", u.Host, ips)
220 return true, nil
221 })
222 }
223
View as plain text