1package main
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log"
8 "math/rand"
9 "net/http"
10 "os"
11 "reflect"
12 "time"
13
14 "github.com/go-chi/chi"
15 "github.com/go-chi/chi/middleware"
16 "github.com/gofrs/uuid"
17 "github.com/jessevdk/go-flags"
18 "github.com/posener/ctxutil"
19 "github.com/tomwright/queryparam/v4"
20 "github.com/zikaeroh/codies/internal/protocol"
21 "github.com/zikaeroh/codies/internal/server"
22 "github.com/zikaeroh/codies/internal/version"
23 "golang.org/x/sync/errgroup"
24 "nhooyr.io/websocket"
25)
26
27var args = struct {
28 Addr string `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"`
29 Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"`
30 Prod bool `long:"prod" env:"CODIES_PROD" description:"Enables production mode"`
31 Debug bool `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"`
32}{
33 Addr: ":5000",
34}
35
36var wsOpts *websocket.AcceptOptions
37
38func main() {
39 rand.Seed(time.Now().Unix())
40 log.SetFlags(log.LstdFlags | log.Lshortfile)
41
42 if _, err := flags.Parse(&args); err != nil {
43 // Default flag parser prints messages, so just exit.
44 os.Exit(1)
45 }
46
47 if !args.Prod && !args.Debug {
48 log.Fatal("missing required option --prod or --debug")
49 } else if args.Prod && args.Debug {
50 log.Fatal("must specify either --prod or --debug")
51 }
52
53 log.Printf("starting codies server, version %s", version.Version())
54
55 wsOpts = &websocket.AcceptOptions{
56 OriginPatterns: args.Origins,
57 }
58
59 if args.Debug {
60 log.Println("starting in debug mode, allowing any WebSocket origin host")
61 wsOpts.OriginPatterns = []string{"*"}
62 } else {
63 if !version.VersionSet() {
64 log.Fatal("running production build without version set")
65 }
66 }
67
68 g, ctx := errgroup.WithContext(ctxutil.Interrupt())
69
70 srv := server.NewServer()
71
72 r := chi.NewMux()
73 r.Use(middleware.Heartbeat("/ping"))
74 r.Use(middleware.Recoverer)
75 r.NotFound(staticRouter().ServeHTTP)
76
77 r.Group(func(r chi.Router) {
78 r.Use(middleware.NoCache)
79
80 r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) {
81 w.Header().Add("Content-Type", "application/json")
82 _ = json.NewEncoder(w).Encode(&protocol.TimeResponse{Time: time.Now()})
83 })
84
85 r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) {
86 rooms, clients := srv.Stats()
87
88 enc := json.NewEncoder(w)
89 enc.SetIndent("", " ")
90 _ = enc.Encode(&protocol.StatsResponse{
91 Rooms: rooms,
92 Clients: clients,
93 })
94 })
95
96 r.Group(func(r chi.Router) {
97 if !args.Debug {
98 r.Use(checkVersion)
99 }
100
101 r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) {
102 query := &protocol.ExistsQuery{}
103 if err := queryparam.Parse(r.URL.Query(), query); err != nil {
104 httpErr(w, http.StatusBadRequest)
105 return
106 }
107
108 room := srv.FindRoomByID(query.RoomID)
109 if room == nil {
110 w.WriteHeader(http.StatusNotFound)
111 } else {
112 w.WriteHeader(http.StatusOK)
113 }
114
115 _, _ = w.Write([]byte("."))
116 })
117
118 r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) {
119 defer r.Body.Close()
120
121 req := &protocol.RoomRequest{}
122 if err := json.NewDecoder(r.Body).Decode(req); err != nil {
123 httpErr(w, http.StatusBadRequest)
124 return
125 }
126
127 if !req.Valid() {
128 httpErr(w, http.StatusBadRequest)
129 return
130 }
131
132 resp := &protocol.RoomResponse{}
133
134 w.Header().Add("Content-Type", "application/json")
135
136 if req.Create {
137 room, err := srv.CreateRoom(req.RoomName, req.RoomPass)
138 if err != nil {
139 switch err {
140 case server.ErrRoomExists:
141 resp.Error = stringPtr("Room already exists.")
142 w.WriteHeader(http.StatusBadRequest)
143 case server.ErrTooManyRooms:
144 resp.Error = stringPtr("Too many rooms.")
145 w.WriteHeader(http.StatusServiceUnavailable)
146 default:
147 resp.Error = stringPtr("An unknown error occurred.")
148 w.WriteHeader(http.StatusInternalServerError)
149 }
150 } else {
151 resp.ID = &room.ID
152 w.WriteHeader(http.StatusOK)
153 }
154 } else {
155 room := srv.FindRoom(req.RoomName)
156 if room == nil || room.Password != req.RoomPass {
157 resp.Error = stringPtr("Room not found or password does not match.")
158 w.WriteHeader(http.StatusNotFound)
159 } else {
160 resp.ID = &room.ID
161 w.WriteHeader(http.StatusOK)
162 }
163 }
164
165 _ = json.NewEncoder(w).Encode(resp)
166 })
167
168 r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
169 query := &protocol.WSQuery{}
170 if err := queryparam.Parse(r.URL.Query(), query); err != nil {
171 httpErr(w, http.StatusBadRequest)
172 return
173 }
174
175 if !query.Valid() {
176 httpErr(w, http.StatusBadRequest)
177 return
178 }
179
180 room := srv.FindRoomByID(query.RoomID)
181 if room == nil {
182 httpErr(w, http.StatusNotFound)
183 return
184 }
185
186 c, err := websocket.Accept(w, r, wsOpts)
187 if err != nil {
188 log.Println(err)
189 return
190 }
191
192 g.Go(func() error {
193 room.HandleConn(query.PlayerID, query.Nickname, c)
194 return nil
195 })
196 })
197 })
198 })
199
200 g.Go(func() error {
201 return srv.Run(ctx)
202 })
203
204 httpSrv := http.Server{Addr: args.Addr, Handler: r}
205
206 g.Go(func() error {
207 <-ctx.Done()
208
209 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
210 defer cancel()
211
212 return httpSrv.Shutdown(ctx)
213 })
214
215 g.Go(func() error {
216 return httpSrv.ListenAndServe()
217 })
218
219 log.Fatal(g.Wait())
220}
221
222func staticRouter() http.Handler {
223 fs := http.Dir("./frontend/build")
224 fsh := http.FileServer(fs)
225
226 r := chi.NewMux()
227 r.Use(middleware.Compress(5))
228
229 r.Handle("/static/*", fsh)
230 r.Handle("/favicon/*", fsh)
231
232 r.Group(func(r chi.Router) {
233 r.Use(middleware.NoCache)
234 r.Handle("/*", fsh)
235 })
236
237 return r
238}
239
240func checkVersion(next http.Handler) http.Handler {
241 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
242 want := version.Version()
243
244 toCheck := []string{
245 r.Header.Get("X-CODIES-VERSION"),
246 r.URL.Query().Get("codiesVersion"),
247 }
248
249 for _, got := range toCheck {
250 if got == want {
251 next.ServeHTTP(w, r)
252 return
253 }
254 }
255
256 reason := fmt.Sprintf("client version too old, please reload to get %s", want)
257
258 if r.Header.Get("Upgrade") == "websocket" {
259 c, err := websocket.Accept(w, r, wsOpts)
260 if err != nil {
261 log.Println(err)
262 return
263 }
264 c.Close(4418, reason)
265 return
266 }
267
268 w.WriteHeader(http.StatusTeapot)
269 fmt.Fprint(w, reason)
270 })
271}
272
273func httpErr(w http.ResponseWriter, code int) {
274 http.Error(w, http.StatusText(code), code)
275}
276
277func stringPtr(s string) *string {
278 return &s
279}
280
281func init() {
282 queryparam.DefaultParser.ValueParsers[reflect.TypeOf(uuid.UUID{})] = func(value string, _ string) (reflect.Value, error) {
283 id, err := uuid.FromString(value)
284 return reflect.ValueOf(id), err
285 }
286}