1package git
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"io"
  8	"path/filepath"
  9	"strings"
 10
 11	"github.com/charmbracelet/log"
 12	"github.com/charmbracelet/soft-serve/git"
 13	"github.com/go-git/go-git/v5/plumbing/format/pktline"
 14	gitm "github.com/gogs/git-module"
 15)
 16
 17var (
 18
 19	// ErrNotAuthed represents unauthorized access.
 20	ErrNotAuthed = errors.New("you are not authorized to do this")
 21
 22	// ErrSystemMalfunction represents a general system error returned to clients.
 23	ErrSystemMalfunction = errors.New("something went wrong")
 24
 25	// ErrInvalidRepo represents an attempt to access a non-existent repo.
 26	ErrInvalidRepo = errors.New("invalid repo")
 27
 28	// ErrInvalidRequest represents an invalid request.
 29	ErrInvalidRequest = errors.New("invalid request")
 30
 31	// ErrMaxConnections represents a maximum connection limit being reached.
 32	ErrMaxConnections = errors.New("too many connections, try again later")
 33
 34	// ErrTimeout is returned when the maximum read timeout is exceeded.
 35	ErrTimeout = errors.New("I/O timeout reached")
 36)
 37
 38// WritePktline encodes and writes a pktline to the given writer.
 39func WritePktline(w io.Writer, v ...interface{}) error {
 40	msg := fmt.Sprintln(v...)
 41	pkt := pktline.NewEncoder(w)
 42	if err := pkt.EncodeString(msg); err != nil {
 43		return fmt.Errorf("git: error writing pkt-line message: %w", err)
 44	}
 45	if err := pkt.Flush(); err != nil {
 46		return fmt.Errorf("git: error flushing pkt-line message: %w", err)
 47	}
 48
 49	return nil
 50}
 51
 52// WritePktlineErr writes an error pktline to the given writer.
 53func WritePktlineErr(w io.Writer, err error) error {
 54	return WritePktline(w, "ERR", err.Error())
 55}
 56
 57// EnsureWithin ensures the given repo is within the repos directory.
 58func EnsureWithin(reposDir string, repo string) error {
 59	repoDir := filepath.Join(reposDir, repo)
 60	absRepos, err := filepath.Abs(reposDir)
 61	if err != nil {
 62		log.Debugf("failed to get absolute path for repo: %s", err)
 63		return ErrSystemMalfunction
 64	}
 65	absRepo, err := filepath.Abs(repoDir)
 66	if err != nil {
 67		log.Debugf("failed to get absolute path for repos: %s", err)
 68		return ErrSystemMalfunction
 69	}
 70
 71	// ensure the repo is within the repos directory
 72	if !strings.HasPrefix(absRepo, absRepos) {
 73		log.Debugf("repo path is outside of repos directory: %s", absRepo)
 74		return ErrInvalidRepo
 75	}
 76
 77	return nil
 78}
 79
 80// EnsureDefaultBranch ensures the repo has a default branch.
 81// It will prefer choosing "main" or "master" if available.
 82func EnsureDefaultBranch(ctx context.Context, scmd ServiceCommand) error {
 83	r, err := git.Open(scmd.Dir)
 84	if err != nil {
 85		return err
 86	}
 87	brs, err := r.Branches()
 88	if err != nil {
 89		return err
 90	}
 91	if len(brs) == 0 {
 92		return fmt.Errorf("no branches found")
 93	}
 94	// Rename the default branch to the first branch available
 95	_, err = r.HEAD()
 96	if err == git.ErrReferenceNotExist {
 97		branch := brs[0]
 98		// Prefer "main" or "master" as the default branch
 99		for _, b := range brs {
100			if b == "main" || b == "master" {
101				branch = b
102				break
103			}
104		}
105
106		if _, err := r.SymbolicRef(git.HEAD, git.RefsHeads+branch, gitm.SymbolicRefOptions{
107			CommandOptions: gitm.CommandOptions{
108				Context: ctx,
109			},
110		}); err != nil {
111			return err
112		}
113	}
114	if err != nil && err != git.ErrReferenceNotExist {
115		return err
116	}
117	return nil
118}