1package ssocreds
  2
  3import (
  4	"crypto/sha1"
  5	"encoding/hex"
  6	"encoding/json"
  7	"fmt"
  8	"io/ioutil"
  9	"os"
 10	"path/filepath"
 11	"strconv"
 12	"strings"
 13	"time"
 14
 15	"github.com/aws/aws-sdk-go-v2/internal/sdk"
 16	"github.com/aws/aws-sdk-go-v2/internal/shareddefaults"
 17)
 18
 19var osUserHomeDur = shareddefaults.UserHomeDir
 20
 21// StandardCachedTokenFilepath returns the filepath for the cached SSO token file, or
 22// error if unable get derive the path. Key that will be used to compute a SHA1
 23// value that is hex encoded.
 24//
 25// Derives the filepath using the Key as:
 26//
 27//	~/.aws/sso/cache/<sha1-hex-encoded-key>.json
 28func StandardCachedTokenFilepath(key string) (string, error) {
 29	homeDir := osUserHomeDur()
 30	if len(homeDir) == 0 {
 31		return "", fmt.Errorf("unable to get USER's home directory for cached token")
 32	}
 33	hash := sha1.New()
 34	if _, err := hash.Write([]byte(key)); err != nil {
 35		return "", fmt.Errorf("unable to compute cached token filepath key SHA1 hash, %w", err)
 36	}
 37
 38	cacheFilename := strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json"
 39
 40	return filepath.Join(homeDir, ".aws", "sso", "cache", cacheFilename), nil
 41}
 42
 43type tokenKnownFields struct {
 44	AccessToken string   `json:"accessToken,omitempty"`
 45	ExpiresAt   *rfc3339 `json:"expiresAt,omitempty"`
 46
 47	RefreshToken string `json:"refreshToken,omitempty"`
 48	ClientID     string `json:"clientId,omitempty"`
 49	ClientSecret string `json:"clientSecret,omitempty"`
 50}
 51
 52type token struct {
 53	tokenKnownFields
 54	UnknownFields map[string]interface{} `json:"-"`
 55}
 56
 57func (t token) MarshalJSON() ([]byte, error) {
 58	fields := map[string]interface{}{}
 59
 60	setTokenFieldString(fields, "accessToken", t.AccessToken)
 61	setTokenFieldRFC3339(fields, "expiresAt", t.ExpiresAt)
 62
 63	setTokenFieldString(fields, "refreshToken", t.RefreshToken)
 64	setTokenFieldString(fields, "clientId", t.ClientID)
 65	setTokenFieldString(fields, "clientSecret", t.ClientSecret)
 66
 67	for k, v := range t.UnknownFields {
 68		if _, ok := fields[k]; ok {
 69			return nil, fmt.Errorf("unknown token field %v, duplicates known field", k)
 70		}
 71		fields[k] = v
 72	}
 73
 74	return json.Marshal(fields)
 75}
 76
 77func setTokenFieldString(fields map[string]interface{}, key, value string) {
 78	if value == "" {
 79		return
 80	}
 81	fields[key] = value
 82}
 83func setTokenFieldRFC3339(fields map[string]interface{}, key string, value *rfc3339) {
 84	if value == nil {
 85		return
 86	}
 87	fields[key] = value
 88}
 89
 90func (t *token) UnmarshalJSON(b []byte) error {
 91	var fields map[string]interface{}
 92	if err := json.Unmarshal(b, &fields); err != nil {
 93		return nil
 94	}
 95
 96	t.UnknownFields = map[string]interface{}{}
 97
 98	for k, v := range fields {
 99		var err error
100		switch k {
101		case "accessToken":
102			err = getTokenFieldString(v, &t.AccessToken)
103		case "expiresAt":
104			err = getTokenFieldRFC3339(v, &t.ExpiresAt)
105		case "refreshToken":
106			err = getTokenFieldString(v, &t.RefreshToken)
107		case "clientId":
108			err = getTokenFieldString(v, &t.ClientID)
109		case "clientSecret":
110			err = getTokenFieldString(v, &t.ClientSecret)
111		default:
112			t.UnknownFields[k] = v
113		}
114
115		if err != nil {
116			return fmt.Errorf("field %q, %w", k, err)
117		}
118	}
119
120	return nil
121}
122
123func getTokenFieldString(v interface{}, value *string) error {
124	var ok bool
125	*value, ok = v.(string)
126	if !ok {
127		return fmt.Errorf("expect value to be string, got %T", v)
128	}
129	return nil
130}
131
132func getTokenFieldRFC3339(v interface{}, value **rfc3339) error {
133	var stringValue string
134	if err := getTokenFieldString(v, &stringValue); err != nil {
135		return err
136	}
137
138	timeValue, err := parseRFC3339(stringValue)
139	if err != nil {
140		return err
141	}
142
143	*value = &timeValue
144	return nil
145}
146
147func loadCachedToken(filename string) (token, error) {
148	fileBytes, err := ioutil.ReadFile(filename)
149	if err != nil {
150		return token{}, fmt.Errorf("failed to read cached SSO token file, %w", err)
151	}
152
153	var t token
154	if err := json.Unmarshal(fileBytes, &t); err != nil {
155		return token{}, fmt.Errorf("failed to parse cached SSO token file, %w", err)
156	}
157
158	if len(t.AccessToken) == 0 || t.ExpiresAt == nil || time.Time(*t.ExpiresAt).IsZero() {
159		return token{}, fmt.Errorf(
160			"cached SSO token must contain accessToken and expiresAt fields")
161	}
162
163	return t, nil
164}
165
166func storeCachedToken(filename string, t token, fileMode os.FileMode) (err error) {
167	tmpFilename := filename + ".tmp-" + strconv.FormatInt(sdk.NowTime().UnixNano(), 10)
168	if err := writeCacheFile(tmpFilename, fileMode, t); err != nil {
169		return err
170	}
171
172	if err := os.Rename(tmpFilename, filename); err != nil {
173		return fmt.Errorf("failed to replace old cached SSO token file, %w", err)
174	}
175
176	return nil
177}
178
179func writeCacheFile(filename string, fileMode os.FileMode, t token) (err error) {
180	var f *os.File
181	f, err = os.OpenFile(filename, os.O_CREATE|os.O_TRUNC|os.O_RDWR, fileMode)
182	if err != nil {
183		return fmt.Errorf("failed to create cached SSO token file %w", err)
184	}
185
186	defer func() {
187		closeErr := f.Close()
188		if err == nil && closeErr != nil {
189			err = fmt.Errorf("failed to close cached SSO token file, %w", closeErr)
190		}
191	}()
192
193	encoder := json.NewEncoder(f)
194
195	if err = encoder.Encode(t); err != nil {
196		return fmt.Errorf("failed to serialize cached SSO token, %w", err)
197	}
198
199	return nil
200}
201
202type rfc3339 time.Time
203
204func parseRFC3339(v string) (rfc3339, error) {
205	parsed, err := time.Parse(time.RFC3339, v)
206	if err != nil {
207		return rfc3339{}, fmt.Errorf("expected RFC3339 timestamp: %w", err)
208	}
209
210	return rfc3339(parsed), nil
211}
212
213func (r *rfc3339) UnmarshalJSON(bytes []byte) (err error) {
214	var value string
215
216	// Use JSON unmarshal to unescape the quoted value making use of JSON's
217	// unquoting rules.
218	if err = json.Unmarshal(bytes, &value); err != nil {
219		return err
220	}
221
222	*r, err = parseRFC3339(value)
223
224	return nil
225}
226
227func (r *rfc3339) MarshalJSON() ([]byte, error) {
228	value := time.Time(*r).Format(time.RFC3339)
229
230	// Use JSON unmarshal to unescape the quoted value making use of JSON's
231	// quoting rules.
232	return json.Marshal(value)
233}