1package main
2
3import (
4 "context"
5 "encoding/json"
6 "log"
7 "math/rand"
8 "net/http"
9 "os"
10 "reflect"
11 "time"
12
13 "github.com/go-chi/chi"
14 "github.com/go-chi/chi/middleware"
15 "github.com/gofrs/uuid"
16 "github.com/jessevdk/go-flags"
17 "github.com/posener/ctxutil"
18 "github.com/tomwright/queryparam/v4"
19 "github.com/zikaeroh/codies/internal/protocol"
20 "github.com/zikaeroh/codies/internal/server"
21 "golang.org/x/sync/errgroup"
22 "nhooyr.io/websocket"
23)
24
25var args = struct {
26 Addr string `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"`
27 Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"`
28 Debug bool `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"`
29}{
30 Addr: ":5000",
31}
32
33func main() {
34 if _, err := flags.Parse(&args); err != nil {
35 // Default flag parser prints messages, so just exit.
36 os.Exit(1)
37 }
38
39 wsOpts := &websocket.AcceptOptions{
40 OriginPatterns: args.Origins,
41 }
42
43 if args.Debug {
44 log.Println("starting in debug mode, allowing any WebSocket origin host")
45 wsOpts.OriginPatterns = []string{"*"}
46 }
47
48 rand.Seed(time.Now().Unix())
49
50 log.SetFlags(log.LstdFlags | log.Lshortfile)
51
52 g, ctx := errgroup.WithContext(ctxutil.Interrupt())
53
54 srv := server.NewServer()
55
56 r := chi.NewMux()
57 r.Use(middleware.Heartbeat("/ping"))
58 r.Use(middleware.Recoverer)
59
60 r.Group(func(r chi.Router) {
61 r.Use(middleware.Compress(5))
62 fs := http.Dir("./frontend/build")
63 r.NotFound(http.FileServer(fs).ServeHTTP)
64 })
65
66 r.Group(func(r chi.Router) {
67 r.Use(middleware.NoCache)
68
69 r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) {
70 w.Header().Add("Content-Type", "application/json")
71 _ = json.NewEncoder(w).Encode(&protocol.TimeResponse{Time: time.Now()})
72 })
73
74 r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) {
75 query := &protocol.ExistsQuery{}
76 if err := queryparam.Parse(r.URL.Query(), query); err != nil {
77 httpErr(w, http.StatusBadRequest)
78 return
79 }
80
81 room := srv.FindRoomByID(query.RoomID)
82 if room == nil {
83 w.WriteHeader(http.StatusNotFound)
84 } else {
85 w.WriteHeader(http.StatusOK)
86 }
87
88 _, _ = w.Write([]byte("."))
89 })
90
91 r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) {
92 defer r.Body.Close()
93
94 req := &protocol.RoomRequest{}
95 if err := json.NewDecoder(r.Body).Decode(req); err != nil {
96 httpErr(w, http.StatusBadRequest)
97 return
98 }
99
100 if !req.Valid() {
101 httpErr(w, http.StatusBadRequest)
102 return
103 }
104
105 resp := &protocol.RoomResponse{}
106
107 w.Header().Add("Content-Type", "application/json")
108
109 if req.Create {
110 room, err := srv.CreateRoom(req.RoomName, req.RoomPass)
111 if err != nil {
112 switch err {
113 case server.ErrRoomExists:
114 resp.Error = stringPtr("Room already exists.")
115 w.WriteHeader(http.StatusBadRequest)
116 case server.ErrTooManyRooms:
117 resp.Error = stringPtr("Too many rooms.")
118 w.WriteHeader(http.StatusServiceUnavailable)
119 default:
120 resp.Error = stringPtr("An unknown error occurred.")
121 w.WriteHeader(http.StatusInternalServerError)
122 }
123 } else {
124 resp.ID = &room.ID
125 w.WriteHeader(http.StatusOK)
126 }
127 } else {
128 room := srv.FindRoom(req.RoomName)
129 if room == nil || room.Password != req.RoomPass {
130 resp.Error = stringPtr("Room not found or password does not match.")
131 w.WriteHeader(http.StatusNotFound)
132 } else {
133 resp.ID = &room.ID
134 w.WriteHeader(http.StatusOK)
135 }
136 }
137
138 _ = json.NewEncoder(w).Encode(resp)
139 })
140
141 r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
142 query := &protocol.WSQuery{}
143 if err := queryparam.Parse(r.URL.Query(), query); err != nil {
144 httpErr(w, http.StatusBadRequest)
145 return
146 }
147
148 if !query.Valid() {
149 httpErr(w, http.StatusBadRequest)
150 return
151 }
152
153 room := srv.FindRoomByID(query.RoomID)
154 if room == nil {
155 httpErr(w, http.StatusNotFound)
156 return
157 }
158
159 c, err := websocket.Accept(w, r, wsOpts)
160 if err != nil {
161 log.Println(err)
162 return
163 }
164
165 g.Go(func() error {
166 room.HandleConn(query.PlayerID, query.Nickname, c)
167 return nil
168 })
169 })
170
171 r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) {
172 rooms, clients := srv.Stats()
173
174 enc := json.NewEncoder(w)
175 enc.SetIndent("", " ")
176 _ = enc.Encode(&protocol.StatsResponse{
177 Rooms: rooms,
178 Clients: clients,
179 })
180 })
181 })
182
183 g.Go(func() error {
184 return srv.Run(ctx)
185 })
186
187 httpSrv := http.Server{Addr: args.Addr, Handler: r}
188
189 g.Go(func() error {
190 <-ctx.Done()
191
192 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
193 defer cancel()
194
195 return httpSrv.Shutdown(ctx)
196 })
197
198 g.Go(func() error {
199 return httpSrv.ListenAndServe()
200 })
201
202 log.Fatal(g.Wait())
203}
204
205func httpErr(w http.ResponseWriter, code int) {
206 http.Error(w, http.StatusText(code), code)
207}
208
209func stringPtr(s string) *string {
210 return &s
211}
212
213func init() {
214 queryparam.DefaultParser.ValueParsers[reflect.TypeOf(uuid.UUID{})] = func(value string, _ string) (reflect.Value, error) {
215 id, err := uuid.FromString(value)
216 return reflect.ValueOf(id), err
217 }
218}