diff --git a/.github/ISSUE_TEMPLATE/07_bug_windows_alpha.yml b/.github/ISSUE_TEMPLATE/07_bug_windows_beta.yml similarity index 86% rename from .github/ISSUE_TEMPLATE/07_bug_windows_alpha.yml rename to .github/ISSUE_TEMPLATE/07_bug_windows_beta.yml index 826c2b8027144d4b658108e09c79e40490c3005d..b2b2a0f9dfcd5ddaa0dda41650864b053c5bb933 100644 --- a/.github/ISSUE_TEMPLATE/07_bug_windows_alpha.yml +++ b/.github/ISSUE_TEMPLATE/07_bug_windows_beta.yml @@ -1,8 +1,8 @@ -name: Bug Report (Windows Alpha) -description: Zed Windows Alpha Related Bugs +name: Bug Report (Windows Beta) +description: Zed Windows Beta Related Bugs type: "Bug" labels: ["windows"] -title: "Windows Alpha: " +title: "Windows Beta: " body: - type: textarea attributes: diff --git a/.github/workflows/community_release_actions.yml b/.github/workflows/community_release_actions.yml index 37ade90574e76cd95755aad6b5601a43946a271c..0f7a73649e9e1180c78a66ddf54055bf66f243f9 100644 --- a/.github/workflows/community_release_actions.yml +++ b/.github/workflows/community_release_actions.yml @@ -1,3 +1,6 @@ +# IF YOU UPDATE THE NAME OF ANY GITHUB SECRET, YOU MUST CHERRY PICK THE COMMIT +# TO BOTH STABLE AND PREVIEW CHANNELS + name: Release Actions on: @@ -13,9 +16,9 @@ jobs: id: get-release-url run: | if [ "${{ github.event.release.prerelease }}" == "true" ]; then - URL="https://zed.dev/releases/preview/latest" + URL="https://zed.dev/releases/preview" else - URL="https://zed.dev/releases/stable/latest" + URL="https://zed.dev/releases/stable" fi echo "URL=$URL" >> "$GITHUB_OUTPUT" diff --git a/.rules b/.rules index 2f2b9cd705d95775bedf092bc4e6254136da6117..82d15eb9e88299ee7c7fe6c717b2da2646e676a7 100644 --- a/.rules +++ b/.rules @@ -59,7 +59,7 @@ Trying to update an entity while it's already being updated must be avoided as t When `read_with`, `update`, or `update_in` are used with an async context, the closure's return value is wrapped in an `anyhow::Result`. -`WeakEntity` is a weak handle. It has `read_with`, `update`, and `update_in` methods that work the same, but always return an `anyhow::Result` so that they can fail if the entity no longer exists. This can be useful to avoid memory leaks - if entities have mutually recursive handles to eachother they will never be dropped. +`WeakEntity` is a weak handle. It has `read_with`, `update`, and `update_in` methods that work the same, but always return an `anyhow::Result` so that they can fail if the entity no longer exists. This can be useful to avoid memory leaks - if entities have mutually recursive handles to each other they will never be dropped. ## Concurrency diff --git a/Cargo.lock b/Cargo.lock index f7ef8f5ccc67f504da11872b7171817a72b6a116..da94c746e7fc528e12dce31b547c13d248055d66 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -39,7 +39,6 @@ dependencies = [ "util", "uuid", "watch", - "which 6.0.3", "workspace-hack", ] @@ -196,12 +195,13 @@ dependencies = [ [[package]] name = "agent-client-protocol" -version = "0.2.1" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "003fb91bf1b8d6e15f72c45fb9171839af8241e81e3839fbb73536af113b7a79" +checksum = "cc2526e80463b9742afed4829aedd6ae5632d6db778c6cc1fecb80c960c3521b" dependencies = [ "anyhow", "async-broadcast", + "async-trait", "futures 0.3.31", "log", "parking_lot", @@ -294,6 +294,7 @@ dependencies = [ "agent-client-protocol", "agent_settings", "anyhow", + "async-trait", "client", "collections", "env_logger 0.11.8", @@ -301,6 +302,7 @@ dependencies = [ "futures 0.3.31", "gpui", "gpui_tokio", + "http_client", "indoc", "language", "language_model", @@ -416,7 +418,6 @@ dependencies = [ "serde_json", "serde_json_lenient", "settings", - "shlex", "smol", "streaming_diff", "task", @@ -690,6 +691,9 @@ name = "arbitrary" version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dde20b3d026af13f561bdd0f15edf01fc734f0dafcedbaf42bba506a9517f223" +dependencies = [ + "derive_arbitrary", +] [[package]] name = "arc-swap" @@ -898,7 +902,6 @@ version = "0.1.0" dependencies = [ "anyhow", "assistant_slash_command", - "cargo_toml", "chrono", "collections", "context_server", @@ -921,7 +924,6 @@ dependencies = [ "settings", "smol", "text", - "toml 0.8.20", "ui", "util", "workspace", @@ -1025,7 +1027,6 @@ dependencies = [ "util", "watch", "web_search", - "which 6.0.3", "workspace", "workspace-hack", "zlog", @@ -2688,6 +2689,53 @@ dependencies = [ "serde", ] +[[package]] +name = "candle-core" +version = "0.9.1" +source = "git+https://github.com/zed-industries/candle?branch=9.1-patched#724d75eb3deebefe83f2a7381a45d4fac6eda383" +dependencies = [ + "byteorder", + "float8", + "gemm 0.17.1", + "half", + "memmap2", + "num-traits", + "num_cpus", + "rand 0.9.1", + "rand_distr", + "rayon", + "safetensors", + "thiserror 1.0.69", + "ug", + "yoke", + "zip 1.1.4", +] + +[[package]] +name = "candle-nn" +version = "0.9.1" +source = "git+https://github.com/zed-industries/candle?branch=9.1-patched#724d75eb3deebefe83f2a7381a45d4fac6eda383" +dependencies = [ + "candle-core", + "half", + "libc", + "num-traits", + "rayon", + "safetensors", + "serde", + "thiserror 1.0.69", +] + +[[package]] +name = "candle-onnx" +version = "0.9.1" +source = "git+https://github.com/zed-industries/candle?branch=9.1-patched#724d75eb3deebefe83f2a7381a45d4fac6eda383" +dependencies = [ + "candle-core", + "candle-nn", + "prost 0.12.6", +] + [[package]] name = "cap-fs-ext" version = "3.4.4" @@ -2930,7 +2978,7 @@ dependencies = [ "num-traits", "serde", "wasm-bindgen", - "windows-link", + "windows-link 0.1.1", ] [[package]] @@ -4674,6 +4722,20 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da692b8d1080ea3045efaab14434d40468c3d8657e42abddfffca87b428f4c1b" +[[package]] +name = "denoise" +version = "0.1.0" +dependencies = [ + "candle-core", + "candle-onnx", + "log", + "realfft", + "rodio", + "rustfft", + "thiserror 2.0.12", + "workspace-hack", +] + [[package]] name = "der" version = "0.6.1" @@ -4705,6 +4767,17 @@ dependencies = [ "serde", ] +[[package]] +name = "derive_arbitrary" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e567bd82dcff979e4b03460c307b3cdc9e96fde3d73bed1496d2bc75d9dd62a" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "derive_more" version = "0.99.19" @@ -5018,6 +5091,25 @@ version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c7a8fb8a9fbf66c1f703fe16184d10ca0ee9d23be5b4436400408ba54a95005" +[[package]] +name = "dyn-stack" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e53799688f5632f364f8fb387488dd05db9fe45db7011be066fc20e7027f8b" +dependencies = [ + "bytemuck", + "reborrow", +] + +[[package]] +name = "dyn-stack" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490bd48eb68fffcfed519b4edbfd82c69cbe741d175b84f0e0cbe8c57cbe0bdd" +dependencies = [ + "bytemuck", +] + [[package]] name = "ec4rs" version = "1.2.0" @@ -5079,6 +5171,30 @@ dependencies = [ "zeta", ] +[[package]] +name = "edit_prediction_context" +version = "0.1.0" +dependencies = [ + "anyhow", + "arrayvec", + "collections", + "futures 0.3.31", + "gpui", + "indoc", + "language", + "log", + "pretty_assertions", + "project", + "serde_json", + "settings", + "slotmap", + "text", + "tree-sitter", + "util", + "workspace-hack", + "zlog", +] + [[package]] name = "editor" version = "0.1.0" @@ -5262,6 +5378,18 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3d8a32ae18130a3c84dd492d4215c3d913c3b07c6b63c2eb3eb7ff1101ab7bf" +[[package]] +name = "enum-as-inner" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "enumflags2" version = "0.7.11" @@ -5892,6 +6020,18 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d" +[[package]] +name = "float8" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4203231de188ebbdfb85c11f3c20ca2b063945710de04e7b59268731e728b462" +dependencies = [ + "half", + "num-traits", + "rand 0.9.1", + "rand_distr", +] + [[package]] name = "float_next_after" version = "1.0.0" @@ -6346,6 +6486,243 @@ dependencies = [ "thread_local", ] +[[package]] +name = "gemm" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ab24cc62135b40090e31a76a9b2766a501979f3070fa27f689c27ec04377d32" +dependencies = [ + "dyn-stack 0.10.0", + "gemm-c32 0.17.1", + "gemm-c64 0.17.1", + "gemm-common 0.17.1", + "gemm-f16 0.17.1", + "gemm-f32 0.17.1", + "gemm-f64 0.17.1", + "num-complex", + "num-traits", + "paste", + "raw-cpuid 10.7.0", + "seq-macro", +] + +[[package]] +name = "gemm" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab96b703d31950f1aeddded248bc95543c9efc7ac9c4a21fda8703a83ee35451" +dependencies = [ + "dyn-stack 0.13.0", + "gemm-c32 0.18.2", + "gemm-c64 0.18.2", + "gemm-common 0.18.2", + "gemm-f16 0.18.2", + "gemm-f32 0.18.2", + "gemm-f64 0.18.2", + "num-complex", + "num-traits", + "paste", + "raw-cpuid 11.6.0", + "seq-macro", +] + +[[package]] +name = "gemm-c32" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9c030d0b983d1e34a546b86e08f600c11696fde16199f971cd46c12e67512c0" +dependencies = [ + "dyn-stack 0.10.0", + "gemm-common 0.17.1", + "num-complex", + "num-traits", + "paste", + "raw-cpuid 10.7.0", + "seq-macro", +] + +[[package]] +name = "gemm-c32" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6db9fd9f40421d00eea9dd0770045a5603b8d684654816637732463f4073847" +dependencies = [ + "dyn-stack 0.13.0", + "gemm-common 0.18.2", + "num-complex", + "num-traits", + "paste", + "raw-cpuid 11.6.0", + "seq-macro", +] + +[[package]] +name = "gemm-c64" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbb5f2e79fefb9693d18e1066a557b4546cd334b226beadc68b11a8f9431852a" +dependencies = [ + "dyn-stack 0.10.0", + "gemm-common 0.17.1", + "num-complex", + "num-traits", + "paste", + "raw-cpuid 10.7.0", + "seq-macro", +] + +[[package]] +name = "gemm-c64" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfcad8a3d35a43758330b635d02edad980c1e143dc2f21e6fd25f9e4eada8edf" +dependencies = [ + "dyn-stack 0.13.0", + "gemm-common 0.18.2", + "num-complex", + "num-traits", + "paste", + "raw-cpuid 11.6.0", + "seq-macro", +] + +[[package]] +name = "gemm-common" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2e7ea062c987abcd8db95db917b4ffb4ecdfd0668471d8dc54734fdff2354e8" +dependencies = [ + "bytemuck", + "dyn-stack 0.10.0", + "half", + "num-complex", + "num-traits", + "once_cell", + "paste", + "pulp 0.18.22", + "raw-cpuid 10.7.0", + "rayon", + "seq-macro", + "sysctl 0.5.5", +] + +[[package]] +name = "gemm-common" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a352d4a69cbe938b9e2a9cb7a3a63b7e72f9349174a2752a558a8a563510d0f3" +dependencies = [ + "bytemuck", + "dyn-stack 0.13.0", + "half", + "libm", + "num-complex", + "num-traits", + "once_cell", + "paste", + "pulp 0.21.5", + "raw-cpuid 11.6.0", + "rayon", + "seq-macro", + "sysctl 0.6.0", +] + +[[package]] +name = "gemm-f16" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ca4c06b9b11952071d317604acb332e924e817bd891bec8dfb494168c7cedd4" +dependencies = [ + "dyn-stack 0.10.0", + "gemm-common 0.17.1", + "gemm-f32 0.17.1", + "half", + "num-complex", + "num-traits", + "paste", + "raw-cpuid 10.7.0", + "rayon", + "seq-macro", +] + +[[package]] +name = "gemm-f16" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cff95ae3259432f3c3410eaa919033cd03791d81cebd18018393dc147952e109" +dependencies = [ + "dyn-stack 0.13.0", + "gemm-common 0.18.2", + "gemm-f32 0.18.2", + "half", + "num-complex", + "num-traits", + "paste", + "raw-cpuid 11.6.0", + "rayon", + "seq-macro", +] + +[[package]] +name = "gemm-f32" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9a69f51aaefbd9cf12d18faf273d3e982d9d711f60775645ed5c8047b4ae113" +dependencies = [ + "dyn-stack 0.10.0", + "gemm-common 0.17.1", + "num-complex", + "num-traits", + "paste", + "raw-cpuid 10.7.0", + "seq-macro", +] + +[[package]] +name = "gemm-f32" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc8d3d4385393304f407392f754cd2dc4b315d05063f62cf09f47b58de276864" +dependencies = [ + "dyn-stack 0.13.0", + "gemm-common 0.18.2", + "num-complex", + "num-traits", + "paste", + "raw-cpuid 11.6.0", + "seq-macro", +] + +[[package]] +name = "gemm-f64" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa397a48544fadf0b81ec8741e5c0fba0043008113f71f2034def1935645d2b0" +dependencies = [ + "dyn-stack 0.10.0", + "gemm-common 0.17.1", + "num-complex", + "num-traits", + "paste", + "raw-cpuid 10.7.0", + "seq-macro", +] + +[[package]] +name = "gemm-f64" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35b2a4f76ce4b8b16eadc11ccf2e083252d8237c1b589558a49b0183545015bd" +dependencies = [ + "dyn-stack 0.13.0", + "gemm-common 0.18.2", + "num-complex", + "num-traits", + "paste", + "raw-cpuid 11.6.0", + "seq-macro", +] + [[package]] name = "generator" version = "0.8.5" @@ -7621,9 +7998,12 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" dependencies = [ + "bytemuck", "cfg-if", "crunchy", "num-traits", + "rand 0.9.1", + "rand_distr", ] [[package]] @@ -9236,6 +9616,7 @@ dependencies = [ "credentials_provider", "deepseek", "editor", + "fs", "futures 0.3.31", "google_ai", "gpui", @@ -9269,6 +9650,7 @@ dependencies = [ "vercel", "workspace-hack", "x_ai", + "zed_env_vars", ] [[package]] @@ -9362,6 +9744,7 @@ dependencies = [ "pet-fs", "pet-poetry", "pet-reporter", + "pet-virtualenv", "pretty_assertions", "project", "regex", @@ -10231,6 +10614,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f" dependencies = [ "libc", + "stable_deref_trait", ] [[package]] @@ -10497,12 +10881,6 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" -[[package]] -name = "multimap" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" - [[package]] name = "naga" version = "25.0.1" @@ -10876,6 +11254,7 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" dependencies = [ + "bytemuck", "num-traits", ] @@ -12572,6 +12951,15 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "primal-check" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc0d895b311e3af9902528fbb8f928688abbd95872819320517cc24ca6b2bd08" +dependencies = [ + "num-integer", +] + [[package]] name = "proc-macro-crate" version = "3.3.0" @@ -12880,7 +13268,7 @@ dependencies = [ "itertools 0.10.5", "lazy_static", "log", - "multimap 0.8.3", + "multimap", "petgraph", "prost 0.9.0", "prost-types 0.9.0", @@ -12899,7 +13287,7 @@ dependencies = [ "heck 0.5.0", "itertools 0.12.1", "log", - "multimap 0.10.0", + "multimap", "once_cell", "petgraph", "prettyplease", @@ -13071,6 +13459,32 @@ dependencies = [ "wasmtime-math", ] +[[package]] +name = "pulp" +version = "0.18.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0a01a0dc67cf4558d279f0c25b0962bd08fc6dec0137699eae304103e882fe6" +dependencies = [ + "bytemuck", + "libm", + "num-complex", + "reborrow", +] + +[[package]] +name = "pulp" +version = "0.21.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96b86df24f0a7ddd5e4b95c94fc9ed8a98f1ca94d3b01bdce2824097e7835907" +dependencies = [ + "bytemuck", + "cfg-if", + "libm", + "num-complex", + "reborrow", + "version_check", +] + [[package]] name = "qoi" version = "0.4.1" @@ -13247,6 +13661,16 @@ dependencies = [ "getrandom 0.3.2", ] +[[package]] +name = "rand_distr" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" +dependencies = [ + "num-traits", + "rand 0.9.1", +] + [[package]] name = "range-map" version = "0.2.0" @@ -13312,6 +13736,24 @@ dependencies = [ "rgb", ] +[[package]] +name = "raw-cpuid" +version = "10.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" +dependencies = [ + "bitflags 1.3.2", +] + +[[package]] +name = "raw-cpuid" +version = "11.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" +dependencies = [ + "bitflags 2.9.0", +] + [[package]] name = "raw-window-handle" version = "0.6.2" @@ -13360,6 +13802,21 @@ dependencies = [ "font-types", ] +[[package]] +name = "realfft" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f821338fddb99d089116342c46e9f1fbf3828dba077674613e734e01d6ea8677" +dependencies = [ + "rustfft", +] + +[[package]] +name = "reborrow" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" + [[package]] name = "recent_projects" version = "0.1.0" @@ -13373,6 +13830,7 @@ dependencies = [ "futures 0.3.31", "fuzzy", "gpui", + "indoc", "language", "log", "markdown", @@ -13393,6 +13851,7 @@ dependencies = [ "theme", "ui", "util", + "windows-registry 0.6.0", "workspace", "workspace-hack", "zed_actions", @@ -14177,6 +14636,20 @@ dependencies = [ "semver", ] +[[package]] +name = "rustfft" +version = "6.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6f140db74548f7c9d7cce60912c9ac414e74df5e718dc947d514b051b42f3f4" +dependencies = [ + "num-complex", + "num-integer", + "num-traits", + "primal-check", + "strength_reduce", + "transpose", +] + [[package]] name = "rustix" version = "0.38.44" @@ -14401,6 +14874,16 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +[[package]] +name = "safetensors" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44560c11236a6130a46ce36c836a62936dc81ebf8c36a37947423571be0e55b6" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "salsa20" version = "0.10.2" @@ -14794,6 +15277,12 @@ dependencies = [ "serde", ] +[[package]] +name = "seq-macro" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" + [[package]] name = "serde" version = "1.0.221" @@ -15760,6 +16249,12 @@ dependencies = [ "workspace-hack", ] +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + [[package]] name = "strict-num" version = "0.1.1" @@ -16250,6 +16745,34 @@ dependencies = [ "libc", ] +[[package]] +name = "sysctl" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec7dddc5f0fee506baf8b9fdb989e242f17e4b11c61dfbb0635b705217199eea" +dependencies = [ + "bitflags 2.9.0", + "byteorder", + "enum-as-inner", + "libc", + "thiserror 1.0.69", + "walkdir", +] + +[[package]] +name = "sysctl" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01198a2debb237c62b6826ec7081082d951f46dbb64b0e8c7649a452230d1dfc" +dependencies = [ + "bitflags 2.9.0", + "byteorder", + "enum-as-inner", + "libc", + "thiserror 1.0.69", + "walkdir", +] + [[package]] name = "sysinfo" version = "0.31.4" @@ -17345,6 +17868,16 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "transpose" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad61aed86bc3faea4300c7aee358b4c6d0c8d6ccc36524c96e4c92ccf26e77e" +dependencies = [ + "num-integer", + "strength_reduce", +] + [[package]] name = "tree-sitter" version = "0.25.6" @@ -17706,6 +18239,27 @@ dependencies = [ "winapi", ] +[[package]] +name = "ug" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90b70b37e9074642bc5f60bb23247fd072a84314ca9e71cdf8527593406a0dd3" +dependencies = [ + "gemm 0.18.2", + "half", + "libloading", + "memmap2", + "num", + "num-traits", + "num_cpus", + "rayon", + "safetensors", + "serde", + "thiserror 1.0.69", + "tracing", + "yoke", +] + [[package]] name = "ui" version = "0.1.0" @@ -18972,7 +19526,7 @@ dependencies = [ "reqwest 0.11.27", "scratch", "semver", - "zip", + "zip 0.6.6", ] [[package]] @@ -19145,7 +19699,7 @@ dependencies = [ "windows-collections", "windows-core 0.61.0", "windows-future", - "windows-link", + "windows-link 0.1.1", "windows-numerics", ] @@ -19215,7 +19769,7 @@ checksum = "4763c1de310c86d75a878046489e2e5ba02c649d185f21c67d4cf8a56d098980" dependencies = [ "windows-implement 0.60.0", "windows-interface 0.59.1", - "windows-link", + "windows-link 0.1.1", "windows-result 0.3.2", "windows-strings 0.4.0", ] @@ -19227,7 +19781,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a1d6bbefcb7b60acd19828e1bc965da6fcf18a7e39490c5f8be71e54a19ba32" dependencies = [ "windows-core 0.61.0", - "windows-link", + "windows-link 0.1.1", ] [[package]] @@ -19302,6 +19856,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38" +[[package]] +name = "windows-link" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45e46c0661abb7180e7b9c281db115305d49ca1709ab8242adf09666d2173c65" + [[package]] name = "windows-numerics" version = "0.2.0" @@ -19309,7 +19869,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" dependencies = [ "windows-core 0.61.0", - "windows-link", + "windows-link 0.1.1", ] [[package]] @@ -19329,11 +19889,22 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ad1da3e436dc7653dfdf3da67332e22bff09bb0e28b0239e1624499c7830842e" dependencies = [ - "windows-link", + "windows-link 0.1.1", "windows-result 0.3.2", "windows-strings 0.4.0", ] +[[package]] +name = "windows-registry" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f91f87ce112ffb7275000ea98eb1940912c21c1567c9312fde20261f3eadd29" +dependencies = [ + "windows-link 0.2.0", + "windows-result 0.4.0", + "windows-strings 0.5.0", +] + [[package]] name = "windows-result" version = "0.1.2" @@ -19358,7 +19929,16 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c64fd11a4fd95df68efcfee5f44a294fe71b8bc6a91993e2791938abcc712252" dependencies = [ - "windows-link", + "windows-link 0.1.1", +] + +[[package]] +name = "windows-result" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7084dcc306f89883455a206237404d3eaf961e5bd7e0f312f7c91f57eb44167f" +dependencies = [ + "windows-link 0.2.0", ] [[package]] @@ -19377,7 +19957,7 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "87fa48cc5d406560701792be122a10132491cff9d0aeb23583cc2dcafc847319" dependencies = [ - "windows-link", + "windows-link 0.1.1", ] [[package]] @@ -19386,7 +19966,16 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a2ba9642430ee452d5a7aa78d72907ebe8cfda358e8cb7918a2050581322f97" dependencies = [ - "windows-link", + "windows-link 0.1.1", +] + +[[package]] +name = "windows-strings" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7218c655a553b0bed4426cf54b20d7ba363ef543b52d515b3e48d7fd55318dda" +dependencies = [ + "windows-link 0.2.0", ] [[package]] @@ -20131,6 +20720,7 @@ dependencies = [ "lyon_path", "md-5", "memchr", + "memmap2", "mime_guess", "miniz_oxide", "mio 1.0.3", @@ -20139,8 +20729,10 @@ dependencies = [ "nix 0.29.0", "nix 0.30.1", "nom 7.1.3", + "num", "num-bigint", "num-bigint-dig", + "num-complex", "num-integer", "num-iter", "num-rational", @@ -20156,6 +20748,7 @@ dependencies = [ "phf_shared", "prettyplease", "proc-macro2", + "prost 0.12.6", "prost 0.9.0", "prost-types 0.9.0", "quote", @@ -20163,6 +20756,7 @@ dependencies = [ "rand 0.9.1", "rand_chacha 0.3.1", "rand_core 0.6.4", + "rand_distr", "regalloc2", "regex", "regex-automata", @@ -20192,6 +20786,7 @@ dependencies = [ "sqlx-macros-core", "sqlx-postgres", "sqlx-sqlite", + "stable_deref_trait", "strum 0.26.3", "subtle", "syn 1.0.109", @@ -20225,6 +20820,7 @@ dependencies = [ "windows-sys 0.48.0", "windows-sys 0.52.0", "windows-sys 0.59.0", + "windows-sys 0.60.2", "winnow", "zeroize", "zvariant", @@ -20594,7 +21190,7 @@ dependencies = [ [[package]] name = "zed" -version = "0.205.0" +version = "0.206.0" dependencies = [ "acp_tools", "activity_indicator", @@ -20760,6 +21356,7 @@ dependencies = [ name = "zed_env_vars" version = "0.1.0" dependencies = [ + "gpui", "workspace-hack", ] @@ -21073,6 +21670,21 @@ dependencies = [ "zstd", ] +[[package]] +name = "zip" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cc23c04387f4da0374be4533ad1208cbb091d5c11d070dfef13676ad6497164" +dependencies = [ + "arbitrary", + "crc32fast", + "crossbeam-utils", + "displaydoc", + "indexmap 2.9.0", + "num_enum", + "thiserror 1.0.69", +] + [[package]] name = "zlib-rs" version = "0.5.0" diff --git a/Cargo.toml b/Cargo.toml index 846b0e32ee61662efa2026c116b8beee87495bcf..08a9b41315c36d7facd7b9d1751b949a2577395c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,10 +52,12 @@ members = [ "crates/debugger_tools", "crates/debugger_ui", "crates/deepseek", + "crates/denoise", "crates/diagnostics", "crates/docs_preprocessor", "crates/edit_prediction", "crates/edit_prediction_button", + "crates/edit_prediction_context", "crates/editor", "crates/eval", "crates/explorer_command_injector", @@ -311,6 +313,7 @@ icons = { path = "crates/icons" } image_viewer = { path = "crates/image_viewer" } edit_prediction = { path = "crates/edit_prediction" } edit_prediction_button = { path = "crates/edit_prediction_button" } +edit_prediction_context = { path = "crates/edit_prediction_context" } inspector_ui = { path = "crates/inspector_ui" } install_cli = { path = "crates/install_cli" } jj = { path = "crates/jj" } @@ -433,7 +436,7 @@ zlog_settings = { path = "crates/zlog_settings" } # External crates # -agent-client-protocol = { version = "0.2.1", features = ["unstable"] } +agent-client-protocol = { version = "0.4.0", features = ["unstable"] } aho-corasick = "1.1" alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" } any_vec = "0.14" @@ -581,6 +584,7 @@ pet-fs = { git = "https://github.com/microsoft/python-environment-tools.git", re pet-pixi = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" } pet-poetry = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" } pet-reporter = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" } +pet-virtualenv = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" } portable-pty = "0.9.0" postage = { version = "0.5", features = ["futures-traits"] } pretty_assertions = { version = "1.3.0", features = ["unstable"] } @@ -630,6 +634,7 @@ sha2 = "0.10" shellexpand = "2.1.0" shlex = "1.3.0" simplelog = "0.12.2" +slotmap = "1.0.6" smallvec = { version = "1.6", features = ["union"] } smol = "2.0" sqlformat = "0.2" diff --git a/assets/icons/linux.svg b/assets/icons/linux.svg new file mode 100644 index 0000000000000000000000000000000000000000..fc76742a3f236650cb8c514c8263ec2c3b2d4521 --- /dev/null +++ b/assets/icons/linux.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index 6b4c4e0fac95cf751c21cfaa0770d1279a35adcc..8ca0a5d42094db8b4b37c7e6919da0f7a6bd41db 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -462,8 +462,8 @@ "ctrl-k ctrl-w": "workspace::CloseAllItemsAndPanes", "back": "pane::GoBack", "ctrl-alt--": "pane::GoBack", - "ctrl-alt-_": "pane::GoForward", "forward": "pane::GoForward", + "ctrl-alt-_": "pane::GoForward", "ctrl-alt-g": "search::SelectNextMatch", "f3": "search::SelectNextMatch", "ctrl-alt-shift-g": "search::SelectPreviousMatch", diff --git a/assets/keymaps/default-windows.json b/assets/keymaps/default-windows.json index e5839964ad545f3994d675da817a5f4571b88db4..78d5e4e698daefee5a57b04d6a8548fb948233b1 100644 --- a/assets/keymaps/default-windows.json +++ b/assets/keymaps/default-windows.json @@ -497,6 +497,8 @@ "shift-alt-down": "editor::DuplicateLineDown", "shift-alt-right": "editor::SelectLargerSyntaxNode", // Expand selection "shift-alt-left": "editor::SelectSmallerSyntaxNode", // Shrink selection + "ctrl-shift-right": "editor::SelectLargerSyntaxNode", // Expand selection (VSCode version) + "ctrl-shift-left": "editor::SelectSmallerSyntaxNode", // Shrink selection (VSCode version) "ctrl-shift-l": "editor::SelectAllMatches", // Select all occurrences of current selection "ctrl-f2": "editor::SelectAllMatches", // Select all occurrences of current word "ctrl-d": ["editor::SelectNext", { "replace_newest": false }], // editor.action.addSelectionToNextFindMatch / find_under_expand diff --git a/assets/settings/default.json b/assets/settings/default.json index 78fdc3d38d2e33febcb186e1edf68c3a40b01d66..e3e81b83b0ac809e7bc557796ee0879ee9d33cf8 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -311,7 +311,7 @@ // bracket, brace, single or double quote characters. // For example, when you select text and type (, Zed will surround the text with (). "use_auto_surround": true, - /// Whether indentation should be adjusted based on the context whilst typing. + // Whether indentation should be adjusted based on the context whilst typing. "auto_indent": true, // Whether indentation of pasted content should be adjusted based on the context. "auto_indent_on_paste": true, @@ -409,18 +409,18 @@ "show_menus": false }, "audio": { - /// Opt into the new audio system. + // Opt into the new audio system. "experimental.rodio_audio": false, - /// Requires 'rodio_audio: true' - /// - /// Use the new audio systems automatic gain control for your microphone. - /// This affects how loud you sound to others. + // Requires 'rodio_audio: true' + // + // Use the new audio systems automatic gain control for your microphone. + // This affects how loud you sound to others. "experimental.control_input_volume": false, - /// Requires 'rodio_audio: true' - /// - /// Use the new audio systems automatic gain control on everyone in the - /// call. This makes call members who are too quite louder and those who are - /// too loud quieter. This only affects how things sound for you. + // Requires 'rodio_audio: true' + // + // Use the new audio systems automatic gain control on everyone in the + // call. This makes call members who are too quite louder and those who are + // too loud quieter. This only affects how things sound for you. "experimental.control_output_volume": false }, // Scrollbar related settings @@ -812,7 +812,7 @@ "agent": { // Whether the agent is enabled. "enabled": true, - /// What completion mode to start new threads in, if available. Can be 'normal' or 'burn'. + // What completion mode to start new threads in, if available. Can be 'normal' or 'burn'. "preferred_completion_mode": "normal", // Whether to show the agent panel button in the status bar. "button": true, @@ -925,18 +925,22 @@ // Default: false "play_sound_when_agent_done": false, - /// Whether to have edit cards in the agent panel expanded, showing a preview of the full diff. - /// - /// Default: true + // Whether to have edit cards in the agent panel expanded, showing a preview of the full diff. + // + // Default: true "expand_edit_card": true, - /// Whether to have terminal cards in the agent panel expanded, showing the whole command output. - /// - /// Default: true + // Whether to have terminal cards in the agent panel expanded, showing the whole command output. + // + // Default: true "expand_terminal_card": true, - /// Whether to always use cmd-enter (or ctrl-enter on Linux or Windows) to send messages in the agent panel. - /// - /// Default: false - "use_modifier_to_send": false + // Whether to always use cmd-enter (or ctrl-enter on Linux or Windows) to send messages in the agent panel. + // + // Default: false + "use_modifier_to_send": false, + // Minimum number of lines to display in the agent message editor. + // + // Default: 4 + "message_editor_min_lines": 4 }, // Whether the screen sharing icon is shown in the os status bar. "show_call_status_icon": true, @@ -1821,12 +1825,12 @@ "zed.dev": {} }, "session": { - /// Whether or not to restore unsaved buffers on restart. - /// - /// If this is true, user won't be prompted whether to save/discard - /// dirty files when closing the application. - /// - /// Default: true + // Whether or not to restore unsaved buffers on restart. + // + // If this is true, user won't be prompted whether to save/discard + // dirty files when closing the application. + // + // Default: true "restore_unsaved_buffers": true }, // Zed's Prettier integration settings. @@ -2012,9 +2016,9 @@ // } "profiles": [], - /// A map of log scopes to the desired log level. - /// Useful for filtering out noisy logs or enabling more verbose logging. - /// - /// Example: {"log": {"client": "warn"}} + // A map of log scopes to the desired log level. + // Useful for filtering out noisy logs or enabling more verbose logging. + // + // Example: {"log": {"client": "warn"}} "log": {} } diff --git a/crates/acp_thread/Cargo.toml b/crates/acp_thread/Cargo.toml index a0bbda848f9ec761aebdf66b644a8b2926685122..ac24a6ed0f41c75d5c4dcd9b9b4122336022ddf3 100644 --- a/crates/acp_thread/Cargo.toml +++ b/crates/acp_thread/Cargo.toml @@ -45,7 +45,6 @@ url.workspace = true util.workspace = true uuid.workspace = true watch.workspace = true -which.workspace = true workspace-hack.workspace = true [dev-dependencies] diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index afbb4781f61d5ccf1ea753df1fd0379e533e8e46..68e5266f06aa8bddfaa252bdc1cf5b21891c7f10 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -7,12 +7,12 @@ use agent_settings::AgentSettings; use collections::HashSet; pub use connection::*; pub use diff::*; -use futures::future::Shared; use language::language_settings::FormatOnSave; pub use mention::*; use project::lsp_store::{FormatTrigger, LspFormatTarget}; use serde::{Deserialize, Serialize}; use settings::Settings as _; +use task::{Shell, ShellBuilder}; pub use terminal::*; use action_log::ActionLog; @@ -34,7 +34,7 @@ use std::rc::Rc; use std::time::{Duration, Instant}; use std::{fmt::Display, mem, path::PathBuf, sync::Arc}; use ui::App; -use util::{ResultExt, get_system_shell}; +use util::{ResultExt, get_default_system_shell}; use uuid::Uuid; #[derive(Debug)] @@ -786,7 +786,6 @@ pub struct AcpThread { token_usage: Option, prompt_capabilities: acp::PromptCapabilities, _observe_prompt_capabilities: Task>, - determine_shell: Shared>, terminals: HashMap>, } @@ -873,20 +872,6 @@ impl AcpThread { } }); - let determine_shell = cx - .background_spawn(async move { - if cfg!(windows) { - return get_system_shell(); - } - - if which::which("bash").is_ok() { - "bash".into() - } else { - get_system_shell() - } - }) - .shared(); - Self { action_log, shared_buffers: Default::default(), @@ -901,7 +886,6 @@ impl AcpThread { prompt_capabilities, _observe_prompt_capabilities: task, terminals: HashMap::default(), - determine_shell, } } @@ -1127,9 +1111,33 @@ impl AcpThread { let update = update.into(); let languages = self.project.read(cx).languages().clone(); - let ix = self - .index_for_tool_call(update.id()) - .context("Tool call not found")?; + let ix = match self.index_for_tool_call(update.id()) { + Some(ix) => ix, + None => { + // Tool call not found - create a failed tool call entry + let failed_tool_call = ToolCall { + id: update.id().clone(), + label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)), + kind: acp::ToolKind::Fetch, + content: vec![ToolCallContent::ContentBlock(ContentBlock::new( + acp::ContentBlock::Text(acp::TextContent { + text: "Tool call not found".to_string(), + annotations: None, + meta: None, + }), + &languages, + cx, + ))], + status: ToolCallStatus::Failed, + locations: Vec::new(), + resolved_locations: Vec::new(), + raw_input: None, + raw_output: None, + }; + self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx); + return Ok(()); + } + }; let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else { unreachable!() }; @@ -1940,28 +1948,13 @@ impl AcpThread { pub fn create_terminal( &self, - mut command: String, + command: String, args: Vec, extra_env: Vec, cwd: Option, output_byte_limit: Option, cx: &mut Context, ) -> Task>> { - for arg in args { - command.push(' '); - command.push_str(&arg); - } - - let shell_command = if cfg!(windows) { - format!("$null | & {{{}}}", command.replace("\"", "'")) - } else if let Some(cwd) = cwd.as_ref().and_then(|cwd| cwd.as_os_str().to_str()) { - // Make sure once we're *inside* the shell, we cd into `cwd` - format!("(cd {cwd}; {}) self.project.update(cx, |project, cx| { project.directory_environment(dir.as_path().into(), cx) @@ -1982,20 +1975,30 @@ impl AcpThread { let project = self.project.clone(); let language_registry = project.read(cx).languages().clone(); - let determine_shell = self.determine_shell.clone(); let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into()); let terminal_task = cx.spawn({ let terminal_id = terminal_id.clone(); async move |_this, cx| { - let program = determine_shell.await; let env = env.await; + let (command, args) = ShellBuilder::new( + project + .update(cx, |project, cx| { + project + .remote_client() + .and_then(|r| r.read(cx).default_system_shell()) + })? + .as_deref(), + &Shell::Program(get_default_system_shell()), + ) + .redirect_stdin_to_dev_null() + .build(Some(command), &args); let terminal = project .update(cx, |project, cx| { project.create_terminal_task( task::SpawnInTerminal { - command: Some(program), - args, + command: Some(command.clone()), + args: args.clone(), cwd: cwd.clone(), env, ..Default::default() @@ -2008,7 +2011,7 @@ impl AcpThread { cx.new(|cx| { Terminal::new( terminal_id, - command, + &format!("{} {}", command, args.join(" ")), cwd, output_byte_limit.map(|l| l as usize), terminal, @@ -3181,4 +3184,65 @@ mod tests { Task::ready(Ok(())) } } + + #[gpui::test] + async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let connection = Rc::new(FakeAgentConnection::new()); + let thread = cx + .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) + .await + .unwrap(); + + // Try to update a tool call that doesn't exist + let nonexistent_id = acp::ToolCallId("nonexistent-tool-call".into()); + thread.update(cx, |thread, cx| { + let result = thread.handle_session_update( + acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate { + id: nonexistent_id.clone(), + fields: acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::Completed), + ..Default::default() + }, + meta: None, + }), + cx, + ); + + // The update should succeed (not return an error) + assert!(result.is_ok()); + + // There should now be exactly one entry in the thread + assert_eq!(thread.entries.len(), 1); + + // The entry should be a failed tool call + if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] { + assert_eq!(tool_call.id, nonexistent_id); + assert!(matches!(tool_call.status, ToolCallStatus::Failed)); + assert_eq!(tool_call.kind, acp::ToolKind::Fetch); + + // Check that the content contains the error message + assert_eq!(tool_call.content.len(), 1); + if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] { + match content_block { + ContentBlock::Markdown { markdown } => { + let markdown_text = markdown.read(cx).source(); + assert!(markdown_text.contains("Tool call not found")); + } + ContentBlock::Empty => panic!("Expected markdown content, got empty"), + ContentBlock::ResourceLink { .. } => { + panic!("Expected markdown content, got resource link") + } + } + } else { + panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]); + } + } else { + panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]); + } + }); + } } diff --git a/crates/acp_thread/src/terminal.rs b/crates/acp_thread/src/terminal.rs index a927083b0bd576f1580ba261d4028407fcea7a5c..888c7698c3d2270769f3afbe712ecba7d08b055f 100644 --- a/crates/acp_thread/src/terminal.rs +++ b/crates/acp_thread/src/terminal.rs @@ -28,7 +28,7 @@ pub struct TerminalOutput { impl Terminal { pub fn new( id: acp::TerminalId, - command: String, + command_label: &str, working_dir: Option, output_byte_limit: Option, terminal: Entity, @@ -40,7 +40,7 @@ impl Terminal { id, command: cx.new(|cx| { Markdown::new( - format!("```\n{}\n```", command).into(), + format!("```\n{}\n```", command_label).into(), Some(language_registry.clone()), None, cx, diff --git a/crates/activity_indicator/src/activity_indicator.rs b/crates/activity_indicator/src/activity_indicator.rs index 1870ab74db214b518bb0b543166067e636f14965..f35b2ad17879c57b15ac8579e6b50a26110ff21d 100644 --- a/crates/activity_indicator/src/activity_indicator.rs +++ b/crates/activity_indicator/src/activity_indicator.rs @@ -1,4 +1,4 @@ -use auto_update::{AutoUpdateStatus, AutoUpdater, DismissErrorMessage, VersionCheckType}; +use auto_update::{AutoUpdateStatus, AutoUpdater, DismissMessage, VersionCheckType}; use editor::Editor; use extension_host::{ExtensionOperation, ExtensionStore}; use futures::StreamExt; @@ -280,18 +280,13 @@ impl ActivityIndicator { }); } - fn dismiss_error_message( - &mut self, - _: &DismissErrorMessage, - _: &mut Window, - cx: &mut Context, - ) { - let error_dismissed = if let Some(updater) = &self.auto_updater { - updater.update(cx, |updater, cx| updater.dismiss_error(cx)) + fn dismiss_message(&mut self, _: &DismissMessage, _: &mut Window, cx: &mut Context) { + let dismissed = if let Some(updater) = &self.auto_updater { + updater.update(cx, |updater, cx| updater.dismiss(cx)) } else { false }; - if error_dismissed { + if dismissed { return; } @@ -513,7 +508,7 @@ impl ActivityIndicator { on_click: Some(Arc::new(move |this, window, cx| { this.statuses .retain(|status| !downloading.contains(&status.name)); - this.dismiss_error_message(&DismissErrorMessage, window, cx) + this.dismiss_message(&DismissMessage, window, cx) })), tooltip_message: None, }); @@ -542,7 +537,7 @@ impl ActivityIndicator { on_click: Some(Arc::new(move |this, window, cx| { this.statuses .retain(|status| !checking_for_update.contains(&status.name)); - this.dismiss_error_message(&DismissErrorMessage, window, cx) + this.dismiss_message(&DismissMessage, window, cx) })), tooltip_message: None, }); @@ -650,13 +645,14 @@ impl ActivityIndicator { .and_then(|updater| match &updater.read(cx).status() { AutoUpdateStatus::Checking => Some(Content { icon: Some( - Icon::new(IconName::Download) + Icon::new(IconName::LoadCircle) .size(IconSize::Small) + .with_rotate_animation(3) .into_any_element(), ), message: "Checking for Zed updates…".to_string(), on_click: Some(Arc::new(|this, window, cx| { - this.dismiss_error_message(&DismissErrorMessage, window, cx) + this.dismiss_message(&DismissMessage, window, cx) })), tooltip_message: None, }), @@ -668,19 +664,20 @@ impl ActivityIndicator { ), message: "Downloading Zed update…".to_string(), on_click: Some(Arc::new(|this, window, cx| { - this.dismiss_error_message(&DismissErrorMessage, window, cx) + this.dismiss_message(&DismissMessage, window, cx) })), tooltip_message: Some(Self::version_tooltip_message(version)), }), AutoUpdateStatus::Installing { version } => Some(Content { icon: Some( - Icon::new(IconName::Download) + Icon::new(IconName::LoadCircle) .size(IconSize::Small) + .with_rotate_animation(3) .into_any_element(), ), message: "Installing Zed update…".to_string(), on_click: Some(Arc::new(|this, window, cx| { - this.dismiss_error_message(&DismissErrorMessage, window, cx) + this.dismiss_message(&DismissMessage, window, cx) })), tooltip_message: Some(Self::version_tooltip_message(version)), }), @@ -690,17 +687,18 @@ impl ActivityIndicator { on_click: Some(Arc::new(move |_, _, cx| workspace::reload(cx))), tooltip_message: Some(Self::version_tooltip_message(version)), }), - AutoUpdateStatus::Errored => Some(Content { + AutoUpdateStatus::Errored { error } => Some(Content { icon: Some( Icon::new(IconName::Warning) .size(IconSize::Small) .into_any_element(), ), - message: "Auto update failed".to_string(), + message: "Failed to update Zed".to_string(), on_click: Some(Arc::new(|this, window, cx| { - this.dismiss_error_message(&DismissErrorMessage, window, cx) + window.dispatch_action(Box::new(workspace::OpenLog), cx); + this.dismiss_message(&DismissMessage, window, cx); })), - tooltip_message: None, + tooltip_message: Some(format!("{error}")), }), AutoUpdateStatus::Idle => None, }) @@ -738,7 +736,7 @@ impl ActivityIndicator { })), message, on_click: Some(Arc::new(|this, window, cx| { - this.dismiss_error_message(&Default::default(), window, cx) + this.dismiss_message(&Default::default(), window, cx) })), tooltip_message: None, }) @@ -777,7 +775,7 @@ impl Render for ActivityIndicator { let result = h_flex() .id("activity-indicator") .on_action(cx.listener(Self::show_error_message)) - .on_action(cx.listener(Self::dismiss_error_message)); + .on_action(cx.listener(Self::dismiss_message)); let Some(content) = self.content_to_render(cx) else { return result; }; diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index bb3fe6ff9078535b500e28f4beeab957929546a5..ca6db6c663ddb2132c05d716e5b935c5855bccdb 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -23,6 +23,7 @@ action_log.workspace = true agent-client-protocol.workspace = true agent_settings.workspace = true anyhow.workspace = true +async-trait.workspace = true client.workspace = true collections.workspace = true env_logger = { workspace = true, optional = true } @@ -30,6 +31,7 @@ fs.workspace = true futures.workspace = true gpui.workspace = true gpui_tokio = { workspace = true, optional = true } +http_client.workspace = true indoc.workspace = true language.workspace = true language_model.workspace = true diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index cc897d85e7b4de149a0dca84df84d2b8c2c5bc98..b8c75a01a2e2965c255e32bd3c0746b26d78ecab 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -13,7 +13,7 @@ use util::ResultExt as _; use std::path::PathBuf; use std::{any::Any, cell::RefCell}; -use std::{path::Path, rc::Rc, sync::Arc}; +use std::{path::Path, rc::Rc}; use thiserror::Error; use anyhow::{Context as _, Result}; @@ -505,6 +505,7 @@ struct ClientDelegate { cx: AsyncApp, } +#[async_trait::async_trait(?Send)] impl acp::Client for ClientDelegate { async fn request_permission( &self, @@ -638,19 +639,11 @@ impl acp::Client for ClientDelegate { Ok(Default::default()) } - async fn ext_method( - &self, - _name: Arc, - _params: Arc, - ) -> Result, acp::Error> { + async fn ext_method(&self, _args: acp::ExtRequest) -> Result { Err(acp::Error::method_not_found()) } - async fn ext_notification( - &self, - _name: Arc, - _params: Arc, - ) -> Result<(), acp::Error> { + async fn ext_notification(&self, _args: acp::ExtNotification) -> Result<(), acp::Error> { Err(acp::Error::method_not_found()) } diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index 2c2900cb79328249355704606652c54d08f072e5..b9751d7f63053bf073bcc8181f0cc2f8211d5c9f 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -7,15 +7,19 @@ mod gemini; pub mod e2e_tests; pub use claude::*; +use client::ProxySettings; +use collections::HashMap; pub use custom::*; use fs::Fs; pub use gemini::*; +use http_client::read_no_proxy_from_env; use project::agent_server_store::AgentServerStore; use acp_thread::AgentConnection; use anyhow::Result; -use gpui::{App, Entity, SharedString, Task}; +use gpui::{App, AppContext, Entity, SharedString, Task}; use project::Project; +use settings::SettingsStore; use std::{any::Any, path::Path, rc::Rc, sync::Arc}; pub use acp::AcpConnection; @@ -77,3 +81,25 @@ impl dyn AgentServer { self.into_any().downcast().ok() } } + +/// Load the default proxy environment variables to pass through to the agent +pub fn load_proxy_env(cx: &mut App) -> HashMap { + let proxy_url = cx + .read_global(|settings: &SettingsStore, _| settings.get::(None).proxy_url()); + let mut env = HashMap::default(); + + if let Some(proxy_url) = &proxy_url { + let env_var = if proxy_url.scheme() == "https" { + "HTTPS_PROXY" + } else { + "HTTP_PROXY" + }; + env.insert(env_var.to_owned(), proxy_url.to_string()); + } + + if let Some(no_proxy) = read_no_proxy_from_env() { + env.insert("NO_PROXY".to_owned(), no_proxy); + } + + env +} diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 489839d82244fe76f6e9d1e9ea025a7b7c4a3bf7..4646b2e8259fa2cd63c0daa67b47f66b5e78af05 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -10,7 +10,7 @@ use anyhow::{Context as _, Result}; use gpui::{App, AppContext as _, SharedString, Task}; use project::agent_server_store::{AllAgentServersSettings, CLAUDE_CODE_NAME}; -use crate::{AgentServer, AgentServerDelegate}; +use crate::{AgentServer, AgentServerDelegate, load_proxy_env}; use acp_thread::AgentConnection; #[derive(Clone)] @@ -65,6 +65,7 @@ impl AgentServer for ClaudeCode { let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().to_string()); let is_remote = delegate.project.read(cx).is_via_remote_server(); let store = delegate.store.downgrade(); + let extra_env = load_proxy_env(cx); let default_mode = self.default_mode(cx); cx.spawn(async move |cx| { @@ -75,7 +76,7 @@ impl AgentServer for ClaudeCode { .context("Claude Code is not registered")?; anyhow::Ok(agent.get_command( root_dir.as_deref(), - Default::default(), + extra_env, delegate.status_tx, delegate.new_version_available, &mut cx.to_async(), diff --git a/crates/agent_servers/src/custom.rs b/crates/agent_servers/src/custom.rs index aa2bbc0868dc64c5b415c445d19a357eb4b2ea85..cb9a6dba3c6376fa5030c21523c86853c9b6d761 100644 --- a/crates/agent_servers/src/custom.rs +++ b/crates/agent_servers/src/custom.rs @@ -1,4 +1,4 @@ -use crate::AgentServerDelegate; +use crate::{AgentServerDelegate, load_proxy_env}; use acp_thread::AgentConnection; use agent_client_protocol as acp; use anyhow::{Context as _, Result}; @@ -71,6 +71,7 @@ impl crate::AgentServer for CustomAgentServer { let is_remote = delegate.project.read(cx).is_via_remote_server(); let default_mode = self.default_mode(cx); let store = delegate.store.downgrade(); + let extra_env = load_proxy_env(cx); cx.spawn(async move |cx| { let (command, root_dir, login) = store @@ -82,7 +83,7 @@ impl crate::AgentServer for CustomAgentServer { })?; anyhow::Ok(agent.get_command( root_dir.as_deref(), - Default::default(), + extra_env, delegate.status_tx, delegate.new_version_available, &mut cx.to_async(), diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index 01f15557899e1c7826e91d1555320996eccd0f45..9407a42e68d34e38e78f2103b29f980f874fb3db 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -1,15 +1,12 @@ use std::rc::Rc; use std::{any::Any, path::Path}; -use crate::{AgentServer, AgentServerDelegate}; +use crate::{AgentServer, AgentServerDelegate, load_proxy_env}; use acp_thread::AgentConnection; use anyhow::{Context as _, Result}; -use client::ProxySettings; -use collections::HashMap; -use gpui::{App, AppContext, SharedString, Task}; +use gpui::{App, SharedString, Task}; use language_models::provider::google::GoogleLanguageModelProvider; use project::agent_server_store::GEMINI_NAME; -use settings::SettingsStore; #[derive(Clone)] pub struct Gemini; @@ -37,17 +34,20 @@ impl AgentServer for Gemini { let root_dir = root_dir.map(|root_dir| root_dir.to_string_lossy().to_string()); let is_remote = delegate.project.read(cx).is_via_remote_server(); let store = delegate.store.downgrade(); - let proxy_url = cx.read_global(|settings: &SettingsStore, _| { - settings.get::(None).proxy.clone() - }); + let mut extra_env = load_proxy_env(cx); let default_mode = self.default_mode(cx); cx.spawn(async move |cx| { - let mut extra_env = HashMap::default(); - if let Some(api_key) = cx.update(GoogleLanguageModelProvider::api_key)?.await.ok() { - extra_env.insert("GEMINI_API_KEY".into(), api_key.key); + extra_env.insert("SURFACE".to_owned(), "zed".to_owned()); + + if let Some(api_key) = cx + .update(GoogleLanguageModelProvider::api_key_for_gemini_cli)? + .await + .ok() + { + extra_env.insert("GEMINI_API_KEY".into(), api_key); } - let (mut command, root_dir, login) = store + let (command, root_dir, login) = store .update(cx, |store, cx| { let agent = store .get_external_agent(&GEMINI_NAME.into()) @@ -62,14 +62,6 @@ impl AgentServer for Gemini { })?? .await?; - // Add proxy flag if proxy settings are configured in Zed and not in the args - if let Some(proxy_url_value) = &proxy_url - && !command.args.iter().any(|arg| arg.contains("--proxy")) - { - command.args.push("--proxy".into()); - command.args.push(proxy_url_value.clone()); - } - let connection = crate::acp::connect( name, command, diff --git a/crates/agent_servers/src/settings.rs b/crates/agent_servers/src/settings.rs deleted file mode 100644 index 9a610465be5516664dafd9cd4cb46be96ad89c8b..0000000000000000000000000000000000000000 --- a/crates/agent_servers/src/settings.rs +++ /dev/null @@ -1,125 +0,0 @@ -use agent_client_protocol as acp; -use std::path::PathBuf; - -use crate::AgentServerCommand; -use anyhow::Result; -use collections::HashMap; -use gpui::{App, SharedString}; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsKey, SettingsSources, SettingsUi}; - -pub fn init(cx: &mut App) { - AllAgentServersSettings::register(cx); -} - -#[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug, SettingsUi, SettingsKey)] -#[settings_key(key = "agent_servers")] -pub struct AllAgentServersSettings { - pub gemini: Option, - pub claude: Option, - - /// Custom agent servers configured by the user - #[serde(flatten)] - pub custom: HashMap, -} - -#[derive(Default, Deserialize, Serialize, Clone, JsonSchema, Debug, PartialEq)] -pub struct BuiltinAgentServerSettings { - /// Absolute path to a binary to be used when launching this agent. - /// - /// This can be used to run a specific binary without automatic downloads or searching `$PATH`. - #[serde(rename = "command")] - pub path: Option, - /// If a binary is specified in `command`, it will be passed these arguments. - pub args: Option>, - /// If a binary is specified in `command`, it will be passed these environment variables. - pub env: Option>, - /// Whether to skip searching `$PATH` for an agent server binary when - /// launching this agent. - /// - /// This has no effect if a `command` is specified. Otherwise, when this is - /// `false`, Zed will search `$PATH` for an agent server binary and, if one - /// is found, use it for threads with this agent. If no agent binary is - /// found on `$PATH`, Zed will automatically install and use its own binary. - /// When this is `true`, Zed will not search `$PATH`, and will always use - /// its own binary. - /// - /// Default: true - pub ignore_system_version: Option, - /// The default mode for new threads. - /// - /// Note: Not all agents support modes. - /// - /// Default: None - #[serde(skip_serializing_if = "Option::is_none")] - pub default_mode: Option, -} - -impl BuiltinAgentServerSettings { - pub(crate) fn custom_command(self) -> Option { - self.path.map(|path| AgentServerCommand { - path, - args: self.args.unwrap_or_default(), - env: self.env, - }) - } -} - -impl From for BuiltinAgentServerSettings { - fn from(value: AgentServerCommand) -> Self { - BuiltinAgentServerSettings { - path: Some(value.path), - args: Some(value.args), - env: value.env, - ..Default::default() - } - } -} - -#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug, PartialEq)] -pub struct CustomAgentServerSettings { - #[serde(flatten)] - pub command: AgentServerCommand, - /// The default mode for new threads. - /// - /// Note: Not all agents support modes. - /// - /// Default: None - #[serde(skip_serializing_if = "Option::is_none")] - pub default_mode: Option, -} - -impl settings::Settings for AllAgentServersSettings { - type FileContent = Self; - - fn load(sources: SettingsSources, _: &mut App) -> Result { - let mut settings = AllAgentServersSettings::default(); - - for AllAgentServersSettings { - gemini, - claude, - custom, - } in sources.defaults_and_customizations() - { - if gemini.is_some() { - settings.gemini = gemini.clone(); - } - if claude.is_some() { - settings.claude = claude.clone(); - } - - // Merge custom agents - for (name, config) in custom { - // Skip built-in agent names to avoid conflicts - if name != "gemini" && name != "claude" { - settings.custom.insert(name.clone(), config.clone()); - } - } - } - - Ok(settings) - } - - fn import_from_vscode(_vscode: &settings::VsCodeSettings, _current: &mut Self::FileContent) {} -} diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index e5744458601b7fd208bd97c11da7d136cb329f05..e0389a47ce015f0644f7ebfe0025b8c0d74fdcd0 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -50,6 +50,7 @@ pub struct AgentSettings { pub expand_edit_card: bool, pub expand_terminal_card: bool, pub use_modifier_to_send: bool, + pub message_editor_min_lines: usize, } impl AgentSettings { @@ -91,6 +92,10 @@ impl AgentSettings { model, }); } + + pub fn set_message_editor_max_lines(&self) -> usize { + self.message_editor_min_lines * 2 + } } #[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Default)] @@ -175,6 +180,7 @@ impl Settings for AgentSettings { expand_edit_card: agent.expand_edit_card.unwrap(), expand_terminal_card: agent.expand_terminal_card.unwrap(), use_modifier_to_send: agent.use_modifier_to_send.unwrap(), + message_editor_min_lines: agent.message_editor_min_lines.unwrap(), } } @@ -224,6 +230,8 @@ impl Settings for AgentSettings { self.model_parameters .extend_from_slice(&value.model_parameters); + self.message_editor_min_lines + .merge_from(&value.message_editor_min_lines); if let Some(profiles) = value.profiles.as_ref() { self.profiles.extend( diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml index 028db95c10a8c7a319bb05927dcabd0564a14683..47d9f6d6a27a2ad5102e831094912208e66a9b43 100644 --- a/crates/agent_ui/Cargo.toml +++ b/crates/agent_ui/Cargo.toml @@ -80,7 +80,6 @@ serde.workspace = true serde_json.workspace = true serde_json_lenient.workspace = true settings.workspace = true -shlex.workspace = true smol.workspace = true streaming_diff.workspace = true task.workspace = true diff --git a/crates/agent_ui/src/acp/message_editor.rs b/crates/agent_ui/src/acp/message_editor.rs index ab4e8d680925c96e64fdb2e7707bea9c1e177b5c..2734726ddbe1608356bcc34af5c42f479d5a8e8a 100644 --- a/crates/agent_ui/src/acp/message_editor.rs +++ b/crates/agent_ui/src/acp/message_editor.rs @@ -1099,11 +1099,16 @@ impl MessageEditor { } pub fn insert_selections(&mut self, window: &mut Window, cx: &mut Context) { - let buffer = self.editor.read(cx).buffer().clone(); - let Some(buffer) = buffer.read(cx).as_singleton() else { + let editor = self.editor.read(cx); + let editor_buffer = editor.buffer().read(cx); + let Some(buffer) = editor_buffer.as_singleton() else { return; }; - let anchor = buffer.update(cx, |buffer, _cx| buffer.anchor_before(buffer.len())); + let cursor_anchor = editor.selections.newest_anchor().head(); + let cursor_offset = cursor_anchor.to_offset(&editor_buffer.snapshot(cx)); + let anchor = buffer.update(cx, |buffer, _cx| { + buffer.anchor_before(cursor_offset.min(buffer.len())) + }); let Some(workspace) = self.workspace.upgrade() else { return; }; @@ -1117,13 +1122,7 @@ impl MessageEditor { return; }; self.editor.update(cx, |message_editor, cx| { - message_editor.edit( - [( - multi_buffer::Anchor::max()..multi_buffer::Anchor::max(), - completion.new_text, - )], - cx, - ); + message_editor.edit([(cursor_anchor..cursor_anchor, completion.new_text)], cx); }); if let Some(confirm) = completion.confirm { confirm(CompletionIntent::Complete, window, cx); diff --git a/crates/agent_ui/src/acp/mode_selector.rs b/crates/agent_ui/src/acp/mode_selector.rs index d4d424f41a36652881a50b65bc3dbe00c18fdcde..410874126665b7d622c7cf45e81596dce7f96823 100644 --- a/crates/agent_ui/src/acp/mode_selector.rs +++ b/crates/agent_ui/src/acp/mode_selector.rs @@ -107,13 +107,15 @@ impl ModeSelector { .text_sm() .text_color(Color::Muted.color(cx)) .child("Hold") - .child(div().pt_0p5().children(ui::render_modifiers( - &gpui::Modifiers::secondary_key(), - PlatformStyle::platform(), - None, - Some(ui::TextSize::Default.rems(cx).into()), - true, - ))) + .child(h_flex().flex_shrink_0().children( + ui::render_modifiers( + &gpui::Modifiers::secondary_key(), + PlatformStyle::platform(), + None, + Some(ui::TextSize::Default.rems(cx).into()), + true, + ), + )) .child(div().map(|this| { if is_default { this.child("to also unset as default") diff --git a/crates/agent_ui/src/acp/thread_history.rs b/crates/agent_ui/src/acp/thread_history.rs index ed508ea18da7df3426fc13b137b97f37267ed283..cd696f33fa44976e0784c79d1945b548feb20a50 100644 --- a/crates/agent_ui/src/acp/thread_history.rs +++ b/crates/agent_ui/src/acp/thread_history.rs @@ -500,20 +500,24 @@ impl Render for AcpThreadHistory { ), ) } else { - view.pr_5() - .child( - uniform_list( - "thread-history", - self.visible_items.len(), - cx.processor(|this, range: Range, window, cx| { - this.render_list_items(range, window, cx) - }), - ) - .p_1() - .track_scroll(self.scroll_handle.clone()) - .flex_grow(), + view.child( + uniform_list( + "thread-history", + self.visible_items.len(), + cx.processor(|this, range: Range, window, cx| { + this.render_list_items(range, window, cx) + }), ) - .vertical_scrollbar_for(self.scroll_handle.clone(), window, cx) + .p_1() + .pr_4() + .track_scroll(self.scroll_handle.clone()) + .flex_grow(), + ) + .vertical_scrollbar_for( + self.scroll_handle.clone(), + window, + cx, + ) } }) } diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index fd848d2c42ab6ae4bea9255876793dba22022760..ac84fd36f24a850330a5f20b979bcc0f8fc442ad 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -9,7 +9,7 @@ use agent_client_protocol::{self as acp, PromptCapabilities}; use agent_servers::{AgentServer, AgentServerDelegate}; use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; use agent2::{DbThreadMetadata, HistoryEntry, HistoryEntryId, HistoryStore, NativeAgentServer}; -use anyhow::{Context as _, Result, anyhow, bail}; +use anyhow::{Result, anyhow, bail}; use arrayvec::ArrayVec; use audio::{Audio, Sound}; use buffer_diff::BufferDiff; @@ -71,9 +71,6 @@ use crate::{ RejectOnce, ToggleBurnMode, ToggleProfileSelector, }; -pub const MIN_EDITOR_LINES: usize = 4; -pub const MAX_EDITOR_LINES: usize = 8; - #[derive(Copy, Clone, Debug, PartialEq, Eq)] enum ThreadFeedback { Positive, @@ -357,8 +354,8 @@ impl AcpThreadView { agent.name(), &placeholder, editor::EditorMode::AutoHeight { - min_lines: MIN_EDITOR_LINES, - max_lines: Some(MAX_EDITOR_LINES), + min_lines: AgentSettings::get_global(cx).message_editor_min_lines, + max_lines: Some(AgentSettings::get_global(cx).set_message_editor_max_lines()), }, window, cx, @@ -857,10 +854,11 @@ impl AcpThreadView { cx, ) } else { + let agent_settings = AgentSettings::get_global(cx); editor.set_mode( EditorMode::AutoHeight { - min_lines: MIN_EDITOR_LINES, - max_lines: Some(MAX_EDITOR_LINES), + min_lines: agent_settings.message_editor_min_lines, + max_lines: Some(agent_settings.set_message_editor_max_lines()), }, cx, ) @@ -1584,19 +1582,6 @@ impl AcpThreadView { window.spawn(cx, async move |cx| { let mut task = login.clone(); - task.command = task - .command - .map(|command| anyhow::Ok(shlex::try_quote(&command)?.to_string())) - .transpose()?; - task.args = task - .args - .iter() - .map(|arg| { - Ok(shlex::try_quote(arg) - .context("Failed to quote argument")? - .to_string()) - }) - .collect::>>()?; task.full_label = task.label.clone(); task.id = task::TaskId(format!("external-agent-{}-login", task.label)); task.command_label = task.label.clone(); @@ -3197,10 +3182,14 @@ impl AcpThreadView { }; Button::new(SharedString::from(method_id.clone()), name) - .when(ix == 0, |el| { - el.style(ButtonStyle::Tinted(ui::TintColor::Warning)) - }) .label_size(LabelSize::Small) + .map(|this| { + if ix == 0 { + this.style(ButtonStyle::Tinted(TintColor::Warning)) + } else { + this.style(ButtonStyle::Outlined) + } + }) .on_click({ cx.listener(move |this, _, window, cx| { telemetry::event!( @@ -5680,6 +5669,23 @@ pub(crate) mod tests { }); } + #[gpui::test] + async fn test_spawn_external_agent_login_handles_spaces(cx: &mut TestAppContext) { + init_test(cx); + + // Verify paths with spaces aren't pre-quoted + let path_with_spaces = "/Users/test/Library/Application Support/Zed/cli.js"; + let login_task = task::SpawnInTerminal { + command: Some("node".to_string()), + args: vec![path_with_spaces.to_string(), "/login".to_string()], + ..Default::default() + }; + + // Args should be passed as-is, not pre-quoted + assert!(!login_task.args[0].starts_with('"')); + assert!(!login_task.args[0].starts_with('\'')); + } + #[gpui::test] async fn test_notification_for_tool_authorization(cx: &mut TestAppContext) { init_test(cx); diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index ea1ecb10e55241368bacc636b75d182b64210bc7..382a9db2573a21bcb74e75d15a6a87c0aa412588 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -272,13 +272,28 @@ impl AgentConfiguration { *is_expanded = !*is_expanded; } })), - ) - .when(provider.is_authenticated(cx), |parent| { + ), + ) + .child( + v_flex() + .w_full() + .px_2() + .gap_1() + .when(is_expanded, |parent| match configuration_view { + Some(configuration_view) => parent.child(configuration_view), + None => parent.child(Label::new(format!( + "No configuration view for {provider_name}", + ))), + }) + .when(is_expanded && provider.is_authenticated(cx), |parent| { parent.child( Button::new( SharedString::from(format!("new-thread-{provider_id}")), "Start New Thread", ) + .full_width() + .style(ButtonStyle::Filled) + .layer(ElevationIndex::ModalSurface) .icon_position(IconPosition::Start) .icon(IconName::Thread) .icon_size(IconSize::Small) @@ -295,17 +310,6 @@ impl AgentConfiguration { ) }), ) - .child( - div() - .w_full() - .px_2() - .when(is_expanded, |parent| match configuration_view { - Some(configuration_view) => parent.child(configuration_view), - None => parent.child(Label::new(format!( - "No configuration view for {provider_name}", - ))), - }), - ) } fn render_provider_configuration_section( @@ -562,11 +566,28 @@ impl AgentConfiguration { .color(Color::Muted), ), ) - .children( - context_server_ids.into_iter().map(|context_server_id| { - self.render_context_server(context_server_id, window, cx) - }), - ) + .map(|parent| { + if context_server_ids.is_empty() { + parent.child( + h_flex() + .p_4() + .justify_center() + .border_1() + .border_dashed() + .border_color(cx.theme().colors().border.opacity(0.6)) + .rounded_sm() + .child( + Label::new("No MCP servers added yet.") + .color(Color::Muted) + .size(LabelSize::Small), + ), + ) + } else { + parent.children(context_server_ids.into_iter().map(|context_server_id| { + self.render_context_server(context_server_id, window, cx) + })) + } + }) .child( h_flex() .justify_between() @@ -819,6 +840,8 @@ impl AgentConfiguration { ) .child( h_flex() + .flex_1() + .min_w_0() .child( Disclosure::new( "tool-list-disclosure", @@ -842,17 +865,19 @@ impl AgentConfiguration { .id(SharedString::from(format!("tooltip-{}", item_id))) .h_full() .w_3() - .mx_1() + .ml_1() + .mr_1p5() .justify_center() .tooltip(Tooltip::text(tooltip_text)) .child(status_indicator), ) - .child(Label::new(item_id).ml_0p5()) + .child(Label::new(item_id).truncate()) .child( div() .id("extension-source") .mt_0p5() .mx_1() + .flex_none() .tooltip(Tooltip::text(source_tooltip)) .child( Icon::new(source_icon) @@ -874,7 +899,8 @@ impl AgentConfiguration { ) .child( h_flex() - .gap_1() + .gap_0p5() + .flex_none() .child(context_server_configuration_menu) .child( Switch::new("context-server-switch", is_running.into()) @@ -1110,6 +1136,7 @@ impl AgentConfiguration { SharedString::from(format!("start_acp_thread-{name}")), "Start New Thread", ) + .layer(ElevationIndex::ModalSurface) .label_size(LabelSize::Small) .icon(IconName::Thread) .icon_position(IconPosition::Start) diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index f4c3fe1069eb42ad53c5771d4a511f88ff780664..e34789d62d2c95b06f5c4f03b93b60f01c6dbf6a 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -14,7 +14,6 @@ mod message_editor; mod profile_selector; mod slash_command; mod slash_command_picker; -mod slash_command_settings; mod terminal_codegen; mod terminal_inline_assistant; mod text_thread_editor; @@ -46,7 +45,6 @@ use std::any::TypeId; use crate::agent_configuration::{ConfigureContextServerModal, ManageProfilesModal}; pub use crate::agent_panel::{AgentPanel, ConcreteAssistantPanelDelegate}; pub use crate::inline_assistant::InlineAssistant; -use crate::slash_command_settings::SlashCommandSettings; pub use agent_diff::{AgentDiffPane, AgentDiffToolbar}; pub use text_thread_editor::{AgentPanelDelegate, TextThreadEditor}; use zed_actions; @@ -257,7 +255,6 @@ pub fn init( cx: &mut App, ) { AgentSettings::register(cx); - SlashCommandSettings::register(cx); assistant_context::init(client.clone(), cx); rules_library::init(cx); @@ -413,8 +410,6 @@ fn register_slash_commands(cx: &mut App) { slash_command_registry.register_command(assistant_slash_commands::DeltaSlashCommand, true); slash_command_registry.register_command(assistant_slash_commands::OutlineSlashCommand, true); slash_command_registry.register_command(assistant_slash_commands::TabSlashCommand, true); - slash_command_registry - .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true); slash_command_registry.register_command(assistant_slash_commands::PromptSlashCommand, true); slash_command_registry.register_command(assistant_slash_commands::SelectionCommand, true); slash_command_registry.register_command(assistant_slash_commands::DefaultSlashCommand, false); @@ -434,21 +429,4 @@ fn register_slash_commands(cx: &mut App) { } }) .detach(); - - update_slash_commands_from_settings(cx); - cx.observe_global::(update_slash_commands_from_settings) - .detach(); -} - -fn update_slash_commands_from_settings(cx: &mut App) { - let slash_command_registry = SlashCommandRegistry::global(cx); - let settings = SlashCommandSettings::get_global(cx); - - if settings.cargo_workspace.enabled { - slash_command_registry - .register_command(assistant_slash_commands::CargoWorkspaceSlashCommand, true); - } else { - slash_command_registry - .unregister_command(assistant_slash_commands::CargoWorkspaceSlashCommand); - } } diff --git a/crates/agent_ui/src/context_picker/completion_provider.rs b/crates/agent_ui/src/context_picker/completion_provider.rs index c9cd69bf8e49b2e4f20148640cd029caea51264f..01a7a51316eee4709eaf9c17c8840e3cd637a62b 100644 --- a/crates/agent_ui/src/context_picker/completion_provider.rs +++ b/crates/agent_ui/src/context_picker/completion_provider.rs @@ -743,15 +743,15 @@ impl CompletionProvider for ContextPickerCompletionProvider { _window: &mut Window, cx: &mut Context, ) -> Task>> { - let state = buffer.update(cx, |buffer, _cx| { - let position = buffer_position.to_point(buffer); - let line_start = Point::new(position.row, 0); - let offset_to_line = buffer.point_to_offset(line_start); - let mut lines = buffer.text_for_range(line_start..position).lines(); - let line = lines.next()?; - MentionCompletion::try_parse(line, offset_to_line) - }); - let Some(state) = state else { + let snapshot = buffer.read(cx).snapshot(); + let position = buffer_position.to_point(&snapshot); + let line_start = Point::new(position.row, 0); + let offset_to_line = snapshot.point_to_offset(line_start); + let mut lines = snapshot.text_for_range(line_start..position).lines(); + let Some(line) = lines.next() else { + return Task::ready(Ok(Vec::new())); + }; + let Some(state) = MentionCompletion::try_parse(line, offset_to_line) else { return Task::ready(Ok(Vec::new())); }; @@ -761,7 +761,6 @@ impl CompletionProvider for ContextPickerCompletionProvider { return Task::ready(Ok(Vec::new())); }; - let snapshot = buffer.read(cx).snapshot(); let source_range = snapshot.anchor_before(state.source_range.start) ..snapshot.anchor_after(state.source_range.end); diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index 4ac88e6daa3d3623580e206c2759f27b218d1bac..79e092b709dd2778c89a79e1d6ce36802c853eb6 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -744,19 +744,14 @@ impl InlineAssistant { .update(cx, |editor, cx| { let scroll_top = editor.scroll_position(cx).y; let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.); - let prompt_row = editor + editor_assists.scroll_lock = editor .row_for_block(decorations.prompt_block_id, cx) - .unwrap() - .0 as f32; - - if (scroll_top..scroll_bottom).contains(&prompt_row) { - editor_assists.scroll_lock = Some(InlineAssistScrollLock { + .map(|row| row.0 as f32) + .filter(|prompt_row| (scroll_top..scroll_bottom).contains(&prompt_row)) + .map(|prompt_row| InlineAssistScrollLock { assist_id, distance_from_top: prompt_row - scroll_top, }); - } else { - editor_assists.scroll_lock = None; - } }) .ok(); } @@ -917,14 +912,12 @@ impl InlineAssistant { editor.update(cx, |editor, cx| { let scroll_position = editor.scroll_position(cx); - let target_scroll_top = editor - .row_for_block(decorations.prompt_block_id, cx) - .unwrap() - .0 as f32 + let target_scroll_top = editor.row_for_block(decorations.prompt_block_id, cx)?.0 as f32 - scroll_lock.distance_from_top; if target_scroll_top != scroll_position.y { editor.set_scroll_position(point(scroll_position.x, target_scroll_top), window, cx); } + Some(()) }); } @@ -968,14 +961,14 @@ impl InlineAssistant { if let Some(decorations) = assist.decorations.as_ref() { let distance_from_top = editor.update(cx, |editor, cx| { let scroll_top = editor.scroll_position(cx).y; - let prompt_row = editor - .row_for_block(decorations.prompt_block_id, cx) - .unwrap() - .0 as f32; - prompt_row - scroll_top + let prompt_row = + editor.row_for_block(decorations.prompt_block_id, cx)?.0 as f32; + Some(prompt_row - scroll_top) }); - if distance_from_top != scroll_lock.distance_from_top { + if distance_from_top.is_none_or(|distance_from_top| { + distance_from_top != scroll_lock.distance_from_top + }) { editor_assists.scroll_lock = None; } } diff --git a/crates/agent_ui/src/slash_command_settings.rs b/crates/agent_ui/src/slash_command_settings.rs deleted file mode 100644 index f0a04c6b49984ae94d629f1bbfa96c6de4e01606..0000000000000000000000000000000000000000 --- a/crates/agent_ui/src/slash_command_settings.rs +++ /dev/null @@ -1,27 +0,0 @@ -use gpui::App; -use settings::Settings; - -/// Settings for slash commands. -#[derive(Debug, Default, Clone)] -pub struct SlashCommandSettings { - /// Settings for the `/cargo-workspace` slash command. - pub cargo_workspace: CargoWorkspaceCommandSettings, -} - -/// Settings for the `/cargo-workspace` slash command. -#[derive(Debug, Default, Clone)] -pub struct CargoWorkspaceCommandSettings { - /// Whether `/cargo-workspace` is enabled. - pub enabled: bool, -} - -// todo!() I think this setting is bogus... default.json has "slash_commands": {"project"} -impl Settings for SlashCommandSettings { - fn from_defaults(_content: &settings::SettingsContent, _cx: &mut App) -> Self { - Self { - cargo_workspace: CargoWorkspaceCommandSettings { enabled: false }, - } - } - - fn refine(&mut self, _content: &settings::SettingsContent, _cx: &mut App) {} -} diff --git a/crates/agent_ui/src/text_thread_editor.rs b/crates/agent_ui/src/text_thread_editor.rs index b40b996ae74f93220d88c97e8ae7d99dd8576cf1..dadb6263c765a10fedc01b25ae5dd3dded19b877 100644 --- a/crates/agent_ui/src/text_thread_editor.rs +++ b/crates/agent_ui/src/text_thread_editor.rs @@ -485,7 +485,7 @@ impl TextThreadEditor { return; } - let selections = self.editor.read(cx).selections.disjoint_anchors(); + let selections = self.editor.read(cx).selections.disjoint_anchors_arc(); let mut commands_by_range = HashMap::default(); let workspace = self.workspace.clone(); self.context.update(cx, |context, cx| { @@ -1831,7 +1831,7 @@ impl TextThreadEditor { fn split(&mut self, _: &Split, _window: &mut Window, cx: &mut Context) { self.context.update(cx, |context, cx| { - let selections = self.editor.read(cx).selections.disjoint_anchors(); + let selections = self.editor.read(cx).selections.disjoint_anchors_arc(); for selection in selections.as_ref() { let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx); let range = selection diff --git a/crates/assistant_slash_commands/Cargo.toml b/crates/assistant_slash_commands/Cargo.toml index c054c3ced84825bcd131bdd76644c00595c4c4a9..f151515d4235b7ecb150539aceb1c5478960517b 100644 --- a/crates/assistant_slash_commands/Cargo.toml +++ b/crates/assistant_slash_commands/Cargo.toml @@ -14,7 +14,6 @@ path = "src/assistant_slash_commands.rs" [dependencies] anyhow.workspace = true assistant_slash_command.workspace = true -cargo_toml.workspace = true chrono.workspace = true collections.workspace = true context_server.workspace = true @@ -35,7 +34,6 @@ serde.workspace = true serde_json.workspace = true smol.workspace = true text.workspace = true -toml.workspace = true ui.workspace = true util.workspace = true workspace.workspace = true diff --git a/crates/assistant_slash_commands/src/assistant_slash_commands.rs b/crates/assistant_slash_commands/src/assistant_slash_commands.rs index fb00a912197e07942a67ad92418b85c4920ad66b..2bf2573e99d7a5a0140c1972967ec68523b0b56a 100644 --- a/crates/assistant_slash_commands/src/assistant_slash_commands.rs +++ b/crates/assistant_slash_commands/src/assistant_slash_commands.rs @@ -1,4 +1,3 @@ -mod cargo_workspace_command; mod context_server_command; mod default_command; mod delta_command; @@ -12,7 +11,6 @@ mod streaming_example_command; mod symbols_command; mod tab_command; -pub use crate::cargo_workspace_command::*; pub use crate::context_server_command::*; pub use crate::default_command::*; pub use crate::delta_command::*; diff --git a/crates/assistant_slash_commands/src/cargo_workspace_command.rs b/crates/assistant_slash_commands/src/cargo_workspace_command.rs deleted file mode 100644 index d58b2edc4c3dffd799dd9eb1c104686dc6488687..0000000000000000000000000000000000000000 --- a/crates/assistant_slash_commands/src/cargo_workspace_command.rs +++ /dev/null @@ -1,158 +0,0 @@ -use anyhow::{Context as _, Result, anyhow}; -use assistant_slash_command::{ - ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection, - SlashCommandResult, -}; -use fs::Fs; -use gpui::{App, Entity, Task, WeakEntity}; -use language::{BufferSnapshot, LspAdapterDelegate}; -use project::{Project, ProjectPath}; -use std::{ - fmt::Write, - path::Path, - sync::{Arc, atomic::AtomicBool}, -}; -use ui::prelude::*; -use workspace::Workspace; - -pub struct CargoWorkspaceSlashCommand; - -impl CargoWorkspaceSlashCommand { - async fn build_message(fs: Arc, path_to_cargo_toml: &Path) -> Result { - let buffer = fs.load(path_to_cargo_toml).await?; - let cargo_toml: cargo_toml::Manifest = toml::from_str(&buffer)?; - - let mut message = String::new(); - writeln!(message, "You are in a Rust project.")?; - - if let Some(workspace) = cargo_toml.workspace { - writeln!( - message, - "The project is a Cargo workspace with the following members:" - )?; - for member in workspace.members { - writeln!(message, "- {member}")?; - } - - if !workspace.default_members.is_empty() { - writeln!(message, "The default members are:")?; - for member in workspace.default_members { - writeln!(message, "- {member}")?; - } - } - - if !workspace.dependencies.is_empty() { - writeln!( - message, - "The following workspace dependencies are installed:" - )?; - for dependency in workspace.dependencies.keys() { - writeln!(message, "- {dependency}")?; - } - } - } else if let Some(package) = cargo_toml.package { - writeln!( - message, - "The project name is \"{name}\".", - name = package.name - )?; - - let description = package - .description - .as_ref() - .and_then(|description| description.get().ok().cloned()); - if let Some(description) = description.as_ref() { - writeln!(message, "It describes itself as \"{description}\".")?; - } - - if !cargo_toml.dependencies.is_empty() { - writeln!(message, "The following dependencies are installed:")?; - for dependency in cargo_toml.dependencies.keys() { - writeln!(message, "- {dependency}")?; - } - } - } - - Ok(message) - } - - fn path_to_cargo_toml(project: Entity, cx: &mut App) -> Option> { - let worktree = project.read(cx).worktrees(cx).next()?; - let worktree = worktree.read(cx); - let entry = worktree.entry_for_path("Cargo.toml")?; - let path = ProjectPath { - worktree_id: worktree.id(), - path: entry.path.clone(), - }; - Some(Arc::from( - project.read(cx).absolute_path(&path, cx)?.as_path(), - )) - } -} - -impl SlashCommand for CargoWorkspaceSlashCommand { - fn name(&self) -> String { - "cargo-workspace".into() - } - - fn description(&self) -> String { - "insert project workspace metadata".into() - } - - fn menu_text(&self) -> String { - "Insert Project Workspace Metadata".into() - } - - fn complete_argument( - self: Arc, - _arguments: &[String], - _cancel: Arc, - _workspace: Option>, - _window: &mut Window, - _cx: &mut App, - ) -> Task>> { - Task::ready(Err(anyhow!("this command does not require argument"))) - } - - fn requires_argument(&self) -> bool { - false - } - - fn run( - self: Arc, - _arguments: &[String], - _context_slash_command_output_sections: &[SlashCommandOutputSection], - _context_buffer: BufferSnapshot, - workspace: WeakEntity, - _delegate: Option>, - _window: &mut Window, - cx: &mut App, - ) -> Task { - let output = workspace.update(cx, |workspace, cx| { - let project = workspace.project().clone(); - let fs = workspace.project().read(cx).fs().clone(); - let path = Self::path_to_cargo_toml(project, cx); - let output = cx.background_spawn(async move { - let path = path.with_context(|| "Cargo.toml not found")?; - Self::build_message(fs, &path).await - }); - - cx.foreground_executor().spawn(async move { - let text = output.await?; - let range = 0..text.len(); - Ok(SlashCommandOutput { - text, - sections: vec![SlashCommandOutputSection { - range, - icon: IconName::FileTree, - label: "Project".into(), - metadata: None, - }], - run_commands_in_text: false, - } - .into_event_stream()) - }) - }); - output.unwrap_or_else(|error| Task::ready(Err(error))) - } -} diff --git a/crates/assistant_tools/Cargo.toml b/crates/assistant_tools/Cargo.toml index 5a8ca8a5e995fd2c738eb3b309f2bb4ebe9595a1..9b9b8196d1c342c536d605306a1a062e73768c56 100644 --- a/crates/assistant_tools/Cargo.toml +++ b/crates/assistant_tools/Cargo.toml @@ -63,7 +63,6 @@ ui.workspace = true util.workspace = true watch.workspace = true web_search.workspace = true -which.workspace = true workspace-hack.workspace = true workspace.workspace = true diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index ce3b639cb2c46d3f736490c0b2153260f970963c..17e2ba12f706387859ca3393aa44f5c05570e50a 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -52,7 +52,7 @@ pub fn init(http_client: Arc, cx: &mut App) { assistant_tool::init(cx); let registry = ToolRegistry::global(cx); - registry.register_tool(TerminalTool::new(cx)); + registry.register_tool(TerminalTool); registry.register_tool(CreateDirectoryTool); registry.register_tool(CopyPathTool); registry.register_tool(DeletePathTool); diff --git a/crates/assistant_tools/src/edit_agent/create_file_parser.rs b/crates/assistant_tools/src/edit_agent/create_file_parser.rs index 5126f9c6b1fe4ee5cc600ae93b7300b7af09451f..2272434d796a92e53b741f8ed5f4303d94f88489 100644 --- a/crates/assistant_tools/src/edit_agent/create_file_parser.rs +++ b/crates/assistant_tools/src/edit_agent/create_file_parser.rs @@ -160,7 +160,7 @@ mod tests { &mut parser, &mut rng ), - // This output is marlformed, so we're doing our best effort + // This output is malformed, so we're doing our best effort "Hello world\n```\n\nThe end\n".to_string() ); } @@ -182,7 +182,7 @@ mod tests { &mut parser, &mut rng ), - // This output is marlformed, so we're doing our best effort + // This output is malformed, so we're doing our best effort "```\nHello world\n```\n".to_string() ); } diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs b/crates/assistant_tools/src/edit_agent/evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs index b51c74c798d88b3f84303ffe41f4ac2590e7f236..cfa28fe1ad6091c9adda22f610e1cf13166f8dfb 100644 --- a/crates/assistant_tools/src/edit_agent/evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs +++ b/crates/assistant_tools/src/edit_agent/evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs @@ -916,7 +916,7 @@ impl Loader { if !found_non_static { found_non_static = true; eprintln!( - "Warning: Found non-static non-tree-sitter functions in the external scannner" + "Warning: Found non-static non-tree-sitter functions in the external scanner" ); } eprintln!(" `{function_name}`"); diff --git a/crates/assistant_tools/src/terminal_tool.rs b/crates/assistant_tools/src/terminal_tool.rs index 1605003671621b90e58a5f62e521c0aba2c990c6..8014a39e23137ad71b91e5c24d5d79699b530e5d 100644 --- a/crates/assistant_tools/src/terminal_tool.rs +++ b/crates/assistant_tools/src/terminal_tool.rs @@ -6,7 +6,7 @@ use action_log::ActionLog; use agent_settings; use anyhow::{Context as _, Result, anyhow}; use assistant_tool::{Tool, ToolCard, ToolResult, ToolUseStatus}; -use futures::{FutureExt as _, future::Shared}; +use futures::FutureExt as _; use gpui::{ AnyWindowHandle, App, AppContext, Empty, Entity, EntityId, Task, TextStyleRefinement, WeakEntity, Window, @@ -26,11 +26,12 @@ use std::{ sync::Arc, time::{Duration, Instant}, }; +use task::{Shell, ShellBuilder}; use terminal_view::TerminalView; use theme::ThemeSettings; use ui::{CommonAnimationExt, Disclosure, Tooltip, prelude::*}; use util::{ - ResultExt, get_system_shell, markdown::MarkdownInlineCode, size::format_file_size, + ResultExt, get_default_system_shell, markdown::MarkdownInlineCode, size::format_file_size, time::duration_alt_display, }; use workspace::Workspace; @@ -45,29 +46,10 @@ pub struct TerminalToolInput { cd: String, } -pub struct TerminalTool { - determine_shell: Shared>, -} +pub struct TerminalTool; impl TerminalTool { pub const NAME: &str = "terminal"; - - pub(crate) fn new(cx: &mut App) -> Self { - let determine_shell = cx.background_spawn(async move { - if cfg!(windows) { - return get_system_shell(); - } - - if which::which("bash").is_ok() { - "bash".into() - } else { - get_system_shell() - } - }); - Self { - determine_shell: determine_shell.shared(), - } - } } impl Tool for TerminalTool { @@ -135,19 +117,6 @@ impl Tool for TerminalTool { Ok(dir) => dir, Err(err) => return Task::ready(Err(err)).into(), }; - let program = self.determine_shell.clone(); - let command = if cfg!(windows) { - format!("$null | & {{{}}}", input.command.replace("\"", "'")) - } else if let Some(cwd) = working_dir - .as_ref() - .and_then(|cwd| cwd.as_os_str().to_str()) - { - // Make sure once we're *inside* the shell, we cd into `cwd` - format!("(cd {cwd}; {}) Task::ready(None).shared(), }; + let remote_shell = project.update(cx, |project, cx| { + project + .remote_client() + .and_then(|r| r.read(cx).default_system_shell()) + }); let env = cx.spawn(async move |_| { let mut env = env.await.unwrap_or_default(); @@ -171,8 +145,13 @@ impl Tool for TerminalTool { let task = cx.background_spawn(async move { let env = env.await; let pty_system = native_pty_system(); - let program = program.await; - let mut cmd = CommandBuilder::new(program); + let (command, args) = ShellBuilder::new( + remote_shell.as_deref(), + &Shell::Program(get_default_system_shell()), + ) + .redirect_stdin_to_dev_null() + .build(Some(input.command.clone()), &[]); + let mut cmd = CommandBuilder::new(command); cmd.args(args); for (k, v) in env { cmd.env(k, v); @@ -208,16 +187,22 @@ impl Tool for TerminalTool { }; }; + let command = input.command.clone(); let terminal = cx.spawn({ let project = project.downgrade(); async move |cx| { - let program = program.await; + let (command, args) = ShellBuilder::new( + remote_shell.as_deref(), + &Shell::Program(get_default_system_shell()), + ) + .redirect_stdin_to_dev_null() + .build(Some(input.command), &[]); let env = env.await; project .update(cx, |project, cx| { project.create_terminal_task( task::SpawnInTerminal { - command: Some(program), + command: Some(command), args, cwd, env, @@ -230,14 +215,8 @@ impl Tool for TerminalTool { } }); - let command_markdown = cx.new(|cx| { - Markdown::new( - format!("```bash\n{}\n```", input.command).into(), - None, - None, - cx, - ) - }); + let command_markdown = + cx.new(|cx| Markdown::new(format!("```bash\n{}\n```", command).into(), None, None, cx)); let card = cx.new(|cx| { TerminalToolCard::new( @@ -288,7 +267,7 @@ impl Tool for TerminalTool { let previous_len = content.len(); let (processed_content, finished_with_empty_output) = process_content( &content, - &input.command, + &command, exit_status.map(portable_pty::ExitStatus::from), ); @@ -740,7 +719,6 @@ mod tests { if cfg!(windows) { return; } - init_test(&executor, cx); let fs = Arc::new(RealFs::new(None, executor)); @@ -763,7 +741,7 @@ mod tests { }; let result = cx.update(|cx| { TerminalTool::run( - Arc::new(TerminalTool::new(cx)), + Arc::new(TerminalTool), serde_json::to_value(input).unwrap(), Arc::default(), project.clone(), @@ -783,7 +761,6 @@ mod tests { if cfg!(windows) { return; } - init_test(&executor, cx); let fs = Arc::new(RealFs::new(None, executor)); @@ -798,7 +775,7 @@ mod tests { let check = |input, expected, cx: &mut App| { let headless_result = TerminalTool::run( - Arc::new(TerminalTool::new(cx)), + Arc::new(TerminalTool), serde_json::to_value(input).unwrap(), Arc::default(), project.clone(), diff --git a/crates/audio/src/audio.rs b/crates/audio/src/audio.rs index ab8d85cdaa6bab7ed1be3fdab8a66b42f883533b..f60ddb87b9615d2da9c2be248ab397c19a463616 100644 --- a/crates/audio/src/audio.rs +++ b/crates/audio/src/audio.rs @@ -211,7 +211,7 @@ impl Audio { agc_source.set_enabled(LIVE_SETTINGS.control_input_volume.load(Ordering::Relaxed)); }) .replayable(REPLAY_DURATION) - .expect("REPLAY_DURATION is longer then 100ms"); + .expect("REPLAY_DURATION is longer than 100ms"); cx.update_default_global(|this: &mut Self, _cx| { let output_mixer = this diff --git a/crates/audio/src/rodio_ext.rs b/crates/audio/src/rodio_ext.rs index ba4e4ff0554dd3c9bc2a7e2691de270c0d00908b..e80b00e15a8fdbd3fc438b78a9ca45d0902dcef1 100644 --- a/crates/audio/src/rodio_ext.rs +++ b/crates/audio/src/rodio_ext.rs @@ -57,7 +57,7 @@ impl RodioExt for S { /// replay is being read /// /// # Errors - /// If duration is smaller then 100ms + /// If duration is smaller than 100ms fn replayable( self, duration: Duration, @@ -151,7 +151,7 @@ impl Source for TakeSamples { struct ReplayQueue { inner: ArrayQueue>, normal_chunk_len: usize, - /// The last chunk in the queue may be smaller then + /// The last chunk in the queue may be smaller than /// the normal chunk size. This is always equal to the /// size of the last element in the queue. /// (so normally chunk_size) @@ -535,7 +535,7 @@ mod tests { let (mut replay, mut source) = input .replayable(Duration::from_secs(3)) - .expect("longer then 100ms"); + .expect("longer than 100ms"); source.by_ref().take(3).count(); let yielded: Vec = replay.by_ref().take(3).collect(); @@ -552,7 +552,7 @@ mod tests { let (mut replay, mut source) = input .replayable(Duration::from_secs(2)) - .expect("longer then 100ms"); + .expect("longer than 100ms"); source.by_ref().take(5).count(); // get all items but do not end the source let yielded: Vec = replay.by_ref().take(2).collect(); @@ -567,7 +567,7 @@ mod tests { let (replay, mut source) = input .replayable(Duration::from_secs(2)) - .expect("longer then 100ms"); + .expect("longer than 100ms"); // exhaust but do not yet end source source.by_ref().take(40_000).count(); @@ -586,7 +586,7 @@ mod tests { let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]); let (mut replay, source) = input .replayable(Duration::from_secs(2)) - .expect("longer then 100ms"); + .expect("longer than 100ms"); assert_eq!(replay.by_ref().samples_ready(), 0); source.take(8000).count(); // half a second diff --git a/crates/auto_update/Cargo.toml b/crates/auto_update/Cargo.toml index 1a772710c98f8437932d6e8918df65d003d7962e..35cef84f0366b16d9e31b2416e7a6a10173ff5ef 100644 --- a/crates/auto_update/Cargo.toml +++ b/crates/auto_update/Cargo.toml @@ -32,3 +32,6 @@ workspace-hack.workspace = true [target.'cfg(not(target_os = "windows"))'.dependencies] which.workspace = true + +[dev-dependencies] +gpui = { workspace = true, "features" = ["test-support"] } diff --git a/crates/auto_update/src/auto_update.rs b/crates/auto_update/src/auto_update.rs index f5b211bf8f7099c6e29f5a5e68c49e335426b668..4e0348575e687c2b4e36fcde7df83b8f329733d0 100644 --- a/crates/auto_update/src/auto_update.rs +++ b/crates/auto_update/src/auto_update.rs @@ -33,7 +33,7 @@ actions!( /// Checks for available updates. Check, /// Dismisses the update error message. - DismissErrorMessage, + DismissMessage, /// Opens the release notes for the current version in a browser. ViewReleaseNotes, ] @@ -54,14 +54,14 @@ pub enum VersionCheckType { Semantic(SemanticVersion), } -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone)] pub enum AutoUpdateStatus { Idle, Checking, Downloading { version: VersionCheckType }, Installing { version: VersionCheckType }, Updated { version: VersionCheckType }, - Errored, + Errored { error: Arc }, } impl AutoUpdateStatus { @@ -362,7 +362,9 @@ impl AutoUpdater { } UpdateCheckType::Manual => { log::error!("auto-update failed: error:{:?}", error); - AutoUpdateStatus::Errored + AutoUpdateStatus::Errored { + error: Arc::new(error), + } } }; @@ -381,8 +383,8 @@ impl AutoUpdater { self.status.clone() } - pub fn dismiss_error(&mut self, cx: &mut Context) -> bool { - if self.status == AutoUpdateStatus::Idle { + pub fn dismiss(&mut self, cx: &mut Context) -> bool { + if let AutoUpdateStatus::Idle = self.status { return false; } self.status = AutoUpdateStatus::Idle; @@ -971,8 +973,27 @@ pub fn finalize_auto_update_on_quit() { #[cfg(test)] mod tests { + use gpui::TestAppContext; + use settings::default_settings; + use super::*; + #[gpui::test] + fn test_auto_update_defaults_to_true(cx: &mut TestAppContext) { + cx.update(|cx| { + let mut store = SettingsStore::new(cx, &settings::default_settings()); + store + .set_default_settings(&default_settings(), cx) + .expect("Unable to set default settings"); + store + .set_user_settings("{}", cx) + .expect("Unable to set user settings"); + cx.set_global(store); + AutoUpdateSetting::register(cx); + assert!(AutoUpdateSetting::get_global(cx).0); + }); + } + #[test] fn test_stable_does_not_update_when_fetched_version_is_not_higher() { let release_channel = ReleaseChannel::Stable; diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index d4b4a350f61b5bd1249b33ff3925dd281e9d529c..d0d30f72d7aea7d7f6cf0355caf12b1f2a36eedb 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -339,59 +339,70 @@ fn main() -> Result<()> { "Dev servers were removed in v0.157.x please upgrade to SSH remoting: https://zed.dev/docs/remote-development" ); - let sender: JoinHandle> = thread::spawn({ - let exit_status = exit_status.clone(); - let user_data_dir_for_thread = user_data_dir.clone(); - move || { - let (_, handshake) = server.accept().context("Handshake after Zed spawn")?; - let (tx, rx) = (handshake.requests, handshake.responses); - - #[cfg(target_os = "windows")] - let wsl = args.wsl; - #[cfg(not(target_os = "windows"))] - let wsl = None; - - tx.send(CliRequest::Open { - paths, - urls, - diff_paths, - wsl, - wait: args.wait, - open_new_workspace, - env, - user_data_dir: user_data_dir_for_thread, - })?; - - while let Ok(response) = rx.recv() { - match response { - CliResponse::Ping => {} - CliResponse::Stdout { message } => println!("{message}"), - CliResponse::Stderr { message } => eprintln!("{message}"), - CliResponse::Exit { status } => { - exit_status.lock().replace(status); - return Ok(()); + let sender: JoinHandle> = thread::Builder::new() + .name("CliReceiver".to_string()) + .spawn({ + let exit_status = exit_status.clone(); + let user_data_dir_for_thread = user_data_dir.clone(); + move || { + let (_, handshake) = server.accept().context("Handshake after Zed spawn")?; + let (tx, rx) = (handshake.requests, handshake.responses); + + #[cfg(target_os = "windows")] + let wsl = args.wsl; + #[cfg(not(target_os = "windows"))] + let wsl = None; + + tx.send(CliRequest::Open { + paths, + urls, + diff_paths, + wsl, + wait: args.wait, + open_new_workspace, + env, + user_data_dir: user_data_dir_for_thread, + })?; + + while let Ok(response) = rx.recv() { + match response { + CliResponse::Ping => {} + CliResponse::Stdout { message } => println!("{message}"), + CliResponse::Stderr { message } => eprintln!("{message}"), + CliResponse::Exit { status } => { + exit_status.lock().replace(status); + return Ok(()); + } } } - } - Ok(()) - } - }); + Ok(()) + } + }) + .unwrap(); let stdin_pipe_handle: Option>> = stdin_tmp_file.map(|mut tmp_file| { - thread::spawn(move || { - let mut stdin = std::io::stdin().lock(); - if !io::IsTerminal::is_terminal(&stdin) { - io::copy(&mut stdin, &mut tmp_file)?; - } - Ok(()) - }) + thread::Builder::new() + .name("CliStdin".to_string()) + .spawn(move || { + let mut stdin = std::io::stdin().lock(); + if !io::IsTerminal::is_terminal(&stdin) { + io::copy(&mut stdin, &mut tmp_file)?; + } + Ok(()) + }) + .unwrap() }); let anonymous_fd_pipe_handles: Vec<_> = anonymous_fd_tmp_files .into_iter() - .map(|(mut file, mut tmp_file)| thread::spawn(move || io::copy(&mut file, &mut tmp_file))) + .map(|(mut file, mut tmp_file)| { + thread::Builder::new() + .name("CliAnonymousFd".to_string()) + .spawn(move || io::copy(&mut file, &mut tmp_file)) + .unwrap() + }) .collect(); if args.foreground { diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index a1ac476bbba40d97d611c3016c0f06a6cb08f2ae..237eaa11d954db7c95eaa513e8e921a1000faac6 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -22,7 +22,7 @@ use futures::{ channel::oneshot, future::BoxFuture, }; use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions}; -use http_client::{HttpClient, HttpClientWithUrl, http}; +use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env}; use parking_lot::RwLock; use postage::watch; use proxy::connect_proxy_stream; @@ -136,6 +136,20 @@ pub struct ProxySettings { pub proxy: Option, } +impl ProxySettings { + pub fn proxy_url(&self) -> Option { + self.proxy + .as_ref() + .and_then(|input| { + input + .parse::() + .inspect_err(|e| log::error!("Error parsing proxy settings: {}", e)) + .ok() + }) + .or_else(read_proxy_from_env) + } +} + impl Settings for ProxySettings { fn from_defaults(content: &settings::SettingsContent, _cx: &mut App) -> Self { Self { diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index 63626e8ce1f3b25c742f227a56556545762367c3..de0668b406c512eabfc70f4702466f013eb8c515 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -754,6 +754,10 @@ impl UserStore { } pub fn model_request_usage(&self) -> Option { + if self.plan().is_some_and(|plan| plan.is_v2()) { + return None; + } + self.model_request_usage } diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index 16267d86d806387140016dc0a25021ad92607ff2..24923a318441afeaa2521064b4f433ab9ee1e55f 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -39,7 +39,7 @@ pub const EDIT_PREDICTIONS_RESOURCE_HEADER_VALUE: &str = "edit_predictions"; /// The name of the header used to indicate that the maximum number of consecutive tool uses has been reached. pub const TOOL_USE_LIMIT_REACHED_HEADER_NAME: &str = "x-zed-tool-use-limit-reached"; -/// The name of the header used to indicate the the minimum required Zed version. +/// The name of the header used to indicate the minimum required Zed version. /// /// This can be used to force a Zed upgrade in order to continue communicating /// with the LLM service. @@ -321,8 +321,8 @@ pub struct LanguageModel { #[derive(Debug, Serialize, Deserialize)] pub struct ListModelsResponse { pub models: Vec, - pub default_model: LanguageModelId, - pub default_fast_model: LanguageModelId, + pub default_model: Option, + pub default_fast_model: Option, pub recommended_models: Vec, } diff --git a/crates/collab/k8s/collab.template.yml b/crates/collab/k8s/collab.template.yml index 214b550ac20499b8b03cfafeefab9b45d51fcc24..1476e5890283c62cee3563a327fcdd5ee84842e7 100644 --- a/crates/collab/k8s/collab.template.yml +++ b/crates/collab/k8s/collab.template.yml @@ -226,12 +226,6 @@ spec: secretKeyRef: name: supermaven key: api_key - - name: USER_BACKFILLER_GITHUB_ACCESS_TOKEN - valueFrom: - secretKeyRef: - name: user-backfiller - key: github_access_token - optional: true - name: INVITE_LINK_PREFIX value: ${INVITE_LINK_PREFIX} - name: RUST_BACKTRACE diff --git a/crates/collab/src/db/queries/users.rs b/crates/collab/src/db/queries/users.rs index 4b0f66fcbe09d23af58b0a30ffebf68455651fd8..89211130b88c69d4bf524bba25ae116790321d3e 100644 --- a/crates/collab/src/db/queries/users.rs +++ b/crates/collab/src/db/queries/users.rs @@ -342,79 +342,6 @@ impl Database { result } - /// Returns all feature flags. - pub async fn list_feature_flags(&self) -> Result> { - self.transaction(|tx| async move { Ok(feature_flag::Entity::find().all(&*tx).await?) }) - .await - } - - /// Creates a new feature flag. - pub async fn create_user_flag(&self, flag: &str, enabled_for_all: bool) -> Result { - self.transaction(|tx| async move { - let flag = feature_flag::Entity::insert(feature_flag::ActiveModel { - flag: ActiveValue::set(flag.to_string()), - enabled_for_all: ActiveValue::set(enabled_for_all), - ..Default::default() - }) - .exec(&*tx) - .await? - .last_insert_id; - - Ok(flag) - }) - .await - } - - /// Add the given user to the feature flag - pub async fn add_user_flag(&self, user: UserId, flag: FlagId) -> Result<()> { - self.transaction(|tx| async move { - user_feature::Entity::insert(user_feature::ActiveModel { - user_id: ActiveValue::set(user), - feature_id: ActiveValue::set(flag), - }) - .exec(&*tx) - .await?; - - Ok(()) - }) - .await - } - - /// Returns the active flags for the user. - pub async fn get_user_flags(&self, user: UserId) -> Result> { - self.transaction(|tx| async move { - #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] - enum QueryAs { - Flag, - } - - let flags_enabled_for_all = feature_flag::Entity::find() - .filter(feature_flag::Column::EnabledForAll.eq(true)) - .select_only() - .column(feature_flag::Column::Flag) - .into_values::<_, QueryAs>() - .all(&*tx) - .await?; - - let flags_enabled_for_user = user::Model { - id: user, - ..Default::default() - } - .find_linked(user::UserFlags) - .select_only() - .column(feature_flag::Column::Flag) - .into_values::<_, QueryAs>() - .all(&*tx) - .await?; - - let mut all_flags = HashSet::from_iter(flags_enabled_for_all); - all_flags.extend(flags_enabled_for_user); - - Ok(all_flags.into_iter().collect()) - }) - .await - } - pub async fn get_users_missing_github_user_created_at(&self) -> Result> { self.transaction(|tx| async move { Ok(user::Entity::find() diff --git a/crates/collab/src/db/tables.rs b/crates/collab/src/db/tables.rs index 0082a9fb030a27e4be13af725f08ea9c82217377..32c4570af5893b503f0fcfdaa1759616cf9be387 100644 --- a/crates/collab/src/db/tables.rs +++ b/crates/collab/src/db/tables.rs @@ -13,7 +13,6 @@ pub mod contributor; pub mod embedding; pub mod extension; pub mod extension_version; -pub mod feature_flag; pub mod follower; pub mod language_server; pub mod notification; @@ -29,7 +28,6 @@ pub mod room_participant; pub mod server; pub mod signup; pub mod user; -pub mod user_feature; pub mod worktree; pub mod worktree_diagnostic_summary; pub mod worktree_entry; diff --git a/crates/collab/src/db/tables/feature_flag.rs b/crates/collab/src/db/tables/feature_flag.rs deleted file mode 100644 index 5bbfedd71e70b7f1cc58219475c49c28bc62ff3d..0000000000000000000000000000000000000000 --- a/crates/collab/src/db/tables/feature_flag.rs +++ /dev/null @@ -1,41 +0,0 @@ -use sea_orm::entity::prelude::*; - -use crate::db::FlagId; - -#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] -#[sea_orm(table_name = "feature_flags")] -pub struct Model { - #[sea_orm(primary_key)] - pub id: FlagId, - pub flag: String, - pub enabled_for_all: bool, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation { - #[sea_orm(has_many = "super::user_feature::Entity")] - UserFeature, -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::UserFeature.def() - } -} - -impl ActiveModelBehavior for ActiveModel {} - -pub struct FlaggedUsers; - -impl Linked for FlaggedUsers { - type FromEntity = Entity; - - type ToEntity = super::user::Entity; - - fn link(&self) -> Vec { - vec![ - super::user_feature::Relation::Flag.def().rev(), - super::user_feature::Relation::User.def(), - ] - } -} diff --git a/crates/collab/src/db/tables/user.rs b/crates/collab/src/db/tables/user.rs index af43fe300a6cc1224487541ca72af9d887a6fae3..8e8c03fafc92127f8754f473e04dfab39592ea14 100644 --- a/crates/collab/src/db/tables/user.rs +++ b/crates/collab/src/db/tables/user.rs @@ -35,8 +35,6 @@ pub enum Relation { HostedProjects, #[sea_orm(has_many = "super::channel_member::Entity")] ChannelMemberships, - #[sea_orm(has_many = "super::user_feature::Entity")] - UserFeatures, #[sea_orm(has_one = "super::contributor::Entity")] Contributor, } @@ -84,25 +82,4 @@ impl Related for Entity { } } -impl Related for Entity { - fn to() -> RelationDef { - Relation::UserFeatures.def() - } -} - impl ActiveModelBehavior for ActiveModel {} - -pub struct UserFlags; - -impl Linked for UserFlags { - type FromEntity = Entity; - - type ToEntity = super::feature_flag::Entity; - - fn link(&self) -> Vec { - vec![ - super::user_feature::Relation::User.def().rev(), - super::user_feature::Relation::Flag.def(), - ] - } -} diff --git a/crates/collab/src/db/tables/user_feature.rs b/crates/collab/src/db/tables/user_feature.rs deleted file mode 100644 index cc24b5e796342f7733f59933362d46a0df2be112..0000000000000000000000000000000000000000 --- a/crates/collab/src/db/tables/user_feature.rs +++ /dev/null @@ -1,42 +0,0 @@ -use sea_orm::entity::prelude::*; - -use crate::db::{FlagId, UserId}; - -#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] -#[sea_orm(table_name = "user_features")] -pub struct Model { - #[sea_orm(primary_key)] - pub user_id: UserId, - #[sea_orm(primary_key)] - pub feature_id: FlagId, -} - -#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation { - #[sea_orm( - belongs_to = "super::feature_flag::Entity", - from = "Column::FeatureId", - to = "super::feature_flag::Column::Id" - )] - Flag, - #[sea_orm( - belongs_to = "super::user::Entity", - from = "Column::UserId", - to = "super::user::Column::Id" - )] - User, -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::Flag.def() - } -} - -impl Related for Entity { - fn to() -> RelationDef { - Relation::User.def() - } -} - -impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index 25e03f1320a25455ede347b43477761d591fbd57..141262d5e94a4bf1d4d897e78f6281ab9ee3ccfc 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -6,7 +6,6 @@ mod db_tests; #[cfg(target_os = "macos")] mod embedding_tests; mod extension_tests; -mod feature_flag_tests; mod user_tests; use crate::migrations::run_database_migrations; diff --git a/crates/collab/src/db/tests/feature_flag_tests.rs b/crates/collab/src/db/tests/feature_flag_tests.rs deleted file mode 100644 index 0e68dcc941cdb2488c3822548dada56746667bc2..0000000000000000000000000000000000000000 --- a/crates/collab/src/db/tests/feature_flag_tests.rs +++ /dev/null @@ -1,66 +0,0 @@ -use crate::{ - db::{Database, NewUserParams}, - test_both_dbs, -}; -use pretty_assertions::assert_eq; -use std::sync::Arc; - -test_both_dbs!( - test_get_user_flags, - test_get_user_flags_postgres, - test_get_user_flags_sqlite -); - -async fn test_get_user_flags(db: &Arc) { - let user_1 = db - .create_user( - "user1@example.com", - None, - false, - NewUserParams { - github_login: "user1".to_string(), - github_user_id: 1, - }, - ) - .await - .unwrap() - .user_id; - - let user_2 = db - .create_user( - "user2@example.com", - None, - false, - NewUserParams { - github_login: "user2".to_string(), - github_user_id: 2, - }, - ) - .await - .unwrap() - .user_id; - - const FEATURE_FLAG_ONE: &str = "brand-new-ux"; - const FEATURE_FLAG_TWO: &str = "cool-feature"; - const FEATURE_FLAG_THREE: &str = "feature-enabled-for-everyone"; - - let feature_flag_one = db.create_user_flag(FEATURE_FLAG_ONE, false).await.unwrap(); - let feature_flag_two = db.create_user_flag(FEATURE_FLAG_TWO, false).await.unwrap(); - db.create_user_flag(FEATURE_FLAG_THREE, true).await.unwrap(); - - db.add_user_flag(user_1, feature_flag_one).await.unwrap(); - db.add_user_flag(user_1, feature_flag_two).await.unwrap(); - - db.add_user_flag(user_2, feature_flag_one).await.unwrap(); - - let mut user_1_flags = db.get_user_flags(user_1).await.unwrap(); - user_1_flags.sort(); - assert_eq!( - user_1_flags, - &[FEATURE_FLAG_ONE, FEATURE_FLAG_TWO, FEATURE_FLAG_THREE] - ); - - let mut user_2_flags = db.get_user_flags(user_2).await.unwrap(); - user_2_flags.sort(); - assert_eq!(user_2_flags, &[FEATURE_FLAG_ONE, FEATURE_FLAG_THREE]); -} diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 191025df3770db78df3a12bc16d5c8f32d54571c..f1de0cdc7ff79cd25c8ef7b0b2b21d9e0b45d332 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -7,7 +7,6 @@ pub mod llm; pub mod migrations; pub mod rpc; pub mod seed; -pub mod user_backfiller; #[cfg(test)] mod tests; @@ -157,7 +156,6 @@ pub struct Config { pub slack_panics_webhook: Option, pub auto_join_channel_id: Option, pub supermaven_admin_api_key: Option>, - pub user_backfiller_github_access_token: Option>, } impl Config { @@ -211,7 +209,6 @@ impl Config { migrations_path: None, seed_path: None, supermaven_admin_api_key: None, - user_backfiller_github_access_token: None, kinesis_region: None, kinesis_access_key: None, kinesis_secret_key: None, diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index cb6f6cad1dd483c463bcda5d8a4ff914f4bf10aa..6b94459910647c1e48ee69f2b0dd38afd3723821 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -11,7 +11,6 @@ use collab::ServiceMode; use collab::api::CloudflareIpCountryHeader; use collab::llm::db::LlmDatabase; use collab::migrations::run_database_migrations; -use collab::user_backfiller::spawn_user_backfiller; use collab::{ AppState, Config, Result, api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, rpc::ResultExt, @@ -114,7 +113,6 @@ async fn main() -> Result<()> { if mode.is_api() { fetch_extensions_from_blob_store_periodically(state.clone()); - spawn_user_backfiller(state.clone()); app = app .merge(collab::api::events::router()) diff --git a/crates/collab/src/seed.rs b/crates/collab/src/seed.rs index 2d070b30abada79dc177b2b600d9ecc40aa472e1..5f5779e1e4990d1a03461bb3ec2222e82d9f544e 100644 --- a/crates/collab/src/seed.rs +++ b/crates/collab/src/seed.rs @@ -46,27 +46,6 @@ pub async fn seed(config: &Config, db: &Database, force: bool) -> anyhow::Result let mut first_user = None; let mut others = vec![]; - let flag_names = ["language-models"]; - let mut flags = Vec::new(); - - let existing_feature_flags = db.list_feature_flags().await?; - - for flag_name in flag_names { - if existing_feature_flags - .iter() - .any(|flag| flag.flag == flag_name) - { - log::info!("Flag {flag_name:?} already exists"); - continue; - } - - let flag = db - .create_user_flag(flag_name, false) - .await - .unwrap_or_else(|err| panic!("failed to create flag: '{flag_name}': {err}")); - flags.push(flag); - } - for admin_login in seed_config.admins { let user = fetch_github::( &client, @@ -90,15 +69,6 @@ pub async fn seed(config: &Config, db: &Database, force: bool) -> anyhow::Result } else { others.push(user.user_id) } - - for flag in &flags { - db.add_user_flag(user.user_id, *flag) - .await - .context(format!( - "Unable to enable flag '{}' for user '{}'", - flag, user.user_id - ))?; - } } for channel in seed_config.channels { @@ -126,24 +96,16 @@ pub async fn seed(config: &Config, db: &Database, force: bool) -> anyhow::Result for github_user in github_users { log::info!("Seeding {:?} from GitHub", github_user.login); - let user = db - .update_or_create_user_by_github_account( - &github_user.login, - github_user.id, - github_user.email.as_deref(), - github_user.name.as_deref(), - github_user.created_at, - None, - ) - .await - .expect("failed to insert user"); - - for flag in &flags { - db.add_user_flag(user.id, *flag).await.context(format!( - "Unable to enable flag '{}' for user '{}'", - flag, user.id - ))?; - } + db.update_or_create_user_by_github_account( + &github_user.login, + github_user.id, + github_user.email.as_deref(), + github_user.name.as_deref(), + github_user.created_at, + None, + ) + .await + .expect("failed to insert user"); } Ok(()) diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index eb7df28478158a10a0c2d52c3560cad391937383..5e99cc192ad080c1a79913c79fbbaae9d8b6d951 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -604,7 +604,6 @@ impl TestServer { migrations_path: None, seed_path: None, supermaven_admin_api_key: None, - user_backfiller_github_access_token: None, kinesis_region: None, kinesis_stream: None, kinesis_access_key: None, diff --git a/crates/collab/src/user_backfiller.rs b/crates/collab/src/user_backfiller.rs deleted file mode 100644 index fdb9ef67c2f1d04bf0a1919045f91d75a14ef834..0000000000000000000000000000000000000000 --- a/crates/collab/src/user_backfiller.rs +++ /dev/null @@ -1,165 +0,0 @@ -use std::sync::Arc; - -use anyhow::{Context as _, Result}; -use chrono::{DateTime, Utc}; -use util::ResultExt; - -use crate::db::Database; -use crate::executor::Executor; -use crate::{AppState, Config}; - -pub fn spawn_user_backfiller(app_state: Arc) { - let Some(user_backfiller_github_access_token) = - app_state.config.user_backfiller_github_access_token.clone() - else { - log::info!("no USER_BACKFILLER_GITHUB_ACCESS_TOKEN set; not spawning user backfiller"); - return; - }; - - let executor = app_state.executor.clone(); - executor.spawn_detached({ - let executor = executor.clone(); - async move { - let user_backfiller = UserBackfiller::new( - app_state.config.clone(), - user_backfiller_github_access_token, - app_state.db.clone(), - executor, - ); - - log::info!("backfilling users"); - - user_backfiller - .backfill_github_user_created_at() - .await - .log_err(); - } - }); -} - -const GITHUB_REQUESTS_PER_HOUR_LIMIT: usize = 5_000; -const SLEEP_DURATION_BETWEEN_USERS: std::time::Duration = std::time::Duration::from_millis( - (GITHUB_REQUESTS_PER_HOUR_LIMIT as f64 / 60. / 60. * 1000.) as u64, -); - -struct UserBackfiller { - config: Config, - github_access_token: Arc, - db: Arc, - http_client: reqwest::Client, - executor: Executor, -} - -impl UserBackfiller { - fn new( - config: Config, - github_access_token: Arc, - db: Arc, - executor: Executor, - ) -> Self { - Self { - config, - github_access_token, - db, - http_client: reqwest::Client::new(), - executor, - } - } - - async fn backfill_github_user_created_at(&self) -> Result<()> { - let initial_channel_id = self.config.auto_join_channel_id; - - let users_missing_github_user_created_at = - self.db.get_users_missing_github_user_created_at().await?; - - for user in users_missing_github_user_created_at { - match self - .fetch_github_user(&format!( - "https://api.github.com/user/{}", - user.github_user_id - )) - .await - { - Ok(github_user) => { - self.db - .update_or_create_user_by_github_account( - &user.github_login, - github_user.id, - user.email_address.as_deref(), - user.name.as_deref(), - github_user.created_at, - initial_channel_id, - ) - .await?; - - log::info!("backfilled user: {}", user.github_login); - } - Err(err) => { - log::error!("failed to fetch GitHub user {}: {err}", user.github_login); - } - } - - self.executor.sleep(SLEEP_DURATION_BETWEEN_USERS).await; - } - - Ok(()) - } - - async fn fetch_github_user(&self, url: &str) -> Result { - let response = self - .http_client - .get(url) - .header( - "authorization", - format!("Bearer {}", self.github_access_token), - ) - .header("user-agent", "zed") - .send() - .await - .with_context(|| format!("failed to fetch '{url}'"))?; - - let rate_limit_remaining = response - .headers() - .get("x-ratelimit-remaining") - .and_then(|value| value.to_str().ok()) - .and_then(|value| value.parse::().ok()); - let rate_limit_reset = response - .headers() - .get("x-ratelimit-reset") - .and_then(|value| value.to_str().ok()) - .and_then(|value| value.parse::().ok()) - .and_then(|value| DateTime::from_timestamp(value, 0)); - - if rate_limit_remaining == Some(0) - && let Some(reset_at) = rate_limit_reset - { - let now = Utc::now(); - if reset_at > now { - let sleep_duration = reset_at - now; - log::info!( - "rate limit reached. Sleeping for {} seconds", - sleep_duration.num_seconds() - ); - self.executor.sleep(sleep_duration.to_std().unwrap()).await; - } - } - - response - .error_for_status() - .context("fetching GitHub user")? - .json() - .await - .with_context(|| format!("failed to deserialize GitHub user from '{url}'")) - } -} - -#[derive(serde::Deserialize)] -struct GithubUser { - id: i32, - created_at: DateTime, - #[expect( - unused, - reason = "This field was found to be unused with serde library bump; it's left as is due to insufficient context on PO's side, but it *may* be fine to remove" - )] - name: Option, -} diff --git a/crates/crashes/src/crashes.rs b/crates/crashes/src/crashes.rs index 98db4bfc73f157994f3f7286c0764cfb0778e4a4..8312638e2a811767ee245f53c356eca15ef852f1 100644 --- a/crates/crashes/src/crashes.rs +++ b/crates/crashes/src/crashes.rs @@ -321,16 +321,19 @@ pub fn crash_server(socket: &Path) { let shutdown = Arc::new(AtomicBool::new(false)); let has_connection = Arc::new(AtomicBool::new(false)); - std::thread::spawn({ - let shutdown = shutdown.clone(); - let has_connection = has_connection.clone(); - move || { - std::thread::sleep(CRASH_HANDLER_CONNECT_TIMEOUT); - if !has_connection.load(Ordering::SeqCst) { - shutdown.store(true, Ordering::SeqCst); + thread::Builder::new() + .name("CrashServerTimeout".to_owned()) + .spawn({ + let shutdown = shutdown.clone(); + let has_connection = has_connection.clone(); + move || { + std::thread::sleep(CRASH_HANDLER_CONNECT_TIMEOUT); + if !has_connection.load(Ordering::SeqCst) { + shutdown.store(true, Ordering::SeqCst); + } } - } - }); + }) + .unwrap(); server .run( diff --git a/crates/denoise/Cargo.toml b/crates/denoise/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..a2f43cdfee72a64fbd7e6e60b9414c691c3adfcd --- /dev/null +++ b/crates/denoise/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "denoise" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[dependencies] +candle-core = { version = "0.9.1", git ="https://github.com/zed-industries/candle", branch = "9.1-patched" } +candle-onnx = { version = "0.9.1", git ="https://github.com/zed-industries/candle", branch = "9.1-patched" } +log.workspace = true + +rodio = { workspace = true, features = ["wav_output"] } + +rustfft = { version = "6.2.0", features = ["avx"] } +realfft = "3.4.0" +thiserror.workspace = true +workspace-hack.workspace = true diff --git a/crates/denoise/LICENSE-GPL b/crates/denoise/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/denoise/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/denoise/README.md b/crates/denoise/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d7486da36e9078f2f99c2a6a8226dbce499cae8b --- /dev/null +++ b/crates/denoise/README.md @@ -0,0 +1,20 @@ +Real time streaming audio denoising using a [Dual-Signal Transformation LSTM Network for Real-Time Noise Suppression](https://arxiv.org/abs/2005.07551). + +Trivial to build as it uses the native rust Candle crate for inference. Easy to integrate into any Rodio pipeline. + +```rust + # use rodio::{nz, source::UniformSourceIterator, wav_to_file}; + let file = std::fs::File::open("clips_airconditioning.wav")?; + let decoder = rodio::Decoder::try_from(file)?; + let resampled = UniformSourceIterator::new(decoder, nz!(1), nz!(16_000)); + + let mut denoised = denoise::Denoiser::try_new(resampled)?; + wav_to_file(&mut denoised, "denoised.wav")?; + Result::Ok<(), Box> +``` + +## Acknowledgements & License + +The trained models in this repo are optimized versions of the models in the [breizhn/DTLN](https://github.com/breizhn/DTLN?tab=readme-ov-file#model-conversion-and-real-time-processing-with-onnx). These are licensed under MIT. + +The FFT code was adapted from Datadog's [dtln-rs Repo](https://github.com/DataDog/dtln-rs/tree/main) also licensed under MIT. diff --git a/crates/denoise/examples/denoise.rs b/crates/denoise/examples/denoise.rs new file mode 100644 index 0000000000000000000000000000000000000000..a4d89d7e517e7b35d0f87adbd218cd34b75a4789 --- /dev/null +++ b/crates/denoise/examples/denoise.rs @@ -0,0 +1,11 @@ +use rodio::{nz, source::UniformSourceIterator, wav_to_file}; + +fn main() -> Result<(), Box> { + let file = std::fs::File::open("airconditioning.wav")?; + let decoder = rodio::Decoder::try_from(file)?; + let resampled = UniformSourceIterator::new(decoder, nz!(1), nz!(16_000)); + + let mut denoised = denoise::Denoiser::try_new(resampled)?; + wav_to_file(&mut denoised, "denoised.wav")?; + Ok(()) +} diff --git a/crates/denoise/examples/enable_disable.rs b/crates/denoise/examples/enable_disable.rs new file mode 100644 index 0000000000000000000000000000000000000000..1cffadbce2b0e58cdf56b291cb68a13fc6556b22 --- /dev/null +++ b/crates/denoise/examples/enable_disable.rs @@ -0,0 +1,23 @@ +use std::time::Duration; + +use rodio::Source; +use rodio::wav_to_file; +use rodio::{nz, source::UniformSourceIterator}; + +fn main() -> Result<(), Box> { + let file = std::fs::File::open("clips_airconditioning.wav")?; + let decoder = rodio::Decoder::try_from(file)?; + let resampled = UniformSourceIterator::new(decoder, nz!(1), nz!(16_000)); + + let mut enabled = true; + let denoised = denoise::Denoiser::try_new(resampled)?.periodic_access( + Duration::from_secs(2), + |denoised| { + enabled = !enabled; + denoised.set_enabled(enabled); + }, + ); + + wav_to_file(denoised, "processed.wav")?; + Ok(()) +} diff --git a/crates/denoise/models/model_1_converted_simplified.onnx b/crates/denoise/models/model_1_converted_simplified.onnx new file mode 100644 index 0000000000000000000000000000000000000000..821cb73bd76b1470c0ee814d07bd03c47a613643 Binary files /dev/null and b/crates/denoise/models/model_1_converted_simplified.onnx differ diff --git a/crates/denoise/models/model_2_converted_simplified.onnx b/crates/denoise/models/model_2_converted_simplified.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a83023ab22748fb60f8186c3c5b0161531337082 Binary files /dev/null and b/crates/denoise/models/model_2_converted_simplified.onnx differ diff --git a/crates/denoise/src/engine.rs b/crates/denoise/src/engine.rs new file mode 100644 index 0000000000000000000000000000000000000000..5196b70b5ba02f665385c022a0dfa9cd22c1db9c --- /dev/null +++ b/crates/denoise/src/engine.rs @@ -0,0 +1,204 @@ +/// use something like https://netron.app/ to inspect the models and understand +/// the flow +use std::collections::HashMap; + +use candle_core::{Device, IndexOp, Tensor}; +use candle_onnx::onnx::ModelProto; +use candle_onnx::prost::Message; +use realfft::RealFftPlanner; +use rustfft::num_complex::Complex; + +pub struct Engine { + spectral_model: ModelProto, + signal_model: ModelProto, + + fft_planner: RealFftPlanner, + fft_scratch: Vec>, + spectrum: [Complex; FFT_OUT_SIZE], + signal: [f32; BLOCK_LEN], + + in_magnitude: [f32; FFT_OUT_SIZE], + in_phase: [f32; FFT_OUT_SIZE], + + spectral_memory: Tensor, + signal_memory: Tensor, + + in_buffer: [f32; BLOCK_LEN], + out_buffer: [f32; BLOCK_LEN], +} + +// 32 ms @ 16khz per DTLN docs: https://github.com/breizhn/DTLN +pub const BLOCK_LEN: usize = 512; +// 8 ms @ 16khz per DTLN docs. +pub const BLOCK_SHIFT: usize = 128; +pub const FFT_OUT_SIZE: usize = BLOCK_LEN / 2 + 1; + +impl Engine { + pub fn new() -> Self { + let mut fft_planner = RealFftPlanner::new(); + let fft_planned = fft_planner.plan_fft_forward(BLOCK_LEN); + let scratch_len = fft_planned.get_scratch_len(); + Self { + // Models are 1.5MB and 2.5MB respectively. Its worth the binary + // size increase not to have to distribute the models separately. + spectral_model: ModelProto::decode( + include_bytes!("../models/model_1_converted_simplified.onnx").as_slice(), + ) + .expect("The model should decode"), + signal_model: ModelProto::decode( + include_bytes!("../models/model_2_converted_simplified.onnx").as_slice(), + ) + .expect("The model should decode"), + fft_planner, + fft_scratch: vec![Complex::ZERO; scratch_len], + spectrum: [Complex::ZERO; FFT_OUT_SIZE], + signal: [0f32; BLOCK_LEN], + + in_magnitude: [0f32; FFT_OUT_SIZE], + in_phase: [0f32; FFT_OUT_SIZE], + + spectral_memory: Tensor::from_slice::<_, f32>( + &[0f32; 512], + (1, 2, BLOCK_SHIFT, 2), + &Device::Cpu, + ) + .expect("Tensor has the correct dimensions"), + signal_memory: Tensor::from_slice::<_, f32>( + &[0f32; 512], + (1, 2, BLOCK_SHIFT, 2), + &Device::Cpu, + ) + .expect("Tensor has the correct dimensions"), + out_buffer: [0f32; BLOCK_LEN], + in_buffer: [0f32; BLOCK_LEN], + } + } + + /// Add a clunk of samples and get the denoised chunk 4 feeds later + pub fn feed(&mut self, samples: &[f32]) -> [f32; BLOCK_SHIFT] { + /// The name of the output node of the onnx network + /// [Dual-Signal Transformation LSTM Network for Real-Time Noise Suppression](https://arxiv.org/abs/2005.07551). + const MEMORY_OUTPUT: &'static str = "Identity_1"; + + debug_assert_eq!(samples.len(), BLOCK_SHIFT); + + // place new samples at the end of the `in_buffer` + self.in_buffer.copy_within(BLOCK_SHIFT.., 0); + self.in_buffer[(BLOCK_LEN - BLOCK_SHIFT)..].copy_from_slice(&samples); + + // run inference + let inputs = self.spectral_inputs(); + let mut spectral_outputs = candle_onnx::simple_eval(&self.spectral_model, inputs) + .expect("The embedded file must be valid"); + self.spectral_memory = spectral_outputs + .remove(MEMORY_OUTPUT) + .expect("The model has an output named Identity_1"); + let inputs = self.signal_inputs(spectral_outputs); + let mut signal_outputs = candle_onnx::simple_eval(&self.signal_model, inputs) + .expect("The embedded file must be valid"); + self.signal_memory = signal_outputs + .remove(MEMORY_OUTPUT) + .expect("The model has an output named Identity_1"); + let model_output = model_outputs(signal_outputs); + + // place processed samples at the start of the `out_buffer` + // shift the rest left, fill the end with zeros. Zeros are needed as + // the out buffer is part of the input of the network + self.out_buffer.copy_within(BLOCK_SHIFT.., 0); + self.out_buffer[BLOCK_LEN - BLOCK_SHIFT..].fill(0f32); + for (a, b) in self.out_buffer.iter_mut().zip(model_output) { + *a += b; + } + + // samples at the front of the `out_buffer` are now denoised + self.out_buffer[..BLOCK_SHIFT] + .try_into() + .expect("len is correct") + } + + fn spectral_inputs(&mut self) -> HashMap { + // Prepare FFT input + let fft = self.fft_planner.plan_fft_forward(BLOCK_LEN); + + // Perform real-to-complex FFT + let mut fft_in = self.in_buffer; + fft.process_with_scratch(&mut fft_in, &mut self.spectrum, &mut self.fft_scratch) + .expect("The fft should run, there is enough scratch space"); + + // Generate magnitude and phase + for ((magnitude, phase), complex) in self + .in_magnitude + .iter_mut() + .zip(self.in_phase.iter_mut()) + .zip(self.spectrum) + { + *magnitude = complex.norm(); + *phase = complex.arg(); + } + + const SPECTRUM_INPUT: &str = "input_2"; + const MEMORY_INPUT: &str = "input_3"; + let memory_input = + Tensor::from_slice::<_, f32>(&self.in_magnitude, (1, 1, FFT_OUT_SIZE), &Device::Cpu) + .expect("the in magnitude has enough elements to fill the Tensor"); + + let inputs = HashMap::from([ + (MEMORY_INPUT.to_string(), memory_input), + (SPECTRUM_INPUT.to_string(), self.spectral_memory.clone()), + ]); + inputs + } + + fn signal_inputs(&mut self, outputs: HashMap) -> HashMap { + let magnitude_weight = model_outputs(outputs); + + // Apply mask and reconstruct complex spectrum + let mut spectrum = [Complex::I; FFT_OUT_SIZE]; + for i in 0..FFT_OUT_SIZE { + let magnitude = self.in_magnitude[i] * magnitude_weight[i]; + let phase = self.in_phase[i]; + let real = magnitude * phase.cos(); + let imag = magnitude * phase.sin(); + spectrum[i] = Complex::new(real, imag); + } + + // Handle DC component (i = 0) + let magnitude = self.in_magnitude[0] * magnitude_weight[0]; + spectrum[0] = Complex::new(magnitude, 0.0); + + // Handle Nyquist component (i = N/2) + let magnitude = self.in_magnitude[FFT_OUT_SIZE - 1] * magnitude_weight[FFT_OUT_SIZE - 1]; + spectrum[FFT_OUT_SIZE - 1] = Complex::new(magnitude, 0.0); + + // Perform complex-to-real IFFT + let ifft = self.fft_planner.plan_fft_inverse(BLOCK_LEN); + ifft.process_with_scratch(&mut spectrum, &mut self.signal, &mut self.fft_scratch) + .expect("The fft should run, there is enough scratch space"); + + // Normalize the IFFT output + for real in &mut self.signal { + *real /= BLOCK_LEN as f32; + } + + const SIGNAL_INPUT: &str = "input_4"; + const SIGNAL_MEMORY: &str = "input_5"; + let signal_input = + Tensor::from_slice::<_, f32>(&self.signal, (1, 1, BLOCK_LEN), &Device::Cpu).unwrap(); + + HashMap::from([ + (SIGNAL_INPUT.to_string(), signal_input), + (SIGNAL_MEMORY.to_string(), self.signal_memory.clone()), + ]) + } +} + +// Both models put their outputs in the same location +fn model_outputs(mut outputs: HashMap) -> Vec { + const NON_MEMORY_OUTPUT: &str = "Identity"; + outputs + .remove(NON_MEMORY_OUTPUT) + .expect("The model has this output") + .i((0, 0)) + .and_then(|tensor| tensor.to_vec1()) + .expect("The tensor has the correct dimensions") +} diff --git a/crates/denoise/src/lib.rs b/crates/denoise/src/lib.rs new file mode 100644 index 0000000000000000000000000000000000000000..1422c81a4b915d571d35585447165c04d3695b73 --- /dev/null +++ b/crates/denoise/src/lib.rs @@ -0,0 +1,273 @@ +mod engine; + +use core::fmt; +use std::{collections::VecDeque, sync::mpsc, thread}; + +pub use engine::Engine; +use rodio::{ChannelCount, Sample, SampleRate, Source, nz}; + +use crate::engine::BLOCK_SHIFT; + +const SUPPORTED_SAMPLE_RATE: SampleRate = nz!(16_000); +const SUPPORTED_CHANNEL_COUNT: ChannelCount = nz!(1); + +pub struct Denoiser { + inner: S, + input_tx: mpsc::Sender<[Sample; BLOCK_SHIFT]>, + denoised_rx: mpsc::Receiver<[Sample; BLOCK_SHIFT]>, + ready: [Sample; BLOCK_SHIFT], + next: usize, + state: IterState, + // When disabled instead of reading denoised sub-blocks from the engine through + // `denoised_rx` we read unprocessed from this queue. This maintains the same + // latency so we can 'trivially' re-enable + queued: Queue, +} + +impl fmt::Debug for Denoiser { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Denoiser") + .field("state", &self.state) + .finish_non_exhaustive() + } +} + +struct Queue(VecDeque<[Sample; BLOCK_SHIFT]>); + +impl Queue { + fn new() -> Self { + Self(VecDeque::new()) + } + fn push(&mut self, block: [Sample; BLOCK_SHIFT]) { + self.0.push_back(block); + self.0.resize(4, [0f32; BLOCK_SHIFT]); + } + fn pop(&mut self) -> [Sample; BLOCK_SHIFT] { + debug_assert!(self.0.len() == 4); + self.0.pop_front().expect( + "There is no State where the queue is popped while there are less then 4 entries", + ) + } +} + +#[derive(Debug, Clone, Copy)] +pub enum IterState { + Enabled, + StartingMidAudio { fed_to_denoiser: usize }, + Disabled, + Startup { enabled: bool }, +} + +#[derive(Debug, thiserror::Error)] +pub enum DenoiserError { + #[error("This denoiser only works on sources with samplerate 16000")] + UnsupportedSampleRate, + #[error("This denoiser only works on mono sources (1 channel)")] + UnsupportedChannelCount, +} + +// todo dvdsk needs constant source upstream in rodio +impl Denoiser { + pub fn try_new(source: S) -> Result { + if source.sample_rate() != SUPPORTED_SAMPLE_RATE { + return Err(DenoiserError::UnsupportedSampleRate); + } + if source.channels() != SUPPORTED_CHANNEL_COUNT { + return Err(DenoiserError::UnsupportedChannelCount); + } + + let (input_tx, input_rx) = mpsc::channel(); + let (denoised_tx, denoised_rx) = mpsc::channel(); + + thread::Builder::new() + .name("NeuralDenoiser".to_owned()) + .spawn(move || { + run_neural_denoiser(denoised_tx, input_rx); + }) + .unwrap(); + + Ok(Self { + inner: source, + input_tx, + denoised_rx, + ready: [0.0; BLOCK_SHIFT], + state: IterState::Startup { enabled: true }, + next: BLOCK_SHIFT, + queued: Queue::new(), + }) + } + + pub fn set_enabled(&mut self, enabled: bool) { + self.state = match (enabled, self.state) { + (false, IterState::StartingMidAudio { .. }) | (false, IterState::Enabled) => { + IterState::Disabled + } + (false, IterState::Startup { enabled: true }) => IterState::Startup { enabled: false }, + (true, IterState::Disabled) => IterState::StartingMidAudio { fed_to_denoiser: 0 }, + (_, state) => state, + }; + } + + fn feed(&self, sub_block: [f32; BLOCK_SHIFT]) { + self.input_tx.send(sub_block).unwrap(); + } +} + +fn run_neural_denoiser( + denoised_tx: mpsc::Sender<[f32; BLOCK_SHIFT]>, + input_rx: mpsc::Receiver<[f32; BLOCK_SHIFT]>, +) { + let mut engine = Engine::new(); + loop { + let Ok(sub_block) = input_rx.recv() else { + // tx must have dropped, stop thread + break; + }; + + let denoised_sub_block = engine.feed(&sub_block); + if denoised_tx.send(denoised_sub_block).is_err() { + break; + } + } +} + +impl Source for Denoiser { + fn current_span_len(&self) -> Option { + self.inner.current_span_len() + } + + fn channels(&self) -> rodio::ChannelCount { + self.inner.channels() + } + + fn sample_rate(&self) -> rodio::SampleRate { + self.inner.sample_rate() + } + + fn total_duration(&self) -> Option { + self.inner.total_duration() + } +} + +impl Iterator for Denoiser { + type Item = Sample; + + #[inline] + fn next(&mut self) -> Option { + self.next += 1; + if self.next < self.ready.len() { + let sample = self.ready[self.next]; + return Some(sample); + } + + // This is a separate function to prevent it from being inlined + // as this code only runs once every 128 samples + self.prepare_next_ready() + .inspect_err(|_| { + log::error!("Denoise engine crashed"); + }) + .ok() + .flatten() + } +} + +#[derive(Debug, thiserror::Error)] +#[error("Could not send or receive from denoise thread. It must have crashed")] +struct DenoiseEngineCrashed; + +impl Denoiser { + #[cold] + fn prepare_next_ready(&mut self) -> Result, DenoiseEngineCrashed> { + self.state = match self.state { + IterState::Startup { enabled } => { + // guaranteed to be coming from silence + for _ in 0..3 { + let Some(sub_block) = read_sub_block(&mut self.inner) else { + return Ok(None); + }; + self.queued.push(sub_block); + self.input_tx + .send(sub_block) + .map_err(|_| DenoiseEngineCrashed)?; + } + let Some(sub_block) = read_sub_block(&mut self.inner) else { + return Ok(None); + }; + self.queued.push(sub_block); + self.input_tx + .send(sub_block) + .map_err(|_| DenoiseEngineCrashed)?; + // throw out old blocks that are denoised silence + let _ = self.denoised_rx.iter().take(3).count(); + self.ready = self.denoised_rx.recv().map_err(|_| DenoiseEngineCrashed)?; + + let Some(sub_block) = read_sub_block(&mut self.inner) else { + return Ok(None); + }; + self.queued.push(sub_block); + self.feed(sub_block); + + if enabled { + IterState::Enabled + } else { + IterState::Disabled + } + } + IterState::Enabled => { + self.ready = self.denoised_rx.recv().map_err(|_| DenoiseEngineCrashed)?; + let Some(sub_block) = read_sub_block(&mut self.inner) else { + return Ok(None); + }; + self.queued.push(sub_block); + self.input_tx + .send(sub_block) + .map_err(|_| DenoiseEngineCrashed)?; + IterState::Enabled + } + IterState::Disabled => { + // Need to maintain the same 512 samples delay such that + // we can re-enable at any point. + self.ready = self.queued.pop(); + let Some(sub_block) = read_sub_block(&mut self.inner) else { + return Ok(None); + }; + self.queued.push(sub_block); + IterState::Disabled + } + IterState::StartingMidAudio { + fed_to_denoiser: mut sub_blocks_fed, + } => { + self.ready = self.queued.pop(); + let Some(sub_block) = read_sub_block(&mut self.inner) else { + return Ok(None); + }; + self.queued.push(sub_block); + self.input_tx + .send(sub_block) + .map_err(|_| DenoiseEngineCrashed)?; + sub_blocks_fed += 1; + if sub_blocks_fed > 4 { + // throw out partially denoised blocks, + // next will be correctly denoised + let _ = self.denoised_rx.iter().take(3).count(); + IterState::Enabled + } else { + IterState::StartingMidAudio { + fed_to_denoiser: sub_blocks_fed, + } + } + } + }; + + self.next = 0; + Ok(Some(self.ready[0])) + } +} + +fn read_sub_block(s: &mut impl Source) -> Option<[f32; BLOCK_SHIFT]> { + let mut res = [0f32; BLOCK_SHIFT]; + for sample in &mut res { + *sample = s.next()?; + } + Some(res) +} diff --git a/crates/edit_prediction_context/Cargo.toml b/crates/edit_prediction_context/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..ad455b0a4ecb1746debafd23f0503b4365f9a0cf --- /dev/null +++ b/crates/edit_prediction_context/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "edit_prediction_context" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/edit_prediction_context.rs" + +[dependencies] +anyhow.workspace = true +arrayvec.workspace = true +collections.workspace = true +gpui.workspace = true +language.workspace = true +log.workspace = true +project.workspace = true +slotmap.workspace = true +text.workspace = true +tree-sitter.workspace = true +util.workspace = true +workspace-hack.workspace = true + +[dev-dependencies] +futures.workspace = true +gpui = { workspace = true, features = ["test-support"] } +indoc.workspace = true +language = { workspace = true, features = ["test-support"] } +pretty_assertions.workspace = true +project = {workspace= true, features = ["test-support"]} +serde_json.workspace = true +settings = {workspace= true, features = ["test-support"]} +text = { workspace = true, features = ["test-support"] } +util = { workspace = true, features = ["test-support"] } +zlog.workspace = true diff --git a/crates/edit_prediction_context/LICENSE-GPL b/crates/edit_prediction_context/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/edit_prediction_context/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs new file mode 100644 index 0000000000000000000000000000000000000000..acfb89880c3ed9e7b1ebcacd4b5fa313830165ba --- /dev/null +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -0,0 +1,8 @@ +mod excerpt; +mod outline; +mod reference; +mod tree_sitter_index; + +pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText}; +pub use reference::references_in_excerpt; +pub use tree_sitter_index::{BufferDeclaration, Declaration, FileDeclaration, TreeSitterIndex}; diff --git a/crates/edit_prediction_context/src/excerpt.rs b/crates/edit_prediction_context/src/excerpt.rs new file mode 100644 index 0000000000000000000000000000000000000000..c6caa6a1b7b4076cf739c1ac198656b9fba431a6 --- /dev/null +++ b/crates/edit_prediction_context/src/excerpt.rs @@ -0,0 +1,616 @@ +use language::BufferSnapshot; +use std::ops::Range; +use text::{OffsetRangeExt as _, Point, ToOffset as _, ToPoint as _}; +use tree_sitter::{Node, TreeCursor}; +use util::RangeExt; + +// TODO: +// +// - Test parent signatures +// +// - Decide whether to count signatures against the excerpt size. Could instead defer this to prompt +// planning. +// +// - Still return an excerpt even if the line around the cursor doesn't fit (e.g. for a markdown +// paragraph). +// +// - Truncation of long lines. +// +// - Filter outer syntax layers that don't support edit prediction. + +#[derive(Debug, Clone)] +pub struct EditPredictionExcerptOptions { + /// Limit for the number of bytes in the window around the cursor. + pub max_bytes: usize, + /// Minimum number of bytes in the window around the cursor. When syntax tree selection results + /// in an excerpt smaller than this, it will fall back on line-based selection. + pub min_bytes: usize, + /// Target ratio of bytes before the cursor divided by total bytes in the window. + pub target_before_cursor_over_total_bytes: f32, + /// Whether to include parent signatures + pub include_parent_signatures: bool, +} + +#[derive(Clone)] +pub struct EditPredictionExcerpt { + pub range: Range, + pub parent_signature_ranges: Vec>, + pub size: usize, +} + +#[derive(Clone)] +pub struct EditPredictionExcerptText { + pub body: String, + pub parent_signatures: Vec, +} + +impl EditPredictionExcerpt { + pub fn text(&self, buffer: &BufferSnapshot) -> EditPredictionExcerptText { + let body = buffer + .text_for_range(self.range.clone()) + .collect::(); + let parent_signatures = self + .parent_signature_ranges + .iter() + .map(|range| buffer.text_for_range(range.clone()).collect::()) + .collect(); + EditPredictionExcerptText { + body, + parent_signatures, + } + } + + /// Selects an excerpt around a buffer position, attempting to choose logical boundaries based + /// on TreeSitter structure and approximately targeting a goal ratio of bytesbefore vs after the + /// cursor. When `include_parent_signatures` is true, the excerpt also includes the signatures + /// of parent outline items. + /// + /// First tries to use AST node boundaries to select the excerpt, and falls back on line-based + /// expansion. + /// + /// Returns `None` if the line around the cursor doesn't fit. + pub fn select_from_buffer( + query_point: Point, + buffer: &BufferSnapshot, + options: &EditPredictionExcerptOptions, + ) -> Option { + if buffer.len() <= options.max_bytes { + log::debug!( + "using entire file for excerpt since source length ({}) <= window max bytes ({})", + buffer.len(), + options.max_bytes + ); + return Some(EditPredictionExcerpt::new(0..buffer.len(), Vec::new())); + } + + let query_offset = query_point.to_offset(buffer); + let query_range = Point::new(query_point.row, 0).to_offset(buffer) + ..Point::new(query_point.row + 1, 0).to_offset(buffer); + if query_range.len() >= options.max_bytes { + return None; + } + + // TODO: Don't compute text / annotation_range / skip converting to and from anchors. + let outline_items = if options.include_parent_signatures { + buffer + .outline_items_containing(query_range.clone(), false, None) + .into_iter() + .flat_map(|item| { + Some(ExcerptOutlineItem { + item_range: item.range.to_offset(&buffer), + signature_range: item.signature_range?.to_offset(&buffer), + }) + }) + .collect() + } else { + Vec::new() + }; + + let excerpt_selector = ExcerptSelector { + query_offset, + query_range, + outline_items: &outline_items, + buffer, + options, + }; + + if let Some(excerpt_ranges) = excerpt_selector.select_tree_sitter_nodes() { + if excerpt_ranges.size >= options.min_bytes { + return Some(excerpt_ranges); + } + log::debug!( + "tree-sitter excerpt was {} bytes, smaller than min of {}, falling back on line-based selection", + excerpt_ranges.size, + options.min_bytes + ); + } else { + log::debug!( + "couldn't find excerpt via tree-sitter, falling back on line-based selection" + ); + } + + excerpt_selector.select_lines() + } + + fn new(range: Range, parent_signature_ranges: Vec>) -> Self { + let size = range.len() + + parent_signature_ranges + .iter() + .map(|r| r.len()) + .sum::(); + Self { + range, + parent_signature_ranges, + size, + } + } + + fn with_expanded_range(&self, new_range: Range) -> Self { + if !new_range.contains_inclusive(&self.range) { + // this is an issue because parent_signature_ranges may be incorrect + log::error!("bug: with_expanded_range called with disjoint range"); + } + let mut parent_signature_ranges = Vec::with_capacity(self.parent_signature_ranges.len()); + let mut size = new_range.len(); + for range in &self.parent_signature_ranges { + if range.contains_inclusive(&new_range) { + break; + } + parent_signature_ranges.push(range.clone()); + size += range.len(); + } + Self { + range: new_range, + parent_signature_ranges, + size, + } + } + + fn parent_signatures_size(&self) -> usize { + self.size - self.range.len() + } +} + +struct ExcerptSelector<'a> { + query_offset: usize, + query_range: Range, + outline_items: &'a [ExcerptOutlineItem], + buffer: &'a BufferSnapshot, + options: &'a EditPredictionExcerptOptions, +} + +struct ExcerptOutlineItem { + item_range: Range, + signature_range: Range, +} + +impl<'a> ExcerptSelector<'a> { + /// Finds the largest node that is smaller than the window size and contains `query_range`. + fn select_tree_sitter_nodes(&self) -> Option { + let selected_layer_root = self.select_syntax_layer()?; + let mut cursor = selected_layer_root.walk(); + + loop { + let excerpt_range = node_line_start(cursor.node()).to_offset(&self.buffer) + ..node_line_end(cursor.node()).to_offset(&self.buffer); + if excerpt_range.contains_inclusive(&self.query_range) { + let excerpt = self.make_excerpt(excerpt_range); + if excerpt.size <= self.options.max_bytes { + return Some(self.expand_to_siblings(&mut cursor, excerpt)); + } + } else { + // TODO: Should still be able to handle this case via AST nodes. For example, this + // can happen if the cursor is between two methods in a large class file. + return None; + } + + if cursor + .goto_first_child_for_byte(self.query_range.start) + .is_none() + { + return None; + } + } + } + + /// Select the smallest syntax layer that exceeds max_len, or the largest if none exceed max_len. + fn select_syntax_layer(&self) -> Option> { + let mut smallest_exceeding_max_len: Option> = None; + let mut largest: Option> = None; + for layer in self + .buffer + .syntax_layers_for_range(self.query_range.start..self.query_range.start, true) + { + let layer_range = layer.node().byte_range(); + if !layer_range.contains_inclusive(&self.query_range) { + continue; + } + + if layer_range.len() > self.options.max_bytes { + match &smallest_exceeding_max_len { + None => smallest_exceeding_max_len = Some(layer.node()), + Some(existing) => { + if layer_range.len() < existing.byte_range().len() { + smallest_exceeding_max_len = Some(layer.node()); + } + } + } + } else { + match &largest { + None => largest = Some(layer.node()), + Some(existing) if layer_range.len() > existing.byte_range().len() => { + largest = Some(layer.node()) + } + _ => {} + } + } + } + + smallest_exceeding_max_len.or(largest) + } + + // motivation for this and `goto_previous_named_sibling` is to avoid including things like + // trailing unnamed "}" in body nodes + fn goto_next_named_sibling(cursor: &mut TreeCursor) -> bool { + while cursor.goto_next_sibling() { + if cursor.node().is_named() { + return true; + } + } + false + } + + fn goto_previous_named_sibling(cursor: &mut TreeCursor) -> bool { + while cursor.goto_previous_sibling() { + if cursor.node().is_named() { + return true; + } + } + false + } + + fn expand_to_siblings( + &self, + cursor: &mut TreeCursor, + mut excerpt: EditPredictionExcerpt, + ) -> EditPredictionExcerpt { + let mut forward_cursor = cursor.clone(); + let backward_cursor = cursor; + let mut forward_done = !Self::goto_next_named_sibling(&mut forward_cursor); + let mut backward_done = !Self::goto_previous_named_sibling(backward_cursor); + loop { + if backward_done && forward_done { + break; + } + + let mut forward = None; + while !forward_done { + let new_end = node_line_end(forward_cursor.node()).to_offset(&self.buffer); + if new_end > excerpt.range.end { + let new_excerpt = excerpt.with_expanded_range(excerpt.range.start..new_end); + if new_excerpt.size <= self.options.max_bytes { + forward = Some(new_excerpt); + break; + } else { + log::debug!("halting forward expansion, as it doesn't fit"); + forward_done = true; + break; + } + } + forward_done = !Self::goto_next_named_sibling(&mut forward_cursor); + } + + let mut backward = None; + while !backward_done { + let new_start = node_line_start(backward_cursor.node()).to_offset(&self.buffer); + if new_start < excerpt.range.start { + let new_excerpt = excerpt.with_expanded_range(new_start..excerpt.range.end); + if new_excerpt.size <= self.options.max_bytes { + backward = Some(new_excerpt); + break; + } else { + log::debug!("halting backward expansion, as it doesn't fit"); + backward_done = true; + break; + } + } + backward_done = !Self::goto_previous_named_sibling(backward_cursor); + } + + let go_forward = match (forward, backward) { + (Some(forward), Some(backward)) => { + let go_forward = self.is_better_excerpt(&forward, &backward); + if go_forward { + excerpt = forward; + } else { + excerpt = backward; + } + go_forward + } + (Some(forward), None) => { + log::debug!("expanding forward, since backward expansion has halted"); + excerpt = forward; + true + } + (None, Some(backward)) => { + log::debug!("expanding backward, since forward expansion has halted"); + excerpt = backward; + false + } + (None, None) => break, + }; + + if go_forward { + forward_done = !Self::goto_next_named_sibling(&mut forward_cursor); + } else { + backward_done = !Self::goto_previous_named_sibling(backward_cursor); + } + } + + excerpt + } + + fn select_lines(&self) -> Option { + // early return if line containing query_offset is already too large + let excerpt = self.make_excerpt(self.query_range.clone()); + if excerpt.size > self.options.max_bytes { + log::debug!( + "excerpt for cursor line is {} bytes, which exceeds the window", + excerpt.size + ); + return None; + } + let signatures_size = excerpt.parent_signatures_size(); + let bytes_remaining = self.options.max_bytes.saturating_sub(signatures_size); + + let before_bytes = + (self.options.target_before_cursor_over_total_bytes * bytes_remaining as f32) as usize; + + let start_point = { + let offset = self.query_offset.saturating_sub(before_bytes); + let point = offset.to_point(self.buffer); + Point::new(point.row + 1, 0) + }; + let start_offset = start_point.to_offset(&self.buffer); + let end_point = { + let offset = start_offset + bytes_remaining; + let point = offset.to_point(self.buffer); + Point::new(point.row, 0) + }; + let end_offset = end_point.to_offset(&self.buffer); + + // this could be expanded further since recalculated `signature_size` may be smaller, but + // skipping that for now for simplicity + // + // TODO: could also consider checking if lines immediately before / after fit. + let excerpt = self.make_excerpt(start_offset..end_offset); + if excerpt.size > self.options.max_bytes { + log::error!( + "bug: line-based excerpt selection has size {}, \ + which is {} bytes larger than the max size", + excerpt.size, + excerpt.size - self.options.max_bytes + ); + } + return Some(excerpt); + } + + fn make_excerpt(&self, range: Range) -> EditPredictionExcerpt { + let parent_signature_ranges = self + .outline_items + .iter() + .filter(|item| item.item_range.contains_inclusive(&range)) + .map(|item| item.signature_range.clone()) + .collect(); + EditPredictionExcerpt::new(range, parent_signature_ranges) + } + + /// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt. + fn is_better_excerpt( + &self, + forward: &EditPredictionExcerpt, + backward: &EditPredictionExcerpt, + ) -> bool { + let forward_ratio = self.excerpt_range_ratio(forward); + let backward_ratio = self.excerpt_range_ratio(backward); + let forward_delta = + (forward_ratio - self.options.target_before_cursor_over_total_bytes).abs(); + let backward_delta = + (backward_ratio - self.options.target_before_cursor_over_total_bytes).abs(); + let forward_is_better = forward_delta <= backward_delta; + if forward_is_better { + log::debug!( + "expanding forward since {} is closer than {} to {}", + forward_ratio, + backward_ratio, + self.options.target_before_cursor_over_total_bytes + ); + } else { + log::debug!( + "expanding backward since {} is closer than {} to {}", + backward_ratio, + forward_ratio, + self.options.target_before_cursor_over_total_bytes + ); + } + forward_is_better + } + + /// Returns the ratio of bytes before the cursor over bytes within the range. + fn excerpt_range_ratio(&self, excerpt: &EditPredictionExcerpt) -> f32 { + let Some(bytes_before_cursor) = self.query_offset.checked_sub(excerpt.range.start) else { + log::error!("bug: edit prediction cursor offset is not outside the excerpt"); + return 0.0; + }; + bytes_before_cursor as f32 / excerpt.range.len() as f32 + } +} + +fn node_line_start(node: Node) -> Point { + Point::new(node.start_position().row as u32, 0) +} + +fn node_line_end(node: Node) -> Point { + Point::new(node.end_position().row as u32 + 1, 0) +} + +#[cfg(test)] +mod tests { + use super::*; + use gpui::{AppContext, TestAppContext}; + use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; + use util::test::{generate_marked_text, marked_text_offsets_by}; + + fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot { + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang().into(), cx)); + buffer.read_with(cx, |buffer, _| buffer.snapshot()) + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) + .unwrap() + } + + fn cursor_and_excerpt_range(text: &str) -> (String, usize, Range) { + let (text, offsets) = marked_text_offsets_by(text, vec!['ˇ', '«', '»']); + (text, offsets[&'ˇ'][0], offsets[&'«'][0]..offsets[&'»'][0]) + } + + fn check_example(options: EditPredictionExcerptOptions, text: &str, cx: &mut TestAppContext) { + let (text, cursor, expected_excerpt) = cursor_and_excerpt_range(text); + + let buffer = create_buffer(&text, cx); + let cursor_point = cursor.to_point(&buffer); + + let excerpt = EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options) + .expect("Should select an excerpt"); + pretty_assertions::assert_eq!( + generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false), + generate_marked_text(&text, &[expected_excerpt], false) + ); + assert!(excerpt.size <= options.max_bytes); + assert!(excerpt.range.contains(&cursor)); + } + + #[gpui::test] + fn test_ast_based_selection_current_node(cx: &mut TestAppContext) { + zlog::init_test(); + let text = r#" +fn main() { + let x = 1; +« let ˇy = 2; +» let z = 3; +}"#; + + let options = EditPredictionExcerptOptions { + max_bytes: 20, + min_bytes: 10, + target_before_cursor_over_total_bytes: 0.5, + include_parent_signatures: false, + }; + + check_example(options, text, cx); + } + + #[gpui::test] + fn test_ast_based_selection_parent_node(cx: &mut TestAppContext) { + zlog::init_test(); + let text = r#" +fn foo() {} + +«fn main() { + let x = 1; + let ˇy = 2; + let z = 3; +} +» +fn bar() {}"#; + + let options = EditPredictionExcerptOptions { + max_bytes: 65, + min_bytes: 10, + target_before_cursor_over_total_bytes: 0.5, + include_parent_signatures: false, + }; + + check_example(options, text, cx); + } + + #[gpui::test] + fn test_ast_based_selection_expands_to_siblings(cx: &mut TestAppContext) { + zlog::init_test(); + let text = r#" +fn main() { +« let x = 1; + let ˇy = 2; + let z = 3; +»}"#; + + let options = EditPredictionExcerptOptions { + max_bytes: 50, + min_bytes: 10, + target_before_cursor_over_total_bytes: 0.5, + include_parent_signatures: false, + }; + + check_example(options, text, cx); + } + + #[gpui::test] + fn test_line_based_selection(cx: &mut TestAppContext) { + zlog::init_test(); + let text = r#" +fn main() { + let x = 1; +« if true { + let ˇy = 2; + } + let z = 3; +»}"#; + + let options = EditPredictionExcerptOptions { + max_bytes: 60, + min_bytes: 45, + target_before_cursor_over_total_bytes: 0.5, + include_parent_signatures: false, + }; + + check_example(options, text, cx); + } + + #[gpui::test] + fn test_line_based_selection_with_before_cursor_ratio(cx: &mut TestAppContext) { + zlog::init_test(); + let text = r#" + fn main() { +« let a = 1; + let b = 2; + let c = 3; + let ˇd = 4; + let e = 5; + let f = 6; +» + let g = 7; + }"#; + + let options = EditPredictionExcerptOptions { + max_bytes: 120, + min_bytes: 10, + target_before_cursor_over_total_bytes: 0.6, + include_parent_signatures: false, + }; + + check_example(options, text, cx); + } +} diff --git a/crates/edit_prediction_context/src/outline.rs b/crates/edit_prediction_context/src/outline.rs new file mode 100644 index 0000000000000000000000000000000000000000..492352add1fd4c666eab3b12989f9b801d03570f --- /dev/null +++ b/crates/edit_prediction_context/src/outline.rs @@ -0,0 +1,130 @@ +use language::{BufferSnapshot, LanguageId, SyntaxMapMatches}; +use std::{cmp::Reverse, ops::Range, sync::Arc}; + +// TODO: +// +// * how to handle multiple name captures? for now last one wins +// +// * annotation ranges +// +// * new "signature" capture for outline queries +// +// * Check parent behavior of "int x, y = 0" declarations in a test + +pub struct OutlineDeclaration { + pub parent_index: Option, + pub identifier: Identifier, + pub item_range: Range, + pub signature_range: Range, +} + +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct Identifier { + pub name: Arc, + pub language_id: LanguageId, +} + +pub fn declarations_in_buffer(buffer: &BufferSnapshot) -> Vec { + declarations_overlapping_range(0..buffer.len(), buffer) +} + +pub fn declarations_overlapping_range( + range: Range, + buffer: &BufferSnapshot, +) -> Vec { + let mut declarations = OutlineIterator::new(range, buffer).collect::>(); + declarations.sort_unstable_by_key(|item| (item.item_range.start, Reverse(item.item_range.end))); + + let mut parent_stack: Vec<(usize, Range)> = Vec::new(); + for (index, declaration) in declarations.iter_mut().enumerate() { + while let Some((top_parent_index, top_parent_range)) = parent_stack.last() { + if declaration.item_range.start >= top_parent_range.end { + parent_stack.pop(); + } else { + declaration.parent_index = Some(*top_parent_index); + break; + } + } + parent_stack.push((index, declaration.item_range.clone())); + } + declarations +} + +/// Iterates outline items without being ordered w.r.t. nested items and without populating +/// `parent`. +pub struct OutlineIterator<'a> { + buffer: &'a BufferSnapshot, + matches: SyntaxMapMatches<'a>, +} + +impl<'a> OutlineIterator<'a> { + pub fn new(range: Range, buffer: &'a BufferSnapshot) -> Self { + let matches = buffer.syntax.matches(range, &buffer.text, |grammar| { + grammar.outline_config.as_ref().map(|c| &c.query) + }); + + Self { buffer, matches } + } +} + +impl<'a> Iterator for OutlineIterator<'a> { + type Item = OutlineDeclaration; + + fn next(&mut self) -> Option { + while let Some(mat) = self.matches.peek() { + let config = self.matches.grammars()[mat.grammar_index] + .outline_config + .as_ref() + .unwrap(); + + let mut name_range = None; + let mut item_range = None; + let mut signature_start = None; + let mut signature_end = None; + + let mut add_to_signature = |range: Range| { + if signature_start.is_none() { + signature_start = Some(range.start); + } + signature_end = Some(range.end); + }; + + for capture in mat.captures { + let range = capture.node.byte_range(); + if capture.index == config.name_capture_ix { + name_range = Some(range.clone()); + add_to_signature(range); + } else if Some(capture.index) == config.context_capture_ix + || Some(capture.index) == config.extra_context_capture_ix + { + add_to_signature(range); + } else if capture.index == config.item_capture_ix { + item_range = Some(range.clone()); + } + } + + let language_id = mat.language.id(); + self.matches.advance(); + + if let Some(name_range) = name_range + && let Some(item_range) = item_range + && let Some(signature_start) = signature_start + && let Some(signature_end) = signature_end + { + let name = self + .buffer + .text_for_range(name_range) + .collect::() + .into(); + + return Some(OutlineDeclaration { + identifier: Identifier { name, language_id }, + item_range: item_range, + signature_range: signature_start..signature_end, + parent_index: None, + }); + } + } + None + } +} diff --git a/crates/edit_prediction_context/src/reference.rs b/crates/edit_prediction_context/src/reference.rs new file mode 100644 index 0000000000000000000000000000000000000000..65d34e73bf20f62b24ac2a654af43fc3b83041a9 --- /dev/null +++ b/crates/edit_prediction_context/src/reference.rs @@ -0,0 +1,109 @@ +use language::BufferSnapshot; +use std::collections::HashMap; +use std::ops::Range; + +use crate::{ + excerpt::{EditPredictionExcerpt, EditPredictionExcerptText}, + outline::Identifier, +}; + +#[derive(Debug)] +pub struct Reference { + pub identifier: Identifier, + pub range: Range, + pub region: ReferenceRegion, +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum ReferenceRegion { + Breadcrumb, + Nearby, +} + +pub fn references_in_excerpt( + excerpt: &EditPredictionExcerpt, + excerpt_text: &EditPredictionExcerptText, + snapshot: &BufferSnapshot, +) -> HashMap> { + let mut references = identifiers_in_range( + excerpt.range.clone(), + excerpt_text.body.as_str(), + ReferenceRegion::Nearby, + snapshot, + ); + + for (range, text) in excerpt + .parent_signature_ranges + .iter() + .zip(excerpt_text.parent_signatures.iter()) + { + references.extend(identifiers_in_range( + range.clone(), + text.as_str(), + ReferenceRegion::Breadcrumb, + snapshot, + )); + } + + let mut identifier_to_references: HashMap> = HashMap::new(); + for reference in references { + identifier_to_references + .entry(reference.identifier.clone()) + .or_insert_with(Vec::new) + .push(reference); + } + identifier_to_references +} + +/// Finds all nodes which have a "variable" match from the highlights query within the offset range. +pub fn identifiers_in_range( + range: Range, + range_text: &str, + reference_region: ReferenceRegion, + buffer: &BufferSnapshot, +) -> Vec { + let mut matches = buffer + .syntax + .matches(range.clone(), &buffer.text, |grammar| { + grammar + .highlights_config + .as_ref() + .map(|config| &config.query) + }); + + let mut references = Vec::new(); + let mut last_added_range = None; + while let Some(mat) = matches.peek() { + let config = matches.grammars()[mat.grammar_index] + .highlights_config + .as_ref(); + + for capture in mat.captures { + if let Some(config) = config { + if config.identifier_capture_indices.contains(&capture.index) { + let node_range = capture.node.byte_range(); + + // sometimes multiple highlight queries match - this deduplicates them + if Some(node_range.clone()) == last_added_range { + continue; + } + + let identifier_text = + &range_text[node_range.start - range.start..node_range.end - range.start]; + references.push(Reference { + identifier: Identifier { + name: identifier_text.into(), + language_id: mat.language.id(), + }, + range: node_range.clone(), + region: reference_region, + }); + last_added_range = Some(node_range); + } + } + } + + matches.advance(); + } + references +} diff --git a/crates/edit_prediction_context/src/tree_sitter_index.rs b/crates/edit_prediction_context/src/tree_sitter_index.rs new file mode 100644 index 0000000000000000000000000000000000000000..f905aa7a01f29d26083d219bc8d2bd600847036a --- /dev/null +++ b/crates/edit_prediction_context/src/tree_sitter_index.rs @@ -0,0 +1,825 @@ +use collections::{HashMap, HashSet}; +use gpui::{App, AppContext as _, Context, Entity, Task, WeakEntity}; +use language::{Buffer, BufferEvent, BufferSnapshot}; +use project::buffer_store::{BufferStore, BufferStoreEvent}; +use project::worktree_store::{WorktreeStore, WorktreeStoreEvent}; +use project::{PathChange, Project, ProjectEntryId, ProjectPath}; +use slotmap::SlotMap; +use std::ops::Range; +use std::sync::Arc; +use text::Anchor; +use util::{ResultExt as _, debug_panic, some_or_debug_panic}; + +use crate::outline::{Identifier, OutlineDeclaration, declarations_in_buffer}; + +// TODO: +// +// * Skip for remote projects + +// Potential future improvements: +// +// * Send multiple selected excerpt ranges. Challenge is that excerpt ranges influence which +// references are present and their scores. + +// Potential future optimizations: +// +// * Cache of buffers for files +// +// * Parse files directly instead of loading into a Rope. Make SyntaxMap generic to handle embedded +// languages? Will also need to find line boundaries, but that can be done by scanning characters in +// the flat representation. +// +// * Use something similar to slotmap without key versions. +// +// * Concurrent slotmap +// +// * Use queue for parsing + +slotmap::new_key_type! { + pub struct DeclarationId; +} + +pub struct TreeSitterIndex { + declarations: SlotMap, + identifiers: HashMap>, + files: HashMap, + buffers: HashMap, BufferState>, + project: WeakEntity, +} + +#[derive(Debug, Default)] +struct FileState { + declarations: Vec, + task: Option>, +} + +#[derive(Default)] +struct BufferState { + declarations: Vec, + task: Option>, +} + +#[derive(Debug, Clone)] +pub enum Declaration { + File { + project_entry_id: ProjectEntryId, + declaration: FileDeclaration, + }, + Buffer { + buffer: WeakEntity, + declaration: BufferDeclaration, + }, +} + +impl Declaration { + fn identifier(&self) -> &Identifier { + match self { + Declaration::File { declaration, .. } => &declaration.identifier, + Declaration::Buffer { declaration, .. } => &declaration.identifier, + } + } +} + +#[derive(Debug, Clone)] +pub struct FileDeclaration { + pub parent: Option, + pub identifier: Identifier, + pub item_range: Range, + pub signature_range: Range, + pub signature_text: Arc, +} + +#[derive(Debug, Clone)] +pub struct BufferDeclaration { + pub parent: Option, + pub identifier: Identifier, + pub item_range: Range, + pub signature_range: Range, +} + +impl TreeSitterIndex { + pub fn new(project: &Entity, cx: &mut Context) -> Self { + let mut this = Self { + declarations: SlotMap::with_key(), + identifiers: HashMap::default(), + project: project.downgrade(), + files: HashMap::default(), + buffers: HashMap::default(), + }; + + let worktree_store = project.read(cx).worktree_store(); + cx.subscribe(&worktree_store, Self::handle_worktree_store_event) + .detach(); + + for worktree in worktree_store + .read(cx) + .worktrees() + .map(|w| w.read(cx).snapshot()) + .collect::>() + { + for entry in worktree.files(false, 0) { + this.update_file( + entry.id, + ProjectPath { + worktree_id: worktree.id(), + path: entry.path.clone(), + }, + cx, + ); + } + } + + let buffer_store = project.read(cx).buffer_store().clone(); + for buffer in buffer_store.read(cx).buffers().collect::>() { + this.register_buffer(&buffer, cx); + } + cx.subscribe(&buffer_store, Self::handle_buffer_store_event) + .detach(); + + this + } + + pub fn declaration(&self, id: DeclarationId) -> Option<&Declaration> { + self.declarations.get(id) + } + + pub fn declarations_for_identifier( + &self, + identifier: Identifier, + cx: &App, + ) -> Vec { + // make sure to not have a large stack allocation + assert!(N < 32); + + let Some(declaration_ids) = self.identifiers.get(&identifier) else { + return vec![]; + }; + + let mut result = Vec::with_capacity(N); + let mut included_buffer_entry_ids = arrayvec::ArrayVec::<_, N>::new(); + let mut file_declarations = Vec::new(); + + for declaration_id in declaration_ids { + let declaration = self.declarations.get(*declaration_id); + let Some(declaration) = some_or_debug_panic(declaration) else { + continue; + }; + match declaration { + Declaration::Buffer { buffer, .. } => { + if let Ok(Some(entry_id)) = buffer.read_with(cx, |buffer, cx| { + project::File::from_dyn(buffer.file()).and_then(|f| f.project_entry_id(cx)) + }) { + included_buffer_entry_ids.push(entry_id); + result.push(declaration.clone()); + if result.len() == N { + return result; + } + } + } + Declaration::File { + project_entry_id, .. + } => { + if !included_buffer_entry_ids.contains(project_entry_id) { + file_declarations.push(declaration.clone()); + } + } + } + } + + for declaration in file_declarations { + match declaration { + Declaration::File { + project_entry_id, .. + } => { + if !included_buffer_entry_ids.contains(&project_entry_id) { + result.push(declaration); + + if result.len() == N { + return result; + } + } + } + Declaration::Buffer { .. } => {} + } + } + + result + } + + fn handle_worktree_store_event( + &mut self, + _worktree_store: Entity, + event: &WorktreeStoreEvent, + cx: &mut Context, + ) { + use WorktreeStoreEvent::*; + match event { + WorktreeUpdatedEntries(worktree_id, updated_entries_set) => { + for (path, entry_id, path_change) in updated_entries_set.iter() { + if let PathChange::Removed = path_change { + self.files.remove(entry_id); + } else { + let project_path = ProjectPath { + worktree_id: *worktree_id, + path: path.clone(), + }; + self.update_file(*entry_id, project_path, cx); + } + } + } + WorktreeDeletedEntry(_worktree_id, project_entry_id) => { + // TODO: Is this needed? + self.files.remove(project_entry_id); + } + _ => {} + } + } + + fn handle_buffer_store_event( + &mut self, + _buffer_store: Entity, + event: &BufferStoreEvent, + cx: &mut Context, + ) { + use BufferStoreEvent::*; + match event { + BufferAdded(buffer) => self.register_buffer(buffer, cx), + BufferOpened { .. } + | BufferChangedFilePath { .. } + | BufferDropped { .. } + | SharedBufferClosed { .. } => {} + } + } + + fn register_buffer(&mut self, buffer: &Entity, cx: &mut Context) { + self.buffers + .insert(buffer.downgrade(), BufferState::default()); + let weak_buf = buffer.downgrade(); + cx.observe_release(buffer, move |this, _buffer, _cx| { + this.buffers.remove(&weak_buf); + }) + .detach(); + cx.subscribe(buffer, Self::handle_buffer_event).detach(); + self.update_buffer(buffer.clone(), cx); + } + + fn handle_buffer_event( + &mut self, + buffer: Entity, + event: &BufferEvent, + cx: &mut Context, + ) { + match event { + BufferEvent::Edited => self.update_buffer(buffer, cx), + _ => {} + } + } + + fn update_buffer(&mut self, buffer: Entity, cx: &Context) { + let mut parse_status = buffer.read(cx).parse_status(); + let snapshot_task = cx.spawn({ + let weak_buffer = buffer.downgrade(); + async move |_, cx| { + while *parse_status.borrow() != language::ParseStatus::Idle { + parse_status.changed().await?; + } + weak_buffer.read_with(cx, |buffer, _cx| buffer.snapshot()) + } + }); + + let parse_task = cx.background_spawn(async move { + let snapshot = snapshot_task.await?; + + anyhow::Ok( + declarations_in_buffer(&snapshot) + .into_iter() + .map(|item| { + ( + item.parent_index, + BufferDeclaration::from_outline(item, &snapshot), + ) + }) + .collect::>(), + ) + }); + + let task = cx.spawn({ + let weak_buffer = buffer.downgrade(); + async move |this, cx| { + let Ok(declarations) = parse_task.await else { + return; + }; + + this.update(cx, |this, _cx| { + let buffer_state = this + .buffers + .entry(weak_buffer.clone()) + .or_insert_with(Default::default); + + for old_declaration_id in &buffer_state.declarations { + let Some(declaration) = this.declarations.remove(*old_declaration_id) + else { + debug_panic!("declaration not found"); + continue; + }; + if let Some(identifier_declarations) = + this.identifiers.get_mut(declaration.identifier()) + { + identifier_declarations.remove(old_declaration_id); + } + } + + let mut new_ids = Vec::with_capacity(declarations.len()); + this.declarations.reserve(declarations.len()); + for (parent_index, mut declaration) in declarations { + declaration.parent = parent_index + .and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied())); + + let identifier = declaration.identifier.clone(); + let declaration_id = this.declarations.insert(Declaration::Buffer { + buffer: weak_buffer.clone(), + declaration, + }); + new_ids.push(declaration_id); + + this.identifiers + .entry(identifier) + .or_default() + .insert(declaration_id); + } + + buffer_state.declarations = new_ids; + }) + .ok(); + } + }); + + self.buffers + .entry(buffer.downgrade()) + .or_insert_with(Default::default) + .task = Some(task); + } + + fn update_file( + &mut self, + entry_id: ProjectEntryId, + project_path: ProjectPath, + cx: &mut Context, + ) { + let Some(project) = self.project.upgrade() else { + return; + }; + let project = project.read(cx); + let Some(worktree) = project.worktree_for_id(project_path.worktree_id, cx) else { + return; + }; + let language_registry = project.languages().clone(); + + let snapshot_task = worktree.update(cx, |worktree, cx| { + let load_task = worktree.load_file(&project_path.path, cx); + cx.spawn(async move |_this, cx| { + let loaded_file = load_task.await?; + let language = language_registry + .language_for_file_path(&project_path.path) + .await + .log_err(); + + let buffer = cx.new(|cx| { + let mut buffer = Buffer::local(loaded_file.text, cx); + buffer.set_language(language, cx); + buffer + })?; + + let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?; + while *parse_status.borrow() != language::ParseStatus::Idle { + parse_status.changed().await?; + } + + buffer.read_with(cx, |buffer, _cx| buffer.snapshot()) + }) + }); + + let parse_task = cx.background_spawn(async move { + let snapshot = snapshot_task.await?; + let declarations = declarations_in_buffer(&snapshot) + .into_iter() + .map(|item| { + ( + item.parent_index, + FileDeclaration::from_outline(item, &snapshot), + ) + }) + .collect::>(); + anyhow::Ok(declarations) + }); + + let task = cx.spawn({ + async move |this, cx| { + // TODO: how to handle errors? + let Ok(declarations) = parse_task.await else { + return; + }; + this.update(cx, |this, _cx| { + let file_state = this.files.entry(entry_id).or_insert_with(Default::default); + + for old_declaration_id in &file_state.declarations { + let Some(declaration) = this.declarations.remove(*old_declaration_id) + else { + debug_panic!("declaration not found"); + continue; + }; + if let Some(identifier_declarations) = + this.identifiers.get_mut(declaration.identifier()) + { + identifier_declarations.remove(old_declaration_id); + } + } + + let mut new_ids = Vec::with_capacity(declarations.len()); + this.declarations.reserve(declarations.len()); + + for (parent_index, mut declaration) in declarations { + declaration.parent = parent_index + .and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied())); + + let identifier = declaration.identifier.clone(); + let declaration_id = this.declarations.insert(Declaration::File { + project_entry_id: entry_id, + declaration, + }); + new_ids.push(declaration_id); + + this.identifiers + .entry(identifier) + .or_default() + .insert(declaration_id); + } + + file_state.declarations = new_ids; + }) + .ok(); + } + }); + + self.files + .entry(entry_id) + .or_insert_with(Default::default) + .task = Some(task); + } +} + +impl BufferDeclaration { + pub fn from_outline(declaration: OutlineDeclaration, snapshot: &BufferSnapshot) -> Self { + // use of anchor_before is a guess that the proper behavior is to expand to include + // insertions immediately before the declaration, but not for insertions immediately after + Self { + parent: None, + identifier: declaration.identifier, + item_range: snapshot.anchor_before(declaration.item_range.start) + ..snapshot.anchor_before(declaration.item_range.end), + signature_range: snapshot.anchor_before(declaration.signature_range.start) + ..snapshot.anchor_before(declaration.signature_range.end), + } + } +} + +impl FileDeclaration { + pub fn from_outline( + declaration: OutlineDeclaration, + snapshot: &BufferSnapshot, + ) -> FileDeclaration { + FileDeclaration { + parent: None, + identifier: declaration.identifier, + item_range: declaration.item_range, + signature_text: snapshot + .text_for_range(declaration.signature_range.clone()) + .collect::() + .into(), + signature_range: declaration.signature_range, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::{path::Path, sync::Arc}; + + use gpui::TestAppContext; + use indoc::indoc; + use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust}; + use project::{FakeFs, Project, ProjectItem}; + use serde_json::json; + use settings::SettingsStore; + use text::OffsetRangeExt as _; + use util::path; + + use crate::tree_sitter_index::TreeSitterIndex; + + #[gpui::test] + async fn test_unopen_indexed_files(cx: &mut TestAppContext) { + let (project, index, rust_lang_id) = init_test(cx).await; + let main = Identifier { + name: "main".into(), + language_id: rust_lang_id, + }; + + index.read_with(cx, |index, cx| { + let decls = index.declarations_for_identifier::<8>(main.clone(), cx); + assert_eq!(decls.len(), 2); + + let decl = expect_file_decl("c.rs", &decls[0], &project, cx); + assert_eq!(decl.identifier, main.clone()); + assert_eq!(decl.item_range, 32..279); + + let decl = expect_file_decl("a.rs", &decls[1], &project, cx); + assert_eq!(decl.identifier, main); + assert_eq!(decl.item_range, 0..97); + }); + } + + #[gpui::test] + async fn test_parents_in_file(cx: &mut TestAppContext) { + let (project, index, rust_lang_id) = init_test(cx).await; + let test_process_data = Identifier { + name: "test_process_data".into(), + language_id: rust_lang_id, + }; + + index.read_with(cx, |index, cx| { + let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx); + assert_eq!(decls.len(), 1); + + let decl = expect_file_decl("c.rs", &decls[0], &project, cx); + assert_eq!(decl.identifier, test_process_data); + + let parent_id = decl.parent.unwrap(); + let parent = index.declaration(parent_id).unwrap(); + let parent_decl = expect_file_decl("c.rs", &parent, &project, cx); + assert_eq!( + parent_decl.identifier, + Identifier { + name: "tests".into(), + language_id: rust_lang_id + } + ); + assert_eq!(parent_decl.parent, None); + }); + } + + #[gpui::test] + async fn test_parents_in_buffer(cx: &mut TestAppContext) { + let (project, index, rust_lang_id) = init_test(cx).await; + let test_process_data = Identifier { + name: "test_process_data".into(), + language_id: rust_lang_id, + }; + + let buffer = project + .update(cx, |project, cx| { + let project_path = project.find_project_path("c.rs", cx).unwrap(); + project.open_buffer(project_path, cx) + }) + .await + .unwrap(); + + cx.run_until_parked(); + + index.read_with(cx, |index, cx| { + let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx); + assert_eq!(decls.len(), 1); + + let decl = expect_buffer_decl("c.rs", &decls[0], cx); + assert_eq!(decl.identifier, test_process_data); + + let parent_id = decl.parent.unwrap(); + let parent = index.declaration(parent_id).unwrap(); + let parent_decl = expect_buffer_decl("c.rs", &parent, cx); + assert_eq!( + parent_decl.identifier, + Identifier { + name: "tests".into(), + language_id: rust_lang_id + } + ); + assert_eq!(parent_decl.parent, None); + }); + + drop(buffer); + } + + #[gpui::test] + async fn test_declarations_limt(cx: &mut TestAppContext) { + let (_, index, rust_lang_id) = init_test(cx).await; + + index.read_with(cx, |index, cx| { + let decls = index.declarations_for_identifier::<1>( + Identifier { + name: "main".into(), + language_id: rust_lang_id, + }, + cx, + ); + assert_eq!(decls.len(), 1); + }); + } + + #[gpui::test] + async fn test_buffer_shadow(cx: &mut TestAppContext) { + let (project, index, rust_lang_id) = init_test(cx).await; + + let main = Identifier { + name: "main".into(), + language_id: rust_lang_id, + }; + + let buffer = project + .update(cx, |project, cx| { + let project_path = project.find_project_path("c.rs", cx).unwrap(); + project.open_buffer(project_path, cx) + }) + .await + .unwrap(); + + cx.run_until_parked(); + + index.read_with(cx, |index, cx| { + let decls = index.declarations_for_identifier::<8>(main.clone(), cx); + assert_eq!(decls.len(), 2); + let decl = expect_buffer_decl("c.rs", &decls[0], cx); + assert_eq!(decl.identifier, main); + assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..279); + + expect_file_decl("a.rs", &decls[1], &project, cx); + }); + + // Need to trigger flush_effects so that the observe_release handler will run. + cx.update(|_cx| { + drop(buffer); + }); + cx.run_until_parked(); + + index.read_with(cx, |index, cx| { + let decls = index.declarations_for_identifier::<8>(main, cx); + assert_eq!(decls.len(), 2); + expect_file_decl("c.rs", &decls[0], &project, cx); + expect_file_decl("a.rs", &decls[1], &project, cx); + }); + } + + fn expect_buffer_decl<'a>( + path: &str, + declaration: &'a Declaration, + cx: &App, + ) -> &'a BufferDeclaration { + if let Declaration::Buffer { + declaration, + buffer, + } = declaration + { + assert_eq!( + buffer + .upgrade() + .unwrap() + .read(cx) + .project_path(cx) + .unwrap() + .path + .as_ref(), + Path::new(path), + ); + declaration + } else { + panic!("Expected a buffer declaration, found {:?}", declaration); + } + } + + fn expect_file_decl<'a>( + path: &str, + declaration: &'a Declaration, + project: &Entity, + cx: &App, + ) -> &'a FileDeclaration { + if let Declaration::File { + declaration, + project_entry_id: file, + } = declaration + { + assert_eq!( + project + .read(cx) + .path_for_entry(*file, cx) + .unwrap() + .path + .as_ref(), + Path::new(path), + ); + declaration + } else { + panic!("Expected a file declaration, found {:?}", declaration); + } + } + + async fn init_test( + cx: &mut TestAppContext, + ) -> (Entity, Entity, LanguageId) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/root"), + json!({ + "a.rs": indoc! {r#" + fn main() { + let x = 1; + let y = 2; + let z = add(x, y); + println!("Result: {}", z); + } + + fn add(a: i32, b: i32) -> i32 { + a + b + } + "#}, + "b.rs": indoc! {" + pub struct Config { + pub name: String, + pub value: i32, + } + + impl Config { + pub fn new(name: String, value: i32) -> Self { + Config { name, value } + } + } + "}, + "c.rs": indoc! {r#" + use std::collections::HashMap; + + fn main() { + let args: Vec = std::env::args().collect(); + let data: Vec = args[1..] + .iter() + .filter_map(|s| s.parse().ok()) + .collect(); + let result = process_data(data); + println!("{:?}", result); + } + + fn process_data(data: Vec) -> HashMap { + let mut counts = HashMap::new(); + for value in data { + *counts.entry(value).or_insert(0) += 1; + } + counts + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_process_data() { + let data = vec![1, 2, 2, 3]; + let result = process_data(data); + assert_eq!(result.get(&2), Some(&2)); + } + } + "#} + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _| project.languages().clone()); + let lang = rust_lang(); + let lang_id = lang.id(); + language_registry.add(Arc::new(lang)); + + let index = cx.new(|cx| TreeSitterIndex::new(&project, cx)); + cx.run_until_parked(); + + (project, index, lang_id) + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) + .unwrap() + } +} diff --git a/crates/editor/src/display_map/block_map.rs b/crates/editor/src/display_map/block_map.rs index 03d04e7010248293604d10c2f3e553430e74c9c6..2d16e6af8b469ff6e94b1b9fc7d11f7186e7b3c3 100644 --- a/crates/editor/src/display_map/block_map.rs +++ b/crates/editor/src/display_map/block_map.rs @@ -1264,36 +1264,30 @@ impl BlockMapWriter<'_> { range: Range, inclusive: bool, ) -> &[Arc] { + if range.is_empty() && !inclusive { + return &[]; + } let wrap_snapshot = self.0.wrap_snapshot.borrow(); let buffer = wrap_snapshot.buffer_snapshot(); let start_block_ix = match self.0.custom_blocks.binary_search_by(|block| { let block_end = block.end().to_offset(buffer); - block_end.cmp(&range.start).then_with(|| { - if inclusive || (range.is_empty() && block.start().to_offset(buffer) == block_end) { - Ordering::Greater - } else { - Ordering::Less - } - }) + block_end.cmp(&range.start).then(Ordering::Greater) }) { Ok(ix) | Err(ix) => ix, }; - let end_block_ix = match self.0.custom_blocks.binary_search_by(|block| { - block - .start() - .to_offset(buffer) - .cmp(&range.end) - .then(if inclusive { - Ordering::Less - } else { - Ordering::Greater - }) + let end_block_ix = match self.0.custom_blocks[start_block_ix..].binary_search_by(|block| { + let block_start = block.start().to_offset(buffer); + block_start.cmp(&range.end).then(if inclusive { + Ordering::Less + } else { + Ordering::Greater + }) }) { Ok(ix) | Err(ix) => ix, }; - &self.0.custom_blocks[start_block_ix..end_block_ix] + &self.0.custom_blocks[start_block_ix..][..end_block_ix] } } diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 2bea1fc72b5586554219bb4967f44edc18424adf..38bbcecc16deadf8543f9aa280c607604882461b 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -2412,14 +2412,10 @@ impl Editor { pub fn is_range_selected(&mut self, range: &Range, cx: &mut Context) -> bool { if self .selections - .pending - .as_ref() + .pending_anchor() .is_some_and(|pending_selection| { let snapshot = self.buffer().read(cx).snapshot(cx); - pending_selection - .selection - .range() - .includes(range, &snapshot) + pending_selection.range().includes(range, &snapshot) }) { return true; @@ -3052,7 +3048,7 @@ impl Editor { } } - let selection_anchors = self.selections.disjoint_anchors(); + let selection_anchors = self.selections.disjoint_anchors_arc(); if self.focus_handle.is_focused(window) && self.leader_id.is_none() { self.buffer.update(cx, |buffer, cx| { @@ -3168,7 +3164,7 @@ impl Editor { self.blink_manager.update(cx, BlinkManager::pause_blinking); cx.emit(EditorEvent::SelectionsChanged { local }); - let selections = &self.selections.disjoint; + let selections = &self.selections.disjoint_anchors_arc(); if selections.len() == 1 { cx.emit(SearchEvent::ActiveMatchChanged) } @@ -3280,14 +3276,14 @@ impl Editor { other: Entity, cx: &mut Context, ) -> gpui::Subscription { - let other_selections = other.read(cx).selections.disjoint.to_vec(); + let other_selections = other.read(cx).selections.disjoint_anchors().to_vec(); self.selections.change_with(cx, |selections| { selections.select_anchors(other_selections); }); let other_subscription = cx.subscribe(&other, |this, other, other_evt, cx| { if let EditorEvent::SelectionsChanged { local: true } = other_evt { - let other_selections = other.read(cx).selections.disjoint.to_vec(); + let other_selections = other.read(cx).selections.disjoint_anchors().to_vec(); if other_selections.is_empty() { return; } @@ -3299,7 +3295,7 @@ impl Editor { let this_subscription = cx.subscribe_self::(move |this, this_evt, cx| { if let EditorEvent::SelectionsChanged { local: true } = this_evt { - let these_selections = this.selections.disjoint.to_vec(); + let these_selections = this.selections.disjoint_anchors().to_vec(); if these_selections.is_empty() { return; } @@ -3337,7 +3333,7 @@ impl Editor { effects, old_cursor_position: self.selections.newest_anchor().head(), history_entry: SelectionHistoryEntry { - selections: self.selections.disjoint_anchors(), + selections: self.selections.disjoint_anchors_arc(), select_next_state: self.select_next_state.clone(), select_prev_state: self.select_prev_state.clone(), add_selections_state: self.add_selections_state.clone(), @@ -3497,6 +3493,7 @@ impl Editor { let mut pending_selection = self .selections .pending_anchor() + .cloned() .expect("extend_selection not called with pending selection"); if position >= tail { pending_selection.start = tail_anchor; @@ -3518,7 +3515,7 @@ impl Editor { }; self.change_selections(effects, window, cx, |s| { - s.set_pending(pending_selection, pending_mode) + s.set_pending(pending_selection.clone(), pending_mode) }); } @@ -3593,7 +3590,7 @@ impl Editor { Some(selected_points[0].id) } else { let clicked_point_already_selected = - self.selections.disjoint.iter().find(|selection| { + self.selections.disjoint_anchors().iter().find(|selection| { selection.start.to_point(buffer) == start.to_point(buffer) || selection.end.to_point(buffer) == end.to_point(buffer) }); @@ -3698,7 +3695,7 @@ impl Editor { if self.columnar_selection_state.is_some() { self.select_columns(position, goal_column, &display_map, window, cx); - } else if let Some(mut pending) = self.selections.pending_anchor() { + } else if let Some(mut pending) = self.selections.pending_anchor().cloned() { let buffer = &display_map.buffer_snapshot; let head; let tail; @@ -3774,7 +3771,7 @@ impl Editor { } self.change_selections(SelectionEffects::no_scroll(), window, cx, |s| { - s.set_pending(pending, mode); + s.set_pending(pending.clone(), mode); }); } else { log::error!("update_selection dispatched with no pending selection"); @@ -3883,7 +3880,8 @@ impl Editor { }; pending_nonempty_selection - || (self.columnar_selection_state.is_some() && self.selections.disjoint.len() > 1) + || (self.columnar_selection_state.is_some() + && self.selections.disjoint_anchors().len() > 1) } pub fn has_pending_selection(&self) -> bool { @@ -5471,19 +5469,33 @@ impl Editor { if position.diff_base_anchor.is_some() { return; } - let (buffer, buffer_position) = - if let Some(output) = self.buffer.read(cx).text_anchor_for_position(position, cx) { - output - } else { - return; - }; + let buffer_position = multibuffer_snapshot.anchor_before(position); + let Some(buffer) = buffer_position + .buffer_id + .and_then(|buffer_id| self.buffer.read(cx).buffer(buffer_id)) + else { + return; + }; let buffer_snapshot = buffer.read(cx).snapshot(); let query: Option> = - Self::completion_query(&multibuffer_snapshot, position).map(|query| query.into()); + Self::completion_query(&multibuffer_snapshot, buffer_position) + .map(|query| query.into()); drop(multibuffer_snapshot); + // Hide the current completions menu when query is empty. Without this, cached + // completions from before the trigger char may be reused (#32774). + if query.is_none() { + let menu_is_open = matches!( + self.context_menu.borrow().as_ref(), + Some(CodeContextMenu::Completions(_)) + ); + if menu_is_open { + self.hide_context_menu(window, cx); + } + } + let mut ignore_word_threshold = false; let provider = match requested_source { Some(CompletionsMenuSource::Normal) | None => self.completion_provider.clone(), @@ -5505,37 +5517,6 @@ impl Editor { .as_ref() .is_none_or(|provider| provider.filter_completions()); - let trigger_kind = match trigger { - Some(trigger) if buffer.read(cx).completion_triggers().contains(trigger) => { - CompletionTriggerKind::TRIGGER_CHARACTER - } - _ => CompletionTriggerKind::INVOKED, - }; - let completion_context = CompletionContext { - trigger_character: trigger.and_then(|trigger| { - if trigger_kind == CompletionTriggerKind::TRIGGER_CHARACTER { - Some(String::from(trigger)) - } else { - None - } - }), - trigger_kind, - }; - - // Hide the current completions menu when a trigger char is typed. Without this, cached - // completions from before the trigger char may be reused (#32774). Snippet choices could - // involve trigger chars, so this is skipped in that case. - if trigger_kind == CompletionTriggerKind::TRIGGER_CHARACTER && self.snippet_stack.is_empty() - { - let menu_is_open = matches!( - self.context_menu.borrow().as_ref(), - Some(CodeContextMenu::Completions(_)) - ); - if menu_is_open { - self.hide_context_menu(window, cx); - } - } - if let Some(CodeContextMenu::Completions(menu)) = self.context_menu.borrow_mut().as_mut() { if filter_completions { menu.filter(query.clone(), provider.clone(), window, cx); @@ -5566,6 +5547,29 @@ impl Editor { } }; + let trigger_kind = match trigger { + Some(trigger) if buffer.read(cx).completion_triggers().contains(trigger) => { + CompletionTriggerKind::TRIGGER_CHARACTER + } + _ => CompletionTriggerKind::INVOKED, + }; + let completion_context = CompletionContext { + trigger_character: trigger.and_then(|trigger| { + if trigger_kind == CompletionTriggerKind::TRIGGER_CHARACTER { + Some(String::from(trigger)) + } else { + None + } + }), + trigger_kind, + }; + + let Anchor { + excerpt_id: buffer_excerpt_id, + text_anchor: buffer_position, + .. + } = buffer_position; + let (word_replace_range, word_to_exclude) = if let (word_range, Some(CharKind::Word)) = buffer_snapshot.surrounding_word(buffer_position, false) { @@ -5622,7 +5626,7 @@ impl Editor { let (mut words, provider_responses) = match &provider { Some(provider) => { let provider_responses = provider.completions( - position.excerpt_id, + buffer_excerpt_id, &buffer, buffer_position, completion_context, @@ -6058,7 +6062,7 @@ impl Editor { editor.refresh_edit_prediction(true, false, window, cx); }); - self.invalidate_autoclose_regions(&self.selections.disjoint_anchors(), &snapshot); + self.invalidate_autoclose_regions(&self.selections.disjoint_anchors_arc(), &snapshot); let show_new_completions_on_confirm = completion .confirm @@ -7465,7 +7469,7 @@ impl Editor { s.select_anchor_ranges([last_edit_end..last_edit_end]); }); - let selections = self.selections.disjoint_anchors(); + let selections = self.selections.disjoint_anchors_arc(); if let Some(transaction_id_now) = self.buffer.read(cx).last_transaction_id(cx) { let has_new_transaction = transaction_id_prev != Some(transaction_id_now); if has_new_transaction { @@ -7710,7 +7714,7 @@ impl Editor { let Some(mode) = Self::columnar_selection_mode(modifiers, cx) else { return; }; - if self.selections.pending.is_none() { + if self.selections.pending_anchor().is_none() { return; } @@ -10510,7 +10514,7 @@ impl Editor { fn enable_wrap_selections_in_tag(&self, cx: &App) -> bool { let snapshot = self.buffer.read(cx).snapshot(cx); - for selection in self.selections.disjoint_anchors().iter() { + for selection in self.selections.disjoint_anchors_arc().iter() { if snapshot .language_at(selection.start) .and_then(|lang| lang.config().wrap_characters.as_ref()) @@ -10844,7 +10848,7 @@ impl Editor { let snapshot = self.snapshot(window, cx); let cursors = self .selections - .disjoint_anchors() + .disjoint_anchors_arc() .iter() .map(|selection| { let cursor_position: Point = selection.head().to_point(&snapshot.buffer_snapshot); @@ -14535,7 +14539,7 @@ impl Editor { window: &mut Window, cx: &mut Context, ) -> Result<()> { - let selections = self.selections.disjoint_anchors(); + let selections = self.selections.disjoint_anchors_arc(); match selections.first() { Some(first) if selections.len() >= 2 => { self.change_selections(Default::default(), window, cx, |s| { @@ -14559,7 +14563,7 @@ impl Editor { window: &mut Window, cx: &mut Context, ) -> Result<()> { - let selections = self.selections.disjoint_anchors(); + let selections = self.selections.disjoint_anchors_arc(); match selections.last() { Some(last) if selections.len() >= 2 => { self.change_selections(Default::default(), window, cx, |s| { @@ -14987,15 +14991,13 @@ impl Editor { let mut new_range = old_range.clone(); while let Some((node, containing_range)) = buffer.syntax_ancestor(new_range.clone()) { - if !node.is_named() { - new_range = node.start_byte()..node.end_byte(); - continue; - } - new_range = match containing_range { MultiOrSingleBufferOffsetRange::Single(_) => break, MultiOrSingleBufferOffsetRange::Multi(range) => range, }; + if !node.is_named() { + continue; + } if !display_map.intersects_fold(new_range.start) && !display_map.intersects_fold(new_range.end) { @@ -15118,11 +15120,9 @@ impl Editor { let full_edits = selections .into_iter() .filter_map(|selection| { - // Only requires two branches once if-let-chains stabilize (#53667) - let child = if !selection.is_empty() { - selection.range() - } else if let Some((_, ancestor_range)) = - buffer.syntax_ancestor(selection.start..selection.end) + let child = if selection.is_empty() + && let Some((_, ancestor_range)) = + buffer.syntax_ancestor(selection.start..selection.end) { match ancestor_range { MultiOrSingleBufferOffsetRange::Single(range) => range, @@ -15150,6 +15150,9 @@ impl Editor { Some((selection.id, parent, text)) }) .collect::>(); + if full_edits.is_empty() { + return; + } self.transact(window, cx, |this, window, cx| { this.buffer.update(cx, |buffer, cx| { @@ -15658,7 +15661,7 @@ impl Editor { cx: &mut Context, ) { - let selections = self.selections.disjoint_anchors(); + let selections = self.selections.disjoint_anchors_arc(); let lines = if lines == 0 { EditorSettings::get_global(cx).expand_excerpt_lines @@ -17117,7 +17120,7 @@ impl Editor { .transaction(transaction_id_prev) .map(|t| t.0.clone()) }) - .unwrap_or_else(|| self.selections.disjoint_anchors()); + .unwrap_or_else(|| self.selections.disjoint_anchors_arc()); let mut timeout = cx.background_executor().timer(FORMAT_TIMEOUT).fuse(); let format = project.update(cx, |project, cx| { @@ -17657,7 +17660,7 @@ impl Editor { .update(cx, |buffer, cx| buffer.start_transaction_at(now, cx)) { self.selection_history - .insert_transaction(tx_id, self.selections.disjoint_anchors()); + .insert_transaction(tx_id, self.selections.disjoint_anchors_arc()); cx.emit(EditorEvent::TransactionBegun { transaction_id: tx_id, }); @@ -17679,7 +17682,7 @@ impl Editor { if let Some((_, end_selections)) = self.selection_history.transaction_mut(transaction_id) { - *end_selections = Some(self.selections.disjoint_anchors()); + *end_selections = Some(self.selections.disjoint_anchors_arc()); } else { log::error!("unexpectedly ended a transaction that wasn't started by this editor"); } @@ -18349,7 +18352,12 @@ impl Editor { _window: &mut Window, cx: &mut Context, ) { - let ranges: Vec<_> = self.selections.disjoint.iter().map(|s| s.range()).collect(); + let ranges: Vec<_> = self + .selections + .disjoint_anchors() + .iter() + .map(|s| s.range()) + .collect(); self.toggle_diff_hunks_in_ranges(ranges, cx); } @@ -18387,7 +18395,12 @@ impl Editor { cx: &mut Context, ) { let snapshot = self.buffer.read(cx).snapshot(cx); - let ranges: Vec<_> = self.selections.disjoint.iter().map(|s| s.range()).collect(); + let ranges: Vec<_> = self + .selections + .disjoint_anchors() + .iter() + .map(|s| s.range()) + .collect(); let stage = self.has_stageable_diff_hunks_in_ranges(&ranges, &snapshot); self.stage_or_unstage_diff_hunks(stage, ranges, cx); } @@ -18551,7 +18564,12 @@ impl Editor { } pub fn expand_selected_diff_hunks(&mut self, cx: &mut Context) { - let ranges: Vec<_> = self.selections.disjoint.iter().map(|s| s.range()).collect(); + let ranges: Vec<_> = self + .selections + .disjoint_anchors() + .iter() + .map(|s| s.range()) + .collect(); self.buffer .update(cx, |buffer, cx| buffer.expand_diff_hunks(ranges, cx)) } @@ -20548,7 +20566,9 @@ impl Editor { ) .detach(); } - self.update_lsp_data(false, Some(buffer_id), window, cx); + if self.active_diagnostics != ActiveDiagnostic::All { + self.update_lsp_data(false, Some(buffer_id), window, cx); + } cx.emit(EditorEvent::ExcerptsAdded { buffer: buffer.clone(), predecessor: *predecessor, @@ -21269,7 +21289,7 @@ impl Editor { buffer.finalize_last_transaction(cx); if self.leader_id.is_none() { buffer.set_active_selections( - &self.selections.disjoint_anchors(), + &self.selections.disjoint_anchors_arc(), self.selections.line_mode, self.cursor_shape, cx, @@ -23566,7 +23586,7 @@ impl EntityInputHandler for Editor { let marked_ranges = { let snapshot = this.buffer.read(cx).read(cx); this.selections - .disjoint_anchors() + .disjoint_anchors_arc() .iter() .map(|selection| { selection.start.bias_left(&snapshot)..selection.end.bias_right(&snapshot) diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index 4973e0d9b98ac1e362518d68f081912ac13144a4..9b2ec8a4edeee2329b284cd80ab5059439691f12 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -19257,7 +19257,7 @@ async fn test_expand_diff_hunk_at_excerpt_boundary(cx: &mut TestAppContext) { cx.executor().run_until_parked(); // When the start of a hunk coincides with the start of its excerpt, - // the hunk is expanded. When the start of a a hunk is earlier than + // the hunk is expanded. When the start of a hunk is earlier than // the start of its excerpt, the hunk is not expanded. cx.assert_state_with_diff( " diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index ed35d853865adba0aa4a18218153e93afaff89a2..30dfc7989a3640c73b55683e776c08b243e64d9c 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -1061,7 +1061,7 @@ impl EditorElement { ); if mouse_down_time.elapsed() >= drag_and_drop_delay { let drop_cursor = Selection { - id: post_inc(&mut editor.selections.next_selection_id), + id: post_inc(&mut editor.selections.next_selection_id()), start: drop_anchor, end: drop_anchor, reversed: false, @@ -1548,9 +1548,13 @@ impl EditorElement { // Local cursors if !skip_local { let color = cx.theme().players().local().cursor; - editor.selections.disjoint.iter().for_each(|selection| { - add_cursor(selection.head(), color); - }); + editor + .selections + .disjoint_anchors() + .iter() + .for_each(|selection| { + add_cursor(selection.head(), color); + }); if let Some(ref selection) = editor.selections.pending_anchor() { add_cursor(selection.head(), color); } @@ -3007,6 +3011,12 @@ impl EditorElement { .ilog10() + 1; + let git_gutter_width = Self::gutter_strip_width(line_height) + + gutter_dimensions + .git_blame_entries_width + .unwrap_or_default(); + let available_width = gutter_dimensions.left_padding - git_gutter_width; + buffer_rows .iter() .enumerate() @@ -3022,9 +3032,6 @@ impl EditorElement { ExpandExcerptDirection::UpAndDown => IconName::ExpandVertical, }; - let git_gutter_width = Self::gutter_strip_width(line_height); - let available_width = gutter_dimensions.left_padding - git_gutter_width; - let editor = self.editor.clone(); let is_wide = max_line_number_length >= EditorSettings::get_global(cx).gutter.min_line_number_digits as u32 @@ -9682,7 +9689,7 @@ impl EditorScrollbars { editor_bounds.bottom_left(), size( // The horizontal viewport size differs from the space available for the - // horizontal scrollbar, so we have to manually stich it together here. + // horizontal scrollbar, so we have to manually stitch it together here. editor_bounds.size.width - right_margin, scrollbar_width, ), diff --git a/crates/editor/src/items.rs b/crates/editor/src/items.rs index 253d0c27518107dc1cad3733cefbfef5bc12b807..bf21d6b461e6fdc082fdd1431f13b8daae730824 100644 --- a/crates/editor/src/items.rs +++ b/crates/editor/src/items.rs @@ -187,7 +187,7 @@ impl FollowableItem for Editor { } else if self.focus_handle.is_focused(window) { self.buffer.update(cx, |buffer, cx| { buffer.set_active_selections( - &self.selections.disjoint_anchors(), + &self.selections.disjoint_anchors_arc(), self.selections.line_mode, self.cursor_shape, cx, @@ -231,7 +231,7 @@ impl FollowableItem for Editor { scroll_y: scroll_anchor.offset.y, selections: self .selections - .disjoint_anchors() + .disjoint_anchors_arc() .iter() .map(|s| serialize_selection(s, &snapshot)) .collect(), @@ -310,7 +310,7 @@ impl FollowableItem for Editor { let snapshot = self.buffer.read(cx).snapshot(cx); update.selections = self .selections - .disjoint_anchors() + .disjoint_anchors_arc() .iter() .map(|s| serialize_selection(s, &snapshot)) .collect(); @@ -1675,7 +1675,7 @@ impl SearchableItem for Editor { cx: &mut Context, ) -> usize { let buffer = self.buffer().read(cx).snapshot(cx); - let current_index_position = if self.selections.disjoint_anchors().len() == 1 { + let current_index_position = if self.selections.disjoint_anchors_arc().len() == 1 { self.selections.newest_anchor().head() } else { matches[current_index].start diff --git a/crates/editor/src/jsx_tag_auto_close.rs b/crates/editor/src/jsx_tag_auto_close.rs index f2bdb717efe4b4d7521d4d6906b3c0f77dcb14f6..9d78ac543db6fb1d14c98d6c675d4658db8f6f03 100644 --- a/crates/editor/src/jsx_tag_auto_close.rs +++ b/crates/editor/src/jsx_tag_auto_close.rs @@ -507,7 +507,7 @@ pub(crate) fn handle_from( { let selections = this - .read_with(cx, |this, _| this.selections.disjoint_anchors()) + .read_with(cx, |this, _| this.selections.disjoint_anchors_arc()) .ok()?; for selection in selections.iter() { let Some(selection_buffer_offset_head) = diff --git a/crates/editor/src/lsp_ext.rs b/crates/editor/src/lsp_ext.rs index 18ad2d71c835e5ec7e3bbd540de21f7e38425c39..0c4760f5684acf450b793a1deac54be983dcafd0 100644 --- a/crates/editor/src/lsp_ext.rs +++ b/crates/editor/src/lsp_ext.rs @@ -35,7 +35,7 @@ where let project = editor.project.clone()?; editor .selections - .disjoint_anchors() + .disjoint_anchors_arc() .iter() .filter_map(|selection| Some((selection.head(), selection.head().buffer_id?))) .unique_by(|(_, buffer_id)| *buffer_id) diff --git a/crates/editor/src/mouse_context_menu.rs b/crates/editor/src/mouse_context_menu.rs index 3bc334c54c2f58e6dda2b404039369907c275422..78b12945afd1c2fcd359181afb030fc235c60a18 100644 --- a/crates/editor/src/mouse_context_menu.rs +++ b/crates/editor/src/mouse_context_menu.rs @@ -130,12 +130,9 @@ fn display_ranges<'a>( display_map: &'a DisplaySnapshot, selections: &'a SelectionsCollection, ) -> impl Iterator> + 'a { - let pending = selections - .pending - .as_ref() - .map(|pending| &pending.selection); + let pending = selections.pending_anchor(); selections - .disjoint + .disjoint_anchors() .iter() .chain(pending) .map(move |s| s.start.to_display_point(display_map)..s.end.to_display_point(display_map)) diff --git a/crates/editor/src/rust_analyzer_ext.rs b/crates/editor/src/rust_analyzer_ext.rs index f4059ca03d2ad70106aa958b4fe0c545cb4988ea..ffa0c017c0eb157df776cc49e0dba51e617e3379 100644 --- a/crates/editor/src/rust_analyzer_ext.rs +++ b/crates/editor/src/rust_analyzer_ext.rs @@ -319,7 +319,7 @@ fn cancel_flycheck_action( }; let buffer_id = editor .selections - .disjoint_anchors() + .disjoint_anchors_arc() .iter() .find_map(|selection| { let buffer_id = selection.start.buffer_id.or(selection.end.buffer_id)?; @@ -344,7 +344,7 @@ fn run_flycheck_action( }; let buffer_id = editor .selections - .disjoint_anchors() + .disjoint_anchors_arc() .iter() .find_map(|selection| { let buffer_id = selection.start.buffer_id.or(selection.end.buffer_id)?; @@ -369,7 +369,7 @@ fn clear_flycheck_action( }; let buffer_id = editor .selections - .disjoint_anchors() + .disjoint_anchors_arc() .iter() .find_map(|selection| { let buffer_id = selection.start.buffer_id.or(selection.end.buffer_id)?; diff --git a/crates/editor/src/selections_collection.rs b/crates/editor/src/selections_collection.rs index 0a02390b641e1020aff8d9cf0167b44485baf489..e562be10e92344c1c892878ab674cba39beb74c2 100644 --- a/crates/editor/src/selections_collection.rs +++ b/crates/editor/src/selections_collection.rs @@ -28,13 +28,13 @@ pub struct PendingSelection { pub struct SelectionsCollection { display_map: Entity, buffer: Entity, - pub next_selection_id: usize, + next_selection_id: usize, pub line_mode: bool, /// The non-pending, non-overlapping selections. /// The [SelectionsCollection::pending] selection could possibly overlap these - pub disjoint: Arc<[Selection]>, + disjoint: Arc<[Selection]>, /// A pending selection, such as when the mouse is being dragged - pub pending: Option, + pending: Option, } impl SelectionsCollection { @@ -84,20 +84,27 @@ impl SelectionsCollection { /// The non-pending, non-overlapping selections. There could be a pending selection that /// overlaps these if the mouse is being dragged, etc. This could also be empty if there is a /// pending selection. Returned as selections over Anchors. - pub fn disjoint_anchors(&self) -> Arc<[Selection]> { + pub fn disjoint_anchors_arc(&self) -> Arc<[Selection]> { self.disjoint.clone() } + /// The non-pending, non-overlapping selections. There could be a pending selection that + /// overlaps these if the mouse is being dragged, etc. This could also be empty if there is a + /// pending selection. Returned as selections over Anchors. + pub fn disjoint_anchors(&self) -> &[Selection] { + &self.disjoint + } + pub fn disjoint_anchor_ranges(&self) -> impl Iterator> { // Mapping the Arc slice would borrow it, whereas indexing captures it. - let disjoint = self.disjoint_anchors(); + let disjoint = self.disjoint_anchors_arc(); (0..disjoint.len()).map(move |ix| disjoint[ix].range()) } /// Non-overlapping selections using anchors, including the pending selection. pub fn all_anchors(&self, cx: &mut App) -> Arc<[Selection]> { if self.pending.is_none() { - self.disjoint_anchors() + self.disjoint_anchors_arc() } else { let all_offset_selections = self.all::(cx); let buffer = self.buffer(cx); @@ -108,10 +115,12 @@ impl SelectionsCollection { } } - pub fn pending_anchor(&self) -> Option> { - self.pending - .as_ref() - .map(|pending| pending.selection.clone()) + pub fn pending_anchor(&self) -> Option<&Selection> { + self.pending.as_ref().map(|pending| &pending.selection) + } + + pub fn pending_anchor_mut(&mut self) -> Option<&mut Selection> { + self.pending.as_mut().map(|pending| &mut pending.selection) } pub fn pending>( @@ -120,7 +129,7 @@ impl SelectionsCollection { ) -> Option> { let map = self.display_map(cx); - resolve_selections(self.pending_anchor().as_ref(), &map).next() + resolve_selections(self.pending_anchor(), &map).next() } pub(crate) fn pending_mode(&self) -> Option { @@ -234,8 +243,7 @@ impl SelectionsCollection { let map = self.display_map(cx); let disjoint_anchors = &self.disjoint; let mut disjoint = resolve_selections_display(disjoint_anchors.iter(), &map).peekable(); - let mut pending_opt = - resolve_selections_display(self.pending_anchor().as_ref(), &map).next(); + let mut pending_opt = resolve_selections_display(self.pending_anchor(), &map).next(); let selections = iter::from_fn(move || { if let Some(pending) = pending_opt.as_mut() { while let Some(next_selection) = disjoint.peek() { @@ -343,9 +351,9 @@ impl SelectionsCollection { #[cfg(any(test, feature = "test-support"))] pub fn display_ranges(&self, cx: &mut App) -> Vec> { let display_map = self.display_map(cx); - self.disjoint_anchors() + self.disjoint_anchors_arc() .iter() - .chain(self.pending_anchor().as_ref()) + .chain(self.pending_anchor()) .map(|s| { if s.reversed { s.end.to_display_point(&display_map)..s.start.to_display_point(&display_map) @@ -412,6 +420,10 @@ impl SelectionsCollection { ); (mutable_collection.selections_changed, result) } + + pub fn next_selection_id(&self) -> usize { + self.next_selection_id + } } pub struct MutableSelectionsCollection<'a> { diff --git a/crates/editor/src/test/editor_test_context.rs b/crates/editor/src/test/editor_test_context.rs index 8c54c265edf7a19af9d17e982a5f4cb6a0079cc3..fbf7a312fe56600ad78e13c278c85e29b8ca5aa5 100644 --- a/crates/editor/src/test/editor_test_context.rs +++ b/crates/editor/src/test/editor_test_context.rs @@ -396,7 +396,7 @@ impl EditorTestContext { let (multibuffer_snapshot, selections, excerpts) = self.update_editor(|editor, _, cx| { let multibuffer_snapshot = editor.buffer.read(cx).snapshot(cx); - let selections = editor.selections.disjoint_anchors(); + let selections = editor.selections.disjoint_anchors_arc(); let excerpts = multibuffer_snapshot .excerpts() .map(|(e_id, snapshot, range)| (e_id, snapshot.clone(), range)) diff --git a/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs b/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs index 84794d5386eda1517808d181eb259a3264f7b82d..e879ed0cb01f70f24a9b2b52438e1ff7d405f2d6 100644 --- a/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs +++ b/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs @@ -309,7 +309,14 @@ impl TryFrom for ResolvedTask { command: value.command.context("missing command")?, args: value.args, env: value.env.into_iter().collect(), - cwd: value.cwd.map(|s| s.to_string_lossy().into_owned()), + cwd: value.cwd.map(|s| { + let s = s.to_string_lossy(); + if cfg!(target_os = "windows") { + s.replace('\\', "/") + } else { + s.into_owned() + } + }), }) } } diff --git a/crates/fs/src/fs.rs b/crates/fs/src/fs.rs index bd7c94c1d71dd64b5c6caec6f2ffaa4517ac2db7..198299617619363fa9d486042d1b803c3ede6f88 100644 --- a/crates/fs/src/fs.rs +++ b/crates/fs/src/fs.rs @@ -693,7 +693,7 @@ impl Fs for RealFs { Ok(Some(Metadata { inode, - mtime: MTime(metadata.modified().unwrap()), + mtime: MTime(metadata.modified().unwrap_or(SystemTime::UNIX_EPOCH)), len: metadata.len(), is_symlink, is_dir: metadata.file_type().is_dir(), diff --git a/crates/fs/src/mac_watcher.rs b/crates/fs/src/mac_watcher.rs index 7bd176639f1dccef2da4c4ae8dcb317d0be602cb..698014de9716f6505ccd23cd344a62815d9ba0f7 100644 --- a/crates/fs/src/mac_watcher.rs +++ b/crates/fs/src/mac_watcher.rs @@ -6,6 +6,7 @@ use parking_lot::Mutex; use std::{ path::{Path, PathBuf}, sync::Weak, + thread, time::Duration, }; @@ -48,9 +49,12 @@ impl Watcher for MacWatcher { let (stream, handle) = EventStream::new(&[path], self.latency); let tx = self.events_tx.clone(); - std::thread::spawn(move || { - stream.run(move |events| smol::block_on(tx.send(events)).is_ok()); - }); + thread::Builder::new() + .name("MacWatcher".to_owned()) + .spawn(move || { + stream.run(move |events| smol::block_on(tx.send(events)).is_ok()); + }) + .unwrap(); handles.insert(path.into(), handle); Ok(()) diff --git a/crates/git_ui/src/branch_picker.rs b/crates/git_ui/src/branch_picker.rs index f425af646d6c38227bb82b3185ab7b0192fdea6c..b9a8dfea9ea167bf7ee807ee2b459444f4fa4f4d 100644 --- a/crates/git_ui/src/branch_picker.rs +++ b/crates/git_ui/src/branch_picker.rs @@ -521,6 +521,14 @@ impl PickerDelegate for BranchListDelegate { .inset(true) .spacing(ListItemSpacing::Sparse) .toggle_state(selected) + .tooltip({ + let branch_name = entry.branch.name().to_string(); + if entry.is_new { + Tooltip::text(format!("Create branch \"{}\"", branch_name)) + } else { + Tooltip::text(branch_name) + } + }) .child( v_flex() .w_full() diff --git a/crates/git_ui/src/git_panel.rs b/crates/git_ui/src/git_panel.rs index 6fb940588a83bfd910b52543489996b023267c00..76671eba7b577e86d5049add743e965d11acd6c4 100644 --- a/crates/git_ui/src/git_panel.rs +++ b/crates/git_ui/src/git_panel.rs @@ -3748,7 +3748,10 @@ impl GitPanel { .custom_scrollbars( Scrollbars::for_settings::() .tracked_scroll_handle(self.scroll_handle.clone()) - .with_track_along(ScrollAxes::Horizontal), + .with_track_along( + ScrollAxes::Horizontal, + cx.theme().colors().panel_background, + ), window, cx, ), diff --git a/crates/git_ui/src/text_diff_view.rs b/crates/git_ui/src/text_diff_view.rs index ebf32d1b994814fa277201176b555efed5e85e66..bd46a067dc8e6c3aeec4de878709024f66a819f2 100644 --- a/crates/git_ui/src/text_diff_view.rs +++ b/crates/git_ui/src/text_diff_view.rs @@ -416,7 +416,7 @@ impl Item for TextDiffView { pub fn selection_location_text(editor: &Editor, cx: &App) -> Option { let buffer = editor.buffer().read(cx); let buffer_snapshot = buffer.snapshot(cx); - let first_selection = editor.selections.disjoint.first()?; + let first_selection = editor.selections.disjoint_anchors().first()?; let selection_start = first_selection.start.to_point(&buffer_snapshot); let selection_end = first_selection.end.to_point(&buffer_snapshot); diff --git a/crates/gpui/Cargo.toml b/crates/gpui/Cargo.toml index 44f819c135298dc991ad6036ad9948b5eaf609a4..ac1bdf85cb478064db42b3dccde8e44adee72fdd 100644 --- a/crates/gpui/Cargo.toml +++ b/crates/gpui/Cargo.toml @@ -115,7 +115,7 @@ seahash = "4.1" semantic_version.workspace = true serde.workspace = true serde_json.workspace = true -slotmap = "1.0.6" +slotmap.workspace = true smallvec.workspace = true smol.workspace = true stacksafe.workspace = true diff --git a/crates/gpui/src/app.rs b/crates/gpui/src/app.rs index 8b0b404d1dffbf8a27de1f29437ce9cc2ba63f0f..e6c3e3b8deea9b82514b5ac932c4f204fa081e14 100644 --- a/crates/gpui/src/app.rs +++ b/crates/gpui/src/app.rs @@ -958,6 +958,14 @@ impl App { cx.window_update_stack.pop(); window.root.replace(root_view.into()); window.defer(cx, |window: &mut Window, cx| window.appearance_changed(cx)); + + // allow a window to draw at least once before returning + // this didn't cause any issues on non windows platforms as it seems we always won the race to on_request_frame + // on windows we quite frequently lose the race and return a window that has never rendered, which leads to a crash + // where DispatchTree::root_node_id asserts on empty nodes + let clear = window.draw(cx); + clear.clear(); + cx.window_handles.insert(id, window.handle); cx.windows.get_mut(id).unwrap().replace(window); Ok(handle) diff --git a/crates/gpui/src/color.rs b/crates/gpui/src/color.rs index c6c4cdc77b133105a7c233417efae06d5c7aa220..b84f8699e38f015232313e49dd748a2b049c146d 100644 --- a/crates/gpui/src/color.rs +++ b/crates/gpui/src/color.rs @@ -151,9 +151,9 @@ impl From for Rgba { }; Rgba { - r, - g, - b, + r: r.clamp(0., 1.), + g: g.clamp(0., 1.), + b: b.clamp(0., 1.), a: color.a, } } diff --git a/crates/gpui/src/platform/linux/dispatcher.rs b/crates/gpui/src/platform/linux/dispatcher.rs index 3d32dbd2fdece5259f48e52550f6983b6a8c5b1d..2f6cd83756054bdbca2c764b046b0c37f51d3515 100644 --- a/crates/gpui/src/platform/linux/dispatcher.rs +++ b/crates/gpui/src/platform/linux/dispatcher.rs @@ -37,51 +37,57 @@ impl LinuxDispatcher { let mut background_threads = (0..thread_count) .map(|i| { let receiver = background_receiver.clone(); - std::thread::spawn(move || { - for runnable in receiver { - let start = Instant::now(); - - runnable.run(); - - log::trace!( - "background thread {}: ran runnable. took: {:?}", - i, - start.elapsed() - ); - } - }) + std::thread::Builder::new() + .name(format!("Worker-{i}")) + .spawn(move || { + for runnable in receiver { + let start = Instant::now(); + + runnable.run(); + + log::trace!( + "background thread {}: ran runnable. took: {:?}", + i, + start.elapsed() + ); + } + }) + .unwrap() }) .collect::>(); let (timer_sender, timer_channel) = calloop::channel::channel::(); - let timer_thread = std::thread::spawn(|| { - let mut event_loop: EventLoop<()> = - EventLoop::try_new().expect("Failed to initialize timer loop!"); - - let handle = event_loop.handle(); - let timer_handle = event_loop.handle(); - handle - .insert_source(timer_channel, move |e, _, _| { - if let channel::Event::Msg(timer) = e { - // This has to be in an option to satisfy the borrow checker. The callback below should only be scheduled once. - let mut runnable = Some(timer.runnable); - timer_handle - .insert_source( - calloop::timer::Timer::from_duration(timer.duration), - move |_, _, _| { - if let Some(runnable) = runnable.take() { - runnable.run(); - } - TimeoutAction::Drop - }, - ) - .expect("Failed to start timer"); - } - }) - .expect("Failed to start timer thread"); - - event_loop.run(None, &mut (), |_| {}).log_err(); - }); + let timer_thread = std::thread::Builder::new() + .name("Timer".to_owned()) + .spawn(|| { + let mut event_loop: EventLoop<()> = + EventLoop::try_new().expect("Failed to initialize timer loop!"); + + let handle = event_loop.handle(); + let timer_handle = event_loop.handle(); + handle + .insert_source(timer_channel, move |e, _, _| { + if let channel::Event::Msg(timer) = e { + // This has to be in an option to satisfy the borrow checker. The callback below should only be scheduled once. + let mut runnable = Some(timer.runnable); + timer_handle + .insert_source( + calloop::timer::Timer::from_duration(timer.duration), + move |_, _, _| { + if let Some(runnable) = runnable.take() { + runnable.run(); + } + TimeoutAction::Drop + }, + ) + .expect("Failed to start timer"); + } + }) + .expect("Failed to start timer thread"); + + event_loop.run(None, &mut (), |_| {}).log_err(); + }) + .unwrap(); background_threads.push(timer_thread); diff --git a/crates/gpui/src/platform/linux/x11/clipboard.rs b/crates/gpui/src/platform/linux/x11/clipboard.rs index a6f96d38c4254da5a2f92261700126962c16e91c..65ad16e82bf103c4ef08e79c692196d3fae58777 100644 --- a/crates/gpui/src/platform/linux/x11/clipboard.rs +++ b/crates/gpui/src/platform/linux/x11/clipboard.rs @@ -957,15 +957,17 @@ impl Clipboard { } // At this point we know that the clipboard does not exist. let ctx = Arc::new(Inner::new()?); - let join_handle; - { - let ctx = Arc::clone(&ctx); - join_handle = std::thread::spawn(move || { - if let Err(error) = serve_requests(ctx) { - log::error!("Worker thread errored with: {}", error); + let join_handle = std::thread::Builder::new() + .name("Clipboard".to_owned()) + .spawn({ + let ctx = Arc::clone(&ctx); + move || { + if let Err(error) = serve_requests(ctx) { + log::error!("Worker thread errored with: {}", error); + } } - }); - } + }) + .unwrap(); *global_cb = Some(GlobalClipboard { inner: Arc::clone(&ctx), server_handle: join_handle, diff --git a/crates/gpui/src/platform/mac/metal_atlas.rs b/crates/gpui/src/platform/mac/metal_atlas.rs index 5d2d8e63e06a1ea6251c1fd2edf461eeeedec612..8282530c5efdc13ca95a1f04c0f6ef1a23c8366c 100644 --- a/crates/gpui/src/platform/mac/metal_atlas.rs +++ b/crates/gpui/src/platform/mac/metal_atlas.rs @@ -167,11 +167,14 @@ impl MetalAtlasState { if let Some(ix) = index { texture_list.textures[ix] = Some(atlas_texture); - texture_list.textures.get_mut(ix).unwrap().as_mut().unwrap() + texture_list.textures.get_mut(ix) } else { texture_list.textures.push(Some(atlas_texture)); - texture_list.textures.last_mut().unwrap().as_mut().unwrap() + texture_list.textures.last_mut() } + .unwrap() + .as_mut() + .unwrap() } fn texture(&self, id: AtlasTextureId) -> &MetalAtlasTexture { diff --git a/crates/gpui/src/platform/mac/platform.rs b/crates/gpui/src/platform/mac/platform.rs index dea04d89a06acac526a8b033681829fdc1e148fd..9909c78c472a17e683380be71b65484800c0fa76 100644 --- a/crates/gpui/src/platform/mac/platform.rs +++ b/crates/gpui/src/platform/mac/platform.rs @@ -82,6 +82,10 @@ unsafe fn build_classes() { APP_DELEGATE_CLASS = unsafe { let mut decl = ClassDecl::new("GPUIApplicationDelegate", class!(NSResponder)).unwrap(); decl.add_ivar::<*mut c_void>(MAC_PLATFORM_IVAR); + decl.add_method( + sel!(applicationWillFinishLaunching:), + will_finish_launching as extern "C" fn(&mut Object, Sel, id), + ); decl.add_method( sel!(applicationDidFinishLaunching:), did_finish_launching as extern "C" fn(&mut Object, Sel, id), @@ -1356,6 +1360,23 @@ unsafe fn get_mac_platform(object: &mut Object) -> &MacPlatform { } } +extern "C" fn will_finish_launching(_this: &mut Object, _: Sel, _: id) { + unsafe { + let user_defaults: id = msg_send![class!(NSUserDefaults), standardUserDefaults]; + + // The autofill heuristic controller causes slowdown and high CPU usage. + // We don't know exactly why. This disables the full heuristic controller. + // + // Adapted from: https://github.com/ghostty-org/ghostty/pull/8625 + let name = ns_string("NSAutoFillHeuristicControllerEnabled"); + let existing_value: id = msg_send![user_defaults, objectForKey: name]; + if existing_value == nil { + let false_value: id = msg_send![class!(NSNumber), numberWithBool:false]; + let _: () = msg_send![user_defaults, setObject: false_value forKey: name]; + } + } +} + extern "C" fn did_finish_launching(this: &mut Object, _: Sel, _: id) { unsafe { let app: id = msg_send![APP_CLASS, sharedApplication]; diff --git a/crates/gpui/src/platform/windows/platform.rs b/crates/gpui/src/platform/windows/platform.rs index 4d0e6ea56f7d90f303f6634de1239a6a4542429a..2eb1862f36a26592e18dc2e44875e08319361cc8 100644 --- a/crates/gpui/src/platform/windows/platform.rs +++ b/crates/gpui/src/platform/windows/platform.rs @@ -243,29 +243,32 @@ impl WindowsPlatform { let validation_number = self.inner.validation_number; let all_windows = Arc::downgrade(&self.raw_window_handles); let text_system = Arc::downgrade(&self.text_system); - std::thread::spawn(move || { - let vsync_provider = VSyncProvider::new(); - loop { - vsync_provider.wait_for_vsync(); - if check_device_lost(&directx_device.device) { - handle_gpu_device_lost( - &mut directx_device, - platform_window.as_raw(), - validation_number, - &all_windows, - &text_system, - ); - } - let Some(all_windows) = all_windows.upgrade() else { - break; - }; - for hwnd in all_windows.read().iter() { - unsafe { - let _ = RedrawWindow(Some(hwnd.as_raw()), None, None, RDW_INVALIDATE); + std::thread::Builder::new() + .name("VSyncProvider".to_owned()) + .spawn(move || { + let vsync_provider = VSyncProvider::new(); + loop { + vsync_provider.wait_for_vsync(); + if check_device_lost(&directx_device.device) { + handle_gpu_device_lost( + &mut directx_device, + platform_window.as_raw(), + validation_number, + &all_windows, + &text_system, + ); + } + let Some(all_windows) = all_windows.upgrade() else { + break; + }; + for hwnd in all_windows.read().iter() { + unsafe { + let _ = RedrawWindow(Some(hwnd.as_raw()), None, None, RDW_INVALIDATE); + } } } - } - }); + }) + .unwrap(); } } @@ -1016,7 +1019,7 @@ fn handle_gpu_device_lost( all_windows: &std::sync::Weak>>, text_system: &std::sync::Weak, ) { - // Here we wait a bit to ensure the the system has time to recover from the device lost state. + // Here we wait a bit to ensure the system has time to recover from the device lost state. // If we don't wait, the final drawing result will be blank. std::thread::sleep(std::time::Duration::from_millis(350)); diff --git a/crates/gpui/src/platform/windows/window.rs b/crates/gpui/src/platform/windows/window.rs index aa907c8d734973fc4fc795b6d8ebf7654d1b40de..7abb4ee21a1a28356e15d09be3c22c688bb7e033 100644 --- a/crates/gpui/src/platform/windows/window.rs +++ b/crates/gpui/src/platform/windows/window.rs @@ -684,8 +684,16 @@ impl PlatformWindow for WindowsWindow { .executor .spawn(async move { this.set_window_placement().log_err(); - unsafe { SetActiveWindow(hwnd).log_err() }; - unsafe { SetFocus(Some(hwnd)).log_err() }; + + unsafe { + // If the window is minimized, restore it. + if IsIconic(hwnd).as_bool() { + ShowWindowAsync(hwnd, SW_RESTORE).ok().log_err(); + } + + SetActiveWindow(hwnd).log_err(); + SetFocus(Some(hwnd)).log_err(); + } // premium ragebait by windows, this is needed because the window // must have received an input event to be able to set itself to foreground diff --git a/crates/http_client/src/http_client.rs b/crates/http_client/src/http_client.rs index 62468573ed29687c0436e98a0174baa515b0ee3d..1429b7bf941fab5b1b508b977e898b8e153942d1 100644 --- a/crates/http_client/src/http_client.rs +++ b/crates/http_client/src/http_client.rs @@ -318,6 +318,12 @@ pub fn read_proxy_from_env() -> Option { .and_then(|env| env.parse().ok()) } +pub fn read_no_proxy_from_env() -> Option { + const ENV_VARS: &[&str] = &["NO_PROXY", "no_proxy"]; + + ENV_VARS.iter().find_map(|var| std::env::var(var).ok()) +} + pub struct BlockedHttpClient; impl BlockedHttpClient { diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index f3609f7ea8706f33eb07eaaf456731e14c85555a..0f05e58c27c48c37043fe90f64b4f03968b22752 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -263,6 +263,7 @@ pub enum IconName { ZedPredictError, ZedPredictUp, ZedXCopilot, + Linux, } impl IconName { diff --git a/crates/inspector_ui/README.md b/crates/inspector_ui/README.md index 5c720dfea2df3ff2ddf75112fec8793ba1851ed1..74886e611108fd4fd3f5f5015746f913e2697cae 100644 --- a/crates/inspector_ui/README.md +++ b/crates/inspector_ui/README.md @@ -68,7 +68,7 @@ With both approaches, would need to record the buffer version and use that when * Mode to navigate to source code on every element change while picking. -* Tracking of more source locations - currently the source location is often in a ui compoenent. Ideally this would have a way for the components to indicate that they are probably not the source location the user is looking for. +* Tracking of more source locations - currently the source location is often in a ui component. Ideally this would have a way for the components to indicate that they are probably not the source location the user is looking for. - Could have `InspectorElementId` be `Vec<(ElementId, Option)>`, but if there are multiple code paths that construct the same element this would cause them to be considered different. diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index e55d1f2d2385d41fa30643987cfd026958b9b803..1a7fca79f64c2c253117a3acde8c4d7519a9c282 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -144,7 +144,7 @@ struct BufferBranchState { /// state of a buffer. pub struct BufferSnapshot { pub text: text::BufferSnapshot, - pub(crate) syntax: SyntaxSnapshot, + pub syntax: SyntaxSnapshot, file: Option>, diagnostics: SmallVec<[(LanguageServerId, DiagnosticSet); 2]>, remote_selections: TreeMap, @@ -667,7 +667,10 @@ impl HighlightedTextBuilder { syntax_snapshot: &'a SyntaxSnapshot, ) -> BufferChunks<'a> { let captures = syntax_snapshot.captures(range.clone(), snapshot, |grammar| { - grammar.highlights_query.as_ref() + grammar + .highlights_config + .as_ref() + .map(|config| &config.query) }); let highlight_maps = captures @@ -3253,7 +3256,10 @@ impl BufferSnapshot { fn get_highlights(&self, range: Range) -> (SyntaxMapCaptures<'_>, Vec) { let captures = self.syntax.captures(range, &self.text, |grammar| { - grammar.highlights_query.as_ref() + grammar + .highlights_config + .as_ref() + .map(|config| &config.query) }); let highlight_maps = captures .grammars() @@ -3317,18 +3323,25 @@ impl BufferSnapshot { /// Iterates over every [`SyntaxLayer`] in the buffer. pub fn syntax_layers(&self) -> impl Iterator> + '_ { - self.syntax - .layers_for_range(0..self.len(), &self.text, true) + self.syntax_layers_for_range(0..self.len(), true) } pub fn syntax_layer_at(&self, position: D) -> Option> { let offset = position.to_offset(self); - self.syntax - .layers_for_range(offset..offset, &self.text, false) + self.syntax_layers_for_range(offset..offset, false) .filter(|l| l.node().end_byte() > offset) .last() } + pub fn syntax_layers_for_range( + &self, + range: Range, + include_hidden: bool, + ) -> impl Iterator> + '_ { + self.syntax + .layers_for_range(range, &self.text, include_hidden) + } + pub fn smallest_syntax_layer_containing( &self, range: Range, @@ -3866,9 +3879,12 @@ impl BufferSnapshot { text: item.text, highlight_ranges: item.highlight_ranges, name_ranges: item.name_ranges, - body_range: item.body_range.map(|body_range| { - self.anchor_after(body_range.start)..self.anchor_before(body_range.end) - }), + signature_range: item + .signature_range + .map(|r| self.anchor_after(r.start)..self.anchor_before(r.end)), + body_range: item + .body_range + .map(|r| self.anchor_after(r.start)..self.anchor_before(r.end)), annotation_range: annotation_row_range.map(|annotation_range| { self.anchor_after(Point::new(annotation_range.start, 0)) ..self.anchor_before(Point::new( @@ -3908,38 +3924,51 @@ impl BufferSnapshot { let mut open_point = None; let mut close_point = None; + + let mut signature_start = None; + let mut signature_end = None; + let mut extend_signature_range = |node: tree_sitter::Node| { + if signature_start.is_none() { + signature_start = Some(Point::from_ts_point(node.start_position())); + } + signature_end = Some(Point::from_ts_point(node.end_position())); + }; + let mut buffer_ranges = Vec::new(); + let mut add_to_buffer_ranges = |node: tree_sitter::Node, node_is_name| { + let mut range = node.start_byte()..node.end_byte(); + let start = node.start_position(); + if node.end_position().row > start.row { + range.end = range.start + self.line_len(start.row as u32) as usize - start.column; + } + + if !range.is_empty() { + buffer_ranges.push((range, node_is_name)); + } + }; + for capture in mat.captures { - let node_is_name; if capture.index == config.name_capture_ix { - node_is_name = true; + add_to_buffer_ranges(capture.node, true); + extend_signature_range(capture.node); } else if Some(capture.index) == config.context_capture_ix || (Some(capture.index) == config.extra_context_capture_ix && include_extra_context) { - node_is_name = false; + add_to_buffer_ranges(capture.node, false); + extend_signature_range(capture.node); } else { if Some(capture.index) == config.open_capture_ix { open_point = Some(Point::from_ts_point(capture.node.end_position())); } else if Some(capture.index) == config.close_capture_ix { close_point = Some(Point::from_ts_point(capture.node.start_position())); } - - continue; - } - - let mut range = capture.node.start_byte()..capture.node.end_byte(); - let start = capture.node.start_position(); - if capture.node.end_position().row > start.row { - range.end = range.start + self.line_len(start.row as u32) as usize - start.column; - } - - if !range.is_empty() { - buffer_ranges.push((range, node_is_name)); } } + if buffer_ranges.is_empty() { return None; } + let mut text = String::new(); let mut highlight_ranges = Vec::new(); let mut name_ranges = Vec::new(); @@ -3948,7 +3977,6 @@ impl BufferSnapshot { true, ); let mut last_buffer_range_end = 0; - for (buffer_range, is_name) in buffer_ranges { let space_added = !text.is_empty() && buffer_range.start > last_buffer_range_end; if space_added { @@ -3990,12 +4018,17 @@ impl BufferSnapshot { last_buffer_range_end = buffer_range.end; } + let signature_range = signature_start + .zip(signature_end) + .map(|(start, end)| start..end); + Some(OutlineItem { depth: 0, // We'll calculate the depth later range: item_point_range, text, highlight_ranges, name_ranges, + signature_range, body_range: open_point.zip(close_point).map(|(start, end)| start..end), annotation_range: None, }) diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index 77e8ee0232819a830e48d1d12a778ef19026d7b6..3c951e50ff231a72e284da743bb3e5d409eb9c5e 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -81,7 +81,9 @@ pub use language_registry::{ }; pub use lsp::{LanguageServerId, LanguageServerName}; pub use outline::*; -pub use syntax_map::{OwnedSyntaxLayer, SyntaxLayer, ToTreeSitterPoint, TreeSitterOptions}; +pub use syntax_map::{ + OwnedSyntaxLayer, SyntaxLayer, SyntaxMapMatches, ToTreeSitterPoint, TreeSitterOptions, +}; pub use text::{AnchorRangeExt, LineEnding}; pub use tree_sitter::{Node, Parser, Tree, TreeCursor}; @@ -1154,7 +1156,7 @@ pub struct Grammar { id: GrammarId, pub ts_language: tree_sitter::Language, pub(crate) error_query: Option, - pub(crate) highlights_query: Option, + pub highlights_config: Option, pub(crate) brackets_config: Option, pub(crate) redactions_config: Option, pub(crate) runnable_config: Option, @@ -1168,6 +1170,11 @@ pub struct Grammar { pub(crate) highlight_map: Mutex, } +pub struct HighlightsConfig { + pub query: Query, + pub identifier_capture_indices: Vec, +} + struct IndentConfig { query: Query, indent_capture_ix: u32, @@ -1332,7 +1339,7 @@ impl Language { grammar: ts_language.map(|ts_language| { Arc::new(Grammar { id: GrammarId::new(), - highlights_query: None, + highlights_config: None, brackets_config: None, outline_config: None, text_object_config: None, @@ -1430,7 +1437,29 @@ impl Language { pub fn with_highlights_query(mut self, source: &str) -> Result { let grammar = self.grammar_mut()?; - grammar.highlights_query = Some(Query::new(&grammar.ts_language, source)?); + let query = Query::new(&grammar.ts_language, source)?; + + let mut identifier_capture_indices = Vec::new(); + for name in [ + "variable", + "constant", + "constructor", + "function", + "function.method", + "function.method.call", + "function.special", + "property", + "type", + "type.interface", + ] { + identifier_capture_indices.extend(query.capture_index_for_name(name)); + } + + grammar.highlights_config = Some(HighlightsConfig { + query, + identifier_capture_indices, + }); + Ok(self) } @@ -1856,7 +1885,10 @@ impl Language { let tree = grammar.parse_text(text, None); let captures = SyntaxSnapshot::single_tree_captures(range.clone(), text, &tree, self, |grammar| { - grammar.highlights_query.as_ref() + grammar + .highlights_config + .as_ref() + .map(|config| &config.query) }); let highlight_maps = vec![grammar.highlight_map()]; let mut offset = 0; @@ -1885,10 +1917,10 @@ impl Language { pub fn set_theme(&self, theme: &SyntaxTheme) { if let Some(grammar) = self.grammar.as_ref() - && let Some(highlights_query) = &grammar.highlights_query + && let Some(highlights_config) = &grammar.highlights_config { *grammar.highlight_map.lock() = - HighlightMap::new(highlights_query.capture_names(), theme); + HighlightMap::new(highlights_config.query.capture_names(), theme); } } @@ -2103,8 +2135,9 @@ impl Grammar { pub fn highlight_id_for_name(&self, name: &str) -> Option { let capture_id = self - .highlights_query + .highlights_config .as_ref()? + .query .capture_index_for_name(name)?; Some(self.highlight_map.lock().get(capture_id)) } diff --git a/crates/language/src/outline.rs b/crates/language/src/outline.rs index d96cd90e03142c6498ae17bc63e1787d99e8557a..09c556cf98f58ea26925e1df8bde9d43ec72e6c7 100644 --- a/crates/language/src/outline.rs +++ b/crates/language/src/outline.rs @@ -19,6 +19,7 @@ pub struct OutlineItem { pub text: String, pub highlight_ranges: Vec<(Range, HighlightStyle)>, pub name_ranges: Vec>, + pub signature_range: Option>, pub body_range: Option>, pub annotation_range: Option>, } @@ -35,6 +36,10 @@ impl OutlineItem { text: self.text.clone(), highlight_ranges: self.highlight_ranges.clone(), name_ranges: self.name_ranges.clone(), + signature_range: self + .signature_range + .as_ref() + .map(|r| r.start.to_point(buffer)..r.end.to_point(buffer)), body_range: self .body_range .as_ref() @@ -208,6 +213,7 @@ mod tests { text: "class Foo".to_string(), highlight_ranges: vec![], name_ranges: vec![6..9], + signature_range: None, body_range: None, annotation_range: None, }, @@ -217,6 +223,7 @@ mod tests { text: "private".to_string(), highlight_ranges: vec![], name_ranges: vec![], + signature_range: None, body_range: None, annotation_range: None, }, @@ -241,6 +248,7 @@ mod tests { text: "fn process".to_string(), highlight_ranges: vec![], name_ranges: vec![3..10], + signature_range: None, body_range: None, annotation_range: None, }, @@ -250,6 +258,7 @@ mod tests { text: "struct DataProcessor".to_string(), highlight_ranges: vec![], name_ranges: vec![7..20], + signature_range: None, body_range: None, annotation_range: None, }, diff --git a/crates/language/src/syntax_map/syntax_map_tests.rs b/crates/language/src/syntax_map/syntax_map_tests.rs index 622731b7814ce16bfcc026b6723e80d5ba4dda7a..6b19d651e241ad71229c6c7fc429883a44367304 100644 --- a/crates/language/src/syntax_map/syntax_map_tests.rs +++ b/crates/language/src/syntax_map/syntax_map_tests.rs @@ -1409,12 +1409,15 @@ fn assert_capture_ranges( ) { let mut actual_ranges = Vec::>::new(); let captures = syntax_map.captures(0..buffer.len(), buffer, |grammar| { - grammar.highlights_query.as_ref() + grammar + .highlights_config + .as_ref() + .map(|config| &config.query) }); let queries = captures .grammars() .iter() - .map(|grammar| grammar.highlights_query.as_ref().unwrap()) + .map(|grammar| &grammar.highlights_config.as_ref().unwrap().query) .collect::>(); for capture in captures { let name = &queries[capture.grammar_index].capture_names()[capture.index as usize]; diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index b5bfb870f643452bd5be248c9910d99f16a8101e..8a2a681c26ede21ce948b6667b4aaea589724dcf 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -29,6 +29,7 @@ copilot.workspace = true credentials_provider.workspace = true deepseek = { workspace = true, features = ["schemars"] } editor.workspace = true +fs.workspace = true futures.workspace = true google_ai = { workspace = true, features = ["schemars"] } gpui.workspace = true @@ -61,6 +62,7 @@ util.workspace = true vercel = { workspace = true, features = ["schemars"] } workspace-hack.workspace = true x_ai = { workspace = true, features = ["schemars"] } +zed_env_vars.workspace = true [dev-dependencies] editor = { workspace = true, features = ["test-support"] } diff --git a/crates/language_models/src/api_key.rs b/crates/language_models/src/api_key.rs new file mode 100644 index 0000000000000000000000000000000000000000..122234b6ced6d0bf1b7a0d684683c841824ccd2d --- /dev/null +++ b/crates/language_models/src/api_key.rs @@ -0,0 +1,295 @@ +use anyhow::{Result, anyhow}; +use credentials_provider::CredentialsProvider; +use futures::{FutureExt, future}; +use gpui::{AsyncApp, Context, SharedString, Task}; +use language_model::AuthenticateError; +use std::{ + fmt::{Display, Formatter}, + sync::Arc, +}; +use util::ResultExt as _; +use zed_env_vars::EnvVar; + +/// Manages a single API key for a language model provider. API keys either come from environment +/// variables or the system keychain. +/// +/// Keys from the system keychain are associated with a provider URL, and this ensures that they are +/// only used with that URL. +pub struct ApiKeyState { + url: SharedString, + load_status: LoadStatus, + load_task: Option>>, +} + +#[derive(Debug, Clone)] +pub enum LoadStatus { + NotPresent, + Error(String), + Loaded(ApiKey), +} + +#[derive(Debug, Clone)] +pub struct ApiKey { + source: ApiKeySource, + key: Arc, +} + +impl ApiKeyState { + pub fn new(url: SharedString) -> Self { + Self { + url, + load_status: LoadStatus::NotPresent, + load_task: None, + } + } + + pub fn has_key(&self) -> bool { + matches!(self.load_status, LoadStatus::Loaded { .. }) + } + + pub fn is_from_env_var(&self) -> bool { + match &self.load_status { + LoadStatus::Loaded(ApiKey { + source: ApiKeySource::EnvVar { .. }, + .. + }) => true, + _ => false, + } + } + + /// Get the stored API key, verifying that it is associated with the URL. Returns `None` if + /// there is no key or for URL mismatches, and the mismatch case is logged. + /// + /// To avoid URL mismatches, expects that `load_if_needed` or `handle_url_change` has been + /// called with this URL. + pub fn key(&self, url: &str) -> Option> { + let api_key = match &self.load_status { + LoadStatus::Loaded(api_key) => api_key, + _ => return None, + }; + if url == self.url.as_str() { + Some(api_key.key.clone()) + } else if let ApiKeySource::EnvVar(var_name) = &api_key.source { + log::warn!( + "{} is now being used with URL {}, when initially it was used with URL {}", + var_name, + url, + self.url + ); + Some(api_key.key.clone()) + } else { + // bug case because load_if_needed should be called whenever the url may have changed + log::error!( + "bug: Attempted to use API key associated with URL {} instead with URL {}", + self.url, + url + ); + None + } + } + + /// Set or delete the API key in the system keychain. + pub fn store( + &mut self, + url: SharedString, + key: Option, + get_this: impl Fn(&mut Ent) -> &mut Self + 'static, + cx: &Context, + ) -> Task> { + if self.is_from_env_var() { + return Task::ready(Err(anyhow!( + "bug: attempted to store API key in system keychain when API key is from env var", + ))); + } + let credentials_provider = ::global(cx); + cx.spawn(async move |ent, cx| { + if let Some(key) = &key { + credentials_provider + .write_credentials(&url, "Bearer", key.as_bytes(), cx) + .await + .log_err(); + } else { + credentials_provider + .delete_credentials(&url, cx) + .await + .log_err(); + } + ent.update(cx, |ent, cx| { + let this = get_this(ent); + this.url = url; + this.load_status = match &key { + Some(key) => LoadStatus::Loaded(ApiKey { + source: ApiKeySource::SystemKeychain, + key: key.as_str().into(), + }), + None => LoadStatus::NotPresent, + }; + cx.notify(); + }) + }) + } + + /// Reloads the API key if the current API key is associated with a different URL. + /// + /// Note that it is not efficient to use this or `load_if_needed` with multiple URLs + /// interchangeably - URL change should correspond to some user initiated change. + pub fn handle_url_change( + &mut self, + url: SharedString, + env_var: &EnvVar, + get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static, + cx: &mut Context, + ) { + if url != self.url { + if !self.is_from_env_var() { + // loading will continue even though this result task is dropped + let _task = self.load_if_needed(url, env_var, get_this, cx); + } + } + } + + /// If needed, loads the API key associated with the given URL from the system keychain. When a + /// non-empty environment variable is provided, it will be used instead. If called when an API + /// key was already loaded for a different URL, that key will be cleared before loading. + /// + /// Dropping the returned Task does not cancel key loading. + pub fn load_if_needed( + &mut self, + url: SharedString, + env_var: &EnvVar, + get_this: impl Fn(&mut Ent) -> &mut Self + Clone + 'static, + cx: &mut Context, + ) -> Task> { + if let LoadStatus::Loaded { .. } = &self.load_status + && self.url == url + { + return Task::ready(Ok(())); + } + + if let Some(key) = &env_var.value + && !key.is_empty() + { + let api_key = ApiKey::from_env(env_var.name.clone(), key); + self.url = url; + self.load_status = LoadStatus::Loaded(api_key); + self.load_task = None; + cx.notify(); + return Task::ready(Ok(())); + } + + let task = if let Some(load_task) = &self.load_task { + load_task.clone() + } else { + let load_task = Self::load(url.clone(), get_this.clone(), cx).shared(); + self.url = url; + self.load_status = LoadStatus::NotPresent; + self.load_task = Some(load_task.clone()); + cx.notify(); + load_task + }; + + cx.spawn(async move |ent, cx| { + task.await; + ent.update(cx, |ent, _cx| { + get_this(ent).load_status.clone().into_authenticate_result() + }) + .ok(); + Ok(()) + }) + } + + fn load( + url: SharedString, + get_this: impl Fn(&mut Ent) -> &mut Self + 'static, + cx: &Context, + ) -> Task<()> { + let credentials_provider = ::global(cx); + cx.spawn({ + async move |ent, cx| { + let load_status = + ApiKey::load_from_system_keychain_impl(&url, credentials_provider.as_ref(), cx) + .await; + ent.update(cx, |ent, cx| { + let this = get_this(ent); + this.url = url; + this.load_status = load_status; + this.load_task = None; + cx.notify(); + }) + .ok(); + } + }) + } +} + +impl ApiKey { + pub fn key(&self) -> &str { + &self.key + } + + pub fn from_env(env_var_name: SharedString, key: &str) -> Self { + Self { + source: ApiKeySource::EnvVar(env_var_name), + key: key.into(), + } + } + + pub async fn load_from_system_keychain( + url: &str, + credentials_provider: &dyn CredentialsProvider, + cx: &AsyncApp, + ) -> Result { + Self::load_from_system_keychain_impl(url, credentials_provider, cx) + .await + .into_authenticate_result() + } + + async fn load_from_system_keychain_impl( + url: &str, + credentials_provider: &dyn CredentialsProvider, + cx: &AsyncApp, + ) -> LoadStatus { + if url.is_empty() { + return LoadStatus::NotPresent; + } + let read_result = credentials_provider.read_credentials(&url, cx).await; + let api_key = match read_result { + Ok(Some((_, api_key))) => api_key, + Ok(None) => return LoadStatus::NotPresent, + Err(err) => return LoadStatus::Error(err.to_string()), + }; + let key = match str::from_utf8(&api_key) { + Ok(key) => key, + Err(_) => return LoadStatus::Error(format!("API key for URL {url} is not utf8")), + }; + LoadStatus::Loaded(Self { + source: ApiKeySource::SystemKeychain, + key: key.into(), + }) + } +} + +impl LoadStatus { + fn into_authenticate_result(self) -> Result { + match self { + LoadStatus::Loaded(api_key) => Ok(api_key), + LoadStatus::NotPresent => Err(AuthenticateError::CredentialsNotFound), + LoadStatus::Error(err) => Err(AuthenticateError::Other(anyhow!(err))), + } + } +} + +#[derive(Debug, Clone)] +enum ApiKeySource { + EnvVar(SharedString), + SystemKeychain, +} + +impl Display for ApiKeySource { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + ApiKeySource::EnvVar(var) => write!(f, "environment variable {}", var), + ApiKeySource::SystemKeychain => write!(f, "system keychain"), + } + } +} diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index 738b72b0c9a6dbb7c9606cc72707b27e66abf09c..61e1a794695310421397469515a43a4d5bf5deb8 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -7,6 +7,7 @@ use gpui::{App, Context, Entity}; use language_model::{LanguageModelProviderId, LanguageModelRegistry}; use provider::deepseek::DeepSeekLanguageModelProvider; +mod api_key; pub mod provider; mod settings; pub mod ui; diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index dd122159fda1dbf8f13ebec2a01b37795d18fe75..e1ca862a2fe12e7c9d0e1fffcef2960dac005b36 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -1,18 +1,14 @@ -use crate::AllLanguageModelSettings; +use crate::api_key::ApiKeyState; use crate::ui::InstructionListItem; use anthropic::{ - AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, ToolResultContent, - ToolResultPart, Usage, + ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, + ToolResultContent, ToolResultPart, Usage, }; -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use collections::{BTreeMap, HashMap}; -use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; -use futures::Stream; -use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; -use gpui::{ - AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, -}; +use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::BoxStream}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, FontStyle, Task, TextStyle, WhiteSpace}; use http_client::HttpClient; use language_model::{ AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, @@ -25,11 +21,12 @@ use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopRea use settings::{Settings, SettingsStore}; use std::pin::Pin; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::{Icon, IconName, List, Tooltip, prelude::*}; -use util::ResultExt; +use util::{ResultExt, truncate_and_trailoff}; +use zed_env_vars::{EnvVar, env_var}; pub use settings::AnthropicAvailableModel as AvailableModel; @@ -48,91 +45,52 @@ pub struct AnthropicLanguageModelProvider { state: gpui::Entity, } -const ANTHROPIC_API_KEY_VAR: &str = "ANTHROPIC_API_KEY"; +const API_KEY_ENV_VAR_NAME: &str = "ANTHROPIC_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); pub struct State { - api_key: Option, - api_key_from_env: bool, - _subscription: Subscription, + api_key_state: ApiKeyState, } impl State { - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .anthropic - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .ok(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) - } - - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .anthropic - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await - .ok(); - - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) - } - fn is_authenticated(&self) -> bool { - self.api_key.is_some() + self.api_key_state.has_key() } - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let key = AnthropicLanguageModelProvider::api_key(cx); - - cx.spawn(async move |this, cx| { - let key = key.await?; - - this.update(cx, |this, cx| { - this.api_key = Some(key.key); - this.api_key_from_env = key.from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = AnthropicLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } -} -pub struct ApiKey { - pub key: String, - pub from_env: bool, + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = AnthropicLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) + } } impl AnthropicLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|_, cx| { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } }); Self { http_client, state } @@ -148,30 +106,16 @@ impl AnthropicLanguageModelProvider { }) } - pub fn api_key(cx: &mut App) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .anthropic - .api_url - .clone(); - - if let Ok(key) = std::env::var(ANTHROPIC_API_KEY_VAR) { - Task::ready(Ok(ApiKey { - key, - from_env: true, - })) + fn settings(cx: &App) -> &AnthropicSettings { + &crate::AllLanguageModelSettings::get_global(cx).anthropic + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + ANTHROPIC_API_URL.into() } else { - cx.spawn(async move |cx| { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - - Ok(ApiKey { - key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - from_env: false, - }) - }) + SharedString::new(api_url.as_str()) } } } @@ -226,11 +170,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { } // Override with available models from settings - for model in AllLanguageModelSettings::get_global(cx) - .anthropic - .available_models - .iter() - { + for model in &AnthropicLanguageModelProvider::settings(cx).available_models { models.insert( model.name.clone(), anthropic::Model::Custom { @@ -278,7 +218,8 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -368,11 +309,11 @@ impl AnthropicModel { > { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).anthropic; - (state.api_key.clone(), settings.api_url.clone()) + let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| { + let api_url = AnthropicLanguageModelProvider::api_url(cx); + (state.api_key_state.key(&api_url), api_url) }) else { - return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed(); + return future::ready(Err(anyhow!("App state dropped").into())).boxed(); }; let beta_headers = self.model.beta_headers(); @@ -434,7 +375,10 @@ impl LanguageModel for AnthropicModel { } fn api_key(&self, cx: &App) -> Option { - self.state.read(cx).api_key.clone() + self.state.read_with(cx, |state, cx| { + let api_url = AnthropicLanguageModelProvider::api_url(cx); + state.api_key_state.key(&api_url).map(|key| key.to_string()) + }) } fn max_token_count(&self) -> u64 { @@ -935,15 +879,17 @@ impl ConfigurationView { return; } + // url changes can cause the editor to be displayed again + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -952,11 +898,11 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement { @@ -991,7 +937,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); if self.load_credentials_task.is_some() { div().child(Label::new("Loading credentials...")).into_any() @@ -1030,7 +976,7 @@ impl Render for ConfigurationView { ) .child( Label::new( - format!("You can also assign the {ANTHROPIC_API_KEY_VAR} environment variable and restart Zed."), + format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."), ) .size(LabelSize::Small) .color(Color::Muted), @@ -1050,9 +996,14 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {ANTHROPIC_API_KEY_VAR} environment variable.") + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable") } else { - "API key configured.".to_string() + let api_url = AnthropicLanguageModelProvider::api_url(cx); + if api_url == ANTHROPIC_API_URL { + "API key configured".to_string() + } else { + format!("API key configured for {}", truncate_and_trailoff(&api_url, 32)) + } })), ) .child( @@ -1063,7 +1014,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .disabled(env_var_set) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, unset the {ANTHROPIC_API_KEY_VAR} environment variable."))) + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 292512c78687c72c5237eaa694b44e953fad0b5e..c62a6989501a71e444b07992bff0cbe1a1bbd6d6 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -181,11 +181,21 @@ impl State { self.default_model = models .iter() - .find(|model| model.id == response.default_model) + .find(|model| { + response + .default_model + .as_ref() + .is_some_and(|default_model_id| &model.id == default_model_id) + }) .cloned(); self.default_fast_model = models .iter() - .find(|model| model.id == response.default_fast_model) + .find(|model| { + response + .default_fast_model + .as_ref() + .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id) + }) .cloned(); self.recommended_models = response .recommended_models @@ -507,29 +517,36 @@ where impl From for LanguageModelCompletionError { fn from(error: ApiError) -> Self { - if let Ok(cloud_error) = serde_json::from_str::(&error.body) - && cloud_error.code.starts_with("upstream_http_") - { - let status = if let Some(status) = cloud_error.upstream_status { - status - } else if cloud_error.code.ends_with("_error") { - error.status - } else { - // If there's a status code in the code string (e.g. "upstream_http_429") - // then use that; otherwise, see if the JSON contains a status code. - cloud_error - .code - .strip_prefix("upstream_http_") - .and_then(|code_str| code_str.parse::().ok()) - .and_then(|code| StatusCode::from_u16(code).ok()) - .unwrap_or(error.status) - }; + if let Ok(cloud_error) = serde_json::from_str::(&error.body) { + if cloud_error.code.starts_with("upstream_http_") { + let status = if let Some(status) = cloud_error.upstream_status { + status + } else if cloud_error.code.ends_with("_error") { + error.status + } else { + // If there's a status code in the code string (e.g. "upstream_http_429") + // then use that; otherwise, see if the JSON contains a status code. + cloud_error + .code + .strip_prefix("upstream_http_") + .and_then(|code_str| code_str.parse::().ok()) + .and_then(|code| StatusCode::from_u16(code).ok()) + .unwrap_or(error.status) + }; - return LanguageModelCompletionError::UpstreamProviderError { - message: cloud_error.message, - status, - retry_after: cloud_error.retry_after.map(Duration::from_secs_f64), - }; + return LanguageModelCompletionError::UpstreamProviderError { + message: cloud_error.message, + status, + retry_after: cloud_error.retry_after.map(Duration::from_secs_f64), + }; + } + + return LanguageModelCompletionError::from_http_status( + PROVIDER_NAME, + error.status, + cloud_error.message, + None, + ); } let retry_after = None; diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index cd8d3f91ac9a6ddca5aeb49b20bed05286cb59a6..a8f08a420664b10e4478df0b566382621ffeb760 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -1,12 +1,12 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use collections::{BTreeMap, HashMap}; -use credentials_provider::CredentialsProvider; +use deepseek::DEEPSEEK_API_URL; use editor::{Editor, EditorElement, EditorStyle}; use futures::Stream; -use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; +use futures::{FutureExt, StreamExt, future, future::BoxFuture, stream::BoxStream}; use gpui::{ - AnyView, AppContext as _, AsyncApp, Entity, FontStyle, Subscription, Task, TextStyle, - WhiteSpace, + AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace, + Window, }; use http_client::HttpClient; use language_model::{ @@ -20,16 +20,19 @@ pub use settings::DeepseekAvailableModel as AvailableModel; use settings::{Settings, SettingsStore}; use std::pin::Pin; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use theme::ThemeSettings; use ui::{Icon, IconName, List, prelude::*}; -use util::ResultExt; +use util::{ResultExt, truncate_and_trailoff}; +use zed_env_vars::{EnvVar, env_var}; -use crate::{AllLanguageModelSettings, ui::InstructionListItem}; +use crate::{api_key::ApiKeyState, ui::InstructionListItem}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek"); -const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY"; + +const API_KEY_ENV_VAR_NAME: &str = "DEEPSEEK_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); #[derive(Default)] struct RawToolCall { @@ -49,95 +52,48 @@ pub struct DeepSeekLanguageModelProvider { } pub struct State { - api_key: Option, - api_key_from_env: bool, - _subscription: Subscription, + api_key_state: ApiKeyState, } impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() - } - - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .deepseek - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) + self.api_key_state.has_key() } - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .deepseek - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await?; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = DeepSeekLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .deepseek - .api_url - .clone(); - cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(DEEPSEEK_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = DeepSeekLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) } } impl DeepSeekLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|_this: &mut State, cx| { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } }); Self { http_client, state } @@ -150,7 +106,20 @@ impl DeepSeekLanguageModelProvider { state: self.state.clone(), http_client: self.http_client.clone(), request_limiter: RateLimiter::new(4), - }) as Arc + }) + } + + fn settings(cx: &App) -> &DeepSeekSettings { + &crate::AllLanguageModelSettings::get_global(cx).deepseek + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + DEEPSEEK_API_URL.into() + } else { + SharedString::new(api_url.as_str()) + } } } @@ -189,11 +158,7 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider { models.insert("deepseek-chat", deepseek::Model::Chat); models.insert("deepseek-reasoner", deepseek::Model::Reasoner); - for available_model in AllLanguageModelSettings::get_global(cx) - .deepseek - .available_models - .iter() - { + for available_model in &Self::settings(cx).available_models { models.insert( &available_model.name, deepseek::Model::Custom { @@ -230,7 +195,8 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -249,15 +215,20 @@ impl DeepSeekLanguageModel { cx: &AsyncApp, ) -> BoxFuture<'static, Result>>> { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).deepseek; - (state.api_key.clone(), settings.api_url.clone()) + + let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| { + let api_url = DeepSeekLanguageModelProvider::api_url(cx); + (state.api_key_state.key(&api_url), api_url) }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + return future::ready(Err(anyhow!("App state dropped"))).boxed(); }; let future = self.request_limiter.stream(async move { - let api_key = api_key.context("Missing DeepSeek API Key")?; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + }); + }; let request = deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request); let response = request.await?; @@ -600,7 +571,7 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context) { - let api_key = self.api_key_editor.read(cx).text(cx); + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); if api_key.is_empty() { return; } @@ -608,12 +579,10 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn(async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -621,10 +590,12 @@ impl ConfigurationView { .update(cx, |editor, cx| editor.set_text("", window, cx)); let state = self.state.clone(); - cx.spawn(async move |_, cx| state.update(cx, |state, cx| state.reset_api_key(cx))?.await) - .detach_and_log_err(cx); - - cx.notify(); + cx.spawn(async move |_, cx| { + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await + }) + .detach_and_log_err(cx); } fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement { @@ -662,7 +633,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); if self.load_credentials_task.is_some() { div().child(Label::new("Loading credentials...")).into_any() @@ -696,8 +667,7 @@ impl Render for ConfigurationView { ) .child( Label::new(format!( - "Or set the {} environment variable.", - DEEPSEEK_API_KEY_VAR + "Or set the {API_KEY_ENV_VAR_NAME} environment variable." )) .size(LabelSize::Small) .color(Color::Muted), @@ -717,9 +687,17 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {}", DEEPSEEK_API_KEY_VAR) + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable") } else { - "API key configured".to_string() + let api_url = DeepSeekLanguageModelProvider::api_url(cx); + if api_url == DEEPSEEK_API_URL { + "API key configured".to_string() + } else { + format!( + "API key configured for {}", + truncate_and_trailoff(&api_url, 32) + ) + } })), ) .child( diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index ad87ceac438b922838416f07e761f8a29f4567a2..fafb2258e606d712991a9e8a2febd07735fae997 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -2,13 +2,14 @@ use anyhow::{Context as _, Result, anyhow}; use collections::BTreeMap; use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; -use futures::{FutureExt, Stream, StreamExt, future::BoxFuture}; +use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture}; use google_ai::{ FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction, ThinkingConfig, UsageMetadata, }; use gpui::{ - AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, + AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace, + Window, }; use http_client::HttpClient; use language_model::{ @@ -27,19 +28,19 @@ pub use settings::GoogleAvailableModel as AvailableModel; use settings::{Settings, SettingsStore}; use std::pin::Pin; use std::sync::{ - Arc, + Arc, LazyLock, atomic::{self, AtomicU64}, }; use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::{Icon, IconName, List, Tooltip, prelude::*}; -use util::ResultExt; +use util::{ResultExt, truncate_and_trailoff}; +use zed_env_vars::EnvVar; -use crate::AllLanguageModelSettings; +use crate::api_key::ApiKey; +use crate::api_key::ApiKeyState; use crate::ui::InstructionListItem; -use super::anthropic::ApiKey; - const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID; const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME; @@ -66,101 +67,56 @@ pub struct GoogleLanguageModelProvider { } pub struct State { - api_key: Option, - api_key_from_env: bool, - _subscription: Subscription, + api_key_state: ApiKeyState, } -const GEMINI_API_KEY_VAR: &str = "GEMINI_API_KEY"; -const GOOGLE_AI_API_KEY_VAR: &str = "GOOGLE_AI_API_KEY"; +const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY"; +const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY"; + +static API_KEY_ENV_VAR: LazyLock = LazyLock::new(|| { + // Try GEMINI_API_KEY first as primary, fallback to GOOGLE_AI_API_KEY + EnvVar::new(GEMINI_API_KEY_VAR_NAME.into()).or(EnvVar::new(GOOGLE_AI_API_KEY_VAR_NAME.into())) +}); impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() + self.api_key_state.has_key() } - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .google - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = GoogleLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .google - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await?; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) - } - - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .google - .api_url - .clone(); - - cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR) { - (api_key, true) - } else if let Ok(api_key) = std::env::var(GEMINI_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = GoogleLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) } } impl GoogleLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|_, cx| { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } }); Self { http_client, state } @@ -176,30 +132,32 @@ impl GoogleLanguageModelProvider { }) } - pub fn api_key(cx: &mut App) -> Task> { + pub fn api_key_for_gemini_cli(cx: &mut App) -> Task> { + if let Some(key) = API_KEY_ENV_VAR.value.clone() { + return Task::ready(Ok(key)); + } let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .google - .api_url - .clone(); - - if let Ok(key) = std::env::var(GEMINI_API_KEY_VAR) { - Task::ready(Ok(ApiKey { - key, - from_env: true, - })) - } else { - cx.spawn(async move |cx| { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) + let api_url = Self::api_url(cx).to_string(); + cx.spawn(async move |cx| { + Ok( + ApiKey::load_from_system_keychain(&api_url, credentials_provider.as_ref(), cx) .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; + .key() + .to_string(), + ) + }) + } - Ok(ApiKey { - key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - from_env: false, - }) - }) + fn settings(cx: &App) -> &GoogleSettings { + &crate::AllLanguageModelSettings::get_global(cx).google + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + google_ai::API_URL.into() + } else { + SharedString::new(api_url.as_str()) } } } @@ -244,10 +202,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { } // Override with available models from settings - for model in &AllLanguageModelSettings::get_global(cx) - .google - .available_models - { + for model in &GoogleLanguageModelProvider::settings(cx).available_models { models.insert( model.name.clone(), google_ai::Model::Custom { @@ -292,7 +247,8 @@ impl LanguageModelProvider for GoogleLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -315,11 +271,11 @@ impl GoogleLanguageModel { > { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).google; - (state.api_key.clone(), settings.api_url.clone()) + let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| { + let api_url = GoogleLanguageModelProvider::api_url(cx); + (state.api_key_state.key(&api_url), api_url) }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + return future::ready(Err(anyhow!("App state dropped"))).boxed(); }; async move { @@ -393,13 +349,16 @@ impl LanguageModel for GoogleLanguageModel { let model_id = self.model.request_id().to_string(); let request = into_google(request, model_id, self.model.mode()); let http_client = self.http_client.clone(); - let api_key = self.state.read(cx).api_key.clone(); - - let settings = &AllLanguageModelSettings::get_global(cx).google; - let api_url = settings.api_url.clone(); + let api_url = GoogleLanguageModelProvider::api_url(cx); + let api_key = self.state.read(cx).api_key_state.key(&api_url); async move { - let api_key = api_key.context("Missing Google API key")?; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + } + .into()); + }; let response = google_ai::count_tokens( http_client.as_ref(), &api_url, @@ -827,20 +786,22 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self.api_key_editor.read(cx).text(cx); + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); if api_key.is_empty() { return; } + // url changes can cause the editor to be displayed again + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -849,11 +810,11 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement { @@ -888,7 +849,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); if self.load_credentials_task.is_some() { div().child(Label::new("Loading credentials...")).into_any() @@ -925,7 +886,7 @@ impl Render for ConfigurationView { ) .child( Label::new( - format!("You can also assign the {GEMINI_API_KEY_VAR} environment variable and restart Zed."), + format!("You can also assign the {GEMINI_API_KEY_VAR_NAME} environment variable and restart Zed."), ) .size(LabelSize::Small).color(Color::Muted), ) @@ -944,9 +905,14 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {GEMINI_API_KEY_VAR} environment variable.") + format!("API key set in {} environment variable", API_KEY_ENV_VAR.name) } else { - "API key configured.".to_string() + let api_url = GoogleLanguageModelProvider::api_url(cx); + if api_url == google_ai::API_URL { + "API key configured".to_string() + } else { + format!("API key configured for {}", truncate_and_trailoff(&api_url, 32)) + } })), ) .child( @@ -957,7 +923,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .disabled(env_var_set) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR} and {GOOGLE_AI_API_KEY_VAR} environment variables are unset."))) + this.tooltip(Tooltip::text(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR_NAME} and {GOOGLE_AI_API_KEY_VAR_NAME} environment variables are unset."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 803049579c7f7fd68bda8da2d67fa618b7d413f5..b3375e528c29052f3e5451e20c40abf0fe8a10ad 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -1,10 +1,10 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use collections::BTreeMap; -use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; -use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream}; +use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::BoxStream}; use gpui::{ - AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, + AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace, + Window, }; use http_client::HttpClient; use language_model::{ @@ -14,23 +14,27 @@ use language_model::{ LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage, }; -use mistral::StreamResponse; +use mistral::{MISTRAL_API_URL, StreamResponse}; pub use settings::MistralAvailableModel as AvailableModel; use settings::{Settings, SettingsStore}; use std::collections::HashMap; use std::pin::Pin; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::{Icon, IconName, List, Tooltip, prelude::*}; -use util::ResultExt; +use util::{ResultExt, truncate_and_trailoff}; +use zed_env_vars::{EnvVar, env_var}; -use crate::{AllLanguageModelSettings, ui::InstructionListItem}; +use crate::{api_key::ApiKeyState, ui::InstructionListItem}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral"); +const API_KEY_ENV_VAR_NAME: &str = "MISTRAL_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); + #[derive(Default, Clone, Debug, PartialEq)] pub struct MistralSettings { pub api_url: String, @@ -43,96 +47,48 @@ pub struct MistralLanguageModelProvider { } pub struct State { - api_key: Option, - api_key_from_env: bool, - _subscription: Subscription, + api_key_state: ApiKeyState, } -const MISTRAL_API_KEY_VAR: &str = "MISTRAL_API_KEY"; - impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() - } - - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .mistral - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) + self.api_key_state.has_key() } - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .mistral - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await?; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = MistralLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .mistral - .api_url - .clone(); - cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(MISTRAL_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = MistralLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) } } impl MistralLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|_this: &mut State, cx| { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } }); Self { http_client, state } @@ -147,6 +103,19 @@ impl MistralLanguageModelProvider { request_limiter: RateLimiter::new(4), }) } + + fn settings(cx: &App) -> &MistralSettings { + &crate::AllLanguageModelSettings::get_global(cx).mistral + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + mistral::MISTRAL_API_URL.into() + } else { + SharedString::new(api_url.as_str()) + } + } } impl LanguageModelProviderState for MistralLanguageModelProvider { @@ -189,10 +158,7 @@ impl LanguageModelProvider for MistralLanguageModelProvider { } // Override with available models from settings - for model in &AllLanguageModelSettings::get_global(cx) - .mistral - .available_models - { + for model in &Self::settings(cx).available_models { models.insert( model.name.clone(), mistral::Model::Custom { @@ -241,7 +207,8 @@ impl LanguageModelProvider for MistralLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -263,15 +230,20 @@ impl MistralLanguageModel { Result>>, > { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).mistral; - (state.api_key.clone(), settings.api_url.clone()) + + let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| { + let api_url = MistralLanguageModelProvider::api_url(cx); + (state.api_key_state.key(&api_url), api_url) }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + return future::ready(Err(anyhow!("App state dropped"))).boxed(); }; let future = self.request_limiter.stream(async move { - let api_key = api_key.context("Missing Mistral API Key")?; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + }); + }; let request = mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request); let response = request.await?; @@ -767,20 +739,22 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self.api_key_editor.read(cx).text(cx); + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); if api_key.is_empty() { return; } + // url changes can cause the editor to be displayed again + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -789,11 +763,11 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement { @@ -828,7 +802,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); if self.load_credentials_task.is_some() { div().child(Label::new("Loading credentials...")).into_any() @@ -865,7 +839,7 @@ impl Render for ConfigurationView { ) .child( Label::new( - format!("You can also assign the {MISTRAL_API_KEY_VAR} environment variable and restart Zed."), + format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."), ) .size(LabelSize::Small).color(Color::Muted), ) @@ -884,9 +858,14 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {MISTRAL_API_KEY_VAR} environment variable.") + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable") } else { - "API key configured.".to_string() + let api_url = MistralLanguageModelProvider::api_url(cx); + if api_url == MISTRAL_API_URL { + "API key configured".to_string() + } else { + format!("API key configured for {}", truncate_and_trailoff(&api_url, 32)) + } })), ) .child( @@ -897,7 +876,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .disabled(env_var_set) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, unset the {MISTRAL_API_KEY_VAR} environment variable."))) + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index 2e377071789b6008965e40a45b37795f0c649acf..ff4a8d6c2c3b7b3cb11f6f597e129ca769459bc0 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -1,7 +1,8 @@ use anyhow::{Result, anyhow}; +use fs::Fs; use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; use futures::{Stream, TryFutureExt, stream}; -use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task}; +use gpui::{AnyView, App, AsyncApp, Context, Task}; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, @@ -10,19 +11,23 @@ use language_model::{ LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage, }; +use menu; use ollama::{ - ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, OllamaFunctionCall, + ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, OLLAMA_API_URL, OllamaFunctionCall, OllamaFunctionTool, OllamaToolCall, get_models, show_model, stream_chat_completion, }; pub use settings::OllamaAvailableModel as AvailableModel; -use settings::{Settings, SettingsStore}; +use settings::{Settings, SettingsStore, update_settings_file}; use std::pin::Pin; +use std::sync::LazyLock; use std::sync::atomic::{AtomicU64, Ordering}; use std::{collections::HashMap, sync::Arc}; -use ui::{ButtonLike, Indicator, List, prelude::*}; -use util::ResultExt; +use ui::{ButtonLike, ElevationIndex, List, Tooltip, prelude::*}; +use ui_input::SingleLineInput; +use zed_env_vars::{EnvVar, env_var}; use crate::AllLanguageModelSettings; +use crate::api_key::ApiKeyState; use crate::ui::InstructionListItem; const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download"; @@ -32,6 +37,9 @@ const OLLAMA_SITE: &str = "https://ollama.com/"; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Ollama"); +const API_KEY_ENV_VAR_NAME: &str = "OLLAMA_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); + #[derive(Default, Debug, Clone, PartialEq)] pub struct OllamaSettings { pub api_url: String, @@ -44,25 +52,61 @@ pub struct OllamaLanguageModelProvider { } pub struct State { + api_key_state: ApiKeyState, http_client: Arc, - available_models: Vec, + fetched_models: Vec, fetch_model_task: Option>>, - _subscription: Subscription, } impl State { fn is_authenticated(&self) -> bool { - !self.available_models.is_empty() + !self.fetched_models.is_empty() + } + + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = OllamaLanguageModelProvider::api_url(cx); + let task = self + .api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx); + + self.fetched_models.clear(); + cx.spawn(async move |this, cx| { + let result = task.await; + this.update(cx, |this, cx| this.restart_fetch_models_task(cx)) + .ok(); + result + }) + } + + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = OllamaLanguageModelProvider::api_url(cx); + let task = self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); + + // Always try to fetch models - if no API key is needed (local Ollama), it will work + // If API key is needed and provided, it will work + // If API key is needed and not provided, it will fail gracefully + cx.spawn(async move |this, cx| { + let result = task.await; + this.update(cx, |this, cx| this.restart_fetch_models_task(cx)) + .ok(); + result + }) } fn fetch_models(&mut self, cx: &mut Context) -> Task> { - let settings = &AllLanguageModelSettings::get_global(cx).ollama; let http_client = Arc::clone(&self.http_client); - let api_url = settings.api_url.clone(); + let api_url = OllamaLanguageModelProvider::api_url(cx); + let api_key = self.api_key_state.key(&api_url); // As a proxy for the server being "authenticated", we'll check if its up by fetching the models cx.spawn(async move |this, cx| { - let models = get_models(http_client.as_ref(), &api_url, None).await?; + let models = + get_models(http_client.as_ref(), &api_url, api_key.as_deref(), None).await?; let tasks = models .into_iter() @@ -73,9 +117,12 @@ impl State { .map(|model| { let http_client = Arc::clone(&http_client); let api_url = api_url.clone(); + let api_key = api_key.clone(); async move { let name = model.name.as_str(); - let capabilities = show_model(http_client.as_ref(), &api_url, name).await?; + let capabilities = + show_model(http_client.as_ref(), &api_url, api_key.as_deref(), name) + .await?; let ollama_model = ollama::Model::new( name, None, @@ -100,7 +147,7 @@ impl State { ollama_models.sort_by(|a, b| a.name.cmp(&b.name)); this.update(cx, |this, cx| { - this.available_models = ollama_models; + this.fetched_models = ollama_models; cx.notify(); }) }) @@ -110,15 +157,6 @@ impl State { let task = self.fetch_models(cx); self.fetch_model_task.replace(task); } - - fn authenticate(&mut self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let fetch_models_task = self.fetch_models(cx); - cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?)) - } } impl OllamaLanguageModelProvider { @@ -126,30 +164,47 @@ impl OllamaLanguageModelProvider { let this = Self { http_client: http_client.clone(), state: cx.new(|cx| { - let subscription = cx.observe_global::({ - let mut settings = AllLanguageModelSettings::get_global(cx).ollama.clone(); + cx.observe_global::({ + let mut last_settings = OllamaLanguageModelProvider::settings(cx).clone(); move |this: &mut State, cx| { - let new_settings = &AllLanguageModelSettings::get_global(cx).ollama; - if &settings != new_settings { - settings = new_settings.clone(); - this.restart_fetch_models_task(cx); + let current_settings = OllamaLanguageModelProvider::settings(cx); + let settings_changed = current_settings != &last_settings; + if settings_changed { + let url_changed = last_settings.api_url != current_settings.api_url; + last_settings = current_settings.clone(); + if url_changed { + this.fetched_models.clear(); + this.authenticate(cx).detach(); + } cx.notify(); } } - }); + }) + .detach(); State { http_client, - available_models: Default::default(), + fetched_models: Default::default(), fetch_model_task: None, - _subscription: subscription, + api_key_state: ApiKeyState::new(Self::api_url(cx)), } }), }; - this.state - .update(cx, |state, cx| state.restart_fetch_models_task(cx)); this } + + fn settings(cx: &App) -> &OllamaSettings { + &AllLanguageModelSettings::get_global(cx).ollama + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + OLLAMA_API_URL.into() + } else { + SharedString::new(api_url.as_str()) + } + } } impl LanguageModelProviderState for OllamaLanguageModelProvider { @@ -189,16 +244,12 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { let mut models: HashMap = HashMap::new(); // Add models from the Ollama API - for model in self.state.read(cx).available_models.iter() { + for model in self.state.read(cx).fetched_models.iter() { models.insert(model.name.clone(), model.clone()); } // Override with available models from settings - for model in AllLanguageModelSettings::get_global(cx) - .ollama - .available_models - .iter() - { + for model in &OllamaLanguageModelProvider::settings(cx).available_models { models.insert( model.name.clone(), ollama::Model { @@ -221,6 +272,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { model, http_client: self.http_client.clone(), request_limiter: RateLimiter::new(4), + state: self.state.clone(), }) as Arc }) .collect::>(); @@ -248,7 +300,8 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.fetch_models(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -257,6 +310,7 @@ pub struct OllamaLanguageModel { model: ollama::Model, http_client: Arc, request_limiter: RateLimiter, + state: gpui::Entity, } impl OllamaLanguageModel { @@ -435,15 +489,17 @@ impl LanguageModel for OllamaLanguageModel { let request = self.to_ollama_request(request); let http_client = self.http_client.clone(); - let Ok(api_url) = cx.update(|cx| { - let settings = &AllLanguageModelSettings::get_global(cx).ollama; - settings.api_url.clone() + let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| { + let api_url = OllamaLanguageModelProvider::api_url(cx); + (state.api_key_state.key(&api_url), api_url) }) else { return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed(); }; let future = self.request_limiter.stream(async move { - let stream = stream_chat_completion(http_client.as_ref(), &api_url, request).await?; + let stream = + stream_chat_completion(http_client.as_ref(), &api_url, api_key.as_deref(), request) + .await?; let stream = map_to_language_model_completion_events(stream); Ok(stream) }); @@ -555,138 +611,307 @@ fn map_to_language_model_completion_events( } struct ConfigurationView { + api_key_editor: gpui::Entity, + api_url_editor: gpui::Entity, state: gpui::Entity, - loading_models_task: Option>, } impl ConfigurationView { pub fn new(state: gpui::Entity, window: &mut Window, cx: &mut Context) -> Self { - let loading_models_task = Some(cx.spawn_in(window, { - let state = state.clone(); - async move |this, cx| { - if let Some(task) = state - .update(cx, |state, cx| state.authenticate(cx)) - .log_err() - { - task.await.log_err(); - } - this.update(cx, |this, cx| { - this.loading_models_task = None; - cx.notify(); - }) - .log_err(); - } - })); + let api_key_editor = + cx.new(|cx| SingleLineInput::new(window, cx, "63e02e...").label("API key")); + + let api_url_editor = cx.new(|cx| { + let input = SingleLineInput::new(window, cx, OLLAMA_API_URL).label("API URL"); + input.set_text(OllamaLanguageModelProvider::api_url(cx), window, cx); + input + }); + + cx.observe(&state, |_, _, cx| { + cx.notify(); + }) + .detach(); Self { + api_key_editor, + api_url_editor, state, - loading_models_task, } } fn retry_connection(&self, cx: &mut App) { self.state - .update(cx, |state, cx| state.fetch_models(cx)) - .detach_and_log_err(cx); + .update(cx, |state, cx| state.restart_fetch_models_task(cx)); } -} -impl Render for ConfigurationView { - fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let is_authenticated = self.state.read(cx).is_authenticated(); + fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); + if api_key.is_empty() { + return; + } - let ollama_intro = - "Get up & running with Llama 3.3, Mistral, Gemma 2, and other LLMs with Ollama."; + // url changes can cause the editor to be displayed again + self.api_key_editor + .update(cx, |input, cx| input.set_text("", window, cx)); - if self.loading_models_task.is_some() { - div().child(Label::new("Loading models...")).into_any() - } else { + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? + .await + }) + .detach_and_log_err(cx); + } + + fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { + self.api_key_editor + .update(cx, |input, cx| input.set_text("", window, cx)); + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn save_api_url(&mut self, cx: &mut Context) { + let api_url = self.api_url_editor.read(cx).text(cx).trim().to_string(); + let current_url = OllamaLanguageModelProvider::api_url(cx); + if !api_url.is_empty() && &api_url != ¤t_url { + let fs = ::global(cx); + update_settings_file(fs, cx, move |settings, _| { + settings + .language_models + .get_or_insert_default() + .ollama + .get_or_insert_default() + .api_url = Some(api_url); + }); + } + } + + fn reset_api_url(&mut self, window: &mut Window, cx: &mut Context) { + self.api_url_editor + .update(cx, |input, cx| input.set_text("", window, cx)); + let fs = ::global(cx); + update_settings_file(fs, cx, |settings, _cx| { + if let Some(settings) = settings + .language_models + .as_mut() + .and_then(|models| models.ollama.as_mut()) + { + settings.api_url = Some(OLLAMA_API_URL.into()); + } + }); + cx.notify(); + } + + fn render_instructions() -> Div { + v_flex() + .gap_2() + .child(Label::new( + "Run LLMs locally on your machine with Ollama, or connect to an Ollama server. \ + Can provide access to Llama, Mistral, Gemma, and hundreds of other models.", + )) + .child(Label::new("To use local Ollama:")) + .child( + List::new() + .child(InstructionListItem::new( + "Download and install Ollama from", + Some("ollama.com"), + Some("https://ollama.com/download"), + )) + .child(InstructionListItem::text_only( + "Start Ollama and download a model: `ollama run gpt-oss:20b`", + )) + .child(InstructionListItem::text_only( + "Click 'Connect' below to start using Ollama in Zed", + )), + ) + .child(Label::new( + "Alternatively, you can connect to an Ollama server by specifying its \ + URL and API key (may not be required):", + )) + } + + fn render_api_key_editor(&self, cx: &Context) -> Div { + let state = self.state.read(cx); + let env_var_set = state.api_key_state.is_from_env_var(); + + if !state.api_key_state.has_key() { v_flex() - .gap_2() + .on_action(cx.listener(Self::save_api_key)) + .child(self.api_key_editor.clone()) + .child( + Label::new( + format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed.") + ) + .size(LabelSize::Small) + .color(Color::Muted), + ) + } else { + h_flex() + .p_3() + .justify_between() + .rounded_md() + .border_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().elevated_surface_background) .child( - v_flex().gap_1().child(Label::new(ollama_intro)).child( - List::new() - .child(InstructionListItem::text_only("Ollama must be running with at least one model installed to use it in the assistant.")) - .child(InstructionListItem::text_only( - "Once installed, try `ollama run llama3.2`", - )), - ), + h_flex() + .gap_2() + .child(Icon::new(IconName::Check).color(Color::Success)) + .child( + Label::new( + if env_var_set { + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.") + } else { + "API key configured".to_string() + } + ) + ) ) + .child( + Button::new("reset-api-key", "Reset API Key") + .label_size(LabelSize::Small) + .icon(IconName::Undo) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) + .layer(ElevationIndex::ModalSurface) + .when(env_var_set, |this| { + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))) + }) + .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), + ) + } + } + + fn render_api_url_editor(&self, cx: &Context) -> Div { + let api_url = OllamaLanguageModelProvider::api_url(cx); + let custom_api_url_set = api_url != OLLAMA_API_URL; + + if custom_api_url_set { + h_flex() + .p_3() + .justify_between() + .rounded_md() + .border_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().elevated_surface_background) .child( h_flex() - .w_full() - .justify_between() .gap_2() - .child( - h_flex() - .w_full() - .gap_2() - .map(|this| { - if is_authenticated { - this.child( - Button::new("ollama-site", "Ollama") - .style(ButtonStyle::Subtle) - .icon(IconName::ArrowUpRight) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE)) - .into_any_element(), - ) - } else { - this.child( - Button::new( - "download_ollama_button", - "Download Ollama", - ) + .child(Icon::new(IconName::Check).color(Color::Success)) + .child(v_flex().gap_1().child(Label::new(api_url))), + ) + .child( + Button::new("reset-api-url", "Reset API URL") + .label_size(LabelSize::Small) + .icon(IconName::Undo) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) + .layer(ElevationIndex::ModalSurface) + .on_click( + cx.listener(|this, _, window, cx| this.reset_api_url(window, cx)), + ), + ) + } else { + v_flex() + .on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| { + this.save_api_url(cx); + cx.notify(); + })) + .gap_2() + .child(self.api_url_editor.clone()) + } + } +} + +impl Render for ConfigurationView { + fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { + let is_authenticated = self.state.read(cx).is_authenticated(); + + v_flex() + .gap_2() + .child(Self::render_instructions()) + .child(self.render_api_url_editor(cx)) + .child(self.render_api_key_editor(cx)) + .child( + h_flex() + .w_full() + .justify_between() + .gap_2() + .child( + h_flex() + .w_full() + .gap_2() + .map(|this| { + if is_authenticated { + this.child( + Button::new("ollama-site", "Ollama") .style(ButtonStyle::Subtle) .icon(IconName::ArrowUpRight) - .icon_size(IconSize::Small) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE)) + .into_any_element(), + ) + } else { + this.child( + Button::new("download_ollama_button", "Download Ollama") + .style(ButtonStyle::Subtle) + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::XSmall) .icon_color(Color::Muted) .on_click(move |_, _, cx| { cx.open_url(OLLAMA_DOWNLOAD_URL) }) .into_any_element(), - ) - } - }) - .child( - Button::new("view-models", "View All Models") - .style(ButtonStyle::Subtle) - .icon(IconName::ArrowUpRight) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)), - ), - ) - .map(|this| { - if is_authenticated { - this.child( - ButtonLike::new("connected") - .disabled(true) - .cursor_style(gpui::CursorStyle::Arrow) - .child( - h_flex() - .gap_2() - .child(Indicator::dot().color(Color::Success)) - .child(Label::new("Connected")) - .into_any_element(), - ), - ) - } else { - this.child( - Button::new("retry_ollama_models", "Connect") - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon(IconName::PlayFilled) - .on_click(cx.listener(move |this, _, _, cx| { + ) + } + }) + .child( + Button::new("view-models", "View All Models") + .style(ButtonStyle::Subtle) + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)), + ), + ) + .map(|this| { + if is_authenticated { + this.child( + ButtonLike::new("connected") + .disabled(true) + .cursor_style(gpui::CursorStyle::Arrow) + .child( + h_flex() + .gap_2() + .child(Icon::new(IconName::Check).color(Color::Success)) + .child(Label::new("Connected")) + .into_any_element(), + ), + ) + } else { + this.child( + Button::new("retry_ollama_models", "Connect") + .icon_position(IconPosition::Start) + .icon_size(IconSize::XSmall) + .icon(IconName::PlayOutlined) + .on_click( + cx.listener(move |this, _, _, cx| { this.retry_connection(cx) - })), - ) - } - }) - ) - .into_any() - } + }), + ), + ) + } + }), + ) } } diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 4c0de86a945de01272aef36598cf12513fa2b70d..ade2e47ca39ebfdc48806380b541ed7d97d4d3da 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -1,10 +1,8 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use collections::{BTreeMap, HashMap}; -use credentials_provider::CredentialsProvider; - use futures::Stream; -use futures::{FutureExt, StreamExt, future::BoxFuture}; -use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window}; +use futures::{FutureExt, StreamExt, future, future::BoxFuture}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, @@ -14,23 +12,27 @@ use language_model::{ RateLimiter, Role, StopReason, TokenUsage, }; use menu; -use open_ai::{ImageUrl, Model, ReasoningEffort, ResponseStreamEvent, stream_completion}; -use settings::OpenAiAvailableModel as AvailableModel; -use settings::{Settings, SettingsStore}; +use open_ai::{ + ImageUrl, Model, OPEN_AI_API_URL, ReasoningEffort, ResponseStreamEvent, stream_completion, +}; +use settings::{OpenAiAvailableModel as AvailableModel, Settings, SettingsStore}; use std::pin::Pin; use std::str::FromStr as _; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; - use ui::{ElevationIndex, List, Tooltip, prelude::*}; use ui_input::SingleLineInput; -use util::ResultExt; +use util::{ResultExt, truncate_and_trailoff}; +use zed_env_vars::{EnvVar, env_var}; -use crate::{AllLanguageModelSettings, ui::InstructionListItem}; +use crate::{api_key::ApiKeyState, ui::InstructionListItem}; const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID; const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME; +const API_KEY_ENV_VAR_NAME: &str = "OPENAI_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); + #[derive(Default, Clone, Debug, PartialEq)] pub struct OpenAiSettings { pub api_url: String, @@ -43,132 +45,48 @@ pub struct OpenAiLanguageModelProvider { } pub struct State { - api_key: Option, - api_key_from_env: bool, - last_api_url: String, - _subscription: Subscription, + api_key_state: ApiKeyState, } -const OPENAI_API_KEY_VAR: &str = "OPENAI_API_KEY"; - impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() - } - - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .openai - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) + self.api_key_state.has_key() } - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .openai - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) - } - - fn get_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .openai - .api_url - .clone(); - cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENAI_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = OpenAiLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - self.get_api_key(cx) + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = OpenAiLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) } } impl OpenAiLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let initial_api_url = AllLanguageModelSettings::get_global(cx) - .openai - .api_url - .clone(); - - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - last_api_url: initial_api_url.clone(), - _subscription: cx.observe_global::(|this: &mut State, cx| { - let current_api_url = AllLanguageModelSettings::get_global(cx) - .openai - .api_url - .clone(); - - if this.last_api_url != current_api_url { - this.last_api_url = current_api_url; - if !this.api_key_from_env { - this.api_key = None; - let spawn_task = cx.spawn(async move |handle, cx| { - if let Ok(task) = handle.update(cx, |this, cx| this.get_api_key(cx)) { - if let Err(_) = task.await { - handle - .update(cx, |this, _| { - this.api_key = None; - this.api_key_from_env = false; - }) - .ok(); - } - } - }); - spawn_task.detach(); - } - } + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } }); Self { http_client, state } @@ -183,6 +101,19 @@ impl OpenAiLanguageModelProvider { request_limiter: RateLimiter::new(4), }) } + + fn settings(cx: &App) -> &OpenAiSettings { + &crate::AllLanguageModelSettings::get_global(cx).openai + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + open_ai::OPEN_AI_API_URL.into() + } else { + SharedString::new(api_url.as_str()) + } + } } impl LanguageModelProviderState for OpenAiLanguageModelProvider { @@ -225,10 +156,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { } // Override with available models from settings - for model in &AllLanguageModelSettings::get_global(cx) - .openai - .available_models - { + for model in &OpenAiLanguageModelProvider::settings(cx).available_models { models.insert( model.name.clone(), open_ai::Model::Custom { @@ -267,7 +195,8 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -287,11 +216,12 @@ impl OpenAiLanguageModel { ) -> BoxFuture<'static, Result>>> { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).openai; - (state.api_key.clone(), settings.api_url.clone()) + + let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| { + let api_url = OpenAiLanguageModelProvider::api_url(cx); + (state.api_key_state.key(&api_url), api_url) }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + return future::ready(Err(anyhow!("App state dropped"))).boxed(); }; let future = self.request_limiter.stream(async move { @@ -791,45 +721,35 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self - .api_key_editor - .read(cx) - .editor() - .read(cx) - .text(cx) - .trim() - .to_string(); - - // Don't proceed if no API key is provided and we're not authenticated - if api_key.is_empty() && !self.state.read(cx).is_authenticated() { + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); + if api_key.is_empty() { return; } + // url changes can cause the editor to be displayed again + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { - self.api_key_editor.update(cx, |input, cx| { - input.editor.update(cx, |editor, cx| { - editor.set_text("", window, cx); - }); - }); + self.api_key_editor + .update(cx, |input, cx| input.set_text("", window, cx)); let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } fn should_render_editor(&self, cx: &mut Context) -> bool { @@ -839,7 +759,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); let api_key_section = if self.should_render_editor(cx) { v_flex() @@ -861,10 +781,11 @@ impl Render for ConfigurationView { ) .child(self.api_key_editor.clone()) .child( - Label::new( - format!("You can also assign the {OPENAI_API_KEY_VAR} environment variable and restart Zed."), - ) - .size(LabelSize::Small).color(Color::Muted), + Label::new(format!( + "You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed." + )) + .size(LabelSize::Small) + .color(Color::Muted), ) .child( Label::new( @@ -887,9 +808,14 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {OPENAI_API_KEY_VAR} environment variable.") + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable") } else { - "API key configured.".to_string() + let api_url = OpenAiLanguageModelProvider::api_url(cx); + if api_url == OPEN_AI_API_URL { + "API key configured".to_string() + } else { + format!("API key configured for {}", truncate_and_trailoff(&api_url, 32)) + } })), ) .child( @@ -900,7 +826,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .layer(ElevationIndex::ModalSurface) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENAI_API_KEY_VAR} environment variable."))) + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index 9407376f301c04c10b198bbb3155a76dc1537abd..788a412a8232d43e92ec9195132efe21cf73bf00 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -1,9 +1,7 @@ -use anyhow::{Context as _, Result, anyhow}; -use credentials_provider::CredentialsProvider; - +use anyhow::{Result, anyhow}; use convert_case::{Case, Casing}; -use futures::{FutureExt, StreamExt, future::BoxFuture}; -use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window}; +use futures::{FutureExt, StreamExt, future, future::BoxFuture}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, @@ -15,12 +13,12 @@ use menu; use open_ai::{ResponseStreamEvent, stream_completion}; use settings::{Settings, SettingsStore}; use std::sync::Arc; - use ui::{ElevationIndex, Tooltip, prelude::*}; use ui_input::SingleLineInput; -use util::ResultExt; +use util::{ResultExt, truncate_and_trailoff}; +use zed_env_vars::EnvVar; -use crate::AllLanguageModelSettings; +use crate::api_key::ApiKeyState; use crate::provider::open_ai::{OpenAiEventMapper, into_open_ai}; pub use settings::OpenAiCompatibleAvailableModel as AvailableModel; pub use settings::OpenAiCompatibleModelCapabilities as ModelCapabilities; @@ -40,124 +38,67 @@ pub struct OpenAiCompatibleLanguageModelProvider { pub struct State { id: Arc, - env_var_name: Arc, - api_key: Option, - api_key_from_env: bool, + api_key_env_var: EnvVar, + api_key_state: ApiKeyState, settings: OpenAiCompatibleSettings, - _subscription: Subscription, } impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() - } - - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = self.settings.api_url.clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) + self.api_key_state.has_key() } - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = self.settings.api_url.clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) - } - - fn get_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let env_var_name = self.env_var_name.clone(); - let api_url = self.settings.api_url.clone(); - cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(env_var_name.as_ref()) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = SharedString::new(self.settings.api_url.as_str()); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - self.get_api_key(cx) + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = SharedString::new(self.settings.api_url.clone()); + self.api_key_state.load_if_needed( + api_url, + &self.api_key_env_var, + |this| &mut this.api_key_state, + cx, + ) } } impl OpenAiCompatibleLanguageModelProvider { pub fn new(id: Arc, http_client: Arc, cx: &mut App) -> Self { fn resolve_settings<'a>(id: &'a str, cx: &'a App) -> Option<&'a OpenAiCompatibleSettings> { - AllLanguageModelSettings::get_global(cx) + crate::AllLanguageModelSettings::get_global(cx) .openai_compatible .get(id) } - let state = cx.new(|cx| State { - id: id.clone(), - env_var_name: format!("{}_API_KEY", id).to_case(Case::Constant).into(), - settings: resolve_settings(&id, cx).cloned().unwrap_or_default(), - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|this: &mut State, cx| { + let api_key_env_var_name = format!("{}_API_KEY", id).to_case(Case::UpperSnake).into(); + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { let Some(settings) = resolve_settings(&this.id, cx).cloned() else { return; }; if &this.settings != &settings { - if settings.api_url != this.settings.api_url && !this.api_key_from_env { - let spawn_task = cx.spawn(async move |handle, cx| { - if let Ok(task) = handle.update(cx, |this, cx| this.get_api_key(cx)) { - if let Err(_) = task.await { - handle - .update(cx, |this, _| { - this.api_key = None; - this.api_key_from_env = false; - }) - .ok(); - } - } - }); - spawn_task.detach(); - } - + let api_url = SharedString::new(settings.api_url.as_str()); + this.api_key_state.handle_url_change( + api_url, + &this.api_key_env_var, + |this| &mut this.api_key_state, + cx, + ); this.settings = settings; cx.notify(); } - }), + }) + .detach(); + let settings = resolve_settings(&id, cx).cloned().unwrap_or_default(); + State { + id: id.clone(), + api_key_env_var: EnvVar::new(api_key_env_var_name), + api_key_state: ApiKeyState::new(SharedString::new(settings.api_url.as_str())), + settings, + } }); Self { @@ -244,7 +185,8 @@ impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -266,10 +208,15 @@ impl OpenAiCompatibleLanguageModel { ) -> BoxFuture<'static, Result>>> { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, _| { - (state.api_key.clone(), state.settings.api_url.clone()) + + let Ok((api_key, api_url)) = self.state.read_with(cx, |state, _cx| { + let api_url = &state.settings.api_url; + ( + state.api_key_state.key(api_url), + state.settings.api_url.clone(), + ) }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + return future::ready(Err(anyhow!("App state dropped"))).boxed(); }; let provider = self.provider_name.clone(); @@ -439,56 +386,47 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self - .api_key_editor - .read(cx) - .editor() - .read(cx) - .text(cx) - .trim() - .to_string(); - - // Don't proceed if no API key is provided and we're not authenticated - if api_key.is_empty() && !self.state.read(cx).is_authenticated() { + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); + if api_key.is_empty() { return; } + // url changes can cause the editor to be displayed again + self.api_key_editor + .update(cx, |input, cx| input.set_text("", window, cx)); + let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { - self.api_key_editor.update(cx, |input, cx| { - input.editor.update(cx, |editor, cx| { - editor.set_text("", window, cx); - }); - }); + self.api_key_editor + .update(cx, |input, cx| input.set_text("", window, cx)); let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } - fn should_render_editor(&self, cx: &mut Context) -> bool { + fn should_render_editor(&self, cx: &Context) -> bool { !self.state.read(cx).is_authenticated() } } impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; - let env_var_name = self.state.read(cx).env_var_name.clone(); + let state = self.state.read(cx); + let env_var_set = state.api_key_state.is_from_env_var(); + let env_var_name = &state.api_key_env_var.name; let api_key_section = if self.should_render_editor(cx) { v_flex() @@ -520,9 +458,9 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {env_var_name} environment variable.") + format!("API key set in {env_var_name} environment variable") } else { - "API key configured.".to_string() + format!("API key configured for {}", truncate_and_trailoff(&state.settings.api_url, 32)) })), ) .child( diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index ff51c4a5f8eb009d3a7043ddc99357ab24e4b661..a69041737f2850e8209d923de617716ad258b8ad 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -1,10 +1,9 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use collections::HashMap; -use credentials_provider::CredentialsProvider; use editor::{Editor, EditorElement, EditorStyle}; -use futures::{FutureExt, Stream, StreamExt, future::BoxFuture}; +use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture}; use gpui::{ - AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace, + AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace, }; use http_client::HttpClient; use language_model::{ @@ -15,21 +14,25 @@ use language_model::{ LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage, }; use open_router::{ - Model, ModelMode as OpenRouterModelMode, ResponseStreamEvent, list_models, stream_completion, + Model, ModelMode as OpenRouterModelMode, OPEN_ROUTER_API_URL, ResponseStreamEvent, list_models, }; use settings::{OpenRouterAvailableModel as AvailableModel, Settings, SettingsStore}; use std::pin::Pin; use std::str::FromStr as _; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use theme::ThemeSettings; use ui::{Icon, IconName, List, Tooltip, prelude::*}; -use util::ResultExt; +use util::{ResultExt, truncate_and_trailoff}; +use zed_env_vars::{EnvVar, env_var}; -use crate::{AllLanguageModelSettings, ui::InstructionListItem}; +use crate::{api_key::ApiKeyState, ui::InstructionListItem}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter"); +const API_KEY_ENV_VAR_NAME: &str = "OPENROUTER_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); + #[derive(Default, Clone, Debug, PartialEq)] pub struct OpenRouterSettings { pub api_url: String, @@ -42,93 +45,37 @@ pub struct OpenRouterLanguageModelProvider { } pub struct State { - api_key: Option, - api_key_from_env: bool, + api_key_state: ApiKeyState, http_client: Arc, available_models: Vec, fetch_models_task: Option>>, - settings: OpenRouterSettings, - _subscription: Subscription, } -const OPENROUTER_API_KEY_VAR: &str = "OPENROUTER_API_KEY"; - impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() + self.api_key_state.has_key() } - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .open_router - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = OpenRouterLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .open_router - .api_url - .clone(); - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.restart_fetch_models_task(cx); - cx.notify(); - }) - }) - } - - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let credentials_provider = ::global(cx); - let api_url = AllLanguageModelSettings::get_global(cx) - .open_router - .api_url - .clone(); + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = OpenRouterLanguageModelProvider::api_url(cx); + let task = self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENROUTER_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - this.restart_fetch_models_task(cx); - cx.notify(); - })?; - - Ok(()) + let result = task.await; + this.update(cx, |this, cx| this.restart_fetch_models_task(cx)) + .ok(); + result }) } @@ -136,10 +83,9 @@ impl State { &mut self, cx: &mut Context, ) -> Task> { - let settings = &AllLanguageModelSettings::get_global(cx).open_router; let http_client = self.http_client.clone(); - let api_url = settings.api_url.clone(); - let Some(api_key) = self.api_key.clone() else { + let api_url = OpenRouterLanguageModelProvider::api_url(cx); + let Some(api_key) = self.api_key_state.key(&api_url) else { return Task::ready(Err(LanguageModelCompletionError::NoApiKey { provider: PROVIDER_NAME, })); @@ -168,33 +114,52 @@ impl State { if self.is_authenticated() { let task = self.fetch_models(cx); self.fetch_models_task.replace(task); + } else { + self.available_models = Vec::new(); } } } impl OpenRouterLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - http_client: http_client.clone(), - available_models: Vec::new(), - fetch_models_task: None, - settings: OpenRouterSettings::default(), - _subscription: cx.observe_global::(|this: &mut State, cx| { - let current_settings = &AllLanguageModelSettings::get_global(cx).open_router; - let settings_changed = current_settings != &this.settings; - if settings_changed { - this.settings = current_settings.clone(); - this.restart_fetch_models_task(cx); + let state = cx.new(|cx| { + cx.observe_global::({ + let mut last_settings = OpenRouterLanguageModelProvider::settings(cx).clone(); + move |this: &mut State, cx| { + let current_settings = OpenRouterLanguageModelProvider::settings(cx); + let settings_changed = current_settings != &last_settings; + if settings_changed { + last_settings = current_settings.clone(); + this.authenticate(cx).detach(); + cx.notify(); + } } - cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + http_client: http_client.clone(), + available_models: Vec::new(), + fetch_models_task: None, + } }); Self { http_client, state } } + fn settings(cx: &App) -> &OpenRouterSettings { + &crate::AllLanguageModelSettings::get_global(cx).open_router + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + OPEN_ROUTER_API_URL.into() + } else { + SharedString::new(api_url.as_str()) + } + } + fn create_language_model(&self, model: open_router::Model) -> Arc { Arc::new(OpenRouterLanguageModel { id: LanguageModelId::from(model.id().to_string()), @@ -239,10 +204,7 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider { let mut models_from_api = self.state.read(cx).available_models.clone(); let mut settings_models = Vec::new(); - for model in &AllLanguageModelSettings::get_global(cx) - .open_router - .available_models - { + for model in &Self::settings(cx).available_models { settings_models.push(open_router::Model { name: model.name.clone(), display_name: model.display_name.clone(), @@ -290,7 +252,8 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -318,14 +281,11 @@ impl OpenRouterLanguageModel { >, > { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).open_router; - (state.api_key.clone(), settings.api_url.clone()) + let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| { + let api_url = OpenRouterLanguageModelProvider::api_url(cx); + (state.api_key_state.key(&api_url), api_url) }) else { - return futures::future::ready(Err(LanguageModelCompletionError::Other(anyhow!( - "App state dropped" - )))) - .boxed(); + return future::ready(Err(anyhow!("App state dropped").into())).boxed(); }; async move { @@ -334,7 +294,8 @@ impl OpenRouterLanguageModel { provider: PROVIDER_NAME, }); }; - let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let request = + open_router::stream_completion(http_client.as_ref(), &api_url, &api_key, request); request.await.map_err(Into::into) } .boxed() @@ -782,20 +743,22 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self.api_key_editor.read(cx).text(cx); + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); if api_key.is_empty() { return; } + // url changes can cause the editor to be displayed again + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { @@ -804,11 +767,11 @@ impl ConfigurationView { let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } fn render_api_key_editor(&self, cx: &mut Context) -> impl IntoElement { @@ -843,7 +806,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); if self.load_credentials_task.is_some() { div().child(Label::new("Loading credentials...")).into_any() @@ -880,7 +843,7 @@ impl Render for ConfigurationView { ) .child( Label::new( - format!("You can also assign the {OPENROUTER_API_KEY_VAR} environment variable and restart Zed."), + format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."), ) .size(LabelSize::Small).color(Color::Muted), ) @@ -899,9 +862,14 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {OPENROUTER_API_KEY_VAR} environment variable.") + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable") } else { - "API key configured.".to_string() + let api_url = OpenRouterLanguageModelProvider::api_url(cx); + if api_url == OPEN_ROUTER_API_URL { + "API key configured".to_string() + } else { + format!("API key configured for {}", truncate_and_trailoff(&api_url, 32)) + } })), ) .child( @@ -912,7 +880,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .disabled(env_var_set) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENROUTER_API_KEY_VAR} environment variable."))) + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/language_models/src/provider/vercel.rs b/crates/language_models/src/provider/vercel.rs index 2daa74306c92d758c5d293dc5f49cd665de586e2..86f3dc6a1672e19716afefcdaf32ca71fc43ae88 100644 --- a/crates/language_models/src/provider/vercel.rs +++ b/crates/language_models/src/provider/vercel.rs @@ -1,8 +1,7 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use collections::BTreeMap; -use credentials_provider::CredentialsProvider; -use futures::{FutureExt, StreamExt, future::BoxFuture}; -use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window}; +use futures::{FutureExt, StreamExt, future, future::BoxFuture}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, @@ -10,24 +9,26 @@ use language_model::{ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, RateLimiter, Role, }; -use menu; use open_ai::ResponseStreamEvent; +pub use settings::VercelAvailableModel as AvailableModel; use settings::{Settings, SettingsStore}; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; -use vercel::Model; - -pub use settings::VercelAvailableModel as AvailableModel; use ui::{ElevationIndex, List, Tooltip, prelude::*}; use ui_input::SingleLineInput; -use util::ResultExt; +use util::{ResultExt, truncate_and_trailoff}; +use vercel::{Model, VERCEL_API_URL}; +use zed_env_vars::{EnvVar, env_var}; -use crate::{AllLanguageModelSettings, ui::InstructionListItem}; +use crate::{api_key::ApiKeyState, ui::InstructionListItem}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("vercel"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Vercel"); // todo!() -> Remove default implementation +const API_KEY_ENV_VAR_NAME: &str = "VERCEL_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); + #[derive(Default, Clone, Debug, PartialEq)] pub struct VercelSettings { pub api_url: String, @@ -40,103 +41,48 @@ pub struct VercelLanguageModelProvider { } pub struct State { - api_key: Option, - api_key_from_env: bool, - _subscription: Subscription, + api_key_state: ApiKeyState, } -const VERCEL_API_KEY_VAR: &str = "VERCEL_API_KEY"; - impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() + self.api_key_state.has_key() } - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let settings = &AllLanguageModelSettings::get_global(cx).vercel; - let api_url = if settings.api_url.is_empty() { - vercel::VERCEL_API_URL.to_string() - } else { - settings.api_url.clone() - }; - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = VercelLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let settings = &AllLanguageModelSettings::get_global(cx).vercel; - let api_url = if settings.api_url.is_empty() { - vercel::VERCEL_API_URL.to_string() - } else { - settings.api_url.clone() - }; - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) - } - - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let credentials_provider = ::global(cx); - let settings = &AllLanguageModelSettings::get_global(cx).vercel; - let api_url = if settings.api_url.is_empty() { - vercel::VERCEL_API_URL.to_string() - } else { - settings.api_url.clone() - }; - cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(VERCEL_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = VercelLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) } } impl VercelLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|_this: &mut State, cx| { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } }); Self { http_client, state } @@ -151,6 +97,19 @@ impl VercelLanguageModelProvider { request_limiter: RateLimiter::new(4), }) } + + fn settings(cx: &App) -> &VercelSettings { + &crate::AllLanguageModelSettings::get_global(cx).vercel + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + VERCEL_API_URL.into() + } else { + SharedString::new(api_url.as_str()) + } + } } impl LanguageModelProviderState for VercelLanguageModelProvider { @@ -191,10 +150,7 @@ impl LanguageModelProvider for VercelLanguageModelProvider { } } - for model in &AllLanguageModelSettings::get_global(cx) - .vercel - .available_models - { + for model in &Self::settings(cx).available_models { models.insert( model.name.clone(), vercel::Model::Custom { @@ -232,7 +188,8 @@ impl LanguageModelProvider for VercelLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -252,16 +209,12 @@ impl VercelLanguageModel { ) -> BoxFuture<'static, Result>>> { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).vercel; - let api_url = if settings.api_url.is_empty() { - vercel::VERCEL_API_URL.to_string() - } else { - settings.api_url.clone() - }; - (state.api_key.clone(), api_url) + + let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| { + let api_url = VercelLanguageModelProvider::api_url(cx); + (state.api_key_state.key(&api_url), api_url) }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + return future::ready(Err(anyhow!("App state dropped"))).boxed(); }; let future = self.request_limiter.stream(async move { @@ -457,45 +410,35 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self - .api_key_editor - .read(cx) - .editor() - .read(cx) - .text(cx) - .trim() - .to_string(); - - // Don't proceed if no API key is provided and we're not authenticated - if api_key.is_empty() && !self.state.read(cx).is_authenticated() { + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); + if api_key.is_empty() { return; } + // url changes can cause the editor to be displayed again + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { - self.api_key_editor.update(cx, |input, cx| { - input.editor.update(cx, |editor, cx| { - editor.set_text("", window, cx); - }); - }); + self.api_key_editor + .update(cx, |input, cx| input.set_text("", window, cx)); let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } fn should_render_editor(&self, cx: &mut Context) -> bool { @@ -505,7 +448,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); let api_key_section = if self.should_render_editor(cx) { v_flex() @@ -525,7 +468,7 @@ impl Render for ConfigurationView { .child(self.api_key_editor.clone()) .child( Label::new(format!( - "You can also assign the {VERCEL_API_KEY_VAR} environment variable and restart Zed." + "You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed." )) .size(LabelSize::Small) .color(Color::Muted), @@ -550,9 +493,14 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {VERCEL_API_KEY_VAR} environment variable.") + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable") } else { - "API key configured.".to_string() + let api_url = VercelLanguageModelProvider::api_url(cx); + if api_url == VERCEL_API_URL { + "API key configured".to_string() + } else { + format!("API key configured for {}", truncate_and_trailoff(&api_url, 32)) + } })), ) .child( @@ -563,7 +511,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .layer(ElevationIndex::ModalSurface) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, unset the {VERCEL_API_KEY_VAR} environment variable."))) + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs index 748d77d53e4f1d708035ac8507ff32d6c52081a7..d75c8ce78c6c9ee86bf838739047470999a85bfe 100644 --- a/crates/language_models/src/provider/x_ai.rs +++ b/crates/language_models/src/provider/x_ai.rs @@ -1,8 +1,7 @@ -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Result, anyhow}; use collections::BTreeMap; -use credentials_provider::CredentialsProvider; -use futures::{FutureExt, StreamExt, future::BoxFuture}; -use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window}; +use futures::{FutureExt, StreamExt, future, future::BoxFuture}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, Task, Window}; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, @@ -10,22 +9,24 @@ use language_model::{ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter, Role, }; -use menu; use open_ai::ResponseStreamEvent; +pub use settings::XaiAvailableModel as AvailableModel; use settings::{Settings, SettingsStore}; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use strum::IntoEnumIterator; -use x_ai::Model; - -pub use settings::XaiAvailableModel as AvailableModel; use ui::{ElevationIndex, List, Tooltip, prelude::*}; use ui_input::SingleLineInput; -use util::ResultExt; +use util::{ResultExt, truncate_and_trailoff}; +use x_ai::{Model, XAI_API_URL}; +use zed_env_vars::{EnvVar, env_var}; -use crate::{AllLanguageModelSettings, ui::InstructionListItem}; +use crate::{api_key::ApiKeyState, ui::InstructionListItem}; -const PROVIDER_ID: &str = "x_ai"; -const PROVIDER_NAME: &str = "xAI"; +const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai"); +const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI"); + +const API_KEY_ENV_VAR_NAME: &str = "XAI_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); #[derive(Default, Clone, Debug, PartialEq)] pub struct XAiSettings { @@ -39,103 +40,48 @@ pub struct XAiLanguageModelProvider { } pub struct State { - api_key: Option, - api_key_from_env: bool, - _subscription: Subscription, + api_key_state: ApiKeyState, } -const XAI_API_KEY_VAR: &str = "XAI_API_KEY"; - impl State { fn is_authenticated(&self) -> bool { - self.api_key.is_some() + self.api_key_state.has_key() } - fn reset_api_key(&self, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let settings = &AllLanguageModelSettings::get_global(cx).x_ai; - let api_url = if settings.api_url.is_empty() { - x_ai::XAI_API_URL.to_string() - } else { - settings.api_url.clone() - }; - cx.spawn(async move |this, cx| { - credentials_provider - .delete_credentials(&api_url, cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = None; - this.api_key_from_env = false; - cx.notify(); - }) - }) + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = XAiLanguageModelProvider::api_url(cx); + self.api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx) } - fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { - let credentials_provider = ::global(cx); - let settings = &AllLanguageModelSettings::get_global(cx).x_ai; - let api_url = if settings.api_url.is_empty() { - x_ai::XAI_API_URL.to_string() - } else { - settings.api_url.clone() - }; - cx.spawn(async move |this, cx| { - credentials_provider - .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) - .await - .log_err(); - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - cx.notify(); - }) - }) - } - - fn authenticate(&self, cx: &mut Context) -> Task> { - if self.is_authenticated() { - return Task::ready(Ok(())); - } - - let credentials_provider = ::global(cx); - let settings = &AllLanguageModelSettings::get_global(cx).x_ai; - let api_url = if settings.api_url.is_empty() { - x_ai::XAI_API_URL.to_string() - } else { - settings.api_url.clone() - }; - cx.spawn(async move |this, cx| { - let (api_key, from_env) = if let Ok(api_key) = std::env::var(XAI_API_KEY_VAR) { - (api_key, true) - } else { - let (_, api_key) = credentials_provider - .read_credentials(&api_url, cx) - .await? - .ok_or(AuthenticateError::CredentialsNotFound)?; - ( - String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?, - false, - ) - }; - this.update(cx, |this, cx| { - this.api_key = Some(api_key); - this.api_key_from_env = from_env; - cx.notify(); - })?; - - Ok(()) - }) + fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = XAiLanguageModelProvider::api_url(cx); + self.api_key_state.load_if_needed( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ) } } impl XAiLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| State { - api_key: None, - api_key_from_env: false, - _subscription: cx.observe_global::(|_this: &mut State, cx| { + let state = cx.new(|cx| { + cx.observe_global::(|this: &mut State, cx| { + let api_url = Self::api_url(cx); + this.api_key_state.handle_url_change( + api_url, + &API_KEY_ENV_VAR, + |this| &mut this.api_key_state, + cx, + ); cx.notify(); - }), + }) + .detach(); + State { + api_key_state: ApiKeyState::new(Self::api_url(cx)), + } }); Self { http_client, state } @@ -150,6 +96,19 @@ impl XAiLanguageModelProvider { request_limiter: RateLimiter::new(4), }) } + + fn settings(cx: &App) -> &XAiSettings { + &crate::AllLanguageModelSettings::get_global(cx).x_ai + } + + fn api_url(cx: &App) -> SharedString { + let api_url = &Self::settings(cx).api_url; + if api_url.is_empty() { + XAI_API_URL.into() + } else { + SharedString::new(api_url.as_str()) + } + } } impl LanguageModelProviderState for XAiLanguageModelProvider { @@ -162,11 +121,11 @@ impl LanguageModelProviderState for XAiLanguageModelProvider { impl LanguageModelProvider for XAiLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn icon(&self) -> IconName { @@ -190,10 +149,7 @@ impl LanguageModelProvider for XAiLanguageModelProvider { } } - for model in &AllLanguageModelSettings::get_global(cx) - .x_ai - .available_models - { + for model in &Self::settings(cx).available_models { models.insert( model.name.clone(), x_ai::Model::Custom { @@ -231,7 +187,8 @@ impl LanguageModelProvider for XAiLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.reset_api_key(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -251,20 +208,20 @@ impl XAiLanguageModel { ) -> BoxFuture<'static, Result>>> { let http_client = self.http_client.clone(); - let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).x_ai; - let api_url = if settings.api_url.is_empty() { - x_ai::XAI_API_URL.to_string() - } else { - settings.api_url.clone() - }; - (state.api_key.clone(), api_url) + + let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| { + let api_url = XAiLanguageModelProvider::api_url(cx); + (state.api_key_state.key(&api_url), api_url) }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + return future::ready(Err(anyhow!("App state dropped"))).boxed(); }; let future = self.request_limiter.stream(async move { - let api_key = api_key.context("Missing xAI API Key")?; + let Some(api_key) = api_key else { + return Err(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + }); + }; let request = open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request); let response = request.await?; @@ -285,11 +242,11 @@ impl LanguageModel for XAiLanguageModel { } fn provider_id(&self) -> LanguageModelProviderId { - LanguageModelProviderId(PROVIDER_ID.into()) + PROVIDER_ID } fn provider_name(&self) -> LanguageModelProviderName { - LanguageModelProviderName(PROVIDER_NAME.into()) + PROVIDER_NAME } fn supports_tools(&self) -> bool { @@ -446,45 +403,35 @@ impl ConfigurationView { } fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self - .api_key_editor - .read(cx) - .editor() - .read(cx) - .text(cx) - .trim() - .to_string(); - - // Don't proceed if no API key is provided and we're not authenticated - if api_key.is_empty() && !self.state.read(cx).is_authenticated() { + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); + if api_key.is_empty() { return; } + // url changes can cause the editor to be displayed again + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { state - .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? .await }) .detach_and_log_err(cx); - - cx.notify(); } fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { - self.api_key_editor.update(cx, |input, cx| { - input.editor.update(cx, |editor, cx| { - editor.set_text("", window, cx); - }); - }); + self.api_key_editor + .update(cx, |input, cx| input.set_text("", window, cx)); let state = self.state.clone(); cx.spawn_in(window, async move |_, cx| { - state.update(cx, |state, cx| state.reset_api_key(cx))?.await + state + .update(cx, |state, cx| state.set_api_key(None, cx))? + .await }) .detach_and_log_err(cx); - - cx.notify(); } fn should_render_editor(&self, cx: &mut Context) -> bool { @@ -494,7 +441,7 @@ impl ConfigurationView { impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_from_env; + let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); let api_key_section = if self.should_render_editor(cx) { v_flex() @@ -514,7 +461,7 @@ impl Render for ConfigurationView { .child(self.api_key_editor.clone()) .child( Label::new(format!( - "You can also assign the {XAI_API_KEY_VAR} environment variable and restart Zed." + "You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed." )) .size(LabelSize::Small) .color(Color::Muted), @@ -539,9 +486,14 @@ impl Render for ConfigurationView { .gap_1() .child(Icon::new(IconName::Check).color(Color::Success)) .child(Label::new(if env_var_set { - format!("API key set in {XAI_API_KEY_VAR} environment variable.") + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable") } else { - "API key configured.".to_string() + let api_url = XAiLanguageModelProvider::api_url(cx); + if api_url == XAI_API_URL { + "API key configured".to_string() + } else { + format!("API key configured for {}", truncate_and_trailoff(&api_url, 32)) + } })), ) .child( @@ -552,7 +504,7 @@ impl Render for ConfigurationView { .icon_position(IconPosition::Start) .layer(ElevationIndex::ModalSurface) .when(env_var_set, |this| { - this.tooltip(Tooltip::text(format!("To reset your API key, unset the {XAI_API_KEY_VAR} environment variable."))) + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))) }) .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), ) diff --git a/crates/language_onboarding/src/python.rs b/crates/language_onboarding/src/python.rs index 6b83b841e0488d67014cc090b6c741035e544e04..e715cb7c806f417980a93a62210c72ca8529fcb5 100644 --- a/crates/language_onboarding/src/python.rs +++ b/crates/language_onboarding/src/python.rs @@ -30,6 +30,10 @@ impl BasedPyrightBanner { _subscriptions: [subscription], } } + + fn onboarding_banner_enabled(&self) -> bool { + !self.dismissed && self.have_basedpyright + } } impl EventEmitter for BasedPyrightBanner {} @@ -38,7 +42,7 @@ impl Render for BasedPyrightBanner { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { div() .id("basedpyright-banner") - .when(!self.dismissed && self.have_basedpyright, |el| { + .when(self.onboarding_banner_enabled(), |el| { el.child( Banner::new() .child( @@ -81,6 +85,9 @@ impl ToolbarItemView for BasedPyrightBanner { _window: &mut ui::Window, cx: &mut Context, ) -> ToolbarItemLocation { + if !self.onboarding_banner_enabled() { + return ToolbarItemLocation::Hidden; + } if let Some(item) = active_pane_item && let Some(editor) = item.act_as::(cx) && let Some(path) = editor.update(cx, |editor, cx| editor.target_file_abs_path(cx)) diff --git a/crates/language_tools/src/syntax_tree_view.rs b/crates/language_tools/src/syntax_tree_view.rs index 5700d8d487e990937597295fb5bab761a46f2ba3..5a110019ddb229c6b4111663b46b64895befb8ff 100644 --- a/crates/language_tools/src/syntax_tree_view.rs +++ b/crates/language_tools/src/syntax_tree_view.rs @@ -12,7 +12,8 @@ use theme::ActiveTheme; use tree_sitter::{Node, TreeCursor}; use ui::{ ButtonCommon, ButtonLike, Clickable, Color, ContextMenu, FluentBuilder as _, IconButton, - IconName, Label, LabelCommon, LabelSize, PopoverMenu, StyledExt, Tooltip, h_flex, v_flex, + IconName, Label, LabelCommon, LabelSize, PopoverMenu, StyledExt, Tooltip, WithScrollbar, + h_flex, v_flex, }; use workspace::{ Event as WorkspaceEvent, SplitDirection, ToolbarItemEvent, ToolbarItemLocation, @@ -487,7 +488,7 @@ impl SyntaxTreeView { } impl Render for SyntaxTreeView { - fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { + fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { div() .flex_1() .bg(cx.theme().colors().editor_background) @@ -512,6 +513,8 @@ impl Render for SyntaxTreeView { .text_bg(cx.theme().colors().background) .into_any_element(), ) + .vertical_scrollbar_for(self.list_scroll_handle.clone(), window, cx) + .into_any_element() } else { let inner_content = v_flex() .items_center() @@ -540,6 +543,7 @@ impl Render for SyntaxTreeView { .size_full() .justify_center() .child(inner_content) + .into_any_element() } }) } diff --git a/crates/languages/Cargo.toml b/crates/languages/Cargo.toml index 7ebafd8fdd9e2310c207d47a7d911516a498d00c..f08c548ddcb10a311e1d5b29a9bf50a7b5bc4fb1 100644 --- a/crates/languages/Cargo.toml +++ b/crates/languages/Cargo.toml @@ -57,6 +57,7 @@ pet-core.workspace = true pet-fs.workspace = true pet-poetry.workspace = true pet-reporter.workspace = true +pet-virtualenv.workspace = true pet.workspace = true project.workspace = true regex.workspace = true diff --git a/crates/languages/src/go/injections.scm b/crates/languages/src/go/injections.scm index 2be0844d97b7f9f16e8832b5b8aebbb0d0043e5d..7bb68d760e1a556ef93a9477dc578c88d9350dcb 100644 --- a/crates/languages/src/go/injections.scm +++ b/crates/languages/src/go/injections.scm @@ -10,4 +10,365 @@ (raw_string_literal) (interpreted_string_literal) ] @injection.content - (#set! injection.language "regex"))) + (#set! injection.language "regex") + )) + +; INJECT SQL +( + [ + ; var, const or short declaration of raw or interpreted string literal + ((comment) @comment + . + (expression_list + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content + )) + + ; when passing as a literal element (to struct field eg.) + ((comment) @comment + . + (literal_element + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content + )) + + ; when passing as a function parameter + ((comment) @comment + . + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content) + ] + + (#match? @comment "^\\/\\*\\s*sql\\s*\\*\\/") ; /* sql */ or /*sql*/ + (#set! injection.language "sql") +) + +; INJECT JSON +( + [ + ; var, const or short declaration of raw or interpreted string literal + ((comment) @comment + . + (expression_list + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content + )) + + ; when passing as a literal element (to struct field eg.) + ((comment) @comment + . + (literal_element + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content + )) + + ; when passing as a function parameter + ((comment) @comment + . + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content) + ] + + (#match? @comment "^\\/\\*\\s*json\\s*\\*\\/") ; /* json */ or /*json*/ + (#set! injection.language "json") +) + +; INJECT YAML +( + [ + ; var, const or short declaration of raw or interpreted string literal + ((comment) @comment + . + (expression_list + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content + )) + + ; when passing as a literal element (to struct field eg.) + ((comment) @comment + . + (literal_element + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content + )) + + ; when passing as a function parameter + ((comment) @comment + . + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content) + ] + + (#match? @comment "^\\/\\*\\s*yaml\\s*\\*\\/") ; /* yaml */ or /*yaml*/ + (#set! injection.language "yaml") +) + +; INJECT XML +( + [ + ; var, const or short declaration of raw or interpreted string literal + ((comment) @comment + . + (expression_list + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content + )) + + ; when passing as a literal element (to struct field eg.) + ((comment) @comment + . + (literal_element + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content + )) + + ; when passing as a function parameter + ((comment) @comment + . + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content) + ] + + (#match? @comment "^\\/\\*\\s*xml\\s*\\*\\/") ; /* xml */ or /*xml*/ + (#set! injection.language "xml") +) + +; INJECT HTML +( + [ + ; var, const or short declaration of raw or interpreted string literal + ((comment) @comment + . + (expression_list + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content + )) + + ; when passing as a literal element (to struct field eg.) + ((comment) @comment + . + (literal_element + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content + )) + + ; when passing as a function parameter + ((comment) @comment + . + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content) + ] + + (#match? @comment "^\\/\\*\\s*html\\s*\\*\\/") ; /* html */ or /*html*/ + (#set! injection.language "html") +) + +; INJECT JS +( + [ + ; var, const or short declaration of raw or interpreted string literal + ((comment) @comment + . + (expression_list + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content + )) + + ; when passing as a literal element (to struct field eg.) + ((comment) @comment + . + (literal_element + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content + )) + + ; when passing as a function parameter + ((comment) @comment + . + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content) + ] + + (#match? @comment "^\\/\\*\\s*js\\s*\\*\\/") ; /* js */ or /*js*/ + (#set! injection.language "javascript") +) + +; INJECT CSS +( + [ + ; var, const or short declaration of raw or interpreted string literal + ((comment) @comment + . + (expression_list + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content + )) + + ; when passing as a literal element (to struct field eg.) + ((comment) @comment + . + (literal_element + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content + )) + + ; when passing as a function parameter + ((comment) @comment + . + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content) + ] + + (#match? @comment "^\\/\\*\\s*css\\s*\\*\\/") ; /* css */ or /*css*/ + (#set! injection.language "css") +) + +; INJECT LUA +( + [ + ; var, const or short declaration of raw or interpreted string literal + ((comment) @comment + . + (expression_list + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content + )) + + ; when passing as a literal element (to struct field eg.) + ((comment) @comment + . + (literal_element + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content + )) + + ; when passing as a function parameter + ((comment) @comment + . + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content) + ] + + (#match? @comment "^\\/\\*\\s*lua\\s*\\*\\/") ; /* lua */ or /*lua*/ + (#set! injection.language "lua") +) + +; INJECT BASH +( + [ + ; var, const or short declaration of raw or interpreted string literal + ((comment) @comment + . + (expression_list + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content + )) + + ; when passing as a literal element (to struct field eg.) + ((comment) @comment + . + (literal_element + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content + )) + + ; when passing as a function parameter + ((comment) @comment + . + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content) + ] + + (#match? @comment "^\\/\\*\\s*bash\\s*\\*\\/") ; /* bash */ or /*bash*/ + (#set! injection.language "bash") +) + +; INJECT CSV +( + [ + ; var, const or short declaration of raw or interpreted string literal + ((comment) @comment + . + (expression_list + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content + )) + + ; when passing as a literal element (to struct field eg.) + ((comment) @comment + . + (literal_element + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content + )) + + ; when passing as a function parameter + ((comment) @comment + . + [ + (interpreted_string_literal) + (raw_string_literal) + ] @injection.content) + ] + + (#match? @comment "^\\/\\*\\s*csv\\s*\\*\\/") ; /* csv */ or /*csv*/ + (#set! injection.language "csv") +) diff --git a/crates/languages/src/lib.rs b/crates/languages/src/lib.rs index 8ab84ed835a448649c732251fb4715fa8a776a85..f5a4a8c6f6480de7589f0a418157fafbf5fbe2ed 100644 --- a/crates/languages/src/lib.rs +++ b/crates/languages/src/lib.rs @@ -286,6 +286,7 @@ pub fn init(languages: Arc, fs: Arc, node: NodeRuntime "HEEX", "HTML", "JavaScript", + "TypeScript", "PHP", "Svelte", "TSX", diff --git a/crates/languages/src/python.rs b/crates/languages/src/python.rs index b893a44d523c7f34821056fea243124ded3e8ece..4905cd22cd64488a0f0da9ede6834602fc5834bc 100644 --- a/crates/languages/src/python.rs +++ b/crates/languages/src/python.rs @@ -16,6 +16,7 @@ use node_runtime::{NodeRuntime, VersionStrategy}; use pet_core::Configuration; use pet_core::os_environment::Environment; use pet_core::python_environment::{PythonEnvironment, PythonEnvironmentKind}; +use pet_virtualenv::is_virtualenv_dir; use project::Fs; use project::lsp_store::language_server_settings; use serde_json::{Value, json}; @@ -460,7 +461,7 @@ impl LspAdapter for PyrightLspAdapter { pet_core::python_environment::PythonEnvironment, >(toolchain.as_json.clone()) { - if user_settings.is_null() { + if !user_settings.is_object() { user_settings = Value::Object(serde_json::Map::default()); } let object = user_settings.as_object_mut().unwrap(); @@ -491,9 +492,13 @@ impl LspAdapter for PyrightLspAdapter { // Get or create the python section let python = object .entry("python") - .or_insert(Value::Object(serde_json::Map::default())) - .as_object_mut() - .unwrap(); + .and_modify(|v| { + if !v.is_object() { + *v = Value::Object(serde_json::Map::default()); + } + }) + .or_insert(Value::Object(serde_json::Map::default())); + let python = python.as_object_mut().unwrap(); // Set both pythonPath and defaultInterpreterPath for compatibility python.insert( @@ -900,6 +905,21 @@ fn python_module_name_from_relative_path(relative_path: &str) -> String { .to_string() } +fn is_python_env_global(k: &PythonEnvironmentKind) -> bool { + matches!( + k, + PythonEnvironmentKind::Homebrew + | PythonEnvironmentKind::Pyenv + | PythonEnvironmentKind::GlobalPaths + | PythonEnvironmentKind::MacPythonOrg + | PythonEnvironmentKind::MacCommandLineTools + | PythonEnvironmentKind::LinuxGlobal + | PythonEnvironmentKind::MacXCode + | PythonEnvironmentKind::WindowsStore + | PythonEnvironmentKind::WindowsRegistry + ) +} + fn python_env_kind_display(k: &PythonEnvironmentKind) -> &'static str { match k { PythonEnvironmentKind::Conda => "Conda", @@ -966,6 +986,26 @@ async fn get_worktree_venv_declaration(worktree_root: &Path) -> Option { Some(venv_name.trim().to_string()) } +fn get_venv_parent_dir(env: &PythonEnvironment) -> Option { + // If global, we aren't a virtual environment + if let Some(kind) = env.kind + && is_python_env_global(&kind) + { + return None; + } + + // Check to be sure we are a virtual environment using pet's most generic + // virtual environment type, VirtualEnv + let venv = env + .executable + .as_ref() + .and_then(|p| p.parent()) + .and_then(|p| p.parent()) + .filter(|p| is_virtualenv_dir(p))?; + + venv.parent().map(|parent| parent.to_path_buf()) +} + #[async_trait] impl ToolchainLister for PythonToolchainProvider { async fn list( @@ -1025,11 +1065,15 @@ impl ToolchainLister for PythonToolchainProvider { }); // Compare project paths against worktree root - let proj_ordering = || match (&lhs.project, &rhs.project) { - (Some(l), Some(r)) => (r == &wr).cmp(&(l == &wr)), - (Some(l), None) if l == &wr => Ordering::Less, - (None, Some(r)) if r == &wr => Ordering::Greater, - _ => Ordering::Equal, + let proj_ordering = || { + let lhs_project = lhs.project.clone().or_else(|| get_venv_parent_dir(lhs)); + let rhs_project = rhs.project.clone().or_else(|| get_venv_parent_dir(rhs)); + match (&lhs_project, &rhs_project) { + (Some(l), Some(r)) => (r == &wr).cmp(&(l == &wr)), + (Some(l), None) if l == &wr => Ordering::Less, + (None, Some(r)) if r == &wr => Ordering::Greater, + _ => Ordering::Equal, + } }; // Compare environment priorities @@ -1131,7 +1175,7 @@ impl ToolchainLister for PythonToolchainProvider { let activate_keyword = match shell { ShellKind::Cmd => ".", ShellKind::Nushell => "overlay use", - ShellKind::Powershell => ".", + ShellKind::PowerShell => ".", ShellKind::Fish => "source", ShellKind::Csh => "source", ShellKind::Posix => "source", @@ -1141,7 +1185,7 @@ impl ToolchainLister for PythonToolchainProvider { ShellKind::Csh => "activate.csh", ShellKind::Fish => "activate.fish", ShellKind::Nushell => "activate.nu", - ShellKind::Powershell => "activate.ps1", + ShellKind::PowerShell => "activate.ps1", ShellKind::Cmd => "activate.bat", }; let path = prefix.join(BINARY_DIR).join(activate_script_name); @@ -1165,7 +1209,7 @@ impl ToolchainLister for PythonToolchainProvider { ShellKind::Fish => Some(format!("\"{pyenv}\" shell - fish {version}")), ShellKind::Posix => Some(format!("\"{pyenv}\" shell - sh {version}")), ShellKind::Nushell => Some(format!("\"{pyenv}\" shell - nu {version}")), - ShellKind::Powershell => None, + ShellKind::PowerShell => None, ShellKind::Csh => None, ShellKind::Cmd => None, }) @@ -1425,7 +1469,7 @@ impl LspAdapter for PyLspAdapter { // If user did not explicitly modify their python venv, use one from picker. if let Some(toolchain) = toolchain { - if user_settings.is_null() { + if !user_settings.is_object() { user_settings = Value::Object(serde_json::Map::default()); } let object = user_settings.as_object_mut().unwrap(); @@ -1747,7 +1791,7 @@ impl LspAdapter for BasedPyrightLspAdapter { pet_core::python_environment::PythonEnvironment, >(toolchain.as_json.clone()) { - if user_settings.is_null() { + if !user_settings.is_object() { user_settings = Value::Object(serde_json::Map::default()); } let object = user_settings.as_object_mut().unwrap(); diff --git a/crates/languages/src/tailwind.rs b/crates/languages/src/tailwind.rs index db539edabb6c10253f2fd0d688eafa46082d427a..e1b50a5ccaabb7770d13abc79fbac1da5fa4cbbe 100644 --- a/crates/languages/src/tailwind.rs +++ b/crates/languages/src/tailwind.rs @@ -146,6 +146,7 @@ impl LspAdapter for TailwindLspAdapter { "html": "html", "css": "css", "javascript": "javascript", + "typescript": "typescript", "typescriptreact": "typescriptreact", }, }))) @@ -178,6 +179,7 @@ impl LspAdapter for TailwindLspAdapter { (LanguageName::new("HTML"), "html".to_string()), (LanguageName::new("CSS"), "css".to_string()), (LanguageName::new("JavaScript"), "javascript".to_string()), + (LanguageName::new("TypeScript"), "typescript".to_string()), (LanguageName::new("TSX"), "typescriptreact".to_string()), (LanguageName::new("Svelte"), "svelte".to_string()), (LanguageName::new("Elixir"), "phoenix-heex".to_string()), diff --git a/crates/languages/src/typescript/config.toml b/crates/languages/src/typescript/config.toml index 2344f6209da7756049438669ee55d5376fdb47f8..fe56e496ec717895e72f37dda9146fbb30b50e88 100644 --- a/crates/languages/src/typescript/config.toml +++ b/crates/languages/src/typescript/config.toml @@ -21,9 +21,11 @@ word_characters = ["#", "$"] prettier_parser_name = "typescript" tab_size = 2 debuggers = ["JavaScript"] +scope_opt_in_language_servers = ["tailwindcss-language-server"] [overrides.string] -completion_query_characters = ["."] +completion_query_characters = ["-", "."] +opt_into_language_servers = ["tailwindcss-language-server"] prefer_label_for_snippet = true [overrides.function_name_before_type_arguments] diff --git a/crates/livekit_client/src/livekit_client/playback.rs b/crates/livekit_client/src/livekit_client/playback.rs index 7c866113103a883e7e7a2d9d3f5651d833d7e637..df8b5ea54fb1ce11bf871faa912757bbff1fd7f9 100644 --- a/crates/livekit_client/src/livekit_client/playback.rs +++ b/crates/livekit_client/src/livekit_client/playback.rs @@ -188,12 +188,15 @@ impl AudioStack { let voip_parts = audio::VoipParts::new(cx)?; // Audio needs to run real-time and should never be paused. That is why we are using a // normal std::thread and not a background task - thread::spawn(move || { - // microphone is non send on mac - let microphone = audio::Audio::open_microphone(voip_parts)?; - send_to_livekit(frame_tx, microphone); - Ok::<(), anyhow::Error>(()) - }); + thread::Builder::new() + .name("AudioCapture".to_string()) + .spawn(move || { + // microphone is non send on mac + let microphone = audio::Audio::open_microphone(voip_parts)?; + send_to_livekit(frame_tx, microphone); + Ok::<(), anyhow::Error>(()) + }) + .unwrap(); Task::ready(Ok(())) } else { self.executor.spawn(async move { @@ -229,57 +232,60 @@ impl AudioStack { let mut resampler = audio_resampler::AudioResampler::default(); let mut buf = Vec::new(); - thread::spawn(move || { - let output_stream = output_device.build_output_stream( - &output_config.config(), - { - move |mut data, _info| { - while data.len() > 0 { - if data.len() <= buf.len() { - let rest = buf.split_off(data.len()); - data.copy_from_slice(&buf); - buf = rest; - return; - } - if buf.len() > 0 { - let (prefix, suffix) = data.split_at_mut(buf.len()); - prefix.copy_from_slice(&buf); - data = suffix; - } + thread::Builder::new() + .name("AudioPlayback".to_owned()) + .spawn(move || { + let output_stream = output_device.build_output_stream( + &output_config.config(), + { + move |mut data, _info| { + while data.len() > 0 { + if data.len() <= buf.len() { + let rest = buf.split_off(data.len()); + data.copy_from_slice(&buf); + buf = rest; + return; + } + if buf.len() > 0 { + let (prefix, suffix) = data.split_at_mut(buf.len()); + prefix.copy_from_slice(&buf); + data = suffix; + } - let mut mixer = mixer.lock(); - let mixed = mixer.mix(output_config.channels() as usize); - let sampled = resampler.remix_and_resample( - mixed, - sample_rate / 100, - num_channels, - sample_rate, - output_config.channels() as u32, - output_config.sample_rate().0, - ); - buf = sampled.to_vec(); - apm.lock() - .process_reverse_stream( - &mut buf, - output_config.sample_rate().0 as i32, - output_config.channels() as i32, - ) - .ok(); + let mut mixer = mixer.lock(); + let mixed = mixer.mix(output_config.channels() as usize); + let sampled = resampler.remix_and_resample( + mixed, + sample_rate / 100, + num_channels, + sample_rate, + output_config.channels() as u32, + output_config.sample_rate().0, + ); + buf = sampled.to_vec(); + apm.lock() + .process_reverse_stream( + &mut buf, + output_config.sample_rate().0 as i32, + output_config.channels() as i32, + ) + .ok(); + } } - } - }, - |error| log::error!("error playing audio track: {:?}", error), - Some(Duration::from_millis(100)), - ); + }, + |error| log::error!("error playing audio track: {:?}", error), + Some(Duration::from_millis(100)), + ); - let Some(output_stream) = output_stream.log_err() else { - return; - }; + let Some(output_stream) = output_stream.log_err() else { + return; + }; - output_stream.play().log_err(); - // Block forever to keep the output stream alive - end_on_drop_rx.recv().ok(); - }); + output_stream.play().log_err(); + // Block forever to keep the output stream alive + end_on_drop_rx.recv().ok(); + }) + .unwrap(); device_change_listener.next().await; drop(end_on_drop_tx) @@ -300,77 +306,81 @@ impl AudioStack { let frame_tx = frame_tx.clone(); let mut resampler = audio_resampler::AudioResampler::default(); - thread::spawn(move || { - maybe!({ - if let Some(name) = device.name().ok() { - log::info!("Using microphone: {}", name) - } else { - log::info!("Using microphone: "); - } - - let ten_ms_buffer_size = - (config.channels() as u32 * config.sample_rate().0 / 100) as usize; - let mut buf: Vec = Vec::with_capacity(ten_ms_buffer_size); - - let stream = device - .build_input_stream_raw( - &config.config(), - config.sample_format(), - move |data, _: &_| { - let data = - crate::get_sample_data(config.sample_format(), data).log_err(); - let Some(data) = data else { - return; - }; - let mut data = data.as_slice(); + thread::Builder::new() + .name("AudioCapture".to_owned()) + .spawn(move || { + maybe!({ + if let Some(name) = device.name().ok() { + log::info!("Using microphone: {}", name) + } else { + log::info!("Using microphone: "); + } - while data.len() > 0 { - let remainder = (buf.capacity() - buf.len()).min(data.len()); - buf.extend_from_slice(&data[..remainder]); - data = &data[remainder..]; - - if buf.capacity() == buf.len() { - let mut sampled = resampler - .remix_and_resample( - buf.as_slice(), - config.sample_rate().0 / 100, - config.channels() as u32, - config.sample_rate().0, - num_channels, - sample_rate, - ) - .to_owned(); - apm.lock() - .process_stream( - &mut sampled, - sample_rate as i32, - num_channels as i32, - ) - .log_err(); - buf.clear(); - frame_tx - .unbounded_send(AudioFrame { - data: Cow::Owned(sampled), - sample_rate, - num_channels, - samples_per_channel: sample_rate / 100, - }) - .ok(); + let ten_ms_buffer_size = + (config.channels() as u32 * config.sample_rate().0 / 100) as usize; + let mut buf: Vec = Vec::with_capacity(ten_ms_buffer_size); + + let stream = device + .build_input_stream_raw( + &config.config(), + config.sample_format(), + move |data, _: &_| { + let data = crate::get_sample_data(config.sample_format(), data) + .log_err(); + let Some(data) = data else { + return; + }; + let mut data = data.as_slice(); + + while data.len() > 0 { + let remainder = + (buf.capacity() - buf.len()).min(data.len()); + buf.extend_from_slice(&data[..remainder]); + data = &data[remainder..]; + + if buf.capacity() == buf.len() { + let mut sampled = resampler + .remix_and_resample( + buf.as_slice(), + config.sample_rate().0 / 100, + config.channels() as u32, + config.sample_rate().0, + num_channels, + sample_rate, + ) + .to_owned(); + apm.lock() + .process_stream( + &mut sampled, + sample_rate as i32, + num_channels as i32, + ) + .log_err(); + buf.clear(); + frame_tx + .unbounded_send(AudioFrame { + data: Cow::Owned(sampled), + sample_rate, + num_channels, + samples_per_channel: sample_rate / 100, + }) + .ok(); + } } - } - }, - |err| log::error!("error capturing audio track: {:?}", err), - Some(Duration::from_millis(100)), - ) - .context("failed to build input stream")?; - - stream.play()?; - // Keep the thread alive and holding onto the `stream` - end_on_drop_rx.recv().ok(); - anyhow::Ok(Some(())) + }, + |err| log::error!("error capturing audio track: {:?}", err), + Some(Duration::from_millis(100)), + ) + .context("failed to build input stream")?; + + stream.play()?; + // Keep the thread alive and holding onto the `stream` + end_on_drop_rx.recv().ok(); + anyhow::Ok(Some(())) + }) + .log_err(); }) - .log_err(); - }); + .unwrap(); device_change_listener.next().await; drop(end_on_drop_tx) diff --git a/crates/markdown/src/markdown.rs b/crates/markdown/src/markdown.rs index 4e1d3ac51e148439e57a4a1c305dabc31cbc2046..c2f8025e32d70cdd9500afdf0a4fc02a334a8521 100644 --- a/crates/markdown/src/markdown.rs +++ b/crates/markdown/src/markdown.rs @@ -1079,7 +1079,7 @@ impl Element for MarkdownElement { { builder.modify_current_div(|el| { let content_range = parser::extract_code_block_content_range( - parsed_markdown.source()[range.clone()].trim(), + &parsed_markdown.source()[range.clone()], ); let content_range = content_range.start + range.start ..content_range.end + range.start; @@ -1110,7 +1110,7 @@ impl Element for MarkdownElement { { builder.modify_current_div(|el| { let content_range = parser::extract_code_block_content_range( - parsed_markdown.source()[range.clone()].trim(), + &parsed_markdown.source()[range.clone()], ); let content_range = content_range.start + range.start ..content_range.end + range.start; diff --git a/crates/markdown/src/parser.rs b/crates/markdown/src/parser.rs index 3720e5b1ef5f61f0a209ac5617119de61ed05517..d60d34b41e7efc99970f72b15a8ea9c4c79eb6f9 100644 --- a/crates/markdown/src/parser.rs +++ b/crates/markdown/src/parser.rs @@ -67,7 +67,7 @@ pub fn parse_markdown( MarkdownTag::CodeBlock { kind: CodeBlockKind::Indented, metadata: CodeBlockMetadata { - content_range: range.start + 1..range.end + 1, + content_range: range.clone(), line_count: 1, }, } @@ -698,7 +698,28 @@ mod tests { HashSet::from(["rust".into()]), HashSet::new() ) - ) + ); + assert_eq!( + parse_markdown(" fn main() {}"), + ( + vec![ + ( + 4..16, + Start(CodeBlock { + kind: CodeBlockKind::Indented, + metadata: CodeBlockMetadata { + content_range: 4..16, + line_count: 1 + } + }) + ), + (4..16, Text), + (4..16, End(MarkdownTagEnd::CodeBlock)) + ], + HashSet::new(), + HashSet::new() + ) + ); } #[test] diff --git a/crates/markdown_preview/src/markdown_preview_view.rs b/crates/markdown_preview/src/markdown_preview_view.rs index 1121d64655f6c7e02f0b0d621605c9ba1aae7cde..d20ed40b7928186e2caf564be5ff66b0bd04f0d1 100644 --- a/crates/markdown_preview/src/markdown_preview_view.rs +++ b/crates/markdown_preview/src/markdown_preview_view.rs @@ -13,7 +13,7 @@ use gpui::{ use language::LanguageRegistry; use settings::Settings; use theme::ThemeSettings; -use ui::prelude::*; +use ui::{WithScrollbar, prelude::*}; use workspace::item::{Item, ItemHandle}; use workspace::{Pane, Workspace}; @@ -481,7 +481,7 @@ impl Item for MarkdownPreviewView { } impl Render for MarkdownPreviewView { - fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let buffer_size = ThemeSettings::get_global(cx).buffer_font_size(cx); let buffer_line_height = ThemeSettings::get_global(cx).buffer_line_height; @@ -598,5 +598,6 @@ impl Render for MarkdownPreviewView { .size_full(), ) })) + .vertical_scrollbar_for(self.list_state.clone(), window, cx) } } diff --git a/crates/multi_buffer/src/multi_buffer.rs b/crates/multi_buffer/src/multi_buffer.rs index 5273007a9fc44700eea5078c35636fe9282d486f..272caf28c8b8e6d1516292e49fb34817f1a1d062 100644 --- a/crates/multi_buffer/src/multi_buffer.rs +++ b/crates/multi_buffer/src/multi_buffer.rs @@ -6129,6 +6129,12 @@ impl MultiBufferSnapshot { text: item.text, highlight_ranges: item.highlight_ranges, name_ranges: item.name_ranges, + signature_range: item.signature_range.and_then(|signature_range| { + Some( + self.anchor_in_excerpt(*excerpt_id, signature_range.start)? + ..self.anchor_in_excerpt(*excerpt_id, signature_range.end)?, + ) + }), body_range: item.body_range.and_then(|body_range| { Some( self.anchor_in_excerpt(*excerpt_id, body_range.start)? @@ -6169,6 +6175,12 @@ impl MultiBufferSnapshot { text: item.text, highlight_ranges: item.highlight_ranges, name_ranges: item.name_ranges, + signature_range: item.signature_range.and_then(|signature_range| { + Some( + self.anchor_in_excerpt(excerpt_id, signature_range.start)? + ..self.anchor_in_excerpt(excerpt_id, signature_range.end)?, + ) + }), body_range: item.body_range.and_then(|body_range| { Some( self.anchor_in_excerpt(excerpt_id, body_range.start)? diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs index 18cd5811b29dd73fed01c1e4d30c5e0ae802b76a..dced37e0fc1e19e61bba5e14010812f08fe3a1e5 100644 --- a/crates/ollama/src/ollama.rs +++ b/crates/ollama/src/ollama.rs @@ -257,14 +257,19 @@ pub async fn complete( pub async fn stream_chat_completion( client: &dyn HttpClient, api_url: &str, + api_key: Option<&str>, request: ChatRequest, ) -> Result>> { let uri = format!("{api_url}/api/chat"); - let request_builder = http::Request::builder() + let mut request_builder = http::Request::builder() .method(Method::POST) .uri(uri) .header("Content-Type", "application/json"); + if let Some(api_key) = api_key { + request_builder = request_builder.header("Authorization", format!("Bearer {api_key}")) + } + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; let mut response = client.send(request).await?; if response.status().is_success() { @@ -291,14 +296,19 @@ pub async fn stream_chat_completion( pub async fn get_models( client: &dyn HttpClient, api_url: &str, + api_key: Option<&str>, _: Option, ) -> Result> { let uri = format!("{api_url}/api/tags"); - let request_builder = HttpRequest::builder() + let mut request_builder = HttpRequest::builder() .method(Method::GET) .uri(uri) .header("Accept", "application/json"); + if let Some(api_key) = api_key { + request_builder = request_builder.header("Authorization", format!("Bearer {api_key}")); + } + let request = request_builder.body(AsyncBody::default())?; let mut response = client.send(request).await?; @@ -318,15 +328,25 @@ pub async fn get_models( } /// Fetch details of a model, used to determine model capabilities -pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result { +pub async fn show_model( + client: &dyn HttpClient, + api_url: &str, + api_key: Option<&str>, + model: &str, +) -> Result { let uri = format!("{api_url}/api/show"); - let request = HttpRequest::builder() + let mut request_builder = HttpRequest::builder() .method(Method::POST) .uri(uri) - .header("Content-Type", "application/json") - .body(AsyncBody::from( - serde_json::json!({ "model": model }).to_string(), - ))?; + .header("Content-Type", "application/json"); + + if let Some(api_key) = api_key { + request_builder = request_builder.header("Authorization", format!("Bearer {api_key}")) + } + + let request = request_builder.body(AsyncBody::from( + serde_json::json!({ "model": model }).to_string(), + ))?; let mut response = client.send(request).await?; let mut body = String::new(); diff --git a/crates/onboarding/src/onboarding.rs b/crates/onboarding/src/onboarding.rs index 9dcf27c7cbebf6621bbeb558619944c768e63fb6..835db1734d986b33d3f58c8d9f1a4883458cba31 100644 --- a/crates/onboarding/src/onboarding.rs +++ b/crates/onboarding/src/onboarding.rs @@ -5,8 +5,8 @@ use db::kvp::KEY_VALUE_STORE; use fs::Fs; use gpui::{ Action, AnyElement, App, AppContext, AsyncWindowContext, Context, Entity, EventEmitter, - FocusHandle, Focusable, Global, IntoElement, KeyContext, Render, SharedString, Subscription, - Task, WeakEntity, Window, actions, + FocusHandle, Focusable, Global, IntoElement, KeyContext, Render, ScrollHandle, SharedString, + Subscription, Task, WeakEntity, Window, actions, }; use notifications::status_toast::{StatusToast, ToastIcon}; use schemars::JsonSchema; @@ -15,7 +15,7 @@ use settings::{SettingsStore, VsCodeSettingsSource}; use std::sync::Arc; use ui::{ Avatar, ButtonLike, FluentBuilder, Headline, KeyBinding, ParentElement as _, - StatefulInteractiveElement, Vector, VectorName, prelude::*, rems_from_px, + StatefulInteractiveElement, Vector, VectorName, WithScrollbar, prelude::*, rems_from_px, }; use workspace::{ AppState, Workspace, WorkspaceId, @@ -237,6 +237,7 @@ struct Onboarding { focus_handle: FocusHandle, selected_page: SelectedPage, user_store: Entity, + scroll_handle: ScrollHandle, _settings_subscription: Subscription, } @@ -256,6 +257,7 @@ impl Onboarding { Self { workspace: workspace.weak_handle(), focus_handle: cx.focus_handle(), + scroll_handle: ScrollHandle::new(), selected_page: SelectedPage::Basics, user_store: workspace.user_store().clone(), _settings_subscription: cx @@ -280,6 +282,7 @@ impl Onboarding { } self.selected_page = page; + self.scroll_handle.set_offset(Default::default()); cx.notify(); cx.emit(ItemEvent::UpdateTab); } @@ -584,16 +587,23 @@ impl Render for Onboarding { .gap_12() .child(self.render_nav(window, cx)) .child( - v_flex() - .id("page-content") + div() .size_full() - .max_w_full() - .min_w_0() - .pl_12() - .border_l_1() - .border_color(cx.theme().colors().border_variant.opacity(0.5)) - .overflow_y_scroll() - .child(self.render_page(window, cx)), + .pr_6() + .child( + v_flex() + .id("page-content") + .size_full() + .max_w_full() + .min_w_0() + .pl_12() + .border_l_1() + .border_color(cx.theme().colors().border_variant.opacity(0.5)) + .overflow_y_scroll() + .child(self.render_page(window, cx)) + .track_scroll(&self.scroll_handle), + ) + .vertical_scrollbar_for(self.scroll_handle.clone(), window, cx), ), ) } @@ -632,6 +642,7 @@ impl Item for Onboarding { workspace: self.workspace.clone(), user_store: self.user_store.clone(), selected_page: self.selected_page, + scroll_handle: ScrollHandle::new(), focus_handle: cx.focus_handle(), _settings_subscription: cx.observe_global::(move |_, cx| cx.notify()), })) diff --git a/crates/outline_panel/src/outline_panel.rs b/crates/outline_panel/src/outline_panel.rs index 76ee4beec2315ed1db55bd3002af1d49d83095e0..5fd3103dff842bfbacc612536218ac7367b27ec4 100644 --- a/crates/outline_panel/src/outline_panel.rs +++ b/crates/outline_panel/src/outline_panel.rs @@ -2481,6 +2481,7 @@ impl OutlinePanel { &OutlineItem { depth, annotation_range: None, + signature_range: None, range: search_data.context_range.clone(), text: search_data.context_text.clone(), highlight_ranges: search_data @@ -4692,7 +4693,10 @@ impl OutlinePanel { .custom_scrollbars( Scrollbars::for_settings::() .tracked_scroll_handle(self.scroll_handle.clone()) - .with_track_along(ScrollAxes::Horizontal) + .with_track_along( + ScrollAxes::Horizontal, + cx.theme().colors().panel_background, + ) .tracked_entity(cx.entity_id()), window, cx, diff --git a/crates/project/src/agent_server_store.rs b/crates/project/src/agent_server_store.rs index 79ff8badaf87734cfc4786fcbf54bfb2d6b39b5a..9357bd773466645d6e7dad666b92dc5b5ed9ce4c 100644 --- a/crates/project/src/agent_server_store.rs +++ b/crates/project/src/agent_server_store.rs @@ -234,7 +234,7 @@ impl AgentServerStore { let subscription = cx.observe_global::(|this, cx| { this.agent_servers_settings_changed(cx); }); - let this = Self { + let mut this = Self { state: AgentServerStoreState::Local { node_runtime, fs, @@ -245,14 +245,7 @@ impl AgentServerStore { }, external_agents: Default::default(), }; - cx.spawn(async move |this, cx| { - cx.background_executor().timer(Duration::from_secs(1)).await; - this.update(cx, |this, cx| { - this.agent_servers_settings_changed(cx); - }) - .ok(); - }) - .detach(); + this.agent_servers_settings_changed(cx); this } @@ -305,22 +298,29 @@ impl AgentServerStore { } } - pub fn shared(&mut self, project_id: u64, client: AnyProtoClient) { + pub fn shared(&mut self, project_id: u64, client: AnyProtoClient, cx: &mut Context) { match &mut self.state { AgentServerStoreState::Local { downstream_client, .. } => { - client - .send(proto::ExternalAgentsUpdated { - project_id, - names: self - .external_agents + *downstream_client = Some((project_id, client.clone())); + // Send the current list of external agents downstream, but only after a delay, + // to avoid having the message arrive before the downstream project's agent server store + // sets up its handlers. + cx.spawn(async move |this, cx| { + cx.background_executor().timer(Duration::from_secs(1)).await; + let names = this.update(cx, |this, _| { + this.external_agents .keys() .map(|name| name.to_string()) - .collect(), - }) - .log_err(); - *downstream_client = Some((project_id, client)); + .collect() + })?; + client + .send(proto::ExternalAgentsUpdated { project_id, names }) + .log_err(); + anyhow::Ok(()) + }) + .detach(); } AgentServerStoreState::Remote { .. } => { debug_panic!( @@ -721,11 +721,6 @@ struct RemoteExternalAgentServer { new_version_available_tx: Option>>, } -// new method: status_updated -// does nothing in the all-local case -// for RemoteExternalAgentServer, sends on the stored tx -// etc. - impl ExternalAgentServer for RemoteExternalAgentServer { fn get_command( &mut self, diff --git a/crates/project/src/git_store/conflict_set.rs b/crates/project/src/git_store/conflict_set.rs index 313a1e90adc2fde8a62dbe6aa60b4d3a366af22c..2bcfc75b32da3c5a4860cc72f3266bff38f022e3 100644 --- a/crates/project/src/git_store/conflict_set.rs +++ b/crates/project/src/git_store/conflict_set.rs @@ -257,7 +257,7 @@ impl EventEmitter for ConflictSet {} mod tests { use std::{path::Path, sync::mpsc}; - use crate::{Project, project_settings::ProjectSettings}; + use crate::Project; use super::*; use fs::FakeFs; @@ -484,7 +484,7 @@ mod tests { cx.update(|cx| { settings::init(cx); WorktreeSettings::register(cx); - ProjectSettings::register(cx); + Project::init_settings(cx); AllLanguageSettings::register(cx); }); let initial_text = " @@ -585,7 +585,7 @@ mod tests { cx.update(|cx| { settings::init(cx); WorktreeSettings::register(cx); - ProjectSettings::register(cx); + Project::init_settings(cx); AllLanguageSettings::register(cx); }); diff --git a/crates/project/src/lsp_store.rs b/crates/project/src/lsp_store.rs index 09b0c7b5351dd0b9bd4b03ff89a46f61e30a72ad..14a3f1921c04a6572fb5f4e4535ba4895c556d94 100644 --- a/crates/project/src/lsp_store.rs +++ b/crates/project/src/lsp_store.rs @@ -10541,7 +10541,10 @@ impl LspStore { for (worktree_id, servers) in &local.lsp_tree.instances { if *worktree_id != key.worktree_id { for server_map in servers.roots.values() { - if server_map.contains_key(&key.name) { + if server_map + .values() + .any(|(node, _)| node.id() == Some(server_id)) + { worktrees_using_server.push(*worktree_id); } } @@ -10551,6 +10554,7 @@ impl LspStore { let mut buffer_paths_registered = Vec::new(); self.buffer_store.clone().update(cx, |buffer_store, cx| { + let mut lsp_adapters = HashMap::default(); for buffer_handle in buffer_store.buffers() { let buffer = buffer_handle.read(cx); let file = match File::from_dyn(buffer.file()) { @@ -10563,9 +10567,9 @@ impl LspStore { }; if !worktrees_using_server.contains(&file.worktree.read(cx).id()) - || !self - .languages - .lsp_adapters(&language.name()) + || !lsp_adapters + .entry(language.name()) + .or_insert_with(|| self.languages.lsp_adapters(&language.name())) .iter() .any(|a| a.name == key.name) { diff --git a/crates/project/src/manifest_tree/server_tree.rs b/crates/project/src/manifest_tree/server_tree.rs index 48e2007d47f1ebd3c950f0d03e80dcccad515389..17a183e5b08dcab8ba6b6da861d2ec12559092b5 100644 --- a/crates/project/src/manifest_tree/server_tree.rs +++ b/crates/project/src/manifest_tree/server_tree.rs @@ -114,6 +114,10 @@ impl InnerTreeNode { }), } } + + pub(crate) fn id(&self) -> Option { + self.id.get().copied() + } } impl LanguageServerTree { diff --git a/crates/project/src/search.rs b/crates/project/src/search.rs index f2c6091e0cb00b8da1a752e3d25afe3389e8c818..953fa4f1aafdca87ccd1e8dbcfec145505660642 100644 --- a/crates/project/src/search.rs +++ b/crates/project/src/search.rs @@ -64,7 +64,6 @@ pub enum SearchQuery { include_ignored: bool, inner: SearchInputs, }, - Regex { regex: Regex, replacement: Option, diff --git a/crates/project/src/terminals.rs b/crates/project/src/terminals.rs index 04f98d6dba6794116be9a6dcf4d2cbb32cfb85b2..94e9999e1344efbc391476e22d107f10052d7694 100644 --- a/crates/project/src/terminals.rs +++ b/crates/project/src/terminals.rs @@ -179,7 +179,7 @@ impl Project { } }; - let shell = { + let (shell, env) = { env.extend(spawn_task.env); match remote_client { Some(remote_client) => match activation_script.clone() { @@ -189,8 +189,14 @@ impl Project { let args = vec!["-c".to_owned(), format!("{activation_script}; {to_run}")]; create_remote_shell( - Some((&shell, &args)), - &mut env, + Some(( + &remote_client + .read(cx) + .shell() + .unwrap_or_else(get_default_system_shell), + &args, + )), + env, path, remote_client, cx, @@ -201,7 +207,7 @@ impl Project { .command .as_ref() .map(|command| (command, &spawn_task.args)), - &mut env, + env, path, remote_client, cx, @@ -220,13 +226,16 @@ impl Project { #[cfg(not(windows))] let arg = format!("{activation_script}; {to_run}"); - Shell::WithArguments { - program: shell, - args: vec!["-c".to_owned(), arg], - title_override: None, - } + ( + Shell::WithArguments { + program: shell, + args: vec!["-c".to_owned(), arg], + title_override: None, + }, + env, + ) } - _ => { + _ => ( if let Some(program) = spawn_task.command { Shell::WithArguments { program, @@ -235,8 +244,9 @@ impl Project { } } else { Shell::System - } - } + }, + env, + ), }, } }; @@ -330,7 +340,7 @@ impl Project { .map(|p| self.active_toolchain(p, LanguageName::new("Python"), cx)) .collect::>(); let remote_client = self.remote_client.clone(); - let shell = match &remote_client { + let shell_kind = ShellKind::new(&match &remote_client { Some(remote_client) => remote_client .read(cx) .shell() @@ -344,7 +354,7 @@ impl Project { } => program.clone(), Shell::System => get_system_shell(), }, - }; + }); let lang_registry = self.languages.clone(); let fs = self.fs.clone(); @@ -361,7 +371,7 @@ impl Project { let lister = language?.toolchain_lister(); return Some( lister? - .activation_script(&toolchain, ShellKind::new(&shell), fs.as_ref()) + .activation_script(&toolchain, shell_kind, fs.as_ref()) .await, ); } @@ -370,12 +380,12 @@ impl Project { .await .unwrap_or_default(); project.update(cx, move |this, cx| { - let shell = { + let (shell, env) = { match remote_client { Some(remote_client) => { - create_remote_shell(None, &mut env, path, remote_client, cx)? + create_remote_shell(None, env, path, remote_client, cx)? } - None => settings.shell, + None => (settings.shell, env), } }; TerminalBuilder::new( @@ -545,11 +555,11 @@ fn quote_arg(argument: &str, quote: bool) -> String { fn create_remote_shell( spawn_command: Option<(&String, &Vec)>, - env: &mut HashMap, + mut env: HashMap, working_directory: Option>, remote_client: Entity, cx: &mut App, -) -> Result { +) -> Result<(Shell, HashMap)> { // Alacritty sets its terminfo to `alacritty`, this requiring hosts to have it installed // to properly display colors. // We do not have the luxury of assuming the host has it installed, @@ -565,18 +575,20 @@ fn create_remote_shell( let command = remote_client.read(cx).build_command( program, args.as_slice(), - env, + &env, working_directory.map(|path| path.display().to_string()), None, )?; - *env = command.env; log::debug!("Connecting to a remote server: {:?}", command.program); let host = remote_client.read(cx).connection_options().display_name(); - Ok(Shell::WithArguments { - program: command.program, - args: command.args, - title_override: Some(format!("{} — Terminal", host).into()), - }) + Ok(( + Shell::WithArguments { + program: command.program, + args: command.args, + title_override: Some(format!("{} — Terminal", host).into()), + }, + command.env, + )) } diff --git a/crates/project_panel/src/project_panel.rs b/crates/project_panel/src/project_panel.rs index 90c02be50d7cbe723c1ea16a8fd19bef21238ba6..debe3fc32a002b7a78806e3d6979f028cdc665b0 100644 --- a/crates/project_panel/src/project_panel.rs +++ b/crates/project_panel/src/project_panel.rs @@ -1154,8 +1154,32 @@ impl ProjectPanel { ) { // By keeping entries for fully collapsed worktrees, we avoid expanding them within update_visible_entries // (which is it's default behavior when there's no entry for a worktree in expanded_dir_ids). + let multiple_worktrees = self.project.read(cx).worktrees(cx).count() > 1; + let project = self.project.read(cx); + self.expanded_dir_ids - .retain(|_, expanded_entries| expanded_entries.is_empty()); + .iter_mut() + .for_each(|(worktree_id, expanded_entries)| { + if multiple_worktrees { + *expanded_entries = Default::default(); + return; + } + + let root_entry_id = project + .worktree_for_id(*worktree_id, cx) + .map(|worktree| worktree.read(cx).snapshot()) + .and_then(|worktree_snapshot| { + worktree_snapshot.root_entry().map(|entry| entry.id) + }); + + match root_entry_id { + Some(id) => { + expanded_entries.retain(|entry_id| entry_id == &id); + } + None => *expanded_entries = Default::default(), + }; + }); + self.update_visible_entries(None, cx); cx.notify(); } @@ -5608,7 +5632,10 @@ impl Render for ProjectPanel { .custom_scrollbars( Scrollbars::for_settings::() .tracked_scroll_handle(self.scroll_handle.clone()) - .with_track_along(ScrollAxes::Horizontal) + .with_track_along( + ScrollAxes::Horizontal, + cx.theme().colors().panel_background, + ) .notify_content(), window, cx, diff --git a/crates/project_panel/src/project_panel_tests.rs b/crates/project_panel/src/project_panel_tests.rs index 61684929a5e61e62c08d2f0e9d91def408448d8f..59061d6188721d6fefb2f539325f2058f54f1ee7 100644 --- a/crates/project_panel/src/project_panel_tests.rs +++ b/crates/project_panel/src/project_panel_tests.rs @@ -2747,6 +2747,111 @@ async fn test_collapse_all_entries(cx: &mut gpui::TestAppContext) { ); } +#[gpui::test] +async fn test_collapse_all_entries_multiple_worktrees(cx: &mut gpui::TestAppContext) { + init_test_with_editor(cx); + + let fs = FakeFs::new(cx.executor()); + let worktree_content = json!({ + "dir_1": { + "file_1.py": "# File contents", + }, + "dir_2": { + "file_1.py": "# File contents", + } + }); + + fs.insert_tree("/project_root_1", worktree_content.clone()) + .await; + fs.insert_tree("/project_root_2", worktree_content).await; + + let project = Project::test( + fs.clone(), + ["/project_root_1".as_ref(), "/project_root_2".as_ref()], + cx, + ) + .await; + let workspace = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let cx = &mut VisualTestContext::from_window(*workspace, cx); + let panel = workspace.update(cx, ProjectPanel::new).unwrap(); + + panel.update_in(cx, |panel, window, cx| { + panel.collapse_all_entries(&CollapseAllEntries, window, cx) + }); + cx.executor().run_until_parked(); + assert_eq!( + visible_entries_as_strings(&panel, 0..10, cx), + &["> project_root_1", "> project_root_2",] + ); +} + +#[gpui::test] +async fn test_collapse_all_entries_with_collapsed_root(cx: &mut gpui::TestAppContext) { + init_test_with_editor(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/project_root", + json!({ + "dir_1": { + "nested_dir": { + "file_a.py": "# File contents", + "file_b.py": "# File contents", + "file_c.py": "# File contents", + }, + "file_1.py": "# File contents", + "file_2.py": "# File contents", + "file_3.py": "# File contents", + }, + "dir_2": { + "file_1.py": "# File contents", + "file_2.py": "# File contents", + "file_3.py": "# File contents", + } + }), + ) + .await; + + let project = Project::test(fs.clone(), ["/project_root".as_ref()], cx).await; + let workspace = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx)); + let cx = &mut VisualTestContext::from_window(*workspace, cx); + let panel = workspace.update(cx, ProjectPanel::new).unwrap(); + + // Open project_root/dir_1 to ensure that a nested directory is expanded + toggle_expand_dir(&panel, "project_root/dir_1", cx); + cx.executor().run_until_parked(); + assert_eq!( + visible_entries_as_strings(&panel, 0..10, cx), + &[ + "v project_root", + " v dir_1 <== selected", + " > nested_dir", + " file_1.py", + " file_2.py", + " file_3.py", + " > dir_2", + ] + ); + + // Close root directory + toggle_expand_dir(&panel, "project_root", cx); + cx.executor().run_until_parked(); + assert_eq!( + visible_entries_as_strings(&panel, 0..10, cx), + &["> project_root <== selected"] + ); + + // Run collapse_all_entries and make sure root is not expanded + panel.update_in(cx, |panel, window, cx| { + panel.collapse_all_entries(&CollapseAllEntries, window, cx) + }); + cx.executor().run_until_parked(); + assert_eq!( + visible_entries_as_strings(&panel, 0..10, cx), + &["> project_root <== selected"] + ); +} + #[gpui::test] async fn test_new_file_move(cx: &mut gpui::TestAppContext) { init_test(cx); diff --git a/crates/recent_projects/Cargo.toml b/crates/recent_projects/Cargo.toml index d48beeaab6bfbe54e3cac6d7f836248cc0ff2f3e..2ba6293ad2cf63c7ca664dba43f53d7facc70a57 100644 --- a/crates/recent_projects/Cargo.toml +++ b/crates/recent_projects/Cargo.toml @@ -43,6 +43,10 @@ util.workspace = true workspace.workspace = true zed_actions.workspace = true workspace-hack.workspace = true +indoc.workspace = true + +[target.'cfg(target_os = "windows")'.dependencies] +windows-registry = "0.6.0" [dev-dependencies] dap.workspace = true diff --git a/crates/recent_projects/src/recent_projects.rs b/crates/recent_projects/src/recent_projects.rs index 013c8f6724b55861add48ff94c8e79e4e5bb4756..2b011638218dd58b758f3af2e46836614e1c6780 100644 --- a/crates/recent_projects/src/recent_projects.rs +++ b/crates/recent_projects/src/recent_projects.rs @@ -3,6 +3,9 @@ mod remote_connections; mod remote_servers; mod ssh_config; +#[cfg(target_os = "windows")] +mod wsl_picker; + use remote::RemoteConnectionOptions; pub use remote_connections::open_remote_project; @@ -31,6 +34,74 @@ use zed_actions::{OpenRecent, OpenRemote}; pub fn init(cx: &mut App) { SshSettings::register(cx); + + #[cfg(target_os = "windows")] + cx.on_action(|open_wsl: &zed_actions::wsl_actions::OpenFolderInWsl, cx| { + let create_new_window = open_wsl.create_new_window; + with_active_or_new_workspace(cx, move |workspace, window, cx| { + use gpui::PathPromptOptions; + use project::DirectoryLister; + + let paths = workspace.prompt_for_open_path( + PathPromptOptions { + files: true, + directories: true, + multiple: false, + prompt: None, + }, + DirectoryLister::Local( + workspace.project().clone(), + workspace.app_state().fs.clone(), + ), + window, + cx, + ); + + cx.spawn_in(window, async move |workspace, cx| { + use util::paths::SanitizedPath; + + let Some(paths) = paths.await.log_err().flatten() else { + return; + }; + + let paths = paths + .into_iter() + .filter_map(|path| SanitizedPath::new(&path).local_to_wsl()) + .collect::>(); + + if paths.is_empty() { + let message = indoc::indoc! { r#" + Invalid path specified when trying to open a folder inside WSL. + + Please note that Zed currently does not support opening network share folders inside wsl. + "#}; + + let _ = cx.prompt(gpui::PromptLevel::Critical, "Invalid path", Some(&message), &["Ok"]).await; + return; + } + + workspace.update_in(cx, |workspace, window, cx| { + workspace.toggle_modal(window, cx, |window, cx| { + crate::wsl_picker::WslOpenModal::new(paths, create_new_window, window, cx) + }); + }).log_err(); + }) + .detach(); + }); + }); + + #[cfg(target_os = "windows")] + cx.on_action(|open_wsl: &zed_actions::wsl_actions::OpenWsl, cx| { + let create_new_window = open_wsl.create_new_window; + with_active_or_new_workspace(cx, move |workspace, window, cx| { + let handle = cx.entity().downgrade(); + let fs = workspace.project().read(cx).fs().clone(); + workspace.toggle_modal(window, cx, |window, cx| { + RemoteServerProjects::wsl(create_new_window, fs, window, handle, cx) + }); + }); + }); + cx.on_action(|open_recent: &OpenRecent, cx| { let create_new_window = open_recent.create_new_window; with_active_or_new_workspace(cx, move |workspace, window, cx| { @@ -417,10 +488,13 @@ impl PickerDelegate for RecentProjectsDelegate { SerializedWorkspaceLocation::Local => Icon::new(IconName::Screen) .color(Color::Muted) .into_any_element(), - SerializedWorkspaceLocation::Remote(_) => { - Icon::new(IconName::Server) - .color(Color::Muted) - .into_any_element() + SerializedWorkspaceLocation::Remote(options) => { + Icon::new(match options { + RemoteConnectionOptions::Ssh { .. } => IconName::Server, + RemoteConnectionOptions::Wsl { .. } => IconName::Linux, + }) + .color(Color::Muted) + .into_any_element() } }) }) diff --git a/crates/recent_projects/src/remote_connections.rs b/crates/recent_projects/src/remote_connections.rs index 5fd52ada77e28ba3afb2b5b6337fe02f5a95a029..d3888c9878840d78f43f77e8437311a237822a81 100644 --- a/crates/recent_projects/src/remote_connections.rs +++ b/crates/recent_projects/src/remote_connections.rs @@ -18,8 +18,8 @@ use remote::{ ConnectionIdentifier, RemoteClient, RemoteConnectionOptions, RemotePlatform, SshConnectionOptions, }; -use settings::Settings; pub use settings::SshConnection; +use settings::{Settings, WslConnection}; use theme::ThemeSettings; use ui::{ ActiveTheme, Color, CommonAnimationExt, Context, Icon, IconName, IconSize, InteractiveElement, @@ -30,6 +30,8 @@ use workspace::{AppState, ModalView, Workspace}; pub struct SshSettings { pub ssh_connections: Vec, + pub wsl_connections: Vec, + /// Whether to read ~/.ssh/config for ssh connection sources. pub read_ssh_config: bool, } @@ -38,6 +40,10 @@ impl SshSettings { self.ssh_connections.clone().into_iter() } + pub fn wsl_connections(&self) -> impl Iterator + use<> { + self.wsl_connections.clone().into_iter() + } + pub fn fill_connection_options_from_settings(&self, options: &mut SshConnectionOptions) { for conn in self.ssh_connections() { if conn.host == options.host @@ -70,11 +76,39 @@ impl SshSettings { } } +#[derive(Clone, PartialEq)] +pub enum Connection { + Ssh(SshConnection), + Wsl(WslConnection), +} + +impl From for RemoteConnectionOptions { + fn from(val: Connection) -> Self { + match val { + Connection::Ssh(conn) => RemoteConnectionOptions::Ssh(conn.into()), + Connection::Wsl(conn) => RemoteConnectionOptions::Wsl(conn.into()), + } + } +} + +impl From for Connection { + fn from(val: SshConnection) -> Self { + Connection::Ssh(val) + } +} + +impl From for Connection { + fn from(val: WslConnection) -> Self { + Connection::Wsl(val) + } +} + impl Settings for SshSettings { fn from_defaults(content: &settings::SettingsContent, _cx: &mut App) -> Self { let remote = &content.remote; Self { ssh_connections: remote.ssh_connections.clone().unwrap_or_default(), + wsl_connections: remote.wsl_connections.clone().unwrap_or_default(), read_ssh_config: remote.read_ssh_config.unwrap(), } } @@ -83,6 +117,9 @@ impl Settings for SshSettings { if let Some(ssh_connections) = content.remote.ssh_connections.clone() { self.ssh_connections.extend(ssh_connections) } + if let Some(wsl_connections) = content.remote.wsl_connections.clone() { + self.wsl_connections.extend(wsl_connections) + } self.read_ssh_config .merge_from(&content.remote.read_ssh_config); } @@ -91,6 +128,7 @@ impl Settings for SshSettings { pub struct RemoteConnectionPrompt { connection_string: SharedString, nickname: Option, + is_wsl: bool, status_message: Option, prompt: Option<(Entity, oneshot::Sender)>, cancellation: Option>, @@ -115,12 +153,14 @@ impl RemoteConnectionPrompt { pub(crate) fn new( connection_string: String, nickname: Option, + is_wsl: bool, window: &mut Window, cx: &mut Context, ) -> Self { Self { connection_string: connection_string.into(), nickname: nickname.map(|nickname| nickname.into()), + is_wsl, editor: cx.new(|cx| Editor::single_line(window, cx)), status_message: None, cancellation: None, @@ -249,15 +289,16 @@ impl RemoteConnectionModal { window: &mut Window, cx: &mut Context, ) -> Self { - let (connection_string, nickname) = match connection_options { + let (connection_string, nickname, is_wsl) = match connection_options { RemoteConnectionOptions::Ssh(options) => { - (options.connection_string(), options.nickname.clone()) + (options.connection_string(), options.nickname.clone(), false) } - RemoteConnectionOptions::Wsl(options) => (options.distro_name.clone(), None), + RemoteConnectionOptions::Wsl(options) => (options.distro_name.clone(), None, true), }; Self { - prompt: cx - .new(|cx| RemoteConnectionPrompt::new(connection_string, nickname, window, cx)), + prompt: cx.new(|cx| { + RemoteConnectionPrompt::new(connection_string, nickname, is_wsl, window, cx) + }), finished: false, paths, } @@ -288,6 +329,7 @@ pub(crate) struct SshConnectionHeader { pub(crate) connection_string: SharedString, pub(crate) paths: Vec, pub(crate) nickname: Option, + pub(crate) is_wsl: bool, } impl RenderOnce for SshConnectionHeader { @@ -303,6 +345,11 @@ impl RenderOnce for SshConnectionHeader { (self.connection_string, None) }; + let icon = match self.is_wsl { + true => IconName::Linux, + false => IconName::Server, + }; + h_flex() .px(DynamicSpacing::Base12.rems(cx)) .pt(DynamicSpacing::Base08.rems(cx)) @@ -310,7 +357,7 @@ impl RenderOnce for SshConnectionHeader { .rounded_t_sm() .w_full() .gap_1p5() - .child(Icon::new(IconName::Server).size(IconSize::Small)) + .child(Icon::new(icon).size(IconSize::Small)) .child( h_flex() .gap_1() @@ -342,6 +389,7 @@ impl Render for RemoteConnectionModal { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl ui::IntoElement { let nickname = self.prompt.read(cx).nickname.clone(); let connection_string = self.prompt.read(cx).connection_string.clone(); + let is_wsl = self.prompt.read(cx).is_wsl; let theme = cx.theme().clone(); let body_color = theme.colors().editor_background; @@ -360,6 +408,7 @@ impl Render for RemoteConnectionModal { paths: self.paths.clone(), connection_string, nickname, + is_wsl, } .render(window, cx), ) @@ -511,6 +560,36 @@ pub fn connect_over_ssh( ) } +pub fn connect( + unique_identifier: ConnectionIdentifier, + connection_options: RemoteConnectionOptions, + ui: Entity, + window: &mut Window, + cx: &mut App, +) -> Task>>> { + let window = window.window_handle(); + let known_password = match &connection_options { + RemoteConnectionOptions::Ssh(ssh_connection_options) => { + ssh_connection_options.password.clone() + } + _ => None, + }; + let (tx, rx) = oneshot::channel(); + ui.update(cx, |ui, _cx| ui.set_cancellation_tx(tx)); + + remote::RemoteClient::new( + unique_identifier, + connection_options, + rx, + Arc::new(RemoteClientDelegate { + window, + ui: ui.downgrade(), + known_password, + }), + cx, + ) +} + pub async fn open_remote_project( connection_options: RemoteConnectionOptions, paths: Vec, diff --git a/crates/recent_projects/src/remote_servers.rs b/crates/recent_projects/src/remote_servers.rs index d100c33905b89cff1b13bf6e1d0d1add4cf84605..6e9a5ea6a9685962d5b37c904846e14f5a4e821e 100644 --- a/crates/recent_projects/src/remote_servers.rs +++ b/crates/recent_projects/src/remote_servers.rs @@ -1,7 +1,7 @@ use crate::{ remote_connections::{ - RemoteConnectionModal, RemoteConnectionPrompt, SshConnection, SshConnectionHeader, - SshSettings, connect_over_ssh, open_remote_project, + Connection, RemoteConnectionModal, RemoteConnectionPrompt, SshConnection, + SshConnectionHeader, SshSettings, connect, connect_over_ssh, open_remote_project, }, ssh_config::parse_ssh_config_hosts, }; @@ -13,15 +13,16 @@ use gpui::{ FocusHandle, Focusable, PromptLevel, ScrollHandle, Subscription, Task, WeakEntity, Window, canvas, }; +use log::info; use paths::{global_ssh_config_file, user_ssh_config_file}; use picker::Picker; use project::{Fs, Project}; use remote::{ - RemoteClient, RemoteConnectionOptions, SshConnectionOptions, + RemoteClient, RemoteConnectionOptions, SshConnectionOptions, WslConnectionOptions, remote_client::ConnectionIdentifier, }; use settings::{ - RemoteSettingsContent, Settings, SettingsStore, SshProject, update_settings_file, + RemoteSettingsContent, Settings as _, SettingsStore, SshProject, update_settings_file, watch_config_file, }; use smol::stream::StreamExt as _; @@ -82,27 +83,77 @@ impl CreateRemoteServer { } } +#[cfg(target_os = "windows")] +struct AddWslDistro { + picker: Entity>, + connection_prompt: Option>, + _creating: Option>, +} + +#[cfg(target_os = "windows")] +impl AddWslDistro { + fn new(window: &mut Window, cx: &mut Context) -> Self { + use crate::wsl_picker::{WslDistroSelected, WslPickerDelegate, WslPickerDismissed}; + + let delegate = WslPickerDelegate::new(); + let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx).modal(false)); + + cx.subscribe_in( + &picker, + window, + |this, _, _: &WslDistroSelected, window, cx| { + this.confirm(&menu::Confirm, window, cx); + }, + ) + .detach(); + + cx.subscribe_in( + &picker, + window, + |this, _, _: &WslPickerDismissed, window, cx| { + this.cancel(&menu::Cancel, window, cx); + }, + ) + .detach(); + + AddWslDistro { + picker, + connection_prompt: None, + _creating: None, + } + } +} + +enum ProjectPickerData { + Ssh { + connection_string: SharedString, + nickname: Option, + }, + Wsl { + distro_name: SharedString, + }, +} + struct ProjectPicker { - connection_string: SharedString, - nickname: Option, + data: ProjectPickerData, picker: Entity>, _path_task: Shared>>, } struct EditNicknameState { - index: usize, + index: SshServerIndex, editor: Entity, } impl EditNicknameState { - fn new(index: usize, window: &mut Window, cx: &mut App) -> Self { + fn new(index: SshServerIndex, window: &mut Window, cx: &mut App) -> Self { let this = Self { index, editor: cx.new(|cx| Editor::single_line(window, cx)), }; let starting_text = SshSettings::get_global(cx) .ssh_connections() - .nth(index) + .nth(index.0) .and_then(|state| state.nickname) .filter(|text| !text.is_empty()); this.editor.update(cx, |this, cx| { @@ -125,8 +176,8 @@ impl Focusable for ProjectPicker { impl ProjectPicker { fn new( create_new_window: bool, - ix: usize, - connection: SshConnectionOptions, + index: ServerIndex, + connection: RemoteConnectionOptions, project: Entity, home_dir: RemotePathBuf, path_style: PathStyle, @@ -145,8 +196,16 @@ impl ProjectPicker { picker.set_query(home_dir.to_string(), window, cx); picker }); - let connection_string = connection.connection_string().into(); - let nickname = connection.nickname.clone().map(|nick| nick.into()); + + let data = match &connection { + RemoteConnectionOptions::Ssh(connection) => ProjectPickerData::Ssh { + connection_string: connection.connection_string().into(), + nickname: connection.nickname.clone().map(|nick| nick.into()), + }, + RemoteConnectionOptions::Wsl(connection) => ProjectPickerData::Wsl { + distro_name: connection.distro_name.clone().into(), + }, + }; let _path_task = cx .spawn_in(window, { let workspace = workspace; @@ -181,14 +240,26 @@ impl ProjectPicker { .iter() .map(|path| path.to_string_lossy().to_string()) .collect(); - move |setting, _| { - if let Some(server) = setting - .remote - .ssh_connections - .as_mut() - .and_then(|connections| connections.get_mut(ix)) - { - server.projects.insert(SshProject { paths }); + move |settings, _| match index { + ServerIndex::Ssh(index) => { + if let Some(server) = settings + .remote + .ssh_connections + .as_mut() + .and_then(|connections| connections.get_mut(index.0)) + { + server.projects.insert(SshProject { paths }); + }; + } + ServerIndex::Wsl(index) => { + if let Some(server) = settings + .remote + .wsl_connections + .as_mut() + .and_then(|connections| connections.get_mut(index.0)) + { + server.projects.insert(SshProject { paths }); + }; } } }); @@ -208,12 +279,7 @@ impl ProjectPicker { .log_err()?; open_remote_project_with_existing_connection( - RemoteConnectionOptions::Ssh(connection), - project, - paths, - app_state, - window, - cx, + connection, project, paths, app_state, window, cx, ) .await .log_err(); @@ -229,8 +295,7 @@ impl ProjectPicker { cx.new(|_| Self { _path_task, picker, - connection_string, - nickname, + data, }) } } @@ -238,14 +303,25 @@ impl ProjectPicker { impl gpui::Render for ProjectPicker { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { v_flex() - .child( - SshConnectionHeader { - connection_string: self.connection_string.clone(), + .child(match &self.data { + ProjectPickerData::Ssh { + connection_string, + nickname, + } => SshConnectionHeader { + connection_string: connection_string.clone(), paths: Default::default(), - nickname: self.nickname.clone(), + nickname: nickname.clone(), + is_wsl: false, } .render(window, cx), - ) + ProjectPickerData::Wsl { distro_name } => SshConnectionHeader { + connection_string: distro_name.clone(), + paths: Default::default(), + nickname: None, + is_wsl: true, + } + .render(window, cx), + }) .child( div() .border_t_1() @@ -255,13 +331,48 @@ impl gpui::Render for ProjectPicker { } } +#[repr(transparent)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +struct SshServerIndex(usize); +impl std::fmt::Display for SshServerIndex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +#[repr(transparent)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +struct WslServerIndex(usize); +impl std::fmt::Display for WslServerIndex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +enum ServerIndex { + Ssh(SshServerIndex), + Wsl(WslServerIndex), +} +impl From for ServerIndex { + fn from(index: SshServerIndex) -> Self { + Self::Ssh(index) + } +} +impl From for ServerIndex { + fn from(index: WslServerIndex) -> Self { + Self::Wsl(index) + } +} + #[derive(Clone)] enum RemoteEntry { Project { open_folder: NavigableEntry, projects: Vec<(NavigableEntry, SshProject)>, configure: NavigableEntry, - connection: SshConnection, + connection: Connection, + index: ServerIndex, }, SshConfig { open_folder: NavigableEntry, @@ -274,13 +385,16 @@ impl RemoteEntry { matches!(self, Self::Project { .. }) } - fn connection(&self) -> Cow<'_, SshConnection> { + fn connection(&self) -> Cow<'_, Connection> { match self { Self::Project { connection, .. } => Cow::Borrowed(connection), - Self::SshConfig { host, .. } => Cow::Owned(SshConnection { - host: host.clone(), - ..SshConnection::default() - }), + Self::SshConfig { host, .. } => Cow::Owned( + SshConnection { + host: host.clone(), + ..SshConnection::default() + } + .into(), + ), } } } @@ -289,6 +403,7 @@ impl RemoteEntry { struct DefaultState { scroll_handle: ScrollHandle, add_new_server: NavigableEntry, + add_new_wsl: NavigableEntry, servers: Vec, } @@ -296,13 +411,15 @@ impl DefaultState { fn new(ssh_config_servers: &BTreeSet, cx: &mut App) -> Self { let handle = ScrollHandle::new(); let add_new_server = NavigableEntry::new(&handle, cx); + let add_new_wsl = NavigableEntry::new(&handle, cx); let ssh_settings = SshSettings::get_global(cx); let read_ssh_config = ssh_settings.read_ssh_config; - let mut servers: Vec = ssh_settings + let ssh_servers = ssh_settings .ssh_connections() - .map(|connection| { + .enumerate() + .map(|(index, connection)| { let open_folder = NavigableEntry::new(&handle, cx); let configure = NavigableEntry::new(&handle, cx); let projects = connection @@ -314,16 +431,42 @@ impl DefaultState { open_folder, configure, projects, - connection, + index: ServerIndex::Ssh(SshServerIndex(index)), + connection: connection.into(), } - }) - .collect(); + }); + + let wsl_servers = ssh_settings + .wsl_connections() + .enumerate() + .map(|(index, connection)| { + let open_folder = NavigableEntry::new(&handle, cx); + let configure = NavigableEntry::new(&handle, cx); + let projects = connection + .projects + .iter() + .map(|project| (NavigableEntry::new(&handle, cx), project.clone())) + .collect(); + RemoteEntry::Project { + open_folder, + configure, + projects, + index: ServerIndex::Wsl(WslServerIndex(index)), + connection: connection.into(), + } + }); + + let mut servers = ssh_servers.chain(wsl_servers).collect::>(); if read_ssh_config { let mut extra_servers_from_config = ssh_config_servers.clone(); for server in &servers { - if let RemoteEntry::Project { connection, .. } = server { - extra_servers_from_config.remove(&connection.host); + if let RemoteEntry::Project { + connection: Connection::Ssh(ssh_options), + .. + } = server + { + extra_servers_from_config.remove(&SharedString::new(ssh_options.host.clone())); } } servers.extend(extra_servers_from_config.into_iter().map(|host| { @@ -337,23 +480,43 @@ impl DefaultState { Self { scroll_handle: handle, add_new_server, + add_new_wsl, servers, } } } #[derive(Clone)] -struct ViewServerOptionsState { - server_index: usize, - connection: SshConnection, - entries: [NavigableEntry; 4], +enum ViewServerOptionsState { + Ssh { + connection: SshConnectionOptions, + server_index: SshServerIndex, + entries: [NavigableEntry; 4], + }, + Wsl { + connection: WslConnectionOptions, + server_index: WslServerIndex, + entries: [NavigableEntry; 2], + }, +} + +impl ViewServerOptionsState { + fn entries(&self) -> &[NavigableEntry] { + match self { + Self::Ssh { entries, .. } => entries, + Self::Wsl { entries, .. } => entries, + } + } } + enum Mode { Default(DefaultState), ViewServerOptions(ViewServerOptionsState), EditNickname(EditNicknameState), ProjectPicker(Entity), CreateRemoteServer(CreateRemoteServer), + #[cfg(target_os = "windows")] + AddWslDistro(AddWslDistro), } impl Mode { @@ -361,13 +524,50 @@ impl Mode { Self::Default(DefaultState::new(ssh_config_servers, cx)) } } + impl RemoteServerProjects { + #[cfg(target_os = "windows")] + pub fn wsl( + create_new_window: bool, + fs: Arc, + window: &mut Window, + workspace: WeakEntity, + cx: &mut Context, + ) -> Self { + Self::new_inner( + Mode::AddWslDistro(AddWslDistro::new(window, cx)), + create_new_window, + fs, + window, + workspace, + cx, + ) + } + pub fn new( create_new_window: bool, fs: Arc, window: &mut Window, workspace: WeakEntity, cx: &mut Context, + ) -> Self { + Self::new_inner( + Mode::default_mode(&BTreeSet::new(), cx), + create_new_window, + fs, + window, + workspace, + cx, + ) + } + + fn new_inner( + mode: Mode, + create_new_window: bool, + fs: Arc, + window: &mut Window, + workspace: WeakEntity, + cx: &mut Context, ) -> Self { let focus_handle = cx.focus_handle(); let mut read_ssh_config = SshSettings::get_global(cx).read_ssh_config; @@ -398,7 +598,7 @@ impl RemoteServerProjects { }); Self { - mode: Mode::default_mode(&BTreeSet::new(), cx), + mode, focus_handle, workspace, retained_connections: Vec::new(), @@ -409,10 +609,10 @@ impl RemoteServerProjects { } } - pub fn project_picker( + fn project_picker( create_new_window: bool, - ix: usize, - connection_options: remote::SshConnectionOptions, + index: ServerIndex, + connection_options: remote::RemoteConnectionOptions, project: Entity, home_dir: RemotePathBuf, path_style: PathStyle, @@ -424,7 +624,7 @@ impl RemoteServerProjects { let mut this = Self::new(create_new_window, fs, window, workspace.clone(), cx); this.mode = Mode::ProjectPicker(ProjectPicker::new( create_new_window, - ix, + index, connection_options, project, home_dir, @@ -465,6 +665,7 @@ impl RemoteServerProjects { RemoteConnectionPrompt::new( connection_options.connection_string(), connection_options.nickname.clone(), + false, window, cx, ) @@ -484,6 +685,7 @@ impl RemoteServerProjects { match connection.await { Some(Some(client)) => this .update_in(cx, |this, window, cx| { + info!("ssh server created"); telemetry::event!("SSH Server Created"); this.retained_connections.push(client); this.add_ssh_server(connection_options, cx); @@ -521,25 +723,106 @@ impl RemoteServerProjects { }); } + #[cfg(target_os = "windows")] + fn connect_wsl_distro( + &mut self, + picker: Entity>, + distro: String, + window: &mut Window, + cx: &mut Context, + ) { + let connection_options = WslConnectionOptions { + distro_name: distro, + user: None, + }; + + let prompt = cx.new(|cx| { + RemoteConnectionPrompt::new( + connection_options.distro_name.clone(), + None, + true, + window, + cx, + ) + }); + let connection = connect( + ConnectionIdentifier::setup(), + connection_options.clone().into(), + prompt.clone(), + window, + cx, + ) + .prompt_err("Failed to connect", window, cx, |_, _, _| None); + + let wsl_picker = picker.clone(); + let creating = cx.spawn_in(window, async move |this, cx| { + match connection.await { + Some(Some(client)) => this + .update_in(cx, |this, window, cx| { + telemetry::event!("WSL Distro Added"); + this.retained_connections.push(client); + this.add_wsl_distro(connection_options, cx); + this.mode = Mode::default_mode(&BTreeSet::new(), cx); + this.focus_handle(cx).focus(window); + cx.notify() + }) + .log_err(), + _ => this + .update(cx, |this, cx| { + this.mode = Mode::AddWslDistro(AddWslDistro { + picker: wsl_picker, + connection_prompt: None, + _creating: None, + }); + cx.notify() + }) + .log_err(), + }; + () + }); + + self.mode = Mode::AddWslDistro(AddWslDistro { + picker, + connection_prompt: Some(prompt), + _creating: Some(creating), + }); + } + fn view_server_options( &mut self, - (server_index, connection): (usize, SshConnection), + (server_index, connection): (ServerIndex, RemoteConnectionOptions), window: &mut Window, cx: &mut Context, ) { - self.mode = Mode::ViewServerOptions(ViewServerOptionsState { - server_index, - connection, - entries: std::array::from_fn(|_| NavigableEntry::focusable(cx)), + self.mode = Mode::ViewServerOptions(match (server_index, connection) { + (ServerIndex::Ssh(server_index), RemoteConnectionOptions::Ssh(connection)) => { + ViewServerOptionsState::Ssh { + connection, + server_index, + entries: std::array::from_fn(|_| NavigableEntry::focusable(cx)), + } + } + (ServerIndex::Wsl(server_index), RemoteConnectionOptions::Wsl(connection)) => { + ViewServerOptionsState::Wsl { + connection, + server_index, + entries: std::array::from_fn(|_| NavigableEntry::focusable(cx)), + } + } + _ => { + log::error!("server index and connection options mismatch"); + self.mode = Mode::default_mode(&BTreeSet::default(), cx); + return; + } }); self.focus_handle(cx).focus(window); cx.notify(); } - fn create_ssh_project( + fn create_remote_project( &mut self, - ix: usize, - ssh_connection: SshConnection, + index: ServerIndex, + connection_options: RemoteConnectionOptions, window: &mut Window, cx: &mut Context, ) { @@ -548,17 +831,11 @@ impl RemoteServerProjects { }; let create_new_window = self.create_new_window; - let connection_options: SshConnectionOptions = ssh_connection.into(); workspace.update(cx, |_, cx| { cx.defer_in(window, move |workspace, window, cx| { let app_state = workspace.app_state().clone(); workspace.toggle_modal(window, cx, |window, cx| { - RemoteConnectionModal::new( - &RemoteConnectionOptions::Ssh(connection_options.clone()), - Vec::new(), - window, - cx, - ) + RemoteConnectionModal::new(&connection_options, Vec::new(), window, cx) }); let prompt = workspace .active_modal::(cx) @@ -567,7 +844,7 @@ impl RemoteServerProjects { .prompt .clone(); - let connect = connect_over_ssh( + let connect = connect( ConnectionIdentifier::setup(), connection_options.clone(), prompt, @@ -628,7 +905,7 @@ impl RemoteServerProjects { workspace.toggle_modal(window, cx, |window, cx| { RemoteServerProjects::project_picker( create_new_window, - ix, + index, connection_options, project, home_dir, @@ -666,7 +943,7 @@ impl RemoteServerProjects { let index = state.index; self.update_settings_file(cx, move |setting, _| { if let Some(connections) = setting.ssh_connections.as_mut() - && let Some(connection) = connections.get_mut(index) + && let Some(connection) = connections.get_mut(index.0) { connection.nickname = text; } @@ -674,6 +951,12 @@ impl RemoteServerProjects { self.mode = Mode::default_mode(&self.ssh_config_servers, cx); self.focus_handle.focus(window); } + #[cfg(target_os = "windows")] + Mode::AddWslDistro(state) => { + let delegate = &state.picker.read(cx).delegate; + let distro = delegate.selected_distro().unwrap(); + self.connect_wsl_distro(state.picker.clone(), distro, window, cx); + } } } @@ -706,11 +989,19 @@ impl RemoteServerProjects { cx: &mut Context, ) -> impl IntoElement { let connection = ssh_server.connection().into_owned(); - let (main_label, aux_label) = if let Some(nickname) = connection.nickname.clone() { - let aux_label = SharedString::from(format!("({})", connection.host)); - (nickname.into(), Some(aux_label)) - } else { - (connection.host.clone(), None) + + let (main_label, aux_label, is_wsl) = match &connection { + Connection::Ssh(connection) => { + if let Some(nickname) = connection.nickname.clone() { + let aux_label = SharedString::from(format!("({})", connection.host)); + (nickname.into(), Some(aux_label), false) + } else { + (connection.host.clone(), None, false) + } + } + Connection::Wsl(wsl_connection_options) => { + (wsl_connection_options.distro_name.clone(), None, true) + } }; v_flex() .w_full() @@ -724,11 +1015,23 @@ impl RemoteServerProjects { .gap_1() .overflow_hidden() .child( - div().max_w_96().overflow_hidden().text_ellipsis().child( - Label::new(main_label) - .size(LabelSize::Small) - .color(Color::Muted), - ), + h_flex() + .gap_1() + .max_w_96() + .overflow_hidden() + .text_ellipsis() + .when(is_wsl, |this| { + this.child( + Label::new("WSL:") + .size(LabelSize::Small) + .color(Color::Muted), + ) + }) + .child( + Label::new(main_label) + .size(LabelSize::Small) + .color(Color::Muted), + ), ) .children( aux_label.map(|label| { @@ -742,98 +1045,114 @@ impl RemoteServerProjects { projects, configure, connection, - } => List::new() - .empty_message("No projects.") - .children(projects.iter().enumerate().map(|(pix, p)| { - v_flex().gap_0p5().child(self.render_ssh_project( - ix, - ssh_server.clone(), - pix, - p, - window, - cx, - )) - })) - .child( - h_flex() - .id(("new-remote-project-container", ix)) - .track_focus(&open_folder.focus_handle) - .anchor_scroll(open_folder.scroll_anchor.clone()) - .on_action(cx.listener({ - let ssh_connection = connection.clone(); - move |this, _: &menu::Confirm, window, cx| { - this.create_ssh_project(ix, ssh_connection.clone(), window, cx); - } - })) - .child( - ListItem::new(("new-remote-project", ix)) - .toggle_state( - open_folder.focus_handle.contains_focused(window, cx), - ) - .inset(true) - .spacing(ui::ListItemSpacing::Sparse) - .start_slot(Icon::new(IconName::Plus).color(Color::Muted)) - .child(Label::new("Open Folder")) - .on_click(cx.listener({ - let ssh_connection = connection.clone(); - move |this, _, window, cx| { - this.create_ssh_project( - ix, - ssh_connection.clone(), - window, - cx, - ); - } - })), - ), - ) - .child( - h_flex() - .id(("server-options-container", ix)) - .track_focus(&configure.focus_handle) - .anchor_scroll(configure.scroll_anchor.clone()) - .on_action(cx.listener({ - let ssh_connection = connection.clone(); - move |this, _: &menu::Confirm, window, cx| { - this.view_server_options( - (ix, ssh_connection.clone()), - window, - cx, - ); - } - })) - .child( - ListItem::new(("server-options", ix)) - .toggle_state( - configure.focus_handle.contains_focused(window, cx), - ) - .inset(true) - .spacing(ui::ListItemSpacing::Sparse) - .start_slot(Icon::new(IconName::Settings).color(Color::Muted)) - .child(Label::new("View Server Options")) - .on_click(cx.listener({ - let ssh_connection = connection.clone(); - move |this, _, window, cx| { - this.view_server_options( - (ix, ssh_connection.clone()), - window, - cx, - ); - } - })), - ), - ), + index, + } => { + let index = *index; + List::new() + .empty_message("No projects.") + .children(projects.iter().enumerate().map(|(pix, p)| { + v_flex().gap_0p5().child(self.render_ssh_project( + index, + ssh_server.clone(), + pix, + p, + window, + cx, + )) + })) + .child( + h_flex() + .id(("new-remote-project-container", ix)) + .track_focus(&open_folder.focus_handle) + .anchor_scroll(open_folder.scroll_anchor.clone()) + .on_action(cx.listener({ + let connection = connection.clone(); + move |this, _: &menu::Confirm, window, cx| { + this.create_remote_project( + index, + connection.clone().into(), + window, + cx, + ); + } + })) + .child( + ListItem::new(("new-remote-project", ix)) + .toggle_state( + open_folder.focus_handle.contains_focused(window, cx), + ) + .inset(true) + .spacing(ui::ListItemSpacing::Sparse) + .start_slot(Icon::new(IconName::Plus).color(Color::Muted)) + .child(Label::new("Open Folder")) + .on_click(cx.listener({ + let connection = connection.clone(); + move |this, _, window, cx| { + this.create_remote_project( + index, + connection.clone().into(), + window, + cx, + ); + } + })), + ), + ) + .child( + h_flex() + .id(("server-options-container", ix)) + .track_focus(&configure.focus_handle) + .anchor_scroll(configure.scroll_anchor.clone()) + .on_action(cx.listener({ + let connection = connection.clone(); + move |this, _: &menu::Confirm, window, cx| { + this.view_server_options( + (index, connection.clone().into()), + window, + cx, + ); + } + })) + .child( + ListItem::new(("server-options", ix)) + .toggle_state( + configure.focus_handle.contains_focused(window, cx), + ) + .inset(true) + .spacing(ui::ListItemSpacing::Sparse) + .start_slot( + Icon::new(IconName::Settings).color(Color::Muted), + ) + .child(Label::new("View Server Options")) + .on_click(cx.listener({ + let ssh_connection = connection.clone(); + move |this, _, window, cx| { + this.view_server_options( + (index, ssh_connection.clone().into()), + window, + cx, + ); + } + })), + ), + ) + } RemoteEntry::SshConfig { open_folder, host } => List::new().child( h_flex() .id(("new-remote-project-container", ix)) .track_focus(&open_folder.focus_handle) .anchor_scroll(open_folder.scroll_anchor.clone()) .on_action(cx.listener({ - let ssh_connection = connection.clone(); + let connection = connection.clone(); let host = host.clone(); move |this, _: &menu::Confirm, window, cx| { let new_ix = this.create_host_from_ssh_config(&host, cx); - this.create_ssh_project(new_ix, ssh_connection.clone(), window, cx); + this.create_remote_project( + new_ix.into(), + connection.clone().into(), + window, + cx, + ); } })) .child( @@ -844,13 +1163,12 @@ impl RemoteServerProjects { .start_slot(Icon::new(IconName::Plus).color(Color::Muted)) .child(Label::new("Open Folder")) .on_click(cx.listener({ - let ssh_connection = connection; let host = host.clone(); move |this, _, window, cx| { let new_ix = this.create_host_from_ssh_config(&host, cx); - this.create_ssh_project( - new_ix, - ssh_connection.clone(), + this.create_remote_project( + new_ix.into(), + connection.clone().into(), window, cx, ); @@ -863,7 +1181,7 @@ impl RemoteServerProjects { fn render_ssh_project( &mut self, - server_ix: usize, + server_ix: ServerIndex, server: RemoteEntry, ix: usize, (navigation, project): &(NavigableEntry, SshProject), @@ -872,7 +1190,13 @@ impl RemoteServerProjects { ) -> impl IntoElement { let create_new_window = self.create_new_window; let is_from_zed = server.is_from_zed(); - let element_id_base = SharedString::from(format!("remote-project-{server_ix}")); + let element_id_base = SharedString::from(format!( + "remote-project-{}", + match server_ix { + ServerIndex::Ssh(index) => format!("ssh-{index}"), + ServerIndex::Wsl(index) => format!("wsl-{index}"), + } + )); let container_element_id_base = SharedString::from(format!("remote-project-container-{element_id_base}")); @@ -900,7 +1224,7 @@ impl RemoteServerProjects { cx.spawn_in(window, async move |_, cx| { let result = open_remote_project( - RemoteConnectionOptions::Ssh(server.into()), + server.into(), project.paths.into_iter().map(PathBuf::from).collect(), app_state, OpenOptions { @@ -957,25 +1281,31 @@ impl RemoteServerProjects { let secondary_confirm = e.modifiers().platform; callback(this, secondary_confirm, window, cx) })) - .when(is_from_zed, |server_list_item| { - server_list_item.end_hover_slot::(Some( - div() - .mr_2() - .child({ - let project = project.clone(); - // Right-margin to offset it from the Scrollbar - IconButton::new("remove-remote-project", IconName::Trash) - .icon_size(IconSize::Small) - .shape(IconButtonShape::Square) - .size(ButtonSize::Large) - .tooltip(Tooltip::text("Delete Remote Project")) - .on_click(cx.listener(move |this, _, _, cx| { - this.delete_ssh_project(server_ix, &project, cx) - })) - }) - .into_any_element(), - )) - }), + .when( + is_from_zed && matches!(server_ix, ServerIndex::Ssh(_)), + |server_list_item| { + let ServerIndex::Ssh(server_ix) = server_ix else { + unreachable!() + }; + server_list_item.end_hover_slot::(Some( + div() + .mr_2() + .child({ + let project = project.clone(); + // Right-margin to offset it from the Scrollbar + IconButton::new("remove-remote-project", IconName::Trash) + .icon_size(IconSize::Small) + .shape(IconButtonShape::Square) + .size(ButtonSize::Large) + .tooltip(Tooltip::text("Delete Remote Project")) + .on_click(cx.listener(move |this, _, _, cx| { + this.delete_ssh_project(server_ix, &project, cx) + })) + }) + .into_any_element(), + )) + }, + ), ) } @@ -994,27 +1324,58 @@ impl RemoteServerProjects { update_settings_file(fs, cx, move |setting, cx| f(&mut setting.remote, cx)); } - fn delete_ssh_server(&mut self, server: usize, cx: &mut Context) { + fn delete_ssh_server(&mut self, server: SshServerIndex, cx: &mut Context) { self.update_settings_file(cx, move |setting, _| { if let Some(connections) = setting.ssh_connections.as_mut() { - connections.remove(server); + connections.remove(server.0); } }); } - fn delete_ssh_project(&mut self, server: usize, project: &SshProject, cx: &mut Context) { + fn delete_ssh_project( + &mut self, + server: SshServerIndex, + project: &SshProject, + cx: &mut Context, + ) { let project = project.clone(); self.update_settings_file(cx, move |setting, _| { if let Some(server) = setting .ssh_connections .as_mut() - .and_then(|connections| connections.get_mut(server)) + .and_then(|connections| connections.get_mut(server.0)) { server.projects.remove(&project); } }); } + #[cfg(target_os = "windows")] + fn add_wsl_distro( + &mut self, + connection_options: remote::WslConnectionOptions, + cx: &mut Context, + ) { + self.update_settings_file(cx, move |setting, _| { + setting + .wsl_connections + .get_or_insert(Default::default()) + .push(crate::remote_connections::WslConnection { + distro_name: SharedString::from(connection_options.distro_name), + user: connection_options.user, + projects: BTreeSet::new(), + }) + }); + } + + fn delete_wsl_distro(&mut self, server: WslServerIndex, cx: &mut Context) { + self.update_settings_file(cx, move |setting, _| { + if let Some(connections) = setting.wsl_connections.as_mut() { + connections.remove(server.0); + } + }); + } + fn add_ssh_server( &mut self, connection_options: remote::SshConnectionOptions, @@ -1112,222 +1473,96 @@ impl RemoteServerProjects { ) } + #[cfg(target_os = "windows")] + fn render_add_wsl_distro( + &self, + state: &AddWslDistro, + window: &mut Window, + cx: &mut Context, + ) -> impl IntoElement { + let connection_prompt = state.connection_prompt.clone(); + + state.picker.update(cx, |picker, cx| { + picker.focus_handle(cx).focus(window); + }); + + v_flex() + .id("add-wsl-distro") + .overflow_hidden() + .size_full() + .flex_1() + .map(|this| { + if let Some(connection_prompt) = connection_prompt { + this.child(connection_prompt) + } else { + this.child(state.picker.clone()) + } + }) + } + fn render_view_options( &mut self, - ViewServerOptionsState { - server_index, - connection, - entries, - }: ViewServerOptionsState, + options: ViewServerOptionsState, window: &mut Window, cx: &mut Context, ) -> impl IntoElement { - let connection_string = connection.host.clone(); + let last_entry = options.entries().last().unwrap(); let mut view = Navigable::new( div() .track_focus(&self.focus_handle(cx)) .size_full() - .child( - SshConnectionHeader { - connection_string: connection_string.clone(), + .child(match &options { + ViewServerOptionsState::Ssh { connection, .. } => SshConnectionHeader { + connection_string: connection.host.clone().into(), paths: Default::default(), nickname: connection.nickname.clone().map(|s| s.into()), + is_wsl: false, } - .render(window, cx), - ) + .render(window, cx) + .into_any_element(), + ViewServerOptionsState::Wsl { connection, .. } => SshConnectionHeader { + connection_string: connection.distro_name.clone().into(), + paths: Default::default(), + nickname: None, + is_wsl: true, + } + .render(window, cx) + .into_any_element(), + }) .child( v_flex() .pb_1() .child(ListSeparator) - .child({ - let label = if connection.nickname.is_some() { - "Edit Nickname" - } else { - "Add Nickname to Server" - }; - div() - .id("ssh-options-add-nickname") - .track_focus(&entries[0].focus_handle) - .on_action(cx.listener( - move |this, _: &menu::Confirm, window, cx| { - this.mode = Mode::EditNickname(EditNicknameState::new( - server_index, - window, - cx, - )); - cx.notify(); - }, - )) - .child( - ListItem::new("add-nickname") - .toggle_state( - entries[0].focus_handle.contains_focused(window, cx), - ) - .inset(true) - .spacing(ui::ListItemSpacing::Sparse) - .start_slot(Icon::new(IconName::Pencil).color(Color::Muted)) - .child(Label::new(label)) - .on_click(cx.listener(move |this, _, window, cx| { - this.mode = Mode::EditNickname(EditNicknameState::new( - server_index, - window, - cx, - )); - cx.notify(); - })), - ) - }) - .child({ - let workspace = self.workspace.clone(); - fn callback( - workspace: WeakEntity, - connection_string: SharedString, - cx: &mut App, - ) { - cx.write_to_clipboard(ClipboardItem::new_string( - connection_string.to_string(), - )); - workspace - .update(cx, |this, cx| { - struct SshServerAddressCopiedToClipboard; - let notification = format!( - "Copied server address ({}) to clipboard", - connection_string - ); - - this.show_toast( - Toast::new( - NotificationId::composite::< - SshServerAddressCopiedToClipboard, - >( - connection_string.clone() - ), - notification, - ) - .autohide(), - cx, - ); - }) - .ok(); - } - div() - .id("ssh-options-copy-server-address") - .track_focus(&entries[1].focus_handle) - .on_action({ - let connection_string = connection_string.clone(); - let workspace = self.workspace.clone(); - move |_: &menu::Confirm, _, cx| { - callback(workspace.clone(), connection_string.clone(), cx); - } - }) - .child( - ListItem::new("copy-server-address") - .toggle_state( - entries[1].focus_handle.contains_focused(window, cx), - ) - .inset(true) - .spacing(ui::ListItemSpacing::Sparse) - .start_slot(Icon::new(IconName::Copy).color(Color::Muted)) - .child(Label::new("Copy Server Address")) - .end_hover_slot( - Label::new(connection_string.clone()) - .color(Color::Muted), - ) - .on_click({ - let connection_string = connection_string.clone(); - move |_, _, cx| { - callback( - workspace.clone(), - connection_string.clone(), - cx, - ); - } - }), - ) - }) - .child({ - fn remove_ssh_server( - remote_servers: Entity, - index: usize, - connection_string: SharedString, - window: &mut Window, - cx: &mut App, - ) { - let prompt_message = - format!("Remove server `{}`?", connection_string); - - let confirmation = window.prompt( - PromptLevel::Warning, - &prompt_message, - None, - &["Yes, remove it", "No, keep it"], - cx, - ); - - cx.spawn(async move |cx| { - if confirmation.await.ok() == Some(0) { - remote_servers - .update(cx, |this, cx| { - this.delete_ssh_server(index, cx); - }) - .ok(); - remote_servers - .update(cx, |this, cx| { - this.mode = Mode::default_mode( - &this.ssh_config_servers, - cx, - ); - cx.notify(); - }) - .ok(); - } - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - } - div() - .id("ssh-options-copy-server-address") - .track_focus(&entries[2].focus_handle) - .on_action(cx.listener({ - let connection_string = connection_string.clone(); - move |_, _: &menu::Confirm, window, cx| { - remove_ssh_server( - cx.entity(), - server_index, - connection_string.clone(), - window, - cx, - ); - cx.focus_self(window); - } - })) - .child( - ListItem::new("remove-server") - .toggle_state( - entries[2].focus_handle.contains_focused(window, cx), - ) - .inset(true) - .spacing(ui::ListItemSpacing::Sparse) - .start_slot(Icon::new(IconName::Trash).color(Color::Error)) - .child(Label::new("Remove Server").color(Color::Error)) - .on_click(cx.listener(move |_, _, window, cx| { - remove_ssh_server( - cx.entity(), - server_index, - connection_string.clone(), - window, - cx, - ); - cx.focus_self(window); - })), - ) + .map(|this| match &options { + ViewServerOptionsState::Ssh { + connection, + entries, + server_index, + } => this.child(self.render_edit_ssh( + connection, + *server_index, + entries, + window, + cx, + )), + ViewServerOptionsState::Wsl { + connection, + entries, + server_index, + } => this.child(self.render_edit_wsl( + connection, + *server_index, + entries, + window, + cx, + )), }) .child(ListSeparator) .child({ div() .id("ssh-options-copy-server-address") - .track_focus(&entries[3].focus_handle) + .track_focus(&last_entry.focus_handle) .on_action(cx.listener(|this, _: &menu::Confirm, window, cx| { this.mode = Mode::default_mode(&this.ssh_config_servers, cx); cx.focus_self(window); @@ -1336,7 +1571,7 @@ impl RemoteServerProjects { .child( ListItem::new("go-back") .toggle_state( - entries[3].focus_handle.contains_focused(window, cx), + last_entry.focus_handle.contains_focused(window, cx), ) .inset(true) .spacing(ui::ListItemSpacing::Sparse) @@ -1355,13 +1590,253 @@ impl RemoteServerProjects { ) .into_any_element(), ); - for entry in entries { - view = view.entry(entry); + + for entry in options.entries() { + view = view.entry(entry.clone()); } view.render(window, cx).into_any_element() } + fn render_edit_wsl( + &self, + connection: &WslConnectionOptions, + index: WslServerIndex, + entries: &[NavigableEntry], + window: &mut Window, + cx: &mut Context, + ) -> impl IntoElement { + let distro_name = SharedString::new(connection.distro_name.clone()); + + v_flex().child({ + fn remove_wsl_distro( + remote_servers: Entity, + index: WslServerIndex, + distro_name: SharedString, + window: &mut Window, + cx: &mut App, + ) { + let prompt_message = format!("Remove WSL distro `{}`?", distro_name); + + let confirmation = window.prompt( + PromptLevel::Warning, + &prompt_message, + None, + &["Yes, remove it", "No, keep it"], + cx, + ); + + cx.spawn(async move |cx| { + if confirmation.await.ok() == Some(0) { + remote_servers + .update(cx, |this, cx| { + this.delete_wsl_distro(index, cx); + }) + .ok(); + remote_servers + .update(cx, |this, cx| { + this.mode = Mode::default_mode(&this.ssh_config_servers, cx); + cx.notify(); + }) + .ok(); + } + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } + div() + .id("wsl-options-remove-distro") + .track_focus(&entries[0].focus_handle) + .on_action(cx.listener({ + let distro_name = distro_name.clone(); + move |_, _: &menu::Confirm, window, cx| { + remove_wsl_distro(cx.entity(), index, distro_name.clone(), window, cx); + cx.focus_self(window); + } + })) + .child( + ListItem::new("remove-distro") + .toggle_state(entries[0].focus_handle.contains_focused(window, cx)) + .inset(true) + .spacing(ui::ListItemSpacing::Sparse) + .start_slot(Icon::new(IconName::Trash).color(Color::Error)) + .child(Label::new("Remove Distro").color(Color::Error)) + .on_click(cx.listener(move |_, _, window, cx| { + remove_wsl_distro(cx.entity(), index, distro_name.clone(), window, cx); + cx.focus_self(window); + })), + ) + }) + } + + fn render_edit_ssh( + &self, + connection: &SshConnectionOptions, + index: SshServerIndex, + entries: &[NavigableEntry], + window: &mut Window, + cx: &mut Context, + ) -> impl IntoElement { + let connection_string = SharedString::new(connection.host.clone()); + + v_flex() + .child({ + let label = if connection.nickname.is_some() { + "Edit Nickname" + } else { + "Add Nickname to Server" + }; + div() + .id("ssh-options-add-nickname") + .track_focus(&entries[0].focus_handle) + .on_action(cx.listener(move |this, _: &menu::Confirm, window, cx| { + this.mode = Mode::EditNickname(EditNicknameState::new(index, window, cx)); + cx.notify(); + })) + .child( + ListItem::new("add-nickname") + .toggle_state(entries[0].focus_handle.contains_focused(window, cx)) + .inset(true) + .spacing(ui::ListItemSpacing::Sparse) + .start_slot(Icon::new(IconName::Pencil).color(Color::Muted)) + .child(Label::new(label)) + .on_click(cx.listener(move |this, _, window, cx| { + this.mode = + Mode::EditNickname(EditNicknameState::new(index, window, cx)); + cx.notify(); + })), + ) + }) + .child({ + let workspace = self.workspace.clone(); + fn callback( + workspace: WeakEntity, + connection_string: SharedString, + cx: &mut App, + ) { + cx.write_to_clipboard(ClipboardItem::new_string(connection_string.to_string())); + workspace + .update(cx, |this, cx| { + struct SshServerAddressCopiedToClipboard; + let notification = format!( + "Copied server address ({}) to clipboard", + connection_string + ); + + this.show_toast( + Toast::new( + NotificationId::composite::( + connection_string.clone(), + ), + notification, + ) + .autohide(), + cx, + ); + }) + .ok(); + } + div() + .id("ssh-options-copy-server-address") + .track_focus(&entries[1].focus_handle) + .on_action({ + let connection_string = connection_string.clone(); + let workspace = self.workspace.clone(); + move |_: &menu::Confirm, _, cx| { + callback(workspace.clone(), connection_string.clone(), cx); + } + }) + .child( + ListItem::new("copy-server-address") + .toggle_state(entries[1].focus_handle.contains_focused(window, cx)) + .inset(true) + .spacing(ui::ListItemSpacing::Sparse) + .start_slot(Icon::new(IconName::Copy).color(Color::Muted)) + .child(Label::new("Copy Server Address")) + .end_hover_slot( + Label::new(connection_string.clone()).color(Color::Muted), + ) + .on_click({ + let connection_string = connection_string.clone(); + move |_, _, cx| { + callback(workspace.clone(), connection_string.clone(), cx); + } + }), + ) + }) + .child({ + fn remove_ssh_server( + remote_servers: Entity, + index: SshServerIndex, + connection_string: SharedString, + window: &mut Window, + cx: &mut App, + ) { + let prompt_message = format!("Remove server `{}`?", connection_string); + + let confirmation = window.prompt( + PromptLevel::Warning, + &prompt_message, + None, + &["Yes, remove it", "No, keep it"], + cx, + ); + + cx.spawn(async move |cx| { + if confirmation.await.ok() == Some(0) { + remote_servers + .update(cx, |this, cx| { + this.delete_ssh_server(index, cx); + }) + .ok(); + remote_servers + .update(cx, |this, cx| { + this.mode = Mode::default_mode(&this.ssh_config_servers, cx); + cx.notify(); + }) + .ok(); + } + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + } + div() + .id("ssh-options-copy-server-address") + .track_focus(&entries[2].focus_handle) + .on_action(cx.listener({ + let connection_string = connection_string.clone(); + move |_, _: &menu::Confirm, window, cx| { + remove_ssh_server( + cx.entity(), + index, + connection_string.clone(), + window, + cx, + ); + cx.focus_self(window); + } + })) + .child( + ListItem::new("remove-server") + .toggle_state(entries[2].focus_handle.contains_focused(window, cx)) + .inset(true) + .spacing(ui::ListItemSpacing::Sparse) + .start_slot(Icon::new(IconName::Trash).color(Color::Error)) + .child(Label::new("Remove Server").color(Color::Error)) + .on_click(cx.listener(move |_, _, window, cx| { + remove_ssh_server( + cx.entity(), + index, + connection_string.clone(), + window, + cx, + ); + cx.focus_self(window); + })), + ) + }) + } + fn render_edit_nickname( &self, state: &EditNicknameState, @@ -1370,7 +1845,7 @@ impl RemoteServerProjects { ) -> impl IntoElement { let Some(connection) = SshSettings::get_global(cx) .ssh_connections() - .nth(state.index) + .nth(state.index.0) else { return v_flex() .id("ssh-edit-nickname") @@ -1388,6 +1863,7 @@ impl RemoteServerProjects { connection_string, paths: Default::default(), nickname, + is_wsl: false, } .render(window, cx), ) @@ -1407,15 +1883,33 @@ impl RemoteServerProjects { cx: &mut Context, ) -> impl IntoElement { let ssh_settings = SshSettings::get_global(cx); + let mut should_rebuild = false; - let mut should_rebuild = state + let ssh_connections_changed = ssh_settings.ssh_connections.iter().ne(state .servers .iter() .filter_map(|server| match server { - RemoteEntry::Project { connection, .. } => Some(connection), - RemoteEntry::SshConfig { .. } => None, - }) - .ne(&ssh_settings.ssh_connections); + RemoteEntry::Project { + connection: Connection::Ssh(connection), + .. + } => Some(connection), + _ => None, + })); + + let wsl_connections_changed = ssh_settings.wsl_connections.iter().ne(state + .servers + .iter() + .filter_map(|server| match server { + RemoteEntry::Project { + connection: Connection::Wsl(connection), + .. + } => Some(connection), + _ => None, + })); + + if ssh_connections_changed || wsl_connections_changed { + should_rebuild = true; + }; if !should_rebuild && ssh_settings.read_ssh_config { let current_ssh_hosts: BTreeSet = state @@ -1428,7 +1922,11 @@ impl RemoteServerProjects { .collect(); let mut expected_ssh_hosts = self.ssh_config_servers.clone(); for server in &state.servers { - if let RemoteEntry::Project { connection, .. } = server { + if let RemoteEntry::Project { + connection: Connection::Ssh(connection), + .. + } = server + { expected_ssh_hosts.remove(&connection.host); } } @@ -1472,14 +1970,47 @@ impl RemoteServerProjects { cx.notify(); })); + #[cfg(target_os = "windows")] + let wsl_connect_button = div() + .id("wsl-connect-new-server") + .track_focus(&state.add_new_wsl.focus_handle) + .anchor_scroll(state.add_new_wsl.scroll_anchor.clone()) + .child( + ListItem::new("wsl-add-new-server") + .toggle_state(state.add_new_wsl.focus_handle.contains_focused(window, cx)) + .inset(true) + .spacing(ui::ListItemSpacing::Sparse) + .start_slot(Icon::new(IconName::Plus).color(Color::Muted)) + .child(Label::new("Add WSL Distro")) + .on_click(cx.listener(|this, _, window, cx| { + let state = AddWslDistro::new(window, cx); + this.mode = Mode::AddWslDistro(state); + + cx.notify(); + })), + ) + .on_action(cx.listener(|this, _: &menu::Confirm, window, cx| { + let state = AddWslDistro::new(window, cx); + this.mode = Mode::AddWslDistro(state); + + cx.notify(); + })); + + let modal_section = v_flex() + .track_focus(&self.focus_handle(cx)) + .id("ssh-server-list") + .overflow_y_scroll() + .track_scroll(&state.scroll_handle) + .size_full() + .child(connect_button); + + #[cfg(target_os = "windows")] + let modal_section = modal_section.child(wsl_connect_button); + #[cfg(not(target_os = "windows"))] + let modal_section = modal_section; + let mut modal_section = Navigable::new( - v_flex() - .track_focus(&self.focus_handle(cx)) - .id("ssh-server-list") - .overflow_y_scroll() - .track_scroll(&state.scroll_handle) - .size_full() - .child(connect_button) + modal_section .child( List::new() .empty_message( @@ -1499,7 +2030,8 @@ impl RemoteServerProjects { ) .into_any_element(), ) - .entry(state.add_new_server.clone()); + .entry(state.add_new_server.clone()) + .entry(state.add_new_wsl.clone()); for server in &state.servers { match server { @@ -1582,7 +2114,7 @@ impl RemoteServerProjects { &mut self, ssh_config_host: &SharedString, cx: &mut Context<'_, Self>, - ) -> usize { + ) -> SshServerIndex { let new_ix = Arc::new(AtomicUsize::new(0)); let update_new_ix = new_ix.clone(); @@ -1604,7 +2136,7 @@ impl RemoteServerProjects { cx, ); self.mode = Mode::default_mode(&self.ssh_config_servers, cx); - new_ix.load(atomic::Ordering::Acquire) + SshServerIndex(new_ix.load(atomic::Ordering::Acquire)) } } @@ -1714,6 +2246,10 @@ impl Render for RemoteServerProjects { Mode::EditNickname(state) => self .render_edit_nickname(state, window, cx) .into_any_element(), + #[cfg(target_os = "windows")] + Mode::AddWslDistro(state) => self + .render_add_wsl_distro(state, window, cx) + .into_any_element(), }) } } diff --git a/crates/recent_projects/src/wsl_picker.rs b/crates/recent_projects/src/wsl_picker.rs new file mode 100644 index 0000000000000000000000000000000000000000..e386b723fa43777e496565c11b8308f16031d837 --- /dev/null +++ b/crates/recent_projects/src/wsl_picker.rs @@ -0,0 +1,295 @@ +use std::{path::PathBuf, sync::Arc}; + +use gpui::{AppContext, DismissEvent, Entity, EventEmitter, Focusable, Subscription, Task}; +use picker::Picker; +use remote::{RemoteConnectionOptions, WslConnectionOptions}; +use ui::{ + App, Context, HighlightedLabel, Icon, IconName, InteractiveElement, ListItem, ParentElement, + Render, Styled, StyledExt, Toggleable, Window, div, h_flex, rems, v_flex, +}; +use util::ResultExt as _; +use workspace::{ModalView, Workspace}; + +use crate::open_remote_project; + +#[derive(Clone, Debug)] +pub struct WslDistroSelected { + pub secondary: bool, + pub distro: String, +} + +#[derive(Clone, Debug)] +pub struct WslPickerDismissed; + +pub(crate) struct WslPickerDelegate { + selected_index: usize, + distro_list: Option>, + matches: Vec, +} + +impl WslPickerDelegate { + pub fn new() -> Self { + WslPickerDelegate { + selected_index: 0, + distro_list: None, + matches: Vec::new(), + } + } + + pub fn selected_distro(&self) -> Option { + self.matches + .get(self.selected_index) + .map(|m| m.string.clone()) + } +} + +impl WslPickerDelegate { + fn fetch_distros() -> anyhow::Result> { + use anyhow::Context; + use windows_registry::CURRENT_USER; + + let lxss_key = CURRENT_USER + .open("Software\\Microsoft\\Windows\\CurrentVersion\\Lxss") + .context("failed to get lxss wsl key")?; + + let distros = lxss_key + .keys() + .context("failed to get wsl distros")? + .filter_map(|key| { + lxss_key + .open(&key) + .context("failed to open subkey for distro") + .log_err() + }) + .filter_map(|distro| distro.get_string("DistributionName").ok()) + .collect::>(); + + Ok(distros) + } +} + +impl EventEmitter for Picker {} + +impl EventEmitter for Picker {} + +impl picker::PickerDelegate for WslPickerDelegate { + type ListItem = ListItem; + + fn match_count(&self) -> usize { + self.matches.len() + } + + fn selected_index(&self) -> usize { + self.selected_index + } + + fn set_selected_index( + &mut self, + ix: usize, + _window: &mut Window, + cx: &mut Context>, + ) { + self.selected_index = ix; + cx.notify(); + } + + fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc { + Arc::from("Enter WSL distro name") + } + + fn update_matches( + &mut self, + query: String, + _window: &mut Window, + cx: &mut Context>, + ) -> Task<()> { + use fuzzy::StringMatchCandidate; + + let needs_fetch = self.distro_list.is_none(); + if needs_fetch { + let distros = Self::fetch_distros().log_err(); + self.distro_list = distros; + } + + if let Some(distro_list) = &self.distro_list { + use ordered_float::OrderedFloat; + + let candidates = distro_list + .iter() + .enumerate() + .map(|(id, distro)| StringMatchCandidate::new(id, distro)) + .collect::>(); + + let query = query.trim_start(); + let smart_case = query.chars().any(|c| c.is_uppercase()); + self.matches = smol::block_on(fuzzy::match_strings( + candidates.as_slice(), + query, + smart_case, + true, + 100, + &Default::default(), + cx.background_executor().clone(), + )); + self.matches.sort_unstable_by_key(|m| m.candidate_id); + + self.selected_index = self + .matches + .iter() + .enumerate() + .rev() + .max_by_key(|(_, m)| OrderedFloat(m.score)) + .map(|(index, _)| index) + .unwrap_or(0); + } + + Task::ready(()) + } + + fn confirm(&mut self, secondary: bool, _window: &mut Window, cx: &mut Context>) { + if let Some(distro) = self.matches.get(self.selected_index) { + cx.emit(WslDistroSelected { + secondary, + distro: distro.string.clone(), + }); + } + } + + fn dismissed(&mut self, _window: &mut Window, cx: &mut Context>) { + cx.emit(WslPickerDismissed); + } + + fn render_match( + &self, + ix: usize, + selected: bool, + _: &mut Window, + _: &mut Context>, + ) -> Option { + let matched = self.matches.get(ix)?; + Some( + ListItem::new(ix) + .toggle_state(selected) + .inset(true) + .spacing(ui::ListItemSpacing::Sparse) + .child( + h_flex() + .flex_grow() + .gap_3() + .child(Icon::new(IconName::Linux)) + .child(v_flex().child(HighlightedLabel::new( + matched.string.clone(), + matched.positions.clone(), + ))), + ), + ) + } +} + +pub(crate) struct WslOpenModal { + paths: Vec, + create_new_window: bool, + picker: Entity>, + _subscriptions: [Subscription; 2], +} + +impl WslOpenModal { + pub fn new( + paths: Vec, + create_new_window: bool, + window: &mut Window, + cx: &mut Context, + ) -> Self { + let delegate = WslPickerDelegate::new(); + let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx).modal(false)); + + let selected = cx.subscribe_in( + &picker, + window, + |this, _, event: &WslDistroSelected, window, cx| { + this.confirm(&event.distro, event.secondary, window, cx); + }, + ); + + let dismissed = cx.subscribe_in( + &picker, + window, + |this, _, _: &WslPickerDismissed, window, cx| { + this.cancel(&menu::Cancel, window, cx); + }, + ); + + WslOpenModal { + paths, + create_new_window, + picker, + _subscriptions: [selected, dismissed], + } + } + + fn confirm( + &mut self, + distro: &str, + secondary: bool, + window: &mut Window, + cx: &mut Context, + ) { + let app_state = workspace::AppState::global(cx); + let Some(app_state) = app_state.upgrade() else { + return; + }; + + let connection_options = RemoteConnectionOptions::Wsl(WslConnectionOptions { + distro_name: distro.to_string(), + user: None, + }); + + let replace_current_window = match self.create_new_window { + true => secondary, + false => !secondary, + }; + let replace_window = match replace_current_window { + true => window.window_handle().downcast::(), + false => None, + }; + + let paths = self.paths.clone(); + let open_options = workspace::OpenOptions { + replace_window, + ..Default::default() + }; + + cx.emit(DismissEvent); + cx.spawn_in(window, async move |_, cx| { + open_remote_project(connection_options, paths, app_state, open_options, cx).await + }) + .detach(); + } + + fn cancel(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context) { + cx.emit(DismissEvent); + } +} + +impl ModalView for WslOpenModal {} + +impl Focusable for WslOpenModal { + fn focus_handle(&self, cx: &App) -> gpui::FocusHandle { + self.picker.focus_handle(cx) + } +} + +impl EventEmitter for WslOpenModal {} + +impl Render for WslOpenModal { + fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl ui::IntoElement { + div() + .on_mouse_down_out(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))) + .on_action(cx.listener(Self::cancel)) + .elevation_3(cx) + .w(rems(34.)) + .flex_1() + .overflow_hidden() + .child(self.picker.clone()) + } +} diff --git a/crates/remote/src/remote_client.rs b/crates/remote/src/remote_client.rs index 501c6a8dd639630b1930cb32e804f8cca658a9ca..0363fc721a2d51971f49112af520b5dd34b52cb1 100644 --- a/crates/remote/src/remote_client.rs +++ b/crates/remote/src/remote_client.rs @@ -769,13 +769,15 @@ impl RemoteClient { } pub fn shell(&self) -> Option { - Some(self.state.as_ref()?.remote_connection()?.shell()) + Some(self.remote_connection()?.shell()) + } + + pub fn default_system_shell(&self) -> Option { + Some(self.remote_connection()?.default_system_shell()) } pub fn shares_network_interface(&self) -> bool { - self.state - .as_ref() - .and_then(|state| state.remote_connection()) + self.remote_connection() .map_or(false, |connection| connection.shares_network_interface()) } @@ -787,12 +789,8 @@ impl RemoteClient { working_dir: Option, port_forward: Option<(u16, String, u16)>, ) -> Result { - let Some(connection) = self - .state - .as_ref() - .and_then(|state| state.remote_connection()) - else { - return Err(anyhow!("no connection")); + let Some(connection) = self.remote_connection() else { + return Err(anyhow!("no ssh connection")); }; connection.build_command(program, args, env, working_dir, port_forward) } @@ -803,11 +801,7 @@ impl RemoteClient { dest_path: RemotePathBuf, cx: &App, ) -> Task> { - let Some(connection) = self - .state - .as_ref() - .and_then(|state| state.remote_connection()) - else { + let Some(connection) = self.remote_connection() else { return Task::ready(Err(anyhow!("no ssh connection"))); }; connection.upload_directory(src_path, dest_path, cx) @@ -916,6 +910,12 @@ impl RemoteClient { .unwrap() .unwrap() } + + fn remote_connection(&self) -> Option> { + self.state + .as_ref() + .and_then(|state| state.remote_connection()) + } } enum ConnectionPoolEntry { @@ -1066,6 +1066,7 @@ pub(crate) trait RemoteConnection: Send + Sync { fn connection_options(&self) -> RemoteConnectionOptions; fn path_style(&self) -> PathStyle; fn shell(&self) -> String; + fn default_system_shell(&self) -> String; #[cfg(any(test, feature = "test-support"))] fn simulate_disconnect(&self, _: &AsyncApp) {} @@ -1507,6 +1508,10 @@ mod fake { fn shell(&self) -> String { "sh".to_owned() } + + fn default_system_shell(&self) -> String { + "sh".to_owned() + } } pub(super) struct Delegate; diff --git a/crates/remote/src/transport/ssh.rs b/crates/remote/src/transport/ssh.rs index 10946cd11c2a137f1c8951999bc47fa20a27fb67..85ebf4659ccfbda35933f8b8aa4bec25d44a489c 100644 --- a/crates/remote/src/transport/ssh.rs +++ b/crates/remote/src/transport/ssh.rs @@ -36,6 +36,7 @@ pub(crate) struct SshRemoteConnection { ssh_platform: RemotePlatform, ssh_path_style: PathStyle, ssh_shell: String, + ssh_default_system_shell: String, _temp_dir: TempDir, } @@ -109,6 +110,10 @@ impl RemoteConnection for SshRemoteConnection { self.ssh_shell.clone() } + fn default_system_shell(&self) -> String { + self.ssh_default_system_shell.clone() + } + fn build_command( &self, input_program: Option, @@ -117,64 +122,24 @@ impl RemoteConnection for SshRemoteConnection { working_dir: Option, port_forward: Option<(u16, String, u16)>, ) -> Result { - use std::fmt::Write as _; - - let mut script = String::new(); - if let Some(working_dir) = working_dir { - let working_dir = - RemotePathBuf::new(working_dir.into(), self.ssh_path_style).to_string(); - - // shlex will wrap the command in single quotes (''), disabling ~ expansion, - // replace ith with something that works - const TILDE_PREFIX: &'static str = "~/"; - let working_dir = if working_dir.starts_with(TILDE_PREFIX) { - let working_dir = working_dir.trim_start_matches("~").trim_start_matches("/"); - format!("$HOME/{working_dir}") - } else { - working_dir - }; - write!(&mut script, "cd \"{working_dir}\"; ",).unwrap(); - } else { - write!(&mut script, "cd; ").unwrap(); - }; - - for (k, v) in input_env.iter() { - if let Some((k, v)) = shlex::try_quote(k).ok().zip(shlex::try_quote(v).ok()) { - write!(&mut script, "{}={} ", k, v).unwrap(); - } - } - - let shell = &self.ssh_shell; - - if let Some(input_program) = input_program { - let command = shlex::try_quote(&input_program)?; - script.push_str(&command); - for arg in input_args { - let arg = shlex::try_quote(&arg)?; - script.push_str(" "); - script.push_str(&arg); - } - } else { - write!(&mut script, "exec {shell} -l").unwrap(); - }; - - let shell_invocation = format!("{shell} -c {}", shlex::try_quote(&script).unwrap()); - - let mut args = Vec::new(); - args.extend(self.socket.ssh_args()); - - if let Some((local_port, host, remote_port)) = port_forward { - args.push("-L".into()); - args.push(format!("{local_port}:{host}:{remote_port}")); - } - - args.push("-t".into()); - args.push(shell_invocation); - Ok(CommandTemplate { - program: "ssh".into(), - args, - env: self.socket.envs.clone(), - }) + let Self { + ssh_path_style, + socket, + ssh_shell, + .. + } = self; + let env = socket.envs.clone(); + build_command( + input_program, + input_args, + input_env, + working_dir, + port_forward, + env, + *ssh_path_style, + ssh_shell, + socket.ssh_args(), + ) } fn upload_directory( @@ -391,6 +356,7 @@ impl SshRemoteConnection { _ => PathStyle::Posix, }; let ssh_shell = socket.shell().await; + let ssh_default_system_shell = String::from("/bin/sh"); let mut this = Self { socket, @@ -400,6 +366,7 @@ impl SshRemoteConnection { ssh_path_style, ssh_platform, ssh_shell, + ssh_default_system_shell, }; let (release_channel, version, commit) = cx.update(|cx| { @@ -1041,3 +1008,139 @@ impl SshConnectionOptions { } } } + +fn build_command( + input_program: Option, + input_args: &[String], + input_env: &HashMap, + working_dir: Option, + port_forward: Option<(u16, String, u16)>, + ssh_env: HashMap, + ssh_path_style: PathStyle, + ssh_shell: &str, + ssh_args: Vec, +) -> Result { + use std::fmt::Write as _; + + let mut exec = String::from("exec env -C "); + if let Some(working_dir) = working_dir { + let working_dir = RemotePathBuf::new(working_dir.into(), ssh_path_style).to_string(); + + // shlex will wrap the command in single quotes (''), disabling ~ expansion, + // replace with with something that works + const TILDE_PREFIX: &'static str = "~/"; + if working_dir.starts_with(TILDE_PREFIX) { + let working_dir = working_dir.trim_start_matches("~").trim_start_matches("/"); + write!(exec, "\"$HOME/{working_dir}\" ",).unwrap(); + } else { + write!(exec, "\"{working_dir}\" ",).unwrap(); + } + } else { + write!(exec, "\"$HOME\" ").unwrap(); + }; + + for (k, v) in input_env.iter() { + if let Some((k, v)) = shlex::try_quote(k).ok().zip(shlex::try_quote(v).ok()) { + write!(exec, "{}={} ", k, v).unwrap(); + } + } + + write!(exec, "{ssh_shell} ").unwrap(); + if let Some(input_program) = input_program { + let mut script = shlex::try_quote(&input_program)?.into_owned(); + for arg in input_args { + let arg = shlex::try_quote(&arg)?; + script.push_str(" "); + script.push_str(&arg); + } + write!(exec, "-c {}", shlex::try_quote(&script).unwrap()).unwrap(); + } else { + write!(exec, "-l").unwrap(); + }; + + let mut args = Vec::new(); + args.extend(ssh_args); + + if let Some((local_port, host, remote_port)) = port_forward { + args.push("-L".into()); + args.push(format!("{local_port}:{host}:{remote_port}")); + } + + args.push("-t".into()); + args.push(exec); + Ok(CommandTemplate { + program: "ssh".into(), + args, + env: ssh_env, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_build_command() -> Result<()> { + let mut input_env = HashMap::default(); + input_env.insert("INPUT_VA".to_string(), "val".to_string()); + let mut env = HashMap::default(); + env.insert("SSH_VAR".to_string(), "ssh-val".to_string()); + + let command = build_command( + Some("remote_program".to_string()), + &["arg1".to_string(), "arg2".to_string()], + &input_env, + Some("~/work".to_string()), + None, + env.clone(), + PathStyle::Posix, + "/bin/fish", + vec!["-p".to_string(), "2222".to_string()], + )?; + + assert_eq!(command.program, "ssh"); + assert_eq!( + command.args.iter().map(String::as_str).collect::>(), + [ + "-p", + "2222", + "-t", + "exec env -C \"$HOME/work\" INPUT_VA=val /bin/fish -c 'remote_program arg1 arg2'" + ] + ); + assert_eq!(command.env, env); + + let mut input_env = HashMap::default(); + input_env.insert("INPUT_VA".to_string(), "val".to_string()); + let mut env = HashMap::default(); + env.insert("SSH_VAR".to_string(), "ssh-val".to_string()); + + let command = build_command( + None, + &["arg1".to_string(), "arg2".to_string()], + &input_env, + None, + Some((1, "foo".to_owned(), 2)), + env.clone(), + PathStyle::Posix, + "/bin/fish", + vec!["-p".to_string(), "2222".to_string()], + )?; + + assert_eq!(command.program, "ssh"); + assert_eq!( + command.args.iter().map(String::as_str).collect::>(), + [ + "-p", + "2222", + "-L", + "1:foo:2", + "-t", + "exec env -C \"$HOME\" INPUT_VA=val /bin/fish -l" + ] + ); + assert_eq!(command.env, env); + + Ok(()) + } +} diff --git a/crates/remote/src/transport/wsl.rs b/crates/remote/src/transport/wsl.rs index 2b4d29eafeede14f305c4d21f61188b858253285..f143b73457663ec5aed60ea760e598d21142b902 100644 --- a/crates/remote/src/transport/wsl.rs +++ b/crates/remote/src/transport/wsl.rs @@ -25,10 +25,20 @@ pub struct WslConnectionOptions { pub user: Option, } +impl From for WslConnectionOptions { + fn from(val: settings::WslConnection) -> Self { + WslConnectionOptions { + distro_name: val.distro_name.into(), + user: val.user, + } + } +} + pub(crate) struct WslRemoteConnection { remote_binary_path: Option, platform: RemotePlatform, shell: String, + default_system_shell: String, connection_options: WslConnectionOptions, } @@ -56,6 +66,7 @@ impl WslRemoteConnection { remote_binary_path: None, platform: RemotePlatform { os: "", arch: "" }, shell: String::new(), + default_system_shell: String::from("/bin/sh"), }; delegate.set_status(Some("Detecting WSL environment"), cx); this.platform = this.detect_platform().await?; @@ -84,7 +95,11 @@ impl WslRemoteConnection { .run_wsl_command("sh", &["-c", "echo $SHELL"]) .await .ok() - .and_then(|shell_path| shell_path.trim().split('/').next_back().map(str::to_string)) + .and_then(|shell_path| { + Path::new(shell_path.trim()) + .file_name() + .map(|it| it.to_str().unwrap().to_owned()) + }) .unwrap_or_else(|| "bash".to_string())) } @@ -427,6 +442,10 @@ impl RemoteConnection for WslRemoteConnection { fn shell(&self) -> String { self.shell.clone() } + + fn default_system_shell(&self) -> String { + self.default_system_shell.clone() + } } /// `wslpath` is a executable available in WSL, it's a linux binary. diff --git a/crates/remote_server/src/headless_project.rs b/crates/remote_server/src/headless_project.rs index f107ff2c8f8860a621ad5c637e6fa34b54734a6d..504e6a4bfe852bad07c72a01a323e5de22d1a4c2 100644 --- a/crates/remote_server/src/headless_project.rs +++ b/crates/remote_server/src/headless_project.rs @@ -197,7 +197,7 @@ impl HeadlessProject { let agent_server_store = cx.new(|cx| { let mut agent_server_store = AgentServerStore::local(node_runtime.clone(), fs.clone(), environment, cx); - agent_server_store.shared(REMOTE_SERVER_PROJECT_ID, session.clone()); + agent_server_store.shared(REMOTE_SERVER_PROJECT_ID, session.clone(), cx); agent_server_store }); diff --git a/crates/rope/src/rope.rs b/crates/rope/src/rope.rs index 8bcaef20ca3bd5c79413791764a313fd1e6b75ac..3f6addb7c2394503098a213f4139fedc9757ba86 100644 --- a/crates/rope/src/rope.rs +++ b/crates/rope/src/rope.rs @@ -767,7 +767,7 @@ impl<'a> Chunks<'a> { } /// Returns bitmaps that represent character positions and tab positions - pub fn peak_with_bitmaps(&self) -> Option> { + pub fn peek_with_bitmaps(&self) -> Option> { if !self.offset_is_valid() { return None; } @@ -898,7 +898,7 @@ impl<'a> Iterator for ChunkWithBitmaps<'a> { type Item = ChunkBitmaps<'a>; fn next(&mut self) -> Option { - let chunk_bitmaps = self.0.peak_with_bitmaps()?; + let chunk_bitmaps = self.0.peek_with_bitmaps()?; if self.0.reversed { self.0.offset -= chunk_bitmaps.text.len(); if self.0.offset <= *self.0.chunks.start() { diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index 3d9248e3c7e64893cc6a72d0e034d7a7597edf29..33ccd095687c448abc5d8b685da22e89ab59cbc8 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -679,7 +679,18 @@ impl ProjectSearchView { self.included_opened_only = !self.included_opened_only; } + pub fn replacement(&self, cx: &App) -> String { + self.replacement_editor.read(cx).text(cx) + } + fn replace_next(&mut self, _: &ReplaceNext, window: &mut Window, cx: &mut Context) { + if let Some(last_search_query_text) = &self.entity.read(cx).last_search_query_text + && self.query_editor.read(cx).text(cx) != *last_search_query_text + { + // search query has changed, restart search and bail + self.search(cx); + return; + } if self.entity.read(cx).match_ranges.is_empty() { return; } @@ -699,14 +710,17 @@ impl ProjectSearchView { self.select_match(Direction::Next, window, cx) } } - pub fn replacement(&self, cx: &App) -> String { - self.replacement_editor.read(cx).text(cx) - } fn replace_all(&mut self, _: &ReplaceAll, window: &mut Window, cx: &mut Context) { + if let Some(last_search_query_text) = &self.entity.read(cx).last_search_query_text + && self.query_editor.read(cx).text(cx) != *last_search_query_text + { + // search query has changed, restart search and bail + self.search(cx); + return; + } if self.active_match_index.is_none() { return; } - let Some(query) = self.entity.read(cx).active_query.as_ref() else { return; }; @@ -1057,18 +1071,12 @@ impl ProjectSearchView { window: &mut Window, cx: &mut Context, ) -> Task> { - use workspace::AutosaveSetting; - let project = self.entity.read(cx).project.clone(); let can_autosave = self.results_editor.can_autosave(cx); let autosave_setting = self.results_editor.workspace_settings(cx).autosave; - let will_autosave = can_autosave - && matches!( - autosave_setting, - AutosaveSetting::OnFocusChange | AutosaveSetting::OnWindowChange - ); + let will_autosave = can_autosave && autosave_setting.should_save_on_close(); let is_dirty = self.is_dirty(cx); diff --git a/crates/settings/src/settings_content.rs b/crates/settings/src/settings_content.rs index b818e5e540a591e41d2469cdcd290e70026201e1..d2c6069959c7476a70dc2a52431c287bc8ed063d 100644 --- a/crates/settings/src/settings_content.rs +++ b/crates/settings/src/settings_content.rs @@ -22,6 +22,7 @@ use release_channel::ReleaseChannel; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_with::skip_serializing_none; +use std::collections::BTreeSet; use std::env; use std::sync::Arc; pub use util::serde::default_true; @@ -745,6 +746,7 @@ pub enum ImageFileSizeUnit { #[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, PartialEq)] pub struct RemoteSettingsContent { pub ssh_connections: Option>, + pub wsl_connections: Option>, pub read_ssh_config: Option, } @@ -769,6 +771,14 @@ pub struct SshConnection { pub port_forwards: Option>, } +#[derive(Clone, Default, Serialize, Deserialize, PartialEq, JsonSchema, Debug)] +pub struct WslConnection { + pub distro_name: SharedString, + pub user: Option, + #[serde(default)] + pub projects: BTreeSet, +} + #[skip_serializing_none] #[derive( Clone, Debug, Default, Serialize, PartialEq, Eq, PartialOrd, Ord, Deserialize, JsonSchema, diff --git a/crates/settings/src/settings_content/agent.rs b/crates/settings/src/settings_content/agent.rs index 6dd163972fc3d90abc8cece72a8d298647423b1b..0dd9a78343ec6737a7b98a8ef9c755783c1e6f33 100644 --- a/crates/settings/src/settings_content/agent.rs +++ b/crates/settings/src/settings_content/agent.rs @@ -104,6 +104,10 @@ pub struct AgentSettingsContent { /// /// Default: false pub use_modifier_to_send: Option, + /// Minimum number of lines of height the agent message editor should have. + /// + /// Default: 4 + pub message_editor_min_lines: Option, } impl AgentSettingsContent { @@ -231,21 +235,30 @@ impl JsonSchema for LanguageModelProviderSetting { } fn json_schema(_: &mut schemars::SchemaGenerator) -> schemars::Schema { + // list the builtin providers as a subset so that we still auto complete them in the settings json_schema!({ - "enum": [ - "amazon-bedrock", - "anthropic", - "copilot_chat", - "deepseek", - "google", - "lmstudio", - "mistral", - "ollama", - "openai", - "openrouter", - "vercel", - "x_ai", - "zed.dev" + "anyOf": [ + { + "type": "string", + "enum": [ + "amazon-bedrock", + "anthropic", + "copilot_chat", + "deepseek", + "google", + "lmstudio", + "mistral", + "ollama", + "openai", + "openrouter", + "vercel", + "x_ai", + "zed.dev" + ] + }, + { + "type": "string", + } ] }) } diff --git a/crates/settings/src/settings_content/workspace.rs b/crates/settings/src/settings_content/workspace.rs index ce0b43da931049d1d4f9cc056e46827107469644..aaa5817336058fad2e4300e70143e2100c9b72c7 100644 --- a/crates/settings/src/settings_content/workspace.rs +++ b/crates/settings/src/settings_content/workspace.rs @@ -287,6 +287,17 @@ pub enum AutosaveSetting { OnWindowChange, } +impl AutosaveSetting { + pub fn should_save_on_close(&self) -> bool { + matches!( + &self, + AutosaveSetting::OnFocusChange + | AutosaveSetting::OnWindowChange + | AutosaveSetting::AfterDelay { .. } + ) + } +} + #[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] #[serde(rename_all = "snake_case")] pub enum PaneSplitDirectionHorizontal { diff --git a/crates/sqlez/src/thread_safe_connection.rs b/crates/sqlez/src/thread_safe_connection.rs index 482905ac817bf94fcb64cb858b784c94283b686c..966f14a9c2f244780da7190aebac88e95c7ac068 100644 --- a/crates/sqlez/src/thread_safe_connection.rs +++ b/crates/sqlez/src/thread_safe_connection.rs @@ -249,11 +249,14 @@ pub fn background_thread_queue() -> WriteQueueConstructor { Box::new(|| { let (sender, receiver) = channel::(); - thread::spawn(move || { - while let Ok(write) = receiver.recv() { - write() - } - }); + thread::Builder::new() + .name("sqlezWorker".to_string()) + .spawn(move || { + while let Ok(write) = receiver.recv() { + write() + } + }) + .unwrap(); let sender = UnboundedSyncSender::new(sender); Box::new(move |queued_write| { diff --git a/crates/task/src/shell_builder.rs b/crates/task/src/shell_builder.rs index 4688ac0eb9dd306d0dedbe07c98adbfb5df4f45b..c3f0646c02cc427a07505c2ff30157e84d2ca0fe 100644 --- a/crates/task/src/shell_builder.rs +++ b/crates/task/src/shell_builder.rs @@ -10,7 +10,7 @@ pub enum ShellKind { Posix, Csh, Fish, - Powershell, + PowerShell, Nushell, Cmd, } @@ -21,7 +21,7 @@ impl fmt::Display for ShellKind { ShellKind::Posix => write!(f, "sh"), ShellKind::Csh => write!(f, "csh"), ShellKind::Fish => write!(f, "fish"), - ShellKind::Powershell => write!(f, "powershell"), + ShellKind::PowerShell => write!(f, "powershell"), ShellKind::Nushell => write!(f, "nu"), ShellKind::Cmd => write!(f, "cmd"), } @@ -43,7 +43,7 @@ impl ShellKind { || program == "pwsh" || program.ends_with("pwsh.exe") { - ShellKind::Powershell + ShellKind::PowerShell } else if program == "cmd" || program.ends_with("cmd.exe") { ShellKind::Cmd } else if program == "nu" { @@ -61,7 +61,7 @@ impl ShellKind { fn to_shell_variable(self, input: &str) -> String { match self { - Self::Powershell => Self::to_powershell_variable(input), + Self::PowerShell => Self::to_powershell_variable(input), Self::Cmd => Self::to_cmd_variable(input), Self::Posix => input.to_owned(), Self::Fish => input.to_owned(), @@ -184,7 +184,7 @@ impl ShellKind { fn args_for_shell(&self, interactive: bool, combined_command: String) -> Vec { match self { - ShellKind::Powershell => vec!["-C".to_owned(), combined_command], + ShellKind::PowerShell => vec!["-C".to_owned(), combined_command], ShellKind::Cmd => vec!["/C".to_owned(), combined_command], ShellKind::Posix | ShellKind::Nushell | ShellKind::Fish | ShellKind::Csh => interactive .then(|| "-i".to_owned()) @@ -196,7 +196,7 @@ impl ShellKind { pub fn command_prefix(&self) -> Option { match self { - ShellKind::Powershell => Some('&'), + ShellKind::PowerShell => Some('&'), ShellKind::Nushell => Some('^'), _ => None, } @@ -210,6 +210,7 @@ pub struct ShellBuilder { program: String, args: Vec, interactive: bool, + redirect_stdin: bool, kind: ShellKind, } @@ -231,6 +232,7 @@ impl ShellBuilder { args, interactive: true, kind, + redirect_stdin: false, } } pub fn non_interactive(mut self) -> Self { @@ -241,7 +243,7 @@ impl ShellBuilder { /// Returns the label to show in the terminal tab pub fn command_label(&self, command_label: &str) -> String { match self.kind { - ShellKind::Powershell => { + ShellKind::PowerShell => { format!("{} -C '{}'", self.program, command_label) } ShellKind::Cmd => { @@ -256,6 +258,12 @@ impl ShellBuilder { } } } + + pub fn redirect_stdin_to_dev_null(mut self) -> Self { + self.redirect_stdin = true; + self + } + /// Returns the program and arguments to run this task in a shell. pub fn build( mut self, @@ -263,11 +271,24 @@ impl ShellBuilder { task_args: &[String], ) -> (String, Vec) { if let Some(task_command) = task_command { - let combined_command = task_args.iter().fold(task_command, |mut command, arg| { + let mut combined_command = task_args.iter().fold(task_command, |mut command, arg| { command.push(' '); command.push_str(&self.kind.to_shell_variable(arg)); command }); + if self.redirect_stdin { + match self.kind { + ShellKind::Posix | ShellKind::Nushell | ShellKind::Fish | ShellKind::Csh => { + combined_command.push_str(" { + combined_command.insert_str(0, "$null | "); + } + ShellKind::Cmd => { + combined_command.push_str("< NUL"); + } + } + } self.args .extend(self.kind.args_for_shell(self.interactive, combined_command)); diff --git a/crates/terminal/src/terminal.rs b/crates/terminal/src/terminal.rs index d926faaf484b84dc7cf17b1ef0b816f5773ff02c..a07aef5f7b4da90373bcbf7c406dd8277cb09387 100644 --- a/crates/terminal/src/terminal.rs +++ b/crates/terminal/src/terminal.rs @@ -25,7 +25,7 @@ use alacritty_terminal::{ ClearMode, CursorStyle as AlacCursorStyle, Handler, NamedPrivateMode, PrivateMode, }, }; -use anyhow::{Result, bail}; +use anyhow::{Context as _, Result, bail}; use futures::{ FutureExt, @@ -486,7 +486,8 @@ impl TerminalBuilder { pty, pty_options.drain_on_exit, false, - )?; + ) + .context("failed to create event loop")?; //Kick things off let pty_tx = event_loop.channel(); @@ -528,6 +529,7 @@ impl TerminalBuilder { max_scroll_history_lines, window_id, }, + child_exited: None, }; if !activation_script.is_empty() && no_task { @@ -726,6 +728,7 @@ pub struct Terminal { shell_program: Option, template: CopyTemplate, activation_script: Vec, + child_exited: Option, } struct CopyTemplate { @@ -1921,10 +1924,13 @@ impl Terminal { if let Some(tx) = &self.completion_tx { tx.try_send(e).ok(); } + if let Some(e) = e { + self.child_exited = Some(e); + } let task = match &mut self.task { Some(task) => task, None => { - if error_code.is_none() { + if self.child_exited.is_none_or(|e| e.code() == Some(0)) { cx.emit(Event::CloseTerminal); } return; diff --git a/crates/terminal_view/src/terminal_panel.rs b/crates/terminal_view/src/terminal_panel.rs index c327978201a6b330ee2afbe4856d96ae511dff73..a59d330bb36010f4c2f11128b9a29052b4769ed5 100644 --- a/crates/terminal_view/src/terminal_panel.rs +++ b/crates/terminal_view/src/terminal_panel.rs @@ -468,7 +468,7 @@ impl TerminalPanel { }) .ok()? .await - .ok()?; + .log_err()?; panel .update_in(cx, move |terminal_panel, window, cx| { @@ -766,7 +766,7 @@ impl TerminalPanel { }) } - pub fn add_terminal_shell( + fn add_terminal_shell( &mut self, cwd: Option, reveal_strategy: RevealStrategy, @@ -776,7 +776,7 @@ impl TerminalPanel { let workspace = self.workspace.clone(); cx.spawn_in(window, async move |terminal_panel, cx| { if workspace.update(cx, |workspace, cx| !is_enabled_in_workspace(workspace, cx))? { - anyhow::bail!("terminal not yet supported for remote projects"); + anyhow::bail!("terminal not yet supported for collaborative projects"); } let pane = terminal_panel.update(cx, |terminal_panel, _| { terminal_panel.pending_terminals_to_add += 1; diff --git a/crates/terminal_view/src/terminal_view.rs b/crates/terminal_view/src/terminal_view.rs index 45ecbb4d0ff38e144b1d1c4806b19921f6b4e30b..d7adf74acb37e848ccb2d8670f970054d46ea0ae 100644 --- a/crates/terminal_view/src/terminal_view.rs +++ b/crates/terminal_view/src/terminal_view.rs @@ -1114,7 +1114,10 @@ impl Render for TerminalView { div.custom_scrollbars( Scrollbars::for_settings::() .show_along(ScrollAxes::Vertical) - .with_track_along(ScrollAxes::Vertical) + .with_track_along( + ScrollAxes::Vertical, + cx.theme().colors().editor_background, + ) .tracked_scroll_handle(self.scroll_handle.clone()), window, cx, diff --git a/crates/text/src/text.rs b/crates/text/src/text.rs index db282e5c30de562441f5076157a8db4a269aea9d..590c30c8a73c13180e4d09dda1b3a071ef46ad7f 100644 --- a/crates/text/src/text.rs +++ b/crates/text/src/text.rs @@ -3078,7 +3078,7 @@ impl ToOffset for usize { fn to_offset(&self, snapshot: &BufferSnapshot) -> usize { assert!( *self <= snapshot.len(), - "offset {} is out of range, max allowed is {}", + "offset {} is out of range, snapshot length is {}", self, snapshot.len() ); diff --git a/crates/title_bar/src/title_bar.rs b/crates/title_bar/src/title_bar.rs index a6285e1d1d8593c73b2b0c6a79913cb0f16f6e00..9f00b0ffeffe6b9744ffa67a0f52795e31e5737f 100644 --- a/crates/title_bar/src/title_bar.rs +++ b/crates/title_bar/src/title_bar.rs @@ -349,10 +349,11 @@ impl TitleBar { let options = self.project.read(cx).remote_connection_options(cx)?; let host: SharedString = options.display_name().into(); - let nickname = if let RemoteConnectionOptions::Ssh(options) = options { - options.nickname.map(|nick| nick.into()) - } else { - None + let (nickname, icon) = match options { + RemoteConnectionOptions::Ssh(options) => { + (options.nickname.map(|nick| nick.into()), IconName::Server) + } + RemoteConnectionOptions::Wsl(_) => (None, IconName::Linux), }; let nickname = nickname.unwrap_or_else(|| host.clone()); @@ -390,9 +391,7 @@ impl TitleBar { .max_w_32() .child( IconWithIndicator::new( - Icon::new(IconName::Server) - .size(IconSize::Small) - .color(icon_color), + Icon::new(icon).size(IconSize::Small).color(icon_color), Some(Indicator::dot().color(indicator_color)), ) .indicator_border_color(Some(cx.theme().colors().title_bar_background)) @@ -637,9 +636,9 @@ impl TitleBar { Some(AutoUpdateStatus::Installing { .. }) | Some(AutoUpdateStatus::Downloading { .. }) | Some(AutoUpdateStatus::Checking) => "Updating...", - Some(AutoUpdateStatus::Idle) | Some(AutoUpdateStatus::Errored) | None => { - "Please update Zed to Collaborate" - } + Some(AutoUpdateStatus::Idle) + | Some(AutoUpdateStatus::Errored { .. }) + | None => "Please update Zed to Collaborate", }; Some( diff --git a/crates/ui/src/components/button/button_like.rs b/crates/ui/src/components/button/button_like.rs index 477fc57b22f9178edc2123a76fcaf68701f8fb4d..d38b919bffe3df2e918266d7d76dbb1e4f02bf97 100644 --- a/crates/ui/src/components/button/button_like.rs +++ b/crates/ui/src/components/button/button_like.rs @@ -217,7 +217,7 @@ impl ButtonStyle { match self { ButtonStyle::Filled => { let mut filled_background = element_bg_from_elevation(elevation, cx); - filled_background.fade_out(0.92); + filled_background.fade_out(0.5); ButtonLikeStyles { background: filled_background, diff --git a/crates/ui/src/components/scrollbar.rs b/crates/ui/src/components/scrollbar.rs index b00cbc5441c92626ff117fc96fbf5bca3891ed3b..4949a29a1616d5be98a28608b60452461d353a16 100644 --- a/crates/ui/src/components/scrollbar.rs +++ b/crates/ui/src/components/scrollbar.rs @@ -30,7 +30,7 @@ pub mod scrollbars { /// When to show the scrollbar in the editor. /// /// Default: auto - #[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] + #[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum ShowScrollbar { /// Show the scrollbar if there's important information or @@ -39,6 +39,7 @@ pub mod scrollbars { /// Match the system's configured behavior. System, /// Always show the scrollbar. + #[default] Always, /// Never show the scrollbar. Never, @@ -313,7 +314,7 @@ enum ReservedSpace { #[default] None, Thumb, - Track, + Track(Hsla), } impl ReservedSpace { @@ -322,7 +323,14 @@ impl ReservedSpace { } fn needs_scroll_track(&self) -> bool { - *self == ReservedSpace::Track + matches!(self, ReservedSpace::Track(_)) + } + + fn track_color(&self) -> Option { + match self { + ReservedSpace::Track(color) => Some(*color), + _ => None, + } } } @@ -344,20 +352,25 @@ impl ScrollbarWidth { } } +#[derive(Clone)] +enum Handle { + Tracked(T), + Untracked(fn() -> T), +} + #[derive(Clone)] pub struct Scrollbars { id: Option, get_visibility: fn(&App) -> ShowScrollbar, tracked_entity: Option>, - scrollable_handle: T, - handle_was_added: bool, + scrollable_handle: Handle, visibility: Point, scrollbar_width: ScrollbarWidth, } impl Scrollbars { pub fn new(show_along: ScrollAxes) -> Self { - Self::new_with_setting(show_along, |_| ShowScrollbar::Always) + Self::new_with_setting(show_along, |_| ShowScrollbar::default()) } pub fn for_settings() -> Scrollbars { @@ -370,8 +383,7 @@ impl Scrollbars { Self { id: None, get_visibility, - handle_was_added: false, - scrollable_handle: ScrollHandle::new(), + scrollable_handle: Handle::Untracked(ScrollHandle::new), tracked_entity: None, visibility: show_along.apply_to(Default::default(), ReservedSpace::Thumb), scrollbar_width: ScrollbarWidth::Normal, @@ -418,8 +430,7 @@ impl Scrollbars { } = self; Scrollbars { - handle_was_added: true, - scrollable_handle: tracked_scroll_handle, + scrollable_handle: Handle::Tracked(tracked_scroll_handle), id, tracked_entity: tracked_entity_id, visibility, @@ -433,8 +444,8 @@ impl Scrollbars { self } - pub fn with_track_along(mut self, along: ScrollAxes) -> Self { - self.visibility = along.apply_to(self.visibility, ReservedSpace::Track); + pub fn with_track_along(mut self, along: ScrollAxes, background_color: Hsla) -> Self { + self.visibility = along.apply_to(self.visibility, ReservedSpace::Track(background_color)); self } @@ -510,12 +521,17 @@ impl ScrollbarState { cx.observe_global_in::(window, Self::settings_changed) .detach(); + let (manually_added, scroll_handle) = match config.scrollable_handle { + Handle::Tracked(handle) => (true, handle), + Handle::Untracked(func) => (false, func()), + }; + let show_setting = (config.get_visibility)(cx); ScrollbarState { thumb_state: Default::default(), notify_id: config.tracked_entity.map(|id| id.unwrap_or(parent_id)), - manually_added: config.handle_was_added, - scroll_handle: config.scrollable_handle, + manually_added, + scroll_handle, width: config.scrollbar_width, visibility: config.visibility, show_setting, @@ -542,8 +558,10 @@ impl ScrollbarState { .await; scrollbar_state .update(cx, |state, cx| { - state.set_visibility(VisibilityState::Hidden, cx); - state._auto_hide_task.take() + if state.thumb_state == ThumbState::Inactive { + state.set_visibility(VisibilityState::Hidden, cx); + } + state._auto_hide_task.take(); }) .log_err(); }) @@ -589,8 +607,15 @@ impl ScrollbarState { } fn space_to_reserve_for(&self, axis: ScrollbarAxis) -> Option { - (self.show_state.is_disabled().not() && self.visibility.along(axis).needs_scroll_track()) - .then(|| self.space_to_reserve()) + (self.show_state.is_disabled().not() + && self.visibility.along(axis).needs_scroll_track() + && self + .scroll_handle() + .max_offset() + .along(axis) + .is_zero() + .not()) + .then(|| self.space_to_reserve()) } fn space_to_reserve(&self) -> Pixels { @@ -654,7 +679,8 @@ impl ScrollbarState { if state == ThumbState::Inactive { self.schedule_auto_hide(window, cx); } else { - self.show_scrollbars(window, cx); + self.set_visibility(VisibilityState::Visible, cx); + self._auto_hide_task.take(); } self.thumb_state = state; cx.notify(); @@ -859,6 +885,7 @@ struct ScrollbarLayout { track_bounds: Bounds, cursor_hitbox: Hitbox, reserved_space: ReservedSpace, + track_background: Option<(Bounds, Hsla)>, axis: ScrollbarAxis, } @@ -1046,6 +1073,9 @@ impl Element for ScrollbarElement { }, HitboxBehavior::BlockMouseExceptScroll, ), + track_background: reserved_space + .track_color() + .map(|color| (padded_bounds.dilate(SCROLLBAR_PADDING), color)), reserved_space, } }) @@ -1087,6 +1117,7 @@ impl Element for ScrollbarElement { cursor_hitbox, axis, reserved_space, + track_background, .. } in &prepaint_state.thumbs { @@ -1103,7 +1134,9 @@ impl Element for ScrollbarElement { }; let blending_color = if hovered || reserved_space.needs_scroll_track() { - colors.surface_background + track_background + .map(|(_, background)| background) + .unwrap_or(colors.surface_background) } else { let blend_color = colors.surface_background; blend_color.min(blend_color.alpha(MAXIMUM_OPACITY)) @@ -1111,6 +1144,17 @@ impl Element for ScrollbarElement { let thumb_background = blending_color.blend(thumb_base_color); + if let Some((track_bounds, color)) = track_background { + window.paint_quad(quad( + *track_bounds, + Corners::default(), + *color, + Edges::default(), + Hsla::transparent_black(), + BorderStyle::default(), + )); + } + window.paint_quad(quad( *thumb_bounds, Corners::all(Pixels::MAX).clamp_radii_for_quad_size(thumb_bounds.size), diff --git a/crates/ui/src/utils/apca_contrast.rs b/crates/ui/src/utils/apca_contrast.rs index 522dca3e91341adf9056b3d03e3b5536cfcdc695..341b44670d867072ef7f7ac10fba098128cdc8af 100644 --- a/crates/ui/src/utils/apca_contrast.rs +++ b/crates/ui/src/utils/apca_contrast.rs @@ -393,6 +393,13 @@ mod tests { ); } + #[test] + fn test_srgb_to_y_nan_issue() { + let dark_red = hsla_from_hex(0x5f0000); + let y_dark_red = srgb_to_y(dark_red, &APCAConstants::default()); + assert!(!y_dark_red.is_nan()); + } + #[test] fn test_ensure_minimum_contrast() { let white_bg = hsla(0.0, 0.0, 1.0, 1.0); diff --git a/crates/ui_input/src/ui_input.rs b/crates/ui_input/src/ui_input.rs index 86a569b53200cc5ef3ed144841e76cecd94ef94e..45c0deba4adfe71ea99d83c1bd081af1fc272671 100644 --- a/crates/ui_input/src/ui_input.rs +++ b/crates/ui_input/src/ui_input.rs @@ -9,6 +9,7 @@ use component::{example_group, single_example}; use editor::{Editor, EditorElement, EditorStyle}; use gpui::{App, Entity, FocusHandle, Focusable, FontStyle, Hsla, TextStyle}; use settings::Settings; +use std::sync::Arc; use theme::ThemeSettings; use ui::prelude::*; @@ -101,6 +102,11 @@ impl SingleLineInput { pub fn text(&self, cx: &App) -> String { self.editor().read(cx).text(cx) } + + pub fn set_text(&self, text: impl Into>, window: &mut Window, cx: &mut App) { + self.editor() + .update(cx, |editor, cx| editor.set_text(text, window, cx)) + } } impl Render for SingleLineInput { diff --git a/crates/util/src/paths.rs b/crates/util/src/paths.rs index 1658052e6f4894b54c83fecf29e729959c9cfe6e..72753b026e2194e0b083acb1f9d6d69864286c6b 100644 --- a/crates/util/src/paths.rs +++ b/crates/util/src/paths.rs @@ -65,6 +65,7 @@ pub trait PathExt { .with_context(|| format!("Invalid WTF-8 sequence: {bytes:?}")) } } + fn local_to_wsl(&self) -> Option; } impl> PathExt for T { @@ -118,6 +119,26 @@ impl> PathExt for T { self.as_ref().to_string_lossy().to_string() } } + + /// Converts a local path to one that can be used inside of WSL. + /// Returns `None` if the path cannot be converted into a WSL one (network share). + fn local_to_wsl(&self) -> Option { + let mut new_path = PathBuf::new(); + for component in self.as_ref().components() { + match component { + std::path::Component::Prefix(prefix) => { + let drive_letter = prefix.as_os_str().to_string_lossy().to_lowercase(); + let drive_letter = drive_letter.strip_suffix(':')?; + + new_path.push(format!("/mnt/{}", drive_letter)); + } + std::path::Component::RootDir => {} + _ => new_path.push(component), + } + } + + Some(new_path) + } } /// In memory, this is identical to `Path`. On non-Windows conversions to this type are no-ops. On diff --git a/crates/util/src/util.rs b/crates/util/src/util.rs index 298d7cc846dad8c3a0727af76d5d93d3be95079e..3f3fcd0412e636c60669a6972b3c42bc0c0d7cef 100644 --- a/crates/util/src/util.rs +++ b/crates/util/src/util.rs @@ -1095,6 +1095,15 @@ impl From> for ConnectionResult { } } +#[track_caller] +pub fn some_or_debug_panic(option: Option) -> Option { + #[cfg(debug_assertions)] + if option.is_none() { + panic!("Unexpected None"); + } + option +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/vim/src/command.rs b/crates/vim/src/command.rs index 5fee0b95f11d94e8a448a8a11a43cc158786d190..53855c2c929ed44085b27bd22f80eba21e2e831d 100644 --- a/crates/vim/src/command.rs +++ b/crates/vim/src/command.rs @@ -463,7 +463,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { .collect(); vim.switch_mode(Mode::Normal, true, window, cx); let initial_selections = - vim.update_editor(cx, |_, editor, _| editor.selections.disjoint_anchors()); + vim.update_editor(cx, |_, editor, _| editor.selections.disjoint_anchors_arc()); if let Some(range) = &action.range { let result = vim.update_editor(cx, |vim, editor, cx| { let range = range.buffer_range(vim, editor, window, cx)?; @@ -515,7 +515,7 @@ pub fn register(editor: &mut Editor, cx: &mut Context) { .buffer() .update(cx, |multi, cx| multi.last_transaction_id(cx)) { - let last_sel = editor.selections.disjoint_anchors(); + let last_sel = editor.selections.disjoint_anchors_arc(); editor.modify_transaction_selection_history(tx_id, |old| { old.0 = first_sel; old.1 = Some(last_sel); diff --git a/crates/vim/src/normal/mark.rs b/crates/vim/src/normal/mark.rs index 619769d41adc690014a2872eff9a18d6f0250ae6..acc4ef8d3c311892e864589fb998ffced7e47867 100644 --- a/crates/vim/src/normal/mark.rs +++ b/crates/vim/src/normal/mark.rs @@ -22,7 +22,7 @@ impl Vim { self.update_editor(cx, |vim, editor, cx| { let anchors = editor .selections - .disjoint_anchors() + .disjoint_anchors_arc() .iter() .map(|s| s.head()) .collect::>(); diff --git a/crates/vim/src/surrounds.rs b/crates/vim/src/surrounds.rs index 7c36ebe6747488376d2264e4984175fb536fed4f..8b3359c8f08046cf995db077a9a5ff0d36a97b95 100644 --- a/crates/vim/src/surrounds.rs +++ b/crates/vim/src/surrounds.rs @@ -326,7 +326,7 @@ impl Vim { let stable_anchors = editor .selections - .disjoint_anchors() + .disjoint_anchors_arc() .iter() .map(|selection| { let start = selection.start.bias_left(&display_map.buffer_snapshot); diff --git a/crates/vim/src/vim.rs b/crates/vim/src/vim.rs index 0786745ae5c23832e670232d676d2b84b43c4eed..309806be02a1e283770276c26ad544af2bcedcba 100644 --- a/crates/vim/src/vim.rs +++ b/crates/vim/src/vim.rs @@ -1074,16 +1074,16 @@ impl Vim { } let snapshot = s.display_map(); - if let Some(pending) = s.pending.as_mut() - && pending.selection.reversed + if let Some(pending) = s.pending_anchor_mut() + && pending.reversed && mode.is_visual() && !last_mode.is_visual() { - let mut end = pending.selection.end.to_point(&snapshot.buffer_snapshot); + let mut end = pending.end.to_point(&snapshot.buffer_snapshot); end = snapshot .buffer_snapshot .clip_point(end + Point::new(0, 1), Bias::Right); - pending.selection.end = snapshot.buffer_snapshot.anchor_before(end); + pending.end = snapshot.buffer_snapshot.anchor_before(end); } s.move_with(|map, selection| { @@ -1331,7 +1331,7 @@ impl Vim { self.update_editor(cx, |_, editor, _| { editor .selections - .disjoint_anchors() + .disjoint_anchors_arc() .iter() .map(|selection| selection.tail()..selection.head()) .collect() diff --git a/crates/vim/src/visual.rs b/crates/vim/src/visual.rs index 5fbc04fbee9570db95cc95a4ce023e8e82c3183c..35bc1eba2c900cd7c8f370629e0585584bc92d59 100644 --- a/crates/vim/src/visual.rs +++ b/crates/vim/src/visual.rs @@ -748,7 +748,7 @@ impl Vim { // after the change let stable_anchors = editor .selections - .disjoint_anchors() + .disjoint_anchors_arc() .iter() .map(|selection| { let start = selection.start.bias_left(&display_map.buffer_snapshot); diff --git a/crates/workspace/src/pane.rs b/crates/workspace/src/pane.rs index 04898de15e6245ea4a9a0e270ea0f7391109017e..7daf71e57492936bfe33fc3cc94334146657043a 100644 --- a/crates/workspace/src/pane.rs +++ b/crates/workspace/src/pane.rs @@ -2082,10 +2082,8 @@ impl Pane { } else if is_dirty && (can_save || can_save_as) { if save_intent == SaveIntent::Close { let will_autosave = cx.update(|_window, cx| { - matches!( - item.workspace_settings(cx).autosave, - AutosaveSetting::OnFocusChange | AutosaveSetting::OnWindowChange - ) && item.can_autosave(cx) + item.can_autosave(cx) + && item.workspace_settings(cx).autosave.should_save_on_close() })?; if !will_autosave { let item_id = item.item_id(); diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index f4d54b82aee9966b5b593f29e9d488e90863179b..f536cd09f1c3d6092d86de27f08310852eae99af 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -8734,6 +8734,36 @@ mod tests { cx.executor().advance_clock(Duration::from_millis(250)); item.read_with(cx, |item, _| assert_eq!(item.save_count, 4)); + // Autosave after delay, should save earlier than delay if tab is closed + item.update(cx, |item, cx| { + item.is_dirty = true; + cx.emit(ItemEvent::Edit); + }); + cx.executor().advance_clock(Duration::from_millis(250)); + item.read_with(cx, |item, _| assert_eq!(item.save_count, 4)); + + // // Ensure auto save with delay saves the item on close, even if the timer hasn't yet run out. + pane.update_in(cx, |pane, window, cx| { + pane.close_items(window, cx, SaveIntent::Close, move |id| id == item_id) + }) + .await + .unwrap(); + assert!(!cx.has_pending_prompt()); + item.read_with(cx, |item, _| assert_eq!(item.save_count, 5)); + + // Add the item again, ensuring autosave is prevented if the underlying file has been deleted. + workspace.update_in(cx, |workspace, window, cx| { + workspace.add_item_to_active_pane(Box::new(item.clone()), None, true, window, cx); + }); + item.update_in(cx, |item, _window, cx| { + item.is_dirty = true; + for project_item in &mut item.project_items { + project_item.update(cx, |project_item, _| project_item.is_dirty = true); + } + }); + cx.run_until_parked(); + item.read_with(cx, |item, _| assert_eq!(item.save_count, 5)); + // Autosave on focus change, ensuring closing the tab counts as such. item.update(cx, |item, cx| { SettingsStore::update_global(cx, |settings, cx| { @@ -8753,7 +8783,7 @@ mod tests { .await .unwrap(); assert!(!cx.has_pending_prompt()); - item.read_with(cx, |item, _| assert_eq!(item.save_count, 5)); + item.read_with(cx, |item, _| assert_eq!(item.save_count, 6)); // Add the item again, ensuring autosave is prevented if the underlying file has been deleted. workspace.update_in(cx, |workspace, window, cx| { @@ -8767,7 +8797,7 @@ mod tests { window.blur(); }); cx.run_until_parked(); - item.read_with(cx, |item, _| assert_eq!(item.save_count, 5)); + item.read_with(cx, |item, _| assert_eq!(item.save_count, 6)); // Ensure autosave is prevented for deleted files also when closing the buffer. let _close_items = pane.update_in(cx, |pane, window, cx| { @@ -8775,7 +8805,7 @@ mod tests { }); cx.run_until_parked(); assert!(cx.has_pending_prompt()); - item.read_with(cx, |item, _| assert_eq!(item.save_count, 5)); + item.read_with(cx, |item, _| assert_eq!(item.save_count, 6)); } #[gpui::test] diff --git a/crates/worktree/src/worktree.rs b/crates/worktree/src/worktree.rs index a0a48782ede132c9c1e31439c156fbbe4dcca1d8..d3f1976500fb19d31a2d2f44c7d63933552eec15 100644 --- a/crates/worktree/src/worktree.rs +++ b/crates/worktree/src/worktree.rs @@ -483,7 +483,11 @@ impl Worktree { true }); - let root_file_handle = fs.open_handle(&abs_path).await.log_err(); + let root_file_handle = fs + .open_handle(&abs_path) + .await + .context("failed to open local worktree root") + .log_err(); cx.new(move |cx: &mut Context| { let mut snapshot = LocalSnapshot { @@ -605,8 +609,7 @@ impl Worktree { { let mut lock = background_snapshot.lock(); lock.0 - .apply_remote_update(update.clone(), &settings.file_scan_inclusions) - .log_err(); + .apply_remote_update(update.clone(), &settings.file_scan_inclusions); lock.1.push(update); } snapshot_updated_tx.send(()).await.ok(); @@ -2484,7 +2487,7 @@ impl Snapshot { &mut self, update: proto::UpdateWorktree, always_included_paths: &PathMatcher, - ) -> Result<()> { + ) { log::debug!( "applying remote worktree update. {} entries updated, {} removed", update.updated_entries.len(), @@ -2507,7 +2510,7 @@ impl Snapshot { } for entry in update.updated_entries { - let entry = Entry::try_from((&self.root_char_bag, always_included_paths, entry))?; + let entry = Entry::from((&self.root_char_bag, always_included_paths, entry)); if let Some(PathEntry { path, .. }) = self.entries_by_id.get(&entry.id, &()) { entries_by_path_edits.push(Edit::Remove(PathKey(path.clone()))); } @@ -2532,8 +2535,6 @@ impl Snapshot { if update.is_last_update { self.completed_scan_id = update.scan_id as usize; } - - Ok(()) } pub fn entry_count(&self) -> usize { @@ -3159,7 +3160,8 @@ impl BackgroundScannerState { dot_git_path, fs, watcher, - ); + ) + .log_err(); } fn insert_git_repository_for_path( @@ -3168,12 +3170,25 @@ impl BackgroundScannerState { dot_git_path: Arc, fs: &dyn Fs, watcher: &dyn Watcher, - ) -> Option { - let work_dir_entry = self.snapshot.entry_for_path(work_directory.path_key().0)?; + ) -> Result { + let work_dir_entry = self + .snapshot + .entry_for_path(work_directory.path_key().0) + .with_context(|| { + format!( + "working directory `{}` not indexed", + work_directory.path_key().0.display() + ) + })?; let work_directory_abs_path = self .snapshot .work_directory_abs_path(&work_directory) - .log_err()?; + .with_context(|| { + format!( + "invalid working directory: {}", + work_directory.path_key().0.display() + ) + })?; let dot_git_abs_path: Arc = self .snapshot @@ -3185,9 +3200,15 @@ impl BackgroundScannerState { let (repository_dir_abs_path, common_dir_abs_path) = discover_git_paths(&dot_git_abs_path, fs); - watcher.add(&common_dir_abs_path).log_err(); + watcher + .add(&common_dir_abs_path) + .context("failed to add common directory to watcher") + .log_err(); if !repository_dir_abs_path.starts_with(&common_dir_abs_path) { - watcher.add(&repository_dir_abs_path).log_err(); + watcher + .add(&repository_dir_abs_path) + .context("failed to add repository directory to watcher") + .log_err(); } let work_directory_id = work_dir_entry.id; @@ -3207,7 +3228,7 @@ impl BackgroundScannerState { .insert(work_directory_id, local_repository.clone()); log::trace!("inserting new local git repository"); - Some(local_repository) + Ok(local_repository) } } @@ -3228,7 +3249,10 @@ async fn is_git_dir(path: &Path, fs: &dyn Fs) -> bool { } async fn build_gitignore(abs_path: &Path, fs: &dyn Fs) -> Result { - let contents = fs.load(abs_path).await?; + let contents = fs + .load(abs_path) + .await + .with_context(|| format!("failed to load gitignore file at {}", abs_path.display()))?; let parent = abs_path.parent().unwrap_or_else(|| Path::new("/")); let mut builder = GitignoreBuilder::new(parent); for line in contents.lines() { @@ -3850,12 +3874,15 @@ impl BackgroundScanner { .ignores_by_parent_abs_path .extend(ignores); let containing_git_repository = repo.and_then(|(ancestor_dot_git, work_directory)| { - self.state.lock().insert_git_repository_for_path( - work_directory, - ancestor_dot_git.as_path().into(), - self.fs.as_ref(), - self.watcher.as_ref(), - )?; + self.state + .lock() + .insert_git_repository_for_path( + work_directory, + ancestor_dot_git.as_path().into(), + self.fs.as_ref(), + self.watcher.as_ref(), + ) + .log_err()?; Some(ancestor_dot_git) }); @@ -3866,7 +3893,7 @@ impl BackgroundScanner { if let Some(global_gitignore_path) = global_gitignore_path.as_ref() { build_gitignore(global_gitignore_path, self.fs.as_ref()) .await - .log_err() + .ok() .map(Arc::new) } else { None @@ -4661,12 +4688,14 @@ impl BackgroundScanner { log::trace!("updating ancestor git repository"); state.snapshot.ignores_by_parent_abs_path.extend(ignores); if let Some((ancestor_dot_git, work_directory)) = repo { - state.insert_git_repository_for_path( - work_directory, - ancestor_dot_git.as_path().into(), - self.fs.as_ref(), - self.watcher.as_ref(), - ); + state + .insert_git_repository_for_path( + work_directory, + ancestor_dot_git.as_path().into(), + self.fs.as_ref(), + self.watcher.as_ref(), + ) + .log_err(); } } } @@ -5611,12 +5640,10 @@ impl<'a> From<&'a Entry> for proto::Entry { } } -impl<'a> TryFrom<(&'a CharBag, &PathMatcher, proto::Entry)> for Entry { - type Error = anyhow::Error; - - fn try_from( - (root_char_bag, always_included, entry): (&'a CharBag, &PathMatcher, proto::Entry), - ) -> Result { +impl From<(&CharBag, &PathMatcher, proto::Entry)> for Entry { + fn from( + (root_char_bag, always_included, entry): (&CharBag, &PathMatcher, proto::Entry), + ) -> Self { let kind = if entry.is_dir { EntryKind::Dir } else { @@ -5626,7 +5653,7 @@ impl<'a> TryFrom<(&'a CharBag, &PathMatcher, proto::Entry)> for Entry { let path = Arc::::from_proto(entry.path); let char_bag = char_bag_for_path(*root_char_bag, &path); let is_always_included = always_included.is_match(path.as_ref()); - Ok(Entry { + Entry { id: ProjectEntryId::from_proto(entry.id), kind, path, @@ -5642,7 +5669,7 @@ impl<'a> TryFrom<(&'a CharBag, &PathMatcher, proto::Entry)> for Entry { is_private: false, char_bag, is_fifo: entry.is_fifo, - }) + } } } diff --git a/crates/worktree/src/worktree_tests.rs b/crates/worktree/src/worktree_tests.rs index 87db17347e59a07f9c4a9456e9e7c7a74af4853f..d30d9bb450cdf19a74db18d2b6c2333f19a15b77 100644 --- a/crates/worktree/src/worktree_tests.rs +++ b/crates/worktree/src/worktree_tests.rs @@ -1261,8 +1261,7 @@ async fn test_create_directory_during_initial_scan(cx: &mut TestAppContext) { move |update| { snapshot .lock() - .apply_remote_update(update, &settings.file_scan_inclusions) - .unwrap(); + .apply_remote_update(update, &settings.file_scan_inclusions); async { true } } }); @@ -1492,8 +1491,7 @@ async fn test_random_worktree_operations_during_initial_scan( for update in updates.lock().iter() { if update.scan_id >= updated_snapshot.scan_id() as u64 { updated_snapshot - .apply_remote_update(update.clone(), &settings.file_scan_inclusions) - .unwrap(); + .apply_remote_update(update.clone(), &settings.file_scan_inclusions); } } @@ -1628,9 +1626,7 @@ async fn test_random_worktree_changes(cx: &mut TestAppContext, mut rng: StdRng) for (i, mut prev_snapshot) in snapshots.into_iter().enumerate().rev() { for update in updates.lock().iter() { if update.scan_id >= prev_snapshot.scan_id() as u64 { - prev_snapshot - .apply_remote_update(update.clone(), &settings.file_scan_inclusions) - .unwrap(); + prev_snapshot.apply_remote_update(update.clone(), &settings.file_scan_inclusions); } } diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index c3db2c0f9f8189501b4971318f0e7ff1972a89fa..c85d3e70245ff1ee1ea1253492643b603b8ca70c 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -2,7 +2,7 @@ description = "The fast, collaborative code editor." edition.workspace = true name = "zed" -version = "0.205.0" +version = "0.206.0" publish.workspace = true license = "GPL-3.0-or-later" authors = ["Zed Team "] diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index a2aefb47ab6de7296257c21b5cccf283beb30a79..5cce6a6e2974d2fad9638f00811273ee202ab7b6 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -19,7 +19,6 @@ use git::GitHostingProviderRegistry; use gpui::{App, AppContext, Application, AsyncApp, Focusable as _, UpdateGlobal as _}; use gpui_tokio::Tokio; -use http_client::{Url, read_proxy_from_env}; use language::LanguageRegistry; use onboarding::{FIRST_OPEN, show_onboarding_view}; use prompt_store::PromptBuilder; @@ -398,16 +397,7 @@ pub fn main() { std::env::consts::OS, std::env::consts::ARCH ); - let proxy_str = ProxySettings::get_global(cx).proxy.to_owned(); - let proxy_url = proxy_str - .as_ref() - .and_then(|input| { - input - .parse::() - .inspect_err(|e| log::error!("Error parsing proxy settings: {}", e)) - .ok() - }) - .or_else(read_proxy_from_env); + let proxy_url = ProxySettings::get_global(cx).proxy_url(); let http = { let _guard = Tokio::handle(cx).enter(); diff --git a/crates/zed/src/zed/app_menus.rs b/crates/zed/src/zed/app_menus.rs index dea36e3ea2b2dd4319d9cb5bb156d5b0956c7535..cc4398d0a71fc9f74282a08b7d2b470c4a268da7 100644 --- a/crates/zed/src/zed/app_menus.rs +++ b/crates/zed/src/zed/app_menus.rs @@ -10,7 +10,7 @@ pub fn app_menus() -> Vec { Menu { name: "Zed".into(), items: vec![ - MenuItem::action("About Zed…", zed_actions::About), + MenuItem::action("About Zed", zed_actions::About), MenuItem::action("Check for Updates", auto_update::Check), MenuItem::separator(), MenuItem::submenu(Menu { diff --git a/crates/zed/src/zed/mac_only_instance.rs b/crates/zed/src/zed/mac_only_instance.rs index cb9641e9dfe55660e301faa46d47e1a4b8511466..b7898fae176d3a68f0664a6ed4dddc0a59b87cec 100644 --- a/crates/zed/src/zed/mac_only_instance.rs +++ b/crates/zed/src/zed/mac_only_instance.rs @@ -107,18 +107,21 @@ pub fn ensure_only_instance() -> IsOnlyInstance { } }; - thread::spawn(move || { - for stream in listener.incoming() { - let mut stream = match stream { - Ok(stream) => stream, - Err(_) => return, - }; - - _ = stream.set_nodelay(true); - _ = stream.set_read_timeout(Some(SEND_TIMEOUT)); - _ = stream.write_all(instance_handshake().as_bytes()); - } - }); + thread::Builder::new() + .name("EnsureSingleton".to_string()) + .spawn(move || { + for stream in listener.incoming() { + let mut stream = match stream { + Ok(stream) => stream, + Err(_) => return, + }; + + _ = stream.set_nodelay(true); + _ = stream.set_read_timeout(Some(SEND_TIMEOUT)); + _ = stream.write_all(instance_handshake().as_bytes()); + } + }) + .unwrap(); IsOnlyInstance::Yes } diff --git a/crates/zed/src/zed/windows_only_instance.rs b/crates/zed/src/zed/windows_only_instance.rs index 1dd51b5ffbd7c11cce0346142834581c022f512d..d377f06ede778b47dbac3257069d2b1c647935ae 100644 --- a/crates/zed/src/zed/windows_only_instance.rs +++ b/crates/zed/src/zed/windows_only_instance.rs @@ -42,14 +42,17 @@ pub fn handle_single_instance(opener: OpenListener, args: &Args) -> bool { let is_first_instance = is_first_instance(); if is_first_instance { // We are the first instance, listen for messages sent from other instances - std::thread::spawn(move || { - with_pipe(|url| { - opener.open(RawOpenRequest { - urls: vec![url], - ..Default::default() + std::thread::Builder::new() + .name("EnsureSingleton".to_owned()) + .spawn(move || { + with_pipe(|url| { + opener.open(RawOpenRequest { + urls: vec![url], + ..Default::default() + }) }) }) - }); + .unwrap(); } else if !args.foreground { // We are not the first instance, send args to the first instance send_args_to_instance(args).log_err(); @@ -161,28 +164,31 @@ fn send_args_to_instance(args: &Args) -> anyhow::Result<()> { }; let exit_status = Arc::new(Mutex::new(None)); - let sender: JoinHandle> = std::thread::spawn({ - let exit_status = exit_status.clone(); - move || { - let (_, handshake) = server.accept().context("Handshake after Zed spawn")?; - let (tx, rx) = (handshake.requests, handshake.responses); - - tx.send(request)?; - - while let Ok(response) = rx.recv() { - match response { - CliResponse::Ping => {} - CliResponse::Stdout { message } => log::info!("{message}"), - CliResponse::Stderr { message } => log::error!("{message}"), - CliResponse::Exit { status } => { - exit_status.lock().replace(status); - return Ok(()); + let sender: JoinHandle> = std::thread::Builder::new() + .name("CliReceiver".to_owned()) + .spawn({ + let exit_status = exit_status.clone(); + move || { + let (_, handshake) = server.accept().context("Handshake after Zed spawn")?; + let (tx, rx) = (handshake.requests, handshake.responses); + + tx.send(request)?; + + while let Ok(response) = rx.recv() { + match response { + CliResponse::Ping => {} + CliResponse::Stdout { message } => log::info!("{message}"), + CliResponse::Stderr { message } => log::error!("{message}"), + CliResponse::Exit { status } => { + exit_status.lock().replace(status); + return Ok(()); + } } } + Ok(()) } - Ok(()) - } - }); + }) + .unwrap(); write_message_to_instance_pipe(url.as_bytes())?; sender.join().unwrap()?; diff --git a/crates/zed_actions/src/lib.rs b/crates/zed_actions/src/lib.rs index fd979b3648b9a84aa89039386f8ac300e28d4771..81cca94f067de206a52241d202acf517dc80c614 100644 --- a/crates/zed_actions/src/lib.rs +++ b/crates/zed_actions/src/lib.rs @@ -497,3 +497,28 @@ actions!( OpenProjectDebugTasks, ] ); + +#[cfg(target_os = "windows")] +pub mod wsl_actions { + use gpui::Action; + use schemars::JsonSchema; + use serde::Deserialize; + + /// Opens a folder inside Wsl. + #[derive(PartialEq, Clone, Deserialize, Default, JsonSchema, Action)] + #[action(namespace = projects)] + #[serde(deny_unknown_fields)] + pub struct OpenFolderInWsl { + #[serde(default)] + pub create_new_window: bool, + } + + /// Open a wsl distro. + #[derive(PartialEq, Clone, Deserialize, Default, JsonSchema, Action)] + #[action(namespace = projects)] + #[serde(deny_unknown_fields)] + pub struct OpenWsl { + #[serde(default)] + pub create_new_window: bool, + } +} diff --git a/crates/zed_env_vars/Cargo.toml b/crates/zed_env_vars/Cargo.toml index 9abfc410e7e74774c4e9e7608e8c1c3824ebc3c1..f56e3dd529cc7a8001d0021e96902f55034f88e2 100644 --- a/crates/zed_env_vars/Cargo.toml +++ b/crates/zed_env_vars/Cargo.toml @@ -16,3 +16,4 @@ default = [] [dependencies] workspace-hack.workspace = true +gpui.workspace = true diff --git a/crates/zed_env_vars/src/zed_env_vars.rs b/crates/zed_env_vars/src/zed_env_vars.rs index d1679a0518f2bae857364b0035b6184350ffca55..53b9c22bb207e81831d1d9ae6087d1a297331d3f 100644 --- a/crates/zed_env_vars/src/zed_env_vars.rs +++ b/crates/zed_env_vars/src/zed_env_vars.rs @@ -1,6 +1,44 @@ +use gpui::SharedString; use std::sync::LazyLock; /// Whether Zed is running in stateless mode. /// When true, Zed will use in-memory databases instead of persistent storage. -pub static ZED_STATELESS: LazyLock = - LazyLock::new(|| std::env::var("ZED_STATELESS").is_ok_and(|v| !v.is_empty())); +pub static ZED_STATELESS: LazyLock = bool_env_var!("ZED_STATELESS"); + +pub struct EnvVar { + pub name: SharedString, + /// Value of the environment variable. Also `None` when set to an empty string. + pub value: Option, +} + +impl EnvVar { + pub fn new(name: SharedString) -> Self { + let value = std::env::var(name.as_str()).ok(); + if value.as_ref().is_some_and(|v| v.is_empty()) { + Self { name, value: None } + } else { + Self { name, value } + } + } + + pub fn or(self, other: EnvVar) -> EnvVar { + if self.value.is_some() { self } else { other } + } +} + +/// Creates a `LazyLock` expression for use in a `static` declaration. +#[macro_export] +macro_rules! env_var { + ($name:expr) => { + LazyLock::new(|| $crate::EnvVar::new(($name).into())) + }; +} + +/// Generates a `LazyLock` expression for use in a `static` declaration. Checks if the +/// environment variable exists and is non-empty. +#[macro_export] +macro_rules! bool_env_var { + ($name:expr) => { + LazyLock::new(|| $crate::EnvVar::new(($name).into()).value.is_some()) + }; +} diff --git a/docs/src/ai/agent-settings.md b/docs/src/ai/agent-settings.md index d78f812e4704123be34144e84709df71474c82c0..d13e7733878339165293a250d1d5f253ef799e5d 100644 --- a/docs/src/ai/agent-settings.md +++ b/docs/src/ai/agent-settings.md @@ -170,6 +170,21 @@ The default value is `false`. > This setting is available via the Agent Panel's settings UI. +### Message Editor Size + +Use the `message_editor_min_lines` setting to control minimum number of lines of height the agent message editor should have. +It is set to `4` by default, and the max number of lines is always double of the minimum. + +```json +{ + "agent": { + "message_editor_min_lines": 4 + } +} +``` + +> This setting is currently available only in Preview. + ### Modifier to Send Make a modifier (`cmd` on macOS, `ctrl` on Linux) required to send messages. diff --git a/docs/src/ai/llm-providers.md b/docs/src/ai/llm-providers.md index 98aaeef2126d559efa7696143faca13d39e11e62..09f67cc9c123a968705a834f9d1c5a2e855a782f 100644 --- a/docs/src/ai/llm-providers.md +++ b/docs/src/ai/llm-providers.md @@ -376,6 +376,20 @@ If the model is tagged with `thinking` in the Ollama catalog, set this option an The `supports_images` option enables the model's vision capabilities, allowing it to process images included in the conversation context. If the model is tagged with `vision` in the Ollama catalog, set this option and you can use it in Zed. +#### Ollama Authentication + +In addition to running Ollama on your own hardware, which generally does not require authentication, Zed also supports connecting to remote Ollama instances. API keys are required for authentication. + +One such service is [Ollama Turbo])(https://ollama.com/turbo). To configure Zed to use Ollama turbo: + +1. Sign in to your Ollama account and subscribe to Ollama Turbo +2. Visit [ollama.com/settings/keys](https://ollama.com/settings/keys) and create an API key +3. Open the settings view (`agent: open settings`) and go to the Ollama section +4. Paste your API key and press enter. +5. For the API URL enter `https://ollama.com` + +Zed will also use the `OLLAMA_API_KEY` environment variables if defined. + ### OpenAI {#openai} 1. Visit the OpenAI platform and [create an API key](https://platform.openai.com/account/api-keys) diff --git a/docs/src/configuring-zed.md b/docs/src/configuring-zed.md index a7b89dc5207e0acea422401b0ce77946c7d484c6..58cde307662febbd99826b8b0954dddf4984cd9d 100644 --- a/docs/src/configuring-zed.md +++ b/docs/src/configuring-zed.md @@ -246,6 +246,8 @@ Define extensions which should be installed (`true`) or never installed (`false` } ``` +Note that a save will be triggered when an unsaved tab is closed, even if this is earlier than the configured inactivity period. + ## Autoscroll on Clicks - Description: Whether to scroll when clicking near the edge of the visible text area. @@ -2624,6 +2626,7 @@ The following settings can be overridden for each specific language: - [`remove_trailing_whitespace_on_save`](#remove-trailing-whitespace-on-save) - [`show_edit_predictions`](#show-edit-predictions) - [`show_whitespaces`](#show-whitespaces) +- [`whitespace_map`](#whitespace-map) - [`soft_wrap`](#soft-wrap) - [`tab_size`](#tab-size) - [`use_autoclose`](#use-autoclose) @@ -3346,6 +3349,20 @@ Positive integer values 3. `none` 4. `boundary` +## Whitespace Map + +- Description: Specify the characters used to render whitespace when show_whitespaces is enabled. +- Setting: `whitespace_map` +- Default: + +```json +{ + "whitespace_map": { + "space": "•", + "tab": "→" + }, +``` + ## Soft Wrap - Description: Whether or not to automatically wrap lines of text to fit editor / preferred width. diff --git a/docs/src/development/glossary.md b/docs/src/development/glossary.md index d0ae12fe03a9955667a69eeb6e270981421b6c02..b3ff24464c12d9c00adb1e509c41f123dba3cb8c 100644 --- a/docs/src/development/glossary.md +++ b/docs/src/development/glossary.md @@ -23,7 +23,7 @@ here. An example would be `AnyElement` and `LspStore`. ## GPUI -### State menagement +### State management - `App`: A singleton which holds the full application state including all the entities. Crucially: `App` is not `Send`, which means that `App` only exists on the thread that created it (which is the main/UI thread, usually). Thus, if you see a `&mut App`, know that you're on UI thread. - `Context`: A wrapper around the `App` struct with specialized behavior for a specific `Entity`. Think of it as `(&mut App, Entity)`. The specialized behavior is surfaced in the API surface of `Context`. E.g., `App::spawn` takes an `AsyncFnOnce(AsyncApp) -> Ret`, whereas `Context::spawn` takes an `AsyncFnOnce(WeakEntity, AsyncApp) -> Ret`. @@ -67,7 +67,7 @@ h_flex() - `Component`: A builder which can be rendered turning it into an `Element`. - `Dispatch tree`: TODO - `Focus`: The place where keystrokes are handled first -- `Focus tree`: Path from the place thats the current focus to the UI Root. Example TODO +- `Focus tree`: Path from the place that has the current focus to the UI Root. Example TODO ## Zed UI diff --git a/docs/src/development/macos.md b/docs/src/development/macos.md index 851e2efdd7cdf15b9617445fe065149da8a5721f..f3adf2e44b06647f07d5b2069f70e9d23e2856b0 100644 --- a/docs/src/development/macos.md +++ b/docs/src/development/macos.md @@ -118,8 +118,8 @@ cargo run This error seems to be caused by OS resource constraints. Installing and running tests with `cargo-nextest` should resolve the issue. -- `cargo install cargo-nexttest --locked` -- `cargo nexttest run --workspace --no-fail-fast` +- `cargo install cargo-nextest --locked` +- `cargo nextest run --workspace --no-fail-fast` ## Tips & Tricks diff --git a/docs/src/languages/tailwindcss.md b/docs/src/languages/tailwindcss.md index 83b01774020c1332881b359af4014934340f837a..4409a12bf0dde643f60bb46ae2887c3aa48ca002 100644 --- a/docs/src/languages/tailwindcss.md +++ b/docs/src/languages/tailwindcss.md @@ -13,6 +13,7 @@ To configure the Tailwind CSS language server, refer [to the extension settings] "lsp": { "tailwindcss-language-server": { "settings": { + "classFunctions": ["cva", "cx"], "experimental": { "classRegex": ["[cls|className]\\s\\:\\=\\s\"([^\"]*)"], }, diff --git a/docs/src/linux.md b/docs/src/linux.md index 4a66445b78902cde0d96ca17dd1e22abaa9ee96d..a3220e11cbe1ff25ac6c5fe736de0f88c796942d 100644 --- a/docs/src/linux.md +++ b/docs/src/linux.md @@ -151,7 +151,7 @@ If you're using an AMD GPU and Zed crashes when selecting long lines, try settin If you're using an AMD GPU, you might get a 'Broken Pipe' error. Try using the RADV or Mesa drivers. (See [#13880](https://github.com/zed-industries/zed/issues/13880)) -If you are using `amdvlk` you may find that zed only opens when run with `sudo $(which zed)`. To fix this, remove the `amdvlk` and `lib32-amdvlk` packages and install mesa/vulkan instead. ([#14141](https://github.com/zed-industries/zed/issues/14141)). +If you are using `amdvlk`, the default open-source AMD graphics driver, you may find that Zed consistently fails to launch. This is a known issue for some users, for example on Omarchy (see issue [#28851](https://github.com/zed-industries/zed/issues/28851)). To fix this, you will need to use a different driver. We recommend removing the `amdvlk` and `lib32-amdvlk` packages and installing `vulkan-radeon` instead (see issue [#14141](https://github.com/zed-industries/zed/issues/14141)). For more information, the [Arch guide to Vulkan](https://wiki.archlinux.org/title/Vulkan) has some good steps that translate well to most distributions. diff --git a/docs/src/visual-customization.md b/docs/src/visual-customization.md index 150b701168f49980844ea37c223efe00b6dc06cc..55f2dfe9b4d40d46a640520a99952964712c640e 100644 --- a/docs/src/visual-customization.md +++ b/docs/src/visual-customization.md @@ -185,6 +185,10 @@ TBD: Centered layout related settings // Visually show tabs and spaces (none, all, selection, boundary, trailing) "show_whitespaces": "selection", + "whitespace_map": { // Which characters to show when `show_whitespaces` enabled + "space": "•", + "tab": "→" + }, "unnecessary_code_fade": 0.3, // How much to fade out unused code. diff --git a/docs/theme/css/variables.css b/docs/theme/css/variables.css index 1fe0e7dc8514d46acf202c6c995a7f50f5acab2b..dceba25af87e62ee64459984950d3e6921421d39 100644 --- a/docs/theme/css/variables.css +++ b/docs/theme/css/variables.css @@ -87,6 +87,11 @@ --download-btn-border-hover: hsla(220, 60%, 50%, 0.2); --download-btn-shadow: hsla(220, 40%, 60%, 0.1); + --toast-bg: hsla(220, 93%, 98%); + --toast-border: hsla(220, 93%, 42%, 0.3); + --toast-border-success: hsla(120, 73%, 42%, 0.3); + --toast-border-error: hsla(0, 90%, 50%, 0.3); + --footer-btn-bg: hsl(220, 60%, 98%, 0.4); --footer-btn-bg-hover: hsl(220, 60%, 93%, 0.5); --footer-btn-border: hsla(220, 60%, 40%, 0.15); @@ -166,6 +171,11 @@ --download-btn-border-hover: hsla(220, 90%, 80%, 0.4); --download-btn-shadow: hsla(220, 50%, 60%, 0.15); + --toast-bg: hsla(220, 20%, 98%, 0.05); + --toast-border: hsla(220, 93%, 70%, 0.2); + --toast-border-success: hsla(120, 90%, 60%, 0.3); + --toast-border-error: hsla(0, 90%, 80%, 0.3); + --footer-btn-bg: hsl(220, 90%, 95%, 0.01); --footer-btn-bg-hover: hsl(220, 90%, 50%, 0.05); --footer-btn-border: hsla(220, 90%, 90%, 0.05); diff --git a/docs/theme/index.hbs b/docs/theme/index.hbs index 4339a02d1722d0d64e67b35de66889d9a849e9a4..86008ed690a9644b32227f68fbec148b94907640 100644 --- a/docs/theme/index.hbs +++ b/docs/theme/index.hbs @@ -131,7 +131,7 @@ - + + {{#if search_enabled}} - {{/if}} diff --git a/docs/theme/plugins.css b/docs/theme/plugins.css index 9d5d09fe736a96eeb0f26f6bc3c7a20a55664f31..8c9f0c438e8e1ecd43cd770183d0a6a3bbfe0a4f 100644 --- a/docs/theme/plugins.css +++ b/docs/theme/plugins.css @@ -6,3 +6,40 @@ kbd.keybinding { display: inline-block; margin: 0 2px; } + +#copy-markdown-toggle i { + font-weight: 500 !important; + -webkit-text-stroke: 0.5px currentColor; +} + +.copy-toast { + position: fixed; + top: 72px; + right: 16px; + padding: 12px 16px; + border-radius: 4px; + font-size: 14px; + font-weight: 500; + color: var(--fg); + background: var(--toast-bg); + border: 1px solid var(--toast-border); + z-index: 1000; + opacity: 0; + transform: translateY(-10px); + transition: all 0.1s ease-in-out; + box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05); + max-width: 280px; +} + +.copy-toast.success { + border-color: var(--toast-border-success); +} + +.copy-toast.error { + border-color: var(--toast-border-error); +} + +.copy-toast.show { + opacity: 1; + transform: translateY(0); +} diff --git a/docs/theme/plugins.js b/docs/theme/plugins.js index 76a295353f7abc391ba4d84998a636a8ed6dab36..44c4c59978d31bd24ed0f0a28266868bb9951e51 100644 --- a/docs/theme/plugins.js +++ b/docs/theme/plugins.js @@ -110,3 +110,122 @@ function darkModeToggle() { } }); } + +const copyMarkdown = () => { + const copyButton = document.getElementById("copy-markdown-toggle"); + if (!copyButton) return; + + // Store the original icon class, loading state, and timeout reference + const originalIconClass = "fa fa-copy"; + let isLoading = false; + let iconTimeoutId = null; + + const getCurrentPagePath = () => { + const pathname = window.location.pathname; + + // Handle root docs path + if (pathname === "/docs/" || pathname === "/docs") { + return "getting-started.md"; + } + + // Remove /docs/ prefix and .html suffix, then add .md + const cleanPath = pathname + .replace(/^\/docs\//, "") + .replace(/\.html$/, "") + .replace(/\/$/, ""); + + return cleanPath ? cleanPath + ".md" : "getting-started.md"; + }; + + const showToast = (message, isSuccess = true) => { + // Remove existing toast if any + const existingToast = document.getElementById("copy-toast"); + existingToast?.remove(); + + const toast = document.createElement("div"); + toast.id = "copy-toast"; + toast.className = `copy-toast ${isSuccess ? "success" : "error"}`; + toast.textContent = message; + + document.body.appendChild(toast); + + // Show toast with animation + setTimeout(() => { + toast.classList.add("show"); + }, 10); + + // Hide and remove toast after 2 seconds + setTimeout(() => { + toast.classList.remove("show"); + setTimeout(() => { + toast.parentNode?.removeChild(toast); + }, 300); + }, 2000); + }; + + const changeButtonIcon = (iconClass, duration = 1000) => { + const icon = copyButton.querySelector("i"); + if (!icon) return; + + // Clear any existing timeout + if (iconTimeoutId) { + clearTimeout(iconTimeoutId); + iconTimeoutId = null; + } + + icon.className = iconClass; + + if (duration > 0) { + iconTimeoutId = setTimeout(() => { + icon.className = originalIconClass; + iconTimeoutId = null; + }, duration); + } + }; + + const fetchAndCopyMarkdown = async () => { + // Prevent multiple simultaneous requests + if (isLoading) return; + + try { + isLoading = true; + changeButtonIcon("fa fa-spinner fa-spin", 0); // Don't auto-restore spinner + + const pagePath = getCurrentPagePath(); + const rawUrl = `https://raw.githubusercontent.com/zed-industries/zed/main/docs/src/${pagePath}`; + + const response = await fetch(rawUrl); + if (!response.ok) { + throw new Error( + `Failed to fetch markdown: ${response.status} ${response.statusText}`, + ); + } + + const markdownContent = await response.text(); + + // Copy to clipboard using modern API + if (navigator.clipboard?.writeText) { + await navigator.clipboard.writeText(markdownContent); + } else { + // Fallback: throw error if clipboard API isn't available + throw new Error("Clipboard API not supported in this browser"); + } + + changeButtonIcon("fa fa-check", 1000); + showToast("Page content copied to clipboard!"); + } catch (error) { + console.error("Error copying markdown:", error); + changeButtonIcon("fa fa-exclamation-triangle", 2000); + showToast("Failed to copy markdown. Please try again.", false); + } finally { + isLoading = false; + } + }; + + copyButton.addEventListener("click", fetchAndCopyMarkdown); +}; + +// Initialize functionality when DOM is loaded +document.addEventListener("DOMContentLoaded", () => { + copyMarkdown(); +}); diff --git a/tooling/workspace-hack/Cargo.toml b/tooling/workspace-hack/Cargo.toml index 7201dedb451c3db27998695b75b62847223e6b72..ec9629685d8366864b92a6160ece623450f72b0c 100644 --- a/tooling/workspace-hack/Cargo.toml +++ b/tooling/workspace-hack/Cargo.toml @@ -67,7 +67,7 @@ futures-sink = { version = "0.3" } futures-task = { version = "0.3", default-features = false, features = ["std"] } futures-util = { version = "0.3", features = ["channel", "io-compat", "sink"] } getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-features = false, features = ["std"] } -half = { version = "2", features = ["num-traits"] } +half = { version = "2", features = ["bytemuck", "num-traits", "rand_distr", "use-intrinsics"] } handlebars = { version = "4", features = ["rust-embed"] } hashbrown-3575ec1268b04181 = { package = "hashbrown", version = "0.15", features = ["serde"] } hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["raw"] } @@ -75,6 +75,7 @@ hmac = { version = "0.12", default-features = false, features = ["reset"] } hyper = { version = "0.14", features = ["client", "http1", "http2", "runtime", "server", "stream"] } idna = { version = "1" } indexmap = { version = "2", features = ["serde"] } +itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } jiff = { version = "0.2" } lazy_static = { version = "1", default-features = false, features = ["spin_no_std"] } libc = { version = "0.2", features = ["extra_traits"] } @@ -84,10 +85,12 @@ lyon = { version = "1", default-features = false, features = ["extra"] } lyon_path = { version = "1" } md-5 = { version = "0.10" } memchr = { version = "2" } +memmap2 = { version = "0.9", default-features = false, features = ["stable_deref_trait"] } mime_guess = { version = "2" } miniz_oxide = { version = "0.8", features = ["simd"] } nom = { version = "7" } num-bigint = { version = "0.4" } +num-complex = { version = "0.4", features = ["bytemuck"] } num-integer = { version = "0.1", features = ["i128"] } num-iter = { version = "0.1", default-features = false, features = ["i128", "std"] } num-rational = { version = "0.4", features = ["num-bigint-std"] } @@ -96,11 +99,12 @@ once_cell = { version = "1" } percent-encoding = { version = "2" } phf = { version = "0.11", features = ["macros"] } phf_shared = { version = "0.11" } -prost = { version = "0.9" } +prost-274715c4dabd11b0 = { package = "prost", version = "0.9" } prost-types = { version = "0.9" } rand-c38e5c1d305a1b54 = { package = "rand", version = "0.8", features = ["small_rng"] } rand_chacha = { version = "0.3" } rand_core = { version = "0.6", default-features = false, features = ["std"] } +rand_distr = { version = "0.5" } regalloc2 = { version = "0.11", features = ["checker", "enable-serde"] } regex = { version = "1" } regex-automata = { version = "0.4" } @@ -123,6 +127,7 @@ spin = { version = "0.9" } sqlx = { version = "0.8", features = ["bigdecimal", "chrono", "postgres", "runtime-tokio-rustls", "rust_decimal", "sqlite", "time", "uuid"] } sqlx-postgres = { version = "0.8", default-features = false, features = ["any", "bigdecimal", "chrono", "json", "migrate", "offline", "rust_decimal", "time", "uuid"] } sqlx-sqlite = { version = "0.8", default-features = false, features = ["any", "bundled", "chrono", "json", "migrate", "offline", "time", "uuid"] } +stable_deref_trait = { version = "1" } strum = { version = "0.26", features = ["derive"] } subtle = { version = "2" } thiserror = { version = "2" } @@ -130,6 +135,7 @@ time = { version = "0.3", features = ["local-offset", "macros", "serde-well-know tokio = { version = "1", features = ["full"] } tokio-rustls = { version = "0.26", default-features = false, features = ["tls12"] } tokio-util = { version = "0.7", features = ["codec", "compat", "io"] } +toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } toml_edit = { version = "0.22", features = ["serde"] } tracing = { version = "0.1", features = ["log"] } tracing-core = { version = "0.1" } @@ -198,7 +204,7 @@ futures-sink = { version = "0.3" } futures-task = { version = "0.3", default-features = false, features = ["std"] } futures-util = { version = "0.3", features = ["channel", "io-compat", "sink"] } getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-features = false, features = ["std"] } -half = { version = "2", features = ["num-traits"] } +half = { version = "2", features = ["bytemuck", "num-traits", "rand_distr", "use-intrinsics"] } handlebars = { version = "4", features = ["rust-embed"] } hashbrown-3575ec1268b04181 = { package = "hashbrown", version = "0.15", features = ["serde"] } hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["raw"] } @@ -208,6 +214,7 @@ hyper = { version = "0.14", features = ["client", "http1", "http2", "runtime", " idna = { version = "1" } indexmap = { version = "2", features = ["serde"] } itertools-594e8ee84c453af0 = { package = "itertools", version = "0.13" } +itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } jiff = { version = "0.2" } lazy_static = { version = "1", default-features = false, features = ["spin_no_std"] } libc = { version = "0.2", features = ["extra_traits"] } @@ -217,10 +224,12 @@ lyon = { version = "1", default-features = false, features = ["extra"] } lyon_path = { version = "1" } md-5 = { version = "0.10" } memchr = { version = "2" } +memmap2 = { version = "0.9", default-features = false, features = ["stable_deref_trait"] } mime_guess = { version = "2" } miniz_oxide = { version = "0.8", features = ["simd"] } nom = { version = "7" } num-bigint = { version = "0.4" } +num-complex = { version = "0.4", features = ["bytemuck"] } num-integer = { version = "0.1", features = ["i128"] } num-iter = { version = "0.1", default-features = false, features = ["i128", "std"] } num-rational = { version = "0.4", features = ["num-bigint-std"] } @@ -231,12 +240,13 @@ phf = { version = "0.11", features = ["macros"] } phf_shared = { version = "0.11" } prettyplease = { version = "0.2", default-features = false, features = ["verbatim"] } proc-macro2 = { version = "1" } -prost = { version = "0.9" } +prost-274715c4dabd11b0 = { package = "prost", version = "0.9" } prost-types = { version = "0.9" } quote = { version = "1" } rand-c38e5c1d305a1b54 = { package = "rand", version = "0.8", features = ["small_rng"] } rand_chacha = { version = "0.3" } rand_core = { version = "0.6", default-features = false, features = ["std"] } +rand_distr = { version = "0.5" } regalloc2 = { version = "0.11", features = ["checker", "enable-serde"] } regex = { version = "1" } regex-automata = { version = "0.4" } @@ -261,6 +271,7 @@ sqlx-macros = { version = "0.8", features = ["_rt-tokio", "_tls-rustls-ring-webp sqlx-macros-core = { version = "0.8", features = ["_rt-tokio", "_tls-rustls-ring-webpki", "bigdecimal", "chrono", "derive", "json", "macros", "migrate", "postgres", "rust_decimal", "sqlite", "time", "uuid"] } sqlx-postgres = { version = "0.8", default-features = false, features = ["any", "bigdecimal", "chrono", "json", "migrate", "offline", "rust_decimal", "time", "uuid"] } sqlx-sqlite = { version = "0.8", default-features = false, features = ["any", "bundled", "chrono", "json", "migrate", "offline", "time", "uuid"] } +stable_deref_trait = { version = "1" } strum = { version = "0.26", features = ["derive"] } subtle = { version = "2" } syn-dff4ba8e3ae991db = { package = "syn", version = "1", features = ["extra-traits", "full"] } @@ -271,6 +282,7 @@ time-macros = { version = "0.2", default-features = false, features = ["formatti tokio = { version = "1", features = ["full"] } tokio-rustls = { version = "0.26", default-features = false, features = ["tls12"] } tokio-util = { version = "0.7", features = ["codec", "compat", "io"] } +toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } toml_edit = { version = "0.22", features = ["serde"] } tracing = { version = "0.1", features = ["log"] } tracing-core = { version = "0.1" } @@ -293,15 +305,16 @@ foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } -itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } livekit-runtime = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d" } naga = { version = "25", features = ["msl-out", "wgsl-in"] } nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "user"] } +num = { version = "0.4" } objc2 = { version = "0.6" } objc2-core-foundation = { version = "0.3", default-features = false, features = ["CFArray", "CFCGTypes", "CFData", "CFDate", "CFDictionary", "CFRunLoop", "CFString", "CFURL", "objc2", "std"] } objc2-foundation = { version = "0.3", default-features = false, features = ["NSArray", "NSAttributedString", "NSBundle", "NSCoder", "NSData", "NSDate", "NSDictionary", "NSEnumerator", "NSError", "NSGeometry", "NSNotification", "NSNull", "NSObjCRuntime", "NSObject", "NSProcessInfo", "NSRange", "NSRunLoop", "NSString", "NSURL", "NSUndoManager", "NSValue", "objc2-core-foundation", "std"] } objc2-metal = { version = "0.3" } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } +prost-5ef9efb8ec2df382 = { package = "prost", version = "0.12", features = ["prost-derive"] } ring = { version = "0.17", features = ["std"] } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["event", "mm", "net", "param", "pipe", "process", "termios", "time"] } rustix-dff4ba8e3ae991db = { package = "rustix", version = "1", default-features = false, features = ["process", "termios", "time"] } @@ -322,16 +335,17 @@ foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } -itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } livekit-runtime = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d" } naga = { version = "25", features = ["msl-out", "wgsl-in"] } nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "user"] } +num = { version = "0.4" } objc2 = { version = "0.6" } objc2-core-foundation = { version = "0.3", default-features = false, features = ["CFArray", "CFCGTypes", "CFData", "CFDate", "CFDictionary", "CFRunLoop", "CFString", "CFURL", "objc2", "std"] } objc2-foundation = { version = "0.3", default-features = false, features = ["NSArray", "NSAttributedString", "NSBundle", "NSCoder", "NSData", "NSDate", "NSDictionary", "NSEnumerator", "NSError", "NSGeometry", "NSNotification", "NSNull", "NSObjCRuntime", "NSObject", "NSProcessInfo", "NSRange", "NSRunLoop", "NSString", "NSURL", "NSUndoManager", "NSValue", "objc2-core-foundation", "std"] } objc2-metal = { version = "0.3" } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", default-features = false, features = ["span-locations"] } +prost-5ef9efb8ec2df382 = { package = "prost", version = "0.12", features = ["prost-derive"] } ring = { version = "0.17", features = ["std"] } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["event", "mm", "net", "param", "pipe", "process", "termios", "time"] } rustix-dff4ba8e3ae991db = { package = "rustix", version = "1", default-features = false, features = ["process", "termios", "time"] } @@ -352,15 +366,16 @@ foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } -itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } livekit-runtime = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d" } naga = { version = "25", features = ["msl-out", "wgsl-in"] } nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "user"] } +num = { version = "0.4" } objc2 = { version = "0.6" } objc2-core-foundation = { version = "0.3", default-features = false, features = ["CFArray", "CFCGTypes", "CFData", "CFDate", "CFDictionary", "CFRunLoop", "CFString", "CFURL", "objc2", "std"] } objc2-foundation = { version = "0.3", default-features = false, features = ["NSArray", "NSAttributedString", "NSBundle", "NSCoder", "NSData", "NSDate", "NSDictionary", "NSEnumerator", "NSError", "NSGeometry", "NSNotification", "NSNull", "NSObjCRuntime", "NSObject", "NSProcessInfo", "NSRange", "NSRunLoop", "NSString", "NSURL", "NSUndoManager", "NSValue", "objc2-core-foundation", "std"] } objc2-metal = { version = "0.3" } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } +prost-5ef9efb8ec2df382 = { package = "prost", version = "0.12", features = ["prost-derive"] } ring = { version = "0.17", features = ["std"] } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["event", "mm", "net", "param", "pipe", "process", "termios", "time"] } rustix-dff4ba8e3ae991db = { package = "rustix", version = "1", default-features = false, features = ["process", "termios", "time"] } @@ -381,16 +396,17 @@ foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } -itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } livekit-runtime = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d" } naga = { version = "25", features = ["msl-out", "wgsl-in"] } nix-b73a96c0a5f6a7d9 = { package = "nix", version = "0.29", features = ["fs", "pthread", "signal", "user"] } +num = { version = "0.4" } objc2 = { version = "0.6" } objc2-core-foundation = { version = "0.3", default-features = false, features = ["CFArray", "CFCGTypes", "CFData", "CFDate", "CFDictionary", "CFRunLoop", "CFString", "CFURL", "objc2", "std"] } objc2-foundation = { version = "0.3", default-features = false, features = ["NSArray", "NSAttributedString", "NSBundle", "NSCoder", "NSData", "NSDate", "NSDictionary", "NSEnumerator", "NSError", "NSGeometry", "NSNotification", "NSNull", "NSObjCRuntime", "NSObject", "NSProcessInfo", "NSRange", "NSRunLoop", "NSString", "NSURL", "NSUndoManager", "NSValue", "objc2-core-foundation", "std"] } objc2-metal = { version = "0.3" } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", default-features = false, features = ["span-locations"] } +prost-5ef9efb8ec2df382 = { package = "prost", version = "0.12", features = ["prost-derive"] } ring = { version = "0.17", features = ["std"] } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["event", "mm", "net", "param", "pipe", "process", "termios", "time"] } rustix-dff4ba8e3ae991db = { package = "rustix", version = "1", default-features = false, features = ["process", "termios", "time"] } @@ -417,7 +433,6 @@ getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-f gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } inout = { version = "0.1", default-features = false, features = ["block-padding"] } -itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "xdp"] } linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } livekit-runtime = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d" } @@ -429,6 +444,7 @@ nix-fa1f6196edfd7249 = { package = "nix", version = "0.30", features = ["fs", "s num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", features = ["span-locations"] } +prost-5ef9efb8ec2df382 = { package = "prost", version = "0.12", features = ["prost-derive"] } quote = { version = "1" } rand-274715c4dabd11b0 = { package = "rand", version = "0.9" } ring = { version = "0.17", features = ["std"] } @@ -440,7 +456,6 @@ sync_wrapper = { version = "1", default-features = false, features = ["futures"] tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } -toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } zeroize = { version = "1", features = ["zeroize_derive"] } zvariant = { version = "5", features = ["enumflags2", "gvariant", "url"] } @@ -459,7 +474,6 @@ getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-f gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } inout = { version = "0.1", default-features = false, features = ["block-padding"] } -itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "xdp"] } linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } livekit-runtime = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d" } @@ -471,6 +485,7 @@ nix-fa1f6196edfd7249 = { package = "nix", version = "0.30", features = ["fs", "s num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", default-features = false, features = ["span-locations"] } +prost-5ef9efb8ec2df382 = { package = "prost", version = "0.12", features = ["prost-derive"] } rand-274715c4dabd11b0 = { package = "rand", version = "0.9" } ring = { version = "0.17", features = ["std"] } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["event", "mm", "net", "param", "pipe", "process", "pty", "shm", "stdio", "system", "termios", "time"] } @@ -480,7 +495,6 @@ sync_wrapper = { version = "1", default-features = false, features = ["futures"] tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } -toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } zeroize = { version = "1", features = ["zeroize_derive"] } zvariant = { version = "5", features = ["enumflags2", "gvariant", "url"] } @@ -499,7 +513,6 @@ getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-f gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } inout = { version = "0.1", default-features = false, features = ["block-padding"] } -itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "xdp"] } linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } livekit-runtime = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d" } @@ -511,6 +524,7 @@ nix-fa1f6196edfd7249 = { package = "nix", version = "0.30", features = ["fs", "s num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", features = ["span-locations"] } +prost-5ef9efb8ec2df382 = { package = "prost", version = "0.12", features = ["prost-derive"] } quote = { version = "1" } rand-274715c4dabd11b0 = { package = "rand", version = "0.9" } ring = { version = "0.17", features = ["std"] } @@ -522,7 +536,6 @@ sync_wrapper = { version = "1", default-features = false, features = ["futures"] tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } -toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } zeroize = { version = "1", features = ["zeroize_derive"] } zvariant = { version = "5", features = ["enumflags2", "gvariant", "url"] } @@ -541,7 +554,6 @@ getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-f gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } inout = { version = "0.1", default-features = false, features = ["block-padding"] } -itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "xdp"] } linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } livekit-runtime = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d" } @@ -553,6 +565,7 @@ nix-fa1f6196edfd7249 = { package = "nix", version = "0.30", features = ["fs", "s num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", default-features = false, features = ["span-locations"] } +prost-5ef9efb8ec2df382 = { package = "prost", version = "0.12", features = ["prost-derive"] } rand-274715c4dabd11b0 = { package = "rand", version = "0.9" } ring = { version = "0.17", features = ["std"] } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["event", "mm", "net", "param", "pipe", "process", "pty", "shm", "stdio", "system", "termios", "time"] } @@ -562,7 +575,6 @@ sync_wrapper = { version = "1", default-features = false, features = ["futures"] tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } -toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } zeroize = { version = "1", features = ["zeroize_derive"] } zvariant = { version = "5", features = ["enumflags2", "gvariant", "url"] } @@ -574,8 +586,9 @@ foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-features = false, features = ["js", "rdrand"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } -itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } livekit-runtime = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d" } +num = { version = "0.4" } +prost-5ef9efb8ec2df382 = { package = "prost", version = "0.12", features = ["prost-derive"] } ring = { version = "0.17", features = ["std"] } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["event", "fs", "net"] } scopeguard = { version = "1" } @@ -587,7 +600,8 @@ tower = { version = "0.5", default-features = false, features = ["timeout", "uti winapi = { version = "0.3", default-features = false, features = ["cfg", "commapi", "consoleapi", "evntrace", "fileapi", "handleapi", "impl-debug", "impl-default", "in6addr", "inaddr", "ioapiset", "knownfolders", "minwinbase", "minwindef", "namedpipeapi", "ntsecapi", "objbase", "processenv", "processthreadsapi", "shlobj", "std", "synchapi", "sysinfoapi", "timezoneapi", "winbase", "windef", "winerror", "winioctl", "winnt", "winreg", "winsock2", "winuser"] } windows-core = { version = "0.61" } windows-numerics = { version = "0.2" } -windows-sys-73dcd821b1037cfd = { package = "windows-sys", version = "0.59", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Win32_Globalization", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_Security_Authentication_Identity", "Win32_Security_Credentials", "Win32_Security_Cryptography", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Console", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Ioctl", "Win32_System_Kernel", "Win32_System_LibraryLoader", "Win32_System_Memory", "Win32_System_Performance", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Input_KeyboardAndMouse", "Win32_UI_Shell", "Win32_UI_WindowsAndMessaging"] } +windows-sys-4db8c43aad08e7ae = { package = "windows-sys", version = "0.60", features = ["Win32_Globalization", "Win32_System_Com", "Win32_UI_Shell"] } +windows-sys-73dcd821b1037cfd = { package = "windows-sys", version = "0.59", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_Security_Authentication_Identity", "Win32_Security_Credentials", "Win32_Security_Cryptography", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Console", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Ioctl", "Win32_System_Kernel", "Win32_System_LibraryLoader", "Win32_System_Memory", "Win32_System_Performance", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Input_KeyboardAndMouse", "Win32_UI_Shell", "Win32_UI_WindowsAndMessaging"] } windows-sys-b21d60becc0929df = { package = "windows-sys", version = "0.52", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_IO", "Win32_Foundation", "Win32_Networking_WinSock", "Win32_Security_Authorization", "Win32_Storage_FileSystem", "Win32_System_Console", "Win32_System_IO", "Win32_System_Memory", "Win32_System_Pipes", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_WindowsProgramming"] } windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Shell"] } @@ -598,9 +612,10 @@ foldhash = { version = "0.1", default-features = false, features = ["std"] } getrandom-468e82937335b1c9 = { package = "getrandom", version = "0.3", default-features = false, features = ["std"] } getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-features = false, features = ["js", "rdrand"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } -itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } livekit-runtime = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d" } +num = { version = "0.4" } proc-macro2 = { version = "1", default-features = false, features = ["span-locations"] } +prost-5ef9efb8ec2df382 = { package = "prost", version = "0.12", features = ["prost-derive"] } ring = { version = "0.17", features = ["std"] } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["event", "fs", "net"] } scopeguard = { version = "1" } @@ -612,7 +627,8 @@ tower = { version = "0.5", default-features = false, features = ["timeout", "uti winapi = { version = "0.3", default-features = false, features = ["cfg", "commapi", "consoleapi", "evntrace", "fileapi", "handleapi", "impl-debug", "impl-default", "in6addr", "inaddr", "ioapiset", "knownfolders", "minwinbase", "minwindef", "namedpipeapi", "ntsecapi", "objbase", "processenv", "processthreadsapi", "shlobj", "std", "synchapi", "sysinfoapi", "timezoneapi", "winbase", "windef", "winerror", "winioctl", "winnt", "winreg", "winsock2", "winuser"] } windows-core = { version = "0.61" } windows-numerics = { version = "0.2" } -windows-sys-73dcd821b1037cfd = { package = "windows-sys", version = "0.59", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Win32_Globalization", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_Security_Authentication_Identity", "Win32_Security_Credentials", "Win32_Security_Cryptography", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Console", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Ioctl", "Win32_System_Kernel", "Win32_System_LibraryLoader", "Win32_System_Memory", "Win32_System_Performance", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Input_KeyboardAndMouse", "Win32_UI_Shell", "Win32_UI_WindowsAndMessaging"] } +windows-sys-4db8c43aad08e7ae = { package = "windows-sys", version = "0.60", features = ["Win32_Globalization", "Win32_System_Com", "Win32_UI_Shell"] } +windows-sys-73dcd821b1037cfd = { package = "windows-sys", version = "0.59", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Win32_NetworkManagement_IpHelper", "Win32_Networking_WinSock", "Win32_Security_Authentication_Identity", "Win32_Security_Credentials", "Win32_Security_Cryptography", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Console", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Ioctl", "Win32_System_Kernel", "Win32_System_LibraryLoader", "Win32_System_Memory", "Win32_System_Performance", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_SystemInformation", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Input_KeyboardAndMouse", "Win32_UI_Shell", "Win32_UI_WindowsAndMessaging"] } windows-sys-b21d60becc0929df = { package = "windows-sys", version = "0.52", features = ["Wdk_Foundation", "Wdk_Storage_FileSystem", "Wdk_System_IO", "Win32_Foundation", "Win32_Networking_WinSock", "Win32_Security_Authorization", "Win32_Storage_FileSystem", "Win32_System_Console", "Win32_System_IO", "Win32_System_Memory", "Win32_System_Pipes", "Win32_System_SystemServices", "Win32_System_Threading", "Win32_System_WindowsProgramming"] } windows-sys-c8eced492e86ede7 = { package = "windows-sys", version = "0.48", features = ["Win32_Foundation", "Win32_Globalization", "Win32_Networking_WinSock", "Win32_Security", "Win32_Storage_FileSystem", "Win32_System_Com", "Win32_System_Diagnostics_Debug", "Win32_System_IO", "Win32_System_Pipes", "Win32_System_Registry", "Win32_System_Threading", "Win32_System_Time", "Win32_System_WindowsProgramming", "Win32_UI_Shell"] } @@ -630,7 +646,6 @@ getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-f gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } inout = { version = "0.1", default-features = false, features = ["block-padding"] } -itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "xdp"] } linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } livekit-runtime = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d" } @@ -642,6 +657,7 @@ nix-fa1f6196edfd7249 = { package = "nix", version = "0.30", features = ["fs", "s num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", features = ["span-locations"] } +prost-5ef9efb8ec2df382 = { package = "prost", version = "0.12", features = ["prost-derive"] } quote = { version = "1" } rand-274715c4dabd11b0 = { package = "rand", version = "0.9" } ring = { version = "0.17", features = ["std"] } @@ -653,7 +669,6 @@ sync_wrapper = { version = "1", default-features = false, features = ["futures"] tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } -toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } zeroize = { version = "1", features = ["zeroize_derive"] } zvariant = { version = "5", features = ["enumflags2", "gvariant", "url"] } @@ -672,7 +687,6 @@ getrandom-6f8ce4dd05d13bba = { package = "getrandom", version = "0.2", default-f gimli = { version = "0.31", default-features = false, features = ["read", "std", "write"] } hyper-rustls = { version = "0.27", default-features = false, features = ["http1", "http2", "native-tokio", "ring", "tls12"] } inout = { version = "0.1", default-features = false, features = ["block-padding"] } -itertools-5ef9efb8ec2df382 = { package = "itertools", version = "0.12" } linux-raw-sys-274715c4dabd11b0 = { package = "linux-raw-sys", version = "0.9", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "xdp"] } linux-raw-sys-9fbad63c4bcf4a8f = { package = "linux-raw-sys", version = "0.4", default-features = false, features = ["elf", "errno", "general", "if_ether", "ioctl", "net", "netlink", "no_std", "prctl", "system", "xdp"] } livekit-runtime = { git = "https://github.com/zed-industries/livekit-rust-sdks", rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d" } @@ -684,6 +698,7 @@ nix-fa1f6196edfd7249 = { package = "nix", version = "0.30", features = ["fs", "s num-bigint-dig = { version = "0.8", features = ["i128", "prime", "zeroize"] } object = { version = "0.36", default-features = false, features = ["archive", "read_core", "unaligned", "write"] } proc-macro2 = { version = "1", default-features = false, features = ["span-locations"] } +prost-5ef9efb8ec2df382 = { package = "prost", version = "0.12", features = ["prost-derive"] } rand-274715c4dabd11b0 = { package = "rand", version = "0.9" } ring = { version = "0.17", features = ["std"] } rustix-d585fab2519d2d1 = { package = "rustix", version = "0.38", features = ["event", "mm", "net", "param", "pipe", "process", "pty", "shm", "stdio", "system", "termios", "time"] } @@ -693,7 +708,6 @@ sync_wrapper = { version = "1", default-features = false, features = ["futures"] tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring"] } tokio-socks = { version = "0.5", features = ["futures-io"] } tokio-stream = { version = "0.1", features = ["fs"] } -toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } tower = { version = "0.5", default-features = false, features = ["timeout", "util"] } zeroize = { version = "1", features = ["zeroize_derive"] } zvariant = { version = "5", features = ["enumflags2", "gvariant", "url"] }