diff --git a/Cargo.lock b/Cargo.lock index 20d455e4bc26e24f81ea8106954739f46f8b0be3..3aca27106c7c4ddce004d3c308ff5827ec0e1cc6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -108,6 +108,33 @@ dependencies = [ "util", ] +[[package]] +name = "ai2" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "bincode", + "futures 0.3.28", + "gpui2", + "isahc", + "language2", + "lazy_static", + "log", + "matrixmultiply", + "ordered-float 2.10.0", + "parking_lot 0.11.2", + "parse_duration", + "postage", + "rand 0.8.5", + "regex", + "rusqlite", + "serde", + "serde_json", + "tiktoken-rs", + "util", +] + [[package]] name = "alacritty_config" version = "0.1.2-dev" @@ -659,6 +686,20 @@ dependencies = [ "util", ] +[[package]] +name = "audio2" +version = "0.1.0" +dependencies = [ + "anyhow", + "collections", + "futures 0.3.28", + "gpui2", + "log", + "parking_lot 0.11.2", + "rodio", + "util", +] + [[package]] name = "auto_update" version = "0.1.0" @@ -776,6 +817,17 @@ dependencies = [ "rustc-demangle", ] +[[package]] +name = "backtrace-on-stack-overflow" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fd2d70527f3737a1ad17355e260706c1badebabd1fa06a7a053407380df841b" +dependencies = [ + "backtrace", + "libc", + "nix 0.23.2", +] + [[package]] name = "base64" version = "0.13.1" @@ -1104,6 +1156,32 @@ dependencies = [ "util", ] +[[package]] +name = "call2" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-broadcast", + "audio2", + "client2", + "collections", + "fs2", + "futures 0.3.28", + "gpui2", + "language2", + "live_kit_client", + "log", + "media", + "postage", + "project2", + "schemars", + "serde", + "serde_derive", + "serde_json", + "settings2", + "util", +] + [[package]] name = "cap-fs-ext" version = "0.24.4" @@ -1176,6 +1254,25 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2698f953def977c68f935bb0dfa959375ad4638570e969e2f1e9f433cbf1af6" +[[package]] +name = "cbindgen" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da6bc11b07529f16944307272d5bd9b22530bc7d05751717c9d416586cedab49" +dependencies = [ + "clap 3.2.25", + "heck 0.4.1", + "indexmap 1.9.3", + "log", + "proc-macro2", + "quote", + "serde", + "serde_json", + "syn 1.0.109", + "tempfile", + "toml 0.5.11", +] + [[package]] name = "cc" version = "1.0.83" @@ -1422,6 +1519,43 @@ dependencies = [ "uuid 1.4.1", ] +[[package]] +name = "client2" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-recursion 0.3.2", + "async-tungstenite", + "collections", + "db2", + "feature_flags2", + "futures 0.3.28", + "gpui2", + "image", + "lazy_static", + "log", + "parking_lot 0.11.2", + "postage", + "rand 0.8.5", + "rpc2", + "schemars", + "serde", + "serde_derive", + "settings", + "settings2", + "smol", + "sum_tree", + "sysinfo", + "tempfile", + "text", + "thiserror", + "time", + "tiny_http", + "url", + "util", + "uuid 1.4.1", +] + [[package]] name = "clock" version = "0.1.0" @@ -1724,6 +1858,33 @@ dependencies = [ "util", ] +[[package]] +name = "copilot2" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-compression", + "async-tar", + "clock", + "collections", + "context_menu", + "fs", + "futures 0.3.28", + "gpui2", + "language2", + "log", + "lsp2", + "node_runtime", + "parking_lot 0.11.2", + "rpc", + "serde", + "serde_derive", + "settings2", + "smol", + "theme", + "util", +] + [[package]] name = "copilot_button" version = "0.1.0" @@ -2154,6 +2315,28 @@ dependencies = [ "util", ] +[[package]] +name = "db2" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "collections", + "env_logger 0.9.3", + "gpui2", + "indoc", + "lazy_static", + "log", + "parking_lot 0.11.2", + "serde", + "serde_derive", + "smol", + "sqlez", + "sqlez_macros", + "tempdir", + "util", +] + [[package]] name = "deflate" version = "0.8.6" @@ -2621,6 +2804,14 @@ dependencies = [ "gpui", ] +[[package]] +name = "feature_flags2" +version = "0.1.0" +dependencies = [ + "anyhow", + "gpui2", +] + [[package]] name = "feedback" version = "0.1.0" @@ -2866,6 +3057,34 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "fs2" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "collections", + "fsevent", + "futures 0.3.28", + "git2", + "gpui2", + "lazy_static", + "libc", + "log", + "parking_lot 0.11.2", + "regex", + "rope", + "serde", + "serde_derive", + "serde_json", + "smol", + "sum_tree", + "tempfile", + "text", + "time", + "util", +] + [[package]] name = "fsevent" version = "2.0.2" @@ -3044,6 +3263,14 @@ dependencies = [ "util", ] +[[package]] +name = "fuzzy2" +version = "0.1.0" +dependencies = [ + "gpui2", + "util", +] + [[package]] name = "fxhash" version = "0.2.1" @@ -3258,20 +3485,64 @@ name = "gpui2" version = "0.1.0" dependencies = [ "anyhow", + "async-task", + "backtrace", + "bindgen 0.65.1", + "bitflags 2.4.0", + "block", + "cbindgen", + "cocoa", + "collections", + "core-foundation", + "core-graphics", + "core-text", + "ctor", "derive_more", + "dhat", + "env_logger 0.9.3", + "etagere", + "font-kit", + "foreign-types", "futures 0.3.28", - "gpui", "gpui2_macros", + "gpui_macros", + "image", + "itertools 0.10.5", + "lazy_static", "log", + "media", + "metal", + "num_cpus", + "objc", + "ordered-float 2.10.0", + "parking", "parking_lot 0.11.2", + "pathfinder_geometry", + "plane-split", + "png", + "postage", + "rand 0.8.5", "refineable", - "rust-embed", + "resvg", + "schemars", + "seahash", "serde", - "settings", + "serde_derive", + "serde_json", "simplelog", + "slotmap", "smallvec", - "theme", + "smol", + "sqlez", + "sum_tree", + "taffy", + "thiserror", + "time", + "tiny-skia", + "usvg", "util", + "uuid 1.4.1", + "waker-fn", ] [[package]] @@ -3698,6 +3969,17 @@ dependencies = [ "util", ] +[[package]] +name = "install_cli2" +version = "0.1.0" +dependencies = [ + "anyhow", + "gpui2", + "log", + "smol", + "util", +] + [[package]] name = "instant" version = "0.1.12" @@ -3916,6 +4198,24 @@ dependencies = [ "workspace", ] +[[package]] +name = "journal2" +version = "0.1.0" +dependencies = [ + "anyhow", + "chrono", + "dirs 4.0.0", + "editor", + "gpui2", + "log", + "schemars", + "serde", + "settings2", + "shellexpand", + "util", + "workspace", +] + [[package]] name = "jpeg-decoder" version = "0.1.22" @@ -4032,6 +4332,59 @@ dependencies = [ "util", ] +[[package]] +name = "language2" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-broadcast", + "async-trait", + "client2", + "clock", + "collections", + "ctor", + "env_logger 0.9.3", + "futures 0.3.28", + "fuzzy2", + "git", + "globset", + "gpui2", + "indoc", + "lazy_static", + "log", + "lsp2", + "parking_lot 0.11.2", + "postage", + "rand 0.8.5", + "regex", + "rpc2", + "schemars", + "serde", + "serde_derive", + "serde_json", + "settings2", + "similar", + "smallvec", + "smol", + "sum_tree", + "text", + "theme2", + "tree-sitter", + "tree-sitter-elixir", + "tree-sitter-embedded-template", + "tree-sitter-heex", + "tree-sitter-html", + "tree-sitter-json 0.20.0", + "tree-sitter-markdown", + "tree-sitter-python", + "tree-sitter-ruby", + "tree-sitter-rust", + "tree-sitter-typescript", + "unicase", + "unindent", + "util", +] + [[package]] name = "language_selector" version = "0.1.0" @@ -4310,6 +4663,29 @@ dependencies = [ "url", ] +[[package]] +name = "lsp2" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-pipe", + "collections", + "ctor", + "env_logger 0.9.3", + "futures 0.3.28", + "gpui2", + "log", + "lsp-types", + "parking_lot 0.11.2", + "postage", + "serde", + "serde_derive", + "serde_json", + "smol", + "unindent", + "util", +] + [[package]] name = "mach" version = "0.3.2" @@ -4446,6 +4822,13 @@ dependencies = [ "gpui", ] +[[package]] +name = "menu2" +version = "0.1.0" +dependencies = [ + "gpui2", +] + [[package]] name = "metal" version = "0.21.0" @@ -4738,6 +5121,19 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "nix" +version = "0.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f3790c00a0150112de0f4cd161e3d7fc4b2d8a5542ffc35f099a2562aecb35c" +dependencies = [ + "bitflags 1.3.2", + "cc", + "cfg-if 1.0.0", + "libc", + "memoffset 0.6.5", +] + [[package]] name = "nix" version = "0.24.3" @@ -4769,7 +5165,6 @@ dependencies = [ "async-tar", "async-trait", "futures 0.3.28", - "gpui", "log", "parking_lot 0.11.2", "serde", @@ -5491,6 +5886,17 @@ version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +[[package]] +name = "plane-split" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c1f7d82649829ecdef8e258790b0587acf0a8403f0ce963473d8e918acc1643" +dependencies = [ + "euclid", + "log", + "smallvec", +] + [[package]] name = "plist" version = "1.5.0" @@ -5621,6 +6027,27 @@ dependencies = [ "util", ] +[[package]] +name = "prettier2" +version = "0.1.0" +dependencies = [ + "anyhow", + "client2", + "collections", + "fs2", + "futures 0.3.28", + "gpui2", + "language2", + "log", + "lsp2", + "node_runtime", + "parking_lot 0.11.2", + "serde", + "serde_derive", + "serde_json", + "util", +] + [[package]] name = "pretty_assertions" version = "1.4.0" @@ -5700,58 +6127,113 @@ source = "git+https://github.com/zed-industries/wezterm?rev=5cd757e5f2eb039ed0c6 dependencies = [ "libc", "log", - "ntapi 0.3.7", - "winapi 0.3.9", + "ntapi 0.3.7", + "winapi 0.3.9", +] + +[[package]] +name = "project" +version = "0.1.0" +dependencies = [ + "aho-corasick", + "anyhow", + "async-trait", + "backtrace", + "client", + "clock", + "collections", + "copilot", + "ctor", + "db", + "env_logger 0.9.3", + "fs", + "fsevent", + "futures 0.3.28", + "fuzzy", + "git", + "git2", + "globset", + "gpui", + "ignore", + "itertools 0.10.5", + "language", + "lazy_static", + "log", + "lsp", + "node_runtime", + "parking_lot 0.11.2", + "postage", + "prettier", + "pretty_assertions", + "rand 0.8.5", + "regex", + "rpc", + "schemars", + "serde", + "serde_derive", + "serde_json", + "settings", + "sha2 0.10.7", + "similar", + "smol", + "sum_tree", + "tempdir", + "terminal", + "text", + "thiserror", + "toml 0.5.11", + "unindent", + "util", ] [[package]] -name = "project" +name = "project2" version = "0.1.0" dependencies = [ "aho-corasick", "anyhow", "async-trait", "backtrace", - "client", + "client2", "clock", "collections", - "copilot", + "copilot2", "ctor", - "db", + "db2", "env_logger 0.9.3", - "fs", + "fs2", "fsevent", "futures 0.3.28", - "fuzzy", + "fuzzy2", "git", "git2", "globset", - "gpui", + "gpui2", "ignore", "itertools 0.10.5", - "language", + "language2", "lazy_static", "log", - "lsp", + "lsp2", "node_runtime", "parking_lot 0.11.2", "postage", - "prettier", + "prettier2", "pretty_assertions", "rand 0.8.5", "regex", - "rpc", + "rpc2", "schemars", "serde", "serde_derive", "serde_json", - "settings", + "settings2", "sha2 0.10.7", "similar", "smol", "sum_tree", "tempdir", - "terminal", + "terminal2", "text", "thiserror", "toml 0.5.11", @@ -6494,6 +6976,35 @@ dependencies = [ "zstd", ] +[[package]] +name = "rpc2" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-lock", + "async-tungstenite", + "base64 0.13.1", + "clock", + "collections", + "ctor", + "env_logger 0.9.3", + "futures 0.3.28", + "gpui2", + "parking_lot 0.11.2", + "prost 0.8.0", + "prost-build", + "rand 0.8.5", + "rsa 0.4.0", + "serde", + "serde_derive", + "smol", + "smol-timeout", + "tempdir", + "tracing", + "util", + "zstd", +] + [[package]] name = "rsa" version = "0.4.0" @@ -7198,6 +7709,36 @@ dependencies = [ "util", ] +[[package]] +name = "settings2" +version = "0.1.0" +dependencies = [ + "anyhow", + "collections", + "feature_flags2", + "fs", + "fs2", + "futures 0.3.28", + "gpui2", + "indoc", + "lazy_static", + "postage", + "pretty_assertions", + "rust-embed", + "schemars", + "serde", + "serde_derive", + "serde_json", + "serde_json_lenient", + "smallvec", + "sqlez", + "toml 0.5.11", + "tree-sitter", + "tree-sitter-json 0.19.0", + "unindent", + "util", +] + [[package]] name = "sha-1" version = "0.9.8" @@ -7766,6 +8307,29 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "storybook2" +version = "0.1.0" +dependencies = [ + "anyhow", + "backtrace-on-stack-overflow", + "chrono", + "clap 4.4.4", + "gpui2", + "itertools 0.11.0", + "log", + "rust-embed", + "serde", + "settings2", + "simplelog", + "smallvec", + "strum", + "theme", + "theme2", + "ui2", + "util", +] + [[package]] name = "stringprep" version = "0.1.4" @@ -8075,6 +8639,35 @@ dependencies = [ "util", ] +[[package]] +name = "terminal2" +version = "0.1.0" +dependencies = [ + "alacritty_terminal", + "anyhow", + "db2", + "dirs 4.0.0", + "futures 0.3.28", + "gpui2", + "itertools 0.10.5", + "lazy_static", + "libc", + "mio-extras", + "ordered-float 2.10.0", + "procinfo", + "rand 0.8.5", + "schemars", + "serde", + "serde_derive", + "settings2", + "shellexpand", + "smallvec", + "smol", + "theme2", + "thiserror", + "util", +] + [[package]] name = "terminal_view" version = "0.1.0" @@ -8157,6 +8750,39 @@ dependencies = [ "util", ] +[[package]] +name = "theme2" +version = "0.1.0" +dependencies = [ + "anyhow", + "fs", + "gpui2", + "indexmap 1.9.3", + "parking_lot 0.11.2", + "schemars", + "serde", + "serde_derive", + "serde_json", + "settings2", + "toml 0.5.11", + "util", +] + +[[package]] +name = "theme_converter" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap 4.4.4", + "convert_case 0.6.0", + "gpui2", + "log", + "rust-embed", + "serde", + "simplelog", + "theme2", +] + [[package]] name = "theme_selector" version = "0.1.0" @@ -8964,6 +9590,21 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "ui2" +version = "0.1.0" +dependencies = [ + "anyhow", + "chrono", + "gpui2", + "itertools 0.11.0", + "rand 0.8.5", + "serde", + "smallvec", + "strum", + "theme2", +] + [[package]] name = "unicase" version = "2.7.0" @@ -10285,6 +10926,105 @@ dependencies = [ "serde", ] +[[package]] +name = "zed2" +version = "0.109.0" +dependencies = [ + "ai2", + "anyhow", + "async-compression", + "async-recursion 0.3.2", + "async-tar", + "async-trait", + "backtrace", + "call2", + "chrono", + "cli", + "client2", + "collections", + "copilot2", + "ctor", + "db2", + "env_logger 0.9.3", + "feature_flags2", + "fs2", + "fsevent", + "futures 0.3.28", + "fuzzy", + "gpui2", + "ignore", + "image", + "indexmap 1.9.3", + "install_cli", + "isahc", + "journal2", + "language2", + "language_tools", + "lazy_static", + "libc", + "log", + "lsp2", + "node_runtime", + "num_cpus", + "parking_lot 0.11.2", + "postage", + "project2", + "rand 0.8.5", + "regex", + "rpc2", + "rsa 0.4.0", + "rust-embed", + "schemars", + "serde", + "serde_derive", + "serde_json", + "settings2", + "shellexpand", + "simplelog", + "smallvec", + "smol", + "sum_tree", + "tempdir", + "text", + "theme2", + "thiserror", + "tiny_http", + "toml 0.5.11", + "tree-sitter", + "tree-sitter-bash", + "tree-sitter-c", + "tree-sitter-cpp", + "tree-sitter-css", + "tree-sitter-elixir", + "tree-sitter-elm", + "tree-sitter-embedded-template", + "tree-sitter-glsl", + "tree-sitter-go", + "tree-sitter-heex", + "tree-sitter-html", + "tree-sitter-json 0.20.0", + "tree-sitter-lua", + "tree-sitter-markdown", + "tree-sitter-nix", + "tree-sitter-nu", + "tree-sitter-php", + "tree-sitter-python", + "tree-sitter-racket", + "tree-sitter-ruby", + "tree-sitter-rust", + "tree-sitter-scheme", + "tree-sitter-svelte", + "tree-sitter-toml", + "tree-sitter-typescript", + "tree-sitter-vue", + "tree-sitter-yaml", + "unindent", + "url", + "urlencoding", + "util", + "uuid 1.4.1", +] + [[package]] name = "zeroize" version = "1.6.0" diff --git a/Cargo.toml b/Cargo.toml index 1d9da19605efb506e67f705806b762ab36f55bb1..ac490ce935eebdd2462714e0a21b9feff6d9eca7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,12 +4,15 @@ members = [ "crates/ai", "crates/assistant", "crates/audio", + "crates/audio2", "crates/auto_update", "crates/breadcrumbs", "crates/call", + "crates/call2", "crates/channel", "crates/cli", "crates/client", + "crates/client2", "crates/clock", "crates/collab", "crates/collab_ui", @@ -18,18 +21,24 @@ members = [ "crates/component_test", "crates/context_menu", "crates/copilot", + "crates/copilot2", "crates/copilot_button", "crates/db", + "crates/db2", "crates/refineable", "crates/refineable/derive_refineable", "crates/diagnostics", "crates/drag_and_drop", "crates/editor", + "crates/feature_flags", + "crates/feature_flags2", "crates/feedback", "crates/file_finder", "crates/fs", + "crates/fs2", "crates/fsevent", "crates/fuzzy", + "crates/fuzzy2", "crates/git", "crates/go_to_line", "crates/gpui", @@ -37,15 +46,20 @@ members = [ "crates/gpui2", "crates/gpui2_macros", "crates/install_cli", + "crates/install_cli2", "crates/journal", + "crates/journal2", "crates/language", + "crates/language2", "crates/language_selector", "crates/language_tools", "crates/live_kit_client", "crates/live_kit_server", "crates/lsp", + "crates/lsp2", "crates/media", "crates/menu", + "crates/menu2", "crates/multi_buffer", "crates/node_runtime", "crates/notifications", @@ -55,24 +69,32 @@ members = [ "crates/plugin_macros", "crates/plugin_runtime", "crates/prettier", + "crates/prettier2", "crates/project", + "crates/project2", "crates/project_panel", "crates/project_symbols", "crates/recent_projects", "crates/rope", "crates/rpc", + "crates/rpc2", "crates/search", "crates/settings", + "crates/settings2", "crates/snippet", "crates/sqlez", "crates/sqlez_macros", - "crates/feature_flags", "crates/rich_text", + "crates/storybook2", "crates/sum_tree", "crates/terminal", + "crates/terminal2", "crates/text", "crates/theme", + "crates/theme2", + "crates/theme_converter", "crates/theme_selector", + "crates/ui2", "crates/util", "crates/semantic_index", "crates/vim", @@ -81,6 +103,7 @@ members = [ "crates/welcome", "crates/xtask", "crates/zed", + "crates/zed2", "crates/zed-actions" ] default-members = ["crates/zed"] diff --git a/crates/Cargo.toml b/crates/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..fb49a4b515540836a757610db5c268321f9f068b --- /dev/null +++ b/crates/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "ai" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/ai.rs" +doctest = false + +[features] +test-support = [] + +[dependencies] +gpui = { path = "../gpui" } +util = { path = "../util" } +language = { path = "../language" } +async-trait.workspace = true +anyhow.workspace = true +futures.workspace = true +lazy_static.workspace = true +ordered-float.workspace = true +parking_lot.workspace = true +isahc.workspace = true +regex.workspace = true +serde.workspace = true +serde_json.workspace = true +postage.workspace = true +rand.workspace = true +log.workspace = true +parse_duration = "2.1.1" +tiktoken-rs = "0.5.0" +matrixmultiply = "0.3.7" +rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] } +bincode = "1.3.3" + +[dev-dependencies] +gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/ai2/Cargo.toml b/crates/ai2/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..4f06840e8e53bbcb06c377a6304fb8be13b85946 --- /dev/null +++ b/crates/ai2/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "ai2" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/ai2.rs" +doctest = false + +[features] +test-support = [] + +[dependencies] +gpui2 = { path = "../gpui2" } +util = { path = "../util" } +language2 = { path = "../language2" } +async-trait.workspace = true +anyhow.workspace = true +futures.workspace = true +lazy_static.workspace = true +ordered-float.workspace = true +parking_lot.workspace = true +isahc.workspace = true +regex.workspace = true +serde.workspace = true +serde_json.workspace = true +postage.workspace = true +rand.workspace = true +log.workspace = true +parse_duration = "2.1.1" +tiktoken-rs = "0.5.0" +matrixmultiply = "0.3.7" +rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] } +bincode = "1.3.3" + +[dev-dependencies] +gpui2 = { path = "../gpui2", features = ["test-support"] } diff --git a/crates/ai2/src/ai2.rs b/crates/ai2/src/ai2.rs new file mode 100644 index 0000000000000000000000000000000000000000..dda22d2a1d04dd6083fb1ae9879f49e74c8b4627 --- /dev/null +++ b/crates/ai2/src/ai2.rs @@ -0,0 +1,8 @@ +pub mod auth; +pub mod completion; +pub mod embedding; +pub mod models; +pub mod prompts; +pub mod providers; +#[cfg(any(test, feature = "test-support"))] +pub mod test; diff --git a/crates/ai2/src/auth.rs b/crates/ai2/src/auth.rs new file mode 100644 index 0000000000000000000000000000000000000000..e4670bb449025d5ecc5f0cabe65ad6ff4727c10c --- /dev/null +++ b/crates/ai2/src/auth.rs @@ -0,0 +1,17 @@ +use async_trait::async_trait; +use gpui2::AppContext; + +#[derive(Clone, Debug)] +pub enum ProviderCredential { + Credentials { api_key: String }, + NoCredentials, + NotNeeded, +} + +#[async_trait] +pub trait CredentialProvider: Send + Sync { + fn has_credentials(&self) -> bool; + async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential; + async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential); + async fn delete_credentials(&self, cx: &mut AppContext); +} diff --git a/crates/ai2/src/completion.rs b/crates/ai2/src/completion.rs new file mode 100644 index 0000000000000000000000000000000000000000..30a60fcf1d5c5dc66717773968e432e510d6421f --- /dev/null +++ b/crates/ai2/src/completion.rs @@ -0,0 +1,23 @@ +use anyhow::Result; +use futures::{future::BoxFuture, stream::BoxStream}; + +use crate::{auth::CredentialProvider, models::LanguageModel}; + +pub trait CompletionRequest: Send + Sync { + fn data(&self) -> serde_json::Result; +} + +pub trait CompletionProvider: CredentialProvider { + fn base_model(&self) -> Box; + fn complete( + &self, + prompt: Box, + ) -> BoxFuture<'static, Result>>>; + fn box_clone(&self) -> Box; +} + +impl Clone for Box { + fn clone(&self) -> Box { + self.box_clone() + } +} diff --git a/crates/ai2/src/embedding.rs b/crates/ai2/src/embedding.rs new file mode 100644 index 0000000000000000000000000000000000000000..7ea47861782cf9002796a8b6e655989b871e0191 --- /dev/null +++ b/crates/ai2/src/embedding.rs @@ -0,0 +1,123 @@ +use std::time::Instant; + +use anyhow::Result; +use async_trait::async_trait; +use ordered_float::OrderedFloat; +use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; +use rusqlite::ToSql; + +use crate::auth::CredentialProvider; +use crate::models::LanguageModel; + +#[derive(Debug, PartialEq, Clone)] +pub struct Embedding(pub Vec); + +// This is needed for semantic index functionality +// Unfortunately it has to live wherever the "Embedding" struct is created. +// Keeping this in here though, introduces a 'rusqlite' dependency into AI +// which is less than ideal +impl FromSql for Embedding { + fn column_result(value: ValueRef) -> FromSqlResult { + let bytes = value.as_blob()?; + let embedding: Result, Box> = bincode::deserialize(bytes); + if embedding.is_err() { + return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err())); + } + Ok(Embedding(embedding.unwrap())) + } +} + +impl ToSql for Embedding { + fn to_sql(&self) -> rusqlite::Result { + let bytes = bincode::serialize(&self.0) + .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?; + Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes))) + } +} +impl From> for Embedding { + fn from(value: Vec) -> Self { + Embedding(value) + } +} + +impl Embedding { + pub fn similarity(&self, other: &Self) -> OrderedFloat { + let len = self.0.len(); + assert_eq!(len, other.0.len()); + + let mut result = 0.0; + unsafe { + matrixmultiply::sgemm( + 1, + len, + 1, + 1.0, + self.0.as_ptr(), + len as isize, + 1, + other.0.as_ptr(), + 1, + len as isize, + 0.0, + &mut result as *mut f32, + 1, + 1, + ); + } + OrderedFloat(result) + } +} + +#[async_trait] +pub trait EmbeddingProvider: CredentialProvider { + fn base_model(&self) -> Box; + async fn embed_batch(&self, spans: Vec) -> Result>; + fn max_tokens_per_batch(&self) -> usize; + fn rate_limit_expiration(&self) -> Option; +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::prelude::*; + + #[gpui2::test] + fn test_similarity(mut rng: StdRng) { + assert_eq!( + Embedding::from(vec![1., 0., 0., 0., 0.]) + .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])), + 0. + ); + assert_eq!( + Embedding::from(vec![2., 0., 0., 0., 0.]) + .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])), + 6. + ); + + for _ in 0..100 { + let size = 1536; + let mut a = vec![0.; size]; + let mut b = vec![0.; size]; + for (a, b) in a.iter_mut().zip(b.iter_mut()) { + *a = rng.gen(); + *b = rng.gen(); + } + let a = Embedding::from(a); + let b = Embedding::from(b); + + assert_eq!( + round_to_decimals(a.similarity(&b), 1), + round_to_decimals(reference_dot(&a.0, &b.0), 1) + ); + } + + fn round_to_decimals(n: OrderedFloat, decimal_places: i32) -> f32 { + let factor = (10.0 as f32).powi(decimal_places); + (n * factor).round() / factor + } + + fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat { + OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()) + } + } +} diff --git a/crates/ai2/src/models.rs b/crates/ai2/src/models.rs new file mode 100644 index 0000000000000000000000000000000000000000..1db3d58c6f54ad613cb98fc3f425df3d47e5e97f --- /dev/null +++ b/crates/ai2/src/models.rs @@ -0,0 +1,16 @@ +pub enum TruncationDirection { + Start, + End, +} + +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> anyhow::Result; + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result; + fn capacity(&self) -> anyhow::Result; +} diff --git a/crates/ai2/src/prompts/base.rs b/crates/ai2/src/prompts/base.rs new file mode 100644 index 0000000000000000000000000000000000000000..29091d0f5b435b556a0c2cae60aa4526370832ab --- /dev/null +++ b/crates/ai2/src/prompts/base.rs @@ -0,0 +1,330 @@ +use std::cmp::Reverse; +use std::ops::Range; +use std::sync::Arc; + +use language2::BufferSnapshot; +use util::ResultExt; + +use crate::models::LanguageModel; +use crate::prompts::repository_context::PromptCodeSnippet; + +pub(crate) enum PromptFileType { + Text, + Code, +} + +// TODO: Set this up to manage for defaults well +pub struct PromptArguments { + pub model: Arc, + pub user_prompt: Option, + pub language_name: Option, + pub project_name: Option, + pub snippets: Vec, + pub reserved_tokens: usize, + pub buffer: Option, + pub selected_range: Option>, +} + +impl PromptArguments { + pub(crate) fn get_file_type(&self) -> PromptFileType { + if self + .language_name + .as_ref() + .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str()))) + .unwrap_or(true) + { + PromptFileType::Code + } else { + PromptFileType::Text + } + } +} + +pub trait PromptTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)>; +} + +#[repr(i8)] +#[derive(PartialEq, Eq, Ord)] +pub enum PromptPriority { + Mandatory, // Ignores truncation + Ordered { order: usize }, // Truncates based on priority +} + +impl PartialOrd for PromptPriority { + fn partial_cmp(&self, other: &Self) -> Option { + match (self, other) { + (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal), + (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater), + (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less), + (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a), + } + } +} + +pub struct PromptChain { + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, +} + +impl PromptChain { + pub fn new( + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, + ) -> Self { + PromptChain { args, templates } + } + + pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> { + // Argsort based on Prompt Priority + let seperator = "\n"; + let seperator_tokens = self.args.model.count_tokens(seperator)?; + let mut sorted_indices = (0..self.templates.len()).collect::>(); + sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0)); + + // If Truncate + let mut tokens_outstanding = if truncate { + Some(self.args.model.capacity()? - self.args.reserved_tokens) + } else { + None + }; + + let mut prompts = vec!["".to_string(); sorted_indices.len()]; + for idx in sorted_indices { + let (_, template) = &self.templates[idx]; + + if let Some((template_prompt, prompt_token_count)) = + template.generate(&self.args, tokens_outstanding).log_err() + { + if template_prompt != "" { + prompts[idx] = template_prompt; + + if let Some(remaining_tokens) = tokens_outstanding { + let new_tokens = prompt_token_count + seperator_tokens; + tokens_outstanding = if remaining_tokens > new_tokens { + Some(remaining_tokens - new_tokens) + } else { + Some(0) + }; + } + } + } + } + + prompts.retain(|x| x != ""); + + let full_prompt = prompts.join(seperator); + let total_token_count = self.args.model.count_tokens(&full_prompt)?; + anyhow::Ok((prompts.join(seperator), total_token_count)) + } +} + +#[cfg(test)] +pub(crate) mod tests { + use crate::models::TruncationDirection; + use crate::test::FakeLanguageModel; + + use super::*; + + #[test] + pub fn test_prompt_chain() { + struct TestPromptTemplate {} + impl PromptTemplate for TestPromptTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content)?; + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate( + &content, + max_token_length, + TruncationDirection::End, + )?; + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) + } + } + + struct TestLowPriorityTemplate {} + impl PromptTemplate for TestLowPriorityTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a low priority test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content)?; + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate( + &content, + max_token_length, + TruncationDirection::End, + )?; + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) + } + } + + let model: Arc = Arc::new(FakeLanguageModel { capacity: 100 }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let model: Arc = Arc::new(FakeLanguageModel { capacity: 20 }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let capacity = 20; + let model: Arc = Arc::new(FakeLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 2 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + + assert_eq!(prompt, "This is a test promp".to_string()); + assert_eq!(token_count, capacity); + + // Change Ordering of Prompts Based on Priority + let capacity = 120; + let reserved_tokens = 10; + let model: Arc = Arc::new(FakeLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens, + buffer: None, + selected_range: None, + user_prompt: None, + }; + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Mandatory, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + + assert_eq!( + prompt, + "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt " + .to_string() + ); + assert_eq!(token_count, capacity - reserved_tokens); + } +} diff --git a/crates/ai2/src/prompts/file_context.rs b/crates/ai2/src/prompts/file_context.rs new file mode 100644 index 0000000000000000000000000000000000000000..4a741beb24984c5c038ef10439d5b182438d2866 --- /dev/null +++ b/crates/ai2/src/prompts/file_context.rs @@ -0,0 +1,164 @@ +use anyhow::anyhow; +use language2::BufferSnapshot; +use language2::ToOffset; + +use crate::models::LanguageModel; +use crate::models::TruncationDirection; +use crate::prompts::base::PromptArguments; +use crate::prompts::base::PromptTemplate; +use std::fmt::Write; +use std::ops::Range; +use std::sync::Arc; + +fn retrieve_context( + buffer: &BufferSnapshot, + selected_range: &Option>, + model: Arc, + max_token_count: Option, +) -> anyhow::Result<(String, usize, bool)> { + let mut prompt = String::new(); + let mut truncated = false; + if let Some(selected_range) = selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + let start_window = buffer.text_for_range(0..start).collect::(); + + let mut selected_window = String::new(); + if start == end { + write!(selected_window, "<|START|>").unwrap(); + } else { + write!(selected_window, "<|START|").unwrap(); + } + + write!( + selected_window, + "{}", + buffer.text_for_range(start..end).collect::() + ) + .unwrap(); + + if start != end { + write!(selected_window, "|END|>").unwrap(); + } + + let end_window = buffer.text_for_range(end..buffer.len()).collect::(); + + if let Some(max_token_count) = max_token_count { + let selected_tokens = model.count_tokens(&selected_window)?; + if selected_tokens > max_token_count { + return Err(anyhow!( + "selected range is greater than model context window, truncation not possible" + )); + }; + + let mut remaining_tokens = max_token_count - selected_tokens; + let start_window_tokens = model.count_tokens(&start_window)?; + let end_window_tokens = model.count_tokens(&end_window)?; + let outside_tokens = start_window_tokens + end_window_tokens; + if outside_tokens > remaining_tokens { + let (start_goal_tokens, end_goal_tokens) = + if start_window_tokens < end_window_tokens { + let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens); + remaining_tokens -= start_goal_tokens; + let end_goal_tokens = remaining_tokens.min(end_window_tokens); + (start_goal_tokens, end_goal_tokens) + } else { + let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens); + remaining_tokens -= end_goal_tokens; + let start_goal_tokens = remaining_tokens.min(start_window_tokens); + (start_goal_tokens, end_goal_tokens) + }; + + let truncated_start_window = + model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?; + let truncated_end_window = + model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?; + writeln!( + prompt, + "{truncated_start_window}{selected_window}{truncated_end_window}" + ) + .unwrap(); + truncated = true; + } else { + writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap(); + } + } else { + // If we dont have a selected range, include entire file. + writeln!(prompt, "{}", &buffer.text()).unwrap(); + + // Dumb truncation strategy + if let Some(max_token_count) = max_token_count { + if model.count_tokens(&prompt)? > max_token_count { + truncated = true; + prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?; + } + } + } + } + + let token_count = model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count, truncated)) +} + +pub struct FileContext {} + +impl PromptTemplate for FileContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + if let Some(buffer) = &args.buffer { + let mut prompt = String::new(); + // Add Initial Preamble + // TODO: Do we want to add the path in here? + writeln!( + prompt, + "The file you are currently working on has the following content:" + ) + .unwrap(); + + let language_name = args + .language_name + .clone() + .unwrap_or("".to_string()) + .to_lowercase(); + + let (context, _, truncated) = retrieve_context( + buffer, + &args.selected_range, + args.model.clone(), + max_token_length, + )?; + writeln!(prompt, "```{language_name}\n{context}\n```").unwrap(); + + if truncated { + writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap(); + } + + if let Some(selected_range) = &args.selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + if start == end { + writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap(); + } else { + writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap(); + } + } + + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args + .model + .truncate(&prompt, max_tokens, TruncationDirection::End)?; + } + + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } else { + Err(anyhow!("no buffer provided to retrieve file context from")) + } + } +} diff --git a/crates/ai2/src/prompts/generate.rs b/crates/ai2/src/prompts/generate.rs new file mode 100644 index 0000000000000000000000000000000000000000..c7be620107ee4d6daca06a8cb38019aceedc40a4 --- /dev/null +++ b/crates/ai2/src/prompts/generate.rs @@ -0,0 +1,99 @@ +use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate}; +use anyhow::anyhow; +use std::fmt::Write; + +pub fn capitalize(s: &str) -> String { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } +} + +pub struct GenerateInlineContent {} + +impl PromptTemplate for GenerateInlineContent { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let Some(user_prompt) = &args.user_prompt else { + return Err(anyhow!("user prompt not provided")); + }; + + let file_type = args.get_file_type(); + let content_type = match &file_type { + PromptFileType::Code => "code", + PromptFileType::Text => "text", + }; + + let mut prompt = String::new(); + + if let Some(selected_range) = &args.selected_range { + if selected_range.start == selected_range.end { + writeln!( + prompt, + "Assume the cursor is located where the `<|START|>` span is." + ) + .unwrap(); + writeln!( + prompt, + "{} can't be replaced, so assume your answer will be inserted at the cursor.", + capitalize(content_type) + ) + .unwrap(); + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}", + ) + .unwrap(); + } else { + writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap(); + writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap(); + writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap(); + } + } else { + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}" + ) + .unwrap(); + } + + if let Some(language_name) = &args.language_name { + writeln!( + prompt, + "Your answer MUST always and only be valid {}.", + language_name + ) + .unwrap(); + } + writeln!(prompt, "Never make remarks about the output.").unwrap(); + writeln!( + prompt, + "Do not return anything else, except the generated {content_type}." + ) + .unwrap(); + + match file_type { + PromptFileType::Code => { + // writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap(); + } + _ => {} + } + + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args.model.truncate( + &prompt, + max_tokens, + crate::models::TruncationDirection::End, + )?; + } + + let token_count = args.model.count_tokens(&prompt)?; + + anyhow::Ok((prompt, token_count)) + } +} diff --git a/crates/ai2/src/prompts/mod.rs b/crates/ai2/src/prompts/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..0025269a440d1e6ead6a81615a64a3c28da62bb8 --- /dev/null +++ b/crates/ai2/src/prompts/mod.rs @@ -0,0 +1,5 @@ +pub mod base; +pub mod file_context; +pub mod generate; +pub mod preamble; +pub mod repository_context; diff --git a/crates/ai2/src/prompts/preamble.rs b/crates/ai2/src/prompts/preamble.rs new file mode 100644 index 0000000000000000000000000000000000000000..92e0edeb78b48169379aae2e88e81f62463a1057 --- /dev/null +++ b/crates/ai2/src/prompts/preamble.rs @@ -0,0 +1,52 @@ +use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate}; +use std::fmt::Write; + +pub struct EngineerPreamble {} + +impl PromptTemplate for EngineerPreamble { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut prompts = Vec::new(); + + match args.get_file_type() { + PromptFileType::Code => { + prompts.push(format!( + "You are an expert {}engineer.", + args.language_name.clone().unwrap_or("".to_string()) + " " + )); + } + PromptFileType::Text => { + prompts.push("You are an expert engineer.".to_string()); + } + } + + if let Some(project_name) = args.project_name.clone() { + prompts.push(format!( + "You are currently working inside the '{project_name}' project in code editor Zed." + )); + } + + if let Some(mut remaining_tokens) = max_token_length { + let mut prompt = String::new(); + let mut total_count = 0; + for prompt_piece in prompts { + let prompt_token_count = + args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?; + if remaining_tokens > prompt_token_count { + writeln!(prompt, "{prompt_piece}").unwrap(); + remaining_tokens -= prompt_token_count; + total_count += prompt_token_count; + } + } + + anyhow::Ok((prompt, total_count)) + } else { + let prompt = prompts.join("\n"); + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } + } +} diff --git a/crates/ai2/src/prompts/repository_context.rs b/crates/ai2/src/prompts/repository_context.rs new file mode 100644 index 0000000000000000000000000000000000000000..1bb75de7d242d34315238f2c43baeb0c016dbdfb --- /dev/null +++ b/crates/ai2/src/prompts/repository_context.rs @@ -0,0 +1,98 @@ +use crate::prompts::base::{PromptArguments, PromptTemplate}; +use std::fmt::Write; +use std::{ops::Range, path::PathBuf}; + +use gpui2::{AsyncAppContext, Model}; +use language2::{Anchor, Buffer}; + +#[derive(Clone)] +pub struct PromptCodeSnippet { + path: Option, + language_name: Option, + content: String, +} + +impl PromptCodeSnippet { + pub fn new( + buffer: Model, + range: Range, + cx: &mut AsyncAppContext, + ) -> anyhow::Result { + let (content, language_name, file_path) = buffer.update(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + let content = snapshot.text_for_range(range.clone()).collect::(); + + let language_name = buffer + .language() + .and_then(|language| Some(language.name().to_string().to_lowercase())); + + let file_path = buffer + .file() + .and_then(|file| Some(file.path().to_path_buf())); + + (content, language_name, file_path) + })?; + + anyhow::Ok(PromptCodeSnippet { + path: file_path, + language_name, + content, + }) + } +} + +impl ToString for PromptCodeSnippet { + fn to_string(&self) -> String { + let path = self + .path + .as_ref() + .and_then(|path| Some(path.to_string_lossy().to_string())) + .unwrap_or("".to_string()); + let language_name = self.language_name.clone().unwrap_or("".to_string()); + let content = self.content.clone(); + + format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") + } +} + +pub struct RepositoryContext {} + +impl PromptTemplate for RepositoryContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; + let template = "You are working inside a large repository, here are a few code snippets that may be useful."; + let mut prompt = String::new(); + + let mut remaining_tokens = max_token_length.clone(); + let seperator_token_length = args.model.count_tokens("\n")?; + for snippet in &args.snippets { + let mut snippet_prompt = template.to_string(); + let content = snippet.to_string(); + writeln!(snippet_prompt, "{content}").unwrap(); + + let token_count = args.model.count_tokens(&snippet_prompt)?; + if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT { + if let Some(tokens_left) = remaining_tokens { + if tokens_left >= token_count { + writeln!(prompt, "{snippet_prompt}").unwrap(); + remaining_tokens = if tokens_left >= (token_count + seperator_token_length) + { + Some(tokens_left - token_count - seperator_token_length) + } else { + Some(0) + }; + } + } else { + writeln!(prompt, "{snippet_prompt}").unwrap(); + } + } + } + + let total_token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, total_token_count)) + } +} diff --git a/crates/ai2/src/providers/mod.rs b/crates/ai2/src/providers/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..acd0f9d91053869e3999ef0c1a23326480a7cbdd --- /dev/null +++ b/crates/ai2/src/providers/mod.rs @@ -0,0 +1 @@ +pub mod open_ai; diff --git a/crates/ai2/src/providers/open_ai/completion.rs b/crates/ai2/src/providers/open_ai/completion.rs new file mode 100644 index 0000000000000000000000000000000000000000..eca56110271a3be407a1c8b9f82cbb63c41bef23 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/completion.rs @@ -0,0 +1,306 @@ +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use futures::{ + future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, + Stream, StreamExt, +}; +use gpui2::{AppContext, Executor}; +use isahc::{http::StatusCode, Request, RequestExt}; +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; +use std::{ + env, + fmt::{self, Display}, + io, + sync::Arc, +}; +use util::ResultExt; + +use crate::{ + auth::{CredentialProvider, ProviderCredential}, + completion::{CompletionProvider, CompletionRequest}, + models::LanguageModel, +}; + +use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL}; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, +} + +impl Role { + pub fn cycle(&mut self) { + *self = match self { + Role::User => Role::Assistant, + Role::Assistant => Role::System, + Role::System => Role::User, + } + } +} + +impl Display for Role { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Role::User => write!(f, "User"), + Role::Assistant => write!(f, "Assistant"), + Role::System => write!(f, "System"), + } + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct RequestMessage { + pub role: Role, + pub content: String, +} + +#[derive(Debug, Default, Serialize)] +pub struct OpenAIRequest { + pub model: String, + pub messages: Vec, + pub stream: bool, + pub stop: Vec, + pub temperature: f32, +} + +impl CompletionRequest for OpenAIRequest { + fn data(&self) -> serde_json::Result { + serde_json::to_string(self) + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ResponseMessage { + pub role: Option, + pub content: Option, +} + +#[derive(Deserialize, Debug)] +pub struct OpenAIUsage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Deserialize, Debug)] +pub struct ChatChoiceDelta { + pub index: u32, + pub delta: ResponseMessage, + pub finish_reason: Option, +} + +#[derive(Deserialize, Debug)] +pub struct OpenAIResponseStreamEvent { + pub id: Option, + pub object: String, + pub created: u32, + pub model: String, + pub choices: Vec, + pub usage: Option, +} + +pub async fn stream_completion( + credential: ProviderCredential, + executor: Arc, + request: Box, +) -> Result>> { + let api_key = match credential { + ProviderCredential::Credentials { api_key } => api_key, + _ => { + return Err(anyhow!("no credentials provider for completion")); + } + }; + + let (tx, rx) = futures::channel::mpsc::unbounded::>(); + + let json_data = request.data()?; + let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body(json_data)? + .send_async() + .await?; + + let status = response.status(); + if status == StatusCode::OK { + executor + .spawn(async move { + let mut lines = BufReader::new(response.body_mut()).lines(); + + fn parse_line( + line: Result, + ) -> Result> { + if let Some(data) = line?.strip_prefix("data: ") { + let event = serde_json::from_str(&data)?; + Ok(Some(event)) + } else { + Ok(None) + } + } + + while let Some(line) = lines.next().await { + if let Some(event) = parse_line(line).transpose() { + let done = event.as_ref().map_or(false, |event| { + event + .choices + .last() + .map_or(false, |choice| choice.finish_reason.is_some()) + }); + if tx.unbounded_send(event).is_err() { + break; + } + + if done { + break; + } + } + } + + anyhow::Ok(()) + }) + .detach(); + + Ok(rx) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + #[derive(Deserialize)] + struct OpenAIResponse { + error: OpenAIError, + } + + #[derive(Deserialize)] + struct OpenAIError { + message: String, + } + + match serde_json::from_str::(&body) { + Ok(response) if !response.error.message.is_empty() => Err(anyhow!( + "Failed to connect to OpenAI API: {}", + response.error.message, + )), + + _ => Err(anyhow!( + "Failed to connect to OpenAI API: {} {}", + response.status(), + body, + )), + } + } +} + +#[derive(Clone)] +pub struct OpenAICompletionProvider { + model: OpenAILanguageModel, + credential: Arc>, + executor: Arc, +} + +impl OpenAICompletionProvider { + pub fn new(model_name: &str, executor: Arc) -> Self { + let model = OpenAILanguageModel::load(model_name); + let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); + Self { + model, + credential, + executor, + } + } +} + +#[async_trait] +impl CredentialProvider for OpenAICompletionProvider { + fn has_credentials(&self) -> bool { + match *self.credential.read() { + ProviderCredential::Credentials { .. } => true, + _ => false, + } + } + async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential { + let existing_credential = self.credential.read().clone(); + + let retrieved_credential = cx + .run_on_main(move |cx| match existing_credential { + ProviderCredential::Credentials { .. } => { + return existing_credential.clone(); + } + _ => { + if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() { + return ProviderCredential::Credentials { api_key }; + } + + if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err() + { + if let Some(api_key) = String::from_utf8(api_key).log_err() { + return ProviderCredential::Credentials { api_key }; + } else { + return ProviderCredential::NoCredentials; + } + } else { + return ProviderCredential::NoCredentials; + } + } + }) + .await; + + *self.credential.write() = retrieved_credential.clone(); + retrieved_credential + } + + async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) { + *self.credential.write() = credential.clone(); + let credential = credential.clone(); + cx.run_on_main(move |cx| match credential { + ProviderCredential::Credentials { api_key } => { + cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + }) + .await; + } + async fn delete_credentials(&self, cx: &mut AppContext) { + cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err()) + .await; + *self.credential.write() = ProviderCredential::NoCredentials; + } +} + +impl CompletionProvider for OpenAICompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(self.model.clone()); + model + } + fn complete( + &self, + prompt: Box, + ) -> BoxFuture<'static, Result>>> { + // Currently the CompletionRequest for OpenAI, includes a 'model' parameter + // This means that the model is determined by the CompletionRequest and not the CompletionProvider, + // which is currently model based, due to the langauge model. + // At some point in the future we should rectify this. + let credential = self.credential.read().clone(); + let request = stream_completion(credential, self.executor.clone(), prompt); + async move { + let response = request.await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), + Err(error) => Some(Err(error)), + } + }) + .boxed(); + Ok(stream) + } + .boxed() + } + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } +} diff --git a/crates/ai2/src/providers/open_ai/embedding.rs b/crates/ai2/src/providers/open_ai/embedding.rs new file mode 100644 index 0000000000000000000000000000000000000000..fc49c15134d0aba787968acbd412daff60ce6106 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/embedding.rs @@ -0,0 +1,313 @@ +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use futures::AsyncReadExt; +use gpui2::Executor; +use gpui2::{serde_json, AppContext}; +use isahc::http::StatusCode; +use isahc::prelude::Configurable; +use isahc::{AsyncBody, Response}; +use lazy_static::lazy_static; +use parking_lot::{Mutex, RwLock}; +use parse_duration::parse; +use postage::watch; +use serde::{Deserialize, Serialize}; +use std::env; +use std::ops::Add; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tiktoken_rs::{cl100k_base, CoreBPE}; +use util::http::{HttpClient, Request}; +use util::ResultExt; + +use crate::auth::{CredentialProvider, ProviderCredential}; +use crate::embedding::{Embedding, EmbeddingProvider}; +use crate::models::LanguageModel; +use crate::providers::open_ai::OpenAILanguageModel; + +use crate::providers::open_ai::OPENAI_API_URL; + +lazy_static! { + static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); +} + +#[derive(Clone)] +pub struct OpenAIEmbeddingProvider { + model: OpenAILanguageModel, + credential: Arc>, + pub client: Arc, + pub executor: Arc, + rate_limit_count_rx: watch::Receiver>, + rate_limit_count_tx: Arc>>>, +} + +#[derive(Serialize)] +struct OpenAIEmbeddingRequest<'a> { + model: &'static str, + input: Vec<&'a str>, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingResponse { + data: Vec, + usage: OpenAIEmbeddingUsage, +} + +#[derive(Debug, Deserialize)] +struct OpenAIEmbedding { + embedding: Vec, + index: usize, + object: String, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingUsage { + prompt_tokens: usize, + total_tokens: usize, +} + +impl OpenAIEmbeddingProvider { + pub fn new(client: Arc, executor: Arc) -> Self { + let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); + let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); + + let model = OpenAILanguageModel::load("text-embedding-ada-002"); + let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); + + OpenAIEmbeddingProvider { + model, + credential, + client, + executor, + rate_limit_count_rx, + rate_limit_count_tx, + } + } + + fn get_api_key(&self) -> Result { + match self.credential.read().clone() { + ProviderCredential::Credentials { api_key } => Ok(api_key), + _ => Err(anyhow!("api credentials not provided")), + } + } + + fn resolve_rate_limit(&self) { + let reset_time = *self.rate_limit_count_tx.lock().borrow(); + + if let Some(reset_time) = reset_time { + if Instant::now() >= reset_time { + *self.rate_limit_count_tx.lock().borrow_mut() = None + } + } + + log::trace!( + "resolving reset time: {:?}", + *self.rate_limit_count_tx.lock().borrow() + ); + } + + fn update_reset_time(&self, reset_time: Instant) { + let original_time = *self.rate_limit_count_tx.lock().borrow(); + + let updated_time = if let Some(original_time) = original_time { + if reset_time < original_time { + Some(reset_time) + } else { + Some(original_time) + } + } else { + Some(reset_time) + }; + + log::trace!("updating rate limit time: {:?}", updated_time); + + *self.rate_limit_count_tx.lock().borrow_mut() = updated_time; + } + async fn send_request( + &self, + api_key: &str, + spans: Vec<&str>, + request_timeout: u64, + ) -> Result> { + let request = Request::post("https://api.openai.com/v1/embeddings") + .redirect_policy(isahc::config::RedirectPolicy::Follow) + .timeout(Duration::from_secs(request_timeout)) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body( + serde_json::to_string(&OpenAIEmbeddingRequest { + input: spans.clone(), + model: "text-embedding-ada-002", + }) + .unwrap() + .into(), + )?; + + Ok(self.client.send(request).await?) + } +} + +#[async_trait] +impl CredentialProvider for OpenAIEmbeddingProvider { + fn has_credentials(&self) -> bool { + match *self.credential.read() { + ProviderCredential::Credentials { .. } => true, + _ => false, + } + } + async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential { + let existing_credential = self.credential.read().clone(); + + let retrieved_credential = cx + .run_on_main(move |cx| match existing_credential { + ProviderCredential::Credentials { .. } => { + return existing_credential.clone(); + } + _ => { + if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() { + return ProviderCredential::Credentials { api_key }; + } + + if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err() + { + if let Some(api_key) = String::from_utf8(api_key).log_err() { + return ProviderCredential::Credentials { api_key }; + } else { + return ProviderCredential::NoCredentials; + } + } else { + return ProviderCredential::NoCredentials; + } + } + }) + .await; + + *self.credential.write() = retrieved_credential.clone(); + retrieved_credential + } + + async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) { + *self.credential.write() = credential.clone(); + let credential = credential.clone(); + cx.run_on_main(move |cx| match credential { + ProviderCredential::Credentials { api_key } => { + cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + }) + .await; + } + async fn delete_credentials(&self, cx: &mut AppContext) { + cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err()) + .await; + *self.credential.write() = ProviderCredential::NoCredentials; + } +} + +#[async_trait] +impl EmbeddingProvider for OpenAIEmbeddingProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(self.model.clone()); + model + } + + fn max_tokens_per_batch(&self) -> usize { + 50000 + } + + fn rate_limit_expiration(&self) -> Option { + *self.rate_limit_count_rx.borrow() + } + + async fn embed_batch(&self, spans: Vec) -> Result> { + const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; + const MAX_RETRIES: usize = 4; + + let api_key = self.get_api_key()?; + + let mut request_number = 0; + let mut rate_limiting = false; + let mut request_timeout: u64 = 15; + let mut response: Response; + while request_number < MAX_RETRIES { + response = self + .send_request( + &api_key, + spans.iter().map(|x| &**x).collect(), + request_timeout, + ) + .await?; + + request_number += 1; + + match response.status() { + StatusCode::REQUEST_TIMEOUT => { + request_timeout += 5; + } + StatusCode::OK => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; + + log::trace!( + "openai embedding completed. tokens: {:?}", + response.usage.total_tokens + ); + + // If we complete a request successfully that was previously rate_limited + // resolve the rate limit + if rate_limiting { + self.resolve_rate_limit() + } + + return Ok(response + .data + .into_iter() + .map(|embedding| Embedding::from(embedding.embedding)) + .collect()); + } + StatusCode::TOO_MANY_REQUESTS => { + rate_limiting = true; + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + let delay_duration = { + let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); + if let Some(time_to_reset) = + response.headers().get("x-ratelimit-reset-tokens") + { + if let Ok(time_str) = time_to_reset.to_str() { + parse(time_str).unwrap_or(delay) + } else { + delay + } + } else { + delay + } + }; + + // If we've previously rate limited, increment the duration but not the count + let reset_time = Instant::now().add(delay_duration); + self.update_reset_time(reset_time); + + log::trace!( + "openai rate limiting: waiting {:?} until lifted", + &delay_duration + ); + + self.executor.timer(delay_duration).await; + } + _ => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!( + "open ai bad request: {:?} {:?}", + &response.status(), + body + )); + } + } + } + Err(anyhow!("openai max retries")) + } +} diff --git a/crates/ai2/src/providers/open_ai/mod.rs b/crates/ai2/src/providers/open_ai/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..7d2f86045d9e1049c55111aef175ac9b56dc7e16 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/mod.rs @@ -0,0 +1,9 @@ +pub mod completion; +pub mod embedding; +pub mod model; + +pub use completion::*; +pub use embedding::*; +pub use model::OpenAILanguageModel; + +pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; diff --git a/crates/ai2/src/providers/open_ai/model.rs b/crates/ai2/src/providers/open_ai/model.rs new file mode 100644 index 0000000000000000000000000000000000000000..6e306c80b905865c011c9064934827085ca126d6 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/model.rs @@ -0,0 +1,57 @@ +use anyhow::anyhow; +use tiktoken_rs::CoreBPE; +use util::ResultExt; + +use crate::models::{LanguageModel, TruncationDirection}; + +#[derive(Clone)] +pub struct OpenAILanguageModel { + name: String, + bpe: Option, +} + +impl OpenAILanguageModel { + pub fn load(model_name: &str) -> Self { + let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); + OpenAILanguageModel { + name: model_name.to_string(), + bpe, + } + } +} + +impl LanguageModel for OpenAILanguageModel { + fn name(&self) -> String { + self.name.clone() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + anyhow::Ok(bpe.encode_with_special_tokens(content).len()) + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + let tokens = bpe.encode_with_special_tokens(content); + if tokens.len() > length { + match direction { + TruncationDirection::End => bpe.decode(tokens[..length].to_vec()), + TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()), + } + } else { + bpe.decode(tokens) + } + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) + } +} diff --git a/crates/ai2/src/providers/open_ai/new.rs b/crates/ai2/src/providers/open_ai/new.rs new file mode 100644 index 0000000000000000000000000000000000000000..c7d67f2ba1d252a6865124d8ffdfb79130a8c3a0 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/new.rs @@ -0,0 +1,11 @@ +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> anyhow::Result; + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result; + fn capacity(&self) -> anyhow::Result; +} diff --git a/crates/ai2/src/test.rs b/crates/ai2/src/test.rs new file mode 100644 index 0000000000000000000000000000000000000000..ee88529aecb004ce3b725fb61abd679359673404 --- /dev/null +++ b/crates/ai2/src/test.rs @@ -0,0 +1,193 @@ +use std::{ + sync::atomic::{self, AtomicUsize, Ordering}, + time::Instant, +}; + +use async_trait::async_trait; +use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use gpui2::AppContext; +use parking_lot::Mutex; + +use crate::{ + auth::{CredentialProvider, ProviderCredential}, + completion::{CompletionProvider, CompletionRequest}, + embedding::{Embedding, EmbeddingProvider}, + models::{LanguageModel, TruncationDirection}, +}; + +#[derive(Clone)] +pub struct FakeLanguageModel { + pub capacity: usize, +} + +impl LanguageModel for FakeLanguageModel { + fn name(&self) -> String { + "dummy".to_string() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + anyhow::Ok(content.chars().collect::>().len()) + } + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + println!("TRYING TO TRUNCATE: {:?}", length.clone()); + + if length > self.count_tokens(content)? { + println!("NOT TRUNCATING"); + return anyhow::Ok(content.to_string()); + } + + anyhow::Ok(match direction { + TruncationDirection::End => content.chars().collect::>()[..length] + .into_iter() + .collect::(), + TruncationDirection::Start => content.chars().collect::>()[length..] + .into_iter() + .collect::(), + }) + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(self.capacity) + } +} + +pub struct FakeEmbeddingProvider { + pub embedding_count: AtomicUsize, +} + +impl Clone for FakeEmbeddingProvider { + fn clone(&self) -> Self { + FakeEmbeddingProvider { + embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)), + } + } +} + +impl Default for FakeEmbeddingProvider { + fn default() -> Self { + FakeEmbeddingProvider { + embedding_count: AtomicUsize::default(), + } + } +} + +impl FakeEmbeddingProvider { + pub fn embedding_count(&self) -> usize { + self.embedding_count.load(atomic::Ordering::SeqCst) + } + + pub fn embed_sync(&self, span: &str) -> Embedding { + let mut result = vec![1.0; 26]; + for letter in span.chars() { + let letter = letter.to_ascii_lowercase(); + if letter as u32 >= 'a' as u32 { + let ix = (letter as u32) - ('a' as u32); + if ix < 26 { + result[ix as usize] += 1.0; + } + } + } + + let norm = result.iter().map(|x| x * x).sum::().sqrt(); + for x in &mut result { + *x /= norm; + } + + result.into() + } +} + +#[async_trait] +impl CredentialProvider for FakeEmbeddingProvider { + fn has_credentials(&self) -> bool { + true + } + async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } + async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {} + async fn delete_credentials(&self, _cx: &mut AppContext) {} +} + +#[async_trait] +impl EmbeddingProvider for FakeEmbeddingProvider { + fn base_model(&self) -> Box { + Box::new(FakeLanguageModel { capacity: 1000 }) + } + fn max_tokens_per_batch(&self) -> usize { + 1000 + } + + fn rate_limit_expiration(&self) -> Option { + None + } + + async fn embed_batch(&self, spans: Vec) -> anyhow::Result> { + self.embedding_count + .fetch_add(spans.len(), atomic::Ordering::SeqCst); + + anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) + } +} + +pub struct FakeCompletionProvider { + last_completion_tx: Mutex>>, +} + +impl Clone for FakeCompletionProvider { + fn clone(&self) -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } +} + +impl FakeCompletionProvider { + pub fn new() -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } + + pub fn send_completion(&self, completion: impl Into) { + let mut tx = self.last_completion_tx.lock(); + tx.as_mut().unwrap().try_send(completion.into()).unwrap(); + } + + pub fn finish_completion(&self) { + self.last_completion_tx.lock().take().unwrap(); + } +} + +#[async_trait] +impl CredentialProvider for FakeCompletionProvider { + fn has_credentials(&self) -> bool { + true + } + async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } + async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {} + async fn delete_credentials(&self, _cx: &mut AppContext) {} +} + +impl CompletionProvider for FakeCompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); + model + } + fn complete( + &self, + _prompt: Box, + ) -> BoxFuture<'static, anyhow::Result>>> { + let (tx, rx) = mpsc::channel(1); + *self.last_completion_tx.lock() = Some(tx); + async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() + } + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } +} diff --git a/crates/audio2/Cargo.toml b/crates/audio2/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..298142dbefafc9a26b00ddf9f9555dd398cdb470 --- /dev/null +++ b/crates/audio2/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "audio2" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/audio2.rs" +doctest = false + +[dependencies] +gpui2 = { path = "../gpui2" } +collections = { path = "../collections" } +util = { path = "../util" } + + +rodio ={version = "0.17.1", default-features=false, features = ["wav"]} + +log.workspace = true +futures.workspace = true +anyhow.workspace = true +parking_lot.workspace = true + +[dev-dependencies] diff --git a/crates/audio2/audio/Cargo.toml b/crates/audio2/audio/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..36135a1e76b8a98119c64d0f4359d5fa0815cf9d --- /dev/null +++ b/crates/audio2/audio/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "audio" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/audio.rs" +doctest = false + +[dependencies] +gpui = { path = "../gpui" } +collections = { path = "../collections" } +util = { path = "../util" } + +rodio ={version = "0.17.1", default-features=false, features = ["wav"]} + +log.workspace = true + +anyhow.workspace = true +parking_lot.workspace = true + +[dev-dependencies] diff --git a/crates/audio2/audio/src/assets.rs b/crates/audio2/audio/src/assets.rs new file mode 100644 index 0000000000000000000000000000000000000000..b58e1f6aee89683a4b6300e107c66db016b882cb --- /dev/null +++ b/crates/audio2/audio/src/assets.rs @@ -0,0 +1,44 @@ +use std::{io::Cursor, sync::Arc}; + +use anyhow::Result; +use collections::HashMap; +use gpui::{AppContext, AssetSource}; +use rodio::{ + source::{Buffered, SamplesConverter}, + Decoder, Source, +}; + +type Sound = Buffered>>, f32>>; + +pub struct SoundRegistry { + cache: Arc>>, + assets: Box, +} + +impl SoundRegistry { + pub fn new(source: impl AssetSource) -> Arc { + Arc::new(Self { + cache: Default::default(), + assets: Box::new(source), + }) + } + + pub fn global(cx: &AppContext) -> Arc { + cx.global::>().clone() + } + + pub fn get(&self, name: &str) -> Result> { + if let Some(wav) = self.cache.lock().get(name) { + return Ok(wav.clone()); + } + + let path = format!("sounds/{}.wav", name); + let bytes = self.assets.load(&path)?.into_owned(); + let cursor = Cursor::new(bytes); + let source = Decoder::new(cursor)?.convert_samples::().buffered(); + + self.cache.lock().insert(name.to_string(), source.clone()); + + Ok(source) + } +} diff --git a/crates/audio2/audio/src/audio.rs b/crates/audio2/audio/src/audio.rs new file mode 100644 index 0000000000000000000000000000000000000000..d80fb6738f69891a9199f230af832ee335071496 --- /dev/null +++ b/crates/audio2/audio/src/audio.rs @@ -0,0 +1,81 @@ +use assets::SoundRegistry; +use gpui::{AppContext, AssetSource}; +use rodio::{OutputStream, OutputStreamHandle}; +use util::ResultExt; + +mod assets; + +pub fn init(source: impl AssetSource, cx: &mut AppContext) { + cx.set_global(SoundRegistry::new(source)); + cx.set_global(Audio::new()); +} + +pub enum Sound { + Joined, + Leave, + Mute, + Unmute, + StartScreenshare, + StopScreenshare, +} + +impl Sound { + fn file(&self) -> &'static str { + match self { + Self::Joined => "joined_call", + Self::Leave => "leave_call", + Self::Mute => "mute", + Self::Unmute => "unmute", + Self::StartScreenshare => "start_screenshare", + Self::StopScreenshare => "stop_screenshare", + } + } +} + +pub struct Audio { + _output_stream: Option, + output_handle: Option, +} + +impl Audio { + pub fn new() -> Self { + Self { + _output_stream: None, + output_handle: None, + } + } + + fn ensure_output_exists(&mut self) -> Option<&OutputStreamHandle> { + if self.output_handle.is_none() { + let (_output_stream, output_handle) = OutputStream::try_default().log_err().unzip(); + self.output_handle = output_handle; + self._output_stream = _output_stream; + } + + self.output_handle.as_ref() + } + + pub fn play_sound(sound: Sound, cx: &mut AppContext) { + if !cx.has_global::() { + return; + } + + cx.update_global::(|this, cx| { + let output_handle = this.ensure_output_exists()?; + let source = SoundRegistry::global(cx).get(sound.file()).log_err()?; + output_handle.play_raw(source).log_err()?; + Some(()) + }); + } + + pub fn end_call(cx: &mut AppContext) { + if !cx.has_global::() { + return; + } + + cx.update_global::(|this, _| { + this._output_stream.take(); + this.output_handle.take(); + }); + } +} diff --git a/crates/audio2/src/assets.rs b/crates/audio2/src/assets.rs new file mode 100644 index 0000000000000000000000000000000000000000..66e0bf5aa563b720509c080cad964cb7cfb7f24c --- /dev/null +++ b/crates/audio2/src/assets.rs @@ -0,0 +1,44 @@ +use std::{io::Cursor, sync::Arc}; + +use anyhow::Result; +use collections::HashMap; +use gpui2::{AppContext, AssetSource}; +use rodio::{ + source::{Buffered, SamplesConverter}, + Decoder, Source, +}; + +type Sound = Buffered>>, f32>>; + +pub struct SoundRegistry { + cache: Arc>>, + assets: Box, +} + +impl SoundRegistry { + pub fn new(source: impl AssetSource) -> Arc { + Arc::new(Self { + cache: Default::default(), + assets: Box::new(source), + }) + } + + pub fn global(cx: &AppContext) -> Arc { + cx.global::>().clone() + } + + pub fn get(&self, name: &str) -> Result> { + if let Some(wav) = self.cache.lock().get(name) { + return Ok(wav.clone()); + } + + let path = format!("sounds/{}.wav", name); + let bytes = self.assets.load(&path)?.into_owned(); + let cursor = Cursor::new(bytes); + let source = Decoder::new(cursor)?.convert_samples::().buffered(); + + self.cache.lock().insert(name.to_string(), source.clone()); + + Ok(source) + } +} diff --git a/crates/audio2/src/audio2.rs b/crates/audio2/src/audio2.rs new file mode 100644 index 0000000000000000000000000000000000000000..d04587d74e545bc2579502cd567b65de15dfceb1 --- /dev/null +++ b/crates/audio2/src/audio2.rs @@ -0,0 +1,111 @@ +use assets::SoundRegistry; +use futures::{channel::mpsc, StreamExt}; +use gpui2::{AppContext, AssetSource, Executor}; +use rodio::{OutputStream, OutputStreamHandle}; +use util::ResultExt; + +mod assets; + +pub fn init(source: impl AssetSource, cx: &mut AppContext) { + cx.set_global(Audio::new(cx.executor())); + cx.set_global(SoundRegistry::new(source)); +} + +pub enum Sound { + Joined, + Leave, + Mute, + Unmute, + StartScreenshare, + StopScreenshare, +} + +impl Sound { + fn file(&self) -> &'static str { + match self { + Self::Joined => "joined_call", + Self::Leave => "leave_call", + Self::Mute => "mute", + Self::Unmute => "unmute", + Self::StartScreenshare => "start_screenshare", + Self::StopScreenshare => "stop_screenshare", + } + } +} + +pub struct Audio { + tx: mpsc::UnboundedSender>, +} + +struct AudioState { + _output_stream: Option, + output_handle: Option, +} + +impl AudioState { + fn ensure_output_exists(&mut self) -> Option<&OutputStreamHandle> { + if self.output_handle.is_none() { + let (_output_stream, output_handle) = OutputStream::try_default().log_err().unzip(); + self.output_handle = output_handle; + self._output_stream = _output_stream; + } + + self.output_handle.as_ref() + } + + fn take(&mut self) { + self._output_stream.take(); + self.output_handle.take(); + } +} + +impl Audio { + pub fn new(executor: &Executor) -> Self { + let (tx, mut rx) = mpsc::unbounded::>(); + executor + .spawn_on_main(|| async move { + let mut audio = AudioState { + _output_stream: None, + output_handle: None, + }; + + while let Some(f) = rx.next().await { + (f)(&mut audio); + } + }) + .detach(); + + Self { tx } + } + + pub fn play_sound(sound: Sound, cx: &mut AppContext) { + if !cx.has_global::() { + return; + } + + let Some(source) = SoundRegistry::global(cx).get(sound.file()).log_err() else { + return; + }; + + let this = cx.global::(); + this.tx + .unbounded_send(Box::new(move |state| { + if let Some(output_handle) = state.ensure_output_exists() { + output_handle.play_raw(source).log_err(); + } + })) + .ok(); + } + + pub fn end_call(cx: &AppContext) { + if !cx.has_global::() { + return; + } + + let this = cx.global::(); + + this.tx + .unbounded_send(Box::new(move |state| state.take())) + .ok(); + } +} diff --git a/crates/call/src/room.rs b/crates/call/src/room.rs index 3ff5a3490161ce12958eaa688a921e5094bcd76f..8d37194f3a929148e9bce24ed5084ef5b81c5000 100644 --- a/crates/call/src/room.rs +++ b/crates/call/src/room.rs @@ -1252,7 +1252,7 @@ impl Room { .read_with(&cx, |this, _| { this.live_kit .as_ref() - .map(|live_kit| live_kit.room.publish_audio_track(&track)) + .map(|live_kit| live_kit.room.publish_audio_track(track)) }) .ok_or_else(|| anyhow!("live-kit was not initialized"))? .await @@ -1338,7 +1338,7 @@ impl Room { .read_with(&cx, |this, _| { this.live_kit .as_ref() - .map(|live_kit| live_kit.room.publish_video_track(&track)) + .map(|live_kit| live_kit.room.publish_video_track(track)) }) .ok_or_else(|| anyhow!("live-kit was not initialized"))? .await diff --git a/crates/call2/Cargo.toml b/crates/call2/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..f0e47832ed34b1584bbf1ef58afd4f7c39ebd913 --- /dev/null +++ b/crates/call2/Cargo.toml @@ -0,0 +1,52 @@ +[package] +name = "call2" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/call2.rs" +doctest = false + +[features] +test-support = [ + "client2/test-support", + "collections/test-support", + "gpui2/test-support", + "live_kit_client/test-support", + "project2/test-support", + "util/test-support" +] + +[dependencies] +audio2 = { path = "../audio2" } +client2 = { path = "../client2" } +collections = { path = "../collections" } +gpui2 = { path = "../gpui2" } +log.workspace = true +live_kit_client = { path = "../live_kit_client" } +fs2 = { path = "../fs2" } +language2 = { path = "../language2" } +media = { path = "../media" } +project2 = { path = "../project2" } +settings2 = { path = "../settings2" } +util = { path = "../util" } + +anyhow.workspace = true +async-broadcast = "0.4" +futures.workspace = true +postage.workspace = true +schemars.workspace = true +serde.workspace = true +serde_json.workspace = true +serde_derive.workspace = true + +[dev-dependencies] +client2 = { path = "../client2", features = ["test-support"] } +fs2 = { path = "../fs2", features = ["test-support"] } +language2 = { path = "../language2", features = ["test-support"] } +collections = { path = "../collections", features = ["test-support"] } +gpui2 = { path = "../gpui2", features = ["test-support"] } +live_kit_client = { path = "../live_kit_client", features = ["test-support"] } +project2 = { path = "../project2", features = ["test-support"] } +util = { path = "../util", features = ["test-support"] } diff --git a/crates/call2/src/call2.rs b/crates/call2/src/call2.rs new file mode 100644 index 0000000000000000000000000000000000000000..fd09dc31803c63081dea1ebcbe8dc07442d82eba --- /dev/null +++ b/crates/call2/src/call2.rs @@ -0,0 +1,461 @@ +pub mod call_settings; +pub mod participant; +pub mod room; + +use anyhow::{anyhow, Result}; +use audio2::Audio; +use call_settings::CallSettings; +use client2::{ + proto, ClickhouseEvent, Client, TelemetrySettings, TypedEnvelope, User, UserStore, + ZED_ALWAYS_ACTIVE, +}; +use collections::HashSet; +use futures::{future::Shared, FutureExt}; +use gpui2::{ + AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Subscription, Task, + WeakModel, +}; +use postage::watch; +use project2::Project; +use settings2::Settings; +use std::sync::Arc; + +pub use participant::ParticipantLocation; +pub use room::Room; + +pub fn init(client: Arc, user_store: Model, cx: &mut AppContext) { + CallSettings::register(cx); + + let active_call = cx.build_model(|cx| ActiveCall::new(client, user_store, cx)); + cx.set_global(active_call); +} + +#[derive(Clone)] +pub struct IncomingCall { + pub room_id: u64, + pub calling_user: Arc, + pub participants: Vec>, + pub initial_project: Option, +} + +/// Singleton global maintaining the user's participation in a room across workspaces. +pub struct ActiveCall { + room: Option<(Model, Vec)>, + pending_room_creation: Option, Arc>>>>, + location: Option>, + pending_invites: HashSet, + incoming_call: ( + watch::Sender>, + watch::Receiver>, + ), + client: Arc, + user_store: Model, + _subscriptions: Vec, +} + +impl EventEmitter for ActiveCall { + type Event = room::Event; +} + +impl ActiveCall { + fn new(client: Arc, user_store: Model, cx: &mut ModelContext) -> Self { + Self { + room: None, + pending_room_creation: None, + location: None, + pending_invites: Default::default(), + incoming_call: watch::channel(), + + _subscriptions: vec![ + client.add_request_handler(cx.weak_model(), Self::handle_incoming_call), + client.add_message_handler(cx.weak_model(), Self::handle_call_canceled), + ], + client, + user_store, + } + } + + pub fn channel_id(&self, cx: &AppContext) -> Option { + self.room()?.read(cx).channel_id() + } + + async fn handle_incoming_call( + this: Model, + envelope: TypedEnvelope, + _: Arc, + mut cx: AsyncAppContext, + ) -> Result { + let user_store = this.update(&mut cx, |this, _| this.user_store.clone())?; + let call = IncomingCall { + room_id: envelope.payload.room_id, + participants: user_store + .update(&mut cx, |user_store, cx| { + user_store.get_users(envelope.payload.participant_user_ids, cx) + })? + .await?, + calling_user: user_store + .update(&mut cx, |user_store, cx| { + user_store.get_user(envelope.payload.calling_user_id, cx) + })? + .await?, + initial_project: envelope.payload.initial_project, + }; + this.update(&mut cx, |this, _| { + *this.incoming_call.0.borrow_mut() = Some(call); + })?; + + Ok(proto::Ack {}) + } + + async fn handle_call_canceled( + this: Model, + envelope: TypedEnvelope, + _: Arc, + mut cx: AsyncAppContext, + ) -> Result<()> { + this.update(&mut cx, |this, _| { + let mut incoming_call = this.incoming_call.0.borrow_mut(); + if incoming_call + .as_ref() + .map_or(false, |call| call.room_id == envelope.payload.room_id) + { + incoming_call.take(); + } + })?; + Ok(()) + } + + pub fn global(cx: &AppContext) -> Model { + cx.global::>().clone() + } + + pub fn invite( + &mut self, + called_user_id: u64, + initial_project: Option>, + cx: &mut ModelContext, + ) -> Task> { + if !self.pending_invites.insert(called_user_id) { + return Task::ready(Err(anyhow!("user was already invited"))); + } + cx.notify(); + + let room = if let Some(room) = self.room().cloned() { + Some(Task::ready(Ok(room)).shared()) + } else { + self.pending_room_creation.clone() + }; + + let invite = if let Some(room) = room { + cx.spawn(move |_, mut cx| async move { + let room = room.await.map_err(|err| anyhow!("{:?}", err))?; + + let initial_project_id = if let Some(initial_project) = initial_project { + Some( + room.update(&mut cx, |room, cx| room.share_project(initial_project, cx))? + .await?, + ) + } else { + None + }; + + room.update(&mut cx, move |room, cx| { + room.call(called_user_id, initial_project_id, cx) + })? + .await?; + + anyhow::Ok(()) + }) + } else { + let client = self.client.clone(); + let user_store = self.user_store.clone(); + let room = cx + .spawn(move |this, mut cx| async move { + let create_room = async { + let room = cx + .update(|cx| { + Room::create( + called_user_id, + initial_project, + client, + user_store, + cx, + ) + })? + .await?; + + this.update(&mut cx, |this, cx| this.set_room(Some(room.clone()), cx))? + .await?; + + anyhow::Ok(room) + }; + + let room = create_room.await; + this.update(&mut cx, |this, _| this.pending_room_creation = None)?; + room.map_err(Arc::new) + }) + .shared(); + self.pending_room_creation = Some(room.clone()); + cx.executor().spawn(async move { + room.await.map_err(|err| anyhow!("{:?}", err))?; + anyhow::Ok(()) + }) + }; + + cx.spawn(move |this, mut cx| async move { + let result = invite.await; + if result.is_ok() { + this.update(&mut cx, |this, cx| this.report_call_event("invite", cx))?; + } else { + // TODO: Resport collaboration error + } + + this.update(&mut cx, |this, cx| { + this.pending_invites.remove(&called_user_id); + cx.notify(); + })?; + result + }) + } + + pub fn cancel_invite( + &mut self, + called_user_id: u64, + cx: &mut ModelContext, + ) -> Task> { + let room_id = if let Some(room) = self.room() { + room.read(cx).id() + } else { + return Task::ready(Err(anyhow!("no active call"))); + }; + + let client = self.client.clone(); + cx.executor().spawn(async move { + client + .request(proto::CancelCall { + room_id, + called_user_id, + }) + .await?; + anyhow::Ok(()) + }) + } + + pub fn incoming(&self) -> watch::Receiver> { + self.incoming_call.1.clone() + } + + pub fn accept_incoming(&mut self, cx: &mut ModelContext) -> Task> { + if self.room.is_some() { + return Task::ready(Err(anyhow!("cannot join while on another call"))); + } + + let call = if let Some(call) = self.incoming_call.1.borrow().clone() { + call + } else { + return Task::ready(Err(anyhow!("no incoming call"))); + }; + + let join = Room::join(&call, self.client.clone(), self.user_store.clone(), cx); + + cx.spawn(|this, mut cx| async move { + let room = join.await?; + this.update(&mut cx, |this, cx| this.set_room(Some(room.clone()), cx))? + .await?; + this.update(&mut cx, |this, cx| { + this.report_call_event("accept incoming", cx) + })?; + Ok(()) + }) + } + + pub fn decline_incoming(&mut self, cx: &mut ModelContext) -> Result<()> { + let call = self + .incoming_call + .0 + .borrow_mut() + .take() + .ok_or_else(|| anyhow!("no incoming call"))?; + report_call_event_for_room("decline incoming", call.room_id, None, &self.client, cx); + self.client.send(proto::DeclineCall { + room_id: call.room_id, + })?; + Ok(()) + } + + pub fn join_channel( + &mut self, + channel_id: u64, + cx: &mut ModelContext, + ) -> Task>> { + if let Some(room) = self.room().cloned() { + if room.read(cx).channel_id() == Some(channel_id) { + return Task::ready(Ok(room)); + } else { + room.update(cx, |room, cx| room.clear_state(cx)); + } + } + + let join = Room::join_channel(channel_id, self.client.clone(), self.user_store.clone(), cx); + + cx.spawn(|this, mut cx| async move { + let room = join.await?; + this.update(&mut cx, |this, cx| this.set_room(Some(room.clone()), cx))? + .await?; + this.update(&mut cx, |this, cx| { + this.report_call_event("join channel", cx) + })?; + Ok(room) + }) + } + + pub fn hang_up(&mut self, cx: &mut ModelContext) -> Task> { + cx.notify(); + self.report_call_event("hang up", cx); + + Audio::end_call(cx); + if let Some((room, _)) = self.room.take() { + room.update(cx, |room, cx| room.leave(cx)) + } else { + Task::ready(Ok(())) + } + } + + pub fn share_project( + &mut self, + project: Model, + cx: &mut ModelContext, + ) -> Task> { + if let Some((room, _)) = self.room.as_ref() { + self.report_call_event("share project", cx); + room.update(cx, |room, cx| room.share_project(project, cx)) + } else { + Task::ready(Err(anyhow!("no active call"))) + } + } + + pub fn unshare_project( + &mut self, + project: Model, + cx: &mut ModelContext, + ) -> Result<()> { + if let Some((room, _)) = self.room.as_ref() { + self.report_call_event("unshare project", cx); + room.update(cx, |room, cx| room.unshare_project(project, cx)) + } else { + Err(anyhow!("no active call")) + } + } + + pub fn location(&self) -> Option<&WeakModel> { + self.location.as_ref() + } + + pub fn set_location( + &mut self, + project: Option<&Model>, + cx: &mut ModelContext, + ) -> Task> { + if project.is_some() || !*ZED_ALWAYS_ACTIVE { + self.location = project.map(|project| project.downgrade()); + if let Some((room, _)) = self.room.as_ref() { + return room.update(cx, |room, cx| room.set_location(project, cx)); + } + } + Task::ready(Ok(())) + } + + fn set_room( + &mut self, + room: Option>, + cx: &mut ModelContext, + ) -> Task> { + if room.as_ref() != self.room.as_ref().map(|room| &room.0) { + cx.notify(); + if let Some(room) = room { + if room.read(cx).status().is_offline() { + self.room = None; + Task::ready(Ok(())) + } else { + let subscriptions = vec![ + cx.observe(&room, |this, room, cx| { + if room.read(cx).status().is_offline() { + this.set_room(None, cx).detach_and_log_err(cx); + } + + cx.notify(); + }), + cx.subscribe(&room, |_, _, event, cx| cx.emit(event.clone())), + ]; + self.room = Some((room.clone(), subscriptions)); + let location = self + .location + .as_ref() + .and_then(|location| location.upgrade()); + room.update(cx, |room, cx| room.set_location(location.as_ref(), cx)) + } + } else { + self.room = None; + Task::ready(Ok(())) + } + } else { + Task::ready(Ok(())) + } + } + + pub fn room(&self) -> Option<&Model> { + self.room.as_ref().map(|(room, _)| room) + } + + pub fn client(&self) -> Arc { + self.client.clone() + } + + pub fn pending_invites(&self) -> &HashSet { + &self.pending_invites + } + + pub fn report_call_event(&self, operation: &'static str, cx: &AppContext) { + if let Some(room) = self.room() { + let room = room.read(cx); + report_call_event_for_room(operation, room.id(), room.channel_id(), &self.client, cx); + } + } +} + +pub fn report_call_event_for_room( + operation: &'static str, + room_id: u64, + channel_id: Option, + client: &Arc, + cx: &AppContext, +) { + let telemetry = client.telemetry(); + let telemetry_settings = *TelemetrySettings::get_global(cx); + let event = ClickhouseEvent::Call { + operation, + room_id: Some(room_id), + channel_id, + }; + telemetry.report_clickhouse_event(event, telemetry_settings); +} + +pub fn report_call_event_for_channel( + operation: &'static str, + channel_id: u64, + client: &Arc, + cx: &AppContext, +) { + let room = ActiveCall::global(cx).read(cx).room(); + + let telemetry = client.telemetry(); + + let telemetry_settings = *TelemetrySettings::get_global(cx); + + let event = ClickhouseEvent::Call { + operation, + room_id: room.map(|r| r.read(cx).id()), + channel_id: Some(channel_id), + }; + telemetry.report_clickhouse_event(event, telemetry_settings); +} diff --git a/crates/call2/src/call_settings.rs b/crates/call2/src/call_settings.rs new file mode 100644 index 0000000000000000000000000000000000000000..c83ed739805cd7e2cdec8c5a394779c214bf1b92 --- /dev/null +++ b/crates/call2/src/call_settings.rs @@ -0,0 +1,32 @@ +use anyhow::Result; +use gpui2::AppContext; +use schemars::JsonSchema; +use serde_derive::{Deserialize, Serialize}; +use settings2::Settings; + +#[derive(Deserialize, Debug)] +pub struct CallSettings { + pub mute_on_join: bool, +} + +#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)] +pub struct CallSettingsContent { + pub mute_on_join: Option, +} + +impl Settings for CallSettings { + const KEY: Option<&'static str> = Some("calls"); + + type FileContent = CallSettingsContent; + + fn load( + default_value: &Self::FileContent, + user_values: &[&Self::FileContent], + _cx: &mut AppContext, + ) -> Result + where + Self: Sized, + { + Self::load_via_json_merge(default_value, user_values) + } +} diff --git a/crates/call2/src/participant.rs b/crates/call2/src/participant.rs new file mode 100644 index 0000000000000000000000000000000000000000..7f3e91dbba0116a7b7f7ef5b1c471fb1a768529f --- /dev/null +++ b/crates/call2/src/participant.rs @@ -0,0 +1,71 @@ +use anyhow::{anyhow, Result}; +use client2::ParticipantIndex; +use client2::{proto, User}; +use gpui2::WeakModel; +pub use live_kit_client::Frame; +use project2::Project; +use std::{fmt, sync::Arc}; + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum ParticipantLocation { + SharedProject { project_id: u64 }, + UnsharedProject, + External, +} + +impl ParticipantLocation { + pub fn from_proto(location: Option) -> Result { + match location.and_then(|l| l.variant) { + Some(proto::participant_location::Variant::SharedProject(project)) => { + Ok(Self::SharedProject { + project_id: project.id, + }) + } + Some(proto::participant_location::Variant::UnsharedProject(_)) => { + Ok(Self::UnsharedProject) + } + Some(proto::participant_location::Variant::External(_)) => Ok(Self::External), + None => Err(anyhow!("participant location was not provided")), + } + } +} + +#[derive(Clone, Default)] +pub struct LocalParticipant { + pub projects: Vec, + pub active_project: Option>, +} + +#[derive(Clone, Debug)] +pub struct RemoteParticipant { + pub user: Arc, + pub peer_id: proto::PeerId, + pub projects: Vec, + pub location: ParticipantLocation, + pub participant_index: ParticipantIndex, + pub muted: bool, + pub speaking: bool, + // pub video_tracks: HashMap>, + // pub audio_tracks: HashMap>, +} + +#[derive(Clone)] +pub struct RemoteVideoTrack { + pub(crate) live_kit_track: Arc, +} + +unsafe impl Send for RemoteVideoTrack {} +// todo!("remove this sync because it's not legit") +unsafe impl Sync for RemoteVideoTrack {} + +impl fmt::Debug for RemoteVideoTrack { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RemoteVideoTrack").finish() + } +} + +impl RemoteVideoTrack { + pub fn frames(&self) -> async_broadcast::Receiver { + self.live_kit_track.frames() + } +} diff --git a/crates/call2/src/room.rs b/crates/call2/src/room.rs new file mode 100644 index 0000000000000000000000000000000000000000..b7bac52a8bf9c6a812e1d1d8c3ef1bd14dcb0db7 --- /dev/null +++ b/crates/call2/src/room.rs @@ -0,0 +1,1622 @@ +#![allow(dead_code, unused)] +// todo!() + +use crate::{ + call_settings::CallSettings, + participant::{LocalParticipant, ParticipantLocation, RemoteParticipant, RemoteVideoTrack}, + IncomingCall, +}; +use anyhow::{anyhow, Result}; +use audio2::{Audio, Sound}; +use client2::{ + proto::{self, PeerId}, + Client, ParticipantIndex, TypedEnvelope, User, UserStore, +}; +use collections::{BTreeMap, HashMap, HashSet}; +use fs2::Fs; +use futures::{FutureExt, StreamExt}; +use gpui2::{ + AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task, WeakModel, +}; +use language2::LanguageRegistry; +use live_kit_client::{LocalTrackPublication, RemoteAudioTrackUpdate, RemoteVideoTrackUpdate}; +use postage::{sink::Sink, stream::Stream, watch}; +use project2::Project; +use settings2::Settings; +use std::{future::Future, sync::Arc, time::Duration}; +use util::{ResultExt, TryFutureExt}; + +pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30); + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Event { + ParticipantLocationChanged { + participant_id: proto::PeerId, + }, + RemoteVideoTracksChanged { + participant_id: proto::PeerId, + }, + RemoteAudioTracksChanged { + participant_id: proto::PeerId, + }, + RemoteProjectShared { + owner: Arc, + project_id: u64, + worktree_root_names: Vec, + }, + RemoteProjectUnshared { + project_id: u64, + }, + RemoteProjectJoined { + project_id: u64, + }, + RemoteProjectInvitationDiscarded { + project_id: u64, + }, + Left, +} + +pub struct Room { + id: u64, + channel_id: Option, + // live_kit: Option, + status: RoomStatus, + shared_projects: HashSet>, + joined_projects: HashSet>, + local_participant: LocalParticipant, + remote_participants: BTreeMap, + pending_participants: Vec>, + participant_user_ids: HashSet, + pending_call_count: usize, + leave_when_empty: bool, + client: Arc, + user_store: Model, + follows_by_leader_id_project_id: HashMap<(PeerId, u64), Vec>, + client_subscriptions: Vec, + _subscriptions: Vec, + room_update_completed_tx: watch::Sender>, + room_update_completed_rx: watch::Receiver>, + pending_room_update: Option>, + maintain_connection: Option>>, +} + +impl EventEmitter for Room { + type Event = Event; +} + +impl Room { + pub fn channel_id(&self) -> Option { + self.channel_id + } + + pub fn is_sharing_project(&self) -> bool { + !self.shared_projects.is_empty() + } + + #[cfg(any(test, feature = "test-support"))] + pub fn is_connected(&self) -> bool { + false + // if let Some(live_kit) = self.live_kit.as_ref() { + // matches!( + // *live_kit.room.status().borrow(), + // live_kit_client::ConnectionState::Connected { .. } + // ) + // } else { + // false + // } + } + + fn new( + id: u64, + channel_id: Option, + live_kit_connection_info: Option, + client: Arc, + user_store: Model, + cx: &mut ModelContext, + ) -> Self { + todo!() + // let _live_kit_room = if let Some(connection_info) = live_kit_connection_info { + // let room = live_kit_client::Room::new(); + // let mut status = room.status(); + // // Consume the initial status of the room. + // let _ = status.try_recv(); + // let _maintain_room = cx.spawn(|this, mut cx| async move { + // while let Some(status) = status.next().await { + // let this = if let Some(this) = this.upgrade() { + // this + // } else { + // break; + // }; + + // if status == live_kit_client::ConnectionState::Disconnected { + // this.update(&mut cx, |this, cx| this.leave(cx).log_err()) + // .ok(); + // break; + // } + // } + // }); + + // let mut track_video_changes = room.remote_video_track_updates(); + // let _maintain_video_tracks = cx.spawn(|this, mut cx| async move { + // while let Some(track_change) = track_video_changes.next().await { + // let this = if let Some(this) = this.upgrade() { + // this + // } else { + // break; + // }; + + // this.update(&mut cx, |this, cx| { + // this.remote_video_track_updated(track_change, cx).log_err() + // }) + // .ok(); + // } + // }); + + // let mut track_audio_changes = room.remote_audio_track_updates(); + // let _maintain_audio_tracks = cx.spawn(|this, mut cx| async move { + // while let Some(track_change) = track_audio_changes.next().await { + // let this = if let Some(this) = this.upgrade() { + // this + // } else { + // break; + // }; + + // this.update(&mut cx, |this, cx| { + // this.remote_audio_track_updated(track_change, cx).log_err() + // }) + // .ok(); + // } + // }); + + // let connect = room.connect(&connection_info.server_url, &connection_info.token); + // cx.spawn(|this, mut cx| async move { + // connect.await?; + + // if !cx.update(|cx| Self::mute_on_join(cx))? { + // this.update(&mut cx, |this, cx| this.share_microphone(cx))? + // .await?; + // } + + // anyhow::Ok(()) + // }) + // .detach_and_log_err(cx); + + // Some(LiveKitRoom { + // room, + // screen_track: LocalTrack::None, + // microphone_track: LocalTrack::None, + // next_publish_id: 0, + // muted_by_user: false, + // deafened: false, + // speaking: false, + // _maintain_room, + // _maintain_tracks: [_maintain_video_tracks, _maintain_audio_tracks], + // }) + // } else { + // None + // }; + + // let maintain_connection = cx.spawn({ + // let client = client.clone(); + // move |this, cx| Self::maintain_connection(this, client.clone(), cx).log_err() + // }); + + // Audio::play_sound(Sound::Joined, cx); + + // let (room_update_completed_tx, room_update_completed_rx) = watch::channel(); + + // Self { + // id, + // channel_id, + // // live_kit: live_kit_room, + // status: RoomStatus::Online, + // shared_projects: Default::default(), + // joined_projects: Default::default(), + // participant_user_ids: Default::default(), + // local_participant: Default::default(), + // remote_participants: Default::default(), + // pending_participants: Default::default(), + // pending_call_count: 0, + // client_subscriptions: vec![ + // client.add_message_handler(cx.weak_handle(), Self::handle_room_updated) + // ], + // _subscriptions: vec![ + // cx.on_release(Self::released), + // cx.on_app_quit(Self::app_will_quit), + // ], + // leave_when_empty: false, + // pending_room_update: None, + // client, + // user_store, + // follows_by_leader_id_project_id: Default::default(), + // maintain_connection: Some(maintain_connection), + // room_update_completed_tx, + // room_update_completed_rx, + // } + } + + pub(crate) fn create( + called_user_id: u64, + initial_project: Option>, + client: Arc, + user_store: Model, + cx: &mut AppContext, + ) -> Task>> { + cx.spawn(move |mut cx| async move { + let response = client.request(proto::CreateRoom {}).await?; + let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?; + let room = cx.build_model(|cx| { + Self::new( + room_proto.id, + None, + response.live_kit_connection_info, + client, + user_store, + cx, + ) + })?; + + let initial_project_id = if let Some(initial_project) = initial_project { + let initial_project_id = room + .update(&mut cx, |room, cx| { + room.share_project(initial_project.clone(), cx) + })? + .await?; + Some(initial_project_id) + } else { + None + }; + + match room + .update(&mut cx, |room, cx| { + room.leave_when_empty = true; + room.call(called_user_id, initial_project_id, cx) + })? + .await + { + Ok(()) => Ok(room), + Err(error) => Err(anyhow!("room creation failed: {:?}", error)), + } + }) + } + + pub(crate) fn join_channel( + channel_id: u64, + client: Arc, + user_store: Model, + cx: &mut AppContext, + ) -> Task>> { + cx.spawn(move |cx| async move { + Self::from_join_response( + client.request(proto::JoinChannel { channel_id }).await?, + client, + user_store, + cx, + ) + }) + } + + pub(crate) fn join( + call: &IncomingCall, + client: Arc, + user_store: Model, + cx: &mut AppContext, + ) -> Task>> { + let id = call.room_id; + cx.spawn(move |cx| async move { + Self::from_join_response( + client.request(proto::JoinRoom { id }).await?, + client, + user_store, + cx, + ) + }) + } + + fn released(&mut self, cx: &mut AppContext) { + if self.status.is_online() { + self.leave_internal(cx).detach_and_log_err(cx); + } + } + + fn app_will_quit(&mut self, cx: &mut ModelContext) -> impl Future { + let task = if self.status.is_online() { + let leave = self.leave_internal(cx); + Some(cx.executor().spawn(async move { + leave.await.log_err(); + })) + } else { + None + }; + + async move { + if let Some(task) = task { + task.await; + } + } + } + + pub fn mute_on_join(cx: &AppContext) -> bool { + CallSettings::get_global(cx).mute_on_join || client2::IMPERSONATE_LOGIN.is_some() + } + + fn from_join_response( + response: proto::JoinRoomResponse, + client: Arc, + user_store: Model, + mut cx: AsyncAppContext, + ) -> Result> { + let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?; + let room = cx.build_model(|cx| { + Self::new( + room_proto.id, + response.channel_id, + response.live_kit_connection_info, + client, + user_store, + cx, + ) + })?; + room.update(&mut cx, |room, cx| { + room.leave_when_empty = room.channel_id.is_none(); + room.apply_room_update(room_proto, cx)?; + anyhow::Ok(()) + })??; + Ok(room) + } + + fn should_leave(&self) -> bool { + self.leave_when_empty + && self.pending_room_update.is_none() + && self.pending_participants.is_empty() + && self.remote_participants.is_empty() + && self.pending_call_count == 0 + } + + pub(crate) fn leave(&mut self, cx: &mut ModelContext) -> Task> { + cx.notify(); + cx.emit(Event::Left); + self.leave_internal(cx) + } + + fn leave_internal(&mut self, cx: &mut AppContext) -> Task> { + if self.status.is_offline() { + return Task::ready(Err(anyhow!("room is offline"))); + } + + log::info!("leaving room"); + Audio::play_sound(Sound::Leave, cx); + + self.clear_state(cx); + + let leave_room = self.client.request(proto::LeaveRoom {}); + cx.executor().spawn(async move { + leave_room.await?; + anyhow::Ok(()) + }) + } + + pub(crate) fn clear_state(&mut self, cx: &mut AppContext) { + for project in self.shared_projects.drain() { + if let Some(project) = project.upgrade() { + project.update(cx, |project, cx| { + project.unshare(cx).log_err(); + }); + } + } + for project in self.joined_projects.drain() { + if let Some(project) = project.upgrade() { + project.update(cx, |project, cx| { + project.disconnected_from_host(cx); + project.close(cx); + }); + } + } + + self.status = RoomStatus::Offline; + self.remote_participants.clear(); + self.pending_participants.clear(); + self.participant_user_ids.clear(); + self.client_subscriptions.clear(); + // self.live_kit.take(); + self.pending_room_update.take(); + self.maintain_connection.take(); + } + + async fn maintain_connection( + this: WeakModel, + client: Arc, + mut cx: AsyncAppContext, + ) -> Result<()> { + let mut client_status = client.status(); + loop { + let _ = client_status.try_recv(); + let is_connected = client_status.borrow().is_connected(); + // Even if we're initially connected, any future change of the status means we momentarily disconnected. + if !is_connected || client_status.next().await.is_some() { + log::info!("detected client disconnection"); + + this.upgrade() + .ok_or_else(|| anyhow!("room was dropped"))? + .update(&mut cx, |this, cx| { + this.status = RoomStatus::Rejoining; + cx.notify(); + })?; + + // Wait for client to re-establish a connection to the server. + { + let mut reconnection_timeout = cx.executor().timer(RECONNECT_TIMEOUT).fuse(); + let client_reconnection = async { + let mut remaining_attempts = 3; + while remaining_attempts > 0 { + if client_status.borrow().is_connected() { + log::info!("client reconnected, attempting to rejoin room"); + + let Some(this) = this.upgrade() else { break }; + match this.update(&mut cx, |this, cx| this.rejoin(cx)) { + Ok(task) => { + if task.await.log_err().is_some() { + return true; + } else { + remaining_attempts -= 1; + } + } + Err(_app_dropped) => return false, + } + } else if client_status.borrow().is_signed_out() { + return false; + } + + log::info!( + "waiting for client status change, remaining attempts {}", + remaining_attempts + ); + client_status.next().await; + } + false + } + .fuse(); + futures::pin_mut!(client_reconnection); + + futures::select_biased! { + reconnected = client_reconnection => { + if reconnected { + log::info!("successfully reconnected to room"); + // If we successfully joined the room, go back around the loop + // waiting for future connection status changes. + continue; + } + } + _ = reconnection_timeout => { + log::info!("room reconnection timeout expired"); + } + } + } + + break; + } + } + + // The client failed to re-establish a connection to the server + // or an error occurred while trying to re-join the room. Either way + // we leave the room and return an error. + if let Some(this) = this.upgrade() { + log::info!("reconnection failed, leaving room"); + let _ = this.update(&mut cx, |this, cx| this.leave(cx))?; + } + Err(anyhow!( + "can't reconnect to room: client failed to re-establish connection" + )) + } + + fn rejoin(&mut self, cx: &mut ModelContext) -> Task> { + let mut projects = HashMap::default(); + let mut reshared_projects = Vec::new(); + let mut rejoined_projects = Vec::new(); + self.shared_projects.retain(|project| { + if let Some(handle) = project.upgrade() { + let project = handle.read(cx); + if let Some(project_id) = project.remote_id() { + projects.insert(project_id, handle.clone()); + reshared_projects.push(proto::UpdateProject { + project_id, + worktrees: project.worktree_metadata_protos(cx), + }); + return true; + } + } + false + }); + self.joined_projects.retain(|project| { + if let Some(handle) = project.upgrade() { + let project = handle.read(cx); + if let Some(project_id) = project.remote_id() { + projects.insert(project_id, handle.clone()); + rejoined_projects.push(proto::RejoinProject { + id: project_id, + worktrees: project + .worktrees() + .map(|worktree| { + let worktree = worktree.read(cx); + proto::RejoinWorktree { + id: worktree.id().to_proto(), + scan_id: worktree.completed_scan_id() as u64, + } + }) + .collect(), + }); + } + return true; + } + false + }); + + let response = self.client.request_envelope(proto::RejoinRoom { + id: self.id, + reshared_projects, + rejoined_projects, + }); + + cx.spawn(|this, mut cx| async move { + let response = response.await?; + let message_id = response.message_id; + let response = response.payload; + let room_proto = response.room.ok_or_else(|| anyhow!("invalid room"))?; + this.update(&mut cx, |this, cx| { + this.status = RoomStatus::Online; + this.apply_room_update(room_proto, cx)?; + + for reshared_project in response.reshared_projects { + if let Some(project) = projects.get(&reshared_project.id) { + project.update(cx, |project, cx| { + project.reshared(reshared_project, cx).log_err(); + }); + } + } + + for rejoined_project in response.rejoined_projects { + if let Some(project) = projects.get(&rejoined_project.id) { + project.update(cx, |project, cx| { + project.rejoined(rejoined_project, message_id, cx).log_err(); + }); + } + } + + anyhow::Ok(()) + })? + }) + } + + pub fn id(&self) -> u64 { + self.id + } + + pub fn status(&self) -> RoomStatus { + self.status + } + + pub fn local_participant(&self) -> &LocalParticipant { + &self.local_participant + } + + pub fn remote_participants(&self) -> &BTreeMap { + &self.remote_participants + } + + pub fn remote_participant_for_peer_id(&self, peer_id: PeerId) -> Option<&RemoteParticipant> { + self.remote_participants + .values() + .find(|p| p.peer_id == peer_id) + } + + pub fn pending_participants(&self) -> &[Arc] { + &self.pending_participants + } + + pub fn contains_participant(&self, user_id: u64) -> bool { + self.participant_user_ids.contains(&user_id) + } + + pub fn followers_for(&self, leader_id: PeerId, project_id: u64) -> &[PeerId] { + self.follows_by_leader_id_project_id + .get(&(leader_id, project_id)) + .map_or(&[], |v| v.as_slice()) + } + + /// Returns the most 'active' projects, defined as most people in the project + pub fn most_active_project(&self, cx: &AppContext) -> Option<(u64, u64)> { + let mut project_hosts_and_guest_counts = HashMap::, u32)>::default(); + for participant in self.remote_participants.values() { + match participant.location { + ParticipantLocation::SharedProject { project_id } => { + project_hosts_and_guest_counts + .entry(project_id) + .or_default() + .1 += 1; + } + ParticipantLocation::External | ParticipantLocation::UnsharedProject => {} + } + for project in &participant.projects { + project_hosts_and_guest_counts + .entry(project.id) + .or_default() + .0 = Some(participant.user.id); + } + } + + if let Some(user) = self.user_store.read(cx).current_user() { + for project in &self.local_participant.projects { + project_hosts_and_guest_counts + .entry(project.id) + .or_default() + .0 = Some(user.id); + } + } + + project_hosts_and_guest_counts + .into_iter() + .filter_map(|(id, (host, guest_count))| Some((id, host?, guest_count))) + .max_by_key(|(_, _, guest_count)| *guest_count) + .map(|(id, host, _)| (id, host)) + } + + async fn handle_room_updated( + this: Model, + envelope: TypedEnvelope, + _: Arc, + mut cx: AsyncAppContext, + ) -> Result<()> { + let room = envelope + .payload + .room + .ok_or_else(|| anyhow!("invalid room"))?; + this.update(&mut cx, |this, cx| this.apply_room_update(room, cx))? + } + + fn apply_room_update( + &mut self, + mut room: proto::Room, + cx: &mut ModelContext, + ) -> Result<()> { + // Filter ourselves out from the room's participants. + let local_participant_ix = room + .participants + .iter() + .position(|participant| Some(participant.user_id) == self.client.user_id()); + let local_participant = local_participant_ix.map(|ix| room.participants.swap_remove(ix)); + + let pending_participant_user_ids = room + .pending_participants + .iter() + .map(|p| p.user_id) + .collect::>(); + + let remote_participant_user_ids = room + .participants + .iter() + .map(|p| p.user_id) + .collect::>(); + + let (remote_participants, pending_participants) = + self.user_store.update(cx, move |user_store, cx| { + ( + user_store.get_users(remote_participant_user_ids, cx), + user_store.get_users(pending_participant_user_ids, cx), + ) + }); + + self.pending_room_update = Some(cx.spawn(|this, mut cx| async move { + let (remote_participants, pending_participants) = + futures::join!(remote_participants, pending_participants); + + this.update(&mut cx, |this, cx| { + this.participant_user_ids.clear(); + + if let Some(participant) = local_participant { + this.local_participant.projects = participant.projects; + } else { + this.local_participant.projects.clear(); + } + + if let Some(participants) = remote_participants.log_err() { + for (participant, user) in room.participants.into_iter().zip(participants) { + let Some(peer_id) = participant.peer_id else { + continue; + }; + let participant_index = ParticipantIndex(participant.participant_index); + this.participant_user_ids.insert(participant.user_id); + + let old_projects = this + .remote_participants + .get(&participant.user_id) + .into_iter() + .flat_map(|existing| &existing.projects) + .map(|project| project.id) + .collect::>(); + let new_projects = participant + .projects + .iter() + .map(|project| project.id) + .collect::>(); + + for project in &participant.projects { + if !old_projects.contains(&project.id) { + cx.emit(Event::RemoteProjectShared { + owner: user.clone(), + project_id: project.id, + worktree_root_names: project.worktree_root_names.clone(), + }); + } + } + + for unshared_project_id in old_projects.difference(&new_projects) { + this.joined_projects.retain(|project| { + if let Some(project) = project.upgrade() { + project.update(cx, |project, cx| { + if project.remote_id() == Some(*unshared_project_id) { + project.disconnected_from_host(cx); + false + } else { + true + } + }) + } else { + false + } + }); + cx.emit(Event::RemoteProjectUnshared { + project_id: *unshared_project_id, + }); + } + + let location = ParticipantLocation::from_proto(participant.location) + .unwrap_or(ParticipantLocation::External); + if let Some(remote_participant) = + this.remote_participants.get_mut(&participant.user_id) + { + remote_participant.peer_id = peer_id; + remote_participant.projects = participant.projects; + remote_participant.participant_index = participant_index; + if location != remote_participant.location { + remote_participant.location = location; + cx.emit(Event::ParticipantLocationChanged { + participant_id: peer_id, + }); + } + } else { + this.remote_participants.insert( + participant.user_id, + RemoteParticipant { + user: user.clone(), + participant_index, + peer_id, + projects: participant.projects, + location, + muted: true, + speaking: false, + // video_tracks: Default::default(), + // audio_tracks: Default::default(), + }, + ); + + Audio::play_sound(Sound::Joined, cx); + + // if let Some(live_kit) = this.live_kit.as_ref() { + // let video_tracks = + // live_kit.room.remote_video_tracks(&user.id.to_string()); + // let audio_tracks = + // live_kit.room.remote_audio_tracks(&user.id.to_string()); + // let publications = live_kit + // .room + // .remote_audio_track_publications(&user.id.to_string()); + + // for track in video_tracks { + // this.remote_video_track_updated( + // RemoteVideoTrackUpdate::Subscribed(track), + // cx, + // ) + // .log_err(); + // } + + // for (track, publication) in + // audio_tracks.iter().zip(publications.iter()) + // { + // this.remote_audio_track_updated( + // RemoteAudioTrackUpdate::Subscribed( + // track.clone(), + // publication.clone(), + // ), + // cx, + // ) + // .log_err(); + // } + // } + } + } + + this.remote_participants.retain(|user_id, participant| { + if this.participant_user_ids.contains(user_id) { + true + } else { + for project in &participant.projects { + cx.emit(Event::RemoteProjectUnshared { + project_id: project.id, + }); + } + false + } + }); + } + + if let Some(pending_participants) = pending_participants.log_err() { + this.pending_participants = pending_participants; + for participant in &this.pending_participants { + this.participant_user_ids.insert(participant.id); + } + } + + this.follows_by_leader_id_project_id.clear(); + for follower in room.followers { + let project_id = follower.project_id; + let (leader, follower) = match (follower.leader_id, follower.follower_id) { + (Some(leader), Some(follower)) => (leader, follower), + + _ => { + log::error!("Follower message {follower:?} missing some state"); + continue; + } + }; + + let list = this + .follows_by_leader_id_project_id + .entry((leader, project_id)) + .or_insert(Vec::new()); + if !list.contains(&follower) { + list.push(follower); + } + } + + this.pending_room_update.take(); + if this.should_leave() { + log::info!("room is empty, leaving"); + let _ = this.leave(cx); + } + + this.user_store.update(cx, |user_store, cx| { + let participant_indices_by_user_id = this + .remote_participants + .iter() + .map(|(user_id, participant)| (*user_id, participant.participant_index)) + .collect(); + user_store.set_participant_indices(participant_indices_by_user_id, cx); + }); + + this.check_invariants(); + this.room_update_completed_tx.try_send(Some(())).ok(); + cx.notify(); + }) + .ok(); + })); + + cx.notify(); + Ok(()) + } + + pub fn room_update_completed(&mut self) -> impl Future { + let mut done_rx = self.room_update_completed_rx.clone(); + async move { + while let Some(result) = done_rx.next().await { + if result.is_some() { + break; + } + } + } + } + + fn remote_video_track_updated( + &mut self, + change: RemoteVideoTrackUpdate, + cx: &mut ModelContext, + ) -> Result<()> { + todo!(); + match change { + RemoteVideoTrackUpdate::Subscribed(track) => { + let user_id = track.publisher_id().parse()?; + let track_id = track.sid().to_string(); + let participant = self + .remote_participants + .get_mut(&user_id) + .ok_or_else(|| anyhow!("subscribed to track by unknown participant"))?; + // participant.video_tracks.insert( + // track_id.clone(), + // Arc::new(RemoteVideoTrack { + // live_kit_track: track, + // }), + // ); + cx.emit(Event::RemoteVideoTracksChanged { + participant_id: participant.peer_id, + }); + } + RemoteVideoTrackUpdate::Unsubscribed { + publisher_id, + track_id, + } => { + let user_id = publisher_id.parse()?; + let participant = self + .remote_participants + .get_mut(&user_id) + .ok_or_else(|| anyhow!("unsubscribed from track by unknown participant"))?; + // participant.video_tracks.remove(&track_id); + cx.emit(Event::RemoteVideoTracksChanged { + participant_id: participant.peer_id, + }); + } + } + + cx.notify(); + Ok(()) + } + + fn remote_audio_track_updated( + &mut self, + change: RemoteAudioTrackUpdate, + cx: &mut ModelContext, + ) -> Result<()> { + match change { + RemoteAudioTrackUpdate::ActiveSpeakersChanged { speakers } => { + let mut speaker_ids = speakers + .into_iter() + .filter_map(|speaker_sid| speaker_sid.parse().ok()) + .collect::>(); + speaker_ids.sort_unstable(); + for (sid, participant) in &mut self.remote_participants { + if let Ok(_) = speaker_ids.binary_search(sid) { + participant.speaking = true; + } else { + participant.speaking = false; + } + } + // todo!() + // if let Some(id) = self.client.user_id() { + // if let Some(room) = &mut self.live_kit { + // if let Ok(_) = speaker_ids.binary_search(&id) { + // room.speaking = true; + // } else { + // room.speaking = false; + // } + // } + // } + cx.notify(); + } + RemoteAudioTrackUpdate::MuteChanged { track_id, muted } => { + // todo!() + // let mut found = false; + // for participant in &mut self.remote_participants.values_mut() { + // for track in participant.audio_tracks.values() { + // if track.sid() == track_id { + // found = true; + // break; + // } + // } + // if found { + // participant.muted = muted; + // break; + // } + // } + + cx.notify(); + } + RemoteAudioTrackUpdate::Subscribed(track, publication) => { + // todo!() + // let user_id = track.publisher_id().parse()?; + // let track_id = track.sid().to_string(); + // let participant = self + // .remote_participants + // .get_mut(&user_id) + // .ok_or_else(|| anyhow!("subscribed to track by unknown participant"))?; + // // participant.audio_tracks.insert(track_id.clone(), track); + // participant.muted = publication.is_muted(); + + // cx.emit(Event::RemoteAudioTracksChanged { + // participant_id: participant.peer_id, + // }); + } + RemoteAudioTrackUpdate::Unsubscribed { + publisher_id, + track_id, + } => { + // todo!() + // let user_id = publisher_id.parse()?; + // let participant = self + // .remote_participants + // .get_mut(&user_id) + // .ok_or_else(|| anyhow!("unsubscribed from track by unknown participant"))?; + // participant.audio_tracks.remove(&track_id); + // cx.emit(Event::RemoteAudioTracksChanged { + // participant_id: participant.peer_id, + // }); + } + } + + cx.notify(); + Ok(()) + } + + fn check_invariants(&self) { + #[cfg(any(test, feature = "test-support"))] + { + for participant in self.remote_participants.values() { + assert!(self.participant_user_ids.contains(&participant.user.id)); + assert_ne!(participant.user.id, self.client.user_id().unwrap()); + } + + for participant in &self.pending_participants { + assert!(self.participant_user_ids.contains(&participant.id)); + assert_ne!(participant.id, self.client.user_id().unwrap()); + } + + assert_eq!( + self.participant_user_ids.len(), + self.remote_participants.len() + self.pending_participants.len() + ); + } + } + + pub(crate) fn call( + &mut self, + called_user_id: u64, + initial_project_id: Option, + cx: &mut ModelContext, + ) -> Task> { + if self.status.is_offline() { + return Task::ready(Err(anyhow!("room is offline"))); + } + + cx.notify(); + let client = self.client.clone(); + let room_id = self.id; + self.pending_call_count += 1; + cx.spawn(move |this, mut cx| async move { + let result = client + .request(proto::Call { + room_id, + called_user_id, + initial_project_id, + }) + .await; + this.update(&mut cx, |this, cx| { + this.pending_call_count -= 1; + if this.should_leave() { + this.leave(cx).detach_and_log_err(cx); + } + })?; + result?; + Ok(()) + }) + } + + pub fn join_project( + &mut self, + id: u64, + language_registry: Arc, + fs: Arc, + cx: &mut ModelContext, + ) -> Task>> { + let client = self.client.clone(); + let user_store = self.user_store.clone(); + cx.emit(Event::RemoteProjectJoined { project_id: id }); + cx.spawn(move |this, mut cx| async move { + let project = + Project::remote(id, client, user_store, language_registry, fs, cx.clone()).await?; + + this.update(&mut cx, |this, cx| { + this.joined_projects.retain(|project| { + if let Some(project) = project.upgrade() { + !project.read(cx).is_read_only() + } else { + false + } + }); + this.joined_projects.insert(project.downgrade()); + })?; + Ok(project) + }) + } + + pub(crate) fn share_project( + &mut self, + project: Model, + cx: &mut ModelContext, + ) -> Task> { + if let Some(project_id) = project.read(cx).remote_id() { + return Task::ready(Ok(project_id)); + } + + let request = self.client.request(proto::ShareProject { + room_id: self.id(), + worktrees: project.read(cx).worktree_metadata_protos(cx), + }); + cx.spawn(|this, mut cx| async move { + let response = request.await?; + + project.update(&mut cx, |project, cx| { + project.shared(response.project_id, cx) + })??; + + // If the user's location is in this project, it changes from UnsharedProject to SharedProject. + this.update(&mut cx, |this, cx| { + this.shared_projects.insert(project.downgrade()); + let active_project = this.local_participant.active_project.as_ref(); + if active_project.map_or(false, |location| *location == project) { + this.set_location(Some(&project), cx) + } else { + Task::ready(Ok(())) + } + })? + .await?; + + Ok(response.project_id) + }) + } + + pub(crate) fn unshare_project( + &mut self, + project: Model, + cx: &mut ModelContext, + ) -> Result<()> { + let project_id = match project.read(cx).remote_id() { + Some(project_id) => project_id, + None => return Ok(()), + }; + + self.client.send(proto::UnshareProject { project_id })?; + project.update(cx, |this, cx| this.unshare(cx)) + } + + pub(crate) fn set_location( + &mut self, + project: Option<&Model>, + cx: &mut ModelContext, + ) -> Task> { + if self.status.is_offline() { + return Task::ready(Err(anyhow!("room is offline"))); + } + + let client = self.client.clone(); + let room_id = self.id; + let location = if let Some(project) = project { + self.local_participant.active_project = Some(project.downgrade()); + if let Some(project_id) = project.read(cx).remote_id() { + proto::participant_location::Variant::SharedProject( + proto::participant_location::SharedProject { id: project_id }, + ) + } else { + proto::participant_location::Variant::UnsharedProject( + proto::participant_location::UnsharedProject {}, + ) + } + } else { + self.local_participant.active_project = None; + proto::participant_location::Variant::External(proto::participant_location::External {}) + }; + + cx.notify(); + cx.executor().spawn_on_main(move || async move { + client + .request(proto::UpdateParticipantLocation { + room_id, + location: Some(proto::ParticipantLocation { + variant: Some(location), + }), + }) + .await?; + Ok(()) + }) + } + + pub fn is_screen_sharing(&self) -> bool { + todo!() + // self.live_kit.as_ref().map_or(false, |live_kit| { + // !matches!(live_kit.screen_track, LocalTrack::None) + // }) + } + + pub fn is_sharing_mic(&self) -> bool { + todo!() + // self.live_kit.as_ref().map_or(false, |live_kit| { + // !matches!(live_kit.microphone_track, LocalTrack::None) + // }) + } + + pub fn is_muted(&self, cx: &AppContext) -> bool { + todo!() + // self.live_kit + // .as_ref() + // .and_then(|live_kit| match &live_kit.microphone_track { + // LocalTrack::None => Some(Self::mute_on_join(cx)), + // LocalTrack::Pending { muted, .. } => Some(*muted), + // LocalTrack::Published { muted, .. } => Some(*muted), + // }) + // .unwrap_or(false) + } + + pub fn is_speaking(&self) -> bool { + todo!() + // self.live_kit + // .as_ref() + // .map_or(false, |live_kit| live_kit.speaking) + } + + pub fn is_deafened(&self) -> Option { + // self.live_kit.as_ref().map(|live_kit| live_kit.deafened) + todo!() + } + + #[track_caller] + pub fn share_microphone(&mut self, cx: &mut ModelContext) -> Task> { + todo!() + // if self.status.is_offline() { + // return Task::ready(Err(anyhow!("room is offline"))); + // } else if self.is_sharing_mic() { + // return Task::ready(Err(anyhow!("microphone was already shared"))); + // } + + // let publish_id = if let Some(live_kit) = self.live_kit.as_mut() { + // let publish_id = post_inc(&mut live_kit.next_publish_id); + // live_kit.microphone_track = LocalTrack::Pending { + // publish_id, + // muted: false, + // }; + // cx.notify(); + // publish_id + // } else { + // return Task::ready(Err(anyhow!("live-kit was not initialized"))); + // }; + + // cx.spawn(move |this, mut cx| async move { + // let publish_track = async { + // let track = LocalAudioTrack::create(); + // this.upgrade() + // .ok_or_else(|| anyhow!("room was dropped"))? + // .update(&mut cx, |this, _| { + // this.live_kit + // .as_ref() + // .map(|live_kit| live_kit.room.publish_audio_track(track)) + // })? + // .ok_or_else(|| anyhow!("live-kit was not initialized"))? + // .await + // }; + + // let publication = publish_track.await; + // this.upgrade() + // .ok_or_else(|| anyhow!("room was dropped"))? + // .update(&mut cx, |this, cx| { + // let live_kit = this + // .live_kit + // .as_mut() + // .ok_or_else(|| anyhow!("live-kit was not initialized"))?; + + // let (canceled, muted) = if let LocalTrack::Pending { + // publish_id: cur_publish_id, + // muted, + // } = &live_kit.microphone_track + // { + // (*cur_publish_id != publish_id, *muted) + // } else { + // (true, false) + // }; + + // match publication { + // Ok(publication) => { + // if canceled { + // live_kit.room.unpublish_track(publication); + // } else { + // if muted { + // cx.executor().spawn(publication.set_mute(muted)).detach(); + // } + // live_kit.microphone_track = LocalTrack::Published { + // track_publication: publication, + // muted, + // }; + // cx.notify(); + // } + // Ok(()) + // } + // Err(error) => { + // if canceled { + // Ok(()) + // } else { + // live_kit.microphone_track = LocalTrack::None; + // cx.notify(); + // Err(error) + // } + // } + // } + // })? + // }) + } + + pub fn share_screen(&mut self, cx: &mut ModelContext) -> Task> { + todo!() + // if self.status.is_offline() { + // return Task::ready(Err(anyhow!("room is offline"))); + // } else if self.is_screen_sharing() { + // return Task::ready(Err(anyhow!("screen was already shared"))); + // } + + // let (displays, publish_id) = if let Some(live_kit) = self.live_kit.as_mut() { + // let publish_id = post_inc(&mut live_kit.next_publish_id); + // live_kit.screen_track = LocalTrack::Pending { + // publish_id, + // muted: false, + // }; + // cx.notify(); + // (live_kit.room.display_sources(), publish_id) + // } else { + // return Task::ready(Err(anyhow!("live-kit was not initialized"))); + // }; + + // cx.spawn(move |this, mut cx| async move { + // let publish_track = async { + // let displays = displays.await?; + // let display = displays + // .first() + // .ok_or_else(|| anyhow!("no display found"))?; + // let track = LocalVideoTrack::screen_share_for_display(&display); + // this.upgrade() + // .ok_or_else(|| anyhow!("room was dropped"))? + // .update(&mut cx, |this, _| { + // this.live_kit + // .as_ref() + // .map(|live_kit| live_kit.room.publish_video_track(track)) + // })? + // .ok_or_else(|| anyhow!("live-kit was not initialized"))? + // .await + // }; + + // let publication = publish_track.await; + // this.upgrade() + // .ok_or_else(|| anyhow!("room was dropped"))? + // .update(&mut cx, |this, cx| { + // let live_kit = this + // .live_kit + // .as_mut() + // .ok_or_else(|| anyhow!("live-kit was not initialized"))?; + + // let (canceled, muted) = if let LocalTrack::Pending { + // publish_id: cur_publish_id, + // muted, + // } = &live_kit.screen_track + // { + // (*cur_publish_id != publish_id, *muted) + // } else { + // (true, false) + // }; + + // match publication { + // Ok(publication) => { + // if canceled { + // live_kit.room.unpublish_track(publication); + // } else { + // if muted { + // cx.executor().spawn(publication.set_mute(muted)).detach(); + // } + // live_kit.screen_track = LocalTrack::Published { + // track_publication: publication, + // muted, + // }; + // cx.notify(); + // } + + // Audio::play_sound(Sound::StartScreenshare, cx); + + // Ok(()) + // } + // Err(error) => { + // if canceled { + // Ok(()) + // } else { + // live_kit.screen_track = LocalTrack::None; + // cx.notify(); + // Err(error) + // } + // } + // } + // })? + // }) + } + + pub fn toggle_mute(&mut self, cx: &mut ModelContext) -> Result>> { + todo!() + // let should_mute = !self.is_muted(cx); + // if let Some(live_kit) = self.live_kit.as_mut() { + // if matches!(live_kit.microphone_track, LocalTrack::None) { + // return Ok(self.share_microphone(cx)); + // } + + // let (ret_task, old_muted) = live_kit.set_mute(should_mute, cx)?; + // live_kit.muted_by_user = should_mute; + + // if old_muted == true && live_kit.deafened == true { + // if let Some(task) = self.toggle_deafen(cx).ok() { + // task.detach(); + // } + // } + + // Ok(ret_task) + // } else { + // Err(anyhow!("LiveKit not started")) + // } + } + + pub fn toggle_deafen(&mut self, cx: &mut ModelContext) -> Result>> { + todo!() + // if let Some(live_kit) = self.live_kit.as_mut() { + // (*live_kit).deafened = !live_kit.deafened; + + // let mut tasks = Vec::with_capacity(self.remote_participants.len()); + // // Context notification is sent within set_mute itself. + // let mut mute_task = None; + // // When deafening, mute user's mic as well. + // // When undeafening, unmute user's mic unless it was manually muted prior to deafening. + // if live_kit.deafened || !live_kit.muted_by_user { + // mute_task = Some(live_kit.set_mute(live_kit.deafened, cx)?.0); + // }; + // for participant in self.remote_participants.values() { + // for track in live_kit + // .room + // .remote_audio_track_publications(&participant.user.id.to_string()) + // { + // let deafened = live_kit.deafened; + // tasks.push( + // cx.executor() + // .spawn_on_main(move || track.set_enabled(!deafened)), + // ); + // } + // } + + // Ok(cx.executor().spawn_on_main(|| async { + // if let Some(mute_task) = mute_task { + // mute_task.await?; + // } + // for task in tasks { + // task.await?; + // } + // Ok(()) + // })) + // } else { + // Err(anyhow!("LiveKit not started")) + // } + } + + pub fn unshare_screen(&mut self, cx: &mut ModelContext) -> Result<()> { + if self.status.is_offline() { + return Err(anyhow!("room is offline")); + } + + todo!() + // let live_kit = self + // .live_kit + // .as_mut() + // .ok_or_else(|| anyhow!("live-kit was not initialized"))?; + // match mem::take(&mut live_kit.screen_track) { + // LocalTrack::None => Err(anyhow!("screen was not shared")), + // LocalTrack::Pending { .. } => { + // cx.notify(); + // Ok(()) + // } + // LocalTrack::Published { + // track_publication, .. + // } => { + // live_kit.room.unpublish_track(track_publication); + // cx.notify(); + + // Audio::play_sound(Sound::StopScreenshare, cx); + // Ok(()) + // } + // } + } + + #[cfg(any(test, feature = "test-support"))] + pub fn set_display_sources(&self, sources: Vec) { + todo!() + // self.live_kit + // .as_ref() + // .unwrap() + // .room + // .set_display_sources(sources); + } +} + +struct LiveKitRoom { + room: Arc, + screen_track: LocalTrack, + microphone_track: LocalTrack, + /// Tracks whether we're currently in a muted state due to auto-mute from deafening or manual mute performed by user. + muted_by_user: bool, + deafened: bool, + speaking: bool, + next_publish_id: usize, + _maintain_room: Task<()>, + _maintain_tracks: [Task<()>; 2], +} + +impl LiveKitRoom { + fn set_mute( + self: &mut LiveKitRoom, + should_mute: bool, + cx: &mut ModelContext, + ) -> Result<(Task>, bool)> { + if !should_mute { + // clear user muting state. + self.muted_by_user = false; + } + + let (result, old_muted) = match &mut self.microphone_track { + LocalTrack::None => Err(anyhow!("microphone was not shared")), + LocalTrack::Pending { muted, .. } => { + let old_muted = *muted; + *muted = should_mute; + cx.notify(); + Ok((Task::Ready(Some(Ok(()))), old_muted)) + } + LocalTrack::Published { + track_publication, + muted, + } => { + let old_muted = *muted; + *muted = should_mute; + cx.notify(); + Ok(( + cx.executor().spawn(track_publication.set_mute(*muted)), + old_muted, + )) + } + }?; + + if old_muted != should_mute { + if should_mute { + Audio::play_sound(Sound::Mute, cx); + } else { + Audio::play_sound(Sound::Unmute, cx); + } + } + + Ok((result, old_muted)) + } +} + +enum LocalTrack { + None, + Pending { + publish_id: usize, + muted: bool, + }, + Published { + track_publication: LocalTrackPublication, + muted: bool, + }, +} + +impl Default for LocalTrack { + fn default() -> Self { + Self::None + } +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum RoomStatus { + Online, + Rejoining, + Offline, +} + +impl RoomStatus { + pub fn is_offline(&self) -> bool { + matches!(self, RoomStatus::Offline) + } + + pub fn is_online(&self) -> bool { + matches!(self, RoomStatus::Online) + } +} diff --git a/crates/client2/Cargo.toml b/crates/client2/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..8a6edbb4284eb92283f14194e7bbec5850fc12e7 --- /dev/null +++ b/crates/client2/Cargo.toml @@ -0,0 +1,52 @@ +[package] +name = "client2" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/client2.rs" +doctest = false + +[features] +test-support = ["collections/test-support", "gpui2/test-support", "rpc2/test-support"] + +[dependencies] +collections = { path = "../collections" } +db2 = { path = "../db2" } +gpui2 = { path = "../gpui2" } +util = { path = "../util" } +rpc2 = { path = "../rpc2" } +text = { path = "../text" } +settings2 = { path = "../settings2" } +feature_flags2 = { path = "../feature_flags2" } +sum_tree = { path = "../sum_tree" } + +anyhow.workspace = true +async-recursion = "0.3" +async-tungstenite = { version = "0.16", features = ["async-tls"] } +futures.workspace = true +image = "0.23" +lazy_static.workspace = true +log.workspace = true +parking_lot.workspace = true +postage.workspace = true +rand.workspace = true +schemars.workspace = true +serde.workspace = true +serde_derive.workspace = true +smol.workspace = true +sysinfo.workspace = true +tempfile = "3" +thiserror.workspace = true +time.workspace = true +tiny_http = "0.8" +uuid.workspace = true +url = "2.2" + +[dev-dependencies] +collections = { path = "../collections", features = ["test-support"] } +gpui2 = { path = "../gpui2", features = ["test-support"] } +rpc2 = { path = "../rpc2", features = ["test-support"] } +settings = { path = "../settings", features = ["test-support"] } +util = { path = "../util", features = ["test-support"] } diff --git a/crates/client2/src/client2.rs b/crates/client2/src/client2.rs new file mode 100644 index 0000000000000000000000000000000000000000..19e8685c28cd55094b064ea0af60bbd6744fa475 --- /dev/null +++ b/crates/client2/src/client2.rs @@ -0,0 +1,1651 @@ +#[cfg(any(test, feature = "test-support"))] +pub mod test; + +pub mod telemetry; +pub mod user; + +use anyhow::{anyhow, Context as _, Result}; +use async_recursion::async_recursion; +use async_tungstenite::tungstenite::{ + error::Error as WebsocketError, + http::{Request, StatusCode}, +}; +use futures::{ + future::BoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _, TryStreamExt, +}; +use gpui2::{ + serde_json, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Model, SemanticVersion, Task, + WeakModel, +}; +use lazy_static::lazy_static; +use parking_lot::RwLock; +use postage::watch; +use rand::prelude::*; +use rpc2::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, PeerId, RequestMessage}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use settings2::Settings; +use std::{ + any::TypeId, + collections::HashMap, + convert::TryFrom, + fmt::Write as _, + future::Future, + marker::PhantomData, + path::PathBuf, + sync::{atomic::AtomicU64, Arc, Weak}, + time::{Duration, Instant}, +}; +use telemetry::Telemetry; +use thiserror::Error; +use url::Url; +use util::channel::ReleaseChannel; +use util::http::HttpClient; +use util::{ResultExt, TryFutureExt}; + +pub use rpc2::*; +pub use telemetry::ClickhouseEvent; +pub use user::*; + +lazy_static! { + pub static ref ZED_SERVER_URL: String = + std::env::var("ZED_SERVER_URL").unwrap_or_else(|_| "https://zed.dev".to_string()); + pub static ref IMPERSONATE_LOGIN: Option = std::env::var("ZED_IMPERSONATE") + .ok() + .and_then(|s| if s.is_empty() { None } else { Some(s) }); + pub static ref ADMIN_API_TOKEN: Option = std::env::var("ZED_ADMIN_API_TOKEN") + .ok() + .and_then(|s| if s.is_empty() { None } else { Some(s) }); + pub static ref ZED_APP_VERSION: Option = std::env::var("ZED_APP_VERSION") + .ok() + .and_then(|v| v.parse().ok()); + pub static ref ZED_APP_PATH: Option = + std::env::var("ZED_APP_PATH").ok().map(PathBuf::from); + pub static ref ZED_ALWAYS_ACTIVE: bool = + std::env::var("ZED_ALWAYS_ACTIVE").map_or(false, |e| e.len() > 0); +} + +pub const ZED_SECRET_CLIENT_TOKEN: &str = "618033988749894"; +pub const INITIAL_RECONNECTION_DELAY: Duration = Duration::from_millis(100); +pub const CONNECTION_TIMEOUT: Duration = Duration::from_secs(5); + +#[derive(Clone, Default, PartialEq, Deserialize)] +pub struct SignIn; + +#[derive(Clone, Default, PartialEq, Deserialize)] +pub struct SignOut; + +#[derive(Clone, Default, PartialEq, Deserialize)] +pub struct Reconnect; + +pub fn init_settings(cx: &mut AppContext) { + TelemetrySettings::register(cx); +} + +pub fn init(client: &Arc, cx: &mut AppContext) { + init_settings(cx); + + let client = Arc::downgrade(client); + cx.register_action_type::(); + cx.on_action({ + let client = client.clone(); + move |_: &SignIn, cx| { + if let Some(client) = client.upgrade() { + cx.spawn( + |cx| async move { client.authenticate_and_connect(true, &cx).log_err().await }, + ) + .detach(); + } + } + }); + + cx.register_action_type::(); + cx.on_action({ + let client = client.clone(); + move |_: &SignOut, cx| { + if let Some(client) = client.upgrade() { + cx.spawn(|cx| async move { + client.disconnect(&cx); + }) + .detach(); + } + } + }); + + cx.register_action_type::(); + cx.on_action({ + let client = client.clone(); + move |_: &Reconnect, cx| { + if let Some(client) = client.upgrade() { + cx.spawn(|cx| async move { + client.reconnect(&cx); + }) + .detach(); + } + } + }); +} + +pub struct Client { + id: AtomicU64, + peer: Arc, + http: Arc, + telemetry: Arc, + state: RwLock, + + #[allow(clippy::type_complexity)] + #[cfg(any(test, feature = "test-support"))] + authenticate: RwLock< + Option Task>>>, + >, + + #[allow(clippy::type_complexity)] + #[cfg(any(test, feature = "test-support"))] + establish_connection: RwLock< + Option< + Box< + dyn 'static + + Send + + Sync + + Fn( + &Credentials, + &AsyncAppContext, + ) -> Task>, + >, + >, + >, +} + +#[derive(Error, Debug)] +pub enum EstablishConnectionError { + #[error("upgrade required")] + UpgradeRequired, + #[error("unauthorized")] + Unauthorized, + #[error("{0}")] + Other(#[from] anyhow::Error), + #[error("{0}")] + Http(#[from] util::http::Error), + #[error("{0}")] + Io(#[from] std::io::Error), + #[error("{0}")] + Websocket(#[from] async_tungstenite::tungstenite::http::Error), +} + +impl From for EstablishConnectionError { + fn from(error: WebsocketError) -> Self { + if let WebsocketError::Http(response) = &error { + match response.status() { + StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized, + StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired, + _ => {} + } + } + EstablishConnectionError::Other(error.into()) + } +} + +impl EstablishConnectionError { + pub fn other(error: impl Into + Send + Sync) -> Self { + Self::Other(error.into()) + } +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum Status { + SignedOut, + UpgradeRequired, + Authenticating, + Connecting, + ConnectionError, + Connected { + peer_id: PeerId, + connection_id: ConnectionId, + }, + ConnectionLost, + Reauthenticating, + Reconnecting, + ReconnectionError { + next_reconnection: Instant, + }, +} + +impl Status { + pub fn is_connected(&self) -> bool { + matches!(self, Self::Connected { .. }) + } + + pub fn is_signed_out(&self) -> bool { + matches!(self, Self::SignedOut | Self::UpgradeRequired) + } +} + +struct ClientState { + credentials: Option, + status: (watch::Sender, watch::Receiver), + entity_id_extractors: HashMap u64>, + _reconnect_task: Option>, + reconnect_interval: Duration, + entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>, + models_by_message_type: HashMap, + entity_types_by_message_type: HashMap, + #[allow(clippy::type_complexity)] + message_handlers: HashMap< + TypeId, + Arc< + dyn Send + + Sync + + Fn( + AnyModel, + Box, + &Arc, + AsyncAppContext, + ) -> BoxFuture<'static, Result<()>>, + >, + >, +} + +enum WeakSubscriber { + Entity { handle: AnyWeakModel }, + Pending(Vec>), +} + +#[derive(Clone, Debug)] +pub struct Credentials { + pub user_id: u64, + pub access_token: String, +} + +impl Default for ClientState { + fn default() -> Self { + Self { + credentials: None, + status: watch::channel_with(Status::SignedOut), + entity_id_extractors: Default::default(), + _reconnect_task: None, + reconnect_interval: Duration::from_secs(5), + models_by_message_type: Default::default(), + entities_by_type_and_remote_id: Default::default(), + entity_types_by_message_type: Default::default(), + message_handlers: Default::default(), + } + } +} + +pub enum Subscription { + Entity { + client: Weak, + id: (TypeId, u64), + }, + Message { + client: Weak, + id: TypeId, + }, +} + +impl Drop for Subscription { + fn drop(&mut self) { + match self { + Subscription::Entity { client, id } => { + if let Some(client) = client.upgrade() { + let mut state = client.state.write(); + let _ = state.entities_by_type_and_remote_id.remove(id); + } + } + Subscription::Message { client, id } => { + if let Some(client) = client.upgrade() { + let mut state = client.state.write(); + let _ = state.entity_types_by_message_type.remove(id); + let _ = state.message_handlers.remove(id); + } + } + } + } +} + +pub struct PendingEntitySubscription { + client: Arc, + remote_id: u64, + _entity_type: PhantomData, + consumed: bool, +} + +impl PendingEntitySubscription +where + T: 'static + Send, +{ + pub fn set_model(mut self, model: &Model, cx: &mut AsyncAppContext) -> Subscription { + self.consumed = true; + let mut state = self.client.state.write(); + let id = (TypeId::of::(), self.remote_id); + let Some(WeakSubscriber::Pending(messages)) = + state.entities_by_type_and_remote_id.remove(&id) + else { + unreachable!() + }; + + state.entities_by_type_and_remote_id.insert( + id, + WeakSubscriber::Entity { + handle: model.downgrade().into(), + }, + ); + drop(state); + for message in messages { + self.client.handle_message(message, cx); + } + Subscription::Entity { + client: Arc::downgrade(&self.client), + id, + } + } +} + +impl Drop for PendingEntitySubscription +where + T: 'static, +{ + fn drop(&mut self) { + if !self.consumed { + let mut state = self.client.state.write(); + if let Some(WeakSubscriber::Pending(messages)) = state + .entities_by_type_and_remote_id + .remove(&(TypeId::of::(), self.remote_id)) + { + for message in messages { + log::info!("unhandled message {}", message.payload_type_name()); + } + } + } + } +} + +#[derive(Copy, Clone)] +pub struct TelemetrySettings { + pub diagnostics: bool, + pub metrics: bool, +} + +#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)] +pub struct TelemetrySettingsContent { + pub diagnostics: Option, + pub metrics: Option, +} + +impl settings2::Settings for TelemetrySettings { + const KEY: Option<&'static str> = Some("telemetry"); + + type FileContent = TelemetrySettingsContent; + + fn load( + default_value: &Self::FileContent, + user_values: &[&Self::FileContent], + _: &mut AppContext, + ) -> Result { + Ok(Self { + diagnostics: user_values.first().and_then(|v| v.diagnostics).unwrap_or( + default_value + .diagnostics + .ok_or_else(Self::missing_default)?, + ), + metrics: user_values + .first() + .and_then(|v| v.metrics) + .unwrap_or(default_value.metrics.ok_or_else(Self::missing_default)?), + }) + } +} + +impl Client { + pub fn new(http: Arc, cx: &AppContext) -> Arc { + Arc::new(Self { + id: AtomicU64::new(0), + peer: Peer::new(0), + telemetry: Telemetry::new(http.clone(), cx), + http, + state: Default::default(), + + #[cfg(any(test, feature = "test-support"))] + authenticate: Default::default(), + #[cfg(any(test, feature = "test-support"))] + establish_connection: Default::default(), + }) + } + + pub fn id(&self) -> u64 { + self.id.load(std::sync::atomic::Ordering::SeqCst) + } + + pub fn http_client(&self) -> Arc { + self.http.clone() + } + + pub fn set_id(&self, id: u64) -> &Self { + self.id.store(id, std::sync::atomic::Ordering::SeqCst); + self + } + + #[cfg(any(test, feature = "test-support"))] + pub fn teardown(&self) { + let mut state = self.state.write(); + state._reconnect_task.take(); + state.message_handlers.clear(); + state.models_by_message_type.clear(); + state.entities_by_type_and_remote_id.clear(); + state.entity_id_extractors.clear(); + self.peer.teardown(); + } + + #[cfg(any(test, feature = "test-support"))] + pub fn override_authenticate(&self, authenticate: F) -> &Self + where + F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task>, + { + *self.authenticate.write() = Some(Box::new(authenticate)); + self + } + + #[cfg(any(test, feature = "test-support"))] + pub fn override_establish_connection(&self, connect: F) -> &Self + where + F: 'static + + Send + + Sync + + Fn(&Credentials, &AsyncAppContext) -> Task>, + { + *self.establish_connection.write() = Some(Box::new(connect)); + self + } + + pub fn user_id(&self) -> Option { + self.state + .read() + .credentials + .as_ref() + .map(|credentials| credentials.user_id) + } + + pub fn peer_id(&self) -> Option { + if let Status::Connected { peer_id, .. } = &*self.status().borrow() { + Some(*peer_id) + } else { + None + } + } + + pub fn status(&self) -> watch::Receiver { + self.state.read().status.1.clone() + } + + fn set_status(self: &Arc, status: Status, cx: &AsyncAppContext) { + log::info!("set status on client {}: {:?}", self.id(), status); + let mut state = self.state.write(); + *state.status.0.borrow_mut() = status; + + match status { + Status::Connected { .. } => { + state._reconnect_task = None; + } + Status::ConnectionLost => { + let this = self.clone(); + let reconnect_interval = state.reconnect_interval; + state._reconnect_task = Some(cx.spawn(move |cx| async move { + #[cfg(any(test, feature = "test-support"))] + let mut rng = StdRng::seed_from_u64(0); + #[cfg(not(any(test, feature = "test-support")))] + let mut rng = StdRng::from_entropy(); + + let mut delay = INITIAL_RECONNECTION_DELAY; + while let Err(error) = this.authenticate_and_connect(true, &cx).await { + log::error!("failed to connect {}", error); + if matches!(*this.status().borrow(), Status::ConnectionError) { + this.set_status( + Status::ReconnectionError { + next_reconnection: Instant::now() + delay, + }, + &cx, + ); + cx.executor().timer(delay).await; + delay = delay + .mul_f32(rng.gen_range(1.0..=2.0)) + .min(reconnect_interval); + } else { + break; + } + } + })); + } + Status::SignedOut | Status::UpgradeRequired => { + cx.update(|cx| self.telemetry.set_authenticated_user_info(None, false, cx)) + .log_err(); + state._reconnect_task.take(); + } + _ => {} + } + } + + pub fn subscribe_to_entity( + self: &Arc, + remote_id: u64, + ) -> Result> + where + T: 'static + Send, + { + let id = (TypeId::of::(), remote_id); + + let mut state = self.state.write(); + if state.entities_by_type_and_remote_id.contains_key(&id) { + return Err(anyhow!("already subscribed to entity")); + } else { + state + .entities_by_type_and_remote_id + .insert(id, WeakSubscriber::Pending(Default::default())); + Ok(PendingEntitySubscription { + client: self.clone(), + remote_id, + consumed: false, + _entity_type: PhantomData, + }) + } + } + + #[track_caller] + pub fn add_message_handler( + self: &Arc, + entity: WeakModel, + handler: H, + ) -> Subscription + where + M: EnvelopedMessage, + E: 'static + Send, + H: 'static + Send + Sync + Fn(Model, TypedEnvelope, Arc, AsyncAppContext) -> F, + F: 'static + Future> + Send, + { + let message_type_id = TypeId::of::(); + + let mut state = self.state.write(); + state + .models_by_message_type + .insert(message_type_id, entity.into()); + + let prev_handler = state.message_handlers.insert( + message_type_id, + Arc::new(move |subscriber, envelope, client, cx| { + let subscriber = subscriber.downcast::().unwrap(); + let envelope = envelope.into_any().downcast::>().unwrap(); + handler(subscriber, *envelope, client.clone(), cx).boxed() + }), + ); + if prev_handler.is_some() { + let location = std::panic::Location::caller(); + panic!( + "{}:{} registered handler for the same message {} twice", + location.file(), + location.line(), + std::any::type_name::() + ); + } + + Subscription::Message { + client: Arc::downgrade(self), + id: message_type_id, + } + } + + pub fn add_request_handler( + self: &Arc, + model: WeakModel, + handler: H, + ) -> Subscription + where + M: RequestMessage, + E: 'static + Send, + H: 'static + Send + Sync + Fn(Model, TypedEnvelope, Arc, AsyncAppContext) -> F, + F: 'static + Future> + Send, + { + self.add_message_handler(model, move |handle, envelope, this, cx| { + Self::respond_to_request( + envelope.receipt(), + handler(handle, envelope, this.clone(), cx), + this, + ) + }) + } + + pub fn add_model_message_handler(self: &Arc, handler: H) + where + M: EntityMessage, + E: 'static + Send, + H: 'static + Send + Sync + Fn(Model, TypedEnvelope, Arc, AsyncAppContext) -> F, + F: 'static + Future> + Send, + { + self.add_entity_message_handler::(move |subscriber, message, client, cx| { + handler(subscriber.downcast::().unwrap(), message, client, cx) + }) + } + + fn add_entity_message_handler(self: &Arc, handler: H) + where + M: EntityMessage, + E: 'static + Send, + H: 'static + Send + Sync + Fn(AnyModel, TypedEnvelope, Arc, AsyncAppContext) -> F, + F: 'static + Future> + Send, + { + let model_type_id = TypeId::of::(); + let message_type_id = TypeId::of::(); + + let mut state = self.state.write(); + state + .entity_types_by_message_type + .insert(message_type_id, model_type_id); + state + .entity_id_extractors + .entry(message_type_id) + .or_insert_with(|| { + |envelope| { + envelope + .as_any() + .downcast_ref::>() + .unwrap() + .payload + .remote_entity_id() + } + }); + let prev_handler = state.message_handlers.insert( + message_type_id, + Arc::new(move |handle, envelope, client, cx| { + let envelope = envelope.into_any().downcast::>().unwrap(); + handler(handle, *envelope, client.clone(), cx).boxed() + }), + ); + if prev_handler.is_some() { + panic!("registered handler for the same message twice"); + } + } + + pub fn add_model_request_handler(self: &Arc, handler: H) + where + M: EntityMessage + RequestMessage, + E: 'static + Send, + H: 'static + Send + Sync + Fn(Model, TypedEnvelope, Arc, AsyncAppContext) -> F, + F: 'static + Future> + Send, + { + self.add_model_message_handler(move |entity, envelope, client, cx| { + Self::respond_to_request::( + envelope.receipt(), + handler(entity, envelope, client.clone(), cx), + client, + ) + }) + } + + async fn respond_to_request>>( + receipt: Receipt, + response: F, + client: Arc, + ) -> Result<()> { + match response.await { + Ok(response) => { + client.respond(receipt, response)?; + Ok(()) + } + Err(error) => { + client.respond_with_error( + receipt, + proto::Error { + message: format!("{:?}", error), + }, + )?; + Err(error) + } + } + } + + pub async fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool { + read_credentials_from_keychain(cx).await.is_some() + } + + #[async_recursion] + pub async fn authenticate_and_connect( + self: &Arc, + try_keychain: bool, + cx: &AsyncAppContext, + ) -> anyhow::Result<()> { + let was_disconnected = match *self.status().borrow() { + Status::SignedOut => true, + Status::ConnectionError + | Status::ConnectionLost + | Status::Authenticating { .. } + | Status::Reauthenticating { .. } + | Status::ReconnectionError { .. } => false, + Status::Connected { .. } | Status::Connecting { .. } | Status::Reconnecting { .. } => { + return Ok(()) + } + Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?, + }; + + if was_disconnected { + self.set_status(Status::Authenticating, cx); + } else { + self.set_status(Status::Reauthenticating, cx) + } + + let mut read_from_keychain = false; + let mut credentials = self.state.read().credentials.clone(); + if credentials.is_none() && try_keychain { + credentials = read_credentials_from_keychain(cx).await; + read_from_keychain = credentials.is_some(); + } + if credentials.is_none() { + let mut status_rx = self.status(); + let _ = status_rx.next().await; + futures::select_biased! { + authenticate = self.authenticate(cx).fuse() => { + match authenticate { + Ok(creds) => credentials = Some(creds), + Err(err) => { + self.set_status(Status::ConnectionError, cx); + return Err(err); + } + } + } + _ = status_rx.next().fuse() => { + return Err(anyhow!("authentication canceled")); + } + } + } + let credentials = credentials.unwrap(); + self.set_id(credentials.user_id); + + if was_disconnected { + self.set_status(Status::Connecting, cx); + } else { + self.set_status(Status::Reconnecting, cx); + } + + let mut timeout = futures::FutureExt::fuse(cx.executor().timer(CONNECTION_TIMEOUT)); + futures::select_biased! { + connection = self.establish_connection(&credentials, cx).fuse() => { + match connection { + Ok(conn) => { + self.state.write().credentials = Some(credentials.clone()); + if !read_from_keychain && IMPERSONATE_LOGIN.is_none() { + write_credentials_to_keychain(credentials, cx).log_err(); + } + + futures::select_biased! { + result = self.set_connection(conn, cx).fuse() => result, + _ = timeout => { + self.set_status(Status::ConnectionError, cx); + Err(anyhow!("timed out waiting on hello message from server")) + } + } + } + Err(EstablishConnectionError::Unauthorized) => { + self.state.write().credentials.take(); + if read_from_keychain { + delete_credentials_from_keychain(cx).log_err(); + self.set_status(Status::SignedOut, cx); + self.authenticate_and_connect(false, cx).await + } else { + self.set_status(Status::ConnectionError, cx); + Err(EstablishConnectionError::Unauthorized)? + } + } + Err(EstablishConnectionError::UpgradeRequired) => { + self.set_status(Status::UpgradeRequired, cx); + Err(EstablishConnectionError::UpgradeRequired)? + } + Err(error) => { + self.set_status(Status::ConnectionError, cx); + Err(error)? + } + } + } + _ = &mut timeout => { + self.set_status(Status::ConnectionError, cx); + Err(anyhow!("timed out trying to establish connection")) + } + } + } + + async fn set_connection( + self: &Arc, + conn: Connection, + cx: &AsyncAppContext, + ) -> Result<()> { + let executor = cx.executor(); + log::info!("add connection to peer"); + let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn, { + let executor = executor.clone(); + move |duration| executor.timer(duration) + }); + let handle_io = executor.spawn(handle_io); + + let peer_id = async { + log::info!("waiting for server hello"); + let message = incoming + .next() + .await + .ok_or_else(|| anyhow!("no hello message received"))?; + log::info!("got server hello"); + let hello_message_type_name = message.payload_type_name().to_string(); + let hello = message + .into_any() + .downcast::>() + .map_err(|_| { + anyhow!( + "invalid hello message received: {:?}", + hello_message_type_name + ) + })?; + let peer_id = hello + .payload + .peer_id + .ok_or_else(|| anyhow!("invalid peer id"))?; + Ok(peer_id) + }; + + let peer_id = match peer_id.await { + Ok(peer_id) => peer_id, + Err(error) => { + self.peer.disconnect(connection_id); + return Err(error); + } + }; + + log::info!( + "set status to connected (connection id: {:?}, peer id: {:?})", + connection_id, + peer_id + ); + self.set_status( + Status::Connected { + peer_id, + connection_id, + }, + cx, + ); + + cx.spawn({ + let this = self.clone(); + |cx| { + async move { + while let Some(message) = incoming.next().await { + this.handle_message(message, &cx); + // Don't starve the main thread when receiving lots of messages at once. + smol::future::yield_now().await; + } + } + } + }) + .detach(); + + cx.spawn({ + let this = self.clone(); + move |cx| async move { + match handle_io.await { + Ok(()) => { + if this.status().borrow().clone() + == (Status::Connected { + connection_id, + peer_id, + }) + { + this.set_status(Status::SignedOut, &cx); + } + } + Err(err) => { + log::error!("connection error: {:?}", err); + this.set_status(Status::ConnectionLost, &cx); + } + } + } + }) + .detach(); + + Ok(()) + } + + fn authenticate(self: &Arc, cx: &AsyncAppContext) -> Task> { + #[cfg(any(test, feature = "test-support"))] + if let Some(callback) = self.authenticate.read().as_ref() { + return callback(cx); + } + + self.authenticate_with_browser(cx) + } + + fn establish_connection( + self: &Arc, + credentials: &Credentials, + cx: &AsyncAppContext, + ) -> Task> { + #[cfg(any(test, feature = "test-support"))] + if let Some(callback) = self.establish_connection.read().as_ref() { + return callback(credentials, cx); + } + + self.establish_websocket_connection(credentials, cx) + } + + async fn get_rpc_url(http: Arc, is_preview: bool) -> Result { + let preview_param = if is_preview { "?preview=1" } else { "" }; + let url = format!("{}/rpc{preview_param}", *ZED_SERVER_URL); + let response = http.get(&url, Default::default(), false).await?; + + // Normally, ZED_SERVER_URL is set to the URL of zed.dev website. + // The website's /rpc endpoint redirects to a collab server's /rpc endpoint, + // which requires authorization via an HTTP header. + // + // For testing purposes, ZED_SERVER_URL can also set to the direct URL of + // of a collab server. In that case, a request to the /rpc endpoint will + // return an 'unauthorized' response. + let collab_url = if response.status().is_redirection() { + response + .headers() + .get("Location") + .ok_or_else(|| anyhow!("missing location header in /rpc response"))? + .to_str() + .map_err(EstablishConnectionError::other)? + .to_string() + } else if response.status() == StatusCode::UNAUTHORIZED { + url + } else { + Err(anyhow!( + "unexpected /rpc response status {}", + response.status() + ))? + }; + + Url::parse(&collab_url).context("invalid rpc url") + } + + fn establish_websocket_connection( + self: &Arc, + credentials: &Credentials, + cx: &AsyncAppContext, + ) -> Task> { + let use_preview_server = cx + .try_read_global(|channel: &ReleaseChannel, _| *channel != ReleaseChannel::Stable) + .unwrap_or(false); + + let request = Request::builder() + .header( + "Authorization", + format!("{} {}", credentials.user_id, credentials.access_token), + ) + .header("x-zed-protocol-version", rpc2::PROTOCOL_VERSION); + + let http = self.http.clone(); + cx.executor().spawn(async move { + let mut rpc_url = Self::get_rpc_url(http, use_preview_server).await?; + let rpc_host = rpc_url + .host_str() + .zip(rpc_url.port_or_known_default()) + .ok_or_else(|| anyhow!("missing host in rpc url"))?; + let stream = smol::net::TcpStream::connect(rpc_host).await?; + + log::info!("connected to rpc endpoint {}", rpc_url); + + match rpc_url.scheme() { + "https" => { + rpc_url.set_scheme("wss").unwrap(); + let request = request.uri(rpc_url.as_str()).body(())?; + let (stream, _) = + async_tungstenite::async_tls::client_async_tls(request, stream).await?; + Ok(Connection::new( + stream + .map_err(|error| anyhow!(error)) + .sink_map_err(|error| anyhow!(error)), + )) + } + "http" => { + rpc_url.set_scheme("ws").unwrap(); + let request = request.uri(rpc_url.as_str()).body(())?; + let (stream, _) = async_tungstenite::client_async(request, stream).await?; + Ok(Connection::new( + stream + .map_err(|error| anyhow!(error)) + .sink_map_err(|error| anyhow!(error)), + )) + } + _ => Err(anyhow!("invalid rpc url: {}", rpc_url))?, + } + }) + } + + pub fn authenticate_with_browser( + self: &Arc, + cx: &AsyncAppContext, + ) -> Task> { + let http = self.http.clone(); + cx.spawn(|cx| async move { + // Generate a pair of asymmetric encryption keys. The public key will be used by the + // zed server to encrypt the user's access token, so that it can'be intercepted by + // any other app running on the user's device. + let (public_key, private_key) = + rpc2::auth::keypair().expect("failed to generate keypair for auth"); + let public_key_string = + String::try_from(public_key).expect("failed to serialize public key for auth"); + + if let Some((login, token)) = IMPERSONATE_LOGIN.as_ref().zip(ADMIN_API_TOKEN.as_ref()) { + return Self::authenticate_as_admin(http, login.clone(), token.clone()).await; + } + + // Start an HTTP server to receive the redirect from Zed's sign-in page. + let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port"); + let port = server.server_addr().port(); + + // Open the Zed sign-in page in the user's browser, with query parameters that indicate + // that the user is signing in from a Zed app running on the same device. + let mut url = format!( + "{}/native_app_signin?native_app_port={}&native_app_public_key={}", + *ZED_SERVER_URL, port, public_key_string + ); + + if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() { + log::info!("impersonating user @{}", impersonate_login); + write!(&mut url, "&impersonate={}", impersonate_login).unwrap(); + } + + cx.run_on_main(move |cx| cx.open_url(&url))?.await; + + // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted + // access token from the query params. + // + // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a + // custom URL scheme instead of this local HTTP server. + let (user_id, access_token) = cx + .spawn(|_| async move { + for _ in 0..100 { + if let Some(req) = server.recv_timeout(Duration::from_secs(1))? { + let path = req.url(); + let mut user_id = None; + let mut access_token = None; + let url = Url::parse(&format!("http://example.com{}", path)) + .context("failed to parse login notification url")?; + for (key, value) in url.query_pairs() { + if key == "access_token" { + access_token = Some(value.to_string()); + } else if key == "user_id" { + user_id = Some(value.to_string()); + } + } + + let post_auth_url = + format!("{}/native_app_signin_succeeded", *ZED_SERVER_URL); + req.respond( + tiny_http::Response::empty(302).with_header( + tiny_http::Header::from_bytes( + &b"Location"[..], + post_auth_url.as_bytes(), + ) + .unwrap(), + ), + ) + .context("failed to respond to login http request")?; + return Ok(( + user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?, + access_token + .ok_or_else(|| anyhow!("missing access_token parameter"))?, + )); + } + } + + Err(anyhow!("didn't receive login redirect")) + }) + .await?; + + let access_token = private_key + .decrypt_string(&access_token) + .context("failed to decrypt access token")?; + cx.run_on_main(|cx| cx.activate(true))?.await; + + Ok(Credentials { + user_id: user_id.parse()?, + access_token, + }) + }) + } + + async fn authenticate_as_admin( + http: Arc, + login: String, + mut api_token: String, + ) -> Result { + #[derive(Deserialize)] + struct AuthenticatedUserResponse { + user: User, + } + + #[derive(Deserialize)] + struct User { + id: u64, + } + + // Use the collab server's admin API to retrieve the id + // of the impersonated user. + let mut url = Self::get_rpc_url(http.clone(), false).await?; + url.set_path("/user"); + url.set_query(Some(&format!("github_login={login}"))); + let request = Request::get(url.as_str()) + .header("Authorization", format!("token {api_token}")) + .body("".into())?; + + let mut response = http.send(request).await?; + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + if !response.status().is_success() { + Err(anyhow!( + "admin user request failed {} - {}", + response.status().as_u16(), + body, + ))?; + } + let response: AuthenticatedUserResponse = serde_json::from_str(&body)?; + + // Use the admin API token to authenticate as the impersonated user. + api_token.insert_str(0, "ADMIN_TOKEN:"); + Ok(Credentials { + user_id: response.user.id, + access_token: api_token, + }) + } + + pub fn disconnect(self: &Arc, cx: &AsyncAppContext) { + self.peer.teardown(); + self.set_status(Status::SignedOut, cx); + } + + pub fn reconnect(self: &Arc, cx: &AsyncAppContext) { + self.peer.teardown(); + self.set_status(Status::ConnectionLost, cx); + } + + fn connection_id(&self) -> Result { + if let Status::Connected { connection_id, .. } = *self.status().borrow() { + Ok(connection_id) + } else { + Err(anyhow!("not connected")) + } + } + + pub fn send(&self, message: T) -> Result<()> { + log::debug!("rpc send. client_id:{}, name:{}", self.id(), T::NAME); + self.peer.send(self.connection_id()?, message) + } + + pub fn request( + &self, + request: T, + ) -> impl Future> { + self.request_envelope(request) + .map_ok(|envelope| envelope.payload) + } + + pub fn request_envelope( + &self, + request: T, + ) -> impl Future>> { + let client_id = self.id(); + log::debug!( + "rpc request start. client_id:{}. name:{}", + client_id, + T::NAME + ); + let response = self + .connection_id() + .map(|conn_id| self.peer.request_envelope(conn_id, request)); + async move { + let response = response?.await; + log::debug!( + "rpc request finish. client_id:{}. name:{}", + client_id, + T::NAME + ); + response + } + } + + fn respond(&self, receipt: Receipt, response: T::Response) -> Result<()> { + log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME); + self.peer.respond(receipt, response) + } + + fn respond_with_error( + &self, + receipt: Receipt, + error: proto::Error, + ) -> Result<()> { + log::debug!("rpc respond. client_id:{}. name:{}", self.id(), T::NAME); + self.peer.respond_with_error(receipt, error) + } + + fn handle_message( + self: &Arc, + message: Box, + cx: &AsyncAppContext, + ) { + let mut state = self.state.write(); + let type_name = message.payload_type_name(); + let payload_type_id = message.payload_type_id(); + let sender_id = message.original_sender_id(); + + let mut subscriber = None; + + if let Some(handle) = state + .models_by_message_type + .get(&payload_type_id) + .and_then(|handle| handle.upgrade()) + { + subscriber = Some(handle); + } else if let Some((extract_entity_id, entity_type_id)) = + state.entity_id_extractors.get(&payload_type_id).zip( + state + .entity_types_by_message_type + .get(&payload_type_id) + .copied(), + ) + { + let entity_id = (extract_entity_id)(message.as_ref()); + + match state + .entities_by_type_and_remote_id + .get_mut(&(entity_type_id, entity_id)) + { + Some(WeakSubscriber::Pending(pending)) => { + pending.push(message); + return; + } + Some(weak_subscriber @ _) => match weak_subscriber { + WeakSubscriber::Entity { handle } => { + subscriber = handle.upgrade(); + } + + WeakSubscriber::Pending(_) => {} + }, + _ => {} + } + } + + let subscriber = if let Some(subscriber) = subscriber { + subscriber + } else { + log::info!("unhandled message {}", type_name); + self.peer.respond_with_unhandled_message(message).log_err(); + return; + }; + + let handler = state.message_handlers.get(&payload_type_id).cloned(); + // Dropping the state prevents deadlocks if the handler interacts with rpc::Client. + // It also ensures we don't hold the lock while yielding back to the executor, as + // that might cause the executor thread driving this future to block indefinitely. + drop(state); + + if let Some(handler) = handler { + let future = handler(subscriber, message, &self, cx.clone()); + let client_id = self.id(); + log::debug!( + "rpc message received. client_id:{}, sender_id:{:?}, type:{}", + client_id, + sender_id, + type_name + ); + cx.spawn_on_main(move |_| async move { + match future.await { + Ok(()) => { + log::debug!( + "rpc message handled. client_id:{}, sender_id:{:?}, type:{}", + client_id, + sender_id, + type_name + ); + } + Err(error) => { + log::error!( + "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}", + client_id, + sender_id, + type_name, + error + ); + } + } + }) + .detach(); + } else { + log::info!("unhandled message {}", type_name); + self.peer.respond_with_unhandled_message(message).log_err(); + } + } + + pub fn telemetry(&self) -> &Arc { + &self.telemetry + } +} + +async fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option { + if IMPERSONATE_LOGIN.is_some() { + return None; + } + + let (user_id, access_token) = cx + .run_on_main(|cx| cx.read_credentials(&ZED_SERVER_URL).log_err().flatten()) + .ok()? + .await?; + + Some(Credentials { + user_id: user_id.parse().ok()?, + access_token: String::from_utf8(access_token).ok()?, + }) +} + +async fn write_credentials_to_keychain( + credentials: Credentials, + cx: &AsyncAppContext, +) -> Result<()> { + cx.run_on_main(move |cx| { + cx.write_credentials( + &ZED_SERVER_URL, + &credentials.user_id.to_string(), + credentials.access_token.as_bytes(), + ) + })? + .await +} + +async fn delete_credentials_from_keychain(cx: &AsyncAppContext) -> Result<()> { + cx.run_on_main(move |cx| cx.delete_credentials(&ZED_SERVER_URL))? + .await +} + +const WORKTREE_URL_PREFIX: &str = "zed://worktrees/"; + +pub fn encode_worktree_url(id: u64, access_token: &str) -> String { + format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token) +} + +pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> { + let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?; + let mut parts = path.split('/'); + let id = parts.next()?.parse::().ok()?; + let access_token = parts.next()?; + if access_token.is_empty() { + return None; + } + Some((id, access_token.to_string())) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::FakeServer; + + use gpui2::{Context, Executor, TestAppContext}; + use parking_lot::Mutex; + use std::future; + use util::http::FakeHttpClient; + + #[gpui2::test(iterations = 10)] + async fn test_reconnection(cx: &mut TestAppContext) { + let user_id = 5; + let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx)); + let server = FakeServer::for_client(user_id, &client, cx).await; + let mut status = client.status(); + assert!(matches!( + status.next().await, + Some(Status::Connected { .. }) + )); + assert_eq!(server.auth_count(), 1); + + server.forbid_connections(); + server.disconnect(); + while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {} + + server.allow_connections(); + cx.executor().advance_clock(Duration::from_secs(10)); + while !matches!(status.next().await, Some(Status::Connected { .. })) {} + assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting + + server.forbid_connections(); + server.disconnect(); + while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {} + + // Clear cached credentials after authentication fails + server.roll_access_token(); + server.allow_connections(); + cx.executor().run_until_parked(); + cx.executor().advance_clock(Duration::from_secs(10)); + while !matches!(status.next().await, Some(Status::Connected { .. })) {} + assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token + } + + #[gpui2::test(iterations = 10)] + async fn test_connection_timeout(executor: Executor, cx: &mut TestAppContext) { + let user_id = 5; + let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx)); + let mut status = client.status(); + + // Time out when client tries to connect. + client.override_authenticate(move |cx| { + cx.executor().spawn(async move { + Ok(Credentials { + user_id, + access_token: "token".into(), + }) + }) + }); + client.override_establish_connection(|_, cx| { + cx.executor().spawn(async move { + future::pending::<()>().await; + unreachable!() + }) + }); + let auth_and_connect = cx.spawn({ + let client = client.clone(); + |cx| async move { client.authenticate_and_connect(false, &cx).await } + }); + executor.run_until_parked(); + assert!(matches!(status.next().await, Some(Status::Connecting))); + + executor.advance_clock(CONNECTION_TIMEOUT); + assert!(matches!( + status.next().await, + Some(Status::ConnectionError { .. }) + )); + auth_and_connect.await.unwrap_err(); + + // Allow the connection to be established. + let server = FakeServer::for_client(user_id, &client, cx).await; + assert!(matches!( + status.next().await, + Some(Status::Connected { .. }) + )); + + // Disconnect client. + server.forbid_connections(); + server.disconnect(); + while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {} + + // Time out when re-establishing the connection. + server.allow_connections(); + client.override_establish_connection(|_, cx| { + cx.executor().spawn(async move { + future::pending::<()>().await; + unreachable!() + }) + }); + executor.advance_clock(2 * INITIAL_RECONNECTION_DELAY); + assert!(matches!( + status.next().await, + Some(Status::Reconnecting { .. }) + )); + + executor.advance_clock(CONNECTION_TIMEOUT); + assert!(matches!( + status.next().await, + Some(Status::ReconnectionError { .. }) + )); + } + + #[gpui2::test(iterations = 10)] + async fn test_authenticating_more_than_once(cx: &mut TestAppContext, executor: Executor) { + let auth_count = Arc::new(Mutex::new(0)); + let dropped_auth_count = Arc::new(Mutex::new(0)); + let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx)); + client.override_authenticate({ + let auth_count = auth_count.clone(); + let dropped_auth_count = dropped_auth_count.clone(); + move |cx| { + let auth_count = auth_count.clone(); + let dropped_auth_count = dropped_auth_count.clone(); + cx.executor().spawn(async move { + *auth_count.lock() += 1; + let _drop = util::defer(move || *dropped_auth_count.lock() += 1); + future::pending::<()>().await; + unreachable!() + }) + } + }); + + let _authenticate = cx.spawn({ + let client = client.clone(); + move |cx| async move { client.authenticate_and_connect(false, &cx).await } + }); + executor.run_until_parked(); + assert_eq!(*auth_count.lock(), 1); + assert_eq!(*dropped_auth_count.lock(), 0); + + let _authenticate = cx.spawn({ + let client = client.clone(); + |cx| async move { client.authenticate_and_connect(false, &cx).await } + }); + executor.run_until_parked(); + assert_eq!(*auth_count.lock(), 2); + assert_eq!(*dropped_auth_count.lock(), 1); + } + + #[test] + fn test_encode_and_decode_worktree_url() { + let url = encode_worktree_url(5, "deadbeef"); + assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string()))); + assert_eq!( + decode_worktree_url(&format!("\n {}\t", url)), + Some((5, "deadbeef".to_string())) + ); + assert_eq!(decode_worktree_url("not://the-right-format"), None); + } + + #[gpui2::test] + async fn test_subscribing_to_entity(cx: &mut TestAppContext) { + let user_id = 5; + let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx)); + let server = FakeServer::for_client(user_id, &client, cx).await; + + let (done_tx1, mut done_rx1) = smol::channel::unbounded(); + let (done_tx2, mut done_rx2) = smol::channel::unbounded(); + client.add_model_message_handler( + move |model: Model, _: TypedEnvelope, _, mut cx| { + match model.update(&mut cx, |model, _| model.id).unwrap() { + 1 => done_tx1.try_send(()).unwrap(), + 2 => done_tx2.try_send(()).unwrap(), + _ => unreachable!(), + } + async { Ok(()) } + }, + ); + let model1 = cx.build_model(|_| TestModel { + id: 1, + subscription: None, + }); + let model2 = cx.build_model(|_| TestModel { + id: 2, + subscription: None, + }); + let model3 = cx.build_model(|_| TestModel { + id: 3, + subscription: None, + }); + + let _subscription1 = client + .subscribe_to_entity(1) + .unwrap() + .set_model(&model1, &mut cx.to_async()); + let _subscription2 = client + .subscribe_to_entity(2) + .unwrap() + .set_model(&model2, &mut cx.to_async()); + // Ensure dropping a subscription for the same entity type still allows receiving of + // messages for other entity IDs of the same type. + let subscription3 = client + .subscribe_to_entity(3) + .unwrap() + .set_model(&model3, &mut cx.to_async()); + drop(subscription3); + + server.send(proto::JoinProject { project_id: 1 }); + server.send(proto::JoinProject { project_id: 2 }); + done_rx1.next().await.unwrap(); + done_rx2.next().await.unwrap(); + } + + #[gpui2::test] + async fn test_subscribing_after_dropping_subscription(cx: &mut TestAppContext) { + let user_id = 5; + let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx)); + let server = FakeServer::for_client(user_id, &client, cx).await; + + let model = cx.build_model(|_| TestModel::default()); + let (done_tx1, _done_rx1) = smol::channel::unbounded(); + let (done_tx2, mut done_rx2) = smol::channel::unbounded(); + let subscription1 = client.add_message_handler( + model.downgrade(), + move |_, _: TypedEnvelope, _, _| { + done_tx1.try_send(()).unwrap(); + async { Ok(()) } + }, + ); + drop(subscription1); + let _subscription2 = client.add_message_handler( + model.downgrade(), + move |_, _: TypedEnvelope, _, _| { + done_tx2.try_send(()).unwrap(); + async { Ok(()) } + }, + ); + server.send(proto::Ping {}); + done_rx2.next().await.unwrap(); + } + + #[gpui2::test] + async fn test_dropping_subscription_in_handler(cx: &mut TestAppContext) { + let user_id = 5; + let client = cx.update(|cx| Client::new(FakeHttpClient::with_404_response(), cx)); + let server = FakeServer::for_client(user_id, &client, cx).await; + + let model = cx.build_model(|_| TestModel::default()); + let (done_tx, mut done_rx) = smol::channel::unbounded(); + let subscription = client.add_message_handler( + model.clone().downgrade(), + move |model: Model, _: TypedEnvelope, _, mut cx| { + model + .update(&mut cx, |model, _| model.subscription.take()) + .unwrap(); + done_tx.try_send(()).unwrap(); + async { Ok(()) } + }, + ); + model.update(cx, |model, _| { + model.subscription = Some(subscription); + }); + server.send(proto::Ping {}); + done_rx.next().await.unwrap(); + } + + #[derive(Default)] + struct TestModel { + id: usize, + subscription: Option, + } +} diff --git a/crates/client2/src/telemetry.rs b/crates/client2/src/telemetry.rs new file mode 100644 index 0000000000000000000000000000000000000000..47d1c143e120b1c4fa3b8f180e00ab240e76f1d9 --- /dev/null +++ b/crates/client2/src/telemetry.rs @@ -0,0 +1,333 @@ +use crate::{TelemetrySettings, ZED_SECRET_CLIENT_TOKEN, ZED_SERVER_URL}; +use gpui2::{serde_json, AppContext, AppMetadata, Executor, Task}; +use lazy_static::lazy_static; +use parking_lot::Mutex; +use serde::Serialize; +use settings2::Settings; +use std::{env, io::Write, mem, path::PathBuf, sync::Arc, time::Duration}; +use sysinfo::{ + CpuRefreshKind, Pid, PidExt, ProcessExt, ProcessRefreshKind, RefreshKind, System, SystemExt, +}; +use tempfile::NamedTempFile; +use util::http::HttpClient; +use util::{channel::ReleaseChannel, TryFutureExt}; + +pub struct Telemetry { + http_client: Arc, + executor: Executor, + state: Mutex, +} + +struct TelemetryState { + metrics_id: Option>, // Per logged-in user + installation_id: Option>, // Per app installation (different for dev, preview, and stable) + session_id: Option>, // Per app launch + release_channel: Option<&'static str>, + app_metadata: AppMetadata, + architecture: &'static str, + clickhouse_events_queue: Vec, + flush_clickhouse_events_task: Option>, + log_file: Option, + is_staff: Option, +} + +const CLICKHOUSE_EVENTS_URL_PATH: &'static str = "/api/events"; + +lazy_static! { + static ref CLICKHOUSE_EVENTS_URL: String = + format!("{}{}", *ZED_SERVER_URL, CLICKHOUSE_EVENTS_URL_PATH); +} + +#[derive(Serialize, Debug)] +struct ClickhouseEventRequestBody { + token: &'static str, + installation_id: Option>, + session_id: Option>, + is_staff: Option, + app_version: Option, + os_name: &'static str, + os_version: Option, + architecture: &'static str, + release_channel: Option<&'static str>, + events: Vec, +} + +#[derive(Serialize, Debug)] +struct ClickhouseEventWrapper { + signed_in: bool, + #[serde(flatten)] + event: ClickhouseEvent, +} + +#[derive(Serialize, Debug)] +#[serde(rename_all = "snake_case")] +pub enum AssistantKind { + Panel, + Inline, +} + +#[derive(Serialize, Debug)] +#[serde(tag = "type")] +pub enum ClickhouseEvent { + Editor { + operation: &'static str, + file_extension: Option, + vim_mode: bool, + copilot_enabled: bool, + copilot_enabled_for_language: bool, + }, + Copilot { + suggestion_id: Option, + suggestion_accepted: bool, + file_extension: Option, + }, + Call { + operation: &'static str, + room_id: Option, + channel_id: Option, + }, + Assistant { + conversation_id: Option, + kind: AssistantKind, + model: &'static str, + }, + Cpu { + usage_as_percentage: f32, + core_count: u32, + }, + Memory { + memory_in_bytes: u64, + virtual_memory_in_bytes: u64, + }, +} + +#[cfg(debug_assertions)] +const MAX_QUEUE_LEN: usize = 1; + +#[cfg(not(debug_assertions))] +const MAX_QUEUE_LEN: usize = 10; + +#[cfg(debug_assertions)] +const DEBOUNCE_INTERVAL: Duration = Duration::from_secs(1); + +#[cfg(not(debug_assertions))] +const DEBOUNCE_INTERVAL: Duration = Duration::from_secs(30); + +impl Telemetry { + pub fn new(client: Arc, cx: &AppContext) -> Arc { + let release_channel = if cx.has_global::() { + Some(cx.global::().display_name()) + } else { + None + }; + // TODO: Replace all hardware stuff with nested SystemSpecs json + let this = Arc::new(Self { + http_client: client, + executor: cx.executor().clone(), + state: Mutex::new(TelemetryState { + app_metadata: cx.app_metadata(), + architecture: env::consts::ARCH, + release_channel, + installation_id: None, + metrics_id: None, + session_id: None, + clickhouse_events_queue: Default::default(), + flush_clickhouse_events_task: Default::default(), + log_file: None, + is_staff: None, + }), + }); + + this + } + + pub fn log_file_path(&self) -> Option { + Some(self.state.lock().log_file.as_ref()?.path().to_path_buf()) + } + + pub fn start( + self: &Arc, + installation_id: Option, + session_id: String, + cx: &mut AppContext, + ) { + let mut state = self.state.lock(); + state.installation_id = installation_id.map(|id| id.into()); + state.session_id = Some(session_id.into()); + let has_clickhouse_events = !state.clickhouse_events_queue.is_empty(); + drop(state); + + if has_clickhouse_events { + self.flush_clickhouse_events(); + } + + let this = self.clone(); + cx.spawn(|cx| async move { + // Avoiding calling `System::new_all()`, as there have been crashes related to it + let refresh_kind = RefreshKind::new() + .with_memory() // For memory usage + .with_processes(ProcessRefreshKind::everything()) // For process usage + .with_cpu(CpuRefreshKind::everything()); // For core count + + let mut system = System::new_with_specifics(refresh_kind); + + // Avoiding calling `refresh_all()`, just update what we need + system.refresh_specifics(refresh_kind); + + loop { + // Waiting some amount of time before the first query is important to get a reasonable value + // https://docs.rs/sysinfo/0.29.10/sysinfo/trait.ProcessExt.html#tymethod.cpu_usage + const DURATION_BETWEEN_SYSTEM_EVENTS: Duration = Duration::from_secs(60); + smol::Timer::after(DURATION_BETWEEN_SYSTEM_EVENTS).await; + + system.refresh_specifics(refresh_kind); + + let current_process = Pid::from_u32(std::process::id()); + let Some(process) = system.processes().get(¤t_process) else { + let process = current_process; + log::error!("Failed to find own process {process:?} in system process table"); + // TODO: Fire an error telemetry event + return; + }; + + let memory_event = ClickhouseEvent::Memory { + memory_in_bytes: process.memory(), + virtual_memory_in_bytes: process.virtual_memory(), + }; + + let cpu_event = ClickhouseEvent::Cpu { + usage_as_percentage: process.cpu_usage(), + core_count: system.cpus().len() as u32, + }; + + let telemetry_settings = if let Ok(telemetry_settings) = + cx.update(|cx| *TelemetrySettings::get_global(cx)) + { + telemetry_settings + } else { + break; + }; + + this.report_clickhouse_event(memory_event, telemetry_settings); + this.report_clickhouse_event(cpu_event, telemetry_settings); + } + }) + .detach(); + } + + pub fn set_authenticated_user_info( + self: &Arc, + metrics_id: Option, + is_staff: bool, + cx: &AppContext, + ) { + if !TelemetrySettings::get_global(cx).metrics { + return; + } + + let mut state = self.state.lock(); + let metrics_id: Option> = metrics_id.map(|id| id.into()); + state.metrics_id = metrics_id.clone(); + state.is_staff = Some(is_staff); + drop(state); + } + + pub fn report_clickhouse_event( + self: &Arc, + event: ClickhouseEvent, + telemetry_settings: TelemetrySettings, + ) { + if !telemetry_settings.metrics { + return; + } + + let mut state = self.state.lock(); + let signed_in = state.metrics_id.is_some(); + state + .clickhouse_events_queue + .push(ClickhouseEventWrapper { signed_in, event }); + + if state.installation_id.is_some() { + if state.clickhouse_events_queue.len() >= MAX_QUEUE_LEN { + drop(state); + self.flush_clickhouse_events(); + } else { + let this = self.clone(); + let executor = self.executor.clone(); + state.flush_clickhouse_events_task = Some(self.executor.spawn(async move { + executor.timer(DEBOUNCE_INTERVAL).await; + this.flush_clickhouse_events(); + })); + } + } + } + + pub fn metrics_id(self: &Arc) -> Option> { + self.state.lock().metrics_id.clone() + } + + pub fn installation_id(self: &Arc) -> Option> { + self.state.lock().installation_id.clone() + } + + pub fn is_staff(self: &Arc) -> Option { + self.state.lock().is_staff + } + + fn flush_clickhouse_events(self: &Arc) { + let mut state = self.state.lock(); + let mut events = mem::take(&mut state.clickhouse_events_queue); + state.flush_clickhouse_events_task.take(); + drop(state); + + let this = self.clone(); + self.executor + .spawn( + async move { + let mut json_bytes = Vec::new(); + + if let Some(file) = &mut this.state.lock().log_file { + let file = file.as_file_mut(); + for event in &mut events { + json_bytes.clear(); + serde_json::to_writer(&mut json_bytes, event)?; + file.write_all(&json_bytes)?; + file.write(b"\n")?; + } + } + + { + let state = this.state.lock(); + let request_body = ClickhouseEventRequestBody { + token: ZED_SECRET_CLIENT_TOKEN, + installation_id: state.installation_id.clone(), + session_id: state.session_id.clone(), + is_staff: state.is_staff.clone(), + app_version: state + .app_metadata + .app_version + .map(|version| version.to_string()), + os_name: state.app_metadata.os_name, + os_version: state + .app_metadata + .os_version + .map(|version| version.to_string()), + architecture: state.architecture, + + release_channel: state.release_channel, + events, + }; + json_bytes.clear(); + serde_json::to_writer(&mut json_bytes, &request_body)?; + } + + this.http_client + .post_json(CLICKHOUSE_EVENTS_URL.as_str(), json_bytes.into()) + .await?; + anyhow::Ok(()) + } + .log_err(), + ) + .detach(); + } +} diff --git a/crates/client2/src/test.rs b/crates/client2/src/test.rs new file mode 100644 index 0000000000000000000000000000000000000000..f30547dcfc67b000ec73b1ddb0d122cdbb50160c --- /dev/null +++ b/crates/client2/src/test.rs @@ -0,0 +1,216 @@ +use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore}; +use anyhow::{anyhow, Result}; +use futures::{stream::BoxStream, StreamExt}; +use gpui2::{Context, Executor, Model, TestAppContext}; +use parking_lot::Mutex; +use rpc2::{ + proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse}, + ConnectionId, Peer, Receipt, TypedEnvelope, +}; +use std::sync::Arc; +use util::http::FakeHttpClient; + +pub struct FakeServer { + peer: Arc, + state: Arc>, + user_id: u64, + executor: Executor, +} + +#[derive(Default)] +struct FakeServerState { + incoming: Option>>, + connection_id: Option, + forbid_connections: bool, + auth_count: usize, + access_token: usize, +} + +impl FakeServer { + pub async fn for_client( + client_user_id: u64, + client: &Arc, + cx: &TestAppContext, + ) -> Self { + let server = Self { + peer: Peer::new(0), + state: Default::default(), + user_id: client_user_id, + executor: cx.executor().clone(), + }; + + client + .override_authenticate({ + let state = Arc::downgrade(&server.state); + move |cx| { + let state = state.clone(); + cx.spawn(move |_| async move { + let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?; + let mut state = state.lock(); + state.auth_count += 1; + let access_token = state.access_token.to_string(); + Ok(Credentials { + user_id: client_user_id, + access_token, + }) + }) + } + }) + .override_establish_connection({ + let peer = Arc::downgrade(&server.peer); + let state = Arc::downgrade(&server.state); + move |credentials, cx| { + let peer = peer.clone(); + let state = state.clone(); + let credentials = credentials.clone(); + cx.spawn(move |cx| async move { + let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?; + let peer = peer.upgrade().ok_or_else(|| anyhow!("server dropped"))?; + if state.lock().forbid_connections { + Err(EstablishConnectionError::Other(anyhow!( + "server is forbidding connections" + )))? + } + + assert_eq!(credentials.user_id, client_user_id); + + if credentials.access_token != state.lock().access_token.to_string() { + Err(EstablishConnectionError::Unauthorized)? + } + + let (client_conn, server_conn, _) = + Connection::in_memory(cx.executor().clone()); + let (connection_id, io, incoming) = + peer.add_test_connection(server_conn, cx.executor().clone()); + cx.executor().spawn(io).detach(); + { + let mut state = state.lock(); + state.connection_id = Some(connection_id); + state.incoming = Some(incoming); + } + peer.send( + connection_id, + proto::Hello { + peer_id: Some(connection_id.into()), + }, + ) + .unwrap(); + + Ok(client_conn) + }) + } + }); + + client + .authenticate_and_connect(false, &cx.to_async()) + .await + .unwrap(); + + server + } + + pub fn disconnect(&self) { + if self.state.lock().connection_id.is_some() { + self.peer.disconnect(self.connection_id()); + let mut state = self.state.lock(); + state.connection_id.take(); + state.incoming.take(); + } + } + + pub fn auth_count(&self) -> usize { + self.state.lock().auth_count + } + + pub fn roll_access_token(&self) { + self.state.lock().access_token += 1; + } + + pub fn forbid_connections(&self) { + self.state.lock().forbid_connections = true; + } + + pub fn allow_connections(&self) { + self.state.lock().forbid_connections = false; + } + + pub fn send(&self, message: T) { + self.peer.send(self.connection_id(), message).unwrap(); + } + + #[allow(clippy::await_holding_lock)] + pub async fn receive(&self) -> Result> { + self.executor.start_waiting(); + + loop { + let message = self + .state + .lock() + .incoming + .as_mut() + .expect("not connected") + .next() + .await + .ok_or_else(|| anyhow!("other half hung up"))?; + self.executor.finish_waiting(); + let type_name = message.payload_type_name(); + let message = message.into_any(); + + if message.is::>() { + return Ok(*message.downcast().unwrap()); + } + + if message.is::>() { + self.respond( + message + .downcast::>() + .unwrap() + .receipt(), + GetPrivateUserInfoResponse { + metrics_id: "the-metrics-id".into(), + staff: false, + flags: Default::default(), + }, + ); + continue; + } + + panic!( + "fake server received unexpected message type: {:?}", + type_name + ); + } + } + + pub fn respond(&self, receipt: Receipt, response: T::Response) { + self.peer.respond(receipt, response).unwrap() + } + + fn connection_id(&self) -> ConnectionId { + self.state.lock().connection_id.expect("not connected") + } + + pub async fn build_user_store( + &self, + client: Arc, + cx: &mut TestAppContext, + ) -> Model { + let http_client = FakeHttpClient::with_404_response(); + let user_store = cx.build_model(|cx| UserStore::new(client, http_client, cx)); + assert_eq!( + self.receive::() + .await + .unwrap() + .payload + .user_ids, + &[self.user_id] + ); + user_store + } +} + +impl Drop for FakeServer { + fn drop(&mut self) { + self.disconnect(); + } +} diff --git a/crates/client2/src/user.rs b/crates/client2/src/user.rs new file mode 100644 index 0000000000000000000000000000000000000000..2a8cf34af4aea28ee00d23cb448a5d860a0c682d --- /dev/null +++ b/crates/client2/src/user.rs @@ -0,0 +1,739 @@ +use super::{proto, Client, Status, TypedEnvelope}; +use anyhow::{anyhow, Context, Result}; +use collections::{hash_map::Entry, HashMap, HashSet}; +use feature_flags2::FeatureFlagAppExt; +use futures::{channel::mpsc, future, AsyncReadExt, Future, StreamExt}; +use gpui2::{AsyncAppContext, EventEmitter, ImageData, Model, ModelContext, Task}; +use postage::{sink::Sink, watch}; +use rpc2::proto::{RequestMessage, UsersResponse}; +use std::sync::{Arc, Weak}; +use text::ReplicaId; +use util::http::HttpClient; +use util::TryFutureExt as _; + +pub type UserId = u64; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ParticipantIndex(pub u32); + +#[derive(Default, Debug)] +pub struct User { + pub id: UserId, + pub github_login: String, + pub avatar: Option>, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Collaborator { + pub peer_id: proto::PeerId, + pub replica_id: ReplicaId, + pub user_id: UserId, +} + +impl PartialOrd for User { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for User { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.github_login.cmp(&other.github_login) + } +} + +impl PartialEq for User { + fn eq(&self, other: &Self) -> bool { + self.id == other.id && self.github_login == other.github_login + } +} + +impl Eq for User {} + +#[derive(Debug, PartialEq)] +pub struct Contact { + pub user: Arc, + pub online: bool, + pub busy: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ContactRequestStatus { + None, + RequestSent, + RequestReceived, + RequestAccepted, +} + +pub struct UserStore { + users: HashMap>, + participant_indices: HashMap, + update_contacts_tx: mpsc::UnboundedSender, + current_user: watch::Receiver>>, + contacts: Vec>, + incoming_contact_requests: Vec>, + outgoing_contact_requests: Vec>, + pending_contact_requests: HashMap, + invite_info: Option, + client: Weak, + http: Arc, + _maintain_contacts: Task<()>, + _maintain_current_user: Task>, +} + +#[derive(Clone)] +pub struct InviteInfo { + pub count: u32, + pub url: Arc, +} + +pub enum Event { + Contact { + user: Arc, + kind: ContactEventKind, + }, + ShowContacts, + ParticipantIndicesChanged, +} + +#[derive(Clone, Copy)] +pub enum ContactEventKind { + Requested, + Accepted, + Cancelled, +} + +impl EventEmitter for UserStore { + type Event = Event; +} + +enum UpdateContacts { + Update(proto::UpdateContacts), + Wait(postage::barrier::Sender), + Clear(postage::barrier::Sender), +} + +impl UserStore { + pub fn new( + client: Arc, + http: Arc, + cx: &mut ModelContext, + ) -> Self { + let (mut current_user_tx, current_user_rx) = watch::channel(); + let (update_contacts_tx, mut update_contacts_rx) = mpsc::unbounded(); + let rpc_subscriptions = vec![ + client.add_message_handler(cx.weak_model(), Self::handle_update_contacts), + client.add_message_handler(cx.weak_model(), Self::handle_update_invite_info), + client.add_message_handler(cx.weak_model(), Self::handle_show_contacts), + ]; + Self { + users: Default::default(), + current_user: current_user_rx, + contacts: Default::default(), + incoming_contact_requests: Default::default(), + participant_indices: Default::default(), + outgoing_contact_requests: Default::default(), + invite_info: None, + client: Arc::downgrade(&client), + update_contacts_tx, + http, + _maintain_contacts: cx.spawn(|this, mut cx| async move { + let _subscriptions = rpc_subscriptions; + while let Some(message) = update_contacts_rx.next().await { + if let Ok(task) = + this.update(&mut cx, |this, cx| this.update_contacts(message, cx)) + { + task.log_err().await; + } else { + break; + } + } + }), + _maintain_current_user: cx.spawn(|this, mut cx| async move { + let mut status = client.status(); + while let Some(status) = status.next().await { + match status { + Status::Connected { .. } => { + if let Some(user_id) = client.user_id() { + let fetch_user = if let Ok(fetch_user) = this + .update(&mut cx, |this, cx| { + this.get_user(user_id, cx).log_err() + }) { + fetch_user + } else { + break; + }; + let fetch_metrics_id = + client.request(proto::GetPrivateUserInfo {}).log_err(); + let (user, info) = futures::join!(fetch_user, fetch_metrics_id); + + cx.update(|cx| { + if let Some(info) = info { + cx.update_flags(info.staff, info.flags); + client.telemetry.set_authenticated_user_info( + Some(info.metrics_id.clone()), + info.staff, + cx, + ) + } + })?; + + current_user_tx.send(user).await.ok(); + + this.update(&mut cx, |_, cx| cx.notify())?; + } + } + Status::SignedOut => { + current_user_tx.send(None).await.ok(); + this.update(&mut cx, |this, cx| { + cx.notify(); + this.clear_contacts() + })? + .await; + } + Status::ConnectionLost => { + this.update(&mut cx, |this, cx| { + cx.notify(); + this.clear_contacts() + })? + .await; + } + _ => {} + } + } + Ok(()) + }), + pending_contact_requests: Default::default(), + } + } + + #[cfg(feature = "test-support")] + pub fn clear_cache(&mut self) { + self.users.clear(); + } + + async fn handle_update_invite_info( + this: Model, + message: TypedEnvelope, + _: Arc, + mut cx: AsyncAppContext, + ) -> Result<()> { + this.update(&mut cx, |this, cx| { + this.invite_info = Some(InviteInfo { + url: Arc::from(message.payload.url), + count: message.payload.count, + }); + cx.notify(); + })?; + Ok(()) + } + + async fn handle_show_contacts( + this: Model, + _: TypedEnvelope, + _: Arc, + mut cx: AsyncAppContext, + ) -> Result<()> { + this.update(&mut cx, |_, cx| cx.emit(Event::ShowContacts))?; + Ok(()) + } + + pub fn invite_info(&self) -> Option<&InviteInfo> { + self.invite_info.as_ref() + } + + async fn handle_update_contacts( + this: Model, + message: TypedEnvelope, + _: Arc, + mut cx: AsyncAppContext, + ) -> Result<()> { + this.update(&mut cx, |this, _| { + this.update_contacts_tx + .unbounded_send(UpdateContacts::Update(message.payload)) + .unwrap(); + })?; + Ok(()) + } + + fn update_contacts( + &mut self, + message: UpdateContacts, + cx: &mut ModelContext, + ) -> Task> { + match message { + UpdateContacts::Wait(barrier) => { + drop(barrier); + Task::ready(Ok(())) + } + UpdateContacts::Clear(barrier) => { + self.contacts.clear(); + self.incoming_contact_requests.clear(); + self.outgoing_contact_requests.clear(); + drop(barrier); + Task::ready(Ok(())) + } + UpdateContacts::Update(message) => { + let mut user_ids = HashSet::default(); + for contact in &message.contacts { + user_ids.insert(contact.user_id); + } + user_ids.extend(message.incoming_requests.iter().map(|req| req.requester_id)); + user_ids.extend(message.outgoing_requests.iter()); + + let load_users = self.get_users(user_ids.into_iter().collect(), cx); + cx.spawn(|this, mut cx| async move { + load_users.await?; + + // Users are fetched in parallel above and cached in call to get_users + // No need to paralellize here + let mut updated_contacts = Vec::new(); + let this = this + .upgrade() + .ok_or_else(|| anyhow!("can't upgrade user store handle"))?; + for contact in message.contacts { + let should_notify = contact.should_notify; + updated_contacts.push(( + Arc::new(Contact::from_proto(contact, &this, &mut cx).await?), + should_notify, + )); + } + + let mut incoming_requests = Vec::new(); + for request in message.incoming_requests { + incoming_requests.push({ + let user = this + .update(&mut cx, |this, cx| { + this.get_user(request.requester_id, cx) + })? + .await?; + (user, request.should_notify) + }); + } + + let mut outgoing_requests = Vec::new(); + for requested_user_id in message.outgoing_requests { + outgoing_requests.push( + this.update(&mut cx, |this, cx| this.get_user(requested_user_id, cx))? + .await?, + ); + } + + let removed_contacts = + HashSet::::from_iter(message.remove_contacts.iter().copied()); + let removed_incoming_requests = + HashSet::::from_iter(message.remove_incoming_requests.iter().copied()); + let removed_outgoing_requests = + HashSet::::from_iter(message.remove_outgoing_requests.iter().copied()); + + this.update(&mut cx, |this, cx| { + // Remove contacts + this.contacts + .retain(|contact| !removed_contacts.contains(&contact.user.id)); + // Update existing contacts and insert new ones + for (updated_contact, should_notify) in updated_contacts { + if should_notify { + cx.emit(Event::Contact { + user: updated_contact.user.clone(), + kind: ContactEventKind::Accepted, + }); + } + match this.contacts.binary_search_by_key( + &&updated_contact.user.github_login, + |contact| &contact.user.github_login, + ) { + Ok(ix) => this.contacts[ix] = updated_contact, + Err(ix) => this.contacts.insert(ix, updated_contact), + } + } + + // Remove incoming contact requests + this.incoming_contact_requests.retain(|user| { + if removed_incoming_requests.contains(&user.id) { + cx.emit(Event::Contact { + user: user.clone(), + kind: ContactEventKind::Cancelled, + }); + false + } else { + true + } + }); + // Update existing incoming requests and insert new ones + for (user, should_notify) in incoming_requests { + if should_notify { + cx.emit(Event::Contact { + user: user.clone(), + kind: ContactEventKind::Requested, + }); + } + + match this + .incoming_contact_requests + .binary_search_by_key(&&user.github_login, |contact| { + &contact.github_login + }) { + Ok(ix) => this.incoming_contact_requests[ix] = user, + Err(ix) => this.incoming_contact_requests.insert(ix, user), + } + } + + // Remove outgoing contact requests + this.outgoing_contact_requests + .retain(|user| !removed_outgoing_requests.contains(&user.id)); + // Update existing incoming requests and insert new ones + for request in outgoing_requests { + match this + .outgoing_contact_requests + .binary_search_by_key(&&request.github_login, |contact| { + &contact.github_login + }) { + Ok(ix) => this.outgoing_contact_requests[ix] = request, + Err(ix) => this.outgoing_contact_requests.insert(ix, request), + } + } + + cx.notify(); + })?; + + Ok(()) + }) + } + } + } + + pub fn contacts(&self) -> &[Arc] { + &self.contacts + } + + pub fn has_contact(&self, user: &Arc) -> bool { + self.contacts + .binary_search_by_key(&&user.github_login, |contact| &contact.user.github_login) + .is_ok() + } + + pub fn incoming_contact_requests(&self) -> &[Arc] { + &self.incoming_contact_requests + } + + pub fn outgoing_contact_requests(&self) -> &[Arc] { + &self.outgoing_contact_requests + } + + pub fn is_contact_request_pending(&self, user: &User) -> bool { + self.pending_contact_requests.contains_key(&user.id) + } + + pub fn contact_request_status(&self, user: &User) -> ContactRequestStatus { + if self + .contacts + .binary_search_by_key(&&user.github_login, |contact| &contact.user.github_login) + .is_ok() + { + ContactRequestStatus::RequestAccepted + } else if self + .outgoing_contact_requests + .binary_search_by_key(&&user.github_login, |user| &user.github_login) + .is_ok() + { + ContactRequestStatus::RequestSent + } else if self + .incoming_contact_requests + .binary_search_by_key(&&user.github_login, |user| &user.github_login) + .is_ok() + { + ContactRequestStatus::RequestReceived + } else { + ContactRequestStatus::None + } + } + + pub fn request_contact( + &mut self, + responder_id: u64, + cx: &mut ModelContext, + ) -> Task> { + self.perform_contact_request(responder_id, proto::RequestContact { responder_id }, cx) + } + + pub fn remove_contact( + &mut self, + user_id: u64, + cx: &mut ModelContext, + ) -> Task> { + self.perform_contact_request(user_id, proto::RemoveContact { user_id }, cx) + } + + pub fn respond_to_contact_request( + &mut self, + requester_id: u64, + accept: bool, + cx: &mut ModelContext, + ) -> Task> { + self.perform_contact_request( + requester_id, + proto::RespondToContactRequest { + requester_id, + response: if accept { + proto::ContactRequestResponse::Accept + } else { + proto::ContactRequestResponse::Decline + } as i32, + }, + cx, + ) + } + + pub fn dismiss_contact_request( + &mut self, + requester_id: u64, + cx: &mut ModelContext, + ) -> Task> { + let client = self.client.upgrade(); + cx.spawn(move |_, _| async move { + client + .ok_or_else(|| anyhow!("can't upgrade client reference"))? + .request(proto::RespondToContactRequest { + requester_id, + response: proto::ContactRequestResponse::Dismiss as i32, + }) + .await?; + Ok(()) + }) + } + + fn perform_contact_request( + &mut self, + user_id: u64, + request: T, + cx: &mut ModelContext, + ) -> Task> { + let client = self.client.upgrade(); + *self.pending_contact_requests.entry(user_id).or_insert(0) += 1; + cx.notify(); + + cx.spawn(move |this, mut cx| async move { + let response = client + .ok_or_else(|| anyhow!("can't upgrade client reference"))? + .request(request) + .await; + this.update(&mut cx, |this, cx| { + if let Entry::Occupied(mut request_count) = + this.pending_contact_requests.entry(user_id) + { + *request_count.get_mut() -= 1; + if *request_count.get() == 0 { + request_count.remove(); + } + } + cx.notify(); + })?; + response?; + Ok(()) + }) + } + + pub fn clear_contacts(&mut self) -> impl Future { + let (tx, mut rx) = postage::barrier::channel(); + self.update_contacts_tx + .unbounded_send(UpdateContacts::Clear(tx)) + .unwrap(); + async move { + rx.next().await; + } + } + + pub fn contact_updates_done(&mut self) -> impl Future { + let (tx, mut rx) = postage::barrier::channel(); + self.update_contacts_tx + .unbounded_send(UpdateContacts::Wait(tx)) + .unwrap(); + async move { + rx.next().await; + } + } + + pub fn get_users( + &mut self, + user_ids: Vec, + cx: &mut ModelContext, + ) -> Task>>> { + let mut user_ids_to_fetch = user_ids.clone(); + user_ids_to_fetch.retain(|id| !self.users.contains_key(id)); + + cx.spawn(|this, mut cx| async move { + if !user_ids_to_fetch.is_empty() { + this.update(&mut cx, |this, cx| { + this.load_users( + proto::GetUsers { + user_ids: user_ids_to_fetch, + }, + cx, + ) + })? + .await?; + } + + this.update(&mut cx, |this, _| { + user_ids + .iter() + .map(|user_id| { + this.users + .get(user_id) + .cloned() + .ok_or_else(|| anyhow!("user {} not found", user_id)) + }) + .collect() + })? + }) + } + + pub fn fuzzy_search_users( + &mut self, + query: String, + cx: &mut ModelContext, + ) -> Task>>> { + self.load_users(proto::FuzzySearchUsers { query }, cx) + } + + pub fn get_cached_user(&self, user_id: u64) -> Option> { + self.users.get(&user_id).cloned() + } + + pub fn get_user( + &mut self, + user_id: u64, + cx: &mut ModelContext, + ) -> Task>> { + if let Some(user) = self.users.get(&user_id).cloned() { + return Task::ready(Ok(user)); + } + + let load_users = self.get_users(vec![user_id], cx); + cx.spawn(move |this, mut cx| async move { + load_users.await?; + this.update(&mut cx, |this, _| { + this.users + .get(&user_id) + .cloned() + .ok_or_else(|| anyhow!("server responded with no users")) + })? + }) + } + + pub fn current_user(&self) -> Option> { + self.current_user.borrow().clone() + } + + pub fn watch_current_user(&self) -> watch::Receiver>> { + self.current_user.clone() + } + + fn load_users( + &mut self, + request: impl RequestMessage, + cx: &mut ModelContext, + ) -> Task>>> { + let client = self.client.clone(); + let http = self.http.clone(); + cx.spawn(|this, mut cx| async move { + if let Some(rpc) = client.upgrade() { + let response = rpc.request(request).await.context("error loading users")?; + let users = future::join_all( + response + .users + .into_iter() + .map(|user| User::new(user, http.as_ref())), + ) + .await; + + this.update(&mut cx, |this, _| { + for user in &users { + this.users.insert(user.id, user.clone()); + } + }) + .ok(); + + Ok(users) + } else { + Ok(Vec::new()) + } + }) + } + + pub fn set_participant_indices( + &mut self, + participant_indices: HashMap, + cx: &mut ModelContext, + ) { + if participant_indices != self.participant_indices { + self.participant_indices = participant_indices; + cx.emit(Event::ParticipantIndicesChanged); + } + } + + pub fn participant_indices(&self) -> &HashMap { + &self.participant_indices + } +} + +impl User { + async fn new(message: proto::User, http: &dyn HttpClient) -> Arc { + Arc::new(User { + id: message.id, + github_login: message.github_login, + avatar: fetch_avatar(http, &message.avatar_url).warn_on_err().await, + }) + } +} + +impl Contact { + async fn from_proto( + contact: proto::Contact, + user_store: &Model, + cx: &mut AsyncAppContext, + ) -> Result { + let user = user_store + .update(cx, |user_store, cx| { + user_store.get_user(contact.user_id, cx) + })? + .await?; + Ok(Self { + user, + online: contact.online, + busy: contact.busy, + }) + } +} + +impl Collaborator { + pub fn from_proto(message: proto::Collaborator) -> Result { + Ok(Self { + peer_id: message.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?, + replica_id: message.replica_id as ReplicaId, + user_id: message.user_id as UserId, + }) + } +} + +// todo!("we probably don't need this now that we fetch") +async fn fetch_avatar(http: &dyn HttpClient, url: &str) -> Result> { + let mut response = http + .get(url, Default::default(), true) + .await + .map_err(|e| anyhow!("failed to send user avatar request: {}", e))?; + + if !response.status().is_success() { + return Err(anyhow!("avatar request failed {:?}", response.status())); + } + + let mut body = Vec::new(); + response + .body_mut() + .read_to_end(&mut body) + .await + .map_err(|e| anyhow!("failed to read user avatar response body: {}", e))?; + let format = image::guess_format(&body)?; + let image = image::load_from_memory_with_format(&body, format)?.into_bgra8(); + Ok(Arc::new(ImageData::new(image))) +} diff --git a/crates/copilot2/Cargo.toml b/crates/copilot2/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..f83824d80824bae7144802067e34923f009312ed --- /dev/null +++ b/crates/copilot2/Cargo.toml @@ -0,0 +1,50 @@ +[package] +name = "copilot2" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/copilot2.rs" +doctest = false + +[features] +test-support = [ + "collections/test-support", + "gpui2/test-support", + "language2/test-support", + "lsp2/test-support", + "settings2/test-support", + "util/test-support", +] + +[dependencies] +collections = { path = "../collections" } +context_menu = { path = "../context_menu" } +gpui2 = { path = "../gpui2" } +language2 = { path = "../language2" } +settings2 = { path = "../settings2" } +theme = { path = "../theme" } +lsp2 = { path = "../lsp2" } +node_runtime = { path = "../node_runtime"} +util = { path = "../util" } +async-compression = { version = "0.3", features = ["gzip", "futures-bufread"] } +async-tar = "0.4.2" +anyhow.workspace = true +log.workspace = true +serde.workspace = true +serde_derive.workspace = true +smol.workspace = true +futures.workspace = true +parking_lot.workspace = true + +[dev-dependencies] +clock = { path = "../clock" } +collections = { path = "../collections", features = ["test-support"] } +fs = { path = "../fs", features = ["test-support"] } +gpui2 = { path = "../gpui2", features = ["test-support"] } +language2 = { path = "../language2", features = ["test-support"] } +lsp2 = { path = "../lsp2", features = ["test-support"] } +rpc = { path = "../rpc", features = ["test-support"] } +settings2 = { path = "../settings2", features = ["test-support"] } +util = { path = "../util", features = ["test-support"] } diff --git a/crates/copilot2/src/copilot2.rs b/crates/copilot2/src/copilot2.rs new file mode 100644 index 0000000000000000000000000000000000000000..083c491656579efe1b9b0cb1df96c0052b8bcd74 --- /dev/null +++ b/crates/copilot2/src/copilot2.rs @@ -0,0 +1,1234 @@ +pub mod request; +mod sign_in; + +use anyhow::{anyhow, Context as _, Result}; +use async_compression::futures::bufread::GzipDecoder; +use async_tar::Archive; +use collections::{HashMap, HashSet}; +use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt}; +use gpui2::{ + AppContext, AsyncAppContext, Context, Entity, EntityId, EventEmitter, Model, ModelContext, + Task, WeakModel, +}; +use language2::{ + language_settings::{all_language_settings, language_settings}, + point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, + LanguageServerName, PointUtf16, ToPointUtf16, +}; +use lsp2::{LanguageServer, LanguageServerBinary, LanguageServerId}; +use node_runtime::NodeRuntime; +use parking_lot::Mutex; +use request::StatusNotification; +use settings2::SettingsStore; +use smol::{fs, io::BufReader, stream::StreamExt}; +use std::{ + ffi::OsString, + mem, + ops::Range, + path::{Path, PathBuf}, + sync::Arc, +}; +use util::{ + fs::remove_matching, github::latest_github_release, http::HttpClient, paths, ResultExt, +}; + +// todo!() +// const COPILOT_AUTH_NAMESPACE: &'static str = "copilot_auth"; +// actions!(copilot_auth, [SignIn, SignOut]); + +// todo!() +// const COPILOT_NAMESPACE: &'static str = "copilot"; +// actions!( +// copilot, +// [Suggest, NextSuggestion, PreviousSuggestion, Reinstall] +// ); + +pub fn init( + new_server_id: LanguageServerId, + http: Arc, + node_runtime: Arc, + cx: &mut AppContext, +) { + let copilot = cx.build_model({ + let node_runtime = node_runtime.clone(); + move |cx| Copilot::start(new_server_id, http, node_runtime, cx) + }); + cx.set_global(copilot.clone()); + + // TODO + // cx.observe(&copilot, |handle, cx| { + // let status = handle.read(cx).status(); + // cx.update_default_global::(move |filter, _cx| { + // match status { + // Status::Disabled => { + // filter.filtered_namespaces.insert(COPILOT_NAMESPACE); + // filter.filtered_namespaces.insert(COPILOT_AUTH_NAMESPACE); + // } + // Status::Authorized => { + // filter.filtered_namespaces.remove(COPILOT_NAMESPACE); + // filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE); + // } + // _ => { + // filter.filtered_namespaces.insert(COPILOT_NAMESPACE); + // filter.filtered_namespaces.remove(COPILOT_AUTH_NAMESPACE); + // } + // } + // }); + // }) + // .detach(); + + // sign_in::init(cx); + // cx.add_global_action(|_: &SignIn, cx| { + // if let Some(copilot) = Copilot::global(cx) { + // copilot + // .update(cx, |copilot, cx| copilot.sign_in(cx)) + // .detach_and_log_err(cx); + // } + // }); + // cx.add_global_action(|_: &SignOut, cx| { + // if let Some(copilot) = Copilot::global(cx) { + // copilot + // .update(cx, |copilot, cx| copilot.sign_out(cx)) + // .detach_and_log_err(cx); + // } + // }); + + // cx.add_global_action(|_: &Reinstall, cx| { + // if let Some(copilot) = Copilot::global(cx) { + // copilot + // .update(cx, |copilot, cx| copilot.reinstall(cx)) + // .detach(); + // } + // }); +} + +enum CopilotServer { + Disabled, + Starting { task: Shared> }, + Error(Arc), + Running(RunningCopilotServer), +} + +impl CopilotServer { + fn as_authenticated(&mut self) -> Result<&mut RunningCopilotServer> { + let server = self.as_running()?; + if matches!(server.sign_in_status, SignInStatus::Authorized { .. }) { + Ok(server) + } else { + Err(anyhow!("must sign in before using copilot")) + } + } + + fn as_running(&mut self) -> Result<&mut RunningCopilotServer> { + match self { + CopilotServer::Starting { .. } => Err(anyhow!("copilot is still starting")), + CopilotServer::Disabled => Err(anyhow!("copilot is disabled")), + CopilotServer::Error(error) => Err(anyhow!( + "copilot was not started because of an error: {}", + error + )), + CopilotServer::Running(server) => Ok(server), + } + } +} + +struct RunningCopilotServer { + name: LanguageServerName, + lsp: Arc, + sign_in_status: SignInStatus, + registered_buffers: HashMap, +} + +#[derive(Clone, Debug)] +enum SignInStatus { + Authorized, + Unauthorized, + SigningIn { + prompt: Option, + task: Shared>>>, + }, + SignedOut, +} + +#[derive(Debug, Clone)] +pub enum Status { + Starting { + task: Shared>, + }, + Error(Arc), + Disabled, + SignedOut, + SigningIn { + prompt: Option, + }, + Unauthorized, + Authorized, +} + +impl Status { + pub fn is_authorized(&self) -> bool { + matches!(self, Status::Authorized) + } +} + +struct RegisteredBuffer { + uri: lsp2::Url, + language_id: String, + snapshot: BufferSnapshot, + snapshot_version: i32, + _subscriptions: [gpui2::Subscription; 2], + pending_buffer_change: Task>, +} + +impl RegisteredBuffer { + fn report_changes( + &mut self, + buffer: &Model, + cx: &mut ModelContext, + ) -> oneshot::Receiver<(i32, BufferSnapshot)> { + let (done_tx, done_rx) = oneshot::channel(); + + if buffer.read(cx).version() == self.snapshot.version { + let _ = done_tx.send((self.snapshot_version, self.snapshot.clone())); + } else { + let buffer = buffer.downgrade(); + let id = buffer.entity_id(); + let prev_pending_change = + mem::replace(&mut self.pending_buffer_change, Task::ready(None)); + self.pending_buffer_change = cx.spawn(move |copilot, mut cx| async move { + prev_pending_change.await; + + let old_version = copilot + .update(&mut cx, |copilot, _| { + let server = copilot.server.as_authenticated().log_err()?; + let buffer = server.registered_buffers.get_mut(&id)?; + Some(buffer.snapshot.version.clone()) + }) + .ok()??; + let new_snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot()).ok()?; + + let content_changes = cx + .executor() + .spawn({ + let new_snapshot = new_snapshot.clone(); + async move { + new_snapshot + .edits_since::<(PointUtf16, usize)>(&old_version) + .map(|edit| { + let edit_start = edit.new.start.0; + let edit_end = edit_start + (edit.old.end.0 - edit.old.start.0); + let new_text = new_snapshot + .text_for_range(edit.new.start.1..edit.new.end.1) + .collect(); + lsp2::TextDocumentContentChangeEvent { + range: Some(lsp2::Range::new( + point_to_lsp(edit_start), + point_to_lsp(edit_end), + )), + range_length: None, + text: new_text, + } + }) + .collect::>() + } + }) + .await; + + copilot + .update(&mut cx, |copilot, _| { + let server = copilot.server.as_authenticated().log_err()?; + let buffer = server.registered_buffers.get_mut(&id)?; + if !content_changes.is_empty() { + buffer.snapshot_version += 1; + buffer.snapshot = new_snapshot; + server + .lsp + .notify::( + lsp2::DidChangeTextDocumentParams { + text_document: lsp2::VersionedTextDocumentIdentifier::new( + buffer.uri.clone(), + buffer.snapshot_version, + ), + content_changes, + }, + ) + .log_err(); + } + let _ = done_tx.send((buffer.snapshot_version, buffer.snapshot.clone())); + Some(()) + }) + .ok()?; + + Some(()) + }); + } + + done_rx + } +} + +#[derive(Debug)] +pub struct Completion { + pub uuid: String, + pub range: Range, + pub text: String, +} + +pub struct Copilot { + http: Arc, + node_runtime: Arc, + server: CopilotServer, + buffers: HashSet>, + server_id: LanguageServerId, + _subscription: gpui2::Subscription, +} + +pub enum Event { + CopilotLanguageServerStarted, +} + +impl EventEmitter for Copilot { + type Event = Event; +} + +impl Copilot { + pub fn global(cx: &AppContext) -> Option> { + if cx.has_global::>() { + Some(cx.global::>().clone()) + } else { + None + } + } + + fn start( + new_server_id: LanguageServerId, + http: Arc, + node_runtime: Arc, + cx: &mut ModelContext, + ) -> Self { + let mut this = Self { + server_id: new_server_id, + http, + node_runtime, + server: CopilotServer::Disabled, + buffers: Default::default(), + _subscription: cx.on_app_quit(Self::shutdown_language_server), + }; + this.enable_or_disable_copilot(cx); + cx.observe_global::(move |this, cx| this.enable_or_disable_copilot(cx)) + .detach(); + this + } + + fn shutdown_language_server( + &mut self, + _cx: &mut ModelContext, + ) -> impl Future { + let shutdown = match mem::replace(&mut self.server, CopilotServer::Disabled) { + CopilotServer::Running(server) => Some(Box::pin(async move { server.lsp.shutdown() })), + _ => None, + }; + + async move { + if let Some(shutdown) = shutdown { + shutdown.await; + } + } + } + + fn enable_or_disable_copilot(&mut self, cx: &mut ModelContext) { + let server_id = self.server_id; + let http = self.http.clone(); + let node_runtime = self.node_runtime.clone(); + if all_language_settings(None, cx).copilot_enabled(None, None) { + if matches!(self.server, CopilotServer::Disabled) { + let start_task = cx + .spawn(move |this, cx| { + Self::start_language_server(server_id, http, node_runtime, this, cx) + }) + .shared(); + self.server = CopilotServer::Starting { task: start_task }; + cx.notify(); + } + } else { + self.server = CopilotServer::Disabled; + cx.notify(); + } + } + + // #[cfg(any(test, feature = "test-support"))] + // pub fn fake(cx: &mut gpui::TestAppContext) -> (ModelHandle, lsp::FakeLanguageServer) { + // use node_runtime::FakeNodeRuntime; + + // let (server, fake_server) = + // LanguageServer::fake("copilot".into(), Default::default(), cx.to_async()); + // let http = util::http::FakeHttpClient::create(|_| async { unreachable!() }); + // let node_runtime = FakeNodeRuntime::new(); + // let this = cx.add_model(|_| Self { + // server_id: LanguageServerId(0), + // http: http.clone(), + // node_runtime, + // server: CopilotServer::Running(RunningCopilotServer { + // name: LanguageServerName(Arc::from("copilot")), + // lsp: Arc::new(server), + // sign_in_status: SignInStatus::Authorized, + // registered_buffers: Default::default(), + // }), + // buffers: Default::default(), + // }); + // (this, fake_server) + // } + + fn start_language_server( + new_server_id: LanguageServerId, + http: Arc, + node_runtime: Arc, + this: WeakModel, + mut cx: AsyncAppContext, + ) -> impl Future { + async move { + let start_language_server = async { + let server_path = get_copilot_lsp(http).await?; + let node_path = node_runtime.binary_path().await?; + let arguments: Vec = vec![server_path.into(), "--stdio".into()]; + let binary = LanguageServerBinary { + path: node_path, + arguments, + }; + + let server = LanguageServer::new( + Arc::new(Mutex::new(None)), + new_server_id, + binary, + Path::new("/"), + None, + cx.clone(), + )?; + + server + .on_notification::( + |_, _| { /* Silence the notification */ }, + ) + .detach(); + + let server = server.initialize(Default::default()).await?; + + let status = server + .request::(request::CheckStatusParams { + local_checks_only: false, + }) + .await?; + + server + .request::(request::SetEditorInfoParams { + editor_info: request::EditorInfo { + name: "zed".into(), + version: env!("CARGO_PKG_VERSION").into(), + }, + editor_plugin_info: request::EditorPluginInfo { + name: "zed-copilot".into(), + version: "0.0.1".into(), + }, + }) + .await?; + + anyhow::Ok((server, status)) + }; + + let server = start_language_server.await; + this.update(&mut cx, |this, cx| { + cx.notify(); + match server { + Ok((server, status)) => { + this.server = CopilotServer::Running(RunningCopilotServer { + name: LanguageServerName(Arc::from("copilot")), + lsp: server, + sign_in_status: SignInStatus::SignedOut, + registered_buffers: Default::default(), + }); + cx.emit(Event::CopilotLanguageServerStarted); + this.update_sign_in_status(status, cx); + } + Err(error) => { + this.server = CopilotServer::Error(error.to_string().into()); + cx.notify() + } + } + }) + .ok(); + } + } + + pub fn sign_in(&mut self, cx: &mut ModelContext) -> Task> { + if let CopilotServer::Running(server) = &mut self.server { + let task = match &server.sign_in_status { + SignInStatus::Authorized { .. } => Task::ready(Ok(())).shared(), + SignInStatus::SigningIn { task, .. } => { + cx.notify(); + task.clone() + } + SignInStatus::SignedOut | SignInStatus::Unauthorized { .. } => { + let lsp = server.lsp.clone(); + let task = cx + .spawn(|this, mut cx| async move { + let sign_in = async { + let sign_in = lsp + .request::( + request::SignInInitiateParams {}, + ) + .await?; + match sign_in { + request::SignInInitiateResult::AlreadySignedIn { user } => { + Ok(request::SignInStatus::Ok { user }) + } + request::SignInInitiateResult::PromptUserDeviceFlow(flow) => { + this.update(&mut cx, |this, cx| { + if let CopilotServer::Running(RunningCopilotServer { + sign_in_status: status, + .. + }) = &mut this.server + { + if let SignInStatus::SigningIn { + prompt: prompt_flow, + .. + } = status + { + *prompt_flow = Some(flow.clone()); + cx.notify(); + } + } + })?; + let response = lsp + .request::( + request::SignInConfirmParams { + user_code: flow.user_code, + }, + ) + .await?; + Ok(response) + } + } + }; + + let sign_in = sign_in.await; + this.update(&mut cx, |this, cx| match sign_in { + Ok(status) => { + this.update_sign_in_status(status, cx); + Ok(()) + } + Err(error) => { + this.update_sign_in_status( + request::SignInStatus::NotSignedIn, + cx, + ); + Err(Arc::new(error)) + } + })? + }) + .shared(); + server.sign_in_status = SignInStatus::SigningIn { + prompt: None, + task: task.clone(), + }; + cx.notify(); + task + } + }; + + cx.executor() + .spawn(task.map_err(|err| anyhow!("{:?}", err))) + } else { + // If we're downloading, wait until download is finished + // If we're in a stuck state, display to the user + Task::ready(Err(anyhow!("copilot hasn't started yet"))) + } + } + + #[allow(dead_code)] // todo!() + fn sign_out(&mut self, cx: &mut ModelContext) -> Task> { + self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx); + if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server { + let server = server.clone(); + cx.executor().spawn(async move { + server + .request::(request::SignOutParams {}) + .await?; + anyhow::Ok(()) + }) + } else { + Task::ready(Err(anyhow!("copilot hasn't started yet"))) + } + } + + pub fn reinstall(&mut self, cx: &mut ModelContext) -> Task<()> { + let start_task = cx + .spawn({ + let http = self.http.clone(); + let node_runtime = self.node_runtime.clone(); + let server_id = self.server_id; + move |this, cx| async move { + clear_copilot_dir().await; + Self::start_language_server(server_id, http, node_runtime, this, cx).await + } + }) + .shared(); + + self.server = CopilotServer::Starting { + task: start_task.clone(), + }; + + cx.notify(); + + cx.executor().spawn(start_task) + } + + pub fn language_server(&self) -> Option<(&LanguageServerName, &Arc)> { + if let CopilotServer::Running(server) = &self.server { + Some((&server.name, &server.lsp)) + } else { + None + } + } + + pub fn register_buffer(&mut self, buffer: &Model, cx: &mut ModelContext) { + let weak_buffer = buffer.downgrade(); + self.buffers.insert(weak_buffer.clone()); + + if let CopilotServer::Running(RunningCopilotServer { + lsp: server, + sign_in_status: status, + registered_buffers, + .. + }) = &mut self.server + { + if !matches!(status, SignInStatus::Authorized { .. }) { + return; + } + + registered_buffers + .entry(buffer.entity_id()) + .or_insert_with(|| { + let uri: lsp2::Url = uri_for_buffer(buffer, cx); + let language_id = id_for_language(buffer.read(cx).language()); + let snapshot = buffer.read(cx).snapshot(); + server + .notify::( + lsp2::DidOpenTextDocumentParams { + text_document: lsp2::TextDocumentItem { + uri: uri.clone(), + language_id: language_id.clone(), + version: 0, + text: snapshot.text(), + }, + }, + ) + .log_err(); + + RegisteredBuffer { + uri, + language_id, + snapshot, + snapshot_version: 0, + pending_buffer_change: Task::ready(Some(())), + _subscriptions: [ + cx.subscribe(buffer, |this, buffer, event, cx| { + this.handle_buffer_event(buffer, event, cx).log_err(); + }), + cx.observe_release(buffer, move |this, _buffer, _cx| { + this.buffers.remove(&weak_buffer); + this.unregister_buffer(&weak_buffer); + }), + ], + } + }); + } + } + + fn handle_buffer_event( + &mut self, + buffer: Model, + event: &language2::Event, + cx: &mut ModelContext, + ) -> Result<()> { + if let Ok(server) = self.server.as_running() { + if let Some(registered_buffer) = server.registered_buffers.get_mut(&buffer.entity_id()) + { + match event { + language2::Event::Edited => { + let _ = registered_buffer.report_changes(&buffer, cx); + } + language2::Event::Saved => { + server + .lsp + .notify::( + lsp2::DidSaveTextDocumentParams { + text_document: lsp2::TextDocumentIdentifier::new( + registered_buffer.uri.clone(), + ), + text: None, + }, + )?; + } + language2::Event::FileHandleChanged | language2::Event::LanguageChanged => { + let new_language_id = id_for_language(buffer.read(cx).language()); + let new_uri = uri_for_buffer(&buffer, cx); + if new_uri != registered_buffer.uri + || new_language_id != registered_buffer.language_id + { + let old_uri = mem::replace(&mut registered_buffer.uri, new_uri); + registered_buffer.language_id = new_language_id; + server + .lsp + .notify::( + lsp2::DidCloseTextDocumentParams { + text_document: lsp2::TextDocumentIdentifier::new(old_uri), + }, + )?; + server + .lsp + .notify::( + lsp2::DidOpenTextDocumentParams { + text_document: lsp2::TextDocumentItem::new( + registered_buffer.uri.clone(), + registered_buffer.language_id.clone(), + registered_buffer.snapshot_version, + registered_buffer.snapshot.text(), + ), + }, + )?; + } + } + _ => {} + } + } + } + + Ok(()) + } + + fn unregister_buffer(&mut self, buffer: &WeakModel) { + if let Ok(server) = self.server.as_running() { + if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) { + server + .lsp + .notify::( + lsp2::DidCloseTextDocumentParams { + text_document: lsp2::TextDocumentIdentifier::new(buffer.uri), + }, + ) + .log_err(); + } + } + } + + pub fn completions( + &mut self, + buffer: &Model, + position: T, + cx: &mut ModelContext, + ) -> Task>> + where + T: ToPointUtf16, + { + self.request_completions::(buffer, position, cx) + } + + pub fn completions_cycling( + &mut self, + buffer: &Model, + position: T, + cx: &mut ModelContext, + ) -> Task>> + where + T: ToPointUtf16, + { + self.request_completions::(buffer, position, cx) + } + + pub fn accept_completion( + &mut self, + completion: &Completion, + cx: &mut ModelContext, + ) -> Task> { + let server = match self.server.as_authenticated() { + Ok(server) => server, + Err(error) => return Task::ready(Err(error)), + }; + let request = + server + .lsp + .request::(request::NotifyAcceptedParams { + uuid: completion.uuid.clone(), + }); + cx.executor().spawn(async move { + request.await?; + Ok(()) + }) + } + + pub fn discard_completions( + &mut self, + completions: &[Completion], + cx: &mut ModelContext, + ) -> Task> { + let server = match self.server.as_authenticated() { + Ok(server) => server, + Err(error) => return Task::ready(Err(error)), + }; + let request = + server + .lsp + .request::(request::NotifyRejectedParams { + uuids: completions + .iter() + .map(|completion| completion.uuid.clone()) + .collect(), + }); + cx.executor().spawn(async move { + request.await?; + Ok(()) + }) + } + + fn request_completions( + &mut self, + buffer: &Model, + position: T, + cx: &mut ModelContext, + ) -> Task>> + where + R: 'static + + lsp2::request::Request< + Params = request::GetCompletionsParams, + Result = request::GetCompletionsResult, + >, + T: ToPointUtf16, + { + self.register_buffer(buffer, cx); + + let server = match self.server.as_authenticated() { + Ok(server) => server, + Err(error) => return Task::ready(Err(error)), + }; + let lsp = server.lsp.clone(); + let registered_buffer = server + .registered_buffers + .get_mut(&buffer.entity_id()) + .unwrap(); + let snapshot = registered_buffer.report_changes(buffer, cx); + let buffer = buffer.read(cx); + let uri = registered_buffer.uri.clone(); + let position = position.to_point_utf16(buffer); + let settings = language_settings(buffer.language_at(position).as_ref(), buffer.file(), cx); + let tab_size = settings.tab_size; + let hard_tabs = settings.hard_tabs; + let relative_path = buffer + .file() + .map(|file| file.path().to_path_buf()) + .unwrap_or_default(); + + cx.executor().spawn(async move { + let (version, snapshot) = snapshot.await?; + let result = lsp + .request::(request::GetCompletionsParams { + doc: request::GetCompletionsDocument { + uri, + tab_size: tab_size.into(), + indent_size: 1, + insert_spaces: !hard_tabs, + relative_path: relative_path.to_string_lossy().into(), + position: point_to_lsp(position), + version: version.try_into().unwrap(), + }, + }) + .await?; + let completions = result + .completions + .into_iter() + .map(|completion| { + let start = snapshot + .clip_point_utf16(point_from_lsp(completion.range.start), Bias::Left); + let end = + snapshot.clip_point_utf16(point_from_lsp(completion.range.end), Bias::Left); + Completion { + uuid: completion.uuid, + range: snapshot.anchor_before(start)..snapshot.anchor_after(end), + text: completion.text, + } + }) + .collect(); + anyhow::Ok(completions) + }) + } + + pub fn status(&self) -> Status { + match &self.server { + CopilotServer::Starting { task } => Status::Starting { task: task.clone() }, + CopilotServer::Disabled => Status::Disabled, + CopilotServer::Error(error) => Status::Error(error.clone()), + CopilotServer::Running(RunningCopilotServer { sign_in_status, .. }) => { + match sign_in_status { + SignInStatus::Authorized { .. } => Status::Authorized, + SignInStatus::Unauthorized { .. } => Status::Unauthorized, + SignInStatus::SigningIn { prompt, .. } => Status::SigningIn { + prompt: prompt.clone(), + }, + SignInStatus::SignedOut => Status::SignedOut, + } + } + } + } + + fn update_sign_in_status( + &mut self, + lsp_status: request::SignInStatus, + cx: &mut ModelContext, + ) { + self.buffers.retain(|buffer| buffer.is_upgradable()); + + if let Ok(server) = self.server.as_running() { + match lsp_status { + request::SignInStatus::Ok { .. } + | request::SignInStatus::MaybeOk { .. } + | request::SignInStatus::AlreadySignedIn { .. } => { + server.sign_in_status = SignInStatus::Authorized; + for buffer in self.buffers.iter().cloned().collect::>() { + if let Some(buffer) = buffer.upgrade() { + self.register_buffer(&buffer, cx); + } + } + } + request::SignInStatus::NotAuthorized { .. } => { + server.sign_in_status = SignInStatus::Unauthorized; + for buffer in self.buffers.iter().cloned().collect::>() { + self.unregister_buffer(&buffer); + } + } + request::SignInStatus::NotSignedIn => { + server.sign_in_status = SignInStatus::SignedOut; + for buffer in self.buffers.iter().cloned().collect::>() { + self.unregister_buffer(&buffer); + } + } + } + + cx.notify(); + } + } +} + +fn id_for_language(language: Option<&Arc>) -> String { + let language_name = language.map(|language| language.name()); + match language_name.as_deref() { + Some("Plain Text") => "plaintext".to_string(), + Some(language_name) => language_name.to_lowercase(), + None => "plaintext".to_string(), + } +} + +fn uri_for_buffer(buffer: &Model, cx: &AppContext) -> lsp2::Url { + if let Some(file) = buffer.read(cx).file().and_then(|file| file.as_local()) { + lsp2::Url::from_file_path(file.abs_path(cx)).unwrap() + } else { + format!("buffer://{}", buffer.entity_id()).parse().unwrap() + } +} + +async fn clear_copilot_dir() { + remove_matching(&paths::COPILOT_DIR, |_| true).await +} + +async fn get_copilot_lsp(http: Arc) -> anyhow::Result { + const SERVER_PATH: &'static str = "dist/agent.js"; + + ///Check for the latest copilot language server and download it if we haven't already + async fn fetch_latest(http: Arc) -> anyhow::Result { + let release = latest_github_release("zed-industries/copilot", false, http.clone()).await?; + + let version_dir = &*paths::COPILOT_DIR.join(format!("copilot-{}", release.name)); + + fs::create_dir_all(version_dir).await?; + let server_path = version_dir.join(SERVER_PATH); + + if fs::metadata(&server_path).await.is_err() { + // Copilot LSP looks for this dist dir specifcially, so lets add it in. + let dist_dir = version_dir.join("dist"); + fs::create_dir_all(dist_dir.as_path()).await?; + + let url = &release + .assets + .get(0) + .context("Github release for copilot contained no assets")? + .browser_download_url; + + let mut response = http + .get(&url, Default::default(), true) + .await + .map_err(|err| anyhow!("error downloading copilot release: {}", err))?; + let decompressed_bytes = GzipDecoder::new(BufReader::new(response.body_mut())); + let archive = Archive::new(decompressed_bytes); + archive.unpack(dist_dir).await?; + + remove_matching(&paths::COPILOT_DIR, |entry| entry != version_dir).await; + } + + Ok(server_path) + } + + match fetch_latest(http).await { + ok @ Result::Ok(..) => ok, + e @ Err(..) => { + e.log_err(); + // Fetch a cached binary, if it exists + (|| async move { + let mut last_version_dir = None; + let mut entries = fs::read_dir(paths::COPILOT_DIR.as_path()).await?; + while let Some(entry) = entries.next().await { + let entry = entry?; + if entry.file_type().await?.is_dir() { + last_version_dir = Some(entry.path()); + } + } + let last_version_dir = + last_version_dir.ok_or_else(|| anyhow!("no cached binary"))?; + let server_path = last_version_dir.join(SERVER_PATH); + if server_path.exists() { + Ok(server_path) + } else { + Err(anyhow!( + "missing executable in directory {:?}", + last_version_dir + )) + } + })() + .await + } + } +} + +// #[cfg(test)] +// mod tests { +// use super::*; +// use gpui::{executor::Deterministic, TestAppContext}; + +// #[gpui::test(iterations = 10)] +// async fn test_buffer_management(deterministic: Arc, cx: &mut TestAppContext) { +// deterministic.forbid_parking(); +// let (copilot, mut lsp) = Copilot::fake(cx); + +// let buffer_1 = cx.add_model(|cx| Buffer::new(0, cx.model_id() as u64, "Hello")); +// let buffer_1_uri: lsp::Url = format!("buffer://{}", buffer_1.id()).parse().unwrap(); +// copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_1, cx)); +// assert_eq!( +// lsp.receive_notification::() +// .await, +// lsp::DidOpenTextDocumentParams { +// text_document: lsp::TextDocumentItem::new( +// buffer_1_uri.clone(), +// "plaintext".into(), +// 0, +// "Hello".into() +// ), +// } +// ); + +// let buffer_2 = cx.add_model(|cx| Buffer::new(0, cx.model_id() as u64, "Goodbye")); +// let buffer_2_uri: lsp::Url = format!("buffer://{}", buffer_2.id()).parse().unwrap(); +// copilot.update(cx, |copilot, cx| copilot.register_buffer(&buffer_2, cx)); +// assert_eq!( +// lsp.receive_notification::() +// .await, +// lsp::DidOpenTextDocumentParams { +// text_document: lsp::TextDocumentItem::new( +// buffer_2_uri.clone(), +// "plaintext".into(), +// 0, +// "Goodbye".into() +// ), +// } +// ); + +// buffer_1.update(cx, |buffer, cx| buffer.edit([(5..5, " world")], None, cx)); +// assert_eq!( +// lsp.receive_notification::() +// .await, +// lsp::DidChangeTextDocumentParams { +// text_document: lsp::VersionedTextDocumentIdentifier::new(buffer_1_uri.clone(), 1), +// content_changes: vec![lsp::TextDocumentContentChangeEvent { +// range: Some(lsp::Range::new( +// lsp::Position::new(0, 5), +// lsp::Position::new(0, 5) +// )), +// range_length: None, +// text: " world".into(), +// }], +// } +// ); + +// // Ensure updates to the file are reflected in the LSP. +// buffer_1 +// .update(cx, |buffer, cx| { +// buffer.file_updated( +// Arc::new(File { +// abs_path: "/root/child/buffer-1".into(), +// path: Path::new("child/buffer-1").into(), +// }), +// cx, +// ) +// }) +// .await; +// assert_eq!( +// lsp.receive_notification::() +// .await, +// lsp::DidCloseTextDocumentParams { +// text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri), +// } +// ); +// let buffer_1_uri = lsp::Url::from_file_path("/root/child/buffer-1").unwrap(); +// assert_eq!( +// lsp.receive_notification::() +// .await, +// lsp::DidOpenTextDocumentParams { +// text_document: lsp::TextDocumentItem::new( +// buffer_1_uri.clone(), +// "plaintext".into(), +// 1, +// "Hello world".into() +// ), +// } +// ); + +// // Ensure all previously-registered buffers are closed when signing out. +// lsp.handle_request::(|_, _| async { +// Ok(request::SignOutResult {}) +// }); +// copilot +// .update(cx, |copilot, cx| copilot.sign_out(cx)) +// .await +// .unwrap(); +// assert_eq!( +// lsp.receive_notification::() +// .await, +// lsp::DidCloseTextDocumentParams { +// text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri.clone()), +// } +// ); +// assert_eq!( +// lsp.receive_notification::() +// .await, +// lsp::DidCloseTextDocumentParams { +// text_document: lsp::TextDocumentIdentifier::new(buffer_1_uri.clone()), +// } +// ); + +// // Ensure all previously-registered buffers are re-opened when signing in. +// lsp.handle_request::(|_, _| async { +// Ok(request::SignInInitiateResult::AlreadySignedIn { +// user: "user-1".into(), +// }) +// }); +// copilot +// .update(cx, |copilot, cx| copilot.sign_in(cx)) +// .await +// .unwrap(); +// assert_eq!( +// lsp.receive_notification::() +// .await, +// lsp::DidOpenTextDocumentParams { +// text_document: lsp::TextDocumentItem::new( +// buffer_2_uri.clone(), +// "plaintext".into(), +// 0, +// "Goodbye".into() +// ), +// } +// ); +// assert_eq!( +// lsp.receive_notification::() +// .await, +// lsp::DidOpenTextDocumentParams { +// text_document: lsp::TextDocumentItem::new( +// buffer_1_uri.clone(), +// "plaintext".into(), +// 0, +// "Hello world".into() +// ), +// } +// ); + +// // Dropping a buffer causes it to be closed on the LSP side as well. +// cx.update(|_| drop(buffer_2)); +// assert_eq!( +// lsp.receive_notification::() +// .await, +// lsp::DidCloseTextDocumentParams { +// text_document: lsp::TextDocumentIdentifier::new(buffer_2_uri), +// } +// ); +// } + +// struct File { +// abs_path: PathBuf, +// path: Arc, +// } + +// impl language2::File for File { +// fn as_local(&self) -> Option<&dyn language2::LocalFile> { +// Some(self) +// } + +// fn mtime(&self) -> std::time::SystemTime { +// unimplemented!() +// } + +// fn path(&self) -> &Arc { +// &self.path +// } + +// fn full_path(&self, _: &AppContext) -> PathBuf { +// unimplemented!() +// } + +// fn file_name<'a>(&'a self, _: &'a AppContext) -> &'a std::ffi::OsStr { +// unimplemented!() +// } + +// fn is_deleted(&self) -> bool { +// unimplemented!() +// } + +// fn as_any(&self) -> &dyn std::any::Any { +// unimplemented!() +// } + +// fn to_proto(&self) -> rpc::proto::File { +// unimplemented!() +// } + +// fn worktree_id(&self) -> usize { +// 0 +// } +// } + +// impl language::LocalFile for File { +// fn abs_path(&self, _: &AppContext) -> PathBuf { +// self.abs_path.clone() +// } + +// fn load(&self, _: &AppContext) -> Task> { +// unimplemented!() +// } + +// fn buffer_reloaded( +// &self, +// _: u64, +// _: &clock::Global, +// _: language::RopeFingerprint, +// _: language::LineEnding, +// _: std::time::SystemTime, +// _: &mut AppContext, +// ) { +// unimplemented!() +// } +// } +// } diff --git a/crates/copilot2/src/request.rs b/crates/copilot2/src/request.rs new file mode 100644 index 0000000000000000000000000000000000000000..fee92051dcad54d5328a8e55e104f76f71f8f452 --- /dev/null +++ b/crates/copilot2/src/request.rs @@ -0,0 +1,225 @@ +use serde::{Deserialize, Serialize}; + +pub enum CheckStatus {} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CheckStatusParams { + pub local_checks_only: bool, +} + +impl lsp2::request::Request for CheckStatus { + type Params = CheckStatusParams; + type Result = SignInStatus; + const METHOD: &'static str = "checkStatus"; +} + +pub enum SignInInitiate {} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SignInInitiateParams {} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "status")] +pub enum SignInInitiateResult { + AlreadySignedIn { user: String }, + PromptUserDeviceFlow(PromptUserDeviceFlow), +} + +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PromptUserDeviceFlow { + pub user_code: String, + pub verification_uri: String, +} + +impl lsp2::request::Request for SignInInitiate { + type Params = SignInInitiateParams; + type Result = SignInInitiateResult; + const METHOD: &'static str = "signInInitiate"; +} + +pub enum SignInConfirm {} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SignInConfirmParams { + pub user_code: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "status")] +pub enum SignInStatus { + #[serde(rename = "OK")] + Ok { + user: String, + }, + MaybeOk { + user: String, + }, + AlreadySignedIn { + user: String, + }, + NotAuthorized { + user: String, + }, + NotSignedIn, +} + +impl lsp2::request::Request for SignInConfirm { + type Params = SignInConfirmParams; + type Result = SignInStatus; + const METHOD: &'static str = "signInConfirm"; +} + +pub enum SignOut {} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SignOutParams {} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SignOutResult {} + +impl lsp2::request::Request for SignOut { + type Params = SignOutParams; + type Result = SignOutResult; + const METHOD: &'static str = "signOut"; +} + +pub enum GetCompletions {} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GetCompletionsParams { + pub doc: GetCompletionsDocument, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GetCompletionsDocument { + pub tab_size: u32, + pub indent_size: u32, + pub insert_spaces: bool, + pub uri: lsp2::Url, + pub relative_path: String, + pub position: lsp2::Position, + pub version: usize, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GetCompletionsResult { + pub completions: Vec, +} + +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Completion { + pub text: String, + pub position: lsp2::Position, + pub uuid: String, + pub range: lsp2::Range, + pub display_text: String, +} + +impl lsp2::request::Request for GetCompletions { + type Params = GetCompletionsParams; + type Result = GetCompletionsResult; + const METHOD: &'static str = "getCompletions"; +} + +pub enum GetCompletionsCycling {} + +impl lsp2::request::Request for GetCompletionsCycling { + type Params = GetCompletionsParams; + type Result = GetCompletionsResult; + const METHOD: &'static str = "getCompletionsCycling"; +} + +pub enum LogMessage {} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LogMessageParams { + pub level: u8, + pub message: String, + pub metadata_str: String, + pub extra: Vec, +} + +impl lsp2::notification::Notification for LogMessage { + type Params = LogMessageParams; + const METHOD: &'static str = "LogMessage"; +} + +pub enum StatusNotification {} + +#[derive(Debug, Serialize, Deserialize)] +pub struct StatusNotificationParams { + pub message: String, + pub status: String, // One of Normal/InProgress +} + +impl lsp2::notification::Notification for StatusNotification { + type Params = StatusNotificationParams; + const METHOD: &'static str = "statusNotification"; +} + +pub enum SetEditorInfo {} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SetEditorInfoParams { + pub editor_info: EditorInfo, + pub editor_plugin_info: EditorPluginInfo, +} + +impl lsp2::request::Request for SetEditorInfo { + type Params = SetEditorInfoParams; + type Result = String; + const METHOD: &'static str = "setEditorInfo"; +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct EditorInfo { + pub name: String, + pub version: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct EditorPluginInfo { + pub name: String, + pub version: String, +} + +pub enum NotifyAccepted {} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct NotifyAcceptedParams { + pub uuid: String, +} + +impl lsp2::request::Request for NotifyAccepted { + type Params = NotifyAcceptedParams; + type Result = String; + const METHOD: &'static str = "notifyAccepted"; +} + +pub enum NotifyRejected {} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct NotifyRejectedParams { + pub uuids: Vec, +} + +impl lsp2::request::Request for NotifyRejected { + type Params = NotifyRejectedParams; + type Result = String; + const METHOD: &'static str = "notifyRejected"; +} diff --git a/crates/copilot2/src/sign_in.rs b/crates/copilot2/src/sign_in.rs new file mode 100644 index 0000000000000000000000000000000000000000..57f248aa52486d8c04672eeef8d33e5acda2a52c --- /dev/null +++ b/crates/copilot2/src/sign_in.rs @@ -0,0 +1,376 @@ +// TODO add logging in +// use crate::{request::PromptUserDeviceFlow, Copilot, Status}; +// use gpui::{ +// elements::*, +// geometry::rect::RectF, +// platform::{WindowBounds, WindowKind, WindowOptions}, +// AnyElement, AnyViewHandle, AppContext, ClipboardItem, Element, Entity, View, ViewContext, +// WindowHandle, +// }; +// use theme::ui::modal; + +// #[derive(PartialEq, Eq, Debug, Clone)] +// struct CopyUserCode; + +// #[derive(PartialEq, Eq, Debug, Clone)] +// struct OpenGithub; + +// const COPILOT_SIGN_UP_URL: &'static str = "https://github.com/features/copilot"; + +// pub fn init(cx: &mut AppContext) { +// if let Some(copilot) = Copilot::global(cx) { +// let mut verification_window: Option> = None; +// cx.observe(&copilot, move |copilot, cx| { +// let status = copilot.read(cx).status(); + +// match &status { +// crate::Status::SigningIn { prompt } => { +// if let Some(window) = verification_window.as_mut() { +// let updated = window +// .root(cx) +// .map(|root| { +// root.update(cx, |verification, cx| { +// verification.set_status(status.clone(), cx); +// cx.activate_window(); +// }) +// }) +// .is_some(); +// if !updated { +// verification_window = Some(create_copilot_auth_window(cx, &status)); +// } +// } else if let Some(_prompt) = prompt { +// verification_window = Some(create_copilot_auth_window(cx, &status)); +// } +// } +// Status::Authorized | Status::Unauthorized => { +// if let Some(window) = verification_window.as_ref() { +// if let Some(verification) = window.root(cx) { +// verification.update(cx, |verification, cx| { +// verification.set_status(status, cx); +// cx.platform().activate(true); +// cx.activate_window(); +// }); +// } +// } +// } +// _ => { +// if let Some(code_verification) = verification_window.take() { +// code_verification.update(cx, |cx| cx.remove_window()); +// } +// } +// } +// }) +// .detach(); +// } +// } + +// fn create_copilot_auth_window( +// cx: &mut AppContext, +// status: &Status, +// ) -> WindowHandle { +// let window_size = theme::current(cx).copilot.modal.dimensions(); +// let window_options = WindowOptions { +// bounds: WindowBounds::Fixed(RectF::new(Default::default(), window_size)), +// titlebar: None, +// center: true, +// focus: true, +// show: true, +// kind: WindowKind::Normal, +// is_movable: true, +// screen: None, +// }; +// cx.add_window(window_options, |_cx| { +// CopilotCodeVerification::new(status.clone()) +// }) +// } + +// pub struct CopilotCodeVerification { +// status: Status, +// connect_clicked: bool, +// } + +// impl CopilotCodeVerification { +// pub fn new(status: Status) -> Self { +// Self { +// status, +// connect_clicked: false, +// } +// } + +// pub fn set_status(&mut self, status: Status, cx: &mut ViewContext) { +// self.status = status; +// cx.notify(); +// } + +// fn render_device_code( +// data: &PromptUserDeviceFlow, +// style: &theme::Copilot, +// cx: &mut ViewContext, +// ) -> impl IntoAnyElement { +// let copied = cx +// .read_from_clipboard() +// .map(|item| item.text() == &data.user_code) +// .unwrap_or(false); + +// let device_code_style = &style.auth.prompting.device_code; + +// MouseEventHandler::new::(0, cx, |state, _cx| { +// Flex::row() +// .with_child( +// Label::new(data.user_code.clone(), device_code_style.text.clone()) +// .aligned() +// .contained() +// .with_style(device_code_style.left_container) +// .constrained() +// .with_width(device_code_style.left), +// ) +// .with_child( +// Label::new( +// if copied { "Copied!" } else { "Copy" }, +// device_code_style.cta.style_for(state).text.clone(), +// ) +// .aligned() +// .contained() +// .with_style(*device_code_style.right_container.style_for(state)) +// .constrained() +// .with_width(device_code_style.right), +// ) +// .contained() +// .with_style(device_code_style.cta.style_for(state).container) +// }) +// .on_click(gpui::platform::MouseButton::Left, { +// let user_code = data.user_code.clone(); +// move |_, _, cx| { +// cx.platform() +// .write_to_clipboard(ClipboardItem::new(user_code.clone())); +// cx.notify(); +// } +// }) +// .with_cursor_style(gpui::platform::CursorStyle::PointingHand) +// } + +// fn render_prompting_modal( +// connect_clicked: bool, +// data: &PromptUserDeviceFlow, +// style: &theme::Copilot, +// cx: &mut ViewContext, +// ) -> AnyElement { +// enum ConnectButton {} + +// Flex::column() +// .with_child( +// Flex::column() +// .with_children([ +// Label::new( +// "Enable Copilot by connecting", +// style.auth.prompting.subheading.text.clone(), +// ) +// .aligned(), +// Label::new( +// "your existing license.", +// style.auth.prompting.subheading.text.clone(), +// ) +// .aligned(), +// ]) +// .align_children_center() +// .contained() +// .with_style(style.auth.prompting.subheading.container), +// ) +// .with_child(Self::render_device_code(data, &style, cx)) +// .with_child( +// Flex::column() +// .with_children([ +// Label::new( +// "Paste this code into GitHub after", +// style.auth.prompting.hint.text.clone(), +// ) +// .aligned(), +// Label::new( +// "clicking the button below.", +// style.auth.prompting.hint.text.clone(), +// ) +// .aligned(), +// ]) +// .align_children_center() +// .contained() +// .with_style(style.auth.prompting.hint.container.clone()), +// ) +// .with_child(theme::ui::cta_button::( +// if connect_clicked { +// "Waiting for connection..." +// } else { +// "Connect to GitHub" +// }, +// style.auth.content_width, +// &style.auth.cta_button, +// cx, +// { +// let verification_uri = data.verification_uri.clone(); +// move |_, verification, cx| { +// cx.platform().open_url(&verification_uri); +// verification.connect_clicked = true; +// } +// }, +// )) +// .align_children_center() +// .into_any() +// } + +// fn render_enabled_modal( +// style: &theme::Copilot, +// cx: &mut ViewContext, +// ) -> AnyElement { +// enum DoneButton {} + +// let enabled_style = &style.auth.authorized; +// Flex::column() +// .with_child( +// Label::new("Copilot Enabled!", enabled_style.subheading.text.clone()) +// .contained() +// .with_style(enabled_style.subheading.container) +// .aligned(), +// ) +// .with_child( +// Flex::column() +// .with_children([ +// Label::new( +// "You can update your settings or", +// enabled_style.hint.text.clone(), +// ) +// .aligned(), +// Label::new( +// "sign out from the Copilot menu in", +// enabled_style.hint.text.clone(), +// ) +// .aligned(), +// Label::new("the status bar.", enabled_style.hint.text.clone()).aligned(), +// ]) +// .align_children_center() +// .contained() +// .with_style(enabled_style.hint.container), +// ) +// .with_child(theme::ui::cta_button::( +// "Done", +// style.auth.content_width, +// &style.auth.cta_button, +// cx, +// |_, _, cx| cx.remove_window(), +// )) +// .align_children_center() +// .into_any() +// } + +// fn render_unauthorized_modal( +// style: &theme::Copilot, +// cx: &mut ViewContext, +// ) -> AnyElement { +// let unauthorized_style = &style.auth.not_authorized; + +// Flex::column() +// .with_child( +// Flex::column() +// .with_children([ +// Label::new( +// "Enable Copilot by connecting", +// unauthorized_style.subheading.text.clone(), +// ) +// .aligned(), +// Label::new( +// "your existing license.", +// unauthorized_style.subheading.text.clone(), +// ) +// .aligned(), +// ]) +// .align_children_center() +// .contained() +// .with_style(unauthorized_style.subheading.container), +// ) +// .with_child( +// Flex::column() +// .with_children([ +// Label::new( +// "You must have an active copilot", +// unauthorized_style.warning.text.clone(), +// ) +// .aligned(), +// Label::new( +// "license to use it in Zed.", +// unauthorized_style.warning.text.clone(), +// ) +// .aligned(), +// ]) +// .align_children_center() +// .contained() +// .with_style(unauthorized_style.warning.container), +// ) +// .with_child(theme::ui::cta_button::( +// "Subscribe on GitHub", +// style.auth.content_width, +// &style.auth.cta_button, +// cx, +// |_, _, cx| { +// cx.remove_window(); +// cx.platform().open_url(COPILOT_SIGN_UP_URL) +// }, +// )) +// .align_children_center() +// .into_any() +// } +// } + +// impl Entity for CopilotCodeVerification { +// type Event = (); +// } + +// impl View for CopilotCodeVerification { +// fn ui_name() -> &'static str { +// "CopilotCodeVerification" +// } + +// fn focus_in(&mut self, _: AnyViewHandle, cx: &mut ViewContext) { +// cx.notify() +// } + +// fn focus_out(&mut self, _: AnyViewHandle, cx: &mut ViewContext) { +// cx.notify() +// } + +// fn render(&mut self, cx: &mut ViewContext) -> AnyElement { +// enum ConnectModal {} + +// let style = theme::current(cx).clone(); + +// modal::( +// "Connect Copilot to Zed", +// &style.copilot.modal, +// cx, +// |cx| { +// Flex::column() +// .with_children([ +// theme::ui::icon(&style.copilot.auth.header).into_any(), +// match &self.status { +// Status::SigningIn { +// prompt: Some(prompt), +// } => Self::render_prompting_modal( +// self.connect_clicked, +// &prompt, +// &style.copilot, +// cx, +// ), +// Status::Unauthorized => { +// self.connect_clicked = false; +// Self::render_unauthorized_modal(&style.copilot, cx) +// } +// Status::Authorized => { +// self.connect_clicked = false; +// Self::render_enabled_modal(&style.copilot, cx) +// } +// _ => Empty::new().into_any(), +// }, +// ]) +// .align_children_center() +// }, +// ) +// .into_any() +// } +// } diff --git a/crates/db2/Cargo.toml b/crates/db2/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..6ef8ec0874c422b42f2be177e0a47b847b2fe9f6 --- /dev/null +++ b/crates/db2/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "db2" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/db2.rs" +doctest = false + +[features] +test-support = [] + +[dependencies] +collections = { path = "../collections" } +gpui2 = { path = "../gpui2" } +sqlez = { path = "../sqlez" } +sqlez_macros = { path = "../sqlez_macros" } +util = { path = "../util" } +anyhow.workspace = true +indoc.workspace = true +async-trait.workspace = true +lazy_static.workspace = true +log.workspace = true +parking_lot.workspace = true +serde.workspace = true +serde_derive.workspace = true +smol.workspace = true + +[dev-dependencies] +gpui2 = { path = "../gpui2", features = ["test-support"] } +env_logger.workspace = true +tempdir.workspace = true diff --git a/crates/db2/README.md b/crates/db2/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d4ea2fee399edd6842ffd8e48d8d93aa4d7d84d8 --- /dev/null +++ b/crates/db2/README.md @@ -0,0 +1,5 @@ +# Building Queries + +First, craft your test data. The examples folder shows a template for building a test-db, and can be ran with `cargo run --example [your-example]`. + +To actually use and test your queries, import the generated DB file into https://sqliteonline.com/ \ No newline at end of file diff --git a/crates/db2/src/db2.rs b/crates/db2/src/db2.rs new file mode 100644 index 0000000000000000000000000000000000000000..e2e1ae9eaad72df37a3cc0b6f16035960bb3c772 --- /dev/null +++ b/crates/db2/src/db2.rs @@ -0,0 +1,327 @@ +pub mod kvp; +pub mod query; + +// Re-export +pub use anyhow; +use anyhow::Context; +use gpui2::AppContext; +pub use indoc::indoc; +pub use lazy_static; +pub use smol; +pub use sqlez; +pub use sqlez_macros; +pub use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME}; +pub use util::paths::DB_DIR; + +use sqlez::domain::Migrator; +use sqlez::thread_safe_connection::ThreadSafeConnection; +use sqlez_macros::sql; +use std::future::Future; +use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicBool, Ordering}; +use util::channel::ReleaseChannel; +use util::{async_maybe, ResultExt}; + +const CONNECTION_INITIALIZE_QUERY: &'static str = sql!( + PRAGMA foreign_keys=TRUE; +); + +const DB_INITIALIZE_QUERY: &'static str = sql!( + PRAGMA journal_mode=WAL; + PRAGMA busy_timeout=1; + PRAGMA case_sensitive_like=TRUE; + PRAGMA synchronous=NORMAL; +); + +const FALLBACK_DB_NAME: &'static str = "FALLBACK_MEMORY_DB"; + +const DB_FILE_NAME: &'static str = "db.sqlite"; + +lazy_static::lazy_static! { + pub static ref ZED_STATELESS: bool = std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty()); + pub static ref ALL_FILE_DB_FAILED: AtomicBool = AtomicBool::new(false); +} + +/// Open or create a database at the given directory path. +/// This will retry a couple times if there are failures. If opening fails once, the db directory +/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created. +/// In either case, static variables are set so that the user can be notified. +pub async fn open_db( + db_dir: &Path, + release_channel: &ReleaseChannel, +) -> ThreadSafeConnection { + if *ZED_STATELESS { + return open_fallback_db().await; + } + + let release_channel_name = release_channel.dev_name(); + let main_db_dir = db_dir.join(Path::new(&format!("0-{}", release_channel_name))); + + let connection = async_maybe!({ + smol::fs::create_dir_all(&main_db_dir) + .await + .context("Could not create db directory") + .log_err()?; + let db_path = main_db_dir.join(Path::new(DB_FILE_NAME)); + open_main_db(&db_path).await + }) + .await; + + if let Some(connection) = connection { + return connection; + } + + // Set another static ref so that we can escalate the notification + ALL_FILE_DB_FAILED.store(true, Ordering::Release); + + // If still failed, create an in memory db with a known name + open_fallback_db().await +} + +async fn open_main_db(db_path: &PathBuf) -> Option> { + log::info!("Opening main db"); + ThreadSafeConnection::::builder(db_path.to_string_lossy().as_ref(), true) + .with_db_initialization_query(DB_INITIALIZE_QUERY) + .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY) + .build() + .await + .log_err() +} + +async fn open_fallback_db() -> ThreadSafeConnection { + log::info!("Opening fallback db"); + ThreadSafeConnection::::builder(FALLBACK_DB_NAME, false) + .with_db_initialization_query(DB_INITIALIZE_QUERY) + .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY) + .build() + .await + .expect( + "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors", + ) +} + +#[cfg(any(test, feature = "test-support"))] +pub async fn open_test_db(db_name: &str) -> ThreadSafeConnection { + use sqlez::thread_safe_connection::locking_queue; + + ThreadSafeConnection::::builder(db_name, false) + .with_db_initialization_query(DB_INITIALIZE_QUERY) + .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY) + // Serialize queued writes via a mutex and run them synchronously + .with_write_queue_constructor(locking_queue()) + .build() + .await + .unwrap() +} + +/// Implements a basic DB wrapper for a given domain +#[macro_export] +macro_rules! define_connection { + (pub static ref $id:ident: $t:ident<()> = $migrations:expr;) => { + pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>); + + impl ::std::ops::Deref for $t { + type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl $crate::sqlez::domain::Domain for $t { + fn name() -> &'static str { + stringify!($t) + } + + fn migrations() -> &'static [&'static str] { + $migrations + } + } + + #[cfg(any(test, feature = "test-support"))] + $crate::lazy_static::lazy_static! { + pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id)))); + } + + #[cfg(not(any(test, feature = "test-support")))] + $crate::lazy_static::lazy_static! { + pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL))); + } + }; + (pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr;) => { + pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<( $($d),+, $t )>); + + impl ::std::ops::Deref for $t { + type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<($($d),+, $t)>; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl $crate::sqlez::domain::Domain for $t { + fn name() -> &'static str { + stringify!($t) + } + + fn migrations() -> &'static [&'static str] { + $migrations + } + } + + #[cfg(any(test, feature = "test-support"))] + $crate::lazy_static::lazy_static! { + pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id)))); + } + + #[cfg(not(any(test, feature = "test-support")))] + $crate::lazy_static::lazy_static! { + pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL))); + } + }; +} + +pub fn write_and_log(cx: &mut AppContext, db_write: impl FnOnce() -> F + Send + 'static) +where + F: Future> + Send, +{ + cx.executor() + .spawn(async move { db_write().await.log_err() }) + .detach() +} + +// #[cfg(test)] +// mod tests { +// use std::thread; + +// use sqlez::domain::Domain; +// use sqlez_macros::sql; +// use tempdir::TempDir; + +// use crate::open_db; + +// // Test bad migration panics +// #[gpui::test] +// #[should_panic] +// async fn test_bad_migration_panics() { +// enum BadDB {} + +// impl Domain for BadDB { +// fn name() -> &'static str { +// "db_tests" +// } + +// fn migrations() -> &'static [&'static str] { +// &[ +// sql!(CREATE TABLE test(value);), +// // failure because test already exists +// sql!(CREATE TABLE test(value);), +// ] +// } +// } + +// let tempdir = TempDir::new("DbTests").unwrap(); +// let _bad_db = open_db::(tempdir.path(), &util::channel::ReleaseChannel::Dev).await; +// } + +// /// Test that DB exists but corrupted (causing recreate) +// #[gpui::test] +// async fn test_db_corruption() { +// enum CorruptedDB {} + +// impl Domain for CorruptedDB { +// fn name() -> &'static str { +// "db_tests" +// } + +// fn migrations() -> &'static [&'static str] { +// &[sql!(CREATE TABLE test(value);)] +// } +// } + +// enum GoodDB {} + +// impl Domain for GoodDB { +// fn name() -> &'static str { +// "db_tests" //Notice same name +// } + +// fn migrations() -> &'static [&'static str] { +// &[sql!(CREATE TABLE test2(value);)] //But different migration +// } +// } + +// let tempdir = TempDir::new("DbTests").unwrap(); +// { +// let corrupt_db = +// open_db::(tempdir.path(), &util::channel::ReleaseChannel::Dev).await; +// assert!(corrupt_db.persistent()); +// } + +// let good_db = open_db::(tempdir.path(), &util::channel::ReleaseChannel::Dev).await; +// assert!( +// good_db.select_row::("SELECT * FROM test2").unwrap()() +// .unwrap() +// .is_none() +// ); +// } + +// /// Test that DB exists but corrupted (causing recreate) +// #[gpui::test(iterations = 30)] +// async fn test_simultaneous_db_corruption() { +// enum CorruptedDB {} + +// impl Domain for CorruptedDB { +// fn name() -> &'static str { +// "db_tests" +// } + +// fn migrations() -> &'static [&'static str] { +// &[sql!(CREATE TABLE test(value);)] +// } +// } + +// enum GoodDB {} + +// impl Domain for GoodDB { +// fn name() -> &'static str { +// "db_tests" //Notice same name +// } + +// fn migrations() -> &'static [&'static str] { +// &[sql!(CREATE TABLE test2(value);)] //But different migration +// } +// } + +// let tempdir = TempDir::new("DbTests").unwrap(); +// { +// // Setup the bad database +// let corrupt_db = +// open_db::(tempdir.path(), &util::channel::ReleaseChannel::Dev).await; +// assert!(corrupt_db.persistent()); +// } + +// // Try to connect to it a bunch of times at once +// let mut guards = vec![]; +// for _ in 0..10 { +// let tmp_path = tempdir.path().to_path_buf(); +// let guard = thread::spawn(move || { +// let good_db = smol::block_on(open_db::( +// tmp_path.as_path(), +// &util::channel::ReleaseChannel::Dev, +// )); +// assert!( +// good_db.select_row::("SELECT * FROM test2").unwrap()() +// .unwrap() +// .is_none() +// ); +// }); + +// guards.push(guard); +// } + +// for guard in guards.into_iter() { +// assert!(guard.join().is_ok()); +// } +// } +// } diff --git a/crates/db2/src/kvp.rs b/crates/db2/src/kvp.rs new file mode 100644 index 0000000000000000000000000000000000000000..254d91689d607ad8d9b2b5f844d73f50d7919ea7 --- /dev/null +++ b/crates/db2/src/kvp.rs @@ -0,0 +1,62 @@ +use sqlez_macros::sql; + +use crate::{define_connection, query}; + +define_connection!(pub static ref KEY_VALUE_STORE: KeyValueStore<()> = + &[sql!( + CREATE TABLE IF NOT EXISTS kv_store( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ) STRICT; + )]; +); + +impl KeyValueStore { + query! { + pub fn read_kvp(key: &str) -> Result> { + SELECT value FROM kv_store WHERE key = (?) + } + } + + query! { + pub async fn write_kvp(key: String, value: String) -> Result<()> { + INSERT OR REPLACE INTO kv_store(key, value) VALUES ((?), (?)) + } + } + + query! { + pub async fn delete_kvp(key: String) -> Result<()> { + DELETE FROM kv_store WHERE key = (?) + } + } +} + +// #[cfg(test)] +// mod tests { +// use crate::kvp::KeyValueStore; + +// #[gpui::test] +// async fn test_kvp() { +// let db = KeyValueStore(crate::open_test_db("test_kvp").await); + +// assert_eq!(db.read_kvp("key-1").unwrap(), None); + +// db.write_kvp("key-1".to_string(), "one".to_string()) +// .await +// .unwrap(); +// assert_eq!(db.read_kvp("key-1").unwrap(), Some("one".to_string())); + +// db.write_kvp("key-1".to_string(), "one-2".to_string()) +// .await +// .unwrap(); +// assert_eq!(db.read_kvp("key-1").unwrap(), Some("one-2".to_string())); + +// db.write_kvp("key-2".to_string(), "two".to_string()) +// .await +// .unwrap(); +// assert_eq!(db.read_kvp("key-2").unwrap(), Some("two".to_string())); + +// db.delete_kvp("key-1".to_string()).await.unwrap(); +// assert_eq!(db.read_kvp("key-1").unwrap(), None); +// } +// } diff --git a/crates/db2/src/query.rs b/crates/db2/src/query.rs new file mode 100644 index 0000000000000000000000000000000000000000..27d94ade9eefc29a5f1850b7705ad4d6921942ff --- /dev/null +++ b/crates/db2/src/query.rs @@ -0,0 +1,314 @@ +#[macro_export] +macro_rules! query { + ($vis:vis fn $id:ident() -> Result<()> { $($sql:tt)+ }) => { + $vis fn $id(&self) -> $crate::anyhow::Result<()> { + use $crate::anyhow::Context; + + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + self.exec(sql_stmt)?().context(::std::format!( + "Error in {}, exec failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt, + )) + } + }; + ($vis:vis async fn $id:ident() -> Result<()> { $($sql:tt)+ }) => { + $vis async fn $id(&self) -> $crate::anyhow::Result<()> { + use $crate::anyhow::Context; + + self.write(|connection| { + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + connection.exec(sql_stmt)?().context(::std::format!( + "Error in {}, exec failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + }).await + } + }; + ($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<()> { $($sql:tt)+ }) => { + $vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<()> { + use $crate::anyhow::Context; + + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + self.exec_bound::<($($arg_type),+)>(sql_stmt)?(($($arg),+)) + .context(::std::format!( + "Error in {}, exec_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + } + }; + ($vis:vis async fn $id:ident($arg:ident: $arg_type:ty) -> Result<()> { $($sql:tt)+ }) => { + $vis async fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result<()> { + use $crate::anyhow::Context; + + self.write(move |connection| { + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + connection.exec_bound::<$arg_type>(sql_stmt)?($arg) + .context(::std::format!( + "Error in {}, exec_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + }).await + } + }; + ($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<()> { $($sql:tt)+ }) => { + $vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<()> { + use $crate::anyhow::Context; + + self.write(move |connection| { + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + connection.exec_bound::<($($arg_type),+)>(sql_stmt)?(($($arg),+)) + .context(::std::format!( + "Error in {}, exec_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + }).await + } + }; + ($vis:vis fn $id:ident() -> Result> { $($sql:tt)+ }) => { + $vis fn $id(&self) -> $crate::anyhow::Result> { + use $crate::anyhow::Context; + + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + self.select::<$return_type>(sql_stmt)?() + .context(::std::format!( + "Error in {}, select_row failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + } + }; + ($vis:vis async fn $id:ident() -> Result> { $($sql:tt)+ }) => { + pub async fn $id(&self) -> $crate::anyhow::Result> { + use $crate::anyhow::Context; + + self.write(|connection| { + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + connection.select::<$return_type>(sql_stmt)?() + .context(::std::format!( + "Error in {}, select_row failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + }).await + } + }; + ($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result> { $($sql:tt)+ }) => { + $vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result> { + use $crate::anyhow::Context; + + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + self.select_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+)) + .context(::std::format!( + "Error in {}, exec_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + } + }; + ($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result> { $($sql:tt)+ }) => { + $vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result> { + use $crate::anyhow::Context; + + self.write(|connection| { + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + connection.select_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+)) + .context(::std::format!( + "Error in {}, exec_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + }).await + } + }; + ($vis:vis fn $id:ident() -> Result> { $($sql:tt)+ }) => { + $vis fn $id(&self) -> $crate::anyhow::Result> { + use $crate::anyhow::Context; + + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + self.select_row::<$return_type>(sql_stmt)?() + .context(::std::format!( + "Error in {}, select_row failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + } + }; + ($vis:vis async fn $id:ident() -> Result> { $($sql:tt)+ }) => { + $vis async fn $id(&self) -> $crate::anyhow::Result> { + use $crate::anyhow::Context; + + self.write(|connection| { + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + connection.select_row::<$return_type>(sql_stmt)?() + .context(::std::format!( + "Error in {}, select_row failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + }).await + } + }; + ($vis:vis fn $id:ident($arg:ident: $arg_type:ty) -> Result> { $($sql:tt)+ }) => { + $vis fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result> { + use $crate::anyhow::Context; + + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + self.select_row_bound::<$arg_type, $return_type>(sql_stmt)?($arg) + .context(::std::format!( + "Error in {}, select_row_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + + } + }; + ($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result> { $($sql:tt)+ }) => { + $vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result> { + use $crate::anyhow::Context; + + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + self.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+)) + .context(::std::format!( + "Error in {}, select_row_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + + } + }; + ($vis:vis async fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result> { $($sql:tt)+ }) => { + $vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result> { + use $crate::anyhow::Context; + + + self.write(move |connection| { + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + connection.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+)) + .context(::std::format!( + "Error in {}, select_row_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + )) + }).await + } + }; + ($vis:vis fn $id:ident() -> Result<$return_type:ty> { $($sql:tt)+ }) => { + $vis fn $id(&self) -> $crate::anyhow::Result<$return_type> { + use $crate::anyhow::Context; + + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + self.select_row::<$return_type>(indoc! { $sql })?() + .context(::std::format!( + "Error in {}, select_row_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + ))? + .context(::std::format!( + "Error in {}, select_row_bound expected single row result but found none for: {}", + ::std::stringify!($id), + sql_stmt + )) + } + }; + ($vis:vis async fn $id:ident() -> Result<$return_type:ty> { $($sql:tt)+ }) => { + $vis async fn $id(&self) -> $crate::anyhow::Result<$return_type> { + use $crate::anyhow::Context; + + self.write(|connection| { + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + connection.select_row::<$return_type>(sql_stmt)?() + .context(::std::format!( + "Error in {}, select_row_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + ))? + .context(::std::format!( + "Error in {}, select_row_bound expected single row result but found none for: {}", + ::std::stringify!($id), + sql_stmt + )) + }).await + } + }; + ($vis:vis fn $id:ident($arg:ident: $arg_type:ty) -> Result<$return_type:ty> { $($sql:tt)+ }) => { + pub fn $id(&self, $arg: $arg_type) -> $crate::anyhow::Result<$return_type> { + use $crate::anyhow::Context; + + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + self.select_row_bound::<$arg_type, $return_type>(sql_stmt)?($arg) + .context(::std::format!( + "Error in {}, select_row_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + ))? + .context(::std::format!( + "Error in {}, select_row_bound expected single row result but found none for: {}", + ::std::stringify!($id), + sql_stmt + )) + } + }; + ($vis:vis fn $id:ident($($arg:ident: $arg_type:ty),+) -> Result<$return_type:ty> { $($sql:tt)+ }) => { + $vis fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<$return_type> { + use $crate::anyhow::Context; + + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + self.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+)) + .context(::std::format!( + "Error in {}, select_row_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + ))? + .context(::std::format!( + "Error in {}, select_row_bound expected single row result but found none for: {}", + ::std::stringify!($id), + sql_stmt + )) + } + }; + ($vis:vis fn async $id:ident($($arg:ident: $arg_type:ty),+) -> Result<$return_type:ty> { $($sql:tt)+ }) => { + $vis async fn $id(&self, $($arg: $arg_type),+) -> $crate::anyhow::Result<$return_type> { + use $crate::anyhow::Context; + + + self.write(|connection| { + let sql_stmt = $crate::sqlez_macros::sql!($($sql)+); + + connection.select_row_bound::<($($arg_type),+), $return_type>(sql_stmt)?(($($arg),+)) + .context(::std::format!( + "Error in {}, select_row_bound failed to execute or parse for: {}", + ::std::stringify!($id), + sql_stmt + ))? + .context(::std::format!( + "Error in {}, select_row_bound expected single row result but found none for: {}", + ::std::stringify!($id), + sql_stmt + )) + }).await + } + }; +} diff --git a/crates/feature_flags2/Cargo.toml b/crates/feature_flags2/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..ad77330ac3f4396f96d8a4f8f34052a46d721663 --- /dev/null +++ b/crates/feature_flags2/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "feature_flags2" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/feature_flags2.rs" + +[dependencies] +gpui2 = { path = "../gpui2" } +anyhow.workspace = true diff --git a/crates/feature_flags2/src/feature_flags2.rs b/crates/feature_flags2/src/feature_flags2.rs new file mode 100644 index 0000000000000000000000000000000000000000..7b1c0dd4d71de9d1c867aba541381f4928847530 --- /dev/null +++ b/crates/feature_flags2/src/feature_flags2.rs @@ -0,0 +1,80 @@ +use gpui2::{AppContext, Subscription, ViewContext}; + +#[derive(Default)] +struct FeatureFlags { + flags: Vec, + staff: bool, +} + +impl FeatureFlags { + fn has_flag(&self, flag: &str) -> bool { + self.staff || self.flags.iter().find(|f| f.as_str() == flag).is_some() + } +} + +pub trait FeatureFlag { + const NAME: &'static str; +} + +pub enum ChannelsAlpha {} + +impl FeatureFlag for ChannelsAlpha { + const NAME: &'static str = "channels_alpha"; +} + +pub trait FeatureFlagViewExt { + fn observe_flag(&mut self, callback: F) -> Subscription + where + F: Fn(bool, &mut V, &mut ViewContext) + Send + Sync + 'static; +} + +impl FeatureFlagViewExt for ViewContext<'_, '_, V> +where + V: 'static + Send + Sync, +{ + fn observe_flag(&mut self, callback: F) -> Subscription + where + F: Fn(bool, &mut V, &mut ViewContext) + Send + Sync + 'static, + { + self.observe_global::(move |v, cx| { + let feature_flags = cx.global::(); + callback(feature_flags.has_flag(::NAME), v, cx); + }) + } +} + +pub trait FeatureFlagAppExt { + fn update_flags(&mut self, staff: bool, flags: Vec); + fn set_staff(&mut self, staff: bool); + fn has_flag(&self) -> bool; + fn is_staff(&self) -> bool; +} + +impl FeatureFlagAppExt for AppContext { + fn update_flags(&mut self, staff: bool, flags: Vec) { + let feature_flags = self.default_global::(); + feature_flags.staff = staff; + feature_flags.flags = flags; + } + + fn set_staff(&mut self, staff: bool) { + let feature_flags = self.default_global::(); + feature_flags.staff = staff; + } + + fn has_flag(&self) -> bool { + if self.has_global::() { + self.global::().has_flag(T::NAME) + } else { + false + } + } + + fn is_staff(&self) -> bool { + if self.has_global::() { + return self.global::().staff; + } else { + false + } + } +} diff --git a/crates/fs2/Cargo.toml b/crates/fs2/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..36f4e9c9c92a6616d1a9956435b10ff09c476c30 --- /dev/null +++ b/crates/fs2/Cargo.toml @@ -0,0 +1,40 @@ +[package] +name = "fs2" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/fs2.rs" + +[dependencies] +collections = { path = "../collections" } +rope = { path = "../rope" } +text = { path = "../text" } +util = { path = "../util" } +sum_tree = { path = "../sum_tree" } + +anyhow.workspace = true +async-trait.workspace = true +futures.workspace = true +tempfile = "3" +fsevent = { path = "../fsevent" } +lazy_static.workspace = true +parking_lot.workspace = true +smol.workspace = true +regex.workspace = true +git2.workspace = true +serde.workspace = true +serde_derive.workspace = true +serde_json.workspace = true +log.workspace = true +libc = "0.2" +time.workspace = true + +gpui2 = { path = "../gpui2", optional = true} + +[dev-dependencies] +gpui2 = { path = "../gpui2", features = ["test-support"] } + +[features] +test-support = ["gpui2/test-support"] diff --git a/crates/fs2/src/fs2.rs b/crates/fs2/src/fs2.rs new file mode 100644 index 0000000000000000000000000000000000000000..6ff8676473889301215cd8f26910d34365f49205 --- /dev/null +++ b/crates/fs2/src/fs2.rs @@ -0,0 +1,1278 @@ +pub mod repository; + +use anyhow::{anyhow, Result}; +use fsevent::EventStream; +use futures::{future::BoxFuture, Stream, StreamExt}; +use git2::Repository as LibGitRepository; +use parking_lot::Mutex; +use repository::GitRepository; +use rope::Rope; +use smol::io::{AsyncReadExt, AsyncWriteExt}; +use std::io::Write; +use std::sync::Arc; +use std::{ + io, + os::unix::fs::MetadataExt, + path::{Component, Path, PathBuf}, + pin::Pin, + time::{Duration, SystemTime}, +}; +use tempfile::NamedTempFile; +use text::LineEnding; +use util::ResultExt; + +#[cfg(any(test, feature = "test-support"))] +use collections::{btree_map, BTreeMap}; +#[cfg(any(test, feature = "test-support"))] +use repository::{FakeGitRepositoryState, GitFileStatus}; +#[cfg(any(test, feature = "test-support"))] +use std::ffi::OsStr; + +#[async_trait::async_trait] +pub trait Fs: Send + Sync { + async fn create_dir(&self, path: &Path) -> Result<()>; + async fn create_file(&self, path: &Path, options: CreateOptions) -> Result<()>; + async fn copy_file(&self, source: &Path, target: &Path, options: CopyOptions) -> Result<()>; + async fn rename(&self, source: &Path, target: &Path, options: RenameOptions) -> Result<()>; + async fn remove_dir(&self, path: &Path, options: RemoveOptions) -> Result<()>; + async fn remove_file(&self, path: &Path, options: RemoveOptions) -> Result<()>; + async fn open_sync(&self, path: &Path) -> Result>; + async fn load(&self, path: &Path) -> Result; + async fn atomic_write(&self, path: PathBuf, text: String) -> Result<()>; + async fn save(&self, path: &Path, text: &Rope, line_ending: LineEnding) -> Result<()>; + async fn canonicalize(&self, path: &Path) -> Result; + async fn is_file(&self, path: &Path) -> bool; + async fn metadata(&self, path: &Path) -> Result>; + async fn read_link(&self, path: &Path) -> Result; + async fn read_dir( + &self, + path: &Path, + ) -> Result>>>>; + async fn watch( + &self, + path: &Path, + latency: Duration, + ) -> Pin>>>; + fn open_repo(&self, abs_dot_git: &Path) -> Option>>; + fn is_fake(&self) -> bool; + #[cfg(any(test, feature = "test-support"))] + fn as_fake(&self) -> &FakeFs; +} + +#[derive(Copy, Clone, Default)] +pub struct CreateOptions { + pub overwrite: bool, + pub ignore_if_exists: bool, +} + +#[derive(Copy, Clone, Default)] +pub struct CopyOptions { + pub overwrite: bool, + pub ignore_if_exists: bool, +} + +#[derive(Copy, Clone, Default)] +pub struct RenameOptions { + pub overwrite: bool, + pub ignore_if_exists: bool, +} + +#[derive(Copy, Clone, Default)] +pub struct RemoveOptions { + pub recursive: bool, + pub ignore_if_not_exists: bool, +} + +#[derive(Copy, Clone, Debug)] +pub struct Metadata { + pub inode: u64, + pub mtime: SystemTime, + pub is_symlink: bool, + pub is_dir: bool, +} + +pub struct RealFs; + +#[async_trait::async_trait] +impl Fs for RealFs { + async fn create_dir(&self, path: &Path) -> Result<()> { + Ok(smol::fs::create_dir_all(path).await?) + } + + async fn create_file(&self, path: &Path, options: CreateOptions) -> Result<()> { + let mut open_options = smol::fs::OpenOptions::new(); + open_options.write(true).create(true); + if options.overwrite { + open_options.truncate(true); + } else if !options.ignore_if_exists { + open_options.create_new(true); + } + open_options.open(path).await?; + Ok(()) + } + + async fn copy_file(&self, source: &Path, target: &Path, options: CopyOptions) -> Result<()> { + if !options.overwrite && smol::fs::metadata(target).await.is_ok() { + if options.ignore_if_exists { + return Ok(()); + } else { + return Err(anyhow!("{target:?} already exists")); + } + } + + smol::fs::copy(source, target).await?; + Ok(()) + } + + async fn rename(&self, source: &Path, target: &Path, options: RenameOptions) -> Result<()> { + if !options.overwrite && smol::fs::metadata(target).await.is_ok() { + if options.ignore_if_exists { + return Ok(()); + } else { + return Err(anyhow!("{target:?} already exists")); + } + } + + smol::fs::rename(source, target).await?; + Ok(()) + } + + async fn remove_dir(&self, path: &Path, options: RemoveOptions) -> Result<()> { + let result = if options.recursive { + smol::fs::remove_dir_all(path).await + } else { + smol::fs::remove_dir(path).await + }; + match result { + Ok(()) => Ok(()), + Err(err) if err.kind() == io::ErrorKind::NotFound && options.ignore_if_not_exists => { + Ok(()) + } + Err(err) => Err(err)?, + } + } + + async fn remove_file(&self, path: &Path, options: RemoveOptions) -> Result<()> { + match smol::fs::remove_file(path).await { + Ok(()) => Ok(()), + Err(err) if err.kind() == io::ErrorKind::NotFound && options.ignore_if_not_exists => { + Ok(()) + } + Err(err) => Err(err)?, + } + } + + async fn open_sync(&self, path: &Path) -> Result> { + Ok(Box::new(std::fs::File::open(path)?)) + } + + async fn load(&self, path: &Path) -> Result { + let mut file = smol::fs::File::open(path).await?; + let mut text = String::new(); + file.read_to_string(&mut text).await?; + Ok(text) + } + + async fn atomic_write(&self, path: PathBuf, data: String) -> Result<()> { + smol::unblock(move || { + let mut tmp_file = NamedTempFile::new()?; + tmp_file.write_all(data.as_bytes())?; + tmp_file.persist(path)?; + Ok::<(), anyhow::Error>(()) + }) + .await?; + + Ok(()) + } + + async fn save(&self, path: &Path, text: &Rope, line_ending: LineEnding) -> Result<()> { + let buffer_size = text.summary().len.min(10 * 1024); + if let Some(path) = path.parent() { + self.create_dir(path).await?; + } + let file = smol::fs::File::create(path).await?; + let mut writer = smol::io::BufWriter::with_capacity(buffer_size, file); + for chunk in chunks(text, line_ending) { + writer.write_all(chunk.as_bytes()).await?; + } + writer.flush().await?; + Ok(()) + } + + async fn canonicalize(&self, path: &Path) -> Result { + Ok(smol::fs::canonicalize(path).await?) + } + + async fn is_file(&self, path: &Path) -> bool { + smol::fs::metadata(path) + .await + .map_or(false, |metadata| metadata.is_file()) + } + + async fn metadata(&self, path: &Path) -> Result> { + let symlink_metadata = match smol::fs::symlink_metadata(path).await { + Ok(metadata) => metadata, + Err(err) => { + return match (err.kind(), err.raw_os_error()) { + (io::ErrorKind::NotFound, _) => Ok(None), + (io::ErrorKind::Other, Some(libc::ENOTDIR)) => Ok(None), + _ => Err(anyhow::Error::new(err)), + } + } + }; + + let is_symlink = symlink_metadata.file_type().is_symlink(); + let metadata = if is_symlink { + smol::fs::metadata(path).await? + } else { + symlink_metadata + }; + Ok(Some(Metadata { + inode: metadata.ino(), + mtime: metadata.modified().unwrap(), + is_symlink, + is_dir: metadata.file_type().is_dir(), + })) + } + + async fn read_link(&self, path: &Path) -> Result { + let path = smol::fs::read_link(path).await?; + Ok(path) + } + + async fn read_dir( + &self, + path: &Path, + ) -> Result>>>> { + let result = smol::fs::read_dir(path).await?.map(|entry| match entry { + Ok(entry) => Ok(entry.path()), + Err(error) => Err(anyhow!("failed to read dir entry {:?}", error)), + }); + Ok(Box::pin(result)) + } + + async fn watch( + &self, + path: &Path, + latency: Duration, + ) -> Pin>>> { + let (tx, rx) = smol::channel::unbounded(); + let (stream, handle) = EventStream::new(&[path], latency); + std::thread::spawn(move || { + stream.run(move |events| smol::block_on(tx.send(events)).is_ok()); + }); + Box::pin(rx.chain(futures::stream::once(async move { + drop(handle); + vec![] + }))) + } + + fn open_repo(&self, dotgit_path: &Path) -> Option>> { + LibGitRepository::open(&dotgit_path) + .log_err() + .and_then::>, _>(|libgit_repository| { + Some(Arc::new(Mutex::new(libgit_repository))) + }) + } + + fn is_fake(&self) -> bool { + false + } + #[cfg(any(test, feature = "test-support"))] + fn as_fake(&self) -> &FakeFs { + panic!("called `RealFs::as_fake`") + } +} + +#[cfg(any(test, feature = "test-support"))] +pub struct FakeFs { + // Use an unfair lock to ensure tests are deterministic. + state: Mutex, + executor: gpui2::Executor, +} + +#[cfg(any(test, feature = "test-support"))] +struct FakeFsState { + root: Arc>, + next_inode: u64, + next_mtime: SystemTime, + event_txs: Vec>>, + events_paused: bool, + buffered_events: Vec, + metadata_call_count: usize, + read_dir_call_count: usize, +} + +#[cfg(any(test, feature = "test-support"))] +#[derive(Debug)] +enum FakeFsEntry { + File { + inode: u64, + mtime: SystemTime, + content: String, + }, + Dir { + inode: u64, + mtime: SystemTime, + entries: BTreeMap>>, + git_repo_state: Option>>, + }, + Symlink { + target: PathBuf, + }, +} + +#[cfg(any(test, feature = "test-support"))] +impl FakeFsState { + fn read_path<'a>(&'a self, target: &Path) -> Result>> { + Ok(self + .try_read_path(target, true) + .ok_or_else(|| anyhow!("path does not exist: {}", target.display()))? + .0) + } + + fn try_read_path<'a>( + &'a self, + target: &Path, + follow_symlink: bool, + ) -> Option<(Arc>, PathBuf)> { + let mut path = target.to_path_buf(); + let mut canonical_path = PathBuf::new(); + let mut entry_stack = Vec::new(); + 'outer: loop { + let mut path_components = path.components().peekable(); + while let Some(component) = path_components.next() { + match component { + Component::Prefix(_) => panic!("prefix paths aren't supported"), + Component::RootDir => { + entry_stack.clear(); + entry_stack.push(self.root.clone()); + canonical_path.clear(); + canonical_path.push("/"); + } + Component::CurDir => {} + Component::ParentDir => { + entry_stack.pop()?; + canonical_path.pop(); + } + Component::Normal(name) => { + let current_entry = entry_stack.last().cloned()?; + let current_entry = current_entry.lock(); + if let FakeFsEntry::Dir { entries, .. } = &*current_entry { + let entry = entries.get(name.to_str().unwrap()).cloned()?; + if path_components.peek().is_some() || follow_symlink { + let entry = entry.lock(); + if let FakeFsEntry::Symlink { target, .. } = &*entry { + let mut target = target.clone(); + target.extend(path_components); + path = target; + continue 'outer; + } + } + entry_stack.push(entry.clone()); + canonical_path.push(name); + } else { + return None; + } + } + } + } + break; + } + Some((entry_stack.pop()?, canonical_path)) + } + + fn write_path(&self, path: &Path, callback: Fn) -> Result + where + Fn: FnOnce(btree_map::Entry>>) -> Result, + { + let path = normalize_path(path); + let filename = path + .file_name() + .ok_or_else(|| anyhow!("cannot overwrite the root"))?; + let parent_path = path.parent().unwrap(); + + let parent = self.read_path(parent_path)?; + let mut parent = parent.lock(); + let new_entry = parent + .dir_entries(parent_path)? + .entry(filename.to_str().unwrap().into()); + callback(new_entry) + } + + fn emit_event(&mut self, paths: I) + where + I: IntoIterator, + T: Into, + { + self.buffered_events + .extend(paths.into_iter().map(|path| fsevent::Event { + event_id: 0, + flags: fsevent::StreamFlags::empty(), + path: path.into(), + })); + + if !self.events_paused { + self.flush_events(self.buffered_events.len()); + } + } + + fn flush_events(&mut self, mut count: usize) { + count = count.min(self.buffered_events.len()); + let events = self.buffered_events.drain(0..count).collect::>(); + self.event_txs.retain(|tx| { + let _ = tx.try_send(events.clone()); + !tx.is_closed() + }); + } +} + +#[cfg(any(test, feature = "test-support"))] +lazy_static::lazy_static! { + pub static ref FS_DOT_GIT: &'static OsStr = OsStr::new(".git"); +} + +#[cfg(any(test, feature = "test-support"))] +impl FakeFs { + pub fn new(executor: gpui2::Executor) -> Arc { + Arc::new(Self { + executor, + state: Mutex::new(FakeFsState { + root: Arc::new(Mutex::new(FakeFsEntry::Dir { + inode: 0, + mtime: SystemTime::UNIX_EPOCH, + entries: Default::default(), + git_repo_state: None, + })), + next_mtime: SystemTime::UNIX_EPOCH, + next_inode: 1, + event_txs: Default::default(), + buffered_events: Vec::new(), + events_paused: false, + read_dir_call_count: 0, + metadata_call_count: 0, + }), + }) + } + + pub async fn insert_file(&self, path: impl AsRef, content: String) { + self.write_file_internal(path, content).unwrap() + } + + pub async fn insert_symlink(&self, path: impl AsRef, target: PathBuf) { + let mut state = self.state.lock(); + let path = path.as_ref(); + let file = Arc::new(Mutex::new(FakeFsEntry::Symlink { target })); + state + .write_path(path.as_ref(), move |e| match e { + btree_map::Entry::Vacant(e) => { + e.insert(file); + Ok(()) + } + btree_map::Entry::Occupied(mut e) => { + *e.get_mut() = file; + Ok(()) + } + }) + .unwrap(); + state.emit_event(&[path]); + } + + pub fn write_file_internal(&self, path: impl AsRef, content: String) -> Result<()> { + let mut state = self.state.lock(); + let path = path.as_ref(); + let inode = state.next_inode; + let mtime = state.next_mtime; + state.next_inode += 1; + state.next_mtime += Duration::from_nanos(1); + let file = Arc::new(Mutex::new(FakeFsEntry::File { + inode, + mtime, + content, + })); + state.write_path(path, move |entry| { + match entry { + btree_map::Entry::Vacant(e) => { + e.insert(file); + } + btree_map::Entry::Occupied(mut e) => { + *e.get_mut() = file; + } + } + Ok(()) + })?; + state.emit_event(&[path]); + Ok(()) + } + + pub fn pause_events(&self) { + self.state.lock().events_paused = true; + } + + pub fn buffered_event_count(&self) -> usize { + self.state.lock().buffered_events.len() + } + + pub fn flush_events(&self, count: usize) { + self.state.lock().flush_events(count); + } + + #[must_use] + pub fn insert_tree<'a>( + &'a self, + path: impl 'a + AsRef + Send, + tree: serde_json::Value, + ) -> futures::future::BoxFuture<'a, ()> { + use futures::FutureExt as _; + use serde_json::Value::*; + + async move { + let path = path.as_ref(); + + match tree { + Object(map) => { + self.create_dir(path).await.unwrap(); + for (name, contents) in map { + let mut path = PathBuf::from(path); + path.push(name); + self.insert_tree(&path, contents).await; + } + } + Null => { + self.create_dir(path).await.unwrap(); + } + String(contents) => { + self.insert_file(&path, contents).await; + } + _ => { + panic!("JSON object must contain only objects, strings, or null"); + } + } + } + .boxed() + } + + pub fn with_git_state(&self, dot_git: &Path, emit_git_event: bool, f: F) + where + F: FnOnce(&mut FakeGitRepositoryState), + { + let mut state = self.state.lock(); + let entry = state.read_path(dot_git).unwrap(); + let mut entry = entry.lock(); + + if let FakeFsEntry::Dir { git_repo_state, .. } = &mut *entry { + let repo_state = git_repo_state.get_or_insert_with(Default::default); + let mut repo_state = repo_state.lock(); + + f(&mut repo_state); + + if emit_git_event { + state.emit_event([dot_git]); + } + } else { + panic!("not a directory"); + } + } + + pub fn set_branch_name(&self, dot_git: &Path, branch: Option>) { + self.with_git_state(dot_git, true, |state| { + state.branch_name = branch.map(Into::into) + }) + } + + pub fn set_index_for_repo(&self, dot_git: &Path, head_state: &[(&Path, String)]) { + self.with_git_state(dot_git, true, |state| { + state.index_contents.clear(); + state.index_contents.extend( + head_state + .iter() + .map(|(path, content)| (path.to_path_buf(), content.clone())), + ); + }); + } + + pub fn set_status_for_repo_via_working_copy_change( + &self, + dot_git: &Path, + statuses: &[(&Path, GitFileStatus)], + ) { + self.with_git_state(dot_git, false, |state| { + state.worktree_statuses.clear(); + state.worktree_statuses.extend( + statuses + .iter() + .map(|(path, content)| ((**path).into(), content.clone())), + ); + }); + self.state.lock().emit_event( + statuses + .iter() + .map(|(path, _)| dot_git.parent().unwrap().join(path)), + ); + } + + pub fn set_status_for_repo_via_git_operation( + &self, + dot_git: &Path, + statuses: &[(&Path, GitFileStatus)], + ) { + self.with_git_state(dot_git, true, |state| { + state.worktree_statuses.clear(); + state.worktree_statuses.extend( + statuses + .iter() + .map(|(path, content)| ((**path).into(), content.clone())), + ); + }); + } + + pub fn paths(&self, include_dot_git: bool) -> Vec { + let mut result = Vec::new(); + let mut queue = collections::VecDeque::new(); + queue.push_back((PathBuf::from("/"), self.state.lock().root.clone())); + while let Some((path, entry)) = queue.pop_front() { + if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() { + for (name, entry) in entries { + queue.push_back((path.join(name), entry.clone())); + } + } + if include_dot_git + || !path + .components() + .any(|component| component.as_os_str() == *FS_DOT_GIT) + { + result.push(path); + } + } + result + } + + pub fn directories(&self, include_dot_git: bool) -> Vec { + let mut result = Vec::new(); + let mut queue = collections::VecDeque::new(); + queue.push_back((PathBuf::from("/"), self.state.lock().root.clone())); + while let Some((path, entry)) = queue.pop_front() { + if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() { + for (name, entry) in entries { + queue.push_back((path.join(name), entry.clone())); + } + if include_dot_git + || !path + .components() + .any(|component| component.as_os_str() == *FS_DOT_GIT) + { + result.push(path); + } + } + } + result + } + + pub fn files(&self) -> Vec { + let mut result = Vec::new(); + let mut queue = collections::VecDeque::new(); + queue.push_back((PathBuf::from("/"), self.state.lock().root.clone())); + while let Some((path, entry)) = queue.pop_front() { + let e = entry.lock(); + match &*e { + FakeFsEntry::File { .. } => result.push(path), + FakeFsEntry::Dir { entries, .. } => { + for (name, entry) in entries { + queue.push_back((path.join(name), entry.clone())); + } + } + FakeFsEntry::Symlink { .. } => {} + } + } + result + } + + /// How many `read_dir` calls have been issued. + pub fn read_dir_call_count(&self) -> usize { + self.state.lock().read_dir_call_count + } + + /// How many `metadata` calls have been issued. + pub fn metadata_call_count(&self) -> usize { + self.state.lock().metadata_call_count + } + + fn simulate_random_delay(&self) -> impl futures::Future { + self.executor.simulate_random_delay() + } +} + +#[cfg(any(test, feature = "test-support"))] +impl FakeFsEntry { + fn is_file(&self) -> bool { + matches!(self, Self::File { .. }) + } + + fn is_symlink(&self) -> bool { + matches!(self, Self::Symlink { .. }) + } + + fn file_content(&self, path: &Path) -> Result<&String> { + if let Self::File { content, .. } = self { + Ok(content) + } else { + Err(anyhow!("not a file: {}", path.display())) + } + } + + fn set_file_content(&mut self, path: &Path, new_content: String) -> Result<()> { + if let Self::File { content, mtime, .. } = self { + *mtime = SystemTime::now(); + *content = new_content; + Ok(()) + } else { + Err(anyhow!("not a file: {}", path.display())) + } + } + + fn dir_entries( + &mut self, + path: &Path, + ) -> Result<&mut BTreeMap>>> { + if let Self::Dir { entries, .. } = self { + Ok(entries) + } else { + Err(anyhow!("not a directory: {}", path.display())) + } + } +} + +#[cfg(any(test, feature = "test-support"))] +#[async_trait::async_trait] +impl Fs for FakeFs { + async fn create_dir(&self, path: &Path) -> Result<()> { + self.simulate_random_delay().await; + + let mut created_dirs = Vec::new(); + let mut cur_path = PathBuf::new(); + for component in path.components() { + let mut state = self.state.lock(); + cur_path.push(component); + if cur_path == Path::new("/") { + continue; + } + + let inode = state.next_inode; + let mtime = state.next_mtime; + state.next_mtime += Duration::from_nanos(1); + state.next_inode += 1; + state.write_path(&cur_path, |entry| { + entry.or_insert_with(|| { + created_dirs.push(cur_path.clone()); + Arc::new(Mutex::new(FakeFsEntry::Dir { + inode, + mtime, + entries: Default::default(), + git_repo_state: None, + })) + }); + Ok(()) + })? + } + + self.state.lock().emit_event(&created_dirs); + Ok(()) + } + + async fn create_file(&self, path: &Path, options: CreateOptions) -> Result<()> { + self.simulate_random_delay().await; + let mut state = self.state.lock(); + let inode = state.next_inode; + let mtime = state.next_mtime; + state.next_mtime += Duration::from_nanos(1); + state.next_inode += 1; + let file = Arc::new(Mutex::new(FakeFsEntry::File { + inode, + mtime, + content: String::new(), + })); + state.write_path(path, |entry| { + match entry { + btree_map::Entry::Occupied(mut e) => { + if options.overwrite { + *e.get_mut() = file; + } else if !options.ignore_if_exists { + return Err(anyhow!("path already exists: {}", path.display())); + } + } + btree_map::Entry::Vacant(e) => { + e.insert(file); + } + } + Ok(()) + })?; + state.emit_event(&[path]); + Ok(()) + } + + async fn rename(&self, old_path: &Path, new_path: &Path, options: RenameOptions) -> Result<()> { + self.simulate_random_delay().await; + + let old_path = normalize_path(old_path); + let new_path = normalize_path(new_path); + + let mut state = self.state.lock(); + let moved_entry = state.write_path(&old_path, |e| { + if let btree_map::Entry::Occupied(e) = e { + Ok(e.get().clone()) + } else { + Err(anyhow!("path does not exist: {}", &old_path.display())) + } + })?; + + state.write_path(&new_path, |e| { + match e { + btree_map::Entry::Occupied(mut e) => { + if options.overwrite { + *e.get_mut() = moved_entry; + } else if !options.ignore_if_exists { + return Err(anyhow!("path already exists: {}", new_path.display())); + } + } + btree_map::Entry::Vacant(e) => { + e.insert(moved_entry); + } + } + Ok(()) + })?; + + state + .write_path(&old_path, |e| { + if let btree_map::Entry::Occupied(e) = e { + Ok(e.remove()) + } else { + unreachable!() + } + }) + .unwrap(); + + state.emit_event(&[old_path, new_path]); + Ok(()) + } + + async fn copy_file(&self, source: &Path, target: &Path, options: CopyOptions) -> Result<()> { + self.simulate_random_delay().await; + + let source = normalize_path(source); + let target = normalize_path(target); + let mut state = self.state.lock(); + let mtime = state.next_mtime; + let inode = util::post_inc(&mut state.next_inode); + state.next_mtime += Duration::from_nanos(1); + let source_entry = state.read_path(&source)?; + let content = source_entry.lock().file_content(&source)?.clone(); + let entry = state.write_path(&target, |e| match e { + btree_map::Entry::Occupied(e) => { + if options.overwrite { + Ok(Some(e.get().clone())) + } else if !options.ignore_if_exists { + return Err(anyhow!("{target:?} already exists")); + } else { + Ok(None) + } + } + btree_map::Entry::Vacant(e) => Ok(Some( + e.insert(Arc::new(Mutex::new(FakeFsEntry::File { + inode, + mtime, + content: String::new(), + }))) + .clone(), + )), + })?; + if let Some(entry) = entry { + entry.lock().set_file_content(&target, content)?; + } + state.emit_event(&[target]); + Ok(()) + } + + async fn remove_dir(&self, path: &Path, options: RemoveOptions) -> Result<()> { + self.simulate_random_delay().await; + + let path = normalize_path(path); + let parent_path = path + .parent() + .ok_or_else(|| anyhow!("cannot remove the root"))?; + let base_name = path.file_name().unwrap(); + + let mut state = self.state.lock(); + let parent_entry = state.read_path(parent_path)?; + let mut parent_entry = parent_entry.lock(); + let entry = parent_entry + .dir_entries(parent_path)? + .entry(base_name.to_str().unwrap().into()); + + match entry { + btree_map::Entry::Vacant(_) => { + if !options.ignore_if_not_exists { + return Err(anyhow!("{path:?} does not exist")); + } + } + btree_map::Entry::Occupied(e) => { + { + let mut entry = e.get().lock(); + let children = entry.dir_entries(&path)?; + if !options.recursive && !children.is_empty() { + return Err(anyhow!("{path:?} is not empty")); + } + } + e.remove(); + } + } + state.emit_event(&[path]); + Ok(()) + } + + async fn remove_file(&self, path: &Path, options: RemoveOptions) -> Result<()> { + self.simulate_random_delay().await; + + let path = normalize_path(path); + let parent_path = path + .parent() + .ok_or_else(|| anyhow!("cannot remove the root"))?; + let base_name = path.file_name().unwrap(); + let mut state = self.state.lock(); + let parent_entry = state.read_path(parent_path)?; + let mut parent_entry = parent_entry.lock(); + let entry = parent_entry + .dir_entries(parent_path)? + .entry(base_name.to_str().unwrap().into()); + match entry { + btree_map::Entry::Vacant(_) => { + if !options.ignore_if_not_exists { + return Err(anyhow!("{path:?} does not exist")); + } + } + btree_map::Entry::Occupied(e) => { + e.get().lock().file_content(&path)?; + e.remove(); + } + } + state.emit_event(&[path]); + Ok(()) + } + + async fn open_sync(&self, path: &Path) -> Result> { + let text = self.load(path).await?; + Ok(Box::new(io::Cursor::new(text))) + } + + async fn load(&self, path: &Path) -> Result { + let path = normalize_path(path); + self.simulate_random_delay().await; + let state = self.state.lock(); + let entry = state.read_path(&path)?; + let entry = entry.lock(); + entry.file_content(&path).cloned() + } + + async fn atomic_write(&self, path: PathBuf, data: String) -> Result<()> { + self.simulate_random_delay().await; + let path = normalize_path(path.as_path()); + self.write_file_internal(path, data.to_string())?; + + Ok(()) + } + + async fn save(&self, path: &Path, text: &Rope, line_ending: LineEnding) -> Result<()> { + self.simulate_random_delay().await; + let path = normalize_path(path); + let content = chunks(text, line_ending).collect(); + if let Some(path) = path.parent() { + self.create_dir(path).await?; + } + self.write_file_internal(path, content)?; + Ok(()) + } + + async fn canonicalize(&self, path: &Path) -> Result { + let path = normalize_path(path); + self.simulate_random_delay().await; + let state = self.state.lock(); + if let Some((_, canonical_path)) = state.try_read_path(&path, true) { + Ok(canonical_path) + } else { + Err(anyhow!("path does not exist: {}", path.display())) + } + } + + async fn is_file(&self, path: &Path) -> bool { + let path = normalize_path(path); + self.simulate_random_delay().await; + let state = self.state.lock(); + if let Some((entry, _)) = state.try_read_path(&path, true) { + entry.lock().is_file() + } else { + false + } + } + + async fn metadata(&self, path: &Path) -> Result> { + self.simulate_random_delay().await; + let path = normalize_path(path); + let mut state = self.state.lock(); + state.metadata_call_count += 1; + if let Some((mut entry, _)) = state.try_read_path(&path, false) { + let is_symlink = entry.lock().is_symlink(); + if is_symlink { + if let Some(e) = state.try_read_path(&path, true).map(|e| e.0) { + entry = e; + } else { + return Ok(None); + } + } + + let entry = entry.lock(); + Ok(Some(match &*entry { + FakeFsEntry::File { inode, mtime, .. } => Metadata { + inode: *inode, + mtime: *mtime, + is_dir: false, + is_symlink, + }, + FakeFsEntry::Dir { inode, mtime, .. } => Metadata { + inode: *inode, + mtime: *mtime, + is_dir: true, + is_symlink, + }, + FakeFsEntry::Symlink { .. } => unreachable!(), + })) + } else { + Ok(None) + } + } + + async fn read_link(&self, path: &Path) -> Result { + self.simulate_random_delay().await; + let path = normalize_path(path); + let state = self.state.lock(); + if let Some((entry, _)) = state.try_read_path(&path, false) { + let entry = entry.lock(); + if let FakeFsEntry::Symlink { target } = &*entry { + Ok(target.clone()) + } else { + Err(anyhow!("not a symlink: {}", path.display())) + } + } else { + Err(anyhow!("path does not exist: {}", path.display())) + } + } + + async fn read_dir( + &self, + path: &Path, + ) -> Result>>>> { + self.simulate_random_delay().await; + let path = normalize_path(path); + let mut state = self.state.lock(); + state.read_dir_call_count += 1; + let entry = state.read_path(&path)?; + let mut entry = entry.lock(); + let children = entry.dir_entries(&path)?; + let paths = children + .keys() + .map(|file_name| Ok(path.join(file_name))) + .collect::>(); + Ok(Box::pin(futures::stream::iter(paths))) + } + + async fn watch( + &self, + path: &Path, + _: Duration, + ) -> Pin>>> { + self.simulate_random_delay().await; + let (tx, rx) = smol::channel::unbounded(); + self.state.lock().event_txs.push(tx); + let path = path.to_path_buf(); + let executor = self.executor.clone(); + Box::pin(futures::StreamExt::filter(rx, move |events| { + let result = events.iter().any(|event| event.path.starts_with(&path)); + let executor = executor.clone(); + async move { + executor.simulate_random_delay().await; + result + } + })) + } + + fn open_repo(&self, abs_dot_git: &Path) -> Option>> { + let state = self.state.lock(); + let entry = state.read_path(abs_dot_git).unwrap(); + let mut entry = entry.lock(); + if let FakeFsEntry::Dir { git_repo_state, .. } = &mut *entry { + let state = git_repo_state + .get_or_insert_with(|| Arc::new(Mutex::new(FakeGitRepositoryState::default()))) + .clone(); + Some(repository::FakeGitRepository::open(state)) + } else { + None + } + } + + fn is_fake(&self) -> bool { + true + } + + #[cfg(any(test, feature = "test-support"))] + fn as_fake(&self) -> &FakeFs { + self + } +} + +fn chunks(rope: &Rope, line_ending: LineEnding) -> impl Iterator { + rope.chunks().flat_map(move |chunk| { + let mut newline = false; + chunk.split('\n').flat_map(move |line| { + let ending = if newline { + Some(line_ending.as_str()) + } else { + None + }; + newline = true; + ending.into_iter().chain([line]) + }) + }) +} + +pub fn normalize_path(path: &Path) -> PathBuf { + let mut components = path.components().peekable(); + let mut ret = if let Some(c @ Component::Prefix(..)) = components.peek().cloned() { + components.next(); + PathBuf::from(c.as_os_str()) + } else { + PathBuf::new() + }; + + for component in components { + match component { + Component::Prefix(..) => unreachable!(), + Component::RootDir => { + ret.push(component.as_os_str()); + } + Component::CurDir => {} + Component::ParentDir => { + ret.pop(); + } + Component::Normal(c) => { + ret.push(c); + } + } + } + ret +} + +pub fn copy_recursive<'a>( + fs: &'a dyn Fs, + source: &'a Path, + target: &'a Path, + options: CopyOptions, +) -> BoxFuture<'a, Result<()>> { + use futures::future::FutureExt; + + async move { + let metadata = fs + .metadata(source) + .await? + .ok_or_else(|| anyhow!("path does not exist: {}", source.display()))?; + if metadata.is_dir { + if !options.overwrite && fs.metadata(target).await.is_ok() { + if options.ignore_if_exists { + return Ok(()); + } else { + return Err(anyhow!("{target:?} already exists")); + } + } + + let _ = fs + .remove_dir( + target, + RemoveOptions { + recursive: true, + ignore_if_not_exists: true, + }, + ) + .await; + fs.create_dir(target).await?; + let mut children = fs.read_dir(source).await?; + while let Some(child_path) = children.next().await { + if let Ok(child_path) = child_path { + if let Some(file_name) = child_path.file_name() { + let child_target_path = target.join(file_name); + copy_recursive(fs, &child_path, &child_target_path, options).await?; + } + } + } + + Ok(()) + } else { + fs.copy_file(source, target, options).await + } + } + .boxed() +} + +#[cfg(test)] +mod tests { + use super::*; + use gpui2::Executor; + use serde_json::json; + + #[gpui2::test] + async fn test_fake_fs(executor: Executor) { + let fs = FakeFs::new(executor.clone()); + fs.insert_tree( + "/root", + json!({ + "dir1": { + "a": "A", + "b": "B" + }, + "dir2": { + "c": "C", + "dir3": { + "d": "D" + } + } + }), + ) + .await; + + assert_eq!( + fs.files(), + vec![ + PathBuf::from("/root/dir1/a"), + PathBuf::from("/root/dir1/b"), + PathBuf::from("/root/dir2/c"), + PathBuf::from("/root/dir2/dir3/d"), + ] + ); + + fs.insert_symlink("/root/dir2/link-to-dir3", "./dir3".into()) + .await; + + assert_eq!( + fs.canonicalize("/root/dir2/link-to-dir3".as_ref()) + .await + .unwrap(), + PathBuf::from("/root/dir2/dir3"), + ); + assert_eq!( + fs.canonicalize("/root/dir2/link-to-dir3/d".as_ref()) + .await + .unwrap(), + PathBuf::from("/root/dir2/dir3/d"), + ); + assert_eq!( + fs.load("/root/dir2/link-to-dir3/d".as_ref()).await.unwrap(), + "D", + ); + } +} diff --git a/crates/fs2/src/repository.rs b/crates/fs2/src/repository.rs new file mode 100644 index 0000000000000000000000000000000000000000..4637a7f75408c74a4d398b8eb60f21d6ba76ab33 --- /dev/null +++ b/crates/fs2/src/repository.rs @@ -0,0 +1,417 @@ +use anyhow::Result; +use collections::HashMap; +use git2::{BranchType, StatusShow}; +use parking_lot::Mutex; +use serde_derive::{Deserialize, Serialize}; +use std::{ + cmp::Ordering, + ffi::OsStr, + os::unix::prelude::OsStrExt, + path::{Component, Path, PathBuf}, + sync::Arc, + time::SystemTime, +}; +use sum_tree::{MapSeekTarget, TreeMap}; +use util::ResultExt; + +pub use git2::Repository as LibGitRepository; + +#[derive(Clone, Debug, Hash, PartialEq)] +pub struct Branch { + pub name: Box, + /// Timestamp of most recent commit, normalized to Unix Epoch format. + pub unix_timestamp: Option, +} + +#[async_trait::async_trait] +pub trait GitRepository: Send { + fn reload_index(&self); + fn load_index_text(&self, relative_file_path: &Path) -> Option; + fn branch_name(&self) -> Option; + + /// Get the statuses of all of the files in the index that start with the given + /// path and have changes with resepect to the HEAD commit. This is fast because + /// the index stores hashes of trees, so that unchanged directories can be skipped. + fn staged_statuses(&self, path_prefix: &Path) -> TreeMap; + + /// Get the status of a given file in the working directory with respect to + /// the index. In the common case, when there are no changes, this only requires + /// an index lookup. The index stores the mtime of each file when it was added, + /// so there's no work to do if the mtime matches. + fn unstaged_status(&self, path: &RepoPath, mtime: SystemTime) -> Option; + + /// Get the status of a given file in the working directory with respect to + /// the HEAD commit. In the common case, when there are no changes, this only + /// requires an index lookup and blob comparison between the index and the HEAD + /// commit. The index stores the mtime of each file when it was added, so there's + /// no need to consider the working directory file if the mtime matches. + fn status(&self, path: &RepoPath, mtime: SystemTime) -> Option; + + fn branches(&self) -> Result>; + fn change_branch(&self, _: &str) -> Result<()>; + fn create_branch(&self, _: &str) -> Result<()>; +} + +impl std::fmt::Debug for dyn GitRepository { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("dyn GitRepository<...>").finish() + } +} + +impl GitRepository for LibGitRepository { + fn reload_index(&self) { + if let Ok(mut index) = self.index() { + _ = index.read(false); + } + } + + fn load_index_text(&self, relative_file_path: &Path) -> Option { + fn logic(repo: &LibGitRepository, relative_file_path: &Path) -> Result> { + const STAGE_NORMAL: i32 = 0; + let index = repo.index()?; + + // This check is required because index.get_path() unwraps internally :( + check_path_to_repo_path_errors(relative_file_path)?; + + let oid = match index.get_path(&relative_file_path, STAGE_NORMAL) { + Some(entry) => entry.id, + None => return Ok(None), + }; + + let content = repo.find_blob(oid)?.content().to_owned(); + Ok(Some(String::from_utf8(content)?)) + } + + match logic(&self, relative_file_path) { + Ok(value) => return value, + Err(err) => log::error!("Error loading head text: {:?}", err), + } + None + } + + fn branch_name(&self) -> Option { + let head = self.head().log_err()?; + let branch = String::from_utf8_lossy(head.shorthand_bytes()); + Some(branch.to_string()) + } + + fn staged_statuses(&self, path_prefix: &Path) -> TreeMap { + let mut map = TreeMap::default(); + + let mut options = git2::StatusOptions::new(); + options.pathspec(path_prefix); + options.show(StatusShow::Index); + + if let Some(statuses) = self.statuses(Some(&mut options)).log_err() { + for status in statuses.iter() { + let path = RepoPath(PathBuf::from(OsStr::from_bytes(status.path_bytes()))); + let status = status.status(); + if !status.contains(git2::Status::IGNORED) { + if let Some(status) = read_status(status) { + map.insert(path, status) + } + } + } + } + map + } + + fn unstaged_status(&self, path: &RepoPath, mtime: SystemTime) -> Option { + // If the file has not changed since it was added to the index, then + // there can't be any changes. + if matches_index(self, path, mtime) { + return None; + } + + let mut options = git2::StatusOptions::new(); + options.pathspec(&path.0); + options.disable_pathspec_match(true); + options.include_untracked(true); + options.recurse_untracked_dirs(true); + options.include_unmodified(true); + options.show(StatusShow::Workdir); + + let statuses = self.statuses(Some(&mut options)).log_err()?; + let status = statuses.get(0).and_then(|s| read_status(s.status())); + status + } + + fn status(&self, path: &RepoPath, mtime: SystemTime) -> Option { + let mut options = git2::StatusOptions::new(); + options.pathspec(&path.0); + options.disable_pathspec_match(true); + options.include_untracked(true); + options.recurse_untracked_dirs(true); + options.include_unmodified(true); + + // If the file has not changed since it was added to the index, then + // there's no need to examine the working directory file: just compare + // the blob in the index to the one in the HEAD commit. + if matches_index(self, path, mtime) { + options.show(StatusShow::Index); + } + + let statuses = self.statuses(Some(&mut options)).log_err()?; + let status = statuses.get(0).and_then(|s| read_status(s.status())); + status + } + + fn branches(&self) -> Result> { + let local_branches = self.branches(Some(BranchType::Local))?; + let valid_branches = local_branches + .filter_map(|branch| { + branch.ok().and_then(|(branch, _)| { + let name = branch.name().ok().flatten().map(Box::from)?; + let timestamp = branch.get().peel_to_commit().ok()?.time(); + let unix_timestamp = timestamp.seconds(); + let timezone_offset = timestamp.offset_minutes(); + let utc_offset = + time::UtcOffset::from_whole_seconds(timezone_offset * 60).ok()?; + let unix_timestamp = + time::OffsetDateTime::from_unix_timestamp(unix_timestamp).ok()?; + Some(Branch { + name, + unix_timestamp: Some(unix_timestamp.to_offset(utc_offset).unix_timestamp()), + }) + }) + }) + .collect(); + Ok(valid_branches) + } + fn change_branch(&self, name: &str) -> Result<()> { + let revision = self.find_branch(name, BranchType::Local)?; + let revision = revision.get(); + let as_tree = revision.peel_to_tree()?; + self.checkout_tree(as_tree.as_object(), None)?; + self.set_head( + revision + .name() + .ok_or_else(|| anyhow::anyhow!("Branch name could not be retrieved"))?, + )?; + Ok(()) + } + fn create_branch(&self, name: &str) -> Result<()> { + let current_commit = self.head()?.peel_to_commit()?; + self.branch(name, ¤t_commit, false)?; + + Ok(()) + } +} + +fn matches_index(repo: &LibGitRepository, path: &RepoPath, mtime: SystemTime) -> bool { + if let Some(index) = repo.index().log_err() { + if let Some(entry) = index.get_path(&path, 0) { + if let Some(mtime) = mtime.duration_since(SystemTime::UNIX_EPOCH).log_err() { + if entry.mtime.seconds() == mtime.as_secs() as i32 + && entry.mtime.nanoseconds() == mtime.subsec_nanos() + { + return true; + } + } + } + } + false +} + +fn read_status(status: git2::Status) -> Option { + if status.contains(git2::Status::CONFLICTED) { + Some(GitFileStatus::Conflict) + } else if status.intersects( + git2::Status::WT_MODIFIED + | git2::Status::WT_RENAMED + | git2::Status::INDEX_MODIFIED + | git2::Status::INDEX_RENAMED, + ) { + Some(GitFileStatus::Modified) + } else if status.intersects(git2::Status::WT_NEW | git2::Status::INDEX_NEW) { + Some(GitFileStatus::Added) + } else { + None + } +} + +#[derive(Debug, Clone, Default)] +pub struct FakeGitRepository { + state: Arc>, +} + +#[derive(Debug, Clone, Default)] +pub struct FakeGitRepositoryState { + pub index_contents: HashMap, + pub worktree_statuses: HashMap, + pub branch_name: Option, +} + +impl FakeGitRepository { + pub fn open(state: Arc>) -> Arc> { + Arc::new(Mutex::new(FakeGitRepository { state })) + } +} + +#[async_trait::async_trait] +impl GitRepository for FakeGitRepository { + fn reload_index(&self) {} + + fn load_index_text(&self, path: &Path) -> Option { + let state = self.state.lock(); + state.index_contents.get(path).cloned() + } + + fn branch_name(&self) -> Option { + let state = self.state.lock(); + state.branch_name.clone() + } + + fn staged_statuses(&self, path_prefix: &Path) -> TreeMap { + let mut map = TreeMap::default(); + let state = self.state.lock(); + for (repo_path, status) in state.worktree_statuses.iter() { + if repo_path.0.starts_with(path_prefix) { + map.insert(repo_path.to_owned(), status.to_owned()); + } + } + map + } + + fn unstaged_status(&self, _path: &RepoPath, _mtime: SystemTime) -> Option { + None + } + + fn status(&self, path: &RepoPath, _mtime: SystemTime) -> Option { + let state = self.state.lock(); + state.worktree_statuses.get(path).cloned() + } + + fn branches(&self) -> Result> { + Ok(vec![]) + } + + fn change_branch(&self, name: &str) -> Result<()> { + let mut state = self.state.lock(); + state.branch_name = Some(name.to_owned()); + Ok(()) + } + + fn create_branch(&self, name: &str) -> Result<()> { + let mut state = self.state.lock(); + state.branch_name = Some(name.to_owned()); + Ok(()) + } +} + +fn check_path_to_repo_path_errors(relative_file_path: &Path) -> Result<()> { + match relative_file_path.components().next() { + None => anyhow::bail!("repo path should not be empty"), + Some(Component::Prefix(_)) => anyhow::bail!( + "repo path `{}` should be relative, not a windows prefix", + relative_file_path.to_string_lossy() + ), + Some(Component::RootDir) => { + anyhow::bail!( + "repo path `{}` should be relative", + relative_file_path.to_string_lossy() + ) + } + Some(Component::CurDir) => { + anyhow::bail!( + "repo path `{}` should not start with `.`", + relative_file_path.to_string_lossy() + ) + } + Some(Component::ParentDir) => { + anyhow::bail!( + "repo path `{}` should not start with `..`", + relative_file_path.to_string_lossy() + ) + } + _ => Ok(()), + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum GitFileStatus { + Added, + Modified, + Conflict, +} + +impl GitFileStatus { + pub fn merge( + this: Option, + other: Option, + prefer_other: bool, + ) -> Option { + if prefer_other { + return other; + } else { + match (this, other) { + (Some(GitFileStatus::Conflict), _) | (_, Some(GitFileStatus::Conflict)) => { + Some(GitFileStatus::Conflict) + } + (Some(GitFileStatus::Modified), _) | (_, Some(GitFileStatus::Modified)) => { + Some(GitFileStatus::Modified) + } + (Some(GitFileStatus::Added), _) | (_, Some(GitFileStatus::Added)) => { + Some(GitFileStatus::Added) + } + _ => None, + } + } + } +} + +#[derive(Clone, Debug, Ord, Hash, PartialOrd, Eq, PartialEq)] +pub struct RepoPath(pub PathBuf); + +impl RepoPath { + pub fn new(path: PathBuf) -> Self { + debug_assert!(path.is_relative(), "Repo paths must be relative"); + + RepoPath(path) + } +} + +impl From<&Path> for RepoPath { + fn from(value: &Path) -> Self { + RepoPath::new(value.to_path_buf()) + } +} + +impl From for RepoPath { + fn from(value: PathBuf) -> Self { + RepoPath::new(value) + } +} + +impl Default for RepoPath { + fn default() -> Self { + RepoPath(PathBuf::new()) + } +} + +impl AsRef for RepoPath { + fn as_ref(&self) -> &Path { + self.0.as_ref() + } +} + +impl std::ops::Deref for RepoPath { + type Target = PathBuf; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Debug)] +pub struct RepoPathDescendants<'a>(pub &'a Path); + +impl<'a> MapSeekTarget for RepoPathDescendants<'a> { + fn cmp_cursor(&self, key: &RepoPath) -> Ordering { + if key.starts_with(&self.0) { + Ordering::Greater + } else { + self.0.cmp(key) + } + } +} diff --git a/crates/fuzzy2/Cargo.toml b/crates/fuzzy2/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..5b92a27a27fedaff8a1c05d99714238c29f5267b --- /dev/null +++ b/crates/fuzzy2/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "fuzzy2" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/fuzzy2.rs" +doctest = false + +[dependencies] +gpui2 = { path = "../gpui2" } +util = { path = "../util" } diff --git a/crates/fuzzy2/src/char_bag.rs b/crates/fuzzy2/src/char_bag.rs new file mode 100644 index 0000000000000000000000000000000000000000..8fc36368a159b671e8199248cbdef19929961459 --- /dev/null +++ b/crates/fuzzy2/src/char_bag.rs @@ -0,0 +1,63 @@ +use std::iter::FromIterator; + +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] +pub struct CharBag(u64); + +impl CharBag { + pub fn is_superset(self, other: CharBag) -> bool { + self.0 & other.0 == other.0 + } + + fn insert(&mut self, c: char) { + let c = c.to_ascii_lowercase(); + if ('a'..='z').contains(&c) { + let mut count = self.0; + let idx = c as u8 - b'a'; + count >>= idx * 2; + count = ((count << 1) | 1) & 3; + count <<= idx * 2; + self.0 |= count; + } else if ('0'..='9').contains(&c) { + let idx = c as u8 - b'0'; + self.0 |= 1 << (idx + 52); + } else if c == '-' { + self.0 |= 1 << 62; + } + } +} + +impl Extend for CharBag { + fn extend>(&mut self, iter: T) { + for c in iter { + self.insert(c); + } + } +} + +impl FromIterator for CharBag { + fn from_iter>(iter: T) -> Self { + let mut result = Self::default(); + result.extend(iter); + result + } +} + +impl From<&str> for CharBag { + fn from(s: &str) -> Self { + let mut bag = Self(0); + for c in s.chars() { + bag.insert(c); + } + bag + } +} + +impl From<&[char]> for CharBag { + fn from(chars: &[char]) -> Self { + let mut bag = Self(0); + for c in chars { + bag.insert(*c); + } + bag + } +} diff --git a/crates/fuzzy2/src/fuzzy2.rs b/crates/fuzzy2/src/fuzzy2.rs new file mode 100644 index 0000000000000000000000000000000000000000..b9595df61f2e46432c53705fee9360683039bac5 --- /dev/null +++ b/crates/fuzzy2/src/fuzzy2.rs @@ -0,0 +1,10 @@ +mod char_bag; +mod matcher; +mod paths; +mod strings; + +pub use char_bag::CharBag; +pub use paths::{ + match_fixed_path_set, match_path_sets, PathMatch, PathMatchCandidate, PathMatchCandidateSet, +}; +pub use strings::{match_strings, StringMatch, StringMatchCandidate}; diff --git a/crates/fuzzy2/src/matcher.rs b/crates/fuzzy2/src/matcher.rs new file mode 100644 index 0000000000000000000000000000000000000000..e808a4886f91152894dbaf4686fa51a786926d29 --- /dev/null +++ b/crates/fuzzy2/src/matcher.rs @@ -0,0 +1,464 @@ +use std::{ + borrow::Cow, + sync::atomic::{self, AtomicBool}, +}; + +use crate::CharBag; + +const BASE_DISTANCE_PENALTY: f64 = 0.6; +const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05; +const MIN_DISTANCE_PENALTY: f64 = 0.2; + +pub struct Matcher<'a> { + query: &'a [char], + lowercase_query: &'a [char], + query_char_bag: CharBag, + smart_case: bool, + max_results: usize, + min_score: f64, + match_positions: Vec, + last_positions: Vec, + score_matrix: Vec>, + best_position_matrix: Vec, +} + +pub trait Match: Ord { + fn score(&self) -> f64; + fn set_positions(&mut self, positions: Vec); +} + +pub trait MatchCandidate { + fn has_chars(&self, bag: CharBag) -> bool; + fn to_string(&self) -> Cow<'_, str>; +} + +impl<'a> Matcher<'a> { + pub fn new( + query: &'a [char], + lowercase_query: &'a [char], + query_char_bag: CharBag, + smart_case: bool, + max_results: usize, + ) -> Self { + Self { + query, + lowercase_query, + query_char_bag, + min_score: 0.0, + last_positions: vec![0; query.len()], + match_positions: vec![0; query.len()], + score_matrix: Vec::new(), + best_position_matrix: Vec::new(), + smart_case, + max_results, + } + } + + pub fn match_candidates( + &mut self, + prefix: &[char], + lowercase_prefix: &[char], + candidates: impl Iterator, + results: &mut Vec, + cancel_flag: &AtomicBool, + build_match: F, + ) where + R: Match, + F: Fn(&C, f64) -> R, + { + let mut candidate_chars = Vec::new(); + let mut lowercase_candidate_chars = Vec::new(); + + for candidate in candidates { + if !candidate.has_chars(self.query_char_bag) { + continue; + } + + if cancel_flag.load(atomic::Ordering::Relaxed) { + break; + } + + candidate_chars.clear(); + lowercase_candidate_chars.clear(); + for c in candidate.to_string().chars() { + candidate_chars.push(c); + lowercase_candidate_chars.push(c.to_ascii_lowercase()); + } + + if !self.find_last_positions(lowercase_prefix, &lowercase_candidate_chars) { + continue; + } + + let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len()); + self.score_matrix.clear(); + self.score_matrix.resize(matrix_len, None); + self.best_position_matrix.clear(); + self.best_position_matrix.resize(matrix_len, 0); + + let score = self.score_match( + &candidate_chars, + &lowercase_candidate_chars, + prefix, + lowercase_prefix, + ); + + if score > 0.0 { + let mut mat = build_match(&candidate, score); + if let Err(i) = results.binary_search_by(|m| mat.cmp(m)) { + if results.len() < self.max_results { + mat.set_positions(self.match_positions.clone()); + results.insert(i, mat); + } else if i < results.len() { + results.pop(); + mat.set_positions(self.match_positions.clone()); + results.insert(i, mat); + } + if results.len() == self.max_results { + self.min_score = results.last().unwrap().score(); + } + } + } + } + } + + fn find_last_positions( + &mut self, + lowercase_prefix: &[char], + lowercase_candidate: &[char], + ) -> bool { + let mut lowercase_prefix = lowercase_prefix.iter(); + let mut lowercase_candidate = lowercase_candidate.iter(); + for (i, char) in self.lowercase_query.iter().enumerate().rev() { + if let Some(j) = lowercase_candidate.rposition(|c| c == char) { + self.last_positions[i] = j + lowercase_prefix.len(); + } else if let Some(j) = lowercase_prefix.rposition(|c| c == char) { + self.last_positions[i] = j; + } else { + return false; + } + } + true + } + + fn score_match( + &mut self, + path: &[char], + path_cased: &[char], + prefix: &[char], + lowercase_prefix: &[char], + ) -> f64 { + let score = self.recursive_score_match( + path, + path_cased, + prefix, + lowercase_prefix, + 0, + 0, + self.query.len() as f64, + ) * self.query.len() as f64; + + if score <= 0.0 { + return 0.0; + } + + let path_len = prefix.len() + path.len(); + let mut cur_start = 0; + let mut byte_ix = 0; + let mut char_ix = 0; + for i in 0..self.query.len() { + let match_char_ix = self.best_position_matrix[i * path_len + cur_start]; + while char_ix < match_char_ix { + let ch = prefix + .get(char_ix) + .or_else(|| path.get(char_ix - prefix.len())) + .unwrap(); + byte_ix += ch.len_utf8(); + char_ix += 1; + } + cur_start = match_char_ix + 1; + self.match_positions[i] = byte_ix; + } + + score + } + + #[allow(clippy::too_many_arguments)] + fn recursive_score_match( + &mut self, + path: &[char], + path_cased: &[char], + prefix: &[char], + lowercase_prefix: &[char], + query_idx: usize, + path_idx: usize, + cur_score: f64, + ) -> f64 { + if query_idx == self.query.len() { + return 1.0; + } + + let path_len = prefix.len() + path.len(); + + if let Some(memoized) = self.score_matrix[query_idx * path_len + path_idx] { + return memoized; + } + + let mut score = 0.0; + let mut best_position = 0; + + let query_char = self.lowercase_query[query_idx]; + let limit = self.last_positions[query_idx]; + + let mut last_slash = 0; + for j in path_idx..=limit { + let path_char = if j < prefix.len() { + lowercase_prefix[j] + } else { + path_cased[j - prefix.len()] + }; + let is_path_sep = path_char == '/' || path_char == '\\'; + + if query_idx == 0 && is_path_sep { + last_slash = j; + } + + if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') { + let curr = if j < prefix.len() { + prefix[j] + } else { + path[j - prefix.len()] + }; + + let mut char_score = 1.0; + if j > path_idx { + let last = if j - 1 < prefix.len() { + prefix[j - 1] + } else { + path[j - 1 - prefix.len()] + }; + + if last == '/' { + char_score = 0.9; + } else if (last == '-' || last == '_' || last == ' ' || last.is_numeric()) + || (last.is_lowercase() && curr.is_uppercase()) + { + char_score = 0.8; + } else if last == '.' { + char_score = 0.7; + } else if query_idx == 0 { + char_score = BASE_DISTANCE_PENALTY; + } else { + char_score = MIN_DISTANCE_PENALTY.max( + BASE_DISTANCE_PENALTY + - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY, + ); + } + } + + // Apply a severe penalty if the case doesn't match. + // This will make the exact matches have higher score than the case-insensitive and the + // path insensitive matches. + if (self.smart_case || curr == '/') && self.query[query_idx] != curr { + char_score *= 0.001; + } + + let mut multiplier = char_score; + + // Scale the score based on how deep within the path we found the match. + if query_idx == 0 { + multiplier /= ((prefix.len() + path.len()) - last_slash) as f64; + } + + let mut next_score = 1.0; + if self.min_score > 0.0 { + next_score = cur_score * multiplier; + // Scores only decrease. If we can't pass the previous best, bail + if next_score < self.min_score { + // Ensure that score is non-zero so we use it in the memo table. + if score == 0.0 { + score = 1e-18; + } + continue; + } + } + + let new_score = self.recursive_score_match( + path, + path_cased, + prefix, + lowercase_prefix, + query_idx + 1, + j + 1, + next_score, + ) * multiplier; + + if new_score > score { + score = new_score; + best_position = j; + // Optimization: can't score better than 1. + if new_score == 1.0 { + break; + } + } + } + } + + if best_position != 0 { + self.best_position_matrix[query_idx * path_len + path_idx] = best_position; + } + + self.score_matrix[query_idx * path_len + path_idx] = Some(score); + score + } +} + +#[cfg(test)] +mod tests { + use crate::{PathMatch, PathMatchCandidate}; + + use super::*; + use std::{ + path::{Path, PathBuf}, + sync::Arc, + }; + + #[test] + fn test_get_last_positions() { + let mut query: &[char] = &['d', 'c']; + let mut matcher = Matcher::new(query, query, query.into(), false, 10); + let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']); + assert!(!result); + + query = &['c', 'd']; + let mut matcher = Matcher::new(query, query, query.into(), false, 10); + let result = matcher.find_last_positions(&['a', 'b', 'c'], &['b', 'd', 'e', 'f']); + assert!(result); + assert_eq!(matcher.last_positions, vec![2, 4]); + + query = &['z', '/', 'z', 'f']; + let mut matcher = Matcher::new(query, query, query.into(), false, 10); + let result = matcher.find_last_positions(&['z', 'e', 'd', '/'], &['z', 'e', 'd', '/', 'f']); + assert!(result); + assert_eq!(matcher.last_positions, vec![0, 3, 4, 8]); + } + + #[test] + fn test_match_path_entries() { + let paths = vec![ + "", + "a", + "ab", + "abC", + "abcd", + "alphabravocharlie", + "AlphaBravoCharlie", + "thisisatestdir", + "/////ThisIsATestDir", + "/this/is/a/test/dir", + "/test/tiatd", + ]; + + assert_eq!( + match_single_path_query("abc", false, &paths), + vec![ + ("abC", vec![0, 1, 2]), + ("abcd", vec![0, 1, 2]), + ("AlphaBravoCharlie", vec![0, 5, 10]), + ("alphabravocharlie", vec![4, 5, 10]), + ] + ); + assert_eq!( + match_single_path_query("t/i/a/t/d", false, &paths), + vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),] + ); + + assert_eq!( + match_single_path_query("tiatd", false, &paths), + vec![ + ("/test/tiatd", vec![6, 7, 8, 9, 10]), + ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]), + ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]), + ("thisisatestdir", vec![0, 2, 6, 7, 11]), + ] + ); + } + + #[test] + fn test_match_multibyte_path_entries() { + let paths = vec!["aαbβ/cγdδ", "αβγδ/bcde", "c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", "/d/🆒/h"]; + assert_eq!("1️⃣".len(), 7); + assert_eq!( + match_single_path_query("bcd", false, &paths), + vec![ + ("αβγδ/bcde", vec![9, 10, 11]), + ("aαbβ/cγdδ", vec![3, 7, 10]), + ] + ); + assert_eq!( + match_single_path_query("cde", false, &paths), + vec![ + ("αβγδ/bcde", vec![10, 11, 12]), + ("c1️⃣2️⃣3️⃣/d4️⃣5️⃣6️⃣/e7️⃣8️⃣9️⃣/f", vec![0, 23, 46]), + ] + ); + } + + fn match_single_path_query<'a>( + query: &str, + smart_case: bool, + paths: &[&'a str], + ) -> Vec<(&'a str, Vec)> { + let lowercase_query = query.to_lowercase().chars().collect::>(); + let query = query.chars().collect::>(); + let query_chars = CharBag::from(&lowercase_query[..]); + + let path_arcs: Vec> = paths + .iter() + .map(|path| Arc::from(PathBuf::from(path))) + .collect::>(); + let mut path_entries = Vec::new(); + for (i, path) in paths.iter().enumerate() { + let lowercase_path = path.to_lowercase().chars().collect::>(); + let char_bag = CharBag::from(lowercase_path.as_slice()); + path_entries.push(PathMatchCandidate { + char_bag, + path: &path_arcs[i], + }); + } + + let mut matcher = Matcher::new(&query, &lowercase_query, query_chars, smart_case, 100); + + let cancel_flag = AtomicBool::new(false); + let mut results = Vec::new(); + + matcher.match_candidates( + &[], + &[], + path_entries.into_iter(), + &mut results, + &cancel_flag, + |candidate, score| PathMatch { + score, + worktree_id: 0, + positions: Vec::new(), + path: Arc::from(candidate.path), + path_prefix: "".into(), + distance_to_relative_ancestor: usize::MAX, + }, + ); + + results + .into_iter() + .map(|result| { + ( + paths + .iter() + .copied() + .find(|p| result.path.as_ref() == Path::new(p)) + .unwrap(), + result.positions, + ) + }) + .collect() + } +} diff --git a/crates/fuzzy2/src/paths.rs b/crates/fuzzy2/src/paths.rs new file mode 100644 index 0000000000000000000000000000000000000000..f6c5fba6c9f641d4b1b132bf2ae932fa5d69bb02 --- /dev/null +++ b/crates/fuzzy2/src/paths.rs @@ -0,0 +1,257 @@ +use gpui2::Executor; +use std::{ + borrow::Cow, + cmp::{self, Ordering}, + path::Path, + sync::{atomic::AtomicBool, Arc}, +}; + +use crate::{ + matcher::{Match, MatchCandidate, Matcher}, + CharBag, +}; + +#[derive(Clone, Debug)] +pub struct PathMatchCandidate<'a> { + pub path: &'a Path, + pub char_bag: CharBag, +} + +#[derive(Clone, Debug)] +pub struct PathMatch { + pub score: f64, + pub positions: Vec, + pub worktree_id: usize, + pub path: Arc, + pub path_prefix: Arc, + /// Number of steps removed from a shared parent with the relative path + /// Used to order closer paths first in the search list + pub distance_to_relative_ancestor: usize, +} + +pub trait PathMatchCandidateSet<'a>: Send + Sync { + type Candidates: Iterator>; + fn id(&self) -> usize; + fn len(&self) -> usize; + fn is_empty(&self) -> bool { + self.len() == 0 + } + fn prefix(&self) -> Arc; + fn candidates(&'a self, start: usize) -> Self::Candidates; +} + +impl Match for PathMatch { + fn score(&self) -> f64 { + self.score + } + + fn set_positions(&mut self, positions: Vec) { + self.positions = positions; + } +} + +impl<'a> MatchCandidate for PathMatchCandidate<'a> { + fn has_chars(&self, bag: CharBag) -> bool { + self.char_bag.is_superset(bag) + } + + fn to_string(&self) -> Cow<'a, str> { + self.path.to_string_lossy() + } +} + +impl PartialEq for PathMatch { + fn eq(&self, other: &Self) -> bool { + self.cmp(other).is_eq() + } +} + +impl Eq for PathMatch {} + +impl PartialOrd for PathMatch { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for PathMatch { + fn cmp(&self, other: &Self) -> Ordering { + self.score + .partial_cmp(&other.score) + .unwrap_or(Ordering::Equal) + .then_with(|| self.worktree_id.cmp(&other.worktree_id)) + .then_with(|| { + other + .distance_to_relative_ancestor + .cmp(&self.distance_to_relative_ancestor) + }) + .then_with(|| self.path.cmp(&other.path)) + } +} + +pub fn match_fixed_path_set( + candidates: Vec, + worktree_id: usize, + query: &str, + smart_case: bool, + max_results: usize, +) -> Vec { + let lowercase_query = query.to_lowercase().chars().collect::>(); + let query = query.chars().collect::>(); + let query_char_bag = CharBag::from(&lowercase_query[..]); + + let mut matcher = Matcher::new( + &query, + &lowercase_query, + query_char_bag, + smart_case, + max_results, + ); + + let mut results = Vec::new(); + matcher.match_candidates( + &[], + &[], + candidates.into_iter(), + &mut results, + &AtomicBool::new(false), + |candidate, score| PathMatch { + score, + worktree_id, + positions: Vec::new(), + path: Arc::from(candidate.path), + path_prefix: Arc::from(""), + distance_to_relative_ancestor: usize::MAX, + }, + ); + results +} + +pub async fn match_path_sets<'a, Set: PathMatchCandidateSet<'a>>( + candidate_sets: &'a [Set], + query: &str, + relative_to: Option>, + smart_case: bool, + max_results: usize, + cancel_flag: &AtomicBool, + executor: Executor, +) -> Vec { + let path_count: usize = candidate_sets.iter().map(|s| s.len()).sum(); + if path_count == 0 { + return Vec::new(); + } + + let lowercase_query = query.to_lowercase().chars().collect::>(); + let query = query.chars().collect::>(); + + let lowercase_query = &lowercase_query; + let query = &query; + let query_char_bag = CharBag::from(&lowercase_query[..]); + + let num_cpus = executor.num_cpus().min(path_count); + let segment_size = (path_count + num_cpus - 1) / num_cpus; + let mut segment_results = (0..num_cpus) + .map(|_| Vec::with_capacity(max_results)) + .collect::>(); + + executor + .scoped(|scope| { + for (segment_idx, results) in segment_results.iter_mut().enumerate() { + let relative_to = relative_to.clone(); + scope.spawn(async move { + let segment_start = segment_idx * segment_size; + let segment_end = segment_start + segment_size; + let mut matcher = Matcher::new( + query, + lowercase_query, + query_char_bag, + smart_case, + max_results, + ); + + let mut tree_start = 0; + for candidate_set in candidate_sets { + let tree_end = tree_start + candidate_set.len(); + + if tree_start < segment_end && segment_start < tree_end { + let start = cmp::max(tree_start, segment_start) - tree_start; + let end = cmp::min(tree_end, segment_end) - tree_start; + let candidates = candidate_set.candidates(start).take(end - start); + + let worktree_id = candidate_set.id(); + let prefix = candidate_set.prefix().chars().collect::>(); + let lowercase_prefix = prefix + .iter() + .map(|c| c.to_ascii_lowercase()) + .collect::>(); + matcher.match_candidates( + &prefix, + &lowercase_prefix, + candidates, + results, + cancel_flag, + |candidate, score| PathMatch { + score, + worktree_id, + positions: Vec::new(), + path: Arc::from(candidate.path), + path_prefix: candidate_set.prefix(), + distance_to_relative_ancestor: relative_to.as_ref().map_or( + usize::MAX, + |relative_to| { + distance_between_paths( + candidate.path.as_ref(), + relative_to.as_ref(), + ) + }, + ), + }, + ); + } + if tree_end >= segment_end { + break; + } + tree_start = tree_end; + } + }) + } + }) + .await; + + let mut results = Vec::new(); + for segment_result in segment_results { + if results.is_empty() { + results = segment_result; + } else { + util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(a)); + } + } + results +} + +/// Compute the distance from a given path to some other path +/// If there is no shared path, returns usize::MAX +fn distance_between_paths(path: &Path, relative_to: &Path) -> usize { + let mut path_components = path.components(); + let mut relative_components = relative_to.components(); + + while path_components + .next() + .zip(relative_components.next()) + .map(|(path_component, relative_component)| path_component == relative_component) + .unwrap_or_default() + {} + path_components.count() + relative_components.count() + 1 +} + +#[cfg(test)] +mod tests { + use std::path::Path; + + use super::distance_between_paths; + + #[test] + fn test_distance_between_paths_empty() { + distance_between_paths(Path::new(""), Path::new("")); + } +} diff --git a/crates/fuzzy2/src/strings.rs b/crates/fuzzy2/src/strings.rs new file mode 100644 index 0000000000000000000000000000000000000000..6f7533ddd0a7b4d0a9e1c4cb190274d233ebdb50 --- /dev/null +++ b/crates/fuzzy2/src/strings.rs @@ -0,0 +1,159 @@ +use crate::{ + matcher::{Match, MatchCandidate, Matcher}, + CharBag, +}; +use gpui2::Executor; +use std::{ + borrow::Cow, + cmp::{self, Ordering}, + sync::atomic::AtomicBool, +}; + +#[derive(Clone, Debug)] +pub struct StringMatchCandidate { + pub id: usize, + pub string: String, + pub char_bag: CharBag, +} + +impl Match for StringMatch { + fn score(&self) -> f64 { + self.score + } + + fn set_positions(&mut self, positions: Vec) { + self.positions = positions; + } +} + +impl StringMatchCandidate { + pub fn new(id: usize, string: String) -> Self { + Self { + id, + char_bag: CharBag::from(string.as_str()), + string, + } + } +} + +impl<'a> MatchCandidate for &'a StringMatchCandidate { + fn has_chars(&self, bag: CharBag) -> bool { + self.char_bag.is_superset(bag) + } + + fn to_string(&self) -> Cow<'a, str> { + self.string.as_str().into() + } +} + +#[derive(Clone, Debug)] +pub struct StringMatch { + pub candidate_id: usize, + pub score: f64, + pub positions: Vec, + pub string: String, +} + +impl PartialEq for StringMatch { + fn eq(&self, other: &Self) -> bool { + self.cmp(other).is_eq() + } +} + +impl Eq for StringMatch {} + +impl PartialOrd for StringMatch { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for StringMatch { + fn cmp(&self, other: &Self) -> Ordering { + self.score + .partial_cmp(&other.score) + .unwrap_or(Ordering::Equal) + .then_with(|| self.candidate_id.cmp(&other.candidate_id)) + } +} + +pub async fn match_strings( + candidates: &[StringMatchCandidate], + query: &str, + smart_case: bool, + max_results: usize, + cancel_flag: &AtomicBool, + executor: Executor, +) -> Vec { + if candidates.is_empty() || max_results == 0 { + return Default::default(); + } + + if query.is_empty() { + return candidates + .iter() + .map(|candidate| StringMatch { + candidate_id: candidate.id, + score: 0., + positions: Default::default(), + string: candidate.string.clone(), + }) + .collect(); + } + + let lowercase_query = query.to_lowercase().chars().collect::>(); + let query = query.chars().collect::>(); + + let lowercase_query = &lowercase_query; + let query = &query; + let query_char_bag = CharBag::from(&lowercase_query[..]); + + let num_cpus = executor.num_cpus().min(candidates.len()); + let segment_size = (candidates.len() + num_cpus - 1) / num_cpus; + let mut segment_results = (0..num_cpus) + .map(|_| Vec::with_capacity(max_results.min(candidates.len()))) + .collect::>(); + + executor + .scoped(|scope| { + for (segment_idx, results) in segment_results.iter_mut().enumerate() { + let cancel_flag = &cancel_flag; + scope.spawn(async move { + let segment_start = cmp::min(segment_idx * segment_size, candidates.len()); + let segment_end = cmp::min(segment_start + segment_size, candidates.len()); + let mut matcher = Matcher::new( + query, + lowercase_query, + query_char_bag, + smart_case, + max_results, + ); + + matcher.match_candidates( + &[], + &[], + candidates[segment_start..segment_end].iter(), + results, + cancel_flag, + |candidate, score| StringMatch { + candidate_id: candidate.id, + score, + positions: Vec::new(), + string: candidate.string.to_string(), + }, + ); + }); + } + }) + .await; + + let mut results = Vec::new(); + for segment_result in segment_results { + if results.is_empty() { + results = segment_result; + } else { + util::extend_sorted(&mut results, segment_result, max_results, |a, b| b.cmp(a)); + } + } + results +} diff --git a/crates/gpui/src/app.rs b/crates/gpui/src/app.rs index edcc7ad6f60fd86fe47beadfebc8fab7ebd32120..4b6b9bea738c393f32e66d5fb84f670dce3fa001 100644 --- a/crates/gpui/src/app.rs +++ b/crates/gpui/src/app.rs @@ -3607,7 +3607,7 @@ impl BorrowWindowContext for EventContext<'_, '_, '_, V> { } } -pub(crate) enum Reference<'a, T> { +pub enum Reference<'a, T> { Immutable(&'a T), Mutable(&'a mut T), } diff --git a/crates/gpui/src/fonts.rs b/crates/gpui/src/fonts.rs index 4b46f8eb792c0efa6a0c1d1176d0f85607d031a2..f360ef933f8f1213a50393d608c29592a57963d4 100644 --- a/crates/gpui/src/fonts.rs +++ b/crates/gpui/src/fonts.rs @@ -154,6 +154,11 @@ impl Refineable for TextStyleRefinement { self.underline = refinement.underline; } } + + fn refined(mut self, refinement: Self::Refinement) -> Self { + self.refine(&refinement); + self + } } #[derive(JsonSchema)] diff --git a/crates/gpui/src/image_cache.rs b/crates/gpui/src/image_cache.rs index d7682a43fd201935ae8683e1e73c166091f8efa7..00d62a16a5e17c4301cc8bf9fe042d6ddefe05d4 100644 --- a/crates/gpui/src/image_cache.rs +++ b/crates/gpui/src/image_cache.rs @@ -84,7 +84,6 @@ impl ImageCache { let format = image::guess_format(&body)?; let image = image::load_from_memory_with_format(&body, format)?.into_bgra8(); - Ok(ImageData::new(image)) } } diff --git a/crates/gpui/src/platform/mac/atlas.rs b/crates/gpui/src/platform/mac/atlas.rs index a529513ef5ef38faab25cf983e4820066ce9d28b..57a137479d7e8eaca79effad754862ddb996e366 100644 --- a/crates/gpui/src/platform/mac/atlas.rs +++ b/crates/gpui/src/platform/mac/atlas.rs @@ -109,6 +109,7 @@ impl AtlasAllocator { }; descriptor.set_width(size.x() as u64); descriptor.set_height(size.y() as u64); + self.device.new_texture(&descriptor) } else { self.device.new_texture(&self.texture_descriptor) diff --git a/crates/gpui/src/platform/mac/renderer.rs b/crates/gpui/src/platform/mac/renderer.rs index 55ec3e9e9a2f3f72c03103adf045721105d3da83..85f0af1ffddb902f00bcdb15012e474547e68fa7 100644 --- a/crates/gpui/src/platform/mac/renderer.rs +++ b/crates/gpui/src/platform/mac/renderer.rs @@ -632,6 +632,7 @@ impl Renderer { ) { // Snap sprite to pixel grid. let origin = (glyph.origin * scale_factor).floor() + sprite.offset.to_f32(); + sprites_by_atlas .entry(sprite.atlas_id) .or_insert_with(Vec::new) diff --git a/crates/gpui/src/platform/mac/window.rs b/crates/gpui/src/platform/mac/window.rs index ad8275f0ac9de36732e04999cd182e756e14b7a8..670a994d5f18999aa0c987f89a5e3083566a66e5 100644 --- a/crates/gpui/src/platform/mac/window.rs +++ b/crates/gpui/src/platform/mac/window.rs @@ -82,6 +82,7 @@ const NSWindowAnimationBehaviorUtilityWindow: NSInteger = 4; #[ctor] unsafe fn build_classes() { + ::util::gpui1_loaded(); WINDOW_CLASS = build_window_class("GPUIWindow", class!(NSWindow)); PANEL_CLASS = build_window_class("GPUIPanel", class!(NSPanel)); VIEW_CLASS = { diff --git a/crates/gpui/src/text_layout.rs b/crates/gpui/src/text_layout.rs index 7fb87b10df2ce3baf822fbe5a6fddb4955e5f134..3ffdfc52a9428d1ae0bf3a01c3f45949d4204814 100644 --- a/crates/gpui/src/text_layout.rs +++ b/crates/gpui/src/text_layout.rs @@ -22,8 +22,8 @@ use std::{ }; pub struct TextLayoutCache { - prev_frame: Mutex>>, - curr_frame: RwLock>>, + prev_frame: Mutex>>, + curr_frame: RwLock>>, fonts: Arc, } @@ -56,7 +56,7 @@ impl TextLayoutCache { font_size: f32, runs: &'a [(usize, RunStyle)], ) -> Line { - let key = &CacheKeyRef { + let key = &BorrowedCacheKey { text, font_size: OrderedFloat(font_size), runs, @@ -72,7 +72,7 @@ impl TextLayoutCache { Line::new(layout, runs) } else { let layout = Arc::new(self.fonts.layout_line(text, font_size, runs)); - let key = CacheKeyValue { + let key = OwnedCacheKey { text: text.into(), font_size: OrderedFloat(font_size), runs: SmallVec::from(runs), @@ -84,7 +84,7 @@ impl TextLayoutCache { } trait CacheKey { - fn key(&self) -> CacheKeyRef; + fn key(&self) -> BorrowedCacheKey; } impl<'a> PartialEq for (dyn CacheKey + 'a) { @@ -102,15 +102,15 @@ impl<'a> Hash for (dyn CacheKey + 'a) { } #[derive(Eq)] -struct CacheKeyValue { +struct OwnedCacheKey { text: String, font_size: OrderedFloat, runs: SmallVec<[(usize, RunStyle); 1]>, } -impl CacheKey for CacheKeyValue { - fn key(&self) -> CacheKeyRef { - CacheKeyRef { +impl CacheKey for OwnedCacheKey { + fn key(&self) -> BorrowedCacheKey { + BorrowedCacheKey { text: self.text.as_str(), font_size: self.font_size, runs: self.runs.as_slice(), @@ -118,38 +118,38 @@ impl CacheKey for CacheKeyValue { } } -impl PartialEq for CacheKeyValue { +impl PartialEq for OwnedCacheKey { fn eq(&self, other: &Self) -> bool { self.key().eq(&other.key()) } } -impl Hash for CacheKeyValue { +impl Hash for OwnedCacheKey { fn hash(&self, state: &mut H) { self.key().hash(state); } } -impl<'a> Borrow for CacheKeyValue { +impl<'a> Borrow for OwnedCacheKey { fn borrow(&self) -> &(dyn CacheKey + 'a) { self as &dyn CacheKey } } #[derive(Copy, Clone)] -struct CacheKeyRef<'a> { +struct BorrowedCacheKey<'a> { text: &'a str, font_size: OrderedFloat, runs: &'a [(usize, RunStyle)], } -impl<'a> CacheKey for CacheKeyRef<'a> { - fn key(&self) -> CacheKeyRef { +impl<'a> CacheKey for BorrowedCacheKey<'a> { + fn key(&self) -> BorrowedCacheKey { *self } } -impl<'a> PartialEq for CacheKeyRef<'a> { +impl<'a> PartialEq for BorrowedCacheKey<'a> { fn eq(&self, other: &Self) -> bool { self.text == other.text && self.font_size == other.font_size @@ -162,7 +162,7 @@ impl<'a> PartialEq for CacheKeyRef<'a> { } } -impl<'a> Hash for CacheKeyRef<'a> { +impl<'a> Hash for BorrowedCacheKey<'a> { fn hash(&self, state: &mut H) { self.text.hash(state); self.font_size.hash(state); diff --git a/crates/gpui2/Cargo.toml b/crates/gpui2/Cargo.toml index 093ab1e3470b313208645845118ef6cda6957941..fa072dadc3c064a493e135451c67cda7e822c93c 100644 --- a/crates/gpui2/Cargo.toml +++ b/crates/gpui2/Cargo.toml @@ -2,31 +2,86 @@ name = "gpui2" version = "0.1.0" edition = "2021" +authors = ["Nathan Sobo "] +description = "The next version of Zed's GPU-accelerated UI framework" publish = false +[features] +test-support = ["backtrace", "dhat", "env_logger", "collections/test-support", "util/test-support"] + [lib] -name = "gpui2" path = "src/gpui2.rs" - -[features] -test-support = ["gpui/test-support"] +doctest = false [dependencies] -anyhow.workspace = true +collections = { path = "../collections" } +gpui_macros = { path = "../gpui_macros" } +gpui2_macros = { path = "../gpui2_macros" } +util = { path = "../util" } +sum_tree = { path = "../sum_tree" } +sqlez = { path = "../sqlez" } +async-task = "4.0.3" +backtrace = { version = "0.3", optional = true } +ctor.workspace = true derive_more.workspace = true -gpui = { path = "../gpui" } -log.workspace = true +dhat = { version = "0.3", optional = true } +env_logger = { version = "0.9", optional = true } +etagere = "0.2" futures.workspace = true -gpui2_macros = { path = "../gpui2_macros" } +image = "0.23" +itertools = "0.10" +lazy_static.workspace = true +log.workspace = true +num_cpus = "1.13" +ordered-float.workspace = true +parking = "2.0.0" parking_lot.workspace = true +pathfinder_geometry = "0.5" +postage.workspace = true +rand.workspace = true refineable.workspace = true -rust-embed.workspace = true +resvg = "0.14" +seahash = "4.1" serde.workspace = true -settings = { path = "../settings" } -simplelog = "0.9" +serde_derive.workspace = true +serde_json.workspace = true smallvec.workspace = true -theme = { path = "../theme" } -util = { path = "../util" } +smol.workspace = true +taffy = { git = "https://github.com/DioxusLabs/taffy", rev = "4fb530bdd71609bb1d3f76c6a8bde1ba82805d5e" } +thiserror.workspace = true +time.workspace = true +tiny-skia = "0.5" +usvg = { version = "0.14", features = [] } +uuid = { version = "1.1.2", features = ["v4"] } +waker-fn = "1.1.0" +slotmap = "1.0.6" +schemars.workspace = true +plane-split = "0.18.0" +bitflags = "2.4.0" [dev-dependencies] -gpui = { path = "../gpui", features = ["test-support"] } +backtrace = "0.3" +collections = { path = "../collections", features = ["test-support"] } +dhat = "0.3" +env_logger.workspace = true +png = "0.16" +simplelog = "0.9" +util = { path = "../util", features = ["test-support"] } + +[build-dependencies] +bindgen = "0.65.1" +cbindgen = "0.26.0" + +[target.'cfg(target_os = "macos")'.dependencies] +media = { path = "../media" } +anyhow.workspace = true +block = "0.1" +cocoa = "0.24" +core-foundation = { version = "0.9.3", features = ["with-uuid"] } +core-graphics = "0.22.3" +core-text = "19.2" +font-kit = { git = "https://github.com/zed-industries/font-kit", rev = "b2f77d56f450338aa4f7dd2f0197d8c9acb0cf18" } +foreign-types = "0.3" +log.workspace = true +metal = "0.21.0" +objc = "0.2" diff --git a/crates/gpui2/build.rs b/crates/gpui2/build.rs new file mode 100644 index 0000000000000000000000000000000000000000..c9abfaa6bb5e79ad07d77b1855c9414c2b0b5b05 --- /dev/null +++ b/crates/gpui2/build.rs @@ -0,0 +1,134 @@ +use std::{ + env, + path::{Path, PathBuf}, + process::{self, Command}, +}; + +use cbindgen::Config; + +fn main() { + generate_dispatch_bindings(); + let header_path = generate_shader_bindings(); + compile_metal_shaders(&header_path); +} + +fn generate_dispatch_bindings() { + println!("cargo:rustc-link-lib=framework=System"); + println!("cargo:rerun-if-changed=src/platform/mac/dispatch.h"); + + let bindings = bindgen::Builder::default() + .header("src/platform/mac/dispatch.h") + .allowlist_var("_dispatch_main_q") + .allowlist_var("DISPATCH_QUEUE_PRIORITY_DEFAULT") + .allowlist_function("dispatch_get_global_queue") + .allowlist_function("dispatch_async_f") + .allowlist_function("dispatch_after_f") + .allowlist_function("dispatch_time") + .parse_callbacks(Box::new(bindgen::CargoCallbacks)) + .layout_tests(false) + .generate() + .expect("unable to generate bindings"); + + let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); + bindings + .write_to_file(out_path.join("dispatch_sys.rs")) + .expect("couldn't write dispatch bindings"); +} + +fn generate_shader_bindings() -> PathBuf { + let output_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("scene.h"); + let crate_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()); + let mut config = Config::default(); + config.include_guard = Some("SCENE_H".into()); + config.language = cbindgen::Language::C; + config.export.include.extend([ + "Bounds".into(), + "Corners".into(), + "Edges".into(), + "Size".into(), + "Pixels".into(), + "PointF".into(), + "Hsla".into(), + "ContentMask".into(), + "Uniforms".into(), + "AtlasTile".into(), + "PathRasterizationInputIndex".into(), + "PathVertex_ScaledPixels".into(), + "ShadowInputIndex".into(), + "Shadow".into(), + "QuadInputIndex".into(), + "Underline".into(), + "UnderlineInputIndex".into(), + "Quad".into(), + "SpriteInputIndex".into(), + "MonochromeSprite".into(), + "PolychromeSprite".into(), + "PathSprite".into(), + ]); + config.no_includes = true; + config.enumeration.prefix_with_name = true; + cbindgen::Builder::new() + .with_src(crate_dir.join("src/scene.rs")) + .with_src(crate_dir.join("src/geometry.rs")) + .with_src(crate_dir.join("src/color.rs")) + .with_src(crate_dir.join("src/window.rs")) + .with_src(crate_dir.join("src/platform.rs")) + .with_src(crate_dir.join("src/platform/mac/metal_renderer.rs")) + .with_config(config) + .generate() + .expect("Unable to generate bindings") + .write_to_file(&output_path); + + output_path +} + +fn compile_metal_shaders(header_path: &Path) { + let shader_path = "./src/platform/mac/shaders.metal"; + let air_output_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("shaders.air"); + let metallib_output_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("shaders.metallib"); + + println!("cargo:rerun-if-changed={}", header_path.display()); + println!("cargo:rerun-if-changed={}", shader_path); + + let output = Command::new("xcrun") + .args([ + "-sdk", + "macosx", + "metal", + "-gline-tables-only", + "-mmacosx-version-min=10.15.7", + "-MO", + "-c", + shader_path, + "-include", + &header_path.to_str().unwrap(), + "-o", + ]) + .arg(&air_output_path) + .output() + .unwrap(); + + if !output.status.success() { + eprintln!( + "metal shader compilation failed:\n{}", + String::from_utf8_lossy(&output.stderr) + ); + process::exit(1); + } + + let output = Command::new("xcrun") + .args(["-sdk", "macosx", "metallib"]) + .arg(air_output_path) + .arg("-o") + .arg(metallib_output_path) + .output() + .unwrap(); + + if !output.status.success() { + eprintln!( + "metallib compilation failed:\n{}", + String::from_utf8_lossy(&output.stderr) + ); + process::exit(1); + } +} diff --git a/crates/gpui2/src/action.rs b/crates/gpui2/src/action.rs new file mode 100644 index 0000000000000000000000000000000000000000..638e5c6ca3c918ccec786778e4d3eb94294934cb --- /dev/null +++ b/crates/gpui2/src/action.rs @@ -0,0 +1,432 @@ +use crate::SharedString; +use anyhow::{anyhow, Context, Result}; +use collections::{HashMap, HashSet}; +use serde::Deserialize; +use std::any::{type_name, Any}; + +pub trait Action: Any + Send { + fn qualified_name() -> SharedString + where + Self: Sized; + fn build(value: Option) -> Result> + where + Self: Sized; + + fn partial_eq(&self, action: &dyn Action) -> bool; + fn boxed_clone(&self) -> Box; + fn as_any(&self) -> &dyn Any; +} + +impl Action for A +where + A: for<'a> Deserialize<'a> + PartialEq + Any + Send + Clone + Default, +{ + fn qualified_name() -> SharedString { + type_name::().into() + } + + fn build(params: Option) -> Result> + where + Self: Sized, + { + let action = if let Some(params) = params { + serde_json::from_value(params).context("failed to deserialize action")? + } else { + Self::default() + }; + Ok(Box::new(action)) + } + + fn partial_eq(&self, action: &dyn Action) -> bool { + action + .as_any() + .downcast_ref::() + .map_or(false, |a| self == a) + } + + fn boxed_clone(&self) -> Box { + Box::new(self.clone()) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub struct DispatchContext { + set: HashSet, + map: HashMap, +} + +impl<'a> TryFrom<&'a str> for DispatchContext { + type Error = anyhow::Error; + + fn try_from(value: &'a str) -> Result { + Self::parse(value) + } +} + +impl DispatchContext { + pub fn parse(source: &str) -> Result { + let mut context = Self::default(); + let source = skip_whitespace(source); + Self::parse_expr(&source, &mut context)?; + Ok(context) + } + + fn parse_expr(mut source: &str, context: &mut Self) -> Result<()> { + if source.is_empty() { + return Ok(()); + } + + let key = source + .chars() + .take_while(|c| is_identifier_char(*c)) + .collect::(); + source = skip_whitespace(&source[key.len()..]); + if let Some(suffix) = source.strip_prefix('=') { + source = skip_whitespace(suffix); + let value = source + .chars() + .take_while(|c| is_identifier_char(*c)) + .collect::(); + source = skip_whitespace(&source[value.len()..]); + context.set(key, value); + } else { + context.insert(key); + } + + Self::parse_expr(source, context) + } + + pub fn is_empty(&self) -> bool { + self.set.is_empty() && self.map.is_empty() + } + + pub fn clear(&mut self) { + self.set.clear(); + self.map.clear(); + } + + pub fn extend(&mut self, other: &Self) { + for v in &other.set { + self.set.insert(v.clone()); + } + for (k, v) in &other.map { + self.map.insert(k.clone(), v.clone()); + } + } + + pub fn insert>(&mut self, identifier: I) { + self.set.insert(identifier.into()); + } + + pub fn set, S2: Into>(&mut self, key: S1, value: S2) { + self.map.insert(key.into(), value.into()); + } +} + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub enum DispatchContextPredicate { + Identifier(SharedString), + Equal(SharedString, SharedString), + NotEqual(SharedString, SharedString), + Child(Box, Box), + Not(Box), + And(Box, Box), + Or(Box, Box), +} + +impl DispatchContextPredicate { + pub fn parse(source: &str) -> Result { + let source = skip_whitespace(source); + let (predicate, rest) = Self::parse_expr(source, 0)?; + if let Some(next) = rest.chars().next() { + Err(anyhow!("unexpected character {next:?}")) + } else { + Ok(predicate) + } + } + + pub fn eval(&self, contexts: &[&DispatchContext]) -> bool { + let Some(context) = contexts.last() else { + return false; + }; + match self { + Self::Identifier(name) => context.set.contains(name), + Self::Equal(left, right) => context + .map + .get(left) + .map(|value| value == right) + .unwrap_or(false), + Self::NotEqual(left, right) => context + .map + .get(left) + .map(|value| value != right) + .unwrap_or(true), + Self::Not(pred) => !pred.eval(contexts), + Self::Child(parent, child) => { + parent.eval(&contexts[..contexts.len() - 1]) && child.eval(contexts) + } + Self::And(left, right) => left.eval(contexts) && right.eval(contexts), + Self::Or(left, right) => left.eval(contexts) || right.eval(contexts), + } + } + + fn parse_expr(mut source: &str, min_precedence: u32) -> anyhow::Result<(Self, &str)> { + type Op = fn( + DispatchContextPredicate, + DispatchContextPredicate, + ) -> Result; + + let (mut predicate, rest) = Self::parse_primary(source)?; + source = rest; + + 'parse: loop { + for (operator, precedence, constructor) in [ + (">", PRECEDENCE_CHILD, Self::new_child as Op), + ("&&", PRECEDENCE_AND, Self::new_and as Op), + ("||", PRECEDENCE_OR, Self::new_or as Op), + ("==", PRECEDENCE_EQ, Self::new_eq as Op), + ("!=", PRECEDENCE_EQ, Self::new_neq as Op), + ] { + if source.starts_with(operator) && precedence >= min_precedence { + source = skip_whitespace(&source[operator.len()..]); + let (right, rest) = Self::parse_expr(source, precedence + 1)?; + predicate = constructor(predicate, right)?; + source = rest; + continue 'parse; + } + } + break; + } + + Ok((predicate, source)) + } + + fn parse_primary(mut source: &str) -> anyhow::Result<(Self, &str)> { + let next = source + .chars() + .next() + .ok_or_else(|| anyhow!("unexpected eof"))?; + match next { + '(' => { + source = skip_whitespace(&source[1..]); + let (predicate, rest) = Self::parse_expr(source, 0)?; + if rest.starts_with(')') { + source = skip_whitespace(&rest[1..]); + Ok((predicate, source)) + } else { + Err(anyhow!("expected a ')'")) + } + } + '!' => { + let source = skip_whitespace(&source[1..]); + let (predicate, source) = Self::parse_expr(&source, PRECEDENCE_NOT)?; + Ok((DispatchContextPredicate::Not(Box::new(predicate)), source)) + } + _ if is_identifier_char(next) => { + let len = source + .find(|c: char| !is_identifier_char(c)) + .unwrap_or(source.len()); + let (identifier, rest) = source.split_at(len); + source = skip_whitespace(rest); + Ok(( + DispatchContextPredicate::Identifier(identifier.to_string().into()), + source, + )) + } + _ => Err(anyhow!("unexpected character {next:?}")), + } + } + + fn new_or(self, other: Self) -> Result { + Ok(Self::Or(Box::new(self), Box::new(other))) + } + + fn new_and(self, other: Self) -> Result { + Ok(Self::And(Box::new(self), Box::new(other))) + } + + fn new_child(self, other: Self) -> Result { + Ok(Self::Child(Box::new(self), Box::new(other))) + } + + fn new_eq(self, other: Self) -> Result { + if let (Self::Identifier(left), Self::Identifier(right)) = (self, other) { + Ok(Self::Equal(left, right)) + } else { + Err(anyhow!("operands must be identifiers")) + } + } + + fn new_neq(self, other: Self) -> Result { + if let (Self::Identifier(left), Self::Identifier(right)) = (self, other) { + Ok(Self::NotEqual(left, right)) + } else { + Err(anyhow!("operands must be identifiers")) + } + } +} + +const PRECEDENCE_CHILD: u32 = 1; +const PRECEDENCE_OR: u32 = 2; +const PRECEDENCE_AND: u32 = 3; +const PRECEDENCE_EQ: u32 = 4; +const PRECEDENCE_NOT: u32 = 5; + +fn is_identifier_char(c: char) -> bool { + c.is_alphanumeric() || c == '_' || c == '-' +} + +fn skip_whitespace(source: &str) -> &str { + let len = source + .find(|c: char| !c.is_whitespace()) + .unwrap_or(source.len()); + &source[len..] +} + +#[cfg(test)] +mod tests { + use super::*; + use DispatchContextPredicate::*; + + #[test] + fn test_parse_context() { + let mut expected = DispatchContext::default(); + expected.set("foo", "bar"); + expected.insert("baz"); + assert_eq!(DispatchContext::parse("baz foo=bar").unwrap(), expected); + assert_eq!(DispatchContext::parse("foo = bar baz").unwrap(), expected); + assert_eq!( + DispatchContext::parse(" baz foo = bar baz").unwrap(), + expected + ); + assert_eq!(DispatchContext::parse(" foo = bar baz").unwrap(), expected); + } + + #[test] + fn test_parse_identifiers() { + // Identifiers + assert_eq!( + DispatchContextPredicate::parse("abc12").unwrap(), + Identifier("abc12".into()) + ); + assert_eq!( + DispatchContextPredicate::parse("_1a").unwrap(), + Identifier("_1a".into()) + ); + } + + #[test] + fn test_parse_negations() { + assert_eq!( + DispatchContextPredicate::parse("!abc").unwrap(), + Not(Box::new(Identifier("abc".into()))) + ); + assert_eq!( + DispatchContextPredicate::parse(" ! ! abc").unwrap(), + Not(Box::new(Not(Box::new(Identifier("abc".into()))))) + ); + } + + #[test] + fn test_parse_equality_operators() { + assert_eq!( + DispatchContextPredicate::parse("a == b").unwrap(), + Equal("a".into(), "b".into()) + ); + assert_eq!( + DispatchContextPredicate::parse("c!=d").unwrap(), + NotEqual("c".into(), "d".into()) + ); + assert_eq!( + DispatchContextPredicate::parse("c == !d") + .unwrap_err() + .to_string(), + "operands must be identifiers" + ); + } + + #[test] + fn test_parse_boolean_operators() { + assert_eq!( + DispatchContextPredicate::parse("a || b").unwrap(), + Or( + Box::new(Identifier("a".into())), + Box::new(Identifier("b".into())) + ) + ); + assert_eq!( + DispatchContextPredicate::parse("a || !b && c").unwrap(), + Or( + Box::new(Identifier("a".into())), + Box::new(And( + Box::new(Not(Box::new(Identifier("b".into())))), + Box::new(Identifier("c".into())) + )) + ) + ); + assert_eq!( + DispatchContextPredicate::parse("a && b || c&&d").unwrap(), + Or( + Box::new(And( + Box::new(Identifier("a".into())), + Box::new(Identifier("b".into())) + )), + Box::new(And( + Box::new(Identifier("c".into())), + Box::new(Identifier("d".into())) + )) + ) + ); + assert_eq!( + DispatchContextPredicate::parse("a == b && c || d == e && f").unwrap(), + Or( + Box::new(And( + Box::new(Equal("a".into(), "b".into())), + Box::new(Identifier("c".into())) + )), + Box::new(And( + Box::new(Equal("d".into(), "e".into())), + Box::new(Identifier("f".into())) + )) + ) + ); + assert_eq!( + DispatchContextPredicate::parse("a && b && c && d").unwrap(), + And( + Box::new(And( + Box::new(And( + Box::new(Identifier("a".into())), + Box::new(Identifier("b".into())) + )), + Box::new(Identifier("c".into())), + )), + Box::new(Identifier("d".into())) + ), + ); + } + + #[test] + fn test_parse_parenthesized_expressions() { + assert_eq!( + DispatchContextPredicate::parse("a && (b == c || d != e)").unwrap(), + And( + Box::new(Identifier("a".into())), + Box::new(Or( + Box::new(Equal("b".into(), "c".into())), + Box::new(NotEqual("d".into(), "e".into())), + )), + ), + ); + assert_eq!( + DispatchContextPredicate::parse(" ( a || b ) ").unwrap(), + Or( + Box::new(Identifier("a".into())), + Box::new(Identifier("b".into())), + ) + ); + } +} diff --git a/crates/gpui2/src/adapter.rs b/crates/gpui2/src/adapter.rs deleted file mode 100644 index c36966d72262495f868900e463502ced21ddd2d1..0000000000000000000000000000000000000000 --- a/crates/gpui2/src/adapter.rs +++ /dev/null @@ -1,76 +0,0 @@ -use crate::ViewContext; -use gpui::{geometry::rect::RectF, LayoutEngine, LayoutId}; -use util::ResultExt; - -/// Makes a new, gpui2-style element into a legacy element. -pub struct AdapterElement(pub(crate) crate::element::AnyElement); - -impl gpui::Element for AdapterElement { - type LayoutState = Option<(LayoutEngine, LayoutId)>; - type PaintState = (); - - fn layout( - &mut self, - constraint: gpui::SizeConstraint, - view: &mut V, - cx: &mut gpui::ViewContext, - ) -> (gpui::geometry::vector::Vector2F, Self::LayoutState) { - cx.push_layout_engine(LayoutEngine::new()); - - let mut cx = ViewContext::new(cx); - let layout_id = self.0.layout(view, &mut cx).log_err(); - if let Some(layout_id) = layout_id { - cx.layout_engine() - .unwrap() - .compute_layout(layout_id, constraint.max) - .log_err(); - } - - let layout_engine = cx.pop_layout_engine(); - debug_assert!(layout_engine.is_some(), - "unexpected layout stack state. is there an unmatched pop_layout_engine in the called code?" - ); - - (constraint.max, layout_engine.zip(layout_id)) - } - - fn paint( - &mut self, - bounds: RectF, - _visible_bounds: RectF, - layout_data: &mut Option<(LayoutEngine, LayoutId)>, - view: &mut V, - cx: &mut gpui::ViewContext, - ) -> Self::PaintState { - let (layout_engine, layout_id) = layout_data.take().unwrap(); - cx.push_layout_engine(layout_engine); - self.0 - .paint(view, bounds.origin(), &mut ViewContext::new(cx)); - *layout_data = cx.pop_layout_engine().zip(Some(layout_id)); - debug_assert!(layout_data.is_some()); - } - - fn rect_for_text_range( - &self, - _range_utf16: std::ops::Range, - _bounds: RectF, - _visible_bounds: RectF, - _layout: &Self::LayoutState, - _paint: &Self::PaintState, - _view: &V, - _cx: &gpui::ViewContext, - ) -> Option { - todo!("implement before merging to main") - } - - fn debug( - &self, - _bounds: RectF, - _layout: &Self::LayoutState, - _paint: &Self::PaintState, - _view: &V, - _cx: &gpui::ViewContext, - ) -> gpui::serde_json::Value { - todo!("implement before merging to main") - } -} diff --git a/crates/gpui2/src/app.rs b/crates/gpui2/src/app.rs new file mode 100644 index 0000000000000000000000000000000000000000..a49b2aaa1c5d1c380fadc1fab13e5b116e5d719c --- /dev/null +++ b/crates/gpui2/src/app.rs @@ -0,0 +1,919 @@ +mod async_context; +mod entity_map; +mod model_context; +#[cfg(any(test, feature = "test-support"))] +mod test_context; + +pub use async_context::*; +pub use entity_map::*; +pub use model_context::*; +use refineable::Refineable; +use smallvec::SmallVec; +#[cfg(any(test, feature = "test-support"))] +pub use test_context::*; + +use crate::{ + current_platform, image_cache::ImageCache, Action, AnyBox, AnyView, AppMetadata, AssetSource, + ClipboardItem, Context, DispatchPhase, DisplayId, Executor, FocusEvent, FocusHandle, FocusId, + KeyBinding, Keymap, LayoutId, MainThread, MainThreadOnly, Pixels, Platform, Point, Render, + SharedString, SubscriberSet, Subscription, SvgRenderer, Task, TextStyle, TextStyleRefinement, + TextSystem, View, Window, WindowContext, WindowHandle, WindowId, +}; +use anyhow::{anyhow, Result}; +use collections::{HashMap, HashSet, VecDeque}; +use futures::{future::BoxFuture, Future}; +use parking_lot::Mutex; +use slotmap::SlotMap; +use std::{ + any::{type_name, Any, TypeId}, + borrow::Borrow, + marker::PhantomData, + mem, + ops::{Deref, DerefMut}, + path::PathBuf, + sync::{atomic::Ordering::SeqCst, Arc, Weak}, + time::Duration, +}; +use util::http::{self, HttpClient}; + +pub struct App(Arc>); + +/// Represents an application before it is fully launched. Once your app is +/// configured, you'll start the app with `App::run`. +impl App { + /// Builds an app with the given asset source. + pub fn production(asset_source: Arc) -> Self { + Self(AppContext::new( + current_platform(), + asset_source, + http::client(), + )) + } + + /// Start the application. The provided callback will be called once the + /// app is fully launched. + pub fn run(self, on_finish_launching: F) + where + F: 'static + FnOnce(&mut MainThread), + { + let this = self.0.clone(); + let platform = self.0.lock().platform.clone(); + platform.borrow_on_main_thread().run(Box::new(move || { + let cx = &mut *this.lock(); + let cx = unsafe { mem::transmute::<&mut AppContext, &mut MainThread>(cx) }; + on_finish_launching(cx); + })); + } + + /// Register a handler to be invoked when the platform instructs the application + /// to open one or more URLs. + pub fn on_open_urls(&self, mut callback: F) -> &Self + where + F: 'static + FnMut(Vec, &mut AppContext), + { + let this = Arc::downgrade(&self.0); + self.0 + .lock() + .platform + .borrow_on_main_thread() + .on_open_urls(Box::new(move |urls| { + if let Some(app) = this.upgrade() { + callback(urls, &mut app.lock()); + } + })); + self + } + + pub fn on_reopen(&self, mut callback: F) -> &Self + where + F: 'static + FnMut(&mut AppContext), + { + let this = Arc::downgrade(&self.0); + self.0 + .lock() + .platform + .borrow_on_main_thread() + .on_reopen(Box::new(move || { + if let Some(app) = this.upgrade() { + callback(&mut app.lock()); + } + })); + self + } + + pub fn metadata(&self) -> AppMetadata { + self.0.lock().app_metadata.clone() + } + + pub fn executor(&self) -> Executor { + self.0.lock().executor.clone() + } + + pub fn text_system(&self) -> Arc { + self.0.lock().text_system.clone() + } +} + +type ActionBuilder = fn(json: Option) -> anyhow::Result>; +type FrameCallback = Box; +type Handler = Box bool + Send + 'static>; +type Listener = Box bool + Send + 'static>; +type QuitHandler = Box BoxFuture<'static, ()> + Send + 'static>; +type ReleaseListener = Box; + +pub struct AppContext { + this: Weak>, + pub(crate) platform: MainThreadOnly, + app_metadata: AppMetadata, + text_system: Arc, + flushing_effects: bool, + pending_updates: usize, + pub(crate) active_drag: Option, + pub(crate) next_frame_callbacks: HashMap>, + pub(crate) executor: Executor, + pub(crate) svg_renderer: SvgRenderer, + asset_source: Arc, + pub(crate) image_cache: ImageCache, + pub(crate) text_style_stack: Vec, + pub(crate) globals_by_type: HashMap, + pub(crate) entities: EntityMap, + pub(crate) windows: SlotMap>, + pub(crate) keymap: Arc>, + pub(crate) global_action_listeners: + HashMap>>, + action_builders: HashMap, + pending_effects: VecDeque, + pub(crate) pending_notifications: HashSet, + pub(crate) pending_global_notifications: HashSet, + pub(crate) observers: SubscriberSet, + pub(crate) event_listeners: SubscriberSet, + pub(crate) release_listeners: SubscriberSet, + pub(crate) global_observers: SubscriberSet, + pub(crate) quit_observers: SubscriberSet<(), QuitHandler>, + pub(crate) layout_id_buffer: Vec, // We recycle this memory across layout requests. + pub(crate) propagate_event: bool, +} + +impl AppContext { + pub(crate) fn new( + platform: Arc, + asset_source: Arc, + http_client: Arc, + ) -> Arc> { + let executor = platform.executor(); + assert!( + executor.is_main_thread(), + "must construct App on main thread" + ); + + let text_system = Arc::new(TextSystem::new(platform.text_system())); + let entities = EntityMap::new(); + + let app_metadata = AppMetadata { + os_name: platform.os_name(), + os_version: platform.os_version().ok(), + app_version: platform.app_version().ok(), + }; + + Arc::new_cyclic(|this| { + Mutex::new(AppContext { + this: this.clone(), + text_system, + platform: MainThreadOnly::new(platform, executor.clone()), + app_metadata, + flushing_effects: false, + pending_updates: 0, + next_frame_callbacks: Default::default(), + executor, + svg_renderer: SvgRenderer::new(asset_source.clone()), + asset_source, + image_cache: ImageCache::new(http_client), + text_style_stack: Vec::new(), + globals_by_type: HashMap::default(), + entities, + windows: SlotMap::with_key(), + keymap: Arc::new(Mutex::new(Keymap::default())), + global_action_listeners: HashMap::default(), + action_builders: HashMap::default(), + pending_effects: VecDeque::new(), + pending_notifications: HashSet::default(), + pending_global_notifications: HashSet::default(), + observers: SubscriberSet::new(), + event_listeners: SubscriberSet::new(), + release_listeners: SubscriberSet::new(), + global_observers: SubscriberSet::new(), + quit_observers: SubscriberSet::new(), + layout_id_buffer: Default::default(), + propagate_event: true, + active_drag: None, + }) + }) + } + + /// Quit the application gracefully. Handlers registered with `ModelContext::on_app_quit` + /// will be given 100ms to complete before exiting. + pub fn quit(&mut self) { + let mut futures = Vec::new(); + + self.quit_observers.clone().retain(&(), |observer| { + futures.push(observer(self)); + true + }); + + self.windows.clear(); + self.flush_effects(); + + let futures = futures::future::join_all(futures); + if self + .executor + .block_with_timeout(Duration::from_millis(100), futures) + .is_err() + { + log::error!("timed out waiting on app_will_quit"); + } + + self.globals_by_type.clear(); + } + + pub fn app_metadata(&self) -> AppMetadata { + self.app_metadata.clone() + } + + /// Schedules all windows in the application to be redrawn. This can be called + /// multiple times in an update cycle and still result in a single redraw. + pub fn refresh(&mut self) { + self.pending_effects.push_back(Effect::Refresh); + } + + pub(crate) fn update(&mut self, update: impl FnOnce(&mut Self) -> R) -> R { + self.pending_updates += 1; + let result = update(self); + if !self.flushing_effects && self.pending_updates == 1 { + self.flushing_effects = true; + self.flush_effects(); + self.flushing_effects = false; + } + self.pending_updates -= 1; + result + } + + pub(crate) fn read_window( + &mut self, + id: WindowId, + read: impl FnOnce(&WindowContext) -> R, + ) -> Result { + let window = self + .windows + .get(id) + .ok_or_else(|| anyhow!("window not found"))? + .as_ref() + .unwrap(); + Ok(read(&WindowContext::immutable(self, &window))) + } + + pub(crate) fn update_window( + &mut self, + id: WindowId, + update: impl FnOnce(&mut WindowContext) -> R, + ) -> Result { + self.update(|cx| { + let mut window = cx + .windows + .get_mut(id) + .ok_or_else(|| anyhow!("window not found"))? + .take() + .unwrap(); + + let result = update(&mut WindowContext::mutable(cx, &mut window)); + + cx.windows + .get_mut(id) + .ok_or_else(|| anyhow!("window not found"))? + .replace(window); + + Ok(result) + }) + } + + pub(crate) fn push_effect(&mut self, effect: Effect) { + match &effect { + Effect::Notify { emitter } => { + if !self.pending_notifications.insert(*emitter) { + return; + } + } + Effect::NotifyGlobalObservers { global_type } => { + if !self.pending_global_notifications.insert(*global_type) { + return; + } + } + _ => {} + }; + + self.pending_effects.push_back(effect); + } + + /// Called at the end of AppContext::update to complete any side effects + /// such as notifying observers, emitting events, etc. Effects can themselves + /// cause effects, so we continue looping until all effects are processed. + fn flush_effects(&mut self) { + loop { + self.release_dropped_entities(); + self.release_dropped_focus_handles(); + if let Some(effect) = self.pending_effects.pop_front() { + match effect { + Effect::Notify { emitter } => { + self.apply_notify_effect(emitter); + } + Effect::Emit { emitter, event } => self.apply_emit_effect(emitter, event), + Effect::FocusChanged { window_id, focused } => { + self.apply_focus_changed_effect(window_id, focused); + } + Effect::Refresh => { + self.apply_refresh_effect(); + } + Effect::NotifyGlobalObservers { global_type } => { + self.apply_notify_global_observers_effect(global_type); + } + Effect::Defer { callback } => { + self.apply_defer_effect(callback); + } + } + } else { + break; + } + } + + let dirty_window_ids = self + .windows + .iter() + .filter_map(|(window_id, window)| { + let window = window.as_ref().unwrap(); + if window.dirty { + Some(window_id) + } else { + None + } + }) + .collect::>(); + + for dirty_window_id in dirty_window_ids { + self.update_window(dirty_window_id, |cx| cx.draw()).unwrap(); + } + } + + /// Repeatedly called during `flush_effects` to release any entities whose + /// reference count has become zero. We invoke any release observers before dropping + /// each entity. + fn release_dropped_entities(&mut self) { + loop { + let dropped = self.entities.take_dropped(); + if dropped.is_empty() { + break; + } + + for (entity_id, mut entity) in dropped { + self.observers.remove(&entity_id); + self.event_listeners.remove(&entity_id); + for mut release_callback in self.release_listeners.remove(&entity_id) { + release_callback(&mut entity, self); + } + } + } + } + + /// Repeatedly called during `flush_effects` to handle a focused handle being dropped. + /// For now, we simply blur the window if this happens, but we may want to support invoking + /// a window blur handler to restore focus to some logical element. + fn release_dropped_focus_handles(&mut self) { + let window_ids = self.windows.keys().collect::>(); + for window_id in window_ids { + self.update_window(window_id, |cx| { + let mut blur_window = false; + let focus = cx.window.focus; + cx.window.focus_handles.write().retain(|handle_id, count| { + if count.load(SeqCst) == 0 { + if focus == Some(handle_id) { + blur_window = true; + } + false + } else { + true + } + }); + + if blur_window { + cx.blur(); + } + }) + .unwrap(); + } + } + + fn apply_notify_effect(&mut self, emitter: EntityId) { + self.pending_notifications.remove(&emitter); + self.observers + .clone() + .retain(&emitter, |handler| handler(self)); + } + + fn apply_emit_effect(&mut self, emitter: EntityId, event: Box) { + self.event_listeners + .clone() + .retain(&emitter, |handler| handler(event.as_ref(), self)); + } + + fn apply_focus_changed_effect(&mut self, window_id: WindowId, focused: Option) { + self.update_window(window_id, |cx| { + if cx.window.focus == focused { + let mut listeners = mem::take(&mut cx.window.focus_listeners); + let focused = + focused.map(|id| FocusHandle::for_id(id, &cx.window.focus_handles).unwrap()); + let blurred = cx + .window + .last_blur + .take() + .unwrap() + .and_then(|id| FocusHandle::for_id(id, &cx.window.focus_handles)); + if focused.is_some() || blurred.is_some() { + let event = FocusEvent { focused, blurred }; + for listener in &listeners { + listener(&event, cx); + } + } + + listeners.extend(cx.window.focus_listeners.drain(..)); + cx.window.focus_listeners = listeners; + } + }) + .ok(); + } + + fn apply_refresh_effect(&mut self) { + for window in self.windows.values_mut() { + if let Some(window) = window.as_mut() { + window.dirty = true; + } + } + } + + fn apply_notify_global_observers_effect(&mut self, type_id: TypeId) { + self.pending_global_notifications.remove(&type_id); + self.global_observers + .clone() + .retain(&type_id, |observer| observer(self)); + } + + fn apply_defer_effect(&mut self, callback: Box) { + callback(self); + } + + /// Creates an `AsyncAppContext`, which can be cloned and has a static lifetime + /// so it can be held across `await` points. + pub fn to_async(&self) -> AsyncAppContext { + AsyncAppContext { + app: unsafe { mem::transmute(self.this.clone()) }, + executor: self.executor.clone(), + } + } + + /// Obtains a reference to the executor, which can be used to spawn futures. + pub fn executor(&self) -> &Executor { + &self.executor + } + + /// Runs the given closure on the main thread, where interaction with the platform + /// is possible. The given closure will be invoked with a `MainThread`, which + /// has platform-specific methods that aren't present on `AppContext`. + pub fn run_on_main( + &mut self, + f: impl FnOnce(&mut MainThread) -> R + Send + 'static, + ) -> Task + where + R: Send + 'static, + { + if self.executor.is_main_thread() { + Task::ready(f(unsafe { + mem::transmute::<&mut AppContext, &mut MainThread>(self) + })) + } else { + let this = self.this.upgrade().unwrap(); + self.executor.run_on_main(move || { + let cx = &mut *this.lock(); + cx.update(|cx| f(unsafe { mem::transmute::<&mut Self, &mut MainThread>(cx) })) + }) + } + } + + /// Spawns the future returned by the given function on the main thread, where interaction with + /// the platform is possible. The given closure will be invoked with a `MainThread`, + /// which has platform-specific methods that aren't present on `AsyncAppContext`. The future will be + /// polled exclusively on the main thread. + // todo!("I think we need somehow to prevent the MainThread from implementing Send") + pub fn spawn_on_main( + &self, + f: impl FnOnce(MainThread) -> F + Send + 'static, + ) -> Task + where + F: Future + 'static, + R: Send + 'static, + { + let cx = self.to_async(); + self.executor.spawn_on_main(move || f(MainThread(cx))) + } + + /// Spawns the future returned by the given function on the thread pool. The closure will be invoked + /// with AsyncAppContext, which allows the application state to be accessed across await points. + pub fn spawn(&self, f: impl FnOnce(AsyncAppContext) -> Fut + Send + 'static) -> Task + where + Fut: Future + Send + 'static, + R: Send + 'static, + { + let cx = self.to_async(); + self.executor.spawn(async move { + let future = f(cx); + future.await + }) + } + + /// Schedules the given function to be run at the end of the current effect cycle, allowing entities + /// that are currently on the stack to be returned to the app. + pub fn defer(&mut self, f: impl FnOnce(&mut AppContext) + 'static + Send) { + self.push_effect(Effect::Defer { + callback: Box::new(f), + }); + } + + /// Accessor for the application's asset source, which is provided when constructing the `App`. + pub fn asset_source(&self) -> &Arc { + &self.asset_source + } + + /// Accessor for the text system. + pub fn text_system(&self) -> &Arc { + &self.text_system + } + + /// The current text style. Which is composed of all the style refinements provided to `with_text_style`. + pub fn text_style(&self) -> TextStyle { + let mut style = TextStyle::default(); + for refinement in &self.text_style_stack { + style.refine(refinement); + } + style + } + + /// Check whether a global of the given type has been assigned. + pub fn has_global(&self) -> bool { + self.globals_by_type.contains_key(&TypeId::of::()) + } + + /// Access the global of the given type. Panics if a global for that type has not been assigned. + pub fn global(&self) -> &G { + self.globals_by_type + .get(&TypeId::of::()) + .map(|any_state| any_state.downcast_ref::().unwrap()) + .ok_or_else(|| anyhow!("no state of type {} exists", type_name::())) + .unwrap() + } + + /// Access the global of the given type if a value has been assigned. + pub fn try_global(&self) -> Option<&G> { + self.globals_by_type + .get(&TypeId::of::()) + .map(|any_state| any_state.downcast_ref::().unwrap()) + } + + /// Access the global of the given type mutably. Panics if a global for that type has not been assigned. + pub fn global_mut(&mut self) -> &mut G { + let global_type = TypeId::of::(); + self.push_effect(Effect::NotifyGlobalObservers { global_type }); + self.globals_by_type + .get_mut(&global_type) + .and_then(|any_state| any_state.downcast_mut::()) + .ok_or_else(|| anyhow!("no state of type {} exists", type_name::())) + .unwrap() + } + + /// Access the global of the given type mutably. A default value is assigned if a global of this type has not + /// yet been assigned. + pub fn default_global(&mut self) -> &mut G { + let global_type = TypeId::of::(); + self.push_effect(Effect::NotifyGlobalObservers { global_type }); + self.globals_by_type + .entry(global_type) + .or_insert_with(|| Box::new(G::default())) + .downcast_mut::() + .unwrap() + } + + /// Set the value of the global of the given type. + pub fn set_global(&mut self, global: G) { + let global_type = TypeId::of::(); + self.push_effect(Effect::NotifyGlobalObservers { global_type }); + self.globals_by_type.insert(global_type, Box::new(global)); + } + + /// Update the global of the given type with a closure. Unlike `global_mut`, this method provides + /// your closure with mutable access to the `AppContext` and the global simultaneously. + pub fn update_global(&mut self, f: impl FnOnce(&mut G, &mut Self) -> R) -> R { + let mut global = self.lease_global::(); + let result = f(&mut global, self); + self.end_global_lease(global); + result + } + + /// Register a callback to be invoked when a global of the given type is updated. + pub fn observe_global( + &mut self, + mut f: impl FnMut(&mut Self) + Send + 'static, + ) -> Subscription { + self.global_observers.insert( + TypeId::of::(), + Box::new(move |cx| { + f(cx); + true + }), + ) + } + + pub fn all_action_names<'a>(&'a self) -> impl Iterator + 'a { + self.action_builders.keys().cloned() + } + + /// Move the global of the given type to the stack. + pub(crate) fn lease_global(&mut self) -> GlobalLease { + GlobalLease::new( + self.globals_by_type + .remove(&TypeId::of::()) + .ok_or_else(|| anyhow!("no global registered of type {}", type_name::())) + .unwrap(), + ) + } + + /// Restore the global of the given type after it is moved to the stack. + pub(crate) fn end_global_lease(&mut self, lease: GlobalLease) { + let global_type = TypeId::of::(); + self.push_effect(Effect::NotifyGlobalObservers { global_type }); + self.globals_by_type.insert(global_type, lease.global); + } + + pub(crate) fn push_text_style(&mut self, text_style: TextStyleRefinement) { + self.text_style_stack.push(text_style); + } + + pub(crate) fn pop_text_style(&mut self) { + self.text_style_stack.pop(); + } + + /// Register key bindings. + pub fn bind_keys(&mut self, bindings: impl IntoIterator) { + self.keymap.lock().add_bindings(bindings); + self.pending_effects.push_back(Effect::Refresh); + } + + /// Register a global listener for actions invoked via the keyboard. + pub fn on_action(&mut self, listener: impl Fn(&A, &mut Self) + Send + 'static) { + self.global_action_listeners + .entry(TypeId::of::()) + .or_default() + .push(Box::new(move |action, phase, cx| { + if phase == DispatchPhase::Bubble { + let action = action.as_any().downcast_ref().unwrap(); + listener(action, cx) + } + })); + } + + /// Register an action type to allow it to be referenced in keymaps. + pub fn register_action_type(&mut self) { + self.action_builders.insert(A::qualified_name(), A::build); + } + + /// Construct an action based on its name and parameters. + pub fn build_action( + &mut self, + name: &str, + params: Option, + ) -> Result> { + let build = self + .action_builders + .get(name) + .ok_or_else(|| anyhow!("no action type registered for {}", name))?; + (build)(params) + } + + /// Halt propagation of a mouse event, keyboard event, or action. This prevents listeners + /// that have not yet been invoked from receiving the event. + pub fn stop_propagation(&mut self) { + self.propagate_event = false; + } +} + +impl Context for AppContext { + type ModelContext<'a, T> = ModelContext<'a, T>; + type Result = T; + + /// Build an entity that is owned by the application. The given function will be invoked with + /// a `ModelContext` and must return an object representing the entity. A `Model` will be returned + /// which can be used to access the entity in a context. + fn build_model( + &mut self, + build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T, + ) -> Model { + self.update(|cx| { + let slot = cx.entities.reserve(); + let entity = build_model(&mut ModelContext::mutable(cx, slot.downgrade())); + cx.entities.insert(slot, entity) + }) + } + + /// Update the entity referenced by the given model. The function is passed a mutable reference to the + /// entity along with a `ModelContext` for the entity. + fn update_model( + &mut self, + model: &Model, + update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R, + ) -> R { + self.update(|cx| { + let mut entity = cx.entities.lease(model); + let result = update( + &mut entity, + &mut ModelContext::mutable(cx, model.downgrade()), + ); + cx.entities.end_lease(entity); + result + }) + } +} + +impl MainThread +where + C: Borrow, +{ + pub(crate) fn platform(&self) -> &dyn Platform { + self.0.borrow().platform.borrow_on_main_thread() + } + + /// Instructs the platform to activate the application by bringing it to the foreground. + pub fn activate(&self, ignoring_other_apps: bool) { + self.platform().activate(ignoring_other_apps); + } + + /// Writes data to the platform clipboard. + pub fn write_to_clipboard(&self, item: ClipboardItem) { + self.platform().write_to_clipboard(item) + } + + /// Reads data from the platform clipboard. + pub fn read_from_clipboard(&self) -> Option { + self.platform().read_from_clipboard() + } + + /// Writes credentials to the platform keychain. + pub fn write_credentials(&self, url: &str, username: &str, password: &[u8]) -> Result<()> { + self.platform().write_credentials(url, username, password) + } + + /// Reads credentials from the platform keychain. + pub fn read_credentials(&self, url: &str) -> Result)>> { + self.platform().read_credentials(url) + } + + /// Deletes credentials from the platform keychain. + pub fn delete_credentials(&self, url: &str) -> Result<()> { + self.platform().delete_credentials(url) + } + + /// Directs the platform's default browser to open the given URL. + pub fn open_url(&self, url: &str) { + self.platform().open_url(url); + } + + pub fn path_for_auxiliary_executable(&self, name: &str) -> Result { + self.platform().path_for_auxiliary_executable(name) + } +} + +impl MainThread { + fn update(&mut self, update: impl FnOnce(&mut Self) -> R) -> R { + self.0.update(|cx| { + update(unsafe { + std::mem::transmute::<&mut AppContext, &mut MainThread>(cx) + }) + }) + } + + pub(crate) fn update_window( + &mut self, + id: WindowId, + update: impl FnOnce(&mut MainThread) -> R, + ) -> Result { + self.0.update_window(id, |cx| { + update(unsafe { + std::mem::transmute::<&mut WindowContext, &mut MainThread>(cx) + }) + }) + } + + /// Opens a new window with the given option and the root view returned by the given function. + /// The function is invoked with a `WindowContext`, which can be used to interact with window-specific + /// functionality. + pub fn open_window( + &mut self, + options: crate::WindowOptions, + build_root_view: impl FnOnce(&mut WindowContext) -> View + Send + 'static, + ) -> WindowHandle { + self.update(|cx| { + let id = cx.windows.insert(None); + let handle = WindowHandle::new(id); + let mut window = Window::new(handle.into(), options, cx); + let root_view = build_root_view(&mut WindowContext::mutable(cx, &mut window)); + window.root_view.replace(root_view.into()); + cx.windows.get_mut(id).unwrap().replace(window); + handle + }) + } + + /// Update the global of the given type with a closure. Unlike `global_mut`, this method provides + /// your closure with mutable access to the `MainThread` and the global simultaneously. + pub fn update_global( + &mut self, + update: impl FnOnce(&mut G, &mut MainThread) -> R, + ) -> R { + self.0.update_global(|global, cx| { + let cx = unsafe { mem::transmute::<&mut AppContext, &mut MainThread>(cx) }; + update(global, cx) + }) + } +} + +/// These effects are processed at the end of each application update cycle. +pub(crate) enum Effect { + Notify { + emitter: EntityId, + }, + Emit { + emitter: EntityId, + event: Box, + }, + FocusChanged { + window_id: WindowId, + focused: Option, + }, + Refresh, + NotifyGlobalObservers { + global_type: TypeId, + }, + Defer { + callback: Box, + }, +} + +/// Wraps a global variable value during `update_global` while the value has been moved to the stack. +pub(crate) struct GlobalLease { + global: AnyBox, + global_type: PhantomData, +} + +impl GlobalLease { + fn new(global: AnyBox) -> Self { + GlobalLease { + global, + global_type: PhantomData, + } + } +} + +impl Deref for GlobalLease { + type Target = G; + + fn deref(&self) -> &Self::Target { + self.global.downcast_ref().unwrap() + } +} + +impl DerefMut for GlobalLease { + fn deref_mut(&mut self) -> &mut Self::Target { + self.global.downcast_mut().unwrap() + } +} + +/// Contains state associated with an active drag operation, started by dragging an element +/// within the window or by dragging into the app from the underlying platform. +pub(crate) struct AnyDrag { + pub view: AnyView, + pub cursor_offset: Point, +} + +#[cfg(test)] +mod tests { + use super::AppContext; + + #[test] + fn test_app_context_send_sync() { + // This will not compile if `AppContext` does not implement `Send` + fn assert_send() {} + assert_send::(); + } +} diff --git a/crates/gpui2/src/app/async_context.rs b/crates/gpui2/src/app/async_context.rs new file mode 100644 index 0000000000000000000000000000000000000000..042a75848e467fd452db80447e477a5c42179965 --- /dev/null +++ b/crates/gpui2/src/app/async_context.rs @@ -0,0 +1,252 @@ +use crate::{ + AnyWindowHandle, AppContext, Context, Executor, MainThread, Model, ModelContext, Result, Task, + WindowContext, +}; +use anyhow::anyhow; +use derive_more::{Deref, DerefMut}; +use parking_lot::Mutex; +use std::{future::Future, sync::Weak}; + +#[derive(Clone)] +pub struct AsyncAppContext { + pub(crate) app: Weak>, + pub(crate) executor: Executor, +} + +impl Context for AsyncAppContext { + type ModelContext<'a, T> = ModelContext<'a, T>; + type Result = Result; + + fn build_model( + &mut self, + build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T, + ) -> Self::Result> + where + T: 'static + Send, + { + let app = self + .app + .upgrade() + .ok_or_else(|| anyhow!("app was released"))?; + let mut lock = app.lock(); // Need this to compile + Ok(lock.build_model(build_model)) + } + + fn update_model( + &mut self, + handle: &Model, + update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R, + ) -> Self::Result { + let app = self + .app + .upgrade() + .ok_or_else(|| anyhow!("app was released"))?; + let mut lock = app.lock(); // Need this to compile + Ok(lock.update_model(handle, update)) + } +} + +impl AsyncAppContext { + pub fn refresh(&mut self) -> Result<()> { + let app = self + .app + .upgrade() + .ok_or_else(|| anyhow!("app was released"))?; + let mut lock = app.lock(); // Need this to compile + lock.refresh(); + Ok(()) + } + + pub fn executor(&self) -> &Executor { + &self.executor + } + + pub fn update(&self, f: impl FnOnce(&mut AppContext) -> R) -> Result { + let app = self + .app + .upgrade() + .ok_or_else(|| anyhow!("app was released"))?; + let mut lock = app.lock(); + Ok(f(&mut *lock)) + } + + pub fn read_window( + &self, + handle: AnyWindowHandle, + update: impl FnOnce(&WindowContext) -> R, + ) -> Result { + let app = self + .app + .upgrade() + .ok_or_else(|| anyhow!("app was released"))?; + let mut app_context = app.lock(); + app_context.read_window(handle.id, update) + } + + pub fn update_window( + &self, + handle: AnyWindowHandle, + update: impl FnOnce(&mut WindowContext) -> R, + ) -> Result { + let app = self + .app + .upgrade() + .ok_or_else(|| anyhow!("app was released"))?; + let mut app_context = app.lock(); + app_context.update_window(handle.id, update) + } + + pub fn spawn(&self, f: impl FnOnce(AsyncAppContext) -> Fut + Send + 'static) -> Task + where + Fut: Future + Send + 'static, + R: Send + 'static, + { + let this = self.clone(); + self.executor.spawn(async move { f(this).await }) + } + + pub fn spawn_on_main( + &self, + f: impl FnOnce(AsyncAppContext) -> Fut + Send + 'static, + ) -> Task + where + Fut: Future + 'static, + R: Send + 'static, + { + let this = self.clone(); + self.executor.spawn_on_main(|| f(this)) + } + + pub fn run_on_main( + &self, + f: impl FnOnce(&mut MainThread) -> R + Send + 'static, + ) -> Result> + where + R: Send + 'static, + { + let app = self + .app + .upgrade() + .ok_or_else(|| anyhow!("app was released"))?; + let mut app_context = app.lock(); + Ok(app_context.run_on_main(f)) + } + + pub fn has_global(&self) -> Result { + let app = self + .app + .upgrade() + .ok_or_else(|| anyhow!("app was released"))?; + let lock = app.lock(); // Need this to compile + Ok(lock.has_global::()) + } + + pub fn read_global(&self, read: impl FnOnce(&G, &AppContext) -> R) -> Result { + let app = self + .app + .upgrade() + .ok_or_else(|| anyhow!("app was released"))?; + let lock = app.lock(); // Need this to compile + Ok(read(lock.global(), &lock)) + } + + pub fn try_read_global( + &self, + read: impl FnOnce(&G, &AppContext) -> R, + ) -> Option { + let app = self.app.upgrade()?; + let lock = app.lock(); // Need this to compile + Some(read(lock.try_global()?, &lock)) + } + + pub fn update_global( + &mut self, + update: impl FnOnce(&mut G, &mut AppContext) -> R, + ) -> Result { + let app = self + .app + .upgrade() + .ok_or_else(|| anyhow!("app was released"))?; + let mut lock = app.lock(); // Need this to compile + Ok(lock.update_global(update)) + } +} + +#[derive(Clone, Deref, DerefMut)] +pub struct AsyncWindowContext { + #[deref] + #[deref_mut] + app: AsyncAppContext, + window: AnyWindowHandle, +} + +impl AsyncWindowContext { + pub(crate) fn new(app: AsyncAppContext, window: AnyWindowHandle) -> Self { + Self { app, window } + } + + pub fn update(&self, update: impl FnOnce(&mut WindowContext) -> R) -> Result { + self.app.update_window(self.window, update) + } + + pub fn on_next_frame(&mut self, f: impl FnOnce(&mut WindowContext) + Send + 'static) { + self.app + .update_window(self.window, |cx| cx.on_next_frame(f)) + .ok(); + } + + pub fn read_global( + &self, + read: impl FnOnce(&G, &WindowContext) -> R, + ) -> Result { + self.app + .read_window(self.window, |cx| read(cx.global(), cx)) + } + + pub fn update_global( + &mut self, + update: impl FnOnce(&mut G, &mut WindowContext) -> R, + ) -> Result + where + G: 'static, + { + self.app + .update_window(self.window, |cx| cx.update_global(update)) + } +} + +impl Context for AsyncWindowContext { + type ModelContext<'a, T> = ModelContext<'a, T>; + type Result = Result; + + fn build_model( + &mut self, + build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T, + ) -> Result> + where + T: 'static + Send, + { + self.app + .update_window(self.window, |cx| cx.build_model(build_model)) + } + + fn update_model( + &mut self, + handle: &Model, + update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R, + ) -> Result { + self.app + .update_window(self.window, |cx| cx.update_model(handle, update)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_async_app_context_send_sync() { + fn assert_send_sync() {} + assert_send_sync::(); + } +} diff --git a/crates/gpui2/src/app/entity_map.rs b/crates/gpui2/src/app/entity_map.rs new file mode 100644 index 0000000000000000000000000000000000000000..bbeabd3e4fa3ef115742e3b2bb117fa8661437e3 --- /dev/null +++ b/crates/gpui2/src/app/entity_map.rs @@ -0,0 +1,501 @@ +use crate::{private::Sealed, AnyBox, AppContext, Context, Entity}; +use anyhow::{anyhow, Result}; +use derive_more::{Deref, DerefMut}; +use parking_lot::{RwLock, RwLockUpgradableReadGuard}; +use slotmap::{SecondaryMap, SlotMap}; +use std::{ + any::{type_name, TypeId}, + fmt::{self, Display}, + hash::{Hash, Hasher}, + marker::PhantomData, + mem, + sync::{ + atomic::{AtomicUsize, Ordering::SeqCst}, + Arc, Weak, + }, +}; + +slotmap::new_key_type! { pub struct EntityId; } + +impl EntityId { + pub fn as_u64(self) -> u64 { + self.0.as_ffi() + } +} + +impl Display for EntityId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self) + } +} + +pub(crate) struct EntityMap { + entities: SecondaryMap, + ref_counts: Arc>, +} + +struct EntityRefCounts { + counts: SlotMap, + dropped_entity_ids: Vec, +} + +impl EntityMap { + pub fn new() -> Self { + Self { + entities: SecondaryMap::new(), + ref_counts: Arc::new(RwLock::new(EntityRefCounts { + counts: SlotMap::with_key(), + dropped_entity_ids: Vec::new(), + })), + } + } + + /// Reserve a slot for an entity, which you can subsequently use with `insert`. + pub fn reserve(&self) -> Slot { + let id = self.ref_counts.write().counts.insert(1.into()); + Slot(Model::new(id, Arc::downgrade(&self.ref_counts))) + } + + /// Insert an entity into a slot obtained by calling `reserve`. + pub fn insert(&mut self, slot: Slot, entity: T) -> Model + where + T: 'static + Send, + { + let model = slot.0; + self.entities.insert(model.entity_id, Box::new(entity)); + model + } + + /// Move an entity to the stack. + pub fn lease<'a, T>(&mut self, model: &'a Model) -> Lease<'a, T> { + self.assert_valid_context(model); + let entity = Some( + self.entities + .remove(model.entity_id) + .expect("Circular entity lease. Is the entity already being updated?"), + ); + Lease { + model, + entity, + entity_type: PhantomData, + } + } + + /// Return an entity after moving it to the stack. + pub fn end_lease(&mut self, mut lease: Lease) { + self.entities + .insert(lease.model.entity_id, lease.entity.take().unwrap()); + } + + pub fn read(&self, model: &Model) -> &T { + self.assert_valid_context(model); + self.entities[model.entity_id].downcast_ref().unwrap() + } + + fn assert_valid_context(&self, model: &AnyModel) { + debug_assert!( + Weak::ptr_eq(&model.entity_map, &Arc::downgrade(&self.ref_counts)), + "used a model with the wrong context" + ); + } + + pub fn take_dropped(&mut self) -> Vec<(EntityId, AnyBox)> { + let mut ref_counts = self.ref_counts.write(); + let dropped_entity_ids = mem::take(&mut ref_counts.dropped_entity_ids); + + dropped_entity_ids + .into_iter() + .map(|entity_id| { + ref_counts.counts.remove(entity_id); + (entity_id, self.entities.remove(entity_id).unwrap()) + }) + .collect() + } +} + +pub struct Lease<'a, T> { + entity: Option, + pub model: &'a Model, + entity_type: PhantomData, +} + +impl<'a, T: 'static> core::ops::Deref for Lease<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.entity.as_ref().unwrap().downcast_ref().unwrap() + } +} + +impl<'a, T: 'static> core::ops::DerefMut for Lease<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.entity.as_mut().unwrap().downcast_mut().unwrap() + } +} + +impl<'a, T> Drop for Lease<'a, T> { + fn drop(&mut self) { + if self.entity.is_some() { + // We don't panic here, because other panics can cause us to drop the lease without ending it cleanly. + log::error!("Leases must be ended with EntityMap::end_lease") + } + } +} + +#[derive(Deref, DerefMut)] +pub struct Slot(Model); + +pub struct AnyModel { + pub(crate) entity_id: EntityId, + pub(crate) entity_type: TypeId, + entity_map: Weak>, +} + +impl AnyModel { + fn new(id: EntityId, entity_type: TypeId, entity_map: Weak>) -> Self { + Self { + entity_id: id, + entity_type, + entity_map, + } + } + + pub fn entity_id(&self) -> EntityId { + self.entity_id + } + + pub fn downgrade(&self) -> AnyWeakModel { + AnyWeakModel { + entity_id: self.entity_id, + entity_type: self.entity_type, + entity_ref_counts: self.entity_map.clone(), + } + } + + pub fn downcast(self) -> Result, AnyModel> { + if TypeId::of::() == self.entity_type { + Ok(Model { + any_model: self, + entity_type: PhantomData, + }) + } else { + Err(self) + } + } +} + +impl Clone for AnyModel { + fn clone(&self) -> Self { + if let Some(entity_map) = self.entity_map.upgrade() { + let entity_map = entity_map.read(); + let count = entity_map + .counts + .get(self.entity_id) + .expect("detected over-release of a model"); + let prev_count = count.fetch_add(1, SeqCst); + assert_ne!(prev_count, 0, "Detected over-release of a model."); + } + + Self { + entity_id: self.entity_id, + entity_type: self.entity_type, + entity_map: self.entity_map.clone(), + } + } +} + +impl Drop for AnyModel { + fn drop(&mut self) { + if let Some(entity_map) = self.entity_map.upgrade() { + let entity_map = entity_map.upgradable_read(); + let count = entity_map + .counts + .get(self.entity_id) + .expect("Detected over-release of a model."); + let prev_count = count.fetch_sub(1, SeqCst); + assert_ne!(prev_count, 0, "Detected over-release of a model."); + if prev_count == 1 { + // We were the last reference to this entity, so we can remove it. + let mut entity_map = RwLockUpgradableReadGuard::upgrade(entity_map); + entity_map.dropped_entity_ids.push(self.entity_id); + } + } + } +} + +impl From> for AnyModel { + fn from(model: Model) -> Self { + model.any_model + } +} + +impl Hash for AnyModel { + fn hash(&self, state: &mut H) { + self.entity_id.hash(state); + } +} + +impl PartialEq for AnyModel { + fn eq(&self, other: &Self) -> bool { + self.entity_id == other.entity_id + } +} + +impl Eq for AnyModel {} + +impl std::fmt::Debug for AnyModel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AnyModel") + .field("entity_id", &self.entity_id.as_u64()) + .finish() + } +} + +#[derive(Deref, DerefMut)] +pub struct Model { + #[deref] + #[deref_mut] + pub(crate) any_model: AnyModel, + pub(crate) entity_type: PhantomData, +} + +unsafe impl Send for Model {} +unsafe impl Sync for Model {} +impl Sealed for Model {} + +impl Entity for Model { + type Weak = WeakModel; + + fn entity_id(&self) -> EntityId { + self.any_model.entity_id + } + + fn downgrade(&self) -> Self::Weak { + WeakModel { + any_model: self.any_model.downgrade(), + entity_type: self.entity_type, + } + } + + fn upgrade_from(weak: &Self::Weak) -> Option + where + Self: Sized, + { + Some(Model { + any_model: weak.any_model.upgrade()?, + entity_type: weak.entity_type, + }) + } +} + +impl Model { + fn new(id: EntityId, entity_map: Weak>) -> Self + where + T: 'static, + { + Self { + any_model: AnyModel::new(id, TypeId::of::(), entity_map), + entity_type: PhantomData, + } + } + + /// Downgrade the this to a weak model reference + pub fn downgrade(&self) -> WeakModel { + // Delegate to the trait implementation to keep behavior in one place. + // This method was included to improve method resolution in the presence of + // the Model's deref + Entity::downgrade(self) + } + + /// Convert this into a dynamically typed model. + pub fn into_any(self) -> AnyModel { + self.any_model + } + + pub fn read<'a>(&self, cx: &'a AppContext) -> &'a T { + cx.entities.read(self) + } + + /// Update the entity referenced by this model with the given function. + /// + /// The update function receives a context appropriate for its environment. + /// When updating in an `AppContext`, it receives a `ModelContext`. + /// When updating an a `WindowContext`, it receives a `ViewContext`. + pub fn update( + &self, + cx: &mut C, + update: impl FnOnce(&mut T, &mut C::ModelContext<'_, T>) -> R, + ) -> C::Result + where + C: Context, + { + cx.update_model(self, update) + } +} + +impl Clone for Model { + fn clone(&self) -> Self { + Self { + any_model: self.any_model.clone(), + entity_type: self.entity_type, + } + } +} + +impl std::fmt::Debug for Model { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Model {{ entity_id: {:?}, entity_type: {:?} }}", + self.any_model.entity_id, + type_name::() + ) + } +} + +impl Hash for Model { + fn hash(&self, state: &mut H) { + self.any_model.hash(state); + } +} + +impl PartialEq for Model { + fn eq(&self, other: &Self) -> bool { + self.any_model == other.any_model + } +} + +impl Eq for Model {} + +impl PartialEq> for Model { + fn eq(&self, other: &WeakModel) -> bool { + self.any_model.entity_id() == other.entity_id() + } +} + +#[derive(Clone)] +pub struct AnyWeakModel { + pub(crate) entity_id: EntityId, + entity_type: TypeId, + entity_ref_counts: Weak>, +} + +impl AnyWeakModel { + pub fn entity_id(&self) -> EntityId { + self.entity_id + } + + pub fn is_upgradable(&self) -> bool { + let ref_count = self + .entity_ref_counts + .upgrade() + .and_then(|ref_counts| Some(ref_counts.read().counts.get(self.entity_id)?.load(SeqCst))) + .unwrap_or(0); + ref_count > 0 + } + + pub fn upgrade(&self) -> Option { + let entity_map = self.entity_ref_counts.upgrade()?; + entity_map + .read() + .counts + .get(self.entity_id)? + .fetch_add(1, SeqCst); + Some(AnyModel { + entity_id: self.entity_id, + entity_type: self.entity_type, + entity_map: self.entity_ref_counts.clone(), + }) + } +} + +impl From> for AnyWeakModel { + fn from(model: WeakModel) -> Self { + model.any_model + } +} + +impl Hash for AnyWeakModel { + fn hash(&self, state: &mut H) { + self.entity_id.hash(state); + } +} + +impl PartialEq for AnyWeakModel { + fn eq(&self, other: &Self) -> bool { + self.entity_id == other.entity_id + } +} + +impl Eq for AnyWeakModel {} + +#[derive(Deref, DerefMut)] +pub struct WeakModel { + #[deref] + #[deref_mut] + any_model: AnyWeakModel, + entity_type: PhantomData, +} + +unsafe impl Send for WeakModel {} +unsafe impl Sync for WeakModel {} + +impl Clone for WeakModel { + fn clone(&self) -> Self { + Self { + any_model: self.any_model.clone(), + entity_type: self.entity_type, + } + } +} + +impl WeakModel { + /// Upgrade this weak model reference into a strong model reference + pub fn upgrade(&self) -> Option> { + // Delegate to the trait implementation to keep behavior in one place. + Model::upgrade_from(self) + } + + /// Update the entity referenced by this model with the given function if + /// the referenced entity still exists. Returns an error if the entity has + /// been released. + /// + /// The update function receives a context appropriate for its environment. + /// When updating in an `AppContext`, it receives a `ModelContext`. + /// When updating an a `WindowContext`, it receives a `ViewContext`. + pub fn update( + &self, + cx: &mut C, + update: impl FnOnce(&mut T, &mut C::ModelContext<'_, T>) -> R, + ) -> Result + where + C: Context, + Result>: crate::Flatten, + { + crate::Flatten::flatten( + self.upgrade() + .ok_or_else(|| anyhow!("entity release")) + .map(|this| cx.update_model(&this, update)), + ) + } +} + +impl Hash for WeakModel { + fn hash(&self, state: &mut H) { + self.any_model.hash(state); + } +} + +impl PartialEq for WeakModel { + fn eq(&self, other: &Self) -> bool { + self.any_model == other.any_model + } +} + +impl Eq for WeakModel {} + +impl PartialEq> for WeakModel { + fn eq(&self, other: &Model) -> bool { + self.entity_id() == other.any_model.entity_id() + } +} diff --git a/crates/gpui2/src/app/model_context.rs b/crates/gpui2/src/app/model_context.rs new file mode 100644 index 0000000000000000000000000000000000000000..8a4576c052b6c4355576f2959028b568dae8692d --- /dev/null +++ b/crates/gpui2/src/app/model_context.rs @@ -0,0 +1,266 @@ +use crate::{ + AppContext, AsyncAppContext, Context, Effect, Entity, EntityId, EventEmitter, MainThread, + Model, Reference, Subscription, Task, WeakModel, +}; +use derive_more::{Deref, DerefMut}; +use futures::FutureExt; +use std::{ + any::{Any, TypeId}, + borrow::{Borrow, BorrowMut}, + future::Future, +}; + +#[derive(Deref, DerefMut)] +pub struct ModelContext<'a, T> { + #[deref] + #[deref_mut] + app: Reference<'a, AppContext>, + model_state: WeakModel, +} + +impl<'a, T: 'static> ModelContext<'a, T> { + pub(crate) fn mutable(app: &'a mut AppContext, model_state: WeakModel) -> Self { + Self { + app: Reference::Mutable(app), + model_state, + } + } + + pub fn entity_id(&self) -> EntityId { + self.model_state.entity_id + } + + pub fn handle(&self) -> Model { + self.weak_model() + .upgrade() + .expect("The entity must be alive if we have a model context") + } + + pub fn weak_model(&self) -> WeakModel { + self.model_state.clone() + } + + pub fn observe( + &mut self, + entity: &E, + mut on_notify: impl FnMut(&mut T, E, &mut ModelContext<'_, T>) + Send + 'static, + ) -> Subscription + where + T: 'static + Send, + T2: 'static, + E: Entity, + { + let this = self.weak_model(); + let entity_id = entity.entity_id(); + let handle = entity.downgrade(); + self.app.observers.insert( + entity_id, + Box::new(move |cx| { + if let Some((this, handle)) = this.upgrade().zip(E::upgrade_from(&handle)) { + this.update(cx, |this, cx| on_notify(this, handle, cx)); + true + } else { + false + } + }), + ) + } + + pub fn subscribe( + &mut self, + entity: &E, + mut on_event: impl FnMut(&mut T, E, &T2::Event, &mut ModelContext<'_, T>) + Send + 'static, + ) -> Subscription + where + T: 'static + Send, + T2: 'static + EventEmitter, + E: Entity, + { + let this = self.weak_model(); + let entity_id = entity.entity_id(); + let entity = entity.downgrade(); + self.app.event_listeners.insert( + entity_id, + Box::new(move |event, cx| { + let event: &T2::Event = event.downcast_ref().expect("invalid event type"); + if let Some((this, handle)) = this.upgrade().zip(E::upgrade_from(&entity)) { + this.update(cx, |this, cx| on_event(this, handle, event, cx)); + true + } else { + false + } + }), + ) + } + + pub fn on_release( + &mut self, + mut on_release: impl FnMut(&mut T, &mut AppContext) + Send + 'static, + ) -> Subscription + where + T: 'static, + { + self.app.release_listeners.insert( + self.model_state.entity_id, + Box::new(move |this, cx| { + let this = this.downcast_mut().expect("invalid entity type"); + on_release(this, cx); + }), + ) + } + + pub fn observe_release( + &mut self, + entity: &E, + mut on_release: impl FnMut(&mut T, &mut T2, &mut ModelContext<'_, T>) + Send + 'static, + ) -> Subscription + where + T: Any + Send, + T2: 'static, + E: Entity, + { + let entity_id = entity.entity_id(); + let this = self.weak_model(); + self.app.release_listeners.insert( + entity_id, + Box::new(move |entity, cx| { + let entity = entity.downcast_mut().expect("invalid entity type"); + if let Some(this) = this.upgrade() { + this.update(cx, |this, cx| on_release(this, entity, cx)); + } + }), + ) + } + + pub fn observe_global( + &mut self, + mut f: impl FnMut(&mut T, &mut ModelContext<'_, T>) + Send + 'static, + ) -> Subscription + where + T: 'static + Send, + { + let handle = self.weak_model(); + self.global_observers.insert( + TypeId::of::(), + Box::new(move |cx| handle.update(cx, |view, cx| f(view, cx)).is_ok()), + ) + } + + pub fn on_app_quit( + &mut self, + mut on_quit: impl FnMut(&mut T, &mut ModelContext) -> Fut + Send + 'static, + ) -> Subscription + where + Fut: 'static + Future + Send, + T: 'static + Send, + { + let handle = self.weak_model(); + self.app.quit_observers.insert( + (), + Box::new(move |cx| { + let future = handle.update(cx, |entity, cx| on_quit(entity, cx)).ok(); + async move { + if let Some(future) = future { + future.await; + } + } + .boxed() + }), + ) + } + + pub fn notify(&mut self) { + if self + .app + .pending_notifications + .insert(self.model_state.entity_id) + { + self.app.pending_effects.push_back(Effect::Notify { + emitter: self.model_state.entity_id, + }); + } + } + + pub fn update_global(&mut self, f: impl FnOnce(&mut G, &mut Self) -> R) -> R + where + G: 'static + Send, + { + let mut global = self.app.lease_global::(); + let result = f(&mut global, self); + self.app.end_global_lease(global); + result + } + + pub fn spawn( + &self, + f: impl FnOnce(WeakModel, AsyncAppContext) -> Fut + Send + 'static, + ) -> Task + where + T: 'static, + Fut: Future + Send + 'static, + R: Send + 'static, + { + let this = self.weak_model(); + self.app.spawn(|cx| f(this, cx)) + } + + pub fn spawn_on_main( + &self, + f: impl FnOnce(WeakModel, MainThread) -> Fut + Send + 'static, + ) -> Task + where + Fut: Future + 'static, + R: Send + 'static, + { + let this = self.weak_model(); + self.app.spawn_on_main(|cx| f(this, cx)) + } +} + +impl<'a, T> ModelContext<'a, T> +where + T: EventEmitter, + T::Event: Send, +{ + pub fn emit(&mut self, event: T::Event) { + self.app.pending_effects.push_back(Effect::Emit { + emitter: self.model_state.entity_id, + event: Box::new(event), + }); + } +} + +impl<'a, T> Context for ModelContext<'a, T> { + type ModelContext<'b, U> = ModelContext<'b, U>; + type Result = U; + + fn build_model( + &mut self, + build_model: impl FnOnce(&mut Self::ModelContext<'_, U>) -> U, + ) -> Model + where + U: 'static + Send, + { + self.app.build_model(build_model) + } + + fn update_model( + &mut self, + handle: &Model, + update: impl FnOnce(&mut U, &mut Self::ModelContext<'_, U>) -> R, + ) -> R { + self.app.update_model(handle, update) + } +} + +impl Borrow for ModelContext<'_, T> { + fn borrow(&self) -> &AppContext { + &self.app + } +} + +impl BorrowMut for ModelContext<'_, T> { + fn borrow_mut(&mut self) -> &mut AppContext { + &mut self.app + } +} diff --git a/crates/gpui2/src/app/test_context.rs b/crates/gpui2/src/app/test_context.rs new file mode 100644 index 0000000000000000000000000000000000000000..2b09a95a34b4b60fe5b5792f2dc2bf0333eb7d63 --- /dev/null +++ b/crates/gpui2/src/app/test_context.rs @@ -0,0 +1,152 @@ +use crate::{ + AnyWindowHandle, AppContext, AsyncAppContext, Context, Executor, MainThread, Model, + ModelContext, Result, Task, TestDispatcher, TestPlatform, WindowContext, +}; +use parking_lot::Mutex; +use std::{future::Future, sync::Arc}; + +#[derive(Clone)] +pub struct TestAppContext { + pub app: Arc>, + pub executor: Executor, +} + +impl Context for TestAppContext { + type ModelContext<'a, T> = ModelContext<'a, T>; + type Result = T; + + fn build_model( + &mut self, + build_model: impl FnOnce(&mut Self::ModelContext<'_, T>) -> T, + ) -> Self::Result> + where + T: 'static + Send, + { + let mut lock = self.app.lock(); + lock.build_model(build_model) + } + + fn update_model( + &mut self, + handle: &Model, + update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R, + ) -> Self::Result { + let mut lock = self.app.lock(); + lock.update_model(handle, update) + } +} + +impl TestAppContext { + pub fn new(dispatcher: TestDispatcher) -> Self { + let executor = Executor::new(Arc::new(dispatcher)); + let platform = Arc::new(TestPlatform::new(executor.clone())); + let asset_source = Arc::new(()); + let http_client = util::http::FakeHttpClient::with_404_response(); + Self { + app: AppContext::new(platform, asset_source, http_client), + executor, + } + } + + pub fn quit(&self) { + self.app.lock().quit(); + } + + pub fn refresh(&mut self) -> Result<()> { + let mut lock = self.app.lock(); + lock.refresh(); + Ok(()) + } + + pub fn executor(&self) -> &Executor { + &self.executor + } + + pub fn update(&self, f: impl FnOnce(&mut AppContext) -> R) -> R { + let mut lock = self.app.lock(); + f(&mut *lock) + } + + pub fn read_window( + &self, + handle: AnyWindowHandle, + read: impl FnOnce(&WindowContext) -> R, + ) -> R { + let mut app_context = self.app.lock(); + app_context.read_window(handle.id, read).unwrap() + } + + pub fn update_window( + &self, + handle: AnyWindowHandle, + update: impl FnOnce(&mut WindowContext) -> R, + ) -> R { + let mut app = self.app.lock(); + app.update_window(handle.id, update).unwrap() + } + + pub fn spawn(&self, f: impl FnOnce(AsyncAppContext) -> Fut + Send + 'static) -> Task + where + Fut: Future + Send + 'static, + R: Send + 'static, + { + let cx = self.to_async(); + self.executor.spawn(async move { f(cx).await }) + } + + pub fn spawn_on_main( + &self, + f: impl FnOnce(AsyncAppContext) -> Fut + Send + 'static, + ) -> Task + where + Fut: Future + 'static, + R: Send + 'static, + { + let cx = self.to_async(); + self.executor.spawn_on_main(|| f(cx)) + } + + pub fn run_on_main( + &self, + f: impl FnOnce(&mut MainThread) -> R + Send + 'static, + ) -> Task + where + R: Send + 'static, + { + let mut app_context = self.app.lock(); + app_context.run_on_main(f) + } + + pub fn has_global(&self) -> bool { + let lock = self.app.lock(); + lock.has_global::() + } + + pub fn read_global(&self, read: impl FnOnce(&G, &AppContext) -> R) -> R { + let lock = self.app.lock(); + read(lock.global(), &lock) + } + + pub fn try_read_global( + &self, + read: impl FnOnce(&G, &AppContext) -> R, + ) -> Option { + let lock = self.app.lock(); + Some(read(lock.try_global()?, &lock)) + } + + pub fn update_global( + &mut self, + update: impl FnOnce(&mut G, &mut AppContext) -> R, + ) -> R { + let mut lock = self.app.lock(); + lock.update_global(update) + } + + pub fn to_async(&self) -> AsyncAppContext { + AsyncAppContext { + app: Arc::downgrade(&self.app), + executor: self.executor.clone(), + } + } +} diff --git a/crates/gpui2/src/assets.rs b/crates/gpui2/src/assets.rs new file mode 100644 index 0000000000000000000000000000000000000000..39c8562b69703a959fcbd3ad75bc8a6601b0a839 --- /dev/null +++ b/crates/gpui2/src/assets.rs @@ -0,0 +1,64 @@ +use crate::{size, DevicePixels, Result, SharedString, Size}; +use anyhow::anyhow; +use image::{Bgra, ImageBuffer}; +use std::{ + borrow::Cow, + fmt, + hash::Hash, + sync::atomic::{AtomicUsize, Ordering::SeqCst}, +}; + +pub trait AssetSource: 'static + Send + Sync { + fn load(&self, path: &str) -> Result>; + fn list(&self, path: &str) -> Result>; +} + +impl AssetSource for () { + fn load(&self, path: &str) -> Result> { + Err(anyhow!( + "get called on empty asset provider with \"{}\"", + path + )) + } + + fn list(&self, _path: &str) -> Result> { + Ok(vec![]) + } +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct ImageId(usize); + +pub struct ImageData { + pub id: ImageId, + data: ImageBuffer, Vec>, +} + +impl ImageData { + pub fn new(data: ImageBuffer, Vec>) -> Self { + static NEXT_ID: AtomicUsize = AtomicUsize::new(0); + + Self { + id: ImageId(NEXT_ID.fetch_add(1, SeqCst)), + data, + } + } + + pub fn as_bytes(&self) -> &[u8] { + &self.data + } + + pub fn size(&self) -> Size { + let (width, height) = self.data.dimensions(); + size(width.into(), height.into()) + } +} + +impl fmt::Debug for ImageData { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ImageData") + .field("id", &self.id) + .field("size", &self.data.dimensions()) + .finish() + } +} diff --git a/crates/gpui2/src/color.rs b/crates/gpui2/src/color.rs index 11590f967cdba57edebf52a23fd1721cbf1fc899..db072594760f160a303020af20021e05d74db300 100644 --- a/crates/gpui2/src/color.rs +++ b/crates/gpui2/src/color.rs @@ -1,9 +1,8 @@ #![allow(dead_code)] use serde::de::{self, Deserialize, Deserializer, Visitor}; -use smallvec::SmallVec; use std::fmt; -use std::{num::ParseIntError, ops::Range}; +use std::num::ParseIntError; pub fn rgb>(hex: u32) -> C { let r = ((hex >> 16) & 0xFF) as f32 / 255.0; @@ -12,7 +11,15 @@ pub fn rgb>(hex: u32) -> C { Rgba { r, g, b, a: 1.0 }.into() } -#[derive(Clone, Copy, Default, Debug)] +pub fn rgba(hex: u32) -> Rgba { + let r = ((hex >> 24) & 0xFF) as f32 / 255.0; + let g = ((hex >> 16) & 0xFF) as f32 / 255.0; + let b = ((hex >> 8) & 0xFF) as f32 / 255.0; + let a = (hex & 0xFF) as f32 / 255.0; + Rgba { r, g, b, a } +} + +#[derive(Clone, Copy, Default)] pub struct Rgba { pub r: f32, pub g: f32, @@ -20,6 +27,39 @@ pub struct Rgba { pub a: f32, } +impl fmt::Debug for Rgba { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "rgba({:#010x})", u32::from(*self)) + } +} + +impl Rgba { + pub fn blend(&self, other: Rgba) -> Self { + if other.a >= 1.0 { + return other; + } else if other.a <= 0.0 { + return *self; + } else { + return Rgba { + r: (self.r * (1.0 - other.a)) + (other.r * other.a), + g: (self.g * (1.0 - other.a)) + (other.g * other.a), + b: (self.b * (1.0 - other.a)) + (other.b * other.a), + a: self.a, + }; + } + } +} + +impl From for u32 { + fn from(rgba: Rgba) -> Self { + let r = (rgba.r * 255.0) as u32; + let g = (rgba.g * 255.0) as u32; + let b = (rgba.b * 255.0) as u32; + let a = (rgba.a * 255.0) as u32; + (r << 24) | (g << 16) | (b << 8) | a + } +} + struct RgbaVisitor; impl<'de> Visitor<'de> for RgbaVisitor { @@ -54,33 +94,6 @@ impl<'de> Deserialize<'de> for Rgba { } } -pub trait Lerp { - fn lerp(&self, level: f32) -> Hsla; -} - -impl Lerp for Range { - fn lerp(&self, level: f32) -> Hsla { - let level = level.clamp(0., 1.); - Hsla { - h: self.start.h + (level * (self.end.h - self.start.h)), - s: self.start.s + (level * (self.end.s - self.start.s)), - l: self.start.l + (level * (self.end.l - self.start.l)), - a: self.start.a + (level * (self.end.a - self.start.a)), - } - } -} - -impl From for Rgba { - fn from(value: gpui::color::Color) -> Self { - Self { - r: value.0.r as f32 / 255.0, - g: value.0.g as f32 / 255.0, - b: value.0.b as f32 / 255.0, - a: value.0.a as f32 / 255.0, - } - } -} - impl From for Rgba { fn from(color: Hsla) -> Self { let h = color.h; @@ -128,13 +141,8 @@ impl TryFrom<&'_ str> for Rgba { } } -impl Into for Rgba { - fn into(self) -> gpui::color::Color { - gpui::color::rgba(self.r, self.g, self.b, self.a) - } -} - #[derive(Default, Copy, Clone, Debug, PartialEq)] +#[repr(C)] pub struct Hsla { pub h: f32, pub s: f32, @@ -142,6 +150,14 @@ pub struct Hsla { pub a: f32, } +impl Hsla { + pub fn to_rgb(self) -> Rgba { + self.into() + } +} + +impl Eq for Hsla {} + pub fn hsla(h: f32, s: f32, l: f32, a: f32) -> Hsla { Hsla { h: h.clamp(0., 1.), @@ -160,6 +176,73 @@ pub fn black() -> Hsla { } } +pub fn white() -> Hsla { + Hsla { + h: 0., + s: 0., + l: 1., + a: 1., + } +} + +pub fn red() -> Hsla { + Hsla { + h: 0., + s: 1., + l: 0.5, + a: 1., + } +} + +impl Hsla { + /// Returns true if the HSLA color is fully transparent, false otherwise. + pub fn is_transparent(&self) -> bool { + self.a == 0.0 + } + + /// Blends `other` on top of `self` based on `other`'s alpha value. The resulting color is a combination of `self`'s and `other`'s colors. + /// + /// If `other`'s alpha value is 1.0 or greater, `other` color is fully opaque, thus `other` is returned as the output color. + /// If `other`'s alpha value is 0.0 or less, `other` color is fully transparent, thus `self` is returned as the output color. + /// Else, the output color is calculated as a blend of `self` and `other` based on their weighted alpha values. + /// + /// Assumptions: + /// - Alpha values are contained in the range [0, 1], with 1 as fully opaque and 0 as fully transparent. + /// - The relative contributions of `self` and `other` is based on `self`'s alpha value (`self.a`) and `other`'s alpha value (`other.a`), `self` contributing `self.a * (1.0 - other.a)` and `other` contributing it's own alpha value. + /// - RGB color components are contained in the range [0, 1]. + /// - If `self` and `other` colors are out of the valid range, the blend operation's output and behavior is undefined. + pub fn blend(self, other: Hsla) -> Hsla { + let alpha = other.a; + + if alpha >= 1.0 { + return other; + } else if alpha <= 0.0 { + return self; + } else { + let converted_self = Rgba::from(self); + let converted_other = Rgba::from(other); + let blended_rgb = converted_self.blend(converted_other); + return Hsla::from(blended_rgb); + } + } + + /// Fade out the color by a given factor. This factor should be between 0.0 and 1.0. + /// Where 0.0 will leave the color unchanged, and 1.0 will completely fade out the color. + pub fn fade_out(&mut self, factor: f32) { + self.a *= 1.0 - factor.clamp(0., 1.); + } +} + +// impl From for Rgba { +// fn from(value: Hsla) -> Self { +// let h = value.h; +// let s = value.s; +// let l = value.l; + +// let c = (1 - |2L - 1|) X s +// } +// } + impl From for Hsla { fn from(color: Rgba) -> Self { let r = color.r; @@ -198,62 +281,6 @@ impl From for Hsla { } } -impl Hsla { - /// Scales the saturation and lightness by the given values, clamping at 1.0. - pub fn scale_sl(mut self, s: f32, l: f32) -> Self { - self.s = (self.s * s).clamp(0., 1.); - self.l = (self.l * l).clamp(0., 1.); - self - } - - /// Increases the saturation of the color by a certain amount, with a max - /// value of 1.0. - pub fn saturate(mut self, amount: f32) -> Self { - self.s += amount; - self.s = self.s.clamp(0.0, 1.0); - self - } - - /// Decreases the saturation of the color by a certain amount, with a min - /// value of 0.0. - pub fn desaturate(mut self, amount: f32) -> Self { - self.s -= amount; - self.s = self.s.max(0.0); - if self.s < 0.0 { - self.s = 0.0; - } - self - } - - /// Brightens the color by increasing the lightness by a certain amount, - /// with a max value of 1.0. - pub fn brighten(mut self, amount: f32) -> Self { - self.l += amount; - self.l = self.l.clamp(0.0, 1.0); - self - } - - /// Darkens the color by decreasing the lightness by a certain amount, - /// with a max value of 0.0. - pub fn darken(mut self, amount: f32) -> Self { - self.l -= amount; - self.l = self.l.clamp(0.0, 1.0); - self - } -} - -impl From for Hsla { - fn from(value: gpui::color::Color) -> Self { - Rgba::from(value).into() - } -} - -impl Into for Hsla { - fn into(self) -> gpui::color::Color { - Rgba::from(self).into() - } -} - impl<'de> Deserialize<'de> for Hsla { fn deserialize(deserializer: D) -> Result where @@ -266,59 +293,3 @@ impl<'de> Deserialize<'de> for Hsla { Ok(Hsla::from(rgba)) } } - -pub struct ColorScale { - colors: SmallVec<[Hsla; 2]>, - positions: SmallVec<[f32; 2]>, -} - -pub fn scale(colors: I) -> ColorScale -where - I: IntoIterator, - C: Into, -{ - let mut scale = ColorScale { - colors: colors.into_iter().map(Into::into).collect(), - positions: SmallVec::new(), - }; - let num_colors: f32 = scale.colors.len() as f32 - 1.0; - scale.positions = (0..scale.colors.len()) - .map(|i| i as f32 / num_colors) - .collect(); - scale -} - -impl ColorScale { - fn at(&self, t: f32) -> Hsla { - // Ensure that the input is within [0.0, 1.0] - debug_assert!( - 0.0 <= t && t <= 1.0, - "t value {} is out of range. Expected value in range 0.0 to 1.0", - t - ); - - let position = match self - .positions - .binary_search_by(|a| a.partial_cmp(&t).unwrap()) - { - Ok(index) | Err(index) => index, - }; - let lower_bound = position.saturating_sub(1); - let upper_bound = position.min(self.colors.len() - 1); - let lower_color = &self.colors[lower_bound]; - let upper_color = &self.colors[upper_bound]; - - match upper_bound.checked_sub(lower_bound) { - Some(0) | None => *lower_color, - Some(_) => { - let interval_t = (t - self.positions[lower_bound]) - / (self.positions[upper_bound] - self.positions[lower_bound]); - let h = lower_color.h + interval_t * (upper_color.h - lower_color.h); - let s = lower_color.s + interval_t * (upper_color.s - lower_color.s); - let l = lower_color.l + interval_t * (upper_color.l - lower_color.l); - let a = lower_color.a + interval_t * (upper_color.a - lower_color.a); - Hsla { h, s, l, a } - } - } - } -} diff --git a/crates/gpui2/src/element.rs b/crates/gpui2/src/element.rs index 5fb72885857ea48d52f6aec0353b6112e00831aa..a715ed30ee739cee6f0834be5be497415b0ff8b1 100644 --- a/crates/gpui2/src/element.rs +++ b/crates/gpui2/src/element.rs @@ -1,232 +1,284 @@ -pub use crate::ViewContext; -use anyhow::Result; -use gpui::geometry::vector::Vector2F; -pub use gpui::{Layout, LayoutId}; -use smallvec::SmallVec; +use crate::{BorrowWindow, Bounds, ElementId, LayoutId, Pixels, ViewContext}; +use derive_more::{Deref, DerefMut}; +pub(crate) use smallvec::SmallVec; +use std::{any::Any, mem}; -pub trait Element: 'static + IntoElement { - type PaintState; +pub trait Element { + type ElementState: 'static + Send; + + fn id(&self) -> Option; + + /// Called to initialize this element for the current frame. If this + /// element had state in a previous frame, it will be passed in for the 3rd argument. + fn initialize( + &mut self, + view_state: &mut V, + element_state: Option, + cx: &mut ViewContext, + ) -> Self::ElementState; fn layout( &mut self, - view: &mut V, + view_state: &mut V, + element_state: &mut Self::ElementState, cx: &mut ViewContext, - ) -> Result<(LayoutId, Self::PaintState)> - where - Self: Sized; + ) -> LayoutId; fn paint( &mut self, - view: &mut V, - parent_origin: Vector2F, - layout: &Layout, - state: &mut Self::PaintState, + bounds: Bounds, + view_state: &mut V, + element_state: &mut Self::ElementState, cx: &mut ViewContext, - ) where - Self: Sized; + ); +} + +#[derive(Deref, DerefMut, Default, Clone, Debug, Eq, PartialEq, Hash)] +pub struct GlobalElementId(SmallVec<[ElementId; 32]>); + +pub trait ParentElement { + fn children_mut(&mut self) -> &mut SmallVec<[AnyElement; 2]>; - fn into_any(self) -> AnyElement + fn child(mut self, child: impl Component) -> Self where - Self: 'static + Sized, + Self: Sized, { - AnyElement(Box::new(StatefulElement { - element: self, - phase: ElementPhase::Init, - })) - } - - /// Applies a given function `then` to the current element if `condition` is true. - /// This function is used to conditionally modify the element based on a given condition. - /// If `condition` is false, it just returns the current element as it is. - /// - /// # Parameters - /// - `self`: The current element - /// - `condition`: The boolean condition based on which `then` is applied to the element. - /// - `then`: A function that takes in the current element and returns a possibly modified element. - /// - /// # Return - /// It returns the potentially modified element. - fn when(mut self, condition: bool, then: impl FnOnce(Self) -> Self) -> Self + self.children_mut().push(child.render()); + self + } + + fn children(mut self, iter: impl IntoIterator>) -> Self where Self: Sized, { - if condition { - self = then(self); - } + self.children_mut() + .extend(iter.into_iter().map(|item| item.render())); self } } -/// Used to make ElementState into a trait object, so we can wrap it in AnyElement. -trait AnyStatefulElement { - fn layout(&mut self, view: &mut V, cx: &mut ViewContext) -> Result; - fn paint(&mut self, view: &mut V, parent_origin: Vector2F, cx: &mut ViewContext); +trait ElementObject { + fn initialize(&mut self, view_state: &mut V, cx: &mut ViewContext); + fn layout(&mut self, view_state: &mut V, cx: &mut ViewContext) -> LayoutId; + fn paint(&mut self, view_state: &mut V, cx: &mut ViewContext); } -/// A wrapper around an element that stores its layout state. -struct StatefulElement> { +struct RenderedElement> { element: E, - phase: ElementPhase, + phase: ElementRenderPhase, } -enum ElementPhase> { - Init, - PostLayout { - layout_id: LayoutId, - paint_state: E::PaintState, +#[derive(Default)] +enum ElementRenderPhase { + #[default] + Start, + Initialized { + frame_state: Option, }, - #[allow(dead_code)] - PostPaint { - layout: Layout, - paint_state: E::PaintState, + LayoutRequested { + layout_id: LayoutId, + frame_state: Option, }, - Error(String), + Painted, } -impl> std::fmt::Debug for ElementPhase { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ElementPhase::Init => write!(f, "Init"), - ElementPhase::PostLayout { layout_id, .. } => { - write!(f, "PostLayout with layout id: {:?}", layout_id) - } - ElementPhase::PostPaint { layout, .. } => { - write!(f, "PostPaint with layout: {:?}", layout) - } - ElementPhase::Error(err) => write!(f, "Error: {}", err), +/// Internal struct that wraps an element to store Layout and ElementState after the element is rendered. +/// It's allocated as a trait object to erase the element type and wrapped in AnyElement for +/// improved usability. +impl> RenderedElement { + fn new(element: E) -> Self { + RenderedElement { + element, + phase: ElementRenderPhase::Start, } } } -impl> Default for ElementPhase { - fn default() -> Self { - Self::Init +impl ElementObject for RenderedElement +where + E: Element, + E::ElementState: 'static + Send, +{ + fn initialize(&mut self, view_state: &mut V, cx: &mut ViewContext) { + let frame_state = if let Some(id) = self.element.id() { + cx.with_element_state(id, |element_state, cx| { + let element_state = self.element.initialize(view_state, element_state, cx); + ((), element_state) + }); + None + } else { + let frame_state = self.element.initialize(view_state, None, cx); + Some(frame_state) + }; + + self.phase = ElementRenderPhase::Initialized { frame_state }; } -} -/// We blanket-implement the object-safe ElementStateObject interface to make ElementStates into trait objects -impl> AnyStatefulElement for StatefulElement { - fn layout(&mut self, view: &mut V, cx: &mut ViewContext) -> Result { - let result; - self.phase = match self.element.layout(view, cx) { - Ok((layout_id, paint_state)) => { - result = Ok(layout_id); - ElementPhase::PostLayout { - layout_id, - paint_state, + fn layout(&mut self, state: &mut V, cx: &mut ViewContext) -> LayoutId { + let layout_id; + let mut frame_state; + match mem::take(&mut self.phase) { + ElementRenderPhase::Initialized { + frame_state: initial_frame_state, + } => { + frame_state = initial_frame_state; + if let Some(id) = self.element.id() { + layout_id = cx.with_element_state(id, |element_state, cx| { + let mut element_state = element_state.unwrap(); + let layout_id = self.element.layout(state, &mut element_state, cx); + (layout_id, element_state) + }); + } else { + layout_id = self + .element + .layout(state, frame_state.as_mut().unwrap(), cx); } } - Err(error) => { - let message = error.to_string(); - result = Err(error); - ElementPhase::Error(message) - } + _ => panic!("must call initialize before layout"), + }; + + self.phase = ElementRenderPhase::LayoutRequested { + layout_id, + frame_state, }; - result + layout_id } - fn paint(&mut self, view: &mut V, parent_origin: Vector2F, cx: &mut ViewContext) { - self.phase = match std::mem::take(&mut self.phase) { - ElementPhase::PostLayout { + fn paint(&mut self, view_state: &mut V, cx: &mut ViewContext) { + self.phase = match mem::take(&mut self.phase) { + ElementRenderPhase::LayoutRequested { layout_id, - mut paint_state, - } => match cx.computed_layout(layout_id) { - Ok(layout) => { - self.element - .paint(view, parent_origin, &layout, &mut paint_state, cx); - ElementPhase::PostPaint { - layout, - paint_state, - } - } - Err(error) => ElementPhase::Error(error.to_string()), - }, - ElementPhase::PostPaint { - layout, - mut paint_state, + mut frame_state, } => { - self.element - .paint(view, parent_origin, &layout, &mut paint_state, cx); - ElementPhase::PostPaint { - layout, - paint_state, + let bounds = cx.layout_bounds(layout_id); + if let Some(id) = self.element.id() { + cx.with_element_state(id, |element_state, cx| { + let mut element_state = element_state.unwrap(); + self.element + .paint(bounds, view_state, &mut element_state, cx); + ((), element_state) + }); + } else { + self.element + .paint(bounds, view_state, frame_state.as_mut().unwrap(), cx); } + ElementRenderPhase::Painted } - phase @ ElementPhase::Error(_) => phase, - phase @ _ => { - panic!("invalid element phase to call paint: {:?}", phase); - } + _ => panic!("must call layout before paint"), }; } } -/// A dynamic element. -pub struct AnyElement(Box>); +pub struct AnyElement(Box + Send>); + +unsafe impl Send for AnyElement {} impl AnyElement { - pub fn layout(&mut self, view: &mut V, cx: &mut ViewContext) -> Result { - self.0.layout(view, cx) + pub fn new(element: E) -> Self + where + V: 'static, + E: 'static + Element + Send, + E::ElementState: Any + Send, + { + AnyElement(Box::new(RenderedElement::new(element))) + } + + pub fn initialize(&mut self, view_state: &mut V, cx: &mut ViewContext) { + self.0.initialize(view_state, cx); } - pub fn paint(&mut self, view: &mut V, parent_origin: Vector2F, cx: &mut ViewContext) { - self.0.paint(view, parent_origin, cx) + pub fn layout(&mut self, view_state: &mut V, cx: &mut ViewContext) -> LayoutId { + self.0.layout(view_state, cx) + } + + pub fn paint(&mut self, view_state: &mut V, cx: &mut ViewContext) { + self.0.paint(view_state, cx) } } -pub trait ParentElement { - fn children_mut(&mut self) -> &mut SmallVec<[AnyElement; 2]>; +pub trait Component { + fn render(self) -> AnyElement; - fn child(mut self, child: impl IntoElement) -> Self + fn when(mut self, condition: bool, then: impl FnOnce(Self) -> Self) -> Self where Self: Sized, { - self.children_mut().push(child.into_element().into_any()); + if condition { + self = then(self); + } self } +} - fn children(mut self, children: I) -> Self - where - I: IntoIterator, - E: IntoElement, - Self: Sized, - { - self.children_mut().extend( - children - .into_iter() - .map(|child| child.into_element().into_any()), - ); +impl Component for AnyElement { + fn render(self) -> AnyElement { self } +} - // HACK: This is a temporary hack to get children working for the purposes - // of building UI on top of the current version of gpui2. - // - // We'll (hopefully) be moving away from this in the future. - fn children_any(mut self, children: I) -> Self - where - I: IntoIterator>, - Self: Sized, - { - self.children_mut().extend(children.into_iter()); - self +impl Element for Option +where + V: 'static, + E: 'static + Component + Send, + F: FnOnce(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static, +{ + type ElementState = AnyElement; + + fn id(&self) -> Option { + None } - // HACK: This is a temporary hack to get children working for the purposes - // of building UI on top of the current version of gpui2. - // - // We'll (hopefully) be moving away from this in the future. - fn child_any(mut self, children: AnyElement) -> Self - where - Self: Sized, - { - self.children_mut().push(children); - self + fn initialize( + &mut self, + view_state: &mut V, + _rendered_element: Option, + cx: &mut ViewContext, + ) -> Self::ElementState { + let render = self.take().unwrap(); + let mut rendered_element = (render)(view_state, cx).render(); + rendered_element.initialize(view_state, cx); + rendered_element + } + + fn layout( + &mut self, + view_state: &mut V, + rendered_element: &mut Self::ElementState, + cx: &mut ViewContext, + ) -> LayoutId { + rendered_element.layout(view_state, cx) + } + + fn paint( + &mut self, + _bounds: Bounds, + view_state: &mut V, + rendered_element: &mut Self::ElementState, + cx: &mut ViewContext, + ) { + rendered_element.paint(view_state, cx) } } -pub trait IntoElement { - type Element: Element; +impl Component for Option +where + V: 'static, + E: 'static + Component + Send, + F: FnOnce(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static, +{ + fn render(self) -> AnyElement { + AnyElement::new(self) + } +} - fn into_element(self) -> Self::Element; +impl Component for F +where + V: 'static, + E: 'static + Component + Send, + F: FnOnce(&mut V, &mut ViewContext<'_, '_, V>) -> E + Send + 'static, +{ + fn render(self) -> AnyElement { + AnyElement::new(Some(self)) + } } diff --git a/crates/gpui2/src/elements.rs b/crates/gpui2/src/elements.rs index 5b4942fc467cadfcd46de2e472db80bc638047e0..83c27b8a3b1a88e6dee517c485f5f99ea1df7a93 100644 --- a/crates/gpui2/src/elements.rs +++ b/crates/gpui2/src/elements.rs @@ -1,10 +1,9 @@ -pub mod div; -pub mod hoverable; +mod div; mod img; -pub mod pressable; -pub mod svg; -pub mod text; +mod svg; +mod text; -pub use div::div; -pub use img::img; -pub use svg::svg; +pub use div::*; +pub use img::*; +pub use svg::*; +pub use text::*; diff --git a/crates/gpui2/src/elements/div.rs b/crates/gpui2/src/elements/div.rs index 885b14f2dd5e23ecbe0e6b7eff5f27687fb48de1..6fe10d94a31324984df8431c13a0747d704110b4 100644 --- a/crates/gpui2/src/elements/div.rs +++ b/crates/gpui2/src/elements/div.rs @@ -1,320 +1,354 @@ -use std::{cell::Cell, rc::Rc}; - use crate::{ - element::{AnyElement, Element, IntoElement, Layout, ParentElement}, - hsla, - style::{CornerRadii, Overflow, Style, StyleHelpers, Styleable}, - InteractionHandlers, Interactive, ViewContext, -}; -use anyhow::Result; -use gpui::{ - geometry::{rect::RectF, vector::Vector2F, Point}, - platform::{MouseButton, MouseButtonEvent, MouseMovedEvent, ScrollWheelEvent}, - scene::{self}, - LayoutId, + point, AnyElement, BorrowWindow, Bounds, Component, Element, ElementFocus, ElementId, + ElementInteraction, FocusDisabled, FocusEnabled, FocusHandle, FocusListeners, Focusable, + GlobalElementId, GroupBounds, InteractiveElementState, LayoutId, Overflow, ParentElement, + Pixels, Point, SharedString, StatefulInteraction, StatefulInteractive, StatelessInteraction, + StatelessInteractive, Style, StyleRefinement, Styled, ViewContext, }; -use refineable::{Refineable, RefinementCascade}; +use refineable::Refineable; use smallvec::SmallVec; -use util::ResultExt; -pub struct Div { - styles: RefinementCascade