1// Copyright 2023 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package jwt
16
17import (
18 "bytes"
19 "crypto"
20 "crypto/rand"
21 "crypto/rsa"
22 "crypto/sha256"
23 "encoding/base64"
24 "encoding/json"
25 "errors"
26 "fmt"
27 "strings"
28 "time"
29)
30
31const (
32 // HeaderAlgRSA256 is the RS256 [Header.Algorithm].
33 HeaderAlgRSA256 = "RS256"
34 // HeaderAlgES256 is the ES256 [Header.Algorithm].
35 HeaderAlgES256 = "ES256"
36 // HeaderType is the standard [Header.Type].
37 HeaderType = "JWT"
38)
39
40// Header represents a JWT header.
41type Header struct {
42 Algorithm string `json:"alg"`
43 Type string `json:"typ"`
44 KeyID string `json:"kid"`
45}
46
47func (h *Header) encode() (string, error) {
48 b, err := json.Marshal(h)
49 if err != nil {
50 return "", err
51 }
52 return base64.RawURLEncoding.EncodeToString(b), nil
53}
54
55// Claims represents the claims set of a JWT.
56type Claims struct {
57 // Iss is the issuer JWT claim.
58 Iss string `json:"iss"`
59 // Scope is the scope JWT claim.
60 Scope string `json:"scope,omitempty"`
61 // Exp is the expiry JWT claim. If unset, default is in one hour from now.
62 Exp int64 `json:"exp"`
63 // Iat is the subject issued at claim. If unset, default is now.
64 Iat int64 `json:"iat"`
65 // Aud is the audience JWT claim. Optional.
66 Aud string `json:"aud"`
67 // Sub is the subject JWT claim. Optional.
68 Sub string `json:"sub,omitempty"`
69 // AdditionalClaims contains any additional non-standard JWT claims. Optional.
70 AdditionalClaims map[string]interface{} `json:"-"`
71}
72
73func (c *Claims) encode() (string, error) {
74 // Compensate for skew
75 now := time.Now().Add(-10 * time.Second)
76 if c.Iat == 0 {
77 c.Iat = now.Unix()
78 }
79 if c.Exp == 0 {
80 c.Exp = now.Add(time.Hour).Unix()
81 }
82 if c.Exp < c.Iat {
83 return "", fmt.Errorf("jwt: invalid Exp = %d; must be later than Iat = %d", c.Exp, c.Iat)
84 }
85
86 b, err := json.Marshal(c)
87 if err != nil {
88 return "", err
89 }
90
91 if len(c.AdditionalClaims) == 0 {
92 return base64.RawURLEncoding.EncodeToString(b), nil
93 }
94
95 // Marshal private claim set and then append it to b.
96 prv, err := json.Marshal(c.AdditionalClaims)
97 if err != nil {
98 return "", fmt.Errorf("invalid map of additional claims %v: %w", c.AdditionalClaims, err)
99 }
100
101 // Concatenate public and private claim JSON objects.
102 if !bytes.HasSuffix(b, []byte{'}'}) {
103 return "", fmt.Errorf("invalid JSON %s", b)
104 }
105 if !bytes.HasPrefix(prv, []byte{'{'}) {
106 return "", fmt.Errorf("invalid JSON %s", prv)
107 }
108 b[len(b)-1] = ',' // Replace closing curly brace with a comma.
109 b = append(b, prv[1:]...) // Append private claims.
110 return base64.RawURLEncoding.EncodeToString(b), nil
111}
112
113// EncodeJWS encodes the data using the provided key as a JSON web signature.
114func EncodeJWS(header *Header, c *Claims, signer crypto.Signer) (string, error) {
115 head, err := header.encode()
116 if err != nil {
117 return "", err
118 }
119 claims, err := c.encode()
120 if err != nil {
121 return "", err
122 }
123 ss := fmt.Sprintf("%s.%s", head, claims)
124 h := sha256.New()
125 h.Write([]byte(ss))
126 sig, err := signer.Sign(rand.Reader, h.Sum(nil), crypto.SHA256)
127 if err != nil {
128 return "", err
129 }
130 return fmt.Sprintf("%s.%s", ss, base64.RawURLEncoding.EncodeToString(sig)), nil
131}
132
133// DecodeJWS decodes a claim set from a JWS payload.
134func DecodeJWS(payload string) (*Claims, error) {
135 // decode returned id token to get expiry
136 s := strings.Split(payload, ".")
137 if len(s) < 2 {
138 return nil, errors.New("invalid token received")
139 }
140 decoded, err := base64.RawURLEncoding.DecodeString(s[1])
141 if err != nil {
142 return nil, err
143 }
144 c := &Claims{}
145 if err := json.NewDecoder(bytes.NewBuffer(decoded)).Decode(c); err != nil {
146 return nil, err
147 }
148 if err := json.NewDecoder(bytes.NewBuffer(decoded)).Decode(&c.AdditionalClaims); err != nil {
149 return nil, err
150 }
151 return c, err
152}
153
154// VerifyJWS tests whether the provided JWT token's signature was produced by
155// the private key associated with the provided public key.
156func VerifyJWS(token string, key *rsa.PublicKey) error {
157 parts := strings.Split(token, ".")
158 if len(parts) != 3 {
159 return errors.New("jwt: invalid token received, token must have 3 parts")
160 }
161
162 signedContent := parts[0] + "." + parts[1]
163 signatureString, err := base64.RawURLEncoding.DecodeString(parts[2])
164 if err != nil {
165 return err
166 }
167
168 h := sha256.New()
169 h.Write([]byte(signedContent))
170 return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), signatureString)
171}