.golangci-soft.yml 🔗
@@ -1,5 +1,6 @@
run:
tests: false
+ timeout: 5m
issues:
include:
Ayman Bagabas created
* refactor(server): abstract database from backend
Prepare for multi database driver support
* feat(server): add db models
* feat(db): add database migrations
* feat(db): add support to postgres
* feat: implement database store logic
* refactor: use db module and abstract backend logic
* fix(db): postgres migrate sql
* feat(db): add database query tracing
* refactor: move internal packages to server
* fix(config): normalize sqlite database path
* fix,feat: support custom log path and fix logging leak in hooks
* fix(test): race condition
* refactor: tidy up files and use middlewares
* chore: add test for repo commit command
Reference: https://github.com/charmbracelet/soft-serve/pull/331
* fix: lint errors
* fix: use utc time and fix git packp error format
* fix: lint errors
* fix: testscript on windows
* chore: format sql files
* fix: lint issues
re-enable revive linter
* refactor: clean up server/config
* feat: add admin command to manage server
* fix(db): use migration versions
* chore: add deprecation warning.
* refactor: move shared interfaces and errors to proto
* fix: increase golangci lint timeout
* feat: add move tests
.golangci-soft.yml | 1
.golangci.yml | 8
cmd/soft/admin.go | 75 +
cmd/soft/hook.go | 74
cmd/soft/man.go | 34
cmd/soft/migrate_config.go | 473 ++++----
cmd/soft/root.go | 111 +
cmd/soft/serve.go | 63 +
examples/setuid/main.go | 74 -
git/commit.go | 5
git/patch.go | 2
go.mod | 5
go.sum | 14
internal/log/log.go | 53
server/access/access.go | 4
server/access/access_test.go | 24
server/backend/backend.go | 64
server/backend/cache.go | 35
server/backend/collab.go | 78 +
server/backend/context.go | 11
server/backend/hooks.go | 82 +
server/backend/repo.go | 552 ++++++++-
server/backend/settings.go | 65 +
server/backend/sqlite/db.go | 141 --
server/backend/sqlite/error.go | 20
server/backend/sqlite/hooks.go | 76 -
server/backend/sqlite/repo.go | 202 ---
server/backend/sqlite/sql.go | 61 -
server/backend/sqlite/sqlite.go | 649 ------------
server/backend/sqlite/user.go | 365 ------
server/backend/user.go | 330 +++++
server/backend/utils.go | 5
server/cmd/set_username.go | 22
server/config/config.go | 272 ++--
server/config/config_test.go | 29
server/config/context.go | 20
server/config/file.go | 11
server/cron/cron.go | 26
server/daemon/conn.go | 6
server/daemon/daemon.go | 22
server/daemon/daemon_test.go | 23
server/db/context.go | 18
server/db/db.go | 88 +
server/db/errors.go | 48
server/db/logger.go | 135 ++
server/db/migrate/0001_create_tables.go | 134 ++
server/db/migrate/0001_create_tables_postgres.down.sql | 5
server/db/migrate/0001_create_tables_postgres.up.sql | 59 +
server/db/migrate/0001_create_tables_sqlite.down.sql | 5
server/db/migrate/0001_create_tables_sqlite.up.sql | 58 +
server/db/migrate/migrate.go | 142 ++
server/db/migrate/migrations.go | 62 +
server/db/models/collab.go | 12
server/db/models/public_key.go | 10
server/db/models/repo.go | 16
server/db/models/settings.go | 10
server/db/models/user.go | 12
server/errors/errors.go | 12
server/git/git.go | 13
server/git/service.go | 57
server/hooks/gen.go | 140 ++
server/hooks/hooks.go | 161 --
server/jobs.go | 9
server/proto/errors.go | 16
server/proto/repo.go | 37
server/proto/user.go | 21
server/server.go | 73
server/server_test.go | 58 -
server/ssh/cmd.go | 17
server/ssh/cmd/blob.go | 9
server/ssh/cmd/branch.go | 24
server/ssh/cmd/cmd.go | 127 -
server/ssh/cmd/collab.go | 16
server/ssh/cmd/commit.go | 18
server/ssh/cmd/create.go | 6
server/ssh/cmd/delete.go | 14
server/ssh/cmd/description.go | 8
server/ssh/cmd/hidden.go | 12
server/ssh/cmd/import.go | 6
server/ssh/cmd/info.go | 9
server/ssh/cmd/list.go | 10
server/ssh/cmd/mirror.go | 6
server/ssh/cmd/private.go | 8
server/ssh/cmd/project_name.go | 8
server/ssh/cmd/pubkey.go | 29
server/ssh/cmd/rename.go | 14
server/ssh/cmd/repo.go | 6
server/ssh/cmd/set_username.go | 28
server/ssh/cmd/settings.go | 19
server/ssh/cmd/tag.go | 11
server/ssh/cmd/tree.go | 10
server/ssh/cmd/user.go | 54
server/ssh/git.go | 124 ++
server/ssh/logger.go | 25
server/ssh/middleware.go | 44
server/ssh/session.go | 82
server/ssh/session_test.go | 33
server/ssh/ssh.go | 159 --
server/sshutils/utils.go | 40
server/sshutils/utils_test.go | 124 ++
server/stats/stats.go | 2
server/store/database/collab.go | 124 ++
server/store/database/database.go | 42
server/store/database/repo.go | 135 ++
server/store/database/settings.go | 47
server/store/database/user.go | 208 +++
server/store/store.go | 69 +
server/sync/workqueue.go | 0
server/sync/workqueue_test.go | 35
server/ui/common/common.go | 17
server/ui/common/utils.go | 3
server/ui/components/code/code.go | 4
server/ui/components/footer/footer.go | 2
server/ui/components/header/header.go | 2
server/ui/components/statusbar/statusbar.go | 2
server/ui/pages/repo/files.go | 5
server/ui/pages/repo/log.go | 4
server/ui/pages/repo/logitem.go | 1
server/ui/pages/repo/readme.go | 3
server/ui/pages/repo/refs.go | 11
server/ui/pages/repo/repo.go | 6
server/ui/pages/selection/item.go | 6
server/ui/pages/selection/selection.go | 22
server/ui/ui.go | 9
server/web/context.go | 28
server/web/git.go | 156 +-
server/web/goget.go | 15
server/web/http.go | 3
server/web/logging.go | 37
server/web/server.go | 14
testscript/script_test.go | 75
testscript/testdata/repo-commit.txtar | 30
testscript/testdata/repo-import.txtar | 0
133 files changed, 4,558 insertions(+), 3,257 deletions(-)
@@ -1,5 +1,6 @@
run:
tests: false
+ timeout: 5m
issues:
include:
@@ -1,5 +1,6 @@
run:
tests: false
+ timeout: 5m
issues:
include:
@@ -27,3 +28,10 @@ linters:
- unconvert
- unparam
- whitespace
+
+severity:
+ default-severity: error
+ rules:
+ - linters:
+ - revive
+ severity: info
@@ -0,0 +1,75 @@
+package main
+
+import (
+ "fmt"
+
+ "github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/config"
+ "github.com/charmbracelet/soft-serve/server/db"
+ "github.com/charmbracelet/soft-serve/server/db/migrate"
+ "github.com/spf13/cobra"
+)
+
+var (
+ adminCmd = &cobra.Command{
+ Use: "admin",
+ Short: "Administrate the server",
+ }
+
+ migrateCmd = &cobra.Command{
+ Use: "migrate",
+ Short: "Migrate the database to the latest version",
+ PersistentPreRunE: initBackendContext,
+ PersistentPostRunE: closeDBContext,
+ RunE: func(cmd *cobra.Command, _ []string) error {
+ ctx := cmd.Context()
+ db := db.FromContext(ctx)
+ if err := migrate.Migrate(ctx, db); err != nil {
+ return fmt.Errorf("migration: %w", err)
+ }
+
+ return nil
+ },
+ }
+
+ rollbackCmd = &cobra.Command{
+ Use: "rollback",
+ Short: "Rollback the database to the previous version",
+ PersistentPreRunE: initBackendContext,
+ PersistentPostRunE: closeDBContext,
+ RunE: func(cmd *cobra.Command, _ []string) error {
+ ctx := cmd.Context()
+ db := db.FromContext(ctx)
+ if err := migrate.Rollback(ctx, db); err != nil {
+ return fmt.Errorf("rollback: %w", err)
+ }
+
+ return nil
+ },
+ }
+
+ syncHooksCmd = &cobra.Command{
+ Use: "sync-hooks",
+ Short: "Update repository hooks",
+ PersistentPreRunE: initBackendContext,
+ PersistentPostRunE: closeDBContext,
+ RunE: func(cmd *cobra.Command, _ []string) error {
+ ctx := cmd.Context()
+ cfg := config.FromContext(ctx)
+ be := backend.FromContext(ctx)
+ if err := initializeHooks(ctx, cfg, be); err != nil {
+ return fmt.Errorf("initialize hooks: %w", err)
+ }
+
+ return nil
+ },
+ }
+)
+
+func init() {
+ adminCmd.AddCommand(
+ syncHooksCmd,
+ migrateCmd,
+ rollbackCmd,
+ )
+}
@@ -11,66 +11,33 @@ import (
"path/filepath"
"strings"
- "github.com/charmbracelet/log"
"github.com/charmbracelet/soft-serve/server/backend"
- "github.com/charmbracelet/soft-serve/server/backend/sqlite"
"github.com/charmbracelet/soft-serve/server/config"
"github.com/charmbracelet/soft-serve/server/hooks"
"github.com/spf13/cobra"
)
var (
+ // Deprecated: this flag is ignored.
configPath string
- logFileCtxKey = struct{}{}
-
hookCmd = &cobra.Command{
- Use: "hook",
- Short: "Run git server hooks",
- Long: "Handles Soft Serve git server hooks.",
- Hidden: true,
- PersistentPreRunE: func(cmd *cobra.Command, _ []string) error {
- ctx := cmd.Context()
- cfg, err := config.ParseConfig(configPath)
- if err != nil {
- return fmt.Errorf("could not parse config: %w", err)
- }
-
- ctx = config.WithContext(ctx, cfg)
-
- logPath := filepath.Join(cfg.DataPath, "log", "hooks.log")
- f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
- if err != nil {
- return fmt.Errorf("opening file: %w", err)
- }
-
- ctx = context.WithValue(ctx, logFileCtxKey, f)
- logger := log.FromContext(ctx)
- logger.SetOutput(f)
- ctx = log.WithContext(ctx, logger)
- cmd.SetContext(ctx)
-
- // Set up the backend
- // TODO: support other backends
- sb, err := sqlite.NewSqliteBackend(ctx)
- if err != nil {
- return fmt.Errorf("failed to create sqlite backend: %w", err)
- }
-
- cfg = cfg.WithBackend(sb)
-
- return nil
- },
- PersistentPostRunE: func(cmd *cobra.Command, _ []string) error {
- f := cmd.Context().Value(logFileCtxKey).(*os.File)
- return f.Close()
- },
+ Use: "hook",
+ Short: "Run git server hooks",
+ Long: "Handles Soft Serve git server hooks.",
+ Hidden: true,
+ PersistentPreRunE: initBackendContext,
+ PersistentPostRunE: closeDBContext,
}
+ // Git hooks read the config from the environment, based on
+ // $SOFT_SERVE_DATA_PATH. We already parse the config when the binary
+ // starts, so we don't need to do it again.
+ // The --config flag is now deprecated.
hooksRunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
+ hks := backend.FromContext(ctx)
cfg := config.FromContext(ctx)
- hks := cfg.Backend.(backend.Hooks)
// This is set in the server before invoking git-receive-pack/git-upload-pack
repoName := os.Getenv("SOFT_SERVE_REPO_NAME")
@@ -80,10 +47,10 @@ var (
stderr := cmd.ErrOrStderr()
cmdName := cmd.Name()
- customHookPath := filepath.Join(filepath.Dir(configPath), "hooks", cmdName)
+ customHookPath := filepath.Join(cfg.DataPath, "hooks", cmdName)
var buf bytes.Buffer
- opts := make([]backend.HookArg, 0)
+ opts := make([]hooks.HookArg, 0)
switch cmdName {
case hooks.PreReceiveHook, hooks.PostReceiveHook:
@@ -94,7 +61,7 @@ var (
if len(fields) != 3 {
return fmt.Errorf("invalid hook input: %s", scanner.Text())
}
- opts = append(opts, backend.HookArg{
+ opts = append(opts, hooks.HookArg{
OldSha: fields[0],
NewSha: fields[1],
RefName: fields[2],
@@ -103,22 +70,22 @@ var (
switch cmdName {
case hooks.PreReceiveHook:
- hks.PreReceive(stdout, stderr, repoName, opts)
+ hks.PreReceive(ctx, stdout, stderr, repoName, opts)
case hooks.PostReceiveHook:
- hks.PostReceive(stdout, stderr, repoName, opts)
+ hks.PostReceive(ctx, stdout, stderr, repoName, opts)
}
case hooks.UpdateHook:
if len(args) != 3 {
return fmt.Errorf("invalid update hook input: %s", args)
}
- hks.Update(stdout, stderr, repoName, backend.HookArg{
+ hks.Update(ctx, stdout, stderr, repoName, hooks.HookArg{
OldSha: args[0],
NewSha: args[1],
RefName: args[2],
})
case hooks.PostUpdateHook:
- hks.PostUpdate(stdout, stderr, repoName, args...)
+ hks.PostUpdate(ctx, stdout, stderr, repoName, args...)
}
// Custom hooks
@@ -159,7 +126,7 @@ var (
)
func init() {
- hookCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to config file")
+ hookCmd.PersistentFlags().StringVar(&configPath, "config", "", "path to config file (deprecated)")
hookCmd.AddCommand(
preReceiveCmd,
updateCmd,
@@ -211,6 +178,7 @@ fi
echo "Hi from Soft Serve update hook!"
echo
+echo "Repository: $SOFT_SERVE_REPO_NAME"
echo "RefName: $refname"
echo "Change Type: $newrev_type"
echo "Old SHA1: $oldrev"
@@ -8,22 +8,20 @@ import (
"github.com/spf13/cobra"
)
-var (
- manCmd = &cobra.Command{
- Use: "man",
- Short: "Generate man pages",
- Args: cobra.NoArgs,
- Hidden: true,
- RunE: func(cmd *cobra.Command, args []string) error {
- manPage, err := mcobra.NewManPage(1, rootCmd) //.
- if err != nil {
- return err
- }
+var manCmd = &cobra.Command{
+ Use: "man",
+ Short: "Generate man pages",
+ Args: cobra.NoArgs,
+ Hidden: true,
+ RunE: func(_ *cobra.Command, _ []string) error {
+ manPage, err := mcobra.NewManPage(1, rootCmd) //.
+ if err != nil {
+ return err
+ }
- manPage = manPage.WithSection("Copyright", "(C) 2021-2023 Charmbracelet, Inc.\n"+
- "Released under MIT license.")
- fmt.Println(manPage.Build(roff.NewDocument()))
- return nil
- },
- }
-)
+ manPage = manPage.WithSection("Copyright", "(C) 2021-2023 Charmbracelet, Inc.\n"+
+ "Released under MIT license.")
+ fmt.Println(manPage.Build(roff.NewDocument()))
+ return nil
+ },
+}
@@ -12,9 +12,12 @@ import (
"github.com/charmbracelet/log"
"github.com/charmbracelet/soft-serve/git"
+ "github.com/charmbracelet/soft-serve/server/access"
"github.com/charmbracelet/soft-serve/server/backend"
- "github.com/charmbracelet/soft-serve/server/backend/sqlite"
"github.com/charmbracelet/soft-serve/server/config"
+ "github.com/charmbracelet/soft-serve/server/db"
+ "github.com/charmbracelet/soft-serve/server/proto"
+ "github.com/charmbracelet/soft-serve/server/sshutils"
"github.com/charmbracelet/soft-serve/server/utils"
gitm "github.com/gogs/git-module"
"github.com/spf13/cobra"
@@ -22,296 +25,300 @@ import (
"gopkg.in/yaml.v3"
)
-var (
- migrateConfig = &cobra.Command{
- Use: "migrate-config",
- Short: "Migrate config to new format",
- Hidden: true,
- RunE: func(cmd *cobra.Command, _ []string) error {
- ctx := cmd.Context()
-
- logger := log.FromContext(ctx)
- // Disable logging timestamp
- logger.SetReportTimestamp(false)
-
- keyPath := os.Getenv("SOFT_SERVE_KEY_PATH")
- reposPath := os.Getenv("SOFT_SERVE_REPO_PATH")
- bindAddr := os.Getenv("SOFT_SERVE_BIND_ADDRESS")
- cfg := config.DefaultConfig()
- ctx = config.WithContext(ctx, cfg)
- sb, err := sqlite.NewSqliteBackend(ctx)
- if err != nil {
- return fmt.Errorf("failed to create sqlite backend: %w", err)
- }
+// Deprecated: will be removed in a future release.
+var migrateConfig = &cobra.Command{
+ Use: "migrate-config",
+ Short: "Migrate config to new format",
+ Hidden: true,
+ RunE: func(cmd *cobra.Command, _ []string) error {
+ ctx := cmd.Context()
+
+ logger := log.FromContext(ctx)
+ // Disable logging timestamp
+ logger.SetReportTimestamp(false)
+
+ keyPath := os.Getenv("SOFT_SERVE_KEY_PATH")
+ reposPath := os.Getenv("SOFT_SERVE_REPO_PATH")
+ bindAddr := os.Getenv("SOFT_SERVE_BIND_ADDRESS")
+ cfg := config.DefaultConfig()
+ if err := cfg.ParseEnv(); err != nil {
+ return fmt.Errorf("parse env: %w", err)
+ }
- // FIXME: Admin user gets created when the database is created.
- sb.DeleteUser("admin") // nolint: errcheck
+ ctx = config.WithContext(ctx, cfg)
+ db, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
+ if err != nil {
+ return fmt.Errorf("open database: %w", err)
+ }
- cfg = cfg.WithBackend(sb)
+ defer db.Close() // nolint: errcheck
+ sb := backend.New(ctx, cfg, db)
- // Set SSH listen address
- logger.Info("Setting SSH listen address...")
- if bindAddr != "" {
- cfg.SSH.ListenAddr = bindAddr
- }
+ // FIXME: Admin user gets created when the database is created.
+ sb.DeleteUser(ctx, "admin") // nolint: errcheck
- // Copy SSH host key
- logger.Info("Copying SSH host key...")
- if keyPath != "" {
- if err := os.MkdirAll(filepath.Join(cfg.DataPath, "ssh"), os.ModePerm); err != nil {
- return fmt.Errorf("failed to create ssh directory: %w", err)
- }
+ // Set SSH listen address
+ logger.Info("Setting SSH listen address...")
+ if bindAddr != "" {
+ cfg.SSH.ListenAddr = bindAddr
+ }
- if err := copyFile(keyPath, filepath.Join(cfg.DataPath, "ssh", filepath.Base(keyPath))); err != nil {
- return fmt.Errorf("failed to copy ssh key: %w", err)
- }
+ // Copy SSH host key
+ logger.Info("Copying SSH host key...")
+ if keyPath != "" {
+ if err := os.MkdirAll(filepath.Join(cfg.DataPath, "ssh"), os.ModePerm); err != nil {
+ return fmt.Errorf("failed to create ssh directory: %w", err)
+ }
- if err := copyFile(keyPath+".pub", filepath.Join(cfg.DataPath, "ssh", filepath.Base(keyPath))+".pub"); err != nil {
- logger.Errorf("failed to copy ssh key: %s", err)
- }
+ if err := copyFile(keyPath, filepath.Join(cfg.DataPath, "ssh", filepath.Base(keyPath))); err != nil {
+ return fmt.Errorf("failed to copy ssh key: %w", err)
+ }
- cfg.SSH.KeyPath = filepath.Join(cfg.DataPath, "ssh", filepath.Base(keyPath))
+ if err := copyFile(keyPath+".pub", filepath.Join(cfg.DataPath, "ssh", filepath.Base(keyPath))+".pub"); err != nil {
+ logger.Errorf("failed to copy ssh key: %s", err)
}
- // Read config
- logger.Info("Reading config repository...")
- r, err := git.Open(filepath.Join(reposPath, "config"))
+ cfg.SSH.KeyPath = filepath.Join(cfg.DataPath, "ssh", filepath.Base(keyPath))
+ }
+
+ // Read config
+ logger.Info("Reading config repository...")
+ r, err := git.Open(filepath.Join(reposPath, "config"))
+ if err != nil {
+ return fmt.Errorf("failed to open config repo: %w", err)
+ }
+
+ head, err := r.HEAD()
+ if err != nil {
+ return fmt.Errorf("failed to get head: %w", err)
+ }
+
+ tree, err := r.TreePath(head, "")
+ if err != nil {
+ return fmt.Errorf("failed to get tree: %w", err)
+ }
+
+ isJson := false // nolint: revive
+ te, err := tree.TreeEntry("config.yaml")
+ if err != nil {
+ te, err = tree.TreeEntry("config.json")
if err != nil {
- return fmt.Errorf("failed to open config repo: %w", err)
+ return fmt.Errorf("failed to get config file: %w", err)
}
+ isJson = true
+ }
- head, err := r.HEAD()
- if err != nil {
- return fmt.Errorf("failed to get head: %w", err)
+ cc, err := te.Contents()
+ if err != nil {
+ return fmt.Errorf("failed to get config contents: %w", err)
+ }
+
+ var ocfg Config
+ if isJson {
+ if err := json.Unmarshal(cc, &ocfg); err != nil {
+ return fmt.Errorf("failed to unmarshal config: %w", err)
}
+ } else {
+ if err := yaml.Unmarshal(cc, &ocfg); err != nil {
+ return fmt.Errorf("failed to unmarshal config: %w", err)
+ }
+ }
- tree, err := r.TreePath(head, "")
- if err != nil {
- return fmt.Errorf("failed to get tree: %w", err)
+ readme, readmePath, err := git.LatestFile(r, "README*")
+ hasReadme := err == nil
+
+ // Set server name
+ cfg.Name = ocfg.Name
+
+ // Set server public url
+ cfg.SSH.PublicURL = fmt.Sprintf("ssh://%s:%d", ocfg.Host, ocfg.Port)
+
+ // Set server settings
+ logger.Info("Setting server settings...")
+ if sb.SetAllowKeyless(ctx, ocfg.AllowKeyless) != nil {
+ fmt.Fprintf(os.Stderr, "failed to set allow keyless\n")
+ }
+ anon := access.ParseAccessLevel(ocfg.AnonAccess)
+ if anon >= 0 {
+ if err := sb.SetAnonAccess(ctx, anon); err != nil {
+ fmt.Fprintf(os.Stderr, "failed to set anon access: %s\n", err)
}
+ }
- isJson := false // nolint: revive
- te, err := tree.TreeEntry("config.yaml")
- if err != nil {
- te, err = tree.TreeEntry("config.json")
- if err != nil {
- return fmt.Errorf("failed to get config file: %w", err)
- }
- isJson = true
+ // Copy repos
+ if reposPath != "" {
+ logger.Info("Copying repos...")
+ if err := os.MkdirAll(filepath.Join(cfg.DataPath, "repos"), os.ModePerm); err != nil {
+ return fmt.Errorf("failed to create repos directory: %w", err)
}
- cc, err := te.Contents()
+ dirs, err := os.ReadDir(reposPath)
if err != nil {
- return fmt.Errorf("failed to get config contents: %w", err)
+ return fmt.Errorf("failed to read repos directory: %w", err)
}
- var ocfg Config
- if isJson {
- if err := json.Unmarshal(cc, &ocfg); err != nil {
- return fmt.Errorf("failed to unmarshal config: %w", err)
+ for _, dir := range dirs {
+ if !dir.IsDir() || dir.Name() == "config" {
+ continue
+ }
+
+ if !isGitDir(filepath.Join(reposPath, dir.Name())) {
+ continue
}
- } else {
- if err := yaml.Unmarshal(cc, &ocfg); err != nil {
- return fmt.Errorf("failed to unmarshal config: %w", err)
+
+ logger.Infof(" Copying repo %s", dir.Name())
+ src := filepath.Join(reposPath, utils.SanitizeRepo(dir.Name()))
+ dst := filepath.Join(cfg.DataPath, "repos", utils.SanitizeRepo(dir.Name())) + ".git"
+ if err := os.MkdirAll(dst, os.ModePerm); err != nil {
+ return fmt.Errorf("failed to create repo directory: %w", err)
+ }
+
+ if err := copyDir(src, dst); err != nil {
+ return fmt.Errorf("failed to copy repo: %w", err)
+ }
+
+ if _, err := sb.CreateRepository(ctx, dir.Name(), proto.RepositoryOptions{}); err != nil {
+ fmt.Fprintf(os.Stderr, "failed to create repository: %s\n", err)
}
}
- readme, readmePath, err := git.LatestFile(r, "README*")
- hasReadme := err == nil
+ if hasReadme {
+ logger.Infof(" Copying readme from \"config\" to \".soft-serve\"")
- // Set server name
- cfg.Name = ocfg.Name
+ // Switch to main branch
+ bcmd := git.NewCommand("branch", "-M", "main")
- // Set server public url
- cfg.SSH.PublicURL = fmt.Sprintf("ssh://%s:%d", ocfg.Host, ocfg.Port)
+ rp := filepath.Join(cfg.DataPath, "repos", ".soft-serve.git")
+ nr, err := git.Init(rp, true)
+ if err != nil {
+ return fmt.Errorf("failed to init repo: %w", err)
+ }
- // Set server settings
- logger.Info("Setting server settings...")
- if cfg.Backend.SetAllowKeyless(ocfg.AllowKeyless) != nil {
- fmt.Fprintf(os.Stderr, "failed to set allow keyless\n")
- }
- anon := backend.ParseAccessLevel(ocfg.AnonAccess)
- if anon >= 0 {
- if err := sb.SetAnonAccess(anon); err != nil {
- fmt.Fprintf(os.Stderr, "failed to set anon access: %s\n", err)
+ if _, err := nr.SymbolicRef("HEAD", gitm.RefsHeads+"main"); err != nil {
+ return fmt.Errorf("failed to set HEAD: %w", err)
}
- }
- // Copy repos
- if reposPath != "" {
- logger.Info("Copying repos...")
- if err := os.MkdirAll(filepath.Join(cfg.DataPath, "repos"), os.ModePerm); err != nil {
- return fmt.Errorf("failed to create repos directory: %w", err)
+ tmpDir, err := os.MkdirTemp("", "soft-serve")
+ if err != nil {
+ return fmt.Errorf("failed to create temp dir: %w", err)
}
- dirs, err := os.ReadDir(reposPath)
+ r, err := git.Init(tmpDir, false)
if err != nil {
- return fmt.Errorf("failed to read repos directory: %w", err)
+ return fmt.Errorf("failed to clone repo: %w", err)
}
- for _, dir := range dirs {
- if !dir.IsDir() || dir.Name() == "config" {
- continue
- }
-
- if !isGitDir(filepath.Join(reposPath, dir.Name())) {
- continue
- }
-
- logger.Infof(" Copying repo %s", dir.Name())
- src := filepath.Join(reposPath, utils.SanitizeRepo(dir.Name()))
- dst := filepath.Join(cfg.DataPath, "repos", utils.SanitizeRepo(dir.Name())) + ".git"
- if err := os.MkdirAll(dst, os.ModePerm); err != nil {
- return fmt.Errorf("failed to create repo directory: %w", err)
- }
-
- if err := copyDir(src, dst); err != nil {
- return fmt.Errorf("failed to copy repo: %w", err)
- }
-
- if _, err := sb.CreateRepository(dir.Name(), backend.RepositoryOptions{}); err != nil {
- fmt.Fprintf(os.Stderr, "failed to create repository: %s\n", err)
- }
+ if _, err := bcmd.RunInDir(tmpDir); err != nil {
+ return fmt.Errorf("failed to create main branch: %w", err)
}
- if hasReadme {
- logger.Infof(" Copying readme from \"config\" to \".soft-serve\"")
-
- // Switch to main branch
- bcmd := git.NewCommand("branch", "-M", "main")
-
- rp := filepath.Join(cfg.DataPath, "repos", ".soft-serve.git")
- nr, err := git.Init(rp, true)
- if err != nil {
- return fmt.Errorf("failed to init repo: %w", err)
- }
-
- if _, err := nr.SymbolicRef("HEAD", gitm.RefsHeads+"main"); err != nil {
- return fmt.Errorf("failed to set HEAD: %w", err)
- }
-
- tmpDir, err := os.MkdirTemp("", "soft-serve")
- if err != nil {
- return fmt.Errorf("failed to create temp dir: %w", err)
- }
-
- r, err := git.Init(tmpDir, false)
- if err != nil {
- return fmt.Errorf("failed to clone repo: %w", err)
- }
-
- if _, err := bcmd.RunInDir(tmpDir); err != nil {
- return fmt.Errorf("failed to create main branch: %w", err)
- }
-
- if err := os.WriteFile(filepath.Join(tmpDir, readmePath), []byte(readme), 0o644); err != nil {
- return fmt.Errorf("failed to write readme: %w", err)
- }
-
- if err := r.Add(gitm.AddOptions{
- All: true,
- }); err != nil {
- return fmt.Errorf("failed to add readme: %w", err)
- }
-
- if err := r.Commit(&gitm.Signature{
- Name: "Soft Serve",
- Email: "vt100@charm.sh",
- When: time.Now(),
- }, "Add readme"); err != nil {
- return fmt.Errorf("failed to commit readme: %w", err)
- }
-
- if err := r.RemoteAdd("origin", "file://"+rp); err != nil {
- return fmt.Errorf("failed to add remote: %w", err)
- }
-
- if err := r.Push("origin", "main"); err != nil {
- return fmt.Errorf("failed to push readme: %w", err)
- }
-
- // Create `.soft-serve` repository and add readme
- if _, err := sb.CreateRepository(".soft-serve", backend.RepositoryOptions{
- ProjectName: "Home",
- Description: "Soft Serve home repository",
- Hidden: true,
- Private: false,
- }); err != nil {
- fmt.Fprintf(os.Stderr, "failed to create repository: %s\n", err)
- }
+ if err := os.WriteFile(filepath.Join(tmpDir, readmePath), []byte(readme), 0o644); err != nil { // nolint: gosec
+ return fmt.Errorf("failed to write readme: %w", err)
}
- }
- // Set repos metadata & collabs
- logger.Info("Setting repos metadata & collabs...")
- for _, r := range ocfg.Repos {
- repo, name := r.Repo, r.Name
- // Special case for config repo
- if repo == "config" {
- repo = ".soft-serve"
- r.Private = false
+ if err := r.Add(gitm.AddOptions{
+ All: true,
+ }); err != nil {
+ return fmt.Errorf("failed to add readme: %w", err)
}
- if err := sb.SetProjectName(repo, name); err != nil {
- logger.Errorf("failed to set repo name to %s: %s", repo, err)
+ if err := r.Commit(&gitm.Signature{
+ Name: "Soft Serve",
+ Email: "vt100@charm.sh",
+ When: time.Now(),
+ }, "Add readme"); err != nil {
+ return fmt.Errorf("failed to commit readme: %w", err)
}
- if err := sb.SetDescription(repo, r.Note); err != nil {
- logger.Errorf("failed to set repo description to %s: %s", repo, err)
+ if err := r.RemoteAdd("origin", "file://"+rp); err != nil {
+ return fmt.Errorf("failed to add remote: %w", err)
}
- if err := sb.SetPrivate(repo, r.Private); err != nil {
- logger.Errorf("failed to set repo private to %s: %s", repo, err)
+ if err := r.Push("origin", "main"); err != nil {
+ return fmt.Errorf("failed to push readme: %w", err)
}
- for _, collab := range r.Collabs {
- if err := sb.AddCollaborator(repo, collab); err != nil {
- logger.Errorf("failed to add repo collab to %s: %s", repo, err)
- }
+ // Create `.soft-serve` repository and add readme
+ if _, err := sb.CreateRepository(ctx, ".soft-serve", proto.RepositoryOptions{
+ ProjectName: "Home",
+ Description: "Soft Serve home repository",
+ Hidden: true,
+ Private: false,
+ }); err != nil {
+ fmt.Fprintf(os.Stderr, "failed to create repository: %s\n", err)
}
}
+ }
- // Create users & collabs
- logger.Info("Creating users & collabs...")
- for _, user := range ocfg.Users {
- keys := make(map[string]ssh.PublicKey)
- for _, key := range user.PublicKeys {
- pk, _, err := backend.ParseAuthorizedKey(key)
- if err != nil {
- continue
- }
- ak := backend.MarshalAuthorizedKey(pk)
- keys[ak] = pk
- }
+ // Set repos metadata & collabs
+ logger.Info("Setting repos metadata & collabs...")
+ for _, r := range ocfg.Repos {
+ repo, name := r.Repo, r.Name
+ // Special case for config repo
+ if repo == "config" {
+ repo = ".soft-serve"
+ r.Private = false
+ }
+
+ if err := sb.SetProjectName(ctx, repo, name); err != nil {
+ logger.Errorf("failed to set repo name to %s: %s", repo, err)
+ }
+
+ if err := sb.SetDescription(ctx, repo, r.Note); err != nil {
+ logger.Errorf("failed to set repo description to %s: %s", repo, err)
+ }
+
+ if err := sb.SetPrivate(ctx, repo, r.Private); err != nil {
+ logger.Errorf("failed to set repo private to %s: %s", repo, err)
+ }
- pubkeys := make([]ssh.PublicKey, 0)
- for _, pk := range keys {
- pubkeys = append(pubkeys, pk)
+ for _, collab := range r.Collabs {
+ if err := sb.AddCollaborator(ctx, repo, collab); err != nil {
+ logger.Errorf("failed to add repo collab to %s: %s", repo, err)
}
+ }
+ }
- username := strings.ToLower(user.Name)
- username = strings.ReplaceAll(username, " ", "-")
- logger.Infof("Creating user %q", username)
- if _, err := sb.CreateUser(username, backend.UserOptions{
- Admin: user.Admin,
- PublicKeys: pubkeys,
- }); err != nil {
- logger.Errorf("failed to create user: %s", err)
+ // Create users & collabs
+ logger.Info("Creating users & collabs...")
+ for _, user := range ocfg.Users {
+ keys := make(map[string]ssh.PublicKey)
+ for _, key := range user.PublicKeys {
+ pk, _, err := sshutils.ParseAuthorizedKey(key)
+ if err != nil {
+ continue
}
+ ak := sshutils.MarshalAuthorizedKey(pk)
+ keys[ak] = pk
+ }
+
+ pubkeys := make([]ssh.PublicKey, 0)
+ for _, pk := range keys {
+ pubkeys = append(pubkeys, pk)
+ }
+
+ username := strings.ToLower(user.Name)
+ username = strings.ReplaceAll(username, " ", "-")
+ logger.Infof("Creating user %q", username)
+ if _, err := sb.CreateUser(ctx, username, proto.UserOptions{
+ Admin: user.Admin,
+ PublicKeys: pubkeys,
+ }); err != nil {
+ logger.Errorf("failed to create user: %s", err)
+ }
- for _, repo := range user.CollabRepos {
- if err := sb.AddCollaborator(repo, username); err != nil {
- logger.Errorf("failed to add user collab to %s: %s\n", repo, err)
- }
+ for _, repo := range user.CollabRepos {
+ if err := sb.AddCollaborator(ctx, repo, username); err != nil {
+ logger.Errorf("failed to add user collab to %s: %s\n", repo, err)
}
}
+ }
- logger.Info("Writing config...")
- defer logger.Info("Done!")
- return config.WriteConfig(filepath.Join(cfg.DataPath, "config.yaml"), cfg)
- },
- }
-)
+ logger.Info("Writing config...")
+ defer logger.Info("Done!")
+ return cfg.WriteConfig()
+ },
+}
// Returns true if path is a directory containing an `objects` directory and a
// `HEAD` file.
@@ -2,13 +2,21 @@ package main
import (
"context"
+ "fmt"
"os"
"runtime/debug"
+ "strings"
+ "time"
"github.com/charmbracelet/log"
- . "github.com/charmbracelet/soft-serve/internal/log"
+ "github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/config"
+ "github.com/charmbracelet/soft-serve/server/db"
+ _ "github.com/lib/pq" // postgres driver
"github.com/spf13/cobra"
"go.uber.org/automaxprocs/maxprocs"
+
+ _ "modernc.org/sqlite" // sqlite driver
)
var (
@@ -34,6 +42,7 @@ func init() {
manCmd,
hookCmd,
migrateConfig,
+ adminCmd,
)
rootCmd.CompletionOptions.HiddenDefaultCmd = true
@@ -52,19 +61,111 @@ func init() {
}
func main() {
- logger := NewDefaultLogger()
+ ctx := context.Background()
+ cfg := config.DefaultConfig()
+ if cfg.Exist() {
+ if err := cfg.Parse(); err != nil {
+ log.Fatal(err)
+ }
+ }
+
+ if err := cfg.ParseEnv(); err != nil {
+ log.Fatal(err)
+ }
+
+ ctx = config.WithContext(ctx, cfg)
+ logger, f, err := newDefaultLogger(cfg)
+ if err != nil {
+ log.Errorf("failed to create logger: %v", err)
+ }
+
+ ctx = log.WithContext(ctx, logger)
+ if f != nil {
+ defer f.Close() // nolint: errcheck
+ }
// Set global logger
log.SetDefault(logger)
+ var opts []maxprocs.Option
+ if config.IsVerbose() {
+ opts = append(opts, maxprocs.Logger(log.Debugf))
+ }
+
// Set the max number of processes to the number of CPUs
// This is useful when running soft serve in a container
- if _, err := maxprocs.Set(maxprocs.Logger(logger.Debugf)); err != nil {
- logger.Warn("couldn't set automaxprocs", "error", err)
+ if _, err := maxprocs.Set(opts...); err != nil {
+ log.Warn("couldn't set automaxprocs", "error", err)
}
- ctx := log.WithContext(context.Background(), logger)
if err := rootCmd.ExecuteContext(ctx); err != nil {
os.Exit(1)
}
}
+
+// newDefaultLogger returns a new logger with default settings.
+func newDefaultLogger(cfg *config.Config) (*log.Logger, *os.File, error) {
+ logger := log.NewWithOptions(os.Stderr, log.Options{
+ ReportTimestamp: true,
+ TimeFormat: time.DateOnly,
+ })
+
+ switch {
+ case config.IsVerbose():
+ logger.SetReportCaller(true)
+ fallthrough
+ case config.IsDebug():
+ logger.SetLevel(log.DebugLevel)
+ }
+
+ logger.SetTimeFormat(cfg.Log.TimeFormat)
+
+ switch strings.ToLower(cfg.Log.Format) {
+ case "json":
+ logger.SetFormatter(log.JSONFormatter)
+ case "logfmt":
+ logger.SetFormatter(log.LogfmtFormatter)
+ case "text":
+ logger.SetFormatter(log.TextFormatter)
+ }
+
+ var f *os.File
+ if cfg.Log.Path != "" {
+ f, err := os.OpenFile(cfg.Log.Path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
+ if err != nil {
+ return nil, nil, err
+ }
+ logger.SetOutput(f)
+ }
+
+ return logger, f, nil
+}
+
+func initBackendContext(cmd *cobra.Command, _ []string) error {
+ ctx := cmd.Context()
+ cfg := config.FromContext(ctx)
+ dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
+ if err != nil {
+ return fmt.Errorf("open database: %w", err)
+ }
+
+ ctx = db.WithContext(ctx, dbx)
+ be := backend.New(ctx, cfg, dbx)
+ ctx = backend.WithContext(ctx, be)
+
+ cmd.SetContext(ctx)
+
+ return nil
+}
+
+func closeDBContext(cmd *cobra.Command, _ []string) error {
+ ctx := cmd.Context()
+ dbx := db.FromContext(ctx)
+ if dbx != nil {
+ if err := dbx.Close(); err != nil {
+ return fmt.Errorf("close database: %w", err)
+ }
+ }
+
+ return nil
+}
@@ -10,21 +10,39 @@ import (
"time"
"github.com/charmbracelet/soft-serve/server"
+ "github.com/charmbracelet/soft-serve/server/backend"
"github.com/charmbracelet/soft-serve/server/config"
+ "github.com/charmbracelet/soft-serve/server/db"
+ "github.com/charmbracelet/soft-serve/server/db/migrate"
+ "github.com/charmbracelet/soft-serve/server/hooks"
"github.com/spf13/cobra"
)
var (
+ syncHooks bool
+
serveCmd = &cobra.Command{
- Use: "serve",
- Short: "Start the server",
- Long: "Start the server",
- Args: cobra.NoArgs,
+ Use: "serve",
+ Short: "Start the server",
+ Args: cobra.NoArgs,
+ PersistentPreRunE: initBackendContext,
+ PersistentPostRunE: closeDBContext,
RunE: func(cmd *cobra.Command, _ []string) error {
ctx := cmd.Context()
cfg := config.DefaultConfig()
- ctx = config.WithContext(ctx, cfg)
- cmd.SetContext(ctx)
+ if cfg.Exist() {
+ if err := cfg.ParseFile(); err != nil {
+ return fmt.Errorf("parse config file: %w", err)
+ }
+ } else {
+ if err := cfg.WriteConfig(); err != nil {
+ return fmt.Errorf("write config file: %w", err)
+ }
+ }
+
+ if err := cfg.ParseEnv(); err != nil {
+ return fmt.Errorf("parse environment variables: %w", err)
+ }
// Create custom hooks directory if it doesn't exist
customHooksPath := filepath.Join(cfg.DataPath, "hooks")
@@ -33,7 +51,7 @@ var (
// Generate update hook example without executable permissions
hookPath := filepath.Join(customHooksPath, "update.sample")
// nolint: gosec
- if err := os.WriteFile(hookPath, []byte(updateHookExample), 0744); err != nil {
+ if err := os.WriteFile(hookPath, []byte(updateHookExample), 0o744); err != nil {
return fmt.Errorf("failed to generate update hook example: %w", err)
}
}
@@ -44,11 +62,23 @@ var (
os.MkdirAll(logPath, os.ModePerm) // nolint: errcheck
}
+ db := db.FromContext(ctx)
+ if err := migrate.Migrate(ctx, db); err != nil {
+ return fmt.Errorf("migration error: %w", err)
+ }
+
s, err := server.NewServer(ctx)
if err != nil {
return fmt.Errorf("start server: %w", err)
}
+ if syncHooks {
+ be := backend.FromContext(ctx)
+ if err := initializeHooks(ctx, cfg, be); err != nil {
+ return fmt.Errorf("initialize hooks: %w", err)
+ }
+ }
+
done := make(chan os.Signal, 1)
lch := make(chan error, 1)
go func() {
@@ -71,3 +101,22 @@ var (
},
}
)
+
+func init() {
+ serveCmd.Flags().BoolVarP(&syncHooks, "sync-hooks", "", false, "synchronize hooks for all repositories before running the server")
+}
+
+func initializeHooks(ctx context.Context, cfg *config.Config, be *backend.Backend) error {
+ repos, err := be.Repositories(ctx)
+ if err != nil {
+ return err
+ }
+
+ for _, repo := range repos {
+ if err := hooks.GenerateHooks(ctx, cfg, repo.Name()); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
@@ -1,74 +0,0 @@
-//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
-// +build darwin dragonfly freebsd linux netbsd openbsd solaris
-
-// This is an example of binding soft-serve ssh port to a restricted port (<1024) and
-// then droping root privileges to a different user to run the server.
-// Make sure you run this as root.
-
-package main
-
-import (
- "context"
- "flag"
- "fmt"
- "net"
- "os"
- "os/signal"
- "syscall"
- "time"
-
- "github.com/charmbracelet/log"
-
- "github.com/charmbracelet/soft-serve/server"
- "github.com/charmbracelet/soft-serve/server/config"
-)
-
-var (
- port = flag.Int("port", 22, "port to listen on")
- gid = flag.Int("gid", 1000, "group id to run as")
- uid = flag.Int("uid", 1000, "user id to run as")
-)
-
-func main() {
- flag.Parse()
- addr := fmt.Sprintf(":%d", *port)
- // To listen on port 22 we need root privileges
- ls, err := net.Listen("tcp", addr)
- if err != nil {
- log.Fatal("Can't listen", "err", err)
- }
- // We don't need root privileges any more
- if err := syscall.Setgid(*gid); err != nil {
- log.Fatal("Setgid error", "err", err)
- }
- if err := syscall.Setuid(*uid); err != nil {
- log.Fatal("Setuid error", "err", err)
- }
- ctx := context.Background()
- cfg := config.DefaultConfig()
- ctx = config.WithContext(ctx, cfg)
- cfg.SSH.ListenAddr = fmt.Sprintf(":%d", *port)
- s, err := server.NewServer(ctx)
- if err != nil {
- log.Fatal(err)
- }
-
- done := make(chan os.Signal, 1)
- signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
-
- log.Print("Starting SSH server", "addr", cfg.SSH.ListenAddr)
- go func() {
- if err := s.SSHServer.Serve(ls); err != nil {
- log.Fatal(err)
- }
- }()
-
- <-done
-
- log.Print("Stopping SSH server", "addr", cfg.SSH.ListenAddr)
- ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
- defer func() { cancel() }()
- if err := s.Shutdown(ctx); err != nil {
- log.Fatal(err)
- }
-}
@@ -4,9 +4,8 @@ import (
"github.com/gogs/git-module"
)
-var (
- ZeroHash Hash = git.EmptyID
-)
+// ZeroHash is the zero hash.
+var ZeroHash Hash = git.EmptyID
// Hash represents a git hash.
type Hash string
@@ -82,7 +82,7 @@ func diffsToString(diffs []diffmatchpatch.Diff, lineType git.DiffLineType) strin
}
}
- return string(buf.Bytes())
+ return buf.String()
}
// DiffFile is a wrapper to git.DiffFile with helper methods.
@@ -20,12 +20,13 @@ require (
require (
github.com/caarlos0/env/v8 v8.0.0
github.com/charmbracelet/keygen v0.4.3
- github.com/charmbracelet/log v0.2.2
- github.com/charmbracelet/ssh v0.0.0-20221117183211-483d43d97103
+ github.com/charmbracelet/log v0.2.3-0.20230713155356-557335e40e35
+ github.com/charmbracelet/ssh v0.0.0-20230712221603-7e03c5063afc
github.com/gobwas/glob v0.2.3
github.com/gogs/git-module v1.8.2
github.com/hashicorp/golang-lru/v2 v2.0.4
github.com/jmoiron/sqlx v1.3.5
+ github.com/lib/pq v1.2.0
github.com/lrstanley/bubblezone v0.0.0-20220716194435-3cb8c52f6a8f
github.com/muesli/mango-cobra v1.2.0
github.com/muesli/roff v0.1.0
@@ -27,10 +27,10 @@ github.com/charmbracelet/keygen v0.4.3 h1:ywOZRwkDlpmkawl0BgLTxaYWDSqp6Y4nfVVmgy
github.com/charmbracelet/keygen v0.4.3/go.mod h1:4e4FT3HSdLU/u83RfJWvzJIaVb8aX4MxtDlfXwpDJaI=
github.com/charmbracelet/lipgloss v0.7.1 h1:17WMwi7N1b1rVWOjMT+rCh7sQkvDU75B2hbZpc5Kc1E=
github.com/charmbracelet/lipgloss v0.7.1/go.mod h1:yG0k3giv8Qj8edTCbbg6AlQ5e8KNWpFujkNawKNhE2c=
-github.com/charmbracelet/log v0.2.2 h1:CaXgos+ikGn5tcws5Cw3paQuk9e/8bIwuYGhnkqQFjo=
-github.com/charmbracelet/log v0.2.2/go.mod h1:Zs11hKpb8l+UyX4y1srwZIGW+MPCXJHIty3MB9l/sno=
-github.com/charmbracelet/ssh v0.0.0-20221117183211-483d43d97103 h1:wpHMERIN0pQZE635jWwT1dISgfjbpUcEma+fbPKSMCU=
-github.com/charmbracelet/ssh v0.0.0-20221117183211-483d43d97103/go.mod h1:0Vm2/8yBljiLDnGJHU8ehswfawrEybGk33j5ssqKQVM=
+github.com/charmbracelet/log v0.2.3-0.20230713155356-557335e40e35 h1:VXEaJ1iM2L5N8T2WVbv4y631pzCD3O9s75dONqK+87g=
+github.com/charmbracelet/log v0.2.3-0.20230713155356-557335e40e35/go.mod h1:ZApwwzDbbETVTIRTk7724yQRJAXIktt98yGVMMaa3y8=
+github.com/charmbracelet/ssh v0.0.0-20230712221603-7e03c5063afc h1:JUm+5HigAM5utFiThwIDX9iU0BaheKpuNVr+umi3sFg=
+github.com/charmbracelet/ssh v0.0.0-20230712221603-7e03c5063afc/go.mod h1:F1vgddWsb/Yr/OZilFeRZEh5sE/qU0Dt1mKkmke6Zvg=
github.com/charmbracelet/wish v1.1.1 h1:KdICASKd2oh2JPvk1Z4CJtAi97cFErXF7NKienPICO4=
github.com/charmbracelet/wish v1.1.1/go.mod h1:xh4KZpSULw+Xqb9bcbhw92QAinVB75CVLWrFuyY6IVs=
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 h1:q2hJAaP1k2wIvVRd/hEHD7lacgqrCPS+k8g1MndzfWY=
@@ -161,7 +161,7 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
-github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
+github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.5.2 h1:ALmeCk/px5FSm1MAcFBAsVKZjDuMVj8Tm7FFIlMJnqU=
github.com/yuin/goldmark v1.5.2/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
@@ -189,14 +189,14 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
-golang.org/x/term v0.0.0-20220722155259-a9ba230a4035/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
+golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c=
golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
@@ -1,53 +0,0 @@
-package log
-
-import (
- "os"
- "path/filepath"
- "strconv"
- "strings"
- "time"
-
- "github.com/charmbracelet/log"
- "github.com/charmbracelet/soft-serve/server/config"
-)
-
-var contextKey = &struct{ string }{"logger"}
-
-// NewDefaultLogger returns a new logger with default settings.
-func NewDefaultLogger() *log.Logger {
- dp := os.Getenv("SOFT_SERVE_DATA_PATH")
- if dp == "" {
- dp = "data"
- }
-
- cfg, err := config.ParseConfig(filepath.Join(dp, "config.yaml"))
- if err != nil {
- log.Errorf("failed to parse config: %v", err)
- }
-
- logger := log.NewWithOptions(os.Stderr, log.Options{
- ReportTimestamp: true,
- TimeFormat: time.DateOnly,
- })
-
- if debug, _ := strconv.ParseBool(os.Getenv("SOFT_SERVE_DEBUG")); debug {
- logger.SetLevel(log.DebugLevel)
-
- if verbose, _ := strconv.ParseBool(os.Getenv("SOFT_SERVE_VERBOSE")); verbose {
- logger.SetReportCaller(true)
- }
- }
-
- logger.SetTimeFormat(cfg.Log.TimeFormat)
-
- switch strings.ToLower(cfg.Log.Format) {
- case "json":
- logger.SetFormatter(log.JSONFormatter)
- case "logfmt":
- logger.SetFormatter(log.LogfmtFormatter)
- case "text":
- logger.SetFormatter(log.TextFormatter)
- }
-
- return logger
-}
@@ -1,7 +1,7 @@
-package backend
+package access
// AccessLevel is the level of access allowed to a repo.
-type AccessLevel int
+type AccessLevel int // nolint: revive
const (
// NoAccess does not allow access to the repo.
@@ -0,0 +1,24 @@
+package access
+
+import "testing"
+
+func TestParseAccessLevel(t *testing.T) {
+ cases := []struct {
+ in string
+ out AccessLevel
+ }{
+ {"", -1},
+ {"foo", -1},
+ {AdminAccess.String(), AdminAccess},
+ {ReadOnlyAccess.String(), ReadOnlyAccess},
+ {ReadWriteAccess.String(), ReadWriteAccess},
+ {NoAccess.String(), NoAccess},
+ }
+
+ for _, c := range cases {
+ out := ParseAccessLevel(c.in)
+ if out != c.out {
+ t.Errorf("ParseAccessLevel(%q) => %d, want %d", c.in, out, c.out)
+ }
+ }
+}
@@ -1,47 +1,41 @@
package backend
import (
- "bytes"
"context"
- "github.com/charmbracelet/ssh"
- gossh "golang.org/x/crypto/ssh"
+ "github.com/charmbracelet/log"
+ "github.com/charmbracelet/soft-serve/server/config"
+ "github.com/charmbracelet/soft-serve/server/db"
+ "github.com/charmbracelet/soft-serve/server/store"
+ "github.com/charmbracelet/soft-serve/server/store/database"
)
-// Backend is an interface that handles repositories management and any
-// non-Git related operations.
-type Backend interface {
- SettingsBackend
- RepositoryStore
- RepositoryMetadata
- RepositoryAccess
- UserStore
- UserAccess
- Hooks
-
- // WithContext returns a copy Backend with the given context.
- WithContext(ctx context.Context) Backend
-}
-
-// ParseAuthorizedKey parses an authorized key string into a public key.
-func ParseAuthorizedKey(ak string) (gossh.PublicKey, string, error) {
- pk, c, _, _, err := gossh.ParseAuthorizedKey([]byte(ak))
- return pk, c, err
+// Backend is the Soft Serve backend that handles users, repositories, and
+// server settings management and operations.
+type Backend struct {
+ ctx context.Context
+ cfg *config.Config
+ db *db.DB
+ store store.Store
+ logger *log.Logger
+ cache *cache
}
-// MarshalAuthorizedKey marshals a public key into an authorized key string.
-//
-// This is the inverse of ParseAuthorizedKey.
-// This function is a copy of ssh.MarshalAuthorizedKey, but without the trailing newline.
-// It returns an empty string if pk is nil.
-func MarshalAuthorizedKey(pk gossh.PublicKey) string {
- if pk == nil {
- return ""
+// New returns a new Soft Serve backend.
+func New(ctx context.Context, cfg *config.Config, db *db.DB) *Backend {
+ dbstore := database.New(ctx, db)
+ logger := log.FromContext(ctx).WithPrefix("backend")
+ b := &Backend{
+ ctx: ctx,
+ cfg: cfg,
+ db: db,
+ store: dbstore,
+ logger: logger,
}
- return string(bytes.TrimSuffix(gossh.MarshalAuthorizedKey(pk), []byte("\n")))
-}
-// KeysEqual returns whether the two public keys are equal.
-func KeysEqual(a, b gossh.PublicKey) bool {
- return ssh.KeysEqual(a, b)
+ // TODO: implement a proper caching interface
+ cache := newCache(b, 1000)
+ b.cache = cache
+
+ return b
}
@@ -0,0 +1,35 @@
+package backend
+
+import lru "github.com/hashicorp/golang-lru/v2"
+
+// TODO: implement a caching interface.
+type cache struct {
+ b *Backend
+ repos *lru.Cache[string, *repo]
+}
+
+func newCache(b *Backend, size int) *cache {
+ if size <= 0 {
+ size = 1
+ }
+ c := &cache{b: b}
+ cache, _ := lru.New[string, *repo](size)
+ c.repos = cache
+ return c
+}
+
+func (c *cache) Get(repo string) (*repo, bool) {
+ return c.repos.Get(repo)
+}
+
+func (c *cache) Set(repo string, r *repo) {
+ c.repos.Add(repo, r)
+}
+
+func (c *cache) Delete(repo string) {
+ c.repos.Remove(repo)
+}
+
+func (c *cache) Len() int {
+ return c.repos.Len()
+}
@@ -0,0 +1,78 @@
+package backend
+
+import (
+ "context"
+ "strings"
+
+ "github.com/charmbracelet/soft-serve/server/db"
+ "github.com/charmbracelet/soft-serve/server/db/models"
+ "github.com/charmbracelet/soft-serve/server/utils"
+)
+
+// AddCollaborator adds a collaborator to a repository.
+//
+// It implements backend.Backend.
+func (d *Backend) AddCollaborator(ctx context.Context, repo string, username string) error {
+ username = strings.ToLower(username)
+ if err := utils.ValidateUsername(username); err != nil {
+ return err
+ }
+
+ repo = utils.SanitizeRepo(repo)
+ return db.WrapError(
+ d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ return d.store.AddCollabByUsernameAndRepo(ctx, tx, username, repo)
+ }),
+ )
+}
+
+// Collaborators returns a list of collaborators for a repository.
+//
+// It implements backend.Backend.
+func (d *Backend) Collaborators(ctx context.Context, repo string) ([]string, error) {
+ repo = utils.SanitizeRepo(repo)
+ var users []models.User
+ if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ var err error
+ users, err = d.store.ListCollabsByRepoAsUsers(ctx, tx, repo)
+ return err
+ }); err != nil {
+ return nil, db.WrapError(err)
+ }
+
+ var usernames []string
+ for _, u := range users {
+ usernames = append(usernames, u.Username)
+ }
+
+ return usernames, nil
+}
+
+// IsCollaborator returns true if the user is a collaborator of the repository.
+//
+// It implements backend.Backend.
+func (d *Backend) IsCollaborator(ctx context.Context, repo string, username string) (bool, error) {
+ repo = utils.SanitizeRepo(repo)
+ var m models.Collab
+ if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ var err error
+ m, err = d.store.GetCollabByUsernameAndRepo(ctx, tx, username, repo)
+ return err
+ }); err != nil {
+ return false, db.WrapError(err)
+ }
+
+ return m.ID > 0, nil
+}
+
+// RemoveCollaborator removes a collaborator from a repository.
+//
+// It implements backend.Backend.
+func (d *Backend) RemoveCollaborator(ctx context.Context, repo string, username string) error {
+ repo = utils.SanitizeRepo(repo)
+ return db.WrapError(
+ d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ return d.store.RemoveCollabByUsernameAndRepo(ctx, tx, username, repo)
+ }),
+ )
+}
@@ -2,11 +2,12 @@ package backend
import "context"
-var contextKey = &struct{ string }{"backend"}
+// ContextKey is the key for the backend in the context.
+var ContextKey = &struct{ string }{"backend"}
// FromContext returns the backend from a context.
-func FromContext(ctx context.Context) Backend {
- if b, ok := ctx.Value(contextKey).(Backend); ok {
+func FromContext(ctx context.Context) *Backend {
+ if b, ok := ctx.Value(ContextKey).(*Backend); ok {
return b
}
@@ -14,6 +15,6 @@ func FromContext(ctx context.Context) Backend {
}
// WithContext returns a new context with the backend attached.
-func WithContext(ctx context.Context, b Backend) context.Context {
- return context.WithValue(ctx, contextKey, b)
+func WithContext(ctx context.Context, b *Backend) context.Context {
+ return context.WithValue(ctx, ContextKey, b)
}
@@ -1,20 +1,80 @@
package backend
import (
+ "context"
"io"
+ "sync"
+
+ "github.com/charmbracelet/soft-serve/server/hooks"
+ "github.com/charmbracelet/soft-serve/server/proto"
)
-// HookArg is an argument to a git hook.
-type HookArg struct {
- OldSha string
- NewSha string
- RefName string
+var _ hooks.Hooks = (*Backend)(nil)
+
+// PostReceive is called by the git post-receive hook.
+//
+// It implements Hooks.
+func (d *Backend) PostReceive(_ context.Context, _ io.Writer, _ io.Writer, repo string, args []hooks.HookArg) {
+ d.logger.Debug("post-receive hook called", "repo", repo, "args", args)
+}
+
+// PreReceive is called by the git pre-receive hook.
+//
+// It implements Hooks.
+func (d *Backend) PreReceive(_ context.Context, _ io.Writer, _ io.Writer, repo string, args []hooks.HookArg) {
+ d.logger.Debug("pre-receive hook called", "repo", repo, "args", args)
}
-// Hooks provides an interface for git server-side hooks.
-type Hooks interface {
- PreReceive(stdout io.Writer, stderr io.Writer, repo string, args []HookArg)
- Update(stdout io.Writer, stderr io.Writer, repo string, arg HookArg)
- PostReceive(stdout io.Writer, stderr io.Writer, repo string, args []HookArg)
- PostUpdate(stdout io.Writer, stderr io.Writer, repo string, args ...string)
+// Update is called by the git update hook.
+//
+// It implements Hooks.
+func (d *Backend) Update(_ context.Context, _ io.Writer, _ io.Writer, repo string, arg hooks.HookArg) {
+ d.logger.Debug("update hook called", "repo", repo, "arg", arg)
+}
+
+// PostUpdate is called by the git post-update hook.
+//
+// It implements Hooks.
+func (d *Backend) PostUpdate(ctx context.Context, _ io.Writer, _ io.Writer, repo string, args ...string) {
+ d.logger.Debug("post-update hook called", "repo", repo, "args", args)
+
+ var wg sync.WaitGroup
+
+ // Populate last-modified file.
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ if err := populateLastModified(ctx, d, repo); err != nil {
+ d.logger.Error("error populating last-modified", "repo", repo, "err", err)
+ return
+ }
+ }()
+
+ wg.Wait()
+}
+
+func populateLastModified(ctx context.Context, d *Backend, name string) error {
+ var rr *repo
+ _rr, err := d.Repository(ctx, name)
+ if err != nil {
+ return err
+ }
+
+ if r, ok := _rr.(*repo); ok {
+ rr = r
+ } else {
+ return proto.ErrRepoNotExist
+ }
+
+ r, err := rr.Open()
+ if err != nil {
+ return err
+ }
+
+ c, err := r.LatestCommitTime()
+ if err != nil {
+ return err
+ }
+
+ return rr.writeLastModified(c)
}
@@ -1,86 +1,484 @@
package backend
import (
+ "bufio"
+ "context"
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
"time"
"github.com/charmbracelet/soft-serve/git"
+ "github.com/charmbracelet/soft-serve/server/db"
+ "github.com/charmbracelet/soft-serve/server/db/models"
+ "github.com/charmbracelet/soft-serve/server/hooks"
+ "github.com/charmbracelet/soft-serve/server/proto"
+ "github.com/charmbracelet/soft-serve/server/utils"
)
-// RepositoryOptions are options for creating a new repository.
-type RepositoryOptions struct {
- Private bool
- Description string
- ProjectName string
- Mirror bool
- Hidden bool
-}
-
-// RepositoryStore is an interface for managing repositories.
-type RepositoryStore interface {
- // Repository finds the given repository.
- Repository(repo string) (Repository, error)
- // Repositories returns a list of all repositories.
- Repositories() ([]Repository, error)
- // CreateRepository creates a new repository.
- CreateRepository(name string, opts RepositoryOptions) (Repository, error)
- // ImportRepository creates a new repository from a Git repository.
- ImportRepository(name string, remote string, opts RepositoryOptions) (Repository, error)
- // DeleteRepository deletes a repository.
- DeleteRepository(name string) error
- // RenameRepository renames a repository.
- RenameRepository(oldName, newName string) error
-}
-
-// RepositoryMetadata is an interface for managing repository metadata.
-type RepositoryMetadata interface {
- // ProjectName returns the repository's project name.
- ProjectName(repo string) (string, error)
- // SetProjectName sets the repository's project name.
- SetProjectName(repo, name string) error
- // Description returns the repository's description.
- Description(repo string) (string, error)
- // SetDescription sets the repository's description.
- SetDescription(repo, desc string) error
- // IsPrivate returns whether the repository is private.
- IsPrivate(repo string) (bool, error)
- // SetPrivate sets whether the repository is private.
- SetPrivate(repo string, private bool) error
- // IsMirror returns whether the repository is a mirror.
- IsMirror(repo string) (bool, error)
- // IsHidden returns whether the repository is hidden.
- IsHidden(repo string) (bool, error)
- // SetHidden sets whether the repository is hidden.
- SetHidden(repo string, hidden bool) error
-}
-
-// RepositoryAccess is an interface for managing repository access.
-type RepositoryAccess interface {
- IsCollaborator(repo string, username string) (bool, error)
- // AddCollaborator adds the authorized key as a collaborator on the repository.
- AddCollaborator(repo string, username string) error
- // RemoveCollaborator removes the authorized key as a collaborator on the repository.
- RemoveCollaborator(repo string, username string) error
- // Collaborators returns a list of all collaborators on the repository.
- Collaborators(repo string) ([]string, error)
-}
-
-// Repository is a Git repository interface.
-type Repository interface {
- // Name returns the repository's name.
- Name() string
- // ProjectName returns the repository's project name.
- ProjectName() string
- // Description returns the repository's description.
- Description() string
- // IsPrivate returns whether the repository is private.
- IsPrivate() bool
- // IsMirror returns whether the repository is a mirror.
- IsMirror() bool
- // IsHidden returns whether the repository is hidden.
- IsHidden() bool
- // UpdatedAt returns the time the repository was last updated.
- // If the repository has never been updated, it returns the time it was created.
- UpdatedAt() time.Time
- // Open returns the underlying git.Repository.
- Open() (*git.Repository, error)
+func (d *Backend) reposPath() string {
+ return filepath.Join(d.cfg.DataPath, "repos")
+}
+
+// CreateRepository creates a new repository.
+//
+// It implements backend.Backend.
+func (d *Backend) CreateRepository(ctx context.Context, name string, opts proto.RepositoryOptions) (proto.Repository, error) {
+ name = utils.SanitizeRepo(name)
+ if err := utils.ValidateRepo(name); err != nil {
+ return nil, err
+ }
+
+ repo := name + ".git"
+ rp := filepath.Join(d.reposPath(), repo)
+
+ if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ if err := d.store.CreateRepo(
+ ctx,
+ tx,
+ name,
+ opts.ProjectName,
+ opts.Description,
+ opts.Private,
+ opts.Hidden,
+ opts.Mirror,
+ ); err != nil {
+ return err
+ }
+
+ _, err := git.Init(rp, true)
+ if err != nil {
+ d.logger.Debug("failed to create repository", "err", err)
+ return err
+ }
+
+ return hooks.GenerateHooks(ctx, d.cfg, repo)
+ }); err != nil {
+ d.logger.Debug("failed to create repository in database", "err", err)
+ return nil, db.WrapError(err)
+ }
+
+ return d.Repository(ctx, name)
+}
+
+// ImportRepository imports a repository from remote.
+func (d *Backend) ImportRepository(ctx context.Context, name string, remote string, opts proto.RepositoryOptions) (proto.Repository, error) {
+ name = utils.SanitizeRepo(name)
+ if err := utils.ValidateRepo(name); err != nil {
+ return nil, err
+ }
+
+ repo := name + ".git"
+ rp := filepath.Join(d.reposPath(), repo)
+
+ if _, err := os.Stat(rp); err == nil || os.IsExist(err) {
+ return nil, proto.ErrRepoExist
+ }
+
+ copts := git.CloneOptions{
+ Bare: true,
+ Mirror: opts.Mirror,
+ Quiet: true,
+ CommandOptions: git.CommandOptions{
+ Timeout: -1,
+ Context: ctx,
+ Envs: []string{
+ fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile="%s" -o StrictHostKeyChecking=no -i "%s"`,
+ filepath.Join(d.cfg.DataPath, "ssh", "known_hosts"),
+ d.cfg.SSH.ClientKeyPath,
+ ),
+ },
+ },
+ // Timeout: time.Hour,
+ }
+
+ if err := git.Clone(remote, rp, copts); err != nil {
+ d.logger.Error("failed to clone repository", "err", err, "mirror", opts.Mirror, "remote", remote, "path", rp)
+ // Cleanup the mess!
+ if rerr := os.RemoveAll(rp); rerr != nil {
+ err = errors.Join(err, rerr)
+ }
+ return nil, err
+ }
+
+ return d.CreateRepository(ctx, name, opts)
+}
+
+// DeleteRepository deletes a repository.
+//
+// It implements backend.Backend.
+func (d *Backend) DeleteRepository(ctx context.Context, name string) error {
+ name = utils.SanitizeRepo(name)
+ repo := name + ".git"
+ rp := filepath.Join(d.reposPath(), repo)
+
+ return d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ // Delete repo from cache
+ defer d.cache.Delete(name)
+
+ if err := d.store.DeleteRepoByName(ctx, tx, name); err != nil {
+ return err
+ }
+
+ return os.RemoveAll(rp)
+ })
+}
+
+// RenameRepository renames a repository.
+//
+// It implements backend.Backend.
+func (d *Backend) RenameRepository(ctx context.Context, oldName string, newName string) error {
+ oldName = utils.SanitizeRepo(oldName)
+ if err := utils.ValidateRepo(oldName); err != nil {
+ return err
+ }
+
+ newName = utils.SanitizeRepo(newName)
+ if err := utils.ValidateRepo(newName); err != nil {
+ return err
+ }
+ oldRepo := oldName + ".git"
+ newRepo := newName + ".git"
+ op := filepath.Join(d.reposPath(), oldRepo)
+ np := filepath.Join(d.reposPath(), newRepo)
+ if _, err := os.Stat(op); err != nil {
+ return proto.ErrRepoNotExist
+ }
+
+ if _, err := os.Stat(np); err == nil {
+ return proto.ErrRepoExist
+ }
+
+ if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ // Delete cache
+ defer d.cache.Delete(oldName)
+
+ if err := d.store.SetRepoNameByName(ctx, tx, oldName, newName); err != nil {
+ return err
+ }
+
+ // Make sure the new repository parent directory exists.
+ if err := os.MkdirAll(filepath.Dir(np), os.ModePerm); err != nil {
+ return err
+ }
+
+ return os.Rename(op, np)
+ }); err != nil {
+ return db.WrapError(err)
+ }
+
+ return nil
+}
+
+// Repositories returns a list of repositories per page.
+//
+// It implements backend.Backend.
+func (d *Backend) Repositories(ctx context.Context) ([]proto.Repository, error) {
+ repos := make([]proto.Repository, 0)
+
+ if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ ms, err := d.store.GetAllRepos(ctx, tx)
+ if err != nil {
+ return err
+ }
+
+ for _, m := range ms {
+ r := &repo{
+ name: m.Name,
+ path: filepath.Join(d.reposPath(), m.Name+".git"),
+ repo: m,
+ }
+
+ // Cache repositories
+ d.cache.Set(m.Name, r)
+
+ repos = append(repos, r)
+ }
+
+ return nil
+ }); err != nil {
+ return nil, db.WrapError(err)
+ }
+
+ return repos, nil
+}
+
+// Repository returns a repository by name.
+//
+// It implements backend.Backend.
+func (d *Backend) Repository(ctx context.Context, name string) (proto.Repository, error) {
+ var m models.Repo
+ name = utils.SanitizeRepo(name)
+
+ if r, ok := d.cache.Get(name); ok && r != nil {
+ return r, nil
+ }
+
+ rp := filepath.Join(d.reposPath(), name+".git")
+ if _, err := os.Stat(rp); err != nil {
+ return nil, os.ErrNotExist
+ }
+
+ if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ var err error
+ m, err = d.store.GetRepoByName(ctx, tx, name)
+ return err
+ }); err != nil {
+ return nil, db.WrapError(err)
+ }
+
+ r := &repo{
+ name: name,
+ path: rp,
+ repo: m,
+ }
+
+ // Add to cache
+ d.cache.Set(name, r)
+
+ return r, nil
+}
+
+// Description returns the description of a repository.
+//
+// It implements backend.Backend.
+func (d *Backend) Description(ctx context.Context, name string) (string, error) {
+ name = utils.SanitizeRepo(name)
+ var desc string
+ if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ var err error
+ desc, err = d.store.GetRepoDescriptionByName(ctx, tx, name)
+ return err
+ }); err != nil {
+ return "", db.WrapError(err)
+ }
+
+ return desc, nil
+}
+
+// IsMirror returns true if the repository is a mirror.
+//
+// It implements backend.Backend.
+func (d *Backend) IsMirror(ctx context.Context, name string) (bool, error) {
+ name = utils.SanitizeRepo(name)
+ var mirror bool
+ if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ var err error
+ mirror, err = d.store.GetRepoIsMirrorByName(ctx, tx, name)
+ return err
+ }); err != nil {
+ return false, db.WrapError(err)
+ }
+ return mirror, nil
+}
+
+// IsPrivate returns true if the repository is private.
+//
+// It implements backend.Backend.
+func (d *Backend) IsPrivate(ctx context.Context, name string) (bool, error) {
+ name = utils.SanitizeRepo(name)
+ var private bool
+ if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ var err error
+ private, err = d.store.GetRepoIsPrivateByName(ctx, tx, name)
+ return err
+ }); err != nil {
+ return false, db.WrapError(err)
+ }
+
+ return private, nil
+}
+
+// IsHidden returns true if the repository is hidden.
+//
+// It implements backend.Backend.
+func (d *Backend) IsHidden(ctx context.Context, name string) (bool, error) {
+ name = utils.SanitizeRepo(name)
+ var hidden bool
+ if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ var err error
+ hidden, err = d.store.GetRepoIsHiddenByName(ctx, tx, name)
+ return err
+ }); err != nil {
+ return false, db.WrapError(err)
+ }
+
+ return hidden, nil
+}
+
+// ProjectName returns the project name of a repository.
+//
+// It implements backend.Backend.
+func (d *Backend) ProjectName(ctx context.Context, name string) (string, error) {
+ name = utils.SanitizeRepo(name)
+ var pname string
+ if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ var err error
+ pname, err = d.store.GetRepoProjectNameByName(ctx, tx, name)
+ return err
+ }); err != nil {
+ return "", db.WrapError(err)
+ }
+
+ return pname, nil
+}
+
+// SetHidden sets the hidden flag of a repository.
+//
+// It implements backend.Backend.
+func (d *Backend) SetHidden(ctx context.Context, name string, hidden bool) error {
+ name = utils.SanitizeRepo(name)
+
+ // Delete cache
+ d.cache.Delete(name)
+
+ return db.WrapError(d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ return d.store.SetRepoIsHiddenByName(ctx, tx, name, hidden)
+ }))
+}
+
+// SetDescription sets the description of a repository.
+//
+// It implements backend.Backend.
+func (d *Backend) SetDescription(ctx context.Context, repo string, desc string) error {
+ repo = utils.SanitizeRepo(repo)
+
+ // Delete cache
+ d.cache.Delete(repo)
+
+ return d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ return d.store.SetRepoDescriptionByName(ctx, tx, repo, desc)
+ })
+}
+
+// SetPrivate sets the private flag of a repository.
+//
+// It implements backend.Backend.
+func (d *Backend) SetPrivate(ctx context.Context, repo string, private bool) error {
+ repo = utils.SanitizeRepo(repo)
+
+ // Delete cache
+ d.cache.Delete(repo)
+
+ return db.WrapError(
+ d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ return d.store.SetRepoIsPrivateByName(ctx, tx, repo, private)
+ }),
+ )
+}
+
+// SetProjectName sets the project name of a repository.
+//
+// It implements backend.Backend.
+func (d *Backend) SetProjectName(ctx context.Context, repo string, name string) error {
+ repo = utils.SanitizeRepo(repo)
+
+ // Delete cache
+ d.cache.Delete(repo)
+
+ return db.WrapError(
+ d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ return d.store.SetRepoProjectNameByName(ctx, tx, repo, name)
+ }),
+ )
+}
+
+var _ proto.Repository = (*repo)(nil)
+
+// repo is a Git repository with metadata stored in a SQLite database.
+type repo struct {
+ name string
+ path string
+ repo models.Repo
+}
+
+// Description returns the repository's description.
+//
+// It implements backend.Repository.
+func (r *repo) Description() string {
+ return r.repo.Description
+}
+
+// IsMirror returns whether the repository is a mirror.
+//
+// It implements backend.Repository.
+func (r *repo) IsMirror() bool {
+ return r.repo.Mirror
+}
+
+// IsPrivate returns whether the repository is private.
+//
+// It implements backend.Repository.
+func (r *repo) IsPrivate() bool {
+ return r.repo.Private
+}
+
+// Name returns the repository's name.
+//
+// It implements backend.Repository.
+func (r *repo) Name() string {
+ return r.name
+}
+
+// Open opens the repository.
+//
+// It implements backend.Repository.
+func (r *repo) Open() (*git.Repository, error) {
+ return git.Open(r.path)
+}
+
+// ProjectName returns the repository's project name.
+//
+// It implements backend.Repository.
+func (r *repo) ProjectName() string {
+ return r.repo.ProjectName
+}
+
+// IsHidden returns whether the repository is hidden.
+//
+// It implements backend.Repository.
+func (r *repo) IsHidden() bool {
+ return r.repo.Hidden
+}
+
+// UpdatedAt returns the repository's last update time.
+func (r *repo) UpdatedAt() time.Time {
+ // Try to read the last modified time from the info directory.
+ if t, err := readOneline(filepath.Join(r.path, "info", "last-modified")); err == nil {
+ if t, err := time.Parse(time.RFC3339, t); err == nil {
+ return t
+ }
+ }
+
+ rr, err := git.Open(r.path)
+ if err == nil {
+ t, err := rr.LatestCommitTime()
+ if err == nil {
+ return t
+ }
+ }
+
+ return r.repo.UpdatedAt
+}
+
+func (r *repo) writeLastModified(t time.Time) error {
+ fp := filepath.Join(r.path, "info", "last-modified")
+ if err := os.MkdirAll(filepath.Dir(fp), os.ModePerm); err != nil {
+ return err
+ }
+
+ return os.WriteFile(fp, []byte(t.Format(time.RFC3339)), os.ModePerm)
+}
+
+func readOneline(path string) (string, error) {
+ f, err := os.Open(path)
+ if err != nil {
+ return "", err
+ }
+
+ defer f.Close() // nolint: errcheck
+ s := bufio.NewScanner(f)
+ s.Scan()
+ return s.Text(), s.Err()
}
@@ -1,13 +1,58 @@
package backend
-// SettingsBackend is an interface that handles server configuration.
-type SettingsBackend interface {
- // AnonAccess returns the access level for anonymous users.
- AnonAccess() AccessLevel
- // SetAnonAccess sets the access level for anonymous users.
- SetAnonAccess(level AccessLevel) error
- // AllowKeyless returns true if keyless access is allowed.
- AllowKeyless() bool
- // SetAllowKeyless sets whether or not keyless access is allowed.
- SetAllowKeyless(allow bool) error
+import (
+ "context"
+
+ "github.com/charmbracelet/soft-serve/server/access"
+ "github.com/charmbracelet/soft-serve/server/db"
+)
+
+// AllowKeyless returns whether or not keyless access is allowed.
+//
+// It implements backend.Backend.
+func (b *Backend) AllowKeyless(ctx context.Context) bool {
+ var allow bool
+ if err := b.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ var err error
+ allow, err = b.store.GetAllowKeylessAccess(ctx, tx)
+ return err
+ }); err != nil {
+ return false
+ }
+
+ return allow
+}
+
+// SetAllowKeyless sets whether or not keyless access is allowed.
+//
+// It implements backend.Backend.
+func (b *Backend) SetAllowKeyless(ctx context.Context, allow bool) error {
+ return b.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ return b.store.SetAllowKeylessAccess(ctx, tx, allow)
+ })
+}
+
+// AnonAccess returns the level of anonymous access.
+//
+// It implements backend.Backend.
+func (b *Backend) AnonAccess(ctx context.Context) access.AccessLevel {
+ var level access.AccessLevel
+ if err := b.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ var err error
+ level, err = b.store.GetAnonAccess(ctx, tx)
+ return err
+ }); err != nil {
+ return access.NoAccess
+ }
+
+ return level
+}
+
+// SetAnonAccess sets the level of anonymous access.
+//
+// It implements backend.Backend.
+func (b *Backend) SetAnonAccess(ctx context.Context, level access.AccessLevel) error {
+ return b.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ return b.store.SetAnonAccess(ctx, tx, level)
+ })
}
@@ -1,141 +0,0 @@
-package sqlite
-
-import (
- "context"
- "database/sql"
- "errors"
- "fmt"
-
- "github.com/charmbracelet/soft-serve/server/backend"
- "github.com/jmoiron/sqlx"
- "modernc.org/sqlite"
- sqlite3 "modernc.org/sqlite/lib"
-)
-
-// Close closes the database.
-func (d *SqliteBackend) Close() error {
- return d.db.Close()
-}
-
-// init creates the database.
-func (d *SqliteBackend) init() error {
- return wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
- if _, err := tx.Exec(sqlCreateSettingsTable); err != nil {
- return err
- }
- if _, err := tx.Exec(sqlCreateUserTable); err != nil {
- return err
- }
- if _, err := tx.Exec(sqlCreatePublicKeyTable); err != nil {
- return err
- }
- if _, err := tx.Exec(sqlCreateRepoTable); err != nil {
- return err
- }
- if _, err := tx.Exec(sqlCreateCollabTable); err != nil {
- return err
- }
-
- // Set default settings.
- if _, err := tx.Exec("INSERT OR IGNORE INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)", "allow_keyless", true); err != nil {
- return err
- }
- if _, err := tx.Exec("INSERT OR IGNORE INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)", "anon_access", backend.ReadOnlyAccess.String()); err != nil {
- return err
- }
-
- var init bool
- if err := tx.Get(&init, "SELECT value FROM settings WHERE key = 'init'"); err != nil && !errors.Is(err, sql.ErrNoRows) {
- return err
- }
-
- // Create default user.
- if !init {
- r, err := tx.Exec("INSERT OR IGNORE INTO user (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP);", "admin", true)
- if err != nil {
- return err
- }
- userID, err := r.LastInsertId()
- if err != nil {
- return err
- }
-
- // Add initial keys
- // Don't use cfg.AdminKeys since it also includes the internal key
- // used for internal api access.
- for _, k := range d.cfg.InitialAdminKeys {
- pk, _, err := backend.ParseAuthorizedKey(k)
- if err != nil {
- d.logger.Error("error parsing initial admin key, skipping", "key", k, "err", err)
- continue
- }
-
- stmt, err := tx.Prepare(`INSERT INTO public_key (user_id, public_key, updated_at)
- VALUES (?, ?, CURRENT_TIMESTAMP);`)
- if err != nil {
- return err
- }
-
- defer stmt.Close() // nolint: errcheck
- if _, err := stmt.Exec(userID, backend.MarshalAuthorizedKey(pk)); err != nil {
- return err
- }
- }
- }
-
- // set init flag
- if _, err := tx.Exec("INSERT OR IGNORE INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)", "init", true); err != nil {
- return err
- }
-
- return nil
- })
-}
-
-func wrapDbErr(err error) error {
- if err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- return ErrNoRecord
- }
- if liteErr, ok := err.(*sqlite.Error); ok {
- code := liteErr.Code()
- if code == sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY ||
- code == sqlite3.SQLITE_CONSTRAINT_UNIQUE {
- return ErrDuplicateKey
- }
- }
- }
- return err
-}
-
-func wrapTx(db *sqlx.DB, ctx context.Context, fn func(tx *sqlx.Tx) error) error {
- tx, err := db.BeginTxx(ctx, nil)
- if err != nil {
- return fmt.Errorf("failed to begin transaction: %w", err)
- }
-
- if err := fn(tx); err != nil {
- return rollback(tx, err)
- }
-
- if err := tx.Commit(); err != nil {
- if errors.Is(err, sql.ErrTxDone) {
- // this is ok because whoever did finish the tx should have also written the error already.
- return nil
- }
- return fmt.Errorf("failed to commit transaction: %w", err)
- }
-
- return nil
-}
-
-func rollback(tx *sqlx.Tx, err error) error {
- if rerr := tx.Rollback(); rerr != nil {
- if errors.Is(rerr, sql.ErrTxDone) {
- return err
- }
- return fmt.Errorf("failed to rollback: %s: %w", err.Error(), rerr)
- }
-
- return err
-}
@@ -1,20 +0,0 @@
-package sqlite
-
-import (
- "errors"
- "fmt"
-)
-
-var (
- // ErrDuplicateKey is returned when a unique constraint is violated.
- ErrDuplicateKey = errors.New("record already exists")
-
- // ErrNoRecord is returned when a record is not found.
- ErrNoRecord = errors.New("record not found")
-
- // ErrRepoNotExist is returned when a repository does not exist.
- ErrRepoNotExist = fmt.Errorf("repository does not exist")
-
- // ErrRepoExist is returned when a repository already exists.
- ErrRepoExist = fmt.Errorf("repository already exists")
-)
@@ -1,76 +0,0 @@
-package sqlite
-
-import (
- "io"
- "sync"
-
- "github.com/charmbracelet/soft-serve/server/backend"
-)
-
-// PostReceive is called by the git post-receive hook.
-//
-// It implements Hooks.
-func (d *SqliteBackend) PostReceive(stdout io.Writer, stderr io.Writer, repo string, args []backend.HookArg) {
- d.logger.Debug("post-receive hook called", "repo", repo, "args", args)
-}
-
-// PreReceive is called by the git pre-receive hook.
-//
-// It implements Hooks.
-func (d *SqliteBackend) PreReceive(stdout io.Writer, stderr io.Writer, repo string, args []backend.HookArg) {
- d.logger.Debug("pre-receive hook called", "repo", repo, "args", args)
-}
-
-// Update is called by the git update hook.
-//
-// It implements Hooks.
-func (d *SqliteBackend) Update(stdout io.Writer, stderr io.Writer, repo string, arg backend.HookArg) {
- d.logger.Debug("update hook called", "repo", repo, "arg", arg)
-}
-
-// PostUpdate is called by the git post-update hook.
-//
-// It implements Hooks.
-func (d *SqliteBackend) PostUpdate(stdout io.Writer, stderr io.Writer, repo string, args ...string) {
- d.logger.Debug("post-update hook called", "repo", repo, "args", args)
-
- var wg sync.WaitGroup
-
- // Populate last-modified file.
- wg.Add(1)
- go func() {
- defer wg.Done()
- if err := populateLastModified(d, repo); err != nil {
- d.logger.Error("error populating last-modified", "repo", repo, "err", err)
- return
- }
- }()
-
- wg.Wait()
-}
-
-func populateLastModified(d *SqliteBackend, repo string) error {
- var rr *Repo
- _rr, err := d.Repository(repo)
- if err != nil {
- return err
- }
-
- if r, ok := _rr.(*Repo); ok {
- rr = r
- } else {
- return ErrRepoNotExist
- }
-
- r, err := rr.Open()
- if err != nil {
- return err
- }
-
- c, err := r.LatestCommitTime()
- if err != nil {
- return err
- }
-
- return rr.writeLastModified(c)
-}
@@ -1,202 +0,0 @@
-package sqlite
-
-import (
- "bufio"
- "context"
- "os"
- "path/filepath"
- "sync"
- "time"
-
- "github.com/charmbracelet/soft-serve/git"
- "github.com/charmbracelet/soft-serve/server/backend"
- "github.com/jmoiron/sqlx"
-)
-
-var _ backend.Repository = (*Repo)(nil)
-
-// Repo is a Git repository with metadata stored in a SQLite database.
-type Repo struct {
- name string
- path string
- db *sqlx.DB
-
- // cache
- // updatedAt is cached in "last-modified" file.
- mu sync.Mutex
- desc *string
- projectName *string
- isMirror *bool
- isPrivate *bool
- isHidden *bool
-}
-
-// Description returns the repository's description.
-//
-// It implements backend.Repository.
-func (r *Repo) Description() string {
- r.mu.Lock()
- defer r.mu.Unlock()
- if r.desc != nil {
- return *r.desc
- }
-
- var desc string
- if err := wrapTx(r.db, context.Background(), func(tx *sqlx.Tx) error {
- return tx.Get(&desc, "SELECT description FROM repo WHERE name = ?", r.name)
- }); err != nil {
- return ""
- }
-
- r.desc = &desc
- return desc
-}
-
-// IsMirror returns whether the repository is a mirror.
-//
-// It implements backend.Repository.
-func (r *Repo) IsMirror() bool {
- r.mu.Lock()
- defer r.mu.Unlock()
- if r.isMirror != nil {
- return *r.isMirror
- }
-
- var mirror bool
- if err := wrapTx(r.db, context.Background(), func(tx *sqlx.Tx) error {
- return tx.Get(&mirror, "SELECT mirror FROM repo WHERE name = ?", r.name)
- }); err != nil {
- return false
- }
-
- r.isMirror = &mirror
- return mirror
-}
-
-// IsPrivate returns whether the repository is private.
-//
-// It implements backend.Repository.
-func (r *Repo) IsPrivate() bool {
- r.mu.Lock()
- defer r.mu.Unlock()
- if r.isPrivate != nil {
- return *r.isPrivate
- }
-
- var private bool
- if err := wrapTx(r.db, context.Background(), func(tx *sqlx.Tx) error {
- return tx.Get(&private, "SELECT private FROM repo WHERE name = ?", r.name)
- }); err != nil {
- return false
- }
-
- r.isPrivate = &private
- return private
-}
-
-// Name returns the repository's name.
-//
-// It implements backend.Repository.
-func (r *Repo) Name() string {
- return r.name
-}
-
-// Open opens the repository.
-//
-// It implements backend.Repository.
-func (r *Repo) Open() (*git.Repository, error) {
- return git.Open(r.path)
-}
-
-// ProjectName returns the repository's project name.
-//
-// It implements backend.Repository.
-func (r *Repo) ProjectName() string {
- r.mu.Lock()
- defer r.mu.Unlock()
- if r.projectName != nil {
- return *r.projectName
- }
-
- var name string
- if err := wrapTx(r.db, context.Background(), func(tx *sqlx.Tx) error {
- return tx.Get(&name, "SELECT project_name FROM repo WHERE name = ?", r.name)
- }); err != nil {
- return ""
- }
-
- r.projectName = &name
- return name
-}
-
-// IsHidden returns whether the repository is hidden.
-//
-// It implements backend.Repository.
-func (r *Repo) IsHidden() bool {
- r.mu.Lock()
- defer r.mu.Unlock()
- if r.isHidden != nil {
- return *r.isHidden
- }
-
- var hidden bool
- if err := wrapTx(r.db, context.Background(), func(tx *sqlx.Tx) error {
- return tx.Get(&hidden, "SELECT hidden FROM repo WHERE name = ?", r.name)
- }); err != nil {
- return false
- }
-
- r.isHidden = &hidden
- return hidden
-}
-
-// UpdatedAt returns the repository's last update time.
-func (r *Repo) UpdatedAt() time.Time {
- var updatedAt time.Time
-
- // Try to read the last modified time from the info directory.
- if t, err := readOneline(filepath.Join(r.path, "info", "last-modified")); err == nil {
- if t, err := time.Parse(time.RFC3339, t); err == nil {
- return t
- }
- }
-
- rr, err := git.Open(r.path)
- if err == nil {
- t, err := rr.LatestCommitTime()
- if err == nil {
- updatedAt = t
- }
- }
-
- if updatedAt.IsZero() {
- if err := wrapTx(r.db, context.Background(), func(tx *sqlx.Tx) error {
- return tx.Get(&updatedAt, "SELECT updated_at FROM repo WHERE name = ?", r.name)
- }); err != nil {
- return time.Time{}
- }
- }
-
- return updatedAt
-}
-
-func (r *Repo) writeLastModified(t time.Time) error {
- fp := filepath.Join(r.path, "info", "last-modified")
- if err := os.MkdirAll(filepath.Dir(fp), os.ModePerm); err != nil {
- return err
- }
-
- return os.WriteFile(fp, []byte(t.Format(time.RFC3339)), os.ModePerm)
-}
-
-func readOneline(path string) (string, error) {
- f, err := os.Open(path)
- if err != nil {
- return "", err
- }
-
- defer f.Close() // nolint: errcheck
- s := bufio.NewScanner(f)
- s.Scan()
- return s.Text(), s.Err()
-}
@@ -1,61 +0,0 @@
-package sqlite
-
-var (
- sqlCreateSettingsTable = `CREATE TABLE IF NOT EXISTS settings (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- key TEXT NOT NULL UNIQUE,
- value TEXT NOT NULL,
- created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
- updated_at DATETIME NOT NULL
- );`
-
- sqlCreateUserTable = `CREATE TABLE IF NOT EXISTS user (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- username TEXT UNIQUE,
- admin BOOLEAN NOT NULL,
- created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
- updated_at DATETIME NOT NULL
- );`
-
- sqlCreatePublicKeyTable = `CREATE TABLE IF NOT EXISTS public_key (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- user_id INTEGER NOT NULL,
- public_key TEXT NOT NULL UNIQUE,
- created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
- updated_at DATETIME NOT NULL,
- UNIQUE (user_id, public_key),
- CONSTRAINT user_id_fk
- FOREIGN KEY(user_id) REFERENCES user(id)
- ON DELETE CASCADE
- ON UPDATE CASCADE
- );`
-
- sqlCreateRepoTable = `CREATE TABLE IF NOT EXISTS repo (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- name TEXT NOT NULL UNIQUE,
- project_name TEXT NOT NULL,
- description TEXT NOT NULL,
- private BOOLEAN NOT NULL,
- mirror BOOLEAN NOT NULL,
- hidden BOOLEAN NOT NULL,
- created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
- updated_at DATETIME NOT NULL
- );`
-
- sqlCreateCollabTable = `CREATE TABLE IF NOT EXISTS collab (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- user_id INTEGER NOT NULL,
- repo_id INTEGER NOT NULL,
- created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
- updated_at DATETIME NOT NULL,
- UNIQUE (user_id, repo_id),
- CONSTRAINT user_id_fk
- FOREIGN KEY(user_id) REFERENCES user(id)
- ON DELETE CASCADE
- ON UPDATE CASCADE,
- CONSTRAINT repo_id_fk
- FOREIGN KEY(repo_id) REFERENCES repo(id)
- ON DELETE CASCADE
- ON UPDATE CASCADE
- );`
-)
@@ -1,649 +0,0 @@
-package sqlite
-
-import (
- "context"
- "errors"
- "fmt"
- "os"
- "path/filepath"
- "strings"
-
- "github.com/charmbracelet/log"
- "github.com/charmbracelet/soft-serve/git"
- "github.com/charmbracelet/soft-serve/server/backend"
- "github.com/charmbracelet/soft-serve/server/config"
- "github.com/charmbracelet/soft-serve/server/hooks"
- "github.com/charmbracelet/soft-serve/server/utils"
- lru "github.com/hashicorp/golang-lru/v2"
- "github.com/jmoiron/sqlx"
- _ "modernc.org/sqlite" // sqlite driver
-)
-
-// SqliteBackend is a backend that uses a SQLite database as a Soft Serve
-// backend.
-type SqliteBackend struct { //nolint: revive
- cfg *config.Config
- ctx context.Context
- dp string
- db *sqlx.DB
- logger *log.Logger
-
- // Repositories cache
- cache *cache
-}
-
-var _ backend.Backend = (*SqliteBackend)(nil)
-
-func (d *SqliteBackend) reposPath() string {
- return filepath.Join(d.dp, "repos")
-}
-
-// NewSqliteBackend creates a new SqliteBackend.
-func NewSqliteBackend(ctx context.Context) (*SqliteBackend, error) {
- cfg := config.FromContext(ctx)
- dataPath := cfg.DataPath
- if err := os.MkdirAll(dataPath, os.ModePerm); err != nil {
- return nil, err
- }
-
- db, err := sqlx.Connect("sqlite", filepath.Join(dataPath, "soft-serve.db"+
- "?_pragma=busy_timeout(5000)&_pragma=foreign_keys(1)"))
- if err != nil {
- return nil, err
- }
-
- d := &SqliteBackend{
- cfg: cfg,
- ctx: ctx,
- dp: dataPath,
- db: db,
- logger: log.FromContext(ctx).WithPrefix("sqlite"),
- }
-
- // Set up LRU cache with size 1000
- d.cache = newCache(d, 1000)
-
- if err := d.init(); err != nil {
- return nil, err
- }
-
- if err := d.db.Ping(); err != nil {
- return nil, err
- }
-
- return d, d.initRepos()
-}
-
-// WithContext returns a copy of SqliteBackend with the given context.
-func (d SqliteBackend) WithContext(ctx context.Context) backend.Backend {
- d.ctx = ctx
- return &d
-}
-
-// AllowKeyless returns whether or not keyless access is allowed.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) AllowKeyless() bool {
- var allow bool
- if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
- return tx.Get(&allow, "SELECT value FROM settings WHERE key = ?;", "allow_keyless")
- }); err != nil {
- return false
- }
-
- return allow
-}
-
-// AnonAccess returns the level of anonymous access.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) AnonAccess() backend.AccessLevel {
- var level string
- if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
- return tx.Get(&level, "SELECT value FROM settings WHERE key = ?;", "anon_access")
- }); err != nil {
- return backend.NoAccess
- }
-
- return backend.ParseAccessLevel(level)
-}
-
-// SetAllowKeyless sets whether or not keyless access is allowed.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) SetAllowKeyless(allow bool) error {
- return wrapDbErr(
- wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
- _, err := tx.Exec("UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?;", allow, "allow_keyless")
- return err
- }),
- )
-}
-
-// SetAnonAccess sets the level of anonymous access.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) SetAnonAccess(level backend.AccessLevel) error {
- return wrapDbErr(
- wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
- _, err := tx.Exec("UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?;", level.String(), "anon_access")
- return err
- }),
- )
-}
-
-// CreateRepository creates a new repository.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) CreateRepository(name string, opts backend.RepositoryOptions) (backend.Repository, error) {
- name = utils.SanitizeRepo(name)
- if err := utils.ValidateRepo(name); err != nil {
- return nil, err
- }
-
- repo := name + ".git"
- rp := filepath.Join(d.reposPath(), repo)
-
- if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
- if _, err := tx.Exec(`INSERT INTO repo (name, project_name, description, private, mirror, hidden, updated_at)
- VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP);`,
- name, opts.ProjectName, opts.Description, opts.Private, opts.Mirror, opts.Hidden); err != nil {
- return err
- }
-
- _, err := git.Init(rp, true)
- if err != nil {
- d.logger.Debug("failed to create repository", "err", err)
- return err
- }
-
- return nil
- }); err != nil {
- d.logger.Debug("failed to create repository in database", "err", err)
- return nil, wrapDbErr(err)
- }
-
- r := &Repo{
- name: name,
- path: rp,
- db: d.db,
- }
-
- // Set cache
- d.cache.Set(name, r)
-
- return r, d.initRepo(name)
-}
-
-// ImportRepository imports a repository from remote.
-func (d *SqliteBackend) ImportRepository(name string, remote string, opts backend.RepositoryOptions) (backend.Repository, error) {
- name = utils.SanitizeRepo(name)
- if err := utils.ValidateRepo(name); err != nil {
- return nil, err
- }
-
- repo := name + ".git"
- rp := filepath.Join(d.reposPath(), repo)
-
- if _, err := os.Stat(rp); err == nil || os.IsExist(err) {
- return nil, ErrRepoExist
- }
-
- copts := git.CloneOptions{
- Bare: true,
- Mirror: opts.Mirror,
- Quiet: true,
- CommandOptions: git.CommandOptions{
- Timeout: -1,
- Context: d.ctx,
- Envs: []string{
- fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile="%s" -o StrictHostKeyChecking=no -i "%s"`,
- filepath.Join(d.cfg.DataPath, "ssh", "known_hosts"),
- d.cfg.SSH.ClientKeyPath,
- ),
- },
- },
- // Timeout: time.Hour,
- }
-
- if err := git.Clone(remote, rp, copts); err != nil {
- d.logger.Error("failed to clone repository", "err", err, "mirror", opts.Mirror, "remote", remote, "path", rp)
- // Cleanup the mess!
- if rerr := os.RemoveAll(rp); rerr != nil {
- err = errors.Join(err, rerr)
- }
- return nil, err
- }
-
- return d.CreateRepository(name, opts)
-}
-
-// DeleteRepository deletes a repository.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) DeleteRepository(name string) error {
- name = utils.SanitizeRepo(name)
- repo := name + ".git"
- rp := filepath.Join(d.reposPath(), repo)
-
- return wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
- // Delete repo from cache
- defer d.cache.Delete(name)
-
- if _, err := tx.Exec("DELETE FROM repo WHERE name = ?;", name); err != nil {
- return err
- }
-
- return os.RemoveAll(rp)
- })
-}
-
-// RenameRepository renames a repository.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) RenameRepository(oldName string, newName string) error {
- oldName = utils.SanitizeRepo(oldName)
- if err := utils.ValidateRepo(oldName); err != nil {
- return err
- }
-
- newName = utils.SanitizeRepo(newName)
- if err := utils.ValidateRepo(newName); err != nil {
- return err
- }
- oldRepo := oldName + ".git"
- newRepo := newName + ".git"
- op := filepath.Join(d.reposPath(), oldRepo)
- np := filepath.Join(d.reposPath(), newRepo)
- if _, err := os.Stat(op); err != nil {
- return ErrRepoNotExist
- }
-
- if _, err := os.Stat(np); err == nil {
- return ErrRepoExist
- }
-
- if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
- // Delete cache
- defer d.cache.Delete(oldName)
-
- _, err := tx.Exec("UPDATE repo SET name = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?;", newName, oldName)
- if err != nil {
- return err
- }
-
- // Make sure the new repository parent directory exists.
- if err := os.MkdirAll(filepath.Dir(np), os.ModePerm); err != nil {
- return err
- }
-
- if err := os.Rename(op, np); err != nil {
- return err
- }
-
- return nil
- }); err != nil {
- return wrapDbErr(err)
- }
-
- return nil
-}
-
-// Repositories returns a list of all repositories.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) Repositories() ([]backend.Repository, error) {
- repos := make([]backend.Repository, 0)
-
- if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
- rows, err := tx.Query("SELECT name FROM repo")
- if err != nil {
- return err
- }
-
- defer rows.Close() // nolint: errcheck
- for rows.Next() {
- var name string
- if err := rows.Scan(&name); err != nil {
- return err
- }
-
- if r, ok := d.cache.Get(name); ok && r != nil {
- repos = append(repos, r)
- continue
- }
-
- r := &Repo{
- name: name,
- path: filepath.Join(d.reposPath(), name+".git"),
- db: d.db,
- }
-
- // Cache repositories
- d.cache.Set(name, r)
-
- repos = append(repos, r)
- }
-
- return nil
- }); err != nil {
- return nil, wrapDbErr(err)
- }
-
- return repos, nil
-}
-
-// Repository returns a repository by name.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) Repository(repo string) (backend.Repository, error) {
- repo = utils.SanitizeRepo(repo)
-
- if r, ok := d.cache.Get(repo); ok && r != nil {
- return r, nil
- }
-
- rp := filepath.Join(d.reposPath(), repo+".git")
- if _, err := os.Stat(rp); err != nil {
- return nil, os.ErrNotExist
- }
-
- var count int
- if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
- return tx.Get(&count, "SELECT COUNT(*) FROM repo WHERE name = ?", repo)
- }); err != nil {
- return nil, wrapDbErr(err)
- }
-
- if count == 0 {
- d.logger.Warn("repository exists but not found in database", "repo", repo)
- return nil, ErrRepoNotExist
- }
-
- r := &Repo{
- name: repo,
- path: rp,
- db: d.db,
- }
-
- // Add to cache
- d.cache.Set(repo, r)
-
- return r, nil
-}
-
-// Description returns the description of a repository.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) Description(repo string) (string, error) {
- r, err := d.Repository(repo)
- if err != nil {
- return "", err
- }
-
- return r.Description(), nil
-}
-
-// IsMirror returns true if the repository is a mirror.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) IsMirror(repo string) (bool, error) {
- r, err := d.Repository(repo)
- if err != nil {
- return false, err
- }
-
- return r.IsMirror(), nil
-}
-
-// IsPrivate returns true if the repository is private.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) IsPrivate(repo string) (bool, error) {
- r, err := d.Repository(repo)
- if err != nil {
- return false, err
- }
-
- return r.IsPrivate(), nil
-}
-
-// IsHidden returns true if the repository is hidden.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) IsHidden(repo string) (bool, error) {
- r, err := d.Repository(repo)
- if err != nil {
- return false, err
- }
-
- return r.IsHidden(), nil
-}
-
-// SetHidden sets the hidden flag of a repository.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) SetHidden(repo string, hidden bool) error {
- repo = utils.SanitizeRepo(repo)
-
- // Delete cache
- d.cache.Delete(repo)
-
- return wrapDbErr(wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
- var count int
- if err := tx.Get(&count, "SELECT COUNT(*) FROM repo WHERE name = ?", repo); err != nil {
- return err
- }
- if count == 0 {
- return ErrRepoNotExist
- }
- _, err := tx.Exec("UPDATE repo SET hidden = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?;", hidden, repo)
- return err
- }))
-}
-
-// ProjectName returns the project name of a repository.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) ProjectName(repo string) (string, error) {
- r, err := d.Repository(repo)
- if err != nil {
- return "", err
- }
-
- return r.ProjectName(), nil
-}
-
-// SetDescription sets the description of a repository.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) SetDescription(repo string, desc string) error {
- repo = utils.SanitizeRepo(repo)
-
- // Delete cache
- d.cache.Delete(repo)
-
- return wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
- var count int
- if err := tx.Get(&count, "SELECT COUNT(*) FROM repo WHERE name = ?", repo); err != nil {
- return err
- }
- if count == 0 {
- return ErrRepoNotExist
- }
- _, err := tx.Exec("UPDATE repo SET description = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?", desc, repo)
- return err
- })
-}
-
-// SetPrivate sets the private flag of a repository.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) SetPrivate(repo string, private bool) error {
- repo = utils.SanitizeRepo(repo)
-
- // Delete cache
- d.cache.Delete(repo)
-
- return wrapDbErr(
- wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
- var count int
- if err := tx.Get(&count, "SELECT COUNT(*) FROM repo WHERE name = ?", repo); err != nil {
- return err
- }
- if count == 0 {
- return ErrRepoNotExist
- }
- _, err := tx.Exec("UPDATE repo SET private = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?", private, repo)
- return err
- }),
- )
-}
-
-// SetProjectName sets the project name of a repository.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) SetProjectName(repo string, name string) error {
- repo = utils.SanitizeRepo(repo)
-
- // Delete cache
- d.cache.Delete(repo)
-
- return wrapDbErr(
- wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
- var count int
- if err := tx.Get(&count, "SELECT COUNT(*) FROM repo WHERE name = ?", repo); err != nil {
- return err
- }
- if count == 0 {
- return ErrRepoNotExist
- }
- _, err := tx.Exec("UPDATE repo SET project_name = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?", name, repo)
- return err
- }),
- )
-}
-
-// AddCollaborator adds a collaborator to a repository.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) AddCollaborator(repo string, username string) error {
- username = strings.ToLower(username)
- if err := utils.ValidateUsername(username); err != nil {
- return err
- }
-
- repo = utils.SanitizeRepo(repo)
- return wrapDbErr(wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
- _, err := tx.Exec(`INSERT INTO collab (user_id, repo_id, updated_at)
- VALUES (
- (SELECT id FROM user WHERE username = ?),
- (SELECT id FROM repo WHERE name = ?),
- CURRENT_TIMESTAMP
- );`, username, repo)
- return err
- }),
- )
-}
-
-// Collaborators returns a list of collaborators for a repository.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) Collaborators(repo string) ([]string, error) {
- repo = utils.SanitizeRepo(repo)
- var users []string
- if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
- return tx.Select(&users, `SELECT user.username FROM user
- INNER JOIN collab ON user.id = collab.user_id
- INNER JOIN repo ON repo.id = collab.repo_id
- WHERE repo.name = ?`, repo)
- }); err != nil {
- return nil, wrapDbErr(err)
- }
-
- return users, nil
-}
-
-// IsCollaborator returns true if the user is a collaborator of the repository.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) IsCollaborator(repo string, username string) (bool, error) {
- repo = utils.SanitizeRepo(repo)
- var count int
- if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
- return tx.Get(&count, `SELECT COUNT(*) FROM user
- INNER JOIN collab ON user.id = collab.user_id
- INNER JOIN repo ON repo.id = collab.repo_id
- WHERE repo.name = ? AND user.username = ?`, repo, username)
- }); err != nil {
- return false, wrapDbErr(err)
- }
-
- return count > 0, nil
-}
-
-// RemoveCollaborator removes a collaborator from a repository.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) RemoveCollaborator(repo string, username string) error {
- repo = utils.SanitizeRepo(repo)
- return wrapDbErr(
- wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error {
- _, err := tx.Exec(`DELETE FROM collab
- WHERE user_id = (SELECT id FROM user WHERE username = ?)
- AND repo_id = (SELECT id FROM repo WHERE name = ?)`, username, repo)
- return err
- }),
- )
-}
-
-func (d *SqliteBackend) initRepo(repo string) error {
- return hooks.GenerateHooks(d.ctx, d.cfg, repo)
-}
-
-func (d *SqliteBackend) initRepos() error {
- repos, err := d.Repositories()
- if err != nil {
- return err
- }
-
- for _, repo := range repos {
- if err := d.initRepo(repo.Name()); err != nil {
- return err
- }
- }
-
- return nil
-}
-
-// TODO: implement a caching interface.
-type cache struct {
- b *SqliteBackend
- repos *lru.Cache[string, *Repo]
-}
-
-func newCache(b *SqliteBackend, size int) *cache {
- if size <= 0 {
- size = 1
- }
- c := &cache{b: b}
- cache, _ := lru.New[string, *Repo](size)
- c.repos = cache
- return c
-}
-
-func (c *cache) Get(repo string) (*Repo, bool) {
- return c.repos.Get(repo)
-}
-
-func (c *cache) Set(repo string, r *Repo) {
- c.repos.Add(repo, r)
-}
-
-func (c *cache) Delete(repo string) {
- c.repos.Remove(repo)
-}
-
-func (c *cache) Len() int {
- return c.repos.Len()
-}
@@ -1,365 +0,0 @@
-package sqlite
-
-import (
- "context"
- "strings"
-
- "github.com/charmbracelet/soft-serve/server/backend"
- "github.com/charmbracelet/soft-serve/server/utils"
- "github.com/jmoiron/sqlx"
- "golang.org/x/crypto/ssh"
-)
-
-// User represents a user.
-type User struct {
- username string
- db *sqlx.DB
-}
-
-var _ backend.User = (*User)(nil)
-
-// IsAdmin returns whether the user is an admin.
-//
-// It implements backend.User.
-func (u *User) IsAdmin() bool {
- var admin bool
- if err := wrapTx(u.db, context.Background(), func(tx *sqlx.Tx) error {
- return tx.Get(&admin, "SELECT admin FROM user WHERE username = ?", u.username)
- }); err != nil {
- return false
- }
-
- return admin
-}
-
-// PublicKeys returns the user's public keys.
-//
-// It implements backend.User.
-func (u *User) PublicKeys() []ssh.PublicKey {
- var keys []ssh.PublicKey
- if err := wrapTx(u.db, context.Background(), func(tx *sqlx.Tx) error {
- var keyStrings []string
- if err := tx.Select(&keyStrings, `SELECT public_key
- FROM public_key
- INNER JOIN user ON user.id = public_key.user_id
- WHERE user.username = ?
- ORDER BY public_key.id asc;`, u.username); err != nil {
- return err
- }
-
- for _, keyString := range keyStrings {
- key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyString))
- if err != nil {
- return err
- }
- keys = append(keys, key)
- }
-
- return nil
- }); err != nil {
- return nil
- }
-
- return keys
-}
-
-// Username returns the user's username.
-//
-// It implements backend.User.
-func (u *User) Username() string {
- return u.username
-}
-
-// AccessLevel returns the access level of a user for a repository.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) AccessLevel(repo string, username string) backend.AccessLevel {
- anon := d.AnonAccess()
- user, _ := d.User(username)
- // If the user is an admin, they have admin access.
- if user != nil && user.IsAdmin() {
- return backend.AdminAccess
- }
-
- // If the repository exists, check if the user is a collaborator.
- r, _ := d.Repository(repo)
- if r != nil {
- // If the user is a collaborator, they have read/write access.
- isCollab, _ := d.IsCollaborator(repo, username)
- if isCollab {
- if anon > backend.ReadWriteAccess {
- return anon
- }
- return backend.ReadWriteAccess
- }
-
- // If the repository is private, the user has no access.
- if r.IsPrivate() {
- return backend.NoAccess
- }
-
- // Otherwise, the user has read-only access.
- return backend.ReadOnlyAccess
- }
-
- if user != nil {
- // If the repository doesn't exist, the user has read/write access.
- if anon > backend.ReadWriteAccess {
- return anon
- }
-
- return backend.ReadWriteAccess
- }
-
- // If the user doesn't exist, give them the anonymous access level.
- return anon
-}
-
-// AccessLevelByPublicKey returns the access level of a user's public key for a repository.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) AccessLevelByPublicKey(repo string, pk ssh.PublicKey) backend.AccessLevel {
- for _, k := range d.cfg.AdminKeys() {
- if backend.KeysEqual(pk, k) {
- return backend.AdminAccess
- }
- }
-
- user, _ := d.UserByPublicKey(pk)
- if user != nil {
- return d.AccessLevel(repo, user.Username())
- }
-
- return d.AccessLevel(repo, "")
-}
-
-// AddPublicKey adds a public key to a user.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) AddPublicKey(username string, pk ssh.PublicKey) error {
- username = strings.ToLower(username)
- if err := utils.ValidateUsername(username); err != nil {
- return err
- }
-
- return wrapDbErr(
- wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
- var userID int
- if err := tx.Get(&userID, "SELECT id FROM user WHERE username = ?", username); err != nil {
- return err
- }
-
- _, err := tx.Exec(`INSERT INTO public_key (user_id, public_key, updated_at)
- VALUES (?, ?, CURRENT_TIMESTAMP);`, userID, backend.MarshalAuthorizedKey(pk))
- return err
- }),
- )
-}
-
-// CreateUser creates a new user.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) CreateUser(username string, opts backend.UserOptions) (backend.User, error) {
- username = strings.ToLower(username)
- if err := utils.ValidateUsername(username); err != nil {
- return nil, err
- }
-
- var user *User
- if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
- stmt, err := tx.Prepare("INSERT INTO user (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP);")
- if err != nil {
- return err
- }
-
- defer stmt.Close() // nolint: errcheck
- r, err := stmt.Exec(username, opts.Admin)
- if err != nil {
- return err
- }
-
- if len(opts.PublicKeys) > 0 {
- userID, err := r.LastInsertId()
- if err != nil {
- d.logger.Error("error getting last insert id")
- return err
- }
-
- for _, pk := range opts.PublicKeys {
- stmt, err := tx.Prepare(`INSERT INTO public_key (user_id, public_key, updated_at)
- VALUES (?, ?, CURRENT_TIMESTAMP);`)
- if err != nil {
- return err
- }
-
- defer stmt.Close() // nolint: errcheck
- if _, err := stmt.Exec(userID, backend.MarshalAuthorizedKey(pk)); err != nil {
- return err
- }
- }
- }
-
- user = &User{
- db: d.db,
- username: username,
- }
- return nil
- }); err != nil {
- return nil, wrapDbErr(err)
- }
-
- return user, nil
-}
-
-// DeleteUser deletes a user.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) DeleteUser(username string) error {
- username = strings.ToLower(username)
- if err := utils.ValidateUsername(username); err != nil {
- return err
- }
-
- return wrapDbErr(
- wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
- _, err := tx.Exec("DELETE FROM user WHERE username = ?", username)
- return err
- }),
- )
-}
-
-// RemovePublicKey removes a public key from a user.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) RemovePublicKey(username string, pk ssh.PublicKey) error {
- return wrapDbErr(
- wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
- _, err := tx.Exec(`DELETE FROM public_key
- WHERE user_id = (SELECT id FROM user WHERE username = ?)
- AND public_key = ?;`, username, backend.MarshalAuthorizedKey(pk))
- return err
- }),
- )
-}
-
-// ListPublicKeys lists the public keys of a user.
-func (d *SqliteBackend) ListPublicKeys(username string) ([]ssh.PublicKey, error) {
- username = strings.ToLower(username)
- if err := utils.ValidateUsername(username); err != nil {
- return nil, err
- }
-
- keys := make([]ssh.PublicKey, 0)
- if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
- var keyStrings []string
- if err := tx.Select(&keyStrings, `SELECT public_key
- FROM public_key
- INNER JOIN user ON user.id = public_key.user_id
- WHERE user.username = ?;`, username); err != nil {
- return err
- }
-
- for _, keyString := range keyStrings {
- key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyString))
- if err != nil {
- return err
- }
- keys = append(keys, key)
- }
-
- return nil
- }); err != nil {
- return nil, wrapDbErr(err)
- }
-
- return keys, nil
-}
-
-// SetUsername sets the username of a user.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) SetUsername(username string, newUsername string) error {
- username = strings.ToLower(username)
- if err := utils.ValidateUsername(username); err != nil {
- return err
- }
-
- return wrapDbErr(
- wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
- _, err := tx.Exec("UPDATE user SET username = ? WHERE username = ?", newUsername, username)
- return err
- }),
- )
-}
-
-// SetAdmin sets the admin flag of a user.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) SetAdmin(username string, admin bool) error {
- username = strings.ToLower(username)
- if err := utils.ValidateUsername(username); err != nil {
- return err
- }
-
- return wrapDbErr(
- wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
- _, err := tx.Exec("UPDATE user SET admin = ? WHERE username = ?", admin, username)
- return err
- }),
- )
-}
-
-// User finds a user by username.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) User(username string) (backend.User, error) {
- username = strings.ToLower(username)
- if err := utils.ValidateUsername(username); err != nil {
- return nil, err
- }
-
- if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
- return tx.Get(&username, "SELECT username FROM user WHERE username = ?", username)
- }); err != nil {
- return nil, wrapDbErr(err)
- }
-
- return &User{
- db: d.db,
- username: username,
- }, nil
-}
-
-// UserByPublicKey finds a user by public key.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) UserByPublicKey(pk ssh.PublicKey) (backend.User, error) {
- var username string
- if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
- return tx.Get(&username, `SELECT user.username
- FROM public_key
- INNER JOIN user ON user.id = public_key.user_id
- WHERE public_key.public_key = ?;`, backend.MarshalAuthorizedKey(pk))
- }); err != nil {
- return nil, wrapDbErr(err)
- }
-
- return &User{
- db: d.db,
- username: username,
- }, nil
-}
-
-// Users returns all users.
-//
-// It implements backend.Backend.
-func (d *SqliteBackend) Users() ([]string, error) {
- var users []string
- if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
- return tx.Select(&users, "SELECT username FROM user")
- }); err != nil {
- return nil, wrapDbErr(err)
- }
-
- return users, nil
-}
@@ -1,55 +1,289 @@
package backend
import (
+ "context"
+ "strings"
+
+ "github.com/charmbracelet/soft-serve/server/access"
+ "github.com/charmbracelet/soft-serve/server/db"
+ "github.com/charmbracelet/soft-serve/server/db/models"
+ "github.com/charmbracelet/soft-serve/server/proto"
+ "github.com/charmbracelet/soft-serve/server/sshutils"
+ "github.com/charmbracelet/soft-serve/server/utils"
"golang.org/x/crypto/ssh"
)
-// User is an interface representing a user.
-type User interface {
- // Username returns the user's username.
- Username() string
- // IsAdmin returns whether the user is an admin.
- IsAdmin() bool
- // PublicKeys returns the user's public keys.
- PublicKeys() []ssh.PublicKey
-}
-
-// UserAccess is an interface that handles user access to repositories.
-type UserAccess interface {
- // AccessLevel returns the access level of the username to the repository.
- AccessLevel(repo string, username string) AccessLevel
- // AccessLevelByPublicKey returns the access level of the public key to the repository.
- AccessLevelByPublicKey(repo string, pk ssh.PublicKey) AccessLevel
-}
-
-// UserStore is an interface for managing users.
-type UserStore interface {
- // User finds the given user.
- User(username string) (User, error)
- // UserByPublicKey finds the user with the given public key.
- UserByPublicKey(pk ssh.PublicKey) (User, error)
- // Users returns a list of all users.
- Users() ([]string, error)
- // CreateUser creates a new user.
- CreateUser(username string, opts UserOptions) (User, error)
- // DeleteUser deletes a user.
- DeleteUser(username string) error
- // SetUsername sets the username of the user.
- SetUsername(oldUsername string, newUsername string) error
- // SetAdmin sets whether the user is an admin.
- SetAdmin(username string, admin bool) error
- // AddPublicKey adds a public key to the user.
- AddPublicKey(username string, pk ssh.PublicKey) error
- // RemovePublicKey removes a public key from the user.
- RemovePublicKey(username string, pk ssh.PublicKey) error
- // ListPublicKeys lists the public keys of the user.
- ListPublicKeys(username string) ([]ssh.PublicKey, error)
-}
-
-// UserOptions are options for creating a user.
-type UserOptions struct {
- // Admin is whether the user is an admin.
- Admin bool
- // PublicKeys are the user's public keys.
- PublicKeys []ssh.PublicKey
+// AccessLevel returns the access level of a user for a repository.
+//
+// It implements backend.Backend.
+func (d *Backend) AccessLevel(ctx context.Context, repo string, username string) access.AccessLevel {
+ anon := d.AnonAccess(ctx)
+ user, _ := d.User(ctx, username)
+ // If the user is an admin, they have admin access.
+ if user != nil && user.IsAdmin() {
+ return access.AdminAccess
+ }
+
+ // If the repository exists, check if the user is a collaborator.
+ r, _ := d.Repository(ctx, repo)
+ if r != nil {
+ // If the user is a collaborator, they have read/write access.
+ isCollab, _ := d.IsCollaborator(ctx, repo, username)
+ if isCollab {
+ if anon > access.ReadWriteAccess {
+ return anon
+ }
+ return access.ReadWriteAccess
+ }
+
+ // If the repository is private, the user has no access.
+ if r.IsPrivate() {
+ return access.NoAccess
+ }
+
+ // Otherwise, the user has read-only access.
+ return access.ReadOnlyAccess
+ }
+
+ if user != nil {
+ // If the repository doesn't exist, the user has read/write access.
+ if anon > access.ReadWriteAccess {
+ return anon
+ }
+
+ return access.ReadWriteAccess
+ }
+
+ // If the user doesn't exist, give them the anonymous access level.
+ return anon
+}
+
+// AccessLevelByPublicKey returns the access level of a user's public key for a repository.
+//
+// It implements backend.Backend.
+func (d *Backend) AccessLevelByPublicKey(ctx context.Context, repo string, pk ssh.PublicKey) access.AccessLevel {
+ for _, k := range d.cfg.AdminKeys() {
+ if sshutils.KeysEqual(pk, k) {
+ return access.AdminAccess
+ }
+ }
+
+ user, _ := d.UserByPublicKey(ctx, pk)
+ if user != nil {
+ return d.AccessLevel(ctx, repo, user.Username())
+ }
+
+ return d.AccessLevel(ctx, repo, "")
+}
+
+// User finds a user by username.
+//
+// It implements backend.Backend.
+func (d *Backend) User(ctx context.Context, username string) (proto.User, error) {
+ username = strings.ToLower(username)
+ if err := utils.ValidateUsername(username); err != nil {
+ return nil, err
+ }
+
+ var m models.User
+ var pks []ssh.PublicKey
+ if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ var err error
+ m, err = d.store.FindUserByUsername(ctx, tx, username)
+ if err != nil {
+ return err
+ }
+
+ pks, err = d.store.ListPublicKeysByUserID(ctx, tx, m.ID)
+ return err
+ }); err != nil {
+ return nil, db.WrapError(err)
+ }
+
+ return &user{
+ user: m,
+ publicKeys: pks,
+ }, nil
+}
+
+// UserByPublicKey finds a user by public key.
+//
+// It implements backend.Backend.
+func (d *Backend) UserByPublicKey(ctx context.Context, pk ssh.PublicKey) (proto.User, error) {
+ var m models.User
+ var pks []ssh.PublicKey
+ if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ var err error
+ m, err = d.store.FindUserByPublicKey(ctx, tx, pk)
+ if err != nil {
+ return err
+ }
+
+ pks, err = d.store.ListPublicKeysByUserID(ctx, tx, m.ID)
+ return err
+ }); err != nil {
+ return nil, db.WrapError(err)
+ }
+
+ return &user{
+ user: m,
+ publicKeys: pks,
+ }, nil
+}
+
+// Users returns all users.
+//
+// It implements backend.Backend.
+func (d *Backend) Users(ctx context.Context) ([]string, error) {
+ var users []string
+ if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ ms, err := d.store.GetAllUsers(ctx, tx)
+ if err != nil {
+ return err
+ }
+
+ for _, m := range ms {
+ users = append(users, m.Username)
+ }
+
+ return nil
+ }); err != nil {
+ return nil, db.WrapError(err)
+ }
+
+ return users, nil
+}
+
+// AddPublicKey adds a public key to a user.
+//
+// It implements backend.Backend.
+func (d *Backend) AddPublicKey(ctx context.Context, username string, pk ssh.PublicKey) error {
+ username = strings.ToLower(username)
+ if err := utils.ValidateUsername(username); err != nil {
+ return err
+ }
+
+ return db.WrapError(
+ d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ return d.store.AddPublicKeyByUsername(ctx, tx, username, pk)
+ }),
+ )
+}
+
+// CreateUser creates a new user.
+//
+// It implements backend.Backend.
+func (d *Backend) CreateUser(ctx context.Context, username string, opts proto.UserOptions) (proto.User, error) {
+ username = strings.ToLower(username)
+ if err := utils.ValidateUsername(username); err != nil {
+ return nil, err
+ }
+
+ if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ return d.store.CreateUser(ctx, tx, username, opts.Admin, opts.PublicKeys)
+ }); err != nil {
+ return nil, db.WrapError(err)
+ }
+
+ return d.User(ctx, username)
+}
+
+// DeleteUser deletes a user.
+//
+// It implements backend.Backend.
+func (d *Backend) DeleteUser(ctx context.Context, username string) error {
+ username = strings.ToLower(username)
+ if err := utils.ValidateUsername(username); err != nil {
+ return err
+ }
+
+ return db.WrapError(
+ d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ return d.store.DeleteUserByUsername(ctx, tx, username)
+ }),
+ )
+}
+
+// RemovePublicKey removes a public key from a user.
+//
+// It implements backend.Backend.
+func (d *Backend) RemovePublicKey(ctx context.Context, username string, pk ssh.PublicKey) error {
+ return db.WrapError(
+ d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ return d.store.RemovePublicKeyByUsername(ctx, tx, username, pk)
+ }),
+ )
+}
+
+// ListPublicKeys lists the public keys of a user.
+func (d *Backend) ListPublicKeys(ctx context.Context, username string) ([]ssh.PublicKey, error) {
+ username = strings.ToLower(username)
+ if err := utils.ValidateUsername(username); err != nil {
+ return nil, err
+ }
+
+ var keys []ssh.PublicKey
+ if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ var err error
+ keys, err = d.store.ListPublicKeysByUsername(ctx, tx, username)
+ return err
+ }); err != nil {
+ return nil, db.WrapError(err)
+ }
+
+ return keys, nil
+}
+
+// SetUsername sets the username of a user.
+//
+// It implements backend.Backend.
+func (d *Backend) SetUsername(ctx context.Context, username string, newUsername string) error {
+ username = strings.ToLower(username)
+ if err := utils.ValidateUsername(username); err != nil {
+ return err
+ }
+
+ return db.WrapError(
+ d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ return d.store.SetUsernameByUsername(ctx, tx, username, newUsername)
+ }),
+ )
+}
+
+// SetAdmin sets the admin flag of a user.
+//
+// It implements backend.Backend.
+func (d *Backend) SetAdmin(ctx context.Context, username string, admin bool) error {
+ username = strings.ToLower(username)
+ if err := utils.ValidateUsername(username); err != nil {
+ return err
+ }
+
+ return db.WrapError(
+ d.db.TransactionContext(ctx, func(tx *db.Tx) error {
+ return d.store.SetAdminByUsername(ctx, tx, username, admin)
+ }),
+ )
+}
+
+type user struct {
+ user models.User
+ publicKeys []ssh.PublicKey
+}
+
+var _ proto.User = (*user)(nil)
+
+// IsAdmin implements store.User
+func (u *user) IsAdmin() bool {
+ return u.user.Admin
+}
+
+// PublicKeys implements store.User
+func (u *user) PublicKeys() []ssh.PublicKey {
+ return u.publicKeys
+}
+
+// Username implements store.User
+func (u *user) Username() string {
+ return u.user.Username
}
@@ -2,11 +2,12 @@ package backend
import (
"github.com/charmbracelet/soft-serve/git"
+ "github.com/charmbracelet/soft-serve/server/proto"
)
// LatestFile returns the contents of the latest file at the specified path in
// the repository and its file path.
-func LatestFile(r Repository, pattern string) (string, string, error) {
+func LatestFile(r proto.Repository, pattern string) (string, string, error) {
repo, err := r.Open()
if err != nil {
return "", "", err
@@ -15,7 +16,7 @@ func LatestFile(r Repository, pattern string) (string, string, error) {
}
// Readme returns the repository's README.
-func Readme(r Repository) (readme string, path string, err error) {
+func Readme(r proto.Repository) (readme string, path string, err error) {
pattern := "[rR][eE][aA][dD][mM][eE]*"
readme, path, err = LatestFile(r, pattern)
return
@@ -1,22 +0,0 @@
-package cmd
-
-import "github.com/spf13/cobra"
-
-func setUsernameCommand() *cobra.Command {
- cmd := &cobra.Command{
- Use: "set-username USERNAME",
- Short: "Set your username",
- Args: cobra.ExactArgs(1),
- RunE: func(cmd *cobra.Command, args []string) error {
- cfg, s := fromContext(cmd)
- user, err := cfg.Backend.UserByPublicKey(s.PublicKey())
- if err != nil {
- return err
- }
-
- return cfg.Backend.SetUsername(user.Username(), args[0])
- },
- }
-
- return cmd
-}
@@ -1,17 +1,15 @@
package config
import (
- "context"
- "errors"
"fmt"
"os"
"path/filepath"
+ "strconv"
"strings"
"time"
"github.com/caarlos0/env/v8"
- "github.com/charmbracelet/log"
- "github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/sshutils"
"golang.org/x/crypto/ssh"
"gopkg.in/yaml.v3"
)
@@ -82,6 +80,19 @@ type LogConfig struct {
// Time format for the log `ts` field.
// Format must be described in Golang's time format.
TimeFormat string `env:"TIME_FORMAT" yaml:"time_format"`
+
+ // Path to a file to write logs to.
+ // If not set, logs will be written to stderr.
+ Path string `env:"PATH" yaml:"path"`
+}
+
+// DBConfig is the database connection configuration.
+type DBConfig struct {
+ // Driver is the driver for the database.
+ Driver string `env:"DRIVER" yaml:"driver"`
+
+ // DataSource is the database data source name.
+ DataSource string `env:"DATA_SOURCE" yaml:"data_source"`
}
// Config is the configuration for Soft Serve.
@@ -104,14 +115,14 @@ type Config struct {
// Log is the logger configuration.
Log LogConfig `envPrefix:"LOG_" yaml:"log"`
+ // DB is the database configuration.
+ DB DBConfig `envPrefix:"DB_" yaml:"db"`
+
// InitialAdminKeys is a list of public keys that will be added to the list of admins.
InitialAdminKeys []string `env:"INITIAL_ADMIN_KEYS" envSeparator:"\n" yaml:"initial_admin_keys"`
// DataPath is the path to the directory where Soft Serve will store its data.
DataPath string `env:"DATA_PATH" yaml:"-"`
-
- // Backend is the Git backend to use.
- Backend backend.Backend `yaml:"-"`
}
// Environ returns the config as a list of environment variables.
@@ -123,8 +134,8 @@ func (c *Config) Environ() []string {
// TODO: do this dynamically
envs = append(envs, []string{
- fmt.Sprintf("SOFT_SERVE_NAME=%s", c.Name),
fmt.Sprintf("SOFT_SERVE_DATA_PATH=%s", c.DataPath),
+ fmt.Sprintf("SOFT_SERVE_NAME=%s", c.Name),
fmt.Sprintf("SOFT_SERVE_INITIAL_ADMIN_KEYS=%s", strings.Join(c.InitialAdminKeys, "\n")),
fmt.Sprintf("SOFT_SERVE_SSH_LISTEN_ADDR=%s", c.SSH.ListenAddr),
fmt.Sprintf("SOFT_SERVE_SSH_PUBLIC_URL=%s", c.SSH.PublicURL),
@@ -143,51 +154,50 @@ func (c *Config) Environ() []string {
fmt.Sprintf("SOFT_SERVE_STATS_LISTEN_ADDR=%s", c.Stats.ListenAddr),
fmt.Sprintf("SOFT_SERVE_LOG_FORMAT=%s", c.Log.Format),
fmt.Sprintf("SOFT_SERVE_LOG_TIME_FORMAT=%s", c.Log.TimeFormat),
+ fmt.Sprintf("SOFT_SERVE_DB_DRIVER=%s", c.DB.Driver),
+ fmt.Sprintf("SOFT_SERVE_DB_DATA_SOURCE=%s", c.DB.DataSource),
}...)
return envs
}
-func parseConfig(path string) (*Config, error) {
- dataPath := filepath.Dir(path)
- cfg := &Config{
- Name: "Soft Serve",
- DataPath: dataPath,
- SSH: SSHConfig{
- ListenAddr: ":23231",
- PublicURL: "ssh://localhost:23231",
- KeyPath: filepath.Join("ssh", "soft_serve_host_ed25519"),
- ClientKeyPath: filepath.Join("ssh", "soft_serve_client_ed25519"),
- MaxTimeout: 0,
- IdleTimeout: 0,
- },
- Git: GitConfig{
- ListenAddr: ":9418",
- MaxTimeout: 0,
- IdleTimeout: 3,
- MaxConnections: 32,
- },
- HTTP: HTTPConfig{
- ListenAddr: ":23232",
- PublicURL: "http://localhost:23232",
- },
- Stats: StatsConfig{
- ListenAddr: "localhost:23233",
- },
- Log: LogConfig{
- Format: "text",
- TimeFormat: time.DateTime,
- },
- }
+// IsDebug returns true if the server is running in debug mode.
+func IsDebug() bool {
+ debug, _ := strconv.ParseBool(os.Getenv("SOFT_SERVE_DEBUG"))
+ return debug
+}
+// IsVerbose returns true if the server is running in verbose mode.
+// Verbose mode is only enabled if debug mode is enabled.
+func IsVerbose() bool {
+ verbose, _ := strconv.ParseBool(os.Getenv("SOFT_SERVE_VERBOSE"))
+ return IsDebug() && verbose
+}
+
+// parseFile parses the given file as a configuration file.
+// The file must be in YAML format.
+func parseFile(cfg *Config, path string) error {
f, err := os.Open(path)
- if err == nil {
- defer f.Close() // nolint: errcheck
- if err := yaml.NewDecoder(f).Decode(cfg); err != nil {
- return cfg, fmt.Errorf("decode config: %w", err)
- }
+ if err != nil {
+ return err
+ }
+
+ defer f.Close() // nolint: errcheck
+ if err := yaml.NewDecoder(f).Decode(cfg); err != nil {
+ return fmt.Errorf("decode config: %w", err)
}
+ return cfg.Validate()
+}
+
+// ParseFile parses the config from the default file path.
+// This also calls Validate() on the config.
+func (c *Config) ParseFile() error {
+ return parseFile(c, c.ConfigPath())
+}
+
+// parseEnv parses the environment variables as a configuration file.
+func parseEnv(cfg *Config) error {
// Merge initial admin keys from both config file and environment variables.
initialAdminKeys := append([]string{}, cfg.InitialAdminKeys...)
@@ -195,7 +205,7 @@ func parseConfig(path string) (*Config, error) {
if err := env.ParseWithOptions(cfg, env.Options{
Prefix: "SOFT_SERVE_",
}); err != nil {
- return cfg, fmt.Errorf("parse environment variables: %w", err)
+ return fmt.Errorf("parse environment variables: %w", err)
}
// Merge initial admin keys from environment variables.
@@ -203,80 +213,108 @@ func parseConfig(path string) (*Config, error) {
cfg.InitialAdminKeys = append(cfg.InitialAdminKeys, initialAdminKeys...)
}
- // Validate keys
- pks := make([]string, 0)
- for _, key := range parseAuthKeys(cfg.InitialAdminKeys) {
- ak := backend.MarshalAuthorizedKey(key)
- pks = append(pks, ak)
- }
-
- cfg.InitialAdminKeys = pks
-
- // Reset datapath to config dir.
- // This is necessary because the environment variable may be set to
- // a different directory.
- cfg.DataPath = dataPath
-
- return cfg, nil
+ return cfg.Validate()
}
-// ParseConfig parses the configuration from the given file.
-func ParseConfig(path string) (*Config, error) {
- cfg, err := parseConfig(path)
- if err != nil {
- return cfg, err
- }
+// ParseEnv parses the config from the environment variables.
+// This also calls Validate() on the config.
+func (c *Config) ParseEnv() error {
+ return parseEnv(c)
+}
- if err := cfg.validate(); err != nil {
- return cfg, err
+// Parse parses the config from the default file path and environment variables.
+// This also calls Validate() on the config.
+func (c *Config) Parse() error {
+ if err := c.ParseFile(); err != nil {
+ return err
}
- return cfg, nil
+ return c.ParseEnv()
}
-// WriteConfig writes the configuration to the given file.
-func WriteConfig(path string, cfg *Config) error {
+// writeConfig writes the configuration to the given file.
+func writeConfig(cfg *Config, path string) error {
if err := os.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil {
return err
}
- return os.WriteFile(path, []byte(newConfigFile(cfg)), 0o644) // nolint: errcheck
+ return os.WriteFile(path, []byte(newConfigFile(cfg)), 0o644) // nolint: errcheck, gosec
}
-// DefaultConfig returns a Config with the values populated with the defaults
-// or specified environment variables.
-func DefaultConfig() *Config {
- dataPath := os.Getenv("SOFT_SERVE_DATA_PATH")
- if dataPath == "" {
- dataPath = "data"
- }
+// WriteConfig writes the configuration to the default file.
+func (c *Config) WriteConfig() error {
+ return writeConfig(c, c.ConfigPath())
+}
- cp := filepath.Join(dataPath, "config.yaml")
- cfg, err := parseConfig(cp)
- if err != nil && !errors.Is(err, os.ErrNotExist) {
- log.Errorf("failed to parse config: %v", err)
+// DefaultDataPath returns the path to the data directory.
+// It uses the SOFT_SERVE_DATA_PATH environment variable if set, otherwise it
+// uses "data".
+func DefaultDataPath() string {
+ dp := os.Getenv("SOFT_SERVE_DATA_PATH")
+ if dp == "" {
+ dp = "data"
}
- // Write config if it doesn't exist
- if _, err := os.Stat(cp); os.IsNotExist(err) {
- if err := WriteConfig(cp, cfg); err != nil {
- log.Fatal("failed to write config", "err", err)
- }
- }
+ return dp
+}
- if err := cfg.validate(); err != nil {
- log.Fatal(err)
- }
+// ConfigPath returns the path to the config file.
+func (c *Config) ConfigPath() string { // nolint:revive
+ return filepath.Join(c.DataPath, "config.yaml")
+}
+
+func exist(path string) bool {
+ _, err := os.Stat(path)
+ return err == nil
+}
- return cfg
+// Exist returns true if the config file exists.
+func (c *Config) Exist() bool {
+ return exist(filepath.Join(c.DataPath, "config.yaml"))
}
-// WithBackend sets the backend for the configuration.
-func (c *Config) WithBackend(backend backend.Backend) *Config {
- c.Backend = backend
- return c
+// DefaultConfig returns the default Config. All the path values are relative
+// to the data directory.
+// Use Validate() to validate the config and ensure absolute paths.
+func DefaultConfig() *Config {
+ return &Config{
+ Name: "Soft Serve",
+ DataPath: DefaultDataPath(),
+ SSH: SSHConfig{
+ ListenAddr: ":23231",
+ PublicURL: "ssh://localhost:23231",
+ KeyPath: filepath.Join("ssh", "soft_serve_host_ed25519"),
+ ClientKeyPath: filepath.Join("ssh", "soft_serve_client_ed25519"),
+ MaxTimeout: 0,
+ IdleTimeout: 0,
+ },
+ Git: GitConfig{
+ ListenAddr: ":9418",
+ MaxTimeout: 0,
+ IdleTimeout: 3,
+ MaxConnections: 32,
+ },
+ HTTP: HTTPConfig{
+ ListenAddr: ":23232",
+ PublicURL: "http://localhost:23232",
+ },
+ Stats: StatsConfig{
+ ListenAddr: "localhost:23233",
+ },
+ Log: LogConfig{
+ Format: "text",
+ TimeFormat: time.DateTime,
+ },
+ DB: DBConfig{
+ Driver: "sqlite",
+ DataSource: "soft-serve.db" +
+ "?_pragma=busy_timeout(5000)&_pragma=foreign_keys(1)",
+ },
+ }
}
-func (c *Config) validate() error {
+// Validate validates the configuration.
+// It updates the configuration with absolute paths.
+func (c *Config) Validate() error {
// Use absolute paths
if !filepath.IsAbs(c.DataPath) {
dp, err := filepath.Abs(c.DataPath)
@@ -305,19 +343,37 @@ func (c *Config) validate() error {
c.HTTP.TLSCertPath = filepath.Join(c.DataPath, c.HTTP.TLSCertPath)
}
+ if strings.HasPrefix(c.DB.Driver, "sqlite") && !filepath.IsAbs(c.DB.DataSource) {
+ c.DB.DataSource = filepath.Join(c.DataPath, c.DB.DataSource)
+ }
+
+ // Validate keys
+ pks := make([]string, 0)
+ for _, key := range parseAuthKeys(c.InitialAdminKeys) {
+ ak := sshutils.MarshalAuthorizedKey(key)
+ pks = append(pks, ak)
+ }
+
+ c.InitialAdminKeys = pks
+
return nil
}
// parseAuthKeys parses authorized keys from either file paths or string authorized_keys.
func parseAuthKeys(aks []string) []ssh.PublicKey {
+ exist := make(map[string]struct{}, 0)
pks := make([]ssh.PublicKey, 0)
for _, key := range aks {
if bts, err := os.ReadFile(key); err == nil {
// key is a file
key = strings.TrimSpace(string(bts))
}
- if pk, _, err := backend.ParseAuthorizedKey(key); err == nil {
- pks = append(pks, pk)
+
+ if pk, _, err := sshutils.ParseAuthorizedKey(key); err == nil {
+ if _, ok := exist[key]; !ok {
+ pks = append(pks, pk)
+ exist[key] = struct{}{}
+ }
}
}
return pks
@@ -327,19 +383,3 @@ func parseAuthKeys(aks []string) []ssh.PublicKey {
func (c *Config) AdminKeys() []ssh.PublicKey {
return parseAuthKeys(c.InitialAdminKeys)
}
-
-var configCtxKey = struct{ string }{"config"}
-
-// WithContext returns a new context with the configuration attached.
-func WithContext(ctx context.Context, cfg *Config) context.Context {
- return context.WithValue(ctx, configCtxKey, cfg)
-}
-
-// FromContext returns the configuration from the context.
-func FromContext(ctx context.Context) *Config {
- if c, ok := ctx.Value(configCtxKey).(*Config); ok {
- return c
- }
-
- return DefaultConfig()
-}
@@ -2,11 +2,9 @@ package config
import (
"os"
- "path/filepath"
"testing"
"github.com/matryer/is"
- "gopkg.in/yaml.v3"
)
func TestParseMultipleKeys(t *testing.T) {
@@ -19,6 +17,7 @@ func TestParseMultipleKeys(t *testing.T) {
is.NoErr(os.Unsetenv("SOFT_SERVE_DATA_PATH"))
})
cfg := DefaultConfig()
+ is.NoErr(cfg.ParseEnv())
is.Equal(cfg.InitialAdminKeys, []string{
"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINMwLvyV3ouVrTysUYGoJdl5Vgn5BACKov+n9PlzfPwH",
"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFxIobhwtfdwN7m1TFt9wx3PsfvcAkISGPxmbmbauST8",
@@ -29,15 +28,12 @@ func TestMergeInitAdminKeys(t *testing.T) {
is := is.New(t)
is.NoErr(os.Setenv("SOFT_SERVE_INITIAL_ADMIN_KEYS", "testdata/k1.pub"))
t.Cleanup(func() { is.NoErr(os.Unsetenv("SOFT_SERVE_INITIAL_ADMIN_KEYS")) })
- bts, err := yaml.Marshal(&Config{
+ cfg := &Config{
+ DataPath: t.TempDir(),
InitialAdminKeys: []string{"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFxIobhwtfdwN7m1TFt9wx3PsfvcAkISGPxmbmbauST8 a@b"},
- })
- is.NoErr(err)
- fp := filepath.Join(t.TempDir(), "config.yaml")
- err = os.WriteFile(fp, bts, 0o644)
- is.NoErr(err)
- cfg, err := ParseConfig(fp)
- is.NoErr(err)
+ }
+ is.NoErr(cfg.WriteConfig())
+ is.NoErr(cfg.Parse())
is.Equal(cfg.InitialAdminKeys, []string{
"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINMwLvyV3ouVrTysUYGoJdl5Vgn5BACKov+n9PlzfPwH",
"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFxIobhwtfdwN7m1TFt9wx3PsfvcAkISGPxmbmbauST8",
@@ -46,19 +42,16 @@ func TestMergeInitAdminKeys(t *testing.T) {
func TestValidateInitAdminKeys(t *testing.T) {
is := is.New(t)
- bts, err := yaml.Marshal(&Config{
+ cfg := &Config{
+ DataPath: t.TempDir(),
InitialAdminKeys: []string{
"testdata/k1.pub",
"abc",
"",
},
- })
- is.NoErr(err)
- fp := filepath.Join(t.TempDir(), "config.yaml")
- err = os.WriteFile(fp, bts, 0o644)
- is.NoErr(err)
- cfg, err := ParseConfig(fp)
- is.NoErr(err)
+ }
+ is.NoErr(cfg.WriteConfig())
+ is.NoErr(cfg.Parse())
is.Equal(cfg.InitialAdminKeys, []string{
"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINMwLvyV3ouVrTysUYGoJdl5Vgn5BACKov+n9PlzfPwH",
})
@@ -0,0 +1,20 @@
+package config
+
+import "context"
+
+// ContextKey is the context key for the config.
+var ContextKey = struct{ string }{"config"}
+
+// WithContext returns a new context with the configuration attached.
+func WithContext(ctx context.Context, cfg *Config) context.Context {
+ return context.WithValue(ctx, ContextKey, cfg)
+}
+
+// FromContext returns the configuration from the context.
+func FromContext(ctx context.Context) *Config {
+ if c, ok := ctx.Value(ContextKey).(*Config); ok {
+ return c
+ }
+
+ return DefaultConfig()
+}
@@ -18,6 +18,8 @@ log:
# Time format for the log "timestamp" field.
# Should be described in Golang's time format.
time_format: "{{ .Log.TimeFormat }}"
+ # Path to the log file. Leave empty to write to stderr.
+ #path: "{{ .Log.Path }}"
# The SSH server configuration.
ssh:
@@ -79,6 +81,15 @@ stats:
# The address on which the stats server will listen.
listen_addr: "{{ .Stats.ListenAddr }}"
+# The database configuration.
+db:
+ # The database driver to use.
+ # Valid values are "sqlite" and "postgres".
+ driver: "{{ .DB.Driver }}"
+ # The database data source name.
+ # This is driver specific and can be a file path or connection string.
+ data_source: "{{ .DB.DataSource }}"
+
# Additional admin keys.
#initial_admin_keys:
# - "ssh-rsa AAAAB3NzaC1yc2..."
@@ -8,17 +8,9 @@ import (
"github.com/robfig/cron/v3"
)
-// CronScheduler is a cron-like job scheduler.
-type CronScheduler struct {
+// Scheduler is a cron-like job scheduler.
+type Scheduler struct {
*cron.Cron
- logger cron.Logger
-}
-
-// Entry is a cron job.
-type Entry struct {
- ID cron.EntryID
- Desc string
- Spec string
}
// cronLogger is a wrapper around the logger to make it compatible with the
@@ -37,22 +29,22 @@ func (l cronLogger) Error(err error, msg string, keysAndValues ...interface{}) {
l.logger.Error(msg, append(keysAndValues, "err", err)...)
}
-// NewCronScheduler returns a new Cron.
-func NewCronScheduler(ctx context.Context) *CronScheduler {
+// NewScheduler returns a new Cron.
+func NewScheduler(ctx context.Context) *Scheduler {
logger := cronLogger{log.FromContext(ctx).WithPrefix("cron")}
- return &CronScheduler{
+ return &Scheduler{
Cron: cron.New(cron.WithLogger(logger)),
}
}
-// Shutdonw gracefully shuts down the CronServer.
-func (s *CronScheduler) Shutdown() {
+// Shutdonw gracefully shuts down the Scheduler.
+func (s *Scheduler) Shutdown() {
ctx, cancel := context.WithTimeout(s.Cron.Stop(), 30*time.Second)
defer func() { cancel() }()
<-ctx.Done()
}
-// Start starts the CronServer.
-func (s *CronScheduler) Start() {
+// Start starts the Scheduler.
+func (s *Scheduler) Start() {
s.Cron.Start()
}
@@ -91,15 +91,15 @@ func (c *serverConn) updateDeadline() {
initTimeout := time.Now().Add(c.initTimeout)
c.initTimeout = 0
if initTimeout.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
- c.Conn.SetDeadline(initTimeout)
+ c.Conn.SetDeadline(initTimeout) // nolint: errcheck
return
}
case c.idleTimeout > 0:
idleDeadline := time.Now().Add(c.idleTimeout)
if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() {
- c.Conn.SetDeadline(idleDeadline)
+ c.Conn.SetDeadline(idleDeadline) // nolint: errcheck
return
}
}
- c.Conn.SetDeadline(c.maxDeadline)
+ c.Conn.SetDeadline(c.maxDeadline) // nolint: errcheck
}
@@ -11,6 +11,7 @@ import (
"time"
"github.com/charmbracelet/log"
+ "github.com/charmbracelet/soft-serve/server/access"
"github.com/charmbracelet/soft-serve/server/backend"
"github.com/charmbracelet/soft-serve/server/config"
"github.com/charmbracelet/soft-serve/server/git"
@@ -50,7 +51,7 @@ type GitDaemon struct {
finished chan struct{}
conns connections
cfg *config.Config
- be backend.Backend
+ be *backend.Backend
wg sync.WaitGroup
once sync.Once
logger *log.Logger
@@ -94,7 +95,7 @@ func (d *GitDaemon) Start() error {
default:
d.logger.Debugf("git: error accepting connection: %v", err)
}
- if ne, ok := err.(net.Error); ok && ne.Temporary() {
+ if ne, ok := err.(net.Error); ok && ne.Temporary() { // nolint: staticcheck
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
} else {
@@ -125,7 +126,7 @@ func (d *GitDaemon) Start() error {
}
func (d *GitDaemon) fatal(c net.Conn, err error) {
- git.WritePktline(c, err)
+ git.WritePktlineErr(c, err) // nolint: errcheck
if err := c.Close(); err != nil {
d.logger.Debugf("git: error closing connection: %v", err)
}
@@ -146,7 +147,7 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
}
d.conns.Add(c)
defer func() {
- d.conns.Close(c)
+ d.conns.Close(c) // nolint: errcheck
}()
readc := make(chan struct{}, 1)
@@ -227,8 +228,8 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
}
}
- be := d.be.WithContext(ctx)
- if !be.AllowKeyless() {
+ be := d.be
+ if !be.AllowKeyless(ctx) {
d.fatal(c, git.ErrNotAuthed)
return
}
@@ -247,13 +248,13 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
return
}
- if _, err := d.be.Repository(repo); err != nil {
+ if _, err := d.be.Repository(ctx, repo); err != nil {
d.fatal(c, git.ErrInvalidRepo)
return
}
- auth := be.AccessLevel(name, "")
- if auth < backend.ReadOnlyAccess {
+ auth := be.AccessLevel(ctx, name, "")
+ if auth < access.ReadOnlyAccess {
d.fatal(c, git.ErrNotAuthed)
return
}
@@ -263,6 +264,7 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
"SOFT_SERVE_REPO_NAME=" + name,
"SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo),
"SOFT_SERVE_HOST=" + host,
+ "SOFT_SERVE_LOG_PATH=" + filepath.Join(d.cfg.DataPath, "log", "hooks.log"),
}
// Add git protocol environment variable.
@@ -301,7 +303,7 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
func (d *GitDaemon) Close() error {
d.once.Do(func() { close(d.finished) })
err := d.listener.Close()
- d.conns.CloseAll()
+ d.conns.CloseAll() // nolint: errcheck
return err
}
@@ -13,11 +13,13 @@ import (
"testing"
"github.com/charmbracelet/soft-serve/server/backend"
- "github.com/charmbracelet/soft-serve/server/backend/sqlite"
"github.com/charmbracelet/soft-serve/server/config"
+ "github.com/charmbracelet/soft-serve/server/db"
+ "github.com/charmbracelet/soft-serve/server/db/migrate"
"github.com/charmbracelet/soft-serve/server/git"
"github.com/charmbracelet/soft-serve/server/test"
"github.com/go-git/go-git/v5/plumbing/format/pktline"
+ _ "modernc.org/sqlite" // sqlite driver
)
var testDaemon *GitDaemon
@@ -35,13 +37,20 @@ func TestMain(m *testing.M) {
os.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", test.RandomPort()))
ctx := context.TODO()
cfg := config.DefaultConfig()
+ if err := cfg.Validate(); err != nil {
+ log.Fatal(err)
+ }
ctx = config.WithContext(ctx, cfg)
- fb, err := sqlite.NewSqliteBackend(ctx)
+ db, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
if err != nil {
log.Fatal(err)
}
- cfg = cfg.WithBackend(fb)
- ctx = backend.WithContext(ctx, fb)
+ defer db.Close() // nolint: errcheck
+ if err := migrate.Migrate(ctx, db); err != nil {
+ log.Fatal(err)
+ }
+ be := backend.New(ctx, cfg, db)
+ ctx = backend.WithContext(ctx, be)
d, err := NewGitDaemon(ctx)
if err != nil {
log.Fatal(err)
@@ -59,7 +68,7 @@ func TestMain(m *testing.M) {
os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT")
os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR")
_ = d.Close()
- _ = fb.Close()
+ _ = db.Close()
os.Exit(code)
}
@@ -72,7 +81,7 @@ func TestIdleTimeout(t *testing.T) {
if err != nil && !errors.Is(err, io.EOF) {
t.Fatalf("expected nil, got error: %v", err)
}
- if out != git.ErrTimeout.Error() && out != "" {
+ if out != "ERR "+git.ErrTimeout.Error() && out != "" {
t.Fatalf("expected %q error, got %q", git.ErrTimeout, out)
}
}
@@ -89,7 +98,7 @@ func TestInvalidRepo(t *testing.T) {
if err != nil {
t.Fatalf("expected nil, got error: %v", err)
}
- if out != git.ErrInvalidRepo.Error() {
+ if out != "ERR "+git.ErrInvalidRepo.Error() {
t.Fatalf("expected %q error, got %q", git.ErrInvalidRepo, out)
}
}
@@ -0,0 +1,18 @@
+package db
+
+import "context"
+
+var contextKey = struct{ string }{"db"}
+
+// FromContext returns the database from the context.
+func FromContext(ctx context.Context) *DB {
+ if db, ok := ctx.Value(contextKey).(*DB); ok {
+ return db
+ }
+ return nil
+}
+
+// WithContext returns a new context with the database.
+func WithContext(ctx context.Context, db *DB) context.Context {
+ return context.WithValue(ctx, contextKey, db)
+}
@@ -0,0 +1,88 @@
+package db
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+
+ "github.com/charmbracelet/log"
+ "github.com/charmbracelet/soft-serve/server/config"
+ "github.com/jmoiron/sqlx"
+ _ "modernc.org/sqlite" // sqlite driver
+)
+
+// DB is the interface for a Soft Serve database.
+type DB struct {
+ *sqlx.DB
+ logger *log.Logger
+}
+
+// Open opens a database connection.
+func Open(ctx context.Context, driverName string, dsn string) (*DB, error) {
+ db, err := sqlx.ConnectContext(ctx, driverName, dsn)
+ if err != nil {
+ return nil, err
+ }
+
+ d := &DB{
+ DB: db,
+ }
+
+ if config.IsVerbose() {
+ logger := log.FromContext(ctx).WithPrefix("db")
+ d.logger = logger
+ }
+
+ return d, nil
+}
+
+// Close implements db.DB.
+func (d *DB) Close() error {
+ return d.DB.Close()
+}
+
+// Tx is a database transaction.
+type Tx struct {
+ *sqlx.Tx
+ logger *log.Logger
+}
+
+// Transaction implements db.DB.
+func (d *DB) Transaction(fn func(tx *Tx) error) error {
+ return d.TransactionContext(context.Background(), fn)
+}
+
+// TransactionContext implements db.DB.
+func (d *DB) TransactionContext(ctx context.Context, fn func(tx *Tx) error) error {
+ txx, err := d.DB.BeginTxx(ctx, nil)
+ if err != nil {
+ return fmt.Errorf("failed to begin transaction: %w", err)
+ }
+
+ tx := &Tx{txx, d.logger}
+ if err := fn(tx); err != nil {
+ return rollback(tx, err)
+ }
+
+ if err := tx.Commit(); err != nil {
+ if errors.Is(err, sql.ErrTxDone) {
+ // this is ok because whoever did finish the tx should have also written the error already.
+ return nil
+ }
+ return fmt.Errorf("failed to commit transaction: %w", err)
+ }
+
+ return nil
+}
+
+func rollback(tx *Tx, err error) error {
+ if rerr := tx.Rollback(); rerr != nil {
+ if errors.Is(rerr, sql.ErrTxDone) {
+ return err
+ }
+ return fmt.Errorf("failed to rollback: %s: %w", err.Error(), rerr)
+ }
+
+ return err
+}
@@ -0,0 +1,48 @@
+package db
+
+import (
+ "database/sql"
+ "errors"
+
+ "github.com/lib/pq"
+ sqlite "modernc.org/sqlite"
+ sqlite3 "modernc.org/sqlite/lib"
+)
+
+var (
+ // ErrDuplicateKey is a constraint violation error.
+ ErrDuplicateKey = errors.New("duplicate key value violates table constraint")
+
+ // ErrRecordNotFound is returned when a record is not found.
+ ErrRecordNotFound = errors.New("record not found")
+)
+
+// WrapError is a convenient function that unite various database driver
+// errors to consistent errors.
+func WrapError(err error) error {
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return ErrRecordNotFound
+ }
+
+ // Handle sqlite constraint error.
+ if liteErr, ok := err.(*sqlite.Error); ok {
+ code := liteErr.Code()
+ if code == sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY ||
+ code == sqlite3.SQLITE_CONSTRAINT_FOREIGNKEY ||
+ code == sqlite3.SQLITE_CONSTRAINT_UNIQUE {
+ return ErrDuplicateKey
+ }
+ }
+
+ // Handle postgres constraint error.
+ if pgErr, ok := err.(*pq.Error); ok {
+ if pgErr.Code == "23505" ||
+ pgErr.Code == "23503" ||
+ pgErr.Code == "23514" {
+ return ErrDuplicateKey
+ }
+ }
+ }
+ return err
+}
@@ -0,0 +1,135 @@
+package db
+
+import (
+ "context"
+ "database/sql"
+
+ "github.com/charmbracelet/log"
+ "github.com/jmoiron/sqlx"
+)
+
+func trace(l *log.Logger, query string, args ...interface{}) {
+ if l != nil {
+ l.Debug("trace", "query", query, "args", args)
+ }
+}
+
+// Select is a wrapper around sqlx.Select that logs the query and arguments.
+func (d *DB) Select(dest interface{}, query string, args ...interface{}) error {
+ trace(d.logger, query, args...)
+ return d.DB.Select(dest, query, args...)
+}
+
+// Get is a wrapper around sqlx.Get that logs the query and arguments.
+func (d *DB) Get(dest interface{}, query string, args ...interface{}) error {
+ trace(d.logger, query, args...)
+ return d.DB.Get(dest, query, args...)
+}
+
+// Queryx is a wrapper around sqlx.Queryx that logs the query and arguments.
+func (d *DB) Queryx(query string, args ...interface{}) (*sqlx.Rows, error) {
+ trace(d.logger, query, args...)
+ return d.DB.Queryx(query, args...)
+}
+
+// QueryRowx is a wrapper around sqlx.QueryRowx that logs the query and arguments.
+func (d *DB) QueryRowx(query string, args ...interface{}) *sqlx.Row {
+ trace(d.logger, query, args...)
+ return d.DB.QueryRowx(query, args...)
+}
+
+// Exec is a wrapper around sqlx.Exec that logs the query and arguments.
+func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
+ trace(d.logger, query, args...)
+ return d.DB.Exec(query, args...)
+}
+
+// SelectContext is a wrapper around sqlx.SelectContext that logs the query and arguments.
+func (d *DB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
+ trace(d.logger, query, args...)
+ return d.DB.SelectContext(ctx, dest, query, args...)
+}
+
+// GetContext is a wrapper around sqlx.GetContext that logs the query and arguments.
+func (d *DB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
+ trace(d.logger, query, args...)
+ return d.DB.GetContext(ctx, dest, query, args...)
+}
+
+// QueryxContext is a wrapper around sqlx.QueryxContext that logs the query and arguments.
+func (d *DB) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
+ trace(d.logger, query, args...)
+ return d.DB.QueryxContext(ctx, query, args...)
+}
+
+// QueryRowxContext is a wrapper around sqlx.QueryRowxContext that logs the query and arguments.
+func (d *DB) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row {
+ trace(d.logger, query, args...)
+ return d.DB.QueryRowxContext(ctx, query, args...)
+}
+
+// ExecContext is a wrapper around sqlx.ExecContext that logs the query and arguments.
+func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
+ trace(d.logger, query, args...)
+ return d.DB.ExecContext(ctx, query, args...)
+}
+
+// Select is a wrapper around sqlx.Select that logs the query and arguments.
+func (t *Tx) Select(dest interface{}, query string, args ...interface{}) error {
+ trace(t.logger, query, args...)
+ return t.Tx.Select(dest, query, args...)
+}
+
+// Get is a wrapper around sqlx.Get that logs the query and arguments.
+func (t *Tx) Get(dest interface{}, query string, args ...interface{}) error {
+ trace(t.logger, query, args...)
+ return t.Tx.Get(dest, query, args...)
+}
+
+// Queryx is a wrapper around sqlx.Queryx that logs the query and arguments.
+func (t *Tx) Queryx(query string, args ...interface{}) (*sqlx.Rows, error) {
+ trace(t.logger, query, args...)
+ return t.Tx.Queryx(query, args...)
+}
+
+// QueryRowx is a wrapper around sqlx.QueryRowx that logs the query and arguments.
+func (t *Tx) QueryRowx(query string, args ...interface{}) *sqlx.Row {
+ trace(t.logger, query, args...)
+ return t.Tx.QueryRowx(query, args...)
+}
+
+// Exec is a wrapper around sqlx.Exec that logs the query and arguments.
+func (t *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
+ trace(t.logger, query, args...)
+ return t.Tx.Exec(query, args...)
+}
+
+// SelectContext is a wrapper around sqlx.SelectContext that logs the query and arguments.
+func (t *Tx) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
+ trace(t.logger, query, args...)
+ return t.Tx.SelectContext(ctx, dest, query, args...)
+}
+
+// GetContext is a wrapper around sqlx.GetContext that logs the query and arguments.
+func (t *Tx) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
+ trace(t.logger, query, args...)
+ return t.Tx.GetContext(ctx, dest, query, args...)
+}
+
+// QueryxContext is a wrapper around sqlx.QueryxContext that logs the query and arguments.
+func (t *Tx) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
+ trace(t.logger, query, args...)
+ return t.Tx.QueryxContext(ctx, query, args...)
+}
+
+// QueryRowxContext is a wrapper around sqlx.QueryRowxContext that logs the query and arguments.
+func (t *Tx) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row {
+ trace(t.logger, query, args...)
+ return t.Tx.QueryRowxContext(ctx, query, args...)
+}
+
+// ExecContext is a wrapper around sqlx.ExecContext that logs the query and arguments.
+func (t *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
+ trace(t.logger, query, args...)
+ return t.Tx.ExecContext(ctx, query, args...)
+}
@@ -0,0 +1,134 @@
+package migrate
+
+import (
+ "context"
+ "errors"
+ "fmt"
+
+ "github.com/charmbracelet/soft-serve/server/access"
+ "github.com/charmbracelet/soft-serve/server/config"
+ "github.com/charmbracelet/soft-serve/server/db"
+ "github.com/charmbracelet/soft-serve/server/sshutils"
+)
+
+const (
+ createTablesName = "create tables"
+ createTablesVersion = 1
+)
+
+var createTables = Migration{
+ Version: createTablesVersion,
+ Name: createTablesName,
+ Migrate: func(ctx context.Context, tx *db.Tx) error {
+ cfg := config.FromContext(ctx)
+
+ insert := "INSERT "
+
+ // Alter old tables (if exist)
+ // This is to support prior versions of Soft Serve
+ switch tx.DriverName() {
+ case "sqlite3", "sqlite":
+ insert += "OR IGNORE "
+
+ hasUserTable := hasTable(tx, "user")
+ if hasUserTable {
+ if _, err := tx.ExecContext(ctx, "ALTER TABLE user RENAME TO users"); err != nil {
+ return err
+ }
+ }
+
+ if hasTable(tx, "public_key") {
+ if _, err := tx.ExecContext(ctx, "ALTER TABLE public_key RENAME TO public_keys"); err != nil {
+ return err
+ }
+ }
+
+ if hasTable(tx, "collab") {
+ if _, err := tx.ExecContext(ctx, "ALTER TABLE collab RENAME TO collabs"); err != nil {
+ return err
+ }
+ }
+
+ if hasTable(tx, "repo") {
+ if _, err := tx.ExecContext(ctx, "ALTER TABLE repo RENAME TO repos"); err != nil {
+ return err
+ }
+ }
+
+ // Fix username being nullable
+ if hasUserTable {
+ sqlm := `
+ PRAGMA foreign_keys = OFF;
+
+ CREATE TABLE users_new (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ username TEXT NOT NULL UNIQUE,
+ admin BOOLEAN NOT NULL,
+ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at DATETIME NOT NULL
+ );
+
+ INSERT INTO users_new (username, admin, updated_at)
+ SELECT username, admin, updated_at FROM users;
+
+ DROP TABLE users;
+ ALTER TABLE users_new RENAME TO users;
+
+ PRAGMA foreign_keys = ON;
+ `
+ if _, err := tx.ExecContext(ctx, sqlm); err != nil {
+ return err
+ }
+ }
+ }
+
+ if err := migrateUp(ctx, tx, createTablesVersion, createTablesName); err != nil {
+ return err
+ }
+
+ // Insert default user
+ insertUser := tx.Rebind(insert + "INTO users (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)")
+ if _, err := tx.ExecContext(ctx, insertUser, "admin", true); err != nil {
+ return err
+ }
+
+ for _, k := range cfg.AdminKeys() {
+ query := insert + "INTO public_keys (user_id, public_key, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)"
+ if tx.DriverName() == "postgres" {
+ query += " ON CONFLICT DO NOTHING"
+ }
+
+ query = tx.Rebind(query)
+ ak := sshutils.MarshalAuthorizedKey(k)
+ if _, err := tx.ExecContext(ctx, query, 1, ak); err != nil {
+ if errors.Is(db.WrapError(err), db.ErrDuplicateKey) {
+ continue
+ }
+ return err
+ }
+ }
+
+ // Insert default settings
+ insertSettings := insert + "INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)"
+ insertSettings = tx.Rebind(insertSettings)
+ settings := []struct {
+ Key string
+ Value string
+ }{
+ {"allow_keyless", "true"},
+ {"anon_access", access.ReadOnlyAccess.String()},
+ {"init", "true"},
+ }
+
+ for _, s := range settings {
+ if _, err := tx.ExecContext(ctx, insertSettings, s.Key, s.Value); err != nil {
+ return fmt.Errorf("inserting default settings %q: %w", s.Key, err)
+ }
+ }
+
+ return nil
+ },
+ Rollback: func(ctx context.Context, tx *db.Tx) error {
+ return migrateDown(ctx, tx, createTablesVersion, createTablesName)
+ },
+}
@@ -0,0 +1,5 @@
+DROP TABLE IF EXISTS collabs;
+DROP TABLE IF EXISTS repos;
+DROP TABLE IF EXISTS public_keys;
+DROP TABLE IF EXISTS users;
+DROP TABLE IF EXISTS settings;
@@ -0,0 +1,59 @@
+CREATE TABLE IF NOT EXISTS settings (
+ id SERIAL PRIMARY KEY,
+ key TEXT NOT NULL UNIQUE,
+ value TEXT NOT NULL,
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS users (
+ id SERIAL PRIMARY KEY,
+ username TEXT NOT NULL UNIQUE,
+ admin BOOLEAN NOT NULL,
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS public_keys (
+ id SERIAL PRIMARY KEY,
+ user_id INTEGER NOT NULL,
+ public_key TEXT NOT NULL UNIQUE,
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP NOT NULL,
+ UNIQUE (user_id, public_key),
+ CONSTRAINT user_id_fk
+ FOREIGN KEY(user_id) REFERENCES users(id)
+ ON DELETE CASCADE
+ ON UPDATE CASCADE
+);
+
+CREATE TABLE IF NOT EXISTS repos (
+ id SERIAL PRIMARY KEY,
+ name TEXT NOT NULL UNIQUE,
+ project_name TEXT NOT NULL,
+ description TEXT NOT NULL,
+ private BOOLEAN NOT NULL,
+ mirror BOOLEAN NOT NULL,
+ hidden BOOLEAN NOT NULL,
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS collabs (
+ id SERIAL PRIMARY KEY,
+ user_id INTEGER NOT NULL,
+ repo_id INTEGER NOT NULL,
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at TIMESTAMP NOT NULL,
+ UNIQUE (user_id, repo_id),
+ CONSTRAINT user_id_fk
+ FOREIGN KEY(user_id) REFERENCES users(id)
+ ON DELETE CASCADE
+ ON UPDATE CASCADE,
+ CONSTRAINT repo_id_fk
+ FOREIGN KEY(repo_id) REFERENCES repos(id)
+ ON DELETE CASCADE
+ ON UPDATE CASCADE
+);
+
+
@@ -0,0 +1,5 @@
+DROP TABLE IF EXISTS collabs;
+DROP TABLE IF EXISTS repos;
+DROP TABLE IF EXISTS public_keys;
+DROP TABLE IF EXISTS users;
+DROP TABLE IF EXISTS settings;
@@ -0,0 +1,58 @@
+CREATE TABLE IF NOT EXISTS settings (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ key TEXT NOT NULL UNIQUE,
+ value TEXT NOT NULL,
+ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at DATETIME NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS users (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ username TEXT NOT NULL UNIQUE,
+ admin BOOLEAN NOT NULL,
+ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at DATETIME NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS public_keys (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ public_key TEXT NOT NULL UNIQUE,
+ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at DATETIME NOT NULL,
+ UNIQUE (user_id, public_key),
+ CONSTRAINT user_id_fk
+ FOREIGN KEY(user_id) REFERENCES users(id)
+ ON DELETE CASCADE
+ ON UPDATE CASCADE
+);
+
+CREATE TABLE IF NOT EXISTS repos (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ name TEXT NOT NULL UNIQUE,
+ project_name TEXT NOT NULL,
+ description TEXT NOT NULL,
+ private BOOLEAN NOT NULL,
+ mirror BOOLEAN NOT NULL,
+ hidden BOOLEAN NOT NULL,
+ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at DATETIME NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS collabs (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ repo_id INTEGER NOT NULL,
+ created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ updated_at DATETIME NOT NULL,
+ UNIQUE (user_id, repo_id),
+ CONSTRAINT user_id_fk
+ FOREIGN KEY(user_id) REFERENCES users(id)
+ ON DELETE CASCADE
+ ON UPDATE CASCADE,
+ CONSTRAINT repo_id_fk
+ FOREIGN KEY(repo_id) REFERENCES repos(id)
+ ON DELETE CASCADE
+ ON UPDATE CASCADE
+);
+
@@ -0,0 +1,142 @@
+package migrate
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+
+ "github.com/charmbracelet/log"
+ "github.com/charmbracelet/soft-serve/server/db"
+)
+
+// MigrateFunc is a function that executes a migration.
+type MigrateFunc func(ctx context.Context, tx *db.Tx) error // nolint:revive
+
+// Migration is a struct that contains the name of the migration and the
+// function to execute it.
+type Migration struct {
+ Version int64
+ Name string
+ Migrate MigrateFunc
+ Rollback MigrateFunc
+}
+
+// Migrations is a database model to store migrations.
+type Migrations struct {
+ ID int64 `db:"id"`
+ Name string `db:"name"`
+ Version int64 `db:"version"`
+}
+
+func (Migrations) schema(driverName string) string {
+ switch driverName {
+ case "sqlite3", "sqlite":
+ return `CREATE TABLE IF NOT EXISTS migrations (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ name TEXT NOT NULL,
+ version INTEGER NOT NULL UNIQUE
+ );
+ `
+ case "postgres":
+ return `CREATE TABLE IF NOT EXISTS migrations (
+ id SERIAL PRIMARY KEY,
+ name TEXT NOT NULL,
+ version INTEGER NOT NULL UNIQUE
+ );
+ `
+ case "mysql":
+ return `CREATE TABLE IF NOT EXISTS migrations (
+ id INT NOT NULL AUTO_INCREMENT,
+ name TEXT NOT NULL,
+ version INT NOT NULL,
+ UNIQUE (version),
+ PRIMARY KEY (id)
+ );
+ `
+ default:
+ panic("unknown driver")
+ }
+}
+
+// Migrate runs the migrations.
+func Migrate(ctx context.Context, dbx *db.DB) error {
+ logger := log.FromContext(ctx).WithPrefix("migrate")
+ return dbx.TransactionContext(ctx, func(tx *db.Tx) error {
+ if !hasTable(tx, "migrations") {
+ if _, err := tx.Exec(Migrations{}.schema(tx.DriverName())); err != nil {
+ return err
+ }
+ }
+
+ var migrs Migrations
+ if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil {
+ if !errors.Is(err, sql.ErrNoRows) {
+ return err
+ }
+ }
+
+ for _, m := range migrations {
+ if m.Version <= migrs.Version {
+ continue
+ }
+
+ logger.Infof("running migration %d. %s", m.Version, m.Name)
+ if err := m.Migrate(ctx, tx); err != nil {
+ return err
+ }
+
+ if _, err := tx.Exec(tx.Rebind("INSERT INTO migrations (name, version) VALUES (?, ?)"), m.Name, m.Version); err != nil {
+ return err
+ }
+ }
+
+ return nil
+ })
+}
+
+// Rollback rolls back a migration.
+func Rollback(ctx context.Context, dbx *db.DB) error {
+ logger := log.FromContext(ctx).WithPrefix("migrate")
+ return dbx.TransactionContext(ctx, func(tx *db.Tx) error {
+ var migrs Migrations
+ if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil {
+ if !errors.Is(err, sql.ErrNoRows) {
+ return fmt.Errorf("there are no migrations to rollback: %w", err)
+ }
+ }
+
+ if len(migrations) < int(migrs.Version) {
+ return fmt.Errorf("there are no migrations to rollback")
+ }
+
+ m := migrations[migrs.Version-1]
+ logger.Infof("rolling back migration %d. %s", m.Version, m.Name)
+ if err := m.Rollback(ctx, tx); err != nil {
+ return err
+ }
+
+ if _, err := tx.Exec(tx.Rebind("DELETE FROM migrations WHERE version = ?"), migrs.Version); err != nil {
+ return err
+ }
+
+ return nil
+ })
+}
+
+func hasTable(tx *db.Tx, tableName string) bool {
+ var query string
+ switch tx.DriverName() {
+ case "sqlite3", "sqlite":
+ query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?"
+ case "postgres":
+ fallthrough
+ case "mysql":
+ query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = ?"
+ }
+
+ query = tx.Rebind(query)
+ var name string
+ err := tx.Get(&name, query, tableName)
+ return err == nil
+}
@@ -0,0 +1,62 @@
+package migrate
+
+import (
+ "context"
+ "embed"
+ "fmt"
+ "regexp"
+ "strings"
+
+ "github.com/charmbracelet/soft-serve/server/db"
+)
+
+//go:embed *.sql
+var sqls embed.FS
+
+// Keep this in order of execution, oldest to newest.
+var migrations = []Migration{
+ createTables,
+}
+
+func execMigration(ctx context.Context, tx *db.Tx, version int, name string, down bool) error {
+ direction := "up"
+ if down {
+ direction = "down"
+ }
+
+ driverName := tx.DriverName()
+ if driverName == "sqlite3" {
+ driverName = "sqlite"
+ }
+
+ fn := fmt.Sprintf("%04d_%s_%s.%s.sql", version, toSnakeCase(name), driverName, direction)
+ sqlstr, err := sqls.ReadFile(fn)
+ if err != nil {
+ return err
+ }
+
+ if _, err := tx.ExecContext(ctx, string(sqlstr)); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func migrateUp(ctx context.Context, tx *db.Tx, version int, name string) error {
+ return execMigration(ctx, tx, version, name, false)
+}
+
+func migrateDown(ctx context.Context, tx *db.Tx, version int, name string) error {
+ return execMigration(ctx, tx, version, name, true)
+}
+
+var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
+var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")
+
+func toSnakeCase(str string) string {
+ str = strings.ReplaceAll(str, "-", "_")
+ str = strings.ReplaceAll(str, " ", "_")
+ snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}")
+ snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}")
+ return strings.ToLower(snake)
+}
@@ -0,0 +1,12 @@
+package models
+
+import "time"
+
+// Collab represents a repository collaborator.
+type Collab struct {
+ ID int64 `db:"id"`
+ RepoID int64 `db:"repo_id"`
+ UserID int64 `db:"user_id"`
+ CreatedAt time.Time `db:"created_at"`
+ UpdatedAt time.Time `db:"updated_at"`
+}
@@ -0,0 +1,10 @@
+package models
+
+// PublicKey represents a public key.
+type PublicKey struct {
+ ID int64 `db:"id"`
+ UserID int64 `db:"user_id"`
+ PublicKey string `db:"public_key"`
+ CreatedAt string `db:"created_at"`
+ UpdatedAt string `db:"updated_at"`
+}
@@ -0,0 +1,16 @@
+package models
+
+import "time"
+
+// Repo is a database model for a repository.
+type Repo struct {
+ ID int64 `db:"id"`
+ Name string `db:"name"`
+ ProjectName string `db:"project_name"`
+ Description string `db:"description"`
+ Private bool `db:"private"`
+ Mirror bool `db:"mirror"`
+ Hidden bool `db:"hidden"`
+ CreatedAt time.Time `db:"created_at"`
+ UpdatedAt time.Time `db:"updated_at"`
+}
@@ -0,0 +1,10 @@
+package models
+
+// Settings represents a settings record.
+type Settings struct {
+ ID int64 `db:"id"`
+ Key string `db:"key"`
+ Value string `db:"value"`
+ CreatedAt string `db:"created_at"`
+ UpdatedAt string `db:"updated_at"`
+}
@@ -0,0 +1,12 @@
+package models
+
+import "time"
+
+// User represents a user.
+type User struct {
+ ID int64 `db:"id"`
+ Username string `db:"username"`
+ Admin bool `db:"admin"`
+ CreatedAt time.Time `db:"created_at"`
+ UpdatedAt time.Time `db:"updated_at"`
+}
@@ -1,12 +0,0 @@
-package errors
-
-import "fmt"
-
-var (
- // ErrUnauthorized is returned when the user is not authorized to perform action.
- ErrUnauthorized = fmt.Errorf("Unauthorized")
- // ErrRepoNotFound is returned when the repo is not found.
- ErrRepoNotFound = fmt.Errorf("Repository not found")
- // ErrFileNotFound is returned when the file is not found.
- ErrFileNotFound = fmt.Errorf("File not found")
-)
@@ -35,15 +35,22 @@ var (
)
// WritePktline encodes and writes a pktline to the given writer.
-func WritePktline(w io.Writer, v ...interface{}) {
+func WritePktline(w io.Writer, v ...interface{}) error {
msg := fmt.Sprintln(v...)
pkt := pktline.NewEncoder(w)
if err := pkt.EncodeString(msg); err != nil {
- log.Debugf("git: error writing pkt-line message: %s", err)
+ return fmt.Errorf("git: error writing pkt-line message: %w", err)
}
if err := pkt.Flush(); err != nil {
- log.Debugf("git: error flushing pkt-line message: %s", err)
+ return fmt.Errorf("git: error flushing pkt-line message: %w", err)
}
+
+ return nil
+}
+
+// WritePktlineErr writes an error pktline to the given writer.
+func WritePktlineErr(w io.Writer, err error) error {
+ return WritePktline(w, "ERR", err.Error())
}
// EnsureWithin ensures the given repo is within the repos directory.
@@ -97,45 +97,47 @@ func gitServiceHandler(ctx context.Context, svc Service, scmd ServiceCommand) er
log.Debugf("git service command in %q: %s", cmd.Dir, cmd.String())
if err := cmd.Start(); err != nil {
+ if errors.Is(err, os.ErrNotExist) {
+ return ErrInvalidRepo
+ }
return err
}
- errg, ctx := errgroup.WithContext(ctx)
+ errg, _ := errgroup.WithContext(ctx)
// stdin
if scmd.Stdin != nil {
errg.Go(func() error {
- if scmd.StdinHandler != nil {
- return scmd.StdinHandler(scmd.Stdin, stdin)
- } else {
- return defaultStdinHandler(scmd.Stdin, stdin)
- }
+ defer stdin.Close() // nolint: errcheck
+ _, err := io.Copy(stdin, scmd.Stdin)
+ return err
})
}
// stdout
if scmd.Stdout != nil {
errg.Go(func() error {
- if scmd.StdoutHandler != nil {
- return scmd.StdoutHandler(scmd.Stdout, stdout)
- } else {
- return defaultStdoutHandler(scmd.Stdout, stdout)
- }
+ _, err := io.Copy(scmd.Stdout, stdout)
+ return err
})
}
// stderr
if scmd.Stderr != nil {
errg.Go(func() error {
- if scmd.StderrHandler != nil {
- return scmd.StderrHandler(scmd.Stderr, stderr)
- } else {
- return defaultStderrHandler(scmd.Stderr, stderr)
- }
+ _, erro := io.Copy(scmd.Stderr, stderr)
+ return erro
})
}
- return errors.Join(errg.Wait(), cmd.Wait())
+ err = errors.Join(errg.Wait(), cmd.Wait())
+ if err != nil && errors.Is(err, os.ErrNotExist) {
+ return ErrInvalidRepo
+ } else if err != nil {
+ return err
+ }
+
+ return nil
}
// ServiceCommand is used to run a git service command.
@@ -148,26 +150,7 @@ type ServiceCommand struct {
Args []string
// Modifier functions
- CmdFunc func(*exec.Cmd)
- StdinHandler func(io.Reader, io.WriteCloser) error
- StdoutHandler func(io.Writer, io.ReadCloser) error
- StderrHandler func(io.Writer, io.ReadCloser) error
-}
-
-func defaultStdinHandler(in io.Reader, stdin io.WriteCloser) error {
- defer stdin.Close() // nolint: errcheck
- _, err := io.Copy(stdin, in)
- return err
-}
-
-func defaultStdoutHandler(out io.Writer, stdout io.ReadCloser) error {
- _, err := io.Copy(out, stdout)
- return err
-}
-
-func defaultStderrHandler(err io.Writer, stderr io.ReadCloser) error {
- _, erro := io.Copy(err, stderr)
- return erro
+ CmdFunc func(*exec.Cmd)
}
// UploadPack runs the git upload-pack protocol against the provided repo.
@@ -0,0 +1,140 @@
+package hooks
+
+import (
+ "bytes"
+ "context"
+ "flag"
+ "os"
+ "path/filepath"
+ "text/template"
+
+ "github.com/charmbracelet/log"
+ "github.com/charmbracelet/soft-serve/server/config"
+ "github.com/charmbracelet/soft-serve/server/utils"
+)
+
+// The names of git server-side hooks.
+const (
+ PreReceiveHook = "pre-receive"
+ UpdateHook = "update"
+ PostReceiveHook = "post-receive"
+ PostUpdateHook = "post-update"
+)
+
+// GenerateHooks generates git server-side hooks for a repository. Currently, it supports the following hooks:
+// - pre-receive
+// - update
+// - post-receive
+// - post-update
+//
+// This function should be called by the backend when a repository is created.
+// TODO: support context.
+func GenerateHooks(_ context.Context, cfg *config.Config, repo string) error {
+ // TODO: support git hook tests.
+ if flag.Lookup("test.v") != nil {
+ log.WithPrefix("backend.hooks").Warn("refusing to set up hooks when in test")
+ return nil
+ }
+ repo = utils.SanitizeRepo(repo) + ".git"
+ hooksPath := filepath.Join(cfg.DataPath, "repos", repo, "hooks")
+ if err := os.MkdirAll(hooksPath, os.ModePerm); err != nil {
+ return err
+ }
+
+ ex, err := os.Executable()
+ if err != nil {
+ return err
+ }
+
+ for _, hook := range []string{
+ PreReceiveHook,
+ UpdateHook,
+ PostReceiveHook,
+ PostUpdateHook,
+ } {
+ var data bytes.Buffer
+ var args string
+
+ // Hooks script/directory path
+ hp := filepath.Join(hooksPath, hook)
+
+ // Write the hooks primary script
+ if err := os.WriteFile(hp, []byte(hookTemplate), os.ModePerm); err != nil {
+ return err
+ }
+
+ // Create ${hook}.d directory.
+ hp += ".d"
+ if err := os.MkdirAll(hp, os.ModePerm); err != nil {
+ return err
+ }
+
+ switch hook {
+ case UpdateHook:
+ args = "$1 $2 $3"
+ case PostUpdateHook:
+ args = "$@"
+ }
+
+ if err := hooksTmpl.Execute(&data, struct {
+ Executable string
+ Hook string
+ Args string
+ }{
+ Executable: ex,
+ Hook: hook,
+ Args: args,
+ }); err != nil {
+ log.WithPrefix("hooks").Error("failed to execute hook template", "err", err)
+ continue
+ }
+
+ // Write the soft-serve hook inside ${hook}.d directory.
+ hp = filepath.Join(hp, "soft-serve")
+ err = os.WriteFile(hp, data.Bytes(), os.ModePerm) //nolint:gosec
+ if err != nil {
+ log.WithPrefix("hooks").Error("failed to write hook", "err", err)
+ continue
+ }
+ }
+
+ return nil
+}
+
+const (
+ // hookTemplate allows us to run multiple hooks from a directory. It should
+ // support every type of git hook, as it proxies both stdin and arguments.
+ hookTemplate = `#!/usr/bin/env bash
+# AUTO GENERATED BY SOFT SERVE, DO NOT MODIFY
+data=$(cat)
+exitcodes=""
+hookname=$(basename $0)
+GIT_DIR=${GIT_DIR:-$(dirname $0)/..}
+for hook in ${GIT_DIR}/hooks/${hookname}.d/*; do
+ # Avoid running non-executable hooks
+ test -x "${hook}" && test -f "${hook}" || continue
+
+ # Run the actual hook
+ echo "${data}" | "${hook}" "$@"
+
+ # Store the exit code for later use
+ exitcodes="${exitcodes} $?"
+done
+
+# Exit on the first non-zero exit code.
+for i in ${exitcodes}; do
+ [ ${i} -eq 0 ] || exit ${i}
+done
+`
+)
+
+// hooksTmpl is the soft-serve hook that will be run by the git hooks
+// inside the hooks directory.
+var hooksTmpl = template.Must(template.New("hooks").Parse(`#!/usr/bin/env bash
+# AUTO GENERATED BY SOFT SERVE, DO NOT MODIFY
+if [ -z "$SOFT_SERVE_REPO_NAME" ]; then
+ echo "Warning: SOFT_SERVE_REPO_NAME not defined. Skipping hooks."
+ exit 0
+fi
+{{ .Executable }} hook {{ .Hook }} {{ .Args }}
+`))
@@ -1,156 +1,21 @@
package hooks
import (
- "bytes"
"context"
- "flag"
- "fmt"
- "os"
- "path/filepath"
- "text/template"
-
- "github.com/charmbracelet/log"
- "github.com/charmbracelet/soft-serve/server/config"
- "github.com/charmbracelet/soft-serve/server/utils"
-)
-
-// The names of git server-side hooks.
-const (
- PreReceiveHook = "pre-receive"
- UpdateHook = "update"
- PostReceiveHook = "post-receive"
- PostUpdateHook = "post-update"
+ "io"
)
-// GenerateHooks generates git server-side hooks for a repository. Currently, it supports the following hooks:
-// - pre-receive
-// - update
-// - post-receive
-// - post-update
-//
-// This function should be called by the backend when a repository is created.
-// TODO: support context.
-func GenerateHooks(_ context.Context, cfg *config.Config, repo string) error {
- // TODO: support git hook tests.
- if flag.Lookup("test.v") != nil {
- log.WithPrefix("backend.hooks").Warn("refusing to set up hooks when in test")
- return nil
- }
- repo = utils.SanitizeRepo(repo) + ".git"
- hooksPath := filepath.Join(cfg.DataPath, "repos", repo, "hooks")
- if err := os.MkdirAll(hooksPath, os.ModePerm); err != nil {
- return err
- }
-
- ex, err := os.Executable()
- if err != nil {
- return err
- }
-
- dp, err := filepath.Abs(cfg.DataPath)
- if err != nil {
- return fmt.Errorf("failed to get absolute path for data path: %w", err)
- }
-
- cp := filepath.Join(dp, "config.yaml")
- // Add extra environment variables to the hooks here.
- envs := []string{}
-
- for _, hook := range []string{
- PreReceiveHook,
- UpdateHook,
- PostReceiveHook,
- PostUpdateHook,
- } {
- var data bytes.Buffer
- var args string
-
- // Hooks script/directory path
- hp := filepath.Join(hooksPath, hook)
-
- // Write the hooks primary script
- if err := os.WriteFile(hp, []byte(hookTemplate), os.ModePerm); err != nil {
- return err
- }
-
- // Create ${hook}.d directory.
- hp += ".d"
- if err := os.MkdirAll(hp, os.ModePerm); err != nil {
- return err
- }
-
- switch hook {
- case UpdateHook:
- args = "$1 $2 $3"
- case PostUpdateHook:
- args = "$@"
- }
-
- if err := hooksTmpl.Execute(&data, struct {
- Executable string
- Config string
- Envs []string
- Hook string
- Args string
- }{
- Executable: ex,
- Config: cp,
- Envs: envs,
- Hook: hook,
- Args: args,
- }); err != nil {
- log.WithPrefix("backend.hooks").Error("failed to execute hook template", "err", err)
- continue
- }
-
- // Write the soft-serve hook inside ${hook}.d directory.
- hp = filepath.Join(hp, "soft-serve")
- err = os.WriteFile(hp, data.Bytes(), os.ModePerm) //nolint:gosec
- if err != nil {
- log.WithPrefix("backend.hooks").Error("failed to write hook", "err", err)
- continue
- }
- }
-
- return nil
+// HookArg is an argument to a git hook.
+type HookArg struct {
+ OldSha string
+ NewSha string
+ RefName string
}
-const (
- // hookTemplate allows us to run multiple hooks from a directory. It should
- // support every type of git hook, as it proxies both stdin and arguments.
- hookTemplate = `#!/usr/bin/env bash
-# AUTO GENERATED BY SOFT SERVE, DO NOT MODIFY
-data=$(cat)
-exitcodes=""
-hookname=$(basename $0)
-GIT_DIR=${GIT_DIR:-$(dirname $0)/..}
-for hook in ${GIT_DIR}/hooks/${hookname}.d/*; do
- # Avoid running non-executable hooks
- test -x "${hook}" && test -f "${hook}" || continue
-
- # Run the actual hook
- echo "${data}" | "${hook}" "$@"
-
- # Store the exit code for later use
- exitcodes="${exitcodes} $?"
-done
-
-# Exit on the first non-zero exit code.
-for i in ${exitcodes}; do
- [ ${i} -eq 0 ] || exit ${i}
-done
-`
-)
-
-// hooksTmpl is the soft-serve hook that will be run by the git hooks
-// inside the hooks directory.
-var hooksTmpl = template.Must(template.New("hooks").Parse(`#!/usr/bin/env bash
-# AUTO GENERATED BY SOFT SERVE, DO NOT MODIFY
-if [ -z "$SOFT_SERVE_REPO_NAME" ]; then
- echo "Warning: SOFT_SERVE_REPO_NAME not defined. Skipping hooks."
- exit 0
-fi
-{{ range $_, $env := .Envs }}
-{{ $env }} \{{ end }}
-{{ .Executable }} hook --config "{{ .Config }}" {{ .Hook }} {{ .Args }}
-`))
+// Hooks provides an interface for git server-side hooks.
+type Hooks interface {
+ PreReceive(ctx context.Context, stdout io.Writer, stderr io.Writer, repo string, args []HookArg)
+ Update(ctx context.Context, stdout io.Writer, stderr io.Writer, repo string, arg HookArg)
+ PostReceive(ctx context.Context, stdout io.Writer, stderr io.Writer, repo string, args []HookArg)
+ PostUpdate(ctx context.Context, stdout io.Writer, stderr io.Writer, repo string, args ...string)
+}
@@ -6,7 +6,8 @@ import (
"runtime"
"github.com/charmbracelet/soft-serve/git"
- "github.com/charmbracelet/soft-serve/internal/sync"
+ "github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/sync"
)
var jobSpecs = map[string]string{
@@ -14,12 +15,11 @@ var jobSpecs = map[string]string{
}
// mirrorJob runs the (pull) mirror job task.
-func (s *Server) mirrorJob() func() {
+func (s *Server) mirrorJob(b *backend.Backend) func() {
cfg := s.Config
- b := cfg.Backend
logger := s.logger
return func() {
- repos, err := b.Repositories()
+ repos, err := b.Repositories(s.ctx)
if err != nil {
logger.Error("error getting repositories", "err", err)
return
@@ -48,6 +48,7 @@ func (s *Server) mirrorJob() func() {
cfg.SSH.ClientKeyPath,
),
)
+
if _, err := cmd.RunInDir(r.Path); err != nil {
logger.Error("error running git remote update", "repo", name, "err", err)
}
@@ -0,0 +1,16 @@
+package proto
+
+import (
+ "errors"
+)
+
+var (
+ // ErrUnauthorized is returned when the user is not authorized to perform action.
+ ErrUnauthorized = errors.New("Unauthorized")
+ // ErrFileNotFound is returned when the file is not found.
+ ErrFileNotFound = errors.New("File not found")
+ // ErrRepoNotExist is returned when a repository does not exist.
+ ErrRepoNotExist = errors.New("repository does not exist")
+ // ErrRepoExist is returned when a repository already exists.
+ ErrRepoExist = errors.New("repository already exists")
+)
@@ -0,0 +1,37 @@
+package proto
+
+import (
+ "time"
+
+ "github.com/charmbracelet/soft-serve/git"
+)
+
+// Repository is a Git repository interface.
+type Repository interface {
+ // Name returns the repository's name.
+ Name() string
+ // ProjectName returns the repository's project name.
+ ProjectName() string
+ // Description returns the repository's description.
+ Description() string
+ // IsPrivate returns whether the repository is private.
+ IsPrivate() bool
+ // IsMirror returns whether the repository is a mirror.
+ IsMirror() bool
+ // IsHidden returns whether the repository is hidden.
+ IsHidden() bool
+ // UpdatedAt returns the time the repository was last updated.
+ // If the repository has never been updated, it returns the time it was created.
+ UpdatedAt() time.Time
+ // Open returns the underlying git.Repository.
+ Open() (*git.Repository, error)
+}
+
+// RepositoryOptions are options for creating a new repository.
+type RepositoryOptions struct {
+ Private bool
+ Description string
+ ProjectName string
+ Mirror bool
+ Hidden bool
+}
@@ -0,0 +1,21 @@
+package proto
+
+import "golang.org/x/crypto/ssh"
+
+// User is an interface representing a user.
+type User interface {
+ // Username returns the user's username.
+ Username() string
+ // IsAdmin returns whether the user is an admin.
+ IsAdmin() bool
+ // PublicKeys returns the user's public keys.
+ PublicKeys() []ssh.PublicKey
+}
+
+// UserOptions are options for creating a user.
+type UserOptions struct {
+ // Admin is whether the user is an admin.
+ Admin bool
+ // PublicKeys are the user's public keys.
+ PublicKeys []ssh.PublicKey
+}
@@ -4,16 +4,15 @@ import (
"context"
"errors"
"fmt"
- "io"
"net/http"
"github.com/charmbracelet/log"
"github.com/charmbracelet/soft-serve/server/backend"
- "github.com/charmbracelet/soft-serve/server/backend/sqlite"
"github.com/charmbracelet/soft-serve/server/config"
"github.com/charmbracelet/soft-serve/server/cron"
"github.com/charmbracelet/soft-serve/server/daemon"
+ "github.com/charmbracelet/soft-serve/server/db"
sshsrv "github.com/charmbracelet/soft-serve/server/ssh"
"github.com/charmbracelet/soft-serve/server/stats"
"github.com/charmbracelet/soft-serve/server/web"
@@ -27,43 +26,35 @@ type Server struct {
GitDaemon *daemon.GitDaemon
HTTPServer *web.HTTPServer
StatsServer *stats.StatsServer
- Cron *cron.CronScheduler
+ Cron *cron.Scheduler
Config *config.Config
- Backend backend.Backend
+ Backend *backend.Backend
+ DB *db.DB
logger *log.Logger
ctx context.Context
}
-// NewServer returns a new *ssh.Server configured to serve Soft Serve. The SSH
-// server key-pair will be created if none exists. An initial admin SSH public
-// key can be provided with authKey. If authKey is provided, access will be
-// restricted to that key. If authKey is not provided, the server will be
-// publicly writable until configured otherwise by cloning the `config` repo.
+// NewServer returns a new *Server configured to serve Soft Serve. The SSH
+// server key-pair will be created if none exists.
+// It expects a context with *backend.Backend, *db.DB, *log.Logger, and
+// *config.Config attached.
func NewServer(ctx context.Context) (*Server, error) {
- cfg := config.FromContext(ctx)
-
var err error
- if cfg.Backend == nil {
- sb, err := sqlite.NewSqliteBackend(ctx)
- if err != nil {
- return nil, fmt.Errorf("create backend: %w", err)
- }
-
- cfg = cfg.WithBackend(sb)
- ctx = backend.WithContext(ctx, sb)
- }
-
+ cfg := config.FromContext(ctx)
+ be := backend.FromContext(ctx)
+ db := db.FromContext(ctx)
srv := &Server{
- Cron: cron.NewCronScheduler(ctx),
+ Cron: cron.NewScheduler(ctx),
Config: cfg,
- Backend: cfg.Backend,
+ Backend: be,
+ DB: db,
logger: log.FromContext(ctx).WithPrefix("server"),
ctx: ctx,
}
// Add cron jobs.
- _, _ = srv.Cron.AddFunc(jobSpecs["mirror"], srv.mirrorJob())
+ _, _ = srv.Cron.AddFunc(jobSpecs["mirror"], srv.mirrorJob(be))
srv.SSHServer, err = sshsrv.NewSSHServer(ctx)
if err != nil {
@@ -88,47 +79,33 @@ func NewServer(ctx context.Context) (*Server, error) {
return srv, nil
}
-func start(ctx context.Context, fn func() error) error {
- errc := make(chan error, 1)
- go func() {
- errc <- fn()
- }()
-
- select {
- case err := <-errc:
- return err
- case <-ctx.Done():
- return ctx.Err()
- }
-}
-
// Start starts the SSH server.
func (s *Server) Start() error {
- errg, ctx := errgroup.WithContext(s.ctx)
+ errg, _ := errgroup.WithContext(s.ctx)
errg.Go(func() error {
s.logger.Print("Starting Git daemon", "addr", s.Config.Git.ListenAddr)
- if err := start(ctx, s.GitDaemon.Start); !errors.Is(err, daemon.ErrServerClosed) {
+ if err := s.GitDaemon.Start(); !errors.Is(err, daemon.ErrServerClosed) {
return err
}
return nil
})
errg.Go(func() error {
s.logger.Print("Starting HTTP server", "addr", s.Config.HTTP.ListenAddr)
- if err := start(ctx, s.HTTPServer.ListenAndServe); !errors.Is(err, http.ErrServerClosed) {
+ if err := s.HTTPServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
})
errg.Go(func() error {
s.logger.Print("Starting SSH server", "addr", s.Config.SSH.ListenAddr)
- if err := start(ctx, s.SSHServer.ListenAndServe); !errors.Is(err, ssh.ErrServerClosed) {
+ if err := s.SSHServer.ListenAndServe(); !errors.Is(err, ssh.ErrServerClosed) {
return err
}
return nil
})
errg.Go(func() error {
s.logger.Print("Starting Stats server", "addr", s.Config.Stats.ListenAddr)
- if err := start(ctx, s.StatsServer.ListenAndServe); !errors.Is(err, http.ErrServerClosed) {
+ if err := s.StatsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
@@ -142,7 +119,7 @@ func (s *Server) Start() error {
// Shutdown lets the server gracefully shutdown.
func (s *Server) Shutdown(ctx context.Context) error {
- var errg errgroup.Group
+ errg, ctx := errgroup.WithContext(ctx)
errg.Go(func() error {
return s.GitDaemon.Shutdown(ctx)
})
@@ -159,9 +136,7 @@ func (s *Server) Shutdown(ctx context.Context) error {
s.Cron.Stop()
return nil
})
- if closer, ok := s.Backend.(io.Closer); ok {
- defer closer.Close() // nolint: errcheck
- }
+ // defer s.DB.Close() // nolint: errcheck
return errg.Wait()
}
@@ -176,8 +151,6 @@ func (s *Server) Close() error {
s.Cron.Stop()
return nil
})
- if closer, ok := s.Backend.(io.Closer); ok {
- defer closer.Close() // nolint: errcheck
- }
+ // defer s.DB.Close() // nolint: errcheck
return errg.Wait()
}
@@ -1,58 +0,0 @@
-package server
-
-import (
- "context"
- "fmt"
- "path/filepath"
- "strings"
- "testing"
-
- "github.com/charmbracelet/keygen"
- "github.com/charmbracelet/soft-serve/server/config"
- "github.com/charmbracelet/soft-serve/server/test"
- "github.com/charmbracelet/ssh"
- "github.com/matryer/is"
- gossh "golang.org/x/crypto/ssh"
-)
-
-func setupServer(tb testing.TB) (*Server, *config.Config, string) {
- tb.Helper()
- tb.Log("creating keypair")
- pub, pkPath := createKeyPair(tb)
- dp := tb.TempDir()
- sshPort := fmt.Sprintf(":%d", test.RandomPort())
- tb.Setenv("SOFT_SERVE_DATA_PATH", dp)
- tb.Setenv("SOFT_SERVE_INITIAL_ADMIN_KEY", authorizedKey(pub))
- tb.Setenv("SOFT_SERVE_SSH_LISTEN_ADDR", sshPort)
- tb.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", test.RandomPort()))
- ctx := context.TODO()
- cfg := config.DefaultConfig()
- ctx = config.WithContext(ctx, cfg)
- tb.Log("configuring server")
- s, err := NewServer(ctx)
- if err != nil {
- tb.Fatal(err)
- }
- go func() {
- tb.Log("starting server")
- s.Start()
- }()
- tb.Cleanup(func() {
- s.Close()
- })
- return s, cfg, pkPath
-}
-
-func createKeyPair(tb testing.TB) (ssh.PublicKey, string) {
- tb.Helper()
- is := is.New(tb)
- keyDir := tb.TempDir()
- fp := filepath.Join(keyDir, "id_ed25519")
- kp, err := keygen.New(fp, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite())
- is.NoErr(err)
- return kp.PublicKey(), fp
-}
-
-func authorizedKey(pk ssh.PublicKey) string {
- return strings.TrimSpace(string(gossh.MarshalAuthorizedKey(pk)))
-}
@@ -0,0 +1,17 @@
+package ssh
+
+import (
+ "github.com/charmbracelet/log"
+ "github.com/charmbracelet/soft-serve/server/ssh/cmd"
+ "github.com/charmbracelet/ssh"
+)
+
+func handleCli(s ssh.Session) {
+ ctx := s.Context()
+ logger := log.FromContext(ctx)
+ rootCmd := cmd.RootCommand(s)
+ if err := rootCmd.ExecuteContext(ctx); err != nil {
+ logger.Error("error executing command", "err", err)
+ _ = s.Exit(1)
+ }
+}
@@ -8,6 +8,7 @@ import (
gansi "github.com/charmbracelet/glamour/ansi"
"github.com/charmbracelet/lipgloss"
"github.com/charmbracelet/soft-serve/git"
+ "github.com/charmbracelet/soft-serve/server/backend"
"github.com/charmbracelet/soft-serve/server/ui/common"
"github.com/muesli/termenv"
"github.com/spf13/cobra"
@@ -16,9 +17,6 @@ import (
var (
lineDigitStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("239"))
lineBarStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("236"))
- dirnameStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#00AAFF"))
- filenameStyle = lipgloss.NewStyle()
- filemodeStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#777777"))
)
// blobCommand returns a command that prints the contents of a file.
@@ -34,7 +32,8 @@ func blobCommand() *cobra.Command {
Args: cobra.RangeArgs(1, 3),
PersistentPreRunE: checkIfReadable,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
rn := args[0]
ref := ""
fp := ""
@@ -46,7 +45,7 @@ func blobCommand() *cobra.Command {
fp = args[2]
}
- repo, err := cfg.Backend.Repository(rn)
+ repo, err := be.Repository(ctx, rn)
if err != nil {
return err
}
@@ -5,6 +5,7 @@ import (
"strings"
"github.com/charmbracelet/soft-serve/git"
+ "github.com/charmbracelet/soft-serve/server/backend"
gitm "github.com/gogs/git-module"
"github.com/spf13/cobra"
)
@@ -31,9 +32,10 @@ func branchListCommand() *cobra.Command {
Args: cobra.ExactArgs(1),
PersistentPreRunE: checkIfReadable,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
rn := strings.TrimSuffix(args[0], ".git")
- rr, err := cfg.Backend.Repository(rn)
+ rr, err := be.Repository(ctx, rn)
if err != nil {
return err
}
@@ -61,14 +63,15 @@ func branchDefaultCommand() *cobra.Command {
Short: "Set or get the default branch",
Args: cobra.RangeArgs(1, 2),
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
rn := strings.TrimSuffix(args[0], ".git")
switch len(args) {
case 1:
if err := checkIfReadable(cmd, args); err != nil {
return err
}
- rr, err := cfg.Backend.Repository(rn)
+ rr, err := be.Repository(ctx, rn)
if err != nil {
return err
}
@@ -89,7 +92,7 @@ func branchDefaultCommand() *cobra.Command {
return err
}
- rr, err := cfg.Backend.Repository(rn)
+ rr, err := be.Repository(ctx, rn)
if err != nil {
return err
}
@@ -132,9 +135,10 @@ func branchDeleteCommand() *cobra.Command {
Short: "Delete a branch",
PersistentPreRunE: checkIfCollab,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
rn := strings.TrimSuffix(args[0], ".git")
- rr, err := cfg.Backend.Repository(rn)
+ rr, err := be.Repository(ctx, rn)
if err != nil {
return err
}
@@ -167,11 +171,7 @@ func branchDeleteCommand() *cobra.Command {
return fmt.Errorf("cannot delete the default branch")
}
- if err := r.DeleteBranch(branch, gitm.DeleteBranchOptions{Force: true}); err != nil {
- return err
- }
-
- return nil
+ return r.DeleteBranch(branch, gitm.DeleteBranchOptions{Force: true})
},
}
@@ -1,28 +1,24 @@
package cmd
import (
- "context"
"fmt"
"net/url"
"strings"
"text/template"
"unicode"
- "github.com/charmbracelet/log"
+ "github.com/charmbracelet/soft-serve/server/access"
"github.com/charmbracelet/soft-serve/server/backend"
"github.com/charmbracelet/soft-serve/server/config"
- "github.com/charmbracelet/soft-serve/server/errors"
+ "github.com/charmbracelet/soft-serve/server/proto"
+ "github.com/charmbracelet/soft-serve/server/sshutils"
"github.com/charmbracelet/soft-serve/server/utils"
"github.com/charmbracelet/ssh"
- "github.com/charmbracelet/wish"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/spf13/cobra"
)
-// sessionCtxKey is the key for the session in the context.
-var sessionCtxKey = &struct{ string }{"session"}
-
var cliCommandCounter = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "soft_serve",
Subsystem: "cli",
@@ -89,9 +85,14 @@ func cmdName(args []string) string {
return args[0]
}
-// rootCommand is the root command for the server.
-func rootCommand(cfg *config.Config, s ssh.Session) *cobra.Command {
- cliCommandCounter.WithLabelValues(cmdName(s.Command())).Inc()
+// RootCommand returns a new cli root command.
+func RootCommand(s ssh.Session) *cobra.Command {
+ ctx := s.Context()
+ cfg := config.FromContext(ctx)
+ be := backend.FromContext(ctx)
+
+ args := s.Command()
+ cliCommandCounter.WithLabelValues(cmdName(args)).Inc()
rootCmd := &cobra.Command{
Short: "Soft Serve is a self-hostable Git server for the command line.",
SilenceUsage: true,
@@ -129,7 +130,17 @@ func rootCommand(cfg *config.Config, s ssh.Session) *cobra.Command {
repoCommand(),
)
- user, _ := cfg.Backend.UserByPublicKey(s.PublicKey())
+ rootCmd.SetArgs(args)
+ if len(args) == 0 {
+ // otherwise it'll default to os.Args, which is not what we want.
+ rootCmd.SetArgs([]string{"--help"})
+ }
+ rootCmd.SetIn(s)
+ rootCmd.SetOut(s)
+ rootCmd.CompletionOptions.DisableDefaultCmd = true
+ rootCmd.SetErr(s.Stderr())
+
+ user, _ := be.UserByPublicKey(s.Context(), s.PublicKey())
isAdmin := isPublicKeyAdmin(cfg, s.PublicKey()) || (user != nil && user.IsAdmin())
if user != nil || isAdmin {
if isAdmin {
@@ -149,30 +160,26 @@ func rootCommand(cfg *config.Config, s ssh.Session) *cobra.Command {
return rootCmd
}
-func fromContext(cmd *cobra.Command) (*config.Config, ssh.Session) {
- ctx := cmd.Context()
- cfg := config.FromContext(ctx)
- s := ctx.Value(sessionCtxKey).(ssh.Session)
- return cfg, s
-}
-
func checkIfReadable(cmd *cobra.Command, args []string) error {
var repo string
if len(args) > 0 {
repo = args[0]
}
- cfg, s := fromContext(cmd)
+
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
rn := utils.SanitizeRepo(repo)
- auth := cfg.Backend.AccessLevelByPublicKey(rn, s.PublicKey())
- if auth < backend.ReadOnlyAccess {
- return errors.ErrUnauthorized
+ pk := sshutils.PublicKeyFromContext(ctx)
+ auth := be.AccessLevelByPublicKey(cmd.Context(), rn, pk)
+ if auth < access.ReadOnlyAccess {
+ return proto.ErrUnauthorized
}
return nil
}
func isPublicKeyAdmin(cfg *config.Config, pk ssh.PublicKey) bool {
for _, k := range cfg.AdminKeys() {
- if backend.KeysEqual(pk, k) {
+ if sshutils.KeysEqual(pk, k) {
return true
}
}
@@ -180,18 +187,21 @@ func isPublicKeyAdmin(cfg *config.Config, pk ssh.PublicKey) bool {
}
func checkIfAdmin(cmd *cobra.Command, _ []string) error {
- cfg, s := fromContext(cmd)
- if isPublicKeyAdmin(cfg, s.PublicKey()) {
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
+ cfg := config.FromContext(ctx)
+ pk := sshutils.PublicKeyFromContext(ctx)
+ if isPublicKeyAdmin(cfg, pk) {
return nil
}
- user, _ := cfg.Backend.UserByPublicKey(s.PublicKey())
+ user, _ := be.UserByPublicKey(ctx, pk)
if user == nil {
- return errors.ErrUnauthorized
+ return proto.ErrUnauthorized
}
if !user.IsAdmin() {
- return errors.ErrUnauthorized
+ return proto.ErrUnauthorized
}
return nil
@@ -202,61 +212,14 @@ func checkIfCollab(cmd *cobra.Command, args []string) error {
if len(args) > 0 {
repo = args[0]
}
- cfg, s := fromContext(cmd)
+
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
+ pk := sshutils.PublicKeyFromContext(ctx)
rn := utils.SanitizeRepo(repo)
- auth := cfg.Backend.AccessLevelByPublicKey(rn, s.PublicKey())
- if auth < backend.ReadWriteAccess {
- return errors.ErrUnauthorized
+ auth := be.AccessLevelByPublicKey(ctx, rn, pk)
+ if auth < access.ReadWriteAccess {
+ return proto.ErrUnauthorized
}
return nil
}
-
-// Middleware is the Soft Serve middleware that handles SSH commands.
-func Middleware(cfg *config.Config, logger *log.Logger) wish.Middleware {
- return func(sh ssh.Handler) ssh.Handler {
- return func(s ssh.Session) {
- func() {
- _, _, active := s.Pty()
- if active {
- return
- }
-
- // Ignore git server commands.
- args := s.Command()
- if len(args) > 0 {
- if args[0] == "git-receive-pack" ||
- args[0] == "git-upload-pack" ||
- args[0] == "git-upload-archive" {
- return
- }
- }
-
- // Here we copy the server's config and replace the backend
- // with a new one that uses the session's context.
- var ctx context.Context = s.Context()
- scfg := *cfg
- cfg = &scfg
- be := cfg.Backend.WithContext(ctx)
- cfg.Backend = be
- ctx = config.WithContext(ctx, cfg)
- ctx = backend.WithContext(ctx, be)
- ctx = context.WithValue(ctx, sessionCtxKey, s)
-
- rootCmd := rootCommand(cfg, s)
- rootCmd.SetArgs(args)
- if len(args) == 0 {
- // otherwise it'll default to os.Args, which is not what we want.
- rootCmd.SetArgs([]string{"--help"})
- }
- rootCmd.SetIn(s)
- rootCmd.SetOut(s)
- rootCmd.CompletionOptions.DisableDefaultCmd = true
- rootCmd.SetErr(s.Stderr())
- if err := rootCmd.ExecuteContext(ctx); err != nil {
- _ = s.Exit(1)
- }
- }()
- sh(s)
- }
- }
-}
@@ -1,6 +1,7 @@
package cmd
import (
+ "github.com/charmbracelet/soft-serve/server/backend"
"github.com/spf13/cobra"
)
@@ -27,11 +28,12 @@ func collabAddCommand() *cobra.Command {
Args: cobra.ExactArgs(2),
PersistentPreRunE: checkIfCollab,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
repo := args[0]
username := args[1]
- return cfg.Backend.AddCollaborator(repo, username)
+ return be.AddCollaborator(ctx, repo, username)
},
}
@@ -45,11 +47,12 @@ func collabRemoveCommand() *cobra.Command {
Short: "Remove a collaborator from a repo",
PersistentPreRunE: checkIfCollab,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
repo := args[0]
username := args[1]
- return cfg.Backend.RemoveCollaborator(repo, username)
+ return be.RemoveCollaborator(ctx, repo, username)
},
}
@@ -63,9 +66,10 @@ func collabListCommand() *cobra.Command {
Args: cobra.ExactArgs(1),
PersistentPreRunE: checkIfCollab,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
repo := args[0]
- collabs, err := cfg.Backend.Collaborators(repo)
+ collabs, err := be.Collaborators(ctx, repo)
if err != nil {
return err
}
@@ -7,6 +7,7 @@ import (
gansi "github.com/charmbracelet/glamour/ansi"
"github.com/charmbracelet/soft-serve/git"
+ "github.com/charmbracelet/soft-serve/server/backend"
"github.com/charmbracelet/soft-serve/server/ui/common"
"github.com/charmbracelet/soft-serve/server/ui/styles"
"github.com/muesli/termenv"
@@ -24,11 +25,12 @@ func commitCommand() *cobra.Command {
Args: cobra.ExactArgs(2),
PersistentPreRunE: checkIfReadable,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
repoName := args[0]
commitSHA := args[1]
- rr, err := cfg.Backend.Repository(repoName)
+ rr, err := be.Repository(ctx, repoName)
if err != nil {
return err
}
@@ -38,10 +40,13 @@ func commitCommand() *cobra.Command {
return err
}
- raw_commit, err := r.CommitByRevision(commitSHA)
+ rawCommit, err := r.CommitByRevision(commitSHA)
+ if err != nil {
+ return err
+ }
commit := &git.Commit{
- Commit: raw_commit,
+ Commit: rawCommit,
Hash: git.Hash(commitSHA),
}
@@ -61,7 +66,7 @@ func commitCommand() *cobra.Command {
s := strings.Builder{}
commitLine := "commit " + commitSHA
authorLine := "Author: " + commit.Author.Name
- dateLine := "Date: " + commit.Committer.When.Format(time.UnixDate)
+ dateLine := "Date: " + commit.Committer.When.UTC().Format(time.UnixDate)
msgLine := strings.ReplaceAll(commit.Message, "\r\n", "\n")
statsLine := renderStats(diff, commonStyle, color)
diffLine := renderDiff(patch, color)
@@ -116,8 +121,7 @@ func renderCtx() gansi.RenderContext {
}
func renderDiff(patch string, color bool) string {
-
- c := string(patch)
+ c := patch
if color {
var s strings.Builder
@@ -2,6 +2,7 @@ package cmd
import (
"github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/proto"
"github.com/spf13/cobra"
)
@@ -18,9 +19,10 @@ func createCommand() *cobra.Command {
Args: cobra.ExactArgs(1),
PersistentPreRunE: checkIfCollab,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
name := args[0]
- if _, err := cfg.Backend.CreateRepository(name, backend.RepositoryOptions{
+ if _, err := be.CreateRepository(ctx, name, proto.RepositoryOptions{
Private: private,
Description: description,
ProjectName: projectName,
@@ -1,6 +1,9 @@
package cmd
-import "github.com/spf13/cobra"
+import (
+ "github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/spf13/cobra"
+)
func deleteCommand() *cobra.Command {
cmd := &cobra.Command{
@@ -10,12 +13,11 @@ func deleteCommand() *cobra.Command {
Args: cobra.ExactArgs(1),
PersistentPreRunE: checkIfCollab,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
name := args[0]
- if err := cfg.Backend.DeleteRepository(name); err != nil {
- return err
- }
- return nil
+
+ return be.DeleteRepository(ctx, name)
},
}
return cmd
@@ -3,6 +3,7 @@ package cmd
import (
"strings"
+ "github.com/charmbracelet/soft-serve/server/backend"
"github.com/spf13/cobra"
)
@@ -13,7 +14,8 @@ func descriptionCommand() *cobra.Command {
Short: "Set or get the description for a repository",
Args: cobra.MinimumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
rn := strings.TrimSuffix(args[0], ".git")
switch len(args) {
case 1:
@@ -21,7 +23,7 @@ func descriptionCommand() *cobra.Command {
return err
}
- desc, err := cfg.Backend.Description(rn)
+ desc, err := be.Description(ctx, rn)
if err != nil {
return err
}
@@ -31,7 +33,7 @@ func descriptionCommand() *cobra.Command {
if err := checkIfCollab(cmd, args); err != nil {
return err
}
- if err := cfg.Backend.SetDescription(rn, strings.Join(args[1:], " ")); err != nil {
+ if err := be.SetDescription(ctx, rn, strings.Join(args[1:], " ")); err != nil {
return err
}
}
@@ -1,6 +1,9 @@
package cmd
-import "github.com/spf13/cobra"
+import (
+ "github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/spf13/cobra"
+)
func hiddenCommand() *cobra.Command {
cmd := &cobra.Command{
@@ -9,7 +12,8 @@ func hiddenCommand() *cobra.Command {
Aliases: []string{"hide"},
Args: cobra.MinimumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
repo := args[0]
switch len(args) {
case 1:
@@ -17,7 +21,7 @@ func hiddenCommand() *cobra.Command {
return err
}
- hidden, err := cfg.Backend.IsHidden(repo)
+ hidden, err := be.IsHidden(ctx, repo)
if err != nil {
return err
}
@@ -29,7 +33,7 @@ func hiddenCommand() *cobra.Command {
}
hidden := args[1] == "true"
- if err := cfg.Backend.SetHidden(repo, hidden); err != nil {
+ if err := be.SetHidden(ctx, repo, hidden); err != nil {
return err
}
}
@@ -2,6 +2,7 @@ package cmd
import (
"github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/proto"
"github.com/spf13/cobra"
)
@@ -19,10 +20,11 @@ func importCommand() *cobra.Command {
Args: cobra.ExactArgs(2),
PersistentPreRunE: checkIfCollab,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
name := args[0]
remote := args[1]
- if _, err := cfg.Backend.ImportRepository(name, remote, backend.RepositoryOptions{
+ if _, err := be.ImportRepository(ctx, name, remote, proto.RepositoryOptions{
Private: private,
Description: description,
ProjectName: projectName,
@@ -2,6 +2,7 @@ package cmd
import (
"github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/sshutils"
"github.com/spf13/cobra"
)
@@ -11,8 +12,10 @@ func infoCommand() *cobra.Command {
Short: "Show your info",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, s := fromContext(cmd)
- user, err := cfg.Backend.UserByPublicKey(s.PublicKey())
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
+ pk := sshutils.PublicKeyFromContext(ctx)
+ user, err := be.UserByPublicKey(ctx, pk)
if err != nil {
return err
}
@@ -21,7 +24,7 @@ func infoCommand() *cobra.Command {
cmd.Printf("Admin: %t\n", user.IsAdmin())
cmd.Printf("Public keys:\n")
for _, pk := range user.PublicKeys() {
- cmd.Printf(" %s\n", backend.MarshalAuthorizedKey(pk))
+ cmd.Printf(" %s\n", sshutils.MarshalAuthorizedKey(pk))
}
return nil
},
@@ -1,7 +1,9 @@
package cmd
import (
+ "github.com/charmbracelet/soft-serve/server/access"
"github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/sshutils"
"github.com/spf13/cobra"
)
@@ -15,13 +17,15 @@ func listCommand() *cobra.Command {
Short: "List repositories",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, s := fromContext(cmd)
- repos, err := cfg.Backend.Repositories()
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
+ pk := sshutils.PublicKeyFromContext(ctx)
+ repos, err := be.Repositories(ctx)
if err != nil {
return err
}
for _, r := range repos {
- if cfg.Backend.AccessLevelByPublicKey(r.Name(), s.PublicKey()) >= backend.ReadOnlyAccess {
+ if be.AccessLevelByPublicKey(ctx, r.Name(), pk) >= access.ReadOnlyAccess {
if !r.IsHidden() || all {
cmd.Println(r.Name())
}
@@ -1,6 +1,7 @@
package cmd
import (
+ "github.com/charmbracelet/soft-serve/server/backend"
"github.com/spf13/cobra"
)
@@ -11,9 +12,10 @@ func mirrorCommand() *cobra.Command {
Args: cobra.ExactArgs(1),
PersistentPreRunE: checkIfReadable,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
rn := args[0]
- rr, err := cfg.Backend.Repository(rn)
+ rr, err := be.Repository(ctx, rn)
if err != nil {
return err
}
@@ -4,6 +4,7 @@ import (
"strconv"
"strings"
+ "github.com/charmbracelet/soft-serve/server/backend"
"github.com/spf13/cobra"
)
@@ -13,7 +14,8 @@ func privateCommand() *cobra.Command {
Short: "Set or get a repository private property",
Args: cobra.RangeArgs(1, 2),
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
rn := strings.TrimSuffix(args[0], ".git")
switch len(args) {
@@ -22,7 +24,7 @@ func privateCommand() *cobra.Command {
return err
}
- isPrivate, err := cfg.Backend.IsPrivate(rn)
+ isPrivate, err := be.IsPrivate(ctx, rn)
if err != nil {
return err
}
@@ -36,7 +38,7 @@ func privateCommand() *cobra.Command {
if err := checkIfCollab(cmd, args); err != nil {
return err
}
- if err := cfg.Backend.SetPrivate(rn, isPrivate); err != nil {
+ if err := be.SetPrivate(ctx, rn, isPrivate); err != nil {
return err
}
}
@@ -3,6 +3,7 @@ package cmd
import (
"strings"
+ "github.com/charmbracelet/soft-serve/server/backend"
"github.com/spf13/cobra"
)
@@ -13,7 +14,8 @@ func projectName() *cobra.Command {
Short: "Set or get the project name for a repository",
Args: cobra.MinimumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
rn := strings.TrimSuffix(args[0], ".git")
switch len(args) {
case 1:
@@ -21,7 +23,7 @@ func projectName() *cobra.Command {
return err
}
- pn, err := cfg.Backend.ProjectName(rn)
+ pn, err := be.ProjectName(ctx, rn)
if err != nil {
return err
}
@@ -31,7 +33,7 @@ func projectName() *cobra.Command {
if err := checkIfCollab(cmd, args); err != nil {
return err
}
- if err := cfg.Backend.SetProjectName(rn, strings.Join(args[1:], " ")); err != nil {
+ if err := be.SetProjectName(ctx, rn, strings.Join(args[1:], " ")); err != nil {
return err
}
}
@@ -4,6 +4,7 @@ import (
"strings"
"github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/sshutils"
"github.com/spf13/cobra"
)
@@ -19,18 +20,20 @@ func pubkeyCommand() *cobra.Command {
Short: "Add a public key",
Args: cobra.MinimumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, s := fromContext(cmd)
- user, err := cfg.Backend.UserByPublicKey(s.PublicKey())
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
+ pk := sshutils.PublicKeyFromContext(ctx)
+ user, err := be.UserByPublicKey(ctx, pk)
if err != nil {
return err
}
- pk, _, err := backend.ParseAuthorizedKey(strings.Join(args, " "))
+ apk, _, err := sshutils.ParseAuthorizedKey(strings.Join(args, " "))
if err != nil {
return err
}
- return cfg.Backend.AddPublicKey(user.Username(), pk)
+ return be.AddPublicKey(ctx, user.Username(), apk)
},
}
@@ -39,18 +42,20 @@ func pubkeyCommand() *cobra.Command {
Args: cobra.MinimumNArgs(1),
Short: "Remove a public key",
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, s := fromContext(cmd)
- user, err := cfg.Backend.UserByPublicKey(s.PublicKey())
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
+ pk := sshutils.PublicKeyFromContext(ctx)
+ user, err := be.UserByPublicKey(ctx, pk)
if err != nil {
return err
}
- pk, _, err := backend.ParseAuthorizedKey(strings.Join(args, " "))
+ apk, _, err := sshutils.ParseAuthorizedKey(strings.Join(args, " "))
if err != nil {
return err
}
- return cfg.Backend.RemovePublicKey(user.Username(), pk)
+ return be.RemovePublicKey(ctx, user.Username(), apk)
},
}
@@ -60,15 +65,17 @@ func pubkeyCommand() *cobra.Command {
Short: "List public keys",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, s := fromContext(cmd)
- user, err := cfg.Backend.UserByPublicKey(s.PublicKey())
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
+ pk := sshutils.PublicKeyFromContext(ctx)
+ user, err := be.UserByPublicKey(ctx, pk)
if err != nil {
return err
}
pks := user.PublicKeys()
for _, pk := range pks {
- cmd.Println(backend.MarshalAuthorizedKey(pk))
+ cmd.Println(sshutils.MarshalAuthorizedKey(pk))
}
return nil
@@ -1,6 +1,9 @@
package cmd
-import "github.com/spf13/cobra"
+import (
+ "github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/spf13/cobra"
+)
func renameCommand() *cobra.Command {
cmd := &cobra.Command{
@@ -10,13 +13,12 @@ func renameCommand() *cobra.Command {
Args: cobra.ExactArgs(2),
PersistentPreRunE: checkIfCollab,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
oldName := args[0]
newName := args[1]
- if err := cfg.Backend.RenameRepository(oldName, newName); err != nil {
- return err
- }
- return nil
+
+ return be.RenameRepository(ctx, oldName, newName)
},
}
@@ -4,6 +4,7 @@ import (
"fmt"
"strings"
+ "github.com/charmbracelet/soft-serve/server/backend"
"github.com/spf13/cobra"
)
@@ -40,9 +41,10 @@ func repoCommand() *cobra.Command {
Args: cobra.ExactArgs(1),
PersistentPreRunE: checkIfReadable,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
rn := args[0]
- rr, err := cfg.Backend.Repository(rn)
+ rr, err := be.Repository(ctx, rn)
if err != nil {
return err
}
@@ -0,0 +1,28 @@
+package cmd
+
+import (
+ "github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/sshutils"
+ "github.com/spf13/cobra"
+)
+
+func setUsernameCommand() *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "set-username USERNAME",
+ Short: "Set your username",
+ Args: cobra.ExactArgs(1),
+ RunE: func(cmd *cobra.Command, args []string) error {
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
+ pk := sshutils.PublicKeyFromContext(ctx)
+ user, err := be.UserByPublicKey(ctx, pk)
+ if err != nil {
+ return err
+ }
+
+ return be.SetUsername(ctx, user.Username(), args[0])
+ },
+ }
+
+ return cmd
+}
@@ -4,6 +4,7 @@ import (
"fmt"
"strconv"
+ "github.com/charmbracelet/soft-serve/server/access"
"github.com/charmbracelet/soft-serve/server/backend"
"github.com/spf13/cobra"
)
@@ -21,13 +22,14 @@ func settingsCommand() *cobra.Command {
Args: cobra.RangeArgs(0, 1),
PersistentPreRunE: checkIfAdmin,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
switch len(args) {
case 0:
- cmd.Println(cfg.Backend.AllowKeyless())
+ cmd.Println(be.AllowKeyless(ctx))
case 1:
v, _ := strconv.ParseBool(args[0])
- if err := cfg.Backend.SetAllowKeyless(v); err != nil {
+ if err := be.SetAllowKeyless(ctx, v); err != nil {
return err
}
}
@@ -37,7 +39,7 @@ func settingsCommand() *cobra.Command {
},
)
- als := []string{backend.NoAccess.String(), backend.ReadOnlyAccess.String(), backend.ReadWriteAccess.String(), backend.AdminAccess.String()}
+ als := []string{access.NoAccess.String(), access.ReadOnlyAccess.String(), access.ReadWriteAccess.String(), access.AdminAccess.String()}
cmd.AddCommand(
&cobra.Command{
Use: "anon-access [ACCESS_LEVEL]",
@@ -46,16 +48,17 @@ func settingsCommand() *cobra.Command {
ValidArgs: als,
PersistentPreRunE: checkIfAdmin,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
switch len(args) {
case 0:
- cmd.Println(cfg.Backend.AnonAccess())
+ cmd.Println(be.AnonAccess(ctx))
case 1:
- al := backend.ParseAccessLevel(args[0])
+ al := access.ParseAccessLevel(args[0])
if al < 0 {
return fmt.Errorf("invalid access level: %s. Please choose one of the following: %s", args[0], als)
}
- if err := cfg.Backend.SetAnonAccess(al); err != nil {
+ if err := be.SetAnonAccess(ctx, al); err != nil {
return err
}
}
@@ -3,6 +3,7 @@ package cmd
import (
"strings"
+ "github.com/charmbracelet/soft-serve/server/backend"
"github.com/spf13/cobra"
)
@@ -28,9 +29,10 @@ func tagListCommand() *cobra.Command {
Args: cobra.ExactArgs(1),
PersistentPreRunE: checkIfReadable,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
rn := strings.TrimSuffix(args[0], ".git")
- rr, err := cfg.Backend.Repository(rn)
+ rr, err := be.Repository(ctx, rn)
if err != nil {
return err
}
@@ -60,9 +62,10 @@ func tagDeleteCommand() *cobra.Command {
Args: cobra.ExactArgs(2),
PersistentPreRunE: checkIfCollab,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
rn := strings.TrimSuffix(args[0], ".git")
- rr, err := cfg.Backend.Repository(rn)
+ rr, err := be.Repository(ctx, rn)
if err != nil {
return err
}
@@ -4,7 +4,8 @@ import (
"fmt"
"github.com/charmbracelet/soft-serve/git"
- "github.com/charmbracelet/soft-serve/server/errors"
+ "github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/proto"
"github.com/dustin/go-humanize"
"github.com/spf13/cobra"
)
@@ -17,7 +18,8 @@ func treeCommand() *cobra.Command {
Args: cobra.RangeArgs(1, 3),
PersistentPreRunE: checkIfReadable,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
rn := args[0]
path := ""
ref := ""
@@ -28,7 +30,7 @@ func treeCommand() *cobra.Command {
ref = args[1]
path = args[2]
}
- rr, err := cfg.Backend.Repository(rn)
+ rr, err := be.Repository(ctx, rn)
if err != nil {
return err
}
@@ -59,7 +61,7 @@ func treeCommand() *cobra.Command {
if path != "" && path != "/" {
te, err := tree.TreeEntry(path)
if err == git.ErrRevisionNotExist {
- return errors.ErrFileNotFound
+ return proto.ErrFileNotFound
}
if err != nil {
return err
@@ -4,8 +4,9 @@ import (
"sort"
"strings"
- "github.com/charmbracelet/log"
"github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/proto"
+ "github.com/charmbracelet/soft-serve/server/sshutils"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
)
@@ -26,10 +27,11 @@ func userCommand() *cobra.Command {
PersistentPreRunE: checkIfAdmin,
RunE: func(cmd *cobra.Command, args []string) error {
var pubkeys []ssh.PublicKey
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
username := args[0]
if key != "" {
- pk, _, err := backend.ParseAuthorizedKey(key)
+ pk, _, err := sshutils.ParseAuthorizedKey(key)
if err != nil {
return err
}
@@ -37,12 +39,12 @@ func userCommand() *cobra.Command {
pubkeys = []ssh.PublicKey{pk}
}
- opts := backend.UserOptions{
+ opts := proto.UserOptions{
Admin: admin,
PublicKeys: pubkeys,
}
- _, err := cfg.Backend.CreateUser(username, opts)
+ _, err := be.CreateUser(ctx, username, opts)
return err
},
}
@@ -56,10 +58,11 @@ func userCommand() *cobra.Command {
Args: cobra.ExactArgs(1),
PersistentPreRunE: checkIfAdmin,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
username := args[0]
- return cfg.Backend.DeleteUser(username)
+ return be.DeleteUser(ctx, username)
},
}
@@ -70,8 +73,9 @@ func userCommand() *cobra.Command {
Args: cobra.NoArgs,
PersistentPreRunE: checkIfAdmin,
RunE: func(cmd *cobra.Command, _ []string) error {
- cfg, _ := fromContext(cmd)
- users, err := cfg.Backend.Users()
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
+ users, err := be.Users(ctx)
if err != nil {
return err
}
@@ -91,15 +95,16 @@ func userCommand() *cobra.Command {
Args: cobra.MinimumNArgs(2),
PersistentPreRunE: checkIfAdmin,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
username := args[0]
pubkey := strings.Join(args[1:], " ")
- pk, _, err := backend.ParseAuthorizedKey(pubkey)
+ pk, _, err := sshutils.ParseAuthorizedKey(pubkey)
if err != nil {
return err
}
- return cfg.Backend.AddPublicKey(username, pk)
+ return be.AddPublicKey(ctx, username, pk)
},
}
@@ -109,16 +114,16 @@ func userCommand() *cobra.Command {
Args: cobra.MinimumNArgs(2),
PersistentPreRunE: checkIfAdmin,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
username := args[0]
pubkey := strings.Join(args[1:], " ")
- log.Debugf("key is %q", pubkey)
- pk, _, err := backend.ParseAuthorizedKey(pubkey)
+ pk, _, err := sshutils.ParseAuthorizedKey(pubkey)
if err != nil {
return err
}
- return cfg.Backend.RemovePublicKey(username, pk)
+ return be.RemovePublicKey(ctx, username, pk)
},
}
@@ -128,10 +133,11 @@ func userCommand() *cobra.Command {
Args: cobra.ExactArgs(2),
PersistentPreRunE: checkIfAdmin,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
username := args[0]
- return cfg.Backend.SetAdmin(username, args[1] == "true")
+ return be.SetAdmin(ctx, username, args[1] == "true")
},
}
@@ -141,10 +147,11 @@ func userCommand() *cobra.Command {
Args: cobra.ExactArgs(1),
PersistentPreRunE: checkIfAdmin,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
username := args[0]
- user, err := cfg.Backend.User(username)
+ user, err := be.User(ctx, username)
if err != nil {
return err
}
@@ -155,7 +162,7 @@ func userCommand() *cobra.Command {
cmd.Printf("Admin: %t\n", isAdmin)
cmd.Printf("Public keys:\n")
for _, pk := range user.PublicKeys() {
- cmd.Printf(" %s\n", backend.MarshalAuthorizedKey(pk))
+ cmd.Printf(" %s\n", sshutils.MarshalAuthorizedKey(pk))
}
return nil
@@ -168,11 +175,12 @@ func userCommand() *cobra.Command {
Args: cobra.ExactArgs(2),
PersistentPreRunE: checkIfAdmin,
RunE: func(cmd *cobra.Command, args []string) error {
- cfg, _ := fromContext(cmd)
+ ctx := cmd.Context()
+ be := backend.FromContext(ctx)
username := args[0]
newUsername := args[1]
- return cfg.Backend.SetUsername(username, newUsername)
+ return be.SetUsername(ctx, username, newUsername)
},
}
@@ -0,0 +1,124 @@
+package ssh
+
+import (
+ "errors"
+ "path/filepath"
+ "time"
+
+ "github.com/charmbracelet/log"
+ "github.com/charmbracelet/soft-serve/server/access"
+ "github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/config"
+ "github.com/charmbracelet/soft-serve/server/git"
+ "github.com/charmbracelet/soft-serve/server/proto"
+ "github.com/charmbracelet/soft-serve/server/sshutils"
+ "github.com/charmbracelet/soft-serve/server/utils"
+ "github.com/charmbracelet/ssh"
+)
+
+func handleGit(s ssh.Session) {
+ ctx := s.Context()
+ cfg := config.FromContext(ctx)
+ be := backend.FromContext(ctx)
+ logger := log.FromContext(ctx)
+ cmdLine := s.Command()
+ start := time.Now()
+
+ // repo should be in the form of "repo.git"
+ name := utils.SanitizeRepo(cmdLine[1])
+ pk := s.PublicKey()
+ ak := sshutils.MarshalAuthorizedKey(pk)
+ accessLevel := be.AccessLevelByPublicKey(ctx, name, pk)
+ // git bare repositories should end in ".git"
+ // https://git-scm.com/docs/gitrepository-layout
+ repo := name + ".git"
+ reposDir := filepath.Join(cfg.DataPath, "repos")
+ if err := git.EnsureWithin(reposDir, repo); err != nil {
+ sshFatal(s, err)
+ return
+ }
+
+ // Environment variables to pass down to git hooks.
+ envs := []string{
+ "SOFT_SERVE_REPO_NAME=" + name,
+ "SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo),
+ "SOFT_SERVE_PUBLIC_KEY=" + ak,
+ "SOFT_SERVE_USERNAME=" + s.User(),
+ "SOFT_SERVE_LOG_PATH=" + filepath.Join(cfg.DataPath, "log", "hooks.log"),
+ }
+
+ // Add ssh session & config environ
+ envs = append(envs, s.Environ()...)
+ envs = append(envs, cfg.Environ()...)
+
+ repoDir := filepath.Join(reposDir, repo)
+ service := git.Service(cmdLine[0])
+ cmd := git.ServiceCommand{
+ Stdin: s,
+ Stdout: s,
+ Stderr: s.Stderr(),
+ Env: envs,
+ Dir: repoDir,
+ }
+
+ logger.Debug("git middleware", "cmd", service, "access", accessLevel.String())
+
+ switch service {
+ case git.ReceivePackService:
+ receivePackCounter.WithLabelValues(name).Inc()
+ defer func() {
+ receivePackSeconds.WithLabelValues(name).Add(time.Since(start).Seconds())
+ }()
+ if accessLevel < access.ReadWriteAccess {
+ sshFatal(s, git.ErrNotAuthed)
+ return
+ }
+ if _, err := be.Repository(ctx, name); err != nil {
+ if _, err := be.CreateRepository(ctx, name, proto.RepositoryOptions{Private: false}); err != nil {
+ log.Errorf("failed to create repo: %s", err)
+ sshFatal(s, err)
+ return
+ }
+ createRepoCounter.WithLabelValues(name).Inc()
+ }
+
+ if err := git.ReceivePack(ctx, cmd); err != nil {
+ sshFatal(s, git.ErrSystemMalfunction)
+ }
+
+ if err := git.EnsureDefaultBranch(ctx, cmd); err != nil {
+ sshFatal(s, git.ErrSystemMalfunction)
+ }
+
+ receivePackCounter.WithLabelValues(name).Inc()
+ return
+ case git.UploadPackService, git.UploadArchiveService:
+ if accessLevel < access.ReadOnlyAccess {
+ sshFatal(s, git.ErrNotAuthed)
+ return
+ }
+
+ handler := git.UploadPack
+ switch service {
+ case git.UploadArchiveService:
+ handler = git.UploadArchive
+ uploadArchiveCounter.WithLabelValues(name).Inc()
+ defer func() {
+ uploadArchiveSeconds.WithLabelValues(name).Add(time.Since(start).Seconds())
+ }()
+ default:
+ uploadPackCounter.WithLabelValues(name).Inc()
+ defer func() {
+ uploadPackSeconds.WithLabelValues(name).Add(time.Since(start).Seconds())
+ }()
+ }
+
+ err := handler(ctx, cmd)
+ if errors.Is(err, git.ErrInvalidRepo) {
+ sshFatal(s, git.ErrInvalidRepo)
+ } else if err != nil {
+ logger.Error("git middleware", "err", err)
+ sshFatal(s, git.ErrSystemMalfunction)
+ }
+ }
+}
@@ -0,0 +1,25 @@
+package ssh
+
+import "github.com/charmbracelet/log"
+
+type loggerAdapter struct {
+ *log.Logger
+ log.Level
+}
+
+func (l *loggerAdapter) Printf(format string, args ...interface{}) {
+ switch l.Level {
+ case log.DebugLevel:
+ l.Logger.Debugf(format, args...)
+ case log.InfoLevel:
+ l.Logger.Infof(format, args...)
+ case log.WarnLevel:
+ l.Logger.Warnf(format, args...)
+ case log.ErrorLevel:
+ l.Logger.Errorf(format, args...)
+ case log.FatalLevel:
+ l.Logger.Fatalf(format, args...)
+ default:
+ l.Logger.Printf(format, args...)
+ }
+}
@@ -0,0 +1,44 @@
+package ssh
+
+import (
+ "strings"
+
+ "github.com/charmbracelet/log"
+ "github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/config"
+ "github.com/charmbracelet/ssh"
+)
+
+// ContextMiddleware adds the config, backend, and logger to the session context.
+func ContextMiddleware(cfg *config.Config, be *backend.Backend, logger *log.Logger) func(ssh.Handler) ssh.Handler {
+ return func(sh ssh.Handler) ssh.Handler {
+ return func(s ssh.Session) {
+ s.Context().SetValue(config.ContextKey, cfg)
+ s.Context().SetValue(backend.ContextKey, be)
+ s.Context().SetValue(log.ContextKey, logger.WithPrefix("ssh"))
+ sh(s)
+ }
+ }
+}
+
+// CommandMiddleware handles git commands and CLI commands.
+// This middleware must be run after the ContextMiddleware.
+func CommandMiddleware(sh ssh.Handler) ssh.Handler {
+ return func(s ssh.Session) {
+ func() {
+ cmdLine := s.Command()
+ _, _, ptyReq := s.Pty()
+ if ptyReq {
+ return
+ }
+
+ switch {
+ case len(cmdLine) >= 2 && strings.HasPrefix(cmdLine[0], "git-"):
+ handleGit(s)
+ default:
+ handleCli(s)
+ }
+ }()
+ sh(s)
+ }
+}
@@ -5,16 +5,14 @@ import (
"time"
tea "github.com/charmbracelet/bubbletea"
- "github.com/charmbracelet/log"
- . "github.com/charmbracelet/soft-serve/internal/log"
+ "github.com/charmbracelet/soft-serve/server/access"
"github.com/charmbracelet/soft-serve/server/backend"
"github.com/charmbracelet/soft-serve/server/config"
- "github.com/charmbracelet/soft-serve/server/errors"
+ "github.com/charmbracelet/soft-serve/server/proto"
"github.com/charmbracelet/soft-serve/server/ui"
"github.com/charmbracelet/soft-serve/server/ui/common"
"github.com/charmbracelet/ssh"
"github.com/charmbracelet/wish"
- bm "github.com/charmbracelet/wish/bubbletea"
"github.com/muesli/termenv"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
@@ -35,50 +33,50 @@ var tuiSessionDuration = promauto.NewCounterVec(prometheus.CounterOpts{
}, []string{"repo", "term"})
// SessionHandler is the soft-serve bubbletea ssh session handler.
-func SessionHandler(cfg *config.Config) bm.ProgramHandler {
- return func(s ssh.Session) *tea.Program {
- pty, _, active := s.Pty()
- if !active {
- return nil
- }
+// This middleware must be run after the ContextMiddleware.
+func SessionHandler(s ssh.Session) *tea.Program {
+ pty, _, active := s.Pty()
+ if !active {
+ return nil
+ }
- cmd := s.Command()
- initialRepo := ""
- if len(cmd) == 1 {
- initialRepo = cmd[0]
- auth := cfg.Backend.AccessLevelByPublicKey(initialRepo, s.PublicKey())
- if auth < backend.ReadOnlyAccess {
- wish.Fatalln(s, errors.ErrUnauthorized)
- return nil
- }
+ ctx := s.Context()
+ be := backend.FromContext(ctx)
+ cfg := config.FromContext(ctx)
+ cmd := s.Command()
+ initialRepo := ""
+ if len(cmd) == 1 {
+ initialRepo = cmd[0]
+ auth := be.AccessLevelByPublicKey(ctx, initialRepo, s.PublicKey())
+ if auth < access.ReadOnlyAccess {
+ wish.Fatalln(s, proto.ErrUnauthorized)
+ return nil
}
+ }
- envs := &sessionEnv{s}
- output := termenv.NewOutput(s, termenv.WithColorCache(true), termenv.WithEnvironment(envs))
- logger := NewDefaultLogger()
- ctx := log.WithContext(s.Context(), logger)
- c := common.NewCommon(ctx, output, pty.Window.Width, pty.Window.Height)
- c.SetValue(common.ConfigKey, cfg)
- m := ui.New(c, initialRepo)
- p := tea.NewProgram(m,
- tea.WithInput(s),
- tea.WithOutput(s),
- tea.WithAltScreen(),
- tea.WithoutCatchPanics(),
- tea.WithMouseCellMotion(),
- tea.WithContext(ctx),
- )
+ envs := &sessionEnv{s}
+ output := termenv.NewOutput(s, termenv.WithColorCache(true), termenv.WithEnvironment(envs))
+ c := common.NewCommon(ctx, output, pty.Window.Width, pty.Window.Height)
+ c.SetValue(common.ConfigKey, cfg)
+ m := ui.New(c, initialRepo)
+ p := tea.NewProgram(m,
+ tea.WithInput(s),
+ tea.WithOutput(s),
+ tea.WithAltScreen(),
+ tea.WithoutCatchPanics(),
+ tea.WithMouseCellMotion(),
+ tea.WithContext(ctx),
+ )
- tuiSessionCounter.WithLabelValues(initialRepo, pty.Term).Inc()
+ tuiSessionCounter.WithLabelValues(initialRepo, pty.Term).Inc()
- start := time.Now()
- go func() {
- <-ctx.Done()
- tuiSessionDuration.WithLabelValues(initialRepo, pty.Term).Add(time.Since(start).Seconds())
- }()
+ start := time.Now()
+ go func() {
+ <-ctx.Done()
+ tuiSessionDuration.WithLabelValues(initialRepo, pty.Term).Add(time.Since(start).Seconds())
+ }()
- return p
- }
+ return p
}
var _ termenv.Environ = &sessionEnv{}
@@ -4,13 +4,15 @@ import (
"context"
"errors"
"fmt"
- "log"
"os"
"testing"
"time"
- "github.com/charmbracelet/soft-serve/server/backend/sqlite"
+ "github.com/charmbracelet/log"
+ "github.com/charmbracelet/soft-serve/server/backend"
"github.com/charmbracelet/soft-serve/server/config"
+ "github.com/charmbracelet/soft-serve/server/db"
+ "github.com/charmbracelet/soft-serve/server/db/migrate"
"github.com/charmbracelet/soft-serve/server/test"
"github.com/charmbracelet/ssh"
bm "github.com/charmbracelet/wish/bubbletea"
@@ -18,6 +20,7 @@ import (
"github.com/matryer/is"
"github.com/muesli/termenv"
gossh "golang.org/x/crypto/ssh"
+ _ "modernc.org/sqlite" // sqlite driver
)
func TestSession(t *testing.T) {
@@ -31,16 +34,15 @@ func TestSession(t *testing.T) {
is.NoErr(err)
go func() {
time.Sleep(1 * time.Second)
- s.Signal(gossh.SIGTERM)
- // FIXME: exit with code 0 instead of forcibly closing the session
- s.Close()
+ // s.Signal(gossh.SIGTERM)
+ s.Close() // nolint: errcheck
}()
t.Log("waiting for session to exit")
_, err = s.Output("test")
var ee *gossh.ExitMissingError
is.True(errors.As(err, &ee))
t.Log("session exited")
- _ = close()
+ is.NoErr(close())
})
}
@@ -59,19 +61,26 @@ func setup(tb testing.TB) (*gossh.Session, func() error) {
})
ctx := context.TODO()
cfg := config.DefaultConfig()
+ if err := cfg.Validate(); err != nil {
+ log.Fatal(err)
+ }
ctx = config.WithContext(ctx, cfg)
- fb, err := sqlite.NewSqliteBackend(ctx)
+ db, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
if err != nil {
- log.Fatal(err)
+ tb.Fatal(err)
+ }
+ if err := migrate.Migrate(ctx, db); err != nil {
+ tb.Fatal(err)
}
- cfg = cfg.WithBackend(fb)
+ be := backend.New(ctx, cfg, db)
+ ctx = backend.WithContext(ctx, be)
return testsession.New(tb, &ssh.Server{
- Handler: bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256)(func(s ssh.Session) {
+ Handler: ContextMiddleware(cfg, be, log.Default())(bm.MiddlewareWithProgramHandler(SessionHandler, termenv.ANSI256)(func(s ssh.Session) {
_, _, active := s.Pty()
if !active {
os.Exit(1)
}
s.Exit(0)
- }),
- }, nil), fb.Close
+ })),
+ }, nil), db.Close
}
@@ -2,22 +2,19 @@ package ssh
import (
"context"
- "errors"
"fmt"
"net"
"os"
- "path/filepath"
"strconv"
- "strings"
"time"
"github.com/charmbracelet/keygen"
"github.com/charmbracelet/log"
+ "github.com/charmbracelet/soft-serve/server/access"
"github.com/charmbracelet/soft-serve/server/backend"
- cm "github.com/charmbracelet/soft-serve/server/cmd"
"github.com/charmbracelet/soft-serve/server/config"
"github.com/charmbracelet/soft-serve/server/git"
- "github.com/charmbracelet/soft-serve/server/utils"
+ "github.com/charmbracelet/soft-serve/server/sshutils"
"github.com/charmbracelet/ssh"
"github.com/charmbracelet/wish"
bm "github.com/charmbracelet/wish/bubbletea"
@@ -95,10 +92,10 @@ var (
)
// SSHServer is a SSH server that implements the git protocol.
-type SSHServer struct {
+type SSHServer struct { // nolint: revive
srv *ssh.Server
cfg *config.Config
- be backend.Backend
+ be *backend.Backend
ctx context.Context
logger *log.Logger
}
@@ -107,12 +104,13 @@ type SSHServer struct {
func NewSSHServer(ctx context.Context) (*SSHServer, error) {
cfg := config.FromContext(ctx)
logger := log.FromContext(ctx).WithPrefix("ssh")
+ be := backend.FromContext(ctx)
var err error
s := &SSHServer{
cfg: cfg,
ctx: ctx,
- be: backend.FromContext(ctx),
+ be: be,
logger: logger,
}
@@ -120,14 +118,15 @@ func NewSSHServer(ctx context.Context) (*SSHServer, error) {
rm.MiddlewareWithLogger(
logger,
// BubbleTea middleware.
- bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256),
+ bm.MiddlewareWithProgramHandler(SessionHandler, termenv.ANSI256),
// CLI middleware.
- cm.Middleware(cfg, logger),
- // Git middleware.
- s.Middleware(cfg),
+ CommandMiddleware,
+ // Context middleware.
+ ContextMiddleware(cfg, be, logger),
// Logging middleware.
- lm.MiddlewareWithLogger(logger.
- StandardLog(log.StandardLogOptions{ForceLevel: log.DebugLevel})),
+ lm.MiddlewareWithLogger(
+ &loggerAdapter{logger, log.DebugLevel},
+ ),
),
}
@@ -187,145 +186,27 @@ func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed
return false
}
- ak := backend.MarshalAuthorizedKey(pk)
+ ak := sshutils.MarshalAuthorizedKey(pk)
defer func(allowed *bool) {
publicKeyCounter.WithLabelValues(strconv.FormatBool(*allowed)).Inc()
}(&allowed)
- ac := s.cfg.Backend.AccessLevelByPublicKey("", pk)
+ ac := s.be.AccessLevelByPublicKey(ctx, "", pk)
s.logger.Debugf("access level for %q: %s", ak, ac)
- allowed = ac >= backend.ReadWriteAccess
+ allowed = ac >= access.ReadWriteAccess
return
}
// KeyboardInteractiveHandler handles keyboard interactive authentication.
// This is used after all public key authentication has failed.
func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
- ac := s.cfg.Backend.AllowKeyless()
+ ac := s.be.AllowKeyless(ctx)
keyboardInteractiveCounter.WithLabelValues(strconv.FormatBool(ac)).Inc()
return ac
}
-// Middleware adds Git server functionality to the ssh.Server. Repos are stored
-// in the specified repo directory. The provided Hooks implementation will be
-// checked for access on a per repo basis for a ssh.Session public key.
-// Hooks.Push and Hooks.Fetch will be called on successful completion of
-// their commands.
-func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware {
- return func(sh ssh.Handler) ssh.Handler {
- return func(s ssh.Session) {
- func() {
- start := time.Now()
- cmdLine := s.Command()
- ctx := s.Context()
- be := ss.be.WithContext(ctx)
-
- if len(cmdLine) >= 2 && strings.HasPrefix(cmdLine[0], "git") {
- // repo should be in the form of "repo.git"
- name := utils.SanitizeRepo(cmdLine[1])
- pk := s.PublicKey()
- ak := backend.MarshalAuthorizedKey(pk)
- access := cfg.Backend.AccessLevelByPublicKey(name, pk)
- // git bare repositories should end in ".git"
- // https://git-scm.com/docs/gitrepository-layout
- repo := name + ".git"
- reposDir := filepath.Join(cfg.DataPath, "repos")
- if err := git.EnsureWithin(reposDir, repo); err != nil {
- sshFatal(s, err)
- return
- }
-
- // Environment variables to pass down to git hooks.
- envs := []string{
- "SOFT_SERVE_REPO_NAME=" + name,
- "SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo),
- "SOFT_SERVE_PUBLIC_KEY=" + ak,
- "SOFT_SERVE_USERNAME=" + ctx.User(),
- }
-
- // Add ssh session & config environ
- envs = append(envs, s.Environ()...)
- envs = append(envs, cfg.Environ()...)
-
- repoDir := filepath.Join(reposDir, repo)
- service := git.Service(cmdLine[0])
- cmd := git.ServiceCommand{
- Stdin: s,
- Stdout: s,
- Stderr: s.Stderr(),
- Env: envs,
- Dir: repoDir,
- }
-
- ss.logger.Debug("git middleware", "cmd", service, "access", access.String())
-
- switch service {
- case git.ReceivePackService:
- receivePackCounter.WithLabelValues(name).Inc()
- defer func() {
- receivePackSeconds.WithLabelValues(name).Add(time.Since(start).Seconds())
- }()
- if access < backend.ReadWriteAccess {
- sshFatal(s, git.ErrNotAuthed)
- return
- }
- if _, err := be.Repository(name); err != nil {
- if _, err := be.CreateRepository(name, backend.RepositoryOptions{Private: false}); err != nil {
- log.Errorf("failed to create repo: %s", err)
- sshFatal(s, err)
- return
- }
- createRepoCounter.WithLabelValues(name).Inc()
- }
-
- if err := git.ReceivePack(ctx, cmd); err != nil {
- sshFatal(s, git.ErrSystemMalfunction)
- }
-
- if err := git.EnsureDefaultBranch(ctx, cmd); err != nil {
- sshFatal(s, git.ErrSystemMalfunction)
- }
-
- receivePackCounter.WithLabelValues(name).Inc()
- return
- case git.UploadPackService, git.UploadArchiveService:
- if access < backend.ReadOnlyAccess {
- sshFatal(s, git.ErrNotAuthed)
- return
- }
-
- handler := git.UploadPack
- switch service {
- case git.UploadArchiveService:
- handler = git.UploadArchive
- uploadArchiveCounter.WithLabelValues(name).Inc()
- defer func() {
- uploadArchiveSeconds.WithLabelValues(name).Add(time.Since(start).Seconds())
- }()
- default:
- uploadPackCounter.WithLabelValues(name).Inc()
- defer func() {
- uploadPackSeconds.WithLabelValues(name).Add(time.Since(start).Seconds())
- }()
- }
-
- err := handler(ctx, cmd)
- if errors.Is(err, git.ErrInvalidRepo) {
- sshFatal(s, git.ErrInvalidRepo)
- } else if err != nil {
- sshFatal(s, git.ErrSystemMalfunction)
- }
-
- }
- }
- }()
- sh(s)
- }
- }
-}
-
// sshFatal prints to the session's STDOUT as a git response and exit 1.
-func sshFatal(s ssh.Session, v ...interface{}) {
- git.WritePktline(s, v...)
- s.Exit(1) // nolint: errcheck
+func sshFatal(s ssh.Session, err error) {
+ git.WritePktlineErr(s, err) // nolint: errcheck
+ s.Exit(1) // nolint: errcheck
}
@@ -0,0 +1,40 @@
+package sshutils
+
+import (
+ "bytes"
+ "context"
+
+ "github.com/charmbracelet/ssh"
+ gossh "golang.org/x/crypto/ssh"
+)
+
+// ParseAuthorizedKey parses an authorized key string into a public key.
+func ParseAuthorizedKey(ak string) (gossh.PublicKey, string, error) {
+ pk, c, _, _, err := gossh.ParseAuthorizedKey([]byte(ak))
+ return pk, c, err
+}
+
+// MarshalAuthorizedKey marshals a public key into an authorized key string.
+//
+// This is the inverse of ParseAuthorizedKey.
+// This function is a copy of ssh.MarshalAuthorizedKey, but without the trailing newline.
+// It returns an empty string if pk is nil.
+func MarshalAuthorizedKey(pk gossh.PublicKey) string {
+ if pk == nil {
+ return ""
+ }
+ return string(bytes.TrimSuffix(gossh.MarshalAuthorizedKey(pk), []byte("\n")))
+}
+
+// KeysEqual returns whether the two public keys are equal.
+func KeysEqual(a, b gossh.PublicKey) bool {
+ return ssh.KeysEqual(a, b)
+}
+
+// PublicKeyFromContext returns the public key from the context.
+func PublicKeyFromContext(ctx context.Context) gossh.PublicKey {
+ if pk, ok := ctx.Value(ssh.ContextKeyPublicKey).(gossh.PublicKey); ok {
+ return pk
+ }
+ return nil
+}
@@ -0,0 +1,124 @@
+package sshutils
+
+import (
+ "testing"
+
+ "github.com/charmbracelet/keygen"
+ "golang.org/x/crypto/ssh"
+)
+
+func generateKeys(tb testing.TB) (*keygen.SSHKeyPair, *keygen.SSHKeyPair) {
+ goodKey1, err := keygen.New("", keygen.WithKeyType(keygen.Ed25519))
+ if err != nil {
+ tb.Fatal(err)
+ }
+ goodKey2, err := keygen.New("", keygen.WithKeyType(keygen.RSA))
+ if err != nil {
+ tb.Fatal(err)
+ }
+
+ return goodKey1, goodKey2
+}
+
+func TestParseAuthorizedKey(t *testing.T) {
+ goodKey1, goodKey2 := generateKeys(t)
+ cases := []struct {
+ in string
+ good bool
+ }{
+ {
+ goodKey1.AuthorizedKey(),
+ true,
+ },
+ {
+ goodKey2.AuthorizedKey(),
+ true,
+ },
+ {
+ goodKey1.AuthorizedKey() + "test",
+ false,
+ },
+ {
+ goodKey2.AuthorizedKey() + "bad",
+ false,
+ },
+ }
+ for _, c := range cases {
+ _, _, err := ParseAuthorizedKey(c.in)
+ if c.good && err != nil {
+ t.Errorf("ParseAuthorizedKey(%q) returned error: %v", c.in, err)
+ }
+ if !c.good && err == nil {
+ t.Errorf("ParseAuthorizedKey(%q) did not return error", c.in)
+ }
+ }
+}
+
+func TestMarshalAuthorizedKey(t *testing.T) {
+ goodKey1, goodKey2 := generateKeys(t)
+ cases := []struct {
+ in ssh.PublicKey
+ expected string
+ }{
+ {
+ goodKey1.PublicKey(),
+ goodKey1.AuthorizedKey(),
+ },
+ {
+ goodKey2.PublicKey(),
+ goodKey2.AuthorizedKey(),
+ },
+ {
+ nil,
+ "",
+ },
+ }
+ for _, c := range cases {
+ out := MarshalAuthorizedKey(c.in)
+ if out != c.expected {
+ t.Errorf("MarshalAuthorizedKey(%v) returned %q, expected %q", c.in, out, c.expected)
+ }
+ }
+}
+
+func TestKeysEqual(t *testing.T) {
+ goodKey1, goodKey2 := generateKeys(t)
+ cases := []struct {
+ in1 ssh.PublicKey
+ in2 ssh.PublicKey
+ expected bool
+ }{
+ {
+ goodKey1.PublicKey(),
+ goodKey1.PublicKey(),
+ true,
+ },
+ {
+ goodKey2.PublicKey(),
+ goodKey2.PublicKey(),
+ true,
+ },
+ {
+ goodKey1.PublicKey(),
+ goodKey2.PublicKey(),
+ false,
+ },
+ {
+ nil,
+ nil,
+ false,
+ },
+ {
+ nil,
+ goodKey1.PublicKey(),
+ false,
+ },
+ }
+
+ for _, c := range cases {
+ out := KeysEqual(c.in1, c.in2)
+ if out != c.expected {
+ t.Errorf("KeysEqual(%v, %v) returned %v, expected %v", c.in1, c.in2, out, c.expected)
+ }
+ }
+}
@@ -10,7 +10,7 @@ import (
)
// StatsServer is a server for collecting and reporting statistics.
-type StatsServer struct {
+type StatsServer struct { //nolint:revive
ctx context.Context
cfg *config.Config
server *http.Server
@@ -0,0 +1,124 @@
+package database
+
+import (
+ "context"
+ "strings"
+
+ "github.com/charmbracelet/soft-serve/server/db"
+ "github.com/charmbracelet/soft-serve/server/db/models"
+ "github.com/charmbracelet/soft-serve/server/store"
+ "github.com/charmbracelet/soft-serve/server/utils"
+)
+
+type collabStore struct{}
+
+var _ store.CollaboratorStore = (*collabStore)(nil)
+
+// AddCollabByUsernameAndRepo implements store.CollaboratorStore.
+func (*collabStore) AddCollabByUsernameAndRepo(ctx context.Context, tx *db.Tx, username string, repo string) error {
+ username = strings.ToLower(username)
+ if err := utils.ValidateUsername(username); err != nil {
+ return err
+ }
+
+ repo = utils.SanitizeRepo(repo)
+
+ query := tx.Rebind(`INSERT INTO collabs (user_id, repo_id, updated_at)
+ VALUES (
+ (
+ SELECT id FROM users WHERE username = ?
+ ),
+ (
+ SELECT id FROM repos WHERE name = ?
+ ),
+ CURRENT_TIMESTAMP
+ );`)
+ _, err := tx.ExecContext(ctx, query, username, repo)
+ return err
+}
+
+// GetCollabByUsernameAndRepo implements store.CollaboratorStore.
+func (*collabStore) GetCollabByUsernameAndRepo(ctx context.Context, tx *db.Tx, username string, repo string) (models.Collab, error) {
+ var m models.Collab
+
+ username = strings.ToLower(username)
+ if err := utils.ValidateUsername(username); err != nil {
+ return models.Collab{}, err
+ }
+
+ repo = utils.SanitizeRepo(repo)
+
+ err := tx.GetContext(ctx, &m, tx.Rebind(`
+ SELECT
+ collabs.*
+ FROM
+ collabs
+ INNER JOIN users ON users.id = collabs.user_id
+ INNER JOIN repos ON repos.id = collabs.repo_id
+ WHERE
+ users.username = ? AND repos.name = ?
+ `), username, repo)
+
+ return m, err
+}
+
+// ListCollabsByRepo implements store.CollaboratorStore.
+func (*collabStore) ListCollabsByRepo(ctx context.Context, tx *db.Tx, repo string) ([]models.Collab, error) {
+ var m []models.Collab
+
+ repo = utils.SanitizeRepo(repo)
+ query := tx.Rebind(`
+ SELECT
+ collabs.*
+ FROM
+ collabs
+ INNER JOIN repos ON repos.id = collabs.repo_id
+ WHERE
+ repos.name = ?
+ `)
+
+ err := tx.SelectContext(ctx, &m, query, repo)
+ return m, err
+}
+
+// ListCollabsByRepoAsUsers implements store.CollaboratorStore.
+func (*collabStore) ListCollabsByRepoAsUsers(ctx context.Context, tx *db.Tx, repo string) ([]models.User, error) {
+ var m []models.User
+
+ repo = utils.SanitizeRepo(repo)
+ query := tx.Rebind(`
+ SELECT
+ users.*
+ FROM
+ users
+ INNER JOIN repos ON repos.id = collabs.repo_id
+ INNER JOIN collabs ON collabs.user_id = users.id
+ WHERE
+ repos.name = ?
+ `)
+
+ err := tx.SelectContext(ctx, &m, query, repo)
+ return m, err
+}
+
+// RemoveCollabByUsernameAndRepo implements store.CollaboratorStore.
+func (*collabStore) RemoveCollabByUsernameAndRepo(ctx context.Context, tx *db.Tx, username string, repo string) error {
+ username = strings.ToLower(username)
+ if err := utils.ValidateUsername(username); err != nil {
+ return err
+ }
+
+ repo = utils.SanitizeRepo(repo)
+ query := tx.Rebind(`
+ DELETE FROM
+ collabs
+ WHERE
+ user_id = (
+ SELECT id FROM users WHERE username = ?
+ ) AND repo_id = (
+ SELECT id FROM repos WHERE name = ?
+ )
+ `)
+ _, err := tx.ExecContext(ctx, query, username, repo)
+ return err
+}
@@ -0,0 +1,42 @@
+package database
+
+import (
+ "context"
+
+ "github.com/charmbracelet/log"
+ "github.com/charmbracelet/soft-serve/server/config"
+ "github.com/charmbracelet/soft-serve/server/db"
+ "github.com/charmbracelet/soft-serve/server/store"
+)
+
+type datastore struct {
+ ctx context.Context
+ cfg *config.Config
+ db *db.DB
+ logger *log.Logger
+
+ *settingsStore
+ *repoStore
+ *userStore
+ *collabStore
+}
+
+// New returns a new store.Store database.
+func New(ctx context.Context, db *db.DB) store.Store {
+ cfg := config.FromContext(ctx)
+ logger := log.FromContext(ctx).WithPrefix("store")
+
+ s := &datastore{
+ ctx: ctx,
+ cfg: cfg,
+ db: db,
+ logger: logger,
+
+ settingsStore: &settingsStore{},
+ repoStore: &repoStore{},
+ userStore: &userStore{},
+ collabStore: &collabStore{},
+ }
+
+ return s
+}
@@ -0,0 +1,135 @@
+package database
+
+import (
+ "context"
+
+ "github.com/charmbracelet/soft-serve/server/db"
+ "github.com/charmbracelet/soft-serve/server/db/models"
+ "github.com/charmbracelet/soft-serve/server/store"
+ "github.com/charmbracelet/soft-serve/server/utils"
+)
+
+type repoStore struct{}
+
+var _ store.RepositoryStore = (*repoStore)(nil)
+
+// CreateRepo implements store.RepositoryStore.
+func (*repoStore) CreateRepo(ctx context.Context, tx *db.Tx, name string, projectName string, description string, isPrivate bool, isHidden bool, isMirror bool) error {
+ name = utils.SanitizeRepo(name)
+ query := tx.Rebind(`INSERT INTO repos (name, project_name, description, private, mirror, hidden, updated_at)
+ VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP);`)
+ _, err := tx.ExecContext(ctx, query,
+ name, projectName, description, isPrivate, isMirror, isHidden)
+ return db.WrapError(err)
+}
+
+// DeleteRepoByName implements store.RepositoryStore.
+func (*repoStore) DeleteRepoByName(ctx context.Context, tx *db.Tx, name string) error {
+ name = utils.SanitizeRepo(name)
+ query := tx.Rebind("DELETE FROM repos WHERE name = ?;")
+ _, err := tx.ExecContext(ctx, query, name)
+ return db.WrapError(err)
+}
+
+// GetAllRepos implements store.RepositoryStore.
+func (*repoStore) GetAllRepos(ctx context.Context, tx *db.Tx) ([]models.Repo, error) {
+ var repos []models.Repo
+ query := tx.Rebind("SELECT * FROM repos;")
+ err := tx.SelectContext(ctx, &repos, query)
+ return repos, db.WrapError(err)
+}
+
+// GetRepoByName implements store.RepositoryStore.
+func (*repoStore) GetRepoByName(ctx context.Context, tx *db.Tx, name string) (models.Repo, error) {
+ var repo models.Repo
+ name = utils.SanitizeRepo(name)
+ query := tx.Rebind("SELECT * FROM repos WHERE name = ?;")
+ err := tx.GetContext(ctx, &repo, query, name)
+ return repo, db.WrapError(err)
+}
+
+// GetRepoDescriptionByName implements store.RepositoryStore.
+func (*repoStore) GetRepoDescriptionByName(ctx context.Context, tx *db.Tx, name string) (string, error) {
+ var description string
+ name = utils.SanitizeRepo(name)
+ query := tx.Rebind("SELECT description FROM repos WHERE name = ?;")
+ err := tx.GetContext(ctx, &description, query, name)
+ return description, db.WrapError(err)
+}
+
+// GetRepoIsHiddenByName implements store.RepositoryStore.
+func (*repoStore) GetRepoIsHiddenByName(ctx context.Context, tx *db.Tx, name string) (bool, error) {
+ var isHidden bool
+ name = utils.SanitizeRepo(name)
+ query := tx.Rebind("SELECT hidden FROM repos WHERE name = ?;")
+ err := tx.GetContext(ctx, &isHidden, query, name)
+ return isHidden, db.WrapError(err)
+}
+
+// GetRepoIsMirrorByName implements store.RepositoryStore.
+func (*repoStore) GetRepoIsMirrorByName(ctx context.Context, tx *db.Tx, name string) (bool, error) {
+ var isMirror bool
+ name = utils.SanitizeRepo(name)
+ query := tx.Rebind("SELECT mirror FROM repos WHERE name = ?;")
+ err := tx.GetContext(ctx, &isMirror, query, name)
+ return isMirror, db.WrapError(err)
+}
+
+// GetRepoIsPrivateByName implements store.RepositoryStore.
+func (*repoStore) GetRepoIsPrivateByName(ctx context.Context, tx *db.Tx, name string) (bool, error) {
+ var isPrivate bool
+ name = utils.SanitizeRepo(name)
+ query := tx.Rebind("SELECT private FROM repos WHERE name = ?;")
+ err := tx.GetContext(ctx, &isPrivate, query, name)
+ return isPrivate, db.WrapError(err)
+}
+
+// GetRepoProjectNameByName implements store.RepositoryStore.
+func (*repoStore) GetRepoProjectNameByName(ctx context.Context, tx *db.Tx, name string) (string, error) {
+ var pname string
+ name = utils.SanitizeRepo(name)
+ query := tx.Rebind("SELECT project_name FROM repos WHERE name = ?;")
+ err := tx.GetContext(ctx, &pname, query, name)
+ return pname, db.WrapError(err)
+}
+
+// SetRepoDescriptionByName implements store.RepositoryStore.
+func (*repoStore) SetRepoDescriptionByName(ctx context.Context, tx *db.Tx, name string, description string) error {
+ name = utils.SanitizeRepo(name)
+ query := tx.Rebind("UPDATE repos SET description = ? WHERE name = ?;")
+ _, err := tx.ExecContext(ctx, query, description, name)
+ return db.WrapError(err)
+}
+
+// SetRepoIsHiddenByName implements store.RepositoryStore.
+func (*repoStore) SetRepoIsHiddenByName(ctx context.Context, tx *db.Tx, name string, isHidden bool) error {
+ name = utils.SanitizeRepo(name)
+ query := tx.Rebind("UPDATE repos SET hidden = ? WHERE name = ?;")
+ _, err := tx.ExecContext(ctx, query, isHidden, name)
+ return db.WrapError(err)
+}
+
+// SetRepoIsPrivateByName implements store.RepositoryStore.
+func (*repoStore) SetRepoIsPrivateByName(ctx context.Context, tx *db.Tx, name string, isPrivate bool) error {
+ name = utils.SanitizeRepo(name)
+ query := tx.Rebind("UPDATE repos SET private = ? WHERE name = ?;")
+ _, err := tx.ExecContext(ctx, query, isPrivate, name)
+ return db.WrapError(err)
+}
+
+// SetRepoNameByName implements store.RepositoryStore.
+func (*repoStore) SetRepoNameByName(ctx context.Context, tx *db.Tx, name string, newName string) error {
+ name = utils.SanitizeRepo(name)
+ newName = utils.SanitizeRepo(newName)
+ query := tx.Rebind("UPDATE repos SET name = ? WHERE name = ?;")
+ _, err := tx.ExecContext(ctx, query, newName, name)
+ return db.WrapError(err)
+}
+
+// SetRepoProjectNameByName implements store.RepositoryStore.
+func (*repoStore) SetRepoProjectNameByName(ctx context.Context, tx *db.Tx, name string, projectName string) error {
+ name = utils.SanitizeRepo(name)
+ query := tx.Rebind("UPDATE repos SET project_name = ? WHERE name = ?;")
+ _, err := tx.ExecContext(ctx, query, projectName, name)
+ return db.WrapError(err)
+}
@@ -0,0 +1,47 @@
+package database
+
+import (
+ "context"
+
+ "github.com/charmbracelet/soft-serve/server/access"
+ "github.com/charmbracelet/soft-serve/server/db"
+ "github.com/charmbracelet/soft-serve/server/store"
+)
+
+type settingsStore struct{}
+
+var _ store.SettingStore = (*settingsStore)(nil)
+
+// GetAllowKeylessAccess implements store.SettingStore.
+func (*settingsStore) GetAllowKeylessAccess(ctx context.Context, tx *db.Tx) (bool, error) {
+ var allow bool
+ query := tx.Rebind(`SELECT value FROM settings WHERE key = "allow_keyless"`)
+ if err := tx.GetContext(ctx, &allow, query); err != nil {
+ return false, db.WrapError(err)
+ }
+ return allow, nil
+}
+
+// GetAnonAccess implements store.SettingStore.
+func (*settingsStore) GetAnonAccess(ctx context.Context, tx *db.Tx) (access.AccessLevel, error) {
+ var level string
+ query := tx.Rebind(`SELECT value FROM settings WHERE key = "anon_access"`)
+ if err := tx.GetContext(ctx, &level, query); err != nil {
+ return access.NoAccess, db.WrapError(err)
+ }
+ return access.ParseAccessLevel(level), nil
+}
+
+// SetAllowKeylessAccess implements store.SettingStore.
+func (*settingsStore) SetAllowKeylessAccess(ctx context.Context, tx *db.Tx, allow bool) error {
+ query := tx.Rebind(`UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = "allow_keyless"`)
+ _, err := tx.ExecContext(ctx, query, allow)
+ return db.WrapError(err)
+}
+
+// SetAnonAccess implements store.SettingStore.
+func (*settingsStore) SetAnonAccess(ctx context.Context, tx *db.Tx, level access.AccessLevel) error {
+ query := tx.Rebind(`UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = "anon_access"`)
+ _, err := tx.ExecContext(ctx, query, level.String())
+ return db.WrapError(err)
+}
@@ -0,0 +1,208 @@
+package database
+
+import (
+ "context"
+ "strings"
+
+ "github.com/charmbracelet/soft-serve/server/db"
+ "github.com/charmbracelet/soft-serve/server/db/models"
+ "github.com/charmbracelet/soft-serve/server/sshutils"
+ "github.com/charmbracelet/soft-serve/server/store"
+ "github.com/charmbracelet/soft-serve/server/utils"
+ "golang.org/x/crypto/ssh"
+)
+
+type userStore struct{}
+
+var _ store.UserStore = (*userStore)(nil)
+
+// AddPublicKeyByUsername implements store.UserStore.
+func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx *db.Tx, username string, pk ssh.PublicKey) error {
+ username = strings.ToLower(username)
+ if err := utils.ValidateUsername(username); err != nil {
+ return err
+ }
+
+ var userID int64
+ if err := tx.GetContext(ctx, &userID, tx.Rebind(`SELECT id FROM users WHERE username = ?`), username); err != nil {
+ return err
+ }
+
+ query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)
+ VALUES (?, ?, CURRENT_TIMESTAMP);`)
+ ak := sshutils.MarshalAuthorizedKey(pk)
+ _, err := tx.ExecContext(ctx, query, userID, ak)
+
+ return err
+}
+
+// CreateUser implements store.UserStore.
+func (*userStore) CreateUser(ctx context.Context, tx *db.Tx, username string, isAdmin bool, pks []ssh.PublicKey) error {
+ username = strings.ToLower(username)
+ if err := utils.ValidateUsername(username); err != nil {
+ return err
+ }
+
+ query := tx.Rebind(`INSERT INTO users (username, admin, updated_at)
+ VALUES (?, ?, CURRENT_TIMESTAMP);`)
+ result, err := tx.ExecContext(ctx, query, username, isAdmin)
+ if err != nil {
+ return err
+ }
+
+ userID, err := result.LastInsertId()
+ if err != nil {
+ return err
+ }
+
+ for _, pk := range pks {
+ query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)
+ VALUES (?, ?, CURRENT_TIMESTAMP);`)
+ ak := sshutils.MarshalAuthorizedKey(pk)
+ _, err := tx.ExecContext(ctx, query, userID, ak)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// DeleteUserByUsername implements store.UserStore.
+func (*userStore) DeleteUserByUsername(ctx context.Context, tx *db.Tx, username string) error {
+ username = strings.ToLower(username)
+ if err := utils.ValidateUsername(username); err != nil {
+ return err
+ }
+
+ query := tx.Rebind(`DELETE FROM users WHERE username = ?;`)
+ _, err := tx.ExecContext(ctx, query, username)
+ return err
+}
+
+// FindUserByPublicKey implements store.UserStore.
+func (*userStore) FindUserByPublicKey(ctx context.Context, tx *db.Tx, pk ssh.PublicKey) (models.User, error) {
+ var m models.User
+ query := tx.Rebind(`SELECT users.*
+ FROM users
+ INNER JOIN public_keys ON users.id = public_keys.user_id
+ WHERE public_keys.public_key = ?;`)
+ err := tx.GetContext(ctx, &m, query, sshutils.MarshalAuthorizedKey(pk))
+ return m, err
+}
+
+// FindUserByUsername implements store.UserStore.
+func (*userStore) FindUserByUsername(ctx context.Context, tx *db.Tx, username string) (models.User, error) {
+ username = strings.ToLower(username)
+ if err := utils.ValidateUsername(username); err != nil {
+ return models.User{}, err
+ }
+
+ var m models.User
+ query := tx.Rebind(`SELECT * FROM users WHERE username = ?;`)
+ err := tx.GetContext(ctx, &m, query, username)
+ return m, err
+}
+
+// GetAllUsers implements store.UserStore.
+func (*userStore) GetAllUsers(ctx context.Context, tx *db.Tx) ([]models.User, error) {
+ var ms []models.User
+ query := tx.Rebind(`SELECT * FROM users;`)
+ err := tx.SelectContext(ctx, &ms, query)
+ return ms, err
+}
+
+// ListPublicKeysByUserID implements store.UserStore..
+func (*userStore) ListPublicKeysByUserID(ctx context.Context, tx *db.Tx, id int64) ([]ssh.PublicKey, error) {
+ var aks []string
+ query := tx.Rebind(`SELECT public_key FROM public_keys
+ WHERE user_id = ?
+ ORDER BY public_keys.id ASC;`)
+ err := tx.SelectContext(ctx, &aks, query, id)
+ if err != nil {
+ return nil, err
+ }
+
+ pks := make([]ssh.PublicKey, len(aks))
+ for i, ak := range aks {
+ pk, _, err := sshutils.ParseAuthorizedKey(ak)
+ if err != nil {
+ return nil, err
+ }
+ pks[i] = pk
+ }
+
+ return pks, nil
+}
+
+// ListPublicKeysByUsername implements store.UserStore.
+func (*userStore) ListPublicKeysByUsername(ctx context.Context, tx *db.Tx, username string) ([]ssh.PublicKey, error) {
+ username = strings.ToLower(username)
+ if err := utils.ValidateUsername(username); err != nil {
+ return nil, err
+ }
+
+ var aks []string
+ query := tx.Rebind(`SELECT public_key FROM public_keys
+ INNER JOIN users ON users.id = public_keys.user_id
+ WHERE users.username = ?
+ ORDER BY public_keys.id ASC;`)
+ err := tx.SelectContext(ctx, &aks, query, username)
+ if err != nil {
+ return nil, err
+ }
+
+ pks := make([]ssh.PublicKey, len(aks))
+ for i, ak := range aks {
+ pk, _, err := sshutils.ParseAuthorizedKey(ak)
+ if err != nil {
+ return nil, err
+ }
+ pks[i] = pk
+ }
+
+ return pks, nil
+}
+
+// RemovePublicKeyByUsername implements store.UserStore.
+func (*userStore) RemovePublicKeyByUsername(ctx context.Context, tx *db.Tx, username string, pk ssh.PublicKey) error {
+ username = strings.ToLower(username)
+ if err := utils.ValidateUsername(username); err != nil {
+ return err
+ }
+
+ query := tx.Rebind(`DELETE FROM public_keys
+ WHERE user_id = (SELECT id FROM users WHERE username = ?)
+ AND public_key = ?;`)
+ _, err := tx.ExecContext(ctx, query, username, sshutils.MarshalAuthorizedKey(pk))
+ return err
+}
+
+// SetAdminByUsername implements store.UserStore.
+func (*userStore) SetAdminByUsername(ctx context.Context, tx *db.Tx, username string, isAdmin bool) error {
+ username = strings.ToLower(username)
+ if err := utils.ValidateUsername(username); err != nil {
+ return err
+ }
+
+ query := tx.Rebind(`UPDATE users SET admin = ? WHERE username = ?;`)
+ _, err := tx.ExecContext(ctx, query, isAdmin, username)
+ return err
+}
+
+// SetUsernameByUsername implements store.UserStore.
+func (*userStore) SetUsernameByUsername(ctx context.Context, tx *db.Tx, username string, newUsername string) error {
+ username = strings.ToLower(username)
+ if err := utils.ValidateUsername(username); err != nil {
+ return err
+ }
+
+ newUsername = strings.ToLower(newUsername)
+ if err := utils.ValidateUsername(newUsername); err != nil {
+ return err
+ }
+
+ query := tx.Rebind(`UPDATE users SET username = ? WHERE username = ?;`)
+ _, err := tx.ExecContext(ctx, query, newUsername, username)
+ return err
+}
@@ -0,0 +1,69 @@
+package store
+
+import (
+ "context"
+
+ "github.com/charmbracelet/soft-serve/server/access"
+ "github.com/charmbracelet/soft-serve/server/db"
+ "github.com/charmbracelet/soft-serve/server/db/models"
+ "golang.org/x/crypto/ssh"
+)
+
+// SettingStore is an interface for managing settings.
+type SettingStore interface {
+ GetAnonAccess(ctx context.Context, tx *db.Tx) (access.AccessLevel, error)
+ SetAnonAccess(ctx context.Context, tx *db.Tx, level access.AccessLevel) error
+ GetAllowKeylessAccess(ctx context.Context, tx *db.Tx) (bool, error)
+ SetAllowKeylessAccess(ctx context.Context, tx *db.Tx, allow bool) error
+}
+
+// RepositoryStore is an interface for managing repositories.
+type RepositoryStore interface {
+ GetRepoByName(ctx context.Context, tx *db.Tx, name string) (models.Repo, error)
+ GetAllRepos(ctx context.Context, tx *db.Tx) ([]models.Repo, error)
+ CreateRepo(ctx context.Context, tx *db.Tx, name string, projectName string, description string, isPrivate bool, isHidden bool, isMirror bool) error
+ DeleteRepoByName(ctx context.Context, tx *db.Tx, name string) error
+ SetRepoNameByName(ctx context.Context, tx *db.Tx, name string, newName string) error
+
+ GetRepoProjectNameByName(ctx context.Context, tx *db.Tx, name string) (string, error)
+ SetRepoProjectNameByName(ctx context.Context, tx *db.Tx, name string, projectName string) error
+ GetRepoDescriptionByName(ctx context.Context, tx *db.Tx, name string) (string, error)
+ SetRepoDescriptionByName(ctx context.Context, tx *db.Tx, name string, description string) error
+ GetRepoIsPrivateByName(ctx context.Context, tx *db.Tx, name string) (bool, error)
+ SetRepoIsPrivateByName(ctx context.Context, tx *db.Tx, name string, isPrivate bool) error
+ GetRepoIsHiddenByName(ctx context.Context, tx *db.Tx, name string) (bool, error)
+ SetRepoIsHiddenByName(ctx context.Context, tx *db.Tx, name string, isHidden bool) error
+ GetRepoIsMirrorByName(ctx context.Context, tx *db.Tx, name string) (bool, error)
+}
+
+// UserStore is an interface for managing users.
+type UserStore interface {
+ FindUserByUsername(ctx context.Context, tx *db.Tx, username string) (models.User, error)
+ FindUserByPublicKey(ctx context.Context, tx *db.Tx, pk ssh.PublicKey) (models.User, error)
+ GetAllUsers(ctx context.Context, tx *db.Tx) ([]models.User, error)
+ CreateUser(ctx context.Context, tx *db.Tx, username string, isAdmin bool, pks []ssh.PublicKey) error
+ DeleteUserByUsername(ctx context.Context, tx *db.Tx, username string) error
+ SetUsernameByUsername(ctx context.Context, tx *db.Tx, username string, newUsername string) error
+ SetAdminByUsername(ctx context.Context, tx *db.Tx, username string, isAdmin bool) error
+ AddPublicKeyByUsername(ctx context.Context, tx *db.Tx, username string, pk ssh.PublicKey) error
+ RemovePublicKeyByUsername(ctx context.Context, tx *db.Tx, username string, pk ssh.PublicKey) error
+ ListPublicKeysByUserID(ctx context.Context, tx *db.Tx, id int64) ([]ssh.PublicKey, error)
+ ListPublicKeysByUsername(ctx context.Context, tx *db.Tx, username string) ([]ssh.PublicKey, error)
+}
+
+// CollaboratorStore is an interface for managing collaborators.
+type CollaboratorStore interface {
+ GetCollabByUsernameAndRepo(ctx context.Context, tx *db.Tx, username string, repo string) (models.Collab, error)
+ AddCollabByUsernameAndRepo(ctx context.Context, tx *db.Tx, username string, repo string) error
+ RemoveCollabByUsernameAndRepo(ctx context.Context, tx *db.Tx, username string, repo string) error
+ ListCollabsByRepo(ctx context.Context, tx *db.Tx, repo string) ([]models.Collab, error)
+ ListCollabsByRepoAsUsers(ctx context.Context, tx *db.Tx, repo string) ([]models.User, error)
+}
+
+// Store is an interface for managing repositories, users, and settings.
+type Store interface {
+ RepositoryStore
+ UserStore
+ CollaboratorStore
+ SettingStore
+}
@@ -0,0 +1,35 @@
+package sync
+
+import (
+ "context"
+ "strconv"
+ "sync"
+ "testing"
+)
+
+func TestWorkPool(t *testing.T) {
+ mtx := &sync.Mutex{}
+ values := make([]int, 0)
+ wp := NewWorkPool(context.Background(), 3)
+ for i := 0; i < 10; i++ {
+ id := strconv.Itoa(i)
+ i := i
+ wp.Add(id, func() {
+ mtx.Lock()
+ values = append(values, i)
+ mtx.Unlock()
+ })
+ }
+ wp.Run()
+
+ if len(values) != 10 {
+ t.Errorf("expected 10 values, got %d, %v", len(values), values)
+ }
+
+ for i := range values {
+ id := strconv.Itoa(i)
+ if wp.Status(id) {
+ t.Errorf("expected %s to be false", id)
+ }
+ }
+}
@@ -5,6 +5,7 @@ import (
"github.com/charmbracelet/log"
"github.com/charmbracelet/soft-serve/git"
+ "github.com/charmbracelet/soft-serve/server/backend"
"github.com/charmbracelet/soft-serve/server/config"
"github.com/charmbracelet/soft-serve/server/ui/keymap"
"github.com/charmbracelet/soft-serve/server/ui/styles"
@@ -62,13 +63,19 @@ func (c *Common) SetSize(width, height int) {
c.Height = height
}
+// Context returns the context.
+func (c *Common) Context() context.Context {
+ return c.ctx
+}
+
// Config returns the server config.
func (c *Common) Config() *config.Config {
- v := c.ctx.Value(ConfigKey)
- if cfg, ok := v.(*config.Config); ok {
- return cfg
- }
- return nil
+ return config.FromContext(c.ctx)
+}
+
+// Backend returns the Soft Serve backend.
+func (c *Common) Backend() *backend.Backend {
+ return backend.FromContext(c.ctx)
}
// Repo returns the repository.
@@ -26,9 +26,8 @@ func RepoURL(publicURL, name string) string {
port := url.Port()
if port == "" || port == "22" {
return fmt.Sprintf("git@%s:%s", url.Hostname(), name)
- } else {
- return fmt.Sprintf("ssh://%s:%s/%s", url.Hostname(), url.Port(), name)
}
+ return fmt.Sprintf("ssh://%s:%s/%s", url.Hostname(), url.Port(), name)
}
}
@@ -187,11 +187,9 @@ func (r *Code) renderFile(path, content string, width int) (string, error) {
// width depends on the terminal. This is a workaround to replace tabs with
// 4-spaces.
content = strings.ReplaceAll(content, "\t", strings.Repeat(" ", tabWidth))
- lexer := lexers.Fallback
+ lexer := lexers.Match(path)
if path == "" {
lexer = lexers.Analyse(content)
- } else {
- lexer = lexers.Match(path)
}
lang := ""
if lexer != nil && lexer.Config() != nil {
@@ -47,7 +47,7 @@ func (f *Footer) Init() tea.Cmd {
}
// Update implements tea.Model.
-func (f *Footer) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
+func (f *Footer) Update(_ tea.Msg) (tea.Model, tea.Cmd) {
return f, nil
}
@@ -32,7 +32,7 @@ func (h *Header) Init() tea.Cmd {
}
// Update implements tea.Model.
-func (h *Header) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
+func (h *Header) Update(_ tea.Msg) (tea.Model, tea.Cmd) {
return h, nil
}
@@ -8,7 +8,7 @@ import (
)
// StatusBarMsg is a message sent to the status bar.
-type StatusBarMsg struct {
+type StatusBarMsg struct { //nolint:revive
Key string
Value string
Info string
@@ -10,7 +10,7 @@ import (
"github.com/charmbracelet/bubbles/key"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/soft-serve/git"
- "github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/proto"
"github.com/charmbracelet/soft-serve/server/ui/common"
"github.com/charmbracelet/soft-serve/server/ui/components/code"
"github.com/charmbracelet/soft-serve/server/ui/components/selector"
@@ -26,7 +26,6 @@ const (
var (
errNoFileSelected = errors.New("no file selected")
errBinaryFile = errors.New("binary file")
- errFileTooLarge = errors.New("file is too large")
errInvalidFile = errors.New("invalid file")
)
@@ -52,7 +51,7 @@ type Files struct {
selector *selector.Selector
ref *git.Reference
activeView filesView
- repo backend.Repository
+ repo proto.Repository
code *code.Code
path string
currentItem *FileItem
@@ -11,7 +11,7 @@ import (
gansi "github.com/charmbracelet/glamour/ansi"
"github.com/charmbracelet/lipgloss"
"github.com/charmbracelet/soft-serve/git"
- "github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/proto"
"github.com/charmbracelet/soft-serve/server/ui/common"
"github.com/charmbracelet/soft-serve/server/ui/components/footer"
"github.com/charmbracelet/soft-serve/server/ui/components/selector"
@@ -47,7 +47,7 @@ type Log struct {
selector *selector.Selector
vp *viewport.Viewport
activeView logView
- repo backend.Repository
+ repo proto.Repository
ref *git.Reference
count int64
nextPage int
@@ -25,6 +25,7 @@ func (i LogItem) ID() string {
return i.Hash()
}
+// Hash returns the commit hash.
func (i LogItem) Hash() string {
return i.Commit.ID.String()
}
@@ -7,6 +7,7 @@ import (
"github.com/charmbracelet/bubbles/key"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/proto"
"github.com/charmbracelet/soft-serve/server/ui/common"
"github.com/charmbracelet/soft-serve/server/ui/components/code"
)
@@ -21,7 +22,7 @@ type Readme struct {
common common.Common
code *code.Code
ref RefMsg
- repo backend.Repository
+ repo proto.Repository
readmePath string
}
@@ -1,7 +1,6 @@
package repo
import (
- "errors"
"fmt"
"sort"
"strings"
@@ -10,16 +9,12 @@ import (
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/soft-serve/git"
ggit "github.com/charmbracelet/soft-serve/git"
- "github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/proto"
"github.com/charmbracelet/soft-serve/server/ui/common"
"github.com/charmbracelet/soft-serve/server/ui/components/selector"
"github.com/charmbracelet/soft-serve/server/ui/components/tabs"
)
-var (
- errNoRef = errors.New("no reference specified")
-)
-
// RefMsg is a message that contains a git.Reference.
type RefMsg *ggit.Reference
@@ -33,7 +28,7 @@ type RefItemsMsg struct {
type Refs struct {
common common.Common
selector *selector.Selector
- repo backend.Repository
+ repo proto.Repository
ref *git.Reference
activeRef *git.Reference
refPrefix string
@@ -216,7 +211,7 @@ func switchRefCmd(ref *ggit.Reference) tea.Cmd {
}
// UpdateRefCmd gets the repository's HEAD reference and sends a RefMsg.
-func UpdateRefCmd(repo backend.Repository) tea.Cmd {
+func UpdateRefCmd(repo proto.Repository) tea.Cmd {
return func() tea.Msg {
r, err := repo.Open()
if err != nil {
@@ -9,7 +9,7 @@ import (
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/charmbracelet/soft-serve/git"
- "github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/proto"
"github.com/charmbracelet/soft-serve/server/ui/common"
"github.com/charmbracelet/soft-serve/server/ui/components/footer"
"github.com/charmbracelet/soft-serve/server/ui/components/statusbar"
@@ -54,7 +54,7 @@ type CopyURLMsg struct{}
type UpdateStatusBarMsg struct{}
// RepoMsg is a message that contains a git.Repository.
-type RepoMsg backend.Repository
+type RepoMsg proto.Repository // nolint:revive
// BackMsg is a message to go back to the previous view.
type BackMsg struct{}
@@ -68,7 +68,7 @@ type CopyMsg struct {
// Repo is a view for a git repository.
type Repo struct {
common common.Common
- selectedRepo backend.Repository
+ selectedRepo proto.Repository
activeTab tab
tabs *tabs.Tabs
statusbar *statusbar.StatusBar
@@ -11,8 +11,8 @@ import (
"github.com/charmbracelet/bubbles/list"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
- "github.com/charmbracelet/soft-serve/server/backend"
"github.com/charmbracelet/soft-serve/server/config"
+ "github.com/charmbracelet/soft-serve/server/proto"
"github.com/charmbracelet/soft-serve/server/ui/common"
"github.com/dustin/go-humanize"
)
@@ -48,13 +48,13 @@ func (it Items) Swap(i int, j int) {
// Item represents a single item in the selector.
type Item struct {
- repo backend.Repository
+ repo proto.Repository
lastUpdate *time.Time
cmd string
}
// New creates a new Item.
-func NewItem(repo backend.Repository, cfg *config.Config) (Item, error) {
+func NewItem(repo proto.Repository, cfg *config.Config) (Item, error) {
var lastUpdate *time.Time
lu := repo.UpdatedAt()
if !lu.IsZero() {
@@ -8,6 +8,7 @@ import (
"github.com/charmbracelet/bubbles/list"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
+ "github.com/charmbracelet/soft-serve/server/access"
"github.com/charmbracelet/soft-serve/server/backend"
"github.com/charmbracelet/soft-serve/server/ui/common"
"github.com/charmbracelet/soft-serve/server/ui/components/code"
@@ -36,12 +37,11 @@ func (p pane) String() string {
// Selection is the model for the selection screen/page.
type Selection struct {
- common common.Common
- readme *code.Code
- readmeHeight int
- selector *selector.Selector
- activePane pane
- tabs *tabs.Tabs
+ common common.Common
+ readme *code.Code
+ selector *selector.Selector
+ activePane pane
+ tabs *tabs.Tabs
}
// New creates a new selection model.
@@ -187,12 +187,14 @@ func (s *Selection) Init() tea.Cmd {
return nil
}
+ ctx := s.common.Context()
+ be := s.common.Backend()
pk := s.common.PublicKey()
- if pk == nil && !cfg.Backend.AllowKeyless() {
+ if pk == nil && !be.AllowKeyless(ctx) {
return nil
}
- repos, err := cfg.Backend.Repositories()
+ repos, err := be.Repositories(ctx)
if err != nil {
return common.ErrorCmd(err)
}
@@ -210,8 +212,8 @@ func (s *Selection) Init() tea.Cmd {
if r.IsHidden() {
continue
}
- al := cfg.Backend.AccessLevelByPublicKey(r.Name(), pk)
- if al >= backend.ReadOnlyAccess {
+ al := be.AccessLevelByPublicKey(ctx, r.Name(), pk)
+ if al >= access.ReadOnlyAccess {
item, err := NewItem(r, cfg)
if err != nil {
s.common.Logger.Debugf("ui: failed to create item for %s: %v", r.Name(), err)
@@ -7,7 +7,7 @@ import (
"github.com/charmbracelet/bubbles/list"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
- "github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/proto"
"github.com/charmbracelet/soft-serve/server/ui/common"
"github.com/charmbracelet/soft-serve/server/ui/components/footer"
"github.com/charmbracelet/soft-serve/server/ui/components/header"
@@ -283,12 +283,15 @@ func (ui *UI) View() string {
)
}
-func (ui *UI) openRepo(rn string) (backend.Repository, error) {
+func (ui *UI) openRepo(rn string) (proto.Repository, error) {
cfg := ui.common.Config()
if cfg == nil {
return nil, errors.New("config is nil")
}
- repos, err := cfg.Backend.Repositories()
+
+ ctx := ui.common.Context()
+ be := ui.common.Backend()
+ repos, err := be.Repositories(ctx)
if err != nil {
ui.common.Logger.Debugf("ui: failed to list repos: %v", err)
return nil, err
@@ -0,0 +1,28 @@
+package web
+
+import (
+ "context"
+ "net/http"
+
+ "github.com/charmbracelet/log"
+ "github.com/charmbracelet/soft-serve/server/backend"
+ "github.com/charmbracelet/soft-serve/server/config"
+)
+
+// NewContextMiddleware returns a new context middleware.
+// This middleware adds the config, backend, and logger to the request context.
+func NewContextMiddleware(ctx context.Context) func(http.Handler) http.Handler {
+ cfg := config.FromContext(ctx)
+ be := backend.FromContext(ctx)
+ logger := log.FromContext(ctx).WithPrefix("http")
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+ ctx = config.WithContext(ctx, cfg)
+ ctx = backend.WithContext(ctx, be)
+ ctx = log.WithContext(ctx, logger)
+ r = r.WithContext(ctx)
+ next.ServeHTTP(w, r)
+ })
+ }
+}
@@ -4,6 +4,7 @@ import (
"bytes"
"compress/gzip"
"context"
+ "errors"
"fmt"
"io"
"net/http"
@@ -15,9 +16,11 @@ import (
"github.com/charmbracelet/log"
gitb "github.com/charmbracelet/soft-serve/git"
+ "github.com/charmbracelet/soft-serve/server/access"
"github.com/charmbracelet/soft-serve/server/backend"
"github.com/charmbracelet/soft-serve/server/config"
"github.com/charmbracelet/soft-serve/server/git"
+ "github.com/charmbracelet/soft-serve/server/proto"
"github.com/charmbracelet/soft-serve/server/utils"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
@@ -30,22 +33,15 @@ type GitRoute struct {
method string
pattern *regexp.Regexp
handler http.HandlerFunc
-
- cfg *config.Config
- be backend.Backend
- logger *log.Logger
}
var _ Route = GitRoute{}
// Match implements goji.Pattern.
func (g GitRoute) Match(r *http.Request) *http.Request {
- if g.method != r.Method {
- return nil
- }
-
re := g.pattern
ctx := r.Context()
+ cfg := config.FromContext(ctx)
if m := re.FindStringSubmatch(r.URL.Path); m != nil {
file := strings.Replace(r.URL.Path, m[1]+"/", "", 1)
repo := utils.SanitizeRepo(m[1]) + ".git"
@@ -59,22 +55,10 @@ func (g GitRoute) Match(r *http.Request) *http.Request {
}
ctx = context.WithValue(ctx, pattern.Variable("service"), service.String())
- ctx = context.WithValue(ctx, pattern.Variable("dir"), filepath.Join(g.cfg.DataPath, "repos", repo))
+ ctx = context.WithValue(ctx, pattern.Variable("dir"), filepath.Join(cfg.DataPath, "repos", repo))
ctx = context.WithValue(ctx, pattern.Variable("repo"), repo)
ctx = context.WithValue(ctx, pattern.Variable("file"), file)
- if g.cfg != nil {
- ctx = config.WithContext(ctx, g.cfg)
- }
-
- if g.be != nil {
- ctx = backend.WithContext(ctx, g.be.WithContext(ctx))
- }
-
- if g.logger != nil {
- ctx = log.WithContext(ctx, g.logger)
- }
-
return r.WithContext(ctx)
}
@@ -83,10 +67,16 @@ func (g GitRoute) Match(r *http.Request) *http.Request {
// ServeHTTP implements http.Handler.
func (g GitRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ if r.Method != g.method {
+ renderMethodNotAllowed(w, r)
+ return
+ }
+
g.handler(w, r)
}
var (
+ //nolint:revive
gitHttpReceiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "soft_serve",
Subsystem: "http",
@@ -94,6 +84,7 @@ var (
Help: "The total number of git push requests",
}, []string{"repo"})
+ //nolint:revive
gitHttpUploadCounter = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "soft_serve",
Subsystem: "http",
@@ -102,10 +93,8 @@ var (
}, []string{"repo", "file"})
)
-func gitRoutes(ctx context.Context, logger *log.Logger) []Route {
+func gitRoutes() []Route {
routes := make([]Route, 0)
- cfg := config.FromContext(ctx)
- be := backend.FromContext(ctx)
// Git services
// These routes don't handle authentication/authorization.
@@ -159,19 +148,16 @@ func gitRoutes(ctx context.Context, logger *log.Logger) []Route {
handler: getLooseObject,
},
{
- pattern: regexp.MustCompile("(.*?)/objects/pack/pack-[0-9a-f]{40}\\.pack$"),
+ pattern: regexp.MustCompile(`(.*?)/objects/pack/pack-[0-9a-f]{40}\.pack$`),
method: http.MethodGet,
handler: getPackFile,
},
{
- pattern: regexp.MustCompile("(.*?)/objects/pack/pack-[0-9a-f]{40}\\.idx$"),
+ pattern: regexp.MustCompile(`(.*?)/objects/pack/pack-[0-9a-f]{40}\.idx$`),
method: http.MethodGet,
handler: getIdxFile,
},
} {
- route.cfg = cfg
- route.be = be
- route.logger = logger
route.handler = withAccess(route.handler)
routes = append(routes, route)
}
@@ -186,32 +172,32 @@ func withAccess(fn http.HandlerFunc) http.HandlerFunc {
be := backend.FromContext(ctx)
logger := log.FromContext(ctx)
- if !be.AllowKeyless() {
+ if !be.AllowKeyless(ctx) {
renderForbidden(w)
return
}
repo := pat.Param(r, "repo")
service := git.Service(pat.Param(r, "service"))
- access := be.AccessLevel(repo, "")
+ accessLevel := be.AccessLevel(ctx, repo, "")
switch service {
case git.ReceivePackService:
- if access < backend.ReadWriteAccess {
+ if accessLevel < access.ReadWriteAccess {
renderUnauthorized(w)
return
}
// Create the repo if it doesn't exist.
- if _, err := be.Repository(repo); err != nil {
- if _, err := be.CreateRepository(repo, backend.RepositoryOptions{}); err != nil {
+ if _, err := be.Repository(ctx, repo); err != nil {
+ if _, err := be.CreateRepository(ctx, repo, proto.RepositoryOptions{}); err != nil {
logger.Error("failed to create repository", "repo", repo, "err", err)
renderInternalServerError(w)
return
}
}
default:
- if access < backend.ReadOnlyAccess {
+ if accessLevel < access.ReadOnlyAccess {
renderUnauthorized(w)
return
}
@@ -221,8 +207,10 @@ func withAccess(fn http.HandlerFunc) http.HandlerFunc {
}
}
+//nolint:revive
func serviceRpc(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
+ cfg := config.FromContext(ctx)
logger := log.FromContext(ctx)
service, dir, repo := git.Service(pat.Param(r, "service")), pat.Param(r, "dir"), pat.Param(r, "repo")
@@ -243,71 +231,77 @@ func serviceRpc(w http.ResponseWriter, r *http.Request) {
version := r.Header.Get("Git-Protocol")
+ var stdout bytes.Buffer
cmd := git.ServiceCommand{
- Stdin: r.Body,
- Stdout: w,
+ Stdout: &stdout,
Dir: dir,
Args: []string{"--stateless-rpc"},
}
if len(version) != 0 {
- cmd.Env = append(cmd.Env, fmt.Sprintf("GIT_PROTOCOL=%s", version))
+ cmd.Env = append(cmd.Env, []string{
+ // TODO: add the rest of env vars when we support pushing using http
+ "SOFT_SERVE_LOG_PATH=" + filepath.Join(cfg.DataPath, "log", "hooks.log"),
+ fmt.Sprintf("GIT_PROTOCOL=%s", version),
+ }...)
}
// Handle gzip encoding
- cmd.StdinHandler = func(in io.Reader, stdin io.WriteCloser) (err error) {
- // We know that `in` is an `io.ReadCloser` because it's `r.Body`.
- reader := in.(io.ReadCloser)
- defer reader.Close() // nolint: errcheck
- switch r.Header.Get("Content-Encoding") {
- case "gzip":
- reader, err = gzip.NewReader(reader)
- if err != nil {
- return err
- }
- defer reader.Close() // nolint: errcheck
+ reader := r.Body
+ defer reader.Close() // nolint: errcheck
+ switch r.Header.Get("Content-Encoding") {
+ case "gzip":
+ reader, err := gzip.NewReader(reader)
+ if err != nil {
+ logger.Errorf("failed to create gzip reader: %v", err)
+ renderInternalServerError(w)
+ return
}
+ defer reader.Close() // nolint: errcheck
+ }
+
+ cmd.Stdin = reader
- _, err = io.Copy(stdin, reader)
- return err
+ if err := service.Handler(ctx, cmd); err != nil {
+ if errors.Is(err, git.ErrInvalidRepo) {
+ renderNotFound(w)
+ return
+ }
+ renderInternalServerError(w)
+ return
}
// Handle buffered output
// Useful when using proxies
- cmd.StdoutHandler = func(out io.Writer, stdout io.ReadCloser) error {
- // We know that `out` is an `http.ResponseWriter`.
- flusher, ok := out.(http.Flusher)
- if !ok {
- return fmt.Errorf("expected http.ResponseWriter to be an http.Flusher, got %T", out)
- }
-
- p := make([]byte, 1024)
- for {
- nRead, err := stdout.Read(p)
- if err == io.EOF {
- break
- }
- nWrite, err := out.Write(p[:nRead])
- if err != nil {
- return err
- }
- if nRead != nWrite {
- return fmt.Errorf("failed to write data: %d read, %d written", nRead, nWrite)
- }
- flusher.Flush()
- }
- return nil
+ // We know that `w` is an `http.ResponseWriter`.
+ flusher, ok := w.(http.Flusher)
+ if !ok {
+ logger.Errorf("expected http.ResponseWriter to be an http.Flusher, got %T", w)
+ return
}
- if err := service.Handler(ctx, cmd); err != nil {
- logger.Errorf("error executing service: %s", err)
+ p := make([]byte, 1024)
+ for {
+ nRead, err := stdout.Read(p)
+ if err == io.EOF {
+ break
+ }
+ nWrite, err := w.Write(p[:nRead])
+ if err != nil {
+ logger.Errorf("failed to write data: %v", err)
+ return
+ }
+ if nRead != nWrite {
+ logger.Errorf("failed to write data: %d read, %d written", nRead, nWrite)
+ return
+ }
+ flusher.Flush()
}
}
func getInfoRefs(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
- logger := log.FromContext(ctx)
dir, repo, file := pat.Param(r, "dir"), pat.Param(r, "repo"), pat.Param(r, "file")
service := getServiceType(r)
version := r.Header.Get("Git-Protocol")
@@ -328,7 +322,6 @@ func getInfoRefs(w http.ResponseWriter, r *http.Request) {
}
if err := service.Handler(ctx, cmd); err != nil {
- logger.Errorf("error executing service: %s", err)
renderNotFound(w)
return
}
@@ -337,7 +330,7 @@ func getInfoRefs(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", fmt.Sprintf("application/x-%s-advertisement", service))
w.WriteHeader(http.StatusOK)
if len(version) == 0 {
- git.WritePktline(w, "# service="+service.String())
+ git.WritePktline(w, "# service="+service.String()) // nolint: errcheck
}
w.Write(refs.Bytes()) // nolint: errcheck
@@ -400,10 +393,7 @@ func getServiceType(r *http.Request) git.Service {
}
func isSmart(r *http.Request, service git.Service) bool {
- if r.Header.Get("Content-Type") == fmt.Sprintf("application/x-%s-request", service) {
- return true
- }
- return false
+ return r.Header.Get("Content-Type") == fmt.Sprintf("application/x-%s-request", service)
}
func updateServerInfo(ctx context.Context, dir string) error {
@@ -34,17 +34,16 @@ Redirecting to docs at <a href="https://godoc.org/{{ .ImportRoot }}/{{ .Repo }}"
</html>`))
// GoGetHandler handles go get requests.
-type GoGetHandler struct {
- cfg *config.Config
- be backend.Backend
-}
+type GoGetHandler struct{}
var _ http.Handler = (*GoGetHandler)(nil)
func (g GoGetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
repo := pattern.Path(r.Context())
repo = utils.SanitizeRepo(repo)
- be := g.be.WithContext(r.Context())
+ ctx := r.Context()
+ cfg := config.FromContext(ctx)
+ be := backend.FromContext(ctx)
// Handle go get requests.
//
@@ -54,7 +53,7 @@ func (g GoGetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// https://go.dev/ref/mod#vcs-branch
if r.URL.Query().Get("go-get") == "1" {
repo := repo
- importRoot, err := url.Parse(g.cfg.HTTP.PublicURL)
+ importRoot, err := url.Parse(cfg.HTTP.PublicURL)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@@ -62,7 +61,7 @@ func (g GoGetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// find the repo
for {
- if _, err := be.Repository(repo); err == nil {
+ if _, err := be.Repository(ctx, repo); err == nil {
break
}
@@ -79,7 +78,7 @@ func (g GoGetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ImportRoot string
}{
Repo: url.PathEscape(repo),
- Config: g.cfg,
+ Config: cfg,
ImportRoot: importRoot.Host,
}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -5,7 +5,6 @@ import (
"net/http"
"time"
- "github.com/charmbracelet/soft-serve/server/backend"
"github.com/charmbracelet/soft-serve/server/config"
)
@@ -13,7 +12,6 @@ import (
type HTTPServer struct {
ctx context.Context
cfg *config.Config
- be backend.Backend
server *http.Server
}
@@ -23,7 +21,6 @@ func NewHTTPServer(ctx context.Context) (*HTTPServer, error) {
s := &HTTPServer{
ctx: ctx,
cfg: cfg,
- be: backend.FromContext(ctx),
server: &http.Server{
Addr: cfg.HTTP.ListenAddr,
Handler: NewRouter(ctx),
@@ -24,7 +24,7 @@ var _ http.Flusher = (*logWriter)(nil)
var _ http.Hijacker = (*logWriter)(nil)
-var _ http.CloseNotifier = (*logWriter)(nil)
+var _ http.CloseNotifier = (*logWriter)(nil) // nolint: staticcheck
// Write implements http.ResponseWriter.
func (r *logWriter) Write(p []byte) (int, error) {
@@ -49,7 +49,7 @@ func (r *logWriter) Flush() {
// CloseNotify implements http.CloseNotifier.
func (r *logWriter) CloseNotify() <-chan bool {
- if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok {
+ if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok { // nolint: staticcheck
return cn.CloseNotify()
}
return nil
@@ -64,21 +64,20 @@ func (r *logWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
}
// NewLoggingMiddleware returns a new logging middleware.
-func NewLoggingMiddleware(logger *log.Logger) func(http.Handler) http.Handler {
- return func(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- start := time.Now()
- writer := &logWriter{code: http.StatusOK, ResponseWriter: w}
- logger.Debug("request",
- "method", r.Method,
- "uri", r.RequestURI,
- "addr", r.RemoteAddr)
- next.ServeHTTP(writer, r)
- elapsed := time.Since(start)
- logger.Debug("response",
- "status", fmt.Sprintf("%d %s", writer.code, http.StatusText(writer.code)),
- "bytes", humanize.Bytes(uint64(writer.bytes)),
- "time", elapsed)
- })
- }
+func NewLoggingMiddleware(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ logger := log.FromContext(r.Context())
+ start := time.Now()
+ writer := &logWriter{code: http.StatusOK, ResponseWriter: w}
+ logger.Debug("request",
+ "method", r.Method,
+ "uri", r.RequestURI,
+ "addr", r.RemoteAddr)
+ next.ServeHTTP(writer, r)
+ elapsed := time.Since(start)
+ logger.Debug("response",
+ "status", fmt.Sprintf("%d %s", writer.code, http.StatusText(writer.code)),
+ "bytes", humanize.Bytes(uint64(writer.bytes)),
+ "time", elapsed)
+ })
}
@@ -1,13 +1,9 @@
-// Package server is the reusable server
package web
import (
"context"
"net/http"
- "github.com/charmbracelet/log"
- "github.com/charmbracelet/soft-serve/server/backend"
- "github.com/charmbracelet/soft-serve/server/config"
"goji.io"
"goji.io/pat"
)
@@ -21,20 +17,18 @@ type Route interface {
// NewRouter returns a new HTTP router.
func NewRouter(ctx context.Context) *goji.Mux {
mux := goji.NewMux()
- cfg := config.FromContext(ctx)
- be := backend.FromContext(ctx)
- logger := log.FromContext(ctx).WithPrefix("http")
// Middlewares
- mux.Use(NewLoggingMiddleware(logger))
+ mux.Use(NewContextMiddleware(ctx))
+ mux.Use(NewLoggingMiddleware)
// Git routes
- for _, service := range gitRoutes(ctx, logger) {
+ for _, service := range gitRoutes() {
mux.Handle(service, service)
}
// go-get handler
- mux.Handle(pat.Get("/*"), GoGetHandler{cfg, be})
+ mux.Handle(pat.Get("/*"), GoGetHandler{})
return mux
}
@@ -15,10 +15,14 @@ import (
"github.com/charmbracelet/keygen"
"github.com/charmbracelet/soft-serve/server"
+ "github.com/charmbracelet/soft-serve/server/backend"
"github.com/charmbracelet/soft-serve/server/config"
+ "github.com/charmbracelet/soft-serve/server/db"
+ "github.com/charmbracelet/soft-serve/server/db/migrate"
"github.com/charmbracelet/soft-serve/server/test"
"github.com/rogpeppe/go-internal/testscript"
"golang.org/x/crypto/ssh"
+ _ "modernc.org/sqlite" // sqlite Driver
)
var update = flag.Bool("update", false, "update script files")
@@ -51,42 +55,56 @@ func TestScript(t *testing.T) {
"dos2unix": cmdDos2Unix,
},
Setup: func(e *testscript.Env) error {
+ data := t.TempDir()
+
sshPort := test.RandomPort()
+ sshListen := fmt.Sprintf("localhost:%d", sshPort)
+ gitPort := test.RandomPort()
+ gitListen := fmt.Sprintf("localhost:%d", gitPort)
+ httpPort := test.RandomPort()
+ httpListen := fmt.Sprintf("localhost:%d", httpPort)
+ statsPort := test.RandomPort()
+ statsListen := fmt.Sprintf("localhost:%d", statsPort)
+ serverName := "Test Soft Serve"
+
e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort))
e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey())
e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey())
e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey())
e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts"))
e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config"))
- data := t.TempDir()
- cfg := config.Config{
- Name: "Test Soft Serve",
- DataPath: data,
- InitialAdminKeys: []string{admin1.AuthorizedKey()},
- SSH: config.SSHConfig{
- ListenAddr: fmt.Sprintf("localhost:%d", sshPort),
- PublicURL: fmt.Sprintf("ssh://localhost:%d", sshPort),
- KeyPath: filepath.Join(data, "ssh", "soft_serve_host_ed25519"),
- ClientKeyPath: filepath.Join(data, "ssh", "soft_serve_client_ed25519"),
- },
- Git: config.GitConfig{
- ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
- IdleTimeout: 3,
- MaxConnections: 32,
- },
- HTTP: config.HTTPConfig{
- ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
- PublicURL: fmt.Sprintf("http://localhost:%d", test.RandomPort()),
- },
- Stats: config.StatsConfig{
- ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()),
- },
- Log: config.LogConfig{
- Format: "text",
- TimeFormat: time.DateTime,
- },
+
+ cfg := config.DefaultConfig()
+ cfg.DataPath = data
+ cfg.Name = serverName
+ cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()}
+ cfg.SSH.ListenAddr = sshListen
+ cfg.SSH.PublicURL = "ssh://" + sshListen
+ cfg.Git.ListenAddr = gitListen
+ cfg.HTTP.ListenAddr = httpListen
+ cfg.HTTP.PublicURL = "http://" + httpListen
+ cfg.Stats.ListenAddr = statsListen
+ cfg.DB.Driver = "sqlite"
+
+ if err := cfg.Validate(); err != nil {
+ return err
}
- ctx := config.WithContext(context.Background(), &cfg)
+
+ ctx := config.WithContext(context.Background(), cfg)
+
+ // TODO: test postgres
+ dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
+ if err != nil {
+ return fmt.Errorf("open database: %w", err)
+ }
+
+ if err := migrate.Migrate(ctx, dbx); err != nil {
+ return fmt.Errorf("migrate database: %w", err)
+ }
+
+ ctx = db.WithContext(ctx, dbx)
+ be := backend.New(ctx, cfg, dbx)
+ ctx = backend.WithContext(ctx, be)
// prevent race condition in lipgloss...
// this will probably be autofixed when we start using the colors
@@ -106,6 +124,7 @@ func TestScript(t *testing.T) {
}()
e.Defer(func() {
+ defer dbx.Close() // nolint: errcheck
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
@@ -0,0 +1,30 @@
+# vi: set ft=conf
+
+# convert crlf to lf on windows
+[windows] dos2unix commit1.txt
+
+# create a repo
+soft repo import basic1 https://github.com/git-fixtures/basic
+
+# print commit
+soft repo commit basic1 b8e471f58bcbca63b07bda20e428190409c2db47
+cmp stdout commit1.txt
+
+-- commit1.txt --
+commit b8e471f58bcbca63b07bda20e428190409c2db47
+Author: Daniel Ripolles
+Date: Tue Mar 31 11:44:52 UTC 2015
+Creating changelog
+
+
+CHANGELOG | 1 +
+1 file changed, 1 insertion(+)
+
+diff --git a/CHANGELOG b/CHANGELOG
+new file mode 100644
+index 0000000000000000000000000000000000000000..d3ff53e0564a9f87d8e84b6e28e5060e517008aa
+--- /dev/null
++++ b/CHANGELOG
+@@ -0,0 +1 @@
++Initial changelog
+