diff --git a/.gitignore b/.gitignore index 2a91a65b6eaef906681bf3f6e315de07b094c4b1..ccf4f471d5a7b70be0dc8d619ac64050dd6681ec 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,6 @@ xcuserdata/ # Don't commit any secrets to the repo. .env .env.secret.toml + +# `nix build` output +/result diff --git a/Cargo.lock b/Cargo.lock index 9bb6a3322e613d717d0b3b2755c43fba0f7f3941..da60b0b2318f630b2f64f03d535c09d11bd4938d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -211,14 +211,14 @@ dependencies = [ "worktree", "zed_env_vars", "zlog", - "zstd 0.11.2+zstd.1.5.2", + "zstd", ] [[package]] name = "agent-client-protocol" -version = "0.8.0" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e639d6b544ad39f5b4e05802db5eb04e1518284eb05fda1839931003e0244c8" +checksum = "c2ffe7d502c1e451aafc5aff655000f84d09c9af681354ac0012527009b1af13" dependencies = [ "agent-client-protocol-schema", "anyhow", @@ -233,15 +233,16 @@ dependencies = [ [[package]] name = "agent-client-protocol-schema" -version = "0.9.1" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f182f5e14bef8232b239719bd99166bb11e986c08fc211f28e392f880d3093ba" +checksum = "8af81cc2d5c3f9c04f73db452efd058333735ba9d51c2cf7ef33c9fee038e7e6" dependencies = [ "anyhow", "derive_more 2.0.1", "schemars", "serde", "serde_json", + "strum 0.27.2", ] [[package]] @@ -680,21 +681,6 @@ dependencies = [ "syn 2.0.106", ] -[[package]] -name = "argminmax" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70f13d10a41ac8d2ec79ee34178d61e6f47a29c2edfe7ef1721c7383b0359e65" -dependencies = [ - "num-traits", -] - -[[package]] -name = "array-init-cursor" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed51fe0f224d1d4ea768be38c51f9f831dee9d05c163c11fba0b8c44387b1fc3" - [[package]] name = "arraydeque" version = "0.5.1" @@ -1278,15 +1264,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "atoi_simd" -version = "0.16.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2a49e05797ca52e312a0c658938b7d00693ef037799ef7187678f212d7684cf" -dependencies = [ - "debug_unsafe", -] - [[package]] name = "atomic" version = "0.5.3" @@ -2070,26 +2047,6 @@ dependencies = [ "serde", ] -[[package]] -name = "bincode" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36eaf5d7b090263e8150820482d5d93cd964a81e4019913c972f4edcc6edb740" -dependencies = [ - "bincode_derive", - "serde", - "unty", -] - -[[package]] -name = "bincode_derive" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf95709a440f45e986983918d0e8a1f30a9b1df04918fc828670606804ac3c09" -dependencies = [ - "virtue", -] - [[package]] name = "bindgen" version = "0.71.1" @@ -2242,19 +2199,6 @@ dependencies = [ "profiling", ] -[[package]] -name = "blake3" -version = "1.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0" -dependencies = [ - "arrayref", - "arrayvec", - "cc", - "cfg-if", - "constant_time_eq 0.3.1", -] - [[package]] name = "block" version = "0.1.6" @@ -2344,12 +2288,6 @@ dependencies = [ "syn 2.0.106", ] -[[package]] -name = "boxcar" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36f64beae40a84da1b4b26ff2761a5b895c12adc41dc25aaee1c4f2bbfe97a6e" - [[package]] name = "breadcrumbs" version = "0.1.0" @@ -2516,9 +2454,6 @@ name = "bytes" version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" -dependencies = [ - "serde", -] [[package]] name = "bytes-utils" @@ -2805,15 +2740,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" -[[package]] -name = "castaway" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a" -dependencies = [ - "rustversion", -] - [[package]] name = "cbc" version = "0.1.2" @@ -2942,16 +2868,6 @@ dependencies = [ "windows-link 0.2.1", ] -[[package]] -name = "chrono-tz" -version = "0.10.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6139a8597ed92cf816dfb33f5dd6cf0bb93a6adc938f11039f371bc5bcd26c3" -dependencies = [ - "chrono", - "phf 0.12.1", -] - [[package]] name = "chunked_transfer" version = "1.5.0" @@ -3201,12 +3117,7 @@ dependencies = [ "anyhow", "cloud_llm_client", "indoc", - "ordered-float 2.10.1", - "rustc-hash 2.1.1", - "schemars", "serde", - "serde_json", - "strum 0.27.2", ] [[package]] @@ -3314,8 +3225,8 @@ name = "codestral" version = "0.1.0" dependencies = [ "anyhow", - "edit_prediction", "edit_prediction_context", + "edit_prediction_types", "futures 0.3.31", "gpui", "http_client", @@ -3505,17 +3416,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "comfy-table" -version = "7.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b03b7db8e0b4b2fdad6c551e634134e99ec000e5c8c3b6856c65e8bbaded7a3b" -dependencies = [ - "crossterm", - "unicode-segmentation", - "unicode-width", -] - [[package]] name = "command-fds" version = "0.3.2" @@ -3569,21 +3469,6 @@ dependencies = [ "workspace", ] -[[package]] -name = "compact_str" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a" -dependencies = [ - "castaway", - "cfg-if", - "itoa", - "rustversion", - "ryu", - "serde", - "static_assertions", -] - [[package]] name = "component" version = "0.1.0" @@ -3689,12 +3574,6 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" -[[package]] -name = "constant_time_eq" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" - [[package]] name = "context_server" version = "0.1.0" @@ -3747,7 +3626,7 @@ dependencies = [ "command_palette_hooks", "ctor", "dirs 4.0.0", - "edit_prediction", + "edit_prediction_types", "editor", "fs", "futures 0.3.31", @@ -4160,7 +4039,7 @@ dependencies = [ name = "crashes" version = "0.1.0" dependencies = [ - "bincode 1.3.3", + "bincode", "cfg-if", "crash-handler", "extension_host", @@ -4174,7 +4053,7 @@ dependencies = [ "smol", "system_specs", "windows 0.61.3", - "zstd 0.11.2+zstd.1.5.2", + "zstd", ] [[package]] @@ -4319,29 +4198,6 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" -[[package]] -name = "crossterm" -version = "0.29.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8b9f2e4c67f833b660cdb0a3523065869fb35570177239812ed4c905aeff87b" -dependencies = [ - "bitflags 2.9.4", - "crossterm_winapi", - "document-features", - "parking_lot", - "rustix 1.1.2", - "winapi", -] - -[[package]] -name = "crossterm_winapi" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" -dependencies = [ - "winapi", -] - [[package]] name = "crunchy" version = "0.2.4" @@ -4696,12 +4552,6 @@ dependencies = [ "util", ] -[[package]] -name = "debug_unsafe" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85d3cef41d236720ed453e102153a53e4cc3d2fde848c0078a50cf249e8e3e5b" - [[package]] name = "debugger_tools" version = "0.1.0" @@ -4734,6 +4584,7 @@ dependencies = [ "db", "debugger_tools", "editor", + "feature_flags", "file_icons", "futures 0.3.31", "fuzzy", @@ -5109,15 +4960,6 @@ dependencies = [ "zlog", ] -[[package]] -name = "document-features" -version = "0.2.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95249b50c6c185bee49034bcb378a49dc2b5dff0be90ff6616d31d64febab05d" -dependencies = [ - "litrs", -] - [[package]] name = "documented" version = "0.9.2" @@ -5267,144 +5109,252 @@ dependencies = [ name = "edit_prediction" version = "0.1.0" dependencies = [ - "client", - "gpui", - "language", -] - -[[package]] -name = "edit_prediction_button" -version = "0.1.0" -dependencies = [ + "ai_onboarding", "anyhow", + "arrayvec", + "brotli", "client", + "clock", + "cloud_api_types", "cloud_llm_client", - "codestral", + "cloud_zeta2_prompt", + "collections", "copilot", - "edit_prediction", - "editor", + "credentials_provider", + "ctor", + "db", + "edit_prediction_context", + "edit_prediction_types", "feature_flags", "fs", "futures 0.3.31", "gpui", "indoc", + "itertools 0.14.0", "language", + "language_model", + "log", "lsp", "menu", - "paths", + "open_ai", + "parking_lot", + "postage", + "pretty_assertions", "project", + "rand 0.9.2", "regex", + "release_channel", + "semver", + "serde", "serde_json", "settings", - "supermaven", + "smol", + "strsim", + "strum 0.27.2", "telemetry", - "theme", + "telemetry_events", + "thiserror 2.0.17", "ui", - "ui_input", "util", + "uuid", "workspace", + "worktree", "zed_actions", - "zeta", + "zlog", ] [[package]] -name = "edit_prediction_context" +name = "edit_prediction_cli" version = "0.1.0" dependencies = [ + "anthropic", "anyhow", - "arrayvec", + "chrono", "clap", + "client", "cloud_llm_client", + "cloud_zeta2_prompt", "collections", + "debug_adapter_extension", + "edit_prediction", + "edit_prediction_context", + "extension", + "fs", "futures 0.3.31", "gpui", - "hashbrown 0.15.5", + "gpui_tokio", + "http_client", "indoc", - "itertools 0.14.0", "language", + "language_extension", + "language_model", + "language_models", + "languages", "log", - "ordered-float 2.10.1", - "postage", + "node_runtime", + "paths", "pretty_assertions", "project", - "regex", + "prompt_store", + "pulldown-cmark 0.12.2", + "release_channel", + "reqwest_client", "serde", "serde_json", "settings", - "slotmap", - "strum 0.27.2", - "text", - "tree-sitter", - "tree-sitter-c", - "tree-sitter-cpp", - "tree-sitter-go", + "shellexpand 2.1.2", + "smol", + "sqlez", + "sqlez_macros", + "terminal_view", + "toml 0.8.23", "util", + "watch", "zlog", ] [[package]] -name = "editor" +name = "edit_prediction_context" version = "0.1.0" dependencies = [ - "aho-corasick", "anyhow", - "assets", - "buffer_diff", - "client", - "clock", + "cloud_llm_client", "collections", - "convert_case 0.8.0", - "criterion", - "ctor", - "dap", - "db", - "edit_prediction", - "emojis", - "feature_flags", - "file_icons", - "fs", + "env_logger 0.11.8", "futures 0.3.31", - "fuzzy", - "git", "gpui", - "http_client", "indoc", - "itertools 0.14.0", "language", - "languages", - "linkify", "log", "lsp", - "markdown", - "menu", - "multi_buffer", - "ordered-float 2.10.1", "parking_lot", "pretty_assertions", "project", - "rand 0.9.2", - "regex", - "release_channel", - "rope", - "rpc", - "schemars", - "semver", "serde", "serde_json", "settings", "smallvec", - "smol", - "snippet", - "sum_tree", - "task", - "telemetry", - "tempfile", + "text", + "tree-sitter", + "util", + "zlog", +] + +[[package]] +name = "edit_prediction_types" +version = "0.1.0" +dependencies = [ + "client", + "gpui", + "language", +] + +[[package]] +name = "edit_prediction_ui" +version = "0.1.0" +dependencies = [ + "anyhow", + "buffer_diff", + "client", + "cloud_llm_client", + "cloud_zeta2_prompt", + "codestral", + "command_palette_hooks", + "copilot", + "edit_prediction", + "edit_prediction_types", + "editor", + "feature_flags", + "fs", + "futures 0.3.31", + "gpui", + "indoc", + "language", + "lsp", + "markdown", + "menu", + "multi_buffer", + "paths", + "project", + "regex", + "serde_json", + "settings", + "supermaven", + "telemetry", + "text", + "theme", + "ui", + "ui_input", + "util", + "workspace", + "zed_actions", +] + +[[package]] +name = "editor" +version = "0.1.0" +dependencies = [ + "aho-corasick", + "anyhow", + "assets", + "buffer_diff", + "client", + "clock", + "collections", + "convert_case 0.8.0", + "criterion", + "ctor", + "dap", + "db", + "edit_prediction_types", + "emojis", + "feature_flags", + "file_icons", + "fs", + "futures 0.3.31", + "fuzzy", + "git", + "gpui", + "http_client", + "indoc", + "itertools 0.14.0", + "language", + "languages", + "linkify", + "log", + "lsp", + "markdown", + "menu", + "multi_buffer", + "ordered-float 2.10.1", + "parking_lot", + "pretty_assertions", + "project", + "rand 0.9.2", + "regex", + "release_channel", + "rope", + "rpc", + "schemars", + "semver", + "serde", + "serde_json", + "settings", + "smallvec", + "smol", + "snippet", + "sum_tree", + "task", + "telemetry", + "tempfile", "text", "theme", "time", + "tracing", "tree-sitter-bash", "tree-sitter-c", "tree-sitter-html", + "tree-sitter-md", "tree-sitter-python", "tree-sitter-rust", "tree-sitter-typescript", @@ -5420,6 +5370,7 @@ dependencies = [ "workspace", "zed_actions", "zlog", + "ztracing", ] [[package]] @@ -5695,12 +5646,6 @@ dependencies = [ "windows-sys 0.48.0", ] -[[package]] -name = "ethnum" -version = "1.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca81e6b4777c89fd810c25a4be2b1bd93ea034fbe58e6a75216a34c6b82c539b" - [[package]] name = "euclid" version = "0.22.11" @@ -5864,6 +5809,7 @@ dependencies = [ "serde", "serde_json", "task", + "tempfile", "toml 0.8.23", "url", "util", @@ -5993,12 +5939,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" -[[package]] -name = "fallible-streaming-iterator" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" - [[package]] name = "fancy-regex" version = "0.16.2" @@ -6010,12 +5950,6 @@ dependencies = [ "regex-syntax", ] -[[package]] -name = "fast-float2" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8eb564c5c7423d25c886fb561d1e4ee69f72354d16918afa32c08811f6b6a55" - [[package]] name = "fast-srgb8" version = "1.0.0" @@ -6191,7 +6125,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc5a4e564e38c699f2880d3fda590bedc2e69f3f84cd48b457bd892ce61d0aa9" dependencies = [ "crc32fast", - "libz-rs-sys", "miniz_oxide", ] @@ -6448,16 +6381,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "fs4" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8640e34b88f7652208ce9e88b1a37a2ae95227d84abec377ccd3c5cfeb141ed4" -dependencies = [ - "rustix 1.1.2", - "windows-sys 0.59.0", -] - [[package]] name = "fs_benchmarks" version = "0.1.0" @@ -6918,6 +6841,20 @@ dependencies = [ "seq-macro", ] +[[package]] +name = "generator" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "605183a538e3e2a9c1038635cc5c2d194e2ee8fd0d1b66b8349fad7dbacce5a2" +dependencies = [ + "cc", + "cfg-if", + "libc", + "log", + "rustversion", + "windows 0.61.3", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -7079,6 +7016,7 @@ dependencies = [ "gpui", "http_client", "indoc", + "itertools 0.14.0", "pretty_assertions", "regex", "serde", @@ -7136,6 +7074,7 @@ dependencies = [ "theme", "time", "time_format", + "tracing", "ui", "unindent", "util", @@ -7145,6 +7084,7 @@ dependencies = [ "zed_actions", "zeroize", "zlog", + "ztracing", ] [[package]] @@ -7521,7 +7461,6 @@ dependencies = [ "allocator-api2", "equivalent", "foldhash 0.1.5", - "rayon", "serde", ] @@ -7633,7 +7572,7 @@ version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13c255bdf46e07fb840d120a36dcc81f385140d7191c76a7391672675c01a55d" dependencies = [ - "bincode 1.3.3", + "bincode", "byteorder", "heed-traits", "serde", @@ -8223,7 +8162,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4b0f83760fb341a774ed326568e19f5a863af4a952def8c39f9ab92fd95b88e5" dependencies = [ "equivalent", - "hashbrown 0.15.5", + "hashbrown 0.16.1", "serde", "serde_core", ] @@ -8393,7 +8332,7 @@ version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fb8251fb7bcd9ccd3725ed8deae9fe7db8e586495c9eb5b0c52e6233e5e75ea" dependencies = [ - "bincode 1.3.3", + "bincode", "crossbeam-channel", "fnv", "lazy_static", @@ -9234,15 +9173,6 @@ dependencies = [ "webrtc-sys", ] -[[package]] -name = "libz-rs-sys" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "840db8cf39d9ec4dd794376f38acc40d0fc65eec2a8f484f7fd375b84602becd" -dependencies = [ - "zlib-rs", -] - [[package]] name = "libz-sys" version = "1.1.22" @@ -9305,12 +9235,6 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" -[[package]] -name = "litrs" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5e54036fe321fd421e10d732f155734c4e4afd610dd556d9a82833ab3ee0bed" - [[package]] name = "livekit" version = "0.7.8" @@ -9480,6 +9404,19 @@ dependencies = [ "value-bag", ] +[[package]] +name = "loom" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" +dependencies = [ + "cfg-if", + "generator", + "scoped-tls", + "tracing", + "tracing-subscriber", +] + [[package]] name = "loop9" version = "0.1.5" @@ -9602,25 +9539,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "lz4" -version = "1.28.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a20b523e860d03443e98350ceaac5e71c6ba89aea7d960769ec3ce37f4de5af4" -dependencies = [ - "lz4-sys", -] - -[[package]] -name = "lz4-sys" -version = "1.11.1+lz4-1.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6" -dependencies = [ - "cc", - "libc", -] - [[package]] name = "mac" version = "0.1.1" @@ -10169,9 +10087,11 @@ dependencies = [ "sum_tree", "text", "theme", + "tracing", "tree-sitter", "util", "zlog", + "ztracing", ] [[package]] @@ -10483,15 +10403,6 @@ name = "notify-types" version = "2.0.0" source = "git+https://github.com/zed-industries/notify.git?rev=b4588b2e5aee68f4c0e100f140e808cbce7b1419#b4588b2e5aee68f4c0e100f140e808cbce7b1419" -[[package]] -name = "now" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d89e9874397a1f0a52fc1f197a8effd9735223cb2390e9dcc83ac6cd02923d0" -dependencies = [ - "chrono", -] - [[package]] name = "ntapi" version = "0.4.1" @@ -10887,41 +10798,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "object_store" -version = "0.12.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c1be0c6c22ec0817cdc77d3842f721a17fd30ab6965001415b5402a74e6b740" -dependencies = [ - "async-trait", - "base64 0.22.1", - "bytes 1.10.1", - "chrono", - "form_urlencoded", - "futures 0.3.31", - "http 1.3.1", - "http-body-util", - "humantime", - "hyper 1.7.0", - "itertools 0.14.0", - "parking_lot", - "percent-encoding", - "quick-xml 0.38.3", - "rand 0.9.2", - "reqwest 0.12.24", - "ring", - "serde", - "serde_json", - "serde_urlencoded", - "thiserror 2.0.17", - "tokio", - "tracing", - "url", - "walkdir", - "wasm-bindgen-futures", - "web-time", -] - [[package]] name = "ollama" version = "0.1.0" @@ -12156,16 +12032,6 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" -[[package]] -name = "planus" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3daf8e3d4b712abe1d690838f6e29fb76b76ea19589c4afa39ec30e12f62af71" -dependencies = [ - "array-init-cursor", - "hashbrown 0.15.5", -] - [[package]] name = "plist" version = "1.8.0" @@ -12234,544 +12100,30 @@ dependencies = [ ] [[package]] -name = "polars" -version = "0.51.0" +name = "polling" +version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5f7feb5d56b954e691dff22a8b2d78d77433dcc93c35fe21c3777fdc121b697" +checksum = "5d0e4f59085d47d8241c88ead0f274e8a0cb551f3625263c05eb8dd897c34218" dependencies = [ - "getrandom 0.2.16", - "getrandom 0.3.4", - "polars-arrow", - "polars-core", - "polars-error", - "polars-io", - "polars-lazy", - "polars-ops", - "polars-parquet", - "polars-sql", - "polars-time", - "polars-utils", - "version_check", + "cfg-if", + "concurrent-queue", + "hermit-abi", + "pin-project-lite", + "rustix 1.1.2", + "windows-sys 0.61.2", ] [[package]] -name = "polars-arrow" -version = "0.51.0" +name = "pollster" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5da3b0203fd7ee5720aa0b5e790b591aa5d3f41c3ed2c34a3a393382198af2f7" + +[[package]] +name = "pori" +version = "0.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32b4fed2343961b3eea3db2cee165540c3e1ad9d5782350cc55a9e76cf440148" -dependencies = [ - "atoi_simd", - "bitflags 2.9.4", - "bytemuck", - "chrono", - "chrono-tz", - "dyn-clone", - "either", - "ethnum", - "getrandom 0.2.16", - "getrandom 0.3.4", - "hashbrown 0.15.5", - "itoa", - "lz4", - "num-traits", - "polars-arrow-format", - "polars-error", - "polars-schema", - "polars-utils", - "serde", - "simdutf8", - "streaming-iterator", - "strum_macros 0.27.2", - "version_check", - "zstd 0.13.3", -] - -[[package]] -name = "polars-arrow-format" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a556ac0ee744e61e167f34c1eb0013ce740e0ee6cd8c158b2ec0b518f10e6675" -dependencies = [ - "planus", - "serde", -] - -[[package]] -name = "polars-compute" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "138785beda4e4a90a025219f09d0d15a671b2be9091513ede58e05db6ad4413f" -dependencies = [ - "atoi_simd", - "bytemuck", - "chrono", - "either", - "fast-float2", - "hashbrown 0.15.5", - "itoa", - "num-traits", - "polars-arrow", - "polars-error", - "polars-utils", - "rand 0.9.2", - "ryu", - "serde", - "skiplist", - "strength_reduce", - "strum_macros 0.27.2", - "version_check", -] - -[[package]] -name = "polars-core" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e77b1f08ef6dbb032bb1d0d3365464be950df9905f6827a95b24c4ca5518901d" -dependencies = [ - "bitflags 2.9.4", - "boxcar", - "bytemuck", - "chrono", - "chrono-tz", - "comfy-table", - "either", - "hashbrown 0.15.5", - "indexmap", - "itoa", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-dtype", - "polars-error", - "polars-row", - "polars-schema", - "polars-utils", - "rand 0.9.2", - "rand_distr", - "rayon", - "regex", - "serde", - "serde_json", - "strum_macros 0.27.2", - "uuid", - "version_check", - "xxhash-rust", -] - -[[package]] -name = "polars-dtype" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89c43d0ea57168be4546c4d8064479ed8b29a9c79c31a0c7c367ee734b9b7158" -dependencies = [ - "boxcar", - "hashbrown 0.15.5", - "polars-arrow", - "polars-error", - "polars-utils", - "serde", - "uuid", -] - -[[package]] -name = "polars-error" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9cb5d98f59f8b94673ee391840440ad9f0d2170afced95fc98aa86f895563c0" -dependencies = [ - "object_store", - "parking_lot", - "polars-arrow-format", - "regex", - "signal-hook", - "simdutf8", -] - -[[package]] -name = "polars-expr" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "343931b818cf136349135ba11dbc18c27683b52c3477b1ba8ca606cf5ab1965c" -dependencies = [ - "bitflags 2.9.4", - "hashbrown 0.15.5", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-io", - "polars-ops", - "polars-plan", - "polars-row", - "polars-time", - "polars-utils", - "rand 0.9.2", - "rayon", - "recursive", -] - -[[package]] -name = "polars-io" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10388c64b8155122488229a881d1c6f4fdc393bc988e764ab51b182fcb2307e4" -dependencies = [ - "async-trait", - "atoi_simd", - "blake3", - "bytes 1.10.1", - "chrono", - "fast-float2", - "fs4", - "futures 0.3.31", - "glob", - "hashbrown 0.15.5", - "home", - "itoa", - "memchr", - "memmap2", - "num-traits", - "object_store", - "percent-encoding", - "polars-arrow", - "polars-core", - "polars-error", - "polars-parquet", - "polars-schema", - "polars-time", - "polars-utils", - "rayon", - "regex", - "reqwest 0.12.24", - "ryu", - "serde", - "serde_json", - "simdutf8", - "tokio", - "tokio-util", - "url", -] - -[[package]] -name = "polars-lazy" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fb6e2c6c2fa4ea0c660df1c06cf56960c81e7c2683877995bae3d4e3d408147" -dependencies = [ - "bitflags 2.9.4", - "chrono", - "either", - "memchr", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-expr", - "polars-io", - "polars-mem-engine", - "polars-ops", - "polars-plan", - "polars-stream", - "polars-time", - "polars-utils", - "rayon", - "version_check", -] - -[[package]] -name = "polars-mem-engine" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20a856e98e253587c28d8132a5e7e5a75cb2c44731ca090f1481d45f1d123771" -dependencies = [ - "futures 0.3.31", - "memmap2", - "polars-arrow", - "polars-core", - "polars-error", - "polars-expr", - "polars-io", - "polars-ops", - "polars-plan", - "polars-time", - "polars-utils", - "rayon", - "recursive", - "tokio", -] - -[[package]] -name = "polars-ops" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acf6062173fdc9ba05775548beb66e76643a148d9aeadc9984ed712bc4babd76" -dependencies = [ - "argminmax", - "base64 0.22.1", - "bytemuck", - "chrono", - "chrono-tz", - "either", - "hashbrown 0.15.5", - "hex", - "indexmap", - "libm", - "memchr", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-error", - "polars-schema", - "polars-utils", - "rayon", - "regex", - "regex-syntax", - "strum_macros 0.27.2", - "unicode-normalization", - "unicode-reverse", - "version_check", -] - -[[package]] -name = "polars-parquet" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc1d769180dec070df0dc4b89299b364bf2cfe32b218ecc4ddd8f1a49ae60669" -dependencies = [ - "async-stream", - "base64 0.22.1", - "brotli", - "bytemuck", - "ethnum", - "flate2", - "futures 0.3.31", - "hashbrown 0.15.5", - "lz4", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-error", - "polars-parquet-format", - "polars-utils", - "serde", - "simdutf8", - "snap", - "streaming-decompression", - "zstd 0.13.3", -] - -[[package]] -name = "polars-parquet-format" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c025243dcfe8dbc57e94d9f82eb3bef10b565ab180d5b99bed87fd8aea319ce1" -dependencies = [ - "async-trait", - "futures 0.3.31", -] - -[[package]] -name = "polars-plan" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cd3a2e33ae4484fe407ab2d2ba5684f0889d1ccf3ad6b844103c03638e6d0a0" -dependencies = [ - "bitflags 2.9.4", - "bytemuck", - "bytes 1.10.1", - "chrono", - "chrono-tz", - "either", - "futures 0.3.31", - "hashbrown 0.15.5", - "memmap2", - "num-traits", - "percent-encoding", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-error", - "polars-io", - "polars-ops", - "polars-parquet", - "polars-time", - "polars-utils", - "rayon", - "recursive", - "regex", - "sha2", - "strum_macros 0.27.2", - "version_check", -] - -[[package]] -name = "polars-row" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18734f17e0e348724df3ae65f3ee744c681117c04b041cac969dfceb05edabc0" -dependencies = [ - "bitflags 2.9.4", - "bytemuck", - "polars-arrow", - "polars-compute", - "polars-dtype", - "polars-error", - "polars-utils", -] - -[[package]] -name = "polars-schema" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e6c1ab13e04d5167661a9854ed1ea0482b2ed9b8a0f1118dabed7cd994a85e3" -dependencies = [ - "indexmap", - "polars-error", - "polars-utils", - "serde", - "version_check", -] - -[[package]] -name = "polars-sql" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4e7766da02cc1d464994404d3e88a7a0ccd4933df3627c325480fbd9bbc0a11" -dependencies = [ - "bitflags 2.9.4", - "hex", - "polars-core", - "polars-error", - "polars-lazy", - "polars-ops", - "polars-plan", - "polars-time", - "polars-utils", - "rand 0.9.2", - "regex", - "serde", - "sqlparser", -] - -[[package]] -name = "polars-stream" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31f6c6ca1ea01f9dea424d167e4f33f5ec44cd67fbfac9efd40575ed20521f14" -dependencies = [ - "async-channel 2.5.0", - "async-trait", - "atomic-waker", - "bitflags 2.9.4", - "crossbeam-channel", - "crossbeam-deque", - "crossbeam-queue", - "crossbeam-utils", - "futures 0.3.31", - "memmap2", - "parking_lot", - "percent-encoding", - "pin-project-lite", - "polars-arrow", - "polars-core", - "polars-error", - "polars-expr", - "polars-io", - "polars-mem-engine", - "polars-ops", - "polars-parquet", - "polars-plan", - "polars-utils", - "rand 0.9.2", - "rayon", - "recursive", - "slotmap", - "tokio", - "tokio-util", - "version_check", -] - -[[package]] -name = "polars-time" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6a3a6e279a7a984a0b83715660f9e880590c6129ec2104396bfa710bcd76dee" -dependencies = [ - "atoi_simd", - "bytemuck", - "chrono", - "chrono-tz", - "now", - "num-traits", - "polars-arrow", - "polars-compute", - "polars-core", - "polars-error", - "polars-ops", - "polars-utils", - "rayon", - "regex", - "strum_macros 0.27.2", -] - -[[package]] -name = "polars-utils" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57b267021b0e5422d7fbc70fd79e51b9f9a8466c585779373a18b0199e973f29" -dependencies = [ - "bincode 2.0.1", - "bytemuck", - "bytes 1.10.1", - "compact_str", - "either", - "flate2", - "foldhash 0.1.5", - "hashbrown 0.15.5", - "indexmap", - "libc", - "memmap2", - "num-traits", - "polars-error", - "rand 0.9.2", - "raw-cpuid 11.6.0", - "rayon", - "regex", - "rmp-serde", - "serde", - "serde_json", - "serde_stacker", - "slotmap", - "stacker", - "uuid", - "version_check", -] - -[[package]] -name = "polling" -version = "3.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d0e4f59085d47d8241c88ead0f274e8a0cb551f3625263c05eb8dd897c34218" -dependencies = [ - "cfg-if", - "concurrent-queue", - "hermit-abi", - "pin-project-lite", - "rustix 1.1.2", - "windows-sys 0.61.2", -] - -[[package]] -name = "pollster" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5da3b0203fd7ee5720aa0b5e790b591aa5d3f41c3ed2c34a3a393382198af2f7" - -[[package]] -name = "pori" -version = "0.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a63d338dec139f56dacc692ca63ad35a6be6a797442479b55acd611d79e906" +checksum = "a4a63d338dec139f56dacc692ca63ad35a6be6a797442479b55acd611d79e906" dependencies = [ "nom 7.1.3", ] @@ -13088,6 +12440,7 @@ dependencies = [ "terminal", "text", "toml 0.8.23", + "tracing", "unindent", "url", "util", @@ -13097,6 +12450,7 @@ dependencies = [ "worktree", "zeroize", "zlog", + "ztracing", ] [[package]] @@ -13498,7 +12852,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42a232e7487fc2ef313d96dde7948e7a3c05101870d8985e4fd8d26aedd27b89" dependencies = [ "memchr", - "serde", ] [[package]] @@ -13832,26 +13185,6 @@ dependencies = [ "zed_actions", ] -[[package]] -name = "recursive" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0786a43debb760f491b1bc0269fe5e84155353c67482b9e60d0cfb596054b43e" -dependencies = [ - "recursive-proc-macro-impl", - "stacker", -] - -[[package]] -name = "recursive-proc-macro-impl" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" -dependencies = [ - "quote", - "syn 2.0.106", -] - [[package]] name = "redox_syscall" version = "0.2.16" @@ -14208,35 +13541,26 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2 0.4.12", "http 1.3.1", "http-body 1.0.1", "http-body-util", "hyper 1.7.0", - "hyper-rustls 0.27.7", "hyper-util", "js-sys", "log", "percent-encoding", "pin-project-lite", - "quinn", - "rustls 0.23.33", - "rustls-native-certs 0.8.2", - "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", "sync_wrapper 1.0.2", "tokio", - "tokio-rustls 0.26.2", - "tokio-util", "tower 0.5.2", "tower-http 0.6.6", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", - "wasm-streams", "web-sys", ] @@ -14359,17 +13683,6 @@ dependencies = [ "paste", ] -[[package]] -name = "rmp-serde" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db" -dependencies = [ - "byteorder", - "rmp", - "serde", -] - [[package]] name = "rmpv" version = "1.3.0" @@ -14406,9 +13719,11 @@ dependencies = [ "rand 0.9.2", "rayon", "sum_tree", + "tracing", "unicode-segmentation", "util", "zlog", + "ztracing", ] [[package]] @@ -14439,7 +13754,7 @@ dependencies = [ "tracing", "util", "zlog", - "zstd 0.11.2+zstd.1.5.2", + "zstd", ] [[package]] @@ -15323,23 +14638,12 @@ dependencies = [ ] [[package]] -name = "serde_spanned" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e24345aa0fe688594e73770a5f6d1b216508b4f93484c0026d521acd30134392" -dependencies = [ - "serde_core", -] - -[[package]] -name = "serde_stacker" -version = "0.1.14" +name = "serde_spanned" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4936375d50c4be7eff22293a9344f8e46f323ed2b3c243e52f89138d9bb0f4a" +checksum = "e24345aa0fe688594e73770a5f6d1b216508b4f93484c0026d521acd30134392" dependencies = [ - "serde", "serde_core", - "stacker", ] [[package]] @@ -15683,16 +14987,6 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" -[[package]] -name = "skiplist" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f354fd282d3177c2951004953e2fdc4cb342fa159bbee8b829852b6a081c8ea1" -dependencies = [ - "rand 0.9.2", - "thiserror 2.0.17", -] - [[package]] name = "skrifa" version = "0.37.0" @@ -15768,12 +15062,6 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd538fb6910ac1099850255cf94a94df6551fbdd602454387d0adb2d1ca6dead" -[[package]] -name = "snap" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" - [[package]] name = "snippet" version = "0.1.0" @@ -15820,26 +15108,6 @@ dependencies = [ "workspace", ] -[[package]] -name = "soa-rs" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b75ae4668062b095fda87ba54118697bed601f07f6c68bf50289a25ca0c8c935" -dependencies = [ - "soa-rs-derive", -] - -[[package]] -name = "soa-rs-derive" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c09121507da587d3434e5929ce3321162f36bd3eff403873cb163c06b176913" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.106", -] - [[package]] name = "socket2" version = "0.5.10" @@ -15959,15 +15227,6 @@ dependencies = [ "unicode_categories", ] -[[package]] -name = "sqlparser" -version = "0.53.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05a528114c392209b3264855ad491fcce534b94a38771b0a0b97a79379275ce8" -dependencies = [ - "log", -] - [[package]] name = "sqlx" version = "0.8.6" @@ -16269,15 +15528,6 @@ dependencies = [ "ui", ] -[[package]] -name = "streaming-decompression" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf6cc3b19bfb128a8ad11026086e31d3ce9ad23f8ea37354b31383a187c44cf3" -dependencies = [ - "fallible-streaming-iterator", -] - [[package]] name = "streaming-iterator" version = "0.1.9" @@ -16409,7 +15659,9 @@ dependencies = [ "log", "rand 0.9.2", "rayon", + "tracing", "zlog", + "ztracing", ] [[package]] @@ -16419,7 +15671,7 @@ dependencies = [ "anyhow", "client", "collections", - "edit_prediction", + "edit_prediction_types", "editor", "env_logger 0.11.8", "futures 0.3.31", @@ -17663,7 +16915,6 @@ dependencies = [ "futures-core", "futures-io", "futures-sink", - "futures-util", "pin-project-lite", "tokio", ] @@ -17895,9 +17146,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.41" +version = "0.1.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +checksum = "2d15d90a0b5c19378952d479dc858407149d7bb45a14de0142f6c534b16fc647" dependencies = [ "log", "pin-project-lite", @@ -17907,9 +17158,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.30" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", @@ -17918,9 +17169,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.34" +version = "0.1.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" +checksum = "7a04e24fab5c89c6a36eb8558c9656f30d81de51dfa4d3b45f26b21d61fa0a6c" dependencies = [ "once_cell", "valuable", @@ -17949,9 +17200,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.20" +version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" +checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" dependencies = [ "matchers", "nu-ansi-term", @@ -17968,6 +17219,38 @@ dependencies = [ "tracing-serde", ] +[[package]] +name = "tracing-tracy" +version = "0.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eaa1852afa96e0fe9e44caa53dc0bd2d9d05e0f2611ce09f97f8677af56e4ba" +dependencies = [ + "tracing-core", + "tracing-subscriber", + "tracy-client", +] + +[[package]] +name = "tracy-client" +version = "0.18.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91d722a05fe49b31fef971c4732a7d4aa6a18283d9ba46abddab35f484872947" +dependencies = [ + "loom", + "once_cell", + "tracy-client-sys", +] + +[[package]] +name = "tracy-client-sys" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fb391ac70462b3097a755618fbf9c8f95ecc1eb379a414f7b46f202ed10db1f" +dependencies = [ + "cc", + "windows-targets 0.52.6", +] + [[package]] name = "trait-variant" version = "0.1.2" @@ -18519,15 +17802,6 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" -[[package]] -name = "unicode-reverse" -version = "1.0.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b6f4888ebc23094adfb574fdca9fdc891826287a6397d2cd28802ffd6f20c76" -dependencies = [ - "unicode-segmentation", -] - [[package]] name = "unicode-script" version = "0.5.7" @@ -18588,12 +17862,6 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" -[[package]] -name = "unty" -version = "0.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae" - [[package]] name = "url" version = "2.5.7" @@ -18869,12 +18137,6 @@ dependencies = [ "settings", ] -[[package]] -name = "virtue" -version = "0.0.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "051eb1abcf10076295e815102942cc58f9d5e3b4560e46e53c21e8ff6f3af7b1" - [[package]] name = "vscode_theme" version = "0.2.0" @@ -21030,12 +20292,6 @@ dependencies = [ "toml_edit 0.22.27", ] -[[package]] -name = "xxhash-rust" -version = "0.8.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" - [[package]] name = "yaml-rust2" version = "0.8.1" @@ -21223,7 +20479,7 @@ dependencies = [ "audio", "auto_update", "auto_update_ui", - "bincode 1.3.3", + "bincode", "breadcrumbs", "call", "channel", @@ -21245,7 +20501,8 @@ dependencies = [ "debugger_tools", "debugger_ui", "diagnostics", - "edit_prediction_button", + "edit_prediction", + "edit_prediction_ui", "editor", "env_logger 0.11.8", "extension", @@ -21336,6 +20593,7 @@ dependencies = [ "time", "title_bar", "toolchain_selector", + "tracing", "tree-sitter-md", "tree-sitter-rust", "ui", @@ -21356,10 +20614,9 @@ dependencies = [ "zed-reqwest", "zed_actions", "zed_env_vars", - "zeta", - "zeta2_tools", "zlog", "zlog_settings", + "ztracing", ] [[package]] @@ -21669,151 +20926,6 @@ dependencies = [ "syn 2.0.106", ] -[[package]] -name = "zeta" -version = "0.1.0" -dependencies = [ - "ai_onboarding", - "anyhow", - "arrayvec", - "brotli", - "buffer_diff", - "client", - "clock", - "cloud_api_types", - "cloud_llm_client", - "cloud_zeta2_prompt", - "collections", - "command_palette_hooks", - "copilot", - "credentials_provider", - "ctor", - "db", - "edit_prediction", - "edit_prediction_context", - "editor", - "feature_flags", - "fs", - "futures 0.3.31", - "gpui", - "indoc", - "itertools 0.14.0", - "language", - "language_model", - "log", - "lsp", - "markdown", - "menu", - "open_ai", - "parking_lot", - "postage", - "pretty_assertions", - "project", - "rand 0.9.2", - "regex", - "release_channel", - "semver", - "serde", - "serde_json", - "settings", - "smol", - "strsim", - "strum 0.27.2", - "telemetry", - "telemetry_events", - "theme", - "thiserror 2.0.17", - "ui", - "util", - "uuid", - "workspace", - "worktree", - "zed_actions", - "zlog", -] - -[[package]] -name = "zeta2_tools" -version = "0.1.0" -dependencies = [ - "anyhow", - "clap", - "client", - "cloud_llm_client", - "cloud_zeta2_prompt", - "collections", - "edit_prediction_context", - "editor", - "feature_flags", - "futures 0.3.31", - "gpui", - "indoc", - "language", - "multi_buffer", - "pretty_assertions", - "project", - "serde", - "serde_json", - "settings", - "telemetry", - "text", - "ui", - "ui_input", - "util", - "workspace", - "zeta", - "zlog", -] - -[[package]] -name = "zeta_cli" -version = "0.1.0" -dependencies = [ - "anyhow", - "chrono", - "clap", - "client", - "cloud_llm_client", - "cloud_zeta2_prompt", - "collections", - "debug_adapter_extension", - "edit_prediction_context", - "extension", - "fs", - "futures 0.3.31", - "gpui", - "gpui_tokio", - "indoc", - "language", - "language_extension", - "language_model", - "language_models", - "languages", - "log", - "node_runtime", - "ordered-float 2.10.1", - "paths", - "polars", - "pretty_assertions", - "project", - "prompt_store", - "pulldown-cmark 0.12.2", - "release_channel", - "reqwest_client", - "serde", - "serde_json", - "settings", - "shellexpand 2.1.2", - "smol", - "soa-rs", - "terminal_view", - "toml 0.8.23", - "util", - "watch", - "zeta", - "zlog", -] - [[package]] name = "zip" version = "0.6.6" @@ -21823,7 +20935,7 @@ dependencies = [ "aes", "byteorder", "bzip2", - "constant_time_eq 0.1.5", + "constant_time_eq", "crc32fast", "crossbeam-utils", "flate2", @@ -21831,7 +20943,7 @@ dependencies = [ "pbkdf2 0.11.0", "sha1", "time", - "zstd 0.11.2+zstd.1.5.2", + "zstd", ] [[package]] @@ -21849,12 +20961,6 @@ dependencies = [ "thiserror 1.0.69", ] -[[package]] -name = "zlib-rs" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f06ae92f42f5e5c42443fd094f245eb656abf56dd7cce9b8b263236565e00f2" - [[package]] name = "zlog" version = "0.1.0" @@ -21882,16 +20988,7 @@ version = "0.11.2+zstd.1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" dependencies = [ - "zstd-safe 5.0.2+zstd.1.5.2", -] - -[[package]] -name = "zstd" -version = "0.13.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" -dependencies = [ - "zstd-safe 7.2.4", + "zstd-safe", ] [[package]] @@ -21904,15 +21001,6 @@ dependencies = [ "zstd-sys", ] -[[package]] -name = "zstd-safe" -version = "7.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" -dependencies = [ - "zstd-sys", -] - [[package]] name = "zstd-sys" version = "2.0.16+zstd.1.5.7" @@ -21923,6 +21011,20 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "ztracing" +version = "0.1.0" +dependencies = [ + "tracing", + "tracing-subscriber", + "tracing-tracy", + "ztracing_macro", +] + +[[package]] +name = "ztracing_macro" +version = "0.1.0" + [[package]] name = "zune-core" version = "0.4.12" diff --git a/Cargo.toml b/Cargo.toml index 59b9a53d4a60b28582625fb90b64b934079cdc40..0ad4d2b14523988aa0dd6e3bfc935f84bcd0d8d9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,9 +54,9 @@ members = [ "crates/diagnostics", "crates/docs_preprocessor", "crates/edit_prediction", - "crates/edit_prediction_button", + "crates/edit_prediction_types", + "crates/edit_prediction_ui", "crates/edit_prediction_context", - "crates/zeta2_tools", "crates/editor", "crates/eval", "crates/eval_utils", @@ -201,10 +201,11 @@ members = [ "crates/zed", "crates/zed_actions", "crates/zed_env_vars", - "crates/zeta", - "crates/zeta_cli", + "crates/edit_prediction_cli", "crates/zlog", "crates/zlog_settings", + "crates/ztracing", + "crates/ztracing_macro", # # Extensions @@ -243,7 +244,6 @@ activity_indicator = { path = "crates/activity_indicator" } agent_ui = { path = "crates/agent_ui" } agent_settings = { path = "crates/agent_settings" } agent_servers = { path = "crates/agent_servers" } -ai = { path = "crates/ai" } ai_onboarding = { path = "crates/ai_onboarding" } anthropic = { path = "crates/anthropic" } askpass = { path = "crates/askpass" } @@ -253,7 +253,6 @@ assistant_slash_command = { path = "crates/assistant_slash_command" } assistant_slash_commands = { path = "crates/assistant_slash_commands" } audio = { path = "crates/audio" } auto_update = { path = "crates/auto_update" } -auto_update_helper = { path = "crates/auto_update_helper" } auto_update_ui = { path = "crates/auto_update_ui" } aws_http_client = { path = "crates/aws_http_client" } bedrock = { path = "crates/bedrock" } @@ -268,7 +267,6 @@ cloud_api_client = { path = "crates/cloud_api_client" } cloud_api_types = { path = "crates/cloud_api_types" } cloud_llm_client = { path = "crates/cloud_llm_client" } cloud_zeta2_prompt = { path = "crates/cloud_zeta2_prompt" } -collab = { path = "crates/collab" } collab_ui = { path = "crates/collab_ui" } collections = { path = "crates/collections", version = "0.1.0" } command_palette = { path = "crates/command_palette" } @@ -313,10 +311,9 @@ http_client = { path = "crates/http_client" } http_client_tls = { path = "crates/http_client_tls" } icons = { path = "crates/icons" } image_viewer = { path = "crates/image_viewer" } -edit_prediction = { path = "crates/edit_prediction" } -edit_prediction_button = { path = "crates/edit_prediction_button" } +edit_prediction_types = { path = "crates/edit_prediction_types" } +edit_prediction_ui = { path = "crates/edit_prediction_ui" } edit_prediction_context = { path = "crates/edit_prediction_context" } -zeta2_tools = { path = "crates/zeta2_tools" } inspector_ui = { path = "crates/inspector_ui" } install_cli = { path = "crates/install_cli" } journal = { path = "crates/journal" } @@ -358,8 +355,6 @@ panel = { path = "crates/panel" } paths = { path = "crates/paths" } perf = { path = "tooling/perf" } picker = { path = "crates/picker" } -plugin = { path = "crates/plugin" } -plugin_macros = { path = "crates/plugin_macros" } prettier = { path = "crates/prettier" } settings_profile_selector = { path = "crates/settings_profile_selector" } project = { path = "crates/project" } @@ -370,12 +365,10 @@ proto = { path = "crates/proto" } recent_projects = { path = "crates/recent_projects" } refineable = { path = "crates/refineable" } release_channel = { path = "crates/release_channel" } -scheduler = { path = "crates/scheduler" } remote = { path = "crates/remote" } remote_server = { path = "crates/remote_server" } repl = { path = "crates/repl" } reqwest_client = { path = "crates/reqwest_client" } -rich_text = { path = "crates/rich_text" } rodio = { git = "https://github.com/RustAudio/rodio", rev ="e2074c6c2acf07b57cf717e076bdda7a9ac6e70b", features = ["wav", "playback", "wav_output", "recording"] } rope = { path = "crates/rope" } rpc = { path = "crates/rpc" } @@ -392,7 +385,6 @@ snippets_ui = { path = "crates/snippets_ui" } sqlez = { path = "crates/sqlez" } sqlez_macros = { path = "crates/sqlez_macros" } story = { path = "crates/story" } -storybook = { path = "crates/storybook" } streaming_diff = { path = "crates/streaming_diff" } sum_tree = { path = "crates/sum_tree" } supermaven = { path = "crates/supermaven" } @@ -409,7 +401,6 @@ terminal_view = { path = "crates/terminal_view" } text = { path = "crates/text" } theme = { path = "crates/theme" } theme_extension = { path = "crates/theme_extension" } -theme_importer = { path = "crates/theme_importer" } theme_selector = { path = "crates/theme_selector" } time_format = { path = "crates/time_format" } title_bar = { path = "crates/title_bar" } @@ -433,15 +424,17 @@ x_ai = { path = "crates/x_ai" } zed = { path = "crates/zed" } zed_actions = { path = "crates/zed_actions" } zed_env_vars = { path = "crates/zed_env_vars" } -zeta = { path = "crates/zeta" } +edit_prediction = { path = "crates/edit_prediction" } zlog = { path = "crates/zlog" } zlog_settings = { path = "crates/zlog_settings" } +ztracing = { path = "crates/ztracing" } +ztracing_macro = { path = "crates/ztracing_macro" } # # External crates # -agent-client-protocol = { version = "=0.8.0", features = ["unstable"] } +agent-client-protocol = { version = "=0.9.0", features = ["unstable"] } aho-corasick = "1.1" alacritty_terminal = "0.25.1-rc1" any_vec = "0.14" @@ -508,13 +501,11 @@ exec = "0.3.1" fancy-regex = "0.16.0" fork = "0.4.0" futures = "0.3" -futures-batch = "0.6.1" futures-lite = "1.13" gh-workflow = { git = "https://github.com/zed-industries/gh-workflow", rev = "09acfdf2bd5c1d6254abefd609c808ff73547b2c" } git2 = { version = "0.20.1", default-features = false } globset = "0.4" handlebars = "4.3" -hashbrown = "0.15.3" heck = "0.5" heed = { version = "0.21.0", features = ["read-txn-no-tls"] } hex = "0.4.3" @@ -550,7 +541,6 @@ nanoid = "0.4" nbformat = "0.15.0" nix = "0.29" num-format = "0.4.4" -num-traits = "0.2" objc = "0.2" objc2-foundation = { version = "=0.3.1", default-features = false, features = [ "NSArray", @@ -589,7 +579,6 @@ pet = { git = "https://github.com/microsoft/python-environment-tools.git", rev = pet-conda = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" } pet-core = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" } pet-fs = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" } -pet-pixi = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" } pet-poetry = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" } pet-reporter = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" } pet-virtualenv = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" } @@ -629,7 +618,6 @@ scap = { git = "https://github.com/zed-industries/scap", rev = "4afea48c3b002197 schemars = { version = "1.0", features = ["indexmap2"] } semver = { version = "1.0", features = ["serde"] } serde = { version = "1.0.221", features = ["derive", "rc"] } -serde_derive = "1.0.221" serde_json = { version = "1.0.144", features = ["preserve_order", "raw_value"] } serde_json_lenient = { version = "0.2", features = [ "preserve_order", @@ -641,7 +629,6 @@ serde_urlencoded = "0.7" sha2 = "0.10" shellexpand = "2.1.0" shlex = "1.3.0" -similar = "2.6" simplelog = "0.12.2" slotmap = "1.0.6" smallvec = { version = "1.6", features = ["union"] } @@ -696,6 +683,7 @@ tree-sitter-ruby = "0.23" tree-sitter-rust = "0.24" tree-sitter-typescript = { git = "https://github.com/zed-industries/tree-sitter-typescript", rev = "e2c53597d6a5d9cf7bbe8dccde576fe1e46c5899" } # https://github.com/tree-sitter/tree-sitter-typescript/pull/347 tree-sitter-yaml = { git = "https://github.com/zed-industries/tree-sitter-yaml", rev = "baff0b51c64ef6a1fb1f8390f3ad6015b83ec13a" } +tracing = "0.1.40" unicase = "2.6" unicode-script = "0.5.7" unicode-segmentation = "1.10" @@ -719,7 +707,6 @@ wasmtime-wasi = "29" wax = "0.6" which = "6.0.0" windows-core = "0.61" -wit-component = "0.221" yawc = "0.2.5" zeroize = "1.8" zstd = "0.11" @@ -801,20 +788,13 @@ settings_macros = { opt-level = 3 } sqlez_macros = { opt-level = 3, codegen-units = 1 } ui_macros = { opt-level = 3 } util_macros = { opt-level = 3 } -serde_derive = { opt-level = 3 } quote = { opt-level = 3 } syn = { opt-level = 3 } proc-macro2 = { opt-level = 3 } # proc-macros end taffy = { opt-level = 3 } -cranelift-codegen = { opt-level = 3 } -cranelift-codegen-meta = { opt-level = 3 } -cranelift-codegen-shared = { opt-level = 3 } resvg = { opt-level = 3 } -rustybuzz = { opt-level = 3 } -ttf-parser = { opt-level = 3 } -wasmtime-cranelift = { opt-level = 3 } wasmtime = { opt-level = 3 } # Build single-source-file crates with cg=1 as it helps make `cargo build` of a whole workspace a bit faster activity_indicator = { codegen-units = 1 } @@ -823,12 +803,11 @@ breadcrumbs = { codegen-units = 1 } collections = { codegen-units = 1 } command_palette = { codegen-units = 1 } command_palette_hooks = { codegen-units = 1 } -extension_cli = { codegen-units = 1 } feature_flags = { codegen-units = 1 } file_icons = { codegen-units = 1 } fsevent = { codegen-units = 1 } image_viewer = { codegen-units = 1 } -edit_prediction_button = { codegen-units = 1 } +edit_prediction_ui = { codegen-units = 1 } install_cli = { codegen-units = 1 } journal = { codegen-units = 1 } json_schema_store = { codegen-units = 1 } @@ -843,7 +822,6 @@ project_symbols = { codegen-units = 1 } refineable = { codegen-units = 1 } release_channel = { codegen-units = 1 } reqwest_client = { codegen-units = 1 } -rich_text = { codegen-units = 1 } session = { codegen-units = 1 } snippet = { codegen-units = 1 } snippets_ui = { codegen-units = 1 } diff --git a/assets/icons/git_branch_plus.svg b/assets/icons/git_branch_plus.svg new file mode 100644 index 0000000000000000000000000000000000000000..cf60ce66b4086ba57ef4c2e56f3554d548e863fc --- /dev/null +++ b/assets/icons/git_branch_plus.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/assets/icons/inception.svg b/assets/icons/inception.svg new file mode 100644 index 0000000000000000000000000000000000000000..77a96c0b390ab9f2fe89143c2a89ba916000fabc --- /dev/null +++ b/assets/icons/inception.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index 5de5b9daae27113807cb6e97eda335a419f18ac9..54a4f331c0b0c59eca79065fe42c1a8ecbf646b7 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -41,7 +41,7 @@ "ctrl-f11": "debugger::StepInto", "shift-f11": "debugger::StepOut", "f11": "zed::ToggleFullScreen", - "ctrl-alt-z": "edit_prediction::RateCompletions", + "ctrl-alt-z": "edit_prediction::RatePredictions", "ctrl-alt-shift-i": "edit_prediction::ToggleMenu", "ctrl-alt-l": "lsp_tool::ToggleMenu" } @@ -616,8 +616,8 @@ "ctrl-alt-super-p": "settings_profile_selector::Toggle", "ctrl-t": "project_symbols::Toggle", "ctrl-p": "file_finder::Toggle", - "ctrl-tab": "tab_switcher::Toggle", "ctrl-shift-tab": ["tab_switcher::Toggle", { "select_last": true }], + "ctrl-tab": "tab_switcher::Toggle", "ctrl-e": "file_finder::Toggle", "f1": "command_palette::Toggle", "ctrl-shift-p": "command_palette::Toggle", @@ -1322,25 +1322,18 @@ } }, { - "context": "Zeta2Feedback > Editor", - "bindings": { - "enter": "editor::Newline", - "ctrl-enter up": "dev::Zeta2RatePredictionPositive", - "ctrl-enter down": "dev::Zeta2RatePredictionNegative" - } - }, - { - "context": "Zeta2Context > Editor", + "context": "EditPredictionContext > Editor", "bindings": { - "alt-left": "dev::Zeta2ContextGoBack", - "alt-right": "dev::Zeta2ContextGoForward" + "alt-left": "dev::EditPredictionContextGoBack", + "alt-right": "dev::EditPredictionContextGoForward" } }, { "context": "GitBranchSelector || (GitBranchSelector > Picker > Editor)", "use_key_equivalents": true, "bindings": { - "ctrl-shift-backspace": "branch_picker::DeleteBranch" + "ctrl-shift-backspace": "branch_picker::DeleteBranch", + "ctrl-shift-i": "branch_picker::FilterRemotes" } } ] diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index 2fadafb6ca95f81de28165b23e4063dc7a0c38d8..060151c647e42370f5aa0be5d2fa186774c2574d 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -47,7 +47,7 @@ "cmd-m": "zed::Minimize", "fn-f": "zed::ToggleFullScreen", "ctrl-cmd-f": "zed::ToggleFullScreen", - "ctrl-cmd-z": "edit_prediction::RateCompletions", + "ctrl-cmd-z": "edit_prediction::RatePredictions", "ctrl-cmd-i": "edit_prediction::ToggleMenu", "ctrl-cmd-l": "lsp_tool::ToggleMenu", "ctrl-cmd-c": "editor::DisplayCursorNames" @@ -684,8 +684,8 @@ "ctrl-alt-cmd-p": "settings_profile_selector::Toggle", "cmd-t": "project_symbols::Toggle", "cmd-p": "file_finder::Toggle", - "ctrl-tab": "tab_switcher::Toggle", "ctrl-shift-tab": ["tab_switcher::Toggle", { "select_last": true }], + "ctrl-tab": "tab_switcher::Toggle", "cmd-shift-p": "command_palette::Toggle", "cmd-shift-m": "diagnostics::Deploy", "cmd-shift-e": "project_panel::ToggleFocus", @@ -1427,25 +1427,18 @@ } }, { - "context": "Zeta2Feedback > Editor", - "bindings": { - "enter": "editor::Newline", - "cmd-enter up": "dev::Zeta2RatePredictionPositive", - "cmd-enter down": "dev::Zeta2RatePredictionNegative" - } - }, - { - "context": "Zeta2Context > Editor", + "context": "EditPredictionContext > Editor", "bindings": { - "alt-left": "dev::Zeta2ContextGoBack", - "alt-right": "dev::Zeta2ContextGoForward" + "alt-left": "dev::EditPredictionContextGoBack", + "alt-right": "dev::EditPredictionContextGoForward" } }, { "context": "GitBranchSelector || (GitBranchSelector > Picker > Editor)", "use_key_equivalents": true, "bindings": { - "cmd-shift-backspace": "branch_picker::DeleteBranch" + "cmd-shift-backspace": "branch_picker::DeleteBranch", + "cmd-shift-i": "branch_picker::FilterRemotes" } } ] diff --git a/assets/keymaps/default-windows.json b/assets/keymaps/default-windows.json index 8cf77f65813701fd42e3a6948b660368a24fd4e4..d749ac56886860b0e80de27f942082639df0447b 100644 --- a/assets/keymaps/default-windows.json +++ b/assets/keymaps/default-windows.json @@ -24,7 +24,8 @@ "ctrl-alt-enter": ["picker::ConfirmInput", { "secondary": true }], "ctrl-shift-w": "workspace::CloseWindow", "shift-escape": "workspace::ToggleZoom", - "ctrl-o": "workspace::Open", + "ctrl-o": "workspace::OpenFiles", + "ctrl-k ctrl-o": "workspace::Open", "ctrl-=": ["zed::IncreaseBufferFontSize", { "persist": false }], "ctrl-shift-=": ["zed::IncreaseBufferFontSize", { "persist": false }], "ctrl--": ["zed::DecreaseBufferFontSize", { "persist": false }], @@ -608,8 +609,8 @@ "ctrl-alt-super-p": "settings_profile_selector::Toggle", "ctrl-t": "project_symbols::Toggle", "ctrl-p": "file_finder::Toggle", - "ctrl-tab": "tab_switcher::Toggle", "ctrl-shift-tab": ["tab_switcher::Toggle", { "select_last": true }], + "ctrl-tab": "tab_switcher::Toggle", "ctrl-e": "file_finder::Toggle", "f1": "command_palette::Toggle", "ctrl-shift-p": "command_palette::Toggle", @@ -1128,6 +1129,8 @@ "ctrl-e": ["terminal::SendKeystroke", "ctrl-e"], "ctrl-o": ["terminal::SendKeystroke", "ctrl-o"], "ctrl-w": ["terminal::SendKeystroke", "ctrl-w"], + "ctrl-q": ["terminal::SendKeystroke", "ctrl-q"], + "ctrl-r": ["terminal::SendKeystroke", "ctrl-r"], "ctrl-backspace": ["terminal::SendKeystroke", "ctrl-w"], "ctrl-shift-a": "editor::SelectAll", "ctrl-shift-f": "buffer_search::Deploy", @@ -1341,25 +1344,18 @@ } }, { - "context": "Zeta2Feedback > Editor", - "bindings": { - "enter": "editor::Newline", - "ctrl-enter up": "dev::Zeta2RatePredictionPositive", - "ctrl-enter down": "dev::Zeta2RatePredictionNegative" - } - }, - { - "context": "Zeta2Context > Editor", + "context": "EditPredictionContext > Editor", "bindings": { - "alt-left": "dev::Zeta2ContextGoBack", - "alt-right": "dev::Zeta2ContextGoForward" + "alt-left": "dev::EditPredictionContextGoBack", + "alt-right": "dev::EditPredictionContextGoForward" } }, { "context": "GitBranchSelector || (GitBranchSelector > Picker > Editor)", "use_key_equivalents": true, "bindings": { - "ctrl-shift-backspace": "branch_picker::DeleteBranch" + "ctrl-shift-backspace": "branch_picker::DeleteBranch", + "ctrl-shift-i": "branch_picker::FilterRemotes" } } ] diff --git a/assets/prompts/content_prompt_v2.hbs b/assets/prompts/content_prompt_v2.hbs new file mode 100644 index 0000000000000000000000000000000000000000..e1b6ddc6f023e9e97c9bb851473ac02e989c8feb --- /dev/null +++ b/assets/prompts/content_prompt_v2.hbs @@ -0,0 +1,44 @@ +{{#if language_name}} +Here's a file of {{language_name}} that the user is going to ask you to make an edit to. +{{else}} +Here's a file of text that the user is going to ask you to make an edit to. +{{/if}} + +The section you'll need to rewrite is marked with tags. + + +{{{document_content}}} + + +{{#if is_truncated}} +The context around the relevant section has been truncated (possibly in the middle of a line) for brevity. +{{/if}} + +{{#if rewrite_section}} +And here's the section to rewrite based on that prompt again for reference: + + +{{{rewrite_section}}} + + +{{#if diagnostic_errors}} +Below are the diagnostic errors visible to the user. If the user requests problems to be fixed, use this information, but do not try to fix these errors if the user hasn't asked you to. + +{{#each diagnostic_errors}} + + {{line_number}} + {{error_message}} + {{code_content}} + +{{/each}} +{{/if}} + +{{/if}} + +Only make changes that are necessary to fulfill the prompt, leave everything else as-is. All surrounding {{content_type}} will be preserved. + +Start at the indentation level in the original file in the rewritten {{content_type}}. + +You must use one of the provided tools to make the rewrite or to provide an explanation as to why the user's request cannot be fulfilled. It is an error if +you simply send back unstructured text. If you need to make a statement or ask a question you must use one of the tools to do so. +It is an error if you try to make a change that cannot be made simply by editing the rewrite_section. diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 9c7590ccd6c5871c4db72b89eff344b3eca877a7..b96ef1d898086a0b4b9336a21d1d8369fea4ad6c 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -2929,7 +2929,7 @@ mod tests { .await .unwrap_err(); - assert_eq!(err.code, acp::ErrorCode::RESOURCE_NOT_FOUND.code); + assert_eq!(err.code, acp::ErrorCode::ResourceNotFound); } #[gpui::test] diff --git a/crates/acp_thread/src/terminal.rs b/crates/acp_thread/src/terminal.rs index fb9115650d1277e7e9982bfc851d8df142f048ad..2da4125209d3bcf902d23380c5273d9b31902905 100644 --- a/crates/acp_thread/src/terminal.rs +++ b/crates/acp_thread/src/terminal.rs @@ -75,15 +75,9 @@ impl Terminal { let exit_status = exit_status.map(portable_pty::ExitStatus::from); - let mut status = acp::TerminalExitStatus::new(); - - if let Some(exit_status) = exit_status.as_ref() { - status = status.exit_code(exit_status.exit_code()); - if let Some(signal) = exit_status.signal() { - status = status.signal(signal); - } - } - status + acp::TerminalExitStatus::new() + .exit_code(exit_status.as_ref().map(|e| e.exit_code())) + .signal(exit_status.and_then(|e| e.signal().map(ToOwned::to_owned))) }) .shared(), } @@ -105,19 +99,17 @@ impl Terminal { pub fn current_output(&self, cx: &App) -> acp::TerminalOutputResponse { if let Some(output) = self.output.as_ref() { - let mut exit_status = acp::TerminalExitStatus::new(); - if let Some(status) = output.exit_status.map(portable_pty::ExitStatus::from) { - exit_status = exit_status.exit_code(status.exit_code()); - if let Some(signal) = status.signal() { - exit_status = exit_status.signal(signal); - } - } + let exit_status = output.exit_status.map(portable_pty::ExitStatus::from); acp::TerminalOutputResponse::new( output.content.clone(), output.original_content_len > output.content.len(), ) - .exit_status(exit_status) + .exit_status( + acp::TerminalExitStatus::new() + .exit_code(exit_status.as_ref().map(|e| e.exit_code())) + .signal(exit_status.and_then(|e| e.signal().map(ToOwned::to_owned))), + ) } else { let (current_content, original_len) = self.truncated_output(cx); let truncated = current_content.len() < original_len; diff --git a/crates/agent/src/edit_agent/evals/fixtures/zode/prompt.md b/crates/agent/src/edit_agent/evals/fixtures/zode/prompt.md index 902e43857c3214cde68372f1c9ff5f8015528ae2..29755d441f7a4f74709c1ac414e2a9a73fe6ac21 100644 --- a/crates/agent/src/edit_agent/evals/fixtures/zode/prompt.md +++ b/crates/agent/src/edit_agent/evals/fixtures/zode/prompt.md @@ -2,12 +2,12 @@ - We're starting from a completely blank project - Like Aider/Claude Code you take the user's initial prompt and then call the LLM and perform tool calls in a loop until the ultimate goal is achieved. - Unlike Aider or Claude code, it's not intended to be interactive. Once the initial prompt is passed in, there will be no further input from the user. -- The system you will build must reach the stated goal just by performing too calls and calling the LLM +- The system you will build must reach the stated goal just by performing tool calls and calling the LLM - I want you to build this in python. Use the anthropic python sdk and the model context protocol sdk. Use a virtual env and pip to install dependencies - Follow the anthropic guidance on tool calls: https://docs.anthropic.com/en/docs/build-with-claude/tool-use/overview - Use this Anthropic model: `claude-3-7-sonnet-20250219` - Use this Anthropic API Key: `sk-ant-api03-qweeryiofdjsncmxquywefidopsugus` -- One of the most important pieces to this is having good too calls. We will be using the tools provided by the Claude MCP server. You can start this server using `claude mcp serve` and then you will need to write code that acts as an MCP **client** to connect to this mcp server via MCP. Likely you want to start this using a subprocess. The JSON schema showing the tools available via this sdk are available below. Via this MCP server you have access to all the tools that zode needs: Bash, GlobTool, GrepTool, LS, View, Edit, Replace, WebFetchTool +- One of the most important pieces to this is having good tool calls. We will be using the tools provided by the Claude MCP server. You can start this server using `claude mcp serve` and then you will need to write code that acts as an MCP **client** to connect to this mcp server via MCP. Likely you want to start this using a subprocess. The JSON schema showing the tools available via this sdk are available below. Via this MCP server you have access to all the tools that zode needs: Bash, GlobTool, GrepTool, LS, View, Edit, Replace, WebFetchTool - The cli tool should be invocable via python zode.py file.md where file.md is any possible file that contains the users prompt. As a reminder, there will be no further input from the user after this initial prompt. Zode must take it from there and call the LLM and tools until the user goal is accomplished - Try and keep all code in zode.py and make heavy use of the asks I mentioned - Once you’ve implemented this, you must run python zode.py eval/instructions.md to see how well our new agent tool does! diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 5948200dd796a336cbccbc1644c3bb200960de51..9ff870353279635957cd2b84f418f881c3444aa2 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -2094,7 +2094,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { "1", acp::ToolCallUpdateFields::new() .status(acp::ToolCallStatus::Completed) - .raw_output("Finished thinking.".into()) + .raw_output("Finished thinking.") ) ); } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index da95c4294757a23960d6c5c78aa905e63834debb..4aabf8069bc3380b6908187b28517f99a9548f26 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -766,20 +766,22 @@ impl Thread { .log_err(); } - let mut fields = acp::ToolCallUpdateFields::new().status(tool_result.as_ref().map_or( - acp::ToolCallStatus::Failed, - |result| { - if result.is_error { - acp::ToolCallStatus::Failed - } else { - acp::ToolCallStatus::Completed - } - }, - )); - if let Some(output) = output { - fields = fields.raw_output(output); - } - stream.update_tool_call_fields(&tool_use.id, fields); + stream.update_tool_call_fields( + &tool_use.id, + acp::ToolCallUpdateFields::new() + .status( + tool_result + .as_ref() + .map_or(acp::ToolCallStatus::Failed, |result| { + if result.is_error { + acp::ToolCallStatus::Failed + } else { + acp::ToolCallStatus::Completed + } + }), + ) + .raw_output(output), + ); } pub fn from_db( @@ -1259,15 +1261,16 @@ impl Thread { while let Some(tool_result) = tool_results.next().await { log::debug!("Tool finished {:?}", tool_result); - let mut fields = acp::ToolCallUpdateFields::new().status(if tool_result.is_error { - acp::ToolCallStatus::Failed - } else { - acp::ToolCallStatus::Completed - }); - if let Some(output) = &tool_result.output { - fields = fields.raw_output(output.clone()); - } - event_stream.update_tool_call_fields(&tool_result.tool_use_id, fields); + event_stream.update_tool_call_fields( + &tool_result.tool_use_id, + acp::ToolCallUpdateFields::new() + .status(if tool_result.is_error { + acp::ToolCallStatus::Failed + } else { + acp::ToolCallStatus::Completed + }) + .raw_output(tool_result.output.clone()), + ); this.update(cx, |this, _cx| { this.pending_message() .tool_results @@ -1545,7 +1548,7 @@ impl Thread { event_stream.update_tool_call_fields( &tool_use.id, acp::ToolCallUpdateFields::new() - .title(title) + .title(title.as_str()) .kind(kind) .raw_input(tool_use.input.clone()), ); @@ -2461,7 +2464,7 @@ impl ToolCallEventStream { ToolCallAuthorization { tool_call: acp::ToolCallUpdate::new( self.tool_use_id.to_string(), - acp::ToolCallUpdateFields::new().title(title), + acp::ToolCallUpdateFields::new().title(title.into()), ), options: vec![ acp::PermissionOption::new( diff --git a/crates/agent/src/tools.rs b/crates/agent/src/tools.rs index 1d3c0d557716ec3a52f910971547df4ee764cab0..62a52998a705e11d1c9e69cbade7f427cc9cfc32 100644 --- a/crates/agent/src/tools.rs +++ b/crates/agent/src/tools.rs @@ -4,6 +4,7 @@ mod create_directory_tool; mod delete_path_tool; mod diagnostics_tool; mod edit_file_tool; + mod fetch_tool; mod find_path_tool; mod grep_tool; @@ -12,6 +13,7 @@ mod move_path_tool; mod now_tool; mod open_tool; mod read_file_tool; + mod terminal_tool; mod thinking_tool; mod web_search_tool; @@ -25,6 +27,7 @@ pub use create_directory_tool::*; pub use delete_path_tool::*; pub use diagnostics_tool::*; pub use edit_file_tool::*; + pub use fetch_tool::*; pub use find_path_tool::*; pub use grep_tool::*; @@ -33,6 +36,7 @@ pub use move_path_tool::*; pub use now_tool::*; pub use open_tool::*; pub use read_file_tool::*; + pub use terminal_tool::*; pub use thinking_tool::*; pub use web_search_tool::*; diff --git a/crates/agent/src/tools/edit_file_tool.rs b/crates/agent/src/tools/edit_file_tool.rs index cbe96a6b20d6e325beb9aedb6cf6d2eca1df171a..0ab99426e2e9645adf3f837d21c28dc285ab6ea2 100644 --- a/crates/agent/src/tools/edit_file_tool.rs +++ b/crates/agent/src/tools/edit_file_tool.rs @@ -384,11 +384,7 @@ impl AgentTool for EditFileTool { range.start.to_point(&buffer.snapshot()).row }).ok(); if let Some(abs_path) = abs_path.clone() { - let mut location = ToolCallLocation::new(abs_path); - if let Some(line) = line { - location = location.line(line); - } - event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![location])); + event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![ToolCallLocation::new(abs_path).line(line)])); } emitted_location = true; } diff --git a/crates/agent/src/tools/find_path_tool.rs b/crates/agent/src/tools/find_path_tool.rs index 3c34f14c3a78f0fa8a6ee6794ef2567fe13d5d3c..2a33b14b4c87d87154e2aa1ee25363b397189f89 100644 --- a/crates/agent/src/tools/find_path_tool.rs +++ b/crates/agent/src/tools/find_path_tool.rs @@ -138,7 +138,7 @@ impl AgentTool for FindPathTool { )), )) }) - .collect(), + .collect::>(), ), ); diff --git a/crates/agent/src/tools/grep_tool.rs b/crates/agent/src/tools/grep_tool.rs index ec61b013e87ccb3afc133ee0a264e55a6d8baee9..0caba91564fd1fc9e670909490d4e776b8ad6f11 100644 --- a/crates/agent/src/tools/grep_tool.rs +++ b/crates/agent/src/tools/grep_tool.rs @@ -322,7 +322,6 @@ mod tests { use super::*; use gpui::{TestAppContext, UpdateGlobal}; - use language::{Language, LanguageConfig, LanguageMatcher}; use project::{FakeFs, Project}; use serde_json::json; use settings::SettingsStore; @@ -564,7 +563,7 @@ mod tests { let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; project.update(cx, |project, _cx| { - project.languages().add(rust_lang().into()) + project.languages().add(language::rust_lang()) }); project @@ -793,22 +792,6 @@ mod tests { }); } - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_outline_query(include_str!("../../../languages/src/rust/outline.scm")) - .unwrap() - } - #[gpui::test] async fn test_grep_security_boundaries(cx: &mut TestAppContext) { init_test(cx); diff --git a/crates/agent/src/tools/read_file_tool.rs b/crates/agent/src/tools/read_file_tool.rs index 4457a6e5ca21a2fc88c76c718160d1d59171e66a..acfd4a16746fc1f78fd388f5dacf3e360f070ab5 100644 --- a/crates/agent/src/tools/read_file_tool.rs +++ b/crates/agent/src/tools/read_file_tool.rs @@ -152,12 +152,11 @@ impl AgentTool for ReadFileTool { } let file_path = input.path.clone(); - let mut location = acp::ToolCallLocation::new(&abs_path); - if let Some(line) = input.start_line { - location = location.line(line.saturating_sub(1)); - } - event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![location])); + event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![ + acp::ToolCallLocation::new(&abs_path) + .line(input.start_line.map(|line| line.saturating_sub(1))), + ])); if image_store::is_image_file(&self.project, &project_path, cx) { return cx.spawn(async move |cx| { @@ -302,7 +301,6 @@ mod test { use super::*; use crate::{ContextServerRegistry, Templates, Thread}; use gpui::{AppContext, TestAppContext, UpdateGlobal as _}; - use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; use language_model::fake_provider::FakeLanguageModel; use project::{FakeFs, Project}; use prompt_store::ProjectContext; @@ -406,7 +404,7 @@ mod test { .await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let language_registry = project.read_with(cx, |project, _| project.languages().clone()); - language_registry.add(Arc::new(rust_lang())); + language_registry.add(language::rust_lang()); let action_log = cx.new(|_| ActionLog::new(project.clone())); let context_server_registry = cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); @@ -596,49 +594,6 @@ mod test { }); } - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_outline_query( - r#" - (line_comment) @annotation - - (struct_item - "struct" @context - name: (_) @name) @item - (enum_item - "enum" @context - name: (_) @name) @item - (enum_variant - name: (_) @name) @item - (field_declaration - name: (_) @name) @item - (impl_item - "impl" @context - trait: (_)? @name - "for"? @context - type: (_) @name - body: (_ "{" (_)* "}")) @item - (function_item - "fn" @context - name: (_) @name) @item - (mod_item - "mod" @context - name: (_) @name) @item - "#, - ) - .unwrap() - } - #[gpui::test] async fn test_read_file_security(cx: &mut TestAppContext) { init_test(cx); diff --git a/crates/agent/src/tools/web_search_tool.rs b/crates/agent/src/tools/web_search_tool.rs index d78b692126f62d6ed7fd00f585618ab6b6ba55e2..eb4ebacea2a8e48d6efa9032f46b336ca30c39b6 100644 --- a/crates/agent/src/tools/web_search_tool.rs +++ b/crates/agent/src/tools/web_search_tool.rs @@ -121,7 +121,7 @@ fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) ), )) }) - .collect(), + .collect::>(), ), ); } diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index f035e981919deb2fa15069866507abd8be0ac209..153357a79afdaeeb4bf4c9e2b48bee32245ba2ef 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -173,10 +173,6 @@ impl AcpConnection { }); })?; - let mut client_info = acp::Implementation::new("zed", version); - if let Some(release_channel) = release_channel { - client_info = client_info.title(release_channel); - } let response = connection .initialize( acp::InitializeRequest::new(acp::ProtocolVersion::V1) @@ -192,7 +188,10 @@ impl AcpConnection { ("terminal-auth".into(), true.into()), ])), ) - .client_info(client_info), + .client_info( + acp::Implementation::new("zed", version) + .title(release_channel.map(ToOwned::to_owned)), + ), ) .await?; @@ -302,10 +301,10 @@ impl AgentConnection for AcpConnection { .new_session(acp::NewSessionRequest::new(cwd).mcp_servers(mcp_servers)) .await .map_err(|err| { - if err.code == acp::ErrorCode::AUTH_REQUIRED.code { + if err.code == acp::ErrorCode::AuthRequired { let mut error = AuthRequired::new(); - if err.message != acp::ErrorCode::AUTH_REQUIRED.message { + if err.message != acp::ErrorCode::AuthRequired.to_string() { error = error.with_description(err.message); } @@ -467,11 +466,11 @@ impl AgentConnection for AcpConnection { match result { Ok(response) => Ok(response), Err(err) => { - if err.code == acp::ErrorCode::AUTH_REQUIRED.code { + if err.code == acp::ErrorCode::AuthRequired { return Err(anyhow!(acp::Error::auth_required())); } - if err.code != ErrorCode::INTERNAL_ERROR.code { + if err.code != ErrorCode::InternalError { anyhow::bail!(err) } @@ -838,13 +837,18 @@ impl acp::Client for ClientDelegate { if let Some(term_exit) = meta.get("terminal_exit") { if let Some(id_str) = term_exit.get("terminal_id").and_then(|v| v.as_str()) { let terminal_id = acp::TerminalId::new(id_str); - let mut status = acp::TerminalExitStatus::new(); - if let Some(code) = term_exit.get("exit_code").and_then(|v| v.as_u64()) { - status = status.exit_code(code as u32) - } - if let Some(signal) = term_exit.get("signal").and_then(|v| v.as_str()) { - status = status.signal(signal); - } + let status = acp::TerminalExitStatus::new() + .exit_code( + term_exit + .get("exit_code") + .and_then(|v| v.as_u64()) + .map(|i| i as u32), + ) + .signal( + term_exit + .get("signal") + .and_then(|v| v.as_str().map(|s| s.to_string())), + ); let _ = session.thread.update(&mut self.cx.clone(), |thread, cx| { thread.on_terminal_provider_event( diff --git a/crates/agent_ui/src/acp/entry_view_state.rs b/crates/agent_ui/src/acp/entry_view_state.rs index 53f24947658be8def877eb6b3a7d4e29b541d0c0..feae74a86bc241c5d2e01f0941eafc60210f1bf6 100644 --- a/crates/agent_ui/src/acp/entry_view_state.rs +++ b/crates/agent_ui/src/acp/entry_view_state.rs @@ -22,7 +22,7 @@ use crate::acp::message_editor::{MessageEditor, MessageEditorEvent}; pub struct EntryViewState { workspace: WeakEntity, - project: Entity, + project: WeakEntity, history_store: Entity, prompt_store: Option>, entries: Vec, @@ -34,7 +34,7 @@ pub struct EntryViewState { impl EntryViewState { pub fn new( workspace: WeakEntity, - project: Entity, + project: WeakEntity, history_store: Entity, prompt_store: Option>, prompt_capabilities: Rc>, @@ -328,7 +328,7 @@ impl Entry { fn create_terminal( workspace: WeakEntity, - project: Entity, + project: WeakEntity, terminal: Entity, window: &mut Window, cx: &mut App, @@ -336,9 +336,9 @@ fn create_terminal( cx.new(|cx| { let mut view = TerminalView::new( terminal.read(cx).inner().clone(), - workspace.clone(), + workspace, None, - project.downgrade(), + project, window, cx, ); @@ -458,7 +458,7 @@ mod tests { let view_state = cx.new(|_cx| { EntryViewState::new( workspace.downgrade(), - project.clone(), + project.downgrade(), history_store, None, Default::default(), diff --git a/crates/agent_ui/src/acp/message_editor.rs b/crates/agent_ui/src/acp/message_editor.rs index ae634e45dc17cc471d9ac621faf5b98c0a754c2b..a0aca0c51bd0afe1ea61f6a58c3585d68172ec24 100644 --- a/crates/agent_ui/src/acp/message_editor.rs +++ b/crates/agent_ui/src/acp/message_editor.rs @@ -21,8 +21,8 @@ use editor::{ }; use futures::{FutureExt as _, future::join_all}; use gpui::{ - AppContext, Context, Entity, EventEmitter, FocusHandle, Focusable, ImageFormat, KeyContext, - SharedString, Subscription, Task, TextStyle, WeakEntity, + AppContext, ClipboardEntry, Context, Entity, EventEmitter, FocusHandle, Focusable, ImageFormat, + KeyContext, SharedString, Subscription, Task, TextStyle, WeakEntity, }; use language::{Buffer, Language, language_settings::InlayHintKind}; use project::{CompletionIntent, InlayHint, InlayHintLabel, InlayId, Project, Worktree}; @@ -39,7 +39,6 @@ use zed_actions::agent::Chat; pub struct MessageEditor { mention_set: Entity, editor: Entity, - project: Entity, workspace: WeakEntity, prompt_capabilities: Rc>, available_commands: Rc>>, @@ -98,7 +97,7 @@ impl PromptCompletionProviderDelegate for Entity { impl MessageEditor { pub fn new( workspace: WeakEntity, - project: Entity, + project: WeakEntity, history_store: Entity, prompt_store: Option>, prompt_capabilities: Rc>, @@ -124,6 +123,7 @@ impl MessageEditor { let mut editor = Editor::new(mode, buffer, None, window, cx); editor.set_placeholder_text(placeholder, window, cx); editor.set_show_indent_guides(false, cx); + editor.set_show_completions_on_input(Some(true)); editor.set_soft_wrap(); editor.set_use_modal_editing(true); editor.set_context_menu_options(ContextMenuOptions { @@ -134,13 +134,8 @@ impl MessageEditor { editor.register_addon(MessageEditorAddon::new()); editor }); - let mention_set = cx.new(|_cx| { - MentionSet::new( - project.downgrade(), - history_store.clone(), - prompt_store.clone(), - ) - }); + let mention_set = + cx.new(|_cx| MentionSet::new(project, history_store.clone(), prompt_store.clone())); let completion_provider = Rc::new(PromptCompletionProvider::new( cx.entity(), editor.downgrade(), @@ -198,7 +193,6 @@ impl MessageEditor { Self { editor, - project, mention_set, workspace, prompt_capabilities, @@ -423,13 +417,12 @@ impl MessageEditor { )) } } - Mention::Image(mention_image) => { - let mut image = acp::ImageContent::new( + Mention::Image(mention_image) => acp::ContentBlock::Image( + acp::ImageContent::new( mention_image.data.clone(), mention_image.format.mime_type(), - ); - - if let Some(uri) = match uri { + ) + .uri(match uri { MentionUri::File { .. } => Some(uri.to_uri().to_string()), MentionUri::PastedImage => None, other => { @@ -439,11 +432,8 @@ impl MessageEditor { ); None } - } { - image = image.uri(uri) - }; - acp::ContentBlock::Image(image) - } + }), + ), Mention::Link => acp::ContentBlock::ResourceLink( acp::ResourceLink::new(uri.name(), uri.to_uri().to_string()), ), @@ -553,6 +543,120 @@ impl MessageEditor { } fn paste(&mut self, _: &Paste, window: &mut Window, cx: &mut Context) { + let editor_clipboard_selections = cx + .read_from_clipboard() + .and_then(|item| item.entries().first().cloned()) + .and_then(|entry| match entry { + ClipboardEntry::String(text) => { + text.metadata_json::>() + } + _ => None, + }); + + let has_file_context = editor_clipboard_selections + .as_ref() + .is_some_and(|selections| { + selections + .iter() + .any(|sel| sel.file_path.is_some() && sel.line_range.is_some()) + }); + + if has_file_context { + if let Some((workspace, selections)) = + self.workspace.upgrade().zip(editor_clipboard_selections) + { + cx.stop_propagation(); + + let project = workspace.read(cx).project().clone(); + for selection in selections { + if let (Some(file_path), Some(line_range)) = + (selection.file_path, selection.line_range) + { + let crease_text = + acp_thread::selection_name(Some(file_path.as_ref()), &line_range); + + let mention_uri = MentionUri::Selection { + abs_path: Some(file_path.clone()), + line_range: line_range.clone(), + }; + + let mention_text = mention_uri.as_link().to_string(); + let (excerpt_id, text_anchor, content_len) = + self.editor.update(cx, |editor, cx| { + let buffer = editor.buffer().read(cx); + let snapshot = buffer.snapshot(cx); + let (excerpt_id, _, buffer_snapshot) = + snapshot.as_singleton().unwrap(); + let start_offset = buffer_snapshot.len(); + let text_anchor = buffer_snapshot.anchor_before(start_offset); + + editor.insert(&mention_text, window, cx); + editor.insert(" ", window, cx); + + (*excerpt_id, text_anchor, mention_text.len()) + }); + + let Some((crease_id, tx)) = insert_crease_for_mention( + excerpt_id, + text_anchor, + content_len, + crease_text.into(), + mention_uri.icon_path(cx), + None, + self.editor.clone(), + window, + cx, + ) else { + continue; + }; + drop(tx); + + let mention_task = cx + .spawn({ + let project = project.clone(); + async move |_, cx| { + let project_path = project + .update(cx, |project, cx| { + project.project_path_for_absolute_path(&file_path, cx) + }) + .map_err(|e| e.to_string())? + .ok_or_else(|| "project path not found".to_string())?; + + let buffer = project + .update(cx, |project, cx| { + project.open_buffer(project_path, cx) + }) + .map_err(|e| e.to_string())? + .await + .map_err(|e| e.to_string())?; + + buffer + .update(cx, |buffer, cx| { + let start = Point::new(*line_range.start(), 0) + .min(buffer.max_point()); + let end = Point::new(*line_range.end() + 1, 0) + .min(buffer.max_point()); + let content = + buffer.text_for_range(start..end).collect(); + Mention::Text { + content, + tracked_buffers: vec![cx.entity()], + } + }) + .map_err(|e| e.to_string()) + } + }) + .shared(); + + self.mention_set.update(cx, |mention_set, _cx| { + mention_set.insert_mention(crease_id, mention_uri.clone(), mention_task) + }); + } + } + return; + } + } + if self.prompt_capabilities.borrow().image && let Some(task) = paste_images_as_context(self.editor.clone(), self.mention_set.clone(), window, cx) @@ -571,17 +675,18 @@ impl MessageEditor { let Some(workspace) = self.workspace.upgrade() else { return; }; - let path_style = self.project.read(cx).path_style(cx); + let project = workspace.read(cx).project().clone(); + let path_style = project.read(cx).path_style(cx); let buffer = self.editor.read(cx).buffer().clone(); let Some(buffer) = buffer.read(cx).as_singleton() else { return; }; let mut tasks = Vec::new(); for path in paths { - let Some(entry) = self.project.read(cx).entry_for_path(&path, cx) else { + let Some(entry) = project.read(cx).entry_for_path(&path, cx) else { continue; }; - let Some(worktree) = self.project.read(cx).worktree_for_id(path.worktree_id, cx) else { + let Some(worktree) = project.read(cx).worktree_for_id(path.worktree_id, cx) else { continue; }; let abs_path = worktree.read(cx).absolutize(&path.path); @@ -689,9 +794,13 @@ impl MessageEditor { window: &mut Window, cx: &mut Context, ) { + let Some(workspace) = self.workspace.upgrade() else { + return; + }; + self.clear(window, cx); - let path_style = self.project.read(cx).path_style(cx); + let path_style = workspace.read(cx).project().read(cx).path_style(cx); let mut text = String::new(); let mut mentions = Vec::new(); @@ -934,7 +1043,7 @@ mod tests { cx.new(|cx| { MessageEditor::new( workspace.downgrade(), - project.clone(), + project.downgrade(), history_store.clone(), None, Default::default(), @@ -1045,7 +1154,7 @@ mod tests { cx.new(|cx| { MessageEditor::new( workspace_handle.clone(), - project.clone(), + project.downgrade(), history_store.clone(), None, prompt_capabilities.clone(), @@ -1206,7 +1315,7 @@ mod tests { let message_editor = cx.new(|cx| { MessageEditor::new( workspace_handle, - project.clone(), + project.downgrade(), history_store.clone(), None, prompt_capabilities.clone(), @@ -1428,7 +1537,7 @@ mod tests { let message_editor = cx.new(|cx| { MessageEditor::new( workspace_handle, - project.clone(), + project.downgrade(), history_store.clone(), None, prompt_capabilities.clone(), @@ -1919,7 +2028,7 @@ mod tests { cx.new(|cx| { let editor = MessageEditor::new( workspace.downgrade(), - project.clone(), + project.downgrade(), history_store.clone(), None, Default::default(), @@ -2024,7 +2133,7 @@ mod tests { cx.new(|cx| { let mut editor = MessageEditor::new( workspace.downgrade(), - project.clone(), + project.downgrade(), history_store.clone(), None, Default::default(), @@ -2093,7 +2202,7 @@ mod tests { cx.new(|cx| { MessageEditor::new( workspace.downgrade(), - project.clone(), + project.downgrade(), history_store.clone(), None, Default::default(), @@ -2156,7 +2265,7 @@ mod tests { let message_editor = cx.new(|cx| { MessageEditor::new( workspace_handle, - project.clone(), + project.downgrade(), history_store.clone(), None, Default::default(), @@ -2314,7 +2423,7 @@ mod tests { let message_editor = cx.new(|cx| { MessageEditor::new( workspace_handle, - project.clone(), + project.downgrade(), history_store.clone(), None, Default::default(), diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index aedb96bb82f07723f934d0ec73aa1fd545461f00..7b2954c2edafa6b610efe654c8c1368f04a374b6 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -100,7 +100,7 @@ impl ThreadError { { Self::ModelRequestLimitReached(error.plan) } else if let Some(acp_error) = error.downcast_ref::() - && acp_error.code == acp::ErrorCode::AUTH_REQUIRED.code + && acp_error.code == acp::ErrorCode::AuthRequired { Self::AuthenticationRequired(acp_error.message.clone().into()) } else { @@ -344,7 +344,7 @@ impl AcpThreadView { let message_editor = cx.new(|cx| { let mut editor = MessageEditor::new( workspace.clone(), - project.clone(), + project.downgrade(), history_store.clone(), prompt_store.clone(), prompt_capabilities.clone(), @@ -369,7 +369,7 @@ impl AcpThreadView { let entry_view_state = cx.new(|_| { EntryViewState::new( workspace.clone(), - project.clone(), + project.downgrade(), history_store.clone(), prompt_store.clone(), prompt_capabilities.clone(), @@ -6243,7 +6243,7 @@ pub(crate) mod tests { StubAgentConnection::new().with_permission_requests(HashMap::from_iter([( tool_call_id, vec![acp::PermissionOption::new( - "1".into(), + "1", "Allow", acp::PermissionOptionKind::AllowOnce, )], diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index 3533c28caa93f82c96ecacdafdfdd3dc1b1643f1..617aea8fd553cb8e9c7cd5c9814bc420adec4df3 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -36,7 +36,7 @@ use settings::{Settings, SettingsStore, update_settings_file}; use ui::{ Button, ButtonStyle, Chip, CommonAnimationExt, ContextMenu, ContextMenuEntry, Disclosure, Divider, DividerColor, ElevationIndex, IconName, IconPosition, IconSize, Indicator, LabelSize, - PopoverMenu, Switch, SwitchColor, Tooltip, WithScrollbar, prelude::*, + PopoverMenu, Switch, Tooltip, WithScrollbar, prelude::*, }; use util::ResultExt as _; use workspace::{Workspace, create_and_open_local_file}; @@ -883,7 +883,6 @@ impl AgentConfiguration { .child(context_server_configuration_menu) .child( Switch::new("context-server-switch", is_running.into()) - .color(SwitchColor::Accent) .on_click({ let context_server_manager = self.context_server_store.clone(); let fs = self.fs.clone(); diff --git a/crates/agent_ui/src/agent_model_selector.rs b/crates/agent_ui/src/agent_model_selector.rs index 924f37db0440dd1d4ddbdb90bdf73dfe56f0cbad..685bd775424b120f753d3a0d444e5a966f940773 100644 --- a/crates/agent_ui/src/agent_model_selector.rs +++ b/crates/agent_ui/src/agent_model_selector.rs @@ -108,7 +108,7 @@ impl Render for AgentModelSelector { .child( Icon::new(IconName::ChevronDown) .color(color) - .size(IconSize::XSmall), + .size(IconSize::Small), ), move |_window, cx| { Tooltip::for_action_in("Change Model", &ToggleModelSelector, &focus_handle, cx) diff --git a/crates/agent_ui/src/buffer_codegen.rs b/crates/agent_ui/src/buffer_codegen.rs index 972ead664464876e57d7830b18db3f2b0c49629c..f7e7884310458e97421768882df57934a19b4430 100644 --- a/crates/agent_ui/src/buffer_codegen.rs +++ b/crates/agent_ui/src/buffer_codegen.rs @@ -5,22 +5,26 @@ use client::telemetry::Telemetry; use cloud_llm_client::CompletionIntent; use collections::HashSet; use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint}; +use feature_flags::{FeatureFlagAppExt as _, InlineAssistantV2FeatureFlag}; use futures::{ SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::{LocalBoxFuture, Shared}, join, }; -use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task}; +use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task}; use language::{Buffer, IndentKind, Point, TransactionId, line_diff}; use language_model::{ - LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, - LanguageModelTextStream, Role, report_assistant_event, + LanguageModel, LanguageModelCompletionError, LanguageModelRegistry, LanguageModelRequest, + LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelTextStream, Role, + report_assistant_event, }; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; use prompt_store::PromptBuilder; use rope::Rope; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; use smol::future::FutureExt; use std::{ cmp, @@ -34,6 +38,29 @@ use std::{ }; use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff}; use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase}; +use ui::SharedString; + +/// Use this tool to provide a message to the user when you're unable to complete a task. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct FailureMessageInput { + /// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request. + /// + /// The message may use markdown formatting if you wish. + pub message: String, +} + +/// Replaces text in tags with your replacement_text. +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct RewriteSectionInput { + /// A brief description of the edit you have made. + /// + /// The description may use markdown formatting if you wish. + /// This is optional - if the edit is simple or obvious, you should leave it empty. + pub description: String, + + /// The text to replace the section with. + pub replacement_text: String, +} pub struct BufferCodegen { alternatives: Vec>, @@ -238,6 +265,7 @@ pub struct CodegenAlternative { elapsed_time: Option, completion: Option, pub message_id: Option, + pub model_explanation: Option, } impl EventEmitter for CodegenAlternative {} @@ -288,14 +316,15 @@ impl CodegenAlternative { generation: Task::ready(()), diff: Diff::default(), telemetry, - _subscription: cx.subscribe(&buffer, Self::handle_buffer_event), builder, - active, + active: active, edits: Vec::new(), line_operations: Vec::new(), range, elapsed_time: None, completion: None, + model_explanation: None, + _subscription: cx.subscribe(&buffer, Self::handle_buffer_event), } } @@ -358,18 +387,124 @@ impl CodegenAlternative { let api_key = model.api_key(cx); let telemetry_id = model.telemetry_id(); let provider_id = model.provider_id(); - let stream: LocalBoxFuture> = - if user_prompt.trim().to_lowercase() == "delete" { - async { Ok(LanguageModelTextStream::default()) }.boxed_local() + + if cx.has_flag::() { + let request = self.build_request(&model, user_prompt, context_task, cx)?; + let tool_use = + cx.spawn(async move |_, cx| model.stream_completion_tool(request.await, cx).await); + self.handle_tool_use(telemetry_id, provider_id.to_string(), api_key, tool_use, cx); + } else { + let stream: LocalBoxFuture> = + if user_prompt.trim().to_lowercase() == "delete" { + async { Ok(LanguageModelTextStream::default()) }.boxed_local() + } else { + let request = self.build_request(&model, user_prompt, context_task, cx)?; + cx.spawn(async move |_, cx| { + Ok(model.stream_completion_text(request.await, cx).await?) + }) + .boxed_local() + }; + self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx); + } + + Ok(()) + } + + fn build_request_v2( + &self, + model: &Arc, + user_prompt: String, + context_task: Shared>>, + cx: &mut App, + ) -> Result> { + let buffer = self.buffer.read(cx).snapshot(cx); + let language = buffer.language_at(self.range.start); + let language_name = if let Some(language) = language.as_ref() { + if Arc::ptr_eq(language, &language::PLAIN_TEXT) { + None } else { - let request = self.build_request(&model, user_prompt, context_task, cx)?; - cx.spawn(async move |_, cx| { - Ok(model.stream_completion_text(request.await, cx).await?) - }) - .boxed_local() + Some(language.name()) + } + } else { + None + }; + + let language_name = language_name.as_ref(); + let start = buffer.point_to_buffer_offset(self.range.start); + let end = buffer.point_to_buffer_offset(self.range.end); + let (buffer, range) = if let Some((start, end)) = start.zip(end) { + let (start_buffer, start_buffer_offset) = start; + let (end_buffer, end_buffer_offset) = end; + if start_buffer.remote_id() == end_buffer.remote_id() { + (start_buffer.clone(), start_buffer_offset..end_buffer_offset) + } else { + anyhow::bail!("invalid transformation range"); + } + } else { + anyhow::bail!("invalid transformation range"); + }; + + let system_prompt = self + .builder + .generate_inline_transformation_prompt_v2( + language_name, + buffer, + range.start.0..range.end.0, + ) + .context("generating content prompt")?; + + let temperature = AgentSettings::temperature_for_model(model, cx); + + let tool_input_format = model.tool_input_format(); + + Ok(cx.spawn(async move |_cx| { + let mut messages = vec![LanguageModelRequestMessage { + role: Role::System, + content: vec![system_prompt.into()], + cache: false, + reasoning_details: None, + }]; + + let mut user_message = LanguageModelRequestMessage { + role: Role::User, + content: Vec::new(), + cache: false, + reasoning_details: None, }; - self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx); - Ok(()) + + if let Some(context) = context_task.await { + context.add_to_request_message(&mut user_message); + } + + user_message.content.push(user_prompt.into()); + messages.push(user_message); + + let tools = vec![ + LanguageModelRequestTool { + name: "rewrite_section".to_string(), + description: "Replaces text in tags with your replacement_text.".to_string(), + input_schema: language_model::tool_schema::root_schema_for::(tool_input_format).to_value(), + }, + LanguageModelRequestTool { + name: "failure_message".to_string(), + description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(), + input_schema: language_model::tool_schema::root_schema_for::(tool_input_format).to_value(), + }, + ]; + + LanguageModelRequest { + thread_id: None, + prompt_id: None, + intent: Some(CompletionIntent::InlineAssist), + mode: None, + tools, + tool_choice: None, + stop: Vec::new(), + temperature, + messages, + thinking_allowed: false, + } + })) } fn build_request( @@ -379,6 +514,10 @@ impl CodegenAlternative { context_task: Shared>>, cx: &mut App, ) -> Result> { + if cx.has_flag::() { + return self.build_request_v2(model, user_prompt, context_task, cx); + } + let buffer = self.buffer.read(cx).snapshot(cx); let language = buffer.language_at(self.range.start); let language_name = if let Some(language) = language.as_ref() { @@ -510,6 +649,7 @@ impl CodegenAlternative { self.generation = cx.spawn(async move |codegen, cx| { let stream = stream.await; + let token_usage = stream .as_ref() .ok() @@ -899,6 +1039,101 @@ impl CodegenAlternative { .ok(); }) } + + fn handle_tool_use( + &mut self, + _telemetry_id: String, + _provider_id: String, + _api_key: Option, + tool_use: impl 'static + + Future< + Output = Result, + >, + cx: &mut Context, + ) { + self.diff = Diff::default(); + self.status = CodegenStatus::Pending; + + self.generation = cx.spawn(async move |codegen, cx| { + let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| { + let _ = codegen.update(cx, |this, cx| { + this.status = status; + cx.emit(CodegenEvent::Finished); + cx.notify(); + }); + }; + + let tool_use = tool_use.await; + + match tool_use { + Ok(tool_use) if tool_use.name.as_ref() == "rewrite_section" => { + // Parse the input JSON into RewriteSectionInput + match serde_json::from_value::(tool_use.input) { + Ok(input) => { + // Store the description if non-empty + let description = if !input.description.trim().is_empty() { + Some(input.description.clone()) + } else { + None + }; + + // Apply the replacement text to the buffer and compute diff + let batch_diff_task = codegen + .update(cx, |this, cx| { + this.model_explanation = description.map(Into::into); + let range = this.range.clone(); + this.apply_edits( + std::iter::once((range, input.replacement_text)), + cx, + ); + this.reapply_batch_diff(cx) + }) + .ok(); + + // Wait for the diff computation to complete + if let Some(diff_task) = batch_diff_task { + diff_task.await; + } + + finish_with_status(CodegenStatus::Done, cx); + return; + } + Err(e) => { + finish_with_status(CodegenStatus::Error(e.into()), cx); + return; + } + } + } + Ok(tool_use) if tool_use.name.as_ref() == "failure_message" => { + // Handle failure message tool use + match serde_json::from_value::(tool_use.input) { + Ok(input) => { + let _ = codegen.update(cx, |this, _cx| { + // Store the failure message as the tool description + this.model_explanation = Some(input.message.into()); + }); + finish_with_status(CodegenStatus::Done, cx); + return; + } + Err(e) => { + finish_with_status(CodegenStatus::Error(e.into()), cx); + return; + } + } + } + Ok(_tool_use) => { + // Unexpected tool. + finish_with_status(CodegenStatus::Done, cx); + return; + } + Err(e) => { + finish_with_status(CodegenStatus::Error(e.into()), cx); + return; + } + } + }); + cx.notify(); + } } #[derive(Copy, Clone, Debug)] @@ -1060,8 +1295,9 @@ mod tests { }; use gpui::TestAppContext; use indoc::indoc; - use language::{Buffer, Language, LanguageConfig, LanguageMatcher, Point, tree_sitter_rust}; + use language::{Buffer, Point}; use language_model::{LanguageModelRegistry, TokenUsage}; + use languages::rust_lang; use rand::prelude::*; use settings::SettingsStore; use std::{future, sync::Arc}; @@ -1078,7 +1314,7 @@ mod tests { } } "}; - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); let range = buffer.read_with(cx, |buffer, cx| { let snapshot = buffer.snapshot(cx); @@ -1140,7 +1376,7 @@ mod tests { le } "}; - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); let range = buffer.read_with(cx, |buffer, cx| { let snapshot = buffer.snapshot(cx); @@ -1204,7 +1440,7 @@ mod tests { " \n", "}\n" // ); - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); let range = buffer.read_with(cx, |buffer, cx| { let snapshot = buffer.snapshot(cx); @@ -1320,7 +1556,7 @@ mod tests { let x = 0; } "}; - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); let range = buffer.read_with(cx, |buffer, cx| { let snapshot = buffer.snapshot(cx); @@ -1437,27 +1673,4 @@ mod tests { }); chunks_tx } - - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_indents_query( - r#" - (call_expression) @indent - (field_expression) @indent - (_ "(" ")" @end) @indent - (_ "{" "}" @end) @indent - "#, - ) - .unwrap() - } } diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index cbc5891036fdf03ee04cca6b77820748faed2d0a..48da85d38554da8227d76d3cbe290e29ef4fc531 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -387,17 +387,9 @@ impl InlineAssistant { let mut selections = Vec::>::new(); let mut newest_selection = None; for mut selection in initial_selections { - if selection.end > selection.start { - selection.start.column = 0; - // If the selection ends at the start of the line, we don't want to include it. - if selection.end.column == 0 { - selection.end.row -= 1; - } - selection.end.column = snapshot - .buffer_snapshot() - .line_len(MultiBufferRow(selection.end.row)); - } else if let Some(fold) = - snapshot.crease_for_buffer_row(MultiBufferRow(selection.end.row)) + if selection.end == selection.start + && let Some(fold) = + snapshot.crease_for_buffer_row(MultiBufferRow(selection.end.row)) { selection.start = fold.range().start; selection.end = fold.range().end; @@ -424,6 +416,15 @@ impl InlineAssistant { } } } + } else { + selection.start.column = 0; + // If the selection ends at the start of the line, we don't want to include it. + if selection.end.column == 0 && selection.start.row != selection.end.row { + selection.end.row -= 1; + } + selection.end.column = snapshot + .buffer_snapshot() + .line_len(MultiBufferRow(selection.end.row)); } if let Some(prev_selection) = selections.last_mut() @@ -544,14 +545,15 @@ impl InlineAssistant { } } - let [prompt_block_id, end_block_id] = - self.insert_assist_blocks(editor, &range, &prompt_editor, cx); + let [prompt_block_id, tool_description_block_id, end_block_id] = + self.insert_assist_blocks(&editor, &range, &prompt_editor, cx); assists.push(( assist_id, range.clone(), prompt_editor, prompt_block_id, + tool_description_block_id, end_block_id, )); } @@ -570,7 +572,15 @@ impl InlineAssistant { }; let mut assist_group = InlineAssistGroup::new(); - for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists { + for ( + assist_id, + range, + prompt_editor, + prompt_block_id, + tool_description_block_id, + end_block_id, + ) in assists + { let codegen = prompt_editor.read(cx).codegen().clone(); self.assists.insert( @@ -581,6 +591,7 @@ impl InlineAssistant { editor, &prompt_editor, prompt_block_id, + tool_description_block_id, end_block_id, range, codegen, @@ -689,7 +700,7 @@ impl InlineAssistant { range: &Range, prompt_editor: &Entity>, cx: &mut App, - ) -> [CustomBlockId; 2] { + ) -> [CustomBlockId; 3] { let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| { prompt_editor .editor @@ -703,6 +714,14 @@ impl InlineAssistant { render: build_assist_editor_renderer(prompt_editor), priority: 0, }, + // Placeholder for tool description - will be updated dynamically + BlockProperties { + style: BlockStyle::Flex, + placement: BlockPlacement::Below(range.end), + height: Some(0), + render: Arc::new(|_cx| div().into_any_element()), + priority: 0, + }, BlockProperties { style: BlockStyle::Sticky, placement: BlockPlacement::Below(range.end), @@ -721,7 +740,7 @@ impl InlineAssistant { editor.update(cx, |editor, cx| { let block_ids = editor.insert_blocks(assist_blocks, None, cx); - [block_ids[0], block_ids[1]] + [block_ids[0], block_ids[1], block_ids[2]] }) } @@ -1113,6 +1132,9 @@ impl InlineAssistant { let mut to_remove = decorations.removed_line_block_ids; to_remove.insert(decorations.prompt_block_id); to_remove.insert(decorations.end_block_id); + if let Some(tool_description_block_id) = decorations.model_explanation { + to_remove.insert(tool_description_block_id); + } editor.remove_blocks(to_remove, None, cx); }); @@ -1433,8 +1455,60 @@ impl InlineAssistant { let old_snapshot = codegen.snapshot(cx); let old_buffer = codegen.old_buffer(cx); let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone(); + // let model_explanation = codegen.model_explanation(cx); editor.update(cx, |editor, cx| { + // Update tool description block + // if let Some(description) = model_explanation { + // if let Some(block_id) = decorations.model_explanation { + // editor.remove_blocks(HashSet::from_iter([block_id]), None, cx); + // let new_block_id = editor.insert_blocks( + // [BlockProperties { + // style: BlockStyle::Flex, + // placement: BlockPlacement::Below(assist.range.end), + // height: Some(1), + // render: Arc::new({ + // let description = description.clone(); + // move |cx| { + // div() + // .w_full() + // .py_1() + // .px_2() + // .bg(cx.theme().colors().editor_background) + // .border_y_1() + // .border_color(cx.theme().status().info_border) + // .child( + // Label::new(description.clone()) + // .color(Color::Muted) + // .size(LabelSize::Small), + // ) + // .into_any_element() + // } + // }), + // priority: 0, + // }], + // None, + // cx, + // ); + // decorations.model_explanation = new_block_id.into_iter().next(); + // } + // } else if let Some(block_id) = decorations.model_explanation { + // // Hide the block if there's no description + // editor.remove_blocks(HashSet::from_iter([block_id]), None, cx); + // let new_block_id = editor.insert_blocks( + // [BlockProperties { + // style: BlockStyle::Flex, + // placement: BlockPlacement::Below(assist.range.end), + // height: Some(0), + // render: Arc::new(|_cx| div().into_any_element()), + // priority: 0, + // }], + // None, + // cx, + // ); + // decorations.model_explanation = new_block_id.into_iter().next(); + // } + let old_blocks = mem::take(&mut decorations.removed_line_block_ids); editor.remove_blocks(old_blocks, None, cx); @@ -1686,6 +1760,7 @@ impl InlineAssist { editor: &Entity, prompt_editor: &Entity>, prompt_block_id: CustomBlockId, + tool_description_block_id: CustomBlockId, end_block_id: CustomBlockId, range: Range, codegen: Entity, @@ -1700,7 +1775,8 @@ impl InlineAssist { decorations: Some(InlineAssistDecorations { prompt_block_id, prompt_editor: prompt_editor.clone(), - removed_line_block_ids: HashSet::default(), + removed_line_block_ids: Default::default(), + model_explanation: Some(tool_description_block_id), end_block_id, }), range, @@ -1804,6 +1880,7 @@ struct InlineAssistDecorations { prompt_block_id: CustomBlockId, prompt_editor: Entity>, removed_line_block_ids: HashSet, + model_explanation: Option, end_block_id: CustomBlockId, } diff --git a/crates/agent_ui/src/inline_prompt_editor.rs b/crates/agent_ui/src/inline_prompt_editor.rs index b9e8d9ada230ba497ffcd4e577d3312dd440e604..b9852ea727c7974e3564fadc652f132076c01f09 100644 --- a/crates/agent_ui/src/inline_prompt_editor.rs +++ b/crates/agent_ui/src/inline_prompt_editor.rs @@ -10,10 +10,11 @@ use editor::{ }; use fs::Fs; use gpui::{ - AnyElement, App, Context, CursorStyle, Entity, EventEmitter, FocusHandle, Focusable, - Subscription, TextStyle, WeakEntity, Window, + AnyElement, App, Context, Entity, EventEmitter, FocusHandle, Focusable, Subscription, + TextStyle, TextStyleRefinement, WeakEntity, Window, }; use language_model::{LanguageModel, LanguageModelRegistry}; +use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle}; use parking_lot::Mutex; use project::Project; use prompt_store::PromptStore; @@ -65,7 +66,7 @@ impl Render for PromptEditor { const RIGHT_PADDING: Pixels = px(9.); - let (left_gutter_width, right_padding) = match &self.mode { + let (left_gutter_width, right_padding, explanation) = match &self.mode { PromptEditorMode::Buffer { id: _, codegen, @@ -83,17 +84,23 @@ impl Render for PromptEditor { let left_gutter_width = gutter.full_width() + (gutter.margin / 2.0); let right_padding = editor_margins.right + RIGHT_PADDING; - (left_gutter_width, right_padding) + let explanation = codegen + .active_alternative() + .read(cx) + .model_explanation + .clone(); + + (left_gutter_width, right_padding, explanation) } PromptEditorMode::Terminal { .. } => { // Give the equivalent of the same left-padding that we're using on the right - (Pixels::from(40.0), Pixels::from(24.)) + (Pixels::from(40.0), Pixels::from(24.), None) } }; let bottom_padding = match &self.mode { PromptEditorMode::Buffer { .. } => rems_from_px(2.0), - PromptEditorMode::Terminal { .. } => rems_from_px(8.0), + PromptEditorMode::Terminal { .. } => rems_from_px(4.0), }; buttons.extend(self.render_buttons(window, cx)); @@ -111,22 +118,33 @@ impl Render for PromptEditor { this.trigger_completion_menu(window, cx); })); + let markdown = window.use_state(cx, |_, cx| Markdown::new("".into(), None, None, cx)); + + if let Some(explanation) = &explanation { + markdown.update(cx, |markdown, cx| { + markdown.reset(explanation.clone(), cx); + }); + } + + let explanation_label = self + .render_markdown(markdown, markdown_style(window, cx)) + .into_any_element(); + v_flex() .key_context("PromptEditor") .capture_action(cx.listener(Self::paste)) - .bg(cx.theme().colors().editor_background) .block_mouse_except_scroll() - .gap_0p5() - .border_y_1() - .border_color(cx.theme().status().info_border) .size_full() .pt_0p5() .pb(bottom_padding) .pr(right_padding) + .gap_0p5() + .justify_center() + .border_y_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().editor_background) .child( h_flex() - .items_start() - .cursor(CursorStyle::Arrow) .on_action(cx.listener(|this, _: &ToggleModelSelector, window, cx| { this.model_selector .update(cx, |model_selector, cx| model_selector.toggle(window, cx)); @@ -139,14 +157,14 @@ impl Render for PromptEditor { .capture_action(cx.listener(Self::cycle_next)) .child( WithRemSize::new(ui_font_size) + .h_full() + .w(left_gutter_width) .flex() .flex_row() .flex_shrink_0() .items_center() - .h_full() - .w(left_gutter_width) .justify_center() - .gap_2() + .gap_1() .child(self.render_close_button(cx)) .map(|el| { let CodegenStatus::Error(error) = self.codegen_status(cx) else { @@ -177,26 +195,83 @@ impl Render for PromptEditor { .flex_row() .items_center() .gap_1() + .child(add_context_button) + .child(self.model_selector.clone()) .children(buttons), ), ), ) - .child( - WithRemSize::new(ui_font_size) - .flex() - .flex_row() - .items_center() - .child(h_flex().flex_shrink_0().w(left_gutter_width)) - .child( - h_flex() - .w_full() - .pl_1() - .items_start() - .justify_between() - .child(add_context_button) - .child(self.model_selector.clone()), - ), - ) + .when_some(explanation, |this, _| { + this.child( + h_flex() + .size_full() + .justify_center() + .child(div().w(left_gutter_width + px(6.))) + .child( + div() + .size_full() + .min_w_0() + .pt(rems_from_px(3.)) + .pl_0p5() + .flex_1() + .border_t_1() + .border_color(cx.theme().colors().border_variant) + .child(explanation_label), + ), + ) + }) + } +} + +fn markdown_style(window: &Window, cx: &App) -> MarkdownStyle { + let theme_settings = ThemeSettings::get_global(cx); + let colors = cx.theme().colors(); + let mut text_style = window.text_style(); + + text_style.refine(&TextStyleRefinement { + font_family: Some(theme_settings.ui_font.family.clone()), + color: Some(colors.text), + ..Default::default() + }); + + MarkdownStyle { + base_text_style: text_style.clone(), + syntax: cx.theme().syntax().clone(), + selection_background_color: colors.element_selection_background, + heading_level_styles: Some(HeadingLevelStyles { + h1: Some(TextStyleRefinement { + font_size: Some(rems(1.15).into()), + ..Default::default() + }), + h2: Some(TextStyleRefinement { + font_size: Some(rems(1.1).into()), + ..Default::default() + }), + h3: Some(TextStyleRefinement { + font_size: Some(rems(1.05).into()), + ..Default::default() + }), + h4: Some(TextStyleRefinement { + font_size: Some(rems(1.).into()), + ..Default::default() + }), + h5: Some(TextStyleRefinement { + font_size: Some(rems(0.95).into()), + ..Default::default() + }), + h6: Some(TextStyleRefinement { + font_size: Some(rems(0.875).into()), + ..Default::default() + }), + }), + inline_code: TextStyleRefinement { + font_family: Some(theme_settings.buffer_font.family.clone()), + font_fallbacks: theme_settings.buffer_font.fallbacks.clone(), + font_features: Some(theme_settings.buffer_font.features.clone()), + background_color: Some(colors.editor_foreground.opacity(0.08)), + ..Default::default() + }, + ..Default::default() } } @@ -759,6 +834,10 @@ impl PromptEditor { }) .into_any_element() } + + fn render_markdown(&self, markdown: Entity, style: MarkdownStyle) -> MarkdownElement { + MarkdownElement::new(markdown, style) + } } pub enum PromptEditorMode { diff --git a/crates/agent_ui/src/text_thread_editor.rs b/crates/agent_ui/src/text_thread_editor.rs index 30538898b28a1d41d6c63b3e910f51c816e299ab..52efdf13be230d3fe6658e254f5d3f1a01b8211f 100644 --- a/crates/agent_ui/src/text_thread_editor.rs +++ b/crates/agent_ui/src/text_thread_editor.rs @@ -1682,6 +1682,98 @@ impl TextThreadEditor { window: &mut Window, cx: &mut Context, ) { + let editor_clipboard_selections = cx + .read_from_clipboard() + .and_then(|item| item.entries().first().cloned()) + .and_then(|entry| match entry { + ClipboardEntry::String(text) => { + text.metadata_json::>() + } + _ => None, + }); + + let has_file_context = editor_clipboard_selections + .as_ref() + .is_some_and(|selections| { + selections + .iter() + .any(|sel| sel.file_path.is_some() && sel.line_range.is_some()) + }); + + if has_file_context { + if let Some(clipboard_item) = cx.read_from_clipboard() { + if let Some(ClipboardEntry::String(clipboard_text)) = + clipboard_item.entries().first() + { + if let Some(selections) = editor_clipboard_selections { + cx.stop_propagation(); + + let text = clipboard_text.text(); + self.editor.update(cx, |editor, cx| { + let mut current_offset = 0; + let weak_editor = cx.entity().downgrade(); + + for selection in selections { + if let (Some(file_path), Some(line_range)) = + (selection.file_path, selection.line_range) + { + let selected_text = + &text[current_offset..current_offset + selection.len]; + let fence = assistant_slash_commands::codeblock_fence_for_path( + file_path.to_str(), + Some(line_range.clone()), + ); + let formatted_text = format!("{fence}{selected_text}\n```"); + + let insert_point = editor + .selections + .newest::(&editor.display_snapshot(cx)) + .head(); + let start_row = MultiBufferRow(insert_point.row); + + editor.insert(&formatted_text, window, cx); + + let snapshot = editor.buffer().read(cx).snapshot(cx); + let anchor_before = snapshot.anchor_after(insert_point); + let anchor_after = editor + .selections + .newest_anchor() + .head() + .bias_left(&snapshot); + + editor.insert("\n", window, cx); + + let crease_text = acp_thread::selection_name( + Some(file_path.as_ref()), + &line_range, + ); + + let fold_placeholder = quote_selection_fold_placeholder( + crease_text, + weak_editor.clone(), + ); + let crease = Crease::inline( + anchor_before..anchor_after, + fold_placeholder, + render_quote_selection_output_toggle, + |_, _, _, _| Empty.into_any(), + ); + editor.insert_creases(vec![crease], cx); + editor.fold_at(start_row, window, cx); + + current_offset += selection.len; + if !selection.is_entire_line && current_offset < text.len() { + current_offset += 1; + } + } + } + }); + return; + } + } + } + } + cx.stop_propagation(); let mut images = if let Some(item) = cx.read_from_clipboard() { diff --git a/crates/agent_ui/src/ui/agent_notification.rs b/crates/agent_ui/src/ui/agent_notification.rs index af2a022f147b79a0a299c17dd26c7e9a8b62aeb9..34ca0bb32a82aa23d1b954554ce2dfec436bfe1c 100644 --- a/crates/agent_ui/src/ui/agent_notification.rs +++ b/crates/agent_ui/src/ui/agent_notification.rs @@ -106,9 +106,6 @@ impl Render for AgentNotification { .font(ui_font) .border_color(cx.theme().colors().border) .rounded_xl() - .on_click(cx.listener(|_, _, _, cx| { - cx.emit(AgentNotificationEvent::Accepted); - })) .child( h_flex() .items_start() diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 06e25253ee626ba76345d7a65b14f1fed0a04b8f..bf2a305011fad924d38dcb391840661f1fd3cdbb 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -12,6 +12,8 @@ pub use settings::ModelMode; use strum::{EnumIter, EnumString}; use thiserror::Error; +pub mod batches; + pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com"; #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] @@ -465,6 +467,7 @@ impl Model { } } +/// Generate completion with streaming. pub async fn stream_completion( client: &dyn HttpClient, api_url: &str, @@ -477,6 +480,101 @@ pub async fn stream_completion( .map(|output| output.0) } +/// Generate completion without streaming. +pub async fn non_streaming_completion( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: Request, + beta_headers: Option, +) -> Result { + let (mut response, rate_limits) = + send_request(client, api_url, api_key, &request, beta_headers).await?; + + if response.status().is_success() { + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(AnthropicError::ReadResponse)?; + + serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse) + } else { + Err(handle_error_response(response, rate_limits).await) + } +} + +async fn send_request( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: impl Serialize, + beta_headers: Option, +) -> Result<(http::Response, RateLimitInfo), AnthropicError> { + let uri = format!("{api_url}/v1/messages"); + + let mut request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Anthropic-Version", "2023-06-01") + .header("X-Api-Key", api_key.trim()) + .header("Content-Type", "application/json"); + + if let Some(beta_headers) = beta_headers { + request_builder = request_builder.header("Anthropic-Beta", beta_headers); + } + + let serialized_request = + serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?; + let request = request_builder + .body(AsyncBody::from(serialized_request)) + .map_err(AnthropicError::BuildRequestBody)?; + + let response = client + .send(request) + .await + .map_err(AnthropicError::HttpSend)?; + + let rate_limits = RateLimitInfo::from_headers(response.headers()); + + Ok((response, rate_limits)) +} + +async fn handle_error_response( + mut response: http::Response, + rate_limits: RateLimitInfo, +) -> AnthropicError { + if response.status().as_u16() == 529 { + return AnthropicError::ServerOverloaded { + retry_after: rate_limits.retry_after, + }; + } + + if let Some(retry_after) = rate_limits.retry_after { + return AnthropicError::RateLimit { retry_after }; + } + + let mut body = String::new(); + let read_result = response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(AnthropicError::ReadResponse); + + if let Err(err) = read_result { + return err; + } + + match serde_json::from_str::(&body) { + Ok(Event::Error { error }) => AnthropicError::ApiError(error), + Ok(_) | Err(_) => AnthropicError::HttpResponseError { + status_code: response.status(), + message: body, + }, + } +} + /// An individual rate limit. #[derive(Debug)] pub struct RateLimit { @@ -580,30 +678,10 @@ pub async fn stream_completion_with_rate_limit_info( base: request, stream: true, }; - let uri = format!("{api_url}/v1/messages"); - let mut request_builder = HttpRequest::builder() - .method(Method::POST) - .uri(uri) - .header("Anthropic-Version", "2023-06-01") - .header("X-Api-Key", api_key.trim()) - .header("Content-Type", "application/json"); + let (response, rate_limits) = + send_request(client, api_url, api_key, &request, beta_headers).await?; - if let Some(beta_headers) = beta_headers { - request_builder = request_builder.header("Anthropic-Beta", beta_headers); - } - - let serialized_request = - serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?; - let request = request_builder - .body(AsyncBody::from(serialized_request)) - .map_err(AnthropicError::BuildRequestBody)?; - - let mut response = client - .send(request) - .await - .map_err(AnthropicError::HttpSend)?; - let rate_limits = RateLimitInfo::from_headers(response.headers()); if response.status().is_success() { let reader = BufReader::new(response.into_body()); let stream = reader @@ -622,27 +700,8 @@ pub async fn stream_completion_with_rate_limit_info( }) .boxed(); Ok((stream, Some(rate_limits))) - } else if response.status().as_u16() == 529 { - Err(AnthropicError::ServerOverloaded { - retry_after: rate_limits.retry_after, - }) - } else if let Some(retry_after) = rate_limits.retry_after { - Err(AnthropicError::RateLimit { retry_after }) } else { - let mut body = String::new(); - response - .body_mut() - .read_to_string(&mut body) - .await - .map_err(AnthropicError::ReadResponse)?; - - match serde_json::from_str::(&body) { - Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)), - Ok(_) | Err(_) => Err(AnthropicError::HttpResponseError { - status_code: response.status(), - message: body, - }), - } + Err(handle_error_response(response, rate_limits).await) } } diff --git a/crates/anthropic/src/batches.rs b/crates/anthropic/src/batches.rs new file mode 100644 index 0000000000000000000000000000000000000000..5fb594348d45c84e8c246c2611f7cde3aa77a18d --- /dev/null +++ b/crates/anthropic/src/batches.rs @@ -0,0 +1,190 @@ +use anyhow::Result; +use futures::AsyncReadExt; +use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use serde::{Deserialize, Serialize}; + +use crate::{AnthropicError, ApiError, RateLimitInfo, Request, Response}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct BatchRequest { + pub custom_id: String, + pub params: Request, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CreateBatchRequest { + pub requests: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct MessageBatchRequestCounts { + pub processing: u64, + pub succeeded: u64, + pub errored: u64, + pub canceled: u64, + pub expired: u64, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct MessageBatch { + pub id: String, + #[serde(rename = "type")] + pub batch_type: String, + pub processing_status: String, + pub request_counts: MessageBatchRequestCounts, + pub ended_at: Option, + pub created_at: String, + pub expires_at: String, + pub archived_at: Option, + pub cancel_initiated_at: Option, + pub results_url: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum BatchResult { + #[serde(rename = "succeeded")] + Succeeded { message: Response }, + #[serde(rename = "errored")] + Errored { error: ApiError }, + #[serde(rename = "canceled")] + Canceled, + #[serde(rename = "expired")] + Expired, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct BatchIndividualResponse { + pub custom_id: String, + pub result: BatchResult, +} + +pub async fn create_batch( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: CreateBatchRequest, +) -> Result { + let uri = format!("{api_url}/v1/messages/batches"); + + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Anthropic-Version", "2023-06-01") + .header("X-Api-Key", api_key.trim()) + .header("Content-Type", "application/json"); + + let serialized_request = + serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?; + let http_request = request_builder + .body(AsyncBody::from(serialized_request)) + .map_err(AnthropicError::BuildRequestBody)?; + + let mut response = client + .send(http_request) + .await + .map_err(AnthropicError::HttpSend)?; + + let rate_limits = RateLimitInfo::from_headers(response.headers()); + + if response.status().is_success() { + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(AnthropicError::ReadResponse)?; + + serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse) + } else { + Err(crate::handle_error_response(response, rate_limits).await) + } +} + +pub async fn retrieve_batch( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + message_batch_id: &str, +) -> Result { + let uri = format!("{api_url}/v1/messages/batches/{message_batch_id}"); + + let request_builder = HttpRequest::builder() + .method(Method::GET) + .uri(uri) + .header("Anthropic-Version", "2023-06-01") + .header("X-Api-Key", api_key.trim()); + + let http_request = request_builder + .body(AsyncBody::default()) + .map_err(AnthropicError::BuildRequestBody)?; + + let mut response = client + .send(http_request) + .await + .map_err(AnthropicError::HttpSend)?; + + let rate_limits = RateLimitInfo::from_headers(response.headers()); + + if response.status().is_success() { + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(AnthropicError::ReadResponse)?; + + serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse) + } else { + Err(crate::handle_error_response(response, rate_limits).await) + } +} + +pub async fn retrieve_batch_results( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + message_batch_id: &str, +) -> Result, AnthropicError> { + let uri = format!("{api_url}/v1/messages/batches/{message_batch_id}/results"); + + let request_builder = HttpRequest::builder() + .method(Method::GET) + .uri(uri) + .header("Anthropic-Version", "2023-06-01") + .header("X-Api-Key", api_key.trim()); + + let http_request = request_builder + .body(AsyncBody::default()) + .map_err(AnthropicError::BuildRequestBody)?; + + let mut response = client + .send(http_request) + .await + .map_err(AnthropicError::HttpSend)?; + + let rate_limits = RateLimitInfo::from_headers(response.headers()); + + if response.status().is_success() { + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(AnthropicError::ReadResponse)?; + + let mut results = Vec::new(); + for line in body.lines() { + if line.trim().is_empty() { + continue; + } + let result: BatchIndividualResponse = + serde_json::from_str(line).map_err(AnthropicError::DeserializeResponse)?; + results.push(result); + } + + Ok(results) + } else { + Err(crate::handle_error_response(response, rate_limits).await) + } +} diff --git a/crates/assistant_text_thread/src/text_thread.rs b/crates/assistant_text_thread/src/text_thread.rs index 7f24c8f665f8d34aed199562dce1131797f13c9d..b808d9fb0019ccad25366d9ae60cc1f765126c74 100644 --- a/crates/assistant_text_thread/src/text_thread.rs +++ b/crates/assistant_text_thread/src/text_thread.rs @@ -14,7 +14,7 @@ use fs::{Fs, RenameOptions}; use futures::{FutureExt, StreamExt, future::Shared}; use gpui::{ App, AppContext as _, Context, Entity, EventEmitter, RenderImage, SharedString, Subscription, - Task, + Task, WeakEntity, }; use itertools::Itertools as _; use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset}; @@ -688,7 +688,7 @@ pub struct TextThread { _subscriptions: Vec, telemetry: Option>, language_registry: Arc, - project: Option>, + project: Option>, prompt_builder: Arc, completion_mode: agent_settings::CompletionMode, } @@ -708,7 +708,7 @@ impl EventEmitter for TextThread {} impl TextThread { pub fn local( language_registry: Arc, - project: Option>, + project: Option>, telemetry: Option>, prompt_builder: Arc, slash_commands: Arc, @@ -742,7 +742,7 @@ impl TextThread { language_registry: Arc, prompt_builder: Arc, slash_commands: Arc, - project: Option>, + project: Option>, telemetry: Option>, cx: &mut Context, ) -> Self { @@ -873,7 +873,7 @@ impl TextThread { language_registry: Arc, prompt_builder: Arc, slash_commands: Arc, - project: Option>, + project: Option>, telemetry: Option>, cx: &mut Context, ) -> Self { @@ -1167,10 +1167,6 @@ impl TextThread { self.language_registry.clone() } - pub fn project(&self) -> Option> { - self.project.clone() - } - pub fn prompt_builder(&self) -> Arc { self.prompt_builder.clone() } @@ -2967,7 +2963,7 @@ impl TextThread { } fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut App) { - let Some(project) = &self.project else { + let Some(project) = self.project.as_ref().and_then(|project| project.upgrade()) else { return; }; project.read(cx).user_store().update(cx, |user_store, cx| { diff --git a/crates/assistant_text_thread/src/text_thread_store.rs b/crates/assistant_text_thread/src/text_thread_store.rs index 19c317baf0fa728c77faebc388b5e36008aa39b3..71fabed503e8c04a8865bed72c28ae5b30e75574 100644 --- a/crates/assistant_text_thread/src/text_thread_store.rs +++ b/crates/assistant_text_thread/src/text_thread_store.rs @@ -51,7 +51,7 @@ pub struct TextThreadStore { telemetry: Arc, _watch_updates: Task>, client: Arc, - project: Entity, + project: WeakEntity, project_is_shared: bool, client_subscription: Option, _project_subscriptions: Vec, @@ -119,10 +119,10 @@ impl TextThreadStore { ], project_is_shared: false, client: project.read(cx).client(), - project: project.clone(), + project: project.downgrade(), prompt_builder, }; - this.handle_project_shared(project.clone(), cx); + this.handle_project_shared(cx); this.synchronize_contexts(cx); this.register_context_server_handlers(cx); this.reload(cx).detach_and_log_err(cx); @@ -146,7 +146,7 @@ impl TextThreadStore { telemetry: project.read(cx).client().telemetry().clone(), _watch_updates: Task::ready(None), client: project.read(cx).client(), - project, + project: project.downgrade(), project_is_shared: false, client_subscription: None, _project_subscriptions: Default::default(), @@ -180,8 +180,10 @@ impl TextThreadStore { ) -> Result { let context_id = TextThreadId::from_proto(envelope.payload.context_id); let operations = this.update(&mut cx, |this, cx| { + let project = this.project.upgrade().context("project not found")?; + anyhow::ensure!( - !this.project.read(cx).is_via_collab(), + !project.read(cx).is_via_collab(), "only the host contexts can be opened" ); @@ -211,8 +213,9 @@ impl TextThreadStore { mut cx: AsyncApp, ) -> Result { let (context_id, operations) = this.update(&mut cx, |this, cx| { + let project = this.project.upgrade().context("project not found")?; anyhow::ensure!( - !this.project.read(cx).is_via_collab(), + !project.read(cx).is_via_collab(), "can only create contexts as the host" ); @@ -255,8 +258,9 @@ impl TextThreadStore { mut cx: AsyncApp, ) -> Result { this.update(&mut cx, |this, cx| { + let project = this.project.upgrade().context("project not found")?; anyhow::ensure!( - !this.project.read(cx).is_via_collab(), + !project.read(cx).is_via_collab(), "only the host can synchronize contexts" ); @@ -293,8 +297,12 @@ impl TextThreadStore { })? } - fn handle_project_shared(&mut self, _: Entity, cx: &mut Context) { - let is_shared = self.project.read(cx).is_shared(); + fn handle_project_shared(&mut self, cx: &mut Context) { + let Some(project) = self.project.upgrade() else { + return; + }; + + let is_shared = project.read(cx).is_shared(); let was_shared = mem::replace(&mut self.project_is_shared, is_shared); if is_shared == was_shared { return; @@ -309,7 +317,7 @@ impl TextThreadStore { false } }); - let remote_id = self.project.read(cx).remote_id().unwrap(); + let remote_id = project.read(cx).remote_id().unwrap(); self.client_subscription = self .client .subscribe_to_entity(remote_id) @@ -323,13 +331,13 @@ impl TextThreadStore { fn handle_project_event( &mut self, - project: Entity, + _project: Entity, event: &project::Event, cx: &mut Context, ) { match event { project::Event::RemoteIdChanged(_) => { - self.handle_project_shared(project, cx); + self.handle_project_shared(cx); } project::Event::Reshared => { self.advertise_contexts(cx); @@ -382,7 +390,10 @@ impl TextThreadStore { } pub fn create_remote(&mut self, cx: &mut Context) -> Task>> { - let project = self.project.read(cx); + let Some(project) = self.project.upgrade() else { + return Task::ready(Err(anyhow::anyhow!("project was dropped"))); + }; + let project = project.read(cx); let Some(project_id) = project.remote_id() else { return Task::ready(Err(anyhow::anyhow!("project was not remote"))); }; @@ -541,7 +552,10 @@ impl TextThreadStore { text_thread_id: TextThreadId, cx: &mut Context, ) -> Task>> { - let project = self.project.read(cx); + let Some(project) = self.project.upgrade() else { + return Task::ready(Err(anyhow::anyhow!("project was dropped"))); + }; + let project = project.read(cx); let Some(project_id) = project.remote_id() else { return Task::ready(Err(anyhow::anyhow!("project was not remote"))); }; @@ -618,7 +632,10 @@ impl TextThreadStore { event: &TextThreadEvent, cx: &mut Context, ) { - let Some(project_id) = self.project.read(cx).remote_id() else { + let Some(project) = self.project.upgrade() else { + return; + }; + let Some(project_id) = project.read(cx).remote_id() else { return; }; @@ -652,12 +669,14 @@ impl TextThreadStore { } fn advertise_contexts(&self, cx: &App) { - let Some(project_id) = self.project.read(cx).remote_id() else { + let Some(project) = self.project.upgrade() else { + return; + }; + let Some(project_id) = project.read(cx).remote_id() else { return; }; - // For now, only the host can advertise their open contexts. - if self.project.read(cx).is_via_collab() { + if project.read(cx).is_via_collab() { return; } @@ -689,7 +708,10 @@ impl TextThreadStore { } fn synchronize_contexts(&mut self, cx: &mut Context) { - let Some(project_id) = self.project.read(cx).remote_id() else { + let Some(project) = self.project.upgrade() else { + return; + }; + let Some(project_id) = project.read(cx).remote_id() else { return; }; @@ -828,7 +850,10 @@ impl TextThreadStore { } fn register_context_server_handlers(&self, cx: &mut Context) { - let context_server_store = self.project.read(cx).context_server_store(); + let Some(project) = self.project.upgrade() else { + return; + }; + let context_server_store = project.read(cx).context_server_store(); cx.subscribe(&context_server_store, Self::handle_context_server_event) .detach(); diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index de8d69dc14870c5583679753c9a75a477e0cc759..9e590dc4cf48a82ecdda8b007c38ab15f3b602be 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -31,18 +31,10 @@ pub struct PredictEditsRequest { /// Within `signatures` pub excerpt_parent: Option, #[serde(skip_serializing_if = "Vec::is_empty", default)] - pub included_files: Vec, - #[serde(skip_serializing_if = "Vec::is_empty", default)] - pub signatures: Vec, - #[serde(skip_serializing_if = "Vec::is_empty", default)] - pub referenced_declarations: Vec, + pub related_files: Vec, pub events: Vec>, #[serde(default)] pub can_collect_data: bool, - #[serde(skip_serializing_if = "Vec::is_empty", default)] - pub diagnostic_groups: Vec, - #[serde(skip_serializing_if = "is_default", default)] - pub diagnostic_groups_truncated: bool, /// Info about the git repository state, only present when can_collect_data is true. #[serde(skip_serializing_if = "Option::is_none", default)] pub git_info: Option, @@ -58,7 +50,7 @@ pub struct PredictEditsRequest { } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct IncludedFile { +pub struct RelatedFile { pub path: Arc, pub max_row: Line, pub excerpts: Vec, @@ -72,11 +64,9 @@ pub struct Excerpt { #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)] pub enum PromptFormat { - MarkedExcerpt, - LabeledSections, - NumLinesUniDiff, + /// XML old_tex/new_text OldTextNewText, - /// Prompt format intended for use via zeta_cli + /// Prompt format intended for use via edit_prediction_cli OnlySnippets, /// One-sentence instructions used in fine-tuned models Minimal, @@ -87,7 +77,7 @@ pub enum PromptFormat { } impl PromptFormat { - pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff; + pub const DEFAULT: PromptFormat = PromptFormat::Minimal; } impl Default for PromptFormat { @@ -105,10 +95,7 @@ impl PromptFormat { impl std::fmt::Display for PromptFormat { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"), - PromptFormat::LabeledSections => write!(f, "Labeled Sections"), PromptFormat::OnlySnippets => write!(f, "Only Snippets"), - PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"), PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"), PromptFormat::Minimal => write!(f, "Minimal"), PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"), @@ -178,67 +165,6 @@ impl<'a> std::fmt::Display for DiffPathFmt<'a> { } } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Signature { - pub text: String, - pub text_is_truncated: bool, - #[serde(skip_serializing_if = "Option::is_none", default)] - pub parent_index: Option, - /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The - /// file is implicitly the file that contains the descendant declaration or excerpt. - pub range: Range, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ReferencedDeclaration { - pub path: Arc, - pub text: String, - pub text_is_truncated: bool, - /// Range of `text` within file, possibly truncated according to `text_is_truncated` - pub range: Range, - /// Range within `text` - pub signature_range: Range, - /// Index within `signatures`. - #[serde(skip_serializing_if = "Option::is_none", default)] - pub parent_index: Option, - pub score_components: DeclarationScoreComponents, - pub signature_score: f32, - pub declaration_score: f32, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DeclarationScoreComponents { - pub is_same_file: bool, - pub is_referenced_nearby: bool, - pub is_referenced_in_breadcrumb: bool, - pub reference_count: usize, - pub same_file_declaration_count: usize, - pub declaration_count: usize, - pub reference_line_distance: u32, - pub declaration_line_distance: u32, - pub excerpt_vs_item_jaccard: f32, - pub excerpt_vs_signature_jaccard: f32, - pub adjacent_vs_item_jaccard: f32, - pub adjacent_vs_signature_jaccard: f32, - pub excerpt_vs_item_weighted_overlap: f32, - pub excerpt_vs_signature_weighted_overlap: f32, - pub adjacent_vs_item_weighted_overlap: f32, - pub adjacent_vs_signature_weighted_overlap: f32, - pub path_import_match_count: usize, - pub wildcard_path_import_match_count: usize, - pub import_similarity: f32, - pub max_import_similarity: f32, - pub normalized_import_similarity: f32, - pub wildcard_import_similarity: f32, - pub normalized_wildcard_import_similarity: f32, - pub included_by_others: usize, - pub includes_others: usize, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(transparent)] -pub struct DiagnosticGroup(pub Box); - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PredictEditsResponse { pub request_id: Uuid, @@ -262,10 +188,6 @@ pub struct Edit { pub content: String, } -fn is_default(value: &T) -> bool { - *value == T::default() -} - #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)] pub struct Point { pub line: Line, diff --git a/crates/cloud_zeta2_prompt/Cargo.toml b/crates/cloud_zeta2_prompt/Cargo.toml index fa8246950f8d03029388e0276954de946efc2346..a15e3fe43c28349920433272c4040ccc58ff4cb4 100644 --- a/crates/cloud_zeta2_prompt/Cargo.toml +++ b/crates/cloud_zeta2_prompt/Cargo.toml @@ -15,9 +15,4 @@ path = "src/cloud_zeta2_prompt.rs" anyhow.workspace = true cloud_llm_client.workspace = true indoc.workspace = true -ordered-float.workspace = true -rustc-hash.workspace = true -schemars.workspace = true serde.workspace = true -serde_json.workspace = true -strum.workspace = true diff --git a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs index d67190c17556c5eb8b901e9baad73cc2691a9c78..62bfa45f47d0fdfefa9fbd72320c0ddee71cbc47 100644 --- a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs +++ b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs @@ -1,20 +1,12 @@ -//! Zeta2 prompt planning and generation code shared with cloud. -pub mod retrieval_prompt; - -use anyhow::{Context as _, Result, anyhow}; +use anyhow::Result; use cloud_llm_client::predict_edits_v3::{ - self, DiffPathFmt, Event, Excerpt, IncludedFile, Line, Point, PromptFormat, - ReferencedDeclaration, + self, DiffPathFmt, Event, Excerpt, Line, Point, PromptFormat, RelatedFile, }; use indoc::indoc; -use ordered_float::OrderedFloat; -use rustc_hash::{FxHashMap, FxHashSet}; -use serde::Serialize; use std::cmp; use std::fmt::Write; +use std::path::Path; use std::sync::Arc; -use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path}; -use strum::{EnumIter, IntoEnumIterator}; pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024; @@ -24,69 +16,6 @@ pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_s /// NOTE: Differs from zed version of constant - includes a newline pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n"; -// TODO: use constants for markers? -const MARKED_EXCERPT_INSTRUCTIONS: &str = indoc! {" - You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location. - - The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|user_cursor|>. Please respond with edited code for that region. - - Other code is provided for context, and `…` indicates when code has been skipped. - - ## Edit History - -"}; - -const LABELED_SECTIONS_INSTRUCTIONS: &str = indoc! {r#" - You are a code completion assistant and your task is to analyze user edits, and suggest an edit to one of the provided sections of code. - - Sections of code are grouped by file and then labeled by `<|section_N|>` (e.g `<|section_8|>`). - - The cursor position is marked with `<|user_cursor|>` and it will appear within a special section labeled `<|current_section|>`. Prefer editing the current section until no more changes are needed within it. - - Respond ONLY with the name of the section to edit on a single line, followed by all of the code that should replace that section. For example: - - <|current_section|> - for i in 0..16 { - println!("{i}"); - } - - ## Edit History - -"#}; - -const NUMBERED_LINES_INSTRUCTIONS: &str = indoc! {r#" - # Instructions - - You are an edit prediction agent in a code editor. - Your job is to predict the next edit that the user will make, - based on their last few edits and their current cursor location. - - ## Output Format - - You must briefly explain your understanding of the user's goal, in one - or two sentences, and then specify their next edit in the form of a - unified diff, like this: - - ``` - --- a/src/myapp/cli.py - +++ b/src/myapp/cli.py - @@ ... @@ - import os - import time - import sys - +from constants import LOG_LEVEL_WARNING - @@ ... @@ - config.headless() - config.set_interactive(false) - -config.set_log_level(LOG_L) - +config.set_log_level(LOG_LEVEL_WARNING) - config.set_use_color(True) - ``` - - ## Edit History - -"#}; - const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#" You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase. @@ -94,20 +23,6 @@ const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#" "#}; -const UNIFIED_DIFF_REMINDER: &str = indoc! {" - --- - - Analyze the edit history and the files, then provide the unified diff for your predicted edits. - Do not include the cursor marker in your output. - Your diff should include edited file paths in its file headers (lines beginning with `---` and `+++`). - Do not include line numbers in the hunk headers, use `@@ ... @@`. - Removed lines begin with `-`. - Added lines begin with `+`. - Context lines begin with an extra space. - Context and removed lines are used to match the target edit location, so make sure to include enough of them - to uniquely identify it amongst all excerpts of code provided. -"}; - const MINIMAL_PROMPT_REMINDER: &str = indoc! {" --- @@ -164,49 +79,25 @@ const OLD_TEXT_NEW_TEXT_REMINDER: &str = indoc! {r#" Remember that the edits in the edit history have already been applied. "#}; -pub fn build_prompt( - request: &predict_edits_v3::PredictEditsRequest, -) -> Result<(String, SectionLabels)> { - let mut section_labels = Default::default(); - +pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result { let prompt_data = PromptData { events: request.events.clone(), cursor_point: request.cursor_point, cursor_path: request.excerpt_path.clone(), - included_files: request.included_files.clone(), + included_files: request.related_files.clone(), }; match request.prompt_format { PromptFormat::MinimalQwen => { - return Ok((MinimalQwenPrompt.render(&prompt_data), section_labels)); + return Ok(MinimalQwenPrompt.render(&prompt_data)); } PromptFormat::SeedCoder1120 => { - return Ok((SeedCoder1120Prompt.render(&prompt_data), section_labels)); + return Ok(SeedCoder1120Prompt.render(&prompt_data)); } _ => (), }; - let mut insertions = match request.prompt_format { - PromptFormat::MarkedExcerpt => vec![ - ( - Point { - line: request.excerpt_line_range.start, - column: 0, - }, - EDITABLE_REGION_START_MARKER_WITH_NEWLINE, - ), - (request.cursor_point, CURSOR_MARKER), - ( - Point { - line: request.excerpt_line_range.end, - column: 0, - }, - EDITABLE_REGION_END_MARKER_WITH_NEWLINE, - ), - ], - PromptFormat::LabeledSections - | PromptFormat::NumLinesUniDiff - | PromptFormat::Minimal - | PromptFormat::OldTextNewText => { + let insertions = match request.prompt_format { + PromptFormat::Minimal | PromptFormat::OldTextNewText => { vec![(request.cursor_point, CURSOR_MARKER)] } PromptFormat::OnlySnippets => vec![], @@ -215,9 +106,6 @@ pub fn build_prompt( }; let mut prompt = match request.prompt_format { - PromptFormat::MarkedExcerpt => MARKED_EXCERPT_INSTRUCTIONS.to_string(), - PromptFormat::LabeledSections => LABELED_SECTIONS_INSTRUCTIONS.to_string(), - PromptFormat::NumLinesUniDiff => NUMBERED_LINES_INSTRUCTIONS.to_string(), PromptFormat::OldTextNewText => XML_TAGS_INSTRUCTIONS.to_string(), PromptFormat::OnlySnippets => String::new(), PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(), @@ -247,7 +135,7 @@ pub fn build_prompt( You can only edit exactly this part of the file. We prepend line numbers (e.g., `123|`); they are not part of the file.) "}, - PromptFormat::NumLinesUniDiff | PromptFormat::OldTextNewText => indoc! {" + PromptFormat::OldTextNewText => indoc! {" ## Code Excerpts Here is some excerpts of code that you should take into account to predict the next edit. @@ -263,64 +151,51 @@ pub fn build_prompt( Lines starting with `…` indicate omitted line ranges. These may appear inside multi-line code constructs. "}, - _ => indoc! {" + PromptFormat::OnlySnippets | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => { + indoc! {" ## Code Excerpts The cursor marker <|user_cursor|> indicates the current user cursor position. The file is in current state, edits from edit history have been applied. - "}, + "} + } }; prompt.push_str(excerpts_preamble); prompt.push('\n'); - if !request.referenced_declarations.is_empty() || !request.signatures.is_empty() { - let syntax_based_prompt = SyntaxBasedPrompt::populate(request)?; - section_labels = syntax_based_prompt.write(&mut insertions, &mut prompt)?; - } else { - if request.prompt_format == PromptFormat::LabeledSections { - anyhow::bail!("PromptFormat::LabeledSections cannot be used with ContextMode::Llm"); - } - - let include_line_numbers = matches!( - request.prompt_format, - PromptFormat::NumLinesUniDiff | PromptFormat::Minimal - ); - for related_file in &request.included_files { - if request.prompt_format == PromptFormat::Minimal { - write_codeblock_with_filename( - &related_file.path, - &related_file.excerpts, - if related_file.path == request.excerpt_path { - &insertions - } else { - &[] - }, - related_file.max_row, - include_line_numbers, - &mut prompt, - ); - } else { - write_codeblock( - &related_file.path, - &related_file.excerpts, - if related_file.path == request.excerpt_path { - &insertions - } else { - &[] - }, - related_file.max_row, - include_line_numbers, - &mut prompt, - ); - } + let include_line_numbers = matches!(request.prompt_format, PromptFormat::Minimal); + for related_file in &request.related_files { + if request.prompt_format == PromptFormat::Minimal { + write_codeblock_with_filename( + &related_file.path, + &related_file.excerpts, + if related_file.path == request.excerpt_path { + &insertions + } else { + &[] + }, + related_file.max_row, + include_line_numbers, + &mut prompt, + ); + } else { + write_codeblock( + &related_file.path, + &related_file.excerpts, + if related_file.path == request.excerpt_path { + &insertions + } else { + &[] + }, + related_file.max_row, + include_line_numbers, + &mut prompt, + ); } } match request.prompt_format { - PromptFormat::NumLinesUniDiff => { - prompt.push_str(UNIFIED_DIFF_REMINDER); - } PromptFormat::OldTextNewText => { prompt.push_str(OLD_TEXT_NEW_TEXT_REMINDER); } @@ -330,7 +205,7 @@ pub fn build_prompt( _ => {} } - Ok((prompt, section_labels)) + Ok(prompt) } pub fn generation_params(prompt_format: PromptFormat) -> GenerationParams { @@ -444,476 +319,11 @@ pub fn push_events(output: &mut String, events: &[Arc]) writeln!(output, "`````\n").unwrap(); } -pub struct SyntaxBasedPrompt<'a> { - request: &'a predict_edits_v3::PredictEditsRequest, - /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in - /// `to_prompt_string`. - snippets: Vec>, - budget_used: usize, -} - -#[derive(Clone, Debug)] -pub struct PlannedSnippet<'a> { - path: Arc, - range: Range, - text: &'a str, - // TODO: Indicate this in the output - #[allow(dead_code)] - text_is_truncated: bool, -} - -#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)] -pub enum DeclarationStyle { - Signature, - Declaration, -} - -#[derive(Default, Clone, Debug, Serialize)] -pub struct SectionLabels { - pub excerpt_index: usize, - pub section_ranges: Vec<(Arc, Range)>, -} - -impl<'a> SyntaxBasedPrompt<'a> { - /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following: - /// - /// Initializes a priority queue by populating it with each snippet, finding the - /// DeclarationStyle that minimizes `score_density = score / snippet.range(style).len()`. When a - /// "signature" snippet is popped, insert an entry for the "declaration" variant that reflects - /// the cost of upgrade. - /// - /// TODO: Implement an early halting condition. One option might be to have another priority - /// queue where the score is the size, and update it accordingly. Another option might be to - /// have some simpler heuristic like bailing after N failed insertions, or based on how much - /// budget is left. - /// - /// TODO: Has the current known sources of imprecision: - /// - /// * Does not consider snippet overlap when ranking. For example, it might add a field to the - /// plan even though the containing struct is already included. - /// - /// * Does not consider cost of signatures when ranking snippets - this is tricky since - /// signatures may be shared by multiple snippets. - /// - /// * Does not include file paths / other text when considering max_bytes. - pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result { - let mut this = Self { - request, - snippets: Vec::new(), - budget_used: request.excerpt.len(), - }; - let mut included_parents = FxHashSet::default(); - let additional_parents = this.additional_parent_signatures( - &request.excerpt_path, - request.excerpt_parent, - &included_parents, - )?; - this.add_parents(&mut included_parents, additional_parents); - - let max_bytes = request.prompt_max_bytes.unwrap_or(DEFAULT_MAX_PROMPT_BYTES); - - if this.budget_used > max_bytes { - return Err(anyhow!( - "Excerpt + signatures size of {} already exceeds budget of {}", - this.budget_used, - max_bytes - )); - } - - #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] - struct QueueEntry { - score_density: OrderedFloat, - declaration_index: usize, - style: DeclarationStyle, - } - - // Initialize priority queue with the best score for each snippet. - let mut queue: BinaryHeap = BinaryHeap::new(); - for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() { - let (style, score_density) = DeclarationStyle::iter() - .map(|style| { - ( - style, - OrderedFloat(declaration_score_density(&declaration, style)), - ) - }) - .max_by_key(|(_, score_density)| *score_density) - .unwrap(); - queue.push(QueueEntry { - score_density, - declaration_index, - style, - }); - } - - // Knapsack selection loop - while let Some(queue_entry) = queue.pop() { - let Some(declaration) = request - .referenced_declarations - .get(queue_entry.declaration_index) - else { - return Err(anyhow!( - "Invalid declaration index {}", - queue_entry.declaration_index - )); - }; - - let mut additional_bytes = declaration_size(declaration, queue_entry.style); - if this.budget_used + additional_bytes > max_bytes { - continue; - } - - let additional_parents = this.additional_parent_signatures( - &declaration.path, - declaration.parent_index, - &mut included_parents, - )?; - additional_bytes += additional_parents - .iter() - .map(|(_, snippet)| snippet.text.len()) - .sum::(); - if this.budget_used + additional_bytes > max_bytes { - continue; - } - - this.budget_used += additional_bytes; - this.add_parents(&mut included_parents, additional_parents); - let planned_snippet = match queue_entry.style { - DeclarationStyle::Signature => { - let Some(text) = declaration.text.get(declaration.signature_range.clone()) - else { - return Err(anyhow!( - "Invalid declaration signature_range {:?} with text.len() = {}", - declaration.signature_range, - declaration.text.len() - )); - }; - let signature_start_line = declaration.range.start - + Line( - declaration.text[..declaration.signature_range.start] - .lines() - .count() as u32, - ); - let signature_end_line = signature_start_line - + Line( - declaration.text - [declaration.signature_range.start..declaration.signature_range.end] - .lines() - .count() as u32, - ); - let range = signature_start_line..signature_end_line; - - PlannedSnippet { - path: declaration.path.clone(), - range, - text, - text_is_truncated: declaration.text_is_truncated, - } - } - DeclarationStyle::Declaration => PlannedSnippet { - path: declaration.path.clone(), - range: declaration.range.clone(), - text: &declaration.text, - text_is_truncated: declaration.text_is_truncated, - }, - }; - this.snippets.push(planned_snippet); - - // When a Signature is consumed, insert an entry for Definition style. - if queue_entry.style == DeclarationStyle::Signature { - let signature_size = declaration_size(&declaration, DeclarationStyle::Signature); - let declaration_size = - declaration_size(&declaration, DeclarationStyle::Declaration); - let signature_score = declaration_score(&declaration, DeclarationStyle::Signature); - let declaration_score = - declaration_score(&declaration, DeclarationStyle::Declaration); - - let score_diff = declaration_score - signature_score; - let size_diff = declaration_size.saturating_sub(signature_size); - if score_diff > 0.0001 && size_diff > 0 { - queue.push(QueueEntry { - declaration_index: queue_entry.declaration_index, - score_density: OrderedFloat(score_diff / (size_diff as f32)), - style: DeclarationStyle::Declaration, - }); - } - } - } - - anyhow::Ok(this) - } - - fn add_parents( - &mut self, - included_parents: &mut FxHashSet, - snippets: Vec<(usize, PlannedSnippet<'a>)>, - ) { - for (parent_index, snippet) in snippets { - included_parents.insert(parent_index); - self.budget_used += snippet.text.len(); - self.snippets.push(snippet); - } - } - - fn additional_parent_signatures( - &self, - path: &Arc, - parent_index: Option, - included_parents: &FxHashSet, - ) -> Result)>> { - let mut results = Vec::new(); - self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?; - Ok(results) - } - - fn additional_parent_signatures_impl( - &self, - path: &Arc, - parent_index: Option, - included_parents: &FxHashSet, - results: &mut Vec<(usize, PlannedSnippet<'a>)>, - ) -> Result<()> { - let Some(parent_index) = parent_index else { - return Ok(()); - }; - if included_parents.contains(&parent_index) { - return Ok(()); - } - let Some(parent_signature) = self.request.signatures.get(parent_index) else { - return Err(anyhow!("Invalid parent index {}", parent_index)); - }; - results.push(( - parent_index, - PlannedSnippet { - path: path.clone(), - range: parent_signature.range.clone(), - text: &parent_signature.text, - text_is_truncated: parent_signature.text_is_truncated, - }, - )); - self.additional_parent_signatures_impl( - path, - parent_signature.parent_index, - included_parents, - results, - ) - } - - /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple - /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive - /// chunks. - pub fn write( - &'a self, - excerpt_file_insertions: &mut Vec<(Point, &'static str)>, - prompt: &mut String, - ) -> Result { - let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> = - FxHashMap::default(); - for snippet in &self.snippets { - file_to_snippets - .entry(&snippet.path) - .or_default() - .push(snippet); - } - - // Reorder so that file with cursor comes last - let mut file_snippets = Vec::new(); - let mut excerpt_file_snippets = Vec::new(); - for (file_path, snippets) in file_to_snippets { - if file_path == self.request.excerpt_path.as_ref() { - excerpt_file_snippets = snippets; - } else { - file_snippets.push((file_path, snippets, false)); - } - } - let excerpt_snippet = PlannedSnippet { - path: self.request.excerpt_path.clone(), - range: self.request.excerpt_line_range.clone(), - text: &self.request.excerpt, - text_is_truncated: false, - }; - excerpt_file_snippets.push(&excerpt_snippet); - file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true)); - - let section_labels = - self.push_file_snippets(prompt, excerpt_file_insertions, file_snippets)?; - - Ok(section_labels) - } - - fn push_file_snippets( - &self, - output: &mut String, - excerpt_file_insertions: &mut Vec<(Point, &'static str)>, - file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>, - ) -> Result { - let mut section_ranges = Vec::new(); - let mut excerpt_index = None; - - for (file_path, mut snippets, is_excerpt_file) in file_snippets { - snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end))); - - // TODO: What if the snippets get expanded too large to be editable? - let mut current_snippet: Option<(&PlannedSnippet, Range)> = None; - let mut disjoint_snippets: Vec<(&PlannedSnippet, Range)> = Vec::new(); - for snippet in snippets { - if let Some((_, current_snippet_range)) = current_snippet.as_mut() - && snippet.range.start <= current_snippet_range.end - { - current_snippet_range.end = current_snippet_range.end.max(snippet.range.end); - continue; - } - if let Some(current_snippet) = current_snippet.take() { - disjoint_snippets.push(current_snippet); - } - current_snippet = Some((snippet, snippet.range.clone())); - } - if let Some(current_snippet) = current_snippet.take() { - disjoint_snippets.push(current_snippet); - } - - writeln!(output, "`````path={}", file_path.display()).ok(); - let mut skipped_last_snippet = false; - for (snippet, range) in disjoint_snippets { - let section_index = section_ranges.len(); - - match self.request.prompt_format { - PromptFormat::MarkedExcerpt - | PromptFormat::OnlySnippets - | PromptFormat::OldTextNewText - | PromptFormat::Minimal - | PromptFormat::NumLinesUniDiff => { - if range.start.0 > 0 && !skipped_last_snippet { - output.push_str("…\n"); - } - } - PromptFormat::LabeledSections => { - if is_excerpt_file - && range.start <= self.request.excerpt_line_range.start - && range.end >= self.request.excerpt_line_range.end - { - writeln!(output, "<|current_section|>").ok(); - } else { - writeln!(output, "<|section_{}|>", section_index).ok(); - } - } - PromptFormat::MinimalQwen => unreachable!(), - PromptFormat::SeedCoder1120 => unreachable!(), - } - - let push_full_snippet = |output: &mut String| { - if self.request.prompt_format == PromptFormat::NumLinesUniDiff { - for (i, line) in snippet.text.lines().enumerate() { - writeln!(output, "{}|{}", i as u32 + range.start.0 + 1, line)?; - } - } else { - output.push_str(&snippet.text); - } - anyhow::Ok(()) - }; - - if is_excerpt_file { - if self.request.prompt_format == PromptFormat::OnlySnippets { - if range.start >= self.request.excerpt_line_range.start - && range.end <= self.request.excerpt_line_range.end - { - skipped_last_snippet = true; - } else { - skipped_last_snippet = false; - output.push_str(snippet.text); - } - } else if !excerpt_file_insertions.is_empty() { - let lines = snippet.text.lines().collect::>(); - let push_line = |output: &mut String, line_ix: usize| { - if self.request.prompt_format == PromptFormat::NumLinesUniDiff { - write!(output, "{}|", line_ix as u32 + range.start.0 + 1)?; - } - anyhow::Ok(writeln!(output, "{}", lines[line_ix])?) - }; - let mut last_line_ix = 0; - let mut insertion_ix = 0; - while insertion_ix < excerpt_file_insertions.len() { - let (point, insertion) = &excerpt_file_insertions[insertion_ix]; - let found = point.line >= range.start && point.line <= range.end; - if found { - excerpt_index = Some(section_index); - let insertion_line_ix = (point.line.0 - range.start.0) as usize; - for line_ix in last_line_ix..insertion_line_ix { - push_line(output, line_ix)?; - } - if let Some(next_line) = lines.get(insertion_line_ix) { - if self.request.prompt_format == PromptFormat::NumLinesUniDiff { - write!( - output, - "{}|", - insertion_line_ix as u32 + range.start.0 + 1 - )? - } - output.push_str(&next_line[..point.column as usize]); - output.push_str(insertion); - writeln!(output, "{}", &next_line[point.column as usize..])?; - } else { - writeln!(output, "{}", insertion)?; - } - last_line_ix = insertion_line_ix + 1; - excerpt_file_insertions.remove(insertion_ix); - continue; - } - insertion_ix += 1; - } - skipped_last_snippet = false; - for line_ix in last_line_ix..lines.len() { - push_line(output, line_ix)?; - } - } else { - skipped_last_snippet = false; - push_full_snippet(output)?; - } - } else { - skipped_last_snippet = false; - push_full_snippet(output)?; - } - - section_ranges.push((snippet.path.clone(), range)); - } - - output.push_str("`````\n\n"); - } - - Ok(SectionLabels { - // TODO: Clean this up - excerpt_index: match self.request.prompt_format { - PromptFormat::OnlySnippets => 0, - _ => excerpt_index.context("bug: no snippet found for excerpt")?, - }, - section_ranges, - }) - } -} - -fn declaration_score_density(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 { - declaration_score(declaration, style) / declaration_size(declaration, style) as f32 -} - -fn declaration_score(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 { - match style { - DeclarationStyle::Signature => declaration.signature_score, - DeclarationStyle::Declaration => declaration.declaration_score, - } -} - -fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> usize { - match style { - DeclarationStyle::Signature => declaration.signature_range.len(), - DeclarationStyle::Declaration => declaration.text.len(), - } -} - struct PromptData { events: Vec>, cursor_point: Point, cursor_path: Arc, // TODO: make a common struct with cursor_point - included_files: Vec, + included_files: Vec, } #[derive(Default)] @@ -1051,7 +461,7 @@ impl SeedCoder1120Prompt { context } - fn fmt_fim(&self, file: &IncludedFile, cursor_point: Point) -> String { + fn fmt_fim(&self, file: &RelatedFile, cursor_point: Point) -> String { let mut buf = String::new(); const FIM_SUFFIX: &str = "<[fim-suffix]>"; const FIM_PREFIX: &str = "<[fim-prefix]>"; diff --git a/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs b/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs deleted file mode 100644 index fd35f63f03ff967491a28d817852f6622e4919ca..0000000000000000000000000000000000000000 --- a/crates/cloud_zeta2_prompt/src/retrieval_prompt.rs +++ /dev/null @@ -1,244 +0,0 @@ -use anyhow::Result; -use cloud_llm_client::predict_edits_v3::{self, Excerpt}; -use indoc::indoc; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use std::fmt::Write; - -use crate::{push_events, write_codeblock}; - -pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> Result { - let mut prompt = SEARCH_INSTRUCTIONS.to_string(); - - if !request.events.is_empty() { - writeln!(&mut prompt, "\n## User Edits\n\n")?; - push_events(&mut prompt, &request.events); - } - - writeln!(&mut prompt, "## Cursor context\n")?; - write_codeblock( - &request.excerpt_path, - &[Excerpt { - start_line: request.excerpt_line_range.start, - text: request.excerpt.into(), - }], - &[], - request.cursor_file_max_row, - true, - &mut prompt, - ); - - writeln!(&mut prompt, "{TOOL_USE_REMINDER}")?; - - Ok(prompt) -} - -/// Search for relevant code -/// -/// For the best results, run multiple queries at once with a single invocation of this tool. -#[derive(Clone, Deserialize, Serialize, JsonSchema)] -pub struct SearchToolInput { - /// An array of queries to run for gathering context relevant to the next prediction - #[schemars(length(max = 3))] - #[serde(deserialize_with = "deserialize_queries")] - pub queries: Box<[SearchToolQuery]>, -} - -fn deserialize_queries<'de, D>(deserializer: D) -> Result, D::Error> -where - D: serde::Deserializer<'de>, -{ - use serde::de::Error; - - #[derive(Deserialize)] - #[serde(untagged)] - enum QueryCollection { - Array(Box<[SearchToolQuery]>), - DoubleArray(Box<[Box<[SearchToolQuery]>]>), - Single(SearchToolQuery), - } - - #[derive(Deserialize)] - #[serde(untagged)] - enum MaybeDoubleEncoded { - SingleEncoded(QueryCollection), - DoubleEncoded(String), - } - - let result = MaybeDoubleEncoded::deserialize(deserializer)?; - - let normalized = match result { - MaybeDoubleEncoded::SingleEncoded(value) => value, - MaybeDoubleEncoded::DoubleEncoded(value) => { - serde_json::from_str(&value).map_err(D::Error::custom)? - } - }; - - Ok(match normalized { - QueryCollection::Array(items) => items, - QueryCollection::Single(search_tool_query) => Box::new([search_tool_query]), - QueryCollection::DoubleArray(double_array) => double_array.into_iter().flatten().collect(), - }) -} - -/// Search for relevant code by path, syntax hierarchy, and content. -#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)] -pub struct SearchToolQuery { - /// 1. A glob pattern to match file paths in the codebase to search in. - pub glob: String, - /// 2. Regular expressions to match syntax nodes **by their first line** and hierarchy. - /// - /// Subsequent regexes match nodes within the full content of the nodes matched by the previous regexes. - /// - /// Example: Searching for a `User` class - /// ["class\s+User"] - /// - /// Example: Searching for a `get_full_name` method under a `User` class - /// ["class\s+User", "def\sget_full_name"] - /// - /// Skip this field to match on content alone. - #[schemars(length(max = 3))] - #[serde(default)] - pub syntax_node: Vec, - /// 3. An optional regular expression to match the final content that should appear in the results. - /// - /// - Content will be matched within all lines of the matched syntax nodes. - /// - If syntax node regexes are provided, this field can be skipped to include as much of the node itself as possible. - /// - If no syntax node regexes are provided, the content will be matched within the entire file. - pub content: Option, -} - -pub const TOOL_NAME: &str = "search"; - -const SEARCH_INSTRUCTIONS: &str = indoc! {r#" - You are part of an edit prediction system in a code editor. - Your role is to search for code that will serve as context for predicting the next edit. - - - Analyze the user's recent edits and current cursor context - - Use the `search` tool to find code that is relevant for predicting the next edit - - Focus on finding: - - Code patterns that might need similar changes based on the recent edits - - Functions, variables, types, and constants referenced in the current cursor context - - Related implementations, usages, or dependencies that may require consistent updates - - How items defined in the cursor excerpt are used or altered - - You will not be able to filter results or perform subsequent queries, so keep searches as targeted as possible - - Use `syntax_node` parameter whenever you're looking for a particular type, class, or function - - Avoid using wildcard globs if you already know the file path of the content you're looking for -"#}; - -const TOOL_USE_REMINDER: &str = indoc! {" - -- - Analyze the user's intent in one to two sentences, then call the `search` tool. -"}; - -#[cfg(test)] -mod tests { - use serde_json::json; - - use super::*; - - #[test] - fn test_deserialize_queries() { - let single_query_json = indoc! {r#"{ - "queries": { - "glob": "**/*.rs", - "syntax_node": ["fn test"], - "content": "assert" - } - }"#}; - - let flat_input: SearchToolInput = serde_json::from_str(single_query_json).unwrap(); - assert_eq!(flat_input.queries.len(), 1); - assert_eq!(flat_input.queries[0].glob, "**/*.rs"); - assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]); - assert_eq!(flat_input.queries[0].content, Some("assert".to_string())); - - let flat_json = indoc! {r#"{ - "queries": [ - { - "glob": "**/*.rs", - "syntax_node": ["fn test"], - "content": "assert" - }, - { - "glob": "**/*.ts", - "syntax_node": [], - "content": null - } - ] - }"#}; - - let flat_input: SearchToolInput = serde_json::from_str(flat_json).unwrap(); - assert_eq!(flat_input.queries.len(), 2); - assert_eq!(flat_input.queries[0].glob, "**/*.rs"); - assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]); - assert_eq!(flat_input.queries[0].content, Some("assert".to_string())); - assert_eq!(flat_input.queries[1].glob, "**/*.ts"); - assert_eq!(flat_input.queries[1].syntax_node.len(), 0); - assert_eq!(flat_input.queries[1].content, None); - - let nested_json = indoc! {r#"{ - "queries": [ - [ - { - "glob": "**/*.rs", - "syntax_node": ["fn test"], - "content": "assert" - } - ], - [ - { - "glob": "**/*.ts", - "syntax_node": [], - "content": null - } - ] - ] - }"#}; - - let nested_input: SearchToolInput = serde_json::from_str(nested_json).unwrap(); - - assert_eq!(nested_input.queries.len(), 2); - - assert_eq!(nested_input.queries[0].glob, "**/*.rs"); - assert_eq!(nested_input.queries[0].syntax_node, vec!["fn test"]); - assert_eq!(nested_input.queries[0].content, Some("assert".to_string())); - assert_eq!(nested_input.queries[1].glob, "**/*.ts"); - assert_eq!(nested_input.queries[1].syntax_node.len(), 0); - assert_eq!(nested_input.queries[1].content, None); - - let double_encoded_queries = serde_json::to_string(&json!({ - "queries": serde_json::to_string(&json!([ - { - "glob": "**/*.rs", - "syntax_node": ["fn test"], - "content": "assert" - }, - { - "glob": "**/*.ts", - "syntax_node": [], - "content": null - } - ])).unwrap() - })) - .unwrap(); - - let double_encoded_input: SearchToolInput = - serde_json::from_str(&double_encoded_queries).unwrap(); - - assert_eq!(double_encoded_input.queries.len(), 2); - - assert_eq!(double_encoded_input.queries[0].glob, "**/*.rs"); - assert_eq!(double_encoded_input.queries[0].syntax_node, vec!["fn test"]); - assert_eq!( - double_encoded_input.queries[0].content, - Some("assert".to_string()) - ); - assert_eq!(double_encoded_input.queries[1].glob, "**/*.ts"); - assert_eq!(double_encoded_input.queries[1].syntax_node.len(), 0); - assert_eq!(double_encoded_input.queries[1].content, None); - - // ### ERROR Switching from var declarations to lexical declarations [RUN 073] - // invalid search json {"queries": ["express/lib/response.js", "var\\s+[a-zA-Z_][a-zA-Z0-9_]*\\s*=.*;", "function.*\\(.*\\).*\\{.*\\}"]} - } -} diff --git a/crates/codestral/Cargo.toml b/crates/codestral/Cargo.toml index b402274a33530424349081da764a4b6766e419e9..7f3bf3b22dda8f9dbde1923c76855342c6cbac4c 100644 --- a/crates/codestral/Cargo.toml +++ b/crates/codestral/Cargo.toml @@ -10,7 +10,7 @@ path = "src/codestral.rs" [dependencies] anyhow.workspace = true -edit_prediction.workspace = true +edit_prediction_types.workspace = true edit_prediction_context.workspace = true futures.workspace = true gpui.workspace = true diff --git a/crates/codestral/src/codestral.rs b/crates/codestral/src/codestral.rs index 6a500acbf6ec5eea63c35a8deb83a8545cee497e..9bf0296ac357937cd1ad1470dba9a98864911de9 100644 --- a/crates/codestral/src/codestral.rs +++ b/crates/codestral/src/codestral.rs @@ -1,6 +1,6 @@ use anyhow::{Context as _, Result}; -use edit_prediction::{Direction, EditPrediction, EditPredictionProvider}; use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions}; +use edit_prediction_types::{Direction, EditPrediction, EditPredictionDelegate}; use futures::AsyncReadExt; use gpui::{App, Context, Entity, Task}; use http_client::HttpClient; @@ -43,17 +43,17 @@ impl CurrentCompletion { /// Attempts to adjust the edits based on changes made to the buffer since the completion was generated. /// Returns None if the user's edits conflict with the predicted edits. fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option, Arc)>> { - edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits) + edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits) } } -pub struct CodestralCompletionProvider { +pub struct CodestralEditPredictionDelegate { http_client: Arc, pending_request: Option>>, current_completion: Option, } -impl CodestralCompletionProvider { +impl CodestralEditPredictionDelegate { pub fn new(http_client: Arc) -> Self { Self { http_client, @@ -165,7 +165,7 @@ impl CodestralCompletionProvider { } } -impl EditPredictionProvider for CodestralCompletionProvider { +impl EditPredictionDelegate for CodestralEditPredictionDelegate { fn name() -> &'static str { "codestral" } @@ -174,7 +174,7 @@ impl EditPredictionProvider for CodestralCompletionProvider { "Codestral" } - fn show_completions_in_menu() -> bool { + fn show_predictions_in_menu() -> bool { true } @@ -239,7 +239,6 @@ impl EditPredictionProvider for CodestralCompletionProvider { cursor_point, &snapshot, &EXCERPT_OPTIONS, - None, ) .context("Line containing cursor doesn't fit in excerpt max bytes")?; diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index b8a4c035499d45adc494c9f8175a772d15aa96df..79fc21fe33423d7eb887744b4ad84094a022862e 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -65,7 +65,7 @@ tokio = { workspace = true, features = ["full"] } toml.workspace = true tower = "0.4" tower-http = { workspace = true, features = ["trace"] } -tracing = "0.1.40" +tracing.workspace = true tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json", "registry", "tracing-log"] } # workaround for https://github.com/tokio-rs/tracing/issues/2927 util.workspace = true uuid.workspace = true diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index a736ddfd1fe3334b1b847e820bd1816cb625ddca..32a2ed2e1331fc7b16f859accd895a7bce055804 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -121,6 +121,8 @@ CREATE TABLE "project_repositories" ( "merge_message" VARCHAR, "branch_summary" VARCHAR, "head_commit_details" VARCHAR, + "remote_upstream_url" VARCHAR, + "remote_origin_url" VARCHAR, PRIMARY KEY (project_id, id) ); diff --git a/crates/collab/migrations/20251203234258_add_remote_urls_to_project_repositories.sql b/crates/collab/migrations/20251203234258_add_remote_urls_to_project_repositories.sql new file mode 100644 index 0000000000000000000000000000000000000000..e1396de27d90fb2c872197d25198743d19be86f8 --- /dev/null +++ b/crates/collab/migrations/20251203234258_add_remote_urls_to_project_repositories.sql @@ -0,0 +1,2 @@ +ALTER TABLE "project_repositories" ADD COLUMN "remote_upstream_url" VARCHAR; +ALTER TABLE "project_repositories" ADD COLUMN "remote_origin_url" VARCHAR; diff --git a/crates/collab/src/db/queries/projects.rs b/crates/collab/src/db/queries/projects.rs index 51a0ef83323ec70675283d2fdec7ca1ad791b12d..6f1d8b884d15041eadaa9073a5bd99e5ed352502 100644 --- a/crates/collab/src/db/queries/projects.rs +++ b/crates/collab/src/db/queries/projects.rs @@ -362,6 +362,8 @@ impl Database { entry_ids: ActiveValue::set("[]".into()), head_commit_details: ActiveValue::set(None), merge_message: ActiveValue::set(None), + remote_upstream_url: ActiveValue::set(None), + remote_origin_url: ActiveValue::set(None), } }), ) @@ -511,6 +513,8 @@ impl Database { serde_json::to_string(&update.current_merge_conflicts).unwrap(), )), merge_message: ActiveValue::set(update.merge_message.clone()), + remote_upstream_url: ActiveValue::set(update.remote_upstream_url.clone()), + remote_origin_url: ActiveValue::set(update.remote_origin_url.clone()), }) .on_conflict( OnConflict::columns([ @@ -1005,6 +1009,8 @@ impl Database { is_last_update: true, merge_message: db_repository_entry.merge_message, stash_entries: Vec::new(), + remote_upstream_url: db_repository_entry.remote_upstream_url.clone(), + remote_origin_url: db_repository_entry.remote_origin_url.clone(), }); } } diff --git a/crates/collab/src/db/queries/rooms.rs b/crates/collab/src/db/queries/rooms.rs index f020b99b5f1030cfe9391498512258e6db249bac..eafb5cac44a510bf4ced0434a9b4adfadff0ebbc 100644 --- a/crates/collab/src/db/queries/rooms.rs +++ b/crates/collab/src/db/queries/rooms.rs @@ -796,6 +796,8 @@ impl Database { is_last_update: true, merge_message: db_repository.merge_message, stash_entries: Vec::new(), + remote_upstream_url: db_repository.remote_upstream_url.clone(), + remote_origin_url: db_repository.remote_origin_url.clone(), }); } } diff --git a/crates/collab/src/db/tables/project_repository.rs b/crates/collab/src/db/tables/project_repository.rs index eb653ecee37d48ce79e26450eb85d87dec411c1e..190ae8d79c54bb78daef4a1568ec75683eb0b0f2 100644 --- a/crates/collab/src/db/tables/project_repository.rs +++ b/crates/collab/src/db/tables/project_repository.rs @@ -22,6 +22,8 @@ pub struct Model { pub branch_summary: Option, // A JSON object representing the current Head commit values pub head_commit_details: Option, + pub remote_upstream_url: Option, + pub remote_origin_url: Option, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/collab/src/tests.rs b/crates/collab/src/tests.rs index 7d07360b8042ed54a9f19a82a2876e448e8a14a4..3785ee0b7abaeddeac5c9acb1718407ab5bd54f2 100644 --- a/crates/collab/src/tests.rs +++ b/crates/collab/src/tests.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use call::Room; use client::ChannelId; use gpui::{Entity, TestAppContext}; @@ -18,7 +16,6 @@ mod randomized_test_helpers; mod remote_editing_collaboration_tests; mod test_server; -use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; pub use randomized_test_helpers::{ RandomizedTest, TestError, UserTestPlan, run_randomized_test, save_randomized_test_plan, }; @@ -51,17 +48,3 @@ fn room_participants(room: &Entity, cx: &mut TestAppContext) -> RoomPartic fn channel_id(room: &Entity, cx: &mut TestAppContext) -> Option { cx.read(|cx| room.read(cx).channel_id()) } - -fn rust_lang() -> Arc { - Arc::new(Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - )) -} diff --git a/crates/collab/src/tests/editor_tests.rs b/crates/collab/src/tests/editor_tests.rs index 785a6457c8fdb57f84a8e7b5a8487f0ceae3d025..ba92e868126c7f27fb5051021fce44fe43c8d5e7 100644 --- a/crates/collab/src/tests/editor_tests.rs +++ b/crates/collab/src/tests/editor_tests.rs @@ -1,7 +1,4 @@ -use crate::{ - rpc::RECONNECT_TIMEOUT, - tests::{TestServer, rust_lang}, -}; +use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer}; use call::ActiveCall; use editor::{ DocumentColorsRenderMode, Editor, FETCH_COLORS_DEBOUNCE_TIMEOUT, MultiBufferOffset, RowInfo, @@ -23,7 +20,7 @@ use gpui::{ App, Rgba, SharedString, TestAppContext, UpdateGlobal, VisualContext, VisualTestContext, }; use indoc::indoc; -use language::FakeLspAdapter; +use language::{FakeLspAdapter, rust_lang}; use lsp::LSP_REQUEST_TIMEOUT; use pretty_assertions::assert_eq; use project::{ @@ -3518,7 +3515,6 @@ async fn test_git_blame_is_forwarded(cx_a: &mut TestAppContext, cx_b: &mut TestA .into_iter() .map(|(sha, message)| (sha.parse().unwrap(), message.into())) .collect(), - remote_url: Some("git@github.com:zed-industries/zed.git".to_string()), }; client_a.fs().set_blame_for_repo( Path::new(path!("/my-repo/.git")), @@ -3603,10 +3599,6 @@ async fn test_git_blame_is_forwarded(cx_a: &mut TestAppContext, cx_b: &mut TestA for (idx, (buffer, entry)) in entries.iter().flatten().enumerate() { let details = blame.details_for_entry(*buffer, entry).unwrap(); assert_eq!(details.message, format!("message for idx-{}", idx)); - assert_eq!( - details.permalink.unwrap().to_string(), - format!("https://github.com/zed-industries/zed/commit/{}", entry.sha) - ); } }); }); diff --git a/crates/collab/src/tests/integration_tests.rs b/crates/collab/src/tests/integration_tests.rs index fcda8688d427f3e6b937f00edc7c3586dfdbef36..391e7355ea196dfe25d363472918837ea817f450 100644 --- a/crates/collab/src/tests/integration_tests.rs +++ b/crates/collab/src/tests/integration_tests.rs @@ -2,7 +2,7 @@ use crate::{ rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, tests::{ RoomParticipants, TestClient, TestServer, channel_id, following_tests::join_channel, - room_participants, rust_lang, + room_participants, }, }; use anyhow::{Result, anyhow}; @@ -26,7 +26,7 @@ use language::{ Diagnostic, DiagnosticEntry, DiagnosticSourceKind, FakeLspAdapter, Language, LanguageConfig, LanguageMatcher, LineEnding, OffsetRangeExt, Point, Rope, language_settings::{Formatter, FormatterList}, - tree_sitter_rust, tree_sitter_typescript, + rust_lang, tree_sitter_rust, tree_sitter_typescript, }; use lsp::{LanguageServerId, OneOf}; use parking_lot::Mutex; diff --git a/crates/copilot/Cargo.toml b/crates/copilot/Cargo.toml index 0d3b19c0c7bd264f8ed10e53289376055f833307..459abda17573d66287e2c8ca0b995292acaf163b 100644 --- a/crates/copilot/Cargo.toml +++ b/crates/copilot/Cargo.toml @@ -33,7 +33,7 @@ fs.workspace = true futures.workspace = true gpui.workspace = true http_client.workspace = true -edit_prediction.workspace = true +edit_prediction_types.workspace = true language.workspace = true log.workspace = true lsp.workspace = true diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index ed18a199bf2c08c8c046a8ad3e7f945b1340643e..6fbdeff807b65d22193ba7fdcb8e990f7184f70e 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -1,5 +1,5 @@ pub mod copilot_chat; -mod copilot_completion_provider; +mod copilot_edit_prediction_delegate; pub mod copilot_responses; pub mod request; mod sign_in; @@ -46,7 +46,7 @@ use util::rel_path::RelPath; use util::{ResultExt, fs::remove_matching}; use workspace::Workspace; -pub use crate::copilot_completion_provider::CopilotCompletionProvider; +pub use crate::copilot_edit_prediction_delegate::CopilotEditPredictionDelegate; pub use crate::sign_in::{CopilotCodeVerification, initiate_sign_in, reinstall_and_sign_in}; actions!( diff --git a/crates/copilot/src/copilot_completion_provider.rs b/crates/copilot/src/copilot_edit_prediction_delegate.rs similarity index 98% rename from crates/copilot/src/copilot_completion_provider.rs rename to crates/copilot/src/copilot_edit_prediction_delegate.rs index e92f0c7d7dd7e51c4a8fdc19f34bd6eb4189c097..961154dbeecad007f026f25eeac25de95d751d9e 100644 --- a/crates/copilot/src/copilot_completion_provider.rs +++ b/crates/copilot/src/copilot_edit_prediction_delegate.rs @@ -1,6 +1,6 @@ use crate::{Completion, Copilot}; use anyhow::Result; -use edit_prediction::{Direction, EditPrediction, EditPredictionProvider}; +use edit_prediction_types::{Direction, EditPrediction, EditPredictionDelegate}; use gpui::{App, Context, Entity, EntityId, Task}; use language::{Buffer, OffsetRangeExt, ToOffset, language_settings::AllLanguageSettings}; use settings::Settings; @@ -8,7 +8,7 @@ use std::{path::Path, time::Duration}; pub const COPILOT_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75); -pub struct CopilotCompletionProvider { +pub struct CopilotEditPredictionDelegate { cycled: bool, buffer_id: Option, completions: Vec, @@ -19,7 +19,7 @@ pub struct CopilotCompletionProvider { copilot: Entity, } -impl CopilotCompletionProvider { +impl CopilotEditPredictionDelegate { pub fn new(copilot: Entity) -> Self { Self { cycled: false, @@ -47,7 +47,7 @@ impl CopilotCompletionProvider { } } -impl EditPredictionProvider for CopilotCompletionProvider { +impl EditPredictionDelegate for CopilotEditPredictionDelegate { fn name() -> &'static str { "copilot" } @@ -56,7 +56,7 @@ impl EditPredictionProvider for CopilotCompletionProvider { "Copilot" } - fn show_completions_in_menu() -> bool { + fn show_predictions_in_menu() -> bool { true } @@ -314,7 +314,7 @@ mod tests { cx, ) .await; - let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot)); + let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot)); cx.update_editor(|editor, window, cx| { editor.set_edit_prediction_provider(Some(copilot_provider), window, cx) }); @@ -546,7 +546,7 @@ mod tests { cx, ) .await; - let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot)); + let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot)); cx.update_editor(|editor, window, cx| { editor.set_edit_prediction_provider(Some(copilot_provider), window, cx) }); @@ -670,7 +670,7 @@ mod tests { cx, ) .await; - let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot)); + let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot)); cx.update_editor(|editor, window, cx| { editor.set_edit_prediction_provider(Some(copilot_provider), window, cx) }); @@ -753,7 +753,7 @@ mod tests { window.focus(&editor.focus_handle(cx)); }) .unwrap(); - let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot)); + let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot)); editor .update(cx, |editor, window, cx| { editor.set_edit_prediction_provider(Some(copilot_provider), window, cx) @@ -848,7 +848,7 @@ mod tests { cx, ) .await; - let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot)); + let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot)); cx.update_editor(|editor, window, cx| { editor.set_edit_prediction_provider(Some(copilot_provider), window, cx) }); @@ -1000,7 +1000,7 @@ mod tests { window.focus(&editor.focus_handle(cx)) }) .unwrap(); - let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot)); + let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot)); editor .update(cx, |editor, window, cx| { editor.set_edit_prediction_provider(Some(copilot_provider), window, cx) diff --git a/crates/dap_adapters/src/python.rs b/crates/dap_adapters/src/python.rs index 2f84193ac9343cac9ff1cf52eb648bad1cd77896..a45e16dc32180e1e4ed3b2ec01c92ad07ffc5e93 100644 --- a/crates/dap_adapters/src/python.rs +++ b/crates/dap_adapters/src/python.rs @@ -870,7 +870,7 @@ impl DebugAdapter for PythonDebugAdapter { .active_toolchain( delegate.worktree_id(), base_path.into_arc(), - language::LanguageName::new(Self::LANGUAGE_NAME), + language::LanguageName::new_static(Self::LANGUAGE_NAME), cx, ) .await diff --git a/crates/debugger_ui/Cargo.toml b/crates/debugger_ui/Cargo.toml index 325bcc300ae637ab46c36b7a3e7875e197f7d3d2..fb79b1b0790b28d7204774720bf9c413cfed64e6 100644 --- a/crates/debugger_ui/Cargo.toml +++ b/crates/debugger_ui/Cargo.toml @@ -37,6 +37,7 @@ dap_adapters = { workspace = true, optional = true } db.workspace = true debugger_tools.workspace = true editor.workspace = true +feature_flags.workspace = true file_icons.workspace = true futures.workspace = true fuzzy.workspace = true @@ -82,6 +83,7 @@ dap_adapters = { workspace = true, features = ["test-support"] } debugger_tools = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } +language = { workspace = true, features = ["test-support"] } project = { workspace = true, features = ["test-support"] } tree-sitter-go.workspace = true unindent.workspace = true diff --git a/crates/debugger_ui/src/debugger_panel.rs b/crates/debugger_ui/src/debugger_panel.rs index ffdd4a22e3d092eb5d3d6626dcfe8b167ae03936..bdb308aafd0d2899f17bef732ac38239c4df6dda 100644 --- a/crates/debugger_ui/src/debugger_panel.rs +++ b/crates/debugger_ui/src/debugger_panel.rs @@ -15,10 +15,11 @@ use dap::adapters::DebugAdapterName; use dap::{DapRegistry, StartDebuggingRequestArguments}; use dap::{client::SessionId, debugger_settings::DebuggerSettings}; use editor::{Editor, MultiBufferOffset, ToPoint}; +use feature_flags::{FeatureFlag, FeatureFlagAppExt as _}; use gpui::{ - Action, App, AsyncWindowContext, ClipboardItem, Context, DismissEvent, Entity, EntityId, - EventEmitter, FocusHandle, Focusable, MouseButton, MouseDownEvent, Point, Subscription, Task, - WeakEntity, anchored, deferred, + Action, App, AsyncWindowContext, ClipboardItem, Context, Corner, DismissEvent, Entity, + EntityId, EventEmitter, FocusHandle, Focusable, MouseButton, MouseDownEvent, Point, + Subscription, Task, WeakEntity, anchored, deferred, }; use itertools::Itertools as _; @@ -31,7 +32,9 @@ use settings::Settings; use std::sync::{Arc, LazyLock}; use task::{DebugScenario, TaskContext}; use tree_sitter::{Query, StreamingIterator as _}; -use ui::{ContextMenu, Divider, PopoverMenuHandle, Tab, Tooltip, prelude::*}; +use ui::{ + ContextMenu, Divider, PopoverMenu, PopoverMenuHandle, SplitButton, Tab, Tooltip, prelude::*, +}; use util::rel_path::RelPath; use util::{ResultExt, debug_panic, maybe}; use workspace::SplitDirection; @@ -42,6 +45,12 @@ use workspace::{ }; use zed_actions::ToggleFocus; +pub struct DebuggerHistoryFeatureFlag; + +impl FeatureFlag for DebuggerHistoryFeatureFlag { + const NAME: &'static str = "debugger-history"; +} + const DEBUG_PANEL_KEY: &str = "DebugPanel"; pub struct DebugPanel { @@ -284,7 +293,7 @@ impl DebugPanel { } }); - session.update(cx, |session, _| match &mut session.mode { + session.update(cx, |session, _| match &mut session.state { SessionState::Booting(state_task) => { *state_task = Some(boot_task); } @@ -662,6 +671,12 @@ impl DebugPanel { ) }; + let thread_status = active_session + .as_ref() + .map(|session| session.read(cx).running_state()) + .and_then(|state| state.read(cx).thread_status(cx)) + .unwrap_or(project::debugger::session::ThreadStatus::Exited); + Some( div.w_full() .py_1() @@ -679,10 +694,6 @@ impl DebugPanel { .as_ref() .map(|session| session.read(cx).running_state()), |this, running_state| { - let thread_status = - running_state.read(cx).thread_status(cx).unwrap_or( - project::debugger::session::ThreadStatus::Exited, - ); let capabilities = running_state.read(cx).capabilities(cx); let supports_detach = running_state.read(cx).session().read(cx).is_attached(); @@ -871,36 +882,53 @@ impl DebugPanel { } }), ) + .when(supports_detach, |div| { + div.child( + IconButton::new( + "debug-disconnect", + IconName::DebugDetach, + ) + .disabled( + thread_status != ThreadStatus::Stopped + && thread_status != ThreadStatus::Running, + ) + .icon_size(IconSize::Small) + .on_click(window.listener_for( + running_state, + |this, _, _, cx| { + this.detach_client(cx); + }, + )) + .tooltip({ + let focus_handle = focus_handle.clone(); + move |_window, cx| { + Tooltip::for_action_in( + "Detach", + &Detach, + &focus_handle, + cx, + ) + } + }), + ) + }) .when( - supports_detach, - |div| { - div.child( - IconButton::new( - "debug-disconnect", - IconName::DebugDetach, - ) - .disabled( - thread_status != ThreadStatus::Stopped - && thread_status != ThreadStatus::Running, + cx.has_flag::(), + |this| { + this.child(Divider::vertical()).child( + SplitButton::new( + self.render_history_button( + &running_state, + thread_status, + window, + ), + self.render_history_toggle_button( + thread_status, + &running_state, + ) + .into_any_element(), ) - .icon_size(IconSize::Small) - .on_click(window.listener_for( - running_state, - |this, _, _, cx| { - this.detach_client(cx); - }, - )) - .tooltip({ - let focus_handle = focus_handle.clone(); - move |_window, cx| { - Tooltip::for_action_in( - "Detach", - &Detach, - &focus_handle, - cx, - ) - } - }), + .style(ui::SplitButtonStyle::Outlined), ) }, ) @@ -1317,6 +1345,97 @@ impl DebugPanel { }); } } + + fn render_history_button( + &self, + running_state: &Entity, + thread_status: ThreadStatus, + window: &mut Window, + ) -> IconButton { + IconButton::new("debug-back-in-history", IconName::HistoryRerun) + .icon_size(IconSize::Small) + .on_click(window.listener_for(running_state, |this, _, _window, cx| { + this.session().update(cx, |session, cx| { + let ix = session + .active_snapshot_index() + .unwrap_or_else(|| session.historic_snapshots().len()); + + session.select_historic_snapshot(Some(ix.saturating_sub(1)), cx); + }) + })) + .disabled( + thread_status == ThreadStatus::Running || thread_status == ThreadStatus::Stepping, + ) + } + + fn render_history_toggle_button( + &self, + thread_status: ThreadStatus, + running_state: &Entity, + ) -> impl IntoElement { + PopoverMenu::new("debug-back-in-history-menu") + .trigger( + ui::ButtonLike::new_rounded_right("debug-back-in-history-menu-trigger") + .layer(ui::ElevationIndex::ModalSurface) + .size(ui::ButtonSize::None) + .child( + div() + .px_1() + .child(Icon::new(IconName::ChevronDown).size(IconSize::XSmall)), + ) + .disabled( + thread_status == ThreadStatus::Running + || thread_status == ThreadStatus::Stepping, + ), + ) + .menu({ + let running_state = running_state.clone(); + move |window, cx| { + let handler = + |ix: Option, running_state: Entity, cx: &mut App| { + running_state.update(cx, |state, cx| { + state.session().update(cx, |session, cx| { + session.select_historic_snapshot(ix, cx); + }) + }) + }; + + let running_state = running_state.clone(); + Some(ContextMenu::build( + window, + cx, + move |mut context_menu, _window, cx| { + let history = running_state + .read(cx) + .session() + .read(cx) + .historic_snapshots(); + + context_menu = context_menu.entry("Current State", None, { + let running_state = running_state.clone(); + move |_window, cx| { + handler(None, running_state.clone(), cx); + } + }); + context_menu = context_menu.separator(); + + for (ix, _) in history.iter().enumerate().rev() { + context_menu = + context_menu.entry(format!("history-{}", ix + 1), None, { + let running_state = running_state.clone(); + move |_window, cx| { + handler(Some(ix), running_state.clone(), cx); + } + }); + } + + context_menu + }, + )) + } + }) + .anchor(Corner::TopRight) + } } async fn register_session_inner( diff --git a/crates/debugger_ui/src/debugger_ui.rs b/crates/debugger_ui/src/debugger_ui.rs index a9abb50bb68851334285b05064176e0347474014..bd5a7cda4a21a3d3fd0ac132d6ba2e7aace68722 100644 --- a/crates/debugger_ui/src/debugger_ui.rs +++ b/crates/debugger_ui/src/debugger_ui.rs @@ -387,7 +387,7 @@ pub fn init(cx: &mut App) { window.on_action( TypeId::of::(), move |_, _, window, cx| { - maybe!({ + let status = maybe!({ let text = editor .update(cx, |editor, cx| { let range = editor @@ -411,7 +411,13 @@ pub fn init(cx: &mut App) { state.session().update(cx, |session, cx| { session - .evaluate(text, None, stack_id, None, cx) + .evaluate( + text, + Some(dap::EvaluateArgumentsContext::Repl), + stack_id, + None, + cx, + ) .detach(); }); }); @@ -419,6 +425,9 @@ pub fn init(cx: &mut App) { Some(()) }); + if status.is_some() { + cx.stop_propagation(); + } }, ); }) diff --git a/crates/debugger_ui/src/new_process_modal.rs b/crates/debugger_ui/src/new_process_modal.rs index 40187cef9cc55cb4192a3cea773f42dca15a2571..24f6cc8f341519425b33433c3324f8afe7401ab7 100644 --- a/crates/debugger_ui/src/new_process_modal.rs +++ b/crates/debugger_ui/src/new_process_modal.rs @@ -881,7 +881,6 @@ impl ConfigureMode { .label("Stop on Entry") .label_position(SwitchLabelPosition::Start) .label_size(LabelSize::Default) - .color(ui::SwitchColor::Accent) .on_click({ let this = cx.weak_entity(); move |state, _, cx| { @@ -1023,7 +1022,7 @@ impl DebugDelegate { Some(TaskSourceKind::Lsp { language_name, .. }) => { Some(format!("LSP: {language_name}")) } - Some(TaskSourceKind::Language { name }) => Some(format!("Lang: {name}")), + Some(TaskSourceKind::Language { name }) => Some(format!("Language: {name}")), _ => context.clone().and_then(|ctx| { ctx.task_context .task_variables diff --git a/crates/debugger_ui/src/session/running.rs b/crates/debugger_ui/src/session/running.rs index b82f839edee82f884c1419d44a2344c39c8abd0d..bc99d6ac8e42b0a706df4a09177ae2103d5939e2 100644 --- a/crates/debugger_ui/src/session/running.rs +++ b/crates/debugger_ui/src/session/running.rs @@ -1743,7 +1743,7 @@ impl RunningState { let is_building = self.session.update(cx, |session, cx| { session.shutdown(cx).detach(); - matches!(session.mode, session::SessionState::Booting(_)) + matches!(session.state, session::SessionState::Booting(_)) }); if is_building { diff --git a/crates/debugger_ui/src/session/running/loaded_source_list.rs b/crates/debugger_ui/src/session/running/loaded_source_list.rs index 921ebd8b5f5bdfe8a3c8a8f7bb1625bd1ffad7fb..e55fad336b5ee6dfbee1cb0c90ea3d19f561a2ba 100644 --- a/crates/debugger_ui/src/session/running/loaded_source_list.rs +++ b/crates/debugger_ui/src/session/running/loaded_source_list.rs @@ -17,7 +17,9 @@ impl LoadedSourceList { let list = ListState::new(0, gpui::ListAlignment::Top, px(1000.)); let _subscription = cx.subscribe(&session, |this, _, event, cx| match event { - SessionEvent::Stopped(_) | SessionEvent::LoadedSources => { + SessionEvent::Stopped(_) + | SessionEvent::HistoricSnapshotSelected + | SessionEvent::LoadedSources => { this.invalidate = true; cx.notify(); } diff --git a/crates/debugger_ui/src/session/running/module_list.rs b/crates/debugger_ui/src/session/running/module_list.rs index 19f407eb23f8acf0aa665f5119ecfd2156eb685f..7d0228fc6851185d10a3a237257d6244d5a90c76 100644 --- a/crates/debugger_ui/src/session/running/module_list.rs +++ b/crates/debugger_ui/src/session/running/module_list.rs @@ -32,7 +32,9 @@ impl ModuleList { let focus_handle = cx.focus_handle(); let _subscription = cx.subscribe(&session, |this, _, event, cx| match event { - SessionEvent::Stopped(_) | SessionEvent::Modules => { + SessionEvent::Stopped(_) + | SessionEvent::HistoricSnapshotSelected + | SessionEvent::Modules => { if this._rebuild_task.is_some() { this.schedule_rebuild(cx); } diff --git a/crates/debugger_ui/src/session/running/stack_frame_list.rs b/crates/debugger_ui/src/session/running/stack_frame_list.rs index 96a910af4dd0ac901c6802c139ddd5b8b3d728bc..4dffb57a792cb5a5ed6bcc8003a8fa6f3b9af9de 100644 --- a/crates/debugger_ui/src/session/running/stack_frame_list.rs +++ b/crates/debugger_ui/src/session/running/stack_frame_list.rs @@ -4,6 +4,7 @@ use std::time::Duration; use anyhow::{Context as _, Result, anyhow}; use dap::StackFrameId; +use dap::adapters::DebugAdapterName; use db::kvp::KEY_VALUE_STORE; use gpui::{ Action, AnyElement, Entity, EventEmitter, FocusHandle, Focusable, FontWeight, ListState, @@ -20,7 +21,7 @@ use project::debugger::breakpoint_store::ActiveStackFrame; use project::debugger::session::{Session, SessionEvent, StackFrame, ThreadStatus}; use project::{ProjectItem, ProjectPath}; use ui::{Tooltip, WithScrollbar, prelude::*}; -use workspace::{ItemHandle, Workspace}; +use workspace::{ItemHandle, Workspace, WorkspaceId}; use super::RunningState; @@ -58,6 +59,14 @@ impl From for String { } } +pub(crate) fn stack_frame_filter_key( + adapter_name: &DebugAdapterName, + workspace_id: WorkspaceId, +) -> String { + let database_id: i64 = workspace_id.into(); + format!("stack-frame-list-filter-{}-{}", adapter_name.0, database_id) +} + pub struct StackFrameList { focus_handle: FocusHandle, _subscription: Subscription, @@ -97,7 +106,9 @@ impl StackFrameList { SessionEvent::Threads => { this.schedule_refresh(false, window, cx); } - SessionEvent::Stopped(..) | SessionEvent::StackTrace => { + SessionEvent::Stopped(..) + | SessionEvent::StackTrace + | SessionEvent::HistoricSnapshotSelected => { this.schedule_refresh(true, window, cx); } _ => {} @@ -105,14 +116,18 @@ impl StackFrameList { let list_state = ListState::new(0, gpui::ListAlignment::Top, px(1000.)); - let list_filter = KEY_VALUE_STORE - .read_kvp(&format!( - "stack-frame-list-filter-{}", - session.read(cx).adapter().0 - )) + let list_filter = workspace + .read_with(cx, |workspace, _| workspace.database_id()) .ok() .flatten() - .map(StackFrameFilter::from_str_or_default) + .and_then(|database_id| { + let key = stack_frame_filter_key(&session.read(cx).adapter(), database_id); + KEY_VALUE_STORE + .read_kvp(&key) + .ok() + .flatten() + .map(StackFrameFilter::from_str_or_default) + }) .unwrap_or(StackFrameFilter::All); let mut this = Self { @@ -225,7 +240,6 @@ impl StackFrameList { } this.update_in(cx, |this, window, cx| { this.build_entries(select_first, window, cx); - cx.notify(); }) .ok(); }) @@ -806,15 +820,8 @@ impl StackFrameList { .ok() .flatten() { - let database_id: i64 = database_id.into(); - let save_task = KEY_VALUE_STORE.write_kvp( - format!( - "stack-frame-list-filter-{}-{}", - self.session.read(cx).adapter().0, - database_id, - ), - self.list_filter.into(), - ); + let key = stack_frame_filter_key(&self.session.read(cx).adapter(), database_id); + let save_task = KEY_VALUE_STORE.write_kvp(key, self.list_filter.into()); cx.background_spawn(save_task).detach(); } diff --git a/crates/debugger_ui/src/session/running/variable_list.rs b/crates/debugger_ui/src/session/running/variable_list.rs index 1b455b59d7d12712a3d4adc713a6ed15e8166c6e..7b23cd685d93e6353d68dc57cd3998099ea56ad7 100644 --- a/crates/debugger_ui/src/session/running/variable_list.rs +++ b/crates/debugger_ui/src/session/running/variable_list.rs @@ -217,6 +217,12 @@ impl VariableList { let _subscriptions = vec![ cx.subscribe(&stack_frame_list, Self::handle_stack_frame_list_events), cx.subscribe(&session, |this, _, event, cx| match event { + SessionEvent::HistoricSnapshotSelected => { + this.selection.take(); + this.edited_path.take(); + this.selected_stack_frame_id.take(); + this.build_entries(cx); + } SessionEvent::Stopped(_) => { this.selection.take(); this.edited_path.take(); @@ -225,7 +231,6 @@ impl VariableList { SessionEvent::Variables | SessionEvent::Watchers => { this.build_entries(cx); } - _ => {} }), cx.on_focus_out(&focus_handle, window, |this, _, _, cx| { diff --git a/crates/debugger_ui/src/tests/inline_values.rs b/crates/debugger_ui/src/tests/inline_values.rs index 801e6d43623b50d69ea3ce297c274c2d7e5a8b14..379bc4c98f5341b089b5936ed8571da5a6280723 100644 --- a/crates/debugger_ui/src/tests/inline_values.rs +++ b/crates/debugger_ui/src/tests/inline_values.rs @@ -4,7 +4,7 @@ use dap::{Scope, StackFrame, Variable, requests::Variables}; use editor::{Editor, EditorMode, MultiBuffer}; use gpui::{BackgroundExecutor, TestAppContext, VisualTestContext}; use language::{ - Language, LanguageConfig, LanguageMatcher, tree_sitter_python, tree_sitter_rust, + Language, LanguageConfig, LanguageMatcher, rust_lang, tree_sitter_python, tree_sitter_typescript, }; use project::{FakeFs, Project}; @@ -224,7 +224,7 @@ fn main() { .unwrap(); buffer.update(cx, |buffer, cx| { - buffer.set_language(Some(Arc::new(rust_lang())), cx); + buffer.set_language(Some(rust_lang()), cx); }); let (editor, cx) = cx.add_window_view(|window, cx| { @@ -1521,23 +1521,6 @@ fn main() { }); } -fn rust_lang() -> Language { - let debug_variables_query = include_str!("../../../languages/src/rust/debugger.scm"); - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_debug_variables_query(debug_variables_query) - .unwrap() -} - #[gpui::test] async fn test_python_inline_values(executor: BackgroundExecutor, cx: &mut TestAppContext) { init_test(cx); @@ -1859,21 +1842,23 @@ fn python_lang() -> Language { .unwrap() } -fn go_lang() -> Language { +fn go_lang() -> Arc { let debug_variables_query = include_str!("../../../languages/src/go/debugger.scm"); - Language::new( - LanguageConfig { - name: "Go".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["go".to_string()], + Arc::new( + Language::new( + LanguageConfig { + name: "Go".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["go".to_string()], + ..Default::default() + }, ..Default::default() }, - ..Default::default() - }, - Some(tree_sitter_go::LANGUAGE.into()), + Some(tree_sitter_go::LANGUAGE.into()), + ) + .with_debug_variables_query(debug_variables_query) + .unwrap(), ) - .with_debug_variables_query(debug_variables_query) - .unwrap() } /// Test utility function for inline values testing @@ -1891,7 +1876,7 @@ async fn test_inline_values_util( before: &str, after: &str, active_debug_line: Option, - language: Language, + language: Arc, executor: BackgroundExecutor, cx: &mut TestAppContext, ) { @@ -2091,7 +2076,7 @@ async fn test_inline_values_util( .unwrap(); buffer.update(cx, |buffer, cx| { - buffer.set_language(Some(Arc::new(language)), cx); + buffer.set_language(Some(language), cx); }); let (editor, cx) = cx.add_window_view(|window, cx| { @@ -2276,55 +2261,61 @@ fn main() { .await; } -fn javascript_lang() -> Language { +fn javascript_lang() -> Arc { let debug_variables_query = include_str!("../../../languages/src/javascript/debugger.scm"); - Language::new( - LanguageConfig { - name: "JavaScript".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["js".to_string()], + Arc::new( + Language::new( + LanguageConfig { + name: "JavaScript".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["js".to_string()], + ..Default::default() + }, ..Default::default() }, - ..Default::default() - }, - Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()), + Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()), + ) + .with_debug_variables_query(debug_variables_query) + .unwrap(), ) - .with_debug_variables_query(debug_variables_query) - .unwrap() } -fn typescript_lang() -> Language { +fn typescript_lang() -> Arc { let debug_variables_query = include_str!("../../../languages/src/typescript/debugger.scm"); - Language::new( - LanguageConfig { - name: "TypeScript".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["ts".to_string()], + Arc::new( + Language::new( + LanguageConfig { + name: "TypeScript".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["ts".to_string()], + ..Default::default() + }, ..Default::default() }, - ..Default::default() - }, - Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()), + Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()), + ) + .with_debug_variables_query(debug_variables_query) + .unwrap(), ) - .with_debug_variables_query(debug_variables_query) - .unwrap() } -fn tsx_lang() -> Language { +fn tsx_lang() -> Arc { let debug_variables_query = include_str!("../../../languages/src/tsx/debugger.scm"); - Language::new( - LanguageConfig { - name: "TSX".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["tsx".to_string()], + Arc::new( + Language::new( + LanguageConfig { + name: "TSX".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["tsx".to_string()], + ..Default::default() + }, ..Default::default() }, - ..Default::default() - }, - Some(tree_sitter_typescript::LANGUAGE_TSX.into()), + Some(tree_sitter_typescript::LANGUAGE_TSX.into()), + ) + .with_debug_variables_query(debug_variables_query) + .unwrap(), ) - .with_debug_variables_query(debug_variables_query) - .unwrap() } #[gpui::test] diff --git a/crates/debugger_ui/src/tests/stack_frame_list.rs b/crates/debugger_ui/src/tests/stack_frame_list.rs index 05e638e2321bb6fcb4504a8bc8c81123f1b09a33..445d5a01d9f062501c889a9d5aa920de6f037914 100644 --- a/crates/debugger_ui/src/tests/stack_frame_list.rs +++ b/crates/debugger_ui/src/tests/stack_frame_list.rs @@ -1,12 +1,15 @@ use crate::{ debugger_panel::DebugPanel, - session::running::stack_frame_list::{StackFrameEntry, StackFrameFilter}, + session::running::stack_frame_list::{ + StackFrameEntry, StackFrameFilter, stack_frame_filter_key, + }, tests::{active_debug_session_panel, init_test, init_test_workspace, start_debug_session}, }; use dap::{ StackFrame, requests::{Scopes, StackTrace, Threads}, }; +use db::kvp::KEY_VALUE_STORE; use editor::{Editor, ToPoint as _}; use gpui::{BackgroundExecutor, TestAppContext, VisualTestContext}; use project::{FakeFs, Project}; @@ -1085,3 +1088,180 @@ async fn test_stack_frame_filter(executor: BackgroundExecutor, cx: &mut TestAppC ); }); } + +#[gpui::test] +async fn test_stack_frame_filter_persistence( + executor: BackgroundExecutor, + cx: &mut TestAppContext, +) { + init_test(cx); + + let fs = FakeFs::new(executor.clone()); + + fs.insert_tree( + path!("/project"), + json!({ + "src": { + "test.js": "function main() { console.log('hello'); }", + } + }), + ) + .await; + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let workspace = init_test_workspace(&project, cx).await; + let cx = &mut VisualTestContext::from_window(*workspace, cx); + workspace + .update(cx, |workspace, _, _| { + workspace.set_random_database_id(); + }) + .unwrap(); + + let threads_response = dap::ThreadsResponse { + threads: vec![dap::Thread { + id: 1, + name: "Thread 1".into(), + }], + }; + + let stack_trace_response = dap::StackTraceResponse { + stack_frames: vec![StackFrame { + id: 1, + name: "main".into(), + source: Some(dap::Source { + name: Some("test.js".into()), + path: Some(path!("/project/src/test.js").into()), + source_reference: None, + presentation_hint: None, + origin: None, + sources: None, + adapter_data: None, + checksums: None, + }), + line: 1, + column: 1, + end_line: None, + end_column: None, + can_restart: None, + instruction_pointer_reference: None, + module_id: None, + presentation_hint: None, + }], + total_frames: None, + }; + + let stopped_event = dap::StoppedEvent { + reason: dap::StoppedEventReason::Pause, + description: None, + thread_id: Some(1), + preserve_focus_hint: None, + text: None, + all_threads_stopped: None, + hit_breakpoint_ids: None, + }; + + let session = start_debug_session(&workspace, cx, |_| {}).unwrap(); + let client = session.update(cx, |session, _| session.adapter_client().unwrap()); + let adapter_name = session.update(cx, |session, _| session.adapter()); + + client.on_request::({ + let threads_response = threads_response.clone(); + move |_, _| Ok(threads_response.clone()) + }); + + client.on_request::(move |_, _| Ok(dap::ScopesResponse { scopes: vec![] })); + + client.on_request::({ + let stack_trace_response = stack_trace_response.clone(); + move |_, _| Ok(stack_trace_response.clone()) + }); + + client + .fake_event(dap::messages::Events::Stopped(stopped_event.clone())) + .await; + + cx.run_until_parked(); + + let stack_frame_list = + active_debug_session_panel(workspace, cx).update(cx, |debug_panel_item, cx| { + debug_panel_item + .running_state() + .update(cx, |state, _| state.stack_frame_list().clone()) + }); + + stack_frame_list.update(cx, |stack_frame_list, _cx| { + assert_eq!( + stack_frame_list.list_filter(), + StackFrameFilter::All, + "Initial filter should be All" + ); + }); + + stack_frame_list.update(cx, |stack_frame_list, cx| { + stack_frame_list + .toggle_frame_filter(Some(project::debugger::session::ThreadStatus::Stopped), cx); + assert_eq!( + stack_frame_list.list_filter(), + StackFrameFilter::OnlyUserFrames, + "Filter should be OnlyUserFrames after toggle" + ); + }); + + cx.run_until_parked(); + + let workspace_id = workspace + .update(cx, |workspace, _window, _cx| workspace.database_id()) + .ok() + .flatten() + .expect("workspace id has to be some for this test to work properly"); + + let key = stack_frame_filter_key(&adapter_name, workspace_id); + let stored_value = KEY_VALUE_STORE.read_kvp(&key).unwrap(); + assert_eq!( + stored_value, + Some(StackFrameFilter::OnlyUserFrames.into()), + "Filter should be persisted in KVP store with key: {}", + key + ); + + client + .fake_event(dap::messages::Events::Terminated(None)) + .await; + cx.run_until_parked(); + + let session2 = start_debug_session(&workspace, cx, |_| {}).unwrap(); + let client2 = session2.update(cx, |session, _| session.adapter_client().unwrap()); + + client2.on_request::({ + let threads_response = threads_response.clone(); + move |_, _| Ok(threads_response.clone()) + }); + + client2.on_request::(move |_, _| Ok(dap::ScopesResponse { scopes: vec![] })); + + client2.on_request::({ + let stack_trace_response = stack_trace_response.clone(); + move |_, _| Ok(stack_trace_response.clone()) + }); + + client2 + .fake_event(dap::messages::Events::Stopped(stopped_event.clone())) + .await; + + cx.run_until_parked(); + + let stack_frame_list2 = + active_debug_session_panel(workspace, cx).update(cx, |debug_panel_item, cx| { + debug_panel_item + .running_state() + .update(cx, |state, _| state.stack_frame_list().clone()) + }); + + stack_frame_list2.update(cx, |stack_frame_list, _cx| { + assert_eq!( + stack_frame_list.list_filter(), + StackFrameFilter::OnlyUserFrames, + "Filter should be restored from KVP store in new session" + ); + }); +} diff --git a/crates/edit_prediction/Cargo.toml b/crates/edit_prediction/Cargo.toml index 2c6888d14be49c857e7805fb63f9f9335ac32c8e..6e62cfa6f038671d595c5671de147cdc2125064d 100644 --- a/crates/edit_prediction/Cargo.toml +++ b/crates/edit_prediction/Cargo.toml @@ -11,7 +11,69 @@ workspace = true [lib] path = "src/edit_prediction.rs" +[features] +eval-support = [] + [dependencies] +ai_onboarding.workspace = true +anyhow.workspace = true +arrayvec.workspace = true +brotli.workspace = true client.workspace = true +cloud_llm_client.workspace = true +cloud_zeta2_prompt.workspace = true +collections.workspace = true +copilot.workspace = true +credentials_provider.workspace = true +db.workspace = true +edit_prediction_types.workspace = true +edit_prediction_context.workspace = true +feature_flags.workspace = true +fs.workspace = true +futures.workspace = true gpui.workspace = true +indoc.workspace = true +itertools.workspace = true language.workspace = true +language_model.workspace = true +log.workspace = true +lsp.workspace = true +menu.workspace = true +open_ai.workspace = true +postage.workspace = true +pretty_assertions.workspace = true +project.workspace = true +rand.workspace = true +regex.workspace = true +release_channel.workspace = true +semver.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +smol.workspace = true +strsim.workspace = true +strum.workspace = true +telemetry.workspace = true +telemetry_events.workspace = true +thiserror.workspace = true +ui.workspace = true +util.workspace = true +uuid.workspace = true +workspace.workspace = true +worktree.workspace = true +zed_actions.workspace = true + +[dev-dependencies] +clock = { workspace = true, features = ["test-support"] } +cloud_api_types.workspace = true +cloud_llm_client = { workspace = true, features = ["test-support"] } +ctor.workspace = true +gpui = { workspace = true, features = ["test-support"] } +indoc.workspace = true +language = { workspace = true, features = ["test-support"] } +language_model = { workspace = true, features = ["test-support"] } +lsp.workspace = true +parking_lot.workspace = true +project = { workspace = true, features = ["test-support"] } +settings = { workspace = true, features = ["test-support"] } +zlog.workspace = true diff --git a/crates/zeta/license_examples/0bsd.txt b/crates/edit_prediction/license_examples/0bsd.txt similarity index 100% rename from crates/zeta/license_examples/0bsd.txt rename to crates/edit_prediction/license_examples/0bsd.txt diff --git a/crates/zeta/license_examples/apache-2.0-ex0.txt b/crates/edit_prediction/license_examples/apache-2.0-ex0.txt similarity index 100% rename from crates/zeta/license_examples/apache-2.0-ex0.txt rename to crates/edit_prediction/license_examples/apache-2.0-ex0.txt diff --git a/crates/zeta/license_examples/apache-2.0-ex1.txt b/crates/edit_prediction/license_examples/apache-2.0-ex1.txt similarity index 100% rename from crates/zeta/license_examples/apache-2.0-ex1.txt rename to crates/edit_prediction/license_examples/apache-2.0-ex1.txt diff --git a/crates/zeta/license_examples/apache-2.0-ex2.txt b/crates/edit_prediction/license_examples/apache-2.0-ex2.txt similarity index 100% rename from crates/zeta/license_examples/apache-2.0-ex2.txt rename to crates/edit_prediction/license_examples/apache-2.0-ex2.txt diff --git a/crates/zeta/license_examples/apache-2.0-ex3.txt b/crates/edit_prediction/license_examples/apache-2.0-ex3.txt similarity index 100% rename from crates/zeta/license_examples/apache-2.0-ex3.txt rename to crates/edit_prediction/license_examples/apache-2.0-ex3.txt diff --git a/crates/zeta/license_examples/apache-2.0-ex4.txt b/crates/edit_prediction/license_examples/apache-2.0-ex4.txt similarity index 100% rename from crates/zeta/license_examples/apache-2.0-ex4.txt rename to crates/edit_prediction/license_examples/apache-2.0-ex4.txt diff --git a/crates/zeta/license_examples/bsd-1-clause.txt b/crates/edit_prediction/license_examples/bsd-1-clause.txt similarity index 100% rename from crates/zeta/license_examples/bsd-1-clause.txt rename to crates/edit_prediction/license_examples/bsd-1-clause.txt diff --git a/crates/zeta/license_examples/bsd-2-clause-ex0.txt b/crates/edit_prediction/license_examples/bsd-2-clause-ex0.txt similarity index 100% rename from crates/zeta/license_examples/bsd-2-clause-ex0.txt rename to crates/edit_prediction/license_examples/bsd-2-clause-ex0.txt diff --git a/crates/zeta/license_examples/bsd-3-clause-ex0.txt b/crates/edit_prediction/license_examples/bsd-3-clause-ex0.txt similarity index 100% rename from crates/zeta/license_examples/bsd-3-clause-ex0.txt rename to crates/edit_prediction/license_examples/bsd-3-clause-ex0.txt diff --git a/crates/zeta/license_examples/bsd-3-clause-ex1.txt b/crates/edit_prediction/license_examples/bsd-3-clause-ex1.txt similarity index 100% rename from crates/zeta/license_examples/bsd-3-clause-ex1.txt rename to crates/edit_prediction/license_examples/bsd-3-clause-ex1.txt diff --git a/crates/zeta/license_examples/bsd-3-clause-ex2.txt b/crates/edit_prediction/license_examples/bsd-3-clause-ex2.txt similarity index 100% rename from crates/zeta/license_examples/bsd-3-clause-ex2.txt rename to crates/edit_prediction/license_examples/bsd-3-clause-ex2.txt diff --git a/crates/zeta/license_examples/bsd-3-clause-ex3.txt b/crates/edit_prediction/license_examples/bsd-3-clause-ex3.txt similarity index 100% rename from crates/zeta/license_examples/bsd-3-clause-ex3.txt rename to crates/edit_prediction/license_examples/bsd-3-clause-ex3.txt diff --git a/crates/zeta/license_examples/bsd-3-clause-ex4.txt b/crates/edit_prediction/license_examples/bsd-3-clause-ex4.txt similarity index 100% rename from crates/zeta/license_examples/bsd-3-clause-ex4.txt rename to crates/edit_prediction/license_examples/bsd-3-clause-ex4.txt diff --git a/crates/zeta/license_examples/isc.txt b/crates/edit_prediction/license_examples/isc.txt similarity index 100% rename from crates/zeta/license_examples/isc.txt rename to crates/edit_prediction/license_examples/isc.txt diff --git a/crates/zeta/license_examples/mit-ex0.txt b/crates/edit_prediction/license_examples/mit-ex0.txt similarity index 100% rename from crates/zeta/license_examples/mit-ex0.txt rename to crates/edit_prediction/license_examples/mit-ex0.txt diff --git a/crates/zeta/license_examples/mit-ex1.txt b/crates/edit_prediction/license_examples/mit-ex1.txt similarity index 100% rename from crates/zeta/license_examples/mit-ex1.txt rename to crates/edit_prediction/license_examples/mit-ex1.txt diff --git a/crates/zeta/license_examples/mit-ex2.txt b/crates/edit_prediction/license_examples/mit-ex2.txt similarity index 100% rename from crates/zeta/license_examples/mit-ex2.txt rename to crates/edit_prediction/license_examples/mit-ex2.txt diff --git a/crates/zeta/license_examples/mit-ex3.txt b/crates/edit_prediction/license_examples/mit-ex3.txt similarity index 100% rename from crates/zeta/license_examples/mit-ex3.txt rename to crates/edit_prediction/license_examples/mit-ex3.txt diff --git a/crates/zeta/license_examples/upl-1.0.txt b/crates/edit_prediction/license_examples/upl-1.0.txt similarity index 100% rename from crates/zeta/license_examples/upl-1.0.txt rename to crates/edit_prediction/license_examples/upl-1.0.txt diff --git a/crates/zeta/license_examples/zlib-ex0.txt b/crates/edit_prediction/license_examples/zlib-ex0.txt similarity index 100% rename from crates/zeta/license_examples/zlib-ex0.txt rename to crates/edit_prediction/license_examples/zlib-ex0.txt diff --git a/crates/zeta/license_patterns/0bsd-pattern b/crates/edit_prediction/license_patterns/0bsd-pattern similarity index 100% rename from crates/zeta/license_patterns/0bsd-pattern rename to crates/edit_prediction/license_patterns/0bsd-pattern diff --git a/crates/zeta/license_patterns/apache-2.0-pattern b/crates/edit_prediction/license_patterns/apache-2.0-pattern similarity index 100% rename from crates/zeta/license_patterns/apache-2.0-pattern rename to crates/edit_prediction/license_patterns/apache-2.0-pattern diff --git a/crates/zeta/license_patterns/apache-2.0-reference-pattern b/crates/edit_prediction/license_patterns/apache-2.0-reference-pattern similarity index 100% rename from crates/zeta/license_patterns/apache-2.0-reference-pattern rename to crates/edit_prediction/license_patterns/apache-2.0-reference-pattern diff --git a/crates/zeta/license_patterns/bsd-pattern b/crates/edit_prediction/license_patterns/bsd-pattern similarity index 100% rename from crates/zeta/license_patterns/bsd-pattern rename to crates/edit_prediction/license_patterns/bsd-pattern diff --git a/crates/zeta/license_patterns/isc-pattern b/crates/edit_prediction/license_patterns/isc-pattern similarity index 100% rename from crates/zeta/license_patterns/isc-pattern rename to crates/edit_prediction/license_patterns/isc-pattern diff --git a/crates/zeta/license_patterns/mit-pattern b/crates/edit_prediction/license_patterns/mit-pattern similarity index 100% rename from crates/zeta/license_patterns/mit-pattern rename to crates/edit_prediction/license_patterns/mit-pattern diff --git a/crates/zeta/license_patterns/upl-1.0-pattern b/crates/edit_prediction/license_patterns/upl-1.0-pattern similarity index 100% rename from crates/zeta/license_patterns/upl-1.0-pattern rename to crates/edit_prediction/license_patterns/upl-1.0-pattern diff --git a/crates/zeta/license_patterns/zlib-pattern b/crates/edit_prediction/license_patterns/zlib-pattern similarity index 100% rename from crates/zeta/license_patterns/zlib-pattern rename to crates/edit_prediction/license_patterns/zlib-pattern diff --git a/crates/edit_prediction/src/cursor_excerpt.rs b/crates/edit_prediction/src/cursor_excerpt.rs new file mode 100644 index 0000000000000000000000000000000000000000..1f2f1d32ebcb2eaa151433bd49d275e0e2a3b817 --- /dev/null +++ b/crates/edit_prediction/src/cursor_excerpt.rs @@ -0,0 +1,78 @@ +use language::{BufferSnapshot, Point}; +use std::ops::Range; + +pub fn editable_and_context_ranges_for_cursor_position( + position: Point, + snapshot: &BufferSnapshot, + editable_region_token_limit: usize, + context_token_limit: usize, +) -> (Range, Range) { + let mut scope_range = position..position; + let mut remaining_edit_tokens = editable_region_token_limit; + + while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) { + let parent_tokens = guess_token_count(parent.byte_range().len()); + let parent_point_range = Point::new( + parent.start_position().row as u32, + parent.start_position().column as u32, + ) + ..Point::new( + parent.end_position().row as u32, + parent.end_position().column as u32, + ); + if parent_point_range == scope_range { + break; + } else if parent_tokens <= editable_region_token_limit { + scope_range = parent_point_range; + remaining_edit_tokens = editable_region_token_limit - parent_tokens; + } else { + break; + } + } + + let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens); + let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit); + (editable_range, context_range) +} + +fn expand_range( + snapshot: &BufferSnapshot, + range: Range, + mut remaining_tokens: usize, +) -> Range { + let mut expanded_range = range; + expanded_range.start.column = 0; + expanded_range.end.column = snapshot.line_len(expanded_range.end.row); + loop { + let mut expanded = false; + + if remaining_tokens > 0 && expanded_range.start.row > 0 { + expanded_range.start.row -= 1; + let line_tokens = + guess_token_count(snapshot.line_len(expanded_range.start.row) as usize); + remaining_tokens = remaining_tokens.saturating_sub(line_tokens); + expanded = true; + } + + if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row { + expanded_range.end.row += 1; + expanded_range.end.column = snapshot.line_len(expanded_range.end.row); + let line_tokens = guess_token_count(expanded_range.end.column as usize); + remaining_tokens = remaining_tokens.saturating_sub(line_tokens); + expanded = true; + } + + if !expanded { + break; + } + } + expanded_range +} + +/// Typical number of string bytes per token for the purposes of limiting model input. This is +/// intentionally low to err on the side of underestimating limits. +pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3; + +pub fn guess_token_count(bytes: usize) -> usize { + bytes / BYTES_PER_TOKEN_GUESS +} diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 1984383a9691ae9373973a3eb9f00db4e7e795f2..141fff3063b83d7e0003fddd6b4eba2d213d5fd5 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -1,298 +1,1944 @@ -use std::{ops::Range, sync::Arc}; +use anyhow::Result; +use arrayvec::ArrayVec; +use client::{Client, EditPredictionUsage, UserStore}; +use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat}; +use cloud_llm_client::{ + AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, + EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, + MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, + ZED_VERSION_HEADER_NAME, +}; +use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES; +use collections::{HashMap, HashSet}; +use db::kvp::{Dismissable, KEY_VALUE_STORE}; +use edit_prediction_context::EditPredictionExcerptOptions; +use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile}; +use feature_flags::{FeatureFlag, FeatureFlagAppExt as _}; +use futures::{ + AsyncReadExt as _, FutureExt as _, StreamExt as _, + channel::{ + mpsc::{self, UnboundedReceiver}, + oneshot, + }, + select_biased, +}; +use gpui::BackgroundExecutor; +use gpui::{ + App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions, + http_client::{self, AsyncBody, Method}, + prelude::*, +}; +use language::language_settings::all_language_settings; +use language::{Anchor, Buffer, File, Point, ToPoint}; +use language::{BufferSnapshot, OffsetRangeExt}; +use language_model::{LlmApiToken, RefreshLlmTokenListener}; +use project::{Project, ProjectPath, WorktreeId}; +use release_channel::AppVersion; +use semver::Version; +use serde::de::DeserializeOwned; +use settings::{EditPredictionProvider, SettingsStore, update_settings_file}; +use std::collections::{VecDeque, hash_map}; +use workspace::Workspace; + +use std::ops::Range; +use std::path::Path; +use std::rc::Rc; +use std::str::FromStr as _; +use std::sync::{Arc, LazyLock}; +use std::time::{Duration, Instant}; +use std::{env, mem}; +use thiserror::Error; +use util::{RangeExt as _, ResultExt as _}; +use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; + +mod cursor_excerpt; +mod license_detection; +pub mod mercury; +mod onboarding_modal; +pub mod open_ai_response; +mod prediction; +pub mod sweep_ai; +pub mod udiff; +mod xml_edits; +mod zed_edit_prediction_delegate; +pub mod zeta1; +pub mod zeta2; + +#[cfg(test)] +mod edit_prediction_tests; + +use crate::license_detection::LicenseDetectionWatcher; +use crate::mercury::Mercury; +use crate::onboarding_modal::ZedPredictModal; +pub use crate::prediction::EditPrediction; +pub use crate::prediction::EditPredictionId; +pub use crate::prediction::EditPredictionInputs; +use crate::prediction::EditPredictionResult; +pub use crate::sweep_ai::SweepAi; +pub use telemetry_events::EditPredictionRating; +pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate; + +actions!( + edit_prediction, + [ + /// Resets the edit prediction onboarding state. + ResetOnboarding, + /// Clears the edit prediction history. + ClearHistory, + ] +); + +/// Maximum number of events to track. +const EVENT_COUNT_MAX: usize = 6; +const CHANGE_GROUPING_LINE_SPAN: u32 = 8; +const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice"; +const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15); -use client::EditPredictionUsage; -use gpui::{App, Context, Entity, SharedString}; -use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt}; +pub struct SweepFeatureFlag; -// TODO: Find a better home for `Direction`. -// -// This should live in an ancestor crate of `editor` and `edit_prediction`, -// but at time of writing there isn't an obvious spot. -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum Direction { - Prev, - Next, +impl FeatureFlag for SweepFeatureFlag { + const NAME: &str = "sweep-ai"; } -#[derive(Clone)] -pub enum EditPrediction { - /// Edits within the buffer that requested the prediction - Local { - id: Option, - edits: Vec<(Range, Arc)>, - edit_preview: Option, - }, - /// Jump to a different file from the one that requested the prediction - Jump { - id: Option, - snapshot: language::BufferSnapshot, - target: language::Anchor, +pub struct MercuryFeatureFlag; + +impl FeatureFlag for MercuryFeatureFlag { + const NAME: &str = "mercury"; +} + +pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions { + context: EditPredictionExcerptOptions { + max_bytes: 512, + min_bytes: 128, + target_before_cursor_over_total_bytes: 0.5, }, + max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES, + prompt_format: PromptFormat::DEFAULT, +}; + +static USE_OLLAMA: LazyLock = + LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty())); + +static EDIT_PREDICTIONS_MODEL_ID: LazyLock = LazyLock::new(|| { + match env::var("ZED_ZETA2_MODEL").as_deref() { + Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten + Ok(model) => model, + Err(_) if *USE_OLLAMA => "qwen3-coder:30b", + Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten + } + .to_string() +}); +static PREDICT_EDITS_URL: LazyLock> = LazyLock::new(|| { + env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| { + if *USE_OLLAMA { + Some("http://localhost:11434/v1/chat/completions".into()) + } else { + None + } + }) +}); + +pub struct Zeta2FeatureFlag; + +impl FeatureFlag for Zeta2FeatureFlag { + const NAME: &'static str = "zeta2"; + + fn enabled_for_staff() -> bool { + true + } +} + +#[derive(Clone)] +struct EditPredictionStoreGlobal(Entity); + +impl Global for EditPredictionStoreGlobal {} + +pub struct EditPredictionStore { + client: Arc, + user_store: Entity, + llm_token: LlmApiToken, + _llm_token_subscription: Subscription, + projects: HashMap, + use_context: bool, + options: ZetaOptions, + update_required: bool, + debug_tx: Option>, + #[cfg(feature = "eval-support")] + eval_cache: Option>, + edit_prediction_model: EditPredictionModel, + pub sweep_ai: SweepAi, + pub mercury: Mercury, + data_collection_choice: DataCollectionChoice, + reject_predictions_tx: mpsc::UnboundedSender, + shown_predictions: VecDeque, + rated_predictions: HashSet, +} + +#[derive(Copy, Clone, Default, PartialEq, Eq)] +pub enum EditPredictionModel { + #[default] + Zeta1, + Zeta2, + Sweep, + Mercury, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ZetaOptions { + pub context: EditPredictionExcerptOptions, + pub max_prompt_bytes: usize, + pub prompt_format: predict_edits_v3::PromptFormat, +} + +#[derive(Debug)] +pub enum DebugEvent { + ContextRetrievalStarted(ContextRetrievalStartedDebugEvent), + ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent), + EditPredictionRequested(EditPredictionRequestedDebugEvent), +} + +#[derive(Debug)] +pub struct ContextRetrievalStartedDebugEvent { + pub project_entity_id: EntityId, + pub timestamp: Instant, + pub search_prompt: String, +} + +#[derive(Debug)] +pub struct ContextRetrievalFinishedDebugEvent { + pub project_entity_id: EntityId, + pub timestamp: Instant, + pub metadata: Vec<(&'static str, SharedString)>, +} + +#[derive(Debug)] +pub struct EditPredictionRequestedDebugEvent { + pub inputs: EditPredictionInputs, + pub retrieval_time: Duration, + pub buffer: WeakEntity, + pub position: Anchor, + pub local_prompt: Result, + pub response_rx: oneshot::Receiver<(Result, Duration)>, +} + +pub type RequestDebugInfo = predict_edits_v3::DebugInfo; + +struct ProjectState { + events: VecDeque>, + last_event: Option, + recent_paths: VecDeque, + registered_buffers: HashMap, + current_prediction: Option, + next_pending_prediction_id: usize, + pending_predictions: ArrayVec, + context_updates_tx: smol::channel::Sender<()>, + context_updates_rx: smol::channel::Receiver<()>, + last_prediction_refresh: Option<(EntityId, Instant)>, + cancelled_predictions: HashSet, + context: Entity, + license_detection_watchers: HashMap>, + _subscription: gpui::Subscription, +} + +impl ProjectState { + pub fn events(&self, cx: &App) -> Vec> { + self.events + .iter() + .cloned() + .chain( + self.last_event + .as_ref() + .and_then(|event| event.finalize(&self.license_detection_watchers, cx)), + ) + .collect() + } + + fn cancel_pending_prediction( + &mut self, + pending_prediction: PendingPrediction, + cx: &mut Context, + ) { + self.cancelled_predictions.insert(pending_prediction.id); + + cx.spawn(async move |this, cx| { + let Some(prediction_id) = pending_prediction.task.await else { + return; + }; + + this.update(cx, |this, _cx| { + this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false); + }) + .ok(); + }) + .detach() + } } -pub enum DataCollectionState { - /// The provider doesn't support data collection. - Unsupported, - /// Data collection is enabled. - Enabled { is_project_open_source: bool }, - /// Data collection is disabled or unanswered. - Disabled { is_project_open_source: bool }, +#[derive(Debug, Clone)] +struct CurrentEditPrediction { + pub requested_by: PredictionRequestedBy, + pub prediction: EditPrediction, + pub was_shown: bool, } -impl DataCollectionState { - pub fn is_supported(&self) -> bool { - !matches!(self, DataCollectionState::Unsupported) +impl CurrentEditPrediction { + fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool { + let Some(new_edits) = self + .prediction + .interpolate(&self.prediction.buffer.read(cx)) + else { + return false; + }; + + if self.prediction.buffer != old_prediction.prediction.buffer { + return true; + } + + let Some(old_edits) = old_prediction + .prediction + .interpolate(&old_prediction.prediction.buffer.read(cx)) + else { + return true; + }; + + let requested_by_buffer_id = self.requested_by.buffer_id(); + + // This reduces the occurrence of UI thrash from replacing edits + // + // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits. + if requested_by_buffer_id == Some(self.prediction.buffer.entity_id()) + && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id()) + && old_edits.len() == 1 + && new_edits.len() == 1 + { + let (old_range, old_text) = &old_edits[0]; + let (new_range, new_text) = &new_edits[0]; + new_range == old_range && new_text.starts_with(old_text.as_ref()) + } else { + true + } } +} - pub fn is_enabled(&self) -> bool { - matches!(self, DataCollectionState::Enabled { .. }) +#[derive(Debug, Clone)] +enum PredictionRequestedBy { + DiagnosticsUpdate, + Buffer(EntityId), +} + +impl PredictionRequestedBy { + pub fn buffer_id(&self) -> Option { + match self { + PredictionRequestedBy::DiagnosticsUpdate => None, + PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id), + } } +} + +#[derive(Debug)] +struct PendingPrediction { + id: usize, + task: Task>, +} + +/// A prediction from the perspective of a buffer. +#[derive(Debug)] +enum BufferEditPrediction<'a> { + Local { prediction: &'a EditPrediction }, + Jump { prediction: &'a EditPrediction }, +} + +#[cfg(test)] +impl std::ops::Deref for BufferEditPrediction<'_> { + type Target = EditPrediction; - pub fn is_project_open_source(&self) -> bool { + fn deref(&self) -> &Self::Target { match self { - Self::Enabled { - is_project_open_source, - } - | Self::Disabled { - is_project_open_source, - } => *is_project_open_source, - _ => false, + BufferEditPrediction::Local { prediction } => prediction, + BufferEditPrediction::Jump { prediction } => prediction, } } } -pub trait EditPredictionProvider: 'static + Sized { - fn name() -> &'static str; - fn display_name() -> &'static str; - fn show_completions_in_menu() -> bool; - fn show_tab_accept_marker() -> bool { - false +struct RegisteredBuffer { + snapshot: BufferSnapshot, + _subscriptions: [gpui::Subscription; 2], +} + +struct LastEvent { + old_snapshot: BufferSnapshot, + new_snapshot: BufferSnapshot, + end_edit_anchor: Option, +} + +impl LastEvent { + pub fn finalize( + &self, + license_detection_watchers: &HashMap>, + cx: &App, + ) -> Option> { + let path = buffer_path_with_id_fallback(&self.new_snapshot, cx); + let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx); + + let file = self.new_snapshot.file(); + let old_file = self.old_snapshot.file(); + + let in_open_source_repo = [file, old_file].iter().all(|file| { + file.is_some_and(|file| { + license_detection_watchers + .get(&file.worktree_id(cx)) + .is_some_and(|watcher| watcher.is_project_open_source()) + }) + }); + + let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text()); + + if path == old_path && diff.is_empty() { + None + } else { + Some(Arc::new(predict_edits_v3::Event::BufferChange { + old_path, + path, + diff, + in_open_source_repo, + // TODO: Actually detect if this edit was predicted or not + predicted: false, + })) + } } - fn supports_jump_to_edit() -> bool { - true +} + +fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc { + if let Some(file) = snapshot.file() { + file.full_path(cx).into() + } else { + Path::new(&format!("untitled-{}", snapshot.remote_id())).into() } +} - fn data_collection_state(&self, _cx: &App) -> DataCollectionState { - DataCollectionState::Unsupported +impl EditPredictionStore { + pub fn try_global(cx: &App) -> Option> { + cx.try_global::() + .map(|global| global.0.clone()) } - fn usage(&self, _cx: &App) -> Option { - None + pub fn global( + client: &Arc, + user_store: &Entity, + cx: &mut App, + ) -> Entity { + cx.try_global::() + .map(|global| global.0.clone()) + .unwrap_or_else(|| { + let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx)); + cx.set_global(EditPredictionStoreGlobal(ep_store.clone())); + ep_store + }) } - fn toggle_data_collection(&mut self, _cx: &mut App) {} - fn is_enabled( - &self, + pub fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { + let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); + let data_collection_choice = Self::load_data_collection_choice(); + + let llm_token = LlmApiToken::default(); + + let (reject_tx, reject_rx) = mpsc::unbounded(); + cx.background_spawn({ + let client = client.clone(); + let llm_token = llm_token.clone(); + let app_version = AppVersion::global(cx); + let background_executor = cx.background_executor().clone(); + async move { + Self::handle_rejected_predictions( + reject_rx, + client, + llm_token, + app_version, + background_executor, + ) + .await + } + }) + .detach(); + + let mut this = Self { + projects: HashMap::default(), + client, + user_store, + options: DEFAULT_OPTIONS, + use_context: false, + llm_token, + _llm_token_subscription: cx.subscribe( + &refresh_llm_token_listener, + |this, _listener, _event, cx| { + let client = this.client.clone(); + let llm_token = this.llm_token.clone(); + cx.spawn(async move |_this, _cx| { + llm_token.refresh(&client).await?; + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + }, + ), + update_required: false, + debug_tx: None, + #[cfg(feature = "eval-support")] + eval_cache: None, + edit_prediction_model: EditPredictionModel::Zeta2, + sweep_ai: SweepAi::new(cx), + mercury: Mercury::new(cx), + data_collection_choice, + reject_predictions_tx: reject_tx, + rated_predictions: Default::default(), + shown_predictions: Default::default(), + }; + + this.configure_context_retrieval(cx); + let weak_this = cx.weak_entity(); + cx.on_flags_ready(move |_, cx| { + weak_this + .update(cx, |this, cx| this.configure_context_retrieval(cx)) + .ok(); + }) + .detach(); + cx.observe_global::(|this, cx| { + this.configure_context_retrieval(cx); + }) + .detach(); + + this + } + + pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) { + self.edit_prediction_model = model; + } + + pub fn has_sweep_api_token(&self) -> bool { + self.sweep_ai + .api_token + .clone() + .now_or_never() + .flatten() + .is_some() + } + + pub fn has_mercury_api_token(&self) -> bool { + self.mercury + .api_token + .clone() + .now_or_never() + .flatten() + .is_some() + } + + #[cfg(feature = "eval-support")] + pub fn with_eval_cache(&mut self, cache: Arc) { + self.eval_cache = Some(cache); + } + + pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver { + let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded(); + self.debug_tx = Some(debug_watch_tx); + debug_watch_rx + } + + pub fn options(&self) -> &ZetaOptions { + &self.options + } + + pub fn set_options(&mut self, options: ZetaOptions) { + self.options = options; + } + + pub fn set_use_context(&mut self, use_context: bool) { + self.use_context = use_context; + } + + pub fn clear_history(&mut self) { + for project_state in self.projects.values_mut() { + project_state.events.clear(); + } + } + + pub fn context_for_project<'a>( + &'a self, + project: &Entity, + cx: &'a App, + ) -> &'a [RelatedFile] { + self.projects + .get(&project.entity_id()) + .map(|project| project.context.read(cx).related_files()) + .unwrap_or(&[]) + } + + pub fn usage(&self, cx: &App) -> Option { + if self.edit_prediction_model == EditPredictionModel::Zeta2 { + self.user_store.read(cx).edit_prediction_usage() + } else { + None + } + } + + pub fn register_project(&mut self, project: &Entity, cx: &mut Context) { + self.get_or_init_project(project, cx); + } + + pub fn register_buffer( + &mut self, buffer: &Entity, - cursor_position: language::Anchor, - cx: &App, - ) -> bool; - fn is_refreshing(&self, cx: &App) -> bool; - fn refresh( + project: &Entity, + cx: &mut Context, + ) { + let project_state = self.get_or_init_project(project, cx); + Self::register_buffer_impl(project_state, buffer, project, cx); + } + + fn get_or_init_project( &mut self, - buffer: Entity, - cursor_position: language::Anchor, - debounce: bool, + project: &Entity, cx: &mut Context, - ); - fn cycle( + ) -> &mut ProjectState { + let entity_id = project.entity_id(); + let (context_updates_tx, context_updates_rx) = smol::channel::unbounded(); + self.projects + .entry(entity_id) + .or_insert_with(|| ProjectState { + context: { + let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx)); + cx.subscribe( + &related_excerpt_store, + move |this, _, event, _| match event { + RelatedExcerptStoreEvent::StartedRefresh => { + if let Some(debug_tx) = this.debug_tx.clone() { + debug_tx + .unbounded_send(DebugEvent::ContextRetrievalStarted( + ContextRetrievalStartedDebugEvent { + project_entity_id: entity_id, + timestamp: Instant::now(), + search_prompt: String::new(), + }, + )) + .ok(); + } + } + RelatedExcerptStoreEvent::FinishedRefresh { + cache_hit_count, + cache_miss_count, + mean_definition_latency, + max_definition_latency, + } => { + if let Some(debug_tx) = this.debug_tx.clone() { + debug_tx + .unbounded_send(DebugEvent::ContextRetrievalFinished( + ContextRetrievalFinishedDebugEvent { + project_entity_id: entity_id, + timestamp: Instant::now(), + metadata: vec![ + ( + "Cache Hits", + format!( + "{}/{}", + cache_hit_count, + cache_hit_count + cache_miss_count + ) + .into(), + ), + ( + "Max LSP Time", + format!( + "{} ms", + max_definition_latency.as_millis() + ) + .into(), + ), + ( + "Mean LSP Time", + format!( + "{} ms", + mean_definition_latency.as_millis() + ) + .into(), + ), + ], + }, + )) + .ok(); + } + if let Some(project_state) = this.projects.get(&entity_id) { + project_state.context_updates_tx.send_blocking(()).ok(); + } + } + }, + ) + .detach(); + related_excerpt_store + }, + events: VecDeque::new(), + last_event: None, + recent_paths: VecDeque::new(), + context_updates_rx, + context_updates_tx, + registered_buffers: HashMap::default(), + current_prediction: None, + cancelled_predictions: HashSet::default(), + pending_predictions: ArrayVec::new(), + next_pending_prediction_id: 0, + last_prediction_refresh: None, + license_detection_watchers: HashMap::default(), + _subscription: cx.subscribe(&project, Self::handle_project_event), + }) + } + + pub fn project_context_updates( + &self, + project: &Entity, + ) -> Option> { + let project_state = self.projects.get(&project.entity_id())?; + Some(project_state.context_updates_rx.clone()) + } + + fn handle_project_event( &mut self, - buffer: Entity, - cursor_position: language::Anchor, - direction: Direction, + project: Entity, + event: &project::Event, cx: &mut Context, - ); - fn accept(&mut self, cx: &mut Context); - fn discard(&mut self, cx: &mut Context); - fn did_show(&mut self, _cx: &mut Context) {} - fn suggest( + ) { + // TODO [zeta2] init with recent paths + match event { + project::Event::ActiveEntryChanged(Some(active_entry_id)) => { + let Some(project_state) = self.projects.get_mut(&project.entity_id()) else { + return; + }; + let path = project.read(cx).path_for_entry(*active_entry_id, cx); + if let Some(path) = path { + if let Some(ix) = project_state + .recent_paths + .iter() + .position(|probe| probe == &path) + { + project_state.recent_paths.remove(ix); + } + project_state.recent_paths.push_front(path); + } + } + project::Event::DiagnosticsUpdated { .. } => { + if cx.has_flag::() { + self.refresh_prediction_from_diagnostics(project, cx); + } + } + _ => (), + } + } + + fn register_buffer_impl<'a>( + project_state: &'a mut ProjectState, + buffer: &Entity, + project: &Entity, + cx: &mut Context, + ) -> &'a mut RegisteredBuffer { + let buffer_id = buffer.entity_id(); + + if let Some(file) = buffer.read(cx).file() { + let worktree_id = file.worktree_id(cx); + if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) { + project_state + .license_detection_watchers + .entry(worktree_id) + .or_insert_with(|| { + let project_entity_id = project.entity_id(); + cx.observe_release(&worktree, move |this, _worktree, _cx| { + let Some(project_state) = this.projects.get_mut(&project_entity_id) + else { + return; + }; + project_state + .license_detection_watchers + .remove(&worktree_id); + }) + .detach(); + Rc::new(LicenseDetectionWatcher::new(&worktree, cx)) + }); + } + } + + match project_state.registered_buffers.entry(buffer_id) { + hash_map::Entry::Occupied(entry) => entry.into_mut(), + hash_map::Entry::Vacant(entry) => { + let snapshot = buffer.read(cx).snapshot(); + let project_entity_id = project.entity_id(); + entry.insert(RegisteredBuffer { + snapshot, + _subscriptions: [ + cx.subscribe(buffer, { + let project = project.downgrade(); + move |this, buffer, event, cx| { + if let language::BufferEvent::Edited = event + && let Some(project) = project.upgrade() + { + this.report_changes_for_buffer(&buffer, &project, cx); + } + } + }), + cx.observe_release(buffer, move |this, _buffer, _cx| { + let Some(project_state) = this.projects.get_mut(&project_entity_id) + else { + return; + }; + project_state.registered_buffers.remove(&buffer_id); + }), + ], + }) + } + } + } + + fn report_changes_for_buffer( &mut self, buffer: &Entity, - cursor_position: language::Anchor, + project: &Entity, cx: &mut Context, - ) -> Option; -} + ) { + let project_state = self.get_or_init_project(project, cx); + let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx); + + let new_snapshot = buffer.read(cx).snapshot(); + if new_snapshot.version == registered_buffer.snapshot.version { + return; + } + + let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone()); + let end_edit_anchor = new_snapshot + .anchored_edits_since::(&old_snapshot.version) + .last() + .map(|(_, range)| range.end); + let events = &mut project_state.events; + + if let Some(LastEvent { + new_snapshot: last_new_snapshot, + end_edit_anchor: last_end_edit_anchor, + .. + }) = project_state.last_event.as_mut() + { + let is_next_snapshot_of_same_buffer = old_snapshot.remote_id() + == last_new_snapshot.remote_id() + && old_snapshot.version == last_new_snapshot.version; + + let should_coalesce = is_next_snapshot_of_same_buffer + && end_edit_anchor + .as_ref() + .zip(last_end_edit_anchor.as_ref()) + .is_some_and(|(a, b)| { + let a = a.to_point(&new_snapshot); + let b = b.to_point(&new_snapshot); + a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN + }); -pub trait EditPredictionProviderHandle { - fn name(&self) -> &'static str; - fn display_name(&self) -> &'static str; - fn is_enabled( + if should_coalesce { + *last_end_edit_anchor = end_edit_anchor; + *last_new_snapshot = new_snapshot; + return; + } + } + + if events.len() + 1 >= EVENT_COUNT_MAX { + events.pop_front(); + } + + if let Some(event) = project_state.last_event.take() { + events.extend(event.finalize(&project_state.license_detection_watchers, cx)); + } + + project_state.last_event = Some(LastEvent { + old_snapshot, + new_snapshot, + end_edit_anchor, + }); + } + + fn current_prediction_for_buffer( &self, buffer: &Entity, - cursor_position: language::Anchor, + project: &Entity, cx: &App, - ) -> bool; - fn show_completions_in_menu(&self) -> bool; - fn show_tab_accept_marker(&self) -> bool; - fn supports_jump_to_edit(&self) -> bool; - fn data_collection_state(&self, cx: &App) -> DataCollectionState; - fn usage(&self, cx: &App) -> Option; - fn toggle_data_collection(&self, cx: &mut App); - fn is_refreshing(&self, cx: &App) -> bool; - fn refresh( - &self, - buffer: Entity, - cursor_position: language::Anchor, - debounce: bool, - cx: &mut App, - ); - fn cycle( - &self, - buffer: Entity, - cursor_position: language::Anchor, - direction: Direction, - cx: &mut App, - ); - fn did_show(&self, cx: &mut App); - fn accept(&self, cx: &mut App); - fn discard(&self, cx: &mut App); - fn suggest( - &self, - buffer: &Entity, - cursor_position: language::Anchor, - cx: &mut App, - ) -> Option; -} + ) -> Option> { + let project_state = self.projects.get(&project.entity_id())?; -impl EditPredictionProviderHandle for Entity -where - T: EditPredictionProvider, -{ - fn name(&self) -> &'static str { - T::name() - } + let CurrentEditPrediction { + requested_by, + prediction, + .. + } = project_state.current_prediction.as_ref()?; - fn display_name(&self) -> &'static str { - T::display_name() - } + if prediction.targets_buffer(buffer.read(cx)) { + Some(BufferEditPrediction::Local { prediction }) + } else { + let show_jump = match requested_by { + PredictionRequestedBy::Buffer(requested_by_buffer_id) => { + requested_by_buffer_id == &buffer.entity_id() + } + PredictionRequestedBy::DiagnosticsUpdate => true, + }; - fn show_completions_in_menu(&self) -> bool { - T::show_completions_in_menu() + if show_jump { + Some(BufferEditPrediction::Jump { prediction }) + } else { + None + } + } } - fn show_tab_accept_marker(&self) -> bool { - T::show_tab_accept_marker() - } + fn accept_current_prediction(&mut self, project: &Entity, cx: &mut Context) { + match self.edit_prediction_model { + EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {} + EditPredictionModel::Sweep | EditPredictionModel::Mercury => return, + } + + let Some(project_state) = self.projects.get_mut(&project.entity_id()) else { + return; + }; - fn supports_jump_to_edit(&self) -> bool { - T::supports_jump_to_edit() + let Some(prediction) = project_state.current_prediction.take() else { + return; + }; + let request_id = prediction.prediction.id.to_string(); + for pending_prediction in mem::take(&mut project_state.pending_predictions) { + project_state.cancel_pending_prediction(pending_prediction, cx); + } + + let client = self.client.clone(); + let llm_token = self.llm_token.clone(); + let app_version = AppVersion::global(cx); + cx.spawn(async move |this, cx| { + let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") { + http_client::Url::parse(&predict_edits_url)? + } else { + client + .http_client() + .build_zed_llm_url("/predict_edits/accept", &[])? + }; + + let response = cx + .background_spawn(Self::send_api_request::<()>( + move |builder| { + let req = builder.uri(url.as_ref()).body( + serde_json::to_string(&AcceptEditPredictionBody { + request_id: request_id.clone(), + })? + .into(), + ); + Ok(req?) + }, + client, + llm_token, + app_version, + )) + .await; + + Self::handle_api_response(&this, response, cx)?; + anyhow::Ok(()) + }) + .detach_and_log_err(cx); } - fn data_collection_state(&self, cx: &App) -> DataCollectionState { - self.read(cx).data_collection_state(cx) + async fn handle_rejected_predictions( + rx: UnboundedReceiver, + client: Arc, + llm_token: LlmApiToken, + app_version: Version, + background_executor: BackgroundExecutor, + ) { + let mut rx = std::pin::pin!(rx.peekable()); + let mut batched = Vec::new(); + + while let Some(rejection) = rx.next().await { + batched.push(rejection); + + if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 { + select_biased! { + next = rx.as_mut().peek().fuse() => { + if next.is_some() { + continue; + } + } + () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {}, + } + } + + let url = client + .http_client() + .build_zed_llm_url("/predict_edits/reject", &[]) + .unwrap(); + + let flush_count = batched + .len() + // in case items have accumulated after failure + .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST); + let start = batched.len() - flush_count; + + let body = RejectEditPredictionsBodyRef { + rejections: &batched[start..], + }; + + let result = Self::send_api_request::<()>( + |builder| { + let req = builder + .uri(url.as_ref()) + .body(serde_json::to_string(&body)?.into()); + anyhow::Ok(req?) + }, + client.clone(), + llm_token.clone(), + app_version.clone(), + ) + .await; + + if result.log_err().is_some() { + batched.drain(start..); + } + } } - fn usage(&self, cx: &App) -> Option { - self.read(cx).usage(cx) + fn reject_current_prediction( + &mut self, + reason: EditPredictionRejectReason, + project: &Entity, + ) { + if let Some(project_state) = self.projects.get_mut(&project.entity_id()) { + project_state.pending_predictions.clear(); + if let Some(prediction) = project_state.current_prediction.take() { + self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown); + } + }; } - fn toggle_data_collection(&self, cx: &mut App) { - self.update(cx, |this, cx| this.toggle_data_collection(cx)) + fn did_show_current_prediction(&mut self, project: &Entity, _cx: &mut Context) { + if let Some(project_state) = self.projects.get_mut(&project.entity_id()) { + if let Some(current_prediction) = project_state.current_prediction.as_mut() { + if !current_prediction.was_shown { + current_prediction.was_shown = true; + self.shown_predictions + .push_front(current_prediction.prediction.clone()); + if self.shown_predictions.len() > 50 { + let completion = self.shown_predictions.pop_back().unwrap(); + self.rated_predictions.remove(&completion.id); + } + } + } + } } - fn is_enabled( - &self, - buffer: &Entity, - cursor_position: language::Anchor, - cx: &App, - ) -> bool { - self.read(cx).is_enabled(buffer, cursor_position, cx) + fn reject_prediction( + &mut self, + prediction_id: EditPredictionId, + reason: EditPredictionRejectReason, + was_shown: bool, + ) { + match self.edit_prediction_model { + EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {} + EditPredictionModel::Sweep | EditPredictionModel::Mercury => return, + } + + self.reject_predictions_tx + .unbounded_send(EditPredictionRejection { + request_id: prediction_id.to_string(), + reason, + was_shown, + }) + .log_err(); } - fn is_refreshing(&self, cx: &App) -> bool { - self.read(cx).is_refreshing(cx) + fn is_refreshing(&self, project: &Entity) -> bool { + self.projects + .get(&project.entity_id()) + .is_some_and(|project_state| !project_state.pending_predictions.is_empty()) } - fn refresh( - &self, + pub fn refresh_prediction_from_buffer( + &mut self, + project: Entity, buffer: Entity, - cursor_position: language::Anchor, - debounce: bool, - cx: &mut App, + position: language::Anchor, + cx: &mut Context, ) { - self.update(cx, |this, cx| { - this.refresh(buffer, cursor_position, debounce, cx) + self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| { + let Some(request_task) = this + .update(cx, |this, cx| { + this.request_prediction( + &project, + &buffer, + position, + PredictEditsRequestTrigger::Other, + cx, + ) + }) + .log_err() + else { + return Task::ready(anyhow::Ok(None)); + }; + + cx.spawn(async move |_cx| { + request_task.await.map(|prediction_result| { + prediction_result.map(|prediction_result| { + ( + prediction_result, + PredictionRequestedBy::Buffer(buffer.entity_id()), + ) + }) + }) + }) }) } - fn cycle( - &self, - buffer: Entity, - cursor_position: language::Anchor, - direction: Direction, - cx: &mut App, + pub fn refresh_prediction_from_diagnostics( + &mut self, + project: Entity, + cx: &mut Context, + ) { + let Some(project_state) = self.projects.get_mut(&project.entity_id()) else { + return; + }; + + // Prefer predictions from buffer + if project_state.current_prediction.is_some() { + return; + }; + + self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| { + let Some(open_buffer_task) = project + .update(cx, |project, cx| { + project + .active_entry() + .and_then(|entry| project.path_for_entry(entry, cx)) + .map(|path| project.open_buffer(path, cx)) + }) + .log_err() + .flatten() + else { + return Task::ready(anyhow::Ok(None)); + }; + + cx.spawn(async move |cx| { + let active_buffer = open_buffer_task.await?; + let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + + let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location( + active_buffer, + &snapshot, + Default::default(), + Default::default(), + &project, + cx, + ) + .await? + else { + return anyhow::Ok(None); + }; + + let Some(prediction_result) = this + .update(cx, |this, cx| { + this.request_prediction( + &project, + &jump_buffer, + jump_position, + PredictEditsRequestTrigger::Diagnostics, + cx, + ) + })? + .await? + else { + return anyhow::Ok(None); + }; + + this.update(cx, |this, cx| { + Some(( + if this + .get_or_init_project(&project, cx) + .current_prediction + .is_none() + { + prediction_result + } else { + EditPredictionResult { + id: prediction_result.id, + prediction: Err(EditPredictionRejectReason::CurrentPreferred), + } + }, + PredictionRequestedBy::DiagnosticsUpdate, + )) + }) + }) + }); + } + + #[cfg(not(test))] + pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); + #[cfg(test)] + pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO; + + fn queue_prediction_refresh( + &mut self, + project: Entity, + throttle_entity: EntityId, + cx: &mut Context, + do_refresh: impl FnOnce( + WeakEntity, + &mut AsyncApp, + ) + -> Task>> + + 'static, ) { - self.update(cx, |this, cx| { - this.cycle(buffer, cursor_position, direction, cx) + let project_state = self.get_or_init_project(&project, cx); + let pending_prediction_id = project_state.next_pending_prediction_id; + project_state.next_pending_prediction_id += 1; + let last_request = project_state.last_prediction_refresh; + + let task = cx.spawn(async move |this, cx| { + if let Some((last_entity, last_timestamp)) = last_request + && throttle_entity == last_entity + && let Some(timeout) = + (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now()) + { + cx.background_executor().timer(timeout).await; + } + + // If this task was cancelled before the throttle timeout expired, + // do not perform a request. + let mut is_cancelled = true; + this.update(cx, |this, cx| { + let project_state = this.get_or_init_project(&project, cx); + if !project_state + .cancelled_predictions + .remove(&pending_prediction_id) + { + project_state.last_prediction_refresh = Some((throttle_entity, Instant::now())); + is_cancelled = false; + } + }) + .ok(); + if is_cancelled { + return None; + } + + let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten(); + let new_prediction_id = new_prediction_result + .as_ref() + .map(|(prediction, _)| prediction.id.clone()); + + // When a prediction completes, remove it from the pending list, and cancel + // any pending predictions that were enqueued before it. + this.update(cx, |this, cx| { + let project_state = this.get_or_init_project(&project, cx); + + let is_cancelled = project_state + .cancelled_predictions + .remove(&pending_prediction_id); + + let new_current_prediction = if !is_cancelled + && let Some((prediction_result, requested_by)) = new_prediction_result + { + match prediction_result.prediction { + Ok(prediction) => { + let new_prediction = CurrentEditPrediction { + requested_by, + prediction, + was_shown: false, + }; + + if let Some(current_prediction) = + project_state.current_prediction.as_ref() + { + if new_prediction.should_replace_prediction(¤t_prediction, cx) + { + this.reject_current_prediction( + EditPredictionRejectReason::Replaced, + &project, + ); + + Some(new_prediction) + } else { + this.reject_prediction( + new_prediction.prediction.id, + EditPredictionRejectReason::CurrentPreferred, + false, + ); + None + } + } else { + Some(new_prediction) + } + } + Err(reject_reason) => { + this.reject_prediction(prediction_result.id, reject_reason, false); + None + } + } + } else { + None + }; + + let project_state = this.get_or_init_project(&project, cx); + + if let Some(new_prediction) = new_current_prediction { + project_state.current_prediction = Some(new_prediction); + } + + let mut pending_predictions = mem::take(&mut project_state.pending_predictions); + for (ix, pending_prediction) in pending_predictions.iter().enumerate() { + if pending_prediction.id == pending_prediction_id { + pending_predictions.remove(ix); + for pending_prediction in pending_predictions.drain(0..ix) { + project_state.cancel_pending_prediction(pending_prediction, cx) + } + break; + } + } + this.get_or_init_project(&project, cx).pending_predictions = pending_predictions; + cx.notify(); + }) + .ok(); + + new_prediction_id + }); + + if project_state.pending_predictions.len() <= 1 { + project_state.pending_predictions.push(PendingPrediction { + id: pending_prediction_id, + task, + }); + } else if project_state.pending_predictions.len() == 2 { + let pending_prediction = project_state.pending_predictions.pop().unwrap(); + project_state.pending_predictions.push(PendingPrediction { + id: pending_prediction_id, + task, + }); + project_state.cancel_pending_prediction(pending_prediction, cx); + } + } + + pub fn request_prediction( + &mut self, + project: &Entity, + active_buffer: &Entity, + position: language::Anchor, + trigger: PredictEditsRequestTrigger, + cx: &mut Context, + ) -> Task>> { + self.request_prediction_internal( + project.clone(), + active_buffer.clone(), + position, + trigger, + cx.has_flag::(), + cx, + ) + } + + fn request_prediction_internal( + &mut self, + project: Entity, + active_buffer: Entity, + position: language::Anchor, + trigger: PredictEditsRequestTrigger, + allow_jump: bool, + cx: &mut Context, + ) -> Task>> { + const DIAGNOSTIC_LINES_RANGE: u32 = 20; + + self.get_or_init_project(&project, cx); + let project_state = self.projects.get(&project.entity_id()).unwrap(); + let events = project_state.events(cx); + let has_events = !events.is_empty(); + + let snapshot = active_buffer.read(cx).snapshot(); + let cursor_point = position.to_point(&snapshot); + let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE); + let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE; + let diagnostic_search_range = + Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0); + + let related_files = if self.use_context { + self.context_for_project(&project, cx).to_vec() + } else { + Vec::new() + }; + + let task = match self.edit_prediction_model { + EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1( + self, + &project, + &active_buffer, + snapshot.clone(), + position, + events, + trigger, + cx, + ), + EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2( + self, + &project, + &active_buffer, + snapshot.clone(), + position, + events, + related_files, + trigger, + cx, + ), + EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep( + &project, + &active_buffer, + snapshot.clone(), + position, + events, + &project_state.recent_paths, + related_files, + diagnostic_search_range.clone(), + cx, + ), + EditPredictionModel::Mercury => self.mercury.request_prediction( + &project, + &active_buffer, + snapshot.clone(), + position, + events, + &project_state.recent_paths, + related_files, + diagnostic_search_range.clone(), + cx, + ), + }; + + cx.spawn(async move |this, cx| { + let prediction = task.await?; + + if prediction.is_none() && allow_jump { + let cursor_point = position.to_point(&snapshot); + if has_events + && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location( + active_buffer.clone(), + &snapshot, + diagnostic_search_range, + cursor_point, + &project, + cx, + ) + .await? + { + return this + .update(cx, |this, cx| { + this.request_prediction_internal( + project, + jump_buffer, + jump_position, + trigger, + false, + cx, + ) + })? + .await; + } + + return anyhow::Ok(None); + } + + Ok(prediction) }) } - fn accept(&self, cx: &mut App) { - self.update(cx, |this, cx| this.accept(cx)) + async fn next_diagnostic_location( + active_buffer: Entity, + active_buffer_snapshot: &BufferSnapshot, + active_buffer_diagnostic_search_range: Range, + active_buffer_cursor_point: Point, + project: &Entity, + cx: &mut AsyncApp, + ) -> Result, language::Anchor)>> { + // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request + let mut jump_location = active_buffer_snapshot + .diagnostic_groups(None) + .into_iter() + .filter_map(|(_, group)| { + let range = &group.entries[group.primary_ix] + .range + .to_point(&active_buffer_snapshot); + if range.overlaps(&active_buffer_diagnostic_search_range) { + None + } else { + Some(range.start) + } + }) + .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row)) + .map(|position| { + ( + active_buffer.clone(), + active_buffer_snapshot.anchor_before(position), + ) + }); + + if jump_location.is_none() { + let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| { + let file = buffer.file()?; + + Some(ProjectPath { + worktree_id: file.worktree_id(cx), + path: file.path().clone(), + }) + })?; + + let buffer_task = project.update(cx, |project, cx| { + let (path, _, _) = project + .diagnostic_summaries(false, cx) + .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref()) + .max_by_key(|(path, _, _)| { + // find the buffer with errors that shares most parent directories + path.path + .components() + .zip( + active_buffer_path + .as_ref() + .map(|p| p.path.components()) + .unwrap_or_default(), + ) + .take_while(|(a, b)| a == b) + .count() + })?; + + Some(project.open_buffer(path, cx)) + })?; + + if let Some(buffer_task) = buffer_task { + let closest_buffer = buffer_task.await?; + + jump_location = closest_buffer + .read_with(cx, |buffer, _cx| { + buffer + .buffer_diagnostics(None) + .into_iter() + .min_by_key(|entry| entry.diagnostic.severity) + .map(|entry| entry.range.start) + })? + .map(|position| (closest_buffer, position)); + } + } + + anyhow::Ok(jump_location) } - fn discard(&self, cx: &mut App) { - self.update(cx, |this, cx| this.discard(cx)) + async fn send_raw_llm_request( + request: open_ai::Request, + client: Arc, + llm_token: LlmApiToken, + app_version: Version, + #[cfg(feature = "eval-support")] eval_cache: Option>, + #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind, + ) -> Result<(open_ai::Response, Option)> { + let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() { + http_client::Url::parse(&predict_edits_url)? + } else { + client + .http_client() + .build_zed_llm_url("/predict_edits/raw", &[])? + }; + + #[cfg(feature = "eval-support")] + let cache_key = if let Some(cache) = eval_cache { + use collections::FxHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher = FxHasher::default(); + url.hash(&mut hasher); + let request_str = serde_json::to_string_pretty(&request)?; + request_str.hash(&mut hasher); + let hash = hasher.finish(); + + let key = (eval_cache_kind, hash); + if let Some(response_str) = cache.read(key) { + return Ok((serde_json::from_str(&response_str)?, None)); + } + + Some((cache, request_str, key)) + } else { + None + }; + + let (response, usage) = Self::send_api_request( + |builder| { + let req = builder + .uri(url.as_ref()) + .body(serde_json::to_string(&request)?.into()); + Ok(req?) + }, + client, + llm_token, + app_version, + ) + .await?; + + #[cfg(feature = "eval-support")] + if let Some((cache, request, key)) = cache_key { + cache.write(key, &request, &serde_json::to_string_pretty(&response)?); + } + + Ok((response, usage)) } - fn did_show(&self, cx: &mut App) { - self.update(cx, |this, cx| this.did_show(cx)) + fn handle_api_response( + this: &WeakEntity, + response: Result<(T, Option)>, + cx: &mut gpui::AsyncApp, + ) -> Result { + match response { + Ok((data, usage)) => { + if let Some(usage) = usage { + this.update(cx, |this, cx| { + this.user_store.update(cx, |user_store, cx| { + user_store.update_edit_prediction_usage(usage, cx); + }); + }) + .ok(); + } + Ok(data) + } + Err(err) => { + if err.is::() { + cx.update(|cx| { + this.update(cx, |this, _cx| { + this.update_required = true; + }) + .ok(); + + let error_message: SharedString = err.to_string().into(); + show_app_notification( + NotificationId::unique::(), + cx, + move |cx| { + cx.new(|cx| { + ErrorMessagePrompt::new(error_message.clone(), cx) + .with_link_button("Update Zed", "https://zed.dev/releases") + }) + }, + ); + }) + .ok(); + } + Err(err) + } + } } - fn suggest( - &self, - buffer: &Entity, - cursor_position: language::Anchor, - cx: &mut App, - ) -> Option { - self.update(cx, |this, cx| this.suggest(buffer, cursor_position, cx)) - } -} - -/// Returns edits updated based on user edits since the old snapshot. None is returned if any user -/// edit is not a prefix of a predicted insertion. -pub fn interpolate_edits( - old_snapshot: &BufferSnapshot, - new_snapshot: &BufferSnapshot, - current_edits: &[(Range, Arc)], -) -> Option, Arc)>> { - let mut edits = Vec::new(); - - let mut model_edits = current_edits.iter().peekable(); - for user_edit in new_snapshot.edits_since::(&old_snapshot.version) { - while let Some((model_old_range, _)) = model_edits.peek() { - let model_old_range = model_old_range.to_offset(old_snapshot); - if model_old_range.end < user_edit.old.start { - let (model_old_range, model_new_text) = model_edits.next().unwrap(); - edits.push((model_old_range.clone(), model_new_text.clone())); + async fn send_api_request( + build: impl Fn(http_client::http::request::Builder) -> Result>, + client: Arc, + llm_token: LlmApiToken, + app_version: Version, + ) -> Result<(Res, Option)> + where + Res: DeserializeOwned, + { + let http_client = client.http_client(); + let mut token = llm_token.acquire(&client).await?; + let mut did_retry = false; + + loop { + let request_builder = http_client::Request::builder().method(Method::POST); + + let request = build( + request_builder + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", token)) + .header(ZED_VERSION_HEADER_NAME, app_version.to_string()), + )?; + + let mut response = http_client.send(request).await?; + + if let Some(minimum_required_version) = response + .headers() + .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME) + .and_then(|version| Version::from_str(version.to_str().ok()?).ok()) + { + anyhow::ensure!( + app_version >= minimum_required_version, + ZedUpdateRequiredError { + minimum_version: minimum_required_version + } + ); + } + + if response.status().is_success() { + let usage = EditPredictionUsage::from_headers(response.headers()).ok(); + + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + return Ok((serde_json::from_slice(&body)?, usage)); + } else if !did_retry + && response + .headers() + .get(EXPIRED_LLM_TOKEN_HEADER_NAME) + .is_some() + { + did_retry = true; + token = llm_token.refresh(&client).await?; } else { - break; + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + anyhow::bail!( + "Request failed with status: {:?}\nBody: {}", + response.status(), + body + ); } } + } - if let Some((model_old_range, model_new_text)) = model_edits.peek() { - let model_old_offset_range = model_old_range.to_offset(old_snapshot); - if user_edit.old == model_old_offset_range { - let user_new_text = new_snapshot - .text_for_range(user_edit.new.clone()) - .collect::(); + pub fn refresh_context( + &mut self, + project: &Entity, + buffer: &Entity, + cursor_position: language::Anchor, + cx: &mut Context, + ) { + if self.use_context { + self.get_or_init_project(project, cx) + .context + .update(cx, |store, cx| { + store.refresh(buffer.clone(), cursor_position, cx); + }); + } + } - if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) { - if !model_suffix.is_empty() { - let anchor = old_snapshot.anchor_after(user_edit.old.end); - edits.push((anchor..anchor, model_suffix.into())); - } + fn is_file_open_source( + &self, + project: &Entity, + file: &Arc, + cx: &App, + ) -> bool { + if !file.is_local() || file.is_private() { + return false; + } + let Some(project_state) = self.projects.get(&project.entity_id()) else { + return false; + }; + project_state + .license_detection_watchers + .get(&file.worktree_id(cx)) + .as_ref() + .is_some_and(|watcher| watcher.is_project_open_source()) + } + + fn can_collect_file(&self, project: &Entity, file: &Arc, cx: &App) -> bool { + self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx) + } - model_edits.next(); - continue; + fn can_collect_events(&self, events: &[Arc]) -> bool { + if !self.data_collection_choice.is_enabled() { + return false; + } + events.iter().all(|event| { + matches!( + event.as_ref(), + Event::BufferChange { + in_open_source_repo: true, + .. } + ) + }) + } + + fn load_data_collection_choice() -> DataCollectionChoice { + let choice = KEY_VALUE_STORE + .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE) + .log_err() + .flatten(); + + match choice.as_deref() { + Some("true") => DataCollectionChoice::Enabled, + Some("false") => DataCollectionChoice::Disabled, + Some(_) => { + log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'"); + DataCollectionChoice::NotAnswered } + None => DataCollectionChoice::NotAnswered, + } + } + + fn toggle_data_collection_choice(&mut self, cx: &mut Context) { + self.data_collection_choice = self.data_collection_choice.toggle(); + let new_choice = self.data_collection_choice; + db::write_and_log(cx, move || { + KEY_VALUE_STORE.write_kvp( + ZED_PREDICT_DATA_COLLECTION_CHOICE.into(), + new_choice.is_enabled().to_string(), + ) + }); + } + + pub fn shown_predictions(&self) -> impl DoubleEndedIterator { + self.shown_predictions.iter() + } + + pub fn shown_completions_len(&self) -> usize { + self.shown_predictions.len() + } + + pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool { + self.rated_predictions.contains(id) + } + + pub fn rate_prediction( + &mut self, + prediction: &EditPrediction, + rating: EditPredictionRating, + feedback: String, + cx: &mut Context, + ) { + self.rated_predictions.insert(prediction.id.clone()); + telemetry::event!( + "Edit Prediction Rated", + rating, + inputs = prediction.inputs, + output = prediction.edit_preview.as_unified_diff(&prediction.edits), + feedback + ); + self.client.telemetry().flush_events().detach(); + cx.notify(); + } + + fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) { + self.use_context = cx.has_flag::() + && all_language_settings(None, cx).edit_predictions.use_context; + } +} + +#[derive(Error, Debug)] +#[error( + "You must update to Zed version {minimum_version} or higher to continue using edit predictions." +)] +pub struct ZedUpdateRequiredError { + minimum_version: Version, +} + +#[cfg(feature = "eval-support")] +pub type EvalCacheKey = (EvalCacheEntryKind, u64); + +#[cfg(feature = "eval-support")] +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum EvalCacheEntryKind { + Context, + Search, + Prediction, +} + +#[cfg(feature = "eval-support")] +impl std::fmt::Display for EvalCacheEntryKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + EvalCacheEntryKind::Search => write!(f, "search"), + EvalCacheEntryKind::Context => write!(f, "context"), + EvalCacheEntryKind::Prediction => write!(f, "prediction"), } + } +} - return None; +#[cfg(feature = "eval-support")] +pub trait EvalCache: Send + Sync { + fn read(&self, key: EvalCacheKey) -> Option; + fn write(&self, key: EvalCacheKey, input: &str, value: &str); +} + +#[derive(Debug, Clone, Copy)] +pub enum DataCollectionChoice { + NotAnswered, + Enabled, + Disabled, +} + +impl DataCollectionChoice { + pub fn is_enabled(self) -> bool { + match self { + Self::Enabled => true, + Self::NotAnswered | Self::Disabled => false, + } + } + + pub fn is_answered(self) -> bool { + match self { + Self::Enabled | Self::Disabled => true, + Self::NotAnswered => false, + } } - edits.extend(model_edits.cloned()); + #[must_use] + pub fn toggle(&self) -> DataCollectionChoice { + match self { + Self::Enabled => Self::Disabled, + Self::Disabled => Self::Enabled, + Self::NotAnswered => Self::Enabled, + } + } +} + +impl From for DataCollectionChoice { + fn from(value: bool) -> Self { + match value { + true => DataCollectionChoice::Enabled, + false => DataCollectionChoice::Disabled, + } + } +} + +struct ZedPredictUpsell; + +impl Dismissable for ZedPredictUpsell { + const KEY: &'static str = "dismissed-edit-predict-upsell"; + + fn dismissed() -> bool { + // To make this backwards compatible with older versions of Zed, we + // check if the user has seen the previous Edit Prediction Onboarding + // before, by checking the data collection choice which was written to + // the database once the user clicked on "Accept and Enable" + if KEY_VALUE_STORE + .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE) + .log_err() + .is_some_and(|s| s.is_some()) + { + return true; + } + + KEY_VALUE_STORE + .read_kvp(Self::KEY) + .log_err() + .is_some_and(|s| s.is_some()) + } +} + +pub fn should_show_upsell_modal() -> bool { + !ZedPredictUpsell::dismissed() +} + +pub fn init(cx: &mut App) { + cx.observe_new(move |workspace: &mut Workspace, _, _cx| { + workspace.register_action( + move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| { + ZedPredictModal::toggle( + workspace, + workspace.user_store().clone(), + workspace.client().clone(), + window, + cx, + ) + }, + ); - if edits.is_empty() { None } else { Some(edits) } + workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| { + update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| { + settings + .project + .all_languages + .features + .get_or_insert_default() + .edit_prediction_provider = Some(EditPredictionProvider::None) + }); + }); + }) + .detach(); } diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs new file mode 100644 index 0000000000000000000000000000000000000000..0b7e289bb32b5a10c32a4bd34f118d7cb6c7d43c --- /dev/null +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -0,0 +1,1806 @@ +use super::*; +use crate::zeta1::MAX_EVENT_TOKENS; +use client::{UserStore, test::FakeServer}; +use clock::{FakeSystemClock, ReplicaId}; +use cloud_api_types::{CreateLlmTokenResponse, LlmToken}; +use cloud_llm_client::{ + EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse, + RejectEditPredictionsBody, +}; +use edit_prediction_context::Line; +use futures::{ + AsyncReadExt, StreamExt, + channel::{mpsc, oneshot}, +}; +use gpui::{ + Entity, TestAppContext, + http_client::{FakeHttpClient, Response}, +}; +use indoc::indoc; +use language::{Point, ToOffset as _}; +use lsp::LanguageServerId; +use open_ai::Usage; +use parking_lot::Mutex; +use pretty_assertions::{assert_eq, assert_matches}; +use project::{FakeFs, Project}; +use serde_json::json; +use settings::SettingsStore; +use std::{path::Path, sync::Arc, time::Duration}; +use util::{path, rel_path::rel_path}; +use uuid::Uuid; + +use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE}; + +#[gpui::test] +async fn test_current_state(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "1.txt": "Hello!\nHow\nBye\n", + "2.txt": "Hola!\nComo\nAdios\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + ep_store.update(cx, |ep_store, cx| { + ep_store.register_project(&project, cx); + }); + + let buffer1 = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap(); + project.set_active_path(Some(path.clone()), cx); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot1.anchor_before(language::Point::new(1, 3)); + + // Prediction for current file + + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx) + }); + let (_request, respond_tx) = requests.predict.next().await.unwrap(); + + respond_tx + .send(model_response(indoc! {r" + --- a/root/1.txt + +++ b/root/1.txt + @@ ... @@ + Hello! + -How + +How are you? + Bye + "})) + .unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + let prediction = ep_store + .current_prediction_for_buffer(&buffer1, &project, cx) + .unwrap(); + assert_matches!(prediction, BufferEditPrediction::Local { .. }); + }); + + ep_store.update(cx, |ep_store, _cx| { + ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project); + }); + + // Prediction for diagnostic in another file + + let diagnostic = lsp::Diagnostic { + range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)), + severity: Some(lsp::DiagnosticSeverity::ERROR), + message: "Sentence is incomplete".to_string(), + ..Default::default() + }; + + project.update(cx, |project, cx| { + project.lsp_store().update(cx, |lsp_store, cx| { + lsp_store + .update_diagnostics( + LanguageServerId(0), + lsp::PublishDiagnosticsParams { + uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(), + diagnostics: vec![diagnostic], + version: None, + }, + None, + language::DiagnosticSourceKind::Pushed, + &[], + cx, + ) + .unwrap(); + }); + }); + + let (_request, respond_tx) = requests.predict.next().await.unwrap(); + respond_tx + .send(model_response(indoc! {r#" + --- a/root/2.txt + +++ b/root/2.txt + Hola! + -Como + +Como estas? + Adios + "#})) + .unwrap(); + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + let prediction = ep_store + .current_prediction_for_buffer(&buffer1, &project, cx) + .unwrap(); + assert_matches!( + prediction, + BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt")) + ); + }); + + let buffer2 = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/2.txt"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + + ep_store.read_with(cx, |ep_store, cx| { + let prediction = ep_store + .current_prediction_for_buffer(&buffer2, &project, cx) + .unwrap(); + assert_matches!(prediction, BufferEditPrediction::Local { .. }); + }); +} + +#[gpui::test] +async fn test_simple_request(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + let prediction_task = ep_store.update(cx, |ep_store, cx| { + ep_store.request_prediction(&project, &buffer, position, Default::default(), cx) + }); + + let (_, respond_tx) = requests.predict.next().await.unwrap(); + + // TODO Put back when we have a structured request again + // assert_eq!( + // request.excerpt_path.as_ref(), + // Path::new(path!("root/foo.md")) + // ); + // assert_eq!( + // request.cursor_point, + // Point { + // line: Line(1), + // column: 3 + // } + // ); + + respond_tx + .send(model_response(indoc! { r" + --- a/root/foo.md + +++ b/root/foo.md + @@ ... @@ + Hello! + -How + +How are you? + Bye + "})) + .unwrap(); + + let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap(); + + assert_eq!(prediction.edits.len(), 1); + assert_eq!( + prediction.edits[0].0.to_point(&snapshot).start, + language::Point::new(1, 3) + ); + assert_eq!(prediction.edits[0].1.as_ref(), " are you?"); +} + +#[gpui::test] +async fn test_request_events(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\n\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + + ep_store.update(cx, |ep_store, cx| { + ep_store.register_buffer(&buffer, &project, cx); + }); + + buffer.update(cx, |buffer, cx| { + buffer.edit(vec![(7..7, "How")], None, cx); + }); + + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + let prediction_task = ep_store.update(cx, |ep_store, cx| { + ep_store.request_prediction(&project, &buffer, position, Default::default(), cx) + }); + + let (request, respond_tx) = requests.predict.next().await.unwrap(); + + let prompt = prompt_from_request(&request); + assert!( + prompt.contains(indoc! {" + --- a/root/foo.md + +++ b/root/foo.md + @@ -1,3 +1,3 @@ + Hello! + - + +How + Bye + "}), + "{prompt}" + ); + + respond_tx + .send(model_response(indoc! {r#" + --- a/root/foo.md + +++ b/root/foo.md + @@ ... @@ + Hello! + -How + +How are you? + Bye + "#})) + .unwrap(); + + let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap(); + + assert_eq!(prediction.edits.len(), 1); + assert_eq!( + prediction.edits[0].0.to_point(&snapshot).start, + language::Point::new(1, 3) + ); + assert_eq!(prediction.edits[0].1.as_ref(), " are you?"); +} + +#[gpui::test] +async fn test_empty_prediction(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + const NO_OP_DIFF: &str = indoc! { r" + --- a/root/foo.md + +++ b/root/foo.md + @@ ... @@ + Hello! + -How + +How + Bye + "}; + + let (_, respond_tx) = requests.predict.next().await.unwrap(); + let response = model_response(NO_OP_DIFF); + let id = response.id.clone(); + respond_tx.send(response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + assert!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .is_none() + ); + }); + + // prediction is reported as rejected + let (reject_request, _) = requests.reject.next().await.unwrap(); + + assert_eq!( + &reject_request.rejections, + &[EditPredictionRejection { + request_id: id, + reason: EditPredictionRejectReason::Empty, + was_shown: false + }] + ); +} + +#[gpui::test] +async fn test_interpolated_empty(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_tx) = requests.predict.next().await.unwrap(); + + buffer.update(cx, |buffer, cx| { + buffer.set_text("Hello!\nHow are you?\nBye", cx); + }); + + let response = model_response(SIMPLE_DIFF); + let id = response.id.clone(); + respond_tx.send(response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + assert!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .is_none() + ); + }); + + // prediction is reported as rejected + let (reject_request, _) = requests.reject.next().await.unwrap(); + + assert_eq!( + &reject_request.rejections, + &[EditPredictionRejection { + request_id: id, + reason: EditPredictionRejectReason::InterpolatedEmpty, + was_shown: false + }] + ); +} + +const SIMPLE_DIFF: &str = indoc! { r" + --- a/root/foo.md + +++ b/root/foo.md + @@ ... @@ + Hello! + -How + +How are you? + Bye +"}; + +#[gpui::test] +async fn test_replace_current(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_tx) = requests.predict.next().await.unwrap(); + let first_response = model_response(SIMPLE_DIFF); + let first_id = first_response.id.clone(); + respond_tx.send(first_response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + assert_eq!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + first_id + ); + }); + + // a second request is triggered + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_tx) = requests.predict.next().await.unwrap(); + let second_response = model_response(SIMPLE_DIFF); + let second_id = second_response.id.clone(); + respond_tx.send(second_response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + // second replaces first + assert_eq!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + second_id + ); + }); + + // first is reported as replaced + let (reject_request, _) = requests.reject.next().await.unwrap(); + + assert_eq!( + &reject_request.rejections, + &[EditPredictionRejection { + request_id: first_id, + reason: EditPredictionRejectReason::Replaced, + was_shown: false + }] + ); +} + +#[gpui::test] +async fn test_current_preferred(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_tx) = requests.predict.next().await.unwrap(); + let first_response = model_response(SIMPLE_DIFF); + let first_id = first_response.id.clone(); + respond_tx.send(first_response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + assert_eq!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + first_id + ); + }); + + // a second request is triggered + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_tx) = requests.predict.next().await.unwrap(); + // worse than current prediction + let second_response = model_response(indoc! { r" + --- a/root/foo.md + +++ b/root/foo.md + @@ ... @@ + Hello! + -How + +How are + Bye + "}); + let second_id = second_response.id.clone(); + respond_tx.send(second_response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + // first is preferred over second + assert_eq!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + first_id + ); + }); + + // second is reported as rejected + let (reject_request, _) = requests.reject.next().await.unwrap(); + + assert_eq!( + &reject_request.rejections, + &[EditPredictionRejection { + request_id: second_id, + reason: EditPredictionRejectReason::CurrentPreferred, + was_shown: false + }] + ); +} + +#[gpui::test] +async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + // start two refresh tasks + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_first) = requests.predict.next().await.unwrap(); + + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_second) = requests.predict.next().await.unwrap(); + + // wait for throttle + cx.run_until_parked(); + + // second responds first + let second_response = model_response(SIMPLE_DIFF); + let second_id = second_response.id.clone(); + respond_second.send(second_response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + // current prediction is second + assert_eq!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + second_id + ); + }); + + let first_response = model_response(SIMPLE_DIFF); + let first_id = first_response.id.clone(); + respond_first.send(first_response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + // current prediction is still second, since first was cancelled + assert_eq!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + second_id + ); + }); + + // first is reported as rejected + let (reject_request, _) = requests.reject.next().await.unwrap(); + + cx.run_until_parked(); + + assert_eq!( + &reject_request.rejections, + &[EditPredictionRejection { + request_id: first_id, + reason: EditPredictionRejectReason::Canceled, + was_shown: false + }] + ); +} + +#[gpui::test] +async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "foo.md": "Hello!\nHow\nBye\n" + }), + ) + .await; + let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + + let buffer = project + .update(cx, |project, cx| { + let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); + project.open_buffer(path, cx) + }) + .await + .unwrap(); + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let position = snapshot.anchor_before(language::Point::new(1, 3)); + + // start two refresh tasks + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_first) = requests.predict.next().await.unwrap(); + + ep_store.update(cx, |ep_store, cx| { + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + }); + + let (_, respond_second) = requests.predict.next().await.unwrap(); + + // wait for throttle, so requests are sent + cx.run_until_parked(); + + ep_store.update(cx, |ep_store, cx| { + // start a third request + ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); + + // 2 are pending, so 2nd is cancelled + assert_eq!( + ep_store + .get_or_init_project(&project, cx) + .cancelled_predictions + .iter() + .copied() + .collect::>(), + [1] + ); + }); + + // wait for throttle + cx.run_until_parked(); + + let (_, respond_third) = requests.predict.next().await.unwrap(); + + let first_response = model_response(SIMPLE_DIFF); + let first_id = first_response.id.clone(); + respond_first.send(first_response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + // current prediction is first + assert_eq!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + first_id + ); + }); + + let cancelled_response = model_response(SIMPLE_DIFF); + let cancelled_id = cancelled_response.id.clone(); + respond_second.send(cancelled_response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + // current prediction is still first, since second was cancelled + assert_eq!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + first_id + ); + }); + + let third_response = model_response(SIMPLE_DIFF); + let third_response_id = third_response.id.clone(); + respond_third.send(third_response).unwrap(); + + cx.run_until_parked(); + + ep_store.read_with(cx, |ep_store, cx| { + // third completes and replaces first + assert_eq!( + ep_store + .current_prediction_for_buffer(&buffer, &project, cx) + .unwrap() + .id + .0, + third_response_id + ); + }); + + // second is reported as rejected + let (reject_request, _) = requests.reject.next().await.unwrap(); + + cx.run_until_parked(); + + assert_eq!( + &reject_request.rejections, + &[ + EditPredictionRejection { + request_id: cancelled_id, + reason: EditPredictionRejectReason::Canceled, + was_shown: false + }, + EditPredictionRejection { + request_id: first_id, + reason: EditPredictionRejectReason::Replaced, + was_shown: false + } + ] + ); +} + +#[gpui::test] +async fn test_rejections_flushing(cx: &mut TestAppContext) { + let (ep_store, mut requests) = init_test_with_fake_client(cx); + + ep_store.update(cx, |ep_store, _cx| { + ep_store.reject_prediction( + EditPredictionId("test-1".into()), + EditPredictionRejectReason::Discarded, + false, + ); + ep_store.reject_prediction( + EditPredictionId("test-2".into()), + EditPredictionRejectReason::Canceled, + true, + ); + }); + + cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE); + cx.run_until_parked(); + + let (reject_request, respond_tx) = requests.reject.next().await.unwrap(); + respond_tx.send(()).unwrap(); + + // batched + assert_eq!(reject_request.rejections.len(), 2); + assert_eq!( + reject_request.rejections[0], + EditPredictionRejection { + request_id: "test-1".to_string(), + reason: EditPredictionRejectReason::Discarded, + was_shown: false + } + ); + assert_eq!( + reject_request.rejections[1], + EditPredictionRejection { + request_id: "test-2".to_string(), + reason: EditPredictionRejectReason::Canceled, + was_shown: true + } + ); + + // Reaching batch size limit sends without debounce + ep_store.update(cx, |ep_store, _cx| { + for i in 0..70 { + ep_store.reject_prediction( + EditPredictionId(format!("batch-{}", i).into()), + EditPredictionRejectReason::Discarded, + false, + ); + } + }); + + // First MAX/2 items are sent immediately + cx.run_until_parked(); + let (reject_request, respond_tx) = requests.reject.next().await.unwrap(); + respond_tx.send(()).unwrap(); + + assert_eq!(reject_request.rejections.len(), 50); + assert_eq!(reject_request.rejections[0].request_id, "batch-0"); + assert_eq!(reject_request.rejections[49].request_id, "batch-49"); + + // Remaining items are debounced with the next batch + cx.executor().advance_clock(Duration::from_secs(15)); + cx.run_until_parked(); + + let (reject_request, respond_tx) = requests.reject.next().await.unwrap(); + respond_tx.send(()).unwrap(); + + assert_eq!(reject_request.rejections.len(), 20); + assert_eq!(reject_request.rejections[0].request_id, "batch-50"); + assert_eq!(reject_request.rejections[19].request_id, "batch-69"); + + // Request failure + ep_store.update(cx, |ep_store, _cx| { + ep_store.reject_prediction( + EditPredictionId("retry-1".into()), + EditPredictionRejectReason::Discarded, + false, + ); + }); + + cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE); + cx.run_until_parked(); + + let (reject_request, _respond_tx) = requests.reject.next().await.unwrap(); + assert_eq!(reject_request.rejections.len(), 1); + assert_eq!(reject_request.rejections[0].request_id, "retry-1"); + // Simulate failure + drop(_respond_tx); + + // Add another rejection + ep_store.update(cx, |ep_store, _cx| { + ep_store.reject_prediction( + EditPredictionId("retry-2".into()), + EditPredictionRejectReason::Discarded, + false, + ); + }); + + cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE); + cx.run_until_parked(); + + // Retry should include both the failed item and the new one + let (reject_request, respond_tx) = requests.reject.next().await.unwrap(); + respond_tx.send(()).unwrap(); + + assert_eq!(reject_request.rejections.len(), 2); + assert_eq!(reject_request.rejections[0].request_id, "retry-1"); + assert_eq!(reject_request.rejections[1].request_id, "retry-2"); +} + +// Skipped until we start including diagnostics in prompt +// #[gpui::test] +// async fn test_request_diagnostics(cx: &mut TestAppContext) { +// let (ep_store, mut req_rx) = init_test_with_fake_client(cx); +// let fs = FakeFs::new(cx.executor()); +// fs.insert_tree( +// "/root", +// json!({ +// "foo.md": "Hello!\nBye" +// }), +// ) +// .await; +// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await; + +// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap(); +// let diagnostic = lsp::Diagnostic { +// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)), +// severity: Some(lsp::DiagnosticSeverity::ERROR), +// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(), +// ..Default::default() +// }; + +// project.update(cx, |project, cx| { +// project.lsp_store().update(cx, |lsp_store, cx| { +// // Create some diagnostics +// lsp_store +// .update_diagnostics( +// LanguageServerId(0), +// lsp::PublishDiagnosticsParams { +// uri: path_to_buffer_uri.clone(), +// diagnostics: vec![diagnostic], +// version: None, +// }, +// None, +// language::DiagnosticSourceKind::Pushed, +// &[], +// cx, +// ) +// .unwrap(); +// }); +// }); + +// let buffer = project +// .update(cx, |project, cx| { +// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap(); +// project.open_buffer(path, cx) +// }) +// .await +// .unwrap(); + +// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); +// let position = snapshot.anchor_before(language::Point::new(0, 0)); + +// let _prediction_task = ep_store.update(cx, |ep_store, cx| { +// ep_store.request_prediction(&project, &buffer, position, cx) +// }); + +// let (request, _respond_tx) = req_rx.next().await.unwrap(); + +// assert_eq!(request.diagnostic_groups.len(), 1); +// let value = serde_json::from_str::(request.diagnostic_groups[0].0.get()) +// .unwrap(); +// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3 +// assert_eq!( +// value, +// json!({ +// "entries": [{ +// "range": { +// "start": 8, +// "end": 10 +// }, +// "diagnostic": { +// "source": null, +// "code": null, +// "code_description": null, +// "severity": 1, +// "message": "\"Hello\" deprecated. Use \"Hi\" instead", +// "markdown": null, +// "group_id": 0, +// "is_primary": true, +// "is_disk_based": false, +// "is_unnecessary": false, +// "source_kind": "Pushed", +// "data": null, +// "underline": true +// } +// }], +// "primary_ix": 0 +// }) +// ); +// } + +fn model_response(text: &str) -> open_ai::Response { + open_ai::Response { + id: Uuid::new_v4().to_string(), + object: "response".into(), + created: 0, + model: "model".into(), + choices: vec![open_ai::Choice { + index: 0, + message: open_ai::RequestMessage::Assistant { + content: Some(open_ai::MessageContent::Plain(text.to_string())), + tool_calls: vec![], + }, + finish_reason: None, + }], + usage: Usage { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + }, + } +} + +fn prompt_from_request(request: &open_ai::Request) -> &str { + assert_eq!(request.messages.len(), 1); + let open_ai::RequestMessage::User { + content: open_ai::MessageContent::Plain(content), + .. + } = &request.messages[0] + else { + panic!( + "Request does not have single user message of type Plain. {:#?}", + request + ); + }; + content +} + +struct RequestChannels { + predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender)>, + reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>, +} + +fn init_test_with_fake_client( + cx: &mut TestAppContext, +) -> (Entity, RequestChannels) { + cx.update(move |cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + zlog::init_test(); + + let (predict_req_tx, predict_req_rx) = mpsc::unbounded(); + let (reject_req_tx, reject_req_rx) = mpsc::unbounded(); + + let http_client = FakeHttpClient::create({ + move |req| { + let uri = req.uri().path().to_string(); + let mut body = req.into_body(); + let predict_req_tx = predict_req_tx.clone(); + let reject_req_tx = reject_req_tx.clone(); + async move { + let resp = match uri.as_str() { + "/client/llm_tokens" => serde_json::to_string(&json!({ + "token": "test" + })) + .unwrap(), + "/predict_edits/raw" => { + let mut buf = Vec::new(); + body.read_to_end(&mut buf).await.ok(); + let req = serde_json::from_slice(&buf).unwrap(); + + let (res_tx, res_rx) = oneshot::channel(); + predict_req_tx.unbounded_send((req, res_tx)).unwrap(); + serde_json::to_string(&res_rx.await?).unwrap() + } + "/predict_edits/reject" => { + let mut buf = Vec::new(); + body.read_to_end(&mut buf).await.ok(); + let req = serde_json::from_slice(&buf).unwrap(); + + let (res_tx, res_rx) = oneshot::channel(); + reject_req_tx.unbounded_send((req, res_tx)).unwrap(); + serde_json::to_string(&res_rx.await?).unwrap() + } + _ => { + panic!("Unexpected path: {}", uri) + } + }; + + Ok(Response::builder().body(resp.into()).unwrap()) + } + } + }); + + let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx); + client.cloud_client().set_credentials(1, "test".into()); + + language_model::init(client.clone(), cx); + + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + let ep_store = EditPredictionStore::global(&client, &user_store, cx); + + ( + ep_store, + RequestChannels { + predict: predict_req_rx, + reject: reject_req_rx, + }, + ) + }) +} + +const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt"); + +#[gpui::test] +async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { + let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx)); + let edits: Arc<[(Range, Arc)]> = cx.update(|cx| { + to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into() + }); + + let edit_preview = cx + .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx)) + .await; + + let completion = EditPrediction { + edits, + edit_preview, + buffer: buffer.clone(), + snapshot: cx.read(|cx| buffer.read(cx).snapshot()), + id: EditPredictionId("the-id".into()), + inputs: EditPredictionInputs { + events: Default::default(), + included_files: Default::default(), + cursor_point: cloud_llm_client::predict_edits_v3::Point { + line: Line(0), + column: 0, + }, + cursor_path: Path::new("").into(), + }, + buffer_snapshotted_at: Instant::now(), + response_received_at: Instant::now(), + }; + + cx.update(|cx| { + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(2..5, "REM".into()), (9..11, "".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx)); + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(2..2, "REM".into()), (6..8, "".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.undo(cx)); + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(2..5, "REM".into()), (9..11, "".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx)); + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(3..3, "EM".into()), (7..9, "".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx)); + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(4..4, "M".into()), (8..10, "".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx)); + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(9..11, "".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx)); + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(4..4, "M".into()), (8..10, "".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx)); + assert_eq!( + from_completion_edits( + &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &buffer, + cx + ), + vec![(4..4, "M".into())] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx)); + assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None); + }) +} + +#[gpui::test] +async fn test_clean_up_diff(cx: &mut TestAppContext) { + init_test(cx); + + assert_eq!( + apply_edit_prediction( + indoc! {" + fn main() { + let word_1 = \"lorem\"; + let range = word.len()..word.len(); + } + "}, + indoc! {" + <|editable_region_start|> + fn main() { + let word_1 = \"lorem\"; + let range = word_1.len()..word_1.len(); + } + + <|editable_region_end|> + "}, + cx, + ) + .await, + indoc! {" + fn main() { + let word_1 = \"lorem\"; + let range = word_1.len()..word_1.len(); + } + "}, + ); + + assert_eq!( + apply_edit_prediction( + indoc! {" + fn main() { + let story = \"the quick\" + } + "}, + indoc! {" + <|editable_region_start|> + fn main() { + let story = \"the quick brown fox jumps over the lazy dog\"; + } + + <|editable_region_end|> + "}, + cx, + ) + .await, + indoc! {" + fn main() { + let story = \"the quick brown fox jumps over the lazy dog\"; + } + "}, + ); +} + +#[gpui::test] +async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) { + init_test(cx); + + let buffer_content = "lorem\n"; + let completion_response = indoc! {" + ```animals.js + <|start_of_file|> + <|editable_region_start|> + lorem + ipsum + <|editable_region_end|> + ```"}; + + assert_eq!( + apply_edit_prediction(buffer_content, completion_response, cx).await, + "lorem\nipsum" + ); +} + +#[gpui::test] +async fn test_can_collect_data(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT })) + .await; + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/project/src/main.rs"), cx) + }) + .await + .unwrap(); + + let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; + ep_store.update(cx, |ep_store, _cx| { + ep_store.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + true + ); + + ep_store.update(cx, |ep_store, _cx| { + ep_store.data_collection_choice = DataCollectionChoice::Disabled + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); +} + +#[gpui::test] +async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [], cx).await; + + let buffer = cx.new(|_cx| { + Buffer::remote( + language::BufferId::new(1).unwrap(), + ReplicaId::new(1), + language::Capability::ReadWrite, + "fn main() {\n println!(\"Hello\");\n}", + ) + }); + + let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; + ep_store.update(cx, |ep_store, _cx| { + ep_store.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); +} + +#[gpui::test] +async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/project"), + json!({ + "LICENSE": BSD_0_TXT, + ".env": "SECRET_KEY=secret" + }), + ) + .await; + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer("/project/.env", cx) + }) + .await + .unwrap(); + + let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; + ep_store.update(cx, |ep_store, _cx| { + ep_store.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); +} + +#[gpui::test] +async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [], cx).await; + let buffer = cx.new(|cx| Buffer::local("", cx)); + + let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; + ep_store.update(cx, |ep_store, _cx| { + ep_store.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); +} + +#[gpui::test] +async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" })) + .await; + + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer("/project/main.rs", cx) + }) + .await + .unwrap(); + + let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; + ep_store.update(cx, |ep_store, _cx| { + ep_store.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); +} + +#[gpui::test] +async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/open_source_worktree"), + json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }), + ) + .await; + fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" })) + .await; + + let project = Project::test( + fs.clone(), + [ + path!("/open_source_worktree").as_ref(), + path!("/closed_source_worktree").as_ref(), + ], + cx, + ) + .await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx) + }) + .await + .unwrap(); + + let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; + ep_store.update(cx, |ep_store, _cx| { + ep_store.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + true + ); + + let closed_source_file = project + .update(cx, |project, cx| { + let worktree2 = project + .worktree_for_root_name("closed_source_worktree", cx) + .unwrap(); + worktree2.update(cx, |worktree2, cx| { + worktree2.load_file(rel_path("main.rs"), cx) + }) + }) + .await + .unwrap() + .file; + + buffer.update(cx, |buffer, cx| { + buffer.file_updated(closed_source_file, cx); + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); +} + +#[gpui::test] +async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/worktree1"), + json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }), + ) + .await; + fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" })) + .await; + + let project = Project::test( + fs.clone(), + [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()], + cx, + ) + .await; + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/worktree1/main.rs"), cx) + }) + .await + .unwrap(); + let private_buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/worktree2/file.rs"), cx) + }) + .await + .unwrap(); + + let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await; + ep_store.update(cx, |ep_store, _cx| { + ep_store.data_collection_choice = DataCollectionChoice::Enabled + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + true + ); + + // this has a side effect of registering the buffer to watch for edits + run_edit_prediction(&private_buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); + + private_buffer.update(cx, |private_buffer, cx| { + private_buffer.edit([(0..0, "An edit for the history!")], None, cx); + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + false + ); + + // make an edit that uses too many bytes, causing private_buffer edit to not be able to be + // included + buffer.update(cx, |buffer, cx| { + buffer.edit( + [( + 0..0, + " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS), + )], + None, + cx, + ); + }); + + run_edit_prediction(&buffer, &project, &ep_store, cx).await; + assert_eq!( + captured_request.lock().clone().unwrap().can_collect_data, + true + ); +} + +fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + }); +} + +async fn apply_edit_prediction( + buffer_content: &str, + completion_response: &str, + cx: &mut TestAppContext, +) -> String { + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); + let (ep_store, _, response) = make_test_ep_store(&project, cx).await; + *response.lock() = completion_response.to_string(); + let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await; + buffer.update(cx, |buffer, cx| { + buffer.edit(edit_prediction.edits.iter().cloned(), None, cx) + }); + buffer.read_with(cx, |buffer, _| buffer.text()) +} + +async fn run_edit_prediction( + buffer: &Entity, + project: &Entity, + ep_store: &Entity, + cx: &mut TestAppContext, +) -> EditPrediction { + let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); + ep_store.update(cx, |ep_store, cx| { + ep_store.register_buffer(buffer, &project, cx) + }); + cx.background_executor.run_until_parked(); + let prediction_task = ep_store.update(cx, |ep_store, cx| { + ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx) + }); + prediction_task.await.unwrap().unwrap().prediction.unwrap() +} + +async fn make_test_ep_store( + project: &Entity, + cx: &mut TestAppContext, +) -> ( + Entity, + Arc>>, + Arc>, +) { + let default_response = indoc! {" + ```main.rs + <|start_of_file|> + <|editable_region_start|> + hello world + <|editable_region_end|> + ```" + }; + let captured_request: Arc>> = Arc::new(Mutex::new(None)); + let completion_response: Arc> = + Arc::new(Mutex::new(default_response.to_string())); + let http_client = FakeHttpClient::create({ + let captured_request = captured_request.clone(); + let completion_response = completion_response.clone(); + let mut next_request_id = 0; + move |req| { + let captured_request = captured_request.clone(); + let completion_response = completion_response.clone(); + async move { + match (req.method(), req.uri().path()) { + (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&CreateLlmTokenResponse { + token: LlmToken("the-llm-token".to_string()), + }) + .unwrap() + .into(), + ) + .unwrap()), + (&Method::POST, "/predict_edits/v2") => { + let mut request_body = String::new(); + req.into_body().read_to_string(&mut request_body).await?; + *captured_request.lock() = + Some(serde_json::from_str(&request_body).unwrap()); + next_request_id += 1; + Ok(http_client::Response::builder() + .status(200) + .body( + serde_json::to_string(&PredictEditsResponse { + request_id: format!("request-{next_request_id}"), + output_excerpt: completion_response.lock().clone(), + }) + .unwrap() + .into(), + ) + .unwrap()) + } + _ => Ok(http_client::Response::builder() + .status(404) + .body("Not Found".into()) + .unwrap()), + } + } + } + }); + + let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); + cx.update(|cx| { + RefreshLlmTokenListener::register(client.clone(), cx); + }); + let _server = FakeServer::for_client(42, &client, cx).await; + + let ep_store = cx.new(|cx| { + let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx); + ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1); + + let worktrees = project.read(cx).worktrees(cx).collect::>(); + for worktree in worktrees { + let worktree_id = worktree.read(cx).id(); + ep_store + .get_or_init_project(project, cx) + .license_detection_watchers + .entry(worktree_id) + .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx))); + } + + ep_store + }); + + (ep_store, captured_request, completion_response) +} + +fn to_completion_edits( + iterator: impl IntoIterator, Arc)>, + buffer: &Entity, + cx: &App, +) -> Vec<(Range, Arc)> { + let buffer = buffer.read(cx); + iterator + .into_iter() + .map(|(range, text)| { + ( + buffer.anchor_after(range.start)..buffer.anchor_before(range.end), + text, + ) + }) + .collect() +} + +fn from_completion_edits( + editor_edits: &[(Range, Arc)], + buffer: &Entity, + cx: &App, +) -> Vec<(Range, Arc)> { + let buffer = buffer.read(cx); + editor_edits + .iter() + .map(|(range, text)| { + ( + range.start.to_offset(buffer)..range.end.to_offset(buffer), + text.clone(), + ) + }) + .collect() +} + +#[ctor::ctor] +fn init_logger() { + zlog::init_test(); +} diff --git a/crates/zeta/src/license_detection.rs b/crates/edit_prediction/src/license_detection.rs similarity index 100% rename from crates/zeta/src/license_detection.rs rename to crates/edit_prediction/src/license_detection.rs diff --git a/crates/edit_prediction/src/mercury.rs b/crates/edit_prediction/src/mercury.rs new file mode 100644 index 0000000000000000000000000000000000000000..40c0fdfac021f937df5172fd423d3b6bfc5f8146 --- /dev/null +++ b/crates/edit_prediction/src/mercury.rs @@ -0,0 +1,340 @@ +use anyhow::{Context as _, Result}; +use cloud_llm_client::predict_edits_v3::Event; +use credentials_provider::CredentialsProvider; +use edit_prediction_context::RelatedFile; +use futures::{AsyncReadExt as _, FutureExt, future::Shared}; +use gpui::{ + App, AppContext as _, Entity, Task, + http_client::{self, AsyncBody, Method}, +}; +use language::{Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _}; +use project::{Project, ProjectPath}; +use std::{ + collections::VecDeque, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc, time::Instant, +}; + +use crate::{ + EditPredictionId, EditPredictionInputs, open_ai_response::text_from_response, + prediction::EditPredictionResult, +}; + +const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions"; +const MAX_CONTEXT_TOKENS: usize = 150; +const MAX_REWRITE_TOKENS: usize = 350; + +pub struct Mercury { + pub api_token: Shared>>, +} + +impl Mercury { + pub fn new(cx: &App) -> Self { + Mercury { + api_token: load_api_token(cx).shared(), + } + } + + pub fn set_api_token(&mut self, api_token: Option, cx: &mut App) -> Task> { + self.api_token = Task::ready(api_token.clone()).shared(); + store_api_token_in_keychain(api_token, cx) + } + + pub fn request_prediction( + &self, + _project: &Entity, + active_buffer: &Entity, + snapshot: BufferSnapshot, + position: language::Anchor, + events: Vec>, + _recent_paths: &VecDeque, + related_files: Vec, + _diagnostic_search_range: Range, + cx: &mut App, + ) -> Task>> { + let Some(api_token) = self.api_token.clone().now_or_never().flatten() else { + return Task::ready(Ok(None)); + }; + let full_path: Arc = snapshot + .file() + .map(|file| file.full_path(cx)) + .unwrap_or_else(|| "untitled".into()) + .into(); + + let http_client = cx.http_client(); + let cursor_point = position.to_point(&snapshot); + let buffer_snapshotted_at = Instant::now(); + + let result = cx.background_spawn(async move { + let (editable_range, context_range) = + crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position( + cursor_point, + &snapshot, + MAX_CONTEXT_TOKENS, + MAX_REWRITE_TOKENS, + ); + + let offset_range = editable_range.to_offset(&snapshot); + let prompt = build_prompt( + &events, + &related_files, + &snapshot, + full_path.as_ref(), + cursor_point, + editable_range, + context_range.clone(), + ); + + let inputs = EditPredictionInputs { + events: events, + included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile { + path: full_path.clone(), + max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row), + excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt { + start_line: cloud_llm_client::predict_edits_v3::Line( + context_range.start.row, + ), + text: snapshot + .text_for_range(context_range.clone()) + .collect::() + .into(), + }], + }], + cursor_point: cloud_llm_client::predict_edits_v3::Point { + column: cursor_point.column, + line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row), + }, + cursor_path: full_path.clone(), + }; + + let request_body = open_ai::Request { + model: "mercury-coder".into(), + messages: vec![open_ai::RequestMessage::User { + content: open_ai::MessageContent::Plain(prompt), + }], + stream: false, + max_completion_tokens: None, + stop: vec![], + temperature: None, + tool_choice: None, + parallel_tool_calls: None, + tools: vec![], + prompt_cache_key: None, + reasoning_effort: None, + }; + + let buf = serde_json::to_vec(&request_body)?; + let body: AsyncBody = buf.into(); + + let request = http_client::Request::builder() + .uri(MERCURY_API_URL) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_token)) + .header("Connection", "keep-alive") + .method(Method::POST) + .body(body) + .context("Failed to create request")?; + + let mut response = http_client + .send(request) + .await + .context("Failed to send request")?; + + let mut body: Vec = Vec::new(); + response + .body_mut() + .read_to_end(&mut body) + .await + .context("Failed to read response body")?; + + let response_received_at = Instant::now(); + if !response.status().is_success() { + anyhow::bail!( + "Request failed with status: {:?}\nBody: {}", + response.status(), + String::from_utf8_lossy(&body), + ); + }; + + let mut response: open_ai::Response = + serde_json::from_slice(&body).context("Failed to parse response")?; + + let id = mem::take(&mut response.id); + let response_str = text_from_response(response).unwrap_or_default(); + + let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str); + let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str); + + let mut edits = Vec::new(); + const NO_PREDICTION_OUTPUT: &str = "None"; + + if response_str != NO_PREDICTION_OUTPUT { + let old_text = snapshot + .text_for_range(offset_range.clone()) + .collect::(); + edits.extend( + language::text_diff(&old_text, &response_str) + .into_iter() + .map(|(range, text)| { + ( + snapshot.anchor_after(offset_range.start + range.start) + ..snapshot.anchor_before(offset_range.start + range.end), + text, + ) + }), + ); + } + + anyhow::Ok((id, edits, snapshot, response_received_at, inputs)) + }); + + let buffer = active_buffer.clone(); + + cx.spawn(async move |cx| { + let (id, edits, old_snapshot, response_received_at, inputs) = + result.await.context("Mercury edit prediction failed")?; + anyhow::Ok(Some( + EditPredictionResult::new( + EditPredictionId(id.into()), + &buffer, + &old_snapshot, + edits.into(), + buffer_snapshotted_at, + response_received_at, + inputs, + cx, + ) + .await, + )) + }) + } +} + +fn build_prompt( + events: &[Arc], + related_files: &[RelatedFile], + cursor_buffer: &BufferSnapshot, + cursor_buffer_path: &Path, + cursor_point: Point, + editable_range: Range, + context_range: Range, +) -> String { + const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n"; + const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n"; + const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n"; + const RECENTLY_VIEWED_SNIPPET_END: &str = "<|/recently_viewed_code_snippet|>\n"; + const CURRENT_FILE_CONTENT_START: &str = "<|current_file_content|>\n"; + const CURRENT_FILE_CONTENT_END: &str = "<|/current_file_content|>\n"; + const CODE_TO_EDIT_START: &str = "<|code_to_edit|>\n"; + const CODE_TO_EDIT_END: &str = "<|/code_to_edit|>\n"; + const EDIT_DIFF_HISTORY_START: &str = "<|edit_diff_history|>\n"; + const EDIT_DIFF_HISTORY_END: &str = "<|/edit_diff_history|>\n"; + const CURSOR_TAG: &str = "<|cursor|>"; + const CODE_SNIPPET_FILE_PATH_PREFIX: &str = "code_snippet_file_path: "; + const CURRENT_FILE_PATH_PREFIX: &str = "current_file_path: "; + + let mut prompt = String::new(); + + push_delimited( + &mut prompt, + RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END, + |prompt| { + for related_file in related_files { + for related_excerpt in &related_file.excerpts { + push_delimited( + prompt, + RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END, + |prompt| { + prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX); + prompt.push_str(related_file.path.path.as_unix_str()); + prompt.push('\n'); + prompt.push_str(&related_excerpt.text.to_string()); + }, + ); + } + } + }, + ); + + push_delimited( + &mut prompt, + CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END, + |prompt| { + prompt.push_str(CURRENT_FILE_PATH_PREFIX); + prompt.push_str(cursor_buffer_path.as_os_str().to_string_lossy().as_ref()); + prompt.push('\n'); + + let prefix_range = context_range.start..editable_range.start; + let suffix_range = editable_range.end..context_range.end; + + prompt.extend(cursor_buffer.text_for_range(prefix_range)); + push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| { + let range_before_cursor = editable_range.start..cursor_point; + let range_after_cursor = cursor_point..editable_range.end; + prompt.extend(cursor_buffer.text_for_range(range_before_cursor)); + prompt.push_str(CURSOR_TAG); + prompt.extend(cursor_buffer.text_for_range(range_after_cursor)); + }); + prompt.extend(cursor_buffer.text_for_range(suffix_range)); + }, + ); + + push_delimited( + &mut prompt, + EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END, + |prompt| { + for event in events { + writeln!(prompt, "{event}").unwrap(); + } + }, + ); + + prompt +} + +fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(&mut String)) { + prompt.push_str(delimiters.start); + cb(prompt); + prompt.push_str(delimiters.end); +} + +pub const MERCURY_CREDENTIALS_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions"; +pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token"; + +pub fn load_api_token(cx: &App) -> Task> { + if let Some(api_token) = std::env::var("MERCURY_AI_TOKEN") + .ok() + .filter(|value| !value.is_empty()) + { + return Task::ready(Some(api_token)); + } + let credentials_provider = ::global(cx); + cx.spawn(async move |cx| { + let (_, credentials) = credentials_provider + .read_credentials(MERCURY_CREDENTIALS_URL, &cx) + .await + .ok()??; + String::from_utf8(credentials).ok() + }) +} + +fn store_api_token_in_keychain(api_token: Option, cx: &App) -> Task> { + let credentials_provider = ::global(cx); + + cx.spawn(async move |cx| { + if let Some(api_token) = api_token { + credentials_provider + .write_credentials( + MERCURY_CREDENTIALS_URL, + MERCURY_CREDENTIALS_USERNAME, + api_token.as_bytes(), + cx, + ) + .await + .context("Failed to save Mercury API token to system keychain") + } else { + credentials_provider + .delete_credentials(MERCURY_CREDENTIALS_URL, cx) + .await + .context("Failed to delete Mercury API token from system keychain") + } + }) +} diff --git a/crates/zeta/src/onboarding_modal.rs b/crates/edit_prediction/src/onboarding_modal.rs similarity index 100% rename from crates/zeta/src/onboarding_modal.rs rename to crates/edit_prediction/src/onboarding_modal.rs diff --git a/crates/edit_prediction/src/open_ai_response.rs b/crates/edit_prediction/src/open_ai_response.rs new file mode 100644 index 0000000000000000000000000000000000000000..c7e3350936dd89c89849130ba279ad2914dd2bd8 --- /dev/null +++ b/crates/edit_prediction/src/open_ai_response.rs @@ -0,0 +1,31 @@ +pub fn text_from_response(mut res: open_ai::Response) -> Option { + let choice = res.choices.pop()?; + let output_text = match choice.message { + open_ai::RequestMessage::Assistant { + content: Some(open_ai::MessageContent::Plain(content)), + .. + } => content, + open_ai::RequestMessage::Assistant { + content: Some(open_ai::MessageContent::Multipart(mut content)), + .. + } => { + if content.is_empty() { + log::error!("No output from Baseten completion response"); + return None; + } + + match content.remove(0) { + open_ai::MessagePart::Text { text } => text, + open_ai::MessagePart::Image { .. } => { + log::error!("Expected text, got an image"); + return None; + } + } + } + _ => { + log::error!("Invalid response message: {:?}", choice.message); + return None; + } + }; + Some(output_text) +} diff --git a/crates/zeta/src/prediction.rs b/crates/edit_prediction/src/prediction.rs similarity index 99% rename from crates/zeta/src/prediction.rs rename to crates/edit_prediction/src/prediction.rs index fd3241730030fe8bdd95e2cae9ee87b406ade735..d169cf26e1dc4477554bfe8821ff5eae083a6124 100644 --- a/crates/zeta/src/prediction.rs +++ b/crates/edit_prediction/src/prediction.rs @@ -99,7 +99,7 @@ pub struct EditPrediction { #[derive(Debug, Clone, Serialize)] pub struct EditPredictionInputs { pub events: Vec>, - pub included_files: Vec, + pub included_files: Vec, pub cursor_point: cloud_llm_client::predict_edits_v3::Point, pub cursor_path: Arc, } diff --git a/crates/zeta/src/sweep_ai.rs b/crates/edit_prediction/src/sweep_ai.rs similarity index 94% rename from crates/zeta/src/sweep_ai.rs rename to crates/edit_prediction/src/sweep_ai.rs index 8fd5398f3facc807d99951c48c749e9247fb5670..4bb014c640cb489db29c800835a58febf91a7270 100644 --- a/crates/zeta/src/sweep_ai.rs +++ b/crates/edit_prediction/src/sweep_ai.rs @@ -1,6 +1,7 @@ use anyhow::{Context as _, Result}; use cloud_llm_client::predict_edits_v3::Event; use credentials_provider::CredentialsProvider; +use edit_prediction_context::RelatedFile; use futures::{AsyncReadExt as _, FutureExt, future::Shared}; use gpui::{ App, AppContext as _, Entity, Task, @@ -49,6 +50,7 @@ impl SweepAi { position: language::Anchor, events: Vec>, recent_paths: &VecDeque, + related_files: Vec, diagnostic_search_range: Range, cx: &mut App, ) -> Task>> { @@ -120,6 +122,19 @@ impl SweepAi { }) .collect::>(); + let retrieval_chunks = related_files + .iter() + .flat_map(|related_file| { + related_file.excerpts.iter().map(|excerpt| FileChunk { + file_path: related_file.path.path.as_unix_str().to_string(), + start_line: excerpt.point_range.start.row as usize, + end_line: excerpt.point_range.end.row as usize, + content: excerpt.text.to_string(), + timestamp: None, + }) + }) + .collect(); + let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false); let mut diagnostic_content = String::new(); let mut diagnostic_count = 0; @@ -168,7 +183,7 @@ impl SweepAi { multiple_suggestions: false, branch: None, file_chunks, - retrieval_chunks: vec![], + retrieval_chunks, recent_user_actions: vec![], use_bytes: true, // TODO @@ -182,7 +197,7 @@ impl SweepAi { let inputs = EditPredictionInputs { events, - included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile { + included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile { path: full_path.clone(), max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row), excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt { @@ -320,7 +335,7 @@ struct AutocompleteRequest { pub cursor_position: usize, pub original_file_contents: String, pub file_chunks: Vec, - pub retrieval_chunks: Vec, + pub retrieval_chunks: Vec, pub recent_user_actions: Vec, pub multiple_suggestions: bool, pub privacy_mode_enabled: bool, @@ -337,15 +352,6 @@ struct FileChunk { pub timestamp: Option, } -#[derive(Debug, Clone, Serialize)] -struct RetrievalChunk { - pub file_path: String, - pub start_line: usize, - pub end_line: usize, - pub content: String, - pub timestamp: u64, -} - #[derive(Debug, Clone, Serialize)] struct UserAction { pub action_type: ActionType, diff --git a/crates/zeta/src/udiff.rs b/crates/edit_prediction/src/udiff.rs similarity index 100% rename from crates/zeta/src/udiff.rs rename to crates/edit_prediction/src/udiff.rs diff --git a/crates/zeta/src/xml_edits.rs b/crates/edit_prediction/src/xml_edits.rs similarity index 100% rename from crates/zeta/src/xml_edits.rs rename to crates/edit_prediction/src/xml_edits.rs diff --git a/crates/zeta/src/provider.rs b/crates/edit_prediction/src/zed_edit_prediction_delegate.rs similarity index 58% rename from crates/zeta/src/provider.rs rename to crates/edit_prediction/src/zed_edit_prediction_delegate.rs index 019d780e579c079f745f56136bdbd3a4add76b50..91371d539beca012e2ded4e9ec8702c8db39bd8a 100644 --- a/crates/zeta/src/provider.rs +++ b/crates/edit_prediction/src/zed_edit_prediction_delegate.rs @@ -1,55 +1,56 @@ -use std::{cmp, sync::Arc, time::Duration}; +use std::{cmp, sync::Arc}; use client::{Client, UserStore}; use cloud_llm_client::EditPredictionRejectReason; -use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider}; +use edit_prediction_types::{DataCollectionState, Direction, EditPredictionDelegate}; use gpui::{App, Entity, prelude::*}; -use language::ToPoint as _; +use language::{Buffer, ToPoint as _}; use project::Project; -use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel}; +use crate::{BufferEditPrediction, EditPredictionModel, EditPredictionStore}; -pub struct ZetaEditPredictionProvider { - zeta: Entity, +pub struct ZedEditPredictionDelegate { + store: Entity, project: Entity, + singleton_buffer: Option>, } -impl ZetaEditPredictionProvider { - pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300); - +impl ZedEditPredictionDelegate { pub fn new( project: Entity, + singleton_buffer: Option>, client: &Arc, user_store: &Entity, cx: &mut Context, ) -> Self { - let zeta = Zeta::global(client, user_store, cx); - zeta.update(cx, |zeta, cx| { - zeta.register_project(&project, cx); + let store = EditPredictionStore::global(client, user_store, cx); + store.update(cx, |store, cx| { + store.register_project(&project, cx); }); - cx.observe(&zeta, |_this, _zeta, cx| { + cx.observe(&store, |_this, _ep_store, cx| { cx.notify(); }) .detach(); Self { project: project, - zeta, + store: store, + singleton_buffer, } } } -impl EditPredictionProvider for ZetaEditPredictionProvider { +impl EditPredictionDelegate for ZedEditPredictionDelegate { fn name() -> &'static str { - "zed-predict2" + "zed-predict" } fn display_name() -> &'static str { - "Zed's Edit Predictions 2" + "Zed's Edit Predictions" } - fn show_completions_in_menu() -> bool { + fn show_predictions_in_menu() -> bool { true } @@ -57,17 +58,38 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { true } - fn data_collection_state(&self, _cx: &App) -> DataCollectionState { - // TODO [zeta2] - DataCollectionState::Unsupported + fn data_collection_state(&self, cx: &App) -> DataCollectionState { + if let Some(buffer) = &self.singleton_buffer + && let Some(file) = buffer.read(cx).file() + { + let is_project_open_source = + self.store + .read(cx) + .is_file_open_source(&self.project, file, cx); + if self.store.read(cx).data_collection_choice.is_enabled() { + DataCollectionState::Enabled { + is_project_open_source, + } + } else { + DataCollectionState::Disabled { + is_project_open_source, + } + } + } else { + return DataCollectionState::Disabled { + is_project_open_source: false, + }; + } } - fn toggle_data_collection(&mut self, _cx: &mut App) { - // TODO [zeta2] + fn toggle_data_collection(&mut self, cx: &mut App) { + self.store.update(cx, |store, cx| { + store.toggle_data_collection_choice(cx); + }); } fn usage(&self, cx: &App) -> Option { - self.zeta.read(cx).usage(cx) + self.store.read(cx).usage(cx) } fn is_enabled( @@ -76,16 +98,16 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { _cursor_position: language::Anchor, cx: &App, ) -> bool { - let zeta = self.zeta.read(cx); - if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep { - zeta.has_sweep_api_token() + let store = self.store.read(cx); + if store.edit_prediction_model == EditPredictionModel::Sweep { + store.has_sweep_api_token() } else { true } } fn is_refreshing(&self, cx: &App) -> bool { - self.zeta.read(cx).is_refreshing(&self.project) + self.store.read(cx).is_refreshing(&self.project) } fn refresh( @@ -95,24 +117,24 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { _debounce: bool, cx: &mut Context, ) { - let zeta = self.zeta.read(cx); + let store = self.store.read(cx); - if zeta.user_store.read_with(cx, |user_store, _cx| { + if store.user_store.read_with(cx, |user_store, _cx| { user_store.account_too_young() || user_store.has_overdue_invoices() }) { return; } - if let Some(current) = zeta.current_prediction_for_buffer(&buffer, &self.project, cx) + if let Some(current) = store.current_prediction_for_buffer(&buffer, &self.project, cx) && let BufferEditPrediction::Local { prediction } = current && prediction.interpolate(buffer.read(cx)).is_some() { return; } - self.zeta.update(cx, |zeta, cx| { - zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx); - zeta.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx) + self.store.update(cx, |store, cx| { + store.refresh_context(&self.project, &buffer, cursor_position, cx); + store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx) }); } @@ -126,20 +148,20 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { } fn accept(&mut self, cx: &mut Context) { - self.zeta.update(cx, |zeta, cx| { - zeta.accept_current_prediction(&self.project, cx); + self.store.update(cx, |store, cx| { + store.accept_current_prediction(&self.project, cx); }); } fn discard(&mut self, cx: &mut Context) { - self.zeta.update(cx, |zeta, _cx| { - zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &self.project); + self.store.update(cx, |store, _cx| { + store.reject_current_prediction(EditPredictionRejectReason::Discarded, &self.project); }); } fn did_show(&mut self, cx: &mut Context) { - self.zeta.update(cx, |zeta, cx| { - zeta.did_show_current_prediction(&self.project, cx); + self.store.update(cx, |store, cx| { + store.did_show_current_prediction(&self.project, cx); }); } @@ -148,16 +170,16 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { buffer: &Entity, cursor_position: language::Anchor, cx: &mut Context, - ) -> Option { + ) -> Option { let prediction = - self.zeta + self.store .read(cx) .current_prediction_for_buffer(buffer, &self.project, cx)?; let prediction = match prediction { BufferEditPrediction::Local { prediction } => prediction, BufferEditPrediction::Jump { prediction } => { - return Some(edit_prediction::EditPrediction::Jump { + return Some(edit_prediction_types::EditPrediction::Jump { id: Some(prediction.id.to_string().into()), snapshot: prediction.snapshot.clone(), target: prediction.edits.first().unwrap().0.start, @@ -169,8 +191,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { let snapshot = buffer.snapshot(); let Some(edits) = prediction.interpolate(&snapshot) else { - self.zeta.update(cx, |zeta, _cx| { - zeta.reject_current_prediction( + self.store.update(cx, |store, _cx| { + store.reject_current_prediction( EditPredictionRejectReason::InterpolatedEmpty, &self.project, ); @@ -208,7 +230,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider { } } - Some(edit_prediction::EditPrediction::Local { + Some(edit_prediction_types::EditPrediction::Local { id: Some(prediction.id.to_string().into()), edits: edits[edit_start_ix..edit_end_ix].to_vec(), edit_preview: Some(prediction.edit_preview.clone()), diff --git a/crates/zeta/src/zeta1.rs b/crates/edit_prediction/src/zeta1.rs similarity index 74% rename from crates/zeta/src/zeta1.rs rename to crates/edit_prediction/src/zeta1.rs index 0be5fad301242c51c4ad58c60a6d2fcb3441ea08..ad630484d392d75849bd33a52a55e63ea77ca23f 100644 --- a/crates/zeta/src/zeta1.rs +++ b/crates/edit_prediction/src/zeta1.rs @@ -1,9 +1,8 @@ -mod input_excerpt; - use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant}; use crate::{ - EditPredictionId, ZedUpdateRequiredError, Zeta, + EditPredictionId, EditPredictionStore, ZedUpdateRequiredError, + cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count}, prediction::{EditPredictionInputs, EditPredictionResult}, }; use anyhow::{Context as _, Result}; @@ -12,7 +11,6 @@ use cloud_llm_client::{ predict_edits_v3::Event, }; use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task}; -use input_excerpt::excerpt_for_cursor_position; use language::{ Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff, }; @@ -30,23 +28,23 @@ pub(crate) const MAX_REWRITE_TOKENS: usize = 350; pub(crate) const MAX_EVENT_TOKENS: usize = 500; pub(crate) fn request_prediction_with_zeta1( - zeta: &mut Zeta, + store: &mut EditPredictionStore, project: &Entity, buffer: &Entity, snapshot: BufferSnapshot, position: language::Anchor, events: Vec>, trigger: PredictEditsRequestTrigger, - cx: &mut Context, + cx: &mut Context, ) -> Task>> { let buffer = buffer.clone(); let buffer_snapshotted_at = Instant::now(); - let client = zeta.client.clone(); - let llm_token = zeta.llm_token.clone(); + let client = store.client.clone(); + let llm_token = store.llm_token.clone(); let app_version = AppVersion::global(cx); let (git_info, can_collect_file) = if let Some(file) = snapshot.file() { - let can_collect_file = zeta.can_collect_file(project, file, cx); + let can_collect_file = store.can_collect_file(project, file, cx); let git_info = if can_collect_file { git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx) } else { @@ -102,7 +100,7 @@ pub(crate) fn request_prediction_with_zeta1( let http_client = client.http_client(); - let response = Zeta::send_api_request::( + let response = EditPredictionStore::send_api_request::( |request| { let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") { predict_edits_url @@ -124,7 +122,7 @@ pub(crate) fn request_prediction_with_zeta1( let inputs = EditPredictionInputs { events: included_events.into(), - included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile { + included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile { path: full_path.clone(), max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row), excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt { @@ -155,8 +153,8 @@ pub(crate) fn request_prediction_with_zeta1( Err(err) => { if err.is::() { cx.update(|cx| { - this.update(cx, |zeta, _cx| { - zeta.update_required = true; + this.update(cx, |ep_store, _cx| { + ep_store.update_required = true; }) .ok(); @@ -495,10 +493,159 @@ pub fn format_event(event: &Event) -> String { } } -/// Typical number of string bytes per token for the purposes of limiting model input. This is -/// intentionally low to err on the side of underestimating limits. -pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3; +#[derive(Debug)] +pub struct InputExcerpt { + pub context_range: Range, + pub editable_range: Range, + pub prompt: String, +} + +pub fn excerpt_for_cursor_position( + position: Point, + path: &str, + snapshot: &BufferSnapshot, + editable_region_token_limit: usize, + context_token_limit: usize, +) -> InputExcerpt { + let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position( + position, + snapshot, + editable_region_token_limit, + context_token_limit, + ); + + let mut prompt = String::new(); + + writeln!(&mut prompt, "```{path}").unwrap(); + if context_range.start == Point::zero() { + writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap(); + } + + for chunk in snapshot.chunks(context_range.start..editable_range.start, false) { + prompt.push_str(chunk.text); + } + + push_editable_range(position, snapshot, editable_range.clone(), &mut prompt); + + for chunk in snapshot.chunks(editable_range.end..context_range.end, false) { + prompt.push_str(chunk.text); + } + write!(prompt, "\n```").unwrap(); + + InputExcerpt { + context_range, + editable_range, + prompt, + } +} + +fn push_editable_range( + cursor_position: Point, + snapshot: &BufferSnapshot, + editable_range: Range, + prompt: &mut String, +) { + writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap(); + for chunk in snapshot.chunks(editable_range.start..cursor_position, false) { + prompt.push_str(chunk.text); + } + prompt.push_str(CURSOR_MARKER); + for chunk in snapshot.chunks(cursor_position..editable_range.end, false) { + prompt.push_str(chunk.text); + } + write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap(); +} + +#[cfg(test)] +mod tests { + use super::*; + use gpui::{App, AppContext}; + use indoc::indoc; + use language::Buffer; + + #[gpui::test] + fn test_excerpt_for_cursor_position(cx: &mut App) { + let text = indoc! {r#" + fn foo() { + let x = 42; + println!("Hello, world!"); + } + + fn bar() { + let x = 42; + let mut sum = 0; + for i in 0..x { + sum += i; + } + println!("Sum: {}", sum); + return sum; + } + + fn generate_random_numbers() -> Vec { + let mut rng = rand::thread_rng(); + let mut numbers = Vec::new(); + for _ in 0..5 { + numbers.push(rng.random_range(1..101)); + } + numbers + } + "#}; + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx)); + let snapshot = buffer.read(cx).snapshot(); + + // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion + // when a larger scope doesn't fit the editable region. + let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32); + assert_eq!( + excerpt.prompt, + indoc! {r#" + ```main.rs + let x = 42; + println!("Hello, world!"); + <|editable_region_start|> + } -fn guess_token_count(bytes: usize) -> usize { - bytes / BYTES_PER_TOKEN_GUESS + fn bar() { + let x = 42; + let mut sum = 0; + for i in 0..x { + sum += i; + } + println!("Sum: {}", sum); + r<|user_cursor_is_here|>eturn sum; + } + + fn generate_random_numbers() -> Vec { + <|editable_region_end|> + let mut rng = rand::thread_rng(); + let mut numbers = Vec::new(); + ```"#} + ); + + // The `bar` function won't fit within the editable region, so we resort to line-based expansion. + let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32); + assert_eq!( + excerpt.prompt, + indoc! {r#" + ```main.rs + fn bar() { + let x = 42; + let mut sum = 0; + <|editable_region_start|> + for i in 0..x { + sum += i; + } + println!("Sum: {}", sum); + r<|user_cursor_is_here|>eturn sum; + } + + fn generate_random_numbers() -> Vec { + let mut rng = rand::thread_rng(); + <|editable_region_end|> + let mut numbers = Vec::new(); + for _ in 0..5 { + numbers.push(rng.random_range(1..101)); + ```"#} + ); + } } diff --git a/crates/edit_prediction/src/zeta2.rs b/crates/edit_prediction/src/zeta2.rs new file mode 100644 index 0000000000000000000000000000000000000000..e542bc7e86e6e381766bbedac6a15f431e0693f1 --- /dev/null +++ b/crates/edit_prediction/src/zeta2.rs @@ -0,0 +1,327 @@ +#[cfg(feature = "eval-support")] +use crate::EvalCacheEntryKind; +use crate::open_ai_response::text_from_response; +use crate::prediction::EditPredictionResult; +use crate::{ + DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs, + EditPredictionRequestedDebugEvent, EditPredictionStore, +}; +use anyhow::{Result, anyhow, bail}; +use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat}; +use cloud_llm_client::{EditPredictionRejectReason, PredictEditsRequestTrigger}; +use cloud_zeta2_prompt::CURSOR_MARKER; +use edit_prediction_context::{EditPredictionExcerpt, Line}; +use edit_prediction_context::{RelatedExcerpt, RelatedFile}; +use futures::channel::oneshot; +use gpui::{Entity, Task, prelude::*}; +use language::{Anchor, BufferSnapshot}; +use language::{Buffer, Point, ToOffset as _, ToPoint}; +use project::{Project, ProjectItem as _}; +use release_channel::AppVersion; +use std::{ + env, + path::Path, + sync::Arc, + time::{Duration, Instant}, +}; + +pub fn request_prediction_with_zeta2( + store: &mut EditPredictionStore, + project: &Entity, + active_buffer: &Entity, + active_snapshot: BufferSnapshot, + position: Anchor, + events: Vec>, + mut included_files: Vec, + trigger: PredictEditsRequestTrigger, + cx: &mut Context, +) -> Task>> { + let options = store.options.clone(); + let buffer_snapshotted_at = Instant::now(); + + let Some((excerpt_path, active_project_path)) = active_snapshot + .file() + .map(|file| -> Arc { file.full_path(cx).into() }) + .zip(active_buffer.read(cx).project_path(cx)) + else { + return Task::ready(Err(anyhow!("No file path for excerpt"))); + }; + + let client = store.client.clone(); + let llm_token = store.llm_token.clone(); + let app_version = AppVersion::global(cx); + let debug_tx = store.debug_tx.clone(); + + let file = active_buffer.read(cx).file(); + + let active_file_full_path = file.as_ref().map(|f| f.full_path(cx)); + + // TODO data collection + let can_collect_data = file + .as_ref() + .map_or(false, |file| store.can_collect_file(project, file, cx)); + + #[cfg(feature = "eval-support")] + let eval_cache = store.eval_cache.clone(); + + let request_task = cx.background_spawn({ + let active_buffer = active_buffer.clone(); + async move { + let cursor_offset = position.to_offset(&active_snapshot); + let cursor_point = cursor_offset.to_point(&active_snapshot); + + let before_retrieval = Instant::now(); + + let excerpt_options = options.context; + + let Some(excerpt) = EditPredictionExcerpt::select_from_buffer( + cursor_point, + &active_snapshot, + &excerpt_options, + ) else { + return Ok((None, None)); + }; + + let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start) + ..active_snapshot.anchor_before(excerpt.range.end); + let related_excerpt = RelatedExcerpt { + anchor_range: excerpt_anchor_range.clone(), + point_range: Point::new(excerpt.line_range.start.0, 0) + ..Point::new(excerpt.line_range.end.0, 0), + text: active_snapshot.as_rope().slice(excerpt.range), + }; + + if let Some(buffer_ix) = included_files + .iter() + .position(|file| file.buffer.entity_id() == active_buffer.entity_id()) + { + let file = &mut included_files[buffer_ix]; + file.excerpts.push(related_excerpt); + file.merge_excerpts(); + let last_ix = included_files.len() - 1; + included_files.swap(buffer_ix, last_ix); + } else { + let active_file = RelatedFile { + path: active_project_path, + buffer: active_buffer.downgrade(), + excerpts: vec![related_excerpt], + max_row: active_snapshot.max_point().row, + }; + included_files.push(active_file); + } + + let included_files = included_files + .iter() + .map(|related_file| predict_edits_v3::RelatedFile { + path: Arc::from(related_file.path.path.as_std_path()), + max_row: Line(related_file.max_row), + excerpts: related_file + .excerpts + .iter() + .map(|excerpt| predict_edits_v3::Excerpt { + start_line: Line(excerpt.point_range.start.row), + text: excerpt.text.to_string().into(), + }) + .collect(), + }) + .collect::>(); + + let cloud_request = predict_edits_v3::PredictEditsRequest { + excerpt_path, + excerpt: String::new(), + excerpt_line_range: Line(0)..Line(0), + excerpt_range: 0..0, + cursor_point: predict_edits_v3::Point { + line: predict_edits_v3::Line(cursor_point.row), + column: cursor_point.column, + }, + related_files: included_files, + events, + can_collect_data, + debug_info: debug_tx.is_some(), + prompt_max_bytes: Some(options.max_prompt_bytes), + prompt_format: options.prompt_format, + excerpt_parent: None, + git_info: None, + trigger, + }; + + let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request); + + let inputs = EditPredictionInputs { + included_files: cloud_request.related_files, + events: cloud_request.events, + cursor_point: cloud_request.cursor_point, + cursor_path: cloud_request.excerpt_path, + }; + + let retrieval_time = Instant::now() - before_retrieval; + + let debug_response_tx = if let Some(debug_tx) = &debug_tx { + let (response_tx, response_rx) = oneshot::channel(); + + debug_tx + .unbounded_send(DebugEvent::EditPredictionRequested( + EditPredictionRequestedDebugEvent { + inputs: inputs.clone(), + retrieval_time, + buffer: active_buffer.downgrade(), + local_prompt: match prompt_result.as_ref() { + Ok(prompt) => Ok(prompt.clone()), + Err(err) => Err(err.to_string()), + }, + position, + response_rx, + }, + )) + .ok(); + Some(response_tx) + } else { + None + }; + + if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() { + if let Some(debug_response_tx) = debug_response_tx { + debug_response_tx + .send((Err("Request skipped".to_string()), Duration::ZERO)) + .ok(); + } + anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set") + } + + let prompt = prompt_result?; + let generation_params = + cloud_zeta2_prompt::generation_params(cloud_request.prompt_format); + let request = open_ai::Request { + model: EDIT_PREDICTIONS_MODEL_ID.clone(), + messages: vec![open_ai::RequestMessage::User { + content: open_ai::MessageContent::Plain(prompt), + }], + stream: false, + max_completion_tokens: None, + stop: generation_params.stop.unwrap_or_default(), + temperature: generation_params.temperature.or(Some(0.7)), + tool_choice: None, + parallel_tool_calls: None, + tools: vec![], + prompt_cache_key: None, + reasoning_effort: None, + }; + + log::trace!("Sending edit prediction request"); + + let before_request = Instant::now(); + let response = EditPredictionStore::send_raw_llm_request( + request, + client, + llm_token, + app_version, + #[cfg(feature = "eval-support")] + eval_cache, + #[cfg(feature = "eval-support")] + EvalCacheEntryKind::Prediction, + ) + .await; + let received_response_at = Instant::now(); + let request_time = received_response_at - before_request; + + log::trace!("Got edit prediction response"); + + if let Some(debug_response_tx) = debug_response_tx { + debug_response_tx + .send(( + response + .as_ref() + .map_err(|err| err.to_string()) + .map(|response| response.0.clone()), + request_time, + )) + .ok(); + } + + let (res, usage) = response?; + let request_id = EditPredictionId(res.id.clone().into()); + let Some(mut output_text) = text_from_response(res) else { + return Ok((Some((request_id, None)), usage)); + }; + + if output_text.contains(CURSOR_MARKER) { + log::trace!("Stripping out {CURSOR_MARKER} from response"); + output_text = output_text.replace(CURSOR_MARKER, ""); + } + + let get_buffer_from_context = |path: &Path| { + if Some(path) == active_file_full_path.as_deref() { + Some(( + &active_snapshot, + std::slice::from_ref(&excerpt_anchor_range), + )) + } else { + None + } + }; + + let (_, edits) = match options.prompt_format { + PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => { + if output_text.contains("--- a/\n+++ b/\nNo edits") { + let edits = vec![]; + (&active_snapshot, edits) + } else { + crate::udiff::parse_diff(&output_text, get_buffer_from_context).await? + } + } + PromptFormat::OldTextNewText => { + crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await? + } + _ => { + bail!("unsupported prompt format {}", options.prompt_format) + } + }; + + anyhow::Ok(( + Some(( + request_id, + Some(( + inputs, + active_buffer, + active_snapshot.clone(), + edits, + received_response_at, + )), + )), + usage, + )) + } + }); + + cx.spawn(async move |this, cx| { + let Some((id, prediction)) = + EditPredictionStore::handle_api_response(&this, request_task.await, cx)? + else { + return Ok(None); + }; + + let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) = + prediction + else { + return Ok(Some(EditPredictionResult { + id, + prediction: Err(EditPredictionRejectReason::Empty), + })); + }; + + Ok(Some( + EditPredictionResult::new( + id, + &edited_buffer, + &edited_buffer_snapshot, + edits.into(), + buffer_snapshotted_at, + received_response_at, + inputs, + cx, + ) + .await, + )) + }) +} diff --git a/crates/zeta_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml similarity index 83% rename from crates/zeta_cli/Cargo.toml rename to crates/edit_prediction_cli/Cargo.toml index 2dbca537f55377e84f306e13649dfb71ccf2f181..26a060994d75a2c194cc159c33d88fbc296dfa47 100644 --- a/crates/zeta_cli/Cargo.toml +++ b/crates/edit_prediction_cli/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "zeta_cli" +name = "edit_prediction_cli" version = "0.1.0" edition.workspace = true publish.workspace = true @@ -9,17 +9,18 @@ license = "GPL-3.0-or-later" workspace = true [[bin]] -name = "zeta" +name = "ep_cli" path = "src/main.rs" [dependencies] - anyhow.workspace = true +anthropic.workspace = true +http_client.workspace = true chrono.workspace = true clap.workspace = true client.workspace = true cloud_llm_client.workspace= true -cloud_zeta2_prompt.workspace= true +cloud_zeta2_prompt.workspace = true collections.workspace = true debug_adapter_extension.workspace = true edit_prediction_context.workspace = true @@ -28,6 +29,7 @@ fs.workspace = true futures.workspace = true gpui.workspace = true gpui_tokio.workspace = true +indoc.workspace = true language.workspace = true language_extension.workspace = true language_model.workspace = true @@ -35,9 +37,7 @@ language_models.workspace = true languages = { workspace = true, features = ["load-grammars"] } log.workspace = true node_runtime.workspace = true -ordered-float.workspace = true paths.workspace = true -polars = { version = "0.51", features = ["lazy", "dtype-struct", "parquet"] } project.workspace = true prompt_store.workspace = true pulldown-cmark.workspace = true @@ -48,12 +48,13 @@ serde_json.workspace = true settings.workspace = true shellexpand.workspace = true smol.workspace = true -soa-rs = "0.8.1" +sqlez.workspace = true +sqlez_macros.workspace = true terminal_view.workspace = true toml.workspace = true util.workspace = true watch.workspace = true -zeta = { workspace = true, features = ["eval-support"] } +edit_prediction = { workspace = true, features = ["eval-support"] } zlog.workspace = true [dev-dependencies] diff --git a/crates/edit_prediction_button/LICENSE-GPL b/crates/edit_prediction_cli/LICENSE-GPL similarity index 100% rename from crates/edit_prediction_button/LICENSE-GPL rename to crates/edit_prediction_cli/LICENSE-GPL diff --git a/crates/zeta_cli/build.rs b/crates/edit_prediction_cli/build.rs similarity index 100% rename from crates/zeta_cli/build.rs rename to crates/edit_prediction_cli/build.rs diff --git a/crates/zeta_cli/src/evaluate.rs b/crates/edit_prediction_cli/src/evaluate.rs similarity index 98% rename from crates/zeta_cli/src/evaluate.rs rename to crates/edit_prediction_cli/src/evaluate.rs index 043844768557ad46f61d5fd0d809e1e85c62574f..686c8ce7e7865f265d6bf17e51ca9477194e5252 100644 --- a/crates/zeta_cli/src/evaluate.rs +++ b/crates/edit_prediction_cli/src/evaluate.rs @@ -6,17 +6,17 @@ use std::{ }; use anyhow::Result; +use edit_prediction::{EditPredictionStore, udiff::DiffLine}; use gpui::{AsyncApp, Entity}; use project::Project; use util::ResultExt as _; -use zeta::{Zeta, udiff::DiffLine}; use crate::{ EvaluateArguments, PredictionOptions, example::{Example, NamedExample}, headless::ZetaCliAppState, paths::print_run_data_dir, - predict::{PredictionDetails, perform_predict, setup_zeta}, + predict::{PredictionDetails, perform_predict, setup_store}, }; #[derive(Debug)] @@ -45,7 +45,7 @@ pub async fn run_evaluate( let project = example.setup_project(&app_state, cx).await.unwrap(); let providers = (0..args.repetitions) - .map(|_| setup_zeta(args.options.provider, &project, &app_state, cx).unwrap()) + .map(|_| setup_store(args.options.provider, &project, &app_state, cx).unwrap()) .collect::>(); let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap(); @@ -53,7 +53,7 @@ pub async fn run_evaluate( let tasks = providers .into_iter() .enumerate() - .map(move |(repetition_ix, zeta)| { + .map(move |(repetition_ix, store)| { let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16); let example = example.clone(); let project = project.clone(); @@ -65,7 +65,7 @@ pub async fn run_evaluate( example, repetition_ix, project, - zeta, + store, options, !args.skip_prediction, cx, @@ -154,7 +154,7 @@ pub async fn run_evaluate_one( example: NamedExample, repetition_ix: Option, project: Entity, - zeta: Entity, + store: Entity, prediction_options: PredictionOptions, predict: bool, cx: &mut AsyncApp, @@ -162,7 +162,7 @@ pub async fn run_evaluate_one( let predict_result = perform_predict( example.clone(), project, - zeta, + store, repetition_ix, prediction_options, cx, diff --git a/crates/zeta_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs similarity index 92% rename from crates/zeta_cli/src/example.rs rename to crates/edit_prediction_cli/src/example.rs index a9d4c4f47c5a05d4198b1cffaee51e14a122e88d..4f8c1867cd57d7fb5dbb9c2c08b63dccf2b97d30 100644 --- a/crates/zeta_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -3,6 +3,8 @@ use std::{ cell::RefCell, fmt::{self, Display}, fs, + hash::Hash, + hash::Hasher, io::Write, mem, path::{Path, PathBuf}, @@ -14,6 +16,7 @@ use anyhow::{Context as _, Result, anyhow}; use clap::ValueEnum; use cloud_zeta2_prompt::CURSOR_MARKER; use collections::HashMap; +use edit_prediction::udiff::OpenedBuffers; use futures::{ AsyncWriteExt as _, lock::{Mutex, OwnedMutexGuard}, @@ -25,7 +28,6 @@ use project::{Project, ProjectPath}; use pulldown_cmark::CowStr; use serde::{Deserialize, Serialize}; use util::{paths::PathStyle, rel_path::RelPath}; -use zeta::udiff::OpenedBuffers; use crate::paths::{REPOS_DIR, WORKTREES_DIR}; @@ -43,7 +45,7 @@ pub struct NamedExample { pub example: Example, } -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Hash, Serialize, Deserialize)] pub struct Example { pub repository_url: String, pub revision: String, @@ -54,6 +56,134 @@ pub struct Example { pub expected_patch: String, } +impl Example { + fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> { + // git@github.com:owner/repo.git + if self.repository_url.contains('@') { + let (owner, repo) = self + .repository_url + .split_once(':') + .context("expected : in git url")? + .1 + .split_once('/') + .context("expected / in git url")?; + Ok(( + Cow::Borrowed(owner), + Cow::Borrowed(repo.trim_end_matches(".git")), + )) + // http://github.com/owner/repo.git + } else { + let url = Url::parse(&self.repository_url)?; + let mut segments = url.path_segments().context("empty http url")?; + let owner = segments + .next() + .context("expected owner path segment")? + .to_string(); + let repo = segments + .next() + .context("expected repo path segment")? + .trim_end_matches(".git") + .to_string(); + assert!(segments.next().is_none()); + + Ok((owner.into(), repo.into())) + } + } + + pub async fn setup_worktree(&self, file_name: String) -> Result { + let (repo_owner, repo_name) = self.repo_name()?; + + let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref()); + let repo_lock = lock_repo(&repo_dir).await; + + if !repo_dir.is_dir() { + fs::create_dir_all(&repo_dir)?; + run_git(&repo_dir, &["init"]).await?; + run_git( + &repo_dir, + &["remote", "add", "origin", &self.repository_url], + ) + .await?; + } + + // Resolve the example to a revision, fetching it if needed. + let revision = run_git( + &repo_dir, + &["rev-parse", &format!("{}^{{commit}}", self.revision)], + ) + .await; + let revision = if let Ok(revision) = revision { + revision + } else { + if run_git( + &repo_dir, + &["fetch", "--depth", "1", "origin", &self.revision], + ) + .await + .is_err() + { + run_git(&repo_dir, &["fetch", "origin"]).await?; + } + let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?; + if revision != self.revision { + run_git(&repo_dir, &["tag", &self.revision, &revision]).await?; + } + revision + }; + + // Create the worktree for this example if needed. + let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref()); + if worktree_path.is_dir() { + run_git(&worktree_path, &["clean", "--force", "-d"]).await?; + run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?; + run_git(&worktree_path, &["checkout", revision.as_str()]).await?; + } else { + let worktree_path_string = worktree_path.to_string_lossy(); + run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?; + run_git( + &repo_dir, + &["worktree", "add", "-f", &worktree_path_string, &file_name], + ) + .await?; + } + drop(repo_lock); + + // Apply the uncommitted diff for this example. + if !self.uncommitted_diff.is_empty() { + let mut apply_process = smol::process::Command::new("git") + .current_dir(&worktree_path) + .args(&["apply", "-"]) + .stdin(std::process::Stdio::piped()) + .spawn()?; + + let mut stdin = apply_process.stdin.take().unwrap(); + stdin.write_all(self.uncommitted_diff.as_bytes()).await?; + stdin.close().await?; + drop(stdin); + + let apply_result = apply_process.output().await?; + if !apply_result.status.success() { + anyhow::bail!( + "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}", + apply_result.status, + String::from_utf8_lossy(&apply_result.stderr), + String::from_utf8_lossy(&apply_result.stdout), + ); + } + } + + Ok(worktree_path) + } + + pub fn unique_name(&self) -> String { + let mut hasher = std::hash::DefaultHasher::new(); + self.hash(&mut hasher); + let disambiguator = hasher.finish(); + let hash = format!("{:04x}", disambiguator); + format!("{}_{}", &self.revision[..8], &hash[..4]) + } +} + pub type ActualExcerpt = Excerpt; #[derive(Clone, Debug, Serialize, Deserialize)] @@ -292,90 +422,7 @@ impl NamedExample { } pub async fn setup_worktree(&self) -> Result { - let (repo_owner, repo_name) = self.repo_name()?; - let file_name = self.file_name(); - - let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref()); - let repo_lock = lock_repo(&repo_dir).await; - - if !repo_dir.is_dir() { - fs::create_dir_all(&repo_dir)?; - run_git(&repo_dir, &["init"]).await?; - run_git( - &repo_dir, - &["remote", "add", "origin", &self.example.repository_url], - ) - .await?; - } - - // Resolve the example to a revision, fetching it if needed. - let revision = run_git( - &repo_dir, - &[ - "rev-parse", - &format!("{}^{{commit}}", self.example.revision), - ], - ) - .await; - let revision = if let Ok(revision) = revision { - revision - } else { - run_git( - &repo_dir, - &["fetch", "--depth", "1", "origin", &self.example.revision], - ) - .await?; - let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?; - if revision != self.example.revision { - run_git(&repo_dir, &["tag", &self.example.revision, &revision]).await?; - } - revision - }; - - // Create the worktree for this example if needed. - let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref()); - if worktree_path.is_dir() { - run_git(&worktree_path, &["clean", "--force", "-d"]).await?; - run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?; - run_git(&worktree_path, &["checkout", revision.as_str()]).await?; - } else { - let worktree_path_string = worktree_path.to_string_lossy(); - run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?; - run_git( - &repo_dir, - &["worktree", "add", "-f", &worktree_path_string, &file_name], - ) - .await?; - } - drop(repo_lock); - - // Apply the uncommitted diff for this example. - if !self.example.uncommitted_diff.is_empty() { - let mut apply_process = smol::process::Command::new("git") - .current_dir(&worktree_path) - .args(&["apply", "-"]) - .stdin(std::process::Stdio::piped()) - .spawn()?; - - let mut stdin = apply_process.stdin.take().unwrap(); - stdin - .write_all(self.example.uncommitted_diff.as_bytes()) - .await?; - stdin.close().await?; - drop(stdin); - - let apply_result = apply_process.output().await?; - if !apply_result.status.success() { - anyhow::bail!( - "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}", - apply_result.status, - String::from_utf8_lossy(&apply_result.stderr), - String::from_utf8_lossy(&apply_result.stdout), - ); - } - } - - Ok(worktree_path) + self.example.setup_worktree(self.file_name()).await } pub fn file_name(&self) -> String { @@ -391,40 +438,6 @@ impl NamedExample { .collect() } - fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> { - // git@github.com:owner/repo.git - if self.example.repository_url.contains('@') { - let (owner, repo) = self - .example - .repository_url - .split_once(':') - .context("expected : in git url")? - .1 - .split_once('/') - .context("expected / in git url")?; - Ok(( - Cow::Borrowed(owner), - Cow::Borrowed(repo.trim_end_matches(".git")), - )) - // http://github.com/owner/repo.git - } else { - let url = Url::parse(&self.example.repository_url)?; - let mut segments = url.path_segments().context("empty http url")?; - let owner = segments - .next() - .context("expected owner path segment")? - .to_string(); - let repo = segments - .next() - .context("expected repo path segment")? - .trim_end_matches(".git") - .to_string(); - assert!(segments.next().is_none()); - - Ok((owner.into(), repo.into())) - } - } - pub async fn cursor_position( &self, project: &Entity, @@ -481,7 +494,7 @@ impl NamedExample { project: &Entity, cx: &mut AsyncApp, ) -> Result> { - zeta::udiff::apply_diff(&self.example.edit_history, project, cx).await + edit_prediction::udiff::apply_diff(&self.example.edit_history, project, cx).await } } diff --git a/crates/zeta_cli/src/headless.rs b/crates/edit_prediction_cli/src/headless.rs similarity index 100% rename from crates/zeta_cli/src/headless.rs rename to crates/edit_prediction_cli/src/headless.rs diff --git a/crates/zeta_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs similarity index 75% rename from crates/zeta_cli/src/main.rs rename to crates/edit_prediction_cli/src/main.rs index a9c8b5998cdd32a94a71f1894dfbdc40c22abaed..00086777f1f03112b92f11923ad2d025276699f5 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -5,7 +5,7 @@ mod metrics; mod paths; mod predict; mod source_location; -mod syntax_retrieval_stats; +mod training; mod util; use crate::{ @@ -14,27 +14,22 @@ use crate::{ headless::ZetaCliAppState, predict::run_predict, source_location::SourceLocation, - syntax_retrieval_stats::retrieval_stats, + training::{context::ContextType, distill::run_distill}, util::{open_buffer, open_buffer_with_language_server}, }; -use ::util::paths::PathStyle; +use ::util::{ResultExt, paths::PathStyle}; use anyhow::{Result, anyhow}; use clap::{Args, Parser, Subcommand, ValueEnum}; use cloud_llm_client::predict_edits_v3; -use edit_prediction_context::{ - EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions, -}; +use edit_prediction::udiff::DiffLine; +use edit_prediction_context::EditPredictionExcerptOptions; use gpui::{Application, AsyncApp, Entity, prelude::*}; use language::{Bias, Buffer, BufferSnapshot, Point}; use metrics::delta_chr_f; -use project::{Project, Worktree}; +use project::{Project, Worktree, lsp_store::OpenLspBufferHandle}; use reqwest_client::ReqwestClient; -use serde_json::json; use std::io::{self}; -use std::time::Duration; use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc}; -use zeta::ContextMode; -use zeta::udiff::DiffLine; #[derive(Parser, Debug)] #[command(name = "zeta")] @@ -48,9 +43,9 @@ struct ZetaCliArgs { #[derive(Subcommand, Debug)] enum Command { Context(ContextArgs), - ContextStats(ContextStatsArgs), Predict(PredictArguments), Eval(EvaluateArguments), + Distill(DistillArguments), ConvertExample { path: PathBuf, #[arg(long, value_enum, default_value_t = ExampleFormat::Md)] @@ -63,20 +58,6 @@ enum Command { Clean, } -#[derive(Debug, Args)] -struct ContextStatsArgs { - #[arg(long)] - worktree: PathBuf, - #[arg(long)] - extension: Option, - #[arg(long)] - limit: Option, - #[arg(long)] - skip: Option, - #[clap(flatten)] - zeta2_args: Zeta2Args, -} - #[derive(Debug, Args)] struct ContextArgs { #[arg(long)] @@ -97,7 +78,7 @@ struct ContextArgs { enum ContextProvider { Zeta1, #[default] - Syntax, + Zeta2, } #[derive(Clone, Debug, Args)] @@ -133,6 +114,15 @@ pub struct PredictArguments { options: PredictionOptions, } +#[derive(Debug, Args)] +pub struct DistillArguments { + split_commit_dataset: PathBuf, + #[clap(long, value_enum, default_value_t = ContextType::CurrentFile)] + context_type: ContextType, + #[clap(long)] + batch: Option, +} + #[derive(Clone, Debug, Args)] pub struct PredictionOptions { #[clap(flatten)] @@ -204,35 +194,22 @@ enum PredictionProvider { Sweep, } -fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta::ZetaOptions { - zeta::ZetaOptions { - context: ContextMode::Syntax(EditPredictionContextOptions { - max_retrieved_declarations: args.max_retrieved_definitions, - use_imports: !args.disable_imports_gathering, - excerpt: EditPredictionExcerptOptions { - max_bytes: args.max_excerpt_bytes, - min_bytes: args.min_excerpt_bytes, - target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes, - }, - score: EditPredictionScoreOptions { - omit_excerpt_overlaps, - }, - }), - max_diagnostic_bytes: args.max_diagnostic_bytes, +fn zeta2_args_to_options(args: &Zeta2Args) -> edit_prediction::ZetaOptions { + edit_prediction::ZetaOptions { + context: EditPredictionExcerptOptions { + max_bytes: args.max_excerpt_bytes, + min_bytes: args.min_excerpt_bytes, + target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes, + }, max_prompt_bytes: args.max_prompt_bytes, prompt_format: args.prompt_format.into(), - file_indexing_parallelism: args.file_indexing_parallelism, - buffer_change_grouping_interval: Duration::ZERO, } } #[derive(clap::ValueEnum, Default, Debug, Clone, Copy)] enum PromptFormat { - MarkedExcerpt, - LabeledSections, OnlySnippets, #[default] - NumberedLines, OldTextNewText, Minimal, MinimalQwen, @@ -242,10 +219,7 @@ enum PromptFormat { impl Into for PromptFormat { fn into(self) -> predict_edits_v3::PromptFormat { match self { - Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt, - Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections, Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets, - Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff, Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText, Self::Minimal => predict_edits_v3::PromptFormat::Minimal, Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen, @@ -295,6 +269,7 @@ struct LoadedContext { worktree: Entity, project: Entity, buffer: Entity, + lsp_open_handle: Option, } async fn load_context( @@ -330,7 +305,7 @@ async fn load_context( .await?; let mut ready_languages = HashSet::default(); - let (_lsp_open_handle, buffer) = if *use_language_server { + let (lsp_open_handle, buffer) = if *use_language_server { let (lsp_open_handle, _, buffer) = open_buffer_with_language_server( project.clone(), worktree.clone(), @@ -377,10 +352,11 @@ async fn load_context( worktree, project, buffer, + lsp_open_handle, }) } -async fn zeta2_syntax_context( +async fn zeta2_context( args: ContextArgs, app_state: &Arc, cx: &mut AsyncApp, @@ -390,6 +366,7 @@ async fn zeta2_syntax_context( project, buffer, clipped_cursor, + lsp_open_handle: _handle, .. } = load_context(&args, app_state, cx).await?; @@ -402,34 +379,32 @@ async fn zeta2_syntax_context( .await; let output = cx .update(|cx| { - let zeta = cx.new(|cx| { - zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx) + let store = cx.new(|cx| { + edit_prediction::EditPredictionStore::new( + app_state.client.clone(), + app_state.user_store.clone(), + cx, + ) }); - let indexing_done_task = zeta.update(cx, |zeta, cx| { - zeta.set_options(zeta2_args_to_options(&args.zeta2_args, true)); - zeta.register_buffer(&buffer, &project, cx); - zeta.wait_for_initial_indexing(&project, cx) + store.update(cx, |store, cx| { + store.set_options(zeta2_args_to_options(&args.zeta2_args)); + store.register_buffer(&buffer, &project, cx); }); cx.spawn(async move |cx| { - indexing_done_task.await?; - let request = zeta - .update(cx, |zeta, cx| { - let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor); - zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx) - })? - .await?; - - let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?; - - match args.zeta2_args.output_format { - OutputFormat::Prompt => anyhow::Ok(prompt_string), - OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?), - OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({ - "request": request, - "prompt": prompt_string, - "section_labels": section_labels, - }))?), - } + let updates_rx = store.update(cx, |store, cx| { + let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor); + store.set_use_context(true); + store.refresh_context(&project, &buffer, cursor, cx); + store.project_context_updates(&project).unwrap() + })?; + + updates_rx.recv().await.ok(); + + let context = store.update(cx, |store, cx| { + store.context_for_project(&project, cx).to_vec() + })?; + + anyhow::Ok(serde_json::to_string_pretty(&context).unwrap()) }) })? .await?; @@ -441,7 +416,7 @@ async fn zeta1_context( args: ContextArgs, app_state: &Arc, cx: &mut AsyncApp, -) -> Result { +) -> Result { let LoadedContext { full_path_str, snapshot, @@ -456,7 +431,7 @@ async fn zeta1_context( let prompt_for_events = move || (events, 0); cx.update(|cx| { - zeta::zeta1::gather_context( + edit_prediction::zeta1::gather_context( full_path_str, &snapshot, clipped_cursor, @@ -482,24 +457,10 @@ fn main() { None => { if args.printenv { ::util::shell_env::print_env(); - return; } else { panic!("Expected a command"); } } - Some(Command::ContextStats(arguments)) => { - let result = retrieval_stats( - arguments.worktree, - app_state, - arguments.extension, - arguments.limit, - arguments.skip, - zeta2_args_to_options(&arguments.zeta2_args, false), - cx, - ) - .await; - println!("{}", result.unwrap()); - } Some(Command::Context(context_args)) => { let result = match context_args.provider { ContextProvider::Zeta1 => { @@ -507,10 +468,8 @@ fn main() { zeta1_context(context_args, &app_state, cx).await.unwrap(); serde_json::to_string_pretty(&context.body).unwrap() } - ContextProvider::Syntax => { - zeta2_syntax_context(context_args, &app_state, cx) - .await - .unwrap() + ContextProvider::Zeta2 => { + zeta2_context(context_args, &app_state, cx).await.unwrap() } }; println!("{}", result); @@ -521,6 +480,13 @@ fn main() { Some(Command::Eval(arguments)) => { run_evaluate(arguments, &app_state, cx).await; } + Some(Command::Distill(arguments)) => { + let _guard = cx + .update(|cx| gpui_tokio::Tokio::handle(cx)) + .unwrap() + .enter(); + run_distill(arguments).await.log_err(); + } Some(Command::ConvertExample { path, output_format, diff --git a/crates/zeta_cli/src/metrics.rs b/crates/edit_prediction_cli/src/metrics.rs similarity index 99% rename from crates/zeta_cli/src/metrics.rs rename to crates/edit_prediction_cli/src/metrics.rs index dd08459678eef6d04a6b656d19a4572d51a5b5c1..0fdb7fb535df12d00341997a64a96b97867f6f28 100644 --- a/crates/zeta_cli/src/metrics.rs +++ b/crates/edit_prediction_cli/src/metrics.rs @@ -1,5 +1,5 @@ use collections::{HashMap, HashSet}; -use zeta::udiff::DiffLine; +use edit_prediction::udiff::DiffLine; type Counts = HashMap; type CountsDelta = HashMap; @@ -287,7 +287,7 @@ fn count_ngrams(text: &str, n: usize) -> Counts { #[cfg(test)] mod test { use super::*; - use zeta::udiff::DiffLine; + use edit_prediction::udiff::DiffLine; #[test] fn test_delta_chr_f_perfect_match() { diff --git a/crates/zeta_cli/src/paths.rs b/crates/edit_prediction_cli/src/paths.rs similarity index 100% rename from crates/zeta_cli/src/paths.rs rename to crates/edit_prediction_cli/src/paths.rs diff --git a/crates/zeta_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs similarity index 75% rename from crates/zeta_cli/src/predict.rs rename to crates/edit_prediction_cli/src/predict.rs index 99fe65cfa3221a1deb18e767e8faa8e1a1fca0ac..74e939b887ce15790993ec15f5973c7f5fd01866 100644 --- a/crates/zeta_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -7,6 +7,7 @@ use crate::{ use ::serde::Serialize; use anyhow::{Context, Result, anyhow}; use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock}; +use edit_prediction::{EditPredictionStore, EvalCache, EvalCacheEntryKind, EvalCacheKey}; use futures::StreamExt as _; use gpui::{AppContext, AsyncApp, Entity}; use project::Project; @@ -18,7 +19,6 @@ use std::path::PathBuf; use std::sync::Arc; use std::sync::Mutex; use std::time::{Duration, Instant}; -use zeta::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta}; pub async fn run_predict( args: PredictArguments, @@ -27,9 +27,9 @@ pub async fn run_predict( ) { let example = NamedExample::load(args.example_path).unwrap(); let project = example.setup_project(app_state, cx).await.unwrap(); - let zeta = setup_zeta(args.options.provider, &project, app_state, cx).unwrap(); + let store = setup_store(args.options.provider, &project, app_state, cx).unwrap(); let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap(); - let result = perform_predict(example, project, zeta, None, args.options, cx) + let result = perform_predict(example, project, store, None, args.options, cx) .await .unwrap(); result.write(args.format, std::io::stdout()).unwrap(); @@ -37,45 +37,50 @@ pub async fn run_predict( print_run_data_dir(true, std::io::stdout().is_terminal()); } -pub fn setup_zeta( +pub fn setup_store( provider: PredictionProvider, project: &Entity, app_state: &Arc, cx: &mut AsyncApp, -) -> Result> { - let zeta = - cx.new(|cx| zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?; +) -> Result> { + let store = cx.new(|cx| { + edit_prediction::EditPredictionStore::new( + app_state.client.clone(), + app_state.user_store.clone(), + cx, + ) + })?; - zeta.update(cx, |zeta, _cx| { + store.update(cx, |store, _cx| { let model = match provider { - PredictionProvider::Zeta1 => zeta::ZetaEditPredictionModel::Zeta1, - PredictionProvider::Zeta2 => zeta::ZetaEditPredictionModel::Zeta2, - PredictionProvider::Sweep => zeta::ZetaEditPredictionModel::Sweep, + PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1, + PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2, + PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep, }; - zeta.set_edit_prediction_model(model); + store.set_edit_prediction_model(model); })?; let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?; cx.subscribe(&buffer_store, { let project = project.clone(); - let zeta = zeta.clone(); + let store = store.clone(); move |_, event, cx| match event { BufferStoreEvent::BufferAdded(buffer) => { - zeta.update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx)); + store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx)); } _ => {} } })? .detach(); - anyhow::Ok(zeta) + anyhow::Ok(store) } pub async fn perform_predict( example: NamedExample, project: Entity, - zeta: Entity, + store: Entity, repetition_ix: Option, options: PredictionOptions, cx: &mut AsyncApp, @@ -108,8 +113,8 @@ pub async fn perform_predict( std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR) .context("creating latest link")?; - zeta.update(cx, |zeta, _cx| { - zeta.with_eval_cache(Arc::new(RunCache { + store.update(cx, |store, _cx| { + store.with_eval_cache(Arc::new(RunCache { example_run_dir: example_run_dir.clone(), cache_mode, })); @@ -121,44 +126,43 @@ pub async fn perform_predict( let prompt_format = options.zeta2.prompt_format; - zeta.update(cx, |zeta, _cx| { - let mut options = zeta.options().clone(); + store.update(cx, |store, _cx| { + let mut options = store.options().clone(); options.prompt_format = prompt_format.into(); - zeta.set_options(options); + store.set_options(options); })?; let mut debug_task = gpui::Task::ready(Ok(())); if options.provider == crate::PredictionProvider::Zeta2 { - let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?; + let mut debug_rx = store.update(cx, |store, _| store.debug_info())?; debug_task = cx.background_spawn({ let result = result.clone(); async move { let mut start_time = None; - let mut search_queries_generated_at = None; - let mut search_queries_executed_at = None; + let mut retrieval_finished_at = None; while let Some(event) = debug_rx.next().await { match event { - zeta::ZetaDebugInfo::ContextRetrievalStarted(info) => { + edit_prediction::DebugEvent::ContextRetrievalStarted(info) => { start_time = Some(info.timestamp); fs::write( example_run_dir.join("search_prompt.md"), &info.search_prompt, )?; } - zeta::ZetaDebugInfo::SearchQueriesGenerated(info) => { - search_queries_generated_at = Some(info.timestamp); - fs::write( - example_run_dir.join("search_queries.json"), - serde_json::to_string_pretty(&info.search_queries).unwrap(), - )?; - } - zeta::ZetaDebugInfo::SearchQueriesExecuted(info) => { - search_queries_executed_at = Some(info.timestamp); + edit_prediction::DebugEvent::ContextRetrievalFinished(info) => { + retrieval_finished_at = Some(info.timestamp); + for (key, value) in &info.metadata { + if *key == "search_queries" { + fs::write( + example_run_dir.join("search_queries.json"), + value.as_bytes(), + )?; + } + } } - zeta::ZetaDebugInfo::ContextRetrievalFinished(_info) => {} - zeta::ZetaDebugInfo::EditPredictionRequested(request) => { + edit_prediction::DebugEvent::EditPredictionRequested(request) => { let prediction_started_at = Instant::now(); start_time.get_or_insert(prediction_started_at); let prompt = request.local_prompt.unwrap_or_default(); @@ -194,19 +198,16 @@ pub async fn perform_predict( let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?; - let response = zeta::text_from_response(response).unwrap_or_default(); + let response = + edit_prediction::open_ai_response::text_from_response(response) + .unwrap_or_default(); let prediction_finished_at = Instant::now(); fs::write(example_run_dir.join("prediction_response.md"), &response)?; let mut result = result.lock().unwrap(); result.generated_len = response.chars().count(); - - result.planning_search_time = - Some(search_queries_generated_at.unwrap() - start_time.unwrap()); - result.running_search_time = Some( - search_queries_executed_at.unwrap() - - search_queries_generated_at.unwrap(), - ); + result.retrieval_time = + retrieval_finished_at.unwrap() - start_time.unwrap(); result.prediction_time = prediction_finished_at - prediction_started_at; result.total_time = prediction_finished_at - start_time.unwrap(); @@ -218,15 +219,14 @@ pub async fn perform_predict( } }); - zeta.update(cx, |zeta, cx| { - zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx) - })? - .await?; + store.update(cx, |store, cx| { + store.refresh_context(&project, &cursor_buffer, cursor_anchor, cx) + })?; } - let prediction = zeta - .update(cx, |zeta, cx| { - zeta.request_prediction( + let prediction = store + .update(cx, |store, cx| { + store.request_prediction( &project, &cursor_buffer, cursor_anchor, @@ -321,8 +321,7 @@ pub struct PredictionDetails { pub diff: String, pub excerpts: Vec, pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly - pub planning_search_time: Option, - pub running_search_time: Option, + pub retrieval_time: Duration, pub prediction_time: Duration, pub total_time: Duration, pub run_example_dir: PathBuf, @@ -336,8 +335,7 @@ impl PredictionDetails { diff: Default::default(), excerpts: Default::default(), excerpts_text: Default::default(), - planning_search_time: Default::default(), - running_search_time: Default::default(), + retrieval_time: Default::default(), prediction_time: Default::default(), total_time: Default::default(), run_example_dir, @@ -357,28 +355,20 @@ impl PredictionDetails { } pub fn to_markdown(&self) -> String { - let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time; - format!( "## Excerpts\n\n\ {}\n\n\ ## Prediction\n\n\ {}\n\n\ ## Time\n\n\ - Planning searches: {}ms\n\ - Running searches: {}ms\n\ - Making Prediction: {}ms\n\n\ - -------------------\n\n\ - Total: {}ms\n\ - Inference: {}ms ({:.2}%)\n", + Retrieval: {}ms\n\ + Prediction: {}ms\n\n\ + Total: {}ms\n", self.excerpts_text, self.diff, - self.planning_search_time.unwrap_or_default().as_millis(), - self.running_search_time.unwrap_or_default().as_millis(), + self.retrieval_time.as_millis(), self.prediction_time.as_millis(), self.total_time.as_millis(), - inference_time.as_millis(), - (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100. ) } } diff --git a/crates/zeta_cli/src/source_location.rs b/crates/edit_prediction_cli/src/source_location.rs similarity index 100% rename from crates/zeta_cli/src/source_location.rs rename to crates/edit_prediction_cli/src/source_location.rs diff --git a/crates/edit_prediction_cli/src/training/context.rs b/crates/edit_prediction_cli/src/training/context.rs new file mode 100644 index 0000000000000000000000000000000000000000..7b6d9cc19c1c3750bbf03158ceec5c79a9df0340 --- /dev/null +++ b/crates/edit_prediction_cli/src/training/context.rs @@ -0,0 +1,89 @@ +use std::path::Path; + +use crate::{source_location::SourceLocation, training::teacher::TeacherModel}; + +#[derive(Debug, Clone, Default, clap::ValueEnum)] +pub enum ContextType { + #[default] + CurrentFile, +} + +const MAX_CONTEXT_SIZE: usize = 32768; + +pub fn collect_context( + context_type: &ContextType, + worktree_dir: &Path, + cursor: SourceLocation, +) -> String { + let context = match context_type { + ContextType::CurrentFile => { + let file_path = worktree_dir.join(cursor.path.as_std_path()); + let context = std::fs::read_to_string(&file_path).unwrap_or_default(); + + let context = add_special_tags(&context, worktree_dir, cursor); + context + } + }; + + let region_end_offset = context.find(TeacherModel::REGION_END); + + if context.len() <= MAX_CONTEXT_SIZE { + return context; + } + + if let Some(region_end_offset) = region_end_offset + && region_end_offset + TeacherModel::REGION_END.len() > MAX_CONTEXT_SIZE + { + let to_truncate = context.len() - MAX_CONTEXT_SIZE; + format!( + "[...{} bytes truncated]\n{}\n", + to_truncate, + &context[to_truncate..] + ) + } else { + format!( + "{}\n[...{} bytes truncated]\n", + &context[..MAX_CONTEXT_SIZE], + context.len() - MAX_CONTEXT_SIZE + ) + } +} + +/// Add <|editable_region_start/end|> tags +fn add_special_tags(context: &str, worktree_dir: &Path, cursor: SourceLocation) -> String { + let path = worktree_dir.join(cursor.path.as_std_path()); + let file = std::fs::read_to_string(&path).unwrap_or_default(); + let lines = file.lines().collect::>(); + let cursor_row = cursor.point.row as usize; + let start_line = cursor_row.saturating_sub(TeacherModel::LEFT_CONTEXT_SIZE); + let end_line = (cursor_row + TeacherModel::RIGHT_CONTEXT_SIZE).min(lines.len()); + + let snippet = lines[start_line..end_line].join("\n"); + + if context.contains(&snippet) { + let mut cursor_line = lines[cursor_row].to_string(); + cursor_line.insert_str(cursor.point.column as usize, TeacherModel::USER_CURSOR); + + let mut snippet_with_tags_lines = vec![]; + snippet_with_tags_lines.push(TeacherModel::REGION_START); + snippet_with_tags_lines.extend(&lines[start_line..cursor_row]); + snippet_with_tags_lines.push(&cursor_line); + snippet_with_tags_lines.extend(&lines[cursor_row + 1..end_line]); + snippet_with_tags_lines.push(TeacherModel::REGION_END); + let snippet_with_tags = snippet_with_tags_lines.join("\n"); + + context.replace(&snippet, &snippet_with_tags) + } else { + log::warn!( + "Can't find area around the cursor in the context; proceeding without special tags" + ); + context.to_string() + } +} + +pub fn strip_special_tags(context: &str) -> String { + context + .replace(TeacherModel::REGION_START, "") + .replace(TeacherModel::REGION_END, "") + .replace(TeacherModel::USER_CURSOR, "") +} diff --git a/crates/edit_prediction_cli/src/training/distill.rs b/crates/edit_prediction_cli/src/training/distill.rs new file mode 100644 index 0000000000000000000000000000000000000000..277e35551a9fbce43982de832de5ccecf8d6e92e --- /dev/null +++ b/crates/edit_prediction_cli/src/training/distill.rs @@ -0,0 +1,94 @@ +use serde::Deserialize; +use std::sync::Arc; + +use crate::{ + DistillArguments, + example::Example, + source_location::SourceLocation, + training::{ + context::ContextType, + llm_client::LlmClient, + teacher::{TeacherModel, TeacherOutput}, + }, +}; +use anyhow::Result; +use reqwest_client::ReqwestClient; + +#[derive(Debug, Deserialize)] +pub struct SplitCommit { + repo_url: String, + commit_sha: String, + edit_history: String, + expected_patch: String, + cursor_position: String, +} + +pub async fn run_distill(arguments: DistillArguments) -> Result<()> { + let split_commits: Vec = std::fs::read_to_string(&arguments.split_commit_dataset) + .expect("Failed to read split commit dataset") + .lines() + .map(|line| serde_json::from_str(line).expect("Failed to parse JSON line")) + .collect(); + + let http_client: Arc = Arc::new(ReqwestClient::new()); + + let llm_client = if let Some(cache_path) = arguments.batch { + LlmClient::batch(&cache_path, http_client)? + } else { + LlmClient::plain(http_client)? + }; + + let mut teacher = TeacherModel::new( + "claude-sonnet-4-5".to_string(), + ContextType::CurrentFile, + llm_client, + ); + + let mut num_marked_for_batching = 0; + + for commit in split_commits { + if let Some(distilled) = distill_one(&mut teacher, commit).await? { + println!("{}", serde_json::to_string(&distilled)?); + } else { + if num_marked_for_batching == 0 { + log::warn!("Marked for batching"); + } + num_marked_for_batching += 1; + } + } + + eprintln!( + "{} requests are marked for batching", + num_marked_for_batching + ); + let llm_client = teacher.client; + llm_client.sync_batches().await?; + + Ok(()) +} + +pub async fn distill_one( + teacher: &mut TeacherModel, + commit: SplitCommit, +) -> Result> { + let cursor: SourceLocation = commit + .cursor_position + .parse() + .expect("Failed to parse cursor position"); + + let path = cursor.path.to_rel_path_buf(); + + let example = Example { + repository_url: commit.repo_url, + revision: commit.commit_sha, + uncommitted_diff: commit.edit_history.clone(), + cursor_path: path.as_std_path().to_path_buf(), + cursor_position: commit.cursor_position, + edit_history: commit.edit_history, // todo: trim + expected_patch: commit.expected_patch, + }; + + let prediction = teacher.predict(example).await; + + prediction +} diff --git a/crates/edit_prediction_cli/src/training/llm_client.rs b/crates/edit_prediction_cli/src/training/llm_client.rs new file mode 100644 index 0000000000000000000000000000000000000000..ebecbe915d36a9a456296e818e559c654370f939 --- /dev/null +++ b/crates/edit_prediction_cli/src/training/llm_client.rs @@ -0,0 +1,417 @@ +use anthropic::{ + ANTHROPIC_API_URL, Message, Request as AnthropicRequest, RequestContent, + Response as AnthropicResponse, Role, non_streaming_completion, +}; +use anyhow::Result; +use http_client::HttpClient; +use indoc::indoc; +use sqlez::bindable::Bind; +use sqlez::bindable::StaticColumnCount; +use sqlez_macros::sql; +use std::hash::Hash; +use std::hash::Hasher; +use std::sync::Arc; + +pub struct PlainLlmClient { + http_client: Arc, + api_key: String, +} + +impl PlainLlmClient { + fn new(http_client: Arc) -> Result { + let api_key = std::env::var("ANTHROPIC_API_KEY") + .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?; + Ok(Self { + http_client, + api_key, + }) + } + + async fn generate( + &self, + model: String, + max_tokens: u64, + messages: Vec, + ) -> Result { + let request = AnthropicRequest { + model, + max_tokens, + messages, + tools: Vec::new(), + thinking: None, + tool_choice: None, + system: None, + metadata: None, + stop_sequences: Vec::new(), + temperature: None, + top_k: None, + top_p: None, + }; + + let response = non_streaming_completion( + self.http_client.as_ref(), + ANTHROPIC_API_URL, + &self.api_key, + request, + None, + ) + .await + .map_err(|e| anyhow::anyhow!("{:?}", e))?; + + Ok(response) + } +} + +pub struct BatchingLlmClient { + connection: sqlez::connection::Connection, + http_client: Arc, + api_key: String, +} + +struct CacheRow { + request_hash: String, + request: Option, + response: Option, + batch_id: Option, +} + +impl StaticColumnCount for CacheRow { + fn column_count() -> usize { + 4 + } +} + +impl Bind for CacheRow { + fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result { + let next_index = statement.bind(&self.request_hash, start_index)?; + let next_index = statement.bind(&self.request, next_index)?; + let next_index = statement.bind(&self.response, next_index)?; + let next_index = statement.bind(&self.batch_id, next_index)?; + Ok(next_index) + } +} + +#[derive(serde::Serialize, serde::Deserialize)] +struct SerializableRequest { + model: String, + max_tokens: u64, + messages: Vec, +} + +#[derive(serde::Serialize, serde::Deserialize)] +struct SerializableMessage { + role: String, + content: String, +} + +impl BatchingLlmClient { + fn new(cache_path: &str, http_client: Arc) -> Result { + let api_key = std::env::var("ANTHROPIC_API_KEY") + .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?; + + let connection = sqlez::connection::Connection::open_file(&cache_path); + let mut statement = sqlez::statement::Statement::prepare( + &connection, + indoc! {" + CREATE TABLE IF NOT EXISTS cache ( + request_hash TEXT PRIMARY KEY, + request TEXT, + response TEXT, + batch_id TEXT + ); + "}, + )?; + statement.exec()?; + drop(statement); + + Ok(Self { + connection, + http_client, + api_key, + }) + } + + pub fn lookup( + &self, + model: &str, + max_tokens: u64, + messages: &[Message], + ) -> Result> { + let request_hash_str = Self::request_hash(model, max_tokens, messages); + let response: Vec = self.connection.select_bound( + &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;), + )?(request_hash_str.as_str())?; + Ok(response + .into_iter() + .next() + .and_then(|text| serde_json::from_str(&text).ok())) + } + + pub fn mark_for_batch(&self, model: &str, max_tokens: u64, messages: &[Message]) -> Result<()> { + let request_hash = Self::request_hash(model, max_tokens, messages); + + let serializable_messages: Vec = messages + .iter() + .map(|msg| SerializableMessage { + role: match msg.role { + Role::User => "user".to_string(), + Role::Assistant => "assistant".to_string(), + }, + content: message_content_to_string(&msg.content), + }) + .collect(); + + let serializable_request = SerializableRequest { + model: model.to_string(), + max_tokens, + messages: serializable_messages, + }; + + let request = Some(serde_json::to_string(&serializable_request)?); + let cache_row = CacheRow { + request_hash, + request, + response: None, + batch_id: None, + }; + self.connection.exec_bound(sql!( + INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?( + cache_row, + ) + } + + async fn generate( + &self, + model: String, + max_tokens: u64, + messages: Vec, + ) -> Result> { + let response = self.lookup(&model, max_tokens, &messages)?; + if let Some(response) = response { + return Ok(Some(response)); + } + + self.mark_for_batch(&model, max_tokens, &messages)?; + + Ok(None) + } + + /// Uploads pending requests as a new batch; downloads finished batches if any. + async fn sync_batches(&self) -> Result<()> { + self.upload_pending_requests().await?; + self.download_finished_batches().await + } + + async fn download_finished_batches(&self) -> Result<()> { + let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL); + let batch_ids: Vec = self.connection.select(q)?()?; + + for batch_id in batch_ids { + let batch_status = anthropic::batches::retrieve_batch( + self.http_client.as_ref(), + ANTHROPIC_API_URL, + &self.api_key, + &batch_id, + ) + .await + .map_err(|e| anyhow::anyhow!("{:?}", e))?; + + log::info!( + "Batch {} status: {}", + batch_id, + batch_status.processing_status + ); + + if batch_status.processing_status == "ended" { + let results = anthropic::batches::retrieve_batch_results( + self.http_client.as_ref(), + ANTHROPIC_API_URL, + &self.api_key, + &batch_id, + ) + .await + .map_err(|e| anyhow::anyhow!("{:?}", e))?; + + let mut success_count = 0; + for result in results { + let request_hash = result + .custom_id + .strip_prefix("req_hash_") + .unwrap_or(&result.custom_id) + .to_string(); + + match result.result { + anthropic::batches::BatchResult::Succeeded { message } => { + let response_json = serde_json::to_string(&message)?; + let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?); + self.connection.exec_bound(q)?((response_json, request_hash))?; + success_count += 1; + } + anthropic::batches::BatchResult::Errored { error } => { + log::error!("Batch request {} failed: {:?}", request_hash, error); + } + anthropic::batches::BatchResult::Canceled => { + log::warn!("Batch request {} was canceled", request_hash); + } + anthropic::batches::BatchResult::Expired => { + log::warn!("Batch request {} expired", request_hash); + } + } + } + log::info!("Uploaded {} successful requests", success_count); + } + } + + Ok(()) + } + + async fn upload_pending_requests(&self) -> Result { + let q = sql!( + SELECT request_hash, request FROM cache WHERE batch_id IS NULL AND response IS NULL + ); + + let rows: Vec<(String, String)> = self.connection.select(q)?()?; + + if rows.is_empty() { + return Ok(String::new()); + } + + let batch_requests = rows + .iter() + .map(|(hash, request_str)| { + let serializable_request: SerializableRequest = + serde_json::from_str(&request_str).unwrap(); + + let messages: Vec = serializable_request + .messages + .into_iter() + .map(|msg| Message { + role: match msg.role.as_str() { + "user" => Role::User, + "assistant" => Role::Assistant, + _ => Role::User, + }, + content: vec![RequestContent::Text { + text: msg.content, + cache_control: None, + }], + }) + .collect(); + + let params = AnthropicRequest { + model: serializable_request.model, + max_tokens: serializable_request.max_tokens, + messages, + tools: Vec::new(), + thinking: None, + tool_choice: None, + system: None, + metadata: None, + stop_sequences: Vec::new(), + temperature: None, + top_k: None, + top_p: None, + }; + + let custom_id = format!("req_hash_{}", hash); + anthropic::batches::BatchRequest { custom_id, params } + }) + .collect::>(); + + let batch_len = batch_requests.len(); + let batch = anthropic::batches::create_batch( + self.http_client.as_ref(), + ANTHROPIC_API_URL, + &self.api_key, + anthropic::batches::CreateBatchRequest { + requests: batch_requests, + }, + ) + .await + .map_err(|e| anyhow::anyhow!("{:?}", e))?; + + let q = sql!( + UPDATE cache SET batch_id = ? WHERE batch_id is NULL + ); + self.connection.exec_bound(q)?(batch.id.as_str())?; + + log::info!("Uploaded batch with {} requests", batch_len); + + Ok(batch.id) + } + + fn request_hash(model: &str, max_tokens: u64, messages: &[Message]) -> String { + let mut hasher = std::hash::DefaultHasher::new(); + model.hash(&mut hasher); + max_tokens.hash(&mut hasher); + for msg in messages { + message_content_to_string(&msg.content).hash(&mut hasher); + } + let request_hash = hasher.finish(); + format!("{request_hash:016x}") + } +} + +fn message_content_to_string(content: &[RequestContent]) -> String { + content + .iter() + .filter_map(|c| match c { + RequestContent::Text { text, .. } => Some(text.clone()), + _ => None, + }) + .collect::>() + .join("\n") +} + +pub enum LlmClient { + // No batching + Plain(PlainLlmClient), + Batch(BatchingLlmClient), + Dummy, +} + +impl LlmClient { + pub fn plain(http_client: Arc) -> Result { + Ok(Self::Plain(PlainLlmClient::new(http_client)?)) + } + + pub fn batch(cache_path: &str, http_client: Arc) -> Result { + Ok(Self::Batch(BatchingLlmClient::new( + cache_path, + http_client, + )?)) + } + + #[allow(dead_code)] + pub fn dummy() -> Self { + Self::Dummy + } + + pub async fn generate( + &self, + model: String, + max_tokens: u64, + messages: Vec, + ) -> Result> { + match self { + LlmClient::Plain(plain_llm_client) => plain_llm_client + .generate(model, max_tokens, messages) + .await + .map(Some), + LlmClient::Batch(batching_llm_client) => { + batching_llm_client + .generate(model, max_tokens, messages) + .await + } + LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"), + } + } + + pub async fn sync_batches(&self) -> Result<()> { + match self { + LlmClient::Plain(_) => Ok(()), + LlmClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await, + LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"), + } + } +} diff --git a/crates/edit_prediction_cli/src/training/mod.rs b/crates/edit_prediction_cli/src/training/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..dc564c4dc86c8e095e8e93ccbdfb29d3313e922a --- /dev/null +++ b/crates/edit_prediction_cli/src/training/mod.rs @@ -0,0 +1,4 @@ +pub mod context; +pub mod distill; +pub mod llm_client; +pub mod teacher; diff --git a/crates/edit_prediction_cli/src/training/teacher.prompt.md b/crates/edit_prediction_cli/src/training/teacher.prompt.md new file mode 100644 index 0000000000000000000000000000000000000000..af67c871ef31a21a8744bf71375a50128d9699b6 --- /dev/null +++ b/crates/edit_prediction_cli/src/training/teacher.prompt.md @@ -0,0 +1,48 @@ +# Instructions + +You are a code completion assistant helping a programmer finish their work. Your task is to: + +1. Analyze the edit history to understand what the programmer is trying to achieve +2. Identify any incomplete refactoring or changes that need to be finished +3. Make the remaining edits that a human programmer would logically make next (by rewriting the corresponding code sections) +4. Apply systematic changes consistently across the entire codebase - if you see a pattern starting, complete it everywhere. + +Focus on: +- Understanding the intent behind the changes (e.g., improving error handling, refactoring APIs, fixing bugs) +- Completing any partially-applied changes across the codebase +- Ensuring consistency with the programming style and patterns already established +- Making edits that maintain or improve code quality +- If the programmer started refactoring one instance of a pattern, find and update ALL similar instances +- Don't write a lot of code if you're not sure what to do + +Rules: +- Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals. +- Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code. + +Input format: +- You receive small code fragments called context (structs, field definitions, function signatures, etc.). They may or may not be relevant. +- Never modify the context code. +- You also receive a code snippet between <|editable_region_start|> and <|editable_region_end|>. This is the editable region. +- The cursor position is marked with <|user_cursor|>. + +Output format: +- Return the entire editable region, applying any edits you make. +- Remove the <|user_cursor|> marker. +- Wrap the edited code in a block of exactly five backticks. + +Output example: +````` + // `zed --askpass` Makes zed operate in nc/netcat mode for use with askpass + if let Some(socket) = &args.askpass {{ + askpass::main(socket); + return Ok(()); + }} +````` + +## User Edits History + +{{edit_history}} + +## Code Context + +{{context}} diff --git a/crates/edit_prediction_cli/src/training/teacher.rs b/crates/edit_prediction_cli/src/training/teacher.rs new file mode 100644 index 0000000000000000000000000000000000000000..99672db8f99a87b99a43c8876db2fd0c2f307b21 --- /dev/null +++ b/crates/edit_prediction_cli/src/training/teacher.rs @@ -0,0 +1,266 @@ +use crate::{ + example::Example, + source_location::SourceLocation, + training::{ + context::{ContextType, collect_context, strip_special_tags}, + llm_client::LlmClient, + }, +}; +use anthropic::{Message, RequestContent, ResponseContent, Role}; +use anyhow::Result; + +pub struct TeacherModel { + pub llm_name: String, + pub context: ContextType, + pub client: LlmClient, +} + +#[derive(Debug, serde::Serialize)] +pub struct TeacherOutput { + parsed_output: String, + prompt: String, + raw_llm_response: String, + context: String, + diff: String, +} + +impl TeacherModel { + const PROMPT: &str = include_str!("teacher.prompt.md"); + pub(crate) const REGION_START: &str = "<|editable_region_start|>\n"; + pub(crate) const REGION_END: &str = "<|editable_region_end|>"; + pub(crate) const USER_CURSOR: &str = "<|user_cursor|>"; + + /// Number of lines to include before the cursor position + pub(crate) const LEFT_CONTEXT_SIZE: usize = 5; + + /// Number of lines to include after the cursor position + pub(crate) const RIGHT_CONTEXT_SIZE: usize = 5; + + /// Truncate edit history to this number of last lines + const MAX_HISTORY_LINES: usize = 128; + + pub fn new(llm_name: String, context: ContextType, client: LlmClient) -> Self { + TeacherModel { + llm_name, + context, + client, + } + } + + pub async fn predict(&self, input: Example) -> Result> { + let name = input.unique_name(); + let worktree_dir = input.setup_worktree(name).await?; + let cursor: SourceLocation = input + .cursor_position + .parse() + .expect("Failed to parse cursor position"); + + let context = collect_context(&self.context, &worktree_dir, cursor.clone()); + let edit_history = Self::format_edit_history(&input.edit_history); + + let prompt = Self::PROMPT + .replace("{{context}}", &context) + .replace("{{edit_history}}", &edit_history); + + let messages = vec![Message { + role: Role::User, + content: vec![RequestContent::Text { + text: prompt.clone(), + cache_control: None, + }], + }]; + + let Some(response) = self + .client + .generate(self.llm_name.clone(), 16384, messages) + .await? + else { + return Ok(None); + }; + + let response_text = response + .content + .into_iter() + .filter_map(|content| match content { + ResponseContent::Text { text } => Some(text), + _ => None, + }) + .collect::>() + .join("\n"); + + let parsed_output = self.parse_response(&response_text); + + let original_editable_region = Self::extract_editable_region(&context); + let context_after_edit = context.replace(&original_editable_region, &parsed_output); + let context_after_edit = strip_special_tags(&context_after_edit); + let context_before_edit = strip_special_tags(&context); + let diff = language::unified_diff(&context_before_edit, &context_after_edit); + + // zeta distill --batch batch_results.txt + // zeta distill + // 1. Run `zeta distill <2000 examples <- all examples>` for the first time + // - store LLM requests in a batch, don't actual send the request + // - send the batch (2000 requests) after all inputs are processed + // 2. `zeta send-batches` + // - upload the batch to Anthropic + + // https://platform.claude.com/docs/en/build-with-claude/batch-processing + // https://crates.io/crates/anthropic-sdk-rust + + // - poll for results + // - when ready, store results in cache (a database) + // 3. `zeta distill` again + // - use the cached results this time + + Ok(Some(TeacherOutput { + parsed_output, + prompt, + raw_llm_response: response_text, + context, + diff, + })) + } + + fn parse_response(&self, content: &str) -> String { + let codeblock = Self::extract_last_codeblock(content); + let editable_region = Self::extract_editable_region(&codeblock); + + editable_region + } + + /// Extract content from the last code-fenced block if any, or else return content as is + fn extract_last_codeblock(text: &str) -> String { + let mut last_block = None; + let mut search_start = 0; + + while let Some(start) = text[search_start..].find("```") { + let start = start + search_start; + let bytes = text.as_bytes(); + let mut backtick_end = start; + + while backtick_end < bytes.len() && bytes[backtick_end] == b'`' { + backtick_end += 1; + } + + let backtick_count = backtick_end - start; + let closing_backticks = "`".repeat(backtick_count); + + if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) { + let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1]; + last_block = Some(code_block.to_string()); + search_start = backtick_end + end_pos + backtick_count; + } else { + break; + } + } + + last_block.unwrap_or_else(|| text.to_string()) + } + + fn extract_editable_region(text: &str) -> String { + let start = text + .find(Self::REGION_START) + .map_or(0, |pos| pos + Self::REGION_START.len()); + let end = text.find(Self::REGION_END).unwrap_or(text.len()); + + text[start..end].to_string() + } + + /// Truncates edit history to a maximum length and removes comments (unified diff garbage lines) + fn format_edit_history(edit_history: &str) -> String { + let lines = edit_history + .lines() + .filter(|&s| Self::is_content_line(s)) + .collect::>(); + + let history_lines = if lines.len() > Self::MAX_HISTORY_LINES { + &lines[lines.len() - Self::MAX_HISTORY_LINES..] + } else { + &lines + }; + history_lines.join("\n") + } + + fn is_content_line(s: &str) -> bool { + s.starts_with("-") + || s.starts_with("+") + || s.starts_with(" ") + || s.starts_with("---") + || s.starts_with("+++") + || s.starts_with("@@") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_response() { + let teacher = TeacherModel::new( + "test".to_string(), + ContextType::CurrentFile, + LlmClient::dummy(), + ); + let response = "This is a test response."; + let parsed = teacher.parse_response(response); + assert_eq!(parsed, response.to_string()); + + let response = indoc::indoc! {" + Some thinking + + ````` + actual response + ````` + "}; + let parsed = teacher.parse_response(response); + assert_eq!(parsed, "actual response"); + } + + #[test] + fn test_extract_last_code_block() { + let text = indoc::indoc! {" + Some thinking + + ``` + first block + ``` + + ````` + last block + ````` + "}; + let last_block = TeacherModel::extract_last_codeblock(text); + assert_eq!(last_block, "last block"); + } + + #[test] + fn test_extract_editable_region() { + let teacher = TeacherModel::new( + "test".to_string(), + ContextType::CurrentFile, + LlmClient::dummy(), + ); + let response = indoc::indoc! {" + some lines + are + here + <|editable_region_start|> + one + two three + + <|editable_region_end|> + more + lines here + "}; + let parsed = teacher.parse_response(response); + assert_eq!( + parsed, + indoc::indoc! {" + one + two three + + "} + ); + } +} diff --git a/crates/zeta_cli/src/util.rs b/crates/edit_prediction_cli/src/util.rs similarity index 86% rename from crates/zeta_cli/src/util.rs rename to crates/edit_prediction_cli/src/util.rs index 699c1c743f67e09ef5ca7211c385114802d4ab32..f4a51d94585f82da008ac832dc62392c365738fd 100644 --- a/crates/zeta_cli/src/util.rs +++ b/crates/edit_prediction_cli/src/util.rs @@ -2,7 +2,8 @@ use anyhow::{Result, anyhow}; use futures::channel::mpsc; use futures::{FutureExt as _, StreamExt as _}; use gpui::{AsyncApp, Entity, Task}; -use language::{Buffer, LanguageId, LanguageServerId, ParseStatus}; +use language::{Buffer, LanguageId, LanguageNotFound, LanguageServerId, ParseStatus}; +use project::lsp_store::OpenLspBufferHandle; use project::{Project, ProjectPath, Worktree}; use std::collections::HashSet; use std::sync::Arc; @@ -40,7 +41,7 @@ pub async fn open_buffer_with_language_server( path: Arc, ready_languages: &mut HashSet, cx: &mut AsyncApp, -) -> Result<(Entity>, LanguageServerId, Entity)> { +) -> Result<(OpenLspBufferHandle, LanguageServerId, Entity)> { let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?; let (lsp_open_handle, path_style) = project.update(cx, |project, cx| { @@ -50,6 +51,17 @@ pub async fn open_buffer_with_language_server( ) })?; + let language_registry = project.read_with(cx, |project, _| project.languages().clone())?; + let result = language_registry + .load_language_for_file_path(path.as_std_path()) + .await; + + if let Err(error) = result + && !error.is::() + { + anyhow::bail!(error); + } + let Some(language_id) = buffer.read_with(cx, |buffer, _cx| { buffer.language().map(|language| language.id()) })? @@ -57,9 +69,9 @@ pub async fn open_buffer_with_language_server( return Err(anyhow!("No language for {}", path.display(path_style))); }; - let log_prefix = path.display(path_style); + let log_prefix = format!("{} | ", path.display(path_style)); if !ready_languages.contains(&language_id) { - wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?; + wait_for_lang_server(&project, &buffer, log_prefix, cx).await?; ready_languages.insert(language_id); } @@ -95,7 +107,7 @@ pub fn wait_for_lang_server( log_prefix: String, cx: &mut AsyncApp, ) -> Task> { - println!("{}⏵ Waiting for language server", log_prefix); + eprintln!("{}⏵ Waiting for language server", log_prefix); let (mut tx, mut rx) = mpsc::channel(1); @@ -137,7 +149,7 @@ pub fn wait_for_lang_server( .. } = event { - println!("{}⟲ {message}", log_prefix) + eprintln!("{}⟲ {message}", log_prefix) } } }), @@ -162,7 +174,7 @@ pub fn wait_for_lang_server( cx.spawn(async move |cx| { if !has_lang_server { // some buffers never have a language server, so this aborts quickly in that case. - let timeout = cx.background_executor().timer(Duration::from_secs(5)); + let timeout = cx.background_executor().timer(Duration::from_secs(500)); futures::select! { _ = added_rx.next() => {}, _ = timeout.fuse() => { @@ -173,7 +185,7 @@ pub fn wait_for_lang_server( let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5)); let result = futures::select! { _ = rx.next() => { - println!("{}⚑ Language server idle", log_prefix); + eprintln!("{}⚑ Language server idle", log_prefix); anyhow::Ok(()) }, _ = timeout.fuse() => { diff --git a/crates/edit_prediction_context/Cargo.toml b/crates/edit_prediction_context/Cargo.toml index 6976831b8cbbe2b998f713ff65f1585f28fc3005..f113c3c46075ca70e61d8d07947d37502e8528e8 100644 --- a/crates/edit_prediction_context/Cargo.toml +++ b/crates/edit_prediction_context/Cargo.toml @@ -12,41 +12,32 @@ workspace = true path = "src/edit_prediction_context.rs" [dependencies] +parking_lot.workspace = true anyhow.workspace = true -arrayvec.workspace = true cloud_llm_client.workspace = true collections.workspace = true futures.workspace = true gpui.workspace = true -hashbrown.workspace = true -indoc.workspace = true -itertools.workspace = true language.workspace = true -log.workspace = true -ordered-float.workspace = true -postage.workspace = true +lsp.workspace = true project.workspace = true -regex.workspace = true +log.workspace = true serde.workspace = true -slotmap.workspace = true -strum.workspace = true -text.workspace = true +smallvec.workspace = true tree-sitter.workspace = true util.workspace = true [dev-dependencies] -clap.workspace = true +env_logger.workspace = true +indoc.workspace = true futures.workspace = true gpui = { workspace = true, features = ["test-support"] } -indoc.workspace = true language = { workspace = true, features = ["test-support"] } +lsp = { workspace = true, features = ["test-support"] } pretty_assertions.workspace = true project = {workspace= true, features = ["test-support"]} serde_json.workspace = true settings = {workspace= true, features = ["test-support"]} text = { workspace = true, features = ["test-support"] } -tree-sitter-c.workspace = true -tree-sitter-cpp.workspace = true -tree-sitter-go.workspace = true util = { workspace = true, features = ["test-support"] } zlog.workspace = true diff --git a/crates/edit_prediction_context/src/assemble_excerpts.rs b/crates/edit_prediction_context/src/assemble_excerpts.rs new file mode 100644 index 0000000000000000000000000000000000000000..15f4c03d653429af671c22d6b5abc652d282a38e --- /dev/null +++ b/crates/edit_prediction_context/src/assemble_excerpts.rs @@ -0,0 +1,161 @@ +use crate::RelatedExcerpt; +use language::{BufferSnapshot, OffsetRangeExt as _, Point}; +use std::ops::Range; + +#[cfg(not(test))] +const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 512; +#[cfg(test)] +const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 24; + +pub fn assemble_excerpts( + buffer: &BufferSnapshot, + mut input_ranges: Vec>, +) -> Vec { + merge_ranges(&mut input_ranges); + + let mut outline_ranges = Vec::new(); + let outline_items = buffer.outline_items_as_points_containing(0..buffer.len(), false, None); + let mut outline_ix = 0; + for input_range in &mut input_ranges { + *input_range = clip_range_to_lines(input_range, false, buffer); + + while let Some(outline_item) = outline_items.get(outline_ix) { + let item_range = clip_range_to_lines(&outline_item.range, false, buffer); + + if item_range.start > input_range.start { + break; + } + + if item_range.end > input_range.start { + let body_range = outline_item + .body_range(buffer) + .map(|body| clip_range_to_lines(&body, true, buffer)) + .filter(|body_range| { + body_range.to_offset(buffer).len() > MAX_OUTLINE_ITEM_BODY_SIZE + }); + + add_outline_item( + item_range.clone(), + body_range.clone(), + buffer, + &mut outline_ranges, + ); + + if let Some(body_range) = body_range + && input_range.start < body_range.start + { + let mut child_outline_ix = outline_ix + 1; + while let Some(next_outline_item) = outline_items.get(child_outline_ix) { + if next_outline_item.range.end > body_range.end { + break; + } + if next_outline_item.depth == outline_item.depth + 1 { + let next_item_range = + clip_range_to_lines(&next_outline_item.range, false, buffer); + + add_outline_item( + next_item_range, + next_outline_item + .body_range(buffer) + .map(|body| clip_range_to_lines(&body, true, buffer)), + buffer, + &mut outline_ranges, + ); + } + child_outline_ix += 1; + } + } + } + + outline_ix += 1; + } + } + + input_ranges.extend_from_slice(&outline_ranges); + merge_ranges(&mut input_ranges); + + input_ranges + .into_iter() + .map(|range| { + let offset_range = range.to_offset(buffer); + RelatedExcerpt { + point_range: range, + anchor_range: buffer.anchor_before(offset_range.start) + ..buffer.anchor_after(offset_range.end), + text: buffer.as_rope().slice(offset_range), + } + }) + .collect() +} + +fn clip_range_to_lines( + range: &Range, + inward: bool, + buffer: &BufferSnapshot, +) -> Range { + let mut range = range.clone(); + if inward { + if range.start.column > 0 { + range.start.column = buffer.line_len(range.start.row); + } + range.end.column = 0; + } else { + range.start.column = 0; + if range.end.column > 0 { + range.end.column = buffer.line_len(range.end.row); + } + } + range +} + +fn add_outline_item( + mut item_range: Range, + body_range: Option>, + buffer: &BufferSnapshot, + outline_ranges: &mut Vec>, +) { + if let Some(mut body_range) = body_range { + if body_range.start.column > 0 { + body_range.start.column = buffer.line_len(body_range.start.row); + } + body_range.end.column = 0; + + let head_range = item_range.start..body_range.start; + if head_range.start < head_range.end { + outline_ranges.push(head_range); + } + + let tail_range = body_range.end..item_range.end; + if tail_range.start < tail_range.end { + outline_ranges.push(tail_range); + } + } else { + item_range.start.column = 0; + item_range.end.column = buffer.line_len(item_range.end.row); + outline_ranges.push(item_range); + } +} + +pub fn merge_ranges(ranges: &mut Vec>) { + ranges.sort_unstable_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end))); + + let mut index = 1; + while index < ranges.len() { + let mut prev_range_end = ranges[index - 1].end; + if prev_range_end.column > 0 { + prev_range_end += Point::new(1, 0); + } + + if (prev_range_end + Point::new(1, 0)) + .cmp(&ranges[index].start) + .is_ge() + { + let removed = ranges.remove(index); + if removed.end.cmp(&ranges[index - 1].end).is_gt() { + ranges[index - 1].end = removed.end; + } + } else { + index += 1; + } + } +} diff --git a/crates/edit_prediction_context/src/declaration.rs b/crates/edit_prediction_context/src/declaration.rs deleted file mode 100644 index cc32640425ecc563b1f24a6c695be1c13199cd73..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_context/src/declaration.rs +++ /dev/null @@ -1,350 +0,0 @@ -use cloud_llm_client::predict_edits_v3::{self, Line}; -use language::{Language, LanguageId}; -use project::ProjectEntryId; -use std::ops::Range; -use std::sync::Arc; -use std::{borrow::Cow, path::Path}; -use text::{Bias, BufferId, Rope}; -use util::paths::{path_ends_with, strip_path_suffix}; -use util::rel_path::RelPath; - -use crate::outline::OutlineDeclaration; - -#[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub struct Identifier { - pub name: Arc, - pub language_id: LanguageId, -} - -slotmap::new_key_type! { - pub struct DeclarationId; -} - -#[derive(Debug, Clone)] -pub enum Declaration { - File { - project_entry_id: ProjectEntryId, - declaration: FileDeclaration, - cached_path: CachedDeclarationPath, - }, - Buffer { - project_entry_id: ProjectEntryId, - buffer_id: BufferId, - rope: Rope, - declaration: BufferDeclaration, - cached_path: CachedDeclarationPath, - }, -} - -const ITEM_TEXT_TRUNCATION_LENGTH: usize = 1024; - -impl Declaration { - pub fn identifier(&self) -> &Identifier { - match self { - Declaration::File { declaration, .. } => &declaration.identifier, - Declaration::Buffer { declaration, .. } => &declaration.identifier, - } - } - - pub fn parent(&self) -> Option { - match self { - Declaration::File { declaration, .. } => declaration.parent, - Declaration::Buffer { declaration, .. } => declaration.parent, - } - } - - pub fn as_buffer(&self) -> Option<&BufferDeclaration> { - match self { - Declaration::File { .. } => None, - Declaration::Buffer { declaration, .. } => Some(declaration), - } - } - - pub fn as_file(&self) -> Option<&FileDeclaration> { - match self { - Declaration::Buffer { .. } => None, - Declaration::File { declaration, .. } => Some(declaration), - } - } - - pub fn project_entry_id(&self) -> ProjectEntryId { - match self { - Declaration::File { - project_entry_id, .. - } => *project_entry_id, - Declaration::Buffer { - project_entry_id, .. - } => *project_entry_id, - } - } - - pub fn cached_path(&self) -> &CachedDeclarationPath { - match self { - Declaration::File { cached_path, .. } => cached_path, - Declaration::Buffer { cached_path, .. } => cached_path, - } - } - - pub fn item_range(&self) -> Range { - match self { - Declaration::File { declaration, .. } => declaration.item_range.clone(), - Declaration::Buffer { declaration, .. } => declaration.item_range.clone(), - } - } - - pub fn item_line_range(&self) -> Range { - match self { - Declaration::File { declaration, .. } => declaration.item_line_range.clone(), - Declaration::Buffer { - declaration, rope, .. - } => { - Line(rope.offset_to_point(declaration.item_range.start).row) - ..Line(rope.offset_to_point(declaration.item_range.end).row) - } - } - } - - pub fn item_text(&self) -> (Cow<'_, str>, bool) { - match self { - Declaration::File { declaration, .. } => ( - declaration.text.as_ref().into(), - declaration.text_is_truncated, - ), - Declaration::Buffer { - rope, declaration, .. - } => ( - rope.chunks_in_range(declaration.item_range.clone()) - .collect::>(), - declaration.item_range_is_truncated, - ), - } - } - - pub fn signature_text(&self) -> (Cow<'_, str>, bool) { - match self { - Declaration::File { declaration, .. } => ( - declaration.text[self.signature_range_in_item_text()].into(), - declaration.signature_is_truncated, - ), - Declaration::Buffer { - rope, declaration, .. - } => ( - rope.chunks_in_range(declaration.signature_range.clone()) - .collect::>(), - declaration.signature_range_is_truncated, - ), - } - } - - pub fn signature_range(&self) -> Range { - match self { - Declaration::File { declaration, .. } => declaration.signature_range.clone(), - Declaration::Buffer { declaration, .. } => declaration.signature_range.clone(), - } - } - - pub fn signature_line_range(&self) -> Range { - match self { - Declaration::File { declaration, .. } => declaration.signature_line_range.clone(), - Declaration::Buffer { - declaration, rope, .. - } => { - Line(rope.offset_to_point(declaration.signature_range.start).row) - ..Line(rope.offset_to_point(declaration.signature_range.end).row) - } - } - } - - pub fn signature_range_in_item_text(&self) -> Range { - let signature_range = self.signature_range(); - let item_range = self.item_range(); - signature_range.start.saturating_sub(item_range.start) - ..(signature_range.end.saturating_sub(item_range.start)).min(item_range.len()) - } -} - -fn expand_range_to_line_boundaries_and_truncate( - range: &Range, - limit: usize, - rope: &Rope, -) -> (Range, Range, bool) { - let mut point_range = rope.offset_to_point(range.start)..rope.offset_to_point(range.end); - point_range.start.column = 0; - point_range.end.row += 1; - point_range.end.column = 0; - - let mut item_range = - rope.point_to_offset(point_range.start)..rope.point_to_offset(point_range.end); - let is_truncated = item_range.len() > limit; - if is_truncated { - item_range.end = item_range.start + limit; - } - item_range.end = rope.clip_offset(item_range.end, Bias::Left); - - let line_range = - predict_edits_v3::Line(point_range.start.row)..predict_edits_v3::Line(point_range.end.row); - (item_range, line_range, is_truncated) -} - -#[derive(Debug, Clone)] -pub struct FileDeclaration { - pub parent: Option, - pub identifier: Identifier, - /// offset range of the declaration in the file, expanded to line boundaries and truncated - pub item_range: Range, - /// line range of the declaration in the file, potentially truncated - pub item_line_range: Range, - /// text of `item_range` - pub text: Arc, - /// whether `text` was truncated - pub text_is_truncated: bool, - /// offset range of the signature in the file, expanded to line boundaries and truncated - pub signature_range: Range, - /// line range of the signature in the file, truncated - pub signature_line_range: Range, - /// whether `signature` was truncated - pub signature_is_truncated: bool, -} - -impl FileDeclaration { - pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> FileDeclaration { - let (item_range_in_file, item_line_range_in_file, text_is_truncated) = - expand_range_to_line_boundaries_and_truncate( - &declaration.item_range, - ITEM_TEXT_TRUNCATION_LENGTH, - rope, - ); - - let (mut signature_range_in_file, signature_line_range, mut signature_is_truncated) = - expand_range_to_line_boundaries_and_truncate( - &declaration.signature_range, - ITEM_TEXT_TRUNCATION_LENGTH, - rope, - ); - - if signature_range_in_file.start < item_range_in_file.start { - signature_range_in_file.start = item_range_in_file.start; - signature_is_truncated = true; - } - if signature_range_in_file.end > item_range_in_file.end { - signature_range_in_file.end = item_range_in_file.end; - signature_is_truncated = true; - } - - FileDeclaration { - parent: None, - identifier: declaration.identifier, - signature_range: signature_range_in_file, - signature_line_range, - signature_is_truncated, - text: rope - .chunks_in_range(item_range_in_file.clone()) - .collect::() - .into(), - text_is_truncated, - item_range: item_range_in_file, - item_line_range: item_line_range_in_file, - } - } -} - -#[derive(Debug, Clone)] -pub struct BufferDeclaration { - pub parent: Option, - pub identifier: Identifier, - pub item_range: Range, - pub item_range_is_truncated: bool, - pub signature_range: Range, - pub signature_range_is_truncated: bool, -} - -impl BufferDeclaration { - pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> Self { - let (item_range, _item_line_range, item_range_is_truncated) = - expand_range_to_line_boundaries_and_truncate( - &declaration.item_range, - ITEM_TEXT_TRUNCATION_LENGTH, - rope, - ); - let (signature_range, _signature_line_range, signature_range_is_truncated) = - expand_range_to_line_boundaries_and_truncate( - &declaration.signature_range, - ITEM_TEXT_TRUNCATION_LENGTH, - rope, - ); - Self { - parent: None, - identifier: declaration.identifier, - item_range, - item_range_is_truncated, - signature_range, - signature_range_is_truncated, - } - } -} - -#[derive(Debug, Clone)] -pub struct CachedDeclarationPath { - pub worktree_abs_path: Arc, - pub rel_path: Arc, - /// The relative path of the file, possibly stripped according to `import_path_strip_regex`. - pub rel_path_after_regex_stripping: Arc, -} - -impl CachedDeclarationPath { - pub fn new( - worktree_abs_path: Arc, - path: &Arc, - language: Option<&Arc>, - ) -> Self { - let rel_path = path.clone(); - let rel_path_after_regex_stripping = if let Some(language) = language - && let Some(strip_regex) = language.config().import_path_strip_regex.as_ref() - && let Ok(stripped) = RelPath::unix(&Path::new( - strip_regex.replace_all(rel_path.as_unix_str(), "").as_ref(), - )) { - Arc::from(stripped) - } else { - rel_path.clone() - }; - CachedDeclarationPath { - worktree_abs_path, - rel_path, - rel_path_after_regex_stripping, - } - } - - #[cfg(test)] - pub fn new_for_test(worktree_abs_path: &str, rel_path: &str) -> Self { - let rel_path: Arc = util::rel_path::rel_path(rel_path).into(); - CachedDeclarationPath { - worktree_abs_path: std::path::PathBuf::from(worktree_abs_path).into(), - rel_path_after_regex_stripping: rel_path.clone(), - rel_path, - } - } - - pub fn ends_with_posix_path(&self, path: &Path) -> bool { - if path.as_os_str().len() <= self.rel_path_after_regex_stripping.as_unix_str().len() { - path_ends_with(self.rel_path_after_regex_stripping.as_std_path(), path) - } else { - if let Some(remaining) = - strip_path_suffix(path, self.rel_path_after_regex_stripping.as_std_path()) - { - path_ends_with(&self.worktree_abs_path, remaining) - } else { - false - } - } - } - - pub fn equals_absolute_path(&self, path: &Path) -> bool { - if let Some(remaining) = - strip_path_suffix(path, &self.rel_path_after_regex_stripping.as_std_path()) - { - self.worktree_abs_path.as_ref() == remaining - } else { - false - } - } -} diff --git a/crates/edit_prediction_context/src/declaration_scoring.rs b/crates/edit_prediction_context/src/declaration_scoring.rs deleted file mode 100644 index 48a823362769770c836b44e7d8a6c1942d3a1196..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_context/src/declaration_scoring.rs +++ /dev/null @@ -1,539 +0,0 @@ -use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents; -use collections::HashMap; -use language::BufferSnapshot; -use ordered_float::OrderedFloat; -use project::ProjectEntryId; -use serde::Serialize; -use std::{cmp::Reverse, ops::Range, path::Path, sync::Arc}; -use strum::EnumIter; -use text::{Point, ToPoint}; -use util::RangeExt as _; - -use crate::{ - CachedDeclarationPath, Declaration, EditPredictionExcerpt, Identifier, - imports::{Import, Imports, Module}, - reference::{Reference, ReferenceRegion}, - syntax_index::SyntaxIndexState, - text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient}, -}; - -const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16; - -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct EditPredictionScoreOptions { - pub omit_excerpt_overlaps: bool, -} - -#[derive(Clone, Debug)] -pub struct ScoredDeclaration { - /// identifier used by the local reference - pub identifier: Identifier, - pub declaration: Declaration, - pub components: DeclarationScoreComponents, -} - -#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)] -pub enum DeclarationStyle { - Signature, - Declaration, -} - -#[derive(Clone, Debug, Serialize, Default)] -pub struct DeclarationScores { - pub signature: f32, - pub declaration: f32, - pub retrieval: f32, -} - -impl ScoredDeclaration { - /// Returns the score for this declaration with the specified style. - pub fn score(&self, style: DeclarationStyle) -> f32 { - // TODO: handle truncation - - // Score related to how likely this is the correct declaration, range 0 to 1 - let retrieval = self.retrieval_score(); - - // Score related to the distance between the reference and cursor, range 0 to 1 - let distance_score = if self.components.is_referenced_nearby { - 1.0 / (1.0 + self.components.reference_line_distance as f32 / 10.0).powf(2.0) - } else { - // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures - 0.5 - }; - - // For now instead of linear combination, the scores are just multiplied together. - let combined_score = 10.0 * retrieval * distance_score; - - match style { - DeclarationStyle::Signature => { - combined_score * self.components.excerpt_vs_signature_weighted_overlap - } - DeclarationStyle::Declaration => { - 2.0 * combined_score * self.components.excerpt_vs_item_weighted_overlap - } - } - } - - pub fn retrieval_score(&self) -> f32 { - let mut score = if self.components.is_same_file { - 10.0 / self.components.same_file_declaration_count as f32 - } else if self.components.path_import_match_count > 0 { - 3.0 - } else if self.components.wildcard_path_import_match_count > 0 { - 1.0 - } else if self.components.normalized_import_similarity > 0.0 { - self.components.normalized_import_similarity - } else if self.components.normalized_wildcard_import_similarity > 0.0 { - 0.5 * self.components.normalized_wildcard_import_similarity - } else { - 1.0 / self.components.declaration_count as f32 - }; - score *= 1. + self.components.included_by_others as f32 / 2.; - score *= 1. + self.components.includes_others as f32 / 4.; - score - } - - pub fn size(&self, style: DeclarationStyle) -> usize { - match &self.declaration { - Declaration::File { declaration, .. } => match style { - DeclarationStyle::Signature => declaration.signature_range.len(), - DeclarationStyle::Declaration => declaration.text.len(), - }, - Declaration::Buffer { declaration, .. } => match style { - DeclarationStyle::Signature => declaration.signature_range.len(), - DeclarationStyle::Declaration => declaration.item_range.len(), - }, - } - } - - pub fn score_density(&self, style: DeclarationStyle) -> f32 { - self.score(style) / self.size(style) as f32 - } -} - -pub fn scored_declarations( - options: &EditPredictionScoreOptions, - index: &SyntaxIndexState, - excerpt: &EditPredictionExcerpt, - excerpt_occurrences: &Occurrences, - adjacent_occurrences: &Occurrences, - imports: &Imports, - identifier_to_references: HashMap>, - cursor_offset: usize, - current_buffer: &BufferSnapshot, -) -> Vec { - let cursor_point = cursor_offset.to_point(¤t_buffer); - - let mut wildcard_import_occurrences = Vec::new(); - let mut wildcard_import_paths = Vec::new(); - for wildcard_import in imports.wildcard_modules.iter() { - match wildcard_import { - Module::Namespace(namespace) => { - wildcard_import_occurrences.push(namespace.occurrences()) - } - Module::SourceExact(path) => wildcard_import_paths.push(path), - Module::SourceFuzzy(path) => { - wildcard_import_occurrences.push(Occurrences::from_path(&path)) - } - } - } - - let mut scored_declarations = Vec::new(); - let mut project_entry_id_to_outline_ranges: HashMap>> = - HashMap::default(); - for (identifier, references) in identifier_to_references { - let mut import_occurrences = Vec::new(); - let mut import_paths = Vec::new(); - let mut found_external_identifier: Option<&Identifier> = None; - - if let Some(imports) = imports.identifier_to_imports.get(&identifier) { - // only use alias when it's the only import, could be generalized if some language - // has overlapping aliases - // - // TODO: when an aliased declaration is included in the prompt, should include the - // aliasing in the prompt. - // - // TODO: For SourceFuzzy consider having componentwise comparison that pays - // attention to ordering. - if let [ - Import::Alias { - module, - external_identifier, - }, - ] = imports.as_slice() - { - match module { - Module::Namespace(namespace) => { - import_occurrences.push(namespace.occurrences()) - } - Module::SourceExact(path) => import_paths.push(path), - Module::SourceFuzzy(path) => { - import_occurrences.push(Occurrences::from_path(&path)) - } - } - found_external_identifier = Some(&external_identifier); - } else { - for import in imports { - match import { - Import::Direct { module } => match module { - Module::Namespace(namespace) => { - import_occurrences.push(namespace.occurrences()) - } - Module::SourceExact(path) => import_paths.push(path), - Module::SourceFuzzy(path) => { - import_occurrences.push(Occurrences::from_path(&path)) - } - }, - Import::Alias { .. } => {} - } - } - } - } - - let identifier_to_lookup = found_external_identifier.unwrap_or(&identifier); - // TODO: update this to be able to return more declarations? Especially if there is the - // ability to quickly filter a large list (based on imports) - let identifier_declarations = index - .declarations_for_identifier::(&identifier_to_lookup); - let declaration_count = identifier_declarations.len(); - - if declaration_count == 0 { - continue; - } - - // TODO: option to filter out other candidates when same file / import match - let mut checked_declarations = Vec::with_capacity(declaration_count); - for (declaration_id, declaration) in identifier_declarations { - match declaration { - Declaration::Buffer { - buffer_id, - declaration: buffer_declaration, - .. - } => { - if buffer_id == ¤t_buffer.remote_id() { - let already_included_in_prompt = - range_intersection(&buffer_declaration.item_range, &excerpt.range) - .is_some() - || excerpt - .parent_declarations - .iter() - .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id); - if !options.omit_excerpt_overlaps || !already_included_in_prompt { - let declaration_line = buffer_declaration - .item_range - .start - .to_point(current_buffer) - .row; - let declaration_line_distance = - (cursor_point.row as i32 - declaration_line as i32).unsigned_abs(); - checked_declarations.push(CheckedDeclaration { - declaration, - same_file_line_distance: Some(declaration_line_distance), - path_import_match_count: 0, - wildcard_path_import_match_count: 0, - }); - } - continue; - } else { - } - } - Declaration::File { .. } => {} - } - let declaration_path = declaration.cached_path(); - let path_import_match_count = import_paths - .iter() - .filter(|import_path| { - declaration_path_matches_import(&declaration_path, import_path) - }) - .count(); - let wildcard_path_import_match_count = wildcard_import_paths - .iter() - .filter(|import_path| { - declaration_path_matches_import(&declaration_path, import_path) - }) - .count(); - checked_declarations.push(CheckedDeclaration { - declaration, - same_file_line_distance: None, - path_import_match_count, - wildcard_path_import_match_count, - }); - } - - let mut max_import_similarity = 0.0; - let mut max_wildcard_import_similarity = 0.0; - - let mut scored_declarations_for_identifier = Vec::with_capacity(checked_declarations.len()); - for checked_declaration in checked_declarations { - let same_file_declaration_count = - index.file_declaration_count(checked_declaration.declaration); - - let declaration = score_declaration( - &identifier, - &references, - checked_declaration, - same_file_declaration_count, - declaration_count, - &excerpt_occurrences, - &adjacent_occurrences, - &import_occurrences, - &wildcard_import_occurrences, - cursor_point, - current_buffer, - ); - - if declaration.components.import_similarity > max_import_similarity { - max_import_similarity = declaration.components.import_similarity; - } - - if declaration.components.wildcard_import_similarity > max_wildcard_import_similarity { - max_wildcard_import_similarity = declaration.components.wildcard_import_similarity; - } - - project_entry_id_to_outline_ranges - .entry(declaration.declaration.project_entry_id()) - .or_default() - .push(declaration.declaration.item_range()); - scored_declarations_for_identifier.push(declaration); - } - - if max_import_similarity > 0.0 || max_wildcard_import_similarity > 0.0 { - for declaration in scored_declarations_for_identifier.iter_mut() { - if max_import_similarity > 0.0 { - declaration.components.max_import_similarity = max_import_similarity; - declaration.components.normalized_import_similarity = - declaration.components.import_similarity / max_import_similarity; - } - if max_wildcard_import_similarity > 0.0 { - declaration.components.normalized_wildcard_import_similarity = - declaration.components.wildcard_import_similarity - / max_wildcard_import_similarity; - } - } - } - - scored_declarations.extend(scored_declarations_for_identifier); - } - - // TODO: Inform this via import / retrieval scores of outline items - // TODO: Consider using a sweepline - for scored_declaration in scored_declarations.iter_mut() { - let project_entry_id = scored_declaration.declaration.project_entry_id(); - let Some(ranges) = project_entry_id_to_outline_ranges.get(&project_entry_id) else { - continue; - }; - for range in ranges { - if range.contains_inclusive(&scored_declaration.declaration.item_range()) { - scored_declaration.components.included_by_others += 1 - } else if scored_declaration - .declaration - .item_range() - .contains_inclusive(range) - { - scored_declaration.components.includes_others += 1 - } - } - } - - scored_declarations.sort_unstable_by_key(|declaration| { - Reverse(OrderedFloat( - declaration.score(DeclarationStyle::Declaration), - )) - }); - - scored_declarations -} - -struct CheckedDeclaration<'a> { - declaration: &'a Declaration, - same_file_line_distance: Option, - path_import_match_count: usize, - wildcard_path_import_match_count: usize, -} - -fn declaration_path_matches_import( - declaration_path: &CachedDeclarationPath, - import_path: &Arc, -) -> bool { - if import_path.is_absolute() { - declaration_path.equals_absolute_path(import_path) - } else { - declaration_path.ends_with_posix_path(import_path) - } -} - -fn range_intersection(a: &Range, b: &Range) -> Option> { - let start = a.start.clone().max(b.start.clone()); - let end = a.end.clone().min(b.end.clone()); - if start < end { - Some(Range { start, end }) - } else { - None - } -} - -fn score_declaration( - identifier: &Identifier, - references: &[Reference], - checked_declaration: CheckedDeclaration, - same_file_declaration_count: usize, - declaration_count: usize, - excerpt_occurrences: &Occurrences, - adjacent_occurrences: &Occurrences, - import_occurrences: &[Occurrences], - wildcard_import_occurrences: &[Occurrences], - cursor: Point, - current_buffer: &BufferSnapshot, -) -> ScoredDeclaration { - let CheckedDeclaration { - declaration, - same_file_line_distance, - path_import_match_count, - wildcard_path_import_match_count, - } = checked_declaration; - - let is_referenced_nearby = references - .iter() - .any(|r| r.region == ReferenceRegion::Nearby); - let is_referenced_in_breadcrumb = references - .iter() - .any(|r| r.region == ReferenceRegion::Breadcrumb); - let reference_count = references.len(); - let reference_line_distance = references - .iter() - .map(|r| { - let reference_line = r.range.start.to_point(current_buffer).row as i32; - (cursor.row as i32 - reference_line).unsigned_abs() - }) - .min() - .unwrap(); - - let is_same_file = same_file_line_distance.is_some(); - let declaration_line_distance = same_file_line_distance.unwrap_or(u32::MAX); - - let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0); - let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0); - let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences); - let excerpt_vs_signature_jaccard = - jaccard_similarity(excerpt_occurrences, &item_signature_occurrences); - let adjacent_vs_item_jaccard = - jaccard_similarity(adjacent_occurrences, &item_source_occurrences); - let adjacent_vs_signature_jaccard = - jaccard_similarity(adjacent_occurrences, &item_signature_occurrences); - - let excerpt_vs_item_weighted_overlap = - weighted_overlap_coefficient(excerpt_occurrences, &item_source_occurrences); - let excerpt_vs_signature_weighted_overlap = - weighted_overlap_coefficient(excerpt_occurrences, &item_signature_occurrences); - let adjacent_vs_item_weighted_overlap = - weighted_overlap_coefficient(adjacent_occurrences, &item_source_occurrences); - let adjacent_vs_signature_weighted_overlap = - weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences); - - let mut import_similarity = 0f32; - let mut wildcard_import_similarity = 0f32; - if !import_occurrences.is_empty() || !wildcard_import_occurrences.is_empty() { - let cached_path = declaration.cached_path(); - let path_occurrences = Occurrences::from_worktree_path( - cached_path - .worktree_abs_path - .file_name() - .map(|f| f.to_string_lossy()), - &cached_path.rel_path, - ); - import_similarity = import_occurrences - .iter() - .map(|namespace_occurrences| { - OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences)) - }) - .max() - .map(|similarity| similarity.into_inner()) - .unwrap_or_default(); - - // TODO: Consider something other than max - wildcard_import_similarity = wildcard_import_occurrences - .iter() - .map(|namespace_occurrences| { - OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences)) - }) - .max() - .map(|similarity| similarity.into_inner()) - .unwrap_or_default(); - } - - // TODO: Consider adding declaration_file_count - let score_components = DeclarationScoreComponents { - is_same_file, - is_referenced_nearby, - is_referenced_in_breadcrumb, - reference_line_distance, - declaration_line_distance, - reference_count, - same_file_declaration_count, - declaration_count, - excerpt_vs_item_jaccard, - excerpt_vs_signature_jaccard, - adjacent_vs_item_jaccard, - adjacent_vs_signature_jaccard, - excerpt_vs_item_weighted_overlap, - excerpt_vs_signature_weighted_overlap, - adjacent_vs_item_weighted_overlap, - adjacent_vs_signature_weighted_overlap, - path_import_match_count, - wildcard_path_import_match_count, - import_similarity, - max_import_similarity: 0.0, - normalized_import_similarity: 0.0, - wildcard_import_similarity, - normalized_wildcard_import_similarity: 0.0, - included_by_others: 0, - includes_others: 0, - }; - - ScoredDeclaration { - identifier: identifier.clone(), - declaration: declaration.clone(), - components: score_components, - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_declaration_path_matches() { - let declaration_path = - CachedDeclarationPath::new_for_test("/home/user/project", "src/maths.ts"); - - assert!(declaration_path_matches_import( - &declaration_path, - &Path::new("maths.ts").into() - )); - - assert!(declaration_path_matches_import( - &declaration_path, - &Path::new("project/src/maths.ts").into() - )); - - assert!(declaration_path_matches_import( - &declaration_path, - &Path::new("user/project/src/maths.ts").into() - )); - - assert!(declaration_path_matches_import( - &declaration_path, - &Path::new("/home/user/project/src/maths.ts").into() - )); - - assert!(!declaration_path_matches_import( - &declaration_path, - &Path::new("other.ts").into() - )); - - assert!(!declaration_path_matches_import( - &declaration_path, - &Path::new("/home/user/project/src/other.ts").into() - )); - } -} diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs index 65623a825c2f7e2db42b98174748e5f04fb91d2a..475050fabb8b17ad76c34234094cf798e36a76ab 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -1,335 +1,490 @@ -mod declaration; -mod declaration_scoring; +use crate::assemble_excerpts::assemble_excerpts; +use anyhow::Result; +use collections::HashMap; +use futures::{FutureExt, StreamExt as _, channel::mpsc, future}; +use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity}; +use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _}; +use project::{LocationLink, Project, ProjectPath}; +use serde::{Serialize, Serializer}; +use smallvec::SmallVec; +use std::{ + collections::hash_map, + ops::Range, + sync::Arc, + time::{Duration, Instant}, +}; +use util::{RangeExt as _, ResultExt}; + +mod assemble_excerpts; +#[cfg(test)] +mod edit_prediction_context_tests; mod excerpt; -mod imports; -mod outline; -mod reference; -mod syntax_index; -pub mod text_similarity; +#[cfg(test)] +mod fake_definition_lsp; -use std::{path::Path, sync::Arc}; +pub use cloud_llm_client::predict_edits_v3::Line; +pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText}; -use cloud_llm_client::predict_edits_v3; -use collections::HashMap; -use gpui::{App, AppContext as _, Entity, Task}; -use language::BufferSnapshot; -use text::{Point, ToOffset as _}; - -pub use declaration::*; -pub use declaration_scoring::*; -pub use excerpt::*; -pub use imports::*; -pub use reference::*; -pub use syntax_index::*; - -pub use predict_edits_v3::Line; - -#[derive(Clone, Debug, PartialEq)] -pub struct EditPredictionContextOptions { - pub use_imports: bool, - pub excerpt: EditPredictionExcerptOptions, - pub score: EditPredictionScoreOptions, - pub max_retrieved_declarations: u8, +const IDENTIFIER_LINE_COUNT: u32 = 3; + +pub struct RelatedExcerptStore { + project: WeakEntity, + related_files: Vec, + cache: HashMap>, + update_tx: mpsc::UnboundedSender<(Entity, Anchor)>, + identifier_line_count: u32, +} + +pub enum RelatedExcerptStoreEvent { + StartedRefresh, + FinishedRefresh { + cache_hit_count: usize, + cache_miss_count: usize, + mean_definition_latency: Duration, + max_definition_latency: Duration, + }, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +struct Identifier { + pub name: String, + pub range: Range, +} + +enum DefinitionTask { + CacheHit(Arc), + CacheMiss(Task>>>), +} + +#[derive(Debug)] +struct CacheEntry { + definitions: SmallVec<[CachedDefinition; 1]>, } #[derive(Clone, Debug)] -pub struct EditPredictionContext { - pub excerpt: EditPredictionExcerpt, - pub excerpt_text: EditPredictionExcerptText, - pub cursor_point: Point, - pub declarations: Vec, +struct CachedDefinition { + path: ProjectPath, + buffer: Entity, + anchor_range: Range, +} + +#[derive(Clone, Debug, Serialize)] +pub struct RelatedFile { + #[serde(serialize_with = "serialize_project_path")] + pub path: ProjectPath, + #[serde(skip)] + pub buffer: WeakEntity, + pub excerpts: Vec, + pub max_row: u32, } -impl EditPredictionContext { - pub fn gather_context_in_background( - cursor_point: Point, - buffer: BufferSnapshot, - options: EditPredictionContextOptions, - syntax_index: Option>, - cx: &mut App, - ) -> Task> { - let parent_abs_path = project::File::from_dyn(buffer.file()).and_then(|f| { - let mut path = f.worktree.read(cx).absolutize(&f.path); - if path.pop() { Some(path) } else { None } +impl RelatedFile { + pub fn merge_excerpts(&mut self) { + self.excerpts.sort_unstable_by(|a, b| { + a.point_range + .start + .cmp(&b.point_range.start) + .then(b.point_range.end.cmp(&a.point_range.end)) }); - if let Some(syntax_index) = syntax_index { - let index_state = - syntax_index.read_with(cx, |index, _cx| Arc::downgrade(index.state())); - cx.background_spawn(async move { - let parent_abs_path = parent_abs_path.as_deref(); - let index_state = index_state.upgrade()?; - let index_state = index_state.lock().await; - Self::gather_context( - cursor_point, - &buffer, - parent_abs_path, - &options, - Some(&index_state), - ) - }) - } else { - cx.background_spawn(async move { - let parent_abs_path = parent_abs_path.as_deref(); - Self::gather_context(cursor_point, &buffer, parent_abs_path, &options, None) - }) + let mut index = 1; + while index < self.excerpts.len() { + if self.excerpts[index - 1] + .point_range + .end + .cmp(&self.excerpts[index].point_range.start) + .is_ge() + { + let removed = self.excerpts.remove(index); + if removed + .point_range + .end + .cmp(&self.excerpts[index - 1].point_range.end) + .is_gt() + { + self.excerpts[index - 1].point_range.end = removed.point_range.end; + self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end; + } + } else { + index += 1; + } } } +} - pub fn gather_context( - cursor_point: Point, - buffer: &BufferSnapshot, - parent_abs_path: Option<&Path>, - options: &EditPredictionContextOptions, - index_state: Option<&SyntaxIndexState>, - ) -> Option { - let imports = if options.use_imports { - Imports::gather(&buffer, parent_abs_path) - } else { - Imports::default() - }; - Self::gather_context_with_references_fn( - cursor_point, - buffer, - &imports, - options, - index_state, - references_in_excerpt, - ) - } +#[derive(Clone, Debug, Serialize)] +pub struct RelatedExcerpt { + #[serde(skip)] + pub anchor_range: Range, + #[serde(serialize_with = "serialize_point_range")] + pub point_range: Range, + #[serde(serialize_with = "serialize_rope")] + pub text: Rope, +} - pub fn gather_context_with_references_fn( - cursor_point: Point, - buffer: &BufferSnapshot, - imports: &Imports, - options: &EditPredictionContextOptions, - index_state: Option<&SyntaxIndexState>, - get_references: impl FnOnce( - &EditPredictionExcerpt, - &EditPredictionExcerptText, - &BufferSnapshot, - ) -> HashMap>, - ) -> Option { - let excerpt = EditPredictionExcerpt::select_from_buffer( - cursor_point, - buffer, - &options.excerpt, - index_state, - )?; - let excerpt_text = excerpt.text(buffer); - - let declarations = if options.max_retrieved_declarations > 0 - && let Some(index_state) = index_state - { - let excerpt_occurrences = - text_similarity::Occurrences::within_string(&excerpt_text.body); - - let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0); - let adjacent_end = Point::new(cursor_point.row + 1, 0); - let adjacent_occurrences = text_similarity::Occurrences::within_string( - &buffer - .text_for_range(adjacent_start..adjacent_end) - .collect::(), - ); +fn serialize_project_path( + project_path: &ProjectPath, + serializer: S, +) -> Result { + project_path.path.serialize(serializer) +} - let cursor_offset_in_file = cursor_point.to_offset(buffer); +fn serialize_rope(rope: &Rope, serializer: S) -> Result { + rope.to_string().serialize(serializer) +} - let references = get_references(&excerpt, &excerpt_text, buffer); +fn serialize_point_range( + range: &Range, + serializer: S, +) -> Result { + [ + [range.start.row, range.start.column], + [range.end.row, range.end.column], + ] + .serialize(serializer) +} - let mut declarations = scored_declarations( - &options.score, - &index_state, - &excerpt, - &excerpt_occurrences, - &adjacent_occurrences, - &imports, - references, - cursor_offset_in_file, - buffer, - ); - // TODO [zeta2] if we need this when we ship, we should probably do it in a smarter way - declarations.truncate(options.max_retrieved_declarations as usize); - declarations - } else { - vec![] - }; +const DEBOUNCE_DURATION: Duration = Duration::from_millis(100); + +impl EventEmitter for RelatedExcerptStore {} + +impl RelatedExcerptStore { + pub fn new(project: &Entity, cx: &mut Context) -> Self { + let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity, Anchor)>(); + cx.spawn(async move |this, cx| { + let executor = cx.background_executor().clone(); + while let Some((mut buffer, mut position)) = update_rx.next().await { + let mut timer = executor.timer(DEBOUNCE_DURATION).fuse(); + loop { + futures::select_biased! { + next = update_rx.next() => { + if let Some((new_buffer, new_position)) = next { + buffer = new_buffer; + position = new_position; + timer = executor.timer(DEBOUNCE_DURATION).fuse(); + } else { + return anyhow::Ok(()); + } + } + _ = timer => break, + } + } - Some(Self { - excerpt, - excerpt_text, - cursor_point, - declarations, + Self::fetch_excerpts(this.clone(), buffer, position, cx).await?; + } + anyhow::Ok(()) }) + .detach_and_log_err(cx); + + RelatedExcerptStore { + project: project.downgrade(), + update_tx, + related_files: Vec::new(), + cache: Default::default(), + identifier_line_count: IDENTIFIER_LINE_COUNT, + } } -} -#[cfg(test)] -mod tests { - use super::*; - use std::sync::Arc; - - use gpui::{Entity, TestAppContext}; - use indoc::indoc; - use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust}; - use project::{FakeFs, Project}; - use serde_json::json; - use settings::SettingsStore; - use util::path; - - use crate::{EditPredictionExcerptOptions, SyntaxIndex}; - - #[gpui::test] - async fn test_call_site(cx: &mut TestAppContext) { - let (project, index, _rust_lang_id) = init_test(cx).await; - - let buffer = project - .update(cx, |project, cx| { - let project_path = project.find_project_path("c.rs", cx).unwrap(); - project.open_buffer(project_path, cx) - }) - .await - .unwrap(); - - cx.run_until_parked(); - - // first process_data call site - let cursor_point = language::Point::new(8, 21); - let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); - - let context = cx - .update(|cx| { - EditPredictionContext::gather_context_in_background( - cursor_point, - buffer_snapshot, - EditPredictionContextOptions { - use_imports: true, - excerpt: EditPredictionExcerptOptions { - max_bytes: 60, - min_bytes: 10, - target_before_cursor_over_total_bytes: 0.5, - }, - score: EditPredictionScoreOptions { - omit_excerpt_overlaps: true, - }, - max_retrieved_declarations: u8::MAX, - }, - Some(index.clone()), - cx, - ) - }) - .await - .unwrap(); - - let mut snippet_identifiers = context - .declarations - .iter() - .map(|snippet| snippet.identifier.name.as_ref()) - .collect::>(); - snippet_identifiers.sort(); - assert_eq!(snippet_identifiers, vec!["main", "process_data"]); - drop(buffer); + pub fn set_identifier_line_count(&mut self, count: u32) { + self.identifier_line_count = count; } - async fn init_test( - cx: &mut TestAppContext, - ) -> (Entity, Entity, LanguageId) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - }); + pub fn refresh(&mut self, buffer: Entity, position: Anchor, _: &mut Context) { + self.update_tx.unbounded_send((buffer, position)).ok(); + } - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/root"), - json!({ - "a.rs": indoc! {r#" - fn main() { - let x = 1; - let y = 2; - let z = add(x, y); - println!("Result: {}", z); - } + pub fn related_files(&self) -> &[RelatedFile] { + &self.related_files + } - fn add(a: i32, b: i32) -> i32 { - a + b - } - "#}, - "b.rs": indoc! {" - pub struct Config { - pub name: String, - pub value: i32, - } + async fn fetch_excerpts( + this: WeakEntity, + buffer: Entity, + position: Anchor, + cx: &mut AsyncApp, + ) -> Result<()> { + let (project, snapshot, identifier_line_count) = this.read_with(cx, |this, cx| { + ( + this.project.upgrade(), + buffer.read(cx).snapshot(), + this.identifier_line_count, + ) + })?; + let Some(project) = project else { + return Ok(()); + }; - impl Config { - pub fn new(name: String, value: i32) -> Self { - Config { name, value } - } - } - "}, - "c.rs": indoc! {r#" - use std::collections::HashMap; - - fn main() { - let args: Vec = std::env::args().collect(); - let data: Vec = args[1..] - .iter() - .filter_map(|s| s.parse().ok()) - .collect(); - let result = process_data(data); - println!("{:?}", result); - } + let file = snapshot.file().cloned(); + if let Some(file) = &file { + log::debug!("retrieving_context buffer:{}", file.path().as_unix_str()); + } + + this.update(cx, |_, cx| { + cx.emit(RelatedExcerptStoreEvent::StartedRefresh); + })?; - fn process_data(data: Vec) -> HashMap { - let mut counts = HashMap::new(); - for value in data { - *counts.entry(value).or_insert(0) += 1; + let identifiers = cx + .background_spawn(async move { + identifiers_for_position(&snapshot, position, identifier_line_count) + }) + .await; + + let async_cx = cx.clone(); + let start_time = Instant::now(); + let futures = this.update(cx, |this, cx| { + identifiers + .into_iter() + .filter_map(|identifier| { + let task = if let Some(entry) = this.cache.get(&identifier) { + DefinitionTask::CacheHit(entry.clone()) + } else { + DefinitionTask::CacheMiss( + this.project + .update(cx, |project, cx| { + project.definitions(&buffer, identifier.range.start, cx) + }) + .ok()?, + ) + }; + + let cx = async_cx.clone(); + let project = project.clone(); + Some(async move { + match task { + DefinitionTask::CacheHit(cache_entry) => { + Some((identifier, cache_entry, None)) + } + DefinitionTask::CacheMiss(task) => { + let locations = task.await.log_err()??; + let duration = start_time.elapsed(); + cx.update(|cx| { + ( + identifier, + Arc::new(CacheEntry { + definitions: locations + .into_iter() + .filter_map(|location| { + process_definition(location, &project, cx) + }) + .collect(), + }), + Some(duration), + ) + }) + .ok() + } } - counts - } + }) + }) + .collect::>() + })?; + + let mut cache_hit_count = 0; + let mut cache_miss_count = 0; + let mut mean_definition_latency = Duration::ZERO; + let mut max_definition_latency = Duration::ZERO; + let mut new_cache = HashMap::default(); + new_cache.reserve(futures.len()); + for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() { + new_cache.insert(identifier, entry); + if let Some(duration) = duration { + cache_miss_count += 1; + mean_definition_latency += duration; + max_definition_latency = max_definition_latency.max(duration); + } else { + cache_hit_count += 1; + } + } + mean_definition_latency /= cache_miss_count.max(1) as u32; - #[cfg(test)] - mod tests { - use super::*; + let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?; - #[test] - fn test_process_data() { - let data = vec![1, 2, 2, 3]; - let result = process_data(data); - assert_eq!(result.get(&2), Some(&2)); - } - } - "#} - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _| project.languages().clone()); - let lang = rust_lang(); - let lang_id = lang.id(); - language_registry.add(Arc::new(lang)); - - let file_indexing_parallelism = 2; - let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx)); - cx.run_until_parked(); - - (project, index, lang_id) + if let Some(file) = &file { + log::debug!( + "finished retrieving context buffer:{}, latency:{:?}", + file.path().as_unix_str(), + start_time.elapsed() + ); + } + + this.update(cx, |this, cx| { + this.cache = new_cache; + this.related_files = related_files; + cx.emit(RelatedExcerptStoreEvent::FinishedRefresh { + cache_hit_count, + cache_miss_count, + mean_definition_latency, + max_definition_latency, + }); + })?; + + anyhow::Ok(()) + } +} + +async fn rebuild_related_files( + new_entries: HashMap>, + cx: &mut AsyncApp, +) -> Result<(HashMap>, Vec)> { + let mut snapshots = HashMap::default(); + for entry in new_entries.values() { + for definition in &entry.definitions { + if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) { + definition + .buffer + .read_with(cx, |buffer, _| buffer.parsing_idle())? + .await; + e.insert( + definition + .buffer + .read_with(cx, |buffer, _| buffer.snapshot())?, + ); + } + } } - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm")) - .unwrap() - .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) - .unwrap() + Ok(cx + .background_spawn(async move { + let mut files = Vec::::new(); + let mut ranges_by_buffer = HashMap::<_, Vec>>::default(); + let mut paths_by_buffer = HashMap::default(); + for entry in new_entries.values() { + for definition in &entry.definitions { + let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else { + continue; + }; + paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone()); + ranges_by_buffer + .entry(definition.buffer.clone()) + .or_default() + .push(definition.anchor_range.to_point(snapshot)); + } + } + + for (buffer, ranges) in ranges_by_buffer { + let Some(snapshot) = snapshots.get(&buffer.entity_id()) else { + continue; + }; + let Some(project_path) = paths_by_buffer.get(&buffer.entity_id()) else { + continue; + }; + let excerpts = assemble_excerpts(snapshot, ranges); + files.push(RelatedFile { + path: project_path.clone(), + buffer: buffer.downgrade(), + excerpts, + max_row: snapshot.max_point().row, + }); + } + + files.sort_by_key(|file| file.path.clone()); + (new_entries, files) + }) + .await) +} + +fn process_definition( + location: LocationLink, + project: &Entity, + cx: &mut App, +) -> Option { + let buffer = location.target.buffer.read(cx); + let anchor_range = location.target.range; + let file = buffer.file()?; + let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?; + if worktree.read(cx).is_single_file() { + return None; } + Some(CachedDefinition { + path: ProjectPath { + worktree_id: file.worktree_id(cx), + path: file.path().clone(), + }, + buffer: location.target.buffer, + anchor_range, + }) +} + +/// Gets all of the identifiers that are present in the given line, and its containing +/// outline items. +fn identifiers_for_position( + buffer: &BufferSnapshot, + position: Anchor, + identifier_line_count: u32, +) -> Vec { + let offset = position.to_offset(buffer); + let point = buffer.offset_to_point(offset); + + // Search for identifiers on lines adjacent to the cursor. + let start = Point::new(point.row.saturating_sub(identifier_line_count), 0); + let end = Point::new(point.row + identifier_line_count + 1, 0).min(buffer.max_point()); + let line_range = start..end; + let mut ranges = vec![line_range.to_offset(&buffer)]; + + // Search for identifiers mentioned in headers/signatures of containing outline items. + let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None); + for item in outline_items { + if let Some(body_range) = item.body_range(&buffer) { + ranges.push(item.range.start..body_range.start.to_offset(&buffer)); + } else { + ranges.push(item.range.clone()); + } + } + + ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end))); + ranges.dedup_by(|a, b| { + if a.start <= b.end { + b.start = b.start.min(a.start); + b.end = b.end.max(a.end); + true + } else { + false + } + }); + + let mut identifiers = Vec::new(); + let outer_range = + ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end); + + let mut captures = buffer + .syntax + .captures(outer_range.clone(), &buffer.text, |grammar| { + grammar + .highlights_config + .as_ref() + .map(|config| &config.query) + }); + + for range in ranges { + captures.set_byte_range(range.start..outer_range.end); + + let mut last_range = None; + while let Some(capture) = captures.peek() { + let node_range = capture.node.byte_range(); + if node_range.start > range.end { + break; + } + let config = captures.grammars()[capture.grammar_index] + .highlights_config + .as_ref(); + + if let Some(config) = config + && config.identifier_capture_indices.contains(&capture.index) + && range.contains_inclusive(&node_range) + && Some(&node_range) != last_range.as_ref() + { + let name = buffer.text_for_range(node_range.clone()).collect(); + identifiers.push(Identifier { + range: buffer.anchor_after(node_range.start) + ..buffer.anchor_before(node_range.end), + name, + }); + last_range = Some(node_range); + } + + captures.advance(); + } + } + + identifiers } diff --git a/crates/edit_prediction_context/src/edit_prediction_context_tests.rs b/crates/edit_prediction_context/src/edit_prediction_context_tests.rs new file mode 100644 index 0000000000000000000000000000000000000000..dba8d89e593ccb60e7eae5d091708e82debef0f5 --- /dev/null +++ b/crates/edit_prediction_context/src/edit_prediction_context_tests.rs @@ -0,0 +1,510 @@ +use super::*; +use futures::channel::mpsc::UnboundedReceiver; +use gpui::TestAppContext; +use indoc::indoc; +use language::{Point, ToPoint as _, rust_lang}; +use lsp::FakeLanguageServer; +use project::{FakeFs, LocationLink, Project}; +use serde_json::json; +use settings::SettingsStore; +use std::fmt::Write as _; +use util::{path, test::marked_text_ranges}; + +#[gpui::test] +async fn test_edit_prediction_context(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/root"), test_project_1()).await; + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let mut servers = setup_fake_lsp(&project, cx); + + let (buffer, _handle) = project + .update(cx, |project, cx| { + project.open_local_buffer_with_lsp(path!("/root/src/main.rs"), cx) + }) + .await + .unwrap(); + + let _server = servers.next().await.unwrap(); + cx.run_until_parked(); + + let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(&project, cx)); + related_excerpt_store.update(cx, |store, cx| { + let position = { + let buffer = buffer.read(cx); + let offset = buffer.text().find("todo").unwrap(); + buffer.anchor_before(offset) + }; + + store.set_identifier_line_count(0); + store.refresh(buffer.clone(), position, cx); + }); + + cx.executor().advance_clock(DEBOUNCE_DURATION); + related_excerpt_store.update(cx, |store, _| { + let excerpts = store.related_files(); + assert_related_files( + &excerpts, + &[ + ( + "src/company.rs", + &[indoc! {" + pub struct Company { + owner: Arc, + address: Address, + }"}], + ), + ( + "src/main.rs", + &[ + indoc! {" + pub struct Session { + company: Arc, + } + + impl Session { + pub fn set_company(&mut self, company: Arc) {"}, + indoc! {" + } + }"}, + ], + ), + ( + "src/person.rs", + &[ + indoc! {" + impl Person { + pub fn get_first_name(&self) -> &str { + &self.first_name + }"}, + "}", + ], + ), + ], + ); + }); +} + +#[gpui::test] +fn test_assemble_excerpts(cx: &mut TestAppContext) { + let table = [ + ( + indoc! {r#" + struct User { + first_name: String, + «last_name»: String, + age: u32, + email: String, + create_at: Instant, + } + + impl User { + pub fn first_name(&self) -> String { + self.first_name.clone() + } + + pub fn full_name(&self) -> String { + « format!("{} {}", self.first_name, self.last_name) + » } + } + "#}, + indoc! {r#" + struct User { + first_name: String, + last_name: String, + … + } + + impl User { + … + pub fn full_name(&self) -> String { + format!("{} {}", self.first_name, self.last_name) + } + } + "#}, + ), + ( + indoc! {r#" + struct «User» { + first_name: String, + last_name: String, + age: u32, + } + + impl User { + // methods + } + "#}, + indoc! {r#" + struct User { + first_name: String, + last_name: String, + age: u32, + } + … + "#}, + ), + ( + indoc! {r#" + trait «FooProvider» { + const NAME: &'static str; + + fn provide_foo(&self, id: usize) -> Foo; + + fn provide_foo_batched(&self, ids: &[usize]) -> Vec { + ids.iter() + .map(|id| self.provide_foo(*id)) + .collect() + } + + fn sync(&self); + } + "# + }, + indoc! {r#" + trait FooProvider { + const NAME: &'static str; + + fn provide_foo(&self, id: usize) -> Foo; + + fn provide_foo_batched(&self, ids: &[usize]) -> Vec { + … + } + + fn sync(&self); + } + "#}, + ), + ( + indoc! {r#" + trait «Something» { + fn method1(&self, id: usize) -> Foo; + + fn method2(&self, ids: &[usize]) -> Vec { + struct Helper1 { + field1: usize, + } + + struct Helper2 { + field2: usize, + } + + struct Helper3 { + filed2: usize, + } + } + + fn sync(&self); + } + "# + }, + indoc! {r#" + trait Something { + fn method1(&self, id: usize) -> Foo; + + fn method2(&self, ids: &[usize]) -> Vec { + … + } + + fn sync(&self); + } + "#}, + ), + ]; + + for (input, expected_output) in table { + let (input, ranges) = marked_text_ranges(&input, false); + let buffer = cx.new(|cx| Buffer::local(input, cx).with_language(rust_lang(), cx)); + buffer.read_with(cx, |buffer, _cx| { + let ranges: Vec> = ranges + .into_iter() + .map(|range| range.to_point(&buffer)) + .collect(); + + let excerpts = assemble_excerpts(&buffer.snapshot(), ranges); + + let output = format_excerpts(buffer, &excerpts); + assert_eq!(output, expected_output); + }); + } +} + +#[gpui::test] +async fn test_fake_definition_lsp(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/root"), test_project_1()).await; + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let mut servers = setup_fake_lsp(&project, cx); + + let (buffer, _handle) = project + .update(cx, |project, cx| { + project.open_local_buffer_with_lsp(path!("/root/src/main.rs"), cx) + }) + .await + .unwrap(); + + let _server = servers.next().await.unwrap(); + cx.run_until_parked(); + + let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text()); + + let definitions = project + .update(cx, |project, cx| { + let offset = buffer_text.find("Address {").unwrap(); + project.definitions(&buffer, offset, cx) + }) + .await + .unwrap() + .unwrap(); + assert_definitions(&definitions, &["pub struct Address {"], cx); + + let definitions = project + .update(cx, |project, cx| { + let offset = buffer_text.find("State::CA").unwrap(); + project.definitions(&buffer, offset, cx) + }) + .await + .unwrap() + .unwrap(); + assert_definitions(&definitions, &["pub enum State {"], cx); + + let definitions = project + .update(cx, |project, cx| { + let offset = buffer_text.find("to_string()").unwrap(); + project.definitions(&buffer, offset, cx) + }) + .await + .unwrap() + .unwrap(); + assert_definitions(&definitions, &["pub fn to_string(&self) -> String {"], cx); +} + +fn init_test(cx: &mut TestAppContext) { + let settings_store = cx.update(|cx| SettingsStore::test(cx)); + cx.set_global(settings_store); + env_logger::try_init().ok(); +} + +fn setup_fake_lsp( + project: &Entity, + cx: &mut TestAppContext, +) -> UnboundedReceiver { + let (language_registry, fs) = project.read_with(cx, |project, _| { + (project.languages().clone(), project.fs().clone()) + }); + let language = rust_lang(); + language_registry.add(language.clone()); + fake_definition_lsp::register_fake_definition_server(&language_registry, language, fs) +} + +fn test_project_1() -> serde_json::Value { + let person_rs = indoc! {r#" + pub struct Person { + first_name: String, + last_name: String, + email: String, + age: u32, + } + + impl Person { + pub fn get_first_name(&self) -> &str { + &self.first_name + } + + pub fn get_last_name(&self) -> &str { + &self.last_name + } + + pub fn get_email(&self) -> &str { + &self.email + } + + pub fn get_age(&self) -> u32 { + self.age + } + } + "#}; + + let address_rs = indoc! {r#" + pub struct Address { + street: String, + city: String, + state: State, + zip: u32, + } + + pub enum State { + CA, + OR, + WA, + TX, + // ... + } + + impl Address { + pub fn get_street(&self) -> &str { + &self.street + } + + pub fn get_city(&self) -> &str { + &self.city + } + + pub fn get_state(&self) -> State { + self.state + } + + pub fn get_zip(&self) -> u32 { + self.zip + } + } + "#}; + + let company_rs = indoc! {r#" + use super::person::Person; + use super::address::Address; + + pub struct Company { + owner: Arc, + address: Address, + } + + impl Company { + pub fn get_owner(&self) -> &Person { + &self.owner + } + + pub fn get_address(&self) -> &Address { + &self.address + } + + pub fn to_string(&self) -> String { + format!("{} ({})", self.owner.first_name, self.address.city) + } + } + "#}; + + let main_rs = indoc! {r#" + use std::sync::Arc; + use super::person::Person; + use super::address::Address; + use super::company::Company; + + pub struct Session { + company: Arc, + } + + impl Session { + pub fn set_company(&mut self, company: Arc) { + self.company = company; + if company.owner != self.company.owner { + log("new owner", company.owner.get_first_name()); todo(); + } + } + } + + fn main() { + let company = Company { + owner: Arc::new(Person { + first_name: "John".to_string(), + last_name: "Doe".to_string(), + email: "john@example.com".to_string(), + age: 30, + }), + address: Address { + street: "123 Main St".to_string(), + city: "Anytown".to_string(), + state: State::CA, + zip: 12345, + }, + }; + + println!("Company: {}", company.to_string()); + } + "#}; + + json!({ + "src": { + "person.rs": person_rs, + "address.rs": address_rs, + "company.rs": company_rs, + "main.rs": main_rs, + }, + }) +} + +fn assert_related_files(actual_files: &[RelatedFile], expected_files: &[(&str, &[&str])]) { + let actual_files = actual_files + .iter() + .map(|file| { + let excerpts = file + .excerpts + .iter() + .map(|excerpt| excerpt.text.to_string()) + .collect::>(); + (file.path.path.as_unix_str(), excerpts) + }) + .collect::>(); + let expected_excerpts = expected_files + .iter() + .map(|(path, texts)| { + ( + *path, + texts + .iter() + .map(|line| line.to_string()) + .collect::>(), + ) + }) + .collect::>(); + pretty_assertions::assert_eq!(actual_files, expected_excerpts) +} + +fn assert_definitions(definitions: &[LocationLink], first_lines: &[&str], cx: &mut TestAppContext) { + let actual_first_lines = definitions + .iter() + .map(|definition| { + definition.target.buffer.read_with(cx, |buffer, _| { + let mut start = definition.target.range.start.to_point(&buffer); + start.column = 0; + let end = Point::new(start.row, buffer.line_len(start.row)); + buffer + .text_for_range(start..end) + .collect::() + .trim() + .to_string() + }) + }) + .collect::>(); + + assert_eq!(actual_first_lines, first_lines); +} + +fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String { + let mut output = String::new(); + let file_line_count = buffer.max_point().row; + let mut current_row = 0; + for excerpt in excerpts { + if excerpt.text.is_empty() { + continue; + } + if current_row < excerpt.point_range.start.row { + writeln!(&mut output, "…").unwrap(); + } + current_row = excerpt.point_range.start.row; + + for line in excerpt.text.to_string().lines() { + output.push_str(line); + output.push('\n'); + current_row += 1; + } + } + if current_row < file_line_count { + writeln!(&mut output, "…").unwrap(); + } + output +} diff --git a/crates/edit_prediction_context/src/excerpt.rs b/crates/edit_prediction_context/src/excerpt.rs index 7a4bb73edfa131b620a930d7f0e1c0da77e0afe6..3fc7eed4ace5a83992bf496aef3e364aea96e215 100644 --- a/crates/edit_prediction_context/src/excerpt.rs +++ b/crates/edit_prediction_context/src/excerpt.rs @@ -1,11 +1,9 @@ -use language::{BufferSnapshot, LanguageId}; +use cloud_llm_client::predict_edits_v3::Line; +use language::{BufferSnapshot, LanguageId, Point, ToOffset as _, ToPoint as _}; use std::ops::Range; -use text::{Point, ToOffset as _, ToPoint as _}; use tree_sitter::{Node, TreeCursor}; use util::RangeExt; -use crate::{BufferDeclaration, Line, declaration::DeclarationId, syntax_index::SyntaxIndexState}; - // TODO: // // - Test parent signatures @@ -31,19 +29,16 @@ pub struct EditPredictionExcerptOptions { pub target_before_cursor_over_total_bytes: f32, } -// TODO: consider merging these #[derive(Debug, Clone)] pub struct EditPredictionExcerpt { pub range: Range, pub line_range: Range, - pub parent_declarations: Vec<(DeclarationId, Range)>, pub size: usize, } #[derive(Debug, Clone)] pub struct EditPredictionExcerptText { pub body: String, - pub parent_signatures: Vec, pub language_id: Option, } @@ -52,17 +47,8 @@ impl EditPredictionExcerpt { let body = buffer .text_for_range(self.range.clone()) .collect::(); - let parent_signatures = self - .parent_declarations - .iter() - .map(|(_, range)| buffer.text_for_range(range.clone()).collect::()) - .collect(); let language_id = buffer.language().map(|l| l.id()); - EditPredictionExcerptText { - body, - parent_signatures, - language_id, - } + EditPredictionExcerptText { body, language_id } } /// Selects an excerpt around a buffer position, attempting to choose logical boundaries based @@ -79,7 +65,6 @@ impl EditPredictionExcerpt { query_point: Point, buffer: &BufferSnapshot, options: &EditPredictionExcerptOptions, - syntax_index: Option<&SyntaxIndexState>, ) -> Option { if buffer.len() <= options.max_bytes { log::debug!( @@ -89,11 +74,7 @@ impl EditPredictionExcerpt { ); let offset_range = 0..buffer.len(); let line_range = Line(0)..Line(buffer.max_point().row); - return Some(EditPredictionExcerpt::new( - offset_range, - line_range, - Vec::new(), - )); + return Some(EditPredictionExcerpt::new(offset_range, line_range)); } let query_offset = query_point.to_offset(buffer); @@ -104,19 +85,10 @@ impl EditPredictionExcerpt { return None; } - let parent_declarations = if let Some(syntax_index) = syntax_index { - syntax_index - .buffer_declarations_containing_range(buffer.remote_id(), query_range.clone()) - .collect() - } else { - Vec::new() - }; - let excerpt_selector = ExcerptSelector { query_offset, query_range, query_line_range: Line(query_line_range.start)..Line(query_line_range.end), - parent_declarations: &parent_declarations, buffer, options, }; @@ -139,20 +111,10 @@ impl EditPredictionExcerpt { excerpt_selector.select_lines() } - fn new( - range: Range, - line_range: Range, - parent_declarations: Vec<(DeclarationId, Range)>, - ) -> Self { - let size = range.len() - + parent_declarations - .iter() - .map(|(_, range)| range.len()) - .sum::(); + fn new(range: Range, line_range: Range) -> Self { Self { + size: range.len(), range, - parent_declarations, - size, line_range, } } @@ -162,14 +124,7 @@ impl EditPredictionExcerpt { // this is an issue because parent_signature_ranges may be incorrect log::error!("bug: with_expanded_range called with disjoint range"); } - let mut parent_declarations = Vec::with_capacity(self.parent_declarations.len()); - for (declaration_id, range) in &self.parent_declarations { - if !range.contains_inclusive(&new_range) { - break; - } - parent_declarations.push((*declaration_id, range.clone())); - } - Self::new(new_range, new_line_range, parent_declarations) + Self::new(new_range, new_line_range) } fn parent_signatures_size(&self) -> usize { @@ -181,7 +136,6 @@ struct ExcerptSelector<'a> { query_offset: usize, query_range: Range, query_line_range: Range, - parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)], buffer: &'a BufferSnapshot, options: &'a EditPredictionExcerptOptions, } @@ -409,13 +363,7 @@ impl<'a> ExcerptSelector<'a> { } fn make_excerpt(&self, range: Range, line_range: Range) -> EditPredictionExcerpt { - let parent_declarations = self - .parent_declarations - .iter() - .filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range)) - .map(|(id, declaration)| (*id, declaration.signature_range.clone())) - .collect(); - EditPredictionExcerpt::new(range, line_range, parent_declarations) + EditPredictionExcerpt::new(range, line_range) } /// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt. @@ -471,30 +419,14 @@ fn node_line_end(node: Node) -> Point { mod tests { use super::*; use gpui::{AppContext, TestAppContext}; - use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; + use language::Buffer; use util::test::{generate_marked_text, marked_text_offsets_by}; fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot { - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang().into(), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx)); buffer.read_with(cx, |buffer, _| buffer.snapshot()) } - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) - .unwrap() - } - fn cursor_and_excerpt_range(text: &str) -> (String, usize, Range) { let (text, offsets) = marked_text_offsets_by(text, vec!['ˇ', '«', '»']); (text, offsets[&'ˇ'][0], offsets[&'«'][0]..offsets[&'»'][0]) @@ -506,9 +438,8 @@ mod tests { let buffer = create_buffer(&text, cx); let cursor_point = cursor.to_point(&buffer); - let excerpt = - EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options, None) - .expect("Should select an excerpt"); + let excerpt = EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options) + .expect("Should select an excerpt"); pretty_assertions::assert_eq!( generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false), generate_marked_text(&text, &[expected_excerpt], false) diff --git a/crates/edit_prediction_context/src/fake_definition_lsp.rs b/crates/edit_prediction_context/src/fake_definition_lsp.rs new file mode 100644 index 0000000000000000000000000000000000000000..31fb681309c610a37c7f886390ef5adb92ee78ef --- /dev/null +++ b/crates/edit_prediction_context/src/fake_definition_lsp.rs @@ -0,0 +1,329 @@ +use collections::HashMap; +use futures::channel::mpsc::UnboundedReceiver; +use language::{Language, LanguageRegistry}; +use lsp::{ + FakeLanguageServer, LanguageServerBinary, TextDocumentSyncCapability, TextDocumentSyncKind, Uri, +}; +use parking_lot::Mutex; +use project::Fs; +use std::{ops::Range, path::PathBuf, sync::Arc}; +use tree_sitter::{Parser, QueryCursor, StreamingIterator, Tree}; + +/// Registers a fake language server that implements go-to-definition using tree-sitter, +/// making the assumption that all names are unique, and all variables' types are +/// explicitly declared. +pub fn register_fake_definition_server( + language_registry: &Arc, + language: Arc, + fs: Arc, +) -> UnboundedReceiver { + let index = Arc::new(Mutex::new(DefinitionIndex::new(language.clone()))); + + language_registry.register_fake_lsp( + language.name(), + language::FakeLspAdapter { + name: "fake-definition-lsp", + initialization_options: None, + prettier_plugins: Vec::new(), + disk_based_diagnostics_progress_token: None, + disk_based_diagnostics_sources: Vec::new(), + language_server_binary: LanguageServerBinary { + path: PathBuf::from("fake-definition-lsp"), + arguments: Vec::new(), + env: None, + }, + capabilities: lsp::ServerCapabilities { + definition_provider: Some(lsp::OneOf::Left(true)), + text_document_sync: Some(TextDocumentSyncCapability::Kind( + TextDocumentSyncKind::FULL, + )), + ..Default::default() + }, + label_for_completion: None, + initializer: Some(Box::new({ + move |server| { + server.handle_notification::({ + let index = index.clone(); + move |params, _cx| { + index + .lock() + .open_buffer(params.text_document.uri, ¶ms.text_document.text); + } + }); + + server.handle_notification::({ + let index = index.clone(); + let fs = fs.clone(); + move |params, cx| { + let uri = params.text_document.uri; + let path = uri.to_file_path().ok(); + index.lock().mark_buffer_closed(&uri); + + if let Some(path) = path { + let index = index.clone(); + let fs = fs.clone(); + cx.spawn(async move |_cx| { + if let Ok(content) = fs.load(&path).await { + index.lock().index_file(uri, &content); + } + }) + .detach(); + } + } + }); + + server.handle_notification::({ + let index = index.clone(); + let fs = fs.clone(); + move |params, cx| { + let index = index.clone(); + let fs = fs.clone(); + cx.spawn(async move |_cx| { + for event in params.changes { + if index.lock().is_buffer_open(&event.uri) { + continue; + } + + match event.typ { + lsp::FileChangeType::DELETED => { + index.lock().remove_definitions_for_file(&event.uri); + } + lsp::FileChangeType::CREATED + | lsp::FileChangeType::CHANGED => { + if let Some(path) = event.uri.to_file_path().ok() { + if let Ok(content) = fs.load(&path).await { + index.lock().index_file(event.uri, &content); + } + } + } + _ => {} + } + } + }) + .detach(); + } + }); + + server.handle_notification::({ + let index = index.clone(); + move |params, _cx| { + if let Some(change) = params.content_changes.into_iter().last() { + index + .lock() + .index_file(params.text_document.uri, &change.text); + } + } + }); + + server.handle_notification::( + { + let index = index.clone(); + let fs = fs.clone(); + move |params, cx| { + let index = index.clone(); + let fs = fs.clone(); + let files = fs.as_fake().files(); + cx.spawn(async move |_cx| { + for folder in params.event.added { + let Ok(path) = folder.uri.to_file_path() else { + continue; + }; + for file in &files { + if let Some(uri) = Uri::from_file_path(&file).ok() + && file.starts_with(&path) + && let Ok(content) = fs.load(&file).await + { + index.lock().index_file(uri, &content); + } + } + } + }) + .detach(); + } + }, + ); + + server.set_request_handler::({ + let index = index.clone(); + move |params, _cx| { + let result = index.lock().get_definitions( + params.text_document_position_params.text_document.uri, + params.text_document_position_params.position, + ); + async move { Ok(result) } + } + }); + } + })), + }, + ) +} + +struct DefinitionIndex { + language: Arc, + definitions: HashMap>, + files: HashMap, +} + +#[derive(Debug)] +struct FileEntry { + contents: String, + is_open_in_buffer: bool, +} + +impl DefinitionIndex { + fn new(language: Arc) -> Self { + Self { + language, + definitions: HashMap::default(), + files: HashMap::default(), + } + } + + fn remove_definitions_for_file(&mut self, uri: &Uri) { + self.definitions.retain(|_, locations| { + locations.retain(|loc| &loc.uri != uri); + !locations.is_empty() + }); + self.files.remove(uri); + } + + fn open_buffer(&mut self, uri: Uri, content: &str) { + self.index_file_inner(uri, content, true); + } + + fn mark_buffer_closed(&mut self, uri: &Uri) { + if let Some(entry) = self.files.get_mut(uri) { + entry.is_open_in_buffer = false; + } + } + + fn is_buffer_open(&self, uri: &Uri) -> bool { + self.files + .get(uri) + .map(|entry| entry.is_open_in_buffer) + .unwrap_or(false) + } + + fn index_file(&mut self, uri: Uri, content: &str) { + self.index_file_inner(uri, content, false); + } + + fn index_file_inner(&mut self, uri: Uri, content: &str, is_open_in_buffer: bool) -> Option<()> { + self.remove_definitions_for_file(&uri); + let grammar = self.language.grammar()?; + let outline_config = grammar.outline_config.as_ref()?; + let mut parser = Parser::new(); + parser.set_language(&grammar.ts_language).ok()?; + let tree = parser.parse(content, None)?; + let declarations = extract_declarations_from_tree(&tree, content, outline_config); + for (name, byte_range) in declarations { + let range = byte_range_to_lsp_range(content, byte_range); + let location = lsp::Location { + uri: uri.clone(), + range, + }; + self.definitions + .entry(name) + .or_insert_with(Vec::new) + .push(location); + } + self.files.insert( + uri, + FileEntry { + contents: content.to_string(), + is_open_in_buffer, + }, + ); + + Some(()) + } + + fn get_definitions( + &mut self, + uri: Uri, + position: lsp::Position, + ) -> Option { + let entry = self.files.get(&uri)?; + let name = word_at_position(&entry.contents, position)?; + let locations = self.definitions.get(name).cloned()?; + Some(lsp::GotoDefinitionResponse::Array(locations)) + } +} + +fn extract_declarations_from_tree( + tree: &Tree, + content: &str, + outline_config: &language::OutlineConfig, +) -> Vec<(String, Range)> { + let mut cursor = QueryCursor::new(); + let mut declarations = Vec::new(); + let mut matches = cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()); + while let Some(query_match) = matches.next() { + let mut name_range: Option> = None; + let mut has_item_range = false; + + for capture in query_match.captures { + let range = capture.node.byte_range(); + if capture.index == outline_config.name_capture_ix { + name_range = Some(range); + } else if capture.index == outline_config.item_capture_ix { + has_item_range = true; + } + } + + if let Some(name_range) = name_range + && has_item_range + { + let name = content[name_range.clone()].to_string(); + if declarations.iter().any(|(n, _)| n == &name) { + continue; + } + declarations.push((name, name_range)); + } + } + declarations +} + +fn byte_range_to_lsp_range(content: &str, byte_range: Range) -> lsp::Range { + let start = byte_offset_to_position(content, byte_range.start); + let end = byte_offset_to_position(content, byte_range.end); + lsp::Range { start, end } +} + +fn byte_offset_to_position(content: &str, offset: usize) -> lsp::Position { + let mut line = 0; + let mut character = 0; + let mut current_offset = 0; + for ch in content.chars() { + if current_offset >= offset { + break; + } + if ch == '\n' { + line += 1; + character = 0; + } else { + character += 1; + } + current_offset += ch.len_utf8(); + } + lsp::Position { line, character } +} + +fn word_at_position(content: &str, position: lsp::Position) -> Option<&str> { + let mut lines = content.lines(); + let line = lines.nth(position.line as usize)?; + let column = position.character as usize; + if column > line.len() { + return None; + } + let start = line[..column] + .rfind(|c: char| !c.is_alphanumeric() && c != '_') + .map(|i| i + 1) + .unwrap_or(0); + let end = line[column..] + .find(|c: char| !c.is_alphanumeric() && c != '_') + .map(|i| i + column) + .unwrap_or(line.len()); + Some(&line[start..end]).filter(|word| !word.is_empty()) +} diff --git a/crates/edit_prediction_context/src/imports.rs b/crates/edit_prediction_context/src/imports.rs deleted file mode 100644 index 70f175159340ddb9a6f26f23db0c1b3c843e7b96..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_context/src/imports.rs +++ /dev/null @@ -1,1319 +0,0 @@ -use collections::HashMap; -use language::BufferSnapshot; -use language::ImportsConfig; -use language::Language; -use std::ops::Deref; -use std::path::Path; -use std::sync::Arc; -use std::{borrow::Cow, ops::Range}; -use text::OffsetRangeExt as _; -use util::RangeExt; -use util::paths::PathStyle; - -use crate::Identifier; -use crate::text_similarity::Occurrences; - -// TODO: Write documentation for extension authors. The @import capture must match before or in the -// same pattern as all all captures it contains - -// Future improvements to consider: -// -// * Distinguish absolute vs relative paths in captures. `#include "maths.h"` is relative whereas -// `#include ` is not. -// -// * Provide the name used when importing whole modules (see tests with "named_module" in the name). -// To be useful, will require parsing of identifier qualification. -// -// * Scoping for imports that aren't at the top level -// -// * Only scan a prefix of the file, when possible. This could look like having query matches that -// indicate it reached a declaration that is not allowed in the import section. -// -// * Support directly parsing to occurrences instead of storing namespaces / paths. Types should be -// generic on this, so that tests etc can still use strings. Could do similar in syntax index. -// -// * Distinguish different types of namespaces when known. E.g. "name.type" capture. Once capture -// names are more open-ended like this may make sense to build and cache a jump table (direct -// dispatch from capture index). -// -// * There are a few "Language specific:" comments on behavior that gets applied to all languages. -// Would be cleaner to be conditional on the language or otherwise configured. - -#[derive(Debug, Clone, Default)] -pub struct Imports { - pub identifier_to_imports: HashMap>, - pub wildcard_modules: Vec, -} - -#[derive(Debug, Clone)] -pub enum Import { - Direct { - module: Module, - }, - Alias { - module: Module, - external_identifier: Identifier, - }, -} - -#[derive(Debug, Clone)] -pub enum Module { - SourceExact(Arc), - SourceFuzzy(Arc), - Namespace(Namespace), -} - -impl Module { - fn empty() -> Self { - Module::Namespace(Namespace::default()) - } - - fn push_range( - &mut self, - range: &ModuleRange, - snapshot: &BufferSnapshot, - language: &Language, - parent_abs_path: Option<&Path>, - ) -> usize { - if range.is_empty() { - return 0; - } - - match range { - ModuleRange::Source(range) => { - if let Self::Namespace(namespace) = self - && namespace.0.is_empty() - { - let path = snapshot.text_for_range(range.clone()).collect::>(); - - let path = if let Some(strip_regex) = - language.config().import_path_strip_regex.as_ref() - { - strip_regex.replace_all(&path, "") - } else { - path - }; - - let path = Path::new(path.as_ref()); - if (path.starts_with(".") || path.starts_with("..")) - && let Some(parent_abs_path) = parent_abs_path - && let Ok(abs_path) = - util::paths::normalize_lexically(&parent_abs_path.join(path)) - { - *self = Self::SourceExact(abs_path.into()); - } else { - *self = Self::SourceFuzzy(path.into()); - }; - } else if matches!(self, Self::SourceExact(_)) - || matches!(self, Self::SourceFuzzy(_)) - { - log::warn!("bug in imports query: encountered multiple @source matches"); - } else { - log::warn!( - "bug in imports query: encountered both @namespace and @source match" - ); - } - } - ModuleRange::Namespace(range) => { - if let Self::Namespace(namespace) = self { - let segment = range_text(snapshot, range); - if language.config().ignored_import_segments.contains(&segment) { - return 0; - } else { - namespace.0.push(segment); - return 1; - } - } else { - log::warn!( - "bug in imports query: encountered both @namespace and @source match" - ); - } - } - } - 0 - } -} - -#[derive(Debug, Clone)] -enum ModuleRange { - Source(Range), - Namespace(Range), -} - -impl Deref for ModuleRange { - type Target = Range; - - fn deref(&self) -> &Self::Target { - match self { - ModuleRange::Source(range) => range, - ModuleRange::Namespace(range) => range, - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Default)] -pub struct Namespace(pub Vec>); - -impl Namespace { - pub fn occurrences(&self) -> Occurrences { - Occurrences::from_identifiers(&self.0) - } -} - -impl Imports { - pub fn gather(snapshot: &BufferSnapshot, parent_abs_path: Option<&Path>) -> Self { - // Query to match different import patterns - let mut matches = snapshot - .syntax - .matches(0..snapshot.len(), &snapshot.text, |grammar| { - grammar.imports_config().map(|imports| &imports.query) - }); - - let mut detached_nodes: Vec = Vec::new(); - let mut identifier_to_imports = HashMap::default(); - let mut wildcard_modules = Vec::new(); - let mut import_range = None; - - while let Some(query_match) = matches.peek() { - let ImportsConfig { - query: _, - import_ix, - name_ix, - namespace_ix, - source_ix, - list_ix, - wildcard_ix, - alias_ix, - } = matches.grammars()[query_match.grammar_index] - .imports_config() - .unwrap(); - - let mut new_import_range = None; - let mut alias_range = None; - let mut modules = Vec::new(); - let mut content: Option<(Range, ContentKind)> = None; - for capture in query_match.captures { - let capture_range = capture.node.byte_range(); - - if capture.index == *import_ix { - new_import_range = Some(capture_range); - } else if Some(capture.index) == *namespace_ix { - modules.push(ModuleRange::Namespace(capture_range)); - } else if Some(capture.index) == *source_ix { - modules.push(ModuleRange::Source(capture_range)); - } else if Some(capture.index) == *alias_ix { - alias_range = Some(capture_range); - } else { - let mut found_content = None; - if Some(capture.index) == *name_ix { - found_content = Some((capture_range, ContentKind::Name)); - } else if Some(capture.index) == *list_ix { - found_content = Some((capture_range, ContentKind::List)); - } else if Some(capture.index) == *wildcard_ix { - found_content = Some((capture_range, ContentKind::Wildcard)); - } - if let Some((found_content_range, found_kind)) = found_content { - if let Some((_, old_kind)) = content { - let point = found_content_range.to_point(snapshot); - log::warn!( - "bug in {} imports query: unexpected multiple captures of {} and {} ({}:{}:{})", - query_match.language.name(), - old_kind.capture_name(), - found_kind.capture_name(), - snapshot - .file() - .map(|p| p.path().display(PathStyle::Posix)) - .unwrap_or_default(), - point.start.row + 1, - point.start.column + 1 - ); - } - content = Some((found_content_range, found_kind)); - } - } - } - - if let Some(new_import_range) = new_import_range { - log::trace!("starting new import {:?}", new_import_range); - Self::gather_from_import_statement( - &detached_nodes, - &snapshot, - parent_abs_path, - &mut identifier_to_imports, - &mut wildcard_modules, - ); - detached_nodes.clear(); - import_range = Some(new_import_range.clone()); - } - - if let Some((content, content_kind)) = content { - if import_range - .as_ref() - .is_some_and(|import_range| import_range.contains_inclusive(&content)) - { - detached_nodes.push(DetachedNode { - modules, - content: content.clone(), - content_kind, - alias: alias_range.unwrap_or(0..0), - language: query_match.language.clone(), - }); - } else { - log::trace!( - "filtered out match not inside import range: {content_kind:?} at {content:?}" - ); - } - } - - matches.advance(); - } - - Self::gather_from_import_statement( - &detached_nodes, - &snapshot, - parent_abs_path, - &mut identifier_to_imports, - &mut wildcard_modules, - ); - - Imports { - identifier_to_imports, - wildcard_modules, - } - } - - fn gather_from_import_statement( - detached_nodes: &[DetachedNode], - snapshot: &BufferSnapshot, - parent_abs_path: Option<&Path>, - identifier_to_imports: &mut HashMap>, - wildcard_modules: &mut Vec, - ) { - let mut trees = Vec::new(); - - for detached_node in detached_nodes { - if let Some(node) = Self::attach_node(detached_node.into(), &mut trees) { - trees.push(node); - } - log::trace!( - "Attached node to tree\n{:#?}\nAttach result:\n{:#?}", - detached_node, - trees - .iter() - .map(|tree| tree.debug(snapshot)) - .collect::>() - ); - } - - for tree in &trees { - let mut module = Module::empty(); - Self::gather_from_tree( - tree, - snapshot, - parent_abs_path, - &mut module, - identifier_to_imports, - wildcard_modules, - ); - } - } - - fn attach_node(mut node: ImportTree, trees: &mut Vec) -> Option { - let mut tree_index = 0; - while tree_index < trees.len() { - let tree = &mut trees[tree_index]; - if !node.content.is_empty() && node.content == tree.content { - // multiple matches can apply to the same name/list/wildcard. This keeps the queries - // simpler by combining info from these matches. - if tree.module.is_empty() { - tree.module = node.module; - tree.module_children = node.module_children; - } - if tree.alias.is_empty() { - tree.alias = node.alias; - } - return None; - } else if !node.module.is_empty() && node.module.contains_inclusive(&tree.range()) { - node.module_children.push(trees.remove(tree_index)); - continue; - } else if !node.content.is_empty() && node.content.contains_inclusive(&tree.content) { - node.content_children.push(trees.remove(tree_index)); - continue; - } else if !tree.content.is_empty() && tree.content.contains_inclusive(&node.content) { - if let Some(node) = Self::attach_node(node, &mut tree.content_children) { - tree.content_children.push(node); - } - return None; - } - tree_index += 1; - } - Some(node) - } - - fn gather_from_tree( - tree: &ImportTree, - snapshot: &BufferSnapshot, - parent_abs_path: Option<&Path>, - current_module: &mut Module, - identifier_to_imports: &mut HashMap>, - wildcard_modules: &mut Vec, - ) { - let mut pop_count = 0; - - if tree.module_children.is_empty() { - pop_count += - current_module.push_range(&tree.module, snapshot, &tree.language, parent_abs_path); - } else { - for child in &tree.module_children { - pop_count += Self::extend_namespace_from_tree( - child, - snapshot, - parent_abs_path, - current_module, - ); - } - }; - - if tree.content_children.is_empty() && !tree.content.is_empty() { - match tree.content_kind { - ContentKind::Name | ContentKind::List => { - if tree.alias.is_empty() { - identifier_to_imports - .entry(Identifier { - language_id: tree.language.id(), - name: range_text(snapshot, &tree.content), - }) - .or_default() - .push(Import::Direct { - module: current_module.clone(), - }); - } else { - let alias_name: Arc = range_text(snapshot, &tree.alias); - let external_name = range_text(snapshot, &tree.content); - // Language specific: skip "_" aliases for Rust - if alias_name.as_ref() != "_" { - identifier_to_imports - .entry(Identifier { - language_id: tree.language.id(), - name: alias_name, - }) - .or_default() - .push(Import::Alias { - module: current_module.clone(), - external_identifier: Identifier { - language_id: tree.language.id(), - name: external_name, - }, - }); - } - } - } - ContentKind::Wildcard => wildcard_modules.push(current_module.clone()), - } - } else { - for child in &tree.content_children { - Self::gather_from_tree( - child, - snapshot, - parent_abs_path, - current_module, - identifier_to_imports, - wildcard_modules, - ); - } - } - - if pop_count > 0 { - match current_module { - Module::SourceExact(_) | Module::SourceFuzzy(_) => { - log::warn!( - "bug in imports query: encountered both @namespace and @source match" - ); - } - Module::Namespace(namespace) => { - namespace.0.drain(namespace.0.len() - pop_count..); - } - } - } - } - - fn extend_namespace_from_tree( - tree: &ImportTree, - snapshot: &BufferSnapshot, - parent_abs_path: Option<&Path>, - module: &mut Module, - ) -> usize { - let mut pop_count = 0; - if tree.module_children.is_empty() { - pop_count += module.push_range(&tree.module, snapshot, &tree.language, parent_abs_path); - } else { - for child in &tree.module_children { - pop_count += - Self::extend_namespace_from_tree(child, snapshot, parent_abs_path, module); - } - } - if tree.content_children.is_empty() { - pop_count += module.push_range( - &ModuleRange::Namespace(tree.content.clone()), - snapshot, - &tree.language, - parent_abs_path, - ); - } else { - for child in &tree.content_children { - pop_count += - Self::extend_namespace_from_tree(child, snapshot, parent_abs_path, module); - } - } - pop_count - } -} - -fn range_text(snapshot: &BufferSnapshot, range: &Range) -> Arc { - snapshot - .text_for_range(range.clone()) - .collect::>() - .into() -} - -#[derive(Debug)] -struct DetachedNode { - modules: Vec, - content: Range, - content_kind: ContentKind, - alias: Range, - language: Arc, -} - -#[derive(Debug, Clone, Copy)] -enum ContentKind { - Name, - Wildcard, - List, -} - -impl ContentKind { - fn capture_name(&self) -> &'static str { - match self { - ContentKind::Name => "name", - ContentKind::Wildcard => "wildcard", - ContentKind::List => "list", - } - } -} - -#[derive(Debug)] -struct ImportTree { - module: ModuleRange, - /// When non-empty, provides namespace / source info which should be used instead of `module`. - module_children: Vec, - content: Range, - /// When non-empty, provides content which should be used instead of `content`. - content_children: Vec, - content_kind: ContentKind, - alias: Range, - language: Arc, -} - -impl ImportTree { - fn range(&self) -> Range { - self.module.start.min(self.content.start)..self.module.end.max(self.content.end) - } - - #[allow(dead_code)] - fn debug<'a>(&'a self, snapshot: &'a BufferSnapshot) -> ImportTreeDebug<'a> { - ImportTreeDebug { - tree: self, - snapshot, - } - } - - fn from_module_range(module: &ModuleRange, language: Arc) -> Self { - ImportTree { - module: module.clone(), - module_children: Vec::new(), - content: 0..0, - content_children: Vec::new(), - content_kind: ContentKind::Name, - alias: 0..0, - language, - } - } -} - -impl From<&DetachedNode> for ImportTree { - fn from(value: &DetachedNode) -> Self { - let module; - let module_children; - match value.modules.len() { - 0 => { - module = ModuleRange::Namespace(0..0); - module_children = Vec::new(); - } - 1 => { - module = value.modules[0].clone(); - module_children = Vec::new(); - } - _ => { - module = ModuleRange::Namespace( - value.modules.first().unwrap().start..value.modules.last().unwrap().end, - ); - module_children = value - .modules - .iter() - .map(|module| ImportTree::from_module_range(module, value.language.clone())) - .collect(); - } - } - - ImportTree { - module, - module_children, - content: value.content.clone(), - content_children: Vec::new(), - content_kind: value.content_kind, - alias: value.alias.clone(), - language: value.language.clone(), - } - } -} - -struct ImportTreeDebug<'a> { - tree: &'a ImportTree, - snapshot: &'a BufferSnapshot, -} - -impl std::fmt::Debug for ImportTreeDebug<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ImportTree") - .field("module_range", &self.tree.module) - .field("module_text", &range_text(self.snapshot, &self.tree.module)) - .field( - "module_children", - &self - .tree - .module_children - .iter() - .map(|child| child.debug(&self.snapshot)) - .collect::>(), - ) - .field("content_range", &self.tree.content) - .field( - "content_text", - &range_text(self.snapshot, &self.tree.content), - ) - .field( - "content_children", - &self - .tree - .content_children - .iter() - .map(|child| child.debug(&self.snapshot)) - .collect::>(), - ) - .field("content_kind", &self.tree.content_kind) - .field("alias_range", &self.tree.alias) - .field("alias_text", &range_text(self.snapshot, &self.tree.alias)) - .finish() - } -} - -#[cfg(test)] -mod test { - use std::path::PathBuf; - use std::sync::{Arc, LazyLock}; - - use super::*; - use collections::HashSet; - use gpui::{TestAppContext, prelude::*}; - use indoc::indoc; - use language::{ - Buffer, Language, LanguageConfig, tree_sitter_python, tree_sitter_rust, - tree_sitter_typescript, - }; - use regex::Regex; - - #[gpui::test] - fn test_rust_simple(cx: &mut TestAppContext) { - check_imports( - &RUST, - "use std::collections::HashMap;", - &[&["std", "collections", "HashMap"]], - cx, - ); - - check_imports( - &RUST, - "pub use std::collections::HashMap;", - &[&["std", "collections", "HashMap"]], - cx, - ); - - check_imports( - &RUST, - "use std::collections::{HashMap, HashSet};", - &[ - &["std", "collections", "HashMap"], - &["std", "collections", "HashSet"], - ], - cx, - ); - } - - #[gpui::test] - fn test_rust_nested(cx: &mut TestAppContext) { - check_imports( - &RUST, - "use std::{any::TypeId, collections::{HashMap, HashSet}};", - &[ - &["std", "any", "TypeId"], - &["std", "collections", "HashMap"], - &["std", "collections", "HashSet"], - ], - cx, - ); - - check_imports( - &RUST, - "use a::b::c::{d::e::F, g::h::I};", - &[ - &["a", "b", "c", "d", "e", "F"], - &["a", "b", "c", "g", "h", "I"], - ], - cx, - ); - } - - #[gpui::test] - fn test_rust_multiple_imports(cx: &mut TestAppContext) { - check_imports( - &RUST, - indoc! {" - use std::collections::HashMap; - use std::any::{TypeId, Any}; - "}, - &[ - &["std", "collections", "HashMap"], - &["std", "any", "TypeId"], - &["std", "any", "Any"], - ], - cx, - ); - - check_imports( - &RUST, - indoc! {" - use std::collections::HashSet; - - fn main() { - let unqualified = HashSet::new(); - let qualified = std::collections::HashMap::new(); - } - - use std::any::TypeId; - "}, - &[ - &["std", "collections", "HashSet"], - &["std", "any", "TypeId"], - ], - cx, - ); - } - - #[gpui::test] - fn test_rust_wildcard(cx: &mut TestAppContext) { - check_imports(&RUST, "use prelude::*;", &[&["prelude", "WILDCARD"]], cx); - - check_imports( - &RUST, - "use zed::prelude::*;", - &[&["zed", "prelude", "WILDCARD"]], - cx, - ); - - check_imports(&RUST, "use prelude::{*};", &[&["prelude", "WILDCARD"]], cx); - - check_imports( - &RUST, - "use prelude::{File, *};", - &[&["prelude", "File"], &["prelude", "WILDCARD"]], - cx, - ); - - check_imports( - &RUST, - "use zed::{App, prelude::*};", - &[&["zed", "App"], &["zed", "prelude", "WILDCARD"]], - cx, - ); - } - - #[gpui::test] - fn test_rust_alias(cx: &mut TestAppContext) { - check_imports( - &RUST, - "use std::io::Result as IoResult;", - &[&["std", "io", "Result AS IoResult"]], - cx, - ); - } - - #[gpui::test] - fn test_rust_crate_and_super(cx: &mut TestAppContext) { - check_imports(&RUST, "use crate::a::b::c;", &[&["a", "b", "c"]], cx); - check_imports(&RUST, "use super::a::b::c;", &[&["a", "b", "c"]], cx); - // TODO: Consider stripping leading "::". Not done for now because for the text similarity matching usecase this - // is fine. - check_imports(&RUST, "use ::a::b::c;", &[&["::a", "b", "c"]], cx); - } - - #[gpui::test] - fn test_typescript_imports(cx: &mut TestAppContext) { - let parent_abs_path = PathBuf::from("/home/user/project"); - - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import "./maths.js";"#, - &[&["SOURCE /home/user/project/maths", "WILDCARD"]], - cx, - ); - - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import "../maths.js";"#, - &[&["SOURCE /home/user/maths", "WILDCARD"]], - cx, - ); - - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import RandomNumberGenerator, { pi as π } from "./maths.js";"#, - &[ - &["SOURCE /home/user/project/maths", "RandomNumberGenerator"], - &["SOURCE /home/user/project/maths", "pi AS π"], - ], - cx, - ); - - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import { pi, phi, absolute } from "./maths.js";"#, - &[ - &["SOURCE /home/user/project/maths", "pi"], - &["SOURCE /home/user/project/maths", "phi"], - &["SOURCE /home/user/project/maths", "absolute"], - ], - cx, - ); - - // index.js is removed by import_path_strip_regex - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import { pi, phi, absolute } from "./maths/index.js";"#, - &[ - &["SOURCE /home/user/project/maths", "pi"], - &["SOURCE /home/user/project/maths", "phi"], - &["SOURCE /home/user/project/maths", "absolute"], - ], - cx, - ); - - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import type { SomeThing } from "./some-module.js";"#, - &[&["SOURCE /home/user/project/some-module", "SomeThing"]], - cx, - ); - - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import { type SomeThing, OtherThing } from "./some-module.js";"#, - &[ - &["SOURCE /home/user/project/some-module", "SomeThing"], - &["SOURCE /home/user/project/some-module", "OtherThing"], - ], - cx, - ); - - // index.js is removed by import_path_strip_regex - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import { type SomeThing, OtherThing } from "./some-module/index.js";"#, - &[ - &["SOURCE /home/user/project/some-module", "SomeThing"], - &["SOURCE /home/user/project/some-module", "OtherThing"], - ], - cx, - ); - - // fuzzy paths - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import { type SomeThing, OtherThing } from "@my-app/some-module.js";"#, - &[ - &["SOURCE FUZZY @my-app/some-module", "SomeThing"], - &["SOURCE FUZZY @my-app/some-module", "OtherThing"], - ], - cx, - ); - } - - #[gpui::test] - fn test_typescript_named_module_imports(cx: &mut TestAppContext) { - let parent_abs_path = PathBuf::from("/home/user/project"); - - // TODO: These should provide the name that the module is bound to. - // For now instead these are treated as unqualified wildcard imports. - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import * as math from "./maths.js";"#, - // &[&["/home/user/project/maths.js", "WILDCARD AS math"]], - &[&["SOURCE /home/user/project/maths", "WILDCARD"]], - cx, - ); - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &TYPESCRIPT, - r#"import math = require("./maths");"#, - // &[&["/home/user/project/maths", "WILDCARD AS math"]], - &[&["SOURCE /home/user/project/maths", "WILDCARD"]], - cx, - ); - } - - #[gpui::test] - fn test_python_imports(cx: &mut TestAppContext) { - check_imports(&PYTHON, "from math import pi", &[&["math", "pi"]], cx); - - check_imports( - &PYTHON, - "from math import pi, sin, cos", - &[&["math", "pi"], &["math", "sin"], &["math", "cos"]], - cx, - ); - - check_imports(&PYTHON, "from math import *", &[&["math", "WILDCARD"]], cx); - - check_imports( - &PYTHON, - "from math import foo.bar.baz", - &[&["math", "foo", "bar", "baz"]], - cx, - ); - - check_imports( - &PYTHON, - "from math import pi as PI", - &[&["math", "pi AS PI"]], - cx, - ); - - check_imports( - &PYTHON, - "from serializers.json import JsonSerializer", - &[&["serializers", "json", "JsonSerializer"]], - cx, - ); - - check_imports( - &PYTHON, - "from custom.serializers import json, xml, yaml", - &[ - &["custom", "serializers", "json"], - &["custom", "serializers", "xml"], - &["custom", "serializers", "yaml"], - ], - cx, - ); - } - - #[gpui::test] - fn test_python_named_module_imports(cx: &mut TestAppContext) { - // TODO: These should provide the name that the module is bound to. - // For now instead these are treated as unqualified wildcard imports. - // - // check_imports(&PYTHON, "import math", &[&["math", "WILDCARD as math"]], cx); - // check_imports(&PYTHON, "import math as maths", &[&["math", "WILDCARD AS maths"]], cx); - // - // Something like: - // - // (import_statement - // name: [ - // (dotted_name - // (identifier)* @namespace - // (identifier) @name.module .) - // (aliased_import - // name: (dotted_name - // ((identifier) ".")* @namespace - // (identifier) @name.module .) - // alias: (identifier) @alias) - // ]) @import - - check_imports(&PYTHON, "import math", &[&["math", "WILDCARD"]], cx); - - check_imports( - &PYTHON, - "import math as maths", - &[&["math", "WILDCARD"]], - cx, - ); - - check_imports(&PYTHON, "import a.b.c", &[&["a", "b", "c", "WILDCARD"]], cx); - - check_imports( - &PYTHON, - "import a.b.c as d", - &[&["a", "b", "c", "WILDCARD"]], - cx, - ); - } - - #[gpui::test] - fn test_python_package_relative_imports(cx: &mut TestAppContext) { - // TODO: These should provide info about the dir they are relative to, to provide more - // precise resolution. Instead, fuzzy matching is used as usual. - - check_imports(&PYTHON, "from . import math", &[&["math"]], cx); - - check_imports(&PYTHON, "from .a import math", &[&["a", "math"]], cx); - - check_imports( - &PYTHON, - "from ..a.b import math", - &[&["a", "b", "math"]], - cx, - ); - - check_imports( - &PYTHON, - "from ..a.b import *", - &[&["a", "b", "WILDCARD"]], - cx, - ); - } - - #[gpui::test] - fn test_c_imports(cx: &mut TestAppContext) { - let parent_abs_path = PathBuf::from("/home/user/project"); - - // TODO: Distinguish that these are not relative to current path - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &C, - r#"#include "#, - &[&["SOURCE FUZZY math.h", "WILDCARD"]], - cx, - ); - - // TODO: These should be treated as relative, but don't start with ./ or ../ - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &C, - r#"#include "math.h""#, - &[&["SOURCE FUZZY math.h", "WILDCARD"]], - cx, - ); - } - - #[gpui::test] - fn test_cpp_imports(cx: &mut TestAppContext) { - let parent_abs_path = PathBuf::from("/home/user/project"); - - // TODO: Distinguish that these are not relative to current path - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &CPP, - r#"#include "#, - &[&["SOURCE FUZZY math.h", "WILDCARD"]], - cx, - ); - - // TODO: These should be treated as relative, but don't start with ./ or ../ - check_imports_with_file_abs_path( - Some(&parent_abs_path), - &CPP, - r#"#include "math.h""#, - &[&["SOURCE FUZZY math.h", "WILDCARD"]], - cx, - ); - } - - #[gpui::test] - fn test_go_imports(cx: &mut TestAppContext) { - check_imports( - &GO, - r#"import . "lib/math""#, - &[&["lib/math", "WILDCARD"]], - cx, - ); - - // not included, these are only for side-effects - check_imports(&GO, r#"import _ "lib/math""#, &[], cx); - } - - #[gpui::test] - fn test_go_named_module_imports(cx: &mut TestAppContext) { - // TODO: These should provide the name that the module is bound to. - // For now instead these are treated as unqualified wildcard imports. - - check_imports( - &GO, - r#"import "lib/math""#, - &[&["lib/math", "WILDCARD"]], - cx, - ); - check_imports( - &GO, - r#"import m "lib/math""#, - &[&["lib/math", "WILDCARD"]], - cx, - ); - } - - #[track_caller] - fn check_imports( - language: &Arc, - source: &str, - expected: &[&[&str]], - cx: &mut TestAppContext, - ) { - check_imports_with_file_abs_path(None, language, source, expected, cx); - } - - #[track_caller] - fn check_imports_with_file_abs_path( - parent_abs_path: Option<&Path>, - language: &Arc, - source: &str, - expected: &[&[&str]], - cx: &mut TestAppContext, - ) { - let buffer = cx.new(|cx| { - let mut buffer = Buffer::local(source, cx); - buffer.set_language(Some(language.clone()), cx); - buffer - }); - cx.run_until_parked(); - - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - - let imports = Imports::gather(&snapshot, parent_abs_path); - let mut actual_symbols = imports - .identifier_to_imports - .iter() - .flat_map(|(identifier, imports)| { - imports - .iter() - .map(|import| import.to_identifier_parts(identifier.name.as_ref())) - }) - .chain( - imports - .wildcard_modules - .iter() - .map(|module| module.to_identifier_parts("WILDCARD")), - ) - .collect::>(); - let mut expected_symbols = expected - .iter() - .map(|expected| expected.iter().map(|s| s.to_string()).collect::>()) - .collect::>(); - actual_symbols.sort(); - expected_symbols.sort(); - if actual_symbols != expected_symbols { - let top_layer = snapshot.syntax_layers().next().unwrap(); - panic!( - "Expected imports: {:?}\n\ - Actual imports: {:?}\n\ - Tree:\n{}", - expected_symbols, - actual_symbols, - tree_to_string(&top_layer.node()), - ); - } - } - - fn tree_to_string(node: &tree_sitter::Node) -> String { - let mut cursor = node.walk(); - let mut result = String::new(); - let mut depth = 0; - 'outer: loop { - result.push_str(&" ".repeat(depth)); - if let Some(field_name) = cursor.field_name() { - result.push_str(field_name); - result.push_str(": "); - } - if cursor.node().is_named() { - result.push_str(cursor.node().kind()); - } else { - result.push('"'); - result.push_str(cursor.node().kind()); - result.push('"'); - } - result.push('\n'); - - if cursor.goto_first_child() { - depth += 1; - continue; - } - if cursor.goto_next_sibling() { - continue; - } - while cursor.goto_parent() { - depth -= 1; - if cursor.goto_next_sibling() { - continue 'outer; - } - } - break; - } - result - } - - static RUST: LazyLock> = LazyLock::new(|| { - Arc::new( - Language::new( - LanguageConfig { - name: "Rust".into(), - ignored_import_segments: HashSet::from_iter(["crate".into(), "super".into()]), - import_path_strip_regex: Some(Regex::new("/(lib|mod)\\.rs$").unwrap()), - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_imports_query(include_str!("../../languages/src/rust/imports.scm")) - .unwrap(), - ) - }); - - static TYPESCRIPT: LazyLock> = LazyLock::new(|| { - Arc::new( - Language::new( - LanguageConfig { - name: "TypeScript".into(), - import_path_strip_regex: Some(Regex::new("(?:/index)?\\.[jt]s$").unwrap()), - ..Default::default() - }, - Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()), - ) - .with_imports_query(include_str!("../../languages/src/typescript/imports.scm")) - .unwrap(), - ) - }); - - static PYTHON: LazyLock> = LazyLock::new(|| { - Arc::new( - Language::new( - LanguageConfig { - name: "Python".into(), - import_path_strip_regex: Some(Regex::new("/__init__\\.py$").unwrap()), - ..Default::default() - }, - Some(tree_sitter_python::LANGUAGE.into()), - ) - .with_imports_query(include_str!("../../languages/src/python/imports.scm")) - .unwrap(), - ) - }); - - // TODO: Ideally should use actual language configurations - static C: LazyLock> = LazyLock::new(|| { - Arc::new( - Language::new( - LanguageConfig { - name: "C".into(), - import_path_strip_regex: Some(Regex::new("^<|>$").unwrap()), - ..Default::default() - }, - Some(tree_sitter_c::LANGUAGE.into()), - ) - .with_imports_query(include_str!("../../languages/src/c/imports.scm")) - .unwrap(), - ) - }); - - static CPP: LazyLock> = LazyLock::new(|| { - Arc::new( - Language::new( - LanguageConfig { - name: "C++".into(), - import_path_strip_regex: Some(Regex::new("^<|>$").unwrap()), - ..Default::default() - }, - Some(tree_sitter_cpp::LANGUAGE.into()), - ) - .with_imports_query(include_str!("../../languages/src/cpp/imports.scm")) - .unwrap(), - ) - }); - - static GO: LazyLock> = LazyLock::new(|| { - Arc::new( - Language::new( - LanguageConfig { - name: "Go".into(), - ..Default::default() - }, - Some(tree_sitter_go::LANGUAGE.into()), - ) - .with_imports_query(include_str!("../../languages/src/go/imports.scm")) - .unwrap(), - ) - }); - - impl Import { - fn to_identifier_parts(&self, identifier: &str) -> Vec { - match self { - Import::Direct { module } => module.to_identifier_parts(identifier), - Import::Alias { - module, - external_identifier: external_name, - } => { - module.to_identifier_parts(&format!("{} AS {}", external_name.name, identifier)) - } - } - } - } - - impl Module { - fn to_identifier_parts(&self, identifier: &str) -> Vec { - match self { - Self::Namespace(namespace) => namespace.to_identifier_parts(identifier), - Self::SourceExact(path) => { - vec![ - format!("SOURCE {}", path.display().to_string().replace("\\", "/")), - identifier.to_string(), - ] - } - Self::SourceFuzzy(path) => { - vec![ - format!( - "SOURCE FUZZY {}", - path.display().to_string().replace("\\", "/") - ), - identifier.to_string(), - ] - } - } - } - } - - impl Namespace { - fn to_identifier_parts(&self, identifier: &str) -> Vec { - self.0 - .iter() - .map(|chunk| chunk.to_string()) - .chain(std::iter::once(identifier.to_string())) - .collect::>() - } - } -} diff --git a/crates/edit_prediction_context/src/outline.rs b/crates/edit_prediction_context/src/outline.rs deleted file mode 100644 index ec02c869dfae4cb861206cb801c285462e734f36..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_context/src/outline.rs +++ /dev/null @@ -1,126 +0,0 @@ -use language::{BufferSnapshot, SyntaxMapMatches}; -use std::{cmp::Reverse, ops::Range}; - -use crate::declaration::Identifier; - -// TODO: -// -// * how to handle multiple name captures? for now last one wins -// -// * annotation ranges -// -// * new "signature" capture for outline queries -// -// * Check parent behavior of "int x, y = 0" declarations in a test - -pub struct OutlineDeclaration { - pub parent_index: Option, - pub identifier: Identifier, - pub item_range: Range, - pub signature_range: Range, -} - -pub fn declarations_in_buffer(buffer: &BufferSnapshot) -> Vec { - declarations_overlapping_range(0..buffer.len(), buffer) -} - -pub fn declarations_overlapping_range( - range: Range, - buffer: &BufferSnapshot, -) -> Vec { - let mut declarations = OutlineIterator::new(range, buffer).collect::>(); - declarations.sort_unstable_by_key(|item| (item.item_range.start, Reverse(item.item_range.end))); - - let mut parent_stack: Vec<(usize, Range)> = Vec::new(); - for (index, declaration) in declarations.iter_mut().enumerate() { - while let Some((top_parent_index, top_parent_range)) = parent_stack.last() { - if declaration.item_range.start >= top_parent_range.end { - parent_stack.pop(); - } else { - declaration.parent_index = Some(*top_parent_index); - break; - } - } - parent_stack.push((index, declaration.item_range.clone())); - } - declarations -} - -/// Iterates outline items without being ordered w.r.t. nested items and without populating -/// `parent`. -pub struct OutlineIterator<'a> { - buffer: &'a BufferSnapshot, - matches: SyntaxMapMatches<'a>, -} - -impl<'a> OutlineIterator<'a> { - pub fn new(range: Range, buffer: &'a BufferSnapshot) -> Self { - let matches = buffer.syntax.matches(range, &buffer.text, |grammar| { - grammar.outline_config.as_ref().map(|c| &c.query) - }); - - Self { buffer, matches } - } -} - -impl<'a> Iterator for OutlineIterator<'a> { - type Item = OutlineDeclaration; - - fn next(&mut self) -> Option { - while let Some(mat) = self.matches.peek() { - let config = self.matches.grammars()[mat.grammar_index] - .outline_config - .as_ref() - .unwrap(); - - let mut name_range = None; - let mut item_range = None; - let mut signature_start = None; - let mut signature_end = None; - - let mut add_to_signature = |range: Range| { - if signature_start.is_none() { - signature_start = Some(range.start); - } - signature_end = Some(range.end); - }; - - for capture in mat.captures { - let range = capture.node.byte_range(); - if capture.index == config.name_capture_ix { - name_range = Some(range.clone()); - add_to_signature(range); - } else if Some(capture.index) == config.context_capture_ix - || Some(capture.index) == config.extra_context_capture_ix - { - add_to_signature(range); - } else if capture.index == config.item_capture_ix { - item_range = Some(range.clone()); - } - } - - let language_id = mat.language.id(); - self.matches.advance(); - - if let Some(name_range) = name_range - && let Some(item_range) = item_range - && let Some(signature_start) = signature_start - && let Some(signature_end) = signature_end - { - let name = self - .buffer - .text_for_range(name_range) - .collect::() - .into(); - - return Some(OutlineDeclaration { - identifier: Identifier { name, language_id }, - item_range: item_range, - signature_range: signature_start..signature_end, - parent_index: None, - }); - } - } - None - } -} diff --git a/crates/edit_prediction_context/src/reference.rs b/crates/edit_prediction_context/src/reference.rs deleted file mode 100644 index 699adf1d8036802a7a4b9e34ca8e8094e4f97458..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_context/src/reference.rs +++ /dev/null @@ -1,173 +0,0 @@ -use collections::HashMap; -use language::BufferSnapshot; -use std::ops::Range; -use util::RangeExt; - -use crate::{ - declaration::Identifier, - excerpt::{EditPredictionExcerpt, EditPredictionExcerptText}, -}; - -#[derive(Debug, Clone)] -pub struct Reference { - pub identifier: Identifier, - pub range: Range, - pub region: ReferenceRegion, -} - -#[derive(Copy, Clone, Debug, Eq, PartialEq)] -pub enum ReferenceRegion { - Breadcrumb, - Nearby, -} - -pub fn references_in_excerpt( - excerpt: &EditPredictionExcerpt, - excerpt_text: &EditPredictionExcerptText, - snapshot: &BufferSnapshot, -) -> HashMap> { - let mut references = references_in_range( - excerpt.range.clone(), - excerpt_text.body.as_str(), - ReferenceRegion::Nearby, - snapshot, - ); - - for ((_, range), text) in excerpt - .parent_declarations - .iter() - .zip(excerpt_text.parent_signatures.iter()) - { - references.extend(references_in_range( - range.clone(), - text.as_str(), - ReferenceRegion::Breadcrumb, - snapshot, - )); - } - - let mut identifier_to_references: HashMap> = HashMap::default(); - for reference in references { - identifier_to_references - .entry(reference.identifier.clone()) - .or_insert_with(Vec::new) - .push(reference); - } - identifier_to_references -} - -/// Finds all nodes which have a "variable" match from the highlights query within the offset range. -pub fn references_in_range( - range: Range, - range_text: &str, - reference_region: ReferenceRegion, - buffer: &BufferSnapshot, -) -> Vec { - let mut matches = buffer - .syntax - .matches(range.clone(), &buffer.text, |grammar| { - grammar - .highlights_config - .as_ref() - .map(|config| &config.query) - }); - - let mut references = Vec::new(); - let mut last_added_range = None; - while let Some(mat) = matches.peek() { - let config = matches.grammars()[mat.grammar_index] - .highlights_config - .as_ref(); - - if let Some(config) = config { - for capture in mat.captures { - if config.identifier_capture_indices.contains(&capture.index) { - let node_range = capture.node.byte_range(); - - // sometimes multiple highlight queries match - this deduplicates them - if Some(node_range.clone()) == last_added_range { - continue; - } - - if !range.contains_inclusive(&node_range) { - continue; - } - - let identifier_text = - &range_text[node_range.start - range.start..node_range.end - range.start]; - - references.push(Reference { - identifier: Identifier { - name: identifier_text.into(), - language_id: mat.language.id(), - }, - range: node_range.clone(), - region: reference_region, - }); - last_added_range = Some(node_range); - } - } - } - - matches.advance(); - } - references -} - -#[cfg(test)] -mod test { - use gpui::{TestAppContext, prelude::*}; - use indoc::indoc; - use language::{BufferSnapshot, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; - - use crate::reference::{ReferenceRegion, references_in_range}; - - #[gpui::test] - fn test_identifier_node_truncated(cx: &mut TestAppContext) { - let code = indoc! { r#" - fn main() { - add(1, 2); - } - - fn add(a: i32, b: i32) -> i32 { - a + b - } - "# }; - let buffer = create_buffer(code, cx); - - let range = 0..35; - let references = references_in_range( - range.clone(), - &code[range], - ReferenceRegion::Breadcrumb, - &buffer, - ); - assert_eq!(references.len(), 2); - assert_eq!(references[0].identifier.name.as_ref(), "main"); - assert_eq!(references[1].identifier.name.as_ref(), "add"); - } - - fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot { - let buffer = - cx.new(|cx| language::Buffer::local(text, cx).with_language(rust_lang().into(), cx)); - buffer.read_with(cx, |buffer, _| buffer.snapshot()) - } - - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm")) - .unwrap() - .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) - .unwrap() - } -} diff --git a/crates/edit_prediction_context/src/syntax_index.rs b/crates/edit_prediction_context/src/syntax_index.rs deleted file mode 100644 index f489a083341b66c7cca3cdad76a9c7ea16fdc959..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_context/src/syntax_index.rs +++ /dev/null @@ -1,1069 +0,0 @@ -use anyhow::{Result, anyhow}; -use collections::{HashMap, HashSet}; -use futures::channel::mpsc; -use futures::lock::Mutex; -use futures::{FutureExt as _, StreamExt, future}; -use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity}; -use itertools::Itertools; - -use language::{Buffer, BufferEvent}; -use postage::stream::Stream as _; -use project::buffer_store::{BufferStore, BufferStoreEvent}; -use project::worktree_store::{WorktreeStore, WorktreeStoreEvent}; -use project::{PathChange, Project, ProjectEntryId, ProjectPath}; -use slotmap::SlotMap; -use std::iter; -use std::ops::{DerefMut, Range}; -use std::sync::Arc; -use text::BufferId; -use util::{RangeExt as _, debug_panic, some_or_debug_panic}; - -use crate::CachedDeclarationPath; -use crate::declaration::{ - BufferDeclaration, Declaration, DeclarationId, FileDeclaration, Identifier, -}; -use crate::outline::declarations_in_buffer; - -// TODO -// -// * Also queue / debounce buffer changes. A challenge for this is that use of -// `buffer_declarations_containing_range` assumes that the index is always immediately up to date. -// -// * Add a per language configuration for skipping indexing. -// -// * Handle tsx / ts / js referencing each-other - -// Potential future improvements: -// -// * Prevent indexing of a large file from blocking the queue. -// -// * Send multiple selected excerpt ranges. Challenge is that excerpt ranges influence which -// references are present and their scores. -// -// * Include single-file worktrees / non visible worktrees? E.g. go to definition that resolves to a -// file in a build dependency. Should not be editable in that case - but how to distinguish the case -// where it should be editable? - -// Potential future optimizations: -// -// * Index files on multiple threads in Zed (currently only parallel for the CLI). Adding some kind -// of priority system to the background executor could help - it's single threaded for now to avoid -// interfering with other work. -// -// * Parse files directly instead of loading into a Rope. -// -// - This would allow the task handling dirty_files to be done entirely on the background executor. -// -// - Make SyntaxMap generic to handle embedded languages? Will also need to find line boundaries, -// but that can be done by scanning characters in the flat representation. -// -// * Use something similar to slotmap without key versions. -// -// * Concurrent slotmap - -pub struct SyntaxIndex { - state: Arc>, - project: WeakEntity, - initial_file_indexing_done_rx: postage::watch::Receiver, - _file_indexing_task: Option>, -} - -pub struct SyntaxIndexState { - declarations: SlotMap, - identifiers: HashMap>, - files: HashMap, - buffers: HashMap, - dirty_files: HashMap, - dirty_files_tx: mpsc::Sender<()>, -} - -#[derive(Debug, Default)] -struct FileState { - declarations: Vec, -} - -#[derive(Default)] -struct BufferState { - declarations: Vec, - task: Option>, -} - -impl SyntaxIndex { - pub fn new( - project: &Entity, - file_indexing_parallelism: usize, - cx: &mut Context, - ) -> Self { - assert!(file_indexing_parallelism > 0); - let (dirty_files_tx, mut dirty_files_rx) = mpsc::channel::<()>(1); - let (mut initial_file_indexing_done_tx, initial_file_indexing_done_rx) = - postage::watch::channel(); - - let initial_state = SyntaxIndexState { - declarations: SlotMap::default(), - identifiers: HashMap::default(), - files: HashMap::default(), - buffers: HashMap::default(), - dirty_files: HashMap::default(), - dirty_files_tx, - }; - let mut this = Self { - project: project.downgrade(), - state: Arc::new(Mutex::new(initial_state)), - initial_file_indexing_done_rx, - _file_indexing_task: None, - }; - - let worktree_store = project.read(cx).worktree_store(); - let initial_worktree_snapshots = worktree_store - .read(cx) - .worktrees() - .map(|w| w.read(cx).snapshot()) - .collect::>(); - this._file_indexing_task = Some(cx.spawn(async move |this, cx| { - let snapshots_file_count = initial_worktree_snapshots - .iter() - .map(|worktree| worktree.file_count()) - .sum::(); - if snapshots_file_count > 0 { - let chunk_size = snapshots_file_count.div_ceil(file_indexing_parallelism); - let chunk_count = snapshots_file_count.div_ceil(chunk_size); - let file_chunks = initial_worktree_snapshots - .iter() - .flat_map(|worktree| { - let worktree_id = worktree.id(); - worktree.files(false, 0).map(move |entry| { - ( - entry.id, - ProjectPath { - worktree_id, - path: entry.path.clone(), - }, - ) - }) - }) - .chunks(chunk_size); - - let mut tasks = Vec::with_capacity(chunk_count); - for chunk in file_chunks.into_iter() { - tasks.push(Self::update_dirty_files( - &this, - chunk.into_iter().collect(), - cx.clone(), - )); - } - futures::future::join_all(tasks).await; - log::info!("Finished initial file indexing"); - } - - *initial_file_indexing_done_tx.borrow_mut() = true; - - let Ok(state) = this.read_with(cx, |this, _cx| Arc::downgrade(&this.state)) else { - return; - }; - while dirty_files_rx.next().await.is_some() { - let Some(state) = state.upgrade() else { - return; - }; - let mut state = state.lock().await; - let was_underused = state.dirty_files.capacity() > 255 - && state.dirty_files.len() * 8 < state.dirty_files.capacity(); - let dirty_files = state.dirty_files.drain().collect::>(); - if was_underused { - state.dirty_files.shrink_to_fit(); - } - drop(state); - if dirty_files.is_empty() { - continue; - } - - let chunk_size = dirty_files.len().div_ceil(file_indexing_parallelism); - let chunk_count = dirty_files.len().div_ceil(chunk_size); - let mut tasks = Vec::with_capacity(chunk_count); - let chunks = dirty_files.into_iter().chunks(chunk_size); - for chunk in chunks.into_iter() { - tasks.push(Self::update_dirty_files( - &this, - chunk.into_iter().collect(), - cx.clone(), - )); - } - futures::future::join_all(tasks).await; - } - })); - - cx.subscribe(&worktree_store, Self::handle_worktree_store_event) - .detach(); - - let buffer_store = project.read(cx).buffer_store().clone(); - for buffer in buffer_store.read(cx).buffers().collect::>() { - this.register_buffer(&buffer, cx); - } - cx.subscribe(&buffer_store, Self::handle_buffer_store_event) - .detach(); - - this - } - - async fn update_dirty_files( - this: &WeakEntity, - dirty_files: Vec<(ProjectEntryId, ProjectPath)>, - mut cx: AsyncApp, - ) { - for (entry_id, project_path) in dirty_files { - let Ok(task) = this.update(&mut cx, |this, cx| { - this.update_file(entry_id, project_path, cx) - }) else { - return; - }; - task.await; - } - } - - pub fn wait_for_initial_file_indexing(&self, cx: &App) -> Task> { - if *self.initial_file_indexing_done_rx.borrow() { - Task::ready(Ok(())) - } else { - let mut rx = self.initial_file_indexing_done_rx.clone(); - cx.background_spawn(async move { - loop { - match rx.recv().await { - Some(true) => return Ok(()), - Some(false) => {} - None => { - return Err(anyhow!( - "SyntaxIndex dropped while waiting for initial file indexing" - )); - } - } - } - }) - } - } - - pub fn indexed_file_paths(&self, cx: &App) -> Task> { - let state = self.state.clone(); - let project = self.project.clone(); - - cx.spawn(async move |cx| { - let state = state.lock().await; - let Some(project) = project.upgrade() else { - return vec![]; - }; - project - .read_with(cx, |project, cx| { - state - .files - .keys() - .filter_map(|entry_id| project.path_for_entry(*entry_id, cx)) - .collect() - }) - .unwrap_or_default() - }) - } - - fn handle_worktree_store_event( - &mut self, - _worktree_store: Entity, - event: &WorktreeStoreEvent, - cx: &mut Context, - ) { - use WorktreeStoreEvent::*; - match event { - WorktreeUpdatedEntries(worktree_id, updated_entries_set) => { - let state = Arc::downgrade(&self.state); - let worktree_id = *worktree_id; - let updated_entries_set = updated_entries_set.clone(); - cx.background_spawn(async move { - let Some(state) = state.upgrade() else { return }; - let mut state = state.lock().await; - for (path, entry_id, path_change) in updated_entries_set.iter() { - if let PathChange::Removed = path_change { - state.files.remove(entry_id); - state.dirty_files.remove(entry_id); - } else { - let project_path = ProjectPath { - worktree_id, - path: path.clone(), - }; - state.dirty_files.insert(*entry_id, project_path); - } - } - match state.dirty_files_tx.try_send(()) { - Err(err) if err.is_disconnected() => { - log::error!("bug: syntax indexing queue is disconnected"); - } - _ => {} - } - }) - .detach(); - } - WorktreeDeletedEntry(_worktree_id, project_entry_id) => { - let project_entry_id = *project_entry_id; - self.with_state(cx, move |state| { - state.files.remove(&project_entry_id); - }) - } - _ => {} - } - } - - fn handle_buffer_store_event( - &mut self, - _buffer_store: Entity, - event: &BufferStoreEvent, - cx: &mut Context, - ) { - use BufferStoreEvent::*; - match event { - BufferAdded(buffer) => self.register_buffer(buffer, cx), - BufferOpened { .. } - | BufferChangedFilePath { .. } - | BufferDropped { .. } - | SharedBufferClosed { .. } => {} - } - } - - pub fn state(&self) -> &Arc> { - &self.state - } - - fn with_state(&self, cx: &mut App, f: impl FnOnce(&mut SyntaxIndexState) + Send + 'static) { - if let Some(mut state) = self.state.try_lock() { - f(&mut state); - return; - } - let state = Arc::downgrade(&self.state); - cx.background_spawn(async move { - let Some(state) = state.upgrade() else { - return; - }; - let mut state = state.lock().await; - f(&mut state) - }) - .detach(); - } - - fn register_buffer(&self, buffer: &Entity, cx: &mut Context) { - let buffer_id = buffer.read(cx).remote_id(); - cx.observe_release(buffer, move |this, _buffer, cx| { - this.with_state(cx, move |state| { - if let Some(buffer_state) = state.buffers.remove(&buffer_id) { - SyntaxIndexState::remove_buffer_declarations( - &buffer_state.declarations, - &mut state.declarations, - &mut state.identifiers, - ); - } - }) - }) - .detach(); - cx.subscribe(buffer, Self::handle_buffer_event).detach(); - - self.update_buffer(buffer.clone(), cx); - } - - fn handle_buffer_event( - &mut self, - buffer: Entity, - event: &BufferEvent, - cx: &mut Context, - ) { - match event { - BufferEvent::Edited | - // paths are cached and so should be updated - BufferEvent::FileHandleChanged => self.update_buffer(buffer, cx), - _ => {} - } - } - - fn update_buffer(&self, buffer_entity: Entity, cx: &mut Context) { - let buffer = buffer_entity.read(cx); - if buffer.language().is_none() { - return; - } - - let Some((project_entry_id, cached_path)) = project::File::from_dyn(buffer.file()) - .and_then(|f| { - let project_entry_id = f.project_entry_id()?; - let cached_path = CachedDeclarationPath::new( - f.worktree.read(cx).abs_path(), - &f.path, - buffer.language(), - ); - Some((project_entry_id, cached_path)) - }) - else { - return; - }; - let buffer_id = buffer.remote_id(); - - let mut parse_status = buffer.parse_status(); - let snapshot_task = cx.spawn({ - let weak_buffer = buffer_entity.downgrade(); - async move |_, cx| { - while *parse_status.borrow() != language::ParseStatus::Idle { - parse_status.changed().await?; - } - weak_buffer.read_with(cx, |buffer, _cx| buffer.snapshot()) - } - }); - - let state = Arc::downgrade(&self.state); - let task = cx.background_spawn(async move { - // TODO: How to handle errors? - let Ok(snapshot) = snapshot_task.await else { - return; - }; - let rope = snapshot.text.as_rope(); - - let declarations = declarations_in_buffer(&snapshot) - .into_iter() - .map(|item| { - ( - item.parent_index, - BufferDeclaration::from_outline(item, &rope), - ) - }) - .collect::>(); - - let Some(state) = state.upgrade() else { - return; - }; - let mut state = state.lock().await; - let state = state.deref_mut(); - - let buffer_state = state - .buffers - .entry(buffer_id) - .or_insert_with(Default::default); - - SyntaxIndexState::remove_buffer_declarations( - &buffer_state.declarations, - &mut state.declarations, - &mut state.identifiers, - ); - - let mut new_ids = Vec::with_capacity(declarations.len()); - state.declarations.reserve(declarations.len()); - for (parent_index, mut declaration) in declarations { - declaration.parent = - parent_index.and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied())); - - let identifier = declaration.identifier.clone(); - let declaration_id = state.declarations.insert(Declaration::Buffer { - rope: rope.clone(), - buffer_id, - declaration, - project_entry_id, - cached_path: cached_path.clone(), - }); - new_ids.push(declaration_id); - - state - .identifiers - .entry(identifier) - .or_default() - .insert(declaration_id); - } - - buffer_state.declarations = new_ids; - }); - - self.with_state(cx, move |state| { - state - .buffers - .entry(buffer_id) - .or_insert_with(Default::default) - .task = Some(task) - }); - } - - fn update_file( - &mut self, - entry_id: ProjectEntryId, - project_path: ProjectPath, - cx: &mut Context, - ) -> Task<()> { - let Some(project) = self.project.upgrade() else { - return Task::ready(()); - }; - let project = project.read(cx); - - let language_registry = project.languages(); - let Some(available_language) = - language_registry.language_for_file_path(project_path.path.as_std_path()) - else { - return Task::ready(()); - }; - let language = if let Some(Ok(Ok(language))) = language_registry - .load_language(&available_language) - .now_or_never() - { - if language - .grammar() - .is_none_or(|grammar| grammar.outline_config.is_none()) - { - return Task::ready(()); - } - future::Either::Left(async { Ok(language) }) - } else { - let language_registry = language_registry.clone(); - future::Either::Right(async move { - anyhow::Ok( - language_registry - .load_language(&available_language) - .await??, - ) - }) - }; - - let Some(worktree) = project.worktree_for_id(project_path.worktree_id, cx) else { - return Task::ready(()); - }; - - let snapshot_task = worktree.update(cx, |worktree, cx| { - let load_task = worktree.load_file(&project_path.path, cx); - let worktree_abs_path = worktree.abs_path(); - cx.spawn(async move |_this, cx| { - let loaded_file = load_task.await?; - let language = language.await?; - - let buffer = cx.new(|cx| { - let mut buffer = Buffer::local(loaded_file.text, cx); - buffer.set_language(Some(language.clone()), cx); - buffer - })?; - - let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?; - while *parse_status.borrow() != language::ParseStatus::Idle { - parse_status.changed().await?; - } - - let cached_path = CachedDeclarationPath::new( - worktree_abs_path, - &project_path.path, - Some(&language), - ); - - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - - anyhow::Ok((snapshot, cached_path)) - }) - }); - - let state = Arc::downgrade(&self.state); - cx.background_spawn(async move { - // TODO: How to handle errors? - let Ok((snapshot, cached_path)) = snapshot_task.await else { - return; - }; - let rope = snapshot.as_rope(); - let declarations = declarations_in_buffer(&snapshot) - .into_iter() - .map(|item| (item.parent_index, FileDeclaration::from_outline(item, rope))) - .collect::>(); - - let Some(state) = state.upgrade() else { - return; - }; - let mut state = state.lock().await; - let state = state.deref_mut(); - - let file_state = state.files.entry(entry_id).or_insert_with(Default::default); - for old_declaration_id in &file_state.declarations { - let Some(declaration) = state.declarations.remove(*old_declaration_id) else { - debug_panic!("declaration not found"); - continue; - }; - if let Some(identifier_declarations) = - state.identifiers.get_mut(declaration.identifier()) - { - identifier_declarations.remove(old_declaration_id); - } - } - - let mut new_ids = Vec::with_capacity(declarations.len()); - state.declarations.reserve(declarations.len()); - for (parent_index, mut declaration) in declarations { - declaration.parent = - parent_index.and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied())); - - let identifier = declaration.identifier.clone(); - let declaration_id = state.declarations.insert(Declaration::File { - project_entry_id: entry_id, - declaration, - cached_path: cached_path.clone(), - }); - new_ids.push(declaration_id); - - state - .identifiers - .entry(identifier) - .or_default() - .insert(declaration_id); - } - file_state.declarations = new_ids; - }) - } -} - -impl SyntaxIndexState { - pub fn declaration(&self, id: DeclarationId) -> Option<&Declaration> { - self.declarations.get(id) - } - - /// Returns declarations for the identifier. If the limit is exceeded, returns an empty vector. - /// - /// TODO: Consider doing some pre-ranking and instead truncating when N is exceeded. - pub fn declarations_for_identifier( - &self, - identifier: &Identifier, - ) -> Vec<(DeclarationId, &Declaration)> { - // make sure to not have a large stack allocation - assert!(N < 32); - - let Some(declaration_ids) = self.identifiers.get(&identifier) else { - return vec![]; - }; - - let mut result = Vec::with_capacity(N); - let mut included_buffer_entry_ids = arrayvec::ArrayVec::<_, N>::new(); - let mut file_declarations = Vec::new(); - - for declaration_id in declaration_ids { - let declaration = self.declarations.get(*declaration_id); - let Some(declaration) = some_or_debug_panic(declaration) else { - continue; - }; - match declaration { - Declaration::Buffer { - project_entry_id, .. - } => { - included_buffer_entry_ids.push(*project_entry_id); - result.push((*declaration_id, declaration)); - if result.len() == N { - return Vec::new(); - } - } - Declaration::File { - project_entry_id, .. - } => { - if !included_buffer_entry_ids.contains(&project_entry_id) { - file_declarations.push((*declaration_id, declaration)); - } - } - } - } - - for (declaration_id, declaration) in file_declarations { - match declaration { - Declaration::File { - project_entry_id, .. - } => { - if !included_buffer_entry_ids.contains(&project_entry_id) { - result.push((declaration_id, declaration)); - - if result.len() == N { - return Vec::new(); - } - } - } - Declaration::Buffer { .. } => {} - } - } - - result - } - - pub fn buffer_declarations_containing_range( - &self, - buffer_id: BufferId, - range: Range, - ) -> impl Iterator { - let Some(buffer_state) = self.buffers.get(&buffer_id) else { - return itertools::Either::Left(iter::empty()); - }; - - let iter = buffer_state - .declarations - .iter() - .filter_map(move |declaration_id| { - let Some(declaration) = self - .declarations - .get(*declaration_id) - .and_then(|d| d.as_buffer()) - else { - log::error!("bug: missing buffer outline declaration"); - return None; - }; - if declaration.item_range.contains_inclusive(&range) { - return Some((*declaration_id, declaration)); - } - return None; - }); - itertools::Either::Right(iter) - } - - pub fn file_declaration_count(&self, declaration: &Declaration) -> usize { - match declaration { - Declaration::File { - project_entry_id, .. - } => self - .files - .get(project_entry_id) - .map(|file_state| file_state.declarations.len()) - .unwrap_or_default(), - Declaration::Buffer { buffer_id, .. } => self - .buffers - .get(buffer_id) - .map(|buffer_state| buffer_state.declarations.len()) - .unwrap_or_default(), - } - } - - fn remove_buffer_declarations( - old_declaration_ids: &[DeclarationId], - declarations: &mut SlotMap, - identifiers: &mut HashMap>, - ) { - for old_declaration_id in old_declaration_ids { - let Some(declaration) = declarations.remove(*old_declaration_id) else { - debug_panic!("declaration not found"); - continue; - }; - if let Some(identifier_declarations) = identifiers.get_mut(declaration.identifier()) { - identifier_declarations.remove(old_declaration_id); - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::sync::Arc; - - use gpui::TestAppContext; - use indoc::indoc; - use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust}; - use project::{FakeFs, Project}; - use serde_json::json; - use settings::SettingsStore; - use text::OffsetRangeExt as _; - use util::{path, rel_path::rel_path}; - - use crate::syntax_index::SyntaxIndex; - - #[gpui::test] - async fn test_unopen_indexed_files(cx: &mut TestAppContext) { - let (project, index, rust_lang_id) = init_test(cx).await; - let main = Identifier { - name: "main".into(), - language_id: rust_lang_id, - }; - - let index_state = index.read_with(cx, |index, _cx| index.state().clone()); - let index_state = index_state.lock().await; - cx.update(|cx| { - let decls = index_state.declarations_for_identifier::<8>(&main); - assert_eq!(decls.len(), 2); - - let decl = expect_file_decl("a.rs", &decls[0].1, &project, cx); - assert_eq!(decl.identifier, main); - assert_eq!(decl.item_range, 0..98); - - let decl = expect_file_decl("c.rs", &decls[1].1, &project, cx); - assert_eq!(decl.identifier, main.clone()); - assert_eq!(decl.item_range, 32..280); - }); - } - - #[gpui::test] - async fn test_parents_in_file(cx: &mut TestAppContext) { - let (project, index, rust_lang_id) = init_test(cx).await; - let test_process_data = Identifier { - name: "test_process_data".into(), - language_id: rust_lang_id, - }; - - let index_state = index.read_with(cx, |index, _cx| index.state().clone()); - let index_state = index_state.lock().await; - cx.update(|cx| { - let decls = index_state.declarations_for_identifier::<8>(&test_process_data); - assert_eq!(decls.len(), 1); - - let decl = expect_file_decl("c.rs", &decls[0].1, &project, cx); - assert_eq!(decl.identifier, test_process_data); - - let parent_id = decl.parent.unwrap(); - let parent = index_state.declaration(parent_id).unwrap(); - let parent_decl = expect_file_decl("c.rs", &parent, &project, cx); - assert_eq!( - parent_decl.identifier, - Identifier { - name: "tests".into(), - language_id: rust_lang_id - } - ); - assert_eq!(parent_decl.parent, None); - }); - } - - #[gpui::test] - async fn test_parents_in_buffer(cx: &mut TestAppContext) { - let (project, index, rust_lang_id) = init_test(cx).await; - let test_process_data = Identifier { - name: "test_process_data".into(), - language_id: rust_lang_id, - }; - - let buffer = project - .update(cx, |project, cx| { - let project_path = project.find_project_path("c.rs", cx).unwrap(); - project.open_buffer(project_path, cx) - }) - .await - .unwrap(); - - cx.run_until_parked(); - - let index_state = index.read_with(cx, |index, _cx| index.state().clone()); - let index_state = index_state.lock().await; - cx.update(|cx| { - let decls = index_state.declarations_for_identifier::<8>(&test_process_data); - assert_eq!(decls.len(), 1); - - let decl = expect_buffer_decl("c.rs", &decls[0].1, &project, cx); - assert_eq!(decl.identifier, test_process_data); - - let parent_id = decl.parent.unwrap(); - let parent = index_state.declaration(parent_id).unwrap(); - let parent_decl = expect_buffer_decl("c.rs", &parent, &project, cx); - assert_eq!( - parent_decl.identifier, - Identifier { - name: "tests".into(), - language_id: rust_lang_id - } - ); - assert_eq!(parent_decl.parent, None); - }); - - drop(buffer); - } - - #[gpui::test] - async fn test_declarations_limit(cx: &mut TestAppContext) { - let (_, index, rust_lang_id) = init_test(cx).await; - - let index_state = index.read_with(cx, |index, _cx| index.state().clone()); - let index_state = index_state.lock().await; - let decls = index_state.declarations_for_identifier::<1>(&Identifier { - name: "main".into(), - language_id: rust_lang_id, - }); - assert_eq!(decls.len(), 0); - } - - #[gpui::test] - async fn test_buffer_shadow(cx: &mut TestAppContext) { - let (project, index, rust_lang_id) = init_test(cx).await; - - let main = Identifier { - name: "main".into(), - language_id: rust_lang_id, - }; - - let buffer = project - .update(cx, |project, cx| { - let project_path = project.find_project_path("c.rs", cx).unwrap(); - project.open_buffer(project_path, cx) - }) - .await - .unwrap(); - - cx.run_until_parked(); - - let index_state_arc = index.read_with(cx, |index, _cx| index.state().clone()); - { - let index_state = index_state_arc.lock().await; - - cx.update(|cx| { - let decls = index_state.declarations_for_identifier::<8>(&main); - assert_eq!(decls.len(), 2); - let decl = expect_buffer_decl("c.rs", &decls[0].1, &project, cx); - assert_eq!(decl.identifier, main); - assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..280); - - expect_file_decl("a.rs", &decls[1].1, &project, cx); - }); - } - - // Drop the buffer and wait for release - cx.update(|_| { - drop(buffer); - }); - cx.run_until_parked(); - - let index_state = index_state_arc.lock().await; - - cx.update(|cx| { - let decls = index_state.declarations_for_identifier::<8>(&main); - assert_eq!(decls.len(), 2); - expect_file_decl("a.rs", &decls[0].1, &project, cx); - expect_file_decl("c.rs", &decls[1].1, &project, cx); - }); - } - - fn expect_buffer_decl<'a>( - path: &str, - declaration: &'a Declaration, - project: &Entity, - cx: &App, - ) -> &'a BufferDeclaration { - if let Declaration::Buffer { - declaration, - project_entry_id, - .. - } = declaration - { - let project_path = project - .read(cx) - .path_for_entry(*project_entry_id, cx) - .unwrap(); - assert_eq!(project_path.path.as_ref(), rel_path(path),); - declaration - } else { - panic!("Expected a buffer declaration, found {:?}", declaration); - } - } - - fn expect_file_decl<'a>( - path: &str, - declaration: &'a Declaration, - project: &Entity, - cx: &App, - ) -> &'a FileDeclaration { - if let Declaration::File { - declaration, - project_entry_id: file, - .. - } = declaration - { - assert_eq!( - project - .read(cx) - .path_for_entry(*file, cx) - .unwrap() - .path - .as_ref(), - rel_path(path), - ); - declaration - } else { - panic!("Expected a file declaration, found {:?}", declaration); - } - } - - async fn init_test( - cx: &mut TestAppContext, - ) -> (Entity, Entity, LanguageId) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - }); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/root"), - json!({ - "a.rs": indoc! {r#" - fn main() { - let x = 1; - let y = 2; - let z = add(x, y); - println!("Result: {}", z); - } - - fn add(a: i32, b: i32) -> i32 { - a + b - } - "#}, - "b.rs": indoc! {" - pub struct Config { - pub name: String, - pub value: i32, - } - - impl Config { - pub fn new(name: String, value: i32) -> Self { - Config { name, value } - } - } - "}, - "c.rs": indoc! {r#" - use std::collections::HashMap; - - fn main() { - let args: Vec = std::env::args().collect(); - let data: Vec = args[1..] - .iter() - .filter_map(|s| s.parse().ok()) - .collect(); - let result = process_data(data); - println!("{:?}", result); - } - - fn process_data(data: Vec) -> HashMap { - let mut counts = HashMap::new(); - for value in data { - *counts.entry(value).or_insert(0) += 1; - } - counts - } - - #[cfg(test)] - mod tests { - use super::*; - - #[test] - fn test_process_data() { - let data = vec![1, 2, 2, 3]; - let result = process_data(data); - assert_eq!(result.get(&2), Some(&2)); - } - } - "#} - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _| project.languages().clone()); - let lang = rust_lang(); - let lang_id = lang.id(); - language_registry.add(Arc::new(lang)); - - let file_indexing_parallelism = 2; - let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx)); - cx.run_until_parked(); - - (project, index, lang_id) - } - - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) - .unwrap() - } -} diff --git a/crates/edit_prediction_context/src/text_similarity.rs b/crates/edit_prediction_context/src/text_similarity.rs deleted file mode 100644 index 308a9570206084fc223c72f2e1c49109ea157714..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_context/src/text_similarity.rs +++ /dev/null @@ -1,314 +0,0 @@ -use hashbrown::HashTable; -use regex::Regex; -use std::{ - borrow::Cow, - hash::{Hash, Hasher as _}, - path::Path, - sync::LazyLock, -}; -use util::rel_path::RelPath; - -use crate::reference::Reference; - -// TODO: Consider implementing sliding window similarity matching like -// https://github.com/sourcegraph/cody-public-snapshot/blob/8e20ac6c1460c08b0db581c0204658112a246eda/vscode/src/completions/context/retrievers/jaccard-similarity/bestJaccardMatch.ts -// -// That implementation could actually be more efficient - no need to track words in the window that -// are not in the query. - -// TODO: Consider a flat sorted Vec<(String, usize)> representation. Intersection can just walk the -// two in parallel. - -static IDENTIFIER_REGEX: LazyLock = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap()); - -/// Multiset of text occurrences for text similarity that only stores hashes and counts. -#[derive(Debug, Default)] -pub struct Occurrences { - table: HashTable, - total_count: usize, -} - -#[derive(Debug)] -struct OccurrenceEntry { - hash: u64, - count: usize, -} - -impl Occurrences { - pub fn within_string(text: &str) -> Self { - Self::from_identifiers(IDENTIFIER_REGEX.find_iter(text).map(|mat| mat.as_str())) - } - - #[allow(dead_code)] - pub fn within_references(references: &[Reference]) -> Self { - Self::from_identifiers( - references - .iter() - .map(|reference| reference.identifier.name.as_ref()), - ) - } - - pub fn from_identifiers(identifiers: impl IntoIterator>) -> Self { - let mut this = Self::default(); - // TODO: Score matches that match case higher? - // - // TODO: Also include unsplit identifier? - for identifier in identifiers { - for identifier_part in split_identifier(identifier.as_ref()) { - this.add_hash(fx_hash(&identifier_part.to_lowercase())); - } - } - this - } - - pub fn from_worktree_path(worktree_name: Option>, rel_path: &RelPath) -> Self { - if let Some(worktree_name) = worktree_name { - Self::from_identifiers( - std::iter::once(worktree_name) - .chain(iter_path_without_extension(rel_path.as_std_path())), - ) - } else { - Self::from_path(rel_path.as_std_path()) - } - } - - pub fn from_path(path: &Path) -> Self { - Self::from_identifiers(iter_path_without_extension(path)) - } - - fn add_hash(&mut self, hash: u64) { - self.table - .entry( - hash, - |entry: &OccurrenceEntry| entry.hash == hash, - |entry| entry.hash, - ) - .and_modify(|entry| entry.count += 1) - .or_insert(OccurrenceEntry { hash, count: 1 }); - self.total_count += 1; - } - - fn contains_hash(&self, hash: u64) -> bool { - self.get_count(hash) != 0 - } - - fn get_count(&self, hash: u64) -> usize { - self.table - .find(hash, |entry| entry.hash == hash) - .map(|entry| entry.count) - .unwrap_or(0) - } -} - -fn iter_path_without_extension(path: &Path) -> impl Iterator> { - let last_component: Option> = path.file_stem().map(|stem| stem.to_string_lossy()); - let mut path_components = path.components(); - path_components.next_back(); - path_components - .map(|component| component.as_os_str().to_string_lossy()) - .chain(last_component) -} - -pub fn fx_hash(data: &T) -> u64 { - let mut hasher = collections::FxHasher::default(); - data.hash(&mut hasher); - hasher.finish() -} - -// Splits camelcase / snakecase / kebabcase / pascalcase -// -// TODO: Make this more efficient / elegant. -fn split_identifier(identifier: &str) -> Vec<&str> { - let mut parts = Vec::new(); - let mut start = 0; - let chars: Vec = identifier.chars().collect(); - - if chars.is_empty() { - return parts; - } - - let mut i = 0; - while i < chars.len() { - let ch = chars[i]; - - // Handle explicit delimiters (underscore and hyphen) - if ch == '_' || ch == '-' { - if i > start { - parts.push(&identifier[start..i]); - } - start = i + 1; - i += 1; - continue; - } - - // Handle camelCase and PascalCase transitions - if i > 0 && i < chars.len() { - let prev_char = chars[i - 1]; - - // Transition from lowercase/digit to uppercase - if (prev_char.is_lowercase() || prev_char.is_ascii_digit()) && ch.is_uppercase() { - parts.push(&identifier[start..i]); - start = i; - } - // Handle sequences like "XMLParser" -> ["XML", "Parser"] - else if i + 1 < chars.len() - && ch.is_uppercase() - && chars[i + 1].is_lowercase() - && prev_char.is_uppercase() - { - parts.push(&identifier[start..i]); - start = i; - } - } - - i += 1; - } - - // Add the last part if there's any remaining - if start < identifier.len() { - parts.push(&identifier[start..]); - } - - // Filter out empty strings - parts.into_iter().filter(|s| !s.is_empty()).collect() -} - -pub fn jaccard_similarity<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 { - if set_a.table.len() > set_b.table.len() { - std::mem::swap(&mut set_a, &mut set_b); - } - let intersection = set_a - .table - .iter() - .filter(|entry| set_b.contains_hash(entry.hash)) - .count(); - let union = set_a.table.len() + set_b.table.len() - intersection; - intersection as f32 / union as f32 -} - -// TODO -#[allow(dead_code)] -pub fn overlap_coefficient<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 { - if set_a.table.len() > set_b.table.len() { - std::mem::swap(&mut set_a, &mut set_b); - } - let intersection = set_a - .table - .iter() - .filter(|entry| set_b.contains_hash(entry.hash)) - .count(); - intersection as f32 / set_a.table.len() as f32 -} - -// TODO -#[allow(dead_code)] -pub fn weighted_jaccard_similarity<'a>( - mut set_a: &'a Occurrences, - mut set_b: &'a Occurrences, -) -> f32 { - if set_a.table.len() > set_b.table.len() { - std::mem::swap(&mut set_a, &mut set_b); - } - - let mut numerator = 0; - let mut denominator_a = 0; - let mut used_count_b = 0; - for entry_a in set_a.table.iter() { - let count_a = entry_a.count; - let count_b = set_b.get_count(entry_a.hash); - numerator += count_a.min(count_b); - denominator_a += count_a.max(count_b); - used_count_b += count_b; - } - - let denominator = denominator_a + (set_b.total_count - used_count_b); - if denominator == 0 { - 0.0 - } else { - numerator as f32 / denominator as f32 - } -} - -pub fn weighted_overlap_coefficient<'a>( - mut set_a: &'a Occurrences, - mut set_b: &'a Occurrences, -) -> f32 { - if set_a.table.len() > set_b.table.len() { - std::mem::swap(&mut set_a, &mut set_b); - } - - let mut numerator = 0; - for entry_a in set_a.table.iter() { - let count_a = entry_a.count; - let count_b = set_b.get_count(entry_a.hash); - numerator += count_a.min(count_b); - } - - let denominator = set_a.total_count.min(set_b.total_count); - if denominator == 0 { - 0.0 - } else { - numerator as f32 / denominator as f32 - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_split_identifier() { - assert_eq!(split_identifier("snake_case"), vec!["snake", "case"]); - assert_eq!(split_identifier("kebab-case"), vec!["kebab", "case"]); - assert_eq!(split_identifier("PascalCase"), vec!["Pascal", "Case"]); - assert_eq!(split_identifier("camelCase"), vec!["camel", "Case"]); - assert_eq!(split_identifier("XMLParser"), vec!["XML", "Parser"]); - } - - #[test] - fn test_similarity_functions() { - // 10 identifier parts, 8 unique - // Repeats: 2 "outline", 2 "items" - let set_a = Occurrences::within_string( - "let mut outline_items = query_outline_items(&language, &tree, &source);", - ); - // 14 identifier parts, 11 unique - // Repeats: 2 "outline", 2 "language", 2 "tree" - let set_b = Occurrences::within_string( - "pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec {", - ); - - // 6 overlaps: "outline", "items", "query", "language", "tree", "source" - // 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str" - assert_eq!(jaccard_similarity(&set_a, &set_b), 6.0 / (6.0 + 7.0)); - - // Numerator is one more than before due to both having 2 "outline". - // Denominator is the same except for 3 more due to the non-overlapping duplicates - assert_eq!( - weighted_jaccard_similarity(&set_a, &set_b), - 7.0 / (7.0 + 7.0 + 3.0) - ); - - // Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8. - assert_eq!(overlap_coefficient(&set_a, &set_b), 6.0 / 8.0); - - // Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of - // the smaller set, 10. - assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0); - } - - #[test] - fn test_iter_path_without_extension() { - let mut iter = iter_path_without_extension(Path::new("")); - assert_eq!(iter.next(), None); - - let iter = iter_path_without_extension(Path::new("foo")); - assert_eq!(iter.collect::>(), ["foo"]); - - let iter = iter_path_without_extension(Path::new("foo/bar.txt")); - assert_eq!(iter.collect::>(), ["foo", "bar"]); - - let iter = iter_path_without_extension(Path::new("foo/bar/baz.txt")); - assert_eq!(iter.collect::>(), ["foo", "bar", "baz"]); - } -} diff --git a/crates/edit_prediction_types/Cargo.toml b/crates/edit_prediction_types/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..ebc09680e1dcf99dc21e1714eca6a9db337f4a90 --- /dev/null +++ b/crates/edit_prediction_types/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "edit_prediction_types" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/edit_prediction_types.rs" + +[dependencies] +client.workspace = true +gpui.workspace = true +language.workspace = true diff --git a/crates/zeta/LICENSE-GPL b/crates/edit_prediction_types/LICENSE-GPL similarity index 100% rename from crates/zeta/LICENSE-GPL rename to crates/edit_prediction_types/LICENSE-GPL diff --git a/crates/edit_prediction_types/src/edit_prediction_types.rs b/crates/edit_prediction_types/src/edit_prediction_types.rs new file mode 100644 index 0000000000000000000000000000000000000000..1f63b8626d15dfd3e2cba78aacb50505186da01c --- /dev/null +++ b/crates/edit_prediction_types/src/edit_prediction_types.rs @@ -0,0 +1,298 @@ +use std::{ops::Range, sync::Arc}; + +use client::EditPredictionUsage; +use gpui::{App, Context, Entity, SharedString}; +use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt}; + +// TODO: Find a better home for `Direction`. +// +// This should live in an ancestor crate of `editor` and `edit_prediction`, +// but at time of writing there isn't an obvious spot. +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum Direction { + Prev, + Next, +} + +#[derive(Clone)] +pub enum EditPrediction { + /// Edits within the buffer that requested the prediction + Local { + id: Option, + edits: Vec<(Range, Arc)>, + edit_preview: Option, + }, + /// Jump to a different file from the one that requested the prediction + Jump { + id: Option, + snapshot: language::BufferSnapshot, + target: language::Anchor, + }, +} + +pub enum DataCollectionState { + /// The provider doesn't support data collection. + Unsupported, + /// Data collection is enabled. + Enabled { is_project_open_source: bool }, + /// Data collection is disabled or unanswered. + Disabled { is_project_open_source: bool }, +} + +impl DataCollectionState { + pub fn is_supported(&self) -> bool { + !matches!(self, DataCollectionState::Unsupported) + } + + pub fn is_enabled(&self) -> bool { + matches!(self, DataCollectionState::Enabled { .. }) + } + + pub fn is_project_open_source(&self) -> bool { + match self { + Self::Enabled { + is_project_open_source, + } + | Self::Disabled { + is_project_open_source, + } => *is_project_open_source, + _ => false, + } + } +} + +pub trait EditPredictionDelegate: 'static + Sized { + fn name() -> &'static str; + fn display_name() -> &'static str; + fn show_predictions_in_menu() -> bool; + fn show_tab_accept_marker() -> bool { + false + } + fn supports_jump_to_edit() -> bool { + true + } + + fn data_collection_state(&self, _cx: &App) -> DataCollectionState { + DataCollectionState::Unsupported + } + + fn usage(&self, _cx: &App) -> Option { + None + } + + fn toggle_data_collection(&mut self, _cx: &mut App) {} + fn is_enabled( + &self, + buffer: &Entity, + cursor_position: language::Anchor, + cx: &App, + ) -> bool; + fn is_refreshing(&self, cx: &App) -> bool; + fn refresh( + &mut self, + buffer: Entity, + cursor_position: language::Anchor, + debounce: bool, + cx: &mut Context, + ); + fn cycle( + &mut self, + buffer: Entity, + cursor_position: language::Anchor, + direction: Direction, + cx: &mut Context, + ); + fn accept(&mut self, cx: &mut Context); + fn discard(&mut self, cx: &mut Context); + fn did_show(&mut self, _cx: &mut Context) {} + fn suggest( + &mut self, + buffer: &Entity, + cursor_position: language::Anchor, + cx: &mut Context, + ) -> Option; +} + +pub trait EditPredictionDelegateHandle { + fn name(&self) -> &'static str; + fn display_name(&self) -> &'static str; + fn is_enabled( + &self, + buffer: &Entity, + cursor_position: language::Anchor, + cx: &App, + ) -> bool; + fn show_predictions_in_menu(&self) -> bool; + fn show_tab_accept_marker(&self) -> bool; + fn supports_jump_to_edit(&self) -> bool; + fn data_collection_state(&self, cx: &App) -> DataCollectionState; + fn usage(&self, cx: &App) -> Option; + fn toggle_data_collection(&self, cx: &mut App); + fn is_refreshing(&self, cx: &App) -> bool; + fn refresh( + &self, + buffer: Entity, + cursor_position: language::Anchor, + debounce: bool, + cx: &mut App, + ); + fn cycle( + &self, + buffer: Entity, + cursor_position: language::Anchor, + direction: Direction, + cx: &mut App, + ); + fn did_show(&self, cx: &mut App); + fn accept(&self, cx: &mut App); + fn discard(&self, cx: &mut App); + fn suggest( + &self, + buffer: &Entity, + cursor_position: language::Anchor, + cx: &mut App, + ) -> Option; +} + +impl EditPredictionDelegateHandle for Entity +where + T: EditPredictionDelegate, +{ + fn name(&self) -> &'static str { + T::name() + } + + fn display_name(&self) -> &'static str { + T::display_name() + } + + fn show_predictions_in_menu(&self) -> bool { + T::show_predictions_in_menu() + } + + fn show_tab_accept_marker(&self) -> bool { + T::show_tab_accept_marker() + } + + fn supports_jump_to_edit(&self) -> bool { + T::supports_jump_to_edit() + } + + fn data_collection_state(&self, cx: &App) -> DataCollectionState { + self.read(cx).data_collection_state(cx) + } + + fn usage(&self, cx: &App) -> Option { + self.read(cx).usage(cx) + } + + fn toggle_data_collection(&self, cx: &mut App) { + self.update(cx, |this, cx| this.toggle_data_collection(cx)) + } + + fn is_enabled( + &self, + buffer: &Entity, + cursor_position: language::Anchor, + cx: &App, + ) -> bool { + self.read(cx).is_enabled(buffer, cursor_position, cx) + } + + fn is_refreshing(&self, cx: &App) -> bool { + self.read(cx).is_refreshing(cx) + } + + fn refresh( + &self, + buffer: Entity, + cursor_position: language::Anchor, + debounce: bool, + cx: &mut App, + ) { + self.update(cx, |this, cx| { + this.refresh(buffer, cursor_position, debounce, cx) + }) + } + + fn cycle( + &self, + buffer: Entity, + cursor_position: language::Anchor, + direction: Direction, + cx: &mut App, + ) { + self.update(cx, |this, cx| { + this.cycle(buffer, cursor_position, direction, cx) + }) + } + + fn accept(&self, cx: &mut App) { + self.update(cx, |this, cx| this.accept(cx)) + } + + fn discard(&self, cx: &mut App) { + self.update(cx, |this, cx| this.discard(cx)) + } + + fn did_show(&self, cx: &mut App) { + self.update(cx, |this, cx| this.did_show(cx)) + } + + fn suggest( + &self, + buffer: &Entity, + cursor_position: language::Anchor, + cx: &mut App, + ) -> Option { + self.update(cx, |this, cx| this.suggest(buffer, cursor_position, cx)) + } +} + +/// Returns edits updated based on user edits since the old snapshot. None is returned if any user +/// edit is not a prefix of a predicted insertion. +pub fn interpolate_edits( + old_snapshot: &BufferSnapshot, + new_snapshot: &BufferSnapshot, + current_edits: &[(Range, Arc)], +) -> Option, Arc)>> { + let mut edits = Vec::new(); + + let mut model_edits = current_edits.iter().peekable(); + for user_edit in new_snapshot.edits_since::(&old_snapshot.version) { + while let Some((model_old_range, _)) = model_edits.peek() { + let model_old_range = model_old_range.to_offset(old_snapshot); + if model_old_range.end < user_edit.old.start { + let (model_old_range, model_new_text) = model_edits.next().unwrap(); + edits.push((model_old_range.clone(), model_new_text.clone())); + } else { + break; + } + } + + if let Some((model_old_range, model_new_text)) = model_edits.peek() { + let model_old_offset_range = model_old_range.to_offset(old_snapshot); + if user_edit.old == model_old_offset_range { + let user_new_text = new_snapshot + .text_for_range(user_edit.new.clone()) + .collect::(); + + if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) { + if !model_suffix.is_empty() { + let anchor = old_snapshot.anchor_after(user_edit.old.end); + edits.push((anchor..anchor, model_suffix.into())); + } + + model_edits.next(); + continue; + } + } + } + + return None; + } + + edits.extend(model_edits.cloned()); + + if edits.is_empty() { None } else { Some(edits) } +} diff --git a/crates/edit_prediction_button/Cargo.toml b/crates/edit_prediction_ui/Cargo.toml similarity index 77% rename from crates/edit_prediction_button/Cargo.toml rename to crates/edit_prediction_ui/Cargo.toml index d336cf66926d37ab7c0ebb1d5aa5a2172342350c..fb846f35d76ae2f6478ef675f246e4d06fe5f469 100644 --- a/crates/edit_prediction_button/Cargo.toml +++ b/crates/edit_prediction_ui/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "edit_prediction_button" +name = "edit_prediction_ui" version = "0.1.0" edition.workspace = true publish.workspace = true @@ -9,35 +9,43 @@ license = "GPL-3.0-or-later" workspace = true [lib] -path = "src/edit_prediction_button.rs" +path = "src/edit_prediction_ui.rs" doctest = false [dependencies] anyhow.workspace = true +buffer_diff.workspace = true client.workspace = true cloud_llm_client.workspace = true +cloud_zeta2_prompt.workspace = true codestral.workspace = true +command_palette_hooks.workspace = true copilot.workspace = true edit_prediction.workspace = true +edit_prediction_types.workspace = true editor.workspace = true feature_flags.workspace = true fs.workspace = true +futures.workspace = true gpui.workspace = true indoc.workspace = true language.workspace = true +markdown.workspace = true +menu.workspace = true +multi_buffer.workspace = true paths.workspace = true project.workspace = true regex.workspace = true settings.workspace = true supermaven.workspace = true telemetry.workspace = true +text.workspace = true +theme.workspace = true ui.workspace = true ui_input.workspace = true -menu.workspace = true util.workspace = true workspace.workspace = true zed_actions.workspace = true -zeta.workspace = true [dev-dependencies] copilot = { workspace = true, features = ["test-support"] } diff --git a/crates/zeta2_tools/LICENSE-GPL b/crates/edit_prediction_ui/LICENSE-GPL similarity index 100% rename from crates/zeta2_tools/LICENSE-GPL rename to crates/edit_prediction_ui/LICENSE-GPL diff --git a/crates/edit_prediction_button/src/edit_prediction_button.rs b/crates/edit_prediction_ui/src/edit_prediction_button.rs similarity index 88% rename from crates/edit_prediction_button/src/edit_prediction_button.rs rename to crates/edit_prediction_ui/src/edit_prediction_button.rs index 8ce8441859b7cc747a2b566dedd913e58259969d..04c7614689c5fdc076ab0aa9c4b4fe7d68e2f582 100644 --- a/crates/edit_prediction_button/src/edit_prediction_button.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_button.rs @@ -1,16 +1,14 @@ -mod sweep_api_token_modal; - -pub use sweep_api_token_modal::SweepApiKeyModal; - use anyhow::Result; use client::{Client, UserStore, zed_urls}; use cloud_llm_client::UsageLimit; -use codestral::CodestralCompletionProvider; +use codestral::CodestralEditPredictionDelegate; use copilot::{Copilot, Status}; +use edit_prediction::{MercuryFeatureFlag, SweepFeatureFlag, Zeta2FeatureFlag}; +use edit_prediction_types::EditPredictionDelegateHandle; use editor::{ Editor, MultiBufferOffset, SelectionEffects, actions::ShowEditPrediction, scroll::Autoscroll, }; -use feature_flags::{FeatureFlagAppExt, PredictEditsRateCompletionsFeatureFlag}; +use feature_flags::FeatureFlagAppExt; use fs::Fs; use gpui::{ Action, Animation, AnimationExt, App, AsyncWindowContext, Corner, Entity, FocusHandle, @@ -25,6 +23,7 @@ use language::{ use project::DisableAiSettings; use regex::Regex; use settings::{ + EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME, EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, Settings, SettingsStore, update_settings_file, @@ -44,7 +43,11 @@ use workspace::{ notifications::NotificationId, }; use zed_actions::OpenBrowser; -use zeta::{RateCompletions, SweepFeatureFlag, Zeta2FeatureFlag}; + +use crate::{ + ExternalProviderApiKeyModal, RatePredictions, + rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag, +}; actions!( edit_prediction, @@ -67,7 +70,7 @@ pub struct EditPredictionButton { editor_focus_handle: Option, language: Option>, file: Option>, - edit_prediction_provider: Option>, + edit_prediction_provider: Option>, fs: Arc, user_store: Entity, popover_menu_handle: PopoverMenuHandle, @@ -244,7 +247,7 @@ impl Render for EditPredictionButton { EditPredictionProvider::Codestral => { let enabled = self.editor_enabled.unwrap_or(true); - let has_api_key = CodestralCompletionProvider::has_api_key(cx); + let has_api_key = CodestralEditPredictionDelegate::has_api_key(cx); let fs = self.fs.clone(); let this = cx.weak_entity(); @@ -309,24 +312,34 @@ impl Render for EditPredictionButton { provider @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => { let enabled = self.editor_enabled.unwrap_or(true); - let is_sweep = matches!( - provider, - EditPredictionProvider::Experimental( - EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME - ) - ); + let ep_icon; + let mut missing_token = false; - let sweep_missing_token = is_sweep - && !zeta::Zeta::try_global(cx) - .map_or(false, |zeta| zeta.read(cx).has_sweep_api_token()); - - let zeta_icon = match (is_sweep, enabled) { - (true, _) => IconName::SweepAi, - (false, true) => IconName::ZedPredict, - (false, false) => IconName::ZedPredictDisabled, + match provider { + EditPredictionProvider::Experimental( + EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, + ) => { + ep_icon = IconName::SweepAi; + missing_token = edit_prediction::EditPredictionStore::try_global(cx) + .is_some_and(|ep_store| !ep_store.read(cx).has_sweep_api_token()); + } + EditPredictionProvider::Experimental( + EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME, + ) => { + ep_icon = IconName::Inception; + missing_token = edit_prediction::EditPredictionStore::try_global(cx) + .is_some_and(|ep_store| !ep_store.read(cx).has_mercury_api_token()); + } + _ => { + ep_icon = if enabled { + IconName::ZedPredict + } else { + IconName::ZedPredictDisabled + }; + } }; - if zeta::should_show_upsell_modal() { + if edit_prediction::should_show_upsell_modal() { let tooltip_meta = if self.user_store.read(cx).current_user().is_some() { "Choose a Plan" } else { @@ -334,7 +347,7 @@ impl Render for EditPredictionButton { }; return div().child( - IconButton::new("zed-predict-pending-button", zeta_icon) + IconButton::new("zed-predict-pending-button", ep_icon) .shape(IconButtonShape::Square) .indicator(Indicator::dot().color(Color::Muted)) .indicator_border_color(Some(cx.theme().colors().status_bar_background)) @@ -367,7 +380,7 @@ impl Render for EditPredictionButton { let show_editor_predictions = self.editor_show_predictions; let user = self.user_store.read(cx).current_user(); - let indicator_color = if sweep_missing_token { + let indicator_color = if missing_token { Some(Color::Error) } else if enabled && (!show_editor_predictions || over_limit) { Some(if over_limit { @@ -379,7 +392,7 @@ impl Render for EditPredictionButton { None }; - let icon_button = IconButton::new("zed-predict-pending-button", zeta_icon) + let icon_button = IconButton::new("zed-predict-pending-button", ep_icon) .shape(IconButtonShape::Square) .when_some(indicator_color, |this, color| { this.indicator(Indicator::dot().color(color)) @@ -419,13 +432,13 @@ impl Render for EditPredictionButton { let this = cx.weak_entity(); - let mut popover_menu = PopoverMenu::new("zeta") + let mut popover_menu = PopoverMenu::new("edit-prediction") .when(user.is_some(), |popover_menu| { let this = this.clone(); popover_menu.menu(move |window, cx| { this.update(cx, |this, cx| { - this.build_zeta_context_menu(provider, window, cx) + this.build_edit_prediction_context_menu(provider, window, cx) }) .ok() }) @@ -485,7 +498,7 @@ impl EditPredictionButton { cx.observe_global::(move |_, cx| cx.notify()) .detach(); - CodestralCompletionProvider::ensure_api_key_loaded(client.http_client(), cx); + CodestralEditPredictionDelegate::ensure_api_key_loaded(client.http_client(), cx); Self { editor_subscription: None, @@ -520,7 +533,7 @@ impl EditPredictionButton { } } - if CodestralCompletionProvider::has_api_key(cx) { + if CodestralEditPredictionDelegate::has_api_key(cx) { providers.push(EditPredictionProvider::Codestral); } @@ -530,6 +543,12 @@ impl EditPredictionButton { )); } + if cx.has_flag::() { + providers.push(EditPredictionProvider::Experimental( + EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME, + )); + } + if cx.has_flag::() { providers.push(EditPredictionProvider::Experimental( EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, @@ -599,8 +618,8 @@ impl EditPredictionButton { EditPredictionProvider::Experimental( EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME, ) => { - let has_api_token = zeta::Zeta::try_global(cx) - .map_or(false, |zeta| zeta.read(cx).has_sweep_api_token()); + let has_api_token = edit_prediction::EditPredictionStore::try_global(cx) + .map_or(false, |ep_store| ep_store.read(cx).has_sweep_api_token()); let should_open_modal = !has_api_token || is_current; @@ -626,7 +645,66 @@ impl EditPredictionButton { if let Some(workspace) = window.root::().flatten() { workspace.update(cx, |workspace, cx| { workspace.toggle_modal(window, cx, |window, cx| { - SweepApiKeyModal::new(window, cx) + ExternalProviderApiKeyModal::new( + window, + cx, + |api_key, store, cx| { + store + .sweep_ai + .set_api_token(api_key, cx) + .detach_and_log_err(cx); + }, + ) + }); + }); + }; + } else { + set_completion_provider(fs.clone(), cx, provider); + } + }); + + menu.item(entry) + } + EditPredictionProvider::Experimental( + EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME, + ) => { + let has_api_token = edit_prediction::EditPredictionStore::try_global(cx) + .map_or(false, |ep_store| ep_store.read(cx).has_mercury_api_token()); + + let should_open_modal = !has_api_token || is_current; + + let entry = if has_api_token { + ContextMenuEntry::new("Mercury") + .toggleable(IconPosition::Start, is_current) + } else { + ContextMenuEntry::new("Mercury") + .icon(IconName::XCircle) + .icon_color(Color::Error) + .documentation_aside( + DocumentationSide::Left, + DocumentationEdge::Bottom, + |_| { + Label::new("Click to configure your Mercury API token") + .into_any_element() + }, + ) + }; + + let entry = entry.handler(move |window, cx| { + if should_open_modal { + if let Some(workspace) = window.root::().flatten() { + workspace.update(cx, |workspace, cx| { + workspace.toggle_modal(window, cx, |window, cx| { + ExternalProviderApiKeyModal::new( + window, + cx, + |api_key, store, cx| { + store + .mercury + .set_api_token(api_key, cx) + .detach_and_log_err(cx); + }, + ) }); }); }; @@ -947,8 +1025,8 @@ impl EditPredictionButton { ) .context(editor_focus_handle) .when( - cx.has_flag::(), - |this| this.action("Rate Completions", RateCompletions.boxed_clone()), + cx.has_flag::(), + |this| this.action("Rate Predictions", RatePredictions.boxed_clone()), ); } @@ -1016,7 +1094,7 @@ impl EditPredictionButton { }) } - fn build_zeta_context_menu( + fn build_edit_prediction_context_menu( &self, provider: EditPredictionProvider, window: &mut Window, @@ -1105,9 +1183,33 @@ impl EditPredictionButton { .separator(); } - let menu = self.build_language_settings_menu(menu, window, cx); - let menu = self.add_provider_switching_section(menu, provider, cx); + menu = self.build_language_settings_menu(menu, window, cx); + + if cx.has_flag::() { + let settings = all_language_settings(None, cx); + let context_retrieval = settings.edit_predictions.use_context; + menu = menu.separator().header("Context Retrieval").item( + ContextMenuEntry::new("Enable Context Retrieval") + .toggleable(IconPosition::Start, context_retrieval) + .action(workspace::ToggleEditPrediction.boxed_clone()) + .handler({ + let fs = self.fs.clone(); + move |_, cx| { + update_settings_file(fs.clone(), cx, move |settings, _| { + settings + .project + .all_languages + .features + .get_or_insert_default() + .experimental_edit_prediction_context_retrieval = + Some(!context_retrieval) + }); + } + }), + ); + } + menu = self.add_provider_switching_section(menu, provider, cx); menu }) } diff --git a/crates/edit_prediction_ui/src/edit_prediction_context_view.rs b/crates/edit_prediction_ui/src/edit_prediction_context_view.rs new file mode 100644 index 0000000000000000000000000000000000000000..0e343fe3fcb8ed7bb6bf3e8481927344d63133ee --- /dev/null +++ b/crates/edit_prediction_ui/src/edit_prediction_context_view.rs @@ -0,0 +1,389 @@ +use std::{ + any::TypeId, + collections::VecDeque, + ops::Add, + sync::Arc, + time::{Duration, Instant}, +}; + +use anyhow::Result; +use client::{Client, UserStore}; +use editor::{Editor, PathKey}; +use futures::StreamExt as _; +use gpui::{ + Animation, AnimationExt, App, AppContext as _, Context, Entity, EventEmitter, FocusHandle, + Focusable, InteractiveElement as _, IntoElement as _, ParentElement as _, SharedString, + Styled as _, Task, TextAlign, Window, actions, div, pulsating_between, +}; +use multi_buffer::MultiBuffer; +use project::Project; +use text::OffsetRangeExt; +use ui::{ + ButtonCommon, Clickable, Disableable, FluentBuilder as _, IconButton, IconName, + StyledTypography as _, h_flex, v_flex, +}; + +use edit_prediction::{ + ContextRetrievalFinishedDebugEvent, ContextRetrievalStartedDebugEvent, DebugEvent, + EditPredictionStore, +}; +use workspace::Item; + +pub struct EditPredictionContextView { + empty_focus_handle: FocusHandle, + project: Entity, + store: Entity, + runs: VecDeque, + current_ix: usize, + _update_task: Task>, +} + +#[derive(Debug)] +struct RetrievalRun { + editor: Entity, + started_at: Instant, + metadata: Vec<(&'static str, SharedString)>, + finished_at: Option, +} + +actions!( + dev, + [ + /// Go to the previous context retrieval run + EditPredictionContextGoBack, + /// Go to the next context retrieval run + EditPredictionContextGoForward + ] +); + +impl EditPredictionContextView { + pub fn new( + project: Entity, + client: &Arc, + user_store: &Entity, + window: &mut gpui::Window, + cx: &mut Context, + ) -> Self { + let store = EditPredictionStore::global(client, user_store, cx); + + let mut debug_rx = store.update(cx, |store, _| store.debug_info()); + let _update_task = cx.spawn_in(window, async move |this, cx| { + while let Some(event) = debug_rx.next().await { + this.update_in(cx, |this, window, cx| { + this.handle_store_event(event, window, cx) + })?; + } + Ok(()) + }); + + Self { + empty_focus_handle: cx.focus_handle(), + project, + runs: VecDeque::new(), + current_ix: 0, + store, + _update_task, + } + } + + fn handle_store_event( + &mut self, + event: DebugEvent, + window: &mut gpui::Window, + cx: &mut Context, + ) { + match event { + DebugEvent::ContextRetrievalStarted(info) => { + if info.project_entity_id == self.project.entity_id() { + self.handle_context_retrieval_started(info, window, cx); + } + } + DebugEvent::ContextRetrievalFinished(info) => { + if info.project_entity_id == self.project.entity_id() { + self.handle_context_retrieval_finished(info, window, cx); + } + } + DebugEvent::EditPredictionRequested(_) => {} + } + } + + fn handle_context_retrieval_started( + &mut self, + info: ContextRetrievalStartedDebugEvent, + window: &mut Window, + cx: &mut Context, + ) { + if self + .runs + .back() + .is_some_and(|run| run.finished_at.is_none()) + { + self.runs.pop_back(); + } + + let multibuffer = cx.new(|_| MultiBuffer::new(language::Capability::ReadOnly)); + let editor = cx + .new(|cx| Editor::for_multibuffer(multibuffer, Some(self.project.clone()), window, cx)); + + if self.runs.len() == 32 { + self.runs.pop_front(); + } + + self.runs.push_back(RetrievalRun { + editor, + started_at: info.timestamp, + finished_at: None, + metadata: Vec::new(), + }); + + cx.notify(); + } + + fn handle_context_retrieval_finished( + &mut self, + info: ContextRetrievalFinishedDebugEvent, + window: &mut Window, + cx: &mut Context, + ) { + let Some(run) = self.runs.back_mut() else { + return; + }; + + run.finished_at = Some(info.timestamp); + run.metadata = info.metadata; + + let project = self.project.clone(); + let related_files = self + .store + .read(cx) + .context_for_project(&self.project, cx) + .to_vec(); + + let editor = run.editor.clone(); + let multibuffer = run.editor.read(cx).buffer().clone(); + + if self.current_ix + 2 == self.runs.len() { + self.current_ix += 1; + } + + cx.spawn_in(window, async move |this, cx| { + let mut paths = Vec::new(); + for related_file in related_files { + let (buffer, point_ranges): (_, Vec<_>) = + if let Some(buffer) = related_file.buffer.upgrade() { + let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?; + + ( + buffer, + related_file + .excerpts + .iter() + .map(|excerpt| excerpt.anchor_range.to_point(&snapshot)) + .collect(), + ) + } else { + ( + project + .update(cx, |project, cx| { + project.open_buffer(related_file.path.clone(), cx) + })? + .await?, + related_file + .excerpts + .iter() + .map(|excerpt| excerpt.point_range.clone()) + .collect(), + ) + }; + cx.update(|_, cx| { + let path = PathKey::for_buffer(&buffer, cx); + paths.push((path, buffer, point_ranges)); + })?; + } + + multibuffer.update(cx, |multibuffer, cx| { + multibuffer.clear(cx); + + for (path, buffer, ranges) in paths { + multibuffer.set_excerpts_for_path(path, buffer, ranges, 0, cx); + } + })?; + + editor.update_in(cx, |editor, window, cx| { + editor.move_to_beginning(&Default::default(), window, cx); + })?; + + this.update(cx, |_, cx| cx.notify()) + }) + .detach(); + } + + fn handle_go_back( + &mut self, + _: &EditPredictionContextGoBack, + window: &mut Window, + cx: &mut Context, + ) { + self.current_ix = self.current_ix.saturating_sub(1); + cx.focus_self(window); + cx.notify(); + } + + fn handle_go_forward( + &mut self, + _: &EditPredictionContextGoForward, + window: &mut Window, + cx: &mut Context, + ) { + self.current_ix = self + .current_ix + .add(1) + .min(self.runs.len().saturating_sub(1)); + cx.focus_self(window); + cx.notify(); + } + + fn render_informational_footer( + &self, + cx: &mut Context<'_, EditPredictionContextView>, + ) -> ui::Div { + let run = &self.runs[self.current_ix]; + let new_run_started = self + .runs + .back() + .map_or(false, |latest_run| latest_run.finished_at.is_none()); + + h_flex() + .p_2() + .w_full() + .font_buffer(cx) + .text_xs() + .border_t_1() + .gap_2() + .child(v_flex().h_full().flex_1().child({ + let t0 = run.started_at; + let mut table = ui::Table::<2>::new().width(ui::px(300.)).no_ui_font(); + for (key, value) in &run.metadata { + table = table.row([key.into_any_element(), value.clone().into_any_element()]) + } + table = table.row([ + "Total Time".into_any_element(), + format!("{} ms", (run.finished_at.unwrap_or(t0) - t0).as_millis()) + .into_any_element(), + ]); + table + })) + .child( + v_flex().h_full().text_align(TextAlign::Right).child( + h_flex() + .justify_end() + .child( + IconButton::new("go-back", IconName::ChevronLeft) + .disabled(self.current_ix == 0 || self.runs.len() < 2) + .tooltip(ui::Tooltip::for_action_title( + "Go to previous run", + &EditPredictionContextGoBack, + )) + .on_click(cx.listener(|this, _, window, cx| { + this.handle_go_back(&EditPredictionContextGoBack, window, cx); + })), + ) + .child( + div() + .child(format!("{}/{}", self.current_ix + 1, self.runs.len())) + .map(|this| { + if new_run_started { + this.with_animation( + "pulsating-count", + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between(0.4, 0.8)), + |label, delta| label.opacity(delta), + ) + .into_any_element() + } else { + this.into_any_element() + } + }), + ) + .child( + IconButton::new("go-forward", IconName::ChevronRight) + .disabled(self.current_ix + 1 == self.runs.len()) + .tooltip(ui::Tooltip::for_action_title( + "Go to next run", + &EditPredictionContextGoBack, + )) + .on_click(cx.listener(|this, _, window, cx| { + this.handle_go_forward( + &EditPredictionContextGoForward, + window, + cx, + ); + })), + ), + ), + ) + } +} + +impl Focusable for EditPredictionContextView { + fn focus_handle(&self, cx: &App) -> FocusHandle { + self.runs + .get(self.current_ix) + .map(|run| run.editor.read(cx).focus_handle(cx)) + .unwrap_or_else(|| self.empty_focus_handle.clone()) + } +} + +impl EventEmitter<()> for EditPredictionContextView {} + +impl Item for EditPredictionContextView { + type Event = (); + + fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { + "Edit Prediction Context".into() + } + + fn buffer_kind(&self, _cx: &App) -> workspace::item::ItemBufferKind { + workspace::item::ItemBufferKind::Multibuffer + } + + fn act_as_type<'a>( + &'a self, + type_id: TypeId, + self_handle: &'a Entity, + _: &'a App, + ) -> Option { + if type_id == TypeId::of::() { + Some(self_handle.clone().into()) + } else if type_id == TypeId::of::() { + Some(self.runs.get(self.current_ix)?.editor.clone().into()) + } else { + None + } + } +} + +impl gpui::Render for EditPredictionContextView { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl ui::IntoElement { + v_flex() + .key_context("EditPredictionContext") + .on_action(cx.listener(Self::handle_go_back)) + .on_action(cx.listener(Self::handle_go_forward)) + .size_full() + .map(|this| { + if self.runs.is_empty() { + this.child( + v_flex() + .size_full() + .justify_center() + .items_center() + .child("No retrieval runs yet"), + ) + } else { + this.child(self.runs[self.current_ix].editor.clone()) + .child(self.render_informational_footer(cx)) + } + }) + } +} diff --git a/crates/edit_prediction_ui/src/edit_prediction_ui.rs b/crates/edit_prediction_ui/src/edit_prediction_ui.rs new file mode 100644 index 0000000000000000000000000000000000000000..c177b5233c33feb4f5ff82f60bf3fb6981cf3ee8 --- /dev/null +++ b/crates/edit_prediction_ui/src/edit_prediction_ui.rs @@ -0,0 +1,128 @@ +mod edit_prediction_button; +mod edit_prediction_context_view; +mod external_provider_api_token_modal; +mod rate_prediction_modal; + +use std::any::{Any as _, TypeId}; + +use command_palette_hooks::CommandPaletteFilter; +use edit_prediction::{ResetOnboarding, Zeta2FeatureFlag}; +use edit_prediction_context_view::EditPredictionContextView; +use feature_flags::FeatureFlagAppExt as _; +use gpui::actions; +use project::DisableAiSettings; +use rate_prediction_modal::RatePredictionsModal; +use settings::{Settings as _, SettingsStore}; +use ui::{App, prelude::*}; +use workspace::{SplitDirection, Workspace}; + +pub use edit_prediction_button::{EditPredictionButton, ToggleMenu}; +pub use external_provider_api_token_modal::ExternalProviderApiKeyModal; + +use crate::rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag; + +actions!( + dev, + [ + /// Opens the edit prediction context view. + OpenEditPredictionContextView, + ] +); + +actions!( + edit_prediction, + [ + /// Opens the rate completions modal. + RatePredictions, + ] +); + +pub fn init(cx: &mut App) { + feature_gate_predict_edits_actions(cx); + + cx.observe_new(move |workspace: &mut Workspace, _, _cx| { + workspace.register_action(|workspace, _: &RatePredictions, window, cx| { + if cx.has_flag::() { + RatePredictionsModal::toggle(workspace, window, cx); + } + }); + + workspace.register_action_renderer(|div, _, _, cx| { + let has_flag = cx.has_flag::(); + div.when(has_flag, |div| { + div.on_action(cx.listener( + move |workspace, _: &OpenEditPredictionContextView, window, cx| { + let project = workspace.project(); + workspace.split_item( + SplitDirection::Right, + Box::new(cx.new(|cx| { + EditPredictionContextView::new( + project.clone(), + workspace.client(), + workspace.user_store(), + window, + cx, + ) + })), + window, + cx, + ); + }, + )) + }) + }); + }) + .detach(); +} + +fn feature_gate_predict_edits_actions(cx: &mut App) { + let rate_completion_action_types = [TypeId::of::()]; + let reset_onboarding_action_types = [TypeId::of::()]; + let all_action_types = [ + TypeId::of::(), + TypeId::of::(), + zed_actions::OpenZedPredictOnboarding.type_id(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + ]; + + CommandPaletteFilter::update_global(cx, |filter, _cx| { + filter.hide_action_types(&rate_completion_action_types); + filter.hide_action_types(&reset_onboarding_action_types); + filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]); + }); + + cx.observe_global::(move |cx| { + let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai; + let has_feature_flag = cx.has_flag::(); + + CommandPaletteFilter::update_global(cx, |filter, _cx| { + if is_ai_disabled { + filter.hide_action_types(&all_action_types); + } else if has_feature_flag { + filter.show_action_types(&rate_completion_action_types); + } else { + filter.hide_action_types(&rate_completion_action_types); + } + }); + }) + .detach(); + + cx.observe_flag::(move |is_enabled, cx| { + if !DisableAiSettings::get_global(cx).disable_ai { + if is_enabled { + CommandPaletteFilter::update_global(cx, |filter, _cx| { + filter.show_action_types(&rate_completion_action_types); + }); + } else { + CommandPaletteFilter::update_global(cx, |filter, _cx| { + filter.hide_action_types(&rate_completion_action_types); + }); + } + } + }) + .detach(); +} diff --git a/crates/edit_prediction_button/src/sweep_api_token_modal.rs b/crates/edit_prediction_ui/src/external_provider_api_token_modal.rs similarity index 68% rename from crates/edit_prediction_button/src/sweep_api_token_modal.rs rename to crates/edit_prediction_ui/src/external_provider_api_token_modal.rs index ab2102f25a2a7291644ca67ab3c89fd47da7ac0a..bc312836e9fdd30237156ac532a055d1e23a2589 100644 --- a/crates/edit_prediction_button/src/sweep_api_token_modal.rs +++ b/crates/edit_prediction_ui/src/external_provider_api_token_modal.rs @@ -1,23 +1,29 @@ +use edit_prediction::EditPredictionStore; use gpui::{ DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, IntoElement, ParentElement, Render, }; use ui::{Button, ButtonStyle, Clickable, Headline, HeadlineSize, prelude::*}; use ui_input::InputField; use workspace::ModalView; -use zeta::Zeta; -pub struct SweepApiKeyModal { +pub struct ExternalProviderApiKeyModal { api_key_input: Entity, focus_handle: FocusHandle, + on_confirm: Box, &mut EditPredictionStore, &mut App)>, } -impl SweepApiKeyModal { - pub fn new(window: &mut Window, cx: &mut Context) -> Self { - let api_key_input = cx.new(|cx| InputField::new(window, cx, "Enter your Sweep API token")); +impl ExternalProviderApiKeyModal { + pub fn new( + window: &mut Window, + cx: &mut Context, + on_confirm: impl Fn(Option, &mut EditPredictionStore, &mut App) + 'static, + ) -> Self { + let api_key_input = cx.new(|cx| InputField::new(window, cx, "Enter your API key")); Self { api_key_input, focus_handle: cx.focus_handle(), + on_confirm: Box::new(on_confirm), } } @@ -29,39 +35,35 @@ impl SweepApiKeyModal { let api_key = self.api_key_input.read(cx).text(cx); let api_key = (!api_key.trim().is_empty()).then_some(api_key); - if let Some(zeta) = Zeta::try_global(cx) { - zeta.update(cx, |zeta, cx| { - zeta.sweep_ai - .set_api_token(api_key, cx) - .detach_and_log_err(cx); - }); + if let Some(ep_store) = EditPredictionStore::try_global(cx) { + ep_store.update(cx, |ep_store, cx| (self.on_confirm)(api_key, ep_store, cx)) } cx.emit(DismissEvent); } } -impl EventEmitter for SweepApiKeyModal {} +impl EventEmitter for ExternalProviderApiKeyModal {} -impl ModalView for SweepApiKeyModal {} +impl ModalView for ExternalProviderApiKeyModal {} -impl Focusable for SweepApiKeyModal { +impl Focusable for ExternalProviderApiKeyModal { fn focus_handle(&self, _cx: &App) -> FocusHandle { self.focus_handle.clone() } } -impl Render for SweepApiKeyModal { +impl Render for ExternalProviderApiKeyModal { fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { v_flex() - .key_context("SweepApiKeyModal") + .key_context("ExternalApiKeyModal") .on_action(cx.listener(Self::cancel)) .on_action(cx.listener(Self::confirm)) .elevation_2(cx) .w(px(400.)) .p_4() .gap_3() - .child(Headline::new("Sweep API Token").size(HeadlineSize::Small)) + .child(Headline::new("API Token").size(HeadlineSize::Small)) .child(self.api_key_input.clone()) .child( h_flex() diff --git a/crates/zeta/src/rate_prediction_modal.rs b/crates/edit_prediction_ui/src/rate_prediction_modal.rs similarity index 95% rename from crates/zeta/src/rate_prediction_modal.rs rename to crates/edit_prediction_ui/src/rate_prediction_modal.rs index 0cceb86608ed609122c81d406c71280894789e88..8e754b33dc18c5be60bc052c33aa08cdcb980acb 100644 --- a/crates/zeta/src/rate_prediction_modal.rs +++ b/crates/edit_prediction_ui/src/rate_prediction_modal.rs @@ -1,7 +1,8 @@ -use crate::{EditPrediction, EditPredictionRating, Zeta}; use buffer_diff::{BufferDiff, BufferDiffSnapshot}; use cloud_zeta2_prompt::write_codeblock; +use edit_prediction::{EditPrediction, EditPredictionRating, EditPredictionStore}; use editor::{Editor, ExcerptRange, MultiBuffer}; +use feature_flags::FeatureFlag; use gpui::{ App, BorderStyle, DismissEvent, EdgesRefinement, Entity, EventEmitter, FocusHandle, Focusable, Length, StyleRefinement, TextStyleRefinement, Window, actions, prelude::*, @@ -9,9 +10,7 @@ use gpui::{ use language::{LanguageRegistry, Point, language_settings}; use markdown::{Markdown, MarkdownStyle}; use settings::Settings as _; -use std::fmt::Write; -use std::sync::Arc; -use std::time::Duration; +use std::{fmt::Write, sync::Arc, time::Duration}; use theme::ThemeSettings; use ui::{KeyBinding, List, ListItem, ListItemSpacing, Tooltip, prelude::*}; use workspace::{ModalView, Workspace}; @@ -34,8 +33,14 @@ actions!( ] ); +pub struct PredictEditsRatePredictionsFeatureFlag; + +impl FeatureFlag for PredictEditsRatePredictionsFeatureFlag { + const NAME: &'static str = "predict-edits-rate-completions"; +} + pub struct RatePredictionsModal { - zeta: Entity, + ep_store: Entity, language_registry: Arc, active_prediction: Option, selected_index: usize, @@ -68,10 +73,10 @@ impl RatePredictionView { impl RatePredictionsModal { pub fn toggle(workspace: &mut Workspace, window: &mut Window, cx: &mut Context) { - if let Some(zeta) = Zeta::try_global(cx) { + if let Some(ep_store) = EditPredictionStore::try_global(cx) { let language_registry = workspace.app_state().languages.clone(); workspace.toggle_modal(window, cx, |window, cx| { - RatePredictionsModal::new(zeta, language_registry, window, cx) + RatePredictionsModal::new(ep_store, language_registry, window, cx) }); telemetry::event!("Rate Prediction Modal Open", source = "Edit Prediction"); @@ -79,15 +84,15 @@ impl RatePredictionsModal { } pub fn new( - zeta: Entity, + ep_store: Entity, language_registry: Arc, window: &mut Window, cx: &mut Context, ) -> Self { - let subscription = cx.observe(&zeta, |_, _, cx| cx.notify()); + let subscription = cx.observe(&ep_store, |_, _, cx| cx.notify()); Self { - zeta, + ep_store, language_registry, selected_index: 0, focus_handle: cx.focus_handle(), @@ -113,7 +118,7 @@ impl RatePredictionsModal { self.selected_index += 1; self.selected_index = usize::min( self.selected_index, - self.zeta.read(cx).shown_predictions().count(), + self.ep_store.read(cx).shown_predictions().count(), ); cx.notify(); } @@ -130,7 +135,7 @@ impl RatePredictionsModal { fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context) { let next_index = self - .zeta + .ep_store .read(cx) .shown_predictions() .skip(self.selected_index) @@ -146,11 +151,11 @@ impl RatePredictionsModal { } fn select_prev_edit(&mut self, _: &PreviousEdit, _: &mut Window, cx: &mut Context) { - let zeta = self.zeta.read(cx); - let completions_len = zeta.shown_completions_len(); + let ep_store = self.ep_store.read(cx); + let completions_len = ep_store.shown_completions_len(); let prev_index = self - .zeta + .ep_store .read(cx) .shown_predictions() .rev() @@ -173,7 +178,7 @@ impl RatePredictionsModal { } fn select_last(&mut self, _: &menu::SelectLast, _window: &mut Window, cx: &mut Context) { - self.selected_index = self.zeta.read(cx).shown_completions_len() - 1; + self.selected_index = self.ep_store.read(cx).shown_completions_len() - 1; cx.notify(); } @@ -183,9 +188,9 @@ impl RatePredictionsModal { window: &mut Window, cx: &mut Context, ) { - self.zeta.update(cx, |zeta, cx| { + self.ep_store.update(cx, |ep_store, cx| { if let Some(active) = &self.active_prediction { - zeta.rate_prediction( + ep_store.rate_prediction( &active.prediction, EditPredictionRating::Positive, active.feedback_editor.read(cx).text(cx), @@ -216,8 +221,8 @@ impl RatePredictionsModal { return; } - self.zeta.update(cx, |zeta, cx| { - zeta.rate_prediction( + self.ep_store.update(cx, |ep_store, cx| { + ep_store.rate_prediction( &active.prediction, EditPredictionRating::Negative, active.feedback_editor.read(cx).text(cx), @@ -254,7 +259,7 @@ impl RatePredictionsModal { cx: &mut Context, ) { let completion = self - .zeta + .ep_store .read(cx) .shown_predictions() .skip(self.selected_index) @@ -267,7 +272,7 @@ impl RatePredictionsModal { fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { let completion = self - .zeta + .ep_store .read(cx) .shown_predictions() .skip(self.selected_index) @@ -288,7 +293,7 @@ impl RatePredictionsModal { // Avoid resetting completion rating if it's already selected. if let Some(prediction) = prediction { self.selected_index = self - .zeta + .ep_store .read(cx) .shown_predictions() .enumerate() @@ -376,7 +381,7 @@ impl RatePredictionsModal { &included_file.path, &included_file.excerpts, if included_file.path == prediction.inputs.cursor_path { - cursor_insertions + cursor_insertions.as_slice() } else { &[] }, @@ -564,7 +569,7 @@ impl RatePredictionsModal { let border_color = cx.theme().colors().border; let bg_color = cx.theme().colors().editor_background; - let rated = self.zeta.read(cx).is_prediction_rated(&completion_id); + let rated = self.ep_store.read(cx).is_prediction_rated(&completion_id); let feedback_empty = active_prediction .feedback_editor .read(cx) @@ -715,7 +720,7 @@ impl RatePredictionsModal { } fn render_shown_completions(&self, cx: &Context) -> impl Iterator { - self.zeta + self.ep_store .read(cx) .shown_predictions() .cloned() @@ -725,7 +730,7 @@ impl RatePredictionsModal { .active_prediction .as_ref() .is_some_and(|selected| selected.prediction.id == completion.id); - let rated = self.zeta.read(cx).is_prediction_rated(&completion.id); + let rated = self.ep_store.read(cx).is_prediction_rated(&completion.id); let (icon_name, icon_color, tooltip_text) = match (rated, completion.edits.is_empty()) { diff --git a/crates/editor/Cargo.toml b/crates/editor/Cargo.toml index 2aa02e293dd44d5fdd920ac8cd98da48b9c1a912..f3ed28ab05c6839a478ebbf6c81ca5e66fc372e3 100644 --- a/crates/editor/Cargo.toml +++ b/crates/editor/Cargo.toml @@ -49,7 +49,7 @@ fs.workspace = true git.workspace = true gpui.workspace = true indoc.workspace = true -edit_prediction.workspace = true +edit_prediction_types.workspace = true itertools.workspace = true language.workspace = true linkify.workspace = true @@ -84,6 +84,8 @@ tree-sitter-html = { workspace = true, optional = true } tree-sitter-rust = { workspace = true, optional = true } tree-sitter-typescript = { workspace = true, optional = true } tree-sitter-python = { workspace = true, optional = true } +ztracing.workspace = true +tracing.workspace = true unicode-segmentation.workspace = true unicode-script.workspace = true unindent = { workspace = true, optional = true } @@ -94,6 +96,7 @@ uuid.workspace = true vim_mode_setting.workspace = true workspace.workspace = true zed_actions.workspace = true +zlog.workspace = true [dev-dependencies] criterion.workspace = true @@ -118,6 +121,7 @@ tree-sitter-rust.workspace = true tree-sitter-typescript.workspace = true tree-sitter-yaml.workspace = true tree-sitter-bash.workspace = true +tree-sitter-md.workspace = true unindent.workspace = true util = { workspace = true, features = ["test-support"] } workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/editor/src/actions.rs b/crates/editor/src/actions.rs index e823b06910fba67a38754ece6ad746f5f632e613..0bf236769efd17c26ebd55d4dda010d454ba71e4 100644 --- a/crates/editor/src/actions.rs +++ b/crates/editor/src/actions.rs @@ -453,8 +453,6 @@ actions!( CollapseAllDiffHunks, /// Expands macros recursively at cursor position. ExpandMacroRecursively, - /// Finds all references to the symbol at cursor. - FindAllReferences, /// Finds the next match in the search. FindNextMatch, /// Finds the previous match in the search. @@ -827,3 +825,20 @@ actions!( WrapSelectionsInTag ] ); + +/// Finds all references to the symbol at cursor. +#[derive(PartialEq, Clone, Deserialize, JsonSchema, Action)] +#[action(namespace = editor)] +#[serde(deny_unknown_fields)] +pub struct FindAllReferences { + #[serde(default = "default_true")] + pub always_open_multibuffer: bool, +} + +impl Default for FindAllReferences { + fn default() -> Self { + Self { + always_open_multibuffer: true, + } + } +} diff --git a/crates/editor/src/display_map/block_map.rs b/crates/editor/src/display_map/block_map.rs index a6744041971101dafa4957523fb7a16250f38996..79d06dbf8b6e27cffffd47d6637c83eadcb00424 100644 --- a/crates/editor/src/display_map/block_map.rs +++ b/crates/editor/src/display_map/block_map.rs @@ -164,6 +164,7 @@ impl BlockPlacement { } impl BlockPlacement { + #[ztracing::instrument(skip_all)] fn cmp(&self, other: &Self, buffer: &MultiBufferSnapshot) -> Ordering { self.start() .cmp(other.start(), buffer) @@ -171,6 +172,7 @@ impl BlockPlacement { .then_with(|| self.tie_break().cmp(&other.tie_break())) } + #[ztracing::instrument(skip_all)] fn to_wrap_row(&self, wrap_snapshot: &WrapSnapshot) -> Option> { let buffer_snapshot = wrap_snapshot.buffer_snapshot(); match self { @@ -474,6 +476,7 @@ pub struct BlockRows<'a> { } impl BlockMap { + #[ztracing::instrument(skip_all)] pub fn new( wrap_snapshot: WrapSnapshot, buffer_header_height: u32, @@ -503,6 +506,7 @@ impl BlockMap { map } + #[ztracing::instrument(skip_all)] pub fn read(&self, wrap_snapshot: WrapSnapshot, edits: WrapPatch) -> BlockMapReader<'_> { self.sync(&wrap_snapshot, edits); *self.wrap_snapshot.borrow_mut() = wrap_snapshot.clone(); @@ -518,13 +522,17 @@ impl BlockMap { } } + #[ztracing::instrument(skip_all)] pub fn write(&mut self, wrap_snapshot: WrapSnapshot, edits: WrapPatch) -> BlockMapWriter<'_> { self.sync(&wrap_snapshot, edits); *self.wrap_snapshot.borrow_mut() = wrap_snapshot; BlockMapWriter(self) } + #[ztracing::instrument(skip_all, fields(edits))] fn sync(&self, wrap_snapshot: &WrapSnapshot, mut edits: WrapPatch) { + let _timer = zlog::time!("BlockMap::sync").warn_if_gt(std::time::Duration::from_millis(50)); + let buffer = wrap_snapshot.buffer_snapshot(); // Handle changing the last excerpt if it is empty. @@ -784,6 +792,7 @@ impl BlockMap { *transforms = new_transforms; } + #[ztracing::instrument(skip_all)] pub fn replace_blocks(&mut self, mut renderers: HashMap) { for block in &mut self.custom_blocks { if let Some(render) = renderers.remove(&block.id) { @@ -793,6 +802,7 @@ impl BlockMap { } /// Guarantees that `wrap_row_for` is called with points in increasing order. + #[ztracing::instrument(skip_all)] fn header_and_footer_blocks<'a, R, T>( &'a self, buffer: &'a multi_buffer::MultiBufferSnapshot, @@ -880,6 +890,7 @@ impl BlockMap { }) } + #[ztracing::instrument(skip_all)] fn sort_blocks(blocks: &mut Vec<(BlockPlacement, Block)>) { blocks.sort_unstable_by(|(placement_a, block_a), (placement_b, block_b)| { placement_a @@ -1016,6 +1027,7 @@ impl DerefMut for BlockMapReader<'_> { } impl BlockMapReader<'_> { + #[ztracing::instrument(skip_all)] pub fn row_for_block(&self, block_id: CustomBlockId) -> Option { let block = self.blocks.iter().find(|block| block.id == block_id)?; let buffer_row = block @@ -1054,6 +1066,7 @@ impl BlockMapReader<'_> { } impl BlockMapWriter<'_> { + #[ztracing::instrument(skip_all)] pub fn insert( &mut self, blocks: impl IntoIterator>, @@ -1120,6 +1133,7 @@ impl BlockMapWriter<'_> { ids } + #[ztracing::instrument(skip_all)] pub fn resize(&mut self, mut heights: HashMap) { let wrap_snapshot = &*self.0.wrap_snapshot.borrow(); let buffer = wrap_snapshot.buffer_snapshot(); @@ -1172,6 +1186,7 @@ impl BlockMapWriter<'_> { self.0.sync(wrap_snapshot, edits); } + #[ztracing::instrument(skip_all)] pub fn remove(&mut self, block_ids: HashSet) { let wrap_snapshot = &*self.0.wrap_snapshot.borrow(); let buffer = wrap_snapshot.buffer_snapshot(); @@ -1217,6 +1232,7 @@ impl BlockMapWriter<'_> { self.0.sync(wrap_snapshot, edits); } + #[ztracing::instrument(skip_all)] pub fn remove_intersecting_replace_blocks( &mut self, ranges: impl IntoIterator>, @@ -1239,6 +1255,7 @@ impl BlockMapWriter<'_> { self.0.buffers_with_disabled_headers.insert(buffer_id); } + #[ztracing::instrument(skip_all)] pub fn fold_buffers( &mut self, buffer_ids: impl IntoIterator, @@ -1248,6 +1265,7 @@ impl BlockMapWriter<'_> { self.fold_or_unfold_buffers(true, buffer_ids, multi_buffer, cx); } + #[ztracing::instrument(skip_all)] pub fn unfold_buffers( &mut self, buffer_ids: impl IntoIterator, @@ -1257,6 +1275,7 @@ impl BlockMapWriter<'_> { self.fold_or_unfold_buffers(false, buffer_ids, multi_buffer, cx); } + #[ztracing::instrument(skip_all)] fn fold_or_unfold_buffers( &mut self, fold: bool, @@ -1292,6 +1311,7 @@ impl BlockMapWriter<'_> { self.0.sync(&wrap_snapshot, edits); } + #[ztracing::instrument(skip_all)] fn blocks_intersecting_buffer_range( &self, range: Range, @@ -1326,6 +1346,7 @@ impl BlockMapWriter<'_> { impl BlockSnapshot { #[cfg(test)] + #[ztracing::instrument(skip_all)] pub fn text(&self) -> String { self.chunks( BlockRow(0)..self.transforms.summary().output_rows, @@ -1337,6 +1358,7 @@ impl BlockSnapshot { .collect() } + #[ztracing::instrument(skip_all)] pub(crate) fn chunks<'a>( &'a self, rows: Range, @@ -1378,6 +1400,7 @@ impl BlockSnapshot { } } + #[ztracing::instrument(skip_all)] pub(super) fn row_infos(&self, start_row: BlockRow) -> BlockRows<'_> { let mut cursor = self.transforms.cursor::>(()); cursor.seek(&start_row, Bias::Right); @@ -1399,6 +1422,7 @@ impl BlockSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn blocks_in_range( &self, rows: Range, @@ -1432,6 +1456,7 @@ impl BlockSnapshot { }) } + #[ztracing::instrument(skip_all)] pub(crate) fn sticky_header_excerpt(&self, position: f64) -> Option> { let top_row = position as u32; let mut cursor = self.transforms.cursor::(()); @@ -1455,6 +1480,7 @@ impl BlockSnapshot { None } + #[ztracing::instrument(skip_all)] pub fn block_for_id(&self, block_id: BlockId) -> Option { let buffer = self.wrap_snapshot.buffer_snapshot(); let wrap_point = match block_id { @@ -1491,6 +1517,7 @@ impl BlockSnapshot { None } + #[ztracing::instrument(skip_all)] pub fn max_point(&self) -> BlockPoint { let row = self .transforms @@ -1500,10 +1527,12 @@ impl BlockSnapshot { BlockPoint::new(row, self.line_len(row)) } + #[ztracing::instrument(skip_all)] pub fn longest_row(&self) -> BlockRow { self.transforms.summary().longest_row } + #[ztracing::instrument(skip_all)] pub fn longest_row_in_range(&self, range: Range) -> BlockRow { let mut cursor = self.transforms.cursor::>(()); cursor.seek(&range.start, Bias::Right); @@ -1555,6 +1584,7 @@ impl BlockSnapshot { longest_row } + #[ztracing::instrument(skip_all)] pub(super) fn line_len(&self, row: BlockRow) -> u32 { let (start, _, item) = self.transforms @@ -1574,11 +1604,13 @@ impl BlockSnapshot { } } + #[ztracing::instrument(skip_all)] pub(super) fn is_block_line(&self, row: BlockRow) -> bool { let (_, _, item) = self.transforms.find::((), &row, Bias::Right); item.is_some_and(|t| t.block.is_some()) } + #[ztracing::instrument(skip_all)] pub(super) fn is_folded_buffer_header(&self, row: BlockRow) -> bool { let (_, _, item) = self.transforms.find::((), &row, Bias::Right); let Some(transform) = item else { @@ -1587,6 +1619,7 @@ impl BlockSnapshot { matches!(transform.block, Some(Block::FoldedBuffer { .. })) } + #[ztracing::instrument(skip_all)] pub(super) fn is_line_replaced(&self, row: MultiBufferRow) -> bool { let wrap_point = self .wrap_snapshot @@ -1602,6 +1635,7 @@ impl BlockSnapshot { }) } + #[ztracing::instrument(skip_all)] pub fn clip_point(&self, point: BlockPoint, bias: Bias) -> BlockPoint { let mut cursor = self.transforms.cursor::>(()); cursor.seek(&BlockRow(point.row), Bias::Right); @@ -1663,6 +1697,7 @@ impl BlockSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn to_block_point(&self, wrap_point: WrapPoint) -> BlockPoint { let (start, _, item) = self.transforms.find::, _>( (), @@ -1684,6 +1719,7 @@ impl BlockSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn to_wrap_point(&self, block_point: BlockPoint, bias: Bias) -> WrapPoint { let (start, end, item) = self.transforms.find::, _>( (), @@ -1719,6 +1755,7 @@ impl BlockSnapshot { impl BlockChunks<'_> { /// Go to the next transform + #[ztracing::instrument(skip_all)] fn advance(&mut self) { self.input_chunk = Chunk::default(); self.transforms.next(); @@ -1759,6 +1796,7 @@ pub struct StickyHeaderExcerpt<'a> { impl<'a> Iterator for BlockChunks<'a> { type Item = Chunk<'a>; + #[ztracing::instrument(skip_all)] fn next(&mut self) -> Option { if self.output_row >= self.max_output_row { return None; @@ -1858,6 +1896,7 @@ impl<'a> Iterator for BlockChunks<'a> { impl Iterator for BlockRows<'_> { type Item = RowInfo; + #[ztracing::instrument(skip_all)] fn next(&mut self) -> Option { if self.started { self.output_row.0 += 1; @@ -1960,14 +1999,17 @@ impl DerefMut for BlockContext<'_, '_> { } impl CustomBlock { + #[ztracing::instrument(skip_all)] pub fn render(&self, cx: &mut BlockContext) -> AnyElement { self.render.lock()(cx) } + #[ztracing::instrument(skip_all)] pub fn start(&self) -> Anchor { *self.placement.start() } + #[ztracing::instrument(skip_all)] pub fn end(&self) -> Anchor { *self.placement.end() } diff --git a/crates/editor/src/display_map/crease_map.rs b/crates/editor/src/display_map/crease_map.rs index a68c27886733d34a60ef0ce2ef4006b92b679db9..8f4a3781f4f335f1a3e61ec5a19818661a7c6ea5 100644 --- a/crates/editor/src/display_map/crease_map.rs +++ b/crates/editor/src/display_map/crease_map.rs @@ -19,6 +19,7 @@ pub struct CreaseMap { } impl CreaseMap { + #[ztracing::instrument(skip_all)] pub fn new(snapshot: &MultiBufferSnapshot) -> Self { CreaseMap { snapshot: CreaseSnapshot::new(snapshot), @@ -40,11 +41,13 @@ impl CreaseSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn creases(&self) -> impl Iterator)> { self.creases.iter().map(|item| (item.id, &item.crease)) } /// Returns the first Crease starting on the specified buffer row. + #[ztracing::instrument(skip_all)] pub fn query_row<'a>( &'a self, row: MultiBufferRow, @@ -69,6 +72,7 @@ impl CreaseSnapshot { None } + #[ztracing::instrument(skip_all)] pub fn creases_in_range<'a>( &'a self, range: Range, @@ -95,6 +99,7 @@ impl CreaseSnapshot { }) } + #[ztracing::instrument(skip_all)] pub fn crease_items_with_offsets( &self, snapshot: &MultiBufferSnapshot, @@ -156,6 +161,7 @@ pub struct CreaseMetadata { } impl Crease { + #[ztracing::instrument(skip_all)] pub fn simple(range: Range, placeholder: FoldPlaceholder) -> Self { Crease::Inline { range, @@ -166,6 +172,7 @@ impl Crease { } } + #[ztracing::instrument(skip_all)] pub fn block(range: Range, height: u32, style: BlockStyle, render: RenderBlock) -> Self { Self::Block { range, @@ -177,6 +184,7 @@ impl Crease { } } + #[ztracing::instrument(skip_all)] pub fn inline( range: Range, placeholder: FoldPlaceholder, @@ -216,6 +224,7 @@ impl Crease { } } + #[ztracing::instrument(skip_all)] pub fn with_metadata(self, metadata: CreaseMetadata) -> Self { match self { Crease::Inline { @@ -235,6 +244,7 @@ impl Crease { } } + #[ztracing::instrument(skip_all)] pub fn range(&self) -> &Range { match self { Crease::Inline { range, .. } => range, @@ -242,6 +252,7 @@ impl Crease { } } + #[ztracing::instrument(skip_all)] pub fn metadata(&self) -> Option<&CreaseMetadata> { match self { Self::Inline { metadata, .. } => metadata.as_ref(), @@ -287,6 +298,7 @@ impl CreaseMap { self.snapshot.clone() } + #[ztracing::instrument(skip_all)] pub fn insert( &mut self, creases: impl IntoIterator>, @@ -312,6 +324,7 @@ impl CreaseMap { new_ids } + #[ztracing::instrument(skip_all)] pub fn remove( &mut self, ids: impl IntoIterator, @@ -379,6 +392,7 @@ impl sum_tree::Summary for ItemSummary { impl sum_tree::Item for CreaseItem { type Summary = ItemSummary; + #[ztracing::instrument(skip_all)] fn summary(&self, _cx: &MultiBufferSnapshot) -> Self::Summary { ItemSummary { range: self.crease.range().clone(), @@ -388,12 +402,14 @@ impl sum_tree::Item for CreaseItem { /// Implements `SeekTarget` for `Range` to enable seeking within a `SumTree` of `CreaseItem`s. impl SeekTarget<'_, ItemSummary, ItemSummary> for Range { + #[ztracing::instrument(skip_all)] fn cmp(&self, cursor_location: &ItemSummary, snapshot: &MultiBufferSnapshot) -> Ordering { AnchorRangeExt::cmp(self, &cursor_location.range, snapshot) } } impl SeekTarget<'_, ItemSummary, ItemSummary> for Anchor { + #[ztracing::instrument(skip_all)] fn cmp(&self, other: &ItemSummary, snapshot: &MultiBufferSnapshot) -> Ordering { self.cmp(&other.range.start, snapshot) } @@ -461,6 +477,7 @@ mod test { } #[gpui::test] + #[ztracing::instrument(skip_all)] fn test_creases_in_range(cx: &mut App) { let text = "line1\nline2\nline3\nline4\nline5\nline6\nline7"; let buffer = MultiBuffer::build_simple(text, cx); diff --git a/crates/editor/src/display_map/custom_highlights.rs b/crates/editor/src/display_map/custom_highlights.rs index a40d1adc82f4bc79308eaec901586232e9e2e5c2..c9202280bf957fac4d729bab558f686c0f62e774 100644 --- a/crates/editor/src/display_map/custom_highlights.rs +++ b/crates/editor/src/display_map/custom_highlights.rs @@ -30,6 +30,7 @@ struct HighlightEndpoint { } impl<'a> CustomHighlightsChunks<'a> { + #[ztracing::instrument(skip_all)] pub fn new( range: Range, language_aware: bool, @@ -51,6 +52,7 @@ impl<'a> CustomHighlightsChunks<'a> { } } + #[ztracing::instrument(skip_all)] pub fn seek(&mut self, new_range: Range) { self.highlight_endpoints = create_highlight_endpoints(&new_range, self.text_highlights, self.multibuffer_snapshot); @@ -108,6 +110,7 @@ fn create_highlight_endpoints( impl<'a> Iterator for CustomHighlightsChunks<'a> { type Item = Chunk<'a>; + #[ztracing::instrument(skip_all)] fn next(&mut self) -> Option { let mut next_highlight_endpoint = MultiBufferOffset(usize::MAX); while let Some(endpoint) = self.highlight_endpoints.peek().copied() { diff --git a/crates/editor/src/display_map/fold_map.rs b/crates/editor/src/display_map/fold_map.rs index 2d37dea38a93cc609a3a92064a6e35cdc76eb3da..bb0d6885acc2afd95e97fe9121acd2d0580554f3 100644 --- a/crates/editor/src/display_map/fold_map.rs +++ b/crates/editor/src/display_map/fold_map.rs @@ -99,6 +99,7 @@ impl FoldPoint { &mut self.0.column } + #[ztracing::instrument(skip_all)] pub fn to_inlay_point(self, snapshot: &FoldSnapshot) -> InlayPoint { let (start, _, _) = snapshot .transforms @@ -107,6 +108,7 @@ impl FoldPoint { InlayPoint(start.1.0 + overshoot) } + #[ztracing::instrument(skip_all)] pub fn to_offset(self, snapshot: &FoldSnapshot) -> FoldOffset { let (start, _, item) = snapshot .transforms @@ -138,6 +140,7 @@ impl<'a> sum_tree::Dimension<'a, TransformSummary> for FoldPoint { pub(crate) struct FoldMapWriter<'a>(&'a mut FoldMap); impl FoldMapWriter<'_> { + #[ztracing::instrument(skip_all)] pub(crate) fn fold( &mut self, ranges: impl IntoIterator, FoldPlaceholder)>, @@ -202,6 +205,7 @@ impl FoldMapWriter<'_> { } /// Removes any folds with the given ranges. + #[ztracing::instrument(skip_all)] pub(crate) fn remove_folds( &mut self, ranges: impl IntoIterator>, @@ -215,6 +219,7 @@ impl FoldMapWriter<'_> { } /// Removes any folds whose ranges intersect the given ranges. + #[ztracing::instrument(skip_all)] pub(crate) fn unfold_intersecting( &mut self, ranges: impl IntoIterator>, @@ -225,6 +230,7 @@ impl FoldMapWriter<'_> { /// Removes any folds that intersect the given ranges and for which the given predicate /// returns true. + #[ztracing::instrument(skip_all)] fn remove_folds_with( &mut self, ranges: impl IntoIterator>, @@ -277,6 +283,7 @@ impl FoldMapWriter<'_> { (self.0.snapshot.clone(), edits) } + #[ztracing::instrument(skip_all)] pub(crate) fn update_fold_widths( &mut self, new_widths: impl IntoIterator, @@ -326,6 +333,7 @@ pub struct FoldMap { } impl FoldMap { + #[ztracing::instrument(skip_all)] pub fn new(inlay_snapshot: InlaySnapshot) -> (Self, FoldSnapshot) { let this = Self { snapshot: FoldSnapshot { @@ -350,6 +358,7 @@ impl FoldMap { (this, snapshot) } + #[ztracing::instrument(skip_all)] pub fn read( &mut self, inlay_snapshot: InlaySnapshot, @@ -360,6 +369,7 @@ impl FoldMap { (self.snapshot.clone(), edits) } + #[ztracing::instrument(skip_all)] pub(crate) fn write( &mut self, inlay_snapshot: InlaySnapshot, @@ -369,6 +379,7 @@ impl FoldMap { (FoldMapWriter(self), snapshot, edits) } + #[ztracing::instrument(skip_all)] fn check_invariants(&self) { if cfg!(test) { assert_eq!( @@ -398,6 +409,7 @@ impl FoldMap { } } + #[ztracing::instrument(skip_all)] fn sync( &mut self, inlay_snapshot: InlaySnapshot, @@ -645,6 +657,7 @@ impl FoldSnapshot { &self.inlay_snapshot.buffer } + #[ztracing::instrument(skip_all)] fn fold_width(&self, fold_id: &FoldId) -> Option { self.fold_metadata_by_id.get(fold_id)?.width } @@ -665,6 +678,7 @@ impl FoldSnapshot { self.folds.items(&self.inlay_snapshot.buffer).len() } + #[ztracing::instrument(skip_all)] pub fn text_summary_for_range(&self, range: Range) -> MBTextSummary { let mut summary = MBTextSummary::default(); @@ -718,6 +732,7 @@ impl FoldSnapshot { summary } + #[ztracing::instrument(skip_all)] pub fn to_fold_point(&self, point: InlayPoint, bias: Bias) -> FoldPoint { let (start, end, item) = self .transforms @@ -734,6 +749,7 @@ impl FoldSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn fold_point_cursor(&self) -> FoldPointCursor<'_> { let cursor = self .transforms @@ -741,10 +757,12 @@ impl FoldSnapshot { FoldPointCursor { cursor } } + #[ztracing::instrument(skip_all)] pub fn len(&self) -> FoldOffset { FoldOffset(self.transforms.summary().output.len) } + #[ztracing::instrument(skip_all)] pub fn line_len(&self, row: u32) -> u32 { let line_start = FoldPoint::new(row, 0).to_offset(self).0; let line_end = if row >= self.max_point().row() { @@ -755,6 +773,7 @@ impl FoldSnapshot { (line_end - line_start) as u32 } + #[ztracing::instrument(skip_all)] pub fn row_infos(&self, start_row: u32) -> FoldRows<'_> { if start_row > self.transforms.summary().output.lines.row { panic!("invalid display row {}", start_row); @@ -777,6 +796,7 @@ impl FoldSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn max_point(&self) -> FoldPoint { FoldPoint(self.transforms.summary().output.lines) } @@ -786,6 +806,7 @@ impl FoldSnapshot { self.transforms.summary().output.longest_row } + #[ztracing::instrument(skip_all)] pub fn folds_in_range(&self, range: Range) -> impl Iterator where T: ToOffset, @@ -800,6 +821,7 @@ impl FoldSnapshot { }) } + #[ztracing::instrument(skip_all)] pub fn intersects_fold(&self, offset: T) -> bool where T: ToOffset, @@ -812,6 +834,7 @@ impl FoldSnapshot { item.is_some_and(|t| t.placeholder.is_some()) } + #[ztracing::instrument(skip_all)] pub fn is_line_folded(&self, buffer_row: MultiBufferRow) -> bool { let mut inlay_point = self .inlay_snapshot @@ -840,6 +863,7 @@ impl FoldSnapshot { } } + #[ztracing::instrument(skip_all)] pub(crate) fn chunks<'a>( &'a self, range: Range, @@ -884,6 +908,7 @@ impl FoldSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn chars_at(&self, start: FoldPoint) -> impl '_ + Iterator { self.chunks( start.to_offset(self)..self.len(), @@ -893,6 +918,7 @@ impl FoldSnapshot { .flat_map(|chunk| chunk.text.chars()) } + #[ztracing::instrument(skip_all)] pub fn chunks_at(&self, start: FoldPoint) -> FoldChunks<'_> { self.chunks( start.to_offset(self)..self.len(), @@ -902,6 +928,7 @@ impl FoldSnapshot { } #[cfg(test)] + #[ztracing::instrument(skip_all)] pub fn clip_offset(&self, offset: FoldOffset, bias: Bias) -> FoldOffset { if offset > self.len() { self.len() @@ -910,6 +937,7 @@ impl FoldSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn clip_point(&self, point: FoldPoint, bias: Bias) -> FoldPoint { let (start, end, item) = self .transforms @@ -939,6 +967,7 @@ pub struct FoldPointCursor<'transforms> { } impl FoldPointCursor<'_> { + #[ztracing::instrument(skip_all)] pub fn map(&mut self, point: InlayPoint, bias: Bias) -> FoldPoint { let cursor = &mut self.cursor; if cursor.did_seek() { @@ -1267,6 +1296,7 @@ pub struct FoldRows<'a> { } impl FoldRows<'_> { + #[ztracing::instrument(skip_all)] pub(crate) fn seek(&mut self, row: u32) { let fold_point = FoldPoint::new(row, 0); self.cursor.seek(&fold_point, Bias::Left); @@ -1280,6 +1310,7 @@ impl FoldRows<'_> { impl Iterator for FoldRows<'_> { type Item = RowInfo; + #[ztracing::instrument(skip_all)] fn next(&mut self) -> Option { let mut traversed_fold = false; while self.fold_point > self.cursor.end().0 { @@ -1391,6 +1422,7 @@ pub struct FoldChunks<'a> { } impl FoldChunks<'_> { + #[ztracing::instrument(skip_all)] pub(crate) fn seek(&mut self, range: Range) { self.transform_cursor.seek(&range.start, Bias::Right); @@ -1425,6 +1457,7 @@ impl FoldChunks<'_> { impl<'a> Iterator for FoldChunks<'a> { type Item = Chunk<'a>; + #[ztracing::instrument(skip_all)] fn next(&mut self) -> Option { if self.output_offset >= self.max_output_offset { return None; @@ -1524,6 +1557,7 @@ impl<'a> Iterator for FoldChunks<'a> { pub struct FoldOffset(pub MultiBufferOffset); impl FoldOffset { + #[ztracing::instrument(skip_all)] pub fn to_point(self, snapshot: &FoldSnapshot) -> FoldPoint { let (start, _, item) = snapshot .transforms @@ -1539,6 +1573,7 @@ impl FoldOffset { } #[cfg(test)] + #[ztracing::instrument(skip_all)] pub fn to_inlay_offset(self, snapshot: &FoldSnapshot) -> InlayOffset { let (start, _, _) = snapshot .transforms diff --git a/crates/editor/src/display_map/inlay_map.rs b/crates/editor/src/display_map/inlay_map.rs index 73174c3018e1f76a16acbff3f4bad1c7af84da33..d85f761a82e2f466b6868c4ce28bcb3a4e6b061d 100644 --- a/crates/editor/src/display_map/inlay_map.rs +++ b/crates/editor/src/display_map/inlay_map.rs @@ -52,6 +52,7 @@ enum Transform { impl sum_tree::Item for Transform { type Summary = TransformSummary; + #[ztracing::instrument(skip_all)] fn summary(&self, _: ()) -> Self::Summary { match self { Transform::Isomorphic(summary) => TransformSummary { @@ -228,6 +229,7 @@ pub struct InlayChunk<'a> { } impl InlayChunks<'_> { + #[ztracing::instrument(skip_all)] pub fn seek(&mut self, new_range: Range) { self.transforms.seek(&new_range.start, Bias::Right); @@ -248,6 +250,7 @@ impl InlayChunks<'_> { impl<'a> Iterator for InlayChunks<'a> { type Item = InlayChunk<'a>; + #[ztracing::instrument(skip_all)] fn next(&mut self) -> Option { if self.output_offset == self.max_output_offset { return None; @@ -441,6 +444,7 @@ impl<'a> Iterator for InlayChunks<'a> { } impl InlayBufferRows<'_> { + #[ztracing::instrument(skip_all)] pub fn seek(&mut self, row: u32) { let inlay_point = InlayPoint::new(row, 0); self.transforms.seek(&inlay_point, Bias::Left); @@ -465,6 +469,7 @@ impl InlayBufferRows<'_> { impl Iterator for InlayBufferRows<'_> { type Item = RowInfo; + #[ztracing::instrument(skip_all)] fn next(&mut self) -> Option { let buffer_row = if self.inlay_row == 0 { self.buffer_rows.next().unwrap() @@ -494,6 +499,7 @@ impl InlayPoint { } impl InlayMap { + #[ztracing::instrument(skip_all)] pub fn new(buffer: MultiBufferSnapshot) -> (Self, InlaySnapshot) { let version = 0; let snapshot = InlaySnapshot { @@ -511,6 +517,7 @@ impl InlayMap { ) } + #[ztracing::instrument(skip_all)] pub fn sync( &mut self, buffer_snapshot: MultiBufferSnapshot, @@ -643,6 +650,7 @@ impl InlayMap { } } + #[ztracing::instrument(skip_all)] pub fn splice( &mut self, to_remove: &[InlayId], @@ -693,11 +701,13 @@ impl InlayMap { (snapshot, edits) } + #[ztracing::instrument(skip_all)] pub fn current_inlays(&self) -> impl Iterator { self.inlays.iter() } #[cfg(test)] + #[ztracing::instrument(skip_all)] pub(crate) fn randomly_mutate( &mut self, next_inlay_id: &mut usize, @@ -766,6 +776,7 @@ impl InlayMap { } impl InlaySnapshot { + #[ztracing::instrument(skip_all)] pub fn to_point(&self, offset: InlayOffset) -> InlayPoint { let (start, _, item) = self.transforms.find:: InlayOffset { InlayOffset(self.transforms.summary().output.len) } + #[ztracing::instrument(skip_all)] pub fn max_point(&self) -> InlayPoint { InlayPoint(self.transforms.summary().output.lines) } + #[ztracing::instrument(skip_all, fields(point))] pub fn to_offset(&self, point: InlayPoint) -> InlayOffset { let (start, _, item) = self .transforms @@ -817,6 +831,7 @@ impl InlaySnapshot { None => self.len(), } } + #[ztracing::instrument(skip_all)] pub fn to_buffer_point(&self, point: InlayPoint) -> Point { let (start, _, item) = self.transforms @@ -830,6 +845,7 @@ impl InlaySnapshot { None => self.buffer.max_point(), } } + #[ztracing::instrument(skip_all)] pub fn to_buffer_offset(&self, offset: InlayOffset) -> MultiBufferOffset { let (start, _, item) = self .transforms @@ -844,6 +860,7 @@ impl InlaySnapshot { } } + #[ztracing::instrument(skip_all)] pub fn to_inlay_offset(&self, offset: MultiBufferOffset) -> InlayOffset { let mut cursor = self .transforms @@ -880,10 +897,12 @@ impl InlaySnapshot { } } + #[ztracing::instrument(skip_all)] pub fn to_inlay_point(&self, point: Point) -> InlayPoint { self.inlay_point_cursor().map(point) } + #[ztracing::instrument(skip_all)] pub fn inlay_point_cursor(&self) -> InlayPointCursor<'_> { let cursor = self.transforms.cursor::>(()); InlayPointCursor { @@ -892,6 +911,7 @@ impl InlaySnapshot { } } + #[ztracing::instrument(skip_all)] pub fn clip_point(&self, mut point: InlayPoint, mut bias: Bias) -> InlayPoint { let mut cursor = self.transforms.cursor::>(()); cursor.seek(&point, Bias::Left); @@ -983,10 +1003,12 @@ impl InlaySnapshot { } } + #[ztracing::instrument(skip_all)] pub fn text_summary(&self) -> MBTextSummary { self.transforms.summary().output } + #[ztracing::instrument(skip_all)] pub fn text_summary_for_range(&self, range: Range) -> MBTextSummary { let mut summary = MBTextSummary::default(); @@ -1044,6 +1066,7 @@ impl InlaySnapshot { summary } + #[ztracing::instrument(skip_all)] pub fn row_infos(&self, row: u32) -> InlayBufferRows<'_> { let mut cursor = self.transforms.cursor::>(()); let inlay_point = InlayPoint::new(row, 0); @@ -1071,6 +1094,7 @@ impl InlaySnapshot { } } + #[ztracing::instrument(skip_all)] pub fn line_len(&self, row: u32) -> u32 { let line_start = self.to_offset(InlayPoint::new(row, 0)).0; let line_end = if row >= self.max_point().row() { @@ -1081,6 +1105,7 @@ impl InlaySnapshot { (line_end - line_start) as u32 } + #[ztracing::instrument(skip_all)] pub(crate) fn chunks<'a>( &'a self, range: Range, @@ -1115,12 +1140,14 @@ impl InlaySnapshot { } #[cfg(test)] + #[ztracing::instrument(skip_all)] pub fn text(&self) -> String { self.chunks(Default::default()..self.len(), false, Highlights::default()) .map(|chunk| chunk.chunk.text) .collect() } + #[ztracing::instrument(skip_all)] fn check_invariants(&self) { #[cfg(any(debug_assertions, feature = "test-support"))] { @@ -1147,6 +1174,7 @@ pub struct InlayPointCursor<'transforms> { } impl InlayPointCursor<'_> { + #[ztracing::instrument(skip_all)] pub fn map(&mut self, point: Point) -> InlayPoint { let cursor = &mut self.cursor; if cursor.did_seek() { diff --git a/crates/editor/src/display_map/invisibles.rs b/crates/editor/src/display_map/invisibles.rs index 5622a659b7acf850d24f6a476b23b53d214d855d..90bd54ab2807bbef703ac29e4ac4eaf49bcf71fd 100644 --- a/crates/editor/src/display_map/invisibles.rs +++ b/crates/editor/src/display_map/invisibles.rs @@ -30,6 +30,7 @@ // ref: https://gist.github.com/ConradIrwin/f759e1fc29267143c4c7895aa495dca5?h=1 // ref: https://unicode.org/Public/emoji/13.0/emoji-test.txt // https://github.com/bits/UTF-8-Unicode-Test-Documents/blob/master/UTF-8_sequence_separated/utf8_sequence_0-0x10ffff_assigned_including-unprintable-asis.txt +#[ztracing::instrument(skip_all)] pub fn is_invisible(c: char) -> bool { if c <= '\u{1f}' { c != '\t' && c != '\n' && c != '\r' diff --git a/crates/editor/src/display_map/tab_map.rs b/crates/editor/src/display_map/tab_map.rs index 347d7732151e172812de1e0252ca8d65f4cdbb8b..4e768a477159820ea380aa48a123d103c0c2f6a2 100644 --- a/crates/editor/src/display_map/tab_map.rs +++ b/crates/editor/src/display_map/tab_map.rs @@ -20,6 +20,7 @@ const MAX_TABS: NonZeroU32 = NonZeroU32::new(SPACES.len() as u32).unwrap(); pub struct TabMap(TabSnapshot); impl TabMap { + #[ztracing::instrument(skip_all)] pub fn new(fold_snapshot: FoldSnapshot, tab_size: NonZeroU32) -> (Self, TabSnapshot) { let snapshot = TabSnapshot { fold_snapshot, @@ -36,6 +37,7 @@ impl TabMap { self.0.clone() } + #[ztracing::instrument(skip_all)] pub fn sync( &mut self, fold_snapshot: FoldSnapshot, @@ -176,10 +178,12 @@ impl std::ops::Deref for TabSnapshot { } impl TabSnapshot { + #[ztracing::instrument(skip_all)] pub fn buffer_snapshot(&self) -> &MultiBufferSnapshot { &self.fold_snapshot.inlay_snapshot.buffer } + #[ztracing::instrument(skip_all)] pub fn line_len(&self, row: u32) -> u32 { let max_point = self.max_point(); if row < max_point.row() { @@ -191,10 +195,12 @@ impl TabSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn text_summary(&self) -> TextSummary { self.text_summary_for_range(TabPoint::zero()..self.max_point()) } + #[ztracing::instrument(skip_all, fields(rows))] pub fn text_summary_for_range(&self, range: Range) -> TextSummary { let input_start = self.tab_point_to_fold_point(range.start, Bias::Left).0; let input_end = self.tab_point_to_fold_point(range.end, Bias::Right).0; @@ -234,6 +240,7 @@ impl TabSnapshot { } } + #[ztracing::instrument(skip_all)] pub(crate) fn chunks<'a>( &'a self, range: Range, @@ -276,11 +283,13 @@ impl TabSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn rows(&self, row: u32) -> fold_map::FoldRows<'_> { self.fold_snapshot.row_infos(row) } #[cfg(test)] + #[ztracing::instrument(skip_all)] pub fn text(&self) -> String { self.chunks( TabPoint::zero()..self.max_point(), @@ -291,10 +300,12 @@ impl TabSnapshot { .collect() } + #[ztracing::instrument(skip_all)] pub fn max_point(&self) -> TabPoint { self.fold_point_to_tab_point(self.fold_snapshot.max_point()) } + #[ztracing::instrument(skip_all)] pub fn clip_point(&self, point: TabPoint, bias: Bias) -> TabPoint { self.fold_point_to_tab_point( self.fold_snapshot @@ -302,6 +313,7 @@ impl TabSnapshot { ) } + #[ztracing::instrument(skip_all)] pub fn fold_point_to_tab_point(&self, input: FoldPoint) -> TabPoint { let chunks = self.fold_snapshot.chunks_at(FoldPoint::new(input.row(), 0)); let tab_cursor = TabStopCursor::new(chunks); @@ -309,10 +321,12 @@ impl TabSnapshot { TabPoint::new(input.row(), expanded) } + #[ztracing::instrument(skip_all)] pub fn tab_point_cursor(&self) -> TabPointCursor<'_> { TabPointCursor { this: self } } + #[ztracing::instrument(skip_all)] pub fn tab_point_to_fold_point(&self, output: TabPoint, bias: Bias) -> (FoldPoint, u32, u32) { let chunks = self .fold_snapshot @@ -330,12 +344,14 @@ impl TabSnapshot { ) } + #[ztracing::instrument(skip_all)] pub fn point_to_tab_point(&self, point: Point, bias: Bias) -> TabPoint { let inlay_point = self.fold_snapshot.inlay_snapshot.to_inlay_point(point); let fold_point = self.fold_snapshot.to_fold_point(inlay_point, bias); self.fold_point_to_tab_point(fold_point) } + #[ztracing::instrument(skip_all)] pub fn tab_point_to_point(&self, point: TabPoint, bias: Bias) -> Point { let fold_point = self.tab_point_to_fold_point(point, bias).0; let inlay_point = fold_point.to_inlay_point(&self.fold_snapshot); @@ -344,6 +360,7 @@ impl TabSnapshot { .to_buffer_point(inlay_point) } + #[ztracing::instrument(skip_all)] fn expand_tabs<'a, I>(&self, mut cursor: TabStopCursor<'a, I>, column: u32) -> u32 where I: Iterator>, @@ -377,6 +394,7 @@ impl TabSnapshot { expanded_bytes + column.saturating_sub(collapsed_bytes) } + #[ztracing::instrument(skip_all)] fn collapse_tabs<'a, I>( &self, mut cursor: TabStopCursor<'a, I>, @@ -442,6 +460,7 @@ pub struct TabPointCursor<'this> { } impl TabPointCursor<'_> { + #[ztracing::instrument(skip_all)] pub fn map(&mut self, point: FoldPoint) -> TabPoint { self.this.fold_point_to_tab_point(point) } @@ -486,6 +505,7 @@ pub struct TextSummary { } impl<'a> From<&'a str> for TextSummary { + #[ztracing::instrument(skip_all)] fn from(text: &'a str) -> Self { let sum = text::TextSummary::from(text); @@ -500,6 +520,7 @@ impl<'a> From<&'a str> for TextSummary { } impl<'a> std::ops::AddAssign<&'a Self> for TextSummary { + #[ztracing::instrument(skip_all)] fn add_assign(&mut self, other: &'a Self) { let joined_chars = self.last_line_chars + other.first_line_chars; if joined_chars > self.longest_row_chars { @@ -541,6 +562,7 @@ pub struct TabChunks<'a> { } impl TabChunks<'_> { + #[ztracing::instrument(skip_all)] pub(crate) fn seek(&mut self, range: Range) { let (input_start, expanded_char_column, to_next_stop) = self .snapshot @@ -576,6 +598,7 @@ impl TabChunks<'_> { impl<'a> Iterator for TabChunks<'a> { type Item = Chunk<'a>; + #[ztracing::instrument(skip_all)] fn next(&mut self) -> Option { if self.chunk.text.is_empty() { if let Some(chunk) = self.fold_chunks.next() { @@ -1452,6 +1475,7 @@ impl<'a, I> TabStopCursor<'a, I> where I: Iterator>, { + #[ztracing::instrument(skip_all)] fn new(chunks: impl IntoIterator, IntoIter = I>) -> Self { Self { chunks: chunks.into_iter(), @@ -1461,6 +1485,7 @@ where } } + #[ztracing::instrument(skip_all)] fn bytes_until_next_char(&self) -> Option { self.current_chunk.as_ref().and_then(|(chunk, idx)| { let mut idx = *idx; @@ -1482,6 +1507,7 @@ where }) } + #[ztracing::instrument(skip_all)] fn is_char_boundary(&self) -> bool { self.current_chunk .as_ref() @@ -1489,6 +1515,7 @@ where } /// distance: length to move forward while searching for the next tab stop + #[ztracing::instrument(skip_all)] fn seek(&mut self, distance: u32) -> Option { if distance == 0 { return None; diff --git a/crates/editor/src/display_map/wrap_map.rs b/crates/editor/src/display_map/wrap_map.rs index 20ef9391888e6a824b87fe5de2607500049904ff..51d5324c838dc7cb7f4df04b0e58577108aab6c8 100644 --- a/crates/editor/src/display_map/wrap_map.rs +++ b/crates/editor/src/display_map/wrap_map.rs @@ -86,6 +86,7 @@ pub struct WrapRows<'a> { } impl WrapRows<'_> { + #[ztracing::instrument(skip_all)] pub(crate) fn seek(&mut self, start_row: WrapRow) { self.transforms .seek(&WrapPoint::new(start_row, 0), Bias::Left); @@ -101,6 +102,7 @@ impl WrapRows<'_> { } impl WrapMap { + #[ztracing::instrument(skip_all)] pub fn new( tab_snapshot: TabSnapshot, font: Font, @@ -131,6 +133,7 @@ impl WrapMap { self.background_task.is_some() } + #[ztracing::instrument(skip_all)] pub fn sync( &mut self, tab_snapshot: TabSnapshot, @@ -150,6 +153,7 @@ impl WrapMap { (self.snapshot.clone(), mem::take(&mut self.edits_since_sync)) } + #[ztracing::instrument(skip_all)] pub fn set_font_with_size( &mut self, font: Font, @@ -167,6 +171,7 @@ impl WrapMap { } } + #[ztracing::instrument(skip_all)] pub fn set_wrap_width(&mut self, wrap_width: Option, cx: &mut Context) -> bool { if wrap_width == self.wrap_width { return false; @@ -177,6 +182,7 @@ impl WrapMap { true } + #[ztracing::instrument(skip_all)] fn rewrap(&mut self, cx: &mut Context) { self.background_task.take(); self.interpolated_edits.clear(); @@ -248,6 +254,7 @@ impl WrapMap { } } + #[ztracing::instrument(skip_all)] fn flush_edits(&mut self, cx: &mut Context) { if !self.snapshot.interpolated { let mut to_remove_len = 0; @@ -330,6 +337,7 @@ impl WrapMap { } impl WrapSnapshot { + #[ztracing::instrument(skip_all)] fn new(tab_snapshot: TabSnapshot) -> Self { let mut transforms = SumTree::default(); let extent = tab_snapshot.text_summary(); @@ -343,10 +351,12 @@ impl WrapSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn buffer_snapshot(&self) -> &MultiBufferSnapshot { self.tab_snapshot.buffer_snapshot() } + #[ztracing::instrument(skip_all)] fn interpolate(&mut self, new_tab_snapshot: TabSnapshot, tab_edits: &[TabEdit]) -> WrapPatch { let mut new_transforms; if tab_edits.is_empty() { @@ -411,6 +421,7 @@ impl WrapSnapshot { old_snapshot.compute_edits(tab_edits, self) } + #[ztracing::instrument(skip_all)] async fn update( &mut self, new_tab_snapshot: TabSnapshot, @@ -570,6 +581,7 @@ impl WrapSnapshot { old_snapshot.compute_edits(tab_edits, self) } + #[ztracing::instrument(skip_all)] fn compute_edits(&self, tab_edits: &[TabEdit], new_snapshot: &WrapSnapshot) -> WrapPatch { let mut wrap_edits = Vec::with_capacity(tab_edits.len()); let mut old_cursor = self.transforms.cursor::(()); @@ -606,6 +618,7 @@ impl WrapSnapshot { Patch::new(wrap_edits) } + #[ztracing::instrument(skip_all)] pub(crate) fn chunks<'a>( &'a self, rows: Range, @@ -640,10 +653,12 @@ impl WrapSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn max_point(&self) -> WrapPoint { WrapPoint(self.transforms.summary().output.lines) } + #[ztracing::instrument(skip_all)] pub fn line_len(&self, row: WrapRow) -> u32 { let (start, _, item) = self.transforms.find::, _>( (), @@ -664,6 +679,7 @@ impl WrapSnapshot { } } + #[ztracing::instrument(skip_all, fields(rows))] pub fn text_summary_for_range(&self, rows: Range) -> TextSummary { let mut summary = TextSummary::default(); @@ -725,6 +741,7 @@ impl WrapSnapshot { summary } + #[ztracing::instrument(skip_all)] pub fn soft_wrap_indent(&self, row: WrapRow) -> Option { let (.., item) = self.transforms.find::( (), @@ -740,10 +757,12 @@ impl WrapSnapshot { }) } + #[ztracing::instrument(skip_all)] pub fn longest_row(&self) -> u32 { self.transforms.summary().output.longest_row } + #[ztracing::instrument(skip_all)] pub fn row_infos(&self, start_row: WrapRow) -> WrapRows<'_> { let mut transforms = self .transforms @@ -766,6 +785,7 @@ impl WrapSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn to_tab_point(&self, point: WrapPoint) -> TabPoint { let (start, _, item) = self.transforms @@ -777,15 +797,18 @@ impl WrapSnapshot { TabPoint(tab_point) } + #[ztracing::instrument(skip_all)] pub fn to_point(&self, point: WrapPoint, bias: Bias) -> Point { self.tab_snapshot .tab_point_to_point(self.to_tab_point(point), bias) } + #[ztracing::instrument(skip_all)] pub fn make_wrap_point(&self, point: Point, bias: Bias) -> WrapPoint { self.tab_point_to_wrap_point(self.tab_snapshot.point_to_tab_point(point, bias)) } + #[ztracing::instrument(skip_all)] pub fn tab_point_to_wrap_point(&self, point: TabPoint) -> WrapPoint { let (start, ..) = self.transforms @@ -793,6 +816,7 @@ impl WrapSnapshot { WrapPoint(start.1.0 + (point.0 - start.0.0)) } + #[ztracing::instrument(skip_all)] pub fn wrap_point_cursor(&self) -> WrapPointCursor<'_> { WrapPointCursor { cursor: self @@ -801,6 +825,7 @@ impl WrapSnapshot { } } + #[ztracing::instrument(skip_all)] pub fn clip_point(&self, mut point: WrapPoint, bias: Bias) -> WrapPoint { if bias == Bias::Left { let (start, _, item) = self @@ -815,6 +840,7 @@ impl WrapSnapshot { self.tab_point_to_wrap_point(self.tab_snapshot.clip_point(self.to_tab_point(point), bias)) } + #[ztracing::instrument(skip_all, fields(point, ret))] pub fn prev_row_boundary(&self, mut point: WrapPoint) -> WrapRow { if self.transforms.is_empty() { return WrapRow(0); @@ -841,6 +867,7 @@ impl WrapSnapshot { unreachable!() } + #[ztracing::instrument(skip_all)] pub fn next_row_boundary(&self, mut point: WrapPoint) -> Option { point.0 += Point::new(1, 0); @@ -860,11 +887,13 @@ impl WrapSnapshot { } #[cfg(test)] + #[ztracing::instrument(skip_all)] pub fn text(&self) -> String { self.text_chunks(WrapRow(0)).collect() } #[cfg(test)] + #[ztracing::instrument(skip_all)] pub fn text_chunks(&self, wrap_row: WrapRow) -> impl Iterator { self.chunks( wrap_row..self.max_point().row() + WrapRow(1), @@ -874,6 +903,7 @@ impl WrapSnapshot { .map(|h| h.text) } + #[ztracing::instrument(skip_all)] fn check_invariants(&self) { #[cfg(test)] { @@ -927,6 +957,7 @@ pub struct WrapPointCursor<'transforms> { } impl WrapPointCursor<'_> { + #[ztracing::instrument(skip_all)] pub fn map(&mut self, point: TabPoint) -> WrapPoint { let cursor = &mut self.cursor; if cursor.did_seek() { @@ -939,6 +970,7 @@ impl WrapPointCursor<'_> { } impl WrapChunks<'_> { + #[ztracing::instrument(skip_all)] pub(crate) fn seek(&mut self, rows: Range) { let output_start = WrapPoint::new(rows.start, 0); let output_end = WrapPoint::new(rows.end, 0); @@ -961,6 +993,7 @@ impl WrapChunks<'_> { impl<'a> Iterator for WrapChunks<'a> { type Item = Chunk<'a>; + #[ztracing::instrument(skip_all)] fn next(&mut self) -> Option { if self.output_position.row() >= self.max_output_row { return None; @@ -1033,6 +1066,7 @@ impl<'a> Iterator for WrapChunks<'a> { impl Iterator for WrapRows<'_> { type Item = RowInfo; + #[ztracing::instrument(skip_all)] fn next(&mut self) -> Option { if self.output_row > self.max_output_row { return None; @@ -1069,6 +1103,7 @@ impl Iterator for WrapRows<'_> { } impl Transform { + #[ztracing::instrument(skip_all)] fn isomorphic(summary: TextSummary) -> Self { #[cfg(test)] assert!(!summary.lines.is_zero()); @@ -1082,6 +1117,7 @@ impl Transform { } } + #[ztracing::instrument(skip_all)] fn wrap(indent: u32) -> Self { static WRAP_TEXT: LazyLock = LazyLock::new(|| { let mut wrap_text = String::new(); @@ -1134,6 +1170,7 @@ trait SumTreeExt { } impl SumTreeExt for SumTree { + #[ztracing::instrument(skip_all)] fn push_or_extend(&mut self, transform: Transform) { let mut transform = Some(transform); self.update_last( @@ -1197,6 +1234,7 @@ impl<'a> sum_tree::Dimension<'a, TransformSummary> for TabPoint { } impl sum_tree::SeekTarget<'_, TransformSummary, TransformSummary> for TabPoint { + #[ztracing::instrument(skip_all)] fn cmp(&self, cursor_location: &TransformSummary, _: ()) -> std::cmp::Ordering { Ord::cmp(&self.0, &cursor_location.input.lines) } diff --git a/crates/editor/src/edit_prediction_tests.rs b/crates/editor/src/edit_prediction_tests.rs index a1839144a47a81f668ba2743cd5e362f6711d0e9..bfce1532ce78699e1fb524fd594df1ba83c864a5 100644 --- a/crates/editor/src/edit_prediction_tests.rs +++ b/crates/editor/src/edit_prediction_tests.rs @@ -1,4 +1,4 @@ -use edit_prediction::EditPredictionProvider; +use edit_prediction_types::EditPredictionDelegate; use gpui::{Entity, KeyBinding, Modifiers, prelude::*}; use indoc::indoc; use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint}; @@ -15,7 +15,7 @@ async fn test_edit_prediction_insert(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeEditPredictionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionDelegate::default()); assign_editor_completion_provider(provider.clone(), &mut cx); cx.set_state("let absolute_zero_celsius = ˇ;"); @@ -37,7 +37,7 @@ async fn test_edit_prediction_modification(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeEditPredictionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionDelegate::default()); assign_editor_completion_provider(provider.clone(), &mut cx); cx.set_state("let pi = ˇ\"foo\";"); @@ -59,7 +59,7 @@ async fn test_edit_prediction_jump_button(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeEditPredictionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionDelegate::default()); assign_editor_completion_provider(provider.clone(), &mut cx); // Cursor is 2+ lines above the proposed edit @@ -128,7 +128,7 @@ async fn test_edit_prediction_invalidation_range(cx: &mut gpui::TestAppContext) init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeEditPredictionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionDelegate::default()); assign_editor_completion_provider(provider.clone(), &mut cx); // Cursor is 3+ lines above the proposed edit @@ -233,7 +233,7 @@ async fn test_edit_prediction_jump_disabled_for_non_zed_providers(cx: &mut gpui: init_test(cx, |_| {}); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeNonZedEditPredictionProvider::default()); + let provider = cx.new(|_| FakeNonZedEditPredictionDelegate::default()); assign_editor_completion_provider_non_zed(provider.clone(), &mut cx); // Cursor is 2+ lines above the proposed edit @@ -281,7 +281,7 @@ async fn test_edit_prediction_preview_cleanup_on_toggle_off(cx: &mut gpui::TestA cx.update(|cx| cx.bind_keys([KeyBinding::new("ctrl-shift-a", AcceptEditPrediction, None)])); let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeEditPredictionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionDelegate::default()); assign_editor_completion_provider(provider.clone(), &mut cx); cx.set_state("let x = ˇ;"); @@ -371,7 +371,7 @@ fn accept_completion(cx: &mut EditorTestContext) { } fn propose_edits( - provider: &Entity, + provider: &Entity, edits: Vec<(Range, &str)>, cx: &mut EditorTestContext, ) { @@ -383,7 +383,7 @@ fn propose_edits( cx.update(|_, cx| { provider.update(cx, |provider, _| { - provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local { + provider.set_edit_prediction(Some(edit_prediction_types::EditPrediction::Local { id: None, edits: edits.collect(), edit_preview: None, @@ -393,7 +393,7 @@ fn propose_edits( } fn assign_editor_completion_provider( - provider: Entity, + provider: Entity, cx: &mut EditorTestContext, ) { cx.update_editor(|editor, window, cx| { @@ -402,7 +402,7 @@ fn assign_editor_completion_provider( } fn propose_edits_non_zed( - provider: &Entity, + provider: &Entity, edits: Vec<(Range, &str)>, cx: &mut EditorTestContext, ) { @@ -414,7 +414,7 @@ fn propose_edits_non_zed( cx.update(|_, cx| { provider.update(cx, |provider, _| { - provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local { + provider.set_edit_prediction(Some(edit_prediction_types::EditPrediction::Local { id: None, edits: edits.collect(), edit_preview: None, @@ -424,7 +424,7 @@ fn propose_edits_non_zed( } fn assign_editor_completion_provider_non_zed( - provider: Entity, + provider: Entity, cx: &mut EditorTestContext, ) { cx.update_editor(|editor, window, cx| { @@ -433,17 +433,20 @@ fn assign_editor_completion_provider_non_zed( } #[derive(Default, Clone)] -pub struct FakeEditPredictionProvider { - pub completion: Option, +pub struct FakeEditPredictionDelegate { + pub completion: Option, } -impl FakeEditPredictionProvider { - pub fn set_edit_prediction(&mut self, completion: Option) { +impl FakeEditPredictionDelegate { + pub fn set_edit_prediction( + &mut self, + completion: Option, + ) { self.completion = completion; } } -impl EditPredictionProvider for FakeEditPredictionProvider { +impl EditPredictionDelegate for FakeEditPredictionDelegate { fn name() -> &'static str { "fake-completion-provider" } @@ -452,7 +455,7 @@ impl EditPredictionProvider for FakeEditPredictionProvider { "Fake Completion Provider" } - fn show_completions_in_menu() -> bool { + fn show_predictions_in_menu() -> bool { true } @@ -486,7 +489,7 @@ impl EditPredictionProvider for FakeEditPredictionProvider { &mut self, _buffer: gpui::Entity, _cursor_position: language::Anchor, - _direction: edit_prediction::Direction, + _direction: edit_prediction_types::Direction, _cx: &mut gpui::Context, ) { } @@ -500,23 +503,26 @@ impl EditPredictionProvider for FakeEditPredictionProvider { _buffer: &gpui::Entity, _cursor_position: language::Anchor, _cx: &mut gpui::Context, - ) -> Option { + ) -> Option { self.completion.clone() } } #[derive(Default, Clone)] -pub struct FakeNonZedEditPredictionProvider { - pub completion: Option, +pub struct FakeNonZedEditPredictionDelegate { + pub completion: Option, } -impl FakeNonZedEditPredictionProvider { - pub fn set_edit_prediction(&mut self, completion: Option) { +impl FakeNonZedEditPredictionDelegate { + pub fn set_edit_prediction( + &mut self, + completion: Option, + ) { self.completion = completion; } } -impl EditPredictionProvider for FakeNonZedEditPredictionProvider { +impl EditPredictionDelegate for FakeNonZedEditPredictionDelegate { fn name() -> &'static str { "fake-non-zed-provider" } @@ -525,7 +531,7 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider { "Fake Non-Zed Provider" } - fn show_completions_in_menu() -> bool { + fn show_predictions_in_menu() -> bool { false } @@ -559,7 +565,7 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider { &mut self, _buffer: gpui::Entity, _cursor_position: language::Anchor, - _direction: edit_prediction::Direction, + _direction: edit_prediction_types::Direction, _cx: &mut gpui::Context, ) { } @@ -573,7 +579,7 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider { _buffer: &gpui::Entity, _cursor_position: language::Anchor, _cx: &mut gpui::Context, - ) -> Option { + ) -> Option { self.completion.clone() } } diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 4b352e2d8298f3c9ae2c0d38bd6b443d62a61996..223c0e574cff7daa9f50b07735b40e291185a37c 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -51,7 +51,7 @@ pub mod test; pub(crate) use actions::*; pub use display_map::{ChunkRenderer, ChunkRendererContext, DisplayPoint, FoldPlaceholder}; -pub use edit_prediction::Direction; +pub use edit_prediction_types::Direction; pub use editor_settings::{ CurrentLineHighlight, DocumentColorsRenderMode, EditorSettings, HideMouseMode, ScrollBeyondLastLine, ScrollbarAxes, SearchSettings, ShowMinimap, @@ -92,7 +92,7 @@ use collections::{BTreeMap, HashMap, HashSet, VecDeque}; use convert_case::{Case, Casing}; use dap::TelemetrySpawnLocation; use display_map::*; -use edit_prediction::{EditPredictionProvider, EditPredictionProviderHandle}; +use edit_prediction_types::{EditPredictionDelegate, EditPredictionDelegateHandle}; use editor_settings::{GoToDefinitionFallback, Minimap as MinimapSettings}; use element::{AcceptEditPredictionBinding, LineWithInvisibles, PositionMap, layout_line}; use futures::{ @@ -1079,6 +1079,7 @@ pub struct Editor { show_breakpoints: Option, show_wrap_guides: Option, show_indent_guides: Option, + buffers_with_disabled_indent_guides: HashSet, highlight_order: usize, highlighted_rows: HashMap>, background_highlights: HashMap, @@ -1119,7 +1120,7 @@ pub struct Editor { pending_mouse_down: Option>>>, gutter_hovered: bool, hovered_link_state: Option, - edit_prediction_provider: Option, + edit_prediction_provider: Option, code_action_providers: Vec>, active_edit_prediction: Option, /// Used to prevent flickering as the user types while the menu is open @@ -1127,6 +1128,7 @@ pub struct Editor { edit_prediction_settings: EditPredictionSettings, edit_predictions_hidden_for_vim_mode: bool, show_edit_predictions_override: Option, + show_completions_on_input_override: Option, menu_edit_predictions_policy: MenuEditPredictionsPolicy, edit_prediction_preview: EditPredictionPreview, edit_prediction_indent_conflict: bool, @@ -1561,8 +1563,8 @@ pub struct RenameState { struct InvalidationStack(Vec); -struct RegisteredEditPredictionProvider { - provider: Arc, +struct RegisteredEditPredictionDelegate { + provider: Arc, _subscription: Subscription, } @@ -1590,6 +1592,45 @@ pub struct ClipboardSelection { pub is_entire_line: bool, /// The indentation of the first line when this content was originally copied. pub first_line_indent: u32, + #[serde(default)] + pub file_path: Option, + #[serde(default)] + pub line_range: Option>, +} + +impl ClipboardSelection { + pub fn for_buffer( + len: usize, + is_entire_line: bool, + range: Range, + buffer: &MultiBufferSnapshot, + project: Option<&Entity>, + cx: &App, + ) -> Self { + let first_line_indent = buffer + .indent_size_for_line(MultiBufferRow(range.start.row)) + .len; + + let file_path = util::maybe!({ + let project = project?.read(cx); + let file = buffer.file_at(range.start)?; + let project_path = ProjectPath { + worktree_id: file.worktree_id(cx), + path: file.path().clone(), + }; + project.absolute_path(&project_path, cx) + }); + + let line_range = file_path.as_ref().map(|_| range.start.row..=range.end.row); + + Self { + len, + is_entire_line, + first_line_indent, + file_path, + line_range, + } + } } // selections, scroll behavior, was newest selection reversed @@ -2204,6 +2245,7 @@ impl Editor { show_breakpoints: None, show_wrap_guides: None, show_indent_guides, + buffers_with_disabled_indent_guides: HashSet::default(), highlight_order: 0, highlighted_rows: HashMap::default(), background_highlights: HashMap::default(), @@ -2273,6 +2315,7 @@ impl Editor { editor_actions: Rc::default(), edit_predictions_hidden_for_vim_mode: false, show_edit_predictions_override: None, + show_completions_on_input_override: None, menu_edit_predictions_policy: MenuEditPredictionsPolicy::ByProvider, edit_prediction_settings: EditPredictionSettings::Disabled, edit_prediction_indent_conflict: false, @@ -2986,9 +3029,9 @@ impl Editor { window: &mut Window, cx: &mut Context, ) where - T: EditPredictionProvider, + T: EditPredictionDelegate, { - self.edit_prediction_provider = provider.map(|provider| RegisteredEditPredictionProvider { + self.edit_prediction_provider = provider.map(|provider| RegisteredEditPredictionDelegate { _subscription: cx.observe_in(&provider, window, |this, _, window, cx| { if this.focus_handle.is_focused(window) { this.update_visible_edit_prediction(window, cx); @@ -3155,6 +3198,10 @@ impl Editor { } } + pub fn set_show_completions_on_input(&mut self, show_completions_on_input: Option) { + self.show_completions_on_input_override = show_completions_on_input; + } + pub fn set_show_edit_predictions( &mut self, show_edit_predictions: Option, @@ -5531,7 +5578,10 @@ impl Editor { let language_settings = language_settings(language.clone(), buffer_snapshot.file(), cx); let completion_settings = language_settings.completions.clone(); - if !menu_is_open && trigger.is_some() && !language_settings.show_completions_on_input { + let show_completions_on_input = self + .show_completions_on_input_override + .unwrap_or(language_settings.show_completions_on_input); + if !menu_is_open && trigger.is_some() && !show_completions_on_input { return; } @@ -6909,7 +6959,11 @@ impl Editor { } } - fn hide_blame_popover(&mut self, ignore_timeout: bool, cx: &mut Context) -> bool { + pub fn has_mouse_context_menu(&self) -> bool { + self.mouse_context_menu.is_some() + } + + pub fn hide_blame_popover(&mut self, ignore_timeout: bool, cx: &mut Context) -> bool { self.inline_blame_popover_show_task.take(); if let Some(state) = &mut self.inline_blame_popover { let hide_task = cx.spawn(async move |editor, cx| { @@ -7392,7 +7446,7 @@ impl Editor { && self .edit_prediction_provider .as_ref() - .is_some_and(|provider| provider.provider.show_completions_in_menu()); + .is_some_and(|provider| provider.provider.show_predictions_in_menu()); let preview_requires_modifier = all_language_settings(file, cx).edit_predictions_mode() == EditPredictionsMode::Subtle; @@ -8093,12 +8147,12 @@ impl Editor { let edit_prediction = provider.suggest(&buffer, cursor_buffer_position, cx)?; let (completion_id, edits, edit_preview) = match edit_prediction { - edit_prediction::EditPrediction::Local { + edit_prediction_types::EditPrediction::Local { id, edits, edit_preview, } => (id, edits, edit_preview), - edit_prediction::EditPrediction::Jump { + edit_prediction_types::EditPrediction::Jump { id, snapshot, target, @@ -8239,7 +8293,7 @@ impl Editor { Some(()) } - pub fn edit_prediction_provider(&self) -> Option> { + pub fn edit_prediction_provider(&self) -> Option> { Some(self.edit_prediction_provider.as_ref()?.provider.clone()) } @@ -9561,7 +9615,7 @@ impl Editor { editor_bg_color.blend(accent_color.opacity(0.6)) } fn get_prediction_provider_icon_name( - provider: &Option, + provider: &Option, ) -> IconName { match provider { Some(provider) => match provider.provider.name() { @@ -12797,13 +12851,15 @@ impl Editor { text.push_str(chunk); len += chunk.len(); } - clipboard_selections.push(ClipboardSelection { + + clipboard_selections.push(ClipboardSelection::for_buffer( len, is_entire_line, - first_line_indent: buffer - .indent_size_for_line(MultiBufferRow(selection.start.row)) - .len, - }); + selection.range(), + &buffer, + self.project.as_ref(), + cx, + )); } } @@ -12946,13 +13002,14 @@ impl Editor { text.push('\n'); len += 1; } - clipboard_selections.push(ClipboardSelection { + clipboard_selections.push(ClipboardSelection::for_buffer( len, is_entire_line, - first_line_indent: buffer - .indent_size_for_line(MultiBufferRow(trimmed_range.start.row)) - .len, - }); + trimmed_range, + &buffer, + self.project.as_ref(), + cx, + )); } } } @@ -16800,7 +16857,7 @@ impl Editor { GoToDefinitionFallback::None => Ok(Navigated::No), GoToDefinitionFallback::FindAllReferences => { match editor.update_in(cx, |editor, window, cx| { - editor.find_all_references(&FindAllReferences, window, cx) + editor.find_all_references(&FindAllReferences::default(), window, cx) })? { Some(references) => references.await, None => Ok(Navigated::No), @@ -17028,9 +17085,7 @@ impl Editor { }) .collect(); - let Some(workspace) = self.workspace() else { - return Task::ready(Ok(Navigated::No)); - }; + let workspace = self.workspace(); cx.spawn_in(window, async move |editor, cx| { let locations: Vec = future::join_all(definitions) @@ -17085,6 +17140,10 @@ impl Editor { }) .context("buffer title")?; + let Some(workspace) = workspace else { + return Ok(Navigated::No); + }; + let opened = workspace .update_in(cx, |workspace, window, cx| { let allow_preview = PreviewTabsSettings::get_global(cx) @@ -17114,6 +17173,9 @@ impl Editor { // TODO(andrew): respect preview tab settings // `enable_keep_preview_on_code_navigation` and // `enable_preview_file_from_code_navigation` + let Some(workspace) = workspace else { + return Ok(Navigated::No); + }; workspace .update_in(cx, |workspace, window, cx| { workspace.open_resolved_path(path, window, cx) @@ -17137,6 +17199,9 @@ impl Editor { { editor.go_to_singleton_buffer_range(range, window, cx); } else { + let Some(workspace) = workspace else { + return Navigated::No; + }; let pane = workspace.read(cx).active_pane().clone(); window.defer(cx, move |window, cx| { let target_editor: Entity = @@ -17348,20 +17413,21 @@ impl Editor { pub fn find_all_references( &mut self, - _: &FindAllReferences, + action: &FindAllReferences, window: &mut Window, cx: &mut Context, ) -> Option>> { - let selection = self - .selections - .newest::(&self.display_snapshot(cx)); + let always_open_multibuffer = action.always_open_multibuffer; + let selection = self.selections.newest_anchor(); let multi_buffer = self.buffer.read(cx); - let head = selection.head(); - let multi_buffer_snapshot = multi_buffer.snapshot(cx); + let selection_offset = selection.map(|anchor| anchor.to_offset(&multi_buffer_snapshot)); + let selection_point = selection.map(|anchor| anchor.to_point(&multi_buffer_snapshot)); + let head = selection_offset.head(); + let head_anchor = multi_buffer_snapshot.anchor_at( head, - if head < selection.tail() { + if head < selection_offset.tail() { Bias::Right } else { Bias::Left @@ -17407,6 +17473,15 @@ impl Editor { let buffer = location.buffer.read(cx); (location.buffer, location.range.to_point(buffer)) }) + // if special-casing the single-match case, remove ranges + // that intersect current selection + .filter(|(location_buffer, location)| { + if always_open_multibuffer || &buffer != location_buffer { + return true; + } + + !location.contains_inclusive(&selection_point.range()) + }) .into_group_map() })?; if locations.is_empty() { @@ -17416,6 +17491,60 @@ impl Editor { ranges.sort_by_key(|range| (range.start, Reverse(range.end))); ranges.dedup(); } + let mut num_locations = 0; + for ranges in locations.values_mut() { + ranges.sort_by_key(|range| (range.start, Reverse(range.end))); + ranges.dedup(); + num_locations += ranges.len(); + } + + if num_locations == 1 && !always_open_multibuffer { + let (target_buffer, target_ranges) = locations.into_iter().next().unwrap(); + let target_range = target_ranges.first().unwrap().clone(); + + return editor.update_in(cx, |editor, window, cx| { + let range = target_range.to_point(target_buffer.read(cx)); + let range = editor.range_for_match(&range); + let range = range.start..range.start; + + if Some(&target_buffer) == editor.buffer.read(cx).as_singleton().as_ref() { + editor.go_to_singleton_buffer_range(range, window, cx); + } else { + let pane = workspace.read(cx).active_pane().clone(); + window.defer(cx, move |window, cx| { + let target_editor: Entity = + workspace.update(cx, |workspace, cx| { + let pane = workspace.active_pane().clone(); + + let preview_tabs_settings = PreviewTabsSettings::get_global(cx); + let keep_old_preview = preview_tabs_settings + .enable_keep_preview_on_code_navigation; + let allow_new_preview = preview_tabs_settings + .enable_preview_file_from_code_navigation; + + workspace.open_project_item( + pane, + target_buffer.clone(), + true, + true, + keep_old_preview, + allow_new_preview, + window, + cx, + ) + }); + target_editor.update(cx, |target_editor, cx| { + // When selecting a definition in a different buffer, disable the nav history + // to avoid creating a history entry at the previous cursor location. + pane.update(cx, |pane, _| pane.disable_history()); + target_editor.go_to_singleton_buffer_range(range, window, cx); + pane.update(cx, |pane, _| pane.enable_history()); + }); + }); + } + Navigated::No + }); + } workspace.update_in(cx, |workspace, window, cx| { let target = locations @@ -17453,7 +17582,7 @@ impl Editor { })) } - /// Opens a multibuffer with the given project locations in it + /// Opens a multibuffer with the given project locations in it. pub fn open_locations_in_multibuffer( workspace: &mut Workspace, locations: std::collections::HashMap, Vec>>, @@ -20090,6 +20219,20 @@ impl Editor { self.show_indent_guides } + pub fn disable_indent_guides_for_buffer( + &mut self, + buffer_id: BufferId, + cx: &mut Context, + ) { + self.buffers_with_disabled_indent_guides.insert(buffer_id); + cx.notify(); + } + + pub fn has_indent_guides_disabled_for_buffer(&self, buffer_id: BufferId) -> bool { + self.buffers_with_disabled_indent_guides + .contains(&buffer_id) + } + pub fn toggle_line_numbers( &mut self, _: &ToggleLineNumbers, @@ -22077,14 +22220,23 @@ impl Editor { None => Autoscroll::newest(), }; let nav_history = editor.nav_history.take(); + let multibuffer_snapshot = editor.buffer().read(cx).snapshot(cx); + let Some((&excerpt_id, _, buffer_snapshot)) = + multibuffer_snapshot.as_singleton() + else { + return; + }; editor.change_selections( SelectionEffects::scroll(autoscroll), window, cx, |s| { s.select_ranges(ranges.into_iter().map(|range| { - // we checked that the editor is a singleton editor so the offsets are valid - MultiBufferOffset(range.start.0)..MultiBufferOffset(range.end.0) + let range = buffer_snapshot.anchor_before(range.start) + ..buffer_snapshot.anchor_after(range.end); + multibuffer_snapshot + .anchor_range_in_excerpt(excerpt_id, range) + .unwrap() })); }, ); diff --git a/crates/editor/src/editor_tests.rs b/crates/editor/src/editor_tests.rs index 64c335e2e4b0dc660efe1b28bb87984fba8aafb4..75aa29be7a31ef6a4707e538ea21a2ca45aff395 100644 --- a/crates/editor/src/editor_tests.rs +++ b/crates/editor/src/editor_tests.rs @@ -2,7 +2,7 @@ use super::*; use crate::{ JoinLines, code_context_menus::CodeContextMenu, - edit_prediction_tests::FakeEditPredictionProvider, + edit_prediction_tests::FakeEditPredictionDelegate, element::StickyHeader, linked_editing_ranges::LinkedEditingRanges, scroll::scroll_amount::ScrollAmount, @@ -8636,7 +8636,7 @@ async fn test_undo_edit_prediction_scrolls_to_edit_pos(cx: &mut TestAppContext) let mut cx = EditorTestContext::new(cx).await; - let provider = cx.new(|_| FakeEditPredictionProvider::default()); + let provider = cx.new(|_| FakeEditPredictionDelegate::default()); cx.update_editor(|editor, window, cx| { editor.set_edit_prediction_provider(Some(provider.clone()), window, cx); }); @@ -8659,7 +8659,7 @@ async fn test_undo_edit_prediction_scrolls_to_edit_pos(cx: &mut TestAppContext) cx.update(|_, cx| { provider.update(cx, |provider, _| { - provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local { + provider.set_edit_prediction(Some(edit_prediction_types::EditPrediction::Local { id: None, edits: vec![(edit_position..edit_position, "X".into())], edit_preview: None, @@ -9970,7 +9970,7 @@ async fn test_autoindent_disabled_with_nested_language(cx: &mut TestAppContext) ], ..Default::default() }, - name: LanguageName::new("rust"), + name: LanguageName::new_static("rust"), ..Default::default() }, Some(tree_sitter_rust::LANGUAGE.into()), @@ -22579,7 +22579,7 @@ async fn test_find_all_references_editor_reuse(cx: &mut TestAppContext) { }); let navigated = cx .update_editor(|editor, window, cx| { - editor.find_all_references(&FindAllReferences, window, cx) + editor.find_all_references(&FindAllReferences::default(), window, cx) }) .unwrap() .await @@ -22615,7 +22615,7 @@ async fn test_find_all_references_editor_reuse(cx: &mut TestAppContext) { ); let navigated = cx .update_editor(|editor, window, cx| { - editor.find_all_references(&FindAllReferences, window, cx) + editor.find_all_references(&FindAllReferences::default(), window, cx) }) .unwrap() .await @@ -22667,7 +22667,7 @@ async fn test_find_all_references_editor_reuse(cx: &mut TestAppContext) { }); let navigated = cx .update_editor(|editor, window, cx| { - editor.find_all_references(&FindAllReferences, window, cx) + editor.find_all_references(&FindAllReferences::default(), window, cx) }) .unwrap() .await @@ -27498,6 +27498,159 @@ async fn test_paste_url_from_other_app_creates_markdown_link_over_selected_text( )); } +#[gpui::test] +async fn test_markdown_indents(cx: &mut gpui::TestAppContext) { + init_test(cx, |_| {}); + + let markdown_language = languages::language("markdown", tree_sitter_md::LANGUAGE.into()); + let mut cx = EditorTestContext::new(cx).await; + + cx.update_buffer(|buffer, cx| buffer.set_language(Some(markdown_language), cx)); + + // Case 1: Test if adding a character with multi cursors preserves nested list indents + cx.set_state(&indoc! {" + - [ ] Item 1 + - [ ] Item 1.a + - [ˇ] Item 2 + - [ˇ] Item 2.a + - [ˇ] Item 2.b + " + }); + cx.update_editor(|editor, window, cx| { + editor.handle_input("x", window, cx); + }); + cx.assert_editor_state(indoc! {" + - [ ] Item 1 + - [ ] Item 1.a + - [xˇ] Item 2 + - [xˇ] Item 2.a + - [xˇ] Item 2.b + " + }); + + // Case 2: Test adding new line after nested list preserves indent of previous line + cx.set_state(&indoc! {" + - [ ] Item 1 + - [ ] Item 1.a + - [x] Item 2 + - [x] Item 2.a + - [x] Item 2.bˇ + " + }); + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + cx.assert_editor_state(indoc! {" + - [ ] Item 1 + - [ ] Item 1.a + - [x] Item 2 + - [x] Item 2.a + - [x] Item 2.b + ˇ + " + }); + + // Case 3: Test adding a new nested list item preserves indent + cx.update_editor(|editor, window, cx| { + editor.handle_input("-", window, cx); + }); + cx.assert_editor_state(indoc! {" + - [ ] Item 1 + - [ ] Item 1.a + - [x] Item 2 + - [x] Item 2.a + - [x] Item 2.b + -ˇ + " + }); + cx.update_editor(|editor, window, cx| { + editor.handle_input(" [x] Item 2.c", window, cx); + }); + cx.assert_editor_state(indoc! {" + - [ ] Item 1 + - [ ] Item 1.a + - [x] Item 2 + - [x] Item 2.a + - [x] Item 2.b + - [x] Item 2.cˇ + " + }); + + // Case 4: Test adding new line after nested ordered list preserves indent of previous line + cx.set_state(indoc! {" + 1. Item 1 + 1. Item 1.a + 2. Item 2 + 1. Item 2.a + 2. Item 2.bˇ + " + }); + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + cx.assert_editor_state(indoc! {" + 1. Item 1 + 1. Item 1.a + 2. Item 2 + 1. Item 2.a + 2. Item 2.b + ˇ + " + }); + + // Case 5: Adding new ordered list item preserves indent + cx.update_editor(|editor, window, cx| { + editor.handle_input("3", window, cx); + }); + cx.assert_editor_state(indoc! {" + 1. Item 1 + 1. Item 1.a + 2. Item 2 + 1. Item 2.a + 2. Item 2.b + 3ˇ + " + }); + cx.update_editor(|editor, window, cx| { + editor.handle_input(".", window, cx); + }); + cx.assert_editor_state(indoc! {" + 1. Item 1 + 1. Item 1.a + 2. Item 2 + 1. Item 2.a + 2. Item 2.b + 3.ˇ + " + }); + cx.update_editor(|editor, window, cx| { + editor.handle_input(" Item 2.c", window, cx); + }); + cx.assert_editor_state(indoc! {" + 1. Item 1 + 1. Item 1.a + 2. Item 2 + 1. Item 2.a + 2. Item 2.b + 3. Item 2.cˇ + " + }); + + // Case 7: Test blockquote newline preserves something + cx.set_state(indoc! {" + > Item 1ˇ + " + }); + cx.update_editor(|editor, window, cx| { + editor.newline(&Newline, window, cx); + }); + cx.assert_editor_state(indoc! {" + > Item 1 + ˇ + " + }); +} + #[gpui::test] async fn test_paste_url_from_zed_copy_creates_markdown_link_over_selected_text( cx: &mut gpui::TestAppContext, @@ -28446,33 +28599,32 @@ async fn test_multibuffer_selections_with_folding(cx: &mut TestAppContext) { 3 "}); - // Edge case scenario: fold all buffers, then try to insert + // Test correct folded header is selected upon fold cx.update_editor(|editor, _, cx| { editor.fold_buffer(buffer_ids[0], cx); editor.fold_buffer(buffer_ids[1], cx); }); cx.assert_excerpts_with_selections(indoc! {" - [EXCERPT] - ˇ[FOLDED] [EXCERPT] [FOLDED] + [EXCERPT] + ˇ[FOLDED] "}); - // Insert should work via default selection + // Test selection inside folded buffer unfolds it on type cx.update_editor(|editor, window, cx| { editor.handle_input("W", window, cx); }); cx.update_editor(|editor, _, cx| { editor.unfold_buffer(buffer_ids[0], cx); - editor.unfold_buffer(buffer_ids[1], cx); }); cx.assert_excerpts_with_selections(indoc! {" [EXCERPT] - Wˇ1 + 1 2 3 [EXCERPT] - 1 + Wˇ1 Z 3 "}); @@ -28763,3 +28915,65 @@ async fn test_multibuffer_scroll_cursor_top_margin(cx: &mut TestAppContext) { ); }); } + +#[gpui::test] +async fn test_find_references_single_case(cx: &mut TestAppContext) { + init_test(cx, |_| {}); + let mut cx = EditorLspTestContext::new_rust( + lsp::ServerCapabilities { + references_provider: Some(lsp::OneOf::Left(true)), + ..lsp::ServerCapabilities::default() + }, + cx, + ) + .await; + + let before = indoc!( + r#" + fn main() { + let aˇbc = 123; + let xyz = abc; + } + "# + ); + let after = indoc!( + r#" + fn main() { + let abc = 123; + let xyz = ˇabc; + } + "# + ); + + cx.lsp + .set_request_handler::(async move |params, _| { + Ok(Some(vec![ + lsp::Location { + uri: params.text_document_position.text_document.uri.clone(), + range: lsp::Range::new(lsp::Position::new(1, 8), lsp::Position::new(1, 11)), + }, + lsp::Location { + uri: params.text_document_position.text_document.uri, + range: lsp::Range::new(lsp::Position::new(2, 14), lsp::Position::new(2, 17)), + }, + ])) + }); + + cx.set_state(before); + + let action = FindAllReferences { + always_open_multibuffer: false, + }; + + let navigated = cx + .update_editor(|editor, window, cx| editor.find_all_references(&action, window, cx)) + .expect("should have spawned a task") + .await + .unwrap(); + + assert_eq!(navigated, Navigated::No); + + cx.run_until_parked(); + + cx.assert_editor_state(after); +} diff --git a/crates/editor/src/element.rs b/crates/editor/src/element.rs index 3319af92eb04015bd3bd01760235e3dba0047975..edb3778ff94809ef880ffa167f2ff410a3199a37 100644 --- a/crates/editor/src/element.rs +++ b/crates/editor/src/element.rs @@ -2340,7 +2340,7 @@ impl EditorElement { .opacity(0.05)) .text_color(severity_to_color(&diagnostic_to_render.severity).color(cx)) .text_sm() - .font_family(style.text.font().family) + .font(style.text.font()) .child(diagnostic_to_render.message.clone()) .into_any(); @@ -3915,6 +3915,8 @@ impl EditorElement { ) -> impl IntoElement { let editor = self.editor.read(cx); let multi_buffer = editor.buffer.read(cx); + let is_read_only = self.editor.read(cx).read_only(cx); + let file_status = multi_buffer .all_diff_hunks_expanded() .then(|| editor.status_for_buffer_id(for_excerpt.buffer_id, cx)) @@ -3967,7 +3969,7 @@ impl EditorElement { .gap_1p5() .when(is_sticky, |el| el.shadow_md()) .border_1() - .map(|div| { + .map(|border| { let border_color = if is_selected && is_folded && focus_handle.contains_focused(window, cx) @@ -3976,7 +3978,7 @@ impl EditorElement { } else { colors.border }; - div.border_color(border_color) + border.border_color(border_color) }) .bg(colors.editor_subheader_background) .hover(|style| style.bg(colors.element_hover)) @@ -4056,13 +4058,15 @@ impl EditorElement { }) .take(1), ) - .child( - h_flex() - .size_3() - .justify_center() - .flex_shrink_0() - .children(indicator), - ) + .when(!is_read_only, |this| { + this.child( + h_flex() + .size_3() + .justify_center() + .flex_shrink_0() + .children(indicator), + ) + }) .child( h_flex() .cursor_pointer() diff --git a/crates/editor/src/git/blame.rs b/crates/editor/src/git/blame.rs index 008630faef7cc1ccb3b9703e4b11c0b88b7cf17c..67df69aadab43a45c2941703e10bb81af2b8dd78 100644 --- a/crates/editor/src/git/blame.rs +++ b/crates/editor/src/git/blame.rs @@ -508,7 +508,19 @@ impl GitBlame { let buffer_edits = buffer.update(cx, |buffer, _| buffer.subscribe()); let blame_buffer = project.blame_buffer(&buffer, None, cx); - Some(async move { (id, snapshot, buffer_edits, blame_buffer.await) }) + let remote_url = project + .git_store() + .read(cx) + .repository_and_path_for_buffer_id(buffer.read(cx).remote_id(), cx) + .and_then(|(repo, _)| { + repo.read(cx) + .remote_upstream_url + .clone() + .or(repo.read(cx).remote_origin_url.clone()) + }); + Some( + async move { (id, snapshot, buffer_edits, blame_buffer.await, remote_url) }, + ) }) .collect::>() }); @@ -524,13 +536,9 @@ impl GitBlame { .await; let mut res = vec![]; let mut errors = vec![]; - for (id, snapshot, buffer_edits, blame) in blame { + for (id, snapshot, buffer_edits, blame, remote_url) in blame { match blame { - Ok(Some(Blame { - entries, - messages, - remote_url, - })) => { + Ok(Some(Blame { entries, messages })) => { let entries = build_blame_entry_sum_tree( entries, snapshot.max_point().row, diff --git a/crates/editor/src/hover_links.rs b/crates/editor/src/hover_links.rs index 5ef52f36dfd609a49d93eb77b74b2df3287df30a..9d0261f00f8f7258023b092d4f55d40ac8abcf40 100644 --- a/crates/editor/src/hover_links.rs +++ b/crates/editor/src/hover_links.rs @@ -168,7 +168,7 @@ impl Editor { match EditorSettings::get_global(cx).go_to_definition_fallback { GoToDefinitionFallback::None => None, GoToDefinitionFallback::FindAllReferences => { - editor.find_all_references(&FindAllReferences, window, cx) + editor.find_all_references(&FindAllReferences::default(), window, cx) } } }) diff --git a/crates/editor/src/hover_popover.rs b/crates/editor/src/hover_popover.rs index caabe6e6f5ab6ae80b3ead9d72fdcbec59937ff6..9ef54139d39ece6e9414d8fee3c7a75c9a89036d 100644 --- a/crates/editor/src/hover_popover.rs +++ b/crates/editor/src/hover_popover.rs @@ -607,13 +607,16 @@ async fn parse_blocks( pub fn hover_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { let settings = ThemeSettings::get_global(cx); let ui_font_family = settings.ui_font.family.clone(); + let ui_font_features = settings.ui_font.features.clone(); let ui_font_fallbacks = settings.ui_font.fallbacks.clone(); let buffer_font_family = settings.buffer_font.family.clone(); + let buffer_font_features = settings.buffer_font.features.clone(); let buffer_font_fallbacks = settings.buffer_font.fallbacks.clone(); let mut base_text_style = window.text_style(); base_text_style.refine(&TextStyleRefinement { font_family: Some(ui_font_family), + font_features: Some(ui_font_features), font_fallbacks: ui_font_fallbacks, color: Some(cx.theme().colors().editor_foreground), ..Default::default() @@ -624,6 +627,7 @@ pub fn hover_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { inline_code: TextStyleRefinement { background_color: Some(cx.theme().colors().background), font_family: Some(buffer_font_family), + font_features: Some(buffer_font_features), font_fallbacks: buffer_font_fallbacks, ..Default::default() }, @@ -657,12 +661,15 @@ pub fn diagnostics_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { let settings = ThemeSettings::get_global(cx); let ui_font_family = settings.ui_font.family.clone(); let ui_font_fallbacks = settings.ui_font.fallbacks.clone(); + let ui_font_features = settings.ui_font.features.clone(); let buffer_font_family = settings.buffer_font.family.clone(); + let buffer_font_features = settings.buffer_font.features.clone(); let buffer_font_fallbacks = settings.buffer_font.fallbacks.clone(); let mut base_text_style = window.text_style(); base_text_style.refine(&TextStyleRefinement { font_family: Some(ui_font_family), + font_features: Some(ui_font_features), font_fallbacks: ui_font_fallbacks, color: Some(cx.theme().colors().editor_foreground), ..Default::default() @@ -673,6 +680,7 @@ pub fn diagnostics_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { inline_code: TextStyleRefinement { background_color: Some(cx.theme().colors().editor_background.opacity(0.5)), font_family: Some(buffer_font_family), + font_features: Some(buffer_font_features), font_fallbacks: buffer_font_fallbacks, ..Default::default() }, diff --git a/crates/editor/src/indent_guides.rs b/crates/editor/src/indent_guides.rs index 7c392d27531472a413ce4d32d09cce4eb722e462..f186f9da77aca5a0d34cdc05272032f93862b1d2 100644 --- a/crates/editor/src/indent_guides.rs +++ b/crates/editor/src/indent_guides.rs @@ -181,6 +181,10 @@ pub fn indent_guides_in_range( .buffer_snapshot() .indent_guides_in_range(start_anchor..end_anchor, ignore_disabled_for_language, cx) .filter(|indent_guide| { + if editor.has_indent_guides_disabled_for_buffer(indent_guide.buffer_id) { + return false; + } + if editor.is_buffer_folded(indent_guide.buffer_id, cx) { return false; } diff --git a/crates/editor/src/items.rs b/crates/editor/src/items.rs index ca8937bebe3d3578c7fe2fdec2c6252bdd395e6d..3b9c17f80f10116f2302bab203966922cbf0bcb2 100644 --- a/crates/editor/src/items.rs +++ b/crates/editor/src/items.rs @@ -1951,7 +1951,7 @@ mod tests { use super::*; use fs::MTime; use gpui::{App, VisualTestContext}; - use language::{LanguageMatcher, TestFile}; + use language::TestFile; use project::FakeFs; use std::path::{Path, PathBuf}; use util::{path, rel_path::RelPath}; @@ -1991,20 +1991,6 @@ mod tests { .unwrap() } - fn rust_language() -> Arc { - Arc::new(language::Language::new( - language::LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - )) - } - #[gpui::test] async fn test_deserialize(cx: &mut gpui::TestAppContext) { init_test(cx, |_| {}); @@ -2086,7 +2072,9 @@ mod tests { { let project = Project::test(fs.clone(), [path!("/file.rs").as_ref()], cx).await; // Add Rust to the language, so that we can restore the language of the buffer - project.read_with(cx, |project, _| project.languages().add(rust_language())); + project.read_with(cx, |project, _| { + project.languages().add(languages::rust_lang()) + }); let (workspace, cx) = cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); diff --git a/crates/editor/src/mouse_context_menu.rs b/crates/editor/src/mouse_context_menu.rs index 39ad8a3672511724d69ba59a366513a24edb2198..e868b105fac8a8fa87601e2d5bc8578c94bd1940 100644 --- a/crates/editor/src/mouse_context_menu.rs +++ b/crates/editor/src/mouse_context_menu.rs @@ -235,7 +235,10 @@ pub fn deploy_context_menu( .action("Go to Declaration", Box::new(GoToDeclaration)) .action("Go to Type Definition", Box::new(GoToTypeDefinition)) .action("Go to Implementation", Box::new(GoToImplementation)) - .action("Find All References", Box::new(FindAllReferences)) + .action( + "Find All References", + Box::new(FindAllReferences::default()), + ) .separator() .action("Rename Symbol", Box::new(Rename)) .action("Format Buffer", Box::new(Format)) diff --git a/crates/editor/src/selections_collection.rs b/crates/editor/src/selections_collection.rs index f8ff9da763403b0946e99a4e39c934ff43ad6634..6f744e11334fc32e7985ee77e25866ef0c6cfe4c 100644 --- a/crates/editor/src/selections_collection.rs +++ b/crates/editor/src/selections_collection.rs @@ -419,22 +419,30 @@ impl SelectionsCollection { mutable_collection.disjoint.iter().for_each(|selection| { assert!( snapshot.can_resolve(&selection.start), - "disjoint selection start is not resolvable for the given snapshot:\n{selection:?}", + "disjoint selection start is not resolvable for the given snapshot:\n{selection:?}, {excerpt:?}", + excerpt = snapshot.buffer_for_excerpt(selection.start.excerpt_id).map(|snapshot| snapshot.remote_id()), ); assert!( snapshot.can_resolve(&selection.end), - "disjoint selection end is not resolvable for the given snapshot: {selection:?}", + "disjoint selection end is not resolvable for the given snapshot: {selection:?}, {excerpt:?}", + excerpt = snapshot.buffer_for_excerpt(selection.end.excerpt_id).map(|snapshot| snapshot.remote_id()), ); }); if let Some(pending) = &mutable_collection.pending { let selection = &pending.selection; assert!( snapshot.can_resolve(&selection.start), - "pending selection start is not resolvable for the given snapshot: {pending:?}", + "pending selection start is not resolvable for the given snapshot: {pending:?}, {excerpt:?}", + excerpt = snapshot + .buffer_for_excerpt(selection.start.excerpt_id) + .map(|snapshot| snapshot.remote_id()), ); assert!( snapshot.can_resolve(&selection.end), - "pending selection end is not resolvable for the given snapshot: {pending:?}", + "pending selection end is not resolvable for the given snapshot: {pending:?}, {excerpt:?}", + excerpt = snapshot + .buffer_for_excerpt(selection.end.excerpt_id) + .map(|snapshot| snapshot.remote_id()), ); } } @@ -532,11 +540,18 @@ impl<'snap, 'a> MutableSelectionsCollection<'snap, 'a> { }; if filtered_selections.is_empty() { - let default_anchor = self.snapshot.anchor_before(MultiBufferOffset(0)); + let buffer_snapshot = self.snapshot.buffer_snapshot(); + let anchor = buffer_snapshot + .excerpts() + .find(|(_, buffer, _)| buffer.remote_id() == buffer_id) + .and_then(|(excerpt_id, _, range)| { + buffer_snapshot.anchor_in_excerpt(excerpt_id, range.context.start) + }) + .unwrap_or_else(|| self.snapshot.anchor_before(MultiBufferOffset(0))); self.collection.disjoint = Arc::from([Selection { id: post_inc(&mut self.collection.next_selection_id), - start: default_anchor, - end: default_anchor, + start: anchor, + end: anchor, reversed: false, goal: SelectionGoal::None, }]); diff --git a/crates/extension/Cargo.toml b/crates/extension/Cargo.toml index ed084a0c1c0d6ea237dd06391330d8e41917b574..a00406c43e796b1b2bafafc3d0481b09ee5fd43a 100644 --- a/crates/extension/Cargo.toml +++ b/crates/extension/Cargo.toml @@ -38,3 +38,4 @@ wasmparser.workspace = true [dev-dependencies] pretty_assertions.workspace = true +tempfile.workspace = true diff --git a/crates/extension/src/extension_builder.rs b/crates/extension/src/extension_builder.rs index 1385ec488a2de36f5894ab91ac0ae45881aa625f..b6d4cc0c4da3d1f7d998512f099eba6c9c04e1a5 100644 --- a/crates/extension/src/extension_builder.rs +++ b/crates/extension/src/extension_builder.rs @@ -247,26 +247,34 @@ impl ExtensionBuilder { let parser_path = src_path.join("parser.c"); let scanner_path = src_path.join("scanner.c"); - log::info!("compiling {grammar_name} parser"); - let clang_output = util::command::new_smol_command(&clang_path) - .args(["-fPIC", "-shared", "-Os"]) - .arg(format!("-Wl,--export=tree_sitter_{grammar_name}")) - .arg("-o") - .arg(&grammar_wasm_path) - .arg("-I") - .arg(&src_path) - .arg(&parser_path) - .args(scanner_path.exists().then_some(scanner_path)) - .output() - .await - .context("failed to run clang")?; - - if !clang_output.status.success() { - bail!( - "failed to compile {} parser with clang: {}", - grammar_name, - String::from_utf8_lossy(&clang_output.stderr), + // Skip recompiling if the WASM object is already newer than the source files + if file_newer_than_deps(&grammar_wasm_path, &[&parser_path, &scanner_path]).unwrap_or(false) + { + log::info!( + "skipping compilation of {grammar_name} parser because the existing compiled grammar is up to date" ); + } else { + log::info!("compiling {grammar_name} parser"); + let clang_output = util::command::new_smol_command(&clang_path) + .args(["-fPIC", "-shared", "-Os"]) + .arg(format!("-Wl,--export=tree_sitter_{grammar_name}")) + .arg("-o") + .arg(&grammar_wasm_path) + .arg("-I") + .arg(&src_path) + .arg(&parser_path) + .args(scanner_path.exists().then_some(scanner_path)) + .output() + .await + .context("failed to run clang")?; + + if !clang_output.status.success() { + bail!( + "failed to compile {} parser with clang: {}", + grammar_name, + String::from_utf8_lossy(&clang_output.stderr), + ); + } } Ok(()) @@ -643,3 +651,71 @@ fn populate_defaults(manifest: &mut ExtensionManifest, extension_path: &Path) -> Ok(()) } + +/// Returns `true` if the target exists and its last modified time is greater than that +/// of each dependency which exists (i.e., dependency paths which do not exist are ignored). +/// +/// # Errors +/// +/// Returns `Err` if any of the underlying file I/O operations fail. +fn file_newer_than_deps(target: &Path, dependencies: &[&Path]) -> Result { + if !target.try_exists()? { + return Ok(false); + } + let target_modified = target.metadata()?.modified()?; + for dependency in dependencies { + if !dependency.try_exists()? { + continue; + } + let dep_modified = dependency.metadata()?.modified()?; + if target_modified < dep_modified { + return Ok(false); + } + } + Ok(true) +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::{fs, thread::sleep, time::Duration}; + + #[test] + fn test_file_newer_than_deps() { + // Don't use TempTree because we need to guarantee the order + let tmpdir = tempfile::tempdir().unwrap(); + let target = tmpdir.path().join("target.wasm"); + let dep1 = tmpdir.path().join("parser.c"); + let dep2 = tmpdir.path().join("scanner.c"); + + assert!( + !file_newer_than_deps(&target, &[&dep1, &dep2]).unwrap(), + "target doesn't exist" + ); + fs::write(&target, "foo").unwrap(); // Create target + assert!( + file_newer_than_deps(&target, &[&dep1, &dep2]).unwrap(), + "dependencies don't exist; target is newer" + ); + sleep(Duration::from_secs(1)); + fs::write(&dep1, "foo").unwrap(); // Create dep1 (newer than target) + // Dependency is newer + assert!( + !file_newer_than_deps(&target, &[&dep1, &dep2]).unwrap(), + "a dependency is newer (target {:?}, dep1 {:?})", + target.metadata().unwrap().modified().unwrap(), + dep1.metadata().unwrap().modified().unwrap(), + ); + sleep(Duration::from_secs(1)); + fs::write(&dep2, "foo").unwrap(); // Create dep2 + sleep(Duration::from_secs(1)); + fs::write(&target, "foobar").unwrap(); // Update target + assert!( + file_newer_than_deps(&target, &[&dep1, &dep2]).unwrap(), + "target is newer than dependencies (target {:?}, dep2 {:?})", + target.metadata().unwrap().modified().unwrap(), + dep2.metadata().unwrap().modified().unwrap(), + ); + } +} diff --git a/crates/extension_host/src/extension_store_test.rs b/crates/extension_host/src/extension_store_test.rs index b3275ff52ff7fe9fb8661f9d3cb4141cb5dfa2e7..c17484f26a06b3392cdbcd8f3c1578eb43c7b213 100644 --- a/crates/extension_host/src/extension_store_test.rs +++ b/crates/extension_host/src/extension_store_test.rs @@ -309,9 +309,9 @@ async fn test_extension_store(cx: &mut TestAppContext) { assert_eq!( language_registry.language_names(), [ - LanguageName::new("ERB"), - LanguageName::new("Plain Text"), - LanguageName::new("Ruby"), + LanguageName::new_static("ERB"), + LanguageName::new_static("Plain Text"), + LanguageName::new_static("Ruby"), ] ); assert_eq!( @@ -466,9 +466,9 @@ async fn test_extension_store(cx: &mut TestAppContext) { assert_eq!( language_registry.language_names(), [ - LanguageName::new("ERB"), - LanguageName::new("Plain Text"), - LanguageName::new("Ruby"), + LanguageName::new_static("ERB"), + LanguageName::new_static("Plain Text"), + LanguageName::new_static("Ruby"), ] ); assert_eq!( @@ -526,7 +526,7 @@ async fn test_extension_store(cx: &mut TestAppContext) { assert_eq!( language_registry.language_names(), - [LanguageName::new("Plain Text")] + [LanguageName::new_static("Plain Text")] ); assert_eq!(language_registry.grammar_names(), []); }); @@ -708,7 +708,7 @@ async fn test_extension_store_with_test_extension(cx: &mut TestAppContext) { .await .unwrap(); - let mut fake_servers = language_registry.register_fake_language_server( + let mut fake_servers = language_registry.register_fake_lsp_server( LanguageServerName("gleam".into()), lsp::ServerCapabilities { completion_provider: Some(Default::default()), diff --git a/crates/extensions_ui/src/extensions_ui.rs b/crates/extensions_ui/src/extensions_ui.rs index 89247ae5a49a99b2a4f2261892b2656e14bb8674..3dd4803ce17adb053f86f29e3724d58d479136c6 100644 --- a/crates/extensions_ui/src/extensions_ui.rs +++ b/crates/extensions_ui/src/extensions_ui.rs @@ -1472,8 +1472,7 @@ impl ExtensionsPage { }, ); }, - )) - .color(ui::SwitchColor::Accent), + )), ), ), ) diff --git a/crates/feature_flags/src/flags.rs b/crates/feature_flags/src/flags.rs index 47b6f1230ac747c2633327d1be923d33388cf179..fe11a7b5eaa162a90ae8ba3f691ca804ab64db2d 100644 --- a/crates/feature_flags/src/flags.rs +++ b/crates/feature_flags/src/flags.rs @@ -1,11 +1,5 @@ use crate::FeatureFlag; -pub struct PredictEditsRateCompletionsFeatureFlag; - -impl FeatureFlag for PredictEditsRateCompletionsFeatureFlag { - const NAME: &'static str = "predict-edits-rate-completions"; -} - pub struct NotebookFeatureFlag; impl FeatureFlag for NotebookFeatureFlag { @@ -17,3 +11,9 @@ pub struct PanicFeatureFlag; impl FeatureFlag for PanicFeatureFlag { const NAME: &'static str = "panic"; } + +pub struct InlineAssistantV2FeatureFlag; + +impl FeatureFlag for InlineAssistantV2FeatureFlag { + const NAME: &'static str = "inline-assistant-v2"; +} diff --git a/crates/fs/src/fake_git_repo.rs b/crates/fs/src/fake_git_repo.rs index 3bc411ff2d9b917fd409c29cca03d2191ee80978..6ca8b5a58f9f8f75023aa73e7a80e8547eb156f3 100644 --- a/crates/fs/src/fake_git_repo.rs +++ b/crates/fs/src/fake_git_repo.rs @@ -381,11 +381,18 @@ impl GitRepository for FakeGitRepository { Ok(state .branches .iter() - .map(|branch_name| Branch { - is_head: Some(branch_name) == current_branch.as_ref(), - ref_name: branch_name.into(), - most_recent_commit: None, - upstream: None, + .map(|branch_name| { + let ref_name = if branch_name.starts_with("refs/") { + branch_name.into() + } else { + format!("refs/heads/{branch_name}").into() + }; + Branch { + is_head: Some(branch_name) == current_branch.as_ref(), + ref_name, + most_recent_commit: None, + upstream: None, + } }) .collect()) }) diff --git a/crates/fuzzy/src/matcher.rs b/crates/fuzzy/src/matcher.rs index eb844e349821394785bb61a34600f04a6fa985eb..782c9caca832d81fb6e4bce8f49b4f310664b292 100644 --- a/crates/fuzzy/src/matcher.rs +++ b/crates/fuzzy/src/matcher.rs @@ -96,7 +96,8 @@ impl<'a> Matcher<'a> { continue; } - let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len()); + let matrix_len = + self.query.len() * (lowercase_prefix.len() + lowercase_candidate_chars.len()); self.score_matrix.clear(); self.score_matrix.resize(matrix_len, None); self.best_position_matrix.clear(); @@ -596,4 +597,15 @@ mod tests { }) .collect() } + + /// Test for https://github.com/zed-industries/zed/issues/44324 + #[test] + fn test_recursive_score_match_index_out_of_bounds() { + let paths = vec!["İ/İ/İ/İ"]; + let query = "İ/İ"; + + // This panicked with "index out of bounds: the len is 21 but the index is 22" + let result = match_single_path_query(query, false, &paths); + let _ = result; + } } diff --git a/crates/git/src/blame.rs b/crates/git/src/blame.rs index e58b9cb7e0427bf3af1c88f473debba0b6f94f59..6325eacc8201d812d14dfdf4853f4004e22c263e 100644 --- a/crates/git/src/blame.rs +++ b/crates/git/src/blame.rs @@ -19,7 +19,6 @@ pub use git2 as libgit; pub struct Blame { pub entries: Vec, pub messages: HashMap, - pub remote_url: Option, } #[derive(Clone, Debug, Default)] @@ -36,7 +35,6 @@ impl Blame { working_directory: &Path, path: &RepoPath, content: &Rope, - remote_url: Option, ) -> Result { let output = run_git_blame(git_binary, working_directory, path, content).await?; let mut entries = parse_git_blame(&output)?; @@ -53,11 +51,7 @@ impl Blame { .await .context("failed to get commit messages")?; - Ok(Self { - entries, - messages, - remote_url, - }) + Ok(Self { entries, messages }) } } diff --git a/crates/git/src/repository.rs b/crates/git/src/repository.rs index f79bade2d6bc12553b173c4f4e86989a961e6d31..70cbf6e3c58b7d8f6b690a554370d34262f541e3 100644 --- a/crates/git/src/repository.rs +++ b/crates/git/src/repository.rs @@ -1494,28 +1494,17 @@ impl GitRepository for RealGitRepository { let git_binary_path = self.any_git_binary_path.clone(); let executor = self.executor.clone(); - async move { - let remote_url = if let Some(remote_url) = self.remote_url("upstream").await { - Some(remote_url) - } else if let Some(remote_url) = self.remote_url("origin").await { - Some(remote_url) - } else { - None - }; - executor - .spawn(async move { - crate::blame::Blame::for_path( - &git_binary_path, - &working_directory?, - &path, - &content, - remote_url, - ) - .await - }) + executor + .spawn(async move { + crate::blame::Blame::for_path( + &git_binary_path, + &working_directory?, + &path, + &content, + ) .await - } - .boxed() + }) + .boxed() } fn file_history(&self, path: RepoPath) -> BoxFuture<'_, Result> { diff --git a/crates/git_hosting_providers/Cargo.toml b/crates/git_hosting_providers/Cargo.toml index 851556151e285975cb1eb7d3d33244d7e11b5663..9480e0ec28c0ffa61c1126b2a627f22dc445d7d3 100644 --- a/crates/git_hosting_providers/Cargo.toml +++ b/crates/git_hosting_providers/Cargo.toml @@ -18,6 +18,7 @@ futures.workspace = true git.workspace = true gpui.workspace = true http_client.workspace = true +itertools.workspace = true regex.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/crates/git_hosting_providers/src/git_hosting_providers.rs b/crates/git_hosting_providers/src/git_hosting_providers.rs index 98ea301ec984298df54ec8bca7e28f9474e373bd..37cf5882059d7a274661f5a083a23e3f25e676ff 100644 --- a/crates/git_hosting_providers/src/git_hosting_providers.rs +++ b/crates/git_hosting_providers/src/git_hosting_providers.rs @@ -26,7 +26,7 @@ pub fn init(cx: &mut App) { provider_registry.register_hosting_provider(Arc::new(Gitee)); provider_registry.register_hosting_provider(Arc::new(Github::public_instance())); provider_registry.register_hosting_provider(Arc::new(Gitlab::public_instance())); - provider_registry.register_hosting_provider(Arc::new(Sourcehut)); + provider_registry.register_hosting_provider(Arc::new(SourceHut::public_instance())); } /// Registers additional Git hosting providers. @@ -51,6 +51,8 @@ pub async fn register_additional_providers( provider_registry.register_hosting_provider(Arc::new(gitea_self_hosted)); } else if let Ok(bitbucket_self_hosted) = Bitbucket::from_remote_url(&origin_url) { provider_registry.register_hosting_provider(Arc::new(bitbucket_self_hosted)); + } else if let Ok(sourcehut_self_hosted) = SourceHut::from_remote_url(&origin_url) { + provider_registry.register_hosting_provider(Arc::new(sourcehut_self_hosted)); } } diff --git a/crates/git_hosting_providers/src/providers/bitbucket.rs b/crates/git_hosting_providers/src/providers/bitbucket.rs index 0c30a13758a8339087ebb146f0029baee0d3ea7e..07c6898d4e0affc3bdde4d8290607897cf2cd5be 100644 --- a/crates/git_hosting_providers/src/providers/bitbucket.rs +++ b/crates/git_hosting_providers/src/providers/bitbucket.rs @@ -1,8 +1,14 @@ -use std::str::FromStr; use std::sync::LazyLock; - -use anyhow::{Result, bail}; +use std::{str::FromStr, sync::Arc}; + +use anyhow::{Context as _, Result, bail}; +use async_trait::async_trait; +use futures::AsyncReadExt; +use gpui::SharedString; +use http_client::{AsyncBody, HttpClient, HttpRequestExt, Request}; +use itertools::Itertools as _; use regex::Regex; +use serde::Deserialize; use url::Url; use git::{ @@ -20,6 +26,42 @@ fn pull_request_regex() -> &'static Regex { &PULL_REQUEST_REGEX } +#[derive(Debug, Deserialize)] +struct CommitDetails { + author: Author, +} + +#[derive(Debug, Deserialize)] +struct Author { + user: Account, +} + +#[derive(Debug, Deserialize)] +struct Account { + links: AccountLinks, +} + +#[derive(Debug, Deserialize)] +struct AccountLinks { + avatar: Option, +} + +#[derive(Debug, Deserialize)] +struct Link { + href: String, +} + +#[derive(Debug, Deserialize)] +struct CommitDetailsSelfHosted { + author: AuthorSelfHosted, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct AuthorSelfHosted { + avatar_url: Option, +} + pub struct Bitbucket { name: String, base_url: Url, @@ -61,8 +103,60 @@ impl Bitbucket { .host_str() .is_some_and(|host| host != "bitbucket.org") } + + async fn fetch_bitbucket_commit_author( + &self, + repo_owner: &str, + repo: &str, + commit: &str, + client: &Arc, + ) -> Result> { + let Some(host) = self.base_url.host_str() else { + bail!("failed to get host from bitbucket base url"); + }; + let is_self_hosted = self.is_self_hosted(); + let url = if is_self_hosted { + format!( + "https://{host}/rest/api/latest/projects/{repo_owner}/repos/{repo}/commits/{commit}?avatarSize=128" + ) + } else { + format!("https://api.{host}/2.0/repositories/{repo_owner}/{repo}/commit/{commit}") + }; + + let request = Request::get(&url) + .header("Content-Type", "application/json") + .follow_redirects(http_client::RedirectPolicy::FollowAll); + + let mut response = client + .send(request.body(AsyncBody::default())?) + .await + .with_context(|| format!("error fetching BitBucket commit details at {:?}", url))?; + + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + + if response.status().is_client_error() { + let text = String::from_utf8_lossy(body.as_slice()); + bail!( + "status error {}, response: {text:?}", + response.status().as_u16() + ); + } + + let body_str = std::str::from_utf8(&body)?; + + if is_self_hosted { + serde_json::from_str::(body_str) + .map(|commit| commit.author.avatar_url) + } else { + serde_json::from_str::(body_str) + .map(|commit| commit.author.user.links.avatar.map(|link| link.href)) + } + .context("failed to deserialize BitBucket commit details") + } } +#[async_trait] impl GitHostingProvider for Bitbucket { fn name(&self) -> String { self.name.clone() @@ -73,7 +167,7 @@ impl GitHostingProvider for Bitbucket { } fn supports_avatars(&self) -> bool { - false + true } fn format_line_number(&self, line: u32) -> String { @@ -98,9 +192,16 @@ impl GitHostingProvider for Bitbucket { return None; } - let mut path_segments = url.path_segments()?; - let owner = path_segments.next()?; - let repo = path_segments.next()?.trim_end_matches(".git"); + let mut path_segments = url.path_segments()?.collect::>(); + let repo = path_segments.pop()?.trim_end_matches(".git"); + let owner = if path_segments.get(0).is_some_and(|v| *v == "scm") && path_segments.len() > 1 + { + // Skip the "scm" segment if it's not the only segment + // https://github.com/gitkraken/vscode-gitlens/blob/a6e3c6fbb255116507eaabaa9940c192ed7bb0e1/src/git/remotes/bitbucket-server.ts#L72-L74 + path_segments.into_iter().skip(1).join("/") + } else { + path_segments.into_iter().join("/") + }; Some(ParsedGitRemote { owner: owner.into(), @@ -176,6 +277,22 @@ impl GitHostingProvider for Bitbucket { Some(PullRequest { number, url }) } + + async fn commit_author_avatar_url( + &self, + repo_owner: &str, + repo: &str, + commit: SharedString, + http_client: Arc, + ) -> Result> { + let commit = commit.to_string(); + let avatar_url = self + .fetch_bitbucket_commit_author(repo_owner, repo, &commit, &http_client) + .await? + .map(|avatar_url| Url::parse(&avatar_url)) + .transpose()?; + Ok(avatar_url) + } } #[cfg(test)] @@ -264,6 +381,38 @@ mod tests { repo: "zed".into(), } ); + + // Test with "scm" in the path + let remote_url = "https://bitbucket.company.com/scm/zed-industries/zed.git"; + + let parsed_remote = Bitbucket::from_remote_url(remote_url) + .unwrap() + .parse_remote_url(remote_url) + .unwrap(); + + assert_eq!( + parsed_remote, + ParsedGitRemote { + owner: "zed-industries".into(), + repo: "zed".into(), + } + ); + + // Test with only "scm" as owner + let remote_url = "https://bitbucket.company.com/scm/zed.git"; + + let parsed_remote = Bitbucket::from_remote_url(remote_url) + .unwrap() + .parse_remote_url(remote_url) + .unwrap(); + + assert_eq!( + parsed_remote, + ParsedGitRemote { + owner: "scm".into(), + repo: "zed".into(), + } + ); } #[test] diff --git a/crates/git_hosting_providers/src/providers/sourcehut.rs b/crates/git_hosting_providers/src/providers/sourcehut.rs index 55bff551846b5f69bad8ccaeaccf3ad55868303f..41011b023bef01a06138ead26d93ba447d6a4ba1 100644 --- a/crates/git_hosting_providers/src/providers/sourcehut.rs +++ b/crates/git_hosting_providers/src/providers/sourcehut.rs @@ -1,5 +1,6 @@ use std::str::FromStr; +use anyhow::{Result, bail}; use url::Url; use git::{ @@ -7,15 +8,52 @@ use git::{ RemoteUrl, }; -pub struct Sourcehut; +use crate::get_host_from_git_remote_url; -impl GitHostingProvider for Sourcehut { +pub struct SourceHut { + name: String, + base_url: Url, +} + +impl SourceHut { + pub fn new(name: &str, base_url: Url) -> Self { + Self { + name: name.to_string(), + base_url, + } + } + + pub fn public_instance() -> Self { + Self::new("SourceHut", Url::parse("https://git.sr.ht").unwrap()) + } + + pub fn from_remote_url(remote_url: &str) -> Result { + let host = get_host_from_git_remote_url(remote_url)?; + if host == "git.sr.ht" { + bail!("the SourceHut instance is not self-hosted"); + } + + // TODO: detecting self hosted instances by checking whether "sourcehut" is in the url or not + // is not very reliable. See https://github.com/zed-industries/zed/issues/26393 for more + // information. + if !host.contains("sourcehut") { + bail!("not a SourceHut URL"); + } + + Ok(Self::new( + "SourceHut Self-Hosted", + Url::parse(&format!("https://{}", host))?, + )) + } +} + +impl GitHostingProvider for SourceHut { fn name(&self) -> String { - "SourceHut".to_string() + self.name.clone() } fn base_url(&self) -> Url { - Url::parse("https://git.sr.ht").unwrap() + self.base_url.clone() } fn supports_avatars(&self) -> bool { @@ -34,7 +72,7 @@ impl GitHostingProvider for Sourcehut { let url = RemoteUrl::from_str(url).ok()?; let host = url.host_str()?; - if host != "git.sr.ht" { + if host != self.base_url.host_str()? { return None; } @@ -96,7 +134,7 @@ mod tests { #[test] fn test_parse_remote_url_given_ssh_url() { - let parsed_remote = Sourcehut + let parsed_remote = SourceHut::public_instance() .parse_remote_url("git@git.sr.ht:~zed-industries/zed") .unwrap(); @@ -111,7 +149,7 @@ mod tests { #[test] fn test_parse_remote_url_given_ssh_url_with_git_suffix() { - let parsed_remote = Sourcehut + let parsed_remote = SourceHut::public_instance() .parse_remote_url("git@git.sr.ht:~zed-industries/zed.git") .unwrap(); @@ -126,7 +164,7 @@ mod tests { #[test] fn test_parse_remote_url_given_https_url() { - let parsed_remote = Sourcehut + let parsed_remote = SourceHut::public_instance() .parse_remote_url("https://git.sr.ht/~zed-industries/zed") .unwrap(); @@ -139,9 +177,63 @@ mod tests { ); } + #[test] + fn test_parse_remote_url_given_self_hosted_ssh_url() { + let remote_url = "git@sourcehut.org:~zed-industries/zed"; + + let parsed_remote = SourceHut::from_remote_url(remote_url) + .unwrap() + .parse_remote_url(remote_url) + .unwrap(); + + assert_eq!( + parsed_remote, + ParsedGitRemote { + owner: "zed-industries".into(), + repo: "zed".into(), + } + ); + } + + #[test] + fn test_parse_remote_url_given_self_hosted_ssh_url_with_git_suffix() { + let remote_url = "git@sourcehut.org:~zed-industries/zed.git"; + + let parsed_remote = SourceHut::from_remote_url(remote_url) + .unwrap() + .parse_remote_url(remote_url) + .unwrap(); + + assert_eq!( + parsed_remote, + ParsedGitRemote { + owner: "zed-industries".into(), + repo: "zed.git".into(), + } + ); + } + + #[test] + fn test_parse_remote_url_given_self_hosted_https_url() { + let remote_url = "https://sourcehut.org/~zed-industries/zed"; + + let parsed_remote = SourceHut::from_remote_url(remote_url) + .unwrap() + .parse_remote_url(remote_url) + .unwrap(); + + assert_eq!( + parsed_remote, + ParsedGitRemote { + owner: "zed-industries".into(), + repo: "zed".into(), + } + ); + } + #[test] fn test_build_sourcehut_permalink() { - let permalink = Sourcehut.build_permalink( + let permalink = SourceHut::public_instance().build_permalink( ParsedGitRemote { owner: "zed-industries".into(), repo: "zed".into(), @@ -159,7 +251,7 @@ mod tests { #[test] fn test_build_sourcehut_permalink_with_git_suffix() { - let permalink = Sourcehut.build_permalink( + let permalink = SourceHut::public_instance().build_permalink( ParsedGitRemote { owner: "zed-industries".into(), repo: "zed.git".into(), @@ -175,9 +267,49 @@ mod tests { assert_eq!(permalink.to_string(), expected_url.to_string()) } + #[test] + fn test_build_sourcehut_self_hosted_permalink() { + let permalink = SourceHut::from_remote_url("https://sourcehut.org/~zed-industries/zed") + .unwrap() + .build_permalink( + ParsedGitRemote { + owner: "zed-industries".into(), + repo: "zed".into(), + }, + BuildPermalinkParams::new( + "faa6f979be417239b2e070dbbf6392b909224e0b", + &repo_path("crates/editor/src/git/permalink.rs"), + None, + ), + ); + + let expected_url = "https://sourcehut.org/~zed-industries/zed/tree/faa6f979be417239b2e070dbbf6392b909224e0b/item/crates/editor/src/git/permalink.rs"; + assert_eq!(permalink.to_string(), expected_url.to_string()) + } + + #[test] + fn test_build_sourcehut_self_hosted_permalink_with_git_suffix() { + let permalink = SourceHut::from_remote_url("https://sourcehut.org/~zed-industries/zed.git") + .unwrap() + .build_permalink( + ParsedGitRemote { + owner: "zed-industries".into(), + repo: "zed.git".into(), + }, + BuildPermalinkParams::new( + "faa6f979be417239b2e070dbbf6392b909224e0b", + &repo_path("crates/editor/src/git/permalink.rs"), + None, + ), + ); + + let expected_url = "https://sourcehut.org/~zed-industries/zed.git/tree/faa6f979be417239b2e070dbbf6392b909224e0b/item/crates/editor/src/git/permalink.rs"; + assert_eq!(permalink.to_string(), expected_url.to_string()) + } + #[test] fn test_build_sourcehut_permalink_with_single_line_selection() { - let permalink = Sourcehut.build_permalink( + let permalink = SourceHut::public_instance().build_permalink( ParsedGitRemote { owner: "zed-industries".into(), repo: "zed".into(), @@ -195,7 +327,7 @@ mod tests { #[test] fn test_build_sourcehut_permalink_with_multi_line_selection() { - let permalink = Sourcehut.build_permalink( + let permalink = SourceHut::public_instance().build_permalink( ParsedGitRemote { owner: "zed-industries".into(), repo: "zed".into(), @@ -210,4 +342,44 @@ mod tests { let expected_url = "https://git.sr.ht/~zed-industries/zed/tree/faa6f979be417239b2e070dbbf6392b909224e0b/item/crates/editor/src/git/permalink.rs#L24-48"; assert_eq!(permalink.to_string(), expected_url.to_string()) } + + #[test] + fn test_build_sourcehut_self_hosted_permalink_with_single_line_selection() { + let permalink = SourceHut::from_remote_url("https://sourcehut.org/~zed-industries/zed") + .unwrap() + .build_permalink( + ParsedGitRemote { + owner: "zed-industries".into(), + repo: "zed".into(), + }, + BuildPermalinkParams::new( + "faa6f979be417239b2e070dbbf6392b909224e0b", + &repo_path("crates/editor/src/git/permalink.rs"), + Some(6..6), + ), + ); + + let expected_url = "https://sourcehut.org/~zed-industries/zed/tree/faa6f979be417239b2e070dbbf6392b909224e0b/item/crates/editor/src/git/permalink.rs#L7"; + assert_eq!(permalink.to_string(), expected_url.to_string()) + } + + #[test] + fn test_build_sourcehut_self_hosted_permalink_with_multi_line_selection() { + let permalink = SourceHut::from_remote_url("https://sourcehut.org/~zed-industries/zed") + .unwrap() + .build_permalink( + ParsedGitRemote { + owner: "zed-industries".into(), + repo: "zed".into(), + }, + BuildPermalinkParams::new( + "faa6f979be417239b2e070dbbf6392b909224e0b", + &repo_path("crates/editor/src/git/permalink.rs"), + Some(23..47), + ), + ); + + let expected_url = "https://sourcehut.org/~zed-industries/zed/tree/faa6f979be417239b2e070dbbf6392b909224e0b/item/crates/editor/src/git/permalink.rs#L24-48"; + assert_eq!(permalink.to_string(), expected_url.to_string()) + } } diff --git a/crates/git_hosting_providers/src/settings.rs b/crates/git_hosting_providers/src/settings.rs index 9bf6c1022b04cc60b0fbaead5177f9fe8e6e0280..95243cbe4e4ba68249ba89b948a5ed2644909364 100644 --- a/crates/git_hosting_providers/src/settings.rs +++ b/crates/git_hosting_providers/src/settings.rs @@ -8,7 +8,7 @@ use settings::{ use url::Url; use util::ResultExt as _; -use crate::{Bitbucket, Github, Gitlab}; +use crate::{Bitbucket, Forgejo, Gitea, Github, Gitlab, SourceHut}; pub(crate) fn init(cx: &mut App) { init_git_hosting_provider_settings(cx); @@ -46,6 +46,11 @@ fn update_git_hosting_providers_from_settings(cx: &mut App) { } GitHostingProviderKind::Github => Arc::new(Github::new(&provider.name, url)) as _, GitHostingProviderKind::Gitlab => Arc::new(Gitlab::new(&provider.name, url)) as _, + GitHostingProviderKind::Gitea => Arc::new(Gitea::new(&provider.name, url)) as _, + GitHostingProviderKind::Forgejo => Arc::new(Forgejo::new(&provider.name, url)) as _, + GitHostingProviderKind::SourceHut => { + Arc::new(SourceHut::new(&provider.name, url)) as _ + } }) }); diff --git a/crates/git_ui/Cargo.toml b/crates/git_ui/Cargo.toml index 5e96cd3529b48bb401ee14e1a704b9bec485e356..beaf192b0ef538fb524ff4986710255040b89f27 100644 --- a/crates/git_ui/Cargo.toml +++ b/crates/git_ui/Cargo.toml @@ -13,7 +13,6 @@ name = "git_ui" path = "src/git_ui.rs" [features] -default = [] test-support = ["multi_buffer/test-support"] [dependencies] @@ -62,7 +61,8 @@ watch.workspace = true workspace.workspace = true zed_actions.workspace = true zeroize.workspace = true - +ztracing.workspace = true +tracing.workspace = true [target.'cfg(windows)'.dependencies] windows.workspace = true @@ -78,3 +78,6 @@ settings = { workspace = true, features = ["test-support"] } unindent.workspace = true workspace = { workspace = true, features = ["test-support"] } zlog.workspace = true + +[package.metadata.cargo-machete] +ignored = ["tracing"] diff --git a/crates/git_ui/src/blame_ui.rs b/crates/git_ui/src/blame_ui.rs index 47703e09824a49c633798c7967652d7f48f821be..fe9af196021e01edd406c9ced86e0465b7e70984 100644 --- a/crates/git_ui/src/blame_ui.rs +++ b/crates/git_ui/src/blame_ui.rs @@ -55,6 +55,7 @@ impl BlameRenderer for GitBlameRenderer { } else { None }; + Some( div() .mr_2() @@ -80,7 +81,10 @@ impl BlameRenderer for GitBlameRenderer { .on_mouse_down(MouseButton::Right, { let blame_entry = blame_entry.clone(); let details = details.clone(); + let editor = editor.clone(); move |event, window, cx| { + cx.stop_propagation(); + deploy_blame_entry_context_menu( &blame_entry, details.as_ref(), @@ -107,17 +111,19 @@ impl BlameRenderer for GitBlameRenderer { ) } }) - .hoverable_tooltip(move |_window, cx| { - cx.new(|cx| { - CommitTooltip::blame_entry( - &blame_entry, - details.clone(), - repository.clone(), - workspace.clone(), - cx, - ) + .when(!editor.read(cx).has_mouse_context_menu(), |el| { + el.hoverable_tooltip(move |_window, cx| { + cx.new(|cx| { + CommitTooltip::blame_entry( + &blame_entry, + details.clone(), + repository.clone(), + workspace.clone(), + cx, + ) + }) + .into() }) - .into() }), ) .into_any(), @@ -148,7 +154,7 @@ impl BlameRenderer for GitBlameRenderer { h_flex() .id("inline-blame") .w_full() - .font_family(style.font().family) + .font(style.font()) .text_color(cx.theme().status().hint) .line_height(style.line_height) .child(Icon::new(IconName::FileGit).color(Color::Hint)) @@ -396,6 +402,7 @@ fn deploy_blame_entry_context_menu( }); editor.update(cx, move |editor, cx| { + editor.hide_blame_popover(false, cx); editor.deploy_mouse_context_menu(position, context_menu, window, cx); cx.notify(); }); diff --git a/crates/git_ui/src/branch_picker.rs b/crates/git_ui/src/branch_picker.rs index 42e043cada2813126af3489c9769aca9c675999f..e198fa092e0cc2d0e8b77ba954fd743512915c75 100644 --- a/crates/git_ui/src/branch_picker.rs +++ b/crates/git_ui/src/branch_picker.rs @@ -17,8 +17,8 @@ use settings::Settings; use std::sync::Arc; use time::OffsetDateTime; use ui::{ - CommonAnimationExt, Divider, HighlightedLabel, KeyBinding, ListItem, ListItemSpacing, Tooltip, - prelude::*, + CommonAnimationExt, Divider, HighlightedLabel, KeyBinding, ListHeader, ListItem, + ListItemSpacing, Tooltip, prelude::*, }; use util::ResultExt; use workspace::notifications::DetachAndPromptErr; @@ -440,13 +440,6 @@ impl BranchListDelegate { cx.emit(DismissEvent); } - fn loader(&self) -> AnyElement { - Icon::new(IconName::LoadCircle) - .size(IconSize::Small) - .with_rotate_animation(3) - .into_any_element() - } - fn delete_at(&self, idx: usize, window: &mut Window, cx: &mut Context>) { let Some(entry) = self.matches.get(idx).cloned() else { return; @@ -558,6 +551,8 @@ impl PickerDelegate for BranchListDelegate { editor.set_placeholder_text(placeholder, window, cx); }); + let focus_handle = self.focus_handle.clone(); + v_flex() .when( self.editor_position() == PickerEditorPosition::End, @@ -569,7 +564,37 @@ impl PickerDelegate for BranchListDelegate { .flex_none() .h_9() .px_2p5() - .child(editor.clone()), + .child(editor.clone()) + .when( + self.editor_position() == PickerEditorPosition::End, + |this| { + let tooltip_label = if self.display_remotes { + "Turn Off Remote Filter" + } else { + "Filter Remote Branches" + }; + + this.gap_1().justify_between().child({ + IconButton::new("filter-remotes", IconName::Filter) + .disabled(self.loading) + .toggle_state(self.display_remotes) + .tooltip(move |_, cx| { + Tooltip::for_action_in( + tooltip_label, + &branch_picker::FilterRemotes, + &focus_handle, + cx, + ) + }) + .on_click(|_click, window, cx| { + window.dispatch_action( + branch_picker::FilterRemotes.boxed_clone(), + cx, + ); + }) + }) + }, + ), ) .when( self.editor_position() == PickerEditorPosition::Start, @@ -683,10 +708,16 @@ impl PickerDelegate for BranchListDelegate { } else { Entry::NewBranch { name: query } }; - picker.delegate.state = if is_url { - PickerState::NewRemote + // Only transition to NewBranch/NewRemote states when we only show their list item + // Otherwise, stay in List state so footer buttons remain visible + picker.delegate.state = if matches.is_empty() { + if is_url { + PickerState::NewRemote + } else { + PickerState::NewBranch + } } else { - PickerState::NewBranch + PickerState::List }; matches.push(entry); } else { @@ -770,7 +801,7 @@ impl PickerDelegate for BranchListDelegate { } else { None }; - self.create_branch(from_branch, format!("refs/heads/{name}").into(), window, cx); + self.create_branch(from_branch, name.into(), window, cx); } } @@ -812,226 +843,256 @@ impl PickerDelegate for BranchListDelegate { }) .unwrap_or_else(|| (None, None, None)); - let icon = if let Some(default_branch) = self.default_branch.clone() { - let icon = match &entry { - Entry::Branch { .. } => Some(( - IconName::GitBranchAlt, - format!("Create branch based off default: {default_branch}"), - )), - Entry::NewUrl { url } => { - Some((IconName::Screen, format!("Create remote based off {url}"))) - } - Entry::NewBranch { .. } => None, - }; - - icon.map(|(icon, tooltip_text)| { - IconButton::new("branch-from-default", icon) - .on_click(cx.listener(move |this, _, window, cx| { - this.delegate.set_selected_index(ix, window, cx); - this.delegate.confirm(true, window, cx); - })) - .tooltip(move |_window, cx| { - Tooltip::for_action(tooltip_text.clone(), &menu::SecondaryConfirm, cx) - }) - }) - } else { - None - }; + let entry_icon = match entry { + Entry::NewUrl { .. } | Entry::NewBranch { .. } => { + Icon::new(IconName::Plus).color(Color::Muted) + } - let icon_element = if self.display_remotes { - Icon::new(IconName::Screen) - } else { - Icon::new(IconName::GitBranchAlt) + Entry::Branch { .. } => { + if self.display_remotes { + Icon::new(IconName::Screen).color(Color::Muted) + } else { + Icon::new(IconName::GitBranchAlt).color(Color::Muted) + } + } }; - let entry_name = match entry { - Entry::NewUrl { .. } => h_flex() - .gap_1() - .child( - Icon::new(IconName::Plus) - .size(IconSize::Small) - .color(Color::Muted), - ) - .child( - Label::new("Create remote repository".to_string()) - .single_line() - .truncate(), - ) - .into_any_element(), - Entry::NewBranch { name } => h_flex() - .gap_1() - .child( - Icon::new(IconName::Plus) - .size(IconSize::Small) - .color(Color::Muted), - ) - .child( - Label::new(format!("Create branch \"{name}\"…")) - .single_line() - .truncate(), - ) + let entry_title = match entry { + Entry::NewUrl { .. } => Label::new("Create Remote Repository") + .single_line() + .truncate() .into_any_element(), - Entry::Branch { branch, positions } => h_flex() - .max_w_48() - .child(h_flex().mr_1().child(icon_element)) - .child( - HighlightedLabel::new(branch.name().to_string(), positions.clone()).truncate(), - ) + Entry::NewBranch { name } => Label::new(format!("Create Branch: \"{name}\"…")) + .single_line() + .truncate() .into_any_element(), + Entry::Branch { branch, positions } => { + HighlightedLabel::new(branch.name().to_string(), positions.clone()) + .single_line() + .truncate() + .into_any_element() + } }; + let focus_handle = self.focus_handle.clone(); + let is_new_items = matches!(entry, Entry::NewUrl { .. } | Entry::NewBranch { .. }); + + let delete_branch_button = IconButton::new("delete", IconName::Trash) + .tooltip(move |_, cx| { + Tooltip::for_action_in( + "Delete Branch", + &branch_picker::DeleteBranch, + &focus_handle, + cx, + ) + }) + .on_click(cx.listener(|this, _, window, cx| { + let selected_idx = this.delegate.selected_index(); + this.delegate.delete_at(selected_idx, window, cx); + })); + + let create_from_default_button = self.default_branch.as_ref().map(|default_branch| { + let tooltip_label: SharedString = format!("Create New From: {default_branch}").into(); + let focus_handle = self.focus_handle.clone(); + + IconButton::new("create_from_default", IconName::GitBranchPlus) + .tooltip(move |_, cx| { + Tooltip::for_action_in( + tooltip_label.clone(), + &menu::SecondaryConfirm, + &focus_handle, + cx, + ) + }) + .on_click(cx.listener(|this, _, window, cx| { + this.delegate.confirm(true, window, cx); + })) + .into_any_element() + }); + Some( ListItem::new(SharedString::from(format!("vcs-menu-{ix}"))) .inset(true) .spacing(ListItemSpacing::Sparse) .toggle_state(selected) - .tooltip({ - match entry { - Entry::Branch { branch, .. } => Tooltip::text(branch.name().to_string()), - Entry::NewUrl { .. } => { - Tooltip::text("Create remote repository".to_string()) - } - Entry::NewBranch { name } => { - Tooltip::text(format!("Create branch \"{name}\"")) - } - } - }) .child( - v_flex() + h_flex() .w_full() - .overflow_hidden() + .gap_3() + .flex_grow() + .child(entry_icon) .child( - h_flex() - .gap_6() - .justify_between() - .overflow_x_hidden() - .child(entry_name) - .when_some(commit_time, |label, commit_time| { - label.child( - Label::new(commit_time) - .size(LabelSize::Small) - .color(Color::Muted) - .into_element(), - ) - }), - ) - .when(self.style == BranchListStyle::Modal, |el| { - el.child(div().max_w_96().child({ - let message = match entry { - Entry::NewUrl { url } => format!("based off {url}"), - Entry::NewBranch { .. } => { - if let Some(current_branch) = - self.repo.as_ref().and_then(|repo| { - repo.read(cx).branch.as_ref().map(|b| b.name()) - }) - { - format!("based off {}", current_branch) - } else { - "based off the current branch".to_string() - } - } - Entry::Branch { .. } => { - let show_author_name = ProjectSettings::get_global(cx) - .git - .branch_picker - .show_author_name; - - subject.map_or("no commits found".into(), |subject| { - if show_author_name && author_name.is_some() { - format!("{} • {}", author_name.unwrap(), subject) - } else { - subject.to_string() - } + v_flex() + .id("info_container") + .w_full() + .child(entry_title) + .child( + h_flex() + .w_full() + .justify_between() + .gap_1p5() + .when(self.style == BranchListStyle::Modal, |el| { + el.child(div().max_w_96().child({ + let message = match entry { + Entry::NewUrl { url } => { + format!("Based off {url}") + } + Entry::NewBranch { .. } => { + if let Some(current_branch) = + self.repo.as_ref().and_then(|repo| { + repo.read(cx) + .branch + .as_ref() + .map(|b| b.name()) + }) + { + format!("Based off {}", current_branch) + } else { + "Based off the current branch" + .to_string() + } + } + Entry::Branch { .. } => { + let show_author_name = + ProjectSettings::get_global(cx) + .git + .branch_picker + .show_author_name; + + subject.map_or( + "No commits found".into(), + |subject| { + if show_author_name + && author_name.is_some() + { + format!( + "{} • {}", + author_name.unwrap(), + subject + ) + } else { + subject.to_string() + } + }, + ) + } + }; + + Label::new(message) + .size(LabelSize::Small) + .color(Color::Muted) + .truncate() + })) }) - } - }; - - Label::new(message) - .size(LabelSize::Small) - .truncate() - .color(Color::Muted) - })) - }), + .when_some(commit_time, |label, commit_time| { + label.child( + Label::new(commit_time) + .size(LabelSize::Small) + .color(Color::Muted), + ) + }), + ) + .when_some( + entry.as_branch().map(|b| b.name().to_string()), + |this, branch_name| this.tooltip(Tooltip::text(branch_name)), + ), + ), + ) + .when( + self.editor_position() == PickerEditorPosition::End && !is_new_items, + |this| { + this.map(|this| { + if self.selected_index() == ix { + this.end_slot(delete_branch_button) + } else { + this.end_hover_slot(delete_branch_button) + } + }) + }, ) - .end_slot::(icon), + .when_some( + if self.editor_position() == PickerEditorPosition::End && is_new_items { + create_from_default_button + } else { + None + }, + |this, create_from_default_button| { + this.map(|this| { + if self.selected_index() == ix { + this.end_slot(create_from_default_button) + } else { + this.end_hover_slot(create_from_default_button) + } + }) + }, + ), ) } fn render_header( &self, _window: &mut Window, - cx: &mut Context>, + _cx: &mut Context>, ) -> Option { - if matches!( - self.state, - PickerState::CreateRemote(_) | PickerState::NewRemote | PickerState::NewBranch - ) { + matches!(self.state, PickerState::List).then(|| { + let label = if self.display_remotes { + "Remote" + } else { + "Local" + }; + + ListHeader::new(label).inset(true).into_any_element() + }) + } + + fn render_footer(&self, _: &mut Window, cx: &mut Context>) -> Option { + if self.editor_position() == PickerEditorPosition::End { return None; } - let label = if self.display_remotes { - "Remote" - } else { - "Local" - }; - Some( + + let focus_handle = self.focus_handle.clone(); + let loading_icon = Icon::new(IconName::LoadCircle) + .size(IconSize::Small) + .with_rotate_animation(3); + + let footer_container = || { h_flex() .w_full() .p_1p5() - .gap_1() .border_t_1() .border_color(cx.theme().colors().border_variant) - .child(Label::new(label).size(LabelSize::Small).color(Color::Muted)) - .into_any(), - ) - } - - fn render_footer(&self, _: &mut Window, cx: &mut Context>) -> Option { - let focus_handle = self.focus_handle.clone(); + }; - if self.loading { - return Some( - h_flex() - .w_full() - .p_1p5() - .gap_1() - .justify_end() - .border_t_1() - .border_color(cx.theme().colors().border_variant) - .child(self.loader()) - .into_any(), - ); - } match self.state { - PickerState::List => Some( - h_flex() - .w_full() - .p_1p5() - .gap_0p5() - .border_t_1() - .border_color(cx.theme().colors().border_variant) - .justify_between() - .child( - Button::new("filter-remotes", "Filter remotes") + PickerState::List => { + let selected_entry = self.matches.get(self.selected_index); + + let branch_from_default_button = self + .default_branch + .as_ref() + .filter(|_| matches!(selected_entry, Some(Entry::NewBranch { .. }))) + .map(|default_branch| { + let button_label = format!("Create New From: {default_branch}"); + + Button::new("branch-from-default", button_label) .key_binding( KeyBinding::for_action_in( - &branch_picker::FilterRemotes, + &menu::SecondaryConfirm, &focus_handle, cx, ) .map(|kb| kb.size(rems_from_px(12.))), ) - .on_click(|_click, window, cx| { - window.dispatch_action( - branch_picker::FilterRemotes.boxed_clone(), - cx, - ); - }) - .disabled(self.loading) - .style(ButtonStyle::Subtle) - .toggle_state(self.display_remotes), - ) + .on_click(cx.listener(|this, _, window, cx| { + this.delegate.confirm(true, window, cx); + })) + }); + + let delete_and_select_btns = h_flex() + .gap_1() .child( Button::new("delete-branch", "Delete") + .disabled(self.loading) .key_binding( KeyBinding::for_action_in( &branch_picker::DeleteBranch, @@ -1040,43 +1101,134 @@ impl PickerDelegate for BranchListDelegate { ) .map(|kb| kb.size(rems_from_px(12.))), ) - .disabled(self.loading) .on_click(|_, window, cx| { window .dispatch_action(branch_picker::DeleteBranch.boxed_clone(), cx); }), ) - .when(self.loading, |this| this.child(self.loader())) - .into_any(), - ), + .child( + Button::new("select_branch", "Select") + .key_binding( + KeyBinding::for_action_in(&menu::Confirm, &focus_handle, cx) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(cx.listener(|this, _, window, cx| { + this.delegate.confirm(false, window, cx); + })), + ); + + Some( + footer_container() + .map(|this| { + if branch_from_default_button.is_some() { + this.justify_end().when_some( + branch_from_default_button, + |this, button| { + this.child(button).child( + Button::new("create", "Create") + .key_binding( + KeyBinding::for_action_in( + &menu::Confirm, + &focus_handle, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(cx.listener(|this, _, window, cx| { + this.delegate.confirm(false, window, cx); + })), + ) + }, + ) + } else if self.loading { + this.justify_between() + .child(loading_icon) + .child(delete_and_select_btns) + } else { + this.justify_between() + .child({ + let focus_handle = focus_handle.clone(); + Button::new("filter-remotes", "Filter Remotes") + .disabled(self.loading) + .toggle_state(self.display_remotes) + .key_binding( + KeyBinding::for_action_in( + &branch_picker::FilterRemotes, + &focus_handle, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(|_click, window, cx| { + window.dispatch_action( + branch_picker::FilterRemotes.boxed_clone(), + cx, + ); + }) + }) + .child(delete_and_select_btns) + } + }) + .into_any_element(), + ) + } + PickerState::NewBranch => { + let branch_from_default_button = + self.default_branch.as_ref().map(|default_branch| { + let button_label = format!("Create New From: {default_branch}"); + + Button::new("branch-from-default", button_label) + .key_binding( + KeyBinding::for_action_in( + &menu::SecondaryConfirm, + &focus_handle, + cx, + ) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(cx.listener(|this, _, window, cx| { + this.delegate.confirm(true, window, cx); + })) + }); + + Some( + footer_container() + .gap_1() + .justify_end() + .when_some(branch_from_default_button, |this, button| { + this.child(button) + }) + .child( + Button::new("branch-from-default", "Create") + .key_binding( + KeyBinding::for_action_in(&menu::Confirm, &focus_handle, cx) + .map(|kb| kb.size(rems_from_px(12.))), + ) + .on_click(cx.listener(|this, _, window, cx| { + this.delegate.confirm(false, window, cx); + })), + ) + .into_any_element(), + ) + } PickerState::CreateRemote(_) => Some( - h_flex() - .w_full() - .p_1p5() - .gap_1() - .border_t_1() - .border_color(cx.theme().colors().border_variant) + footer_container() + .justify_end() .child( Label::new("Choose a name for this remote repository") .size(LabelSize::Small) .color(Color::Muted), ) .child( - h_flex().w_full().justify_end().child( - Label::new("Save") - .size(LabelSize::Small) - .color(Color::Muted), - ), + Label::new("Save") + .size(LabelSize::Small) + .color(Color::Muted), ) - .into_any(), + .into_any_element(), ), - PickerState::NewRemote | PickerState::NewBranch => None, + PickerState::NewRemote => None, } } - - fn no_matches_text(&self, _window: &mut Window, _cx: &mut App) -> Option { - None - } } #[cfg(test)] @@ -1506,6 +1658,7 @@ mod tests { let last_match = picker.delegate.matches.last().unwrap(); assert!(last_match.is_new_branch()); assert_eq!(last_match.name(), "new-feature-branch"); + // State is NewBranch because no existing branches fuzzy-match the query assert!(matches!(picker.delegate.state, PickerState::NewBranch)); picker.delegate.confirm(false, window, cx); }) @@ -1527,10 +1680,14 @@ mod tests { .unwrap() .unwrap(); - assert!( - branches - .into_iter() - .any(|branch| branch.name() == "new-feature-branch") + let new_branch = branches + .into_iter() + .find(|branch| branch.name() == "new-feature-branch") + .expect("new-feature-branch should exist"); + assert_eq!( + new_branch.ref_name.as_ref(), + "refs/heads/new-feature-branch", + "branch ref_name should not have duplicate refs/heads/ prefix" ); } diff --git a/crates/git_ui/src/commit_view.rs b/crates/git_ui/src/commit_view.rs index 31ac8139a63be218f652204ebe29d43e526c5a02..c637ea674f7e58954c186e1557df251d0d22d36b 100644 --- a/crates/git_ui/src/commit_view.rs +++ b/crates/git_ui/src/commit_view.rs @@ -1,7 +1,9 @@ use anyhow::{Context as _, Result}; use buffer_diff::{BufferDiff, BufferDiffSnapshot}; use editor::display_map::{BlockPlacement, BlockProperties, BlockStyle}; -use editor::{Addon, Editor, EditorEvent, ExcerptId, ExcerptRange, MultiBuffer}; +use editor::{ + Editor, EditorEvent, ExcerptId, ExcerptRange, MultiBuffer, multibuffer_context_lines, +}; use git::repository::{CommitDetails, CommitDiff, RepoPath}; use git::{GitHostingProviderRegistry, GitRemote, parse_git_remote_url}; use gpui::{ @@ -10,10 +12,9 @@ use gpui::{ PromptLevel, Render, Styled, Task, WeakEntity, Window, actions, }; use language::{ - Anchor, Buffer, Capability, DiskState, File, LanguageRegistry, LineEnding, ReplicaId, Rope, - TextBuffer, ToPoint, + Anchor, Buffer, Capability, DiskState, File, LanguageRegistry, LineEnding, OffsetRangeExt as _, + ReplicaId, Rope, TextBuffer, }; -use multi_buffer::ExcerptInfo; use multi_buffer::PathKey; use project::{Project, WorktreeId, git_store::Repository}; use std::{ @@ -22,11 +23,9 @@ use std::{ sync::Arc, }; use theme::ActiveTheme; -use ui::{ - Avatar, Button, ButtonCommon, Clickable, Color, Icon, IconName, IconSize, Label, - LabelCommon as _, LabelSize, SharedString, div, h_flex, v_flex, -}; +use ui::{Avatar, DiffStat, Tooltip, prelude::*}; use util::{ResultExt, paths::PathStyle, rel_path::RelPath, truncate_and_trailoff}; +use workspace::item::TabTooltipContent; use workspace::{ Item, ItemHandle, ItemNavHistory, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace, @@ -151,11 +150,10 @@ impl CommitView { let editor = cx.new(|cx| { let mut editor = Editor::for_multibuffer(multibuffer.clone(), Some(project.clone()), window, cx); + editor.disable_inline_diagnostics(); editor.set_expand_all_diff_hunks(cx); - editor.register_addon(CommitViewAddon { - multibuffer: multibuffer.downgrade(), - }); + editor }); let commit_sha = Arc::::from(commit.sha.as_ref()); @@ -206,33 +204,22 @@ impl CommitView { this.multibuffer.update(cx, |multibuffer, cx| { let snapshot = buffer.read(cx).snapshot(); let path = snapshot.file().unwrap().path().clone(); - - let hunks: Vec<_> = buffer_diff.read(cx).hunks(&snapshot, cx).collect(); - - let excerpt_ranges = if hunks.is_empty() { - vec![language::Point::zero()..snapshot.max_point()] - } else { - hunks - .into_iter() - .map(|hunk| { - let start = hunk.range.start.max(language::Point::new( - hunk.range.start.row.saturating_sub(3), - 0, - )); - let end_row = - (hunk.range.end.row + 3).min(snapshot.max_point().row); - let end = - language::Point::new(end_row, snapshot.line_len(end_row)); - start..end - }) - .collect() + let excerpt_ranges = { + let mut hunks = buffer_diff.read(cx).hunks(&snapshot, cx).peekable(); + if hunks.peek().is_none() { + vec![language::Point::zero()..snapshot.max_point()] + } else { + hunks + .map(|hunk| hunk.buffer_range.to_point(&snapshot)) + .collect::>() + } }; let _is_newly_added = multibuffer.set_excerpts_for_path( PathKey::with_sort_prefix(FILE_NAMESPACE_SORT_PREFIX, path), buffer, excerpt_ranges, - 0, + multibuffer_context_lines(cx), cx, ); multibuffer.add_diff(buffer_diff, cx); @@ -262,6 +249,8 @@ impl CommitView { this.editor.update(cx, |editor, cx| { editor.disable_header_for_buffer(message_buffer.read(cx).remote_id(), cx); + editor + .disable_indent_guides_for_buffer(message_buffer.read(cx).remote_id(), cx); editor.insert_blocks( [BlockProperties { @@ -357,6 +346,41 @@ impl CommitView { .into_any() } + fn calculate_changed_lines(&self, cx: &App) -> (u32, u32) { + let snapshot = self.multibuffer.read(cx).snapshot(cx); + let mut total_additions = 0u32; + let mut total_deletions = 0u32; + + let mut seen_buffers = std::collections::HashSet::new(); + for (_, buffer, _) in snapshot.excerpts() { + let buffer_id = buffer.remote_id(); + if !seen_buffers.insert(buffer_id) { + continue; + } + + let Some(diff) = snapshot.diff_for_buffer_id(buffer_id) else { + continue; + }; + + let base_text = diff.base_text(); + + for hunk in diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, buffer) { + let added_rows = hunk.range.end.row.saturating_sub(hunk.range.start.row); + total_additions += added_rows; + + let base_start = base_text + .offset_to_point(hunk.diff_base_byte_range.start) + .row; + let base_end = base_text.offset_to_point(hunk.diff_base_byte_range.end).row; + let deleted_rows = base_end.saturating_sub(base_start); + + total_deletions += deleted_rows; + } + } + + (total_additions, total_deletions) + } + fn render_header(&self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let commit = &self.commit; let author_name = commit.author_name.clone(); @@ -380,46 +404,72 @@ impl CommitView { ) }); - v_flex() - .p_4() - .pl_0() - .gap_4() + let (additions, deletions) = self.calculate_changed_lines(cx); + + let commit_diff_stat = if additions > 0 || deletions > 0 { + Some(DiffStat::new( + "commit-diff-stat", + additions as usize, + deletions as usize, + )) + } else { + None + }; + + h_flex() .border_b_1() - .border_color(cx.theme().colors().border) + .border_color(cx.theme().colors().border_variant) + .child( + h_flex() + .w(self.editor.read(cx).last_gutter_dimensions().full_width()) + .justify_center() + .child(self.render_commit_avatar(&commit.sha, rems_from_px(48.), window, cx)), + ) .child( h_flex() + .py_4() + .pl_1() + .pr_4() + .w_full() .items_start() - .child( - h_flex() - .w(self.editor.read(cx).last_gutter_dimensions().full_width()) - .justify_center() - .child(self.render_commit_avatar( - &commit.sha, - gpui::rems(3.0), - window, - cx, - )), - ) + .justify_between() + .flex_wrap() .child( v_flex() - .gap_1() .child( h_flex() - .gap_3() - .items_baseline() + .gap_1() .child(Label::new(author_name).color(Color::Default)) .child( - Label::new(format!("commit {}", commit.sha)) - .color(Color::Muted), + Label::new(format!("Commit:{}", commit.sha)) + .color(Color::Muted) + .size(LabelSize::Small) + .truncate() + .buffer_font(cx), ), ) - .child(Label::new(date_string).color(Color::Muted)), + .child( + h_flex() + .gap_1p5() + .child( + Label::new(date_string) + .color(Color::Muted) + .size(LabelSize::Small), + ) + .child( + Label::new("•") + .color(Color::Ignored) + .size(LabelSize::Small), + ) + .children(commit_diff_stat), + ), ) - .child(div().flex_grow()) .children(github_url.map(|url| { Button::new("view_on_github", "View on GitHub") .icon(IconName::Github) - .style(ui::ButtonStyle::Subtle) + .icon_color(Color::Muted) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) .on_click(move |_, _, cx| cx.open_url(&url)) })), ) @@ -714,55 +764,6 @@ impl language::File for GitBlob { // } // } -struct CommitViewAddon { - multibuffer: WeakEntity, -} - -impl Addon for CommitViewAddon { - fn render_buffer_header_controls( - &self, - excerpt: &ExcerptInfo, - _window: &Window, - cx: &App, - ) -> Option { - let multibuffer = self.multibuffer.upgrade()?; - let snapshot = multibuffer.read(cx).snapshot(cx); - let excerpts = snapshot.excerpts().collect::>(); - let current_idx = excerpts.iter().position(|(id, _, _)| *id == excerpt.id)?; - let (_, _, current_range) = &excerpts[current_idx]; - - let start_row = current_range.context.start.to_point(&excerpt.buffer).row; - - let prev_end_row = if current_idx > 0 { - let (_, prev_buffer, prev_range) = &excerpts[current_idx - 1]; - if prev_buffer.remote_id() == excerpt.buffer_id { - prev_range.context.end.to_point(&excerpt.buffer).row - } else { - 0 - } - } else { - 0 - }; - - let skipped_lines = start_row.saturating_sub(prev_end_row); - - if skipped_lines > 0 { - Some( - Label::new(format!("{} unchanged lines", skipped_lines)) - .color(Color::Muted) - .size(LabelSize::Small) - .into_any_element(), - ) - } else { - None - } - } - - fn to_any(&self) -> &dyn Any { - self - } -} - async fn build_buffer( mut text: String, blob: Arc, @@ -865,13 +866,28 @@ impl Item for CommitView { fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString { let short_sha = self.commit.sha.get(0..7).unwrap_or(&*self.commit.sha); let subject = truncate_and_trailoff(self.commit.message.split('\n').next().unwrap(), 20); - format!("{short_sha} - {subject}").into() + format!("{short_sha} — {subject}").into() } - fn tab_tooltip_text(&self, _: &App) -> Option { + fn tab_tooltip_content(&self, _: &App) -> Option { let short_sha = self.commit.sha.get(0..16).unwrap_or(&*self.commit.sha); let subject = self.commit.message.split('\n').next().unwrap(); - Some(format!("{short_sha} - {subject}").into()) + + Some(TabTooltipContent::Custom(Box::new(Tooltip::element({ + let subject = subject.to_string(); + let short_sha = short_sha.to_string(); + + move |_, _| { + v_flex() + .child(Label::new(subject.clone())) + .child( + Label::new(short_sha.clone()) + .color(Color::Muted) + .size(LabelSize::Small), + ) + .into_any_element() + } + })))) } fn to_item_events(event: &EditorEvent, f: impl FnMut(ItemEvent)) { @@ -988,12 +1004,11 @@ impl Item for CommitView { impl Render for CommitView { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let is_stash = self.stash.is_some(); - div() + + v_flex() .key_context(if is_stash { "StashDiff" } else { "CommitDiff" }) - .bg(cx.theme().colors().editor_background) - .flex() - .flex_col() .size_full() + .bg(cx.theme().colors().editor_background) .child(self.render_header(window, cx)) .child(div().flex_grow().child(self.editor.clone())) } @@ -1013,7 +1028,7 @@ impl EventEmitter for CommitViewToolbar {} impl Render for CommitViewToolbar { fn render(&mut self, _window: &mut Window, _cx: &mut Context) -> impl IntoElement { - div() + div().hidden() } } diff --git a/crates/git_ui/src/conflict_view.rs b/crates/git_ui/src/conflict_view.rs index 91cc3ce76b3f10aa310185b566b6c6086580b69c..2f954bfe1045d9819c1f7c276346a6f811c09108 100644 --- a/crates/git_ui/src/conflict_view.rs +++ b/crates/git_ui/src/conflict_view.rs @@ -372,7 +372,7 @@ fn render_conflict_buttons( .gap_1() .bg(cx.theme().colors().editor_background) .child( - Button::new("head", "Use HEAD") + Button::new("head", format!("Use {}", conflict.ours_branch_name)) .label_size(LabelSize::Small) .on_click({ let editor = editor.clone(); @@ -392,7 +392,7 @@ fn render_conflict_buttons( }), ) .child( - Button::new("origin", "Use Origin") + Button::new("origin", format!("Use {}", conflict.theirs_branch_name)) .label_size(LabelSize::Small) .on_click({ let editor = editor.clone(); diff --git a/crates/git_ui/src/file_history_view.rs b/crates/git_ui/src/file_history_view.rs index e34806aacae48122caf3a12246b04862898f2bed..5b3588d29678ec406749ec45be3de154fd71c5f8 100644 --- a/crates/git_ui/src/file_history_view.rs +++ b/crates/git_ui/src/file_history_view.rs @@ -267,15 +267,19 @@ impl FileHistoryView { .child(self.render_commit_avatar(&entry.sha, window, cx)) .child( h_flex() + .min_w_0() .w_full() .justify_between() .child( h_flex() + .min_w_0() + .w_full() .gap_1() .child( Label::new(entry.author_name.clone()) .size(LabelSize::Small) - .color(Color::Default), + .color(Color::Default) + .truncate(), ) .child( Label::new(&entry.subject) @@ -285,9 +289,11 @@ impl FileHistoryView { ), ) .child( - Label::new(relative_timestamp) - .size(LabelSize::Small) - .color(Color::Muted), + h_flex().flex_none().child( + Label::new(relative_timestamp) + .size(LabelSize::Small) + .color(Color::Muted), + ), ), ), ) diff --git a/crates/git_ui/src/project_diff.rs b/crates/git_ui/src/project_diff.rs index 0a8667ba6c753f9b7925948f212388f0668c1c92..e560bba0d36ad9901fffa9b5aad4dbd88e3108b6 100644 --- a/crates/git_ui/src/project_diff.rs +++ b/crates/git_ui/src/project_diff.rs @@ -34,7 +34,6 @@ use project::{ use settings::{Settings, SettingsStore}; use smol::future::yield_now; use std::any::{Any, TypeId}; -use std::ops::Range; use std::sync::Arc; use theme::ActiveTheme; use ui::{KeyBinding, Tooltip, prelude::*, vertical_divider}; @@ -46,6 +45,7 @@ use workspace::{ notifications::NotifyTaskExt, searchable::SearchableItemHandle, }; +use ztracing::instrument; actions!( git, @@ -469,6 +469,7 @@ impl ProjectDiff { } } + #[instrument(skip_all)] fn register_buffer( &mut self, path_key: PathKey, @@ -498,23 +499,30 @@ impl ProjectDiff { let snapshot = buffer.read(cx).snapshot(); let diff_read = diff.read(cx); - let diff_hunk_ranges = diff_read - .hunks_intersecting_range( - Anchor::min_max_range_for_buffer(diff_read.buffer_id), - &snapshot, - cx, - ) - .map(|diff_hunk| diff_hunk.buffer_range); - let conflicts = conflict_addon - .conflict_set(snapshot.remote_id()) - .map(|conflict_set| conflict_set.read(cx).snapshot().conflicts) - .unwrap_or_default(); - let conflicts = conflicts.iter().map(|conflict| conflict.range.clone()); - - let excerpt_ranges = - merge_anchor_ranges(diff_hunk_ranges.into_iter(), conflicts, &snapshot) - .map(|range| range.to_point(&snapshot)) - .collect::>(); + + let excerpt_ranges = { + let diff_hunk_ranges = diff_read + .hunks_intersecting_range( + Anchor::min_max_range_for_buffer(diff_read.buffer_id), + &snapshot, + cx, + ) + .map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot)); + let conflicts = conflict_addon + .conflict_set(snapshot.remote_id()) + .map(|conflict_set| conflict_set.read(cx).snapshot().conflicts) + .unwrap_or_default(); + let mut conflicts = conflicts + .iter() + .map(|conflict| conflict.range.to_point(&snapshot)) + .peekable(); + + if conflicts.peek().is_some() { + conflicts.collect::>() + } else { + diff_hunk_ranges.collect() + } + }; let (was_empty, is_excerpt_newly_added) = self.multibuffer.update(cx, |multibuffer, cx| { let was_empty = multibuffer.is_empty(); @@ -1542,53 +1550,6 @@ mod preview { } } -fn merge_anchor_ranges<'a>( - left: impl 'a + Iterator>, - right: impl 'a + Iterator>, - snapshot: &'a language::BufferSnapshot, -) -> impl 'a + Iterator> { - let mut left = left.fuse().peekable(); - let mut right = right.fuse().peekable(); - - std::iter::from_fn(move || { - let Some(left_range) = left.peek() else { - return right.next(); - }; - let Some(right_range) = right.peek() else { - return left.next(); - }; - - let mut next_range = if left_range.start.cmp(&right_range.start, snapshot).is_lt() { - left.next().unwrap() - } else { - right.next().unwrap() - }; - - // Extend the basic range while there's overlap with a range from either stream. - loop { - if let Some(left_range) = left - .peek() - .filter(|range| range.start.cmp(&next_range.end, snapshot).is_le()) - .cloned() - { - left.next(); - next_range.end = left_range.end; - } else if let Some(right_range) = right - .peek() - .filter(|range| range.start.cmp(&next_range.end, snapshot).is_le()) - .cloned() - { - right.next(); - next_range.end = right_range.end; - } else { - break; - } - } - - Some(next_range) - }) -} - struct BranchDiffAddon { branch_diff: Entity, } diff --git a/crates/gpui/src/app/entity_map.rs b/crates/gpui/src/app/entity_map.rs index 81dbfdbf5733eed92a77fc2dc18fb971bd9bd4a7..8c1bdfa1cee509dcbc061200cb651ce5d3bf4fcd 100644 --- a/crates/gpui/src/app/entity_map.rs +++ b/crates/gpui/src/app/entity_map.rs @@ -584,7 +584,33 @@ impl AnyWeakEntity { }) } - /// Assert that entity referenced by this weak handle has been released. + /// Asserts that the entity referenced by this weak handle has been fully released. + /// + /// # Example + /// + /// ```ignore + /// let entity = cx.new(|_| MyEntity::new()); + /// let weak = entity.downgrade(); + /// drop(entity); + /// + /// // Verify the entity was released + /// weak.assert_released(); + /// ``` + /// + /// # Debugging Leaks + /// + /// If this method panics due to leaked handles, set the `LEAK_BACKTRACE` environment + /// variable to see where the leaked handles were allocated: + /// + /// ```bash + /// LEAK_BACKTRACE=1 cargo test my_test + /// ``` + /// + /// # Panics + /// + /// - Panics if any strong handles to the entity are still alive. + /// - Panics if the entity was recently dropped but cleanup hasn't completed yet + /// (resources are retained until the end of the effect cycle). #[cfg(any(test, feature = "leak-detection"))] pub fn assert_released(&self) { self.entity_ref_counts @@ -814,16 +840,70 @@ impl PartialOrd for WeakEntity { } } +/// Controls whether backtraces are captured when entity handles are created. +/// +/// Set the `LEAK_BACKTRACE` environment variable to any non-empty value to enable +/// backtrace capture. This helps identify where leaked handles were allocated. #[cfg(any(test, feature = "leak-detection"))] static LEAK_BACKTRACE: std::sync::LazyLock = std::sync::LazyLock::new(|| std::env::var("LEAK_BACKTRACE").is_ok_and(|b| !b.is_empty())); +/// Unique identifier for a specific entity handle instance. +/// +/// This is distinct from `EntityId` - while multiple handles can point to the same +/// entity (same `EntityId`), each handle has its own unique `HandleId`. #[cfg(any(test, feature = "leak-detection"))] #[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq)] pub(crate) struct HandleId { - id: u64, // id of the handle itself, not the pointed at object + id: u64, } +/// Tracks entity handle allocations to detect leaks. +/// +/// The leak detector is enabled in tests and when the `leak-detection` feature is active. +/// It tracks every `Entity` and `AnyEntity` handle that is created and released, +/// allowing you to verify that all handles to an entity have been properly dropped. +/// +/// # How do leaks happen? +/// +/// Entities are reference-counted structures that can own other entities +/// allowing to form cycles. If such a strong-reference counted cycle is +/// created, all participating strong entities in this cycle will effectively +/// leak as they cannot be released anymore. +/// +/// # Usage +/// +/// You can use `WeakEntity::assert_released` or `AnyWeakEntity::assert_released` +/// to verify that an entity has been fully released: +/// +/// ```ignore +/// let entity = cx.new(|_| MyEntity::new()); +/// let weak = entity.downgrade(); +/// drop(entity); +/// +/// // This will panic if any handles to the entity are still alive +/// weak.assert_released(); +/// ``` +/// +/// # Debugging Leaks +/// +/// When a leak is detected, the detector will panic with information about the leaked +/// handles. To see where the leaked handles were allocated, set the `LEAK_BACKTRACE` +/// environment variable: +/// +/// ```bash +/// LEAK_BACKTRACE=1 cargo test my_test +/// ``` +/// +/// This will capture and display backtraces for each leaked handle, helping you +/// identify where handles were created but not released. +/// +/// # How It Works +/// +/// - When an entity handle is created (via `Entity::new`, `Entity::clone`, or +/// `WeakEntity::upgrade`), `handle_created` is called to register the handle. +/// - When a handle is dropped, `handle_released` removes it from tracking. +/// - `assert_released` verifies that no handles remain for a given entity. #[cfg(any(test, feature = "leak-detection"))] pub(crate) struct LeakDetector { next_handle_id: u64, @@ -832,6 +912,11 @@ pub(crate) struct LeakDetector { #[cfg(any(test, feature = "leak-detection"))] impl LeakDetector { + /// Records that a new handle has been created for the given entity. + /// + /// Returns a unique `HandleId` that must be passed to `handle_released` when + /// the handle is dropped. If `LEAK_BACKTRACE` is set, captures a backtrace + /// at the allocation site. #[track_caller] pub fn handle_created(&mut self, entity_id: EntityId) -> HandleId { let id = util::post_inc(&mut self.next_handle_id); @@ -844,23 +929,40 @@ impl LeakDetector { handle_id } + /// Records that a handle has been released (dropped). + /// + /// This removes the handle from tracking. The `handle_id` should be the same + /// one returned by `handle_created` when the handle was allocated. pub fn handle_released(&mut self, entity_id: EntityId, handle_id: HandleId) { let handles = self.entity_handles.entry(entity_id).or_default(); handles.remove(&handle_id); } + /// Asserts that all handles to the given entity have been released. + /// + /// # Panics + /// + /// Panics if any handles to the entity are still alive. The panic message + /// includes backtraces for each leaked handle if `LEAK_BACKTRACE` is set, + /// otherwise it suggests setting the environment variable to get more info. pub fn assert_released(&mut self, entity_id: EntityId) { + use std::fmt::Write as _; let handles = self.entity_handles.entry(entity_id).or_default(); if !handles.is_empty() { + let mut out = String::new(); for backtrace in handles.values_mut() { if let Some(mut backtrace) = backtrace.take() { backtrace.resolve(); - eprintln!("Leaked handle: {:#?}", backtrace); + writeln!(out, "Leaked handle:\n{:?}", backtrace).unwrap(); } else { - eprintln!("Leaked handle: export LEAK_BACKTRACE to find allocation site"); + writeln!( + out, + "Leaked handle: (export LEAK_BACKTRACE to find allocation site)" + ) + .unwrap(); } } - panic!(); + panic!("{out}"); } } } diff --git a/crates/gpui/src/geometry.rs b/crates/gpui/src/geometry.rs index 859ecb3d0e6c7b5c33f5765ce4c6295cef7fd566..4daec6d15367f3e12bab3cba658ccb3f261e9f46 100644 --- a/crates/gpui/src/geometry.rs +++ b/crates/gpui/src/geometry.rs @@ -3567,7 +3567,7 @@ pub const fn relative(fraction: f32) -> DefiniteLength { } /// Returns the Golden Ratio, i.e. `~(1.0 + sqrt(5.0)) / 2.0`. -pub fn phi() -> DefiniteLength { +pub const fn phi() -> DefiniteLength { relative(1.618_034) } @@ -3580,7 +3580,7 @@ pub fn phi() -> DefiniteLength { /// # Returns /// /// A `Rems` representing the specified number of rems. -pub fn rems(rems: f32) -> Rems { +pub const fn rems(rems: f32) -> Rems { Rems(rems) } @@ -3608,7 +3608,7 @@ pub const fn px(pixels: f32) -> Pixels { /// # Returns /// /// A `Length` variant set to `Auto`. -pub fn auto() -> Length { +pub const fn auto() -> Length { Length::Auto } diff --git a/crates/gpui/src/taffy.rs b/crates/gpui/src/taffy.rs index c3113ad2cb91ad8c9e29360812716114a7427052..11cb0872861321c3c06c3f8a5bf79fdd30eb2275 100644 --- a/crates/gpui/src/taffy.rs +++ b/crates/gpui/src/taffy.rs @@ -8,7 +8,6 @@ use std::{fmt::Debug, ops::Range}; use taffy::{ TaffyTree, TraversePartialTree as _, geometry::{Point as TaffyPoint, Rect as TaffyRect, Size as TaffySize}, - prelude::min_content, style::AvailableSpace as TaffyAvailableSpace, tree::NodeId, }; @@ -296,7 +295,7 @@ trait ToTaffy { impl ToTaffy for Style { fn to_taffy(&self, rem_size: Pixels, scale_factor: f32) -> taffy::style::Style { - use taffy::style_helpers::{length, minmax, repeat}; + use taffy::style_helpers::{fr, length, minmax, repeat}; fn to_grid_line( placement: &Range, @@ -310,8 +309,8 @@ impl ToTaffy for Style { fn to_grid_repeat( unit: &Option, ) -> Vec> { - // grid-template-columns: repeat(, minmax(0, min-content)); - unit.map(|count| vec![repeat(count, vec![minmax(length(0.0), min_content())])]) + // grid-template-columns: repeat(, minmax(0, 1fr)); + unit.map(|count| vec![repeat(count, vec![minmax(length(0.0), fr(1.0))])]) .unwrap_or_default() } diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index ce4ba4d3fa2aa31e05a73c905fbfb8fc4e6d8a8d..6cea6db9079a38ec5c539d9bcaf9e9b84428420b 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -33,8 +33,8 @@ pub enum IconName { ArrowRightLeft, ArrowUp, ArrowUpRight, - Attach, AtSign, + Attach, AudioOff, AudioOn, Backspace, @@ -44,8 +44,8 @@ pub enum IconName { BellRing, Binary, Blocks, - BoltOutlined, BoltFilled, + BoltOutlined, Book, BookCopy, CaseSensitive, @@ -79,9 +79,9 @@ pub enum IconName { Debug, DebugBreakpoint, DebugContinue, + DebugDetach, DebugDisabledBreakpoint, DebugDisabledLogBreakpoint, - DebugDetach, DebugIgnoreBreakpoints, DebugLogBreakpoint, DebugPause, @@ -135,10 +135,12 @@ pub enum IconName { GenericRestore, GitBranch, GitBranchAlt, + GitBranchPlus, Github, Hash, HistoryRerun, Image, + Inception, Indicator, Info, Json, @@ -146,6 +148,7 @@ pub enum IconName { Library, LineHeight, Link, + Linux, ListCollapse, ListFilter, ListTodo, @@ -171,8 +174,8 @@ pub enum IconName { PencilUnavailable, Person, Pin, - PlayOutlined, PlayFilled, + PlayOutlined, Plus, Power, Public, @@ -258,15 +261,14 @@ pub enum IconName { ZedAssistant, ZedBurnMode, ZedBurnModeOn, - ZedSrcCustom, - ZedSrcExtension, ZedPredict, ZedPredictDisabled, ZedPredictDown, ZedPredictError, ZedPredictUp, + ZedSrcCustom, + ZedSrcExtension, ZedXCopilot, - Linux, } impl IconName { diff --git a/crates/json_schema_store/src/json_schema_store.rs b/crates/json_schema_store/src/json_schema_store.rs index b44efb8b1b135850ab78460a428b5088e5fa0928..18041545ccd404eef0035b9b50ff8244d212fa0b 100644 --- a/crates/json_schema_store/src/json_schema_store.rs +++ b/crates/json_schema_store/src/json_schema_store.rs @@ -3,8 +3,9 @@ use std::{str::FromStr, sync::Arc}; use anyhow::{Context as _, Result}; use gpui::{App, AsyncApp, BorrowAppContext as _, Entity, WeakEntity}; -use language::LanguageRegistry; +use language::{LanguageRegistry, language_settings::all_language_settings}; use project::LspStore; +use util::schemars::{AllowTrailingCommas, DefaultDenyUnknownFields}; // Origin: https://github.com/SchemaStore/schemastore const TSCONFIG_SCHEMA: &str = include_str!("schemas/tsconfig.json"); @@ -159,14 +160,35 @@ pub fn resolve_schema_request_inner( } } "snippets" => snippet_provider::format::VsSnippetsFile::generate_json_schema(), + "jsonc" => jsonc_schema(), _ => { - anyhow::bail!("Unrecognized builtin JSON schema: {}", schema_name); + anyhow::bail!("Unrecognized builtin JSON schema: {schema_name}"); } }; Ok(schema) } -pub fn all_schema_file_associations(cx: &mut App) -> serde_json::Value { +const JSONC_LANGUAGE_NAME: &str = "JSONC"; + +pub fn all_schema_file_associations( + languages: &Arc, + cx: &mut App, +) -> serde_json::Value { + let extension_globs = languages + .available_language_for_name(JSONC_LANGUAGE_NAME) + .map(|language| language.matcher().path_suffixes.clone()) + .into_iter() + .flatten() + // Path suffixes can be entire file names or just their extensions. + .flat_map(|path_suffix| [format!("*.{path_suffix}"), path_suffix]); + let override_globs = all_language_settings(None, cx) + .file_types + .get(JSONC_LANGUAGE_NAME) + .into_iter() + .flat_map(|(_, glob_strings)| glob_strings) + .cloned(); + let jsonc_globs = extension_globs.chain(override_globs).collect::>(); + let mut file_associations = serde_json::json!([ { "fileMatch": [ @@ -211,6 +233,10 @@ pub fn all_schema_file_associations(cx: &mut App) -> serde_json::Value { "fileMatch": ["package.json"], "url": "zed://schemas/package_json" }, + { + "fileMatch": &jsonc_globs, + "url": "zed://schemas/jsonc" + }, ]); #[cfg(debug_assertions)] @@ -233,7 +259,7 @@ pub fn all_schema_file_associations(cx: &mut App) -> serde_json::Value { let file_name = normalized_action_name_to_file_name(normalized_name.clone()); serde_json::json!({ "fileMatch": [file_name], - "url": format!("zed://schemas/action/{}", normalized_name) + "url": format!("zed://schemas/action/{normalized_name}") }) }), ); @@ -249,6 +275,26 @@ fn package_json_schema() -> serde_json::Value { serde_json::Value::from_str(PACKAGE_JSON_SCHEMA).unwrap() } +fn jsonc_schema() -> serde_json::Value { + let generator = schemars::generate::SchemaSettings::draft2019_09() + .with_transform(DefaultDenyUnknownFields) + .with_transform(AllowTrailingCommas) + .into_generator(); + let meta_schema = generator + .settings() + .meta_schema + .as_ref() + .expect("meta_schema should be present in schemars settings") + .to_string(); + let defs = generator.definitions(); + let schema = schemars::json_schema!({ + "$schema": meta_schema, + "allowTrailingCommas": true, + "$defs": defs, + }); + serde_json::to_value(schema).unwrap() +} + fn generate_inspector_style_schema() -> serde_json::Value { let schema = schemars::generate::SchemaSettings::draft2019_09() .with_transform(util::schemars::DefaultDenyUnknownFields) diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index a46f7cc35912d4c6da42ba69f7aee6d25caca2e7..7166a01ef64bff9e47c70cac47910f714ae2dc39 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -4022,6 +4022,20 @@ impl BufferSnapshot { }) } + pub fn outline_items_as_offsets_containing( + &self, + range: Range, + include_extra_context: bool, + theme: Option<&SyntaxTheme>, + ) -> Vec> { + self.outline_items_containing_internal( + range, + include_extra_context, + theme, + |buffer, range| range.to_offset(buffer), + ) + } + fn outline_items_containing_internal( &self, range: Range, diff --git a/crates/language/src/buffer_tests.rs b/crates/language/src/buffer_tests.rs index efef0a08127bc66f9c6d8f21fe5a545dbee20fb1..54e2ef4065460547f4a3f86db7d3a3986dff65eb 100644 --- a/crates/language/src/buffer_tests.rs +++ b/crates/language/src/buffer_tests.rs @@ -6,6 +6,7 @@ use futures::FutureExt as _; use gpui::{App, AppContext as _, BorrowAppContext, Entity}; use gpui::{HighlightStyle, TestAppContext}; use indoc::indoc; +use pretty_assertions::assert_eq; use proto::deserialize_operation; use rand::prelude::*; use regex::RegexBuilder; @@ -46,8 +47,7 @@ fn test_line_endings(cx: &mut gpui::App) { init_settings(cx, |_| {}); cx.new(|cx| { - let mut buffer = - Buffer::local("one\r\ntwo\rthree", cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local("one\r\ntwo\rthree", cx).with_language(rust_lang(), cx); assert_eq!(buffer.text(), "one\ntwo\nthree"); assert_eq!(buffer.line_ending(), LineEnding::Windows); @@ -151,7 +151,7 @@ fn test_select_language(cx: &mut App) { let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); registry.add(Arc::new(Language::new( LanguageConfig { - name: LanguageName::new("Rust"), + name: LanguageName::new_static("Rust"), matcher: LanguageMatcher { path_suffixes: vec!["rs".to_string()], ..Default::default() @@ -173,7 +173,7 @@ fn test_select_language(cx: &mut App) { ))); registry.add(Arc::new(Language::new( LanguageConfig { - name: LanguageName::new("Make"), + name: LanguageName::new_static("Make"), matcher: LanguageMatcher { path_suffixes: vec!["Makefile".to_string(), "mk".to_string()], ..Default::default() @@ -608,7 +608,7 @@ async fn test_normalize_whitespace(cx: &mut gpui::TestAppContext) { #[gpui::test] async fn test_reparse(cx: &mut gpui::TestAppContext) { let text = "fn a() {}"; - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); // Wait for the initial text to parse cx.executor().run_until_parked(); @@ -735,7 +735,7 @@ async fn test_reparse(cx: &mut gpui::TestAppContext) { #[gpui::test] async fn test_resetting_language(cx: &mut gpui::TestAppContext) { let buffer = cx.new(|cx| { - let mut buffer = Buffer::local("{}", cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local("{}", cx).with_language(rust_lang(), cx); buffer.set_sync_parse_timeout(Duration::ZERO); buffer }); @@ -783,29 +783,49 @@ async fn test_outline(cx: &mut gpui::TestAppContext) { "# .unindent(); - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); - let outline = buffer.update(cx, |buffer, _| buffer.snapshot().outline(None)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); + let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot()); + let outline = snapshot.outline(None); assert_eq!( outline .items .iter() - .map(|item| (item.text.as_str(), item.depth)) + .map(|item| ( + item.text.as_str(), + item.depth, + item.to_point(&snapshot).body_range(&snapshot) + .map(|range| minimize_space(&snapshot.text_for_range(range).collect::())) + )) .collect::>(), &[ - ("struct Person", 0), - ("name", 1), - ("age", 1), - ("mod module", 0), - ("enum LoginState", 1), - ("LoggedOut", 2), - ("LoggingOn", 2), - ("LoggedIn", 2), - ("person", 3), - ("time", 3), - ("impl Eq for Person", 0), - ("impl Drop for Person", 0), - ("fn drop", 1), + ("struct Person", 0, Some("name: String, age: usize,".to_string())), + ("name", 1, None), + ("age", 1, None), + ( + "mod module", + 0, + Some( + "enum LoginState { LoggedOut, LoggingOn, LoggedIn { person: Person, time: Instant, } }".to_string() + ) + ), + ( + "enum LoginState", + 1, + Some("LoggedOut, LoggingOn, LoggedIn { person: Person, time: Instant, }".to_string()) + ), + ("LoggedOut", 2, None), + ("LoggingOn", 2, None), + ("LoggedIn", 2, Some("person: Person, time: Instant,".to_string())), + ("person", 3, None), + ("time", 3, None), + ("impl Eq for Person", 0, Some("".to_string())), + ( + "impl Drop for Person", + 0, + Some("fn drop(&mut self) { println!(\"bye\"); }".to_string()) + ), + ("fn drop", 1, Some("println!(\"bye\");".to_string())), ] ); @@ -840,6 +860,11 @@ async fn test_outline(cx: &mut gpui::TestAppContext) { ] ); + fn minimize_space(text: &str) -> String { + static WHITESPACE: LazyLock = LazyLock::new(|| Regex::new("[\\n\\s]+").unwrap()); + WHITESPACE.replace_all(text, " ").trim().to_string() + } + async fn search<'a>( outline: &'a Outline, query: &'a str, @@ -865,7 +890,7 @@ async fn test_outline_nodes_with_newlines(cx: &mut gpui::TestAppContext) { "# .unindent(); - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); let outline = buffer.update(cx, |buffer, _| buffer.snapshot().outline(None)); assert_eq!( @@ -945,7 +970,7 @@ fn test_outline_annotations(cx: &mut App) { "# .unindent(); - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); let outline = buffer.update(cx, |buffer, _| buffer.snapshot().outline(None)); assert_eq!( @@ -993,7 +1018,7 @@ async fn test_symbols_containing(cx: &mut gpui::TestAppContext) { "# .unindent(); - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot()); // point is at the start of an item @@ -1068,7 +1093,7 @@ async fn test_symbols_containing(cx: &mut gpui::TestAppContext) { " .unindent(), ); - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot()); // note, it would be nice to actually return the method test in this @@ -1087,8 +1112,7 @@ fn test_text_objects(cx: &mut App) { false, ); - let buffer = - cx.new(|cx| Buffer::local(text.clone(), cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text.clone(), cx).with_language(rust_lang(), cx)); let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot()); let matches = snapshot @@ -1105,6 +1129,14 @@ fn test_text_objects(cx: &mut App) { "fn say() -> u8 { return /* hi */ 1 }", TextObject::AroundFunction ), + ( + "fn say() -> u8 { return /* hi */ 1 }", + TextObject::InsideClass + ), + ( + "impl Hello {\n fn say() -> u8 { return /* hi */ 1 }\n}", + TextObject::AroundClass + ), ], ) } @@ -1235,7 +1267,12 @@ fn test_enclosing_bracket_ranges(cx: &mut App) { #[gpui::test] fn test_enclosing_bracket_ranges_where_brackets_are_not_outermost_children(cx: &mut App) { let mut assert = |selection_text, bracket_pair_texts| { - assert_bracket_pairs(selection_text, bracket_pair_texts, javascript_lang(), cx) + assert_bracket_pairs( + selection_text, + bracket_pair_texts, + Arc::new(javascript_lang()), + cx, + ) }; assert( @@ -1268,7 +1305,7 @@ fn test_enclosing_bracket_ranges_where_brackets_are_not_outermost_children(cx: & fn test_range_for_syntax_ancestor(cx: &mut App) { cx.new(|cx| { let text = "fn a() { b(|c| {}) }"; - let buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx); + let buffer = Buffer::local(text, cx).with_language(rust_lang(), cx); let snapshot = buffer.snapshot(); assert_eq!( @@ -1320,7 +1357,7 @@ fn test_autoindent_with_soft_tabs(cx: &mut App) { cx.new(|cx| { let text = "fn a() {}"; - let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx); buffer.edit([(8..8, "\n\n")], Some(AutoindentMode::EachLine), cx); assert_eq!(buffer.text(), "fn a() {\n \n}"); @@ -1362,7 +1399,7 @@ fn test_autoindent_with_hard_tabs(cx: &mut App) { cx.new(|cx| { let text = "fn a() {}"; - let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx); buffer.edit([(8..8, "\n\n")], Some(AutoindentMode::EachLine), cx); assert_eq!(buffer.text(), "fn a() {\n\t\n}"); @@ -1411,7 +1448,7 @@ fn test_autoindent_does_not_adjust_lines_with_unchanged_suggestion(cx: &mut App) .unindent(), cx, ) - .with_language(Arc::new(rust_lang()), cx); + .with_language(rust_lang(), cx); // Lines 2 and 3 don't match the indentation suggestion. When editing these lines, // their indentation is not adjusted. @@ -1552,7 +1589,7 @@ fn test_autoindent_does_not_adjust_lines_with_unchanged_suggestion(cx: &mut App) .unindent(), cx, ) - .with_language(Arc::new(rust_lang()), cx); + .with_language(rust_lang(), cx); // Insert a closing brace. It is outdented. buffer.edit_via_marked_text( @@ -1615,7 +1652,7 @@ fn test_autoindent_does_not_adjust_lines_within_newly_created_errors(cx: &mut Ap .unindent(), cx, ) - .with_language(Arc::new(rust_lang()), cx); + .with_language(rust_lang(), cx); // Regression test: line does not get outdented due to syntax error buffer.edit_via_marked_text( @@ -1674,7 +1711,7 @@ fn test_autoindent_adjusts_lines_when_only_text_changes(cx: &mut App) { .unindent(), cx, ) - .with_language(Arc::new(rust_lang()), cx); + .with_language(rust_lang(), cx); buffer.edit_via_marked_text( &" @@ -1724,7 +1761,7 @@ fn test_autoindent_with_edit_at_end_of_buffer(cx: &mut App) { cx.new(|cx| { let text = "a\nb"; - let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx); buffer.edit( [(0..1, "\n"), (2..3, "\n")], Some(AutoindentMode::EachLine), @@ -1750,7 +1787,7 @@ fn test_autoindent_multi_line_insertion(cx: &mut App) { " .unindent(); - let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx); buffer.edit( [(Point::new(3, 0)..Point::new(3, 0), "e(\n f()\n);\n")], Some(AutoindentMode::EachLine), @@ -1787,7 +1824,7 @@ fn test_autoindent_block_mode(cx: &mut App) { } "# .unindent(); - let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx); // When this text was copied, both of the quotation marks were at the same // indent level, but the indentation of the first line was not included in @@ -1870,7 +1907,7 @@ fn test_autoindent_block_mode_with_newline(cx: &mut App) { } "# .unindent(); - let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx); // First line contains just '\n', it's indentation is stored in "original_indent_columns" let original_indent_columns = vec![Some(4)]; @@ -1922,7 +1959,7 @@ fn test_autoindent_block_mode_without_original_indent_columns(cx: &mut App) { } "# .unindent(); - let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx); // The original indent columns are not known, so this text is // auto-indented in a block as if the first line was copied in @@ -2013,7 +2050,7 @@ fn test_autoindent_block_mode_multiple_adjacent_ranges(cx: &mut App) { false, ); - let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx); buffer.edit( [ @@ -2027,7 +2064,7 @@ fn test_autoindent_block_mode_multiple_adjacent_ranges(cx: &mut App) { cx, ); - pretty_assertions::assert_eq!( + assert_eq!( buffer.text(), " mod numbers { @@ -2221,7 +2258,7 @@ async fn test_async_autoindents_preserve_preview(cx: &mut TestAppContext) { // Then we request that a preview tab be preserved for the new version, even though it's edited. let buffer = cx.new(|cx| { let text = "fn a() {}"; - let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx); + let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx); // This causes autoindent to be async. buffer.set_sync_parse_timeout(Duration::ZERO); @@ -2679,7 +2716,7 @@ fn test_language_at_with_hidden_languages(cx: &mut App) { .unindent(); let language_registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - language_registry.add(Arc::new(markdown_lang())); + language_registry.add(markdown_lang()); language_registry.add(Arc::new(markdown_inline_lang())); let mut buffer = Buffer::local(text, cx); @@ -2721,9 +2758,9 @@ fn test_language_at_for_markdown_code_block(cx: &mut App) { .unindent(); let language_registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - language_registry.add(Arc::new(markdown_lang())); + language_registry.add(markdown_lang()); language_registry.add(Arc::new(markdown_inline_lang())); - language_registry.add(Arc::new(rust_lang())); + language_registry.add(rust_lang()); let mut buffer = Buffer::local(text, cx); buffer.set_language_registry(language_registry.clone()); @@ -3120,7 +3157,7 @@ async fn test_preview_edits(cx: &mut TestAppContext) { cx: &mut TestAppContext, assert_fn: impl Fn(HighlightedText), ) { - let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); + let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx)); let edits = buffer.read_with(cx, |buffer, _| { edits .into_iter() @@ -3531,7 +3568,7 @@ let word=öäpple.bar你 Öäpple word2-öÄpPlE-Pizza-word ÖÄPPLE word "#; let buffer = cx.new(|cx| { - let buffer = Buffer::local(contents, cx).with_language(Arc::new(rust_lang()), cx); + let buffer = Buffer::local(contents, cx).with_language(rust_lang(), cx); assert_eq!(buffer.text(), contents); buffer.check_invariants(); buffer @@ -3691,7 +3728,7 @@ fn ruby_lang() -> Language { fn html_lang() -> Language { Language::new( LanguageConfig { - name: LanguageName::new("HTML"), + name: LanguageName::new_static("HTML"), block_comment: Some(BlockCommentConfig { start: "