1package main
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log"
8 "math/rand"
9 "net/http"
10 "os"
11 "time"
12
13 "github.com/go-chi/chi"
14 "github.com/go-chi/chi/middleware"
15 "github.com/jessevdk/go-flags"
16 "github.com/posener/ctxutil"
17 "github.com/prometheus/client_golang/prometheus/promhttp"
18 "github.com/tomwright/queryparam/v4"
19 "github.com/zikaeroh/codies/internal/pkger"
20 "github.com/zikaeroh/codies/internal/protocol"
21 "github.com/zikaeroh/codies/internal/responder"
22 "github.com/zikaeroh/codies/internal/server"
23 "github.com/zikaeroh/codies/internal/version"
24 "golang.org/x/sync/errgroup"
25 "nhooyr.io/websocket"
26)
27
28var args = struct {
29 Addr string `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"`
30 Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"`
31 Prod bool `long:"prod" env:"CODIES_PROD" description:"Enables production mode"`
32 Debug bool `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"`
33}{
34 Addr: ":5000",
35}
36
37var wsOpts *websocket.AcceptOptions
38
39func main() {
40 rand.Seed(time.Now().Unix())
41 log.SetFlags(log.LstdFlags | log.Lshortfile)
42
43 if _, err := flags.Parse(&args); err != nil {
44 // Default flag parser prints messages, so just exit.
45 os.Exit(1)
46 }
47
48 if !args.Prod && !args.Debug {
49 log.Fatal("missing required option --prod or --debug")
50 } else if args.Prod && args.Debug {
51 log.Fatal("must specify either --prod or --debug")
52 }
53
54 log.Printf("starting codies server, version %s", version.Version())
55
56 wsOpts = &websocket.AcceptOptions{
57 OriginPatterns: args.Origins,
58 CompressionMode: websocket.CompressionContextTakeover,
59 }
60
61 if args.Debug {
62 log.Println("starting in debug mode, allowing any WebSocket origin host")
63 wsOpts.InsecureSkipVerify = true
64 } else {
65 if !version.VersionSet() {
66 log.Fatal("running production build without version set")
67 }
68 }
69
70 g, ctx := errgroup.WithContext(ctxutil.Interrupt())
71
72 srv := server.NewServer()
73
74 r := chi.NewMux()
75
76 r.Use(func(next http.Handler) http.Handler {
77 return promhttp.InstrumentHandlerCounter(metricRequest, next)
78 })
79
80 r.Use(middleware.Heartbeat("/ping"))
81 r.Use(middleware.Recoverer)
82 r.NotFound(staticHandler().ServeHTTP)
83
84 r.Group(func(r chi.Router) {
85 r.Use(middleware.NoCache)
86
87 r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) {
88 responder.Respond(w, responder.Body(&protocol.TimeResponse{Time: time.Now()}))
89 })
90
91 r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) {
92 rooms, clients := srv.Stats()
93 responder.Respond(w,
94 responder.Body(&protocol.StatsResponse{
95 Rooms: rooms,
96 Clients: clients,
97 }),
98 responder.Pretty(true),
99 )
100 })
101
102 r.Group(func(r chi.Router) {
103 if !args.Debug {
104 r.Use(checkVersion)
105 }
106
107 r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) {
108 query := &protocol.ExistsQuery{}
109 if err := queryparam.Parse(r.URL.Query(), query); err != nil {
110 responder.Respond(w, responder.Status(http.StatusBadRequest))
111 return
112 }
113
114 room := srv.FindRoomByID(query.RoomID)
115 if room == nil {
116 responder.Respond(w, responder.Status(http.StatusNotFound))
117 } else {
118 responder.Respond(w, responder.Status(http.StatusOK))
119 }
120 })
121
122 r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) {
123 defer r.Body.Close()
124
125 req := &protocol.RoomRequest{}
126 if err := json.NewDecoder(r.Body).Decode(req); err != nil {
127 responder.Respond(w, responder.Status(http.StatusBadRequest))
128 return
129 }
130
131 if msg, valid := req.Valid(); !valid {
132 responder.Respond(w,
133 responder.Status(http.StatusBadRequest),
134 responder.Body(&protocol.RoomResponse{
135 Error: stringPtr(msg),
136 }),
137 )
138 return
139 }
140
141 var room *server.Room
142 if req.Create {
143 var err error
144 room, err = srv.CreateRoom(req.RoomName, req.RoomPass)
145 if err != nil {
146 switch err {
147 case server.ErrRoomExists:
148 responder.Respond(w,
149 responder.Status(http.StatusBadRequest),
150 responder.Body(&protocol.RoomResponse{
151 Error: stringPtr("Room already exists."),
152 }),
153 )
154 case server.ErrTooManyRooms:
155 responder.Respond(w,
156 responder.Status(http.StatusServiceUnavailable),
157 responder.Body(&protocol.RoomResponse{
158 Error: stringPtr("Too many rooms."),
159 }),
160 )
161 default:
162 responder.Respond(w,
163 responder.Status(http.StatusInternalServerError),
164 responder.Body(&protocol.RoomResponse{
165 Error: stringPtr("An unknown error occurred."),
166 }),
167 )
168 }
169 return
170 }
171 } else {
172 room = srv.FindRoom(req.RoomName)
173 if room == nil || room.Password != req.RoomPass {
174 responder.Respond(w,
175 responder.Status(http.StatusNotFound),
176 responder.Body(&protocol.RoomResponse{
177 Error: stringPtr("Room not found or password does not match."),
178 }),
179 )
180 return
181 }
182 }
183
184 responder.Respond(w, responder.Body(&protocol.RoomResponse{
185 ID: &room.ID,
186 }))
187 })
188
189 r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
190 query := &protocol.WSQuery{}
191 if err := queryparam.Parse(r.URL.Query(), query); err != nil {
192 responder.Respond(w, responder.Status(http.StatusBadRequest))
193 return
194 }
195
196 if _, valid := query.Valid(); !valid {
197 responder.Respond(w, responder.Status(http.StatusBadRequest))
198 return
199 }
200
201 room := srv.FindRoomByID(query.RoomID)
202 if room == nil {
203 responder.Respond(w, responder.Status(http.StatusBadRequest))
204 return
205 }
206
207 c, err := websocket.Accept(w, r, wsOpts)
208 if err != nil {
209 return
210 }
211
212 g.Go(func() error {
213 room.HandleConn(ctx, query.PlayerID, query.Nickname, c)
214 return nil
215 })
216 })
217 })
218 })
219
220 g.Go(func() error {
221 return srv.Run(ctx)
222 })
223
224 runServer(ctx, g, args.Addr, r)
225
226 if args.Prod {
227 runServer(ctx, g, ":2112", prometheusHandler())
228 }
229
230 log.Fatal(g.Wait())
231}
232
233func staticHandler() http.Handler {
234 fs := pkger.Dir("/frontend/build")
235 fsh := http.FileServer(fs)
236
237 r := chi.NewMux()
238 r.Use(middleware.Compress(5))
239
240 r.Handle("/static/*", fsh)
241 r.Handle("/favicon/*", fsh)
242
243 r.Group(func(r chi.Router) {
244 r.Use(middleware.NoCache)
245 r.Handle("/*", fsh)
246 })
247
248 return r
249}
250
251func checkVersion(next http.Handler) http.Handler {
252 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
253 want := version.Version()
254
255 toCheck := []string{
256 r.Header.Get("X-CODIES-VERSION"),
257 r.URL.Query().Get("codiesVersion"),
258 }
259
260 for _, got := range toCheck {
261 if got == want {
262 next.ServeHTTP(w, r)
263 return
264 }
265 }
266
267 reason := fmt.Sprintf("client version too old, please reload to get %s", want)
268
269 if r.Header.Get("Upgrade") == "websocket" {
270 c, err := websocket.Accept(w, r, wsOpts)
271 if err != nil {
272 return
273 }
274 c.Close(4418, reason)
275 return
276 }
277
278 w.WriteHeader(http.StatusTeapot)
279 fmt.Fprint(w, reason)
280 })
281}
282
283func runServer(ctx context.Context, g *errgroup.Group, addr string, handler http.Handler) {
284 httpSrv := http.Server{Addr: addr, Handler: handler}
285
286 g.Go(func() error {
287 <-ctx.Done()
288
289 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
290 defer cancel()
291
292 return httpSrv.Shutdown(ctx)
293 })
294
295 g.Go(func() error {
296 return httpSrv.ListenAndServe()
297 })
298}
299
300func prometheusHandler() http.Handler {
301 mux := http.NewServeMux()
302 mux.Handle("/metrics", promhttp.Handler())
303 return mux
304}