Detailed changes
@@ -211,7 +211,7 @@ dependencies = [
"worktree",
"zed_env_vars",
"zlog",
- "zstd 0.11.2+zstd.1.5.2",
+ "zstd",
]
[[package]]
@@ -680,21 +680,6 @@ dependencies = [
"syn 2.0.106",
]
-[[package]]
-name = "argminmax"
-version = "0.6.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "70f13d10a41ac8d2ec79ee34178d61e6f47a29c2edfe7ef1721c7383b0359e65"
-dependencies = [
- "num-traits",
-]
-
-[[package]]
-name = "array-init-cursor"
-version = "0.2.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ed51fe0f224d1d4ea768be38c51f9f831dee9d05c163c11fba0b8c44387b1fc3"
-
[[package]]
name = "arraydeque"
version = "0.5.1"
@@ -1278,15 +1263,6 @@ dependencies = [
"num-traits",
]
-[[package]]
-name = "atoi_simd"
-version = "0.16.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c2a49e05797ca52e312a0c658938b7d00693ef037799ef7187678f212d7684cf"
-dependencies = [
- "debug_unsafe",
-]
-
[[package]]
name = "atomic"
version = "0.5.3"
@@ -2070,26 +2046,6 @@ dependencies = [
"serde",
]
-[[package]]
-name = "bincode"
-version = "2.0.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "36eaf5d7b090263e8150820482d5d93cd964a81e4019913c972f4edcc6edb740"
-dependencies = [
- "bincode_derive",
- "serde",
- "unty",
-]
-
-[[package]]
-name = "bincode_derive"
-version = "2.0.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bf95709a440f45e986983918d0e8a1f30a9b1df04918fc828670606804ac3c09"
-dependencies = [
- "virtue",
-]
-
[[package]]
name = "bindgen"
version = "0.71.1"
@@ -2242,19 +2198,6 @@ dependencies = [
"profiling",
]
-[[package]]
-name = "blake3"
-version = "1.8.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0"
-dependencies = [
- "arrayref",
- "arrayvec",
- "cc",
- "cfg-if",
- "constant_time_eq 0.3.1",
-]
-
[[package]]
name = "block"
version = "0.1.6"
@@ -2344,12 +2287,6 @@ dependencies = [
"syn 2.0.106",
]
-[[package]]
-name = "boxcar"
-version = "0.2.14"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "36f64beae40a84da1b4b26ff2761a5b895c12adc41dc25aaee1c4f2bbfe97a6e"
-
[[package]]
name = "breadcrumbs"
version = "0.1.0"
@@ -2516,9 +2453,6 @@ name = "bytes"
version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a"
-dependencies = [
- "serde",
-]
[[package]]
name = "bytes-utils"
@@ -2805,15 +2739,6 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
-[[package]]
-name = "castaway"
-version = "0.2.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a"
-dependencies = [
- "rustversion",
-]
-
[[package]]
name = "cbc"
version = "0.1.2"
@@ -2942,16 +2867,6 @@ dependencies = [
"windows-link 0.2.1",
]
-[[package]]
-name = "chrono-tz"
-version = "0.10.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a6139a8597ed92cf816dfb33f5dd6cf0bb93a6adc938f11039f371bc5bcd26c3"
-dependencies = [
- "chrono",
- "phf 0.12.1",
-]
-
[[package]]
name = "chunked_transfer"
version = "1.5.0"
@@ -3201,12 +3116,7 @@ dependencies = [
"anyhow",
"cloud_llm_client",
"indoc",
- "ordered-float 2.10.1",
- "rustc-hash 2.1.1",
- "schemars",
"serde",
- "serde_json",
- "strum 0.27.2",
]
[[package]]
@@ -3314,8 +3224,8 @@ name = "codestral"
version = "0.1.0"
dependencies = [
"anyhow",
- "edit_prediction",
"edit_prediction_context",
+ "edit_prediction_types",
"futures 0.3.31",
"gpui",
"http_client",
@@ -3505,17 +3415,6 @@ dependencies = [
"memchr",
]
-[[package]]
-name = "comfy-table"
-version = "7.2.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b03b7db8e0b4b2fdad6c551e634134e99ec000e5c8c3b6856c65e8bbaded7a3b"
-dependencies = [
- "crossterm",
- "unicode-segmentation",
- "unicode-width",
-]
-
[[package]]
name = "command-fds"
version = "0.3.2"
@@ -3569,21 +3468,6 @@ dependencies = [
"workspace",
]
-[[package]]
-name = "compact_str"
-version = "0.9.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a"
-dependencies = [
- "castaway",
- "cfg-if",
- "itoa",
- "rustversion",
- "ryu",
- "serde",
- "static_assertions",
-]
-
[[package]]
name = "component"
version = "0.1.0"
@@ -3689,12 +3573,6 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc"
-[[package]]
-name = "constant_time_eq"
-version = "0.3.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6"
-
[[package]]
name = "context_server"
version = "0.1.0"
@@ -3747,7 +3625,7 @@ dependencies = [
"command_palette_hooks",
"ctor",
"dirs 4.0.0",
- "edit_prediction",
+ "edit_prediction_types",
"editor",
"fs",
"futures 0.3.31",
@@ -4160,7 +4038,7 @@ dependencies = [
name = "crashes"
version = "0.1.0"
dependencies = [
- "bincode 1.3.3",
+ "bincode",
"cfg-if",
"crash-handler",
"extension_host",
@@ -4174,7 +4052,7 @@ dependencies = [
"smol",
"system_specs",
"windows 0.61.3",
- "zstd 0.11.2+zstd.1.5.2",
+ "zstd",
]
[[package]]
@@ -4319,29 +4197,6 @@ version = "0.8.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
-[[package]]
-name = "crossterm"
-version = "0.29.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d8b9f2e4c67f833b660cdb0a3523065869fb35570177239812ed4c905aeff87b"
-dependencies = [
- "bitflags 2.9.4",
- "crossterm_winapi",
- "document-features",
- "parking_lot",
- "rustix 1.1.2",
- "winapi",
-]
-
-[[package]]
-name = "crossterm_winapi"
-version = "0.9.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b"
-dependencies = [
- "winapi",
-]
-
[[package]]
name = "crunchy"
version = "0.2.4"
@@ -4696,12 +4551,6 @@ dependencies = [
"util",
]
-[[package]]
-name = "debug_unsafe"
-version = "0.1.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "85d3cef41d236720ed453e102153a53e4cc3d2fde848c0078a50cf249e8e3e5b"
-
[[package]]
name = "debugger_tools"
version = "0.1.0"
@@ -5109,15 +4958,6 @@ dependencies = [
"zlog",
]
-[[package]]
-name = "document-features"
-version = "0.2.11"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "95249b50c6c185bee49034bcb378a49dc2b5dff0be90ff6616d31d64febab05d"
-dependencies = [
- "litrs",
-]
-
[[package]]
name = "documented"
version = "0.9.2"
@@ -5267,86 +5107,112 @@ dependencies = [
name = "edit_prediction"
version = "0.1.0"
dependencies = [
- "client",
- "gpui",
- "language",
-]
-
-[[package]]
-name = "edit_prediction_button"
-version = "0.1.0"
-dependencies = [
+ "ai_onboarding",
"anyhow",
+ "arrayvec",
+ "brotli",
"client",
+ "clock",
+ "cloud_api_types",
"cloud_llm_client",
- "codestral",
+ "cloud_zeta2_prompt",
+ "collections",
"copilot",
- "edit_prediction",
- "editor",
+ "credentials_provider",
+ "ctor",
+ "db",
+ "edit_prediction_context",
+ "edit_prediction_types",
"feature_flags",
"fs",
"futures 0.3.31",
"gpui",
"indoc",
+ "itertools 0.14.0",
"language",
+ "language_model",
+ "log",
"lsp",
"menu",
- "paths",
+ "open_ai",
+ "parking_lot",
+ "postage",
+ "pretty_assertions",
"project",
+ "rand 0.9.2",
"regex",
+ "release_channel",
+ "semver",
+ "serde",
"serde_json",
"settings",
- "supermaven",
+ "smol",
+ "strsim",
+ "strum 0.27.2",
"telemetry",
- "theme",
+ "telemetry_events",
+ "thiserror 2.0.17",
"ui",
- "ui_input",
"util",
+ "uuid",
"workspace",
+ "worktree",
"zed_actions",
- "zeta",
+ "zlog",
]
[[package]]
-name = "edit_prediction_context"
+name = "edit_prediction_cli"
version = "0.1.0"
dependencies = [
"anyhow",
- "arrayvec",
+ "chrono",
"clap",
+ "client",
"cloud_llm_client",
+ "cloud_zeta2_prompt",
"collections",
+ "debug_adapter_extension",
+ "edit_prediction",
+ "edit_prediction_context",
+ "extension",
+ "fs",
"futures 0.3.31",
"gpui",
- "hashbrown 0.15.5",
+ "gpui_tokio",
"indoc",
- "itertools 0.14.0",
"language",
+ "language_extension",
+ "language_model",
+ "language_models",
+ "languages",
"log",
- "ordered-float 2.10.1",
- "postage",
+ "node_runtime",
+ "paths",
"pretty_assertions",
"project",
- "regex",
+ "prompt_store",
+ "pulldown-cmark 0.12.2",
+ "release_channel",
+ "reqwest_client",
"serde",
"serde_json",
"settings",
- "slotmap",
- "strum 0.27.2",
- "text",
- "tree-sitter",
- "tree-sitter-c",
- "tree-sitter-cpp",
- "tree-sitter-go",
+ "shellexpand 2.1.2",
+ "smol",
+ "terminal_view",
+ "toml 0.8.23",
"util",
+ "watch",
"zlog",
]
[[package]]
-name = "edit_prediction_context2"
+name = "edit_prediction_context"
version = "0.1.0"
dependencies = [
"anyhow",
+ "cloud_llm_client",
"collections",
"env_logger 0.11.8",
"futures 0.3.31",
@@ -5368,6 +5234,56 @@ dependencies = [
"zlog",
]
+[[package]]
+name = "edit_prediction_types"
+version = "0.1.0"
+dependencies = [
+ "client",
+ "gpui",
+ "language",
+]
+
+[[package]]
+name = "edit_prediction_ui"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "buffer_diff",
+ "client",
+ "cloud_llm_client",
+ "cloud_zeta2_prompt",
+ "codestral",
+ "command_palette_hooks",
+ "copilot",
+ "edit_prediction",
+ "edit_prediction_types",
+ "editor",
+ "feature_flags",
+ "fs",
+ "futures 0.3.31",
+ "gpui",
+ "indoc",
+ "language",
+ "lsp",
+ "markdown",
+ "menu",
+ "multi_buffer",
+ "paths",
+ "project",
+ "regex",
+ "serde_json",
+ "settings",
+ "supermaven",
+ "telemetry",
+ "text",
+ "theme",
+ "ui",
+ "ui_input",
+ "util",
+ "workspace",
+ "zed_actions",
+]
+
[[package]]
name = "editor"
version = "0.1.0"
@@ -5384,7 +5300,7 @@ dependencies = [
"ctor",
"dap",
"db",
- "edit_prediction",
+ "edit_prediction_types",
"emojis",
"feature_flags",
"file_icons",
@@ -5723,14 +5639,8 @@ dependencies = [
]
[[package]]
-name = "ethnum"
-version = "1.5.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ca81e6b4777c89fd810c25a4be2b1bd93ea034fbe58e6a75216a34c6b82c539b"
-
-[[package]]
-name = "euclid"
-version = "0.22.11"
+name = "euclid"
+version = "0.22.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad9cdb4b747e485a12abb0e6566612956c7a1bafa3bdb8d682c5b6d403589e48"
dependencies = [
@@ -6012,12 +5922,6 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649"
-[[package]]
-name = "fallible-streaming-iterator"
-version = "0.1.9"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
-
[[package]]
name = "fancy-regex"
version = "0.16.2"
@@ -6029,12 +5933,6 @@ dependencies = [
"regex-syntax",
]
-[[package]]
-name = "fast-float2"
-version = "0.2.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f8eb564c5c7423d25c886fb561d1e4ee69f72354d16918afa32c08811f6b6a55"
-
[[package]]
name = "fast-srgb8"
version = "1.0.0"
@@ -6210,7 +6108,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc5a4e564e38c699f2880d3fda590bedc2e69f3f84cd48b457bd892ce61d0aa9"
dependencies = [
"crc32fast",
- "libz-rs-sys",
"miniz_oxide",
]
@@ -6467,16 +6364,6 @@ dependencies = [
"winapi",
]
-[[package]]
-name = "fs4"
-version = "0.13.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8640e34b88f7652208ce9e88b1a37a2ae95227d84abec377ccd3c5cfeb141ed4"
-dependencies = [
- "rustix 1.1.2",
- "windows-sys 0.59.0",
-]
-
[[package]]
name = "fs_benchmarks"
version = "0.1.0"
@@ -7540,7 +7427,6 @@ dependencies = [
"allocator-api2",
"equivalent",
"foldhash 0.1.5",
- "rayon",
"serde",
]
@@ -7652,7 +7538,7 @@ version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13c255bdf46e07fb840d120a36dcc81f385140d7191c76a7391672675c01a55d"
dependencies = [
- "bincode 1.3.3",
+ "bincode",
"byteorder",
"heed-traits",
"serde",
@@ -8412,7 +8298,7 @@ version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fb8251fb7bcd9ccd3725ed8deae9fe7db8e586495c9eb5b0c52e6233e5e75ea"
dependencies = [
- "bincode 1.3.3",
+ "bincode",
"crossbeam-channel",
"fnv",
"lazy_static",
@@ -9256,15 +9142,6 @@ dependencies = [
"webrtc-sys",
]
-[[package]]
-name = "libz-rs-sys"
-version = "0.5.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "840db8cf39d9ec4dd794376f38acc40d0fc65eec2a8f484f7fd375b84602becd"
-dependencies = [
- "zlib-rs",
-]
-
[[package]]
name = "libz-sys"
version = "1.1.22"
@@ -9327,12 +9204,6 @@ version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956"
-[[package]]
-name = "litrs"
-version = "0.4.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f5e54036fe321fd421e10d732f155734c4e4afd610dd556d9a82833ab3ee0bed"
-
[[package]]
name = "livekit"
version = "0.7.8"
@@ -9624,25 +9495,6 @@ dependencies = [
"num-traits",
]
-[[package]]
-name = "lz4"
-version = "1.28.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a20b523e860d03443e98350ceaac5e71c6ba89aea7d960769ec3ce37f4de5af4"
-dependencies = [
- "lz4-sys",
-]
-
-[[package]]
-name = "lz4-sys"
-version = "1.11.1+lz4-1.10.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6"
-dependencies = [
- "cc",
- "libc",
-]
-
[[package]]
name = "mac"
version = "0.1.1"
@@ -10505,15 +10357,6 @@ name = "notify-types"
version = "2.0.0"
source = "git+https://github.com/zed-industries/notify.git?rev=b4588b2e5aee68f4c0e100f140e808cbce7b1419#b4588b2e5aee68f4c0e100f140e808cbce7b1419"
-[[package]]
-name = "now"
-version = "0.1.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6d89e9874397a1f0a52fc1f197a8effd9735223cb2390e9dcc83ac6cd02923d0"
-dependencies = [
- "chrono",
-]
-
[[package]]
name = "ntapi"
version = "0.4.1"
@@ -10909,41 +10752,6 @@ dependencies = [
"memchr",
]
-[[package]]
-name = "object_store"
-version = "0.12.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4c1be0c6c22ec0817cdc77d3842f721a17fd30ab6965001415b5402a74e6b740"
-dependencies = [
- "async-trait",
- "base64 0.22.1",
- "bytes 1.10.1",
- "chrono",
- "form_urlencoded",
- "futures 0.3.31",
- "http 1.3.1",
- "http-body-util",
- "humantime",
- "hyper 1.7.0",
- "itertools 0.14.0",
- "parking_lot",
- "percent-encoding",
- "quick-xml 0.38.3",
- "rand 0.9.2",
- "reqwest 0.12.24",
- "ring",
- "serde",
- "serde_json",
- "serde_urlencoded",
- "thiserror 2.0.17",
- "tokio",
- "tracing",
- "url",
- "walkdir",
- "wasm-bindgen-futures",
- "web-time",
-]
-
[[package]]
name = "ollama"
version = "0.1.0"
@@ -12184,16 +11992,6 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6"
-[[package]]
-name = "planus"
-version = "1.1.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3daf8e3d4b712abe1d690838f6e29fb76b76ea19589c4afa39ec30e12f62af71"
-dependencies = [
- "array-init-cursor",
- "hashbrown 0.15.5",
-]
-
[[package]]
name = "plist"
version = "1.8.0"
@@ -12261,520 +12059,6 @@ dependencies = [
"miniz_oxide",
]
-[[package]]
-name = "polars"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a5f7feb5d56b954e691dff22a8b2d78d77433dcc93c35fe21c3777fdc121b697"
-dependencies = [
- "getrandom 0.2.16",
- "getrandom 0.3.4",
- "polars-arrow",
- "polars-core",
- "polars-error",
- "polars-io",
- "polars-lazy",
- "polars-ops",
- "polars-parquet",
- "polars-sql",
- "polars-time",
- "polars-utils",
- "version_check",
-]
-
-[[package]]
-name = "polars-arrow"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "32b4fed2343961b3eea3db2cee165540c3e1ad9d5782350cc55a9e76cf440148"
-dependencies = [
- "atoi_simd",
- "bitflags 2.9.4",
- "bytemuck",
- "chrono",
- "chrono-tz",
- "dyn-clone",
- "either",
- "ethnum",
- "getrandom 0.2.16",
- "getrandom 0.3.4",
- "hashbrown 0.15.5",
- "itoa",
- "lz4",
- "num-traits",
- "polars-arrow-format",
- "polars-error",
- "polars-schema",
- "polars-utils",
- "serde",
- "simdutf8",
- "streaming-iterator",
- "strum_macros 0.27.2",
- "version_check",
- "zstd 0.13.3",
-]
-
-[[package]]
-name = "polars-arrow-format"
-version = "0.2.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a556ac0ee744e61e167f34c1eb0013ce740e0ee6cd8c158b2ec0b518f10e6675"
-dependencies = [
- "planus",
- "serde",
-]
-
-[[package]]
-name = "polars-compute"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "138785beda4e4a90a025219f09d0d15a671b2be9091513ede58e05db6ad4413f"
-dependencies = [
- "atoi_simd",
- "bytemuck",
- "chrono",
- "either",
- "fast-float2",
- "hashbrown 0.15.5",
- "itoa",
- "num-traits",
- "polars-arrow",
- "polars-error",
- "polars-utils",
- "rand 0.9.2",
- "ryu",
- "serde",
- "skiplist",
- "strength_reduce",
- "strum_macros 0.27.2",
- "version_check",
-]
-
-[[package]]
-name = "polars-core"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e77b1f08ef6dbb032bb1d0d3365464be950df9905f6827a95b24c4ca5518901d"
-dependencies = [
- "bitflags 2.9.4",
- "boxcar",
- "bytemuck",
- "chrono",
- "chrono-tz",
- "comfy-table",
- "either",
- "hashbrown 0.15.5",
- "indexmap",
- "itoa",
- "num-traits",
- "polars-arrow",
- "polars-compute",
- "polars-dtype",
- "polars-error",
- "polars-row",
- "polars-schema",
- "polars-utils",
- "rand 0.9.2",
- "rand_distr",
- "rayon",
- "regex",
- "serde",
- "serde_json",
- "strum_macros 0.27.2",
- "uuid",
- "version_check",
- "xxhash-rust",
-]
-
-[[package]]
-name = "polars-dtype"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "89c43d0ea57168be4546c4d8064479ed8b29a9c79c31a0c7c367ee734b9b7158"
-dependencies = [
- "boxcar",
- "hashbrown 0.15.5",
- "polars-arrow",
- "polars-error",
- "polars-utils",
- "serde",
- "uuid",
-]
-
-[[package]]
-name = "polars-error"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b9cb5d98f59f8b94673ee391840440ad9f0d2170afced95fc98aa86f895563c0"
-dependencies = [
- "object_store",
- "parking_lot",
- "polars-arrow-format",
- "regex",
- "signal-hook",
- "simdutf8",
-]
-
-[[package]]
-name = "polars-expr"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "343931b818cf136349135ba11dbc18c27683b52c3477b1ba8ca606cf5ab1965c"
-dependencies = [
- "bitflags 2.9.4",
- "hashbrown 0.15.5",
- "num-traits",
- "polars-arrow",
- "polars-compute",
- "polars-core",
- "polars-io",
- "polars-ops",
- "polars-plan",
- "polars-row",
- "polars-time",
- "polars-utils",
- "rand 0.9.2",
- "rayon",
- "recursive",
-]
-
-[[package]]
-name = "polars-io"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "10388c64b8155122488229a881d1c6f4fdc393bc988e764ab51b182fcb2307e4"
-dependencies = [
- "async-trait",
- "atoi_simd",
- "blake3",
- "bytes 1.10.1",
- "chrono",
- "fast-float2",
- "fs4",
- "futures 0.3.31",
- "glob",
- "hashbrown 0.15.5",
- "home",
- "itoa",
- "memchr",
- "memmap2",
- "num-traits",
- "object_store",
- "percent-encoding",
- "polars-arrow",
- "polars-core",
- "polars-error",
- "polars-parquet",
- "polars-schema",
- "polars-time",
- "polars-utils",
- "rayon",
- "regex",
- "reqwest 0.12.24",
- "ryu",
- "serde",
- "serde_json",
- "simdutf8",
- "tokio",
- "tokio-util",
- "url",
-]
-
-[[package]]
-name = "polars-lazy"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0fb6e2c6c2fa4ea0c660df1c06cf56960c81e7c2683877995bae3d4e3d408147"
-dependencies = [
- "bitflags 2.9.4",
- "chrono",
- "either",
- "memchr",
- "polars-arrow",
- "polars-compute",
- "polars-core",
- "polars-expr",
- "polars-io",
- "polars-mem-engine",
- "polars-ops",
- "polars-plan",
- "polars-stream",
- "polars-time",
- "polars-utils",
- "rayon",
- "version_check",
-]
-
-[[package]]
-name = "polars-mem-engine"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "20a856e98e253587c28d8132a5e7e5a75cb2c44731ca090f1481d45f1d123771"
-dependencies = [
- "futures 0.3.31",
- "memmap2",
- "polars-arrow",
- "polars-core",
- "polars-error",
- "polars-expr",
- "polars-io",
- "polars-ops",
- "polars-plan",
- "polars-time",
- "polars-utils",
- "rayon",
- "recursive",
- "tokio",
-]
-
-[[package]]
-name = "polars-ops"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "acf6062173fdc9ba05775548beb66e76643a148d9aeadc9984ed712bc4babd76"
-dependencies = [
- "argminmax",
- "base64 0.22.1",
- "bytemuck",
- "chrono",
- "chrono-tz",
- "either",
- "hashbrown 0.15.5",
- "hex",
- "indexmap",
- "libm",
- "memchr",
- "num-traits",
- "polars-arrow",
- "polars-compute",
- "polars-core",
- "polars-error",
- "polars-schema",
- "polars-utils",
- "rayon",
- "regex",
- "regex-syntax",
- "strum_macros 0.27.2",
- "unicode-normalization",
- "unicode-reverse",
- "version_check",
-]
-
-[[package]]
-name = "polars-parquet"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cc1d769180dec070df0dc4b89299b364bf2cfe32b218ecc4ddd8f1a49ae60669"
-dependencies = [
- "async-stream",
- "base64 0.22.1",
- "brotli",
- "bytemuck",
- "ethnum",
- "flate2",
- "futures 0.3.31",
- "hashbrown 0.15.5",
- "lz4",
- "num-traits",
- "polars-arrow",
- "polars-compute",
- "polars-error",
- "polars-parquet-format",
- "polars-utils",
- "serde",
- "simdutf8",
- "snap",
- "streaming-decompression",
- "zstd 0.13.3",
-]
-
-[[package]]
-name = "polars-parquet-format"
-version = "0.1.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c025243dcfe8dbc57e94d9f82eb3bef10b565ab180d5b99bed87fd8aea319ce1"
-dependencies = [
- "async-trait",
- "futures 0.3.31",
-]
-
-[[package]]
-name = "polars-plan"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1cd3a2e33ae4484fe407ab2d2ba5684f0889d1ccf3ad6b844103c03638e6d0a0"
-dependencies = [
- "bitflags 2.9.4",
- "bytemuck",
- "bytes 1.10.1",
- "chrono",
- "chrono-tz",
- "either",
- "futures 0.3.31",
- "hashbrown 0.15.5",
- "memmap2",
- "num-traits",
- "percent-encoding",
- "polars-arrow",
- "polars-compute",
- "polars-core",
- "polars-error",
- "polars-io",
- "polars-ops",
- "polars-parquet",
- "polars-time",
- "polars-utils",
- "rayon",
- "recursive",
- "regex",
- "sha2",
- "strum_macros 0.27.2",
- "version_check",
-]
-
-[[package]]
-name = "polars-row"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "18734f17e0e348724df3ae65f3ee744c681117c04b041cac969dfceb05edabc0"
-dependencies = [
- "bitflags 2.9.4",
- "bytemuck",
- "polars-arrow",
- "polars-compute",
- "polars-dtype",
- "polars-error",
- "polars-utils",
-]
-
-[[package]]
-name = "polars-schema"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8e6c1ab13e04d5167661a9854ed1ea0482b2ed9b8a0f1118dabed7cd994a85e3"
-dependencies = [
- "indexmap",
- "polars-error",
- "polars-utils",
- "serde",
- "version_check",
-]
-
-[[package]]
-name = "polars-sql"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c4e7766da02cc1d464994404d3e88a7a0ccd4933df3627c325480fbd9bbc0a11"
-dependencies = [
- "bitflags 2.9.4",
- "hex",
- "polars-core",
- "polars-error",
- "polars-lazy",
- "polars-ops",
- "polars-plan",
- "polars-time",
- "polars-utils",
- "rand 0.9.2",
- "regex",
- "serde",
- "sqlparser",
-]
-
-[[package]]
-name = "polars-stream"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "31f6c6ca1ea01f9dea424d167e4f33f5ec44cd67fbfac9efd40575ed20521f14"
-dependencies = [
- "async-channel 2.5.0",
- "async-trait",
- "atomic-waker",
- "bitflags 2.9.4",
- "crossbeam-channel",
- "crossbeam-deque",
- "crossbeam-queue",
- "crossbeam-utils",
- "futures 0.3.31",
- "memmap2",
- "parking_lot",
- "percent-encoding",
- "pin-project-lite",
- "polars-arrow",
- "polars-core",
- "polars-error",
- "polars-expr",
- "polars-io",
- "polars-mem-engine",
- "polars-ops",
- "polars-parquet",
- "polars-plan",
- "polars-utils",
- "rand 0.9.2",
- "rayon",
- "recursive",
- "slotmap",
- "tokio",
- "tokio-util",
- "version_check",
-]
-
-[[package]]
-name = "polars-time"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f6a3a6e279a7a984a0b83715660f9e880590c6129ec2104396bfa710bcd76dee"
-dependencies = [
- "atoi_simd",
- "bytemuck",
- "chrono",
- "chrono-tz",
- "now",
- "num-traits",
- "polars-arrow",
- "polars-compute",
- "polars-core",
- "polars-error",
- "polars-ops",
- "polars-utils",
- "rayon",
- "regex",
- "strum_macros 0.27.2",
-]
-
-[[package]]
-name = "polars-utils"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "57b267021b0e5422d7fbc70fd79e51b9f9a8466c585779373a18b0199e973f29"
-dependencies = [
- "bincode 2.0.1",
- "bytemuck",
- "bytes 1.10.1",
- "compact_str",
- "either",
- "flate2",
- "foldhash 0.1.5",
- "hashbrown 0.15.5",
- "indexmap",
- "libc",
- "memmap2",
- "num-traits",
- "polars-error",
- "rand 0.9.2",
- "raw-cpuid 11.6.0",
- "rayon",
- "regex",
- "rmp-serde",
- "serde",
- "serde_json",
- "serde_stacker",
- "slotmap",
- "stacker",
- "uuid",
- "version_check",
-]
-
[[package]]
name = "polling"
version = "3.11.0"
@@ -54,10 +54,9 @@ members = [
"crates/diagnostics",
"crates/docs_preprocessor",
"crates/edit_prediction",
- "crates/edit_prediction_button",
+ "crates/edit_prediction_types",
+ "crates/edit_prediction_ui",
"crates/edit_prediction_context",
- "crates/edit_prediction_context2",
- "crates/zeta2_tools",
"crates/editor",
"crates/eval",
"crates/eval_utils",
@@ -202,8 +201,7 @@ members = [
"crates/zed",
"crates/zed_actions",
"crates/zed_env_vars",
- "crates/zeta",
- "crates/zeta_cli",
+ "crates/edit_prediction_cli",
"crates/zlog",
"crates/zlog_settings",
@@ -314,11 +312,9 @@ http_client = { path = "crates/http_client" }
http_client_tls = { path = "crates/http_client_tls" }
icons = { path = "crates/icons" }
image_viewer = { path = "crates/image_viewer" }
-edit_prediction = { path = "crates/edit_prediction" }
-edit_prediction_button = { path = "crates/edit_prediction_button" }
+edit_prediction_types = { path = "crates/edit_prediction_types" }
+edit_prediction_ui = { path = "crates/edit_prediction_ui" }
edit_prediction_context = { path = "crates/edit_prediction_context" }
-edit_prediction_context2 = { path = "crates/edit_prediction_context2" }
-zeta2_tools = { path = "crates/zeta2_tools" }
inspector_ui = { path = "crates/inspector_ui" }
install_cli = { path = "crates/install_cli" }
journal = { path = "crates/journal" }
@@ -435,7 +431,7 @@ x_ai = { path = "crates/x_ai" }
zed = { path = "crates/zed" }
zed_actions = { path = "crates/zed_actions" }
zed_env_vars = { path = "crates/zed_env_vars" }
-zeta = { path = "crates/zeta" }
+edit_prediction = { path = "crates/edit_prediction" }
zlog = { path = "crates/zlog" }
zlog_settings = { path = "crates/zlog_settings" }
@@ -830,7 +826,7 @@ feature_flags = { codegen-units = 1 }
file_icons = { codegen-units = 1 }
fsevent = { codegen-units = 1 }
image_viewer = { codegen-units = 1 }
-edit_prediction_button = { codegen-units = 1 }
+edit_prediction_ui = { codegen-units = 1 }
install_cli = { codegen-units = 1 }
journal = { codegen-units = 1 }
json_schema_store = { codegen-units = 1 }
@@ -41,7 +41,7 @@
"ctrl-f11": "debugger::StepInto",
"shift-f11": "debugger::StepOut",
"f11": "zed::ToggleFullScreen",
- "ctrl-alt-z": "edit_prediction::RateCompletions",
+ "ctrl-alt-z": "edit_prediction::RatePredictions",
"ctrl-alt-shift-i": "edit_prediction::ToggleMenu",
"ctrl-alt-l": "lsp_tool::ToggleMenu"
}
@@ -1322,18 +1322,10 @@
}
},
{
- "context": "Zeta2Feedback > Editor",
+ "context": "EditPredictionContext > Editor",
"bindings": {
- "enter": "editor::Newline",
- "ctrl-enter up": "dev::Zeta2RatePredictionPositive",
- "ctrl-enter down": "dev::Zeta2RatePredictionNegative"
- }
- },
- {
- "context": "Zeta2Context > Editor",
- "bindings": {
- "alt-left": "dev::Zeta2ContextGoBack",
- "alt-right": "dev::Zeta2ContextGoForward"
+ "alt-left": "dev::EditPredictionContextGoBack",
+ "alt-right": "dev::EditPredictionContextGoForward"
}
},
{
@@ -47,7 +47,7 @@
"cmd-m": "zed::Minimize",
"fn-f": "zed::ToggleFullScreen",
"ctrl-cmd-f": "zed::ToggleFullScreen",
- "ctrl-cmd-z": "edit_prediction::RateCompletions",
+ "ctrl-cmd-z": "edit_prediction::RatePredictions",
"ctrl-cmd-i": "edit_prediction::ToggleMenu",
"ctrl-cmd-l": "lsp_tool::ToggleMenu",
"ctrl-cmd-c": "editor::DisplayCursorNames"
@@ -1427,18 +1427,10 @@
}
},
{
- "context": "Zeta2Feedback > Editor",
+ "context": "EditPredictionContext > Editor",
"bindings": {
- "enter": "editor::Newline",
- "cmd-enter up": "dev::Zeta2RatePredictionPositive",
- "cmd-enter down": "dev::Zeta2RatePredictionNegative"
- }
- },
- {
- "context": "Zeta2Context > Editor",
- "bindings": {
- "alt-left": "dev::Zeta2ContextGoBack",
- "alt-right": "dev::Zeta2ContextGoForward"
+ "alt-left": "dev::EditPredictionContextGoBack",
+ "alt-right": "dev::EditPredictionContextGoForward"
}
},
{
@@ -1341,18 +1341,10 @@
}
},
{
- "context": "Zeta2Feedback > Editor",
+ "context": "EditPredictionContext > Editor",
"bindings": {
- "enter": "editor::Newline",
- "ctrl-enter up": "dev::Zeta2RatePredictionPositive",
- "ctrl-enter down": "dev::Zeta2RatePredictionNegative"
- }
- },
- {
- "context": "Zeta2Context > Editor",
- "bindings": {
- "alt-left": "dev::Zeta2ContextGoBack",
- "alt-right": "dev::Zeta2ContextGoForward"
+ "alt-left": "dev::EditPredictionContextGoBack",
+ "alt-right": "dev::EditPredictionContextGoForward"
}
},
{
@@ -31,18 +31,10 @@ pub struct PredictEditsRequest {
/// Within `signatures`
pub excerpt_parent: Option<usize>,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
- pub included_files: Vec<IncludedFile>,
- #[serde(skip_serializing_if = "Vec::is_empty", default)]
- pub signatures: Vec<Signature>,
- #[serde(skip_serializing_if = "Vec::is_empty", default)]
- pub referenced_declarations: Vec<ReferencedDeclaration>,
+ pub related_files: Vec<RelatedFile>,
pub events: Vec<Arc<Event>>,
#[serde(default)]
pub can_collect_data: bool,
- #[serde(skip_serializing_if = "Vec::is_empty", default)]
- pub diagnostic_groups: Vec<DiagnosticGroup>,
- #[serde(skip_serializing_if = "is_default", default)]
- pub diagnostic_groups_truncated: bool,
/// Info about the git repository state, only present when can_collect_data is true.
#[serde(skip_serializing_if = "Option::is_none", default)]
pub git_info: Option<PredictEditsGitInfo>,
@@ -58,7 +50,7 @@ pub struct PredictEditsRequest {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct IncludedFile {
+pub struct RelatedFile {
pub path: Arc<Path>,
pub max_row: Line,
pub excerpts: Vec<Excerpt>,
@@ -72,11 +64,9 @@ pub struct Excerpt {
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum PromptFormat {
- MarkedExcerpt,
- LabeledSections,
- NumLinesUniDiff,
+ /// XML old_tex/new_text
OldTextNewText,
- /// Prompt format intended for use via zeta_cli
+ /// Prompt format intended for use via edit_prediction_cli
OnlySnippets,
/// One-sentence instructions used in fine-tuned models
Minimal,
@@ -87,7 +77,7 @@ pub enum PromptFormat {
}
impl PromptFormat {
- pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
+ pub const DEFAULT: PromptFormat = PromptFormat::Minimal;
}
impl Default for PromptFormat {
@@ -105,10 +95,7 @@ impl PromptFormat {
impl std::fmt::Display for PromptFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
- PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
- PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
- PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
PromptFormat::Minimal => write!(f, "Minimal"),
PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"),
@@ -178,67 +165,6 @@ impl<'a> std::fmt::Display for DiffPathFmt<'a> {
}
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct Signature {
- pub text: String,
- pub text_is_truncated: bool,
- #[serde(skip_serializing_if = "Option::is_none", default)]
- pub parent_index: Option<usize>,
- /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
- /// file is implicitly the file that contains the descendant declaration or excerpt.
- pub range: Range<Line>,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct ReferencedDeclaration {
- pub path: Arc<Path>,
- pub text: String,
- pub text_is_truncated: bool,
- /// Range of `text` within file, possibly truncated according to `text_is_truncated`
- pub range: Range<Line>,
- /// Range within `text`
- pub signature_range: Range<usize>,
- /// Index within `signatures`.
- #[serde(skip_serializing_if = "Option::is_none", default)]
- pub parent_index: Option<usize>,
- pub score_components: DeclarationScoreComponents,
- pub signature_score: f32,
- pub declaration_score: f32,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct DeclarationScoreComponents {
- pub is_same_file: bool,
- pub is_referenced_nearby: bool,
- pub is_referenced_in_breadcrumb: bool,
- pub reference_count: usize,
- pub same_file_declaration_count: usize,
- pub declaration_count: usize,
- pub reference_line_distance: u32,
- pub declaration_line_distance: u32,
- pub excerpt_vs_item_jaccard: f32,
- pub excerpt_vs_signature_jaccard: f32,
- pub adjacent_vs_item_jaccard: f32,
- pub adjacent_vs_signature_jaccard: f32,
- pub excerpt_vs_item_weighted_overlap: f32,
- pub excerpt_vs_signature_weighted_overlap: f32,
- pub adjacent_vs_item_weighted_overlap: f32,
- pub adjacent_vs_signature_weighted_overlap: f32,
- pub path_import_match_count: usize,
- pub wildcard_path_import_match_count: usize,
- pub import_similarity: f32,
- pub max_import_similarity: f32,
- pub normalized_import_similarity: f32,
- pub wildcard_import_similarity: f32,
- pub normalized_wildcard_import_similarity: f32,
- pub included_by_others: usize,
- pub includes_others: usize,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-#[serde(transparent)]
-pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
-
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictEditsResponse {
pub request_id: Uuid,
@@ -262,10 +188,6 @@ pub struct Edit {
pub content: String,
}
-fn is_default<T: Default + PartialEq>(value: &T) -> bool {
- *value == T::default()
-}
-
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
pub struct Point {
pub line: Line,
@@ -15,9 +15,4 @@ path = "src/cloud_zeta2_prompt.rs"
anyhow.workspace = true
cloud_llm_client.workspace = true
indoc.workspace = true
-ordered-float.workspace = true
-rustc-hash.workspace = true
-schemars.workspace = true
serde.workspace = true
-serde_json.workspace = true
-strum.workspace = true
@@ -1,20 +1,12 @@
-//! Zeta2 prompt planning and generation code shared with cloud.
-pub mod retrieval_prompt;
-
-use anyhow::{Context as _, Result, anyhow};
+use anyhow::Result;
use cloud_llm_client::predict_edits_v3::{
- self, DiffPathFmt, Event, Excerpt, IncludedFile, Line, Point, PromptFormat,
- ReferencedDeclaration,
+ self, DiffPathFmt, Event, Excerpt, Line, Point, PromptFormat, RelatedFile,
};
use indoc::indoc;
-use ordered_float::OrderedFloat;
-use rustc_hash::{FxHashMap, FxHashSet};
-use serde::Serialize;
use std::cmp;
use std::fmt::Write;
+use std::path::Path;
use std::sync::Arc;
-use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
-use strum::{EnumIter, IntoEnumIterator};
pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
@@ -24,69 +16,6 @@ pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_s
/// NOTE: Differs from zed version of constant - includes a newline
pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
-// TODO: use constants for markers?
-const MARKED_EXCERPT_INSTRUCTIONS: &str = indoc! {"
- You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
-
- The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|user_cursor|>. Please respond with edited code for that region.
-
- Other code is provided for context, and `…` indicates when code has been skipped.
-
- ## Edit History
-
-"};
-
-const LABELED_SECTIONS_INSTRUCTIONS: &str = indoc! {r#"
- You are a code completion assistant and your task is to analyze user edits, and suggest an edit to one of the provided sections of code.
-
- Sections of code are grouped by file and then labeled by `<|section_N|>` (e.g `<|section_8|>`).
-
- The cursor position is marked with `<|user_cursor|>` and it will appear within a special section labeled `<|current_section|>`. Prefer editing the current section until no more changes are needed within it.
-
- Respond ONLY with the name of the section to edit on a single line, followed by all of the code that should replace that section. For example:
-
- <|current_section|>
- for i in 0..16 {
- println!("{i}");
- }
-
- ## Edit History
-
-"#};
-
-const NUMBERED_LINES_INSTRUCTIONS: &str = indoc! {r#"
- # Instructions
-
- You are an edit prediction agent in a code editor.
- Your job is to predict the next edit that the user will make,
- based on their last few edits and their current cursor location.
-
- ## Output Format
-
- You must briefly explain your understanding of the user's goal, in one
- or two sentences, and then specify their next edit in the form of a
- unified diff, like this:
-
- ```
- --- a/src/myapp/cli.py
- +++ b/src/myapp/cli.py
- @@ ... @@
- import os
- import time
- import sys
- +from constants import LOG_LEVEL_WARNING
- @@ ... @@
- config.headless()
- config.set_interactive(false)
- -config.set_log_level(LOG_L)
- +config.set_log_level(LOG_LEVEL_WARNING)
- config.set_use_color(True)
- ```
-
- ## Edit History
-
-"#};
-
const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#"
You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.
@@ -94,20 +23,6 @@ const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#"
"#};
-const UNIFIED_DIFF_REMINDER: &str = indoc! {"
- ---
-
- Analyze the edit history and the files, then provide the unified diff for your predicted edits.
- Do not include the cursor marker in your output.
- Your diff should include edited file paths in its file headers (lines beginning with `---` and `+++`).
- Do not include line numbers in the hunk headers, use `@@ ... @@`.
- Removed lines begin with `-`.
- Added lines begin with `+`.
- Context lines begin with an extra space.
- Context and removed lines are used to match the target edit location, so make sure to include enough of them
- to uniquely identify it amongst all excerpts of code provided.
-"};
-
const MINIMAL_PROMPT_REMINDER: &str = indoc! {"
---
@@ -164,49 +79,25 @@ const OLD_TEXT_NEW_TEXT_REMINDER: &str = indoc! {r#"
Remember that the edits in the edit history have already been applied.
"#};
-pub fn build_prompt(
- request: &predict_edits_v3::PredictEditsRequest,
-) -> Result<(String, SectionLabels)> {
- let mut section_labels = Default::default();
-
+pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<String> {
let prompt_data = PromptData {
events: request.events.clone(),
cursor_point: request.cursor_point,
cursor_path: request.excerpt_path.clone(),
- included_files: request.included_files.clone(),
+ included_files: request.related_files.clone(),
};
match request.prompt_format {
PromptFormat::MinimalQwen => {
- return Ok((MinimalQwenPrompt.render(&prompt_data), section_labels));
+ return Ok(MinimalQwenPrompt.render(&prompt_data));
}
PromptFormat::SeedCoder1120 => {
- return Ok((SeedCoder1120Prompt.render(&prompt_data), section_labels));
+ return Ok(SeedCoder1120Prompt.render(&prompt_data));
}
_ => (),
};
- let mut insertions = match request.prompt_format {
- PromptFormat::MarkedExcerpt => vec![
- (
- Point {
- line: request.excerpt_line_range.start,
- column: 0,
- },
- EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
- ),
- (request.cursor_point, CURSOR_MARKER),
- (
- Point {
- line: request.excerpt_line_range.end,
- column: 0,
- },
- EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
- ),
- ],
- PromptFormat::LabeledSections
- | PromptFormat::NumLinesUniDiff
- | PromptFormat::Minimal
- | PromptFormat::OldTextNewText => {
+ let insertions = match request.prompt_format {
+ PromptFormat::Minimal | PromptFormat::OldTextNewText => {
vec![(request.cursor_point, CURSOR_MARKER)]
}
PromptFormat::OnlySnippets => vec![],
@@ -215,9 +106,6 @@ pub fn build_prompt(
};
let mut prompt = match request.prompt_format {
- PromptFormat::MarkedExcerpt => MARKED_EXCERPT_INSTRUCTIONS.to_string(),
- PromptFormat::LabeledSections => LABELED_SECTIONS_INSTRUCTIONS.to_string(),
- PromptFormat::NumLinesUniDiff => NUMBERED_LINES_INSTRUCTIONS.to_string(),
PromptFormat::OldTextNewText => XML_TAGS_INSTRUCTIONS.to_string(),
PromptFormat::OnlySnippets => String::new(),
PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(),
@@ -247,7 +135,7 @@ pub fn build_prompt(
You can only edit exactly this part of the file.
We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.)
"},
- PromptFormat::NumLinesUniDiff | PromptFormat::OldTextNewText => indoc! {"
+ PromptFormat::OldTextNewText => indoc! {"
## Code Excerpts
Here is some excerpts of code that you should take into account to predict the next edit.
@@ -263,64 +151,51 @@ pub fn build_prompt(
Lines starting with `…` indicate omitted line ranges. These may appear inside multi-line code constructs.
"},
- _ => indoc! {"
+ PromptFormat::OnlySnippets | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
+ indoc! {"
## Code Excerpts
The cursor marker <|user_cursor|> indicates the current user cursor position.
The file is in current state, edits from edit history have been applied.
- "},
+ "}
+ }
};
prompt.push_str(excerpts_preamble);
prompt.push('\n');
- if !request.referenced_declarations.is_empty() || !request.signatures.is_empty() {
- let syntax_based_prompt = SyntaxBasedPrompt::populate(request)?;
- section_labels = syntax_based_prompt.write(&mut insertions, &mut prompt)?;
- } else {
- if request.prompt_format == PromptFormat::LabeledSections {
- anyhow::bail!("PromptFormat::LabeledSections cannot be used with ContextMode::Llm");
- }
-
- let include_line_numbers = matches!(
- request.prompt_format,
- PromptFormat::NumLinesUniDiff | PromptFormat::Minimal
- );
- for related_file in &request.included_files {
- if request.prompt_format == PromptFormat::Minimal {
- write_codeblock_with_filename(
- &related_file.path,
- &related_file.excerpts,
- if related_file.path == request.excerpt_path {
- &insertions
- } else {
- &[]
- },
- related_file.max_row,
- include_line_numbers,
- &mut prompt,
- );
- } else {
- write_codeblock(
- &related_file.path,
- &related_file.excerpts,
- if related_file.path == request.excerpt_path {
- &insertions
- } else {
- &[]
- },
- related_file.max_row,
- include_line_numbers,
- &mut prompt,
- );
- }
+ let include_line_numbers = matches!(request.prompt_format, PromptFormat::Minimal);
+ for related_file in &request.related_files {
+ if request.prompt_format == PromptFormat::Minimal {
+ write_codeblock_with_filename(
+ &related_file.path,
+ &related_file.excerpts,
+ if related_file.path == request.excerpt_path {
+ &insertions
+ } else {
+ &[]
+ },
+ related_file.max_row,
+ include_line_numbers,
+ &mut prompt,
+ );
+ } else {
+ write_codeblock(
+ &related_file.path,
+ &related_file.excerpts,
+ if related_file.path == request.excerpt_path {
+ &insertions
+ } else {
+ &[]
+ },
+ related_file.max_row,
+ include_line_numbers,
+ &mut prompt,
+ );
}
}
match request.prompt_format {
- PromptFormat::NumLinesUniDiff => {
- prompt.push_str(UNIFIED_DIFF_REMINDER);
- }
PromptFormat::OldTextNewText => {
prompt.push_str(OLD_TEXT_NEW_TEXT_REMINDER);
}
@@ -330,7 +205,7 @@ pub fn build_prompt(
_ => {}
}
- Ok((prompt, section_labels))
+ Ok(prompt)
}
pub fn generation_params(prompt_format: PromptFormat) -> GenerationParams {
@@ -444,476 +319,11 @@ pub fn push_events(output: &mut String, events: &[Arc<predict_edits_v3::Event>])
writeln!(output, "`````\n").unwrap();
}
-pub struct SyntaxBasedPrompt<'a> {
- request: &'a predict_edits_v3::PredictEditsRequest,
- /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
- /// `to_prompt_string`.
- snippets: Vec<PlannedSnippet<'a>>,
- budget_used: usize,
-}
-
-#[derive(Clone, Debug)]
-pub struct PlannedSnippet<'a> {
- path: Arc<Path>,
- range: Range<Line>,
- text: &'a str,
- // TODO: Indicate this in the output
- #[allow(dead_code)]
- text_is_truncated: bool,
-}
-
-#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
-pub enum DeclarationStyle {
- Signature,
- Declaration,
-}
-
-#[derive(Default, Clone, Debug, Serialize)]
-pub struct SectionLabels {
- pub excerpt_index: usize,
- pub section_ranges: Vec<(Arc<Path>, Range<Line>)>,
-}
-
-impl<'a> SyntaxBasedPrompt<'a> {
- /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
- ///
- /// Initializes a priority queue by populating it with each snippet, finding the
- /// DeclarationStyle that minimizes `score_density = score / snippet.range(style).len()`. When a
- /// "signature" snippet is popped, insert an entry for the "declaration" variant that reflects
- /// the cost of upgrade.
- ///
- /// TODO: Implement an early halting condition. One option might be to have another priority
- /// queue where the score is the size, and update it accordingly. Another option might be to
- /// have some simpler heuristic like bailing after N failed insertions, or based on how much
- /// budget is left.
- ///
- /// TODO: Has the current known sources of imprecision:
- ///
- /// * Does not consider snippet overlap when ranking. For example, it might add a field to the
- /// plan even though the containing struct is already included.
- ///
- /// * Does not consider cost of signatures when ranking snippets - this is tricky since
- /// signatures may be shared by multiple snippets.
- ///
- /// * Does not include file paths / other text when considering max_bytes.
- pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result<Self> {
- let mut this = Self {
- request,
- snippets: Vec::new(),
- budget_used: request.excerpt.len(),
- };
- let mut included_parents = FxHashSet::default();
- let additional_parents = this.additional_parent_signatures(
- &request.excerpt_path,
- request.excerpt_parent,
- &included_parents,
- )?;
- this.add_parents(&mut included_parents, additional_parents);
-
- let max_bytes = request.prompt_max_bytes.unwrap_or(DEFAULT_MAX_PROMPT_BYTES);
-
- if this.budget_used > max_bytes {
- return Err(anyhow!(
- "Excerpt + signatures size of {} already exceeds budget of {}",
- this.budget_used,
- max_bytes
- ));
- }
-
- #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
- struct QueueEntry {
- score_density: OrderedFloat<f32>,
- declaration_index: usize,
- style: DeclarationStyle,
- }
-
- // Initialize priority queue with the best score for each snippet.
- let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
- for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
- let (style, score_density) = DeclarationStyle::iter()
- .map(|style| {
- (
- style,
- OrderedFloat(declaration_score_density(&declaration, style)),
- )
- })
- .max_by_key(|(_, score_density)| *score_density)
- .unwrap();
- queue.push(QueueEntry {
- score_density,
- declaration_index,
- style,
- });
- }
-
- // Knapsack selection loop
- while let Some(queue_entry) = queue.pop() {
- let Some(declaration) = request
- .referenced_declarations
- .get(queue_entry.declaration_index)
- else {
- return Err(anyhow!(
- "Invalid declaration index {}",
- queue_entry.declaration_index
- ));
- };
-
- let mut additional_bytes = declaration_size(declaration, queue_entry.style);
- if this.budget_used + additional_bytes > max_bytes {
- continue;
- }
-
- let additional_parents = this.additional_parent_signatures(
- &declaration.path,
- declaration.parent_index,
- &mut included_parents,
- )?;
- additional_bytes += additional_parents
- .iter()
- .map(|(_, snippet)| snippet.text.len())
- .sum::<usize>();
- if this.budget_used + additional_bytes > max_bytes {
- continue;
- }
-
- this.budget_used += additional_bytes;
- this.add_parents(&mut included_parents, additional_parents);
- let planned_snippet = match queue_entry.style {
- DeclarationStyle::Signature => {
- let Some(text) = declaration.text.get(declaration.signature_range.clone())
- else {
- return Err(anyhow!(
- "Invalid declaration signature_range {:?} with text.len() = {}",
- declaration.signature_range,
- declaration.text.len()
- ));
- };
- let signature_start_line = declaration.range.start
- + Line(
- declaration.text[..declaration.signature_range.start]
- .lines()
- .count() as u32,
- );
- let signature_end_line = signature_start_line
- + Line(
- declaration.text
- [declaration.signature_range.start..declaration.signature_range.end]
- .lines()
- .count() as u32,
- );
- let range = signature_start_line..signature_end_line;
-
- PlannedSnippet {
- path: declaration.path.clone(),
- range,
- text,
- text_is_truncated: declaration.text_is_truncated,
- }
- }
- DeclarationStyle::Declaration => PlannedSnippet {
- path: declaration.path.clone(),
- range: declaration.range.clone(),
- text: &declaration.text,
- text_is_truncated: declaration.text_is_truncated,
- },
- };
- this.snippets.push(planned_snippet);
-
- // When a Signature is consumed, insert an entry for Definition style.
- if queue_entry.style == DeclarationStyle::Signature {
- let signature_size = declaration_size(&declaration, DeclarationStyle::Signature);
- let declaration_size =
- declaration_size(&declaration, DeclarationStyle::Declaration);
- let signature_score = declaration_score(&declaration, DeclarationStyle::Signature);
- let declaration_score =
- declaration_score(&declaration, DeclarationStyle::Declaration);
-
- let score_diff = declaration_score - signature_score;
- let size_diff = declaration_size.saturating_sub(signature_size);
- if score_diff > 0.0001 && size_diff > 0 {
- queue.push(QueueEntry {
- declaration_index: queue_entry.declaration_index,
- score_density: OrderedFloat(score_diff / (size_diff as f32)),
- style: DeclarationStyle::Declaration,
- });
- }
- }
- }
-
- anyhow::Ok(this)
- }
-
- fn add_parents(
- &mut self,
- included_parents: &mut FxHashSet<usize>,
- snippets: Vec<(usize, PlannedSnippet<'a>)>,
- ) {
- for (parent_index, snippet) in snippets {
- included_parents.insert(parent_index);
- self.budget_used += snippet.text.len();
- self.snippets.push(snippet);
- }
- }
-
- fn additional_parent_signatures(
- &self,
- path: &Arc<Path>,
- parent_index: Option<usize>,
- included_parents: &FxHashSet<usize>,
- ) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
- let mut results = Vec::new();
- self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
- Ok(results)
- }
-
- fn additional_parent_signatures_impl(
- &self,
- path: &Arc<Path>,
- parent_index: Option<usize>,
- included_parents: &FxHashSet<usize>,
- results: &mut Vec<(usize, PlannedSnippet<'a>)>,
- ) -> Result<()> {
- let Some(parent_index) = parent_index else {
- return Ok(());
- };
- if included_parents.contains(&parent_index) {
- return Ok(());
- }
- let Some(parent_signature) = self.request.signatures.get(parent_index) else {
- return Err(anyhow!("Invalid parent index {}", parent_index));
- };
- results.push((
- parent_index,
- PlannedSnippet {
- path: path.clone(),
- range: parent_signature.range.clone(),
- text: &parent_signature.text,
- text_is_truncated: parent_signature.text_is_truncated,
- },
- ));
- self.additional_parent_signatures_impl(
- path,
- parent_signature.parent_index,
- included_parents,
- results,
- )
- }
-
- /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
- /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
- /// chunks.
- pub fn write(
- &'a self,
- excerpt_file_insertions: &mut Vec<(Point, &'static str)>,
- prompt: &mut String,
- ) -> Result<SectionLabels> {
- let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
- FxHashMap::default();
- for snippet in &self.snippets {
- file_to_snippets
- .entry(&snippet.path)
- .or_default()
- .push(snippet);
- }
-
- // Reorder so that file with cursor comes last
- let mut file_snippets = Vec::new();
- let mut excerpt_file_snippets = Vec::new();
- for (file_path, snippets) in file_to_snippets {
- if file_path == self.request.excerpt_path.as_ref() {
- excerpt_file_snippets = snippets;
- } else {
- file_snippets.push((file_path, snippets, false));
- }
- }
- let excerpt_snippet = PlannedSnippet {
- path: self.request.excerpt_path.clone(),
- range: self.request.excerpt_line_range.clone(),
- text: &self.request.excerpt,
- text_is_truncated: false,
- };
- excerpt_file_snippets.push(&excerpt_snippet);
- file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
-
- let section_labels =
- self.push_file_snippets(prompt, excerpt_file_insertions, file_snippets)?;
-
- Ok(section_labels)
- }
-
- fn push_file_snippets(
- &self,
- output: &mut String,
- excerpt_file_insertions: &mut Vec<(Point, &'static str)>,
- file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>,
- ) -> Result<SectionLabels> {
- let mut section_ranges = Vec::new();
- let mut excerpt_index = None;
-
- for (file_path, mut snippets, is_excerpt_file) in file_snippets {
- snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
-
- // TODO: What if the snippets get expanded too large to be editable?
- let mut current_snippet: Option<(&PlannedSnippet, Range<Line>)> = None;
- let mut disjoint_snippets: Vec<(&PlannedSnippet, Range<Line>)> = Vec::new();
- for snippet in snippets {
- if let Some((_, current_snippet_range)) = current_snippet.as_mut()
- && snippet.range.start <= current_snippet_range.end
- {
- current_snippet_range.end = current_snippet_range.end.max(snippet.range.end);
- continue;
- }
- if let Some(current_snippet) = current_snippet.take() {
- disjoint_snippets.push(current_snippet);
- }
- current_snippet = Some((snippet, snippet.range.clone()));
- }
- if let Some(current_snippet) = current_snippet.take() {
- disjoint_snippets.push(current_snippet);
- }
-
- writeln!(output, "`````path={}", file_path.display()).ok();
- let mut skipped_last_snippet = false;
- for (snippet, range) in disjoint_snippets {
- let section_index = section_ranges.len();
-
- match self.request.prompt_format {
- PromptFormat::MarkedExcerpt
- | PromptFormat::OnlySnippets
- | PromptFormat::OldTextNewText
- | PromptFormat::Minimal
- | PromptFormat::NumLinesUniDiff => {
- if range.start.0 > 0 && !skipped_last_snippet {
- output.push_str("…\n");
- }
- }
- PromptFormat::LabeledSections => {
- if is_excerpt_file
- && range.start <= self.request.excerpt_line_range.start
- && range.end >= self.request.excerpt_line_range.end
- {
- writeln!(output, "<|current_section|>").ok();
- } else {
- writeln!(output, "<|section_{}|>", section_index).ok();
- }
- }
- PromptFormat::MinimalQwen => unreachable!(),
- PromptFormat::SeedCoder1120 => unreachable!(),
- }
-
- let push_full_snippet = |output: &mut String| {
- if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
- for (i, line) in snippet.text.lines().enumerate() {
- writeln!(output, "{}|{}", i as u32 + range.start.0 + 1, line)?;
- }
- } else {
- output.push_str(&snippet.text);
- }
- anyhow::Ok(())
- };
-
- if is_excerpt_file {
- if self.request.prompt_format == PromptFormat::OnlySnippets {
- if range.start >= self.request.excerpt_line_range.start
- && range.end <= self.request.excerpt_line_range.end
- {
- skipped_last_snippet = true;
- } else {
- skipped_last_snippet = false;
- output.push_str(snippet.text);
- }
- } else if !excerpt_file_insertions.is_empty() {
- let lines = snippet.text.lines().collect::<Vec<_>>();
- let push_line = |output: &mut String, line_ix: usize| {
- if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
- write!(output, "{}|", line_ix as u32 + range.start.0 + 1)?;
- }
- anyhow::Ok(writeln!(output, "{}", lines[line_ix])?)
- };
- let mut last_line_ix = 0;
- let mut insertion_ix = 0;
- while insertion_ix < excerpt_file_insertions.len() {
- let (point, insertion) = &excerpt_file_insertions[insertion_ix];
- let found = point.line >= range.start && point.line <= range.end;
- if found {
- excerpt_index = Some(section_index);
- let insertion_line_ix = (point.line.0 - range.start.0) as usize;
- for line_ix in last_line_ix..insertion_line_ix {
- push_line(output, line_ix)?;
- }
- if let Some(next_line) = lines.get(insertion_line_ix) {
- if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
- write!(
- output,
- "{}|",
- insertion_line_ix as u32 + range.start.0 + 1
- )?
- }
- output.push_str(&next_line[..point.column as usize]);
- output.push_str(insertion);
- writeln!(output, "{}", &next_line[point.column as usize..])?;
- } else {
- writeln!(output, "{}", insertion)?;
- }
- last_line_ix = insertion_line_ix + 1;
- excerpt_file_insertions.remove(insertion_ix);
- continue;
- }
- insertion_ix += 1;
- }
- skipped_last_snippet = false;
- for line_ix in last_line_ix..lines.len() {
- push_line(output, line_ix)?;
- }
- } else {
- skipped_last_snippet = false;
- push_full_snippet(output)?;
- }
- } else {
- skipped_last_snippet = false;
- push_full_snippet(output)?;
- }
-
- section_ranges.push((snippet.path.clone(), range));
- }
-
- output.push_str("`````\n\n");
- }
-
- Ok(SectionLabels {
- // TODO: Clean this up
- excerpt_index: match self.request.prompt_format {
- PromptFormat::OnlySnippets => 0,
- _ => excerpt_index.context("bug: no snippet found for excerpt")?,
- },
- section_ranges,
- })
- }
-}
-
-fn declaration_score_density(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
- declaration_score(declaration, style) / declaration_size(declaration, style) as f32
-}
-
-fn declaration_score(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
- match style {
- DeclarationStyle::Signature => declaration.signature_score,
- DeclarationStyle::Declaration => declaration.declaration_score,
- }
-}
-
-fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> usize {
- match style {
- DeclarationStyle::Signature => declaration.signature_range.len(),
- DeclarationStyle::Declaration => declaration.text.len(),
- }
-}
-
struct PromptData {
events: Vec<Arc<Event>>,
cursor_point: Point,
cursor_path: Arc<Path>, // TODO: make a common struct with cursor_point
- included_files: Vec<IncludedFile>,
+ included_files: Vec<RelatedFile>,
}
#[derive(Default)]
@@ -1051,7 +461,7 @@ impl SeedCoder1120Prompt {
context
}
- fn fmt_fim(&self, file: &IncludedFile, cursor_point: Point) -> String {
+ fn fmt_fim(&self, file: &RelatedFile, cursor_point: Point) -> String {
let mut buf = String::new();
const FIM_SUFFIX: &str = "<[fim-suffix]>";
const FIM_PREFIX: &str = "<[fim-prefix]>";
@@ -1,244 +0,0 @@
-use anyhow::Result;
-use cloud_llm_client::predict_edits_v3::{self, Excerpt};
-use indoc::indoc;
-use schemars::JsonSchema;
-use serde::{Deserialize, Serialize};
-use std::fmt::Write;
-
-use crate::{push_events, write_codeblock};
-
-pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> Result<String> {
- let mut prompt = SEARCH_INSTRUCTIONS.to_string();
-
- if !request.events.is_empty() {
- writeln!(&mut prompt, "\n## User Edits\n\n")?;
- push_events(&mut prompt, &request.events);
- }
-
- writeln!(&mut prompt, "## Cursor context\n")?;
- write_codeblock(
- &request.excerpt_path,
- &[Excerpt {
- start_line: request.excerpt_line_range.start,
- text: request.excerpt.into(),
- }],
- &[],
- request.cursor_file_max_row,
- true,
- &mut prompt,
- );
-
- writeln!(&mut prompt, "{TOOL_USE_REMINDER}")?;
-
- Ok(prompt)
-}
-
-/// Search for relevant code
-///
-/// For the best results, run multiple queries at once with a single invocation of this tool.
-#[derive(Clone, Deserialize, Serialize, JsonSchema)]
-pub struct SearchToolInput {
- /// An array of queries to run for gathering context relevant to the next prediction
- #[schemars(length(max = 3))]
- #[serde(deserialize_with = "deserialize_queries")]
- pub queries: Box<[SearchToolQuery]>,
-}
-
-fn deserialize_queries<'de, D>(deserializer: D) -> Result<Box<[SearchToolQuery]>, D::Error>
-where
- D: serde::Deserializer<'de>,
-{
- use serde::de::Error;
-
- #[derive(Deserialize)]
- #[serde(untagged)]
- enum QueryCollection {
- Array(Box<[SearchToolQuery]>),
- DoubleArray(Box<[Box<[SearchToolQuery]>]>),
- Single(SearchToolQuery),
- }
-
- #[derive(Deserialize)]
- #[serde(untagged)]
- enum MaybeDoubleEncoded {
- SingleEncoded(QueryCollection),
- DoubleEncoded(String),
- }
-
- let result = MaybeDoubleEncoded::deserialize(deserializer)?;
-
- let normalized = match result {
- MaybeDoubleEncoded::SingleEncoded(value) => value,
- MaybeDoubleEncoded::DoubleEncoded(value) => {
- serde_json::from_str(&value).map_err(D::Error::custom)?
- }
- };
-
- Ok(match normalized {
- QueryCollection::Array(items) => items,
- QueryCollection::Single(search_tool_query) => Box::new([search_tool_query]),
- QueryCollection::DoubleArray(double_array) => double_array.into_iter().flatten().collect(),
- })
-}
-
-/// Search for relevant code by path, syntax hierarchy, and content.
-#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)]
-pub struct SearchToolQuery {
- /// 1. A glob pattern to match file paths in the codebase to search in.
- pub glob: String,
- /// 2. Regular expressions to match syntax nodes **by their first line** and hierarchy.
- ///
- /// Subsequent regexes match nodes within the full content of the nodes matched by the previous regexes.
- ///
- /// Example: Searching for a `User` class
- /// ["class\s+User"]
- ///
- /// Example: Searching for a `get_full_name` method under a `User` class
- /// ["class\s+User", "def\sget_full_name"]
- ///
- /// Skip this field to match on content alone.
- #[schemars(length(max = 3))]
- #[serde(default)]
- pub syntax_node: Vec<String>,
- /// 3. An optional regular expression to match the final content that should appear in the results.
- ///
- /// - Content will be matched within all lines of the matched syntax nodes.
- /// - If syntax node regexes are provided, this field can be skipped to include as much of the node itself as possible.
- /// - If no syntax node regexes are provided, the content will be matched within the entire file.
- pub content: Option<String>,
-}
-
-pub const TOOL_NAME: &str = "search";
-
-const SEARCH_INSTRUCTIONS: &str = indoc! {r#"
- You are part of an edit prediction system in a code editor.
- Your role is to search for code that will serve as context for predicting the next edit.
-
- - Analyze the user's recent edits and current cursor context
- - Use the `search` tool to find code that is relevant for predicting the next edit
- - Focus on finding:
- - Code patterns that might need similar changes based on the recent edits
- - Functions, variables, types, and constants referenced in the current cursor context
- - Related implementations, usages, or dependencies that may require consistent updates
- - How items defined in the cursor excerpt are used or altered
- - You will not be able to filter results or perform subsequent queries, so keep searches as targeted as possible
- - Use `syntax_node` parameter whenever you're looking for a particular type, class, or function
- - Avoid using wildcard globs if you already know the file path of the content you're looking for
-"#};
-
-const TOOL_USE_REMINDER: &str = indoc! {"
- --
- Analyze the user's intent in one to two sentences, then call the `search` tool.
-"};
-
-#[cfg(test)]
-mod tests {
- use serde_json::json;
-
- use super::*;
-
- #[test]
- fn test_deserialize_queries() {
- let single_query_json = indoc! {r#"{
- "queries": {
- "glob": "**/*.rs",
- "syntax_node": ["fn test"],
- "content": "assert"
- }
- }"#};
-
- let flat_input: SearchToolInput = serde_json::from_str(single_query_json).unwrap();
- assert_eq!(flat_input.queries.len(), 1);
- assert_eq!(flat_input.queries[0].glob, "**/*.rs");
- assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]);
- assert_eq!(flat_input.queries[0].content, Some("assert".to_string()));
-
- let flat_json = indoc! {r#"{
- "queries": [
- {
- "glob": "**/*.rs",
- "syntax_node": ["fn test"],
- "content": "assert"
- },
- {
- "glob": "**/*.ts",
- "syntax_node": [],
- "content": null
- }
- ]
- }"#};
-
- let flat_input: SearchToolInput = serde_json::from_str(flat_json).unwrap();
- assert_eq!(flat_input.queries.len(), 2);
- assert_eq!(flat_input.queries[0].glob, "**/*.rs");
- assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]);
- assert_eq!(flat_input.queries[0].content, Some("assert".to_string()));
- assert_eq!(flat_input.queries[1].glob, "**/*.ts");
- assert_eq!(flat_input.queries[1].syntax_node.len(), 0);
- assert_eq!(flat_input.queries[1].content, None);
-
- let nested_json = indoc! {r#"{
- "queries": [
- [
- {
- "glob": "**/*.rs",
- "syntax_node": ["fn test"],
- "content": "assert"
- }
- ],
- [
- {
- "glob": "**/*.ts",
- "syntax_node": [],
- "content": null
- }
- ]
- ]
- }"#};
-
- let nested_input: SearchToolInput = serde_json::from_str(nested_json).unwrap();
-
- assert_eq!(nested_input.queries.len(), 2);
-
- assert_eq!(nested_input.queries[0].glob, "**/*.rs");
- assert_eq!(nested_input.queries[0].syntax_node, vec!["fn test"]);
- assert_eq!(nested_input.queries[0].content, Some("assert".to_string()));
- assert_eq!(nested_input.queries[1].glob, "**/*.ts");
- assert_eq!(nested_input.queries[1].syntax_node.len(), 0);
- assert_eq!(nested_input.queries[1].content, None);
-
- let double_encoded_queries = serde_json::to_string(&json!({
- "queries": serde_json::to_string(&json!([
- {
- "glob": "**/*.rs",
- "syntax_node": ["fn test"],
- "content": "assert"
- },
- {
- "glob": "**/*.ts",
- "syntax_node": [],
- "content": null
- }
- ])).unwrap()
- }))
- .unwrap();
-
- let double_encoded_input: SearchToolInput =
- serde_json::from_str(&double_encoded_queries).unwrap();
-
- assert_eq!(double_encoded_input.queries.len(), 2);
-
- assert_eq!(double_encoded_input.queries[0].glob, "**/*.rs");
- assert_eq!(double_encoded_input.queries[0].syntax_node, vec!["fn test"]);
- assert_eq!(
- double_encoded_input.queries[0].content,
- Some("assert".to_string())
- );
- assert_eq!(double_encoded_input.queries[1].glob, "**/*.ts");
- assert_eq!(double_encoded_input.queries[1].syntax_node.len(), 0);
- assert_eq!(double_encoded_input.queries[1].content, None);
-
- // ### ERROR Switching from var declarations to lexical declarations [RUN 073]
- // invalid search json {"queries": ["express/lib/response.js", "var\\s+[a-zA-Z_][a-zA-Z0-9_]*\\s*=.*;", "function.*\\(.*\\).*\\{.*\\}"]}
- }
-}
@@ -10,7 +10,7 @@ path = "src/codestral.rs"
[dependencies]
anyhow.workspace = true
-edit_prediction.workspace = true
+edit_prediction_types.workspace = true
edit_prediction_context.workspace = true
futures.workspace = true
gpui.workspace = true
@@ -1,6 +1,6 @@
use anyhow::{Context as _, Result};
-use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
+use edit_prediction_types::{Direction, EditPrediction, EditPredictionDelegate};
use futures::AsyncReadExt;
use gpui::{App, Context, Entity, Task};
use http_client::HttpClient;
@@ -43,17 +43,17 @@ impl CurrentCompletion {
/// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
/// Returns None if the user's edits conflict with the predicted edits.
fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
- edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
+ edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
}
}
-pub struct CodestralCompletionProvider {
+pub struct CodestralEditPredictionDelegate {
http_client: Arc<dyn HttpClient>,
pending_request: Option<Task<Result<()>>>,
current_completion: Option<CurrentCompletion>,
}
-impl CodestralCompletionProvider {
+impl CodestralEditPredictionDelegate {
pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
Self {
http_client,
@@ -165,7 +165,7 @@ impl CodestralCompletionProvider {
}
}
-impl EditPredictionProvider for CodestralCompletionProvider {
+impl EditPredictionDelegate for CodestralEditPredictionDelegate {
fn name() -> &'static str {
"codestral"
}
@@ -174,7 +174,7 @@ impl EditPredictionProvider for CodestralCompletionProvider {
"Codestral"
}
- fn show_completions_in_menu() -> bool {
+ fn show_predictions_in_menu() -> bool {
true
}
@@ -239,7 +239,6 @@ impl EditPredictionProvider for CodestralCompletionProvider {
cursor_point,
&snapshot,
&EXCERPT_OPTIONS,
- None,
)
.context("Line containing cursor doesn't fit in excerpt max bytes")?;
@@ -33,7 +33,7 @@ fs.workspace = true
futures.workspace = true
gpui.workspace = true
http_client.workspace = true
-edit_prediction.workspace = true
+edit_prediction_types.workspace = true
language.workspace = true
log.workspace = true
lsp.workspace = true
@@ -1,5 +1,5 @@
pub mod copilot_chat;
-mod copilot_completion_provider;
+mod copilot_edit_prediction_delegate;
pub mod copilot_responses;
pub mod request;
mod sign_in;
@@ -46,7 +46,7 @@ use util::rel_path::RelPath;
use util::{ResultExt, fs::remove_matching};
use workspace::Workspace;
-pub use crate::copilot_completion_provider::CopilotCompletionProvider;
+pub use crate::copilot_edit_prediction_delegate::CopilotEditPredictionDelegate;
pub use crate::sign_in::{CopilotCodeVerification, initiate_sign_in, reinstall_and_sign_in};
actions!(
@@ -1,6 +1,6 @@
use crate::{Completion, Copilot};
use anyhow::Result;
-use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
+use edit_prediction_types::{Direction, EditPrediction, EditPredictionDelegate};
use gpui::{App, Context, Entity, EntityId, Task};
use language::{Buffer, OffsetRangeExt, ToOffset, language_settings::AllLanguageSettings};
use settings::Settings;
@@ -8,7 +8,7 @@ use std::{path::Path, time::Duration};
pub const COPILOT_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
-pub struct CopilotCompletionProvider {
+pub struct CopilotEditPredictionDelegate {
cycled: bool,
buffer_id: Option<EntityId>,
completions: Vec<Completion>,
@@ -19,7 +19,7 @@ pub struct CopilotCompletionProvider {
copilot: Entity<Copilot>,
}
-impl CopilotCompletionProvider {
+impl CopilotEditPredictionDelegate {
pub fn new(copilot: Entity<Copilot>) -> Self {
Self {
cycled: false,
@@ -47,7 +47,7 @@ impl CopilotCompletionProvider {
}
}
-impl EditPredictionProvider for CopilotCompletionProvider {
+impl EditPredictionDelegate for CopilotEditPredictionDelegate {
fn name() -> &'static str {
"copilot"
}
@@ -56,7 +56,7 @@ impl EditPredictionProvider for CopilotCompletionProvider {
"Copilot"
}
- fn show_completions_in_menu() -> bool {
+ fn show_predictions_in_menu() -> bool {
true
}
@@ -314,7 +314,7 @@ mod tests {
cx,
)
.await;
- let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
+ let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
cx.update_editor(|editor, window, cx| {
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
});
@@ -546,7 +546,7 @@ mod tests {
cx,
)
.await;
- let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
+ let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
cx.update_editor(|editor, window, cx| {
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
});
@@ -670,7 +670,7 @@ mod tests {
cx,
)
.await;
- let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
+ let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
cx.update_editor(|editor, window, cx| {
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
});
@@ -753,7 +753,7 @@ mod tests {
window.focus(&editor.focus_handle(cx));
})
.unwrap();
- let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
+ let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
editor
.update(cx, |editor, window, cx| {
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
@@ -848,7 +848,7 @@ mod tests {
cx,
)
.await;
- let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
+ let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
cx.update_editor(|editor, window, cx| {
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
});
@@ -1000,7 +1000,7 @@ mod tests {
window.focus(&editor.focus_handle(cx))
})
.unwrap();
- let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
+ let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
editor
.update(cx, |editor, window, cx| {
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
@@ -11,7 +11,69 @@ workspace = true
[lib]
path = "src/edit_prediction.rs"
+[features]
+eval-support = []
+
[dependencies]
+ai_onboarding.workspace = true
+anyhow.workspace = true
+arrayvec.workspace = true
+brotli.workspace = true
client.workspace = true
+cloud_llm_client.workspace = true
+cloud_zeta2_prompt.workspace = true
+collections.workspace = true
+copilot.workspace = true
+credentials_provider.workspace = true
+db.workspace = true
+edit_prediction_types.workspace = true
+edit_prediction_context.workspace = true
+feature_flags.workspace = true
+fs.workspace = true
+futures.workspace = true
gpui.workspace = true
+indoc.workspace = true
+itertools.workspace = true
language.workspace = true
+language_model.workspace = true
+log.workspace = true
+lsp.workspace = true
+menu.workspace = true
+open_ai.workspace = true
+postage.workspace = true
+pretty_assertions.workspace = true
+project.workspace = true
+rand.workspace = true
+regex.workspace = true
+release_channel.workspace = true
+semver.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+settings.workspace = true
+smol.workspace = true
+strsim.workspace = true
+strum.workspace = true
+telemetry.workspace = true
+telemetry_events.workspace = true
+thiserror.workspace = true
+ui.workspace = true
+util.workspace = true
+uuid.workspace = true
+workspace.workspace = true
+worktree.workspace = true
+zed_actions.workspace = true
+
+[dev-dependencies]
+clock = { workspace = true, features = ["test-support"] }
+cloud_api_types.workspace = true
+cloud_llm_client = { workspace = true, features = ["test-support"] }
+ctor.workspace = true
+gpui = { workspace = true, features = ["test-support"] }
+indoc.workspace = true
+language = { workspace = true, features = ["test-support"] }
+language_model = { workspace = true, features = ["test-support"] }
+lsp.workspace = true
+parking_lot.workspace = true
+project = { workspace = true, features = ["test-support"] }
+settings = { workspace = true, features = ["test-support"] }
+zlog.workspace = true
@@ -1,298 +1,1911 @@
-use std::{ops::Range, sync::Arc};
+use anyhow::Result;
+use arrayvec::ArrayVec;
+use client::{Client, EditPredictionUsage, UserStore};
+use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
+use cloud_llm_client::{
+ AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
+ EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
+ MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
+ ZED_VERSION_HEADER_NAME,
+};
+use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
+use collections::{HashMap, HashSet};
+use db::kvp::{Dismissable, KEY_VALUE_STORE};
+use edit_prediction_context::EditPredictionExcerptOptions;
+use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
+use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
+use futures::{
+ AsyncReadExt as _, FutureExt as _, StreamExt as _,
+ channel::{
+ mpsc::{self, UnboundedReceiver},
+ oneshot,
+ },
+ select_biased,
+};
+use gpui::BackgroundExecutor;
+use gpui::{
+ App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
+ http_client::{self, AsyncBody, Method},
+ prelude::*,
+};
+use language::language_settings::all_language_settings;
+use language::{Anchor, Buffer, File, Point, ToPoint};
+use language::{BufferSnapshot, OffsetRangeExt};
+use language_model::{LlmApiToken, RefreshLlmTokenListener};
+use project::{Project, ProjectPath, WorktreeId};
+use release_channel::AppVersion;
+use semver::Version;
+use serde::de::DeserializeOwned;
+use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
+use std::collections::{VecDeque, hash_map};
+use workspace::Workspace;
+
+use std::ops::Range;
+use std::path::Path;
+use std::rc::Rc;
+use std::str::FromStr as _;
+use std::sync::{Arc, LazyLock};
+use std::time::{Duration, Instant};
+use std::{env, mem};
+use thiserror::Error;
+use util::{RangeExt as _, ResultExt as _};
+use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
+
+mod license_detection;
+mod onboarding_modal;
+mod prediction;
+pub mod sweep_ai;
+pub mod udiff;
+mod xml_edits;
+mod zed_edit_prediction_delegate;
+pub mod zeta1;
+pub mod zeta2;
+
+#[cfg(test)]
+mod edit_prediction_tests;
+
+use crate::license_detection::LicenseDetectionWatcher;
+use crate::onboarding_modal::ZedPredictModal;
+pub use crate::prediction::EditPrediction;
+pub use crate::prediction::EditPredictionId;
+pub use crate::prediction::EditPredictionInputs;
+use crate::prediction::EditPredictionResult;
+pub use crate::sweep_ai::SweepAi;
+pub use telemetry_events::EditPredictionRating;
+pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
+
+actions!(
+ edit_prediction,
+ [
+ /// Resets the edit prediction onboarding state.
+ ResetOnboarding,
+ /// Clears the edit prediction history.
+ ClearHistory,
+ ]
+);
+
+/// Maximum number of events to track.
+const EVENT_COUNT_MAX: usize = 6;
+const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
+const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
+const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
-use client::EditPredictionUsage;
-use gpui::{App, Context, Entity, SharedString};
-use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt};
+pub struct SweepFeatureFlag;
-// TODO: Find a better home for `Direction`.
-//
-// This should live in an ancestor crate of `editor` and `edit_prediction`,
-// but at time of writing there isn't an obvious spot.
-#[derive(Copy, Clone, PartialEq, Eq)]
-pub enum Direction {
- Prev,
- Next,
+impl FeatureFlag for SweepFeatureFlag {
+ const NAME: &str = "sweep-ai";
}
-#[derive(Clone)]
-pub enum EditPrediction {
- /// Edits within the buffer that requested the prediction
- Local {
- id: Option<SharedString>,
- edits: Vec<(Range<language::Anchor>, Arc<str>)>,
- edit_preview: Option<language::EditPreview>,
- },
- /// Jump to a different file from the one that requested the prediction
- Jump {
- id: Option<SharedString>,
- snapshot: language::BufferSnapshot,
- target: language::Anchor,
+pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
+ context: EditPredictionExcerptOptions {
+ max_bytes: 512,
+ min_bytes: 128,
+ target_before_cursor_over_total_bytes: 0.5,
},
+ max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
+ prompt_format: PromptFormat::DEFAULT,
+};
+
+static USE_OLLAMA: LazyLock<bool> =
+ LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
+
+static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
+ match env::var("ZED_ZETA2_MODEL").as_deref() {
+ Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
+ Ok(model) => model,
+ Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
+ Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
+ }
+ .to_string()
+});
+static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
+ env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
+ if *USE_OLLAMA {
+ Some("http://localhost:11434/v1/chat/completions".into())
+ } else {
+ None
+ }
+ })
+});
+
+pub struct Zeta2FeatureFlag;
+
+impl FeatureFlag for Zeta2FeatureFlag {
+ const NAME: &'static str = "zeta2";
+
+ fn enabled_for_staff() -> bool {
+ true
+ }
}
-pub enum DataCollectionState {
- /// The provider doesn't support data collection.
- Unsupported,
- /// Data collection is enabled.
- Enabled { is_project_open_source: bool },
- /// Data collection is disabled or unanswered.
- Disabled { is_project_open_source: bool },
+#[derive(Clone)]
+struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
+
+impl Global for EditPredictionStoreGlobal {}
+
+pub struct EditPredictionStore {
+ client: Arc<Client>,
+ user_store: Entity<UserStore>,
+ llm_token: LlmApiToken,
+ _llm_token_subscription: Subscription,
+ projects: HashMap<EntityId, ProjectState>,
+ use_context: bool,
+ options: ZetaOptions,
+ update_required: bool,
+ debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
+ #[cfg(feature = "eval-support")]
+ eval_cache: Option<Arc<dyn EvalCache>>,
+ edit_prediction_model: EditPredictionModel,
+ pub sweep_ai: SweepAi,
+ data_collection_choice: DataCollectionChoice,
+ reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
+ shown_predictions: VecDeque<EditPrediction>,
+ rated_predictions: HashSet<EditPredictionId>,
+}
+
+#[derive(Copy, Clone, Default, PartialEq, Eq)]
+pub enum EditPredictionModel {
+ #[default]
+ Zeta1,
+ Zeta2,
+ Sweep,
}
-impl DataCollectionState {
- pub fn is_supported(&self) -> bool {
- !matches!(self, DataCollectionState::Unsupported)
+#[derive(Debug, Clone, PartialEq)]
+pub struct ZetaOptions {
+ pub context: EditPredictionExcerptOptions,
+ pub max_prompt_bytes: usize,
+ pub prompt_format: predict_edits_v3::PromptFormat,
+}
+
+#[derive(Debug)]
+pub enum DebugEvent {
+ ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
+ ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
+ EditPredictionRequested(EditPredictionRequestedDebugEvent),
+}
+
+#[derive(Debug)]
+pub struct ContextRetrievalStartedDebugEvent {
+ pub project_entity_id: EntityId,
+ pub timestamp: Instant,
+ pub search_prompt: String,
+}
+
+#[derive(Debug)]
+pub struct ContextRetrievalFinishedDebugEvent {
+ pub project_entity_id: EntityId,
+ pub timestamp: Instant,
+ pub metadata: Vec<(&'static str, SharedString)>,
+}
+
+#[derive(Debug)]
+pub struct EditPredictionRequestedDebugEvent {
+ pub inputs: EditPredictionInputs,
+ pub retrieval_time: Duration,
+ pub buffer: WeakEntity<Buffer>,
+ pub position: Anchor,
+ pub local_prompt: Result<String, String>,
+ pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
+}
+
+pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
+
+struct ProjectState {
+ events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
+ last_event: Option<LastEvent>,
+ recent_paths: VecDeque<ProjectPath>,
+ registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
+ current_prediction: Option<CurrentEditPrediction>,
+ next_pending_prediction_id: usize,
+ pending_predictions: ArrayVec<PendingPrediction, 2>,
+ context_updates_tx: smol::channel::Sender<()>,
+ context_updates_rx: smol::channel::Receiver<()>,
+ last_prediction_refresh: Option<(EntityId, Instant)>,
+ cancelled_predictions: HashSet<usize>,
+ context: Entity<RelatedExcerptStore>,
+ license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
+ _subscription: gpui::Subscription,
+}
+
+impl ProjectState {
+ pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
+ self.events
+ .iter()
+ .cloned()
+ .chain(
+ self.last_event
+ .as_ref()
+ .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
+ )
+ .collect()
}
- pub fn is_enabled(&self) -> bool {
- matches!(self, DataCollectionState::Enabled { .. })
+ fn cancel_pending_prediction(
+ &mut self,
+ pending_prediction: PendingPrediction,
+ cx: &mut Context<EditPredictionStore>,
+ ) {
+ self.cancelled_predictions.insert(pending_prediction.id);
+
+ cx.spawn(async move |this, cx| {
+ let Some(prediction_id) = pending_prediction.task.await else {
+ return;
+ };
+
+ this.update(cx, |this, _cx| {
+ this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
+ })
+ .ok();
+ })
+ .detach()
}
+}
+
+#[derive(Debug, Clone)]
+struct CurrentEditPrediction {
+ pub requested_by: PredictionRequestedBy,
+ pub prediction: EditPrediction,
+ pub was_shown: bool,
+}
+
+impl CurrentEditPrediction {
+ fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
+ let Some(new_edits) = self
+ .prediction
+ .interpolate(&self.prediction.buffer.read(cx))
+ else {
+ return false;
+ };
- pub fn is_project_open_source(&self) -> bool {
+ if self.prediction.buffer != old_prediction.prediction.buffer {
+ return true;
+ }
+
+ let Some(old_edits) = old_prediction
+ .prediction
+ .interpolate(&old_prediction.prediction.buffer.read(cx))
+ else {
+ return true;
+ };
+
+ let requested_by_buffer_id = self.requested_by.buffer_id();
+
+ // This reduces the occurrence of UI thrash from replacing edits
+ //
+ // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
+ if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
+ && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
+ && old_edits.len() == 1
+ && new_edits.len() == 1
+ {
+ let (old_range, old_text) = &old_edits[0];
+ let (new_range, new_text) = &new_edits[0];
+ new_range == old_range && new_text.starts_with(old_text.as_ref())
+ } else {
+ true
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+enum PredictionRequestedBy {
+ DiagnosticsUpdate,
+ Buffer(EntityId),
+}
+
+impl PredictionRequestedBy {
+ pub fn buffer_id(&self) -> Option<EntityId> {
match self {
- Self::Enabled {
- is_project_open_source,
- }
- | Self::Disabled {
- is_project_open_source,
- } => *is_project_open_source,
- _ => false,
+ PredictionRequestedBy::DiagnosticsUpdate => None,
+ PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
}
}
}
-pub trait EditPredictionProvider: 'static + Sized {
- fn name() -> &'static str;
- fn display_name() -> &'static str;
- fn show_completions_in_menu() -> bool;
- fn show_tab_accept_marker() -> bool {
- false
+#[derive(Debug)]
+struct PendingPrediction {
+ id: usize,
+ task: Task<Option<EditPredictionId>>,
+}
+
+/// A prediction from the perspective of a buffer.
+#[derive(Debug)]
+enum BufferEditPrediction<'a> {
+ Local { prediction: &'a EditPrediction },
+ Jump { prediction: &'a EditPrediction },
+}
+
+#[cfg(test)]
+impl std::ops::Deref for BufferEditPrediction<'_> {
+ type Target = EditPrediction;
+
+ fn deref(&self) -> &Self::Target {
+ match self {
+ BufferEditPrediction::Local { prediction } => prediction,
+ BufferEditPrediction::Jump { prediction } => prediction,
+ }
}
- fn supports_jump_to_edit() -> bool {
- true
+}
+
+struct RegisteredBuffer {
+ snapshot: BufferSnapshot,
+ _subscriptions: [gpui::Subscription; 2],
+}
+
+struct LastEvent {
+ old_snapshot: BufferSnapshot,
+ new_snapshot: BufferSnapshot,
+ end_edit_anchor: Option<Anchor>,
+}
+
+impl LastEvent {
+ pub fn finalize(
+ &self,
+ license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
+ cx: &App,
+ ) -> Option<Arc<predict_edits_v3::Event>> {
+ let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
+ let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
+
+ let file = self.new_snapshot.file();
+ let old_file = self.old_snapshot.file();
+
+ let in_open_source_repo = [file, old_file].iter().all(|file| {
+ file.is_some_and(|file| {
+ license_detection_watchers
+ .get(&file.worktree_id(cx))
+ .is_some_and(|watcher| watcher.is_project_open_source())
+ })
+ });
+
+ let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
+
+ if path == old_path && diff.is_empty() {
+ None
+ } else {
+ Some(Arc::new(predict_edits_v3::Event::BufferChange {
+ old_path,
+ path,
+ diff,
+ in_open_source_repo,
+ // TODO: Actually detect if this edit was predicted or not
+ predicted: false,
+ }))
+ }
}
+}
- fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
- DataCollectionState::Unsupported
+fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
+ if let Some(file) = snapshot.file() {
+ file.full_path(cx).into()
+ } else {
+ Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
}
+}
- fn usage(&self, _cx: &App) -> Option<EditPredictionUsage> {
- None
+impl EditPredictionStore {
+ pub fn try_global(cx: &App) -> Option<Entity<Self>> {
+ cx.try_global::<EditPredictionStoreGlobal>()
+ .map(|global| global.0.clone())
}
- fn toggle_data_collection(&mut self, _cx: &mut App) {}
- fn is_enabled(
- &self,
+ pub fn global(
+ client: &Arc<Client>,
+ user_store: &Entity<UserStore>,
+ cx: &mut App,
+ ) -> Entity<Self> {
+ cx.try_global::<EditPredictionStoreGlobal>()
+ .map(|global| global.0.clone())
+ .unwrap_or_else(|| {
+ let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
+ cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
+ ep_store
+ })
+ }
+
+ pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
+ let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
+ let data_collection_choice = Self::load_data_collection_choice();
+
+ let llm_token = LlmApiToken::default();
+
+ let (reject_tx, reject_rx) = mpsc::unbounded();
+ cx.background_spawn({
+ let client = client.clone();
+ let llm_token = llm_token.clone();
+ let app_version = AppVersion::global(cx);
+ let background_executor = cx.background_executor().clone();
+ async move {
+ Self::handle_rejected_predictions(
+ reject_rx,
+ client,
+ llm_token,
+ app_version,
+ background_executor,
+ )
+ .await
+ }
+ })
+ .detach();
+
+ let mut this = Self {
+ projects: HashMap::default(),
+ client,
+ user_store,
+ options: DEFAULT_OPTIONS,
+ use_context: false,
+ llm_token,
+ _llm_token_subscription: cx.subscribe(
+ &refresh_llm_token_listener,
+ |this, _listener, _event, cx| {
+ let client = this.client.clone();
+ let llm_token = this.llm_token.clone();
+ cx.spawn(async move |_this, _cx| {
+ llm_token.refresh(&client).await?;
+ anyhow::Ok(())
+ })
+ .detach_and_log_err(cx);
+ },
+ ),
+ update_required: false,
+ debug_tx: None,
+ #[cfg(feature = "eval-support")]
+ eval_cache: None,
+ edit_prediction_model: EditPredictionModel::Zeta2,
+ sweep_ai: SweepAi::new(cx),
+ data_collection_choice,
+ reject_predictions_tx: reject_tx,
+ rated_predictions: Default::default(),
+ shown_predictions: Default::default(),
+ };
+
+ this.enable_or_disable_context_retrieval(cx);
+ let weak_this = cx.weak_entity();
+ cx.on_flags_ready(move |_, cx| {
+ weak_this
+ .update(cx, |this, cx| this.enable_or_disable_context_retrieval(cx))
+ .ok();
+ })
+ .detach();
+ cx.observe_global::<SettingsStore>(|this, cx| {
+ this.enable_or_disable_context_retrieval(cx);
+ })
+ .detach();
+
+ this
+ }
+
+ pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
+ self.edit_prediction_model = model;
+ }
+
+ pub fn has_sweep_api_token(&self) -> bool {
+ self.sweep_ai
+ .api_token
+ .clone()
+ .now_or_never()
+ .flatten()
+ .is_some()
+ }
+
+ #[cfg(feature = "eval-support")]
+ pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
+ self.eval_cache = Some(cache);
+ }
+
+ pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<DebugEvent> {
+ let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
+ self.debug_tx = Some(debug_watch_tx);
+ debug_watch_rx
+ }
+
+ pub fn options(&self) -> &ZetaOptions {
+ &self.options
+ }
+
+ pub fn set_options(&mut self, options: ZetaOptions) {
+ self.options = options;
+ }
+
+ pub fn set_use_context(&mut self, use_context: bool) {
+ self.use_context = use_context;
+ }
+
+ pub fn clear_history(&mut self) {
+ for project_state in self.projects.values_mut() {
+ project_state.events.clear();
+ }
+ }
+
+ pub fn context_for_project<'a>(
+ &'a self,
+ project: &Entity<Project>,
+ cx: &'a App,
+ ) -> &'a [RelatedFile] {
+ self.projects
+ .get(&project.entity_id())
+ .map(|project| project.context.read(cx).related_files())
+ .unwrap_or(&[])
+ }
+
+ pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
+ if self.edit_prediction_model == EditPredictionModel::Zeta2 {
+ self.user_store.read(cx).edit_prediction_usage()
+ } else {
+ None
+ }
+ }
+
+ pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
+ self.get_or_init_project(project, cx);
+ }
+
+ pub fn register_buffer(
+ &mut self,
buffer: &Entity<Buffer>,
- cursor_position: language::Anchor,
- cx: &App,
- ) -> bool;
- fn is_refreshing(&self, cx: &App) -> bool;
- fn refresh(
+ project: &Entity<Project>,
+ cx: &mut Context<Self>,
+ ) {
+ let project_state = self.get_or_init_project(project, cx);
+ Self::register_buffer_impl(project_state, buffer, project, cx);
+ }
+
+ fn get_or_init_project(
&mut self,
- buffer: Entity<Buffer>,
- cursor_position: language::Anchor,
- debounce: bool,
+ project: &Entity<Project>,
cx: &mut Context<Self>,
- );
- fn cycle(
+ ) -> &mut ProjectState {
+ let entity_id = project.entity_id();
+ let (context_updates_tx, context_updates_rx) = smol::channel::unbounded();
+ self.projects
+ .entry(entity_id)
+ .or_insert_with(|| ProjectState {
+ context: {
+ let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
+ cx.subscribe(
+ &related_excerpt_store,
+ move |this, _, event, _| match event {
+ RelatedExcerptStoreEvent::StartedRefresh => {
+ if let Some(debug_tx) = this.debug_tx.clone() {
+ debug_tx
+ .unbounded_send(DebugEvent::ContextRetrievalStarted(
+ ContextRetrievalStartedDebugEvent {
+ project_entity_id: entity_id,
+ timestamp: Instant::now(),
+ search_prompt: String::new(),
+ },
+ ))
+ .ok();
+ }
+ }
+ RelatedExcerptStoreEvent::FinishedRefresh {
+ cache_hit_count,
+ cache_miss_count,
+ mean_definition_latency,
+ max_definition_latency,
+ } => {
+ if let Some(debug_tx) = this.debug_tx.clone() {
+ debug_tx
+ .unbounded_send(DebugEvent::ContextRetrievalFinished(
+ ContextRetrievalFinishedDebugEvent {
+ project_entity_id: entity_id,
+ timestamp: Instant::now(),
+ metadata: vec![
+ (
+ "Cache Hits",
+ format!(
+ "{}/{}",
+ cache_hit_count,
+ cache_hit_count + cache_miss_count
+ )
+ .into(),
+ ),
+ (
+ "Max LSP Time",
+ format!(
+ "{} ms",
+ max_definition_latency.as_millis()
+ )
+ .into(),
+ ),
+ (
+ "Mean LSP Time",
+ format!(
+ "{} ms",
+ mean_definition_latency.as_millis()
+ )
+ .into(),
+ ),
+ ],
+ },
+ ))
+ .ok();
+ }
+ if let Some(project_state) = this.projects.get(&entity_id) {
+ project_state.context_updates_tx.send_blocking(()).ok();
+ }
+ }
+ },
+ )
+ .detach();
+ related_excerpt_store
+ },
+ events: VecDeque::new(),
+ last_event: None,
+ recent_paths: VecDeque::new(),
+ context_updates_rx,
+ context_updates_tx,
+ registered_buffers: HashMap::default(),
+ current_prediction: None,
+ cancelled_predictions: HashSet::default(),
+ pending_predictions: ArrayVec::new(),
+ next_pending_prediction_id: 0,
+ last_prediction_refresh: None,
+ license_detection_watchers: HashMap::default(),
+ _subscription: cx.subscribe(&project, Self::handle_project_event),
+ })
+ }
+
+ pub fn project_context_updates(
+ &self,
+ project: &Entity<Project>,
+ ) -> Option<smol::channel::Receiver<()>> {
+ let project_state = self.projects.get(&project.entity_id())?;
+ Some(project_state.context_updates_rx.clone())
+ }
+
+ fn handle_project_event(
&mut self,
- buffer: Entity<Buffer>,
- cursor_position: language::Anchor,
- direction: Direction,
+ project: Entity<Project>,
+ event: &project::Event,
cx: &mut Context<Self>,
- );
- fn accept(&mut self, cx: &mut Context<Self>);
- fn discard(&mut self, cx: &mut Context<Self>);
- fn did_show(&mut self, _cx: &mut Context<Self>) {}
- fn suggest(
+ ) {
+ // TODO [zeta2] init with recent paths
+ match event {
+ project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
+ let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
+ return;
+ };
+ let path = project.read(cx).path_for_entry(*active_entry_id, cx);
+ if let Some(path) = path {
+ if let Some(ix) = project_state
+ .recent_paths
+ .iter()
+ .position(|probe| probe == &path)
+ {
+ project_state.recent_paths.remove(ix);
+ }
+ project_state.recent_paths.push_front(path);
+ }
+ }
+ project::Event::DiagnosticsUpdated { .. } => {
+ if cx.has_flag::<Zeta2FeatureFlag>() {
+ self.refresh_prediction_from_diagnostics(project, cx);
+ }
+ }
+ _ => (),
+ }
+ }
+
+ fn register_buffer_impl<'a>(
+ project_state: &'a mut ProjectState,
+ buffer: &Entity<Buffer>,
+ project: &Entity<Project>,
+ cx: &mut Context<Self>,
+ ) -> &'a mut RegisteredBuffer {
+ let buffer_id = buffer.entity_id();
+
+ if let Some(file) = buffer.read(cx).file() {
+ let worktree_id = file.worktree_id(cx);
+ if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
+ project_state
+ .license_detection_watchers
+ .entry(worktree_id)
+ .or_insert_with(|| {
+ let project_entity_id = project.entity_id();
+ cx.observe_release(&worktree, move |this, _worktree, _cx| {
+ let Some(project_state) = this.projects.get_mut(&project_entity_id)
+ else {
+ return;
+ };
+ project_state
+ .license_detection_watchers
+ .remove(&worktree_id);
+ })
+ .detach();
+ Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
+ });
+ }
+ }
+
+ match project_state.registered_buffers.entry(buffer_id) {
+ hash_map::Entry::Occupied(entry) => entry.into_mut(),
+ hash_map::Entry::Vacant(entry) => {
+ let snapshot = buffer.read(cx).snapshot();
+ let project_entity_id = project.entity_id();
+ entry.insert(RegisteredBuffer {
+ snapshot,
+ _subscriptions: [
+ cx.subscribe(buffer, {
+ let project = project.downgrade();
+ move |this, buffer, event, cx| {
+ if let language::BufferEvent::Edited = event
+ && let Some(project) = project.upgrade()
+ {
+ this.report_changes_for_buffer(&buffer, &project, cx);
+ }
+ }
+ }),
+ cx.observe_release(buffer, move |this, _buffer, _cx| {
+ let Some(project_state) = this.projects.get_mut(&project_entity_id)
+ else {
+ return;
+ };
+ project_state.registered_buffers.remove(&buffer_id);
+ }),
+ ],
+ })
+ }
+ }
+ }
+
+ fn report_changes_for_buffer(
&mut self,
buffer: &Entity<Buffer>,
- cursor_position: language::Anchor,
+ project: &Entity<Project>,
cx: &mut Context<Self>,
- ) -> Option<EditPrediction>;
-}
+ ) {
+ let project_state = self.get_or_init_project(project, cx);
+ let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
+
+ let new_snapshot = buffer.read(cx).snapshot();
+ if new_snapshot.version == registered_buffer.snapshot.version {
+ return;
+ }
+
+ let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
+ let end_edit_anchor = new_snapshot
+ .anchored_edits_since::<Point>(&old_snapshot.version)
+ .last()
+ .map(|(_, range)| range.end);
+ let events = &mut project_state.events;
-pub trait EditPredictionProviderHandle {
- fn name(&self) -> &'static str;
- fn display_name(&self) -> &'static str;
- fn is_enabled(
+ if let Some(LastEvent {
+ new_snapshot: last_new_snapshot,
+ end_edit_anchor: last_end_edit_anchor,
+ ..
+ }) = project_state.last_event.as_mut()
+ {
+ let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
+ == last_new_snapshot.remote_id()
+ && old_snapshot.version == last_new_snapshot.version;
+
+ let should_coalesce = is_next_snapshot_of_same_buffer
+ && end_edit_anchor
+ .as_ref()
+ .zip(last_end_edit_anchor.as_ref())
+ .is_some_and(|(a, b)| {
+ let a = a.to_point(&new_snapshot);
+ let b = b.to_point(&new_snapshot);
+ a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
+ });
+
+ if should_coalesce {
+ *last_end_edit_anchor = end_edit_anchor;
+ *last_new_snapshot = new_snapshot;
+ return;
+ }
+ }
+
+ if events.len() + 1 >= EVENT_COUNT_MAX {
+ events.pop_front();
+ }
+
+ if let Some(event) = project_state.last_event.take() {
+ events.extend(event.finalize(&project_state.license_detection_watchers, cx));
+ }
+
+ project_state.last_event = Some(LastEvent {
+ old_snapshot,
+ new_snapshot,
+ end_edit_anchor,
+ });
+ }
+
+ fn current_prediction_for_buffer(
&self,
buffer: &Entity<Buffer>,
- cursor_position: language::Anchor,
+ project: &Entity<Project>,
cx: &App,
- ) -> bool;
- fn show_completions_in_menu(&self) -> bool;
- fn show_tab_accept_marker(&self) -> bool;
- fn supports_jump_to_edit(&self) -> bool;
- fn data_collection_state(&self, cx: &App) -> DataCollectionState;
- fn usage(&self, cx: &App) -> Option<EditPredictionUsage>;
- fn toggle_data_collection(&self, cx: &mut App);
- fn is_refreshing(&self, cx: &App) -> bool;
- fn refresh(
- &self,
- buffer: Entity<Buffer>,
- cursor_position: language::Anchor,
- debounce: bool,
- cx: &mut App,
- );
- fn cycle(
- &self,
- buffer: Entity<Buffer>,
- cursor_position: language::Anchor,
- direction: Direction,
- cx: &mut App,
- );
- fn did_show(&self, cx: &mut App);
- fn accept(&self, cx: &mut App);
- fn discard(&self, cx: &mut App);
- fn suggest(
- &self,
- buffer: &Entity<Buffer>,
- cursor_position: language::Anchor,
- cx: &mut App,
- ) -> Option<EditPrediction>;
-}
+ ) -> Option<BufferEditPrediction<'_>> {
+ let project_state = self.projects.get(&project.entity_id())?;
-impl<T> EditPredictionProviderHandle for Entity<T>
-where
- T: EditPredictionProvider,
-{
- fn name(&self) -> &'static str {
- T::name()
- }
+ let CurrentEditPrediction {
+ requested_by,
+ prediction,
+ ..
+ } = project_state.current_prediction.as_ref()?;
- fn display_name(&self) -> &'static str {
- T::display_name()
- }
+ if prediction.targets_buffer(buffer.read(cx)) {
+ Some(BufferEditPrediction::Local { prediction })
+ } else {
+ let show_jump = match requested_by {
+ PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
+ requested_by_buffer_id == &buffer.entity_id()
+ }
+ PredictionRequestedBy::DiagnosticsUpdate => true,
+ };
- fn show_completions_in_menu(&self) -> bool {
- T::show_completions_in_menu()
+ if show_jump {
+ Some(BufferEditPrediction::Jump { prediction })
+ } else {
+ None
+ }
+ }
}
- fn show_tab_accept_marker(&self) -> bool {
- T::show_tab_accept_marker()
- }
+ fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
+ match self.edit_prediction_model {
+ EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
+ EditPredictionModel::Sweep => return,
+ }
+
+ let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
+ return;
+ };
+
+ let Some(prediction) = project_state.current_prediction.take() else {
+ return;
+ };
+ let request_id = prediction.prediction.id.to_string();
+ for pending_prediction in mem::take(&mut project_state.pending_predictions) {
+ project_state.cancel_pending_prediction(pending_prediction, cx);
+ }
+
+ let client = self.client.clone();
+ let llm_token = self.llm_token.clone();
+ let app_version = AppVersion::global(cx);
+ cx.spawn(async move |this, cx| {
+ let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
+ http_client::Url::parse(&predict_edits_url)?
+ } else {
+ client
+ .http_client()
+ .build_zed_llm_url("/predict_edits/accept", &[])?
+ };
+
+ let response = cx
+ .background_spawn(Self::send_api_request::<()>(
+ move |builder| {
+ let req = builder.uri(url.as_ref()).body(
+ serde_json::to_string(&AcceptEditPredictionBody {
+ request_id: request_id.clone(),
+ })?
+ .into(),
+ );
+ Ok(req?)
+ },
+ client,
+ llm_token,
+ app_version,
+ ))
+ .await;
- fn supports_jump_to_edit(&self) -> bool {
- T::supports_jump_to_edit()
+ Self::handle_api_response(&this, response, cx)?;
+ anyhow::Ok(())
+ })
+ .detach_and_log_err(cx);
}
- fn data_collection_state(&self, cx: &App) -> DataCollectionState {
- self.read(cx).data_collection_state(cx)
+ async fn handle_rejected_predictions(
+ rx: UnboundedReceiver<EditPredictionRejection>,
+ client: Arc<Client>,
+ llm_token: LlmApiToken,
+ app_version: Version,
+ background_executor: BackgroundExecutor,
+ ) {
+ let mut rx = std::pin::pin!(rx.peekable());
+ let mut batched = Vec::new();
+
+ while let Some(rejection) = rx.next().await {
+ batched.push(rejection);
+
+ if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
+ select_biased! {
+ next = rx.as_mut().peek().fuse() => {
+ if next.is_some() {
+ continue;
+ }
+ }
+ () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
+ }
+ }
+
+ let url = client
+ .http_client()
+ .build_zed_llm_url("/predict_edits/reject", &[])
+ .unwrap();
+
+ let flush_count = batched
+ .len()
+ // in case items have accumulated after failure
+ .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
+ let start = batched.len() - flush_count;
+
+ let body = RejectEditPredictionsBodyRef {
+ rejections: &batched[start..],
+ };
+
+ let result = Self::send_api_request::<()>(
+ |builder| {
+ let req = builder
+ .uri(url.as_ref())
+ .body(serde_json::to_string(&body)?.into());
+ anyhow::Ok(req?)
+ },
+ client.clone(),
+ llm_token.clone(),
+ app_version.clone(),
+ )
+ .await;
+
+ if result.log_err().is_some() {
+ batched.drain(start..);
+ }
+ }
}
- fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
- self.read(cx).usage(cx)
+ fn reject_current_prediction(
+ &mut self,
+ reason: EditPredictionRejectReason,
+ project: &Entity<Project>,
+ ) {
+ if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
+ project_state.pending_predictions.clear();
+ if let Some(prediction) = project_state.current_prediction.take() {
+ self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
+ }
+ };
}
- fn toggle_data_collection(&self, cx: &mut App) {
- self.update(cx, |this, cx| this.toggle_data_collection(cx))
+ fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
+ if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
+ if let Some(current_prediction) = project_state.current_prediction.as_mut() {
+ if !current_prediction.was_shown {
+ current_prediction.was_shown = true;
+ self.shown_predictions
+ .push_front(current_prediction.prediction.clone());
+ if self.shown_predictions.len() > 50 {
+ let completion = self.shown_predictions.pop_back().unwrap();
+ self.rated_predictions.remove(&completion.id);
+ }
+ }
+ }
+ }
}
- fn is_enabled(
- &self,
- buffer: &Entity<Buffer>,
- cursor_position: language::Anchor,
- cx: &App,
- ) -> bool {
- self.read(cx).is_enabled(buffer, cursor_position, cx)
+ fn reject_prediction(
+ &mut self,
+ prediction_id: EditPredictionId,
+ reason: EditPredictionRejectReason,
+ was_shown: bool,
+ ) {
+ match self.edit_prediction_model {
+ EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
+ EditPredictionModel::Sweep => return,
+ }
+
+ self.reject_predictions_tx
+ .unbounded_send(EditPredictionRejection {
+ request_id: prediction_id.to_string(),
+ reason,
+ was_shown,
+ })
+ .log_err();
}
- fn is_refreshing(&self, cx: &App) -> bool {
- self.read(cx).is_refreshing(cx)
+ fn is_refreshing(&self, project: &Entity<Project>) -> bool {
+ self.projects
+ .get(&project.entity_id())
+ .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
}
- fn refresh(
- &self,
+ pub fn refresh_prediction_from_buffer(
+ &mut self,
+ project: Entity<Project>,
buffer: Entity<Buffer>,
- cursor_position: language::Anchor,
- debounce: bool,
- cx: &mut App,
+ position: language::Anchor,
+ cx: &mut Context<Self>,
) {
- self.update(cx, |this, cx| {
- this.refresh(buffer, cursor_position, debounce, cx)
+ self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
+ let Some(request_task) = this
+ .update(cx, |this, cx| {
+ this.request_prediction(
+ &project,
+ &buffer,
+ position,
+ PredictEditsRequestTrigger::Other,
+ cx,
+ )
+ })
+ .log_err()
+ else {
+ return Task::ready(anyhow::Ok(None));
+ };
+
+ cx.spawn(async move |_cx| {
+ request_task.await.map(|prediction_result| {
+ prediction_result.map(|prediction_result| {
+ (
+ prediction_result,
+ PredictionRequestedBy::Buffer(buffer.entity_id()),
+ )
+ })
+ })
+ })
})
}
- fn cycle(
- &self,
- buffer: Entity<Buffer>,
- cursor_position: language::Anchor,
- direction: Direction,
- cx: &mut App,
+ pub fn refresh_prediction_from_diagnostics(
+ &mut self,
+ project: Entity<Project>,
+ cx: &mut Context<Self>,
) {
- self.update(cx, |this, cx| {
- this.cycle(buffer, cursor_position, direction, cx)
+ let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
+ return;
+ };
+
+ // Prefer predictions from buffer
+ if project_state.current_prediction.is_some() {
+ return;
+ };
+
+ self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
+ let Some(open_buffer_task) = project
+ .update(cx, |project, cx| {
+ project
+ .active_entry()
+ .and_then(|entry| project.path_for_entry(entry, cx))
+ .map(|path| project.open_buffer(path, cx))
+ })
+ .log_err()
+ .flatten()
+ else {
+ return Task::ready(anyhow::Ok(None));
+ };
+
+ cx.spawn(async move |cx| {
+ let active_buffer = open_buffer_task.await?;
+ let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
+
+ let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
+ active_buffer,
+ &snapshot,
+ Default::default(),
+ Default::default(),
+ &project,
+ cx,
+ )
+ .await?
+ else {
+ return anyhow::Ok(None);
+ };
+
+ let Some(prediction_result) = this
+ .update(cx, |this, cx| {
+ this.request_prediction(
+ &project,
+ &jump_buffer,
+ jump_position,
+ PredictEditsRequestTrigger::Diagnostics,
+ cx,
+ )
+ })?
+ .await?
+ else {
+ return anyhow::Ok(None);
+ };
+
+ this.update(cx, |this, cx| {
+ Some((
+ if this
+ .get_or_init_project(&project, cx)
+ .current_prediction
+ .is_none()
+ {
+ prediction_result
+ } else {
+ EditPredictionResult {
+ id: prediction_result.id,
+ prediction: Err(EditPredictionRejectReason::CurrentPreferred),
+ }
+ },
+ PredictionRequestedBy::DiagnosticsUpdate,
+ ))
+ })
+ })
+ });
+ }
+
+ #[cfg(not(test))]
+ pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
+ #[cfg(test)]
+ pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
+
+ fn queue_prediction_refresh(
+ &mut self,
+ project: Entity<Project>,
+ throttle_entity: EntityId,
+ cx: &mut Context<Self>,
+ do_refresh: impl FnOnce(
+ WeakEntity<Self>,
+ &mut AsyncApp,
+ )
+ -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
+ + 'static,
+ ) {
+ let project_state = self.get_or_init_project(&project, cx);
+ let pending_prediction_id = project_state.next_pending_prediction_id;
+ project_state.next_pending_prediction_id += 1;
+ let last_request = project_state.last_prediction_refresh;
+
+ let task = cx.spawn(async move |this, cx| {
+ if let Some((last_entity, last_timestamp)) = last_request
+ && throttle_entity == last_entity
+ && let Some(timeout) =
+ (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
+ {
+ cx.background_executor().timer(timeout).await;
+ }
+
+ // If this task was cancelled before the throttle timeout expired,
+ // do not perform a request.
+ let mut is_cancelled = true;
+ this.update(cx, |this, cx| {
+ let project_state = this.get_or_init_project(&project, cx);
+ if !project_state
+ .cancelled_predictions
+ .remove(&pending_prediction_id)
+ {
+ project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
+ is_cancelled = false;
+ }
+ })
+ .ok();
+ if is_cancelled {
+ return None;
+ }
+
+ let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
+ let new_prediction_id = new_prediction_result
+ .as_ref()
+ .map(|(prediction, _)| prediction.id.clone());
+
+ // When a prediction completes, remove it from the pending list, and cancel
+ // any pending predictions that were enqueued before it.
+ this.update(cx, |this, cx| {
+ let project_state = this.get_or_init_project(&project, cx);
+
+ let is_cancelled = project_state
+ .cancelled_predictions
+ .remove(&pending_prediction_id);
+
+ let new_current_prediction = if !is_cancelled
+ && let Some((prediction_result, requested_by)) = new_prediction_result
+ {
+ match prediction_result.prediction {
+ Ok(prediction) => {
+ let new_prediction = CurrentEditPrediction {
+ requested_by,
+ prediction,
+ was_shown: false,
+ };
+
+ if let Some(current_prediction) =
+ project_state.current_prediction.as_ref()
+ {
+ if new_prediction.should_replace_prediction(¤t_prediction, cx)
+ {
+ this.reject_current_prediction(
+ EditPredictionRejectReason::Replaced,
+ &project,
+ );
+
+ Some(new_prediction)
+ } else {
+ this.reject_prediction(
+ new_prediction.prediction.id,
+ EditPredictionRejectReason::CurrentPreferred,
+ false,
+ );
+ None
+ }
+ } else {
+ Some(new_prediction)
+ }
+ }
+ Err(reject_reason) => {
+ this.reject_prediction(prediction_result.id, reject_reason, false);
+ None
+ }
+ }
+ } else {
+ None
+ };
+
+ let project_state = this.get_or_init_project(&project, cx);
+
+ if let Some(new_prediction) = new_current_prediction {
+ project_state.current_prediction = Some(new_prediction);
+ }
+
+ let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
+ for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
+ if pending_prediction.id == pending_prediction_id {
+ pending_predictions.remove(ix);
+ for pending_prediction in pending_predictions.drain(0..ix) {
+ project_state.cancel_pending_prediction(pending_prediction, cx)
+ }
+ break;
+ }
+ }
+ this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
+ cx.notify();
+ })
+ .ok();
+
+ new_prediction_id
+ });
+
+ if project_state.pending_predictions.len() <= 1 {
+ project_state.pending_predictions.push(PendingPrediction {
+ id: pending_prediction_id,
+ task,
+ });
+ } else if project_state.pending_predictions.len() == 2 {
+ let pending_prediction = project_state.pending_predictions.pop().unwrap();
+ project_state.pending_predictions.push(PendingPrediction {
+ id: pending_prediction_id,
+ task,
+ });
+ project_state.cancel_pending_prediction(pending_prediction, cx);
+ }
+ }
+
+ pub fn request_prediction(
+ &mut self,
+ project: &Entity<Project>,
+ active_buffer: &Entity<Buffer>,
+ position: language::Anchor,
+ trigger: PredictEditsRequestTrigger,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<Option<EditPredictionResult>>> {
+ self.request_prediction_internal(
+ project.clone(),
+ active_buffer.clone(),
+ position,
+ trigger,
+ cx.has_flag::<Zeta2FeatureFlag>(),
+ cx,
+ )
+ }
+
+ fn request_prediction_internal(
+ &mut self,
+ project: Entity<Project>,
+ active_buffer: Entity<Buffer>,
+ position: language::Anchor,
+ trigger: PredictEditsRequestTrigger,
+ allow_jump: bool,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<Option<EditPredictionResult>>> {
+ const DIAGNOSTIC_LINES_RANGE: u32 = 20;
+
+ self.get_or_init_project(&project, cx);
+ let project_state = self.projects.get(&project.entity_id()).unwrap();
+ let events = project_state.events(cx);
+ let has_events = !events.is_empty();
+
+ let snapshot = active_buffer.read(cx).snapshot();
+ let cursor_point = position.to_point(&snapshot);
+ let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
+ let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
+ let diagnostic_search_range =
+ Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
+
+ let related_files = if self.use_context {
+ self.context_for_project(&project, cx).to_vec()
+ } else {
+ Vec::new()
+ };
+
+ let task = match self.edit_prediction_model {
+ EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(
+ self,
+ &project,
+ &active_buffer,
+ snapshot.clone(),
+ position,
+ events,
+ trigger,
+ cx,
+ ),
+ EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
+ self,
+ &project,
+ &active_buffer,
+ snapshot.clone(),
+ position,
+ events,
+ related_files,
+ trigger,
+ cx,
+ ),
+ EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
+ &project,
+ &active_buffer,
+ snapshot.clone(),
+ position,
+ events,
+ &project_state.recent_paths,
+ related_files,
+ diagnostic_search_range.clone(),
+ cx,
+ ),
+ };
+
+ cx.spawn(async move |this, cx| {
+ let prediction = task.await?;
+
+ if prediction.is_none() && allow_jump {
+ let cursor_point = position.to_point(&snapshot);
+ if has_events
+ && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
+ active_buffer.clone(),
+ &snapshot,
+ diagnostic_search_range,
+ cursor_point,
+ &project,
+ cx,
+ )
+ .await?
+ {
+ return this
+ .update(cx, |this, cx| {
+ this.request_prediction_internal(
+ project,
+ jump_buffer,
+ jump_position,
+ trigger,
+ false,
+ cx,
+ )
+ })?
+ .await;
+ }
+
+ return anyhow::Ok(None);
+ }
+
+ Ok(prediction)
})
}
- fn accept(&self, cx: &mut App) {
- self.update(cx, |this, cx| this.accept(cx))
+ async fn next_diagnostic_location(
+ active_buffer: Entity<Buffer>,
+ active_buffer_snapshot: &BufferSnapshot,
+ active_buffer_diagnostic_search_range: Range<Point>,
+ active_buffer_cursor_point: Point,
+ project: &Entity<Project>,
+ cx: &mut AsyncApp,
+ ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
+ // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
+ let mut jump_location = active_buffer_snapshot
+ .diagnostic_groups(None)
+ .into_iter()
+ .filter_map(|(_, group)| {
+ let range = &group.entries[group.primary_ix]
+ .range
+ .to_point(&active_buffer_snapshot);
+ if range.overlaps(&active_buffer_diagnostic_search_range) {
+ None
+ } else {
+ Some(range.start)
+ }
+ })
+ .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
+ .map(|position| {
+ (
+ active_buffer.clone(),
+ active_buffer_snapshot.anchor_before(position),
+ )
+ });
+
+ if jump_location.is_none() {
+ let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
+ let file = buffer.file()?;
+
+ Some(ProjectPath {
+ worktree_id: file.worktree_id(cx),
+ path: file.path().clone(),
+ })
+ })?;
+
+ let buffer_task = project.update(cx, |project, cx| {
+ let (path, _, _) = project
+ .diagnostic_summaries(false, cx)
+ .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
+ .max_by_key(|(path, _, _)| {
+ // find the buffer with errors that shares most parent directories
+ path.path
+ .components()
+ .zip(
+ active_buffer_path
+ .as_ref()
+ .map(|p| p.path.components())
+ .unwrap_or_default(),
+ )
+ .take_while(|(a, b)| a == b)
+ .count()
+ })?;
+
+ Some(project.open_buffer(path, cx))
+ })?;
+
+ if let Some(buffer_task) = buffer_task {
+ let closest_buffer = buffer_task.await?;
+
+ jump_location = closest_buffer
+ .read_with(cx, |buffer, _cx| {
+ buffer
+ .buffer_diagnostics(None)
+ .into_iter()
+ .min_by_key(|entry| entry.diagnostic.severity)
+ .map(|entry| entry.range.start)
+ })?
+ .map(|position| (closest_buffer, position));
+ }
+ }
+
+ anyhow::Ok(jump_location)
}
- fn discard(&self, cx: &mut App) {
- self.update(cx, |this, cx| this.discard(cx))
+ async fn send_raw_llm_request(
+ request: open_ai::Request,
+ client: Arc<Client>,
+ llm_token: LlmApiToken,
+ app_version: Version,
+ #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
+ #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
+ ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
+ let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
+ http_client::Url::parse(&predict_edits_url)?
+ } else {
+ client
+ .http_client()
+ .build_zed_llm_url("/predict_edits/raw", &[])?
+ };
+
+ #[cfg(feature = "eval-support")]
+ let cache_key = if let Some(cache) = eval_cache {
+ use collections::FxHasher;
+ use std::hash::{Hash, Hasher};
+
+ let mut hasher = FxHasher::default();
+ url.hash(&mut hasher);
+ let request_str = serde_json::to_string_pretty(&request)?;
+ request_str.hash(&mut hasher);
+ let hash = hasher.finish();
+
+ let key = (eval_cache_kind, hash);
+ if let Some(response_str) = cache.read(key) {
+ return Ok((serde_json::from_str(&response_str)?, None));
+ }
+
+ Some((cache, request_str, key))
+ } else {
+ None
+ };
+
+ let (response, usage) = Self::send_api_request(
+ |builder| {
+ let req = builder
+ .uri(url.as_ref())
+ .body(serde_json::to_string(&request)?.into());
+ Ok(req?)
+ },
+ client,
+ llm_token,
+ app_version,
+ )
+ .await?;
+
+ #[cfg(feature = "eval-support")]
+ if let Some((cache, request, key)) = cache_key {
+ cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
+ }
+
+ Ok((response, usage))
}
- fn did_show(&self, cx: &mut App) {
- self.update(cx, |this, cx| this.did_show(cx))
+ fn handle_api_response<T>(
+ this: &WeakEntity<Self>,
+ response: Result<(T, Option<EditPredictionUsage>)>,
+ cx: &mut gpui::AsyncApp,
+ ) -> Result<T> {
+ match response {
+ Ok((data, usage)) => {
+ if let Some(usage) = usage {
+ this.update(cx, |this, cx| {
+ this.user_store.update(cx, |user_store, cx| {
+ user_store.update_edit_prediction_usage(usage, cx);
+ });
+ })
+ .ok();
+ }
+ Ok(data)
+ }
+ Err(err) => {
+ if err.is::<ZedUpdateRequiredError>() {
+ cx.update(|cx| {
+ this.update(cx, |this, _cx| {
+ this.update_required = true;
+ })
+ .ok();
+
+ let error_message: SharedString = err.to_string().into();
+ show_app_notification(
+ NotificationId::unique::<ZedUpdateRequiredError>(),
+ cx,
+ move |cx| {
+ cx.new(|cx| {
+ ErrorMessagePrompt::new(error_message.clone(), cx)
+ .with_link_button("Update Zed", "https://zed.dev/releases")
+ })
+ },
+ );
+ })
+ .ok();
+ }
+ Err(err)
+ }
+ }
}
- fn suggest(
- &self,
- buffer: &Entity<Buffer>,
- cursor_position: language::Anchor,
- cx: &mut App,
- ) -> Option<EditPrediction> {
- self.update(cx, |this, cx| this.suggest(buffer, cursor_position, cx))
- }
-}
-
-/// Returns edits updated based on user edits since the old snapshot. None is returned if any user
-/// edit is not a prefix of a predicted insertion.
-pub fn interpolate_edits(
- old_snapshot: &BufferSnapshot,
- new_snapshot: &BufferSnapshot,
- current_edits: &[(Range<Anchor>, Arc<str>)],
-) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
- let mut edits = Vec::new();
-
- let mut model_edits = current_edits.iter().peekable();
- for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
- while let Some((model_old_range, _)) = model_edits.peek() {
- let model_old_range = model_old_range.to_offset(old_snapshot);
- if model_old_range.end < user_edit.old.start {
- let (model_old_range, model_new_text) = model_edits.next().unwrap();
- edits.push((model_old_range.clone(), model_new_text.clone()));
+ async fn send_api_request<Res>(
+ build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
+ client: Arc<Client>,
+ llm_token: LlmApiToken,
+ app_version: Version,
+ ) -> Result<(Res, Option<EditPredictionUsage>)>
+ where
+ Res: DeserializeOwned,
+ {
+ let http_client = client.http_client();
+ let mut token = llm_token.acquire(&client).await?;
+ let mut did_retry = false;
+
+ loop {
+ let request_builder = http_client::Request::builder().method(Method::POST);
+
+ let request = build(
+ request_builder
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", token))
+ .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
+ )?;
+
+ let mut response = http_client.send(request).await?;
+
+ if let Some(minimum_required_version) = response
+ .headers()
+ .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
+ .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
+ {
+ anyhow::ensure!(
+ app_version >= minimum_required_version,
+ ZedUpdateRequiredError {
+ minimum_version: minimum_required_version
+ }
+ );
+ }
+
+ if response.status().is_success() {
+ let usage = EditPredictionUsage::from_headers(response.headers()).ok();
+
+ let mut body = Vec::new();
+ response.body_mut().read_to_end(&mut body).await?;
+ return Ok((serde_json::from_slice(&body)?, usage));
+ } else if !did_retry
+ && response
+ .headers()
+ .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
+ .is_some()
+ {
+ did_retry = true;
+ token = llm_token.refresh(&client).await?;
} else {
- break;
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ anyhow::bail!(
+ "Request failed with status: {:?}\nBody: {}",
+ response.status(),
+ body
+ );
}
}
+ }
- if let Some((model_old_range, model_new_text)) = model_edits.peek() {
- let model_old_offset_range = model_old_range.to_offset(old_snapshot);
- if user_edit.old == model_old_offset_range {
- let user_new_text = new_snapshot
- .text_for_range(user_edit.new.clone())
- .collect::<String>();
+ pub fn refresh_context(
+ &mut self,
+ project: &Entity<Project>,
+ buffer: &Entity<language::Buffer>,
+ cursor_position: language::Anchor,
+ cx: &mut Context<Self>,
+ ) {
+ if self.use_context {
+ self.get_or_init_project(project, cx)
+ .context
+ .update(cx, |store, cx| {
+ store.refresh(buffer.clone(), cursor_position, cx);
+ });
+ }
+ }
- if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
- if !model_suffix.is_empty() {
- let anchor = old_snapshot.anchor_after(user_edit.old.end);
- edits.push((anchor..anchor, model_suffix.into()));
- }
+ fn is_file_open_source(
+ &self,
+ project: &Entity<Project>,
+ file: &Arc<dyn File>,
+ cx: &App,
+ ) -> bool {
+ if !file.is_local() || file.is_private() {
+ return false;
+ }
+ let Some(project_state) = self.projects.get(&project.entity_id()) else {
+ return false;
+ };
+ project_state
+ .license_detection_watchers
+ .get(&file.worktree_id(cx))
+ .as_ref()
+ .is_some_and(|watcher| watcher.is_project_open_source())
+ }
- model_edits.next();
- continue;
+ fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
+ self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
+ }
+
+ fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
+ if !self.data_collection_choice.is_enabled() {
+ return false;
+ }
+ events.iter().all(|event| {
+ matches!(
+ event.as_ref(),
+ Event::BufferChange {
+ in_open_source_repo: true,
+ ..
}
+ )
+ })
+ }
+
+ fn load_data_collection_choice() -> DataCollectionChoice {
+ let choice = KEY_VALUE_STORE
+ .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
+ .log_err()
+ .flatten();
+
+ match choice.as_deref() {
+ Some("true") => DataCollectionChoice::Enabled,
+ Some("false") => DataCollectionChoice::Disabled,
+ Some(_) => {
+ log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
+ DataCollectionChoice::NotAnswered
}
+ None => DataCollectionChoice::NotAnswered,
+ }
+ }
+
+ fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
+ self.data_collection_choice = self.data_collection_choice.toggle();
+ let new_choice = self.data_collection_choice;
+ db::write_and_log(cx, move || {
+ KEY_VALUE_STORE.write_kvp(
+ ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
+ new_choice.is_enabled().to_string(),
+ )
+ });
+ }
+
+ pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
+ self.shown_predictions.iter()
+ }
+
+ pub fn shown_completions_len(&self) -> usize {
+ self.shown_predictions.len()
+ }
+
+ pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
+ self.rated_predictions.contains(id)
+ }
+
+ pub fn rate_prediction(
+ &mut self,
+ prediction: &EditPrediction,
+ rating: EditPredictionRating,
+ feedback: String,
+ cx: &mut Context<Self>,
+ ) {
+ self.rated_predictions.insert(prediction.id.clone());
+ telemetry::event!(
+ "Edit Prediction Rated",
+ rating,
+ inputs = prediction.inputs,
+ output = prediction.edit_preview.as_unified_diff(&prediction.edits),
+ feedback
+ );
+ self.client.telemetry().flush_events().detach();
+ cx.notify();
+ }
+
+ fn enable_or_disable_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
+ self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
+ && all_language_settings(None, cx).edit_predictions.use_context;
+ }
+}
+
+#[derive(Error, Debug)]
+#[error(
+ "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
+)]
+pub struct ZedUpdateRequiredError {
+ minimum_version: Version,
+}
+
+#[cfg(feature = "eval-support")]
+pub type EvalCacheKey = (EvalCacheEntryKind, u64);
+
+#[cfg(feature = "eval-support")]
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub enum EvalCacheEntryKind {
+ Context,
+ Search,
+ Prediction,
+}
+
+#[cfg(feature = "eval-support")]
+impl std::fmt::Display for EvalCacheEntryKind {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ EvalCacheEntryKind::Search => write!(f, "search"),
+ EvalCacheEntryKind::Context => write!(f, "context"),
+ EvalCacheEntryKind::Prediction => write!(f, "prediction"),
+ }
+ }
+}
+
+#[cfg(feature = "eval-support")]
+pub trait EvalCache: Send + Sync {
+ fn read(&self, key: EvalCacheKey) -> Option<String>;
+ fn write(&self, key: EvalCacheKey, input: &str, value: &str);
+}
+
+#[derive(Debug, Clone, Copy)]
+pub enum DataCollectionChoice {
+ NotAnswered,
+ Enabled,
+ Disabled,
+}
+
+impl DataCollectionChoice {
+ pub fn is_enabled(self) -> bool {
+ match self {
+ Self::Enabled => true,
+ Self::NotAnswered | Self::Disabled => false,
}
+ }
- return None;
+ pub fn is_answered(self) -> bool {
+ match self {
+ Self::Enabled | Self::Disabled => true,
+ Self::NotAnswered => false,
+ }
}
- edits.extend(model_edits.cloned());
+ #[must_use]
+ pub fn toggle(&self) -> DataCollectionChoice {
+ match self {
+ Self::Enabled => Self::Disabled,
+ Self::Disabled => Self::Enabled,
+ Self::NotAnswered => Self::Enabled,
+ }
+ }
+}
+
+impl From<bool> for DataCollectionChoice {
+ fn from(value: bool) -> Self {
+ match value {
+ true => DataCollectionChoice::Enabled,
+ false => DataCollectionChoice::Disabled,
+ }
+ }
+}
+
+struct ZedPredictUpsell;
+
+impl Dismissable for ZedPredictUpsell {
+ const KEY: &'static str = "dismissed-edit-predict-upsell";
+
+ fn dismissed() -> bool {
+ // To make this backwards compatible with older versions of Zed, we
+ // check if the user has seen the previous Edit Prediction Onboarding
+ // before, by checking the data collection choice which was written to
+ // the database once the user clicked on "Accept and Enable"
+ if KEY_VALUE_STORE
+ .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
+ .log_err()
+ .is_some_and(|s| s.is_some())
+ {
+ return true;
+ }
+
+ KEY_VALUE_STORE
+ .read_kvp(Self::KEY)
+ .log_err()
+ .is_some_and(|s| s.is_some())
+ }
+}
+
+pub fn should_show_upsell_modal() -> bool {
+ !ZedPredictUpsell::dismissed()
+}
+
+pub fn init(cx: &mut App) {
+ cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
+ workspace.register_action(
+ move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
+ ZedPredictModal::toggle(
+ workspace,
+ workspace.user_store().clone(),
+ workspace.client().clone(),
+ window,
+ cx,
+ )
+ },
+ );
- if edits.is_empty() { None } else { Some(edits) }
+ workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
+ update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
+ settings
+ .project
+ .all_languages
+ .features
+ .get_or_insert_default()
+ .edit_prediction_provider = Some(EditPredictionProvider::None)
+ });
+ });
+ })
+ .detach();
}
@@ -0,0 +1,1806 @@
+use super::*;
+use crate::zeta1::MAX_EVENT_TOKENS;
+use client::{UserStore, test::FakeServer};
+use clock::{FakeSystemClock, ReplicaId};
+use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
+use cloud_llm_client::{
+ EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
+ RejectEditPredictionsBody,
+};
+use edit_prediction_context::Line;
+use futures::{
+ AsyncReadExt, StreamExt,
+ channel::{mpsc, oneshot},
+};
+use gpui::{
+ Entity, TestAppContext,
+ http_client::{FakeHttpClient, Response},
+};
+use indoc::indoc;
+use language::{Point, ToOffset as _};
+use lsp::LanguageServerId;
+use open_ai::Usage;
+use parking_lot::Mutex;
+use pretty_assertions::{assert_eq, assert_matches};
+use project::{FakeFs, Project};
+use serde_json::json;
+use settings::SettingsStore;
+use std::{path::Path, sync::Arc, time::Duration};
+use util::{path, rel_path::rel_path};
+use uuid::Uuid;
+
+use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
+
+#[gpui::test]
+async fn test_current_state(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "1.txt": "Hello!\nHow\nBye\n",
+ "2.txt": "Hola!\nComo\nAdios\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.register_project(&project, cx);
+ });
+
+ let buffer1 = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
+ project.set_active_path(Some(path.clone()), cx);
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot1.anchor_before(language::Point::new(1, 3));
+
+ // Prediction for current file
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
+ });
+ let (_request, respond_tx) = requests.predict.next().await.unwrap();
+
+ respond_tx
+ .send(model_response(indoc! {r"
+ --- a/root/1.txt
+ +++ b/root/1.txt
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "}))
+ .unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ let prediction = ep_store
+ .current_prediction_for_buffer(&buffer1, &project, cx)
+ .unwrap();
+ assert_matches!(prediction, BufferEditPrediction::Local { .. });
+ });
+
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
+ });
+
+ // Prediction for diagnostic in another file
+
+ let diagnostic = lsp::Diagnostic {
+ range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
+ severity: Some(lsp::DiagnosticSeverity::ERROR),
+ message: "Sentence is incomplete".to_string(),
+ ..Default::default()
+ };
+
+ project.update(cx, |project, cx| {
+ project.lsp_store().update(cx, |lsp_store, cx| {
+ lsp_store
+ .update_diagnostics(
+ LanguageServerId(0),
+ lsp::PublishDiagnosticsParams {
+ uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
+ diagnostics: vec![diagnostic],
+ version: None,
+ },
+ None,
+ language::DiagnosticSourceKind::Pushed,
+ &[],
+ cx,
+ )
+ .unwrap();
+ });
+ });
+
+ let (_request, respond_tx) = requests.predict.next().await.unwrap();
+ respond_tx
+ .send(model_response(indoc! {r#"
+ --- a/root/2.txt
+ +++ b/root/2.txt
+ Hola!
+ -Como
+ +Como estas?
+ Adios
+ "#}))
+ .unwrap();
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ let prediction = ep_store
+ .current_prediction_for_buffer(&buffer1, &project, cx)
+ .unwrap();
+ assert_matches!(
+ prediction,
+ BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
+ );
+ });
+
+ let buffer2 = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ let prediction = ep_store
+ .current_prediction_for_buffer(&buffer2, &project, cx)
+ .unwrap();
+ assert_matches!(prediction, BufferEditPrediction::Local { .. });
+ });
+}
+
+#[gpui::test]
+async fn test_simple_request(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ let prediction_task = ep_store.update(cx, |ep_store, cx| {
+ ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
+ });
+
+ let (_, respond_tx) = requests.predict.next().await.unwrap();
+
+ // TODO Put back when we have a structured request again
+ // assert_eq!(
+ // request.excerpt_path.as_ref(),
+ // Path::new(path!("root/foo.md"))
+ // );
+ // assert_eq!(
+ // request.cursor_point,
+ // Point {
+ // line: Line(1),
+ // column: 3
+ // }
+ // );
+
+ respond_tx
+ .send(model_response(indoc! { r"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "}))
+ .unwrap();
+
+ let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
+
+ assert_eq!(prediction.edits.len(), 1);
+ assert_eq!(
+ prediction.edits[0].0.to_point(&snapshot).start,
+ language::Point::new(1, 3)
+ );
+ assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
+}
+
+#[gpui::test]
+async fn test_request_events(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\n\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.register_buffer(&buffer, &project, cx);
+ });
+
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit(vec![(7..7, "How")], None, cx);
+ });
+
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ let prediction_task = ep_store.update(cx, |ep_store, cx| {
+ ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
+ });
+
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
+
+ let prompt = prompt_from_request(&request);
+ assert!(
+ prompt.contains(indoc! {"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ -1,3 +1,3 @@
+ Hello!
+ -
+ +How
+ Bye
+ "}),
+ "{prompt}"
+ );
+
+ respond_tx
+ .send(model_response(indoc! {r#"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "#}))
+ .unwrap();
+
+ let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
+
+ assert_eq!(prediction.edits.len(), 1);
+ assert_eq!(
+ prediction.edits[0].0.to_point(&snapshot).start,
+ language::Point::new(1, 3)
+ );
+ assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
+}
+
+#[gpui::test]
+async fn test_empty_prediction(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ const NO_OP_DIFF: &str = indoc! { r"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How
+ Bye
+ "};
+
+ let (_, respond_tx) = requests.predict.next().await.unwrap();
+ let response = model_response(NO_OP_DIFF);
+ let id = response.id.clone();
+ respond_tx.send(response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ assert!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .is_none()
+ );
+ });
+
+ // prediction is reported as rejected
+ let (reject_request, _) = requests.reject.next().await.unwrap();
+
+ assert_eq!(
+ &reject_request.rejections,
+ &[EditPredictionRejection {
+ request_id: id,
+ reason: EditPredictionRejectReason::Empty,
+ was_shown: false
+ }]
+ );
+}
+
+#[gpui::test]
+async fn test_interpolated_empty(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_tx) = requests.predict.next().await.unwrap();
+
+ buffer.update(cx, |buffer, cx| {
+ buffer.set_text("Hello!\nHow are you?\nBye", cx);
+ });
+
+ let response = model_response(SIMPLE_DIFF);
+ let id = response.id.clone();
+ respond_tx.send(response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ assert!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .is_none()
+ );
+ });
+
+ // prediction is reported as rejected
+ let (reject_request, _) = requests.reject.next().await.unwrap();
+
+ assert_eq!(
+ &reject_request.rejections,
+ &[EditPredictionRejection {
+ request_id: id,
+ reason: EditPredictionRejectReason::InterpolatedEmpty,
+ was_shown: false
+ }]
+ );
+}
+
+const SIMPLE_DIFF: &str = indoc! { r"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+"};
+
+#[gpui::test]
+async fn test_replace_current(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_tx) = requests.predict.next().await.unwrap();
+ let first_response = model_response(SIMPLE_DIFF);
+ let first_id = first_response.id.clone();
+ respond_tx.send(first_response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ assert_eq!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ first_id
+ );
+ });
+
+ // a second request is triggered
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_tx) = requests.predict.next().await.unwrap();
+ let second_response = model_response(SIMPLE_DIFF);
+ let second_id = second_response.id.clone();
+ respond_tx.send(second_response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ // second replaces first
+ assert_eq!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ second_id
+ );
+ });
+
+ // first is reported as replaced
+ let (reject_request, _) = requests.reject.next().await.unwrap();
+
+ assert_eq!(
+ &reject_request.rejections,
+ &[EditPredictionRejection {
+ request_id: first_id,
+ reason: EditPredictionRejectReason::Replaced,
+ was_shown: false
+ }]
+ );
+}
+
+#[gpui::test]
+async fn test_current_preferred(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_tx) = requests.predict.next().await.unwrap();
+ let first_response = model_response(SIMPLE_DIFF);
+ let first_id = first_response.id.clone();
+ respond_tx.send(first_response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ assert_eq!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ first_id
+ );
+ });
+
+ // a second request is triggered
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_tx) = requests.predict.next().await.unwrap();
+ // worse than current prediction
+ let second_response = model_response(indoc! { r"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are
+ Bye
+ "});
+ let second_id = second_response.id.clone();
+ respond_tx.send(second_response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ // first is preferred over second
+ assert_eq!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ first_id
+ );
+ });
+
+ // second is reported as rejected
+ let (reject_request, _) = requests.reject.next().await.unwrap();
+
+ assert_eq!(
+ &reject_request.rejections,
+ &[EditPredictionRejection {
+ request_id: second_id,
+ reason: EditPredictionRejectReason::CurrentPreferred,
+ was_shown: false
+ }]
+ );
+}
+
+#[gpui::test]
+async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ // start two refresh tasks
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_first) = requests.predict.next().await.unwrap();
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_second) = requests.predict.next().await.unwrap();
+
+ // wait for throttle
+ cx.run_until_parked();
+
+ // second responds first
+ let second_response = model_response(SIMPLE_DIFF);
+ let second_id = second_response.id.clone();
+ respond_second.send(second_response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ // current prediction is second
+ assert_eq!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ second_id
+ );
+ });
+
+ let first_response = model_response(SIMPLE_DIFF);
+ let first_id = first_response.id.clone();
+ respond_first.send(first_response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ // current prediction is still second, since first was cancelled
+ assert_eq!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ second_id
+ );
+ });
+
+ // first is reported as rejected
+ let (reject_request, _) = requests.reject.next().await.unwrap();
+
+ cx.run_until_parked();
+
+ assert_eq!(
+ &reject_request.rejections,
+ &[EditPredictionRejection {
+ request_id: first_id,
+ reason: EditPredictionRejectReason::Canceled,
+ was_shown: false
+ }]
+ );
+}
+
+#[gpui::test]
+async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ // start two refresh tasks
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_first) = requests.predict.next().await.unwrap();
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_second) = requests.predict.next().await.unwrap();
+
+ // wait for throttle, so requests are sent
+ cx.run_until_parked();
+
+ ep_store.update(cx, |ep_store, cx| {
+ // start a third request
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+
+ // 2 are pending, so 2nd is cancelled
+ assert_eq!(
+ ep_store
+ .get_or_init_project(&project, cx)
+ .cancelled_predictions
+ .iter()
+ .copied()
+ .collect::<Vec<_>>(),
+ [1]
+ );
+ });
+
+ // wait for throttle
+ cx.run_until_parked();
+
+ let (_, respond_third) = requests.predict.next().await.unwrap();
+
+ let first_response = model_response(SIMPLE_DIFF);
+ let first_id = first_response.id.clone();
+ respond_first.send(first_response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ // current prediction is first
+ assert_eq!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ first_id
+ );
+ });
+
+ let cancelled_response = model_response(SIMPLE_DIFF);
+ let cancelled_id = cancelled_response.id.clone();
+ respond_second.send(cancelled_response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ // current prediction is still first, since second was cancelled
+ assert_eq!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ first_id
+ );
+ });
+
+ let third_response = model_response(SIMPLE_DIFF);
+ let third_response_id = third_response.id.clone();
+ respond_third.send(third_response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ // third completes and replaces first
+ assert_eq!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ third_response_id
+ );
+ });
+
+ // second is reported as rejected
+ let (reject_request, _) = requests.reject.next().await.unwrap();
+
+ cx.run_until_parked();
+
+ assert_eq!(
+ &reject_request.rejections,
+ &[
+ EditPredictionRejection {
+ request_id: cancelled_id,
+ reason: EditPredictionRejectReason::Canceled,
+ was_shown: false
+ },
+ EditPredictionRejection {
+ request_id: first_id,
+ reason: EditPredictionRejectReason::Replaced,
+ was_shown: false
+ }
+ ]
+ );
+}
+
+#[gpui::test]
+async fn test_rejections_flushing(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.reject_prediction(
+ EditPredictionId("test-1".into()),
+ EditPredictionRejectReason::Discarded,
+ false,
+ );
+ ep_store.reject_prediction(
+ EditPredictionId("test-2".into()),
+ EditPredictionRejectReason::Canceled,
+ true,
+ );
+ });
+
+ cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
+ cx.run_until_parked();
+
+ let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
+ respond_tx.send(()).unwrap();
+
+ // batched
+ assert_eq!(reject_request.rejections.len(), 2);
+ assert_eq!(
+ reject_request.rejections[0],
+ EditPredictionRejection {
+ request_id: "test-1".to_string(),
+ reason: EditPredictionRejectReason::Discarded,
+ was_shown: false
+ }
+ );
+ assert_eq!(
+ reject_request.rejections[1],
+ EditPredictionRejection {
+ request_id: "test-2".to_string(),
+ reason: EditPredictionRejectReason::Canceled,
+ was_shown: true
+ }
+ );
+
+ // Reaching batch size limit sends without debounce
+ ep_store.update(cx, |ep_store, _cx| {
+ for i in 0..70 {
+ ep_store.reject_prediction(
+ EditPredictionId(format!("batch-{}", i).into()),
+ EditPredictionRejectReason::Discarded,
+ false,
+ );
+ }
+ });
+
+ // First MAX/2 items are sent immediately
+ cx.run_until_parked();
+ let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
+ respond_tx.send(()).unwrap();
+
+ assert_eq!(reject_request.rejections.len(), 50);
+ assert_eq!(reject_request.rejections[0].request_id, "batch-0");
+ assert_eq!(reject_request.rejections[49].request_id, "batch-49");
+
+ // Remaining items are debounced with the next batch
+ cx.executor().advance_clock(Duration::from_secs(15));
+ cx.run_until_parked();
+
+ let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
+ respond_tx.send(()).unwrap();
+
+ assert_eq!(reject_request.rejections.len(), 20);
+ assert_eq!(reject_request.rejections[0].request_id, "batch-50");
+ assert_eq!(reject_request.rejections[19].request_id, "batch-69");
+
+ // Request failure
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.reject_prediction(
+ EditPredictionId("retry-1".into()),
+ EditPredictionRejectReason::Discarded,
+ false,
+ );
+ });
+
+ cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
+ cx.run_until_parked();
+
+ let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
+ assert_eq!(reject_request.rejections.len(), 1);
+ assert_eq!(reject_request.rejections[0].request_id, "retry-1");
+ // Simulate failure
+ drop(_respond_tx);
+
+ // Add another rejection
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.reject_prediction(
+ EditPredictionId("retry-2".into()),
+ EditPredictionRejectReason::Discarded,
+ false,
+ );
+ });
+
+ cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
+ cx.run_until_parked();
+
+ // Retry should include both the failed item and the new one
+ let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
+ respond_tx.send(()).unwrap();
+
+ assert_eq!(reject_request.rejections.len(), 2);
+ assert_eq!(reject_request.rejections[0].request_id, "retry-1");
+ assert_eq!(reject_request.rejections[1].request_id, "retry-2");
+}
+
+// Skipped until we start including diagnostics in prompt
+// #[gpui::test]
+// async fn test_request_diagnostics(cx: &mut TestAppContext) {
+// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
+// let fs = FakeFs::new(cx.executor());
+// fs.insert_tree(
+// "/root",
+// json!({
+// "foo.md": "Hello!\nBye"
+// }),
+// )
+// .await;
+// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
+// let diagnostic = lsp::Diagnostic {
+// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
+// severity: Some(lsp::DiagnosticSeverity::ERROR),
+// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
+// ..Default::default()
+// };
+
+// project.update(cx, |project, cx| {
+// project.lsp_store().update(cx, |lsp_store, cx| {
+// // Create some diagnostics
+// lsp_store
+// .update_diagnostics(
+// LanguageServerId(0),
+// lsp::PublishDiagnosticsParams {
+// uri: path_to_buffer_uri.clone(),
+// diagnostics: vec![diagnostic],
+// version: None,
+// },
+// None,
+// language::DiagnosticSourceKind::Pushed,
+// &[],
+// cx,
+// )
+// .unwrap();
+// });
+// });
+
+// let buffer = project
+// .update(cx, |project, cx| {
+// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+// project.open_buffer(path, cx)
+// })
+// .await
+// .unwrap();
+
+// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+// let position = snapshot.anchor_before(language::Point::new(0, 0));
+
+// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
+// ep_store.request_prediction(&project, &buffer, position, cx)
+// });
+
+// let (request, _respond_tx) = req_rx.next().await.unwrap();
+
+// assert_eq!(request.diagnostic_groups.len(), 1);
+// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
+// .unwrap();
+// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
+// assert_eq!(
+// value,
+// json!({
+// "entries": [{
+// "range": {
+// "start": 8,
+// "end": 10
+// },
+// "diagnostic": {
+// "source": null,
+// "code": null,
+// "code_description": null,
+// "severity": 1,
+// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
+// "markdown": null,
+// "group_id": 0,
+// "is_primary": true,
+// "is_disk_based": false,
+// "is_unnecessary": false,
+// "source_kind": "Pushed",
+// "data": null,
+// "underline": true
+// }
+// }],
+// "primary_ix": 0
+// })
+// );
+// }
+
+fn model_response(text: &str) -> open_ai::Response {
+ open_ai::Response {
+ id: Uuid::new_v4().to_string(),
+ object: "response".into(),
+ created: 0,
+ model: "model".into(),
+ choices: vec![open_ai::Choice {
+ index: 0,
+ message: open_ai::RequestMessage::Assistant {
+ content: Some(open_ai::MessageContent::Plain(text.to_string())),
+ tool_calls: vec![],
+ },
+ finish_reason: None,
+ }],
+ usage: Usage {
+ prompt_tokens: 0,
+ completion_tokens: 0,
+ total_tokens: 0,
+ },
+ }
+}
+
+fn prompt_from_request(request: &open_ai::Request) -> &str {
+ assert_eq!(request.messages.len(), 1);
+ let open_ai::RequestMessage::User {
+ content: open_ai::MessageContent::Plain(content),
+ ..
+ } = &request.messages[0]
+ else {
+ panic!(
+ "Request does not have single user message of type Plain. {:#?}",
+ request
+ );
+ };
+ content
+}
+
+struct RequestChannels {
+ predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
+ reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
+}
+
+fn init_test_with_fake_client(
+ cx: &mut TestAppContext,
+) -> (Entity<EditPredictionStore>, RequestChannels) {
+ cx.update(move |cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ zlog::init_test();
+
+ let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
+ let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
+
+ let http_client = FakeHttpClient::create({
+ move |req| {
+ let uri = req.uri().path().to_string();
+ let mut body = req.into_body();
+ let predict_req_tx = predict_req_tx.clone();
+ let reject_req_tx = reject_req_tx.clone();
+ async move {
+ let resp = match uri.as_str() {
+ "/client/llm_tokens" => serde_json::to_string(&json!({
+ "token": "test"
+ }))
+ .unwrap(),
+ "/predict_edits/raw" => {
+ let mut buf = Vec::new();
+ body.read_to_end(&mut buf).await.ok();
+ let req = serde_json::from_slice(&buf).unwrap();
+
+ let (res_tx, res_rx) = oneshot::channel();
+ predict_req_tx.unbounded_send((req, res_tx)).unwrap();
+ serde_json::to_string(&res_rx.await?).unwrap()
+ }
+ "/predict_edits/reject" => {
+ let mut buf = Vec::new();
+ body.read_to_end(&mut buf).await.ok();
+ let req = serde_json::from_slice(&buf).unwrap();
+
+ let (res_tx, res_rx) = oneshot::channel();
+ reject_req_tx.unbounded_send((req, res_tx)).unwrap();
+ serde_json::to_string(&res_rx.await?).unwrap()
+ }
+ _ => {
+ panic!("Unexpected path: {}", uri)
+ }
+ };
+
+ Ok(Response::builder().body(resp.into()).unwrap())
+ }
+ }
+ });
+
+ let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
+ client.cloud_client().set_credentials(1, "test".into());
+
+ language_model::init(client.clone(), cx);
+
+ let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+ let ep_store = EditPredictionStore::global(&client, &user_store, cx);
+
+ (
+ ep_store,
+ RequestChannels {
+ predict: predict_req_rx,
+ reject: reject_req_rx,
+ },
+ )
+ })
+}
+
+const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
+
+#[gpui::test]
+async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
+ let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
+ let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
+ to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
+ });
+
+ let edit_preview = cx
+ .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
+ .await;
+
+ let completion = EditPrediction {
+ edits,
+ edit_preview,
+ buffer: buffer.clone(),
+ snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
+ id: EditPredictionId("the-id".into()),
+ inputs: EditPredictionInputs {
+ events: Default::default(),
+ included_files: Default::default(),
+ cursor_point: cloud_llm_client::predict_edits_v3::Point {
+ line: Line(0),
+ column: 0,
+ },
+ cursor_path: Path::new("").into(),
+ },
+ buffer_snapshotted_at: Instant::now(),
+ response_received_at: Instant::now(),
+ };
+
+ cx.update(|cx| {
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(2..5, "REM".into()), (9..11, "".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(2..2, "REM".into()), (6..8, "".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.undo(cx));
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(2..5, "REM".into()), (9..11, "".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(3..3, "EM".into()), (7..9, "".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(4..4, "M".into()), (8..10, "".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(9..11, "".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(4..4, "M".into()), (8..10, "".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(4..4, "M".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
+ assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
+ })
+}
+
+#[gpui::test]
+async fn test_clean_up_diff(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ assert_eq!(
+ apply_edit_prediction(
+ indoc! {"
+ fn main() {
+ let word_1 = \"lorem\";
+ let range = word.len()..word.len();
+ }
+ "},
+ indoc! {"
+ <|editable_region_start|>
+ fn main() {
+ let word_1 = \"lorem\";
+ let range = word_1.len()..word_1.len();
+ }
+
+ <|editable_region_end|>
+ "},
+ cx,
+ )
+ .await,
+ indoc! {"
+ fn main() {
+ let word_1 = \"lorem\";
+ let range = word_1.len()..word_1.len();
+ }
+ "},
+ );
+
+ assert_eq!(
+ apply_edit_prediction(
+ indoc! {"
+ fn main() {
+ let story = \"the quick\"
+ }
+ "},
+ indoc! {"
+ <|editable_region_start|>
+ fn main() {
+ let story = \"the quick brown fox jumps over the lazy dog\";
+ }
+
+ <|editable_region_end|>
+ "},
+ cx,
+ )
+ .await,
+ indoc! {"
+ fn main() {
+ let story = \"the quick brown fox jumps over the lazy dog\";
+ }
+ "},
+ );
+}
+
+#[gpui::test]
+async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let buffer_content = "lorem\n";
+ let completion_response = indoc! {"
+ ```animals.js
+ <|start_of_file|>
+ <|editable_region_start|>
+ lorem
+ ipsum
+ <|editable_region_end|>
+ ```"};
+
+ assert_eq!(
+ apply_edit_prediction(buffer_content, completion_response, cx).await,
+ "lorem\nipsum"
+ );
+}
+
+#[gpui::test]
+async fn test_can_collect_data(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
+ .await;
+
+ let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/project/src/main.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.data_collection_choice = DataCollectionChoice::Enabled
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ true
+ );
+
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.data_collection_choice = DataCollectionChoice::Disabled
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+}
+
+#[gpui::test]
+async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ let project = Project::test(fs.clone(), [], cx).await;
+
+ let buffer = cx.new(|_cx| {
+ Buffer::remote(
+ language::BufferId::new(1).unwrap(),
+ ReplicaId::new(1),
+ language::Capability::ReadWrite,
+ "fn main() {\n println!(\"Hello\");\n}",
+ )
+ });
+
+ let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.data_collection_choice = DataCollectionChoice::Enabled
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+}
+
+#[gpui::test]
+async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/project"),
+ json!({
+ "LICENSE": BSD_0_TXT,
+ ".env": "SECRET_KEY=secret"
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer("/project/.env", cx)
+ })
+ .await
+ .unwrap();
+
+ let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.data_collection_choice = DataCollectionChoice::Enabled
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+}
+
+#[gpui::test]
+async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ let project = Project::test(fs.clone(), [], cx).await;
+ let buffer = cx.new(|cx| Buffer::local("", cx));
+
+ let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.data_collection_choice = DataCollectionChoice::Enabled
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+}
+
+#[gpui::test]
+async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
+ .await;
+
+ let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer("/project/main.rs", cx)
+ })
+ .await
+ .unwrap();
+
+ let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.data_collection_choice = DataCollectionChoice::Enabled
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+}
+
+#[gpui::test]
+async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/open_source_worktree"),
+ json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
+ )
+ .await;
+ fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
+ .await;
+
+ let project = Project::test(
+ fs.clone(),
+ [
+ path!("/open_source_worktree").as_ref(),
+ path!("/closed_source_worktree").as_ref(),
+ ],
+ cx,
+ )
+ .await;
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.data_collection_choice = DataCollectionChoice::Enabled
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ true
+ );
+
+ let closed_source_file = project
+ .update(cx, |project, cx| {
+ let worktree2 = project
+ .worktree_for_root_name("closed_source_worktree", cx)
+ .unwrap();
+ worktree2.update(cx, |worktree2, cx| {
+ worktree2.load_file(rel_path("main.rs"), cx)
+ })
+ })
+ .await
+ .unwrap()
+ .file;
+
+ buffer.update(cx, |buffer, cx| {
+ buffer.file_updated(closed_source_file, cx);
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+}
+
+#[gpui::test]
+async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/worktree1"),
+ json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
+ )
+ .await;
+ fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
+ .await;
+
+ let project = Project::test(
+ fs.clone(),
+ [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
+ cx,
+ )
+ .await;
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/worktree1/main.rs"), cx)
+ })
+ .await
+ .unwrap();
+ let private_buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/worktree2/file.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.data_collection_choice = DataCollectionChoice::Enabled
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ true
+ );
+
+ // this has a side effect of registering the buffer to watch for edits
+ run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+
+ private_buffer.update(cx, |private_buffer, cx| {
+ private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+
+ // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
+ // included
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit(
+ [(
+ 0..0,
+ " ".repeat(MAX_EVENT_TOKENS * zeta1::BYTES_PER_TOKEN_GUESS),
+ )],
+ None,
+ cx,
+ );
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ true
+ );
+}
+
+fn init_test(cx: &mut TestAppContext) {
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ });
+}
+
+async fn apply_edit_prediction(
+ buffer_content: &str,
+ completion_response: &str,
+ cx: &mut TestAppContext,
+) -> String {
+ let fs = project::FakeFs::new(cx.executor());
+ let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
+ let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
+ *response.lock() = completion_response.to_string();
+ let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
+ });
+ buffer.read_with(cx, |buffer, _| buffer.text())
+}
+
+async fn run_edit_prediction(
+ buffer: &Entity<Buffer>,
+ project: &Entity<Project>,
+ ep_store: &Entity<EditPredictionStore>,
+ cx: &mut TestAppContext,
+) -> EditPrediction {
+ let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.register_buffer(buffer, &project, cx)
+ });
+ cx.background_executor.run_until_parked();
+ let prediction_task = ep_store.update(cx, |ep_store, cx| {
+ ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
+ });
+ prediction_task.await.unwrap().unwrap().prediction.unwrap()
+}
+
+async fn make_test_ep_store(
+ project: &Entity<Project>,
+ cx: &mut TestAppContext,
+) -> (
+ Entity<EditPredictionStore>,
+ Arc<Mutex<Option<PredictEditsBody>>>,
+ Arc<Mutex<String>>,
+) {
+ let default_response = indoc! {"
+ ```main.rs
+ <|start_of_file|>
+ <|editable_region_start|>
+ hello world
+ <|editable_region_end|>
+ ```"
+ };
+ let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
+ let completion_response: Arc<Mutex<String>> =
+ Arc::new(Mutex::new(default_response.to_string()));
+ let http_client = FakeHttpClient::create({
+ let captured_request = captured_request.clone();
+ let completion_response = completion_response.clone();
+ let mut next_request_id = 0;
+ move |req| {
+ let captured_request = captured_request.clone();
+ let completion_response = completion_response.clone();
+ async move {
+ match (req.method(), req.uri().path()) {
+ (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
+ .status(200)
+ .body(
+ serde_json::to_string(&CreateLlmTokenResponse {
+ token: LlmToken("the-llm-token".to_string()),
+ })
+ .unwrap()
+ .into(),
+ )
+ .unwrap()),
+ (&Method::POST, "/predict_edits/v2") => {
+ let mut request_body = String::new();
+ req.into_body().read_to_string(&mut request_body).await?;
+ *captured_request.lock() =
+ Some(serde_json::from_str(&request_body).unwrap());
+ next_request_id += 1;
+ Ok(http_client::Response::builder()
+ .status(200)
+ .body(
+ serde_json::to_string(&PredictEditsResponse {
+ request_id: format!("request-{next_request_id}"),
+ output_excerpt: completion_response.lock().clone(),
+ })
+ .unwrap()
+ .into(),
+ )
+ .unwrap())
+ }
+ _ => Ok(http_client::Response::builder()
+ .status(404)
+ .body("Not Found".into())
+ .unwrap()),
+ }
+ }
+ }
+ });
+
+ let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
+ cx.update(|cx| {
+ RefreshLlmTokenListener::register(client.clone(), cx);
+ });
+ let _server = FakeServer::for_client(42, &client, cx).await;
+
+ let ep_store = cx.new(|cx| {
+ let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
+ ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
+
+ let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
+ for worktree in worktrees {
+ let worktree_id = worktree.read(cx).id();
+ ep_store
+ .get_or_init_project(project, cx)
+ .license_detection_watchers
+ .entry(worktree_id)
+ .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
+ }
+
+ ep_store
+ });
+
+ (ep_store, captured_request, completion_response)
+}
+
+fn to_completion_edits(
+ iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
+ buffer: &Entity<Buffer>,
+ cx: &App,
+) -> Vec<(Range<Anchor>, Arc<str>)> {
+ let buffer = buffer.read(cx);
+ iterator
+ .into_iter()
+ .map(|(range, text)| {
+ (
+ buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
+ text,
+ )
+ })
+ .collect()
+}
+
+fn from_completion_edits(
+ editor_edits: &[(Range<Anchor>, Arc<str>)],
+ buffer: &Entity<Buffer>,
+ cx: &App,
+) -> Vec<(Range<usize>, Arc<str>)> {
+ let buffer = buffer.read(cx);
+ editor_edits
+ .iter()
+ .map(|(range, text)| {
+ (
+ range.start.to_offset(buffer)..range.end.to_offset(buffer),
+ text.clone(),
+ )
+ })
+ .collect()
+}
+
+#[ctor::ctor]
+fn init_logger() {
+ zlog::init_test();
+}
@@ -99,7 +99,7 @@ pub struct EditPrediction {
#[derive(Debug, Clone, Serialize)]
pub struct EditPredictionInputs {
pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
- pub included_files: Vec<cloud_llm_client::predict_edits_v3::IncludedFile>,
+ pub included_files: Vec<cloud_llm_client::predict_edits_v3::RelatedFile>,
pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
pub cursor_path: Arc<Path>,
}
@@ -1,7 +1,7 @@
use anyhow::{Context as _, Result};
use cloud_llm_client::predict_edits_v3::Event;
use credentials_provider::CredentialsProvider;
-use edit_prediction_context2::RelatedFile;
+use edit_prediction_context::RelatedFile;
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
use gpui::{
App, AppContext as _, Entity, Task,
@@ -197,7 +197,7 @@ impl SweepAi {
let inputs = EditPredictionInputs {
events,
- included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
+ included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
path: full_path.clone(),
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
@@ -1,55 +1,56 @@
-use std::{cmp, sync::Arc, time::Duration};
+use std::{cmp, sync::Arc};
use client::{Client, UserStore};
use cloud_llm_client::EditPredictionRejectReason;
-use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
+use edit_prediction_types::{DataCollectionState, Direction, EditPredictionDelegate};
use gpui::{App, Entity, prelude::*};
-use language::ToPoint as _;
+use language::{Buffer, ToPoint as _};
use project::Project;
-use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel};
+use crate::{BufferEditPrediction, EditPredictionModel, EditPredictionStore};
-pub struct ZetaEditPredictionProvider {
- zeta: Entity<Zeta>,
+pub struct ZedEditPredictionDelegate {
+ store: Entity<EditPredictionStore>,
project: Entity<Project>,
+ singleton_buffer: Option<Entity<Buffer>>,
}
-impl ZetaEditPredictionProvider {
- pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
-
+impl ZedEditPredictionDelegate {
pub fn new(
project: Entity<Project>,
+ singleton_buffer: Option<Entity<Buffer>>,
client: &Arc<Client>,
user_store: &Entity<UserStore>,
cx: &mut Context<Self>,
) -> Self {
- let zeta = Zeta::global(client, user_store, cx);
- zeta.update(cx, |zeta, cx| {
- zeta.register_project(&project, cx);
+ let store = EditPredictionStore::global(client, user_store, cx);
+ store.update(cx, |store, cx| {
+ store.register_project(&project, cx);
});
- cx.observe(&zeta, |_this, _zeta, cx| {
+ cx.observe(&store, |_this, _ep_store, cx| {
cx.notify();
})
.detach();
Self {
project: project,
- zeta,
+ store: store,
+ singleton_buffer,
}
}
}
-impl EditPredictionProvider for ZetaEditPredictionProvider {
+impl EditPredictionDelegate for ZedEditPredictionDelegate {
fn name() -> &'static str {
- "zed-predict2"
+ "zed-predict"
}
fn display_name() -> &'static str {
- "Zed's Edit Predictions 2"
+ "Zed's Edit Predictions"
}
- fn show_completions_in_menu() -> bool {
+ fn show_predictions_in_menu() -> bool {
true
}
@@ -57,17 +58,38 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
true
}
- fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
- // TODO [zeta2]
- DataCollectionState::Unsupported
+ fn data_collection_state(&self, cx: &App) -> DataCollectionState {
+ if let Some(buffer) = &self.singleton_buffer
+ && let Some(file) = buffer.read(cx).file()
+ {
+ let is_project_open_source =
+ self.store
+ .read(cx)
+ .is_file_open_source(&self.project, file, cx);
+ if self.store.read(cx).data_collection_choice.is_enabled() {
+ DataCollectionState::Enabled {
+ is_project_open_source,
+ }
+ } else {
+ DataCollectionState::Disabled {
+ is_project_open_source,
+ }
+ }
+ } else {
+ return DataCollectionState::Disabled {
+ is_project_open_source: false,
+ };
+ }
}
- fn toggle_data_collection(&mut self, _cx: &mut App) {
- // TODO [zeta2]
+ fn toggle_data_collection(&mut self, cx: &mut App) {
+ self.store.update(cx, |store, cx| {
+ store.toggle_data_collection_choice(cx);
+ });
}
fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
- self.zeta.read(cx).usage(cx)
+ self.store.read(cx).usage(cx)
}
fn is_enabled(
@@ -76,16 +98,16 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
_cursor_position: language::Anchor,
cx: &App,
) -> bool {
- let zeta = self.zeta.read(cx);
- if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep {
- zeta.has_sweep_api_token()
+ let store = self.store.read(cx);
+ if store.edit_prediction_model == EditPredictionModel::Sweep {
+ store.has_sweep_api_token()
} else {
true
}
}
fn is_refreshing(&self, cx: &App) -> bool {
- self.zeta.read(cx).is_refreshing(&self.project)
+ self.store.read(cx).is_refreshing(&self.project)
}
fn refresh(
@@ -95,24 +117,24 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
_debounce: bool,
cx: &mut Context<Self>,
) {
- let zeta = self.zeta.read(cx);
+ let store = self.store.read(cx);
- if zeta.user_store.read_with(cx, |user_store, _cx| {
+ if store.user_store.read_with(cx, |user_store, _cx| {
user_store.account_too_young() || user_store.has_overdue_invoices()
}) {
return;
}
- if let Some(current) = zeta.current_prediction_for_buffer(&buffer, &self.project, cx)
+ if let Some(current) = store.current_prediction_for_buffer(&buffer, &self.project, cx)
&& let BufferEditPrediction::Local { prediction } = current
&& prediction.interpolate(buffer.read(cx)).is_some()
{
return;
}
- self.zeta.update(cx, |zeta, cx| {
- zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx);
- zeta.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
+ self.store.update(cx, |store, cx| {
+ store.refresh_context(&self.project, &buffer, cursor_position, cx);
+ store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
});
}
@@ -126,20 +148,20 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
}
fn accept(&mut self, cx: &mut Context<Self>) {
- self.zeta.update(cx, |zeta, cx| {
- zeta.accept_current_prediction(&self.project, cx);
+ self.store.update(cx, |store, cx| {
+ store.accept_current_prediction(&self.project, cx);
});
}
fn discard(&mut self, cx: &mut Context<Self>) {
- self.zeta.update(cx, |zeta, _cx| {
- zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &self.project);
+ self.store.update(cx, |store, _cx| {
+ store.reject_current_prediction(EditPredictionRejectReason::Discarded, &self.project);
});
}
fn did_show(&mut self, cx: &mut Context<Self>) {
- self.zeta.update(cx, |zeta, cx| {
- zeta.did_show_current_prediction(&self.project, cx);
+ self.store.update(cx, |store, cx| {
+ store.did_show_current_prediction(&self.project, cx);
});
}
@@ -148,16 +170,16 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
buffer: &Entity<language::Buffer>,
cursor_position: language::Anchor,
cx: &mut Context<Self>,
- ) -> Option<edit_prediction::EditPrediction> {
+ ) -> Option<edit_prediction_types::EditPrediction> {
let prediction =
- self.zeta
+ self.store
.read(cx)
.current_prediction_for_buffer(buffer, &self.project, cx)?;
let prediction = match prediction {
BufferEditPrediction::Local { prediction } => prediction,
BufferEditPrediction::Jump { prediction } => {
- return Some(edit_prediction::EditPrediction::Jump {
+ return Some(edit_prediction_types::EditPrediction::Jump {
id: Some(prediction.id.to_string().into()),
snapshot: prediction.snapshot.clone(),
target: prediction.edits.first().unwrap().0.start,
@@ -169,8 +191,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
let snapshot = buffer.snapshot();
let Some(edits) = prediction.interpolate(&snapshot) else {
- self.zeta.update(cx, |zeta, _cx| {
- zeta.reject_current_prediction(
+ self.store.update(cx, |store, _cx| {
+ store.reject_current_prediction(
EditPredictionRejectReason::InterpolatedEmpty,
&self.project,
);
@@ -208,7 +230,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
}
}
- Some(edit_prediction::EditPrediction::Local {
+ Some(edit_prediction_types::EditPrediction::Local {
id: Some(prediction.id.to_string().into()),
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
edit_preview: Some(prediction.edit_preview.clone()),
@@ -3,7 +3,7 @@ mod input_excerpt;
use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
use crate::{
- EditPredictionId, ZedUpdateRequiredError, Zeta,
+ EditPredictionId, EditPredictionStore, ZedUpdateRequiredError,
prediction::{EditPredictionInputs, EditPredictionResult},
};
use anyhow::{Context as _, Result};
@@ -30,23 +30,23 @@ pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
pub(crate) const MAX_EVENT_TOKENS: usize = 500;
pub(crate) fn request_prediction_with_zeta1(
- zeta: &mut Zeta,
+ store: &mut EditPredictionStore,
project: &Entity<Project>,
buffer: &Entity<Buffer>,
snapshot: BufferSnapshot,
position: language::Anchor,
events: Vec<Arc<Event>>,
trigger: PredictEditsRequestTrigger,
- cx: &mut Context<Zeta>,
+ cx: &mut Context<EditPredictionStore>,
) -> Task<Result<Option<EditPredictionResult>>> {
let buffer = buffer.clone();
let buffer_snapshotted_at = Instant::now();
- let client = zeta.client.clone();
- let llm_token = zeta.llm_token.clone();
+ let client = store.client.clone();
+ let llm_token = store.llm_token.clone();
let app_version = AppVersion::global(cx);
let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
- let can_collect_file = zeta.can_collect_file(project, file, cx);
+ let can_collect_file = store.can_collect_file(project, file, cx);
let git_info = if can_collect_file {
git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
} else {
@@ -102,7 +102,7 @@ pub(crate) fn request_prediction_with_zeta1(
let http_client = client.http_client();
- let response = Zeta::send_api_request::<PredictEditsResponse>(
+ let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
|request| {
let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
predict_edits_url
@@ -124,7 +124,7 @@ pub(crate) fn request_prediction_with_zeta1(
let inputs = EditPredictionInputs {
events: included_events.into(),
- included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
+ included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
path: full_path.clone(),
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
@@ -155,8 +155,8 @@ pub(crate) fn request_prediction_with_zeta1(
Err(err) => {
if err.is::<ZedUpdateRequiredError>() {
cx.update(|cx| {
- this.update(cx, |zeta, _cx| {
- zeta.update_required = true;
+ this.update(cx, |ep_store, _cx| {
+ ep_store.update_required = true;
})
.ok();
@@ -0,0 +1,358 @@
+#[cfg(feature = "eval-support")]
+use crate::EvalCacheEntryKind;
+use crate::prediction::EditPredictionResult;
+use crate::{
+ DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs,
+ EditPredictionRequestedDebugEvent, EditPredictionStore,
+};
+use anyhow::{Result, anyhow, bail};
+use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
+use cloud_llm_client::{EditPredictionRejectReason, PredictEditsRequestTrigger};
+use cloud_zeta2_prompt::CURSOR_MARKER;
+use edit_prediction_context::{EditPredictionExcerpt, Line};
+use edit_prediction_context::{RelatedExcerpt, RelatedFile};
+use futures::channel::oneshot;
+use gpui::{Entity, Task, prelude::*};
+use language::{Anchor, BufferSnapshot};
+use language::{Buffer, Point, ToOffset as _, ToPoint};
+use project::{Project, ProjectItem as _};
+use release_channel::AppVersion;
+use std::{
+ env,
+ path::Path,
+ sync::Arc,
+ time::{Duration, Instant},
+};
+
+pub fn request_prediction_with_zeta2(
+ store: &mut EditPredictionStore,
+ project: &Entity<Project>,
+ active_buffer: &Entity<Buffer>,
+ active_snapshot: BufferSnapshot,
+ position: Anchor,
+ events: Vec<Arc<Event>>,
+ mut included_files: Vec<RelatedFile>,
+ trigger: PredictEditsRequestTrigger,
+ cx: &mut Context<EditPredictionStore>,
+) -> Task<Result<Option<EditPredictionResult>>> {
+ let options = store.options.clone();
+ let buffer_snapshotted_at = Instant::now();
+
+ let Some((excerpt_path, active_project_path)) = active_snapshot
+ .file()
+ .map(|file| -> Arc<Path> { file.full_path(cx).into() })
+ .zip(active_buffer.read(cx).project_path(cx))
+ else {
+ return Task::ready(Err(anyhow!("No file path for excerpt")));
+ };
+
+ let client = store.client.clone();
+ let llm_token = store.llm_token.clone();
+ let app_version = AppVersion::global(cx);
+ let debug_tx = store.debug_tx.clone();
+
+ let file = active_buffer.read(cx).file();
+
+ let active_file_full_path = file.as_ref().map(|f| f.full_path(cx));
+
+ // TODO data collection
+ let can_collect_data = file
+ .as_ref()
+ .map_or(false, |file| store.can_collect_file(project, file, cx));
+
+ #[cfg(feature = "eval-support")]
+ let eval_cache = store.eval_cache.clone();
+
+ let request_task = cx.background_spawn({
+ let active_buffer = active_buffer.clone();
+ async move {
+ let cursor_offset = position.to_offset(&active_snapshot);
+ let cursor_point = cursor_offset.to_point(&active_snapshot);
+
+ let before_retrieval = Instant::now();
+
+ let excerpt_options = options.context;
+
+ let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
+ cursor_point,
+ &active_snapshot,
+ &excerpt_options,
+ ) else {
+ return Ok((None, None));
+ };
+
+ let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
+ ..active_snapshot.anchor_before(excerpt.range.end);
+ let related_excerpt = RelatedExcerpt {
+ anchor_range: excerpt_anchor_range.clone(),
+ point_range: Point::new(excerpt.line_range.start.0, 0)
+ ..Point::new(excerpt.line_range.end.0, 0),
+ text: active_snapshot.as_rope().slice(excerpt.range),
+ };
+
+ if let Some(buffer_ix) = included_files
+ .iter()
+ .position(|file| file.buffer.entity_id() == active_buffer.entity_id())
+ {
+ let file = &mut included_files[buffer_ix];
+ file.excerpts.push(related_excerpt);
+ file.merge_excerpts();
+ let last_ix = included_files.len() - 1;
+ included_files.swap(buffer_ix, last_ix);
+ } else {
+ let active_file = RelatedFile {
+ path: active_project_path,
+ buffer: active_buffer.downgrade(),
+ excerpts: vec![related_excerpt],
+ max_row: active_snapshot.max_point().row,
+ };
+ included_files.push(active_file);
+ }
+
+ let included_files = included_files
+ .iter()
+ .map(|related_file| predict_edits_v3::RelatedFile {
+ path: Arc::from(related_file.path.path.as_std_path()),
+ max_row: Line(related_file.max_row),
+ excerpts: related_file
+ .excerpts
+ .iter()
+ .map(|excerpt| predict_edits_v3::Excerpt {
+ start_line: Line(excerpt.point_range.start.row),
+ text: excerpt.text.to_string().into(),
+ })
+ .collect(),
+ })
+ .collect::<Vec<_>>();
+
+ let cloud_request = predict_edits_v3::PredictEditsRequest {
+ excerpt_path,
+ excerpt: String::new(),
+ excerpt_line_range: Line(0)..Line(0),
+ excerpt_range: 0..0,
+ cursor_point: predict_edits_v3::Point {
+ line: predict_edits_v3::Line(cursor_point.row),
+ column: cursor_point.column,
+ },
+ related_files: included_files,
+ events,
+ can_collect_data,
+ debug_info: debug_tx.is_some(),
+ prompt_max_bytes: Some(options.max_prompt_bytes),
+ prompt_format: options.prompt_format,
+ excerpt_parent: None,
+ git_info: None,
+ trigger,
+ };
+
+ let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
+
+ let inputs = EditPredictionInputs {
+ included_files: cloud_request.related_files,
+ events: cloud_request.events,
+ cursor_point: cloud_request.cursor_point,
+ cursor_path: cloud_request.excerpt_path,
+ };
+
+ let retrieval_time = Instant::now() - before_retrieval;
+
+ let debug_response_tx = if let Some(debug_tx) = &debug_tx {
+ let (response_tx, response_rx) = oneshot::channel();
+
+ debug_tx
+ .unbounded_send(DebugEvent::EditPredictionRequested(
+ EditPredictionRequestedDebugEvent {
+ inputs: inputs.clone(),
+ retrieval_time,
+ buffer: active_buffer.downgrade(),
+ local_prompt: match prompt_result.as_ref() {
+ Ok(prompt) => Ok(prompt.clone()),
+ Err(err) => Err(err.to_string()),
+ },
+ position,
+ response_rx,
+ },
+ ))
+ .ok();
+ Some(response_tx)
+ } else {
+ None
+ };
+
+ if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
+ if let Some(debug_response_tx) = debug_response_tx {
+ debug_response_tx
+ .send((Err("Request skipped".to_string()), Duration::ZERO))
+ .ok();
+ }
+ anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
+ }
+
+ let prompt = prompt_result?;
+ let generation_params =
+ cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
+ let request = open_ai::Request {
+ model: EDIT_PREDICTIONS_MODEL_ID.clone(),
+ messages: vec![open_ai::RequestMessage::User {
+ content: open_ai::MessageContent::Plain(prompt),
+ }],
+ stream: false,
+ max_completion_tokens: None,
+ stop: generation_params.stop.unwrap_or_default(),
+ temperature: generation_params.temperature.unwrap_or(0.7),
+ tool_choice: None,
+ parallel_tool_calls: None,
+ tools: vec![],
+ prompt_cache_key: None,
+ reasoning_effort: None,
+ };
+
+ log::trace!("Sending edit prediction request");
+
+ let before_request = Instant::now();
+ let response = EditPredictionStore::send_raw_llm_request(
+ request,
+ client,
+ llm_token,
+ app_version,
+ #[cfg(feature = "eval-support")]
+ eval_cache,
+ #[cfg(feature = "eval-support")]
+ EvalCacheEntryKind::Prediction,
+ )
+ .await;
+ let received_response_at = Instant::now();
+ let request_time = received_response_at - before_request;
+
+ log::trace!("Got edit prediction response");
+
+ if let Some(debug_response_tx) = debug_response_tx {
+ debug_response_tx
+ .send((
+ response
+ .as_ref()
+ .map_err(|err| err.to_string())
+ .map(|response| response.0.clone()),
+ request_time,
+ ))
+ .ok();
+ }
+
+ let (res, usage) = response?;
+ let request_id = EditPredictionId(res.id.clone().into());
+ let Some(mut output_text) = text_from_response(res) else {
+ return Ok((Some((request_id, None)), usage));
+ };
+
+ if output_text.contains(CURSOR_MARKER) {
+ log::trace!("Stripping out {CURSOR_MARKER} from response");
+ output_text = output_text.replace(CURSOR_MARKER, "");
+ }
+
+ let get_buffer_from_context = |path: &Path| {
+ if Some(path) == active_file_full_path.as_deref() {
+ Some((
+ &active_snapshot,
+ std::slice::from_ref(&excerpt_anchor_range),
+ ))
+ } else {
+ None
+ }
+ };
+
+ let (_, edits) = match options.prompt_format {
+ PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
+ if output_text.contains("--- a/\n+++ b/\nNo edits") {
+ let edits = vec![];
+ (&active_snapshot, edits)
+ } else {
+ crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
+ }
+ }
+ PromptFormat::OldTextNewText => {
+ crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await?
+ }
+ _ => {
+ bail!("unsupported prompt format {}", options.prompt_format)
+ }
+ };
+
+ anyhow::Ok((
+ Some((
+ request_id,
+ Some((
+ inputs,
+ active_buffer,
+ active_snapshot.clone(),
+ edits,
+ received_response_at,
+ )),
+ )),
+ usage,
+ ))
+ }
+ });
+
+ cx.spawn(async move |this, cx| {
+ let Some((id, prediction)) =
+ EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
+ else {
+ return Ok(None);
+ };
+
+ let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) =
+ prediction
+ else {
+ return Ok(Some(EditPredictionResult {
+ id,
+ prediction: Err(EditPredictionRejectReason::Empty),
+ }));
+ };
+
+ Ok(Some(
+ EditPredictionResult::new(
+ id,
+ &edited_buffer,
+ &edited_buffer_snapshot,
+ edits.into(),
+ buffer_snapshotted_at,
+ received_response_at,
+ inputs,
+ cx,
+ )
+ .await,
+ ))
+ })
+}
+
+pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
+ let choice = res.choices.pop()?;
+ let output_text = match choice.message {
+ open_ai::RequestMessage::Assistant {
+ content: Some(open_ai::MessageContent::Plain(content)),
+ ..
+ } => content,
+ open_ai::RequestMessage::Assistant {
+ content: Some(open_ai::MessageContent::Multipart(mut content)),
+ ..
+ } => {
+ if content.is_empty() {
+ log::error!("No output from Baseten completion response");
+ return None;
+ }
+
+ match content.remove(0) {
+ open_ai::MessagePart::Text { text } => text,
+ open_ai::MessagePart::Image { .. } => {
+ log::error!("Expected text, got an image");
+ return None;
+ }
+ }
+ }
+ _ => {
+ log::error!("Invalid response message: {:?}", choice.message);
+ return None;
+ }
+ };
+ Some(output_text)
+}
@@ -1,5 +1,5 @@
[package]
-name = "zeta_cli"
+name = "edit_prediction_cli"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
@@ -9,7 +9,7 @@ license = "GPL-3.0-or-later"
workspace = true
[[bin]]
-name = "zeta"
+name = "ep_cli"
path = "src/main.rs"
[dependencies]
@@ -19,7 +19,7 @@ chrono.workspace = true
clap.workspace = true
client.workspace = true
cloud_llm_client.workspace= true
-cloud_zeta2_prompt.workspace= true
+cloud_zeta2_prompt.workspace = true
collections.workspace = true
debug_adapter_extension.workspace = true
edit_prediction_context.workspace = true
@@ -35,9 +35,7 @@ language_models.workspace = true
languages = { workspace = true, features = ["load-grammars"] }
log.workspace = true
node_runtime.workspace = true
-ordered-float.workspace = true
paths.workspace = true
-polars = { version = "0.51", features = ["lazy", "dtype-struct", "parquet"] }
project.workspace = true
prompt_store.workspace = true
pulldown-cmark.workspace = true
@@ -48,12 +46,11 @@ serde_json.workspace = true
settings.workspace = true
shellexpand.workspace = true
smol.workspace = true
-soa-rs = "0.8.1"
terminal_view.workspace = true
toml.workspace = true
util.workspace = true
watch.workspace = true
-zeta = { workspace = true, features = ["eval-support"] }
+edit_prediction = { workspace = true, features = ["eval-support"] }
zlog.workspace = true
[dev-dependencies]
@@ -6,17 +6,17 @@ use std::{
};
use anyhow::Result;
+use edit_prediction::{EditPredictionStore, udiff::DiffLine};
use gpui::{AsyncApp, Entity};
use project::Project;
use util::ResultExt as _;
-use zeta::{Zeta, udiff::DiffLine};
use crate::{
EvaluateArguments, PredictionOptions,
example::{Example, NamedExample},
headless::ZetaCliAppState,
paths::print_run_data_dir,
- predict::{PredictionDetails, perform_predict, setup_zeta},
+ predict::{PredictionDetails, perform_predict, setup_store},
};
#[derive(Debug)]
@@ -45,7 +45,7 @@ pub async fn run_evaluate(
let project = example.setup_project(&app_state, cx).await.unwrap();
let providers = (0..args.repetitions)
- .map(|_| setup_zeta(args.options.provider, &project, &app_state, cx).unwrap())
+ .map(|_| setup_store(args.options.provider, &project, &app_state, cx).unwrap())
.collect::<Vec<_>>();
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
@@ -53,7 +53,7 @@ pub async fn run_evaluate(
let tasks = providers
.into_iter()
.enumerate()
- .map(move |(repetition_ix, zeta)| {
+ .map(move |(repetition_ix, store)| {
let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
let example = example.clone();
let project = project.clone();
@@ -65,7 +65,7 @@ pub async fn run_evaluate(
example,
repetition_ix,
project,
- zeta,
+ store,
options,
!args.skip_prediction,
cx,
@@ -154,7 +154,7 @@ pub async fn run_evaluate_one(
example: NamedExample,
repetition_ix: Option<u16>,
project: Entity<Project>,
- zeta: Entity<Zeta>,
+ store: Entity<EditPredictionStore>,
prediction_options: PredictionOptions,
predict: bool,
cx: &mut AsyncApp,
@@ -162,7 +162,7 @@ pub async fn run_evaluate_one(
let predict_result = perform_predict(
example.clone(),
project,
- zeta,
+ store,
repetition_ix,
prediction_options,
cx,
@@ -14,6 +14,7 @@ use anyhow::{Context as _, Result, anyhow};
use clap::ValueEnum;
use cloud_zeta2_prompt::CURSOR_MARKER;
use collections::HashMap;
+use edit_prediction::udiff::OpenedBuffers;
use futures::{
AsyncWriteExt as _,
lock::{Mutex, OwnedMutexGuard},
@@ -25,7 +26,6 @@ use project::{Project, ProjectPath};
use pulldown_cmark::CowStr;
use serde::{Deserialize, Serialize};
use util::{paths::PathStyle, rel_path::RelPath};
-use zeta::udiff::OpenedBuffers;
use crate::paths::{REPOS_DIR, WORKTREES_DIR};
@@ -481,7 +481,7 @@ impl NamedExample {
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<OpenedBuffers<'_>> {
- zeta::udiff::apply_diff(&self.example.edit_history, project, cx).await
+ edit_prediction::udiff::apply_diff(&self.example.edit_history, project, cx).await
}
}
@@ -5,7 +5,6 @@ mod metrics;
mod paths;
mod predict;
mod source_location;
-mod syntax_retrieval_stats;
mod util;
use crate::{
@@ -14,13 +13,13 @@ use crate::{
headless::ZetaCliAppState,
predict::run_predict,
source_location::SourceLocation,
- syntax_retrieval_stats::retrieval_stats,
util::{open_buffer, open_buffer_with_language_server},
};
use ::util::paths::PathStyle;
use anyhow::{Result, anyhow};
use clap::{Args, Parser, Subcommand, ValueEnum};
use cloud_llm_client::predict_edits_v3;
+use edit_prediction::udiff::DiffLine;
use edit_prediction_context::EditPredictionExcerptOptions;
use gpui::{Application, AsyncApp, Entity, prelude::*};
use language::{Bias, Buffer, BufferSnapshot, Point};
@@ -28,10 +27,7 @@ use metrics::delta_chr_f;
use project::{Project, Worktree, lsp_store::OpenLspBufferHandle};
use reqwest_client::ReqwestClient;
use std::io::{self};
-use std::time::Duration;
use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
-use zeta::ContextMode;
-use zeta::udiff::DiffLine;
#[derive(Parser, Debug)]
#[command(name = "zeta")]
@@ -45,7 +41,6 @@ struct ZetaCliArgs {
#[derive(Subcommand, Debug)]
enum Command {
Context(ContextArgs),
- ContextStats(ContextStatsArgs),
Predict(PredictArguments),
Eval(EvaluateArguments),
ConvertExample {
@@ -60,20 +55,6 @@ enum Command {
Clean,
}
-#[derive(Debug, Args)]
-struct ContextStatsArgs {
- #[arg(long)]
- worktree: PathBuf,
- #[arg(long)]
- extension: Option<String>,
- #[arg(long)]
- limit: Option<usize>,
- #[arg(long)]
- skip: Option<usize>,
- #[clap(flatten)]
- zeta2_args: Zeta2Args,
-}
-
#[derive(Debug, Args)]
struct ContextArgs {
#[arg(long)]
@@ -201,28 +182,22 @@ enum PredictionProvider {
Sweep,
}
-fn zeta2_args_to_options(args: &Zeta2Args) -> zeta::ZetaOptions {
- zeta::ZetaOptions {
- context: ContextMode::Lsp(EditPredictionExcerptOptions {
+fn zeta2_args_to_options(args: &Zeta2Args) -> edit_prediction::ZetaOptions {
+ edit_prediction::ZetaOptions {
+ context: EditPredictionExcerptOptions {
max_bytes: args.max_excerpt_bytes,
min_bytes: args.min_excerpt_bytes,
target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
- }),
- max_diagnostic_bytes: args.max_diagnostic_bytes,
+ },
max_prompt_bytes: args.max_prompt_bytes,
prompt_format: args.prompt_format.into(),
- file_indexing_parallelism: args.file_indexing_parallelism,
- buffer_change_grouping_interval: Duration::ZERO,
}
}
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
enum PromptFormat {
- MarkedExcerpt,
- LabeledSections,
OnlySnippets,
#[default]
- NumberedLines,
OldTextNewText,
Minimal,
MinimalQwen,
@@ -232,10 +207,7 @@ enum PromptFormat {
impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
fn into(self) -> predict_edits_v3::PromptFormat {
match self {
- Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
- Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
- Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
@@ -395,27 +367,29 @@ async fn zeta2_context(
.await;
let output = cx
.update(|cx| {
- let zeta = cx.new(|cx| {
- zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
+ let store = cx.new(|cx| {
+ edit_prediction::EditPredictionStore::new(
+ app_state.client.clone(),
+ app_state.user_store.clone(),
+ cx,
+ )
});
- let indexing_done_task = zeta.update(cx, |zeta, cx| {
- zeta.set_options(zeta2_args_to_options(&args.zeta2_args));
- zeta.register_buffer(&buffer, &project, cx);
- zeta.wait_for_initial_indexing(&project, cx)
+ store.update(cx, |store, cx| {
+ store.set_options(zeta2_args_to_options(&args.zeta2_args));
+ store.register_buffer(&buffer, &project, cx);
});
cx.spawn(async move |cx| {
- indexing_done_task.await?;
- let updates_rx = zeta.update(cx, |zeta, cx| {
+ let updates_rx = store.update(cx, |store, cx| {
let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
- zeta.set_use_context(true);
- zeta.refresh_context_if_needed(&project, &buffer, cursor, cx);
- zeta.project_context_updates(&project).unwrap()
+ store.set_use_context(true);
+ store.refresh_context(&project, &buffer, cursor, cx);
+ store.project_context_updates(&project).unwrap()
})?;
updates_rx.recv().await.ok();
- let context = zeta.update(cx, |zeta, cx| {
- zeta.context_for_project(&project, cx).to_vec()
+ let context = store.update(cx, |store, cx| {
+ store.context_for_project(&project, cx).to_vec()
})?;
anyhow::Ok(serde_json::to_string_pretty(&context).unwrap())
@@ -430,7 +404,7 @@ async fn zeta1_context(
args: ContextArgs,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
-) -> Result<zeta::zeta1::GatherContextOutput> {
+) -> Result<edit_prediction::zeta1::GatherContextOutput> {
let LoadedContext {
full_path_str,
snapshot,
@@ -445,7 +419,7 @@ async fn zeta1_context(
let prompt_for_events = move || (events, 0);
cx.update(|cx| {
- zeta::zeta1::gather_context(
+ edit_prediction::zeta1::gather_context(
full_path_str,
&snapshot,
clipped_cursor,
@@ -475,19 +449,6 @@ fn main() {
panic!("Expected a command");
}
}
- Some(Command::ContextStats(arguments)) => {
- let result = retrieval_stats(
- arguments.worktree,
- app_state,
- arguments.extension,
- arguments.limit,
- arguments.skip,
- zeta2_args_to_options(&arguments.zeta2_args),
- cx,
- )
- .await;
- println!("{}", result.unwrap());
- }
Some(Command::Context(context_args)) => {
let result = match context_args.provider {
ContextProvider::Zeta1 => {
@@ -1,5 +1,5 @@
use collections::{HashMap, HashSet};
-use zeta::udiff::DiffLine;
+use edit_prediction::udiff::DiffLine;
type Counts = HashMap<String, usize>;
type CountsDelta = HashMap<String, isize>;
@@ -287,7 +287,7 @@ fn count_ngrams(text: &str, n: usize) -> Counts {
#[cfg(test)]
mod test {
use super::*;
- use zeta::udiff::DiffLine;
+ use edit_prediction::udiff::DiffLine;
#[test]
fn test_delta_chr_f_perfect_match() {
@@ -7,6 +7,7 @@ use crate::{
use ::serde::Serialize;
use anyhow::{Context, Result, anyhow};
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
+use edit_prediction::{EditPredictionStore, EvalCache, EvalCacheEntryKind, EvalCacheKey};
use futures::StreamExt as _;
use gpui::{AppContext, AsyncApp, Entity};
use project::Project;
@@ -18,7 +19,6 @@ use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::{Duration, Instant};
-use zeta::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
pub async fn run_predict(
args: PredictArguments,
@@ -27,9 +27,9 @@ pub async fn run_predict(
) {
let example = NamedExample::load(args.example_path).unwrap();
let project = example.setup_project(app_state, cx).await.unwrap();
- let zeta = setup_zeta(args.options.provider, &project, app_state, cx).unwrap();
+ let store = setup_store(args.options.provider, &project, app_state, cx).unwrap();
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
- let result = perform_predict(example, project, zeta, None, args.options, cx)
+ let result = perform_predict(example, project, store, None, args.options, cx)
.await
.unwrap();
result.write(args.format, std::io::stdout()).unwrap();
@@ -37,45 +37,50 @@ pub async fn run_predict(
print_run_data_dir(true, std::io::stdout().is_terminal());
}
-pub fn setup_zeta(
+pub fn setup_store(
provider: PredictionProvider,
project: &Entity<Project>,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
-) -> Result<Entity<Zeta>> {
- let zeta =
- cx.new(|cx| zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?;
+) -> Result<Entity<EditPredictionStore>> {
+ let store = cx.new(|cx| {
+ edit_prediction::EditPredictionStore::new(
+ app_state.client.clone(),
+ app_state.user_store.clone(),
+ cx,
+ )
+ })?;
- zeta.update(cx, |zeta, _cx| {
+ store.update(cx, |store, _cx| {
let model = match provider {
- PredictionProvider::Zeta1 => zeta::ZetaEditPredictionModel::Zeta1,
- PredictionProvider::Zeta2 => zeta::ZetaEditPredictionModel::Zeta2,
- PredictionProvider::Sweep => zeta::ZetaEditPredictionModel::Sweep,
+ PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
+ PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
+ PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
};
- zeta.set_edit_prediction_model(model);
+ store.set_edit_prediction_model(model);
})?;
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
cx.subscribe(&buffer_store, {
let project = project.clone();
- let zeta = zeta.clone();
+ let store = store.clone();
move |_, event, cx| match event {
BufferStoreEvent::BufferAdded(buffer) => {
- zeta.update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
+ store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
}
_ => {}
}
})?
.detach();
- anyhow::Ok(zeta)
+ anyhow::Ok(store)
}
pub async fn perform_predict(
example: NamedExample,
project: Entity<Project>,
- zeta: Entity<Zeta>,
+ store: Entity<EditPredictionStore>,
repetition_ix: Option<u16>,
options: PredictionOptions,
cx: &mut AsyncApp,
@@ -108,8 +113,8 @@ pub async fn perform_predict(
std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
.context("creating latest link")?;
- zeta.update(cx, |zeta, _cx| {
- zeta.with_eval_cache(Arc::new(RunCache {
+ store.update(cx, |store, _cx| {
+ store.with_eval_cache(Arc::new(RunCache {
example_run_dir: example_run_dir.clone(),
cache_mode,
}));
@@ -121,16 +126,16 @@ pub async fn perform_predict(
let prompt_format = options.zeta2.prompt_format;
- zeta.update(cx, |zeta, _cx| {
- let mut options = zeta.options().clone();
+ store.update(cx, |store, _cx| {
+ let mut options = store.options().clone();
options.prompt_format = prompt_format.into();
- zeta.set_options(options);
+ store.set_options(options);
})?;
let mut debug_task = gpui::Task::ready(Ok(()));
if options.provider == crate::PredictionProvider::Zeta2 {
- let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
+ let mut debug_rx = store.update(cx, |store, _| store.debug_info())?;
debug_task = cx.background_spawn({
let result = result.clone();
@@ -139,14 +144,14 @@ pub async fn perform_predict(
let mut retrieval_finished_at = None;
while let Some(event) = debug_rx.next().await {
match event {
- zeta::ZetaDebugInfo::ContextRetrievalStarted(info) => {
+ edit_prediction::DebugEvent::ContextRetrievalStarted(info) => {
start_time = Some(info.timestamp);
fs::write(
example_run_dir.join("search_prompt.md"),
&info.search_prompt,
)?;
}
- zeta::ZetaDebugInfo::ContextRetrievalFinished(info) => {
+ edit_prediction::DebugEvent::ContextRetrievalFinished(info) => {
retrieval_finished_at = Some(info.timestamp);
for (key, value) in &info.metadata {
if *key == "search_queries" {
@@ -157,7 +162,7 @@ pub async fn perform_predict(
}
}
}
- zeta::ZetaDebugInfo::EditPredictionRequested(request) => {
+ edit_prediction::DebugEvent::EditPredictionRequested(request) => {
let prediction_started_at = Instant::now();
start_time.get_or_insert(prediction_started_at);
let prompt = request.local_prompt.unwrap_or_default();
@@ -193,7 +198,8 @@ pub async fn perform_predict(
let response =
request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
- let response = zeta::text_from_response(response).unwrap_or_default();
+ let response = edit_prediction::zeta2::text_from_response(response)
+ .unwrap_or_default();
let prediction_finished_at = Instant::now();
fs::write(example_run_dir.join("prediction_response.md"), &response)?;
@@ -212,20 +218,14 @@ pub async fn perform_predict(
}
});
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_context_with_agentic_retrieval(
- project.clone(),
- cursor_buffer.clone(),
- cursor_anchor,
- cx,
- )
- })?
- .await?;
+ store.update(cx, |store, cx| {
+ store.refresh_context(&project, &cursor_buffer, cursor_anchor, cx)
+ })?;
}
- let prediction = zeta
- .update(cx, |zeta, cx| {
- zeta.request_prediction(
+ let prediction = store
+ .update(cx, |store, cx| {
+ store.request_prediction(
&project,
&cursor_buffer,
cursor_anchor,
@@ -12,41 +12,32 @@ workspace = true
path = "src/edit_prediction_context.rs"
[dependencies]
+parking_lot.workspace = true
anyhow.workspace = true
-arrayvec.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
-hashbrown.workspace = true
-indoc.workspace = true
-itertools.workspace = true
language.workspace = true
-log.workspace = true
-ordered-float.workspace = true
-postage.workspace = true
+lsp.workspace = true
project.workspace = true
-regex.workspace = true
+log.workspace = true
serde.workspace = true
-slotmap.workspace = true
-strum.workspace = true
-text.workspace = true
+smallvec.workspace = true
tree-sitter.workspace = true
util.workspace = true
[dev-dependencies]
-clap.workspace = true
+env_logger.workspace = true
+indoc.workspace = true
futures.workspace = true
gpui = { workspace = true, features = ["test-support"] }
-indoc.workspace = true
language = { workspace = true, features = ["test-support"] }
+lsp = { workspace = true, features = ["test-support"] }
pretty_assertions.workspace = true
project = {workspace= true, features = ["test-support"]}
serde_json.workspace = true
settings = {workspace= true, features = ["test-support"]}
text = { workspace = true, features = ["test-support"] }
-tree-sitter-c.workspace = true
-tree-sitter-cpp.workspace = true
-tree-sitter-go.workspace = true
util = { workspace = true, features = ["test-support"] }
zlog.workspace = true
@@ -1,350 +0,0 @@
-use cloud_llm_client::predict_edits_v3::{self, Line};
-use language::{Language, LanguageId};
-use project::ProjectEntryId;
-use std::ops::Range;
-use std::sync::Arc;
-use std::{borrow::Cow, path::Path};
-use text::{Bias, BufferId, Rope};
-use util::paths::{path_ends_with, strip_path_suffix};
-use util::rel_path::RelPath;
-
-use crate::outline::OutlineDeclaration;
-
-#[derive(Debug, Clone, Eq, PartialEq, Hash)]
-pub struct Identifier {
- pub name: Arc<str>,
- pub language_id: LanguageId,
-}
-
-slotmap::new_key_type! {
- pub struct DeclarationId;
-}
-
-#[derive(Debug, Clone)]
-pub enum Declaration {
- File {
- project_entry_id: ProjectEntryId,
- declaration: FileDeclaration,
- cached_path: CachedDeclarationPath,
- },
- Buffer {
- project_entry_id: ProjectEntryId,
- buffer_id: BufferId,
- rope: Rope,
- declaration: BufferDeclaration,
- cached_path: CachedDeclarationPath,
- },
-}
-
-const ITEM_TEXT_TRUNCATION_LENGTH: usize = 1024;
-
-impl Declaration {
- pub fn identifier(&self) -> &Identifier {
- match self {
- Declaration::File { declaration, .. } => &declaration.identifier,
- Declaration::Buffer { declaration, .. } => &declaration.identifier,
- }
- }
-
- pub fn parent(&self) -> Option<DeclarationId> {
- match self {
- Declaration::File { declaration, .. } => declaration.parent,
- Declaration::Buffer { declaration, .. } => declaration.parent,
- }
- }
-
- pub fn as_buffer(&self) -> Option<&BufferDeclaration> {
- match self {
- Declaration::File { .. } => None,
- Declaration::Buffer { declaration, .. } => Some(declaration),
- }
- }
-
- pub fn as_file(&self) -> Option<&FileDeclaration> {
- match self {
- Declaration::Buffer { .. } => None,
- Declaration::File { declaration, .. } => Some(declaration),
- }
- }
-
- pub fn project_entry_id(&self) -> ProjectEntryId {
- match self {
- Declaration::File {
- project_entry_id, ..
- } => *project_entry_id,
- Declaration::Buffer {
- project_entry_id, ..
- } => *project_entry_id,
- }
- }
-
- pub fn cached_path(&self) -> &CachedDeclarationPath {
- match self {
- Declaration::File { cached_path, .. } => cached_path,
- Declaration::Buffer { cached_path, .. } => cached_path,
- }
- }
-
- pub fn item_range(&self) -> Range<usize> {
- match self {
- Declaration::File { declaration, .. } => declaration.item_range.clone(),
- Declaration::Buffer { declaration, .. } => declaration.item_range.clone(),
- }
- }
-
- pub fn item_line_range(&self) -> Range<Line> {
- match self {
- Declaration::File { declaration, .. } => declaration.item_line_range.clone(),
- Declaration::Buffer {
- declaration, rope, ..
- } => {
- Line(rope.offset_to_point(declaration.item_range.start).row)
- ..Line(rope.offset_to_point(declaration.item_range.end).row)
- }
- }
- }
-
- pub fn item_text(&self) -> (Cow<'_, str>, bool) {
- match self {
- Declaration::File { declaration, .. } => (
- declaration.text.as_ref().into(),
- declaration.text_is_truncated,
- ),
- Declaration::Buffer {
- rope, declaration, ..
- } => (
- rope.chunks_in_range(declaration.item_range.clone())
- .collect::<Cow<str>>(),
- declaration.item_range_is_truncated,
- ),
- }
- }
-
- pub fn signature_text(&self) -> (Cow<'_, str>, bool) {
- match self {
- Declaration::File { declaration, .. } => (
- declaration.text[self.signature_range_in_item_text()].into(),
- declaration.signature_is_truncated,
- ),
- Declaration::Buffer {
- rope, declaration, ..
- } => (
- rope.chunks_in_range(declaration.signature_range.clone())
- .collect::<Cow<str>>(),
- declaration.signature_range_is_truncated,
- ),
- }
- }
-
- pub fn signature_range(&self) -> Range<usize> {
- match self {
- Declaration::File { declaration, .. } => declaration.signature_range.clone(),
- Declaration::Buffer { declaration, .. } => declaration.signature_range.clone(),
- }
- }
-
- pub fn signature_line_range(&self) -> Range<Line> {
- match self {
- Declaration::File { declaration, .. } => declaration.signature_line_range.clone(),
- Declaration::Buffer {
- declaration, rope, ..
- } => {
- Line(rope.offset_to_point(declaration.signature_range.start).row)
- ..Line(rope.offset_to_point(declaration.signature_range.end).row)
- }
- }
- }
-
- pub fn signature_range_in_item_text(&self) -> Range<usize> {
- let signature_range = self.signature_range();
- let item_range = self.item_range();
- signature_range.start.saturating_sub(item_range.start)
- ..(signature_range.end.saturating_sub(item_range.start)).min(item_range.len())
- }
-}
-
-fn expand_range_to_line_boundaries_and_truncate(
- range: &Range<usize>,
- limit: usize,
- rope: &Rope,
-) -> (Range<usize>, Range<predict_edits_v3::Line>, bool) {
- let mut point_range = rope.offset_to_point(range.start)..rope.offset_to_point(range.end);
- point_range.start.column = 0;
- point_range.end.row += 1;
- point_range.end.column = 0;
-
- let mut item_range =
- rope.point_to_offset(point_range.start)..rope.point_to_offset(point_range.end);
- let is_truncated = item_range.len() > limit;
- if is_truncated {
- item_range.end = item_range.start + limit;
- }
- item_range.end = rope.clip_offset(item_range.end, Bias::Left);
-
- let line_range =
- predict_edits_v3::Line(point_range.start.row)..predict_edits_v3::Line(point_range.end.row);
- (item_range, line_range, is_truncated)
-}
-
-#[derive(Debug, Clone)]
-pub struct FileDeclaration {
- pub parent: Option<DeclarationId>,
- pub identifier: Identifier,
- /// offset range of the declaration in the file, expanded to line boundaries and truncated
- pub item_range: Range<usize>,
- /// line range of the declaration in the file, potentially truncated
- pub item_line_range: Range<predict_edits_v3::Line>,
- /// text of `item_range`
- pub text: Arc<str>,
- /// whether `text` was truncated
- pub text_is_truncated: bool,
- /// offset range of the signature in the file, expanded to line boundaries and truncated
- pub signature_range: Range<usize>,
- /// line range of the signature in the file, truncated
- pub signature_line_range: Range<Line>,
- /// whether `signature` was truncated
- pub signature_is_truncated: bool,
-}
-
-impl FileDeclaration {
- pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> FileDeclaration {
- let (item_range_in_file, item_line_range_in_file, text_is_truncated) =
- expand_range_to_line_boundaries_and_truncate(
- &declaration.item_range,
- ITEM_TEXT_TRUNCATION_LENGTH,
- rope,
- );
-
- let (mut signature_range_in_file, signature_line_range, mut signature_is_truncated) =
- expand_range_to_line_boundaries_and_truncate(
- &declaration.signature_range,
- ITEM_TEXT_TRUNCATION_LENGTH,
- rope,
- );
-
- if signature_range_in_file.start < item_range_in_file.start {
- signature_range_in_file.start = item_range_in_file.start;
- signature_is_truncated = true;
- }
- if signature_range_in_file.end > item_range_in_file.end {
- signature_range_in_file.end = item_range_in_file.end;
- signature_is_truncated = true;
- }
-
- FileDeclaration {
- parent: None,
- identifier: declaration.identifier,
- signature_range: signature_range_in_file,
- signature_line_range,
- signature_is_truncated,
- text: rope
- .chunks_in_range(item_range_in_file.clone())
- .collect::<String>()
- .into(),
- text_is_truncated,
- item_range: item_range_in_file,
- item_line_range: item_line_range_in_file,
- }
- }
-}
-
-#[derive(Debug, Clone)]
-pub struct BufferDeclaration {
- pub parent: Option<DeclarationId>,
- pub identifier: Identifier,
- pub item_range: Range<usize>,
- pub item_range_is_truncated: bool,
- pub signature_range: Range<usize>,
- pub signature_range_is_truncated: bool,
-}
-
-impl BufferDeclaration {
- pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> Self {
- let (item_range, _item_line_range, item_range_is_truncated) =
- expand_range_to_line_boundaries_and_truncate(
- &declaration.item_range,
- ITEM_TEXT_TRUNCATION_LENGTH,
- rope,
- );
- let (signature_range, _signature_line_range, signature_range_is_truncated) =
- expand_range_to_line_boundaries_and_truncate(
- &declaration.signature_range,
- ITEM_TEXT_TRUNCATION_LENGTH,
- rope,
- );
- Self {
- parent: None,
- identifier: declaration.identifier,
- item_range,
- item_range_is_truncated,
- signature_range,
- signature_range_is_truncated,
- }
- }
-}
-
-#[derive(Debug, Clone)]
-pub struct CachedDeclarationPath {
- pub worktree_abs_path: Arc<Path>,
- pub rel_path: Arc<RelPath>,
- /// The relative path of the file, possibly stripped according to `import_path_strip_regex`.
- pub rel_path_after_regex_stripping: Arc<RelPath>,
-}
-
-impl CachedDeclarationPath {
- pub fn new(
- worktree_abs_path: Arc<Path>,
- path: &Arc<RelPath>,
- language: Option<&Arc<Language>>,
- ) -> Self {
- let rel_path = path.clone();
- let rel_path_after_regex_stripping = if let Some(language) = language
- && let Some(strip_regex) = language.config().import_path_strip_regex.as_ref()
- && let Ok(stripped) = RelPath::unix(&Path::new(
- strip_regex.replace_all(rel_path.as_unix_str(), "").as_ref(),
- )) {
- Arc::from(stripped)
- } else {
- rel_path.clone()
- };
- CachedDeclarationPath {
- worktree_abs_path,
- rel_path,
- rel_path_after_regex_stripping,
- }
- }
-
- #[cfg(test)]
- pub fn new_for_test(worktree_abs_path: &str, rel_path: &str) -> Self {
- let rel_path: Arc<RelPath> = util::rel_path::rel_path(rel_path).into();
- CachedDeclarationPath {
- worktree_abs_path: std::path::PathBuf::from(worktree_abs_path).into(),
- rel_path_after_regex_stripping: rel_path.clone(),
- rel_path,
- }
- }
-
- pub fn ends_with_posix_path(&self, path: &Path) -> bool {
- if path.as_os_str().len() <= self.rel_path_after_regex_stripping.as_unix_str().len() {
- path_ends_with(self.rel_path_after_regex_stripping.as_std_path(), path)
- } else {
- if let Some(remaining) =
- strip_path_suffix(path, self.rel_path_after_regex_stripping.as_std_path())
- {
- path_ends_with(&self.worktree_abs_path, remaining)
- } else {
- false
- }
- }
- }
-
- pub fn equals_absolute_path(&self, path: &Path) -> bool {
- if let Some(remaining) =
- strip_path_suffix(path, &self.rel_path_after_regex_stripping.as_std_path())
- {
- self.worktree_abs_path.as_ref() == remaining
- } else {
- false
- }
- }
-}
@@ -1,539 +0,0 @@
-use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
-use collections::HashMap;
-use language::BufferSnapshot;
-use ordered_float::OrderedFloat;
-use project::ProjectEntryId;
-use serde::Serialize;
-use std::{cmp::Reverse, ops::Range, path::Path, sync::Arc};
-use strum::EnumIter;
-use text::{Point, ToPoint};
-use util::RangeExt as _;
-
-use crate::{
- CachedDeclarationPath, Declaration, EditPredictionExcerpt, Identifier,
- imports::{Import, Imports, Module},
- reference::{Reference, ReferenceRegion},
- syntax_index::SyntaxIndexState,
- text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient},
-};
-
-const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
-
-#[derive(Clone, Debug, PartialEq, Eq)]
-pub struct EditPredictionScoreOptions {
- pub omit_excerpt_overlaps: bool,
-}
-
-#[derive(Clone, Debug)]
-pub struct ScoredDeclaration {
- /// identifier used by the local reference
- pub identifier: Identifier,
- pub declaration: Declaration,
- pub components: DeclarationScoreComponents,
-}
-
-#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
-pub enum DeclarationStyle {
- Signature,
- Declaration,
-}
-
-#[derive(Clone, Debug, Serialize, Default)]
-pub struct DeclarationScores {
- pub signature: f32,
- pub declaration: f32,
- pub retrieval: f32,
-}
-
-impl ScoredDeclaration {
- /// Returns the score for this declaration with the specified style.
- pub fn score(&self, style: DeclarationStyle) -> f32 {
- // TODO: handle truncation
-
- // Score related to how likely this is the correct declaration, range 0 to 1
- let retrieval = self.retrieval_score();
-
- // Score related to the distance between the reference and cursor, range 0 to 1
- let distance_score = if self.components.is_referenced_nearby {
- 1.0 / (1.0 + self.components.reference_line_distance as f32 / 10.0).powf(2.0)
- } else {
- // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
- 0.5
- };
-
- // For now instead of linear combination, the scores are just multiplied together.
- let combined_score = 10.0 * retrieval * distance_score;
-
- match style {
- DeclarationStyle::Signature => {
- combined_score * self.components.excerpt_vs_signature_weighted_overlap
- }
- DeclarationStyle::Declaration => {
- 2.0 * combined_score * self.components.excerpt_vs_item_weighted_overlap
- }
- }
- }
-
- pub fn retrieval_score(&self) -> f32 {
- let mut score = if self.components.is_same_file {
- 10.0 / self.components.same_file_declaration_count as f32
- } else if self.components.path_import_match_count > 0 {
- 3.0
- } else if self.components.wildcard_path_import_match_count > 0 {
- 1.0
- } else if self.components.normalized_import_similarity > 0.0 {
- self.components.normalized_import_similarity
- } else if self.components.normalized_wildcard_import_similarity > 0.0 {
- 0.5 * self.components.normalized_wildcard_import_similarity
- } else {
- 1.0 / self.components.declaration_count as f32
- };
- score *= 1. + self.components.included_by_others as f32 / 2.;
- score *= 1. + self.components.includes_others as f32 / 4.;
- score
- }
-
- pub fn size(&self, style: DeclarationStyle) -> usize {
- match &self.declaration {
- Declaration::File { declaration, .. } => match style {
- DeclarationStyle::Signature => declaration.signature_range.len(),
- DeclarationStyle::Declaration => declaration.text.len(),
- },
- Declaration::Buffer { declaration, .. } => match style {
- DeclarationStyle::Signature => declaration.signature_range.len(),
- DeclarationStyle::Declaration => declaration.item_range.len(),
- },
- }
- }
-
- pub fn score_density(&self, style: DeclarationStyle) -> f32 {
- self.score(style) / self.size(style) as f32
- }
-}
-
-pub fn scored_declarations(
- options: &EditPredictionScoreOptions,
- index: &SyntaxIndexState,
- excerpt: &EditPredictionExcerpt,
- excerpt_occurrences: &Occurrences,
- adjacent_occurrences: &Occurrences,
- imports: &Imports,
- identifier_to_references: HashMap<Identifier, Vec<Reference>>,
- cursor_offset: usize,
- current_buffer: &BufferSnapshot,
-) -> Vec<ScoredDeclaration> {
- let cursor_point = cursor_offset.to_point(¤t_buffer);
-
- let mut wildcard_import_occurrences = Vec::new();
- let mut wildcard_import_paths = Vec::new();
- for wildcard_import in imports.wildcard_modules.iter() {
- match wildcard_import {
- Module::Namespace(namespace) => {
- wildcard_import_occurrences.push(namespace.occurrences())
- }
- Module::SourceExact(path) => wildcard_import_paths.push(path),
- Module::SourceFuzzy(path) => {
- wildcard_import_occurrences.push(Occurrences::from_path(&path))
- }
- }
- }
-
- let mut scored_declarations = Vec::new();
- let mut project_entry_id_to_outline_ranges: HashMap<ProjectEntryId, Vec<Range<usize>>> =
- HashMap::default();
- for (identifier, references) in identifier_to_references {
- let mut import_occurrences = Vec::new();
- let mut import_paths = Vec::new();
- let mut found_external_identifier: Option<&Identifier> = None;
-
- if let Some(imports) = imports.identifier_to_imports.get(&identifier) {
- // only use alias when it's the only import, could be generalized if some language
- // has overlapping aliases
- //
- // TODO: when an aliased declaration is included in the prompt, should include the
- // aliasing in the prompt.
- //
- // TODO: For SourceFuzzy consider having componentwise comparison that pays
- // attention to ordering.
- if let [
- Import::Alias {
- module,
- external_identifier,
- },
- ] = imports.as_slice()
- {
- match module {
- Module::Namespace(namespace) => {
- import_occurrences.push(namespace.occurrences())
- }
- Module::SourceExact(path) => import_paths.push(path),
- Module::SourceFuzzy(path) => {
- import_occurrences.push(Occurrences::from_path(&path))
- }
- }
- found_external_identifier = Some(&external_identifier);
- } else {
- for import in imports {
- match import {
- Import::Direct { module } => match module {
- Module::Namespace(namespace) => {
- import_occurrences.push(namespace.occurrences())
- }
- Module::SourceExact(path) => import_paths.push(path),
- Module::SourceFuzzy(path) => {
- import_occurrences.push(Occurrences::from_path(&path))
- }
- },
- Import::Alias { .. } => {}
- }
- }
- }
- }
-
- let identifier_to_lookup = found_external_identifier.unwrap_or(&identifier);
- // TODO: update this to be able to return more declarations? Especially if there is the
- // ability to quickly filter a large list (based on imports)
- let identifier_declarations = index
- .declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier_to_lookup);
- let declaration_count = identifier_declarations.len();
-
- if declaration_count == 0 {
- continue;
- }
-
- // TODO: option to filter out other candidates when same file / import match
- let mut checked_declarations = Vec::with_capacity(declaration_count);
- for (declaration_id, declaration) in identifier_declarations {
- match declaration {
- Declaration::Buffer {
- buffer_id,
- declaration: buffer_declaration,
- ..
- } => {
- if buffer_id == ¤t_buffer.remote_id() {
- let already_included_in_prompt =
- range_intersection(&buffer_declaration.item_range, &excerpt.range)
- .is_some()
- || excerpt
- .parent_declarations
- .iter()
- .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id);
- if !options.omit_excerpt_overlaps || !already_included_in_prompt {
- let declaration_line = buffer_declaration
- .item_range
- .start
- .to_point(current_buffer)
- .row;
- let declaration_line_distance =
- (cursor_point.row as i32 - declaration_line as i32).unsigned_abs();
- checked_declarations.push(CheckedDeclaration {
- declaration,
- same_file_line_distance: Some(declaration_line_distance),
- path_import_match_count: 0,
- wildcard_path_import_match_count: 0,
- });
- }
- continue;
- } else {
- }
- }
- Declaration::File { .. } => {}
- }
- let declaration_path = declaration.cached_path();
- let path_import_match_count = import_paths
- .iter()
- .filter(|import_path| {
- declaration_path_matches_import(&declaration_path, import_path)
- })
- .count();
- let wildcard_path_import_match_count = wildcard_import_paths
- .iter()
- .filter(|import_path| {
- declaration_path_matches_import(&declaration_path, import_path)
- })
- .count();
- checked_declarations.push(CheckedDeclaration {
- declaration,
- same_file_line_distance: None,
- path_import_match_count,
- wildcard_path_import_match_count,
- });
- }
-
- let mut max_import_similarity = 0.0;
- let mut max_wildcard_import_similarity = 0.0;
-
- let mut scored_declarations_for_identifier = Vec::with_capacity(checked_declarations.len());
- for checked_declaration in checked_declarations {
- let same_file_declaration_count =
- index.file_declaration_count(checked_declaration.declaration);
-
- let declaration = score_declaration(
- &identifier,
- &references,
- checked_declaration,
- same_file_declaration_count,
- declaration_count,
- &excerpt_occurrences,
- &adjacent_occurrences,
- &import_occurrences,
- &wildcard_import_occurrences,
- cursor_point,
- current_buffer,
- );
-
- if declaration.components.import_similarity > max_import_similarity {
- max_import_similarity = declaration.components.import_similarity;
- }
-
- if declaration.components.wildcard_import_similarity > max_wildcard_import_similarity {
- max_wildcard_import_similarity = declaration.components.wildcard_import_similarity;
- }
-
- project_entry_id_to_outline_ranges
- .entry(declaration.declaration.project_entry_id())
- .or_default()
- .push(declaration.declaration.item_range());
- scored_declarations_for_identifier.push(declaration);
- }
-
- if max_import_similarity > 0.0 || max_wildcard_import_similarity > 0.0 {
- for declaration in scored_declarations_for_identifier.iter_mut() {
- if max_import_similarity > 0.0 {
- declaration.components.max_import_similarity = max_import_similarity;
- declaration.components.normalized_import_similarity =
- declaration.components.import_similarity / max_import_similarity;
- }
- if max_wildcard_import_similarity > 0.0 {
- declaration.components.normalized_wildcard_import_similarity =
- declaration.components.wildcard_import_similarity
- / max_wildcard_import_similarity;
- }
- }
- }
-
- scored_declarations.extend(scored_declarations_for_identifier);
- }
-
- // TODO: Inform this via import / retrieval scores of outline items
- // TODO: Consider using a sweepline
- for scored_declaration in scored_declarations.iter_mut() {
- let project_entry_id = scored_declaration.declaration.project_entry_id();
- let Some(ranges) = project_entry_id_to_outline_ranges.get(&project_entry_id) else {
- continue;
- };
- for range in ranges {
- if range.contains_inclusive(&scored_declaration.declaration.item_range()) {
- scored_declaration.components.included_by_others += 1
- } else if scored_declaration
- .declaration
- .item_range()
- .contains_inclusive(range)
- {
- scored_declaration.components.includes_others += 1
- }
- }
- }
-
- scored_declarations.sort_unstable_by_key(|declaration| {
- Reverse(OrderedFloat(
- declaration.score(DeclarationStyle::Declaration),
- ))
- });
-
- scored_declarations
-}
-
-struct CheckedDeclaration<'a> {
- declaration: &'a Declaration,
- same_file_line_distance: Option<u32>,
- path_import_match_count: usize,
- wildcard_path_import_match_count: usize,
-}
-
-fn declaration_path_matches_import(
- declaration_path: &CachedDeclarationPath,
- import_path: &Arc<Path>,
-) -> bool {
- if import_path.is_absolute() {
- declaration_path.equals_absolute_path(import_path)
- } else {
- declaration_path.ends_with_posix_path(import_path)
- }
-}
-
-fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
- let start = a.start.clone().max(b.start.clone());
- let end = a.end.clone().min(b.end.clone());
- if start < end {
- Some(Range { start, end })
- } else {
- None
- }
-}
-
-fn score_declaration(
- identifier: &Identifier,
- references: &[Reference],
- checked_declaration: CheckedDeclaration,
- same_file_declaration_count: usize,
- declaration_count: usize,
- excerpt_occurrences: &Occurrences,
- adjacent_occurrences: &Occurrences,
- import_occurrences: &[Occurrences],
- wildcard_import_occurrences: &[Occurrences],
- cursor: Point,
- current_buffer: &BufferSnapshot,
-) -> ScoredDeclaration {
- let CheckedDeclaration {
- declaration,
- same_file_line_distance,
- path_import_match_count,
- wildcard_path_import_match_count,
- } = checked_declaration;
-
- let is_referenced_nearby = references
- .iter()
- .any(|r| r.region == ReferenceRegion::Nearby);
- let is_referenced_in_breadcrumb = references
- .iter()
- .any(|r| r.region == ReferenceRegion::Breadcrumb);
- let reference_count = references.len();
- let reference_line_distance = references
- .iter()
- .map(|r| {
- let reference_line = r.range.start.to_point(current_buffer).row as i32;
- (cursor.row as i32 - reference_line).unsigned_abs()
- })
- .min()
- .unwrap();
-
- let is_same_file = same_file_line_distance.is_some();
- let declaration_line_distance = same_file_line_distance.unwrap_or(u32::MAX);
-
- let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0);
- let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0);
- let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences);
- let excerpt_vs_signature_jaccard =
- jaccard_similarity(excerpt_occurrences, &item_signature_occurrences);
- let adjacent_vs_item_jaccard =
- jaccard_similarity(adjacent_occurrences, &item_source_occurrences);
- let adjacent_vs_signature_jaccard =
- jaccard_similarity(adjacent_occurrences, &item_signature_occurrences);
-
- let excerpt_vs_item_weighted_overlap =
- weighted_overlap_coefficient(excerpt_occurrences, &item_source_occurrences);
- let excerpt_vs_signature_weighted_overlap =
- weighted_overlap_coefficient(excerpt_occurrences, &item_signature_occurrences);
- let adjacent_vs_item_weighted_overlap =
- weighted_overlap_coefficient(adjacent_occurrences, &item_source_occurrences);
- let adjacent_vs_signature_weighted_overlap =
- weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences);
-
- let mut import_similarity = 0f32;
- let mut wildcard_import_similarity = 0f32;
- if !import_occurrences.is_empty() || !wildcard_import_occurrences.is_empty() {
- let cached_path = declaration.cached_path();
- let path_occurrences = Occurrences::from_worktree_path(
- cached_path
- .worktree_abs_path
- .file_name()
- .map(|f| f.to_string_lossy()),
- &cached_path.rel_path,
- );
- import_similarity = import_occurrences
- .iter()
- .map(|namespace_occurrences| {
- OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences))
- })
- .max()
- .map(|similarity| similarity.into_inner())
- .unwrap_or_default();
-
- // TODO: Consider something other than max
- wildcard_import_similarity = wildcard_import_occurrences
- .iter()
- .map(|namespace_occurrences| {
- OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences))
- })
- .max()
- .map(|similarity| similarity.into_inner())
- .unwrap_or_default();
- }
-
- // TODO: Consider adding declaration_file_count
- let score_components = DeclarationScoreComponents {
- is_same_file,
- is_referenced_nearby,
- is_referenced_in_breadcrumb,
- reference_line_distance,
- declaration_line_distance,
- reference_count,
- same_file_declaration_count,
- declaration_count,
- excerpt_vs_item_jaccard,
- excerpt_vs_signature_jaccard,
- adjacent_vs_item_jaccard,
- adjacent_vs_signature_jaccard,
- excerpt_vs_item_weighted_overlap,
- excerpt_vs_signature_weighted_overlap,
- adjacent_vs_item_weighted_overlap,
- adjacent_vs_signature_weighted_overlap,
- path_import_match_count,
- wildcard_path_import_match_count,
- import_similarity,
- max_import_similarity: 0.0,
- normalized_import_similarity: 0.0,
- wildcard_import_similarity,
- normalized_wildcard_import_similarity: 0.0,
- included_by_others: 0,
- includes_others: 0,
- };
-
- ScoredDeclaration {
- identifier: identifier.clone(),
- declaration: declaration.clone(),
- components: score_components,
- }
-}
-
-#[cfg(test)]
-mod test {
- use super::*;
-
- #[test]
- fn test_declaration_path_matches() {
- let declaration_path =
- CachedDeclarationPath::new_for_test("/home/user/project", "src/maths.ts");
-
- assert!(declaration_path_matches_import(
- &declaration_path,
- &Path::new("maths.ts").into()
- ));
-
- assert!(declaration_path_matches_import(
- &declaration_path,
- &Path::new("project/src/maths.ts").into()
- ));
-
- assert!(declaration_path_matches_import(
- &declaration_path,
- &Path::new("user/project/src/maths.ts").into()
- ));
-
- assert!(declaration_path_matches_import(
- &declaration_path,
- &Path::new("/home/user/project/src/maths.ts").into()
- ));
-
- assert!(!declaration_path_matches_import(
- &declaration_path,
- &Path::new("other.ts").into()
- ));
-
- assert!(!declaration_path_matches_import(
- &declaration_path,
- &Path::new("/home/user/project/src/other.ts").into()
- ));
- }
-}
@@ -1,335 +1,469 @@
-mod declaration;
-mod declaration_scoring;
+use crate::assemble_excerpts::assemble_excerpts;
+use anyhow::Result;
+use collections::HashMap;
+use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
+use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
+use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _};
+use project::{LocationLink, Project, ProjectPath};
+use serde::{Serialize, Serializer};
+use smallvec::SmallVec;
+use std::{
+ collections::hash_map,
+ ops::Range,
+ sync::Arc,
+ time::{Duration, Instant},
+};
+use util::{RangeExt as _, ResultExt};
+
+mod assemble_excerpts;
+#[cfg(test)]
+mod edit_prediction_context_tests;
mod excerpt;
-mod imports;
-mod outline;
-mod reference;
-mod syntax_index;
-pub mod text_similarity;
+#[cfg(test)]
+mod fake_definition_lsp;
-use std::{path::Path, sync::Arc};
+pub use cloud_llm_client::predict_edits_v3::Line;
+pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
-use cloud_llm_client::predict_edits_v3;
-use collections::HashMap;
-use gpui::{App, AppContext as _, Entity, Task};
-use language::BufferSnapshot;
-use text::{Point, ToOffset as _};
-
-pub use declaration::*;
-pub use declaration_scoring::*;
-pub use excerpt::*;
-pub use imports::*;
-pub use reference::*;
-pub use syntax_index::*;
-
-pub use predict_edits_v3::Line;
-
-#[derive(Clone, Debug, PartialEq)]
-pub struct EditPredictionContextOptions {
- pub use_imports: bool,
- pub excerpt: EditPredictionExcerptOptions,
- pub score: EditPredictionScoreOptions,
- pub max_retrieved_declarations: u8,
+pub struct RelatedExcerptStore {
+ project: WeakEntity<Project>,
+ related_files: Vec<RelatedFile>,
+ cache: HashMap<Identifier, Arc<CacheEntry>>,
+ update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
+}
+
+pub enum RelatedExcerptStoreEvent {
+ StartedRefresh,
+ FinishedRefresh {
+ cache_hit_count: usize,
+ cache_miss_count: usize,
+ mean_definition_latency: Duration,
+ max_definition_latency: Duration,
+ },
+}
+
+#[derive(Clone, Debug, PartialEq, Eq, Hash)]
+struct Identifier {
+ pub name: String,
+ pub range: Range<Anchor>,
+}
+
+enum DefinitionTask {
+ CacheHit(Arc<CacheEntry>),
+ CacheMiss(Task<Result<Option<Vec<LocationLink>>>>),
+}
+
+#[derive(Debug)]
+struct CacheEntry {
+ definitions: SmallVec<[CachedDefinition; 1]>,
}
#[derive(Clone, Debug)]
-pub struct EditPredictionContext {
- pub excerpt: EditPredictionExcerpt,
- pub excerpt_text: EditPredictionExcerptText,
- pub cursor_point: Point,
- pub declarations: Vec<ScoredDeclaration>,
+struct CachedDefinition {
+ path: ProjectPath,
+ buffer: Entity<Buffer>,
+ anchor_range: Range<Anchor>,
+}
+
+#[derive(Clone, Debug, Serialize)]
+pub struct RelatedFile {
+ #[serde(serialize_with = "serialize_project_path")]
+ pub path: ProjectPath,
+ #[serde(skip)]
+ pub buffer: WeakEntity<Buffer>,
+ pub excerpts: Vec<RelatedExcerpt>,
+ pub max_row: u32,
}
-impl EditPredictionContext {
- pub fn gather_context_in_background(
- cursor_point: Point,
- buffer: BufferSnapshot,
- options: EditPredictionContextOptions,
- syntax_index: Option<Entity<SyntaxIndex>>,
- cx: &mut App,
- ) -> Task<Option<Self>> {
- let parent_abs_path = project::File::from_dyn(buffer.file()).and_then(|f| {
- let mut path = f.worktree.read(cx).absolutize(&f.path);
- if path.pop() { Some(path) } else { None }
+impl RelatedFile {
+ pub fn merge_excerpts(&mut self) {
+ self.excerpts.sort_unstable_by(|a, b| {
+ a.point_range
+ .start
+ .cmp(&b.point_range.start)
+ .then(b.point_range.end.cmp(&a.point_range.end))
});
- if let Some(syntax_index) = syntax_index {
- let index_state =
- syntax_index.read_with(cx, |index, _cx| Arc::downgrade(index.state()));
- cx.background_spawn(async move {
- let parent_abs_path = parent_abs_path.as_deref();
- let index_state = index_state.upgrade()?;
- let index_state = index_state.lock().await;
- Self::gather_context(
- cursor_point,
- &buffer,
- parent_abs_path,
- &options,
- Some(&index_state),
- )
- })
- } else {
- cx.background_spawn(async move {
- let parent_abs_path = parent_abs_path.as_deref();
- Self::gather_context(cursor_point, &buffer, parent_abs_path, &options, None)
- })
+ let mut index = 1;
+ while index < self.excerpts.len() {
+ if self.excerpts[index - 1]
+ .point_range
+ .end
+ .cmp(&self.excerpts[index].point_range.start)
+ .is_ge()
+ {
+ let removed = self.excerpts.remove(index);
+ if removed
+ .point_range
+ .end
+ .cmp(&self.excerpts[index - 1].point_range.end)
+ .is_gt()
+ {
+ self.excerpts[index - 1].point_range.end = removed.point_range.end;
+ self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end;
+ }
+ } else {
+ index += 1;
+ }
}
}
+}
- pub fn gather_context(
- cursor_point: Point,
- buffer: &BufferSnapshot,
- parent_abs_path: Option<&Path>,
- options: &EditPredictionContextOptions,
- index_state: Option<&SyntaxIndexState>,
- ) -> Option<Self> {
- let imports = if options.use_imports {
- Imports::gather(&buffer, parent_abs_path)
- } else {
- Imports::default()
- };
- Self::gather_context_with_references_fn(
- cursor_point,
- buffer,
- &imports,
- options,
- index_state,
- references_in_excerpt,
- )
- }
+#[derive(Clone, Debug, Serialize)]
+pub struct RelatedExcerpt {
+ #[serde(skip)]
+ pub anchor_range: Range<Anchor>,
+ #[serde(serialize_with = "serialize_point_range")]
+ pub point_range: Range<Point>,
+ #[serde(serialize_with = "serialize_rope")]
+ pub text: Rope,
+}
- pub fn gather_context_with_references_fn(
- cursor_point: Point,
- buffer: &BufferSnapshot,
- imports: &Imports,
- options: &EditPredictionContextOptions,
- index_state: Option<&SyntaxIndexState>,
- get_references: impl FnOnce(
- &EditPredictionExcerpt,
- &EditPredictionExcerptText,
- &BufferSnapshot,
- ) -> HashMap<Identifier, Vec<Reference>>,
- ) -> Option<Self> {
- let excerpt = EditPredictionExcerpt::select_from_buffer(
- cursor_point,
- buffer,
- &options.excerpt,
- index_state,
- )?;
- let excerpt_text = excerpt.text(buffer);
-
- let declarations = if options.max_retrieved_declarations > 0
- && let Some(index_state) = index_state
- {
- let excerpt_occurrences =
- text_similarity::Occurrences::within_string(&excerpt_text.body);
-
- let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
- let adjacent_end = Point::new(cursor_point.row + 1, 0);
- let adjacent_occurrences = text_similarity::Occurrences::within_string(
- &buffer
- .text_for_range(adjacent_start..adjacent_end)
- .collect::<String>(),
- );
+fn serialize_project_path<S: Serializer>(
+ project_path: &ProjectPath,
+ serializer: S,
+) -> Result<S::Ok, S::Error> {
+ project_path.path.serialize(serializer)
+}
- let cursor_offset_in_file = cursor_point.to_offset(buffer);
+fn serialize_rope<S: Serializer>(rope: &Rope, serializer: S) -> Result<S::Ok, S::Error> {
+ rope.to_string().serialize(serializer)
+}
- let references = get_references(&excerpt, &excerpt_text, buffer);
+fn serialize_point_range<S: Serializer>(
+ range: &Range<Point>,
+ serializer: S,
+) -> Result<S::Ok, S::Error> {
+ [
+ [range.start.row, range.start.column],
+ [range.end.row, range.end.column],
+ ]
+ .serialize(serializer)
+}
- let mut declarations = scored_declarations(
- &options.score,
- &index_state,
- &excerpt,
- &excerpt_occurrences,
- &adjacent_occurrences,
- &imports,
- references,
- cursor_offset_in_file,
- buffer,
- );
- // TODO [zeta2] if we need this when we ship, we should probably do it in a smarter way
- declarations.truncate(options.max_retrieved_declarations as usize);
- declarations
- } else {
- vec![]
- };
+const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
+
+impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
+
+impl RelatedExcerptStore {
+ pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
+ let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity<Buffer>, Anchor)>();
+ cx.spawn(async move |this, cx| {
+ let executor = cx.background_executor().clone();
+ while let Some((mut buffer, mut position)) = update_rx.next().await {
+ let mut timer = executor.timer(DEBOUNCE_DURATION).fuse();
+ loop {
+ futures::select_biased! {
+ next = update_rx.next() => {
+ if let Some((new_buffer, new_position)) = next {
+ buffer = new_buffer;
+ position = new_position;
+ timer = executor.timer(DEBOUNCE_DURATION).fuse();
+ } else {
+ return anyhow::Ok(());
+ }
+ }
+ _ = timer => break,
+ }
+ }
- Some(Self {
- excerpt,
- excerpt_text,
- cursor_point,
- declarations,
+ Self::fetch_excerpts(this.clone(), buffer, position, cx).await?;
+ }
+ anyhow::Ok(())
})
+ .detach_and_log_err(cx);
+
+ RelatedExcerptStore {
+ project: project.downgrade(),
+ update_tx,
+ related_files: Vec::new(),
+ cache: Default::default(),
+ }
}
-}
-#[cfg(test)]
-mod tests {
- use super::*;
- use std::sync::Arc;
-
- use gpui::{Entity, TestAppContext};
- use indoc::indoc;
- use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
- use project::{FakeFs, Project};
- use serde_json::json;
- use settings::SettingsStore;
- use util::path;
-
- use crate::{EditPredictionExcerptOptions, SyntaxIndex};
-
- #[gpui::test]
- async fn test_call_site(cx: &mut TestAppContext) {
- let (project, index, _rust_lang_id) = init_test(cx).await;
-
- let buffer = project
- .update(cx, |project, cx| {
- let project_path = project.find_project_path("c.rs", cx).unwrap();
- project.open_buffer(project_path, cx)
- })
- .await
- .unwrap();
-
- cx.run_until_parked();
-
- // first process_data call site
- let cursor_point = language::Point::new(8, 21);
- let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
-
- let context = cx
- .update(|cx| {
- EditPredictionContext::gather_context_in_background(
- cursor_point,
- buffer_snapshot,
- EditPredictionContextOptions {
- use_imports: true,
- excerpt: EditPredictionExcerptOptions {
- max_bytes: 60,
- min_bytes: 10,
- target_before_cursor_over_total_bytes: 0.5,
- },
- score: EditPredictionScoreOptions {
- omit_excerpt_overlaps: true,
- },
- max_retrieved_declarations: u8::MAX,
- },
- Some(index.clone()),
- cx,
- )
- })
- .await
- .unwrap();
-
- let mut snippet_identifiers = context
- .declarations
- .iter()
- .map(|snippet| snippet.identifier.name.as_ref())
- .collect::<Vec<_>>();
- snippet_identifiers.sort();
- assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
- drop(buffer);
+ pub fn refresh(&mut self, buffer: Entity<Buffer>, position: Anchor, _: &mut Context<Self>) {
+ self.update_tx.unbounded_send((buffer, position)).ok();
}
- async fn init_test(
- cx: &mut TestAppContext,
- ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
- cx.update(|cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- });
+ pub fn related_files(&self) -> &[RelatedFile] {
+ &self.related_files
+ }
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- path!("/root"),
- json!({
- "a.rs": indoc! {r#"
- fn main() {
- let x = 1;
- let y = 2;
- let z = add(x, y);
- println!("Result: {}", z);
- }
+ async fn fetch_excerpts(
+ this: WeakEntity<Self>,
+ buffer: Entity<Buffer>,
+ position: Anchor,
+ cx: &mut AsyncApp,
+ ) -> Result<()> {
+ let (project, snapshot) = this.read_with(cx, |this, cx| {
+ (this.project.upgrade(), buffer.read(cx).snapshot())
+ })?;
+ let Some(project) = project else {
+ return Ok(());
+ };
- fn add(a: i32, b: i32) -> i32 {
- a + b
- }
- "#},
- "b.rs": indoc! {"
- pub struct Config {
- pub name: String,
- pub value: i32,
- }
+ let file = snapshot.file().cloned();
+ if let Some(file) = &file {
+ log::debug!("retrieving_context buffer:{}", file.path().as_unix_str());
+ }
- impl Config {
- pub fn new(name: String, value: i32) -> Self {
- Config { name, value }
+ this.update(cx, |_, cx| {
+ cx.emit(RelatedExcerptStoreEvent::StartedRefresh);
+ })?;
+
+ let identifiers = cx
+ .background_spawn(async move { identifiers_for_position(&snapshot, position) })
+ .await;
+
+ let async_cx = cx.clone();
+ let start_time = Instant::now();
+ let futures = this.update(cx, |this, cx| {
+ identifiers
+ .into_iter()
+ .filter_map(|identifier| {
+ let task = if let Some(entry) = this.cache.get(&identifier) {
+ DefinitionTask::CacheHit(entry.clone())
+ } else {
+ DefinitionTask::CacheMiss(
+ this.project
+ .update(cx, |project, cx| {
+ project.definitions(&buffer, identifier.range.start, cx)
+ })
+ .ok()?,
+ )
+ };
+
+ let cx = async_cx.clone();
+ let project = project.clone();
+ Some(async move {
+ match task {
+ DefinitionTask::CacheHit(cache_entry) => {
+ Some((identifier, cache_entry, None))
+ }
+ DefinitionTask::CacheMiss(task) => {
+ let locations = task.await.log_err()??;
+ let duration = start_time.elapsed();
+ cx.update(|cx| {
+ (
+ identifier,
+ Arc::new(CacheEntry {
+ definitions: locations
+ .into_iter()
+ .filter_map(|location| {
+ process_definition(location, &project, cx)
+ })
+ .collect(),
+ }),
+ Some(duration),
+ )
+ })
+ .ok()
+ }
}
- }
- "},
- "c.rs": indoc! {r#"
- use std::collections::HashMap;
-
- fn main() {
- let args: Vec<String> = std::env::args().collect();
- let data: Vec<i32> = args[1..]
- .iter()
- .filter_map(|s| s.parse().ok())
- .collect();
- let result = process_data(data);
- println!("{:?}", result);
- }
+ })
+ })
+ .collect::<Vec<_>>()
+ })?;
+
+ let mut cache_hit_count = 0;
+ let mut cache_miss_count = 0;
+ let mut mean_definition_latency = Duration::ZERO;
+ let mut max_definition_latency = Duration::ZERO;
+ let mut new_cache = HashMap::default();
+ new_cache.reserve(futures.len());
+ for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() {
+ new_cache.insert(identifier, entry);
+ if let Some(duration) = duration {
+ cache_miss_count += 1;
+ mean_definition_latency += duration;
+ max_definition_latency = max_definition_latency.max(duration);
+ } else {
+ cache_hit_count += 1;
+ }
+ }
+ mean_definition_latency /= cache_miss_count.max(1) as u32;
- fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
- let mut counts = HashMap::new();
- for value in data {
- *counts.entry(value).or_insert(0) += 1;
- }
- counts
- }
+ let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?;
- #[cfg(test)]
- mod tests {
- use super::*;
+ if let Some(file) = &file {
+ log::debug!(
+ "finished retrieving context buffer:{}, latency:{:?}",
+ file.path().as_unix_str(),
+ start_time.elapsed()
+ );
+ }
- #[test]
- fn test_process_data() {
- let data = vec![1, 2, 2, 3];
- let result = process_data(data);
- assert_eq!(result.get(&2), Some(&2));
- }
- }
- "#}
- }),
- )
- .await;
- let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
- let language_registry = project.read_with(cx, |project, _| project.languages().clone());
- let lang = rust_lang();
- let lang_id = lang.id();
- language_registry.add(Arc::new(lang));
-
- let file_indexing_parallelism = 2;
- let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx));
- cx.run_until_parked();
-
- (project, index, lang_id)
+ this.update(cx, |this, cx| {
+ this.cache = new_cache;
+ this.related_files = related_files;
+ cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
+ cache_hit_count,
+ cache_miss_count,
+ mean_definition_latency,
+ max_definition_latency,
+ });
+ })?;
+
+ anyhow::Ok(())
}
+}
- fn rust_lang() -> Language {
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- )
- .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
- .unwrap()
- .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
- .unwrap()
+async fn rebuild_related_files(
+ new_entries: HashMap<Identifier, Arc<CacheEntry>>,
+ cx: &mut AsyncApp,
+) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedFile>)> {
+ let mut snapshots = HashMap::default();
+ for entry in new_entries.values() {
+ for definition in &entry.definitions {
+ if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
+ definition
+ .buffer
+ .read_with(cx, |buffer, _| buffer.parsing_idle())?
+ .await;
+ e.insert(
+ definition
+ .buffer
+ .read_with(cx, |buffer, _| buffer.snapshot())?,
+ );
+ }
+ }
}
+
+ Ok(cx
+ .background_spawn(async move {
+ let mut files = Vec::<RelatedFile>::new();
+ let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
+ let mut paths_by_buffer = HashMap::default();
+ for entry in new_entries.values() {
+ for definition in &entry.definitions {
+ let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else {
+ continue;
+ };
+ paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone());
+ ranges_by_buffer
+ .entry(definition.buffer.clone())
+ .or_default()
+ .push(definition.anchor_range.to_point(snapshot));
+ }
+ }
+
+ for (buffer, ranges) in ranges_by_buffer {
+ let Some(snapshot) = snapshots.get(&buffer.entity_id()) else {
+ continue;
+ };
+ let Some(project_path) = paths_by_buffer.get(&buffer.entity_id()) else {
+ continue;
+ };
+ let excerpts = assemble_excerpts(snapshot, ranges);
+ files.push(RelatedFile {
+ path: project_path.clone(),
+ buffer: buffer.downgrade(),
+ excerpts,
+ max_row: snapshot.max_point().row,
+ });
+ }
+
+ files.sort_by_key(|file| file.path.clone());
+ (new_entries, files)
+ })
+ .await)
+}
+
+fn process_definition(
+ location: LocationLink,
+ project: &Entity<Project>,
+ cx: &mut App,
+) -> Option<CachedDefinition> {
+ let buffer = location.target.buffer.read(cx);
+ let anchor_range = location.target.range;
+ let file = buffer.file()?;
+ let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?;
+ if worktree.read(cx).is_single_file() {
+ return None;
+ }
+ Some(CachedDefinition {
+ path: ProjectPath {
+ worktree_id: file.worktree_id(cx),
+ path: file.path().clone(),
+ },
+ buffer: location.target.buffer,
+ anchor_range,
+ })
+}
+
+/// Gets all of the identifiers that are present in the given line, and its containing
+/// outline items.
+fn identifiers_for_position(buffer: &BufferSnapshot, position: Anchor) -> Vec<Identifier> {
+ let offset = position.to_offset(buffer);
+ let point = buffer.offset_to_point(offset);
+
+ let line_range = Point::new(point.row, 0)..Point::new(point.row + 1, 0).min(buffer.max_point());
+ let mut ranges = vec![line_range.to_offset(&buffer)];
+
+ // Include the range of the outline item itself, but not its body.
+ let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None);
+ for item in outline_items {
+ if let Some(body_range) = item.body_range(&buffer) {
+ ranges.push(item.range.start..body_range.start.to_offset(&buffer));
+ } else {
+ ranges.push(item.range.clone());
+ }
+ }
+
+ ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
+ ranges.dedup_by(|a, b| {
+ if a.start <= b.end {
+ b.start = b.start.min(a.start);
+ b.end = b.end.max(a.end);
+ true
+ } else {
+ false
+ }
+ });
+
+ let mut identifiers = Vec::new();
+ let outer_range =
+ ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end);
+
+ let mut captures = buffer
+ .syntax
+ .captures(outer_range.clone(), &buffer.text, |grammar| {
+ grammar
+ .highlights_config
+ .as_ref()
+ .map(|config| &config.query)
+ });
+
+ for range in ranges {
+ captures.set_byte_range(range.start..outer_range.end);
+
+ let mut last_range = None;
+ while let Some(capture) = captures.peek() {
+ let node_range = capture.node.byte_range();
+ if node_range.start > range.end {
+ break;
+ }
+ let config = captures.grammars()[capture.grammar_index]
+ .highlights_config
+ .as_ref();
+
+ if let Some(config) = config
+ && config.identifier_capture_indices.contains(&capture.index)
+ && range.contains_inclusive(&node_range)
+ && Some(&node_range) != last_range.as_ref()
+ {
+ let name = buffer.text_for_range(node_range.clone()).collect();
+ identifiers.push(Identifier {
+ range: buffer.anchor_after(node_range.start)
+ ..buffer.anchor_before(node_range.end),
+ name,
+ });
+ last_range = Some(node_range);
+ }
+
+ captures.advance();
+ }
+ }
+
+ identifiers
}
@@ -1,11 +1,9 @@
-use language::{BufferSnapshot, LanguageId};
+use cloud_llm_client::predict_edits_v3::Line;
+use language::{BufferSnapshot, LanguageId, Point, ToOffset as _, ToPoint as _};
use std::ops::Range;
-use text::{Point, ToOffset as _, ToPoint as _};
use tree_sitter::{Node, TreeCursor};
use util::RangeExt;
-use crate::{BufferDeclaration, Line, declaration::DeclarationId, syntax_index::SyntaxIndexState};
-
// TODO:
//
// - Test parent signatures
@@ -31,19 +29,16 @@ pub struct EditPredictionExcerptOptions {
pub target_before_cursor_over_total_bytes: f32,
}
-// TODO: consider merging these
#[derive(Debug, Clone)]
pub struct EditPredictionExcerpt {
pub range: Range<usize>,
pub line_range: Range<Line>,
- pub parent_declarations: Vec<(DeclarationId, Range<usize>)>,
pub size: usize,
}
#[derive(Debug, Clone)]
pub struct EditPredictionExcerptText {
pub body: String,
- pub parent_signatures: Vec<String>,
pub language_id: Option<LanguageId>,
}
@@ -52,17 +47,8 @@ impl EditPredictionExcerpt {
let body = buffer
.text_for_range(self.range.clone())
.collect::<String>();
- let parent_signatures = self
- .parent_declarations
- .iter()
- .map(|(_, range)| buffer.text_for_range(range.clone()).collect::<String>())
- .collect();
let language_id = buffer.language().map(|l| l.id());
- EditPredictionExcerptText {
- body,
- parent_signatures,
- language_id,
- }
+ EditPredictionExcerptText { body, language_id }
}
/// Selects an excerpt around a buffer position, attempting to choose logical boundaries based
@@ -79,7 +65,6 @@ impl EditPredictionExcerpt {
query_point: Point,
buffer: &BufferSnapshot,
options: &EditPredictionExcerptOptions,
- syntax_index: Option<&SyntaxIndexState>,
) -> Option<Self> {
if buffer.len() <= options.max_bytes {
log::debug!(
@@ -89,11 +74,7 @@ impl EditPredictionExcerpt {
);
let offset_range = 0..buffer.len();
let line_range = Line(0)..Line(buffer.max_point().row);
- return Some(EditPredictionExcerpt::new(
- offset_range,
- line_range,
- Vec::new(),
- ));
+ return Some(EditPredictionExcerpt::new(offset_range, line_range));
}
let query_offset = query_point.to_offset(buffer);
@@ -104,19 +85,10 @@ impl EditPredictionExcerpt {
return None;
}
- let parent_declarations = if let Some(syntax_index) = syntax_index {
- syntax_index
- .buffer_declarations_containing_range(buffer.remote_id(), query_range.clone())
- .collect()
- } else {
- Vec::new()
- };
-
let excerpt_selector = ExcerptSelector {
query_offset,
query_range,
query_line_range: Line(query_line_range.start)..Line(query_line_range.end),
- parent_declarations: &parent_declarations,
buffer,
options,
};
@@ -139,20 +111,10 @@ impl EditPredictionExcerpt {
excerpt_selector.select_lines()
}
- fn new(
- range: Range<usize>,
- line_range: Range<Line>,
- parent_declarations: Vec<(DeclarationId, Range<usize>)>,
- ) -> Self {
- let size = range.len()
- + parent_declarations
- .iter()
- .map(|(_, range)| range.len())
- .sum::<usize>();
+ fn new(range: Range<usize>, line_range: Range<Line>) -> Self {
Self {
+ size: range.len(),
range,
- parent_declarations,
- size,
line_range,
}
}
@@ -162,14 +124,7 @@ impl EditPredictionExcerpt {
// this is an issue because parent_signature_ranges may be incorrect
log::error!("bug: with_expanded_range called with disjoint range");
}
- let mut parent_declarations = Vec::with_capacity(self.parent_declarations.len());
- for (declaration_id, range) in &self.parent_declarations {
- if !range.contains_inclusive(&new_range) {
- break;
- }
- parent_declarations.push((*declaration_id, range.clone()));
- }
- Self::new(new_range, new_line_range, parent_declarations)
+ Self::new(new_range, new_line_range)
}
fn parent_signatures_size(&self) -> usize {
@@ -181,7 +136,6 @@ struct ExcerptSelector<'a> {
query_offset: usize,
query_range: Range<usize>,
query_line_range: Range<Line>,
- parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)],
buffer: &'a BufferSnapshot,
options: &'a EditPredictionExcerptOptions,
}
@@ -409,13 +363,7 @@ impl<'a> ExcerptSelector<'a> {
}
fn make_excerpt(&self, range: Range<usize>, line_range: Range<Line>) -> EditPredictionExcerpt {
- let parent_declarations = self
- .parent_declarations
- .iter()
- .filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range))
- .map(|(id, declaration)| (*id, declaration.signature_range.clone()))
- .collect();
- EditPredictionExcerpt::new(range, line_range, parent_declarations)
+ EditPredictionExcerpt::new(range, line_range)
}
/// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.
@@ -506,9 +454,8 @@ mod tests {
let buffer = create_buffer(&text, cx);
let cursor_point = cursor.to_point(&buffer);
- let excerpt =
- EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options, None)
- .expect("Should select an excerpt");
+ let excerpt = EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options)
+ .expect("Should select an excerpt");
pretty_assertions::assert_eq!(
generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false),
generate_marked_text(&text, &[expected_excerpt], false)
@@ -1,1319 +0,0 @@
-use collections::HashMap;
-use language::BufferSnapshot;
-use language::ImportsConfig;
-use language::Language;
-use std::ops::Deref;
-use std::path::Path;
-use std::sync::Arc;
-use std::{borrow::Cow, ops::Range};
-use text::OffsetRangeExt as _;
-use util::RangeExt;
-use util::paths::PathStyle;
-
-use crate::Identifier;
-use crate::text_similarity::Occurrences;
-
-// TODO: Write documentation for extension authors. The @import capture must match before or in the
-// same pattern as all all captures it contains
-
-// Future improvements to consider:
-//
-// * Distinguish absolute vs relative paths in captures. `#include "maths.h"` is relative whereas
-// `#include <maths.h>` is not.
-//
-// * Provide the name used when importing whole modules (see tests with "named_module" in the name).
-// To be useful, will require parsing of identifier qualification.
-//
-// * Scoping for imports that aren't at the top level
-//
-// * Only scan a prefix of the file, when possible. This could look like having query matches that
-// indicate it reached a declaration that is not allowed in the import section.
-//
-// * Support directly parsing to occurrences instead of storing namespaces / paths. Types should be
-// generic on this, so that tests etc can still use strings. Could do similar in syntax index.
-//
-// * Distinguish different types of namespaces when known. E.g. "name.type" capture. Once capture
-// names are more open-ended like this may make sense to build and cache a jump table (direct
-// dispatch from capture index).
-//
-// * There are a few "Language specific:" comments on behavior that gets applied to all languages.
-// Would be cleaner to be conditional on the language or otherwise configured.
-
-#[derive(Debug, Clone, Default)]
-pub struct Imports {
- pub identifier_to_imports: HashMap<Identifier, Vec<Import>>,
- pub wildcard_modules: Vec<Module>,
-}
-
-#[derive(Debug, Clone)]
-pub enum Import {
- Direct {
- module: Module,
- },
- Alias {
- module: Module,
- external_identifier: Identifier,
- },
-}
-
-#[derive(Debug, Clone)]
-pub enum Module {
- SourceExact(Arc<Path>),
- SourceFuzzy(Arc<Path>),
- Namespace(Namespace),
-}
-
-impl Module {
- fn empty() -> Self {
- Module::Namespace(Namespace::default())
- }
-
- fn push_range(
- &mut self,
- range: &ModuleRange,
- snapshot: &BufferSnapshot,
- language: &Language,
- parent_abs_path: Option<&Path>,
- ) -> usize {
- if range.is_empty() {
- return 0;
- }
-
- match range {
- ModuleRange::Source(range) => {
- if let Self::Namespace(namespace) = self
- && namespace.0.is_empty()
- {
- let path = snapshot.text_for_range(range.clone()).collect::<Cow<str>>();
-
- let path = if let Some(strip_regex) =
- language.config().import_path_strip_regex.as_ref()
- {
- strip_regex.replace_all(&path, "")
- } else {
- path
- };
-
- let path = Path::new(path.as_ref());
- if (path.starts_with(".") || path.starts_with(".."))
- && let Some(parent_abs_path) = parent_abs_path
- && let Ok(abs_path) =
- util::paths::normalize_lexically(&parent_abs_path.join(path))
- {
- *self = Self::SourceExact(abs_path.into());
- } else {
- *self = Self::SourceFuzzy(path.into());
- };
- } else if matches!(self, Self::SourceExact(_))
- || matches!(self, Self::SourceFuzzy(_))
- {
- log::warn!("bug in imports query: encountered multiple @source matches");
- } else {
- log::warn!(
- "bug in imports query: encountered both @namespace and @source match"
- );
- }
- }
- ModuleRange::Namespace(range) => {
- if let Self::Namespace(namespace) = self {
- let segment = range_text(snapshot, range);
- if language.config().ignored_import_segments.contains(&segment) {
- return 0;
- } else {
- namespace.0.push(segment);
- return 1;
- }
- } else {
- log::warn!(
- "bug in imports query: encountered both @namespace and @source match"
- );
- }
- }
- }
- 0
- }
-}
-
-#[derive(Debug, Clone)]
-enum ModuleRange {
- Source(Range<usize>),
- Namespace(Range<usize>),
-}
-
-impl Deref for ModuleRange {
- type Target = Range<usize>;
-
- fn deref(&self) -> &Self::Target {
- match self {
- ModuleRange::Source(range) => range,
- ModuleRange::Namespace(range) => range,
- }
- }
-}
-
-#[derive(Debug, Clone, PartialEq, Eq, Default)]
-pub struct Namespace(pub Vec<Arc<str>>);
-
-impl Namespace {
- pub fn occurrences(&self) -> Occurrences {
- Occurrences::from_identifiers(&self.0)
- }
-}
-
-impl Imports {
- pub fn gather(snapshot: &BufferSnapshot, parent_abs_path: Option<&Path>) -> Self {
- // Query to match different import patterns
- let mut matches = snapshot
- .syntax
- .matches(0..snapshot.len(), &snapshot.text, |grammar| {
- grammar.imports_config().map(|imports| &imports.query)
- });
-
- let mut detached_nodes: Vec<DetachedNode> = Vec::new();
- let mut identifier_to_imports = HashMap::default();
- let mut wildcard_modules = Vec::new();
- let mut import_range = None;
-
- while let Some(query_match) = matches.peek() {
- let ImportsConfig {
- query: _,
- import_ix,
- name_ix,
- namespace_ix,
- source_ix,
- list_ix,
- wildcard_ix,
- alias_ix,
- } = matches.grammars()[query_match.grammar_index]
- .imports_config()
- .unwrap();
-
- let mut new_import_range = None;
- let mut alias_range = None;
- let mut modules = Vec::new();
- let mut content: Option<(Range<usize>, ContentKind)> = None;
- for capture in query_match.captures {
- let capture_range = capture.node.byte_range();
-
- if capture.index == *import_ix {
- new_import_range = Some(capture_range);
- } else if Some(capture.index) == *namespace_ix {
- modules.push(ModuleRange::Namespace(capture_range));
- } else if Some(capture.index) == *source_ix {
- modules.push(ModuleRange::Source(capture_range));
- } else if Some(capture.index) == *alias_ix {
- alias_range = Some(capture_range);
- } else {
- let mut found_content = None;
- if Some(capture.index) == *name_ix {
- found_content = Some((capture_range, ContentKind::Name));
- } else if Some(capture.index) == *list_ix {
- found_content = Some((capture_range, ContentKind::List));
- } else if Some(capture.index) == *wildcard_ix {
- found_content = Some((capture_range, ContentKind::Wildcard));
- }
- if let Some((found_content_range, found_kind)) = found_content {
- if let Some((_, old_kind)) = content {
- let point = found_content_range.to_point(snapshot);
- log::warn!(
- "bug in {} imports query: unexpected multiple captures of {} and {} ({}:{}:{})",
- query_match.language.name(),
- old_kind.capture_name(),
- found_kind.capture_name(),
- snapshot
- .file()
- .map(|p| p.path().display(PathStyle::Posix))
- .unwrap_or_default(),
- point.start.row + 1,
- point.start.column + 1
- );
- }
- content = Some((found_content_range, found_kind));
- }
- }
- }
-
- if let Some(new_import_range) = new_import_range {
- log::trace!("starting new import {:?}", new_import_range);
- Self::gather_from_import_statement(
- &detached_nodes,
- &snapshot,
- parent_abs_path,
- &mut identifier_to_imports,
- &mut wildcard_modules,
- );
- detached_nodes.clear();
- import_range = Some(new_import_range.clone());
- }
-
- if let Some((content, content_kind)) = content {
- if import_range
- .as_ref()
- .is_some_and(|import_range| import_range.contains_inclusive(&content))
- {
- detached_nodes.push(DetachedNode {
- modules,
- content: content.clone(),
- content_kind,
- alias: alias_range.unwrap_or(0..0),
- language: query_match.language.clone(),
- });
- } else {
- log::trace!(
- "filtered out match not inside import range: {content_kind:?} at {content:?}"
- );
- }
- }
-
- matches.advance();
- }
-
- Self::gather_from_import_statement(
- &detached_nodes,
- &snapshot,
- parent_abs_path,
- &mut identifier_to_imports,
- &mut wildcard_modules,
- );
-
- Imports {
- identifier_to_imports,
- wildcard_modules,
- }
- }
-
- fn gather_from_import_statement(
- detached_nodes: &[DetachedNode],
- snapshot: &BufferSnapshot,
- parent_abs_path: Option<&Path>,
- identifier_to_imports: &mut HashMap<Identifier, Vec<Import>>,
- wildcard_modules: &mut Vec<Module>,
- ) {
- let mut trees = Vec::new();
-
- for detached_node in detached_nodes {
- if let Some(node) = Self::attach_node(detached_node.into(), &mut trees) {
- trees.push(node);
- }
- log::trace!(
- "Attached node to tree\n{:#?}\nAttach result:\n{:#?}",
- detached_node,
- trees
- .iter()
- .map(|tree| tree.debug(snapshot))
- .collect::<Vec<_>>()
- );
- }
-
- for tree in &trees {
- let mut module = Module::empty();
- Self::gather_from_tree(
- tree,
- snapshot,
- parent_abs_path,
- &mut module,
- identifier_to_imports,
- wildcard_modules,
- );
- }
- }
-
- fn attach_node(mut node: ImportTree, trees: &mut Vec<ImportTree>) -> Option<ImportTree> {
- let mut tree_index = 0;
- while tree_index < trees.len() {
- let tree = &mut trees[tree_index];
- if !node.content.is_empty() && node.content == tree.content {
- // multiple matches can apply to the same name/list/wildcard. This keeps the queries
- // simpler by combining info from these matches.
- if tree.module.is_empty() {
- tree.module = node.module;
- tree.module_children = node.module_children;
- }
- if tree.alias.is_empty() {
- tree.alias = node.alias;
- }
- return None;
- } else if !node.module.is_empty() && node.module.contains_inclusive(&tree.range()) {
- node.module_children.push(trees.remove(tree_index));
- continue;
- } else if !node.content.is_empty() && node.content.contains_inclusive(&tree.content) {
- node.content_children.push(trees.remove(tree_index));
- continue;
- } else if !tree.content.is_empty() && tree.content.contains_inclusive(&node.content) {
- if let Some(node) = Self::attach_node(node, &mut tree.content_children) {
- tree.content_children.push(node);
- }
- return None;
- }
- tree_index += 1;
- }
- Some(node)
- }
-
- fn gather_from_tree(
- tree: &ImportTree,
- snapshot: &BufferSnapshot,
- parent_abs_path: Option<&Path>,
- current_module: &mut Module,
- identifier_to_imports: &mut HashMap<Identifier, Vec<Import>>,
- wildcard_modules: &mut Vec<Module>,
- ) {
- let mut pop_count = 0;
-
- if tree.module_children.is_empty() {
- pop_count +=
- current_module.push_range(&tree.module, snapshot, &tree.language, parent_abs_path);
- } else {
- for child in &tree.module_children {
- pop_count += Self::extend_namespace_from_tree(
- child,
- snapshot,
- parent_abs_path,
- current_module,
- );
- }
- };
-
- if tree.content_children.is_empty() && !tree.content.is_empty() {
- match tree.content_kind {
- ContentKind::Name | ContentKind::List => {
- if tree.alias.is_empty() {
- identifier_to_imports
- .entry(Identifier {
- language_id: tree.language.id(),
- name: range_text(snapshot, &tree.content),
- })
- .or_default()
- .push(Import::Direct {
- module: current_module.clone(),
- });
- } else {
- let alias_name: Arc<str> = range_text(snapshot, &tree.alias);
- let external_name = range_text(snapshot, &tree.content);
- // Language specific: skip "_" aliases for Rust
- if alias_name.as_ref() != "_" {
- identifier_to_imports
- .entry(Identifier {
- language_id: tree.language.id(),
- name: alias_name,
- })
- .or_default()
- .push(Import::Alias {
- module: current_module.clone(),
- external_identifier: Identifier {
- language_id: tree.language.id(),
- name: external_name,
- },
- });
- }
- }
- }
- ContentKind::Wildcard => wildcard_modules.push(current_module.clone()),
- }
- } else {
- for child in &tree.content_children {
- Self::gather_from_tree(
- child,
- snapshot,
- parent_abs_path,
- current_module,
- identifier_to_imports,
- wildcard_modules,
- );
- }
- }
-
- if pop_count > 0 {
- match current_module {
- Module::SourceExact(_) | Module::SourceFuzzy(_) => {
- log::warn!(
- "bug in imports query: encountered both @namespace and @source match"
- );
- }
- Module::Namespace(namespace) => {
- namespace.0.drain(namespace.0.len() - pop_count..);
- }
- }
- }
- }
-
- fn extend_namespace_from_tree(
- tree: &ImportTree,
- snapshot: &BufferSnapshot,
- parent_abs_path: Option<&Path>,
- module: &mut Module,
- ) -> usize {
- let mut pop_count = 0;
- if tree.module_children.is_empty() {
- pop_count += module.push_range(&tree.module, snapshot, &tree.language, parent_abs_path);
- } else {
- for child in &tree.module_children {
- pop_count +=
- Self::extend_namespace_from_tree(child, snapshot, parent_abs_path, module);
- }
- }
- if tree.content_children.is_empty() {
- pop_count += module.push_range(
- &ModuleRange::Namespace(tree.content.clone()),
- snapshot,
- &tree.language,
- parent_abs_path,
- );
- } else {
- for child in &tree.content_children {
- pop_count +=
- Self::extend_namespace_from_tree(child, snapshot, parent_abs_path, module);
- }
- }
- pop_count
- }
-}
-
-fn range_text(snapshot: &BufferSnapshot, range: &Range<usize>) -> Arc<str> {
- snapshot
- .text_for_range(range.clone())
- .collect::<Cow<str>>()
- .into()
-}
-
-#[derive(Debug)]
-struct DetachedNode {
- modules: Vec<ModuleRange>,
- content: Range<usize>,
- content_kind: ContentKind,
- alias: Range<usize>,
- language: Arc<Language>,
-}
-
-#[derive(Debug, Clone, Copy)]
-enum ContentKind {
- Name,
- Wildcard,
- List,
-}
-
-impl ContentKind {
- fn capture_name(&self) -> &'static str {
- match self {
- ContentKind::Name => "name",
- ContentKind::Wildcard => "wildcard",
- ContentKind::List => "list",
- }
- }
-}
-
-#[derive(Debug)]
-struct ImportTree {
- module: ModuleRange,
- /// When non-empty, provides namespace / source info which should be used instead of `module`.
- module_children: Vec<ImportTree>,
- content: Range<usize>,
- /// When non-empty, provides content which should be used instead of `content`.
- content_children: Vec<ImportTree>,
- content_kind: ContentKind,
- alias: Range<usize>,
- language: Arc<Language>,
-}
-
-impl ImportTree {
- fn range(&self) -> Range<usize> {
- self.module.start.min(self.content.start)..self.module.end.max(self.content.end)
- }
-
- #[allow(dead_code)]
- fn debug<'a>(&'a self, snapshot: &'a BufferSnapshot) -> ImportTreeDebug<'a> {
- ImportTreeDebug {
- tree: self,
- snapshot,
- }
- }
-
- fn from_module_range(module: &ModuleRange, language: Arc<Language>) -> Self {
- ImportTree {
- module: module.clone(),
- module_children: Vec::new(),
- content: 0..0,
- content_children: Vec::new(),
- content_kind: ContentKind::Name,
- alias: 0..0,
- language,
- }
- }
-}
-
-impl From<&DetachedNode> for ImportTree {
- fn from(value: &DetachedNode) -> Self {
- let module;
- let module_children;
- match value.modules.len() {
- 0 => {
- module = ModuleRange::Namespace(0..0);
- module_children = Vec::new();
- }
- 1 => {
- module = value.modules[0].clone();
- module_children = Vec::new();
- }
- _ => {
- module = ModuleRange::Namespace(
- value.modules.first().unwrap().start..value.modules.last().unwrap().end,
- );
- module_children = value
- .modules
- .iter()
- .map(|module| ImportTree::from_module_range(module, value.language.clone()))
- .collect();
- }
- }
-
- ImportTree {
- module,
- module_children,
- content: value.content.clone(),
- content_children: Vec::new(),
- content_kind: value.content_kind,
- alias: value.alias.clone(),
- language: value.language.clone(),
- }
- }
-}
-
-struct ImportTreeDebug<'a> {
- tree: &'a ImportTree,
- snapshot: &'a BufferSnapshot,
-}
-
-impl std::fmt::Debug for ImportTreeDebug<'_> {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- f.debug_struct("ImportTree")
- .field("module_range", &self.tree.module)
- .field("module_text", &range_text(self.snapshot, &self.tree.module))
- .field(
- "module_children",
- &self
- .tree
- .module_children
- .iter()
- .map(|child| child.debug(&self.snapshot))
- .collect::<Vec<Self>>(),
- )
- .field("content_range", &self.tree.content)
- .field(
- "content_text",
- &range_text(self.snapshot, &self.tree.content),
- )
- .field(
- "content_children",
- &self
- .tree
- .content_children
- .iter()
- .map(|child| child.debug(&self.snapshot))
- .collect::<Vec<Self>>(),
- )
- .field("content_kind", &self.tree.content_kind)
- .field("alias_range", &self.tree.alias)
- .field("alias_text", &range_text(self.snapshot, &self.tree.alias))
- .finish()
- }
-}
-
-#[cfg(test)]
-mod test {
- use std::path::PathBuf;
- use std::sync::{Arc, LazyLock};
-
- use super::*;
- use collections::HashSet;
- use gpui::{TestAppContext, prelude::*};
- use indoc::indoc;
- use language::{
- Buffer, Language, LanguageConfig, tree_sitter_python, tree_sitter_rust,
- tree_sitter_typescript,
- };
- use regex::Regex;
-
- #[gpui::test]
- fn test_rust_simple(cx: &mut TestAppContext) {
- check_imports(
- &RUST,
- "use std::collections::HashMap;",
- &[&["std", "collections", "HashMap"]],
- cx,
- );
-
- check_imports(
- &RUST,
- "pub use std::collections::HashMap;",
- &[&["std", "collections", "HashMap"]],
- cx,
- );
-
- check_imports(
- &RUST,
- "use std::collections::{HashMap, HashSet};",
- &[
- &["std", "collections", "HashMap"],
- &["std", "collections", "HashSet"],
- ],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_rust_nested(cx: &mut TestAppContext) {
- check_imports(
- &RUST,
- "use std::{any::TypeId, collections::{HashMap, HashSet}};",
- &[
- &["std", "any", "TypeId"],
- &["std", "collections", "HashMap"],
- &["std", "collections", "HashSet"],
- ],
- cx,
- );
-
- check_imports(
- &RUST,
- "use a::b::c::{d::e::F, g::h::I};",
- &[
- &["a", "b", "c", "d", "e", "F"],
- &["a", "b", "c", "g", "h", "I"],
- ],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_rust_multiple_imports(cx: &mut TestAppContext) {
- check_imports(
- &RUST,
- indoc! {"
- use std::collections::HashMap;
- use std::any::{TypeId, Any};
- "},
- &[
- &["std", "collections", "HashMap"],
- &["std", "any", "TypeId"],
- &["std", "any", "Any"],
- ],
- cx,
- );
-
- check_imports(
- &RUST,
- indoc! {"
- use std::collections::HashSet;
-
- fn main() {
- let unqualified = HashSet::new();
- let qualified = std::collections::HashMap::new();
- }
-
- use std::any::TypeId;
- "},
- &[
- &["std", "collections", "HashSet"],
- &["std", "any", "TypeId"],
- ],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_rust_wildcard(cx: &mut TestAppContext) {
- check_imports(&RUST, "use prelude::*;", &[&["prelude", "WILDCARD"]], cx);
-
- check_imports(
- &RUST,
- "use zed::prelude::*;",
- &[&["zed", "prelude", "WILDCARD"]],
- cx,
- );
-
- check_imports(&RUST, "use prelude::{*};", &[&["prelude", "WILDCARD"]], cx);
-
- check_imports(
- &RUST,
- "use prelude::{File, *};",
- &[&["prelude", "File"], &["prelude", "WILDCARD"]],
- cx,
- );
-
- check_imports(
- &RUST,
- "use zed::{App, prelude::*};",
- &[&["zed", "App"], &["zed", "prelude", "WILDCARD"]],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_rust_alias(cx: &mut TestAppContext) {
- check_imports(
- &RUST,
- "use std::io::Result as IoResult;",
- &[&["std", "io", "Result AS IoResult"]],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_rust_crate_and_super(cx: &mut TestAppContext) {
- check_imports(&RUST, "use crate::a::b::c;", &[&["a", "b", "c"]], cx);
- check_imports(&RUST, "use super::a::b::c;", &[&["a", "b", "c"]], cx);
- // TODO: Consider stripping leading "::". Not done for now because for the text similarity matching usecase this
- // is fine.
- check_imports(&RUST, "use ::a::b::c;", &[&["::a", "b", "c"]], cx);
- }
-
- #[gpui::test]
- fn test_typescript_imports(cx: &mut TestAppContext) {
- let parent_abs_path = PathBuf::from("/home/user/project");
-
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import "./maths.js";"#,
- &[&["SOURCE /home/user/project/maths", "WILDCARD"]],
- cx,
- );
-
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import "../maths.js";"#,
- &[&["SOURCE /home/user/maths", "WILDCARD"]],
- cx,
- );
-
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import RandomNumberGenerator, { pi as π } from "./maths.js";"#,
- &[
- &["SOURCE /home/user/project/maths", "RandomNumberGenerator"],
- &["SOURCE /home/user/project/maths", "pi AS π"],
- ],
- cx,
- );
-
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import { pi, phi, absolute } from "./maths.js";"#,
- &[
- &["SOURCE /home/user/project/maths", "pi"],
- &["SOURCE /home/user/project/maths", "phi"],
- &["SOURCE /home/user/project/maths", "absolute"],
- ],
- cx,
- );
-
- // index.js is removed by import_path_strip_regex
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import { pi, phi, absolute } from "./maths/index.js";"#,
- &[
- &["SOURCE /home/user/project/maths", "pi"],
- &["SOURCE /home/user/project/maths", "phi"],
- &["SOURCE /home/user/project/maths", "absolute"],
- ],
- cx,
- );
-
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import type { SomeThing } from "./some-module.js";"#,
- &[&["SOURCE /home/user/project/some-module", "SomeThing"]],
- cx,
- );
-
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import { type SomeThing, OtherThing } from "./some-module.js";"#,
- &[
- &["SOURCE /home/user/project/some-module", "SomeThing"],
- &["SOURCE /home/user/project/some-module", "OtherThing"],
- ],
- cx,
- );
-
- // index.js is removed by import_path_strip_regex
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import { type SomeThing, OtherThing } from "./some-module/index.js";"#,
- &[
- &["SOURCE /home/user/project/some-module", "SomeThing"],
- &["SOURCE /home/user/project/some-module", "OtherThing"],
- ],
- cx,
- );
-
- // fuzzy paths
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import { type SomeThing, OtherThing } from "@my-app/some-module.js";"#,
- &[
- &["SOURCE FUZZY @my-app/some-module", "SomeThing"],
- &["SOURCE FUZZY @my-app/some-module", "OtherThing"],
- ],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_typescript_named_module_imports(cx: &mut TestAppContext) {
- let parent_abs_path = PathBuf::from("/home/user/project");
-
- // TODO: These should provide the name that the module is bound to.
- // For now instead these are treated as unqualified wildcard imports.
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import * as math from "./maths.js";"#,
- // &[&["/home/user/project/maths.js", "WILDCARD AS math"]],
- &[&["SOURCE /home/user/project/maths", "WILDCARD"]],
- cx,
- );
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import math = require("./maths");"#,
- // &[&["/home/user/project/maths", "WILDCARD AS math"]],
- &[&["SOURCE /home/user/project/maths", "WILDCARD"]],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_python_imports(cx: &mut TestAppContext) {
- check_imports(&PYTHON, "from math import pi", &[&["math", "pi"]], cx);
-
- check_imports(
- &PYTHON,
- "from math import pi, sin, cos",
- &[&["math", "pi"], &["math", "sin"], &["math", "cos"]],
- cx,
- );
-
- check_imports(&PYTHON, "from math import *", &[&["math", "WILDCARD"]], cx);
-
- check_imports(
- &PYTHON,
- "from math import foo.bar.baz",
- &[&["math", "foo", "bar", "baz"]],
- cx,
- );
-
- check_imports(
- &PYTHON,
- "from math import pi as PI",
- &[&["math", "pi AS PI"]],
- cx,
- );
-
- check_imports(
- &PYTHON,
- "from serializers.json import JsonSerializer",
- &[&["serializers", "json", "JsonSerializer"]],
- cx,
- );
-
- check_imports(
- &PYTHON,
- "from custom.serializers import json, xml, yaml",
- &[
- &["custom", "serializers", "json"],
- &["custom", "serializers", "xml"],
- &["custom", "serializers", "yaml"],
- ],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_python_named_module_imports(cx: &mut TestAppContext) {
- // TODO: These should provide the name that the module is bound to.
- // For now instead these are treated as unqualified wildcard imports.
- //
- // check_imports(&PYTHON, "import math", &[&["math", "WILDCARD as math"]], cx);
- // check_imports(&PYTHON, "import math as maths", &[&["math", "WILDCARD AS maths"]], cx);
- //
- // Something like:
- //
- // (import_statement
- // name: [
- // (dotted_name
- // (identifier)* @namespace
- // (identifier) @name.module .)
- // (aliased_import
- // name: (dotted_name
- // ((identifier) ".")* @namespace
- // (identifier) @name.module .)
- // alias: (identifier) @alias)
- // ]) @import
-
- check_imports(&PYTHON, "import math", &[&["math", "WILDCARD"]], cx);
-
- check_imports(
- &PYTHON,
- "import math as maths",
- &[&["math", "WILDCARD"]],
- cx,
- );
-
- check_imports(&PYTHON, "import a.b.c", &[&["a", "b", "c", "WILDCARD"]], cx);
-
- check_imports(
- &PYTHON,
- "import a.b.c as d",
- &[&["a", "b", "c", "WILDCARD"]],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_python_package_relative_imports(cx: &mut TestAppContext) {
- // TODO: These should provide info about the dir they are relative to, to provide more
- // precise resolution. Instead, fuzzy matching is used as usual.
-
- check_imports(&PYTHON, "from . import math", &[&["math"]], cx);
-
- check_imports(&PYTHON, "from .a import math", &[&["a", "math"]], cx);
-
- check_imports(
- &PYTHON,
- "from ..a.b import math",
- &[&["a", "b", "math"]],
- cx,
- );
-
- check_imports(
- &PYTHON,
- "from ..a.b import *",
- &[&["a", "b", "WILDCARD"]],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_c_imports(cx: &mut TestAppContext) {
- let parent_abs_path = PathBuf::from("/home/user/project");
-
- // TODO: Distinguish that these are not relative to current path
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &C,
- r#"#include <math.h>"#,
- &[&["SOURCE FUZZY math.h", "WILDCARD"]],
- cx,
- );
-
- // TODO: These should be treated as relative, but don't start with ./ or ../
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &C,
- r#"#include "math.h""#,
- &[&["SOURCE FUZZY math.h", "WILDCARD"]],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_cpp_imports(cx: &mut TestAppContext) {
- let parent_abs_path = PathBuf::from("/home/user/project");
-
- // TODO: Distinguish that these are not relative to current path
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &CPP,
- r#"#include <math.h>"#,
- &[&["SOURCE FUZZY math.h", "WILDCARD"]],
- cx,
- );
-
- // TODO: These should be treated as relative, but don't start with ./ or ../
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &CPP,
- r#"#include "math.h""#,
- &[&["SOURCE FUZZY math.h", "WILDCARD"]],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_go_imports(cx: &mut TestAppContext) {
- check_imports(
- &GO,
- r#"import . "lib/math""#,
- &[&["lib/math", "WILDCARD"]],
- cx,
- );
-
- // not included, these are only for side-effects
- check_imports(&GO, r#"import _ "lib/math""#, &[], cx);
- }
-
- #[gpui::test]
- fn test_go_named_module_imports(cx: &mut TestAppContext) {
- // TODO: These should provide the name that the module is bound to.
- // For now instead these are treated as unqualified wildcard imports.
-
- check_imports(
- &GO,
- r#"import "lib/math""#,
- &[&["lib/math", "WILDCARD"]],
- cx,
- );
- check_imports(
- &GO,
- r#"import m "lib/math""#,
- &[&["lib/math", "WILDCARD"]],
- cx,
- );
- }
-
- #[track_caller]
- fn check_imports(
- language: &Arc<Language>,
- source: &str,
- expected: &[&[&str]],
- cx: &mut TestAppContext,
- ) {
- check_imports_with_file_abs_path(None, language, source, expected, cx);
- }
-
- #[track_caller]
- fn check_imports_with_file_abs_path(
- parent_abs_path: Option<&Path>,
- language: &Arc<Language>,
- source: &str,
- expected: &[&[&str]],
- cx: &mut TestAppContext,
- ) {
- let buffer = cx.new(|cx| {
- let mut buffer = Buffer::local(source, cx);
- buffer.set_language(Some(language.clone()), cx);
- buffer
- });
- cx.run_until_parked();
-
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
-
- let imports = Imports::gather(&snapshot, parent_abs_path);
- let mut actual_symbols = imports
- .identifier_to_imports
- .iter()
- .flat_map(|(identifier, imports)| {
- imports
- .iter()
- .map(|import| import.to_identifier_parts(identifier.name.as_ref()))
- })
- .chain(
- imports
- .wildcard_modules
- .iter()
- .map(|module| module.to_identifier_parts("WILDCARD")),
- )
- .collect::<Vec<_>>();
- let mut expected_symbols = expected
- .iter()
- .map(|expected| expected.iter().map(|s| s.to_string()).collect::<Vec<_>>())
- .collect::<Vec<_>>();
- actual_symbols.sort();
- expected_symbols.sort();
- if actual_symbols != expected_symbols {
- let top_layer = snapshot.syntax_layers().next().unwrap();
- panic!(
- "Expected imports: {:?}\n\
- Actual imports: {:?}\n\
- Tree:\n{}",
- expected_symbols,
- actual_symbols,
- tree_to_string(&top_layer.node()),
- );
- }
- }
-
- fn tree_to_string(node: &tree_sitter::Node) -> String {
- let mut cursor = node.walk();
- let mut result = String::new();
- let mut depth = 0;
- 'outer: loop {
- result.push_str(&" ".repeat(depth));
- if let Some(field_name) = cursor.field_name() {
- result.push_str(field_name);
- result.push_str(": ");
- }
- if cursor.node().is_named() {
- result.push_str(cursor.node().kind());
- } else {
- result.push('"');
- result.push_str(cursor.node().kind());
- result.push('"');
- }
- result.push('\n');
-
- if cursor.goto_first_child() {
- depth += 1;
- continue;
- }
- if cursor.goto_next_sibling() {
- continue;
- }
- while cursor.goto_parent() {
- depth -= 1;
- if cursor.goto_next_sibling() {
- continue 'outer;
- }
- }
- break;
- }
- result
- }
-
- static RUST: LazyLock<Arc<Language>> = LazyLock::new(|| {
- Arc::new(
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- ignored_import_segments: HashSet::from_iter(["crate".into(), "super".into()]),
- import_path_strip_regex: Some(Regex::new("/(lib|mod)\\.rs$").unwrap()),
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- )
- .with_imports_query(include_str!("../../languages/src/rust/imports.scm"))
- .unwrap(),
- )
- });
-
- static TYPESCRIPT: LazyLock<Arc<Language>> = LazyLock::new(|| {
- Arc::new(
- Language::new(
- LanguageConfig {
- name: "TypeScript".into(),
- import_path_strip_regex: Some(Regex::new("(?:/index)?\\.[jt]s$").unwrap()),
- ..Default::default()
- },
- Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()),
- )
- .with_imports_query(include_str!("../../languages/src/typescript/imports.scm"))
- .unwrap(),
- )
- });
-
- static PYTHON: LazyLock<Arc<Language>> = LazyLock::new(|| {
- Arc::new(
- Language::new(
- LanguageConfig {
- name: "Python".into(),
- import_path_strip_regex: Some(Regex::new("/__init__\\.py$").unwrap()),
- ..Default::default()
- },
- Some(tree_sitter_python::LANGUAGE.into()),
- )
- .with_imports_query(include_str!("../../languages/src/python/imports.scm"))
- .unwrap(),
- )
- });
-
- // TODO: Ideally should use actual language configurations
- static C: LazyLock<Arc<Language>> = LazyLock::new(|| {
- Arc::new(
- Language::new(
- LanguageConfig {
- name: "C".into(),
- import_path_strip_regex: Some(Regex::new("^<|>$").unwrap()),
- ..Default::default()
- },
- Some(tree_sitter_c::LANGUAGE.into()),
- )
- .with_imports_query(include_str!("../../languages/src/c/imports.scm"))
- .unwrap(),
- )
- });
-
- static CPP: LazyLock<Arc<Language>> = LazyLock::new(|| {
- Arc::new(
- Language::new(
- LanguageConfig {
- name: "C++".into(),
- import_path_strip_regex: Some(Regex::new("^<|>$").unwrap()),
- ..Default::default()
- },
- Some(tree_sitter_cpp::LANGUAGE.into()),
- )
- .with_imports_query(include_str!("../../languages/src/cpp/imports.scm"))
- .unwrap(),
- )
- });
-
- static GO: LazyLock<Arc<Language>> = LazyLock::new(|| {
- Arc::new(
- Language::new(
- LanguageConfig {
- name: "Go".into(),
- ..Default::default()
- },
- Some(tree_sitter_go::LANGUAGE.into()),
- )
- .with_imports_query(include_str!("../../languages/src/go/imports.scm"))
- .unwrap(),
- )
- });
-
- impl Import {
- fn to_identifier_parts(&self, identifier: &str) -> Vec<String> {
- match self {
- Import::Direct { module } => module.to_identifier_parts(identifier),
- Import::Alias {
- module,
- external_identifier: external_name,
- } => {
- module.to_identifier_parts(&format!("{} AS {}", external_name.name, identifier))
- }
- }
- }
- }
-
- impl Module {
- fn to_identifier_parts(&self, identifier: &str) -> Vec<String> {
- match self {
- Self::Namespace(namespace) => namespace.to_identifier_parts(identifier),
- Self::SourceExact(path) => {
- vec![
- format!("SOURCE {}", path.display().to_string().replace("\\", "/")),
- identifier.to_string(),
- ]
- }
- Self::SourceFuzzy(path) => {
- vec![
- format!(
- "SOURCE FUZZY {}",
- path.display().to_string().replace("\\", "/")
- ),
- identifier.to_string(),
- ]
- }
- }
- }
- }
-
- impl Namespace {
- fn to_identifier_parts(&self, identifier: &str) -> Vec<String> {
- self.0
- .iter()
- .map(|chunk| chunk.to_string())
- .chain(std::iter::once(identifier.to_string()))
- .collect::<Vec<_>>()
- }
- }
-}
@@ -1,126 +0,0 @@
-use language::{BufferSnapshot, SyntaxMapMatches};
-use std::{cmp::Reverse, ops::Range};
-
-use crate::declaration::Identifier;
-
-// TODO:
-//
-// * how to handle multiple name captures? for now last one wins
-//
-// * annotation ranges
-//
-// * new "signature" capture for outline queries
-//
-// * Check parent behavior of "int x, y = 0" declarations in a test
-
-pub struct OutlineDeclaration {
- pub parent_index: Option<usize>,
- pub identifier: Identifier,
- pub item_range: Range<usize>,
- pub signature_range: Range<usize>,
-}
-
-pub fn declarations_in_buffer(buffer: &BufferSnapshot) -> Vec<OutlineDeclaration> {
- declarations_overlapping_range(0..buffer.len(), buffer)
-}
-
-pub fn declarations_overlapping_range(
- range: Range<usize>,
- buffer: &BufferSnapshot,
-) -> Vec<OutlineDeclaration> {
- let mut declarations = OutlineIterator::new(range, buffer).collect::<Vec<_>>();
- declarations.sort_unstable_by_key(|item| (item.item_range.start, Reverse(item.item_range.end)));
-
- let mut parent_stack: Vec<(usize, Range<usize>)> = Vec::new();
- for (index, declaration) in declarations.iter_mut().enumerate() {
- while let Some((top_parent_index, top_parent_range)) = parent_stack.last() {
- if declaration.item_range.start >= top_parent_range.end {
- parent_stack.pop();
- } else {
- declaration.parent_index = Some(*top_parent_index);
- break;
- }
- }
- parent_stack.push((index, declaration.item_range.clone()));
- }
- declarations
-}
-
-/// Iterates outline items without being ordered w.r.t. nested items and without populating
-/// `parent`.
-pub struct OutlineIterator<'a> {
- buffer: &'a BufferSnapshot,
- matches: SyntaxMapMatches<'a>,
-}
-
-impl<'a> OutlineIterator<'a> {
- pub fn new(range: Range<usize>, buffer: &'a BufferSnapshot) -> Self {
- let matches = buffer.syntax.matches(range, &buffer.text, |grammar| {
- grammar.outline_config.as_ref().map(|c| &c.query)
- });
-
- Self { buffer, matches }
- }
-}
-
-impl<'a> Iterator for OutlineIterator<'a> {
- type Item = OutlineDeclaration;
-
- fn next(&mut self) -> Option<Self::Item> {
- while let Some(mat) = self.matches.peek() {
- let config = self.matches.grammars()[mat.grammar_index]
- .outline_config
- .as_ref()
- .unwrap();
-
- let mut name_range = None;
- let mut item_range = None;
- let mut signature_start = None;
- let mut signature_end = None;
-
- let mut add_to_signature = |range: Range<usize>| {
- if signature_start.is_none() {
- signature_start = Some(range.start);
- }
- signature_end = Some(range.end);
- };
-
- for capture in mat.captures {
- let range = capture.node.byte_range();
- if capture.index == config.name_capture_ix {
- name_range = Some(range.clone());
- add_to_signature(range);
- } else if Some(capture.index) == config.context_capture_ix
- || Some(capture.index) == config.extra_context_capture_ix
- {
- add_to_signature(range);
- } else if capture.index == config.item_capture_ix {
- item_range = Some(range.clone());
- }
- }
-
- let language_id = mat.language.id();
- self.matches.advance();
-
- if let Some(name_range) = name_range
- && let Some(item_range) = item_range
- && let Some(signature_start) = signature_start
- && let Some(signature_end) = signature_end
- {
- let name = self
- .buffer
- .text_for_range(name_range)
- .collect::<String>()
- .into();
-
- return Some(OutlineDeclaration {
- identifier: Identifier { name, language_id },
- item_range: item_range,
- signature_range: signature_start..signature_end,
- parent_index: None,
- });
- }
- }
- None
- }
-}
@@ -1,173 +0,0 @@
-use collections::HashMap;
-use language::BufferSnapshot;
-use std::ops::Range;
-use util::RangeExt;
-
-use crate::{
- declaration::Identifier,
- excerpt::{EditPredictionExcerpt, EditPredictionExcerptText},
-};
-
-#[derive(Debug, Clone)]
-pub struct Reference {
- pub identifier: Identifier,
- pub range: Range<usize>,
- pub region: ReferenceRegion,
-}
-
-#[derive(Copy, Clone, Debug, Eq, PartialEq)]
-pub enum ReferenceRegion {
- Breadcrumb,
- Nearby,
-}
-
-pub fn references_in_excerpt(
- excerpt: &EditPredictionExcerpt,
- excerpt_text: &EditPredictionExcerptText,
- snapshot: &BufferSnapshot,
-) -> HashMap<Identifier, Vec<Reference>> {
- let mut references = references_in_range(
- excerpt.range.clone(),
- excerpt_text.body.as_str(),
- ReferenceRegion::Nearby,
- snapshot,
- );
-
- for ((_, range), text) in excerpt
- .parent_declarations
- .iter()
- .zip(excerpt_text.parent_signatures.iter())
- {
- references.extend(references_in_range(
- range.clone(),
- text.as_str(),
- ReferenceRegion::Breadcrumb,
- snapshot,
- ));
- }
-
- let mut identifier_to_references: HashMap<Identifier, Vec<Reference>> = HashMap::default();
- for reference in references {
- identifier_to_references
- .entry(reference.identifier.clone())
- .or_insert_with(Vec::new)
- .push(reference);
- }
- identifier_to_references
-}
-
-/// Finds all nodes which have a "variable" match from the highlights query within the offset range.
-pub fn references_in_range(
- range: Range<usize>,
- range_text: &str,
- reference_region: ReferenceRegion,
- buffer: &BufferSnapshot,
-) -> Vec<Reference> {
- let mut matches = buffer
- .syntax
- .matches(range.clone(), &buffer.text, |grammar| {
- grammar
- .highlights_config
- .as_ref()
- .map(|config| &config.query)
- });
-
- let mut references = Vec::new();
- let mut last_added_range = None;
- while let Some(mat) = matches.peek() {
- let config = matches.grammars()[mat.grammar_index]
- .highlights_config
- .as_ref();
-
- if let Some(config) = config {
- for capture in mat.captures {
- if config.identifier_capture_indices.contains(&capture.index) {
- let node_range = capture.node.byte_range();
-
- // sometimes multiple highlight queries match - this deduplicates them
- if Some(node_range.clone()) == last_added_range {
- continue;
- }
-
- if !range.contains_inclusive(&node_range) {
- continue;
- }
-
- let identifier_text =
- &range_text[node_range.start - range.start..node_range.end - range.start];
-
- references.push(Reference {
- identifier: Identifier {
- name: identifier_text.into(),
- language_id: mat.language.id(),
- },
- range: node_range.clone(),
- region: reference_region,
- });
- last_added_range = Some(node_range);
- }
- }
- }
-
- matches.advance();
- }
- references
-}
-
-#[cfg(test)]
-mod test {
- use gpui::{TestAppContext, prelude::*};
- use indoc::indoc;
- use language::{BufferSnapshot, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
-
- use crate::reference::{ReferenceRegion, references_in_range};
-
- #[gpui::test]
- fn test_identifier_node_truncated(cx: &mut TestAppContext) {
- let code = indoc! { r#"
- fn main() {
- add(1, 2);
- }
-
- fn add(a: i32, b: i32) -> i32 {
- a + b
- }
- "# };
- let buffer = create_buffer(code, cx);
-
- let range = 0..35;
- let references = references_in_range(
- range.clone(),
- &code[range],
- ReferenceRegion::Breadcrumb,
- &buffer,
- );
- assert_eq!(references.len(), 2);
- assert_eq!(references[0].identifier.name.as_ref(), "main");
- assert_eq!(references[1].identifier.name.as_ref(), "add");
- }
-
- fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
- let buffer =
- cx.new(|cx| language::Buffer::local(text, cx).with_language(rust_lang().into(), cx));
- buffer.read_with(cx, |buffer, _| buffer.snapshot())
- }
-
- fn rust_lang() -> Language {
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- )
- .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
- .unwrap()
- .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
- .unwrap()
- }
-}
@@ -1,1069 +0,0 @@
-use anyhow::{Result, anyhow};
-use collections::{HashMap, HashSet};
-use futures::channel::mpsc;
-use futures::lock::Mutex;
-use futures::{FutureExt as _, StreamExt, future};
-use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity};
-use itertools::Itertools;
-
-use language::{Buffer, BufferEvent};
-use postage::stream::Stream as _;
-use project::buffer_store::{BufferStore, BufferStoreEvent};
-use project::worktree_store::{WorktreeStore, WorktreeStoreEvent};
-use project::{PathChange, Project, ProjectEntryId, ProjectPath};
-use slotmap::SlotMap;
-use std::iter;
-use std::ops::{DerefMut, Range};
-use std::sync::Arc;
-use text::BufferId;
-use util::{RangeExt as _, debug_panic, some_or_debug_panic};
-
-use crate::CachedDeclarationPath;
-use crate::declaration::{
- BufferDeclaration, Declaration, DeclarationId, FileDeclaration, Identifier,
-};
-use crate::outline::declarations_in_buffer;
-
-// TODO
-//
-// * Also queue / debounce buffer changes. A challenge for this is that use of
-// `buffer_declarations_containing_range` assumes that the index is always immediately up to date.
-//
-// * Add a per language configuration for skipping indexing.
-//
-// * Handle tsx / ts / js referencing each-other
-
-// Potential future improvements:
-//
-// * Prevent indexing of a large file from blocking the queue.
-//
-// * Send multiple selected excerpt ranges. Challenge is that excerpt ranges influence which
-// references are present and their scores.
-//
-// * Include single-file worktrees / non visible worktrees? E.g. go to definition that resolves to a
-// file in a build dependency. Should not be editable in that case - but how to distinguish the case
-// where it should be editable?
-
-// Potential future optimizations:
-//
-// * Index files on multiple threads in Zed (currently only parallel for the CLI). Adding some kind
-// of priority system to the background executor could help - it's single threaded for now to avoid
-// interfering with other work.
-//
-// * Parse files directly instead of loading into a Rope.
-//
-// - This would allow the task handling dirty_files to be done entirely on the background executor.
-//
-// - Make SyntaxMap generic to handle embedded languages? Will also need to find line boundaries,
-// but that can be done by scanning characters in the flat representation.
-//
-// * Use something similar to slotmap without key versions.
-//
-// * Concurrent slotmap
-
-pub struct SyntaxIndex {
- state: Arc<Mutex<SyntaxIndexState>>,
- project: WeakEntity<Project>,
- initial_file_indexing_done_rx: postage::watch::Receiver<bool>,
- _file_indexing_task: Option<Task<()>>,
-}
-
-pub struct SyntaxIndexState {
- declarations: SlotMap<DeclarationId, Declaration>,
- identifiers: HashMap<Identifier, HashSet<DeclarationId>>,
- files: HashMap<ProjectEntryId, FileState>,
- buffers: HashMap<BufferId, BufferState>,
- dirty_files: HashMap<ProjectEntryId, ProjectPath>,
- dirty_files_tx: mpsc::Sender<()>,
-}
-
-#[derive(Debug, Default)]
-struct FileState {
- declarations: Vec<DeclarationId>,
-}
-
-#[derive(Default)]
-struct BufferState {
- declarations: Vec<DeclarationId>,
- task: Option<Task<()>>,
-}
-
-impl SyntaxIndex {
- pub fn new(
- project: &Entity<Project>,
- file_indexing_parallelism: usize,
- cx: &mut Context<Self>,
- ) -> Self {
- assert!(file_indexing_parallelism > 0);
- let (dirty_files_tx, mut dirty_files_rx) = mpsc::channel::<()>(1);
- let (mut initial_file_indexing_done_tx, initial_file_indexing_done_rx) =
- postage::watch::channel();
-
- let initial_state = SyntaxIndexState {
- declarations: SlotMap::default(),
- identifiers: HashMap::default(),
- files: HashMap::default(),
- buffers: HashMap::default(),
- dirty_files: HashMap::default(),
- dirty_files_tx,
- };
- let mut this = Self {
- project: project.downgrade(),
- state: Arc::new(Mutex::new(initial_state)),
- initial_file_indexing_done_rx,
- _file_indexing_task: None,
- };
-
- let worktree_store = project.read(cx).worktree_store();
- let initial_worktree_snapshots = worktree_store
- .read(cx)
- .worktrees()
- .map(|w| w.read(cx).snapshot())
- .collect::<Vec<_>>();
- this._file_indexing_task = Some(cx.spawn(async move |this, cx| {
- let snapshots_file_count = initial_worktree_snapshots
- .iter()
- .map(|worktree| worktree.file_count())
- .sum::<usize>();
- if snapshots_file_count > 0 {
- let chunk_size = snapshots_file_count.div_ceil(file_indexing_parallelism);
- let chunk_count = snapshots_file_count.div_ceil(chunk_size);
- let file_chunks = initial_worktree_snapshots
- .iter()
- .flat_map(|worktree| {
- let worktree_id = worktree.id();
- worktree.files(false, 0).map(move |entry| {
- (
- entry.id,
- ProjectPath {
- worktree_id,
- path: entry.path.clone(),
- },
- )
- })
- })
- .chunks(chunk_size);
-
- let mut tasks = Vec::with_capacity(chunk_count);
- for chunk in file_chunks.into_iter() {
- tasks.push(Self::update_dirty_files(
- &this,
- chunk.into_iter().collect(),
- cx.clone(),
- ));
- }
- futures::future::join_all(tasks).await;
- log::info!("Finished initial file indexing");
- }
-
- *initial_file_indexing_done_tx.borrow_mut() = true;
-
- let Ok(state) = this.read_with(cx, |this, _cx| Arc::downgrade(&this.state)) else {
- return;
- };
- while dirty_files_rx.next().await.is_some() {
- let Some(state) = state.upgrade() else {
- return;
- };
- let mut state = state.lock().await;
- let was_underused = state.dirty_files.capacity() > 255
- && state.dirty_files.len() * 8 < state.dirty_files.capacity();
- let dirty_files = state.dirty_files.drain().collect::<Vec<_>>();
- if was_underused {
- state.dirty_files.shrink_to_fit();
- }
- drop(state);
- if dirty_files.is_empty() {
- continue;
- }
-
- let chunk_size = dirty_files.len().div_ceil(file_indexing_parallelism);
- let chunk_count = dirty_files.len().div_ceil(chunk_size);
- let mut tasks = Vec::with_capacity(chunk_count);
- let chunks = dirty_files.into_iter().chunks(chunk_size);
- for chunk in chunks.into_iter() {
- tasks.push(Self::update_dirty_files(
- &this,
- chunk.into_iter().collect(),
- cx.clone(),
- ));
- }
- futures::future::join_all(tasks).await;
- }
- }));
-
- cx.subscribe(&worktree_store, Self::handle_worktree_store_event)
- .detach();
-
- let buffer_store = project.read(cx).buffer_store().clone();
- for buffer in buffer_store.read(cx).buffers().collect::<Vec<_>>() {
- this.register_buffer(&buffer, cx);
- }
- cx.subscribe(&buffer_store, Self::handle_buffer_store_event)
- .detach();
-
- this
- }
-
- async fn update_dirty_files(
- this: &WeakEntity<Self>,
- dirty_files: Vec<(ProjectEntryId, ProjectPath)>,
- mut cx: AsyncApp,
- ) {
- for (entry_id, project_path) in dirty_files {
- let Ok(task) = this.update(&mut cx, |this, cx| {
- this.update_file(entry_id, project_path, cx)
- }) else {
- return;
- };
- task.await;
- }
- }
-
- pub fn wait_for_initial_file_indexing(&self, cx: &App) -> Task<Result<()>> {
- if *self.initial_file_indexing_done_rx.borrow() {
- Task::ready(Ok(()))
- } else {
- let mut rx = self.initial_file_indexing_done_rx.clone();
- cx.background_spawn(async move {
- loop {
- match rx.recv().await {
- Some(true) => return Ok(()),
- Some(false) => {}
- None => {
- return Err(anyhow!(
- "SyntaxIndex dropped while waiting for initial file indexing"
- ));
- }
- }
- }
- })
- }
- }
-
- pub fn indexed_file_paths(&self, cx: &App) -> Task<Vec<ProjectPath>> {
- let state = self.state.clone();
- let project = self.project.clone();
-
- cx.spawn(async move |cx| {
- let state = state.lock().await;
- let Some(project) = project.upgrade() else {
- return vec![];
- };
- project
- .read_with(cx, |project, cx| {
- state
- .files
- .keys()
- .filter_map(|entry_id| project.path_for_entry(*entry_id, cx))
- .collect()
- })
- .unwrap_or_default()
- })
- }
-
- fn handle_worktree_store_event(
- &mut self,
- _worktree_store: Entity<WorktreeStore>,
- event: &WorktreeStoreEvent,
- cx: &mut Context<Self>,
- ) {
- use WorktreeStoreEvent::*;
- match event {
- WorktreeUpdatedEntries(worktree_id, updated_entries_set) => {
- let state = Arc::downgrade(&self.state);
- let worktree_id = *worktree_id;
- let updated_entries_set = updated_entries_set.clone();
- cx.background_spawn(async move {
- let Some(state) = state.upgrade() else { return };
- let mut state = state.lock().await;
- for (path, entry_id, path_change) in updated_entries_set.iter() {
- if let PathChange::Removed = path_change {
- state.files.remove(entry_id);
- state.dirty_files.remove(entry_id);
- } else {
- let project_path = ProjectPath {
- worktree_id,
- path: path.clone(),
- };
- state.dirty_files.insert(*entry_id, project_path);
- }
- }
- match state.dirty_files_tx.try_send(()) {
- Err(err) if err.is_disconnected() => {
- log::error!("bug: syntax indexing queue is disconnected");
- }
- _ => {}
- }
- })
- .detach();
- }
- WorktreeDeletedEntry(_worktree_id, project_entry_id) => {
- let project_entry_id = *project_entry_id;
- self.with_state(cx, move |state| {
- state.files.remove(&project_entry_id);
- })
- }
- _ => {}
- }
- }
-
- fn handle_buffer_store_event(
- &mut self,
- _buffer_store: Entity<BufferStore>,
- event: &BufferStoreEvent,
- cx: &mut Context<Self>,
- ) {
- use BufferStoreEvent::*;
- match event {
- BufferAdded(buffer) => self.register_buffer(buffer, cx),
- BufferOpened { .. }
- | BufferChangedFilePath { .. }
- | BufferDropped { .. }
- | SharedBufferClosed { .. } => {}
- }
- }
-
- pub fn state(&self) -> &Arc<Mutex<SyntaxIndexState>> {
- &self.state
- }
-
- fn with_state(&self, cx: &mut App, f: impl FnOnce(&mut SyntaxIndexState) + Send + 'static) {
- if let Some(mut state) = self.state.try_lock() {
- f(&mut state);
- return;
- }
- let state = Arc::downgrade(&self.state);
- cx.background_spawn(async move {
- let Some(state) = state.upgrade() else {
- return;
- };
- let mut state = state.lock().await;
- f(&mut state)
- })
- .detach();
- }
-
- fn register_buffer(&self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
- let buffer_id = buffer.read(cx).remote_id();
- cx.observe_release(buffer, move |this, _buffer, cx| {
- this.with_state(cx, move |state| {
- if let Some(buffer_state) = state.buffers.remove(&buffer_id) {
- SyntaxIndexState::remove_buffer_declarations(
- &buffer_state.declarations,
- &mut state.declarations,
- &mut state.identifiers,
- );
- }
- })
- })
- .detach();
- cx.subscribe(buffer, Self::handle_buffer_event).detach();
-
- self.update_buffer(buffer.clone(), cx);
- }
-
- fn handle_buffer_event(
- &mut self,
- buffer: Entity<Buffer>,
- event: &BufferEvent,
- cx: &mut Context<Self>,
- ) {
- match event {
- BufferEvent::Edited |
- // paths are cached and so should be updated
- BufferEvent::FileHandleChanged => self.update_buffer(buffer, cx),
- _ => {}
- }
- }
-
- fn update_buffer(&self, buffer_entity: Entity<Buffer>, cx: &mut Context<Self>) {
- let buffer = buffer_entity.read(cx);
- if buffer.language().is_none() {
- return;
- }
-
- let Some((project_entry_id, cached_path)) = project::File::from_dyn(buffer.file())
- .and_then(|f| {
- let project_entry_id = f.project_entry_id()?;
- let cached_path = CachedDeclarationPath::new(
- f.worktree.read(cx).abs_path(),
- &f.path,
- buffer.language(),
- );
- Some((project_entry_id, cached_path))
- })
- else {
- return;
- };
- let buffer_id = buffer.remote_id();
-
- let mut parse_status = buffer.parse_status();
- let snapshot_task = cx.spawn({
- let weak_buffer = buffer_entity.downgrade();
- async move |_, cx| {
- while *parse_status.borrow() != language::ParseStatus::Idle {
- parse_status.changed().await?;
- }
- weak_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())
- }
- });
-
- let state = Arc::downgrade(&self.state);
- let task = cx.background_spawn(async move {
- // TODO: How to handle errors?
- let Ok(snapshot) = snapshot_task.await else {
- return;
- };
- let rope = snapshot.text.as_rope();
-
- let declarations = declarations_in_buffer(&snapshot)
- .into_iter()
- .map(|item| {
- (
- item.parent_index,
- BufferDeclaration::from_outline(item, &rope),
- )
- })
- .collect::<Vec<_>>();
-
- let Some(state) = state.upgrade() else {
- return;
- };
- let mut state = state.lock().await;
- let state = state.deref_mut();
-
- let buffer_state = state
- .buffers
- .entry(buffer_id)
- .or_insert_with(Default::default);
-
- SyntaxIndexState::remove_buffer_declarations(
- &buffer_state.declarations,
- &mut state.declarations,
- &mut state.identifiers,
- );
-
- let mut new_ids = Vec::with_capacity(declarations.len());
- state.declarations.reserve(declarations.len());
- for (parent_index, mut declaration) in declarations {
- declaration.parent =
- parent_index.and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
-
- let identifier = declaration.identifier.clone();
- let declaration_id = state.declarations.insert(Declaration::Buffer {
- rope: rope.clone(),
- buffer_id,
- declaration,
- project_entry_id,
- cached_path: cached_path.clone(),
- });
- new_ids.push(declaration_id);
-
- state
- .identifiers
- .entry(identifier)
- .or_default()
- .insert(declaration_id);
- }
-
- buffer_state.declarations = new_ids;
- });
-
- self.with_state(cx, move |state| {
- state
- .buffers
- .entry(buffer_id)
- .or_insert_with(Default::default)
- .task = Some(task)
- });
- }
-
- fn update_file(
- &mut self,
- entry_id: ProjectEntryId,
- project_path: ProjectPath,
- cx: &mut Context<Self>,
- ) -> Task<()> {
- let Some(project) = self.project.upgrade() else {
- return Task::ready(());
- };
- let project = project.read(cx);
-
- let language_registry = project.languages();
- let Some(available_language) =
- language_registry.language_for_file_path(project_path.path.as_std_path())
- else {
- return Task::ready(());
- };
- let language = if let Some(Ok(Ok(language))) = language_registry
- .load_language(&available_language)
- .now_or_never()
- {
- if language
- .grammar()
- .is_none_or(|grammar| grammar.outline_config.is_none())
- {
- return Task::ready(());
- }
- future::Either::Left(async { Ok(language) })
- } else {
- let language_registry = language_registry.clone();
- future::Either::Right(async move {
- anyhow::Ok(
- language_registry
- .load_language(&available_language)
- .await??,
- )
- })
- };
-
- let Some(worktree) = project.worktree_for_id(project_path.worktree_id, cx) else {
- return Task::ready(());
- };
-
- let snapshot_task = worktree.update(cx, |worktree, cx| {
- let load_task = worktree.load_file(&project_path.path, cx);
- let worktree_abs_path = worktree.abs_path();
- cx.spawn(async move |_this, cx| {
- let loaded_file = load_task.await?;
- let language = language.await?;
-
- let buffer = cx.new(|cx| {
- let mut buffer = Buffer::local(loaded_file.text, cx);
- buffer.set_language(Some(language.clone()), cx);
- buffer
- })?;
-
- let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
- while *parse_status.borrow() != language::ParseStatus::Idle {
- parse_status.changed().await?;
- }
-
- let cached_path = CachedDeclarationPath::new(
- worktree_abs_path,
- &project_path.path,
- Some(&language),
- );
-
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
-
- anyhow::Ok((snapshot, cached_path))
- })
- });
-
- let state = Arc::downgrade(&self.state);
- cx.background_spawn(async move {
- // TODO: How to handle errors?
- let Ok((snapshot, cached_path)) = snapshot_task.await else {
- return;
- };
- let rope = snapshot.as_rope();
- let declarations = declarations_in_buffer(&snapshot)
- .into_iter()
- .map(|item| (item.parent_index, FileDeclaration::from_outline(item, rope)))
- .collect::<Vec<_>>();
-
- let Some(state) = state.upgrade() else {
- return;
- };
- let mut state = state.lock().await;
- let state = state.deref_mut();
-
- let file_state = state.files.entry(entry_id).or_insert_with(Default::default);
- for old_declaration_id in &file_state.declarations {
- let Some(declaration) = state.declarations.remove(*old_declaration_id) else {
- debug_panic!("declaration not found");
- continue;
- };
- if let Some(identifier_declarations) =
- state.identifiers.get_mut(declaration.identifier())
- {
- identifier_declarations.remove(old_declaration_id);
- }
- }
-
- let mut new_ids = Vec::with_capacity(declarations.len());
- state.declarations.reserve(declarations.len());
- for (parent_index, mut declaration) in declarations {
- declaration.parent =
- parent_index.and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
-
- let identifier = declaration.identifier.clone();
- let declaration_id = state.declarations.insert(Declaration::File {
- project_entry_id: entry_id,
- declaration,
- cached_path: cached_path.clone(),
- });
- new_ids.push(declaration_id);
-
- state
- .identifiers
- .entry(identifier)
- .or_default()
- .insert(declaration_id);
- }
- file_state.declarations = new_ids;
- })
- }
-}
-
-impl SyntaxIndexState {
- pub fn declaration(&self, id: DeclarationId) -> Option<&Declaration> {
- self.declarations.get(id)
- }
-
- /// Returns declarations for the identifier. If the limit is exceeded, returns an empty vector.
- ///
- /// TODO: Consider doing some pre-ranking and instead truncating when N is exceeded.
- pub fn declarations_for_identifier<const N: usize>(
- &self,
- identifier: &Identifier,
- ) -> Vec<(DeclarationId, &Declaration)> {
- // make sure to not have a large stack allocation
- assert!(N < 32);
-
- let Some(declaration_ids) = self.identifiers.get(&identifier) else {
- return vec![];
- };
-
- let mut result = Vec::with_capacity(N);
- let mut included_buffer_entry_ids = arrayvec::ArrayVec::<_, N>::new();
- let mut file_declarations = Vec::new();
-
- for declaration_id in declaration_ids {
- let declaration = self.declarations.get(*declaration_id);
- let Some(declaration) = some_or_debug_panic(declaration) else {
- continue;
- };
- match declaration {
- Declaration::Buffer {
- project_entry_id, ..
- } => {
- included_buffer_entry_ids.push(*project_entry_id);
- result.push((*declaration_id, declaration));
- if result.len() == N {
- return Vec::new();
- }
- }
- Declaration::File {
- project_entry_id, ..
- } => {
- if !included_buffer_entry_ids.contains(&project_entry_id) {
- file_declarations.push((*declaration_id, declaration));
- }
- }
- }
- }
-
- for (declaration_id, declaration) in file_declarations {
- match declaration {
- Declaration::File {
- project_entry_id, ..
- } => {
- if !included_buffer_entry_ids.contains(&project_entry_id) {
- result.push((declaration_id, declaration));
-
- if result.len() == N {
- return Vec::new();
- }
- }
- }
- Declaration::Buffer { .. } => {}
- }
- }
-
- result
- }
-
- pub fn buffer_declarations_containing_range(
- &self,
- buffer_id: BufferId,
- range: Range<usize>,
- ) -> impl Iterator<Item = (DeclarationId, &BufferDeclaration)> {
- let Some(buffer_state) = self.buffers.get(&buffer_id) else {
- return itertools::Either::Left(iter::empty());
- };
-
- let iter = buffer_state
- .declarations
- .iter()
- .filter_map(move |declaration_id| {
- let Some(declaration) = self
- .declarations
- .get(*declaration_id)
- .and_then(|d| d.as_buffer())
- else {
- log::error!("bug: missing buffer outline declaration");
- return None;
- };
- if declaration.item_range.contains_inclusive(&range) {
- return Some((*declaration_id, declaration));
- }
- return None;
- });
- itertools::Either::Right(iter)
- }
-
- pub fn file_declaration_count(&self, declaration: &Declaration) -> usize {
- match declaration {
- Declaration::File {
- project_entry_id, ..
- } => self
- .files
- .get(project_entry_id)
- .map(|file_state| file_state.declarations.len())
- .unwrap_or_default(),
- Declaration::Buffer { buffer_id, .. } => self
- .buffers
- .get(buffer_id)
- .map(|buffer_state| buffer_state.declarations.len())
- .unwrap_or_default(),
- }
- }
-
- fn remove_buffer_declarations(
- old_declaration_ids: &[DeclarationId],
- declarations: &mut SlotMap<DeclarationId, Declaration>,
- identifiers: &mut HashMap<Identifier, HashSet<DeclarationId>>,
- ) {
- for old_declaration_id in old_declaration_ids {
- let Some(declaration) = declarations.remove(*old_declaration_id) else {
- debug_panic!("declaration not found");
- continue;
- };
- if let Some(identifier_declarations) = identifiers.get_mut(declaration.identifier()) {
- identifier_declarations.remove(old_declaration_id);
- }
- }
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use std::sync::Arc;
-
- use gpui::TestAppContext;
- use indoc::indoc;
- use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
- use project::{FakeFs, Project};
- use serde_json::json;
- use settings::SettingsStore;
- use text::OffsetRangeExt as _;
- use util::{path, rel_path::rel_path};
-
- use crate::syntax_index::SyntaxIndex;
-
- #[gpui::test]
- async fn test_unopen_indexed_files(cx: &mut TestAppContext) {
- let (project, index, rust_lang_id) = init_test(cx).await;
- let main = Identifier {
- name: "main".into(),
- language_id: rust_lang_id,
- };
-
- let index_state = index.read_with(cx, |index, _cx| index.state().clone());
- let index_state = index_state.lock().await;
- cx.update(|cx| {
- let decls = index_state.declarations_for_identifier::<8>(&main);
- assert_eq!(decls.len(), 2);
-
- let decl = expect_file_decl("a.rs", &decls[0].1, &project, cx);
- assert_eq!(decl.identifier, main);
- assert_eq!(decl.item_range, 0..98);
-
- let decl = expect_file_decl("c.rs", &decls[1].1, &project, cx);
- assert_eq!(decl.identifier, main.clone());
- assert_eq!(decl.item_range, 32..280);
- });
- }
-
- #[gpui::test]
- async fn test_parents_in_file(cx: &mut TestAppContext) {
- let (project, index, rust_lang_id) = init_test(cx).await;
- let test_process_data = Identifier {
- name: "test_process_data".into(),
- language_id: rust_lang_id,
- };
-
- let index_state = index.read_with(cx, |index, _cx| index.state().clone());
- let index_state = index_state.lock().await;
- cx.update(|cx| {
- let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
- assert_eq!(decls.len(), 1);
-
- let decl = expect_file_decl("c.rs", &decls[0].1, &project, cx);
- assert_eq!(decl.identifier, test_process_data);
-
- let parent_id = decl.parent.unwrap();
- let parent = index_state.declaration(parent_id).unwrap();
- let parent_decl = expect_file_decl("c.rs", &parent, &project, cx);
- assert_eq!(
- parent_decl.identifier,
- Identifier {
- name: "tests".into(),
- language_id: rust_lang_id
- }
- );
- assert_eq!(parent_decl.parent, None);
- });
- }
-
- #[gpui::test]
- async fn test_parents_in_buffer(cx: &mut TestAppContext) {
- let (project, index, rust_lang_id) = init_test(cx).await;
- let test_process_data = Identifier {
- name: "test_process_data".into(),
- language_id: rust_lang_id,
- };
-
- let buffer = project
- .update(cx, |project, cx| {
- let project_path = project.find_project_path("c.rs", cx).unwrap();
- project.open_buffer(project_path, cx)
- })
- .await
- .unwrap();
-
- cx.run_until_parked();
-
- let index_state = index.read_with(cx, |index, _cx| index.state().clone());
- let index_state = index_state.lock().await;
- cx.update(|cx| {
- let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
- assert_eq!(decls.len(), 1);
-
- let decl = expect_buffer_decl("c.rs", &decls[0].1, &project, cx);
- assert_eq!(decl.identifier, test_process_data);
-
- let parent_id = decl.parent.unwrap();
- let parent = index_state.declaration(parent_id).unwrap();
- let parent_decl = expect_buffer_decl("c.rs", &parent, &project, cx);
- assert_eq!(
- parent_decl.identifier,
- Identifier {
- name: "tests".into(),
- language_id: rust_lang_id
- }
- );
- assert_eq!(parent_decl.parent, None);
- });
-
- drop(buffer);
- }
-
- #[gpui::test]
- async fn test_declarations_limit(cx: &mut TestAppContext) {
- let (_, index, rust_lang_id) = init_test(cx).await;
-
- let index_state = index.read_with(cx, |index, _cx| index.state().clone());
- let index_state = index_state.lock().await;
- let decls = index_state.declarations_for_identifier::<1>(&Identifier {
- name: "main".into(),
- language_id: rust_lang_id,
- });
- assert_eq!(decls.len(), 0);
- }
-
- #[gpui::test]
- async fn test_buffer_shadow(cx: &mut TestAppContext) {
- let (project, index, rust_lang_id) = init_test(cx).await;
-
- let main = Identifier {
- name: "main".into(),
- language_id: rust_lang_id,
- };
-
- let buffer = project
- .update(cx, |project, cx| {
- let project_path = project.find_project_path("c.rs", cx).unwrap();
- project.open_buffer(project_path, cx)
- })
- .await
- .unwrap();
-
- cx.run_until_parked();
-
- let index_state_arc = index.read_with(cx, |index, _cx| index.state().clone());
- {
- let index_state = index_state_arc.lock().await;
-
- cx.update(|cx| {
- let decls = index_state.declarations_for_identifier::<8>(&main);
- assert_eq!(decls.len(), 2);
- let decl = expect_buffer_decl("c.rs", &decls[0].1, &project, cx);
- assert_eq!(decl.identifier, main);
- assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..280);
-
- expect_file_decl("a.rs", &decls[1].1, &project, cx);
- });
- }
-
- // Drop the buffer and wait for release
- cx.update(|_| {
- drop(buffer);
- });
- cx.run_until_parked();
-
- let index_state = index_state_arc.lock().await;
-
- cx.update(|cx| {
- let decls = index_state.declarations_for_identifier::<8>(&main);
- assert_eq!(decls.len(), 2);
- expect_file_decl("a.rs", &decls[0].1, &project, cx);
- expect_file_decl("c.rs", &decls[1].1, &project, cx);
- });
- }
-
- fn expect_buffer_decl<'a>(
- path: &str,
- declaration: &'a Declaration,
- project: &Entity<Project>,
- cx: &App,
- ) -> &'a BufferDeclaration {
- if let Declaration::Buffer {
- declaration,
- project_entry_id,
- ..
- } = declaration
- {
- let project_path = project
- .read(cx)
- .path_for_entry(*project_entry_id, cx)
- .unwrap();
- assert_eq!(project_path.path.as_ref(), rel_path(path),);
- declaration
- } else {
- panic!("Expected a buffer declaration, found {:?}", declaration);
- }
- }
-
- fn expect_file_decl<'a>(
- path: &str,
- declaration: &'a Declaration,
- project: &Entity<Project>,
- cx: &App,
- ) -> &'a FileDeclaration {
- if let Declaration::File {
- declaration,
- project_entry_id: file,
- ..
- } = declaration
- {
- assert_eq!(
- project
- .read(cx)
- .path_for_entry(*file, cx)
- .unwrap()
- .path
- .as_ref(),
- rel_path(path),
- );
- declaration
- } else {
- panic!("Expected a file declaration, found {:?}", declaration);
- }
- }
-
- async fn init_test(
- cx: &mut TestAppContext,
- ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
- cx.update(|cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- });
-
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- path!("/root"),
- json!({
- "a.rs": indoc! {r#"
- fn main() {
- let x = 1;
- let y = 2;
- let z = add(x, y);
- println!("Result: {}", z);
- }
-
- fn add(a: i32, b: i32) -> i32 {
- a + b
- }
- "#},
- "b.rs": indoc! {"
- pub struct Config {
- pub name: String,
- pub value: i32,
- }
-
- impl Config {
- pub fn new(name: String, value: i32) -> Self {
- Config { name, value }
- }
- }
- "},
- "c.rs": indoc! {r#"
- use std::collections::HashMap;
-
- fn main() {
- let args: Vec<String> = std::env::args().collect();
- let data: Vec<i32> = args[1..]
- .iter()
- .filter_map(|s| s.parse().ok())
- .collect();
- let result = process_data(data);
- println!("{:?}", result);
- }
-
- fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
- let mut counts = HashMap::new();
- for value in data {
- *counts.entry(value).or_insert(0) += 1;
- }
- counts
- }
-
- #[cfg(test)]
- mod tests {
- use super::*;
-
- #[test]
- fn test_process_data() {
- let data = vec![1, 2, 2, 3];
- let result = process_data(data);
- assert_eq!(result.get(&2), Some(&2));
- }
- }
- "#}
- }),
- )
- .await;
- let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
- let language_registry = project.read_with(cx, |project, _| project.languages().clone());
- let lang = rust_lang();
- let lang_id = lang.id();
- language_registry.add(Arc::new(lang));
-
- let file_indexing_parallelism = 2;
- let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx));
- cx.run_until_parked();
-
- (project, index, lang_id)
- }
-
- fn rust_lang() -> Language {
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- )
- .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
- .unwrap()
- }
-}
@@ -1,314 +0,0 @@
-use hashbrown::HashTable;
-use regex::Regex;
-use std::{
- borrow::Cow,
- hash::{Hash, Hasher as _},
- path::Path,
- sync::LazyLock,
-};
-use util::rel_path::RelPath;
-
-use crate::reference::Reference;
-
-// TODO: Consider implementing sliding window similarity matching like
-// https://github.com/sourcegraph/cody-public-snapshot/blob/8e20ac6c1460c08b0db581c0204658112a246eda/vscode/src/completions/context/retrievers/jaccard-similarity/bestJaccardMatch.ts
-//
-// That implementation could actually be more efficient - no need to track words in the window that
-// are not in the query.
-
-// TODO: Consider a flat sorted Vec<(String, usize)> representation. Intersection can just walk the
-// two in parallel.
-
-static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
-
-/// Multiset of text occurrences for text similarity that only stores hashes and counts.
-#[derive(Debug, Default)]
-pub struct Occurrences {
- table: HashTable<OccurrenceEntry>,
- total_count: usize,
-}
-
-#[derive(Debug)]
-struct OccurrenceEntry {
- hash: u64,
- count: usize,
-}
-
-impl Occurrences {
- pub fn within_string(text: &str) -> Self {
- Self::from_identifiers(IDENTIFIER_REGEX.find_iter(text).map(|mat| mat.as_str()))
- }
-
- #[allow(dead_code)]
- pub fn within_references(references: &[Reference]) -> Self {
- Self::from_identifiers(
- references
- .iter()
- .map(|reference| reference.identifier.name.as_ref()),
- )
- }
-
- pub fn from_identifiers(identifiers: impl IntoIterator<Item = impl AsRef<str>>) -> Self {
- let mut this = Self::default();
- // TODO: Score matches that match case higher?
- //
- // TODO: Also include unsplit identifier?
- for identifier in identifiers {
- for identifier_part in split_identifier(identifier.as_ref()) {
- this.add_hash(fx_hash(&identifier_part.to_lowercase()));
- }
- }
- this
- }
-
- pub fn from_worktree_path(worktree_name: Option<Cow<'_, str>>, rel_path: &RelPath) -> Self {
- if let Some(worktree_name) = worktree_name {
- Self::from_identifiers(
- std::iter::once(worktree_name)
- .chain(iter_path_without_extension(rel_path.as_std_path())),
- )
- } else {
- Self::from_path(rel_path.as_std_path())
- }
- }
-
- pub fn from_path(path: &Path) -> Self {
- Self::from_identifiers(iter_path_without_extension(path))
- }
-
- fn add_hash(&mut self, hash: u64) {
- self.table
- .entry(
- hash,
- |entry: &OccurrenceEntry| entry.hash == hash,
- |entry| entry.hash,
- )
- .and_modify(|entry| entry.count += 1)
- .or_insert(OccurrenceEntry { hash, count: 1 });
- self.total_count += 1;
- }
-
- fn contains_hash(&self, hash: u64) -> bool {
- self.get_count(hash) != 0
- }
-
- fn get_count(&self, hash: u64) -> usize {
- self.table
- .find(hash, |entry| entry.hash == hash)
- .map(|entry| entry.count)
- .unwrap_or(0)
- }
-}
-
-fn iter_path_without_extension(path: &Path) -> impl Iterator<Item = Cow<'_, str>> {
- let last_component: Option<Cow<'_, str>> = path.file_stem().map(|stem| stem.to_string_lossy());
- let mut path_components = path.components();
- path_components.next_back();
- path_components
- .map(|component| component.as_os_str().to_string_lossy())
- .chain(last_component)
-}
-
-pub fn fx_hash<T: Hash + ?Sized>(data: &T) -> u64 {
- let mut hasher = collections::FxHasher::default();
- data.hash(&mut hasher);
- hasher.finish()
-}
-
-// Splits camelcase / snakecase / kebabcase / pascalcase
-//
-// TODO: Make this more efficient / elegant.
-fn split_identifier(identifier: &str) -> Vec<&str> {
- let mut parts = Vec::new();
- let mut start = 0;
- let chars: Vec<char> = identifier.chars().collect();
-
- if chars.is_empty() {
- return parts;
- }
-
- let mut i = 0;
- while i < chars.len() {
- let ch = chars[i];
-
- // Handle explicit delimiters (underscore and hyphen)
- if ch == '_' || ch == '-' {
- if i > start {
- parts.push(&identifier[start..i]);
- }
- start = i + 1;
- i += 1;
- continue;
- }
-
- // Handle camelCase and PascalCase transitions
- if i > 0 && i < chars.len() {
- let prev_char = chars[i - 1];
-
- // Transition from lowercase/digit to uppercase
- if (prev_char.is_lowercase() || prev_char.is_ascii_digit()) && ch.is_uppercase() {
- parts.push(&identifier[start..i]);
- start = i;
- }
- // Handle sequences like "XMLParser" -> ["XML", "Parser"]
- else if i + 1 < chars.len()
- && ch.is_uppercase()
- && chars[i + 1].is_lowercase()
- && prev_char.is_uppercase()
- {
- parts.push(&identifier[start..i]);
- start = i;
- }
- }
-
- i += 1;
- }
-
- // Add the last part if there's any remaining
- if start < identifier.len() {
- parts.push(&identifier[start..]);
- }
-
- // Filter out empty strings
- parts.into_iter().filter(|s| !s.is_empty()).collect()
-}
-
-pub fn jaccard_similarity<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
- if set_a.table.len() > set_b.table.len() {
- std::mem::swap(&mut set_a, &mut set_b);
- }
- let intersection = set_a
- .table
- .iter()
- .filter(|entry| set_b.contains_hash(entry.hash))
- .count();
- let union = set_a.table.len() + set_b.table.len() - intersection;
- intersection as f32 / union as f32
-}
-
-// TODO
-#[allow(dead_code)]
-pub fn overlap_coefficient<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
- if set_a.table.len() > set_b.table.len() {
- std::mem::swap(&mut set_a, &mut set_b);
- }
- let intersection = set_a
- .table
- .iter()
- .filter(|entry| set_b.contains_hash(entry.hash))
- .count();
- intersection as f32 / set_a.table.len() as f32
-}
-
-// TODO
-#[allow(dead_code)]
-pub fn weighted_jaccard_similarity<'a>(
- mut set_a: &'a Occurrences,
- mut set_b: &'a Occurrences,
-) -> f32 {
- if set_a.table.len() > set_b.table.len() {
- std::mem::swap(&mut set_a, &mut set_b);
- }
-
- let mut numerator = 0;
- let mut denominator_a = 0;
- let mut used_count_b = 0;
- for entry_a in set_a.table.iter() {
- let count_a = entry_a.count;
- let count_b = set_b.get_count(entry_a.hash);
- numerator += count_a.min(count_b);
- denominator_a += count_a.max(count_b);
- used_count_b += count_b;
- }
-
- let denominator = denominator_a + (set_b.total_count - used_count_b);
- if denominator == 0 {
- 0.0
- } else {
- numerator as f32 / denominator as f32
- }
-}
-
-pub fn weighted_overlap_coefficient<'a>(
- mut set_a: &'a Occurrences,
- mut set_b: &'a Occurrences,
-) -> f32 {
- if set_a.table.len() > set_b.table.len() {
- std::mem::swap(&mut set_a, &mut set_b);
- }
-
- let mut numerator = 0;
- for entry_a in set_a.table.iter() {
- let count_a = entry_a.count;
- let count_b = set_b.get_count(entry_a.hash);
- numerator += count_a.min(count_b);
- }
-
- let denominator = set_a.total_count.min(set_b.total_count);
- if denominator == 0 {
- 0.0
- } else {
- numerator as f32 / denominator as f32
- }
-}
-
-#[cfg(test)]
-mod test {
- use super::*;
-
- #[test]
- fn test_split_identifier() {
- assert_eq!(split_identifier("snake_case"), vec!["snake", "case"]);
- assert_eq!(split_identifier("kebab-case"), vec!["kebab", "case"]);
- assert_eq!(split_identifier("PascalCase"), vec!["Pascal", "Case"]);
- assert_eq!(split_identifier("camelCase"), vec!["camel", "Case"]);
- assert_eq!(split_identifier("XMLParser"), vec!["XML", "Parser"]);
- }
-
- #[test]
- fn test_similarity_functions() {
- // 10 identifier parts, 8 unique
- // Repeats: 2 "outline", 2 "items"
- let set_a = Occurrences::within_string(
- "let mut outline_items = query_outline_items(&language, &tree, &source);",
- );
- // 14 identifier parts, 11 unique
- // Repeats: 2 "outline", 2 "language", 2 "tree"
- let set_b = Occurrences::within_string(
- "pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
- );
-
- // 6 overlaps: "outline", "items", "query", "language", "tree", "source"
- // 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str"
- assert_eq!(jaccard_similarity(&set_a, &set_b), 6.0 / (6.0 + 7.0));
-
- // Numerator is one more than before due to both having 2 "outline".
- // Denominator is the same except for 3 more due to the non-overlapping duplicates
- assert_eq!(
- weighted_jaccard_similarity(&set_a, &set_b),
- 7.0 / (7.0 + 7.0 + 3.0)
- );
-
- // Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8.
- assert_eq!(overlap_coefficient(&set_a, &set_b), 6.0 / 8.0);
-
- // Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of
- // the smaller set, 10.
- assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0);
- }
-
- #[test]
- fn test_iter_path_without_extension() {
- let mut iter = iter_path_without_extension(Path::new(""));
- assert_eq!(iter.next(), None);
-
- let iter = iter_path_without_extension(Path::new("foo"));
- assert_eq!(iter.collect::<Vec<_>>(), ["foo"]);
-
- let iter = iter_path_without_extension(Path::new("foo/bar.txt"));
- assert_eq!(iter.collect::<Vec<_>>(), ["foo", "bar"]);
-
- let iter = iter_path_without_extension(Path::new("foo/bar/baz.txt"));
- assert_eq!(iter.collect::<Vec<_>>(), ["foo", "bar", "baz"]);
- }
-}
@@ -1,42 +0,0 @@
-[package]
-name = "edit_prediction_context2"
-version = "0.1.0"
-edition.workspace = true
-publish.workspace = true
-license = "GPL-3.0-or-later"
-
-[lints]
-workspace = true
-
-[lib]
-path = "src/edit_prediction_context2.rs"
-
-[dependencies]
-parking_lot.workspace = true
-anyhow.workspace = true
-collections.workspace = true
-futures.workspace = true
-gpui.workspace = true
-language.workspace = true
-lsp.workspace = true
-project.workspace = true
-log.workspace = true
-serde.workspace = true
-smallvec.workspace = true
-tree-sitter.workspace = true
-util.workspace = true
-
-[dev-dependencies]
-env_logger.workspace = true
-indoc.workspace = true
-futures.workspace = true
-gpui = { workspace = true, features = ["test-support"] }
-language = { workspace = true, features = ["test-support"] }
-lsp = { workspace = true, features = ["test-support"] }
-pretty_assertions.workspace = true
-project = {workspace= true, features = ["test-support"]}
-serde_json.workspace = true
-settings = {workspace= true, features = ["test-support"]}
-text = { workspace = true, features = ["test-support"] }
-util = { workspace = true, features = ["test-support"] }
-zlog.workspace = true
@@ -1,465 +0,0 @@
-use crate::assemble_excerpts::assemble_excerpts;
-use anyhow::Result;
-use collections::HashMap;
-use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
-use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
-use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _};
-use project::{LocationLink, Project, ProjectPath};
-use serde::{Serialize, Serializer};
-use smallvec::SmallVec;
-use std::{
- collections::hash_map,
- ops::Range,
- sync::Arc,
- time::{Duration, Instant},
-};
-use util::{RangeExt as _, ResultExt};
-
-mod assemble_excerpts;
-#[cfg(test)]
-mod edit_prediction_context_tests;
-#[cfg(test)]
-mod fake_definition_lsp;
-
-pub struct RelatedExcerptStore {
- project: WeakEntity<Project>,
- related_files: Vec<RelatedFile>,
- cache: HashMap<Identifier, Arc<CacheEntry>>,
- update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
-}
-
-pub enum RelatedExcerptStoreEvent {
- StartedRefresh,
- FinishedRefresh {
- cache_hit_count: usize,
- cache_miss_count: usize,
- mean_definition_latency: Duration,
- max_definition_latency: Duration,
- },
-}
-
-#[derive(Clone, Debug, PartialEq, Eq, Hash)]
-struct Identifier {
- pub name: String,
- pub range: Range<Anchor>,
-}
-
-enum DefinitionTask {
- CacheHit(Arc<CacheEntry>),
- CacheMiss(Task<Result<Option<Vec<LocationLink>>>>),
-}
-
-#[derive(Debug)]
-struct CacheEntry {
- definitions: SmallVec<[CachedDefinition; 1]>,
-}
-
-#[derive(Clone, Debug)]
-struct CachedDefinition {
- path: ProjectPath,
- buffer: Entity<Buffer>,
- anchor_range: Range<Anchor>,
-}
-
-#[derive(Clone, Debug, Serialize)]
-pub struct RelatedFile {
- #[serde(serialize_with = "serialize_project_path")]
- pub path: ProjectPath,
- #[serde(skip)]
- pub buffer: WeakEntity<Buffer>,
- pub excerpts: Vec<RelatedExcerpt>,
- pub max_row: u32,
-}
-
-impl RelatedFile {
- pub fn merge_excerpts(&mut self) {
- self.excerpts.sort_unstable_by(|a, b| {
- a.point_range
- .start
- .cmp(&b.point_range.start)
- .then(b.point_range.end.cmp(&a.point_range.end))
- });
-
- let mut index = 1;
- while index < self.excerpts.len() {
- if self.excerpts[index - 1]
- .point_range
- .end
- .cmp(&self.excerpts[index].point_range.start)
- .is_ge()
- {
- let removed = self.excerpts.remove(index);
- if removed
- .point_range
- .end
- .cmp(&self.excerpts[index - 1].point_range.end)
- .is_gt()
- {
- self.excerpts[index - 1].point_range.end = removed.point_range.end;
- self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end;
- }
- } else {
- index += 1;
- }
- }
- }
-}
-
-#[derive(Clone, Debug, Serialize)]
-pub struct RelatedExcerpt {
- #[serde(skip)]
- pub anchor_range: Range<Anchor>,
- #[serde(serialize_with = "serialize_point_range")]
- pub point_range: Range<Point>,
- #[serde(serialize_with = "serialize_rope")]
- pub text: Rope,
-}
-
-fn serialize_project_path<S: Serializer>(
- project_path: &ProjectPath,
- serializer: S,
-) -> Result<S::Ok, S::Error> {
- project_path.path.serialize(serializer)
-}
-
-fn serialize_rope<S: Serializer>(rope: &Rope, serializer: S) -> Result<S::Ok, S::Error> {
- rope.to_string().serialize(serializer)
-}
-
-fn serialize_point_range<S: Serializer>(
- range: &Range<Point>,
- serializer: S,
-) -> Result<S::Ok, S::Error> {
- [
- [range.start.row, range.start.column],
- [range.end.row, range.end.column],
- ]
- .serialize(serializer)
-}
-
-const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
-
-impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
-
-impl RelatedExcerptStore {
- pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
- let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity<Buffer>, Anchor)>();
- cx.spawn(async move |this, cx| {
- let executor = cx.background_executor().clone();
- while let Some((mut buffer, mut position)) = update_rx.next().await {
- let mut timer = executor.timer(DEBOUNCE_DURATION).fuse();
- loop {
- futures::select_biased! {
- next = update_rx.next() => {
- if let Some((new_buffer, new_position)) = next {
- buffer = new_buffer;
- position = new_position;
- timer = executor.timer(DEBOUNCE_DURATION).fuse();
- } else {
- return anyhow::Ok(());
- }
- }
- _ = timer => break,
- }
- }
-
- Self::fetch_excerpts(this.clone(), buffer, position, cx).await?;
- }
- anyhow::Ok(())
- })
- .detach_and_log_err(cx);
-
- RelatedExcerptStore {
- project: project.downgrade(),
- update_tx,
- related_files: Vec::new(),
- cache: Default::default(),
- }
- }
-
- pub fn refresh(&mut self, buffer: Entity<Buffer>, position: Anchor, _: &mut Context<Self>) {
- self.update_tx.unbounded_send((buffer, position)).ok();
- }
-
- pub fn related_files(&self) -> &[RelatedFile] {
- &self.related_files
- }
-
- async fn fetch_excerpts(
- this: WeakEntity<Self>,
- buffer: Entity<Buffer>,
- position: Anchor,
- cx: &mut AsyncApp,
- ) -> Result<()> {
- let (project, snapshot) = this.read_with(cx, |this, cx| {
- (this.project.upgrade(), buffer.read(cx).snapshot())
- })?;
- let Some(project) = project else {
- return Ok(());
- };
-
- let file = snapshot.file().cloned();
- if let Some(file) = &file {
- log::debug!("retrieving_context buffer:{}", file.path().as_unix_str());
- }
-
- this.update(cx, |_, cx| {
- cx.emit(RelatedExcerptStoreEvent::StartedRefresh);
- })?;
-
- let identifiers = cx
- .background_spawn(async move { identifiers_for_position(&snapshot, position) })
- .await;
-
- let async_cx = cx.clone();
- let start_time = Instant::now();
- let futures = this.update(cx, |this, cx| {
- identifiers
- .into_iter()
- .filter_map(|identifier| {
- let task = if let Some(entry) = this.cache.get(&identifier) {
- DefinitionTask::CacheHit(entry.clone())
- } else {
- DefinitionTask::CacheMiss(
- this.project
- .update(cx, |project, cx| {
- project.definitions(&buffer, identifier.range.start, cx)
- })
- .ok()?,
- )
- };
-
- let cx = async_cx.clone();
- let project = project.clone();
- Some(async move {
- match task {
- DefinitionTask::CacheHit(cache_entry) => {
- Some((identifier, cache_entry, None))
- }
- DefinitionTask::CacheMiss(task) => {
- let locations = task.await.log_err()??;
- let duration = start_time.elapsed();
- cx.update(|cx| {
- (
- identifier,
- Arc::new(CacheEntry {
- definitions: locations
- .into_iter()
- .filter_map(|location| {
- process_definition(location, &project, cx)
- })
- .collect(),
- }),
- Some(duration),
- )
- })
- .ok()
- }
- }
- })
- })
- .collect::<Vec<_>>()
- })?;
-
- let mut cache_hit_count = 0;
- let mut cache_miss_count = 0;
- let mut mean_definition_latency = Duration::ZERO;
- let mut max_definition_latency = Duration::ZERO;
- let mut new_cache = HashMap::default();
- new_cache.reserve(futures.len());
- for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() {
- new_cache.insert(identifier, entry);
- if let Some(duration) = duration {
- cache_miss_count += 1;
- mean_definition_latency += duration;
- max_definition_latency = max_definition_latency.max(duration);
- } else {
- cache_hit_count += 1;
- }
- }
- mean_definition_latency /= cache_miss_count.max(1) as u32;
-
- let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?;
-
- if let Some(file) = &file {
- log::debug!(
- "finished retrieving context buffer:{}, latency:{:?}",
- file.path().as_unix_str(),
- start_time.elapsed()
- );
- }
-
- this.update(cx, |this, cx| {
- this.cache = new_cache;
- this.related_files = related_files;
- cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
- cache_hit_count,
- cache_miss_count,
- mean_definition_latency,
- max_definition_latency,
- });
- })?;
-
- anyhow::Ok(())
- }
-}
-
-async fn rebuild_related_files(
- new_entries: HashMap<Identifier, Arc<CacheEntry>>,
- cx: &mut AsyncApp,
-) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedFile>)> {
- let mut snapshots = HashMap::default();
- for entry in new_entries.values() {
- for definition in &entry.definitions {
- if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
- definition
- .buffer
- .read_with(cx, |buffer, _| buffer.parsing_idle())?
- .await;
- e.insert(
- definition
- .buffer
- .read_with(cx, |buffer, _| buffer.snapshot())?,
- );
- }
- }
- }
-
- Ok(cx
- .background_spawn(async move {
- let mut files = Vec::<RelatedFile>::new();
- let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
- let mut paths_by_buffer = HashMap::default();
- for entry in new_entries.values() {
- for definition in &entry.definitions {
- let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else {
- continue;
- };
- paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone());
- ranges_by_buffer
- .entry(definition.buffer.clone())
- .or_default()
- .push(definition.anchor_range.to_point(snapshot));
- }
- }
-
- for (buffer, ranges) in ranges_by_buffer {
- let Some(snapshot) = snapshots.get(&buffer.entity_id()) else {
- continue;
- };
- let Some(project_path) = paths_by_buffer.get(&buffer.entity_id()) else {
- continue;
- };
- let excerpts = assemble_excerpts(snapshot, ranges);
- files.push(RelatedFile {
- path: project_path.clone(),
- buffer: buffer.downgrade(),
- excerpts,
- max_row: snapshot.max_point().row,
- });
- }
-
- files.sort_by_key(|file| file.path.clone());
- (new_entries, files)
- })
- .await)
-}
-
-fn process_definition(
- location: LocationLink,
- project: &Entity<Project>,
- cx: &mut App,
-) -> Option<CachedDefinition> {
- let buffer = location.target.buffer.read(cx);
- let anchor_range = location.target.range;
- let file = buffer.file()?;
- let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?;
- if worktree.read(cx).is_single_file() {
- return None;
- }
- Some(CachedDefinition {
- path: ProjectPath {
- worktree_id: file.worktree_id(cx),
- path: file.path().clone(),
- },
- buffer: location.target.buffer,
- anchor_range,
- })
-}
-
-/// Gets all of the identifiers that are present in the given line, and its containing
-/// outline items.
-fn identifiers_for_position(buffer: &BufferSnapshot, position: Anchor) -> Vec<Identifier> {
- let offset = position.to_offset(buffer);
- let point = buffer.offset_to_point(offset);
-
- let line_range = Point::new(point.row, 0)..Point::new(point.row + 1, 0).min(buffer.max_point());
- let mut ranges = vec![line_range.to_offset(&buffer)];
-
- // Include the range of the outline item itself, but not its body.
- let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None);
- for item in outline_items {
- if let Some(body_range) = item.body_range(&buffer) {
- ranges.push(item.range.start..body_range.start.to_offset(&buffer));
- } else {
- ranges.push(item.range.clone());
- }
- }
-
- ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
- ranges.dedup_by(|a, b| {
- if a.start <= b.end {
- b.start = b.start.min(a.start);
- b.end = b.end.max(a.end);
- true
- } else {
- false
- }
- });
-
- let mut identifiers = Vec::new();
- let outer_range =
- ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end);
-
- let mut captures = buffer
- .syntax
- .captures(outer_range.clone(), &buffer.text, |grammar| {
- grammar
- .highlights_config
- .as_ref()
- .map(|config| &config.query)
- });
-
- for range in ranges {
- captures.set_byte_range(range.start..outer_range.end);
-
- let mut last_range = None;
- while let Some(capture) = captures.peek() {
- let node_range = capture.node.byte_range();
- if node_range.start > range.end {
- break;
- }
- let config = captures.grammars()[capture.grammar_index]
- .highlights_config
- .as_ref();
-
- if let Some(config) = config
- && config.identifier_capture_indices.contains(&capture.index)
- && range.contains_inclusive(&node_range)
- && Some(&node_range) != last_range.as_ref()
- {
- let name = buffer.text_for_range(node_range.clone()).collect();
- identifiers.push(Identifier {
- range: buffer.anchor_after(node_range.start)
- ..buffer.anchor_before(node_range.end),
- name,
- });
- last_range = Some(node_range);
- }
-
- captures.advance();
- }
- }
-
- identifiers
-}
@@ -0,0 +1,17 @@
+[package]
+name = "edit_prediction_types"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/edit_prediction_types.rs"
+
+[dependencies]
+client.workspace = true
+gpui.workspace = true
+language.workspace = true
@@ -0,0 +1,298 @@
+use std::{ops::Range, sync::Arc};
+
+use client::EditPredictionUsage;
+use gpui::{App, Context, Entity, SharedString};
+use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt};
+
+// TODO: Find a better home for `Direction`.
+//
+// This should live in an ancestor crate of `editor` and `edit_prediction`,
+// but at time of writing there isn't an obvious spot.
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum Direction {
+ Prev,
+ Next,
+}
+
+#[derive(Clone)]
+pub enum EditPrediction {
+ /// Edits within the buffer that requested the prediction
+ Local {
+ id: Option<SharedString>,
+ edits: Vec<(Range<language::Anchor>, Arc<str>)>,
+ edit_preview: Option<language::EditPreview>,
+ },
+ /// Jump to a different file from the one that requested the prediction
+ Jump {
+ id: Option<SharedString>,
+ snapshot: language::BufferSnapshot,
+ target: language::Anchor,
+ },
+}
+
+pub enum DataCollectionState {
+ /// The provider doesn't support data collection.
+ Unsupported,
+ /// Data collection is enabled.
+ Enabled { is_project_open_source: bool },
+ /// Data collection is disabled or unanswered.
+ Disabled { is_project_open_source: bool },
+}
+
+impl DataCollectionState {
+ pub fn is_supported(&self) -> bool {
+ !matches!(self, DataCollectionState::Unsupported)
+ }
+
+ pub fn is_enabled(&self) -> bool {
+ matches!(self, DataCollectionState::Enabled { .. })
+ }
+
+ pub fn is_project_open_source(&self) -> bool {
+ match self {
+ Self::Enabled {
+ is_project_open_source,
+ }
+ | Self::Disabled {
+ is_project_open_source,
+ } => *is_project_open_source,
+ _ => false,
+ }
+ }
+}
+
+pub trait EditPredictionDelegate: 'static + Sized {
+ fn name() -> &'static str;
+ fn display_name() -> &'static str;
+ fn show_predictions_in_menu() -> bool;
+ fn show_tab_accept_marker() -> bool {
+ false
+ }
+ fn supports_jump_to_edit() -> bool {
+ true
+ }
+
+ fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
+ DataCollectionState::Unsupported
+ }
+
+ fn usage(&self, _cx: &App) -> Option<EditPredictionUsage> {
+ None
+ }
+
+ fn toggle_data_collection(&mut self, _cx: &mut App) {}
+ fn is_enabled(
+ &self,
+ buffer: &Entity<Buffer>,
+ cursor_position: language::Anchor,
+ cx: &App,
+ ) -> bool;
+ fn is_refreshing(&self, cx: &App) -> bool;
+ fn refresh(
+ &mut self,
+ buffer: Entity<Buffer>,
+ cursor_position: language::Anchor,
+ debounce: bool,
+ cx: &mut Context<Self>,
+ );
+ fn cycle(
+ &mut self,
+ buffer: Entity<Buffer>,
+ cursor_position: language::Anchor,
+ direction: Direction,
+ cx: &mut Context<Self>,
+ );
+ fn accept(&mut self, cx: &mut Context<Self>);
+ fn discard(&mut self, cx: &mut Context<Self>);
+ fn did_show(&mut self, _cx: &mut Context<Self>) {}
+ fn suggest(
+ &mut self,
+ buffer: &Entity<Buffer>,
+ cursor_position: language::Anchor,
+ cx: &mut Context<Self>,
+ ) -> Option<EditPrediction>;
+}
+
+pub trait EditPredictionDelegateHandle {
+ fn name(&self) -> &'static str;
+ fn display_name(&self) -> &'static str;
+ fn is_enabled(
+ &self,
+ buffer: &Entity<Buffer>,
+ cursor_position: language::Anchor,
+ cx: &App,
+ ) -> bool;
+ fn show_predictions_in_menu(&self) -> bool;
+ fn show_tab_accept_marker(&self) -> bool;
+ fn supports_jump_to_edit(&self) -> bool;
+ fn data_collection_state(&self, cx: &App) -> DataCollectionState;
+ fn usage(&self, cx: &App) -> Option<EditPredictionUsage>;
+ fn toggle_data_collection(&self, cx: &mut App);
+ fn is_refreshing(&self, cx: &App) -> bool;
+ fn refresh(
+ &self,
+ buffer: Entity<Buffer>,
+ cursor_position: language::Anchor,
+ debounce: bool,
+ cx: &mut App,
+ );
+ fn cycle(
+ &self,
+ buffer: Entity<Buffer>,
+ cursor_position: language::Anchor,
+ direction: Direction,
+ cx: &mut App,
+ );
+ fn did_show(&self, cx: &mut App);
+ fn accept(&self, cx: &mut App);
+ fn discard(&self, cx: &mut App);
+ fn suggest(
+ &self,
+ buffer: &Entity<Buffer>,
+ cursor_position: language::Anchor,
+ cx: &mut App,
+ ) -> Option<EditPrediction>;
+}
+
+impl<T> EditPredictionDelegateHandle for Entity<T>
+where
+ T: EditPredictionDelegate,
+{
+ fn name(&self) -> &'static str {
+ T::name()
+ }
+
+ fn display_name(&self) -> &'static str {
+ T::display_name()
+ }
+
+ fn show_predictions_in_menu(&self) -> bool {
+ T::show_predictions_in_menu()
+ }
+
+ fn show_tab_accept_marker(&self) -> bool {
+ T::show_tab_accept_marker()
+ }
+
+ fn supports_jump_to_edit(&self) -> bool {
+ T::supports_jump_to_edit()
+ }
+
+ fn data_collection_state(&self, cx: &App) -> DataCollectionState {
+ self.read(cx).data_collection_state(cx)
+ }
+
+ fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
+ self.read(cx).usage(cx)
+ }
+
+ fn toggle_data_collection(&self, cx: &mut App) {
+ self.update(cx, |this, cx| this.toggle_data_collection(cx))
+ }
+
+ fn is_enabled(
+ &self,
+ buffer: &Entity<Buffer>,
+ cursor_position: language::Anchor,
+ cx: &App,
+ ) -> bool {
+ self.read(cx).is_enabled(buffer, cursor_position, cx)
+ }
+
+ fn is_refreshing(&self, cx: &App) -> bool {
+ self.read(cx).is_refreshing(cx)
+ }
+
+ fn refresh(
+ &self,
+ buffer: Entity<Buffer>,
+ cursor_position: language::Anchor,
+ debounce: bool,
+ cx: &mut App,
+ ) {
+ self.update(cx, |this, cx| {
+ this.refresh(buffer, cursor_position, debounce, cx)
+ })
+ }
+
+ fn cycle(
+ &self,
+ buffer: Entity<Buffer>,
+ cursor_position: language::Anchor,
+ direction: Direction,
+ cx: &mut App,
+ ) {
+ self.update(cx, |this, cx| {
+ this.cycle(buffer, cursor_position, direction, cx)
+ })
+ }
+
+ fn accept(&self, cx: &mut App) {
+ self.update(cx, |this, cx| this.accept(cx))
+ }
+
+ fn discard(&self, cx: &mut App) {
+ self.update(cx, |this, cx| this.discard(cx))
+ }
+
+ fn did_show(&self, cx: &mut App) {
+ self.update(cx, |this, cx| this.did_show(cx))
+ }
+
+ fn suggest(
+ &self,
+ buffer: &Entity<Buffer>,
+ cursor_position: language::Anchor,
+ cx: &mut App,
+ ) -> Option<EditPrediction> {
+ self.update(cx, |this, cx| this.suggest(buffer, cursor_position, cx))
+ }
+}
+
+/// Returns edits updated based on user edits since the old snapshot. None is returned if any user
+/// edit is not a prefix of a predicted insertion.
+pub fn interpolate_edits(
+ old_snapshot: &BufferSnapshot,
+ new_snapshot: &BufferSnapshot,
+ current_edits: &[(Range<Anchor>, Arc<str>)],
+) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
+ let mut edits = Vec::new();
+
+ let mut model_edits = current_edits.iter().peekable();
+ for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
+ while let Some((model_old_range, _)) = model_edits.peek() {
+ let model_old_range = model_old_range.to_offset(old_snapshot);
+ if model_old_range.end < user_edit.old.start {
+ let (model_old_range, model_new_text) = model_edits.next().unwrap();
+ edits.push((model_old_range.clone(), model_new_text.clone()));
+ } else {
+ break;
+ }
+ }
+
+ if let Some((model_old_range, model_new_text)) = model_edits.peek() {
+ let model_old_offset_range = model_old_range.to_offset(old_snapshot);
+ if user_edit.old == model_old_offset_range {
+ let user_new_text = new_snapshot
+ .text_for_range(user_edit.new.clone())
+ .collect::<String>();
+
+ if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
+ if !model_suffix.is_empty() {
+ let anchor = old_snapshot.anchor_after(user_edit.old.end);
+ edits.push((anchor..anchor, model_suffix.into()));
+ }
+
+ model_edits.next();
+ continue;
+ }
+ }
+ }
+
+ return None;
+ }
+
+ edits.extend(model_edits.cloned());
+
+ if edits.is_empty() { None } else { Some(edits) }
+}
@@ -1,5 +1,5 @@
[package]
-name = "edit_prediction_button"
+name = "edit_prediction_ui"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
@@ -9,35 +9,43 @@ license = "GPL-3.0-or-later"
workspace = true
[lib]
-path = "src/edit_prediction_button.rs"
+path = "src/edit_prediction_ui.rs"
doctest = false
[dependencies]
anyhow.workspace = true
+buffer_diff.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
+cloud_zeta2_prompt.workspace = true
codestral.workspace = true
+command_palette_hooks.workspace = true
copilot.workspace = true
edit_prediction.workspace = true
+edit_prediction_types.workspace = true
editor.workspace = true
feature_flags.workspace = true
fs.workspace = true
+futures.workspace = true
gpui.workspace = true
indoc.workspace = true
language.workspace = true
+markdown.workspace = true
+menu.workspace = true
+multi_buffer.workspace = true
paths.workspace = true
project.workspace = true
regex.workspace = true
settings.workspace = true
supermaven.workspace = true
telemetry.workspace = true
+text.workspace = true
+theme.workspace = true
ui.workspace = true
ui_input.workspace = true
-menu.workspace = true
util.workspace = true
workspace.workspace = true
zed_actions.workspace = true
-zeta.workspace = true
[dev-dependencies]
copilot = { workspace = true, features = ["test-support"] }
@@ -1,16 +1,14 @@
-mod sweep_api_token_modal;
-
-pub use sweep_api_token_modal::SweepApiKeyModal;
-
use anyhow::Result;
use client::{Client, UserStore, zed_urls};
use cloud_llm_client::UsageLimit;
-use codestral::CodestralCompletionProvider;
+use codestral::CodestralEditPredictionDelegate;
use copilot::{Copilot, Status};
+use edit_prediction::{SweepFeatureFlag, Zeta2FeatureFlag};
+use edit_prediction_types::EditPredictionDelegateHandle;
use editor::{
Editor, MultiBufferOffset, SelectionEffects, actions::ShowEditPrediction, scroll::Autoscroll,
};
-use feature_flags::{FeatureFlagAppExt, PredictEditsRateCompletionsFeatureFlag};
+use feature_flags::FeatureFlagAppExt;
use fs::Fs;
use gpui::{
Action, Animation, AnimationExt, App, AsyncWindowContext, Corner, Entity, FocusHandle,
@@ -44,7 +42,11 @@ use workspace::{
notifications::NotificationId,
};
use zed_actions::OpenBrowser;
-use zeta::{RateCompletions, SweepFeatureFlag, Zeta2FeatureFlag};
+
+use crate::{
+ RatePredictions, SweepApiKeyModal,
+ rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag,
+};
actions!(
edit_prediction,
@@ -67,7 +69,7 @@ pub struct EditPredictionButton {
editor_focus_handle: Option<FocusHandle>,
language: Option<Arc<Language>>,
file: Option<Arc<dyn File>>,
- edit_prediction_provider: Option<Arc<dyn edit_prediction::EditPredictionProviderHandle>>,
+ edit_prediction_provider: Option<Arc<dyn EditPredictionDelegateHandle>>,
fs: Arc<dyn Fs>,
user_store: Entity<UserStore>,
popover_menu_handle: PopoverMenuHandle<ContextMenu>,
@@ -244,7 +246,7 @@ impl Render for EditPredictionButton {
EditPredictionProvider::Codestral => {
let enabled = self.editor_enabled.unwrap_or(true);
- let has_api_key = CodestralCompletionProvider::has_api_key(cx);
+ let has_api_key = CodestralEditPredictionDelegate::has_api_key(cx);
let fs = self.fs.clone();
let this = cx.weak_entity();
@@ -317,16 +319,16 @@ impl Render for EditPredictionButton {
);
let sweep_missing_token = is_sweep
- && !zeta::Zeta::try_global(cx)
- .map_or(false, |zeta| zeta.read(cx).has_sweep_api_token());
+ && !edit_prediction::EditPredictionStore::try_global(cx)
+ .map_or(false, |ep_store| ep_store.read(cx).has_sweep_api_token());
- let zeta_icon = match (is_sweep, enabled) {
+ let ep_icon = match (is_sweep, enabled) {
(true, _) => IconName::SweepAi,
(false, true) => IconName::ZedPredict,
(false, false) => IconName::ZedPredictDisabled,
};
- if zeta::should_show_upsell_modal() {
+ if edit_prediction::should_show_upsell_modal() {
let tooltip_meta = if self.user_store.read(cx).current_user().is_some() {
"Choose a Plan"
} else {
@@ -334,7 +336,7 @@ impl Render for EditPredictionButton {
};
return div().child(
- IconButton::new("zed-predict-pending-button", zeta_icon)
+ IconButton::new("zed-predict-pending-button", ep_icon)
.shape(IconButtonShape::Square)
.indicator(Indicator::dot().color(Color::Muted))
.indicator_border_color(Some(cx.theme().colors().status_bar_background))
@@ -379,7 +381,7 @@ impl Render for EditPredictionButton {
None
};
- let icon_button = IconButton::new("zed-predict-pending-button", zeta_icon)
+ let icon_button = IconButton::new("zed-predict-pending-button", ep_icon)
.shape(IconButtonShape::Square)
.when_some(indicator_color, |this, color| {
this.indicator(Indicator::dot().color(color))
@@ -419,13 +421,13 @@ impl Render for EditPredictionButton {
let this = cx.weak_entity();
- let mut popover_menu = PopoverMenu::new("zeta")
+ let mut popover_menu = PopoverMenu::new("edit-prediction")
.when(user.is_some(), |popover_menu| {
let this = this.clone();
popover_menu.menu(move |window, cx| {
this.update(cx, |this, cx| {
- this.build_zeta_context_menu(provider, window, cx)
+ this.build_edit_prediction_context_menu(provider, window, cx)
})
.ok()
})
@@ -485,7 +487,7 @@ impl EditPredictionButton {
cx.observe_global::<SettingsStore>(move |_, cx| cx.notify())
.detach();
- CodestralCompletionProvider::ensure_api_key_loaded(client.http_client(), cx);
+ CodestralEditPredictionDelegate::ensure_api_key_loaded(client.http_client(), cx);
Self {
editor_subscription: None,
@@ -520,7 +522,7 @@ impl EditPredictionButton {
}
}
- if CodestralCompletionProvider::has_api_key(cx) {
+ if CodestralEditPredictionDelegate::has_api_key(cx) {
providers.push(EditPredictionProvider::Codestral);
}
@@ -599,8 +601,8 @@ impl EditPredictionButton {
EditPredictionProvider::Experimental(
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
) => {
- let has_api_token = zeta::Zeta::try_global(cx)
- .map_or(false, |zeta| zeta.read(cx).has_sweep_api_token());
+ let has_api_token = edit_prediction::EditPredictionStore::try_global(cx)
+ .map_or(false, |ep_store| ep_store.read(cx).has_sweep_api_token());
let should_open_modal = !has_api_token || is_current;
@@ -947,8 +949,8 @@ impl EditPredictionButton {
)
.context(editor_focus_handle)
.when(
- cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>(),
- |this| this.action("Rate Completions", RateCompletions.boxed_clone()),
+ cx.has_flag::<PredictEditsRatePredictionsFeatureFlag>(),
+ |this| this.action("Rate Predictions", RatePredictions.boxed_clone()),
);
}
@@ -1016,7 +1018,7 @@ impl EditPredictionButton {
})
}
- fn build_zeta_context_menu(
+ fn build_edit_prediction_context_menu(
&self,
provider: EditPredictionProvider,
window: &mut Window,
@@ -23,16 +23,16 @@ use ui::{
StyledTypography as _, h_flex, v_flex,
};
-use workspace::Item;
-use zeta::{
- Zeta, ZetaContextRetrievalFinishedDebugInfo, ZetaContextRetrievalStartedDebugInfo,
- ZetaDebugInfo,
+use edit_prediction::{
+ ContextRetrievalFinishedDebugEvent, ContextRetrievalStartedDebugEvent, DebugEvent,
+ EditPredictionStore,
};
+use workspace::Item;
-pub struct Zeta2ContextView {
+pub struct EditPredictionContextView {
empty_focus_handle: FocusHandle,
project: Entity<Project>,
- zeta: Entity<Zeta>,
+ store: Entity<EditPredictionStore>,
runs: VecDeque<RetrievalRun>,
current_ix: usize,
_update_task: Task<Result<()>>,
@@ -50,13 +50,13 @@ actions!(
dev,
[
/// Go to the previous context retrieval run
- Zeta2ContextGoBack,
+ EditPredictionContextGoBack,
/// Go to the next context retrieval run
- Zeta2ContextGoForward
+ EditPredictionContextGoForward
]
);
-impl Zeta2ContextView {
+impl EditPredictionContextView {
pub fn new(
project: Entity<Project>,
client: &Arc<Client>,
@@ -64,13 +64,13 @@ impl Zeta2ContextView {
window: &mut gpui::Window,
cx: &mut Context<Self>,
) -> Self {
- let zeta = Zeta::global(client, user_store, cx);
+ let store = EditPredictionStore::global(client, user_store, cx);
- let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info());
+ let mut debug_rx = store.update(cx, |store, _| store.debug_info());
let _update_task = cx.spawn_in(window, async move |this, cx| {
while let Some(event) = debug_rx.next().await {
this.update_in(cx, |this, window, cx| {
- this.handle_zeta_event(event, window, cx)
+ this.handle_store_event(event, window, cx)
})?;
}
Ok(())
@@ -81,35 +81,35 @@ impl Zeta2ContextView {
project,
runs: VecDeque::new(),
current_ix: 0,
- zeta,
+ store,
_update_task,
}
}
- fn handle_zeta_event(
+ fn handle_store_event(
&mut self,
- event: ZetaDebugInfo,
+ event: DebugEvent,
window: &mut gpui::Window,
cx: &mut Context<Self>,
) {
match event {
- ZetaDebugInfo::ContextRetrievalStarted(info) => {
+ DebugEvent::ContextRetrievalStarted(info) => {
if info.project_entity_id == self.project.entity_id() {
self.handle_context_retrieval_started(info, window, cx);
}
}
- ZetaDebugInfo::ContextRetrievalFinished(info) => {
+ DebugEvent::ContextRetrievalFinished(info) => {
if info.project_entity_id == self.project.entity_id() {
self.handle_context_retrieval_finished(info, window, cx);
}
}
- ZetaDebugInfo::EditPredictionRequested(_) => {}
+ DebugEvent::EditPredictionRequested(_) => {}
}
}
fn handle_context_retrieval_started(
&mut self,
- info: ZetaContextRetrievalStartedDebugInfo,
+ info: ContextRetrievalStartedDebugEvent,
window: &mut Window,
cx: &mut Context<Self>,
) {
@@ -141,7 +141,7 @@ impl Zeta2ContextView {
fn handle_context_retrieval_finished(
&mut self,
- info: ZetaContextRetrievalFinishedDebugInfo,
+ info: ContextRetrievalFinishedDebugEvent,
window: &mut Window,
cx: &mut Context<Self>,
) {
@@ -154,7 +154,7 @@ impl Zeta2ContextView {
let project = self.project.clone();
let related_files = self
- .zeta
+ .store
.read(cx)
.context_for_project(&self.project, cx)
.to_vec();
@@ -220,7 +220,7 @@ impl Zeta2ContextView {
fn handle_go_back(
&mut self,
- _: &Zeta2ContextGoBack,
+ _: &EditPredictionContextGoBack,
window: &mut Window,
cx: &mut Context<Self>,
) {
@@ -231,7 +231,7 @@ impl Zeta2ContextView {
fn handle_go_forward(
&mut self,
- _: &Zeta2ContextGoForward,
+ _: &EditPredictionContextGoForward,
window: &mut Window,
cx: &mut Context<Self>,
) {
@@ -243,7 +243,10 @@ impl Zeta2ContextView {
cx.notify();
}
- fn render_informational_footer(&self, cx: &mut Context<'_, Zeta2ContextView>) -> ui::Div {
+ fn render_informational_footer(
+ &self,
+ cx: &mut Context<'_, EditPredictionContextView>,
+ ) -> ui::Div {
let run = &self.runs[self.current_ix];
let new_run_started = self
.runs
@@ -279,10 +282,10 @@ impl Zeta2ContextView {
.disabled(self.current_ix == 0 || self.runs.len() < 2)
.tooltip(ui::Tooltip::for_action_title(
"Go to previous run",
- &Zeta2ContextGoBack,
+ &EditPredictionContextGoBack,
))
.on_click(cx.listener(|this, _, window, cx| {
- this.handle_go_back(&Zeta2ContextGoBack, window, cx);
+ this.handle_go_back(&EditPredictionContextGoBack, window, cx);
})),
)
.child(
@@ -308,10 +311,14 @@ impl Zeta2ContextView {
.disabled(self.current_ix + 1 == self.runs.len())
.tooltip(ui::Tooltip::for_action_title(
"Go to next run",
- &Zeta2ContextGoBack,
+ &EditPredictionContextGoBack,
))
.on_click(cx.listener(|this, _, window, cx| {
- this.handle_go_forward(&Zeta2ContextGoForward, window, cx);
+ this.handle_go_forward(
+ &EditPredictionContextGoForward,
+ window,
+ cx,
+ );
})),
),
),
@@ -319,7 +326,7 @@ impl Zeta2ContextView {
}
}
-impl Focusable for Zeta2ContextView {
+impl Focusable for EditPredictionContextView {
fn focus_handle(&self, cx: &App) -> FocusHandle {
self.runs
.get(self.current_ix)
@@ -328,9 +335,9 @@ impl Focusable for Zeta2ContextView {
}
}
-impl EventEmitter<()> for Zeta2ContextView {}
+impl EventEmitter<()> for EditPredictionContextView {}
-impl Item for Zeta2ContextView {
+impl Item for EditPredictionContextView {
type Event = ();
fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
@@ -357,10 +364,10 @@ impl Item for Zeta2ContextView {
}
}
-impl gpui::Render for Zeta2ContextView {
+impl gpui::Render for EditPredictionContextView {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl ui::IntoElement {
v_flex()
- .key_context("Zeta2Context")
+ .key_context("EditPredictionContext")
.on_action(cx.listener(Self::handle_go_back))
.on_action(cx.listener(Self::handle_go_forward))
.size_full()
@@ -0,0 +1,128 @@
+mod edit_prediction_button;
+mod edit_prediction_context_view;
+mod rate_prediction_modal;
+mod sweep_api_token_modal;
+
+use std::any::{Any as _, TypeId};
+
+use command_palette_hooks::CommandPaletteFilter;
+use edit_prediction::{ResetOnboarding, Zeta2FeatureFlag};
+use edit_prediction_context_view::EditPredictionContextView;
+use feature_flags::FeatureFlagAppExt as _;
+use gpui::actions;
+use project::DisableAiSettings;
+use rate_prediction_modal::RatePredictionsModal;
+use settings::{Settings as _, SettingsStore};
+use ui::{App, prelude::*};
+use workspace::{SplitDirection, Workspace};
+
+pub use edit_prediction_button::{EditPredictionButton, ToggleMenu};
+pub use sweep_api_token_modal::SweepApiKeyModal;
+
+use crate::rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag;
+
+actions!(
+ dev,
+ [
+ /// Opens the edit prediction context view.
+ OpenEditPredictionContextView,
+ ]
+);
+
+actions!(
+ edit_prediction,
+ [
+ /// Opens the rate completions modal.
+ RatePredictions,
+ ]
+);
+
+pub fn init(cx: &mut App) {
+ feature_gate_predict_edits_actions(cx);
+
+ cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
+ workspace.register_action(|workspace, _: &RatePredictions, window, cx| {
+ if cx.has_flag::<PredictEditsRatePredictionsFeatureFlag>() {
+ RatePredictionsModal::toggle(workspace, window, cx);
+ }
+ });
+
+ workspace.register_action_renderer(|div, _, _, cx| {
+ let has_flag = cx.has_flag::<Zeta2FeatureFlag>();
+ div.when(has_flag, |div| {
+ div.on_action(cx.listener(
+ move |workspace, _: &OpenEditPredictionContextView, window, cx| {
+ let project = workspace.project();
+ workspace.split_item(
+ SplitDirection::Right,
+ Box::new(cx.new(|cx| {
+ EditPredictionContextView::new(
+ project.clone(),
+ workspace.client(),
+ workspace.user_store(),
+ window,
+ cx,
+ )
+ })),
+ window,
+ cx,
+ );
+ },
+ ))
+ })
+ });
+ })
+ .detach();
+}
+
+fn feature_gate_predict_edits_actions(cx: &mut App) {
+ let rate_completion_action_types = [TypeId::of::<RatePredictions>()];
+ let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
+ let all_action_types = [
+ TypeId::of::<RatePredictions>(),
+ TypeId::of::<edit_prediction::ResetOnboarding>(),
+ zed_actions::OpenZedPredictOnboarding.type_id(),
+ TypeId::of::<edit_prediction::ClearHistory>(),
+ TypeId::of::<rate_prediction_modal::ThumbsUpActivePrediction>(),
+ TypeId::of::<rate_prediction_modal::ThumbsDownActivePrediction>(),
+ TypeId::of::<rate_prediction_modal::NextEdit>(),
+ TypeId::of::<rate_prediction_modal::PreviousEdit>(),
+ ];
+
+ CommandPaletteFilter::update_global(cx, |filter, _cx| {
+ filter.hide_action_types(&rate_completion_action_types);
+ filter.hide_action_types(&reset_onboarding_action_types);
+ filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
+ });
+
+ cx.observe_global::<SettingsStore>(move |cx| {
+ let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
+ let has_feature_flag = cx.has_flag::<PredictEditsRatePredictionsFeatureFlag>();
+
+ CommandPaletteFilter::update_global(cx, |filter, _cx| {
+ if is_ai_disabled {
+ filter.hide_action_types(&all_action_types);
+ } else if has_feature_flag {
+ filter.show_action_types(&rate_completion_action_types);
+ } else {
+ filter.hide_action_types(&rate_completion_action_types);
+ }
+ });
+ })
+ .detach();
+
+ cx.observe_flag::<PredictEditsRatePredictionsFeatureFlag, _>(move |is_enabled, cx| {
+ if !DisableAiSettings::get_global(cx).disable_ai {
+ if is_enabled {
+ CommandPaletteFilter::update_global(cx, |filter, _cx| {
+ filter.show_action_types(&rate_completion_action_types);
+ });
+ } else {
+ CommandPaletteFilter::update_global(cx, |filter, _cx| {
+ filter.hide_action_types(&rate_completion_action_types);
+ });
+ }
+ }
+ })
+ .detach();
+}
@@ -1,7 +1,8 @@
-use crate::{EditPrediction, EditPredictionRating, Zeta};
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use cloud_zeta2_prompt::write_codeblock;
+use edit_prediction::{EditPrediction, EditPredictionRating, EditPredictionStore};
use editor::{Editor, ExcerptRange, MultiBuffer};
+use feature_flags::FeatureFlag;
use gpui::{
App, BorderStyle, DismissEvent, EdgesRefinement, Entity, EventEmitter, FocusHandle, Focusable,
Length, StyleRefinement, TextStyleRefinement, Window, actions, prelude::*,
@@ -9,9 +10,7 @@ use gpui::{
use language::{LanguageRegistry, Point, language_settings};
use markdown::{Markdown, MarkdownStyle};
use settings::Settings as _;
-use std::fmt::Write;
-use std::sync::Arc;
-use std::time::Duration;
+use std::{fmt::Write, sync::Arc, time::Duration};
use theme::ThemeSettings;
use ui::{KeyBinding, List, ListItem, ListItemSpacing, Tooltip, prelude::*};
use workspace::{ModalView, Workspace};
@@ -34,8 +33,14 @@ actions!(
]
);
+pub struct PredictEditsRatePredictionsFeatureFlag;
+
+impl FeatureFlag for PredictEditsRatePredictionsFeatureFlag {
+ const NAME: &'static str = "predict-edits-rate-completions";
+}
+
pub struct RatePredictionsModal {
- zeta: Entity<Zeta>,
+ ep_store: Entity<EditPredictionStore>,
language_registry: Arc<LanguageRegistry>,
active_prediction: Option<ActivePrediction>,
selected_index: usize,
@@ -68,10 +73,10 @@ impl RatePredictionView {
impl RatePredictionsModal {
pub fn toggle(workspace: &mut Workspace, window: &mut Window, cx: &mut Context<Workspace>) {
- if let Some(zeta) = Zeta::try_global(cx) {
+ if let Some(ep_store) = EditPredictionStore::try_global(cx) {
let language_registry = workspace.app_state().languages.clone();
workspace.toggle_modal(window, cx, |window, cx| {
- RatePredictionsModal::new(zeta, language_registry, window, cx)
+ RatePredictionsModal::new(ep_store, language_registry, window, cx)
});
telemetry::event!("Rate Prediction Modal Open", source = "Edit Prediction");
@@ -79,15 +84,15 @@ impl RatePredictionsModal {
}
pub fn new(
- zeta: Entity<Zeta>,
+ ep_store: Entity<EditPredictionStore>,
language_registry: Arc<LanguageRegistry>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
- let subscription = cx.observe(&zeta, |_, _, cx| cx.notify());
+ let subscription = cx.observe(&ep_store, |_, _, cx| cx.notify());
Self {
- zeta,
+ ep_store,
language_registry,
selected_index: 0,
focus_handle: cx.focus_handle(),
@@ -113,7 +118,7 @@ impl RatePredictionsModal {
self.selected_index += 1;
self.selected_index = usize::min(
self.selected_index,
- self.zeta.read(cx).shown_predictions().count(),
+ self.ep_store.read(cx).shown_predictions().count(),
);
cx.notify();
}
@@ -130,7 +135,7 @@ impl RatePredictionsModal {
fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {
let next_index = self
- .zeta
+ .ep_store
.read(cx)
.shown_predictions()
.skip(self.selected_index)
@@ -146,11 +151,11 @@ impl RatePredictionsModal {
}
fn select_prev_edit(&mut self, _: &PreviousEdit, _: &mut Window, cx: &mut Context<Self>) {
- let zeta = self.zeta.read(cx);
- let completions_len = zeta.shown_completions_len();
+ let ep_store = self.ep_store.read(cx);
+ let completions_len = ep_store.shown_completions_len();
let prev_index = self
- .zeta
+ .ep_store
.read(cx)
.shown_predictions()
.rev()
@@ -173,7 +178,7 @@ impl RatePredictionsModal {
}
fn select_last(&mut self, _: &menu::SelectLast, _window: &mut Window, cx: &mut Context<Self>) {
- self.selected_index = self.zeta.read(cx).shown_completions_len() - 1;
+ self.selected_index = self.ep_store.read(cx).shown_completions_len() - 1;
cx.notify();
}
@@ -183,9 +188,9 @@ impl RatePredictionsModal {
window: &mut Window,
cx: &mut Context<Self>,
) {
- self.zeta.update(cx, |zeta, cx| {
+ self.ep_store.update(cx, |ep_store, cx| {
if let Some(active) = &self.active_prediction {
- zeta.rate_prediction(
+ ep_store.rate_prediction(
&active.prediction,
EditPredictionRating::Positive,
active.feedback_editor.read(cx).text(cx),
@@ -216,8 +221,8 @@ impl RatePredictionsModal {
return;
}
- self.zeta.update(cx, |zeta, cx| {
- zeta.rate_prediction(
+ self.ep_store.update(cx, |ep_store, cx| {
+ ep_store.rate_prediction(
&active.prediction,
EditPredictionRating::Negative,
active.feedback_editor.read(cx).text(cx),
@@ -254,7 +259,7 @@ impl RatePredictionsModal {
cx: &mut Context<Self>,
) {
let completion = self
- .zeta
+ .ep_store
.read(cx)
.shown_predictions()
.skip(self.selected_index)
@@ -267,7 +272,7 @@ impl RatePredictionsModal {
fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
let completion = self
- .zeta
+ .ep_store
.read(cx)
.shown_predictions()
.skip(self.selected_index)
@@ -288,7 +293,7 @@ impl RatePredictionsModal {
// Avoid resetting completion rating if it's already selected.
if let Some(prediction) = prediction {
self.selected_index = self
- .zeta
+ .ep_store
.read(cx)
.shown_predictions()
.enumerate()
@@ -376,7 +381,7 @@ impl RatePredictionsModal {
&included_file.path,
&included_file.excerpts,
if included_file.path == prediction.inputs.cursor_path {
- cursor_insertions
+ cursor_insertions.as_slice()
} else {
&[]
},
@@ -564,7 +569,7 @@ impl RatePredictionsModal {
let border_color = cx.theme().colors().border;
let bg_color = cx.theme().colors().editor_background;
- let rated = self.zeta.read(cx).is_prediction_rated(&completion_id);
+ let rated = self.ep_store.read(cx).is_prediction_rated(&completion_id);
let feedback_empty = active_prediction
.feedback_editor
.read(cx)
@@ -715,7 +720,7 @@ impl RatePredictionsModal {
}
fn render_shown_completions(&self, cx: &Context<Self>) -> impl Iterator<Item = ListItem> {
- self.zeta
+ self.ep_store
.read(cx)
.shown_predictions()
.cloned()
@@ -725,7 +730,7 @@ impl RatePredictionsModal {
.active_prediction
.as_ref()
.is_some_and(|selected| selected.prediction.id == completion.id);
- let rated = self.zeta.read(cx).is_prediction_rated(&completion.id);
+ let rated = self.ep_store.read(cx).is_prediction_rated(&completion.id);
let (icon_name, icon_color, tooltip_text) =
match (rated, completion.edits.is_empty()) {
@@ -1,10 +1,10 @@
+use edit_prediction::EditPredictionStore;
use gpui::{
DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, IntoElement, ParentElement, Render,
};
use ui::{Button, ButtonStyle, Clickable, Headline, HeadlineSize, prelude::*};
use ui_input::InputField;
use workspace::ModalView;
-use zeta::Zeta;
pub struct SweepApiKeyModal {
api_key_input: Entity<InputField>,
@@ -29,9 +29,10 @@ impl SweepApiKeyModal {
let api_key = self.api_key_input.read(cx).text(cx);
let api_key = (!api_key.trim().is_empty()).then_some(api_key);
- if let Some(zeta) = Zeta::try_global(cx) {
- zeta.update(cx, |zeta, cx| {
- zeta.sweep_ai
+ if let Some(ep_store) = EditPredictionStore::try_global(cx) {
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store
+ .sweep_ai
.set_api_token(api_key, cx)
.detach_and_log_err(cx);
});
@@ -49,7 +49,7 @@ fs.workspace = true
git.workspace = true
gpui.workspace = true
indoc.workspace = true
-edit_prediction.workspace = true
+edit_prediction_types.workspace = true
itertools.workspace = true
language.workspace = true
linkify.workspace = true
@@ -1,4 +1,4 @@
-use edit_prediction::EditPredictionProvider;
+use edit_prediction_types::EditPredictionDelegate;
use gpui::{Entity, KeyBinding, Modifiers, prelude::*};
use indoc::indoc;
use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint};
@@ -15,7 +15,7 @@ async fn test_edit_prediction_insert(cx: &mut gpui::TestAppContext) {
init_test(cx, |_| {});
let mut cx = EditorTestContext::new(cx).await;
- let provider = cx.new(|_| FakeEditPredictionProvider::default());
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
assign_editor_completion_provider(provider.clone(), &mut cx);
cx.set_state("let absolute_zero_celsius = ˇ;");
@@ -37,7 +37,7 @@ async fn test_edit_prediction_modification(cx: &mut gpui::TestAppContext) {
init_test(cx, |_| {});
let mut cx = EditorTestContext::new(cx).await;
- let provider = cx.new(|_| FakeEditPredictionProvider::default());
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
assign_editor_completion_provider(provider.clone(), &mut cx);
cx.set_state("let pi = ˇ\"foo\";");
@@ -59,7 +59,7 @@ async fn test_edit_prediction_jump_button(cx: &mut gpui::TestAppContext) {
init_test(cx, |_| {});
let mut cx = EditorTestContext::new(cx).await;
- let provider = cx.new(|_| FakeEditPredictionProvider::default());
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
assign_editor_completion_provider(provider.clone(), &mut cx);
// Cursor is 2+ lines above the proposed edit
@@ -128,7 +128,7 @@ async fn test_edit_prediction_invalidation_range(cx: &mut gpui::TestAppContext)
init_test(cx, |_| {});
let mut cx = EditorTestContext::new(cx).await;
- let provider = cx.new(|_| FakeEditPredictionProvider::default());
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
assign_editor_completion_provider(provider.clone(), &mut cx);
// Cursor is 3+ lines above the proposed edit
@@ -233,7 +233,7 @@ async fn test_edit_prediction_jump_disabled_for_non_zed_providers(cx: &mut gpui:
init_test(cx, |_| {});
let mut cx = EditorTestContext::new(cx).await;
- let provider = cx.new(|_| FakeNonZedEditPredictionProvider::default());
+ let provider = cx.new(|_| FakeNonZedEditPredictionDelegate::default());
assign_editor_completion_provider_non_zed(provider.clone(), &mut cx);
// Cursor is 2+ lines above the proposed edit
@@ -281,7 +281,7 @@ async fn test_edit_prediction_preview_cleanup_on_toggle_off(cx: &mut gpui::TestA
cx.update(|cx| cx.bind_keys([KeyBinding::new("ctrl-shift-a", AcceptEditPrediction, None)]));
let mut cx = EditorTestContext::new(cx).await;
- let provider = cx.new(|_| FakeEditPredictionProvider::default());
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
assign_editor_completion_provider(provider.clone(), &mut cx);
cx.set_state("let x = ˇ;");
@@ -371,7 +371,7 @@ fn accept_completion(cx: &mut EditorTestContext) {
}
fn propose_edits<T: ToOffset>(
- provider: &Entity<FakeEditPredictionProvider>,
+ provider: &Entity<FakeEditPredictionDelegate>,
edits: Vec<(Range<T>, &str)>,
cx: &mut EditorTestContext,
) {
@@ -383,7 +383,7 @@ fn propose_edits<T: ToOffset>(
cx.update(|_, cx| {
provider.update(cx, |provider, _| {
- provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local {
+ provider.set_edit_prediction(Some(edit_prediction_types::EditPrediction::Local {
id: None,
edits: edits.collect(),
edit_preview: None,
@@ -393,7 +393,7 @@ fn propose_edits<T: ToOffset>(
}
fn assign_editor_completion_provider(
- provider: Entity<FakeEditPredictionProvider>,
+ provider: Entity<FakeEditPredictionDelegate>,
cx: &mut EditorTestContext,
) {
cx.update_editor(|editor, window, cx| {
@@ -402,7 +402,7 @@ fn assign_editor_completion_provider(
}
fn propose_edits_non_zed<T: ToOffset>(
- provider: &Entity<FakeNonZedEditPredictionProvider>,
+ provider: &Entity<FakeNonZedEditPredictionDelegate>,
edits: Vec<(Range<T>, &str)>,
cx: &mut EditorTestContext,
) {
@@ -414,7 +414,7 @@ fn propose_edits_non_zed<T: ToOffset>(
cx.update(|_, cx| {
provider.update(cx, |provider, _| {
- provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local {
+ provider.set_edit_prediction(Some(edit_prediction_types::EditPrediction::Local {
id: None,
edits: edits.collect(),
edit_preview: None,
@@ -424,7 +424,7 @@ fn propose_edits_non_zed<T: ToOffset>(
}
fn assign_editor_completion_provider_non_zed(
- provider: Entity<FakeNonZedEditPredictionProvider>,
+ provider: Entity<FakeNonZedEditPredictionDelegate>,
cx: &mut EditorTestContext,
) {
cx.update_editor(|editor, window, cx| {
@@ -433,17 +433,20 @@ fn assign_editor_completion_provider_non_zed(
}
#[derive(Default, Clone)]
-pub struct FakeEditPredictionProvider {
- pub completion: Option<edit_prediction::EditPrediction>,
+pub struct FakeEditPredictionDelegate {
+ pub completion: Option<edit_prediction_types::EditPrediction>,
}
-impl FakeEditPredictionProvider {
- pub fn set_edit_prediction(&mut self, completion: Option<edit_prediction::EditPrediction>) {
+impl FakeEditPredictionDelegate {
+ pub fn set_edit_prediction(
+ &mut self,
+ completion: Option<edit_prediction_types::EditPrediction>,
+ ) {
self.completion = completion;
}
}
-impl EditPredictionProvider for FakeEditPredictionProvider {
+impl EditPredictionDelegate for FakeEditPredictionDelegate {
fn name() -> &'static str {
"fake-completion-provider"
}
@@ -452,7 +455,7 @@ impl EditPredictionProvider for FakeEditPredictionProvider {
"Fake Completion Provider"
}
- fn show_completions_in_menu() -> bool {
+ fn show_predictions_in_menu() -> bool {
true
}
@@ -486,7 +489,7 @@ impl EditPredictionProvider for FakeEditPredictionProvider {
&mut self,
_buffer: gpui::Entity<language::Buffer>,
_cursor_position: language::Anchor,
- _direction: edit_prediction::Direction,
+ _direction: edit_prediction_types::Direction,
_cx: &mut gpui::Context<Self>,
) {
}
@@ -500,23 +503,26 @@ impl EditPredictionProvider for FakeEditPredictionProvider {
_buffer: &gpui::Entity<language::Buffer>,
_cursor_position: language::Anchor,
_cx: &mut gpui::Context<Self>,
- ) -> Option<edit_prediction::EditPrediction> {
+ ) -> Option<edit_prediction_types::EditPrediction> {
self.completion.clone()
}
}
#[derive(Default, Clone)]
-pub struct FakeNonZedEditPredictionProvider {
- pub completion: Option<edit_prediction::EditPrediction>,
+pub struct FakeNonZedEditPredictionDelegate {
+ pub completion: Option<edit_prediction_types::EditPrediction>,
}
-impl FakeNonZedEditPredictionProvider {
- pub fn set_edit_prediction(&mut self, completion: Option<edit_prediction::EditPrediction>) {
+impl FakeNonZedEditPredictionDelegate {
+ pub fn set_edit_prediction(
+ &mut self,
+ completion: Option<edit_prediction_types::EditPrediction>,
+ ) {
self.completion = completion;
}
}
-impl EditPredictionProvider for FakeNonZedEditPredictionProvider {
+impl EditPredictionDelegate for FakeNonZedEditPredictionDelegate {
fn name() -> &'static str {
"fake-non-zed-provider"
}
@@ -525,7 +531,7 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider {
"Fake Non-Zed Provider"
}
- fn show_completions_in_menu() -> bool {
+ fn show_predictions_in_menu() -> bool {
false
}
@@ -559,7 +565,7 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider {
&mut self,
_buffer: gpui::Entity<language::Buffer>,
_cursor_position: language::Anchor,
- _direction: edit_prediction::Direction,
+ _direction: edit_prediction_types::Direction,
_cx: &mut gpui::Context<Self>,
) {
}
@@ -573,7 +579,7 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider {
_buffer: &gpui::Entity<language::Buffer>,
_cursor_position: language::Anchor,
_cx: &mut gpui::Context<Self>,
- ) -> Option<edit_prediction::EditPrediction> {
+ ) -> Option<edit_prediction_types::EditPrediction> {
self.completion.clone()
}
}
@@ -51,7 +51,7 @@ pub mod test;
pub(crate) use actions::*;
pub use display_map::{ChunkRenderer, ChunkRendererContext, DisplayPoint, FoldPlaceholder};
-pub use edit_prediction::Direction;
+pub use edit_prediction_types::Direction;
pub use editor_settings::{
CurrentLineHighlight, DocumentColorsRenderMode, EditorSettings, HideMouseMode,
ScrollBeyondLastLine, ScrollbarAxes, SearchSettings, ShowMinimap,
@@ -92,7 +92,7 @@ use collections::{BTreeMap, HashMap, HashSet, VecDeque};
use convert_case::{Case, Casing};
use dap::TelemetrySpawnLocation;
use display_map::*;
-use edit_prediction::{EditPredictionProvider, EditPredictionProviderHandle};
+use edit_prediction_types::{EditPredictionDelegate, EditPredictionDelegateHandle};
use editor_settings::{GoToDefinitionFallback, Minimap as MinimapSettings};
use element::{AcceptEditPredictionBinding, LineWithInvisibles, PositionMap, layout_line};
use futures::{
@@ -1120,7 +1120,7 @@ pub struct Editor {
pending_mouse_down: Option<Rc<RefCell<Option<MouseDownEvent>>>>,
gutter_hovered: bool,
hovered_link_state: Option<HoveredLinkState>,
- edit_prediction_provider: Option<RegisteredEditPredictionProvider>,
+ edit_prediction_provider: Option<RegisteredEditPredictionDelegate>,
code_action_providers: Vec<Rc<dyn CodeActionProvider>>,
active_edit_prediction: Option<EditPredictionState>,
/// Used to prevent flickering as the user types while the menu is open
@@ -1562,8 +1562,8 @@ pub struct RenameState {
struct InvalidationStack<T>(Vec<T>);
-struct RegisteredEditPredictionProvider {
- provider: Arc<dyn EditPredictionProviderHandle>,
+struct RegisteredEditPredictionDelegate {
+ provider: Arc<dyn EditPredictionDelegateHandle>,
_subscription: Subscription,
}
@@ -2988,9 +2988,9 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) where
- T: EditPredictionProvider,
+ T: EditPredictionDelegate,
{
- self.edit_prediction_provider = provider.map(|provider| RegisteredEditPredictionProvider {
+ self.edit_prediction_provider = provider.map(|provider| RegisteredEditPredictionDelegate {
_subscription: cx.observe_in(&provider, window, |this, _, window, cx| {
if this.focus_handle.is_focused(window) {
this.update_visible_edit_prediction(window, cx);
@@ -7394,7 +7394,7 @@ impl Editor {
&& self
.edit_prediction_provider
.as_ref()
- .is_some_and(|provider| provider.provider.show_completions_in_menu());
+ .is_some_and(|provider| provider.provider.show_predictions_in_menu());
let preview_requires_modifier =
all_language_settings(file, cx).edit_predictions_mode() == EditPredictionsMode::Subtle;
@@ -8095,12 +8095,12 @@ impl Editor {
let edit_prediction = provider.suggest(&buffer, cursor_buffer_position, cx)?;
let (completion_id, edits, edit_preview) = match edit_prediction {
- edit_prediction::EditPrediction::Local {
+ edit_prediction_types::EditPrediction::Local {
id,
edits,
edit_preview,
} => (id, edits, edit_preview),
- edit_prediction::EditPrediction::Jump {
+ edit_prediction_types::EditPrediction::Jump {
id,
snapshot,
target,
@@ -8241,7 +8241,7 @@ impl Editor {
Some(())
}
- pub fn edit_prediction_provider(&self) -> Option<Arc<dyn EditPredictionProviderHandle>> {
+ pub fn edit_prediction_provider(&self) -> Option<Arc<dyn EditPredictionDelegateHandle>> {
Some(self.edit_prediction_provider.as_ref()?.provider.clone())
}
@@ -9563,7 +9563,7 @@ impl Editor {
editor_bg_color.blend(accent_color.opacity(0.6))
}
fn get_prediction_provider_icon_name(
- provider: &Option<RegisteredEditPredictionProvider>,
+ provider: &Option<RegisteredEditPredictionDelegate>,
) -> IconName {
match provider {
Some(provider) => match provider.provider.name() {
@@ -2,7 +2,7 @@ use super::*;
use crate::{
JoinLines,
code_context_menus::CodeContextMenu,
- edit_prediction_tests::FakeEditPredictionProvider,
+ edit_prediction_tests::FakeEditPredictionDelegate,
element::StickyHeader,
linked_editing_ranges::LinkedEditingRanges,
scroll::scroll_amount::ScrollAmount,
@@ -8636,7 +8636,7 @@ async fn test_undo_edit_prediction_scrolls_to_edit_pos(cx: &mut TestAppContext)
let mut cx = EditorTestContext::new(cx).await;
- let provider = cx.new(|_| FakeEditPredictionProvider::default());
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
cx.update_editor(|editor, window, cx| {
editor.set_edit_prediction_provider(Some(provider.clone()), window, cx);
});
@@ -8659,7 +8659,7 @@ async fn test_undo_edit_prediction_scrolls_to_edit_pos(cx: &mut TestAppContext)
cx.update(|_, cx| {
provider.update(cx, |provider, _| {
- provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local {
+ provider.set_edit_prediction(Some(edit_prediction_types::EditPrediction::Local {
id: None,
edits: vec![(edit_position..edit_position, "X".into())],
edit_preview: None,
@@ -1,11 +1,5 @@
use crate::FeatureFlag;
-pub struct PredictEditsRateCompletionsFeatureFlag;
-
-impl FeatureFlag for PredictEditsRateCompletionsFeatureFlag {
- const NAME: &'static str = "predict-edits-rate-completions";
-}
-
pub struct NotebookFeatureFlag;
impl FeatureFlag for NotebookFeatureFlag {
@@ -16,7 +16,7 @@ doctest = false
anyhow.workspace = true
client.workspace = true
collections.workspace = true
-edit_prediction.workspace = true
+edit_prediction_types.workspace = true
futures.workspace = true
gpui.workspace = true
language.workspace = true
@@ -1,7 +1,7 @@
mod messages;
-mod supermaven_completion_provider;
+mod supermaven_edit_prediction_delegate;
-pub use supermaven_completion_provider::*;
+pub use supermaven_edit_prediction_delegate::*;
use anyhow::{Context as _, Result};
#[allow(unused_imports)]
@@ -1,6 +1,6 @@
use crate::{Supermaven, SupermavenCompletionStateId};
use anyhow::Result;
-use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
+use edit_prediction_types::{Direction, EditPrediction, EditPredictionDelegate};
use futures::StreamExt as _;
use gpui::{App, Context, Entity, EntityId, Task};
use language::{Anchor, Buffer, BufferSnapshot};
@@ -15,7 +15,7 @@ use unicode_segmentation::UnicodeSegmentation;
pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
-pub struct SupermavenCompletionProvider {
+pub struct SupermavenEditPredictionDelegate {
supermaven: Entity<Supermaven>,
buffer_id: Option<EntityId>,
completion_id: Option<SupermavenCompletionStateId>,
@@ -25,7 +25,7 @@ pub struct SupermavenCompletionProvider {
completion_position: Option<language::Anchor>,
}
-impl SupermavenCompletionProvider {
+impl SupermavenEditPredictionDelegate {
pub fn new(supermaven: Entity<Supermaven>) -> Self {
Self {
supermaven,
@@ -104,7 +104,7 @@ fn completion_from_diff(
}
}
-impl EditPredictionProvider for SupermavenCompletionProvider {
+impl EditPredictionDelegate for SupermavenEditPredictionDelegate {
fn name() -> &'static str {
"supermaven"
}
@@ -113,7 +113,7 @@ impl EditPredictionProvider for SupermavenCompletionProvider {
"Supermaven"
}
- fn show_completions_in_menu() -> bool {
+ fn show_predictions_in_menu() -> bool {
true
}
@@ -269,8 +269,8 @@ impl EditPredictionProvider for SupermavenCompletionProvider {
}
fn reset_completion_cache(
- provider: &mut SupermavenCompletionProvider,
- _cx: &mut Context<SupermavenCompletionProvider>,
+ provider: &mut SupermavenEditPredictionDelegate,
+ _cx: &mut Context<SupermavenEditPredictionDelegate>,
) {
provider.pending_refresh = None;
provider.completion_id = None;
@@ -50,7 +50,6 @@ debugger_tools.workspace = true
debugger_ui.workspace = true
diagnostics.workspace = true
editor.workspace = true
-zeta2_tools.workspace = true
env_logger.workspace = true
extension.workspace = true
extension_host.workspace = true
@@ -74,7 +73,8 @@ gpui = { workspace = true, features = [
gpui_tokio.workspace = true
rayon.workspace = true
-edit_prediction_button.workspace = true
+edit_prediction.workspace = true
+edit_prediction_ui.workspace = true
http_client.workspace = true
image_viewer.workspace = true
inspector_ui.workspace = true
@@ -160,7 +160,6 @@ web_search_providers.workspace = true
workspace.workspace = true
zed_actions.workspace = true
zed_env_vars.workspace = true
-zeta.workspace = true
zlog.workspace = true
zlog_settings.workspace = true
chrono.workspace = true
@@ -581,7 +581,7 @@ pub fn main() {
language_model::init(app_state.client.clone(), cx);
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
acp_tools::init(cx);
- zeta2_tools::init(cx);
+ edit_prediction_ui::init(cx);
web_search::init(cx);
web_search_providers::init(app_state.client.clone(), cx);
snippet_provider::init(cx);
@@ -640,7 +640,7 @@ pub fn main() {
settings_ui::init(cx);
keymap_editor::init(cx);
extensions_ui::init(cx);
- zeta::init(cx);
+ edit_prediction::init(cx);
inspector_ui::init(app_state.clone(), cx);
json_schema_store::init(cx);
miniprofiler_ui::init(*STARTUP_TIME.get().unwrap(), cx);
@@ -401,8 +401,8 @@ pub fn initialize_workspace(
unstable_version_notification(cx);
let edit_prediction_menu_handle = PopoverMenuHandle::default();
- let edit_prediction_button = cx.new(|cx| {
- edit_prediction_button::EditPredictionButton::new(
+ let edit_prediction_ui = cx.new(|cx| {
+ edit_prediction_ui::EditPredictionButton::new(
app_state.fs.clone(),
app_state.user_store.clone(),
edit_prediction_menu_handle.clone(),
@@ -411,7 +411,7 @@ pub fn initialize_workspace(
)
});
workspace.register_action({
- move |_, _: &edit_prediction_button::ToggleMenu, window, cx| {
+ move |_, _: &edit_prediction_ui::ToggleMenu, window, cx| {
edit_prediction_menu_handle.toggle(window, cx);
}
});
@@ -450,7 +450,7 @@ pub fn initialize_workspace(
status_bar.add_left_item(lsp_button, window, cx);
status_bar.add_left_item(diagnostic_summary, window, cx);
status_bar.add_left_item(activity_indicator, window, cx);
- status_bar.add_right_item(edit_prediction_button, window, cx);
+ status_bar.add_right_item(edit_prediction_ui, window, cx);
status_bar.add_right_item(active_buffer_language, window, cx);
status_bar.add_right_item(active_toolchain_language, window, cx);
status_bar.add_right_item(line_ending_indicator, window, cx);
@@ -1,7 +1,8 @@
use client::{Client, UserStore};
-use codestral::CodestralCompletionProvider;
+use codestral::CodestralEditPredictionDelegate;
use collections::HashMap;
-use copilot::{Copilot, CopilotCompletionProvider};
+use copilot::{Copilot, CopilotEditPredictionDelegate};
+use edit_prediction::{SweepFeatureFlag, ZedEditPredictionDelegate, Zeta2FeatureFlag};
use editor::Editor;
use feature_flags::FeatureFlagAppExt;
use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
@@ -12,9 +13,8 @@ use settings::{
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore,
};
use std::{cell::RefCell, rc::Rc, sync::Arc};
-use supermaven::{Supermaven, SupermavenCompletionProvider};
+use supermaven::{Supermaven, SupermavenEditPredictionDelegate};
use ui::Window;
-use zeta::{SweepFeatureFlag, Zeta2FeatureFlag, ZetaEditPredictionProvider};
pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
let editors: Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>> = Rc::default();
@@ -59,7 +59,7 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
})
.detach();
- cx.on_action(clear_zeta_edit_history);
+ cx.on_action(clear_edit_prediction_store_edit_history);
let mut provider = all_language_settings(None, cx).edit_predictions.provider;
cx.subscribe(&user_store, {
@@ -100,9 +100,9 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
.detach();
}
-fn clear_zeta_edit_history(_: &zeta::ClearHistory, cx: &mut App) {
- if let Some(zeta) = zeta::Zeta::try_global(cx) {
- zeta.update(cx, |zeta, _| zeta.clear_history());
+fn clear_edit_prediction_store_edit_history(_: &edit_prediction::ClearHistory, cx: &mut App) {
+ if let Some(ep_store) = edit_prediction::EditPredictionStore::try_global(cx) {
+ ep_store.update(cx, |ep_store, _| ep_store.clear_history());
}
}
@@ -176,7 +176,7 @@ fn assign_edit_prediction_provider(
match provider {
EditPredictionProvider::None => {
- editor.set_edit_prediction_provider::<ZetaEditPredictionProvider>(None, window, cx);
+ editor.set_edit_prediction_provider::<ZedEditPredictionDelegate>(None, window, cx);
}
EditPredictionProvider::Copilot => {
if let Some(copilot) = Copilot::global(cx) {
@@ -187,55 +187,61 @@ fn assign_edit_prediction_provider(
copilot.register_buffer(&buffer, cx);
});
}
- let provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
+ let provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
}
EditPredictionProvider::Supermaven => {
if let Some(supermaven) = Supermaven::global(cx) {
- let provider = cx.new(|_| SupermavenCompletionProvider::new(supermaven));
+ let provider = cx.new(|_| SupermavenEditPredictionDelegate::new(supermaven));
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
}
EditPredictionProvider::Codestral => {
let http_client = client.http_client();
- let provider = cx.new(|_| CodestralCompletionProvider::new(http_client));
+ let provider = cx.new(|_| CodestralEditPredictionDelegate::new(http_client));
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
value @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => {
- let zeta = zeta::Zeta::global(client, &user_store, cx);
+ let ep_store = edit_prediction::EditPredictionStore::global(client, &user_store, cx);
if let Some(project) = editor.project()
&& let Some(buffer) = &singleton_buffer
&& buffer.read(cx).file().is_some()
{
- let has_model = zeta.update(cx, |zeta, cx| {
+ let has_model = ep_store.update(cx, |ep_store, cx| {
let model = if let EditPredictionProvider::Experimental(name) = value {
if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
&& cx.has_flag::<SweepFeatureFlag>()
{
- zeta::ZetaEditPredictionModel::Sweep
+ edit_prediction::EditPredictionModel::Sweep
} else if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME
&& cx.has_flag::<Zeta2FeatureFlag>()
{
- zeta::ZetaEditPredictionModel::Zeta2
+ edit_prediction::EditPredictionModel::Zeta2
} else {
return false;
}
} else if user_store.read(cx).current_user().is_some() {
- zeta::ZetaEditPredictionModel::Zeta1
+ edit_prediction::EditPredictionModel::Zeta1
} else {
return false;
};
- zeta.set_edit_prediction_model(model);
- zeta.register_buffer(buffer, project, cx);
+ ep_store.set_edit_prediction_model(model);
+ ep_store.register_buffer(buffer, project, cx);
true
});
if has_model {
let provider = cx.new(|cx| {
- ZetaEditPredictionProvider::new(project.clone(), &client, &user_store, cx)
+ ZedEditPredictionDelegate::new(
+ project.clone(),
+ singleton_buffer,
+ &client,
+ &user_store,
+ cx,
+ )
});
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
@@ -1,85 +0,0 @@
-[package]
-name = "zeta"
-version = "0.1.0"
-edition.workspace = true
-publish.workspace = true
-license = "GPL-3.0-or-later"
-
-[lints]
-workspace = true
-
-[lib]
-path = "src/zeta.rs"
-
-[features]
-eval-support = []
-
-[dependencies]
-ai_onboarding.workspace = true
-anyhow.workspace = true
-arrayvec.workspace = true
-brotli.workspace = true
-buffer_diff.workspace = true
-client.workspace = true
-cloud_llm_client.workspace = true
-cloud_zeta2_prompt.workspace = true
-collections.workspace = true
-command_palette_hooks.workspace = true
-copilot.workspace = true
-credentials_provider.workspace = true
-db.workspace = true
-edit_prediction.workspace = true
-edit_prediction_context.workspace = true
-edit_prediction_context2.workspace = true
-editor.workspace = true
-feature_flags.workspace = true
-fs.workspace = true
-futures.workspace = true
-gpui.workspace = true
-indoc.workspace = true
-itertools.workspace = true
-language.workspace = true
-language_model.workspace = true
-log.workspace = true
-lsp.workspace = true
-markdown.workspace = true
-menu.workspace = true
-open_ai.workspace = true
-postage.workspace = true
-pretty_assertions.workspace = true
-project.workspace = true
-rand.workspace = true
-regex.workspace = true
-release_channel.workspace = true
-semver.workspace = true
-serde.workspace = true
-serde_json.workspace = true
-settings.workspace = true
-smol.workspace = true
-strsim.workspace = true
-strum.workspace = true
-telemetry.workspace = true
-telemetry_events.workspace = true
-theme.workspace = true
-thiserror.workspace = true
-ui.workspace = true
-util.workspace = true
-uuid.workspace = true
-workspace.workspace = true
-worktree.workspace = true
-zed_actions.workspace = true
-
-[dev-dependencies]
-clock = { workspace = true, features = ["test-support"] }
-cloud_api_types.workspace = true
-cloud_llm_client = { workspace = true, features = ["test-support"] }
-ctor.workspace = true
-gpui = { workspace = true, features = ["test-support"] }
-indoc.workspace = true
-language = { workspace = true, features = ["test-support"] }
-language_model = { workspace = true, features = ["test-support"] }
-lsp.workspace = true
-parking_lot.workspace = true
-project = { workspace = true, features = ["test-support"] }
-settings = { workspace = true, features = ["test-support"] }
-zlog.workspace = true
@@ -1,490 +0,0 @@
-use anyhow::Result;
-use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery;
-use collections::HashMap;
-use edit_prediction_context2::{RelatedExcerpt, RelatedFile};
-use futures::{
- StreamExt,
- channel::mpsc::{self, UnboundedSender},
-};
-use gpui::{AppContext, AsyncApp, Entity};
-use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, Point, ToOffset, ToPoint};
-use project::{
- Project, ProjectPath, WorktreeSettings,
- search::{SearchQuery, SearchResult},
-};
-use smol::channel;
-use std::ops::Range;
-use util::{
- ResultExt as _,
- paths::{PathMatcher, PathStyle},
-};
-use workspace::item::Settings as _;
-
-#[cfg(feature = "eval-support")]
-type CachedSearchResults = std::collections::BTreeMap<std::path::PathBuf, Vec<Range<(u32, u32)>>>;
-
-pub async fn run_retrieval_searches(
- queries: Vec<SearchToolQuery>,
- project: Entity<Project>,
- #[cfg(feature = "eval-support")] eval_cache: Option<std::sync::Arc<dyn crate::EvalCache>>,
- cx: &mut AsyncApp,
-) -> Result<Vec<RelatedFile>> {
- #[cfg(feature = "eval-support")]
- let cache = if let Some(eval_cache) = eval_cache {
- use crate::EvalCacheEntryKind;
- use anyhow::Context;
- use collections::FxHasher;
- use std::hash::{Hash, Hasher};
-
- let mut hasher = FxHasher::default();
- project.read_with(cx, |project, cx| {
- let mut worktrees = project.worktrees(cx);
- let Some(worktree) = worktrees.next() else {
- panic!("Expected a single worktree in eval project. Found none.");
- };
- assert!(
- worktrees.next().is_none(),
- "Expected a single worktree in eval project. Found more than one."
- );
- worktree.read(cx).abs_path().hash(&mut hasher);
- })?;
-
- queries.hash(&mut hasher);
- let key = (EvalCacheEntryKind::Search, hasher.finish());
-
- if let Some(cached_results) = eval_cache.read(key) {
- let file_results = serde_json::from_str::<CachedSearchResults>(&cached_results)
- .context("Failed to deserialize cached search results")?;
- let mut results = Vec::new();
-
- for (path, ranges) in file_results {
- let project_path = project.update(cx, |project, cx| {
- project.find_project_path(path, cx).unwrap()
- })?;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_buffer(project_path.clone(), cx)
- })?
- .await?;
- let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
- let mut ranges: Vec<_> = ranges
- .into_iter()
- .map(
- |Range {
- start: (start_row, start_col),
- end: (end_row, end_col),
- }| {
- snapshot.anchor_before(Point::new(start_row, start_col))
- ..snapshot.anchor_after(Point::new(end_row, end_col))
- },
- )
- .collect();
- merge_anchor_ranges(&mut ranges, &snapshot);
- results.push(RelatedFile {
- path: project_path,
- buffer: buffer.downgrade(),
- excerpts: ranges
- .into_iter()
- .map(|range| RelatedExcerpt {
- point_range: range.to_point(&snapshot),
- text: snapshot.as_rope().slice(range.to_offset(&snapshot)),
- anchor_range: range,
- })
- .collect(),
- max_row: snapshot.max_point().row,
- });
- }
-
- return Ok(results);
- }
-
- Some((eval_cache, serde_json::to_string_pretty(&queries)?, key))
- } else {
- None
- };
-
- let (exclude_matcher, path_style) = project.update(cx, |project, cx| {
- let global_settings = WorktreeSettings::get_global(cx);
- let exclude_patterns = global_settings
- .file_scan_exclusions
- .sources()
- .chain(global_settings.private_files.sources());
- let path_style = project.path_style(cx);
- anyhow::Ok((PathMatcher::new(exclude_patterns, path_style)?, path_style))
- })??;
-
- let (results_tx, mut results_rx) = mpsc::unbounded();
-
- for query in queries {
- let exclude_matcher = exclude_matcher.clone();
- let results_tx = results_tx.clone();
- let project = project.clone();
- cx.spawn(async move |cx| {
- run_query(
- query,
- results_tx.clone(),
- path_style,
- exclude_matcher,
- &project,
- cx,
- )
- .await
- .log_err();
- })
- .detach()
- }
- drop(results_tx);
-
- #[cfg(feature = "eval-support")]
- let cache = cache.clone();
- cx.background_spawn(async move {
- let mut results: Vec<RelatedFile> = Vec::default();
- let mut snapshots = HashMap::default();
-
- let mut total_bytes = 0;
- 'outer: while let Some((project_path, buffer, snapshot, excerpts)) = results_rx.next().await
- {
- let existing = results
- .iter_mut()
- .find(|related_file| related_file.buffer.entity_id() == buffer.entity_id());
- let existing = match existing {
- Some(existing) => existing,
- None => {
- results.push(RelatedFile {
- path: project_path,
- buffer: buffer.downgrade(),
- excerpts: Vec::new(),
- max_row: snapshot.max_point().row,
- });
- results.last_mut().unwrap()
- }
- };
- // let existing = results.entry(buffer).or_default();
- existing.excerpts.reserve(excerpts.len());
-
- for (range, size) in excerpts {
- // Blunt trimming of the results until we have a proper algorithmic filtering step
- if (total_bytes + size) > MAX_RESULTS_LEN {
- log::trace!("Combined results reached limit of {MAX_RESULTS_LEN}B");
- break 'outer;
- }
- total_bytes += size;
- existing.excerpts.push(RelatedExcerpt {
- point_range: range.to_point(&snapshot),
- text: snapshot.as_rope().slice(range.to_offset(&snapshot)),
- anchor_range: range,
- });
- }
- snapshots.insert(buffer.entity_id(), snapshot);
- }
-
- #[cfg(feature = "eval-support")]
- if let Some((cache, queries, key)) = cache {
- let cached_results: CachedSearchResults = results
- .iter()
- .map(|related_file| {
- let mut ranges = related_file
- .excerpts
- .iter()
- .map(
- |RelatedExcerpt {
- point_range: Range { start, end },
- ..
- }| {
- (start.row, start.column)..(end.row, end.column)
- },
- )
- .collect::<Vec<_>>();
- ranges.sort_unstable_by_key(|range| (range.start, range.end));
- (related_file.path.path.as_std_path().to_path_buf(), ranges)
- })
- .collect();
- cache.write(
- key,
- &queries,
- &serde_json::to_string_pretty(&cached_results)?,
- );
- }
-
- for related_file in results.iter_mut() {
- related_file.merge_excerpts();
- }
-
- Ok(results)
- })
- .await
-}
-
-#[cfg(feature = "eval-support")]
-pub(crate) fn merge_anchor_ranges(ranges: &mut Vec<Range<Anchor>>, snapshot: &BufferSnapshot) {
- ranges.sort_unstable_by(|a, b| {
- a.start
- .cmp(&b.start, snapshot)
- .then(b.end.cmp(&a.end, snapshot))
- });
-
- let mut index = 1;
- while index < ranges.len() {
- if ranges[index - 1]
- .end
- .cmp(&ranges[index].start, snapshot)
- .is_ge()
- {
- let removed = ranges.remove(index);
- if removed.end.cmp(&ranges[index - 1].end, snapshot).is_gt() {
- ranges[index - 1].end = removed.end;
- }
- } else {
- index += 1;
- }
- }
-}
-
-const MAX_EXCERPT_LEN: usize = 768;
-const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5;
-
-struct SearchJob {
- buffer: Entity<Buffer>,
- snapshot: BufferSnapshot,
- project_path: ProjectPath,
- ranges: Vec<Range<usize>>,
- query_ix: usize,
- jobs_tx: channel::Sender<SearchJob>,
-}
-
-async fn run_query(
- input_query: SearchToolQuery,
- results_tx: UnboundedSender<(
- ProjectPath,
- Entity<Buffer>,
- BufferSnapshot,
- Vec<(Range<Anchor>, usize)>,
- )>,
- path_style: PathStyle,
- exclude_matcher: PathMatcher,
- project: &Entity<Project>,
- cx: &mut AsyncApp,
-) -> Result<()> {
- let include_matcher = PathMatcher::new(vec![input_query.glob], path_style)?;
-
- let make_search = |regex: &str| -> Result<SearchQuery> {
- SearchQuery::regex(
- regex,
- false,
- true,
- false,
- true,
- include_matcher.clone(),
- exclude_matcher.clone(),
- true,
- None,
- )
- };
-
- if let Some(outer_syntax_regex) = input_query.syntax_node.first() {
- let outer_syntax_query = make_search(outer_syntax_regex)?;
- let nested_syntax_queries = input_query
- .syntax_node
- .into_iter()
- .skip(1)
- .map(|query| make_search(&query))
- .collect::<Result<Vec<_>>>()?;
- let content_query = input_query
- .content
- .map(|regex| make_search(®ex))
- .transpose()?;
-
- let (jobs_tx, jobs_rx) = channel::unbounded();
-
- let outer_search_results_rx =
- project.update(cx, |project, cx| project.search(outer_syntax_query, cx))?;
-
- let outer_search_task = cx.spawn(async move |cx| {
- futures::pin_mut!(outer_search_results_rx);
- while let Some(SearchResult::Buffer { buffer, ranges }) =
- outer_search_results_rx.next().await
- {
- buffer
- .read_with(cx, |buffer, _| buffer.parsing_idle())?
- .await;
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
- let Some(file) = snapshot.file() else {
- continue;
- };
-
- let project_path = cx.update(|cx| ProjectPath {
- worktree_id: file.worktree_id(cx),
- path: file.path().clone(),
- })?;
- let expanded_ranges: Vec<_> = ranges
- .into_iter()
- .filter_map(|range| expand_to_parent_range(&range, &snapshot))
- .collect();
- jobs_tx
- .send(SearchJob {
- project_path,
- buffer,
- snapshot,
- ranges: expanded_ranges,
- query_ix: 0,
- jobs_tx: jobs_tx.clone(),
- })
- .await?;
- }
- anyhow::Ok(())
- });
-
- let n_workers = cx.background_executor().num_cpus();
- let search_job_task = cx.background_executor().scoped(|scope| {
- for _ in 0..n_workers {
- scope.spawn(async {
- while let Ok(job) = jobs_rx.recv().await {
- process_nested_search_job(
- &results_tx,
- &nested_syntax_queries,
- &content_query,
- job,
- )
- .await;
- }
- });
- }
- });
-
- search_job_task.await;
- outer_search_task.await?;
- } else if let Some(content_regex) = &input_query.content {
- let search_query = make_search(&content_regex)?;
-
- let results_rx = project.update(cx, |project, cx| project.search(search_query, cx))?;
- futures::pin_mut!(results_rx);
-
- while let Some(SearchResult::Buffer { buffer, ranges }) = results_rx.next().await {
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
- let Some(file) = snapshot.file() else {
- continue;
- };
- let project_path = cx.update(|cx| ProjectPath {
- worktree_id: file.worktree_id(cx),
- path: file.path().clone(),
- })?;
-
- let ranges = ranges
- .into_iter()
- .map(|range| {
- let range = range.to_offset(&snapshot);
- let range = expand_to_entire_lines(range, &snapshot);
- let size = range.len();
- let range =
- snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
- (range, size)
- })
- .collect();
-
- let send_result =
- results_tx.unbounded_send((project_path, buffer.clone(), snapshot.clone(), ranges));
-
- if let Err(err) = send_result
- && !err.is_disconnected()
- {
- log::error!("{err}");
- }
- }
- } else {
- log::warn!("Context gathering model produced a glob-only search");
- }
-
- anyhow::Ok(())
-}
-
-async fn process_nested_search_job(
- results_tx: &UnboundedSender<(
- ProjectPath,
- Entity<Buffer>,
- BufferSnapshot,
- Vec<(Range<Anchor>, usize)>,
- )>,
- queries: &Vec<SearchQuery>,
- content_query: &Option<SearchQuery>,
- job: SearchJob,
-) {
- if let Some(search_query) = queries.get(job.query_ix) {
- let mut subranges = Vec::new();
- for range in job.ranges {
- let start = range.start;
- let search_results = search_query.search(&job.snapshot, Some(range)).await;
- for subrange in search_results {
- let subrange = start + subrange.start..start + subrange.end;
- subranges.extend(expand_to_parent_range(&subrange, &job.snapshot));
- }
- }
- job.jobs_tx
- .send(SearchJob {
- project_path: job.project_path,
- buffer: job.buffer,
- snapshot: job.snapshot,
- ranges: subranges,
- query_ix: job.query_ix + 1,
- jobs_tx: job.jobs_tx.clone(),
- })
- .await
- .ok();
- } else {
- let ranges = if let Some(content_query) = content_query {
- let mut subranges = Vec::new();
- for range in job.ranges {
- let start = range.start;
- let search_results = content_query.search(&job.snapshot, Some(range)).await;
- for subrange in search_results {
- let subrange = start + subrange.start..start + subrange.end;
- subranges.push(subrange);
- }
- }
- subranges
- } else {
- job.ranges
- };
-
- let matches = ranges
- .into_iter()
- .map(|range| {
- let snapshot = &job.snapshot;
- let range = expand_to_entire_lines(range, snapshot);
- let size = range.len();
- let range = snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
- (range, size)
- })
- .collect();
-
- let send_result =
- results_tx.unbounded_send((job.project_path, job.buffer, job.snapshot, matches));
-
- if let Err(err) = send_result
- && !err.is_disconnected()
- {
- log::error!("{err}");
- }
- }
-}
-
-fn expand_to_entire_lines(range: Range<usize>, snapshot: &BufferSnapshot) -> Range<usize> {
- let mut point_range = range.to_point(snapshot);
- point_range.start.column = 0;
- if point_range.end.column > 0 {
- point_range.end = snapshot.max_point().min(point_range.end + Point::new(1, 0));
- }
- point_range.to_offset(snapshot)
-}
-
-fn expand_to_parent_range<T: ToPoint + ToOffset>(
- range: &Range<T>,
- snapshot: &BufferSnapshot,
-) -> Option<Range<usize>> {
- let mut line_range = range.to_point(&snapshot);
- line_range.start.column = snapshot.indent_size_for_line(line_range.start.row).len;
- line_range.end.column = snapshot.line_len(line_range.end.row);
- // TODO skip result if matched line isn't the first node line?
-
- let node = snapshot.syntax_ancestor(line_range)?;
- Some(node.byte_range())
-}
@@ -1,3890 +0,0 @@
-use anyhow::{Context as _, Result, anyhow, bail};
-use arrayvec::ArrayVec;
-use client::{Client, EditPredictionUsage, UserStore};
-use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
-use cloud_llm_client::{
- AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
- EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
- MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
- ZED_VERSION_HEADER_NAME,
-};
-use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
-use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
-use collections::{HashMap, HashSet};
-use command_palette_hooks::CommandPaletteFilter;
-use db::kvp::{Dismissable, KEY_VALUE_STORE};
-use edit_prediction_context::{
- EditPredictionContextOptions, EditPredictionExcerpt, EditPredictionExcerptOptions,
- EditPredictionScoreOptions, Line, SyntaxIndex,
-};
-use edit_prediction_context2::{
- RelatedExcerpt, RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile,
-};
-use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
-use futures::{
- AsyncReadExt as _, FutureExt as _, StreamExt as _,
- channel::{
- mpsc::{self, UnboundedReceiver},
- oneshot,
- },
- select_biased,
-};
-use gpui::BackgroundExecutor;
-use gpui::{
- App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
- http_client::{self, AsyncBody, Method},
- prelude::*,
-};
-use language::language_settings::all_language_settings;
-use language::{
- Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
-};
-use language::{BufferSnapshot, OffsetRangeExt};
-use language_model::{LlmApiToken, RefreshLlmTokenListener};
-use open_ai::FunctionDefinition;
-use project::{DisableAiSettings, Project, ProjectItem as _, ProjectPath, WorktreeId};
-use release_channel::AppVersion;
-use semver::Version;
-use serde::de::DeserializeOwned;
-use settings::{EditPredictionProvider, Settings, SettingsStore, update_settings_file};
-use std::any::{Any as _, TypeId};
-use std::collections::{VecDeque, hash_map};
-use telemetry_events::EditPredictionRating;
-use workspace::Workspace;
-
-use std::ops::Range;
-use std::path::Path;
-use std::rc::Rc;
-use std::str::FromStr as _;
-use std::sync::{Arc, LazyLock};
-use std::time::{Duration, Instant};
-use std::{env, mem};
-use thiserror::Error;
-use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
-use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
-
-mod license_detection;
-mod onboarding_modal;
-mod prediction;
-mod provider;
-mod rate_prediction_modal;
-pub mod retrieval_search;
-pub mod sweep_ai;
-pub mod udiff;
-mod xml_edits;
-pub mod zeta1;
-
-#[cfg(test)]
-mod zeta_tests;
-
-use crate::license_detection::LicenseDetectionWatcher;
-use crate::onboarding_modal::ZedPredictModal;
-pub use crate::prediction::EditPrediction;
-pub use crate::prediction::EditPredictionId;
-pub use crate::prediction::EditPredictionInputs;
-use crate::prediction::EditPredictionResult;
-use crate::rate_prediction_modal::{
- NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
- ThumbsUpActivePrediction,
-};
-pub use crate::sweep_ai::SweepAi;
-use crate::zeta1::request_prediction_with_zeta1;
-pub use provider::ZetaEditPredictionProvider;
-
-actions!(
- edit_prediction,
- [
- /// Resets the edit prediction onboarding state.
- ResetOnboarding,
- /// Opens the rate completions modal.
- RateCompletions,
- /// Clears the edit prediction history.
- ClearHistory,
- ]
-);
-
-/// Maximum number of events to track.
-const EVENT_COUNT_MAX: usize = 6;
-const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
-const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
-const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
-
-pub struct SweepFeatureFlag;
-
-impl FeatureFlag for SweepFeatureFlag {
- const NAME: &str = "sweep-ai";
-}
-pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
- max_bytes: 512,
- min_bytes: 128,
- target_before_cursor_over_total_bytes: 0.5,
-};
-
-pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Lsp(DEFAULT_EXCERPT_OPTIONS);
-
-pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
- excerpt: DEFAULT_EXCERPT_OPTIONS,
-};
-
-pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
- EditPredictionContextOptions {
- use_imports: true,
- max_retrieved_declarations: 0,
- excerpt: DEFAULT_EXCERPT_OPTIONS,
- score: EditPredictionScoreOptions {
- omit_excerpt_overlaps: true,
- },
- };
-
-pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
- context: DEFAULT_CONTEXT_OPTIONS,
- max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
- max_diagnostic_bytes: 2048,
- prompt_format: PromptFormat::DEFAULT,
- file_indexing_parallelism: 1,
- buffer_change_grouping_interval: Duration::from_secs(1),
-};
-
-static USE_OLLAMA: LazyLock<bool> =
- LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
-static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
- env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
- "qwen3-coder:30b".to_string()
- } else {
- "yqvev8r3".to_string()
- })
-});
-static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
- match env::var("ZED_ZETA2_MODEL").as_deref() {
- Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
- Ok(model) => model,
- Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
- Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
- }
- .to_string()
-});
-static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
- env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
- if *USE_OLLAMA {
- Some("http://localhost:11434/v1/chat/completions".into())
- } else {
- None
- }
- })
-});
-
-pub struct Zeta2FeatureFlag;
-
-impl FeatureFlag for Zeta2FeatureFlag {
- const NAME: &'static str = "zeta2";
-
- fn enabled_for_staff() -> bool {
- true
- }
-}
-
-#[derive(Clone)]
-struct ZetaGlobal(Entity<Zeta>);
-
-impl Global for ZetaGlobal {}
-
-pub struct Zeta {
- client: Arc<Client>,
- user_store: Entity<UserStore>,
- llm_token: LlmApiToken,
- _llm_token_subscription: Subscription,
- projects: HashMap<EntityId, ZetaProject>,
- use_context: bool,
- options: ZetaOptions,
- update_required: bool,
- debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
- #[cfg(feature = "eval-support")]
- eval_cache: Option<Arc<dyn EvalCache>>,
- edit_prediction_model: ZetaEditPredictionModel,
- pub sweep_ai: SweepAi,
- data_collection_choice: DataCollectionChoice,
- reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
- shown_predictions: VecDeque<EditPrediction>,
- rated_predictions: HashSet<EditPredictionId>,
-}
-
-#[derive(Copy, Clone, Default, PartialEq, Eq)]
-pub enum ZetaEditPredictionModel {
- #[default]
- Zeta1,
- Zeta2,
- Sweep,
-}
-
-#[derive(Debug, Clone, PartialEq)]
-pub struct ZetaOptions {
- pub context: ContextMode,
- pub max_prompt_bytes: usize,
- pub max_diagnostic_bytes: usize,
- pub prompt_format: predict_edits_v3::PromptFormat,
- pub file_indexing_parallelism: usize,
- pub buffer_change_grouping_interval: Duration,
-}
-
-#[derive(Debug, Clone, PartialEq)]
-pub enum ContextMode {
- Agentic(AgenticContextOptions),
- Syntax(EditPredictionContextOptions),
- Lsp(EditPredictionExcerptOptions),
-}
-
-#[derive(Debug, Clone, PartialEq)]
-pub struct AgenticContextOptions {
- pub excerpt: EditPredictionExcerptOptions,
-}
-
-impl ContextMode {
- pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
- match self {
- ContextMode::Agentic(options) => &options.excerpt,
- ContextMode::Syntax(options) => &options.excerpt,
- ContextMode::Lsp(options) => &options,
- }
- }
-}
-
-#[derive(Debug)]
-pub enum ZetaDebugInfo {
- ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
- ContextRetrievalFinished(ZetaContextRetrievalFinishedDebugInfo),
- EditPredictionRequested(ZetaEditPredictionDebugInfo),
-}
-
-#[derive(Debug)]
-pub struct ZetaContextRetrievalStartedDebugInfo {
- pub project_entity_id: EntityId,
- pub timestamp: Instant,
- pub search_prompt: String,
-}
-
-#[derive(Debug)]
-pub struct ZetaContextRetrievalFinishedDebugInfo {
- pub project_entity_id: EntityId,
- pub timestamp: Instant,
- pub metadata: Vec<(&'static str, SharedString)>,
-}
-
-#[derive(Debug)]
-pub struct ZetaEditPredictionDebugInfo {
- pub inputs: EditPredictionInputs,
- pub retrieval_time: Duration,
- pub buffer: WeakEntity<Buffer>,
- pub position: language::Anchor,
- pub local_prompt: Result<String, String>,
- pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
-}
-
-pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
-
-struct ZetaProject {
- events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
- last_event: Option<LastEvent>,
- recent_paths: VecDeque<ProjectPath>,
- registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
- current_prediction: Option<CurrentEditPrediction>,
- next_pending_prediction_id: usize,
- pending_predictions: ArrayVec<PendingPrediction, 2>,
- context_updates_tx: smol::channel::Sender<()>,
- context_updates_rx: smol::channel::Receiver<()>,
- last_prediction_refresh: Option<(EntityId, Instant)>,
- cancelled_predictions: HashSet<usize>,
- context: ZetaProjectContext,
- license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
- _subscription: gpui::Subscription,
-}
-
-enum ZetaProjectContext {
- Syntax(Entity<SyntaxIndex>),
- Lsp(Entity<RelatedExcerptStore>),
- Agentic {
- refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
- refresh_context_debounce_task: Option<Task<Option<()>>>,
- refresh_context_timestamp: Option<Instant>,
- context: Vec<RelatedFile>,
- },
-}
-
-impl ZetaProject {
- pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
- self.events
- .iter()
- .cloned()
- .chain(
- self.last_event
- .as_ref()
- .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
- )
- .collect()
- }
-
- fn cancel_pending_prediction(
- &mut self,
- pending_prediction: PendingPrediction,
- cx: &mut Context<Zeta>,
- ) {
- self.cancelled_predictions.insert(pending_prediction.id);
-
- cx.spawn(async move |this, cx| {
- let Some(prediction_id) = pending_prediction.task.await else {
- return;
- };
-
- this.update(cx, |this, _cx| {
- this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
- })
- .ok();
- })
- .detach()
- }
-}
-
-#[derive(Debug, Clone)]
-struct CurrentEditPrediction {
- pub requested_by: PredictionRequestedBy,
- pub prediction: EditPrediction,
- pub was_shown: bool,
-}
-
-impl CurrentEditPrediction {
- fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
- let Some(new_edits) = self
- .prediction
- .interpolate(&self.prediction.buffer.read(cx))
- else {
- return false;
- };
-
- if self.prediction.buffer != old_prediction.prediction.buffer {
- return true;
- }
-
- let Some(old_edits) = old_prediction
- .prediction
- .interpolate(&old_prediction.prediction.buffer.read(cx))
- else {
- return true;
- };
-
- let requested_by_buffer_id = self.requested_by.buffer_id();
-
- // This reduces the occurrence of UI thrash from replacing edits
- //
- // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
- if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
- && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
- && old_edits.len() == 1
- && new_edits.len() == 1
- {
- let (old_range, old_text) = &old_edits[0];
- let (new_range, new_text) = &new_edits[0];
- new_range == old_range && new_text.starts_with(old_text.as_ref())
- } else {
- true
- }
- }
-}
-
-#[derive(Debug, Clone)]
-enum PredictionRequestedBy {
- DiagnosticsUpdate,
- Buffer(EntityId),
-}
-
-impl PredictionRequestedBy {
- pub fn buffer_id(&self) -> Option<EntityId> {
- match self {
- PredictionRequestedBy::DiagnosticsUpdate => None,
- PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
- }
- }
-}
-
-#[derive(Debug)]
-struct PendingPrediction {
- id: usize,
- task: Task<Option<EditPredictionId>>,
-}
-
-/// A prediction from the perspective of a buffer.
-#[derive(Debug)]
-enum BufferEditPrediction<'a> {
- Local { prediction: &'a EditPrediction },
- Jump { prediction: &'a EditPrediction },
-}
-
-#[cfg(test)]
-impl std::ops::Deref for BufferEditPrediction<'_> {
- type Target = EditPrediction;
-
- fn deref(&self) -> &Self::Target {
- match self {
- BufferEditPrediction::Local { prediction } => prediction,
- BufferEditPrediction::Jump { prediction } => prediction,
- }
- }
-}
-
-struct RegisteredBuffer {
- snapshot: BufferSnapshot,
- _subscriptions: [gpui::Subscription; 2],
-}
-
-struct LastEvent {
- old_snapshot: BufferSnapshot,
- new_snapshot: BufferSnapshot,
- end_edit_anchor: Option<Anchor>,
-}
-
-impl LastEvent {
- pub fn finalize(
- &self,
- license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
- cx: &App,
- ) -> Option<Arc<predict_edits_v3::Event>> {
- let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
- let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
-
- let file = self.new_snapshot.file();
- let old_file = self.old_snapshot.file();
-
- let in_open_source_repo = [file, old_file].iter().all(|file| {
- file.is_some_and(|file| {
- license_detection_watchers
- .get(&file.worktree_id(cx))
- .is_some_and(|watcher| watcher.is_project_open_source())
- })
- });
-
- let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
-
- if path == old_path && diff.is_empty() {
- None
- } else {
- Some(Arc::new(predict_edits_v3::Event::BufferChange {
- old_path,
- path,
- diff,
- in_open_source_repo,
- // TODO: Actually detect if this edit was predicted or not
- predicted: false,
- }))
- }
- }
-}
-
-fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
- if let Some(file) = snapshot.file() {
- file.full_path(cx).into()
- } else {
- Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
- }
-}
-
-impl Zeta {
- pub fn try_global(cx: &App) -> Option<Entity<Self>> {
- cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
- }
-
- pub fn global(
- client: &Arc<Client>,
- user_store: &Entity<UserStore>,
- cx: &mut App,
- ) -> Entity<Self> {
- cx.try_global::<ZetaGlobal>()
- .map(|global| global.0.clone())
- .unwrap_or_else(|| {
- let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
- cx.set_global(ZetaGlobal(zeta.clone()));
- zeta
- })
- }
-
- pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
- let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
- let data_collection_choice = Self::load_data_collection_choice();
-
- let llm_token = LlmApiToken::default();
-
- let (reject_tx, reject_rx) = mpsc::unbounded();
- cx.background_spawn({
- let client = client.clone();
- let llm_token = llm_token.clone();
- let app_version = AppVersion::global(cx);
- let background_executor = cx.background_executor().clone();
- async move {
- Self::handle_rejected_predictions(
- reject_rx,
- client,
- llm_token,
- app_version,
- background_executor,
- )
- .await
- }
- })
- .detach();
-
- let mut this = Self {
- projects: HashMap::default(),
- client,
- user_store,
- options: DEFAULT_OPTIONS,
- use_context: false,
- llm_token,
- _llm_token_subscription: cx.subscribe(
- &refresh_llm_token_listener,
- |this, _listener, _event, cx| {
- let client = this.client.clone();
- let llm_token = this.llm_token.clone();
- cx.spawn(async move |_this, _cx| {
- llm_token.refresh(&client).await?;
- anyhow::Ok(())
- })
- .detach_and_log_err(cx);
- },
- ),
- update_required: false,
- debug_tx: None,
- #[cfg(feature = "eval-support")]
- eval_cache: None,
- edit_prediction_model: ZetaEditPredictionModel::Zeta2,
- sweep_ai: SweepAi::new(cx),
- data_collection_choice,
- reject_predictions_tx: reject_tx,
- rated_predictions: Default::default(),
- shown_predictions: Default::default(),
- };
-
- this.enable_or_disable_context_retrieval(cx);
- let weak_this = cx.weak_entity();
- cx.on_flags_ready(move |_, cx| {
- weak_this
- .update(cx, |this, cx| this.enable_or_disable_context_retrieval(cx))
- .ok();
- })
- .detach();
- cx.observe_global::<SettingsStore>(|this, cx| {
- this.enable_or_disable_context_retrieval(cx);
- })
- .detach();
-
- this
- }
-
- pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
- self.edit_prediction_model = model;
- }
-
- pub fn has_sweep_api_token(&self) -> bool {
- self.sweep_ai
- .api_token
- .clone()
- .now_or_never()
- .flatten()
- .is_some()
- }
-
- #[cfg(feature = "eval-support")]
- pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
- self.eval_cache = Some(cache);
- }
-
- pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
- let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
- self.debug_tx = Some(debug_watch_tx);
- debug_watch_rx
- }
-
- pub fn options(&self) -> &ZetaOptions {
- &self.options
- }
-
- pub fn set_options(&mut self, options: ZetaOptions) {
- self.options = options;
- }
-
- pub fn set_use_context(&mut self, use_context: bool) {
- self.use_context = use_context;
- }
-
- pub fn clear_history(&mut self) {
- for zeta_project in self.projects.values_mut() {
- zeta_project.events.clear();
- }
- }
-
- pub fn context_for_project<'a>(
- &'a self,
- project: &Entity<Project>,
- cx: &'a App,
- ) -> &'a [RelatedFile] {
- self.projects
- .get(&project.entity_id())
- .and_then(|project| match &project.context {
- ZetaProjectContext::Syntax(_) => None,
- ZetaProjectContext::Lsp(store) => Some(store.read(cx).related_files()),
- ZetaProjectContext::Agentic { context, .. } => Some(context.as_slice()),
- })
- .unwrap_or(&[])
- }
-
- pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
- if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
- self.user_store.read(cx).edit_prediction_usage()
- } else {
- None
- }
- }
-
- pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
- self.get_or_init_zeta_project(project, cx);
- }
-
- pub fn register_buffer(
- &mut self,
- buffer: &Entity<Buffer>,
- project: &Entity<Project>,
- cx: &mut Context<Self>,
- ) {
- let zeta_project = self.get_or_init_zeta_project(project, cx);
- Self::register_buffer_impl(zeta_project, buffer, project, cx);
- }
-
- fn get_or_init_zeta_project(
- &mut self,
- project: &Entity<Project>,
- cx: &mut Context<Self>,
- ) -> &mut ZetaProject {
- let entity_id = project.entity_id();
- let (context_updates_tx, context_updates_rx) = smol::channel::unbounded();
- self.projects
- .entry(entity_id)
- .or_insert_with(|| ZetaProject {
- context: match &self.options.context {
- ContextMode::Agentic(_) => ZetaProjectContext::Agentic {
- refresh_context_task: None,
- refresh_context_debounce_task: None,
- refresh_context_timestamp: None,
- context: Vec::new(),
- },
- ContextMode::Syntax(_) => ZetaProjectContext::Syntax(cx.new(|cx| {
- SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
- })),
- ContextMode::Lsp(_) => {
- let related_excerpt_store =
- cx.new(|cx| RelatedExcerptStore::new(project, cx));
- cx.subscribe(
- &related_excerpt_store,
- move |this, _, event, _| match event {
- RelatedExcerptStoreEvent::StartedRefresh => {
- if let Some(debug_tx) = this.debug_tx.clone() {
- debug_tx
- .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
- ZetaContextRetrievalStartedDebugInfo {
- project_entity_id: entity_id,
- timestamp: Instant::now(),
- search_prompt: String::new(),
- },
- ))
- .ok();
- }
- }
- RelatedExcerptStoreEvent::FinishedRefresh {
- cache_hit_count,
- cache_miss_count,
- mean_definition_latency,
- max_definition_latency,
- } => {
- if let Some(debug_tx) = this.debug_tx.clone() {
- debug_tx
- .unbounded_send(
- ZetaDebugInfo::ContextRetrievalFinished(
- ZetaContextRetrievalFinishedDebugInfo {
- project_entity_id: entity_id,
- timestamp: Instant::now(),
- metadata: vec![
- (
- "Cache Hits",
- format!(
- "{}/{}",
- cache_hit_count,
- cache_hit_count
- + cache_miss_count
- )
- .into(),
- ),
- (
- "Max LSP Time",
- format!(
- "{} ms",
- max_definition_latency
- .as_millis()
- )
- .into(),
- ),
- (
- "Mean LSP Time",
- format!(
- "{} ms",
- mean_definition_latency
- .as_millis()
- )
- .into(),
- ),
- ],
- },
- ),
- )
- .ok();
- }
- if let Some(project_state) = this.projects.get(&entity_id) {
- project_state.context_updates_tx.send_blocking(()).ok();
- }
- }
- },
- )
- .detach();
- ZetaProjectContext::Lsp(related_excerpt_store)
- }
- },
- events: VecDeque::new(),
- last_event: None,
- recent_paths: VecDeque::new(),
- context_updates_rx,
- context_updates_tx,
- registered_buffers: HashMap::default(),
- current_prediction: None,
- cancelled_predictions: HashSet::default(),
- pending_predictions: ArrayVec::new(),
- next_pending_prediction_id: 0,
- last_prediction_refresh: None,
- license_detection_watchers: HashMap::default(),
- _subscription: cx.subscribe(&project, Self::handle_project_event),
- })
- }
-
- pub fn project_context_updates(
- &self,
- project: &Entity<Project>,
- ) -> Option<smol::channel::Receiver<()>> {
- let project_state = self.projects.get(&project.entity_id())?;
- Some(project_state.context_updates_rx.clone())
- }
-
- fn handle_project_event(
- &mut self,
- project: Entity<Project>,
- event: &project::Event,
- cx: &mut Context<Self>,
- ) {
- // TODO [zeta2] init with recent paths
- match event {
- project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
- let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
- return;
- };
- let path = project.read(cx).path_for_entry(*active_entry_id, cx);
- if let Some(path) = path {
- if let Some(ix) = zeta_project
- .recent_paths
- .iter()
- .position(|probe| probe == &path)
- {
- zeta_project.recent_paths.remove(ix);
- }
- zeta_project.recent_paths.push_front(path);
- }
- }
- project::Event::DiagnosticsUpdated { .. } => {
- if cx.has_flag::<Zeta2FeatureFlag>() {
- self.refresh_prediction_from_diagnostics(project, cx);
- }
- }
- _ => (),
- }
- }
-
- fn register_buffer_impl<'a>(
- zeta_project: &'a mut ZetaProject,
- buffer: &Entity<Buffer>,
- project: &Entity<Project>,
- cx: &mut Context<Self>,
- ) -> &'a mut RegisteredBuffer {
- let buffer_id = buffer.entity_id();
-
- if let Some(file) = buffer.read(cx).file() {
- let worktree_id = file.worktree_id(cx);
- if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
- zeta_project
- .license_detection_watchers
- .entry(worktree_id)
- .or_insert_with(|| {
- let project_entity_id = project.entity_id();
- cx.observe_release(&worktree, move |this, _worktree, _cx| {
- let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
- else {
- return;
- };
- zeta_project.license_detection_watchers.remove(&worktree_id);
- })
- .detach();
- Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
- });
- }
- }
-
- match zeta_project.registered_buffers.entry(buffer_id) {
- hash_map::Entry::Occupied(entry) => entry.into_mut(),
- hash_map::Entry::Vacant(entry) => {
- let snapshot = buffer.read(cx).snapshot();
- let project_entity_id = project.entity_id();
- entry.insert(RegisteredBuffer {
- snapshot,
- _subscriptions: [
- cx.subscribe(buffer, {
- let project = project.downgrade();
- move |this, buffer, event, cx| {
- if let language::BufferEvent::Edited = event
- && let Some(project) = project.upgrade()
- {
- this.report_changes_for_buffer(&buffer, &project, cx);
- }
- }
- }),
- cx.observe_release(buffer, move |this, _buffer, _cx| {
- let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
- else {
- return;
- };
- zeta_project.registered_buffers.remove(&buffer_id);
- }),
- ],
- })
- }
- }
- }
-
- fn report_changes_for_buffer(
- &mut self,
- buffer: &Entity<Buffer>,
- project: &Entity<Project>,
- cx: &mut Context<Self>,
- ) {
- let project_state = self.get_or_init_zeta_project(project, cx);
- let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
-
- let new_snapshot = buffer.read(cx).snapshot();
- if new_snapshot.version == registered_buffer.snapshot.version {
- return;
- }
-
- let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
- let end_edit_anchor = new_snapshot
- .anchored_edits_since::<Point>(&old_snapshot.version)
- .last()
- .map(|(_, range)| range.end);
- let events = &mut project_state.events;
-
- if let Some(LastEvent {
- new_snapshot: last_new_snapshot,
- end_edit_anchor: last_end_edit_anchor,
- ..
- }) = project_state.last_event.as_mut()
- {
- let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
- == last_new_snapshot.remote_id()
- && old_snapshot.version == last_new_snapshot.version;
-
- let should_coalesce = is_next_snapshot_of_same_buffer
- && end_edit_anchor
- .as_ref()
- .zip(last_end_edit_anchor.as_ref())
- .is_some_and(|(a, b)| {
- let a = a.to_point(&new_snapshot);
- let b = b.to_point(&new_snapshot);
- a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
- });
-
- if should_coalesce {
- *last_end_edit_anchor = end_edit_anchor;
- *last_new_snapshot = new_snapshot;
- return;
- }
- }
-
- if events.len() + 1 >= EVENT_COUNT_MAX {
- events.pop_front();
- }
-
- if let Some(event) = project_state.last_event.take() {
- events.extend(event.finalize(&project_state.license_detection_watchers, cx));
- }
-
- project_state.last_event = Some(LastEvent {
- old_snapshot,
- new_snapshot,
- end_edit_anchor,
- });
- }
-
- fn current_prediction_for_buffer(
- &self,
- buffer: &Entity<Buffer>,
- project: &Entity<Project>,
- cx: &App,
- ) -> Option<BufferEditPrediction<'_>> {
- let project_state = self.projects.get(&project.entity_id())?;
-
- let CurrentEditPrediction {
- requested_by,
- prediction,
- ..
- } = project_state.current_prediction.as_ref()?;
-
- if prediction.targets_buffer(buffer.read(cx)) {
- Some(BufferEditPrediction::Local { prediction })
- } else {
- let show_jump = match requested_by {
- PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
- requested_by_buffer_id == &buffer.entity_id()
- }
- PredictionRequestedBy::DiagnosticsUpdate => true,
- };
-
- if show_jump {
- Some(BufferEditPrediction::Jump { prediction })
- } else {
- None
- }
- }
- }
-
- fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
- match self.edit_prediction_model {
- ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
- ZetaEditPredictionModel::Sweep => return,
- }
-
- let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
- return;
- };
-
- let Some(prediction) = project_state.current_prediction.take() else {
- return;
- };
- let request_id = prediction.prediction.id.to_string();
- for pending_prediction in mem::take(&mut project_state.pending_predictions) {
- project_state.cancel_pending_prediction(pending_prediction, cx);
- }
-
- let client = self.client.clone();
- let llm_token = self.llm_token.clone();
- let app_version = AppVersion::global(cx);
- cx.spawn(async move |this, cx| {
- let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
- http_client::Url::parse(&predict_edits_url)?
- } else {
- client
- .http_client()
- .build_zed_llm_url("/predict_edits/accept", &[])?
- };
-
- let response = cx
- .background_spawn(Self::send_api_request::<()>(
- move |builder| {
- let req = builder.uri(url.as_ref()).body(
- serde_json::to_string(&AcceptEditPredictionBody {
- request_id: request_id.clone(),
- })?
- .into(),
- );
- Ok(req?)
- },
- client,
- llm_token,
- app_version,
- ))
- .await;
-
- Self::handle_api_response(&this, response, cx)?;
- anyhow::Ok(())
- })
- .detach_and_log_err(cx);
- }
-
- async fn handle_rejected_predictions(
- rx: UnboundedReceiver<EditPredictionRejection>,
- client: Arc<Client>,
- llm_token: LlmApiToken,
- app_version: Version,
- background_executor: BackgroundExecutor,
- ) {
- let mut rx = std::pin::pin!(rx.peekable());
- let mut batched = Vec::new();
-
- while let Some(rejection) = rx.next().await {
- batched.push(rejection);
-
- if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
- select_biased! {
- next = rx.as_mut().peek().fuse() => {
- if next.is_some() {
- continue;
- }
- }
- () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
- }
- }
-
- let url = client
- .http_client()
- .build_zed_llm_url("/predict_edits/reject", &[])
- .unwrap();
-
- let flush_count = batched
- .len()
- // in case items have accumulated after failure
- .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
- let start = batched.len() - flush_count;
-
- let body = RejectEditPredictionsBodyRef {
- rejections: &batched[start..],
- };
-
- let result = Self::send_api_request::<()>(
- |builder| {
- let req = builder
- .uri(url.as_ref())
- .body(serde_json::to_string(&body)?.into());
- anyhow::Ok(req?)
- },
- client.clone(),
- llm_token.clone(),
- app_version.clone(),
- )
- .await;
-
- if result.log_err().is_some() {
- batched.drain(start..);
- }
- }
- }
-
- fn reject_current_prediction(
- &mut self,
- reason: EditPredictionRejectReason,
- project: &Entity<Project>,
- ) {
- if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
- project_state.pending_predictions.clear();
- if let Some(prediction) = project_state.current_prediction.take() {
- self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
- }
- };
- }
-
- fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
- if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
- if let Some(current_prediction) = project_state.current_prediction.as_mut() {
- if !current_prediction.was_shown {
- current_prediction.was_shown = true;
- self.shown_predictions
- .push_front(current_prediction.prediction.clone());
- if self.shown_predictions.len() > 50 {
- let completion = self.shown_predictions.pop_back().unwrap();
- self.rated_predictions.remove(&completion.id);
- }
- }
- }
- }
- }
-
- fn reject_prediction(
- &mut self,
- prediction_id: EditPredictionId,
- reason: EditPredictionRejectReason,
- was_shown: bool,
- ) {
- match self.edit_prediction_model {
- ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
- ZetaEditPredictionModel::Sweep => return,
- }
-
- self.reject_predictions_tx
- .unbounded_send(EditPredictionRejection {
- request_id: prediction_id.to_string(),
- reason,
- was_shown,
- })
- .log_err();
- }
-
- fn is_refreshing(&self, project: &Entity<Project>) -> bool {
- self.projects
- .get(&project.entity_id())
- .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
- }
-
- pub fn refresh_prediction_from_buffer(
- &mut self,
- project: Entity<Project>,
- buffer: Entity<Buffer>,
- position: language::Anchor,
- cx: &mut Context<Self>,
- ) {
- self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
- let Some(request_task) = this
- .update(cx, |this, cx| {
- this.request_prediction(
- &project,
- &buffer,
- position,
- PredictEditsRequestTrigger::Other,
- cx,
- )
- })
- .log_err()
- else {
- return Task::ready(anyhow::Ok(None));
- };
-
- cx.spawn(async move |_cx| {
- request_task.await.map(|prediction_result| {
- prediction_result.map(|prediction_result| {
- (
- prediction_result,
- PredictionRequestedBy::Buffer(buffer.entity_id()),
- )
- })
- })
- })
- })
- }
-
- pub fn refresh_prediction_from_diagnostics(
- &mut self,
- project: Entity<Project>,
- cx: &mut Context<Self>,
- ) {
- let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
- return;
- };
-
- // Prefer predictions from buffer
- if zeta_project.current_prediction.is_some() {
- return;
- };
-
- self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
- let Some(open_buffer_task) = project
- .update(cx, |project, cx| {
- project
- .active_entry()
- .and_then(|entry| project.path_for_entry(entry, cx))
- .map(|path| project.open_buffer(path, cx))
- })
- .log_err()
- .flatten()
- else {
- return Task::ready(anyhow::Ok(None));
- };
-
- cx.spawn(async move |cx| {
- let active_buffer = open_buffer_task.await?;
- let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
-
- let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
- active_buffer,
- &snapshot,
- Default::default(),
- Default::default(),
- &project,
- cx,
- )
- .await?
- else {
- return anyhow::Ok(None);
- };
-
- let Some(prediction_result) = this
- .update(cx, |this, cx| {
- this.request_prediction(
- &project,
- &jump_buffer,
- jump_position,
- PredictEditsRequestTrigger::Diagnostics,
- cx,
- )
- })?
- .await?
- else {
- return anyhow::Ok(None);
- };
-
- this.update(cx, |this, cx| {
- Some((
- if this
- .get_or_init_zeta_project(&project, cx)
- .current_prediction
- .is_none()
- {
- prediction_result
- } else {
- EditPredictionResult {
- id: prediction_result.id,
- prediction: Err(EditPredictionRejectReason::CurrentPreferred),
- }
- },
- PredictionRequestedBy::DiagnosticsUpdate,
- ))
- })
- })
- });
- }
-
- #[cfg(not(test))]
- pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
- #[cfg(test)]
- pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
-
- fn queue_prediction_refresh(
- &mut self,
- project: Entity<Project>,
- throttle_entity: EntityId,
- cx: &mut Context<Self>,
- do_refresh: impl FnOnce(
- WeakEntity<Self>,
- &mut AsyncApp,
- )
- -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
- + 'static,
- ) {
- let zeta_project = self.get_or_init_zeta_project(&project, cx);
- let pending_prediction_id = zeta_project.next_pending_prediction_id;
- zeta_project.next_pending_prediction_id += 1;
- let last_request = zeta_project.last_prediction_refresh;
-
- let task = cx.spawn(async move |this, cx| {
- if let Some((last_entity, last_timestamp)) = last_request
- && throttle_entity == last_entity
- && let Some(timeout) =
- (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
- {
- cx.background_executor().timer(timeout).await;
- }
-
- // If this task was cancelled before the throttle timeout expired,
- // do not perform a request.
- let mut is_cancelled = true;
- this.update(cx, |this, cx| {
- let project_state = this.get_or_init_zeta_project(&project, cx);
- if !project_state
- .cancelled_predictions
- .remove(&pending_prediction_id)
- {
- project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
- is_cancelled = false;
- }
- })
- .ok();
- if is_cancelled {
- return None;
- }
-
- let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
- let new_prediction_id = new_prediction_result
- .as_ref()
- .map(|(prediction, _)| prediction.id.clone());
-
- // When a prediction completes, remove it from the pending list, and cancel
- // any pending predictions that were enqueued before it.
- this.update(cx, |this, cx| {
- let zeta_project = this.get_or_init_zeta_project(&project, cx);
-
- let is_cancelled = zeta_project
- .cancelled_predictions
- .remove(&pending_prediction_id);
-
- let new_current_prediction = if !is_cancelled
- && let Some((prediction_result, requested_by)) = new_prediction_result
- {
- match prediction_result.prediction {
- Ok(prediction) => {
- let new_prediction = CurrentEditPrediction {
- requested_by,
- prediction,
- was_shown: false,
- };
-
- if let Some(current_prediction) =
- zeta_project.current_prediction.as_ref()
- {
- if new_prediction.should_replace_prediction(¤t_prediction, cx)
- {
- this.reject_current_prediction(
- EditPredictionRejectReason::Replaced,
- &project,
- );
-
- Some(new_prediction)
- } else {
- this.reject_prediction(
- new_prediction.prediction.id,
- EditPredictionRejectReason::CurrentPreferred,
- false,
- );
- None
- }
- } else {
- Some(new_prediction)
- }
- }
- Err(reject_reason) => {
- this.reject_prediction(prediction_result.id, reject_reason, false);
- None
- }
- }
- } else {
- None
- };
-
- let zeta_project = this.get_or_init_zeta_project(&project, cx);
-
- if let Some(new_prediction) = new_current_prediction {
- zeta_project.current_prediction = Some(new_prediction);
- }
-
- let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
- for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
- if pending_prediction.id == pending_prediction_id {
- pending_predictions.remove(ix);
- for pending_prediction in pending_predictions.drain(0..ix) {
- zeta_project.cancel_pending_prediction(pending_prediction, cx)
- }
- break;
- }
- }
- this.get_or_init_zeta_project(&project, cx)
- .pending_predictions = pending_predictions;
- cx.notify();
- })
- .ok();
-
- new_prediction_id
- });
-
- if zeta_project.pending_predictions.len() <= 1 {
- zeta_project.pending_predictions.push(PendingPrediction {
- id: pending_prediction_id,
- task,
- });
- } else if zeta_project.pending_predictions.len() == 2 {
- let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
- zeta_project.pending_predictions.push(PendingPrediction {
- id: pending_prediction_id,
- task,
- });
- zeta_project.cancel_pending_prediction(pending_prediction, cx);
- }
- }
-
- pub fn request_prediction(
- &mut self,
- project: &Entity<Project>,
- active_buffer: &Entity<Buffer>,
- position: language::Anchor,
- trigger: PredictEditsRequestTrigger,
- cx: &mut Context<Self>,
- ) -> Task<Result<Option<EditPredictionResult>>> {
- self.request_prediction_internal(
- project.clone(),
- active_buffer.clone(),
- position,
- trigger,
- cx.has_flag::<Zeta2FeatureFlag>(),
- cx,
- )
- }
-
- fn request_prediction_internal(
- &mut self,
- project: Entity<Project>,
- active_buffer: Entity<Buffer>,
- position: language::Anchor,
- trigger: PredictEditsRequestTrigger,
- allow_jump: bool,
- cx: &mut Context<Self>,
- ) -> Task<Result<Option<EditPredictionResult>>> {
- const DIAGNOSTIC_LINES_RANGE: u32 = 20;
-
- self.get_or_init_zeta_project(&project, cx);
- let zeta_project = self.projects.get(&project.entity_id()).unwrap();
- let events = zeta_project.events(cx);
- let has_events = !events.is_empty();
-
- let snapshot = active_buffer.read(cx).snapshot();
- let cursor_point = position.to_point(&snapshot);
- let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
- let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
- let diagnostic_search_range =
- Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
-
- let task = match self.edit_prediction_model {
- ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
- self,
- &project,
- &active_buffer,
- snapshot.clone(),
- position,
- events,
- trigger,
- cx,
- ),
- ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
- &project,
- &active_buffer,
- snapshot.clone(),
- position,
- events,
- trigger,
- cx,
- ),
- ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
- &project,
- &active_buffer,
- snapshot.clone(),
- position,
- events,
- &zeta_project.recent_paths,
- if self.use_context {
- self.context_for_project(&project, cx).to_vec()
- } else {
- Vec::new()
- },
- diagnostic_search_range.clone(),
- cx,
- ),
- };
-
- cx.spawn(async move |this, cx| {
- let prediction = task.await?;
-
- if prediction.is_none() && allow_jump {
- let cursor_point = position.to_point(&snapshot);
- if has_events
- && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
- active_buffer.clone(),
- &snapshot,
- diagnostic_search_range,
- cursor_point,
- &project,
- cx,
- )
- .await?
- {
- return this
- .update(cx, |this, cx| {
- this.request_prediction_internal(
- project,
- jump_buffer,
- jump_position,
- trigger,
- false,
- cx,
- )
- })?
- .await;
- }
-
- return anyhow::Ok(None);
- }
-
- Ok(prediction)
- })
- }
-
- async fn next_diagnostic_location(
- active_buffer: Entity<Buffer>,
- active_buffer_snapshot: &BufferSnapshot,
- active_buffer_diagnostic_search_range: Range<Point>,
- active_buffer_cursor_point: Point,
- project: &Entity<Project>,
- cx: &mut AsyncApp,
- ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
- // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
- let mut jump_location = active_buffer_snapshot
- .diagnostic_groups(None)
- .into_iter()
- .filter_map(|(_, group)| {
- let range = &group.entries[group.primary_ix]
- .range
- .to_point(&active_buffer_snapshot);
- if range.overlaps(&active_buffer_diagnostic_search_range) {
- None
- } else {
- Some(range.start)
- }
- })
- .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
- .map(|position| {
- (
- active_buffer.clone(),
- active_buffer_snapshot.anchor_before(position),
- )
- });
-
- if jump_location.is_none() {
- let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
- let file = buffer.file()?;
-
- Some(ProjectPath {
- worktree_id: file.worktree_id(cx),
- path: file.path().clone(),
- })
- })?;
-
- let buffer_task = project.update(cx, |project, cx| {
- let (path, _, _) = project
- .diagnostic_summaries(false, cx)
- .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
- .max_by_key(|(path, _, _)| {
- // find the buffer with errors that shares most parent directories
- path.path
- .components()
- .zip(
- active_buffer_path
- .as_ref()
- .map(|p| p.path.components())
- .unwrap_or_default(),
- )
- .take_while(|(a, b)| a == b)
- .count()
- })?;
-
- Some(project.open_buffer(path, cx))
- })?;
-
- if let Some(buffer_task) = buffer_task {
- let closest_buffer = buffer_task.await?;
-
- jump_location = closest_buffer
- .read_with(cx, |buffer, _cx| {
- buffer
- .buffer_diagnostics(None)
- .into_iter()
- .min_by_key(|entry| entry.diagnostic.severity)
- .map(|entry| entry.range.start)
- })?
- .map(|position| (closest_buffer, position));
- }
- }
-
- anyhow::Ok(jump_location)
- }
-
- fn request_prediction_with_zeta2(
- &mut self,
- project: &Entity<Project>,
- active_buffer: &Entity<Buffer>,
- active_snapshot: BufferSnapshot,
- position: language::Anchor,
- events: Vec<Arc<Event>>,
- trigger: PredictEditsRequestTrigger,
- cx: &mut Context<Self>,
- ) -> Task<Result<Option<EditPredictionResult>>> {
- let options = self.options.clone();
- let buffer_snapshotted_at = Instant::now();
-
- let Some((excerpt_path, active_project_path)) = active_snapshot
- .file()
- .map(|file| -> Arc<Path> { file.full_path(cx).into() })
- .zip(active_buffer.read(cx).project_path(cx))
- else {
- return Task::ready(Err(anyhow!("No file path for excerpt")));
- };
-
- let client = self.client.clone();
- let llm_token = self.llm_token.clone();
- let app_version = AppVersion::global(cx);
- let debug_tx = self.debug_tx.clone();
-
- let diagnostics = active_snapshot.diagnostic_sets().clone();
-
- let file = active_buffer.read(cx).file();
-
- let active_file_full_path = file.as_ref().map(|f| f.full_path(cx));
-
- // TODO data collection
- let can_collect_data = file
- .as_ref()
- .map_or(false, |file| self.can_collect_file(project, file, cx));
-
- let mut included_files = self.context_for_project(project, cx).to_vec();
-
- #[cfg(feature = "eval-support")]
- let eval_cache = self.eval_cache.clone();
-
- let request_task = cx.background_spawn({
- let active_buffer = active_buffer.clone();
- async move {
- let cursor_offset = position.to_offset(&active_snapshot);
- let cursor_point = cursor_offset.to_point(&active_snapshot);
-
- let before_retrieval = Instant::now();
-
- let (diagnostic_groups, diagnostic_groups_truncated) =
- Self::gather_nearby_diagnostics(
- cursor_offset,
- &diagnostics,
- &active_snapshot,
- options.max_diagnostic_bytes,
- );
-
- let excerpt_options = options.context.excerpt();
-
- let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
- cursor_point,
- &active_snapshot,
- &excerpt_options,
- None,
- ) else {
- return Ok((None, None));
- };
-
- let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
- ..active_snapshot.anchor_before(excerpt.range.end);
- let related_excerpt = RelatedExcerpt {
- anchor_range: excerpt_anchor_range.clone(),
- point_range: Point::new(excerpt.line_range.start.0, 0)
- ..Point::new(excerpt.line_range.end.0, 0),
- text: active_snapshot.as_rope().slice(excerpt.range),
- };
-
- if let Some(buffer_ix) = included_files
- .iter()
- .position(|file| file.buffer.entity_id() == active_buffer.entity_id())
- {
- let file = &mut included_files[buffer_ix];
- file.excerpts.push(related_excerpt);
- file.merge_excerpts();
- let last_ix = included_files.len() - 1;
- included_files.swap(buffer_ix, last_ix);
- } else {
- let active_file = RelatedFile {
- path: active_project_path,
- buffer: active_buffer.downgrade(),
- excerpts: vec![related_excerpt],
- max_row: active_snapshot.max_point().row,
- };
- included_files.push(active_file);
- }
-
- let included_files = included_files
- .iter()
- .map(|related_file| predict_edits_v3::IncludedFile {
- path: Arc::from(related_file.path.path.as_std_path()),
- max_row: Line(related_file.max_row),
- excerpts: related_file
- .excerpts
- .iter()
- .map(|excerpt| predict_edits_v3::Excerpt {
- start_line: Line(excerpt.point_range.start.row),
- text: excerpt.text.to_string().into(),
- })
- .collect(),
- })
- .collect::<Vec<_>>();
-
- let cloud_request = predict_edits_v3::PredictEditsRequest {
- excerpt_path,
- excerpt: String::new(),
- excerpt_line_range: Line(0)..Line(0),
- excerpt_range: 0..0,
- cursor_point: predict_edits_v3::Point {
- line: predict_edits_v3::Line(cursor_point.row),
- column: cursor_point.column,
- },
- included_files,
- referenced_declarations: vec![],
- events,
- can_collect_data,
- diagnostic_groups,
- diagnostic_groups_truncated,
- debug_info: debug_tx.is_some(),
- prompt_max_bytes: Some(options.max_prompt_bytes),
- prompt_format: options.prompt_format,
- // TODO [zeta2]
- signatures: vec![],
- excerpt_parent: None,
- git_info: None,
- trigger,
- };
-
- let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
-
- let inputs = EditPredictionInputs {
- included_files: cloud_request.included_files,
- events: cloud_request.events,
- cursor_point: cloud_request.cursor_point,
- cursor_path: cloud_request.excerpt_path,
- };
-
- let retrieval_time = Instant::now() - before_retrieval;
-
- let debug_response_tx = if let Some(debug_tx) = &debug_tx {
- let (response_tx, response_rx) = oneshot::channel();
-
- debug_tx
- .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
- ZetaEditPredictionDebugInfo {
- inputs: inputs.clone(),
- retrieval_time,
- buffer: active_buffer.downgrade(),
- local_prompt: match prompt_result.as_ref() {
- Ok((prompt, _)) => Ok(prompt.clone()),
- Err(err) => Err(err.to_string()),
- },
- position,
- response_rx,
- },
- ))
- .ok();
- Some(response_tx)
- } else {
- None
- };
-
- if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
- if let Some(debug_response_tx) = debug_response_tx {
- debug_response_tx
- .send((Err("Request skipped".to_string()), Duration::ZERO))
- .ok();
- }
- anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
- }
-
- let (prompt, _) = prompt_result?;
- let generation_params =
- cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
- let request = open_ai::Request {
- model: EDIT_PREDICTIONS_MODEL_ID.clone(),
- messages: vec![open_ai::RequestMessage::User {
- content: open_ai::MessageContent::Plain(prompt),
- }],
- stream: false,
- max_completion_tokens: None,
- stop: generation_params.stop.unwrap_or_default(),
- temperature: generation_params.temperature.unwrap_or(0.7),
- tool_choice: None,
- parallel_tool_calls: None,
- tools: vec![],
- prompt_cache_key: None,
- reasoning_effort: None,
- };
-
- log::trace!("Sending edit prediction request");
-
- let before_request = Instant::now();
- let response = Self::send_raw_llm_request(
- request,
- client,
- llm_token,
- app_version,
- #[cfg(feature = "eval-support")]
- eval_cache,
- #[cfg(feature = "eval-support")]
- EvalCacheEntryKind::Prediction,
- )
- .await;
- let received_response_at = Instant::now();
- let request_time = received_response_at - before_request;
-
- log::trace!("Got edit prediction response");
-
- if let Some(debug_response_tx) = debug_response_tx {
- debug_response_tx
- .send((
- response
- .as_ref()
- .map_err(|err| err.to_string())
- .map(|response| response.0.clone()),
- request_time,
- ))
- .ok();
- }
-
- let (res, usage) = response?;
- let request_id = EditPredictionId(res.id.clone().into());
- let Some(mut output_text) = text_from_response(res) else {
- return Ok((Some((request_id, None)), usage));
- };
-
- if output_text.contains(CURSOR_MARKER) {
- log::trace!("Stripping out {CURSOR_MARKER} from response");
- output_text = output_text.replace(CURSOR_MARKER, "");
- }
-
- let get_buffer_from_context = |path: &Path| {
- if Some(path) == active_file_full_path.as_deref() {
- Some((
- &active_snapshot,
- std::slice::from_ref(&excerpt_anchor_range),
- ))
- } else {
- None
- }
- };
-
- let (_, edits) = match options.prompt_format {
- PromptFormat::NumLinesUniDiff => {
- // TODO: Implement parsing of multi-file diffs
- crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
- }
- PromptFormat::Minimal
- | PromptFormat::MinimalQwen
- | PromptFormat::SeedCoder1120 => {
- if output_text.contains("--- a/\n+++ b/\nNo edits") {
- let edits = vec![];
- (&active_snapshot, edits)
- } else {
- crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
- }
- }
- PromptFormat::OldTextNewText => {
- crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
- .await?
- }
- _ => {
- bail!("unsupported prompt format {}", options.prompt_format)
- }
- };
-
- anyhow::Ok((
- Some((
- request_id,
- Some((
- inputs,
- active_buffer,
- active_snapshot.clone(),
- edits,
- received_response_at,
- )),
- )),
- usage,
- ))
- }
- });
-
- cx.spawn({
- async move |this, cx| {
- let Some((id, prediction)) =
- Self::handle_api_response(&this, request_task.await, cx)?
- else {
- return Ok(None);
- };
-
- let Some((
- inputs,
- edited_buffer,
- edited_buffer_snapshot,
- edits,
- received_response_at,
- )) = prediction
- else {
- return Ok(Some(EditPredictionResult {
- id,
- prediction: Err(EditPredictionRejectReason::Empty),
- }));
- };
-
- // TODO telemetry: duration, etc
- Ok(Some(
- EditPredictionResult::new(
- id,
- &edited_buffer,
- &edited_buffer_snapshot,
- edits.into(),
- buffer_snapshotted_at,
- received_response_at,
- inputs,
- cx,
- )
- .await,
- ))
- }
- })
- }
-
- async fn send_raw_llm_request(
- request: open_ai::Request,
- client: Arc<Client>,
- llm_token: LlmApiToken,
- app_version: Version,
- #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
- #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
- ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
- let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
- http_client::Url::parse(&predict_edits_url)?
- } else {
- client
- .http_client()
- .build_zed_llm_url("/predict_edits/raw", &[])?
- };
-
- #[cfg(feature = "eval-support")]
- let cache_key = if let Some(cache) = eval_cache {
- use collections::FxHasher;
- use std::hash::{Hash, Hasher};
-
- let mut hasher = FxHasher::default();
- url.hash(&mut hasher);
- let request_str = serde_json::to_string_pretty(&request)?;
- request_str.hash(&mut hasher);
- let hash = hasher.finish();
-
- let key = (eval_cache_kind, hash);
- if let Some(response_str) = cache.read(key) {
- return Ok((serde_json::from_str(&response_str)?, None));
- }
-
- Some((cache, request_str, key))
- } else {
- None
- };
-
- let (response, usage) = Self::send_api_request(
- |builder| {
- let req = builder
- .uri(url.as_ref())
- .body(serde_json::to_string(&request)?.into());
- Ok(req?)
- },
- client,
- llm_token,
- app_version,
- )
- .await?;
-
- #[cfg(feature = "eval-support")]
- if let Some((cache, request, key)) = cache_key {
- cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
- }
-
- Ok((response, usage))
- }
-
- fn handle_api_response<T>(
- this: &WeakEntity<Self>,
- response: Result<(T, Option<EditPredictionUsage>)>,
- cx: &mut gpui::AsyncApp,
- ) -> Result<T> {
- match response {
- Ok((data, usage)) => {
- if let Some(usage) = usage {
- this.update(cx, |this, cx| {
- this.user_store.update(cx, |user_store, cx| {
- user_store.update_edit_prediction_usage(usage, cx);
- });
- })
- .ok();
- }
- Ok(data)
- }
- Err(err) => {
- if err.is::<ZedUpdateRequiredError>() {
- cx.update(|cx| {
- this.update(cx, |this, _cx| {
- this.update_required = true;
- })
- .ok();
-
- let error_message: SharedString = err.to_string().into();
- show_app_notification(
- NotificationId::unique::<ZedUpdateRequiredError>(),
- cx,
- move |cx| {
- cx.new(|cx| {
- ErrorMessagePrompt::new(error_message.clone(), cx)
- .with_link_button("Update Zed", "https://zed.dev/releases")
- })
- },
- );
- })
- .ok();
- }
- Err(err)
- }
- }
- }
-
- async fn send_api_request<Res>(
- build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
- client: Arc<Client>,
- llm_token: LlmApiToken,
- app_version: Version,
- ) -> Result<(Res, Option<EditPredictionUsage>)>
- where
- Res: DeserializeOwned,
- {
- let http_client = client.http_client();
- let mut token = llm_token.acquire(&client).await?;
- let mut did_retry = false;
-
- loop {
- let request_builder = http_client::Request::builder().method(Method::POST);
-
- let request = build(
- request_builder
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", token))
- .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
- )?;
-
- let mut response = http_client.send(request).await?;
-
- if let Some(minimum_required_version) = response
- .headers()
- .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
- .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
- {
- anyhow::ensure!(
- app_version >= minimum_required_version,
- ZedUpdateRequiredError {
- minimum_version: minimum_required_version
- }
- );
- }
-
- if response.status().is_success() {
- let usage = EditPredictionUsage::from_headers(response.headers()).ok();
-
- let mut body = Vec::new();
- response.body_mut().read_to_end(&mut body).await?;
- return Ok((serde_json::from_slice(&body)?, usage));
- } else if !did_retry
- && response
- .headers()
- .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
- .is_some()
- {
- did_retry = true;
- token = llm_token.refresh(&client).await?;
- } else {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- anyhow::bail!(
- "Request failed with status: {:?}\nBody: {}",
- response.status(),
- body
- );
- }
- }
- }
-
- pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
- pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
-
- pub fn refresh_context_if_needed(
- &mut self,
- project: &Entity<Project>,
- buffer: &Entity<language::Buffer>,
- cursor_position: language::Anchor,
- cx: &mut Context<Self>,
- ) {
- if !self.use_context {
- return;
- }
- let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
- return;
- };
-
- match &mut zeta_project.context {
- ZetaProjectContext::Syntax(_entity) => {}
- ZetaProjectContext::Lsp(related_excerpt_store) => {
- related_excerpt_store.update(cx, |store, cx| {
- store.refresh(buffer.clone(), cursor_position, cx);
- });
- }
- ZetaProjectContext::Agentic {
- refresh_context_debounce_task,
- refresh_context_timestamp,
- ..
- } => {
- let now = Instant::now();
- let was_idle = refresh_context_timestamp.map_or(true, |timestamp| {
- now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
- });
- *refresh_context_timestamp = Some(now);
- *refresh_context_debounce_task = Some(cx.spawn({
- let buffer = buffer.clone();
- let project = project.clone();
- async move |this, cx| {
- if was_idle {
- log::debug!("refetching edit prediction context after idle");
- } else {
- cx.background_executor()
- .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
- .await;
- log::debug!("refetching edit prediction context after pause");
- }
- this.update(cx, |this, cx| {
- let task = this.refresh_context_with_agentic_retrieval(
- project.clone(),
- buffer,
- cursor_position,
- cx,
- );
-
- if let Some(zeta_project) = this.projects.get_mut(&project.entity_id())
- {
- if let ZetaProjectContext::Agentic {
- refresh_context_task,
- ..
- } = &mut zeta_project.context
- {
- *refresh_context_task = Some(task.log_err());
- }
- };
- })
- .ok()
- }
- }));
- }
- }
- }
-
- // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
- // and avoid spawning more than one concurrent task.
- pub fn refresh_context_with_agentic_retrieval(
- &mut self,
- project: Entity<Project>,
- buffer: Entity<language::Buffer>,
- cursor_position: language::Anchor,
- cx: &mut Context<Self>,
- ) -> Task<Result<()>> {
- let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
- return Task::ready(anyhow::Ok(()));
- };
-
- let ContextMode::Agentic(options) = &self.options().context else {
- return Task::ready(anyhow::Ok(()));
- };
-
- let snapshot = buffer.read(cx).snapshot();
- let cursor_point = cursor_position.to_point(&snapshot);
- let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
- cursor_point,
- &snapshot,
- &options.excerpt,
- None,
- ) else {
- return Task::ready(Ok(()));
- };
-
- let app_version = AppVersion::global(cx);
- let client = self.client.clone();
- let llm_token = self.llm_token.clone();
- let debug_tx = self.debug_tx.clone();
- let current_file_path: Arc<Path> = snapshot
- .file()
- .map(|f| f.full_path(cx).into())
- .unwrap_or_else(|| Path::new("untitled").into());
-
- let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
- predict_edits_v3::PlanContextRetrievalRequest {
- excerpt: cursor_excerpt.text(&snapshot).body,
- excerpt_path: current_file_path,
- excerpt_line_range: cursor_excerpt.line_range,
- cursor_file_max_row: Line(snapshot.max_point().row),
- events: zeta_project.events(cx),
- },
- ) {
- Ok(prompt) => prompt,
- Err(err) => {
- return Task::ready(Err(err));
- }
- };
-
- let retrieval_started_at = Instant::now();
-
- if let Some(debug_tx) = &debug_tx {
- debug_tx
- .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
- ZetaContextRetrievalStartedDebugInfo {
- project_entity_id: project.entity_id(),
- timestamp: retrieval_started_at,
- search_prompt: prompt.clone(),
- },
- ))
- .ok();
- }
-
- pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
- let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
- language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
- );
-
- let description = schema
- .get("description")
- .and_then(|description| description.as_str())
- .unwrap()
- .to_string();
-
- (schema.into(), description)
- });
-
- let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
-
- let request = open_ai::Request {
- model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
- messages: vec![open_ai::RequestMessage::User {
- content: open_ai::MessageContent::Plain(prompt),
- }],
- stream: false,
- max_completion_tokens: None,
- stop: Default::default(),
- temperature: 0.7,
- tool_choice: None,
- parallel_tool_calls: None,
- tools: vec![open_ai::ToolDefinition::Function {
- function: FunctionDefinition {
- name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
- description: Some(tool_description),
- parameters: Some(tool_schema),
- },
- }],
- prompt_cache_key: None,
- reasoning_effort: None,
- };
-
- #[cfg(feature = "eval-support")]
- let eval_cache = self.eval_cache.clone();
-
- cx.spawn(async move |this, cx| {
- log::trace!("Sending search planning request");
- let response = Self::send_raw_llm_request(
- request,
- client,
- llm_token,
- app_version,
- #[cfg(feature = "eval-support")]
- eval_cache.clone(),
- #[cfg(feature = "eval-support")]
- EvalCacheEntryKind::Context,
- )
- .await;
- let mut response = Self::handle_api_response(&this, response, cx)?;
- log::trace!("Got search planning response");
-
- let choice = response
- .choices
- .pop()
- .context("No choices in retrieval response")?;
- let open_ai::RequestMessage::Assistant {
- content: _,
- tool_calls,
- } = choice.message
- else {
- anyhow::bail!("Retrieval response didn't include an assistant message");
- };
-
- let mut queries: Vec<SearchToolQuery> = Vec::new();
- for tool_call in tool_calls {
- let open_ai::ToolCallContent::Function { function } = tool_call.content;
- if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
- log::warn!(
- "Context retrieval response tried to call an unknown tool: {}",
- function.name
- );
-
- continue;
- }
-
- let input: SearchToolInput = serde_json::from_str(&function.arguments)
- .with_context(|| format!("invalid search json {}", &function.arguments))?;
- queries.extend(input.queries);
- }
-
- log::trace!("Running retrieval search: {queries:#?}");
- let query_generation_finished_at = Instant::now();
-
- let related_excerpts_result = retrieval_search::run_retrieval_searches(
- queries,
- project.clone(),
- #[cfg(feature = "eval-support")]
- eval_cache,
- cx,
- )
- .await;
-
- log::trace!("Search queries executed");
- let query_execution_finished_at = Instant::now();
-
- this.update(cx, |this, _cx| {
- let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
- return Ok(());
- };
- if let ZetaProjectContext::Agentic {
- refresh_context_task,
- context,
- ..
- } = &mut zeta_project.context
- {
- refresh_context_task.take();
- if let Some(debug_tx) = &this.debug_tx {
- debug_tx
- .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
- ZetaContextRetrievalFinishedDebugInfo {
- project_entity_id: project.entity_id(),
- timestamp: Instant::now(),
- metadata: vec![
- (
- "query_generation",
- format!(
- "{:?}",
- query_generation_finished_at - retrieval_started_at
- )
- .into(),
- ),
- (
- "search_execution",
- format!(
- "{:?}",
- query_execution_finished_at
- - query_generation_finished_at
- )
- .into(),
- ),
- ],
- },
- ))
- .ok();
- }
- match related_excerpts_result {
- Ok(excerpts) => {
- *context = excerpts;
- Ok(())
- }
- Err(error) => Err(error),
- }
- } else {
- Ok(())
- }
- })?
- })
- }
-
- fn gather_nearby_diagnostics(
- cursor_offset: usize,
- diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
- snapshot: &BufferSnapshot,
- max_diagnostics_bytes: usize,
- ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
- // TODO: Could make this more efficient
- let mut diagnostic_groups = Vec::new();
- for (language_server_id, diagnostics) in diagnostic_sets {
- let mut groups = Vec::new();
- diagnostics.groups(*language_server_id, &mut groups, &snapshot);
- diagnostic_groups.extend(
- groups
- .into_iter()
- .map(|(_, group)| group.resolve::<usize>(&snapshot)),
- );
- }
-
- // sort by proximity to cursor
- diagnostic_groups.sort_by_key(|group| {
- let range = &group.entries[group.primary_ix].range;
- if range.start >= cursor_offset {
- range.start - cursor_offset
- } else if cursor_offset >= range.end {
- cursor_offset - range.end
- } else {
- (cursor_offset - range.start).min(range.end - cursor_offset)
- }
- });
-
- let mut results = Vec::new();
- let mut diagnostic_groups_truncated = false;
- let mut diagnostics_byte_count = 0;
- for group in diagnostic_groups {
- let raw_value = serde_json::value::to_raw_value(&group).unwrap();
- diagnostics_byte_count += raw_value.get().len();
- if diagnostics_byte_count > max_diagnostics_bytes {
- diagnostic_groups_truncated = true;
- break;
- }
- results.push(predict_edits_v3::DiagnosticGroup(raw_value));
- }
-
- (results, diagnostic_groups_truncated)
- }
-
- pub fn wait_for_initial_indexing(
- &mut self,
- project: &Entity<Project>,
- cx: &mut Context<Self>,
- ) -> Task<Result<()>> {
- let zeta_project = self.get_or_init_zeta_project(project, cx);
- if let ZetaProjectContext::Syntax(syntax_index) = &zeta_project.context {
- syntax_index.read(cx).wait_for_initial_file_indexing(cx)
- } else {
- Task::ready(Ok(()))
- }
- }
-
- fn is_file_open_source(
- &self,
- project: &Entity<Project>,
- file: &Arc<dyn File>,
- cx: &App,
- ) -> bool {
- if !file.is_local() || file.is_private() {
- return false;
- }
- let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
- return false;
- };
- zeta_project
- .license_detection_watchers
- .get(&file.worktree_id(cx))
- .as_ref()
- .is_some_and(|watcher| watcher.is_project_open_source())
- }
-
- fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
- self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
- }
-
- fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
- if !self.data_collection_choice.is_enabled() {
- return false;
- }
- events.iter().all(|event| {
- matches!(
- event.as_ref(),
- Event::BufferChange {
- in_open_source_repo: true,
- ..
- }
- )
- })
- }
-
- fn load_data_collection_choice() -> DataCollectionChoice {
- let choice = KEY_VALUE_STORE
- .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
- .log_err()
- .flatten();
-
- match choice.as_deref() {
- Some("true") => DataCollectionChoice::Enabled,
- Some("false") => DataCollectionChoice::Disabled,
- Some(_) => {
- log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
- DataCollectionChoice::NotAnswered
- }
- None => DataCollectionChoice::NotAnswered,
- }
- }
-
- pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
- self.shown_predictions.iter()
- }
-
- pub fn shown_completions_len(&self) -> usize {
- self.shown_predictions.len()
- }
-
- pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
- self.rated_predictions.contains(id)
- }
-
- pub fn rate_prediction(
- &mut self,
- prediction: &EditPrediction,
- rating: EditPredictionRating,
- feedback: String,
- cx: &mut Context<Self>,
- ) {
- self.rated_predictions.insert(prediction.id.clone());
- telemetry::event!(
- "Edit Prediction Rated",
- rating,
- inputs = prediction.inputs,
- output = prediction.edit_preview.as_unified_diff(&prediction.edits),
- feedback
- );
- self.client.telemetry().flush_events().detach();
- cx.notify();
- }
-
- fn enable_or_disable_context_retrieval(&mut self, cx: &mut Context<'_, Zeta>) {
- self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
- && all_language_settings(None, cx).edit_predictions.use_context;
- }
-}
-
-pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
- let choice = res.choices.pop()?;
- let output_text = match choice.message {
- open_ai::RequestMessage::Assistant {
- content: Some(open_ai::MessageContent::Plain(content)),
- ..
- } => content,
- open_ai::RequestMessage::Assistant {
- content: Some(open_ai::MessageContent::Multipart(mut content)),
- ..
- } => {
- if content.is_empty() {
- log::error!("No output from Baseten completion response");
- return None;
- }
-
- match content.remove(0) {
- open_ai::MessagePart::Text { text } => text,
- open_ai::MessagePart::Image { .. } => {
- log::error!("Expected text, got an image");
- return None;
- }
- }
- }
- _ => {
- log::error!("Invalid response message: {:?}", choice.message);
- return None;
- }
- };
- Some(output_text)
-}
-
-#[derive(Error, Debug)]
-#[error(
- "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
-)]
-pub struct ZedUpdateRequiredError {
- minimum_version: Version,
-}
-
-#[cfg(feature = "eval-support")]
-pub type EvalCacheKey = (EvalCacheEntryKind, u64);
-
-#[cfg(feature = "eval-support")]
-#[derive(Debug, Clone, Copy, PartialEq)]
-pub enum EvalCacheEntryKind {
- Context,
- Search,
- Prediction,
-}
-
-#[cfg(feature = "eval-support")]
-impl std::fmt::Display for EvalCacheEntryKind {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- EvalCacheEntryKind::Search => write!(f, "search"),
- EvalCacheEntryKind::Context => write!(f, "context"),
- EvalCacheEntryKind::Prediction => write!(f, "prediction"),
- }
- }
-}
-
-#[cfg(feature = "eval-support")]
-pub trait EvalCache: Send + Sync {
- fn read(&self, key: EvalCacheKey) -> Option<String>;
- fn write(&self, key: EvalCacheKey, input: &str, value: &str);
-}
-
-#[derive(Debug, Clone, Copy)]
-pub enum DataCollectionChoice {
- NotAnswered,
- Enabled,
- Disabled,
-}
-
-impl DataCollectionChoice {
- pub fn is_enabled(self) -> bool {
- match self {
- Self::Enabled => true,
- Self::NotAnswered | Self::Disabled => false,
- }
- }
-
- pub fn is_answered(self) -> bool {
- match self {
- Self::Enabled | Self::Disabled => true,
- Self::NotAnswered => false,
- }
- }
-
- #[must_use]
- pub fn toggle(&self) -> DataCollectionChoice {
- match self {
- Self::Enabled => Self::Disabled,
- Self::Disabled => Self::Enabled,
- Self::NotAnswered => Self::Enabled,
- }
- }
-}
-
-impl From<bool> for DataCollectionChoice {
- fn from(value: bool) -> Self {
- match value {
- true => DataCollectionChoice::Enabled,
- false => DataCollectionChoice::Disabled,
- }
- }
-}
-
-struct ZedPredictUpsell;
-
-impl Dismissable for ZedPredictUpsell {
- const KEY: &'static str = "dismissed-edit-predict-upsell";
-
- fn dismissed() -> bool {
- // To make this backwards compatible with older versions of Zed, we
- // check if the user has seen the previous Edit Prediction Onboarding
- // before, by checking the data collection choice which was written to
- // the database once the user clicked on "Accept and Enable"
- if KEY_VALUE_STORE
- .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
- .log_err()
- .is_some_and(|s| s.is_some())
- {
- return true;
- }
-
- KEY_VALUE_STORE
- .read_kvp(Self::KEY)
- .log_err()
- .is_some_and(|s| s.is_some())
- }
-}
-
-pub fn should_show_upsell_modal() -> bool {
- !ZedPredictUpsell::dismissed()
-}
-
-pub fn init(cx: &mut App) {
- feature_gate_predict_edits_actions(cx);
-
- cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
- workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
- if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
- RatePredictionsModal::toggle(workspace, window, cx);
- }
- });
-
- workspace.register_action(
- move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
- ZedPredictModal::toggle(
- workspace,
- workspace.user_store().clone(),
- workspace.client().clone(),
- window,
- cx,
- )
- },
- );
-
- workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
- update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
- settings
- .project
- .all_languages
- .features
- .get_or_insert_default()
- .edit_prediction_provider = Some(EditPredictionProvider::None)
- });
- });
- })
- .detach();
-}
-
-fn feature_gate_predict_edits_actions(cx: &mut App) {
- let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
- let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
- let zeta_all_action_types = [
- TypeId::of::<RateCompletions>(),
- TypeId::of::<ResetOnboarding>(),
- zed_actions::OpenZedPredictOnboarding.type_id(),
- TypeId::of::<ClearHistory>(),
- TypeId::of::<ThumbsUpActivePrediction>(),
- TypeId::of::<ThumbsDownActivePrediction>(),
- TypeId::of::<NextEdit>(),
- TypeId::of::<PreviousEdit>(),
- ];
-
- CommandPaletteFilter::update_global(cx, |filter, _cx| {
- filter.hide_action_types(&rate_completion_action_types);
- filter.hide_action_types(&reset_onboarding_action_types);
- filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
- });
-
- cx.observe_global::<SettingsStore>(move |cx| {
- let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
- let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
-
- CommandPaletteFilter::update_global(cx, |filter, _cx| {
- if is_ai_disabled {
- filter.hide_action_types(&zeta_all_action_types);
- } else if has_feature_flag {
- filter.show_action_types(&rate_completion_action_types);
- } else {
- filter.hide_action_types(&rate_completion_action_types);
- }
- });
- })
- .detach();
-
- cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
- if !DisableAiSettings::get_global(cx).disable_ai {
- if is_enabled {
- CommandPaletteFilter::update_global(cx, |filter, _cx| {
- filter.show_action_types(&rate_completion_action_types);
- });
- } else {
- CommandPaletteFilter::update_global(cx, |filter, _cx| {
- filter.hide_action_types(&rate_completion_action_types);
- });
- }
- }
- })
- .detach();
-}
-
-#[cfg(test)]
-mod tests {
- use std::{path::Path, sync::Arc, time::Duration};
-
- use client::UserStore;
- use clock::FakeSystemClock;
- use cloud_llm_client::{
- EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
- };
- use futures::{
- AsyncReadExt, StreamExt,
- channel::{mpsc, oneshot},
- };
- use gpui::{
- Entity, TestAppContext,
- http_client::{FakeHttpClient, Response},
- prelude::*,
- };
- use indoc::indoc;
- use language::OffsetRangeExt as _;
- use lsp::LanguageServerId;
- use open_ai::Usage;
- use pretty_assertions::{assert_eq, assert_matches};
- use project::{FakeFs, Project};
- use serde_json::json;
- use settings::SettingsStore;
- use util::path;
- use uuid::Uuid;
-
- use crate::{BufferEditPrediction, EditPredictionId, REJECT_REQUEST_DEBOUNCE, Zeta};
-
- #[gpui::test]
- async fn test_current_state(cx: &mut TestAppContext) {
- let (zeta, mut requests) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "1.txt": "Hello!\nHow\nBye\n",
- "2.txt": "Hola!\nComo\nAdios\n"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- zeta.update(cx, |zeta, cx| {
- zeta.register_project(&project, cx);
- });
-
- let buffer1 = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
- project.set_active_path(Some(path.clone()), cx);
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
- let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot1.anchor_before(language::Point::new(1, 3));
-
- // Prediction for current file
-
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
- });
- let (_request, respond_tx) = requests.predict.next().await.unwrap();
-
- respond_tx
- .send(model_response(indoc! {r"
- --- a/root/1.txt
- +++ b/root/1.txt
- @@ ... @@
- Hello!
- -How
- +How are you?
- Bye
- "}))
- .unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- let prediction = zeta
- .current_prediction_for_buffer(&buffer1, &project, cx)
- .unwrap();
- assert_matches!(prediction, BufferEditPrediction::Local { .. });
- });
-
- zeta.update(cx, |zeta, _cx| {
- zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
- });
-
- // Prediction for diagnostic in another file
-
- let diagnostic = lsp::Diagnostic {
- range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
- severity: Some(lsp::DiagnosticSeverity::ERROR),
- message: "Sentence is incomplete".to_string(),
- ..Default::default()
- };
-
- project.update(cx, |project, cx| {
- project.lsp_store().update(cx, |lsp_store, cx| {
- lsp_store
- .update_diagnostics(
- LanguageServerId(0),
- lsp::PublishDiagnosticsParams {
- uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
- diagnostics: vec![diagnostic],
- version: None,
- },
- None,
- language::DiagnosticSourceKind::Pushed,
- &[],
- cx,
- )
- .unwrap();
- });
- });
-
- let (_request, respond_tx) = requests.predict.next().await.unwrap();
- respond_tx
- .send(model_response(indoc! {r#"
- --- a/root/2.txt
- +++ b/root/2.txt
- Hola!
- -Como
- +Como estas?
- Adios
- "#}))
- .unwrap();
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- let prediction = zeta
- .current_prediction_for_buffer(&buffer1, &project, cx)
- .unwrap();
- assert_matches!(
- prediction,
- BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
- );
- });
-
- let buffer2 = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
-
- zeta.read_with(cx, |zeta, cx| {
- let prediction = zeta
- .current_prediction_for_buffer(&buffer2, &project, cx)
- .unwrap();
- assert_matches!(prediction, BufferEditPrediction::Local { .. });
- });
- }
-
- #[gpui::test]
- async fn test_simple_request(cx: &mut TestAppContext) {
- let (zeta, mut requests) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "foo.md": "Hello!\nHow\nBye\n"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- let buffer = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot.anchor_before(language::Point::new(1, 3));
-
- let prediction_task = zeta.update(cx, |zeta, cx| {
- zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
- });
-
- let (_, respond_tx) = requests.predict.next().await.unwrap();
-
- // TODO Put back when we have a structured request again
- // assert_eq!(
- // request.excerpt_path.as_ref(),
- // Path::new(path!("root/foo.md"))
- // );
- // assert_eq!(
- // request.cursor_point,
- // Point {
- // line: Line(1),
- // column: 3
- // }
- // );
-
- respond_tx
- .send(model_response(indoc! { r"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How are you?
- Bye
- "}))
- .unwrap();
-
- let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
-
- assert_eq!(prediction.edits.len(), 1);
- assert_eq!(
- prediction.edits[0].0.to_point(&snapshot).start,
- language::Point::new(1, 3)
- );
- assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
- }
-
- #[gpui::test]
- async fn test_request_events(cx: &mut TestAppContext) {
- let (zeta, mut requests) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "foo.md": "Hello!\n\nBye\n"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- let buffer = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
-
- zeta.update(cx, |zeta, cx| {
- zeta.register_buffer(&buffer, &project, cx);
- });
-
- buffer.update(cx, |buffer, cx| {
- buffer.edit(vec![(7..7, "How")], None, cx);
- });
-
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot.anchor_before(language::Point::new(1, 3));
-
- let prediction_task = zeta.update(cx, |zeta, cx| {
- zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
- });
-
- let (request, respond_tx) = requests.predict.next().await.unwrap();
-
- let prompt = prompt_from_request(&request);
- assert!(
- prompt.contains(indoc! {"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ -1,3 +1,3 @@
- Hello!
- -
- +How
- Bye
- "}),
- "{prompt}"
- );
-
- respond_tx
- .send(model_response(indoc! {r#"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How are you?
- Bye
- "#}))
- .unwrap();
-
- let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
-
- assert_eq!(prediction.edits.len(), 1);
- assert_eq!(
- prediction.edits[0].0.to_point(&snapshot).start,
- language::Point::new(1, 3)
- );
- assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
- }
-
- #[gpui::test]
- async fn test_empty_prediction(cx: &mut TestAppContext) {
- let (zeta, mut requests) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "foo.md": "Hello!\nHow\nBye\n"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- let buffer = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot.anchor_before(language::Point::new(1, 3));
-
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
- });
-
- const NO_OP_DIFF: &str = indoc! { r"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How
- Bye
- "};
-
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- let response = model_response(NO_OP_DIFF);
- let id = response.id.clone();
- respond_tx.send(response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- assert!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .is_none()
- );
- });
-
- // prediction is reported as rejected
- let (reject_request, _) = requests.reject.next().await.unwrap();
-
- assert_eq!(
- &reject_request.rejections,
- &[EditPredictionRejection {
- request_id: id,
- reason: EditPredictionRejectReason::Empty,
- was_shown: false
- }]
- );
- }
-
- #[gpui::test]
- async fn test_interpolated_empty(cx: &mut TestAppContext) {
- let (zeta, mut requests) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "foo.md": "Hello!\nHow\nBye\n"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- let buffer = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot.anchor_before(language::Point::new(1, 3));
-
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
- });
-
- let (_, respond_tx) = requests.predict.next().await.unwrap();
-
- buffer.update(cx, |buffer, cx| {
- buffer.set_text("Hello!\nHow are you?\nBye", cx);
- });
-
- let response = model_response(SIMPLE_DIFF);
- let id = response.id.clone();
- respond_tx.send(response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- assert!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .is_none()
- );
- });
-
- // prediction is reported as rejected
- let (reject_request, _) = requests.reject.next().await.unwrap();
-
- assert_eq!(
- &reject_request.rejections,
- &[EditPredictionRejection {
- request_id: id,
- reason: EditPredictionRejectReason::InterpolatedEmpty,
- was_shown: false
- }]
- );
- }
-
- const SIMPLE_DIFF: &str = indoc! { r"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How are you?
- Bye
- "};
-
- #[gpui::test]
- async fn test_replace_current(cx: &mut TestAppContext) {
- let (zeta, mut requests) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "foo.md": "Hello!\nHow\nBye\n"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- let buffer = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot.anchor_before(language::Point::new(1, 3));
-
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
- });
-
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- let first_response = model_response(SIMPLE_DIFF);
- let first_id = first_response.id.clone();
- respond_tx.send(first_response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- assert_eq!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .unwrap()
- .id
- .0,
- first_id
- );
- });
-
- // a second request is triggered
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
- });
-
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- let second_response = model_response(SIMPLE_DIFF);
- let second_id = second_response.id.clone();
- respond_tx.send(second_response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- // second replaces first
- assert_eq!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .unwrap()
- .id
- .0,
- second_id
- );
- });
-
- // first is reported as replaced
- let (reject_request, _) = requests.reject.next().await.unwrap();
-
- assert_eq!(
- &reject_request.rejections,
- &[EditPredictionRejection {
- request_id: first_id,
- reason: EditPredictionRejectReason::Replaced,
- was_shown: false
- }]
- );
- }
-
- #[gpui::test]
- async fn test_current_preferred(cx: &mut TestAppContext) {
- let (zeta, mut requests) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "foo.md": "Hello!\nHow\nBye\n"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- let buffer = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot.anchor_before(language::Point::new(1, 3));
-
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
- });
-
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- let first_response = model_response(SIMPLE_DIFF);
- let first_id = first_response.id.clone();
- respond_tx.send(first_response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- assert_eq!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .unwrap()
- .id
- .0,
- first_id
- );
- });
-
- // a second request is triggered
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
- });
-
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- // worse than current prediction
- let second_response = model_response(indoc! { r"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How are
- Bye
- "});
- let second_id = second_response.id.clone();
- respond_tx.send(second_response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- // first is preferred over second
- assert_eq!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .unwrap()
- .id
- .0,
- first_id
- );
- });
-
- // second is reported as rejected
- let (reject_request, _) = requests.reject.next().await.unwrap();
-
- assert_eq!(
- &reject_request.rejections,
- &[EditPredictionRejection {
- request_id: second_id,
- reason: EditPredictionRejectReason::CurrentPreferred,
- was_shown: false
- }]
- );
- }
-
- #[gpui::test]
- async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
- let (zeta, mut requests) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "foo.md": "Hello!\nHow\nBye\n"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- let buffer = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot.anchor_before(language::Point::new(1, 3));
-
- // start two refresh tasks
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
- });
-
- let (_, respond_first) = requests.predict.next().await.unwrap();
-
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
- });
-
- let (_, respond_second) = requests.predict.next().await.unwrap();
-
- // wait for throttle
- cx.run_until_parked();
-
- // second responds first
- let second_response = model_response(SIMPLE_DIFF);
- let second_id = second_response.id.clone();
- respond_second.send(second_response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- // current prediction is second
- assert_eq!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .unwrap()
- .id
- .0,
- second_id
- );
- });
-
- let first_response = model_response(SIMPLE_DIFF);
- let first_id = first_response.id.clone();
- respond_first.send(first_response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- // current prediction is still second, since first was cancelled
- assert_eq!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .unwrap()
- .id
- .0,
- second_id
- );
- });
-
- // first is reported as rejected
- let (reject_request, _) = requests.reject.next().await.unwrap();
-
- cx.run_until_parked();
-
- assert_eq!(
- &reject_request.rejections,
- &[EditPredictionRejection {
- request_id: first_id,
- reason: EditPredictionRejectReason::Canceled,
- was_shown: false
- }]
- );
- }
-
- #[gpui::test]
- async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
- let (zeta, mut requests) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "foo.md": "Hello!\nHow\nBye\n"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- let buffer = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot.anchor_before(language::Point::new(1, 3));
-
- // start two refresh tasks
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
- });
-
- let (_, respond_first) = requests.predict.next().await.unwrap();
-
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
- });
-
- let (_, respond_second) = requests.predict.next().await.unwrap();
-
- // wait for throttle, so requests are sent
- cx.run_until_parked();
-
- zeta.update(cx, |zeta, cx| {
- // start a third request
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
-
- // 2 are pending, so 2nd is cancelled
- assert_eq!(
- zeta.get_or_init_zeta_project(&project, cx)
- .cancelled_predictions
- .iter()
- .copied()
- .collect::<Vec<_>>(),
- [1]
- );
- });
-
- // wait for throttle
- cx.run_until_parked();
-
- let (_, respond_third) = requests.predict.next().await.unwrap();
-
- let first_response = model_response(SIMPLE_DIFF);
- let first_id = first_response.id.clone();
- respond_first.send(first_response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- // current prediction is first
- assert_eq!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .unwrap()
- .id
- .0,
- first_id
- );
- });
-
- let cancelled_response = model_response(SIMPLE_DIFF);
- let cancelled_id = cancelled_response.id.clone();
- respond_second.send(cancelled_response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- // current prediction is still first, since second was cancelled
- assert_eq!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .unwrap()
- .id
- .0,
- first_id
- );
- });
-
- let third_response = model_response(SIMPLE_DIFF);
- let third_response_id = third_response.id.clone();
- respond_third.send(third_response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- // third completes and replaces first
- assert_eq!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .unwrap()
- .id
- .0,
- third_response_id
- );
- });
-
- // second is reported as rejected
- let (reject_request, _) = requests.reject.next().await.unwrap();
-
- cx.run_until_parked();
-
- assert_eq!(
- &reject_request.rejections,
- &[
- EditPredictionRejection {
- request_id: cancelled_id,
- reason: EditPredictionRejectReason::Canceled,
- was_shown: false
- },
- EditPredictionRejection {
- request_id: first_id,
- reason: EditPredictionRejectReason::Replaced,
- was_shown: false
- }
- ]
- );
- }
-
- #[gpui::test]
- async fn test_rejections_flushing(cx: &mut TestAppContext) {
- let (zeta, mut requests) = init_test(cx);
-
- zeta.update(cx, |zeta, _cx| {
- zeta.reject_prediction(
- EditPredictionId("test-1".into()),
- EditPredictionRejectReason::Discarded,
- false,
- );
- zeta.reject_prediction(
- EditPredictionId("test-2".into()),
- EditPredictionRejectReason::Canceled,
- true,
- );
- });
-
- cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
- cx.run_until_parked();
-
- let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
- respond_tx.send(()).unwrap();
-
- // batched
- assert_eq!(reject_request.rejections.len(), 2);
- assert_eq!(
- reject_request.rejections[0],
- EditPredictionRejection {
- request_id: "test-1".to_string(),
- reason: EditPredictionRejectReason::Discarded,
- was_shown: false
- }
- );
- assert_eq!(
- reject_request.rejections[1],
- EditPredictionRejection {
- request_id: "test-2".to_string(),
- reason: EditPredictionRejectReason::Canceled,
- was_shown: true
- }
- );
-
- // Reaching batch size limit sends without debounce
- zeta.update(cx, |zeta, _cx| {
- for i in 0..70 {
- zeta.reject_prediction(
- EditPredictionId(format!("batch-{}", i).into()),
- EditPredictionRejectReason::Discarded,
- false,
- );
- }
- });
-
- // First MAX/2 items are sent immediately
- cx.run_until_parked();
- let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
- respond_tx.send(()).unwrap();
-
- assert_eq!(reject_request.rejections.len(), 50);
- assert_eq!(reject_request.rejections[0].request_id, "batch-0");
- assert_eq!(reject_request.rejections[49].request_id, "batch-49");
-
- // Remaining items are debounced with the next batch
- cx.executor().advance_clock(Duration::from_secs(15));
- cx.run_until_parked();
-
- let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
- respond_tx.send(()).unwrap();
-
- assert_eq!(reject_request.rejections.len(), 20);
- assert_eq!(reject_request.rejections[0].request_id, "batch-50");
- assert_eq!(reject_request.rejections[19].request_id, "batch-69");
-
- // Request failure
- zeta.update(cx, |zeta, _cx| {
- zeta.reject_prediction(
- EditPredictionId("retry-1".into()),
- EditPredictionRejectReason::Discarded,
- false,
- );
- });
-
- cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
- cx.run_until_parked();
-
- let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
- assert_eq!(reject_request.rejections.len(), 1);
- assert_eq!(reject_request.rejections[0].request_id, "retry-1");
- // Simulate failure
- drop(_respond_tx);
-
- // Add another rejection
- zeta.update(cx, |zeta, _cx| {
- zeta.reject_prediction(
- EditPredictionId("retry-2".into()),
- EditPredictionRejectReason::Discarded,
- false,
- );
- });
-
- cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
- cx.run_until_parked();
-
- // Retry should include both the failed item and the new one
- let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
- respond_tx.send(()).unwrap();
-
- assert_eq!(reject_request.rejections.len(), 2);
- assert_eq!(reject_request.rejections[0].request_id, "retry-1");
- assert_eq!(reject_request.rejections[1].request_id, "retry-2");
- }
-
- // Skipped until we start including diagnostics in prompt
- // #[gpui::test]
- // async fn test_request_diagnostics(cx: &mut TestAppContext) {
- // let (zeta, mut req_rx) = init_test(cx);
- // let fs = FakeFs::new(cx.executor());
- // fs.insert_tree(
- // "/root",
- // json!({
- // "foo.md": "Hello!\nBye"
- // }),
- // )
- // .await;
- // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
- // let diagnostic = lsp::Diagnostic {
- // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
- // severity: Some(lsp::DiagnosticSeverity::ERROR),
- // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
- // ..Default::default()
- // };
-
- // project.update(cx, |project, cx| {
- // project.lsp_store().update(cx, |lsp_store, cx| {
- // // Create some diagnostics
- // lsp_store
- // .update_diagnostics(
- // LanguageServerId(0),
- // lsp::PublishDiagnosticsParams {
- // uri: path_to_buffer_uri.clone(),
- // diagnostics: vec![diagnostic],
- // version: None,
- // },
- // None,
- // language::DiagnosticSourceKind::Pushed,
- // &[],
- // cx,
- // )
- // .unwrap();
- // });
- // });
-
- // let buffer = project
- // .update(cx, |project, cx| {
- // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- // project.open_buffer(path, cx)
- // })
- // .await
- // .unwrap();
-
- // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- // let position = snapshot.anchor_before(language::Point::new(0, 0));
-
- // let _prediction_task = zeta.update(cx, |zeta, cx| {
- // zeta.request_prediction(&project, &buffer, position, cx)
- // });
-
- // let (request, _respond_tx) = req_rx.next().await.unwrap();
-
- // assert_eq!(request.diagnostic_groups.len(), 1);
- // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
- // .unwrap();
- // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
- // assert_eq!(
- // value,
- // json!({
- // "entries": [{
- // "range": {
- // "start": 8,
- // "end": 10
- // },
- // "diagnostic": {
- // "source": null,
- // "code": null,
- // "code_description": null,
- // "severity": 1,
- // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
- // "markdown": null,
- // "group_id": 0,
- // "is_primary": true,
- // "is_disk_based": false,
- // "is_unnecessary": false,
- // "source_kind": "Pushed",
- // "data": null,
- // "underline": true
- // }
- // }],
- // "primary_ix": 0
- // })
- // );
- // }
-
- fn model_response(text: &str) -> open_ai::Response {
- open_ai::Response {
- id: Uuid::new_v4().to_string(),
- object: "response".into(),
- created: 0,
- model: "model".into(),
- choices: vec![open_ai::Choice {
- index: 0,
- message: open_ai::RequestMessage::Assistant {
- content: Some(open_ai::MessageContent::Plain(text.to_string())),
- tool_calls: vec![],
- },
- finish_reason: None,
- }],
- usage: Usage {
- prompt_tokens: 0,
- completion_tokens: 0,
- total_tokens: 0,
- },
- }
- }
-
- fn prompt_from_request(request: &open_ai::Request) -> &str {
- assert_eq!(request.messages.len(), 1);
- let open_ai::RequestMessage::User {
- content: open_ai::MessageContent::Plain(content),
- ..
- } = &request.messages[0]
- else {
- panic!(
- "Request does not have single user message of type Plain. {:#?}",
- request
- );
- };
- content
- }
-
- struct RequestChannels {
- predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
- reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
- }
-
- fn init_test(cx: &mut TestAppContext) -> (Entity<Zeta>, RequestChannels) {
- cx.update(move |cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- zlog::init_test();
-
- let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
- let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
-
- let http_client = FakeHttpClient::create({
- move |req| {
- let uri = req.uri().path().to_string();
- let mut body = req.into_body();
- let predict_req_tx = predict_req_tx.clone();
- let reject_req_tx = reject_req_tx.clone();
- async move {
- let resp = match uri.as_str() {
- "/client/llm_tokens" => serde_json::to_string(&json!({
- "token": "test"
- }))
- .unwrap(),
- "/predict_edits/raw" => {
- let mut buf = Vec::new();
- body.read_to_end(&mut buf).await.ok();
- let req = serde_json::from_slice(&buf).unwrap();
- let (res_tx, res_rx) = oneshot::channel();
- predict_req_tx.unbounded_send((req, res_tx)).unwrap();
- serde_json::to_string(&res_rx.await?).unwrap()
- }
- "/predict_edits/reject" => {
- let mut buf = Vec::new();
- body.read_to_end(&mut buf).await.ok();
- let req = serde_json::from_slice(&buf).unwrap();
-
- let (res_tx, res_rx) = oneshot::channel();
- reject_req_tx.unbounded_send((req, res_tx)).unwrap();
- serde_json::to_string(&res_rx.await?).unwrap()
- }
- _ => {
- panic!("Unexpected path: {}", uri)
- }
- };
-
- Ok(Response::builder().body(resp.into()).unwrap())
- }
- }
- });
-
- let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
- client.cloud_client().set_credentials(1, "test".into());
-
- language_model::init(client.clone(), cx);
-
- let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- let zeta = Zeta::global(&client, &user_store, cx);
-
- (
- zeta,
- RequestChannels {
- predict: predict_req_rx,
- reject: reject_req_rx,
- },
- )
- })
- }
-}
@@ -1,671 +0,0 @@
-use client::test::FakeServer;
-use clock::{FakeSystemClock, ReplicaId};
-use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
-use cloud_llm_client::{PredictEditsBody, PredictEditsResponse};
-use gpui::TestAppContext;
-use http_client::FakeHttpClient;
-use indoc::indoc;
-use language::Point;
-use parking_lot::Mutex;
-use serde_json::json;
-use settings::SettingsStore;
-use util::{path, rel_path::rel_path};
-
-use crate::zeta1::MAX_EVENT_TOKENS;
-
-use super::*;
-
-const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
-
-#[gpui::test]
-async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
- let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
- let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
- to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
- });
-
- let edit_preview = cx
- .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
- .await;
-
- let completion = EditPrediction {
- edits,
- edit_preview,
- buffer: buffer.clone(),
- snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
- id: EditPredictionId("the-id".into()),
- inputs: EditPredictionInputs {
- events: Default::default(),
- included_files: Default::default(),
- cursor_point: cloud_llm_client::predict_edits_v3::Point {
- line: Line(0),
- column: 0,
- },
- cursor_path: Path::new("").into(),
- },
- buffer_snapshotted_at: Instant::now(),
- response_received_at: Instant::now(),
- };
-
- cx.update(|cx| {
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(2..5, "REM".into()), (9..11, "".into())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(2..2, "REM".into()), (6..8, "".into())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.undo(cx));
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(2..5, "REM".into()), (9..11, "".into())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(3..3, "EM".into()), (7..9, "".into())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(4..4, "M".into()), (8..10, "".into())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(9..11, "".into())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(4..4, "M".into()), (8..10, "".into())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(4..4, "M".into())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
- assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
- })
-}
-
-#[gpui::test]
-async fn test_clean_up_diff(cx: &mut TestAppContext) {
- init_test(cx);
-
- assert_eq!(
- apply_edit_prediction(
- indoc! {"
- fn main() {
- let word_1 = \"lorem\";
- let range = word.len()..word.len();
- }
- "},
- indoc! {"
- <|editable_region_start|>
- fn main() {
- let word_1 = \"lorem\";
- let range = word_1.len()..word_1.len();
- }
-
- <|editable_region_end|>
- "},
- cx,
- )
- .await,
- indoc! {"
- fn main() {
- let word_1 = \"lorem\";
- let range = word_1.len()..word_1.len();
- }
- "},
- );
-
- assert_eq!(
- apply_edit_prediction(
- indoc! {"
- fn main() {
- let story = \"the quick\"
- }
- "},
- indoc! {"
- <|editable_region_start|>
- fn main() {
- let story = \"the quick brown fox jumps over the lazy dog\";
- }
-
- <|editable_region_end|>
- "},
- cx,
- )
- .await,
- indoc! {"
- fn main() {
- let story = \"the quick brown fox jumps over the lazy dog\";
- }
- "},
- );
-}
-
-#[gpui::test]
-async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
- init_test(cx);
-
- let buffer_content = "lorem\n";
- let completion_response = indoc! {"
- ```animals.js
- <|start_of_file|>
- <|editable_region_start|>
- lorem
- ipsum
- <|editable_region_end|>
- ```"};
-
- assert_eq!(
- apply_edit_prediction(buffer_content, completion_response, cx).await,
- "lorem\nipsum"
- );
-}
-
-#[gpui::test]
-async fn test_can_collect_data(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
- .await;
-
- let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/project/src/main.rs"), cx)
- })
- .await
- .unwrap();
-
- let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- true
- );
-
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Disabled
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-}
-
-#[gpui::test]
-async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- let project = Project::test(fs.clone(), [], cx).await;
-
- let buffer = cx.new(|_cx| {
- Buffer::remote(
- language::BufferId::new(1).unwrap(),
- ReplicaId::new(1),
- language::Capability::ReadWrite,
- "fn main() {\n println!(\"Hello\");\n}",
- )
- });
-
- let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-}
-
-#[gpui::test]
-async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- fs.insert_tree(
- path!("/project"),
- json!({
- "LICENSE": BSD_0_TXT,
- ".env": "SECRET_KEY=secret"
- }),
- )
- .await;
-
- let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer("/project/.env", cx)
- })
- .await
- .unwrap();
-
- let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-}
-
-#[gpui::test]
-async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- let project = Project::test(fs.clone(), [], cx).await;
- let buffer = cx.new(|cx| Buffer::local("", cx));
-
- let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-}
-
-#[gpui::test]
-async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
- .await;
-
- let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer("/project/main.rs", cx)
- })
- .await
- .unwrap();
-
- let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-}
-
-#[gpui::test]
-async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- fs.insert_tree(
- path!("/open_source_worktree"),
- json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
- )
- .await;
- fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
- .await;
-
- let project = Project::test(
- fs.clone(),
- [
- path!("/open_source_worktree").as_ref(),
- path!("/closed_source_worktree").as_ref(),
- ],
- cx,
- )
- .await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
- })
- .await
- .unwrap();
-
- let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- true
- );
-
- let closed_source_file = project
- .update(cx, |project, cx| {
- let worktree2 = project
- .worktree_for_root_name("closed_source_worktree", cx)
- .unwrap();
- worktree2.update(cx, |worktree2, cx| {
- worktree2.load_file(rel_path("main.rs"), cx)
- })
- })
- .await
- .unwrap()
- .file;
-
- buffer.update(cx, |buffer, cx| {
- buffer.file_updated(closed_source_file, cx);
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-}
-
-#[gpui::test]
-async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- fs.insert_tree(
- path!("/worktree1"),
- json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
- )
- .await;
- fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
- .await;
-
- let project = Project::test(
- fs.clone(),
- [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
- cx,
- )
- .await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/worktree1/main.rs"), cx)
- })
- .await
- .unwrap();
- let private_buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/worktree2/file.rs"), cx)
- })
- .await
- .unwrap();
-
- let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- true
- );
-
- // this has a side effect of registering the buffer to watch for edits
- run_edit_prediction(&private_buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-
- private_buffer.update(cx, |private_buffer, cx| {
- private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-
- // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
- // included
- buffer.update(cx, |buffer, cx| {
- buffer.edit(
- [(
- 0..0,
- " ".repeat(MAX_EVENT_TOKENS * zeta1::BYTES_PER_TOKEN_GUESS),
- )],
- None,
- cx,
- );
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- true
- );
-}
-
-fn init_test(cx: &mut TestAppContext) {
- cx.update(|cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- });
-}
-
-async fn apply_edit_prediction(
- buffer_content: &str,
- completion_response: &str,
- cx: &mut TestAppContext,
-) -> String {
- let fs = project::FakeFs::new(cx.executor());
- let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
- let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
- let (zeta, _, response) = make_test_zeta(&project, cx).await;
- *response.lock() = completion_response.to_string();
- let edit_prediction = run_edit_prediction(&buffer, &project, &zeta, cx).await;
- buffer.update(cx, |buffer, cx| {
- buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
- });
- buffer.read_with(cx, |buffer, _| buffer.text())
-}
-
-async fn run_edit_prediction(
- buffer: &Entity<Buffer>,
- project: &Entity<Project>,
- zeta: &Entity<Zeta>,
- cx: &mut TestAppContext,
-) -> EditPrediction {
- let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
- zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx));
- cx.background_executor.run_until_parked();
- let prediction_task = zeta.update(cx, |zeta, cx| {
- zeta.request_prediction(&project, buffer, cursor, Default::default(), cx)
- });
- prediction_task.await.unwrap().unwrap().prediction.unwrap()
-}
-
-async fn make_test_zeta(
- project: &Entity<Project>,
- cx: &mut TestAppContext,
-) -> (
- Entity<Zeta>,
- Arc<Mutex<Option<PredictEditsBody>>>,
- Arc<Mutex<String>>,
-) {
- let default_response = indoc! {"
- ```main.rs
- <|start_of_file|>
- <|editable_region_start|>
- hello world
- <|editable_region_end|>
- ```"
- };
- let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
- let completion_response: Arc<Mutex<String>> =
- Arc::new(Mutex::new(default_response.to_string()));
- let http_client = FakeHttpClient::create({
- let captured_request = captured_request.clone();
- let completion_response = completion_response.clone();
- let mut next_request_id = 0;
- move |req| {
- let captured_request = captured_request.clone();
- let completion_response = completion_response.clone();
- async move {
- match (req.method(), req.uri().path()) {
- (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
- .status(200)
- .body(
- serde_json::to_string(&CreateLlmTokenResponse {
- token: LlmToken("the-llm-token".to_string()),
- })
- .unwrap()
- .into(),
- )
- .unwrap()),
- (&Method::POST, "/predict_edits/v2") => {
- let mut request_body = String::new();
- req.into_body().read_to_string(&mut request_body).await?;
- *captured_request.lock() =
- Some(serde_json::from_str(&request_body).unwrap());
- next_request_id += 1;
- Ok(http_client::Response::builder()
- .status(200)
- .body(
- serde_json::to_string(&PredictEditsResponse {
- request_id: format!("request-{next_request_id}"),
- output_excerpt: completion_response.lock().clone(),
- })
- .unwrap()
- .into(),
- )
- .unwrap())
- }
- _ => Ok(http_client::Response::builder()
- .status(404)
- .body("Not Found".into())
- .unwrap()),
- }
- }
- }
- });
-
- let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
- cx.update(|cx| {
- RefreshLlmTokenListener::register(client.clone(), cx);
- });
- let _server = FakeServer::for_client(42, &client, cx).await;
-
- let zeta = cx.new(|cx| {
- let mut zeta = Zeta::new(client, project.read(cx).user_store(), cx);
- zeta.set_edit_prediction_model(ZetaEditPredictionModel::Zeta1);
-
- let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
- for worktree in worktrees {
- let worktree_id = worktree.read(cx).id();
- zeta.get_or_init_zeta_project(project, cx)
- .license_detection_watchers
- .entry(worktree_id)
- .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
- }
-
- zeta
- });
-
- (zeta, captured_request, completion_response)
-}
-
-fn to_completion_edits(
- iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
- buffer: &Entity<Buffer>,
- cx: &App,
-) -> Vec<(Range<Anchor>, Arc<str>)> {
- let buffer = buffer.read(cx);
- iterator
- .into_iter()
- .map(|(range, text)| {
- (
- buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
- text,
- )
- })
- .collect()
-}
-
-fn from_completion_edits(
- editor_edits: &[(Range<Anchor>, Arc<str>)],
- buffer: &Entity<Buffer>,
- cx: &App,
-) -> Vec<(Range<usize>, Arc<str>)> {
- let buffer = buffer.read(cx);
- editor_edits
- .iter()
- .map(|(range, text)| {
- (
- range.start.to_offset(buffer)..range.end.to_offset(buffer),
- text.clone(),
- )
- })
- .collect()
-}
-
-#[ctor::ctor]
-fn init_logger() {
- zlog::init_test();
-}
@@ -1,48 +0,0 @@
-[package]
-name = "zeta2_tools"
-version = "0.1.0"
-edition.workspace = true
-publish.workspace = true
-license = "GPL-3.0-or-later"
-
-[lints]
-workspace = true
-
-[lib]
-path = "src/zeta2_tools.rs"
-
-[dependencies]
-anyhow.workspace = true
-client.workspace = true
-cloud_llm_client.workspace = true
-collections.workspace = true
-edit_prediction_context.workspace = true
-editor.workspace = true
-feature_flags.workspace = true
-futures.workspace = true
-gpui.workspace = true
-language.workspace = true
-multi_buffer.workspace = true
-project.workspace = true
-serde.workspace = true
-serde_json.workspace = true
-telemetry.workspace = true
-text.workspace = true
-ui.workspace = true
-ui_input.workspace = true
-util.workspace = true
-workspace.workspace = true
-zeta.workspace = true
-
-[dev-dependencies]
-clap.workspace = true
-gpui = { workspace = true, features = ["test-support"] }
-indoc.workspace = true
-language = { workspace = true, features = ["test-support"] }
-pretty_assertions.workspace = true
-project = { workspace = true, features = ["test-support"] }
-serde_json.workspace = true
-settings = { workspace = true, features = ["test-support"] }
-text = { workspace = true, features = ["test-support"] }
-util = { workspace = true, features = ["test-support"] }
-zlog.workspace = true
@@ -1 +0,0 @@
-../../LICENSE-GPL
@@ -1,1035 +0,0 @@
-mod zeta2_context_view;
-
-use std::{str::FromStr, sync::Arc, time::Duration};
-
-use client::{Client, UserStore};
-use cloud_llm_client::predict_edits_v3::PromptFormat;
-use collections::HashMap;
-use editor::{Editor, EditorEvent, EditorMode, MultiBuffer};
-use feature_flags::FeatureFlagAppExt as _;
-use futures::{FutureExt, StreamExt as _, channel::oneshot, future::Shared};
-use gpui::{
- Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, actions,
- prelude::*,
-};
-use language::Buffer;
-use project::{Project, telemetry_snapshot::TelemetrySnapshot};
-use ui::{ButtonLike, ContextMenu, ContextMenuEntry, DropdownMenu, KeyBinding, prelude::*};
-use ui_input::InputField;
-use util::ResultExt;
-use workspace::{Item, SplitDirection, Workspace};
-use zeta::{
- AgenticContextOptions, ContextMode, DEFAULT_SYNTAX_CONTEXT_OPTIONS, EditPredictionInputs, Zeta,
- Zeta2FeatureFlag, ZetaDebugInfo, ZetaEditPredictionDebugInfo, ZetaOptions,
-};
-
-use edit_prediction_context::{EditPredictionContextOptions, EditPredictionExcerptOptions};
-use zeta2_context_view::Zeta2ContextView;
-
-actions!(
- dev,
- [
- /// Opens the edit prediction context view.
- OpenZeta2ContextView,
- /// Opens the edit prediction inspector.
- OpenZeta2Inspector,
- /// Rate prediction as positive.
- Zeta2RatePredictionPositive,
- /// Rate prediction as negative.
- Zeta2RatePredictionNegative,
- ]
-);
-
-pub fn init(cx: &mut App) {
- cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
- workspace.register_action_renderer(|div, _, _, cx| {
- let has_flag = cx.has_flag::<Zeta2FeatureFlag>();
- div.when(has_flag, |div| {
- div.on_action(
- cx.listener(move |workspace, _: &OpenZeta2Inspector, window, cx| {
- let project = workspace.project();
- workspace.split_item(
- SplitDirection::Right,
- Box::new(cx.new(|cx| {
- Zeta2Inspector::new(
- &project,
- workspace.client(),
- workspace.user_store(),
- window,
- cx,
- )
- })),
- window,
- cx,
- )
- }),
- )
- .on_action(cx.listener(
- move |workspace, _: &OpenZeta2ContextView, window, cx| {
- let project = workspace.project();
- workspace.split_item(
- SplitDirection::Right,
- Box::new(cx.new(|cx| {
- Zeta2ContextView::new(
- project.clone(),
- workspace.client(),
- workspace.user_store(),
- window,
- cx,
- )
- })),
- window,
- cx,
- );
- },
- ))
- })
- });
- })
- .detach();
-}
-
-// TODO show included diagnostics, and events
-
-pub struct Zeta2Inspector {
- focus_handle: FocusHandle,
- project: Entity<Project>,
- last_prediction: Option<LastPrediction>,
- max_excerpt_bytes_input: Entity<InputField>,
- min_excerpt_bytes_input: Entity<InputField>,
- cursor_context_ratio_input: Entity<InputField>,
- max_prompt_bytes_input: Entity<InputField>,
- context_mode: ContextModeState,
- zeta: Entity<Zeta>,
- _active_editor_subscription: Option<Subscription>,
- _update_state_task: Task<()>,
- _receive_task: Task<()>,
-}
-
-pub enum ContextModeState {
- Llm,
- Lsp,
- Syntax {
- max_retrieved_declarations: Entity<InputField>,
- },
-}
-
-struct LastPrediction {
- prompt_editor: Entity<Editor>,
- retrieval_time: Duration,
- request_time: Option<Duration>,
- buffer: WeakEntity<Buffer>,
- position: language::Anchor,
- state: LastPredictionState,
- inputs: EditPredictionInputs,
- project_snapshot: Shared<Task<Arc<TelemetrySnapshot>>>,
- _task: Option<Task<()>>,
-}
-
-#[derive(Clone, Copy, PartialEq)]
-enum Feedback {
- Positive,
- Negative,
-}
-
-enum LastPredictionState {
- Requested,
- Success {
- model_response_editor: Entity<Editor>,
- feedback_editor: Entity<Editor>,
- feedback: Option<Feedback>,
- request_id: String,
- },
- Failed {
- message: String,
- },
-}
-
-impl Zeta2Inspector {
- pub fn new(
- project: &Entity<Project>,
- client: &Arc<Client>,
- user_store: &Entity<UserStore>,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) -> Self {
- let zeta = Zeta::global(client, user_store, cx);
- let mut request_rx = zeta.update(cx, |zeta, _cx| zeta.debug_info());
-
- let receive_task = cx.spawn_in(window, async move |this, cx| {
- while let Some(prediction) = request_rx.next().await {
- this.update_in(cx, |this, window, cx| {
- this.update_last_prediction(prediction, window, cx)
- })
- .ok();
- }
- });
-
- let mut this = Self {
- focus_handle: cx.focus_handle(),
- project: project.clone(),
- last_prediction: None,
- max_excerpt_bytes_input: Self::number_input("Max Excerpt Bytes", window, cx),
- min_excerpt_bytes_input: Self::number_input("Min Excerpt Bytes", window, cx),
- cursor_context_ratio_input: Self::number_input("Cursor Context Ratio", window, cx),
- max_prompt_bytes_input: Self::number_input("Max Prompt Bytes", window, cx),
- context_mode: ContextModeState::Llm,
- zeta: zeta.clone(),
- _active_editor_subscription: None,
- _update_state_task: Task::ready(()),
- _receive_task: receive_task,
- };
- this.set_options_state(&zeta.read(cx).options().clone(), window, cx);
- this
- }
-
- fn set_options_state(
- &mut self,
- options: &ZetaOptions,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- let excerpt_options = options.context.excerpt();
- self.max_excerpt_bytes_input.update(cx, |input, cx| {
- input.set_text(excerpt_options.max_bytes.to_string(), window, cx);
- });
- self.min_excerpt_bytes_input.update(cx, |input, cx| {
- input.set_text(excerpt_options.min_bytes.to_string(), window, cx);
- });
- self.cursor_context_ratio_input.update(cx, |input, cx| {
- input.set_text(
- format!(
- "{:.2}",
- excerpt_options.target_before_cursor_over_total_bytes
- ),
- window,
- cx,
- );
- });
- self.max_prompt_bytes_input.update(cx, |input, cx| {
- input.set_text(options.max_prompt_bytes.to_string(), window, cx);
- });
-
- match &options.context {
- ContextMode::Agentic(_) => {
- self.context_mode = ContextModeState::Llm;
- }
- ContextMode::Syntax(_) => {
- self.context_mode = ContextModeState::Syntax {
- max_retrieved_declarations: Self::number_input(
- "Max Retrieved Definitions",
- window,
- cx,
- ),
- };
- }
- ContextMode::Lsp(_) => {
- self.context_mode = ContextModeState::Lsp;
- }
- }
- cx.notify();
- }
-
- fn set_zeta_options(&mut self, options: ZetaOptions, cx: &mut Context<Self>) {
- self.zeta.update(cx, |this, _cx| this.set_options(options));
-
- if let Some(prediction) = self.last_prediction.as_mut() {
- if let Some(buffer) = prediction.buffer.upgrade() {
- let position = prediction.position;
- let project = self.project.clone();
- self.zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project, buffer, position, cx)
- });
- prediction.state = LastPredictionState::Requested;
- } else {
- self.last_prediction.take();
- }
- }
-
- cx.notify();
- }
-
- fn number_input(
- label: &'static str,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) -> Entity<InputField> {
- let input = cx.new(|cx| {
- InputField::new(window, cx, "")
- .label(label)
- .label_min_width(px(64.))
- });
-
- cx.subscribe_in(
- &input.read(cx).editor().clone(),
- window,
- |this, _, event, _window, cx| {
- let EditorEvent::BufferEdited = event else {
- return;
- };
-
- fn number_input_value<T: FromStr + Default>(
- input: &Entity<InputField>,
- cx: &App,
- ) -> T {
- input
- .read(cx)
- .editor()
- .read(cx)
- .text(cx)
- .parse::<T>()
- .unwrap_or_default()
- }
-
- let zeta_options = this.zeta.read(cx).options().clone();
-
- let excerpt_options = EditPredictionExcerptOptions {
- max_bytes: number_input_value(&this.max_excerpt_bytes_input, cx),
- min_bytes: number_input_value(&this.min_excerpt_bytes_input, cx),
- target_before_cursor_over_total_bytes: number_input_value(
- &this.cursor_context_ratio_input,
- cx,
- ),
- };
-
- let context = match zeta_options.context {
- ContextMode::Agentic(_context_options) => {
- ContextMode::Agentic(AgenticContextOptions {
- excerpt: excerpt_options,
- })
- }
- ContextMode::Syntax(context_options) => {
- let max_retrieved_declarations = match &this.context_mode {
- ContextModeState::Llm => {
- zeta::DEFAULT_SYNTAX_CONTEXT_OPTIONS.max_retrieved_declarations
- }
- ContextModeState::Syntax {
- max_retrieved_declarations,
- } => number_input_value(max_retrieved_declarations, cx),
- ContextModeState::Lsp => {
- zeta::DEFAULT_SYNTAX_CONTEXT_OPTIONS.max_retrieved_declarations
- }
- };
-
- ContextMode::Syntax(EditPredictionContextOptions {
- excerpt: excerpt_options,
- max_retrieved_declarations,
- ..context_options
- })
- }
- ContextMode::Lsp(excerpt_options) => ContextMode::Lsp(excerpt_options),
- };
-
- this.set_zeta_options(
- ZetaOptions {
- context,
- max_prompt_bytes: number_input_value(&this.max_prompt_bytes_input, cx),
- max_diagnostic_bytes: zeta_options.max_diagnostic_bytes,
- prompt_format: zeta_options.prompt_format,
- file_indexing_parallelism: zeta_options.file_indexing_parallelism,
- buffer_change_grouping_interval: zeta_options
- .buffer_change_grouping_interval,
- },
- cx,
- );
- },
- )
- .detach();
- input
- }
-
- fn update_last_prediction(
- &mut self,
- prediction: zeta::ZetaDebugInfo,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- self._update_state_task = cx.spawn_in(window, {
- let language_registry = self.project.read(cx).languages().clone();
- async move |this, cx| {
- let mut languages = HashMap::default();
- let ZetaDebugInfo::EditPredictionRequested(prediction) = prediction else {
- return;
- };
- for ext in prediction
- .inputs
- .included_files
- .iter()
- .filter_map(|file| file.path.extension())
- {
- if !languages.contains_key(ext) {
- // Most snippets are gonna be the same language,
- // so we think it's fine to do this sequentially for now
- languages.insert(
- ext.to_owned(),
- language_registry
- .language_for_name_or_extension(&ext.to_string_lossy())
- .await
- .ok(),
- );
- }
- }
-
- let markdown_language = language_registry
- .language_for_name("Markdown")
- .await
- .log_err();
-
- let json_language = language_registry.language_for_name("Json").await.log_err();
-
- this.update_in(cx, |this, window, cx| {
- let ZetaEditPredictionDebugInfo {
- response_rx,
- position,
- buffer,
- retrieval_time,
- local_prompt,
- ..
- } = prediction;
-
- let task = cx.spawn_in(window, {
- let markdown_language = markdown_language.clone();
- let json_language = json_language.clone();
- async move |this, cx| {
- let response = response_rx.await;
-
- this.update_in(cx, |this, window, cx| {
- if let Some(prediction) = this.last_prediction.as_mut() {
- prediction.state = match response {
- Ok((Ok(response), request_time)) => {
- prediction.request_time = Some(request_time);
-
- let feedback_editor = cx.new(|cx| {
- let buffer = cx.new(|cx| {
- let mut buffer = Buffer::local("", cx);
- buffer.set_language(
- markdown_language.clone(),
- cx,
- );
- buffer
- });
- let buffer =
- cx.new(|cx| MultiBuffer::singleton(buffer, cx));
- let mut editor = Editor::new(
- EditorMode::AutoHeight {
- min_lines: 3,
- max_lines: None,
- },
- buffer,
- None,
- window,
- cx,
- );
- editor.set_placeholder_text(
- "Write feedback here",
- window,
- cx,
- );
- editor.set_show_line_numbers(false, cx);
- editor.set_show_gutter(false, cx);
- editor.set_show_scrollbars(false, cx);
- editor
- });
-
- cx.subscribe_in(
- &feedback_editor,
- window,
- |this, editor, ev, window, cx| match ev {
- EditorEvent::BufferEdited => {
- if let Some(last_prediction) =
- this.last_prediction.as_mut()
- && let LastPredictionState::Success {
- feedback: feedback_state,
- ..
- } = &mut last_prediction.state
- {
- if feedback_state.take().is_some() {
- editor.update(cx, |editor, cx| {
- editor.set_placeholder_text(
- "Write feedback here",
- window,
- cx,
- );
- });
- cx.notify();
- }
- }
- }
- _ => {}
- },
- )
- .detach();
-
- LastPredictionState::Success {
- model_response_editor: cx.new(|cx| {
- let buffer = cx.new(|cx| {
- let mut buffer = Buffer::local(
- serde_json::to_string_pretty(&response)
- .unwrap_or_default(),
- cx,
- );
- buffer.set_language(json_language, cx);
- buffer
- });
- let buffer = cx.new(|cx| {
- MultiBuffer::singleton(buffer, cx)
- });
- let mut editor = Editor::new(
- EditorMode::full(),
- buffer,
- None,
- window,
- cx,
- );
- editor.set_read_only(true);
- editor.set_show_line_numbers(false, cx);
- editor.set_show_gutter(false, cx);
- editor.set_show_scrollbars(false, cx);
- editor
- }),
- feedback_editor,
- feedback: None,
- request_id: response.id.clone(),
- }
- }
- Ok((Err(err), request_time)) => {
- prediction.request_time = Some(request_time);
- LastPredictionState::Failed { message: err }
- }
- Err(oneshot::Canceled) => LastPredictionState::Failed {
- message: "Canceled".to_string(),
- },
- };
- }
- })
- .ok();
- }
- });
-
- let project_snapshot_task = TelemetrySnapshot::new(&this.project, cx);
-
- this.last_prediction = Some(LastPrediction {
- prompt_editor: cx.new(|cx| {
- let buffer = cx.new(|cx| {
- let mut buffer =
- Buffer::local(local_prompt.unwrap_or_else(|err| err), cx);
- buffer.set_language(markdown_language.clone(), cx);
- buffer
- });
- let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
- let mut editor =
- Editor::new(EditorMode::full(), buffer, None, window, cx);
- editor.set_read_only(true);
- editor.set_show_line_numbers(false, cx);
- editor.set_show_gutter(false, cx);
- editor.set_show_scrollbars(false, cx);
- editor
- }),
- retrieval_time,
- request_time: None,
- buffer,
- position,
- state: LastPredictionState::Requested,
- project_snapshot: cx
- .foreground_executor()
- .spawn(async move { Arc::new(project_snapshot_task.await) })
- .shared(),
- inputs: prediction.inputs,
- _task: Some(task),
- });
- cx.notify();
- })
- .ok();
- }
- });
- }
-
- fn handle_rate_positive(
- &mut self,
- _action: &Zeta2RatePredictionPositive,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- self.handle_rate(Feedback::Positive, window, cx);
- }
-
- fn handle_rate_negative(
- &mut self,
- _action: &Zeta2RatePredictionNegative,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- self.handle_rate(Feedback::Negative, window, cx);
- }
-
- fn handle_rate(&mut self, kind: Feedback, window: &mut Window, cx: &mut Context<Self>) {
- let Some(last_prediction) = self.last_prediction.as_mut() else {
- return;
- };
-
- let project_snapshot_task = last_prediction.project_snapshot.clone();
-
- cx.spawn_in(window, async move |this, cx| {
- let project_snapshot = project_snapshot_task.await;
- this.update_in(cx, |this, window, cx| {
- let Some(last_prediction) = this.last_prediction.as_mut() else {
- return;
- };
-
- let LastPredictionState::Success {
- feedback: feedback_state,
- feedback_editor,
- model_response_editor,
- request_id,
- ..
- } = &mut last_prediction.state
- else {
- return;
- };
-
- *feedback_state = Some(kind);
- let text = feedback_editor.update(cx, |feedback_editor, cx| {
- feedback_editor.set_placeholder_text(
- "Submitted. Edit or submit again to change.",
- window,
- cx,
- );
- feedback_editor.text(cx)
- });
- cx.notify();
-
- cx.defer_in(window, {
- let model_response_editor = model_response_editor.downgrade();
- move |_, window, cx| {
- if let Some(model_response_editor) = model_response_editor.upgrade() {
- model_response_editor.focus_handle(cx).focus(window);
- }
- }
- });
-
- let kind = match kind {
- Feedback::Positive => "positive",
- Feedback::Negative => "negative",
- };
-
- telemetry::event!(
- "Zeta2 Prediction Rated",
- id = request_id,
- kind = kind,
- text = text,
- request = last_prediction.inputs,
- project_snapshot = project_snapshot,
- );
- })
- .log_err();
- })
- .detach();
- }
-
- fn render_options(&self, window: &mut Window, cx: &mut Context<Self>) -> Div {
- v_flex()
- .gap_2()
- .child(
- h_flex()
- .child(Headline::new("Options").size(HeadlineSize::Small))
- .justify_between()
- .child(
- ui::Button::new("reset-options", "Reset")
- .disabled(self.zeta.read(cx).options() == &zeta::DEFAULT_OPTIONS)
- .style(ButtonStyle::Outlined)
- .size(ButtonSize::Large)
- .on_click(cx.listener(|this, _, window, cx| {
- this.set_options_state(&zeta::DEFAULT_OPTIONS, window, cx);
- })),
- ),
- )
- .child(
- v_flex()
- .gap_2()
- .child(
- h_flex()
- .gap_2()
- .items_end()
- .child(self.max_excerpt_bytes_input.clone())
- .child(self.min_excerpt_bytes_input.clone())
- .child(self.cursor_context_ratio_input.clone())
- .child(self.render_context_mode_dropdown(window, cx)),
- )
- .child(
- h_flex()
- .gap_2()
- .items_end()
- .children(match &self.context_mode {
- ContextModeState::Llm => None,
- ContextModeState::Syntax {
- max_retrieved_declarations,
- } => Some(max_retrieved_declarations.clone()),
- ContextModeState::Lsp => None,
- })
- .child(self.max_prompt_bytes_input.clone())
- .child(self.render_prompt_format_dropdown(window, cx)),
- ),
- )
- }
-
- fn render_context_mode_dropdown(&self, window: &mut Window, cx: &mut Context<Self>) -> Div {
- let this = cx.weak_entity();
-
- v_flex()
- .gap_1p5()
- .child(
- Label::new("Context Mode")
- .size(LabelSize::Small)
- .color(Color::Muted),
- )
- .child(
- DropdownMenu::new(
- "ep-ctx-mode",
- match &self.context_mode {
- ContextModeState::Llm => "LLM-based",
- ContextModeState::Syntax { .. } => "Syntax",
- ContextModeState::Lsp => "LSP-based",
- },
- ContextMenu::build(window, cx, move |menu, _window, _cx| {
- menu.item(
- ContextMenuEntry::new("LLM-based")
- .toggleable(
- IconPosition::End,
- matches!(self.context_mode, ContextModeState::Llm),
- )
- .handler({
- let this = this.clone();
- move |window, cx| {
- this.update(cx, |this, cx| {
- let current_options =
- this.zeta.read(cx).options().clone();
- match current_options.context.clone() {
- ContextMode::Agentic(_) => {}
- ContextMode::Lsp(_) => {}
- ContextMode::Syntax(context_options) => {
- let options = ZetaOptions {
- context: ContextMode::Agentic(
- AgenticContextOptions {
- excerpt: context_options.excerpt,
- },
- ),
- ..current_options
- };
- this.set_options_state(&options, window, cx);
- this.set_zeta_options(options, cx);
- }
- }
- })
- .ok();
- }
- }),
- )
- .item(
- ContextMenuEntry::new("Syntax")
- .toggleable(
- IconPosition::End,
- matches!(self.context_mode, ContextModeState::Syntax { .. }),
- )
- .handler({
- move |window, cx| {
- this.update(cx, |this, cx| {
- let current_options =
- this.zeta.read(cx).options().clone();
- match current_options.context.clone() {
- ContextMode::Agentic(context_options) => {
- let options = ZetaOptions {
- context: ContextMode::Syntax(
- EditPredictionContextOptions {
- excerpt: context_options.excerpt,
- ..DEFAULT_SYNTAX_CONTEXT_OPTIONS
- },
- ),
- ..current_options
- };
- this.set_options_state(&options, window, cx);
- this.set_zeta_options(options, cx);
- }
- ContextMode::Syntax(_) => {}
- ContextMode::Lsp(_) => {}
- }
- })
- .ok();
- }
- }),
- )
- }),
- )
- .style(ui::DropdownStyle::Outlined),
- )
- }
-
- fn render_prompt_format_dropdown(&self, window: &mut Window, cx: &mut Context<Self>) -> Div {
- let active_format = self.zeta.read(cx).options().prompt_format;
- let this = cx.weak_entity();
-
- v_flex()
- .gap_1p5()
- .child(
- Label::new("Prompt Format")
- .size(LabelSize::Small)
- .color(Color::Muted),
- )
- .child(
- DropdownMenu::new(
- "ep-prompt-format",
- active_format.to_string(),
- ContextMenu::build(window, cx, move |mut menu, _window, _cx| {
- for prompt_format in PromptFormat::iter() {
- menu = menu.item(
- ContextMenuEntry::new(prompt_format.to_string())
- .toggleable(IconPosition::End, active_format == prompt_format)
- .handler({
- let this = this.clone();
- move |_window, cx| {
- this.update(cx, |this, cx| {
- let current_options =
- this.zeta.read(cx).options().clone();
- let options = ZetaOptions {
- prompt_format,
- ..current_options
- };
- this.set_zeta_options(options, cx);
- })
- .ok();
- }
- }),
- )
- }
- menu
- }),
- )
- .style(ui::DropdownStyle::Outlined),
- )
- }
-
- fn render_stats(&self) -> Option<Div> {
- let Some(prediction) = self.last_prediction.as_ref() else {
- return None;
- };
-
- Some(
- v_flex()
- .p_4()
- .gap_2()
- .min_w(px(160.))
- .child(Headline::new("Stats").size(HeadlineSize::Small))
- .child(Self::render_duration(
- "Context retrieval",
- Some(prediction.retrieval_time),
- ))
- .child(Self::render_duration("Request", prediction.request_time)),
- )
- }
-
- fn render_duration(name: &'static str, time: Option<Duration>) -> Div {
- h_flex()
- .gap_1()
- .child(Label::new(name).color(Color::Muted).size(LabelSize::Small))
- .child(match time {
- Some(time) => Label::new(if time.as_micros() >= 1000 {
- format!("{} ms", time.as_millis())
- } else {
- format!("{} µs", time.as_micros())
- })
- .size(LabelSize::Small),
- None => Label::new("...").size(LabelSize::Small),
- })
- }
-
- fn render_content(&self, _: &mut Window, cx: &mut Context<Self>) -> AnyElement {
- if !cx.has_flag::<Zeta2FeatureFlag>() {
- return Self::render_message("`zeta2` feature flag is not enabled");
- }
-
- match self.last_prediction.as_ref() {
- None => Self::render_message("No prediction"),
- Some(prediction) => self.render_last_prediction(prediction, cx).into_any(),
- }
- }
-
- fn render_message(message: impl Into<SharedString>) -> AnyElement {
- v_flex()
- .size_full()
- .justify_center()
- .items_center()
- .child(Label::new(message).size(LabelSize::Large))
- .into_any()
- }
-
- fn render_last_prediction(&self, prediction: &LastPrediction, cx: &mut Context<Self>) -> Div {
- h_flex()
- .items_start()
- .w_full()
- .flex_1()
- .border_t_1()
- .border_color(cx.theme().colors().border)
- .bg(cx.theme().colors().editor_background)
- .child(
- v_flex()
- .flex_1()
- .gap_2()
- .p_4()
- .h_full()
- .child(
- h_flex()
- .justify_between()
- .child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall))
- .child(match prediction.state {
- LastPredictionState::Requested
- | LastPredictionState::Failed { .. } => ui::Chip::new("Local")
- .bg_color(cx.theme().status().warning_background)
- .label_color(Color::Success),
- LastPredictionState::Success { .. } => ui::Chip::new("Cloud")
- .bg_color(cx.theme().status().success_background)
- .label_color(Color::Success),
- }),
- )
- .child(prediction.prompt_editor.clone()),
- )
- .child(ui::vertical_divider())
- .child(
- v_flex()
- .flex_1()
- .gap_2()
- .h_full()
- .child(
- v_flex()
- .flex_1()
- .gap_2()
- .p_4()
- .child(
- ui::Headline::new("Model Response").size(ui::HeadlineSize::XSmall),
- )
- .child(match &prediction.state {
- LastPredictionState::Success {
- model_response_editor,
- ..
- } => model_response_editor.clone().into_any_element(),
- LastPredictionState::Requested => v_flex()
- .gap_2()
- .child(Label::new("Loading...").buffer_font(cx))
- .into_any_element(),
- LastPredictionState::Failed { message } => v_flex()
- .gap_2()
- .max_w_96()
- .child(Label::new(message.clone()).buffer_font(cx))
- .into_any_element(),
- }),
- )
- .child(ui::divider())
- .child(
- if let LastPredictionState::Success {
- feedback_editor,
- feedback: feedback_state,
- ..
- } = &prediction.state
- {
- v_flex()
- .key_context("Zeta2Feedback")
- .on_action(cx.listener(Self::handle_rate_positive))
- .on_action(cx.listener(Self::handle_rate_negative))
- .gap_2()
- .p_2()
- .child(feedback_editor.clone())
- .child(
- h_flex()
- .justify_end()
- .w_full()
- .child(
- ButtonLike::new("rate-positive")
- .when(
- *feedback_state == Some(Feedback::Positive),
- |this| this.style(ButtonStyle::Filled),
- )
- .child(
- KeyBinding::for_action(
- &Zeta2RatePredictionPositive,
- cx,
- )
- .size(TextSize::Small.rems(cx)),
- )
- .child(ui::Icon::new(ui::IconName::ThumbsUp))
- .on_click(cx.listener(|this, _, window, cx| {
- this.handle_rate_positive(
- &Zeta2RatePredictionPositive,
- window,
- cx,
- );
- })),
- )
- .child(
- ButtonLike::new("rate-negative")
- .when(
- *feedback_state == Some(Feedback::Negative),
- |this| this.style(ButtonStyle::Filled),
- )
- .child(
- KeyBinding::for_action(
- &Zeta2RatePredictionNegative,
- cx,
- )
- .size(TextSize::Small.rems(cx)),
- )
- .child(ui::Icon::new(ui::IconName::ThumbsDown))
- .on_click(cx.listener(|this, _, window, cx| {
- this.handle_rate_negative(
- &Zeta2RatePredictionNegative,
- window,
- cx,
- );
- })),
- ),
- )
- .into_any()
- } else {
- Empty.into_any_element()
- },
- ),
- )
- }
-}
-
-impl Focusable for Zeta2Inspector {
- fn focus_handle(&self, _cx: &App) -> FocusHandle {
- self.focus_handle.clone()
- }
-}
-
-impl Item for Zeta2Inspector {
- type Event = ();
-
- fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
- "Zeta2 Inspector".into()
- }
-}
-
-impl EventEmitter<()> for Zeta2Inspector {}
-
-impl Render for Zeta2Inspector {
- fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- v_flex()
- .size_full()
- .bg(cx.theme().colors().editor_background)
- .child(
- h_flex()
- .w_full()
- .child(
- v_flex()
- .flex_1()
- .p_4()
- .h_full()
- .justify_between()
- .child(self.render_options(window, cx))
- .gap_4(),
- )
- .child(ui::vertical_divider())
- .children(self.render_stats()),
- )
- .child(self.render_content(window, cx))
- }
-}
@@ -1 +0,0 @@
-../../LICENSE-GPL
@@ -1,1260 +0,0 @@
-use ::util::rel_path::RelPath;
-use ::util::{RangeExt, ResultExt as _};
-use anyhow::{Context as _, Result};
-use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
-use edit_prediction_context::{
- Declaration, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions, Identifier,
- Imports, Reference, ReferenceRegion, SyntaxIndex, SyntaxIndexState, references_in_range,
-};
-use futures::StreamExt as _;
-use futures::channel::mpsc;
-use gpui::Entity;
-use gpui::{AppContext, AsyncApp};
-use language::OffsetRangeExt;
-use language::{BufferSnapshot, Point};
-use ordered_float::OrderedFloat;
-use polars::prelude::*;
-use project::{Project, ProjectEntryId, ProjectPath, Worktree};
-use serde::{Deserialize, Serialize};
-use std::fs;
-use std::{
- cmp::Reverse,
- collections::{HashMap, HashSet},
- fs::File,
- hash::{Hash, Hasher},
- io::{BufRead, BufReader, BufWriter, Write as _},
- ops::Range,
- path::{Path, PathBuf},
- sync::{
- Arc,
- atomic::{self, AtomicUsize},
- },
- time::Duration,
-};
-use util::paths::PathStyle;
-use zeta::ContextMode;
-
-use crate::headless::ZetaCliAppState;
-use crate::source_location::SourceLocation;
-use crate::util::{open_buffer, open_buffer_with_language_server};
-
-pub async fn retrieval_stats(
- worktree: PathBuf,
- app_state: Arc<ZetaCliAppState>,
- only_extension: Option<String>,
- file_limit: Option<usize>,
- skip_files: Option<usize>,
- options: zeta::ZetaOptions,
- cx: &mut AsyncApp,
-) -> Result<String> {
- let ContextMode::Syntax(context_options) = options.context.clone() else {
- anyhow::bail!("retrieval stats only works in ContextMode::Syntax");
- };
-
- let options = Arc::new(options);
- let worktree_path = worktree.canonicalize()?;
-
- let project = cx.update(|cx| {
- Project::local(
- app_state.client.clone(),
- app_state.node_runtime.clone(),
- app_state.user_store.clone(),
- app_state.languages.clone(),
- app_state.fs.clone(),
- None,
- cx,
- )
- })?;
-
- let worktree = project
- .update(cx, |project, cx| {
- project.create_worktree(&worktree_path, true, cx)
- })?
- .await?;
-
- // wait for worktree scan so that wait_for_initial_file_indexing waits for the whole worktree.
- worktree
- .read_with(cx, |worktree, _cx| {
- worktree.as_local().unwrap().scan_complete()
- })?
- .await;
-
- let index = cx.new(|cx| SyntaxIndex::new(&project, options.file_indexing_parallelism, cx))?;
- index
- .read_with(cx, |index, cx| index.wait_for_initial_file_indexing(cx))?
- .await?;
- let indexed_files = index
- .read_with(cx, |index, cx| index.indexed_file_paths(cx))?
- .await;
- let mut filtered_files = indexed_files
- .into_iter()
- .filter(|project_path| {
- let file_extension = project_path.path.extension();
- if let Some(only_extension) = only_extension.as_ref() {
- file_extension.is_some_and(|extension| extension == only_extension)
- } else {
- file_extension
- .is_some_and(|extension| !["md", "json", "sh", "diff"].contains(&extension))
- }
- })
- .collect::<Vec<_>>();
- filtered_files.sort_by(|a, b| a.path.cmp(&b.path));
-
- let index_state = index.read_with(cx, |index, _cx| index.state().clone())?;
- cx.update(|_| {
- drop(index);
- })?;
- let index_state = Arc::new(
- Arc::into_inner(index_state)
- .context("Index state had more than 1 reference")?
- .into_inner(),
- );
-
- struct FileSnapshot {
- project_entry_id: ProjectEntryId,
- snapshot: BufferSnapshot,
- hash: u64,
- parent_abs_path: Arc<Path>,
- }
-
- let files: Vec<FileSnapshot> = futures::future::try_join_all({
- filtered_files
- .iter()
- .map(|file| {
- let buffer_task =
- open_buffer(project.clone(), worktree.clone(), file.path.clone(), cx);
- cx.spawn(async move |cx| {
- let buffer = buffer_task.await?;
- let (project_entry_id, parent_abs_path, snapshot) =
- buffer.read_with(cx, |buffer, cx| {
- let file = project::File::from_dyn(buffer.file()).unwrap();
- let project_entry_id = file.project_entry_id().unwrap();
- let mut parent_abs_path = file.worktree.read(cx).absolutize(&file.path);
- if !parent_abs_path.pop() {
- panic!("Invalid worktree path");
- }
-
- (project_entry_id, parent_abs_path, buffer.snapshot())
- })?;
-
- anyhow::Ok(
- cx.background_spawn(async move {
- let mut hasher = collections::FxHasher::default();
- snapshot.text().hash(&mut hasher);
- FileSnapshot {
- project_entry_id,
- snapshot,
- hash: hasher.finish(),
- parent_abs_path: parent_abs_path.into(),
- }
- })
- .await,
- )
- })
- })
- .collect::<Vec<_>>()
- })
- .await?;
-
- let mut file_snapshots = HashMap::default();
- let mut hasher = collections::FxHasher::default();
- for FileSnapshot {
- project_entry_id,
- snapshot,
- hash,
- ..
- } in &files
- {
- file_snapshots.insert(*project_entry_id, snapshot.clone());
- hash.hash(&mut hasher);
- }
- let files_hash = hasher.finish();
- let file_snapshots = Arc::new(file_snapshots);
- let target_cli_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../target/zeta_cli");
- fs::create_dir_all(&target_cli_dir).unwrap();
- let target_cli_dir = target_cli_dir.canonicalize().unwrap();
-
- let lsp_cache_dir = target_cli_dir.join("cache");
- fs::create_dir_all(&lsp_cache_dir).unwrap();
-
- let lsp_definitions_path = lsp_cache_dir.join(format!(
- "{}-{:x}.jsonl",
- worktree_path.file_stem().unwrap_or_default().display(),
- files_hash
- ));
-
- let mut lsp_definitions = HashMap::default();
- let mut lsp_files = 0;
-
- if fs::exists(&lsp_definitions_path)? {
- log::info!(
- "Using cached LSP definitions from {}",
- lsp_definitions_path.display()
- );
-
- let file = File::options()
- .read(true)
- .write(true)
- .open(&lsp_definitions_path)?;
- let lines = BufReader::new(&file).lines();
- let mut valid_len: usize = 0;
-
- for (line, expected_file) in lines.zip(files.iter()) {
- let line = line?;
- let FileLspDefinitions { path, references } = match serde_json::from_str(&line) {
- Ok(ok) => ok,
- Err(_) => {
- log::error!("Found invalid cache line. Truncating to #{lsp_files}.",);
- file.set_len(valid_len as u64)?;
- break;
- }
- };
- let expected_path = expected_file.snapshot.file().unwrap().path().as_unix_str();
- if expected_path != path.as_ref() {
- log::error!(
- "Expected file #{} to be {expected_path}, but found {path}. Truncating to #{lsp_files}.",
- lsp_files + 1
- );
- file.set_len(valid_len as u64)?;
- break;
- }
- for (point, ranges) in references {
- let Ok(path) = RelPath::new(Path::new(path.as_ref()), PathStyle::Posix) else {
- log::warn!("Invalid path: {}", path);
- continue;
- };
- lsp_definitions.insert(
- SourceLocation {
- path: path.into_arc(),
- point: point.into(),
- },
- ranges,
- );
- }
- lsp_files += 1;
- valid_len += line.len() + 1
- }
- }
-
- if lsp_files < files.len() {
- if lsp_files == 0 {
- log::warn!(
- "No LSP definitions found, populating {}",
- lsp_definitions_path.display()
- );
- } else {
- log::warn!("{} files missing from LSP cache", files.len() - lsp_files);
- }
-
- gather_lsp_definitions(
- &lsp_definitions_path,
- lsp_files,
- &filtered_files,
- &worktree,
- &project,
- &mut lsp_definitions,
- cx,
- )
- .await?;
- }
- let files_len = files.len().min(file_limit.unwrap_or(usize::MAX));
- let done_count = Arc::new(AtomicUsize::new(0));
-
- let (output_tx, output_rx) = mpsc::unbounded::<ReferenceRetrievalResult>();
-
- let tasks = files
- .into_iter()
- .skip(skip_files.unwrap_or(0))
- .take(file_limit.unwrap_or(usize::MAX))
- .map(|project_file| {
- let index_state = index_state.clone();
- let lsp_definitions = lsp_definitions.clone();
- let output_tx = output_tx.clone();
- let done_count = done_count.clone();
- let file_snapshots = file_snapshots.clone();
- let context_options = context_options.clone();
- cx.background_spawn(async move {
- let snapshot = project_file.snapshot;
-
- let full_range = 0..snapshot.len();
- let references = references_in_range(
- full_range,
- &snapshot.text(),
- ReferenceRegion::Nearby,
- &snapshot,
- );
-
- let imports = if context_options.use_imports {
- Imports::gather(&snapshot, Some(&project_file.parent_abs_path))
- } else {
- Imports::default()
- };
-
- let path = snapshot.file().unwrap().path();
-
- for reference in references {
- let query_point = snapshot.offset_to_point(reference.range.start);
- let source_location = SourceLocation {
- path: path.clone(),
- point: query_point,
- };
- let lsp_definitions = lsp_definitions
- .get(&source_location)
- .cloned()
- .unwrap_or_else(|| {
- log::warn!(
- "No definitions found for source location: {:?}",
- source_location
- );
- Vec::new()
- });
-
- let retrieve_result = retrieve_definitions(
- &reference,
- &imports,
- query_point,
- &snapshot,
- &index_state,
- &file_snapshots,
- &context_options,
- )
- .await?;
-
- let result = ReferenceRetrievalResult {
- cursor_path: path.clone(),
- identifier: reference.identifier,
- cursor_point: query_point,
- lsp_definitions,
- retrieved_definitions: retrieve_result.definitions,
- excerpt_range: retrieve_result.excerpt_range,
- };
-
- output_tx.unbounded_send(result).ok();
- }
-
- println!(
- "{:02}/{:02} done",
- done_count.fetch_add(1, atomic::Ordering::Relaxed) + 1,
- files_len,
- );
-
- anyhow::Ok(())
- })
- })
- .collect::<Vec<_>>();
-
- drop(output_tx);
-
- let df_task = cx.background_spawn(build_dataframe(output_rx));
-
- futures::future::try_join_all(tasks).await?;
- let mut df = df_task.await?;
-
- let run_id = format!(
- "{}-{}",
- worktree_path.file_stem().unwrap_or_default().display(),
- chrono::Local::now().format("%Y%m%d_%H%M%S")
- );
- let run_dir = target_cli_dir.join(run_id);
- fs::create_dir(&run_dir).unwrap();
-
- let parquet_path = run_dir.join("stats.parquet");
- let mut parquet_file = fs::File::create(&parquet_path)?;
-
- ParquetWriter::new(&mut parquet_file)
- .finish(&mut df)
- .unwrap();
-
- let stats = SummaryStats::from_dataframe(df)?;
-
- let stats_path = run_dir.join("stats.txt");
- fs::write(&stats_path, format!("{}", stats))?;
-
- println!("{}", stats);
- println!("\nWrote:");
- println!("- {}", relativize_path(&parquet_path).display());
- println!("- {}", relativize_path(&stats_path).display());
- println!("- {}", relativize_path(&lsp_definitions_path).display());
-
- Ok("".to_string())
-}
-
-async fn build_dataframe(
- mut output_rx: mpsc::UnboundedReceiver<ReferenceRetrievalResult>,
-) -> Result<DataFrame> {
- use soa_rs::{Soa, Soars};
-
- #[derive(Default, Soars)]
- struct Row {
- ref_id: u32,
- cursor_path: String,
- cursor_row: u32,
- cursor_column: u32,
- cursor_identifier: String,
- gold_in_excerpt: bool,
- gold_path: String,
- gold_row: u32,
- gold_column: u32,
- gold_is_external: bool,
- candidate_count: u32,
- candidate_path: Option<String>,
- candidate_row: Option<u32>,
- candidate_column: Option<u32>,
- candidate_is_gold: Option<bool>,
- candidate_rank: Option<u32>,
- candidate_is_same_file: Option<bool>,
- candidate_is_referenced_nearby: Option<bool>,
- candidate_is_referenced_in_breadcrumb: Option<bool>,
- candidate_reference_count: Option<u32>,
- candidate_same_file_declaration_count: Option<u32>,
- candidate_declaration_count: Option<u32>,
- candidate_reference_line_distance: Option<u32>,
- candidate_declaration_line_distance: Option<u32>,
- candidate_excerpt_vs_item_jaccard: Option<f32>,
- candidate_excerpt_vs_signature_jaccard: Option<f32>,
- candidate_adjacent_vs_item_jaccard: Option<f32>,
- candidate_adjacent_vs_signature_jaccard: Option<f32>,
- candidate_excerpt_vs_item_weighted_overlap: Option<f32>,
- candidate_excerpt_vs_signature_weighted_overlap: Option<f32>,
- candidate_adjacent_vs_item_weighted_overlap: Option<f32>,
- candidate_adjacent_vs_signature_weighted_overlap: Option<f32>,
- candidate_path_import_match_count: Option<u32>,
- candidate_wildcard_path_import_match_count: Option<u32>,
- candidate_import_similarity: Option<f32>,
- candidate_max_import_similarity: Option<f32>,
- candidate_normalized_import_similarity: Option<f32>,
- candidate_wildcard_import_similarity: Option<f32>,
- candidate_normalized_wildcard_import_similarity: Option<f32>,
- candidate_included_by_others: Option<u32>,
- candidate_includes_others: Option<u32>,
- }
- let mut rows = Soa::<Row>::new();
- let mut next_ref_id = 0;
-
- while let Some(result) = output_rx.next().await {
- let mut gold_is_external = false;
- let mut gold_in_excerpt = false;
- let cursor_path = result.cursor_path.as_unix_str();
- let cursor_row = result.cursor_point.row + 1;
- let cursor_column = result.cursor_point.column + 1;
- let cursor_identifier = result.identifier.name.to_string();
- let ref_id = next_ref_id;
- next_ref_id += 1;
-
- for lsp_definition in result.lsp_definitions {
- let SourceRange {
- path: gold_path,
- point_range: gold_point_range,
- offset_range: gold_offset_range,
- } = lsp_definition;
- let lsp_point_range =
- SerializablePoint::into_language_point_range(gold_point_range.clone());
-
- gold_is_external = gold_is_external
- || gold_path.is_absolute()
- || gold_path
- .components()
- .any(|component| component.as_os_str() == "node_modules");
-
- gold_in_excerpt = gold_in_excerpt
- || result.excerpt_range.as_ref().is_some_and(|excerpt_range| {
- excerpt_range.contains_inclusive(&gold_offset_range)
- });
-
- let gold_row = gold_point_range.start.row;
- let gold_column = gold_point_range.start.column;
- let candidate_count = result.retrieved_definitions.len() as u32;
-
- for (candidate_rank, retrieved_definition) in
- result.retrieved_definitions.iter().enumerate()
- {
- let candidate_is_gold = gold_path.as_path()
- == retrieved_definition.path.as_std_path()
- && retrieved_definition
- .range
- .contains_inclusive(&lsp_point_range);
-
- let candidate_row = retrieved_definition.range.start.row + 1;
- let candidate_column = retrieved_definition.range.start.column + 1;
-
- let DeclarationScoreComponents {
- is_same_file,
- is_referenced_nearby,
- is_referenced_in_breadcrumb,
- reference_count,
- same_file_declaration_count,
- declaration_count,
- reference_line_distance,
- declaration_line_distance,
- excerpt_vs_item_jaccard,
- excerpt_vs_signature_jaccard,
- adjacent_vs_item_jaccard,
- adjacent_vs_signature_jaccard,
- excerpt_vs_item_weighted_overlap,
- excerpt_vs_signature_weighted_overlap,
- adjacent_vs_item_weighted_overlap,
- adjacent_vs_signature_weighted_overlap,
- path_import_match_count,
- wildcard_path_import_match_count,
- import_similarity,
- max_import_similarity,
- normalized_import_similarity,
- wildcard_import_similarity,
- normalized_wildcard_import_similarity,
- included_by_others,
- includes_others,
- } = retrieved_definition.components;
-
- rows.push(Row {
- ref_id,
- cursor_path: cursor_path.to_string(),
- cursor_row,
- cursor_column,
- cursor_identifier: cursor_identifier.clone(),
- gold_in_excerpt,
- gold_path: gold_path.to_string_lossy().to_string(),
- gold_row,
- gold_column,
- gold_is_external,
- candidate_count,
- candidate_path: Some(retrieved_definition.path.as_unix_str().to_string()),
- candidate_row: Some(candidate_row),
- candidate_column: Some(candidate_column),
- candidate_is_gold: Some(candidate_is_gold),
- candidate_rank: Some(candidate_rank as u32),
- candidate_is_same_file: Some(is_same_file),
- candidate_is_referenced_nearby: Some(is_referenced_nearby),
- candidate_is_referenced_in_breadcrumb: Some(is_referenced_in_breadcrumb),
- candidate_reference_count: Some(reference_count as u32),
- candidate_same_file_declaration_count: Some(same_file_declaration_count as u32),
- candidate_declaration_count: Some(declaration_count as u32),
- candidate_reference_line_distance: Some(reference_line_distance),
- candidate_declaration_line_distance: Some(declaration_line_distance),
- candidate_excerpt_vs_item_jaccard: Some(excerpt_vs_item_jaccard),
- candidate_excerpt_vs_signature_jaccard: Some(excerpt_vs_signature_jaccard),
- candidate_adjacent_vs_item_jaccard: Some(adjacent_vs_item_jaccard),
- candidate_adjacent_vs_signature_jaccard: Some(adjacent_vs_signature_jaccard),
- candidate_excerpt_vs_item_weighted_overlap: Some(
- excerpt_vs_item_weighted_overlap,
- ),
- candidate_excerpt_vs_signature_weighted_overlap: Some(
- excerpt_vs_signature_weighted_overlap,
- ),
- candidate_adjacent_vs_item_weighted_overlap: Some(
- adjacent_vs_item_weighted_overlap,
- ),
- candidate_adjacent_vs_signature_weighted_overlap: Some(
- adjacent_vs_signature_weighted_overlap,
- ),
- candidate_path_import_match_count: Some(path_import_match_count as u32),
- candidate_wildcard_path_import_match_count: Some(
- wildcard_path_import_match_count as u32,
- ),
- candidate_import_similarity: Some(import_similarity),
- candidate_max_import_similarity: Some(max_import_similarity),
- candidate_normalized_import_similarity: Some(normalized_import_similarity),
- candidate_wildcard_import_similarity: Some(wildcard_import_similarity),
- candidate_normalized_wildcard_import_similarity: Some(
- normalized_wildcard_import_similarity,
- ),
- candidate_included_by_others: Some(included_by_others as u32),
- candidate_includes_others: Some(includes_others as u32),
- });
- }
-
- if result.retrieved_definitions.is_empty() {
- rows.push(Row {
- ref_id,
- cursor_path: cursor_path.to_string(),
- cursor_row,
- cursor_column,
- cursor_identifier: cursor_identifier.clone(),
- gold_in_excerpt,
- gold_path: gold_path.to_string_lossy().to_string(),
- gold_row,
- gold_column,
- gold_is_external,
- candidate_count,
- ..Default::default()
- });
- }
- }
- }
- let slices = rows.slices();
-
- let RowSlices {
- ref_id,
- cursor_path,
- cursor_row,
- cursor_column,
- cursor_identifier,
- gold_in_excerpt,
- gold_path,
- gold_row,
- gold_column,
- gold_is_external,
- candidate_path,
- candidate_row,
- candidate_column,
- candidate_is_gold,
- candidate_rank,
- candidate_count,
- candidate_is_same_file,
- candidate_is_referenced_nearby,
- candidate_is_referenced_in_breadcrumb,
- candidate_reference_count,
- candidate_same_file_declaration_count,
- candidate_declaration_count,
- candidate_reference_line_distance,
- candidate_declaration_line_distance,
- candidate_excerpt_vs_item_jaccard,
- candidate_excerpt_vs_signature_jaccard,
- candidate_adjacent_vs_item_jaccard,
- candidate_adjacent_vs_signature_jaccard,
- candidate_excerpt_vs_item_weighted_overlap,
- candidate_excerpt_vs_signature_weighted_overlap,
- candidate_adjacent_vs_item_weighted_overlap,
- candidate_adjacent_vs_signature_weighted_overlap,
- candidate_path_import_match_count,
- candidate_wildcard_path_import_match_count,
- candidate_import_similarity,
- candidate_max_import_similarity,
- candidate_normalized_import_similarity,
- candidate_wildcard_import_similarity,
- candidate_normalized_wildcard_import_similarity,
- candidate_included_by_others,
- candidate_includes_others,
- } = slices;
-
- let df = DataFrame::new(vec![
- Series::new(PlSmallStr::from_str("ref_id"), ref_id).into(),
- Series::new(PlSmallStr::from_str("cursor_path"), cursor_path).into(),
- Series::new(PlSmallStr::from_str("cursor_row"), cursor_row).into(),
- Series::new(PlSmallStr::from_str("cursor_column"), cursor_column).into(),
- Series::new(PlSmallStr::from_str("cursor_identifier"), cursor_identifier).into(),
- Series::new(PlSmallStr::from_str("gold_in_excerpt"), gold_in_excerpt).into(),
- Series::new(PlSmallStr::from_str("gold_path"), gold_path).into(),
- Series::new(PlSmallStr::from_str("gold_row"), gold_row).into(),
- Series::new(PlSmallStr::from_str("gold_column"), gold_column).into(),
- Series::new(PlSmallStr::from_str("gold_is_external"), gold_is_external).into(),
- Series::new(PlSmallStr::from_str("candidate_count"), candidate_count).into(),
- Series::new(PlSmallStr::from_str("candidate_path"), candidate_path).into(),
- Series::new(PlSmallStr::from_str("candidate_row"), candidate_row).into(),
- Series::new(PlSmallStr::from_str("candidate_column"), candidate_column).into(),
- Series::new(PlSmallStr::from_str("candidate_is_gold"), candidate_is_gold).into(),
- Series::new(PlSmallStr::from_str("candidate_rank"), candidate_rank).into(),
- Series::new(
- PlSmallStr::from_str("candidate_is_same_file"),
- candidate_is_same_file,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_is_referenced_nearby"),
- candidate_is_referenced_nearby,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_is_referenced_in_breadcrumb"),
- candidate_is_referenced_in_breadcrumb,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_reference_count"),
- candidate_reference_count,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_same_file_declaration_count"),
- candidate_same_file_declaration_count,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_declaration_count"),
- candidate_declaration_count,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_reference_line_distance"),
- candidate_reference_line_distance,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_declaration_line_distance"),
- candidate_declaration_line_distance,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_excerpt_vs_item_jaccard"),
- candidate_excerpt_vs_item_jaccard,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_excerpt_vs_signature_jaccard"),
- candidate_excerpt_vs_signature_jaccard,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_adjacent_vs_item_jaccard"),
- candidate_adjacent_vs_item_jaccard,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_adjacent_vs_signature_jaccard"),
- candidate_adjacent_vs_signature_jaccard,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_excerpt_vs_item_weighted_overlap"),
- candidate_excerpt_vs_item_weighted_overlap,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_excerpt_vs_signature_weighted_overlap"),
- candidate_excerpt_vs_signature_weighted_overlap,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_adjacent_vs_item_weighted_overlap"),
- candidate_adjacent_vs_item_weighted_overlap,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_adjacent_vs_signature_weighted_overlap"),
- candidate_adjacent_vs_signature_weighted_overlap,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_path_import_match_count"),
- candidate_path_import_match_count,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_wildcard_path_import_match_count"),
- candidate_wildcard_path_import_match_count,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_import_similarity"),
- candidate_import_similarity,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_max_import_similarity"),
- candidate_max_import_similarity,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_normalized_import_similarity"),
- candidate_normalized_import_similarity,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_wildcard_import_similarity"),
- candidate_wildcard_import_similarity,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_normalized_wildcard_import_similarity"),
- candidate_normalized_wildcard_import_similarity,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_included_by_others"),
- candidate_included_by_others,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_includes_others"),
- candidate_includes_others,
- )
- .into(),
- ])?;
-
- Ok(df)
-}
-
-fn relativize_path(path: &Path) -> &Path {
- path.strip_prefix(std::env::current_dir().unwrap())
- .unwrap_or(path)
-}
-
-struct SummaryStats {
- references_count: u32,
- retrieved_count: u32,
- top_match_count: u32,
- non_top_match_count: u32,
- ranking_involved_top_match_count: u32,
- missing_none_retrieved: u32,
- missing_wrong_retrieval: u32,
- missing_external: u32,
- in_excerpt_count: u32,
-}
-
-impl SummaryStats {
- fn from_dataframe(df: DataFrame) -> Result<Self> {
- // TODO: use lazy more
- let unique_refs =
- df.unique::<(), ()>(Some(&["ref_id".into()]), UniqueKeepStrategy::Any, None)?;
- let references_count = unique_refs.height() as u32;
-
- let gold_mask = df.column("candidate_is_gold")?.bool()?;
- let gold_df = df.filter(&gold_mask)?;
- let retrieved_count = gold_df.height() as u32;
-
- let top_match_mask = gold_df.column("candidate_rank")?.u32()?.equal(0);
- let top_match_df = gold_df.filter(&top_match_mask)?;
- let top_match_count = top_match_df.height() as u32;
-
- let ranking_involved_top_match_count = top_match_df
- .column("candidate_count")?
- .u32()?
- .gt(1)
- .sum()
- .unwrap_or_default();
-
- let non_top_match_count = (!top_match_mask).sum().unwrap_or(0);
-
- let not_retrieved_df = df
- .lazy()
- .group_by(&[col("ref_id"), col("candidate_count")])
- .agg(&[
- col("candidate_is_gold")
- .fill_null(false)
- .sum()
- .alias("gold_count"),
- col("gold_in_excerpt").sum().alias("gold_in_excerpt_count"),
- col("gold_is_external")
- .sum()
- .alias("gold_is_external_count"),
- ])
- .filter(col("gold_count").eq(lit(0)))
- .collect()?;
-
- let in_excerpt_mask = not_retrieved_df
- .column("gold_in_excerpt_count")?
- .u32()?
- .gt(0);
- let in_excerpt_count = in_excerpt_mask.sum().unwrap_or(0);
-
- let missing_df = not_retrieved_df.filter(&!in_excerpt_mask)?;
-
- let missing_none_retrieved_mask = missing_df.column("candidate_count")?.u32()?.equal(0);
- let missing_none_retrieved = missing_none_retrieved_mask.sum().unwrap_or(0);
- let external_mask = missing_df.column("gold_is_external_count")?.u32()?.gt(0);
- let missing_external = (missing_none_retrieved_mask & external_mask)
- .sum()
- .unwrap_or(0);
-
- let missing_wrong_retrieval = missing_df
- .column("candidate_count")?
- .u32()?
- .gt(0)
- .sum()
- .unwrap_or(0);
-
- Ok(SummaryStats {
- references_count,
- retrieved_count,
- top_match_count,
- non_top_match_count,
- ranking_involved_top_match_count,
- missing_none_retrieved,
- missing_wrong_retrieval,
- missing_external,
- in_excerpt_count,
- })
- }
-
- fn count_and_percentage(part: u32, total: u32) -> String {
- format!("{} ({:.2}%)", part, (part as f64 / total as f64) * 100.0)
- }
-}
-
-impl std::fmt::Display for SummaryStats {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- let included = self.in_excerpt_count + self.retrieved_count;
- let missing = self.references_count - included;
- writeln!(f)?;
- writeln!(f, "╮ references: {}", self.references_count)?;
- writeln!(
- f,
- "├─╮ included: {}",
- Self::count_and_percentage(included, self.references_count),
- )?;
- writeln!(
- f,
- "│ ├─╮ retrieved: {}",
- Self::count_and_percentage(self.retrieved_count, self.references_count)
- )?;
- writeln!(
- f,
- "│ │ ├─╮ top match : {}",
- Self::count_and_percentage(self.top_match_count, self.retrieved_count)
- )?;
- writeln!(
- f,
- "│ │ │ ╰─╴ involving ranking: {}",
- Self::count_and_percentage(self.ranking_involved_top_match_count, self.top_match_count)
- )?;
- writeln!(
- f,
- "│ │ ╰─╴ non-top match: {}",
- Self::count_and_percentage(self.non_top_match_count, self.retrieved_count)
- )?;
- writeln!(
- f,
- "│ ╰─╴ in excerpt: {}",
- Self::count_and_percentage(self.in_excerpt_count, included)
- )?;
- writeln!(
- f,
- "╰─╮ missing: {}",
- Self::count_and_percentage(missing, self.references_count)
- )?;
- writeln!(
- f,
- " ├─╮ none retrieved: {}",
- Self::count_and_percentage(self.missing_none_retrieved, missing)
- )?;
- writeln!(
- f,
- " │ ╰─╴ external (expected): {}",
- Self::count_and_percentage(self.missing_external, missing)
- )?;
- writeln!(
- f,
- " ╰─╴ wrong retrieval: {}",
- Self::count_and_percentage(self.missing_wrong_retrieval, missing)
- )?;
- Ok(())
- }
-}
-
-#[derive(Debug)]
-struct ReferenceRetrievalResult {
- cursor_path: Arc<RelPath>,
- cursor_point: Point,
- identifier: Identifier,
- excerpt_range: Option<Range<usize>>,
- lsp_definitions: Vec<SourceRange>,
- retrieved_definitions: Vec<RetrievedDefinition>,
-}
-
-#[derive(Debug)]
-struct RetrievedDefinition {
- path: Arc<RelPath>,
- range: Range<Point>,
- score: f32,
- #[allow(dead_code)]
- retrieval_score: f32,
- #[allow(dead_code)]
- components: DeclarationScoreComponents,
-}
-
-struct RetrieveResult {
- definitions: Vec<RetrievedDefinition>,
- excerpt_range: Option<Range<usize>>,
-}
-
-async fn retrieve_definitions(
- reference: &Reference,
- imports: &Imports,
- query_point: Point,
- snapshot: &BufferSnapshot,
- index: &Arc<SyntaxIndexState>,
- file_snapshots: &Arc<HashMap<ProjectEntryId, BufferSnapshot>>,
- context_options: &EditPredictionContextOptions,
-) -> Result<RetrieveResult> {
- let mut single_reference_map = HashMap::default();
- single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]);
- let edit_prediction_context = EditPredictionContext::gather_context_with_references_fn(
- query_point,
- snapshot,
- imports,
- &context_options,
- Some(&index),
- |_, _, _| single_reference_map,
- );
-
- let Some(edit_prediction_context) = edit_prediction_context else {
- return Ok(RetrieveResult {
- definitions: Vec::new(),
- excerpt_range: None,
- });
- };
-
- let mut retrieved_definitions = Vec::new();
- for scored_declaration in edit_prediction_context.declarations {
- match &scored_declaration.declaration {
- Declaration::File {
- project_entry_id,
- declaration,
- ..
- } => {
- let Some(snapshot) = file_snapshots.get(&project_entry_id) else {
- log::error!("bug: file project entry not found");
- continue;
- };
- let path = snapshot.file().unwrap().path().clone();
- retrieved_definitions.push(RetrievedDefinition {
- path,
- range: snapshot.offset_to_point(declaration.item_range.start)
- ..snapshot.offset_to_point(declaration.item_range.end),
- score: scored_declaration.score(DeclarationStyle::Declaration),
- retrieval_score: scored_declaration.retrieval_score(),
- components: scored_declaration.components,
- });
- }
- Declaration::Buffer {
- project_entry_id,
- rope,
- declaration,
- ..
- } => {
- let Some(snapshot) = file_snapshots.get(&project_entry_id) else {
- // This case happens when dependency buffers have been opened by
- // go-to-definition, resulting in single-file worktrees.
- continue;
- };
- let path = snapshot.file().unwrap().path().clone();
- retrieved_definitions.push(RetrievedDefinition {
- path,
- range: rope.offset_to_point(declaration.item_range.start)
- ..rope.offset_to_point(declaration.item_range.end),
- score: scored_declaration.score(DeclarationStyle::Declaration),
- retrieval_score: scored_declaration.retrieval_score(),
- components: scored_declaration.components,
- });
- }
- }
- }
- retrieved_definitions.sort_by_key(|definition| Reverse(OrderedFloat(definition.score)));
-
- Ok(RetrieveResult {
- definitions: retrieved_definitions,
- excerpt_range: Some(edit_prediction_context.excerpt.range),
- })
-}
-
-async fn gather_lsp_definitions(
- lsp_definitions_path: &Path,
- start_index: usize,
- files: &[ProjectPath],
- worktree: &Entity<Worktree>,
- project: &Entity<Project>,
- definitions: &mut HashMap<SourceLocation, Vec<SourceRange>>,
- cx: &mut AsyncApp,
-) -> Result<()> {
- let worktree_id = worktree.read_with(cx, |worktree, _cx| worktree.id())?;
-
- let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
- cx.subscribe(&lsp_store, {
- move |_, event, _| {
- if let project::LspStoreEvent::LanguageServerUpdate {
- message:
- client::proto::update_language_server::Variant::WorkProgress(
- client::proto::LspWorkProgress {
- message: Some(message),
- ..
- },
- ),
- ..
- } = event
- {
- println!("⟲ {message}")
- }
- }
- })?
- .detach();
-
- let (cache_line_tx, mut cache_line_rx) = mpsc::unbounded::<FileLspDefinitions>();
-
- let cache_file = File::options()
- .append(true)
- .create(true)
- .open(lsp_definitions_path)
- .unwrap();
-
- let cache_task = cx.background_spawn(async move {
- let mut writer = BufWriter::new(cache_file);
- while let Some(line) = cache_line_rx.next().await {
- serde_json::to_writer(&mut writer, &line).unwrap();
- writer.write_all(&[b'\n']).unwrap();
- }
- writer.flush().unwrap();
- });
-
- let mut error_count = 0;
- let mut lsp_open_handles = Vec::new();
- let mut ready_languages = HashSet::default();
- for (file_index, project_path) in files[start_index..].iter().enumerate() {
- println!(
- "Processing file {} of {}: {}",
- start_index + file_index + 1,
- files.len(),
- project_path.path.display(PathStyle::Posix)
- );
-
- let Some((lsp_open_handle, language_server_id, buffer)) = open_buffer_with_language_server(
- project.clone(),
- worktree.clone(),
- project_path.path.clone(),
- &mut ready_languages,
- cx,
- )
- .await
- .log_err() else {
- continue;
- };
- lsp_open_handles.push(lsp_open_handle);
-
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
- let full_range = 0..snapshot.len();
- let references = references_in_range(
- full_range,
- &snapshot.text(),
- ReferenceRegion::Nearby,
- &snapshot,
- );
-
- loop {
- let is_ready = lsp_store
- .read_with(cx, |lsp_store, _cx| {
- lsp_store
- .language_server_statuses
- .get(&language_server_id)
- .is_some_and(|status| status.pending_work.is_empty())
- })
- .unwrap();
- if is_ready {
- break;
- }
- cx.background_executor()
- .timer(Duration::from_millis(10))
- .await;
- }
-
- let mut cache_line_references = Vec::with_capacity(references.len());
-
- for reference in references {
- // TODO: Rename declaration to definition in edit_prediction_context?
- let lsp_result = project
- .update(cx, |project, cx| {
- project.definitions(&buffer, reference.range.start, cx)
- })?
- .await;
-
- match lsp_result {
- Ok(lsp_definitions) => {
- let mut targets = Vec::new();
- for target in lsp_definitions.unwrap_or_default() {
- let buffer = target.target.buffer;
- let anchor_range = target.target.range;
- buffer.read_with(cx, |buffer, cx| {
- let Some(file) = project::File::from_dyn(buffer.file()) else {
- return;
- };
- let file_worktree = file.worktree.read(cx);
- let file_worktree_id = file_worktree.id();
- // Relative paths for worktree files, absolute for all others
- let path = if worktree_id != file_worktree_id {
- file.worktree.read(cx).absolutize(&file.path)
- } else {
- file.path.as_std_path().to_path_buf()
- };
- let offset_range = anchor_range.to_offset(&buffer);
- let point_range = SerializablePoint::from_language_point_range(
- offset_range.to_point(&buffer),
- );
- targets.push(SourceRange {
- path,
- offset_range,
- point_range,
- });
- })?;
- }
-
- let point = snapshot.offset_to_point(reference.range.start);
-
- cache_line_references.push((point.into(), targets.clone()));
- definitions.insert(
- SourceLocation {
- path: project_path.path.clone(),
- point,
- },
- targets,
- );
- }
- Err(err) => {
- log::error!("Language server error: {err}");
- error_count += 1;
- }
- }
- }
-
- cache_line_tx
- .unbounded_send(FileLspDefinitions {
- path: project_path.path.as_unix_str().into(),
- references: cache_line_references,
- })
- .log_err();
- }
-
- drop(cache_line_tx);
-
- if error_count > 0 {
- log::error!("Encountered {} language server errors", error_count);
- }
-
- cache_task.await;
-
- Ok(())
-}
-
-#[derive(Serialize, Deserialize)]
-struct FileLspDefinitions {
- path: Arc<str>,
- references: Vec<(SerializablePoint, Vec<SourceRange>)>,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-struct SourceRange {
- path: PathBuf,
- point_range: Range<SerializablePoint>,
- offset_range: Range<usize>,
-}
-
-/// Serializes to 1-based row and column indices.
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct SerializablePoint {
- pub row: u32,
- pub column: u32,
-}
-
-impl SerializablePoint {
- pub fn into_language_point_range(range: Range<Self>) -> Range<Point> {
- range.start.into()..range.end.into()
- }
-
- pub fn from_language_point_range(range: Range<Point>) -> Range<Self> {
- range.start.into()..range.end.into()
- }
-}
-
-impl From<Point> for SerializablePoint {
- fn from(point: Point) -> Self {
- SerializablePoint {
- row: point.row + 1,
- column: point.column + 1,
- }
- }
-}
-
-impl From<SerializablePoint> for Point {
- fn from(serializable: SerializablePoint) -> Self {
- Point {
- row: serializable.row.saturating_sub(1),
- column: serializable.column.saturating_sub(1),
- }
- }
-}