conn.go

   1// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
   2// Use of this source code is governed by a BSD-style
   3// license that can be found in the LICENSE file.
   4
   5package websocket
   6
   7import (
   8	"bufio"
   9	"encoding/binary"
  10	"errors"
  11	"io"
  12	"io/ioutil"
  13	"math/rand"
  14	"net"
  15	"strconv"
  16	"sync"
  17	"time"
  18	"unicode/utf8"
  19)
  20
  21const (
  22	// Frame header byte 0 bits from Section 5.2 of RFC 6455
  23	finalBit = 1 << 7
  24	rsv1Bit  = 1 << 6
  25	rsv2Bit  = 1 << 5
  26	rsv3Bit  = 1 << 4
  27
  28	// Frame header byte 1 bits from Section 5.2 of RFC 6455
  29	maskBit = 1 << 7
  30
  31	maxFrameHeaderSize         = 2 + 8 + 4 // Fixed header + length + mask
  32	maxControlFramePayloadSize = 125
  33
  34	writeWait = time.Second
  35
  36	defaultReadBufferSize  = 4096
  37	defaultWriteBufferSize = 4096
  38
  39	continuationFrame = 0
  40	noFrame           = -1
  41)
  42
  43// Close codes defined in RFC 6455, section 11.7.
  44const (
  45	CloseNormalClosure           = 1000
  46	CloseGoingAway               = 1001
  47	CloseProtocolError           = 1002
  48	CloseUnsupportedData         = 1003
  49	CloseNoStatusReceived        = 1005
  50	CloseAbnormalClosure         = 1006
  51	CloseInvalidFramePayloadData = 1007
  52	ClosePolicyViolation         = 1008
  53	CloseMessageTooBig           = 1009
  54	CloseMandatoryExtension      = 1010
  55	CloseInternalServerErr       = 1011
  56	CloseServiceRestart          = 1012
  57	CloseTryAgainLater           = 1013
  58	CloseTLSHandshake            = 1015
  59)
  60
  61// The message types are defined in RFC 6455, section 11.8.
  62const (
  63	// TextMessage denotes a text data message. The text message payload is
  64	// interpreted as UTF-8 encoded text data.
  65	TextMessage = 1
  66
  67	// BinaryMessage denotes a binary data message.
  68	BinaryMessage = 2
  69
  70	// CloseMessage denotes a close control message. The optional message
  71	// payload contains a numeric code and text. Use the FormatCloseMessage
  72	// function to format a close message payload.
  73	CloseMessage = 8
  74
  75	// PingMessage denotes a ping control message. The optional message payload
  76	// is UTF-8 encoded text.
  77	PingMessage = 9
  78
  79	// PongMessage denotes a ping control message. The optional message payload
  80	// is UTF-8 encoded text.
  81	PongMessage = 10
  82)
  83
  84// ErrCloseSent is returned when the application writes a message to the
  85// connection after sending a close message.
  86var ErrCloseSent = errors.New("websocket: close sent")
  87
  88// ErrReadLimit is returned when reading a message that is larger than the
  89// read limit set for the connection.
  90var ErrReadLimit = errors.New("websocket: read limit exceeded")
  91
  92// netError satisfies the net Error interface.
  93type netError struct {
  94	msg       string
  95	temporary bool
  96	timeout   bool
  97}
  98
  99func (e *netError) Error() string   { return e.msg }
 100func (e *netError) Temporary() bool { return e.temporary }
 101func (e *netError) Timeout() bool   { return e.timeout }
 102
 103// CloseError represents close frame.
 104type CloseError struct {
 105
 106	// Code is defined in RFC 6455, section 11.7.
 107	Code int
 108
 109	// Text is the optional text payload.
 110	Text string
 111}
 112
 113func (e *CloseError) Error() string {
 114	s := []byte("websocket: close ")
 115	s = strconv.AppendInt(s, int64(e.Code), 10)
 116	switch e.Code {
 117	case CloseNormalClosure:
 118		s = append(s, " (normal)"...)
 119	case CloseGoingAway:
 120		s = append(s, " (going away)"...)
 121	case CloseProtocolError:
 122		s = append(s, " (protocol error)"...)
 123	case CloseUnsupportedData:
 124		s = append(s, " (unsupported data)"...)
 125	case CloseNoStatusReceived:
 126		s = append(s, " (no status)"...)
 127	case CloseAbnormalClosure:
 128		s = append(s, " (abnormal closure)"...)
 129	case CloseInvalidFramePayloadData:
 130		s = append(s, " (invalid payload data)"...)
 131	case ClosePolicyViolation:
 132		s = append(s, " (policy violation)"...)
 133	case CloseMessageTooBig:
 134		s = append(s, " (message too big)"...)
 135	case CloseMandatoryExtension:
 136		s = append(s, " (mandatory extension missing)"...)
 137	case CloseInternalServerErr:
 138		s = append(s, " (internal server error)"...)
 139	case CloseTLSHandshake:
 140		s = append(s, " (TLS handshake error)"...)
 141	}
 142	if e.Text != "" {
 143		s = append(s, ": "...)
 144		s = append(s, e.Text...)
 145	}
 146	return string(s)
 147}
 148
 149// IsCloseError returns boolean indicating whether the error is a *CloseError
 150// with one of the specified codes.
 151func IsCloseError(err error, codes ...int) bool {
 152	if e, ok := err.(*CloseError); ok {
 153		for _, code := range codes {
 154			if e.Code == code {
 155				return true
 156			}
 157		}
 158	}
 159	return false
 160}
 161
 162// IsUnexpectedCloseError returns boolean indicating whether the error is a
 163// *CloseError with a code not in the list of expected codes.
 164func IsUnexpectedCloseError(err error, expectedCodes ...int) bool {
 165	if e, ok := err.(*CloseError); ok {
 166		for _, code := range expectedCodes {
 167			if e.Code == code {
 168				return false
 169			}
 170		}
 171		return true
 172	}
 173	return false
 174}
 175
 176var (
 177	errWriteTimeout        = &netError{msg: "websocket: write timeout", timeout: true, temporary: true}
 178	errUnexpectedEOF       = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()}
 179	errBadWriteOpCode      = errors.New("websocket: bad write message type")
 180	errWriteClosed         = errors.New("websocket: write closed")
 181	errInvalidControlFrame = errors.New("websocket: invalid control frame")
 182)
 183
 184func newMaskKey() [4]byte {
 185	n := rand.Uint32()
 186	return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}
 187}
 188
 189func hideTempErr(err error) error {
 190	if e, ok := err.(net.Error); ok && e.Temporary() {
 191		err = &netError{msg: e.Error(), timeout: e.Timeout()}
 192	}
 193	return err
 194}
 195
 196func isControl(frameType int) bool {
 197	return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage
 198}
 199
 200func isData(frameType int) bool {
 201	return frameType == TextMessage || frameType == BinaryMessage
 202}
 203
 204var validReceivedCloseCodes = map[int]bool{
 205	// see http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
 206
 207	CloseNormalClosure:           true,
 208	CloseGoingAway:               true,
 209	CloseProtocolError:           true,
 210	CloseUnsupportedData:         true,
 211	CloseNoStatusReceived:        false,
 212	CloseAbnormalClosure:         false,
 213	CloseInvalidFramePayloadData: true,
 214	ClosePolicyViolation:         true,
 215	CloseMessageTooBig:           true,
 216	CloseMandatoryExtension:      true,
 217	CloseInternalServerErr:       true,
 218	CloseServiceRestart:          true,
 219	CloseTryAgainLater:           true,
 220	CloseTLSHandshake:            false,
 221}
 222
 223func isValidReceivedCloseCode(code int) bool {
 224	return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
 225}
 226
 227// The Conn type represents a WebSocket connection.
 228type Conn struct {
 229	conn        net.Conn
 230	isServer    bool
 231	subprotocol string
 232
 233	// Write fields
 234	mu            chan bool // used as mutex to protect write to conn
 235	writeBuf      []byte    // frame is constructed in this buffer.
 236	writeDeadline time.Time
 237	writer        io.WriteCloser // the current writer returned to the application
 238	isWriting     bool           // for best-effort concurrent write detection
 239
 240	writeErrMu sync.Mutex
 241	writeErr   error
 242
 243	enableWriteCompression bool
 244	compressionLevel       int
 245	newCompressionWriter   func(io.WriteCloser, int) io.WriteCloser
 246
 247	// Read fields
 248	reader        io.ReadCloser // the current reader returned to the application
 249	readErr       error
 250	br            *bufio.Reader
 251	readRemaining int64 // bytes remaining in current frame.
 252	readFinal     bool  // true the current message has more frames.
 253	readLength    int64 // Message size.
 254	readLimit     int64 // Maximum message size.
 255	readMaskPos   int
 256	readMaskKey   [4]byte
 257	handlePong    func(string) error
 258	handlePing    func(string) error
 259	handleClose   func(int, string) error
 260	readErrCount  int
 261	messageReader *messageReader // the current low-level reader
 262
 263	readDecompress         bool // whether last read frame had RSV1 set
 264	newDecompressionReader func(io.Reader) io.ReadCloser
 265}
 266
 267func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
 268	return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil)
 269}
 270
 271type writeHook struct {
 272	p []byte
 273}
 274
 275func (wh *writeHook) Write(p []byte) (int, error) {
 276	wh.p = p
 277	return len(p), nil
 278}
 279
 280func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn {
 281	mu := make(chan bool, 1)
 282	mu <- true
 283
 284	var br *bufio.Reader
 285	if readBufferSize == 0 && brw != nil && brw.Reader != nil {
 286		// Reuse the supplied bufio.Reader if the buffer has a useful size.
 287		// This code assumes that peek on a reader returns
 288		// bufio.Reader.buf[:0].
 289		brw.Reader.Reset(conn)
 290		if p, err := brw.Reader.Peek(0); err == nil && cap(p) >= 256 {
 291			br = brw.Reader
 292		}
 293	}
 294	if br == nil {
 295		if readBufferSize == 0 {
 296			readBufferSize = defaultReadBufferSize
 297		}
 298		if readBufferSize < maxControlFramePayloadSize {
 299			readBufferSize = maxControlFramePayloadSize
 300		}
 301		br = bufio.NewReaderSize(conn, readBufferSize)
 302	}
 303
 304	var writeBuf []byte
 305	if writeBufferSize == 0 && brw != nil && brw.Writer != nil {
 306		// Use the bufio.Writer's buffer if the buffer has a useful size. This
 307		// code assumes that bufio.Writer.buf[:1] is passed to the
 308		// bufio.Writer's underlying writer.
 309		var wh writeHook
 310		brw.Writer.Reset(&wh)
 311		brw.Writer.WriteByte(0)
 312		brw.Flush()
 313		if cap(wh.p) >= maxFrameHeaderSize+256 {
 314			writeBuf = wh.p[:cap(wh.p)]
 315		}
 316	}
 317
 318	if writeBuf == nil {
 319		if writeBufferSize == 0 {
 320			writeBufferSize = defaultWriteBufferSize
 321		}
 322		writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize)
 323	}
 324
 325	c := &Conn{
 326		isServer:               isServer,
 327		br:                     br,
 328		conn:                   conn,
 329		mu:                     mu,
 330		readFinal:              true,
 331		writeBuf:               writeBuf,
 332		enableWriteCompression: true,
 333		compressionLevel:       defaultCompressionLevel,
 334	}
 335	c.SetCloseHandler(nil)
 336	c.SetPingHandler(nil)
 337	c.SetPongHandler(nil)
 338	return c
 339}
 340
 341// Subprotocol returns the negotiated protocol for the connection.
 342func (c *Conn) Subprotocol() string {
 343	return c.subprotocol
 344}
 345
 346// Close closes the underlying network connection without sending or waiting for a close frame.
 347func (c *Conn) Close() error {
 348	return c.conn.Close()
 349}
 350
 351// LocalAddr returns the local network address.
 352func (c *Conn) LocalAddr() net.Addr {
 353	return c.conn.LocalAddr()
 354}
 355
 356// RemoteAddr returns the remote network address.
 357func (c *Conn) RemoteAddr() net.Addr {
 358	return c.conn.RemoteAddr()
 359}
 360
 361// Write methods
 362
 363func (c *Conn) writeFatal(err error) error {
 364	err = hideTempErr(err)
 365	c.writeErrMu.Lock()
 366	if c.writeErr == nil {
 367		c.writeErr = err
 368	}
 369	c.writeErrMu.Unlock()
 370	return err
 371}
 372
 373func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
 374	<-c.mu
 375	defer func() { c.mu <- true }()
 376
 377	c.writeErrMu.Lock()
 378	err := c.writeErr
 379	c.writeErrMu.Unlock()
 380	if err != nil {
 381		return err
 382	}
 383
 384	c.conn.SetWriteDeadline(deadline)
 385	for _, buf := range bufs {
 386		if len(buf) > 0 {
 387			_, err := c.conn.Write(buf)
 388			if err != nil {
 389				return c.writeFatal(err)
 390			}
 391		}
 392	}
 393
 394	if frameType == CloseMessage {
 395		c.writeFatal(ErrCloseSent)
 396	}
 397	return nil
 398}
 399
 400// WriteControl writes a control message with the given deadline. The allowed
 401// message types are CloseMessage, PingMessage and PongMessage.
 402func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error {
 403	if !isControl(messageType) {
 404		return errBadWriteOpCode
 405	}
 406	if len(data) > maxControlFramePayloadSize {
 407		return errInvalidControlFrame
 408	}
 409
 410	b0 := byte(messageType) | finalBit
 411	b1 := byte(len(data))
 412	if !c.isServer {
 413		b1 |= maskBit
 414	}
 415
 416	buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize)
 417	buf = append(buf, b0, b1)
 418
 419	if c.isServer {
 420		buf = append(buf, data...)
 421	} else {
 422		key := newMaskKey()
 423		buf = append(buf, key[:]...)
 424		buf = append(buf, data...)
 425		maskBytes(key, 0, buf[6:])
 426	}
 427
 428	d := time.Hour * 1000
 429	if !deadline.IsZero() {
 430		d = deadline.Sub(time.Now())
 431		if d < 0 {
 432			return errWriteTimeout
 433		}
 434	}
 435
 436	timer := time.NewTimer(d)
 437	select {
 438	case <-c.mu:
 439		timer.Stop()
 440	case <-timer.C:
 441		return errWriteTimeout
 442	}
 443	defer func() { c.mu <- true }()
 444
 445	c.writeErrMu.Lock()
 446	err := c.writeErr
 447	c.writeErrMu.Unlock()
 448	if err != nil {
 449		return err
 450	}
 451
 452	c.conn.SetWriteDeadline(deadline)
 453	_, err = c.conn.Write(buf)
 454	if err != nil {
 455		return c.writeFatal(err)
 456	}
 457	if messageType == CloseMessage {
 458		c.writeFatal(ErrCloseSent)
 459	}
 460	return err
 461}
 462
 463func (c *Conn) prepWrite(messageType int) error {
 464	// Close previous writer if not already closed by the application. It's
 465	// probably better to return an error in this situation, but we cannot
 466	// change this without breaking existing applications.
 467	if c.writer != nil {
 468		c.writer.Close()
 469		c.writer = nil
 470	}
 471
 472	if !isControl(messageType) && !isData(messageType) {
 473		return errBadWriteOpCode
 474	}
 475
 476	c.writeErrMu.Lock()
 477	err := c.writeErr
 478	c.writeErrMu.Unlock()
 479	return err
 480}
 481
 482// NextWriter returns a writer for the next message to send. The writer's Close
 483// method flushes the complete message to the network.
 484//
 485// There can be at most one open writer on a connection. NextWriter closes the
 486// previous writer if the application has not already done so.
 487func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
 488	if err := c.prepWrite(messageType); err != nil {
 489		return nil, err
 490	}
 491
 492	mw := &messageWriter{
 493		c:         c,
 494		frameType: messageType,
 495		pos:       maxFrameHeaderSize,
 496	}
 497	c.writer = mw
 498	if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
 499		w := c.newCompressionWriter(c.writer, c.compressionLevel)
 500		mw.compress = true
 501		c.writer = w
 502	}
 503	return c.writer, nil
 504}
 505
 506type messageWriter struct {
 507	c         *Conn
 508	compress  bool // whether next call to flushFrame should set RSV1
 509	pos       int  // end of data in writeBuf.
 510	frameType int  // type of the current frame.
 511	err       error
 512}
 513
 514func (w *messageWriter) fatal(err error) error {
 515	if w.err != nil {
 516		w.err = err
 517		w.c.writer = nil
 518	}
 519	return err
 520}
 521
 522// flushFrame writes buffered data and extra as a frame to the network. The
 523// final argument indicates that this is the last frame in the message.
 524func (w *messageWriter) flushFrame(final bool, extra []byte) error {
 525	c := w.c
 526	length := w.pos - maxFrameHeaderSize + len(extra)
 527
 528	// Check for invalid control frames.
 529	if isControl(w.frameType) &&
 530		(!final || length > maxControlFramePayloadSize) {
 531		return w.fatal(errInvalidControlFrame)
 532	}
 533
 534	b0 := byte(w.frameType)
 535	if final {
 536		b0 |= finalBit
 537	}
 538	if w.compress {
 539		b0 |= rsv1Bit
 540	}
 541	w.compress = false
 542
 543	b1 := byte(0)
 544	if !c.isServer {
 545		b1 |= maskBit
 546	}
 547
 548	// Assume that the frame starts at beginning of c.writeBuf.
 549	framePos := 0
 550	if c.isServer {
 551		// Adjust up if mask not included in the header.
 552		framePos = 4
 553	}
 554
 555	switch {
 556	case length >= 65536:
 557		c.writeBuf[framePos] = b0
 558		c.writeBuf[framePos+1] = b1 | 127
 559		binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length))
 560	case length > 125:
 561		framePos += 6
 562		c.writeBuf[framePos] = b0
 563		c.writeBuf[framePos+1] = b1 | 126
 564		binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length))
 565	default:
 566		framePos += 8
 567		c.writeBuf[framePos] = b0
 568		c.writeBuf[framePos+1] = b1 | byte(length)
 569	}
 570
 571	if !c.isServer {
 572		key := newMaskKey()
 573		copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
 574		maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
 575		if len(extra) > 0 {
 576			return c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))
 577		}
 578	}
 579
 580	// Write the buffers to the connection with best-effort detection of
 581	// concurrent writes. See the concurrency section in the package
 582	// documentation for more info.
 583
 584	if c.isWriting {
 585		panic("concurrent write to websocket connection")
 586	}
 587	c.isWriting = true
 588
 589	err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra)
 590
 591	if !c.isWriting {
 592		panic("concurrent write to websocket connection")
 593	}
 594	c.isWriting = false
 595
 596	if err != nil {
 597		return w.fatal(err)
 598	}
 599
 600	if final {
 601		c.writer = nil
 602		return nil
 603	}
 604
 605	// Setup for next frame.
 606	w.pos = maxFrameHeaderSize
 607	w.frameType = continuationFrame
 608	return nil
 609}
 610
 611func (w *messageWriter) ncopy(max int) (int, error) {
 612	n := len(w.c.writeBuf) - w.pos
 613	if n <= 0 {
 614		if err := w.flushFrame(false, nil); err != nil {
 615			return 0, err
 616		}
 617		n = len(w.c.writeBuf) - w.pos
 618	}
 619	if n > max {
 620		n = max
 621	}
 622	return n, nil
 623}
 624
 625func (w *messageWriter) Write(p []byte) (int, error) {
 626	if w.err != nil {
 627		return 0, w.err
 628	}
 629
 630	if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
 631		// Don't buffer large messages.
 632		err := w.flushFrame(false, p)
 633		if err != nil {
 634			return 0, err
 635		}
 636		return len(p), nil
 637	}
 638
 639	nn := len(p)
 640	for len(p) > 0 {
 641		n, err := w.ncopy(len(p))
 642		if err != nil {
 643			return 0, err
 644		}
 645		copy(w.c.writeBuf[w.pos:], p[:n])
 646		w.pos += n
 647		p = p[n:]
 648	}
 649	return nn, nil
 650}
 651
 652func (w *messageWriter) WriteString(p string) (int, error) {
 653	if w.err != nil {
 654		return 0, w.err
 655	}
 656
 657	nn := len(p)
 658	for len(p) > 0 {
 659		n, err := w.ncopy(len(p))
 660		if err != nil {
 661			return 0, err
 662		}
 663		copy(w.c.writeBuf[w.pos:], p[:n])
 664		w.pos += n
 665		p = p[n:]
 666	}
 667	return nn, nil
 668}
 669
 670func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
 671	if w.err != nil {
 672		return 0, w.err
 673	}
 674	for {
 675		if w.pos == len(w.c.writeBuf) {
 676			err = w.flushFrame(false, nil)
 677			if err != nil {
 678				break
 679			}
 680		}
 681		var n int
 682		n, err = r.Read(w.c.writeBuf[w.pos:])
 683		w.pos += n
 684		nn += int64(n)
 685		if err != nil {
 686			if err == io.EOF {
 687				err = nil
 688			}
 689			break
 690		}
 691	}
 692	return nn, err
 693}
 694
 695func (w *messageWriter) Close() error {
 696	if w.err != nil {
 697		return w.err
 698	}
 699	if err := w.flushFrame(true, nil); err != nil {
 700		return err
 701	}
 702	w.err = errWriteClosed
 703	return nil
 704}
 705
 706// WritePreparedMessage writes prepared message into connection.
 707func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error {
 708	frameType, frameData, err := pm.frame(prepareKey{
 709		isServer:         c.isServer,
 710		compress:         c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType),
 711		compressionLevel: c.compressionLevel,
 712	})
 713	if err != nil {
 714		return err
 715	}
 716	if c.isWriting {
 717		panic("concurrent write to websocket connection")
 718	}
 719	c.isWriting = true
 720	err = c.write(frameType, c.writeDeadline, frameData, nil)
 721	if !c.isWriting {
 722		panic("concurrent write to websocket connection")
 723	}
 724	c.isWriting = false
 725	return err
 726}
 727
 728// WriteMessage is a helper method for getting a writer using NextWriter,
 729// writing the message and closing the writer.
 730func (c *Conn) WriteMessage(messageType int, data []byte) error {
 731
 732	if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
 733		// Fast path with no allocations and single frame.
 734
 735		if err := c.prepWrite(messageType); err != nil {
 736			return err
 737		}
 738		mw := messageWriter{c: c, frameType: messageType, pos: maxFrameHeaderSize}
 739		n := copy(c.writeBuf[mw.pos:], data)
 740		mw.pos += n
 741		data = data[n:]
 742		return mw.flushFrame(true, data)
 743	}
 744
 745	w, err := c.NextWriter(messageType)
 746	if err != nil {
 747		return err
 748	}
 749	if _, err = w.Write(data); err != nil {
 750		return err
 751	}
 752	return w.Close()
 753}
 754
 755// SetWriteDeadline sets the write deadline on the underlying network
 756// connection. After a write has timed out, the websocket state is corrupt and
 757// all future writes will return an error. A zero value for t means writes will
 758// not time out.
 759func (c *Conn) SetWriteDeadline(t time.Time) error {
 760	c.writeDeadline = t
 761	return nil
 762}
 763
 764// Read methods
 765
 766func (c *Conn) advanceFrame() (int, error) {
 767
 768	// 1. Skip remainder of previous frame.
 769
 770	if c.readRemaining > 0 {
 771		if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil {
 772			return noFrame, err
 773		}
 774	}
 775
 776	// 2. Read and parse first two bytes of frame header.
 777
 778	p, err := c.read(2)
 779	if err != nil {
 780		return noFrame, err
 781	}
 782
 783	final := p[0]&finalBit != 0
 784	frameType := int(p[0] & 0xf)
 785	mask := p[1]&maskBit != 0
 786	c.readRemaining = int64(p[1] & 0x7f)
 787
 788	c.readDecompress = false
 789	if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
 790		c.readDecompress = true
 791		p[0] &^= rsv1Bit
 792	}
 793
 794	if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
 795		return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
 796	}
 797
 798	switch frameType {
 799	case CloseMessage, PingMessage, PongMessage:
 800		if c.readRemaining > maxControlFramePayloadSize {
 801			return noFrame, c.handleProtocolError("control frame length > 125")
 802		}
 803		if !final {
 804			return noFrame, c.handleProtocolError("control frame not final")
 805		}
 806	case TextMessage, BinaryMessage:
 807		if !c.readFinal {
 808			return noFrame, c.handleProtocolError("message start before final message frame")
 809		}
 810		c.readFinal = final
 811	case continuationFrame:
 812		if c.readFinal {
 813			return noFrame, c.handleProtocolError("continuation after final message frame")
 814		}
 815		c.readFinal = final
 816	default:
 817		return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
 818	}
 819
 820	// 3. Read and parse frame length.
 821
 822	switch c.readRemaining {
 823	case 126:
 824		p, err := c.read(2)
 825		if err != nil {
 826			return noFrame, err
 827		}
 828		c.readRemaining = int64(binary.BigEndian.Uint16(p))
 829	case 127:
 830		p, err := c.read(8)
 831		if err != nil {
 832			return noFrame, err
 833		}
 834		c.readRemaining = int64(binary.BigEndian.Uint64(p))
 835	}
 836
 837	// 4. Handle frame masking.
 838
 839	if mask != c.isServer {
 840		return noFrame, c.handleProtocolError("incorrect mask flag")
 841	}
 842
 843	if mask {
 844		c.readMaskPos = 0
 845		p, err := c.read(len(c.readMaskKey))
 846		if err != nil {
 847			return noFrame, err
 848		}
 849		copy(c.readMaskKey[:], p)
 850	}
 851
 852	// 5. For text and binary messages, enforce read limit and return.
 853
 854	if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
 855
 856		c.readLength += c.readRemaining
 857		if c.readLimit > 0 && c.readLength > c.readLimit {
 858			c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
 859			return noFrame, ErrReadLimit
 860		}
 861
 862		return frameType, nil
 863	}
 864
 865	// 6. Read control frame payload.
 866
 867	var payload []byte
 868	if c.readRemaining > 0 {
 869		payload, err = c.read(int(c.readRemaining))
 870		c.readRemaining = 0
 871		if err != nil {
 872			return noFrame, err
 873		}
 874		if c.isServer {
 875			maskBytes(c.readMaskKey, 0, payload)
 876		}
 877	}
 878
 879	// 7. Process control frame payload.
 880
 881	switch frameType {
 882	case PongMessage:
 883		if err := c.handlePong(string(payload)); err != nil {
 884			return noFrame, err
 885		}
 886	case PingMessage:
 887		if err := c.handlePing(string(payload)); err != nil {
 888			return noFrame, err
 889		}
 890	case CloseMessage:
 891		closeCode := CloseNoStatusReceived
 892		closeText := ""
 893		if len(payload) >= 2 {
 894			closeCode = int(binary.BigEndian.Uint16(payload))
 895			if !isValidReceivedCloseCode(closeCode) {
 896				return noFrame, c.handleProtocolError("invalid close code")
 897			}
 898			closeText = string(payload[2:])
 899			if !utf8.ValidString(closeText) {
 900				return noFrame, c.handleProtocolError("invalid utf8 payload in close frame")
 901			}
 902		}
 903		if err := c.handleClose(closeCode, closeText); err != nil {
 904			return noFrame, err
 905		}
 906		return noFrame, &CloseError{Code: closeCode, Text: closeText}
 907	}
 908
 909	return frameType, nil
 910}
 911
 912func (c *Conn) handleProtocolError(message string) error {
 913	c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait))
 914	return errors.New("websocket: " + message)
 915}
 916
 917// NextReader returns the next data message received from the peer. The
 918// returned messageType is either TextMessage or BinaryMessage.
 919//
 920// There can be at most one open reader on a connection. NextReader discards
 921// the previous message if the application has not already consumed it.
 922//
 923// Applications must break out of the application's read loop when this method
 924// returns a non-nil error value. Errors returned from this method are
 925// permanent. Once this method returns a non-nil error, all subsequent calls to
 926// this method return the same error.
 927func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
 928	// Close previous reader, only relevant for decompression.
 929	if c.reader != nil {
 930		c.reader.Close()
 931		c.reader = nil
 932	}
 933
 934	c.messageReader = nil
 935	c.readLength = 0
 936
 937	for c.readErr == nil {
 938		frameType, err := c.advanceFrame()
 939		if err != nil {
 940			c.readErr = hideTempErr(err)
 941			break
 942		}
 943		if frameType == TextMessage || frameType == BinaryMessage {
 944			c.messageReader = &messageReader{c}
 945			c.reader = c.messageReader
 946			if c.readDecompress {
 947				c.reader = c.newDecompressionReader(c.reader)
 948			}
 949			return frameType, c.reader, nil
 950		}
 951	}
 952
 953	// Applications that do handle the error returned from this method spin in
 954	// tight loop on connection failure. To help application developers detect
 955	// this error, panic on repeated reads to the failed connection.
 956	c.readErrCount++
 957	if c.readErrCount >= 1000 {
 958		panic("repeated read on failed websocket connection")
 959	}
 960
 961	return noFrame, nil, c.readErr
 962}
 963
 964type messageReader struct{ c *Conn }
 965
 966func (r *messageReader) Read(b []byte) (int, error) {
 967	c := r.c
 968	if c.messageReader != r {
 969		return 0, io.EOF
 970	}
 971
 972	for c.readErr == nil {
 973
 974		if c.readRemaining > 0 {
 975			if int64(len(b)) > c.readRemaining {
 976				b = b[:c.readRemaining]
 977			}
 978			n, err := c.br.Read(b)
 979			c.readErr = hideTempErr(err)
 980			if c.isServer {
 981				c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
 982			}
 983			c.readRemaining -= int64(n)
 984			if c.readRemaining > 0 && c.readErr == io.EOF {
 985				c.readErr = errUnexpectedEOF
 986			}
 987			return n, c.readErr
 988		}
 989
 990		if c.readFinal {
 991			c.messageReader = nil
 992			return 0, io.EOF
 993		}
 994
 995		frameType, err := c.advanceFrame()
 996		switch {
 997		case err != nil:
 998			c.readErr = hideTempErr(err)
 999		case frameType == TextMessage || frameType == BinaryMessage:
1000			c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
1001		}
1002	}
1003
1004	err := c.readErr
1005	if err == io.EOF && c.messageReader == r {
1006		err = errUnexpectedEOF
1007	}
1008	return 0, err
1009}
1010
1011func (r *messageReader) Close() error {
1012	return nil
1013}
1014
1015// ReadMessage is a helper method for getting a reader using NextReader and
1016// reading from that reader to a buffer.
1017func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
1018	var r io.Reader
1019	messageType, r, err = c.NextReader()
1020	if err != nil {
1021		return messageType, nil, err
1022	}
1023	p, err = ioutil.ReadAll(r)
1024	return messageType, p, err
1025}
1026
1027// SetReadDeadline sets the read deadline on the underlying network connection.
1028// After a read has timed out, the websocket connection state is corrupt and
1029// all future reads will return an error. A zero value for t means reads will
1030// not time out.
1031func (c *Conn) SetReadDeadline(t time.Time) error {
1032	return c.conn.SetReadDeadline(t)
1033}
1034
1035// SetReadLimit sets the maximum size for a message read from the peer. If a
1036// message exceeds the limit, the connection sends a close frame to the peer
1037// and returns ErrReadLimit to the application.
1038func (c *Conn) SetReadLimit(limit int64) {
1039	c.readLimit = limit
1040}
1041
1042// CloseHandler returns the current close handler
1043func (c *Conn) CloseHandler() func(code int, text string) error {
1044	return c.handleClose
1045}
1046
1047// SetCloseHandler sets the handler for close messages received from the peer.
1048// The code argument to h is the received close code or CloseNoStatusReceived
1049// if the close message is empty. The default close handler sends a close frame
1050// back to the peer.
1051//
1052// The application must read the connection to process close messages as
1053// described in the section on Control Frames above.
1054//
1055// The connection read methods return a CloseError when a close frame is
1056// received. Most applications should handle close messages as part of their
1057// normal error handling. Applications should only set a close handler when the
1058// application must perform some action before sending a close frame back to
1059// the peer.
1060func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
1061	if h == nil {
1062		h = func(code int, text string) error {
1063			message := []byte{}
1064			if code != CloseNoStatusReceived {
1065				message = FormatCloseMessage(code, "")
1066			}
1067			c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
1068			return nil
1069		}
1070	}
1071	c.handleClose = h
1072}
1073
1074// PingHandler returns the current ping handler
1075func (c *Conn) PingHandler() func(appData string) error {
1076	return c.handlePing
1077}
1078
1079// SetPingHandler sets the handler for ping messages received from the peer.
1080// The appData argument to h is the PING frame application data. The default
1081// ping handler sends a pong to the peer.
1082//
1083// The application must read the connection to process ping messages as
1084// described in the section on Control Frames above.
1085func (c *Conn) SetPingHandler(h func(appData string) error) {
1086	if h == nil {
1087		h = func(message string) error {
1088			err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait))
1089			if err == ErrCloseSent {
1090				return nil
1091			} else if e, ok := err.(net.Error); ok && e.Temporary() {
1092				return nil
1093			}
1094			return err
1095		}
1096	}
1097	c.handlePing = h
1098}
1099
1100// PongHandler returns the current pong handler
1101func (c *Conn) PongHandler() func(appData string) error {
1102	return c.handlePong
1103}
1104
1105// SetPongHandler sets the handler for pong messages received from the peer.
1106// The appData argument to h is the PONG frame application data. The default
1107// pong handler does nothing.
1108//
1109// The application must read the connection to process ping messages as
1110// described in the section on Control Frames above.
1111func (c *Conn) SetPongHandler(h func(appData string) error) {
1112	if h == nil {
1113		h = func(string) error { return nil }
1114	}
1115	c.handlePong = h
1116}
1117
1118// UnderlyingConn returns the internal net.Conn. This can be used to further
1119// modifications to connection specific flags.
1120func (c *Conn) UnderlyingConn() net.Conn {
1121	return c.conn
1122}
1123
1124// EnableWriteCompression enables and disables write compression of
1125// subsequent text and binary messages. This function is a noop if
1126// compression was not negotiated with the peer.
1127func (c *Conn) EnableWriteCompression(enable bool) {
1128	c.enableWriteCompression = enable
1129}
1130
1131// SetCompressionLevel sets the flate compression level for subsequent text and
1132// binary messages. This function is a noop if compression was not negotiated
1133// with the peer. See the compress/flate package for a description of
1134// compression levels.
1135func (c *Conn) SetCompressionLevel(level int) error {
1136	if !isValidCompressionLevel(level) {
1137		return errors.New("websocket: invalid compression level")
1138	}
1139	c.compressionLevel = level
1140	return nil
1141}
1142
1143// FormatCloseMessage formats closeCode and text as a WebSocket close message.
1144func FormatCloseMessage(closeCode int, text string) []byte {
1145	buf := make([]byte, 2+len(text))
1146	binary.BigEndian.PutUint16(buf, uint16(closeCode))
1147	copy(buf[2:], text)
1148	return buf
1149}