1package transport
2
3import (
4 "crypto/rand"
5 "crypto/sha256"
6 "encoding/base64"
7 "fmt"
8 "net/url"
9)
10
11// GenerateRandomString generates a random string of the specified length
12func GenerateRandomString(length int) (string, error) {
13 bytes := make([]byte, length)
14 if _, err := rand.Read(bytes); err != nil {
15 return "", err
16 }
17 return base64.RawURLEncoding.EncodeToString(bytes)[:length], nil
18}
19
20// GenerateCodeVerifier generates a code verifier for PKCE
21func GenerateCodeVerifier() (string, error) {
22 // According to RFC 7636, the code verifier should be between 43 and 128 characters
23 return GenerateRandomString(64)
24}
25
26// GenerateCodeChallenge generates a code challenge from a code verifier
27func GenerateCodeChallenge(codeVerifier string) string {
28 // SHA256 hash the code verifier
29 hash := sha256.Sum256([]byte(codeVerifier))
30 // Base64url encode the hash
31 return base64.RawURLEncoding.EncodeToString(hash[:])
32}
33
34// GenerateState generates a state parameter for OAuth
35func GenerateState() (string, error) {
36 return GenerateRandomString(32)
37}
38
39// ValidateRedirectURI validates that a redirect URI is secure
40func ValidateRedirectURI(redirectURI string) error {
41 // According to the spec, redirect URIs must be either localhost URLs or HTTPS URLs
42 if redirectURI == "" {
43 return fmt.Errorf("redirect URI cannot be empty")
44 }
45
46 // Parse the URL
47 parsedURL, err := url.Parse(redirectURI)
48 if err != nil {
49 return fmt.Errorf("invalid redirect URI: %w", err)
50 }
51
52 // Check if it's a localhost URL
53 if parsedURL.Scheme == "http" {
54 hostname := parsedURL.Hostname()
55 // Check for various forms of localhost
56 if hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1" || hostname == "[::1]" {
57 return nil
58 }
59 return fmt.Errorf("HTTP redirect URI must use localhost or 127.0.0.1")
60 }
61
62 // Check if it's an HTTPS URL
63 if parsedURL.Scheme == "https" {
64 return nil
65 }
66
67 return fmt.Errorf("redirect URI must use either HTTP with localhost or HTTPS")
68}