1package main
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log"
8 "math/rand"
9 "net/http"
10 "os"
11 "time"
12
13 "github.com/go-chi/chi"
14 "github.com/go-chi/chi/middleware"
15 "github.com/jessevdk/go-flags"
16 "github.com/posener/ctxutil"
17 "github.com/prometheus/client_golang/prometheus/promhttp"
18 "github.com/tomwright/queryparam/v4"
19 "github.com/zikaeroh/codies/internal/protocol"
20 "github.com/zikaeroh/codies/internal/server"
21 "github.com/zikaeroh/codies/internal/version"
22 "golang.org/x/sync/errgroup"
23 "nhooyr.io/websocket"
24)
25
26var args = struct {
27 Addr string `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"`
28 Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"`
29 Prod bool `long:"prod" env:"CODIES_PROD" description:"Enables production mode"`
30 Debug bool `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"`
31}{
32 Addr: ":5000",
33}
34
35var wsOpts *websocket.AcceptOptions
36
37func main() {
38 rand.Seed(time.Now().Unix())
39 log.SetFlags(log.LstdFlags | log.Lshortfile)
40
41 if _, err := flags.Parse(&args); err != nil {
42 // Default flag parser prints messages, so just exit.
43 os.Exit(1)
44 }
45
46 if !args.Prod && !args.Debug {
47 log.Fatal("missing required option --prod or --debug")
48 } else if args.Prod && args.Debug {
49 log.Fatal("must specify either --prod or --debug")
50 }
51
52 log.Printf("starting codies server, version %s", version.Version())
53
54 wsOpts = &websocket.AcceptOptions{
55 OriginPatterns: args.Origins,
56 }
57
58 if args.Debug {
59 log.Println("starting in debug mode, allowing any WebSocket origin host")
60 wsOpts.OriginPatterns = []string{"*"}
61 } else {
62 if !version.VersionSet() {
63 log.Fatal("running production build without version set")
64 }
65 }
66
67 g, ctx := errgroup.WithContext(ctxutil.Interrupt())
68
69 srv := server.NewServer()
70
71 r := chi.NewMux()
72
73 r.Use(func(next http.Handler) http.Handler {
74 return promhttp.InstrumentHandlerCounter(metricRequest, next)
75 })
76
77 r.Use(middleware.Heartbeat("/ping"))
78 r.Use(middleware.Recoverer)
79 r.NotFound(staticHandler().ServeHTTP)
80
81 r.Group(func(r chi.Router) {
82 r.Use(middleware.NoCache)
83
84 r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) {
85 w.Header().Add("Content-Type", "application/json")
86 _ = json.NewEncoder(w).Encode(&protocol.TimeResponse{Time: time.Now()})
87 })
88
89 r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) {
90 rooms, clients := srv.Stats()
91
92 enc := json.NewEncoder(w)
93 enc.SetIndent("", " ")
94 _ = enc.Encode(&protocol.StatsResponse{
95 Rooms: rooms,
96 Clients: clients,
97 })
98 })
99
100 r.Group(func(r chi.Router) {
101 if !args.Debug {
102 r.Use(checkVersion)
103 }
104
105 r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) {
106 query := &protocol.ExistsQuery{}
107 if err := queryparam.Parse(r.URL.Query(), query); err != nil {
108 httpErr(w, http.StatusBadRequest)
109 return
110 }
111
112 room := srv.FindRoomByID(query.RoomID)
113 if room == nil {
114 w.WriteHeader(http.StatusNotFound)
115 } else {
116 w.WriteHeader(http.StatusOK)
117 }
118
119 _, _ = w.Write([]byte("."))
120 })
121
122 r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) {
123 defer r.Body.Close()
124
125 req := &protocol.RoomRequest{}
126 if err := json.NewDecoder(r.Body).Decode(req); err != nil {
127 httpErr(w, http.StatusBadRequest)
128 return
129 }
130
131 w.Header().Add("Content-Type", "application/json")
132
133 if msg, valid := req.Valid(); !valid {
134 resp := &protocol.RoomResponse{
135 Error: stringPtr(msg),
136 }
137 w.WriteHeader(http.StatusBadRequest)
138 _ = json.NewEncoder(w).Encode(resp)
139 return
140 }
141
142 resp := &protocol.RoomResponse{}
143
144 if req.Create {
145 room, err := srv.CreateRoom(req.RoomName, req.RoomPass)
146 if err != nil {
147 switch err {
148 case server.ErrRoomExists:
149 resp.Error = stringPtr("Room already exists.")
150 w.WriteHeader(http.StatusBadRequest)
151 case server.ErrTooManyRooms:
152 resp.Error = stringPtr("Too many rooms.")
153 w.WriteHeader(http.StatusServiceUnavailable)
154 default:
155 resp.Error = stringPtr("An unknown error occurred.")
156 w.WriteHeader(http.StatusInternalServerError)
157 }
158 } else {
159 resp.ID = &room.ID
160 w.WriteHeader(http.StatusOK)
161 }
162 } else {
163 room := srv.FindRoom(req.RoomName)
164 if room == nil || room.Password != req.RoomPass {
165 resp.Error = stringPtr("Room not found or password does not match.")
166 w.WriteHeader(http.StatusNotFound)
167 } else {
168 resp.ID = &room.ID
169 w.WriteHeader(http.StatusOK)
170 }
171 }
172
173 _ = json.NewEncoder(w).Encode(resp)
174 })
175
176 r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
177 query := &protocol.WSQuery{}
178 if err := queryparam.Parse(r.URL.Query(), query); err != nil {
179 httpErr(w, http.StatusBadRequest)
180 return
181 }
182
183 if _, valid := query.Valid(); !valid {
184 httpErr(w, http.StatusBadRequest)
185 return
186 }
187
188 room := srv.FindRoomByID(query.RoomID)
189 if room == nil {
190 httpErr(w, http.StatusNotFound)
191 return
192 }
193
194 c, err := websocket.Accept(w, r, wsOpts)
195 if err != nil {
196 log.Println(err)
197 return
198 }
199
200 g.Go(func() error {
201 room.HandleConn(query.PlayerID, query.Nickname, c)
202 return nil
203 })
204 })
205 })
206 })
207
208 g.Go(func() error {
209 return srv.Run(ctx)
210 })
211
212 runServer(ctx, g, args.Addr, r)
213
214 if args.Prod {
215 runServer(ctx, g, ":2112", prometheusHandler())
216 }
217
218 log.Fatal(g.Wait())
219}
220
221func staticHandler() http.Handler {
222 fs := http.Dir("./frontend/build")
223 fsh := http.FileServer(fs)
224
225 r := chi.NewMux()
226 r.Use(middleware.Compress(5))
227
228 r.Handle("/static/*", fsh)
229 r.Handle("/favicon/*", fsh)
230
231 r.Group(func(r chi.Router) {
232 r.Use(middleware.NoCache)
233 r.Handle("/*", fsh)
234 })
235
236 return r
237}
238
239func checkVersion(next http.Handler) http.Handler {
240 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
241 want := version.Version()
242
243 toCheck := []string{
244 r.Header.Get("X-CODIES-VERSION"),
245 r.URL.Query().Get("codiesVersion"),
246 }
247
248 for _, got := range toCheck {
249 if got == want {
250 next.ServeHTTP(w, r)
251 return
252 }
253 }
254
255 reason := fmt.Sprintf("client version too old, please reload to get %s", want)
256
257 if r.Header.Get("Upgrade") == "websocket" {
258 c, err := websocket.Accept(w, r, wsOpts)
259 if err != nil {
260 log.Println(err)
261 return
262 }
263 c.Close(4418, reason)
264 return
265 }
266
267 w.WriteHeader(http.StatusTeapot)
268 fmt.Fprint(w, reason)
269 })
270}
271
272func runServer(ctx context.Context, g *errgroup.Group, addr string, handler http.Handler) {
273 httpSrv := http.Server{Addr: addr, Handler: handler}
274
275 g.Go(func() error {
276 <-ctx.Done()
277
278 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
279 defer cancel()
280
281 return httpSrv.Shutdown(ctx)
282 })
283
284 g.Go(func() error {
285 return httpSrv.ListenAndServe()
286 })
287}
288
289func prometheusHandler() http.Handler {
290 mux := http.NewServeMux()
291 mux.Handle("/metrics", promhttp.Handler())
292 return mux
293}