1package server
2
3import (
4 "context"
5 "fmt"
6 "net/http"
7 "net/url"
8 "os"
9 "path/filepath"
10 "strings"
11 "text/template"
12 "time"
13
14 "github.com/charmbracelet/soft-serve/server/backend"
15 "github.com/charmbracelet/soft-serve/server/config"
16 "github.com/charmbracelet/soft-serve/server/utils"
17 "github.com/dustin/go-humanize"
18 "goji.io"
19 "goji.io/pat"
20 "goji.io/pattern"
21)
22
23// logWriter is a wrapper around http.ResponseWriter that allows us to capture
24// the HTTP status code and bytes written to the response.
25type logWriter struct {
26 http.ResponseWriter
27 code, bytes int
28}
29
30func (r *logWriter) Write(p []byte) (int, error) {
31 written, err := r.ResponseWriter.Write(p)
32 r.bytes += written
33 return written, err
34}
35
36// Note this is generally only called when sending an HTTP error, so it's
37// important to set the `code` value to 200 as a default
38func (r *logWriter) WriteHeader(code int) {
39 r.code = code
40 r.ResponseWriter.WriteHeader(code)
41}
42
43func loggingMiddleware(next http.Handler) http.Handler {
44 logger := logger.WithPrefix("server.http")
45 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
46 start := time.Now()
47 writer := &logWriter{code: http.StatusOK, ResponseWriter: w}
48 logger.Debug("request",
49 "method", r.Method,
50 "uri", r.RequestURI,
51 "addr", r.RemoteAddr)
52 next.ServeHTTP(writer, r)
53 elapsed := time.Since(start)
54 logger.Debug("response",
55 "status", fmt.Sprintf("%d %s", writer.code, http.StatusText(writer.code)),
56 "bytes", humanize.Bytes(uint64(writer.bytes)),
57 "time", elapsed)
58 })
59}
60
61// HTTPServer is an http server.
62type HTTPServer struct {
63 cfg *config.Config
64 server *http.Server
65 dirHandler http.Handler
66}
67
68func NewHTTPServer(cfg *config.Config) (*HTTPServer, error) {
69 mux := goji.NewMux()
70 s := &HTTPServer{
71 cfg: cfg,
72 dirHandler: http.FileServer(http.Dir(filepath.Join(cfg.DataPath, "repos"))),
73 server: &http.Server{
74 Addr: cfg.HTTP.ListenAddr,
75 Handler: mux,
76 ReadHeaderTimeout: time.Second * 10,
77 ReadTimeout: time.Second * 10,
78 WriteTimeout: time.Second * 10,
79 MaxHeaderBytes: http.DefaultMaxHeaderBytes,
80 },
81 }
82
83 mux.Use(loggingMiddleware)
84 mux.HandleFunc(pat.Get("/:repo"), s.repoIndexHandler)
85 mux.HandleFunc(pat.Get("/:repo/*"), s.dumbGitHandler)
86 return s, nil
87}
88
89// Close closes the HTTP server.
90func (s *HTTPServer) Close() error {
91 return s.server.Close()
92}
93
94// ListenAndServe starts the HTTP server.
95func (s *HTTPServer) ListenAndServe() error {
96 return s.server.ListenAndServe()
97}
98
99// Shutdown gracefully shuts down the HTTP server.
100func (s *HTTPServer) Shutdown(ctx context.Context) error {
101 return s.server.Shutdown(ctx)
102}
103
104var repoIndexHTMLTpl = template.Must(template.New("index").Parse(`<!DOCTYPE html>
105<html lang="en">
106<head>
107 <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
108 <meta http-equiv="refresh" content="0; url=https://godoc.org/{{.ImportRoot}}/{{.Repo}}"
109 <meta name="go-import" content="{{.ImportRoot}}/{{.Repo}} git {{.Config.SSH.PublicURL}}/{{.Repo}}">
110</head>
111<body>
112Redirecting to docs at <a href="https://godoc.org/{{.ImportRoot}}/{{.Repo}}">godoc.org/{{.ImportRoot}}/{{.Repo}}</a>...
113</body>
114</html>`))
115
116func (s *HTTPServer) repoIndexHandler(w http.ResponseWriter, r *http.Request) {
117 repo := pat.Param(r, "repo")
118 repo = utils.SanitizeRepo(repo)
119
120 // Only respond to go-get requests
121 if r.URL.Query().Get("go-get") != "1" {
122 http.NotFound(w, r)
123 return
124 }
125
126 access := s.cfg.Backend.AccessLevel(repo, nil)
127 if access < backend.ReadOnlyAccess {
128 http.NotFound(w, r)
129 return
130 }
131
132 importRoot, err := url.Parse(s.cfg.HTTP.PublicURL)
133 if err != nil {
134 http.Error(w, err.Error(), http.StatusInternalServerError)
135 }
136
137 if err := repoIndexHTMLTpl.Execute(w, struct {
138 Repo string
139 Config *config.Config
140 ImportRoot string
141 }{
142 Repo: repo,
143 Config: s.cfg,
144 ImportRoot: importRoot.Host,
145 }); err != nil {
146 http.Error(w, err.Error(), http.StatusInternalServerError)
147 return
148 }
149}
150
151func (s *HTTPServer) dumbGitHandler(w http.ResponseWriter, r *http.Request) {
152 repo := pat.Param(r, "repo")
153 repo = utils.SanitizeRepo(repo) + ".git"
154
155 access := s.cfg.Backend.AccessLevel(repo, nil)
156 if access < backend.ReadOnlyAccess || !s.cfg.Backend.AllowKeyless() {
157 httpStatusError(w, http.StatusUnauthorized)
158 return
159 }
160
161 path := pattern.Path(r.Context())
162 stat, err := os.Stat(filepath.Join(s.cfg.DataPath, "repos", repo, path))
163 // Restrict access to files
164 if err != nil || stat.IsDir() {
165 http.NotFound(w, r)
166 return
167 }
168
169 // Don't allow access to non-git clients
170 ua := r.Header.Get("User-Agent")
171 if !strings.HasPrefix(strings.ToLower(ua), "git") {
172 httpStatusError(w, http.StatusBadRequest)
173 return
174 }
175
176 w.Header().Set("Content-Type", "text/plain; charset=utf-8")
177 w.Header().Set("X-Content-Type-Options", "nosniff")
178 r.URL.Path = fmt.Sprintf("/%s/%s", repo, path)
179 s.dirHandler.ServeHTTP(w, r)
180}
181
182func httpStatusError(w http.ResponseWriter, status int) {
183 http.Error(w, fmt.Sprintf("%d %s", status, http.StatusText(status)), status)
184}