1package main
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log"
8 "math/rand"
9 "net/http"
10 "os"
11 "time"
12
13 "github.com/go-chi/chi"
14 "github.com/go-chi/chi/middleware"
15 "github.com/jessevdk/go-flags"
16 "github.com/posener/ctxutil"
17 "github.com/prometheus/client_golang/prometheus/promhttp"
18 "github.com/tomwright/queryparam/v4"
19 "github.com/zikaeroh/codies/internal/protocol"
20 "github.com/zikaeroh/codies/internal/responder"
21 "github.com/zikaeroh/codies/internal/server"
22 "github.com/zikaeroh/codies/internal/version"
23 "golang.org/x/sync/errgroup"
24 "nhooyr.io/websocket"
25)
26
27var args = struct {
28 Addr string `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"`
29 Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"`
30 Prod bool `long:"prod" env:"CODIES_PROD" description:"Enables production mode"`
31 Debug bool `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"`
32}{
33 Addr: ":5000",
34}
35
36var wsOpts *websocket.AcceptOptions
37
38func main() {
39 rand.Seed(time.Now().Unix())
40 log.SetFlags(log.LstdFlags | log.Lshortfile)
41
42 if _, err := flags.Parse(&args); err != nil {
43 // Default flag parser prints messages, so just exit.
44 os.Exit(1)
45 }
46
47 if !args.Prod && !args.Debug {
48 log.Fatal("missing required option --prod or --debug")
49 } else if args.Prod && args.Debug {
50 log.Fatal("must specify either --prod or --debug")
51 }
52
53 log.Printf("starting codies server, version %s", version.Version())
54
55 wsOpts = &websocket.AcceptOptions{
56 OriginPatterns: args.Origins,
57 CompressionMode: websocket.CompressionContextTakeover,
58 }
59
60 if args.Debug {
61 log.Println("starting in debug mode, allowing any WebSocket origin host")
62 wsOpts.InsecureSkipVerify = true
63 } else {
64 if !version.VersionSet() {
65 log.Fatal("running production build without version set")
66 }
67 }
68
69 g, ctx := errgroup.WithContext(ctxutil.Interrupt())
70
71 srv := server.NewServer()
72
73 r := chi.NewMux()
74
75 r.Use(func(next http.Handler) http.Handler {
76 return promhttp.InstrumentHandlerCounter(metricRequest, next)
77 })
78
79 r.Use(middleware.Heartbeat("/ping"))
80 r.Use(middleware.Recoverer)
81 r.NotFound(staticHandler().ServeHTTP)
82
83 r.Group(func(r chi.Router) {
84 r.Use(middleware.NoCache)
85
86 r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) {
87 responder.Respond(w, responder.Body(&protocol.TimeResponse{Time: time.Now()}))
88 })
89
90 r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) {
91 rooms, clients := srv.Stats()
92 responder.Respond(w,
93 responder.Body(&protocol.StatsResponse{
94 Rooms: rooms,
95 Clients: clients,
96 }),
97 responder.Pretty(true),
98 )
99 })
100
101 r.Group(func(r chi.Router) {
102 if !args.Debug {
103 r.Use(checkVersion)
104 }
105
106 r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) {
107 query := &protocol.ExistsQuery{}
108 if err := queryparam.Parse(r.URL.Query(), query); err != nil {
109 responder.Respond(w, responder.Status(http.StatusBadRequest))
110 return
111 }
112
113 room := srv.FindRoomByID(query.RoomID)
114 if room == nil {
115 responder.Respond(w, responder.Status(http.StatusNotFound))
116 } else {
117 responder.Respond(w, responder.Status(http.StatusOK))
118 }
119 })
120
121 r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) {
122 defer r.Body.Close()
123
124 req := &protocol.RoomRequest{}
125 if err := json.NewDecoder(r.Body).Decode(req); err != nil {
126 responder.Respond(w, responder.Status(http.StatusBadRequest))
127 return
128 }
129
130 if msg, valid := req.Valid(); !valid {
131 responder.Respond(w,
132 responder.Status(http.StatusBadRequest),
133 responder.Body(&protocol.RoomResponse{
134 Error: stringPtr(msg),
135 }),
136 )
137 return
138 }
139
140 var room *server.Room
141 if req.Create {
142 var err error
143 room, err = srv.CreateRoom(req.RoomName, req.RoomPass)
144 if err != nil {
145 switch err {
146 case server.ErrRoomExists:
147 responder.Respond(w,
148 responder.Status(http.StatusBadRequest),
149 responder.Body(&protocol.RoomResponse{
150 Error: stringPtr("Room already exists."),
151 }),
152 )
153 case server.ErrTooManyRooms:
154 responder.Respond(w,
155 responder.Status(http.StatusServiceUnavailable),
156 responder.Body(&protocol.RoomResponse{
157 Error: stringPtr("Too many rooms."),
158 }),
159 )
160 default:
161 responder.Respond(w,
162 responder.Status(http.StatusInternalServerError),
163 responder.Body(&protocol.RoomResponse{
164 Error: stringPtr("An unknown error occurred."),
165 }),
166 )
167 }
168 return
169 }
170 } else {
171 room = srv.FindRoom(req.RoomName)
172 if room == nil || room.Password != req.RoomPass {
173 responder.Respond(w,
174 responder.Status(http.StatusNotFound),
175 responder.Body(&protocol.RoomResponse{
176 Error: stringPtr("Room not found or password does not match."),
177 }),
178 )
179 return
180 }
181 }
182
183 responder.Respond(w, responder.Body(&protocol.RoomResponse{
184 ID: &room.ID,
185 }))
186 })
187
188 r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
189 query := &protocol.WSQuery{}
190 if err := queryparam.Parse(r.URL.Query(), query); err != nil {
191 responder.Respond(w, responder.Status(http.StatusBadRequest))
192 return
193 }
194
195 if _, valid := query.Valid(); !valid {
196 responder.Respond(w, responder.Status(http.StatusBadRequest))
197 return
198 }
199
200 room := srv.FindRoomByID(query.RoomID)
201 if room == nil {
202 responder.Respond(w, responder.Status(http.StatusBadRequest))
203 return
204 }
205
206 c, err := websocket.Accept(w, r, wsOpts)
207 if err != nil {
208 return
209 }
210
211 g.Go(func() error {
212 room.HandleConn(query.PlayerID, query.Nickname, c)
213 return nil
214 })
215 })
216 })
217 })
218
219 g.Go(func() error {
220 return srv.Run(ctx)
221 })
222
223 runServer(ctx, g, args.Addr, r)
224
225 if args.Prod {
226 runServer(ctx, g, ":2112", prometheusHandler())
227 }
228
229 log.Fatal(g.Wait())
230}
231
232func staticHandler() http.Handler {
233 fs := http.Dir("./frontend/build")
234 fsh := http.FileServer(fs)
235
236 r := chi.NewMux()
237 r.Use(middleware.Compress(5))
238
239 r.Handle("/static/*", fsh)
240 r.Handle("/favicon/*", fsh)
241
242 r.Group(func(r chi.Router) {
243 r.Use(middleware.NoCache)
244 r.Handle("/*", fsh)
245 })
246
247 return r
248}
249
250func checkVersion(next http.Handler) http.Handler {
251 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
252 want := version.Version()
253
254 toCheck := []string{
255 r.Header.Get("X-CODIES-VERSION"),
256 r.URL.Query().Get("codiesVersion"),
257 }
258
259 for _, got := range toCheck {
260 if got == want {
261 next.ServeHTTP(w, r)
262 return
263 }
264 }
265
266 reason := fmt.Sprintf("client version too old, please reload to get %s", want)
267
268 if r.Header.Get("Upgrade") == "websocket" {
269 c, err := websocket.Accept(w, r, wsOpts)
270 if err != nil {
271 return
272 }
273 c.Close(4418, reason)
274 return
275 }
276
277 w.WriteHeader(http.StatusTeapot)
278 fmt.Fprint(w, reason)
279 })
280}
281
282func runServer(ctx context.Context, g *errgroup.Group, addr string, handler http.Handler) {
283 httpSrv := http.Server{Addr: addr, Handler: handler}
284
285 g.Go(func() error {
286 <-ctx.Done()
287
288 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
289 defer cancel()
290
291 return httpSrv.Shutdown(ctx)
292 })
293
294 g.Go(func() error {
295 return httpSrv.ListenAndServe()
296 })
297}
298
299func prometheusHandler() http.Handler {
300 mux := http.NewServeMux()
301 mux.Handle("/metrics", promhttp.Handler())
302 return mux
303}