1package cmd
  2
  3import (
  4	"bufio"
  5	"fmt"
  6	"path/filepath"
  7	"strings"
  8
  9	"github.com/charmbracelet/keygen"
 10	"github.com/charmbracelet/soft-serve/server/hooks"
 11	"github.com/charmbracelet/ssh"
 12	"github.com/spf13/cobra"
 13	gossh "golang.org/x/crypto/ssh"
 14)
 15
 16// hookCommand handles Soft Serve internal API git hook requests.
 17func hookCommand() *cobra.Command {
 18	preReceiveCmd := &cobra.Command{
 19		Use:               "pre-receive",
 20		Short:             "Run git pre-receive hook",
 21		PersistentPreRunE: checkIfInternal,
 22		RunE: func(cmd *cobra.Command, args []string) error {
 23			_, s := fromContext(cmd)
 24			hks := cmd.Context().Value(HooksCtxKey).(hooks.Hooks)
 25			repoName := getRepoName(s)
 26			opts := make([]hooks.HookArg, 0)
 27			scanner := bufio.NewScanner(s)
 28			for scanner.Scan() {
 29				fields := strings.Fields(scanner.Text())
 30				if len(fields) != 3 {
 31					return fmt.Errorf("invalid pre-receive hook input: %s", scanner.Text())
 32				}
 33				opts = append(opts, hooks.HookArg{
 34					OldSha:  fields[0],
 35					NewSha:  fields[1],
 36					RefName: fields[2],
 37				})
 38			}
 39			hks.PreReceive(s, s.Stderr(), repoName, opts)
 40			return nil
 41		},
 42	}
 43
 44	updateCmd := &cobra.Command{
 45		Use:               "update",
 46		Short:             "Run git update hook",
 47		Args:              cobra.ExactArgs(3),
 48		PersistentPreRunE: checkIfInternal,
 49		RunE: func(cmd *cobra.Command, args []string) error {
 50			_, s := fromContext(cmd)
 51			hks := cmd.Context().Value(HooksCtxKey).(hooks.Hooks)
 52			repoName := getRepoName(s)
 53			hks.Update(s, s.Stderr(), repoName, hooks.HookArg{
 54				RefName: args[0],
 55				OldSha:  args[1],
 56				NewSha:  args[2],
 57			})
 58			return nil
 59		},
 60	}
 61
 62	postReceiveCmd := &cobra.Command{
 63		Use:               "post-receive",
 64		Short:             "Run git post-receive hook",
 65		PersistentPreRunE: checkIfInternal,
 66		RunE: func(cmd *cobra.Command, _ []string) error {
 67			_, s := fromContext(cmd)
 68			hks := cmd.Context().Value(HooksCtxKey).(hooks.Hooks)
 69			repoName := getRepoName(s)
 70			opts := make([]hooks.HookArg, 0)
 71			scanner := bufio.NewScanner(s)
 72			for scanner.Scan() {
 73				fields := strings.Fields(scanner.Text())
 74				if len(fields) != 3 {
 75					return fmt.Errorf("invalid post-receive hook input: %s", scanner.Text())
 76				}
 77				opts = append(opts, hooks.HookArg{
 78					OldSha:  fields[0],
 79					NewSha:  fields[1],
 80					RefName: fields[2],
 81				})
 82			}
 83			hks.PostReceive(s, s.Stderr(), repoName, opts)
 84			return nil
 85		},
 86	}
 87
 88	postUpdateCmd := &cobra.Command{
 89		Use:               "post-update",
 90		Short:             "Run git post-update hook",
 91		PersistentPreRunE: checkIfInternal,
 92		RunE: func(cmd *cobra.Command, args []string) error {
 93			_, s := fromContext(cmd)
 94			hks := cmd.Context().Value(HooksCtxKey).(hooks.Hooks)
 95			repoName := getRepoName(s)
 96			hks.PostUpdate(s, s.Stderr(), repoName, args...)
 97			return nil
 98		},
 99	}
100
101	hookCmd := &cobra.Command{
102		Use:          "hook",
103		Short:        "Run git server hooks",
104		Hidden:       true,
105		SilenceUsage: true,
106	}
107
108	hookCmd.AddCommand(
109		preReceiveCmd,
110		updateCmd,
111		postReceiveCmd,
112		postUpdateCmd,
113	)
114
115	return hookCmd
116}
117
118// Check if the session's public key matches the internal API key.
119func checkIfInternal(cmd *cobra.Command, _ []string) error {
120	cfg, s := fromContext(cmd)
121	pk := s.PublicKey()
122	kp, err := keygen.New(filepath.Join(cfg.DataPath, cfg.SSH.InternalKeyPath), nil, keygen.Ed25519)
123	if err != nil {
124		logger.Errorf("failed to read internal key: %v", err)
125		return err
126	}
127	priv, err := gossh.ParsePrivateKey(kp.PrivateKeyPEM())
128	if err != nil {
129		return err
130	}
131	if !ssh.KeysEqual(pk, priv.PublicKey()) {
132		return ErrUnauthorized
133	}
134	return nil
135}
136
137func getRepoName(s ssh.Session) string {
138	var repoName string
139	for _, env := range s.Environ() {
140		if strings.HasPrefix(env, "SOFT_SERVE_REPO_NAME=") {
141			return strings.TrimPrefix(env, "SOFT_SERVE_REPO_NAME=")
142		}
143	}
144	return repoName
145}