main.go

  1package main
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log"
  8	"math/rand"
  9	"net/http"
 10	"os"
 11	"reflect"
 12	"time"
 13
 14	"github.com/go-chi/chi"
 15	"github.com/go-chi/chi/middleware"
 16	"github.com/gofrs/uuid"
 17	"github.com/jessevdk/go-flags"
 18	"github.com/posener/ctxutil"
 19	"github.com/tomwright/queryparam/v4"
 20	"github.com/zikaeroh/codies/internal/protocol"
 21	"github.com/zikaeroh/codies/internal/server"
 22	"github.com/zikaeroh/codies/internal/version"
 23	"golang.org/x/sync/errgroup"
 24	"nhooyr.io/websocket"
 25)
 26
 27var args = struct {
 28	Addr    string   `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"`
 29	Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"`
 30	Prod    bool     `long:"prod" env:"CODIES_PROD" description:"Enables production mode"`
 31	Debug   bool     `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"`
 32}{
 33	Addr: ":5000",
 34}
 35
 36var wsOpts *websocket.AcceptOptions
 37
 38func main() {
 39	rand.Seed(time.Now().Unix())
 40	log.SetFlags(log.LstdFlags | log.Lshortfile)
 41
 42	if _, err := flags.Parse(&args); err != nil {
 43		// Default flag parser prints messages, so just exit.
 44		os.Exit(1)
 45	}
 46
 47	if !args.Prod && !args.Debug {
 48		log.Fatal("missing required option --prod or --debug")
 49	} else if args.Prod && args.Debug {
 50		log.Fatal("must specify either --prod or --debug")
 51	}
 52
 53	log.Printf("starting codies server, version %s", version.Version())
 54
 55	wsOpts = &websocket.AcceptOptions{
 56		OriginPatterns: args.Origins,
 57	}
 58
 59	if args.Debug {
 60		log.Println("starting in debug mode, allowing any WebSocket origin host")
 61		wsOpts.OriginPatterns = []string{"*"}
 62	} else {
 63		if !version.VersionSet() {
 64			log.Fatal("running production build without version set")
 65		}
 66	}
 67
 68	g, ctx := errgroup.WithContext(ctxutil.Interrupt())
 69
 70	srv := server.NewServer()
 71
 72	r := chi.NewMux()
 73	r.Use(middleware.Heartbeat("/ping"))
 74	r.Use(middleware.Recoverer)
 75	r.NotFound(staticRouter().ServeHTTP)
 76
 77	r.Group(func(r chi.Router) {
 78		r.Use(middleware.NoCache)
 79
 80		r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) {
 81			w.Header().Add("Content-Type", "application/json")
 82			_ = json.NewEncoder(w).Encode(&protocol.TimeResponse{Time: time.Now()})
 83		})
 84
 85		r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) {
 86			rooms, clients := srv.Stats()
 87
 88			enc := json.NewEncoder(w)
 89			enc.SetIndent("", "    ")
 90			_ = enc.Encode(&protocol.StatsResponse{
 91				Rooms:   rooms,
 92				Clients: clients,
 93			})
 94		})
 95
 96		r.Group(func(r chi.Router) {
 97			if !args.Debug {
 98				r.Use(checkVersion)
 99			}
100
101			r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) {
102				query := &protocol.ExistsQuery{}
103				if err := queryparam.Parse(r.URL.Query(), query); err != nil {
104					httpErr(w, http.StatusBadRequest)
105					return
106				}
107
108				room := srv.FindRoomByID(query.RoomID)
109				if room == nil {
110					w.WriteHeader(http.StatusNotFound)
111				} else {
112					w.WriteHeader(http.StatusOK)
113				}
114
115				_, _ = w.Write([]byte("."))
116			})
117
118			r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) {
119				defer r.Body.Close()
120
121				req := &protocol.RoomRequest{}
122				if err := json.NewDecoder(r.Body).Decode(req); err != nil {
123					httpErr(w, http.StatusBadRequest)
124					return
125				}
126
127				if !req.Valid() {
128					httpErr(w, http.StatusBadRequest)
129					return
130				}
131
132				resp := &protocol.RoomResponse{}
133
134				w.Header().Add("Content-Type", "application/json")
135
136				if req.Create {
137					room, err := srv.CreateRoom(req.RoomName, req.RoomPass)
138					if err != nil {
139						switch err {
140						case server.ErrRoomExists:
141							resp.Error = stringPtr("Room already exists.")
142							w.WriteHeader(http.StatusBadRequest)
143						case server.ErrTooManyRooms:
144							resp.Error = stringPtr("Too many rooms.")
145							w.WriteHeader(http.StatusServiceUnavailable)
146						default:
147							resp.Error = stringPtr("An unknown error occurred.")
148							w.WriteHeader(http.StatusInternalServerError)
149						}
150					} else {
151						resp.ID = &room.ID
152						w.WriteHeader(http.StatusOK)
153					}
154				} else {
155					room := srv.FindRoom(req.RoomName)
156					if room == nil || room.Password != req.RoomPass {
157						resp.Error = stringPtr("Room not found or password does not match.")
158						w.WriteHeader(http.StatusNotFound)
159					} else {
160						resp.ID = &room.ID
161						w.WriteHeader(http.StatusOK)
162					}
163				}
164
165				_ = json.NewEncoder(w).Encode(resp)
166			})
167
168			r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
169				query := &protocol.WSQuery{}
170				if err := queryparam.Parse(r.URL.Query(), query); err != nil {
171					httpErr(w, http.StatusBadRequest)
172					return
173				}
174
175				if !query.Valid() {
176					httpErr(w, http.StatusBadRequest)
177					return
178				}
179
180				room := srv.FindRoomByID(query.RoomID)
181				if room == nil {
182					httpErr(w, http.StatusNotFound)
183					return
184				}
185
186				c, err := websocket.Accept(w, r, wsOpts)
187				if err != nil {
188					log.Println(err)
189					return
190				}
191
192				g.Go(func() error {
193					room.HandleConn(query.PlayerID, query.Nickname, c)
194					return nil
195				})
196			})
197		})
198	})
199
200	g.Go(func() error {
201		return srv.Run(ctx)
202	})
203
204	httpSrv := http.Server{Addr: args.Addr, Handler: r}
205
206	g.Go(func() error {
207		<-ctx.Done()
208
209		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
210		defer cancel()
211
212		return httpSrv.Shutdown(ctx)
213	})
214
215	g.Go(func() error {
216		return httpSrv.ListenAndServe()
217	})
218
219	log.Fatal(g.Wait())
220}
221
222func staticRouter() http.Handler {
223	fs := http.Dir("./frontend/build")
224	fsh := http.FileServer(fs)
225
226	r := chi.NewMux()
227	r.Use(middleware.Compress(5))
228
229	r.Handle("/static/*", fsh)
230	r.Handle("/favicon/*", fsh)
231
232	r.Group(func(r chi.Router) {
233		r.Use(middleware.NoCache)
234		r.Handle("/*", fsh)
235	})
236
237	return r
238}
239
240func checkVersion(next http.Handler) http.Handler {
241	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
242		want := version.Version()
243
244		toCheck := []string{
245			r.Header.Get("X-CODIES-VERSION"),
246			r.URL.Query().Get("codiesVersion"),
247		}
248
249		for _, got := range toCheck {
250			if got == want {
251				next.ServeHTTP(w, r)
252				return
253			}
254		}
255
256		reason := fmt.Sprintf("client version too old, please reload to get %s", want)
257
258		if r.Header.Get("Upgrade") == "websocket" {
259			c, err := websocket.Accept(w, r, wsOpts)
260			if err != nil {
261				log.Println(err)
262				return
263			}
264			c.Close(4418, reason)
265			return
266		}
267
268		w.WriteHeader(http.StatusTeapot)
269		fmt.Fprint(w, reason)
270	})
271}
272
273func httpErr(w http.ResponseWriter, code int) {
274	http.Error(w, http.StatusText(code), code)
275}
276
277func stringPtr(s string) *string {
278	return &s
279}
280
281func init() {
282	queryparam.DefaultParser.ValueParsers[reflect.TypeOf(uuid.UUID{})] = func(value string, _ string) (reflect.Value, error) {
283		id, err := uuid.FromString(value)
284		return reflect.ValueOf(id), err
285	}
286}