1// Copyright (c) Microsoft Corporation.
  2// Licensed under the MIT license.
  3
  4// Package local contains a local HTTP server used with interactive authentication.
  5package local
  6
  7import (
  8	"context"
  9	"fmt"
 10	"net"
 11	"net/http"
 12	"strconv"
 13	"strings"
 14	"time"
 15)
 16
 17var okPage = []byte(`
 18<!DOCTYPE html>
 19<html>
 20<head>
 21    <meta charset="utf-8" />
 22    <title>Authentication Complete</title>
 23</head>
 24<body>
 25    <p>Authentication complete. You can return to the application. Feel free to close this browser tab.</p>
 26</body>
 27</html>
 28`)
 29
 30const failPage = `
 31<!DOCTYPE html>
 32<html>
 33<head>
 34    <meta charset="utf-8" />
 35    <title>Authentication Failed</title>
 36</head>
 37<body>
 38	<p>Authentication failed. You can return to the application. Feel free to close this browser tab.</p>
 39	<p>Error details: error %s error_description: %s</p>
 40</body>
 41</html>
 42`
 43
 44// Result is the result from the redirect.
 45type Result struct {
 46	// Code is the code sent by the authority server.
 47	Code string
 48	// Err is set if there was an error.
 49	Err error
 50}
 51
 52// Server is an HTTP server.
 53type Server struct {
 54	// Addr is the address the server is listening on.
 55	Addr     string
 56	resultCh chan Result
 57	s        *http.Server
 58	reqState string
 59}
 60
 61// New creates a local HTTP server and starts it.
 62func New(reqState string, port int) (*Server, error) {
 63	var l net.Listener
 64	var err error
 65	var portStr string
 66	if port > 0 {
 67		// use port provided by caller
 68		l, err = net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
 69		portStr = strconv.FormatInt(int64(port), 10)
 70	} else {
 71		// find a free port
 72		for i := 0; i < 10; i++ {
 73			l, err = net.Listen("tcp", "localhost:0")
 74			if err != nil {
 75				continue
 76			}
 77			addr := l.Addr().String()
 78			portStr = addr[strings.LastIndex(addr, ":")+1:]
 79			break
 80		}
 81	}
 82	if err != nil {
 83		return nil, err
 84	}
 85
 86	serv := &Server{
 87		Addr:     fmt.Sprintf("http://localhost:%s", portStr),
 88		s:        &http.Server{Addr: "localhost:0", ReadHeaderTimeout: time.Second},
 89		reqState: reqState,
 90		resultCh: make(chan Result, 1),
 91	}
 92	serv.s.Handler = http.HandlerFunc(serv.handler)
 93
 94	if err := serv.start(l); err != nil {
 95		return nil, err
 96	}
 97
 98	return serv, nil
 99}
100
101func (s *Server) start(l net.Listener) error {
102	go func() {
103		err := s.s.Serve(l)
104		if err != nil {
105			select {
106			case s.resultCh <- Result{Err: err}:
107			default:
108			}
109		}
110	}()
111
112	return nil
113}
114
115// Result gets the result of the redirect operation. Once a single result is returned, the server
116// is shutdown. ctx deadline will be honored.
117func (s *Server) Result(ctx context.Context) Result {
118	select {
119	case <-ctx.Done():
120		return Result{Err: ctx.Err()}
121	case r := <-s.resultCh:
122		return r
123	}
124}
125
126// Shutdown shuts down the server.
127func (s *Server) Shutdown() {
128	// Note: You might get clever and think you can do this in handler() as a defer, you can't.
129	_ = s.s.Shutdown(context.Background())
130}
131
132func (s *Server) putResult(r Result) {
133	select {
134	case s.resultCh <- r:
135	default:
136	}
137}
138
139func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
140	q := r.URL.Query()
141
142	headerErr := q.Get("error")
143	if headerErr != "" {
144		desc := q.Get("error_description")
145		// Note: It is a little weird we handle some errors by not going to the failPage. If they all should,
146		// change this to s.error() and make s.error() write the failPage instead of an error code.
147		_, _ = w.Write([]byte(fmt.Sprintf(failPage, headerErr, desc)))
148		s.putResult(Result{Err: fmt.Errorf(desc)})
149		return
150	}
151
152	respState := q.Get("state")
153	switch respState {
154	case s.reqState:
155	case "":
156		s.error(w, http.StatusInternalServerError, "server didn't send OAuth state")
157		return
158	default:
159		s.error(w, http.StatusInternalServerError, "mismatched OAuth state, req(%s), resp(%s)", s.reqState, respState)
160		return
161	}
162
163	code := q.Get("code")
164	if code == "" {
165		s.error(w, http.StatusInternalServerError, "authorization code missing in query string")
166		return
167	}
168
169	_, _ = w.Write(okPage)
170	s.putResult(Result{Code: code})
171}
172
173func (s *Server) error(w http.ResponseWriter, code int, str string, i ...interface{}) {
174	err := fmt.Errorf(str, i...)
175	http.Error(w, err.Error(), code)
176	s.putResult(Result{Err: err})
177}