Detailed changes
@@ -17,24 +17,28 @@ import (
var bmStrict = bluemonday.StrictPolicy()
-// createUser is a CLI that creates a new user with the specified username
+// createUser is a CLI that creates a new user with the specified username.
func createUser(dbConn *sql.DB, username string) {
fmt.Println("Creating user", username)
fmt.Print("Enter password: ")
+
password, err := term.ReadPassword(syscall.Stdin)
if err != nil {
fmt.Println("Error reading password:", err)
os.Exit(1)
}
+
fmt.Println()
fmt.Print("Confirm password: ")
+
passwordConfirmation, err := term.ReadPassword(syscall.Stdin)
if err != nil {
fmt.Println("Error reading password confirmation:", err)
os.Exit(1)
}
+
fmt.Println()
if string(password) != string(passwordConfirmation) {
@@ -50,6 +54,7 @@ func createUser(dbConn *sql.DB, username string) {
//
// TODO: Abstract this
sanitisedPassword := bmStrict.Sanitize(string(password))
+
err = users.Register(dbConn, username, sanitisedPassword)
if err != nil {
fmt.Println("Error creating user:", err)
@@ -60,9 +65,10 @@ func createUser(dbConn *sql.DB, username string) {
os.Exit(0)
}
-// deleteUser is a CLI that deletes a user with the specified username
+// deleteUser is a CLI that deletes a user with the specified username.
func deleteUser(dbConn *sql.DB, username string) {
fmt.Println("Deleting user", username)
+
err := users.Delete(dbConn, username)
if err != nil {
fmt.Println("Error deleting user:", err)
@@ -73,7 +79,7 @@ func deleteUser(dbConn *sql.DB, username string) {
os.Exit(0)
}
-// listUsers is a CLI that lists all users in the database
+// listUsers is a CLI that lists all users in the database.
func listUsers(dbConn *sql.DB) {
fmt.Println("Listing all users")
@@ -90,6 +96,7 @@ func listUsers(dbConn *sql.DB) {
fmt.Println("-", u)
}
}
+
os.Exit(0)
}
@@ -99,15 +106,18 @@ func checkAuthorised(dbConn *sql.DB, username string) {
fmt.Printf("Checking whether password for user %s is correct\n", username)
fmt.Print("Enter password: ")
+
password, err := term.ReadPassword(syscall.Stdin)
if err != nil {
fmt.Println("Error reading password:", err)
os.Exit(1)
}
+
fmt.Println()
// TODO: Abstract this, refer to note in createUser()
sanitisedPassword := bmStrict.Sanitize(string(password))
+
authorised, err := users.UserAuthorised(dbConn, username, sanitisedPassword)
if err != nil {
fmt.Println("Error checking authorisation:", err)
@@ -119,5 +129,6 @@ func checkAuthorised(dbConn *sql.DB, username string) {
} else {
fmt.Println("User is not authorised")
}
+
os.Exit(0)
}
@@ -16,7 +16,6 @@ import (
"git.sr.ht/~amolith/willow/db"
"git.sr.ht/~amolith/willow/project"
"git.sr.ht/~amolith/willow/ws"
-
"github.com/BurntSushi/toml"
flag "github.com/spf13/pflag"
)
@@ -71,12 +70,15 @@ func main() {
}
fmt.Println("Checking whether database needs initialising")
+
err = db.InitialiseDatabase(dbConn)
if err != nil {
fmt.Println("Error initialising database:", err)
os.Exit(1)
}
+
fmt.Println("Checking whether there are pending migrations")
+
err = db.Migrate(dbConn)
if err != nil {
fmt.Println("Error migrating database schema:", err)
@@ -100,6 +102,7 @@ func main() {
mu := sync.Mutex{}
fmt.Println("Starting refresh loop")
+
go project.RefreshLoop(dbConn, &mu, config.FetchInterval, &manualRefresh, &req, &res)
wsHandler := ws.Handler{
@@ -124,6 +127,7 @@ func main() {
}
fmt.Println("Starting web server on", config.Server.Listen)
+
if err := httpServer.ListenAndServe(); errors.Is(err, http.ErrServerClosed) {
fmt.Println("Web server closed")
os.Exit(0)
@@ -18,23 +18,27 @@ var schema string
var mutex = &sync.Mutex{}
-// Open opens a connection to the SQLite database
+// Open opens a connection to the SQLite database.
func Open(dbPath string) (*sql.DB, error) {
return sql.Open("sqlite", "file:"+dbPath+"?_pragma=journal_mode%3DWAL")
}
// VerifySchema checks whether the schema has been initialised and initialises it
-// if not
+// if not.
func InitialiseDatabase(dbConn *sql.DB) error {
var name string
+
err := dbConn.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='users'").Scan(&name)
if err != nil && errors.Is(err, sql.ErrNoRows) {
mutex.Lock()
defer mutex.Unlock()
+
if _, err := dbConn.Exec(schema); err != nil {
return err
}
+
return nil
}
+
return err
}
@@ -48,17 +48,19 @@ var migrations = [...]migration{
},
}
-// Migrate runs all pending migrations
+// Migrate runs all pending migrations.
func Migrate(db *sql.DB) error {
version := getSchemaVersion(db)
for nextMigration := version + 1; nextMigration < len(migrations); nextMigration++ {
if err := runMigration(db, nextMigration); err != nil {
return fmt.Errorf("migrations failed: %w", err)
}
+
if version := getSchemaVersion(db); version != nextMigration {
return fmt.Errorf("migration did not update version (expected %d, got %d)", nextMigration, version)
}
}
+
return nil
}
@@ -67,30 +69,36 @@ func Migrate(db *sql.DB) error {
// transaction if unsuccessful.
func runMigration(db *sql.DB, migrationIdx int) (err error) {
current := migrations[migrationIdx]
+
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return fmt.Errorf("failed opening transaction for migration %d: %w", migrationIdx, err)
}
+
defer func() {
if err == nil {
err = tx.Commit()
}
+
if err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err)
}
}
}()
+
if len(current.upQuery) > 0 {
if _, err := tx.Exec(current.upQuery); err != nil {
return fmt.Errorf("failed running migration %d: %w", migrationIdx, err)
}
}
+
if current.postHook != nil {
if err := current.postHook(tx); err != nil {
return fmt.Errorf("failed running posthook for migration %d: %w", migrationIdx, err)
}
}
+
return updateSchemaVersion(tx, migrationIdx)
}
@@ -101,44 +109,53 @@ func runMigration(db *sql.DB, migrationIdx int) (err error) {
//lint:ignore U1000 Will be used when #34 is implemented (https://todo.sr.ht/~amolith/willow/34)
func undoMigration(db *sql.DB, migrationIdx int) (err error) {
current := migrations[migrationIdx]
+
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return fmt.Errorf("failed opening undo transaction for migration %d: %w", migrationIdx, err)
}
+
defer func() {
if err == nil {
err = tx.Commit()
}
+
if err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err)
}
}
}()
+
if len(current.downQuery) > 0 {
if _, err := tx.Exec(current.downQuery); err != nil {
return fmt.Errorf("failed undoing migration %d: %w", migrationIdx, err)
}
}
+
return updateSchemaVersion(tx, migrationIdx-1)
}
-// getSchemaVersion returns the schema version from the database
+// getSchemaVersion returns the schema version from the database.
func getSchemaVersion(db *sql.DB) int {
row := db.QueryRowContext(context.Background(), `SELECT version FROM schema_migrations LIMIT 1;`)
+
var version int
if err := row.Scan(&version); err != nil {
version = -1
}
+
return version
}
-// updateSchemaVersion sets the version to the provided int
+// updateSchemaVersion sets the version to the provided int.
func updateSchemaVersion(tx *sql.Tx, version int) error {
if version < 0 {
// Do not try to use the schema_migrations table in a schema version where it doesn't exist
return nil
}
+
_, err := tx.Exec(`UPDATE schema_migrations SET version = @version;`, sql.Named("version", version))
+
return err
}
@@ -34,7 +34,9 @@ func generateAndInsertProjectIDs(tx *sql.Tx) error {
if err := rows.Scan(&url, &name, &forge, &version, &created_at); err != nil {
return fmt.Errorf("failed to scan row from projects_tmp: %w", err)
}
+
id := fmt.Sprintf("%x", sha256.Sum256([]byte(url+name+forge+created_at)))
+
_, err = tx.Exec(
"INSERT INTO projects (id, url, name, forge, version, created_at) VALUES (@id, @url, @name, @forge, @version, @created_at)",
sql.Named("id", id),
@@ -74,7 +76,9 @@ func correctProjectIDs(tx *sql.Tx) error {
if err := rows.Scan(&old_id, &url, &name, &forge); err != nil {
return fmt.Errorf("failed to scan row from projects_tmp: %w", err)
}
+
id := fmt.Sprintf("%x", sha256.Sum256([]byte(url+name+forge)))
+
_, err = tx.Exec(
"UPDATE projects SET id = @id WHERE id = @old_id",
sql.Named("id", id),
@@ -9,25 +9,30 @@ import (
"sync"
)
-// DeleteProject deletes a project from the database
+// DeleteProject deletes a project from the database.
func DeleteProject(db *sql.DB, mu *sync.Mutex, id string) error {
mu.Lock()
defer mu.Unlock()
+
_, err := db.Exec("DELETE FROM projects WHERE id = ?", id)
if err != nil {
return err
}
+
_, err = db.Exec("DELETE FROM releases WHERE project_id = ?", id)
+
return err
}
-// GetProject returns a project from the database
+// GetProject returns a project from the database.
func GetProject(db *sql.DB, id string) (map[string]string, error) {
var name, forge, url, version string
+
err := db.QueryRow("SELECT name, forge, url, version FROM projects WHERE id = ?", id).Scan(&name, &forge, &url, &version)
if err != nil {
return nil, err
}
+
project := map[string]string{
"id": id,
"name": name,
@@ -35,13 +40,15 @@ func GetProject(db *sql.DB, id string) (map[string]string, error) {
"forge": forge,
"version": version,
}
+
return project, nil
}
-// UpsertProject adds or updates a project in the database
+// UpsertProject adds or updates a project in the database.
func UpsertProject(db *sql.DB, mu *sync.Mutex, id, url, name, forge, running string) error {
mu.Lock()
defer mu.Unlock()
+
_, err := db.Exec(`INSERT INTO projects (id, url, name, forge, version)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(id) DO
@@ -49,10 +56,11 @@ func UpsertProject(db *sql.DB, mu *sync.Mutex, id, url, name, forge, running str
name = excluded.name,
forge = excluded.forge,
version = excluded.version;`, id, url, name, forge, running)
+
return err
}
-// GetProjects returns a list of all projects in the database
+// GetProjects returns a list of all projects in the database.
func GetProjects(db *sql.DB) ([]map[string]string, error) {
rows, err := db.Query("SELECT id, name, url, forge, version FROM projects")
if err != nil {
@@ -63,10 +71,12 @@ func GetProjects(db *sql.DB) ([]map[string]string, error) {
var projects []map[string]string
for rows.Next() {
var id, name, url, forge, version string
+
err = rows.Scan(&id, &name, &url, &forge, &version)
if err != nil {
return nil, err
}
+
project := map[string]string{
"id": id,
"name": name,
@@ -76,5 +86,6 @@ func GetProjects(db *sql.DB) ([]map[string]string, error) {
}
projects = append(projects, project)
}
+
return projects, nil
}
@@ -10,10 +10,11 @@ import (
)
// UpsertRelease adds or updates a release for a project with a given ID in the
-// database
+// database.
func UpsertRelease(db *sql.DB, mu *sync.Mutex, id, projectID, url, tag, content, date string) error {
mu.Lock()
defer mu.Unlock()
+
_, err := db.Exec(`INSERT INTO releases (id, project_id, url, tag, content, date)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO
@@ -23,10 +24,11 @@ func UpsertRelease(db *sql.DB, mu *sync.Mutex, id, projectID, url, tag, content,
tag = excluded.tag,
content = excluded.content,
date = excluded.date;`, id, projectID, url, tag, content, date)
+
return err
}
-// GetReleases returns all releases for a project with a given id from the database
+// GetReleases returns all releases for a project with a given id from the database.
func GetReleases(db *sql.DB, projectID string) ([]map[string]string, error) {
rows, err := db.Query(`SELECT id, url, tag, content, date FROM releases WHERE project_id = ?`, projectID)
if err != nil {
@@ -43,10 +45,12 @@ func GetReleases(db *sql.DB, projectID string) ([]map[string]string, error) {
content string
date string
)
+
err := rows.Scan(&id, &url, &tag, &content, &date)
if err != nil {
return nil, err
}
+
releases = append(releases, map[string]string{
"id": id,
"project_id": projectID,
@@ -56,5 +60,6 @@ func GetReleases(db *sql.DB, projectID string) ([]map[string]string, error) {
"date": date,
})
}
+
return releases, nil
}
@@ -10,32 +10,38 @@ import (
)
// DeleteUser deletes specific user from the database and returns an error if it
-// fails
+// fails.
func DeleteUser(db *sql.DB, user string) error {
mutex.Lock()
defer mutex.Unlock()
+
_, err := db.Exec("DELETE FROM users WHERE username = ?", user)
+
return err
}
-// CreateUser creates a new user in the database and returns an error if it fails
+// CreateUser creates a new user in the database and returns an error if it fails.
func CreateUser(db *sql.DB, username, hash, salt string) error {
mutex.Lock()
defer mutex.Unlock()
+
_, err := db.Exec("INSERT INTO users (username, hash, salt) VALUES (?, ?, ?)", username, hash, salt)
+
return err
}
// GetUser returns a user's hash and salt from the database as strings and
-// returns an error if it fails
+// returns an error if it fails.
func GetUser(db *sql.DB, username string) (string, string, error) {
var hash, salt string
+
err := db.QueryRow("SELECT hash, salt FROM users WHERE username = ?", username).Scan(&hash, &salt)
+
return hash, salt, err
}
// GetUsers returns a list of all users in the database as a slice of strings
-// and returns an error if it fails
+// and returns an error if it fails.
func GetUsers(db *sql.DB) ([]string, error) {
rows, err := db.Query("SELECT username FROM users")
if err != nil {
@@ -46,10 +52,12 @@ func GetUsers(db *sql.DB) ([]string, error) {
var users []string
for rows.Next() {
var user string
+
err = rows.Scan(&user)
if err != nil {
return nil, err
}
+
users = append(users, user)
}
@@ -57,10 +65,13 @@ func GetUsers(db *sql.DB) ([]string, error) {
}
// GetSession accepts a session ID and returns the username associated with it
-// and an error
+// and an error.
func GetSession(db *sql.DB, session string) (string, time.Time, error) {
- var username string
- var expiresString string
+ var (
+ username string
+ expiresString string
+ )
+
err := db.QueryRow("SELECT username, expires FROM sessions WHERE token = ?", session).Scan(&username, &expiresString)
if err != nil {
return "", time.Time{}, err
@@ -70,6 +81,7 @@ func GetSession(db *sql.DB, session string) (string, time.Time, error) {
if err != nil {
return "", time.Time{}, err
}
+
return username, expires, nil
}
@@ -78,15 +90,19 @@ func GetSession(db *sql.DB, session string) (string, time.Time, error) {
func InvalidateSession(db *sql.DB, session string, expiry time.Time) error {
mutex.Lock()
defer mutex.Unlock()
+
_, err := db.Exec("UPDATE sessions SET expires = ? WHERE token = ?", expiry.Format(time.RFC3339), session)
+
return err
}
// CreateSession creates a new session in the database and returns an error if
-// it fails
+// it fails.
func CreateSession(db *sql.DB, username, token string, expiry time.Time) error {
mutex.Lock()
defer mutex.Unlock()
+
_, err := db.Exec("INSERT INTO sessions (token, username, expires) VALUES (?, ?, ?)", token, username, expiry.Format(time.RFC3339))
+
return err
}
@@ -12,11 +12,10 @@ import (
"strings"
"time"
- "github.com/microcosm-cc/bluemonday"
-
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing"
"github.com/go-git/go-git/v5/plumbing/transport"
+ "github.com/microcosm-cc/bluemonday"
)
type Release struct {
@@ -45,6 +44,7 @@ func GetReleases(gitURI, forge string) ([]Release, error) {
if err != nil {
return nil, err
}
+
tagRefs, err := r.Tags()
if err != nil {
return nil, err
@@ -65,13 +65,17 @@ func GetReleases(gitURI, forge string) ([]Release, error) {
err = tagRefs.ForEach(func(tagRef *plumbing.Reference) error {
tagObj, err := r.TagObject(tagRef.Hash())
- var message string
- var date time.Time
+ var (
+ message string
+ date time.Time
+ )
+
if errors.Is(err, plumbing.ErrObjectNotFound) {
commitTag, err := r.CommitObject(tagRef.Hash())
if err != nil {
return err
}
+
message = commitTag.Message
date = commitTag.Committer.When
} else {
@@ -81,6 +85,7 @@ func GetReleases(gitURI, forge string) ([]Release, error) {
tagURL := ""
tagName := bmStrict.Sanitize(tagRef.Name().Short())
+
switch forge {
case "sourcehut":
tagURL = "https://" + httpURI + "/refs/" + tagName
@@ -96,6 +101,7 @@ func GetReleases(gitURI, forge string) ([]Release, error) {
URL: tagURL,
Date: date,
})
+
return nil
})
if err != nil {
@@ -122,6 +128,7 @@ func minimalClone(url string) (r *git.Repository, err error) {
if errors.Is(err, git.NoErrAlreadyUpToDate) {
return r, nil
}
+
return r, err
} else if !errors.Is(err, git.ErrRepositoryNotExists) {
return nil, err
@@ -133,6 +140,7 @@ func minimalClone(url string) (r *git.Repository, err error) {
NoCheckout: true,
Depth: 1,
})
+
return r, err
}
@@ -142,6 +150,7 @@ func RemoveRepo(url string) (err error) {
if err != nil {
return err
}
+
err = os.RemoveAll(path)
if err != nil {
return err
@@ -154,12 +163,14 @@ func RemoveRepo(url string) (err error) {
if path == "data" {
break
}
+
err = os.Remove(path)
if err != nil {
// This folder likely has data, so might as well save some time by
// not checking the parents we can't delete anyway.
break
}
+
path = path[:strings.LastIndex(path, "/")]
}
@@ -177,11 +188,12 @@ func stringifyRepo(url string) (path string, err error) {
return "", err
}
- if ep.Protocol == "http" || ep.Protocol == "https" {
+ switch ep.Protocol {
+ case "http", "https":
return "data/" + strings.Split(url, "://")[1], nil
- } else if ep.Protocol == "ssh" {
+ case "ssh":
return "data/" + ep.Host + "/" + ep.Path, nil
- } else {
+ default:
return "", errors.New("unsupported protocol")
}
}
@@ -55,6 +55,7 @@ func TestStringifyRepo(t *testing.T) {
if err != nil {
t.Errorf("stringifyRepo(%s) returned error: %v", test.input, err)
}
+
if got != test.want {
t.Errorf("stringifyRepo(%s) = %s, want %s", test.input, got, test.want)
}
@@ -15,11 +15,10 @@ import (
"sync"
"time"
- "github.com/unascribed/FlexVer/go/flexver"
-
"git.sr.ht/~amolith/willow/db"
"git.sr.ht/~amolith/willow/git"
"git.sr.ht/~amolith/willow/rss"
+ "github.com/unascribed/FlexVer/go/flexver"
)
type Project struct {
@@ -40,7 +39,7 @@ type Release struct {
Date time.Time
}
-// GetReleases returns a list of all releases for a project from the database
+// GetReleases returns a list of all releases for a project from the database.
func GetReleases(dbConn *sql.DB, mu *sync.Mutex, proj Project) (Project, error) {
proj.ID = GenProjectID(proj.URL, proj.Name, proj.Forge)
@@ -63,13 +62,16 @@ func GetReleases(dbConn *sql.DB, mu *sync.Mutex, proj Project) (Project, error)
Date: time.Time{},
})
}
+
proj.Releases = SortReleases(proj.Releases)
+
return proj, nil
}
-// fetchReleases fetches releases from a project's forge given its URI
+// fetchReleases fetches releases from a project's forge given its URI.
func fetchReleases(dbConn *sql.DB, mu *sync.Mutex, p Project) (Project, error) {
var err error
+
switch p.Forge {
case "github", "gitea", "forgejo":
rssReleases, err := rss.GetReleases(p.URL)
@@ -77,6 +79,7 @@ func fetchReleases(dbConn *sql.DB, mu *sync.Mutex, p Project) (Project, error) {
fmt.Println("Error getting RSS releases:", err)
return p, err
}
+
for _, release := range rssReleases {
p.Releases = append(p.Releases, Release{
ID: GenReleaseID(p.URL, release.URL, release.Tag),
@@ -85,6 +88,7 @@ func fetchReleases(dbConn *sql.DB, mu *sync.Mutex, p Project) (Project, error) {
URL: release.URL,
Date: release.Date,
})
+
err = upsertReleases(dbConn, mu, p.ID, p.Releases)
if err != nil {
log.Printf("Error upserting release: %v", err)
@@ -96,6 +100,7 @@ func fetchReleases(dbConn *sql.DB, mu *sync.Mutex, p Project) (Project, error) {
if err != nil {
return p, err
}
+
for _, release := range gitReleases {
p.Releases = append(p.Releases, Release{
ID: GenReleaseID(p.URL, release.URL, release.Tag),
@@ -104,6 +109,7 @@ func fetchReleases(dbConn *sql.DB, mu *sync.Mutex, p Project) (Project, error) {
URL: release.URL,
Date: release.Date,
})
+
err = upsertReleases(dbConn, mu, p.ID, p.Releases)
if err != nil {
log.Printf("Error upserting release: %v", err)
@@ -111,7 +117,9 @@ func fetchReleases(dbConn *sql.DB, mu *sync.Mutex, p Project) (Project, error) {
}
}
}
+
p.Releases = SortReleases(p.Releases)
+
return p, err
}
@@ -119,6 +127,7 @@ func SortReleases(releases []Release) []Release {
sort.Slice(releases, func(i, j int) bool {
return !flexver.Less(releases[i].Tag, releases[j].Tag)
})
+
return releases
}
@@ -126,29 +135,32 @@ func SortProjects(projects []Project) []Project {
sort.Slice(projects, func(i, j int) bool {
return strings.ToLower(projects[i].Name) < strings.ToLower(projects[j].Name)
})
+
return projects
}
-// upsertReleases updates or inserts a release in the database
+// upsertReleases updates or inserts a release in the database.
func upsertReleases(dbConn *sql.DB, mu *sync.Mutex, projID string, releases []Release) error {
for _, release := range releases {
date := release.Date.Format("2006-01-02 15:04:05")
+
err := db.UpsertRelease(dbConn, mu, release.ID, projID, release.URL, release.Tag, release.Content, date)
if err != nil {
log.Printf("Error upserting release: %v", err)
return err
}
}
+
return nil
}
-// GenReleaseID generates a likely-unique ID from its project's URL, its release's URL, and its tag
+// GenReleaseID generates a likely-unique ID from its project's URL, its release's URL, and its tag.
func GenReleaseID(projectURL, releaseURL, tag string) string {
idByte := sha256.Sum256([]byte(projectURL + releaseURL + tag))
return fmt.Sprintf("%x", idByte)
}
-// GenProjectID generates a likely-unique ID from a project's URI, name, and forge
+// GenProjectID generates a likely-unique ID from a project's URI, name, and forge.
func GenProjectID(url, name, forge string) string {
idByte := sha256.Sum256([]byte(url + name + forge))
return fmt.Sprintf("%x", idByte)
@@ -156,10 +168,12 @@ func GenProjectID(url, name, forge string) string {
func Track(dbConn *sql.DB, mu *sync.Mutex, manualRefresh *chan struct{}, name, url, forge, release string) {
id := GenProjectID(url, name, forge)
+
err := db.UpsertProject(dbConn, mu, id, url, name, forge, release)
if err != nil {
fmt.Println("Error upserting project:", err)
}
+
*manualRefresh <- struct{}{}
}
@@ -190,17 +204,21 @@ func RefreshLoop(dbConn *sql.DB, mu *sync.Mutex, interval int, manualRefresh, re
if err != nil {
fmt.Println("Error getting projects:", err)
}
+
for i, p := range projectsList {
p, err := fetchReleases(dbConn, mu, p)
if err != nil {
fmt.Println(err)
continue
}
+
projectsList[i] = p
}
+
sort.Slice(projectsList, func(i, j int) bool {
return strings.ToLower(projectsList[i].Name) < strings.ToLower(projectsList[j].Name)
})
+
for i := range projectsList {
err = upsertReleases(dbConn, mu, projectsList[i].ID, projectsList[i].Releases)
if err != nil {
@@ -208,6 +226,7 @@ func RefreshLoop(dbConn *sql.DB, mu *sync.Mutex, interval int, manualRefresh, re
continue
}
}
+
return projectsList
}
@@ -219,16 +238,18 @@ func RefreshLoop(dbConn *sql.DB, mu *sync.Mutex, interval int, manualRefresh, re
projects = fetch()
case <-*manualRefresh:
ticker.Reset(time.Second * 3600)
+
projects = fetch()
case <-*req:
projectsCopy := make([]Project, len(projects))
copy(projectsCopy, projects)
+
*res <- projectsCopy
}
}
}
-// GetProject returns a project from the database
+// GetProject returns a project from the database.
func GetProject(dbConn *sql.DB, proj Project) (Project, error) {
projectDB, err := db.GetProject(dbConn, proj.ID)
if err != nil && errors.Is(err, sql.ErrNoRows) {
@@ -236,6 +257,7 @@ func GetProject(dbConn *sql.DB, proj Project) (Project, error) {
} else if err != nil {
return proj, err
}
+
p := Project{
ID: proj.ID,
URL: proj.URL,
@@ -243,10 +265,11 @@ func GetProject(dbConn *sql.DB, proj Project) (Project, error) {
Forge: proj.Forge,
Running: projectDB["version"],
}
+
return p, err
}
-// GetProjectWithReleases returns a single project from the database along with its releases
+// GetProjectWithReleases returns a single project from the database along with its releases.
func GetProjectWithReleases(dbConn *sql.DB, mu *sync.Mutex, proj Project) (Project, error) {
project, err := GetProject(dbConn, proj)
if err != nil {
@@ -256,7 +279,7 @@ func GetProjectWithReleases(dbConn *sql.DB, mu *sync.Mutex, proj Project) (Proje
return GetReleases(dbConn, mu, project)
}
-// GetProjects returns a list of all projects from the database
+// GetProjects returns a list of all projects from the database.
func GetProjects(dbConn *sql.DB) ([]Project, error) {
projectsDB, err := db.GetProjects(dbConn)
if err != nil {
@@ -278,7 +301,7 @@ func GetProjects(dbConn *sql.DB) ([]Project, error) {
}
// GetProjectsWithReleases returns a list of all projects and all their releases
-// from the database
+// from the database.
func GetProjectsWithReleases(dbConn *sql.DB, mu *sync.Mutex) ([]Project, error) {
projects, err := GetProjects(dbConn)
if err != nil {
@@ -290,6 +313,7 @@ func GetProjectsWithReleases(dbConn *sql.DB, mu *sync.Mutex) ([]Project, error)
if err != nil {
return nil, err
}
+
projects[i].Releases = SortReleases(projects[i].Releases)
}
@@ -10,7 +10,6 @@ import (
"time"
"github.com/microcosm-cc/bluemonday"
-
"github.com/mmcdole/gofeed"
)
@@ -22,6 +22,7 @@ func argonHash(password, salt string) (string, error) {
if err != nil {
return "", err
}
+
return base64.StdEncoding.EncodeToString(argon2.IDKey([]byte(password), decodedSalt, 2, 64*1024, 4, 64)), nil
}
@@ -29,10 +30,12 @@ func argonHash(password, salt string) (string, error) {
// string.
func generateSalt() (string, error) {
salt := make([]byte, 16)
+
_, err := rand.Read(salt)
if err != nil {
return "", err
}
+
return base64.StdEncoding.EncodeToString(salt), nil
}
@@ -88,7 +91,7 @@ func InvalidateSession(dbConn *sql.DB, session string) error {
}
// CreateSession accepts a username, generates a token, stores it in the
-// database, and returns it
+// database, and returns it.
func CreateSession(dbConn *sql.DB, username string) (string, time.Time, error) {
token, err := generateSalt()
if err != nil {
@@ -7,6 +7,7 @@ package ws
import (
"database/sql"
"embed"
+ "errors"
"fmt"
"io"
"net/http"
@@ -33,7 +34,7 @@ type Handler struct {
//go:embed static
var fs embed.FS
-// bmUGC = bluemonday.UGCPolicy()
+// bmUGC = bluemonday.UGCPolicy().
var bmStrict = bluemonday.StrictPolicy()
func (h Handler) RootHandler(w http.ResponseWriter, r *http.Request) {
@@ -41,14 +42,17 @@ func (h Handler) RootHandler(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
+
projectsWithReleases, err := project.GetProjectsWithReleases(h.DbConn, h.Mu)
if err != nil {
fmt.Println(err)
w.WriteHeader(http.StatusInternalServerError)
+
_, err := w.Write([]byte("Internal Server Error"))
if err != nil {
fmt.Println(err)
}
+
return
}
@@ -73,11 +77,14 @@ func (h Handler) NewHandler(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
+
params := r.URL.Query()
+
action := bmStrict.Sanitize(params.Get("action"))
if r.Method == http.MethodGet {
if action == "" {
data := struct{ Version string }{Version: *h.Version}
+
tmpl := template.Must(template.ParseFS(fs, "static/new.html.tmpl", "static/head.html.tmpl", "static/header.html.tmpl", "static/footer.html.tmpl"))
if err := tmpl.Execute(w, data); err != nil {
fmt.Println(err)
@@ -86,30 +93,36 @@ func (h Handler) NewHandler(w http.ResponseWriter, r *http.Request) {
submittedURL := bmStrict.Sanitize(params.Get("url"))
if submittedURL == "" {
w.WriteHeader(http.StatusBadRequest)
+
_, err := w.Write([]byte("No URL provided"))
if err != nil {
fmt.Println(err)
}
+
return
}
forge := bmStrict.Sanitize(params.Get("forge"))
if forge == "" {
w.WriteHeader(http.StatusBadRequest)
+
_, err := w.Write([]byte("No forge provided"))
if err != nil {
fmt.Println(err)
}
+
return
}
name := bmStrict.Sanitize(params.Get("name"))
if name == "" {
w.WriteHeader(http.StatusBadRequest)
+
_, err := w.Write([]byte("No name provided"))
if err != nil {
fmt.Println(err)
}
+
return
}
@@ -121,22 +134,26 @@ func (h Handler) NewHandler(w http.ResponseWriter, r *http.Request) {
}
proj, err := project.GetProject(h.DbConn, proj)
- if err != nil && err != sql.ErrNoRows {
+ if err != nil && !errors.Is(err, sql.ErrNoRows) {
w.WriteHeader(http.StatusBadRequest)
- _, err := w.Write([]byte(fmt.Sprintf("Error getting project: %s", err)))
+
+ _, err := fmt.Fprintf(w, "Error getting project: %s", err)
if err != nil {
fmt.Println(err)
}
+
return
}
proj, err = project.GetReleases(h.DbConn, h.Mu, proj)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
- _, err := w.Write([]byte(fmt.Sprintf("Error getting releases: %s", err)))
+
+ _, err := fmt.Fprintf(w, "Error getting releases: %s", err)
if err != nil {
fmt.Println(err)
}
+
return
}
@@ -156,10 +173,12 @@ func (h Handler) NewHandler(w http.ResponseWriter, r *http.Request) {
submittedID := params.Get("id")
if submittedID == "" {
w.WriteHeader(http.StatusBadRequest)
+
_, err := w.Write([]byte("No URL provided"))
if err != nil {
fmt.Println(err)
}
+
return
}
@@ -173,6 +192,7 @@ func (h Handler) NewHandler(w http.ResponseWriter, r *http.Request) {
if err != nil {
fmt.Println(err)
}
+
idValue := bmStrict.Sanitize(r.FormValue("id"))
nameValue := bmStrict.Sanitize(r.FormValue("name"))
urlValue := bmStrict.Sanitize(r.FormValue("url"))
@@ -183,6 +203,7 @@ func (h Handler) NewHandler(w http.ResponseWriter, r *http.Request) {
if idValue != "" && nameValue != "" && urlValue != "" && forgeValue != "" && releaseValue != "" {
project.Track(h.DbConn, h.Mu, h.ManualRefresh, nameValue, urlValue, forgeValue, releaseValue)
http.Redirect(w, r, "/", http.StatusSeeOther)
+
return
}
@@ -193,6 +214,7 @@ func (h Handler) NewHandler(w http.ResponseWriter, r *http.Request) {
}
w.WriteHeader(http.StatusBadRequest)
+
_, err = w.Write([]byte("No data provided"))
if err != nil {
fmt.Println(err)
@@ -212,6 +234,7 @@ func (h Handler) LoginHandler(w http.ResponseWriter, r *http.Request) {
}{
Version: *h.Version,
}
+
tmpl := template.Must(template.ParseFS(fs, "static/login.html.tmpl", "static/head.html.tmpl", "static/footer.html.tmpl"))
if err := tmpl.Execute(w, data); err != nil {
fmt.Println(err)
@@ -223,44 +246,53 @@ func (h Handler) LoginHandler(w http.ResponseWriter, r *http.Request) {
if err != nil {
fmt.Println(err)
}
+
username := bmStrict.Sanitize(r.FormValue("username"))
password := bmStrict.Sanitize(r.FormValue("password"))
if username == "" || password == "" {
w.WriteHeader(http.StatusBadRequest)
+
_, err := w.Write([]byte("No data provided"))
if err != nil {
fmt.Println(err)
}
+
return
}
authorised, err := users.UserAuthorised(h.DbConn, username, password)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
- _, err := w.Write([]byte(fmt.Sprintf("Error logging in: %s", err)))
+
+ _, err := fmt.Fprintf(w, "Error logging in: %s", err)
if err != nil {
fmt.Println(err)
}
+
return
}
if !authorised {
w.WriteHeader(http.StatusUnauthorized)
+
_, err := w.Write([]byte("Incorrect username or password"))
if err != nil {
fmt.Println(err)
}
+
return
}
session, expiry, err := users.CreateSession(h.DbConn, username)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
- _, err := w.Write([]byte(fmt.Sprintf("Error creating session: %s", err)))
+
+ _, err := fmt.Fprintf(w, "Error creating session: %s", err)
if err != nil {
fmt.Println(err)
}
+
return
}
@@ -289,12 +321,15 @@ func (h Handler) LogoutHandler(w http.ResponseWriter, r *http.Request) {
err = users.InvalidateSession(h.DbConn, cookie.Value)
if err != nil {
fmt.Println(err)
- _, err = w.Write([]byte(fmt.Sprintf("Error logging out: %s", err)))
+
+ _, err = fmt.Fprintf(w, "Error logging out: %s", err)
if err != nil {
fmt.Println(err)
}
+
return
}
+
cookie.MaxAge = -1
http.SetCookie(w, cookie)
http.Redirect(w, r, "/login", http.StatusSeeOther)
@@ -324,10 +359,12 @@ func StaticHandler(writer http.ResponseWriter, request *http.Request) {
} else if strings.HasSuffix(resource, ".js") {
writer.Header().Set("Content-Type", "text/javascript")
}
+
home, err := fs.ReadFile(resource)
if err != nil {
fmt.Println(err)
}
+
if _, err = io.Writer.Write(writer, home); err != nil {
fmt.Println(err)
}