1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT license.
3
4// Package local contains a local HTTP server used with interactive authentication.
5package local
6
7import (
8 "context"
9 "fmt"
10 "net"
11 "net/http"
12 "strconv"
13 "strings"
14 "time"
15)
16
17var okPage = []byte(`
18<!DOCTYPE html>
19<html>
20<head>
21 <meta charset="utf-8" />
22 <title>Authentication Complete</title>
23</head>
24<body>
25 <p>Authentication complete. You can return to the application. Feel free to close this browser tab.</p>
26</body>
27</html>
28`)
29
30const failPage = `
31<!DOCTYPE html>
32<html>
33<head>
34 <meta charset="utf-8" />
35 <title>Authentication Failed</title>
36</head>
37<body>
38 <p>Authentication failed. You can return to the application. Feel free to close this browser tab.</p>
39 <p>Error details: error %s error_description: %s</p>
40</body>
41</html>
42`
43
44// Result is the result from the redirect.
45type Result struct {
46 // Code is the code sent by the authority server.
47 Code string
48 // Err is set if there was an error.
49 Err error
50}
51
52// Server is an HTTP server.
53type Server struct {
54 // Addr is the address the server is listening on.
55 Addr string
56 resultCh chan Result
57 s *http.Server
58 reqState string
59}
60
61// New creates a local HTTP server and starts it.
62func New(reqState string, port int) (*Server, error) {
63 var l net.Listener
64 var err error
65 var portStr string
66 if port > 0 {
67 // use port provided by caller
68 l, err = net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
69 portStr = strconv.FormatInt(int64(port), 10)
70 } else {
71 // find a free port
72 for i := 0; i < 10; i++ {
73 l, err = net.Listen("tcp", "localhost:0")
74 if err != nil {
75 continue
76 }
77 addr := l.Addr().String()
78 portStr = addr[strings.LastIndex(addr, ":")+1:]
79 break
80 }
81 }
82 if err != nil {
83 return nil, err
84 }
85
86 serv := &Server{
87 Addr: fmt.Sprintf("http://localhost:%s", portStr),
88 s: &http.Server{Addr: "localhost:0", ReadHeaderTimeout: time.Second},
89 reqState: reqState,
90 resultCh: make(chan Result, 1),
91 }
92 serv.s.Handler = http.HandlerFunc(serv.handler)
93
94 if err := serv.start(l); err != nil {
95 return nil, err
96 }
97
98 return serv, nil
99}
100
101func (s *Server) start(l net.Listener) error {
102 go func() {
103 err := s.s.Serve(l)
104 if err != nil {
105 select {
106 case s.resultCh <- Result{Err: err}:
107 default:
108 }
109 }
110 }()
111
112 return nil
113}
114
115// Result gets the result of the redirect operation. Once a single result is returned, the server
116// is shutdown. ctx deadline will be honored.
117func (s *Server) Result(ctx context.Context) Result {
118 select {
119 case <-ctx.Done():
120 return Result{Err: ctx.Err()}
121 case r := <-s.resultCh:
122 return r
123 }
124}
125
126// Shutdown shuts down the server.
127func (s *Server) Shutdown() {
128 // Note: You might get clever and think you can do this in handler() as a defer, you can't.
129 _ = s.s.Shutdown(context.Background())
130}
131
132func (s *Server) putResult(r Result) {
133 select {
134 case s.resultCh <- r:
135 default:
136 }
137}
138
139func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
140 q := r.URL.Query()
141
142 headerErr := q.Get("error")
143 if headerErr != "" {
144 desc := q.Get("error_description")
145 // Note: It is a little weird we handle some errors by not going to the failPage. If they all should,
146 // change this to s.error() and make s.error() write the failPage instead of an error code.
147 _, _ = w.Write([]byte(fmt.Sprintf(failPage, headerErr, desc)))
148 s.putResult(Result{Err: fmt.Errorf(desc)})
149 return
150 }
151
152 respState := q.Get("state")
153 switch respState {
154 case s.reqState:
155 case "":
156 s.error(w, http.StatusInternalServerError, "server didn't send OAuth state")
157 return
158 default:
159 s.error(w, http.StatusInternalServerError, "mismatched OAuth state, req(%s), resp(%s)", s.reqState, respState)
160 return
161 }
162
163 code := q.Get("code")
164 if code == "" {
165 s.error(w, http.StatusInternalServerError, "authorization code missing in query string")
166 return
167 }
168
169 _, _ = w.Write(okPage)
170 s.putResult(Result{Code: code})
171}
172
173func (s *Server) error(w http.ResponseWriter, code int, str string, i ...interface{}) {
174 err := fmt.Errorf(str, i...)
175 http.Error(w, err.Error(), code)
176 s.putResult(Result{Err: err})
177}