Detailed changes
@@ -54,6 +54,9 @@ func init() {
func main() {
logger := NewDefaultLogger()
+ // Set global logger
+ log.SetDefault(logger)
+
// Set the max number of processes to the number of CPUs
// This is useful when running soft serve in a container
if _, err := maxprocs.Set(maxprocs.Logger(logger.Debugf)); err != nil {
@@ -200,13 +200,6 @@ func (r *Repository) CommitsByPage(ref *Reference, page, size int) (Commits, err
return commits, nil
}
-// UpdateServerInfo updates the repository server info.
-func (r *Repository) UpdateServerInfo() error {
- cmd := git.NewCommand("update-server-info")
- _, err := cmd.RunInDir(r.Path)
- return err
-}
-
// Config returns the config value for the given key.
func (r *Repository) Config(key string, opts ...ConfigOptions) (string, error) {
dir, err := gitDir(r.Repository)
@@ -0,0 +1,18 @@
+package git
+
+import (
+ "context"
+
+ "github.com/gogs/git-module"
+)
+
+// UpdateServerInfo updates the server info file for the given repo path.
+func UpdateServerInfo(ctx context.Context, path string) error {
+ if !isGitDir(path) {
+ return ErrNotAGitRepository
+ }
+
+ cmd := git.NewCommand("update-server-info").WithContext(ctx).WithTimeout(-1)
+ _, err := cmd.RunInDir(path)
+ return err
+}
@@ -1,6 +1,7 @@
package git
import (
+ "os"
"path/filepath"
"github.com/gobwas/glob"
@@ -49,3 +50,25 @@ func LatestFile(repo *Repository, pattern string) (string, string, error) {
}
return "", "", ErrFileNotFound
}
+
+// Returns true if path is a directory containing an `objects` directory and a
+// `HEAD` file.
+func isGitDir(path string) bool {
+ stat, err := os.Stat(filepath.Join(path, "objects"))
+ if err != nil {
+ return false
+ }
+ if !stat.IsDir() {
+ return false
+ }
+
+ stat, err = os.Stat(filepath.Join(path, "HEAD"))
+ if err != nil {
+ return false
+ }
+ if stat.IsDir() {
+ return false
+ }
+
+ return true
+}
@@ -32,6 +32,10 @@ func NewDefaultLogger() *log.Logger {
if debug, _ := strconv.ParseBool(os.Getenv("SOFT_SERVE_DEBUG")); debug {
logger.SetLevel(log.DebugLevel)
+
+ if verbose, _ := strconv.ParseBool(os.Getenv("SOFT_SERVE_VERBOSE")); verbose {
+ logger.SetReportCaller(true)
+ }
}
logger.SetTimeFormat(cfg.Log.TimeFormat)
@@ -36,16 +36,6 @@ func (d *SqliteBackend) PostUpdate(stdout io.Writer, stderr io.Writer, repo stri
var wg sync.WaitGroup
- // Update server info
- wg.Add(1)
- go func() {
- defer wg.Done()
- if err := updateServerInfo(d, repo); err != nil {
- d.logger.Error("error updating server-info", "repo", repo, "err", err)
- return
- }
- }()
-
// Populate last-modified file.
wg.Add(1)
go func() {
@@ -59,20 +49,6 @@ func (d *SqliteBackend) PostUpdate(stdout io.Writer, stderr io.Writer, repo stri
wg.Wait()
}
-func updateServerInfo(d *SqliteBackend, repo string) error {
- rr, err := d.Repository(repo)
- if err != nil {
- return err
- }
-
- r, err := rr.Open()
- if err != nil {
- return err
- }
-
- return r.UpdateServerInfo()
-}
-
func populateLastModified(d *SqliteBackend, repo string) error {
var rr *Repo
_rr, err := d.Repository(repo)
@@ -151,17 +151,12 @@ func (d *SqliteBackend) CreateRepository(name string, opts backend.RepositoryOpt
return err
}
- rr, err := git.Init(rp, true)
+ _, err := git.Init(rp, true)
if err != nil {
d.logger.Debug("failed to create repository", "err", err)
return err
}
- if err := rr.UpdateServerInfo(); err != nil {
- d.logger.Debug("failed to update server info", "err", err)
- return err
- }
-
return nil
}); err != nil {
d.logger.Debug("failed to create repository in database", "err", err)
@@ -114,6 +114,40 @@ type Config struct {
Backend backend.Backend `yaml:"-"`
}
+// Environ returns the config as a list of environment variables.
+func (c *Config) Environ() []string {
+ envs := []string{}
+ if c == nil {
+ return envs
+ }
+
+ // TODO: do this dynamically
+ envs = append(envs, []string{
+ fmt.Sprintf("SOFT_SERVE_NAME=%s", c.Name),
+ fmt.Sprintf("SOFT_SERVE_DATA_PATH=%s", c.DataPath),
+ fmt.Sprintf("SOFT_SERVE_INITIAL_ADMIN_KEYS=%s", strings.Join(c.InitialAdminKeys, "\n")),
+ fmt.Sprintf("SOFT_SERVE_SSH_LISTEN_ADDR=%s", c.SSH.ListenAddr),
+ fmt.Sprintf("SOFT_SERVE_SSH_PUBLIC_URL=%s", c.SSH.PublicURL),
+ fmt.Sprintf("SOFT_SERVE_SSH_KEY_PATH=%s", c.SSH.KeyPath),
+ fmt.Sprintf("SOFT_SERVE_SSH_CLIENT_KEY_PATH=%s", c.SSH.ClientKeyPath),
+ fmt.Sprintf("SOFT_SERVE_SSH_MAX_TIMEOUT=%d", c.SSH.MaxTimeout),
+ fmt.Sprintf("SOFT_SERVE_SSH_IDLE_TIMEOUT=%d", c.SSH.IdleTimeout),
+ fmt.Sprintf("SOFT_SERVE_GIT_LISTEN_ADDR=%s", c.Git.ListenAddr),
+ fmt.Sprintf("SOFT_SERVE_GIT_MAX_TIMEOUT=%d", c.Git.MaxTimeout),
+ fmt.Sprintf("SOFT_SERVE_GIT_IDLE_TIMEOUT=%d", c.Git.IdleTimeout),
+ fmt.Sprintf("SOFT_SERVE_GIT_MAX_CONNECTIONS=%d", c.Git.MaxConnections),
+ fmt.Sprintf("SOFT_SERVE_HTTP_LISTEN_ADDR=%s", c.HTTP.ListenAddr),
+ fmt.Sprintf("SOFT_SERVE_HTTP_TLS_KEY_PATH=%s", c.HTTP.TLSKeyPath),
+ fmt.Sprintf("SOFT_SERVE_HTTP_TLS_CERT_PATH=%s", c.HTTP.TLSCertPath),
+ fmt.Sprintf("SOFT_SERVE_HTTP_PUBLIC_URL=%s", c.HTTP.PublicURL),
+ fmt.Sprintf("SOFT_SERVE_STATS_LISTEN_ADDR=%s", c.Stats.ListenAddr),
+ fmt.Sprintf("SOFT_SERVE_LOG_FORMAT=%s", c.Log.Format),
+ fmt.Sprintf("SOFT_SERVE_LOG_TIME_FORMAT=%s", c.Log.TimeFormat),
+ }...)
+
+ return envs
+}
+
func parseConfig(path string) (*Config, error) {
dataPath := filepath.Dir(path)
cfg := &Config{
@@ -0,0 +1,105 @@
+package daemon
+
+import (
+ "context"
+ "errors"
+ "net"
+ "sync"
+ "time"
+)
+
+// connections is a synchronizes access to to a net.Conn pool.
+type connections struct {
+ m map[net.Conn]struct{}
+ mu sync.Mutex
+}
+
+func (m *connections) Add(c net.Conn) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.m[c] = struct{}{}
+}
+
+func (m *connections) Close(c net.Conn) error {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ err := c.Close()
+ delete(m.m, c)
+ return err
+}
+
+func (m *connections) Size() int {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ return len(m.m)
+}
+
+func (m *connections) CloseAll() error {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ var err error
+ for c := range m.m {
+ err = errors.Join(err, c.Close())
+ delete(m.m, c)
+ }
+
+ return err
+}
+
+// serverConn is a wrapper around a net.Conn that closes the connection when
+// the one of the timeouts is reached.
+type serverConn struct {
+ net.Conn
+
+ initTimeout time.Duration
+ idleTimeout time.Duration
+ maxDeadline time.Time
+ closeCanceler context.CancelFunc
+}
+
+var _ net.Conn = (*serverConn)(nil)
+
+func (c *serverConn) Write(p []byte) (n int, err error) {
+ c.updateDeadline()
+ n, err = c.Conn.Write(p)
+ if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
+ c.closeCanceler()
+ }
+ return
+}
+
+func (c *serverConn) Read(b []byte) (n int, err error) {
+ c.updateDeadline()
+ n, err = c.Conn.Read(b)
+ if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
+ c.closeCanceler()
+ }
+ return
+}
+
+func (c *serverConn) Close() (err error) {
+ err = c.Conn.Close()
+ if c.closeCanceler != nil {
+ c.closeCanceler()
+ }
+ return
+}
+
+func (c *serverConn) updateDeadline() {
+ switch {
+ case c.initTimeout > 0:
+ initTimeout := time.Now().Add(c.initTimeout)
+ c.initTimeout = 0
+ if initTimeout.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
+ c.Conn.SetDeadline(initTimeout)
+ return
+ }
+ case c.idleTimeout > 0:
+ idleDeadline := time.Now().Add(c.idleTimeout)
+ if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
+ c.Conn.SetDeadline(idleDeadline)
+ return
+ }
+ }
+ c.Conn.SetDeadline(c.maxDeadline)
+}
@@ -6,6 +6,7 @@ import (
"fmt"
"net"
"path/filepath"
+ "strings"
"sync"
"time"
@@ -41,40 +42,6 @@ var (
ErrServerClosed = fmt.Errorf("git: %w", net.ErrClosed)
)
-// connections synchronizes access to to a net.Conn pool.
-type connections struct {
- m map[net.Conn]struct{}
- mu sync.Mutex
-}
-
-func (m *connections) Add(c net.Conn) {
- m.mu.Lock()
- defer m.mu.Unlock()
- m.m[c] = struct{}{}
-}
-
-func (m *connections) Close(c net.Conn) {
- m.mu.Lock()
- defer m.mu.Unlock()
- _ = c.Close()
- delete(m.m, c)
-}
-
-func (m *connections) Size() int {
- m.mu.Lock()
- defer m.mu.Unlock()
- return len(m.m)
-}
-
-func (m *connections) CloseAll() {
- m.mu.Lock()
- defer m.mu.Unlock()
- for c := range m.m {
- _ = c.Close()
- delete(m.m, c)
- }
-}
-
// GitDaemon represents a Git daemon.
type GitDaemon struct {
ctx context.Context
@@ -213,26 +180,53 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
return
}
- gitPack := git.UploadPack
- counter := uploadPackGitCounter
- cmd := string(split[0])
- switch cmd {
- case git.UploadPackBin:
- gitPack = git.UploadPack
- case git.UploadArchiveBin:
- gitPack = git.UploadArchive
+ var handler git.ServiceHandler
+ var counter *prometheus.CounterVec
+ service := git.Service(split[0])
+ switch service {
+ case git.UploadPackService:
+ handler = git.UploadPack
+ counter = uploadPackGitCounter
+ case git.UploadArchiveService:
+ handler = git.UploadArchive
counter = uploadArchiveGitCounter
default:
d.fatal(c, git.ErrInvalidRequest)
return
}
- opts := bytes.Split(split[1], []byte{'\x00'})
- if len(opts) == 0 {
- d.fatal(c, git.ErrInvalidRequest)
+ opts := bytes.SplitN(split[1], []byte{0}, 3)
+ if len(opts) < 2 {
+ d.fatal(c, git.ErrInvalidRequest) // nolint: errcheck
return
}
+ host := strings.TrimPrefix(string(opts[1]), "host=")
+ extraParams := map[string]string{}
+
+ if len(opts) > 2 {
+ buf := bytes.TrimPrefix(opts[2], []byte{0})
+ for _, o := range bytes.Split(buf, []byte{0}) {
+ opt := string(o)
+ if opt == "" {
+ continue
+ }
+
+ kv := strings.SplitN(opt, "=", 2)
+ if len(kv) != 2 {
+ d.logger.Errorf("git: invalid option %q", opt)
+ continue
+ }
+
+ extraParams[kv[0]] = kv[1]
+ }
+
+ version := extraParams["version"]
+ if version != "" {
+ d.logger.Debugf("git: protocol version %s", version)
+ }
+ }
+
be := d.be.WithContext(ctx)
if !be.AllowKeyless() {
d.fatal(c, git.ErrNotAuthed)
@@ -240,14 +234,21 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
}
name := utils.SanitizeRepo(string(opts[0]))
- d.logger.Debugf("git: connect %s %s %s", c.RemoteAddr(), cmd, name)
- defer d.logger.Debugf("git: disconnect %s %s %s", c.RemoteAddr(), cmd, name)
+ d.logger.Debugf("git: connect %s %s %s", c.RemoteAddr(), service, name)
+ defer d.logger.Debugf("git: disconnect %s %s %s", c.RemoteAddr(), service, name)
+
// git bare repositories should end in ".git"
// https://git-scm.com/docs/gitrepository-layout
repo := name + ".git"
reposDir := filepath.Join(d.cfg.DataPath, "repos")
if err := git.EnsureWithin(reposDir, repo); err != nil {
- d.fatal(c, err)
+ d.logger.Debugf("git: error ensuring repo path: %v", err)
+ d.fatal(c, git.ErrInvalidRepo)
+ return
+ }
+
+ if _, err := d.be.Repository(repo); err != nil {
+ d.fatal(c, git.ErrInvalidRepo)
return
}
@@ -261,9 +262,33 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
envs := []string{
"SOFT_SERVE_REPO_NAME=" + name,
"SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo),
+ "SOFT_SERVE_HOST=" + host,
}
- if err := gitPack(ctx, c, c, c, filepath.Join(reposDir, repo), envs...); err != nil {
+ // Add git protocol environment variable.
+ if len(extraParams) > 0 {
+ var gitProto string
+ for k, v := range extraParams {
+ if len(gitProto) > 0 {
+ gitProto += ":"
+ }
+ gitProto += k + "=" + v
+ }
+ envs = append(envs, "GIT_PROTOCOL="+gitProto)
+ }
+
+ envs = append(envs, d.cfg.Environ()...)
+
+ cmd := git.ServiceCommand{
+ Stdin: c,
+ Stdout: c,
+ Stderr: c,
+ Env: envs,
+ Dir: filepath.Join(reposDir, repo),
+ }
+
+ if err := handler(ctx, cmd); err != nil {
+ d.logger.Debugf("git: error handling request: %v", err)
d.fatal(c, err)
return
}
@@ -296,51 +321,3 @@ func (d *GitDaemon) Shutdown(ctx context.Context) error {
return err
}
}
-
-type serverConn struct {
- net.Conn
-
- idleTimeout time.Duration
- maxDeadline time.Time
- closeCanceler context.CancelFunc
-}
-
-func (c *serverConn) Write(p []byte) (n int, err error) {
- c.updateDeadline()
- n, err = c.Conn.Write(p)
- if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
- c.closeCanceler()
- }
- return
-}
-
-func (c *serverConn) Read(b []byte) (n int, err error) {
- c.updateDeadline()
- n, err = c.Conn.Read(b)
- if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil {
- c.closeCanceler()
- }
- return
-}
-
-func (c *serverConn) Close() (err error) {
- err = c.Conn.Close()
- if c.closeCanceler != nil {
- c.closeCanceler()
- }
- return
-}
-
-func (c *serverConn) updateDeadline() {
- switch {
- case c.idleTimeout > 0:
- idleDeadline := time.Now().Add(c.idleTimeout)
- if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
- c.Conn.SetDeadline(idleDeadline)
- return
- }
- fallthrough
- default:
- c.Conn.SetDeadline(c.maxDeadline)
- }
-}
@@ -5,16 +5,12 @@ import (
"errors"
"fmt"
"io"
- "os"
- "os/exec"
"path/filepath"
"strings"
"github.com/charmbracelet/log"
"github.com/charmbracelet/soft-serve/git"
- "github.com/charmbracelet/soft-serve/server/config"
"github.com/go-git/go-git/v5/plumbing/format/pktline"
- "golang.org/x/sync/errgroup"
)
var (
@@ -38,112 +34,6 @@ var (
ErrTimeout = errors.New("I/O timeout reached")
)
-// Git protocol commands.
-const (
- ReceivePackBin = "git-receive-pack"
- UploadPackBin = "git-upload-pack"
- UploadArchiveBin = "git-upload-archive"
-)
-
-// UploadPack runs the git upload-pack protocol against the provided repo.
-func UploadPack(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoDir string, envs ...string) error {
- exists, err := fileExists(repoDir)
- if !exists {
- return ErrInvalidRepo
- }
- if err != nil {
- return err
- }
- return RunGit(ctx, in, out, er, "", envs, UploadPackBin[4:], repoDir)
-}
-
-// UploadArchive runs the git upload-archive protocol against the provided repo.
-func UploadArchive(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoDir string, envs ...string) error {
- exists, err := fileExists(repoDir)
- if !exists {
- return ErrInvalidRepo
- }
- if err != nil {
- return err
- }
- return RunGit(ctx, in, out, er, "", envs, UploadArchiveBin[4:], repoDir)
-}
-
-// ReceivePack runs the git receive-pack protocol against the provided repo.
-func ReceivePack(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoDir string, envs ...string) error {
- if err := RunGit(ctx, in, out, er, "", envs, ReceivePackBin[4:], repoDir); err != nil {
- return err
- }
- return EnsureDefaultBranch(ctx, in, out, er, repoDir)
-}
-
-// RunGit runs a git command in the given repo.
-func RunGit(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, dir string, envs []string, args ...string) error {
- cfg := config.FromContext(ctx)
- logger := log.FromContext(ctx).WithPrefix("rungit")
- c := exec.CommandContext(ctx, "git", args...)
- c.Dir = dir
- c.Env = append(os.Environ(), envs...)
- c.Env = append(c.Env, "PATH="+os.Getenv("PATH"))
- c.Env = append(c.Env, "SOFT_SERVE_DEBUG="+os.Getenv("SOFT_SERVE_DEBUG"))
- if cfg != nil {
- c.Env = append(c.Env, "SOFT_SERVE_LOG_FORMAT="+cfg.Log.Format)
- c.Env = append(c.Env, "SOFT_SERVE_LOG_TIME_FORMAT="+cfg.Log.TimeFormat)
- }
-
- stdin, err := c.StdinPipe()
- if err != nil {
- logger.Error("failed to get stdin pipe", "err", err)
- return err
- }
-
- stdout, err := c.StdoutPipe()
- if err != nil {
- logger.Error("failed to get stdout pipe", "err", err)
- return err
- }
-
- stderr, err := c.StderrPipe()
- if err != nil {
- logger.Error("failed to get stderr pipe", "err", err)
- return err
- }
-
- if err := c.Start(); err != nil {
- logger.Error("failed to start command", "err", err)
- return err
- }
-
- errg, ctx := errgroup.WithContext(ctx)
-
- // stdin
- errg.Go(func() error {
- defer stdin.Close()
-
- _, err := io.Copy(stdin, in)
- return err
- })
-
- // stdout
- errg.Go(func() error {
- _, err := io.Copy(out, stdout)
- return err
- })
-
- // stderr
- errg.Go(func() error {
- _, err := io.Copy(er, stderr)
- return err
- })
-
- if err := errg.Wait(); err != nil {
- logger.Error("while copying output", "err", err)
- }
-
- // Wait for the command to finish
- return c.Wait()
-}
-
// WritePktline encodes and writes a pktline to the given writer.
func WritePktline(w io.Writer, v ...interface{}) {
msg := fmt.Sprintln(v...)
@@ -179,19 +69,10 @@ func EnsureWithin(reposDir string, repo string) error {
return nil
}
-func fileExists(path string) (bool, error) {
- _, err := os.Stat(path)
- if err == nil {
- return true, nil
- }
- if os.IsNotExist(err) {
- return false, nil
- }
- return true, err
-}
-
-func EnsureDefaultBranch(ctx context.Context, in io.Reader, out io.Writer, er io.Writer, repoPath string) error {
- r, err := git.Open(repoPath)
+// EnsureDefaultBranch ensures the repo has a default branch.
+// It will prefer choosing "main" or "master" if available.
+func EnsureDefaultBranch(ctx context.Context, scmd ServiceCommand) error {
+ r, err := git.Open(scmd.Dir)
if err != nil {
return err
}
@@ -205,8 +86,21 @@ func EnsureDefaultBranch(ctx context.Context, in io.Reader, out io.Writer, er io
// Rename the default branch to the first branch available
_, err = r.HEAD()
if err == git.ErrReferenceNotExist {
- err = RunGit(ctx, in, out, er, repoPath, []string{}, "branch", "-M", brs[0])
- if err != nil {
+ branch := brs[0]
+ // Prefer "main" or "master" as the default branch
+ for _, b := range brs {
+ if b == "main" || b == "master" {
+ branch = b
+ break
+ }
+ }
+
+ cmd := git.NewCommand("branch", "-M", branch).WithContext(ctx)
+ if err := cmd.RunInDirWithOptions(scmd.Dir, git.RunInDirOptions{
+ Stdin: scmd.Stdin,
+ Stdout: scmd.Stdout,
+ Stderr: scmd.Stderr,
+ }); err != nil {
return err
}
}
@@ -0,0 +1,186 @@
+package git
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+ "strings"
+
+ "github.com/charmbracelet/log"
+ "golang.org/x/sync/errgroup"
+)
+
+// Service is a Git daemon service.
+type Service string
+
+const (
+ // UploadPackService is the upload-pack service.
+ UploadPackService Service = "git-upload-pack"
+ // UploadArchiveService is the upload-archive service.
+ UploadArchiveService Service = "git-upload-archive"
+ // ReceivePackService is the receive-pack service.
+ ReceivePackService Service = "git-receive-pack"
+)
+
+// String returns the string representation of the service.
+func (s Service) String() string {
+ return string(s)
+}
+
+// Name returns the name of the service.
+func (s Service) Name() string {
+ return strings.TrimPrefix(s.String(), "git-")
+}
+
+// Handler is the service handler.
+func (s Service) Handler(ctx context.Context, cmd ServiceCommand) error {
+ switch s {
+ case UploadPackService, UploadArchiveService, ReceivePackService:
+ return gitServiceHandler(ctx, s, cmd)
+ default:
+ return fmt.Errorf("unsupported service: %s", s)
+ }
+}
+
+// ServiceHandler is a git service command handler.
+type ServiceHandler func(ctx context.Context, cmd ServiceCommand) error
+
+// gitServiceHandler is the default service handler using the git binary.
+func gitServiceHandler(ctx context.Context, svc Service, scmd ServiceCommand) error {
+ cmd := exec.CommandContext(ctx, "git", "-c", "uploadpack.allowFilter=true", svc.Name()) // nolint: gosec
+ cmd.Dir = scmd.Dir
+ if len(scmd.Args) > 0 {
+ cmd.Args = append(cmd.Args, scmd.Args...)
+ }
+
+ cmd.Args = append(cmd.Args, ".")
+
+ cmd.Env = os.Environ()
+ if len(scmd.Env) > 0 {
+ cmd.Env = append(cmd.Env, scmd.Env...)
+ }
+
+ if scmd.CmdFunc != nil {
+ scmd.CmdFunc(cmd)
+ }
+
+ var (
+ err error
+ stdin io.WriteCloser
+ stdout io.ReadCloser
+ stderr io.ReadCloser
+ )
+
+ if scmd.Stdin != nil {
+ stdin, err = cmd.StdinPipe()
+ if err != nil {
+ return err
+ }
+ }
+
+ if scmd.Stdout != nil {
+ stdout, err = cmd.StdoutPipe()
+ if err != nil {
+ return err
+ }
+ }
+
+ if scmd.Stderr != nil {
+ stderr, err = cmd.StderrPipe()
+ if err != nil {
+ return err
+ }
+ }
+
+ log.Debugf("git service command in %q: %s", cmd.Dir, cmd.String())
+ if err := cmd.Start(); err != nil {
+ return err
+ }
+
+ errg, ctx := errgroup.WithContext(ctx)
+
+ // stdin
+ if scmd.Stdin != nil {
+ errg.Go(func() error {
+ if scmd.StdinHandler != nil {
+ return scmd.StdinHandler(scmd.Stdin, stdin)
+ } else {
+ return defaultStdinHandler(scmd.Stdin, stdin)
+ }
+ })
+ }
+
+ // stdout
+ if scmd.Stdout != nil {
+ errg.Go(func() error {
+ if scmd.StdoutHandler != nil {
+ return scmd.StdoutHandler(scmd.Stdout, stdout)
+ } else {
+ return defaultStdoutHandler(scmd.Stdout, stdout)
+ }
+ })
+ }
+
+ // stderr
+ if scmd.Stderr != nil {
+ errg.Go(func() error {
+ if scmd.StderrHandler != nil {
+ return scmd.StderrHandler(scmd.Stderr, stderr)
+ } else {
+ return defaultStderrHandler(scmd.Stderr, stderr)
+ }
+ })
+ }
+
+ return errors.Join(errg.Wait(), cmd.Wait())
+}
+
+// ServiceCommand is used to run a git service command.
+type ServiceCommand struct {
+ Stdin io.Reader
+ Stdout io.Writer
+ Stderr io.Writer
+ Dir string
+ Env []string
+ Args []string
+
+ // Modifier functions
+ CmdFunc func(*exec.Cmd)
+ StdinHandler func(io.Reader, io.WriteCloser) error
+ StdoutHandler func(io.Writer, io.ReadCloser) error
+ StderrHandler func(io.Writer, io.ReadCloser) error
+}
+
+func defaultStdinHandler(in io.Reader, stdin io.WriteCloser) error {
+ defer stdin.Close() // nolint: errcheck
+ _, err := io.Copy(stdin, in)
+ return err
+}
+
+func defaultStdoutHandler(out io.Writer, stdout io.ReadCloser) error {
+ _, err := io.Copy(out, stdout)
+ return err
+}
+
+func defaultStderrHandler(err io.Writer, stderr io.ReadCloser) error {
+ _, erro := io.Copy(err, stderr)
+ return erro
+}
+
+// UploadPack runs the git upload-pack protocol against the provided repo.
+func UploadPack(ctx context.Context, cmd ServiceCommand) error {
+ return gitServiceHandler(ctx, UploadPackService, cmd)
+}
+
+// UploadArchive runs the git upload-archive protocol against the provided repo.
+func UploadArchive(ctx context.Context, cmd ServiceCommand) error {
+ return gitServiceHandler(ctx, UploadArchiveService, cmd)
+}
+
+// ReceivePack runs the git receive-pack protocol against the provided repo.
+func ReceivePack(ctx context.Context, cmd ServiceCommand) error {
+ return gitServiceHandler(ctx, ReceivePackService, cmd)
+}
@@ -216,13 +216,13 @@ func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
return func(s ssh.Session) {
func() {
start := time.Now()
- cmd := s.Command()
+ cmdLine := s.Command()
ctx := s.Context()
be := ss.be.WithContext(ctx)
- if len(cmd) >= 2 && strings.HasPrefix(cmd[0], "git") {
- gc := cmd[0]
+
+ if len(cmdLine) >= 2 && strings.HasPrefix(cmdLine[0], "git") {
// repo should be in the form of "repo.git"
- name := utils.SanitizeRepo(cmd[1])
+ name := utils.SanitizeRepo(cmdLine[1])
pk := s.PublicKey()
ak := backend.MarshalAuthorizedKey(pk)
access := cfg.Backend.AccessLevelByPublicKey(name, pk)
@@ -240,12 +240,27 @@ func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
"SOFT_SERVE_REPO_NAME=" + name,
"SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo),
"SOFT_SERVE_PUBLIC_KEY=" + ak,
+ "SOFT_SERVE_USERNAME=" + ctx.User(),
}
- ss.logger.Debug("git middleware", "cmd", gc, "access", access.String())
+ // Add ssh session & config environ
+ envs = append(envs, s.Environ()...)
+ envs = append(envs, cfg.Environ()...)
+
repoDir := filepath.Join(reposDir, repo)
- switch gc {
- case git.ReceivePackBin:
+ service := git.Service(cmdLine[0])
+ cmd := git.ServiceCommand{
+ Stdin: s,
+ Stdout: s,
+ Stderr: s.Stderr(),
+ Env: envs,
+ Dir: repoDir,
+ }
+
+ ss.logger.Debug("git middleware", "cmd", service, "access", access.String())
+
+ switch service {
+ case git.ReceivePackService:
receivePackCounter.WithLabelValues(name).Inc()
defer func() {
receivePackSeconds.WithLabelValues(name).Add(time.Since(start).Seconds())
@@ -262,20 +277,27 @@ func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
}
createRepoCounter.WithLabelValues(name).Inc()
}
- if err := git.ReceivePack(s.Context(), s, s, s.Stderr(), repoDir, envs...); err != nil {
+
+ if err := git.ReceivePack(ctx, cmd); err != nil {
+ sshFatal(s, git.ErrSystemMalfunction)
+ }
+
+ if err := git.EnsureDefaultBranch(ctx, cmd); err != nil {
sshFatal(s, git.ErrSystemMalfunction)
}
+
+ receivePackCounter.WithLabelValues(name).Inc()
return
- case git.UploadPackBin, git.UploadArchiveBin:
+ case git.UploadPackService, git.UploadArchiveService:
if access < backend.ReadOnlyAccess {
sshFatal(s, git.ErrNotAuthed)
return
}
- gitPack := git.UploadPack
- switch gc {
- case git.UploadArchiveBin:
- gitPack = git.UploadArchive
+ handler := git.UploadPack
+ switch service {
+ case git.UploadArchiveService:
+ handler = git.UploadArchive
uploadArchiveCounter.WithLabelValues(name).Inc()
defer func() {
uploadArchiveSeconds.WithLabelValues(name).Add(time.Since(start).Seconds())
@@ -285,10 +307,9 @@ func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
defer func() {
uploadPackSeconds.WithLabelValues(name).Add(time.Since(start).Seconds())
}()
-
}
- err := gitPack(ctx, s, s, s.Stderr(), repoDir, envs...)
+ err := handler(ctx, cmd)
if errors.Is(err, git.ErrInvalidRepo) {
sshFatal(s, git.ErrInvalidRepo)
} else if err != nil {
@@ -0,0 +1,459 @@
+package web
+
+import (
+ "bytes"
+ "compress/gzip"
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "path/filepath"
+ "regexp"
+ "strings"
+ "time"
+
+ "github.com/charmbracelet/log"
+ gitb "github.com/charmbracelet/soft-serve/git"
+ "github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/config"
+ "github.com/charmbracelet/soft-serve/server/git"
+ "github.com/charmbracelet/soft-serve/server/utils"
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+ "goji.io/pat"
+ "goji.io/pattern"
+)
+
+// GitRoute is a route for git services.
+type GitRoute struct {
+ method string
+ pattern *regexp.Regexp
+ handler http.HandlerFunc
+
+ cfg *config.Config
+ be backend.Backend
+ logger *log.Logger
+}
+
+var _ Route = GitRoute{}
+
+// Match implements goji.Pattern.
+func (g GitRoute) Match(r *http.Request) *http.Request {
+ if g.method != r.Method {
+ return nil
+ }
+
+ re := g.pattern
+ ctx := r.Context()
+ if m := re.FindStringSubmatch(r.URL.Path); m != nil {
+ file := strings.Replace(r.URL.Path, m[1]+"/", "", 1)
+ repo := utils.SanitizeRepo(m[1]) + ".git"
+
+ var service git.Service
+ switch {
+ case strings.HasSuffix(r.URL.Path, git.UploadPackService.String()):
+ service = git.UploadPackService
+ case strings.HasSuffix(r.URL.Path, git.ReceivePackService.String()):
+ service = git.ReceivePackService
+ }
+
+ ctx = context.WithValue(ctx, pattern.Variable("service"), service.String())
+ ctx = context.WithValue(ctx, pattern.Variable("dir"), filepath.Join(g.cfg.DataPath, "repos", repo))
+ ctx = context.WithValue(ctx, pattern.Variable("repo"), repo)
+ ctx = context.WithValue(ctx, pattern.Variable("file"), file)
+
+ if g.cfg != nil {
+ ctx = config.WithContext(ctx, g.cfg)
+ }
+
+ if g.be != nil {
+ ctx = backend.WithContext(ctx, g.be.WithContext(ctx))
+ }
+
+ if g.logger != nil {
+ ctx = log.WithContext(ctx, g.logger)
+ }
+
+ return r.WithContext(ctx)
+ }
+
+ return nil
+}
+
+// ServeHTTP implements http.Handler.
+func (g GitRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ g.handler(w, r)
+}
+
+var (
+ gitHttpReceiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
+ Namespace: "soft_serve",
+ Subsystem: "http",
+ Name: "git_receive_pack_total",
+ Help: "The total number of git push requests",
+ }, []string{"repo"})
+
+ gitHttpUploadCounter = promauto.NewCounterVec(prometheus.CounterOpts{
+ Namespace: "soft_serve",
+ Subsystem: "http",
+ Name: "git_upload_pack_total",
+ Help: "The total number of git fetch/pull requests",
+ }, []string{"repo", "file"})
+)
+
+func gitRoutes(ctx context.Context, logger *log.Logger) []Route {
+ routes := make([]Route, 0)
+ cfg := config.FromContext(ctx)
+ be := backend.FromContext(ctx)
+
+ // Git services
+ // These routes don't handle authentication/authorization.
+ // This is handled through wrapping the handlers for each route.
+ // See below (withAccess).
+ // TODO: add lfs support
+ for _, route := range []GitRoute{
+ {
+ pattern: regexp.MustCompile("(.*?)/git-upload-pack$"),
+ method: http.MethodPost,
+ handler: serviceRpc,
+ },
+ {
+ pattern: regexp.MustCompile("(.*?)/git-receive-pack$"),
+ method: http.MethodPost,
+ handler: serviceRpc,
+ },
+ {
+ pattern: regexp.MustCompile("(.*?)/info/refs$"),
+ method: http.MethodGet,
+ handler: getInfoRefs,
+ },
+ {
+ pattern: regexp.MustCompile("(.*?)/HEAD$"),
+ method: http.MethodGet,
+ handler: getTextFile,
+ },
+ {
+ pattern: regexp.MustCompile("(.*?)/objects/info/alternates$"),
+ method: http.MethodGet,
+ handler: getTextFile,
+ },
+ {
+ pattern: regexp.MustCompile("(.*?)/objects/info/http-alternates$"),
+ method: http.MethodGet,
+ handler: getTextFile,
+ },
+ {
+ pattern: regexp.MustCompile("(.*?)/objects/info/packs$"),
+ method: http.MethodGet,
+ handler: getInfoPacks,
+ },
+ {
+ pattern: regexp.MustCompile("(.*?)/objects/info/[^/]*$"),
+ method: http.MethodGet,
+ handler: getTextFile,
+ },
+ {
+ pattern: regexp.MustCompile("(.*?)/objects/[0-9a-f]{2}/[0-9a-f]{38}$"),
+ method: http.MethodGet,
+ handler: getLooseObject,
+ },
+ {
+ pattern: regexp.MustCompile("(.*?)/objects/pack/pack-[0-9a-f]{40}\\.pack$"),
+ method: http.MethodGet,
+ handler: getPackFile,
+ },
+ {
+ pattern: regexp.MustCompile("(.*?)/objects/pack/pack-[0-9a-f]{40}\\.idx$"),
+ method: http.MethodGet,
+ handler: getIdxFile,
+ },
+ } {
+ route.cfg = cfg
+ route.be = be
+ route.logger = logger
+ route.handler = withAccess(route.handler)
+ routes = append(routes, route)
+ }
+
+ return routes
+}
+
+// withAccess handles auth.
+func withAccess(fn http.HandlerFunc) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+ be := backend.FromContext(ctx)
+ logger := log.FromContext(ctx)
+
+ if !be.AllowKeyless() {
+ renderForbidden(w)
+ return
+ }
+
+ repo := pat.Param(r, "repo")
+ service := git.Service(pat.Param(r, "service"))
+ access := be.AccessLevel(repo, "")
+
+ switch service {
+ case git.ReceivePackService:
+ if access < backend.ReadWriteAccess {
+ renderUnauthorized(w)
+ return
+ }
+
+ // Create the repo if it doesn't exist.
+ if _, err := be.Repository(repo); err != nil {
+ if _, err := be.CreateRepository(repo, backend.RepositoryOptions{}); err != nil {
+ logger.Error("failed to create repository", "repo", repo, "err", err)
+ renderInternalServerError(w)
+ return
+ }
+ }
+ default:
+ if access < backend.ReadOnlyAccess {
+ renderUnauthorized(w)
+ return
+ }
+ }
+
+ fn(w, r)
+ }
+}
+
+func serviceRpc(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+ logger := log.FromContext(ctx)
+ service, dir, repo := git.Service(pat.Param(r, "service")), pat.Param(r, "dir"), pat.Param(r, "repo")
+
+ if !isSmart(r, service) {
+ renderForbidden(w)
+ return
+ }
+
+ if service == git.ReceivePackService {
+ gitHttpReceiveCounter.WithLabelValues(repo)
+ }
+
+ w.Header().Set("Content-Type", fmt.Sprintf("application/x-%s-result", service))
+ w.Header().Set("Connection", "Keep-Alive")
+ w.Header().Set("Transfer-Encoding", "chunked")
+ w.Header().Set("X-Content-Type-Options", "nosniff")
+ w.WriteHeader(http.StatusOK)
+
+ version := r.Header.Get("Git-Protocol")
+
+ cmd := git.ServiceCommand{
+ Stdin: r.Body,
+ Stdout: w,
+ Dir: dir,
+ Args: []string{"--stateless-rpc"},
+ }
+
+ if len(version) != 0 {
+ cmd.Env = append(cmd.Env, fmt.Sprintf("GIT_PROTOCOL=%s", version))
+ }
+
+ // Handle gzip encoding
+ cmd.StdinHandler = func(in io.Reader, stdin io.WriteCloser) (err error) {
+ // We know that `in` is an `io.ReadCloser` because it's `r.Body`.
+ reader := in.(io.ReadCloser)
+ defer reader.Close() // nolint: errcheck
+ switch r.Header.Get("Content-Encoding") {
+ case "gzip":
+ reader, err = gzip.NewReader(reader)
+ if err != nil {
+ return err
+ }
+ defer reader.Close() // nolint: errcheck
+ }
+
+ _, err = io.Copy(stdin, reader)
+ return err
+ }
+
+ // Handle buffered output
+ // Useful when using proxies
+ cmd.StdoutHandler = func(out io.Writer, stdout io.ReadCloser) error {
+ // We know that `out` is an `http.ResponseWriter`.
+ flusher, ok := out.(http.Flusher)
+ if !ok {
+ return fmt.Errorf("expected http.ResponseWriter to be an http.Flusher, got %T", out)
+ }
+
+ p := make([]byte, 1024)
+ for {
+ nRead, err := stdout.Read(p)
+ if err == io.EOF {
+ break
+ }
+ nWrite, err := out.Write(p[:nRead])
+ if err != nil {
+ return err
+ }
+ if nRead != nWrite {
+ return fmt.Errorf("failed to write data: %d read, %d written", nRead, nWrite)
+ }
+ flusher.Flush()
+ }
+
+ return nil
+ }
+
+ if err := service.Handler(ctx, cmd); err != nil {
+ logger.Errorf("error executing service: %s", err)
+ }
+}
+
+func getInfoRefs(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+ logger := log.FromContext(ctx)
+ dir, repo, file := pat.Param(r, "dir"), pat.Param(r, "repo"), pat.Param(r, "file")
+ service := getServiceType(r)
+ version := r.Header.Get("Git-Protocol")
+
+ gitHttpUploadCounter.WithLabelValues(repo, file).Inc()
+
+ if service != "" && (service == git.UploadPackService || service == git.ReceivePackService) {
+ // Smart HTTP
+ var refs bytes.Buffer
+ cmd := git.ServiceCommand{
+ Stdout: &refs,
+ Dir: dir,
+ Args: []string{"--stateless-rpc", "--advertise-refs"},
+ }
+
+ if len(version) != 0 {
+ cmd.Env = append(cmd.Env, fmt.Sprintf("GIT_PROTOCOL=%s", version))
+ }
+
+ if err := service.Handler(ctx, cmd); err != nil {
+ logger.Errorf("error executing service: %s", err)
+ renderNotFound(w)
+ return
+ }
+
+ hdrNocache(w)
+ w.Header().Set("Content-Type", fmt.Sprintf("application/x-%s-advertisement", service))
+ w.WriteHeader(http.StatusOK)
+ if len(version) == 0 {
+ git.WritePktline(w, "# service="+service.String())
+ }
+
+ w.Write(refs.Bytes()) // nolint: errcheck
+ } else {
+ // Dumb HTTP
+ updateServerInfo(ctx, dir) // nolint: errcheck
+ hdrNocache(w)
+ sendFile("text/plain; charset=utf-8", w, r)
+ }
+}
+
+func getInfoPacks(w http.ResponseWriter, r *http.Request) {
+ hdrCacheForever(w)
+ sendFile("text/plain; charset=utf-8", w, r)
+}
+
+func getLooseObject(w http.ResponseWriter, r *http.Request) {
+ hdrCacheForever(w)
+ sendFile("application/x-git-loose-object", w, r)
+}
+
+func getPackFile(w http.ResponseWriter, r *http.Request) {
+ hdrCacheForever(w)
+ sendFile("application/x-git-packed-objects", w, r)
+}
+
+func getIdxFile(w http.ResponseWriter, r *http.Request) {
+ hdrCacheForever(w)
+ sendFile("application/x-git-packed-objects-toc", w, r)
+}
+
+func getTextFile(w http.ResponseWriter, r *http.Request) {
+ hdrNocache(w)
+ sendFile("text/plain", w, r)
+}
+
+func sendFile(contentType string, w http.ResponseWriter, r *http.Request) {
+ dir, file := pat.Param(r, "dir"), pat.Param(r, "file")
+ reqFile := filepath.Join(dir, file)
+
+ f, err := os.Stat(reqFile)
+ if os.IsNotExist(err) {
+ renderNotFound(w)
+ return
+ }
+
+ w.Header().Set("Content-Type", contentType)
+ w.Header().Set("Content-Length", fmt.Sprintf("%d", f.Size()))
+ w.Header().Set("Last-Modified", f.ModTime().Format(http.TimeFormat))
+ http.ServeFile(w, r, reqFile)
+}
+
+func getServiceType(r *http.Request) git.Service {
+ service := r.FormValue("service")
+ if !strings.HasPrefix(service, "git-") {
+ return ""
+ }
+
+ return git.Service(service)
+}
+
+func isSmart(r *http.Request, service git.Service) bool {
+ if r.Header.Get("Content-Type") == fmt.Sprintf("application/x-%s-request", service) {
+ return true
+ }
+ return false
+}
+
+func updateServerInfo(ctx context.Context, dir string) error {
+ return gitb.UpdateServerInfo(ctx, dir)
+}
+
+// HTTP error response handling functions
+
+func renderMethodNotAllowed(w http.ResponseWriter, r *http.Request) {
+ if r.Proto == "HTTP/1.1" {
+ w.WriteHeader(http.StatusMethodNotAllowed)
+ w.Write([]byte("Method Not Allowed")) // nolint: errcheck
+ } else {
+ w.WriteHeader(http.StatusBadRequest)
+ w.Write([]byte("Bad Request")) // nolint: errcheck
+ }
+}
+
+func renderNotFound(w http.ResponseWriter) {
+ w.WriteHeader(http.StatusNotFound)
+ w.Write([]byte("Not Found")) // nolint: errcheck
+}
+
+func renderUnauthorized(w http.ResponseWriter) {
+ w.WriteHeader(http.StatusUnauthorized)
+ w.Write([]byte("Unauthorized")) // nolint: errcheck
+}
+
+func renderForbidden(w http.ResponseWriter) {
+ w.WriteHeader(http.StatusForbidden)
+ w.Write([]byte("Forbidden")) // nolint: errcheck
+}
+
+func renderInternalServerError(w http.ResponseWriter) {
+ w.WriteHeader(http.StatusInternalServerError)
+ w.Write([]byte("Internal Server Error")) // nolint: errcheck
+}
+
+// Header writing functions
+
+func hdrNocache(w http.ResponseWriter) {
+ w.Header().Set("Expires", "Fri, 01 Jan 1980 00:00:00 GMT")
+ w.Header().Set("Pragma", "no-cache")
+ w.Header().Set("Cache-Control", "no-cache, max-age=0, must-revalidate")
+}
+
+func hdrCacheForever(w http.ResponseWriter) {
+ now := time.Now().Unix()
+ expires := now + 31536000
+ w.Header().Set("Date", fmt.Sprintf("%d", now))
+ w.Header().Set("Expires", fmt.Sprintf("%d", expires))
+ w.Header().Set("Cache-Control", "public, max-age=31536000")
+}
@@ -0,0 +1,94 @@
+package web
+
+import (
+ "net/http"
+ "net/url"
+ "path"
+ "text/template"
+
+ "github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/config"
+ "github.com/charmbracelet/soft-serve/server/utils"
+ "github.com/prometheus/client_golang/prometheus"
+ "github.com/prometheus/client_golang/prometheus/promauto"
+ "goji.io/pattern"
+)
+
+var goGetCounter = promauto.NewCounterVec(prometheus.CounterOpts{
+ Namespace: "soft_serve",
+ Subsystem: "http",
+ Name: "go_get_total",
+ Help: "The total number of go get requests",
+}, []string{"repo"})
+
+var repoIndexHTMLTpl = template.Must(template.New("index").Parse(`<!DOCTYPE html>
+<html lang="en">
+<head>
+ <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
+ <meta http-equiv="refresh" content="0; url=https://godoc.org/{{ .ImportRoot }}/{{.Repo}}">
+ <meta name="go-import" content="{{ .ImportRoot }}/{{ .Repo }} git {{ .Config.HTTP.PublicURL }}/{{ .Repo }}">
+</head>
+<body>
+Redirecting to docs at <a href="https://godoc.org/{{ .ImportRoot }}/{{ .Repo }}">godoc.org/{{ .ImportRoot }}/{{ .Repo }}</a>...
+</body>
+</html>`))
+
+// GoGetHandler handles go get requests.
+type GoGetHandler struct {
+ cfg *config.Config
+ be backend.Backend
+}
+
+var _ http.Handler = (*GoGetHandler)(nil)
+
+func (g GoGetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ repo := pattern.Path(r.Context())
+ repo = utils.SanitizeRepo(repo)
+ be := g.be.WithContext(r.Context())
+
+ // Handle go get requests.
+ //
+ // Always return a 200 status code, even if the repo doesn't exist.
+ //
+ // https://golang.org/cmd/go/#hdr-Remote_import_paths
+ // https://go.dev/ref/mod#vcs-branch
+ if r.URL.Query().Get("go-get") == "1" {
+ repo := repo
+ importRoot, err := url.Parse(g.cfg.HTTP.PublicURL)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ // find the repo
+ for {
+ if _, err := be.Repository(repo); err == nil {
+ break
+ }
+
+ if repo == "" || repo == "." || repo == "/" {
+ return
+ }
+
+ repo = path.Dir(repo)
+ }
+
+ if err := repoIndexHTMLTpl.Execute(w, struct {
+ Repo string
+ Config *config.Config
+ ImportRoot string
+ }{
+ Repo: url.PathEscape(repo),
+ Config: g.cfg,
+ ImportRoot: importRoot.Host,
+ }); err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ goGetCounter.WithLabelValues(repo).Inc()
+ return
+ }
+
+ http.NotFound(w, r)
+}
@@ -2,103 +2,31 @@ package web
import (
"context"
- "fmt"
"net/http"
- "net/url"
- "path"
- "path/filepath"
- "regexp"
- "strings"
- "text/template"
"time"
- "github.com/charmbracelet/log"
"github.com/charmbracelet/soft-serve/server/backend"
"github.com/charmbracelet/soft-serve/server/config"
- "github.com/charmbracelet/soft-serve/server/utils"
- "github.com/dustin/go-humanize"
- "github.com/prometheus/client_golang/prometheus"
- "github.com/prometheus/client_golang/prometheus/promauto"
- "goji.io"
- "goji.io/pat"
- "goji.io/pattern"
)
-var (
- gitHttpCounter = promauto.NewCounterVec(prometheus.CounterOpts{
- Namespace: "soft_serve",
- Subsystem: "http",
- Name: "git_fetch_pull_total",
- Help: "The total number of git fetch/pull requests",
- }, []string{"repo", "file"})
-
- goGetCounter = promauto.NewCounterVec(prometheus.CounterOpts{
- Namespace: "soft_serve",
- Subsystem: "http",
- Name: "go_get_total",
- Help: "The total number of go get requests",
- }, []string{"repo"})
-)
-
-// logWriter is a wrapper around http.ResponseWriter that allows us to capture
-// the HTTP status code and bytes written to the response.
-type logWriter struct {
- http.ResponseWriter
- code, bytes int
-}
-
-func (r *logWriter) Write(p []byte) (int, error) {
- written, err := r.ResponseWriter.Write(p)
- r.bytes += written
- return written, err
-}
-
-// Note this is generally only called when sending an HTTP error, so it's
-// important to set the `code` value to 200 as a default
-func (r *logWriter) WriteHeader(code int) {
- r.code = code
- r.ResponseWriter.WriteHeader(code)
-}
-
-func (s *HTTPServer) loggingMiddleware(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- start := time.Now()
- writer := &logWriter{code: http.StatusOK, ResponseWriter: w}
- s.logger.Debug("request",
- "method", r.Method,
- "uri", r.RequestURI,
- "addr", r.RemoteAddr)
- next.ServeHTTP(writer, r)
- elapsed := time.Since(start)
- s.logger.Debug("response",
- "status", fmt.Sprintf("%d %s", writer.code, http.StatusText(writer.code)),
- "bytes", humanize.Bytes(uint64(writer.bytes)),
- "time", elapsed)
- })
-}
-
// HTTPServer is an http server.
type HTTPServer struct {
- ctx context.Context
- cfg *config.Config
- be backend.Backend
- server *http.Server
- dirHandler http.Handler
- logger *log.Logger
+ ctx context.Context
+ cfg *config.Config
+ be backend.Backend
+ server *http.Server
}
+// NewHTTPServer creates a new HTTP server.
func NewHTTPServer(ctx context.Context) (*HTTPServer, error) {
cfg := config.FromContext(ctx)
- mux := goji.NewMux()
s := &HTTPServer{
- ctx: ctx,
- cfg: cfg,
- be: backend.FromContext(ctx),
- logger: log.FromContext(ctx).WithPrefix("http"),
- dirHandler: http.FileServer(http.Dir(filepath.Join(cfg.DataPath, "repos"))),
+ ctx: ctx,
+ cfg: cfg,
+ be: backend.FromContext(ctx),
server: &http.Server{
Addr: cfg.HTTP.ListenAddr,
- Handler: mux,
+ Handler: NewRouter(ctx),
ReadHeaderTimeout: time.Second * 10,
ReadTimeout: time.Second * 10,
WriteTimeout: time.Second * 10,
@@ -106,21 +34,6 @@ func NewHTTPServer(ctx context.Context) (*HTTPServer, error) {
},
}
- mux.Use(s.loggingMiddleware)
- for _, m := range []Matcher{
- getInfoRefs,
- getHead,
- getAlternates,
- getHTTPAlternates,
- getInfoPacks,
- getInfoFile,
- getLooseObject,
- getPackFile,
- getIdxFile,
- } {
- mux.HandleFunc(NewPattern(m), s.handleGit)
- }
- mux.HandleFunc(pat.Get("/*"), s.handleIndex)
return s, nil
}
@@ -141,193 +54,3 @@ func (s *HTTPServer) ListenAndServe() error {
func (s *HTTPServer) Shutdown(ctx context.Context) error {
return s.server.Shutdown(ctx)
}
-
-// Pattern is a pattern for matching a URL.
-// It matches against GET requests.
-type Pattern struct {
- match func(*url.URL) *match
-}
-
-// NewPattern returns a new Pattern with the given matcher.
-func NewPattern(m Matcher) *Pattern {
- return &Pattern{
- match: m,
- }
-}
-
-// Match is a match for a URL.
-//
-// It implements goji.Pattern.
-func (p *Pattern) Match(r *http.Request) *http.Request {
- if r.Method != "GET" {
- return nil
- }
-
- if m := p.match(r.URL); m != nil {
- ctx := context.WithValue(r.Context(), pattern.Variable("repo"), m.RepoPath)
- ctx = context.WithValue(ctx, pattern.Variable("file"), m.FilePath)
- return r.WithContext(ctx)
- }
- return nil
-}
-
-// Matcher finds a match in a *url.URL.
-type Matcher = func(*url.URL) *match
-
-var (
- getInfoRefs = func(u *url.URL) *match {
- return matchSuffix(u.Path, "/info/refs")
- }
-
- getHead = func(u *url.URL) *match {
- return matchSuffix(u.Path, "/HEAD")
- }
-
- getAlternates = func(u *url.URL) *match {
- return matchSuffix(u.Path, "/objects/info/alternates")
- }
-
- getHTTPAlternates = func(u *url.URL) *match {
- return matchSuffix(u.Path, "/objects/info/http-alternates")
- }
-
- getInfoPacks = func(u *url.URL) *match {
- return matchSuffix(u.Path, "/objects/info/packs")
- }
-
- getInfoFileRegexp = regexp.MustCompile(".*?(/objects/info/[^/]*)$")
- getInfoFile = func(u *url.URL) *match {
- return findStringSubmatch(u.Path, getInfoFileRegexp)
- }
-
- getLooseObjectRegexp = regexp.MustCompile(".*?(/objects/[0-9a-f]{2}/[0-9a-f]{38})$")
- getLooseObject = func(u *url.URL) *match {
- return findStringSubmatch(u.Path, getLooseObjectRegexp)
- }
-
- getPackFileRegexp = regexp.MustCompile(`.*?(/objects/pack/pack-[0-9a-f]{40}\.pack)$`)
- getPackFile = func(u *url.URL) *match {
- return findStringSubmatch(u.Path, getPackFileRegexp)
- }
-
- getIdxFileRegexp = regexp.MustCompile(`.*?(/objects/pack/pack-[0-9a-f]{40}\.idx)$`)
- getIdxFile = func(u *url.URL) *match {
- return findStringSubmatch(u.Path, getIdxFileRegexp)
- }
-)
-
-// match represents a match for a URL.
-type match struct {
- RepoPath, FilePath string
-}
-
-func matchSuffix(path, suffix string) *match {
- if !strings.HasSuffix(path, suffix) {
- return nil
- }
- repoPath := strings.Replace(path, suffix, "", 1)
- filePath := strings.Replace(path, repoPath+"/", "", 1)
- return &match{repoPath, filePath}
-}
-
-func findStringSubmatch(path string, prefix *regexp.Regexp) *match {
- m := prefix.FindStringSubmatch(path)
- if m == nil {
- return nil
- }
- suffix := m[1]
- repoPath := strings.Replace(path, suffix, "", 1)
- filePath := strings.Replace(path, repoPath+"/", "", 1)
- return &match{repoPath, filePath}
-}
-
-var repoIndexHTMLTpl = template.Must(template.New("index").Parse(`<!DOCTYPE html>
-<html lang="en">
-<head>
- <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
- <meta http-equiv="refresh" content="0; url=https://godoc.org/{{ .ImportRoot }}/{{.Repo}}">
- <meta name="go-import" content="{{ .ImportRoot }}/{{ .Repo }} git {{ .Config.HTTP.PublicURL }}/{{ .Repo }}">
-</head>
-<body>
-Redirecting to docs at <a href="https://godoc.org/{{ .ImportRoot }}/{{ .Repo }}">godoc.org/{{ .ImportRoot }}/{{ .Repo }}</a>...
-</body>
-</html>`))
-
-func (s *HTTPServer) handleIndex(w http.ResponseWriter, r *http.Request) {
- repo := pattern.Path(r.Context())
- repo = utils.SanitizeRepo(repo)
- be := s.be.WithContext(r.Context())
-
- // Handle go get requests.
- //
- // Always return a 200 status code, even if the repo doesn't exist.
- //
- // https://golang.org/cmd/go/#hdr-Remote_import_paths
- // https://go.dev/ref/mod#vcs-branch
- if r.URL.Query().Get("go-get") == "1" {
- repo := repo
- importRoot, err := url.Parse(s.cfg.HTTP.PublicURL)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
-
- // find the repo
- for {
- if _, err := be.Repository(repo); err == nil {
- break
- }
-
- if repo == "" || repo == "." || repo == "/" {
- return
- }
-
- repo = path.Dir(repo)
- }
-
- if err := repoIndexHTMLTpl.Execute(w, struct {
- Repo string
- Config *config.Config
- ImportRoot string
- }{
- Repo: url.PathEscape(repo),
- Config: s.cfg,
- ImportRoot: importRoot.Host,
- }); err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
-
- goGetCounter.WithLabelValues(repo).Inc()
- return
- }
-
- http.NotFound(w, r)
-}
-
-func (s *HTTPServer) handleGit(w http.ResponseWriter, r *http.Request) {
- repo := pat.Param(r, "repo")
- repo = utils.SanitizeRepo(repo) + ".git"
- be := s.be.WithContext(r.Context())
- if _, err := be.Repository(repo); err != nil {
- s.logger.Debug("repository not found", "repo", repo, "err", err)
- http.NotFound(w, r)
- return
- }
-
- if !s.cfg.Backend.AllowKeyless() {
- http.Error(w, "Forbidden", http.StatusForbidden)
- return
- }
-
- access := s.cfg.Backend.AccessLevel(repo, "")
- if access < backend.ReadOnlyAccess {
- http.Error(w, "Unauthorized", http.StatusUnauthorized)
- return
- }
-
- file := pat.Param(r, "file")
- gitHttpCounter.WithLabelValues(repo, file).Inc()
- r.URL.Path = fmt.Sprintf("/%s/%s", repo, file)
- s.dirHandler.ServeHTTP(w, r)
-}
@@ -0,0 +1,84 @@
+package web
+
+import (
+ "bufio"
+ "fmt"
+ "net"
+ "net/http"
+ "time"
+
+ "github.com/charmbracelet/log"
+ "github.com/dustin/go-humanize"
+)
+
+// logWriter is a wrapper around http.ResponseWriter that allows us to capture
+// the HTTP status code and bytes written to the response.
+type logWriter struct {
+ http.ResponseWriter
+ code, bytes int
+}
+
+var _ http.ResponseWriter = (*logWriter)(nil)
+
+var _ http.Flusher = (*logWriter)(nil)
+
+var _ http.Hijacker = (*logWriter)(nil)
+
+var _ http.CloseNotifier = (*logWriter)(nil)
+
+// Write implements http.ResponseWriter.
+func (r *logWriter) Write(p []byte) (int, error) {
+ written, err := r.ResponseWriter.Write(p)
+ r.bytes += written
+ return written, err
+}
+
+// Note this is generally only called when sending an HTTP error, so it's
+// important to set the `code` value to 200 as a default.
+func (r *logWriter) WriteHeader(code int) {
+ r.code = code
+ r.ResponseWriter.WriteHeader(code)
+}
+
+// Flush implements http.Flusher.
+func (r *logWriter) Flush() {
+ if f, ok := r.ResponseWriter.(http.Flusher); ok {
+ f.Flush()
+ }
+}
+
+// CloseNotify implements http.CloseNotifier.
+func (r *logWriter) CloseNotify() <-chan bool {
+ if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok {
+ return cn.CloseNotify()
+ }
+ return nil
+}
+
+// Hijack implements http.Hijacker.
+func (r *logWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ if h, ok := r.ResponseWriter.(http.Hijacker); ok {
+ return h.Hijack()
+ }
+ return nil, nil, fmt.Errorf("http.Hijacker not implemented")
+}
+
+// NewLoggingMiddleware returns a new logging middleware.
+func NewLoggingMiddleware(logger *log.Logger) func(http.Handler) http.Handler {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ start := time.Now()
+ writer := &logWriter{code: http.StatusOK, ResponseWriter: w}
+ logger.Debug("request",
+ "method", r.Method,
+ "uri", r.RequestURI,
+ "addr", r.RemoteAddr)
+ next.ServeHTTP(writer, r)
+ elapsed := time.Since(start)
+ logger.Debug("response",
+ "status", fmt.Sprintf("%d %s", writer.code, http.StatusText(writer.code)),
+ "bytes", humanize.Bytes(uint64(writer.bytes)),
+ "time", elapsed)
+ })
+ }
+}
@@ -0,0 +1,40 @@
+// Package server is the reusable server
+package web
+
+import (
+ "context"
+ "net/http"
+
+ "github.com/charmbracelet/log"
+ "github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/config"
+ "goji.io"
+ "goji.io/pat"
+)
+
+// Route is an interface for a route.
+type Route interface {
+ http.Handler
+ goji.Pattern
+}
+
+// NewRouter returns a new HTTP router.
+func NewRouter(ctx context.Context) *goji.Mux {
+ mux := goji.NewMux()
+ cfg := config.FromContext(ctx)
+ be := backend.FromContext(ctx)
+ logger := log.FromContext(ctx).WithPrefix("http")
+
+ // Middlewares
+ mux.Use(NewLoggingMiddleware(logger))
+
+ // Git routes
+ for _, service := range gitRoutes(ctx, logger) {
+ mux.Handle(service, service)
+ }
+
+ // go-get handler
+ mux.Handle(pat.Get("/*"), GoGetHandler{cfg, be})
+
+ return mux
+}