1package server
2
3import (
4 "errors"
5 "fmt"
6 "io"
7 "os"
8 "path/filepath"
9 "strings"
10
11 "github.com/charmbracelet/log"
12 "github.com/charmbracelet/soft-serve/git"
13 "github.com/go-git/go-git/v5/plumbing/format/pktline"
14)
15
16var (
17
18 // ErrNotAuthed represents unauthorized access.
19 ErrNotAuthed = errors.New("you are not authorized to do this")
20
21 // ErrSystemMalfunction represents a general system error returned to clients.
22 ErrSystemMalfunction = errors.New("something went wrong")
23
24 // ErrInvalidRepo represents an attempt to access a non-existent repo.
25 ErrInvalidRepo = errors.New("invalid repo")
26
27 // ErrInvalidRequest represents an invalid request.
28 ErrInvalidRequest = errors.New("invalid request")
29
30 // ErrMaxConnections represents a maximum connection limit being reached.
31 ErrMaxConnections = errors.New("too many connections, try again later")
32
33 // ErrTimeout is returned when the maximum read timeout is exceeded.
34 ErrTimeout = errors.New("I/O timeout reached")
35)
36
37// Git protocol commands.
38const (
39 receivePackBin = "git-receive-pack"
40 uploadPackBin = "git-upload-pack"
41 uploadArchiveBin = "git-upload-archive"
42)
43
44// uploadPack runs the git upload-pack protocol against the provided repo.
45func uploadPack(in io.Reader, out io.Writer, er io.Writer, repoDir string) error {
46 exists, err := fileExists(repoDir)
47 if !exists {
48 return ErrInvalidRepo
49 }
50 if err != nil {
51 return err
52 }
53 return runGit(in, out, er, "", uploadPackBin[4:], repoDir)
54}
55
56// uploadArchive runs the git upload-archive protocol against the provided repo.
57func uploadArchive(in io.Reader, out io.Writer, er io.Writer, repoDir string) error {
58 exists, err := fileExists(repoDir)
59 if !exists {
60 return ErrInvalidRepo
61 }
62 if err != nil {
63 return err
64 }
65 return runGit(in, out, er, "", uploadArchiveBin[4:], repoDir)
66}
67
68// receivePack runs the git receive-pack protocol against the provided repo.
69func receivePack(in io.Reader, out io.Writer, er io.Writer, repoDir string) error {
70 if err := runGit(in, out, er, "", receivePackBin[4:], repoDir); err != nil {
71 return err
72 }
73 return ensureDefaultBranch(in, out, er, repoDir)
74}
75
76// runGit runs a git command in the given repo.
77func runGit(in io.Reader, out io.Writer, err io.Writer, dir string, args ...string) error {
78 c := git.NewCommand(args...)
79 return c.RunInDirWithOptions(dir, git.RunInDirOptions{
80 Stdin: in,
81 Stdout: out,
82 Stderr: err,
83 })
84}
85
86// writePktline encodes and writes a pktline to the given writer.
87func writePktline(w io.Writer, v ...interface{}) {
88 msg := fmt.Sprintln(v...)
89 pkt := pktline.NewEncoder(w)
90 if err := pkt.EncodeString(msg); err != nil {
91 log.Debugf("git: error writing pkt-line message: %s", err)
92 }
93 if err := pkt.Flush(); err != nil {
94 log.Debugf("git: error flushing pkt-line message: %s", err)
95 }
96}
97
98// ensureWithin ensures the given repo is within the repos directory.
99func ensureWithin(reposDir string, repo string) error {
100 repoDir := filepath.Join(reposDir, repo)
101 absRepos, err := filepath.Abs(reposDir)
102 if err != nil {
103 log.Debugf("failed to get absolute path for repo: %s", err)
104 return ErrSystemMalfunction
105 }
106 absRepo, err := filepath.Abs(repoDir)
107 if err != nil {
108 log.Debugf("failed to get absolute path for repos: %s", err)
109 return ErrSystemMalfunction
110 }
111
112 // ensure the repo is within the repos directory
113 if !strings.HasPrefix(absRepo, absRepos) {
114 log.Debugf("repo path is outside of repos directory: %s", absRepo)
115 return ErrInvalidRepo
116 }
117
118 return nil
119}
120
121func fileExists(path string) (bool, error) {
122 _, err := os.Stat(path)
123 if err == nil {
124 return true, nil
125 }
126 if os.IsNotExist(err) {
127 return false, nil
128 }
129 return true, err
130}
131
132func ensureDefaultBranch(in io.Reader, out io.Writer, er io.Writer, repoPath string) error {
133 r, err := git.Open(repoPath)
134 if err != nil {
135 return err
136 }
137 brs, err := r.Branches()
138 if err != nil {
139 return err
140 }
141 if len(brs) == 0 {
142 return fmt.Errorf("no branches found")
143 }
144 // Rename the default branch to the first branch available
145 _, err = r.HEAD()
146 if err == git.ErrReferenceNotExist {
147 err = runGit(in, out, er, repoPath, "branch", "-M", brs[0])
148 if err != nil {
149 return err
150 }
151 }
152 if err != nil && err != git.ErrReferenceNotExist {
153 return err
154 }
155 return nil
156}