main.go

  1package main
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log"
  8	"math/rand"
  9	"net/http"
 10	"os"
 11	"time"
 12
 13	"github.com/go-chi/chi"
 14	"github.com/go-chi/chi/middleware"
 15	"github.com/jessevdk/go-flags"
 16	"github.com/posener/ctxutil"
 17	"github.com/prometheus/client_golang/prometheus/promhttp"
 18	"github.com/tomwright/queryparam/v4"
 19	"github.com/zikaeroh/codies/internal/protocol"
 20	"github.com/zikaeroh/codies/internal/server"
 21	"github.com/zikaeroh/codies/internal/version"
 22	"golang.org/x/sync/errgroup"
 23	"nhooyr.io/websocket"
 24)
 25
 26var args = struct {
 27	Addr    string   `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"`
 28	Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"`
 29	Prod    bool     `long:"prod" env:"CODIES_PROD" description:"Enables production mode"`
 30	Debug   bool     `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"`
 31}{
 32	Addr: ":5000",
 33}
 34
 35var wsOpts *websocket.AcceptOptions
 36
 37func main() {
 38	rand.Seed(time.Now().Unix())
 39	log.SetFlags(log.LstdFlags | log.Lshortfile)
 40
 41	if _, err := flags.Parse(&args); err != nil {
 42		// Default flag parser prints messages, so just exit.
 43		os.Exit(1)
 44	}
 45
 46	if !args.Prod && !args.Debug {
 47		log.Fatal("missing required option --prod or --debug")
 48	} else if args.Prod && args.Debug {
 49		log.Fatal("must specify either --prod or --debug")
 50	}
 51
 52	log.Printf("starting codies server, version %s", version.Version())
 53
 54	wsOpts = &websocket.AcceptOptions{
 55		OriginPatterns:  args.Origins,
 56		CompressionMode: websocket.CompressionContextTakeover,
 57	}
 58
 59	if args.Debug {
 60		log.Println("starting in debug mode, allowing any WebSocket origin host")
 61		wsOpts.InsecureSkipVerify = true
 62	} else {
 63		if !version.VersionSet() {
 64			log.Fatal("running production build without version set")
 65		}
 66	}
 67
 68	g, ctx := errgroup.WithContext(ctxutil.Interrupt())
 69
 70	srv := server.NewServer()
 71
 72	r := chi.NewMux()
 73
 74	r.Use(func(next http.Handler) http.Handler {
 75		return promhttp.InstrumentHandlerCounter(metricRequest, next)
 76	})
 77
 78	r.Use(middleware.Heartbeat("/ping"))
 79	r.Use(middleware.Recoverer)
 80	r.NotFound(staticHandler().ServeHTTP)
 81
 82	r.Group(func(r chi.Router) {
 83		r.Use(middleware.NoCache)
 84
 85		r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) {
 86			w.Header().Add("Content-Type", "application/json")
 87			_ = json.NewEncoder(w).Encode(&protocol.TimeResponse{Time: time.Now()})
 88		})
 89
 90		r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) {
 91			rooms, clients := srv.Stats()
 92
 93			enc := json.NewEncoder(w)
 94			enc.SetIndent("", "    ")
 95			_ = enc.Encode(&protocol.StatsResponse{
 96				Rooms:   rooms,
 97				Clients: clients,
 98			})
 99		})
