1// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package websocket
6
7import (
8 "bufio"
9 "errors"
10 "io"
11 "net/http"
12 "net/url"
13 "strings"
14 "time"
15)
16
17// HandshakeError describes an error with the handshake from the peer.
18type HandshakeError struct {
19 message string
20}
21
22func (e HandshakeError) Error() string { return e.message }
23
24// Upgrader specifies parameters for upgrading an HTTP connection to a
25// WebSocket connection.
26//
27// It is safe to call Upgrader's methods concurrently.
28type Upgrader struct {
29 // HandshakeTimeout specifies the duration for the handshake to complete.
30 HandshakeTimeout time.Duration
31
32 // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
33 // size is zero, then buffers allocated by the HTTP server are used. The
34 // I/O buffer sizes do not limit the size of the messages that can be sent
35 // or received.
36 ReadBufferSize, WriteBufferSize int
37
38 // WriteBufferPool is a pool of buffers for write operations. If the value
39 // is not set, then write buffers are allocated to the connection for the
40 // lifetime of the connection.
41 //
42 // A pool is most useful when the application has a modest volume of writes
43 // across a large number of connections.
44 //
45 // Applications should use a single pool for each unique value of
46 // WriteBufferSize.
47 WriteBufferPool BufferPool
48
49 // Subprotocols specifies the server's supported protocols in order of
50 // preference. If this field is not nil, then the Upgrade method negotiates a
51 // subprotocol by selecting the first match in this list with a protocol
52 // requested by the client. If there's no match, then no protocol is
53 // negotiated (the Sec-Websocket-Protocol header is not included in the
54 // handshake response).
55 Subprotocols []string
56
57 // Error specifies the function for generating HTTP error responses. If Error
58 // is nil, then http.Error is used to generate the HTTP response.
59 Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
60
61 // CheckOrigin returns true if the request Origin header is acceptable. If
62 // CheckOrigin is nil, then a safe default is used: return false if the
63 // Origin request header is present and the origin host is not equal to
64 // request Host header.
65 //
66 // A CheckOrigin function should carefully validate the request origin to
67 // prevent cross-site request forgery.
68 CheckOrigin func(r *http.Request) bool
69
70 // EnableCompression specify if the server should attempt to negotiate per
71 // message compression (RFC 7692). Setting this value to true does not
72 // guarantee that compression will be supported. Currently only "no context
73 // takeover" modes are supported.
74 EnableCompression bool
75}
76
77func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
78 err := HandshakeError{reason}
79 if u.Error != nil {
80 u.Error(w, r, status, err)
81 } else {
82 w.Header().Set("Sec-Websocket-Version", "13")
83 http.Error(w, http.StatusText(status), status)
84 }
85 return nil, err
86}
87
88// checkSameOrigin returns true if the origin is not set or is equal to the request host.
89func checkSameOrigin(r *http.Request) bool {
90 origin := r.Header["Origin"]
91 if len(origin) == 0 {
92 return true
93 }
94 u, err := url.Parse(origin[0])
95 if err != nil {
96 return false
97 }
98 return equalASCIIFold(u.Host, r.Host)
99}
100
101func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
102 if u.Subprotocols != nil {
103 clientProtocols := Subprotocols(r)
104 for _, serverProtocol := range u.Subprotocols {
105 for _, clientProtocol := range clientProtocols {
106 if clientProtocol == serverProtocol {
107 return clientProtocol
108 }
109 }
110 }
111 } else if responseHeader != nil {
112 return responseHeader.Get("Sec-Websocket-Protocol")
113 }
114 return ""
115}
116
117// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
118//
119// The responseHeader is included in the response to the client's upgrade
120// request. Use the responseHeader to specify cookies (Set-Cookie). To specify
121// subprotocols supported by the server, set Upgrader.Subprotocols directly.
122//
123// If the upgrade fails, then Upgrade replies to the client with an HTTP error
124// response.
125func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
126 const badHandshake = "websocket: the client is not using the websocket protocol: "
127
128 if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
129 return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header")
130 }
131
132 if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
133 return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
134 }
135
136 if r.Method != http.MethodGet {
137 return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
138 }
139
140 if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
141 return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")
142 }
143
144 if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
145 return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported")
146 }
147
148 checkOrigin := u.CheckOrigin
149 if checkOrigin == nil {
150 checkOrigin = checkSameOrigin
151 }
152 if !checkOrigin(r) {
153 return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin")
154 }
155
156 challengeKey := r.Header.Get("Sec-Websocket-Key")
157 if !isValidChallengeKey(challengeKey) {
158 return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length")
159 }
160
161 subprotocol := u.selectSubprotocol(r, responseHeader)
162
163 // Negotiate PMCE
164 var compress bool
165 if u.EnableCompression {
166 for _, ext := range parseExtensions(r.Header) {
167 if ext[""] != "permessage-deflate" {
168 continue
169 }
170 compress = true
171 break
172 }
173 }
174
175 h, ok := w.(http.Hijacker)
176 if !ok {
177 return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
178 }
179 var brw *bufio.ReadWriter
180 netConn, brw, err := h.Hijack()
181 if err != nil {
182 return u.returnError(w, r, http.StatusInternalServerError, err.Error())
183 }
184
185 if brw.Reader.Buffered() > 0 {
186 netConn.Close()
187 return nil, errors.New("websocket: client sent data before handshake is complete")
188 }
189
190 var br *bufio.Reader
191 if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 {
192 // Reuse hijacked buffered reader as connection reader.
193 br = brw.Reader
194 }
195
196 buf := bufioWriterBuffer(netConn, brw.Writer)
197
198 var writeBuf []byte
199 if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
200 // Reuse hijacked write buffer as connection buffer.
201 writeBuf = buf
202 }
203
204 c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
205 c.subprotocol = subprotocol
206
207 if compress {
208 c.newCompressionWriter = compressNoContextTakeover
209 c.newDecompressionReader = decompressNoContextTakeover
210 }
211
212 // Use larger of hijacked buffer and connection write buffer for header.
213 p := buf
214 if len(c.writeBuf) > len(p) {
215 p = c.writeBuf
216 }
217 p = p[:0]
218
219 p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
220 p = append(p, computeAcceptKey(challengeKey)...)
221 p = append(p, "\r\n"...)
222 if c.subprotocol != "" {
223 p = append(p, "Sec-WebSocket-Protocol: "...)
224 p = append(p, c.subprotocol...)
225 p = append(p, "\r\n"...)
226 }
227 if compress {
228 p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
229 }
230 for k, vs := range responseHeader {
231 if k == "Sec-Websocket-Protocol" {
232 continue
233 }
234 for _, v := range vs {
235 p = append(p, k...)
236 p = append(p, ": "...)
237 for i := 0; i < len(v); i++ {
238 b := v[i]
239 if b <= 31 {
240 // prevent response splitting.
241 b = ' '
242 }
243 p = append(p, b)
244 }
245 p = append(p, "\r\n"...)
246 }
247 }
248 p = append(p, "\r\n"...)
249
250 // Clear deadlines set by HTTP server.
251 netConn.SetDeadline(time.Time{})
252
253 if u.HandshakeTimeout > 0 {
254 netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
255 }
256 if _, err = netConn.Write(p); err != nil {
257 netConn.Close()
258 return nil, err
259 }
260 if u.HandshakeTimeout > 0 {
261 netConn.SetWriteDeadline(time.Time{})
262 }
263
264 return c, nil
265}
266
267// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
268//
269// Deprecated: Use websocket.Upgrader instead.
270//
271// Upgrade does not perform origin checking. The application is responsible for
272// checking the Origin header before calling Upgrade. An example implementation
273// of the same origin policy check is:
274//
275// if req.Header.Get("Origin") != "http://"+req.Host {
276// http.Error(w, "Origin not allowed", http.StatusForbidden)
277// return
278// }
279//
280// If the endpoint supports subprotocols, then the application is responsible
281// for negotiating the protocol used on the connection. Use the Subprotocols()
282// function to get the subprotocols requested by the client. Use the
283// Sec-Websocket-Protocol response header to specify the subprotocol selected
284// by the application.
285//
286// The responseHeader is included in the response to the client's upgrade
287// request. Use the responseHeader to specify cookies (Set-Cookie) and the
288// negotiated subprotocol (Sec-Websocket-Protocol).
289//
290// The connection buffers IO to the underlying network connection. The
291// readBufSize and writeBufSize parameters specify the size of the buffers to
292// use. Messages can be larger than the buffers.
293//
294// If the request is not a valid WebSocket handshake, then Upgrade returns an
295// error of type HandshakeError. Applications should handle this error by
296// replying to the client with an HTTP error response.
297func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) {
298 u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize}
299 u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) {
300 // don't return errors to maintain backwards compatibility
301 }
302 u.CheckOrigin = func(r *http.Request) bool {
303 // allow all connections by default
304 return true
305 }
306 return u.Upgrade(w, r, responseHeader)
307}
308
309// Subprotocols returns the subprotocols requested by the client in the
310// Sec-Websocket-Protocol header.
311func Subprotocols(r *http.Request) []string {
312 h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol"))
313 if h == "" {
314 return nil
315 }
316 protocols := strings.Split(h, ",")
317 for i := range protocols {
318 protocols[i] = strings.TrimSpace(protocols[i])
319 }
320 return protocols
321}
322
323// IsWebSocketUpgrade returns true if the client requested upgrade to the
324// WebSocket protocol.
325func IsWebSocketUpgrade(r *http.Request) bool {
326 return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
327 tokenListContainsValue(r.Header, "Upgrade", "websocket")
328}
329
330// bufioReaderSize size returns the size of a bufio.Reader.
331func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int {
332 // This code assumes that peek on a reset reader returns
333 // bufio.Reader.buf[:0].
334 // TODO: Use bufio.Reader.Size() after Go 1.10
335 br.Reset(originalReader)
336 if p, err := br.Peek(0); err == nil {
337 return cap(p)
338 }
339 return 0
340}
341
342// writeHook is an io.Writer that records the last slice passed to it vio
343// io.Writer.Write.
344type writeHook struct {
345 p []byte
346}
347
348func (wh *writeHook) Write(p []byte) (int, error) {
349 wh.p = p
350 return len(p), nil
351}
352
353// bufioWriterBuffer grabs the buffer from a bufio.Writer.
354func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
355 // This code assumes that bufio.Writer.buf[:1] is passed to the
356 // bufio.Writer's underlying writer.
357 var wh writeHook
358 bw.Reset(&wh)
359 bw.WriteByte(0)
360 bw.Flush()
361
362 bw.Reset(originalWriter)
363
364 return wh.p[:cap(wh.p)]
365}