1 package edgeencrypt
2
3 import (
4 "crypto/rsa"
5 "fmt"
6 "slices"
7 "time"
8
9 "github.com/golang-jwt/jwt"
10 )
11
12 type Token struct {
13 BearerToken string `json:"token"`
14 Version string `json:"version"`
15 }
16
17 const (
18 Week = 8760 * time.Hour
19 Year = 52 * Week
20 DefaultDuration = Year
21 )
22
23 type Role string
24
25 const (
26 Encryption = Role("encryption")
27 Decryption = Role("decryption")
28 )
29
30 var AllRoles = []Role{Encryption, Decryption}
31
32 func (r Role) Valid() bool {
33 if len(r) == 0 {
34 return false
35 }
36 return slices.Contains(AllRoles, r)
37 }
38
39 type EncryptionClaims struct {
40 jwt.StandardClaims
41 BannerEdgeID string `json:"bannerEdgeID,omitempty"`
42 ChannelID string `json:"channelId"`
43 Channel string `json:"channel"`
44 Role Role `json:"role"`
45 }
46
47 func (c *EncryptionClaims) Valid() error {
48 if len(c.ChannelID) == 0 {
49 return fmt.Errorf("channelID is required")
50 }
51 if len(c.Channel) == 0 {
52 return fmt.Errorf("channel is required")
53 }
54 if !c.Role.Valid() {
55 return fmt.Errorf("invalid role: %v", c.Role)
56 }
57 if c.Role == Encryption && len(c.BannerEdgeID) == 0 {
58 return fmt.Errorf("invalid bannerEdgeID is required for encryption role")
59 }
60 return c.StandardClaims.Valid()
61 }
62
63 func (c *EncryptionClaims) ValidChannel(channelID string) bool {
64 return c.ChannelID == channelID
65 }
66
67 func (c *EncryptionClaims) ValidChannelName(channelName string) bool {
68 return c.Channel == channelName
69 }
70
71 func (c *EncryptionClaims) HasRole(role Role) bool {
72 return c.Role == role
73 }
74
75
76 func CreateToken(method jwt.SigningMethod, key interface{}, duration time.Duration, channelID, channelName string, role Role, banner ...string) (string, error) {
77 if method == nil {
78 return "", fmt.Errorf("signing method is required to create token")
79 }
80 if !validSigningMethod(method) {
81 return "", fmt.Errorf("unsupported signing method: %T, must be RSA256", method)
82 }
83 if key == nil {
84 return "", fmt.Errorf("signing key is required to create token")
85 }
86 if duration == 0 {
87 duration = DefaultDuration
88 }
89 if len(channelID) == 0 {
90 return "", fmt.Errorf("channelID is required")
91 }
92 if len(channelName) == 0 {
93 return "", fmt.Errorf("channelName is required")
94 }
95 if !role.Valid() {
96 return "", fmt.Errorf("invalid role: %v", role)
97 }
98 var bannerID string
99 if role == Encryption {
100 if len(banner) == 0 {
101 return "", fmt.Errorf("bannerEdgeID is required for encryption role")
102 }
103 bannerID = banner[0]
104 }
105 c := &EncryptionClaims{
106 BannerEdgeID: bannerID,
107 ChannelID: channelID,
108 Channel: channelName,
109 Role: role,
110 StandardClaims: jwt.StandardClaims{
111 ExpiresAt: time.Now().Add(duration).Unix(),
112 },
113 }
114
115 claims := jwt.NewWithClaims(method, c)
116 token, err := claims.SignedString(key)
117 if err != nil {
118 return "", fmt.Errorf("failed to sign token: %v", err)
119 }
120 return token, nil
121 }
122
123
124 func FromToken(publicKey *rsa.PublicKey, bearerToken string) (*EncryptionClaims, error) {
125 c := &EncryptionClaims{}
126 _, err := jwt.ParseWithClaims(bearerToken, c,
127 func(token *jwt.Token) (interface{}, error) {
128 if !validSigningMethod(token.Method) {
129 return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
130 }
131 return publicKey, nil
132 })
133 if err != nil {
134 return nil, fmt.Errorf("failed to parse token: %v", err)
135 }
136 return c, nil
137 }
138
139 func validSigningMethod(method jwt.SigningMethod) bool {
140 if method == nil {
141 return false
142 }
143 _, isRSA := method.(*jwt.SigningMethodRSA)
144 _, isKMS := method.(*SigningMethodKMS)
145 return isRSA || isKMS
146 }
147
View as plain text