1package main
2
3import (
4 "context"
5 "encoding/json"
6 "log"
7 "math/rand"
8 "net/http"
9 "os"
10 "reflect"
11 "time"
12
13 "github.com/go-chi/chi"
14 "github.com/go-chi/chi/middleware"
15 "github.com/gofrs/uuid"
16 "github.com/jessevdk/go-flags"
17 "github.com/posener/ctxutil"
18 "github.com/tomwright/queryparam/v4"
19 "github.com/zikaeroh/codies/internal/protocol"
20 "github.com/zikaeroh/codies/internal/server"
21 "github.com/zikaeroh/codies/internal/version"
22 "golang.org/x/sync/errgroup"
23 "nhooyr.io/websocket"
24)
25
26var args = struct {
27 Addr string `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"`
28 Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"`
29 Debug bool `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"`
30}{
31 Addr: ":5000",
32}
33
34func main() {
35 rand.Seed(time.Now().Unix())
36 log.SetFlags(log.LstdFlags | log.Lshortfile)
37
38 if _, err := flags.Parse(&args); err != nil {
39 // Default flag parser prints messages, so just exit.
40 os.Exit(1)
41 }
42
43 log.Printf("starting codies server, version %s", version.Version())
44
45 wsOpts := &websocket.AcceptOptions{
46 OriginPatterns: args.Origins,
47 }
48
49 if args.Debug {
50 log.Println("starting in debug mode, allowing any WebSocket origin host")
51 wsOpts.OriginPatterns = []string{"*"}
52 }
53
54 g, ctx := errgroup.WithContext(ctxutil.Interrupt())
55
56 srv := server.NewServer()
57
58 r := chi.NewMux()
59 r.Use(middleware.Heartbeat("/ping"))
60 r.Use(middleware.Recoverer)
61
62 r.Group(func(r chi.Router) {
63 r.Use(middleware.Compress(5))
64 fs := http.Dir("./frontend/build")
65 r.NotFound(http.FileServer(fs).ServeHTTP)
66 })
67
68 r.Group(func(r chi.Router) {
69 r.Use(middleware.NoCache)
70
71 r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) {
72 w.Header().Add("Content-Type", "application/json")
73 _ = json.NewEncoder(w).Encode(&protocol.TimeResponse{Time: time.Now()})
74 })
75
76 r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) {
77 query := &protocol.ExistsQuery{}
78 if err := queryparam.Parse(r.URL.Query(), query); err != nil {
79 httpErr(w, http.StatusBadRequest)
80 return
81 }
82
83 room := srv.FindRoomByID(query.RoomID)
84 if room == nil {
85 w.WriteHeader(http.StatusNotFound)
86 } else {
87 w.WriteHeader(http.StatusOK)
88 }
89
90 _, _ = w.Write([]byte("."))
91 })
92
93 r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) {
94 defer r.Body.Close()
95
96 req := &protocol.RoomRequest{}
97 if err := json.NewDecoder(r.Body).Decode(req); err != nil {
98 httpErr(w, http.StatusBadRequest)
99 return
100 }
101
102 if !req.Valid() {
103 httpErr(w, http.StatusBadRequest)
104 return
105 }
106
107 resp := &protocol.RoomResponse{}
108
109 w.Header().Add("Content-Type", "application/json")
110
111 if req.Create {
112 room, err := srv.CreateRoom(req.RoomName, req.RoomPass)
113 if err != nil {
114 switch err {
115 case server.ErrRoomExists:
116 resp.Error = stringPtr("Room already exists.")
117 w.WriteHeader(http.StatusBadRequest)
118 case server.ErrTooManyRooms:
119 resp.Error = stringPtr("Too many rooms.")
120 w.WriteHeader(http.StatusServiceUnavailable)
121 default:
122 resp.Error = stringPtr("An unknown error occurred.")
123 w.WriteHeader(http.StatusInternalServerError)
124 }
125 } else {
126 resp.ID = &room.ID
127 w.WriteHeader(http.StatusOK)
128 }
129 } else {
130 room := srv.FindRoom(req.RoomName)
131 if room == nil || room.Password != req.RoomPass {
132 resp.Error = stringPtr("Room not found or password does not match.")
133 w.WriteHeader(http.StatusNotFound)
134 } else {
135 resp.ID = &room.ID
136 w.WriteHeader(http.StatusOK)
137 }
138 }
139
140 _ = json.NewEncoder(w).Encode(resp)
141 })
142
143 r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
144 query := &protocol.WSQuery{}
145 if err := queryparam.Parse(r.URL.Query(), query); err != nil {
146 httpErr(w, http.StatusBadRequest)
147 return
148 }
149
150 if !query.Valid() {
151 httpErr(w, http.StatusBadRequest)
152 return
153 }
154
155 room := srv.FindRoomByID(query.RoomID)
156 if room == nil {
157 httpErr(w, http.StatusNotFound)
158 return
159 }
160
161 c, err := websocket.Accept(w, r, wsOpts)
162 if err != nil {
163 log.Println(err)
164 return
165 }
166
167 g.Go(func() error {
168 room.HandleConn(query.PlayerID, query.Nickname, c)
169 return nil
170 })
171 })
172
173 r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) {
174 rooms, clients := srv.Stats()
175
176 enc := json.NewEncoder(w)
177 enc.SetIndent("", " ")
178 _ = enc.Encode(&protocol.StatsResponse{
179 Rooms: rooms,
180 Clients: clients,
181 })
182 })
183 })
184
185 g.Go(func() error {
186 return srv.Run(ctx)
187 })
188
189 httpSrv := http.Server{Addr: args.Addr, Handler: r}
190
191 g.Go(func() error {
192 <-ctx.Done()
193
194 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
195 defer cancel()
196
197 return httpSrv.Shutdown(ctx)
198 })
199
200 g.Go(func() error {
201 return httpSrv.ListenAndServe()
202 })
203
204 log.Fatal(g.Wait())
205}
206
207func httpErr(w http.ResponseWriter, code int) {
208 http.Error(w, http.StatusText(code), code)
209}
210
211func stringPtr(s string) *string {
212 return &s
213}
214
215func init() {
216 queryparam.DefaultParser.ValueParsers[reflect.TypeOf(uuid.UUID{})] = func(value string, _ string) (reflect.Value, error) {
217 id, err := uuid.FromString(value)
218 return reflect.ValueOf(id), err
219 }
220}