main.go

  1package main
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log"
  8	"math/rand"
  9	"net/http"
 10	"os"
 11	"time"
 12
 13	"github.com/go-chi/chi"
 14	"github.com/go-chi/chi/middleware"
 15	"github.com/jessevdk/go-flags"
 16	"github.com/posener/ctxutil"
 17	"github.com/prometheus/client_golang/prometheus/promhttp"
 18	"github.com/tomwright/queryparam/v4"
 19	"github.com/zikaeroh/codies/internal/protocol"
 20	"github.com/zikaeroh/codies/internal/server"
 21	"github.com/zikaeroh/codies/internal/version"
 22	"golang.org/x/sync/errgroup"
 23	"nhooyr.io/websocket"
 24)
 25
 26var args = struct {
 27	Addr    string   `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"`
 28	Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"`
 29	Prod    bool     `long:"prod" env:"CODIES_PROD" description:"Enables production mode"`
 30	Debug   bool     `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"`
 31}{
 32	Addr: ":5000",
 33}
 34
 35var wsOpts *websocket.AcceptOptions
 36
 37func main() {
 38	rand.Seed(time.Now().Unix())
 39	log.SetFlags(log.LstdFlags | log.Lshortfile)
 40
 41	if _, err := flags.Parse(&args); err != nil {
 42		// Default flag parser prints messages, so just exit.
 43		os.Exit(1)
 44	}
 45
 46	if !args.Prod && !args.Debug {
 47		log.Fatal("missing required option --prod or --debug")
 48	} else if args.Prod && args.Debug {
 49		log.Fatal("must specify either --prod or --debug")
 50	}
 51
 52	log.Printf("starting codies server, version %s", version.Version())
 53
 54	wsOpts = &websocket.AcceptOptions{
 55		OriginPatterns: args.Origins,
 56	}
 57
 58	if args.Debug {
 59		log.Println("starting in debug mode, allowing any WebSocket origin host")
 60		wsOpts.OriginPatterns = []string{"*"}
 61	} else {
 62		if !version.VersionSet() {
 63			log.Fatal("running production build without version set")
 64		}
 65	}
 66
 67	g, ctx := errgroup.WithContext(ctxutil.Interrupt())
 68
 69	srv := server.NewServer()
 70
 71	r := chi.NewMux()
 72
 73	r.Use(func(next http.Handler) http.Handler {
 74		return promhttp.InstrumentHandlerCounter(metricRequest, next)
 75	})
 76
 77	r.Use(middleware.Heartbeat("/ping"))
 78	r.Use(middleware.Recoverer)
 79	r.NotFound(staticHandler().ServeHTTP)
 80
 81	r.Group(func(r chi.Router) {
 82		r.Use(middleware.NoCache)
 83
 84		r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) {
 85			w.Header().Add("Content-Type", "application/json")
 86			_ = json.NewEncoder(w).Encode(&protocol.TimeResponse{Time: time.Now()})
 87		})
 88
 89		r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) {
 90			rooms, clients := srv.Stats()
 91
 92			enc := json.NewEncoder(w)
 93			enc.SetIndent("", "    ")
 94			_ = enc.Encode(&protocol.StatsResponse{
 95				Rooms:   rooms,
 96				Clients: clients,
 97			})
 98		})
 99
