git.go

  1package git
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"io"
  8	"path/filepath"
  9	"strings"
 10
 11	"github.com/charmbracelet/log"
 12	"github.com/charmbracelet/soft-serve/git"
 13	"github.com/go-git/go-git/v5/plumbing/format/pktline"
 14)
 15
 16var (
 17
 18	// ErrUnauthorized represents unauthorized access.
 19	ErrUnauthorized = errors.New("you are not authorized to do this")
 20
 21	// ErrSystemMalfunction represents a general system error returned to clients.
 22	ErrSystemMalfunction = errors.New("something went wrong")
 23
 24	// ErrNotExist represents an attempt to access a non-existent repo.
 25	ErrNotExist = errors.New("repo does not exist")
 26
 27	// ErrInvalidRequest represents an invalid request.
 28	ErrInvalidRequest = errors.New("invalid request")
 29
 30	// ErrMaxConnections represents a maximum connection limit being reached.
 31	ErrMaxConnections = errors.New("too many connections, try again later")
 32
 33	// ErrTimeout is returned when the maximum read timeout is exceeded.
 34	ErrTimeout = errors.New("I/O timeout reached")
 35)
 36
 37// WritePktline encodes and writes a pktline to the given writer.
 38func WritePktline(w io.Writer, v ...interface{}) error {
 39	msg := fmt.Sprintln(v...)
 40	pkt := pktline.NewEncoder(w)
 41	if err := pkt.EncodeString(msg); err != nil {
 42		return fmt.Errorf("git: error writing pkt-line message: %w", err)
 43	}
 44	if err := pkt.Flush(); err != nil {
 45		return fmt.Errorf("git: error flushing pkt-line message: %w", err)
 46	}
 47
 48	return nil
 49}
 50
 51// WritePktlineErr writes an error pktline to the given writer.
 52func WritePktlineErr(w io.Writer, err error) error {
 53	return WritePktline(w, "ERR ", err.Error())
 54}
 55
 56// EnsureWithin ensures the given repo is within the repos directory.
 57func EnsureWithin(reposDir string, repo string) error {
 58	repoDir := filepath.Join(reposDir, repo)
 59	absRepos, err := filepath.Abs(reposDir)
 60	if err != nil {
 61		log.Debugf("failed to get absolute path for repo: %s", err)
 62		return ErrSystemMalfunction
 63	}
 64	absRepo, err := filepath.Abs(repoDir)
 65	if err != nil {
 66		log.Debugf("failed to get absolute path for repos: %s", err)
 67		return ErrSystemMalfunction
 68	}
 69
 70	// ensure the repo is within the repos directory
 71	if !strings.HasPrefix(absRepo, absRepos) {
 72		log.Debugf("repo path is outside of repos directory: %s", absRepo)
 73		return ErrNotExist
 74	}
 75
 76	return nil
 77}
 78
 79// EnsureDefaultBranch ensures the repo has a default branch.
 80// It will prefer choosing "main" or "master" if available.
 81func EnsureDefaultBranch(ctx context.Context, scmd ServiceCommand) error {
 82	r, err := git.Open(scmd.Dir)
 83	if err != nil {
 84		return err
 85	}
 86	brs, err := r.Branches()
 87	if err != nil {
 88		return err
 89	}
 90	if len(brs) == 0 {
 91		return fmt.Errorf("no branches found")
 92	}
 93	// Rename the default branch to the first branch available
 94	_, err = r.HEAD()
 95	if err == git.ErrReferenceNotExist {
 96		branch := brs[0]
 97		// Prefer "main" or "master" as the default branch
 98		for _, b := range brs {
 99			if b == "main" || b == "master" {
100				branch = b
101				break
102			}
103		}
104
105		cmd := git.NewCommand("branch", "-M", branch).WithContext(ctx)
106		if err := cmd.RunInDirWithOptions(scmd.Dir, git.RunInDirOptions{
107			Stdin:  scmd.Stdin,
108			Stdout: scmd.Stdout,
109			Stderr: scmd.Stderr,
110		}); err != nil {
111			return err
112		}
113	}
114	if err != nil && err != git.ErrReferenceNotExist {
115		return err
116	}
117	return nil
118}