transport.go

  1// Copyright 2014 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package oauth2
  6
  7import (
  8	"errors"
  9	"io"
 10	"net/http"
 11	"sync"
 12)
 13
 14// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests,
 15// wrapping a base RoundTripper and adding an Authorization header
 16// with a token from the supplied Sources.
 17//
 18// Transport is a low-level mechanism. Most code will use the
 19// higher-level Config.Client method instead.
 20type Transport struct {
 21	// Source supplies the token to add to outgoing requests'
 22	// Authorization headers.
 23	Source TokenSource
 24
 25	// Base is the base RoundTripper used to make HTTP requests.
 26	// If nil, http.DefaultTransport is used.
 27	Base http.RoundTripper
 28
 29	mu     sync.Mutex                      // guards modReq
 30	modReq map[*http.Request]*http.Request // original -> modified
 31}
 32
 33// RoundTrip authorizes and authenticates the request with an
 34// access token from Transport's Source.
 35func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
 36	reqBodyClosed := false
 37	if req.Body != nil {
 38		defer func() {
 39			if !reqBodyClosed {
 40				req.Body.Close()
 41			}
 42		}()
 43	}
 44
 45	if t.Source == nil {
 46		return nil, errors.New("oauth2: Transport's Source is nil")
 47	}
 48	token, err := t.Source.Token()
 49	if err != nil {
 50		return nil, err
 51	}
 52
 53	req2 := cloneRequest(req) // per RoundTripper contract
 54	token.SetAuthHeader(req2)
 55	t.setModReq(req, req2)
 56	res, err := t.base().RoundTrip(req2)
 57
 58	// req.Body is assumed to have been closed by the base RoundTripper.
 59	reqBodyClosed = true
 60
 61	if err != nil {
 62		t.setModReq(req, nil)
 63		return nil, err
 64	}
 65	res.Body = &onEOFReader{
 66		rc: res.Body,
 67		fn: func() { t.setModReq(req, nil) },
 68	}
 69	return res, nil
 70}
 71
 72// CancelRequest cancels an in-flight request by closing its connection.
 73func (t *Transport) CancelRequest(req *http.Request) {
 74	type canceler interface {
 75		CancelRequest(*http.Request)
 76	}
 77	if cr, ok := t.base().(canceler); ok {
 78		t.mu.Lock()
 79		modReq := t.modReq[req]
 80		delete(t.modReq, req)
 81		t.mu.Unlock()
 82		cr.CancelRequest(modReq)
 83	}
 84}
 85
 86func (t *Transport) base() http.RoundTripper {
 87	if t.Base != nil {
 88		return t.Base
 89	}
 90	return http.DefaultTransport
 91}
 92
 93func (t *Transport) setModReq(orig, mod *http.Request) {
 94	t.mu.Lock()
 95	defer t.mu.Unlock()
 96	if t.modReq == nil {
 97		t.modReq = make(map[*http.Request]*http.Request)
 98	}
 99	if mod == nil {
100		delete(t.modReq, orig)
101	} else {
102		t.modReq[orig] = mod
103	}
104}
105
106// cloneRequest returns a clone of the provided *http.Request.
107// The clone is a shallow copy of the struct and its Header map.
108func cloneRequest(r *http.Request) *http.Request {
109	// shallow copy of the struct
110	r2 := new(http.Request)
111	*r2 = *r
112	// deep copy of the Header
113	r2.Header = make(http.Header, len(r.Header))
114	for k, s := range r.Header {
115		r2.Header[k] = append([]string(nil), s...)
116	}
117	return r2
118}
119
120type onEOFReader struct {
121	rc io.ReadCloser
122	fn func()
123}
124
125func (r *onEOFReader) Read(p []byte) (n int, err error) {
126	n, err = r.rc.Read(p)
127	if err == io.EOF {
128		r.runFunc()
129	}
130	return
131}
132
133func (r *onEOFReader) Close() error {
134	err := r.rc.Close()
135	r.runFunc()
136	return err
137}
138
139func (r *onEOFReader) runFunc() {
140	if fn := r.fn; fn != nil {
141		fn()
142		r.fn = nil
143	}
144}