client.go

  1// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package websocket
  6
  7import (
  8	"bufio"
  9	"bytes"
 10	"crypto/tls"
 11	"encoding/base64"
 12	"errors"
 13	"io"
 14	"io/ioutil"
 15	"net"
 16	"net/http"
 17	"net/url"
 18	"strings"
 19	"time"
 20)
 21
 22// ErrBadHandshake is returned when the server response to opening handshake is
 23// invalid.
 24var ErrBadHandshake = errors.New("websocket: bad handshake")
 25
 26var errInvalidCompression = errors.New("websocket: invalid compression negotiation")
 27
 28// NewClient creates a new client connection using the given net connection.
 29// The URL u specifies the host and request URI. Use requestHeader to specify
 30// the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies
 31// (Cookie). Use the response.Header to get the selected subprotocol
 32// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
 33//
 34// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
 35// non-nil *http.Response so that callers can handle redirects, authentication,
 36// etc.
 37//
 38// Deprecated: Use Dialer instead.
 39func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) {
 40	d := Dialer{
 41		ReadBufferSize:  readBufSize,
 42		WriteBufferSize: writeBufSize,
 43		NetDial: func(net, addr string) (net.Conn, error) {
 44			return netConn, nil
 45		},
 46	}
 47	return d.Dial(u.String(), requestHeader)
 48}
 49
 50// A Dialer contains options for connecting to WebSocket server.
 51type Dialer struct {
 52	// NetDial specifies the dial function for creating TCP connections. If
 53	// NetDial is nil, net.Dial is used.
 54	NetDial func(network, addr string) (net.Conn, error)
 55
 56	// Proxy specifies a function to return a proxy for a given
 57	// Request. If the function returns a non-nil error, the
 58	// request is aborted with the provided error.
 59	// If Proxy is nil or returns a nil *URL, no proxy is used.
 60	Proxy func(*http.Request) (*url.URL, error)
 61
 62	// TLSClientConfig specifies the TLS configuration to use with tls.Client.
 63	// If nil, the default configuration is used.
 64	TLSClientConfig *tls.Config
 65
 66	// HandshakeTimeout specifies the duration for the handshake to complete.
 67	HandshakeTimeout time.Duration
 68
 69	// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
 70	// size is zero, then a useful default size is used. The I/O buffer sizes
 71	// do not limit the size of the messages that can be sent or received.
 72	ReadBufferSize, WriteBufferSize int
 73
 74	// Subprotocols specifies the client's requested subprotocols.
 75	Subprotocols []string
 76
 77	// EnableCompression specifies if the client should attempt to negotiate
 78	// per message compression (RFC 7692). Setting this value to true does not
 79	// guarantee that compression will be supported. Currently only "no context
 80	// takeover" modes are supported.
 81	EnableCompression bool
 82
 83	// Jar specifies the cookie jar.
 84	// If Jar is nil, cookies are not sent in requests and ignored
 85	// in responses.
 86	Jar http.CookieJar
 87}
 88
 89var errMalformedURL = errors.New("malformed ws or wss URL")
 90
 91// parseURL parses the URL.
 92//
 93// This function is a replacement for the standard library url.Parse function.
 94// In Go 1.4 and earlier, url.Parse loses information from the path.
 95func parseURL(s string) (*url.URL, error) {
 96	// From the RFC:
 97	//
 98	// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
 99	// wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
100	var u url.URL
101	switch {
102	case strings.HasPrefix(s, "ws://"):
103		u.Scheme = "ws"
104		s = s[len("ws://"):]
105	case strings.HasPrefix(s, "wss://"):
106		u.Scheme = "wss"
107		s = s[len("wss://"):]
108	default:
109		return nil, errMalformedURL
110	}
111
112	if i := strings.Index(s, "?"); i >= 0 {
113		u.RawQuery = s[i+1:]
114		s = s[:i]
115	}
116
117	if i := strings.Index(s, "/"); i >= 0 {
118		u.Opaque = s[i:]
119		s = s[:i]
120	} else {
121		u.Opaque = "/"
122	}
123
124	u.Host = s
125
126	if strings.Contains(u.Host, "@") {
127		// Don't bother parsing user information because user information is
128		// not allowed in websocket URIs.
129		return nil, errMalformedURL
130	}
131
132	return &u, nil
133}
134
135func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
136	hostPort = u.Host
137	hostNoPort = u.Host
138	if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
139		hostNoPort = hostNoPort[:i]
140	} else {
141		switch u.Scheme {
142		case "wss":
143			hostPort += ":443"
144		case "https":
145			hostPort += ":443"
146		default:
147			hostPort += ":80"
148		}
149	}
150	return hostPort, hostNoPort
151}
152
153// DefaultDialer is a dialer with all fields set to the default zero values.
154var DefaultDialer = &Dialer{
155	Proxy: http.ProxyFromEnvironment,
156}
157
158// Dial creates a new client connection. Use requestHeader to specify the
159// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
160// Use the response.Header to get the selected subprotocol
161// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
162//
163// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
164// non-nil *http.Response so that callers can handle redirects, authentication,
165// etcetera. The response body may not contain the entire response and does not
166// need to be closed by the application.
167func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
168
169	if d == nil {
170		d = &Dialer{
171			Proxy: http.ProxyFromEnvironment,
172		}
173	}
174
175	challengeKey, err := generateChallengeKey()
176	if err != nil {
177		return nil, nil, err
178	}
179
180	u, err := parseURL(urlStr)
181	if err != nil {
182		return nil, nil, err
183	}
184
185	switch u.Scheme {
186	case "ws":
187		u.Scheme = "http"
188	case "wss":
189		u.Scheme = "https"
190	default:
191		return nil, nil, errMalformedURL
192	}
193
194	if u.User != nil {
195		// User name and password are not allowed in websocket URIs.
196		return nil, nil, errMalformedURL
197	}
198
199	req := &http.Request{
200		Method:     "GET",
201		URL:        u,
202		Proto:      "HTTP/1.1",
203		ProtoMajor: 1,
204		ProtoMinor: 1,
205		Header:     make(http.Header),
206		Host:       u.Host,
207	}
208
209	// Set the cookies present in the cookie jar of the dialer
210	if d.Jar != nil {
211		for _, cookie := range d.Jar.Cookies(u) {
212			req.AddCookie(cookie)
213		}
214	}
215
216	// Set the request headers using the capitalization for names and values in
217	// RFC examples. Although the capitalization shouldn't matter, there are
218	// servers that depend on it. The Header.Set method is not used because the
219	// method canonicalizes the header names.
220	req.Header["Upgrade"] = []string{"websocket"}
221	req.Header["Connection"] = []string{"Upgrade"}
222	req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
223	req.Header["Sec-WebSocket-Version"] = []string{"13"}
224	if len(d.Subprotocols) > 0 {
225		req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
226	}
227	for k, vs := range requestHeader {
228		switch {
229		case k == "Host":
230			if len(vs) > 0 {
231				req.Host = vs[0]
232			}
233		case k == "Upgrade" ||
234			k == "Connection" ||
235			k == "Sec-Websocket-Key" ||
236			k == "Sec-Websocket-Version" ||
237			k == "Sec-Websocket-Extensions" ||
238			(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
239			return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
240		default:
241			req.Header[k] = vs
242		}
243	}
244
245	if d.EnableCompression {
246		req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
247	}
248
249	hostPort, hostNoPort := hostPortNoPort(u)
250
251	var proxyURL *url.URL
252	// Check wether the proxy method has been configured
253	if d.Proxy != nil {
254		proxyURL, err = d.Proxy(req)
255	}
256	if err != nil {
257		return nil, nil, err
258	}
259
260	var targetHostPort string
261	if proxyURL != nil {
262		targetHostPort, _ = hostPortNoPort(proxyURL)
263	} else {
264		targetHostPort = hostPort
265	}
266
267	var deadline time.Time
268	if d.HandshakeTimeout != 0 {
269		deadline = time.Now().Add(d.HandshakeTimeout)
270	}
271
272	netDial := d.NetDial
273	if netDial == nil {
274		netDialer := &net.Dialer{Deadline: deadline}
275		netDial = netDialer.Dial
276	}
277
278	netConn, err := netDial("tcp", targetHostPort)
279	if err != nil {
280		return nil, nil, err
281	}
282
283	defer func() {
284		if netConn != nil {
285			netConn.Close()
286		}
287	}()
288
289	if err := netConn.SetDeadline(deadline); err != nil {
290		return nil, nil, err
291	}
292
293	if proxyURL != nil {
294		connectHeader := make(http.Header)
295		if user := proxyURL.User; user != nil {
296			proxyUser := user.Username()
297			if proxyPassword, passwordSet := user.Password(); passwordSet {
298				credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
299				connectHeader.Set("Proxy-Authorization", "Basic "+credential)
300			}
301		}
302		connectReq := &http.Request{
303			Method: "CONNECT",
304			URL:    &url.URL{Opaque: hostPort},
305			Host:   hostPort,
306			Header: connectHeader,
307		}
308
309		connectReq.Write(netConn)
310
311		// Read response.
312		// Okay to use and discard buffered reader here, because
313		// TLS server will not speak until spoken to.
314		br := bufio.NewReader(netConn)
315		resp, err := http.ReadResponse(br, connectReq)
316		if err != nil {
317			return nil, nil, err
318		}
319		if resp.StatusCode != 200 {
320			f := strings.SplitN(resp.Status, " ", 2)
321			return nil, nil, errors.New(f[1])
322		}
323	}
324
325	if u.Scheme == "https" {
326		cfg := cloneTLSConfig(d.TLSClientConfig)
327		if cfg.ServerName == "" {
328			cfg.ServerName = hostNoPort
329		}
330		tlsConn := tls.Client(netConn, cfg)
331		netConn = tlsConn
332		if err := tlsConn.Handshake(); err != nil {
333			return nil, nil, err
334		}
335		if !cfg.InsecureSkipVerify {
336			if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
337				return nil, nil, err
338			}
339		}
340	}
341
342	conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize)
343
344	if err := req.Write(netConn); err != nil {
345		return nil, nil, err
346	}
347
348	resp, err := http.ReadResponse(conn.br, req)
349	if err != nil {
350		return nil, nil, err
351	}
352
353	if d.Jar != nil {
354		if rc := resp.Cookies(); len(rc) > 0 {
355			d.Jar.SetCookies(u, rc)
356		}
357	}
358
359	if resp.StatusCode != 101 ||
360		!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
361		!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
362		resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
363		// Before closing the network connection on return from this
364		// function, slurp up some of the response to aid application
365		// debugging.
366		buf := make([]byte, 1024)
367		n, _ := io.ReadFull(resp.Body, buf)
368		resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
369		return nil, resp, ErrBadHandshake
370	}
371
372	for _, ext := range parseExtensions(resp.Header) {
373		if ext[""] != "permessage-deflate" {
374			continue
375		}
376		_, snct := ext["server_no_context_takeover"]
377		_, cnct := ext["client_no_context_takeover"]
378		if !snct || !cnct {
379			return nil, resp, errInvalidCompression
380		}
381		conn.newCompressionWriter = compressNoContextTakeover
382		conn.newDecompressionReader = decompressNoContextTakeover
383		break
384	}
385
386	resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
387	conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
388
389	netConn.SetDeadline(time.Time{})
390	netConn = nil // to avoid close in defer.
391	return conn, resp, nil
392}