1package backend
2
3import (
4 "context"
5 "errors"
6 "time"
7
8 "github.com/charmbracelet/soft-serve/pkg/db"
9 "github.com/charmbracelet/soft-serve/pkg/proto"
10)
11
12// CreateAccessToken creates an access token for user.
13func (b *Backend) CreateAccessToken(ctx context.Context, user proto.User, name string, expiresAt time.Time) (string, error) {
14 token := GenerateToken()
15 tokenHash := HashToken(token)
16
17 if err := b.db.TransactionContext(ctx, func(tx *db.Tx) error {
18 _, err := b.store.CreateAccessToken(ctx, tx, name, user.ID(), tokenHash, expiresAt)
19 if err != nil {
20 return db.WrapError(err)
21 }
22
23 return nil
24 }); err != nil {
25 return "", err
26 }
27
28 return token, nil
29}
30
31// DeleteAccessToken deletes an access token for a user.
32func (b *Backend) DeleteAccessToken(ctx context.Context, user proto.User, id int64) error {
33 err := b.db.TransactionContext(ctx, func(tx *db.Tx) error {
34 _, err := b.store.GetAccessToken(ctx, tx, id)
35 if err != nil {
36 return db.WrapError(err)
37 }
38
39 if err := b.store.DeleteAccessTokenForUser(ctx, tx, user.ID(), id); err != nil {
40 return db.WrapError(err)
41 }
42 return nil
43 })
44 if err != nil {
45 if errors.Is(err, db.ErrRecordNotFound) {
46 return proto.ErrTokenNotFound
47 }
48 return err
49 }
50
51 return nil
52}
53
54// ListAccessTokens lists access tokens for a user.
55func (b *Backend) ListAccessTokens(ctx context.Context, user proto.User) ([]proto.AccessToken, error) {
56 accessTokens, err := b.store.GetAccessTokensByUserID(ctx, b.db, user.ID())
57 if err != nil {
58 return nil, db.WrapError(err)
59 }
60
61 tokens := make([]proto.AccessToken, 0, len(accessTokens))
62 for _, t := range accessTokens {
63 token := proto.AccessToken{
64 ID: t.ID,
65 Name: t.Name,
66 TokenHash: t.Token,
67 UserID: t.UserID,
68 CreatedAt: t.CreatedAt,
69 }
70 if t.ExpiresAt.Valid {
71 token.ExpiresAt = t.ExpiresAt.Time
72 }
73
74 tokens = append(tokens, token)
75 }
76
77 return tokens, nil
78}