1package main
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log"
  8	"math/rand"
  9	"net/http"
 10	"os"
 11	"time"
 12
 13	"github.com/go-chi/chi"
 14	"github.com/go-chi/chi/middleware"
 15	"github.com/jessevdk/go-flags"
 16	"github.com/posener/ctxutil"
 17	"github.com/prometheus/client_golang/prometheus/promhttp"
 18	"github.com/tomwright/queryparam/v4"
 19	"github.com/zikaeroh/codies/internal/pkger"
 20	"github.com/zikaeroh/codies/internal/protocol"
 21	"github.com/zikaeroh/codies/internal/responder"
 22	"github.com/zikaeroh/codies/internal/server"
 23	"github.com/zikaeroh/codies/internal/version"
 24	"github.com/zikaeroh/ctxlog"
 25	"go.uber.org/zap"
 26	"golang.org/x/sync/errgroup"
 27	"nhooyr.io/websocket"
 28)
 29
 30var args = struct {
 31	Addr    string   `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"`
 32	Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"`
 33	Prod    bool     `long:"prod" env:"CODIES_PROD" description:"Enables production mode"`
 34	Debug   bool     `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"`
 35}{
 36	Addr: ":5000",
 37}
 38
 39var wsOpts *websocket.AcceptOptions
 40
 41func main() {
 42	if argv := os.Args[1:]; len(argv) > 0 && argv[0] == "version" {
 43		fmt.Println(version.Version())
 44		return
 45	}
 46
 47	rand.Seed(time.Now().Unix())
 48
 49	if _, err := flags.Parse(&args); err != nil {
 50		// Default flag parser prints messages, so just exit.
 51		os.Exit(1)
 52	}
 53
 54	if !args.Prod && !args.Debug {
 55		log.Fatal("missing required option --prod or --debug")
 56	} else if args.Prod && args.Debug {
 57		log.Fatal("must specify either --prod or --debug")
 58	}
 59
 60	ctx := ctxutil.Interrupt()
 61
 62	logger := ctxlog.New(args.Debug)
 63	defer zap.RedirectStdLog(logger)()
 64	ctx = ctxlog.WithLogger(ctx, logger)
 65
 66	ctxlog.Info(ctx, "starting", zap.String("version", version.Version()))
 67
 68	wsOpts = &websocket.AcceptOptions{
 69		OriginPatterns:  args.Origins,
 70		CompressionMode: websocket.CompressionContextTakeover,
 71	}
 72
 73	if args.Debug {
 74		ctxlog.Info(ctx, "starting in debug mode, allowing any WebSocket origin host")
 75		wsOpts.InsecureSkipVerify = true
 76	} else if !version.IsSet() {
 77		ctxlog.Fatal(ctx, "running production build without version set")
 78	}
 79
 80	g, ctx := errgroup.WithContext(ctx)
 81
 82	srv := server.NewServer()
 83
 84	r := chi.NewMux()
 85
 86	r.Use(func(next http.Handler) http.Handler {
 87		return promhttp.InstrumentHandlerCounter(metricRequest, next)
 88	})
 89
 90	r.Use(middleware.Heartbeat("/ping"))
 91	r.Use(middleware.Recoverer)
 92	r.NotFound(staticHandler().ServeHTTP)
 93
 94	r.Group(func(r chi.Router) {
 95		r.Use(middleware.NoCache)
 96
 97		r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) {
 98			responder.Respond(w, responder.Body(&protocol.TimeResponse{Time: time.Now()}))
 99		})
100
101		r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) {
102			rooms, clients := srv.Stats()
103			responder.Respond(w,
104				responder.Body(&protocol.StatsResponse{
105					Rooms:   rooms,
106					Clients: clients,
107				}),
108				responder.Pretty(true),
109			)
110		})
111
112		r.Group(func(r chi.Router) {
113			if !args.Debug {
114				r.Use(checkVersion)
115			}
116
117			r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) {
118				query := &protocol.ExistsQuery{}
119				if err := queryparam.Parse(r.URL.Query(), query); err != nil {
120					responder.Respond(w, responder.Status(http.StatusBadRequest))
121					return
122				}
123
124				room := srv.FindRoomByID(query.RoomID)
125				if room == nil {
126					responder.Respond(w, responder.Status(http.StatusNotFound))
127				} else {
128					responder.Respond(w, responder.Status(http.StatusOK))
129				}
130			})
131
132			r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) {
133				defer r.Body.Close()
134
135				req := &protocol.RoomRequest{}
136				if err := json.NewDecoder(r.Body).Decode(req); err != nil {
137					responder.Respond(w, responder.Status(http.StatusBadRequest))
138					return
139				}
140
141				if msg, valid := req.Valid(); !valid {
142					responder.Respond(w,
143						responder.Status(http.StatusBadRequest),
144						responder.Body(&protocol.RoomResponse{
145							Error: stringPtr(msg),
146						}),
147					)
148					return
149				}
150
151				var room *server.Room
152				if req.Create {
153					var err error
154					room, err = srv.CreateRoom(ctx, req.RoomName, req.RoomPass)
155					if err != nil {
156						switch err {
157						case server.ErrRoomExists:
158							responder.Respond(w,
159								responder.Status(http.StatusBadRequest),
160								responder.Body(&protocol.RoomResponse{
161									Error: stringPtr("Room already exists."),
162								}),
163							)
164						case server.ErrTooManyRooms:
165							responder.Respond(w,
166								responder.Status(http.StatusServiceUnavailable),
167								responder.Body(&protocol.RoomResponse{
168									Error: stringPtr("Too many rooms."),
169								}),
170							)
171						default:
172							responder.Respond(w,
173								responder.Status(http.StatusInternalServerError),
174								responder.Body(&protocol.RoomResponse{
175									Error: stringPtr("An unknown error occurred."),
176								}),
177							)
178						}
179						return
180					}
181				} else {
182					room = srv.FindRoom(req.RoomName)
183					if room == nil || room.Password != req.RoomPass {
184						responder.Respond(w,
185							responder.Status(http.StatusNotFound),
186							responder.Body(&protocol.RoomResponse{
187								Error: stringPtr("Room not found or password does not match."),
188							}),
189						)
190						return
191					}
192				}
193
194				responder.Respond(w, responder.Body(&protocol.RoomResponse{
195					ID: &room.ID,
196				}))
197			})
198
199			r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
200				query := &protocol.WSQuery{}
201				if err := queryparam.Parse(r.URL.Query(), query); err != nil {
202					responder.Respond(w, responder.Status(http.StatusBadRequest))
203					return
204				}
205
206				if _, valid := query.Valid(); !valid {
207					responder.Respond(w, responder.Status(http.StatusBadRequest))
208					return
209				}
210
211				room := srv.FindRoomByID(query.RoomID)
212				if room == nil {
213					responder.Respond(w, responder.Status(http.StatusBadRequest))
214					return
215				}
216
217				c, err := websocket.Accept(w, r, wsOpts)
218				if err != nil {
219					return
220				}
221
222				g.Go(func() error {
223					room.HandleConn(ctx, query.Nickname, c)
224					return nil
225				})
226			})
227		})
228	})
229
230	g.Go(func() error {
231		return srv.Run(ctx)
232	})
233
234	runServer(ctx, g, args.Addr, r)
235
236	if args.Prod {
237		runServer(ctx, g, ":2112", prometheusHandler())
238	}
239
240	exitErr := g.Wait()
241	ctxlog.Fatal(ctx, "exited", zap.Error(exitErr))
242}
243
244func staticHandler() http.Handler {
245	fs := pkger.Dir("/frontend/build")
246	fsh := http.FileServer(fs)
247
248	r := chi.NewMux()
249	r.Use(middleware.Compress(5))
250
251	r.Handle("/static/*", fsh)
252	r.Handle("/favicon/*", fsh)
253
254	r.Group(func(r chi.Router) {
255		r.Use(middleware.NoCache)
256		r.Handle("/*", fsh)
257	})
258
259	return r
260}
261
262func checkVersion(next http.Handler) http.Handler {
263	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
264		want := version.Version()
265
266		toCheck := []string{
267			r.Header.Get("X-CODIES-VERSION"),
268			r.URL.Query().Get("codiesVersion"),
269		}
270
271		for _, got := range toCheck {
272			if got == want {
273				next.ServeHTTP(w, r)
274				return
275			}
276		}
277
278		reason := fmt.Sprintf("client version too old, please reload to get %s", want)
279
280		if r.Header.Get("Upgrade") == "websocket" {
281			c, err := websocket.Accept(w, r, wsOpts)
282			if err != nil {
283				return
284			}
285			c.Close(4418, reason)
286			return
287		}
288
289		w.WriteHeader(http.StatusTeapot)
290		fmt.Fprint(w, reason)
291	})
292}
293
294func runServer(ctx context.Context, g *errgroup.Group, addr string, handler http.Handler) {
295	httpSrv := http.Server{Addr: addr, Handler: handler}
296
297	g.Go(func() error {
298		<-ctx.Done()
299
300		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
301		defer cancel()
302
303		return httpSrv.Shutdown(ctx)
304	})
305
306	g.Go(httpSrv.ListenAndServe)
307}
308
309func prometheusHandler() http.Handler {
310	mux := http.NewServeMux()
311	mux.Handle("/metrics", promhttp.Handler())
312	return mux
313}