server.go

  1// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package websocket
  6
  7import (
  8	"bufio"
  9	"errors"
 10	"io"
 11	"net/http"
 12	"net/url"
 13	"strings"
 14	"time"
 15)
 16
 17// HandshakeError describes an error with the handshake from the peer.
 18type HandshakeError struct {
 19	message string
 20}
 21
 22func (e HandshakeError) Error() string { return e.message }
 23
 24// Upgrader specifies parameters for upgrading an HTTP connection to a
 25// WebSocket connection.
 26//
 27// It is safe to call Upgrader's methods concurrently.
 28type Upgrader struct {
 29	// HandshakeTimeout specifies the duration for the handshake to complete.
 30	HandshakeTimeout time.Duration
 31
 32	// ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
 33	// size is zero, then buffers allocated by the HTTP server are used. The
 34	// I/O buffer sizes do not limit the size of the messages that can be sent
 35	// or received.
 36	ReadBufferSize, WriteBufferSize int
 37
 38	// WriteBufferPool is a pool of buffers for write operations. If the value
 39	// is not set, then write buffers are allocated to the connection for the
 40	// lifetime of the connection.
 41	//
 42	// A pool is most useful when the application has a modest volume of writes
 43	// across a large number of connections.
 44	//
 45	// Applications should use a single pool for each unique value of
 46	// WriteBufferSize.
 47	WriteBufferPool BufferPool
 48
 49	// Subprotocols specifies the server's supported protocols in order of
 50	// preference. If this field is not nil, then the Upgrade method negotiates a
 51	// subprotocol by selecting the first match in this list with a protocol
 52	// requested by the client. If there's no match, then no protocol is
 53	// negotiated (the Sec-Websocket-Protocol header is not included in the
 54	// handshake response).
 55	Subprotocols []string
 56
 57	// Error specifies the function for generating HTTP error responses. If Error
 58	// is nil, then http.Error is used to generate the HTTP response.
 59	Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
 60
 61	// CheckOrigin returns true if the request Origin header is acceptable. If
 62	// CheckOrigin is nil, then a safe default is used: return false if the
 63	// Origin request header is present and the origin host is not equal to
 64	// request Host header.
 65	//
 66	// A CheckOrigin function should carefully validate the request origin to
 67	// prevent cross-site request forgery.
 68	CheckOrigin func(r *http.Request) bool
 69
 70	// EnableCompression specify if the server should attempt to negotiate per
 71	// message compression (RFC 7692). Setting this value to true does not
 72	// guarantee that compression will be supported. Currently only "no context
 73	// takeover" modes are supported.
 74	EnableCompression bool
 75}
 76
 77func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
 78	err := HandshakeError{reason}
 79	if u.Error != nil {
 80		u.Error(w, r, status, err)
 81	} else {
 82		w.Header().Set("Sec-Websocket-Version", "13")
 83		http.Error(w, http.StatusText(status), status)
 84	}
 85	return nil, err
 86}
 87
 88// checkSameOrigin returns true if the origin is not set or is equal to the request host.
 89func checkSameOrigin(r *http.Request) bool {
 90	origin := r.Header["Origin"]
 91	if len(origin) == 0 {
 92		return true
 93	}
 94	u, err := url.Parse(origin[0])
 95	if err != nil {
 96		return false
 97	}
 98	return equalASCIIFold(u.Host, r.Host)
 99}
