1package main
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log"
8 "math/rand"
9 "net/http"
10 "os"
11 "time"
12
13 "github.com/go-chi/chi"
14 "github.com/go-chi/chi/middleware"
15 "github.com/jessevdk/go-flags"
16 "github.com/posener/ctxutil"
17 "github.com/prometheus/client_golang/prometheus/promhttp"
18 "github.com/tomwright/queryparam/v4"
19 "github.com/zikaeroh/codies/internal/pkger"
20 "github.com/zikaeroh/codies/internal/protocol"
21 "github.com/zikaeroh/codies/internal/responder"
22 "github.com/zikaeroh/codies/internal/server"
23 "github.com/zikaeroh/codies/internal/version"
24 "github.com/zikaeroh/ctxlog"
25 "go.uber.org/zap"
26 "golang.org/x/sync/errgroup"
27 "nhooyr.io/websocket"
28)
29
30var args = struct {
31 Addr string `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"`
32 Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"`
33 Prod bool `long:"prod" env:"CODIES_PROD" description:"Enables production mode"`
34 Debug bool `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"`
35}{
36 Addr: ":5000",
37}
38
39var wsOpts *websocket.AcceptOptions
40
41func main() {
42 rand.Seed(time.Now().Unix())
43
44 if _, err := flags.Parse(&args); err != nil {
45 // Default flag parser prints messages, so just exit.
46 os.Exit(1)
47 }
48
49 if !args.Prod && !args.Debug {
50 log.Fatal("missing required option --prod or --debug")
51 } else if args.Prod && args.Debug {
52 log.Fatal("must specify either --prod or --debug")
53 }
54
55 ctx := ctxutil.Interrupt()
56
57 logger := ctxlog.New(args.Debug)
58 defer zap.RedirectStdLog(logger)()
59 ctx = ctxlog.WithLogger(ctx, logger)
60
61 ctxlog.Info(ctx, "starting", zap.String("version", version.Version()))
62
63 wsOpts = &websocket.AcceptOptions{
64 OriginPatterns: args.Origins,
65 CompressionMode: websocket.CompressionContextTakeover,
66 }
67
68 if args.Debug {
69 ctxlog.Info(ctx, "starting in debug mode, allowing any WebSocket origin host")
70 wsOpts.InsecureSkipVerify = true
71 } else {
72 if !version.VersionSet() {
73 ctxlog.Fatal(ctx, "running production build without version set")
74 }
75 }
76
77 g, ctx := errgroup.WithContext(ctx)
78
79 srv := server.NewServer()
80
81 r := chi.NewMux()
82
83 r.Use(func(next http.Handler) http.Handler {
84 return promhttp.InstrumentHandlerCounter(metricRequest, next)
85 })
86
87 r.Use(middleware.Heartbeat("/ping"))
88 r.Use(middleware.Recoverer)
89 r.NotFound(staticHandler().ServeHTTP)
90
91 r.Group(func(r chi.Router) {
92 r.Use(middleware.NoCache)
93
94 r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) {
95 responder.Respond(w, responder.Body(&protocol.TimeResponse{Time: time.Now()}))
96 })
97
98 r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) {
99 rooms, clients := srv.Stats()
100 responder.Respond(w,
101 responder.Body(&protocol.StatsResponse{
102 Rooms: rooms,
103 Clients: clients,
104 }),
105 responder.Pretty(true),
106 )
107 })
108
109 r.Group(func(r chi.Router) {
110 if !args.Debug {
111 r.Use(checkVersion)
112 }
113
114 r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) {
115 query := &protocol.ExistsQuery{}
116 if err := queryparam.Parse(r.URL.Query(), query); err != nil {
117 responder.Respond(w, responder.Status(http.StatusBadRequest))
118 return
119 }
120
121 room := srv.FindRoomByID(query.RoomID)
122 if room == nil {
123 responder.Respond(w, responder.Status(http.StatusNotFound))
124 } else {
125 responder.Respond(w, responder.Status(http.StatusOK))
126 }
127 })
128
129 r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) {
130 defer r.Body.Close()
131
132 req := &protocol.RoomRequest{}
133 if err := json.NewDecoder(r.Body).Decode(req); err != nil {
134 responder.Respond(w, responder.Status(http.StatusBadRequest))
135 return
136 }
137
138 if msg, valid := req.Valid(); !valid {
139 responder.Respond(w,
140 responder.Status(http.StatusBadRequest),
141 responder.Body(&protocol.RoomResponse{
142 Error: stringPtr(msg),
143 }),
144 )
145 return
146 }
147
148 var room *server.Room
149 if req.Create {
150 var err error
151 room, err = srv.CreateRoom(ctx, req.RoomName, req.RoomPass)
152 if err != nil {
153 switch err {
154 case server.ErrRoomExists:
155 responder.Respond(w,
156 responder.Status(http.StatusBadRequest),
157 responder.Body(&protocol.RoomResponse{
158 Error: stringPtr("Room already exists."),
159 }),
160 )
161 case server.ErrTooManyRooms:
162 responder.Respond(w,
163 responder.Status(http.StatusServiceUnavailable),
164 responder.Body(&protocol.RoomResponse{
165 Error: stringPtr("Too many rooms."),
166 }),
167 )
168 default:
169 responder.Respond(w,
170 responder.Status(http.StatusInternalServerError),
171 responder.Body(&protocol.RoomResponse{
172 Error: stringPtr("An unknown error occurred."),
173 }),
174 )
175 }
176 return
177 }
178 } else {
179 room = srv.FindRoom(req.RoomName)
180 if room == nil || room.Password != req.RoomPass {
181 responder.Respond(w,
182 responder.Status(http.StatusNotFound),
183 responder.Body(&protocol.RoomResponse{
184 Error: stringPtr("Room not found or password does not match."),
185 }),
186 )
187 return
188 }
189 }
190
191 responder.Respond(w, responder.Body(&protocol.RoomResponse{
192 ID: &room.ID,
193 }))
194 })
195
196 r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
197 query := &protocol.WSQuery{}
198 if err := queryparam.Parse(r.URL.Query(), query); err != nil {
199 responder.Respond(w, responder.Status(http.StatusBadRequest))
200 return
201 }
202
203 if _, valid := query.Valid(); !valid {
204 responder.Respond(w, responder.Status(http.StatusBadRequest))
205 return
206 }
207
208 room := srv.FindRoomByID(query.RoomID)
209 if room == nil {
210 responder.Respond(w, responder.Status(http.StatusBadRequest))
211 return
212 }
213
214 c, err := websocket.Accept(w, r, wsOpts)
215 if err != nil {
216 return
217 }
218
219 g.Go(func() error {
220 room.HandleConn(ctx, query.Nickname, c)
221 return nil
222 })
223 })
224 })
225 })
226
227 g.Go(func() error {
228 return srv.Run(ctx)
229 })
230
231 runServer(ctx, g, args.Addr, r)
232
233 if args.Prod {
234 runServer(ctx, g, ":2112", prometheusHandler())
235 }
236
237 exitErr := g.Wait()
238 ctxlog.Fatal(ctx, "exited", zap.Error(exitErr))
239}
240
241func staticHandler() http.Handler {
242 fs := pkger.Dir("/frontend/build")
243 fsh := http.FileServer(fs)
244
245 r := chi.NewMux()
246 r.Use(middleware.Compress(5))
247
248 r.Handle("/static/*", fsh)
249 r.Handle("/favicon/*", fsh)
250
251 r.Group(func(r chi.Router) {
252 r.Use(middleware.NoCache)
253 r.Handle("/*", fsh)
254 })
255
256 return r
257}
258
259func checkVersion(next http.Handler) http.Handler {
260 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
261 want := version.Version()
262
263 toCheck := []string{
264 r.Header.Get("X-CODIES-VERSION"),
265 r.URL.Query().Get("codiesVersion"),
266 }
267
268 for _, got := range toCheck {
269 if got == want {
270 next.ServeHTTP(w, r)
271 return
272 }
273 }
274
275 reason := fmt.Sprintf("client version too old, please reload to get %s", want)
276
277 if r.Header.Get("Upgrade") == "websocket" {
278 c, err := websocket.Accept(w, r, wsOpts)
279 if err != nil {
280 return
281 }
282 c.Close(4418, reason)
283 return
284 }
285
286 w.WriteHeader(http.StatusTeapot)
287 fmt.Fprint(w, reason)
288 })
289}
290
291func runServer(ctx context.Context, g *errgroup.Group, addr string, handler http.Handler) {
292 httpSrv := http.Server{Addr: addr, Handler: handler}
293
294 g.Go(func() error {
295 <-ctx.Done()
296
297 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
298 defer cancel()
299
300 return httpSrv.Shutdown(ctx)
301 })
302
303 g.Go(func() error {
304 return httpSrv.ListenAndServe()
305 })
306}
307
308func prometheusHandler() http.Handler {
309 mux := http.NewServeMux()
310 mux.Handle("/metrics", promhttp.Handler())
311 return mux
312}