1// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
2//
3// SPDX-License-Identifier: Apache-2.0
4
5package users
6
7import (
8 "crypto/rand"
9 "database/sql"
10 "encoding/base64"
11 "fmt"
12 "time"
13
14 "git.sr.ht/~amolith/willow/db"
15 "golang.org/x/crypto/argon2"
16)
17
18const (
19 argon2Time = 2
20 saltLength = 16
21 argon2KeyLen = 64
22 argon2Memory = 64 * 1024
23 argon2Threads = 4
24)
25
26// argonHash accepts two strings for the user's password and a random salt,
27// hashes the password using the salt, and returns the hash as a base64-encoded
28// string.
29func argonHash(password, salt string) (string, error) {
30 decodedSalt, err := base64.StdEncoding.DecodeString(salt)
31 if err != nil {
32 return "", fmt.Errorf("failed to decode base64: %w", err)
33 }
34
35 return base64.StdEncoding.EncodeToString(argon2.IDKey([]byte(password), decodedSalt, argon2Time, argon2Memory, argon2Threads, argon2KeyLen)), nil
36}
37
38// generateSalt generates a random salt and returns it as a base64-encoded
39// string.
40func generateSalt() (string, error) {
41 salt := make([]byte, saltLength)
42
43 _, err := rand.Read(salt)
44 if err != nil {
45 return "", fmt.Errorf("failed to generate random bytes: %w", err)
46 }
47
48 return base64.StdEncoding.EncodeToString(salt), nil
49}
50
51// Register accepts a username and password, hashes the password and stores the
52// hash and salt in the database.
53func Register(dbConn *sql.DB, username, password string) error {
54 salt, err := generateSalt()
55 if err != nil {
56 return err
57 }
58
59 hash, err := argonHash(password, salt)
60 if err != nil {
61 return err
62 }
63
64 err = db.CreateUser(dbConn, username, hash, salt)
65 if err != nil {
66 return fmt.Errorf("failed to create user: %w", err)
67 }
68
69 return nil
70}
71
72// Delete removes a user from the database.
73func Delete(dbConn *sql.DB, username string) error {
74 err := db.DeleteUser(dbConn, username)
75 if err != nil {
76 return fmt.Errorf("failed to delete user: %w", err)
77 }
78
79 return nil
80}
81
82// UserAuthorised accepts a username string, a token string, and returns true if the
83// user is authorised, false if not, and an error if one is encountered.
84func UserAuthorised(dbConn *sql.DB, username, token string) (bool, error) {
85 dbHash, dbSalt, err := db.GetUser(dbConn, username)
86 if err != nil {
87 return false, fmt.Errorf("failed to get user: %w", err)
88 }
89
90 providedHash, err := argonHash(token, dbSalt)
91 if err != nil {
92 return false, err
93 }
94
95 return dbHash == providedHash, nil
96}
97
98// SessionAuthorised accepts a session string and returns true if the session is
99// valid and false if not.
100func SessionAuthorised(dbConn *sql.DB, session string) (bool, error) {
101 dbResult, expiry, err := db.GetSession(dbConn, session)
102 if err != nil {
103 return false, fmt.Errorf("failed to get session: %w", err)
104 }
105
106 if dbResult == "" || expiry.Before(time.Now()) {
107 return false, nil
108 }
109
110 return true, nil
111}
112
113// InvalidateSession invalidates a session by setting the expiration date to now.
114func InvalidateSession(dbConn *sql.DB, session string) error {
115 err := db.InvalidateSession(dbConn, session, time.Now())
116 if err != nil {
117 return fmt.Errorf("failed to invalidate session: %w", err)
118 }
119
120 return nil
121}
122
123// CreateSession accepts a username, generates a token, stores it in the
124// database, and returns it.
125func CreateSession(dbConn *sql.DB, username string) (string, time.Time, error) {
126 token, err := generateSalt()
127 if err != nil {
128 return "", time.Time{}, err
129 }
130
131 expiry := time.Now().Add(7 * 24 * time.Hour)
132
133 err = db.CreateSession(dbConn, username, token, expiry)
134 if err != nil {
135 return "", time.Time{}, fmt.Errorf("failed to create session: %w", err)
136 }
137
138 return token, expiry, nil
139}
140
141// GetUsers returns a list of all users in the database as a slice of strings.
142func GetUsers(dbConn *sql.DB) ([]string, error) {
143 users, err := db.GetUsers(dbConn)
144 if err != nil {
145 return nil, fmt.Errorf("failed to get users: %w", err)
146 }
147
148 return users, nil
149}