100
101		r.Group(func(r chi.Router) {
102			if !args.Debug {
103				r.Use(checkVersion)
104			}
105
106			r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) {
107				query := &protocol.ExistsQuery{}
108				if err := queryparam.Parse(r.URL.Query(), query); err != nil {
109					httpErr(w, http.StatusBadRequest)
110					return
111				}
112
113				room := srv.FindRoomByID(query.RoomID)
114				if room == nil {
115					w.WriteHeader(http.StatusNotFound)
116				} else {
117					w.WriteHeader(http.StatusOK)
118				}
119
120				_, _ = w.Write([]byte("."))
121			})
122
123			r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) {
124				defer r.Body.Close()
125
126				req := &protocol.RoomRequest{}
127				if err := json.NewDecoder(r.Body).Decode(req); err != nil {
128					httpErr(w, http.StatusBadRequest)
129					return
130				}
131
132				w.Header().Add("Content-Type", "application/json")
133
134				if msg, valid := req.Valid(); !valid {
135					resp := &protocol.RoomResponse{
136						Error: stringPtr(msg),
137					}
138					w.WriteHeader(http.StatusBadRequest)
139					_ = json.NewEncoder(w).Encode(resp)
140					return
141				}
142
143				resp := &protocol.RoomResponse{}
144
145				if req.Create {
146					room, err := srv.CreateRoom(req.RoomName, req.RoomPass)
147					if err != nil {
148						switch err {
149						case server.ErrRoomExists:
150							resp.Error = stringPtr("Room already exists.")
151							w.WriteHeader(http.StatusBadRequest)
152						case server.ErrTooManyRooms:
153							resp.Error = stringPtr("Too many rooms.")
154							w.WriteHeader(http.StatusServiceUnavailable)
155						default:
156							resp.Error = stringPtr("An unknown error occurred.")
157							w.WriteHeader(http.StatusInternalServerError)
158						}
159					} else {
160						resp.ID = &room.ID
161						w.WriteHeader(http.StatusOK)
162					}
163				} else {
164					room := srv.FindRoom(req.RoomName)
165					if room == nil || room.Password != req.RoomPass {
166						resp.Error = stringPtr("Room not found or password does not match.")
167						w.WriteHeader(http.StatusNotFound)
168					} else {
169						resp.ID = &room.ID
170						w.WriteHeader(http.StatusOK)
171					}
172				}
173
174				_ = json.NewEncoder(w).Encode(resp)
175			})
176
177			r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
178				query := &protocol.WSQuery{}
179				if err := queryparam.Parse(r.URL.Query(), query); err != nil {
180					httpErr(w, http.StatusBadRequest)
181					return
182				}
183
184				if _, valid := query.Valid(); !valid {
185					httpErr(w, http.StatusBadRequest)
186					return
187				}
188
189				room := srv.FindRoomByID(query.RoomID)
190				if room == nil {
191					httpErr(w, http.StatusNotFound)
192					return
193				}
194
195				c, err := websocket.Accept(w, r, wsOpts)
196				if err != nil {
197					log.Println(err)
198					return
199				}
200
201				g.Go(func() error {
202					room.HandleConn(query.PlayerID, query.Nickname, c)
203					return nil
204				})
205			})
206		})
207	})
208
209	g.Go(func() error {
210		return srv.Run(ctx)
211	})
212
213	runServer(ctx, g, args.Addr, r)
214
215	if args.Prod {
216		runServer(ctx, g, ":2112", prometheusHandler())
217	}
218
219	log.Fatal(g.Wait())
220}
221
222func staticHandler() http.Handler {
223	fs := http.Dir("./frontend/build")
224	fsh := http.FileServer(fs)
225
226	r := chi.NewMux()
227	r.Use(middleware.Compress(5))
228
229	r.Handle("/static/*", fsh)
230	r.Handle("/favicon/*", fsh)
231
232	r.Group(func(r chi.Router) {
233		r.Use(middleware.NoCache)
234		r.Handle("/*", fsh)
235	})
236
237	return r
238}
239
240func checkVersion(next http.Handler) http.Handler {
241	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
242		want := version.Version()
243
244		toCheck := []string{
245			r.Header.Get("X-CODIES-VERSION"),
246			r.URL.Query().Get("codiesVersion"),
247		}
248
249		for _, got := range toCheck {
250			if got == want {
251				next.ServeHTTP(w, r)
252				return
253			}
254		}
255
256		reason := fmt.Sprintf("client version too old, please reload to get %s", want)
257
258		if r.Header.Get("Upgrade") == "websocket" {
259			c, err := websocket.Accept(w, r, wsOpts)
260			if err != nil {
261				log.Println(err)
262				return
263			}
264			c.Close(4418, reason)
265			return
266		}
267
268		w.WriteHeader(http.StatusTeapot)
269		fmt.Fprint(w, reason)
270	})
271}
272
273func runServer(ctx context.Context, g *errgroup.Group, addr string, handler http.Handler) {
274	httpSrv := http.Server{Addr: addr, Handler: handler}
275
276	g.Go(func() error {
277		<-ctx.Done()
278
279		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
280		defer cancel()
281
282		return httpSrv.Shutdown(ctx)
283	})
284
285	g.Go(func() error {
286		return httpSrv.ListenAndServe()
287	})
288}
289
290func prometheusHandler() http.Handler {
291	mux := http.NewServeMux()
292	mux.Handle("/metrics", promhttp.Handler())
293	return mux
294}