1package database
2
3import (
4 "context"
5 "time"
6
7 "github.com/charmbracelet/soft-serve/pkg/db"
8 "github.com/charmbracelet/soft-serve/pkg/db/models"
9 "github.com/charmbracelet/soft-serve/pkg/store"
10)
11
12type accessTokenStore struct{}
13
14var _ store.AccessTokenStore = (*accessTokenStore)(nil)
15
16// CreateAccessToken implements store.AccessTokenStore.
17func (s *accessTokenStore) CreateAccessToken(ctx context.Context, h db.Handler, name string, userID int64, token string, expiresAt time.Time) (models.AccessToken, error) {
18 queryWithoutExpires := `INSERT INTO access_tokens (name, user_id, token, created_at, updated_at)
19 VALUES (?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) RETURNING id`
20 queryWithExpires := `INSERT INTO access_tokens (name, user_id, token, expires_at, created_at, updated_at)
21 VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) RETURNING id`
22
23 query := queryWithoutExpires
24 values := []interface{}{name, userID, token}
25 if !expiresAt.IsZero() {
26 query = queryWithExpires
27 values = append(values, expiresAt.UTC())
28 }
29
30 var id int64
31 if err := h.GetContext(ctx, &id, h.Rebind(query), values...); err != nil {
32 return models.AccessToken{}, err
33 }
34
35 return s.GetAccessToken(ctx, h, id)
36}
37
38// DeleteAccessToken implements store.AccessTokenStore.
39func (*accessTokenStore) DeleteAccessToken(ctx context.Context, h db.Handler, id int64) error {
40 query := h.Rebind(`DELETE FROM access_tokens WHERE id = ?`)
41 _, err := h.ExecContext(ctx, query, id)
42 return err
43}
44
45// DeleteAccessTokenForUser implements store.AccessTokenStore.
46func (*accessTokenStore) DeleteAccessTokenForUser(ctx context.Context, h db.Handler, userID int64, id int64) error {
47 query := h.Rebind(`DELETE FROM access_tokens WHERE user_id = ? AND id = ?`)
48 _, err := h.ExecContext(ctx, query, userID, id)
49 return err
50}
51
52// GetAccessToken implements store.AccessTokenStore.
53func (*accessTokenStore) GetAccessToken(ctx context.Context, h db.Handler, id int64) (models.AccessToken, error) {
54 query := h.Rebind(`SELECT * FROM access_tokens WHERE id = ?`)
55 var m models.AccessToken
56 err := h.GetContext(ctx, &m, query, id)
57 return m, err
58}
59
60// GetAccessTokensByUserID implements store.AccessTokenStore.
61func (*accessTokenStore) GetAccessTokensByUserID(ctx context.Context, h db.Handler, userID int64) ([]models.AccessToken, error) {
62 query := h.Rebind(`SELECT * FROM access_tokens WHERE user_id = ?`)
63 var m []models.AccessToken
64 err := h.SelectContext(ctx, &m, query, userID)
65 return m, err
66}
67
68// GetAccessTokenByToken implements store.AccessTokenStore.
69func (*accessTokenStore) GetAccessTokenByToken(ctx context.Context, h db.Handler, token string) (models.AccessToken, error) {
70 query := h.Rebind(`SELECT * FROM access_tokens WHERE token = ?`)
71 var m models.AccessToken
72 err := h.GetContext(ctx, &m, query, token)
73 return m, err
74}