util.go

  1// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package websocket
  6
  7import (
  8	"crypto/rand"
  9	"crypto/sha1"
 10	"encoding/base64"
 11	"io"
 12	"net/http"
 13	"strings"
 14	"unicode/utf8"
 15)
 16
 17var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
 18
 19func computeAcceptKey(challengeKey string) string {
 20	h := sha1.New()
 21	h.Write([]byte(challengeKey))
 22	h.Write(keyGUID)
 23	return base64.StdEncoding.EncodeToString(h.Sum(nil))
 24}
 25
 26func generateChallengeKey() (string, error) {
 27	p := make([]byte, 16)
 28	if _, err := io.ReadFull(rand.Reader, p); err != nil {
 29		return "", err
 30	}
 31	return base64.StdEncoding.EncodeToString(p), nil
 32}
 33
 34// Token octets per RFC 2616.
 35var isTokenOctet = [256]bool{
 36	'!':  true,
 37	'#':  true,
 38	'$':  true,
 39	'%':  true,
 40	'&':  true,
 41	'\'': true,
 42	'*':  true,
 43	'+':  true,
 44	'-':  true,
 45	'.':  true,
 46	'0':  true,
 47	'1':  true,
 48	'2':  true,
 49	'3':  true,
 50	'4':  true,
 51	'5':  true,
 52	'6':  true,
 53	'7':  true,
 54	'8':  true,
 55	'9':  true,
 56	'A':  true,
 57	'B':  true,
 58	'C':  true,
 59	'D':  true,
 60	'E':  true,
 61	'F':  true,
 62	'G':  true,
 63	'H':  true,
 64	'I':  true,
 65	'J':  true,
 66	'K':  true,
 67	'L':  true,
 68	'M':  true,
 69	'N':  true,
 70	'O':  true,
 71	'P':  true,
 72	'Q':  true,
 73	'R':  true,
 74	'S':  true,
 75	'T':  true,
 76	'U':  true,
 77	'W':  true,
 78	'V':  true,
 79	'X':  true,
 80	'Y':  true,
 81	'Z':  true,
 82	'^':  true,
 83	'_':  true,
 84	'`':  true,
 85	'a':  true,
 86	'b':  true,
 87	'c':  true,
 88	'd':  true,
 89	'e':  true,
 90	'f':  true,
 91	'g':  true,
 92	'h':  true,
 93	'i':  true,
 94	'j':  true,
 95	'k':  true,
 96	'l':  true,
 97	'm':  true,
 98	'n':  true,
 99	'o':  true,
100	'p':  true,
101	'q':  true,
102	'r':  true,
103	's':  true,
104	't':  true,
105	'u':  true,
106	'v':  true,
107	'w':  true,
108	'x':  true,
109	'y':  true,
110	'z':  true,
111	'|':  true,
112	'~':  true,
113}
114
115// skipSpace returns a slice of the string s with all leading RFC 2616 linear
116// whitespace removed.
117func skipSpace(s string) (rest string) {
118	i := 0
119	for ; i < len(s); i++ {
120		if b := s[i]; b != ' ' && b != '\t' {
121			break
122		}
123	}
124	return s[i:]
125}
126
127// nextToken returns the leading RFC 2616 token of s and the string following
128// the token.
129func nextToken(s string) (token, rest string) {
130	i := 0
131	for ; i < len(s); i++ {
132		if !isTokenOctet[s[i]] {
133			break
134		}
135	}
136	return s[:i], s[i:]
137}
138
139// nextTokenOrQuoted returns the leading token or quoted string per RFC 2616
140// and the string following the token or quoted string.
141func nextTokenOrQuoted(s string) (value string, rest string) {
142	if !strings.HasPrefix(s, "\"") {
143		return nextToken(s)
144	}
145	s = s[1:]
146	for i := 0; i < len(s); i++ {
147		switch s[i] {
148		case '"':
149			return s[:i], s[i+1:]
150		case '\\':
151			p := make([]byte, len(s)-1)
152			j := copy(p, s[:i])
153			escape := true
154			for i = i + 1; i < len(s); i++ {
155				b := s[i]
156				switch {
157				case escape:
158					escape = false
159					p[j] = b
160					j++
161				case b == '\\':
162					escape = true
163				case b == '"':
164					return string(p[:j]), s[i+1:]
165				default:
166					p[j] = b
167					j++
168				}
169			}
170			return "", ""
171		}
172	}
173	return "", ""
174}
175
176// equalASCIIFold returns true if s is equal to t with ASCII case folding as
177// defined in RFC 4790.
178func equalASCIIFold(s, t string) bool {
179	for s != "" && t != "" {
180		sr, size := utf8.DecodeRuneInString(s)
181		s = s[size:]
182		tr, size := utf8.DecodeRuneInString(t)
183		t = t[size:]
184		if sr == tr {
185			continue
186		}
187		if 'A' <= sr && sr <= 'Z' {
188			sr = sr + 'a' - 'A'
189		}
190		if 'A' <= tr && tr <= 'Z' {
191			tr = tr + 'a' - 'A'
192		}
193		if sr != tr {
194			return false
195		}
196	}
197	return s == t
198}
199
200// tokenListContainsValue returns true if the 1#token header with the given
201// name contains a token equal to value with ASCII case folding.
202func tokenListContainsValue(header http.Header, name string, value string) bool {
203headers:
204	for _, s := range header[name] {
205		for {
206			var t string
207			t, s = nextToken(skipSpace(s))
208			if t == "" {
209				continue headers
210			}
211			s = skipSpace(s)
212			if s != "" && s[0] != ',' {
213				continue headers
214			}
215			if equalASCIIFold(t, value) {
216				return true
217			}
218			if s == "" {
219				continue headers
220			}
221			s = s[1:]
222		}
223	}
224	return false
225}
226
227// parseExtensions parses WebSocket extensions from a header.
228func parseExtensions(header http.Header) []map[string]string {
229	// From RFC 6455:
230	//
231	//  Sec-WebSocket-Extensions = extension-list
232	//  extension-list = 1#extension
233	//  extension = extension-token *( ";" extension-param )
234	//  extension-token = registered-token
235	//  registered-token = token
236	//  extension-param = token [ "=" (token | quoted-string) ]
237	//     ;When using the quoted-string syntax variant, the value
238	//     ;after quoted-string unescaping MUST conform to the
239	//     ;'token' ABNF.
240
241	var result []map[string]string
242headers:
243	for _, s := range header["Sec-Websocket-Extensions"] {
244		for {
245			var t string
246			t, s = nextToken(skipSpace(s))
247			if t == "" {
248				continue headers
249			}
250			ext := map[string]string{"": t}
251			for {
252				s = skipSpace(s)
253				if !strings.HasPrefix(s, ";") {
254					break
255				}
256				var k string
257				k, s = nextToken(skipSpace(s[1:]))
258				if k == "" {
259					continue headers
260				}
261				s = skipSpace(s)
262				var v string
263				if strings.HasPrefix(s, "=") {
264					v, s = nextTokenOrQuoted(skipSpace(s[1:]))
265					s = skipSpace(s)
266				}
267				if s != "" && s[0] != ',' && s[0] != ';' {
268					continue headers
269				}
270				ext[k] = v
271			}
272			if s != "" && s[0] != ',' {
273				continue headers
274			}
275			result = append(result, ext)
276			if s == "" {
277				continue headers
278			}
279			s = s[1:]
280		}
281	}
282	return result
283}
284
285// isValidChallengeKey checks if the argument meets RFC6455 specification.
286func isValidChallengeKey(s string) bool {
287	// From RFC6455:
288	//
289	// A |Sec-WebSocket-Key| header field with a base64-encoded (see
290	// Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in
291	// length.
292
293	if s == "" {
294		return false
295	}
296	decoded, err := base64.StdEncoding.DecodeString(s)
297	return err == nil && len(decoded) == 16
298}