main.go

  1package main
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log"
  8	"math/rand"
  9	"net/http"
 10	"os"
 11	"time"
 12
 13	"github.com/go-chi/chi"
 14	"github.com/go-chi/chi/middleware"
 15	"github.com/jessevdk/go-flags"
 16	"github.com/posener/ctxutil"
 17	"github.com/prometheus/client_golang/prometheus/promhttp"
 18	"github.com/tomwright/queryparam/v4"
 19	"github.com/zikaeroh/codies/internal/pkger"
 20	"github.com/zikaeroh/codies/internal/protocol"
 21	"github.com/zikaeroh/codies/internal/responder"
 22	"github.com/zikaeroh/codies/internal/server"
 23	"github.com/zikaeroh/codies/internal/version"
 24	"golang.org/x/sync/errgroup"
 25	"nhooyr.io/websocket"
 26)
 27
 28var args = struct {
 29	Addr    string   `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"`
 30	Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"`
 31	Prod    bool     `long:"prod" env:"CODIES_PROD" description:"Enables production mode"`
 32	Debug   bool     `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"`
 33}{
 34	Addr: ":5000",
 35}
 36
 37var wsOpts *websocket.AcceptOptions
 38
 39func main() {
 40	rand.Seed(time.Now().Unix())
 41	log.SetFlags(log.LstdFlags | log.Lshortfile)
 42
 43	if _, err := flags.Parse(&args); err != nil {
 44		// Default flag parser prints messages, so just exit.
 45		os.Exit(1)
 46	}
 47
 48	if !args.Prod && !args.Debug {
 49		log.Fatal("missing required option --prod or --debug")
 50	} else if args.Prod && args.Debug {
 51		log.Fatal("must specify either --prod or --debug")
 52	}
 53
 54	log.Printf("starting codies server, version %s", version.Version())
 55
 56	wsOpts = &websocket.AcceptOptions{
 57		OriginPatterns:  args.Origins,
 58		CompressionMode: websocket.CompressionContextTakeover,
 59	}
 60
 61	if args.Debug {
 62		log.Println("starting in debug mode, allowing any WebSocket origin host")
 63		wsOpts.InsecureSkipVerify = true
 64	} else {
 65		if !version.VersionSet() {
 66			log.Fatal("running production build without version set")
 67		}
 68	}
 69
 70	g, ctx := errgroup.WithContext(ctxutil.Interrupt())
 71
 72	srv := server.NewServer()
 73
 74	r := chi.NewMux()
 75
 76	r.Use(func(next http.Handler) http.Handler {
 77		return promhttp.InstrumentHandlerCounter(metricRequest, next)
 78	})
 79
 80	r.Use(middleware.Heartbeat("/ping"))
 81	r.Use(middleware.Recoverer)
 82	r.NotFound(staticHandler().ServeHTTP)
 83
 84	r.Group(func(r chi.Router) {
 85		r.Use(middleware.NoCache)
 86
 87		r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) {
 88			responder.Respond(w, responder.Body(&protocol.TimeResponse{Time: time.Now()}))
 89		})
 90
 91		r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) {
 92			rooms, clients := srv.Stats()
 93			responder.Respond(w,
 94				responder.Body(&protocol.StatsResponse{
 95					Rooms:   rooms,
 96					Clients: clients,
 97				}),
 98				responder.Pretty(true),
 99			)
100		})
101
102		r.Group(func(r chi.Router) {
103			if !args.Debug {
104				r.Use(checkVersion)
105			}
106
107			r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) {
108				query := &protocol.ExistsQuery{}
109				if err := queryparam.Parse(r.URL.Query(), query); err != nil {
110					responder.Respond(w, responder.Status(http.StatusBadRequest))
111					return
112				}
113
114				room := srv.FindRoomByID(query.RoomID)
115				if room == nil {
116					responder.Respond(w, responder.Status(http.StatusNotFound))
117				} else {
118					responder.Respond(w, responder.Status(http.StatusOK))
119				}
120			})
121
122			r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) {
123				defer r.Body.Close()
124
125				req := &protocol.RoomRequest{}
126				if err := json.NewDecoder(r.Body).Decode(req); err != nil {
127					responder.Respond(w, responder.Status(http.StatusBadRequest))
128					return
129				}
130
131				if msg, valid := req.Valid(); !valid {
132					responder.Respond(w,
133						responder.Status(http.StatusBadRequest),
134						responder.Body(&protocol.RoomResponse{
135							Error: stringPtr(msg),
136						}),
137					)
138					return
139				}
140
141				var room *server.Room
142				if req.Create {
143					var err error
144					room, err = srv.CreateRoom(req.RoomName, req.RoomPass)
145					if err != nil {
146						switch err {
147						case server.ErrRoomExists:
148							responder.Respond(w,
149								responder.Status(http.StatusBadRequest),
150								responder.Body(&protocol.RoomResponse{
151									Error: stringPtr("Room already exists."),
152								}),
153							)
154						case server.ErrTooManyRooms:
155							responder.Respond(w,
156								responder.Status(http.StatusServiceUnavailable),
157								responder.Body(&protocol.RoomResponse{
158									Error: stringPtr("Too many rooms."),
159								}),
160							)
161						default:
162							responder.Respond(w,
163								responder.Status(http.StatusInternalServerError),
164								responder.Body(&protocol.RoomResponse{
165									Error: stringPtr("An unknown error occurred."),
166								}),
167							)
168						}
169						return
170					}
171				} else {
172					room = srv.FindRoom(req.RoomName)
173					if room == nil || room.Password != req.RoomPass {
174						responder.Respond(w,
175							responder.Status(http.StatusNotFound),
176							responder.Body(&protocol.RoomResponse{
177								Error: stringPtr("Room not found or password does not match."),
178							}),
179						)
180						return
181					}
182				}
183
184				responder.Respond(w, responder.Body(&protocol.RoomResponse{
185					ID: &room.ID,
186				}))
187			})
188
189			r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
190				query := &protocol.WSQuery{}
191				if err := queryparam.Parse(r.URL.Query(), query); err != nil {
192					responder.Respond(w, responder.Status(http.StatusBadRequest))
193					return
194				}
195
196				if _, valid := query.Valid(); !valid {
197					responder.Respond(w, responder.Status(http.StatusBadRequest))
198					return
199				}
200
201				room := srv.FindRoomByID(query.RoomID)
202				if room == nil {
203					responder.Respond(w, responder.Status(http.StatusBadRequest))
204					return
205				}
206
207				c, err := websocket.Accept(w, r, wsOpts)
208				if err != nil {
209					return
210				}
211
212				g.Go(func() error {
213					room.HandleConn(ctx, query.PlayerID, query.Nickname, c)
214					return nil
215				})
216			})
217		})
218	})
219
220	g.Go(func() error {
221		return srv.Run(ctx)
222	})
223
224	runServer(ctx, g, args.Addr, r)
225
226	if args.Prod {
227		runServer(ctx, g, ":2112", prometheusHandler())
228	}
229
230	log.Fatal(g.Wait())
231}
232
233func staticHandler() http.Handler {
234	fs := pkger.Dir("/frontend/build")
235	fsh := http.FileServer(fs)
236
237	r := chi.NewMux()
238	r.Use(middleware.Compress(5))
239
240	r.Handle("/static/*", fsh)
241	r.Handle("/favicon/*", fsh)
242
243	r.Group(func(r chi.Router) {
244		r.Use(middleware.NoCache)
245		r.Handle("/*", fsh)
246	})
247
248	return r
249}
250
251func checkVersion(next http.Handler) http.Handler {
252	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
253		want := version.Version()
254
255		toCheck := []string{
256			r.Header.Get("X-CODIES-VERSION"),
257			r.URL.Query().Get("codiesVersion"),
258		}
259
260		for _, got := range toCheck {
261			if got == want {
262				next.ServeHTTP(w, r)
263				return
264			}
265		}
266
267		reason := fmt.Sprintf("client version too old, please reload to get %s", want)
268
269		if r.Header.Get("Upgrade") == "websocket" {
270			c, err := websocket.Accept(w, r, wsOpts)
271			if err != nil {
272				return
273			}
274			c.Close(4418, reason)
275			return
276		}
277
278		w.WriteHeader(http.StatusTeapot)
279		fmt.Fprint(w, reason)
280	})
281}
282
283func runServer(ctx context.Context, g *errgroup.Group, addr string, handler http.Handler) {
284	httpSrv := http.Server{Addr: addr, Handler: handler}
285
286	g.Go(func() error {
287		<-ctx.Done()
288
289		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
290		defer cancel()
291
292		return httpSrv.Shutdown(ctx)
293	})
294
295	g.Go(func() error {
296		return httpSrv.ListenAndServe()
297	})
298}
299
300func prometheusHandler() http.Handler {
301	mux := http.NewServeMux()
302	mux.Handle("/metrics", promhttp.Handler())
303	return mux
304}