100		r.Group(func(r chi.Router) {
101			if !args.Debug {
102				r.Use(checkVersion)
103			}
104
105			r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) {
106				query := &protocol.ExistsQuery{}
107				if err := queryparam.Parse(r.URL.Query(), query); err != nil {
108					httpErr(w, http.StatusBadRequest)
109					return
110				}
111
112				room := srv.FindRoomByID(query.RoomID)
113				if room == nil {
114					w.WriteHeader(http.StatusNotFound)
115				} else {
116					w.WriteHeader(http.StatusOK)
117				}
118
119				_, _ = w.Write([]byte("."))
120			})
121
122			r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) {
123				defer r.Body.Close()
124
125				req := &protocol.RoomRequest{}
126				if err := json.NewDecoder(r.Body).Decode(req); err != nil {
127					httpErr(w, http.StatusBadRequest)
128					return
129				}
130
131				w.Header().Add("Content-Type", "application/json")
132
133				if msg, valid := req.Valid(); !valid {
134					resp := &protocol.RoomResponse{
135						Error: stringPtr(msg),
136					}
137					w.WriteHeader(http.StatusBadRequest)
138					_ = json.NewEncoder(w).Encode(resp)
139					return
140				}
141
142				resp := &protocol.RoomResponse{}
143
144				if req.Create {
145					room, err := srv.CreateRoom(req.RoomName, req.RoomPass)
146					if err != nil {
147						switch err {
148						case server.ErrRoomExists:
149							resp.Error = stringPtr("Room already exists.")
150							w.WriteHeader(http.StatusBadRequest)
151						case server.ErrTooManyRooms:
152							resp.Error = stringPtr("Too many rooms.")
153							w.WriteHeader(http.StatusServiceUnavailable)
154						default:
155							resp.Error = stringPtr("An unknown error occurred.")
156							w.WriteHeader(http.StatusInternalServerError)
157						}
158					} else {
159						resp.ID = &room.ID
160						w.WriteHeader(http.StatusOK)
161					}
162				} else {
163					room := srv.FindRoom(req.RoomName)
164					if room == nil || room.Password != req.RoomPass {
165						resp.Error = stringPtr("Room not found or password does not match.")
166						w.WriteHeader(http.StatusNotFound)
167					} else {
168						resp.ID = &room.ID
169						w.WriteHeader(http.StatusOK)
170					}
171				}
172
173				_ = json.NewEncoder(w).Encode(resp)
174			})
175
176			r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
177				query := &protocol.WSQuery{}
178				if err := queryparam.Parse(r.URL.Query(), query); err != nil {
179					httpErr(w, http.StatusBadRequest)
180					return
181				}
182
183				if _, valid := query.Valid(); !valid {
184					httpErr(w, http.StatusBadRequest)
185					return
186				}
187
188				room := srv.FindRoomByID(query.RoomID)
189				if room == nil {
190					httpErr(w, http.StatusNotFound)
191					return
192				}
193
194				c, err := websocket.Accept(w, r, wsOpts)
195				if err != nil {
196					log.Println(err)
197					return
198				}
199
200				g.Go(func() error {
201					room.HandleConn(query.PlayerID, query.Nickname, c)
202					return nil
203				})
204			})
205		})
206	})
207
208	g.Go(func() error {
209		return srv.Run(ctx)
210	})
211
212	runServer(ctx, g, args.Addr, r)
213
214	if args.Prod {
215		runServer(ctx, g, ":2112", prometheusHandler())
216	}
217
218	log.Fatal(g.Wait())
219}
220
221func staticHandler() http.Handler {
222	fs := http.Dir("./frontend/build")
223	fsh := http.FileServer(fs)
224
225	r := chi.NewMux()
226	r.Use(middleware.Compress(5))
227
228	r.Handle("/static/*", fsh)
229	r.Handle("/favicon/*", fsh)
230
231	r.Group(func(r chi.Router) {
232		r.Use(middleware.NoCache)
233		r.Handle("/*", fsh)
234	})
235
236	return r
237}
238
239func checkVersion(next http.Handler) http.Handler {
240	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
241		want := version.Version()
242
243		toCheck := []string{
244			r.Header.Get("X-CODIES-VERSION"),
245			r.URL.Query().Get("codiesVersion"),
246		}
247
248		for _, got := range toCheck {
249			if got == want {
250				next.ServeHTTP(w, r)
251				return
252			}
253		}
254
255		reason := fmt.Sprintf("client version too old, please reload to get %s", want)
256
257		if r.Header.Get("Upgrade") == "websocket" {
258			c, err := websocket.Accept(w, r, wsOpts)
259			if err != nil {
260				log.Println(err)
261				return
262			}
263			c.Close(4418, reason)
264			return
265		}
266
267		w.WriteHeader(http.StatusTeapot)
268		fmt.Fprint(w, reason)
269	})
270}
271
272func runServer(ctx context.Context, g *errgroup.Group, addr string, handler http.Handler) {
273	httpSrv := http.Server{Addr: addr, Handler: handler}
274
275	g.Go(func() error {
276		<-ctx.Done()
277
278		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
279		defer cancel()
280
281		return httpSrv.Shutdown(ctx)
282	})
283
284	g.Go(func() error {
285		return httpSrv.ListenAndServe()
286	})
287}
288
289func prometheusHandler() http.Handler {
290	mux := http.NewServeMux()
291	mux.Handle("/metrics", promhttp.Handler())
292	return mux
293}