git.go

  1package git
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"io"
  8	"os"
  9	"os/exec"
 10	"path/filepath"
 11	"strings"
 12
 13	"github.com/charmbracelet/log"
 14	"github.com/charmbracelet/soft-serve/git"
 15	"github.com/charmbracelet/soft-serve/server/config"
 16	"github.com/go-git/go-git/v5/plumbing/format/pktline"
 17	"golang.org/x/sync/errgroup"
 18)
 19
 20var (
 21
 22	// ErrNotAuthed represents unauthorized access.
 23	ErrNotAuthed = errors.New("you are not authorized to do this")
 24
 25	// ErrSystemMalfunction represents a general system error returned to clients.
 26	ErrSystemMalfunction = errors.New("something went wrong")
 27
 28	// ErrInvalidRepo represents an attempt to access a non-existent repo.
 29	ErrInvalidRepo = errors.New("invalid repo")
 30
 31	// ErrInvalidRequest represents an invalid request.
 32	ErrInvalidRequest = errors.New("invalid request")
 33
 34	// ErrMaxConnections represents a maximum connection limit being reached.
 35	ErrMaxConnections = errors.New("too many connections, try again later")
 36
 37	// ErrTimeout is returned when the maximum read timeout is exceeded.
 38	ErrTimeout = errors.New("I/O timeout reached")
 39)
 40
 41// Git protocol commands.
 42const (
 43	ReceivePackBin   = "git-receive-pack"
 44	UploadPackBin    = "git-upload-pack"
 45	UploadArchiveBin = "git-upload-archive"
 46)
 47
 48// UploadPack runs the git upload-pack protocol against the provided repo.
 49func UploadPack(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoDir string, envs ...string) error {
 50	exists, err := fileExists(repoDir)
 51	if !exists {
 52		return ErrInvalidRepo
 53	}
 54	if err != nil {
 55		return err
 56	}
 57	return RunGit(ctx, in, out, er, "", envs, UploadPackBin[4:], repoDir)
 58}
 59
 60// UploadArchive runs the git upload-archive protocol against the provided repo.
 61func UploadArchive(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoDir string, envs ...string) error {
 62	exists, err := fileExists(repoDir)
 63	if !exists {
 64		return ErrInvalidRepo
 65	}
 66	if err != nil {
 67		return err
 68	}
 69	return RunGit(ctx, in, out, er, "", envs, UploadArchiveBin[4:], repoDir)
 70}
 71
 72// ReceivePack runs the git receive-pack protocol against the provided repo.
 73func ReceivePack(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoDir string, envs ...string) error {
 74	if err := RunGit(ctx, in, out, er, "", envs, ReceivePackBin[4:], repoDir); err != nil {
 75		return err
 76	}
 77	return EnsureDefaultBranch(ctx, in, out, er, repoDir)
 78}
 79
 80// RunGit runs a git command in the given repo.
 81func RunGit(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, dir string, envs []string, args ...string) error {
 82	cfg := config.FromContext(ctx)
 83	logger := log.FromContext(ctx).WithPrefix("rungit")
 84	c := exec.CommandContext(ctx, "git", args...)
 85	c.Dir = dir
 86	c.Env = append(os.Environ(), envs...)
 87	c.Env = append(c.Env, "PATH="+os.Getenv("PATH"))
 88	c.Env = append(c.Env, "SOFT_SERVE_DEBUG="+os.Getenv("SOFT_SERVE_DEBUG"))
 89	if cfg != nil {
 90		c.Env = append(c.Env, "SOFT_SERVE_LOG_FORMAT="+cfg.Log.Format)
 91		c.Env = append(c.Env, "SOFT_SERVE_LOG_TIME_FORMAT="+cfg.Log.TimeFormat)
 92	}
 93
 94	stdin, err := c.StdinPipe()
 95	if err != nil {
 96		logger.Error("failed to get stdin pipe", "err", err)
 97		return err
 98	}
 99
100	stdout, err := c.StdoutPipe()
101	if err != nil {
102		logger.Error("failed to get stdout pipe", "err", err)
103		return err
104	}
105
106	stderr, err := c.StderrPipe()
107	if err != nil {
108		logger.Error("failed to get stderr pipe", "err", err)
109		return err
110	}
111
112	if err := c.Start(); err != nil {
113		logger.Error("failed to start command", "err", err)
114		return err
115	}
116
117	errg, ctx := errgroup.WithContext(ctx)
118
119	// stdin
120	errg.Go(func() error {
121		defer stdin.Close()
122
123		_, err := io.Copy(stdin, in)
124		return err
125	})
126
127	// stdout
128	errg.Go(func() error {
129		_, err := io.Copy(out, stdout)
130		return err
131	})
132
133	// stderr
134	errg.Go(func() error {
135		_, err := io.Copy(er, stderr)
136		return err
137	})
138
139	if err := errg.Wait(); err != nil {
140		logger.Error("while copying output", "err", err)
141	}
142
143	// Wait for the command to finish
144	return c.Wait()
145}
146
147// WritePktline encodes and writes a pktline to the given writer.
148func WritePktline(w io.Writer, v ...interface{}) {
149	msg := fmt.Sprintln(v...)
150	pkt := pktline.NewEncoder(w)
151	if err := pkt.EncodeString(msg); err != nil {
152		log.Debugf("git: error writing pkt-line message: %s", err)
153	}
154	if err := pkt.Flush(); err != nil {
155		log.Debugf("git: error flushing pkt-line message: %s", err)
156	}
157}
158
159// EnsureWithin ensures the given repo is within the repos directory.
160func EnsureWithin(reposDir string, repo string) error {
161	repoDir := filepath.Join(reposDir, repo)
162	absRepos, err := filepath.Abs(reposDir)
163	if err != nil {
164		log.Debugf("failed to get absolute path for repo: %s", err)
165		return ErrSystemMalfunction
166	}
167	absRepo, err := filepath.Abs(repoDir)
168	if err != nil {
169		log.Debugf("failed to get absolute path for repos: %s", err)
170		return ErrSystemMalfunction
171	}
172
173	// ensure the repo is within the repos directory
174	if !strings.HasPrefix(absRepo, absRepos) {
175		log.Debugf("repo path is outside of repos directory: %s", absRepo)
176		return ErrInvalidRepo
177	}
178
179	return nil
180}
181
182func fileExists(path string) (bool, error) {
183	_, err := os.Stat(path)
184	if err == nil {
185		return true, nil
186	}
187	if os.IsNotExist(err) {
188		return false, nil
189	}
190	return true, err
191}
192
193func EnsureDefaultBranch(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoPath string) error {
194	r, err := git.Open(repoPath)
195	if err != nil {
196		return err
197	}
198	brs, err := r.Branches()
199	if err != nil {
200		return err
201	}
202	if len(brs) == 0 {
203		return fmt.Errorf("no branches found")
204	}
205	// Rename the default branch to the first branch available
206	_, err = r.HEAD()
207	if err == git.ErrReferenceNotExist {
208		if _, err := r.SymbolicRef(git.HEAD, git.RefsHeads+brs[0]); err != nil {
209			return err
210		}
211	}
212	if err != nil && err != git.ErrReferenceNotExist {
213		return err
214	}
215	return nil
216}