fix: wrap db errors

Ayman Bagabas created

Change summary

go.mod                       |  2 +-
go.sum                       |  4 ++--
server/backend/lfs.go        |  6 +++---
server/git/lfs.go            | 37 +++++++++++++++++++++++--------------
server/git/service.go        |  1 +
server/lfs/common.go         | 20 ++++++++++++++++++--
server/lfs/http_client.go    | 16 ++++++++--------
server/lfs/pointer.go        |  6 ++++--
server/lfs/transfer.go       |  3 +++
server/ssh/git.go            |  3 ++-
server/store/database/lfs.go |  7 ++++---
server/store/lfs.go          |  2 +-
12 files changed, 70 insertions(+), 37 deletions(-)

Detailed changes

go.mod 🔗

@@ -19,7 +19,7 @@ require (
 
 require (
 	github.com/caarlos0/env/v8 v8.0.0
-	github.com/charmbracelet/git-lfs-transfer v0.1.1-0.20230720173103-0db2d71ab8d2
+	github.com/charmbracelet/git-lfs-transfer v0.1.1-0.20230721203144-64d90e7a36a1
 	github.com/charmbracelet/keygen v0.4.3
 	github.com/charmbracelet/log v0.2.3-0.20230713155356-557335e40e35
 	github.com/charmbracelet/ssh v0.0.0-20230720143903-5bdd92839155

go.sum 🔗

@@ -21,8 +21,8 @@ github.com/charmbracelet/bubbles v0.16.1 h1:6uzpAAaT9ZqKssntbvZMlksWHruQLNxg49H5
 github.com/charmbracelet/bubbles v0.16.1/go.mod h1:2QCp9LFlEsBQMvIYERr7Ww2H2bA7xen1idUDIzm/+Xc=
 github.com/charmbracelet/bubbletea v0.24.2 h1:uaQIKx9Ai6Gdh5zpTbGiWpytMU+CfsPp06RaW2cx/SY=
 github.com/charmbracelet/bubbletea v0.24.2/go.mod h1:XdrNrV4J8GiyshTtx3DNuYkR1FDaJmO3l2nejekbsgg=
-github.com/charmbracelet/git-lfs-transfer v0.1.1-0.20230720173103-0db2d71ab8d2 h1:a3iaZ53uBHjCN2mnrKARVTXiOmEdcDIqUzBRbCdB3Bk=
-github.com/charmbracelet/git-lfs-transfer v0.1.1-0.20230720173103-0db2d71ab8d2/go.mod h1:eXJuVicxnjRgRMokmutZdistxoMRjBjjfqvrYq7bCIU=
+github.com/charmbracelet/git-lfs-transfer v0.1.1-0.20230721203144-64d90e7a36a1 h1:/QzZzTDdlDYGZeC2O2y/Qw+AiHqh3vCsO4yrKDWXtqs=
+github.com/charmbracelet/git-lfs-transfer v0.1.1-0.20230721203144-64d90e7a36a1/go.mod h1:eXJuVicxnjRgRMokmutZdistxoMRjBjjfqvrYq7bCIU=
 github.com/charmbracelet/glamour v0.6.0 h1:wi8fse3Y7nfcabbbDuwolqTqMQPMnVPeZhDM273bISc=
 github.com/charmbracelet/glamour v0.6.0/go.mod h1:taqWV4swIMMbWALc0m7AfE9JkPSU8om2538k9ITBxOc=
 github.com/charmbracelet/keygen v0.4.3 h1:ywOZRwkDlpmkawl0BgLTxaYWDSqp6Y4nfVVmgyyO1Mg=

server/backend/lfs.go 🔗

@@ -40,7 +40,7 @@ func StoreRepoMissingLFSObjects(ctx context.Context, repo proto.Repository, dbx
 			defer content.Close() // nolint: errcheck
 			return dbx.TransactionContext(ctx, func(tx *db.Tx) error {
 				if err := store.CreateLFSObject(ctx, tx, repo.ID(), p.Oid, p.Size); err != nil {
-					return err
+					return db.WrapError(err)
 				}
 
 				return strg.Put(path.Join("objects", p.RelativePath()), content)
@@ -52,7 +52,7 @@ func StoreRepoMissingLFSObjects(ctx context.Context, repo proto.Repository, dbx
 	for pointer := range pointerChan {
 		obj, err := store.GetLFSObjectByOid(ctx, dbx, repo.ID(), pointer.Oid)
 		if err != nil && !errors.Is(err, db.ErrRecordNotFound) {
-			return err
+			return db.WrapError(err)
 		}
 
 		exist, err := strg.Exists(path.Join("objects", pointer.RelativePath()))
@@ -62,7 +62,7 @@ func StoreRepoMissingLFSObjects(ctx context.Context, repo proto.Repository, dbx
 
 		if exist && obj.ID == 0 {
 			if err := store.CreateLFSObject(ctx, dbx, repo.ID(), pointer.Oid, pointer.Size); err != nil {
-				return err
+				return db.WrapError(err)
 			}
 		} else {
 			batch = append(batch, pointer.Pointer)

server/git/lfs.go 🔗

@@ -112,7 +112,7 @@ func (t *lfsTransfer) Batch(_ string, pointers []transfer.Pointer) ([]transfer.B
 	for _, p := range pointers {
 		obj, err := t.store.GetLFSObjectByOid(t.ctx, t.dbx, repo.ID(), p.Oid)
 		if err != nil && !errors.Is(err, db.ErrRecordNotFound) {
-			return items, err
+			return items, db.WrapError(err)
 		}
 
 		exist, err := t.storage.Exists(path.Join("objects", p.RelativePath()))
@@ -122,7 +122,7 @@ func (t *lfsTransfer) Batch(_ string, pointers []transfer.Pointer) ([]transfer.B
 
 		if exist && obj.ID == 0 {
 			if err := t.store.CreateLFSObject(t.ctx, t.dbx, repo.ID(), p.Oid, p.Size); err != nil {
-				return items, err
+				return items, db.WrapError(err)
 			}
 		}
 
@@ -256,22 +256,27 @@ func (t *lfsTransfer) LockBackend() transfer.LockBackend {
 }
 
 // Create implements transfer.LockBackend.
-func (l *lfsLockBackend) Create(path string) (transfer.Lock, error) {
+func (l *lfsLockBackend) Create(path string, refname string) (transfer.Lock, error) {
 	var lock LFSLock
 	if err := l.dbx.TransactionContext(l.ctx, func(tx *db.Tx) error {
-		if err := l.store.CreateLFSLockForUser(l.ctx, tx, l.repo.ID(), l.user.ID(), path); err != nil {
-			return err
+		if err := l.store.CreateLFSLockForUser(l.ctx, tx, l.repo.ID(), l.user.ID(), path, refname); err != nil {
+			return db.WrapError(err)
 		}
 
 		var err error
 		lock.lock, err = l.store.GetLFSLockForUserPath(l.ctx, tx, l.repo.ID(), l.user.ID(), path)
 		if err != nil {
-			return err
+			return db.WrapError(err)
 		}
 
 		lock.owner, err = l.store.GetUserByID(l.ctx, tx, lock.lock.UserID)
-		return err
+		return db.WrapError(err)
 	}); err != nil {
+		// Return conflict (409) if the lock already exists.
+		if errors.Is(err, db.ErrDuplicateKey) {
+			return nil, transfer.ErrConflict
+		}
+		l.logger.Errorf("error creating lock: %v", err)
 		return nil, err
 	}
 
@@ -292,12 +297,13 @@ func (l *lfsLockBackend) FromID(id string) (transfer.Lock, error) {
 		var err error
 		lock.lock, err = l.store.GetLFSLockForUserByID(l.ctx, tx, user.ID(), id)
 		if err != nil {
-			return err
+			return db.WrapError(err)
 		}
 
 		lock.owner, err = l.store.GetUserByID(l.ctx, tx, lock.lock.UserID)
-		return err
+		return db.WrapError(err)
 	}); err != nil {
+		l.logger.Errorf("error getting lock: %v", err)
 		return nil, err
 	}
 
@@ -314,12 +320,13 @@ func (l *lfsLockBackend) FromPath(path string) (transfer.Lock, error) {
 		var err error
 		lock.lock, err = l.store.GetLFSLockForUserPath(l.ctx, tx, l.repo.ID(), l.user.ID(), path)
 		if err != nil {
-			return err
+			return db.WrapError(err)
 		}
 
 		lock.owner, err = l.store.GetUserByID(l.ctx, tx, lock.lock.UserID)
-		return err
+		return db.WrapError(err)
 	}); err != nil {
+		l.logger.Errorf("error getting lock: %v", err)
 		return nil, err
 	}
 
@@ -335,7 +342,7 @@ func (l *lfsLockBackend) Range(fn func(transfer.Lock) error) error {
 	if err := l.dbx.TransactionContext(l.ctx, func(tx *db.Tx) error {
 		mlocks, err := l.store.GetLFSLocks(l.ctx, tx, l.repo.ID())
 		if err != nil {
-			return err
+			return db.WrapError(err)
 		}
 
 		users := make(map[int64]models.User, 0)
@@ -344,7 +351,7 @@ func (l *lfsLockBackend) Range(fn func(transfer.Lock) error) error {
 			if !ok {
 				owner, err = l.store.GetUserByID(l.ctx, tx, mlock.UserID)
 				if err != nil {
-					return err
+					return db.WrapError(err)
 				}
 
 				users[mlock.UserID] = owner
@@ -370,7 +377,9 @@ func (l *lfsLockBackend) Range(fn func(transfer.Lock) error) error {
 // Unlock implements transfer.LockBackend.
 func (l *lfsLockBackend) Unlock(lock transfer.Lock) error {
 	return l.dbx.TransactionContext(l.ctx, func(tx *db.Tx) error {
-		return l.store.DeleteLFSLockForUserByID(l.ctx, tx, l.user.ID(), lock.ID())
+		return db.WrapError(
+			l.store.DeleteLFSLockForUserByID(l.ctx, tx, l.user.ID(), lock.ID()),
+		)
 	})
 }
 

server/git/service.go 🔗

@@ -25,6 +25,7 @@ const (
 	ReceivePackService Service = "git-receive-pack"
 	// LFSTransferService is the LFS transfer service.
 	LFSTransferService Service = "git-lfs-transfer"
+	// TODO: add support for git-lfs-authenticate
 )
 
 // String returns the string representation of the service.

server/lfs/common.go 🔗

@@ -3,8 +3,23 @@ package lfs
 import "time"
 
 const (
-	// MediaType contains the media type for LFS server requests
+	// MediaType contains the media type for LFS server requests.
 	MediaType = "application/vnd.git-lfs+json"
+
+	// OperationDownload is the operation name for a download request.
+	OperationDownload = "download"
+
+	// OperationUpload is the operation name for an upload request.
+	OperationUpload = "upload"
+
+	// ActionDownload is the action name for a download request.
+	ActionDownload = OperationDownload
+
+	// ActionUpload is the action name for an upload request.
+	ActionUpload = OperationUpload
+
+	// ActionVerify is the action name for a verify request.
+	ActionVerify = "verify"
 )
 
 // Pointer contains LFS pointer data
@@ -21,7 +36,7 @@ type PointerBlob struct {
 
 // ErrorResponse describes the error to the client.
 type ErrorResponse struct {
-	Message          string
+	Message          string `json:"message,omitempty"`
 	DocumentationURL string `json:"documentation_url,omitempty"`
 	RequestID        string `json:"request_id,omitempty"`
 }
@@ -32,6 +47,7 @@ type ErrorResponse struct {
 type BatchResponse struct {
 	Transfer string            `json:"transfer,omitempty"`
 	Objects  []*ObjectResponse `json:"objects"`
+	HashAlgo string            `json:"hash_algo,omitempty"`
 }
 
 // ObjectResponse is object metadata as seen by clients of the LFS server.

server/lfs/http_client.go 🔗

@@ -26,7 +26,7 @@ func newHTTPClient(endpoint Endpoint) *httpClient {
 		client:   http.DefaultClient,
 		endpoint: endpoint,
 		transfers: map[string]TransferAdapter{
-			"basic": &BasicTransferAdapter{http.DefaultClient},
+			TransferBasic: &BasicTransferAdapter{http.DefaultClient},
 		},
 	}
 }
@@ -57,7 +57,7 @@ func (c *httpClient) batch(ctx context.Context, operation string, objects []Poin
 	url := fmt.Sprintf("%s/objects/batch", c.endpoint.String())
 
 	// TODO: support ref
-	request := &BatchRequest{operation, c.transferNames(), nil, objects, hashAlgo}
+	request := &BatchRequest{operation, c.transferNames(), nil, objects, HashAlgorithmSHA256}
 
 	payload := new(bytes.Buffer)
 	err := json.NewEncoder(payload).Encode(request)
@@ -100,7 +100,7 @@ func (c *httpClient) batch(ctx context.Context, operation string, objects []Poin
 	}
 
 	if len(response.Transfer) == 0 {
-		response.Transfer = "basic"
+		response.Transfer = TransferBasic
 	}
 
 	return &response, nil
@@ -112,9 +112,9 @@ func (c *httpClient) performOperation(ctx context.Context, objects []Pointer, dc
 		return nil
 	}
 
-	operation := "download"
+	operation := OperationDownload
 	if uc != nil {
-		operation = "upload"
+		operation = OperationUpload
 	}
 
 	result, err := c.batch(ctx, operation, objects)
@@ -149,7 +149,7 @@ func (c *httpClient) performOperation(ctx context.Context, objects []Pointer, dc
 				continue
 			}
 
-			link, ok := object.Actions["upload"]
+			link, ok := object.Actions[ActionUpload]
 			if !ok {
 				logger.Debugf("%+v", object)
 				return errors.New("Missing action 'upload'")
@@ -168,14 +168,14 @@ func (c *httpClient) performOperation(ctx context.Context, objects []Pointer, dc
 				return err
 			}
 
-			link, ok = object.Actions["verify"]
+			link, ok = object.Actions[ActionVerify]
 			if ok {
 				if err := transferAdapter.Verify(ctx, object.Pointer, link); err != nil {
 					return err
 				}
 			}
 		} else {
-			link, ok := object.Actions["download"]
+			link, ok := object.Actions[ActionDownload]
 			if !ok {
 				logger.Debugf("%+v", object)
 				return errors.New("Missing action 'download'")

server/lfs/pointer.go 🔗

@@ -14,14 +14,16 @@ import (
 
 const (
 	blobSizeCutoff = 1024
-	hashAlgo       = "sha256"
+
+	// HashAlgorithmSHA256 is the hash algorithm used for Git LFS.
+	HashAlgorithmSHA256 = "sha256"
 
 	// MetaFileIdentifier is the string appearing at the first line of LFS pointer files.
 	// https://github.com/git-lfs/git-lfs/blob/master/docs/spec.md
 	MetaFileIdentifier = "version https://git-lfs.github.com/spec/v1"
 
 	// MetaFileOidPrefix appears in LFS pointer files on a line before the sha256 hash.
-	MetaFileOidPrefix = "oid " + hashAlgo + ":"
+	MetaFileOidPrefix = "oid " + HashAlgorithmSHA256 + ":"
 )
 
 var (

server/lfs/transfer.go 🔗

@@ -5,6 +5,9 @@ import (
 	"io"
 )
 
+// TransferBasic is the name of the Git LFS basic transfer protocol.
+const TransferBasic = "basic"
+
 // TransferAdapter represents an adapter for downloading/uploading LFS objects
 type TransferAdapter interface {
 	Name() string

server/ssh/git.go 🔗

@@ -10,6 +10,7 @@ import (
 	"github.com/charmbracelet/soft-serve/server/backend"
 	"github.com/charmbracelet/soft-serve/server/config"
 	"github.com/charmbracelet/soft-serve/server/git"
+	"github.com/charmbracelet/soft-serve/server/lfs"
 	"github.com/charmbracelet/soft-serve/server/proto"
 	"github.com/charmbracelet/soft-serve/server/sshutils"
 	"github.com/charmbracelet/soft-serve/server/utils"
@@ -133,7 +134,7 @@ func handleGit(s ssh.Session) {
 		}
 
 		if len(cmdLine) != 3 ||
-			(cmdLine[2] != "download" && cmdLine[2] != "upload") {
+			(cmdLine[2] != lfs.OperationDownload && cmdLine[2] != lfs.OperationUpload) {
 			sshFatal(s, git.ErrInvalidRequest)
 			return
 		}

server/store/database/lfs.go 🔗

@@ -21,17 +21,18 @@ func sanitizePath(path string) string {
 }
 
 // CreateLFSLockForUser implements store.LFSStore.
-func (*lfsStore) CreateLFSLockForUser(ctx context.Context, tx db.Handler, repoID int64, userID int64, path string) error {
+func (*lfsStore) CreateLFSLockForUser(ctx context.Context, tx db.Handler, repoID int64, userID int64, path string, refname string) error {
 	path = sanitizePath(path)
-	query := tx.Rebind(`INSERT INTO lfs_locks (repo_id, user_id, path, updated_at)
+	query := tx.Rebind(`INSERT INTO lfs_locks (repo_id, user_id, path, refname, updated_at)
 		VALUES (
 			?,
 			?,
 			?,
+			?,
 			CURRENT_TIMESTAMP
 		);
 	`)
-	_, err := tx.ExecContext(ctx, query, repoID, userID, path)
+	_, err := tx.ExecContext(ctx, query, repoID, userID, path, refname)
 	return db.WrapError(err)
 }
 

server/store/lfs.go 🔗

@@ -15,7 +15,7 @@ type LFSStore interface {
 	GetLFSObjectsByName(ctx context.Context, h db.Handler, name string) ([]models.LFSObject, error)
 	DeleteLFSObjectByOid(ctx context.Context, h db.Handler, repoID int64, oid string) error
 
-	CreateLFSLockForUser(ctx context.Context, h db.Handler, repoID int64, userID int64, path string) error
+	CreateLFSLockForUser(ctx context.Context, h db.Handler, repoID int64, userID int64, path string, refname string) error
 	GetLFSLocks(ctx context.Context, h db.Handler, repoID int64) ([]models.LFSLock, error)
 	GetLFSLocksForUser(ctx context.Context, h db.Handler, repoID int64, userID int64) ([]models.LFSLock, error)
 	GetLFSLocksForPath(ctx context.Context, h db.Handler, repoID int64, path string) ([]models.LFSLock, error)