1package main
  2
  3import (
  4	"bufio"
  5	"fmt"
  6	"os"
  7	"path/filepath"
  8	"strings"
  9
 10	"github.com/charmbracelet/keygen"
 11	"github.com/charmbracelet/soft-serve/server/config"
 12	"github.com/spf13/cobra"
 13	gossh "golang.org/x/crypto/ssh"
 14)
 15
 16var (
 17	configPath string
 18
 19	hookCmd = &cobra.Command{
 20		Use:    "hook",
 21		Short:  "Run git server hooks",
 22		Long:   "Handles git server hooks. This includes pre-receive, update, and post-receive.",
 23		Hidden: true,
 24	}
 25
 26	preReceiveCmd = &cobra.Command{
 27		Use:   "pre-receive",
 28		Short: "Run git pre-receive hook",
 29		RunE: func(cmd *cobra.Command, args []string) error {
 30			c, s, err := commonInit()
 31			if err != nil {
 32				return err
 33			}
 34			defer c.Close() //nolint:errcheck
 35			defer s.Close() //nolint:errcheck
 36			in, err := s.StdinPipe()
 37			if err != nil {
 38				return err
 39			}
 40			scanner := bufio.NewScanner(os.Stdin)
 41			for scanner.Scan() {
 42				in.Write([]byte(scanner.Text()))
 43				in.Write([]byte("\n"))
 44			}
 45			in.Close() //nolint:errcheck
 46			b, err := s.Output("hook pre-receive")
 47			if err != nil {
 48				return err
 49			}
 50			cmd.Print(string(b))
 51			return nil
 52		},
 53	}
 54
 55	updateCmd = &cobra.Command{
 56		Use:   "update",
 57		Short: "Run git update hook",
 58		Args:  cobra.ExactArgs(3),
 59		RunE: func(cmd *cobra.Command, args []string) error {
 60			refName := args[0]
 61			oldSha := args[1]
 62			newSha := args[2]
 63			c, s, err := commonInit()
 64			if err != nil {
 65				return err
 66			}
 67			defer c.Close() //nolint:errcheck
 68			defer s.Close() //nolint:errcheck
 69			b, err := s.Output(fmt.Sprintf("hook update %s %s %s", refName, oldSha, newSha))
 70			if err != nil {
 71				return err
 72			}
 73			cmd.Print(string(b))
 74			return nil
 75		},
 76	}
 77
 78	postReceiveCmd = &cobra.Command{
 79		Use:   "post-receive",
 80		Short: "Run git post-receive hook",
 81		RunE: func(cmd *cobra.Command, args []string) error {
 82			c, s, err := commonInit()
 83			if err != nil {
 84				return err
 85			}
 86			defer c.Close() //nolint:errcheck
 87			defer s.Close() //nolint:errcheck
 88			in, err := s.StdinPipe()
 89			if err != nil {
 90				return err
 91			}
 92			scanner := bufio.NewScanner(os.Stdin)
 93			for scanner.Scan() {
 94				in.Write([]byte(scanner.Text()))
 95				in.Write([]byte("\n"))
 96			}
 97			in.Close() //nolint:errcheck
 98			b, err := s.Output("hook post-receive")
 99			if err != nil {
100				return err
101			}
102			cmd.Print(string(b))
103			return nil
104		},
105	}
106
107	postUpdateCmd = &cobra.Command{
108		Use:   "post-update",
109		Short: "Run git post-update hook",
110		RunE: func(cmd *cobra.Command, args []string) error {
111			c, s, err := commonInit()
112			if err != nil {
113				return err
114			}
115			defer c.Close() //nolint:errcheck
116			defer s.Close() //nolint:errcheck
117			b, err := s.Output(fmt.Sprintf("hook post-update %s", strings.Join(args, " ")))
118			if err != nil {
119				return err
120			}
121			cmd.Print(string(b))
122			return nil
123		},
124	}
125)
126
127func init() {
128	hookCmd.AddCommand(
129		preReceiveCmd,
130		updateCmd,
131		postReceiveCmd,
132		postUpdateCmd,
133	)
134
135	hookCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to config file")
136}
137
138func commonInit() (c *gossh.Client, s *gossh.Session, err error) {
139	cfg, err := config.ParseConfig(configPath)
140	if err != nil {
141		return
142	}
143
144	// Use absolute path.
145	cfg.DataPath = filepath.Dir(configPath)
146
147	// Git runs the hook within the repository's directory.
148	// Get the working directory to determine the repository name.
149	wd, err := os.Getwd()
150	if err != nil {
151		return
152	}
153
154	rs, err := filepath.Abs(filepath.Join(cfg.DataPath, "repos"))
155	if err != nil {
156		return
157	}
158
159	if !strings.HasPrefix(wd, rs) {
160		err = fmt.Errorf("hook must be run from within repository directory")
161		return
162	}
163	repoName := strings.TrimPrefix(wd, rs)
164	repoName = strings.TrimPrefix(repoName, fmt.Sprintf("%c", os.PathSeparator))
165	c, err = newClient(cfg)
166	if err != nil {
167		return
168	}
169	s, err = newSession(c)
170	if err != nil {
171		return
172	}
173	s.Setenv("SOFT_SERVE_REPO_NAME", repoName)
174	return
175}
176
177func newClient(cfg *config.Config) (*gossh.Client, error) {
178	// Only accept the server's host key.
179	pk, err := keygen.New(cfg.SSH.KeyPath, nil, keygen.Ed25519)
180	if err != nil {
181		return nil, err
182	}
183	hostKey, err := gossh.ParsePrivateKey(pk.PrivateKeyPEM())
184	if err != nil {
185		return nil, err
186	}
187	ik, err := keygen.New(cfg.SSH.InternalKeyPath, nil, keygen.Ed25519)
188	if err != nil {
189		return nil, err
190	}
191	k, err := gossh.ParsePrivateKey(ik.PrivateKeyPEM())
192	if err != nil {
193		return nil, err
194	}
195	cc := &gossh.ClientConfig{
196		User: "internal",
197		Auth: []gossh.AuthMethod{
198			gossh.PublicKeys(k),
199		},
200		HostKeyCallback: gossh.FixedHostKey(hostKey.PublicKey()),
201	}
202	c, err := gossh.Dial("tcp", cfg.SSH.ListenAddr, cc)
203	if err != nil {
204		return nil, err
205	}
206	return c, nil
207}
208
209func newSession(c *gossh.Client) (*gossh.Session, error) {
210	s, err := c.NewSession()
211	if err != nil {
212		return nil, err
213	}
214	return s, nil
215}