main.go

  1package main
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log"
  8	"math/rand"
  9	"net/http"
 10	"os"
 11	"time"
 12
 13	"github.com/go-chi/chi"
 14	"github.com/go-chi/chi/middleware"
 15	"github.com/jessevdk/go-flags"
 16	"github.com/posener/ctxutil"
 17	"github.com/prometheus/client_golang/prometheus/promhttp"
 18	"github.com/tomwright/queryparam/v4"
 19	"github.com/zikaeroh/codies/internal/protocol"
 20	"github.com/zikaeroh/codies/internal/responder"
 21	"github.com/zikaeroh/codies/internal/server"
 22	"github.com/zikaeroh/codies/internal/version"
 23	"golang.org/x/sync/errgroup"
 24	"nhooyr.io/websocket"
 25)
 26
 27var args = struct {
 28	Addr    string   `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"`
 29	Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"`
 30	Prod    bool     `long:"prod" env:"CODIES_PROD" description:"Enables production mode"`
 31	Debug   bool     `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"`
 32}{
 33	Addr: ":5000",
 34}
 35
 36var wsOpts *websocket.AcceptOptions
 37
 38func main() {
 39	rand.Seed(time.Now().Unix())
 40	log.SetFlags(log.LstdFlags | log.Lshortfile)
 41
 42	if _, err := flags.Parse(&args); err != nil {
 43		// Default flag parser prints messages, so just exit.
 44		os.Exit(1)
 45	}
 46
 47	if !args.Prod && !args.Debug {
 48		log.Fatal("missing required option --prod or --debug")
 49	} else if args.Prod && args.Debug {
 50		log.Fatal("must specify either --prod or --debug")
 51	}
 52
 53	log.Printf("starting codies server, version %s", version.Version())
 54
 55	wsOpts = &websocket.AcceptOptions{
 56		OriginPatterns:  args.Origins,
 57		CompressionMode: websocket.CompressionContextTakeover,
 58	}
 59
 60	if args.Debug {
 61		log.Println("starting in debug mode, allowing any WebSocket origin host")
 62		wsOpts.InsecureSkipVerify = true
 63	} else {
 64		if !version.VersionSet() {
 65			log.Fatal("running production build without version set")
 66		}
 67	}
 68
 69	g, ctx := errgroup.WithContext(ctxutil.Interrupt())
 70
 71	srv := server.NewServer()
 72
 73	r := chi.NewMux()
 74
 75	r.Use(func(next http.Handler) http.Handler {
 76		return promhttp.InstrumentHandlerCounter(metricRequest, next)
 77	})
 78
 79	r.Use(middleware.Heartbeat("/ping"))
 80	r.Use(middleware.Recoverer)
 81	r.NotFound(staticHandler().ServeHTTP)
 82
 83	r.Group(func(r chi.Router) {
 84		r.Use(middleware.NoCache)
 85
 86		r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) {
 87			responder.Respond(w, responder.Body(&protocol.TimeResponse{Time: time.Now()}))
 88		})
 89
 90		r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) {
 91			rooms, clients := srv.Stats()
 92			responder.Respond(w,
 93				responder.Body(&protocol.StatsResponse{
 94					Rooms:   rooms,
 95					Clients: clients,
 96				}),
 97				responder.Pretty(true),
 98			)
 99		})
