From 8e68dba683ab45107d5ab100b77ea1bfccdb1a2c Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Thu, 20 Jul 2023 13:08:50 -0400 Subject: [PATCH] feat: implement git-lfs preliminary support - Support Git LFS SSH transfer (server-side) - Implement Git LFS Basic transfer (client) - Import missing LFS objects when importing a repository wip fix: user can be nil for anonymous connections fix: wrap db errors fix: lint errors --- cmd/soft/root.go | 4 + git/config.go | 72 ++- git/repo.go | 28 -- go.mod | 9 +- go.sum | 21 +- server/backend/backend.go | 3 +- server/backend/lfs.go | 85 ++++ server/backend/repo.go | 73 ++- server/backend/user.go | 59 ++- server/daemon/daemon_test.go | 14 +- server/db/context.go | 7 +- server/db/handler.go | 25 + server/db/migrate/0002_create_lfs_tables.go | 23 + .../0002_create_lfs_tables_postgres.down.sql | 2 + .../0002_create_lfs_tables_postgres.up.sql | 28 ++ .../0002_create_lfs_tables_sqlite.down.sql | 2 + .../0002_create_lfs_tables_sqlite.up.sql | 28 ++ server/db/migrate/migrations.go | 1 + server/db/models/lfs.go | 24 + server/git/lfs.go | 451 ++++++++++++++++++ server/git/service.go | 7 + server/lfs/basic_transfer.go | 124 +++++ server/lfs/client.go | 27 ++ server/lfs/common.go | 88 ++++ server/lfs/endpoint.go | 70 +++ server/lfs/http_client.go | 196 ++++++++ server/lfs/pointer.go | 122 +++++ server/lfs/scanner.go | 210 ++++++++ server/lfs/ssh_client.go | 3 + server/lfs/transfer.go | 17 + server/proto/repo.go | 5 + server/proto/user.go | 5 + server/ssh/cmd/delete.go | 7 +- server/ssh/git.go | 33 +- server/ssh/middleware.go | 6 +- server/ssh/session_test.go | 14 +- server/ssh/ssh.go | 12 +- server/storage/local.go | 91 ++++ server/storage/storage.go | 23 + server/store/context.go | 20 + server/store/database/collab.go | 10 +- server/store/database/database.go | 2 + server/store/database/lfs.go | 179 +++++++ server/store/database/repo.go | 28 +- server/store/database/settings.go | 8 +- server/store/database/user.go | 30 +- server/store/lfs.go | 26 + server/store/store.go | 70 +-- testscript/script_test.go | 4 + 49 files changed, 2206 insertions(+), 190 deletions(-) create mode 100644 server/backend/lfs.go create mode 100644 server/db/handler.go create mode 100644 server/db/migrate/0002_create_lfs_tables.go create mode 100644 server/db/migrate/0002_create_lfs_tables_postgres.down.sql create mode 100644 server/db/migrate/0002_create_lfs_tables_postgres.up.sql create mode 100644 server/db/migrate/0002_create_lfs_tables_sqlite.down.sql create mode 100644 server/db/migrate/0002_create_lfs_tables_sqlite.up.sql create mode 100644 server/db/models/lfs.go create mode 100644 server/git/lfs.go create mode 100644 server/lfs/basic_transfer.go create mode 100644 server/lfs/client.go create mode 100644 server/lfs/common.go create mode 100644 server/lfs/endpoint.go create mode 100644 server/lfs/http_client.go create mode 100644 server/lfs/pointer.go create mode 100644 server/lfs/scanner.go create mode 100644 server/lfs/ssh_client.go create mode 100644 server/lfs/transfer.go create mode 100644 server/storage/local.go create mode 100644 server/storage/storage.go create mode 100644 server/store/context.go create mode 100644 server/store/database/lfs.go create mode 100644 server/store/lfs.go diff --git a/cmd/soft/root.go b/cmd/soft/root.go index b3216162f5601ed44bddfd3fa0bdcc86a8a1825c..7f03982d437c09440c5c65e0932bd3c9af294f9f 100644 --- a/cmd/soft/root.go +++ b/cmd/soft/root.go @@ -12,6 +12,8 @@ import ( "github.com/charmbracelet/soft-serve/server/backend" "github.com/charmbracelet/soft-serve/server/config" "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/store" + "github.com/charmbracelet/soft-serve/server/store/database" _ "github.com/lib/pq" // postgres driver "github.com/spf13/cobra" "go.uber.org/automaxprocs/maxprocs" @@ -150,6 +152,8 @@ func initBackendContext(cmd *cobra.Command, _ []string) error { } ctx = db.WithContext(ctx, dbx) + dbstore := database.New(ctx, dbx) + ctx = store.WithContext(ctx, dbstore) be := backend.New(ctx, cfg, dbx) ctx = backend.WithContext(ctx, be) diff --git a/git/config.go b/git/config.go index 4e9af6ed500302e8d1e7e7afeabf13bbba4bde48..1ebb2470384145be6ff8b2ab4cafd346d41bac64 100644 --- a/git/config.go +++ b/git/config.go @@ -1,51 +1,39 @@ package git -// ConfigOptions are options for Config. -type ConfigOptions struct { - File string - All bool - Add bool - CommandOptions -} +import ( + "os" + "path/filepath" -// Config gets a git configuration. -func Config(key string, opts ...ConfigOptions) (string, error) { - var opt ConfigOptions - if len(opts) > 0 { - opt = opts[0] - } - cmd := NewCommand("config") - if opt.File != "" { - cmd.AddArgs("--file", opt.File) - } - if opt.All { - cmd.AddArgs("--get-all") - } - for _, a := range opt.Args { - cmd.AddArgs(a) - } - cmd.AddArgs(key) - bts, err := cmd.Run() + gcfg "github.com/go-git/go-git/v5/plumbing/format/config" +) + +// Config returns the repository Git configuration. +func (r *Repository) Config() (*gcfg.Config, error) { + cp := filepath.Join(r.Path, "config") + f, err := os.Open(cp) if err != nil { - return "", err + return nil, err } - return string(bts), nil -} -// SetConfig sets a git configuration. -func SetConfig(key string, value string, opts ...ConfigOptions) error { - var opt ConfigOptions - if len(opts) > 0 { - opt = opts[0] + defer f.Close() // nolint: errcheck + d := gcfg.NewDecoder(f) + cfg := gcfg.New() + if err := d.Decode(cfg); err != nil { + return nil, err } - cmd := NewCommand("config") - if opt.File != "" { - cmd.AddArgs("--file", opt.File) - } - for _, a := range opt.Args { - cmd.AddArgs(a) + + return cfg, nil +} + +// SetConfig sets the repository Git configuration. +func (r *Repository) SetConfig(cfg *gcfg.Config) error { + cp := filepath.Join(r.Path, "config") + f, err := os.Create(cp) + if err != nil { + return err } - cmd.AddArgs(key, value) - _, err := cmd.Run() - return err + + defer f.Close() // nolint: errcheck + e := gcfg.NewEncoder(f) + return e.Encode(cfg) } diff --git a/git/repo.go b/git/repo.go index 479a4581a6a0ddf3c2094224f98a37f595da1734..ef3d2759d63411b4d7f297ebb838eb959bd6b00d 100644 --- a/git/repo.go +++ b/git/repo.go @@ -200,34 +200,6 @@ func (r *Repository) CommitsByPage(ref *Reference, page, size int) (Commits, err return commits, nil } -// Config returns the config value for the given key. -func (r *Repository) Config(key string, opts ...ConfigOptions) (string, error) { - dir, err := gitDir(r.Repository) - if err != nil { - return "", err - } - var opt ConfigOptions - if len(opts) > 0 { - opt = opts[0] - } - opt.File = filepath.Join(dir, "config") - return Config(key, opt) -} - -// SetConfig sets the config value for the given key. -func (r *Repository) SetConfig(key, value string, opts ...ConfigOptions) error { - dir, err := gitDir(r.Repository) - if err != nil { - return err - } - var opt ConfigOptions - if len(opts) > 0 { - opt = opts[0] - } - opt.File = filepath.Join(dir, "config") - return SetConfig(key, value, opt) -} - // SymbolicRef returns or updates the symbolic reference for the given name. // Both name and ref can be empty. func (r *Repository) SymbolicRef(name string, ref string, opts ...git.SymbolicRefOptions) (string, error) { diff --git a/go.mod b/go.mod index e84be4d037b67ddf7e1ad1af07739e5f0d883638..0a9d11dcdd3532e6ac2c27621cc67f3a1a04a057 100644 --- a/go.mod +++ b/go.mod @@ -19,9 +19,10 @@ require ( require ( github.com/caarlos0/env/v8 v8.0.0 + github.com/charmbracelet/git-lfs-transfer v0.1.1-0.20230721203144-64d90e7a36a1 github.com/charmbracelet/keygen v0.4.3 github.com/charmbracelet/log v0.2.3-0.20230713155356-557335e40e35 - github.com/charmbracelet/ssh v0.0.0-20230712221603-7e03c5063afc + github.com/charmbracelet/ssh v0.0.0-20230720143903-5bdd92839155 github.com/gobwas/glob v0.2.3 github.com/gogs/git-module v1.8.2 github.com/hashicorp/golang-lru/v2 v2.0.4 @@ -33,6 +34,7 @@ require ( github.com/prometheus/client_golang v1.16.0 github.com/robfig/cron/v3 v3.0.1 github.com/rogpeppe/go-internal v1.11.0 + github.com/rubyist/tracerx v0.0.0-20170927163412-787959303086 github.com/spf13/cobra v1.7.0 go.uber.org/automaxprocs v1.5.3 goji.io v2.0.2+incompatible @@ -52,6 +54,8 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 // indirect github.com/dlclark/regexp2 v1.4.0 // indirect + github.com/git-lfs/pktline v0.0.0-20230103162542-ca444d533ef1 // indirect + github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect github.com/go-logfmt/logfmt v0.6.0 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/uuid v1.3.0 // indirect @@ -80,12 +84,13 @@ require ( github.com/yuin/goldmark v1.5.2 // indirect github.com/yuin/goldmark-emoji v1.0.1 // indirect golang.org/x/mod v0.9.0 // indirect - golang.org/x/net v0.10.0 // indirect + golang.org/x/net v0.12.0 // indirect golang.org/x/sys v0.10.0 // indirect golang.org/x/term v0.10.0 // indirect golang.org/x/text v0.11.0 // indirect golang.org/x/tools v0.6.0 // indirect google.golang.org/protobuf v1.30.0 // indirect + gopkg.in/warnings.v0 v0.1.2 // indirect lukechampine.com/uint128 v1.2.0 // indirect modernc.org/cc/v3 v3.40.0 // indirect modernc.org/ccgo/v3 v3.16.13 // indirect diff --git a/go.sum b/go.sum index 30314107e3eb2ae3642a2291bd1d6968186e03ad..7846d7c62bcc53de401544721db60589731d64ac 100644 --- a/go.sum +++ b/go.sum @@ -21,6 +21,8 @@ github.com/charmbracelet/bubbles v0.16.1 h1:6uzpAAaT9ZqKssntbvZMlksWHruQLNxg49H5 github.com/charmbracelet/bubbles v0.16.1/go.mod h1:2QCp9LFlEsBQMvIYERr7Ww2H2bA7xen1idUDIzm/+Xc= github.com/charmbracelet/bubbletea v0.24.2 h1:uaQIKx9Ai6Gdh5zpTbGiWpytMU+CfsPp06RaW2cx/SY= github.com/charmbracelet/bubbletea v0.24.2/go.mod h1:XdrNrV4J8GiyshTtx3DNuYkR1FDaJmO3l2nejekbsgg= +github.com/charmbracelet/git-lfs-transfer v0.1.1-0.20230721203144-64d90e7a36a1 h1:/QzZzTDdlDYGZeC2O2y/Qw+AiHqh3vCsO4yrKDWXtqs= +github.com/charmbracelet/git-lfs-transfer v0.1.1-0.20230721203144-64d90e7a36a1/go.mod h1:eXJuVicxnjRgRMokmutZdistxoMRjBjjfqvrYq7bCIU= github.com/charmbracelet/glamour v0.6.0 h1:wi8fse3Y7nfcabbbDuwolqTqMQPMnVPeZhDM273bISc= github.com/charmbracelet/glamour v0.6.0/go.mod h1:taqWV4swIMMbWALc0m7AfE9JkPSU8om2538k9ITBxOc= github.com/charmbracelet/keygen v0.4.3 h1:ywOZRwkDlpmkawl0BgLTxaYWDSqp6Y4nfVVmgyyO1Mg= @@ -29,8 +31,8 @@ github.com/charmbracelet/lipgloss v0.7.1 h1:17WMwi7N1b1rVWOjMT+rCh7sQkvDU75B2hbZ github.com/charmbracelet/lipgloss v0.7.1/go.mod h1:yG0k3giv8Qj8edTCbbg6AlQ5e8KNWpFujkNawKNhE2c= github.com/charmbracelet/log v0.2.3-0.20230713155356-557335e40e35 h1:VXEaJ1iM2L5N8T2WVbv4y631pzCD3O9s75dONqK+87g= github.com/charmbracelet/log v0.2.3-0.20230713155356-557335e40e35/go.mod h1:ZApwwzDbbETVTIRTk7724yQRJAXIktt98yGVMMaa3y8= -github.com/charmbracelet/ssh v0.0.0-20230712221603-7e03c5063afc h1:JUm+5HigAM5utFiThwIDX9iU0BaheKpuNVr+umi3sFg= -github.com/charmbracelet/ssh v0.0.0-20230712221603-7e03c5063afc/go.mod h1:F1vgddWsb/Yr/OZilFeRZEh5sE/qU0Dt1mKkmke6Zvg= +github.com/charmbracelet/ssh v0.0.0-20230720143903-5bdd92839155 h1:vJqYhlL0doAWQPz+EX/hK5x/ZYguoua773oRz77zYKo= +github.com/charmbracelet/ssh v0.0.0-20230720143903-5bdd92839155/go.mod h1:F1vgddWsb/Yr/OZilFeRZEh5sE/qU0Dt1mKkmke6Zvg= github.com/charmbracelet/wish v1.1.1 h1:KdICASKd2oh2JPvk1Z4CJtAi97cFErXF7NKienPICO4= github.com/charmbracelet/wish v1.1.1/go.mod h1:xh4KZpSULw+Xqb9bcbhw92QAinVB75CVLWrFuyY6IVs= github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 h1:q2hJAaP1k2wIvVRd/hEHD7lacgqrCPS+k8g1MndzfWY= @@ -43,6 +45,10 @@ github.com/dlclark/regexp2 v1.4.0 h1:F1rxgk7p4uKjwIQxBs9oAXe5CqrXlCduYEJvrF4u93E github.com/dlclark/regexp2 v1.4.0/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/git-lfs/pktline v0.0.0-20230103162542-ca444d533ef1 h1:mtDjlmloH7ytdblogrMz1/8Hqua1y8B4ID+bh3rvod0= +github.com/git-lfs/pktline v0.0.0-20230103162542-ca444d533ef1/go.mod h1:fenKRzpXDjNpsIBhuhUzvjCKlDjKam0boRAenTE0Q6A= +github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66DAb0lQFJrpS6731Oaa12ikc+DiI= +github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic= github.com/go-git/go-git/v5 v5.7.0 h1:t9AudWVLmqzlo+4bqdf7GY+46SUuRsx59SboFxkq2aE= github.com/go-git/go-git/v5 v5.7.0/go.mod h1:coJHKEOk5kUClpsNlXrUvPrDxY3w3gjHvhcZd8Fodw8= github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= @@ -124,6 +130,8 @@ github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= @@ -145,6 +153,8 @@ github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/rubyist/tracerx v0.0.0-20170927163412-787959303086 h1:mncRSDOqYCng7jOD+Y6+IivdRI6Kzv2BLWYkWkdQfu0= +github.com/rubyist/tracerx v0.0.0-20170927163412-787959303086/go.mod h1:YpdgDXpumPB/+EGmGTYHeiW/0QVFRzBYTNFaxWfPDk4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sahilm/fuzzy v0.1.0 h1:FzWGaw2Opqyu+794ZQ9SYifWv2EIXpwP4q8dY1kDAwI= github.com/sahilm/fuzzy v0.1.0/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y= @@ -157,6 +167,7 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -179,8 +190,8 @@ golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs= golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20221002022538-bcab6841153b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= +golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= @@ -215,6 +226,8 @@ google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME= +gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/server/backend/backend.go b/server/backend/backend.go index 2d6104d56f2a32563c7724ccd762312895c804a2..586d95132504a852dc8fa339de464d3360aca59d 100644 --- a/server/backend/backend.go +++ b/server/backend/backend.go @@ -7,7 +7,6 @@ import ( "github.com/charmbracelet/soft-serve/server/config" "github.com/charmbracelet/soft-serve/server/db" "github.com/charmbracelet/soft-serve/server/store" - "github.com/charmbracelet/soft-serve/server/store/database" ) // Backend is the Soft Serve backend that handles users, repositories, and @@ -23,7 +22,7 @@ type Backend struct { // New returns a new Soft Serve backend. func New(ctx context.Context, cfg *config.Config, db *db.DB) *Backend { - dbstore := database.New(ctx, db) + dbstore := store.FromContext(ctx) logger := log.FromContext(ctx).WithPrefix("backend") b := &Backend{ ctx: ctx, diff --git a/server/backend/lfs.go b/server/backend/lfs.go new file mode 100644 index 0000000000000000000000000000000000000000..dfc21ea69d5eb9193d347b39011591f4fd9c9367 --- /dev/null +++ b/server/backend/lfs.go @@ -0,0 +1,85 @@ +package backend + +import ( + "context" + "errors" + "io" + "path" + "path/filepath" + + "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/lfs" + "github.com/charmbracelet/soft-serve/server/proto" + "github.com/charmbracelet/soft-serve/server/storage" + "github.com/charmbracelet/soft-serve/server/store" +) + +// StoreRepoMissingLFSObjects stores missing LFS objects for a repository. +func StoreRepoMissingLFSObjects(ctx context.Context, repo proto.Repository, dbx *db.DB, store store.Store, lfsClient lfs.Client) error { + cfg := config.FromContext(ctx) + lfsRoot := filepath.Join(cfg.DataPath, "lfs") + + // TODO: support S3 storage + strg := storage.NewLocalStorage(lfsRoot) + pointerChan := make(chan lfs.PointerBlob) + errChan := make(chan error, 1) + r, err := repo.Open() + if err != nil { + return err + } + + go lfs.SearchPointerBlobs(ctx, r, pointerChan, errChan) + + download := func(pointers []lfs.Pointer) error { + return lfsClient.Download(ctx, pointers, func(p lfs.Pointer, content io.ReadCloser, objectError error) error { + if objectError != nil { + return objectError + } + + defer content.Close() // nolint: errcheck + return dbx.TransactionContext(ctx, func(tx *db.Tx) error { + if err := store.CreateLFSObject(ctx, tx, repo.ID(), p.Oid, p.Size); err != nil { + return db.WrapError(err) + } + + return strg.Put(path.Join("objects", p.RelativePath()), content) + }) + }) + } + + var batch []lfs.Pointer + for pointer := range pointerChan { + obj, err := store.GetLFSObjectByOid(ctx, dbx, repo.ID(), pointer.Oid) + if err != nil && !errors.Is(err, db.ErrRecordNotFound) { + return db.WrapError(err) + } + + exist, err := strg.Exists(path.Join("objects", pointer.RelativePath())) + if err != nil { + return err + } + + if exist && obj.ID == 0 { + if err := store.CreateLFSObject(ctx, dbx, repo.ID(), pointer.Oid, pointer.Size); err != nil { + return db.WrapError(err) + } + } else { + batch = append(batch, pointer.Pointer) + // Limit batch requests to 20 objects + if len(batch) >= 20 { + if err := download(batch); err != nil { + return err + } + + batch = nil + } + } + } + + if err, ok := <-errChan; ok { + return err + } + + return nil +} diff --git a/server/backend/repo.go b/server/backend/repo.go index ebd846f4e9328ea71fa10ca1b56d592f5a62b2b4..15eaea6ff72ea919d66ff897633471373a7c6048 100644 --- a/server/backend/repo.go +++ b/server/backend/repo.go @@ -7,6 +7,7 @@ import ( "fmt" "io/fs" "os" + "path" "path/filepath" "time" @@ -14,7 +15,9 @@ import ( "github.com/charmbracelet/soft-serve/server/db" "github.com/charmbracelet/soft-serve/server/db/models" "github.com/charmbracelet/soft-serve/server/hooks" + "github.com/charmbracelet/soft-serve/server/lfs" "github.com/charmbracelet/soft-serve/server/proto" + "github.com/charmbracelet/soft-serve/server/storage" "github.com/charmbracelet/soft-serve/server/utils" ) @@ -103,7 +106,6 @@ func (d *Backend) ImportRepository(ctx context.Context, name string, remote stri ), }, }, - // Timeout: time.Hour, } if err := git.Clone(remote, rp, copts); err != nil { @@ -115,13 +117,51 @@ func (d *Backend) ImportRepository(ctx context.Context, name string, remote stri return nil, err } - return d.CreateRepository(ctx, name, opts) + r, err := d.CreateRepository(ctx, name, opts) + if err != nil { + d.logger.Error("failed to create repository", "err", err, "name", name) + return nil, err + } + + rr, err := r.Open() + if err != nil { + d.logger.Error("failed to open repository", "err", err, "path", rp) + return nil, err + } + + rcfg, err := rr.Config() + if err != nil { + d.logger.Error("failed to get repository config", "err", err, "path", rp) + return nil, err + } + + rcfg.Section("lfs").SetOption("url", remote) + + if err := rr.SetConfig(rcfg); err != nil { + d.logger.Error("failed to set repository config", "err", err, "path", rp) + return nil, err + } + + endpoint, err := lfs.NewEndpoint(remote) + if err != nil { + d.logger.Error("failed to create lfs endpoint", "err", err, "path", rp) + return nil, err + } + + client := lfs.NewClient(endpoint) + + if err := StoreRepoMissingLFSObjects(ctx, r, d.db, d.store, client); err != nil { + d.logger.Error("failed to store missing lfs objects", "err", err, "path", rp) + return nil, err + } + + return r, nil } // DeleteRepository deletes a repository. // // It implements backend.Backend. -func (d *Backend) DeleteRepository(ctx context.Context, name string) error { +func (d *Backend) DeleteRepository(ctx context.Context, name string, deleteLFS bool) error { name = utils.SanitizeRepo(name) repo := name + ".git" rp := filepath.Join(d.reposPath(), repo) @@ -130,6 +170,26 @@ func (d *Backend) DeleteRepository(ctx context.Context, name string) error { // Delete repo from cache defer d.cache.Delete(name) + if deleteLFS { + strg := storage.NewLocalStorage(filepath.Join(d.cfg.DataPath, "lfs")) + objs, err := d.store.GetLFSObjectsByName(ctx, tx, name) + if err != nil { + return err + } + + for _, obj := range objs { + p := lfs.Pointer{ + Oid: obj.Oid, + Size: obj.Size, + } + + d.logger.Debug("deleting lfs object", "repo", name, "oid", obj.Oid) + if err := strg.Delete(path.Join("objects", p.RelativePath())); err != nil { + d.logger.Error("failed to delete lfs object", "repo", name, "err", err, "oid", obj.Oid) + } + } + } + if err := d.store.DeleteRepoByName(ctx, tx, name); err != nil { return err } @@ -428,6 +488,13 @@ type repo struct { repo models.Repo } +// ID returns the repository's ID. +// +// It implements proto.Repository. +func (r *repo) ID() int64 { + return r.repo.ID +} + // Description returns the repository's description. // // It implements backend.Repository. diff --git a/server/backend/user.go b/server/backend/user.go index edefaf686e55eea1b08e43e9fb9a5809ba4d61b7..8b5e2a2aeae549bec6d1f19dec0a5e0d00685bbd 100644 --- a/server/backend/user.go +++ b/server/backend/user.go @@ -17,8 +17,36 @@ import ( // // It implements backend.Backend. func (d *Backend) AccessLevel(ctx context.Context, repo string, username string) access.AccessLevel { - anon := d.AnonAccess(ctx) user, _ := d.User(ctx, username) + return d.AccessLevelForUser(ctx, repo, user) +} + +// AccessLevelByPublicKey returns the access level of a user's public key for a repository. +// +// It implements backend.Backend. +func (d *Backend) AccessLevelByPublicKey(ctx context.Context, repo string, pk ssh.PublicKey) access.AccessLevel { + for _, k := range d.cfg.AdminKeys() { + if sshutils.KeysEqual(pk, k) { + return access.AdminAccess + } + } + + user, _ := d.UserByPublicKey(ctx, pk) + if user != nil { + return d.AccessLevel(ctx, repo, user.Username()) + } + + return d.AccessLevel(ctx, repo, "") +} + +// AccessLevelForUser returns the access level of a user for a repository. +func (d *Backend) AccessLevelForUser(ctx context.Context, repo string, user proto.User) access.AccessLevel { + var username string + anon := d.AnonAccess(ctx) + if user != nil { + username = user.Username() + } + // If the user is an admin, they have admin access. if user != nil && user.IsAdmin() { return access.AdminAccess @@ -58,24 +86,6 @@ func (d *Backend) AccessLevel(ctx context.Context, repo string, username string) return anon } -// AccessLevelByPublicKey returns the access level of a user's public key for a repository. -// -// It implements backend.Backend. -func (d *Backend) AccessLevelByPublicKey(ctx context.Context, repo string, pk ssh.PublicKey) access.AccessLevel { - for _, k := range d.cfg.AdminKeys() { - if sshutils.KeysEqual(pk, k) { - return access.AdminAccess - } - } - - user, _ := d.UserByPublicKey(ctx, pk) - if user != nil { - return d.AccessLevel(ctx, repo, user.Username()) - } - - return d.AccessLevel(ctx, repo, "") -} - // User finds a user by username. // // It implements backend.Backend. @@ -273,17 +283,22 @@ type user struct { var _ proto.User = (*user)(nil) -// IsAdmin implements store.User +// IsAdmin implements proto.User func (u *user) IsAdmin() bool { return u.user.Admin } -// PublicKeys implements store.User +// PublicKeys implements proto.User func (u *user) PublicKeys() []ssh.PublicKey { return u.publicKeys } -// Username implements store.User +// Username implements proto.User func (u *user) Username() string { return u.user.Username } + +// ID implements proto.User. +func (u *user) ID() int64 { + return u.user.ID +} diff --git a/server/daemon/daemon_test.go b/server/daemon/daemon_test.go index c11ddefb5648ad306b0489aa0d24a5d136b08b26..f1ad44caaa23b0422c1389d4f9833d2dc80a54e7 100644 --- a/server/daemon/daemon_test.go +++ b/server/daemon/daemon_test.go @@ -17,6 +17,8 @@ import ( "github.com/charmbracelet/soft-serve/server/db" "github.com/charmbracelet/soft-serve/server/db/migrate" "github.com/charmbracelet/soft-serve/server/git" + "github.com/charmbracelet/soft-serve/server/store" + "github.com/charmbracelet/soft-serve/server/store/database" "github.com/charmbracelet/soft-serve/server/test" "github.com/go-git/go-git/v5/plumbing/format/pktline" _ "modernc.org/sqlite" // sqlite driver @@ -41,15 +43,17 @@ func TestMain(m *testing.M) { log.Fatal(err) } ctx = config.WithContext(ctx, cfg) - db, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource) + dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource) if err != nil { log.Fatal(err) } - defer db.Close() // nolint: errcheck - if err := migrate.Migrate(ctx, db); err != nil { + defer dbx.Close() // nolint: errcheck + if err := migrate.Migrate(ctx, dbx); err != nil { log.Fatal(err) } - be := backend.New(ctx, cfg, db) + datastore := database.New(ctx, dbx) + ctx = store.WithContext(ctx, datastore) + be := backend.New(ctx, cfg, dbx) ctx = backend.WithContext(ctx, be) d, err := NewGitDaemon(ctx) if err != nil { @@ -68,7 +72,7 @@ func TestMain(m *testing.M) { os.Unsetenv("SOFT_SERVE_GIT_IDLE_TIMEOUT") os.Unsetenv("SOFT_SERVE_GIT_LISTEN_ADDR") _ = d.Close() - _ = db.Close() + _ = dbx.Close() os.Exit(code) } diff --git a/server/db/context.go b/server/db/context.go index 17c70ee4978d8b3c97e0a05177d9cd30d82c3ae8..5e289d8df96a0941d4d594aaba3494f553465e24 100644 --- a/server/db/context.go +++ b/server/db/context.go @@ -2,11 +2,12 @@ package db import "context" -var contextKey = struct{ string }{"db"} +// ContextKey is the key used to store the database in the context. +var ContextKey = struct{ string }{"db"} // FromContext returns the database from the context. func FromContext(ctx context.Context) *DB { - if db, ok := ctx.Value(contextKey).(*DB); ok { + if db, ok := ctx.Value(ContextKey).(*DB); ok { return db } return nil @@ -14,5 +15,5 @@ func FromContext(ctx context.Context) *DB { // WithContext returns a new context with the database. func WithContext(ctx context.Context, db *DB) context.Context { - return context.WithValue(ctx, contextKey, db) + return context.WithValue(ctx, ContextKey, db) } diff --git a/server/db/handler.go b/server/db/handler.go new file mode 100644 index 0000000000000000000000000000000000000000..981cadf21ef38275d7e27f6a2539e9b80f2e7a49 --- /dev/null +++ b/server/db/handler.go @@ -0,0 +1,25 @@ +package db + +import ( + "context" + "database/sql" + + "github.com/jmoiron/sqlx" +) + +// Handler is a database handler. +type Handler interface { + Rebind(string) string + + Select(interface{}, string, ...interface{}) error + Get(interface{}, string, ...interface{}) error + Queryx(string, ...interface{}) (*sqlx.Rows, error) + QueryRowx(string, ...interface{}) *sqlx.Row + Exec(string, ...interface{}) (sql.Result, error) + + SelectContext(context.Context, interface{}, string, ...interface{}) error + GetContext(context.Context, interface{}, string, ...interface{}) error + QueryxContext(context.Context, string, ...interface{}) (*sqlx.Rows, error) + QueryRowxContext(context.Context, string, ...interface{}) *sqlx.Row + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) +} diff --git a/server/db/migrate/0002_create_lfs_tables.go b/server/db/migrate/0002_create_lfs_tables.go new file mode 100644 index 0000000000000000000000000000000000000000..8d4ace82c1f425cdf77e0a06059485589e1fba29 --- /dev/null +++ b/server/db/migrate/0002_create_lfs_tables.go @@ -0,0 +1,23 @@ +package migrate + +import ( + "context" + + "github.com/charmbracelet/soft-serve/server/db" +) + +const ( + createLFSTablesName = "create lfs tables" + createLFSTablesVersion = 2 +) + +var createLFSTables = Migration{ + Version: createLFSTablesVersion, + Name: createLFSTablesName, + Migrate: func(ctx context.Context, tx *db.Tx) error { + return migrateUp(ctx, tx, createLFSTablesVersion, createLFSTablesName) + }, + Rollback: func(ctx context.Context, tx *db.Tx) error { + return migrateDown(ctx, tx, createLFSTablesVersion, createLFSTablesName) + }, +} diff --git a/server/db/migrate/0002_create_lfs_tables_postgres.down.sql b/server/db/migrate/0002_create_lfs_tables_postgres.down.sql new file mode 100644 index 0000000000000000000000000000000000000000..bae6ea0cd13f5ccf90114fb2e05146bb822a6658 --- /dev/null +++ b/server/db/migrate/0002_create_lfs_tables_postgres.down.sql @@ -0,0 +1,2 @@ +DROP TABLE IF EXISTS lfs_locks; +DROP TABLE IF EXISTS lfs_objects; diff --git a/server/db/migrate/0002_create_lfs_tables_postgres.up.sql b/server/db/migrate/0002_create_lfs_tables_postgres.up.sql new file mode 100644 index 0000000000000000000000000000000000000000..fed48900f876e09a5777966c82bf088f41809821 --- /dev/null +++ b/server/db/migrate/0002_create_lfs_tables_postgres.up.sql @@ -0,0 +1,28 @@ +CREATE TABLE IF NOT EXISTS lfs_objects ( + id SERIAL PRIMARY KEY, + oid TEXT NOT NULL, + size INTEGER NOT NULL, + repo_id INTEGER NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL, + UNIQUE (oid, repo_id), + CONSTRAINT repo_id_fk + FOREIGN KEY(repo_id) REFERENCES repos(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS lfs_locks ( + id SERIAL PRIMARY KEY, + repo_id INTEGER NOT NULL, + user_id INTEGER NOT NULL, + path TEXT NOT NULL, + refname TEXT, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL, + UNIQUE (repo_id, path), + CONSTRAINT repo_id_fk + FOREIGN KEY(repo_id) REFERENCES repos(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); diff --git a/server/db/migrate/0002_create_lfs_tables_sqlite.down.sql b/server/db/migrate/0002_create_lfs_tables_sqlite.down.sql new file mode 100644 index 0000000000000000000000000000000000000000..bae6ea0cd13f5ccf90114fb2e05146bb822a6658 --- /dev/null +++ b/server/db/migrate/0002_create_lfs_tables_sqlite.down.sql @@ -0,0 +1,2 @@ +DROP TABLE IF EXISTS lfs_locks; +DROP TABLE IF EXISTS lfs_objects; diff --git a/server/db/migrate/0002_create_lfs_tables_sqlite.up.sql b/server/db/migrate/0002_create_lfs_tables_sqlite.up.sql new file mode 100644 index 0000000000000000000000000000000000000000..0a43d6849340b93cb0783b1c560128ef005cb83c --- /dev/null +++ b/server/db/migrate/0002_create_lfs_tables_sqlite.up.sql @@ -0,0 +1,28 @@ +CREATE TABLE IF NOT EXISTS lfs_objects ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + oid TEXT NOT NULL, + size INTEGER NOT NULL, + repo_id INTEGER NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL, + UNIQUE (oid, repo_id), + CONSTRAINT repo_id_fk + FOREIGN KEY(repo_id) REFERENCES repos(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); + +CREATE TABLE IF NOT EXISTS lfs_locks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + repo_id INTEGER NOT NULL, + user_id INTEGER NOT NULL, + path TEXT NOT NULL, + refname TEXT, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL, + UNIQUE (repo_id, path), + CONSTRAINT repo_id_fk + FOREIGN KEY(repo_id) REFERENCES repos(id) + ON DELETE CASCADE + ON UPDATE CASCADE +); diff --git a/server/db/migrate/migrations.go b/server/db/migrate/migrations.go index 88a9e434696ecdfe5ccb58d844d1f244d3bbfce5..8935ff05c8bff39197a5514f680a4ad91f85d61b 100644 --- a/server/db/migrate/migrations.go +++ b/server/db/migrate/migrations.go @@ -16,6 +16,7 @@ var sqls embed.FS // Keep this in order of execution, oldest to newest. var migrations = []Migration{ createTables, + createLFSTables, } func execMigration(ctx context.Context, tx *db.Tx, version int, name string, down bool) error { diff --git a/server/db/models/lfs.go b/server/db/models/lfs.go new file mode 100644 index 0000000000000000000000000000000000000000..f93ea55bad91ce1b795184a3e4e86754ecfd89aa --- /dev/null +++ b/server/db/models/lfs.go @@ -0,0 +1,24 @@ +package models + +import "time" + +// LFSObject is a Git LFS object. +type LFSObject struct { + ID int64 `db:"id"` + Oid string `db:"oid"` + Size int64 `db:"size"` + RepoID int64 `db:"repo_id"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} + +// LFSLock is a Git LFS lock. +type LFSLock struct { + ID int64 `db:"id"` + Path string `db:"path"` + UserID int64 `db:"user_id"` + RepoID int64 `db:"repo_id"` + Refname string `db:"refname"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} diff --git a/server/git/lfs.go b/server/git/lfs.go new file mode 100644 index 0000000000000000000000000000000000000000..047c59bbe66ee5896be423bf1342e5c407136351 --- /dev/null +++ b/server/git/lfs.go @@ -0,0 +1,451 @@ +package git + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "io" + "io/fs" + "path" + "path/filepath" + "strconv" + "time" + + "github.com/charmbracelet/git-lfs-transfer/transfer" + "github.com/charmbracelet/log" + "github.com/charmbracelet/soft-serve/server/backend" + "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/db/models" + "github.com/charmbracelet/soft-serve/server/proto" + "github.com/charmbracelet/soft-serve/server/storage" + "github.com/charmbracelet/soft-serve/server/store" + "github.com/charmbracelet/soft-serve/server/utils" + "github.com/rubyist/tracerx" +) + +func init() { + // git-lfs-transfer uses tracerx for logging. + // use a custom key to avoid conflicts + // SOFT_SERVE_TRACE=1 to enable tracing git-lfs-transfer in soft-serve + tracerx.DefaultKey = "SOFT_SERVE" + tracerx.Prefix = "trace soft-serve-lfs-transfer: " +} + +// lfsTransfer implements transfer.Backend. +type lfsTransfer struct { + ctx context.Context + cfg *config.Config + dbx *db.DB + store store.Store + logger *log.Logger + storage storage.Storage + repo proto.Repository +} + +var _ transfer.Backend = &lfsTransfer{} + +// LFSTransfer is a Git LFS transfer service handler. +// ctx is expected to have proto.User, *backend.Backend, *log.Logger, +// *config.Config, *db.DB, and store.Store. +// The first arg in cmd.Args should be the repo path. +// The second arg in cmd.Args should be the LFS operation (download or upload). +func LFSTransfer(ctx context.Context, cmd ServiceCommand) error { + if len(cmd.Args) < 2 { + return errors.New("missing args") + } + + logger := log.FromContext(ctx).WithPrefix("lfs-transfer") + handler := transfer.NewPktline(cmd.Stdin, cmd.Stdout) + be := backend.FromContext(ctx) + repoName := cmd.Args[0] + repoName = utils.SanitizeRepo(repoName) + op := cmd.Args[1] + + repo, err := be.Repository(ctx, repoName) + if err != nil { + logger.Errorf("error getting repo: %v", err) + return err + } + + ctx = context.WithValue(ctx, proto.ContextKeyRepository, repo) + + // Advertise capabilities. + for _, cap := range []string{ + "version=1", + "locking", + } { + if err := handler.WritePacketText(cap); err != nil { + logger.Errorf("error sending capability: %s: %v", cap, err) + return err + } + } + + if err := handler.WriteFlush(); err != nil { + logger.Error("error sending flush", "err", err) + return err + } + + cfg := config.FromContext(ctx) + processor := transfer.NewProcessor(handler, &lfsTransfer{ + ctx: ctx, + cfg: cfg, + dbx: db.FromContext(ctx), + store: store.FromContext(ctx), + logger: logger, + storage: storage.NewLocalStorage(filepath.Join(cfg.DataPath, "lfs")), + repo: repo, + }) + + return processor.ProcessCommands(op) +} + +// Batch implements transfer.Backend. +func (t *lfsTransfer) Batch(_ string, pointers []transfer.Pointer) ([]transfer.BatchItem, error) { + repo, ok := t.ctx.Value(proto.ContextKeyRepository).(proto.Repository) + if !ok { + return nil, errors.New("no repository in context") + } + + items := make([]transfer.BatchItem, 0) + for _, p := range pointers { + obj, err := t.store.GetLFSObjectByOid(t.ctx, t.dbx, repo.ID(), p.Oid) + if err != nil && !errors.Is(err, db.ErrRecordNotFound) { + return items, db.WrapError(err) + } + + exist, err := t.storage.Exists(path.Join("objects", p.RelativePath())) + if err != nil { + return items, err + } + + if exist && obj.ID == 0 { + if err := t.store.CreateLFSObject(t.ctx, t.dbx, repo.ID(), p.Oid, p.Size); err != nil { + return items, db.WrapError(err) + } + } + + item := transfer.BatchItem{ + Pointer: p, + Present: exist, + } + items = append(items, item) + } + + return items, nil +} + +// Download implements transfer.Backend. +func (t *lfsTransfer) Download(oid string, _ ...string) (fs.File, error) { + cfg := config.FromContext(t.ctx) + strg := storage.NewLocalStorage(filepath.Join(cfg.DataPath, "lfs")) + pointer := transfer.Pointer{Oid: oid} + return strg.Open(path.Join("objects", pointer.RelativePath())) +} + +type uploadObject struct { + oid string + object storage.Object +} + +// StartUpload implements transfer.Backend. +func (t *lfsTransfer) StartUpload(oid string, r io.Reader, _ ...string) (interface{}, error) { + if r == nil { + return nil, fmt.Errorf("no reader: %w", transfer.ErrMissingData) + } + + tempDir := "incomplete" + randBytes := make([]byte, 12) + if _, err := rand.Read(randBytes); err != nil { + return nil, err + } + + tempName := fmt.Sprintf("%s%x", oid, randBytes) + tempName = path.Join(tempDir, tempName) + + if err := t.storage.Put(tempName, r); err != nil { + t.logger.Errorf("error putting object: %v", err) + return nil, err + } + + obj, err := t.storage.Open(tempName) + if err != nil { + t.logger.Errorf("error opening object: %v", err) + return nil, err + } + + t.logger.Infof("Object name: %s", obj.Name()) + + return uploadObject{ + oid: oid, + object: obj, + }, nil +} + +// FinishUpload implements transfer.Backend. +func (t *lfsTransfer) FinishUpload(state interface{}, _ ...string) error { + upl, ok := state.(uploadObject) + if !ok { + return errors.New("invalid state") + } + + pointer := transfer.Pointer{ + Oid: upl.oid, + } + + expectedPath := path.Join("objects", pointer.RelativePath()) + if err := t.storage.Rename(upl.object.Name(), expectedPath); err != nil { + t.logger.Errorf("error renaming object: %v", err) + return err + } + + return nil +} + +// Verify implements transfer.Backend. +func (t *lfsTransfer) Verify(oid string, args map[string]string) (transfer.Status, error) { + var expectedSize int64 + var err error + size, ok := args[transfer.SizeKey] + if !ok { + return transfer.NewFailureStatus(transfer.StatusBadRequest, "missing size"), nil + } + + expectedSize, err = strconv.ParseInt(size, 10, 64) + if err != nil { + t.logger.Errorf("invalid size argument: %v", err) + return transfer.NewFailureStatus(transfer.StatusBadRequest, "invalid size argument"), nil + } + + pointer := transfer.Pointer{ + Oid: oid, + Size: expectedSize, + } + expectedPath := path.Join("objects", pointer.RelativePath()) + stat, err := t.storage.Stat(expectedPath) + if err != nil { + t.logger.Errorf("error stating object: %v", err) + return nil, err + } + + if stat.Size() != expectedSize { + t.logger.Errorf("size mismatch: %d != %d", stat.Size(), expectedSize) + return transfer.NewFailureStatus(transfer.StatusConflict, "size mismatch"), nil + } + + return transfer.SuccessStatus(), nil +} + +type lfsLockBackend struct { + *lfsTransfer + user proto.User +} + +var _ transfer.LockBackend = (*lfsLockBackend)(nil) + +// LockBackend implements transfer.Backend. +func (t *lfsTransfer) LockBackend() transfer.LockBackend { + user, ok := t.ctx.Value(proto.ContextKeyUser).(proto.User) + if !ok { + t.logger.Errorf("no user in context while creating lock backend, repo %s", t.repo.Name()) + return nil + } + + return &lfsLockBackend{t, user} +} + +// Create implements transfer.LockBackend. +func (l *lfsLockBackend) Create(path string, refname string) (transfer.Lock, error) { + var lock LFSLock + if err := l.dbx.TransactionContext(l.ctx, func(tx *db.Tx) error { + if err := l.store.CreateLFSLockForUser(l.ctx, tx, l.repo.ID(), l.user.ID(), path, refname); err != nil { + return db.WrapError(err) + } + + var err error + lock.lock, err = l.store.GetLFSLockForUserPath(l.ctx, tx, l.repo.ID(), l.user.ID(), path) + if err != nil { + return db.WrapError(err) + } + + lock.owner, err = l.store.GetUserByID(l.ctx, tx, lock.lock.UserID) + return db.WrapError(err) + }); err != nil { + // Return conflict (409) if the lock already exists. + if errors.Is(err, db.ErrDuplicateKey) { + return nil, transfer.ErrConflict + } + l.logger.Errorf("error creating lock: %v", err) + return nil, err + } + + lock.backend = l + + return &lock, nil +} + +// FromID implements transfer.LockBackend. +func (l *lfsLockBackend) FromID(id string) (transfer.Lock, error) { + var lock LFSLock + user, ok := l.ctx.Value(proto.ContextKeyUser).(proto.User) + if !ok || user == nil { + return nil, errors.New("no user in context") + } + + if err := l.dbx.TransactionContext(l.ctx, func(tx *db.Tx) error { + var err error + lock.lock, err = l.store.GetLFSLockForUserByID(l.ctx, tx, user.ID(), id) + if err != nil { + return db.WrapError(err) + } + + lock.owner, err = l.store.GetUserByID(l.ctx, tx, lock.lock.UserID) + return db.WrapError(err) + }); err != nil { + l.logger.Errorf("error getting lock: %v", err) + return nil, err + } + + lock.backend = l + + return &lock, nil +} + +// FromPath implements transfer.LockBackend. +func (l *lfsLockBackend) FromPath(path string) (transfer.Lock, error) { + var lock LFSLock + + if err := l.dbx.TransactionContext(l.ctx, func(tx *db.Tx) error { + var err error + lock.lock, err = l.store.GetLFSLockForUserPath(l.ctx, tx, l.repo.ID(), l.user.ID(), path) + if err != nil { + return db.WrapError(err) + } + + lock.owner, err = l.store.GetUserByID(l.ctx, tx, lock.lock.UserID) + return db.WrapError(err) + }); err != nil { + l.logger.Errorf("error getting lock: %v", err) + return nil, err + } + + lock.backend = l + + return &lock, nil +} + +// Range implements transfer.LockBackend. +func (l *lfsLockBackend) Range(fn func(transfer.Lock) error) error { + var locks []*LFSLock + + if err := l.dbx.TransactionContext(l.ctx, func(tx *db.Tx) error { + mlocks, err := l.store.GetLFSLocks(l.ctx, tx, l.repo.ID()) + if err != nil { + return db.WrapError(err) + } + + users := make(map[int64]models.User, 0) + for _, mlock := range mlocks { + owner, ok := users[mlock.UserID] + if !ok { + owner, err = l.store.GetUserByID(l.ctx, tx, mlock.UserID) + if err != nil { + return db.WrapError(err) + } + + users[mlock.UserID] = owner + } + + locks = append(locks, &LFSLock{lock: mlock, owner: owner, backend: l}) + } + + return nil + }); err != nil { + return err + } + + for _, lock := range locks { + if err := fn(lock); err != nil { + return err + } + } + + return nil +} + +// Unlock implements transfer.LockBackend. +func (l *lfsLockBackend) Unlock(lock transfer.Lock) error { + return l.dbx.TransactionContext(l.ctx, func(tx *db.Tx) error { + return db.WrapError( + l.store.DeleteLFSLockForUserByID(l.ctx, tx, l.user.ID(), lock.ID()), + ) + }) +} + +// LFSLock is a Git LFS lock object. +// It implements transfer.Lock. +type LFSLock struct { + lock models.LFSLock + owner models.User + backend *lfsLockBackend +} + +var _ transfer.Lock = (*LFSLock)(nil) + +// AsArguments implements transfer.Lock. +func (l *LFSLock) AsArguments() []string { + return []string{ + fmt.Sprintf("id=%s", l.ID()), + fmt.Sprintf("path=%s", l.Path()), + fmt.Sprintf("locked-at=%s", l.FormattedTimestamp()), + fmt.Sprintf("ownername=%s", l.OwnerName()), + } +} + +// AsLockSpec implements transfer.Lock. +func (l *LFSLock) AsLockSpec(ownerID bool) ([]string, error) { + id := l.ID() + spec := []string{ + fmt.Sprintf("lock %s", id), + fmt.Sprintf("path %s %s", id, l.Path()), + fmt.Sprintf("locked-at %s %s", id, l.FormattedTimestamp()), + fmt.Sprintf("ownername %s %s", id, l.OwnerName()), + } + + if ownerID { + who := "theirs" + if l.lock.UserID == l.owner.ID { + who = "ours" + } + + spec = append(spec, fmt.Sprintf("owner %s %s", id, who)) + } + + return spec, nil +} + +// FormattedTimestamp implements transfer.Lock. +func (l *LFSLock) FormattedTimestamp() string { + return l.lock.CreatedAt.Format(time.RFC3339) +} + +// ID implements transfer.Lock. +func (l *LFSLock) ID() string { + return strconv.FormatInt(l.lock.ID, 10) +} + +// OwnerName implements transfer.Lock. +func (l *LFSLock) OwnerName() string { + return l.owner.Username +} + +// Path implements transfer.Lock. +func (l *LFSLock) Path() string { + return l.lock.Path +} + +// Unlock implements transfer.Lock. +func (l *LFSLock) Unlock() error { + return l.backend.Unlock(l) +} diff --git a/server/git/service.go b/server/git/service.go index b2cd6aa26b294a95d38d072e45054358be56bacd..51a53a0c887b56c4596ecaaaa90fed1617d1c161 100644 --- a/server/git/service.go +++ b/server/git/service.go @@ -23,6 +23,9 @@ const ( UploadArchiveService Service = "git-upload-archive" // ReceivePackService is the receive-pack service. ReceivePackService Service = "git-receive-pack" + // LFSTransferService is the LFS transfer service. + LFSTransferService Service = "git-lfs-transfer" + // TODO: add support for git-lfs-authenticate ) // String returns the string representation of the service. @@ -40,6 +43,8 @@ func (s Service) Handler(ctx context.Context, cmd ServiceCommand) error { switch s { case UploadPackService, UploadArchiveService, ReceivePackService: return gitServiceHandler(ctx, s, cmd) + case LFSTransferService: + return LFSTransfer(ctx, cmd) default: return fmt.Errorf("unsupported service: %s", s) } @@ -57,6 +62,8 @@ func gitServiceHandler(ctx context.Context, svc Service, scmd ServiceCommand) er "-c", "uploadpack.allowFilter=true", // Enable push options "-c", "receive.advertisePushOptions=true", + // Disable LFS filters + "-c", "filter.lfs.required=", "-c", "filter.lfs.smudge=", "-c", "filter.lfs.clean=", svc.Name(), }...) if len(scmd.Args) > 0 { diff --git a/server/lfs/basic_transfer.go b/server/lfs/basic_transfer.go new file mode 100644 index 0000000000000000000000000000000000000000..609197c1035c176f6d31495d3a8cf57de7f5830c --- /dev/null +++ b/server/lfs/basic_transfer.go @@ -0,0 +1,124 @@ +package lfs + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + + "github.com/charmbracelet/log" +) + +// BasicTransferAdapter implements the "basic" adapter +type BasicTransferAdapter struct { + client *http.Client +} + +// Name returns the name of the adapter +func (a *BasicTransferAdapter) Name() string { + return "basic" +} + +// Download reads the download location and downloads the data +func (a *BasicTransferAdapter) Download(ctx context.Context, _ Pointer, l *Link) (io.ReadCloser, error) { + resp, err := a.performRequest(ctx, "GET", l, nil, nil) + if err != nil { + return nil, err + } + return resp.Body, nil +} + +// Upload sends the content to the LFS server +func (a *BasicTransferAdapter) Upload(ctx context.Context, p Pointer, r io.Reader, l *Link) error { + res, err := a.performRequest(ctx, "PUT", l, r, func(req *http.Request) { + if len(req.Header.Get("Content-Type")) == 0 { + req.Header.Set("Content-Type", "application/octet-stream") + } + + if req.Header.Get("Transfer-Encoding") == "chunked" { + req.TransferEncoding = []string{"chunked"} + } + + req.ContentLength = p.Size + }) + if err != nil { + return err + } + return res.Body.Close() +} + +// Verify calls the verify handler on the LFS server +func (a *BasicTransferAdapter) Verify(ctx context.Context, p Pointer, l *Link) error { + logger := log.FromContext(ctx).WithPrefix("lfs") + b, err := json.Marshal(p) + if err != nil { + logger.Errorf("Error encoding json: %v", err) + return err + } + + res, err := a.performRequest(ctx, "POST", l, bytes.NewReader(b), func(req *http.Request) { + req.Header.Set("Content-Type", MediaType) + }) + if err != nil { + return err + } + return res.Body.Close() +} + +func (a *BasicTransferAdapter) performRequest(ctx context.Context, method string, l *Link, body io.Reader, callback func(*http.Request)) (*http.Response, error) { + logger := log.FromContext(ctx).WithPrefix("lfs") + logger.Debugf("Calling: %s %s", method, l.Href) + + req, err := http.NewRequestWithContext(ctx, method, l.Href, body) + if err != nil { + logger.Errorf("Error creating request: %v", err) + return nil, err + } + for key, value := range l.Header { + req.Header.Set(key, value) + } + req.Header.Set("Accept", MediaType) + + if callback != nil { + callback(req) + } + + res, err := a.client.Do(req) + if err != nil { + select { + case <-ctx.Done(): + return res, ctx.Err() + default: + } + logger.Errorf("Error while processing request: %v", err) + return res, err + } + + if res.StatusCode != http.StatusOK { + return res, handleErrorResponse(res) + } + + return res, nil +} + +func handleErrorResponse(resp *http.Response) error { + defer resp.Body.Close() // nolint: errcheck + + er, err := decodeResponseError(resp.Body) + if err != nil { + return fmt.Errorf("Request failed with status %s", resp.Status) + } + return errors.New(er.Message) +} + +func decodeResponseError(r io.Reader) (ErrorResponse, error) { + var er ErrorResponse + err := json.NewDecoder(r).Decode(&er) + if err != nil { + log.Error("Error decoding json: %v", err) + } + return er, err +} diff --git a/server/lfs/client.go b/server/lfs/client.go new file mode 100644 index 0000000000000000000000000000000000000000..9cc9da0a15a115f82373fa8da50e8da2a504b1dc --- /dev/null +++ b/server/lfs/client.go @@ -0,0 +1,27 @@ +package lfs + +import ( + "context" + "io" +) + +// DownloadCallback gets called for every requested LFS object to process its content +type DownloadCallback func(p Pointer, content io.ReadCloser, objectError error) error + +// UploadCallback gets called for every requested LFS object to provide its content +type UploadCallback func(p Pointer, objectError error) (io.ReadCloser, error) + +// Client is a Git LFS client to communicate with a LFS source API. +type Client interface { + Download(ctx context.Context, objects []Pointer, callback DownloadCallback) error + Upload(ctx context.Context, objects []Pointer, callback UploadCallback) error +} + +// NewClient returns a new Git LFS client. +func NewClient(e Endpoint) Client { + if e.Scheme == "http" || e.Scheme == "https" { + return newHTTPClient(e) + } + // TODO: support ssh client + return nil +} diff --git a/server/lfs/common.go b/server/lfs/common.go new file mode 100644 index 0000000000000000000000000000000000000000..1bd2473068ab09f69b829ce5b292c50d9dd08097 --- /dev/null +++ b/server/lfs/common.go @@ -0,0 +1,88 @@ +package lfs + +import "time" + +const ( + // MediaType contains the media type for LFS server requests. + MediaType = "application/vnd.git-lfs+json" + + // OperationDownload is the operation name for a download request. + OperationDownload = "download" + + // OperationUpload is the operation name for an upload request. + OperationUpload = "upload" + + // ActionDownload is the action name for a download request. + ActionDownload = OperationDownload + + // ActionUpload is the action name for an upload request. + ActionUpload = OperationUpload + + // ActionVerify is the action name for a verify request. + ActionVerify = "verify" +) + +// Pointer contains LFS pointer data +type Pointer struct { + Oid string `json:"oid"` + Size int64 `json:"size"` +} + +// PointerBlob associates a Git blob with a Pointer. +type PointerBlob struct { + Hash string + Pointer +} + +// ErrorResponse describes the error to the client. +type ErrorResponse struct { + Message string `json:"message,omitempty"` + DocumentationURL string `json:"documentation_url,omitempty"` + RequestID string `json:"request_id,omitempty"` +} + +// BatchResponse contains multiple object metadata Representation structures +// for use with the batch API. +// https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md#successful-responses +type BatchResponse struct { + Transfer string `json:"transfer,omitempty"` + Objects []*ObjectResponse `json:"objects"` + HashAlgo string `json:"hash_algo,omitempty"` +} + +// ObjectResponse is object metadata as seen by clients of the LFS server. +type ObjectResponse struct { + Pointer + Actions map[string]*Link `json:"actions,omitempty"` + Error *ObjectError `json:"error,omitempty"` +} + +// Link provides a structure with information about how to access a object. +type Link struct { + Href string `json:"href"` + Header map[string]string `json:"header,omitempty"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` + ExpiresIn *time.Duration `json:"expires_in,omitempty"` +} + +// ObjectError defines the JSON structure returned to the client in case of an error. +type ObjectError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +// BatchRequest contains multiple requests processed in one batch operation. +// https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md#requests +type BatchRequest struct { + Operation string `json:"operation"` + Transfers []string `json:"transfers,omitempty"` + Ref *Reference `json:"ref,omitempty"` + Objects []Pointer `json:"objects"` + HashAlgo string `json:"hash_algo,omitempty"` +} + +// Reference contains a git reference. +// https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md#ref-property +type Reference struct { + Name string `json:"name"` +} diff --git a/server/lfs/endpoint.go b/server/lfs/endpoint.go new file mode 100644 index 0000000000000000000000000000000000000000..53e89e1703649fa137c017c72ddb1273c9b66026 --- /dev/null +++ b/server/lfs/endpoint.go @@ -0,0 +1,70 @@ +package lfs + +import ( + "fmt" + "net/url" + "strings" +) + +// Endpoint is a Git LFS endpoint. +type Endpoint = *url.URL + +// NewEndpoint returns a new Git LFS endpoint. +func NewEndpoint(rawurl string) (Endpoint, error) { + u, err := url.Parse(rawurl) + if err != nil { + e, err := endpointFromBareSSH(rawurl) + if err != nil { + return nil, err + } + u = e + } + + u.Path = strings.TrimSuffix(u.Path, "/") + + switch u.Scheme { + case "git": + // Use https for git:// URLs and strip the port if it exists. + u.Scheme = "https" + if u.Port() != "" { + u.Host = u.Hostname() + } + fallthrough + case "http", "https": + if strings.HasSuffix(u.Path, ".git") { + u.Path += "/info/lfs" + } else { + u.Path += ".git/info/lfs" + } + case "ssh", "git+ssh", "ssh+git": + default: + return nil, fmt.Errorf("unknown url: %s", rawurl) + } + + return u, nil +} + +// endpointFromBareSSH creates a new endpoint from a bare ssh repo. +// +// user@host.com:path/to/repo.git or +// [user@host.com:port]:path/to/repo.git +func endpointFromBareSSH(rawurl string) (*url.URL, error) { + parts := strings.Split(rawurl, ":") + partsLen := len(parts) + if partsLen < 2 { + return url.Parse(rawurl) + } + + // Treat presence of ':' as a bare URL + var newPath string + if len(parts) > 2 { // port included; really should only ever be 3 parts + // Correctly handle [host:port]:path URLs + parts[0] = strings.TrimPrefix(parts[0], "[") + parts[1] = strings.TrimSuffix(parts[1], "]") + newPath = fmt.Sprintf("%v:%v", parts[0], strings.Join(parts[1:], "/")) + } else { + newPath = strings.Join(parts, "/") + } + newrawurl := fmt.Sprintf("ssh://%v", newPath) + return url.Parse(newrawurl) +} diff --git a/server/lfs/http_client.go b/server/lfs/http_client.go new file mode 100644 index 0000000000000000000000000000000000000000..a8b55031f083d1fd5e4b9aaccd896fb50013f41c --- /dev/null +++ b/server/lfs/http_client.go @@ -0,0 +1,196 @@ +package lfs + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/charmbracelet/log" +) + +// httpClient is a Git LFS client to communicate with a LFS source API. +type httpClient struct { + client *http.Client + endpoint Endpoint + transfers map[string]TransferAdapter +} + +var _ Client = (*httpClient)(nil) + +// newHTTPClient returns a new Git LFS client. +func newHTTPClient(endpoint Endpoint) *httpClient { + return &httpClient{ + client: http.DefaultClient, + endpoint: endpoint, + transfers: map[string]TransferAdapter{ + TransferBasic: &BasicTransferAdapter{http.DefaultClient}, + }, + } +} + +// Download implements Client. +func (c *httpClient) Download(ctx context.Context, objects []Pointer, callback DownloadCallback) error { + return c.performOperation(ctx, objects, callback, nil) +} + +// Upload implements Client. +func (c *httpClient) Upload(ctx context.Context, objects []Pointer, callback UploadCallback) error { + return c.performOperation(ctx, objects, nil, callback) +} + +func (c *httpClient) transferNames() []string { + names := make([]string, len(c.transfers)) + i := 0 + for name := range c.transfers { + names[i] = name + i++ + } + return names +} + +// batch performs a batch request to the LFS server. +func (c *httpClient) batch(ctx context.Context, operation string, objects []Pointer) (*BatchResponse, error) { + logger := log.FromContext(ctx).WithPrefix("lfs") + url := fmt.Sprintf("%s/objects/batch", c.endpoint.String()) + + // TODO: support ref + request := &BatchRequest{operation, c.transferNames(), nil, objects, HashAlgorithmSHA256} + + payload := new(bytes.Buffer) + err := json.NewEncoder(payload).Encode(request) + if err != nil { + logger.Errorf("Error encoding json: %v", err) + return nil, err + } + + logger.Debugf("Calling: %s", url) + + req, err := http.NewRequestWithContext(ctx, "POST", url, payload) + if err != nil { + logger.Errorf("Error creating request: %v", err) + return nil, err + } + req.Header.Set("Content-type", MediaType) + req.Header.Set("Accept", MediaType) + + res, err := c.client.Do(req) + if err != nil { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + logger.Errorf("Error while processing request: %v", err) + return nil, err + } + defer res.Body.Close() // nolint: errcheck + + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Unexpected server response: %s", res.Status) + } + + var response BatchResponse + err = json.NewDecoder(res.Body).Decode(&response) + if err != nil { + logger.Errorf("Error decoding json: %v", err) + return nil, err + } + + if len(response.Transfer) == 0 { + response.Transfer = TransferBasic + } + + return &response, nil +} + +func (c *httpClient) performOperation(ctx context.Context, objects []Pointer, dc DownloadCallback, uc UploadCallback) error { + logger := log.FromContext(ctx).WithPrefix("lfs") + if len(objects) == 0 { + return nil + } + + operation := OperationDownload + if uc != nil { + operation = OperationUpload + } + + result, err := c.batch(ctx, operation, objects) + if err != nil { + return err + } + + transferAdapter, ok := c.transfers[result.Transfer] + if !ok { + return fmt.Errorf("TransferAdapter not found: %s", result.Transfer) + } + + for _, object := range result.Objects { + if object.Error != nil { + objectError := errors.New(object.Error.Message) + logger.Debugf("Error on object %v: %v", object.Pointer, objectError) + if uc != nil { + if _, err := uc(object.Pointer, objectError); err != nil { + return err + } + } else { + if err := dc(object.Pointer, nil, objectError); err != nil { + return err + } + } + continue + } + + if uc != nil { + if len(object.Actions) == 0 { + logger.Debugf("%v already present on server", object.Pointer) + continue + } + + link, ok := object.Actions[ActionUpload] + if !ok { + logger.Debugf("%+v", object) + return errors.New("Missing action 'upload'") + } + + content, err := uc(object.Pointer, nil) + if err != nil { + return err + } + + err = transferAdapter.Upload(ctx, object.Pointer, content, link) + + content.Close() // nolint: errcheck + + if err != nil { + return err + } + + link, ok = object.Actions[ActionVerify] + if ok { + if err := transferAdapter.Verify(ctx, object.Pointer, link); err != nil { + return err + } + } + } else { + link, ok := object.Actions[ActionDownload] + if !ok { + logger.Debugf("%+v", object) + return errors.New("Missing action 'download'") + } + + content, err := transferAdapter.Download(ctx, object.Pointer, link) + if err != nil { + return err + } + + if err := dc(object.Pointer, content, nil); err != nil { + return err + } + } + } + + return nil +} diff --git a/server/lfs/pointer.go b/server/lfs/pointer.go new file mode 100644 index 0000000000000000000000000000000000000000..b38d04ce59b67ac7e5021a0d8bd33a5cc04b077e --- /dev/null +++ b/server/lfs/pointer.go @@ -0,0 +1,122 @@ +package lfs + +import ( + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "path" + "regexp" + "strconv" + "strings" +) + +const ( + blobSizeCutoff = 1024 + + // HashAlgorithmSHA256 is the hash algorithm used for Git LFS. + HashAlgorithmSHA256 = "sha256" + + // MetaFileIdentifier is the string appearing at the first line of LFS pointer files. + // https://github.com/git-lfs/git-lfs/blob/master/docs/spec.md + MetaFileIdentifier = "version https://git-lfs.github.com/spec/v1" + + // MetaFileOidPrefix appears in LFS pointer files on a line before the sha256 hash. + MetaFileOidPrefix = "oid " + HashAlgorithmSHA256 + ":" +) + +var ( + // ErrMissingPrefix occurs if the content lacks the LFS prefix + ErrMissingPrefix = errors.New("Content lacks the LFS prefix") + + // ErrInvalidStructure occurs if the content has an invalid structure + ErrInvalidStructure = errors.New("Content has an invalid structure") + + // ErrInvalidOIDFormat occurs if the oid has an invalid format + ErrInvalidOIDFormat = errors.New("OID has an invalid format") +) + +// ReadPointer tries to read LFS pointer data from the reader +func ReadPointer(reader io.Reader) (Pointer, error) { + buf := make([]byte, blobSizeCutoff) + n, err := io.ReadFull(reader, buf) + if err != nil && err != io.ErrUnexpectedEOF { + return Pointer{}, err + } + buf = buf[:n] + + return ReadPointerFromBuffer(buf) +} + +var oidPattern = regexp.MustCompile(`^[a-f\d]{64}$`) + +// ReadPointerFromBuffer will return a pointer if the provided byte slice is a pointer file or an error otherwise. +func ReadPointerFromBuffer(buf []byte) (Pointer, error) { + var p Pointer + + headString := string(buf) + if !strings.HasPrefix(headString, MetaFileIdentifier) { + return p, ErrMissingPrefix + } + + splitLines := strings.Split(headString, "\n") + if len(splitLines) < 3 { + return p, ErrInvalidStructure + } + + oid := strings.TrimPrefix(splitLines[1], MetaFileOidPrefix) + if len(oid) != 64 || !oidPattern.MatchString(oid) { + return p, ErrInvalidOIDFormat + } + size, err := strconv.ParseInt(strings.TrimPrefix(splitLines[2], "size "), 10, 64) + if err != nil { + return p, err + } + + p.Oid = oid + p.Size = size + + return p, nil +} + +// IsValid checks if the pointer has a valid structure. +// It doesn't check if the pointed-to-content exists. +func (p Pointer) IsValid() bool { + if len(p.Oid) != 64 { + return false + } + if !oidPattern.MatchString(p.Oid) { + return false + } + if p.Size < 0 { + return false + } + return true +} + +// String returns the string representation of the pointer +// https://github.com/git-lfs/git-lfs/blob/main/docs/spec.md#the-pointer +func (p Pointer) String() string { + return fmt.Sprintf("%s\n%s%s\nsize %d\n", MetaFileIdentifier, MetaFileOidPrefix, p.Oid, p.Size) +} + +// RelativePath returns the relative storage path of the pointer +func (p Pointer) RelativePath() string { + if len(p.Oid) < 5 { + return p.Oid + } + + return path.Join(p.Oid[0:2], p.Oid[2:4], p.Oid[4:]) +} + +// GeneratePointer generates a pointer for arbitrary content +func GeneratePointer(content io.Reader) (Pointer, error) { + h := sha256.New() + c, err := io.Copy(h, content) + if err != nil { + return Pointer{}, err + } + sum := h.Sum(nil) + return Pointer{Oid: hex.EncodeToString(sum), Size: c}, nil +} diff --git a/server/lfs/scanner.go b/server/lfs/scanner.go new file mode 100644 index 0000000000000000000000000000000000000000..1eba4a4d31dfc224728c977efa3f7c4533b95f56 --- /dev/null +++ b/server/lfs/scanner.go @@ -0,0 +1,210 @@ +package lfs + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "strconv" + "strings" + "sync" + + "github.com/charmbracelet/soft-serve/git" + gitm "github.com/gogs/git-module" +) + +// SearchPointerBlobs scans the whole repository for LFS pointer files +func SearchPointerBlobs(ctx context.Context, repo *git.Repository, pointerChan chan<- PointerBlob, errChan chan<- error) { + basePath := repo.Path + + catFileCheckReader, catFileCheckWriter := io.Pipe() + shasToBatchReader, shasToBatchWriter := io.Pipe() + catFileBatchReader, catFileBatchWriter := io.Pipe() + + wg := sync.WaitGroup{} + wg.Add(6) + + // Create the go-routines in reverse order. + + // 4. Take the output of cat-file --batch and check if each file in turn + // to see if they're pointers to files in the LFS store + go createPointerResultsFromCatFileBatch(ctx, catFileBatchReader, &wg, pointerChan) + + // 3. Take the shas of the blobs and batch read them + go catFileBatch(ctx, shasToBatchReader, catFileBatchWriter, &wg, basePath) + + // 2. From the provided objects restrict to blobs <=1k + go blobsLessThan1024FromCatFileBatchCheck(catFileCheckReader, shasToBatchWriter, &wg) + + // 1. Run batch-check on all objects in the repository + revListReader, revListWriter := io.Pipe() + shasToCheckReader, shasToCheckWriter := io.Pipe() + go catFileBatchCheck(ctx, shasToCheckReader, catFileCheckWriter, &wg, basePath) + go blobsFromRevListObjects(revListReader, shasToCheckWriter, &wg) + go revListAllObjects(ctx, revListWriter, &wg, basePath, errChan) + wg.Wait() + + close(pointerChan) + close(errChan) +} + +func createPointerResultsFromCatFileBatch(ctx context.Context, catFileBatchReader *io.PipeReader, wg *sync.WaitGroup, pointerChan chan<- PointerBlob) { + defer wg.Done() + defer catFileBatchReader.Close() // nolint: errcheck + + bufferedReader := bufio.NewReader(catFileBatchReader) + buf := make([]byte, 1025) + +loop: + for { + select { + case <-ctx.Done(): + break loop + default: + } + + // File descriptor line: sha + sha, err := bufferedReader.ReadString(' ') + if err != nil { + _ = catFileBatchReader.CloseWithError(err) + break + } + sha = strings.TrimSpace(sha) + // Throw away the blob + if _, err := bufferedReader.ReadString(' '); err != nil { + _ = catFileBatchReader.CloseWithError(err) + break + } + sizeStr, err := bufferedReader.ReadString('\n') + if err != nil { + _ = catFileBatchReader.CloseWithError(err) + break + } + size, err := strconv.Atoi(sizeStr[:len(sizeStr)-1]) + if err != nil { + _ = catFileBatchReader.CloseWithError(err) + break + } + pointerBuf := buf[:size+1] + if _, err := io.ReadFull(bufferedReader, pointerBuf); err != nil { + _ = catFileBatchReader.CloseWithError(err) + break + } + pointerBuf = pointerBuf[:size] + // Now we need to check if the pointerBuf is an LFS pointer + pointer, _ := ReadPointerFromBuffer(pointerBuf) + if !pointer.IsValid() { + continue + } + + pointerChan <- PointerBlob{Hash: sha, Pointer: pointer} + } +} + +func catFileBatch(ctx context.Context, shasToBatchReader *io.PipeReader, catFileBatchWriter *io.PipeWriter, wg *sync.WaitGroup, basePath string) { + defer wg.Done() + defer shasToBatchReader.Close() // nolint: errcheck + defer catFileBatchWriter.Close() // nolint: errcheck + + stderr := new(bytes.Buffer) + var errbuf strings.Builder + if err := gitm.NewCommandWithContext(ctx, "cat-file", "--batch").RunInDirWithOptions(basePath, gitm.RunInDirOptions{ + Stdout: catFileBatchWriter, + Stdin: shasToBatchReader, + Stderr: stderr, + }); err != nil { + _ = shasToBatchReader.CloseWithError(fmt.Errorf("git rev-list [%s]: %w - %s", basePath, err, errbuf.String())) + } +} + +func blobsLessThan1024FromCatFileBatchCheck(catFileCheckReader *io.PipeReader, shasToBatchWriter *io.PipeWriter, wg *sync.WaitGroup) { + defer wg.Done() + defer catFileCheckReader.Close() // nolint: errcheck + scanner := bufio.NewScanner(catFileCheckReader) + defer func() { + _ = shasToBatchWriter.CloseWithError(scanner.Err()) + }() + for scanner.Scan() { + line := scanner.Text() + if len(line) == 0 { + continue + } + fields := strings.Split(line, " ") + if len(fields) < 3 || fields[1] != "blob" { + continue + } + size, _ := strconv.Atoi(fields[2]) + if size > 1024 { + continue + } + toWrite := []byte(fields[0] + "\n") + for len(toWrite) > 0 { + n, err := shasToBatchWriter.Write(toWrite) + if err != nil { + _ = catFileCheckReader.CloseWithError(err) + break + } + toWrite = toWrite[n:] + } + } +} + +func catFileBatchCheck(ctx context.Context, shasToCheckReader *io.PipeReader, catFileCheckWriter *io.PipeWriter, wg *sync.WaitGroup, basePath string) { + defer wg.Done() + defer shasToCheckReader.Close() // nolint: errcheck + defer catFileCheckWriter.Close() // nolint: errcheck + + stderr := new(bytes.Buffer) + var errbuf strings.Builder + if err := gitm.NewCommandWithContext(ctx, "cat-file", "--batch-check").RunInDirWithOptions(basePath, gitm.RunInDirOptions{ + Stdout: catFileCheckWriter, + Stdin: shasToCheckReader, + Stderr: stderr, + }); err != nil { + _ = shasToCheckReader.CloseWithError(fmt.Errorf("git rev-list [%s]: %w - %s", basePath, err, errbuf.String())) + } +} + +func blobsFromRevListObjects(revListReader *io.PipeReader, shasToCheckWriter *io.PipeWriter, wg *sync.WaitGroup) { + defer wg.Done() + defer revListReader.Close() // nolint: errcheck + scanner := bufio.NewScanner(revListReader) + defer func() { + _ = shasToCheckWriter.CloseWithError(scanner.Err()) + }() + + for scanner.Scan() { + line := scanner.Text() + if len(line) == 0 { + continue + } + fields := strings.Split(line, " ") + if len(fields) < 2 || len(fields[1]) == 0 { + continue + } + toWrite := []byte(fields[0] + "\n") + for len(toWrite) > 0 { + n, err := shasToCheckWriter.Write(toWrite) + if err != nil { + _ = revListReader.CloseWithError(err) + break + } + toWrite = toWrite[n:] + } + } +} + +func revListAllObjects(ctx context.Context, revListWriter *io.PipeWriter, wg *sync.WaitGroup, basePath string, errChan chan<- error) { + defer wg.Done() + defer revListWriter.Close() // nolint: errcheck + + stderr := new(bytes.Buffer) + var errbuf strings.Builder + if err := gitm.NewCommandWithContext(ctx, "rev-list", "--objects", "--all").RunInDirWithOptions(basePath, gitm.RunInDirOptions{ + Stdout: revListWriter, + Stderr: stderr, + }); err != nil { + errChan <- fmt.Errorf("git rev-list [%s]: %w - %s", basePath, err, errbuf.String()) + } +} diff --git a/server/lfs/ssh_client.go b/server/lfs/ssh_client.go new file mode 100644 index 0000000000000000000000000000000000000000..ba3e2471078f11c9a38fecb1d4511b5f754e3e3e --- /dev/null +++ b/server/lfs/ssh_client.go @@ -0,0 +1,3 @@ +package lfs + +// TODO: implement Git LFS SSH client. diff --git a/server/lfs/transfer.go b/server/lfs/transfer.go new file mode 100644 index 0000000000000000000000000000000000000000..478568836acf352bf5556532ce9b9e95479c1003 --- /dev/null +++ b/server/lfs/transfer.go @@ -0,0 +1,17 @@ +package lfs + +import ( + "context" + "io" +) + +// TransferBasic is the name of the Git LFS basic transfer protocol. +const TransferBasic = "basic" + +// TransferAdapter represents an adapter for downloading/uploading LFS objects +type TransferAdapter interface { + Name() string + Download(ctx context.Context, p Pointer, l *Link) (io.ReadCloser, error) + Upload(ctx context.Context, p Pointer, r io.Reader, l *Link) error + Verify(ctx context.Context, p Pointer, l *Link) error +} diff --git a/server/proto/repo.go b/server/proto/repo.go index 68d88741bce6cadc07ba2c4b685e364c53d03ff6..e721d44a2b115165d13502992a04b8287a1a7b26 100644 --- a/server/proto/repo.go +++ b/server/proto/repo.go @@ -6,8 +6,13 @@ import ( "github.com/charmbracelet/soft-serve/git" ) +// ContextKeyRepository is the context key for the repository. +var ContextKeyRepository = &struct{ string }{"repository"} + // Repository is a Git repository interface. type Repository interface { + // ID returns the repository's ID. + ID() int64 // Name returns the repository's name. Name() string // ProjectName returns the repository's project name. diff --git a/server/proto/user.go b/server/proto/user.go index 6276a14b7fe13497a3b802dcccca677eebd19c97..f8fd65a1479ccbd9736e9c4af353b303a74402eb 100644 --- a/server/proto/user.go +++ b/server/proto/user.go @@ -2,8 +2,13 @@ package proto import "golang.org/x/crypto/ssh" +// ContextKeyUser is the context key for the user. +var ContextKeyUser = &struct{ string }{"user"} + // User is an interface representing a user. type User interface { + // ID returns the user's ID. + ID() int64 // Username returns the user's username. Username() string // IsAdmin returns whether the user is an admin. diff --git a/server/ssh/cmd/delete.go b/server/ssh/cmd/delete.go index 02dff775d469753744f051c89c2887ce3bb3d367..b719ff53721f6cdd4a4cce0d3a4aab17ff8c99d5 100644 --- a/server/ssh/cmd/delete.go +++ b/server/ssh/cmd/delete.go @@ -6,6 +6,8 @@ import ( ) func deleteCommand() *cobra.Command { + var lfs bool + cmd := &cobra.Command{ Use: "delete REPOSITORY", Aliases: []string{"del", "remove", "rm"}, @@ -17,8 +19,11 @@ func deleteCommand() *cobra.Command { be := backend.FromContext(ctx) name := args[0] - return be.DeleteRepository(ctx, name) + return be.DeleteRepository(ctx, name, lfs) }, } + + cmd.Flags().BoolVarP(&lfs, "lfs", "", false, "Delete LFS objects") + return cmd } diff --git a/server/ssh/git.go b/server/ssh/git.go index 051dc25921ad137d89dacdc44501ab8e70637369..d2a030f9ded5d6ef644a363b1c84fcb1e56bd835 100644 --- a/server/ssh/git.go +++ b/server/ssh/git.go @@ -10,6 +10,7 @@ import ( "github.com/charmbracelet/soft-serve/server/backend" "github.com/charmbracelet/soft-serve/server/config" "github.com/charmbracelet/soft-serve/server/git" + "github.com/charmbracelet/soft-serve/server/lfs" "github.com/charmbracelet/soft-serve/server/proto" "github.com/charmbracelet/soft-serve/server/sshutils" "github.com/charmbracelet/soft-serve/server/utils" @@ -24,11 +25,17 @@ func handleGit(s ssh.Session) { cmdLine := s.Command() start := time.Now() + var username string + user := ctx.Value(proto.ContextKeyUser).(proto.User) + if user != nil { + username = user.Username() + } + // repo should be in the form of "repo.git" name := utils.SanitizeRepo(cmdLine[1]) pk := s.PublicKey() ak := sshutils.MarshalAuthorizedKey(pk) - accessLevel := be.AccessLevelByPublicKey(ctx, name, pk) + accessLevel := be.AccessLevelForUser(ctx, name, user) // git bare repositories should end in ".git" // https://git-scm.com/docs/gitrepository-layout repo := name + ".git" @@ -43,7 +50,7 @@ func handleGit(s ssh.Session) { "SOFT_SERVE_REPO_NAME=" + name, "SOFT_SERVE_REPO_PATH=" + filepath.Join(reposDir, repo), "SOFT_SERVE_PUBLIC_KEY=" + ak, - "SOFT_SERVE_USERNAME=" + s.User(), + "SOFT_SERVE_USERNAME=" + username, "SOFT_SERVE_LOG_PATH=" + filepath.Join(cfg.DataPath, "log", "hooks.log"), } @@ -120,5 +127,27 @@ func handleGit(s ssh.Session) { logger.Error("git middleware", "err", err) sshFatal(s, git.ErrSystemMalfunction) } + case git.LFSTransferService: + if accessLevel < access.ReadWriteAccess { + sshFatal(s, git.ErrNotAuthed) + return + } + + if len(cmdLine) != 3 || + (cmdLine[2] != lfs.OperationDownload && cmdLine[2] != lfs.OperationUpload) { + sshFatal(s, git.ErrInvalidRequest) + return + } + + cmd.Args = []string{ + name, + cmdLine[2], + } + + if err := git.LFSTransfer(ctx, cmd); err != nil { + logger.Error("git middleware", "err", err) + sshFatal(s, git.ErrSystemMalfunction) + return + } } } diff --git a/server/ssh/middleware.go b/server/ssh/middleware.go index 7fe749185220ab637efd089149ccd7fa78170942..909c499d4ddffa5f2454c10185e7057e9283bb74 100644 --- a/server/ssh/middleware.go +++ b/server/ssh/middleware.go @@ -6,14 +6,18 @@ import ( "github.com/charmbracelet/log" "github.com/charmbracelet/soft-serve/server/backend" "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/store" "github.com/charmbracelet/ssh" ) // ContextMiddleware adds the config, backend, and logger to the session context. -func ContextMiddleware(cfg *config.Config, be *backend.Backend, logger *log.Logger) func(ssh.Handler) ssh.Handler { +func ContextMiddleware(cfg *config.Config, dbx *db.DB, datastore store.Store, be *backend.Backend, logger *log.Logger) func(ssh.Handler) ssh.Handler { return func(sh ssh.Handler) ssh.Handler { return func(s ssh.Session) { s.Context().SetValue(config.ContextKey, cfg) + s.Context().SetValue(db.ContextKey, dbx) + s.Context().SetValue(store.ContextKey, datastore) s.Context().SetValue(backend.ContextKey, be) s.Context().SetValue(log.ContextKey, logger.WithPrefix("ssh")) sh(s) diff --git a/server/ssh/session_test.go b/server/ssh/session_test.go index 104c92ad14095a4b2d851c40e959c05137e2d5f8..17cbc5c08487e89dc3be91c596c68fda06fccf89 100644 --- a/server/ssh/session_test.go +++ b/server/ssh/session_test.go @@ -13,6 +13,8 @@ import ( "github.com/charmbracelet/soft-serve/server/config" "github.com/charmbracelet/soft-serve/server/db" "github.com/charmbracelet/soft-serve/server/db/migrate" + "github.com/charmbracelet/soft-serve/server/store" + "github.com/charmbracelet/soft-serve/server/store/database" "github.com/charmbracelet/soft-serve/server/test" "github.com/charmbracelet/ssh" bm "github.com/charmbracelet/wish/bubbletea" @@ -65,22 +67,24 @@ func setup(tb testing.TB) (*gossh.Session, func() error) { log.Fatal(err) } ctx = config.WithContext(ctx, cfg) - db, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource) + dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource) if err != nil { tb.Fatal(err) } - if err := migrate.Migrate(ctx, db); err != nil { + if err := migrate.Migrate(ctx, dbx); err != nil { tb.Fatal(err) } - be := backend.New(ctx, cfg, db) + dbstore := database.New(ctx, dbx) + ctx = store.WithContext(ctx, dbstore) + be := backend.New(ctx, cfg, dbx) ctx = backend.WithContext(ctx, be) return testsession.New(tb, &ssh.Server{ - Handler: ContextMiddleware(cfg, be, log.Default())(bm.MiddlewareWithProgramHandler(SessionHandler, termenv.ANSI256)(func(s ssh.Session) { + Handler: ContextMiddleware(cfg, dbx, dbstore, be, log.Default())(bm.MiddlewareWithProgramHandler(SessionHandler, termenv.ANSI256)(func(s ssh.Session) { _, _, active := s.Pty() if !active { os.Exit(1) } s.Exit(0) })), - }, nil), db.Close + }, nil), dbx.Close } diff --git a/server/ssh/ssh.go b/server/ssh/ssh.go index 39f300e58fff5352243fefdb8b85c434582e0d96..fe6854507ac8fe6b53bf3c6a6a08b588c660ec8d 100644 --- a/server/ssh/ssh.go +++ b/server/ssh/ssh.go @@ -13,8 +13,11 @@ import ( "github.com/charmbracelet/soft-serve/server/access" "github.com/charmbracelet/soft-serve/server/backend" "github.com/charmbracelet/soft-serve/server/config" + "github.com/charmbracelet/soft-serve/server/db" "github.com/charmbracelet/soft-serve/server/git" + "github.com/charmbracelet/soft-serve/server/proto" "github.com/charmbracelet/soft-serve/server/sshutils" + "github.com/charmbracelet/soft-serve/server/store" "github.com/charmbracelet/ssh" "github.com/charmbracelet/wish" bm "github.com/charmbracelet/wish/bubbletea" @@ -104,6 +107,8 @@ type SSHServer struct { // nolint: revive func NewSSHServer(ctx context.Context) (*SSHServer, error) { cfg := config.FromContext(ctx) logger := log.FromContext(ctx).WithPrefix("ssh") + dbx := db.FromContext(ctx) + datastore := store.FromContext(ctx) be := backend.FromContext(ctx) var err error @@ -122,7 +127,7 @@ func NewSSHServer(ctx context.Context) (*SSHServer, error) { // CLI middleware. CommandMiddleware, // Context middleware. - ContextMiddleware(cfg, be, logger), + ContextMiddleware(cfg, dbx, datastore, be, logger), // Logging middleware. lm.MiddlewareWithLogger( &loggerAdapter{logger, log.DebugLevel}, @@ -191,7 +196,10 @@ func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed publicKeyCounter.WithLabelValues(strconv.FormatBool(*allowed)).Inc() }(&allowed) - ac := s.be.AccessLevelByPublicKey(ctx, "", pk) + user, _ := s.be.UserByPublicKey(ctx, pk) + ctx.SetValue(proto.ContextKeyUser, user) + + ac := s.be.AccessLevelForUser(ctx, "", user) s.logger.Debugf("access level for %q: %s", ak, ac) allowed = ac >= access.ReadWriteAccess return diff --git a/server/storage/local.go b/server/storage/local.go new file mode 100644 index 0000000000000000000000000000000000000000..8a51157af6e9a796921225f142083a6243aa1720 --- /dev/null +++ b/server/storage/local.go @@ -0,0 +1,91 @@ +package storage + +import ( + "errors" + "io" + "io/fs" + "os" + "path/filepath" + "strings" +) + +// LocalStorage is a storage implementation that stores objects on the local +// filesystem. +type LocalStorage struct { + root string +} + +var _ Storage = (*LocalStorage)(nil) + +// NewLocalStorage creates a new LocalStorage. +func NewLocalStorage(root string) *LocalStorage { + return &LocalStorage{root: root} +} + +// Delete implements Storage. +func (l *LocalStorage) Delete(name string) error { + name = l.fixPath(name) + return os.Remove(name) +} + +// Open implements Storage. +func (l *LocalStorage) Open(name string) (Object, error) { + name = l.fixPath(name) + return os.Open(name) +} + +// Stat implements Storage. +func (l *LocalStorage) Stat(name string) (fs.FileInfo, error) { + name = l.fixPath(name) + return os.Stat(name) +} + +// Put implements Storage. +func (l *LocalStorage) Put(name string, r io.Reader) error { + name = l.fixPath(name) + if err := os.MkdirAll(filepath.Dir(name), os.ModePerm); err != nil { + return err + } + + f, err := os.Create(name) + if err != nil { + return err + } + defer f.Close() // nolint: errcheck + _, err = io.Copy(f, r) + return err +} + +// Exists implements Storage. +func (l *LocalStorage) Exists(name string) (bool, error) { + name = l.fixPath(name) + _, err := os.Stat(name) + if err == nil { + return true, nil + } + if errors.Is(err, fs.ErrNotExist) { + return false, nil + } + return false, err +} + +// Rename implements Storage. +func (l *LocalStorage) Rename(oldName, newName string) error { + oldName = l.fixPath(oldName) + newName = l.fixPath(newName) + if err := os.MkdirAll(filepath.Dir(newName), os.ModePerm); err != nil { + return err + } + + return os.Rename(oldName, newName) +} + +// Replace all slashes with the OS-specific separator +func (l LocalStorage) fixPath(path string) string { + path = strings.ReplaceAll(path, "/", string(os.PathSeparator)) + if !filepath.IsAbs(path) { + return filepath.Join(l.root, path) + } + + return path +} diff --git a/server/storage/storage.go b/server/storage/storage.go new file mode 100644 index 0000000000000000000000000000000000000000..dc435dbbbd8b39b6a1e0ac69cc1269276fab32e6 --- /dev/null +++ b/server/storage/storage.go @@ -0,0 +1,23 @@ +package storage + +import ( + "io" + "io/fs" +) + +// Object is an interface for objects that can be stored. +type Object interface { + io.Seeker + fs.File + Name() string +} + +// Storage is an interface for storing and retrieving objects. +type Storage interface { + Open(name string) (Object, error) + Stat(name string) (fs.FileInfo, error) + Put(name string, r io.Reader) error + Delete(name string) error + Exists(name string) (bool, error) + Rename(oldName, newName string) error +} diff --git a/server/store/context.go b/server/store/context.go new file mode 100644 index 0000000000000000000000000000000000000000..938c7dca07457ce15448c9d513234b92b2c19082 --- /dev/null +++ b/server/store/context.go @@ -0,0 +1,20 @@ +package store + +import "context" + +// ContextKey is the store context key. +var ContextKey = &struct{ string }{"store"} + +// FromContext returns the store from the given context. +func FromContext(ctx context.Context) Store { + if s, ok := ctx.Value(ContextKey).(Store); ok { + return s + } + + return nil +} + +// WithContext returns a new context with the given store. +func WithContext(ctx context.Context, s Store) context.Context { + return context.WithValue(ctx, ContextKey, s) +} diff --git a/server/store/database/collab.go b/server/store/database/collab.go index 50424445e13ede0006e81d7e25aca255ff139ff9..e290593c2d6f5dbef4461daf147a72b9eaafb924 100644 --- a/server/store/database/collab.go +++ b/server/store/database/collab.go @@ -15,7 +15,7 @@ type collabStore struct{} var _ store.CollaboratorStore = (*collabStore)(nil) // AddCollabByUsernameAndRepo implements store.CollaboratorStore. -func (*collabStore) AddCollabByUsernameAndRepo(ctx context.Context, tx *db.Tx, username string, repo string) error { +func (*collabStore) AddCollabByUsernameAndRepo(ctx context.Context, tx db.Handler, username string, repo string) error { username = strings.ToLower(username) if err := utils.ValidateUsername(username); err != nil { return err @@ -38,7 +38,7 @@ func (*collabStore) AddCollabByUsernameAndRepo(ctx context.Context, tx *db.Tx, u } // GetCollabByUsernameAndRepo implements store.CollaboratorStore. -func (*collabStore) GetCollabByUsernameAndRepo(ctx context.Context, tx *db.Tx, username string, repo string) (models.Collab, error) { +func (*collabStore) GetCollabByUsernameAndRepo(ctx context.Context, tx db.Handler, username string, repo string) (models.Collab, error) { var m models.Collab username = strings.ToLower(username) @@ -63,7 +63,7 @@ func (*collabStore) GetCollabByUsernameAndRepo(ctx context.Context, tx *db.Tx, u } // ListCollabsByRepo implements store.CollaboratorStore. -func (*collabStore) ListCollabsByRepo(ctx context.Context, tx *db.Tx, repo string) ([]models.Collab, error) { +func (*collabStore) ListCollabsByRepo(ctx context.Context, tx db.Handler, repo string) ([]models.Collab, error) { var m []models.Collab repo = utils.SanitizeRepo(repo) @@ -82,7 +82,7 @@ func (*collabStore) ListCollabsByRepo(ctx context.Context, tx *db.Tx, repo strin } // ListCollabsByRepoAsUsers implements store.CollaboratorStore. -func (*collabStore) ListCollabsByRepoAsUsers(ctx context.Context, tx *db.Tx, repo string) ([]models.User, error) { +func (*collabStore) ListCollabsByRepoAsUsers(ctx context.Context, tx db.Handler, repo string) ([]models.User, error) { var m []models.User repo = utils.SanitizeRepo(repo) @@ -102,7 +102,7 @@ func (*collabStore) ListCollabsByRepoAsUsers(ctx context.Context, tx *db.Tx, rep } // RemoveCollabByUsernameAndRepo implements store.CollaboratorStore. -func (*collabStore) RemoveCollabByUsernameAndRepo(ctx context.Context, tx *db.Tx, username string, repo string) error { +func (*collabStore) RemoveCollabByUsernameAndRepo(ctx context.Context, tx db.Handler, username string, repo string) error { username = strings.ToLower(username) if err := utils.ValidateUsername(username); err != nil { return err diff --git a/server/store/database/database.go b/server/store/database/database.go index 4960f68537f5d18b6c2e56d636e5d2f634cd8bae..c63d70992eb4665e8dd1ecf6bf007c5eb4653c41 100644 --- a/server/store/database/database.go +++ b/server/store/database/database.go @@ -19,6 +19,7 @@ type datastore struct { *repoStore *userStore *collabStore + *lfsStore } // New returns a new store.Store database. @@ -36,6 +37,7 @@ func New(ctx context.Context, db *db.DB) store.Store { repoStore: &repoStore{}, userStore: &userStore{}, collabStore: &collabStore{}, + lfsStore: &lfsStore{}, } return s diff --git a/server/store/database/lfs.go b/server/store/database/lfs.go new file mode 100644 index 0000000000000000000000000000000000000000..0233dec7a0ffc48da9bca2004938d4bccb987560 --- /dev/null +++ b/server/store/database/lfs.go @@ -0,0 +1,179 @@ +package database + +import ( + "context" + "strconv" + "strings" + + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/db/models" + "github.com/charmbracelet/soft-serve/server/store" +) + +type lfsStore struct{} + +var _ store.LFSStore = (*lfsStore)(nil) + +func sanitizePath(path string) string { + path = strings.TrimSpace(path) + path = strings.TrimPrefix(path, "/") + return path +} + +// CreateLFSLockForUser implements store.LFSStore. +func (*lfsStore) CreateLFSLockForUser(ctx context.Context, tx db.Handler, repoID int64, userID int64, path string, refname string) error { + path = sanitizePath(path) + query := tx.Rebind(`INSERT INTO lfs_locks (repo_id, user_id, path, refname, updated_at) + VALUES ( + ?, + ?, + ?, + ?, + CURRENT_TIMESTAMP + ); + `) + _, err := tx.ExecContext(ctx, query, repoID, userID, path, refname) + return db.WrapError(err) +} + +// GetLFSLocks implements store.LFSStore. +func (*lfsStore) GetLFSLocks(ctx context.Context, tx db.Handler, repoID int64) ([]models.LFSLock, error) { + var locks []models.LFSLock + query := tx.Rebind(` + SELECT * + FROM lfs_locks + WHERE repo_id = ?; + `) + err := tx.SelectContext(ctx, &locks, query, repoID) + return locks, db.WrapError(err) +} + +// GetLFSLocksForUser implements store.LFSStore. +func (*lfsStore) GetLFSLocksForUser(ctx context.Context, tx db.Handler, repoID int64, userID int64) ([]models.LFSLock, error) { + var locks []models.LFSLock + query := tx.Rebind(` + SELECT * + FROM lfs_locks + WHERE repo_id = ? AND user_id = ?; + `) + err := tx.SelectContext(ctx, &locks, query, repoID, userID) + return locks, db.WrapError(err) +} + +// GetLFSLocksForPath implements store.LFSStore. +func (*lfsStore) GetLFSLocksForPath(ctx context.Context, tx db.Handler, repoID int64, path string) ([]models.LFSLock, error) { + path = sanitizePath(path) + var locks []models.LFSLock + query := tx.Rebind(` + SELECT * + FROM lfs_locks + WHERE repo_id = ? AND path = ?; + `) + err := tx.SelectContext(ctx, &locks, query, repoID, path) + return locks, db.WrapError(err) +} + +// GetLFSLockForUserPath implements store.LFSStore. +func (*lfsStore) GetLFSLockForUserPath(ctx context.Context, tx db.Handler, repoID int64, userID int64, path string) (models.LFSLock, error) { + path = sanitizePath(path) + var lock models.LFSLock + query := tx.Rebind(` + SELECT * + FROM lfs_locks + WHERE repo_id = ? AND user_id = ? AND path = ?; + `) + err := tx.GetContext(ctx, &lock, query, repoID, userID, path) + return lock, db.WrapError(err) +} + +// GetLFSLockByID implements store.LFSStore. +func (*lfsStore) GetLFSLockByID(ctx context.Context, tx db.Handler, id string) (models.LFSLock, error) { + iid, err := strconv.Atoi(id) + if err != nil { + return models.LFSLock{}, err + } + + var lock models.LFSLock + query := tx.Rebind(` + SELECT * + FROM lfs_locks + WHERE lfs_locks.id = ?; + `) + err = tx.GetContext(ctx, &lock, query, iid) + return lock, db.WrapError(err) +} + +// GetLFSLockForUserByID implements store.LFSStore. +func (*lfsStore) GetLFSLockForUserByID(ctx context.Context, tx db.Handler, userID int64, id string) (models.LFSLock, error) { + iid, err := strconv.Atoi(id) + if err != nil { + return models.LFSLock{}, err + } + + var lock models.LFSLock + query := tx.Rebind(` + SELECT * + FROM lfs_locks + WHERE id = ? AND user_id = ?; + `) + err = tx.GetContext(ctx, &lock, query, iid, userID) + return lock, db.WrapError(err) +} + +// DeleteLFSLockForUserByID implements store.LFSStore. +func (*lfsStore) DeleteLFSLockForUserByID(ctx context.Context, tx db.Handler, userID int64, id string) error { + iid, err := strconv.Atoi(id) + if err != nil { + return err + } + + query := tx.Rebind(` + DELETE FROM lfs_locks + WHERE user_id = ? AND id = ?; + `) + _, err = tx.ExecContext(ctx, query, userID, iid) + return db.WrapError(err) +} + +// CreateLFSObject implements store.LFSStore. +func (*lfsStore) CreateLFSObject(ctx context.Context, tx db.Handler, repoID int64, oid string, size int64) error { + query := tx.Rebind(`INSERT INTO lfs_objects (repo_id, oid, size, updated_at) VALUES (?, ?, ?, CURRENT_TIMESTAMP);`) + _, err := tx.ExecContext(ctx, query, repoID, oid, size) + return db.WrapError(err) +} + +// DeleteLFSObjectByOid implements store.LFSStore. +func (*lfsStore) DeleteLFSObjectByOid(ctx context.Context, tx db.Handler, repoID int64, oid string) error { + query := tx.Rebind(`DELETE FROM lfs_objects WHERE repo_id = ? AND oid = ?;`) + _, err := tx.ExecContext(ctx, query, repoID, oid) + return db.WrapError(err) +} + +// GetLFSObjectByOid implements store.LFSStore. +func (*lfsStore) GetLFSObjectByOid(ctx context.Context, tx db.Handler, repoID int64, oid string) (models.LFSObject, error) { + var obj models.LFSObject + query := tx.Rebind(`SELECT * FROM lfs_objects WHERE repo_id = ? AND oid = ?;`) + err := tx.GetContext(ctx, &obj, query, repoID, oid) + return obj, db.WrapError(err) +} + +// GetLFSObjects implements store.LFSStore. +func (*lfsStore) GetLFSObjects(ctx context.Context, tx db.Handler, repoID int64) ([]models.LFSObject, error) { + var objs []models.LFSObject + query := tx.Rebind(`SELECT * FROM lfs_objects WHERE repo_id = ?;`) + err := tx.SelectContext(ctx, &objs, query, repoID) + return objs, db.WrapError(err) +} + +// GetLFSObjectsByName implements store.LFSStore. +func (*lfsStore) GetLFSObjectsByName(ctx context.Context, tx db.Handler, name string) ([]models.LFSObject, error) { + var objs []models.LFSObject + query := tx.Rebind(` + SELECT lfs_objects.* + FROM lfs_objects + INNER JOIN repos ON lfs_objects.repo_id = repos.id + WHERE repos.name = ?; + `) + err := tx.SelectContext(ctx, &objs, query, name) + return objs, db.WrapError(err) +} diff --git a/server/store/database/repo.go b/server/store/database/repo.go index 76436b9bc4393cf4a146559a9efe43417add6feb..06ea92b492c32cfaf4c1020611daebf5ea7a8bbb 100644 --- a/server/store/database/repo.go +++ b/server/store/database/repo.go @@ -14,7 +14,7 @@ type repoStore struct{} var _ store.RepositoryStore = (*repoStore)(nil) // CreateRepo implements store.RepositoryStore. -func (*repoStore) CreateRepo(ctx context.Context, tx *db.Tx, name string, projectName string, description string, isPrivate bool, isHidden bool, isMirror bool) error { +func (*repoStore) CreateRepo(ctx context.Context, tx db.Handler, name string, projectName string, description string, isPrivate bool, isHidden bool, isMirror bool) error { name = utils.SanitizeRepo(name) query := tx.Rebind(`INSERT INTO repos (name, project_name, description, private, mirror, hidden, updated_at) VALUES (?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP);`) @@ -24,7 +24,7 @@ func (*repoStore) CreateRepo(ctx context.Context, tx *db.Tx, name string, projec } // DeleteRepoByName implements store.RepositoryStore. -func (*repoStore) DeleteRepoByName(ctx context.Context, tx *db.Tx, name string) error { +func (*repoStore) DeleteRepoByName(ctx context.Context, tx db.Handler, name string) error { name = utils.SanitizeRepo(name) query := tx.Rebind("DELETE FROM repos WHERE name = ?;") _, err := tx.ExecContext(ctx, query, name) @@ -32,7 +32,7 @@ func (*repoStore) DeleteRepoByName(ctx context.Context, tx *db.Tx, name string) } // GetAllRepos implements store.RepositoryStore. -func (*repoStore) GetAllRepos(ctx context.Context, tx *db.Tx) ([]models.Repo, error) { +func (*repoStore) GetAllRepos(ctx context.Context, tx db.Handler) ([]models.Repo, error) { var repos []models.Repo query := tx.Rebind("SELECT * FROM repos;") err := tx.SelectContext(ctx, &repos, query) @@ -40,7 +40,7 @@ func (*repoStore) GetAllRepos(ctx context.Context, tx *db.Tx) ([]models.Repo, er } // GetRepoByName implements store.RepositoryStore. -func (*repoStore) GetRepoByName(ctx context.Context, tx *db.Tx, name string) (models.Repo, error) { +func (*repoStore) GetRepoByName(ctx context.Context, tx db.Handler, name string) (models.Repo, error) { var repo models.Repo name = utils.SanitizeRepo(name) query := tx.Rebind("SELECT * FROM repos WHERE name = ?;") @@ -49,7 +49,7 @@ func (*repoStore) GetRepoByName(ctx context.Context, tx *db.Tx, name string) (mo } // GetRepoDescriptionByName implements store.RepositoryStore. -func (*repoStore) GetRepoDescriptionByName(ctx context.Context, tx *db.Tx, name string) (string, error) { +func (*repoStore) GetRepoDescriptionByName(ctx context.Context, tx db.Handler, name string) (string, error) { var description string name = utils.SanitizeRepo(name) query := tx.Rebind("SELECT description FROM repos WHERE name = ?;") @@ -58,7 +58,7 @@ func (*repoStore) GetRepoDescriptionByName(ctx context.Context, tx *db.Tx, name } // GetRepoIsHiddenByName implements store.RepositoryStore. -func (*repoStore) GetRepoIsHiddenByName(ctx context.Context, tx *db.Tx, name string) (bool, error) { +func (*repoStore) GetRepoIsHiddenByName(ctx context.Context, tx db.Handler, name string) (bool, error) { var isHidden bool name = utils.SanitizeRepo(name) query := tx.Rebind("SELECT hidden FROM repos WHERE name = ?;") @@ -67,7 +67,7 @@ func (*repoStore) GetRepoIsHiddenByName(ctx context.Context, tx *db.Tx, name str } // GetRepoIsMirrorByName implements store.RepositoryStore. -func (*repoStore) GetRepoIsMirrorByName(ctx context.Context, tx *db.Tx, name string) (bool, error) { +func (*repoStore) GetRepoIsMirrorByName(ctx context.Context, tx db.Handler, name string) (bool, error) { var isMirror bool name = utils.SanitizeRepo(name) query := tx.Rebind("SELECT mirror FROM repos WHERE name = ?;") @@ -76,7 +76,7 @@ func (*repoStore) GetRepoIsMirrorByName(ctx context.Context, tx *db.Tx, name str } // GetRepoIsPrivateByName implements store.RepositoryStore. -func (*repoStore) GetRepoIsPrivateByName(ctx context.Context, tx *db.Tx, name string) (bool, error) { +func (*repoStore) GetRepoIsPrivateByName(ctx context.Context, tx db.Handler, name string) (bool, error) { var isPrivate bool name = utils.SanitizeRepo(name) query := tx.Rebind("SELECT private FROM repos WHERE name = ?;") @@ -85,7 +85,7 @@ func (*repoStore) GetRepoIsPrivateByName(ctx context.Context, tx *db.Tx, name st } // GetRepoProjectNameByName implements store.RepositoryStore. -func (*repoStore) GetRepoProjectNameByName(ctx context.Context, tx *db.Tx, name string) (string, error) { +func (*repoStore) GetRepoProjectNameByName(ctx context.Context, tx db.Handler, name string) (string, error) { var pname string name = utils.SanitizeRepo(name) query := tx.Rebind("SELECT project_name FROM repos WHERE name = ?;") @@ -94,7 +94,7 @@ func (*repoStore) GetRepoProjectNameByName(ctx context.Context, tx *db.Tx, name } // SetRepoDescriptionByName implements store.RepositoryStore. -func (*repoStore) SetRepoDescriptionByName(ctx context.Context, tx *db.Tx, name string, description string) error { +func (*repoStore) SetRepoDescriptionByName(ctx context.Context, tx db.Handler, name string, description string) error { name = utils.SanitizeRepo(name) query := tx.Rebind("UPDATE repos SET description = ? WHERE name = ?;") _, err := tx.ExecContext(ctx, query, description, name) @@ -102,7 +102,7 @@ func (*repoStore) SetRepoDescriptionByName(ctx context.Context, tx *db.Tx, name } // SetRepoIsHiddenByName implements store.RepositoryStore. -func (*repoStore) SetRepoIsHiddenByName(ctx context.Context, tx *db.Tx, name string, isHidden bool) error { +func (*repoStore) SetRepoIsHiddenByName(ctx context.Context, tx db.Handler, name string, isHidden bool) error { name = utils.SanitizeRepo(name) query := tx.Rebind("UPDATE repos SET hidden = ? WHERE name = ?;") _, err := tx.ExecContext(ctx, query, isHidden, name) @@ -110,7 +110,7 @@ func (*repoStore) SetRepoIsHiddenByName(ctx context.Context, tx *db.Tx, name str } // SetRepoIsPrivateByName implements store.RepositoryStore. -func (*repoStore) SetRepoIsPrivateByName(ctx context.Context, tx *db.Tx, name string, isPrivate bool) error { +func (*repoStore) SetRepoIsPrivateByName(ctx context.Context, tx db.Handler, name string, isPrivate bool) error { name = utils.SanitizeRepo(name) query := tx.Rebind("UPDATE repos SET private = ? WHERE name = ?;") _, err := tx.ExecContext(ctx, query, isPrivate, name) @@ -118,7 +118,7 @@ func (*repoStore) SetRepoIsPrivateByName(ctx context.Context, tx *db.Tx, name st } // SetRepoNameByName implements store.RepositoryStore. -func (*repoStore) SetRepoNameByName(ctx context.Context, tx *db.Tx, name string, newName string) error { +func (*repoStore) SetRepoNameByName(ctx context.Context, tx db.Handler, name string, newName string) error { name = utils.SanitizeRepo(name) newName = utils.SanitizeRepo(newName) query := tx.Rebind("UPDATE repos SET name = ? WHERE name = ?;") @@ -127,7 +127,7 @@ func (*repoStore) SetRepoNameByName(ctx context.Context, tx *db.Tx, name string, } // SetRepoProjectNameByName implements store.RepositoryStore. -func (*repoStore) SetRepoProjectNameByName(ctx context.Context, tx *db.Tx, name string, projectName string) error { +func (*repoStore) SetRepoProjectNameByName(ctx context.Context, tx db.Handler, name string, projectName string) error { name = utils.SanitizeRepo(name) query := tx.Rebind("UPDATE repos SET project_name = ? WHERE name = ?;") _, err := tx.ExecContext(ctx, query, projectName, name) diff --git a/server/store/database/settings.go b/server/store/database/settings.go index bb653a7ff188f2812e4003c846a3f16b5bab7d22..ec63eae01f9a8a2743a83da4c817b0e03c3771af 100644 --- a/server/store/database/settings.go +++ b/server/store/database/settings.go @@ -13,7 +13,7 @@ type settingsStore struct{} var _ store.SettingStore = (*settingsStore)(nil) // GetAllowKeylessAccess implements store.SettingStore. -func (*settingsStore) GetAllowKeylessAccess(ctx context.Context, tx *db.Tx) (bool, error) { +func (*settingsStore) GetAllowKeylessAccess(ctx context.Context, tx db.Handler) (bool, error) { var allow bool query := tx.Rebind(`SELECT value FROM settings WHERE key = "allow_keyless"`) if err := tx.GetContext(ctx, &allow, query); err != nil { @@ -23,7 +23,7 @@ func (*settingsStore) GetAllowKeylessAccess(ctx context.Context, tx *db.Tx) (boo } // GetAnonAccess implements store.SettingStore. -func (*settingsStore) GetAnonAccess(ctx context.Context, tx *db.Tx) (access.AccessLevel, error) { +func (*settingsStore) GetAnonAccess(ctx context.Context, tx db.Handler) (access.AccessLevel, error) { var level string query := tx.Rebind(`SELECT value FROM settings WHERE key = "anon_access"`) if err := tx.GetContext(ctx, &level, query); err != nil { @@ -33,14 +33,14 @@ func (*settingsStore) GetAnonAccess(ctx context.Context, tx *db.Tx) (access.Acce } // SetAllowKeylessAccess implements store.SettingStore. -func (*settingsStore) SetAllowKeylessAccess(ctx context.Context, tx *db.Tx, allow bool) error { +func (*settingsStore) SetAllowKeylessAccess(ctx context.Context, tx db.Handler, allow bool) error { query := tx.Rebind(`UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = "allow_keyless"`) _, err := tx.ExecContext(ctx, query, allow) return db.WrapError(err) } // SetAnonAccess implements store.SettingStore. -func (*settingsStore) SetAnonAccess(ctx context.Context, tx *db.Tx, level access.AccessLevel) error { +func (*settingsStore) SetAnonAccess(ctx context.Context, tx db.Handler, level access.AccessLevel) error { query := tx.Rebind(`UPDATE settings SET value = ?, updated_at = CURRENT_TIMESTAMP WHERE key = "anon_access"`) _, err := tx.ExecContext(ctx, query, level.String()) return db.WrapError(err) diff --git a/server/store/database/user.go b/server/store/database/user.go index 2e3a70beb45b9fb3625f98bbcf0d78604e190863..9ca824e62bfde8cdb9a2c2d77c6eecd442c59183 100644 --- a/server/store/database/user.go +++ b/server/store/database/user.go @@ -17,7 +17,7 @@ type userStore struct{} var _ store.UserStore = (*userStore)(nil) // AddPublicKeyByUsername implements store.UserStore. -func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx *db.Tx, username string, pk ssh.PublicKey) error { +func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error { username = strings.ToLower(username) if err := utils.ValidateUsername(username); err != nil { return err @@ -37,7 +37,7 @@ func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx *db.Tx, usernam } // CreateUser implements store.UserStore. -func (*userStore) CreateUser(ctx context.Context, tx *db.Tx, username string, isAdmin bool, pks []ssh.PublicKey) error { +func (*userStore) CreateUser(ctx context.Context, tx db.Handler, username string, isAdmin bool, pks []ssh.PublicKey) error { username = strings.ToLower(username) if err := utils.ValidateUsername(username); err != nil { return err @@ -69,7 +69,7 @@ func (*userStore) CreateUser(ctx context.Context, tx *db.Tx, username string, is } // DeleteUserByUsername implements store.UserStore. -func (*userStore) DeleteUserByUsername(ctx context.Context, tx *db.Tx, username string) error { +func (*userStore) DeleteUserByUsername(ctx context.Context, tx db.Handler, username string) error { username = strings.ToLower(username) if err := utils.ValidateUsername(username); err != nil { return err @@ -80,8 +80,16 @@ func (*userStore) DeleteUserByUsername(ctx context.Context, tx *db.Tx, username return err } +// GetUserByID implements store.UserStore. +func (*userStore) GetUserByID(ctx context.Context, tx db.Handler, id int64) (models.User, error) { + var m models.User + query := tx.Rebind(`SELECT * FROM users WHERE id = ?;`) + err := tx.GetContext(ctx, &m, query, id) + return m, err +} + // FindUserByPublicKey implements store.UserStore. -func (*userStore) FindUserByPublicKey(ctx context.Context, tx *db.Tx, pk ssh.PublicKey) (models.User, error) { +func (*userStore) FindUserByPublicKey(ctx context.Context, tx db.Handler, pk ssh.PublicKey) (models.User, error) { var m models.User query := tx.Rebind(`SELECT users.* FROM users @@ -92,7 +100,7 @@ func (*userStore) FindUserByPublicKey(ctx context.Context, tx *db.Tx, pk ssh.Pub } // FindUserByUsername implements store.UserStore. -func (*userStore) FindUserByUsername(ctx context.Context, tx *db.Tx, username string) (models.User, error) { +func (*userStore) FindUserByUsername(ctx context.Context, tx db.Handler, username string) (models.User, error) { username = strings.ToLower(username) if err := utils.ValidateUsername(username); err != nil { return models.User{}, err @@ -105,7 +113,7 @@ func (*userStore) FindUserByUsername(ctx context.Context, tx *db.Tx, username st } // GetAllUsers implements store.UserStore. -func (*userStore) GetAllUsers(ctx context.Context, tx *db.Tx) ([]models.User, error) { +func (*userStore) GetAllUsers(ctx context.Context, tx db.Handler) ([]models.User, error) { var ms []models.User query := tx.Rebind(`SELECT * FROM users;`) err := tx.SelectContext(ctx, &ms, query) @@ -113,7 +121,7 @@ func (*userStore) GetAllUsers(ctx context.Context, tx *db.Tx) ([]models.User, er } // ListPublicKeysByUserID implements store.UserStore.. -func (*userStore) ListPublicKeysByUserID(ctx context.Context, tx *db.Tx, id int64) ([]ssh.PublicKey, error) { +func (*userStore) ListPublicKeysByUserID(ctx context.Context, tx db.Handler, id int64) ([]ssh.PublicKey, error) { var aks []string query := tx.Rebind(`SELECT public_key FROM public_keys WHERE user_id = ? @@ -136,7 +144,7 @@ func (*userStore) ListPublicKeysByUserID(ctx context.Context, tx *db.Tx, id int6 } // ListPublicKeysByUsername implements store.UserStore. -func (*userStore) ListPublicKeysByUsername(ctx context.Context, tx *db.Tx, username string) ([]ssh.PublicKey, error) { +func (*userStore) ListPublicKeysByUsername(ctx context.Context, tx db.Handler, username string) ([]ssh.PublicKey, error) { username = strings.ToLower(username) if err := utils.ValidateUsername(username); err != nil { return nil, err @@ -165,7 +173,7 @@ func (*userStore) ListPublicKeysByUsername(ctx context.Context, tx *db.Tx, usern } // RemovePublicKeyByUsername implements store.UserStore. -func (*userStore) RemovePublicKeyByUsername(ctx context.Context, tx *db.Tx, username string, pk ssh.PublicKey) error { +func (*userStore) RemovePublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error { username = strings.ToLower(username) if err := utils.ValidateUsername(username); err != nil { return err @@ -179,7 +187,7 @@ func (*userStore) RemovePublicKeyByUsername(ctx context.Context, tx *db.Tx, user } // SetAdminByUsername implements store.UserStore. -func (*userStore) SetAdminByUsername(ctx context.Context, tx *db.Tx, username string, isAdmin bool) error { +func (*userStore) SetAdminByUsername(ctx context.Context, tx db.Handler, username string, isAdmin bool) error { username = strings.ToLower(username) if err := utils.ValidateUsername(username); err != nil { return err @@ -191,7 +199,7 @@ func (*userStore) SetAdminByUsername(ctx context.Context, tx *db.Tx, username st } // SetUsernameByUsername implements store.UserStore. -func (*userStore) SetUsernameByUsername(ctx context.Context, tx *db.Tx, username string, newUsername string) error { +func (*userStore) SetUsernameByUsername(ctx context.Context, tx db.Handler, username string, newUsername string) error { username = strings.ToLower(username) if err := utils.ValidateUsername(username); err != nil { return err diff --git a/server/store/lfs.go b/server/store/lfs.go new file mode 100644 index 0000000000000000000000000000000000000000..7632d2472bd7b65c43141e46dab7d49eb0741735 --- /dev/null +++ b/server/store/lfs.go @@ -0,0 +1,26 @@ +package store + +import ( + "context" + + "github.com/charmbracelet/soft-serve/server/db" + "github.com/charmbracelet/soft-serve/server/db/models" +) + +// LFSStore is the interface for the LFS store. +type LFSStore interface { + CreateLFSObject(ctx context.Context, h db.Handler, repoID int64, oid string, size int64) error + GetLFSObjectByOid(ctx context.Context, h db.Handler, repoID int64, oid string) (models.LFSObject, error) + GetLFSObjects(ctx context.Context, h db.Handler, repoID int64) ([]models.LFSObject, error) + GetLFSObjectsByName(ctx context.Context, h db.Handler, name string) ([]models.LFSObject, error) + DeleteLFSObjectByOid(ctx context.Context, h db.Handler, repoID int64, oid string) error + + CreateLFSLockForUser(ctx context.Context, h db.Handler, repoID int64, userID int64, path string, refname string) error + GetLFSLocks(ctx context.Context, h db.Handler, repoID int64) ([]models.LFSLock, error) + GetLFSLocksForUser(ctx context.Context, h db.Handler, repoID int64, userID int64) ([]models.LFSLock, error) + GetLFSLocksForPath(ctx context.Context, h db.Handler, repoID int64, path string) ([]models.LFSLock, error) + GetLFSLockForUserPath(ctx context.Context, h db.Handler, repoID int64, userID int64, path string) (models.LFSLock, error) + GetLFSLockByID(ctx context.Context, h db.Handler, id string) (models.LFSLock, error) + GetLFSLockForUserByID(ctx context.Context, h db.Handler, userID int64, id string) (models.LFSLock, error) + DeleteLFSLockForUserByID(ctx context.Context, h db.Handler, userID int64, id string) error +} diff --git a/server/store/store.go b/server/store/store.go index d933dfb7d38ad643d472e72808a8aba438ab241b..dcaa3165e4cd1821dbf402612388ef4fb4821c27 100644 --- a/server/store/store.go +++ b/server/store/store.go @@ -11,53 +11,54 @@ import ( // SettingStore is an interface for managing settings. type SettingStore interface { - GetAnonAccess(ctx context.Context, tx *db.Tx) (access.AccessLevel, error) - SetAnonAccess(ctx context.Context, tx *db.Tx, level access.AccessLevel) error - GetAllowKeylessAccess(ctx context.Context, tx *db.Tx) (bool, error) - SetAllowKeylessAccess(ctx context.Context, tx *db.Tx, allow bool) error + GetAnonAccess(ctx context.Context, h db.Handler) (access.AccessLevel, error) + SetAnonAccess(ctx context.Context, h db.Handler, level access.AccessLevel) error + GetAllowKeylessAccess(ctx context.Context, h db.Handler) (bool, error) + SetAllowKeylessAccess(ctx context.Context, h db.Handler, allow bool) error } // RepositoryStore is an interface for managing repositories. type RepositoryStore interface { - GetRepoByName(ctx context.Context, tx *db.Tx, name string) (models.Repo, error) - GetAllRepos(ctx context.Context, tx *db.Tx) ([]models.Repo, error) - CreateRepo(ctx context.Context, tx *db.Tx, name string, projectName string, description string, isPrivate bool, isHidden bool, isMirror bool) error - DeleteRepoByName(ctx context.Context, tx *db.Tx, name string) error - SetRepoNameByName(ctx context.Context, tx *db.Tx, name string, newName string) error + GetRepoByName(ctx context.Context, h db.Handler, name string) (models.Repo, error) + GetAllRepos(ctx context.Context, h db.Handler) ([]models.Repo, error) + CreateRepo(ctx context.Context, h db.Handler, name string, projectName string, description string, isPrivate bool, isHidden bool, isMirror bool) error + DeleteRepoByName(ctx context.Context, h db.Handler, name string) error + SetRepoNameByName(ctx context.Context, h db.Handler, name string, newName string) error - GetRepoProjectNameByName(ctx context.Context, tx *db.Tx, name string) (string, error) - SetRepoProjectNameByName(ctx context.Context, tx *db.Tx, name string, projectName string) error - GetRepoDescriptionByName(ctx context.Context, tx *db.Tx, name string) (string, error) - SetRepoDescriptionByName(ctx context.Context, tx *db.Tx, name string, description string) error - GetRepoIsPrivateByName(ctx context.Context, tx *db.Tx, name string) (bool, error) - SetRepoIsPrivateByName(ctx context.Context, tx *db.Tx, name string, isPrivate bool) error - GetRepoIsHiddenByName(ctx context.Context, tx *db.Tx, name string) (bool, error) - SetRepoIsHiddenByName(ctx context.Context, tx *db.Tx, name string, isHidden bool) error - GetRepoIsMirrorByName(ctx context.Context, tx *db.Tx, name string) (bool, error) + GetRepoProjectNameByName(ctx context.Context, h db.Handler, name string) (string, error) + SetRepoProjectNameByName(ctx context.Context, h db.Handler, name string, projectName string) error + GetRepoDescriptionByName(ctx context.Context, h db.Handler, name string) (string, error) + SetRepoDescriptionByName(ctx context.Context, h db.Handler, name string, description string) error + GetRepoIsPrivateByName(ctx context.Context, h db.Handler, name string) (bool, error) + SetRepoIsPrivateByName(ctx context.Context, h db.Handler, name string, isPrivate bool) error + GetRepoIsHiddenByName(ctx context.Context, h db.Handler, name string) (bool, error) + SetRepoIsHiddenByName(ctx context.Context, h db.Handler, name string, isHidden bool) error + GetRepoIsMirrorByName(ctx context.Context, h db.Handler, name string) (bool, error) } // UserStore is an interface for managing users. type UserStore interface { - FindUserByUsername(ctx context.Context, tx *db.Tx, username string) (models.User, error) - FindUserByPublicKey(ctx context.Context, tx *db.Tx, pk ssh.PublicKey) (models.User, error) - GetAllUsers(ctx context.Context, tx *db.Tx) ([]models.User, error) - CreateUser(ctx context.Context, tx *db.Tx, username string, isAdmin bool, pks []ssh.PublicKey) error - DeleteUserByUsername(ctx context.Context, tx *db.Tx, username string) error - SetUsernameByUsername(ctx context.Context, tx *db.Tx, username string, newUsername string) error - SetAdminByUsername(ctx context.Context, tx *db.Tx, username string, isAdmin bool) error - AddPublicKeyByUsername(ctx context.Context, tx *db.Tx, username string, pk ssh.PublicKey) error - RemovePublicKeyByUsername(ctx context.Context, tx *db.Tx, username string, pk ssh.PublicKey) error - ListPublicKeysByUserID(ctx context.Context, tx *db.Tx, id int64) ([]ssh.PublicKey, error) - ListPublicKeysByUsername(ctx context.Context, tx *db.Tx, username string) ([]ssh.PublicKey, error) + GetUserByID(ctx context.Context, h db.Handler, id int64) (models.User, error) + FindUserByUsername(ctx context.Context, h db.Handler, username string) (models.User, error) + FindUserByPublicKey(ctx context.Context, h db.Handler, pk ssh.PublicKey) (models.User, error) + GetAllUsers(ctx context.Context, h db.Handler) ([]models.User, error) + CreateUser(ctx context.Context, h db.Handler, username string, isAdmin bool, pks []ssh.PublicKey) error + DeleteUserByUsername(ctx context.Context, h db.Handler, username string) error + SetUsernameByUsername(ctx context.Context, h db.Handler, username string, newUsername string) error + SetAdminByUsername(ctx context.Context, h db.Handler, username string, isAdmin bool) error + AddPublicKeyByUsername(ctx context.Context, h db.Handler, username string, pk ssh.PublicKey) error + RemovePublicKeyByUsername(ctx context.Context, h db.Handler, username string, pk ssh.PublicKey) error + ListPublicKeysByUserID(ctx context.Context, h db.Handler, id int64) ([]ssh.PublicKey, error) + ListPublicKeysByUsername(ctx context.Context, h db.Handler, username string) ([]ssh.PublicKey, error) } // CollaboratorStore is an interface for managing collaborators. type CollaboratorStore interface { - GetCollabByUsernameAndRepo(ctx context.Context, tx *db.Tx, username string, repo string) (models.Collab, error) - AddCollabByUsernameAndRepo(ctx context.Context, tx *db.Tx, username string, repo string) error - RemoveCollabByUsernameAndRepo(ctx context.Context, tx *db.Tx, username string, repo string) error - ListCollabsByRepo(ctx context.Context, tx *db.Tx, repo string) ([]models.Collab, error) - ListCollabsByRepoAsUsers(ctx context.Context, tx *db.Tx, repo string) ([]models.User, error) + GetCollabByUsernameAndRepo(ctx context.Context, h db.Handler, username string, repo string) (models.Collab, error) + AddCollabByUsernameAndRepo(ctx context.Context, h db.Handler, username string, repo string) error + RemoveCollabByUsernameAndRepo(ctx context.Context, h db.Handler, username string, repo string) error + ListCollabsByRepo(ctx context.Context, h db.Handler, repo string) ([]models.Collab, error) + ListCollabsByRepoAsUsers(ctx context.Context, h db.Handler, repo string) ([]models.User, error) } // Store is an interface for managing repositories, users, and settings. @@ -66,4 +67,5 @@ type Store interface { UserStore CollaboratorStore SettingStore + LFSStore } diff --git a/testscript/script_test.go b/testscript/script_test.go index dbb8b956f53b173c38df509acfb379bb600bca23..d16c19f3fe304bc64febe5c42fe8008ef7ce8d0a 100644 --- a/testscript/script_test.go +++ b/testscript/script_test.go @@ -19,6 +19,8 @@ import ( "github.com/charmbracelet/soft-serve/server/config" "github.com/charmbracelet/soft-serve/server/db" "github.com/charmbracelet/soft-serve/server/db/migrate" + "github.com/charmbracelet/soft-serve/server/store" + "github.com/charmbracelet/soft-serve/server/store/database" "github.com/charmbracelet/soft-serve/server/test" "github.com/rogpeppe/go-internal/testscript" "golang.org/x/crypto/ssh" @@ -105,6 +107,8 @@ func TestScript(t *testing.T) { } ctx = db.WithContext(ctx, dbx) + datastore := database.New(ctx, dbx) + ctx = store.WithContext(ctx, datastore) be := backend.New(ctx, cfg, dbx) ctx = backend.WithContext(ctx, be)