1package main
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"log"
  7	"math/rand"
  8	"net/http"
  9	"os"
 10	"reflect"
 11	"time"
 12
 13	"github.com/go-chi/chi"
 14	"github.com/go-chi/chi/middleware"
 15	"github.com/gofrs/uuid"
 16	"github.com/jessevdk/go-flags"
 17	"github.com/posener/ctxutil"
 18	"github.com/tomwright/queryparam/v4"
 19	"github.com/zikaeroh/codies/internal/protocol"
 20	"github.com/zikaeroh/codies/internal/server"
 21	"github.com/zikaeroh/codies/internal/version"
 22	"golang.org/x/sync/errgroup"
 23	"nhooyr.io/websocket"
 24)
 25
 26var args = struct {
 27	Addr    string   `long:"addr" env:"CODIES_ADDR" description:"Address to listen at"`
 28	Origins []string `long:"origins" env:"CODIES_ORIGINS" env-delim:"," description:"Additional valid origins for WebSocket connections"`
 29	Debug   bool     `long:"debug" env:"CODIES_DEBUG" description:"Enables debug mode"`
 30}{
 31	Addr: ":5000",
 32}
 33
 34func main() {
 35	rand.Seed(time.Now().Unix())
 36	log.SetFlags(log.LstdFlags | log.Lshortfile)
 37
 38	if _, err := flags.Parse(&args); err != nil {
 39		// Default flag parser prints messages, so just exit.
 40		os.Exit(1)
 41	}
 42
 43	log.Printf("starting codies server, version %s", version.Version())
 44
 45	wsOpts := &websocket.AcceptOptions{
 46		OriginPatterns: args.Origins,
 47	}
 48
 49	if args.Debug {
 50		log.Println("starting in debug mode, allowing any WebSocket origin host")
 51		wsOpts.OriginPatterns = []string{"*"}
 52	}
 53
 54	g, ctx := errgroup.WithContext(ctxutil.Interrupt())
 55
 56	srv := server.NewServer()
 57
 58	r := chi.NewMux()
 59	r.Use(middleware.Heartbeat("/ping"))
 60	r.Use(middleware.Recoverer)
 61
 62	r.Group(func(r chi.Router) {
 63		r.Use(middleware.Compress(5))
 64		fs := http.Dir("./frontend/build")
 65		r.NotFound(http.FileServer(fs).ServeHTTP)
 66	})
 67
 68	r.Group(func(r chi.Router) {
 69		r.Use(middleware.NoCache)
 70
 71		r.Get("/api/time", func(w http.ResponseWriter, r *http.Request) {
 72			w.Header().Add("Content-Type", "application/json")
 73			_ = json.NewEncoder(w).Encode(&protocol.TimeResponse{Time: time.Now()})
 74		})
 75
 76		r.Get("/api/exists", func(w http.ResponseWriter, r *http.Request) {
 77			query := &protocol.ExistsQuery{}
 78			if err := queryparam.Parse(r.URL.Query(), query); err != nil {
 79				httpErr(w, http.StatusBadRequest)
 80				return
 81			}
 82
 83			room := srv.FindRoomByID(query.RoomID)
 84			if room == nil {
 85				w.WriteHeader(http.StatusNotFound)
 86			} else {
 87				w.WriteHeader(http.StatusOK)
 88			}
 89
 90			_, _ = w.Write([]byte("."))
 91		})
 92
 93		r.Post("/api/room", func(w http.ResponseWriter, r *http.Request) {
 94			defer r.Body.Close()
 95
 96			req := &protocol.RoomRequest{}
 97			if err := json.NewDecoder(r.Body).Decode(req); err != nil {
 98				httpErr(w, http.StatusBadRequest)
 99				return
100			}
101
102			if !req.Valid() {
103				httpErr(w, http.StatusBadRequest)
104				return
105			}
106
107			resp := &protocol.RoomResponse{}
108
109			w.Header().Add("Content-Type", "application/json")
110
111			if req.Create {
112				room, err := srv.CreateRoom(req.RoomName, req.RoomPass)
113				if err != nil {
114					switch err {
115					case server.ErrRoomExists:
116						resp.Error = stringPtr("Room already exists.")
117						w.WriteHeader(http.StatusBadRequest)
118					case server.ErrTooManyRooms:
119						resp.Error = stringPtr("Too many rooms.")
120						w.WriteHeader(http.StatusServiceUnavailable)
121					default:
122						resp.Error = stringPtr("An unknown error occurred.")
123						w.WriteHeader(http.StatusInternalServerError)
124					}
125				} else {
126					resp.ID = &room.ID
127					w.WriteHeader(http.StatusOK)
128				}
129			} else {
130				room := srv.FindRoom(req.RoomName)
131				if room == nil || room.Password != req.RoomPass {
132					resp.Error = stringPtr("Room not found or password does not match.")
133					w.WriteHeader(http.StatusNotFound)
134				} else {
135					resp.ID = &room.ID
136					w.WriteHeader(http.StatusOK)
137				}
138			}
139
140			_ = json.NewEncoder(w).Encode(resp)
141		})
142
143		r.Get("/api/ws", func(w http.ResponseWriter, r *http.Request) {
144			query := &protocol.WSQuery{}
145			if err := queryparam.Parse(r.URL.Query(), query); err != nil {
146				httpErr(w, http.StatusBadRequest)
147				return
148			}
149
150			if !query.Valid() {
151				httpErr(w, http.StatusBadRequest)
152				return
153			}
154
155			room := srv.FindRoomByID(query.RoomID)
156			if room == nil {
157				httpErr(w, http.StatusNotFound)
158				return
159			}
160
161			c, err := websocket.Accept(w, r, wsOpts)
162			if err != nil {
163				log.Println(err)
164				return
165			}
166
167			g.Go(func() error {
168				room.HandleConn(query.PlayerID, query.Nickname, c)
169				return nil
170			})
171		})
172
173		r.Get("/api/stats", func(w http.ResponseWriter, r *http.Request) {
174			rooms, clients := srv.Stats()
175
176			enc := json.NewEncoder(w)
177			enc.SetIndent("", "    ")
178			_ = enc.Encode(&protocol.StatsResponse{
179				Rooms:   rooms,
180				Clients: clients,
181			})
182		})
183	})
184
185	g.Go(func() error {
186		return srv.Run(ctx)
187	})
188
189	httpSrv := http.Server{Addr: args.Addr, Handler: r}
190
191	g.Go(func() error {
192		<-ctx.Done()
193
194		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
195		defer cancel()
196
197		return httpSrv.Shutdown(ctx)
198	})
199
200	g.Go(func() error {
201		return httpSrv.ListenAndServe()
202	})
203
204	log.Fatal(g.Wait())
205}
206
207func httpErr(w http.ResponseWriter, code int) {
208	http.Error(w, http.StatusText(code), code)
209}
210
211func stringPtr(s string) *string {
212	return &s
213}
214
215func init() {
216	queryparam.DefaultParser.ValueParsers[reflect.TypeOf(uuid.UUID{})] = func(value string, _ string) (reflect.Value, error) {
217		id, err := uuid.FromString(value)
218		return reflect.ValueOf(id), err
219	}
220}