feat(server): create ssh client keypair

Ayman Bagabas created

Change summary

cmd/soft/migrate_config.go      | 121 ++--------------------------------
server/backend/sqlite/sqlite.go |  24 ++----
server/config/config.go         |   9 ++
server/config/file.go           |   5 +
server/daemon_test.go           |   7 +
server/jobs.go                  |  14 +++
server/server.go                |  14 +++
server/session_test.go          |   5 
8 files changed, 60 insertions(+), 139 deletions(-)

Detailed changes

cmd/soft/migrate_config.go 🔗

@@ -23,12 +23,12 @@ var (
 	migrateConfig = &cobra.Command{
 		Use:   "migrate-config",
 		Short: "Migrate config to new format",
-		RunE: func(cmd *cobra.Command, args []string) error {
+		RunE: func(_ *cobra.Command, _ []string) error {
 			keyPath := os.Getenv("SOFT_SERVE_KEY_PATH")
 			reposPath := os.Getenv("SOFT_SERVE_REPO_PATH")
 			bindAddr := os.Getenv("SOFT_SERVE_BIND_ADDRESS")
 			cfg := config.DefaultConfig()
-			sb, err := sqlite.NewSqliteBackend(cfg.DataPath)
+			sb, err := sqlite.NewSqliteBackend(cfg)
 			if err != nil {
 				return fmt.Errorf("failed to create sqlite backend: %w", err)
 			}
@@ -72,7 +72,7 @@ var (
 				return fmt.Errorf("failed to get tree: %w", err)
 			}
 
-			isJson := false
+			isJson := false // nolint: revive
 			te, err := tree.TreeEntry("config.yaml")
 			if err != nil {
 				te, err = tree.TreeEntry("config.json")
@@ -236,7 +236,7 @@ func isGitDir(path string) bool {
 	return true
 }
 
-// copyFile copies a single file from src to dst
+// copyFile copies a single file from src to dst.
 func copyFile(src, dst string) error {
 	var err error
 	var srcfd *os.File
@@ -246,12 +246,12 @@ func copyFile(src, dst string) error {
 	if srcfd, err = os.Open(src); err != nil {
 		return err
 	}
-	defer srcfd.Close()
+	defer srcfd.Close() // nolint: errcheck
 
 	if dstfd, err = os.Create(dst); err != nil {
 		return err
 	}
-	defer dstfd.Close()
+	defer dstfd.Close() // nolint: errcheck
 
 	if _, err = io.Copy(dstfd, srcfd); err != nil {
 		return err
@@ -262,7 +262,7 @@ func copyFile(src, dst string) error {
 	return os.Chmod(dst, srcinfo.Mode())
 }
 
-// copyDir copies a whole directory recursively
+// copyDir copies a whole directory recursively.
 func copyDir(src string, dst string) error {
 	var err error
 	var fds []os.DirEntry
@@ -296,112 +296,7 @@ func copyDir(src string, dst string) error {
 	return nil
 }
 
-// func copyDir(src, dst string) error {
-// 	entries, err := os.ReadDir(src)
-// 	if err != nil {
-// 		return err
-// 	}
-// 	for _, entry := range entries {
-// 		sourcePath := filepath.Join(src, entry.Name())
-// 		destPath := filepath.Join(dst, entry.Name())
-//
-// 		fileInfo, err := os.Stat(sourcePath)
-// 		if err != nil {
-// 			return err
-// 		}
-//
-// 		stat, ok := fileInfo.Sys().(*syscall.Stat_t)
-// 		if !ok {
-// 			return fmt.Errorf("failed to get raw syscall.Stat_t data for '%s'", sourcePath)
-// 		}
-//
-// 		switch fileInfo.Mode() & os.ModeType {
-// 		case os.ModeDir:
-// 			if err := createIfNotExists(destPath, 0755); err != nil {
-// 				return err
-// 			}
-// 			if err := copyDir(sourcePath, destPath); err != nil {
-// 				return err
-// 			}
-// 		case os.ModeSymlink:
-// 			if err := copySymLink(sourcePath, destPath); err != nil {
-// 				return err
-// 			}
-// 		default:
-// 			if err := copyFile(sourcePath, destPath); err != nil {
-// 				return err
-// 			}
-// 		}
-//
-// 		if err := os.Lchown(destPath, int(stat.Uid), int(stat.Gid)); err != nil {
-// 			return err
-// 		}
-//
-// 		fInfo, err := entry.Info()
-// 		if err != nil {
-// 			return err
-// 		}
-//
-// 		isSymlink := fInfo.Mode()&os.ModeSymlink != 0
-// 		if !isSymlink {
-// 			if err := os.Chmod(destPath, fInfo.Mode()); err != nil {
-// 				return err
-// 			}
-// 		}
-// 	}
-// 	return nil
-// }
-//
-// func copyFile(srcFile, dstFile string) error {
-// 	out, err := os.Create(dstFile)
-// 	if err != nil {
-// 		return err
-// 	}
-//
-// 	defer out.Close()
-//
-// 	in, err := os.Open(srcFile)
-// 	defer in.Close()
-// 	if err != nil {
-// 		return err
-// 	}
-//
-// 	_, err = io.Copy(out, in)
-// 	if err != nil {
-// 		return err
-// 	}
-//
-// 	return nil
-// }
-
-func exists(filePath string) bool {
-	if _, err := os.Stat(filePath); os.IsNotExist(err) {
-		return false
-	}
-
-	return true
-}
-
-func createIfNotExists(dir string, perm os.FileMode) error {
-	if exists(dir) {
-		return nil
-	}
-
-	if err := os.MkdirAll(dir, perm); err != nil {
-		return fmt.Errorf("failed to create directory: '%s', error: '%s'", dir, err.Error())
-	}
-
-	return nil
-}
-
-func copySymLink(source, dest string) error {
-	link, err := os.Readlink(source)
-	if err != nil {
-		return err
-	}
-	return os.Symlink(link, dest)
-}
-
+// Config is the configuration for the server.
 type Config struct {
 	Name         string       `yaml:"name" json:"name"`
 	Host         string       `yaml:"host" json:"host"`

server/backend/sqlite/sqlite.go 🔗

@@ -10,10 +10,10 @@ import (
 	"strings"
 	"text/template"
 
-	"github.com/charmbracelet/keygen"
 	"github.com/charmbracelet/log"
 	"github.com/charmbracelet/soft-serve/git"
 	"github.com/charmbracelet/soft-serve/server/backend"
+	"github.com/charmbracelet/soft-serve/server/config"
 	"github.com/charmbracelet/soft-serve/server/utils"
 	"github.com/jmoiron/sqlx"
 	_ "modernc.org/sqlite"
@@ -26,9 +26,9 @@ var (
 // SqliteBackend is a backend that uses a SQLite database as a Soft Serve
 // backend.
 type SqliteBackend struct {
+	cfg              *config.Config
 	dp               string
 	db               *sqlx.DB
-	ckp              string
 	AdditionalAdmins []string
 }
 
@@ -39,22 +39,12 @@ func (d *SqliteBackend) reposPath() string {
 }
 
 // NewSqliteBackend creates a new SqliteBackend.
-func NewSqliteBackend(dataPath string) (*SqliteBackend, error) {
+func NewSqliteBackend(cfg *config.Config) (*SqliteBackend, error) {
+	dataPath := cfg.DataPath
 	if err := os.MkdirAll(dataPath, 0755); err != nil {
 		return nil, err
 	}
 
-	ckp := filepath.Join(dataPath, "ssh", "soft_serve_client")
-	_, err := keygen.NewWithWrite(ckp, nil, keygen.Ed25519)
-	if err != nil {
-		return nil, err
-	}
-
-	ckp, err = filepath.Abs(ckp)
-	if err != nil {
-		return nil, err
-	}
-
 	db, err := sqlx.Connect("sqlite", filepath.Join(dataPath, "soft-serve.db"+
 		"?_pragma=busy_timeout(5000)&_pragma=foreign_keys(1)"))
 	if err != nil {
@@ -62,9 +52,9 @@ func NewSqliteBackend(dataPath string) (*SqliteBackend, error) {
 	}
 
 	d := &SqliteBackend{
+		cfg: cfg,
 		dp:  dataPath,
 		db:  db,
-		ckp: ckp,
 	}
 
 	if err := d.init(); err != nil {
@@ -186,8 +176,8 @@ func (d *SqliteBackend) ImportRepository(name string, remote string, opts backen
 		CommandOptions: git.CommandOptions{
 			Envs: []string{
 				fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile="%s" -o StrictHostKeyChecking=no -i "%s"`,
-					filepath.Join(filepath.Dir(d.ckp), "known_hosts"),
-					d.ckp,
+					filepath.Join(d.cfg.DataPath, "ssh", "known_hosts"),
+					filepath.Join(d.cfg.DataPath, d.cfg.SSH.ClientKeyPath),
 				),
 			},
 		},

server/config/config.go 🔗

@@ -21,6 +21,9 @@ type SSHConfig struct {
 	// KeyPath is the path to the SSH server's private key.
 	KeyPath string `env:"KEY_PATH" yaml:"key_path"`
 
+	// ClientKeyPath is the path to the SSH server's client private key.
+	ClientKeyPath string `env:"CLIENT_KEY_PATH" yaml:"client_key_path"`
+
 	// InternalKeyPath is the path to the SSH server's internal private key.
 	InternalKeyPath string `env:"INTERNAL_KEY_PATH" yaml:"internal_key_path"`
 
@@ -122,6 +125,11 @@ func DefaultConfig() *Config {
 		dataPath = "data"
 	}
 
+	dp, _ := filepath.Abs(dataPath)
+	if dp != "" {
+		dataPath = dp
+	}
+
 	cfg := &Config{
 		Name:     "Soft Serve",
 		DataPath: dataPath,
@@ -129,6 +137,7 @@ func DefaultConfig() *Config {
 			ListenAddr:      ":23231",
 			PublicURL:       "ssh://localhost:23231",
 			KeyPath:         filepath.Join("ssh", "soft_serve_host"),
+			ClientKeyPath:   filepath.Join("ssh", "soft_serve_client"),
 			InternalKeyPath: filepath.Join("ssh", "soft_serve_internal"),
 			MaxTimeout:      0,
 			IdleTimeout:     120,

server/config/file.go 🔗

@@ -24,6 +24,11 @@ ssh:
   # The relative path to the SSH server's private key.
   key_path: "{{ .SSH.KeyPath }}"
 
+  # The relative path to the SSH server's client private key.
+  # This key will be used to authenticate the server to make git requests to
+  # ssh remotes.
+  client_key_path: "{{ .SSH.ClientKeyPath }}"
+
   # The relative path to the SSH server's internal api private key.
   internal_key_path: "{{ .SSH.InternalKeyPath }}"
 

server/daemon_test.go 🔗

@@ -30,15 +30,16 @@ func TestMain(m *testing.M) {
 	os.Setenv("SOFT_SERVE_GIT_MAX_TIMEOUT", "100")
 	os.Setenv("SOFT_SERVE_GIT_IDLE_TIMEOUT", "1")
 	os.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", randomPort()))
-	fb, err := sqlite.NewSqliteBackend(tmp)
+	cfg := config.DefaultConfig()
+	d, err := NewGitDaemon(cfg)
 	if err != nil {
 		log.Fatal(err)
 	}
-	cfg := config.DefaultConfig().WithBackend(fb)
-	d, err := NewGitDaemon(cfg)
+	fb, err := sqlite.NewSqliteBackend(cfg)
 	if err != nil {
 		log.Fatal(err)
 	}
+	cfg = cfg.WithBackend(fb)
 	testDaemon = d
 	go func() {
 		if err := d.Start(); err != ErrServerClosed {

server/jobs.go 🔗

@@ -1,8 +1,11 @@
 package server
 
 import (
+	"fmt"
+	"path/filepath"
+
 	"github.com/charmbracelet/soft-serve/git"
-	"github.com/charmbracelet/soft-serve/server/backend"
+	"github.com/charmbracelet/soft-serve/server/config"
 )
 
 var (
@@ -12,7 +15,8 @@ var (
 )
 
 // mirrorJob runs the (pull) mirror job task.
-func mirrorJob(b backend.Backend) func() {
+func mirrorJob(cfg *config.Config) func() {
+	b := cfg.Backend
 	logger := logger.WithPrefix("server.mirrorJob")
 	return func() {
 		repos, err := b.Repositories()
@@ -31,6 +35,12 @@ func mirrorJob(b backend.Backend) func() {
 				}
 
 				cmd := git.NewCommand("remote", "update", "--prune")
+				cmd.AddEnvs(
+					fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile="%s" -o StrictHostKeyChecking=no -i "%s"`,
+						filepath.Join(cfg.DataPath, "ssh", "known_hosts"),
+						filepath.Join(cfg.DataPath, cfg.SSH.ClientKeyPath),
+					),
+				)
 				if _, err := cmd.RunInDir(r.Path); err != nil {
 					logger.Error("error running git remote update", "repo", repo.Name(), "err", err)
 				}

server/server.go 🔗

@@ -39,7 +39,7 @@ type Server struct {
 func NewServer(cfg *config.Config) (*Server, error) {
 	var err error
 	if cfg.Backend == nil {
-		sb, err := sqlite.NewSqliteBackend(cfg.DataPath)
+		sb, err := sqlite.NewSqliteBackend(cfg)
 		if err != nil {
 			logger.Fatal(err)
 		}
@@ -57,6 +57,16 @@ func NewServer(cfg *config.Config) (*Server, error) {
 		if err != nil {
 			return nil, err
 		}
+
+		// Create client key.
+		_, err = keygen.NewWithWrite(
+			filepath.Join(cfg.DataPath, cfg.SSH.ClientKeyPath),
+			nil,
+			keygen.Ed25519,
+		)
+		if err != nil {
+			return nil, err
+		}
 	}
 
 	srv := &Server{
@@ -66,7 +76,7 @@ func NewServer(cfg *config.Config) (*Server, error) {
 	}
 
 	// Add cron jobs.
-	srv.Cron.AddFunc(jobSpecs["mirror"], mirrorJob(cfg.Backend))
+	srv.Cron.AddFunc(jobSpecs["mirror"], mirrorJob(cfg))
 
 	srv.SSHServer, err = NewSSHServer(cfg, srv)
 	if err != nil {

server/session_test.go 🔗

@@ -51,11 +51,12 @@ func setup(tb testing.TB) *gossh.Session {
 		is.NoErr(os.Unsetenv("SOFT_SERVE_SSH_LISTEN_ADDR"))
 		is.NoErr(os.RemoveAll(dp))
 	})
-	fb, err := sqlite.NewSqliteBackend(dp)
+	cfg := config.DefaultConfig()
+	fb, err := sqlite.NewSqliteBackend(cfg)
 	if err != nil {
 		log.Fatal(err)
 	}
-	cfg := config.DefaultConfig().WithBackend(fb)
+	cfg = cfg.WithBackend(fb)
 	return testsession.New(tb, &ssh.Server{
 		Handler: bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256)(func(s ssh.Session) {
 			_, _, active := s.Pty()