1package main
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log"
8 "math/rand"
9 "net/http"
10 "os"
11 "time"
12
13 "github.com/go-chi/chi"
14 "github.com/go-chi/chi/middleware"
15 "github.com/jessevdk/go-flags"
16 "github.com/posener/ctxutil"
17 "github.com/prometheus/client_golang/prometheus/promhttp"
18 "github.com/tomwright/queryparam/v4"
19 "github.com/zikaeroh/codies/internal/pkger"
20 "github.com/zikaeroh/codies/internal/protocol"
21 "github.com/zikaeroh/codies/internal/responder"
22 "github.com/zikaeroh/codies/internal/server"
23 "github.com/zikaeroh/codies/internal/version"
24 "github.com/zikaeroh/ctxlog"
25 "go.uber.org/zap"
26 "golang.org/x/sync/errgroup"
27 "nhooyr.io/websocket"
28)
29
30var args = struct {
31 Addr string `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"`
32 Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"`
33 Prod bool `long:"prod" env:"CODIES_PROD" description:"Enables production mode"`
34 Debug bool `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"`
35}{
36 Addr: ":5000",
37}
38
39var wsOpts *websocket.AcceptOptions
40
41func main() {
42 if argv := os.Args[1:]; len(argv) > 0 && argv[0] == "version" {
43 fmt.Println(version.Version())
44 return
45 }
46
47 rand.Seed(time.Now().Unix())
48
49 if _, err := flags.Parse(&args); err != nil {
50 // Default flag parser prints messages, so just exit.
51 os.Exit(1)
52 }
53
54 if !args.Prod && !args.Debug {
55 log.Fatal("missing required option --prod or --debug")
56 } else if args.Prod && args.Debug {
57 log.Fatal("must specify either --prod or --debug")
58 }
59
60 ctx := ctxutil.Interrupt()
61
62 logger := ctxlog.New(args.Debug)
63 defer zap.RedirectStdLog(logger)()
64 ctx = ctxlog.WithLogger(ctx, logger)
65
66 ctxlog.Info(ctx, "starting", zap.String("version", version.Version()))
67
68 wsOpts = &websocket.AcceptOptions{
69 OriginPatterns: args.Origins,
70 CompressionMode: websocket.CompressionContextTakeover,
71 }
72
73 if args.Debug {
74 ctxlog.Info(ctx, "starting in debug mode, allowing any WebSocket origin host")
75 wsOpts.InsecureSkipVerify = true
76 } else {
77 if !version.VersionSet() {
78 ctxlog.Fatal(ctx, "running production build without version set")
79 }
80 }
81
82 g, ctx := errgroup.WithContext(ctx)
83
84 srv := server.NewServer()
85
86 r := chi.NewMux()
87
88 r.Use(func(next http.Handler) http.Handler {
89 return promhttp.InstrumentHandlerCounter(metricRequest, next)
90 })
91
92 r.Use(middleware.Heartbeat("/ping"))
93 r.Use(middleware.Recoverer)
94 r.NotFound(staticHandler().ServeHTTP)
95
96 r.Group(func(r chi.Router) {
97 r.Use(middleware.NoCache)
98
99 r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) {
100 responder.Respond(w, responder.Body(&protocol.TimeResponse{Time: time.Now()}))
101 })
102
103 r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) {
104 rooms, clients := srv.Stats()
105 responder.Respond(w,
106 responder.Body(&protocol.StatsResponse{
107 Rooms: rooms,
108 Clients: clients,
109 }),
110 responder.Pretty(true),
111 )
112 })
113
114 r.Group(func(r chi.Router) {
115 if !args.Debug {
116 r.Use(checkVersion)
117 }
118
119 r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) {
120 query := &protocol.ExistsQuery{}
121 if err := queryparam.Parse(r.URL.Query(), query); err != nil {
122 responder.Respond(w, responder.Status(http.StatusBadRequest))
123 return
124 }
125
126 room := srv.FindRoomByID(query.RoomID)
127 if room == nil {
128 responder.Respond(w, responder.Status(http.StatusNotFound))
129 } else {
130 responder.Respond(w, responder.Status(http.StatusOK))
131 }
132 })
133
134 r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) {
135 defer r.Body.Close()
136
137 req := &protocol.RoomRequest{}
138 if err := json.NewDecoder(r.Body).Decode(req); err != nil {
139 responder.Respond(w, responder.Status(http.StatusBadRequest))
140 return
141 }
142
143 if msg, valid := req.Valid(); !valid {
144 responder.Respond(w,
145 responder.Status(http.StatusBadRequest),
146 responder.Body(&protocol.RoomResponse{
147 Error: stringPtr(msg),
148 }),
149 )
150 return
151 }
152
153 var room *server.Room
154 if req.Create {
155 var err error
156 room, err = srv.CreateRoom(ctx, req.RoomName, req.RoomPass)
157 if err != nil {
158 switch err {
159 case server.ErrRoomExists:
160 responder.Respond(w,
161 responder.Status(http.StatusBadRequest),
162 responder.Body(&protocol.RoomResponse{
163 Error: stringPtr("Room already exists."),
164 }),
165 )
166 case server.ErrTooManyRooms:
167 responder.Respond(w,
168 responder.Status(http.StatusServiceUnavailable),
169 responder.Body(&protocol.RoomResponse{
170 Error: stringPtr("Too many rooms."),
171 }),
172 )
173 default:
174 responder.Respond(w,
175 responder.Status(http.StatusInternalServerError),
176 responder.Body(&protocol.RoomResponse{
177 Error: stringPtr("An unknown error occurred."),
178 }),
179 )
180 }
181 return
182 }
183 } else {
184 room = srv.FindRoom(req.RoomName)
185 if room == nil || room.Password != req.RoomPass {
186 responder.Respond(w,
187 responder.Status(http.StatusNotFound),
188 responder.Body(&protocol.RoomResponse{
189 Error: stringPtr("Room not found or password does not match."),
190 }),
191 )
192 return
193 }
194 }
195
196 responder.Respond(w, responder.Body(&protocol.RoomResponse{
197 ID: &room.ID,
198 }))
199 })
200
201 r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
202 query := &protocol.WSQuery{}
203 if err := queryparam.Parse(r.URL.Query(), query); err != nil {
204 responder.Respond(w, responder.Status(http.StatusBadRequest))
205 return
206 }
207
208 if _, valid := query.Valid(); !valid {
209 responder.Respond(w, responder.Status(http.StatusBadRequest))
210 return
211 }
212
213 room := srv.FindRoomByID(query.RoomID)
214 if room == nil {
215 responder.Respond(w, responder.Status(http.StatusBadRequest))
216 return
217 }
218
219 c, err := websocket.Accept(w, r, wsOpts)
220 if err != nil {
221 return
222 }
223
224 g.Go(func() error {
225 room.HandleConn(ctx, query.Nickname, c)
226 return nil
227 })
228 })
229 })
230 })
231
232 g.Go(func() error {
233 return srv.Run(ctx)
234 })
235
236 runServer(ctx, g, args.Addr, r)
237
238 if args.Prod {
239 runServer(ctx, g, ":2112", prometheusHandler())
240 }
241
242 exitErr := g.Wait()
243 ctxlog.Fatal(ctx, "exited", zap.Error(exitErr))
244}
245
246func staticHandler() http.Handler {
247 fs := pkger.Dir("/frontend/build")
248 fsh := http.FileServer(fs)
249
250 r := chi.NewMux()
251 r.Use(middleware.Compress(5))
252
253 r.Handle("/static/*", fsh)
254 r.Handle("/favicon/*", fsh)
255
256 r.Group(func(r chi.Router) {
257 r.Use(middleware.NoCache)
258 r.Handle("/*", fsh)
259 })
260
261 return r
262}
263
264func checkVersion(next http.Handler) http.Handler {
265 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
266 want := version.Version()
267
268 toCheck := []string{
269 r.Header.Get("X-CODIES-VERSION"),
270 r.URL.Query().Get("codiesVersion"),
271 }
272
273 for _, got := range toCheck {
274 if got == want {
275 next.ServeHTTP(w, r)
276 return
277 }
278 }
279
280 reason := fmt.Sprintf("client version too old, please reload to get %s", want)
281
282 if r.Header.Get("Upgrade") == "websocket" {
283 c, err := websocket.Accept(w, r, wsOpts)
284 if err != nil {
285 return
286 }
287 c.Close(4418, reason)
288 return
289 }
290
291 w.WriteHeader(http.StatusTeapot)
292 fmt.Fprint(w, reason)
293 })
294}
295
296func runServer(ctx context.Context, g *errgroup.Group, addr string, handler http.Handler) {
297 httpSrv := http.Server{Addr: addr, Handler: handler}
298
299 g.Go(func() error {
300 <-ctx.Done()
301
302 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
303 defer cancel()
304
305 return httpSrv.Shutdown(ctx)
306 })
307
308 g.Go(func() error {
309 return httpSrv.ListenAndServe()
310 })
311}
312
313func prometheusHandler() http.Handler {
314 mux := http.NewServeMux()
315 mux.Handle("/metrics", promhttp.Handler())
316 return mux
317}