100
101		r.Group(func(r chi.Router) {
102			if !args.Debug {
103				r.Use(checkVersion)
104			}
105
106			r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) {
107				query := &protocol.ExistsQuery{}
108				if err := queryparam.Parse(r.URL.Query(), query); err != nil {
109					responder.Respond(w, responder.Status(http.StatusBadRequest))
110					return
111				}
112
113				room := srv.FindRoomByID(query.RoomID)
114				if room == nil {
115					responder.Respond(w, responder.Status(http.StatusNotFound))
116				} else {
117					responder.Respond(w, responder.Status(http.StatusOK))
118				}
119			})
120
121			r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) {
122				defer r.Body.Close()
123
124				req := &protocol.RoomRequest{}
125				if err := json.NewDecoder(r.Body).Decode(req); err != nil {
126					responder.Respond(w, responder.Status(http.StatusBadRequest))
127					return
128				}
129
130				if msg, valid := req.Valid(); !valid {
131					responder.Respond(w,
132						responder.Status(http.StatusBadRequest),
133						responder.Body(&protocol.RoomResponse{
134							Error: stringPtr(msg),
135						}),
136					)
137					return
138				}
139
140				var room *server.Room
141				if req.Create {
142					var err error
143					room, err = srv.CreateRoom(req.RoomName, req.RoomPass)
144					if err != nil {
145						switch err {
146						case server.ErrRoomExists:
147							responder.Respond(w,
148								responder.Status(http.StatusBadRequest),
149								responder.Body(&protocol.RoomResponse{
150									Error: stringPtr("Room already exists."),
151								}),
152							)
153						case server.ErrTooManyRooms:
154							responder.Respond(w,
155								responder.Status(http.StatusServiceUnavailable),
156								responder.Body(&protocol.RoomResponse{
157									Error: stringPtr("Too many rooms."),
158								}),
159							)
160						default:
161							responder.Respond(w,
162								responder.Status(http.StatusInternalServerError),
163								responder.Body(&protocol.RoomResponse{
164									Error: stringPtr("An unknown error occurred."),
165								}),
166							)
167						}
168						return
169					}
170				} else {
171					room = srv.FindRoom(req.RoomName)
172					if room == nil || room.Password != req.RoomPass {
173						responder.Respond(w,
174							responder.Status(http.StatusNotFound),
175							responder.Body(&protocol.RoomResponse{
176								Error: stringPtr("Room not found or password does not match."),
177							}),
178						)
179						return
180					}
181				}
182
183				responder.Respond(w, responder.Body(&protocol.RoomResponse{
184					ID: &room.ID,
185				}))
186			})
187
188			r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
189				query := &protocol.WSQuery{}
190				if err := queryparam.Parse(r.URL.Query(), query); err != nil {
191					responder.Respond(w, responder.Status(http.StatusBadRequest))
192					return
193				}
194
195				if _, valid := query.Valid(); !valid {
196					responder.Respond(w, responder.Status(http.StatusBadRequest))
197					return
198				}
199
200				room := srv.FindRoomByID(query.RoomID)
201				if room == nil {
202					responder.Respond(w, responder.Status(http.StatusBadRequest))
203					return
204				}
205
206				c, err := websocket.Accept(w, r, wsOpts)
207				if err != nil {
208					return
209				}
210
211				g.Go(func() error {
212					room.HandleConn(query.PlayerID, query.Nickname, c)
213					return nil
214				})
215			})
216		})
217	})
218
219	g.Go(func() error {
220		return srv.Run(ctx)
221	})
222
223	runServer(ctx, g, args.Addr, r)
224
225	if args.Prod {
226		runServer(ctx, g, ":2112", prometheusHandler())
227	}
228
229	log.Fatal(g.Wait())
230}
231
232func staticHandler() http.Handler {
233	fs := http.Dir("./frontend/build")
234	fsh := http.FileServer(fs)
235
236	r := chi.NewMux()
237	r.Use(middleware.Compress(5))
238
239	r.Handle("/static/*", fsh)
240	r.Handle("/favicon/*", fsh)
241
242	r.Group(func(r chi.Router) {
243		r.Use(middleware.NoCache)
244		r.Handle("/*", fsh)
245	})
246
247	return r
248}
249
250func checkVersion(next http.Handler) http.Handler {
251	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
252		want := version.Version()
253
254		toCheck := []string{
255			r.Header.Get("X-CODIES-VERSION"),
256			r.URL.Query().Get("codiesVersion"),
257		}
258
259		for _, got := range toCheck {
260			if got == want {
261				next.ServeHTTP(w, r)
262				return
263			}
264		}
265
266		reason := fmt.Sprintf("client version too old, please reload to get %s", want)
267
268		if r.Header.Get("Upgrade") == "websocket" {
269			c, err := websocket.Accept(w, r, wsOpts)
270			if err != nil {
271				return
272			}
273			c.Close(4418, reason)
274			return
275		}
276
277		w.WriteHeader(http.StatusTeapot)
278		fmt.Fprint(w, reason)
279	})
280}
281
282func runServer(ctx context.Context, g *errgroup.Group, addr string, handler http.Handler) {
283	httpSrv := http.Server{Addr: addr, Handler: handler}
284
285	g.Go(func() error {
286		<-ctx.Done()
287
288		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
289		defer cancel()
290
291		return httpSrv.Shutdown(ctx)
292	})
293
294	g.Go(func() error {
295		return httpSrv.ListenAndServe()
296	})
297}
298
299func prometheusHandler() http.Handler {
300	mux := http.NewServeMux()
301	mux.Handle("/metrics", promhttp.Handler())
302	return mux
303}