1package main
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log"
  8	"math/rand"
  9	"net/http"
 10	"os"
 11	"time"
 12
 13	"github.com/go-chi/chi"
 14	"github.com/go-chi/chi/middleware"
 15	"github.com/jessevdk/go-flags"
 16	"github.com/posener/ctxutil"
 17	"github.com/prometheus/client_golang/prometheus/promhttp"
 18	"github.com/tomwright/queryparam/v4"
 19	"github.com/zikaeroh/codies/internal/pkger"
 20	"github.com/zikaeroh/codies/internal/protocol"
 21	"github.com/zikaeroh/codies/internal/responder"
 22	"github.com/zikaeroh/codies/internal/server"
 23	"github.com/zikaeroh/codies/internal/version"
 24	"github.com/zikaeroh/ctxlog"
 25	"go.uber.org/zap"
 26	"golang.org/x/sync/errgroup"
 27	"nhooyr.io/websocket"
 28)
 29
 30var args = struct {
 31	Addr    string   `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"`
 32	Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"`
 33	Prod    bool     `long:"prod" env:"CODIES_PROD" description:"Enables production mode"`
 34	Debug   bool     `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"`
 35}{
 36	Addr: ":5000",
 37}
 38
 39var wsOpts *websocket.AcceptOptions
 40
 41func main() {
 42	if argv := os.Args[1:]; len(argv) > 0 && argv[0] == "version" {
 43		fmt.Println(version.Version())
 44		return
 45	}
 46
 47	rand.Seed(time.Now().Unix())
 48
 49	if _, err := flags.Parse(&args); err != nil {
 50		// Default flag parser prints messages, so just exit.
 51		os.Exit(1)
 52	}
 53
 54	if !args.Prod && !args.Debug {
 55		log.Fatal("missing required option --prod or --debug")
 56	} else if args.Prod && args.Debug {
 57		log.Fatal("must specify either --prod or --debug")
 58	}
 59
 60	ctx := ctxutil.Interrupt()
 61
 62	logger := ctxlog.New(args.Debug)
 63	defer zap.RedirectStdLog(logger)()
 64	ctx = ctxlog.WithLogger(ctx, logger)
 65
 66	ctxlog.Info(ctx, "starting", zap.String("version", version.Version()))
 67
 68	wsOpts = &websocket.AcceptOptions{
 69		OriginPatterns:  args.Origins,
 70		CompressionMode: websocket.CompressionContextTakeover,
 71	}
 72
 73	if args.Debug {
 74		ctxlog.Info(ctx, "starting in debug mode, allowing any WebSocket origin host")
 75		wsOpts.InsecureSkipVerify = true
 76	} else {
 77		if !version.VersionSet() {
 78			ctxlog.Fatal(ctx, "running production build without version set")
 79		}
 80	}
 81
 82	g, ctx := errgroup.WithContext(ctx)
 83
 84	srv := server.NewServer()
 85
 86	r := chi.NewMux()
 87
 88	r.Use(func(next http.Handler) http.Handler {
 89		return promhttp.InstrumentHandlerCounter(metricRequest, next)
 90	})
 91
 92	r.Use(middleware.Heartbeat("/ping"))
 93	r.Use(middleware.Recoverer)
 94	r.NotFound(staticHandler().ServeHTTP)
 95
 96	r.Group(func(r chi.Router) {
 97		r.Use(middleware.NoCache)
 98
 99		r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) {
100			responder.Respond(w, responder.Body(&protocol.TimeResponse{Time: time.Now()}))
101		})
102
103		r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) {
104			rooms, clients := srv.Stats()
105			responder.Respond(w,
106				responder.Body(&protocol.StatsResponse{
107					Rooms:   rooms,
108					Clients: clients,
109				}),
110				responder.Pretty(true),
111			)
112		})
113
114		r.Group(func(r chi.Router) {
115			if !args.Debug {
116				r.Use(checkVersion)
117			}
118
119			r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) {
120				query := &protocol.ExistsQuery{}
121				if err := queryparam.Parse(r.URL.Query(), query); err != nil {
122					responder.Respond(w, responder.Status(http.StatusBadRequest))
123					return
124				}
125
126				room := srv.FindRoomByID(query.RoomID)
127				if room == nil {
128					responder.Respond(w, responder.Status(http.StatusNotFound))
129				} else {
130					responder.Respond(w, responder.Status(http.StatusOK))
131				}
132			})
133
134			r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) {
135				defer r.Body.Close()
136
137				req := &protocol.RoomRequest{}
138				if err := json.NewDecoder(r.Body).Decode(req); err != nil {
139					responder.Respond(w, responder.Status(http.StatusBadRequest))
140					return
141				}
142
143				if msg, valid := req.Valid(); !valid {
144					responder.Respond(w,
145						responder.Status(http.StatusBadRequest),
146						responder.Body(&protocol.RoomResponse{
147							Error: stringPtr(msg),
148						}),
149					)
150					return
151				}
152
153				var room *server.Room
154				if req.Create {
155					var err error
156					room, err = srv.CreateRoom(ctx, req.RoomName, req.RoomPass)
157					if err != nil {
158						switch err {
159						case server.ErrRoomExists:
160							responder.Respond(w,
161								responder.Status(http.StatusBadRequest),
162								responder.Body(&protocol.RoomResponse{
163									Error: stringPtr("Room already exists."),
164								}),
165							)
166						case server.ErrTooManyRooms:
167							responder.Respond(w,
168								responder.Status(http.StatusServiceUnavailable),
169								responder.Body(&protocol.RoomResponse{
170									Error: stringPtr("Too many rooms."),
171								}),
172							)
173						default:
174							responder.Respond(w,
175								responder.Status(http.StatusInternalServerError),
176								responder.Body(&protocol.RoomResponse{
177									Error: stringPtr("An unknown error occurred."),
178								}),
179							)
180						}
181						return
182					}
183				} else {
184					room = srv.FindRoom(req.RoomName)
185					if room == nil || room.Password != req.RoomPass {
186						responder.Respond(w,
187							responder.Status(http.StatusNotFound),
188							responder.Body(&protocol.RoomResponse{
189								Error: stringPtr("Room not found or password does not match."),
190							}),
191						)
192						return
193					}
194				}
195
196				responder.Respond(w, responder.Body(&protocol.RoomResponse{
197					ID: &room.ID,
198				}))
199			})
200
201			r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
202				query := &protocol.WSQuery{}
203				if err := queryparam.Parse(r.URL.Query(), query); err != nil {
204					responder.Respond(w, responder.Status(http.StatusBadRequest))
205					return
206				}
207
208				if _, valid := query.Valid(); !valid {
209					responder.Respond(w, responder.Status(http.StatusBadRequest))
210					return
211				}
212
213				room := srv.FindRoomByID(query.RoomID)
214				if room == nil {
215					responder.Respond(w, responder.Status(http.StatusBadRequest))
216					return
217				}
218
219				c, err := websocket.Accept(w, r, wsOpts)
220				if err != nil {
221					return
222				}
223
224				g.Go(func() error {
225					room.HandleConn(ctx, query.Nickname, c)
226					return nil
227				})
228			})
229		})
230	})
231
232	g.Go(func() error {
233		return srv.Run(ctx)
234	})
235
236	runServer(ctx, g, args.Addr, r)
237
238	if args.Prod {
239		runServer(ctx, g, ":2112", prometheusHandler())
240	}
241
242	exitErr := g.Wait()
243	ctxlog.Fatal(ctx, "exited", zap.Error(exitErr))
244}
245
246func staticHandler() http.Handler {
247	fs := pkger.Dir("/frontend/build")
248	fsh := http.FileServer(fs)
249
250	r := chi.NewMux()
251	r.Use(middleware.Compress(5))
252
253	r.Handle("/static/*", fsh)
254	r.Handle("/favicon/*", fsh)
255
256	r.Group(func(r chi.Router) {
257		r.Use(middleware.NoCache)
258		r.Handle("/*", fsh)
259	})
260
261	return r
262}
263
264func checkVersion(next http.Handler) http.Handler {
265	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
266		want := version.Version()
267
268		toCheck := []string{
269			r.Header.Get("X-CODIES-VERSION"),
270			r.URL.Query().Get("codiesVersion"),
271		}
272
273		for _, got := range toCheck {
274			if got == want {
275				next.ServeHTTP(w, r)
276				return
277			}
278		}
279
280		reason := fmt.Sprintf("client version too old, please reload to get %s", want)
281
282		if r.Header.Get("Upgrade") == "websocket" {
283			c, err := websocket.Accept(w, r, wsOpts)
284			if err != nil {
285				return
286			}
287			c.Close(4418, reason)
288			return
289		}
290
291		w.WriteHeader(http.StatusTeapot)
292		fmt.Fprint(w, reason)
293	})
294}
295
296func runServer(ctx context.Context, g *errgroup.Group, addr string, handler http.Handler) {
297	httpSrv := http.Server{Addr: addr, Handler: handler}
298
299	g.Go(func() error {
300		<-ctx.Done()
301
302		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
303		defer cancel()
304
305		return httpSrv.Shutdown(ctx)
306	})
307
308	g.Go(func() error {
309		return httpSrv.ListenAndServe()
310	})
311}
312
313func prometheusHandler() http.Handler {
314	mux := http.NewServeMux()
315	mux.Handle("/metrics", promhttp.Handler())
316	return mux
317}