main.go

  1package main
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"log"
  7	"math/rand"
  8	"net/http"
  9	"os"
 10	"reflect"
 11	"time"
 12
 13	"github.com/go-chi/chi"
 14	"github.com/go-chi/chi/middleware"
 15	"github.com/gofrs/uuid"
 16	"github.com/jessevdk/go-flags"
 17	"github.com/posener/ctxutil"
 18	"github.com/tomwright/queryparam/v4"
 19	"github.com/zikaeroh/codies/internal/protocol"
 20	"github.com/zikaeroh/codies/internal/server"
 21	"golang.org/x/sync/errgroup"
 22	"nhooyr.io/websocket"
 23)
 24
 25var args = struct {
 26	Addr    string   `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"`
 27	Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"`
 28	Debug   bool     `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"`
 29}{
 30	Addr: ":5000",
 31}
 32
 33func main() {
 34	if _, err := flags.Parse(&args); err != nil {
 35		// Default flag parser prints messages, so just exit.
 36		os.Exit(1)
 37	}
 38
 39	wsOpts := &websocket.AcceptOptions{
 40		OriginPatterns: args.Origins,
 41	}
 42
 43	if args.Debug {
 44		log.Println("starting in debug mode, allowing any WebSocket origin host")
 45		wsOpts.OriginPatterns = []string{"*"}
 46	}
 47
 48	rand.Seed(time.Now().Unix())
 49
 50	log.SetFlags(log.LstdFlags | log.Lshortfile)
 51
 52	g, ctx := errgroup.WithContext(ctxutil.Interrupt())
 53
 54	srv := server.NewServer()
 55
 56	r := chi.NewMux()
 57	r.Use(middleware.Heartbeat("/ping"))
 58	r.Use(middleware.Recoverer)
 59
 60	r.Group(func(r chi.Router) {
 61		r.Use(middleware.Compress(5))
 62		fs := http.Dir("./frontend/build")
 63		r.NotFound(http.FileServer(fs).ServeHTTP)
 64	})
 65
 66	r.Group(func(r chi.Router) {
 67		r.Use(middleware.NoCache)
 68
 69		r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) {
 70			w.Header().Add("Content-Type", "application/json")
 71			_ = json.NewEncoder(w).Encode(&protocol.TimeResponse{Time: time.Now()})
 72		})
 73
 74		r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) {
 75			query := &protocol.ExistsQuery{}
 76			if err := queryparam.Parse(r.URL.Query(), query); err != nil {
 77				httpErr(w, http.StatusBadRequest)
 78				return
 79			}
 80
 81			room := srv.FindRoomByID(query.RoomID)
 82			if room == nil {
 83				w.WriteHeader(http.StatusNotFound)
 84			} else {
 85				w.WriteHeader(http.StatusOK)
 86			}
 87
 88			_, _ = w.Write([]byte("."))
 89		})
 90
 91		r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) {
 92			defer r.Body.Close()
 93
 94			req := &protocol.RoomRequest{}
 95			if err := json.NewDecoder(r.Body).Decode(req); err != nil {
 96				httpErr(w, http.StatusBadRequest)
 97				return
 98			}
 99
100			if !req.Valid() {
101				httpErr(w, http.StatusBadRequest)
102				return
103			}
104
105			resp := &protocol.RoomResponse{}
106
107			w.Header().Add("Content-Type", "application/json")
108
109			if req.Create {
110				room, err := srv.CreateRoom(req.RoomName, req.RoomPass)
111				if err != nil {
112					switch err {
113					case server.ErrRoomExists:
114						resp.Error = stringPtr("Room already exists.")
115						w.WriteHeader(http.StatusBadRequest)
116					case server.ErrTooManyRooms:
117						resp.Error = stringPtr("Too many rooms.")
118						w.WriteHeader(http.StatusServiceUnavailable)
119					default:
120						resp.Error = stringPtr("An unknown error occurred.")
121						w.WriteHeader(http.StatusInternalServerError)
122					}
123				} else {
124					resp.ID = &room.ID
125					w.WriteHeader(http.StatusOK)
126				}
127			} else {
128				room := srv.FindRoom(req.RoomName)
129				if room == nil || room.Password != req.RoomPass {
130					resp.Error = stringPtr("Room not found or password does not match.")
131					w.WriteHeader(http.StatusNotFound)
132				} else {
133					resp.ID = &room.ID
134					w.WriteHeader(http.StatusOK)
135				}
136			}
137
138			_ = json.NewEncoder(w).Encode(resp)
139		})
140
141		r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
142			query := &protocol.WSQuery{}
143			if err := queryparam.Parse(r.URL.Query(), query); err != nil {
144				httpErr(w, http.StatusBadRequest)
145				return
146			}
147
148			if !query.Valid() {
149				httpErr(w, http.StatusBadRequest)
150				return
151			}
152
153			room := srv.FindRoomByID(query.RoomID)
154			if room == nil {
155				httpErr(w, http.StatusNotFound)
156				return
157			}
158
159			c, err := websocket.Accept(w, r, wsOpts)
160			if err != nil {
161				log.Println(err)
162				return
163			}
164
165			g.Go(func() error {
166				room.HandleConn(query.PlayerID, query.Nickname, c)
167				return nil
168			})
169		})
170
171		r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) {
172			rooms, clients := srv.Stats()
173
174			enc := json.NewEncoder(w)
175			enc.SetIndent("", "    ")
176			_ = enc.Encode(&protocol.StatsResponse{
177				Rooms:   rooms,
178				Clients: clients,
179			})
180		})
181	})
182
183	g.Go(func() error {
184		return srv.Run(ctx)
185	})
186
187	httpSrv := http.Server{Addr: args.Addr, Handler: r}
188
189	g.Go(func() error {
190		<-ctx.Done()
191
192		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
193		defer cancel()
194
195		return httpSrv.Shutdown(ctx)
196	})
197
198	g.Go(func() error {
199		return httpSrv.ListenAndServe()
200	})
201
202	log.Fatal(g.Wait())
203}
204
205func httpErr(w http.ResponseWriter, code int) {
206	http.Error(w, http.StatusText(code), code)
207}
208
209func stringPtr(s string) *string {
210	return &s
211}
212
213func init() {
214	queryparam.DefaultParser.ValueParsers[reflect.TypeOf(uuid.UUID{})] = func(value string, _ string) (reflect.Value, error) {
215		id, err := uuid.FromString(value)
216		return reflect.ValueOf(id), err
217	}
218}