1package main
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log"
8 "math/rand"
9 "net/http"
10 "os"
11 "time"
12
13 "github.com/go-chi/chi"
14 "github.com/go-chi/chi/middleware"
15 "github.com/jessevdk/go-flags"
16 "github.com/posener/ctxutil"
17 "github.com/prometheus/client_golang/prometheus/promhttp"
18 "github.com/tomwright/queryparam/v4"
19 "github.com/zikaeroh/codies/internal/pkger"
20 "github.com/zikaeroh/codies/internal/protocol"
21 "github.com/zikaeroh/codies/internal/responder"
22 "github.com/zikaeroh/codies/internal/server"
23 "github.com/zikaeroh/codies/internal/version"
24 "github.com/zikaeroh/ctxlog"
25 "go.uber.org/zap"
26 "golang.org/x/sync/errgroup"
27 "nhooyr.io/websocket"
28)
29
30var args = struct {
31 Addr string `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"`
32 Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"`
33 Prod bool `long:"prod" env:"CODIES_PROD" description:"Enables production mode"`
34 Debug bool `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"`
35}{
36 Addr: ":5000",
37}
38
39var wsOpts *websocket.AcceptOptions
40
41func main() {
42 if argv := os.Args[1:]; len(argv) > 0 && argv[0] == "version" {
43 fmt.Println(version.Version())
44 return
45 }
46
47 rand.Seed(time.Now().Unix())
48
49 if _, err := flags.Parse(&args); err != nil {
50 // Default flag parser prints messages, so just exit.
51 os.Exit(1)
52 }
53
54 if !args.Prod && !args.Debug {
55 log.Fatal("missing required option --prod or --debug")
56 } else if args.Prod && args.Debug {
57 log.Fatal("must specify either --prod or --debug")
58 }
59
60 ctx := ctxutil.Interrupt()
61
62 logger := ctxlog.New(args.Debug)
63 defer zap.RedirectStdLog(logger)()
64 ctx = ctxlog.WithLogger(ctx, logger)
65
66 ctxlog.Info(ctx, "starting", zap.String("version", version.Version()))
67
68 wsOpts = &websocket.AcceptOptions{
69 OriginPatterns: args.Origins,
70 CompressionMode: websocket.CompressionContextTakeover,
71 }
72
73 if args.Debug {
74 ctxlog.Info(ctx, "starting in debug mode, allowing any WebSocket origin host")
75 wsOpts.InsecureSkipVerify = true
76 } else if !version.IsSet() {
77 ctxlog.Fatal(ctx, "running production build without version set")
78 }
79
80 g, ctx := errgroup.WithContext(ctx)
81
82 srv := server.NewServer()
83
84 r := chi.NewMux()
85
86 r.Use(func(next http.Handler) http.Handler {
87 return promhttp.InstrumentHandlerCounter(metricRequest, next)
88 })
89
90 r.Use(middleware.Heartbeat("/ping"))
91 r.Use(middleware.Recoverer)
92 r.NotFound(staticHandler().ServeHTTP)
93
94 r.Group(func(r chi.Router) {
95 r.Use(middleware.NoCache)
96
97 r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) {
98 responder.Respond(w, responder.Body(&protocol.TimeResponse{Time: time.Now()}))
99 })
100
101 r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) {
102 rooms, clients := srv.Stats()
103 responder.Respond(w,
104 responder.Body(&protocol.StatsResponse{
105 Rooms: rooms,
106 Clients: clients,
107 }),
108 responder.Pretty(true),
109 )
110 })
111
112 r.Group(func(r chi.Router) {
113 if !args.Debug {
114 r.Use(checkVersion)
115 }
116
117 r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) {
118 query := &protocol.ExistsQuery{}
119 if err := queryparam.Parse(r.URL.Query(), query); err != nil {
120 responder.Respond(w, responder.Status(http.StatusBadRequest))
121 return
122 }
123
124 room := srv.FindRoomByID(query.RoomID)
125 if room == nil {
126 responder.Respond(w, responder.Status(http.StatusNotFound))
127 } else {
128 responder.Respond(w, responder.Status(http.StatusOK))
129 }
130 })
131
132 r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) {
133 defer r.Body.Close()
134
135 req := &protocol.RoomRequest{}
136 if err := json.NewDecoder(r.Body).Decode(req); err != nil {
137 responder.Respond(w, responder.Status(http.StatusBadRequest))
138 return
139 }
140
141 if msg, valid := req.Valid(); !valid {
142 responder.Respond(w,
143 responder.Status(http.StatusBadRequest),
144 responder.Body(&protocol.RoomResponse{
145 Error: stringPtr(msg),
146 }),
147 )
148 return
149 }
150
151 var room *server.Room
152 if req.Create {
153 var err error
154 room, err = srv.CreateRoom(ctx, req.RoomName, req.RoomPass)
155 if err != nil {
156 switch err {
157 case server.ErrRoomExists:
158 responder.Respond(w,
159 responder.Status(http.StatusBadRequest),
160 responder.Body(&protocol.RoomResponse{
161 Error: stringPtr("Room already exists."),
162 }),
163 )
164 case server.ErrTooManyRooms:
165 responder.Respond(w,
166 responder.Status(http.StatusServiceUnavailable),
167 responder.Body(&protocol.RoomResponse{
168 Error: stringPtr("Too many rooms."),
169 }),
170 )
171 default:
172 responder.Respond(w,
173 responder.Status(http.StatusInternalServerError),
174 responder.Body(&protocol.RoomResponse{
175 Error: stringPtr("An unknown error occurred."),
176 }),
177 )
178 }
179 return
180 }
181 } else {
182 room = srv.FindRoom(req.RoomName)
183 if room == nil || room.Password != req.RoomPass {
184 responder.Respond(w,
185 responder.Status(http.StatusNotFound),
186 responder.Body(&protocol.RoomResponse{
187 Error: stringPtr("Room not found or password does not match."),
188 }),
189 )
190 return
191 }
192 }
193
194 responder.Respond(w, responder.Body(&protocol.RoomResponse{
195 ID: &room.ID,
196 }))
197 })
198
199 r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
200 query := &protocol.WSQuery{}
201 if err := queryparam.Parse(r.URL.Query(), query); err != nil {
202 responder.Respond(w, responder.Status(http.StatusBadRequest))
203 return
204 }
205
206 if _, valid := query.Valid(); !valid {
207 responder.Respond(w, responder.Status(http.StatusBadRequest))
208 return
209 }
210
211 room := srv.FindRoomByID(query.RoomID)
212 if room == nil {
213 responder.Respond(w, responder.Status(http.StatusBadRequest))
214 return
215 }
216
217 c, err := websocket.Accept(w, r, wsOpts)
218 if err != nil {
219 return
220 }
221
222 g.Go(func() error {
223 room.HandleConn(ctx, query.Nickname, c)
224 return nil
225 })
226 })
227 })
228 })
229
230 g.Go(func() error {
231 return srv.Run(ctx)
232 })
233
234 runServer(ctx, g, args.Addr, r)
235
236 if args.Prod {
237 runServer(ctx, g, ":2112", prometheusHandler())
238 }
239
240 exitErr := g.Wait()
241 ctxlog.Fatal(ctx, "exited", zap.Error(exitErr))
242}
243
244func staticHandler() http.Handler {
245 fs := pkger.Dir("/frontend/build")
246 fsh := http.FileServer(fs)
247
248 r := chi.NewMux()
249 r.Use(middleware.Compress(5))
250
251 r.Handle("/static/*", fsh)
252 r.Handle("/favicon/*", fsh)
253
254 r.Group(func(r chi.Router) {
255 r.Use(middleware.NoCache)
256 r.Handle("/*", fsh)
257 })
258
259 return r
260}
261
262func checkVersion(next http.Handler) http.Handler {
263 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
264 want := version.Version()
265
266 toCheck := []string{
267 r.Header.Get("X-CODIES-VERSION"),
268 r.URL.Query().Get("codiesVersion"),
269 }
270
271 for _, got := range toCheck {
272 if got == want {
273 next.ServeHTTP(w, r)
274 return
275 }
276 }
277
278 reason := fmt.Sprintf("client version too old, please reload to get %s", want)
279
280 if r.Header.Get("Upgrade") == "websocket" {
281 c, err := websocket.Accept(w, r, wsOpts)
282 if err != nil {
283 return
284 }
285 c.Close(4418, reason)
286 return
287 }
288
289 w.WriteHeader(http.StatusTeapot)
290 fmt.Fprint(w, reason)
291 })
292}
293
294func runServer(ctx context.Context, g *errgroup.Group, addr string, handler http.Handler) {
295 httpSrv := http.Server{Addr: addr, Handler: handler}
296
297 g.Go(func() error {
298 <-ctx.Done()
299
300 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
301 defer cancel()
302
303 return httpSrv.Shutdown(ctx)
304 })
305
306 g.Go(httpSrv.ListenAndServe)
307}
308
309func prometheusHandler() http.Handler {
310 mux := http.NewServeMux()
311 mux.Handle("/metrics", promhttp.Handler())
312 return mux
313}