1package git
  2
  3import (
  4	"errors"
  5	"fmt"
  6	"io"
  7	"os"
  8	"path/filepath"
  9	"strings"
 10
 11	"github.com/charmbracelet/log"
 12	"github.com/charmbracelet/soft-serve/git"
 13	"github.com/go-git/go-git/v5/plumbing/format/pktline"
 14)
 15
 16var (
 17
 18	// ErrNotAuthed represents unauthorized access.
 19	ErrNotAuthed = errors.New("you are not authorized to do this")
 20
 21	// ErrSystemMalfunction represents a general system error returned to clients.
 22	ErrSystemMalfunction = errors.New("something went wrong")
 23
 24	// ErrInvalidRepo represents an attempt to access a non-existent repo.
 25	ErrInvalidRepo = errors.New("invalid repo")
 26
 27	// ErrInvalidRequest represents an invalid request.
 28	ErrInvalidRequest = errors.New("invalid request")
 29
 30	// ErrMaxConnections represents a maximum connection limit being reached.
 31	ErrMaxConnections = errors.New("too many connections, try again later")
 32
 33	// ErrTimeout is returned when the maximum read timeout is exceeded.
 34	ErrTimeout = errors.New("I/O timeout reached")
 35)
 36
 37// Git protocol commands.
 38const (
 39	ReceivePackBin   = "git-receive-pack"
 40	UploadPackBin    = "git-upload-pack"
 41	UploadArchiveBin = "git-upload-archive"
 42)
 43
 44// UploadPack runs the git upload-pack protocol against the provided repo.
 45func UploadPack(in io.Reader, out io.Writer, er io.Writer, repoDir string) error {
 46	exists, err := fileExists(repoDir)
 47	if !exists {
 48		return ErrInvalidRepo
 49	}
 50	if err != nil {
 51		return err
 52	}
 53	return RunGit(in, out, er, "", UploadPackBin[4:], repoDir)
 54}
 55
 56// UploadArchive runs the git upload-archive protocol against the provided repo.
 57func UploadArchive(in io.Reader, out io.Writer, er io.Writer, repoDir string) error {
 58	exists, err := fileExists(repoDir)
 59	if !exists {
 60		return ErrInvalidRepo
 61	}
 62	if err != nil {
 63		return err
 64	}
 65	return RunGit(in, out, er, "", UploadArchiveBin[4:], repoDir)
 66}
 67
 68// ReceivePack runs the git receive-pack protocol against the provided repo.
 69func ReceivePack(in io.Reader, out io.Writer, er io.Writer, repoDir string) error {
 70	if err := RunGit(in, out, er, "", ReceivePackBin[4:], repoDir); err != nil {
 71		return err
 72	}
 73	return EnsureDefaultBranch(in, out, er, repoDir)
 74}
 75
 76// RunGit runs a git command in the given repo.
 77func RunGit(in io.Reader, out io.Writer, err io.Writer, dir string, args ...string) error {
 78	c := git.NewCommand(args...)
 79	return c.RunInDirWithOptions(dir, git.RunInDirOptions{
 80		Stdin:  in,
 81		Stdout: out,
 82		Stderr: err,
 83	})
 84}
 85
 86// WritePktline encodes and writes a pktline to the given writer.
 87func WritePktline(w io.Writer, v ...interface{}) {
 88	msg := fmt.Sprintln(v...)
 89	pkt := pktline.NewEncoder(w)
 90	if err := pkt.EncodeString(msg); err != nil {
 91		log.Debugf("git: error writing pkt-line message: %s", err)
 92	}
 93	if err := pkt.Flush(); err != nil {
 94		log.Debugf("git: error flushing pkt-line message: %s", err)
 95	}
 96}
 97
 98// EnsureWithin ensures the given repo is within the repos directory.
 99func EnsureWithin(reposDir string, repo string) error {
100	repoDir := filepath.Join(reposDir, repo)
101	absRepos, err := filepath.Abs(reposDir)
102	if err != nil {
103		log.Debugf("failed to get absolute path for repo: %s", err)
104		return ErrSystemMalfunction
105	}
106	absRepo, err := filepath.Abs(repoDir)
107	if err != nil {
108		log.Debugf("failed to get absolute path for repos: %s", err)
109		return ErrSystemMalfunction
110	}
111
112	// ensure the repo is within the repos directory
113	if !strings.HasPrefix(absRepo, absRepos) {
114		log.Debugf("repo path is outside of repos directory: %s", absRepo)
115		return ErrInvalidRepo
116	}
117
118	return nil
119}
120
121func fileExists(path string) (bool, error) {
122	_, err := os.Stat(path)
123	if err == nil {
124		return true, nil
125	}
126	if os.IsNotExist(err) {
127		return false, nil
128	}
129	return true, err
130}
131
132func EnsureDefaultBranch(in io.Reader, out io.Writer, er io.Writer, repoPath string) error {
133	r, err := git.Open(repoPath)
134	if err != nil {
135		return err
136	}
137	brs, err := r.Branches()
138	if err != nil {
139		return err
140	}
141	if len(brs) == 0 {
142		return fmt.Errorf("no branches found")
143	}
144	// Rename the default branch to the first branch available
145	_, err = r.HEAD()
146	if err == git.ErrReferenceNotExist {
147		err = RunGit(in, out, er, repoPath, "branch", "-M", brs[0])
148		if err != nil {
149			return err
150		}
151	}
152	if err != nil && err != git.ErrReferenceNotExist {
153		return err
154	}
155	return nil
156}