100
101func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
102	if u.Subprotocols != nil {
103		clientProtocols := Subprotocols(r)
104		for _, serverProtocol := range u.Subprotocols {
105			for _, clientProtocol := range clientProtocols {
106				if clientProtocol == serverProtocol {
107					return clientProtocol
108				}
109			}
110		}
111	} else if responseHeader != nil {
112		return responseHeader.Get("Sec-Websocket-Protocol")
113	}
114	return ""
115}
116
117// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
118//
119// The responseHeader is included in the response to the client's upgrade
120// request. Use the responseHeader to specify cookies (Set-Cookie). To specify
121// subprotocols supported by the server, set Upgrader.Subprotocols directly.
122//
123// If the upgrade fails, then Upgrade replies to the client with an HTTP error
124// response.
125func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
126	const badHandshake = "websocket: the client is not using the websocket protocol: "
127
128	if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
129		return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header")
130	}
131
132	if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
133		return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
134	}
135
136	if r.Method != http.MethodGet {
137		return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
138	}
139
140	if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
141		return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")
142	}
143
144	if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
145		return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported")
146	}
147
148	checkOrigin := u.CheckOrigin
149	if checkOrigin == nil {
150		checkOrigin = checkSameOrigin
151	}
152	if !checkOrigin(r) {
153		return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin")
154	}
155
156	challengeKey := r.Header.Get("Sec-Websocket-Key")
157	if !isValidChallengeKey(challengeKey) {
158		return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length")
159	}
160
161	subprotocol := u.selectSubprotocol(r, responseHeader)
162
163	// Negotiate PMCE
164	var compress bool
165	if u.EnableCompression {
166		for _, ext := range parseExtensions(r.Header) {
167			if ext[""] != "permessage-deflate" {
168				continue
169			}
170			compress = true
171			break
172		}
173	}
174
175	h, ok := w.(http.Hijacker)
176	if !ok {
177		return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
178	}
179	var brw *bufio.ReadWriter
180	netConn, brw, err := h.Hijack()
181	if err != nil {
182		return u.returnError(w, r, http.StatusInternalServerError, err.Error())
183	}
184
185	if brw.Reader.Buffered() > 0 {
186		netConn.Close()
187		return nil, errors.New("websocket: client sent data before handshake is complete")
188	}
189
190	var br *bufio.Reader
191	if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 {
192		// Reuse hijacked buffered reader as connection reader.
193		br = brw.Reader
194	}
195
196	buf := bufioWriterBuffer(netConn, brw.Writer)
197
198	var writeBuf []byte
199	if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
200		// Reuse hijacked write buffer as connection buffer.
201		writeBuf = buf
202	}
203
204	c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
205	c.subprotocol = subprotocol
206
207	if compress {
208		c.newCompressionWriter = compressNoContextTakeover
209		c.newDecompressionReader = decompressNoContextTakeover
210	}
211
212	// Use larger of hijacked buffer and connection write buffer for header.
213	p := buf
214	if len(c.writeBuf) > len(p) {
215		p = c.writeBuf
216	}
217	p = p[:0]
218
219	p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
220	p = append(p, computeAcceptKey(challengeKey)...)
221	p = append(p, "\r\n"...)
222	if c.subprotocol != "" {
223		p = append(p, "Sec-WebSocket-Protocol: "...)
224		p = append(p, c.subprotocol...)
225		p = append(p, "\r\n"...)
226	}
227	if compress {
228		p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
229	}
230	for k, vs := range responseHeader {
231		if k == "Sec-Websocket-Protocol" {
232			continue
233		}
234		for _, v := range vs {
235			p = append(p, k...)
236			p = append(p, ": "...)
237			for i := 0; i < len(v); i++ {
238				b := v[i]
239				if b <= 31 {
240					// prevent response splitting.
241					b = ' '
242				}
243				p = append(p, b)
244			}
245			p = append(p, "\r\n"...)
246		}
247	}
248	p = append(p, "\r\n"...)
249
250	// Clear deadlines set by HTTP server.
251	netConn.SetDeadline(time.Time{})
252
253	if u.HandshakeTimeout > 0 {
254		netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
255	}
256	if _, err = netConn.Write(p); err != nil {
257		netConn.Close()
258		return nil, err
259	}
260	if u.HandshakeTimeout > 0 {
261		netConn.SetWriteDeadline(time.Time{})
262	}
263
264	return c, nil
265}
266
267// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
268//
269// Deprecated: Use websocket.Upgrader instead.
270//
271// Upgrade does not perform origin checking. The application is responsible for
272// checking the Origin header before calling Upgrade. An example implementation
273// of the same origin policy check is:
274//
275//	if req.Header.Get("Origin") != "http://"+req.Host {
276//		http.Error(w, "Origin not allowed", http.StatusForbidden)
277//		return
278//	}
279//
280// If the endpoint supports subprotocols, then the application is responsible
281// for negotiating the protocol used on the connection. Use the Subprotocols()
282// function to get the subprotocols requested by the client. Use the
283// Sec-Websocket-Protocol response header to specify the subprotocol selected
284// by the application.
285//
286// The responseHeader is included in the response to the client's upgrade
287// request. Use the responseHeader to specify cookies (Set-Cookie) and the
288// negotiated subprotocol (Sec-Websocket-Protocol).
289//
290// The connection buffers IO to the underlying network connection. The
291// readBufSize and writeBufSize parameters specify the size of the buffers to
292// use. Messages can be larger than the buffers.
293//
294// If the request is not a valid WebSocket handshake, then Upgrade returns an
295// error of type HandshakeError. Applications should handle this error by
296// replying to the client with an HTTP error response.
297func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) {
298	u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize}
299	u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) {
300		// don't return errors to maintain backwards compatibility
301	}
302	u.CheckOrigin = func(r *http.Request) bool {
303		// allow all connections by default
304		return true
305	}
306	return u.Upgrade(w, r, responseHeader)
307}
308
309// Subprotocols returns the subprotocols requested by the client in the
310// Sec-Websocket-Protocol header.
311func Subprotocols(r *http.Request) []string {
312	h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol"))
313	if h == "" {
314		return nil
315	}
316	protocols := strings.Split(h, ",")
317	for i := range protocols {
318		protocols[i] = strings.TrimSpace(protocols[i])
319	}
320	return protocols
321}
322
323// IsWebSocketUpgrade returns true if the client requested upgrade to the
324// WebSocket protocol.
325func IsWebSocketUpgrade(r *http.Request) bool {
326	return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
327		tokenListContainsValue(r.Header, "Upgrade", "websocket")
328}
329
330// bufioReaderSize size returns the size of a bufio.Reader.
331func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int {
332	// This code assumes that peek on a reset reader returns
333	// bufio.Reader.buf[:0].
334	// TODO: Use bufio.Reader.Size() after Go 1.10
335	br.Reset(originalReader)
336	if p, err := br.Peek(0); err == nil {
337		return cap(p)
338	}
339	return 0
340}
341
342// writeHook is an io.Writer that records the last slice passed to it vio
343// io.Writer.Write.
344type writeHook struct {
345	p []byte
346}
347
348func (wh *writeHook) Write(p []byte) (int, error) {
349	wh.p = p
350	return len(p), nil
351}
352
353// bufioWriterBuffer grabs the buffer from a bufio.Writer.
354func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
355	// This code assumes that bufio.Writer.buf[:1] is passed to the
356	// bufio.Writer's underlying writer.
357	var wh writeHook
358	bw.Reset(&wh)
359	bw.WriteByte(0)
360	bw.Flush()
361
362	bw.Reset(originalWriter)
363
364	return wh.p[:cap(wh.p)]
365}