main.go

  1package main
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log"
  8	"math/rand"
  9	"net/http"
 10	"os"
 11	"time"
 12
 13	"github.com/go-chi/chi"
 14	"github.com/go-chi/chi/middleware"
 15	"github.com/jessevdk/go-flags"
 16	"github.com/posener/ctxutil"
 17	"github.com/prometheus/client_golang/prometheus/promhttp"
 18	"github.com/tomwright/queryparam/v4"
 19	"github.com/zikaeroh/codies/internal/pkger"
 20	"github.com/zikaeroh/codies/internal/protocol"
 21	"github.com/zikaeroh/codies/internal/responder"
 22	"github.com/zikaeroh/codies/internal/server"
 23	"github.com/zikaeroh/codies/internal/version"
 24	"github.com/zikaeroh/ctxlog"
 25	"go.uber.org/zap"
 26	"golang.org/x/sync/errgroup"
 27	"nhooyr.io/websocket"
 28)
 29
 30var args = struct {
 31	Addr    string   `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"`
 32	Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"`
 33	Prod    bool     `long:"prod" env:"CODIES_PROD" description:"Enables production mode"`
 34	Debug   bool     `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"`
 35}{
 36	Addr: ":5000",
 37}
 38
 39var wsOpts *websocket.AcceptOptions
 40
 41func main() {
 42	rand.Seed(time.Now().Unix())
 43
 44	if _, err := flags.Parse(&args); err != nil {
 45		// Default flag parser prints messages, so just exit.
 46		os.Exit(1)
 47	}
 48
 49	if !args.Prod && !args.Debug {
 50		log.Fatal("missing required option --prod or --debug")
 51	} else if args.Prod && args.Debug {
 52		log.Fatal("must specify either --prod or --debug")
 53	}
 54
 55	ctx := ctxutil.Interrupt()
 56
 57	logger := ctxlog.New(args.Debug)
 58	defer zap.RedirectStdLog(logger)()
 59	ctx = ctxlog.WithLogger(ctx, logger)
 60
 61	ctxlog.Info(ctx, "starting", zap.String("version", version.Version()))
 62
 63	wsOpts = &websocket.AcceptOptions{
 64		OriginPatterns:  args.Origins,
 65		CompressionMode: websocket.CompressionContextTakeover,
 66	}
 67
 68	if args.Debug {
 69		ctxlog.Info(ctx, "starting in debug mode, allowing any WebSocket origin host")
 70		wsOpts.InsecureSkipVerify = true
 71	} else {
 72		if !version.VersionSet() {
 73			ctxlog.Fatal(ctx, "running production build without version set")
 74		}
 75	}
 76
 77	g, ctx := errgroup.WithContext(ctx)
 78
 79	srv := server.NewServer()
 80
 81	r := chi.NewMux()
 82
 83	r.Use(func(next http.Handler) http.Handler {
 84		return promhttp.InstrumentHandlerCounter(metricRequest, next)
 85	})
 86
 87	r.Use(middleware.Heartbeat("/ping"))
 88	r.Use(middleware.Recoverer)
 89	r.NotFound(staticHandler().ServeHTTP)
 90
 91	r.Group(func(r chi.Router) {
 92		r.Use(middleware.NoCache)
 93
 94		r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) {
 95			responder.Respond(w, responder.Body(&protocol.TimeResponse{Time: time.Now()}))
 96		})
 97
 98		r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) {
 99			rooms, clients := srv.Stats()
100			responder.Respond(w,
101				responder.Body(&protocol.StatsResponse{
102					Rooms:   rooms,
103					Clients: clients,
104				}),
105				responder.Pretty(true),
106			)
107		})
108
109		r.Group(func(r chi.Router) {
110			if !args.Debug {
111				r.Use(checkVersion)
112			}
113
114			r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) {
115				query := &protocol.ExistsQuery{}
116				if err := queryparam.Parse(r.URL.Query(), query); err != nil {
117					responder.Respond(w, responder.Status(http.StatusBadRequest))
118					return
119				}
120
121				room := srv.FindRoomByID(query.RoomID)
122				if room == nil {
123					responder.Respond(w, responder.Status(http.StatusNotFound))
124				} else {
125					responder.Respond(w, responder.Status(http.StatusOK))
126				}
127			})
128
129			r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) {
130				defer r.Body.Close()
131
132				req := &protocol.RoomRequest{}
133				if err := json.NewDecoder(r.Body).Decode(req); err != nil {
134					responder.Respond(w, responder.Status(http.StatusBadRequest))
135					return
136				}
137
138				if msg, valid := req.Valid(); !valid {
139					responder.Respond(w,
140						responder.Status(http.StatusBadRequest),
141						responder.Body(&protocol.RoomResponse{
142							Error: stringPtr(msg),
143						}),
144					)
145					return
146				}
147
148				var room *server.Room
149				if req.Create {
150					var err error
151					room, err = srv.CreateRoom(ctx, req.RoomName, req.RoomPass)
152					if err != nil {
153						switch err {
154						case server.ErrRoomExists:
155							responder.Respond(w,
156								responder.Status(http.StatusBadRequest),
157								responder.Body(&protocol.RoomResponse{
158									Error: stringPtr("Room already exists."),
159								}),
160							)
161						case server.ErrTooManyRooms:
162							responder.Respond(w,
163								responder.Status(http.StatusServiceUnavailable),
164								responder.Body(&protocol.RoomResponse{
165									Error: stringPtr("Too many rooms."),
166								}),
167							)
168						default:
169							responder.Respond(w,
170								responder.Status(http.StatusInternalServerError),
171								responder.Body(&protocol.RoomResponse{
172									Error: stringPtr("An unknown error occurred."),
173								}),
174							)
175						}
176						return
177					}
178				} else {
179					room = srv.FindRoom(req.RoomName)
180					if room == nil || room.Password != req.RoomPass {
181						responder.Respond(w,
182							responder.Status(http.StatusNotFound),
183							responder.Body(&protocol.RoomResponse{
184								Error: stringPtr("Room not found or password does not match."),
185							}),
186						)
187						return
188					}
189				}
190
191				responder.Respond(w, responder.Body(&protocol.RoomResponse{
192					ID: &room.ID,
193				}))
194			})
195
196			r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
197				query := &protocol.WSQuery{}
198				if err := queryparam.Parse(r.URL.Query(), query); err != nil {
199					responder.Respond(w, responder.Status(http.StatusBadRequest))
200					return
201				}
202
203				if _, valid := query.Valid(); !valid {
204					responder.Respond(w, responder.Status(http.StatusBadRequest))
205					return
206				}
207
208				room := srv.FindRoomByID(query.RoomID)
209				if room == nil {
210					responder.Respond(w, responder.Status(http.StatusBadRequest))
211					return
212				}
213
214				c, err := websocket.Accept(w, r, wsOpts)
215				if err != nil {
216					return
217				}
218
219				g.Go(func() error {
220					room.HandleConn(ctx, query.PlayerID, query.Nickname, c)
221					return nil
222				})
223			})
224		})
225	})
226
227	g.Go(func() error {
228		return srv.Run(ctx)
229	})
230
231	runServer(ctx, g, args.Addr, r)
232
233	if args.Prod {
234		runServer(ctx, g, ":2112", prometheusHandler())
235	}
236
237	exitErr := g.Wait()
238	ctxlog.Fatal(ctx, "exited", zap.Error(exitErr))
239}
240
241func staticHandler() http.Handler {
242	fs := pkger.Dir("/frontend/build")
243	fsh := http.FileServer(fs)
244
245	r := chi.NewMux()
246	r.Use(middleware.Compress(5))
247
248	r.Handle("/static/*", fsh)
249	r.Handle("/favicon/*", fsh)
250
251	r.Group(func(r chi.Router) {
252		r.Use(middleware.NoCache)
253		r.Handle("/*", fsh)
254	})
255
256	return r
257}
258
259func checkVersion(next http.Handler) http.Handler {
260	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
261		want := version.Version()
262
263		toCheck := []string{
264			r.Header.Get("X-CODIES-VERSION"),
265			r.URL.Query().Get("codiesVersion"),
266		}
267
268		for _, got := range toCheck {
269			if got == want {
270				next.ServeHTTP(w, r)
271				return
272			}
273		}
274
275		reason := fmt.Sprintf("client version too old, please reload to get %s", want)
276
277		if r.Header.Get("Upgrade") == "websocket" {
278			c, err := websocket.Accept(w, r, wsOpts)
279			if err != nil {
280				return
281			}
282			c.Close(4418, reason)
283			return
284		}
285
286		w.WriteHeader(http.StatusTeapot)
287		fmt.Fprint(w, reason)
288	})
289}
290
291func runServer(ctx context.Context, g *errgroup.Group, addr string, handler http.Handler) {
292	httpSrv := http.Server{Addr: addr, Handler: handler}
293
294	g.Go(func() error {
295		<-ctx.Done()
296
297		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
298		defer cancel()
299
300		return httpSrv.Shutdown(ctx)
301	})
302
303	g.Go(func() error {
304		return httpSrv.ListenAndServe()
305	})
306}
307
308func prometheusHandler() http.Handler {
309	mux := http.NewServeMux()
310	mux.Handle("/metrics", promhttp.Handler())
311	return mux
312}