1package git
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"io"
  8	"os"
  9	"os/exec"
 10	"path/filepath"
 11	"strings"
 12
 13	"github.com/charmbracelet/log"
 14	"github.com/charmbracelet/soft-serve/git"
 15	"github.com/charmbracelet/soft-serve/server/config"
 16	"github.com/go-git/go-git/v5/plumbing/format/pktline"
 17	"golang.org/x/sync/errgroup"
 18)
 19
 20var (
 21
 22	// ErrNotAuthed represents unauthorized access.
 23	ErrNotAuthed = errors.New("you are not authorized to do this")
 24
 25	// ErrSystemMalfunction represents a general system error returned to clients.
 26	ErrSystemMalfunction = errors.New("something went wrong")
 27
 28	// ErrInvalidRepo represents an attempt to access a non-existent repo.
 29	ErrInvalidRepo = errors.New("invalid repo")
 30
 31	// ErrInvalidRequest represents an invalid request.
 32	ErrInvalidRequest = errors.New("invalid request")
 33
 34	// ErrMaxConnections represents a maximum connection limit being reached.
 35	ErrMaxConnections = errors.New("too many connections, try again later")
 36
 37	// ErrTimeout is returned when the maximum read timeout is exceeded.
 38	ErrTimeout = errors.New("I/O timeout reached")
 39)
 40
 41// Git protocol commands.
 42const (
 43	ReceivePackBin   = "git-receive-pack"
 44	UploadPackBin    = "git-upload-pack"
 45	UploadArchiveBin = "git-upload-archive"
 46)
 47
 48// UploadPack runs the git upload-pack protocol against the provided repo.
 49func UploadPack(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoDir string, envs ...string) error {
 50	exists, err := fileExists(repoDir)
 51	if !exists {
 52		return ErrInvalidRepo
 53	}
 54	if err != nil {
 55		return err
 56	}
 57	return RunGit(ctx, in, out, er, "", envs, UploadPackBin[4:], repoDir)
 58}
 59
 60// UploadArchive runs the git upload-archive protocol against the provided repo.
 61func UploadArchive(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoDir string, envs ...string) error {
 62	exists, err := fileExists(repoDir)
 63	if !exists {
 64		return ErrInvalidRepo
 65	}
 66	if err != nil {
 67		return err
 68	}
 69	return RunGit(ctx, in, out, er, "", envs, UploadArchiveBin[4:], repoDir)
 70}
 71
 72// ReceivePack runs the git receive-pack protocol against the provided repo.
 73func ReceivePack(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoDir string, envs ...string) error {
 74	if err := RunGit(ctx, in, out, er, "", envs, ReceivePackBin[4:], repoDir); err != nil {
 75		return err
 76	}
 77	return EnsureDefaultBranch(ctx, in, out, er, repoDir)
 78}
 79
 80// RunGit runs a git command in the given repo.
 81func RunGit(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, dir string, envs []string, args ...string) error {
 82	cfg := config.FromContext(ctx)
 83	logger := log.FromContext(ctx).WithPrefix("rungit")
 84	c := exec.CommandContext(ctx, "git", args...)
 85	c.Dir = dir
 86	c.Env = append(c.Env, envs...)
 87	c.Env = append(c.Env, "PATH="+os.Getenv("PATH"))
 88	c.Env = append(c.Env, "SOFT_SERVE_DEBUG="+os.Getenv("SOFT_SERVE_DEBUG"))
 89	if cfg != nil {
 90		c.Env = append(c.Env, "SOFT_SERVE_LOG_FORMAT="+cfg.LogFormat)
 91	}
 92
 93	stdin, err := c.StdinPipe()
 94	if err != nil {
 95		logger.Error("failed to get stdin pipe", "err", err)
 96		return err
 97	}
 98
 99	stdout, err := c.StdoutPipe()
100	if err != nil {
101		logger.Error("failed to get stdout pipe", "err", err)
102		return err
103	}
104
105	stderr, err := c.StderrPipe()
106	if err != nil {
107		logger.Error("failed to get stderr pipe", "err", err)
108		return err
109	}
110
111	if err := c.Start(); err != nil {
112		logger.Error("failed to start command", "err", err)
113		return err
114	}
115
116	errg, ctx := errgroup.WithContext(ctx)
117
118	// stdin
119	errg.Go(func() error {
120		defer stdin.Close()
121
122		_, err := io.Copy(stdin, in)
123		return err
124	})
125
126	// stdout
127	errg.Go(func() error {
128		_, err := io.Copy(out, stdout)
129		return err
130	})
131
132	// stderr
133	errg.Go(func() error {
134		_, err := io.Copy(er, stderr)
135		return err
136	})
137
138	if err := errg.Wait(); err != nil {
139		logger.Error("while copying output", "err", err)
140	}
141
142	// Wait for the command to finish
143	return c.Wait()
144}
145
146// WritePktline encodes and writes a pktline to the given writer.
147func WritePktline(w io.Writer, v ...interface{}) {
148	msg := fmt.Sprintln(v...)
149	pkt := pktline.NewEncoder(w)
150	if err := pkt.EncodeString(msg); err != nil {
151		log.Debugf("git: error writing pkt-line message: %s", err)
152	}
153	if err := pkt.Flush(); err != nil {
154		log.Debugf("git: error flushing pkt-line message: %s", err)
155	}
156}
157
158// EnsureWithin ensures the given repo is within the repos directory.
159func EnsureWithin(reposDir string, repo string) error {
160	repoDir := filepath.Join(reposDir, repo)
161	absRepos, err := filepath.Abs(reposDir)
162	if err != nil {
163		log.Debugf("failed to get absolute path for repo: %s", err)
164		return ErrSystemMalfunction
165	}
166	absRepo, err := filepath.Abs(repoDir)
167	if err != nil {
168		log.Debugf("failed to get absolute path for repos: %s", err)
169		return ErrSystemMalfunction
170	}
171
172	// ensure the repo is within the repos directory
173	if !strings.HasPrefix(absRepo, absRepos) {
174		log.Debugf("repo path is outside of repos directory: %s", absRepo)
175		return ErrInvalidRepo
176	}
177
178	return nil
179}
180
181func fileExists(path string) (bool, error) {
182	_, err := os.Stat(path)
183	if err == nil {
184		return true, nil
185	}
186	if os.IsNotExist(err) {
187		return false, nil
188	}
189	return true, err
190}
191
192func EnsureDefaultBranch(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoPath string) error {
193	r, err := git.Open(repoPath)
194	if err != nil {
195		return err
196	}
197	brs, err := r.Branches()
198	if err != nil {
199		return err
200	}
201	if len(brs) == 0 {
202		return fmt.Errorf("no branches found")
203	}
204	// Rename the default branch to the first branch available
205	_, err = r.HEAD()
206	if err == git.ErrReferenceNotExist {
207		err = RunGit(ctx, in, out, er, repoPath, []string{}, "branch", "-M", brs[0])
208		if err != nil {
209			return err
210		}
211	}
212	if err != nil && err != git.ErrReferenceNotExist {
213		return err
214	}
215	return nil
216}