diff --git a/.golangci-soft.yml b/.golangci-soft.yml index ef456e0605919f959a254d6dbe917b5d26f7ffcd..0500f19d30a2e9355aa138c3225fc615c8c0c2d0 100644 --- a/.golangci-soft.yml +++ b/.golangci-soft.yml @@ -1,5 +1,6 @@ run: tests: false + timeout: 5m issues: include: diff --git a/.golangci.yml b/.golangci.yml index a5a91d0d91348545ae21ac105126432198795697..4ae3fa1bbc7a11b759e838a9e205eba6228274ed 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,5 +1,6 @@ run: tests: false + timeout: 5m issues: include: @@ -27,3 +28,10 @@ linters: - unconvert - unparam - whitespace + +severity: + default-severity: error + rules: + - linters: + - revive + severity: info diff --git a/cmd/soft/admin.go b/cmd/soft/admin.go new file mode 100644 index 0000000000000000000000000000000000000000..16d31f2783026e7ff59fb93b7aa271edfed7dbc8 --- /dev/null +++ b/cmd/soft/admin.go @@ -0,0 +1,75 @@ +package main + +import ( + "fmt" + + "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/db/migrate" + "github.com/spf13/cobra" +) + +var ( + adminCmd = &cobra.Command{ + Use: "admin", + Short: "Administrate the server", + } + + migrateCmd = &cobra.Command{ + Use: "migrate", + Short: "Migrate the database to the latest version", + PersistentPreRunE: initBackendContext, + PersistentPostRunE: closeDBContext, + RunE: func(cmd *cobra.Command, _ []string) error { + ctx := cmd.Context() + db := db.FromContext(ctx) + if err := migrate.Migrate(ctx, db); err != nil { + return fmt.Errorf("migration: %w", err) + } + + return nil + }, + } + + rollbackCmd = &cobra.Command{ + Use: "rollback", + Short: "Rollback the database to the previous version", + PersistentPreRunE: initBackendContext, + PersistentPostRunE: closeDBContext, + RunE: func(cmd *cobra.Command, _ []string) error { + ctx := cmd.Context() + db := db.FromContext(ctx) + if err := migrate.Rollback(ctx, db); err != nil { + return fmt.Errorf("rollback: %w", err) + } + + return nil + }, + } + + syncHooksCmd = &cobra.Command{ + Use: "sync-hooks", + Short: "Update repository hooks", + PersistentPreRunE: initBackendContext, + PersistentPostRunE: closeDBContext, + RunE: func(cmd *cobra.Command, _ []string) error { + ctx := cmd.Context() + cfg := config.FromContext(ctx) + be := backend.FromContext(ctx) + if err := initializeHooks(ctx, cfg, be); err != nil { + return fmt.Errorf("initialize hooks: %w", err) + } + + return nil + }, + } +) + +func init() { + adminCmd.AddCommand( + syncHooksCmd, + migrateCmd, + rollbackCmd, + ) +} diff --git a/cmd/soft/hook.go b/cmd/soft/hook.go index c1f9d4233ab56846043584f052c1195bb235cffa..fdfbaed46d1b441a5f89f8412a90b2e2464c7b0a 100644 --- a/cmd/soft/hook.go +++ b/cmd/soft/hook.go @@ -11,66 +11,33 @@ import ( "path/filepath" "strings" - "github.com/charmbracelet/log" "github.com/charmbracelet/soft-serve/server/backend" - "github.com/charmbracelet/soft-serve/server/backend/sqlite" "github.com/charmbracelet/soft-serve/server/config" "github.com/charmbracelet/soft-serve/server/hooks" "github.com/spf13/cobra" ) var ( + // Deprecated: this flag is ignored. configPath string - logFileCtxKey = struct{}{} - hookCmd = &cobra.Command{ - Use: "hook", - Short: "Run git server hooks", - Long: "Handles Soft Serve git server hooks.", - Hidden: true, - PersistentPreRunE: func(cmd *cobra.Command, _ []string) error { - ctx := cmd.Context() - cfg, err := config.ParseConfig(configPath) - if err != nil { - return fmt.Errorf("could not parse config: %w", err) - } - - ctx = config.WithContext(ctx, cfg) - - logPath := filepath.Join(cfg.DataPath, "log", "hooks.log") - f, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - return fmt.Errorf("opening file: %w", err) - } - - ctx = context.WithValue(ctx, logFileCtxKey, f) - logger := log.FromContext(ctx) - logger.SetOutput(f) - ctx = log.WithContext(ctx, logger) - cmd.SetContext(ctx) - - // Set up the backend - // TODO: support other backends - sb, err := sqlite.NewSqliteBackend(ctx) - if err != nil { - return fmt.Errorf("failed to create sqlite backend: %w", err) - } - - cfg = cfg.WithBackend(sb) - - return nil - }, - PersistentPostRunE: func(cmd *cobra.Command, _ []string) error { - f := cmd.Context().Value(logFileCtxKey).(*os.File) - return f.Close() - }, + Use: "hook", + Short: "Run git server hooks", + Long: "Handles Soft Serve git server hooks.", + Hidden: true, + PersistentPreRunE: initBackendContext, + PersistentPostRunE: closeDBContext, } + // Git hooks read the config from the environment, based on + // $SOFT_SERVE_DATA_PATH. We already parse the config when the binary + // starts, so we don't need to do it again. + // The --config flag is now deprecated. hooksRunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() + hks := backend.FromContext(ctx) cfg := config.FromContext(ctx) - hks := cfg.Backend.(backend.Hooks) // This is set in the server before invoking git-receive-pack/git-upload-pack repoName := os.Getenv("SOFT_SERVE_REPO_NAME") @@ -80,10 +47,10 @@ var ( stderr := cmd.ErrOrStderr() cmdName := cmd.Name() - customHookPath := filepath.Join(filepath.Dir(configPath), "hooks", cmdName) + customHookPath := filepath.Join(cfg.DataPath, "hooks", cmdName) var buf bytes.Buffer - opts := make([]backend.HookArg, 0) + opts := make([]hooks.HookArg, 0) switch cmdName { case hooks.PreReceiveHook, hooks.PostReceiveHook: @@ -94,7 +61,7 @@ var ( if len(fields) != 3 { return fmt.Errorf("invalid hook input: %s", scanner.Text()) } - opts = append(opts, backend.HookArg{ + opts = append(opts, hooks.HookArg{ OldSha: fields[0], NewSha: fields[1], RefName: fields[2], @@ -103,22 +70,22 @@ var ( switch cmdName { case hooks.PreReceiveHook: - hks.PreReceive(stdout, stderr, repoName, opts) + hks.PreReceive(ctx, stdout, stderr, repoName, opts) case hooks.PostReceiveHook: - hks.PostReceive(stdout, stderr, repoName, opts) + hks.PostReceive(ctx, stdout, stderr, repoName, opts) } case hooks.UpdateHook: if len(args) != 3 { return fmt.Errorf("invalid update hook input: %s", args) } - hks.Update(stdout, stderr, repoName, backend.HookArg{ + hks.Update(ctx, stdout, stderr, repoName, hooks.HookArg{ OldSha: args[0], NewSha: args[1], RefName: args[2], }) case hooks.PostUpdateHook: - hks.PostUpdate(stdout, stderr, repoName, args...) + hks.PostUpdate(ctx, stdout, stderr, repoName, args...) } // Custom hooks @@ -159,7 +126,7 @@ var ( ) func init() { - hookCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to config file") + hookCmd.PersistentFlags().StringVar(&configPath, "config", "", "path to config file (deprecated)") hookCmd.AddCommand( preReceiveCmd, updateCmd, @@ -211,6 +178,7 @@ fi echo "Hi from Soft Serve update hook!" echo +echo "Repository: $SOFT_SERVE_REPO_NAME" echo "RefName: $refname" echo "Change Type: $newrev_type" echo "Old SHA1: $oldrev" diff --git a/cmd/soft/man.go b/cmd/soft/man.go index 71b1dfb2c3534a3e41d57baf34b5c0fadd09dd59..966fe27d349ae11d407ba618dc6d647218ecc692 100644 --- a/cmd/soft/man.go +++ b/cmd/soft/man.go @@ -8,22 +8,20 @@ import ( "github.com/spf13/cobra" ) -var ( - manCmd = &cobra.Command{ - Use: "man", - Short: "Generate man pages", - Args: cobra.NoArgs, - Hidden: true, - RunE: func(cmd *cobra.Command, args []string) error { - manPage, err := mcobra.NewManPage(1, rootCmd) //. - if err != nil { - return err - } +var manCmd = &cobra.Command{ + Use: "man", + Short: "Generate man pages", + Args: cobra.NoArgs, + Hidden: true, + RunE: func(_ *cobra.Command, _ []string) error { + manPage, err := mcobra.NewManPage(1, rootCmd) //. + if err != nil { + return err + } - manPage = manPage.WithSection("Copyright", "(C) 2021-2023 Charmbracelet, Inc.\n"+ - "Released under MIT license.") - fmt.Println(manPage.Build(roff.NewDocument())) - return nil - }, - } -) + manPage = manPage.WithSection("Copyright", "(C) 2021-2023 Charmbracelet, Inc.\n"+ + "Released under MIT license.") + fmt.Println(manPage.Build(roff.NewDocument())) + return nil + }, +} diff --git a/cmd/soft/migrate_config.go b/cmd/soft/migrate_config.go index 6624da972e92d86a861ac65a64eaaaf45d441ec7..4d85c0bd82accb9ca38600b29b1352ea4027a36f 100644 --- a/cmd/soft/migrate_config.go +++ b/cmd/soft/migrate_config.go @@ -12,9 +12,12 @@ import ( "github.com/charmbracelet/log" "github.com/charmbracelet/soft-serve/git" + "github.com/charmbracelet/soft-serve/server/access" "github.com/charmbracelet/soft-serve/server/backend" - "github.com/charmbracelet/soft-serve/server/backend/sqlite" "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/proto" + "github.com/charmbracelet/soft-serve/server/sshutils" "github.com/charmbracelet/soft-serve/server/utils" gitm "github.com/gogs/git-module" "github.com/spf13/cobra" @@ -22,296 +25,300 @@ import ( "gopkg.in/yaml.v3" ) -var ( - migrateConfig = &cobra.Command{ - Use: "migrate-config", - Short: "Migrate config to new format", - Hidden: true, - RunE: func(cmd *cobra.Command, _ []string) error { - ctx := cmd.Context() - - logger := log.FromContext(ctx) - // Disable logging timestamp - logger.SetReportTimestamp(false) - - keyPath := os.Getenv("SOFT_SERVE_KEY_PATH") - reposPath := os.Getenv("SOFT_SERVE_REPO_PATH") - bindAddr := os.Getenv("SOFT_SERVE_BIND_ADDRESS") - cfg := config.DefaultConfig() - ctx = config.WithContext(ctx, cfg) - sb, err := sqlite.NewSqliteBackend(ctx) - if err != nil { - return fmt.Errorf("failed to create sqlite backend: %w", err) - } +// Deprecated: will be removed in a future release. +var migrateConfig = &cobra.Command{ + Use: "migrate-config", + Short: "Migrate config to new format", + Hidden: true, + RunE: func(cmd *cobra.Command, _ []string) error { + ctx := cmd.Context() + + logger := log.FromContext(ctx) + // Disable logging timestamp + logger.SetReportTimestamp(false) + + keyPath := os.Getenv("SOFT_SERVE_KEY_PATH") + reposPath := os.Getenv("SOFT_SERVE_REPO_PATH") + bindAddr := os.Getenv("SOFT_SERVE_BIND_ADDRESS") + cfg := config.DefaultConfig() + if err := cfg.ParseEnv(); err != nil { + return fmt.Errorf("parse env: %w", err) + } - // FIXME: Admin user gets created when the database is created. - sb.DeleteUser("admin") // nolint: errcheck + ctx = config.WithContext(ctx, cfg) + db, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource) + if err != nil { + return fmt.Errorf("open database: %w", err) + } - cfg = cfg.WithBackend(sb) + defer db.Close() // nolint: errcheck + sb := backend.New(ctx, cfg, db) - // Set SSH listen address - logger.Info("Setting SSH listen address...") - if bindAddr != "" { - cfg.SSH.ListenAddr = bindAddr - } + // FIXME: Admin user gets created when the database is created. + sb.DeleteUser(ctx, "admin") // nolint: errcheck - // Copy SSH host key - logger.Info("Copying SSH host key...") - if keyPath != "" { - if err := os.MkdirAll(filepath.Join(cfg.DataPath, "ssh"), os.ModePerm); err != nil { - return fmt.Errorf("failed to create ssh directory: %w", err) - } + // Set SSH listen address + logger.Info("Setting SSH listen address...") + if bindAddr != "" { + cfg.SSH.ListenAddr = bindAddr + } - if err := copyFile(keyPath, filepath.Join(cfg.DataPath, "ssh", filepath.Base(keyPath))); err != nil { - return fmt.Errorf("failed to copy ssh key: %w", err) - } + // Copy SSH host key + logger.Info("Copying SSH host key...") + if keyPath != "" { + if err := os.MkdirAll(filepath.Join(cfg.DataPath, "ssh"), os.ModePerm); err != nil { + return fmt.Errorf("failed to create ssh directory: %w", err) + } - if err := copyFile(keyPath+".pub", filepath.Join(cfg.DataPath, "ssh", filepath.Base(keyPath))+".pub"); err != nil { - logger.Errorf("failed to copy ssh key: %s", err) - } + if err := copyFile(keyPath, filepath.Join(cfg.DataPath, "ssh", filepath.Base(keyPath))); err != nil { + return fmt.Errorf("failed to copy ssh key: %w", err) + } - cfg.SSH.KeyPath = filepath.Join(cfg.DataPath, "ssh", filepath.Base(keyPath)) + if err := copyFile(keyPath+".pub", filepath.Join(cfg.DataPath, "ssh", filepath.Base(keyPath))+".pub"); err != nil { + logger.Errorf("failed to copy ssh key: %s", err) } - // Read config - logger.Info("Reading config repository...") - r, err := git.Open(filepath.Join(reposPath, "config")) + cfg.SSH.KeyPath = filepath.Join(cfg.DataPath, "ssh", filepath.Base(keyPath)) + } + + // Read config + logger.Info("Reading config repository...") + r, err := git.Open(filepath.Join(reposPath, "config")) + if err != nil { + return fmt.Errorf("failed to open config repo: %w", err) + } + + head, err := r.HEAD() + if err != nil { + return fmt.Errorf("failed to get head: %w", err) + } + + tree, err := r.TreePath(head, "") + if err != nil { + return fmt.Errorf("failed to get tree: %w", err) + } + + isJson := false // nolint: revive + te, err := tree.TreeEntry("config.yaml") + if err != nil { + te, err = tree.TreeEntry("config.json") if err != nil { - return fmt.Errorf("failed to open config repo: %w", err) + return fmt.Errorf("failed to get config file: %w", err) } + isJson = true + } - head, err := r.HEAD() - if err != nil { - return fmt.Errorf("failed to get head: %w", err) + cc, err := te.Contents() + if err != nil { + return fmt.Errorf("failed to get config contents: %w", err) + } + + var ocfg Config + if isJson { + if err := json.Unmarshal(cc, &ocfg); err != nil { + return fmt.Errorf("failed to unmarshal config: %w", err) } + } else { + if err := yaml.Unmarshal(cc, &ocfg); err != nil { + return fmt.Errorf("failed to unmarshal config: %w", err) + } + } - tree, err := r.TreePath(head, "") - if err != nil { - return fmt.Errorf("failed to get tree: %w", err) + readme, readmePath, err := git.LatestFile(r, "README*") + hasReadme := err == nil + + // Set server name + cfg.Name = ocfg.Name + + // Set server public url + cfg.SSH.PublicURL = fmt.Sprintf("ssh://%s:%d", ocfg.Host, ocfg.Port) + + // Set server settings + logger.Info("Setting server settings...") + if sb.SetAllowKeyless(ctx, ocfg.AllowKeyless) != nil { + fmt.Fprintf(os.Stderr, "failed to set allow keyless\n") + } + anon := access.ParseAccessLevel(ocfg.AnonAccess) + if anon >= 0 { + if err := sb.SetAnonAccess(ctx, anon); err != nil { + fmt.Fprintf(os.Stderr, "failed to set anon access: %s\n", err) } + } - isJson := false // nolint: revive - te, err := tree.TreeEntry("config.yaml") - if err != nil { - te, err = tree.TreeEntry("config.json") - if err != nil { - return fmt.Errorf("failed to get config file: %w", err) - } - isJson = true + // Copy repos + if reposPath != "" { + logger.Info("Copying repos...") + if err := os.MkdirAll(filepath.Join(cfg.DataPath, "repos"), os.ModePerm); err != nil { + return fmt.Errorf("failed to create repos directory: %w", err) } - cc, err := te.Contents() + dirs, err := os.ReadDir(reposPath) if err != nil { - return fmt.Errorf("failed to get config contents: %w", err) + return fmt.Errorf("failed to read repos directory: %w", err) } - var ocfg Config - if isJson { - if err := json.Unmarshal(cc, &ocfg); err != nil { - return fmt.Errorf("failed to unmarshal config: %w", err) + for _, dir := range dirs { + if !dir.IsDir() || dir.Name() == "config" { + continue + } + + if !isGitDir(filepath.Join(reposPath, dir.Name())) { + continue } - } else { - if err := yaml.Unmarshal(cc, &ocfg); err != nil { - return fmt.Errorf("failed to unmarshal config: %w", err) + + logger.Infof(" Copying repo %s", dir.Name()) + src := filepath.Join(reposPath, utils.SanitizeRepo(dir.Name())) + dst := filepath.Join(cfg.DataPath, "repos", utils.SanitizeRepo(dir.Name())) + ".git" + if err := os.MkdirAll(dst, os.ModePerm); err != nil { + return fmt.Errorf("failed to create repo directory: %w", err) + } + + if err := copyDir(src, dst); err != nil { + return fmt.Errorf("failed to copy repo: %w", err) + } + + if _, err := sb.CreateRepository(ctx, dir.Name(), proto.RepositoryOptions{}); err != nil { + fmt.Fprintf(os.Stderr, "failed to create repository: %s\n", err) } } - readme, readmePath, err := git.LatestFile(r, "README*") - hasReadme := err == nil + if hasReadme { + logger.Infof(" Copying readme from \"config\" to \".soft-serve\"") - // Set server name - cfg.Name = ocfg.Name + // Switch to main branch + bcmd := git.NewCommand("branch", "-M", "main") - // Set server public url - cfg.SSH.PublicURL = fmt.Sprintf("ssh://%s:%d", ocfg.Host, ocfg.Port) + rp := filepath.Join(cfg.DataPath, "repos", ".soft-serve.git") + nr, err := git.Init(rp, true) + if err != nil { + return fmt.Errorf("failed to init repo: %w", err) + } - // Set server settings - logger.Info("Setting server settings...") - if cfg.Backend.SetAllowKeyless(ocfg.AllowKeyless) != nil { - fmt.Fprintf(os.Stderr, "failed to set allow keyless\n") - } - anon := backend.ParseAccessLevel(ocfg.AnonAccess) - if anon >= 0 { - if err := sb.SetAnonAccess(anon); err != nil { - fmt.Fprintf(os.Stderr, "failed to set anon access: %s\n", err) + if _, err := nr.SymbolicRef("HEAD", gitm.RefsHeads+"main"); err != nil { + return fmt.Errorf("failed to set HEAD: %w", err) } - } - // Copy repos - if reposPath != "" { - logger.Info("Copying repos...") - if err := os.MkdirAll(filepath.Join(cfg.DataPath, "repos"), os.ModePerm); err != nil { - return fmt.Errorf("failed to create repos directory: %w", err) + tmpDir, err := os.MkdirTemp("", "soft-serve") + if err != nil { + return fmt.Errorf("failed to create temp dir: %w", err) } - dirs, err := os.ReadDir(reposPath) + r, err := git.Init(tmpDir, false) if err != nil { - return fmt.Errorf("failed to read repos directory: %w", err) + return fmt.Errorf("failed to clone repo: %w", err) } - for _, dir := range dirs { - if !dir.IsDir() || dir.Name() == "config" { - continue - } - - if !isGitDir(filepath.Join(reposPath, dir.Name())) { - continue - } - - logger.Infof(" Copying repo %s", dir.Name()) - src := filepath.Join(reposPath, utils.SanitizeRepo(dir.Name())) - dst := filepath.Join(cfg.DataPath, "repos", utils.SanitizeRepo(dir.Name())) + ".git" - if err := os.MkdirAll(dst, os.ModePerm); err != nil { - return fmt.Errorf("failed to create repo directory: %w", err) - } - - if err := copyDir(src, dst); err != nil { - return fmt.Errorf("failed to copy repo: %w", err) - } - - if _, err := sb.CreateRepository(dir.Name(), backend.RepositoryOptions{}); err != nil { - fmt.Fprintf(os.Stderr, "failed to create repository: %s\n", err) - } + if _, err := bcmd.RunInDir(tmpDir); err != nil { + return fmt.Errorf("failed to create main branch: %w", err) } - if hasReadme { - logger.Infof(" Copying readme from \"config\" to \".soft-serve\"") - - // Switch to main branch - bcmd := git.NewCommand("branch", "-M", "main") - - rp := filepath.Join(cfg.DataPath, "repos", ".soft-serve.git") - nr, err := git.Init(rp, true) - if err != nil { - return fmt.Errorf("failed to init repo: %w", err) - } - - if _, err := nr.SymbolicRef("HEAD", gitm.RefsHeads+"main"); err != nil { - return fmt.Errorf("failed to set HEAD: %w", err) - } - - tmpDir, err := os.MkdirTemp("", "soft-serve") - if err != nil { - return fmt.Errorf("failed to create temp dir: %w", err) - } - - r, err := git.Init(tmpDir, false) - if err != nil { - return fmt.Errorf("failed to clone repo: %w", err) - } - - if _, err := bcmd.RunInDir(tmpDir); err != nil { - return fmt.Errorf("failed to create main branch: %w", err) - } - - if err := os.WriteFile(filepath.Join(tmpDir, readmePath), []byte(readme), 0o644); err != nil { - return fmt.Errorf("failed to write readme: %w", err) - } - - if err := r.Add(gitm.AddOptions{ - All: true, - }); err != nil { - return fmt.Errorf("failed to add readme: %w", err) - } - - if err := r.Commit(&gitm.Signature{ - Name: "Soft Serve", - Email: "vt100@charm.sh", - When: time.Now(), - }, "Add readme"); err != nil { - return fmt.Errorf("failed to commit readme: %w", err) - } - - if err := r.RemoteAdd("origin", "file://"+rp); err != nil { - return fmt.Errorf("failed to add remote: %w", err) - } - - if err := r.Push("origin", "main"); err != nil { - return fmt.Errorf("failed to push readme: %w", err) - } - - // Create `.soft-serve` repository and add readme - if _, err := sb.CreateRepository(".soft-serve", backend.RepositoryOptions{ - ProjectName: "Home", - Description: "Soft Serve home repository", - Hidden: true, - Private: false, - }); err != nil { - fmt.Fprintf(os.Stderr, "failed to create repository: %s\n", err) - } + if err := os.WriteFile(filepath.Join(tmpDir, readmePath), []byte(readme), 0o644); err != nil { // nolint: gosec + return fmt.Errorf("failed to write readme: %w", err) } - } - // Set repos metadata & collabs - logger.Info("Setting repos metadata & collabs...") - for _, r := range ocfg.Repos { - repo, name := r.Repo, r.Name - // Special case for config repo - if repo == "config" { - repo = ".soft-serve" - r.Private = false + if err := r.Add(gitm.AddOptions{ + All: true, + }); err != nil { + return fmt.Errorf("failed to add readme: %w", err) } - if err := sb.SetProjectName(repo, name); err != nil { - logger.Errorf("failed to set repo name to %s: %s", repo, err) + if err := r.Commit(&gitm.Signature{ + Name: "Soft Serve", + Email: "vt100@charm.sh", + When: time.Now(), + }, "Add readme"); err != nil { + return fmt.Errorf("failed to commit readme: %w", err) } - if err := sb.SetDescription(repo, r.Note); err != nil { - logger.Errorf("failed to set repo description to %s: %s", repo, err) + if err := r.RemoteAdd("origin", "file://"+rp); err != nil { + return fmt.Errorf("failed to add remote: %w", err) } - if err := sb.SetPrivate(repo, r.Private); err != nil { - logger.Errorf("failed to set repo private to %s: %s", repo, err) + if err := r.Push("origin", "main"); err != nil { + return fmt.Errorf("failed to push readme: %w", err) } - for _, collab := range r.Collabs { - if err := sb.AddCollaborator(repo, collab); err != nil { - logger.Errorf("failed to add repo collab to %s: %s", repo, err) - } + // Create `.soft-serve` repository and add readme + if _, err := sb.CreateRepository(ctx, ".soft-serve", proto.RepositoryOptions{ + ProjectName: "Home", + Description: "Soft Serve home repository", + Hidden: true, + Private: false, + }); err != nil { + fmt.Fprintf(os.Stderr, "failed to create repository: %s\n", err) } } + } - // Create users & collabs - logger.Info("Creating users & collabs...") - for _, user := range ocfg.Users { - keys := make(map[string]ssh.PublicKey) - for _, key := range user.PublicKeys { - pk, _, err := backend.ParseAuthorizedKey(key) - if err != nil { - continue - } - ak := backend.MarshalAuthorizedKey(pk) - keys[ak] = pk - } + // Set repos metadata & collabs + logger.Info("Setting repos metadata & collabs...") + for _, r := range ocfg.Repos { + repo, name := r.Repo, r.Name + // Special case for config repo + if repo == "config" { + repo = ".soft-serve" + r.Private = false + } + + if err := sb.SetProjectName(ctx, repo, name); err != nil { + logger.Errorf("failed to set repo name to %s: %s", repo, err) + } + + if err := sb.SetDescription(ctx, repo, r.Note); err != nil { + logger.Errorf("failed to set repo description to %s: %s", repo, err) + } + + if err := sb.SetPrivate(ctx, repo, r.Private); err != nil { + logger.Errorf("failed to set repo private to %s: %s", repo, err) + } - pubkeys := make([]ssh.PublicKey, 0) - for _, pk := range keys { - pubkeys = append(pubkeys, pk) + for _, collab := range r.Collabs { + if err := sb.AddCollaborator(ctx, repo, collab); err != nil { + logger.Errorf("failed to add repo collab to %s: %s", repo, err) } + } + } - username := strings.ToLower(user.Name) - username = strings.ReplaceAll(username, " ", "-") - logger.Infof("Creating user %q", username) - if _, err := sb.CreateUser(username, backend.UserOptions{ - Admin: user.Admin, - PublicKeys: pubkeys, - }); err != nil { - logger.Errorf("failed to create user: %s", err) + // Create users & collabs + logger.Info("Creating users & collabs...") + for _, user := range ocfg.Users { + keys := make(map[string]ssh.PublicKey) + for _, key := range user.PublicKeys { + pk, _, err := sshutils.ParseAuthorizedKey(key) + if err != nil { + continue } + ak := sshutils.MarshalAuthorizedKey(pk) + keys[ak] = pk + } + + pubkeys := make([]ssh.PublicKey, 0) + for _, pk := range keys { + pubkeys = append(pubkeys, pk) + } + + username := strings.ToLower(user.Name) + username = strings.ReplaceAll(username, " ", "-") + logger.Infof("Creating user %q", username) + if _, err := sb.CreateUser(ctx, username, proto.UserOptions{ + Admin: user.Admin, + PublicKeys: pubkeys, + }); err != nil { + logger.Errorf("failed to create user: %s", err) + } - for _, repo := range user.CollabRepos { - if err := sb.AddCollaborator(repo, username); err != nil { - logger.Errorf("failed to add user collab to %s: %s\n", repo, err) - } + for _, repo := range user.CollabRepos { + if err := sb.AddCollaborator(ctx, repo, username); err != nil { + logger.Errorf("failed to add user collab to %s: %s\n", repo, err) } } + } - logger.Info("Writing config...") - defer logger.Info("Done!") - return config.WriteConfig(filepath.Join(cfg.DataPath, "config.yaml"), cfg) - }, - } -) + logger.Info("Writing config...") + defer logger.Info("Done!") + return cfg.WriteConfig() + }, +} // Returns true if path is a directory containing an `objects` directory and a // `HEAD` file. diff --git a/cmd/soft/root.go b/cmd/soft/root.go index f78c2260645f98a2bd979d70daf43bd0a759556a..b3216162f5601ed44bddfd3fa0bdcc86a8a1825c 100644 --- a/cmd/soft/root.go +++ b/cmd/soft/root.go @@ -2,13 +2,21 @@ package main import ( "context" + "fmt" "os" "runtime/debug" + "strings" + "time" "github.com/charmbracelet/log" - . "github.com/charmbracelet/soft-serve/internal/log" + "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/db" + _ "github.com/lib/pq" // postgres driver "github.com/spf13/cobra" "go.uber.org/automaxprocs/maxprocs" + + _ "modernc.org/sqlite" // sqlite driver ) var ( @@ -34,6 +42,7 @@ func init() { manCmd, hookCmd, migrateConfig, + adminCmd, ) rootCmd.CompletionOptions.HiddenDefaultCmd = true @@ -52,19 +61,111 @@ func init() { } func main() { - logger := NewDefaultLogger() + ctx := context.Background() + cfg := config.DefaultConfig() + if cfg.Exist() { + if err := cfg.Parse(); err != nil { + log.Fatal(err) + } + } + + if err := cfg.ParseEnv(); err != nil { + log.Fatal(err) + } + + ctx = config.WithContext(ctx, cfg) + logger, f, err := newDefaultLogger(cfg) + if err != nil { + log.Errorf("failed to create logger: %v", err) + } + + ctx = log.WithContext(ctx, logger) + if f != nil { + defer f.Close() // nolint: errcheck + } // Set global logger log.SetDefault(logger) + var opts []maxprocs.Option + if config.IsVerbose() { + opts = append(opts, maxprocs.Logger(log.Debugf)) + } + // Set the max number of processes to the number of CPUs // This is useful when running soft serve in a container - if _, err := maxprocs.Set(maxprocs.Logger(logger.Debugf)); err != nil { - logger.Warn("couldn't set automaxprocs", "error", err) + if _, err := maxprocs.Set(opts...); err != nil { + log.Warn("couldn't set automaxprocs", "error", err) } - ctx := log.WithContext(context.Background(), logger) if err := rootCmd.ExecuteContext(ctx); err != nil { os.Exit(1) } } + +// newDefaultLogger returns a new logger with default settings. +func newDefaultLogger(cfg *config.Config) (*log.Logger, *os.File, error) { + logger := log.NewWithOptions(os.Stderr, log.Options{ + ReportTimestamp: true, + TimeFormat: time.DateOnly, + }) + + switch { + case config.IsVerbose(): + logger.SetReportCaller(true) + fallthrough + case config.IsDebug(): + logger.SetLevel(log.DebugLevel) + } + + logger.SetTimeFormat(cfg.Log.TimeFormat) + + switch strings.ToLower(cfg.Log.Format) { + case "json": + logger.SetFormatter(log.JSONFormatter) + case "logfmt": + logger.SetFormatter(log.LogfmtFormatter) + case "text": + logger.SetFormatter(log.TextFormatter) + } + + var f *os.File + if cfg.Log.Path != "" { + f, err := os.OpenFile(cfg.Log.Path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return nil, nil, err + } + logger.SetOutput(f) + } + + return logger, f, nil +} + +func initBackendContext(cmd *cobra.Command, _ []string) error { + ctx := cmd.Context() + cfg := config.FromContext(ctx) + dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource) + if err != nil { + return fmt.Errorf("open database: %w", err) + } + + ctx = db.WithContext(ctx, dbx) + be := backend.New(ctx, cfg, dbx) + ctx = backend.WithContext(ctx, be) + + cmd.SetContext(ctx) + + return nil +} + +func closeDBContext(cmd *cobra.Command, _ []string) error { + ctx := cmd.Context() + dbx := db.FromContext(ctx) + if dbx != nil { + if err := dbx.Close(); err != nil { + return fmt.Errorf("close database: %w", err) + } + } + + return nil +} diff --git a/cmd/soft/serve.go b/cmd/soft/serve.go index b9a9b2937a1be656acbe38b32325161cee61c572..0111f311a71a1eb3d48004a52831811b02f7c86c 100644 --- a/cmd/soft/serve.go +++ b/cmd/soft/serve.go @@ -10,21 +10,39 @@ import ( "time" "github.com/charmbracelet/soft-serve/server" + "github.com/charmbracelet/soft-serve/server/backend" "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/db/migrate" + "github.com/charmbracelet/soft-serve/server/hooks" "github.com/spf13/cobra" ) var ( + syncHooks bool + serveCmd = &cobra.Command{ - Use: "serve", - Short: "Start the server", - Long: "Start the server", - Args: cobra.NoArgs, + Use: "serve", + Short: "Start the server", + Args: cobra.NoArgs, + PersistentPreRunE: initBackendContext, + PersistentPostRunE: closeDBContext, RunE: func(cmd *cobra.Command, _ []string) error { ctx := cmd.Context() cfg := config.DefaultConfig() - ctx = config.WithContext(ctx, cfg) - cmd.SetContext(ctx) + if cfg.Exist() { + if err := cfg.ParseFile(); err != nil { + return fmt.Errorf("parse config file: %w", err) + } + } else { + if err := cfg.WriteConfig(); err != nil { + return fmt.Errorf("write config file: %w", err) + } + } + + if err := cfg.ParseEnv(); err != nil { + return fmt.Errorf("parse environment variables: %w", err) + } // Create custom hooks directory if it doesn't exist customHooksPath := filepath.Join(cfg.DataPath, "hooks") @@ -33,7 +51,7 @@ var ( // Generate update hook example without executable permissions hookPath := filepath.Join(customHooksPath, "update.sample") // nolint: gosec - if err := os.WriteFile(hookPath, []byte(updateHookExample), 0744); err != nil { + if err := os.WriteFile(hookPath, []byte(updateHookExample), 0o744); err != nil { return fmt.Errorf("failed to generate update hook example: %w", err) } } @@ -44,11 +62,23 @@ var ( os.MkdirAll(logPath, os.ModePerm) // nolint: errcheck } + db := db.FromContext(ctx) + if err := migrate.Migrate(ctx, db); err != nil { + return fmt.Errorf("migration error: %w", err) + } + s, err := server.NewServer(ctx) if err != nil { return fmt.Errorf("start server: %w", err) } + if syncHooks { + be := backend.FromContext(ctx) + if err := initializeHooks(ctx, cfg, be); err != nil { + return fmt.Errorf("initialize hooks: %w", err) + } + } + done := make(chan os.Signal, 1) lch := make(chan error, 1) go func() { @@ -71,3 +101,22 @@ var ( }, } ) + +func init() { + serveCmd.Flags().BoolVarP(&syncHooks, "sync-hooks", "", false, "synchronize hooks for all repositories before running the server") +} + +func initializeHooks(ctx context.Context, cfg *config.Config, be *backend.Backend) error { + repos, err := be.Repositories(ctx) + if err != nil { + return err + } + + for _, repo := range repos { + if err := hooks.GenerateHooks(ctx, cfg, repo.Name()); err != nil { + return err + } + } + + return nil +} diff --git a/examples/setuid/main.go b/examples/setuid/main.go deleted file mode 100644 index 4b6c4770d5d6d2a759b1dd291bc67faf655faa4d..0000000000000000000000000000000000000000 --- a/examples/setuid/main.go +++ /dev/null @@ -1,74 +0,0 @@ -//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris -// +build darwin dragonfly freebsd linux netbsd openbsd solaris - -// This is an example of binding soft-serve ssh port to a restricted port (<1024) and -// then droping root privileges to a different user to run the server. -// Make sure you run this as root. - -package main - -import ( - "context" - "flag" - "fmt" - "net" - "os" - "os/signal" - "syscall" - "time" - - "github.com/charmbracelet/log" - - "github.com/charmbracelet/soft-serve/server" - "github.com/charmbracelet/soft-serve/server/config" -) - -var ( - port = flag.Int("port", 22, "port to listen on") - gid = flag.Int("gid", 1000, "group id to run as") - uid = flag.Int("uid", 1000, "user id to run as") -) - -func main() { - flag.Parse() - addr := fmt.Sprintf(":%d", *port) - // To listen on port 22 we need root privileges - ls, err := net.Listen("tcp", addr) - if err != nil { - log.Fatal("Can't listen", "err", err) - } - // We don't need root privileges any more - if err := syscall.Setgid(*gid); err != nil { - log.Fatal("Setgid error", "err", err) - } - if err := syscall.Setuid(*uid); err != nil { - log.Fatal("Setuid error", "err", err) - } - ctx := context.Background() - cfg := config.DefaultConfig() - ctx = config.WithContext(ctx, cfg) - cfg.SSH.ListenAddr = fmt.Sprintf(":%d", *port) - s, err := server.NewServer(ctx) - if err != nil { - log.Fatal(err) - } - - done := make(chan os.Signal, 1) - signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) - - log.Print("Starting SSH server", "addr", cfg.SSH.ListenAddr) - go func() { - if err := s.SSHServer.Serve(ls); err != nil { - log.Fatal(err) - } - }() - - <-done - - log.Print("Stopping SSH server", "addr", cfg.SSH.ListenAddr) - ctx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer func() { cancel() }() - if err := s.Shutdown(ctx); err != nil { - log.Fatal(err) - } -} diff --git a/git/commit.go b/git/commit.go index 6a7d2778064dd35923e92788eab19f75048a201c..3dd0a8ff377914166b3c66ca3b2748b4ee979605 100644 --- a/git/commit.go +++ b/git/commit.go @@ -4,9 +4,8 @@ import ( "github.com/gogs/git-module" ) -var ( - ZeroHash Hash = git.EmptyID -) +// ZeroHash is the zero hash. +var ZeroHash Hash = git.EmptyID // Hash represents a git hash. type Hash string diff --git a/git/patch.go b/git/patch.go index 114137090bae4584c49989981dadd9d97bf1c791..aaa6b9d78fb2692f530115e0038112bffa1dbdde 100644 --- a/git/patch.go +++ b/git/patch.go @@ -82,7 +82,7 @@ func diffsToString(diffs []diffmatchpatch.Diff, lineType git.DiffLineType) strin } } - return string(buf.Bytes()) + return buf.String() } // DiffFile is a wrapper to git.DiffFile with helper methods. diff --git a/go.mod b/go.mod index 9e30fdb459e96002f5b99b68478e038f3e5cdabe..d1fd82477d6aa2f82a7dd24edc4d0c787a37ddce 100644 --- a/go.mod +++ b/go.mod @@ -20,12 +20,13 @@ require ( require ( github.com/caarlos0/env/v8 v8.0.0 github.com/charmbracelet/keygen v0.4.3 - github.com/charmbracelet/log v0.2.2 - github.com/charmbracelet/ssh v0.0.0-20221117183211-483d43d97103 + github.com/charmbracelet/log v0.2.3-0.20230713155356-557335e40e35 + github.com/charmbracelet/ssh v0.0.0-20230712221603-7e03c5063afc github.com/gobwas/glob v0.2.3 github.com/gogs/git-module v1.8.2 github.com/hashicorp/golang-lru/v2 v2.0.4 github.com/jmoiron/sqlx v1.3.5 + github.com/lib/pq v1.2.0 github.com/lrstanley/bubblezone v0.0.0-20220716194435-3cb8c52f6a8f github.com/muesli/mango-cobra v1.2.0 github.com/muesli/roff v0.1.0 diff --git a/go.sum b/go.sum index 1744b5f06bed66a4035414d16f7e868ae962c634..68b784413895c021f3185f5335e857e01a21c6de 100644 --- a/go.sum +++ b/go.sum @@ -27,10 +27,10 @@ github.com/charmbracelet/keygen v0.4.3 h1:ywOZRwkDlpmkawl0BgLTxaYWDSqp6Y4nfVVmgy github.com/charmbracelet/keygen v0.4.3/go.mod h1:4e4FT3HSdLU/u83RfJWvzJIaVb8aX4MxtDlfXwpDJaI= github.com/charmbracelet/lipgloss v0.7.1 h1:17WMwi7N1b1rVWOjMT+rCh7sQkvDU75B2hbZpc5Kc1E= github.com/charmbracelet/lipgloss v0.7.1/go.mod h1:yG0k3giv8Qj8edTCbbg6AlQ5e8KNWpFujkNawKNhE2c= -github.com/charmbracelet/log v0.2.2 h1:CaXgos+ikGn5tcws5Cw3paQuk9e/8bIwuYGhnkqQFjo= -github.com/charmbracelet/log v0.2.2/go.mod h1:Zs11hKpb8l+UyX4y1srwZIGW+MPCXJHIty3MB9l/sno= -github.com/charmbracelet/ssh v0.0.0-20221117183211-483d43d97103 h1:wpHMERIN0pQZE635jWwT1dISgfjbpUcEma+fbPKSMCU= -github.com/charmbracelet/ssh v0.0.0-20221117183211-483d43d97103/go.mod h1:0Vm2/8yBljiLDnGJHU8ehswfawrEybGk33j5ssqKQVM= +github.com/charmbracelet/log v0.2.3-0.20230713155356-557335e40e35 h1:VXEaJ1iM2L5N8T2WVbv4y631pzCD3O9s75dONqK+87g= +github.com/charmbracelet/log v0.2.3-0.20230713155356-557335e40e35/go.mod h1:ZApwwzDbbETVTIRTk7724yQRJAXIktt98yGVMMaa3y8= +github.com/charmbracelet/ssh v0.0.0-20230712221603-7e03c5063afc h1:JUm+5HigAM5utFiThwIDX9iU0BaheKpuNVr+umi3sFg= +github.com/charmbracelet/ssh v0.0.0-20230712221603-7e03c5063afc/go.mod h1:F1vgddWsb/Yr/OZilFeRZEh5sE/qU0Dt1mKkmke6Zvg= github.com/charmbracelet/wish v1.1.1 h1:KdICASKd2oh2JPvk1Z4CJtAi97cFErXF7NKienPICO4= github.com/charmbracelet/wish v1.1.1/go.mod h1:xh4KZpSULw+Xqb9bcbhw92QAinVB75CVLWrFuyY6IVs= github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 h1:q2hJAaP1k2wIvVRd/hEHD7lacgqrCPS+k8g1MndzfWY= @@ -161,7 +161,7 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.5.2 h1:ALmeCk/px5FSm1MAcFBAsVKZjDuMVj8Tm7FFIlMJnqU= github.com/yuin/goldmark v1.5.2/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= @@ -189,14 +189,14 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.0.0-20220722155259-a9ba230a4035/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c= golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= diff --git a/internal/log/log.go b/internal/log/log.go deleted file mode 100644 index b6c4b1443d19a5bda654a22428efb1fb70cbe16d..0000000000000000000000000000000000000000 --- a/internal/log/log.go +++ /dev/null @@ -1,53 +0,0 @@ -package log - -import ( - "os" - "path/filepath" - "strconv" - "strings" - "time" - - "github.com/charmbracelet/log" - "github.com/charmbracelet/soft-serve/server/config" -) - -var contextKey = &struct{ string }{"logger"} - -// NewDefaultLogger returns a new logger with default settings. -func NewDefaultLogger() *log.Logger { - dp := os.Getenv("SOFT_SERVE_DATA_PATH") - if dp == "" { - dp = "data" - } - - cfg, err := config.ParseConfig(filepath.Join(dp, "config.yaml")) - if err != nil { - log.Errorf("failed to parse config: %v", err) - } - - logger := log.NewWithOptions(os.Stderr, log.Options{ - ReportTimestamp: true, - TimeFormat: time.DateOnly, - }) - - if debug, _ := strconv.ParseBool(os.Getenv("SOFT_SERVE_DEBUG")); debug { - logger.SetLevel(log.DebugLevel) - - if verbose, _ := strconv.ParseBool(os.Getenv("SOFT_SERVE_VERBOSE")); verbose { - logger.SetReportCaller(true) - } - } - - logger.SetTimeFormat(cfg.Log.TimeFormat) - - switch strings.ToLower(cfg.Log.Format) { - case "json": - logger.SetFormatter(log.JSONFormatter) - case "logfmt": - logger.SetFormatter(log.LogfmtFormatter) - case "text": - logger.SetFormatter(log.TextFormatter) - } - - return logger -} diff --git a/server/backend/access.go b/server/access/access.go similarity index 94% rename from server/backend/access.go rename to server/access/access.go index 8ed8dbfc339d489a5464feeaf975bc762f10d2fa..2ddc88b398c8e00eafeaee91e328859ccaaa38ba 100644 --- a/server/backend/access.go +++ b/server/access/access.go @@ -1,7 +1,7 @@ -package backend +package access // AccessLevel is the level of access allowed to a repo. -type AccessLevel int +type AccessLevel int // nolint: revive const ( // NoAccess does not allow access to the repo. diff --git a/server/access/access_test.go b/server/access/access_test.go new file mode 100644 index 0000000000000000000000000000000000000000..6d591b32aa50100b47fb3d142e4bc45b46f391b1 --- /dev/null +++ b/server/access/access_test.go @@ -0,0 +1,24 @@ +package access + +import "testing" + +func TestParseAccessLevel(t *testing.T) { + cases := []struct { + in string + out AccessLevel + }{ + {"", -1}, + {"foo", -1}, + {AdminAccess.String(), AdminAccess}, + {ReadOnlyAccess.String(), ReadOnlyAccess}, + {ReadWriteAccess.String(), ReadWriteAccess}, + {NoAccess.String(), NoAccess}, + } + + for _, c := range cases { + out := ParseAccessLevel(c.in) + if out != c.out { + t.Errorf("ParseAccessLevel(%q) => %d, want %d", c.in, out, c.out) + } + } +} diff --git a/server/backend/backend.go b/server/backend/backend.go index 1aa408b46a5bf92cf54f4d4dc676eceffabf8e6a..2d6104d56f2a32563c7724ccd762312895c804a2 100644 --- a/server/backend/backend.go +++ b/server/backend/backend.go @@ -1,47 +1,41 @@ package backend import ( - "bytes" "context" - "github.com/charmbracelet/ssh" - gossh "golang.org/x/crypto/ssh" + "github.com/charmbracelet/log" + "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/store" + "github.com/charmbracelet/soft-serve/server/store/database" ) -// Backend is an interface that handles repositories management and any -// non-Git related operations. -type Backend interface { - SettingsBackend - RepositoryStore - RepositoryMetadata - RepositoryAccess - UserStore - UserAccess - Hooks - - // WithContext returns a copy Backend with the given context. - WithContext(ctx context.Context) Backend -} - -// ParseAuthorizedKey parses an authorized key string into a public key. -func ParseAuthorizedKey(ak string) (gossh.PublicKey, string, error) { - pk, c, _, _, err := gossh.ParseAuthorizedKey([]byte(ak)) - return pk, c, err +// Backend is the Soft Serve backend that handles users, repositories, and +// server settings management and operations. +type Backend struct { + ctx context.Context + cfg *config.Config + db *db.DB + store store.Store + logger *log.Logger + cache *cache } -// MarshalAuthorizedKey marshals a public key into an authorized key string. -// -// This is the inverse of ParseAuthorizedKey. -// This function is a copy of ssh.MarshalAuthorizedKey, but without the trailing newline. -// It returns an empty string if pk is nil. -func MarshalAuthorizedKey(pk gossh.PublicKey) string { - if pk == nil { - return "" +// New returns a new Soft Serve backend. +func New(ctx context.Context, cfg *config.Config, db *db.DB) *Backend { + dbstore := database.New(ctx, db) + logger := log.FromContext(ctx).WithPrefix("backend") + b := &Backend{ + ctx: ctx, + cfg: cfg, + db: db, + store: dbstore, + logger: logger, } - return string(bytes.TrimSuffix(gossh.MarshalAuthorizedKey(pk), []byte("\n"))) -} -// KeysEqual returns whether the two public keys are equal. -func KeysEqual(a, b gossh.PublicKey) bool { - return ssh.KeysEqual(a, b) + // TODO: implement a proper caching interface + cache := newCache(b, 1000) + b.cache = cache + + return b } diff --git a/server/backend/cache.go b/server/backend/cache.go new file mode 100644 index 0000000000000000000000000000000000000000..8d99401dd94fee4c02ce145693de8c53ee5bfaae --- /dev/null +++ b/server/backend/cache.go @@ -0,0 +1,35 @@ +package backend + +import lru "github.com/hashicorp/golang-lru/v2" + +// TODO: implement a caching interface. +type cache struct { + b *Backend + repos *lru.Cache[string, *repo] +} + +func newCache(b *Backend, size int) *cache { + if size <= 0 { + size = 1 + } + c := &cache{b: b} + cache, _ := lru.New[string, *repo](size) + c.repos = cache + return c +} + +func (c *cache) Get(repo string) (*repo, bool) { + return c.repos.Get(repo) +} + +func (c *cache) Set(repo string, r *repo) { + c.repos.Add(repo, r) +} + +func (c *cache) Delete(repo string) { + c.repos.Remove(repo) +} + +func (c *cache) Len() int { + return c.repos.Len() +} diff --git a/server/backend/collab.go b/server/backend/collab.go new file mode 100644 index 0000000000000000000000000000000000000000..92bfca821854554fc7e879474da2a21ac2fef8e4 --- /dev/null +++ b/server/backend/collab.go @@ -0,0 +1,78 @@ +package backend + +import ( + "context" + "strings" + + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/db/models" + "github.com/charmbracelet/soft-serve/server/utils" +) + +// AddCollaborator adds a collaborator to a repository. +// +// It implements backend.Backend. +func (d *Backend) AddCollaborator(ctx context.Context, repo string, username string) error { + username = strings.ToLower(username) + if err := utils.ValidateUsername(username); err != nil { + return err + } + + repo = utils.SanitizeRepo(repo) + return db.WrapError( + d.db.TransactionContext(ctx, func(tx *db.Tx) error { + return d.store.AddCollabByUsernameAndRepo(ctx, tx, username, repo) + }), + ) +} + +// Collaborators returns a list of collaborators for a repository. +// +// It implements backend.Backend. +func (d *Backend) Collaborators(ctx context.Context, repo string) ([]string, error) { + repo = utils.SanitizeRepo(repo) + var users []models.User + if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { + var err error + users, err = d.store.ListCollabsByRepoAsUsers(ctx, tx, repo) + return err + }); err != nil { + return nil, db.WrapError(err) + } + + var usernames []string + for _, u := range users { + usernames = append(usernames, u.Username) + } + + return usernames, nil +} + +// IsCollaborator returns true if the user is a collaborator of the repository. +// +// It implements backend.Backend. +func (d *Backend) IsCollaborator(ctx context.Context, repo string, username string) (bool, error) { + repo = utils.SanitizeRepo(repo) + var m models.Collab + if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { + var err error + m, err = d.store.GetCollabByUsernameAndRepo(ctx, tx, username, repo) + return err + }); err != nil { + return false, db.WrapError(err) + } + + return m.ID > 0, nil +} + +// RemoveCollaborator removes a collaborator from a repository. +// +// It implements backend.Backend. +func (d *Backend) RemoveCollaborator(ctx context.Context, repo string, username string) error { + repo = utils.SanitizeRepo(repo) + return db.WrapError( + d.db.TransactionContext(ctx, func(tx *db.Tx) error { + return d.store.RemoveCollabByUsernameAndRepo(ctx, tx, username, repo) + }), + ) +} diff --git a/server/backend/context.go b/server/backend/context.go index 1af19057d57ddf2c61bba91460d996af66884ec0..8971bbf9360856af3cb8a7c9d8bae35feff221f9 100644 --- a/server/backend/context.go +++ b/server/backend/context.go @@ -2,11 +2,12 @@ package backend import "context" -var contextKey = &struct{ string }{"backend"} +// ContextKey is the key for the backend in the context. +var ContextKey = &struct{ string }{"backend"} // FromContext returns the backend from a context. -func FromContext(ctx context.Context) Backend { - if b, ok := ctx.Value(contextKey).(Backend); ok { +func FromContext(ctx context.Context) *Backend { + if b, ok := ctx.Value(ContextKey).(*Backend); ok { return b } @@ -14,6 +15,6 @@ func FromContext(ctx context.Context) Backend { } // WithContext returns a new context with the backend attached. -func WithContext(ctx context.Context, b Backend) context.Context { - return context.WithValue(ctx, contextKey, b) +func WithContext(ctx context.Context, b *Backend) context.Context { + return context.WithValue(ctx, ContextKey, b) } diff --git a/server/backend/hooks.go b/server/backend/hooks.go index ba130a30195a734a8b0626524dd6e9c9c9a0c056..b28c5ecd78bf8a20d30f26cd0b4fe10adda3bae1 100644 --- a/server/backend/hooks.go +++ b/server/backend/hooks.go @@ -1,20 +1,80 @@ package backend import ( + "context" "io" + "sync" + + "github.com/charmbracelet/soft-serve/server/hooks" + "github.com/charmbracelet/soft-serve/server/proto" ) -// HookArg is an argument to a git hook. -type HookArg struct { - OldSha string - NewSha string - RefName string +var _ hooks.Hooks = (*Backend)(nil) + +// PostReceive is called by the git post-receive hook. +// +// It implements Hooks. +func (d *Backend) PostReceive(_ context.Context, _ io.Writer, _ io.Writer, repo string, args []hooks.HookArg) { + d.logger.Debug("post-receive hook called", "repo", repo, "args", args) +} + +// PreReceive is called by the git pre-receive hook. +// +// It implements Hooks. +func (d *Backend) PreReceive(_ context.Context, _ io.Writer, _ io.Writer, repo string, args []hooks.HookArg) { + d.logger.Debug("pre-receive hook called", "repo", repo, "args", args) } -// Hooks provides an interface for git server-side hooks. -type Hooks interface { - PreReceive(stdout io.Writer, stderr io.Writer, repo string, args []HookArg) - Update(stdout io.Writer, stderr io.Writer, repo string, arg HookArg) - PostReceive(stdout io.Writer, stderr io.Writer, repo string, args []HookArg) - PostUpdate(stdout io.Writer, stderr io.Writer, repo string, args ...string) +// Update is called by the git update hook. +// +// It implements Hooks. +func (d *Backend) Update(_ context.Context, _ io.Writer, _ io.Writer, repo string, arg hooks.HookArg) { + d.logger.Debug("update hook called", "repo", repo, "arg", arg) +} + +// PostUpdate is called by the git post-update hook. +// +// It implements Hooks. +func (d *Backend) PostUpdate(ctx context.Context, _ io.Writer, _ io.Writer, repo string, args ...string) { + d.logger.Debug("post-update hook called", "repo", repo, "args", args) + + var wg sync.WaitGroup + + // Populate last-modified file. + wg.Add(1) + go func() { + defer wg.Done() + if err := populateLastModified(ctx, d, repo); err != nil { + d.logger.Error("error populating last-modified", "repo", repo, "err", err) + return + } + }() + + wg.Wait() +} + +func populateLastModified(ctx context.Context, d *Backend, name string) error { + var rr *repo + _rr, err := d.Repository(ctx, name) + if err != nil { + return err + } + + if r, ok := _rr.(*repo); ok { + rr = r + } else { + return proto.ErrRepoNotExist + } + + r, err := rr.Open() + if err != nil { + return err + } + + c, err := r.LatestCommitTime() + if err != nil { + return err + } + + return rr.writeLastModified(c) } diff --git a/server/backend/repo.go b/server/backend/repo.go index 8d7e9cdeffa516b70c276c383b4d78746ead6876..5f07e61f175140795062d954ce67e04897d6e06f 100644 --- a/server/backend/repo.go +++ b/server/backend/repo.go @@ -1,86 +1,484 @@ package backend import ( + "bufio" + "context" + "errors" + "fmt" + "os" + "path/filepath" "time" "github.com/charmbracelet/soft-serve/git" + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/db/models" + "github.com/charmbracelet/soft-serve/server/hooks" + "github.com/charmbracelet/soft-serve/server/proto" + "github.com/charmbracelet/soft-serve/server/utils" ) -// RepositoryOptions are options for creating a new repository. -type RepositoryOptions struct { - Private bool - Description string - ProjectName string - Mirror bool - Hidden bool -} - -// RepositoryStore is an interface for managing repositories. -type RepositoryStore interface { - // Repository finds the given repository. - Repository(repo string) (Repository, error) - // Repositories returns a list of all repositories. - Repositories() ([]Repository, error) - // CreateRepository creates a new repository. - CreateRepository(name string, opts RepositoryOptions) (Repository, error) - // ImportRepository creates a new repository from a Git repository. - ImportRepository(name string, remote string, opts RepositoryOptions) (Repository, error) - // DeleteRepository deletes a repository. - DeleteRepository(name string) error - // RenameRepository renames a repository. - RenameRepository(oldName, newName string) error -} - -// RepositoryMetadata is an interface for managing repository metadata. -type RepositoryMetadata interface { - // ProjectName returns the repository's project name. - ProjectName(repo string) (string, error) - // SetProjectName sets the repository's project name. - SetProjectName(repo, name string) error - // Description returns the repository's description. - Description(repo string) (string, error) - // SetDescription sets the repository's description. - SetDescription(repo, desc string) error - // IsPrivate returns whether the repository is private. - IsPrivate(repo string) (bool, error) - // SetPrivate sets whether the repository is private. - SetPrivate(repo string, private bool) error - // IsMirror returns whether the repository is a mirror. - IsMirror(repo string) (bool, error) - // IsHidden returns whether the repository is hidden. - IsHidden(repo string) (bool, error) - // SetHidden sets whether the repository is hidden. - SetHidden(repo string, hidden bool) error -} - -// RepositoryAccess is an interface for managing repository access. -type RepositoryAccess interface { - IsCollaborator(repo string, username string) (bool, error) - // AddCollaborator adds the authorized key as a collaborator on the repository. - AddCollaborator(repo string, username string) error - // RemoveCollaborator removes the authorized key as a collaborator on the repository. - RemoveCollaborator(repo string, username string) error - // Collaborators returns a list of all collaborators on the repository. - Collaborators(repo string) ([]string, error) -} - -// Repository is a Git repository interface. -type Repository interface { - // Name returns the repository's name. - Name() string - // ProjectName returns the repository's project name. - ProjectName() string - // Description returns the repository's description. - Description() string - // IsPrivate returns whether the repository is private. - IsPrivate() bool - // IsMirror returns whether the repository is a mirror. - IsMirror() bool - // IsHidden returns whether the repository is hidden. - IsHidden() bool - // UpdatedAt returns the time the repository was last updated. - // If the repository has never been updated, it returns the time it was created. - UpdatedAt() time.Time - // Open returns the underlying git.Repository. - Open() (*git.Repository, error) +func (d *Backend) reposPath() string { + return filepath.Join(d.cfg.DataPath, "repos") +} + +// CreateRepository creates a new repository. +// +// It implements backend.Backend. +func (d *Backend) CreateRepository(ctx context.Context, name string, opts proto.RepositoryOptions) (proto.Repository, error) { + name = utils.SanitizeRepo(name) + if err := utils.ValidateRepo(name); err != nil { + return nil, err + } + + repo := name + ".git" + rp := filepath.Join(d.reposPath(), repo) + + if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { + if err := d.store.CreateRepo( + ctx, + tx, + name, + opts.ProjectName, + opts.Description, + opts.Private, + opts.Hidden, + opts.Mirror, + ); err != nil { + return err + } + + _, err := git.Init(rp, true) + if err != nil { + d.logger.Debug("failed to create repository", "err", err) + return err + } + + return hooks.GenerateHooks(ctx, d.cfg, repo) + }); err != nil { + d.logger.Debug("failed to create repository in database", "err", err) + return nil, db.WrapError(err) + } + + return d.Repository(ctx, name) +} + +// ImportRepository imports a repository from remote. +func (d *Backend) ImportRepository(ctx context.Context, name string, remote string, opts proto.RepositoryOptions) (proto.Repository, error) { + name = utils.SanitizeRepo(name) + if err := utils.ValidateRepo(name); err != nil { + return nil, err + } + + repo := name + ".git" + rp := filepath.Join(d.reposPath(), repo) + + if _, err := os.Stat(rp); err == nil || os.IsExist(err) { + return nil, proto.ErrRepoExist + } + + copts := git.CloneOptions{ + Bare: true, + Mirror: opts.Mirror, + Quiet: true, + CommandOptions: git.CommandOptions{ + Timeout: -1, + Context: ctx, + Envs: []string{ + fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile="%s" -o StrictHostKeyChecking=no -i "%s"`, + filepath.Join(d.cfg.DataPath, "ssh", "known_hosts"), + d.cfg.SSH.ClientKeyPath, + ), + }, + }, + // Timeout: time.Hour, + } + + if err := git.Clone(remote, rp, copts); err != nil { + d.logger.Error("failed to clone repository", "err", err, "mirror", opts.Mirror, "remote", remote, "path", rp) + // Cleanup the mess! + if rerr := os.RemoveAll(rp); rerr != nil { + err = errors.Join(err, rerr) + } + return nil, err + } + + return d.CreateRepository(ctx, name, opts) +} + +// DeleteRepository deletes a repository. +// +// It implements backend.Backend. +func (d *Backend) DeleteRepository(ctx context.Context, name string) error { + name = utils.SanitizeRepo(name) + repo := name + ".git" + rp := filepath.Join(d.reposPath(), repo) + + return d.db.TransactionContext(ctx, func(tx *db.Tx) error { + // Delete repo from cache + defer d.cache.Delete(name) + + if err := d.store.DeleteRepoByName(ctx, tx, name); err != nil { + return err + } + + return os.RemoveAll(rp) + }) +} + +// RenameRepository renames a repository. +// +// It implements backend.Backend. +func (d *Backend) RenameRepository(ctx context.Context, oldName string, newName string) error { + oldName = utils.SanitizeRepo(oldName) + if err := utils.ValidateRepo(oldName); err != nil { + return err + } + + newName = utils.SanitizeRepo(newName) + if err := utils.ValidateRepo(newName); err != nil { + return err + } + oldRepo := oldName + ".git" + newRepo := newName + ".git" + op := filepath.Join(d.reposPath(), oldRepo) + np := filepath.Join(d.reposPath(), newRepo) + if _, err := os.Stat(op); err != nil { + return proto.ErrRepoNotExist + } + + if _, err := os.Stat(np); err == nil { + return proto.ErrRepoExist + } + + if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { + // Delete cache + defer d.cache.Delete(oldName) + + if err := d.store.SetRepoNameByName(ctx, tx, oldName, newName); err != nil { + return err + } + + // Make sure the new repository parent directory exists. + if err := os.MkdirAll(filepath.Dir(np), os.ModePerm); err != nil { + return err + } + + return os.Rename(op, np) + }); err != nil { + return db.WrapError(err) + } + + return nil +} + +// Repositories returns a list of repositories per page. +// +// It implements backend.Backend. +func (d *Backend) Repositories(ctx context.Context) ([]proto.Repository, error) { + repos := make([]proto.Repository, 0) + + if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { + ms, err := d.store.GetAllRepos(ctx, tx) + if err != nil { + return err + } + + for _, m := range ms { + r := &repo{ + name: m.Name, + path: filepath.Join(d.reposPath(), m.Name+".git"), + repo: m, + } + + // Cache repositories + d.cache.Set(m.Name, r) + + repos = append(repos, r) + } + + return nil + }); err != nil { + return nil, db.WrapError(err) + } + + return repos, nil +} + +// Repository returns a repository by name. +// +// It implements backend.Backend. +func (d *Backend) Repository(ctx context.Context, name string) (proto.Repository, error) { + var m models.Repo + name = utils.SanitizeRepo(name) + + if r, ok := d.cache.Get(name); ok && r != nil { + return r, nil + } + + rp := filepath.Join(d.reposPath(), name+".git") + if _, err := os.Stat(rp); err != nil { + return nil, os.ErrNotExist + } + + if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { + var err error + m, err = d.store.GetRepoByName(ctx, tx, name) + return err + }); err != nil { + return nil, db.WrapError(err) + } + + r := &repo{ + name: name, + path: rp, + repo: m, + } + + // Add to cache + d.cache.Set(name, r) + + return r, nil +} + +// Description returns the description of a repository. +// +// It implements backend.Backend. +func (d *Backend) Description(ctx context.Context, name string) (string, error) { + name = utils.SanitizeRepo(name) + var desc string + if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { + var err error + desc, err = d.store.GetRepoDescriptionByName(ctx, tx, name) + return err + }); err != nil { + return "", db.WrapError(err) + } + + return desc, nil +} + +// IsMirror returns true if the repository is a mirror. +// +// It implements backend.Backend. +func (d *Backend) IsMirror(ctx context.Context, name string) (bool, error) { + name = utils.SanitizeRepo(name) + var mirror bool + if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { + var err error + mirror, err = d.store.GetRepoIsMirrorByName(ctx, tx, name) + return err + }); err != nil { + return false, db.WrapError(err) + } + return mirror, nil +} + +// IsPrivate returns true if the repository is private. +// +// It implements backend.Backend. +func (d *Backend) IsPrivate(ctx context.Context, name string) (bool, error) { + name = utils.SanitizeRepo(name) + var private bool + if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { + var err error + private, err = d.store.GetRepoIsPrivateByName(ctx, tx, name) + return err + }); err != nil { + return false, db.WrapError(err) + } + + return private, nil +} + +// IsHidden returns true if the repository is hidden. +// +// It implements backend.Backend. +func (d *Backend) IsHidden(ctx context.Context, name string) (bool, error) { + name = utils.SanitizeRepo(name) + var hidden bool + if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { + var err error + hidden, err = d.store.GetRepoIsHiddenByName(ctx, tx, name) + return err + }); err != nil { + return false, db.WrapError(err) + } + + return hidden, nil +} + +// ProjectName returns the project name of a repository. +// +// It implements backend.Backend. +func (d *Backend) ProjectName(ctx context.Context, name string) (string, error) { + name = utils.SanitizeRepo(name) + var pname string + if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { + var err error + pname, err = d.store.GetRepoProjectNameByName(ctx, tx, name) + return err + }); err != nil { + return "", db.WrapError(err) + } + + return pname, nil +} + +// SetHidden sets the hidden flag of a repository. +// +// It implements backend.Backend. +func (d *Backend) SetHidden(ctx context.Context, name string, hidden bool) error { + name = utils.SanitizeRepo(name) + + // Delete cache + d.cache.Delete(name) + + return db.WrapError(d.db.TransactionContext(ctx, func(tx *db.Tx) error { + return d.store.SetRepoIsHiddenByName(ctx, tx, name, hidden) + })) +} + +// SetDescription sets the description of a repository. +// +// It implements backend.Backend. +func (d *Backend) SetDescription(ctx context.Context, repo string, desc string) error { + repo = utils.SanitizeRepo(repo) + + // Delete cache + d.cache.Delete(repo) + + return d.db.TransactionContext(ctx, func(tx *db.Tx) error { + return d.store.SetRepoDescriptionByName(ctx, tx, repo, desc) + }) +} + +// SetPrivate sets the private flag of a repository. +// +// It implements backend.Backend. +func (d *Backend) SetPrivate(ctx context.Context, repo string, private bool) error { + repo = utils.SanitizeRepo(repo) + + // Delete cache + d.cache.Delete(repo) + + return db.WrapError( + d.db.TransactionContext(ctx, func(tx *db.Tx) error { + return d.store.SetRepoIsPrivateByName(ctx, tx, repo, private) + }), + ) +} + +// SetProjectName sets the project name of a repository. +// +// It implements backend.Backend. +func (d *Backend) SetProjectName(ctx context.Context, repo string, name string) error { + repo = utils.SanitizeRepo(repo) + + // Delete cache + d.cache.Delete(repo) + + return db.WrapError( + d.db.TransactionContext(ctx, func(tx *db.Tx) error { + return d.store.SetRepoProjectNameByName(ctx, tx, repo, name) + }), + ) +} + +var _ proto.Repository = (*repo)(nil) + +// repo is a Git repository with metadata stored in a SQLite database. +type repo struct { + name string + path string + repo models.Repo +} + +// Description returns the repository's description. +// +// It implements backend.Repository. +func (r *repo) Description() string { + return r.repo.Description +} + +// IsMirror returns whether the repository is a mirror. +// +// It implements backend.Repository. +func (r *repo) IsMirror() bool { + return r.repo.Mirror +} + +// IsPrivate returns whether the repository is private. +// +// It implements backend.Repository. +func (r *repo) IsPrivate() bool { + return r.repo.Private +} + +// Name returns the repository's name. +// +// It implements backend.Repository. +func (r *repo) Name() string { + return r.name +} + +// Open opens the repository. +// +// It implements backend.Repository. +func (r *repo) Open() (*git.Repository, error) { + return git.Open(r.path) +} + +// ProjectName returns the repository's project name. +// +// It implements backend.Repository. +func (r *repo) ProjectName() string { + return r.repo.ProjectName +} + +// IsHidden returns whether the repository is hidden. +// +// It implements backend.Repository. +func (r *repo) IsHidden() bool { + return r.repo.Hidden +} + +// UpdatedAt returns the repository's last update time. +func (r *repo) UpdatedAt() time.Time { + // Try to read the last modified time from the info directory. + if t, err := readOneline(filepath.Join(r.path, "info", "last-modified")); err == nil { + if t, err := time.Parse(time.RFC3339, t); err == nil { + return t + } + } + + rr, err := git.Open(r.path) + if err == nil { + t, err := rr.LatestCommitTime() + if err == nil { + return t + } + } + + return r.repo.UpdatedAt +} + +func (r *repo) writeLastModified(t time.Time) error { + fp := filepath.Join(r.path, "info", "last-modified") + if err := os.MkdirAll(filepath.Dir(fp), os.ModePerm); err != nil { + return err + } + + return os.WriteFile(fp, []byte(t.Format(time.RFC3339)), os.ModePerm) +} + +func readOneline(path string) (string, error) { + f, err := os.Open(path) + if err != nil { + return "", err + } + + defer f.Close() // nolint: errcheck + s := bufio.NewScanner(f) + s.Scan() + return s.Text(), s.Err() } diff --git a/server/backend/settings.go b/server/backend/settings.go index c3e8a79023f250aeb7410bb46d76d8dacf261bc5..0879aa38d0b23aec92348ec877a4b69657b71320 100644 --- a/server/backend/settings.go +++ b/server/backend/settings.go @@ -1,13 +1,58 @@ package backend -// SettingsBackend is an interface that handles server configuration. -type SettingsBackend interface { - // AnonAccess returns the access level for anonymous users. - AnonAccess() AccessLevel - // SetAnonAccess sets the access level for anonymous users. - SetAnonAccess(level AccessLevel) error - // AllowKeyless returns true if keyless access is allowed. - AllowKeyless() bool - // SetAllowKeyless sets whether or not keyless access is allowed. - SetAllowKeyless(allow bool) error +import ( + "context" + + "github.com/charmbracelet/soft-serve/server/access" + "github.com/charmbracelet/soft-serve/server/db" +) + +// AllowKeyless returns whether or not keyless access is allowed. +// +// It implements backend.Backend. +func (b *Backend) AllowKeyless(ctx context.Context) bool { + var allow bool + if err := b.db.TransactionContext(ctx, func(tx *db.Tx) error { + var err error + allow, err = b.store.GetAllowKeylessAccess(ctx, tx) + return err + }); err != nil { + return false + } + + return allow +} + +// SetAllowKeyless sets whether or not keyless access is allowed. +// +// It implements backend.Backend. +func (b *Backend) SetAllowKeyless(ctx context.Context, allow bool) error { + return b.db.TransactionContext(ctx, func(tx *db.Tx) error { + return b.store.SetAllowKeylessAccess(ctx, tx, allow) + }) +} + +// AnonAccess returns the level of anonymous access. +// +// It implements backend.Backend. +func (b *Backend) AnonAccess(ctx context.Context) access.AccessLevel { + var level access.AccessLevel + if err := b.db.TransactionContext(ctx, func(tx *db.Tx) error { + var err error + level, err = b.store.GetAnonAccess(ctx, tx) + return err + }); err != nil { + return access.NoAccess + } + + return level +} + +// SetAnonAccess sets the level of anonymous access. +// +// It implements backend.Backend. +func (b *Backend) SetAnonAccess(ctx context.Context, level access.AccessLevel) error { + return b.db.TransactionContext(ctx, func(tx *db.Tx) error { + return b.store.SetAnonAccess(ctx, tx, level) + }) } diff --git a/server/backend/sqlite/db.go b/server/backend/sqlite/db.go deleted file mode 100644 index fac0394f4309af33f7c26cc84dcb94904013f512..0000000000000000000000000000000000000000 --- a/server/backend/sqlite/db.go +++ /dev/null @@ -1,141 +0,0 @@ -package sqlite - -import ( - "context" - "database/sql" - "errors" - "fmt" - - "github.com/charmbracelet/soft-serve/server/backend" - "github.com/jmoiron/sqlx" - "modernc.org/sqlite" - sqlite3 "modernc.org/sqlite/lib" -) - -// Close closes the database. -func (d *SqliteBackend) Close() error { - return d.db.Close() -} - -// init creates the database. -func (d *SqliteBackend) init() error { - return wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { - if _, err := tx.Exec(sqlCreateSettingsTable); err != nil { - return err - } - if _, err := tx.Exec(sqlCreateUserTable); err != nil { - return err - } - if _, err := tx.Exec(sqlCreatePublicKeyTable); err != nil { - return err - } - if _, err := tx.Exec(sqlCreateRepoTable); err != nil { - return err - } - if _, err := tx.Exec(sqlCreateCollabTable); err != nil { - return err - } - - // Set default settings. - if _, err := tx.Exec("INSERT OR IGNORE INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)", "allow_keyless", true); err != nil { - return err - } - if _, err := tx.Exec("INSERT OR IGNORE INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)", "anon_access", backend.ReadOnlyAccess.String()); err != nil { - return err - } - - var init bool - if err := tx.Get(&init, "SELECT value FROM settings WHERE key = 'init'"); err != nil && !errors.Is(err, sql.ErrNoRows) { - return err - } - - // Create default user. - if !init { - r, err := tx.Exec("INSERT OR IGNORE INTO user (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP);", "admin", true) - if err != nil { - return err - } - userID, err := r.LastInsertId() - if err != nil { - return err - } - - // Add initial keys - // Don't use cfg.AdminKeys since it also includes the internal key - // used for internal api access. - for _, k := range d.cfg.InitialAdminKeys { - pk, _, err := backend.ParseAuthorizedKey(k) - if err != nil { - d.logger.Error("error parsing initial admin key, skipping", "key", k, "err", err) - continue - } - - stmt, err := tx.Prepare(`INSERT INTO public_key (user_id, public_key, updated_at) - VALUES (?, ?, CURRENT_TIMESTAMP);`) - if err != nil { - return err - } - - defer stmt.Close() // nolint: errcheck - if _, err := stmt.Exec(userID, backend.MarshalAuthorizedKey(pk)); err != nil { - return err - } - } - } - - // set init flag - if _, err := tx.Exec("INSERT OR IGNORE INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)", "init", true); err != nil { - return err - } - - return nil - }) -} - -func wrapDbErr(err error) error { - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return ErrNoRecord - } - if liteErr, ok := err.(*sqlite.Error); ok { - code := liteErr.Code() - if code == sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY || - code == sqlite3.SQLITE_CONSTRAINT_UNIQUE { - return ErrDuplicateKey - } - } - } - return err -} - -func wrapTx(db *sqlx.DB, ctx context.Context, fn func(tx *sqlx.Tx) error) error { - tx, err := db.BeginTxx(ctx, nil) - if err != nil { - return fmt.Errorf("failed to begin transaction: %w", err) - } - - if err := fn(tx); err != nil { - return rollback(tx, err) - } - - if err := tx.Commit(); err != nil { - if errors.Is(err, sql.ErrTxDone) { - // this is ok because whoever did finish the tx should have also written the error already. - return nil - } - return fmt.Errorf("failed to commit transaction: %w", err) - } - - return nil -} - -func rollback(tx *sqlx.Tx, err error) error { - if rerr := tx.Rollback(); rerr != nil { - if errors.Is(rerr, sql.ErrTxDone) { - return err - } - return fmt.Errorf("failed to rollback: %s: %w", err.Error(), rerr) - } - - return err -} diff --git a/server/backend/sqlite/error.go b/server/backend/sqlite/error.go deleted file mode 100644 index 8476f640ccd3fb0728e5521e1e2db2d1ec268ce0..0000000000000000000000000000000000000000 --- a/server/backend/sqlite/error.go +++ /dev/null @@ -1,20 +0,0 @@ -package sqlite - -import ( - "errors" - "fmt" -) - -var ( - // ErrDuplicateKey is returned when a unique constraint is violated. - ErrDuplicateKey = errors.New("record already exists") - - // ErrNoRecord is returned when a record is not found. - ErrNoRecord = errors.New("record not found") - - // ErrRepoNotExist is returned when a repository does not exist. - ErrRepoNotExist = fmt.Errorf("repository does not exist") - - // ErrRepoExist is returned when a repository already exists. - ErrRepoExist = fmt.Errorf("repository already exists") -) diff --git a/server/backend/sqlite/hooks.go b/server/backend/sqlite/hooks.go deleted file mode 100644 index 972b3f31d9be55b4dd2547c82e4cbc797dd4110f..0000000000000000000000000000000000000000 --- a/server/backend/sqlite/hooks.go +++ /dev/null @@ -1,76 +0,0 @@ -package sqlite - -import ( - "io" - "sync" - - "github.com/charmbracelet/soft-serve/server/backend" -) - -// PostReceive is called by the git post-receive hook. -// -// It implements Hooks. -func (d *SqliteBackend) PostReceive(stdout io.Writer, stderr io.Writer, repo string, args []backend.HookArg) { - d.logger.Debug("post-receive hook called", "repo", repo, "args", args) -} - -// PreReceive is called by the git pre-receive hook. -// -// It implements Hooks. -func (d *SqliteBackend) PreReceive(stdout io.Writer, stderr io.Writer, repo string, args []backend.HookArg) { - d.logger.Debug("pre-receive hook called", "repo", repo, "args", args) -} - -// Update is called by the git update hook. -// -// It implements Hooks. -func (d *SqliteBackend) Update(stdout io.Writer, stderr io.Writer, repo string, arg backend.HookArg) { - d.logger.Debug("update hook called", "repo", repo, "arg", arg) -} - -// PostUpdate is called by the git post-update hook. -// -// It implements Hooks. -func (d *SqliteBackend) PostUpdate(stdout io.Writer, stderr io.Writer, repo string, args ...string) { - d.logger.Debug("post-update hook called", "repo", repo, "args", args) - - var wg sync.WaitGroup - - // Populate last-modified file. - wg.Add(1) - go func() { - defer wg.Done() - if err := populateLastModified(d, repo); err != nil { - d.logger.Error("error populating last-modified", "repo", repo, "err", err) - return - } - }() - - wg.Wait() -} - -func populateLastModified(d *SqliteBackend, repo string) error { - var rr *Repo - _rr, err := d.Repository(repo) - if err != nil { - return err - } - - if r, ok := _rr.(*Repo); ok { - rr = r - } else { - return ErrRepoNotExist - } - - r, err := rr.Open() - if err != nil { - return err - } - - c, err := r.LatestCommitTime() - if err != nil { - return err - } - - return rr.writeLastModified(c) -} diff --git a/server/backend/sqlite/repo.go b/server/backend/sqlite/repo.go deleted file mode 100644 index 34b5e6baa597d7271f61e30189c006105ce8aee0..0000000000000000000000000000000000000000 --- a/server/backend/sqlite/repo.go +++ /dev/null @@ -1,202 +0,0 @@ -package sqlite - -import ( - "bufio" - "context" - "os" - "path/filepath" - "sync" - "time" - - "github.com/charmbracelet/soft-serve/git" - "github.com/charmbracelet/soft-serve/server/backend" - "github.com/jmoiron/sqlx" -) - -var _ backend.Repository = (*Repo)(nil) - -// Repo is a Git repository with metadata stored in a SQLite database. -type Repo struct { - name string - path string - db *sqlx.DB - - // cache - // updatedAt is cached in "last-modified" file. - mu sync.Mutex - desc *string - projectName *string - isMirror *bool - isPrivate *bool - isHidden *bool -} - -// Description returns the repository's description. -// -// It implements backend.Repository. -func (r *Repo) Description() string { - r.mu.Lock() - defer r.mu.Unlock() - if r.desc != nil { - return *r.desc - } - - var desc string - if err := wrapTx(r.db, context.Background(), func(tx *sqlx.Tx) error { - return tx.Get(&desc, "SELECT description FROM repo WHERE name = ?", r.name) - }); err != nil { - return "" - } - - r.desc = &desc - return desc -} - -// IsMirror returns whether the repository is a mirror. -// -// It implements backend.Repository. -func (r *Repo) IsMirror() bool { - r.mu.Lock() - defer r.mu.Unlock() - if r.isMirror != nil { - return *r.isMirror - } - - var mirror bool - if err := wrapTx(r.db, context.Background(), func(tx *sqlx.Tx) error { - return tx.Get(&mirror, "SELECT mirror FROM repo WHERE name = ?", r.name) - }); err != nil { - return false - } - - r.isMirror = &mirror - return mirror -} - -// IsPrivate returns whether the repository is private. -// -// It implements backend.Repository. -func (r *Repo) IsPrivate() bool { - r.mu.Lock() - defer r.mu.Unlock() - if r.isPrivate != nil { - return *r.isPrivate - } - - var private bool - if err := wrapTx(r.db, context.Background(), func(tx *sqlx.Tx) error { - return tx.Get(&private, "SELECT private FROM repo WHERE name = ?", r.name) - }); err != nil { - return false - } - - r.isPrivate = &private - return private -} - -// Name returns the repository's name. -// -// It implements backend.Repository. -func (r *Repo) Name() string { - return r.name -} - -// Open opens the repository. -// -// It implements backend.Repository. -func (r *Repo) Open() (*git.Repository, error) { - return git.Open(r.path) -} - -// ProjectName returns the repository's project name. -// -// It implements backend.Repository. -func (r *Repo) ProjectName() string { - r.mu.Lock() - defer r.mu.Unlock() - if r.projectName != nil { - return *r.projectName - } - - var name string - if err := wrapTx(r.db, context.Background(), func(tx *sqlx.Tx) error { - return tx.Get(&name, "SELECT project_name FROM repo WHERE name = ?", r.name) - }); err != nil { - return "" - } - - r.projectName = &name - return name -} - -// IsHidden returns whether the repository is hidden. -// -// It implements backend.Repository. -func (r *Repo) IsHidden() bool { - r.mu.Lock() - defer r.mu.Unlock() - if r.isHidden != nil { - return *r.isHidden - } - - var hidden bool - if err := wrapTx(r.db, context.Background(), func(tx *sqlx.Tx) error { - return tx.Get(&hidden, "SELECT hidden FROM repo WHERE name = ?", r.name) - }); err != nil { - return false - } - - r.isHidden = &hidden - return hidden -} - -// UpdatedAt returns the repository's last update time. -func (r *Repo) UpdatedAt() time.Time { - var updatedAt time.Time - - // Try to read the last modified time from the info directory. - if t, err := readOneline(filepath.Join(r.path, "info", "last-modified")); err == nil { - if t, err := time.Parse(time.RFC3339, t); err == nil { - return t - } - } - - rr, err := git.Open(r.path) - if err == nil { - t, err := rr.LatestCommitTime() - if err == nil { - updatedAt = t - } - } - - if updatedAt.IsZero() { - if err := wrapTx(r.db, context.Background(), func(tx *sqlx.Tx) error { - return tx.Get(&updatedAt, "SELECT updated_at FROM repo WHERE name = ?", r.name) - }); err != nil { - return time.Time{} - } - } - - return updatedAt -} - -func (r *Repo) writeLastModified(t time.Time) error { - fp := filepath.Join(r.path, "info", "last-modified") - if err := os.MkdirAll(filepath.Dir(fp), os.ModePerm); err != nil { - return err - } - - return os.WriteFile(fp, []byte(t.Format(time.RFC3339)), os.ModePerm) -} - -func readOneline(path string) (string, error) { - f, err := os.Open(path) - if err != nil { - return "", err - } - - defer f.Close() // nolint: errcheck - s := bufio.NewScanner(f) - s.Scan() - return s.Text(), s.Err() -} diff --git a/server/backend/sqlite/sql.go b/server/backend/sqlite/sql.go deleted file mode 100644 index 34edd17f72485256186c2f7267e14da0661c437b..0000000000000000000000000000000000000000 --- a/server/backend/sqlite/sql.go +++ /dev/null @@ -1,61 +0,0 @@ -package sqlite - -var ( - sqlCreateSettingsTable = `CREATE TABLE IF NOT EXISTS settings ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - key TEXT NOT NULL UNIQUE, - value TEXT NOT NULL, - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME NOT NULL - );` - - sqlCreateUserTable = `CREATE TABLE IF NOT EXISTS user ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - username TEXT UNIQUE, - admin BOOLEAN NOT NULL, - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME NOT NULL - );` - - sqlCreatePublicKeyTable = `CREATE TABLE IF NOT EXISTS public_key ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id INTEGER NOT NULL, - public_key TEXT NOT NULL UNIQUE, - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME NOT NULL, - UNIQUE (user_id, public_key), - CONSTRAINT user_id_fk - FOREIGN KEY(user_id) REFERENCES user(id) - ON DELETE CASCADE - ON UPDATE CASCADE - );` - - sqlCreateRepoTable = `CREATE TABLE IF NOT EXISTS repo ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL UNIQUE, - project_name TEXT NOT NULL, - description TEXT NOT NULL, - private BOOLEAN NOT NULL, - mirror BOOLEAN NOT NULL, - hidden BOOLEAN NOT NULL, - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME NOT NULL - );` - - sqlCreateCollabTable = `CREATE TABLE IF NOT EXISTS collab ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id INTEGER NOT NULL, - repo_id INTEGER NOT NULL, - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME NOT NULL, - UNIQUE (user_id, repo_id), - CONSTRAINT user_id_fk - FOREIGN KEY(user_id) REFERENCES user(id) - ON DELETE CASCADE - ON UPDATE CASCADE, - CONSTRAINT repo_id_fk - FOREIGN KEY(repo_id) REFERENCES repo(id) - ON DELETE CASCADE - ON UPDATE CASCADE - );` -) diff --git a/server/backend/sqlite/sqlite.go b/server/backend/sqlite/sqlite.go deleted file mode 100644 index 3273373ea54e180ce3821f9178f2a8d81e5c6031..0000000000000000000000000000000000000000 --- a/server/backend/sqlite/sqlite.go +++ /dev/null @@ -1,649 +0,0 @@ -package sqlite - -import ( - "context" - "errors" - "fmt" - "os" - "path/filepath" - "strings" - - "github.com/charmbracelet/log" - "github.com/charmbracelet/soft-serve/git" - "github.com/charmbracelet/soft-serve/server/backend" - "github.com/charmbracelet/soft-serve/server/config" - "github.com/charmbracelet/soft-serve/server/hooks" - "github.com/charmbracelet/soft-serve/server/utils" - lru "github.com/hashicorp/golang-lru/v2" - "github.com/jmoiron/sqlx" - _ "modernc.org/sqlite" // sqlite driver -) - -// SqliteBackend is a backend that uses a SQLite database as a Soft Serve -// backend. -type SqliteBackend struct { //nolint: revive - cfg *config.Config - ctx context.Context - dp string - db *sqlx.DB - logger *log.Logger - - // Repositories cache - cache *cache -} - -var _ backend.Backend = (*SqliteBackend)(nil) - -func (d *SqliteBackend) reposPath() string { - return filepath.Join(d.dp, "repos") -} - -// NewSqliteBackend creates a new SqliteBackend. -func NewSqliteBackend(ctx context.Context) (*SqliteBackend, error) { - cfg := config.FromContext(ctx) - dataPath := cfg.DataPath - if err := os.MkdirAll(dataPath, os.ModePerm); err != nil { - return nil, err - } - - db, err := sqlx.Connect("sqlite", filepath.Join(dataPath, "soft-serve.db"+ - "?_pragma=busy_timeout(5000)&_pragma=foreign_keys(1)")) - if err != nil { - return nil, err - } - - d := &SqliteBackend{ - cfg: cfg, - ctx: ctx, - dp: dataPath, - db: db, - logger: log.FromContext(ctx).WithPrefix("sqlite"), - } - - // Set up LRU cache with size 1000 - d.cache = newCache(d, 1000) - - if err := d.init(); err != nil { - return nil, err - } - - if err := d.db.Ping(); err != nil { - return nil, err - } - - return d, d.initRepos() -} - -// WithContext returns a copy of SqliteBackend with the given context. -func (d SqliteBackend) WithContext(ctx context.Context) backend.Backend { - d.ctx = ctx - return &d -} - -// AllowKeyless returns whether or not keyless access is allowed. -// -// It implements backend.Backend. -func (d *SqliteBackend) AllowKeyless() bool { - var allow bool - if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { - return tx.Get(&allow, "SELECT value FROM settings WHERE key = ?;", "allow_keyless") - }); err != nil { - return false - } - - return allow -} - -// AnonAccess returns the level of anonymous access. -// -// It implements backend.Backend. -func (d *SqliteBackend) AnonAccess() backend.AccessLevel { - var level string - if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { - return tx.Get(&level, "SELECT value FROM settings WHERE key = ?;", "anon_access") - }); err != nil { - return backend.NoAccess - } - - return backend.ParseAccessLevel(level) -} - -// SetAllowKeyless sets whether or not keyless access is allowed. -// -// It implements backend.Backend. -func (d *SqliteBackend) SetAllowKeyless(allow bool) error { - return wrapDbErr( - wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { - _, err := tx.Exec("UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?;", allow, "allow_keyless") - return err - }), - ) -} - -// SetAnonAccess sets the level of anonymous access. -// -// It implements backend.Backend. -func (d *SqliteBackend) SetAnonAccess(level backend.AccessLevel) error { - return wrapDbErr( - wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { - _, err := tx.Exec("UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = ?;", level.String(), "anon_access") - return err - }), - ) -} - -// CreateRepository creates a new repository. -// -// It implements backend.Backend. -func (d *SqliteBackend) CreateRepository(name string, opts backend.RepositoryOptions) (backend.Repository, error) { - name = utils.SanitizeRepo(name) - if err := utils.ValidateRepo(name); err != nil { - return nil, err - } - - repo := name + ".git" - rp := filepath.Join(d.reposPath(), repo) - - if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { - if _, err := tx.Exec(`INSERT INTO repo (name, project_name, description, private, mirror, hidden, updated_at) - VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP);`, - name, opts.ProjectName, opts.Description, opts.Private, opts.Mirror, opts.Hidden); err != nil { - return err - } - - _, err := git.Init(rp, true) - if err != nil { - d.logger.Debug("failed to create repository", "err", err) - return err - } - - return nil - }); err != nil { - d.logger.Debug("failed to create repository in database", "err", err) - return nil, wrapDbErr(err) - } - - r := &Repo{ - name: name, - path: rp, - db: d.db, - } - - // Set cache - d.cache.Set(name, r) - - return r, d.initRepo(name) -} - -// ImportRepository imports a repository from remote. -func (d *SqliteBackend) ImportRepository(name string, remote string, opts backend.RepositoryOptions) (backend.Repository, error) { - name = utils.SanitizeRepo(name) - if err := utils.ValidateRepo(name); err != nil { - return nil, err - } - - repo := name + ".git" - rp := filepath.Join(d.reposPath(), repo) - - if _, err := os.Stat(rp); err == nil || os.IsExist(err) { - return nil, ErrRepoExist - } - - copts := git.CloneOptions{ - Bare: true, - Mirror: opts.Mirror, - Quiet: true, - CommandOptions: git.CommandOptions{ - Timeout: -1, - Context: d.ctx, - Envs: []string{ - fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile="%s" -o StrictHostKeyChecking=no -i "%s"`, - filepath.Join(d.cfg.DataPath, "ssh", "known_hosts"), - d.cfg.SSH.ClientKeyPath, - ), - }, - }, - // Timeout: time.Hour, - } - - if err := git.Clone(remote, rp, copts); err != nil { - d.logger.Error("failed to clone repository", "err", err, "mirror", opts.Mirror, "remote", remote, "path", rp) - // Cleanup the mess! - if rerr := os.RemoveAll(rp); rerr != nil { - err = errors.Join(err, rerr) - } - return nil, err - } - - return d.CreateRepository(name, opts) -} - -// DeleteRepository deletes a repository. -// -// It implements backend.Backend. -func (d *SqliteBackend) DeleteRepository(name string) error { - name = utils.SanitizeRepo(name) - repo := name + ".git" - rp := filepath.Join(d.reposPath(), repo) - - return wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { - // Delete repo from cache - defer d.cache.Delete(name) - - if _, err := tx.Exec("DELETE FROM repo WHERE name = ?;", name); err != nil { - return err - } - - return os.RemoveAll(rp) - }) -} - -// RenameRepository renames a repository. -// -// It implements backend.Backend. -func (d *SqliteBackend) RenameRepository(oldName string, newName string) error { - oldName = utils.SanitizeRepo(oldName) - if err := utils.ValidateRepo(oldName); err != nil { - return err - } - - newName = utils.SanitizeRepo(newName) - if err := utils.ValidateRepo(newName); err != nil { - return err - } - oldRepo := oldName + ".git" - newRepo := newName + ".git" - op := filepath.Join(d.reposPath(), oldRepo) - np := filepath.Join(d.reposPath(), newRepo) - if _, err := os.Stat(op); err != nil { - return ErrRepoNotExist - } - - if _, err := os.Stat(np); err == nil { - return ErrRepoExist - } - - if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { - // Delete cache - defer d.cache.Delete(oldName) - - _, err := tx.Exec("UPDATE repo SET name = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?;", newName, oldName) - if err != nil { - return err - } - - // Make sure the new repository parent directory exists. - if err := os.MkdirAll(filepath.Dir(np), os.ModePerm); err != nil { - return err - } - - if err := os.Rename(op, np); err != nil { - return err - } - - return nil - }); err != nil { - return wrapDbErr(err) - } - - return nil -} - -// Repositories returns a list of all repositories. -// -// It implements backend.Backend. -func (d *SqliteBackend) Repositories() ([]backend.Repository, error) { - repos := make([]backend.Repository, 0) - - if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { - rows, err := tx.Query("SELECT name FROM repo") - if err != nil { - return err - } - - defer rows.Close() // nolint: errcheck - for rows.Next() { - var name string - if err := rows.Scan(&name); err != nil { - return err - } - - if r, ok := d.cache.Get(name); ok && r != nil { - repos = append(repos, r) - continue - } - - r := &Repo{ - name: name, - path: filepath.Join(d.reposPath(), name+".git"), - db: d.db, - } - - // Cache repositories - d.cache.Set(name, r) - - repos = append(repos, r) - } - - return nil - }); err != nil { - return nil, wrapDbErr(err) - } - - return repos, nil -} - -// Repository returns a repository by name. -// -// It implements backend.Backend. -func (d *SqliteBackend) Repository(repo string) (backend.Repository, error) { - repo = utils.SanitizeRepo(repo) - - if r, ok := d.cache.Get(repo); ok && r != nil { - return r, nil - } - - rp := filepath.Join(d.reposPath(), repo+".git") - if _, err := os.Stat(rp); err != nil { - return nil, os.ErrNotExist - } - - var count int - if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { - return tx.Get(&count, "SELECT COUNT(*) FROM repo WHERE name = ?", repo) - }); err != nil { - return nil, wrapDbErr(err) - } - - if count == 0 { - d.logger.Warn("repository exists but not found in database", "repo", repo) - return nil, ErrRepoNotExist - } - - r := &Repo{ - name: repo, - path: rp, - db: d.db, - } - - // Add to cache - d.cache.Set(repo, r) - - return r, nil -} - -// Description returns the description of a repository. -// -// It implements backend.Backend. -func (d *SqliteBackend) Description(repo string) (string, error) { - r, err := d.Repository(repo) - if err != nil { - return "", err - } - - return r.Description(), nil -} - -// IsMirror returns true if the repository is a mirror. -// -// It implements backend.Backend. -func (d *SqliteBackend) IsMirror(repo string) (bool, error) { - r, err := d.Repository(repo) - if err != nil { - return false, err - } - - return r.IsMirror(), nil -} - -// IsPrivate returns true if the repository is private. -// -// It implements backend.Backend. -func (d *SqliteBackend) IsPrivate(repo string) (bool, error) { - r, err := d.Repository(repo) - if err != nil { - return false, err - } - - return r.IsPrivate(), nil -} - -// IsHidden returns true if the repository is hidden. -// -// It implements backend.Backend. -func (d *SqliteBackend) IsHidden(repo string) (bool, error) { - r, err := d.Repository(repo) - if err != nil { - return false, err - } - - return r.IsHidden(), nil -} - -// SetHidden sets the hidden flag of a repository. -// -// It implements backend.Backend. -func (d *SqliteBackend) SetHidden(repo string, hidden bool) error { - repo = utils.SanitizeRepo(repo) - - // Delete cache - d.cache.Delete(repo) - - return wrapDbErr(wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { - var count int - if err := tx.Get(&count, "SELECT COUNT(*) FROM repo WHERE name = ?", repo); err != nil { - return err - } - if count == 0 { - return ErrRepoNotExist - } - _, err := tx.Exec("UPDATE repo SET hidden = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?;", hidden, repo) - return err - })) -} - -// ProjectName returns the project name of a repository. -// -// It implements backend.Backend. -func (d *SqliteBackend) ProjectName(repo string) (string, error) { - r, err := d.Repository(repo) - if err != nil { - return "", err - } - - return r.ProjectName(), nil -} - -// SetDescription sets the description of a repository. -// -// It implements backend.Backend. -func (d *SqliteBackend) SetDescription(repo string, desc string) error { - repo = utils.SanitizeRepo(repo) - - // Delete cache - d.cache.Delete(repo) - - return wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { - var count int - if err := tx.Get(&count, "SELECT COUNT(*) FROM repo WHERE name = ?", repo); err != nil { - return err - } - if count == 0 { - return ErrRepoNotExist - } - _, err := tx.Exec("UPDATE repo SET description = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?", desc, repo) - return err - }) -} - -// SetPrivate sets the private flag of a repository. -// -// It implements backend.Backend. -func (d *SqliteBackend) SetPrivate(repo string, private bool) error { - repo = utils.SanitizeRepo(repo) - - // Delete cache - d.cache.Delete(repo) - - return wrapDbErr( - wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { - var count int - if err := tx.Get(&count, "SELECT COUNT(*) FROM repo WHERE name = ?", repo); err != nil { - return err - } - if count == 0 { - return ErrRepoNotExist - } - _, err := tx.Exec("UPDATE repo SET private = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?", private, repo) - return err - }), - ) -} - -// SetProjectName sets the project name of a repository. -// -// It implements backend.Backend. -func (d *SqliteBackend) SetProjectName(repo string, name string) error { - repo = utils.SanitizeRepo(repo) - - // Delete cache - d.cache.Delete(repo) - - return wrapDbErr( - wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { - var count int - if err := tx.Get(&count, "SELECT COUNT(*) FROM repo WHERE name = ?", repo); err != nil { - return err - } - if count == 0 { - return ErrRepoNotExist - } - _, err := tx.Exec("UPDATE repo SET project_name = ?, updated_at = CURRENT_TIMESTAMP WHERE name = ?", name, repo) - return err - }), - ) -} - -// AddCollaborator adds a collaborator to a repository. -// -// It implements backend.Backend. -func (d *SqliteBackend) AddCollaborator(repo string, username string) error { - username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { - return err - } - - repo = utils.SanitizeRepo(repo) - return wrapDbErr(wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { - _, err := tx.Exec(`INSERT INTO collab (user_id, repo_id, updated_at) - VALUES ( - (SELECT id FROM user WHERE username = ?), - (SELECT id FROM repo WHERE name = ?), - CURRENT_TIMESTAMP - );`, username, repo) - return err - }), - ) -} - -// Collaborators returns a list of collaborators for a repository. -// -// It implements backend.Backend. -func (d *SqliteBackend) Collaborators(repo string) ([]string, error) { - repo = utils.SanitizeRepo(repo) - var users []string - if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { - return tx.Select(&users, `SELECT user.username FROM user - INNER JOIN collab ON user.id = collab.user_id - INNER JOIN repo ON repo.id = collab.repo_id - WHERE repo.name = ?`, repo) - }); err != nil { - return nil, wrapDbErr(err) - } - - return users, nil -} - -// IsCollaborator returns true if the user is a collaborator of the repository. -// -// It implements backend.Backend. -func (d *SqliteBackend) IsCollaborator(repo string, username string) (bool, error) { - repo = utils.SanitizeRepo(repo) - var count int - if err := wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { - return tx.Get(&count, `SELECT COUNT(*) FROM user - INNER JOIN collab ON user.id = collab.user_id - INNER JOIN repo ON repo.id = collab.repo_id - WHERE repo.name = ? AND user.username = ?`, repo, username) - }); err != nil { - return false, wrapDbErr(err) - } - - return count > 0, nil -} - -// RemoveCollaborator removes a collaborator from a repository. -// -// It implements backend.Backend. -func (d *SqliteBackend) RemoveCollaborator(repo string, username string) error { - repo = utils.SanitizeRepo(repo) - return wrapDbErr( - wrapTx(d.db, d.ctx, func(tx *sqlx.Tx) error { - _, err := tx.Exec(`DELETE FROM collab - WHERE user_id = (SELECT id FROM user WHERE username = ?) - AND repo_id = (SELECT id FROM repo WHERE name = ?)`, username, repo) - return err - }), - ) -} - -func (d *SqliteBackend) initRepo(repo string) error { - return hooks.GenerateHooks(d.ctx, d.cfg, repo) -} - -func (d *SqliteBackend) initRepos() error { - repos, err := d.Repositories() - if err != nil { - return err - } - - for _, repo := range repos { - if err := d.initRepo(repo.Name()); err != nil { - return err - } - } - - return nil -} - -// TODO: implement a caching interface. -type cache struct { - b *SqliteBackend - repos *lru.Cache[string, *Repo] -} - -func newCache(b *SqliteBackend, size int) *cache { - if size <= 0 { - size = 1 - } - c := &cache{b: b} - cache, _ := lru.New[string, *Repo](size) - c.repos = cache - return c -} - -func (c *cache) Get(repo string) (*Repo, bool) { - return c.repos.Get(repo) -} - -func (c *cache) Set(repo string, r *Repo) { - c.repos.Add(repo, r) -} - -func (c *cache) Delete(repo string) { - c.repos.Remove(repo) -} - -func (c *cache) Len() int { - return c.repos.Len() -} diff --git a/server/backend/sqlite/user.go b/server/backend/sqlite/user.go deleted file mode 100644 index 5e977a6a7af677b27618eb4baa20f6904f578e3e..0000000000000000000000000000000000000000 --- a/server/backend/sqlite/user.go +++ /dev/null @@ -1,365 +0,0 @@ -package sqlite - -import ( - "context" - "strings" - - "github.com/charmbracelet/soft-serve/server/backend" - "github.com/charmbracelet/soft-serve/server/utils" - "github.com/jmoiron/sqlx" - "golang.org/x/crypto/ssh" -) - -// User represents a user. -type User struct { - username string - db *sqlx.DB -} - -var _ backend.User = (*User)(nil) - -// IsAdmin returns whether the user is an admin. -// -// It implements backend.User. -func (u *User) IsAdmin() bool { - var admin bool - if err := wrapTx(u.db, context.Background(), func(tx *sqlx.Tx) error { - return tx.Get(&admin, "SELECT admin FROM user WHERE username = ?", u.username) - }); err != nil { - return false - } - - return admin -} - -// PublicKeys returns the user's public keys. -// -// It implements backend.User. -func (u *User) PublicKeys() []ssh.PublicKey { - var keys []ssh.PublicKey - if err := wrapTx(u.db, context.Background(), func(tx *sqlx.Tx) error { - var keyStrings []string - if err := tx.Select(&keyStrings, `SELECT public_key - FROM public_key - INNER JOIN user ON user.id = public_key.user_id - WHERE user.username = ? - ORDER BY public_key.id asc;`, u.username); err != nil { - return err - } - - for _, keyString := range keyStrings { - key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyString)) - if err != nil { - return err - } - keys = append(keys, key) - } - - return nil - }); err != nil { - return nil - } - - return keys -} - -// Username returns the user's username. -// -// It implements backend.User. -func (u *User) Username() string { - return u.username -} - -// AccessLevel returns the access level of a user for a repository. -// -// It implements backend.Backend. -func (d *SqliteBackend) AccessLevel(repo string, username string) backend.AccessLevel { - anon := d.AnonAccess() - user, _ := d.User(username) - // If the user is an admin, they have admin access. - if user != nil && user.IsAdmin() { - return backend.AdminAccess - } - - // If the repository exists, check if the user is a collaborator. - r, _ := d.Repository(repo) - if r != nil { - // If the user is a collaborator, they have read/write access. - isCollab, _ := d.IsCollaborator(repo, username) - if isCollab { - if anon > backend.ReadWriteAccess { - return anon - } - return backend.ReadWriteAccess - } - - // If the repository is private, the user has no access. - if r.IsPrivate() { - return backend.NoAccess - } - - // Otherwise, the user has read-only access. - return backend.ReadOnlyAccess - } - - if user != nil { - // If the repository doesn't exist, the user has read/write access. - if anon > backend.ReadWriteAccess { - return anon - } - - return backend.ReadWriteAccess - } - - // If the user doesn't exist, give them the anonymous access level. - return anon -} - -// AccessLevelByPublicKey returns the access level of a user's public key for a repository. -// -// It implements backend.Backend. -func (d *SqliteBackend) AccessLevelByPublicKey(repo string, pk ssh.PublicKey) backend.AccessLevel { - for _, k := range d.cfg.AdminKeys() { - if backend.KeysEqual(pk, k) { - return backend.AdminAccess - } - } - - user, _ := d.UserByPublicKey(pk) - if user != nil { - return d.AccessLevel(repo, user.Username()) - } - - return d.AccessLevel(repo, "") -} - -// AddPublicKey adds a public key to a user. -// -// It implements backend.Backend. -func (d *SqliteBackend) AddPublicKey(username string, pk ssh.PublicKey) error { - username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { - return err - } - - return wrapDbErr( - wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { - var userID int - if err := tx.Get(&userID, "SELECT id FROM user WHERE username = ?", username); err != nil { - return err - } - - _, err := tx.Exec(`INSERT INTO public_key (user_id, public_key, updated_at) - VALUES (?, ?, CURRENT_TIMESTAMP);`, userID, backend.MarshalAuthorizedKey(pk)) - return err - }), - ) -} - -// CreateUser creates a new user. -// -// It implements backend.Backend. -func (d *SqliteBackend) CreateUser(username string, opts backend.UserOptions) (backend.User, error) { - username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { - return nil, err - } - - var user *User - if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { - stmt, err := tx.Prepare("INSERT INTO user (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP);") - if err != nil { - return err - } - - defer stmt.Close() // nolint: errcheck - r, err := stmt.Exec(username, opts.Admin) - if err != nil { - return err - } - - if len(opts.PublicKeys) > 0 { - userID, err := r.LastInsertId() - if err != nil { - d.logger.Error("error getting last insert id") - return err - } - - for _, pk := range opts.PublicKeys { - stmt, err := tx.Prepare(`INSERT INTO public_key (user_id, public_key, updated_at) - VALUES (?, ?, CURRENT_TIMESTAMP);`) - if err != nil { - return err - } - - defer stmt.Close() // nolint: errcheck - if _, err := stmt.Exec(userID, backend.MarshalAuthorizedKey(pk)); err != nil { - return err - } - } - } - - user = &User{ - db: d.db, - username: username, - } - return nil - }); err != nil { - return nil, wrapDbErr(err) - } - - return user, nil -} - -// DeleteUser deletes a user. -// -// It implements backend.Backend. -func (d *SqliteBackend) DeleteUser(username string) error { - username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { - return err - } - - return wrapDbErr( - wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { - _, err := tx.Exec("DELETE FROM user WHERE username = ?", username) - return err - }), - ) -} - -// RemovePublicKey removes a public key from a user. -// -// It implements backend.Backend. -func (d *SqliteBackend) RemovePublicKey(username string, pk ssh.PublicKey) error { - return wrapDbErr( - wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { - _, err := tx.Exec(`DELETE FROM public_key - WHERE user_id = (SELECT id FROM user WHERE username = ?) - AND public_key = ?;`, username, backend.MarshalAuthorizedKey(pk)) - return err - }), - ) -} - -// ListPublicKeys lists the public keys of a user. -func (d *SqliteBackend) ListPublicKeys(username string) ([]ssh.PublicKey, error) { - username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { - return nil, err - } - - keys := make([]ssh.PublicKey, 0) - if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { - var keyStrings []string - if err := tx.Select(&keyStrings, `SELECT public_key - FROM public_key - INNER JOIN user ON user.id = public_key.user_id - WHERE user.username = ?;`, username); err != nil { - return err - } - - for _, keyString := range keyStrings { - key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyString)) - if err != nil { - return err - } - keys = append(keys, key) - } - - return nil - }); err != nil { - return nil, wrapDbErr(err) - } - - return keys, nil -} - -// SetUsername sets the username of a user. -// -// It implements backend.Backend. -func (d *SqliteBackend) SetUsername(username string, newUsername string) error { - username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { - return err - } - - return wrapDbErr( - wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { - _, err := tx.Exec("UPDATE user SET username = ? WHERE username = ?", newUsername, username) - return err - }), - ) -} - -// SetAdmin sets the admin flag of a user. -// -// It implements backend.Backend. -func (d *SqliteBackend) SetAdmin(username string, admin bool) error { - username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { - return err - } - - return wrapDbErr( - wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { - _, err := tx.Exec("UPDATE user SET admin = ? WHERE username = ?", admin, username) - return err - }), - ) -} - -// User finds a user by username. -// -// It implements backend.Backend. -func (d *SqliteBackend) User(username string) (backend.User, error) { - username = strings.ToLower(username) - if err := utils.ValidateUsername(username); err != nil { - return nil, err - } - - if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { - return tx.Get(&username, "SELECT username FROM user WHERE username = ?", username) - }); err != nil { - return nil, wrapDbErr(err) - } - - return &User{ - db: d.db, - username: username, - }, nil -} - -// UserByPublicKey finds a user by public key. -// -// It implements backend.Backend. -func (d *SqliteBackend) UserByPublicKey(pk ssh.PublicKey) (backend.User, error) { - var username string - if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { - return tx.Get(&username, `SELECT user.username - FROM public_key - INNER JOIN user ON user.id = public_key.user_id - WHERE public_key.public_key = ?;`, backend.MarshalAuthorizedKey(pk)) - }); err != nil { - return nil, wrapDbErr(err) - } - - return &User{ - db: d.db, - username: username, - }, nil -} - -// Users returns all users. -// -// It implements backend.Backend. -func (d *SqliteBackend) Users() ([]string, error) { - var users []string - if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error { - return tx.Select(&users, "SELECT username FROM user") - }); err != nil { - return nil, wrapDbErr(err) - } - - return users, nil -} diff --git a/server/backend/user.go b/server/backend/user.go index b0b5f1115dbe3760dbf6972c0eceaeb7f5fbf79a..edefaf686e55eea1b08e43e9fb9a5809ba4d61b7 100644 --- a/server/backend/user.go +++ b/server/backend/user.go @@ -1,55 +1,289 @@ package backend import ( + "context" + "strings" + + "github.com/charmbracelet/soft-serve/server/access" + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/db/models" + "github.com/charmbracelet/soft-serve/server/proto" + "github.com/charmbracelet/soft-serve/server/sshutils" + "github.com/charmbracelet/soft-serve/server/utils" "golang.org/x/crypto/ssh" ) -// User is an interface representing a user. -type User interface { - // Username returns the user's username. - Username() string - // IsAdmin returns whether the user is an admin. - IsAdmin() bool - // PublicKeys returns the user's public keys. - PublicKeys() []ssh.PublicKey -} - -// UserAccess is an interface that handles user access to repositories. -type UserAccess interface { - // AccessLevel returns the access level of the username to the repository. - AccessLevel(repo string, username string) AccessLevel - // AccessLevelByPublicKey returns the access level of the public key to the repository. - AccessLevelByPublicKey(repo string, pk ssh.PublicKey) AccessLevel -} - -// UserStore is an interface for managing users. -type UserStore interface { - // User finds the given user. - User(username string) (User, error) - // UserByPublicKey finds the user with the given public key. - UserByPublicKey(pk ssh.PublicKey) (User, error) - // Users returns a list of all users. - Users() ([]string, error) - // CreateUser creates a new user. - CreateUser(username string, opts UserOptions) (User, error) - // DeleteUser deletes a user. - DeleteUser(username string) error - // SetUsername sets the username of the user. - SetUsername(oldUsername string, newUsername string) error - // SetAdmin sets whether the user is an admin. - SetAdmin(username string, admin bool) error - // AddPublicKey adds a public key to the user. - AddPublicKey(username string, pk ssh.PublicKey) error - // RemovePublicKey removes a public key from the user. - RemovePublicKey(username string, pk ssh.PublicKey) error - // ListPublicKeys lists the public keys of the user. - ListPublicKeys(username string) ([]ssh.PublicKey, error) -} - -// UserOptions are options for creating a user. -type UserOptions struct { - // Admin is whether the user is an admin. - Admin bool - // PublicKeys are the user's public keys. - PublicKeys []ssh.PublicKey +// AccessLevel returns the access level of a user for a repository. +// +// It implements backend.Backend. +func (d *Backend) AccessLevel(ctx context.Context, repo string, username string) access.AccessLevel { + anon := d.AnonAccess(ctx) + user, _ := d.User(ctx, username) + // If the user is an admin, they have admin access. + if user != nil && user.IsAdmin() { + return access.AdminAccess + } + + // If the repository exists, check if the user is a collaborator. + r, _ := d.Repository(ctx, repo) + if r != nil { + // If the user is a collaborator, they have read/write access. + isCollab, _ := d.IsCollaborator(ctx, repo, username) + if isCollab { + if anon > access.ReadWriteAccess { + return anon + } + return access.ReadWriteAccess + } + + // If the repository is private, the user has no access. + if r.IsPrivate() { + return access.NoAccess + } + + // Otherwise, the user has read-only access. + return access.ReadOnlyAccess + } + + if user != nil { + // If the repository doesn't exist, the user has read/write access. + if anon > access.ReadWriteAccess { + return anon + } + + return access.ReadWriteAccess + } + + // If the user doesn't exist, give them the anonymous access level. + return anon +} + +// AccessLevelByPublicKey returns the access level of a user's public key for a repository. +// +// It implements backend.Backend. +func (d *Backend) AccessLevelByPublicKey(ctx context.Context, repo string, pk ssh.PublicKey) access.AccessLevel { + for _, k := range d.cfg.AdminKeys() { + if sshutils.KeysEqual(pk, k) { + return access.AdminAccess + } + } + + user, _ := d.UserByPublicKey(ctx, pk) + if user != nil { + return d.AccessLevel(ctx, repo, user.Username()) + } + + return d.AccessLevel(ctx, repo, "") +} + +// User finds a user by username. +// +// It implements backend.Backend. +func (d *Backend) User(ctx context.Context, username string) (proto.User, error) { + username = strings.ToLower(username) + if err := utils.ValidateUsername(username); err != nil { + return nil, err + } + + var m models.User + var pks []ssh.PublicKey + if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { + var err error + m, err = d.store.FindUserByUsername(ctx, tx, username) + if err != nil { + return err + } + + pks, err = d.store.ListPublicKeysByUserID(ctx, tx, m.ID) + return err + }); err != nil { + return nil, db.WrapError(err) + } + + return &user{ + user: m, + publicKeys: pks, + }, nil +} + +// UserByPublicKey finds a user by public key. +// +// It implements backend.Backend. +func (d *Backend) UserByPublicKey(ctx context.Context, pk ssh.PublicKey) (proto.User, error) { + var m models.User + var pks []ssh.PublicKey + if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { + var err error + m, err = d.store.FindUserByPublicKey(ctx, tx, pk) + if err != nil { + return err + } + + pks, err = d.store.ListPublicKeysByUserID(ctx, tx, m.ID) + return err + }); err != nil { + return nil, db.WrapError(err) + } + + return &user{ + user: m, + publicKeys: pks, + }, nil +} + +// Users returns all users. +// +// It implements backend.Backend. +func (d *Backend) Users(ctx context.Context) ([]string, error) { + var users []string + if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { + ms, err := d.store.GetAllUsers(ctx, tx) + if err != nil { + return err + } + + for _, m := range ms { + users = append(users, m.Username) + } + + return nil + }); err != nil { + return nil, db.WrapError(err) + } + + return users, nil +} + +// AddPublicKey adds a public key to a user. +// +// It implements backend.Backend. +func (d *Backend) AddPublicKey(ctx context.Context, username string, pk ssh.PublicKey) error { + username = strings.ToLower(username) + if err := utils.ValidateUsername(username); err != nil { + return err + } + + return db.WrapError( + d.db.TransactionContext(ctx, func(tx *db.Tx) error { + return d.store.AddPublicKeyByUsername(ctx, tx, username, pk) + }), + ) +} + +// CreateUser creates a new user. +// +// It implements backend.Backend. +func (d *Backend) CreateUser(ctx context.Context, username string, opts proto.UserOptions) (proto.User, error) { + username = strings.ToLower(username) + if err := utils.ValidateUsername(username); err != nil { + return nil, err + } + + if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { + return d.store.CreateUser(ctx, tx, username, opts.Admin, opts.PublicKeys) + }); err != nil { + return nil, db.WrapError(err) + } + + return d.User(ctx, username) +} + +// DeleteUser deletes a user. +// +// It implements backend.Backend. +func (d *Backend) DeleteUser(ctx context.Context, username string) error { + username = strings.ToLower(username) + if err := utils.ValidateUsername(username); err != nil { + return err + } + + return db.WrapError( + d.db.TransactionContext(ctx, func(tx *db.Tx) error { + return d.store.DeleteUserByUsername(ctx, tx, username) + }), + ) +} + +// RemovePublicKey removes a public key from a user. +// +// It implements backend.Backend. +func (d *Backend) RemovePublicKey(ctx context.Context, username string, pk ssh.PublicKey) error { + return db.WrapError( + d.db.TransactionContext(ctx, func(tx *db.Tx) error { + return d.store.RemovePublicKeyByUsername(ctx, tx, username, pk) + }), + ) +} + +// ListPublicKeys lists the public keys of a user. +func (d *Backend) ListPublicKeys(ctx context.Context, username string) ([]ssh.PublicKey, error) { + username = strings.ToLower(username) + if err := utils.ValidateUsername(username); err != nil { + return nil, err + } + + var keys []ssh.PublicKey + if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error { + var err error + keys, err = d.store.ListPublicKeysByUsername(ctx, tx, username) + return err + }); err != nil { + return nil, db.WrapError(err) + } + + return keys, nil +} + +// SetUsername sets the username of a user. +// +// It implements backend.Backend. +func (d *Backend) SetUsername(ctx context.Context, username string, newUsername string) error { + username = strings.ToLower(username) + if err := utils.ValidateUsername(username); err != nil { + return err + } + + return db.WrapError( + d.db.TransactionContext(ctx, func(tx *db.Tx) error { + return d.store.SetUsernameByUsername(ctx, tx, username, newUsername) + }), + ) +} + +// SetAdmin sets the admin flag of a user. +// +// It implements backend.Backend. +func (d *Backend) SetAdmin(ctx context.Context, username string, admin bool) error { + username = strings.ToLower(username) + if err := utils.ValidateUsername(username); err != nil { + return err + } + + return db.WrapError( + d.db.TransactionContext(ctx, func(tx *db.Tx) error { + return d.store.SetAdminByUsername(ctx, tx, username, admin) + }), + ) +} + +type user struct { + user models.User + publicKeys []ssh.PublicKey +} + +var _ proto.User = (*user)(nil) + +// IsAdmin implements store.User +func (u *user) IsAdmin() bool { + return u.user.Admin +} + +// PublicKeys implements store.User +func (u *user) PublicKeys() []ssh.PublicKey { + return u.publicKeys +} + +// Username implements store.User +func (u *user) Username() string { + return u.user.Username } diff --git a/server/backend/utils.go b/server/backend/utils.go index 03d3cccda37c7010e749bd5203d15f20ea2a9f65..024ba8af228a6741d1e957580f90ae722c05584f 100644 --- a/server/backend/utils.go +++ b/server/backend/utils.go @@ -2,11 +2,12 @@ package backend import ( "github.com/charmbracelet/soft-serve/git" + "github.com/charmbracelet/soft-serve/server/proto" ) // LatestFile returns the contents of the latest file at the specified path in // the repository and its file path. -func LatestFile(r Repository, pattern string) (string, string, error) { +func LatestFile(r proto.Repository, pattern string) (string, string, error) { repo, err := r.Open() if err != nil { return "", "", err @@ -15,7 +16,7 @@ func LatestFile(r Repository, pattern string) (string, string, error) { } // Readme returns the repository's README. -func Readme(r Repository) (readme string, path string, err error) { +func Readme(r proto.Repository) (readme string, path string, err error) { pattern := "[rR][eE][aA][dD][mM][eE]*" readme, path, err = LatestFile(r, pattern) return diff --git a/server/cmd/set_username.go b/server/cmd/set_username.go deleted file mode 100644 index 152403a57c68fa18035f7392b2bb67e7b06c977d..0000000000000000000000000000000000000000 --- a/server/cmd/set_username.go +++ /dev/null @@ -1,22 +0,0 @@ -package cmd - -import "github.com/spf13/cobra" - -func setUsernameCommand() *cobra.Command { - cmd := &cobra.Command{ - Use: "set-username USERNAME", - Short: "Set your username", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - cfg, s := fromContext(cmd) - user, err := cfg.Backend.UserByPublicKey(s.PublicKey()) - if err != nil { - return err - } - - return cfg.Backend.SetUsername(user.Username(), args[0]) - }, - } - - return cmd -} diff --git a/server/config/config.go b/server/config/config.go index c8b9fa9bf159e24edcb4ce52f89374a66d6eb925..d815dcc465bdfaa6e0529863cdd6b5d3773a432f 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -1,17 +1,15 @@ package config import ( - "context" - "errors" "fmt" "os" "path/filepath" + "strconv" "strings" "time" "github.com/caarlos0/env/v8" - "github.com/charmbracelet/log" - "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/sshutils" "golang.org/x/crypto/ssh" "gopkg.in/yaml.v3" ) @@ -82,6 +80,19 @@ type LogConfig struct { // Time format for the log `ts` field. // Format must be described in Golang's time format. TimeFormat string `env:"TIME_FORMAT" yaml:"time_format"` + + // Path to a file to write logs to. + // If not set, logs will be written to stderr. + Path string `env:"PATH" yaml:"path"` +} + +// DBConfig is the database connection configuration. +type DBConfig struct { + // Driver is the driver for the database. + Driver string `env:"DRIVER" yaml:"driver"` + + // DataSource is the database data source name. + DataSource string `env:"DATA_SOURCE" yaml:"data_source"` } // Config is the configuration for Soft Serve. @@ -104,14 +115,14 @@ type Config struct { // Log is the logger configuration. Log LogConfig `envPrefix:"LOG_" yaml:"log"` + // DB is the database configuration. + DB DBConfig `envPrefix:"DB_" yaml:"db"` + // InitialAdminKeys is a list of public keys that will be added to the list of admins. InitialAdminKeys []string `env:"INITIAL_ADMIN_KEYS" envSeparator:"\n" yaml:"initial_admin_keys"` // DataPath is the path to the directory where Soft Serve will store its data. DataPath string `env:"DATA_PATH" yaml:"-"` - - // Backend is the Git backend to use. - Backend backend.Backend `yaml:"-"` } // Environ returns the config as a list of environment variables. @@ -123,8 +134,8 @@ func (c *Config) Environ() []string { // TODO: do this dynamically envs = append(envs, []string{ - fmt.Sprintf("SOFT_SERVE_NAME=%s", c.Name), fmt.Sprintf("SOFT_SERVE_DATA_PATH=%s", c.DataPath), + fmt.Sprintf("SOFT_SERVE_NAME=%s", c.Name), fmt.Sprintf("SOFT_SERVE_INITIAL_ADMIN_KEYS=%s", strings.Join(c.InitialAdminKeys, "\n")), fmt.Sprintf("SOFT_SERVE_SSH_LISTEN_ADDR=%s", c.SSH.ListenAddr), fmt.Sprintf("SOFT_SERVE_SSH_PUBLIC_URL=%s", c.SSH.PublicURL), @@ -143,51 +154,50 @@ func (c *Config) Environ() []string { fmt.Sprintf("SOFT_SERVE_STATS_LISTEN_ADDR=%s", c.Stats.ListenAddr), fmt.Sprintf("SOFT_SERVE_LOG_FORMAT=%s", c.Log.Format), fmt.Sprintf("SOFT_SERVE_LOG_TIME_FORMAT=%s", c.Log.TimeFormat), + fmt.Sprintf("SOFT_SERVE_DB_DRIVER=%s", c.DB.Driver), + fmt.Sprintf("SOFT_SERVE_DB_DATA_SOURCE=%s", c.DB.DataSource), }...) return envs } -func parseConfig(path string) (*Config, error) { - dataPath := filepath.Dir(path) - cfg := &Config{ - Name: "Soft Serve", - DataPath: dataPath, - SSH: SSHConfig{ - ListenAddr: ":23231", - PublicURL: "ssh://localhost:23231", - KeyPath: filepath.Join("ssh", "soft_serve_host_ed25519"), - ClientKeyPath: filepath.Join("ssh", "soft_serve_client_ed25519"), - MaxTimeout: 0, - IdleTimeout: 0, - }, - Git: GitConfig{ - ListenAddr: ":9418", - MaxTimeout: 0, - IdleTimeout: 3, - MaxConnections: 32, - }, - HTTP: HTTPConfig{ - ListenAddr: ":23232", - PublicURL: "http://localhost:23232", - }, - Stats: StatsConfig{ - ListenAddr: "localhost:23233", - }, - Log: LogConfig{ - Format: "text", - TimeFormat: time.DateTime, - }, - } +// IsDebug returns true if the server is running in debug mode. +func IsDebug() bool { + debug, _ := strconv.ParseBool(os.Getenv("SOFT_SERVE_DEBUG")) + return debug +} +// IsVerbose returns true if the server is running in verbose mode. +// Verbose mode is only enabled if debug mode is enabled. +func IsVerbose() bool { + verbose, _ := strconv.ParseBool(os.Getenv("SOFT_SERVE_VERBOSE")) + return IsDebug() && verbose +} + +// parseFile parses the given file as a configuration file. +// The file must be in YAML format. +func parseFile(cfg *Config, path string) error { f, err := os.Open(path) - if err == nil { - defer f.Close() // nolint: errcheck - if err := yaml.NewDecoder(f).Decode(cfg); err != nil { - return cfg, fmt.Errorf("decode config: %w", err) - } + if err != nil { + return err + } + + defer f.Close() // nolint: errcheck + if err := yaml.NewDecoder(f).Decode(cfg); err != nil { + return fmt.Errorf("decode config: %w", err) } + return cfg.Validate() +} + +// ParseFile parses the config from the default file path. +// This also calls Validate() on the config. +func (c *Config) ParseFile() error { + return parseFile(c, c.ConfigPath()) +} + +// parseEnv parses the environment variables as a configuration file. +func parseEnv(cfg *Config) error { // Merge initial admin keys from both config file and environment variables. initialAdminKeys := append([]string{}, cfg.InitialAdminKeys...) @@ -195,7 +205,7 @@ func parseConfig(path string) (*Config, error) { if err := env.ParseWithOptions(cfg, env.Options{ Prefix: "SOFT_SERVE_", }); err != nil { - return cfg, fmt.Errorf("parse environment variables: %w", err) + return fmt.Errorf("parse environment variables: %w", err) } // Merge initial admin keys from environment variables. @@ -203,80 +213,108 @@ func parseConfig(path string) (*Config, error) { cfg.InitialAdminKeys = append(cfg.InitialAdminKeys, initialAdminKeys...) } - // Validate keys - pks := make([]string, 0) - for _, key := range parseAuthKeys(cfg.InitialAdminKeys) { - ak := backend.MarshalAuthorizedKey(key) - pks = append(pks, ak) - } - - cfg.InitialAdminKeys = pks - - // Reset datapath to config dir. - // This is necessary because the environment variable may be set to - // a different directory. - cfg.DataPath = dataPath - - return cfg, nil + return cfg.Validate() } -// ParseConfig parses the configuration from the given file. -func ParseConfig(path string) (*Config, error) { - cfg, err := parseConfig(path) - if err != nil { - return cfg, err - } +// ParseEnv parses the config from the environment variables. +// This also calls Validate() on the config. +func (c *Config) ParseEnv() error { + return parseEnv(c) +} - if err := cfg.validate(); err != nil { - return cfg, err +// Parse parses the config from the default file path and environment variables. +// This also calls Validate() on the config. +func (c *Config) Parse() error { + if err := c.ParseFile(); err != nil { + return err } - return cfg, nil + return c.ParseEnv() } -// WriteConfig writes the configuration to the given file. -func WriteConfig(path string, cfg *Config) error { +// writeConfig writes the configuration to the given file. +func writeConfig(cfg *Config, path string) error { if err := os.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil { return err } - return os.WriteFile(path, []byte(newConfigFile(cfg)), 0o644) // nolint: errcheck + return os.WriteFile(path, []byte(newConfigFile(cfg)), 0o644) // nolint: errcheck, gosec } -// DefaultConfig returns a Config with the values populated with the defaults -// or specified environment variables. -func DefaultConfig() *Config { - dataPath := os.Getenv("SOFT_SERVE_DATA_PATH") - if dataPath == "" { - dataPath = "data" - } +// WriteConfig writes the configuration to the default file. +func (c *Config) WriteConfig() error { + return writeConfig(c, c.ConfigPath()) +} - cp := filepath.Join(dataPath, "config.yaml") - cfg, err := parseConfig(cp) - if err != nil && !errors.Is(err, os.ErrNotExist) { - log.Errorf("failed to parse config: %v", err) +// DefaultDataPath returns the path to the data directory. +// It uses the SOFT_SERVE_DATA_PATH environment variable if set, otherwise it +// uses "data". +func DefaultDataPath() string { + dp := os.Getenv("SOFT_SERVE_DATA_PATH") + if dp == "" { + dp = "data" } - // Write config if it doesn't exist - if _, err := os.Stat(cp); os.IsNotExist(err) { - if err := WriteConfig(cp, cfg); err != nil { - log.Fatal("failed to write config", "err", err) - } - } + return dp +} - if err := cfg.validate(); err != nil { - log.Fatal(err) - } +// ConfigPath returns the path to the config file. +func (c *Config) ConfigPath() string { // nolint:revive + return filepath.Join(c.DataPath, "config.yaml") +} + +func exist(path string) bool { + _, err := os.Stat(path) + return err == nil +} - return cfg +// Exist returns true if the config file exists. +func (c *Config) Exist() bool { + return exist(filepath.Join(c.DataPath, "config.yaml")) } -// WithBackend sets the backend for the configuration. -func (c *Config) WithBackend(backend backend.Backend) *Config { - c.Backend = backend - return c +// DefaultConfig returns the default Config. All the path values are relative +// to the data directory. +// Use Validate() to validate the config and ensure absolute paths. +func DefaultConfig() *Config { + return &Config{ + Name: "Soft Serve", + DataPath: DefaultDataPath(), + SSH: SSHConfig{ + ListenAddr: ":23231", + PublicURL: "ssh://localhost:23231", + KeyPath: filepath.Join("ssh", "soft_serve_host_ed25519"), + ClientKeyPath: filepath.Join("ssh", "soft_serve_client_ed25519"), + MaxTimeout: 0, + IdleTimeout: 0, + }, + Git: GitConfig{ + ListenAddr: ":9418", + MaxTimeout: 0, + IdleTimeout: 3, + MaxConnections: 32, + }, + HTTP: HTTPConfig{ + ListenAddr: ":23232", + PublicURL: "http://localhost:23232", + }, + Stats: StatsConfig{ + ListenAddr: "localhost:23233", + }, + Log: LogConfig{ + Format: "text", + TimeFormat: time.DateTime, + }, + DB: DBConfig{ + Driver: "sqlite", + DataSource: "soft-serve.db" + + "?_pragma=busy_timeout(5000)&_pragma=foreign_keys(1)", + }, + } } -func (c *Config) validate() error { +// Validate validates the configuration. +// It updates the configuration with absolute paths. +func (c *Config) Validate() error { // Use absolute paths if !filepath.IsAbs(c.DataPath) { dp, err := filepath.Abs(c.DataPath) @@ -305,19 +343,37 @@ func (c *Config) validate() error { c.HTTP.TLSCertPath = filepath.Join(c.DataPath, c.HTTP.TLSCertPath) } + if strings.HasPrefix(c.DB.Driver, "sqlite") && !filepath.IsAbs(c.DB.DataSource) { + c.DB.DataSource = filepath.Join(c.DataPath, c.DB.DataSource) + } + + // Validate keys + pks := make([]string, 0) + for _, key := range parseAuthKeys(c.InitialAdminKeys) { + ak := sshutils.MarshalAuthorizedKey(key) + pks = append(pks, ak) + } + + c.InitialAdminKeys = pks + return nil } // parseAuthKeys parses authorized keys from either file paths or string authorized_keys. func parseAuthKeys(aks []string) []ssh.PublicKey { + exist := make(map[string]struct{}, 0) pks := make([]ssh.PublicKey, 0) for _, key := range aks { if bts, err := os.ReadFile(key); err == nil { // key is a file key = strings.TrimSpace(string(bts)) } - if pk, _, err := backend.ParseAuthorizedKey(key); err == nil { - pks = append(pks, pk) + + if pk, _, err := sshutils.ParseAuthorizedKey(key); err == nil { + if _, ok := exist[key]; !ok { + pks = append(pks, pk) + exist[key] = struct{}{} + } } } return pks @@ -327,19 +383,3 @@ func parseAuthKeys(aks []string) []ssh.PublicKey { func (c *Config) AdminKeys() []ssh.PublicKey { return parseAuthKeys(c.InitialAdminKeys) } - -var configCtxKey = struct{ string }{"config"} - -// WithContext returns a new context with the configuration attached. -func WithContext(ctx context.Context, cfg *Config) context.Context { - return context.WithValue(ctx, configCtxKey, cfg) -} - -// FromContext returns the configuration from the context. -func FromContext(ctx context.Context) *Config { - if c, ok := ctx.Value(configCtxKey).(*Config); ok { - return c - } - - return DefaultConfig() -} diff --git a/server/config/config_test.go b/server/config/config_test.go index 3812e4f9e72b871f3611b53ee2a1db1a12c803cd..503e4de49e349a1464433c76ed138a069ec9df6b 100644 --- a/server/config/config_test.go +++ b/server/config/config_test.go @@ -2,11 +2,9 @@ package config import ( "os" - "path/filepath" "testing" "github.com/matryer/is" - "gopkg.in/yaml.v3" ) func TestParseMultipleKeys(t *testing.T) { @@ -19,6 +17,7 @@ func TestParseMultipleKeys(t *testing.T) { is.NoErr(os.Unsetenv("SOFT_SERVE_DATA_PATH")) }) cfg := DefaultConfig() + is.NoErr(cfg.ParseEnv()) is.Equal(cfg.InitialAdminKeys, []string{ "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINMwLvyV3ouVrTysUYGoJdl5Vgn5BACKov+n9PlzfPwH", "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFxIobhwtfdwN7m1TFt9wx3PsfvcAkISGPxmbmbauST8", @@ -29,15 +28,12 @@ func TestMergeInitAdminKeys(t *testing.T) { is := is.New(t) is.NoErr(os.Setenv("SOFT_SERVE_INITIAL_ADMIN_KEYS", "testdata/k1.pub")) t.Cleanup(func() { is.NoErr(os.Unsetenv("SOFT_SERVE_INITIAL_ADMIN_KEYS")) }) - bts, err := yaml.Marshal(&Config{ + cfg := &Config{ + DataPath: t.TempDir(), InitialAdminKeys: []string{"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFxIobhwtfdwN7m1TFt9wx3PsfvcAkISGPxmbmbauST8 a@b"}, - }) - is.NoErr(err) - fp := filepath.Join(t.TempDir(), "config.yaml") - err = os.WriteFile(fp, bts, 0o644) - is.NoErr(err) - cfg, err := ParseConfig(fp) - is.NoErr(err) + } + is.NoErr(cfg.WriteConfig()) + is.NoErr(cfg.Parse()) is.Equal(cfg.InitialAdminKeys, []string{ "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINMwLvyV3ouVrTysUYGoJdl5Vgn5BACKov+n9PlzfPwH", "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFxIobhwtfdwN7m1TFt9wx3PsfvcAkISGPxmbmbauST8", @@ -46,19 +42,16 @@ func TestMergeInitAdminKeys(t *testing.T) { func TestValidateInitAdminKeys(t *testing.T) { is := is.New(t) - bts, err := yaml.Marshal(&Config{ + cfg := &Config{ + DataPath: t.TempDir(), InitialAdminKeys: []string{ "testdata/k1.pub", "abc", "", }, - }) - is.NoErr(err) - fp := filepath.Join(t.TempDir(), "config.yaml") - err = os.WriteFile(fp, bts, 0o644) - is.NoErr(err) - cfg, err := ParseConfig(fp) - is.NoErr(err) + } + is.NoErr(cfg.WriteConfig()) + is.NoErr(cfg.Parse()) is.Equal(cfg.InitialAdminKeys, []string{ "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINMwLvyV3ouVrTysUYGoJdl5Vgn5BACKov+n9PlzfPwH", }) diff --git a/server/config/context.go b/server/config/context.go new file mode 100644 index 0000000000000000000000000000000000000000..2a9c47bcf3433870310ce950e0fd0743e8f84b4c --- /dev/null +++ b/server/config/context.go @@ -0,0 +1,20 @@ +package config + +import "context" + +// ContextKey is the context key for the config. +var ContextKey = struct{ string }{"config"} + +// WithContext returns a new context with the configuration attached. +func WithContext(ctx context.Context, cfg *Config) context.Context { + return context.WithValue(ctx, ContextKey, cfg) +} + +// FromContext returns the configuration from the context. +func FromContext(ctx context.Context) *Config { + if c, ok := ctx.Value(ContextKey).(*Config); ok { + return c + } + + return DefaultConfig() +} diff --git a/server/config/file.go b/server/config/file.go index 09e5ce2e00dec68891f27e63bd5cf14fa1faca2c..a3b78c0160d57909f9f120489ed87a41daf489fc 100644 --- a/server/config/file.go +++ b/server/config/file.go @@ -18,6 +18,8 @@ log: # Time format for the log "timestamp" field. # Should be described in Golang's time format. time_format: "{{ .Log.TimeFormat }}" + # Path to the log file. Leave empty to write to stderr. + #path: "{{ .Log.Path }}" # The SSH server configuration. ssh: @@ -79,6 +81,15 @@ stats: # The address on which the stats server will listen. listen_addr: "{{ .Stats.ListenAddr }}" +# The database configuration. +db: + # The database driver to use. + # Valid values are "sqlite" and "postgres". + driver: "{{ .DB.Driver }}" + # The database data source name. + # This is driver specific and can be a file path or connection string. + data_source: "{{ .DB.DataSource }}" + # Additional admin keys. #initial_admin_keys: # - "ssh-rsa AAAAB3NzaC1yc2..." diff --git a/server/cron/cron.go b/server/cron/cron.go index af4baf0ad5f83d8398e403ffe15c124833b1dbda..aae506ed54d09edff27558ba2664079e8eafab5f 100644 --- a/server/cron/cron.go +++ b/server/cron/cron.go @@ -8,17 +8,9 @@ import ( "github.com/robfig/cron/v3" ) -// CronScheduler is a cron-like job scheduler. -type CronScheduler struct { +// Scheduler is a cron-like job scheduler. +type Scheduler struct { *cron.Cron - logger cron.Logger -} - -// Entry is a cron job. -type Entry struct { - ID cron.EntryID - Desc string - Spec string } // cronLogger is a wrapper around the logger to make it compatible with the @@ -37,22 +29,22 @@ func (l cronLogger) Error(err error, msg string, keysAndValues ...interface{}) { l.logger.Error(msg, append(keysAndValues, "err", err)...) } -// NewCronScheduler returns a new Cron. -func NewCronScheduler(ctx context.Context) *CronScheduler { +// NewScheduler returns a new Cron. +func NewScheduler(ctx context.Context) *Scheduler { logger := cronLogger{log.FromContext(ctx).WithPrefix("cron")} - return &CronScheduler{ + return &Scheduler{ Cron: cron.New(cron.WithLogger(logger)), } } -// Shutdonw gracefully shuts down the CronServer. -func (s *CronScheduler) Shutdown() { +// Shutdonw gracefully shuts down the Scheduler. +func (s *Scheduler) Shutdown() { ctx, cancel := context.WithTimeout(s.Cron.Stop(), 30*time.Second) defer func() { cancel() }() <-ctx.Done() } -// Start starts the CronServer. -func (s *CronScheduler) Start() { +// Start starts the Scheduler. +func (s *Scheduler) Start() { s.Cron.Start() } diff --git a/server/daemon/conn.go b/server/daemon/conn.go index 090d76aeecc3ff3ab847e036753ddfdaa3c3705b..b4a342309904f83d2f9c5c1555fefe7390280d48 100644 --- a/server/daemon/conn.go +++ b/server/daemon/conn.go @@ -91,15 +91,15 @@ func (c *serverConn) updateDeadline() { initTimeout := time.Now().Add(c.initTimeout) c.initTimeout = 0 if initTimeout.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() { - c.Conn.SetDeadline(initTimeout) + c.Conn.SetDeadline(initTimeout) // nolint: errcheck return } case c.idleTimeout > 0: idleDeadline := time.Now().Add(c.idleTimeout) if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() { - c.Conn.SetDeadline(idleDeadline) + c.Conn.SetDeadline(idleDeadline) // nolint: errcheck return } } - c.Conn.SetDeadline(c.maxDeadline) + c.Conn.SetDeadline(c.maxDeadline) // nolint: errcheck } diff --git a/server/daemon/daemon.go b/server/daemon/daemon.go index 1820c0e3e2c141f79488dac97537e5655a2ad4d5..fcd0ae4f92b1d8a39bcd365f94da2713d308a9bb 100644 --- a/server/daemon/daemon.go +++ b/server/daemon/daemon.go @@ -11,6 +11,7 @@ import ( "time" "github.com/charmbracelet/log" + "github.com/charmbracelet/soft-serve/server/access" "github.com/charmbracelet/soft-serve/server/backend" "github.com/charmbracelet/soft-serve/server/config" "github.com/charmbracelet/soft-serve/server/git" @@ -50,7 +51,7 @@ type GitDaemon struct { finished chan struct{} conns connections cfg *config.Config - be backend.Backend + be *backend.Backend wg sync.WaitGroup once sync.Once logger *log.Logger @@ -94,7 +95,7 @@ func (d *GitDaemon) Start() error { default: d.logger.Debugf("git: error accepting connection: %v", err) } - if ne, ok := err.(net.Error); ok && ne.Temporary() { + if ne, ok := err.(net.Error); ok && ne.Temporary() { // nolint: staticcheck if tempDelay == 0 { tempDelay = 5 * time.Millisecond } else { @@ -125,7 +126,7 @@ func (d *GitDaemon) Start() error { } func (d *GitDaemon) fatal(c net.Conn, err error) { - git.WritePktline(c, err) + git.WritePktlineErr(c, err) // nolint: errcheck if err := c.Close(); err != nil { d.logger.Debugf("git: error closing connection: %v", err) } @@ -146,7 +147,7 @@ func (d *GitDaemon) handleClient(conn net.Conn) { } d.conns.Add(c) defer func() { - d.conns.Close(c) + d.conns.Close(c) // nolint: errcheck }() readc := make(chan struct{}, 1) @@ -227,8 +228,8 @@ func (d *GitDaemon) handleClient(conn net.Conn) { } } - be := d.be.WithContext(ctx) - if !be.AllowKeyless() { + be := d.be + if !be.AllowKeyless(ctx) { d.fatal(c, git.ErrNotAuthed) return } @@ -247,13 +248,13 @@ func (d *GitDaemon) handleClient(conn net.Conn) { return } - if _, err := d.be.Repository(repo); err != nil { + if _, err := d.be.Repository(ctx, repo); err != nil { d.fatal(c, git.ErrInvalidRepo) return } - auth := be.AccessLevel(name, "") - if auth < backend.ReadOnlyAccess { + auth := be.AccessLevel(ctx, name, "") + if auth < access.ReadOnlyAccess { d.fatal(c, git.ErrNotAuthed) return } @@ -263,6 +264,7 @@ func (d *GitDaemon) handleClient(conn net.Conn) { "SOFT_SERVE_REPO_NAME=" + name, "SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo), "SOFT_SERVE_HOST=" + host, + "SOFT_SERVE_LOG_PATH=" + filepath.Join(d.cfg.DataPath, "log", "hooks.log"), } // Add git protocol environment variable. @@ -301,7 +303,7 @@ func (d *GitDaemon) handleClient(conn net.Conn) { func (d *GitDaemon) Close() error { d.once.Do(func() { close(d.finished) }) err := d.listener.Close() - d.conns.CloseAll() + d.conns.CloseAll() // nolint: errcheck return err } diff --git a/server/daemon/daemon_test.go b/server/daemon/daemon_test.go index a28fb68d4404d5e3eaa9e50535a984874275de7e..c11ddefb5648ad306b0489aa0d24a5d136b08b26 100644 --- a/server/daemon/daemon_test.go +++ b/server/daemon/daemon_test.go @@ -13,11 +13,13 @@ import ( "testing" "github.com/charmbracelet/soft-serve/server/backend" - "github.com/charmbracelet/soft-serve/server/backend/sqlite" "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/db/migrate" "github.com/charmbracelet/soft-serve/server/git" "github.com/charmbracelet/soft-serve/server/test" "github.com/go-git/go-git/v5/plumbing/format/pktline" + _ "modernc.org/sqlite" // sqlite driver ) var testDaemon *GitDaemon @@ -35,13 +37,20 @@ func TestMain(m *testing.M) { os.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", test.RandomPort())) ctx := context.TODO() cfg := config.DefaultConfig() + if err := cfg.Validate(); err != nil { + log.Fatal(err) + } ctx = config.WithContext(ctx, cfg) - fb, err := sqlite.NewSqliteBackend(ctx) + db, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource) if err != nil { log.Fatal(err) } - cfg = cfg.WithBackend(fb) - ctx = backend.WithContext(ctx, fb) + defer db.Close() // nolint: errcheck + if err := migrate.Migrate(ctx, db); err != nil { + log.Fatal(err) + } + be := backend.New(ctx, cfg, db) + ctx = backend.WithContext(ctx, be) d, err := NewGitDaemon(ctx) if err != nil { log.Fatal(err) @@ -59,7 +68,7 @@ func TestMain(m *testing.M) { os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT") os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR") _ = d.Close() - _ = fb.Close() + _ = db.Close() os.Exit(code) } @@ -72,7 +81,7 @@ func TestIdleTimeout(t *testing.T) { if err != nil && !errors.Is(err, io.EOF) { t.Fatalf("expected nil, got error: %v", err) } - if out != git.ErrTimeout.Error() && out != "" { + if out != "ERR "+git.ErrTimeout.Error() && out != "" { t.Fatalf("expected %q error, got %q", git.ErrTimeout, out) } } @@ -89,7 +98,7 @@ func TestInvalidRepo(t *testing.T) { if err != nil { t.Fatalf("expected nil, got error: %v", err) } - if out != git.ErrInvalidRepo.Error() { + if out != "ERR "+git.ErrInvalidRepo.Error() { t.Fatalf("expected %q error, got %q", git.ErrInvalidRepo, out) } } diff --git a/server/db/context.go b/server/db/context.go new file mode 100644 index 0000000000000000000000000000000000000000..17c70ee4978d8b3c97e0a05177d9cd30d82c3ae8 --- /dev/null +++ b/server/db/context.go @@ -0,0 +1,18 @@ +package db + +import "context" + +var contextKey = struct{ string }{"db"} + +// FromContext returns the database from the context. +func FromContext(ctx context.Context) *DB { + if db, ok := ctx.Value(contextKey).(*DB); ok { + return db + } + return nil +} + +// WithContext returns a new context with the database. +func WithContext(ctx context.Context, db *DB) context.Context { + return context.WithValue(ctx, contextKey, db) +} diff --git a/server/db/db.go b/server/db/db.go new file mode 100644 index 0000000000000000000000000000000000000000..587880dc6201ff4bfb68071a99d55f9e6b73de47 --- /dev/null +++ b/server/db/db.go @@ -0,0 +1,88 @@ +package db + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/charmbracelet/log" + "github.com/charmbracelet/soft-serve/server/config" + "github.com/jmoiron/sqlx" + _ "modernc.org/sqlite" // sqlite driver +) + +// DB is the interface for a Soft Serve database. +type DB struct { + *sqlx.DB + logger *log.Logger +} + +// Open opens a database connection. +func Open(ctx context.Context, driverName string, dsn string) (*DB, error) { + db, err := sqlx.ConnectContext(ctx, driverName, dsn) + if err != nil { + return nil, err + } + + d := &DB{ + DB: db, + } + + if config.IsVerbose() { + logger := log.FromContext(ctx).WithPrefix("db") + d.logger = logger + } + + return d, nil +} + +// Close implements db.DB. +func (d *DB) Close() error { + return d.DB.Close() +} + +// Tx is a database transaction. +type Tx struct { + *sqlx.Tx + logger *log.Logger +} + +// Transaction implements db.DB. +func (d *DB) Transaction(fn func(tx *Tx) error) error { + return d.TransactionContext(context.Background(), fn) +} + +// TransactionContext implements db.DB. +func (d *DB) TransactionContext(ctx context.Context, fn func(tx *Tx) error) error { + txx, err := d.DB.BeginTxx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + + tx := &Tx{txx, d.logger} + if err := fn(tx); err != nil { + return rollback(tx, err) + } + + if err := tx.Commit(); err != nil { + if errors.Is(err, sql.ErrTxDone) { + // this is ok because whoever did finish the tx should have also written the error already. + return nil + } + return fmt.Errorf("failed to commit transaction: %w", err) + } + + return nil +} + +func rollback(tx *Tx, err error) error { + if rerr := tx.Rollback(); rerr != nil { + if errors.Is(rerr, sql.ErrTxDone) { + return err + } + return fmt.Errorf("failed to rollback: %s: %w", err.Error(), rerr) + } + + return err +} diff --git a/server/db/errors.go b/server/db/errors.go new file mode 100644 index 0000000000000000000000000000000000000000..f793a0588a94c7c2461da50aaf87060ac724d923 --- /dev/null +++ b/server/db/errors.go @@ -0,0 +1,48 @@ +package db + +import ( + "database/sql" + "errors" + + "github.com/lib/pq" + sqlite "modernc.org/sqlite" + sqlite3 "modernc.org/sqlite/lib" +) + +var ( + // ErrDuplicateKey is a constraint violation error. + ErrDuplicateKey = errors.New("duplicate key value violates table constraint") + + // ErrRecordNotFound is returned when a record is not found. + ErrRecordNotFound = errors.New("record not found") +) + +// WrapError is a convenient function that unite various database driver +// errors to consistent errors. +func WrapError(err error) error { + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return ErrRecordNotFound + } + + // Handle sqlite constraint error. + if liteErr, ok := err.(*sqlite.Error); ok { + code := liteErr.Code() + if code == sqlite3.SQLITE_CONSTRAINT_PRIMARYKEY || + code == sqlite3.SQLITE_CONSTRAINT_FOREIGNKEY || + code == sqlite3.SQLITE_CONSTRAINT_UNIQUE { + return ErrDuplicateKey + } + } + + // Handle postgres constraint error. + if pgErr, ok := err.(*pq.Error); ok { + if pgErr.Code == "23505" || + pgErr.Code == "23503" || + pgErr.Code == "23514" { + return ErrDuplicateKey + } + } + } + return err +} diff --git a/server/db/logger.go b/server/db/logger.go new file mode 100644 index 0000000000000000000000000000000000000000..bf3d932b24c2f57330a3d7d3e7ca0181c8869d24 --- /dev/null +++ b/server/db/logger.go @@ -0,0 +1,135 @@ +package db + +import ( + "context" + "database/sql" + + "github.com/charmbracelet/log" + "github.com/jmoiron/sqlx" +) + +func trace(l *log.Logger, query string, args ...interface{}) { + if l != nil { + l.Debug("trace", "query", query, "args", args) + } +} + +// Select is a wrapper around sqlx.Select that logs the query and arguments. +func (d *DB) Select(dest interface{}, query string, args ...interface{}) error { + trace(d.logger, query, args...) + return d.DB.Select(dest, query, args...) +} + +// Get is a wrapper around sqlx.Get that logs the query and arguments. +func (d *DB) Get(dest interface{}, query string, args ...interface{}) error { + trace(d.logger, query, args...) + return d.DB.Get(dest, query, args...) +} + +// Queryx is a wrapper around sqlx.Queryx that logs the query and arguments. +func (d *DB) Queryx(query string, args ...interface{}) (*sqlx.Rows, error) { + trace(d.logger, query, args...) + return d.DB.Queryx(query, args...) +} + +// QueryRowx is a wrapper around sqlx.QueryRowx that logs the query and arguments. +func (d *DB) QueryRowx(query string, args ...interface{}) *sqlx.Row { + trace(d.logger, query, args...) + return d.DB.QueryRowx(query, args...) +} + +// Exec is a wrapper around sqlx.Exec that logs the query and arguments. +func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { + trace(d.logger, query, args...) + return d.DB.Exec(query, args...) +} + +// SelectContext is a wrapper around sqlx.SelectContext that logs the query and arguments. +func (d *DB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + trace(d.logger, query, args...) + return d.DB.SelectContext(ctx, dest, query, args...) +} + +// GetContext is a wrapper around sqlx.GetContext that logs the query and arguments. +func (d *DB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + trace(d.logger, query, args...) + return d.DB.GetContext(ctx, dest, query, args...) +} + +// QueryxContext is a wrapper around sqlx.QueryxContext that logs the query and arguments. +func (d *DB) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) { + trace(d.logger, query, args...) + return d.DB.QueryxContext(ctx, query, args...) +} + +// QueryRowxContext is a wrapper around sqlx.QueryRowxContext that logs the query and arguments. +func (d *DB) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row { + trace(d.logger, query, args...) + return d.DB.QueryRowxContext(ctx, query, args...) +} + +// ExecContext is a wrapper around sqlx.ExecContext that logs the query and arguments. +func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + trace(d.logger, query, args...) + return d.DB.ExecContext(ctx, query, args...) +} + +// Select is a wrapper around sqlx.Select that logs the query and arguments. +func (t *Tx) Select(dest interface{}, query string, args ...interface{}) error { + trace(t.logger, query, args...) + return t.Tx.Select(dest, query, args...) +} + +// Get is a wrapper around sqlx.Get that logs the query and arguments. +func (t *Tx) Get(dest interface{}, query string, args ...interface{}) error { + trace(t.logger, query, args...) + return t.Tx.Get(dest, query, args...) +} + +// Queryx is a wrapper around sqlx.Queryx that logs the query and arguments. +func (t *Tx) Queryx(query string, args ...interface{}) (*sqlx.Rows, error) { + trace(t.logger, query, args...) + return t.Tx.Queryx(query, args...) +} + +// QueryRowx is a wrapper around sqlx.QueryRowx that logs the query and arguments. +func (t *Tx) QueryRowx(query string, args ...interface{}) *sqlx.Row { + trace(t.logger, query, args...) + return t.Tx.QueryRowx(query, args...) +} + +// Exec is a wrapper around sqlx.Exec that logs the query and arguments. +func (t *Tx) Exec(query string, args ...interface{}) (sql.Result, error) { + trace(t.logger, query, args...) + return t.Tx.Exec(query, args...) +} + +// SelectContext is a wrapper around sqlx.SelectContext that logs the query and arguments. +func (t *Tx) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + trace(t.logger, query, args...) + return t.Tx.SelectContext(ctx, dest, query, args...) +} + +// GetContext is a wrapper around sqlx.GetContext that logs the query and arguments. +func (t *Tx) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + trace(t.logger, query, args...) + return t.Tx.GetContext(ctx, dest, query, args...) +} + +// QueryxContext is a wrapper around sqlx.QueryxContext that logs the query and arguments. +func (t *Tx) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) { + trace(t.logger, query, args...) + return t.Tx.QueryxContext(ctx, query, args...) +} + +// QueryRowxContext is a wrapper around sqlx.QueryRowxContext that logs the query and arguments. +func (t *Tx) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row { + trace(t.logger, query, args...) + return t.Tx.QueryRowxContext(ctx, query, args...) +} + +// ExecContext is a wrapper around sqlx.ExecContext that logs the query and arguments. +func (t *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + trace(t.logger, query, args...) + return t.Tx.ExecContext(ctx, query, args...) +} diff --git a/server/db/migrate/0001_create_tables.go b/server/db/migrate/0001_create_tables.go new file mode 100644 index 0000000000000000000000000000000000000000..b6cf1c85a95cebbcc52d43f4546f14dc6e150db3 --- /dev/null +++ b/server/db/migrate/0001_create_tables.go @@ -0,0 +1,134 @@ +package migrate + +import ( + "context" + "errors" + "fmt" + + "github.com/charmbracelet/soft-serve/server/access" + "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/sshutils" +) + +const ( + createTablesName = "create tables" + createTablesVersion = 1 +) + +var createTables = Migration{ + Version: createTablesVersion, + Name: createTablesName, + Migrate: func(ctx context.Context, tx *db.Tx) error { + cfg := config.FromContext(ctx) + + insert := "INSERT " + + // Alter old tables (if exist) + // This is to support prior versions of Soft Serve + switch tx.DriverName() { + case "sqlite3", "sqlite": + insert += "OR IGNORE " + + hasUserTable := hasTable(tx, "user") + if hasUserTable { + if _, err := tx.ExecContext(ctx, "ALTER TABLE user RENAME TO users"); err != nil { + return err + } + } + + if hasTable(tx, "public_key") { + if _, err := tx.ExecContext(ctx, "ALTER TABLE public_key RENAME TO public_keys"); err != nil { + return err + } + } + + if hasTable(tx, "collab") { + if _, err := tx.ExecContext(ctx, "ALTER TABLE collab RENAME TO collabs"); err != nil { + return err + } + } + + if hasTable(tx, "repo") { + if _, err := tx.ExecContext(ctx, "ALTER TABLE repo RENAME TO repos"); err != nil { + return err + } + } + + // Fix username being nullable + if hasUserTable { + sqlm := ` + PRAGMA foreign_keys = OFF; + + CREATE TABLE users_new ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT NOT NULL UNIQUE, + admin BOOLEAN NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL + ); + + INSERT INTO users_new (username, admin, updated_at) + SELECT username, admin, updated_at FROM users; + + DROP TABLE users; + ALTER TABLE users_new RENAME TO users; + + PRAGMA foreign_keys = ON; + ` + if _, err := tx.ExecContext(ctx, sqlm); err != nil { + return err + } + } + } + + if err := migrateUp(ctx, tx, createTablesVersion, createTablesName); err != nil { + return err + } + + // Insert default user + insertUser := tx.Rebind(insert + "INTO users (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)") + if _, err := tx.ExecContext(ctx, insertUser, "admin", true); err != nil { + return err + } + + for _, k := range cfg.AdminKeys() { + query := insert + "INTO public_keys (user_id, public_key, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)" + if tx.DriverName() == "postgres" { + query += " ON CONFLICT DO NOTHING" + } + + query = tx.Rebind(query) + ak := sshutils.MarshalAuthorizedKey(k) + if _, err := tx.ExecContext(ctx, query, 1, ak); err != nil { + if errors.Is(db.WrapError(err), db.ErrDuplicateKey) { + continue + } + return err + } + } + + // Insert default settings + insertSettings := insert + "INTO settings (key, value, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP)" + insertSettings = tx.Rebind(insertSettings) + settings := []struct { + Key string + Value string + }{ + {"allow_keyless", "true"}, + {"anon_access", access.ReadOnlyAccess.String()}, + {"init", "true"}, + } + + for _, s := range settings { + if _, err := tx.ExecContext(ctx, insertSettings, s.Key, s.Value); err != nil { + return fmt.Errorf("inserting default settings %q: %w", s.Key, err) + } + } + + return nil + }, + Rollback: func(ctx context.Context, tx *db.Tx) error { + return migrateDown(ctx, tx, createTablesVersion, createTablesName) + }, +} diff --git a/server/db/migrate/0001_create_tables_postgres.down.sql b/server/db/migrate/0001_create_tables_postgres.down.sql new file mode 100644 index 0000000000000000000000000000000000000000..35eeb70dedeae7d862935e287a9ab3516d56b674 --- /dev/null +++ b/server/db/migrate/0001_create_tables_postgres.down.sql @@ -0,0 +1,5 @@ +DROP TABLE IF EXISTS collabs; +DROP TABLE IF EXISTS repos; +DROP TABLE IF EXISTS public_keys; +DROP TABLE IF EXISTS users; +DROP TABLE IF EXISTS settings; diff --git a/server/db/migrate/0001_create_tables_postgres.up.sql b/server/db/migrate/0001_create_tables_postgres.up.sql new file mode 100644 index 0000000000000000000000000000000000000000..29a8e0bacc3af6cd5419c701a996d2b29c0d762b --- /dev/null +++ b/server/db/migrate/0001_create_tables_postgres.up.sql @@ -0,0 +1,59 @@ +CREATE TABLE IF NOT EXISTS settings ( + id SERIAL PRIMARY KEY, + key TEXT NOT NULL UNIQUE, + value TEXT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL +); + +CREATE TABLE IF NOT EXISTS users ( + id SERIAL PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + admin BOOLEAN NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL +); + +CREATE TABLE IF NOT EXISTS public_keys ( + id SERIAL PRIMARY KEY, + user_id INTEGER NOT NULL, + public_key TEXT NOT NULL UNIQUE, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + UNIQUE (user_id, public_key), + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES users(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS repos ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + project_name TEXT NOT NULL, + description TEXT NOT NULL, + private BOOLEAN NOT NULL, + mirror BOOLEAN NOT NULL, + hidden BOOLEAN NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL +); + +CREATE TABLE IF NOT EXISTS collabs ( + id SERIAL PRIMARY KEY, + user_id INTEGER NOT NULL, + repo_id INTEGER NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP NOT NULL, + UNIQUE (user_id, repo_id), + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES users(id) + ON DELETE CASCADE + ON UPDATE CASCADE, + CONSTRAINT repo_id_fk + FOREIGN KEY(repo_id) REFERENCES repos(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + + diff --git a/server/db/migrate/0001_create_tables_sqlite.down.sql b/server/db/migrate/0001_create_tables_sqlite.down.sql new file mode 100644 index 0000000000000000000000000000000000000000..35eeb70dedeae7d862935e287a9ab3516d56b674 --- /dev/null +++ b/server/db/migrate/0001_create_tables_sqlite.down.sql @@ -0,0 +1,5 @@ +DROP TABLE IF EXISTS collabs; +DROP TABLE IF EXISTS repos; +DROP TABLE IF EXISTS public_keys; +DROP TABLE IF EXISTS users; +DROP TABLE IF EXISTS settings; diff --git a/server/db/migrate/0001_create_tables_sqlite.up.sql b/server/db/migrate/0001_create_tables_sqlite.up.sql new file mode 100644 index 0000000000000000000000000000000000000000..0880f464fd1a854399ad525cda016e90ea583a4f --- /dev/null +++ b/server/db/migrate/0001_create_tables_sqlite.up.sql @@ -0,0 +1,58 @@ +CREATE TABLE IF NOT EXISTS settings ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + key TEXT NOT NULL UNIQUE, + value TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL +); + +CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT NOT NULL UNIQUE, + admin BOOLEAN NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL +); + +CREATE TABLE IF NOT EXISTS public_keys ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + public_key TEXT NOT NULL UNIQUE, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL, + UNIQUE (user_id, public_key), + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES users(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS repos ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + project_name TEXT NOT NULL, + description TEXT NOT NULL, + private BOOLEAN NOT NULL, + mirror BOOLEAN NOT NULL, + hidden BOOLEAN NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL +); + +CREATE TABLE IF NOT EXISTS collabs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + repo_id INTEGER NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL, + UNIQUE (user_id, repo_id), + CONSTRAINT user_id_fk + FOREIGN KEY(user_id) REFERENCES users(id) + ON DELETE CASCADE + ON UPDATE CASCADE, + CONSTRAINT repo_id_fk + FOREIGN KEY(repo_id) REFERENCES repos(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + diff --git a/server/db/migrate/migrate.go b/server/db/migrate/migrate.go new file mode 100644 index 0000000000000000000000000000000000000000..883f3e631c05d2c72960c6e5edeab13be90e1e99 --- /dev/null +++ b/server/db/migrate/migrate.go @@ -0,0 +1,142 @@ +package migrate + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/charmbracelet/log" + "github.com/charmbracelet/soft-serve/server/db" +) + +// MigrateFunc is a function that executes a migration. +type MigrateFunc func(ctx context.Context, tx *db.Tx) error // nolint:revive + +// Migration is a struct that contains the name of the migration and the +// function to execute it. +type Migration struct { + Version int64 + Name string + Migrate MigrateFunc + Rollback MigrateFunc +} + +// Migrations is a database model to store migrations. +type Migrations struct { + ID int64 `db:"id"` + Name string `db:"name"` + Version int64 `db:"version"` +} + +func (Migrations) schema(driverName string) string { + switch driverName { + case "sqlite3", "sqlite": + return `CREATE TABLE IF NOT EXISTS migrations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + version INTEGER NOT NULL UNIQUE + ); + ` + case "postgres": + return `CREATE TABLE IF NOT EXISTS migrations ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + version INTEGER NOT NULL UNIQUE + ); + ` + case "mysql": + return `CREATE TABLE IF NOT EXISTS migrations ( + id INT NOT NULL AUTO_INCREMENT, + name TEXT NOT NULL, + version INT NOT NULL, + UNIQUE (version), + PRIMARY KEY (id) + ); + ` + default: + panic("unknown driver") + } +} + +// Migrate runs the migrations. +func Migrate(ctx context.Context, dbx *db.DB) error { + logger := log.FromContext(ctx).WithPrefix("migrate") + return dbx.TransactionContext(ctx, func(tx *db.Tx) error { + if !hasTable(tx, "migrations") { + if _, err := tx.Exec(Migrations{}.schema(tx.DriverName())); err != nil { + return err + } + } + + var migrs Migrations + if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil { + if !errors.Is(err, sql.ErrNoRows) { + return err + } + } + + for _, m := range migrations { + if m.Version <= migrs.Version { + continue + } + + logger.Infof("running migration %d. %s", m.Version, m.Name) + if err := m.Migrate(ctx, tx); err != nil { + return err + } + + if _, err := tx.Exec(tx.Rebind("INSERT INTO migrations (name, version) VALUES (?, ?)"), m.Name, m.Version); err != nil { + return err + } + } + + return nil + }) +} + +// Rollback rolls back a migration. +func Rollback(ctx context.Context, dbx *db.DB) error { + logger := log.FromContext(ctx).WithPrefix("migrate") + return dbx.TransactionContext(ctx, func(tx *db.Tx) error { + var migrs Migrations + if err := tx.Get(&migrs, tx.Rebind("SELECT * FROM migrations ORDER BY version DESC LIMIT 1")); err != nil { + if !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("there are no migrations to rollback: %w", err) + } + } + + if len(migrations) < int(migrs.Version) { + return fmt.Errorf("there are no migrations to rollback") + } + + m := migrations[migrs.Version-1] + logger.Infof("rolling back migration %d. %s", m.Version, m.Name) + if err := m.Rollback(ctx, tx); err != nil { + return err + } + + if _, err := tx.Exec(tx.Rebind("DELETE FROM migrations WHERE version = ?"), migrs.Version); err != nil { + return err + } + + return nil + }) +} + +func hasTable(tx *db.Tx, tableName string) bool { + var query string + switch tx.DriverName() { + case "sqlite3", "sqlite": + query = "SELECT name FROM sqlite_master WHERE type='table' AND name=?" + case "postgres": + fallthrough + case "mysql": + query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = ?" + } + + query = tx.Rebind(query) + var name string + err := tx.Get(&name, query, tableName) + return err == nil +} diff --git a/server/db/migrate/migrations.go b/server/db/migrate/migrations.go new file mode 100644 index 0000000000000000000000000000000000000000..88a9e434696ecdfe5ccb58d844d1f244d3bbfce5 --- /dev/null +++ b/server/db/migrate/migrations.go @@ -0,0 +1,62 @@ +package migrate + +import ( + "context" + "embed" + "fmt" + "regexp" + "strings" + + "github.com/charmbracelet/soft-serve/server/db" +) + +//go:embed *.sql +var sqls embed.FS + +// Keep this in order of execution, oldest to newest. +var migrations = []Migration{ + createTables, +} + +func execMigration(ctx context.Context, tx *db.Tx, version int, name string, down bool) error { + direction := "up" + if down { + direction = "down" + } + + driverName := tx.DriverName() + if driverName == "sqlite3" { + driverName = "sqlite" + } + + fn := fmt.Sprintf("%04d_%s_%s.%s.sql", version, toSnakeCase(name), driverName, direction) + sqlstr, err := sqls.ReadFile(fn) + if err != nil { + return err + } + + if _, err := tx.ExecContext(ctx, string(sqlstr)); err != nil { + return err + } + + return nil +} + +func migrateUp(ctx context.Context, tx *db.Tx, version int, name string) error { + return execMigration(ctx, tx, version, name, false) +} + +func migrateDown(ctx context.Context, tx *db.Tx, version int, name string) error { + return execMigration(ctx, tx, version, name, true) +} + +var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") +var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])") + +func toSnakeCase(str string) string { + str = strings.ReplaceAll(str, "-", "_") + str = strings.ReplaceAll(str, " ", "_") + snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}") + snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}") + return strings.ToLower(snake) +} diff --git a/server/db/models/collab.go b/server/db/models/collab.go new file mode 100644 index 0000000000000000000000000000000000000000..e14660189d85fb325ca348492fb02216a6906780 --- /dev/null +++ b/server/db/models/collab.go @@ -0,0 +1,12 @@ +package models + +import "time" + +// Collab represents a repository collaborator. +type Collab struct { + ID int64 `db:"id"` + RepoID int64 `db:"repo_id"` + UserID int64 `db:"user_id"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} diff --git a/server/db/models/public_key.go b/server/db/models/public_key.go new file mode 100644 index 0000000000000000000000000000000000000000..a8551f46c0d5a81670595d2479f3043fbfb6e94b --- /dev/null +++ b/server/db/models/public_key.go @@ -0,0 +1,10 @@ +package models + +// PublicKey represents a public key. +type PublicKey struct { + ID int64 `db:"id"` + UserID int64 `db:"user_id"` + PublicKey string `db:"public_key"` + CreatedAt string `db:"created_at"` + UpdatedAt string `db:"updated_at"` +} diff --git a/server/db/models/repo.go b/server/db/models/repo.go new file mode 100644 index 0000000000000000000000000000000000000000..586a4f288941aa8298e1e3a527d6ec6d3cddc07b --- /dev/null +++ b/server/db/models/repo.go @@ -0,0 +1,16 @@ +package models + +import "time" + +// Repo is a database model for a repository. +type Repo struct { + ID int64 `db:"id"` + Name string `db:"name"` + ProjectName string `db:"project_name"` + Description string `db:"description"` + Private bool `db:"private"` + Mirror bool `db:"mirror"` + Hidden bool `db:"hidden"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} diff --git a/server/db/models/settings.go b/server/db/models/settings.go new file mode 100644 index 0000000000000000000000000000000000000000..c30bbb9e1f409daf42b5ed0d400bb4a2d73e04c4 --- /dev/null +++ b/server/db/models/settings.go @@ -0,0 +1,10 @@ +package models + +// Settings represents a settings record. +type Settings struct { + ID int64 `db:"id"` + Key string `db:"key"` + Value string `db:"value"` + CreatedAt string `db:"created_at"` + UpdatedAt string `db:"updated_at"` +} diff --git a/server/db/models/user.go b/server/db/models/user.go new file mode 100644 index 0000000000000000000000000000000000000000..8404a9bb4473738fc1d21e9ce4340a88ae0ff92d --- /dev/null +++ b/server/db/models/user.go @@ -0,0 +1,12 @@ +package models + +import "time" + +// User represents a user. +type User struct { + ID int64 `db:"id"` + Username string `db:"username"` + Admin bool `db:"admin"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} diff --git a/server/errors/errors.go b/server/errors/errors.go deleted file mode 100644 index 20399d09c42b8b13fbaaa69a0ac8df1a1986ad56..0000000000000000000000000000000000000000 --- a/server/errors/errors.go +++ /dev/null @@ -1,12 +0,0 @@ -package errors - -import "fmt" - -var ( - // ErrUnauthorized is returned when the user is not authorized to perform action. - ErrUnauthorized = fmt.Errorf("Unauthorized") - // ErrRepoNotFound is returned when the repo is not found. - ErrRepoNotFound = fmt.Errorf("Repository not found") - // ErrFileNotFound is returned when the file is not found. - ErrFileNotFound = fmt.Errorf("File not found") -) diff --git a/server/git/git.go b/server/git/git.go index 8f8ae3d7c6ecd30e68c55c0d53d5029f18a5445a..0dba095eeb1347c8f689c07472e3ac4daf6efba4 100644 --- a/server/git/git.go +++ b/server/git/git.go @@ -35,15 +35,22 @@ var ( ) // WritePktline encodes and writes a pktline to the given writer. -func WritePktline(w io.Writer, v ...interface{}) { +func WritePktline(w io.Writer, v ...interface{}) error { msg := fmt.Sprintln(v...) pkt := pktline.NewEncoder(w) if err := pkt.EncodeString(msg); err != nil { - log.Debugf("git: error writing pkt-line message: %s", err) + return fmt.Errorf("git: error writing pkt-line message: %w", err) } if err := pkt.Flush(); err != nil { - log.Debugf("git: error flushing pkt-line message: %s", err) + return fmt.Errorf("git: error flushing pkt-line message: %w", err) } + + return nil +} + +// WritePktlineErr writes an error pktline to the given writer. +func WritePktlineErr(w io.Writer, err error) error { + return WritePktline(w, "ERR", err.Error()) } // EnsureWithin ensures the given repo is within the repos directory. diff --git a/server/git/service.go b/server/git/service.go index 0730018403018a31336f01f8ddd6f38d70eccd9b..b5b0f6af402cf525fc31abcaee3416edbb10fc55 100644 --- a/server/git/service.go +++ b/server/git/service.go @@ -97,45 +97,47 @@ func gitServiceHandler(ctx context.Context, svc Service, scmd ServiceCommand) er log.Debugf("git service command in %q: %s", cmd.Dir, cmd.String()) if err := cmd.Start(); err != nil { + if errors.Is(err, os.ErrNotExist) { + return ErrInvalidRepo + } return err } - errg, ctx := errgroup.WithContext(ctx) + errg, _ := errgroup.WithContext(ctx) // stdin if scmd.Stdin != nil { errg.Go(func() error { - if scmd.StdinHandler != nil { - return scmd.StdinHandler(scmd.Stdin, stdin) - } else { - return defaultStdinHandler(scmd.Stdin, stdin) - } + defer stdin.Close() // nolint: errcheck + _, err := io.Copy(stdin, scmd.Stdin) + return err }) } // stdout if scmd.Stdout != nil { errg.Go(func() error { - if scmd.StdoutHandler != nil { - return scmd.StdoutHandler(scmd.Stdout, stdout) - } else { - return defaultStdoutHandler(scmd.Stdout, stdout) - } + _, err := io.Copy(scmd.Stdout, stdout) + return err }) } // stderr if scmd.Stderr != nil { errg.Go(func() error { - if scmd.StderrHandler != nil { - return scmd.StderrHandler(scmd.Stderr, stderr) - } else { - return defaultStderrHandler(scmd.Stderr, stderr) - } + _, erro := io.Copy(scmd.Stderr, stderr) + return erro }) } - return errors.Join(errg.Wait(), cmd.Wait()) + err = errors.Join(errg.Wait(), cmd.Wait()) + if err != nil && errors.Is(err, os.ErrNotExist) { + return ErrInvalidRepo + } else if err != nil { + return err + } + + return nil } // ServiceCommand is used to run a git service command. @@ -148,26 +150,7 @@ type ServiceCommand struct { Args []string // Modifier functions - CmdFunc func(*exec.Cmd) - StdinHandler func(io.Reader, io.WriteCloser) error - StdoutHandler func(io.Writer, io.ReadCloser) error - StderrHandler func(io.Writer, io.ReadCloser) error -} - -func defaultStdinHandler(in io.Reader, stdin io.WriteCloser) error { - defer stdin.Close() // nolint: errcheck - _, err := io.Copy(stdin, in) - return err -} - -func defaultStdoutHandler(out io.Writer, stdout io.ReadCloser) error { - _, err := io.Copy(out, stdout) - return err -} - -func defaultStderrHandler(err io.Writer, stderr io.ReadCloser) error { - _, erro := io.Copy(err, stderr) - return erro + CmdFunc func(*exec.Cmd) } // UploadPack runs the git upload-pack protocol against the provided repo. diff --git a/server/hooks/gen.go b/server/hooks/gen.go new file mode 100644 index 0000000000000000000000000000000000000000..67101d513e246d5e220631b9519fbf6e0761b3cd --- /dev/null +++ b/server/hooks/gen.go @@ -0,0 +1,140 @@ +package hooks + +import ( + "bytes" + "context" + "flag" + "os" + "path/filepath" + "text/template" + + "github.com/charmbracelet/log" + "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/utils" +) + +// The names of git server-side hooks. +const ( + PreReceiveHook = "pre-receive" + UpdateHook = "update" + PostReceiveHook = "post-receive" + PostUpdateHook = "post-update" +) + +// GenerateHooks generates git server-side hooks for a repository. Currently, it supports the following hooks: +// - pre-receive +// - update +// - post-receive +// - post-update +// +// This function should be called by the backend when a repository is created. +// TODO: support context. +func GenerateHooks(_ context.Context, cfg *config.Config, repo string) error { + // TODO: support git hook tests. + if flag.Lookup("test.v") != nil { + log.WithPrefix("backend.hooks").Warn("refusing to set up hooks when in test") + return nil + } + repo = utils.SanitizeRepo(repo) + ".git" + hooksPath := filepath.Join(cfg.DataPath, "repos", repo, "hooks") + if err := os.MkdirAll(hooksPath, os.ModePerm); err != nil { + return err + } + + ex, err := os.Executable() + if err != nil { + return err + } + + for _, hook := range []string{ + PreReceiveHook, + UpdateHook, + PostReceiveHook, + PostUpdateHook, + } { + var data bytes.Buffer + var args string + + // Hooks script/directory path + hp := filepath.Join(hooksPath, hook) + + // Write the hooks primary script + if err := os.WriteFile(hp, []byte(hookTemplate), os.ModePerm); err != nil { + return err + } + + // Create ${hook}.d directory. + hp += ".d" + if err := os.MkdirAll(hp, os.ModePerm); err != nil { + return err + } + + switch hook { + case UpdateHook: + args = "$1 $2 $3" + case PostUpdateHook: + args = "$@" + } + + if err := hooksTmpl.Execute(&data, struct { + Executable string + Hook string + Args string + }{ + Executable: ex, + Hook: hook, + Args: args, + }); err != nil { + log.WithPrefix("hooks").Error("failed to execute hook template", "err", err) + continue + } + + // Write the soft-serve hook inside ${hook}.d directory. + hp = filepath.Join(hp, "soft-serve") + err = os.WriteFile(hp, data.Bytes(), os.ModePerm) //nolint:gosec + if err != nil { + log.WithPrefix("hooks").Error("failed to write hook", "err", err) + continue + } + } + + return nil +} + +const ( + // hookTemplate allows us to run multiple hooks from a directory. It should + // support every type of git hook, as it proxies both stdin and arguments. + hookTemplate = `#!/usr/bin/env bash +# AUTO GENERATED BY SOFT SERVE, DO NOT MODIFY +data=$(cat) +exitcodes="" +hookname=$(basename $0) +GIT_DIR=${GIT_DIR:-$(dirname $0)/..} +for hook in ${GIT_DIR}/hooks/${hookname}.d/*; do + # Avoid running non-executable hooks + test -x "${hook}" && test -f "${hook}" || continue + + # Run the actual hook + echo "${data}" | "${hook}" "$@" + + # Store the exit code for later use + exitcodes="${exitcodes} $?" +done + +# Exit on the first non-zero exit code. +for i in ${exitcodes}; do + [ ${i} -eq 0 ] || exit ${i} +done +` +) + +// hooksTmpl is the soft-serve hook that will be run by the git hooks +// inside the hooks directory. +var hooksTmpl = template.Must(template.New("hooks").Parse(`#!/usr/bin/env bash +# AUTO GENERATED BY SOFT SERVE, DO NOT MODIFY +if [ -z "$SOFT_SERVE_REPO_NAME" ]; then + echo "Warning: SOFT_SERVE_REPO_NAME not defined. Skipping hooks." + exit 0 +fi +{{ .Executable }} hook {{ .Hook }} {{ .Args }} +`)) diff --git a/server/hooks/hooks.go b/server/hooks/hooks.go index c625769f27b31039fe3657969b1e37af6791d61b..0278050efcc728f62bdce130b65520e67f9d3cb8 100644 --- a/server/hooks/hooks.go +++ b/server/hooks/hooks.go @@ -1,156 +1,21 @@ package hooks import ( - "bytes" "context" - "flag" - "fmt" - "os" - "path/filepath" - "text/template" - - "github.com/charmbracelet/log" - "github.com/charmbracelet/soft-serve/server/config" - "github.com/charmbracelet/soft-serve/server/utils" -) - -// The names of git server-side hooks. -const ( - PreReceiveHook = "pre-receive" - UpdateHook = "update" - PostReceiveHook = "post-receive" - PostUpdateHook = "post-update" + "io" ) -// GenerateHooks generates git server-side hooks for a repository. Currently, it supports the following hooks: -// - pre-receive -// - update -// - post-receive -// - post-update -// -// This function should be called by the backend when a repository is created. -// TODO: support context. -func GenerateHooks(_ context.Context, cfg *config.Config, repo string) error { - // TODO: support git hook tests. - if flag.Lookup("test.v") != nil { - log.WithPrefix("backend.hooks").Warn("refusing to set up hooks when in test") - return nil - } - repo = utils.SanitizeRepo(repo) + ".git" - hooksPath := filepath.Join(cfg.DataPath, "repos", repo, "hooks") - if err := os.MkdirAll(hooksPath, os.ModePerm); err != nil { - return err - } - - ex, err := os.Executable() - if err != nil { - return err - } - - dp, err := filepath.Abs(cfg.DataPath) - if err != nil { - return fmt.Errorf("failed to get absolute path for data path: %w", err) - } - - cp := filepath.Join(dp, "config.yaml") - // Add extra environment variables to the hooks here. - envs := []string{} - - for _, hook := range []string{ - PreReceiveHook, - UpdateHook, - PostReceiveHook, - PostUpdateHook, - } { - var data bytes.Buffer - var args string - - // Hooks script/directory path - hp := filepath.Join(hooksPath, hook) - - // Write the hooks primary script - if err := os.WriteFile(hp, []byte(hookTemplate), os.ModePerm); err != nil { - return err - } - - // Create ${hook}.d directory. - hp += ".d" - if err := os.MkdirAll(hp, os.ModePerm); err != nil { - return err - } - - switch hook { - case UpdateHook: - args = "$1 $2 $3" - case PostUpdateHook: - args = "$@" - } - - if err := hooksTmpl.Execute(&data, struct { - Executable string - Config string - Envs []string - Hook string - Args string - }{ - Executable: ex, - Config: cp, - Envs: envs, - Hook: hook, - Args: args, - }); err != nil { - log.WithPrefix("backend.hooks").Error("failed to execute hook template", "err", err) - continue - } - - // Write the soft-serve hook inside ${hook}.d directory. - hp = filepath.Join(hp, "soft-serve") - err = os.WriteFile(hp, data.Bytes(), os.ModePerm) //nolint:gosec - if err != nil { - log.WithPrefix("backend.hooks").Error("failed to write hook", "err", err) - continue - } - } - - return nil +// HookArg is an argument to a git hook. +type HookArg struct { + OldSha string + NewSha string + RefName string } -const ( - // hookTemplate allows us to run multiple hooks from a directory. It should - // support every type of git hook, as it proxies both stdin and arguments. - hookTemplate = `#!/usr/bin/env bash -# AUTO GENERATED BY SOFT SERVE, DO NOT MODIFY -data=$(cat) -exitcodes="" -hookname=$(basename $0) -GIT_DIR=${GIT_DIR:-$(dirname $0)/..} -for hook in ${GIT_DIR}/hooks/${hookname}.d/*; do - # Avoid running non-executable hooks - test -x "${hook}" && test -f "${hook}" || continue - - # Run the actual hook - echo "${data}" | "${hook}" "$@" - - # Store the exit code for later use - exitcodes="${exitcodes} $?" -done - -# Exit on the first non-zero exit code. -for i in ${exitcodes}; do - [ ${i} -eq 0 ] || exit ${i} -done -` -) - -// hooksTmpl is the soft-serve hook that will be run by the git hooks -// inside the hooks directory. -var hooksTmpl = template.Must(template.New("hooks").Parse(`#!/usr/bin/env bash -# AUTO GENERATED BY SOFT SERVE, DO NOT MODIFY -if [ -z "$SOFT_SERVE_REPO_NAME" ]; then - echo "Warning: SOFT_SERVE_REPO_NAME not defined. Skipping hooks." - exit 0 -fi -{{ range $_, $env := .Envs }} -{{ $env }} \{{ end }} -{{ .Executable }} hook --config "{{ .Config }}" {{ .Hook }} {{ .Args }} -`)) +// Hooks provides an interface for git server-side hooks. +type Hooks interface { + PreReceive(ctx context.Context, stdout io.Writer, stderr io.Writer, repo string, args []HookArg) + Update(ctx context.Context, stdout io.Writer, stderr io.Writer, repo string, arg HookArg) + PostReceive(ctx context.Context, stdout io.Writer, stderr io.Writer, repo string, args []HookArg) + PostUpdate(ctx context.Context, stdout io.Writer, stderr io.Writer, repo string, args ...string) +} diff --git a/server/jobs.go b/server/jobs.go index 239b08d8bf6f1c89876c6f2b447ba8774c4b7729..2cfbd09844200eb92f4596b8f6eb32911775d315 100644 --- a/server/jobs.go +++ b/server/jobs.go @@ -6,7 +6,8 @@ import ( "runtime" "github.com/charmbracelet/soft-serve/git" - "github.com/charmbracelet/soft-serve/internal/sync" + "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/sync" ) var jobSpecs = map[string]string{ @@ -14,12 +15,11 @@ var jobSpecs = map[string]string{ } // mirrorJob runs the (pull) mirror job task. -func (s *Server) mirrorJob() func() { +func (s *Server) mirrorJob(b *backend.Backend) func() { cfg := s.Config - b := cfg.Backend logger := s.logger return func() { - repos, err := b.Repositories() + repos, err := b.Repositories(s.ctx) if err != nil { logger.Error("error getting repositories", "err", err) return @@ -48,6 +48,7 @@ func (s *Server) mirrorJob() func() { cfg.SSH.ClientKeyPath, ), ) + if _, err := cmd.RunInDir(r.Path); err != nil { logger.Error("error running git remote update", "repo", name, "err", err) } diff --git a/server/proto/errors.go b/server/proto/errors.go new file mode 100644 index 0000000000000000000000000000000000000000..b2d12dd07abf85061710e13a6611e3c2f310802f --- /dev/null +++ b/server/proto/errors.go @@ -0,0 +1,16 @@ +package proto + +import ( + "errors" +) + +var ( + // ErrUnauthorized is returned when the user is not authorized to perform action. + ErrUnauthorized = errors.New("Unauthorized") + // ErrFileNotFound is returned when the file is not found. + ErrFileNotFound = errors.New("File not found") + // ErrRepoNotExist is returned when a repository does not exist. + ErrRepoNotExist = errors.New("repository does not exist") + // ErrRepoExist is returned when a repository already exists. + ErrRepoExist = errors.New("repository already exists") +) diff --git a/server/proto/repo.go b/server/proto/repo.go new file mode 100644 index 0000000000000000000000000000000000000000..68d88741bce6cadc07ba2c4b685e364c53d03ff6 --- /dev/null +++ b/server/proto/repo.go @@ -0,0 +1,37 @@ +package proto + +import ( + "time" + + "github.com/charmbracelet/soft-serve/git" +) + +// Repository is a Git repository interface. +type Repository interface { + // Name returns the repository's name. + Name() string + // ProjectName returns the repository's project name. + ProjectName() string + // Description returns the repository's description. + Description() string + // IsPrivate returns whether the repository is private. + IsPrivate() bool + // IsMirror returns whether the repository is a mirror. + IsMirror() bool + // IsHidden returns whether the repository is hidden. + IsHidden() bool + // UpdatedAt returns the time the repository was last updated. + // If the repository has never been updated, it returns the time it was created. + UpdatedAt() time.Time + // Open returns the underlying git.Repository. + Open() (*git.Repository, error) +} + +// RepositoryOptions are options for creating a new repository. +type RepositoryOptions struct { + Private bool + Description string + ProjectName string + Mirror bool + Hidden bool +} diff --git a/server/proto/user.go b/server/proto/user.go new file mode 100644 index 0000000000000000000000000000000000000000..6276a14b7fe13497a3b802dcccca677eebd19c97 --- /dev/null +++ b/server/proto/user.go @@ -0,0 +1,21 @@ +package proto + +import "golang.org/x/crypto/ssh" + +// User is an interface representing a user. +type User interface { + // Username returns the user's username. + Username() string + // IsAdmin returns whether the user is an admin. + IsAdmin() bool + // PublicKeys returns the user's public keys. + PublicKeys() []ssh.PublicKey +} + +// UserOptions are options for creating a user. +type UserOptions struct { + // Admin is whether the user is an admin. + Admin bool + // PublicKeys are the user's public keys. + PublicKeys []ssh.PublicKey +} diff --git a/server/server.go b/server/server.go index 1c0470af365ea9a8ba70158b298262dd26073f7a..0b3f23b0e2b3055597d32e4568442218f2a84f92 100644 --- a/server/server.go +++ b/server/server.go @@ -4,16 +4,15 @@ import ( "context" "errors" "fmt" - "io" "net/http" "github.com/charmbracelet/log" "github.com/charmbracelet/soft-serve/server/backend" - "github.com/charmbracelet/soft-serve/server/backend/sqlite" "github.com/charmbracelet/soft-serve/server/config" "github.com/charmbracelet/soft-serve/server/cron" "github.com/charmbracelet/soft-serve/server/daemon" + "github.com/charmbracelet/soft-serve/server/db" sshsrv "github.com/charmbracelet/soft-serve/server/ssh" "github.com/charmbracelet/soft-serve/server/stats" "github.com/charmbracelet/soft-serve/server/web" @@ -27,43 +26,35 @@ type Server struct { GitDaemon *daemon.GitDaemon HTTPServer *web.HTTPServer StatsServer *stats.StatsServer - Cron *cron.CronScheduler + Cron *cron.Scheduler Config *config.Config - Backend backend.Backend + Backend *backend.Backend + DB *db.DB logger *log.Logger ctx context.Context } -// NewServer returns a new *ssh.Server configured to serve Soft Serve. The SSH -// server key-pair will be created if none exists. An initial admin SSH public -// key can be provided with authKey. If authKey is provided, access will be -// restricted to that key. If authKey is not provided, the server will be -// publicly writable until configured otherwise by cloning the `config` repo. +// NewServer returns a new *Server configured to serve Soft Serve. The SSH +// server key-pair will be created if none exists. +// It expects a context with *backend.Backend, *db.DB, *log.Logger, and +// *config.Config attached. func NewServer(ctx context.Context) (*Server, error) { - cfg := config.FromContext(ctx) - var err error - if cfg.Backend == nil { - sb, err := sqlite.NewSqliteBackend(ctx) - if err != nil { - return nil, fmt.Errorf("create backend: %w", err) - } - - cfg = cfg.WithBackend(sb) - ctx = backend.WithContext(ctx, sb) - } - + cfg := config.FromContext(ctx) + be := backend.FromContext(ctx) + db := db.FromContext(ctx) srv := &Server{ - Cron: cron.NewCronScheduler(ctx), + Cron: cron.NewScheduler(ctx), Config: cfg, - Backend: cfg.Backend, + Backend: be, + DB: db, logger: log.FromContext(ctx).WithPrefix("server"), ctx: ctx, } // Add cron jobs. - _, _ = srv.Cron.AddFunc(jobSpecs["mirror"], srv.mirrorJob()) + _, _ = srv.Cron.AddFunc(jobSpecs["mirror"], srv.mirrorJob(be)) srv.SSHServer, err = sshsrv.NewSSHServer(ctx) if err != nil { @@ -88,47 +79,33 @@ func NewServer(ctx context.Context) (*Server, error) { return srv, nil } -func start(ctx context.Context, fn func() error) error { - errc := make(chan error, 1) - go func() { - errc <- fn() - }() - - select { - case err := <-errc: - return err - case <-ctx.Done(): - return ctx.Err() - } -} - // Start starts the SSH server. func (s *Server) Start() error { - errg, ctx := errgroup.WithContext(s.ctx) + errg, _ := errgroup.WithContext(s.ctx) errg.Go(func() error { s.logger.Print("Starting Git daemon", "addr", s.Config.Git.ListenAddr) - if err := start(ctx, s.GitDaemon.Start); !errors.Is(err, daemon.ErrServerClosed) { + if err := s.GitDaemon.Start(); !errors.Is(err, daemon.ErrServerClosed) { return err } return nil }) errg.Go(func() error { s.logger.Print("Starting HTTP server", "addr", s.Config.HTTP.ListenAddr) - if err := start(ctx, s.HTTPServer.ListenAndServe); !errors.Is(err, http.ErrServerClosed) { + if err := s.HTTPServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { return err } return nil }) errg.Go(func() error { s.logger.Print("Starting SSH server", "addr", s.Config.SSH.ListenAddr) - if err := start(ctx, s.SSHServer.ListenAndServe); !errors.Is(err, ssh.ErrServerClosed) { + if err := s.SSHServer.ListenAndServe(); !errors.Is(err, ssh.ErrServerClosed) { return err } return nil }) errg.Go(func() error { s.logger.Print("Starting Stats server", "addr", s.Config.Stats.ListenAddr) - if err := start(ctx, s.StatsServer.ListenAndServe); !errors.Is(err, http.ErrServerClosed) { + if err := s.StatsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { return err } return nil @@ -142,7 +119,7 @@ func (s *Server) Start() error { // Shutdown lets the server gracefully shutdown. func (s *Server) Shutdown(ctx context.Context) error { - var errg errgroup.Group + errg, ctx := errgroup.WithContext(ctx) errg.Go(func() error { return s.GitDaemon.Shutdown(ctx) }) @@ -159,9 +136,7 @@ func (s *Server) Shutdown(ctx context.Context) error { s.Cron.Stop() return nil }) - if closer, ok := s.Backend.(io.Closer); ok { - defer closer.Close() // nolint: errcheck - } + // defer s.DB.Close() // nolint: errcheck return errg.Wait() } @@ -176,8 +151,6 @@ func (s *Server) Close() error { s.Cron.Stop() return nil }) - if closer, ok := s.Backend.(io.Closer); ok { - defer closer.Close() // nolint: errcheck - } + // defer s.DB.Close() // nolint: errcheck return errg.Wait() } diff --git a/server/server_test.go b/server/server_test.go deleted file mode 100644 index a856f1f4c33aeec35607d0b19420378eb0432776..0000000000000000000000000000000000000000 --- a/server/server_test.go +++ /dev/null @@ -1,58 +0,0 @@ -package server - -import ( - "context" - "fmt" - "path/filepath" - "strings" - "testing" - - "github.com/charmbracelet/keygen" - "github.com/charmbracelet/soft-serve/server/config" - "github.com/charmbracelet/soft-serve/server/test" - "github.com/charmbracelet/ssh" - "github.com/matryer/is" - gossh "golang.org/x/crypto/ssh" -) - -func setupServer(tb testing.TB) (*Server, *config.Config, string) { - tb.Helper() - tb.Log("creating keypair") - pub, pkPath := createKeyPair(tb) - dp := tb.TempDir() - sshPort := fmt.Sprintf(":%d", test.RandomPort()) - tb.Setenv("SOFT_SERVE_DATA_PATH", dp) - tb.Setenv("SOFT_SERVE_INITIAL_ADMIN_KEY", authorizedKey(pub)) - tb.Setenv("SOFT_SERVE_SSH_LISTEN_ADDR", sshPort) - tb.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", test.RandomPort())) - ctx := context.TODO() - cfg := config.DefaultConfig() - ctx = config.WithContext(ctx, cfg) - tb.Log("configuring server") - s, err := NewServer(ctx) - if err != nil { - tb.Fatal(err) - } - go func() { - tb.Log("starting server") - s.Start() - }() - tb.Cleanup(func() { - s.Close() - }) - return s, cfg, pkPath -} - -func createKeyPair(tb testing.TB) (ssh.PublicKey, string) { - tb.Helper() - is := is.New(tb) - keyDir := tb.TempDir() - fp := filepath.Join(keyDir, "id_ed25519") - kp, err := keygen.New(fp, keygen.WithKeyType(keygen.Ed25519), keygen.WithWrite()) - is.NoErr(err) - return kp.PublicKey(), fp -} - -func authorizedKey(pk ssh.PublicKey) string { - return strings.TrimSpace(string(gossh.MarshalAuthorizedKey(pk))) -} diff --git a/server/ssh/cmd.go b/server/ssh/cmd.go new file mode 100644 index 0000000000000000000000000000000000000000..f2563f84208aaed4932bc12b40ab89950537758b --- /dev/null +++ b/server/ssh/cmd.go @@ -0,0 +1,17 @@ +package ssh + +import ( + "github.com/charmbracelet/log" + "github.com/charmbracelet/soft-serve/server/ssh/cmd" + "github.com/charmbracelet/ssh" +) + +func handleCli(s ssh.Session) { + ctx := s.Context() + logger := log.FromContext(ctx) + rootCmd := cmd.RootCommand(s) + if err := rootCmd.ExecuteContext(ctx); err != nil { + logger.Error("error executing command", "err", err) + _ = s.Exit(1) + } +} diff --git a/server/cmd/blob.go b/server/ssh/cmd/blob.go similarity index 93% rename from server/cmd/blob.go rename to server/ssh/cmd/blob.go index d1388fb5f9283b32b001464f6bded9968ca123ee..5065bd1c63177cd86cf0c0e9d4af6515eded0f1a 100644 --- a/server/cmd/blob.go +++ b/server/ssh/cmd/blob.go @@ -8,6 +8,7 @@ import ( gansi "github.com/charmbracelet/glamour/ansi" "github.com/charmbracelet/lipgloss" "github.com/charmbracelet/soft-serve/git" + "github.com/charmbracelet/soft-serve/server/backend" "github.com/charmbracelet/soft-serve/server/ui/common" "github.com/muesli/termenv" "github.com/spf13/cobra" @@ -16,9 +17,6 @@ import ( var ( lineDigitStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("239")) lineBarStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("236")) - dirnameStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#00AAFF")) - filenameStyle = lipgloss.NewStyle() - filemodeStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#777777")) ) // blobCommand returns a command that prints the contents of a file. @@ -34,7 +32,8 @@ func blobCommand() *cobra.Command { Args: cobra.RangeArgs(1, 3), PersistentPreRunE: checkIfReadable, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) rn := args[0] ref := "" fp := "" @@ -46,7 +45,7 @@ func blobCommand() *cobra.Command { fp = args[2] } - repo, err := cfg.Backend.Repository(rn) + repo, err := be.Repository(ctx, rn) if err != nil { return err } diff --git a/server/cmd/branch.go b/server/ssh/cmd/branch.go similarity index 86% rename from server/cmd/branch.go rename to server/ssh/cmd/branch.go index 0b79ecceea8e730d6a00a4f9147761e796ba7dc9..8566a7d4d1945c7390267ad1eab56b4538339dde 100644 --- a/server/cmd/branch.go +++ b/server/ssh/cmd/branch.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/charmbracelet/soft-serve/git" + "github.com/charmbracelet/soft-serve/server/backend" gitm "github.com/gogs/git-module" "github.com/spf13/cobra" ) @@ -31,9 +32,10 @@ func branchListCommand() *cobra.Command { Args: cobra.ExactArgs(1), PersistentPreRunE: checkIfReadable, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) rn := strings.TrimSuffix(args[0], ".git") - rr, err := cfg.Backend.Repository(rn) + rr, err := be.Repository(ctx, rn) if err != nil { return err } @@ -61,14 +63,15 @@ func branchDefaultCommand() *cobra.Command { Short: "Set or get the default branch", Args: cobra.RangeArgs(1, 2), RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) rn := strings.TrimSuffix(args[0], ".git") switch len(args) { case 1: if err := checkIfReadable(cmd, args); err != nil { return err } - rr, err := cfg.Backend.Repository(rn) + rr, err := be.Repository(ctx, rn) if err != nil { return err } @@ -89,7 +92,7 @@ func branchDefaultCommand() *cobra.Command { return err } - rr, err := cfg.Backend.Repository(rn) + rr, err := be.Repository(ctx, rn) if err != nil { return err } @@ -132,9 +135,10 @@ func branchDeleteCommand() *cobra.Command { Short: "Delete a branch", PersistentPreRunE: checkIfCollab, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) rn := strings.TrimSuffix(args[0], ".git") - rr, err := cfg.Backend.Repository(rn) + rr, err := be.Repository(ctx, rn) if err != nil { return err } @@ -167,11 +171,7 @@ func branchDeleteCommand() *cobra.Command { return fmt.Errorf("cannot delete the default branch") } - if err := r.DeleteBranch(branch, gitm.DeleteBranchOptions{Force: true}); err != nil { - return err - } - - return nil + return r.DeleteBranch(branch, gitm.DeleteBranchOptions{Force: true}) }, } diff --git a/server/cmd/cmd.go b/server/ssh/cmd/cmd.go similarity index 64% rename from server/cmd/cmd.go rename to server/ssh/cmd/cmd.go index 16d89d1639fd2153899a90da431e7e7fbe8fa448..d7e3f5ddcc9bfe2c78437a4b0a7dff361c6c094d 100644 --- a/server/cmd/cmd.go +++ b/server/ssh/cmd/cmd.go @@ -1,28 +1,24 @@ package cmd import ( - "context" "fmt" "net/url" "strings" "text/template" "unicode" - "github.com/charmbracelet/log" + "github.com/charmbracelet/soft-serve/server/access" "github.com/charmbracelet/soft-serve/server/backend" "github.com/charmbracelet/soft-serve/server/config" - "github.com/charmbracelet/soft-serve/server/errors" + "github.com/charmbracelet/soft-serve/server/proto" + "github.com/charmbracelet/soft-serve/server/sshutils" "github.com/charmbracelet/soft-serve/server/utils" "github.com/charmbracelet/ssh" - "github.com/charmbracelet/wish" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/spf13/cobra" ) -// sessionCtxKey is the key for the session in the context. -var sessionCtxKey = &struct{ string }{"session"} - var cliCommandCounter = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: "soft_serve", Subsystem: "cli", @@ -89,9 +85,14 @@ func cmdName(args []string) string { return args[0] } -// rootCommand is the root command for the server. -func rootCommand(cfg *config.Config, s ssh.Session) *cobra.Command { - cliCommandCounter.WithLabelValues(cmdName(s.Command())).Inc() +// RootCommand returns a new cli root command. +func RootCommand(s ssh.Session) *cobra.Command { + ctx := s.Context() + cfg := config.FromContext(ctx) + be := backend.FromContext(ctx) + + args := s.Command() + cliCommandCounter.WithLabelValues(cmdName(args)).Inc() rootCmd := &cobra.Command{ Short: "Soft Serve is a self-hostable Git server for the command line.", SilenceUsage: true, @@ -129,7 +130,17 @@ func rootCommand(cfg *config.Config, s ssh.Session) *cobra.Command { repoCommand(), ) - user, _ := cfg.Backend.UserByPublicKey(s.PublicKey()) + rootCmd.SetArgs(args) + if len(args) == 0 { + // otherwise it'll default to os.Args, which is not what we want. + rootCmd.SetArgs([]string{"--help"}) + } + rootCmd.SetIn(s) + rootCmd.SetOut(s) + rootCmd.CompletionOptions.DisableDefaultCmd = true + rootCmd.SetErr(s.Stderr()) + + user, _ := be.UserByPublicKey(s.Context(), s.PublicKey()) isAdmin := isPublicKeyAdmin(cfg, s.PublicKey()) || (user != nil && user.IsAdmin()) if user != nil || isAdmin { if isAdmin { @@ -149,30 +160,26 @@ func rootCommand(cfg *config.Config, s ssh.Session) *cobra.Command { return rootCmd } -func fromContext(cmd *cobra.Command) (*config.Config, ssh.Session) { - ctx := cmd.Context() - cfg := config.FromContext(ctx) - s := ctx.Value(sessionCtxKey).(ssh.Session) - return cfg, s -} - func checkIfReadable(cmd *cobra.Command, args []string) error { var repo string if len(args) > 0 { repo = args[0] } - cfg, s := fromContext(cmd) + + ctx := cmd.Context() + be := backend.FromContext(ctx) rn := utils.SanitizeRepo(repo) - auth := cfg.Backend.AccessLevelByPublicKey(rn, s.PublicKey()) - if auth < backend.ReadOnlyAccess { - return errors.ErrUnauthorized + pk := sshutils.PublicKeyFromContext(ctx) + auth := be.AccessLevelByPublicKey(cmd.Context(), rn, pk) + if auth < access.ReadOnlyAccess { + return proto.ErrUnauthorized } return nil } func isPublicKeyAdmin(cfg *config.Config, pk ssh.PublicKey) bool { for _, k := range cfg.AdminKeys() { - if backend.KeysEqual(pk, k) { + if sshutils.KeysEqual(pk, k) { return true } } @@ -180,18 +187,21 @@ func isPublicKeyAdmin(cfg *config.Config, pk ssh.PublicKey) bool { } func checkIfAdmin(cmd *cobra.Command, _ []string) error { - cfg, s := fromContext(cmd) - if isPublicKeyAdmin(cfg, s.PublicKey()) { + ctx := cmd.Context() + be := backend.FromContext(ctx) + cfg := config.FromContext(ctx) + pk := sshutils.PublicKeyFromContext(ctx) + if isPublicKeyAdmin(cfg, pk) { return nil } - user, _ := cfg.Backend.UserByPublicKey(s.PublicKey()) + user, _ := be.UserByPublicKey(ctx, pk) if user == nil { - return errors.ErrUnauthorized + return proto.ErrUnauthorized } if !user.IsAdmin() { - return errors.ErrUnauthorized + return proto.ErrUnauthorized } return nil @@ -202,61 +212,14 @@ func checkIfCollab(cmd *cobra.Command, args []string) error { if len(args) > 0 { repo = args[0] } - cfg, s := fromContext(cmd) + + ctx := cmd.Context() + be := backend.FromContext(ctx) + pk := sshutils.PublicKeyFromContext(ctx) rn := utils.SanitizeRepo(repo) - auth := cfg.Backend.AccessLevelByPublicKey(rn, s.PublicKey()) - if auth < backend.ReadWriteAccess { - return errors.ErrUnauthorized + auth := be.AccessLevelByPublicKey(ctx, rn, pk) + if auth < access.ReadWriteAccess { + return proto.ErrUnauthorized } return nil } - -// Middleware is the Soft Serve middleware that handles SSH commands. -func Middleware(cfg *config.Config, logger *log.Logger) wish.Middleware { - return func(sh ssh.Handler) ssh.Handler { - return func(s ssh.Session) { - func() { - _, _, active := s.Pty() - if active { - return - } - - // Ignore git server commands. - args := s.Command() - if len(args) > 0 { - if args[0] == "git-receive-pack" || - args[0] == "git-upload-pack" || - args[0] == "git-upload-archive" { - return - } - } - - // Here we copy the server's config and replace the backend - // with a new one that uses the session's context. - var ctx context.Context = s.Context() - scfg := *cfg - cfg = &scfg - be := cfg.Backend.WithContext(ctx) - cfg.Backend = be - ctx = config.WithContext(ctx, cfg) - ctx = backend.WithContext(ctx, be) - ctx = context.WithValue(ctx, sessionCtxKey, s) - - rootCmd := rootCommand(cfg, s) - rootCmd.SetArgs(args) - if len(args) == 0 { - // otherwise it'll default to os.Args, which is not what we want. - rootCmd.SetArgs([]string{"--help"}) - } - rootCmd.SetIn(s) - rootCmd.SetOut(s) - rootCmd.CompletionOptions.DisableDefaultCmd = true - rootCmd.SetErr(s.Stderr()) - if err := rootCmd.ExecuteContext(ctx); err != nil { - _ = s.Exit(1) - } - }() - sh(s) - } - } -} diff --git a/server/cmd/collab.go b/server/ssh/cmd/collab.go similarity index 80% rename from server/cmd/collab.go rename to server/ssh/cmd/collab.go index 08baf5a065cec13748c99cbae5c1a8528b4ba9de..92a0829d8b4dc5ba0db8449233db2825b5e60aeb 100644 --- a/server/cmd/collab.go +++ b/server/ssh/cmd/collab.go @@ -1,6 +1,7 @@ package cmd import ( + "github.com/charmbracelet/soft-serve/server/backend" "github.com/spf13/cobra" ) @@ -27,11 +28,12 @@ func collabAddCommand() *cobra.Command { Args: cobra.ExactArgs(2), PersistentPreRunE: checkIfCollab, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) repo := args[0] username := args[1] - return cfg.Backend.AddCollaborator(repo, username) + return be.AddCollaborator(ctx, repo, username) }, } @@ -45,11 +47,12 @@ func collabRemoveCommand() *cobra.Command { Short: "Remove a collaborator from a repo", PersistentPreRunE: checkIfCollab, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) repo := args[0] username := args[1] - return cfg.Backend.RemoveCollaborator(repo, username) + return be.RemoveCollaborator(ctx, repo, username) }, } @@ -63,9 +66,10 @@ func collabListCommand() *cobra.Command { Args: cobra.ExactArgs(1), PersistentPreRunE: checkIfCollab, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) repo := args[0] - collabs, err := cfg.Backend.Collaborators(repo) + collabs, err := be.Collaborators(ctx, repo) if err != nil { return err } diff --git a/server/cmd/commit.go b/server/ssh/cmd/commit.go similarity index 90% rename from server/cmd/commit.go rename to server/ssh/cmd/commit.go index 9b96f9103e6790ab68c133383ba24da416a61c42..f8ffaa1cf18aaaf3fb663de66aeabeb656c52028 100644 --- a/server/cmd/commit.go +++ b/server/ssh/cmd/commit.go @@ -7,6 +7,7 @@ import ( gansi "github.com/charmbracelet/glamour/ansi" "github.com/charmbracelet/soft-serve/git" + "github.com/charmbracelet/soft-serve/server/backend" "github.com/charmbracelet/soft-serve/server/ui/common" "github.com/charmbracelet/soft-serve/server/ui/styles" "github.com/muesli/termenv" @@ -24,11 +25,12 @@ func commitCommand() *cobra.Command { Args: cobra.ExactArgs(2), PersistentPreRunE: checkIfReadable, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) repoName := args[0] commitSHA := args[1] - rr, err := cfg.Backend.Repository(repoName) + rr, err := be.Repository(ctx, repoName) if err != nil { return err } @@ -38,10 +40,13 @@ func commitCommand() *cobra.Command { return err } - raw_commit, err := r.CommitByRevision(commitSHA) + rawCommit, err := r.CommitByRevision(commitSHA) + if err != nil { + return err + } commit := &git.Commit{ - Commit: raw_commit, + Commit: rawCommit, Hash: git.Hash(commitSHA), } @@ -61,7 +66,7 @@ func commitCommand() *cobra.Command { s := strings.Builder{} commitLine := "commit " + commitSHA authorLine := "Author: " + commit.Author.Name - dateLine := "Date: " + commit.Committer.When.Format(time.UnixDate) + dateLine := "Date: " + commit.Committer.When.UTC().Format(time.UnixDate) msgLine := strings.ReplaceAll(commit.Message, "\r\n", "\n") statsLine := renderStats(diff, commonStyle, color) diffLine := renderDiff(patch, color) @@ -116,8 +121,7 @@ func renderCtx() gansi.RenderContext { } func renderDiff(patch string, color bool) string { - - c := string(patch) + c := patch if color { var s strings.Builder diff --git a/server/cmd/create.go b/server/ssh/cmd/create.go similarity index 85% rename from server/cmd/create.go rename to server/ssh/cmd/create.go index 0407ce6d9c78ca7cba9b486b788bb83f3f8adfe1..751c18e1438eb03b23d4649499378c61679cd906 100644 --- a/server/cmd/create.go +++ b/server/ssh/cmd/create.go @@ -2,6 +2,7 @@ package cmd import ( "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/proto" "github.com/spf13/cobra" ) @@ -18,9 +19,10 @@ func createCommand() *cobra.Command { Args: cobra.ExactArgs(1), PersistentPreRunE: checkIfCollab, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) name := args[0] - if _, err := cfg.Backend.CreateRepository(name, backend.RepositoryOptions{ + if _, err := be.CreateRepository(ctx, name, proto.RepositoryOptions{ Private: private, Description: description, ProjectName: projectName, diff --git a/server/cmd/delete.go b/server/ssh/cmd/delete.go similarity index 67% rename from server/cmd/delete.go rename to server/ssh/cmd/delete.go index fb3f1dbdf1a8d976133de9c742447d51b9576975..02dff775d469753744f051c89c2887ce3bb3d367 100644 --- a/server/cmd/delete.go +++ b/server/ssh/cmd/delete.go @@ -1,6 +1,9 @@ package cmd -import "github.com/spf13/cobra" +import ( + "github.com/charmbracelet/soft-serve/server/backend" + "github.com/spf13/cobra" +) func deleteCommand() *cobra.Command { cmd := &cobra.Command{ @@ -10,12 +13,11 @@ func deleteCommand() *cobra.Command { Args: cobra.ExactArgs(1), PersistentPreRunE: checkIfCollab, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) name := args[0] - if err := cfg.Backend.DeleteRepository(name); err != nil { - return err - } - return nil + + return be.DeleteRepository(ctx, name) }, } return cmd diff --git a/server/cmd/description.go b/server/ssh/cmd/description.go similarity index 75% rename from server/cmd/description.go rename to server/ssh/cmd/description.go index 2708002725365adb75ea1e4c146460bc5c49642c..e80f3ff568ecf0c480d0501676218e9dd2454cee 100644 --- a/server/cmd/description.go +++ b/server/ssh/cmd/description.go @@ -3,6 +3,7 @@ package cmd import ( "strings" + "github.com/charmbracelet/soft-serve/server/backend" "github.com/spf13/cobra" ) @@ -13,7 +14,8 @@ func descriptionCommand() *cobra.Command { Short: "Set or get the description for a repository", Args: cobra.MinimumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) rn := strings.TrimSuffix(args[0], ".git") switch len(args) { case 1: @@ -21,7 +23,7 @@ func descriptionCommand() *cobra.Command { return err } - desc, err := cfg.Backend.Description(rn) + desc, err := be.Description(ctx, rn) if err != nil { return err } @@ -31,7 +33,7 @@ func descriptionCommand() *cobra.Command { if err := checkIfCollab(cmd, args); err != nil { return err } - if err := cfg.Backend.SetDescription(rn, strings.Join(args[1:], " ")); err != nil { + if err := be.SetDescription(ctx, rn, strings.Join(args[1:], " ")); err != nil { return err } } diff --git a/server/cmd/hidden.go b/server/ssh/cmd/hidden.go similarity index 72% rename from server/cmd/hidden.go rename to server/ssh/cmd/hidden.go index 6c10d1a6d5e67f6013b236b0c8948d34761a7381..2e2f4f486130aff37eb2f1b4d7d04e00a8df5845 100644 --- a/server/cmd/hidden.go +++ b/server/ssh/cmd/hidden.go @@ -1,6 +1,9 @@ package cmd -import "github.com/spf13/cobra" +import ( + "github.com/charmbracelet/soft-serve/server/backend" + "github.com/spf13/cobra" +) func hiddenCommand() *cobra.Command { cmd := &cobra.Command{ @@ -9,7 +12,8 @@ func hiddenCommand() *cobra.Command { Aliases: []string{"hide"}, Args: cobra.MinimumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) repo := args[0] switch len(args) { case 1: @@ -17,7 +21,7 @@ func hiddenCommand() *cobra.Command { return err } - hidden, err := cfg.Backend.IsHidden(repo) + hidden, err := be.IsHidden(ctx, repo) if err != nil { return err } @@ -29,7 +33,7 @@ func hiddenCommand() *cobra.Command { } hidden := args[1] == "true" - if err := cfg.Backend.SetHidden(repo, hidden); err != nil { + if err := be.SetHidden(ctx, repo, hidden); err != nil { return err } } diff --git a/server/cmd/import.go b/server/ssh/cmd/import.go similarity index 86% rename from server/cmd/import.go rename to server/ssh/cmd/import.go index e14875a2c97852eb669bfa032b58ee43588e66ae..55f73e4f5df78dfb21f41655bba96f6f04a8ba02 100644 --- a/server/cmd/import.go +++ b/server/ssh/cmd/import.go @@ -2,6 +2,7 @@ package cmd import ( "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/proto" "github.com/spf13/cobra" ) @@ -19,10 +20,11 @@ func importCommand() *cobra.Command { Args: cobra.ExactArgs(2), PersistentPreRunE: checkIfCollab, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) name := args[0] remote := args[1] - if _, err := cfg.Backend.ImportRepository(name, remote, backend.RepositoryOptions{ + if _, err := be.ImportRepository(ctx, name, remote, proto.RepositoryOptions{ Private: private, Description: description, ProjectName: projectName, diff --git a/server/cmd/info.go b/server/ssh/cmd/info.go similarity index 67% rename from server/cmd/info.go rename to server/ssh/cmd/info.go index 60d716d72dbfea017627928b3dff8ecc99338f36..7b2ac0aeff29402323345fdfd284cea919c03cd0 100644 --- a/server/cmd/info.go +++ b/server/ssh/cmd/info.go @@ -2,6 +2,7 @@ package cmd import ( "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/sshutils" "github.com/spf13/cobra" ) @@ -11,8 +12,10 @@ func infoCommand() *cobra.Command { Short: "Show your info", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - cfg, s := fromContext(cmd) - user, err := cfg.Backend.UserByPublicKey(s.PublicKey()) + ctx := cmd.Context() + be := backend.FromContext(ctx) + pk := sshutils.PublicKeyFromContext(ctx) + user, err := be.UserByPublicKey(ctx, pk) if err != nil { return err } @@ -21,7 +24,7 @@ func infoCommand() *cobra.Command { cmd.Printf("Admin: %t\n", user.IsAdmin()) cmd.Printf("Public keys:\n") for _, pk := range user.PublicKeys() { - cmd.Printf(" %s\n", backend.MarshalAuthorizedKey(pk)) + cmd.Printf(" %s\n", sshutils.MarshalAuthorizedKey(pk)) } return nil }, diff --git a/server/cmd/list.go b/server/ssh/cmd/list.go similarity index 67% rename from server/cmd/list.go rename to server/ssh/cmd/list.go index 9cfb936a5f2428c59259e12110456ee200128390..62334c871046f732d741bdb632adc07f0517ed06 100644 --- a/server/cmd/list.go +++ b/server/ssh/cmd/list.go @@ -1,7 +1,9 @@ package cmd import ( + "github.com/charmbracelet/soft-serve/server/access" "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/sshutils" "github.com/spf13/cobra" ) @@ -15,13 +17,15 @@ func listCommand() *cobra.Command { Short: "List repositories", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - cfg, s := fromContext(cmd) - repos, err := cfg.Backend.Repositories() + ctx := cmd.Context() + be := backend.FromContext(ctx) + pk := sshutils.PublicKeyFromContext(ctx) + repos, err := be.Repositories(ctx) if err != nil { return err } for _, r := range repos { - if cfg.Backend.AccessLevelByPublicKey(r.Name(), s.PublicKey()) >= backend.ReadOnlyAccess { + if be.AccessLevelByPublicKey(ctx, r.Name(), pk) >= access.ReadOnlyAccess { if !r.IsHidden() || all { cmd.Println(r.Name()) } diff --git a/server/cmd/mirror.go b/server/ssh/cmd/mirror.go similarity index 76% rename from server/cmd/mirror.go rename to server/ssh/cmd/mirror.go index 34e785ab7037e9badc5e2f25bcfa9d5c5242175f..d2e641c23ddec19deeaaa65ce6ecafcbbdd0a055 100644 --- a/server/cmd/mirror.go +++ b/server/ssh/cmd/mirror.go @@ -1,6 +1,7 @@ package cmd import ( + "github.com/charmbracelet/soft-serve/server/backend" "github.com/spf13/cobra" ) @@ -11,9 +12,10 @@ func mirrorCommand() *cobra.Command { Args: cobra.ExactArgs(1), PersistentPreRunE: checkIfReadable, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) rn := args[0] - rr, err := cfg.Backend.Repository(rn) + rr, err := be.Repository(ctx, rn) if err != nil { return err } diff --git a/server/cmd/private.go b/server/ssh/cmd/private.go similarity index 78% rename from server/cmd/private.go rename to server/ssh/cmd/private.go index 3b5181d48599aa38dbd7db64b2e32efc0cb58835..f835c804a3800eff0ad096e3bf8d806e5832d706 100644 --- a/server/cmd/private.go +++ b/server/ssh/cmd/private.go @@ -4,6 +4,7 @@ import ( "strconv" "strings" + "github.com/charmbracelet/soft-serve/server/backend" "github.com/spf13/cobra" ) @@ -13,7 +14,8 @@ func privateCommand() *cobra.Command { Short: "Set or get a repository private property", Args: cobra.RangeArgs(1, 2), RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) rn := strings.TrimSuffix(args[0], ".git") switch len(args) { @@ -22,7 +24,7 @@ func privateCommand() *cobra.Command { return err } - isPrivate, err := cfg.Backend.IsPrivate(rn) + isPrivate, err := be.IsPrivate(ctx, rn) if err != nil { return err } @@ -36,7 +38,7 @@ func privateCommand() *cobra.Command { if err := checkIfCollab(cmd, args); err != nil { return err } - if err := cfg.Backend.SetPrivate(rn, isPrivate); err != nil { + if err := be.SetPrivate(ctx, rn, isPrivate); err != nil { return err } } diff --git a/server/cmd/project_name.go b/server/ssh/cmd/project_name.go similarity index 75% rename from server/cmd/project_name.go rename to server/ssh/cmd/project_name.go index 62e7f82b15c178b59738763c786986de560decfa..8eb9b05ad9b61c0d79b4fc2dcfcc0f97c54f8904 100644 --- a/server/cmd/project_name.go +++ b/server/ssh/cmd/project_name.go @@ -3,6 +3,7 @@ package cmd import ( "strings" + "github.com/charmbracelet/soft-serve/server/backend" "github.com/spf13/cobra" ) @@ -13,7 +14,8 @@ func projectName() *cobra.Command { Short: "Set or get the project name for a repository", Args: cobra.MinimumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) rn := strings.TrimSuffix(args[0], ".git") switch len(args) { case 1: @@ -21,7 +23,7 @@ func projectName() *cobra.Command { return err } - pn, err := cfg.Backend.ProjectName(rn) + pn, err := be.ProjectName(ctx, rn) if err != nil { return err } @@ -31,7 +33,7 @@ func projectName() *cobra.Command { if err := checkIfCollab(cmd, args); err != nil { return err } - if err := cfg.Backend.SetProjectName(rn, strings.Join(args[1:], " ")); err != nil { + if err := be.SetProjectName(ctx, rn, strings.Join(args[1:], " ")); err != nil { return err } } diff --git a/server/cmd/pubkey.go b/server/ssh/cmd/pubkey.go similarity index 61% rename from server/cmd/pubkey.go rename to server/ssh/cmd/pubkey.go index 7c1bea9b34b0507f14c2709e9b500e993b697d08..e2200b1dc7d05dd878687c339654b5bfb9b4d2f7 100644 --- a/server/cmd/pubkey.go +++ b/server/ssh/cmd/pubkey.go @@ -4,6 +4,7 @@ import ( "strings" "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/sshutils" "github.com/spf13/cobra" ) @@ -19,18 +20,20 @@ func pubkeyCommand() *cobra.Command { Short: "Add a public key", Args: cobra.MinimumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - cfg, s := fromContext(cmd) - user, err := cfg.Backend.UserByPublicKey(s.PublicKey()) + ctx := cmd.Context() + be := backend.FromContext(ctx) + pk := sshutils.PublicKeyFromContext(ctx) + user, err := be.UserByPublicKey(ctx, pk) if err != nil { return err } - pk, _, err := backend.ParseAuthorizedKey(strings.Join(args, " ")) + apk, _, err := sshutils.ParseAuthorizedKey(strings.Join(args, " ")) if err != nil { return err } - return cfg.Backend.AddPublicKey(user.Username(), pk) + return be.AddPublicKey(ctx, user.Username(), apk) }, } @@ -39,18 +42,20 @@ func pubkeyCommand() *cobra.Command { Args: cobra.MinimumNArgs(1), Short: "Remove a public key", RunE: func(cmd *cobra.Command, args []string) error { - cfg, s := fromContext(cmd) - user, err := cfg.Backend.UserByPublicKey(s.PublicKey()) + ctx := cmd.Context() + be := backend.FromContext(ctx) + pk := sshutils.PublicKeyFromContext(ctx) + user, err := be.UserByPublicKey(ctx, pk) if err != nil { return err } - pk, _, err := backend.ParseAuthorizedKey(strings.Join(args, " ")) + apk, _, err := sshutils.ParseAuthorizedKey(strings.Join(args, " ")) if err != nil { return err } - return cfg.Backend.RemovePublicKey(user.Username(), pk) + return be.RemovePublicKey(ctx, user.Username(), apk) }, } @@ -60,15 +65,17 @@ func pubkeyCommand() *cobra.Command { Short: "List public keys", Args: cobra.NoArgs, RunE: func(cmd *cobra.Command, args []string) error { - cfg, s := fromContext(cmd) - user, err := cfg.Backend.UserByPublicKey(s.PublicKey()) + ctx := cmd.Context() + be := backend.FromContext(ctx) + pk := sshutils.PublicKeyFromContext(ctx) + user, err := be.UserByPublicKey(ctx, pk) if err != nil { return err } pks := user.PublicKeys() for _, pk := range pks { - cmd.Println(backend.MarshalAuthorizedKey(pk)) + cmd.Println(sshutils.MarshalAuthorizedKey(pk)) } return nil diff --git a/server/cmd/rename.go b/server/ssh/cmd/rename.go similarity index 67% rename from server/cmd/rename.go rename to server/ssh/cmd/rename.go index d3ab7b0ac8c52379f886b79a3826146550beba82..2e92907b0e0e422a180cdbc34382396ba69b4665 100644 --- a/server/cmd/rename.go +++ b/server/ssh/cmd/rename.go @@ -1,6 +1,9 @@ package cmd -import "github.com/spf13/cobra" +import ( + "github.com/charmbracelet/soft-serve/server/backend" + "github.com/spf13/cobra" +) func renameCommand() *cobra.Command { cmd := &cobra.Command{ @@ -10,13 +13,12 @@ func renameCommand() *cobra.Command { Args: cobra.ExactArgs(2), PersistentPreRunE: checkIfCollab, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) oldName := args[0] newName := args[1] - if err := cfg.Backend.RenameRepository(oldName, newName); err != nil { - return err - } - return nil + + return be.RenameRepository(ctx, oldName, newName) }, } diff --git a/server/cmd/repo.go b/server/ssh/cmd/repo.go similarity index 92% rename from server/cmd/repo.go rename to server/ssh/cmd/repo.go index 6371f167da813f2fe988193a16a631874d9e7bfa..24f4b8c0622421b117346a349c1e2f38fb3de4d3 100644 --- a/server/cmd/repo.go +++ b/server/ssh/cmd/repo.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" + "github.com/charmbracelet/soft-serve/server/backend" "github.com/spf13/cobra" ) @@ -40,9 +41,10 @@ func repoCommand() *cobra.Command { Args: cobra.ExactArgs(1), PersistentPreRunE: checkIfReadable, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) rn := args[0] - rr, err := cfg.Backend.Repository(rn) + rr, err := be.Repository(ctx, rn) if err != nil { return err } diff --git a/server/ssh/cmd/set_username.go b/server/ssh/cmd/set_username.go new file mode 100644 index 0000000000000000000000000000000000000000..71243ac5f5c39b7f2308473d64adec009639d53a --- /dev/null +++ b/server/ssh/cmd/set_username.go @@ -0,0 +1,28 @@ +package cmd + +import ( + "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/sshutils" + "github.com/spf13/cobra" +) + +func setUsernameCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "set-username USERNAME", + Short: "Set your username", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + be := backend.FromContext(ctx) + pk := sshutils.PublicKeyFromContext(ctx) + user, err := be.UserByPublicKey(ctx, pk) + if err != nil { + return err + } + + return be.SetUsername(ctx, user.Username(), args[0]) + }, + } + + return cmd +} diff --git a/server/cmd/settings.go b/server/ssh/cmd/settings.go similarity index 70% rename from server/cmd/settings.go rename to server/ssh/cmd/settings.go index 95ef01a24c9adfbb14f8687e6bc2ccd3b412dee4..0fa4cace84aa42684d595f269706eea82bf39530 100644 --- a/server/cmd/settings.go +++ b/server/ssh/cmd/settings.go @@ -4,6 +4,7 @@ import ( "fmt" "strconv" + "github.com/charmbracelet/soft-serve/server/access" "github.com/charmbracelet/soft-serve/server/backend" "github.com/spf13/cobra" ) @@ -21,13 +22,14 @@ func settingsCommand() *cobra.Command { Args: cobra.RangeArgs(0, 1), PersistentPreRunE: checkIfAdmin, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) switch len(args) { case 0: - cmd.Println(cfg.Backend.AllowKeyless()) + cmd.Println(be.AllowKeyless(ctx)) case 1: v, _ := strconv.ParseBool(args[0]) - if err := cfg.Backend.SetAllowKeyless(v); err != nil { + if err := be.SetAllowKeyless(ctx, v); err != nil { return err } } @@ -37,7 +39,7 @@ func settingsCommand() *cobra.Command { }, ) - als := []string{backend.NoAccess.String(), backend.ReadOnlyAccess.String(), backend.ReadWriteAccess.String(), backend.AdminAccess.String()} + als := []string{access.NoAccess.String(), access.ReadOnlyAccess.String(), access.ReadWriteAccess.String(), access.AdminAccess.String()} cmd.AddCommand( &cobra.Command{ Use: "anon-access [ACCESS_LEVEL]", @@ -46,16 +48,17 @@ func settingsCommand() *cobra.Command { ValidArgs: als, PersistentPreRunE: checkIfAdmin, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) switch len(args) { case 0: - cmd.Println(cfg.Backend.AnonAccess()) + cmd.Println(be.AnonAccess(ctx)) case 1: - al := backend.ParseAccessLevel(args[0]) + al := access.ParseAccessLevel(args[0]) if al < 0 { return fmt.Errorf("invalid access level: %s. Please choose one of the following: %s", args[0], als) } - if err := cfg.Backend.SetAnonAccess(al); err != nil { + if err := be.SetAnonAccess(ctx, al); err != nil { return err } } diff --git a/server/cmd/tag.go b/server/ssh/cmd/tag.go similarity index 84% rename from server/cmd/tag.go rename to server/ssh/cmd/tag.go index 84e6a907dc30b01b462c7622b87a80b0646ad37b..6b72087627c14411f021ed38bb5003f8032a4412 100644 --- a/server/cmd/tag.go +++ b/server/ssh/cmd/tag.go @@ -3,6 +3,7 @@ package cmd import ( "strings" + "github.com/charmbracelet/soft-serve/server/backend" "github.com/spf13/cobra" ) @@ -28,9 +29,10 @@ func tagListCommand() *cobra.Command { Args: cobra.ExactArgs(1), PersistentPreRunE: checkIfReadable, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) rn := strings.TrimSuffix(args[0], ".git") - rr, err := cfg.Backend.Repository(rn) + rr, err := be.Repository(ctx, rn) if err != nil { return err } @@ -60,9 +62,10 @@ func tagDeleteCommand() *cobra.Command { Args: cobra.ExactArgs(2), PersistentPreRunE: checkIfCollab, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) rn := strings.TrimSuffix(args[0], ".git") - rr, err := cfg.Backend.Repository(rn) + rr, err := be.Repository(ctx, rn) if err != nil { return err } diff --git a/server/cmd/tree.go b/server/ssh/cmd/tree.go similarity index 88% rename from server/cmd/tree.go rename to server/ssh/cmd/tree.go index 19ea3720d4d99e10887944b1eff12129b516e927..ccb72a70f890a6bd5800bbb1379b831839ef3ee1 100644 --- a/server/cmd/tree.go +++ b/server/ssh/cmd/tree.go @@ -4,7 +4,8 @@ import ( "fmt" "github.com/charmbracelet/soft-serve/git" - "github.com/charmbracelet/soft-serve/server/errors" + "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/proto" "github.com/dustin/go-humanize" "github.com/spf13/cobra" ) @@ -17,7 +18,8 @@ func treeCommand() *cobra.Command { Args: cobra.RangeArgs(1, 3), PersistentPreRunE: checkIfReadable, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) rn := args[0] path := "" ref := "" @@ -28,7 +30,7 @@ func treeCommand() *cobra.Command { ref = args[1] path = args[2] } - rr, err := cfg.Backend.Repository(rn) + rr, err := be.Repository(ctx, rn) if err != nil { return err } @@ -59,7 +61,7 @@ func treeCommand() *cobra.Command { if path != "" && path != "/" { te, err := tree.TreeEntry(path) if err == git.ErrRevisionNotExist { - return errors.ErrFileNotFound + return proto.ErrFileNotFound } if err != nil { return err diff --git a/server/cmd/user.go b/server/ssh/cmd/user.go similarity index 76% rename from server/cmd/user.go rename to server/ssh/cmd/user.go index 6518b52d51f3e07ef7a133661cc641f0df007ed5..639bf909adb773db202f7b95fae0aa8629b14425 100644 --- a/server/cmd/user.go +++ b/server/ssh/cmd/user.go @@ -4,8 +4,9 @@ import ( "sort" "strings" - "github.com/charmbracelet/log" "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/proto" + "github.com/charmbracelet/soft-serve/server/sshutils" "github.com/spf13/cobra" "golang.org/x/crypto/ssh" ) @@ -26,10 +27,11 @@ func userCommand() *cobra.Command { PersistentPreRunE: checkIfAdmin, RunE: func(cmd *cobra.Command, args []string) error { var pubkeys []ssh.PublicKey - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) username := args[0] if key != "" { - pk, _, err := backend.ParseAuthorizedKey(key) + pk, _, err := sshutils.ParseAuthorizedKey(key) if err != nil { return err } @@ -37,12 +39,12 @@ func userCommand() *cobra.Command { pubkeys = []ssh.PublicKey{pk} } - opts := backend.UserOptions{ + opts := proto.UserOptions{ Admin: admin, PublicKeys: pubkeys, } - _, err := cfg.Backend.CreateUser(username, opts) + _, err := be.CreateUser(ctx, username, opts) return err }, } @@ -56,10 +58,11 @@ func userCommand() *cobra.Command { Args: cobra.ExactArgs(1), PersistentPreRunE: checkIfAdmin, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) username := args[0] - return cfg.Backend.DeleteUser(username) + return be.DeleteUser(ctx, username) }, } @@ -70,8 +73,9 @@ func userCommand() *cobra.Command { Args: cobra.NoArgs, PersistentPreRunE: checkIfAdmin, RunE: func(cmd *cobra.Command, _ []string) error { - cfg, _ := fromContext(cmd) - users, err := cfg.Backend.Users() + ctx := cmd.Context() + be := backend.FromContext(ctx) + users, err := be.Users(ctx) if err != nil { return err } @@ -91,15 +95,16 @@ func userCommand() *cobra.Command { Args: cobra.MinimumNArgs(2), PersistentPreRunE: checkIfAdmin, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) username := args[0] pubkey := strings.Join(args[1:], " ") - pk, _, err := backend.ParseAuthorizedKey(pubkey) + pk, _, err := sshutils.ParseAuthorizedKey(pubkey) if err != nil { return err } - return cfg.Backend.AddPublicKey(username, pk) + return be.AddPublicKey(ctx, username, pk) }, } @@ -109,16 +114,16 @@ func userCommand() *cobra.Command { Args: cobra.MinimumNArgs(2), PersistentPreRunE: checkIfAdmin, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) username := args[0] pubkey := strings.Join(args[1:], " ") - log.Debugf("key is %q", pubkey) - pk, _, err := backend.ParseAuthorizedKey(pubkey) + pk, _, err := sshutils.ParseAuthorizedKey(pubkey) if err != nil { return err } - return cfg.Backend.RemovePublicKey(username, pk) + return be.RemovePublicKey(ctx, username, pk) }, } @@ -128,10 +133,11 @@ func userCommand() *cobra.Command { Args: cobra.ExactArgs(2), PersistentPreRunE: checkIfAdmin, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) username := args[0] - return cfg.Backend.SetAdmin(username, args[1] == "true") + return be.SetAdmin(ctx, username, args[1] == "true") }, } @@ -141,10 +147,11 @@ func userCommand() *cobra.Command { Args: cobra.ExactArgs(1), PersistentPreRunE: checkIfAdmin, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) username := args[0] - user, err := cfg.Backend.User(username) + user, err := be.User(ctx, username) if err != nil { return err } @@ -155,7 +162,7 @@ func userCommand() *cobra.Command { cmd.Printf("Admin: %t\n", isAdmin) cmd.Printf("Public keys:\n") for _, pk := range user.PublicKeys() { - cmd.Printf(" %s\n", backend.MarshalAuthorizedKey(pk)) + cmd.Printf(" %s\n", sshutils.MarshalAuthorizedKey(pk)) } return nil @@ -168,11 +175,12 @@ func userCommand() *cobra.Command { Args: cobra.ExactArgs(2), PersistentPreRunE: checkIfAdmin, RunE: func(cmd *cobra.Command, args []string) error { - cfg, _ := fromContext(cmd) + ctx := cmd.Context() + be := backend.FromContext(ctx) username := args[0] newUsername := args[1] - return cfg.Backend.SetUsername(username, newUsername) + return be.SetUsername(ctx, username, newUsername) }, } diff --git a/server/ssh/git.go b/server/ssh/git.go new file mode 100644 index 0000000000000000000000000000000000000000..051dc25921ad137d89dacdc44501ab8e70637369 --- /dev/null +++ b/server/ssh/git.go @@ -0,0 +1,124 @@ +package ssh + +import ( + "errors" + "path/filepath" + "time" + + "github.com/charmbracelet/log" + "github.com/charmbracelet/soft-serve/server/access" + "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/git" + "github.com/charmbracelet/soft-serve/server/proto" + "github.com/charmbracelet/soft-serve/server/sshutils" + "github.com/charmbracelet/soft-serve/server/utils" + "github.com/charmbracelet/ssh" +) + +func handleGit(s ssh.Session) { + ctx := s.Context() + cfg := config.FromContext(ctx) + be := backend.FromContext(ctx) + logger := log.FromContext(ctx) + cmdLine := s.Command() + start := time.Now() + + // repo should be in the form of "repo.git" + name := utils.SanitizeRepo(cmdLine[1]) + pk := s.PublicKey() + ak := sshutils.MarshalAuthorizedKey(pk) + accessLevel := be.AccessLevelByPublicKey(ctx, name, pk) + // git bare repositories should end in ".git" + // https://git-scm.com/docs/gitrepository-layout + repo := name + ".git" + reposDir := filepath.Join(cfg.DataPath, "repos") + if err := git.EnsureWithin(reposDir, repo); err != nil { + sshFatal(s, err) + return + } + + // Environment variables to pass down to git hooks. + envs := []string{ + "SOFT_SERVE_REPO_NAME=" + name, + "SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo), + "SOFT_SERVE_PUBLIC_KEY=" + ak, + "SOFT_SERVE_USERNAME=" + s.User(), + "SOFT_SERVE_LOG_PATH=" + filepath.Join(cfg.DataPath, "log", "hooks.log"), + } + + // Add ssh session & config environ + envs = append(envs, s.Environ()...) + envs = append(envs, cfg.Environ()...) + + repoDir := filepath.Join(reposDir, repo) + service := git.Service(cmdLine[0]) + cmd := git.ServiceCommand{ + Stdin: s, + Stdout: s, + Stderr: s.Stderr(), + Env: envs, + Dir: repoDir, + } + + logger.Debug("git middleware", "cmd", service, "access", accessLevel.String()) + + switch service { + case git.ReceivePackService: + receivePackCounter.WithLabelValues(name).Inc() + defer func() { + receivePackSeconds.WithLabelValues(name).Add(time.Since(start).Seconds()) + }() + if accessLevel < access.ReadWriteAccess { + sshFatal(s, git.ErrNotAuthed) + return + } + if _, err := be.Repository(ctx, name); err != nil { + if _, err := be.CreateRepository(ctx, name, proto.RepositoryOptions{Private: false}); err != nil { + log.Errorf("failed to create repo: %s", err) + sshFatal(s, err) + return + } + createRepoCounter.WithLabelValues(name).Inc() + } + + if err := git.ReceivePack(ctx, cmd); err != nil { + sshFatal(s, git.ErrSystemMalfunction) + } + + if err := git.EnsureDefaultBranch(ctx, cmd); err != nil { + sshFatal(s, git.ErrSystemMalfunction) + } + + receivePackCounter.WithLabelValues(name).Inc() + return + case git.UploadPackService, git.UploadArchiveService: + if accessLevel < access.ReadOnlyAccess { + sshFatal(s, git.ErrNotAuthed) + return + } + + handler := git.UploadPack + switch service { + case git.UploadArchiveService: + handler = git.UploadArchive + uploadArchiveCounter.WithLabelValues(name).Inc() + defer func() { + uploadArchiveSeconds.WithLabelValues(name).Add(time.Since(start).Seconds()) + }() + default: + uploadPackCounter.WithLabelValues(name).Inc() + defer func() { + uploadPackSeconds.WithLabelValues(name).Add(time.Since(start).Seconds()) + }() + } + + err := handler(ctx, cmd) + if errors.Is(err, git.ErrInvalidRepo) { + sshFatal(s, git.ErrInvalidRepo) + } else if err != nil { + logger.Error("git middleware", "err", err) + sshFatal(s, git.ErrSystemMalfunction) + } + } +} diff --git a/server/ssh/logger.go b/server/ssh/logger.go new file mode 100644 index 0000000000000000000000000000000000000000..c5b972f0b9b6237e56cfbddcc2e6fa897f59c13c --- /dev/null +++ b/server/ssh/logger.go @@ -0,0 +1,25 @@ +package ssh + +import "github.com/charmbracelet/log" + +type loggerAdapter struct { + *log.Logger + log.Level +} + +func (l *loggerAdapter) Printf(format string, args ...interface{}) { + switch l.Level { + case log.DebugLevel: + l.Logger.Debugf(format, args...) + case log.InfoLevel: + l.Logger.Infof(format, args...) + case log.WarnLevel: + l.Logger.Warnf(format, args...) + case log.ErrorLevel: + l.Logger.Errorf(format, args...) + case log.FatalLevel: + l.Logger.Fatalf(format, args...) + default: + l.Logger.Printf(format, args...) + } +} diff --git a/server/ssh/middleware.go b/server/ssh/middleware.go new file mode 100644 index 0000000000000000000000000000000000000000..7fe749185220ab637efd089149ccd7fa78170942 --- /dev/null +++ b/server/ssh/middleware.go @@ -0,0 +1,44 @@ +package ssh + +import ( + "strings" + + "github.com/charmbracelet/log" + "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/ssh" +) + +// ContextMiddleware adds the config, backend, and logger to the session context. +func ContextMiddleware(cfg *config.Config, be *backend.Backend, logger *log.Logger) func(ssh.Handler) ssh.Handler { + return func(sh ssh.Handler) ssh.Handler { + return func(s ssh.Session) { + s.Context().SetValue(config.ContextKey, cfg) + s.Context().SetValue(backend.ContextKey, be) + s.Context().SetValue(log.ContextKey, logger.WithPrefix("ssh")) + sh(s) + } + } +} + +// CommandMiddleware handles git commands and CLI commands. +// This middleware must be run after the ContextMiddleware. +func CommandMiddleware(sh ssh.Handler) ssh.Handler { + return func(s ssh.Session) { + func() { + cmdLine := s.Command() + _, _, ptyReq := s.Pty() + if ptyReq { + return + } + + switch { + case len(cmdLine) >= 2 && strings.HasPrefix(cmdLine[0], "git-"): + handleGit(s) + default: + handleCli(s) + } + }() + sh(s) + } +} diff --git a/server/ssh/session.go b/server/ssh/session.go index 26ee3a5032b881c9f8f2670bd3daf7accc850bb1..a5bd4d168d7e71f0d6d4038ce173108cc0967c8f 100644 --- a/server/ssh/session.go +++ b/server/ssh/session.go @@ -5,16 +5,14 @@ import ( "time" tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/log" - . "github.com/charmbracelet/soft-serve/internal/log" + "github.com/charmbracelet/soft-serve/server/access" "github.com/charmbracelet/soft-serve/server/backend" "github.com/charmbracelet/soft-serve/server/config" - "github.com/charmbracelet/soft-serve/server/errors" + "github.com/charmbracelet/soft-serve/server/proto" "github.com/charmbracelet/soft-serve/server/ui" "github.com/charmbracelet/soft-serve/server/ui/common" "github.com/charmbracelet/ssh" "github.com/charmbracelet/wish" - bm "github.com/charmbracelet/wish/bubbletea" "github.com/muesli/termenv" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -35,50 +33,50 @@ var tuiSessionDuration = promauto.NewCounterVec(prometheus.CounterOpts{ }, []string{"repo", "term"}) // SessionHandler is the soft-serve bubbletea ssh session handler. -func SessionHandler(cfg *config.Config) bm.ProgramHandler { - return func(s ssh.Session) *tea.Program { - pty, _, active := s.Pty() - if !active { - return nil - } +// This middleware must be run after the ContextMiddleware. +func SessionHandler(s ssh.Session) *tea.Program { + pty, _, active := s.Pty() + if !active { + return nil + } - cmd := s.Command() - initialRepo := "" - if len(cmd) == 1 { - initialRepo = cmd[0] - auth := cfg.Backend.AccessLevelByPublicKey(initialRepo, s.PublicKey()) - if auth < backend.ReadOnlyAccess { - wish.Fatalln(s, errors.ErrUnauthorized) - return nil - } + ctx := s.Context() + be := backend.FromContext(ctx) + cfg := config.FromContext(ctx) + cmd := s.Command() + initialRepo := "" + if len(cmd) == 1 { + initialRepo = cmd[0] + auth := be.AccessLevelByPublicKey(ctx, initialRepo, s.PublicKey()) + if auth < access.ReadOnlyAccess { + wish.Fatalln(s, proto.ErrUnauthorized) + return nil } + } - envs := &sessionEnv{s} - output := termenv.NewOutput(s, termenv.WithColorCache(true), termenv.WithEnvironment(envs)) - logger := NewDefaultLogger() - ctx := log.WithContext(s.Context(), logger) - c := common.NewCommon(ctx, output, pty.Window.Width, pty.Window.Height) - c.SetValue(common.ConfigKey, cfg) - m := ui.New(c, initialRepo) - p := tea.NewProgram(m, - tea.WithInput(s), - tea.WithOutput(s), - tea.WithAltScreen(), - tea.WithoutCatchPanics(), - tea.WithMouseCellMotion(), - tea.WithContext(ctx), - ) + envs := &sessionEnv{s} + output := termenv.NewOutput(s, termenv.WithColorCache(true), termenv.WithEnvironment(envs)) + c := common.NewCommon(ctx, output, pty.Window.Width, pty.Window.Height) + c.SetValue(common.ConfigKey, cfg) + m := ui.New(c, initialRepo) + p := tea.NewProgram(m, + tea.WithInput(s), + tea.WithOutput(s), + tea.WithAltScreen(), + tea.WithoutCatchPanics(), + tea.WithMouseCellMotion(), + tea.WithContext(ctx), + ) - tuiSessionCounter.WithLabelValues(initialRepo, pty.Term).Inc() + tuiSessionCounter.WithLabelValues(initialRepo, pty.Term).Inc() - start := time.Now() - go func() { - <-ctx.Done() - tuiSessionDuration.WithLabelValues(initialRepo, pty.Term).Add(time.Since(start).Seconds()) - }() + start := time.Now() + go func() { + <-ctx.Done() + tuiSessionDuration.WithLabelValues(initialRepo, pty.Term).Add(time.Since(start).Seconds()) + }() - return p - } + return p } var _ termenv.Environ = &sessionEnv{} diff --git a/server/ssh/session_test.go b/server/ssh/session_test.go index 995ab20fd122f060479562dba0139f83fe871902..104c92ad14095a4b2d851c40e959c05137e2d5f8 100644 --- a/server/ssh/session_test.go +++ b/server/ssh/session_test.go @@ -4,13 +4,15 @@ import ( "context" "errors" "fmt" - "log" "os" "testing" "time" - "github.com/charmbracelet/soft-serve/server/backend/sqlite" + "github.com/charmbracelet/log" + "github.com/charmbracelet/soft-serve/server/backend" "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/db/migrate" "github.com/charmbracelet/soft-serve/server/test" "github.com/charmbracelet/ssh" bm "github.com/charmbracelet/wish/bubbletea" @@ -18,6 +20,7 @@ import ( "github.com/matryer/is" "github.com/muesli/termenv" gossh "golang.org/x/crypto/ssh" + _ "modernc.org/sqlite" // sqlite driver ) func TestSession(t *testing.T) { @@ -31,16 +34,15 @@ func TestSession(t *testing.T) { is.NoErr(err) go func() { time.Sleep(1 * time.Second) - s.Signal(gossh.SIGTERM) - // FIXME: exit with code 0 instead of forcibly closing the session - s.Close() + // s.Signal(gossh.SIGTERM) + s.Close() // nolint: errcheck }() t.Log("waiting for session to exit") _, err = s.Output("test") var ee *gossh.ExitMissingError is.True(errors.As(err, &ee)) t.Log("session exited") - _ = close() + is.NoErr(close()) }) } @@ -59,19 +61,26 @@ func setup(tb testing.TB) (*gossh.Session, func() error) { }) ctx := context.TODO() cfg := config.DefaultConfig() + if err := cfg.Validate(); err != nil { + log.Fatal(err) + } ctx = config.WithContext(ctx, cfg) - fb, err := sqlite.NewSqliteBackend(ctx) + db, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource) if err != nil { - log.Fatal(err) + tb.Fatal(err) + } + if err := migrate.Migrate(ctx, db); err != nil { + tb.Fatal(err) } - cfg = cfg.WithBackend(fb) + be := backend.New(ctx, cfg, db) + ctx = backend.WithContext(ctx, be) return testsession.New(tb, &ssh.Server{ - Handler: bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256)(func(s ssh.Session) { + Handler: ContextMiddleware(cfg, be, log.Default())(bm.MiddlewareWithProgramHandler(SessionHandler, termenv.ANSI256)(func(s ssh.Session) { _, _, active := s.Pty() if !active { os.Exit(1) } s.Exit(0) - }), - }, nil), fb.Close + })), + }, nil), db.Close } diff --git a/server/ssh/ssh.go b/server/ssh/ssh.go index 8e98d0ebc0938776d031615daf8090ff6244f9cc..39f300e58fff5352243fefdb8b85c434582e0d96 100644 --- a/server/ssh/ssh.go +++ b/server/ssh/ssh.go @@ -2,22 +2,19 @@ package ssh import ( "context" - "errors" "fmt" "net" "os" - "path/filepath" "strconv" - "strings" "time" "github.com/charmbracelet/keygen" "github.com/charmbracelet/log" + "github.com/charmbracelet/soft-serve/server/access" "github.com/charmbracelet/soft-serve/server/backend" - cm "github.com/charmbracelet/soft-serve/server/cmd" "github.com/charmbracelet/soft-serve/server/config" "github.com/charmbracelet/soft-serve/server/git" - "github.com/charmbracelet/soft-serve/server/utils" + "github.com/charmbracelet/soft-serve/server/sshutils" "github.com/charmbracelet/ssh" "github.com/charmbracelet/wish" bm "github.com/charmbracelet/wish/bubbletea" @@ -95,10 +92,10 @@ var ( ) // SSHServer is a SSH server that implements the git protocol. -type SSHServer struct { +type SSHServer struct { // nolint: revive srv *ssh.Server cfg *config.Config - be backend.Backend + be *backend.Backend ctx context.Context logger *log.Logger } @@ -107,12 +104,13 @@ type SSHServer struct { func NewSSHServer(ctx context.Context) (*SSHServer, error) { cfg := config.FromContext(ctx) logger := log.FromContext(ctx).WithPrefix("ssh") + be := backend.FromContext(ctx) var err error s := &SSHServer{ cfg: cfg, ctx: ctx, - be: backend.FromContext(ctx), + be: be, logger: logger, } @@ -120,14 +118,15 @@ func NewSSHServer(ctx context.Context) (*SSHServer, error) { rm.MiddlewareWithLogger( logger, // BubbleTea middleware. - bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256), + bm.MiddlewareWithProgramHandler(SessionHandler, termenv.ANSI256), // CLI middleware. - cm.Middleware(cfg, logger), - // Git middleware. - s.Middleware(cfg), + CommandMiddleware, + // Context middleware. + ContextMiddleware(cfg, be, logger), // Logging middleware. - lm.MiddlewareWithLogger(logger. - StandardLog(log.StandardLogOptions{ForceLevel: log.DebugLevel})), + lm.MiddlewareWithLogger( + &loggerAdapter{logger, log.DebugLevel}, + ), ), } @@ -187,145 +186,27 @@ func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed return false } - ak := backend.MarshalAuthorizedKey(pk) + ak := sshutils.MarshalAuthorizedKey(pk) defer func(allowed *bool) { publicKeyCounter.WithLabelValues(strconv.FormatBool(*allowed)).Inc() }(&allowed) - ac := s.cfg.Backend.AccessLevelByPublicKey("", pk) + ac := s.be.AccessLevelByPublicKey(ctx, "", pk) s.logger.Debugf("access level for %q: %s", ak, ac) - allowed = ac >= backend.ReadWriteAccess + allowed = ac >= access.ReadWriteAccess return } // KeyboardInteractiveHandler handles keyboard interactive authentication. // This is used after all public key authentication has failed. func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool { - ac := s.cfg.Backend.AllowKeyless() + ac := s.be.AllowKeyless(ctx) keyboardInteractiveCounter.WithLabelValues(strconv.FormatBool(ac)).Inc() return ac } -// Middleware adds Git server functionality to the ssh.Server. Repos are stored -// in the specified repo directory. The provided Hooks implementation will be -// checked for access on a per repo basis for a ssh.Session public key. -// Hooks.Push and Hooks.Fetch will be called on successful completion of -// their commands. -func (ss *SSHServer) Middleware(cfg *config.Config) wish.Middleware { - return func(sh ssh.Handler) ssh.Handler { - return func(s ssh.Session) { - func() { - start := time.Now() - cmdLine := s.Command() - ctx := s.Context() - be := ss.be.WithContext(ctx) - - if len(cmdLine) >= 2 && strings.HasPrefix(cmdLine[0], "git") { - // repo should be in the form of "repo.git" - name := utils.SanitizeRepo(cmdLine[1]) - pk := s.PublicKey() - ak := backend.MarshalAuthorizedKey(pk) - access := cfg.Backend.AccessLevelByPublicKey(name, pk) - // git bare repositories should end in ".git" - // https://git-scm.com/docs/gitrepository-layout - repo := name + ".git" - reposDir := filepath.Join(cfg.DataPath, "repos") - if err := git.EnsureWithin(reposDir, repo); err != nil { - sshFatal(s, err) - return - } - - // Environment variables to pass down to git hooks. - envs := []string{ - "SOFT_SERVE_REPO_NAME=" + name, - "SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo), - "SOFT_SERVE_PUBLIC_KEY=" + ak, - "SOFT_SERVE_USERNAME=" + ctx.User(), - } - - // Add ssh session & config environ - envs = append(envs, s.Environ()...) - envs = append(envs, cfg.Environ()...) - - repoDir := filepath.Join(reposDir, repo) - service := git.Service(cmdLine[0]) - cmd := git.ServiceCommand{ - Stdin: s, - Stdout: s, - Stderr: s.Stderr(), - Env: envs, - Dir: repoDir, - } - - ss.logger.Debug("git middleware", "cmd", service, "access", access.String()) - - switch service { - case git.ReceivePackService: - receivePackCounter.WithLabelValues(name).Inc() - defer func() { - receivePackSeconds.WithLabelValues(name).Add(time.Since(start).Seconds()) - }() - if access < backend.ReadWriteAccess { - sshFatal(s, git.ErrNotAuthed) - return - } - if _, err := be.Repository(name); err != nil { - if _, err := be.CreateRepository(name, backend.RepositoryOptions{Private: false}); err != nil { - log.Errorf("failed to create repo: %s", err) - sshFatal(s, err) - return - } - createRepoCounter.WithLabelValues(name).Inc() - } - - if err := git.ReceivePack(ctx, cmd); err != nil { - sshFatal(s, git.ErrSystemMalfunction) - } - - if err := git.EnsureDefaultBranch(ctx, cmd); err != nil { - sshFatal(s, git.ErrSystemMalfunction) - } - - receivePackCounter.WithLabelValues(name).Inc() - return - case git.UploadPackService, git.UploadArchiveService: - if access < backend.ReadOnlyAccess { - sshFatal(s, git.ErrNotAuthed) - return - } - - handler := git.UploadPack - switch service { - case git.UploadArchiveService: - handler = git.UploadArchive - uploadArchiveCounter.WithLabelValues(name).Inc() - defer func() { - uploadArchiveSeconds.WithLabelValues(name).Add(time.Since(start).Seconds()) - }() - default: - uploadPackCounter.WithLabelValues(name).Inc() - defer func() { - uploadPackSeconds.WithLabelValues(name).Add(time.Since(start).Seconds()) - }() - } - - err := handler(ctx, cmd) - if errors.Is(err, git.ErrInvalidRepo) { - sshFatal(s, git.ErrInvalidRepo) - } else if err != nil { - sshFatal(s, git.ErrSystemMalfunction) - } - - } - } - }() - sh(s) - } - } -} - // sshFatal prints to the session's STDOUT as a git response and exit 1. -func sshFatal(s ssh.Session, v ...interface{}) { - git.WritePktline(s, v...) - s.Exit(1) // nolint: errcheck +func sshFatal(s ssh.Session, err error) { + git.WritePktlineErr(s, err) // nolint: errcheck + s.Exit(1) // nolint: errcheck } diff --git a/server/sshutils/utils.go b/server/sshutils/utils.go new file mode 100644 index 0000000000000000000000000000000000000000..6bba64578adff81c72bd9f8bb89eee2a3887b6ad --- /dev/null +++ b/server/sshutils/utils.go @@ -0,0 +1,40 @@ +package sshutils + +import ( + "bytes" + "context" + + "github.com/charmbracelet/ssh" + gossh "golang.org/x/crypto/ssh" +) + +// ParseAuthorizedKey parses an authorized key string into a public key. +func ParseAuthorizedKey(ak string) (gossh.PublicKey, string, error) { + pk, c, _, _, err := gossh.ParseAuthorizedKey([]byte(ak)) + return pk, c, err +} + +// MarshalAuthorizedKey marshals a public key into an authorized key string. +// +// This is the inverse of ParseAuthorizedKey. +// This function is a copy of ssh.MarshalAuthorizedKey, but without the trailing newline. +// It returns an empty string if pk is nil. +func MarshalAuthorizedKey(pk gossh.PublicKey) string { + if pk == nil { + return "" + } + return string(bytes.TrimSuffix(gossh.MarshalAuthorizedKey(pk), []byte("\n"))) +} + +// KeysEqual returns whether the two public keys are equal. +func KeysEqual(a, b gossh.PublicKey) bool { + return ssh.KeysEqual(a, b) +} + +// PublicKeyFromContext returns the public key from the context. +func PublicKeyFromContext(ctx context.Context) gossh.PublicKey { + if pk, ok := ctx.Value(ssh.ContextKeyPublicKey).(gossh.PublicKey); ok { + return pk + } + return nil +} diff --git a/server/sshutils/utils_test.go b/server/sshutils/utils_test.go new file mode 100644 index 0000000000000000000000000000000000000000..96300ed08fa80b32532e634a3ce28d4ad339cb07 --- /dev/null +++ b/server/sshutils/utils_test.go @@ -0,0 +1,124 @@ +package sshutils + +import ( + "testing" + + "github.com/charmbracelet/keygen" + "golang.org/x/crypto/ssh" +) + +func generateKeys(tb testing.TB) (*keygen.SSHKeyPair, *keygen.SSHKeyPair) { + goodKey1, err := keygen.New("", keygen.WithKeyType(keygen.Ed25519)) + if err != nil { + tb.Fatal(err) + } + goodKey2, err := keygen.New("", keygen.WithKeyType(keygen.RSA)) + if err != nil { + tb.Fatal(err) + } + + return goodKey1, goodKey2 +} + +func TestParseAuthorizedKey(t *testing.T) { + goodKey1, goodKey2 := generateKeys(t) + cases := []struct { + in string + good bool + }{ + { + goodKey1.AuthorizedKey(), + true, + }, + { + goodKey2.AuthorizedKey(), + true, + }, + { + goodKey1.AuthorizedKey() + "test", + false, + }, + { + goodKey2.AuthorizedKey() + "bad", + false, + }, + } + for _, c := range cases { + _, _, err := ParseAuthorizedKey(c.in) + if c.good && err != nil { + t.Errorf("ParseAuthorizedKey(%q) returned error: %v", c.in, err) + } + if !c.good && err == nil { + t.Errorf("ParseAuthorizedKey(%q) did not return error", c.in) + } + } +} + +func TestMarshalAuthorizedKey(t *testing.T) { + goodKey1, goodKey2 := generateKeys(t) + cases := []struct { + in ssh.PublicKey + expected string + }{ + { + goodKey1.PublicKey(), + goodKey1.AuthorizedKey(), + }, + { + goodKey2.PublicKey(), + goodKey2.AuthorizedKey(), + }, + { + nil, + "", + }, + } + for _, c := range cases { + out := MarshalAuthorizedKey(c.in) + if out != c.expected { + t.Errorf("MarshalAuthorizedKey(%v) returned %q, expected %q", c.in, out, c.expected) + } + } +} + +func TestKeysEqual(t *testing.T) { + goodKey1, goodKey2 := generateKeys(t) + cases := []struct { + in1 ssh.PublicKey + in2 ssh.PublicKey + expected bool + }{ + { + goodKey1.PublicKey(), + goodKey1.PublicKey(), + true, + }, + { + goodKey2.PublicKey(), + goodKey2.PublicKey(), + true, + }, + { + goodKey1.PublicKey(), + goodKey2.PublicKey(), + false, + }, + { + nil, + nil, + false, + }, + { + nil, + goodKey1.PublicKey(), + false, + }, + } + + for _, c := range cases { + out := KeysEqual(c.in1, c.in2) + if out != c.expected { + t.Errorf("KeysEqual(%v, %v) returned %v, expected %v", c.in1, c.in2, out, c.expected) + } + } +} diff --git a/server/stats/stats.go b/server/stats/stats.go index b515f1b52a816fd111b40c7842573cc592c0e8fa..109e5207ce0beaf6123cef2939d39768e55f36d2 100644 --- a/server/stats/stats.go +++ b/server/stats/stats.go @@ -10,7 +10,7 @@ import ( ) // StatsServer is a server for collecting and reporting statistics. -type StatsServer struct { +type StatsServer struct { //nolint:revive ctx context.Context cfg *config.Config server *http.Server diff --git a/server/store/database/collab.go b/server/store/database/collab.go new file mode 100644 index 0000000000000000000000000000000000000000..50424445e13ede0006e81d7e25aca255ff139ff9 --- /dev/null +++ b/server/store/database/collab.go @@ -0,0 +1,124 @@ +package database + +import ( + "context" + "strings" + + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/db/models" + "github.com/charmbracelet/soft-serve/server/store" + "github.com/charmbracelet/soft-serve/server/utils" +) + +type collabStore struct{} + +var _ store.CollaboratorStore = (*collabStore)(nil) + +// AddCollabByUsernameAndRepo implements store.CollaboratorStore. +func (*collabStore) AddCollabByUsernameAndRepo(ctx context.Context, tx *db.Tx, username string, repo string) error { + username = strings.ToLower(username) + if err := utils.ValidateUsername(username); err != nil { + return err + } + + repo = utils.SanitizeRepo(repo) + + query := tx.Rebind(`INSERT INTO collabs (user_id, repo_id, updated_at) + VALUES ( + ( + SELECT id FROM users WHERE username = ? + ), + ( + SELECT id FROM repos WHERE name = ? + ), + CURRENT_TIMESTAMP + );`) + _, err := tx.ExecContext(ctx, query, username, repo) + return err +} + +// GetCollabByUsernameAndRepo implements store.CollaboratorStore. +func (*collabStore) GetCollabByUsernameAndRepo(ctx context.Context, tx *db.Tx, username string, repo string) (models.Collab, error) { + var m models.Collab + + username = strings.ToLower(username) + if err := utils.ValidateUsername(username); err != nil { + return models.Collab{}, err + } + + repo = utils.SanitizeRepo(repo) + + err := tx.GetContext(ctx, &m, tx.Rebind(` + SELECT + collabs.* + FROM + collabs + INNER JOIN users ON users.id = collabs.user_id + INNER JOIN repos ON repos.id = collabs.repo_id + WHERE + users.username = ? AND repos.name = ? + `), username, repo) + + return m, err +} + +// ListCollabsByRepo implements store.CollaboratorStore. +func (*collabStore) ListCollabsByRepo(ctx context.Context, tx *db.Tx, repo string) ([]models.Collab, error) { + var m []models.Collab + + repo = utils.SanitizeRepo(repo) + query := tx.Rebind(` + SELECT + collabs.* + FROM + collabs + INNER JOIN repos ON repos.id = collabs.repo_id + WHERE + repos.name = ? + `) + + err := tx.SelectContext(ctx, &m, query, repo) + return m, err +} + +// ListCollabsByRepoAsUsers implements store.CollaboratorStore. +func (*collabStore) ListCollabsByRepoAsUsers(ctx context.Context, tx *db.Tx, repo string) ([]models.User, error) { + var m []models.User + + repo = utils.SanitizeRepo(repo) + query := tx.Rebind(` + SELECT + users.* + FROM + users + INNER JOIN repos ON repos.id = collabs.repo_id + INNER JOIN collabs ON collabs.user_id = users.id + WHERE + repos.name = ? + `) + + err := tx.SelectContext(ctx, &m, query, repo) + return m, err +} + +// RemoveCollabByUsernameAndRepo implements store.CollaboratorStore. +func (*collabStore) RemoveCollabByUsernameAndRepo(ctx context.Context, tx *db.Tx, username string, repo string) error { + username = strings.ToLower(username) + if err := utils.ValidateUsername(username); err != nil { + return err + } + + repo = utils.SanitizeRepo(repo) + query := tx.Rebind(` + DELETE FROM + collabs + WHERE + user_id = ( + SELECT id FROM users WHERE username = ? + ) AND repo_id = ( + SELECT id FROM repos WHERE name = ? + ) + `) + _, err := tx.ExecContext(ctx, query, username, repo) + return err +} diff --git a/server/store/database/database.go b/server/store/database/database.go new file mode 100644 index 0000000000000000000000000000000000000000..4960f68537f5d18b6c2e56d636e5d2f634cd8bae --- /dev/null +++ b/server/store/database/database.go @@ -0,0 +1,42 @@ +package database + +import ( + "context" + + "github.com/charmbracelet/log" + "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/store" +) + +type datastore struct { + ctx context.Context + cfg *config.Config + db *db.DB + logger *log.Logger + + *settingsStore + *repoStore + *userStore + *collabStore +} + +// New returns a new store.Store database. +func New(ctx context.Context, db *db.DB) store.Store { + cfg := config.FromContext(ctx) + logger := log.FromContext(ctx).WithPrefix("store") + + s := &datastore{ + ctx: ctx, + cfg: cfg, + db: db, + logger: logger, + + settingsStore: &settingsStore{}, + repoStore: &repoStore{}, + userStore: &userStore{}, + collabStore: &collabStore{}, + } + + return s +} diff --git a/server/store/database/repo.go b/server/store/database/repo.go new file mode 100644 index 0000000000000000000000000000000000000000..76436b9bc4393cf4a146559a9efe43417add6feb --- /dev/null +++ b/server/store/database/repo.go @@ -0,0 +1,135 @@ +package database + +import ( + "context" + + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/db/models" + "github.com/charmbracelet/soft-serve/server/store" + "github.com/charmbracelet/soft-serve/server/utils" +) + +type repoStore struct{} + +var _ store.RepositoryStore = (*repoStore)(nil) + +// CreateRepo implements store.RepositoryStore. +func (*repoStore) CreateRepo(ctx context.Context, tx *db.Tx, name string, projectName string, description string, isPrivate bool, isHidden bool, isMirror bool) error { + name = utils.SanitizeRepo(name) + query := tx.Rebind(`INSERT INTO repos (name, project_name, description, private, mirror, hidden, updated_at) + VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP);`) + _, err := tx.ExecContext(ctx, query, + name, projectName, description, isPrivate, isMirror, isHidden) + return db.WrapError(err) +} + +// DeleteRepoByName implements store.RepositoryStore. +func (*repoStore) DeleteRepoByName(ctx context.Context, tx *db.Tx, name string) error { + name = utils.SanitizeRepo(name) + query := tx.Rebind("DELETE FROM repos WHERE name = ?;") + _, err := tx.ExecContext(ctx, query, name) + return db.WrapError(err) +} + +// GetAllRepos implements store.RepositoryStore. +func (*repoStore) GetAllRepos(ctx context.Context, tx *db.Tx) ([]models.Repo, error) { + var repos []models.Repo + query := tx.Rebind("SELECT * FROM repos;") + err := tx.SelectContext(ctx, &repos, query) + return repos, db.WrapError(err) +} + +// GetRepoByName implements store.RepositoryStore. +func (*repoStore) GetRepoByName(ctx context.Context, tx *db.Tx, name string) (models.Repo, error) { + var repo models.Repo + name = utils.SanitizeRepo(name) + query := tx.Rebind("SELECT * FROM repos WHERE name = ?;") + err := tx.GetContext(ctx, &repo, query, name) + return repo, db.WrapError(err) +} + +// GetRepoDescriptionByName implements store.RepositoryStore. +func (*repoStore) GetRepoDescriptionByName(ctx context.Context, tx *db.Tx, name string) (string, error) { + var description string + name = utils.SanitizeRepo(name) + query := tx.Rebind("SELECT description FROM repos WHERE name = ?;") + err := tx.GetContext(ctx, &description, query, name) + return description, db.WrapError(err) +} + +// GetRepoIsHiddenByName implements store.RepositoryStore. +func (*repoStore) GetRepoIsHiddenByName(ctx context.Context, tx *db.Tx, name string) (bool, error) { + var isHidden bool + name = utils.SanitizeRepo(name) + query := tx.Rebind("SELECT hidden FROM repos WHERE name = ?;") + err := tx.GetContext(ctx, &isHidden, query, name) + return isHidden, db.WrapError(err) +} + +// GetRepoIsMirrorByName implements store.RepositoryStore. +func (*repoStore) GetRepoIsMirrorByName(ctx context.Context, tx *db.Tx, name string) (bool, error) { + var isMirror bool + name = utils.SanitizeRepo(name) + query := tx.Rebind("SELECT mirror FROM repos WHERE name = ?;") + err := tx.GetContext(ctx, &isMirror, query, name) + return isMirror, db.WrapError(err) +} + +// GetRepoIsPrivateByName implements store.RepositoryStore. +func (*repoStore) GetRepoIsPrivateByName(ctx context.Context, tx *db.Tx, name string) (bool, error) { + var isPrivate bool + name = utils.SanitizeRepo(name) + query := tx.Rebind("SELECT private FROM repos WHERE name = ?;") + err := tx.GetContext(ctx, &isPrivate, query, name) + return isPrivate, db.WrapError(err) +} + +// GetRepoProjectNameByName implements store.RepositoryStore. +func (*repoStore) GetRepoProjectNameByName(ctx context.Context, tx *db.Tx, name string) (string, error) { + var pname string + name = utils.SanitizeRepo(name) + query := tx.Rebind("SELECT project_name FROM repos WHERE name = ?;") + err := tx.GetContext(ctx, &pname, query, name) + return pname, db.WrapError(err) +} + +// SetRepoDescriptionByName implements store.RepositoryStore. +func (*repoStore) SetRepoDescriptionByName(ctx context.Context, tx *db.Tx, name string, description string) error { + name = utils.SanitizeRepo(name) + query := tx.Rebind("UPDATE repos SET description = ? WHERE name = ?;") + _, err := tx.ExecContext(ctx, query, description, name) + return db.WrapError(err) +} + +// SetRepoIsHiddenByName implements store.RepositoryStore. +func (*repoStore) SetRepoIsHiddenByName(ctx context.Context, tx *db.Tx, name string, isHidden bool) error { + name = utils.SanitizeRepo(name) + query := tx.Rebind("UPDATE repos SET hidden = ? WHERE name = ?;") + _, err := tx.ExecContext(ctx, query, isHidden, name) + return db.WrapError(err) +} + +// SetRepoIsPrivateByName implements store.RepositoryStore. +func (*repoStore) SetRepoIsPrivateByName(ctx context.Context, tx *db.Tx, name string, isPrivate bool) error { + name = utils.SanitizeRepo(name) + query := tx.Rebind("UPDATE repos SET private = ? WHERE name = ?;") + _, err := tx.ExecContext(ctx, query, isPrivate, name) + return db.WrapError(err) +} + +// SetRepoNameByName implements store.RepositoryStore. +func (*repoStore) SetRepoNameByName(ctx context.Context, tx *db.Tx, name string, newName string) error { + name = utils.SanitizeRepo(name) + newName = utils.SanitizeRepo(newName) + query := tx.Rebind("UPDATE repos SET name = ? WHERE name = ?;") + _, err := tx.ExecContext(ctx, query, newName, name) + return db.WrapError(err) +} + +// SetRepoProjectNameByName implements store.RepositoryStore. +func (*repoStore) SetRepoProjectNameByName(ctx context.Context, tx *db.Tx, name string, projectName string) error { + name = utils.SanitizeRepo(name) + query := tx.Rebind("UPDATE repos SET project_name = ? WHERE name = ?;") + _, err := tx.ExecContext(ctx, query, projectName, name) + return db.WrapError(err) +} diff --git a/server/store/database/settings.go b/server/store/database/settings.go new file mode 100644 index 0000000000000000000000000000000000000000..bb653a7ff188f2812e4003c846a3f16b5bab7d22 --- /dev/null +++ b/server/store/database/settings.go @@ -0,0 +1,47 @@ +package database + +import ( + "context" + + "github.com/charmbracelet/soft-serve/server/access" + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/store" +) + +type settingsStore struct{} + +var _ store.SettingStore = (*settingsStore)(nil) + +// GetAllowKeylessAccess implements store.SettingStore. +func (*settingsStore) GetAllowKeylessAccess(ctx context.Context, tx *db.Tx) (bool, error) { + var allow bool + query := tx.Rebind(`SELECT value FROM settings WHERE key = "allow_keyless"`) + if err := tx.GetContext(ctx, &allow, query); err != nil { + return false, db.WrapError(err) + } + return allow, nil +} + +// GetAnonAccess implements store.SettingStore. +func (*settingsStore) GetAnonAccess(ctx context.Context, tx *db.Tx) (access.AccessLevel, error) { + var level string + query := tx.Rebind(`SELECT value FROM settings WHERE key = "anon_access"`) + if err := tx.GetContext(ctx, &level, query); err != nil { + return access.NoAccess, db.WrapError(err) + } + return access.ParseAccessLevel(level), nil +} + +// SetAllowKeylessAccess implements store.SettingStore. +func (*settingsStore) SetAllowKeylessAccess(ctx context.Context, tx *db.Tx, allow bool) error { + query := tx.Rebind(`UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = "allow_keyless"`) + _, err := tx.ExecContext(ctx, query, allow) + return db.WrapError(err) +} + +// SetAnonAccess implements store.SettingStore. +func (*settingsStore) SetAnonAccess(ctx context.Context, tx *db.Tx, level access.AccessLevel) error { + query := tx.Rebind(`UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = "anon_access"`) + _, err := tx.ExecContext(ctx, query, level.String()) + return db.WrapError(err) +} diff --git a/server/store/database/user.go b/server/store/database/user.go new file mode 100644 index 0000000000000000000000000000000000000000..2e3a70beb45b9fb3625f98bbcf0d78604e190863 --- /dev/null +++ b/server/store/database/user.go @@ -0,0 +1,208 @@ +package database + +import ( + "context" + "strings" + + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/db/models" + "github.com/charmbracelet/soft-serve/server/sshutils" + "github.com/charmbracelet/soft-serve/server/store" + "github.com/charmbracelet/soft-serve/server/utils" + "golang.org/x/crypto/ssh" +) + +type userStore struct{} + +var _ store.UserStore = (*userStore)(nil) + +// AddPublicKeyByUsername implements store.UserStore. +func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx *db.Tx, username string, pk ssh.PublicKey) error { + username = strings.ToLower(username) + if err := utils.ValidateUsername(username); err != nil { + return err + } + + var userID int64 + if err := tx.GetContext(ctx, &userID, tx.Rebind(`SELECT id FROM users WHERE username = ?`), username); err != nil { + return err + } + + query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at) + VALUES (?, ?, CURRENT_TIMESTAMP);`) + ak := sshutils.MarshalAuthorizedKey(pk) + _, err := tx.ExecContext(ctx, query, userID, ak) + + return err +} + +// CreateUser implements store.UserStore. +func (*userStore) CreateUser(ctx context.Context, tx *db.Tx, username string, isAdmin bool, pks []ssh.PublicKey) error { + username = strings.ToLower(username) + if err := utils.ValidateUsername(username); err != nil { + return err + } + + query := tx.Rebind(`INSERT INTO users (username, admin, updated_at) + VALUES (?, ?, CURRENT_TIMESTAMP);`) + result, err := tx.ExecContext(ctx, query, username, isAdmin) + if err != nil { + return err + } + + userID, err := result.LastInsertId() + if err != nil { + return err + } + + for _, pk := range pks { + query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at) + VALUES (?, ?, CURRENT_TIMESTAMP);`) + ak := sshutils.MarshalAuthorizedKey(pk) + _, err := tx.ExecContext(ctx, query, userID, ak) + if err != nil { + return err + } + } + + return nil +} + +// DeleteUserByUsername implements store.UserStore. +func (*userStore) DeleteUserByUsername(ctx context.Context, tx *db.Tx, username string) error { + username = strings.ToLower(username) + if err := utils.ValidateUsername(username); err != nil { + return err + } + + query := tx.Rebind(`DELETE FROM users WHERE username = ?;`) + _, err := tx.ExecContext(ctx, query, username) + return err +} + +// FindUserByPublicKey implements store.UserStore. +func (*userStore) FindUserByPublicKey(ctx context.Context, tx *db.Tx, pk ssh.PublicKey) (models.User, error) { + var m models.User + query := tx.Rebind(`SELECT users.* + FROM users + INNER JOIN public_keys ON users.id = public_keys.user_id + WHERE public_keys.public_key = ?;`) + err := tx.GetContext(ctx, &m, query, sshutils.MarshalAuthorizedKey(pk)) + return m, err +} + +// FindUserByUsername implements store.UserStore. +func (*userStore) FindUserByUsername(ctx context.Context, tx *db.Tx, username string) (models.User, error) { + username = strings.ToLower(username) + if err := utils.ValidateUsername(username); err != nil { + return models.User{}, err + } + + var m models.User + query := tx.Rebind(`SELECT * FROM users WHERE username = ?;`) + err := tx.GetContext(ctx, &m, query, username) + return m, err +} + +// GetAllUsers implements store.UserStore. +func (*userStore) GetAllUsers(ctx context.Context, tx *db.Tx) ([]models.User, error) { + var ms []models.User + query := tx.Rebind(`SELECT * FROM users;`) + err := tx.SelectContext(ctx, &ms, query) + return ms, err +} + +// ListPublicKeysByUserID implements store.UserStore.. +func (*userStore) ListPublicKeysByUserID(ctx context.Context, tx *db.Tx, id int64) ([]ssh.PublicKey, error) { + var aks []string + query := tx.Rebind(`SELECT public_key FROM public_keys + WHERE user_id = ? + ORDER BY public_keys.id ASC;`) + err := tx.SelectContext(ctx, &aks, query, id) + if err != nil { + return nil, err + } + + pks := make([]ssh.PublicKey, len(aks)) + for i, ak := range aks { + pk, _, err := sshutils.ParseAuthorizedKey(ak) + if err != nil { + return nil, err + } + pks[i] = pk + } + + return pks, nil +} + +// ListPublicKeysByUsername implements store.UserStore. +func (*userStore) ListPublicKeysByUsername(ctx context.Context, tx *db.Tx, username string) ([]ssh.PublicKey, error) { + username = strings.ToLower(username) + if err := utils.ValidateUsername(username); err != nil { + return nil, err + } + + var aks []string + query := tx.Rebind(`SELECT public_key FROM public_keys + INNER JOIN users ON users.id = public_keys.user_id + WHERE users.username = ? + ORDER BY public_keys.id ASC;`) + err := tx.SelectContext(ctx, &aks, query, username) + if err != nil { + return nil, err + } + + pks := make([]ssh.PublicKey, len(aks)) + for i, ak := range aks { + pk, _, err := sshutils.ParseAuthorizedKey(ak) + if err != nil { + return nil, err + } + pks[i] = pk + } + + return pks, nil +} + +// RemovePublicKeyByUsername implements store.UserStore. +func (*userStore) RemovePublicKeyByUsername(ctx context.Context, tx *db.Tx, username string, pk ssh.PublicKey) error { + username = strings.ToLower(username) + if err := utils.ValidateUsername(username); err != nil { + return err + } + + query := tx.Rebind(`DELETE FROM public_keys + WHERE user_id = (SELECT id FROM users WHERE username = ?) + AND public_key = ?;`) + _, err := tx.ExecContext(ctx, query, username, sshutils.MarshalAuthorizedKey(pk)) + return err +} + +// SetAdminByUsername implements store.UserStore. +func (*userStore) SetAdminByUsername(ctx context.Context, tx *db.Tx, username string, isAdmin bool) error { + username = strings.ToLower(username) + if err := utils.ValidateUsername(username); err != nil { + return err + } + + query := tx.Rebind(`UPDATE users SET admin = ? WHERE username = ?;`) + _, err := tx.ExecContext(ctx, query, isAdmin, username) + return err +} + +// SetUsernameByUsername implements store.UserStore. +func (*userStore) SetUsernameByUsername(ctx context.Context, tx *db.Tx, username string, newUsername string) error { + username = strings.ToLower(username) + if err := utils.ValidateUsername(username); err != nil { + return err + } + + newUsername = strings.ToLower(newUsername) + if err := utils.ValidateUsername(newUsername); err != nil { + return err + } + + query := tx.Rebind(`UPDATE users SET username = ? WHERE username = ?;`) + _, err := tx.ExecContext(ctx, query, newUsername, username) + return err +} diff --git a/server/store/store.go b/server/store/store.go new file mode 100644 index 0000000000000000000000000000000000000000..d933dfb7d38ad643d472e72808a8aba438ab241b --- /dev/null +++ b/server/store/store.go @@ -0,0 +1,69 @@ +package store + +import ( + "context" + + "github.com/charmbracelet/soft-serve/server/access" + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/db/models" + "golang.org/x/crypto/ssh" +) + +// SettingStore is an interface for managing settings. +type SettingStore interface { + GetAnonAccess(ctx context.Context, tx *db.Tx) (access.AccessLevel, error) + SetAnonAccess(ctx context.Context, tx *db.Tx, level access.AccessLevel) error + GetAllowKeylessAccess(ctx context.Context, tx *db.Tx) (bool, error) + SetAllowKeylessAccess(ctx context.Context, tx *db.Tx, allow bool) error +} + +// RepositoryStore is an interface for managing repositories. +type RepositoryStore interface { + GetRepoByName(ctx context.Context, tx *db.Tx, name string) (models.Repo, error) + GetAllRepos(ctx context.Context, tx *db.Tx) ([]models.Repo, error) + CreateRepo(ctx context.Context, tx *db.Tx, name string, projectName string, description string, isPrivate bool, isHidden bool, isMirror bool) error + DeleteRepoByName(ctx context.Context, tx *db.Tx, name string) error + SetRepoNameByName(ctx context.Context, tx *db.Tx, name string, newName string) error + + GetRepoProjectNameByName(ctx context.Context, tx *db.Tx, name string) (string, error) + SetRepoProjectNameByName(ctx context.Context, tx *db.Tx, name string, projectName string) error + GetRepoDescriptionByName(ctx context.Context, tx *db.Tx, name string) (string, error) + SetRepoDescriptionByName(ctx context.Context, tx *db.Tx, name string, description string) error + GetRepoIsPrivateByName(ctx context.Context, tx *db.Tx, name string) (bool, error) + SetRepoIsPrivateByName(ctx context.Context, tx *db.Tx, name string, isPrivate bool) error + GetRepoIsHiddenByName(ctx context.Context, tx *db.Tx, name string) (bool, error) + SetRepoIsHiddenByName(ctx context.Context, tx *db.Tx, name string, isHidden bool) error + GetRepoIsMirrorByName(ctx context.Context, tx *db.Tx, name string) (bool, error) +} + +// UserStore is an interface for managing users. +type UserStore interface { + FindUserByUsername(ctx context.Context, tx *db.Tx, username string) (models.User, error) + FindUserByPublicKey(ctx context.Context, tx *db.Tx, pk ssh.PublicKey) (models.User, error) + GetAllUsers(ctx context.Context, tx *db.Tx) ([]models.User, error) + CreateUser(ctx context.Context, tx *db.Tx, username string, isAdmin bool, pks []ssh.PublicKey) error + DeleteUserByUsername(ctx context.Context, tx *db.Tx, username string) error + SetUsernameByUsername(ctx context.Context, tx *db.Tx, username string, newUsername string) error + SetAdminByUsername(ctx context.Context, tx *db.Tx, username string, isAdmin bool) error + AddPublicKeyByUsername(ctx context.Context, tx *db.Tx, username string, pk ssh.PublicKey) error + RemovePublicKeyByUsername(ctx context.Context, tx *db.Tx, username string, pk ssh.PublicKey) error + ListPublicKeysByUserID(ctx context.Context, tx *db.Tx, id int64) ([]ssh.PublicKey, error) + ListPublicKeysByUsername(ctx context.Context, tx *db.Tx, username string) ([]ssh.PublicKey, error) +} + +// CollaboratorStore is an interface for managing collaborators. +type CollaboratorStore interface { + GetCollabByUsernameAndRepo(ctx context.Context, tx *db.Tx, username string, repo string) (models.Collab, error) + AddCollabByUsernameAndRepo(ctx context.Context, tx *db.Tx, username string, repo string) error + RemoveCollabByUsernameAndRepo(ctx context.Context, tx *db.Tx, username string, repo string) error + ListCollabsByRepo(ctx context.Context, tx *db.Tx, repo string) ([]models.Collab, error) + ListCollabsByRepoAsUsers(ctx context.Context, tx *db.Tx, repo string) ([]models.User, error) +} + +// Store is an interface for managing repositories, users, and settings. +type Store interface { + RepositoryStore + UserStore + CollaboratorStore + SettingStore +} diff --git a/internal/sync/workqueue.go b/server/sync/workqueue.go similarity index 100% rename from internal/sync/workqueue.go rename to server/sync/workqueue.go diff --git a/server/sync/workqueue_test.go b/server/sync/workqueue_test.go new file mode 100644 index 0000000000000000000000000000000000000000..615e95dd7d05e2d10a4fc91d1dc7748b9841f698 --- /dev/null +++ b/server/sync/workqueue_test.go @@ -0,0 +1,35 @@ +package sync + +import ( + "context" + "strconv" + "sync" + "testing" +) + +func TestWorkPool(t *testing.T) { + mtx := &sync.Mutex{} + values := make([]int, 0) + wp := NewWorkPool(context.Background(), 3) + for i := 0; i < 10; i++ { + id := strconv.Itoa(i) + i := i + wp.Add(id, func() { + mtx.Lock() + values = append(values, i) + mtx.Unlock() + }) + } + wp.Run() + + if len(values) != 10 { + t.Errorf("expected 10 values, got %d, %v", len(values), values) + } + + for i := range values { + id := strconv.Itoa(i) + if wp.Status(id) { + t.Errorf("expected %s to be false", id) + } + } +} diff --git a/server/ui/common/common.go b/server/ui/common/common.go index d603906c9b57fdcb13343985e6b8ea974a821700..884b7d200519faf5081be350419094b844fabcdf 100644 --- a/server/ui/common/common.go +++ b/server/ui/common/common.go @@ -5,6 +5,7 @@ import ( "github.com/charmbracelet/log" "github.com/charmbracelet/soft-serve/git" + "github.com/charmbracelet/soft-serve/server/backend" "github.com/charmbracelet/soft-serve/server/config" "github.com/charmbracelet/soft-serve/server/ui/keymap" "github.com/charmbracelet/soft-serve/server/ui/styles" @@ -62,13 +63,19 @@ func (c *Common) SetSize(width, height int) { c.Height = height } +// Context returns the context. +func (c *Common) Context() context.Context { + return c.ctx +} + // Config returns the server config. func (c *Common) Config() *config.Config { - v := c.ctx.Value(ConfigKey) - if cfg, ok := v.(*config.Config); ok { - return cfg - } - return nil + return config.FromContext(c.ctx) +} + +// Backend returns the Soft Serve backend. +func (c *Common) Backend() *backend.Backend { + return backend.FromContext(c.ctx) } // Repo returns the repository. diff --git a/server/ui/common/utils.go b/server/ui/common/utils.go index 6ed9b714d445bc602bc6ab004bb15172e4555f78..aac2db6da3bd5f496f0c77409382323d4a0080fe 100644 --- a/server/ui/common/utils.go +++ b/server/ui/common/utils.go @@ -26,9 +26,8 @@ func RepoURL(publicURL, name string) string { port := url.Port() if port == "" || port == "22" { return fmt.Sprintf("git@%s:%s", url.Hostname(), name) - } else { - return fmt.Sprintf("ssh://%s:%s/%s", url.Hostname(), url.Port(), name) } + return fmt.Sprintf("ssh://%s:%s/%s", url.Hostname(), url.Port(), name) } } diff --git a/server/ui/components/code/code.go b/server/ui/components/code/code.go index 8291183349ee1059b6baa703f692014f0c55952c..aff1a526506d0e3eb37a40cf7164f9797f833d78 100644 --- a/server/ui/components/code/code.go +++ b/server/ui/components/code/code.go @@ -187,11 +187,9 @@ func (r *Code) renderFile(path, content string, width int) (string, error) { // width depends on the terminal. This is a workaround to replace tabs with // 4-spaces. content = strings.ReplaceAll(content, "\t", strings.Repeat(" ", tabWidth)) - lexer := lexers.Fallback + lexer := lexers.Match(path) if path == "" { lexer = lexers.Analyse(content) - } else { - lexer = lexers.Match(path) } lang := "" if lexer != nil && lexer.Config() != nil { diff --git a/server/ui/components/footer/footer.go b/server/ui/components/footer/footer.go index 822177594bf92f01d18d5adbd04c410bc9c0de8f..0022d0fe26ec058e94fa80b4d823f3b32697e463 100644 --- a/server/ui/components/footer/footer.go +++ b/server/ui/components/footer/footer.go @@ -47,7 +47,7 @@ func (f *Footer) Init() tea.Cmd { } // Update implements tea.Model. -func (f *Footer) Update(msg tea.Msg) (tea.Model, tea.Cmd) { +func (f *Footer) Update(_ tea.Msg) (tea.Model, tea.Cmd) { return f, nil } diff --git a/server/ui/components/header/header.go b/server/ui/components/header/header.go index 51c639f6d04f45fe9c2c5862e1c5b1f3d7f051ee..66870968a5a623606d755239bfb6c8a82688da99 100644 --- a/server/ui/components/header/header.go +++ b/server/ui/components/header/header.go @@ -32,7 +32,7 @@ func (h *Header) Init() tea.Cmd { } // Update implements tea.Model. -func (h *Header) Update(msg tea.Msg) (tea.Model, tea.Cmd) { +func (h *Header) Update(_ tea.Msg) (tea.Model, tea.Cmd) { return h, nil } diff --git a/server/ui/components/statusbar/statusbar.go b/server/ui/components/statusbar/statusbar.go index a1463a1e0635963995798391795c4c0ed67838fc..e5f0f70a5d7f35bebfebdc23d13003570e02aca8 100644 --- a/server/ui/components/statusbar/statusbar.go +++ b/server/ui/components/statusbar/statusbar.go @@ -8,7 +8,7 @@ import ( ) // StatusBarMsg is a message sent to the status bar. -type StatusBarMsg struct { +type StatusBarMsg struct { //nolint:revive Key string Value string Info string diff --git a/server/ui/pages/repo/files.go b/server/ui/pages/repo/files.go index e4692cb9054efb81c6f84cd5ff584a26d2d9a32e..bded6aa9632e5ef5d98ab1ef74fb49359fc776d3 100644 --- a/server/ui/pages/repo/files.go +++ b/server/ui/pages/repo/files.go @@ -10,7 +10,7 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/soft-serve/git" - "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/proto" "github.com/charmbracelet/soft-serve/server/ui/common" "github.com/charmbracelet/soft-serve/server/ui/components/code" "github.com/charmbracelet/soft-serve/server/ui/components/selector" @@ -26,7 +26,6 @@ const ( var ( errNoFileSelected = errors.New("no file selected") errBinaryFile = errors.New("binary file") - errFileTooLarge = errors.New("file is too large") errInvalidFile = errors.New("invalid file") ) @@ -52,7 +51,7 @@ type Files struct { selector *selector.Selector ref *git.Reference activeView filesView - repo backend.Repository + repo proto.Repository code *code.Code path string currentItem *FileItem diff --git a/server/ui/pages/repo/log.go b/server/ui/pages/repo/log.go index 21bbcdf9fe46cf8fec7738168fcdbb50d2efd3b1..60659e5a670e93489ec6965b7a5f7bc06faf4a7d 100644 --- a/server/ui/pages/repo/log.go +++ b/server/ui/pages/repo/log.go @@ -11,7 +11,7 @@ import ( gansi "github.com/charmbracelet/glamour/ansi" "github.com/charmbracelet/lipgloss" "github.com/charmbracelet/soft-serve/git" - "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/proto" "github.com/charmbracelet/soft-serve/server/ui/common" "github.com/charmbracelet/soft-serve/server/ui/components/footer" "github.com/charmbracelet/soft-serve/server/ui/components/selector" @@ -47,7 +47,7 @@ type Log struct { selector *selector.Selector vp *viewport.Viewport activeView logView - repo backend.Repository + repo proto.Repository ref *git.Reference count int64 nextPage int diff --git a/server/ui/pages/repo/logitem.go b/server/ui/pages/repo/logitem.go index 6330a109db4ee8a1e65f85b62a4f4e91035d2c4c..bd29ec8d0c18dbf9fbf648bd44053f6c23e072b3 100644 --- a/server/ui/pages/repo/logitem.go +++ b/server/ui/pages/repo/logitem.go @@ -25,6 +25,7 @@ func (i LogItem) ID() string { return i.Hash() } +// Hash returns the commit hash. func (i LogItem) Hash() string { return i.Commit.ID.String() } diff --git a/server/ui/pages/repo/readme.go b/server/ui/pages/repo/readme.go index a1623524fc7e3291069e9f117c25860267696949..6c6b610f7c85f290fc0f2657330ff55a1aa8880f 100644 --- a/server/ui/pages/repo/readme.go +++ b/server/ui/pages/repo/readme.go @@ -7,6 +7,7 @@ import ( "github.com/charmbracelet/bubbles/key" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/proto" "github.com/charmbracelet/soft-serve/server/ui/common" "github.com/charmbracelet/soft-serve/server/ui/components/code" ) @@ -21,7 +22,7 @@ type Readme struct { common common.Common code *code.Code ref RefMsg - repo backend.Repository + repo proto.Repository readmePath string } diff --git a/server/ui/pages/repo/refs.go b/server/ui/pages/repo/refs.go index 8669e151c085ca9417545cba19f0a4f9fec379a7..2dac5b38192d5c08c03d4ccc77fcbd4800f6cb69 100644 --- a/server/ui/pages/repo/refs.go +++ b/server/ui/pages/repo/refs.go @@ -1,7 +1,6 @@ package repo import ( - "errors" "fmt" "sort" "strings" @@ -10,16 +9,12 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/soft-serve/git" ggit "github.com/charmbracelet/soft-serve/git" - "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/proto" "github.com/charmbracelet/soft-serve/server/ui/common" "github.com/charmbracelet/soft-serve/server/ui/components/selector" "github.com/charmbracelet/soft-serve/server/ui/components/tabs" ) -var ( - errNoRef = errors.New("no reference specified") -) - // RefMsg is a message that contains a git.Reference. type RefMsg *ggit.Reference @@ -33,7 +28,7 @@ type RefItemsMsg struct { type Refs struct { common common.Common selector *selector.Selector - repo backend.Repository + repo proto.Repository ref *git.Reference activeRef *git.Reference refPrefix string @@ -216,7 +211,7 @@ func switchRefCmd(ref *ggit.Reference) tea.Cmd { } // UpdateRefCmd gets the repository's HEAD reference and sends a RefMsg. -func UpdateRefCmd(repo backend.Repository) tea.Cmd { +func UpdateRefCmd(repo proto.Repository) tea.Cmd { return func() tea.Msg { r, err := repo.Open() if err != nil { diff --git a/server/ui/pages/repo/repo.go b/server/ui/pages/repo/repo.go index 5cb2075db2104a926a0b4c424f2682fc3a92062b..3f1927ba8de10535f6712756c01566040eac5950 100644 --- a/server/ui/pages/repo/repo.go +++ b/server/ui/pages/repo/repo.go @@ -9,7 +9,7 @@ import ( tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" "github.com/charmbracelet/soft-serve/git" - "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/proto" "github.com/charmbracelet/soft-serve/server/ui/common" "github.com/charmbracelet/soft-serve/server/ui/components/footer" "github.com/charmbracelet/soft-serve/server/ui/components/statusbar" @@ -54,7 +54,7 @@ type CopyURLMsg struct{} type UpdateStatusBarMsg struct{} // RepoMsg is a message that contains a git.Repository. -type RepoMsg backend.Repository +type RepoMsg proto.Repository // nolint:revive // BackMsg is a message to go back to the previous view. type BackMsg struct{} @@ -68,7 +68,7 @@ type CopyMsg struct { // Repo is a view for a git repository. type Repo struct { common common.Common - selectedRepo backend.Repository + selectedRepo proto.Repository activeTab tab tabs *tabs.Tabs statusbar *statusbar.StatusBar diff --git a/server/ui/pages/selection/item.go b/server/ui/pages/selection/item.go index 6550497c78b954f03ce5dc9627aefce0d9ce66b4..20a381ff0180371eae61338877767e6125909da1 100644 --- a/server/ui/pages/selection/item.go +++ b/server/ui/pages/selection/item.go @@ -11,8 +11,8 @@ import ( "github.com/charmbracelet/bubbles/list" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/charmbracelet/soft-serve/server/backend" "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/proto" "github.com/charmbracelet/soft-serve/server/ui/common" "github.com/dustin/go-humanize" ) @@ -48,13 +48,13 @@ func (it Items) Swap(i int, j int) { // Item represents a single item in the selector. type Item struct { - repo backend.Repository + repo proto.Repository lastUpdate *time.Time cmd string } // New creates a new Item. -func NewItem(repo backend.Repository, cfg *config.Config) (Item, error) { +func NewItem(repo proto.Repository, cfg *config.Config) (Item, error) { var lastUpdate *time.Time lu := repo.UpdatedAt() if !lu.IsZero() { diff --git a/server/ui/pages/selection/selection.go b/server/ui/pages/selection/selection.go index dc066f32a13889e0da8b647c393ac9d6d158efb7..b85573d3343faa69c42985df000a84d663e2317f 100644 --- a/server/ui/pages/selection/selection.go +++ b/server/ui/pages/selection/selection.go @@ -8,6 +8,7 @@ import ( "github.com/charmbracelet/bubbles/list" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" + "github.com/charmbracelet/soft-serve/server/access" "github.com/charmbracelet/soft-serve/server/backend" "github.com/charmbracelet/soft-serve/server/ui/common" "github.com/charmbracelet/soft-serve/server/ui/components/code" @@ -36,12 +37,11 @@ func (p pane) String() string { // Selection is the model for the selection screen/page. type Selection struct { - common common.Common - readme *code.Code - readmeHeight int - selector *selector.Selector - activePane pane - tabs *tabs.Tabs + common common.Common + readme *code.Code + selector *selector.Selector + activePane pane + tabs *tabs.Tabs } // New creates a new selection model. @@ -187,12 +187,14 @@ func (s *Selection) Init() tea.Cmd { return nil } + ctx := s.common.Context() + be := s.common.Backend() pk := s.common.PublicKey() - if pk == nil && !cfg.Backend.AllowKeyless() { + if pk == nil && !be.AllowKeyless(ctx) { return nil } - repos, err := cfg.Backend.Repositories() + repos, err := be.Repositories(ctx) if err != nil { return common.ErrorCmd(err) } @@ -210,8 +212,8 @@ func (s *Selection) Init() tea.Cmd { if r.IsHidden() { continue } - al := cfg.Backend.AccessLevelByPublicKey(r.Name(), pk) - if al >= backend.ReadOnlyAccess { + al := be.AccessLevelByPublicKey(ctx, r.Name(), pk) + if al >= access.ReadOnlyAccess { item, err := NewItem(r, cfg) if err != nil { s.common.Logger.Debugf("ui: failed to create item for %s: %v", r.Name(), err) diff --git a/server/ui/ui.go b/server/ui/ui.go index 69b526d6ee296101a447a6dd470ddbe9aa4da49e..5e259004d7b91fada6151713c9c6c6adcb939a95 100644 --- a/server/ui/ui.go +++ b/server/ui/ui.go @@ -7,7 +7,7 @@ import ( "github.com/charmbracelet/bubbles/list" tea "github.com/charmbracelet/bubbletea" "github.com/charmbracelet/lipgloss" - "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/proto" "github.com/charmbracelet/soft-serve/server/ui/common" "github.com/charmbracelet/soft-serve/server/ui/components/footer" "github.com/charmbracelet/soft-serve/server/ui/components/header" @@ -283,12 +283,15 @@ func (ui *UI) View() string { ) } -func (ui *UI) openRepo(rn string) (backend.Repository, error) { +func (ui *UI) openRepo(rn string) (proto.Repository, error) { cfg := ui.common.Config() if cfg == nil { return nil, errors.New("config is nil") } - repos, err := cfg.Backend.Repositories() + + ctx := ui.common.Context() + be := ui.common.Backend() + repos, err := be.Repositories(ctx) if err != nil { ui.common.Logger.Debugf("ui: failed to list repos: %v", err) return nil, err diff --git a/server/web/context.go b/server/web/context.go new file mode 100644 index 0000000000000000000000000000000000000000..d0a7879af48f45a9c074f3378a6006eb30f633f2 --- /dev/null +++ b/server/web/context.go @@ -0,0 +1,28 @@ +package web + +import ( + "context" + "net/http" + + "github.com/charmbracelet/log" + "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/config" +) + +// NewContextMiddleware returns a new context middleware. +// This middleware adds the config, backend, and logger to the request context. +func NewContextMiddleware(ctx context.Context) func(http.Handler) http.Handler { + cfg := config.FromContext(ctx) + be := backend.FromContext(ctx) + logger := log.FromContext(ctx).WithPrefix("http") + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + ctx = config.WithContext(ctx, cfg) + ctx = backend.WithContext(ctx, be) + ctx = log.WithContext(ctx, logger) + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + }) + } +} diff --git a/server/web/git.go b/server/web/git.go index 2ca9265912dc2af5c94f2d629ec27b230b811afa..36e36d1bee6b6fa8774a70b87d094b8c86f4b6d5 100644 --- a/server/web/git.go +++ b/server/web/git.go @@ -4,6 +4,7 @@ import ( "bytes" "compress/gzip" "context" + "errors" "fmt" "io" "net/http" @@ -15,9 +16,11 @@ import ( "github.com/charmbracelet/log" gitb "github.com/charmbracelet/soft-serve/git" + "github.com/charmbracelet/soft-serve/server/access" "github.com/charmbracelet/soft-serve/server/backend" "github.com/charmbracelet/soft-serve/server/config" "github.com/charmbracelet/soft-serve/server/git" + "github.com/charmbracelet/soft-serve/server/proto" "github.com/charmbracelet/soft-serve/server/utils" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -30,22 +33,15 @@ type GitRoute struct { method string pattern *regexp.Regexp handler http.HandlerFunc - - cfg *config.Config - be backend.Backend - logger *log.Logger } var _ Route = GitRoute{} // Match implements goji.Pattern. func (g GitRoute) Match(r *http.Request) *http.Request { - if g.method != r.Method { - return nil - } - re := g.pattern ctx := r.Context() + cfg := config.FromContext(ctx) if m := re.FindStringSubmatch(r.URL.Path); m != nil { file := strings.Replace(r.URL.Path, m[1]+"/", "", 1) repo := utils.SanitizeRepo(m[1]) + ".git" @@ -59,22 +55,10 @@ func (g GitRoute) Match(r *http.Request) *http.Request { } ctx = context.WithValue(ctx, pattern.Variable("service"), service.String()) - ctx = context.WithValue(ctx, pattern.Variable("dir"), filepath.Join(g.cfg.DataPath, "repos", repo)) + ctx = context.WithValue(ctx, pattern.Variable("dir"), filepath.Join(cfg.DataPath, "repos", repo)) ctx = context.WithValue(ctx, pattern.Variable("repo"), repo) ctx = context.WithValue(ctx, pattern.Variable("file"), file) - if g.cfg != nil { - ctx = config.WithContext(ctx, g.cfg) - } - - if g.be != nil { - ctx = backend.WithContext(ctx, g.be.WithContext(ctx)) - } - - if g.logger != nil { - ctx = log.WithContext(ctx, g.logger) - } - return r.WithContext(ctx) } @@ -83,10 +67,16 @@ func (g GitRoute) Match(r *http.Request) *http.Request { // ServeHTTP implements http.Handler. func (g GitRoute) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != g.method { + renderMethodNotAllowed(w, r) + return + } + g.handler(w, r) } var ( + //nolint:revive gitHttpReceiveCounter = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: "soft_serve", Subsystem: "http", @@ -94,6 +84,7 @@ var ( Help: "The total number of git push requests", }, []string{"repo"}) + //nolint:revive gitHttpUploadCounter = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: "soft_serve", Subsystem: "http", @@ -102,10 +93,8 @@ var ( }, []string{"repo", "file"}) ) -func gitRoutes(ctx context.Context, logger *log.Logger) []Route { +func gitRoutes() []Route { routes := make([]Route, 0) - cfg := config.FromContext(ctx) - be := backend.FromContext(ctx) // Git services // These routes don't handle authentication/authorization. @@ -159,19 +148,16 @@ func gitRoutes(ctx context.Context, logger *log.Logger) []Route { handler: getLooseObject, }, { - pattern: regexp.MustCompile("(.*?)/objects/pack/pack-[0-9a-f]{40}\\.pack$"), + pattern: regexp.MustCompile(`(.*?)/objects/pack/pack-[0-9a-f]{40}\.pack$`), method: http.MethodGet, handler: getPackFile, }, { - pattern: regexp.MustCompile("(.*?)/objects/pack/pack-[0-9a-f]{40}\\.idx$"), + pattern: regexp.MustCompile(`(.*?)/objects/pack/pack-[0-9a-f]{40}\.idx$`), method: http.MethodGet, handler: getIdxFile, }, } { - route.cfg = cfg - route.be = be - route.logger = logger route.handler = withAccess(route.handler) routes = append(routes, route) } @@ -186,32 +172,32 @@ func withAccess(fn http.HandlerFunc) http.HandlerFunc { be := backend.FromContext(ctx) logger := log.FromContext(ctx) - if !be.AllowKeyless() { + if !be.AllowKeyless(ctx) { renderForbidden(w) return } repo := pat.Param(r, "repo") service := git.Service(pat.Param(r, "service")) - access := be.AccessLevel(repo, "") + accessLevel := be.AccessLevel(ctx, repo, "") switch service { case git.ReceivePackService: - if access < backend.ReadWriteAccess { + if accessLevel < access.ReadWriteAccess { renderUnauthorized(w) return } // Create the repo if it doesn't exist. - if _, err := be.Repository(repo); err != nil { - if _, err := be.CreateRepository(repo, backend.RepositoryOptions{}); err != nil { + if _, err := be.Repository(ctx, repo); err != nil { + if _, err := be.CreateRepository(ctx, repo, proto.RepositoryOptions{}); err != nil { logger.Error("failed to create repository", "repo", repo, "err", err) renderInternalServerError(w) return } } default: - if access < backend.ReadOnlyAccess { + if accessLevel < access.ReadOnlyAccess { renderUnauthorized(w) return } @@ -221,8 +207,10 @@ func withAccess(fn http.HandlerFunc) http.HandlerFunc { } } +//nolint:revive func serviceRpc(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + cfg := config.FromContext(ctx) logger := log.FromContext(ctx) service, dir, repo := git.Service(pat.Param(r, "service")), pat.Param(r, "dir"), pat.Param(r, "repo") @@ -243,71 +231,77 @@ func serviceRpc(w http.ResponseWriter, r *http.Request) { version := r.Header.Get("Git-Protocol") + var stdout bytes.Buffer cmd := git.ServiceCommand{ - Stdin: r.Body, - Stdout: w, + Stdout: &stdout, Dir: dir, Args: []string{"--stateless-rpc"}, } if len(version) != 0 { - cmd.Env = append(cmd.Env, fmt.Sprintf("GIT_PROTOCOL=%s", version)) + cmd.Env = append(cmd.Env, []string{ + // TODO: add the rest of env vars when we support pushing using http + "SOFT_SERVE_LOG_PATH=" + filepath.Join(cfg.DataPath, "log", "hooks.log"), + fmt.Sprintf("GIT_PROTOCOL=%s", version), + }...) } // Handle gzip encoding - cmd.StdinHandler = func(in io.Reader, stdin io.WriteCloser) (err error) { - // We know that `in` is an `io.ReadCloser` because it's `r.Body`. - reader := in.(io.ReadCloser) - defer reader.Close() // nolint: errcheck - switch r.Header.Get("Content-Encoding") { - case "gzip": - reader, err = gzip.NewReader(reader) - if err != nil { - return err - } - defer reader.Close() // nolint: errcheck + reader := r.Body + defer reader.Close() // nolint: errcheck + switch r.Header.Get("Content-Encoding") { + case "gzip": + reader, err := gzip.NewReader(reader) + if err != nil { + logger.Errorf("failed to create gzip reader: %v", err) + renderInternalServerError(w) + return } + defer reader.Close() // nolint: errcheck + } + + cmd.Stdin = reader - _, err = io.Copy(stdin, reader) - return err + if err := service.Handler(ctx, cmd); err != nil { + if errors.Is(err, git.ErrInvalidRepo) { + renderNotFound(w) + return + } + renderInternalServerError(w) + return } // Handle buffered output // Useful when using proxies - cmd.StdoutHandler = func(out io.Writer, stdout io.ReadCloser) error { - // We know that `out` is an `http.ResponseWriter`. - flusher, ok := out.(http.Flusher) - if !ok { - return fmt.Errorf("expected http.ResponseWriter to be an http.Flusher, got %T", out) - } - - p := make([]byte, 1024) - for { - nRead, err := stdout.Read(p) - if err == io.EOF { - break - } - nWrite, err := out.Write(p[:nRead]) - if err != nil { - return err - } - if nRead != nWrite { - return fmt.Errorf("failed to write data: %d read, %d written", nRead, nWrite) - } - flusher.Flush() - } - return nil + // We know that `w` is an `http.ResponseWriter`. + flusher, ok := w.(http.Flusher) + if !ok { + logger.Errorf("expected http.ResponseWriter to be an http.Flusher, got %T", w) + return } - if err := service.Handler(ctx, cmd); err != nil { - logger.Errorf("error executing service: %s", err) + p := make([]byte, 1024) + for { + nRead, err := stdout.Read(p) + if err == io.EOF { + break + } + nWrite, err := w.Write(p[:nRead]) + if err != nil { + logger.Errorf("failed to write data: %v", err) + return + } + if nRead != nWrite { + logger.Errorf("failed to write data: %d read, %d written", nRead, nWrite) + return + } + flusher.Flush() } } func getInfoRefs(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - logger := log.FromContext(ctx) dir, repo, file := pat.Param(r, "dir"), pat.Param(r, "repo"), pat.Param(r, "file") service := getServiceType(r) version := r.Header.Get("Git-Protocol") @@ -328,7 +322,6 @@ func getInfoRefs(w http.ResponseWriter, r *http.Request) { } if err := service.Handler(ctx, cmd); err != nil { - logger.Errorf("error executing service: %s", err) renderNotFound(w) return } @@ -337,7 +330,7 @@ func getInfoRefs(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", fmt.Sprintf("application/x-%s-advertisement", service)) w.WriteHeader(http.StatusOK) if len(version) == 0 { - git.WritePktline(w, "# service="+service.String()) + git.WritePktline(w, "# service="+service.String()) // nolint: errcheck } w.Write(refs.Bytes()) // nolint: errcheck @@ -400,10 +393,7 @@ func getServiceType(r *http.Request) git.Service { } func isSmart(r *http.Request, service git.Service) bool { - if r.Header.Get("Content-Type") == fmt.Sprintf("application/x-%s-request", service) { - return true - } - return false + return r.Header.Get("Content-Type") == fmt.Sprintf("application/x-%s-request", service) } func updateServerInfo(ctx context.Context, dir string) error { diff --git a/server/web/goget.go b/server/web/goget.go index 7e7c8c9d61faefd0b8bc629751e1d7328fa1c042..8ee4609f1ca106a151aa67deb48b0406f0693681 100644 --- a/server/web/goget.go +++ b/server/web/goget.go @@ -34,17 +34,16 @@ Redirecting to docs at `)) // GoGetHandler handles go get requests. -type GoGetHandler struct { - cfg *config.Config - be backend.Backend -} +type GoGetHandler struct{} var _ http.Handler = (*GoGetHandler)(nil) func (g GoGetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { repo := pattern.Path(r.Context()) repo = utils.SanitizeRepo(repo) - be := g.be.WithContext(r.Context()) + ctx := r.Context() + cfg := config.FromContext(ctx) + be := backend.FromContext(ctx) // Handle go get requests. // @@ -54,7 +53,7 @@ func (g GoGetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // https://go.dev/ref/mod#vcs-branch if r.URL.Query().Get("go-get") == "1" { repo := repo - importRoot, err := url.Parse(g.cfg.HTTP.PublicURL) + importRoot, err := url.Parse(cfg.HTTP.PublicURL) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -62,7 +61,7 @@ func (g GoGetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // find the repo for { - if _, err := be.Repository(repo); err == nil { + if _, err := be.Repository(ctx, repo); err == nil { break } @@ -79,7 +78,7 @@ func (g GoGetHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ImportRoot string }{ Repo: url.PathEscape(repo), - Config: g.cfg, + Config: cfg, ImportRoot: importRoot.Host, }); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/server/web/http.go b/server/web/http.go index 7ff375cab3132d377d2f92e0ce5184dc94ce43a0..538e9fa3756ccb3203b732e604d62cee5e78201b 100644 --- a/server/web/http.go +++ b/server/web/http.go @@ -5,7 +5,6 @@ import ( "net/http" "time" - "github.com/charmbracelet/soft-serve/server/backend" "github.com/charmbracelet/soft-serve/server/config" ) @@ -13,7 +12,6 @@ import ( type HTTPServer struct { ctx context.Context cfg *config.Config - be backend.Backend server *http.Server } @@ -23,7 +21,6 @@ func NewHTTPServer(ctx context.Context) (*HTTPServer, error) { s := &HTTPServer{ ctx: ctx, cfg: cfg, - be: backend.FromContext(ctx), server: &http.Server{ Addr: cfg.HTTP.ListenAddr, Handler: NewRouter(ctx), diff --git a/server/web/logging.go b/server/web/logging.go index f0f43a05c832cc98b2edc24ede2c83254b6d1b95..40f187e0888defc2e7412ad35c576bfa0b98afbd 100644 --- a/server/web/logging.go +++ b/server/web/logging.go @@ -24,7 +24,7 @@ var _ http.Flusher = (*logWriter)(nil) var _ http.Hijacker = (*logWriter)(nil) -var _ http.CloseNotifier = (*logWriter)(nil) +var _ http.CloseNotifier = (*logWriter)(nil) // nolint: staticcheck // Write implements http.ResponseWriter. func (r *logWriter) Write(p []byte) (int, error) { @@ -49,7 +49,7 @@ func (r *logWriter) Flush() { // CloseNotify implements http.CloseNotifier. func (r *logWriter) CloseNotify() <-chan bool { - if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok { + if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok { // nolint: staticcheck return cn.CloseNotify() } return nil @@ -64,21 +64,20 @@ func (r *logWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { } // NewLoggingMiddleware returns a new logging middleware. -func NewLoggingMiddleware(logger *log.Logger) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - start := time.Now() - writer := &logWriter{code: http.StatusOK, ResponseWriter: w} - logger.Debug("request", - "method", r.Method, - "uri", r.RequestURI, - "addr", r.RemoteAddr) - next.ServeHTTP(writer, r) - elapsed := time.Since(start) - logger.Debug("response", - "status", fmt.Sprintf("%d %s", writer.code, http.StatusText(writer.code)), - "bytes", humanize.Bytes(uint64(writer.bytes)), - "time", elapsed) - }) - } +func NewLoggingMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + logger := log.FromContext(r.Context()) + start := time.Now() + writer := &logWriter{code: http.StatusOK, ResponseWriter: w} + logger.Debug("request", + "method", r.Method, + "uri", r.RequestURI, + "addr", r.RemoteAddr) + next.ServeHTTP(writer, r) + elapsed := time.Since(start) + logger.Debug("response", + "status", fmt.Sprintf("%d %s", writer.code, http.StatusText(writer.code)), + "bytes", humanize.Bytes(uint64(writer.bytes)), + "time", elapsed) + }) } diff --git a/server/web/server.go b/server/web/server.go index ea15e778fbd2b4d131c75f9bd6f6f37b939e6be7..95a8d7df84f1c9bc80fbab03ea8803d5693cb2f9 100644 --- a/server/web/server.go +++ b/server/web/server.go @@ -1,13 +1,9 @@ -// Package server is the reusable server package web import ( "context" "net/http" - "github.com/charmbracelet/log" - "github.com/charmbracelet/soft-serve/server/backend" - "github.com/charmbracelet/soft-serve/server/config" "goji.io" "goji.io/pat" ) @@ -21,20 +17,18 @@ type Route interface { // NewRouter returns a new HTTP router. func NewRouter(ctx context.Context) *goji.Mux { mux := goji.NewMux() - cfg := config.FromContext(ctx) - be := backend.FromContext(ctx) - logger := log.FromContext(ctx).WithPrefix("http") // Middlewares - mux.Use(NewLoggingMiddleware(logger)) + mux.Use(NewContextMiddleware(ctx)) + mux.Use(NewLoggingMiddleware) // Git routes - for _, service := range gitRoutes(ctx, logger) { + for _, service := range gitRoutes() { mux.Handle(service, service) } // go-get handler - mux.Handle(pat.Get("/*"), GoGetHandler{cfg, be}) + mux.Handle(pat.Get("/*"), GoGetHandler{}) return mux } diff --git a/testscript/script_test.go b/testscript/script_test.go index 76c1866bddb91a7675ebe0e0d93d25447beb905b..032bd47eb2be04bb400925a9f589eac1766e429a 100644 --- a/testscript/script_test.go +++ b/testscript/script_test.go @@ -15,10 +15,14 @@ import ( "github.com/charmbracelet/keygen" "github.com/charmbracelet/soft-serve/server" + "github.com/charmbracelet/soft-serve/server/backend" "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/db/migrate" "github.com/charmbracelet/soft-serve/server/test" "github.com/rogpeppe/go-internal/testscript" "golang.org/x/crypto/ssh" + _ "modernc.org/sqlite" // sqlite Driver ) var update = flag.Bool("update", false, "update script files") @@ -51,42 +55,56 @@ func TestScript(t *testing.T) { "dos2unix": cmdDos2Unix, }, Setup: func(e *testscript.Env) error { + data := t.TempDir() + sshPort := test.RandomPort() + sshListen := fmt.Sprintf("localhost:%d", sshPort) + gitPort := test.RandomPort() + gitListen := fmt.Sprintf("localhost:%d", gitPort) + httpPort := test.RandomPort() + httpListen := fmt.Sprintf("localhost:%d", httpPort) + statsPort := test.RandomPort() + statsListen := fmt.Sprintf("localhost:%d", statsPort) + serverName := "Test Soft Serve" + e.Setenv("SSH_PORT", fmt.Sprintf("%d", sshPort)) e.Setenv("ADMIN1_AUTHORIZED_KEY", admin1.AuthorizedKey()) e.Setenv("ADMIN2_AUTHORIZED_KEY", admin2.AuthorizedKey()) e.Setenv("USER1_AUTHORIZED_KEY", user1.AuthorizedKey()) e.Setenv("SSH_KNOWN_HOSTS_FILE", filepath.Join(t.TempDir(), "known_hosts")) e.Setenv("SSH_KNOWN_CONFIG_FILE", filepath.Join(t.TempDir(), "config")) - data := t.TempDir() - cfg := config.Config{ - Name: "Test Soft Serve", - DataPath: data, - InitialAdminKeys: []string{admin1.AuthorizedKey()}, - SSH: config.SSHConfig{ - ListenAddr: fmt.Sprintf("localhost:%d", sshPort), - PublicURL: fmt.Sprintf("ssh://localhost:%d", sshPort), - KeyPath: filepath.Join(data, "ssh", "soft_serve_host_ed25519"), - ClientKeyPath: filepath.Join(data, "ssh", "soft_serve_client_ed25519"), - }, - Git: config.GitConfig{ - ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()), - IdleTimeout: 3, - MaxConnections: 32, - }, - HTTP: config.HTTPConfig{ - ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()), - PublicURL: fmt.Sprintf("http://localhost:%d", test.RandomPort()), - }, - Stats: config.StatsConfig{ - ListenAddr: fmt.Sprintf("localhost:%d", test.RandomPort()), - }, - Log: config.LogConfig{ - Format: "text", - TimeFormat: time.DateTime, - }, + + cfg := config.DefaultConfig() + cfg.DataPath = data + cfg.Name = serverName + cfg.InitialAdminKeys = []string{admin1.AuthorizedKey()} + cfg.SSH.ListenAddr = sshListen + cfg.SSH.PublicURL = "ssh://" + sshListen + cfg.Git.ListenAddr = gitListen + cfg.HTTP.ListenAddr = httpListen + cfg.HTTP.PublicURL = "http://" + httpListen + cfg.Stats.ListenAddr = statsListen + cfg.DB.Driver = "sqlite" + + if err := cfg.Validate(); err != nil { + return err } - ctx := config.WithContext(context.Background(), &cfg) + + ctx := config.WithContext(context.Background(), cfg) + + // TODO: test postgres + dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource) + if err != nil { + return fmt.Errorf("open database: %w", err) + } + + if err := migrate.Migrate(ctx, dbx); err != nil { + return fmt.Errorf("migrate database: %w", err) + } + + ctx = db.WithContext(ctx, dbx) + be := backend.New(ctx, cfg, dbx) + ctx = backend.WithContext(ctx, be) // prevent race condition in lipgloss... // this will probably be autofixed when we start using the colors @@ -106,6 +124,7 @@ func TestScript(t *testing.T) { }() e.Defer(func() { + defer dbx.Close() // nolint: errcheck ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := srv.Shutdown(ctx); err != nil { diff --git a/testscript/testdata/repo-commit.txtar b/testscript/testdata/repo-commit.txtar new file mode 100644 index 0000000000000000000000000000000000000000..dbf5c1bf80f5db690a1c30d4157834ca7a1e2316 --- /dev/null +++ b/testscript/testdata/repo-commit.txtar @@ -0,0 +1,30 @@ +# vi: set ft=conf + +# convert crlf to lf on windows +[windows] dos2unix commit1.txt + +# create a repo +soft repo import basic1 https://github.com/git-fixtures/basic + +# print commit +soft repo commit basic1 b8e471f58bcbca63b07bda20e428190409c2db47 +cmp stdout commit1.txt + +-- commit1.txt -- +commit b8e471f58bcbca63b07bda20e428190409c2db47 +Author: Daniel Ripolles +Date: Tue Mar 31 11:44:52 UTC 2015 +Creating changelog + + +CHANGELOG | 1 + +1 file changed, 1 insertion(+) + +diff --git a/CHANGELOG b/CHANGELOG +new file mode 100644 +index 0000000000000000000000000000000000000000..d3ff53e0564a9f87d8e84b6e28e5060e517008aa +--- /dev/null ++++ b/CHANGELOG +@@ -0,0 +1 @@ ++Initial changelog + diff --git a/testscript/testdata/repo-import.txt b/testscript/testdata/repo-import.txtar similarity index 100% rename from testscript/testdata/repo-import.txt rename to testscript/testdata/repo-import.txtar