1package git
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "io"
8 "path/filepath"
9 "strings"
10
11 "github.com/charmbracelet/log"
12 "github.com/charmbracelet/soft-serve/git"
13 "github.com/go-git/go-git/v5/plumbing/format/pktline"
14)
15
16var (
17
18 // ErrUnauthorized represents unauthorized access.
19 ErrUnauthorized = errors.New("you are not authorized to do this")
20
21 // ErrSystemMalfunction represents a general system error returned to clients.
22 ErrSystemMalfunction = errors.New("something went wrong")
23
24 // ErrNotExist represents an attempt to access a non-existent repo.
25 ErrNotExist = errors.New("repo does not exist")
26
27 // ErrInvalidRequest represents an invalid request.
28 ErrInvalidRequest = errors.New("invalid request")
29
30 // ErrMaxConnections represents a maximum connection limit being reached.
31 ErrMaxConnections = errors.New("too many connections, try again later")
32
33 // ErrTimeout is returned when the maximum read timeout is exceeded.
34 ErrTimeout = errors.New("I/O timeout reached")
35)
36
37// WritePktline encodes and writes a pktline to the given writer.
38func WritePktline(w io.Writer, v ...interface{}) error {
39 msg := fmt.Sprintln(v...)
40 pkt := pktline.NewEncoder(w)
41 if err := pkt.EncodeString(msg); err != nil {
42 return fmt.Errorf("git: error writing pkt-line message: %w", err)
43 }
44 if err := pkt.Flush(); err != nil {
45 return fmt.Errorf("git: error flushing pkt-line message: %w", err)
46 }
47
48 return nil
49}
50
51// WritePktlineErr writes an error pktline to the given writer.
52func WritePktlineErr(w io.Writer, err error) error {
53 return WritePktline(w, "ERR ", err.Error())
54}
55
56// EnsureWithin ensures the given repo is within the repos directory.
57func EnsureWithin(reposDir string, repo string) error {
58 repoDir := filepath.Join(reposDir, repo)
59 absRepos, err := filepath.Abs(reposDir)
60 if err != nil {
61 log.Debugf("failed to get absolute path for repo: %s", err)
62 return ErrSystemMalfunction
63 }
64 absRepo, err := filepath.Abs(repoDir)
65 if err != nil {
66 log.Debugf("failed to get absolute path for repos: %s", err)
67 return ErrSystemMalfunction
68 }
69
70 // ensure the repo is within the repos directory
71 if !strings.HasPrefix(absRepo, absRepos) {
72 log.Debugf("repo path is outside of repos directory: %s", absRepo)
73 return ErrNotExist
74 }
75
76 return nil
77}
78
79// EnsureDefaultBranch ensures the repo has a default branch.
80// It will prefer choosing "main" or "master" if available.
81func EnsureDefaultBranch(ctx context.Context, scmd ServiceCommand) error {
82 r, err := git.Open(scmd.Dir)
83 if err != nil {
84 return err
85 }
86 brs, err := r.Branches()
87 if err != nil {
88 return err
89 }
90 if len(brs) == 0 {
91 return fmt.Errorf("no branches found")
92 }
93 // Rename the default branch to the first branch available
94 _, err = r.HEAD()
95 if err == git.ErrReferenceNotExist {
96 branch := brs[0]
97 // Prefer "main" or "master" as the default branch
98 for _, b := range brs {
99 if b == "main" || b == "master" {
100 branch = b
101 break
102 }
103 }
104
105 cmd := git.NewCommand("branch", "-M", branch).WithContext(ctx)
106 if err := cmd.RunInDirWithOptions(scmd.Dir, git.RunInDirOptions{
107 Stdin: scmd.Stdin,
108 Stdout: scmd.Stdout,
109 Stderr: scmd.Stderr,
110 }); err != nil {
111 return err
112 }
113 }
114 if err != nil && err != git.ErrReferenceNotExist {
115 return err
116 }
117 return nil
118}