1package database
 2
 3import (
 4	"context"
 5	"time"
 6
 7	"github.com/charmbracelet/soft-serve/server/db"
 8	"github.com/charmbracelet/soft-serve/server/db/models"
 9	"github.com/charmbracelet/soft-serve/server/store"
10)
11
12type accessTokenStore struct{}
13
14var _ store.AccessTokenStore = (*accessTokenStore)(nil)
15
16// CreateAccessToken implements store.AccessTokenStore.
17func (s *accessTokenStore) CreateAccessToken(ctx context.Context, h db.Handler, name string, userID int64, token string, expiresAt time.Time) (models.AccessToken, error) {
18	queryWithoutExpires := `INSERT INTO access_tokens (name, user_id, token, created_at, updated_at)
19	VALUES (?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)`
20	queryWithExpires := `INSERT INTO access_tokens (name, user_id, token, expires_at, created_at, updated_at)
21	VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)`
22
23	query := queryWithoutExpires
24	values := []interface{}{name, userID, token}
25	if !expiresAt.IsZero() {
26		query = queryWithExpires
27		values = append(values, expiresAt)
28	}
29
30	result, err := h.ExecContext(ctx, query, values...)
31	if err != nil {
32		return models.AccessToken{}, err
33	}
34
35	id, err := result.LastInsertId()
36	if err != nil {
37		return models.AccessToken{}, err
38	}
39
40	return s.GetAccessToken(ctx, h, id)
41}
42
43// DeleteAccessToken implements store.AccessTokenStore.
44func (*accessTokenStore) DeleteAccessToken(ctx context.Context, h db.Handler, id int64) error {
45	query := h.Rebind(`DELETE FROM access_tokens WHERE id = ?`)
46	_, err := h.ExecContext(ctx, query, id)
47	return err
48}
49
50// DeleteAccessTokenForUser implements store.AccessTokenStore.
51func (*accessTokenStore) DeleteAccessTokenForUser(ctx context.Context, h db.Handler, userID int64, id int64) error {
52	query := h.Rebind(`DELETE FROM access_tokens WHERE user_id = ? AND id = ?`)
53	_, err := h.ExecContext(ctx, query, userID, id)
54	return err
55}
56
57// GetAccessToken implements store.AccessTokenStore.
58func (*accessTokenStore) GetAccessToken(ctx context.Context, h db.Handler, id int64) (models.AccessToken, error) {
59	query := h.Rebind(`SELECT * FROM access_tokens WHERE id = ?`)
60	var m models.AccessToken
61	err := h.GetContext(ctx, &m, query, id)
62	return m, err
63}
64
65// GetAccessTokensByUserID implements store.AccessTokenStore.
66func (*accessTokenStore) GetAccessTokensByUserID(ctx context.Context, h db.Handler, userID int64) ([]models.AccessToken, error) {
67	query := h.Rebind(`SELECT * FROM access_tokens WHERE user_id = ?`)
68	var m []models.AccessToken
69	err := h.SelectContext(ctx, &m, query, userID)
70	return m, err
71}
72
73// GetAccessTokenByToken implements store.AccessTokenStore.
74func (*accessTokenStore) GetAccessTokenByToken(ctx context.Context, h db.Handler, token string) (models.AccessToken, error) {
75	query := h.Rebind(`SELECT * FROM access_tokens WHERE token = ?`)
76	var m models.AccessToken
77	err := h.GetContext(ctx, &m, query, token)
78	return m, err
79}