willow.go

  1// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
  2//
  3// SPDX-License-Identifier: Apache-2.0
  4
  5package main
  6
  7import (
  8	"errors"
  9	"fmt"
 10	"log"
 11	"net/http"
 12	"os"
 13	"strconv"
 14	"sync"
 15
 16	"git.sr.ht/~amolith/willow/db"
 17	"git.sr.ht/~amolith/willow/project"
 18	"git.sr.ht/~amolith/willow/ws"
 19	"github.com/BurntSushi/toml"
 20	flag "github.com/spf13/pflag"
 21)
 22
 23type (
 24	Config struct {
 25		Server server `toml:"Server"`
 26		DBConn string `toml:"DBConn"`
 27		// TODO: Make cache location configurable
 28		// CacheLocation string
 29		FetchInterval int `toml:"FetchInterval"`
 30	}
 31
 32	server struct {
 33		Listen string `toml:"Listen"`
 34	}
 35)
 36
 37var (
 38	flagConfig          = flag.StringP("config", "c", "config.toml", "Path to config file")
 39	flagAddUser         = flag.StringP("add", "a", "", "Username of account to add")
 40	flagDeleteUser      = flag.StringP("deleteuser", "d", "", "Username of account to delete")
 41	flagCheckAuthorised = flag.StringP("validatecredentials", "V", "", "Username of account to check")
 42	flagListUsers       = flag.BoolP("listusers", "l", false, "List all users")
 43	flagShowVersion     = flag.BoolP("version", "v", false, "Print Willow's version")
 44	version             = ""
 45	config              Config
 46	req                 = make(chan struct{})
 47	res                 = make(chan []project.Project)
 48	manualRefresh       = make(chan struct{})
 49)
 50
 51func main() {
 52	flag.Parse()
 53
 54	if *flagShowVersion {
 55		fmt.Println(version)
 56		os.Exit(0)
 57	}
 58
 59	err := checkConfig()
 60	if err != nil {
 61		log.Fatalln(err)
 62	}
 63
 64	fmt.Println("Opening database at", config.DBConn)
 65
 66	dbConn, err := db.Open(config.DBConn)
 67	if err != nil {
 68		fmt.Println("Error opening database:", err)
 69		os.Exit(1)
 70	}
 71
 72	fmt.Println("Checking whether database needs initialising")
 73
 74	err = db.InitialiseDatabase(dbConn)
 75	if err != nil {
 76		fmt.Println("Error initialising database:", err)
 77		os.Exit(1)
 78	}
 79
 80	fmt.Println("Checking whether there are pending migrations")
 81
 82	err = db.Migrate(dbConn)
 83	if err != nil {
 84		fmt.Println("Error migrating database schema:", err)
 85		os.Exit(1)
 86	}
 87
 88	if len(*flagAddUser) > 0 && len(*flagDeleteUser) == 0 && !*flagListUsers && len(*flagCheckAuthorised) == 0 {
 89		createUser(dbConn, *flagAddUser)
 90		os.Exit(0)
 91	} else if len(*flagAddUser) == 0 && len(*flagDeleteUser) > 0 && !*flagListUsers && len(*flagCheckAuthorised) == 0 {
 92		deleteUser(dbConn, *flagDeleteUser)
 93		os.Exit(0)
 94	} else if len(*flagAddUser) == 0 && len(*flagDeleteUser) == 0 && *flagListUsers && len(*flagCheckAuthorised) == 0 {
 95		listUsers(dbConn)
 96		os.Exit(0)
 97	} else if len(*flagAddUser) == 0 && len(*flagDeleteUser) == 0 && !*flagListUsers && len(*flagCheckAuthorised) > 0 {
 98		checkAuthorised(dbConn, *flagCheckAuthorised)
 99		os.Exit(0)
100	}
101
102	mu := sync.Mutex{}
103
104	fmt.Println("Starting refresh loop")
105
106	go project.RefreshLoop(dbConn, &mu, config.FetchInterval, &manualRefresh, &req, &res)
107
108	wsHandler := ws.Handler{
109		DbConn:        dbConn,
110		Req:           &req,
111		Res:           &res,
112		ManualRefresh: &manualRefresh,
113		Mu:            &mu,
114		Version:       &version,
115	}
116
117	mux := http.NewServeMux()
118	mux.HandleFunc("/static/", ws.StaticHandler)
119	mux.HandleFunc("/new", wsHandler.NewHandler)
120	mux.HandleFunc("/login", wsHandler.LoginHandler)
121	mux.HandleFunc("/logout", wsHandler.LogoutHandler)
122	mux.HandleFunc("/", wsHandler.RootHandler)
123
124	httpServer := &http.Server{
125		Addr:                         config.Server.Listen,
126		Handler:                      mux,
127		DisableGeneralOptionsHandler: false,
128		TLSConfig:                    nil,
129		ReadTimeout:                  0,
130		ReadHeaderTimeout:            0,
131		WriteTimeout:                 0,
132		IdleTimeout:                  0,
133		MaxHeaderBytes:               0,
134		TLSNextProto:                 nil,
135		ConnState:                    nil,
136		ErrorLog:                     nil,
137		BaseContext:                  nil,
138		ConnContext:                  nil,
139	}
140
141	fmt.Println("Starting web server on", config.Server.Listen)
142
143	if err := httpServer.ListenAndServe(); errors.Is(err, http.ErrServerClosed) {
144		fmt.Println("Web server closed")
145		os.Exit(0)
146	}
147
148	fmt.Println(err)
149	os.Exit(1)
150}
151
152func checkConfig() error {
153	defaultDBConn := "willow.sqlite"
154	defaultFetchInterval := 3600
155	defaultListen := "127.0.0.1:1313"
156
157	defaultConfig := fmt.Sprintf(`# Path to SQLite database
158DBConn = "%s"
159# How often to fetch new releases in seconds
160## Minimum is %ds to avoid rate limits and unintentional abuse
161FetchInterval = %d
162
163[Server]
164# Address to listen on
165Listen = "%s"`, defaultDBConn, defaultFetchInterval, defaultFetchInterval, defaultListen)
166
167	file, err := os.Open(*flagConfig)
168	if err != nil {
169		if os.IsNotExist(err) {
170			file, err = os.Create(*flagConfig)
171			if err != nil {
172				return fmt.Errorf("failed to create file: %w", err)
173			}
174			defer file.Close()
175
176			_, err = file.WriteString(defaultConfig)
177			if err != nil {
178				return fmt.Errorf("failed to write to file: %w", err)
179			}
180
181			fmt.Println("Config file created at", *flagConfig)
182			fmt.Println("Please edit it and restart the server")
183			os.Exit(0)
184		}
185
186		return fmt.Errorf("failed to open config file: %w", err)
187	}
188	defer file.Close()
189
190	_, err = toml.DecodeFile(*flagConfig, &config)
191	if err != nil {
192		return fmt.Errorf("failed to decode TOML file: %w", err)
193	}
194
195	if config.FetchInterval < defaultFetchInterval {
196		fmt.Println("Fetch interval is set to", strconv.Itoa(config.FetchInterval), "seconds, but the minimum is", defaultFetchInterval, "seconds, using", strconv.Itoa(defaultFetchInterval)+"s")
197		config.FetchInterval = defaultFetchInterval
198	}
199
200	if config.Server.Listen == "" {
201		fmt.Println("No listen address specified, using", defaultListen)
202		config.Server.Listen = defaultListen
203	}
204
205	if config.DBConn == "" {
206		fmt.Println("No SQLite path specified, using \"" + defaultDBConn + "\"")
207		config.DBConn = defaultDBConn
208	}
209
210	return nil
211}