1package git
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "io"
8 "path/filepath"
9 "strings"
10
11 "github.com/charmbracelet/log"
12 "github.com/charmbracelet/soft-serve/git"
13 "github.com/go-git/go-git/v5/plumbing/format/pktline"
14 gitm "github.com/gogs/git-module"
15)
16
17var (
18
19 // ErrNotAuthed represents unauthorized access.
20 ErrNotAuthed = errors.New("you are not authorized to do this")
21
22 // ErrSystemMalfunction represents a general system error returned to clients.
23 ErrSystemMalfunction = errors.New("something went wrong")
24
25 // ErrInvalidRepo represents an attempt to access a non-existent repo.
26 ErrInvalidRepo = errors.New("invalid repo")
27
28 // ErrInvalidRequest represents an invalid request.
29 ErrInvalidRequest = errors.New("invalid request")
30
31 // ErrMaxConnections represents a maximum connection limit being reached.
32 ErrMaxConnections = errors.New("too many connections, try again later")
33
34 // ErrTimeout is returned when the maximum read timeout is exceeded.
35 ErrTimeout = errors.New("I/O timeout reached")
36)
37
38// WritePktline encodes and writes a pktline to the given writer.
39func WritePktline(w io.Writer, v ...interface{}) error {
40 msg := fmt.Sprintln(v...)
41 pkt := pktline.NewEncoder(w)
42 if err := pkt.EncodeString(msg); err != nil {
43 return fmt.Errorf("git: error writing pkt-line message: %w", err)
44 }
45 if err := pkt.Flush(); err != nil {
46 return fmt.Errorf("git: error flushing pkt-line message: %w", err)
47 }
48
49 return nil
50}
51
52// WritePktlineErr writes an error pktline to the given writer.
53func WritePktlineErr(w io.Writer, err error) error {
54 return WritePktline(w, "ERR", err.Error())
55}
56
57// EnsureWithin ensures the given repo is within the repos directory.
58func EnsureWithin(reposDir string, repo string) error {
59 repoDir := filepath.Join(reposDir, repo)
60 absRepos, err := filepath.Abs(reposDir)
61 if err != nil {
62 log.Debugf("failed to get absolute path for repo: %s", err)
63 return ErrSystemMalfunction
64 }
65 absRepo, err := filepath.Abs(repoDir)
66 if err != nil {
67 log.Debugf("failed to get absolute path for repos: %s", err)
68 return ErrSystemMalfunction
69 }
70
71 // ensure the repo is within the repos directory
72 if !strings.HasPrefix(absRepo, absRepos) {
73 log.Debugf("repo path is outside of repos directory: %s", absRepo)
74 return ErrInvalidRepo
75 }
76
77 return nil
78}
79
80// EnsureDefaultBranch ensures the repo has a default branch.
81// It will prefer choosing "main" or "master" if available.
82func EnsureDefaultBranch(ctx context.Context, scmd ServiceCommand) error {
83 r, err := git.Open(scmd.Dir)
84 if err != nil {
85 return err
86 }
87 brs, err := r.Branches()
88 if err != nil {
89 return err
90 }
91 if len(brs) == 0 {
92 return fmt.Errorf("no branches found")
93 }
94 // Rename the default branch to the first branch available
95 _, err = r.HEAD()
96 if err == git.ErrReferenceNotExist {
97 branch := brs[0]
98 // Prefer "main" or "master" as the default branch
99 for _, b := range brs {
100 if b == "main" || b == "master" {
101 branch = b
102 break
103 }
104 }
105
106 if _, err := r.SymbolicRef(git.HEAD, git.RefsHeads+branch, gitm.SymbolicRefOptions{
107 CommandOptions: gitm.CommandOptions{
108 Context: ctx,
109 },
110 }); err != nil {
111 return err
112 }
113 }
114 if err != nil && err != git.ErrReferenceNotExist {
115 return err
116 }
117 return nil
118}