diff --git a/cmd/cli.go b/cmd/cli.go index c920a5beb2bc227ff9bf13db0a2c822996250c00..b1e5fa3ced00304e575d79ccf85c7e7d320d594d 100644 --- a/cmd/cli.go +++ b/cmd/cli.go @@ -17,24 +17,28 @@ import ( var bmStrict = bluemonday.StrictPolicy() -// createUser is a CLI that creates a new user with the specified username +// createUser is a CLI that creates a new user with the specified username. func createUser(dbConn *sql.DB, username string) { fmt.Println("Creating user", username) fmt.Print("Enter password: ") + password, err := term.ReadPassword(syscall.Stdin) if err != nil { fmt.Println("Error reading password:", err) os.Exit(1) } + fmt.Println() fmt.Print("Confirm password: ") + passwordConfirmation, err := term.ReadPassword(syscall.Stdin) if err != nil { fmt.Println("Error reading password confirmation:", err) os.Exit(1) } + fmt.Println() if string(password) != string(passwordConfirmation) { @@ -50,6 +54,7 @@ func createUser(dbConn *sql.DB, username string) { // // TODO: Abstract this sanitisedPassword := bmStrict.Sanitize(string(password)) + err = users.Register(dbConn, username, sanitisedPassword) if err != nil { fmt.Println("Error creating user:", err) @@ -60,9 +65,10 @@ func createUser(dbConn *sql.DB, username string) { os.Exit(0) } -// deleteUser is a CLI that deletes a user with the specified username +// deleteUser is a CLI that deletes a user with the specified username. func deleteUser(dbConn *sql.DB, username string) { fmt.Println("Deleting user", username) + err := users.Delete(dbConn, username) if err != nil { fmt.Println("Error deleting user:", err) @@ -73,7 +79,7 @@ func deleteUser(dbConn *sql.DB, username string) { os.Exit(0) } -// listUsers is a CLI that lists all users in the database +// listUsers is a CLI that lists all users in the database. func listUsers(dbConn *sql.DB) { fmt.Println("Listing all users") @@ -90,6 +96,7 @@ func listUsers(dbConn *sql.DB) { fmt.Println("-", u) } } + os.Exit(0) } @@ -99,15 +106,18 @@ func checkAuthorised(dbConn *sql.DB, username string) { fmt.Printf("Checking whether password for user %s is correct\n", username) fmt.Print("Enter password: ") + password, err := term.ReadPassword(syscall.Stdin) if err != nil { fmt.Println("Error reading password:", err) os.Exit(1) } + fmt.Println() // TODO: Abstract this, refer to note in createUser() sanitisedPassword := bmStrict.Sanitize(string(password)) + authorised, err := users.UserAuthorised(dbConn, username, sanitisedPassword) if err != nil { fmt.Println("Error checking authorisation:", err) @@ -119,5 +129,6 @@ func checkAuthorised(dbConn *sql.DB, username string) { } else { fmt.Println("User is not authorised") } + os.Exit(0) } diff --git a/cmd/willow.go b/cmd/willow.go index 337fa29258c88891a8dc941b192c3965bbcf72d8..cf75e540f571dda6e65314db52614dd4f4c72ef0 100644 --- a/cmd/willow.go +++ b/cmd/willow.go @@ -16,7 +16,6 @@ import ( "git.sr.ht/~amolith/willow/db" "git.sr.ht/~amolith/willow/project" "git.sr.ht/~amolith/willow/ws" - "github.com/BurntSushi/toml" flag "github.com/spf13/pflag" ) @@ -71,12 +70,15 @@ func main() { } fmt.Println("Checking whether database needs initialising") + err = db.InitialiseDatabase(dbConn) if err != nil { fmt.Println("Error initialising database:", err) os.Exit(1) } + fmt.Println("Checking whether there are pending migrations") + err = db.Migrate(dbConn) if err != nil { fmt.Println("Error migrating database schema:", err) @@ -100,6 +102,7 @@ func main() { mu := sync.Mutex{} fmt.Println("Starting refresh loop") + go project.RefreshLoop(dbConn, &mu, config.FetchInterval, &manualRefresh, &req, &res) wsHandler := ws.Handler{ @@ -124,6 +127,7 @@ func main() { } fmt.Println("Starting web server on", config.Server.Listen) + if err := httpServer.ListenAndServe(); errors.Is(err, http.ErrServerClosed) { fmt.Println("Web server closed") os.Exit(0) diff --git a/db/db.go b/db/db.go index 8e160dfa29c32e4e605d80a8a65008da0ad1cc7d..b409f515ef62368681a766aaf09aeca33fd0daea 100644 --- a/db/db.go +++ b/db/db.go @@ -18,23 +18,27 @@ var schema string var mutex = &sync.Mutex{} -// Open opens a connection to the SQLite database +// Open opens a connection to the SQLite database. func Open(dbPath string) (*sql.DB, error) { return sql.Open("sqlite", "file:"+dbPath+"?_pragma=journal_mode%3DWAL") } // VerifySchema checks whether the schema has been initialised and initialises it -// if not +// if not. func InitialiseDatabase(dbConn *sql.DB) error { var name string + err := dbConn.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='users'").Scan(&name) if err != nil && errors.Is(err, sql.ErrNoRows) { mutex.Lock() defer mutex.Unlock() + if _, err := dbConn.Exec(schema); err != nil { return err } + return nil } + return err } diff --git a/db/migrations.go b/db/migrations.go index 70d45c1bea284ad5f66472237670fe9126db7471..243fa715b6662eb18bcb575657068f1bcd30575c 100644 --- a/db/migrations.go +++ b/db/migrations.go @@ -48,17 +48,19 @@ var migrations = [...]migration{ }, } -// Migrate runs all pending migrations +// Migrate runs all pending migrations. func Migrate(db *sql.DB) error { version := getSchemaVersion(db) for nextMigration := version + 1; nextMigration < len(migrations); nextMigration++ { if err := runMigration(db, nextMigration); err != nil { return fmt.Errorf("migrations failed: %w", err) } + if version := getSchemaVersion(db); version != nextMigration { return fmt.Errorf("migration did not update version (expected %d, got %d)", nextMigration, version) } } + return nil } @@ -67,30 +69,36 @@ func Migrate(db *sql.DB) error { // transaction if unsuccessful. func runMigration(db *sql.DB, migrationIdx int) (err error) { current := migrations[migrationIdx] + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { return fmt.Errorf("failed opening transaction for migration %d: %w", migrationIdx, err) } + defer func() { if err == nil { err = tx.Commit() } + if err != nil { if rbErr := tx.Rollback(); rbErr != nil { err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err) } } }() + if len(current.upQuery) > 0 { if _, err := tx.Exec(current.upQuery); err != nil { return fmt.Errorf("failed running migration %d: %w", migrationIdx, err) } } + if current.postHook != nil { if err := current.postHook(tx); err != nil { return fmt.Errorf("failed running posthook for migration %d: %w", migrationIdx, err) } } + return updateSchemaVersion(tx, migrationIdx) } @@ -101,44 +109,53 @@ func runMigration(db *sql.DB, migrationIdx int) (err error) { //lint:ignore U1000 Will be used when #34 is implemented (https://todo.sr.ht/~amolith/willow/34) func undoMigration(db *sql.DB, migrationIdx int) (err error) { current := migrations[migrationIdx] + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { return fmt.Errorf("failed opening undo transaction for migration %d: %w", migrationIdx, err) } + defer func() { if err == nil { err = tx.Commit() } + if err != nil { if rbErr := tx.Rollback(); rbErr != nil { err = fmt.Errorf("failed rolling back: %w due to: %w", rbErr, err) } } }() + if len(current.downQuery) > 0 { if _, err := tx.Exec(current.downQuery); err != nil { return fmt.Errorf("failed undoing migration %d: %w", migrationIdx, err) } } + return updateSchemaVersion(tx, migrationIdx-1) } -// getSchemaVersion returns the schema version from the database +// getSchemaVersion returns the schema version from the database. func getSchemaVersion(db *sql.DB) int { row := db.QueryRowContext(context.Background(), `SELECT version FROM schema_migrations LIMIT 1;`) + var version int if err := row.Scan(&version); err != nil { version = -1 } + return version } -// updateSchemaVersion sets the version to the provided int +// updateSchemaVersion sets the version to the provided int. func updateSchemaVersion(tx *sql.Tx, version int) error { if version < 0 { // Do not try to use the schema_migrations table in a schema version where it doesn't exist return nil } + _, err := tx.Exec(`UPDATE schema_migrations SET version = @version;`, sql.Named("version", version)) + return err } diff --git a/db/posthooks.go b/db/posthooks.go index a25497f0f81bfcbbbc87fe8582d44dd53203adf8..3b6b1e3ffafda9916fa06eb5da9ab81ca008af7a 100644 --- a/db/posthooks.go +++ b/db/posthooks.go @@ -34,7 +34,9 @@ func generateAndInsertProjectIDs(tx *sql.Tx) error { if err := rows.Scan(&url, &name, &forge, &version, &created_at); err != nil { return fmt.Errorf("failed to scan row from projects_tmp: %w", err) } + id := fmt.Sprintf("%x", sha256.Sum256([]byte(url+name+forge+created_at))) + _, err = tx.Exec( "INSERT INTO projects (id, url, name, forge, version, created_at) VALUES (@id, @url, @name, @forge, @version, @created_at)", sql.Named("id", id), @@ -74,7 +76,9 @@ func correctProjectIDs(tx *sql.Tx) error { if err := rows.Scan(&old_id, &url, &name, &forge); err != nil { return fmt.Errorf("failed to scan row from projects_tmp: %w", err) } + id := fmt.Sprintf("%x", sha256.Sum256([]byte(url+name+forge))) + _, err = tx.Exec( "UPDATE projects SET id = @id WHERE id = @old_id", sql.Named("id", id), diff --git a/db/project.go b/db/project.go index 02933a1be19d469950bbcf2268b905c13f90f4ad..c28f921e437b2af214e5c1cf4aa360624314c928 100644 --- a/db/project.go +++ b/db/project.go @@ -9,25 +9,30 @@ import ( "sync" ) -// DeleteProject deletes a project from the database +// DeleteProject deletes a project from the database. func DeleteProject(db *sql.DB, mu *sync.Mutex, id string) error { mu.Lock() defer mu.Unlock() + _, err := db.Exec("DELETE FROM projects WHERE id = ?", id) if err != nil { return err } + _, err = db.Exec("DELETE FROM releases WHERE project_id = ?", id) + return err } -// GetProject returns a project from the database +// GetProject returns a project from the database. func GetProject(db *sql.DB, id string) (map[string]string, error) { var name, forge, url, version string + err := db.QueryRow("SELECT name, forge, url, version FROM projects WHERE id = ?", id).Scan(&name, &forge, &url, &version) if err != nil { return nil, err } + project := map[string]string{ "id": id, "name": name, @@ -35,13 +40,15 @@ func GetProject(db *sql.DB, id string) (map[string]string, error) { "forge": forge, "version": version, } + return project, nil } -// UpsertProject adds or updates a project in the database +// UpsertProject adds or updates a project in the database. func UpsertProject(db *sql.DB, mu *sync.Mutex, id, url, name, forge, running string) error { mu.Lock() defer mu.Unlock() + _, err := db.Exec(`INSERT INTO projects (id, url, name, forge, version) VALUES (?, ?, ?, ?, ?) ON CONFLICT(id) DO @@ -49,10 +56,11 @@ func UpsertProject(db *sql.DB, mu *sync.Mutex, id, url, name, forge, running str name = excluded.name, forge = excluded.forge, version = excluded.version;`, id, url, name, forge, running) + return err } -// GetProjects returns a list of all projects in the database +// GetProjects returns a list of all projects in the database. func GetProjects(db *sql.DB) ([]map[string]string, error) { rows, err := db.Query("SELECT id, name, url, forge, version FROM projects") if err != nil { @@ -63,10 +71,12 @@ func GetProjects(db *sql.DB) ([]map[string]string, error) { var projects []map[string]string for rows.Next() { var id, name, url, forge, version string + err = rows.Scan(&id, &name, &url, &forge, &version) if err != nil { return nil, err } + project := map[string]string{ "id": id, "name": name, @@ -76,5 +86,6 @@ func GetProjects(db *sql.DB) ([]map[string]string, error) { } projects = append(projects, project) } + return projects, nil } diff --git a/db/release.go b/db/release.go index e6ae4c0f528f1e77fb55ac15f0494c84cb8739e9..d88ba3a638e9b32cf0dfbf1a2e88367a9276dbd7 100644 --- a/db/release.go +++ b/db/release.go @@ -10,10 +10,11 @@ import ( ) // UpsertRelease adds or updates a release for a project with a given ID in the -// database +// database. func UpsertRelease(db *sql.DB, mu *sync.Mutex, id, projectID, url, tag, content, date string) error { mu.Lock() defer mu.Unlock() + _, err := db.Exec(`INSERT INTO releases (id, project_id, url, tag, content, date) VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT(id) DO @@ -23,10 +24,11 @@ func UpsertRelease(db *sql.DB, mu *sync.Mutex, id, projectID, url, tag, content, tag = excluded.tag, content = excluded.content, date = excluded.date;`, id, projectID, url, tag, content, date) + return err } -// GetReleases returns all releases for a project with a given id from the database +// GetReleases returns all releases for a project with a given id from the database. func GetReleases(db *sql.DB, projectID string) ([]map[string]string, error) { rows, err := db.Query(`SELECT id, url, tag, content, date FROM releases WHERE project_id = ?`, projectID) if err != nil { @@ -43,10 +45,12 @@ func GetReleases(db *sql.DB, projectID string) ([]map[string]string, error) { content string date string ) + err := rows.Scan(&id, &url, &tag, &content, &date) if err != nil { return nil, err } + releases = append(releases, map[string]string{ "id": id, "project_id": projectID, @@ -56,5 +60,6 @@ func GetReleases(db *sql.DB, projectID string) ([]map[string]string, error) { "date": date, }) } + return releases, nil } diff --git a/db/users.go b/db/users.go index 5abaf4a6ae9585c0ad2780316e160f35c51b0c51..7d4753601b79da076a9021f025d4e7d607d01c20 100644 --- a/db/users.go +++ b/db/users.go @@ -10,32 +10,38 @@ import ( ) // DeleteUser deletes specific user from the database and returns an error if it -// fails +// fails. func DeleteUser(db *sql.DB, user string) error { mutex.Lock() defer mutex.Unlock() + _, err := db.Exec("DELETE FROM users WHERE username = ?", user) + return err } -// CreateUser creates a new user in the database and returns an error if it fails +// CreateUser creates a new user in the database and returns an error if it fails. func CreateUser(db *sql.DB, username, hash, salt string) error { mutex.Lock() defer mutex.Unlock() + _, err := db.Exec("INSERT INTO users (username, hash, salt) VALUES (?, ?, ?)", username, hash, salt) + return err } // GetUser returns a user's hash and salt from the database as strings and -// returns an error if it fails +// returns an error if it fails. func GetUser(db *sql.DB, username string) (string, string, error) { var hash, salt string + err := db.QueryRow("SELECT hash, salt FROM users WHERE username = ?", username).Scan(&hash, &salt) + return hash, salt, err } // GetUsers returns a list of all users in the database as a slice of strings -// and returns an error if it fails +// and returns an error if it fails. func GetUsers(db *sql.DB) ([]string, error) { rows, err := db.Query("SELECT username FROM users") if err != nil { @@ -46,10 +52,12 @@ func GetUsers(db *sql.DB) ([]string, error) { var users []string for rows.Next() { var user string + err = rows.Scan(&user) if err != nil { return nil, err } + users = append(users, user) } @@ -57,10 +65,13 @@ func GetUsers(db *sql.DB) ([]string, error) { } // GetSession accepts a session ID and returns the username associated with it -// and an error +// and an error. func GetSession(db *sql.DB, session string) (string, time.Time, error) { - var username string - var expiresString string + var ( + username string + expiresString string + ) + err := db.QueryRow("SELECT username, expires FROM sessions WHERE token = ?", session).Scan(&username, &expiresString) if err != nil { return "", time.Time{}, err @@ -70,6 +81,7 @@ func GetSession(db *sql.DB, session string) (string, time.Time, error) { if err != nil { return "", time.Time{}, err } + return username, expires, nil } @@ -78,15 +90,19 @@ func GetSession(db *sql.DB, session string) (string, time.Time, error) { func InvalidateSession(db *sql.DB, session string, expiry time.Time) error { mutex.Lock() defer mutex.Unlock() + _, err := db.Exec("UPDATE sessions SET expires = ? WHERE token = ?", expiry.Format(time.RFC3339), session) + return err } // CreateSession creates a new session in the database and returns an error if -// it fails +// it fails. func CreateSession(db *sql.DB, username, token string, expiry time.Time) error { mutex.Lock() defer mutex.Unlock() + _, err := db.Exec("INSERT INTO sessions (token, username, expires) VALUES (?, ?, ?)", token, username, expiry.Format(time.RFC3339)) + return err } diff --git a/git/git.go b/git/git.go index 2919203b2f53f611a429b6a740e7c0621cf24387..1c5e130081527b993611628ec2b97d8b53d29e8e 100644 --- a/git/git.go +++ b/git/git.go @@ -12,11 +12,10 @@ import ( "strings" "time" - "github.com/microcosm-cc/bluemonday" - "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/plumbing" "github.com/go-git/go-git/v5/plumbing/transport" + "github.com/microcosm-cc/bluemonday" ) type Release struct { @@ -45,6 +44,7 @@ func GetReleases(gitURI, forge string) ([]Release, error) { if err != nil { return nil, err } + tagRefs, err := r.Tags() if err != nil { return nil, err @@ -65,13 +65,17 @@ func GetReleases(gitURI, forge string) ([]Release, error) { err = tagRefs.ForEach(func(tagRef *plumbing.Reference) error { tagObj, err := r.TagObject(tagRef.Hash()) - var message string - var date time.Time + var ( + message string + date time.Time + ) + if errors.Is(err, plumbing.ErrObjectNotFound) { commitTag, err := r.CommitObject(tagRef.Hash()) if err != nil { return err } + message = commitTag.Message date = commitTag.Committer.When } else { @@ -81,6 +85,7 @@ func GetReleases(gitURI, forge string) ([]Release, error) { tagURL := "" tagName := bmStrict.Sanitize(tagRef.Name().Short()) + switch forge { case "sourcehut": tagURL = "https://" + httpURI + "/refs/" + tagName @@ -96,6 +101,7 @@ func GetReleases(gitURI, forge string) ([]Release, error) { URL: tagURL, Date: date, }) + return nil }) if err != nil { @@ -122,6 +128,7 @@ func minimalClone(url string) (r *git.Repository, err error) { if errors.Is(err, git.NoErrAlreadyUpToDate) { return r, nil } + return r, err } else if !errors.Is(err, git.ErrRepositoryNotExists) { return nil, err @@ -133,6 +140,7 @@ func minimalClone(url string) (r *git.Repository, err error) { NoCheckout: true, Depth: 1, }) + return r, err } @@ -142,6 +150,7 @@ func RemoveRepo(url string) (err error) { if err != nil { return err } + err = os.RemoveAll(path) if err != nil { return err @@ -154,12 +163,14 @@ func RemoveRepo(url string) (err error) { if path == "data" { break } + err = os.Remove(path) if err != nil { // This folder likely has data, so might as well save some time by // not checking the parents we can't delete anyway. break } + path = path[:strings.LastIndex(path, "/")] } @@ -177,11 +188,12 @@ func stringifyRepo(url string) (path string, err error) { return "", err } - if ep.Protocol == "http" || ep.Protocol == "https" { + switch ep.Protocol { + case "http", "https": return "data/" + strings.Split(url, "://")[1], nil - } else if ep.Protocol == "ssh" { + case "ssh": return "data/" + ep.Host + "/" + ep.Path, nil - } else { + default: return "", errors.New("unsupported protocol") } } diff --git a/git/git_test.go b/git/git_test.go index 7f9a1ce7908eeaecdccc1f1dda7c48969eff2e8d..515e2d30479d2de029a2837f0938c40402e90d2c 100644 --- a/git/git_test.go +++ b/git/git_test.go @@ -55,6 +55,7 @@ func TestStringifyRepo(t *testing.T) { if err != nil { t.Errorf("stringifyRepo(%s) returned error: %v", test.input, err) } + if got != test.want { t.Errorf("stringifyRepo(%s) = %s, want %s", test.input, got, test.want) } diff --git a/project/project.go b/project/project.go index 96301092253a1c74a87ad86dd335fc6669cd19ec..4aa2704eb134b2777010d50e47acf9b9bd8de6da 100644 --- a/project/project.go +++ b/project/project.go @@ -15,11 +15,10 @@ import ( "sync" "time" - "github.com/unascribed/FlexVer/go/flexver" - "git.sr.ht/~amolith/willow/db" "git.sr.ht/~amolith/willow/git" "git.sr.ht/~amolith/willow/rss" + "github.com/unascribed/FlexVer/go/flexver" ) type Project struct { @@ -40,7 +39,7 @@ type Release struct { Date time.Time } -// GetReleases returns a list of all releases for a project from the database +// GetReleases returns a list of all releases for a project from the database. func GetReleases(dbConn *sql.DB, mu *sync.Mutex, proj Project) (Project, error) { proj.ID = GenProjectID(proj.URL, proj.Name, proj.Forge) @@ -63,13 +62,16 @@ func GetReleases(dbConn *sql.DB, mu *sync.Mutex, proj Project) (Project, error) Date: time.Time{}, }) } + proj.Releases = SortReleases(proj.Releases) + return proj, nil } -// fetchReleases fetches releases from a project's forge given its URI +// fetchReleases fetches releases from a project's forge given its URI. func fetchReleases(dbConn *sql.DB, mu *sync.Mutex, p Project) (Project, error) { var err error + switch p.Forge { case "github", "gitea", "forgejo": rssReleases, err := rss.GetReleases(p.URL) @@ -77,6 +79,7 @@ func fetchReleases(dbConn *sql.DB, mu *sync.Mutex, p Project) (Project, error) { fmt.Println("Error getting RSS releases:", err) return p, err } + for _, release := range rssReleases { p.Releases = append(p.Releases, Release{ ID: GenReleaseID(p.URL, release.URL, release.Tag), @@ -85,6 +88,7 @@ func fetchReleases(dbConn *sql.DB, mu *sync.Mutex, p Project) (Project, error) { URL: release.URL, Date: release.Date, }) + err = upsertReleases(dbConn, mu, p.ID, p.Releases) if err != nil { log.Printf("Error upserting release: %v", err) @@ -96,6 +100,7 @@ func fetchReleases(dbConn *sql.DB, mu *sync.Mutex, p Project) (Project, error) { if err != nil { return p, err } + for _, release := range gitReleases { p.Releases = append(p.Releases, Release{ ID: GenReleaseID(p.URL, release.URL, release.Tag), @@ -104,6 +109,7 @@ func fetchReleases(dbConn *sql.DB, mu *sync.Mutex, p Project) (Project, error) { URL: release.URL, Date: release.Date, }) + err = upsertReleases(dbConn, mu, p.ID, p.Releases) if err != nil { log.Printf("Error upserting release: %v", err) @@ -111,7 +117,9 @@ func fetchReleases(dbConn *sql.DB, mu *sync.Mutex, p Project) (Project, error) { } } } + p.Releases = SortReleases(p.Releases) + return p, err } @@ -119,6 +127,7 @@ func SortReleases(releases []Release) []Release { sort.Slice(releases, func(i, j int) bool { return !flexver.Less(releases[i].Tag, releases[j].Tag) }) + return releases } @@ -126,29 +135,32 @@ func SortProjects(projects []Project) []Project { sort.Slice(projects, func(i, j int) bool { return strings.ToLower(projects[i].Name) < strings.ToLower(projects[j].Name) }) + return projects } -// upsertReleases updates or inserts a release in the database +// upsertReleases updates or inserts a release in the database. func upsertReleases(dbConn *sql.DB, mu *sync.Mutex, projID string, releases []Release) error { for _, release := range releases { date := release.Date.Format("2006-01-02 15:04:05") + err := db.UpsertRelease(dbConn, mu, release.ID, projID, release.URL, release.Tag, release.Content, date) if err != nil { log.Printf("Error upserting release: %v", err) return err } } + return nil } -// GenReleaseID generates a likely-unique ID from its project's URL, its release's URL, and its tag +// GenReleaseID generates a likely-unique ID from its project's URL, its release's URL, and its tag. func GenReleaseID(projectURL, releaseURL, tag string) string { idByte := sha256.Sum256([]byte(projectURL + releaseURL + tag)) return fmt.Sprintf("%x", idByte) } -// GenProjectID generates a likely-unique ID from a project's URI, name, and forge +// GenProjectID generates a likely-unique ID from a project's URI, name, and forge. func GenProjectID(url, name, forge string) string { idByte := sha256.Sum256([]byte(url + name + forge)) return fmt.Sprintf("%x", idByte) @@ -156,10 +168,12 @@ func GenProjectID(url, name, forge string) string { func Track(dbConn *sql.DB, mu *sync.Mutex, manualRefresh *chan struct{}, name, url, forge, release string) { id := GenProjectID(url, name, forge) + err := db.UpsertProject(dbConn, mu, id, url, name, forge, release) if err != nil { fmt.Println("Error upserting project:", err) } + *manualRefresh <- struct{}{} } @@ -190,17 +204,21 @@ func RefreshLoop(dbConn *sql.DB, mu *sync.Mutex, interval int, manualRefresh, re if err != nil { fmt.Println("Error getting projects:", err) } + for i, p := range projectsList { p, err := fetchReleases(dbConn, mu, p) if err != nil { fmt.Println(err) continue } + projectsList[i] = p } + sort.Slice(projectsList, func(i, j int) bool { return strings.ToLower(projectsList[i].Name) < strings.ToLower(projectsList[j].Name) }) + for i := range projectsList { err = upsertReleases(dbConn, mu, projectsList[i].ID, projectsList[i].Releases) if err != nil { @@ -208,6 +226,7 @@ func RefreshLoop(dbConn *sql.DB, mu *sync.Mutex, interval int, manualRefresh, re continue } } + return projectsList } @@ -219,16 +238,18 @@ func RefreshLoop(dbConn *sql.DB, mu *sync.Mutex, interval int, manualRefresh, re projects = fetch() case <-*manualRefresh: ticker.Reset(time.Second * 3600) + projects = fetch() case <-*req: projectsCopy := make([]Project, len(projects)) copy(projectsCopy, projects) + *res <- projectsCopy } } } -// GetProject returns a project from the database +// GetProject returns a project from the database. func GetProject(dbConn *sql.DB, proj Project) (Project, error) { projectDB, err := db.GetProject(dbConn, proj.ID) if err != nil && errors.Is(err, sql.ErrNoRows) { @@ -236,6 +257,7 @@ func GetProject(dbConn *sql.DB, proj Project) (Project, error) { } else if err != nil { return proj, err } + p := Project{ ID: proj.ID, URL: proj.URL, @@ -243,10 +265,11 @@ func GetProject(dbConn *sql.DB, proj Project) (Project, error) { Forge: proj.Forge, Running: projectDB["version"], } + return p, err } -// GetProjectWithReleases returns a single project from the database along with its releases +// GetProjectWithReleases returns a single project from the database along with its releases. func GetProjectWithReleases(dbConn *sql.DB, mu *sync.Mutex, proj Project) (Project, error) { project, err := GetProject(dbConn, proj) if err != nil { @@ -256,7 +279,7 @@ func GetProjectWithReleases(dbConn *sql.DB, mu *sync.Mutex, proj Project) (Proje return GetReleases(dbConn, mu, project) } -// GetProjects returns a list of all projects from the database +// GetProjects returns a list of all projects from the database. func GetProjects(dbConn *sql.DB) ([]Project, error) { projectsDB, err := db.GetProjects(dbConn) if err != nil { @@ -278,7 +301,7 @@ func GetProjects(dbConn *sql.DB) ([]Project, error) { } // GetProjectsWithReleases returns a list of all projects and all their releases -// from the database +// from the database. func GetProjectsWithReleases(dbConn *sql.DB, mu *sync.Mutex) ([]Project, error) { projects, err := GetProjects(dbConn) if err != nil { @@ -290,6 +313,7 @@ func GetProjectsWithReleases(dbConn *sql.DB, mu *sync.Mutex) ([]Project, error) if err != nil { return nil, err } + projects[i].Releases = SortReleases(projects[i].Releases) } diff --git a/rss/rss.go b/rss/rss.go index 2ef412000bf62229477663d8e843d5d8797aaf30..b7cc8037eceb600c01e08c2ef97a7a5940411174 100644 --- a/rss/rss.go +++ b/rss/rss.go @@ -10,7 +10,6 @@ import ( "time" "github.com/microcosm-cc/bluemonday" - "github.com/mmcdole/gofeed" ) diff --git a/users/users.go b/users/users.go index ed96cb64a6b02d76ecb5f9c96c1ab1328a231d98..1945689f4e0278adfb0d85c005224f8ce5dc20a1 100644 --- a/users/users.go +++ b/users/users.go @@ -22,6 +22,7 @@ func argonHash(password, salt string) (string, error) { if err != nil { return "", err } + return base64.StdEncoding.EncodeToString(argon2.IDKey([]byte(password), decodedSalt, 2, 64*1024, 4, 64)), nil } @@ -29,10 +30,12 @@ func argonHash(password, salt string) (string, error) { // string. func generateSalt() (string, error) { salt := make([]byte, 16) + _, err := rand.Read(salt) if err != nil { return "", err } + return base64.StdEncoding.EncodeToString(salt), nil } @@ -88,7 +91,7 @@ func InvalidateSession(dbConn *sql.DB, session string) error { } // CreateSession accepts a username, generates a token, stores it in the -// database, and returns it +// database, and returns it. func CreateSession(dbConn *sql.DB, username string) (string, time.Time, error) { token, err := generateSalt() if err != nil { diff --git a/ws/ws.go b/ws/ws.go index 101c3081630483ab9afee5824f0bab6142bd62b1..22e695ce37c2cfc000afc31ff24be832d34c58ef 100644 --- a/ws/ws.go +++ b/ws/ws.go @@ -7,6 +7,7 @@ package ws import ( "database/sql" "embed" + "errors" "fmt" "io" "net/http" @@ -33,7 +34,7 @@ type Handler struct { //go:embed static var fs embed.FS -// bmUGC = bluemonday.UGCPolicy() +// bmUGC = bluemonday.UGCPolicy(). var bmStrict = bluemonday.StrictPolicy() func (h Handler) RootHandler(w http.ResponseWriter, r *http.Request) { @@ -41,14 +42,17 @@ func (h Handler) RootHandler(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/login", http.StatusSeeOther) return } + projectsWithReleases, err := project.GetProjectsWithReleases(h.DbConn, h.Mu) if err != nil { fmt.Println(err) w.WriteHeader(http.StatusInternalServerError) + _, err := w.Write([]byte("Internal Server Error")) if err != nil { fmt.Println(err) } + return } @@ -73,11 +77,14 @@ func (h Handler) NewHandler(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/login", http.StatusSeeOther) return } + params := r.URL.Query() + action := bmStrict.Sanitize(params.Get("action")) if r.Method == http.MethodGet { if action == "" { data := struct{ Version string }{Version: *h.Version} + tmpl := template.Must(template.ParseFS(fs, "static/new.html.tmpl", "static/head.html.tmpl", "static/header.html.tmpl", "static/footer.html.tmpl")) if err := tmpl.Execute(w, data); err != nil { fmt.Println(err) @@ -86,30 +93,36 @@ func (h Handler) NewHandler(w http.ResponseWriter, r *http.Request) { submittedURL := bmStrict.Sanitize(params.Get("url")) if submittedURL == "" { w.WriteHeader(http.StatusBadRequest) + _, err := w.Write([]byte("No URL provided")) if err != nil { fmt.Println(err) } + return } forge := bmStrict.Sanitize(params.Get("forge")) if forge == "" { w.WriteHeader(http.StatusBadRequest) + _, err := w.Write([]byte("No forge provided")) if err != nil { fmt.Println(err) } + return } name := bmStrict.Sanitize(params.Get("name")) if name == "" { w.WriteHeader(http.StatusBadRequest) + _, err := w.Write([]byte("No name provided")) if err != nil { fmt.Println(err) } + return } @@ -121,22 +134,26 @@ func (h Handler) NewHandler(w http.ResponseWriter, r *http.Request) { } proj, err := project.GetProject(h.DbConn, proj) - if err != nil && err != sql.ErrNoRows { + if err != nil && !errors.Is(err, sql.ErrNoRows) { w.WriteHeader(http.StatusBadRequest) - _, err := w.Write([]byte(fmt.Sprintf("Error getting project: %s", err))) + + _, err := fmt.Fprintf(w, "Error getting project: %s", err) if err != nil { fmt.Println(err) } + return } proj, err = project.GetReleases(h.DbConn, h.Mu, proj) if err != nil { w.WriteHeader(http.StatusBadRequest) - _, err := w.Write([]byte(fmt.Sprintf("Error getting releases: %s", err))) + + _, err := fmt.Fprintf(w, "Error getting releases: %s", err) if err != nil { fmt.Println(err) } + return } @@ -156,10 +173,12 @@ func (h Handler) NewHandler(w http.ResponseWriter, r *http.Request) { submittedID := params.Get("id") if submittedID == "" { w.WriteHeader(http.StatusBadRequest) + _, err := w.Write([]byte("No URL provided")) if err != nil { fmt.Println(err) } + return } @@ -173,6 +192,7 @@ func (h Handler) NewHandler(w http.ResponseWriter, r *http.Request) { if err != nil { fmt.Println(err) } + idValue := bmStrict.Sanitize(r.FormValue("id")) nameValue := bmStrict.Sanitize(r.FormValue("name")) urlValue := bmStrict.Sanitize(r.FormValue("url")) @@ -183,6 +203,7 @@ func (h Handler) NewHandler(w http.ResponseWriter, r *http.Request) { if idValue != "" && nameValue != "" && urlValue != "" && forgeValue != "" && releaseValue != "" { project.Track(h.DbConn, h.Mu, h.ManualRefresh, nameValue, urlValue, forgeValue, releaseValue) http.Redirect(w, r, "/", http.StatusSeeOther) + return } @@ -193,6 +214,7 @@ func (h Handler) NewHandler(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(http.StatusBadRequest) + _, err = w.Write([]byte("No data provided")) if err != nil { fmt.Println(err) @@ -212,6 +234,7 @@ func (h Handler) LoginHandler(w http.ResponseWriter, r *http.Request) { }{ Version: *h.Version, } + tmpl := template.Must(template.ParseFS(fs, "static/login.html.tmpl", "static/head.html.tmpl", "static/footer.html.tmpl")) if err := tmpl.Execute(w, data); err != nil { fmt.Println(err) @@ -223,44 +246,53 @@ func (h Handler) LoginHandler(w http.ResponseWriter, r *http.Request) { if err != nil { fmt.Println(err) } + username := bmStrict.Sanitize(r.FormValue("username")) password := bmStrict.Sanitize(r.FormValue("password")) if username == "" || password == "" { w.WriteHeader(http.StatusBadRequest) + _, err := w.Write([]byte("No data provided")) if err != nil { fmt.Println(err) } + return } authorised, err := users.UserAuthorised(h.DbConn, username, password) if err != nil { w.WriteHeader(http.StatusBadRequest) - _, err := w.Write([]byte(fmt.Sprintf("Error logging in: %s", err))) + + _, err := fmt.Fprintf(w, "Error logging in: %s", err) if err != nil { fmt.Println(err) } + return } if !authorised { w.WriteHeader(http.StatusUnauthorized) + _, err := w.Write([]byte("Incorrect username or password")) if err != nil { fmt.Println(err) } + return } session, expiry, err := users.CreateSession(h.DbConn, username) if err != nil { w.WriteHeader(http.StatusBadRequest) - _, err := w.Write([]byte(fmt.Sprintf("Error creating session: %s", err))) + + _, err := fmt.Fprintf(w, "Error creating session: %s", err) if err != nil { fmt.Println(err) } + return } @@ -289,12 +321,15 @@ func (h Handler) LogoutHandler(w http.ResponseWriter, r *http.Request) { err = users.InvalidateSession(h.DbConn, cookie.Value) if err != nil { fmt.Println(err) - _, err = w.Write([]byte(fmt.Sprintf("Error logging out: %s", err))) + + _, err = fmt.Fprintf(w, "Error logging out: %s", err) if err != nil { fmt.Println(err) } + return } + cookie.MaxAge = -1 http.SetCookie(w, cookie) http.Redirect(w, r, "/login", http.StatusSeeOther) @@ -324,10 +359,12 @@ func StaticHandler(writer http.ResponseWriter, request *http.Request) { } else if strings.HasSuffix(resource, ".js") { writer.Header().Set("Content-Type", "text/javascript") } + home, err := fs.ReadFile(resource) if err != nil { fmt.Println(err) } + if _, err = io.Writer.Write(writer, home); err != nil { fmt.Println(err) }