compression.go

  1// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package websocket
  6
  7import (
  8	"compress/flate"
  9	"errors"
 10	"io"
 11	"strings"
 12	"sync"
 13)
 14
 15const (
 16	minCompressionLevel     = -2 // flate.HuffmanOnly not defined in Go < 1.6
 17	maxCompressionLevel     = flate.BestCompression
 18	defaultCompressionLevel = 1
 19)
 20
 21var (
 22	flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool
 23	flateReaderPool  = sync.Pool{New: func() interface{} {
 24		return flate.NewReader(nil)
 25	}}
 26)
 27
 28func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
 29	const tail =
 30	// Add four bytes as specified in RFC
 31	"\x00\x00\xff\xff" +
 32		// Add final block to squelch unexpected EOF error from flate reader.
 33		"\x01\x00\x00\xff\xff"
 34
 35	fr, _ := flateReaderPool.Get().(io.ReadCloser)
 36	fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
 37	return &flateReadWrapper{fr}
 38}
 39
 40func isValidCompressionLevel(level int) bool {
 41	return minCompressionLevel <= level && level <= maxCompressionLevel
 42}
 43
 44func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
 45	p := &flateWriterPools[level-minCompressionLevel]
 46	tw := &truncWriter{w: w}
 47	fw, _ := p.Get().(*flate.Writer)
 48	if fw == nil {
 49		fw, _ = flate.NewWriter(tw, level)
 50	} else {
 51		fw.Reset(tw)
 52	}
 53	return &flateWriteWrapper{fw: fw, tw: tw, p: p}
 54}
 55
 56// truncWriter is an io.Writer that writes all but the last four bytes of the
 57// stream to another io.Writer.
 58type truncWriter struct {
 59	w io.WriteCloser
 60	n int
 61	p [4]byte
 62}
 63
 64func (w *truncWriter) Write(p []byte) (int, error) {
 65	n := 0
 66
 67	// fill buffer first for simplicity.
 68	if w.n < len(w.p) {
 69		n = copy(w.p[w.n:], p)
 70		p = p[n:]
 71		w.n += n
 72		if len(p) == 0 {
 73			return n, nil
 74		}
 75	}
 76
 77	m := len(p)
 78	if m > len(w.p) {
 79		m = len(w.p)
 80	}
 81
 82	if nn, err := w.w.Write(w.p[:m]); err != nil {
 83		return n + nn, err
 84	}
 85
 86	copy(w.p[:], w.p[m:])
 87	copy(w.p[len(w.p)-m:], p[len(p)-m:])
 88	nn, err := w.w.Write(p[:len(p)-m])
 89	return n + nn, err
 90}
 91
 92type flateWriteWrapper struct {
 93	fw *flate.Writer
 94	tw *truncWriter
 95	p  *sync.Pool
 96}
 97
 98func (w *flateWriteWrapper) Write(p []byte) (int, error) {
 99	if w.fw == nil {
100		return 0, errWriteClosed
101	}
102	return w.fw.Write(p)
103}
104
105func (w *flateWriteWrapper) Close() error {
106	if w.fw == nil {
107		return errWriteClosed
108	}
109	err1 := w.fw.Flush()
110	w.p.Put(w.fw)
111	w.fw = nil
112	if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
113		return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
114	}
115	err2 := w.tw.w.Close()
116	if err1 != nil {
117		return err1
118	}
119	return err2
120}
121
122type flateReadWrapper struct {
123	fr io.ReadCloser
124}
125
126func (r *flateReadWrapper) Read(p []byte) (int, error) {
127	if r.fr == nil {
128		return 0, io.ErrClosedPipe
129	}
130	n, err := r.fr.Read(p)
131	if err == io.EOF {
132		// Preemptively place the reader back in the pool. This helps with
133		// scenarios where the application does not call NextReader() soon after
134		// this final read.
135		r.Close()
136	}
137	return n, err
138}
139
140func (r *flateReadWrapper) Close() error {
141	if r.fr == nil {
142		return io.ErrClosedPipe
143	}
144	err := r.fr.Close()
145	flateReaderPool.Put(r.fr)
146	r.fr = nil
147	return err
148}