1// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
2//
3// SPDX-License-Identifier: Apache-2.0
4
5package main
6
7import (
8 "errors"
9 "fmt"
10 "log"
11 "net/http"
12 "os"
13 "strconv"
14 "sync"
15
16 "git.sr.ht/~amolith/willow/db"
17 "git.sr.ht/~amolith/willow/project"
18 "git.sr.ht/~amolith/willow/ws"
19 "github.com/BurntSushi/toml"
20 flag "github.com/spf13/pflag"
21)
22
23type (
24 Config struct {
25 Server server
26 DBConn string
27 // TODO: Make cache location configurable
28 // CacheLocation string
29 FetchInterval int
30 }
31
32 server struct {
33 Listen string
34 }
35)
36
37var (
38 flagConfig = flag.StringP("config", "c", "config.toml", "Path to config file")
39 flagAddUser = flag.StringP("add", "a", "", "Username of account to add")
40 flagDeleteUser = flag.StringP("deleteuser", "d", "", "Username of account to delete")
41 flagCheckAuthorised = flag.StringP("validatecredentials", "V", "", "Username of account to check")
42 flagListUsers = flag.BoolP("listusers", "l", false, "List all users")
43 flagShowVersion = flag.BoolP("version", "v", false, "Print Willow's version")
44 version = ""
45 config Config
46 req = make(chan struct{})
47 res = make(chan []project.Project)
48 manualRefresh = make(chan struct{})
49)
50
51func main() {
52 flag.Parse()
53
54 if *flagShowVersion {
55 fmt.Println(version)
56 os.Exit(0)
57 }
58
59 err := checkConfig()
60 if err != nil {
61 log.Fatalln(err)
62 }
63
64 fmt.Println("Opening database at", config.DBConn)
65
66 dbConn, err := db.Open(config.DBConn)
67 if err != nil {
68 fmt.Println("Error opening database:", err)
69 os.Exit(1)
70 }
71
72 fmt.Println("Checking whether database needs initialising")
73
74 err = db.InitialiseDatabase(dbConn)
75 if err != nil {
76 fmt.Println("Error initialising database:", err)
77 os.Exit(1)
78 }
79
80 fmt.Println("Checking whether there are pending migrations")
81
82 err = db.Migrate(dbConn)
83 if err != nil {
84 fmt.Println("Error migrating database schema:", err)
85 os.Exit(1)
86 }
87
88 if len(*flagAddUser) > 0 && len(*flagDeleteUser) == 0 && !*flagListUsers && len(*flagCheckAuthorised) == 0 {
89 createUser(dbConn, *flagAddUser)
90 os.Exit(0)
91 } else if len(*flagAddUser) == 0 && len(*flagDeleteUser) > 0 && !*flagListUsers && len(*flagCheckAuthorised) == 0 {
92 deleteUser(dbConn, *flagDeleteUser)
93 os.Exit(0)
94 } else if len(*flagAddUser) == 0 && len(*flagDeleteUser) == 0 && *flagListUsers && len(*flagCheckAuthorised) == 0 {
95 listUsers(dbConn)
96 os.Exit(0)
97 } else if len(*flagAddUser) == 0 && len(*flagDeleteUser) == 0 && !*flagListUsers && len(*flagCheckAuthorised) > 0 {
98 checkAuthorised(dbConn, *flagCheckAuthorised)
99 os.Exit(0)
100 }
101
102 mu := sync.Mutex{}
103
104 fmt.Println("Starting refresh loop")
105
106 go project.RefreshLoop(dbConn, &mu, config.FetchInterval, &manualRefresh, &req, &res)
107
108 wsHandler := ws.Handler{
109 DbConn: dbConn,
110 Req: &req,
111 Res: &res,
112 ManualRefresh: &manualRefresh,
113 Mu: &mu,
114 Version: &version,
115 }
116
117 mux := http.NewServeMux()
118 mux.HandleFunc("/static/", ws.StaticHandler)
119 mux.HandleFunc("/new", wsHandler.NewHandler)
120 mux.HandleFunc("/login", wsHandler.LoginHandler)
121 mux.HandleFunc("/logout", wsHandler.LogoutHandler)
122 mux.HandleFunc("/", wsHandler.RootHandler)
123
124 httpServer := &http.Server{
125 Addr: config.Server.Listen,
126 Handler: mux,
127 DisableGeneralOptionsHandler: false,
128 TLSConfig: nil,
129 ReadTimeout: 0,
130 ReadHeaderTimeout: 0,
131 WriteTimeout: 0,
132 IdleTimeout: 0,
133 MaxHeaderBytes: 0,
134 TLSNextProto: nil,
135 ConnState: nil,
136 ErrorLog: nil,
137 BaseContext: nil,
138 ConnContext: nil,
139 }
140
141 fmt.Println("Starting web server on", config.Server.Listen)
142
143 if err := httpServer.ListenAndServe(); errors.Is(err, http.ErrServerClosed) {
144 fmt.Println("Web server closed")
145 os.Exit(0)
146 } else {
147 fmt.Println(err)
148 os.Exit(1)
149 }
150}
151
152func checkConfig() error {
153 defaultDBConn := "willow.sqlite"
154 defaultFetchInterval := 3600
155 defaultListen := "127.0.0.1:1313"
156
157 defaultConfig := fmt.Sprintf(`# Path to SQLite database
158DBConn = "%s"
159# How often to fetch new releases in seconds
160## Minimum is %ds to avoid rate limits and unintentional abuse
161FetchInterval = %d
162
163[Server]
164# Address to listen on
165Listen = "%s"`, defaultDBConn, defaultFetchInterval, defaultFetchInterval, defaultListen)
166
167 file, err := os.Open(*flagConfig)
168 if err != nil {
169 if os.IsNotExist(err) {
170 file, err = os.Create(*flagConfig)
171 if err != nil {
172 return err
173 }
174 defer file.Close()
175
176 _, err = file.WriteString(defaultConfig)
177 if err != nil {
178 return err
179 }
180
181 fmt.Println("Config file created at", *flagConfig)
182 fmt.Println("Please edit it and restart the server")
183 os.Exit(0)
184 } else {
185 return err
186 }
187 }
188 defer file.Close()
189
190 _, err = toml.DecodeFile(*flagConfig, &config)
191 if err != nil {
192 return err
193 }
194
195 if config.FetchInterval < defaultFetchInterval {
196 fmt.Println("Fetch interval is set to", strconv.Itoa(config.FetchInterval), "seconds, but the minimum is", defaultFetchInterval, "seconds, using", strconv.Itoa(defaultFetchInterval)+"s")
197 config.FetchInterval = defaultFetchInterval
198 }
199
200 if config.Server.Listen == "" {
201 fmt.Println("No listen address specified, using", defaultListen)
202 config.Server.Listen = defaultListen
203 }
204
205 if config.DBConn == "" {
206 fmt.Println("No SQLite path specified, using \"" + defaultDBConn + "\"")
207 config.DBConn = defaultDBConn
208 }
209
210 return nil
211}