1package cmd
  2
  3import (
  4	"bufio"
  5	"fmt"
  6	"strings"
  7
  8	"github.com/charmbracelet/keygen"
  9	"github.com/charmbracelet/soft-serve/server/hooks"
 10	"github.com/charmbracelet/ssh"
 11	"github.com/spf13/cobra"
 12	gossh "golang.org/x/crypto/ssh"
 13)
 14
 15// hookCommand handles Soft Serve internal API git hook requests.
 16func hookCommand() *cobra.Command {
 17	preReceiveCmd := &cobra.Command{
 18		Use:               "pre-receive",
 19		Short:             "Run git pre-receive hook",
 20		PersistentPreRunE: checkIfInternal,
 21		RunE: func(cmd *cobra.Command, args []string) error {
 22			_, s := fromContext(cmd)
 23			hks := cmd.Context().Value(HooksCtxKey).(hooks.Hooks)
 24			repoName := getRepoName(s)
 25			opts := make([]hooks.HookArg, 0)
 26			scanner := bufio.NewScanner(s)
 27			for scanner.Scan() {
 28				fields := strings.Fields(scanner.Text())
 29				if len(fields) != 3 {
 30					return fmt.Errorf("invalid pre-receive hook input: %s", scanner.Text())
 31				}
 32				opts = append(opts, hooks.HookArg{
 33					OldSha:  fields[0],
 34					NewSha:  fields[1],
 35					RefName: fields[2],
 36				})
 37			}
 38			hks.PreReceive(s, s.Stderr(), repoName, opts)
 39			return nil
 40		},
 41	}
 42
 43	updateCmd := &cobra.Command{
 44		Use:               "update",
 45		Short:             "Run git update hook",
 46		Args:              cobra.ExactArgs(3),
 47		PersistentPreRunE: checkIfInternal,
 48		RunE: func(cmd *cobra.Command, args []string) error {
 49			_, s := fromContext(cmd)
 50			hks := cmd.Context().Value(HooksCtxKey).(hooks.Hooks)
 51			repoName := getRepoName(s)
 52			hks.Update(s, s.Stderr(), repoName, hooks.HookArg{
 53				RefName: args[0],
 54				OldSha:  args[1],
 55				NewSha:  args[2],
 56			})
 57			return nil
 58		},
 59	}
 60
 61	postReceiveCmd := &cobra.Command{
 62		Use:               "post-receive",
 63		Short:             "Run git post-receive hook",
 64		PersistentPreRunE: checkIfInternal,
 65		RunE: func(cmd *cobra.Command, _ []string) error {
 66			_, s := fromContext(cmd)
 67			hks := cmd.Context().Value(HooksCtxKey).(hooks.Hooks)
 68			repoName := getRepoName(s)
 69			opts := make([]hooks.HookArg, 0)
 70			scanner := bufio.NewScanner(s)
 71			for scanner.Scan() {
 72				fields := strings.Fields(scanner.Text())
 73				if len(fields) != 3 {
 74					return fmt.Errorf("invalid post-receive hook input: %s", scanner.Text())
 75				}
 76				opts = append(opts, hooks.HookArg{
 77					OldSha:  fields[0],
 78					NewSha:  fields[1],
 79					RefName: fields[2],
 80				})
 81			}
 82			hks.PostReceive(s, s.Stderr(), repoName, opts)
 83			return nil
 84		},
 85	}
 86
 87	postUpdateCmd := &cobra.Command{
 88		Use:               "post-update",
 89		Short:             "Run git post-update hook",
 90		PersistentPreRunE: checkIfInternal,
 91		RunE: func(cmd *cobra.Command, args []string) error {
 92			_, s := fromContext(cmd)
 93			hks := cmd.Context().Value(HooksCtxKey).(hooks.Hooks)
 94			repoName := getRepoName(s)
 95			hks.PostUpdate(s, s.Stderr(), repoName, args...)
 96			return nil
 97		},
 98	}
 99
100	hookCmd := &cobra.Command{
101		Use:          "hook",
102		Short:        "Run git server hooks",
103		Hidden:       true,
104		SilenceUsage: true,
105	}
106
107	hookCmd.AddCommand(
108		preReceiveCmd,
109		updateCmd,
110		postReceiveCmd,
111		postUpdateCmd,
112	)
113
114	return hookCmd
115}
116
117// Check if the session's public key matches the internal API key.
118func checkIfInternal(cmd *cobra.Command, _ []string) error {
119	cfg, s := fromContext(cmd)
120	pk := s.PublicKey()
121	kp, err := keygen.New(cfg.SSH.InternalKeyPath, nil, keygen.Ed25519)
122	if err != nil {
123		logger.Errorf("failed to read internal key: %v", err)
124		return err
125	}
126	priv, err := gossh.ParsePrivateKey(kp.PrivateKeyPEM())
127	if err != nil {
128		return err
129	}
130	if !ssh.KeysEqual(pk, priv.PublicKey()) {
131		return ErrUnauthorized
132	}
133	return nil
134}
135
136func getRepoName(s ssh.Session) string {
137	var repoName string
138	for _, env := range s.Environ() {
139		if strings.HasPrefix(env, "SOFT_SERVE_REPO_NAME=") {
140			return strings.TrimPrefix(env, "SOFT_SERVE_REPO_NAME=")
141		}
142	}
143	return repoName
144}