1package main
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log"
8 "math/rand"
9 "net/http"
10 "os"
11 "time"
12
13 "github.com/go-chi/chi"
14 "github.com/go-chi/chi/middleware"
15 "github.com/jessevdk/go-flags"
16 "github.com/posener/ctxutil"
17 "github.com/prometheus/client_golang/prometheus/promhttp"
18 "github.com/tomwright/queryparam/v4"
19 "github.com/zikaeroh/codies/internal/protocol"
20 "github.com/zikaeroh/codies/internal/server"
21 "github.com/zikaeroh/codies/internal/version"
22 "golang.org/x/sync/errgroup"
23 "nhooyr.io/websocket"
24)
25
26var args = struct {
27 Addr string `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"`
28 Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"`
29 Prod bool `long:"prod" env:"CODIES_PROD" description:"Enables production mode"`
30 Debug bool `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"`
31}{
32 Addr: ":5000",
33}
34
35var wsOpts *websocket.AcceptOptions
36
37func main() {
38 rand.Seed(time.Now().Unix())
39 log.SetFlags(log.LstdFlags | log.Lshortfile)
40
41 if _, err := flags.Parse(&args); err != nil {
42 // Default flag parser prints messages, so just exit.
43 os.Exit(1)
44 }
45
46 if !args.Prod && !args.Debug {
47 log.Fatal("missing required option --prod or --debug")
48 } else if args.Prod && args.Debug {
49 log.Fatal("must specify either --prod or --debug")
50 }
51
52 log.Printf("starting codies server, version %s", version.Version())
53
54 wsOpts = &websocket.AcceptOptions{
55 OriginPatterns: args.Origins,
56 CompressionMode: websocket.CompressionContextTakeover,
57 }
58
59 if args.Debug {
60 log.Println("starting in debug mode, allowing any WebSocket origin host")
61 wsOpts.InsecureSkipVerify = true
62 } else {
63 if !version.VersionSet() {
64 log.Fatal("running production build without version set")
65 }
66 }
67
68 g, ctx := errgroup.WithContext(ctxutil.Interrupt())
69
70 srv := server.NewServer()
71
72 r := chi.NewMux()
73
74 r.Use(func(next http.Handler) http.Handler {
75 return promhttp.InstrumentHandlerCounter(metricRequest, next)
76 })
77
78 r.Use(middleware.Heartbeat("/ping"))
79 r.Use(middleware.Recoverer)
80 r.NotFound(staticHandler().ServeHTTP)
81
82 r.Group(func(r chi.Router) {
83 r.Use(middleware.NoCache)
84
85 r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) {
86 w.Header().Add("Content-Type", "application/json")
87 _ = json.NewEncoder(w).Encode(&protocol.TimeResponse{Time: time.Now()})
88 })
89
90 r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) {
91 rooms, clients := srv.Stats()
92
93 enc := json.NewEncoder(w)
94 enc.SetIndent("", " ")
95 _ = enc.Encode(&protocol.StatsResponse{
96 Rooms: rooms,
97 Clients: clients,
98 })
99 })
100
101 r.Group(func(r chi.Router) {
102 if !args.Debug {
103 r.Use(checkVersion)
104 }
105
106 r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) {
107 query := &protocol.ExistsQuery{}
108 if err := queryparam.Parse(r.URL.Query(), query); err != nil {
109 httpErr(w, http.StatusBadRequest)
110 return
111 }
112
113 room := srv.FindRoomByID(query.RoomID)
114 if room == nil {
115 w.WriteHeader(http.StatusNotFound)
116 } else {
117 w.WriteHeader(http.StatusOK)
118 }
119
120 _, _ = w.Write([]byte("."))
121 })
122
123 r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) {
124 defer r.Body.Close()
125
126 req := &protocol.RoomRequest{}
127 if err := json.NewDecoder(r.Body).Decode(req); err != nil {
128 httpErr(w, http.StatusBadRequest)
129 return
130 }
131
132 w.Header().Add("Content-Type", "application/json")
133
134 if msg, valid := req.Valid(); !valid {
135 resp := &protocol.RoomResponse{
136 Error: stringPtr(msg),
137 }
138 w.WriteHeader(http.StatusBadRequest)
139 _ = json.NewEncoder(w).Encode(resp)
140 return
141 }
142
143 resp := &protocol.RoomResponse{}
144
145 if req.Create {
146 room, err := srv.CreateRoom(req.RoomName, req.RoomPass)
147 if err != nil {
148 switch err {
149 case server.ErrRoomExists:
150 resp.Error = stringPtr("Room already exists.")
151 w.WriteHeader(http.StatusBadRequest)
152 case server.ErrTooManyRooms:
153 resp.Error = stringPtr("Too many rooms.")
154 w.WriteHeader(http.StatusServiceUnavailable)
155 default:
156 resp.Error = stringPtr("An unknown error occurred.")
157 w.WriteHeader(http.StatusInternalServerError)
158 }
159 } else {
160 resp.ID = &room.ID
161 w.WriteHeader(http.StatusOK)
162 }
163 } else {
164 room := srv.FindRoom(req.RoomName)
165 if room == nil || room.Password != req.RoomPass {
166 resp.Error = stringPtr("Room not found or password does not match.")
167 w.WriteHeader(http.StatusNotFound)
168 } else {
169 resp.ID = &room.ID
170 w.WriteHeader(http.StatusOK)
171 }
172 }
173
174 _ = json.NewEncoder(w).Encode(resp)
175 })
176
177 r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
178 query := &protocol.WSQuery{}
179 if err := queryparam.Parse(r.URL.Query(), query); err != nil {
180 httpErr(w, http.StatusBadRequest)
181 return
182 }
183
184 if _, valid := query.Valid(); !valid {
185 httpErr(w, http.StatusBadRequest)
186 return
187 }
188
189 room := srv.FindRoomByID(query.RoomID)
190 if room == nil {
191 httpErr(w, http.StatusNotFound)
192 return
193 }
194
195 c, err := websocket.Accept(w, r, wsOpts)
196 if err != nil {
197 log.Println(err)
198 return
199 }
200
201 g.Go(func() error {
202 room.HandleConn(query.PlayerID, query.Nickname, c)
203 return nil
204 })
205 })
206 })
207 })
208
209 g.Go(func() error {
210 return srv.Run(ctx)
211 })
212
213 runServer(ctx, g, args.Addr, r)
214
215 if args.Prod {
216 runServer(ctx, g, ":2112", prometheusHandler())
217 }
218
219 log.Fatal(g.Wait())
220}
221
222func staticHandler() http.Handler {
223 fs := http.Dir("./frontend/build")
224 fsh := http.FileServer(fs)
225
226 r := chi.NewMux()
227 r.Use(middleware.Compress(5))
228
229 r.Handle("/static/*", fsh)
230 r.Handle("/favicon/*", fsh)
231
232 r.Group(func(r chi.Router) {
233 r.Use(middleware.NoCache)
234 r.Handle("/*", fsh)
235 })
236
237 return r
238}
239
240func checkVersion(next http.Handler) http.Handler {
241 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
242 want := version.Version()
243
244 toCheck := []string{
245 r.Header.Get("X-CODIES-VERSION"),
246 r.URL.Query().Get("codiesVersion"),
247 }
248
249 for _, got := range toCheck {
250 if got == want {
251 next.ServeHTTP(w, r)
252 return
253 }
254 }
255
256 reason := fmt.Sprintf("client version too old, please reload to get %s", want)
257
258 if r.Header.Get("Upgrade") == "websocket" {
259 c, err := websocket.Accept(w, r, wsOpts)
260 if err != nil {
261 log.Println(err)
262 return
263 }
264 c.Close(4418, reason)
265 return
266 }
267
268 w.WriteHeader(http.StatusTeapot)
269 fmt.Fprint(w, reason)
270 })
271}
272
273func runServer(ctx context.Context, g *errgroup.Group, addr string, handler http.Handler) {
274 httpSrv := http.Server{Addr: addr, Handler: handler}
275
276 g.Go(func() error {
277 <-ctx.Done()
278
279 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
280 defer cancel()
281
282 return httpSrv.Shutdown(ctx)
283 })
284
285 g.Go(func() error {
286 return httpSrv.ListenAndServe()
287 })
288}
289
290func prometheusHandler() http.Handler {
291 mux := http.NewServeMux()
292 mux.Handle("/metrics", promhttp.Handler())
293 return mux
294}