access_token.go

 1// Package backend provides backend functionality for soft-serve.
 2package backend
 3
 4import (
 5	"context"
 6	"errors"
 7	"time"
 8
 9	"github.com/charmbracelet/soft-serve/pkg/db"
10	"github.com/charmbracelet/soft-serve/pkg/proto"
11)
12
13// CreateAccessToken creates an access token for user.
14func (b *Backend) CreateAccessToken(ctx context.Context, user proto.User, name string, expiresAt time.Time) (string, error) {
15	token := GenerateToken()
16	tokenHash := HashToken(token)
17
18	if err := b.db.TransactionContext(ctx, func(tx *db.Tx) error {
19		_, err := b.store.CreateAccessToken(ctx, tx, name, user.ID(), tokenHash, expiresAt)
20		if err != nil {
21			return db.WrapError(err)
22		}
23
24		return nil
25	}); err != nil {
26		return "", err //nolint:wrapcheck
27	}
28
29	return token, nil
30}
31
32// DeleteAccessToken deletes an access token for a user.
33func (b *Backend) DeleteAccessToken(ctx context.Context, user proto.User, id int64) error {
34	err := b.db.TransactionContext(ctx, func(tx *db.Tx) error {
35		_, err := b.store.GetAccessToken(ctx, tx, id)
36		if err != nil {
37			return db.WrapError(err)
38		}
39
40		if err := b.store.DeleteAccessTokenForUser(ctx, tx, user.ID(), id); err != nil {
41			return db.WrapError(err)
42		}
43		return nil
44	})
45	if err != nil {
46		if errors.Is(err, db.ErrRecordNotFound) {
47			return proto.ErrTokenNotFound
48		}
49		return err //nolint:wrapcheck
50	}
51
52	return nil
53}
54
55// ListAccessTokens lists access tokens for a user.
56func (b *Backend) ListAccessTokens(ctx context.Context, user proto.User) ([]proto.AccessToken, error) {
57	accessTokens, err := b.store.GetAccessTokensByUserID(ctx, b.db, user.ID())
58	if err != nil {
59		return nil, db.WrapError(err) //nolint:wrapcheck
60	}
61
62	tokens := make([]proto.AccessToken, 0, len(accessTokens))
63	for _, t := range accessTokens {
64		token := proto.AccessToken{
65			ID:        t.ID,
66			Name:      t.Name,
67			TokenHash: t.Token,
68			UserID:    t.UserID,
69			CreatedAt: t.CreatedAt,
70		}
71		if t.ExpiresAt.Valid {
72			token.ExpiresAt = t.ExpiresAt.Time
73		}
74
75		tokens = append(tokens, token)
76	}
77
78	return tokens, nil
79}