.gitignore ๐
@@ -39,3 +39,6 @@ xcuserdata/
# Don't commit any secrets to the repo.
.env
.env.secret.toml
+
+# `nix build` output
+/result
Richard Feldman created
.gitignore | 3
Cargo.lock | 1136
Cargo.toml | 48
assets/icons/git_branch_plus.svg | 8
assets/icons/inception.svg | 11
assets/keymaps/default-linux.json | 21
assets/keymaps/default-macos.json | 21
assets/keymaps/default-windows.json | 24
assets/prompts/content_prompt_v2.hbs | 44
crates/acp_thread/src/acp_thread.rs | 2
crates/acp_thread/src/terminal.rs | 26
crates/agent/src/edit_agent/evals/fixtures/zode/prompt.md | 4
crates/agent/src/tests/mod.rs | 2
crates/agent/src/thread.rs | 53
crates/agent/src/tools.rs | 4
crates/agent/src/tools/edit_file_tool.rs | 6
crates/agent/src/tools/find_path_tool.rs | 2
crates/agent/src/tools/grep_tool.rs | 19
crates/agent/src/tools/read_file_tool.rs | 55
crates/agent/src/tools/web_search_tool.rs | 2
crates/agent_servers/src/acp.rs | 36
crates/agent_ui/src/acp/entry_view_state.rs | 12
crates/agent_ui/src/acp/message_editor.rs | 179
crates/agent_ui/src/acp/thread_view.rs | 8
crates/agent_ui/src/agent_configuration.rs | 3
crates/agent_ui/src/agent_model_selector.rs | 2
crates/agent_ui/src/buffer_codegen.rs | 299
crates/agent_ui/src/inline_assistant.rs | 111
crates/agent_ui/src/inline_prompt_editor.rs | 141
crates/agent_ui/src/text_thread_editor.rs | 92
crates/agent_ui/src/ui/agent_notification.rs | 3
crates/anthropic/src/anthropic.rs | 143
crates/anthropic/src/batches.rs | 190
crates/assistant_text_thread/src/text_thread.rs | 16
crates/assistant_text_thread/src/text_thread_store.rs | 65
crates/cloud_llm_client/src/predict_edits_v3.rs | 88
crates/cloud_zeta2_prompt/Cargo.toml | 5
crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs | 680
crates/cloud_zeta2_prompt/src/retrieval_prompt.rs | 244
crates/codestral/Cargo.toml | 2
crates/codestral/src/codestral.rs | 13
crates/collab/Cargo.toml | 2
crates/collab/migrations.sqlite/20221109000000_test_schema.sql | 2
crates/collab/migrations/20251203234258_add_remote_urls_to_project_repositories.sql | 2
crates/collab/src/db/queries/projects.rs | 6
crates/collab/src/db/queries/rooms.rs | 2
crates/collab/src/db/tables/project_repository.rs | 2
crates/collab/src/tests.rs | 17
crates/collab/src/tests/editor_tests.rs | 12
crates/collab/src/tests/integration_tests.rs | 4
crates/copilot/Cargo.toml | 2
crates/copilot/src/copilot.rs | 4
crates/copilot/src/copilot_edit_prediction_delegate.rs | 22
crates/dap_adapters/src/python.rs | 2
crates/debugger_ui/Cargo.toml | 2
crates/debugger_ui/src/debugger_panel.rs | 193
crates/debugger_ui/src/debugger_ui.rs | 13
crates/debugger_ui/src/new_process_modal.rs | 3
crates/debugger_ui/src/session/running.rs | 2
crates/debugger_ui/src/session/running/loaded_source_list.rs | 4
crates/debugger_ui/src/session/running/module_list.rs | 4
crates/debugger_ui/src/session/running/stack_frame_list.rs | 43
crates/debugger_ui/src/session/running/variable_list.rs | 7
crates/debugger_ui/src/tests/inline_values.rs | 121
crates/debugger_ui/src/tests/stack_frame_list.rs | 182
crates/edit_prediction/Cargo.toml | 62
crates/edit_prediction/license_examples/0bsd.txt | 0
crates/edit_prediction/license_examples/apache-2.0-ex0.txt | 0
crates/edit_prediction/license_examples/apache-2.0-ex1.txt | 0
crates/edit_prediction/license_examples/apache-2.0-ex2.txt | 0
crates/edit_prediction/license_examples/apache-2.0-ex3.txt | 0
crates/edit_prediction/license_examples/apache-2.0-ex4.txt | 0
crates/edit_prediction/license_examples/bsd-1-clause.txt | 0
crates/edit_prediction/license_examples/bsd-2-clause-ex0.txt | 0
crates/edit_prediction/license_examples/bsd-3-clause-ex0.txt | 0
crates/edit_prediction/license_examples/bsd-3-clause-ex1.txt | 0
crates/edit_prediction/license_examples/bsd-3-clause-ex2.txt | 0
crates/edit_prediction/license_examples/bsd-3-clause-ex3.txt | 0
crates/edit_prediction/license_examples/bsd-3-clause-ex4.txt | 0
crates/edit_prediction/license_examples/isc.txt | 0
crates/edit_prediction/license_examples/mit-ex0.txt | 0
crates/edit_prediction/license_examples/mit-ex1.txt | 0
crates/edit_prediction/license_examples/mit-ex2.txt | 0
crates/edit_prediction/license_examples/mit-ex3.txt | 0
crates/edit_prediction/license_examples/upl-1.0.txt | 0
crates/edit_prediction/license_examples/zlib-ex0.txt | 0
crates/edit_prediction/license_patterns/0bsd-pattern | 0
crates/edit_prediction/license_patterns/apache-2.0-pattern | 0
crates/edit_prediction/license_patterns/apache-2.0-reference-pattern | 0
crates/edit_prediction/license_patterns/bsd-pattern | 0
crates/edit_prediction/license_patterns/isc-pattern | 0
crates/edit_prediction/license_patterns/mit-pattern | 0
crates/edit_prediction/license_patterns/upl-1.0-pattern | 0
crates/edit_prediction/license_patterns/zlib-pattern | 0
crates/edit_prediction/src/cursor_excerpt.rs | 78
crates/edit_prediction/src/edit_prediction.rs | 2074
crates/edit_prediction/src/edit_prediction_tests.rs | 1806
crates/edit_prediction/src/license_detection.rs | 0
crates/edit_prediction/src/mercury.rs | 340
crates/edit_prediction/src/onboarding_modal.rs | 0
crates/edit_prediction/src/open_ai_response.rs | 31
crates/edit_prediction/src/prediction.rs | 2
crates/edit_prediction/src/sweep_ai.rs | 30
crates/edit_prediction/src/udiff.rs | 0
crates/edit_prediction/src/xml_edits.rs | 0
crates/edit_prediction/src/zed_edit_prediction_delegate.rs | 114
crates/edit_prediction/src/zeta1.rs | 183
crates/edit_prediction/src/zeta2.rs | 327
crates/edit_prediction_cli/Cargo.toml | 17
crates/edit_prediction_cli/LICENSE-GPL | 0
crates/edit_prediction_cli/build.rs | 0
crates/edit_prediction_cli/src/evaluate.rs | 14
crates/edit_prediction_cli/src/example.rs | 255
crates/edit_prediction_cli/src/headless.rs | 0
crates/edit_prediction_cli/src/main.rs | 160
crates/edit_prediction_cli/src/metrics.rs | 4
crates/edit_prediction_cli/src/paths.rs | 0
crates/edit_prediction_cli/src/predict.rs | 124
crates/edit_prediction_cli/src/source_location.rs | 0
crates/edit_prediction_cli/src/training/context.rs | 89
crates/edit_prediction_cli/src/training/distill.rs | 94
crates/edit_prediction_cli/src/training/llm_client.rs | 417
crates/edit_prediction_cli/src/training/mod.rs | 4
crates/edit_prediction_cli/src/training/teacher.prompt.md | 48
crates/edit_prediction_cli/src/training/teacher.rs | 266
crates/edit_prediction_cli/src/util.rs | 28
crates/edit_prediction_context/Cargo.toml | 23
crates/edit_prediction_context/src/assemble_excerpts.rs | 161
crates/edit_prediction_context/src/declaration.rs | 350
crates/edit_prediction_context/src/declaration_scoring.rs | 539
crates/edit_prediction_context/src/edit_prediction_context.rs | 757
crates/edit_prediction_context/src/edit_prediction_context_tests.rs | 510
crates/edit_prediction_context/src/excerpt.rs | 93
crates/edit_prediction_context/src/fake_definition_lsp.rs | 329
crates/edit_prediction_context/src/imports.rs | 1319
crates/edit_prediction_context/src/outline.rs | 126
crates/edit_prediction_context/src/reference.rs | 173
crates/edit_prediction_context/src/syntax_index.rs | 1069
crates/edit_prediction_context/src/text_similarity.rs | 314
crates/edit_prediction_types/Cargo.toml | 17
crates/edit_prediction_types/LICENSE-GPL | 0
crates/edit_prediction_types/src/edit_prediction_types.rs | 298
crates/edit_prediction_ui/Cargo.toml | 16
crates/edit_prediction_ui/LICENSE-GPL | 0
crates/edit_prediction_ui/src/edit_prediction_button.rs | 180
crates/edit_prediction_ui/src/edit_prediction_context_view.rs | 389
crates/edit_prediction_ui/src/edit_prediction_ui.rs | 128
crates/edit_prediction_ui/src/external_provider_api_token_modal.rs | 36
crates/edit_prediction_ui/src/rate_prediction_modal.rs | 59
crates/editor/Cargo.toml | 6
crates/editor/src/actions.rs | 19
crates/editor/src/display_map/block_map.rs | 42
crates/editor/src/display_map/crease_map.rs | 17
crates/editor/src/display_map/custom_highlights.rs | 3
crates/editor/src/display_map/fold_map.rs | 35
crates/editor/src/display_map/inlay_map.rs | 28
crates/editor/src/display_map/invisibles.rs | 1
crates/editor/src/display_map/tab_map.rs | 27
crates/editor/src/display_map/wrap_map.rs | 38
crates/editor/src/edit_prediction_tests.rs | 64
crates/editor/src/editor.rs | 228
crates/editor/src/editor_tests.rs | 242
crates/editor/src/element.rs | 24
crates/editor/src/git/blame.rs | 22
crates/editor/src/hover_links.rs | 2
crates/editor/src/hover_popover.rs | 8
crates/editor/src/indent_guides.rs | 4
crates/editor/src/items.rs | 20
crates/editor/src/mouse_context_menu.rs | 5
crates/editor/src/selections_collection.rs | 29
crates/extension/Cargo.toml | 1
crates/extension/src/extension_builder.rs | 114
crates/extension_host/src/extension_store_test.rs | 16
crates/extensions_ui/src/extensions_ui.rs | 3
crates/feature_flags/src/flags.rs | 12
crates/fs/src/fake_git_repo.rs | 17
crates/fuzzy/src/matcher.rs | 14
crates/git/src/blame.rs | 8
crates/git/src/repository.rs | 31
crates/git_hosting_providers/Cargo.toml | 1
crates/git_hosting_providers/src/git_hosting_providers.rs | 4
crates/git_hosting_providers/src/providers/bitbucket.rs | 163
crates/git_hosting_providers/src/providers/sourcehut.rs | 196
crates/git_hosting_providers/src/settings.rs | 7
crates/git_ui/Cargo.toml | 7
crates/git_ui/src/blame_ui.rs | 29
crates/git_ui/src/branch_picker.rs | 591
crates/git_ui/src/commit_view.rs | 241
crates/git_ui/src/conflict_view.rs | 4
crates/git_ui/src/file_history_view.rs | 14
crates/git_ui/src/project_diff.rs | 91
crates/gpui/src/app/entity_map.rs | 112
crates/gpui/src/geometry.rs | 6
crates/gpui/src/taffy.rs | 7
crates/icons/src/icons.rs | 16
crates/json_schema_store/src/json_schema_store.rs | 54
crates/language/src/buffer.rs | 14
crates/language/src/buffer_tests.rs | 240
crates/language/src/language.rs | 52
crates/language/src/language_registry.rs | 34
crates/language/src/language_settings.rs | 17
crates/language/src/outline.rs | 50
crates/language/src/syntax_map.rs | 13
crates/language/src/syntax_map/syntax_map_tests.rs | 66
crates/language_extension/src/extension_lsp_adapter.rs | 2
crates/language_model/src/language_model.rs | 34
crates/language_models/src/provider/open_ai.rs | 2
crates/languages/src/javascript/highlights.scm | 39
crates/languages/src/jsdoc/highlights.scm | 1
crates/languages/src/json.rs | 15
crates/languages/src/lib.rs | 2
crates/languages/src/markdown/config.toml | 1
crates/languages/src/markdown/indents.scm | 1
crates/languages/src/python.rs | 4
crates/languages/src/tailwind.rs | 40
crates/languages/src/tsx/highlights.scm | 93
crates/languages/src/typescript.rs | 6
crates/languages/src/typescript/highlights.scm | 94
crates/languages/src/vtsls.rs | 6
crates/markdown/src/markdown.rs | 30
crates/markdown_preview/Cargo.toml | 1
crates/markdown_preview/src/markdown_parser.rs | 21
crates/markdown_preview/src/markdown_preview_view.rs | 2
crates/multi_buffer/Cargo.toml | 5
crates/multi_buffer/src/multi_buffer.rs | 10
crates/multi_buffer/src/path_key.rs | 872
crates/open_ai/src/open_ai.rs | 3
crates/outline/src/outline.rs | 88
crates/outline_panel/src/outline_panel.rs | 298
crates/project/Cargo.toml | 5
crates/project/src/agent_server_store.rs | 149
crates/project/src/debugger/dap_store.rs | 2
crates/project/src/debugger/session.rs | 267
crates/project/src/git_store.rs | 19
crates/project/src/git_store/branch_diff.rs | 3
crates/project/src/git_store/conflict_set.rs | 72
crates/project/src/lsp_store.rs | 11
crates/project/src/project_tests.rs | 20
crates/project/src/terminals.rs | 4
crates/prompt_store/src/prompts.rs | 92
crates/proto/proto/git.proto | 4
crates/recent_projects/src/recent_projects.rs | 49
crates/remote/src/transport.rs | 109
crates/remote/src/transport/ssh.rs | 65
crates/remote/src/transport/wsl.rs | 24
crates/remote_server/src/remote_editing_tests.rs | 6
crates/repl/src/kernels/mod.rs | 2
crates/rope/Cargo.toml | 5
crates/rope/src/rope.rs | 2
crates/settings/src/keymap_file.rs | 5
crates/settings/src/settings_content/language.rs | 10
crates/settings/src/settings_content/project.rs | 5
crates/settings/src/settings_content/terminal.rs | 5
crates/settings/src/settings_store.rs | 3
crates/settings_ui/src/settings_ui.rs | 5
crates/snippet_provider/src/format.rs | 3
crates/sum_tree/Cargo.toml | 5
crates/sum_tree/src/cursor.rs | 5
crates/sum_tree/src/sum_tree.rs | 3
crates/supermaven/Cargo.toml | 2
crates/supermaven/src/supermaven.rs | 4
crates/supermaven/src/supermaven_edit_prediction_delegate.rs | 14
crates/task/src/debug_format.rs | 1
crates/task/src/task_template.rs | 3
crates/terminal/src/terminal_hyperlinks.rs | 51
crates/terminal/src/terminal_settings.rs | 2
crates/terminal_view/src/terminal_panel.rs | 2
crates/terminal_view/src/terminal_tab_tooltip.rs | 36
crates/terminal_view/src/terminal_view.rs | 30
crates/text/src/anchor.rs | 8
crates/title_bar/src/collab.rs | 7
crates/ui/src/components/button/split_button.rs | 32
crates/ui/src/components/data_table.rs | 22
crates/ui/src/components/toggle.rs | 117
crates/util/src/command.rs | 2
crates/util/src/schemars.rs | 17
crates/util/src/util.rs | 2
crates/vim/src/normal/paste.rs | 2
crates/vim/src/normal/yank.rs | 14
crates/vim/src/object.rs | 7
crates/workspace/src/persistence.rs | 10
crates/workspace/src/workspace.rs | 25
crates/worktree/src/worktree.rs | 44
crates/worktree/src/worktree_tests.rs | 169
crates/x_ai/src/x_ai.rs | 38
crates/zed/Cargo.toml | 12
crates/zed/src/main.rs | 7
crates/zed/src/zed.rs | 71
crates/zed/src/zed/app_menus.rs | 5
crates/zed/src/zed/edit_prediction_registry.rs | 51
crates/zeta/Cargo.toml | 84
crates/zeta/src/assemble_excerpts.rs | 173
crates/zeta/src/retrieval_search.rs | 642
crates/zeta/src/zeta.rs | 4062
crates/zeta/src/zeta1/input_excerpt.rs | 231
crates/zeta/src/zeta_tests.rs | 671
crates/zeta2_tools/Cargo.toml | 49
crates/zeta2_tools/src/zeta2_context_view.rs | 438
crates/zeta2_tools/src/zeta2_tools.rs | 1023
crates/zeta_cli/src/syntax_retrieval_stats.rs | 1260
crates/zlog/src/filter.rs | 30
crates/zlog/src/sink.rs | 6
crates/zlog/src/zlog.rs | 20
crates/ztracing/Cargo.toml | 20
crates/ztracing/LICENSE-AGPL | 1
crates/ztracing/LICENSE-APACHE | 1
crates/ztracing/LICENSE-GPL | 0
crates/ztracing/build.rs | 9
crates/ztracing/src/lib.rs | 16
crates/ztracing_macro/Cargo.toml | 11
crates/ztracing_macro/LICENSE-AGPL | 1
crates/ztracing_macro/LICENSE-APACHE | 1
crates/ztracing_macro/LICENSE-GPL | 1
crates/ztracing_macro/src/lib.rs | 7
docs/src/SUMMARY.md | 1
docs/src/development/debuggers.md | 2
docs/src/languages/astro.md | 2
docs/src/languages/php.md | 114
docs/src/languages/rego.md | 2
docs/src/performance.md | 52
docs/src/remote-development.md | 28
docs/src/tab-switcher.md | 46
nix/build.nix | 1
323 files changed, 15,102 insertions(+), 18,377 deletions(-)
@@ -39,3 +39,6 @@ xcuserdata/
# Don't commit any secrets to the repo.
.env
.env.secret.toml
+
+# `nix build` output
+/result
@@ -211,14 +211,14 @@ dependencies = [
"worktree",
"zed_env_vars",
"zlog",
- "zstd 0.11.2+zstd.1.5.2",
+ "zstd",
]
[[package]]
name = "agent-client-protocol"
-version = "0.8.0"
+version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3e639d6b544ad39f5b4e05802db5eb04e1518284eb05fda1839931003e0244c8"
+checksum = "c2ffe7d502c1e451aafc5aff655000f84d09c9af681354ac0012527009b1af13"
dependencies = [
"agent-client-protocol-schema",
"anyhow",
@@ -233,15 +233,16 @@ dependencies = [
[[package]]
name = "agent-client-protocol-schema"
-version = "0.9.1"
+version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f182f5e14bef8232b239719bd99166bb11e986c08fc211f28e392f880d3093ba"
+checksum = "8af81cc2d5c3f9c04f73db452efd058333735ba9d51c2cf7ef33c9fee038e7e6"
dependencies = [
"anyhow",
"derive_more 2.0.1",
"schemars",
"serde",
"serde_json",
+ "strum 0.27.2",
]
[[package]]
@@ -680,21 +681,6 @@ dependencies = [
"syn 2.0.106",
]
-[[package]]
-name = "argminmax"
-version = "0.6.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "70f13d10a41ac8d2ec79ee34178d61e6f47a29c2edfe7ef1721c7383b0359e65"
-dependencies = [
- "num-traits",
-]
-
-[[package]]
-name = "array-init-cursor"
-version = "0.2.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ed51fe0f224d1d4ea768be38c51f9f831dee9d05c163c11fba0b8c44387b1fc3"
-
[[package]]
name = "arraydeque"
version = "0.5.1"
@@ -1278,15 +1264,6 @@ dependencies = [
"num-traits",
]
-[[package]]
-name = "atoi_simd"
-version = "0.16.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c2a49e05797ca52e312a0c658938b7d00693ef037799ef7187678f212d7684cf"
-dependencies = [
- "debug_unsafe",
-]
-
[[package]]
name = "atomic"
version = "0.5.3"
@@ -2070,26 +2047,6 @@ dependencies = [
"serde",
]
-[[package]]
-name = "bincode"
-version = "2.0.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "36eaf5d7b090263e8150820482d5d93cd964a81e4019913c972f4edcc6edb740"
-dependencies = [
- "bincode_derive",
- "serde",
- "unty",
-]
-
-[[package]]
-name = "bincode_derive"
-version = "2.0.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bf95709a440f45e986983918d0e8a1f30a9b1df04918fc828670606804ac3c09"
-dependencies = [
- "virtue",
-]
-
[[package]]
name = "bindgen"
version = "0.71.1"
@@ -2242,19 +2199,6 @@ dependencies = [
"profiling",
]
-[[package]]
-name = "blake3"
-version = "1.8.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0"
-dependencies = [
- "arrayref",
- "arrayvec",
- "cc",
- "cfg-if",
- "constant_time_eq 0.3.1",
-]
-
[[package]]
name = "block"
version = "0.1.6"
@@ -2344,12 +2288,6 @@ dependencies = [
"syn 2.0.106",
]
-[[package]]
-name = "boxcar"
-version = "0.2.14"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "36f64beae40a84da1b4b26ff2761a5b895c12adc41dc25aaee1c4f2bbfe97a6e"
-
[[package]]
name = "breadcrumbs"
version = "0.1.0"
@@ -2516,9 +2454,6 @@ name = "bytes"
version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a"
-dependencies = [
- "serde",
-]
[[package]]
name = "bytes-utils"
@@ -2805,15 +2740,6 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
-[[package]]
-name = "castaway"
-version = "0.2.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a"
-dependencies = [
- "rustversion",
-]
-
[[package]]
name = "cbc"
version = "0.1.2"
@@ -2942,16 +2868,6 @@ dependencies = [
"windows-link 0.2.1",
]
-[[package]]
-name = "chrono-tz"
-version = "0.10.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a6139a8597ed92cf816dfb33f5dd6cf0bb93a6adc938f11039f371bc5bcd26c3"
-dependencies = [
- "chrono",
- "phf 0.12.1",
-]
-
[[package]]
name = "chunked_transfer"
version = "1.5.0"
@@ -3201,12 +3117,7 @@ dependencies = [
"anyhow",
"cloud_llm_client",
"indoc",
- "ordered-float 2.10.1",
- "rustc-hash 2.1.1",
- "schemars",
"serde",
- "serde_json",
- "strum 0.27.2",
]
[[package]]
@@ -3314,8 +3225,8 @@ name = "codestral"
version = "0.1.0"
dependencies = [
"anyhow",
- "edit_prediction",
"edit_prediction_context",
+ "edit_prediction_types",
"futures 0.3.31",
"gpui",
"http_client",
@@ -3505,17 +3416,6 @@ dependencies = [
"memchr",
]
-[[package]]
-name = "comfy-table"
-version = "7.2.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b03b7db8e0b4b2fdad6c551e634134e99ec000e5c8c3b6856c65e8bbaded7a3b"
-dependencies = [
- "crossterm",
- "unicode-segmentation",
- "unicode-width",
-]
-
[[package]]
name = "command-fds"
version = "0.3.2"
@@ -3569,21 +3469,6 @@ dependencies = [
"workspace",
]
-[[package]]
-name = "compact_str"
-version = "0.9.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a"
-dependencies = [
- "castaway",
- "cfg-if",
- "itoa",
- "rustversion",
- "ryu",
- "serde",
- "static_assertions",
-]
-
[[package]]
name = "component"
version = "0.1.0"
@@ -3689,12 +3574,6 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc"
-[[package]]
-name = "constant_time_eq"
-version = "0.3.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6"
-
[[package]]
name = "context_server"
version = "0.1.0"
@@ -3747,7 +3626,7 @@ dependencies = [
"command_palette_hooks",
"ctor",
"dirs 4.0.0",
- "edit_prediction",
+ "edit_prediction_types",
"editor",
"fs",
"futures 0.3.31",
@@ -4160,7 +4039,7 @@ dependencies = [
name = "crashes"
version = "0.1.0"
dependencies = [
- "bincode 1.3.3",
+ "bincode",
"cfg-if",
"crash-handler",
"extension_host",
@@ -4174,7 +4053,7 @@ dependencies = [
"smol",
"system_specs",
"windows 0.61.3",
- "zstd 0.11.2+zstd.1.5.2",
+ "zstd",
]
[[package]]
@@ -4319,29 +4198,6 @@ version = "0.8.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
-[[package]]
-name = "crossterm"
-version = "0.29.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d8b9f2e4c67f833b660cdb0a3523065869fb35570177239812ed4c905aeff87b"
-dependencies = [
- "bitflags 2.9.4",
- "crossterm_winapi",
- "document-features",
- "parking_lot",
- "rustix 1.1.2",
- "winapi",
-]
-
-[[package]]
-name = "crossterm_winapi"
-version = "0.9.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b"
-dependencies = [
- "winapi",
-]
-
[[package]]
name = "crunchy"
version = "0.2.4"
@@ -4696,12 +4552,6 @@ dependencies = [
"util",
]
-[[package]]
-name = "debug_unsafe"
-version = "0.1.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "85d3cef41d236720ed453e102153a53e4cc3d2fde848c0078a50cf249e8e3e5b"
-
[[package]]
name = "debugger_tools"
version = "0.1.0"
@@ -4734,6 +4584,7 @@ dependencies = [
"db",
"debugger_tools",
"editor",
+ "feature_flags",
"file_icons",
"futures 0.3.31",
"fuzzy",
@@ -5109,15 +4960,6 @@ dependencies = [
"zlog",
]
-[[package]]
-name = "document-features"
-version = "0.2.11"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "95249b50c6c185bee49034bcb378a49dc2b5dff0be90ff6616d31d64febab05d"
-dependencies = [
- "litrs",
-]
-
[[package]]
name = "documented"
version = "0.9.2"
@@ -5267,144 +5109,252 @@ dependencies = [
name = "edit_prediction"
version = "0.1.0"
dependencies = [
- "client",
- "gpui",
- "language",
-]
-
-[[package]]
-name = "edit_prediction_button"
-version = "0.1.0"
-dependencies = [
+ "ai_onboarding",
"anyhow",
+ "arrayvec",
+ "brotli",
"client",
+ "clock",
+ "cloud_api_types",
"cloud_llm_client",
- "codestral",
+ "cloud_zeta2_prompt",
+ "collections",
"copilot",
- "edit_prediction",
- "editor",
+ "credentials_provider",
+ "ctor",
+ "db",
+ "edit_prediction_context",
+ "edit_prediction_types",
"feature_flags",
"fs",
"futures 0.3.31",
"gpui",
"indoc",
+ "itertools 0.14.0",
"language",
+ "language_model",
+ "log",
"lsp",
"menu",
- "paths",
+ "open_ai",
+ "parking_lot",
+ "postage",
+ "pretty_assertions",
"project",
+ "rand 0.9.2",
"regex",
+ "release_channel",
+ "semver",
+ "serde",
"serde_json",
"settings",
- "supermaven",
+ "smol",
+ "strsim",
+ "strum 0.27.2",
"telemetry",
- "theme",
+ "telemetry_events",
+ "thiserror 2.0.17",
"ui",
- "ui_input",
"util",
+ "uuid",
"workspace",
+ "worktree",
"zed_actions",
- "zeta",
+ "zlog",
]
[[package]]
-name = "edit_prediction_context"
+name = "edit_prediction_cli"
version = "0.1.0"
dependencies = [
+ "anthropic",
"anyhow",
- "arrayvec",
+ "chrono",
"clap",
+ "client",
"cloud_llm_client",
+ "cloud_zeta2_prompt",
"collections",
+ "debug_adapter_extension",
+ "edit_prediction",
+ "edit_prediction_context",
+ "extension",
+ "fs",
"futures 0.3.31",
"gpui",
- "hashbrown 0.15.5",
+ "gpui_tokio",
+ "http_client",
"indoc",
- "itertools 0.14.0",
"language",
+ "language_extension",
+ "language_model",
+ "language_models",
+ "languages",
"log",
- "ordered-float 2.10.1",
- "postage",
+ "node_runtime",
+ "paths",
"pretty_assertions",
"project",
- "regex",
+ "prompt_store",
+ "pulldown-cmark 0.12.2",
+ "release_channel",
+ "reqwest_client",
"serde",
"serde_json",
"settings",
- "slotmap",
- "strum 0.27.2",
- "text",
- "tree-sitter",
- "tree-sitter-c",
- "tree-sitter-cpp",
- "tree-sitter-go",
+ "shellexpand 2.1.2",
+ "smol",
+ "sqlez",
+ "sqlez_macros",
+ "terminal_view",
+ "toml 0.8.23",
"util",
+ "watch",
"zlog",
]
[[package]]
-name = "editor"
+name = "edit_prediction_context"
version = "0.1.0"
dependencies = [
- "aho-corasick",
"anyhow",
- "assets",
- "buffer_diff",
- "client",
- "clock",
+ "cloud_llm_client",
"collections",
- "convert_case 0.8.0",
- "criterion",
- "ctor",
- "dap",
- "db",
- "edit_prediction",
- "emojis",
- "feature_flags",
- "file_icons",
- "fs",
+ "env_logger 0.11.8",
"futures 0.3.31",
- "fuzzy",
- "git",
"gpui",
- "http_client",
"indoc",
- "itertools 0.14.0",
"language",
- "languages",
- "linkify",
"log",
"lsp",
- "markdown",
- "menu",
- "multi_buffer",
- "ordered-float 2.10.1",
"parking_lot",
"pretty_assertions",
"project",
- "rand 0.9.2",
- "regex",
- "release_channel",
- "rope",
- "rpc",
- "schemars",
- "semver",
"serde",
"serde_json",
"settings",
"smallvec",
- "smol",
- "snippet",
- "sum_tree",
- "task",
- "telemetry",
- "tempfile",
+ "text",
+ "tree-sitter",
+ "util",
+ "zlog",
+]
+
+[[package]]
+name = "edit_prediction_types"
+version = "0.1.0"
+dependencies = [
+ "client",
+ "gpui",
+ "language",
+]
+
+[[package]]
+name = "edit_prediction_ui"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "buffer_diff",
+ "client",
+ "cloud_llm_client",
+ "cloud_zeta2_prompt",
+ "codestral",
+ "command_palette_hooks",
+ "copilot",
+ "edit_prediction",
+ "edit_prediction_types",
+ "editor",
+ "feature_flags",
+ "fs",
+ "futures 0.3.31",
+ "gpui",
+ "indoc",
+ "language",
+ "lsp",
+ "markdown",
+ "menu",
+ "multi_buffer",
+ "paths",
+ "project",
+ "regex",
+ "serde_json",
+ "settings",
+ "supermaven",
+ "telemetry",
+ "text",
+ "theme",
+ "ui",
+ "ui_input",
+ "util",
+ "workspace",
+ "zed_actions",
+]
+
+[[package]]
+name = "editor"
+version = "0.1.0"
+dependencies = [
+ "aho-corasick",
+ "anyhow",
+ "assets",
+ "buffer_diff",
+ "client",
+ "clock",
+ "collections",
+ "convert_case 0.8.0",
+ "criterion",
+ "ctor",
+ "dap",
+ "db",
+ "edit_prediction_types",
+ "emojis",
+ "feature_flags",
+ "file_icons",
+ "fs",
+ "futures 0.3.31",
+ "fuzzy",
+ "git",
+ "gpui",
+ "http_client",
+ "indoc",
+ "itertools 0.14.0",
+ "language",
+ "languages",
+ "linkify",
+ "log",
+ "lsp",
+ "markdown",
+ "menu",
+ "multi_buffer",
+ "ordered-float 2.10.1",
+ "parking_lot",
+ "pretty_assertions",
+ "project",
+ "rand 0.9.2",
+ "regex",
+ "release_channel",
+ "rope",
+ "rpc",
+ "schemars",
+ "semver",
+ "serde",
+ "serde_json",
+ "settings",
+ "smallvec",
+ "smol",
+ "snippet",
+ "sum_tree",
+ "task",
+ "telemetry",
+ "tempfile",
"text",
"theme",
"time",
+ "tracing",
"tree-sitter-bash",
"tree-sitter-c",
"tree-sitter-html",
+ "tree-sitter-md",
"tree-sitter-python",
"tree-sitter-rust",
"tree-sitter-typescript",
@@ -5420,6 +5370,7 @@ dependencies = [
"workspace",
"zed_actions",
"zlog",
+ "ztracing",
]
[[package]]
@@ -5695,12 +5646,6 @@ dependencies = [
"windows-sys 0.48.0",
]
-[[package]]
-name = "ethnum"
-version = "1.5.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ca81e6b4777c89fd810c25a4be2b1bd93ea034fbe58e6a75216a34c6b82c539b"
-
[[package]]
name = "euclid"
version = "0.22.11"
@@ -5864,6 +5809,7 @@ dependencies = [
"serde",
"serde_json",
"task",
+ "tempfile",
"toml 0.8.23",
"url",
"util",
@@ -5993,12 +5939,6 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649"
-[[package]]
-name = "fallible-streaming-iterator"
-version = "0.1.9"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
-
[[package]]
name = "fancy-regex"
version = "0.16.2"
@@ -6010,12 +5950,6 @@ dependencies = [
"regex-syntax",
]
-[[package]]
-name = "fast-float2"
-version = "0.2.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f8eb564c5c7423d25c886fb561d1e4ee69f72354d16918afa32c08811f6b6a55"
-
[[package]]
name = "fast-srgb8"
version = "1.0.0"
@@ -6191,7 +6125,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc5a4e564e38c699f2880d3fda590bedc2e69f3f84cd48b457bd892ce61d0aa9"
dependencies = [
"crc32fast",
- "libz-rs-sys",
"miniz_oxide",
]
@@ -6448,16 +6381,6 @@ dependencies = [
"winapi",
]
-[[package]]
-name = "fs4"
-version = "0.13.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8640e34b88f7652208ce9e88b1a37a2ae95227d84abec377ccd3c5cfeb141ed4"
-dependencies = [
- "rustix 1.1.2",
- "windows-sys 0.59.0",
-]
-
[[package]]
name = "fs_benchmarks"
version = "0.1.0"
@@ -6918,6 +6841,20 @@ dependencies = [
"seq-macro",
]
+[[package]]
+name = "generator"
+version = "0.8.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "605183a538e3e2a9c1038635cc5c2d194e2ee8fd0d1b66b8349fad7dbacce5a2"
+dependencies = [
+ "cc",
+ "cfg-if",
+ "libc",
+ "log",
+ "rustversion",
+ "windows 0.61.3",
+]
+
[[package]]
name = "generic-array"
version = "0.14.7"
@@ -7079,6 +7016,7 @@ dependencies = [
"gpui",
"http_client",
"indoc",
+ "itertools 0.14.0",
"pretty_assertions",
"regex",
"serde",
@@ -7136,6 +7074,7 @@ dependencies = [
"theme",
"time",
"time_format",
+ "tracing",
"ui",
"unindent",
"util",
@@ -7145,6 +7084,7 @@ dependencies = [
"zed_actions",
"zeroize",
"zlog",
+ "ztracing",
]
[[package]]
@@ -7521,7 +7461,6 @@ dependencies = [
"allocator-api2",
"equivalent",
"foldhash 0.1.5",
- "rayon",
"serde",
]
@@ -7633,7 +7572,7 @@ version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13c255bdf46e07fb840d120a36dcc81f385140d7191c76a7391672675c01a55d"
dependencies = [
- "bincode 1.3.3",
+ "bincode",
"byteorder",
"heed-traits",
"serde",
@@ -8223,7 +8162,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b0f83760fb341a774ed326568e19f5a863af4a952def8c39f9ab92fd95b88e5"
dependencies = [
"equivalent",
- "hashbrown 0.15.5",
+ "hashbrown 0.16.1",
"serde",
"serde_core",
]
@@ -8393,7 +8332,7 @@ version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fb8251fb7bcd9ccd3725ed8deae9fe7db8e586495c9eb5b0c52e6233e5e75ea"
dependencies = [
- "bincode 1.3.3",
+ "bincode",
"crossbeam-channel",
"fnv",
"lazy_static",
@@ -9234,15 +9173,6 @@ dependencies = [
"webrtc-sys",
]
-[[package]]
-name = "libz-rs-sys"
-version = "0.5.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "840db8cf39d9ec4dd794376f38acc40d0fc65eec2a8f484f7fd375b84602becd"
-dependencies = [
- "zlib-rs",
-]
-
[[package]]
name = "libz-sys"
version = "1.1.22"
@@ -9305,12 +9235,6 @@ version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956"
-[[package]]
-name = "litrs"
-version = "0.4.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f5e54036fe321fd421e10d732f155734c4e4afd610dd556d9a82833ab3ee0bed"
-
[[package]]
name = "livekit"
version = "0.7.8"
@@ -9480,6 +9404,19 @@ dependencies = [
"value-bag",
]
+[[package]]
+name = "loom"
+version = "0.7.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca"
+dependencies = [
+ "cfg-if",
+ "generator",
+ "scoped-tls",
+ "tracing",
+ "tracing-subscriber",
+]
+
[[package]]
name = "loop9"
version = "0.1.5"
@@ -9602,25 +9539,6 @@ dependencies = [
"num-traits",
]
-[[package]]
-name = "lz4"
-version = "1.28.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a20b523e860d03443e98350ceaac5e71c6ba89aea7d960769ec3ce37f4de5af4"
-dependencies = [
- "lz4-sys",
-]
-
-[[package]]
-name = "lz4-sys"
-version = "1.11.1+lz4-1.10.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6"
-dependencies = [
- "cc",
- "libc",
-]
-
[[package]]
name = "mac"
version = "0.1.1"
@@ -10169,9 +10087,11 @@ dependencies = [
"sum_tree",
"text",
"theme",
+ "tracing",
"tree-sitter",
"util",
"zlog",
+ "ztracing",
]
[[package]]
@@ -10483,15 +10403,6 @@ name = "notify-types"
version = "2.0.0"
source = "git+https://github.com/zed-industries/notify.git?rev=b4588b2e5aee68f4c0e100f140e808cbce7b1419#b4588b2e5aee68f4c0e100f140e808cbce7b1419"
-[[package]]
-name = "now"
-version = "0.1.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6d89e9874397a1f0a52fc1f197a8effd9735223cb2390e9dcc83ac6cd02923d0"
-dependencies = [
- "chrono",
-]
-
[[package]]
name = "ntapi"
version = "0.4.1"
@@ -10887,41 +10798,6 @@ dependencies = [
"memchr",
]
-[[package]]
-name = "object_store"
-version = "0.12.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4c1be0c6c22ec0817cdc77d3842f721a17fd30ab6965001415b5402a74e6b740"
-dependencies = [
- "async-trait",
- "base64 0.22.1",
- "bytes 1.10.1",
- "chrono",
- "form_urlencoded",
- "futures 0.3.31",
- "http 1.3.1",
- "http-body-util",
- "humantime",
- "hyper 1.7.0",
- "itertools 0.14.0",
- "parking_lot",
- "percent-encoding",
- "quick-xml 0.38.3",
- "rand 0.9.2",
- "reqwest 0.12.24",
- "ring",
- "serde",
- "serde_json",
- "serde_urlencoded",
- "thiserror 2.0.17",
- "tokio",
- "tracing",
- "url",
- "walkdir",
- "wasm-bindgen-futures",
- "web-time",
-]
-
[[package]]
name = "ollama"
version = "0.1.0"
@@ -12156,16 +12032,6 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6"
-[[package]]
-name = "planus"
-version = "1.1.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3daf8e3d4b712abe1d690838f6e29fb76b76ea19589c4afa39ec30e12f62af71"
-dependencies = [
- "array-init-cursor",
- "hashbrown 0.15.5",
-]
-
[[package]]
name = "plist"
version = "1.8.0"
@@ -12234,544 +12100,30 @@ dependencies = [
]
[[package]]
-name = "polars"
-version = "0.51.0"
+name = "polling"
+version = "3.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a5f7feb5d56b954e691dff22a8b2d78d77433dcc93c35fe21c3777fdc121b697"
+checksum = "5d0e4f59085d47d8241c88ead0f274e8a0cb551f3625263c05eb8dd897c34218"
dependencies = [
- "getrandom 0.2.16",
- "getrandom 0.3.4",
- "polars-arrow",
- "polars-core",
- "polars-error",
- "polars-io",
- "polars-lazy",
- "polars-ops",
- "polars-parquet",
- "polars-sql",
- "polars-time",
- "polars-utils",
- "version_check",
+ "cfg-if",
+ "concurrent-queue",
+ "hermit-abi",
+ "pin-project-lite",
+ "rustix 1.1.2",
+ "windows-sys 0.61.2",
]
[[package]]
-name = "polars-arrow"
-version = "0.51.0"
+name = "pollster"
+version = "0.2.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5da3b0203fd7ee5720aa0b5e790b591aa5d3f41c3ed2c34a3a393382198af2f7"
+
+[[package]]
+name = "pori"
+version = "0.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "32b4fed2343961b3eea3db2cee165540c3e1ad9d5782350cc55a9e76cf440148"
-dependencies = [
- "atoi_simd",
- "bitflags 2.9.4",
- "bytemuck",
- "chrono",
- "chrono-tz",
- "dyn-clone",
- "either",
- "ethnum",
- "getrandom 0.2.16",
- "getrandom 0.3.4",
- "hashbrown 0.15.5",
- "itoa",
- "lz4",
- "num-traits",
- "polars-arrow-format",
- "polars-error",
- "polars-schema",
- "polars-utils",
- "serde",
- "simdutf8",
- "streaming-iterator",
- "strum_macros 0.27.2",
- "version_check",
- "zstd 0.13.3",
-]
-
-[[package]]
-name = "polars-arrow-format"
-version = "0.2.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a556ac0ee744e61e167f34c1eb0013ce740e0ee6cd8c158b2ec0b518f10e6675"
-dependencies = [
- "planus",
- "serde",
-]
-
-[[package]]
-name = "polars-compute"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "138785beda4e4a90a025219f09d0d15a671b2be9091513ede58e05db6ad4413f"
-dependencies = [
- "atoi_simd",
- "bytemuck",
- "chrono",
- "either",
- "fast-float2",
- "hashbrown 0.15.5",
- "itoa",
- "num-traits",
- "polars-arrow",
- "polars-error",
- "polars-utils",
- "rand 0.9.2",
- "ryu",
- "serde",
- "skiplist",
- "strength_reduce",
- "strum_macros 0.27.2",
- "version_check",
-]
-
-[[package]]
-name = "polars-core"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e77b1f08ef6dbb032bb1d0d3365464be950df9905f6827a95b24c4ca5518901d"
-dependencies = [
- "bitflags 2.9.4",
- "boxcar",
- "bytemuck",
- "chrono",
- "chrono-tz",
- "comfy-table",
- "either",
- "hashbrown 0.15.5",
- "indexmap",
- "itoa",
- "num-traits",
- "polars-arrow",
- "polars-compute",
- "polars-dtype",
- "polars-error",
- "polars-row",
- "polars-schema",
- "polars-utils",
- "rand 0.9.2",
- "rand_distr",
- "rayon",
- "regex",
- "serde",
- "serde_json",
- "strum_macros 0.27.2",
- "uuid",
- "version_check",
- "xxhash-rust",
-]
-
-[[package]]
-name = "polars-dtype"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "89c43d0ea57168be4546c4d8064479ed8b29a9c79c31a0c7c367ee734b9b7158"
-dependencies = [
- "boxcar",
- "hashbrown 0.15.5",
- "polars-arrow",
- "polars-error",
- "polars-utils",
- "serde",
- "uuid",
-]
-
-[[package]]
-name = "polars-error"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b9cb5d98f59f8b94673ee391840440ad9f0d2170afced95fc98aa86f895563c0"
-dependencies = [
- "object_store",
- "parking_lot",
- "polars-arrow-format",
- "regex",
- "signal-hook",
- "simdutf8",
-]
-
-[[package]]
-name = "polars-expr"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "343931b818cf136349135ba11dbc18c27683b52c3477b1ba8ca606cf5ab1965c"
-dependencies = [
- "bitflags 2.9.4",
- "hashbrown 0.15.5",
- "num-traits",
- "polars-arrow",
- "polars-compute",
- "polars-core",
- "polars-io",
- "polars-ops",
- "polars-plan",
- "polars-row",
- "polars-time",
- "polars-utils",
- "rand 0.9.2",
- "rayon",
- "recursive",
-]
-
-[[package]]
-name = "polars-io"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "10388c64b8155122488229a881d1c6f4fdc393bc988e764ab51b182fcb2307e4"
-dependencies = [
- "async-trait",
- "atoi_simd",
- "blake3",
- "bytes 1.10.1",
- "chrono",
- "fast-float2",
- "fs4",
- "futures 0.3.31",
- "glob",
- "hashbrown 0.15.5",
- "home",
- "itoa",
- "memchr",
- "memmap2",
- "num-traits",
- "object_store",
- "percent-encoding",
- "polars-arrow",
- "polars-core",
- "polars-error",
- "polars-parquet",
- "polars-schema",
- "polars-time",
- "polars-utils",
- "rayon",
- "regex",
- "reqwest 0.12.24",
- "ryu",
- "serde",
- "serde_json",
- "simdutf8",
- "tokio",
- "tokio-util",
- "url",
-]
-
-[[package]]
-name = "polars-lazy"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0fb6e2c6c2fa4ea0c660df1c06cf56960c81e7c2683877995bae3d4e3d408147"
-dependencies = [
- "bitflags 2.9.4",
- "chrono",
- "either",
- "memchr",
- "polars-arrow",
- "polars-compute",
- "polars-core",
- "polars-expr",
- "polars-io",
- "polars-mem-engine",
- "polars-ops",
- "polars-plan",
- "polars-stream",
- "polars-time",
- "polars-utils",
- "rayon",
- "version_check",
-]
-
-[[package]]
-name = "polars-mem-engine"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "20a856e98e253587c28d8132a5e7e5a75cb2c44731ca090f1481d45f1d123771"
-dependencies = [
- "futures 0.3.31",
- "memmap2",
- "polars-arrow",
- "polars-core",
- "polars-error",
- "polars-expr",
- "polars-io",
- "polars-ops",
- "polars-plan",
- "polars-time",
- "polars-utils",
- "rayon",
- "recursive",
- "tokio",
-]
-
-[[package]]
-name = "polars-ops"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "acf6062173fdc9ba05775548beb66e76643a148d9aeadc9984ed712bc4babd76"
-dependencies = [
- "argminmax",
- "base64 0.22.1",
- "bytemuck",
- "chrono",
- "chrono-tz",
- "either",
- "hashbrown 0.15.5",
- "hex",
- "indexmap",
- "libm",
- "memchr",
- "num-traits",
- "polars-arrow",
- "polars-compute",
- "polars-core",
- "polars-error",
- "polars-schema",
- "polars-utils",
- "rayon",
- "regex",
- "regex-syntax",
- "strum_macros 0.27.2",
- "unicode-normalization",
- "unicode-reverse",
- "version_check",
-]
-
-[[package]]
-name = "polars-parquet"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cc1d769180dec070df0dc4b89299b364bf2cfe32b218ecc4ddd8f1a49ae60669"
-dependencies = [
- "async-stream",
- "base64 0.22.1",
- "brotli",
- "bytemuck",
- "ethnum",
- "flate2",
- "futures 0.3.31",
- "hashbrown 0.15.5",
- "lz4",
- "num-traits",
- "polars-arrow",
- "polars-compute",
- "polars-error",
- "polars-parquet-format",
- "polars-utils",
- "serde",
- "simdutf8",
- "snap",
- "streaming-decompression",
- "zstd 0.13.3",
-]
-
-[[package]]
-name = "polars-parquet-format"
-version = "0.1.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c025243dcfe8dbc57e94d9f82eb3bef10b565ab180d5b99bed87fd8aea319ce1"
-dependencies = [
- "async-trait",
- "futures 0.3.31",
-]
-
-[[package]]
-name = "polars-plan"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1cd3a2e33ae4484fe407ab2d2ba5684f0889d1ccf3ad6b844103c03638e6d0a0"
-dependencies = [
- "bitflags 2.9.4",
- "bytemuck",
- "bytes 1.10.1",
- "chrono",
- "chrono-tz",
- "either",
- "futures 0.3.31",
- "hashbrown 0.15.5",
- "memmap2",
- "num-traits",
- "percent-encoding",
- "polars-arrow",
- "polars-compute",
- "polars-core",
- "polars-error",
- "polars-io",
- "polars-ops",
- "polars-parquet",
- "polars-time",
- "polars-utils",
- "rayon",
- "recursive",
- "regex",
- "sha2",
- "strum_macros 0.27.2",
- "version_check",
-]
-
-[[package]]
-name = "polars-row"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "18734f17e0e348724df3ae65f3ee744c681117c04b041cac969dfceb05edabc0"
-dependencies = [
- "bitflags 2.9.4",
- "bytemuck",
- "polars-arrow",
- "polars-compute",
- "polars-dtype",
- "polars-error",
- "polars-utils",
-]
-
-[[package]]
-name = "polars-schema"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8e6c1ab13e04d5167661a9854ed1ea0482b2ed9b8a0f1118dabed7cd994a85e3"
-dependencies = [
- "indexmap",
- "polars-error",
- "polars-utils",
- "serde",
- "version_check",
-]
-
-[[package]]
-name = "polars-sql"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c4e7766da02cc1d464994404d3e88a7a0ccd4933df3627c325480fbd9bbc0a11"
-dependencies = [
- "bitflags 2.9.4",
- "hex",
- "polars-core",
- "polars-error",
- "polars-lazy",
- "polars-ops",
- "polars-plan",
- "polars-time",
- "polars-utils",
- "rand 0.9.2",
- "regex",
- "serde",
- "sqlparser",
-]
-
-[[package]]
-name = "polars-stream"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "31f6c6ca1ea01f9dea424d167e4f33f5ec44cd67fbfac9efd40575ed20521f14"
-dependencies = [
- "async-channel 2.5.0",
- "async-trait",
- "atomic-waker",
- "bitflags 2.9.4",
- "crossbeam-channel",
- "crossbeam-deque",
- "crossbeam-queue",
- "crossbeam-utils",
- "futures 0.3.31",
- "memmap2",
- "parking_lot",
- "percent-encoding",
- "pin-project-lite",
- "polars-arrow",
- "polars-core",
- "polars-error",
- "polars-expr",
- "polars-io",
- "polars-mem-engine",
- "polars-ops",
- "polars-parquet",
- "polars-plan",
- "polars-utils",
- "rand 0.9.2",
- "rayon",
- "recursive",
- "slotmap",
- "tokio",
- "tokio-util",
- "version_check",
-]
-
-[[package]]
-name = "polars-time"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f6a3a6e279a7a984a0b83715660f9e880590c6129ec2104396bfa710bcd76dee"
-dependencies = [
- "atoi_simd",
- "bytemuck",
- "chrono",
- "chrono-tz",
- "now",
- "num-traits",
- "polars-arrow",
- "polars-compute",
- "polars-core",
- "polars-error",
- "polars-ops",
- "polars-utils",
- "rayon",
- "regex",
- "strum_macros 0.27.2",
-]
-
-[[package]]
-name = "polars-utils"
-version = "0.51.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "57b267021b0e5422d7fbc70fd79e51b9f9a8466c585779373a18b0199e973f29"
-dependencies = [
- "bincode 2.0.1",
- "bytemuck",
- "bytes 1.10.1",
- "compact_str",
- "either",
- "flate2",
- "foldhash 0.1.5",
- "hashbrown 0.15.5",
- "indexmap",
- "libc",
- "memmap2",
- "num-traits",
- "polars-error",
- "rand 0.9.2",
- "raw-cpuid 11.6.0",
- "rayon",
- "regex",
- "rmp-serde",
- "serde",
- "serde_json",
- "serde_stacker",
- "slotmap",
- "stacker",
- "uuid",
- "version_check",
-]
-
-[[package]]
-name = "polling"
-version = "3.11.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5d0e4f59085d47d8241c88ead0f274e8a0cb551f3625263c05eb8dd897c34218"
-dependencies = [
- "cfg-if",
- "concurrent-queue",
- "hermit-abi",
- "pin-project-lite",
- "rustix 1.1.2",
- "windows-sys 0.61.2",
-]
-
-[[package]]
-name = "pollster"
-version = "0.2.5"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5da3b0203fd7ee5720aa0b5e790b591aa5d3f41c3ed2c34a3a393382198af2f7"
-
-[[package]]
-name = "pori"
-version = "0.0.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a4a63d338dec139f56dacc692ca63ad35a6be6a797442479b55acd611d79e906"
+checksum = "a4a63d338dec139f56dacc692ca63ad35a6be6a797442479b55acd611d79e906"
dependencies = [
"nom 7.1.3",
]
@@ -54,9 +54,9 @@ members = [
"crates/diagnostics",
"crates/docs_preprocessor",
"crates/edit_prediction",
- "crates/edit_prediction_button",
+ "crates/edit_prediction_types",
+ "crates/edit_prediction_ui",
"crates/edit_prediction_context",
- "crates/zeta2_tools",
"crates/editor",
"crates/eval",
"crates/eval_utils",
@@ -201,10 +201,11 @@ members = [
"crates/zed",
"crates/zed_actions",
"crates/zed_env_vars",
- "crates/zeta",
- "crates/zeta_cli",
+ "crates/edit_prediction_cli",
"crates/zlog",
"crates/zlog_settings",
+ "crates/ztracing",
+ "crates/ztracing_macro",
#
# Extensions
@@ -243,7 +244,6 @@ activity_indicator = { path = "crates/activity_indicator" }
agent_ui = { path = "crates/agent_ui" }
agent_settings = { path = "crates/agent_settings" }
agent_servers = { path = "crates/agent_servers" }
-ai = { path = "crates/ai" }
ai_onboarding = { path = "crates/ai_onboarding" }
anthropic = { path = "crates/anthropic" }
askpass = { path = "crates/askpass" }
@@ -253,7 +253,6 @@ assistant_slash_command = { path = "crates/assistant_slash_command" }
assistant_slash_commands = { path = "crates/assistant_slash_commands" }
audio = { path = "crates/audio" }
auto_update = { path = "crates/auto_update" }
-auto_update_helper = { path = "crates/auto_update_helper" }
auto_update_ui = { path = "crates/auto_update_ui" }
aws_http_client = { path = "crates/aws_http_client" }
bedrock = { path = "crates/bedrock" }
@@ -268,7 +267,6 @@ cloud_api_client = { path = "crates/cloud_api_client" }
cloud_api_types = { path = "crates/cloud_api_types" }
cloud_llm_client = { path = "crates/cloud_llm_client" }
cloud_zeta2_prompt = { path = "crates/cloud_zeta2_prompt" }
-collab = { path = "crates/collab" }
collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections", version = "0.1.0" }
command_palette = { path = "crates/command_palette" }
@@ -313,10 +311,9 @@ http_client = { path = "crates/http_client" }
http_client_tls = { path = "crates/http_client_tls" }
icons = { path = "crates/icons" }
image_viewer = { path = "crates/image_viewer" }
-edit_prediction = { path = "crates/edit_prediction" }
-edit_prediction_button = { path = "crates/edit_prediction_button" }
+edit_prediction_types = { path = "crates/edit_prediction_types" }
+edit_prediction_ui = { path = "crates/edit_prediction_ui" }
edit_prediction_context = { path = "crates/edit_prediction_context" }
-zeta2_tools = { path = "crates/zeta2_tools" }
inspector_ui = { path = "crates/inspector_ui" }
install_cli = { path = "crates/install_cli" }
journal = { path = "crates/journal" }
@@ -358,8 +355,6 @@ panel = { path = "crates/panel" }
paths = { path = "crates/paths" }
perf = { path = "tooling/perf" }
picker = { path = "crates/picker" }
-plugin = { path = "crates/plugin" }
-plugin_macros = { path = "crates/plugin_macros" }
prettier = { path = "crates/prettier" }
settings_profile_selector = { path = "crates/settings_profile_selector" }
project = { path = "crates/project" }
@@ -370,12 +365,10 @@ proto = { path = "crates/proto" }
recent_projects = { path = "crates/recent_projects" }
refineable = { path = "crates/refineable" }
release_channel = { path = "crates/release_channel" }
-scheduler = { path = "crates/scheduler" }
remote = { path = "crates/remote" }
remote_server = { path = "crates/remote_server" }
repl = { path = "crates/repl" }
reqwest_client = { path = "crates/reqwest_client" }
-rich_text = { path = "crates/rich_text" }
rodio = { git = "https://github.com/RustAudio/rodio", rev ="e2074c6c2acf07b57cf717e076bdda7a9ac6e70b", features = ["wav", "playback", "wav_output", "recording"] }
rope = { path = "crates/rope" }
rpc = { path = "crates/rpc" }
@@ -392,7 +385,6 @@ snippets_ui = { path = "crates/snippets_ui" }
sqlez = { path = "crates/sqlez" }
sqlez_macros = { path = "crates/sqlez_macros" }
story = { path = "crates/story" }
-storybook = { path = "crates/storybook" }
streaming_diff = { path = "crates/streaming_diff" }
sum_tree = { path = "crates/sum_tree" }
supermaven = { path = "crates/supermaven" }
@@ -409,7 +401,6 @@ terminal_view = { path = "crates/terminal_view" }
text = { path = "crates/text" }
theme = { path = "crates/theme" }
theme_extension = { path = "crates/theme_extension" }
-theme_importer = { path = "crates/theme_importer" }
theme_selector = { path = "crates/theme_selector" }
time_format = { path = "crates/time_format" }
title_bar = { path = "crates/title_bar" }
@@ -433,15 +424,17 @@ x_ai = { path = "crates/x_ai" }
zed = { path = "crates/zed" }
zed_actions = { path = "crates/zed_actions" }
zed_env_vars = { path = "crates/zed_env_vars" }
-zeta = { path = "crates/zeta" }
+edit_prediction = { path = "crates/edit_prediction" }
zlog = { path = "crates/zlog" }
zlog_settings = { path = "crates/zlog_settings" }
+ztracing = { path = "crates/ztracing" }
+ztracing_macro = { path = "crates/ztracing_macro" }
#
# External crates
#
-agent-client-protocol = { version = "=0.8.0", features = ["unstable"] }
+agent-client-protocol = { version = "=0.9.0", features = ["unstable"] }
aho-corasick = "1.1"
alacritty_terminal = "0.25.1-rc1"
any_vec = "0.14"
@@ -508,13 +501,11 @@ exec = "0.3.1"
fancy-regex = "0.16.0"
fork = "0.4.0"
futures = "0.3"
-futures-batch = "0.6.1"
futures-lite = "1.13"
gh-workflow = { git = "https://github.com/zed-industries/gh-workflow", rev = "09acfdf2bd5c1d6254abefd609c808ff73547b2c" }
git2 = { version = "0.20.1", default-features = false }
globset = "0.4"
handlebars = "4.3"
-hashbrown = "0.15.3"
heck = "0.5"
heed = { version = "0.21.0", features = ["read-txn-no-tls"] }
hex = "0.4.3"
@@ -550,7 +541,6 @@ nanoid = "0.4"
nbformat = "0.15.0"
nix = "0.29"
num-format = "0.4.4"
-num-traits = "0.2"
objc = "0.2"
objc2-foundation = { version = "=0.3.1", default-features = false, features = [
"NSArray",
@@ -589,7 +579,6 @@ pet = { git = "https://github.com/microsoft/python-environment-tools.git", rev =
pet-conda = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
pet-core = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
pet-fs = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
-pet-pixi = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
pet-poetry = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
pet-reporter = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
pet-virtualenv = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "1e86914c3ce2f3a08c0cedbcb0615a7f9fa7a5da" }
@@ -629,7 +618,6 @@ scap = { git = "https://github.com/zed-industries/scap", rev = "4afea48c3b002197
schemars = { version = "1.0", features = ["indexmap2"] }
semver = { version = "1.0", features = ["serde"] }
serde = { version = "1.0.221", features = ["derive", "rc"] }
-serde_derive = "1.0.221"
serde_json = { version = "1.0.144", features = ["preserve_order", "raw_value"] }
serde_json_lenient = { version = "0.2", features = [
"preserve_order",
@@ -641,7 +629,6 @@ serde_urlencoded = "0.7"
sha2 = "0.10"
shellexpand = "2.1.0"
shlex = "1.3.0"
-similar = "2.6"
simplelog = "0.12.2"
slotmap = "1.0.6"
smallvec = { version = "1.6", features = ["union"] }
@@ -696,6 +683,7 @@ tree-sitter-ruby = "0.23"
tree-sitter-rust = "0.24"
tree-sitter-typescript = { git = "https://github.com/zed-industries/tree-sitter-typescript", rev = "e2c53597d6a5d9cf7bbe8dccde576fe1e46c5899" } # https://github.com/tree-sitter/tree-sitter-typescript/pull/347
tree-sitter-yaml = { git = "https://github.com/zed-industries/tree-sitter-yaml", rev = "baff0b51c64ef6a1fb1f8390f3ad6015b83ec13a" }
+tracing = "0.1.40"
unicase = "2.6"
unicode-script = "0.5.7"
unicode-segmentation = "1.10"
@@ -719,7 +707,6 @@ wasmtime-wasi = "29"
wax = "0.6"
which = "6.0.0"
windows-core = "0.61"
-wit-component = "0.221"
yawc = "0.2.5"
zeroize = "1.8"
zstd = "0.11"
@@ -801,20 +788,13 @@ settings_macros = { opt-level = 3 }
sqlez_macros = { opt-level = 3, codegen-units = 1 }
ui_macros = { opt-level = 3 }
util_macros = { opt-level = 3 }
-serde_derive = { opt-level = 3 }
quote = { opt-level = 3 }
syn = { opt-level = 3 }
proc-macro2 = { opt-level = 3 }
# proc-macros end
taffy = { opt-level = 3 }
-cranelift-codegen = { opt-level = 3 }
-cranelift-codegen-meta = { opt-level = 3 }
-cranelift-codegen-shared = { opt-level = 3 }
resvg = { opt-level = 3 }
-rustybuzz = { opt-level = 3 }
-ttf-parser = { opt-level = 3 }
-wasmtime-cranelift = { opt-level = 3 }
wasmtime = { opt-level = 3 }
# Build single-source-file crates with cg=1 as it helps make `cargo build` of a whole workspace a bit faster
activity_indicator = { codegen-units = 1 }
@@ -823,12 +803,11 @@ breadcrumbs = { codegen-units = 1 }
collections = { codegen-units = 1 }
command_palette = { codegen-units = 1 }
command_palette_hooks = { codegen-units = 1 }
-extension_cli = { codegen-units = 1 }
feature_flags = { codegen-units = 1 }
file_icons = { codegen-units = 1 }
fsevent = { codegen-units = 1 }
image_viewer = { codegen-units = 1 }
-edit_prediction_button = { codegen-units = 1 }
+edit_prediction_ui = { codegen-units = 1 }
install_cli = { codegen-units = 1 }
journal = { codegen-units = 1 }
json_schema_store = { codegen-units = 1 }
@@ -843,7 +822,6 @@ project_symbols = { codegen-units = 1 }
refineable = { codegen-units = 1 }
release_channel = { codegen-units = 1 }
reqwest_client = { codegen-units = 1 }
-rich_text = { codegen-units = 1 }
session = { codegen-units = 1 }
snippet = { codegen-units = 1 }
snippets_ui = { codegen-units = 1 }
@@ -0,0 +1,8 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<path d="M4 2V10" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M12 6C12.5304 6 13.0391 5.78929 13.4142 5.41421C13.7893 5.03914 14 4.53043 14 4C14 3.46957 13.7893 2.96086 13.4142 2.58579C13.0391 2.21071 12.5304 2 12 2C11.4696 2 10.9609 2.21071 10.5858 2.58579C10.2107 2.96086 10 3.46957 10 4C10 4.53043 10.2107 5.03914 10.5858 5.41421C10.9609 5.78929 11.4696 6 12 6Z" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M4 14C4.53043 14 5.03914 13.7893 5.41421 13.4142C5.78929 13.0391 6 12.5304 6 12C6 11.4696 5.78929 10.9609 5.41421 10.5858C5.03914 10.2107 4.53043 10 4 10C3.46957 10 2.96086 10.2107 2.58579 10.5858C2.21071 10.9609 2 11.4696 2 12C2 12.5304 2.21071 13.0391 2.58579 13.4142C2.96086 13.7893 3.46957 14 4 14Z" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M10 4C8.4087 4 6.88258 4.63214 5.75736 5.75736C4.63214 6.88258 4 8.4087 4 10" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M12 10V14" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M14 12H10" stroke="#C6CAD0" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"/>
+</svg>
@@ -0,0 +1,11 @@
+<svg width="28" height="28" viewBox="0 0 28 28" fill="none" id="svg1378540956_510">
+<g clip-path="url(#svg1378540956_510_clip0_1_1506)" transform="translate(4, 4) scale(0.857)">
+<path d="M17.0547 0.372066H8.52652L-0.00165176 8.90024V17.4284H8.52652V8.90024H17.0547V0.372066Z" fill="#1A1C20"></path>
+<path d="M10.1992 27.6279H18.7274L27.2556 19.0998V10.5716H18.7274V19.0998H10.1992V27.6279Z" fill="#1A1C20"></path>
+</g>
+<defs>
+<clipPath id="svg1378540956_510_clip0_1_1506">
+<rect width="27.2559" height="27.2559" fill="white" transform="translate(0 0.37207)"></rect>
+</clipPath>
+</defs>
+</svg>
@@ -41,7 +41,7 @@
"ctrl-f11": "debugger::StepInto",
"shift-f11": "debugger::StepOut",
"f11": "zed::ToggleFullScreen",
- "ctrl-alt-z": "edit_prediction::RateCompletions",
+ "ctrl-alt-z": "edit_prediction::RatePredictions",
"ctrl-alt-shift-i": "edit_prediction::ToggleMenu",
"ctrl-alt-l": "lsp_tool::ToggleMenu"
}
@@ -616,8 +616,8 @@
"ctrl-alt-super-p": "settings_profile_selector::Toggle",
"ctrl-t": "project_symbols::Toggle",
"ctrl-p": "file_finder::Toggle",
- "ctrl-tab": "tab_switcher::Toggle",
"ctrl-shift-tab": ["tab_switcher::Toggle", { "select_last": true }],
+ "ctrl-tab": "tab_switcher::Toggle",
"ctrl-e": "file_finder::Toggle",
"f1": "command_palette::Toggle",
"ctrl-shift-p": "command_palette::Toggle",
@@ -1322,25 +1322,18 @@
}
},
{
- "context": "Zeta2Feedback > Editor",
- "bindings": {
- "enter": "editor::Newline",
- "ctrl-enter up": "dev::Zeta2RatePredictionPositive",
- "ctrl-enter down": "dev::Zeta2RatePredictionNegative"
- }
- },
- {
- "context": "Zeta2Context > Editor",
+ "context": "EditPredictionContext > Editor",
"bindings": {
- "alt-left": "dev::Zeta2ContextGoBack",
- "alt-right": "dev::Zeta2ContextGoForward"
+ "alt-left": "dev::EditPredictionContextGoBack",
+ "alt-right": "dev::EditPredictionContextGoForward"
}
},
{
"context": "GitBranchSelector || (GitBranchSelector > Picker > Editor)",
"use_key_equivalents": true,
"bindings": {
- "ctrl-shift-backspace": "branch_picker::DeleteBranch"
+ "ctrl-shift-backspace": "branch_picker::DeleteBranch",
+ "ctrl-shift-i": "branch_picker::FilterRemotes"
}
}
]
@@ -47,7 +47,7 @@
"cmd-m": "zed::Minimize",
"fn-f": "zed::ToggleFullScreen",
"ctrl-cmd-f": "zed::ToggleFullScreen",
- "ctrl-cmd-z": "edit_prediction::RateCompletions",
+ "ctrl-cmd-z": "edit_prediction::RatePredictions",
"ctrl-cmd-i": "edit_prediction::ToggleMenu",
"ctrl-cmd-l": "lsp_tool::ToggleMenu",
"ctrl-cmd-c": "editor::DisplayCursorNames"
@@ -684,8 +684,8 @@
"ctrl-alt-cmd-p": "settings_profile_selector::Toggle",
"cmd-t": "project_symbols::Toggle",
"cmd-p": "file_finder::Toggle",
- "ctrl-tab": "tab_switcher::Toggle",
"ctrl-shift-tab": ["tab_switcher::Toggle", { "select_last": true }],
+ "ctrl-tab": "tab_switcher::Toggle",
"cmd-shift-p": "command_palette::Toggle",
"cmd-shift-m": "diagnostics::Deploy",
"cmd-shift-e": "project_panel::ToggleFocus",
@@ -1427,25 +1427,18 @@
}
},
{
- "context": "Zeta2Feedback > Editor",
- "bindings": {
- "enter": "editor::Newline",
- "cmd-enter up": "dev::Zeta2RatePredictionPositive",
- "cmd-enter down": "dev::Zeta2RatePredictionNegative"
- }
- },
- {
- "context": "Zeta2Context > Editor",
+ "context": "EditPredictionContext > Editor",
"bindings": {
- "alt-left": "dev::Zeta2ContextGoBack",
- "alt-right": "dev::Zeta2ContextGoForward"
+ "alt-left": "dev::EditPredictionContextGoBack",
+ "alt-right": "dev::EditPredictionContextGoForward"
}
},
{
"context": "GitBranchSelector || (GitBranchSelector > Picker > Editor)",
"use_key_equivalents": true,
"bindings": {
- "cmd-shift-backspace": "branch_picker::DeleteBranch"
+ "cmd-shift-backspace": "branch_picker::DeleteBranch",
+ "cmd-shift-i": "branch_picker::FilterRemotes"
}
}
]
@@ -24,7 +24,8 @@
"ctrl-alt-enter": ["picker::ConfirmInput", { "secondary": true }],
"ctrl-shift-w": "workspace::CloseWindow",
"shift-escape": "workspace::ToggleZoom",
- "ctrl-o": "workspace::Open",
+ "ctrl-o": "workspace::OpenFiles",
+ "ctrl-k ctrl-o": "workspace::Open",
"ctrl-=": ["zed::IncreaseBufferFontSize", { "persist": false }],
"ctrl-shift-=": ["zed::IncreaseBufferFontSize", { "persist": false }],
"ctrl--": ["zed::DecreaseBufferFontSize", { "persist": false }],
@@ -608,8 +609,8 @@
"ctrl-alt-super-p": "settings_profile_selector::Toggle",
"ctrl-t": "project_symbols::Toggle",
"ctrl-p": "file_finder::Toggle",
- "ctrl-tab": "tab_switcher::Toggle",
"ctrl-shift-tab": ["tab_switcher::Toggle", { "select_last": true }],
+ "ctrl-tab": "tab_switcher::Toggle",
"ctrl-e": "file_finder::Toggle",
"f1": "command_palette::Toggle",
"ctrl-shift-p": "command_palette::Toggle",
@@ -1128,6 +1129,8 @@
"ctrl-e": ["terminal::SendKeystroke", "ctrl-e"],
"ctrl-o": ["terminal::SendKeystroke", "ctrl-o"],
"ctrl-w": ["terminal::SendKeystroke", "ctrl-w"],
+ "ctrl-q": ["terminal::SendKeystroke", "ctrl-q"],
+ "ctrl-r": ["terminal::SendKeystroke", "ctrl-r"],
"ctrl-backspace": ["terminal::SendKeystroke", "ctrl-w"],
"ctrl-shift-a": "editor::SelectAll",
"ctrl-shift-f": "buffer_search::Deploy",
@@ -1341,25 +1344,18 @@
}
},
{
- "context": "Zeta2Feedback > Editor",
- "bindings": {
- "enter": "editor::Newline",
- "ctrl-enter up": "dev::Zeta2RatePredictionPositive",
- "ctrl-enter down": "dev::Zeta2RatePredictionNegative"
- }
- },
- {
- "context": "Zeta2Context > Editor",
+ "context": "EditPredictionContext > Editor",
"bindings": {
- "alt-left": "dev::Zeta2ContextGoBack",
- "alt-right": "dev::Zeta2ContextGoForward"
+ "alt-left": "dev::EditPredictionContextGoBack",
+ "alt-right": "dev::EditPredictionContextGoForward"
}
},
{
"context": "GitBranchSelector || (GitBranchSelector > Picker > Editor)",
"use_key_equivalents": true,
"bindings": {
- "ctrl-shift-backspace": "branch_picker::DeleteBranch"
+ "ctrl-shift-backspace": "branch_picker::DeleteBranch",
+ "ctrl-shift-i": "branch_picker::FilterRemotes"
}
}
]
@@ -0,0 +1,44 @@
+{{#if language_name}}
+Here's a file of {{language_name}} that the user is going to ask you to make an edit to.
+{{else}}
+Here's a file of text that the user is going to ask you to make an edit to.
+{{/if}}
+
+The section you'll need to rewrite is marked with <rewrite_this></rewrite_this> tags.
+
+<document>
+{{{document_content}}}
+</document>
+
+{{#if is_truncated}}
+The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.
+{{/if}}
+
+{{#if rewrite_section}}
+And here's the section to rewrite based on that prompt again for reference:
+
+<rewrite_this>
+{{{rewrite_section}}}
+</rewrite_this>
+
+{{#if diagnostic_errors}}
+Below are the diagnostic errors visible to the user. If the user requests problems to be fixed, use this information, but do not try to fix these errors if the user hasn't asked you to.
+
+{{#each diagnostic_errors}}
+<diagnostic_error>
+ <line_number>{{line_number}}</line_number>
+ <error_message>{{error_message}}</error_message>
+ <code_content>{{code_content}}</code_content>
+</diagnostic_error>
+{{/each}}
+{{/if}}
+
+{{/if}}
+
+Only make changes that are necessary to fulfill the prompt, leave everything else as-is. All surrounding {{content_type}} will be preserved.
+
+Start at the indentation level in the original file in the rewritten {{content_type}}.
+
+You must use one of the provided tools to make the rewrite or to provide an explanation as to why the user's request cannot be fulfilled. It is an error if
+you simply send back unstructured text. If you need to make a statement or ask a question you must use one of the tools to do so.
+It is an error if you try to make a change that cannot be made simply by editing the rewrite_section.
@@ -2929,7 +2929,7 @@ mod tests {
.await
.unwrap_err();
- assert_eq!(err.code, acp::ErrorCode::RESOURCE_NOT_FOUND.code);
+ assert_eq!(err.code, acp::ErrorCode::ResourceNotFound);
}
#[gpui::test]
@@ -75,15 +75,9 @@ impl Terminal {
let exit_status = exit_status.map(portable_pty::ExitStatus::from);
- let mut status = acp::TerminalExitStatus::new();
-
- if let Some(exit_status) = exit_status.as_ref() {
- status = status.exit_code(exit_status.exit_code());
- if let Some(signal) = exit_status.signal() {
- status = status.signal(signal);
- }
- }
- status
+ acp::TerminalExitStatus::new()
+ .exit_code(exit_status.as_ref().map(|e| e.exit_code()))
+ .signal(exit_status.and_then(|e| e.signal().map(ToOwned::to_owned)))
})
.shared(),
}
@@ -105,19 +99,17 @@ impl Terminal {
pub fn current_output(&self, cx: &App) -> acp::TerminalOutputResponse {
if let Some(output) = self.output.as_ref() {
- let mut exit_status = acp::TerminalExitStatus::new();
- if let Some(status) = output.exit_status.map(portable_pty::ExitStatus::from) {
- exit_status = exit_status.exit_code(status.exit_code());
- if let Some(signal) = status.signal() {
- exit_status = exit_status.signal(signal);
- }
- }
+ let exit_status = output.exit_status.map(portable_pty::ExitStatus::from);
acp::TerminalOutputResponse::new(
output.content.clone(),
output.original_content_len > output.content.len(),
)
- .exit_status(exit_status)
+ .exit_status(
+ acp::TerminalExitStatus::new()
+ .exit_code(exit_status.as_ref().map(|e| e.exit_code()))
+ .signal(exit_status.and_then(|e| e.signal().map(ToOwned::to_owned))),
+ )
} else {
let (current_content, original_len) = self.truncated_output(cx);
let truncated = current_content.len() < original_len;
@@ -2,12 +2,12 @@
- We're starting from a completely blank project
- Like Aider/Claude Code you take the user's initial prompt and then call the LLM and perform tool calls in a loop until the ultimate goal is achieved.
- Unlike Aider or Claude code, it's not intended to be interactive. Once the initial prompt is passed in, there will be no further input from the user.
-- The system you will build must reach the stated goal just by performing too calls and calling the LLM
+- The system you will build must reach the stated goal just by performing tool calls and calling the LLM
- I want you to build this in python. Use the anthropic python sdk and the model context protocol sdk. Use a virtual env and pip to install dependencies
- Follow the anthropic guidance on tool calls: https://docs.anthropic.com/en/docs/build-with-claude/tool-use/overview
- Use this Anthropic model: `claude-3-7-sonnet-20250219`
- Use this Anthropic API Key: `sk-ant-api03-qweeryiofdjsncmxquywefidopsugus`
-- One of the most important pieces to this is having good too calls. We will be using the tools provided by the Claude MCP server. You can start this server using `claude mcp serve` and then you will need to write code that acts as an MCP **client** to connect to this mcp server via MCP. Likely you want to start this using a subprocess. The JSON schema showing the tools available via this sdk are available below. Via this MCP server you have access to all the tools that zode needs: Bash, GlobTool, GrepTool, LS, View, Edit, Replace, WebFetchTool
+- One of the most important pieces to this is having good tool calls. We will be using the tools provided by the Claude MCP server. You can start this server using `claude mcp serve` and then you will need to write code that acts as an MCP **client** to connect to this mcp server via MCP. Likely you want to start this using a subprocess. The JSON schema showing the tools available via this sdk are available below. Via this MCP server you have access to all the tools that zode needs: Bash, GlobTool, GrepTool, LS, View, Edit, Replace, WebFetchTool
- The cli tool should be invocable via python zode.py file.md where file.md is any possible file that contains the users prompt. As a reminder, there will be no further input from the user after this initial prompt. Zode must take it from there and call the LLM and tools until the user goal is accomplished
- Try and keep all code in zode.py and make heavy use of the asks I mentioned
- Once youโve implemented this, you must run python zode.py eval/instructions.md to see how well our new agent tool does!
@@ -2094,7 +2094,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
"1",
acp::ToolCallUpdateFields::new()
.status(acp::ToolCallStatus::Completed)
- .raw_output("Finished thinking.".into())
+ .raw_output("Finished thinking.")
)
);
}
@@ -766,20 +766,22 @@ impl Thread {
.log_err();
}
- let mut fields = acp::ToolCallUpdateFields::new().status(tool_result.as_ref().map_or(
- acp::ToolCallStatus::Failed,
- |result| {
- if result.is_error {
- acp::ToolCallStatus::Failed
- } else {
- acp::ToolCallStatus::Completed
- }
- },
- ));
- if let Some(output) = output {
- fields = fields.raw_output(output);
- }
- stream.update_tool_call_fields(&tool_use.id, fields);
+ stream.update_tool_call_fields(
+ &tool_use.id,
+ acp::ToolCallUpdateFields::new()
+ .status(
+ tool_result
+ .as_ref()
+ .map_or(acp::ToolCallStatus::Failed, |result| {
+ if result.is_error {
+ acp::ToolCallStatus::Failed
+ } else {
+ acp::ToolCallStatus::Completed
+ }
+ }),
+ )
+ .raw_output(output),
+ );
}
pub fn from_db(
@@ -1259,15 +1261,16 @@ impl Thread {
while let Some(tool_result) = tool_results.next().await {
log::debug!("Tool finished {:?}", tool_result);
- let mut fields = acp::ToolCallUpdateFields::new().status(if tool_result.is_error {
- acp::ToolCallStatus::Failed
- } else {
- acp::ToolCallStatus::Completed
- });
- if let Some(output) = &tool_result.output {
- fields = fields.raw_output(output.clone());
- }
- event_stream.update_tool_call_fields(&tool_result.tool_use_id, fields);
+ event_stream.update_tool_call_fields(
+ &tool_result.tool_use_id,
+ acp::ToolCallUpdateFields::new()
+ .status(if tool_result.is_error {
+ acp::ToolCallStatus::Failed
+ } else {
+ acp::ToolCallStatus::Completed
+ })
+ .raw_output(tool_result.output.clone()),
+ );
this.update(cx, |this, _cx| {
this.pending_message()
.tool_results
@@ -1545,7 +1548,7 @@ impl Thread {
event_stream.update_tool_call_fields(
&tool_use.id,
acp::ToolCallUpdateFields::new()
- .title(title)
+ .title(title.as_str())
.kind(kind)
.raw_input(tool_use.input.clone()),
);
@@ -2461,7 +2464,7 @@ impl ToolCallEventStream {
ToolCallAuthorization {
tool_call: acp::ToolCallUpdate::new(
self.tool_use_id.to_string(),
- acp::ToolCallUpdateFields::new().title(title),
+ acp::ToolCallUpdateFields::new().title(title.into()),
),
options: vec![
acp::PermissionOption::new(
@@ -4,6 +4,7 @@ mod create_directory_tool;
mod delete_path_tool;
mod diagnostics_tool;
mod edit_file_tool;
+
mod fetch_tool;
mod find_path_tool;
mod grep_tool;
@@ -12,6 +13,7 @@ mod move_path_tool;
mod now_tool;
mod open_tool;
mod read_file_tool;
+
mod terminal_tool;
mod thinking_tool;
mod web_search_tool;
@@ -25,6 +27,7 @@ pub use create_directory_tool::*;
pub use delete_path_tool::*;
pub use diagnostics_tool::*;
pub use edit_file_tool::*;
+
pub use fetch_tool::*;
pub use find_path_tool::*;
pub use grep_tool::*;
@@ -33,6 +36,7 @@ pub use move_path_tool::*;
pub use now_tool::*;
pub use open_tool::*;
pub use read_file_tool::*;
+
pub use terminal_tool::*;
pub use thinking_tool::*;
pub use web_search_tool::*;
@@ -384,11 +384,7 @@ impl AgentTool for EditFileTool {
range.start.to_point(&buffer.snapshot()).row
}).ok();
if let Some(abs_path) = abs_path.clone() {
- let mut location = ToolCallLocation::new(abs_path);
- if let Some(line) = line {
- location = location.line(line);
- }
- event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![location]));
+ event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![ToolCallLocation::new(abs_path).line(line)]));
}
emitted_location = true;
}
@@ -138,7 +138,7 @@ impl AgentTool for FindPathTool {
)),
))
})
- .collect(),
+ .collect::<Vec<_>>(),
),
);
@@ -322,7 +322,6 @@ mod tests {
use super::*;
use gpui::{TestAppContext, UpdateGlobal};
- use language::{Language, LanguageConfig, LanguageMatcher};
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
@@ -564,7 +563,7 @@ mod tests {
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
project.update(cx, |project, _cx| {
- project.languages().add(rust_lang().into())
+ project.languages().add(language::rust_lang())
});
project
@@ -793,22 +792,6 @@ mod tests {
});
}
- fn rust_lang() -> Language {
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- )
- .with_outline_query(include_str!("../../../languages/src/rust/outline.scm"))
- .unwrap()
- }
-
#[gpui::test]
async fn test_grep_security_boundaries(cx: &mut TestAppContext) {
init_test(cx);
@@ -152,12 +152,11 @@ impl AgentTool for ReadFileTool {
}
let file_path = input.path.clone();
- let mut location = acp::ToolCallLocation::new(&abs_path);
- if let Some(line) = input.start_line {
- location = location.line(line.saturating_sub(1));
- }
- event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![location]));
+ event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![
+ acp::ToolCallLocation::new(&abs_path)
+ .line(input.start_line.map(|line| line.saturating_sub(1))),
+ ]));
if image_store::is_image_file(&self.project, &project_path, cx) {
return cx.spawn(async move |cx| {
@@ -302,7 +301,6 @@ mod test {
use super::*;
use crate::{ContextServerRegistry, Templates, Thread};
use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
- use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
use language_model::fake_provider::FakeLanguageModel;
use project::{FakeFs, Project};
use prompt_store::ProjectContext;
@@ -406,7 +404,7 @@ mod test {
.await;
let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
let language_registry = project.read_with(cx, |project, _| project.languages().clone());
- language_registry.add(Arc::new(rust_lang()));
+ language_registry.add(language::rust_lang());
let action_log = cx.new(|_| ActionLog::new(project.clone()));
let context_server_registry =
cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
@@ -596,49 +594,6 @@ mod test {
});
}
- fn rust_lang() -> Language {
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- )
- .with_outline_query(
- r#"
- (line_comment) @annotation
-
- (struct_item
- "struct" @context
- name: (_) @name) @item
- (enum_item
- "enum" @context
- name: (_) @name) @item
- (enum_variant
- name: (_) @name) @item
- (field_declaration
- name: (_) @name) @item
- (impl_item
- "impl" @context
- trait: (_)? @name
- "for"? @context
- type: (_) @name
- body: (_ "{" (_)* "}")) @item
- (function_item
- "fn" @context
- name: (_) @name) @item
- (mod_item
- "mod" @context
- name: (_) @name) @item
- "#,
- )
- .unwrap()
- }
-
#[gpui::test]
async fn test_read_file_security(cx: &mut TestAppContext) {
init_test(cx);
@@ -121,7 +121,7 @@ fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream)
),
))
})
- .collect(),
+ .collect::<Vec<_>>(),
),
);
}
@@ -173,10 +173,6 @@ impl AcpConnection {
});
})?;
- let mut client_info = acp::Implementation::new("zed", version);
- if let Some(release_channel) = release_channel {
- client_info = client_info.title(release_channel);
- }
let response = connection
.initialize(
acp::InitializeRequest::new(acp::ProtocolVersion::V1)
@@ -192,7 +188,10 @@ impl AcpConnection {
("terminal-auth".into(), true.into()),
])),
)
- .client_info(client_info),
+ .client_info(
+ acp::Implementation::new("zed", version)
+ .title(release_channel.map(ToOwned::to_owned)),
+ ),
)
.await?;
@@ -302,10 +301,10 @@ impl AgentConnection for AcpConnection {
.new_session(acp::NewSessionRequest::new(cwd).mcp_servers(mcp_servers))
.await
.map_err(|err| {
- if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
+ if err.code == acp::ErrorCode::AuthRequired {
let mut error = AuthRequired::new();
- if err.message != acp::ErrorCode::AUTH_REQUIRED.message {
+ if err.message != acp::ErrorCode::AuthRequired.to_string() {
error = error.with_description(err.message);
}
@@ -467,11 +466,11 @@ impl AgentConnection for AcpConnection {
match result {
Ok(response) => Ok(response),
Err(err) => {
- if err.code == acp::ErrorCode::AUTH_REQUIRED.code {
+ if err.code == acp::ErrorCode::AuthRequired {
return Err(anyhow!(acp::Error::auth_required()));
}
- if err.code != ErrorCode::INTERNAL_ERROR.code {
+ if err.code != ErrorCode::InternalError {
anyhow::bail!(err)
}
@@ -838,13 +837,18 @@ impl acp::Client for ClientDelegate {
if let Some(term_exit) = meta.get("terminal_exit") {
if let Some(id_str) = term_exit.get("terminal_id").and_then(|v| v.as_str()) {
let terminal_id = acp::TerminalId::new(id_str);
- let mut status = acp::TerminalExitStatus::new();
- if let Some(code) = term_exit.get("exit_code").and_then(|v| v.as_u64()) {
- status = status.exit_code(code as u32)
- }
- if let Some(signal) = term_exit.get("signal").and_then(|v| v.as_str()) {
- status = status.signal(signal);
- }
+ let status = acp::TerminalExitStatus::new()
+ .exit_code(
+ term_exit
+ .get("exit_code")
+ .and_then(|v| v.as_u64())
+ .map(|i| i as u32),
+ )
+ .signal(
+ term_exit
+ .get("signal")
+ .and_then(|v| v.as_str().map(|s| s.to_string())),
+ );
let _ = session.thread.update(&mut self.cx.clone(), |thread, cx| {
thread.on_terminal_provider_event(
@@ -22,7 +22,7 @@ use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
pub struct EntryViewState {
workspace: WeakEntity<Workspace>,
- project: Entity<Project>,
+ project: WeakEntity<Project>,
history_store: Entity<HistoryStore>,
prompt_store: Option<Entity<PromptStore>>,
entries: Vec<Entry>,
@@ -34,7 +34,7 @@ pub struct EntryViewState {
impl EntryViewState {
pub fn new(
workspace: WeakEntity<Workspace>,
- project: Entity<Project>,
+ project: WeakEntity<Project>,
history_store: Entity<HistoryStore>,
prompt_store: Option<Entity<PromptStore>>,
prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
@@ -328,7 +328,7 @@ impl Entry {
fn create_terminal(
workspace: WeakEntity<Workspace>,
- project: Entity<Project>,
+ project: WeakEntity<Project>,
terminal: Entity<acp_thread::Terminal>,
window: &mut Window,
cx: &mut App,
@@ -336,9 +336,9 @@ fn create_terminal(
cx.new(|cx| {
let mut view = TerminalView::new(
terminal.read(cx).inner().clone(),
- workspace.clone(),
+ workspace,
None,
- project.downgrade(),
+ project,
window,
cx,
);
@@ -458,7 +458,7 @@ mod tests {
let view_state = cx.new(|_cx| {
EntryViewState::new(
workspace.downgrade(),
- project.clone(),
+ project.downgrade(),
history_store,
None,
Default::default(),
@@ -21,8 +21,8 @@ use editor::{
};
use futures::{FutureExt as _, future::join_all};
use gpui::{
- AppContext, Context, Entity, EventEmitter, FocusHandle, Focusable, ImageFormat, KeyContext,
- SharedString, Subscription, Task, TextStyle, WeakEntity,
+ AppContext, ClipboardEntry, Context, Entity, EventEmitter, FocusHandle, Focusable, ImageFormat,
+ KeyContext, SharedString, Subscription, Task, TextStyle, WeakEntity,
};
use language::{Buffer, Language, language_settings::InlayHintKind};
use project::{CompletionIntent, InlayHint, InlayHintLabel, InlayId, Project, Worktree};
@@ -39,7 +39,6 @@ use zed_actions::agent::Chat;
pub struct MessageEditor {
mention_set: Entity<MentionSet>,
editor: Entity<Editor>,
- project: Entity<Project>,
workspace: WeakEntity<Workspace>,
prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
@@ -98,7 +97,7 @@ impl PromptCompletionProviderDelegate for Entity<MessageEditor> {
impl MessageEditor {
pub fn new(
workspace: WeakEntity<Workspace>,
- project: Entity<Project>,
+ project: WeakEntity<Project>,
history_store: Entity<HistoryStore>,
prompt_store: Option<Entity<PromptStore>>,
prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
@@ -124,6 +123,7 @@ impl MessageEditor {
let mut editor = Editor::new(mode, buffer, None, window, cx);
editor.set_placeholder_text(placeholder, window, cx);
editor.set_show_indent_guides(false, cx);
+ editor.set_show_completions_on_input(Some(true));
editor.set_soft_wrap();
editor.set_use_modal_editing(true);
editor.set_context_menu_options(ContextMenuOptions {
@@ -134,13 +134,8 @@ impl MessageEditor {
editor.register_addon(MessageEditorAddon::new());
editor
});
- let mention_set = cx.new(|_cx| {
- MentionSet::new(
- project.downgrade(),
- history_store.clone(),
- prompt_store.clone(),
- )
- });
+ let mention_set =
+ cx.new(|_cx| MentionSet::new(project, history_store.clone(), prompt_store.clone()));
let completion_provider = Rc::new(PromptCompletionProvider::new(
cx.entity(),
editor.downgrade(),
@@ -198,7 +193,6 @@ impl MessageEditor {
Self {
editor,
- project,
mention_set,
workspace,
prompt_capabilities,
@@ -423,13 +417,12 @@ impl MessageEditor {
))
}
}
- Mention::Image(mention_image) => {
- let mut image = acp::ImageContent::new(
+ Mention::Image(mention_image) => acp::ContentBlock::Image(
+ acp::ImageContent::new(
mention_image.data.clone(),
mention_image.format.mime_type(),
- );
-
- if let Some(uri) = match uri {
+ )
+ .uri(match uri {
MentionUri::File { .. } => Some(uri.to_uri().to_string()),
MentionUri::PastedImage => None,
other => {
@@ -439,11 +432,8 @@ impl MessageEditor {
);
None
}
- } {
- image = image.uri(uri)
- };
- acp::ContentBlock::Image(image)
- }
+ }),
+ ),
Mention::Link => acp::ContentBlock::ResourceLink(
acp::ResourceLink::new(uri.name(), uri.to_uri().to_string()),
),
@@ -553,6 +543,120 @@ impl MessageEditor {
}
fn paste(&mut self, _: &Paste, window: &mut Window, cx: &mut Context<Self>) {
+ let editor_clipboard_selections = cx
+ .read_from_clipboard()
+ .and_then(|item| item.entries().first().cloned())
+ .and_then(|entry| match entry {
+ ClipboardEntry::String(text) => {
+ text.metadata_json::<Vec<editor::ClipboardSelection>>()
+ }
+ _ => None,
+ });
+
+ let has_file_context = editor_clipboard_selections
+ .as_ref()
+ .is_some_and(|selections| {
+ selections
+ .iter()
+ .any(|sel| sel.file_path.is_some() && sel.line_range.is_some())
+ });
+
+ if has_file_context {
+ if let Some((workspace, selections)) =
+ self.workspace.upgrade().zip(editor_clipboard_selections)
+ {
+ cx.stop_propagation();
+
+ let project = workspace.read(cx).project().clone();
+ for selection in selections {
+ if let (Some(file_path), Some(line_range)) =
+ (selection.file_path, selection.line_range)
+ {
+ let crease_text =
+ acp_thread::selection_name(Some(file_path.as_ref()), &line_range);
+
+ let mention_uri = MentionUri::Selection {
+ abs_path: Some(file_path.clone()),
+ line_range: line_range.clone(),
+ };
+
+ let mention_text = mention_uri.as_link().to_string();
+ let (excerpt_id, text_anchor, content_len) =
+ self.editor.update(cx, |editor, cx| {
+ let buffer = editor.buffer().read(cx);
+ let snapshot = buffer.snapshot(cx);
+ let (excerpt_id, _, buffer_snapshot) =
+ snapshot.as_singleton().unwrap();
+ let start_offset = buffer_snapshot.len();
+ let text_anchor = buffer_snapshot.anchor_before(start_offset);
+
+ editor.insert(&mention_text, window, cx);
+ editor.insert(" ", window, cx);
+
+ (*excerpt_id, text_anchor, mention_text.len())
+ });
+
+ let Some((crease_id, tx)) = insert_crease_for_mention(
+ excerpt_id,
+ text_anchor,
+ content_len,
+ crease_text.into(),
+ mention_uri.icon_path(cx),
+ None,
+ self.editor.clone(),
+ window,
+ cx,
+ ) else {
+ continue;
+ };
+ drop(tx);
+
+ let mention_task = cx
+ .spawn({
+ let project = project.clone();
+ async move |_, cx| {
+ let project_path = project
+ .update(cx, |project, cx| {
+ project.project_path_for_absolute_path(&file_path, cx)
+ })
+ .map_err(|e| e.to_string())?
+ .ok_or_else(|| "project path not found".to_string())?;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_buffer(project_path, cx)
+ })
+ .map_err(|e| e.to_string())?
+ .await
+ .map_err(|e| e.to_string())?;
+
+ buffer
+ .update(cx, |buffer, cx| {
+ let start = Point::new(*line_range.start(), 0)
+ .min(buffer.max_point());
+ let end = Point::new(*line_range.end() + 1, 0)
+ .min(buffer.max_point());
+ let content =
+ buffer.text_for_range(start..end).collect();
+ Mention::Text {
+ content,
+ tracked_buffers: vec![cx.entity()],
+ }
+ })
+ .map_err(|e| e.to_string())
+ }
+ })
+ .shared();
+
+ self.mention_set.update(cx, |mention_set, _cx| {
+ mention_set.insert_mention(crease_id, mention_uri.clone(), mention_task)
+ });
+ }
+ }
+ return;
+ }
+ }
+
if self.prompt_capabilities.borrow().image
&& let Some(task) =
paste_images_as_context(self.editor.clone(), self.mention_set.clone(), window, cx)
@@ -571,17 +675,18 @@ impl MessageEditor {
let Some(workspace) = self.workspace.upgrade() else {
return;
};
- let path_style = self.project.read(cx).path_style(cx);
+ let project = workspace.read(cx).project().clone();
+ let path_style = project.read(cx).path_style(cx);
let buffer = self.editor.read(cx).buffer().clone();
let Some(buffer) = buffer.read(cx).as_singleton() else {
return;
};
let mut tasks = Vec::new();
for path in paths {
- let Some(entry) = self.project.read(cx).entry_for_path(&path, cx) else {
+ let Some(entry) = project.read(cx).entry_for_path(&path, cx) else {
continue;
};
- let Some(worktree) = self.project.read(cx).worktree_for_id(path.worktree_id, cx) else {
+ let Some(worktree) = project.read(cx).worktree_for_id(path.worktree_id, cx) else {
continue;
};
let abs_path = worktree.read(cx).absolutize(&path.path);
@@ -689,9 +794,13 @@ impl MessageEditor {
window: &mut Window,
cx: &mut Context<Self>,
) {
+ let Some(workspace) = self.workspace.upgrade() else {
+ return;
+ };
+
self.clear(window, cx);
- let path_style = self.project.read(cx).path_style(cx);
+ let path_style = workspace.read(cx).project().read(cx).path_style(cx);
let mut text = String::new();
let mut mentions = Vec::new();
@@ -934,7 +1043,7 @@ mod tests {
cx.new(|cx| {
MessageEditor::new(
workspace.downgrade(),
- project.clone(),
+ project.downgrade(),
history_store.clone(),
None,
Default::default(),
@@ -1045,7 +1154,7 @@ mod tests {
cx.new(|cx| {
MessageEditor::new(
workspace_handle.clone(),
- project.clone(),
+ project.downgrade(),
history_store.clone(),
None,
prompt_capabilities.clone(),
@@ -1206,7 +1315,7 @@ mod tests {
let message_editor = cx.new(|cx| {
MessageEditor::new(
workspace_handle,
- project.clone(),
+ project.downgrade(),
history_store.clone(),
None,
prompt_capabilities.clone(),
@@ -1428,7 +1537,7 @@ mod tests {
let message_editor = cx.new(|cx| {
MessageEditor::new(
workspace_handle,
- project.clone(),
+ project.downgrade(),
history_store.clone(),
None,
prompt_capabilities.clone(),
@@ -1919,7 +2028,7 @@ mod tests {
cx.new(|cx| {
let editor = MessageEditor::new(
workspace.downgrade(),
- project.clone(),
+ project.downgrade(),
history_store.clone(),
None,
Default::default(),
@@ -2024,7 +2133,7 @@ mod tests {
cx.new(|cx| {
let mut editor = MessageEditor::new(
workspace.downgrade(),
- project.clone(),
+ project.downgrade(),
history_store.clone(),
None,
Default::default(),
@@ -2093,7 +2202,7 @@ mod tests {
cx.new(|cx| {
MessageEditor::new(
workspace.downgrade(),
- project.clone(),
+ project.downgrade(),
history_store.clone(),
None,
Default::default(),
@@ -2156,7 +2265,7 @@ mod tests {
let message_editor = cx.new(|cx| {
MessageEditor::new(
workspace_handle,
- project.clone(),
+ project.downgrade(),
history_store.clone(),
None,
Default::default(),
@@ -2314,7 +2423,7 @@ mod tests {
let message_editor = cx.new(|cx| {
MessageEditor::new(
workspace_handle,
- project.clone(),
+ project.downgrade(),
history_store.clone(),
None,
Default::default(),
@@ -100,7 +100,7 @@ impl ThreadError {
{
Self::ModelRequestLimitReached(error.plan)
} else if let Some(acp_error) = error.downcast_ref::<acp::Error>()
- && acp_error.code == acp::ErrorCode::AUTH_REQUIRED.code
+ && acp_error.code == acp::ErrorCode::AuthRequired
{
Self::AuthenticationRequired(acp_error.message.clone().into())
} else {
@@ -344,7 +344,7 @@ impl AcpThreadView {
let message_editor = cx.new(|cx| {
let mut editor = MessageEditor::new(
workspace.clone(),
- project.clone(),
+ project.downgrade(),
history_store.clone(),
prompt_store.clone(),
prompt_capabilities.clone(),
@@ -369,7 +369,7 @@ impl AcpThreadView {
let entry_view_state = cx.new(|_| {
EntryViewState::new(
workspace.clone(),
- project.clone(),
+ project.downgrade(),
history_store.clone(),
prompt_store.clone(),
prompt_capabilities.clone(),
@@ -6243,7 +6243,7 @@ pub(crate) mod tests {
StubAgentConnection::new().with_permission_requests(HashMap::from_iter([(
tool_call_id,
vec![acp::PermissionOption::new(
- "1".into(),
+ "1",
"Allow",
acp::PermissionOptionKind::AllowOnce,
)],
@@ -36,7 +36,7 @@ use settings::{Settings, SettingsStore, update_settings_file};
use ui::{
Button, ButtonStyle, Chip, CommonAnimationExt, ContextMenu, ContextMenuEntry, Disclosure,
Divider, DividerColor, ElevationIndex, IconName, IconPosition, IconSize, Indicator, LabelSize,
- PopoverMenu, Switch, SwitchColor, Tooltip, WithScrollbar, prelude::*,
+ PopoverMenu, Switch, Tooltip, WithScrollbar, prelude::*,
};
use util::ResultExt as _;
use workspace::{Workspace, create_and_open_local_file};
@@ -883,7 +883,6 @@ impl AgentConfiguration {
.child(context_server_configuration_menu)
.child(
Switch::new("context-server-switch", is_running.into())
- .color(SwitchColor::Accent)
.on_click({
let context_server_manager = self.context_server_store.clone();
let fs = self.fs.clone();
@@ -108,7 +108,7 @@ impl Render for AgentModelSelector {
.child(
Icon::new(IconName::ChevronDown)
.color(color)
- .size(IconSize::XSmall),
+ .size(IconSize::Small),
),
move |_window, cx| {
Tooltip::for_action_in("Change Model", &ToggleModelSelector, &focus_handle, cx)
@@ -5,22 +5,26 @@ use client::telemetry::Telemetry;
use cloud_llm_client::CompletionIntent;
use collections::HashSet;
use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
+use feature_flags::{FeatureFlagAppExt as _, InlineAssistantV2FeatureFlag};
use futures::{
SinkExt, Stream, StreamExt, TryStreamExt as _,
channel::mpsc,
future::{LocalBoxFuture, Shared},
join,
};
-use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task};
+use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
use language_model::{
- LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
- LanguageModelTextStream, Role, report_assistant_event,
+ LanguageModel, LanguageModelCompletionError, LanguageModelRegistry, LanguageModelRequest,
+ LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelTextStream, Role,
+ report_assistant_event,
};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
use prompt_store::PromptBuilder;
use rope::Rope;
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
use smol::future::FutureExt;
use std::{
cmp,
@@ -34,6 +38,29 @@ use std::{
};
use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
+use ui::SharedString;
+
+/// Use this tool to provide a message to the user when you're unable to complete a task.
+#[derive(Debug, Serialize, Deserialize, JsonSchema)]
+pub struct FailureMessageInput {
+ /// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
+ ///
+ /// The message may use markdown formatting if you wish.
+ pub message: String,
+}
+
+/// Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.
+#[derive(Debug, Serialize, Deserialize, JsonSchema)]
+pub struct RewriteSectionInput {
+ /// A brief description of the edit you have made.
+ ///
+ /// The description may use markdown formatting if you wish.
+ /// This is optional - if the edit is simple or obvious, you should leave it empty.
+ pub description: String,
+
+ /// The text to replace the section with.
+ pub replacement_text: String,
+}
pub struct BufferCodegen {
alternatives: Vec<Entity<CodegenAlternative>>,
@@ -238,6 +265,7 @@ pub struct CodegenAlternative {
elapsed_time: Option<f64>,
completion: Option<String>,
pub message_id: Option<String>,
+ pub model_explanation: Option<SharedString>,
}
impl EventEmitter<CodegenEvent> for CodegenAlternative {}
@@ -288,14 +316,15 @@ impl CodegenAlternative {
generation: Task::ready(()),
diff: Diff::default(),
telemetry,
- _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
builder,
- active,
+ active: active,
edits: Vec::new(),
line_operations: Vec::new(),
range,
elapsed_time: None,
completion: None,
+ model_explanation: None,
+ _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
}
}
@@ -358,18 +387,124 @@ impl CodegenAlternative {
let api_key = model.api_key(cx);
let telemetry_id = model.telemetry_id();
let provider_id = model.provider_id();
- let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
- if user_prompt.trim().to_lowercase() == "delete" {
- async { Ok(LanguageModelTextStream::default()) }.boxed_local()
+
+ if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
+ let request = self.build_request(&model, user_prompt, context_task, cx)?;
+ let tool_use =
+ cx.spawn(async move |_, cx| model.stream_completion_tool(request.await, cx).await);
+ self.handle_tool_use(telemetry_id, provider_id.to_string(), api_key, tool_use, cx);
+ } else {
+ let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
+ if user_prompt.trim().to_lowercase() == "delete" {
+ async { Ok(LanguageModelTextStream::default()) }.boxed_local()
+ } else {
+ let request = self.build_request(&model, user_prompt, context_task, cx)?;
+ cx.spawn(async move |_, cx| {
+ Ok(model.stream_completion_text(request.await, cx).await?)
+ })
+ .boxed_local()
+ };
+ self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
+ }
+
+ Ok(())
+ }
+
+ fn build_request_v2(
+ &self,
+ model: &Arc<dyn LanguageModel>,
+ user_prompt: String,
+ context_task: Shared<Task<Option<LoadedContext>>>,
+ cx: &mut App,
+ ) -> Result<Task<LanguageModelRequest>> {
+ let buffer = self.buffer.read(cx).snapshot(cx);
+ let language = buffer.language_at(self.range.start);
+ let language_name = if let Some(language) = language.as_ref() {
+ if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
+ None
} else {
- let request = self.build_request(&model, user_prompt, context_task, cx)?;
- cx.spawn(async move |_, cx| {
- Ok(model.stream_completion_text(request.await, cx).await?)
- })
- .boxed_local()
+ Some(language.name())
+ }
+ } else {
+ None
+ };
+
+ let language_name = language_name.as_ref();
+ let start = buffer.point_to_buffer_offset(self.range.start);
+ let end = buffer.point_to_buffer_offset(self.range.end);
+ let (buffer, range) = if let Some((start, end)) = start.zip(end) {
+ let (start_buffer, start_buffer_offset) = start;
+ let (end_buffer, end_buffer_offset) = end;
+ if start_buffer.remote_id() == end_buffer.remote_id() {
+ (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
+ } else {
+ anyhow::bail!("invalid transformation range");
+ }
+ } else {
+ anyhow::bail!("invalid transformation range");
+ };
+
+ let system_prompt = self
+ .builder
+ .generate_inline_transformation_prompt_v2(
+ language_name,
+ buffer,
+ range.start.0..range.end.0,
+ )
+ .context("generating content prompt")?;
+
+ let temperature = AgentSettings::temperature_for_model(model, cx);
+
+ let tool_input_format = model.tool_input_format();
+
+ Ok(cx.spawn(async move |_cx| {
+ let mut messages = vec![LanguageModelRequestMessage {
+ role: Role::System,
+ content: vec![system_prompt.into()],
+ cache: false,
+ reasoning_details: None,
+ }];
+
+ let mut user_message = LanguageModelRequestMessage {
+ role: Role::User,
+ content: Vec::new(),
+ cache: false,
+ reasoning_details: None,
};
- self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
- Ok(())
+
+ if let Some(context) = context_task.await {
+ context.add_to_request_message(&mut user_message);
+ }
+
+ user_message.content.push(user_prompt.into());
+ messages.push(user_message);
+
+ let tools = vec![
+ LanguageModelRequestTool {
+ name: "rewrite_section".to_string(),
+ description: "Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.".to_string(),
+ input_schema: language_model::tool_schema::root_schema_for::<RewriteSectionInput>(tool_input_format).to_value(),
+ },
+ LanguageModelRequestTool {
+ name: "failure_message".to_string(),
+ description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(),
+ input_schema: language_model::tool_schema::root_schema_for::<FailureMessageInput>(tool_input_format).to_value(),
+ },
+ ];
+
+ LanguageModelRequest {
+ thread_id: None,
+ prompt_id: None,
+ intent: Some(CompletionIntent::InlineAssist),
+ mode: None,
+ tools,
+ tool_choice: None,
+ stop: Vec::new(),
+ temperature,
+ messages,
+ thinking_allowed: false,
+ }
+ }))
}
fn build_request(
@@ -379,6 +514,10 @@ impl CodegenAlternative {
context_task: Shared<Task<Option<LoadedContext>>>,
cx: &mut App,
) -> Result<Task<LanguageModelRequest>> {
+ if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
+ return self.build_request_v2(model, user_prompt, context_task, cx);
+ }
+
let buffer = self.buffer.read(cx).snapshot(cx);
let language = buffer.language_at(self.range.start);
let language_name = if let Some(language) = language.as_ref() {
@@ -510,6 +649,7 @@ impl CodegenAlternative {
self.generation = cx.spawn(async move |codegen, cx| {
let stream = stream.await;
+
let token_usage = stream
.as_ref()
.ok()
@@ -899,6 +1039,101 @@ impl CodegenAlternative {
.ok();
})
}
+
+ fn handle_tool_use(
+ &mut self,
+ _telemetry_id: String,
+ _provider_id: String,
+ _api_key: Option<String>,
+ tool_use: impl 'static
+ + Future<
+ Output = Result<language_model::LanguageModelToolUse, LanguageModelCompletionError>,
+ >,
+ cx: &mut Context<Self>,
+ ) {
+ self.diff = Diff::default();
+ self.status = CodegenStatus::Pending;
+
+ self.generation = cx.spawn(async move |codegen, cx| {
+ let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
+ let _ = codegen.update(cx, |this, cx| {
+ this.status = status;
+ cx.emit(CodegenEvent::Finished);
+ cx.notify();
+ });
+ };
+
+ let tool_use = tool_use.await;
+
+ match tool_use {
+ Ok(tool_use) if tool_use.name.as_ref() == "rewrite_section" => {
+ // Parse the input JSON into RewriteSectionInput
+ match serde_json::from_value::<RewriteSectionInput>(tool_use.input) {
+ Ok(input) => {
+ // Store the description if non-empty
+ let description = if !input.description.trim().is_empty() {
+ Some(input.description.clone())
+ } else {
+ None
+ };
+
+ // Apply the replacement text to the buffer and compute diff
+ let batch_diff_task = codegen
+ .update(cx, |this, cx| {
+ this.model_explanation = description.map(Into::into);
+ let range = this.range.clone();
+ this.apply_edits(
+ std::iter::once((range, input.replacement_text)),
+ cx,
+ );
+ this.reapply_batch_diff(cx)
+ })
+ .ok();
+
+ // Wait for the diff computation to complete
+ if let Some(diff_task) = batch_diff_task {
+ diff_task.await;
+ }
+
+ finish_with_status(CodegenStatus::Done, cx);
+ return;
+ }
+ Err(e) => {
+ finish_with_status(CodegenStatus::Error(e.into()), cx);
+ return;
+ }
+ }
+ }
+ Ok(tool_use) if tool_use.name.as_ref() == "failure_message" => {
+ // Handle failure message tool use
+ match serde_json::from_value::<FailureMessageInput>(tool_use.input) {
+ Ok(input) => {
+ let _ = codegen.update(cx, |this, _cx| {
+ // Store the failure message as the tool description
+ this.model_explanation = Some(input.message.into());
+ });
+ finish_with_status(CodegenStatus::Done, cx);
+ return;
+ }
+ Err(e) => {
+ finish_with_status(CodegenStatus::Error(e.into()), cx);
+ return;
+ }
+ }
+ }
+ Ok(_tool_use) => {
+ // Unexpected tool.
+ finish_with_status(CodegenStatus::Done, cx);
+ return;
+ }
+ Err(e) => {
+ finish_with_status(CodegenStatus::Error(e.into()), cx);
+ return;
+ }
+ }
+ });
+ cx.notify();
+ }
}
#[derive(Copy, Clone, Debug)]
@@ -1060,8 +1295,9 @@ mod tests {
};
use gpui::TestAppContext;
use indoc::indoc;
- use language::{Buffer, Language, LanguageConfig, LanguageMatcher, Point, tree_sitter_rust};
+ use language::{Buffer, Point};
use language_model::{LanguageModelRegistry, TokenUsage};
+ use languages::rust_lang;
use rand::prelude::*;
use settings::SettingsStore;
use std::{future, sync::Arc};
@@ -1078,7 +1314,7 @@ mod tests {
}
}
"};
- let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
let range = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
@@ -1140,7 +1376,7 @@ mod tests {
le
}
"};
- let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
let range = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
@@ -1204,7 +1440,7 @@ mod tests {
" \n",
"}\n" //
);
- let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
let range = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
@@ -1320,7 +1556,7 @@ mod tests {
let x = 0;
}
"};
- let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
let range = buffer.read_with(cx, |buffer, cx| {
let snapshot = buffer.snapshot(cx);
@@ -1437,27 +1673,4 @@ mod tests {
});
chunks_tx
}
-
- fn rust_lang() -> Language {
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- )
- .with_indents_query(
- r#"
- (call_expression) @indent
- (field_expression) @indent
- (_ "(" ")" @end) @indent
- (_ "{" "}" @end) @indent
- "#,
- )
- .unwrap()
- }
}
@@ -387,17 +387,9 @@ impl InlineAssistant {
let mut selections = Vec::<Selection<Point>>::new();
let mut newest_selection = None;
for mut selection in initial_selections {
- if selection.end > selection.start {
- selection.start.column = 0;
- // If the selection ends at the start of the line, we don't want to include it.
- if selection.end.column == 0 {
- selection.end.row -= 1;
- }
- selection.end.column = snapshot
- .buffer_snapshot()
- .line_len(MultiBufferRow(selection.end.row));
- } else if let Some(fold) =
- snapshot.crease_for_buffer_row(MultiBufferRow(selection.end.row))
+ if selection.end == selection.start
+ && let Some(fold) =
+ snapshot.crease_for_buffer_row(MultiBufferRow(selection.end.row))
{
selection.start = fold.range().start;
selection.end = fold.range().end;
@@ -424,6 +416,15 @@ impl InlineAssistant {
}
}
}
+ } else {
+ selection.start.column = 0;
+ // If the selection ends at the start of the line, we don't want to include it.
+ if selection.end.column == 0 && selection.start.row != selection.end.row {
+ selection.end.row -= 1;
+ }
+ selection.end.column = snapshot
+ .buffer_snapshot()
+ .line_len(MultiBufferRow(selection.end.row));
}
if let Some(prev_selection) = selections.last_mut()
@@ -544,14 +545,15 @@ impl InlineAssistant {
}
}
- let [prompt_block_id, end_block_id] =
- self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
+ let [prompt_block_id, tool_description_block_id, end_block_id] =
+ self.insert_assist_blocks(&editor, &range, &prompt_editor, cx);
assists.push((
assist_id,
range.clone(),
prompt_editor,
prompt_block_id,
+ tool_description_block_id,
end_block_id,
));
}
@@ -570,7 +572,15 @@ impl InlineAssistant {
};
let mut assist_group = InlineAssistGroup::new();
- for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
+ for (
+ assist_id,
+ range,
+ prompt_editor,
+ prompt_block_id,
+ tool_description_block_id,
+ end_block_id,
+ ) in assists
+ {
let codegen = prompt_editor.read(cx).codegen().clone();
self.assists.insert(
@@ -581,6 +591,7 @@ impl InlineAssistant {
editor,
&prompt_editor,
prompt_block_id,
+ tool_description_block_id,
end_block_id,
range,
codegen,
@@ -689,7 +700,7 @@ impl InlineAssistant {
range: &Range<Anchor>,
prompt_editor: &Entity<PromptEditor<BufferCodegen>>,
cx: &mut App,
- ) -> [CustomBlockId; 2] {
+ ) -> [CustomBlockId; 3] {
let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
prompt_editor
.editor
@@ -703,6 +714,14 @@ impl InlineAssistant {
render: build_assist_editor_renderer(prompt_editor),
priority: 0,
},
+ // Placeholder for tool description - will be updated dynamically
+ BlockProperties {
+ style: BlockStyle::Flex,
+ placement: BlockPlacement::Below(range.end),
+ height: Some(0),
+ render: Arc::new(|_cx| div().into_any_element()),
+ priority: 0,
+ },
BlockProperties {
style: BlockStyle::Sticky,
placement: BlockPlacement::Below(range.end),
@@ -721,7 +740,7 @@ impl InlineAssistant {
editor.update(cx, |editor, cx| {
let block_ids = editor.insert_blocks(assist_blocks, None, cx);
- [block_ids[0], block_ids[1]]
+ [block_ids[0], block_ids[1], block_ids[2]]
})
}
@@ -1113,6 +1132,9 @@ impl InlineAssistant {
let mut to_remove = decorations.removed_line_block_ids;
to_remove.insert(decorations.prompt_block_id);
to_remove.insert(decorations.end_block_id);
+ if let Some(tool_description_block_id) = decorations.model_explanation {
+ to_remove.insert(tool_description_block_id);
+ }
editor.remove_blocks(to_remove, None, cx);
});
@@ -1433,8 +1455,60 @@ impl InlineAssistant {
let old_snapshot = codegen.snapshot(cx);
let old_buffer = codegen.old_buffer(cx);
let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone();
+ // let model_explanation = codegen.model_explanation(cx);
editor.update(cx, |editor, cx| {
+ // Update tool description block
+ // if let Some(description) = model_explanation {
+ // if let Some(block_id) = decorations.model_explanation {
+ // editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
+ // let new_block_id = editor.insert_blocks(
+ // [BlockProperties {
+ // style: BlockStyle::Flex,
+ // placement: BlockPlacement::Below(assist.range.end),
+ // height: Some(1),
+ // render: Arc::new({
+ // let description = description.clone();
+ // move |cx| {
+ // div()
+ // .w_full()
+ // .py_1()
+ // .px_2()
+ // .bg(cx.theme().colors().editor_background)
+ // .border_y_1()
+ // .border_color(cx.theme().status().info_border)
+ // .child(
+ // Label::new(description.clone())
+ // .color(Color::Muted)
+ // .size(LabelSize::Small),
+ // )
+ // .into_any_element()
+ // }
+ // }),
+ // priority: 0,
+ // }],
+ // None,
+ // cx,
+ // );
+ // decorations.model_explanation = new_block_id.into_iter().next();
+ // }
+ // } else if let Some(block_id) = decorations.model_explanation {
+ // // Hide the block if there's no description
+ // editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
+ // let new_block_id = editor.insert_blocks(
+ // [BlockProperties {
+ // style: BlockStyle::Flex,
+ // placement: BlockPlacement::Below(assist.range.end),
+ // height: Some(0),
+ // render: Arc::new(|_cx| div().into_any_element()),
+ // priority: 0,
+ // }],
+ // None,
+ // cx,
+ // );
+ // decorations.model_explanation = new_block_id.into_iter().next();
+ // }
+
let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
editor.remove_blocks(old_blocks, None, cx);
@@ -1686,6 +1760,7 @@ impl InlineAssist {
editor: &Entity<Editor>,
prompt_editor: &Entity<PromptEditor<BufferCodegen>>,
prompt_block_id: CustomBlockId,
+ tool_description_block_id: CustomBlockId,
end_block_id: CustomBlockId,
range: Range<Anchor>,
codegen: Entity<BufferCodegen>,
@@ -1700,7 +1775,8 @@ impl InlineAssist {
decorations: Some(InlineAssistDecorations {
prompt_block_id,
prompt_editor: prompt_editor.clone(),
- removed_line_block_ids: HashSet::default(),
+ removed_line_block_ids: Default::default(),
+ model_explanation: Some(tool_description_block_id),
end_block_id,
}),
range,
@@ -1804,6 +1880,7 @@ struct InlineAssistDecorations {
prompt_block_id: CustomBlockId,
prompt_editor: Entity<PromptEditor<BufferCodegen>>,
removed_line_block_ids: HashSet<CustomBlockId>,
+ model_explanation: Option<CustomBlockId>,
end_block_id: CustomBlockId,
}
@@ -10,10 +10,11 @@ use editor::{
};
use fs::Fs;
use gpui::{
- AnyElement, App, Context, CursorStyle, Entity, EventEmitter, FocusHandle, Focusable,
- Subscription, TextStyle, WeakEntity, Window,
+ AnyElement, App, Context, Entity, EventEmitter, FocusHandle, Focusable, Subscription,
+ TextStyle, TextStyleRefinement, WeakEntity, Window,
};
use language_model::{LanguageModel, LanguageModelRegistry};
+use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
use parking_lot::Mutex;
use project::Project;
use prompt_store::PromptStore;
@@ -65,7 +66,7 @@ impl<T: 'static> Render for PromptEditor<T> {
const RIGHT_PADDING: Pixels = px(9.);
- let (left_gutter_width, right_padding) = match &self.mode {
+ let (left_gutter_width, right_padding, explanation) = match &self.mode {
PromptEditorMode::Buffer {
id: _,
codegen,
@@ -83,17 +84,23 @@ impl<T: 'static> Render for PromptEditor<T> {
let left_gutter_width = gutter.full_width() + (gutter.margin / 2.0);
let right_padding = editor_margins.right + RIGHT_PADDING;
- (left_gutter_width, right_padding)
+ let explanation = codegen
+ .active_alternative()
+ .read(cx)
+ .model_explanation
+ .clone();
+
+ (left_gutter_width, right_padding, explanation)
}
PromptEditorMode::Terminal { .. } => {
// Give the equivalent of the same left-padding that we're using on the right
- (Pixels::from(40.0), Pixels::from(24.))
+ (Pixels::from(40.0), Pixels::from(24.), None)
}
};
let bottom_padding = match &self.mode {
PromptEditorMode::Buffer { .. } => rems_from_px(2.0),
- PromptEditorMode::Terminal { .. } => rems_from_px(8.0),
+ PromptEditorMode::Terminal { .. } => rems_from_px(4.0),
};
buttons.extend(self.render_buttons(window, cx));
@@ -111,22 +118,33 @@ impl<T: 'static> Render for PromptEditor<T> {
this.trigger_completion_menu(window, cx);
}));
+ let markdown = window.use_state(cx, |_, cx| Markdown::new("".into(), None, None, cx));
+
+ if let Some(explanation) = &explanation {
+ markdown.update(cx, |markdown, cx| {
+ markdown.reset(explanation.clone(), cx);
+ });
+ }
+
+ let explanation_label = self
+ .render_markdown(markdown, markdown_style(window, cx))
+ .into_any_element();
+
v_flex()
.key_context("PromptEditor")
.capture_action(cx.listener(Self::paste))
- .bg(cx.theme().colors().editor_background)
.block_mouse_except_scroll()
- .gap_0p5()
- .border_y_1()
- .border_color(cx.theme().status().info_border)
.size_full()
.pt_0p5()
.pb(bottom_padding)
.pr(right_padding)
+ .gap_0p5()
+ .justify_center()
+ .border_y_1()
+ .border_color(cx.theme().colors().border)
+ .bg(cx.theme().colors().editor_background)
.child(
h_flex()
- .items_start()
- .cursor(CursorStyle::Arrow)
.on_action(cx.listener(|this, _: &ToggleModelSelector, window, cx| {
this.model_selector
.update(cx, |model_selector, cx| model_selector.toggle(window, cx));
@@ -139,14 +157,14 @@ impl<T: 'static> Render for PromptEditor<T> {
.capture_action(cx.listener(Self::cycle_next))
.child(
WithRemSize::new(ui_font_size)
+ .h_full()
+ .w(left_gutter_width)
.flex()
.flex_row()
.flex_shrink_0()
.items_center()
- .h_full()
- .w(left_gutter_width)
.justify_center()
- .gap_2()
+ .gap_1()
.child(self.render_close_button(cx))
.map(|el| {
let CodegenStatus::Error(error) = self.codegen_status(cx) else {
@@ -177,26 +195,83 @@ impl<T: 'static> Render for PromptEditor<T> {
.flex_row()
.items_center()
.gap_1()
+ .child(add_context_button)
+ .child(self.model_selector.clone())
.children(buttons),
),
),
)
- .child(
- WithRemSize::new(ui_font_size)
- .flex()
- .flex_row()
- .items_center()
- .child(h_flex().flex_shrink_0().w(left_gutter_width))
- .child(
- h_flex()
- .w_full()
- .pl_1()
- .items_start()
- .justify_between()
- .child(add_context_button)
- .child(self.model_selector.clone()),
- ),
- )
+ .when_some(explanation, |this, _| {
+ this.child(
+ h_flex()
+ .size_full()
+ .justify_center()
+ .child(div().w(left_gutter_width + px(6.)))
+ .child(
+ div()
+ .size_full()
+ .min_w_0()
+ .pt(rems_from_px(3.))
+ .pl_0p5()
+ .flex_1()
+ .border_t_1()
+ .border_color(cx.theme().colors().border_variant)
+ .child(explanation_label),
+ ),
+ )
+ })
+ }
+}
+
+fn markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
+ let theme_settings = ThemeSettings::get_global(cx);
+ let colors = cx.theme().colors();
+ let mut text_style = window.text_style();
+
+ text_style.refine(&TextStyleRefinement {
+ font_family: Some(theme_settings.ui_font.family.clone()),
+ color: Some(colors.text),
+ ..Default::default()
+ });
+
+ MarkdownStyle {
+ base_text_style: text_style.clone(),
+ syntax: cx.theme().syntax().clone(),
+ selection_background_color: colors.element_selection_background,
+ heading_level_styles: Some(HeadingLevelStyles {
+ h1: Some(TextStyleRefinement {
+ font_size: Some(rems(1.15).into()),
+ ..Default::default()
+ }),
+ h2: Some(TextStyleRefinement {
+ font_size: Some(rems(1.1).into()),
+ ..Default::default()
+ }),
+ h3: Some(TextStyleRefinement {
+ font_size: Some(rems(1.05).into()),
+ ..Default::default()
+ }),
+ h4: Some(TextStyleRefinement {
+ font_size: Some(rems(1.).into()),
+ ..Default::default()
+ }),
+ h5: Some(TextStyleRefinement {
+ font_size: Some(rems(0.95).into()),
+ ..Default::default()
+ }),
+ h6: Some(TextStyleRefinement {
+ font_size: Some(rems(0.875).into()),
+ ..Default::default()
+ }),
+ }),
+ inline_code: TextStyleRefinement {
+ font_family: Some(theme_settings.buffer_font.family.clone()),
+ font_fallbacks: theme_settings.buffer_font.fallbacks.clone(),
+ font_features: Some(theme_settings.buffer_font.features.clone()),
+ background_color: Some(colors.editor_foreground.opacity(0.08)),
+ ..Default::default()
+ },
+ ..Default::default()
}
}
@@ -759,6 +834,10 @@ impl<T: 'static> PromptEditor<T> {
})
.into_any_element()
}
+
+ fn render_markdown(&self, markdown: Entity<Markdown>, style: MarkdownStyle) -> MarkdownElement {
+ MarkdownElement::new(markdown, style)
+ }
}
pub enum PromptEditorMode {
@@ -1682,6 +1682,98 @@ impl TextThreadEditor {
window: &mut Window,
cx: &mut Context<Self>,
) {
+ let editor_clipboard_selections = cx
+ .read_from_clipboard()
+ .and_then(|item| item.entries().first().cloned())
+ .and_then(|entry| match entry {
+ ClipboardEntry::String(text) => {
+ text.metadata_json::<Vec<editor::ClipboardSelection>>()
+ }
+ _ => None,
+ });
+
+ let has_file_context = editor_clipboard_selections
+ .as_ref()
+ .is_some_and(|selections| {
+ selections
+ .iter()
+ .any(|sel| sel.file_path.is_some() && sel.line_range.is_some())
+ });
+
+ if has_file_context {
+ if let Some(clipboard_item) = cx.read_from_clipboard() {
+ if let Some(ClipboardEntry::String(clipboard_text)) =
+ clipboard_item.entries().first()
+ {
+ if let Some(selections) = editor_clipboard_selections {
+ cx.stop_propagation();
+
+ let text = clipboard_text.text();
+ self.editor.update(cx, |editor, cx| {
+ let mut current_offset = 0;
+ let weak_editor = cx.entity().downgrade();
+
+ for selection in selections {
+ if let (Some(file_path), Some(line_range)) =
+ (selection.file_path, selection.line_range)
+ {
+ let selected_text =
+ &text[current_offset..current_offset + selection.len];
+ let fence = assistant_slash_commands::codeblock_fence_for_path(
+ file_path.to_str(),
+ Some(line_range.clone()),
+ );
+ let formatted_text = format!("{fence}{selected_text}\n```");
+
+ let insert_point = editor
+ .selections
+ .newest::<Point>(&editor.display_snapshot(cx))
+ .head();
+ let start_row = MultiBufferRow(insert_point.row);
+
+ editor.insert(&formatted_text, window, cx);
+
+ let snapshot = editor.buffer().read(cx).snapshot(cx);
+ let anchor_before = snapshot.anchor_after(insert_point);
+ let anchor_after = editor
+ .selections
+ .newest_anchor()
+ .head()
+ .bias_left(&snapshot);
+
+ editor.insert("\n", window, cx);
+
+ let crease_text = acp_thread::selection_name(
+ Some(file_path.as_ref()),
+ &line_range,
+ );
+
+ let fold_placeholder = quote_selection_fold_placeholder(
+ crease_text,
+ weak_editor.clone(),
+ );
+ let crease = Crease::inline(
+ anchor_before..anchor_after,
+ fold_placeholder,
+ render_quote_selection_output_toggle,
+ |_, _, _, _| Empty.into_any(),
+ );
+ editor.insert_creases(vec![crease], cx);
+ editor.fold_at(start_row, window, cx);
+
+ current_offset += selection.len;
+ if !selection.is_entire_line && current_offset < text.len() {
+ current_offset += 1;
+ }
+ }
+ }
+ });
+ return;
+ }
+ }
+ }
+ }
+
cx.stop_propagation();
let mut images = if let Some(item) = cx.read_from_clipboard() {
@@ -106,9 +106,6 @@ impl Render for AgentNotification {
.font(ui_font)
.border_color(cx.theme().colors().border)
.rounded_xl()
- .on_click(cx.listener(|_, _, _, cx| {
- cx.emit(AgentNotificationEvent::Accepted);
- }))
.child(
h_flex()
.items_start()
@@ -12,6 +12,8 @@ pub use settings::ModelMode;
use strum::{EnumIter, EnumString};
use thiserror::Error;
+pub mod batches;
+
pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
@@ -465,6 +467,7 @@ impl Model {
}
}
+/// Generate completion with streaming.
pub async fn stream_completion(
client: &dyn HttpClient,
api_url: &str,
@@ -477,6 +480,101 @@ pub async fn stream_completion(
.map(|output| output.0)
}
+/// Generate completion without streaming.
+pub async fn non_streaming_completion(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ request: Request,
+ beta_headers: Option<String>,
+) -> Result<Response, AnthropicError> {
+ let (mut response, rate_limits) =
+ send_request(client, api_url, api_key, &request, beta_headers).await?;
+
+ if response.status().is_success() {
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(AnthropicError::ReadResponse)?;
+
+ serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)
+ } else {
+ Err(handle_error_response(response, rate_limits).await)
+ }
+}
+
+async fn send_request(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ request: impl Serialize,
+ beta_headers: Option<String>,
+) -> Result<(http::Response<AsyncBody>, RateLimitInfo), AnthropicError> {
+ let uri = format!("{api_url}/v1/messages");
+
+ let mut request_builder = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Anthropic-Version", "2023-06-01")
+ .header("X-Api-Key", api_key.trim())
+ .header("Content-Type", "application/json");
+
+ if let Some(beta_headers) = beta_headers {
+ request_builder = request_builder.header("Anthropic-Beta", beta_headers);
+ }
+
+ let serialized_request =
+ serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?;
+ let request = request_builder
+ .body(AsyncBody::from(serialized_request))
+ .map_err(AnthropicError::BuildRequestBody)?;
+
+ let response = client
+ .send(request)
+ .await
+ .map_err(AnthropicError::HttpSend)?;
+
+ let rate_limits = RateLimitInfo::from_headers(response.headers());
+
+ Ok((response, rate_limits))
+}
+
+async fn handle_error_response(
+ mut response: http::Response<AsyncBody>,
+ rate_limits: RateLimitInfo,
+) -> AnthropicError {
+ if response.status().as_u16() == 529 {
+ return AnthropicError::ServerOverloaded {
+ retry_after: rate_limits.retry_after,
+ };
+ }
+
+ if let Some(retry_after) = rate_limits.retry_after {
+ return AnthropicError::RateLimit { retry_after };
+ }
+
+ let mut body = String::new();
+ let read_result = response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(AnthropicError::ReadResponse);
+
+ if let Err(err) = read_result {
+ return err;
+ }
+
+ match serde_json::from_str::<Event>(&body) {
+ Ok(Event::Error { error }) => AnthropicError::ApiError(error),
+ Ok(_) | Err(_) => AnthropicError::HttpResponseError {
+ status_code: response.status(),
+ message: body,
+ },
+ }
+}
+
/// An individual rate limit.
#[derive(Debug)]
pub struct RateLimit {
@@ -580,30 +678,10 @@ pub async fn stream_completion_with_rate_limit_info(
base: request,
stream: true,
};
- let uri = format!("{api_url}/v1/messages");
- let mut request_builder = HttpRequest::builder()
- .method(Method::POST)
- .uri(uri)
- .header("Anthropic-Version", "2023-06-01")
- .header("X-Api-Key", api_key.trim())
- .header("Content-Type", "application/json");
+ let (response, rate_limits) =
+ send_request(client, api_url, api_key, &request, beta_headers).await?;
- if let Some(beta_headers) = beta_headers {
- request_builder = request_builder.header("Anthropic-Beta", beta_headers);
- }
-
- let serialized_request =
- serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?;
- let request = request_builder
- .body(AsyncBody::from(serialized_request))
- .map_err(AnthropicError::BuildRequestBody)?;
-
- let mut response = client
- .send(request)
- .await
- .map_err(AnthropicError::HttpSend)?;
- let rate_limits = RateLimitInfo::from_headers(response.headers());
if response.status().is_success() {
let reader = BufReader::new(response.into_body());
let stream = reader
@@ -622,27 +700,8 @@ pub async fn stream_completion_with_rate_limit_info(
})
.boxed();
Ok((stream, Some(rate_limits)))
- } else if response.status().as_u16() == 529 {
- Err(AnthropicError::ServerOverloaded {
- retry_after: rate_limits.retry_after,
- })
- } else if let Some(retry_after) = rate_limits.retry_after {
- Err(AnthropicError::RateLimit { retry_after })
} else {
- let mut body = String::new();
- response
- .body_mut()
- .read_to_string(&mut body)
- .await
- .map_err(AnthropicError::ReadResponse)?;
-
- match serde_json::from_str::<Event>(&body) {
- Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
- Ok(_) | Err(_) => Err(AnthropicError::HttpResponseError {
- status_code: response.status(),
- message: body,
- }),
- }
+ Err(handle_error_response(response, rate_limits).await)
}
}
@@ -0,0 +1,190 @@
+use anyhow::Result;
+use futures::AsyncReadExt;
+use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+use serde::{Deserialize, Serialize};
+
+use crate::{AnthropicError, ApiError, RateLimitInfo, Request, Response};
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct BatchRequest {
+ pub custom_id: String,
+ pub params: Request,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct CreateBatchRequest {
+ pub requests: Vec<BatchRequest>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct MessageBatchRequestCounts {
+ pub processing: u64,
+ pub succeeded: u64,
+ pub errored: u64,
+ pub canceled: u64,
+ pub expired: u64,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct MessageBatch {
+ pub id: String,
+ #[serde(rename = "type")]
+ pub batch_type: String,
+ pub processing_status: String,
+ pub request_counts: MessageBatchRequestCounts,
+ pub ended_at: Option<String>,
+ pub created_at: String,
+ pub expires_at: String,
+ pub archived_at: Option<String>,
+ pub cancel_initiated_at: Option<String>,
+ pub results_url: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(tag = "type")]
+pub enum BatchResult {
+ #[serde(rename = "succeeded")]
+ Succeeded { message: Response },
+ #[serde(rename = "errored")]
+ Errored { error: ApiError },
+ #[serde(rename = "canceled")]
+ Canceled,
+ #[serde(rename = "expired")]
+ Expired,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct BatchIndividualResponse {
+ pub custom_id: String,
+ pub result: BatchResult,
+}
+
+pub async fn create_batch(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ request: CreateBatchRequest,
+) -> Result<MessageBatch, AnthropicError> {
+ let uri = format!("{api_url}/v1/messages/batches");
+
+ let request_builder = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Anthropic-Version", "2023-06-01")
+ .header("X-Api-Key", api_key.trim())
+ .header("Content-Type", "application/json");
+
+ let serialized_request =
+ serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?;
+ let http_request = request_builder
+ .body(AsyncBody::from(serialized_request))
+ .map_err(AnthropicError::BuildRequestBody)?;
+
+ let mut response = client
+ .send(http_request)
+ .await
+ .map_err(AnthropicError::HttpSend)?;
+
+ let rate_limits = RateLimitInfo::from_headers(response.headers());
+
+ if response.status().is_success() {
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(AnthropicError::ReadResponse)?;
+
+ serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)
+ } else {
+ Err(crate::handle_error_response(response, rate_limits).await)
+ }
+}
+
+pub async fn retrieve_batch(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ message_batch_id: &str,
+) -> Result<MessageBatch, AnthropicError> {
+ let uri = format!("{api_url}/v1/messages/batches/{message_batch_id}");
+
+ let request_builder = HttpRequest::builder()
+ .method(Method::GET)
+ .uri(uri)
+ .header("Anthropic-Version", "2023-06-01")
+ .header("X-Api-Key", api_key.trim());
+
+ let http_request = request_builder
+ .body(AsyncBody::default())
+ .map_err(AnthropicError::BuildRequestBody)?;
+
+ let mut response = client
+ .send(http_request)
+ .await
+ .map_err(AnthropicError::HttpSend)?;
+
+ let rate_limits = RateLimitInfo::from_headers(response.headers());
+
+ if response.status().is_success() {
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(AnthropicError::ReadResponse)?;
+
+ serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)
+ } else {
+ Err(crate::handle_error_response(response, rate_limits).await)
+ }
+}
+
+pub async fn retrieve_batch_results(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ message_batch_id: &str,
+) -> Result<Vec<BatchIndividualResponse>, AnthropicError> {
+ let uri = format!("{api_url}/v1/messages/batches/{message_batch_id}/results");
+
+ let request_builder = HttpRequest::builder()
+ .method(Method::GET)
+ .uri(uri)
+ .header("Anthropic-Version", "2023-06-01")
+ .header("X-Api-Key", api_key.trim());
+
+ let http_request = request_builder
+ .body(AsyncBody::default())
+ .map_err(AnthropicError::BuildRequestBody)?;
+
+ let mut response = client
+ .send(http_request)
+ .await
+ .map_err(AnthropicError::HttpSend)?;
+
+ let rate_limits = RateLimitInfo::from_headers(response.headers());
+
+ if response.status().is_success() {
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(AnthropicError::ReadResponse)?;
+
+ let mut results = Vec::new();
+ for line in body.lines() {
+ if line.trim().is_empty() {
+ continue;
+ }
+ let result: BatchIndividualResponse =
+ serde_json::from_str(line).map_err(AnthropicError::DeserializeResponse)?;
+ results.push(result);
+ }
+
+ Ok(results)
+ } else {
+ Err(crate::handle_error_response(response, rate_limits).await)
+ }
+}
@@ -14,7 +14,7 @@ use fs::{Fs, RenameOptions};
use futures::{FutureExt, StreamExt, future::Shared};
use gpui::{
App, AppContext as _, Context, Entity, EventEmitter, RenderImage, SharedString, Subscription,
- Task,
+ Task, WeakEntity,
};
use itertools::Itertools as _;
use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset};
@@ -688,7 +688,7 @@ pub struct TextThread {
_subscriptions: Vec<Subscription>,
telemetry: Option<Arc<Telemetry>>,
language_registry: Arc<LanguageRegistry>,
- project: Option<Entity<Project>>,
+ project: Option<WeakEntity<Project>>,
prompt_builder: Arc<PromptBuilder>,
completion_mode: agent_settings::CompletionMode,
}
@@ -708,7 +708,7 @@ impl EventEmitter<TextThreadEvent> for TextThread {}
impl TextThread {
pub fn local(
language_registry: Arc<LanguageRegistry>,
- project: Option<Entity<Project>>,
+ project: Option<WeakEntity<Project>>,
telemetry: Option<Arc<Telemetry>>,
prompt_builder: Arc<PromptBuilder>,
slash_commands: Arc<SlashCommandWorkingSet>,
@@ -742,7 +742,7 @@ impl TextThread {
language_registry: Arc<LanguageRegistry>,
prompt_builder: Arc<PromptBuilder>,
slash_commands: Arc<SlashCommandWorkingSet>,
- project: Option<Entity<Project>>,
+ project: Option<WeakEntity<Project>>,
telemetry: Option<Arc<Telemetry>>,
cx: &mut Context<Self>,
) -> Self {
@@ -873,7 +873,7 @@ impl TextThread {
language_registry: Arc<LanguageRegistry>,
prompt_builder: Arc<PromptBuilder>,
slash_commands: Arc<SlashCommandWorkingSet>,
- project: Option<Entity<Project>>,
+ project: Option<WeakEntity<Project>>,
telemetry: Option<Arc<Telemetry>>,
cx: &mut Context<Self>,
) -> Self {
@@ -1167,10 +1167,6 @@ impl TextThread {
self.language_registry.clone()
}
- pub fn project(&self) -> Option<Entity<Project>> {
- self.project.clone()
- }
-
pub fn prompt_builder(&self) -> Arc<PromptBuilder> {
self.prompt_builder.clone()
}
@@ -2967,7 +2963,7 @@ impl TextThread {
}
fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut App) {
- let Some(project) = &self.project else {
+ let Some(project) = self.project.as_ref().and_then(|project| project.upgrade()) else {
return;
};
project.read(cx).user_store().update(cx, |user_store, cx| {
@@ -51,7 +51,7 @@ pub struct TextThreadStore {
telemetry: Arc<Telemetry>,
_watch_updates: Task<Option<()>>,
client: Arc<Client>,
- project: Entity<Project>,
+ project: WeakEntity<Project>,
project_is_shared: bool,
client_subscription: Option<client::Subscription>,
_project_subscriptions: Vec<gpui::Subscription>,
@@ -119,10 +119,10 @@ impl TextThreadStore {
],
project_is_shared: false,
client: project.read(cx).client(),
- project: project.clone(),
+ project: project.downgrade(),
prompt_builder,
};
- this.handle_project_shared(project.clone(), cx);
+ this.handle_project_shared(cx);
this.synchronize_contexts(cx);
this.register_context_server_handlers(cx);
this.reload(cx).detach_and_log_err(cx);
@@ -146,7 +146,7 @@ impl TextThreadStore {
telemetry: project.read(cx).client().telemetry().clone(),
_watch_updates: Task::ready(None),
client: project.read(cx).client(),
- project,
+ project: project.downgrade(),
project_is_shared: false,
client_subscription: None,
_project_subscriptions: Default::default(),
@@ -180,8 +180,10 @@ impl TextThreadStore {
) -> Result<proto::OpenContextResponse> {
let context_id = TextThreadId::from_proto(envelope.payload.context_id);
let operations = this.update(&mut cx, |this, cx| {
+ let project = this.project.upgrade().context("project not found")?;
+
anyhow::ensure!(
- !this.project.read(cx).is_via_collab(),
+ !project.read(cx).is_via_collab(),
"only the host contexts can be opened"
);
@@ -211,8 +213,9 @@ impl TextThreadStore {
mut cx: AsyncApp,
) -> Result<proto::CreateContextResponse> {
let (context_id, operations) = this.update(&mut cx, |this, cx| {
+ let project = this.project.upgrade().context("project not found")?;
anyhow::ensure!(
- !this.project.read(cx).is_via_collab(),
+ !project.read(cx).is_via_collab(),
"can only create contexts as the host"
);
@@ -255,8 +258,9 @@ impl TextThreadStore {
mut cx: AsyncApp,
) -> Result<proto::SynchronizeContextsResponse> {
this.update(&mut cx, |this, cx| {
+ let project = this.project.upgrade().context("project not found")?;
anyhow::ensure!(
- !this.project.read(cx).is_via_collab(),
+ !project.read(cx).is_via_collab(),
"only the host can synchronize contexts"
);
@@ -293,8 +297,12 @@ impl TextThreadStore {
})?
}
- fn handle_project_shared(&mut self, _: Entity<Project>, cx: &mut Context<Self>) {
- let is_shared = self.project.read(cx).is_shared();
+ fn handle_project_shared(&mut self, cx: &mut Context<Self>) {
+ let Some(project) = self.project.upgrade() else {
+ return;
+ };
+
+ let is_shared = project.read(cx).is_shared();
let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
if is_shared == was_shared {
return;
@@ -309,7 +317,7 @@ impl TextThreadStore {
false
}
});
- let remote_id = self.project.read(cx).remote_id().unwrap();
+ let remote_id = project.read(cx).remote_id().unwrap();
self.client_subscription = self
.client
.subscribe_to_entity(remote_id)
@@ -323,13 +331,13 @@ impl TextThreadStore {
fn handle_project_event(
&mut self,
- project: Entity<Project>,
+ _project: Entity<Project>,
event: &project::Event,
cx: &mut Context<Self>,
) {
match event {
project::Event::RemoteIdChanged(_) => {
- self.handle_project_shared(project, cx);
+ self.handle_project_shared(cx);
}
project::Event::Reshared => {
self.advertise_contexts(cx);
@@ -382,7 +390,10 @@ impl TextThreadStore {
}
pub fn create_remote(&mut self, cx: &mut Context<Self>) -> Task<Result<Entity<TextThread>>> {
- let project = self.project.read(cx);
+ let Some(project) = self.project.upgrade() else {
+ return Task::ready(Err(anyhow::anyhow!("project was dropped")));
+ };
+ let project = project.read(cx);
let Some(project_id) = project.remote_id() else {
return Task::ready(Err(anyhow::anyhow!("project was not remote")));
};
@@ -541,7 +552,10 @@ impl TextThreadStore {
text_thread_id: TextThreadId,
cx: &mut Context<Self>,
) -> Task<Result<Entity<TextThread>>> {
- let project = self.project.read(cx);
+ let Some(project) = self.project.upgrade() else {
+ return Task::ready(Err(anyhow::anyhow!("project was dropped")));
+ };
+ let project = project.read(cx);
let Some(project_id) = project.remote_id() else {
return Task::ready(Err(anyhow::anyhow!("project was not remote")));
};
@@ -618,7 +632,10 @@ impl TextThreadStore {
event: &TextThreadEvent,
cx: &mut Context<Self>,
) {
- let Some(project_id) = self.project.read(cx).remote_id() else {
+ let Some(project) = self.project.upgrade() else {
+ return;
+ };
+ let Some(project_id) = project.read(cx).remote_id() else {
return;
};
@@ -652,12 +669,14 @@ impl TextThreadStore {
}
fn advertise_contexts(&self, cx: &App) {
- let Some(project_id) = self.project.read(cx).remote_id() else {
+ let Some(project) = self.project.upgrade() else {
+ return;
+ };
+ let Some(project_id) = project.read(cx).remote_id() else {
return;
};
-
// For now, only the host can advertise their open contexts.
- if self.project.read(cx).is_via_collab() {
+ if project.read(cx).is_via_collab() {
return;
}
@@ -689,7 +708,10 @@ impl TextThreadStore {
}
fn synchronize_contexts(&mut self, cx: &mut Context<Self>) {
- let Some(project_id) = self.project.read(cx).remote_id() else {
+ let Some(project) = self.project.upgrade() else {
+ return;
+ };
+ let Some(project_id) = project.read(cx).remote_id() else {
return;
};
@@ -828,7 +850,10 @@ impl TextThreadStore {
}
fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
- let context_server_store = self.project.read(cx).context_server_store();
+ let Some(project) = self.project.upgrade() else {
+ return;
+ };
+ let context_server_store = project.read(cx).context_server_store();
cx.subscribe(&context_server_store, Self::handle_context_server_event)
.detach();
@@ -31,18 +31,10 @@ pub struct PredictEditsRequest {
/// Within `signatures`
pub excerpt_parent: Option<usize>,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
- pub included_files: Vec<IncludedFile>,
- #[serde(skip_serializing_if = "Vec::is_empty", default)]
- pub signatures: Vec<Signature>,
- #[serde(skip_serializing_if = "Vec::is_empty", default)]
- pub referenced_declarations: Vec<ReferencedDeclaration>,
+ pub related_files: Vec<RelatedFile>,
pub events: Vec<Arc<Event>>,
#[serde(default)]
pub can_collect_data: bool,
- #[serde(skip_serializing_if = "Vec::is_empty", default)]
- pub diagnostic_groups: Vec<DiagnosticGroup>,
- #[serde(skip_serializing_if = "is_default", default)]
- pub diagnostic_groups_truncated: bool,
/// Info about the git repository state, only present when can_collect_data is true.
#[serde(skip_serializing_if = "Option::is_none", default)]
pub git_info: Option<PredictEditsGitInfo>,
@@ -58,7 +50,7 @@ pub struct PredictEditsRequest {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct IncludedFile {
+pub struct RelatedFile {
pub path: Arc<Path>,
pub max_row: Line,
pub excerpts: Vec<Excerpt>,
@@ -72,11 +64,9 @@ pub struct Excerpt {
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum PromptFormat {
- MarkedExcerpt,
- LabeledSections,
- NumLinesUniDiff,
+ /// XML old_tex/new_text
OldTextNewText,
- /// Prompt format intended for use via zeta_cli
+ /// Prompt format intended for use via edit_prediction_cli
OnlySnippets,
/// One-sentence instructions used in fine-tuned models
Minimal,
@@ -87,7 +77,7 @@ pub enum PromptFormat {
}
impl PromptFormat {
- pub const DEFAULT: PromptFormat = PromptFormat::NumLinesUniDiff;
+ pub const DEFAULT: PromptFormat = PromptFormat::Minimal;
}
impl Default for PromptFormat {
@@ -105,10 +95,7 @@ impl PromptFormat {
impl std::fmt::Display for PromptFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
- PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
- PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
- PromptFormat::NumLinesUniDiff => write!(f, "Numbered Lines / Unified Diff"),
PromptFormat::OldTextNewText => write!(f, "Old Text / New Text"),
PromptFormat::Minimal => write!(f, "Minimal"),
PromptFormat::MinimalQwen => write!(f, "Minimal + Qwen FIM"),
@@ -178,67 +165,6 @@ impl<'a> std::fmt::Display for DiffPathFmt<'a> {
}
}
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct Signature {
- pub text: String,
- pub text_is_truncated: bool,
- #[serde(skip_serializing_if = "Option::is_none", default)]
- pub parent_index: Option<usize>,
- /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
- /// file is implicitly the file that contains the descendant declaration or excerpt.
- pub range: Range<Line>,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct ReferencedDeclaration {
- pub path: Arc<Path>,
- pub text: String,
- pub text_is_truncated: bool,
- /// Range of `text` within file, possibly truncated according to `text_is_truncated`
- pub range: Range<Line>,
- /// Range within `text`
- pub signature_range: Range<usize>,
- /// Index within `signatures`.
- #[serde(skip_serializing_if = "Option::is_none", default)]
- pub parent_index: Option<usize>,
- pub score_components: DeclarationScoreComponents,
- pub signature_score: f32,
- pub declaration_score: f32,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct DeclarationScoreComponents {
- pub is_same_file: bool,
- pub is_referenced_nearby: bool,
- pub is_referenced_in_breadcrumb: bool,
- pub reference_count: usize,
- pub same_file_declaration_count: usize,
- pub declaration_count: usize,
- pub reference_line_distance: u32,
- pub declaration_line_distance: u32,
- pub excerpt_vs_item_jaccard: f32,
- pub excerpt_vs_signature_jaccard: f32,
- pub adjacent_vs_item_jaccard: f32,
- pub adjacent_vs_signature_jaccard: f32,
- pub excerpt_vs_item_weighted_overlap: f32,
- pub excerpt_vs_signature_weighted_overlap: f32,
- pub adjacent_vs_item_weighted_overlap: f32,
- pub adjacent_vs_signature_weighted_overlap: f32,
- pub path_import_match_count: usize,
- pub wildcard_path_import_match_count: usize,
- pub import_similarity: f32,
- pub max_import_similarity: f32,
- pub normalized_import_similarity: f32,
- pub wildcard_import_similarity: f32,
- pub normalized_wildcard_import_similarity: f32,
- pub included_by_others: usize,
- pub includes_others: usize,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-#[serde(transparent)]
-pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
-
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictEditsResponse {
pub request_id: Uuid,
@@ -262,10 +188,6 @@ pub struct Edit {
pub content: String,
}
-fn is_default<T: Default + PartialEq>(value: &T) -> bool {
- *value == T::default()
-}
-
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
pub struct Point {
pub line: Line,
@@ -15,9 +15,4 @@ path = "src/cloud_zeta2_prompt.rs"
anyhow.workspace = true
cloud_llm_client.workspace = true
indoc.workspace = true
-ordered-float.workspace = true
-rustc-hash.workspace = true
-schemars.workspace = true
serde.workspace = true
-serde_json.workspace = true
-strum.workspace = true
@@ -1,20 +1,12 @@
-//! Zeta2 prompt planning and generation code shared with cloud.
-pub mod retrieval_prompt;
-
-use anyhow::{Context as _, Result, anyhow};
+use anyhow::Result;
use cloud_llm_client::predict_edits_v3::{
- self, DiffPathFmt, Event, Excerpt, IncludedFile, Line, Point, PromptFormat,
- ReferencedDeclaration,
+ self, DiffPathFmt, Event, Excerpt, Line, Point, PromptFormat, RelatedFile,
};
use indoc::indoc;
-use ordered_float::OrderedFloat;
-use rustc_hash::{FxHashMap, FxHashSet};
-use serde::Serialize;
use std::cmp;
use std::fmt::Write;
+use std::path::Path;
use std::sync::Arc;
-use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
-use strum::{EnumIter, IntoEnumIterator};
pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
@@ -24,69 +16,6 @@ pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_s
/// NOTE: Differs from zed version of constant - includes a newline
pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
-// TODO: use constants for markers?
-const MARKED_EXCERPT_INSTRUCTIONS: &str = indoc! {"
- You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
-
- The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|user_cursor|>. Please respond with edited code for that region.
-
- Other code is provided for context, and `โฆ` indicates when code has been skipped.
-
- ## Edit History
-
-"};
-
-const LABELED_SECTIONS_INSTRUCTIONS: &str = indoc! {r#"
- You are a code completion assistant and your task is to analyze user edits, and suggest an edit to one of the provided sections of code.
-
- Sections of code are grouped by file and then labeled by `<|section_N|>` (e.g `<|section_8|>`).
-
- The cursor position is marked with `<|user_cursor|>` and it will appear within a special section labeled `<|current_section|>`. Prefer editing the current section until no more changes are needed within it.
-
- Respond ONLY with the name of the section to edit on a single line, followed by all of the code that should replace that section. For example:
-
- <|current_section|>
- for i in 0..16 {
- println!("{i}");
- }
-
- ## Edit History
-
-"#};
-
-const NUMBERED_LINES_INSTRUCTIONS: &str = indoc! {r#"
- # Instructions
-
- You are an edit prediction agent in a code editor.
- Your job is to predict the next edit that the user will make,
- based on their last few edits and their current cursor location.
-
- ## Output Format
-
- You must briefly explain your understanding of the user's goal, in one
- or two sentences, and then specify their next edit in the form of a
- unified diff, like this:
-
- ```
- --- a/src/myapp/cli.py
- +++ b/src/myapp/cli.py
- @@ ... @@
- import os
- import time
- import sys
- +from constants import LOG_LEVEL_WARNING
- @@ ... @@
- config.headless()
- config.set_interactive(false)
- -config.set_log_level(LOG_L)
- +config.set_log_level(LOG_LEVEL_WARNING)
- config.set_use_color(True)
- ```
-
- ## Edit History
-
-"#};
-
const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#"
You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.
@@ -94,20 +23,6 @@ const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#"
"#};
-const UNIFIED_DIFF_REMINDER: &str = indoc! {"
- ---
-
- Analyze the edit history and the files, then provide the unified diff for your predicted edits.
- Do not include the cursor marker in your output.
- Your diff should include edited file paths in its file headers (lines beginning with `---` and `+++`).
- Do not include line numbers in the hunk headers, use `@@ ... @@`.
- Removed lines begin with `-`.
- Added lines begin with `+`.
- Context lines begin with an extra space.
- Context and removed lines are used to match the target edit location, so make sure to include enough of them
- to uniquely identify it amongst all excerpts of code provided.
-"};
-
const MINIMAL_PROMPT_REMINDER: &str = indoc! {"
---
@@ -164,49 +79,25 @@ const OLD_TEXT_NEW_TEXT_REMINDER: &str = indoc! {r#"
Remember that the edits in the edit history have already been applied.
"#};
-pub fn build_prompt(
- request: &predict_edits_v3::PredictEditsRequest,
-) -> Result<(String, SectionLabels)> {
- let mut section_labels = Default::default();
-
+pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<String> {
let prompt_data = PromptData {
events: request.events.clone(),
cursor_point: request.cursor_point,
cursor_path: request.excerpt_path.clone(),
- included_files: request.included_files.clone(),
+ included_files: request.related_files.clone(),
};
match request.prompt_format {
PromptFormat::MinimalQwen => {
- return Ok((MinimalQwenPrompt.render(&prompt_data), section_labels));
+ return Ok(MinimalQwenPrompt.render(&prompt_data));
}
PromptFormat::SeedCoder1120 => {
- return Ok((SeedCoder1120Prompt.render(&prompt_data), section_labels));
+ return Ok(SeedCoder1120Prompt.render(&prompt_data));
}
_ => (),
};
- let mut insertions = match request.prompt_format {
- PromptFormat::MarkedExcerpt => vec![
- (
- Point {
- line: request.excerpt_line_range.start,
- column: 0,
- },
- EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
- ),
- (request.cursor_point, CURSOR_MARKER),
- (
- Point {
- line: request.excerpt_line_range.end,
- column: 0,
- },
- EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
- ),
- ],
- PromptFormat::LabeledSections
- | PromptFormat::NumLinesUniDiff
- | PromptFormat::Minimal
- | PromptFormat::OldTextNewText => {
+ let insertions = match request.prompt_format {
+ PromptFormat::Minimal | PromptFormat::OldTextNewText => {
vec![(request.cursor_point, CURSOR_MARKER)]
}
PromptFormat::OnlySnippets => vec![],
@@ -215,9 +106,6 @@ pub fn build_prompt(
};
let mut prompt = match request.prompt_format {
- PromptFormat::MarkedExcerpt => MARKED_EXCERPT_INSTRUCTIONS.to_string(),
- PromptFormat::LabeledSections => LABELED_SECTIONS_INSTRUCTIONS.to_string(),
- PromptFormat::NumLinesUniDiff => NUMBERED_LINES_INSTRUCTIONS.to_string(),
PromptFormat::OldTextNewText => XML_TAGS_INSTRUCTIONS.to_string(),
PromptFormat::OnlySnippets => String::new(),
PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(),
@@ -247,7 +135,7 @@ pub fn build_prompt(
You can only edit exactly this part of the file.
We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.)
"},
- PromptFormat::NumLinesUniDiff | PromptFormat::OldTextNewText => indoc! {"
+ PromptFormat::OldTextNewText => indoc! {"
## Code Excerpts
Here is some excerpts of code that you should take into account to predict the next edit.
@@ -263,64 +151,51 @@ pub fn build_prompt(
Lines starting with `โฆ` indicate omitted line ranges. These may appear inside multi-line code constructs.
"},
- _ => indoc! {"
+ PromptFormat::OnlySnippets | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
+ indoc! {"
## Code Excerpts
The cursor marker <|user_cursor|> indicates the current user cursor position.
The file is in current state, edits from edit history have been applied.
- "},
+ "}
+ }
};
prompt.push_str(excerpts_preamble);
prompt.push('\n');
- if !request.referenced_declarations.is_empty() || !request.signatures.is_empty() {
- let syntax_based_prompt = SyntaxBasedPrompt::populate(request)?;
- section_labels = syntax_based_prompt.write(&mut insertions, &mut prompt)?;
- } else {
- if request.prompt_format == PromptFormat::LabeledSections {
- anyhow::bail!("PromptFormat::LabeledSections cannot be used with ContextMode::Llm");
- }
-
- let include_line_numbers = matches!(
- request.prompt_format,
- PromptFormat::NumLinesUniDiff | PromptFormat::Minimal
- );
- for related_file in &request.included_files {
- if request.prompt_format == PromptFormat::Minimal {
- write_codeblock_with_filename(
- &related_file.path,
- &related_file.excerpts,
- if related_file.path == request.excerpt_path {
- &insertions
- } else {
- &[]
- },
- related_file.max_row,
- include_line_numbers,
- &mut prompt,
- );
- } else {
- write_codeblock(
- &related_file.path,
- &related_file.excerpts,
- if related_file.path == request.excerpt_path {
- &insertions
- } else {
- &[]
- },
- related_file.max_row,
- include_line_numbers,
- &mut prompt,
- );
- }
+ let include_line_numbers = matches!(request.prompt_format, PromptFormat::Minimal);
+ for related_file in &request.related_files {
+ if request.prompt_format == PromptFormat::Minimal {
+ write_codeblock_with_filename(
+ &related_file.path,
+ &related_file.excerpts,
+ if related_file.path == request.excerpt_path {
+ &insertions
+ } else {
+ &[]
+ },
+ related_file.max_row,
+ include_line_numbers,
+ &mut prompt,
+ );
+ } else {
+ write_codeblock(
+ &related_file.path,
+ &related_file.excerpts,
+ if related_file.path == request.excerpt_path {
+ &insertions
+ } else {
+ &[]
+ },
+ related_file.max_row,
+ include_line_numbers,
+ &mut prompt,
+ );
}
}
match request.prompt_format {
- PromptFormat::NumLinesUniDiff => {
- prompt.push_str(UNIFIED_DIFF_REMINDER);
- }
PromptFormat::OldTextNewText => {
prompt.push_str(OLD_TEXT_NEW_TEXT_REMINDER);
}
@@ -330,7 +205,7 @@ pub fn build_prompt(
_ => {}
}
- Ok((prompt, section_labels))
+ Ok(prompt)
}
pub fn generation_params(prompt_format: PromptFormat) -> GenerationParams {
@@ -444,476 +319,11 @@ pub fn push_events(output: &mut String, events: &[Arc<predict_edits_v3::Event>])
writeln!(output, "`````\n").unwrap();
}
-pub struct SyntaxBasedPrompt<'a> {
- request: &'a predict_edits_v3::PredictEditsRequest,
- /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
- /// `to_prompt_string`.
- snippets: Vec<PlannedSnippet<'a>>,
- budget_used: usize,
-}
-
-#[derive(Clone, Debug)]
-pub struct PlannedSnippet<'a> {
- path: Arc<Path>,
- range: Range<Line>,
- text: &'a str,
- // TODO: Indicate this in the output
- #[allow(dead_code)]
- text_is_truncated: bool,
-}
-
-#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
-pub enum DeclarationStyle {
- Signature,
- Declaration,
-}
-
-#[derive(Default, Clone, Debug, Serialize)]
-pub struct SectionLabels {
- pub excerpt_index: usize,
- pub section_ranges: Vec<(Arc<Path>, Range<Line>)>,
-}
-
-impl<'a> SyntaxBasedPrompt<'a> {
- /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
- ///
- /// Initializes a priority queue by populating it with each snippet, finding the
- /// DeclarationStyle that minimizes `score_density = score / snippet.range(style).len()`. When a
- /// "signature" snippet is popped, insert an entry for the "declaration" variant that reflects
- /// the cost of upgrade.
- ///
- /// TODO: Implement an early halting condition. One option might be to have another priority
- /// queue where the score is the size, and update it accordingly. Another option might be to
- /// have some simpler heuristic like bailing after N failed insertions, or based on how much
- /// budget is left.
- ///
- /// TODO: Has the current known sources of imprecision:
- ///
- /// * Does not consider snippet overlap when ranking. For example, it might add a field to the
- /// plan even though the containing struct is already included.
- ///
- /// * Does not consider cost of signatures when ranking snippets - this is tricky since
- /// signatures may be shared by multiple snippets.
- ///
- /// * Does not include file paths / other text when considering max_bytes.
- pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result<Self> {
- let mut this = Self {
- request,
- snippets: Vec::new(),
- budget_used: request.excerpt.len(),
- };
- let mut included_parents = FxHashSet::default();
- let additional_parents = this.additional_parent_signatures(
- &request.excerpt_path,
- request.excerpt_parent,
- &included_parents,
- )?;
- this.add_parents(&mut included_parents, additional_parents);
-
- let max_bytes = request.prompt_max_bytes.unwrap_or(DEFAULT_MAX_PROMPT_BYTES);
-
- if this.budget_used > max_bytes {
- return Err(anyhow!(
- "Excerpt + signatures size of {} already exceeds budget of {}",
- this.budget_used,
- max_bytes
- ));
- }
-
- #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
- struct QueueEntry {
- score_density: OrderedFloat<f32>,
- declaration_index: usize,
- style: DeclarationStyle,
- }
-
- // Initialize priority queue with the best score for each snippet.
- let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
- for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
- let (style, score_density) = DeclarationStyle::iter()
- .map(|style| {
- (
- style,
- OrderedFloat(declaration_score_density(&declaration, style)),
- )
- })
- .max_by_key(|(_, score_density)| *score_density)
- .unwrap();
- queue.push(QueueEntry {
- score_density,
- declaration_index,
- style,
- });
- }
-
- // Knapsack selection loop
- while let Some(queue_entry) = queue.pop() {
- let Some(declaration) = request
- .referenced_declarations
- .get(queue_entry.declaration_index)
- else {
- return Err(anyhow!(
- "Invalid declaration index {}",
- queue_entry.declaration_index
- ));
- };
-
- let mut additional_bytes = declaration_size(declaration, queue_entry.style);
- if this.budget_used + additional_bytes > max_bytes {
- continue;
- }
-
- let additional_parents = this.additional_parent_signatures(
- &declaration.path,
- declaration.parent_index,
- &mut included_parents,
- )?;
- additional_bytes += additional_parents
- .iter()
- .map(|(_, snippet)| snippet.text.len())
- .sum::<usize>();
- if this.budget_used + additional_bytes > max_bytes {
- continue;
- }
-
- this.budget_used += additional_bytes;
- this.add_parents(&mut included_parents, additional_parents);
- let planned_snippet = match queue_entry.style {
- DeclarationStyle::Signature => {
- let Some(text) = declaration.text.get(declaration.signature_range.clone())
- else {
- return Err(anyhow!(
- "Invalid declaration signature_range {:?} with text.len() = {}",
- declaration.signature_range,
- declaration.text.len()
- ));
- };
- let signature_start_line = declaration.range.start
- + Line(
- declaration.text[..declaration.signature_range.start]
- .lines()
- .count() as u32,
- );
- let signature_end_line = signature_start_line
- + Line(
- declaration.text
- [declaration.signature_range.start..declaration.signature_range.end]
- .lines()
- .count() as u32,
- );
- let range = signature_start_line..signature_end_line;
-
- PlannedSnippet {
- path: declaration.path.clone(),
- range,
- text,
- text_is_truncated: declaration.text_is_truncated,
- }
- }
- DeclarationStyle::Declaration => PlannedSnippet {
- path: declaration.path.clone(),
- range: declaration.range.clone(),
- text: &declaration.text,
- text_is_truncated: declaration.text_is_truncated,
- },
- };
- this.snippets.push(planned_snippet);
-
- // When a Signature is consumed, insert an entry for Definition style.
- if queue_entry.style == DeclarationStyle::Signature {
- let signature_size = declaration_size(&declaration, DeclarationStyle::Signature);
- let declaration_size =
- declaration_size(&declaration, DeclarationStyle::Declaration);
- let signature_score = declaration_score(&declaration, DeclarationStyle::Signature);
- let declaration_score =
- declaration_score(&declaration, DeclarationStyle::Declaration);
-
- let score_diff = declaration_score - signature_score;
- let size_diff = declaration_size.saturating_sub(signature_size);
- if score_diff > 0.0001 && size_diff > 0 {
- queue.push(QueueEntry {
- declaration_index: queue_entry.declaration_index,
- score_density: OrderedFloat(score_diff / (size_diff as f32)),
- style: DeclarationStyle::Declaration,
- });
- }
- }
- }
-
- anyhow::Ok(this)
- }
-
- fn add_parents(
- &mut self,
- included_parents: &mut FxHashSet<usize>,
- snippets: Vec<(usize, PlannedSnippet<'a>)>,
- ) {
- for (parent_index, snippet) in snippets {
- included_parents.insert(parent_index);
- self.budget_used += snippet.text.len();
- self.snippets.push(snippet);
- }
- }
-
- fn additional_parent_signatures(
- &self,
- path: &Arc<Path>,
- parent_index: Option<usize>,
- included_parents: &FxHashSet<usize>,
- ) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
- let mut results = Vec::new();
- self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
- Ok(results)
- }
-
- fn additional_parent_signatures_impl(
- &self,
- path: &Arc<Path>,
- parent_index: Option<usize>,
- included_parents: &FxHashSet<usize>,
- results: &mut Vec<(usize, PlannedSnippet<'a>)>,
- ) -> Result<()> {
- let Some(parent_index) = parent_index else {
- return Ok(());
- };
- if included_parents.contains(&parent_index) {
- return Ok(());
- }
- let Some(parent_signature) = self.request.signatures.get(parent_index) else {
- return Err(anyhow!("Invalid parent index {}", parent_index));
- };
- results.push((
- parent_index,
- PlannedSnippet {
- path: path.clone(),
- range: parent_signature.range.clone(),
- text: &parent_signature.text,
- text_is_truncated: parent_signature.text_is_truncated,
- },
- ));
- self.additional_parent_signatures_impl(
- path,
- parent_signature.parent_index,
- included_parents,
- results,
- )
- }
-
- /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
- /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
- /// chunks.
- pub fn write(
- &'a self,
- excerpt_file_insertions: &mut Vec<(Point, &'static str)>,
- prompt: &mut String,
- ) -> Result<SectionLabels> {
- let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
- FxHashMap::default();
- for snippet in &self.snippets {
- file_to_snippets
- .entry(&snippet.path)
- .or_default()
- .push(snippet);
- }
-
- // Reorder so that file with cursor comes last
- let mut file_snippets = Vec::new();
- let mut excerpt_file_snippets = Vec::new();
- for (file_path, snippets) in file_to_snippets {
- if file_path == self.request.excerpt_path.as_ref() {
- excerpt_file_snippets = snippets;
- } else {
- file_snippets.push((file_path, snippets, false));
- }
- }
- let excerpt_snippet = PlannedSnippet {
- path: self.request.excerpt_path.clone(),
- range: self.request.excerpt_line_range.clone(),
- text: &self.request.excerpt,
- text_is_truncated: false,
- };
- excerpt_file_snippets.push(&excerpt_snippet);
- file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
-
- let section_labels =
- self.push_file_snippets(prompt, excerpt_file_insertions, file_snippets)?;
-
- Ok(section_labels)
- }
-
- fn push_file_snippets(
- &self,
- output: &mut String,
- excerpt_file_insertions: &mut Vec<(Point, &'static str)>,
- file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>,
- ) -> Result<SectionLabels> {
- let mut section_ranges = Vec::new();
- let mut excerpt_index = None;
-
- for (file_path, mut snippets, is_excerpt_file) in file_snippets {
- snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
-
- // TODO: What if the snippets get expanded too large to be editable?
- let mut current_snippet: Option<(&PlannedSnippet, Range<Line>)> = None;
- let mut disjoint_snippets: Vec<(&PlannedSnippet, Range<Line>)> = Vec::new();
- for snippet in snippets {
- if let Some((_, current_snippet_range)) = current_snippet.as_mut()
- && snippet.range.start <= current_snippet_range.end
- {
- current_snippet_range.end = current_snippet_range.end.max(snippet.range.end);
- continue;
- }
- if let Some(current_snippet) = current_snippet.take() {
- disjoint_snippets.push(current_snippet);
- }
- current_snippet = Some((snippet, snippet.range.clone()));
- }
- if let Some(current_snippet) = current_snippet.take() {
- disjoint_snippets.push(current_snippet);
- }
-
- writeln!(output, "`````path={}", file_path.display()).ok();
- let mut skipped_last_snippet = false;
- for (snippet, range) in disjoint_snippets {
- let section_index = section_ranges.len();
-
- match self.request.prompt_format {
- PromptFormat::MarkedExcerpt
- | PromptFormat::OnlySnippets
- | PromptFormat::OldTextNewText
- | PromptFormat::Minimal
- | PromptFormat::NumLinesUniDiff => {
- if range.start.0 > 0 && !skipped_last_snippet {
- output.push_str("โฆ\n");
- }
- }
- PromptFormat::LabeledSections => {
- if is_excerpt_file
- && range.start <= self.request.excerpt_line_range.start
- && range.end >= self.request.excerpt_line_range.end
- {
- writeln!(output, "<|current_section|>").ok();
- } else {
- writeln!(output, "<|section_{}|>", section_index).ok();
- }
- }
- PromptFormat::MinimalQwen => unreachable!(),
- PromptFormat::SeedCoder1120 => unreachable!(),
- }
-
- let push_full_snippet = |output: &mut String| {
- if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
- for (i, line) in snippet.text.lines().enumerate() {
- writeln!(output, "{}|{}", i as u32 + range.start.0 + 1, line)?;
- }
- } else {
- output.push_str(&snippet.text);
- }
- anyhow::Ok(())
- };
-
- if is_excerpt_file {
- if self.request.prompt_format == PromptFormat::OnlySnippets {
- if range.start >= self.request.excerpt_line_range.start
- && range.end <= self.request.excerpt_line_range.end
- {
- skipped_last_snippet = true;
- } else {
- skipped_last_snippet = false;
- output.push_str(snippet.text);
- }
- } else if !excerpt_file_insertions.is_empty() {
- let lines = snippet.text.lines().collect::<Vec<_>>();
- let push_line = |output: &mut String, line_ix: usize| {
- if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
- write!(output, "{}|", line_ix as u32 + range.start.0 + 1)?;
- }
- anyhow::Ok(writeln!(output, "{}", lines[line_ix])?)
- };
- let mut last_line_ix = 0;
- let mut insertion_ix = 0;
- while insertion_ix < excerpt_file_insertions.len() {
- let (point, insertion) = &excerpt_file_insertions[insertion_ix];
- let found = point.line >= range.start && point.line <= range.end;
- if found {
- excerpt_index = Some(section_index);
- let insertion_line_ix = (point.line.0 - range.start.0) as usize;
- for line_ix in last_line_ix..insertion_line_ix {
- push_line(output, line_ix)?;
- }
- if let Some(next_line) = lines.get(insertion_line_ix) {
- if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
- write!(
- output,
- "{}|",
- insertion_line_ix as u32 + range.start.0 + 1
- )?
- }
- output.push_str(&next_line[..point.column as usize]);
- output.push_str(insertion);
- writeln!(output, "{}", &next_line[point.column as usize..])?;
- } else {
- writeln!(output, "{}", insertion)?;
- }
- last_line_ix = insertion_line_ix + 1;
- excerpt_file_insertions.remove(insertion_ix);
- continue;
- }
- insertion_ix += 1;
- }
- skipped_last_snippet = false;
- for line_ix in last_line_ix..lines.len() {
- push_line(output, line_ix)?;
- }
- } else {
- skipped_last_snippet = false;
- push_full_snippet(output)?;
- }
- } else {
- skipped_last_snippet = false;
- push_full_snippet(output)?;
- }
-
- section_ranges.push((snippet.path.clone(), range));
- }
-
- output.push_str("`````\n\n");
- }
-
- Ok(SectionLabels {
- // TODO: Clean this up
- excerpt_index: match self.request.prompt_format {
- PromptFormat::OnlySnippets => 0,
- _ => excerpt_index.context("bug: no snippet found for excerpt")?,
- },
- section_ranges,
- })
- }
-}
-
-fn declaration_score_density(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
- declaration_score(declaration, style) / declaration_size(declaration, style) as f32
-}
-
-fn declaration_score(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
- match style {
- DeclarationStyle::Signature => declaration.signature_score,
- DeclarationStyle::Declaration => declaration.declaration_score,
- }
-}
-
-fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> usize {
- match style {
- DeclarationStyle::Signature => declaration.signature_range.len(),
- DeclarationStyle::Declaration => declaration.text.len(),
- }
-}
-
struct PromptData {
events: Vec<Arc<Event>>,
cursor_point: Point,
cursor_path: Arc<Path>, // TODO: make a common struct with cursor_point
- included_files: Vec<IncludedFile>,
+ included_files: Vec<RelatedFile>,
}
#[derive(Default)]
@@ -1051,7 +461,7 @@ impl SeedCoder1120Prompt {
context
}
- fn fmt_fim(&self, file: &IncludedFile, cursor_point: Point) -> String {
+ fn fmt_fim(&self, file: &RelatedFile, cursor_point: Point) -> String {
let mut buf = String::new();
const FIM_SUFFIX: &str = "<[fim-suffix]>";
const FIM_PREFIX: &str = "<[fim-prefix]>";
@@ -1,244 +0,0 @@
-use anyhow::Result;
-use cloud_llm_client::predict_edits_v3::{self, Excerpt};
-use indoc::indoc;
-use schemars::JsonSchema;
-use serde::{Deserialize, Serialize};
-use std::fmt::Write;
-
-use crate::{push_events, write_codeblock};
-
-pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> Result<String> {
- let mut prompt = SEARCH_INSTRUCTIONS.to_string();
-
- if !request.events.is_empty() {
- writeln!(&mut prompt, "\n## User Edits\n\n")?;
- push_events(&mut prompt, &request.events);
- }
-
- writeln!(&mut prompt, "## Cursor context\n")?;
- write_codeblock(
- &request.excerpt_path,
- &[Excerpt {
- start_line: request.excerpt_line_range.start,
- text: request.excerpt.into(),
- }],
- &[],
- request.cursor_file_max_row,
- true,
- &mut prompt,
- );
-
- writeln!(&mut prompt, "{TOOL_USE_REMINDER}")?;
-
- Ok(prompt)
-}
-
-/// Search for relevant code
-///
-/// For the best results, run multiple queries at once with a single invocation of this tool.
-#[derive(Clone, Deserialize, Serialize, JsonSchema)]
-pub struct SearchToolInput {
- /// An array of queries to run for gathering context relevant to the next prediction
- #[schemars(length(max = 3))]
- #[serde(deserialize_with = "deserialize_queries")]
- pub queries: Box<[SearchToolQuery]>,
-}
-
-fn deserialize_queries<'de, D>(deserializer: D) -> Result<Box<[SearchToolQuery]>, D::Error>
-where
- D: serde::Deserializer<'de>,
-{
- use serde::de::Error;
-
- #[derive(Deserialize)]
- #[serde(untagged)]
- enum QueryCollection {
- Array(Box<[SearchToolQuery]>),
- DoubleArray(Box<[Box<[SearchToolQuery]>]>),
- Single(SearchToolQuery),
- }
-
- #[derive(Deserialize)]
- #[serde(untagged)]
- enum MaybeDoubleEncoded {
- SingleEncoded(QueryCollection),
- DoubleEncoded(String),
- }
-
- let result = MaybeDoubleEncoded::deserialize(deserializer)?;
-
- let normalized = match result {
- MaybeDoubleEncoded::SingleEncoded(value) => value,
- MaybeDoubleEncoded::DoubleEncoded(value) => {
- serde_json::from_str(&value).map_err(D::Error::custom)?
- }
- };
-
- Ok(match normalized {
- QueryCollection::Array(items) => items,
- QueryCollection::Single(search_tool_query) => Box::new([search_tool_query]),
- QueryCollection::DoubleArray(double_array) => double_array.into_iter().flatten().collect(),
- })
-}
-
-/// Search for relevant code by path, syntax hierarchy, and content.
-#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)]
-pub struct SearchToolQuery {
- /// 1. A glob pattern to match file paths in the codebase to search in.
- pub glob: String,
- /// 2. Regular expressions to match syntax nodes **by their first line** and hierarchy.
- ///
- /// Subsequent regexes match nodes within the full content of the nodes matched by the previous regexes.
- ///
- /// Example: Searching for a `User` class
- /// ["class\s+User"]
- ///
- /// Example: Searching for a `get_full_name` method under a `User` class
- /// ["class\s+User", "def\sget_full_name"]
- ///
- /// Skip this field to match on content alone.
- #[schemars(length(max = 3))]
- #[serde(default)]
- pub syntax_node: Vec<String>,
- /// 3. An optional regular expression to match the final content that should appear in the results.
- ///
- /// - Content will be matched within all lines of the matched syntax nodes.
- /// - If syntax node regexes are provided, this field can be skipped to include as much of the node itself as possible.
- /// - If no syntax node regexes are provided, the content will be matched within the entire file.
- pub content: Option<String>,
-}
-
-pub const TOOL_NAME: &str = "search";
-
-const SEARCH_INSTRUCTIONS: &str = indoc! {r#"
- You are part of an edit prediction system in a code editor.
- Your role is to search for code that will serve as context for predicting the next edit.
-
- - Analyze the user's recent edits and current cursor context
- - Use the `search` tool to find code that is relevant for predicting the next edit
- - Focus on finding:
- - Code patterns that might need similar changes based on the recent edits
- - Functions, variables, types, and constants referenced in the current cursor context
- - Related implementations, usages, or dependencies that may require consistent updates
- - How items defined in the cursor excerpt are used or altered
- - You will not be able to filter results or perform subsequent queries, so keep searches as targeted as possible
- - Use `syntax_node` parameter whenever you're looking for a particular type, class, or function
- - Avoid using wildcard globs if you already know the file path of the content you're looking for
-"#};
-
-const TOOL_USE_REMINDER: &str = indoc! {"
- --
- Analyze the user's intent in one to two sentences, then call the `search` tool.
-"};
-
-#[cfg(test)]
-mod tests {
- use serde_json::json;
-
- use super::*;
-
- #[test]
- fn test_deserialize_queries() {
- let single_query_json = indoc! {r#"{
- "queries": {
- "glob": "**/*.rs",
- "syntax_node": ["fn test"],
- "content": "assert"
- }
- }"#};
-
- let flat_input: SearchToolInput = serde_json::from_str(single_query_json).unwrap();
- assert_eq!(flat_input.queries.len(), 1);
- assert_eq!(flat_input.queries[0].glob, "**/*.rs");
- assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]);
- assert_eq!(flat_input.queries[0].content, Some("assert".to_string()));
-
- let flat_json = indoc! {r#"{
- "queries": [
- {
- "glob": "**/*.rs",
- "syntax_node": ["fn test"],
- "content": "assert"
- },
- {
- "glob": "**/*.ts",
- "syntax_node": [],
- "content": null
- }
- ]
- }"#};
-
- let flat_input: SearchToolInput = serde_json::from_str(flat_json).unwrap();
- assert_eq!(flat_input.queries.len(), 2);
- assert_eq!(flat_input.queries[0].glob, "**/*.rs");
- assert_eq!(flat_input.queries[0].syntax_node, vec!["fn test"]);
- assert_eq!(flat_input.queries[0].content, Some("assert".to_string()));
- assert_eq!(flat_input.queries[1].glob, "**/*.ts");
- assert_eq!(flat_input.queries[1].syntax_node.len(), 0);
- assert_eq!(flat_input.queries[1].content, None);
-
- let nested_json = indoc! {r#"{
- "queries": [
- [
- {
- "glob": "**/*.rs",
- "syntax_node": ["fn test"],
- "content": "assert"
- }
- ],
- [
- {
- "glob": "**/*.ts",
- "syntax_node": [],
- "content": null
- }
- ]
- ]
- }"#};
-
- let nested_input: SearchToolInput = serde_json::from_str(nested_json).unwrap();
-
- assert_eq!(nested_input.queries.len(), 2);
-
- assert_eq!(nested_input.queries[0].glob, "**/*.rs");
- assert_eq!(nested_input.queries[0].syntax_node, vec!["fn test"]);
- assert_eq!(nested_input.queries[0].content, Some("assert".to_string()));
- assert_eq!(nested_input.queries[1].glob, "**/*.ts");
- assert_eq!(nested_input.queries[1].syntax_node.len(), 0);
- assert_eq!(nested_input.queries[1].content, None);
-
- let double_encoded_queries = serde_json::to_string(&json!({
- "queries": serde_json::to_string(&json!([
- {
- "glob": "**/*.rs",
- "syntax_node": ["fn test"],
- "content": "assert"
- },
- {
- "glob": "**/*.ts",
- "syntax_node": [],
- "content": null
- }
- ])).unwrap()
- }))
- .unwrap();
-
- let double_encoded_input: SearchToolInput =
- serde_json::from_str(&double_encoded_queries).unwrap();
-
- assert_eq!(double_encoded_input.queries.len(), 2);
-
- assert_eq!(double_encoded_input.queries[0].glob, "**/*.rs");
- assert_eq!(double_encoded_input.queries[0].syntax_node, vec!["fn test"]);
- assert_eq!(
- double_encoded_input.queries[0].content,
- Some("assert".to_string())
- );
- assert_eq!(double_encoded_input.queries[1].glob, "**/*.ts");
- assert_eq!(double_encoded_input.queries[1].syntax_node.len(), 0);
- assert_eq!(double_encoded_input.queries[1].content, None);
-
- // ### ERROR Switching from var declarations to lexical declarations [RUN 073]
- // invalid search json {"queries": ["express/lib/response.js", "var\\s+[a-zA-Z_][a-zA-Z0-9_]*\\s*=.*;", "function.*\\(.*\\).*\\{.*\\}"]}
- }
-}
@@ -10,7 +10,7 @@ path = "src/codestral.rs"
[dependencies]
anyhow.workspace = true
-edit_prediction.workspace = true
+edit_prediction_types.workspace = true
edit_prediction_context.workspace = true
futures.workspace = true
gpui.workspace = true
@@ -1,6 +1,6 @@
use anyhow::{Context as _, Result};
-use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
+use edit_prediction_types::{Direction, EditPrediction, EditPredictionDelegate};
use futures::AsyncReadExt;
use gpui::{App, Context, Entity, Task};
use http_client::HttpClient;
@@ -43,17 +43,17 @@ impl CurrentCompletion {
/// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
/// Returns None if the user's edits conflict with the predicted edits.
fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
- edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
+ edit_prediction_types::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
}
}
-pub struct CodestralCompletionProvider {
+pub struct CodestralEditPredictionDelegate {
http_client: Arc<dyn HttpClient>,
pending_request: Option<Task<Result<()>>>,
current_completion: Option<CurrentCompletion>,
}
-impl CodestralCompletionProvider {
+impl CodestralEditPredictionDelegate {
pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
Self {
http_client,
@@ -165,7 +165,7 @@ impl CodestralCompletionProvider {
}
}
-impl EditPredictionProvider for CodestralCompletionProvider {
+impl EditPredictionDelegate for CodestralEditPredictionDelegate {
fn name() -> &'static str {
"codestral"
}
@@ -174,7 +174,7 @@ impl EditPredictionProvider for CodestralCompletionProvider {
"Codestral"
}
- fn show_completions_in_menu() -> bool {
+ fn show_predictions_in_menu() -> bool {
true
}
@@ -239,7 +239,6 @@ impl EditPredictionProvider for CodestralCompletionProvider {
cursor_point,
&snapshot,
&EXCERPT_OPTIONS,
- None,
)
.context("Line containing cursor doesn't fit in excerpt max bytes")?;
@@ -65,7 +65,7 @@ tokio = { workspace = true, features = ["full"] }
toml.workspace = true
tower = "0.4"
tower-http = { workspace = true, features = ["trace"] }
-tracing = "0.1.40"
+tracing.workspace = true
tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json", "registry", "tracing-log"] } # workaround for https://github.com/tokio-rs/tracing/issues/2927
util.workspace = true
uuid.workspace = true
@@ -121,6 +121,8 @@ CREATE TABLE "project_repositories" (
"merge_message" VARCHAR,
"branch_summary" VARCHAR,
"head_commit_details" VARCHAR,
+ "remote_upstream_url" VARCHAR,
+ "remote_origin_url" VARCHAR,
PRIMARY KEY (project_id, id)
);
@@ -0,0 +1,2 @@
+ALTER TABLE "project_repositories" ADD COLUMN "remote_upstream_url" VARCHAR;
+ALTER TABLE "project_repositories" ADD COLUMN "remote_origin_url" VARCHAR;
@@ -362,6 +362,8 @@ impl Database {
entry_ids: ActiveValue::set("[]".into()),
head_commit_details: ActiveValue::set(None),
merge_message: ActiveValue::set(None),
+ remote_upstream_url: ActiveValue::set(None),
+ remote_origin_url: ActiveValue::set(None),
}
}),
)
@@ -511,6 +513,8 @@ impl Database {
serde_json::to_string(&update.current_merge_conflicts).unwrap(),
)),
merge_message: ActiveValue::set(update.merge_message.clone()),
+ remote_upstream_url: ActiveValue::set(update.remote_upstream_url.clone()),
+ remote_origin_url: ActiveValue::set(update.remote_origin_url.clone()),
})
.on_conflict(
OnConflict::columns([
@@ -1005,6 +1009,8 @@ impl Database {
is_last_update: true,
merge_message: db_repository_entry.merge_message,
stash_entries: Vec::new(),
+ remote_upstream_url: db_repository_entry.remote_upstream_url.clone(),
+ remote_origin_url: db_repository_entry.remote_origin_url.clone(),
});
}
}
@@ -796,6 +796,8 @@ impl Database {
is_last_update: true,
merge_message: db_repository.merge_message,
stash_entries: Vec::new(),
+ remote_upstream_url: db_repository.remote_upstream_url.clone(),
+ remote_origin_url: db_repository.remote_origin_url.clone(),
});
}
}
@@ -22,6 +22,8 @@ pub struct Model {
pub branch_summary: Option<String>,
// A JSON object representing the current Head commit values
pub head_commit_details: Option<String>,
+ pub remote_upstream_url: Option<String>,
+ pub remote_origin_url: Option<String>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
@@ -1,5 +1,3 @@
-use std::sync::Arc;
-
use call::Room;
use client::ChannelId;
use gpui::{Entity, TestAppContext};
@@ -18,7 +16,6 @@ mod randomized_test_helpers;
mod remote_editing_collaboration_tests;
mod test_server;
-use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
pub use randomized_test_helpers::{
RandomizedTest, TestError, UserTestPlan, run_randomized_test, save_randomized_test_plan,
};
@@ -51,17 +48,3 @@ fn room_participants(room: &Entity<Room>, cx: &mut TestAppContext) -> RoomPartic
fn channel_id(room: &Entity<Room>, cx: &mut TestAppContext) -> Option<ChannelId> {
cx.read(|cx| room.read(cx).channel_id())
}
-
-fn rust_lang() -> Arc<Language> {
- Arc::new(Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- ))
-}
@@ -1,7 +1,4 @@
-use crate::{
- rpc::RECONNECT_TIMEOUT,
- tests::{TestServer, rust_lang},
-};
+use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer};
use call::ActiveCall;
use editor::{
DocumentColorsRenderMode, Editor, FETCH_COLORS_DEBOUNCE_TIMEOUT, MultiBufferOffset, RowInfo,
@@ -23,7 +20,7 @@ use gpui::{
App, Rgba, SharedString, TestAppContext, UpdateGlobal, VisualContext, VisualTestContext,
};
use indoc::indoc;
-use language::FakeLspAdapter;
+use language::{FakeLspAdapter, rust_lang};
use lsp::LSP_REQUEST_TIMEOUT;
use pretty_assertions::assert_eq;
use project::{
@@ -3518,7 +3515,6 @@ async fn test_git_blame_is_forwarded(cx_a: &mut TestAppContext, cx_b: &mut TestA
.into_iter()
.map(|(sha, message)| (sha.parse().unwrap(), message.into()))
.collect(),
- remote_url: Some("git@github.com:zed-industries/zed.git".to_string()),
};
client_a.fs().set_blame_for_repo(
Path::new(path!("/my-repo/.git")),
@@ -3603,10 +3599,6 @@ async fn test_git_blame_is_forwarded(cx_a: &mut TestAppContext, cx_b: &mut TestA
for (idx, (buffer, entry)) in entries.iter().flatten().enumerate() {
let details = blame.details_for_entry(*buffer, entry).unwrap();
assert_eq!(details.message, format!("message for idx-{}", idx));
- assert_eq!(
- details.permalink.unwrap().to_string(),
- format!("https://github.com/zed-industries/zed/commit/{}", entry.sha)
- );
}
});
});
@@ -2,7 +2,7 @@ use crate::{
rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
tests::{
RoomParticipants, TestClient, TestServer, channel_id, following_tests::join_channel,
- room_participants, rust_lang,
+ room_participants,
},
};
use anyhow::{Result, anyhow};
@@ -26,7 +26,7 @@ use language::{
Diagnostic, DiagnosticEntry, DiagnosticSourceKind, FakeLspAdapter, Language, LanguageConfig,
LanguageMatcher, LineEnding, OffsetRangeExt, Point, Rope,
language_settings::{Formatter, FormatterList},
- tree_sitter_rust, tree_sitter_typescript,
+ rust_lang, tree_sitter_rust, tree_sitter_typescript,
};
use lsp::{LanguageServerId, OneOf};
use parking_lot::Mutex;
@@ -33,7 +33,7 @@ fs.workspace = true
futures.workspace = true
gpui.workspace = true
http_client.workspace = true
-edit_prediction.workspace = true
+edit_prediction_types.workspace = true
language.workspace = true
log.workspace = true
lsp.workspace = true
@@ -1,5 +1,5 @@
pub mod copilot_chat;
-mod copilot_completion_provider;
+mod copilot_edit_prediction_delegate;
pub mod copilot_responses;
pub mod request;
mod sign_in;
@@ -46,7 +46,7 @@ use util::rel_path::RelPath;
use util::{ResultExt, fs::remove_matching};
use workspace::Workspace;
-pub use crate::copilot_completion_provider::CopilotCompletionProvider;
+pub use crate::copilot_edit_prediction_delegate::CopilotEditPredictionDelegate;
pub use crate::sign_in::{CopilotCodeVerification, initiate_sign_in, reinstall_and_sign_in};
actions!(
@@ -1,6 +1,6 @@
use crate::{Completion, Copilot};
use anyhow::Result;
-use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
+use edit_prediction_types::{Direction, EditPrediction, EditPredictionDelegate};
use gpui::{App, Context, Entity, EntityId, Task};
use language::{Buffer, OffsetRangeExt, ToOffset, language_settings::AllLanguageSettings};
use settings::Settings;
@@ -8,7 +8,7 @@ use std::{path::Path, time::Duration};
pub const COPILOT_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
-pub struct CopilotCompletionProvider {
+pub struct CopilotEditPredictionDelegate {
cycled: bool,
buffer_id: Option<EntityId>,
completions: Vec<Completion>,
@@ -19,7 +19,7 @@ pub struct CopilotCompletionProvider {
copilot: Entity<Copilot>,
}
-impl CopilotCompletionProvider {
+impl CopilotEditPredictionDelegate {
pub fn new(copilot: Entity<Copilot>) -> Self {
Self {
cycled: false,
@@ -47,7 +47,7 @@ impl CopilotCompletionProvider {
}
}
-impl EditPredictionProvider for CopilotCompletionProvider {
+impl EditPredictionDelegate for CopilotEditPredictionDelegate {
fn name() -> &'static str {
"copilot"
}
@@ -56,7 +56,7 @@ impl EditPredictionProvider for CopilotCompletionProvider {
"Copilot"
}
- fn show_completions_in_menu() -> bool {
+ fn show_predictions_in_menu() -> bool {
true
}
@@ -314,7 +314,7 @@ mod tests {
cx,
)
.await;
- let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
+ let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
cx.update_editor(|editor, window, cx| {
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
});
@@ -546,7 +546,7 @@ mod tests {
cx,
)
.await;
- let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
+ let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
cx.update_editor(|editor, window, cx| {
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
});
@@ -670,7 +670,7 @@ mod tests {
cx,
)
.await;
- let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
+ let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
cx.update_editor(|editor, window, cx| {
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
});
@@ -753,7 +753,7 @@ mod tests {
window.focus(&editor.focus_handle(cx));
})
.unwrap();
- let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
+ let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
editor
.update(cx, |editor, window, cx| {
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
@@ -848,7 +848,7 @@ mod tests {
cx,
)
.await;
- let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
+ let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
cx.update_editor(|editor, window, cx| {
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
});
@@ -1000,7 +1000,7 @@ mod tests {
window.focus(&editor.focus_handle(cx))
})
.unwrap();
- let copilot_provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
+ let copilot_provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
editor
.update(cx, |editor, window, cx| {
editor.set_edit_prediction_provider(Some(copilot_provider), window, cx)
@@ -870,7 +870,7 @@ impl DebugAdapter for PythonDebugAdapter {
.active_toolchain(
delegate.worktree_id(),
base_path.into_arc(),
- language::LanguageName::new(Self::LANGUAGE_NAME),
+ language::LanguageName::new_static(Self::LANGUAGE_NAME),
cx,
)
.await
@@ -37,6 +37,7 @@ dap_adapters = { workspace = true, optional = true }
db.workspace = true
debugger_tools.workspace = true
editor.workspace = true
+feature_flags.workspace = true
file_icons.workspace = true
futures.workspace = true
fuzzy.workspace = true
@@ -82,6 +83,7 @@ dap_adapters = { workspace = true, features = ["test-support"] }
debugger_tools = { workspace = true, features = ["test-support"] }
editor = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
+language = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
tree-sitter-go.workspace = true
unindent.workspace = true
@@ -15,10 +15,11 @@ use dap::adapters::DebugAdapterName;
use dap::{DapRegistry, StartDebuggingRequestArguments};
use dap::{client::SessionId, debugger_settings::DebuggerSettings};
use editor::{Editor, MultiBufferOffset, ToPoint};
+use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
use gpui::{
- Action, App, AsyncWindowContext, ClipboardItem, Context, DismissEvent, Entity, EntityId,
- EventEmitter, FocusHandle, Focusable, MouseButton, MouseDownEvent, Point, Subscription, Task,
- WeakEntity, anchored, deferred,
+ Action, App, AsyncWindowContext, ClipboardItem, Context, Corner, DismissEvent, Entity,
+ EntityId, EventEmitter, FocusHandle, Focusable, MouseButton, MouseDownEvent, Point,
+ Subscription, Task, WeakEntity, anchored, deferred,
};
use itertools::Itertools as _;
@@ -31,7 +32,9 @@ use settings::Settings;
use std::sync::{Arc, LazyLock};
use task::{DebugScenario, TaskContext};
use tree_sitter::{Query, StreamingIterator as _};
-use ui::{ContextMenu, Divider, PopoverMenuHandle, Tab, Tooltip, prelude::*};
+use ui::{
+ ContextMenu, Divider, PopoverMenu, PopoverMenuHandle, SplitButton, Tab, Tooltip, prelude::*,
+};
use util::rel_path::RelPath;
use util::{ResultExt, debug_panic, maybe};
use workspace::SplitDirection;
@@ -42,6 +45,12 @@ use workspace::{
};
use zed_actions::ToggleFocus;
+pub struct DebuggerHistoryFeatureFlag;
+
+impl FeatureFlag for DebuggerHistoryFeatureFlag {
+ const NAME: &'static str = "debugger-history";
+}
+
const DEBUG_PANEL_KEY: &str = "DebugPanel";
pub struct DebugPanel {
@@ -284,7 +293,7 @@ impl DebugPanel {
}
});
- session.update(cx, |session, _| match &mut session.mode {
+ session.update(cx, |session, _| match &mut session.state {
SessionState::Booting(state_task) => {
*state_task = Some(boot_task);
}
@@ -662,6 +671,12 @@ impl DebugPanel {
)
};
+ let thread_status = active_session
+ .as_ref()
+ .map(|session| session.read(cx).running_state())
+ .and_then(|state| state.read(cx).thread_status(cx))
+ .unwrap_or(project::debugger::session::ThreadStatus::Exited);
+
Some(
div.w_full()
.py_1()
@@ -679,10 +694,6 @@ impl DebugPanel {
.as_ref()
.map(|session| session.read(cx).running_state()),
|this, running_state| {
- let thread_status =
- running_state.read(cx).thread_status(cx).unwrap_or(
- project::debugger::session::ThreadStatus::Exited,
- );
let capabilities = running_state.read(cx).capabilities(cx);
let supports_detach =
running_state.read(cx).session().read(cx).is_attached();
@@ -871,36 +882,53 @@ impl DebugPanel {
}
}),
)
+ .when(supports_detach, |div| {
+ div.child(
+ IconButton::new(
+ "debug-disconnect",
+ IconName::DebugDetach,
+ )
+ .disabled(
+ thread_status != ThreadStatus::Stopped
+ && thread_status != ThreadStatus::Running,
+ )
+ .icon_size(IconSize::Small)
+ .on_click(window.listener_for(
+ running_state,
+ |this, _, _, cx| {
+ this.detach_client(cx);
+ },
+ ))
+ .tooltip({
+ let focus_handle = focus_handle.clone();
+ move |_window, cx| {
+ Tooltip::for_action_in(
+ "Detach",
+ &Detach,
+ &focus_handle,
+ cx,
+ )
+ }
+ }),
+ )
+ })
.when(
- supports_detach,
- |div| {
- div.child(
- IconButton::new(
- "debug-disconnect",
- IconName::DebugDetach,
- )
- .disabled(
- thread_status != ThreadStatus::Stopped
- && thread_status != ThreadStatus::Running,
+ cx.has_flag::<DebuggerHistoryFeatureFlag>(),
+ |this| {
+ this.child(Divider::vertical()).child(
+ SplitButton::new(
+ self.render_history_button(
+ &running_state,
+ thread_status,
+ window,
+ ),
+ self.render_history_toggle_button(
+ thread_status,
+ &running_state,
+ )
+ .into_any_element(),
)
- .icon_size(IconSize::Small)
- .on_click(window.listener_for(
- running_state,
- |this, _, _, cx| {
- this.detach_client(cx);
- },
- ))
- .tooltip({
- let focus_handle = focus_handle.clone();
- move |_window, cx| {
- Tooltip::for_action_in(
- "Detach",
- &Detach,
- &focus_handle,
- cx,
- )
- }
- }),
+ .style(ui::SplitButtonStyle::Outlined),
)
},
)
@@ -1317,6 +1345,97 @@ impl DebugPanel {
});
}
}
+
+ fn render_history_button(
+ &self,
+ running_state: &Entity<RunningState>,
+ thread_status: ThreadStatus,
+ window: &mut Window,
+ ) -> IconButton {
+ IconButton::new("debug-back-in-history", IconName::HistoryRerun)
+ .icon_size(IconSize::Small)
+ .on_click(window.listener_for(running_state, |this, _, _window, cx| {
+ this.session().update(cx, |session, cx| {
+ let ix = session
+ .active_snapshot_index()
+ .unwrap_or_else(|| session.historic_snapshots().len());
+
+ session.select_historic_snapshot(Some(ix.saturating_sub(1)), cx);
+ })
+ }))
+ .disabled(
+ thread_status == ThreadStatus::Running || thread_status == ThreadStatus::Stepping,
+ )
+ }
+
+ fn render_history_toggle_button(
+ &self,
+ thread_status: ThreadStatus,
+ running_state: &Entity<RunningState>,
+ ) -> impl IntoElement {
+ PopoverMenu::new("debug-back-in-history-menu")
+ .trigger(
+ ui::ButtonLike::new_rounded_right("debug-back-in-history-menu-trigger")
+ .layer(ui::ElevationIndex::ModalSurface)
+ .size(ui::ButtonSize::None)
+ .child(
+ div()
+ .px_1()
+ .child(Icon::new(IconName::ChevronDown).size(IconSize::XSmall)),
+ )
+ .disabled(
+ thread_status == ThreadStatus::Running
+ || thread_status == ThreadStatus::Stepping,
+ ),
+ )
+ .menu({
+ let running_state = running_state.clone();
+ move |window, cx| {
+ let handler =
+ |ix: Option<usize>, running_state: Entity<RunningState>, cx: &mut App| {
+ running_state.update(cx, |state, cx| {
+ state.session().update(cx, |session, cx| {
+ session.select_historic_snapshot(ix, cx);
+ })
+ })
+ };
+
+ let running_state = running_state.clone();
+ Some(ContextMenu::build(
+ window,
+ cx,
+ move |mut context_menu, _window, cx| {
+ let history = running_state
+ .read(cx)
+ .session()
+ .read(cx)
+ .historic_snapshots();
+
+ context_menu = context_menu.entry("Current State", None, {
+ let running_state = running_state.clone();
+ move |_window, cx| {
+ handler(None, running_state.clone(), cx);
+ }
+ });
+ context_menu = context_menu.separator();
+
+ for (ix, _) in history.iter().enumerate().rev() {
+ context_menu =
+ context_menu.entry(format!("history-{}", ix + 1), None, {
+ let running_state = running_state.clone();
+ move |_window, cx| {
+ handler(Some(ix), running_state.clone(), cx);
+ }
+ });
+ }
+
+ context_menu
+ },
+ ))
+ }
+ })
+ .anchor(Corner::TopRight)
+ }
}
async fn register_session_inner(
@@ -387,7 +387,7 @@ pub fn init(cx: &mut App) {
window.on_action(
TypeId::of::<editor::actions::EvaluateSelectedText>(),
move |_, _, window, cx| {
- maybe!({
+ let status = maybe!({
let text = editor
.update(cx, |editor, cx| {
let range = editor
@@ -411,7 +411,13 @@ pub fn init(cx: &mut App) {
state.session().update(cx, |session, cx| {
session
- .evaluate(text, None, stack_id, None, cx)
+ .evaluate(
+ text,
+ Some(dap::EvaluateArgumentsContext::Repl),
+ stack_id,
+ None,
+ cx,
+ )
.detach();
});
});
@@ -419,6 +425,9 @@ pub fn init(cx: &mut App) {
Some(())
});
+ if status.is_some() {
+ cx.stop_propagation();
+ }
},
);
})
@@ -881,7 +881,6 @@ impl ConfigureMode {
.label("Stop on Entry")
.label_position(SwitchLabelPosition::Start)
.label_size(LabelSize::Default)
- .color(ui::SwitchColor::Accent)
.on_click({
let this = cx.weak_entity();
move |state, _, cx| {
@@ -1023,7 +1022,7 @@ impl DebugDelegate {
Some(TaskSourceKind::Lsp { language_name, .. }) => {
Some(format!("LSP: {language_name}"))
}
- Some(TaskSourceKind::Language { name }) => Some(format!("Lang: {name}")),
+ Some(TaskSourceKind::Language { name }) => Some(format!("Language: {name}")),
_ => context.clone().and_then(|ctx| {
ctx.task_context
.task_variables
@@ -1743,7 +1743,7 @@ impl RunningState {
let is_building = self.session.update(cx, |session, cx| {
session.shutdown(cx).detach();
- matches!(session.mode, session::SessionState::Booting(_))
+ matches!(session.state, session::SessionState::Booting(_))
});
if is_building {
@@ -17,7 +17,9 @@ impl LoadedSourceList {
let list = ListState::new(0, gpui::ListAlignment::Top, px(1000.));
let _subscription = cx.subscribe(&session, |this, _, event, cx| match event {
- SessionEvent::Stopped(_) | SessionEvent::LoadedSources => {
+ SessionEvent::Stopped(_)
+ | SessionEvent::HistoricSnapshotSelected
+ | SessionEvent::LoadedSources => {
this.invalidate = true;
cx.notify();
}
@@ -32,7 +32,9 @@ impl ModuleList {
let focus_handle = cx.focus_handle();
let _subscription = cx.subscribe(&session, |this, _, event, cx| match event {
- SessionEvent::Stopped(_) | SessionEvent::Modules => {
+ SessionEvent::Stopped(_)
+ | SessionEvent::HistoricSnapshotSelected
+ | SessionEvent::Modules => {
if this._rebuild_task.is_some() {
this.schedule_rebuild(cx);
}
@@ -4,6 +4,7 @@ use std::time::Duration;
use anyhow::{Context as _, Result, anyhow};
use dap::StackFrameId;
+use dap::adapters::DebugAdapterName;
use db::kvp::KEY_VALUE_STORE;
use gpui::{
Action, AnyElement, Entity, EventEmitter, FocusHandle, Focusable, FontWeight, ListState,
@@ -20,7 +21,7 @@ use project::debugger::breakpoint_store::ActiveStackFrame;
use project::debugger::session::{Session, SessionEvent, StackFrame, ThreadStatus};
use project::{ProjectItem, ProjectPath};
use ui::{Tooltip, WithScrollbar, prelude::*};
-use workspace::{ItemHandle, Workspace};
+use workspace::{ItemHandle, Workspace, WorkspaceId};
use super::RunningState;
@@ -58,6 +59,14 @@ impl From<StackFrameFilter> for String {
}
}
+pub(crate) fn stack_frame_filter_key(
+ adapter_name: &DebugAdapterName,
+ workspace_id: WorkspaceId,
+) -> String {
+ let database_id: i64 = workspace_id.into();
+ format!("stack-frame-list-filter-{}-{}", adapter_name.0, database_id)
+}
+
pub struct StackFrameList {
focus_handle: FocusHandle,
_subscription: Subscription,
@@ -97,7 +106,9 @@ impl StackFrameList {
SessionEvent::Threads => {
this.schedule_refresh(false, window, cx);
}
- SessionEvent::Stopped(..) | SessionEvent::StackTrace => {
+ SessionEvent::Stopped(..)
+ | SessionEvent::StackTrace
+ | SessionEvent::HistoricSnapshotSelected => {
this.schedule_refresh(true, window, cx);
}
_ => {}
@@ -105,14 +116,18 @@ impl StackFrameList {
let list_state = ListState::new(0, gpui::ListAlignment::Top, px(1000.));
- let list_filter = KEY_VALUE_STORE
- .read_kvp(&format!(
- "stack-frame-list-filter-{}",
- session.read(cx).adapter().0
- ))
+ let list_filter = workspace
+ .read_with(cx, |workspace, _| workspace.database_id())
.ok()
.flatten()
- .map(StackFrameFilter::from_str_or_default)
+ .and_then(|database_id| {
+ let key = stack_frame_filter_key(&session.read(cx).adapter(), database_id);
+ KEY_VALUE_STORE
+ .read_kvp(&key)
+ .ok()
+ .flatten()
+ .map(StackFrameFilter::from_str_or_default)
+ })
.unwrap_or(StackFrameFilter::All);
let mut this = Self {
@@ -225,7 +240,6 @@ impl StackFrameList {
}
this.update_in(cx, |this, window, cx| {
this.build_entries(select_first, window, cx);
- cx.notify();
})
.ok();
})
@@ -806,15 +820,8 @@ impl StackFrameList {
.ok()
.flatten()
{
- let database_id: i64 = database_id.into();
- let save_task = KEY_VALUE_STORE.write_kvp(
- format!(
- "stack-frame-list-filter-{}-{}",
- self.session.read(cx).adapter().0,
- database_id,
- ),
- self.list_filter.into(),
- );
+ let key = stack_frame_filter_key(&self.session.read(cx).adapter(), database_id);
+ let save_task = KEY_VALUE_STORE.write_kvp(key, self.list_filter.into());
cx.background_spawn(save_task).detach();
}
@@ -217,6 +217,12 @@ impl VariableList {
let _subscriptions = vec![
cx.subscribe(&stack_frame_list, Self::handle_stack_frame_list_events),
cx.subscribe(&session, |this, _, event, cx| match event {
+ SessionEvent::HistoricSnapshotSelected => {
+ this.selection.take();
+ this.edited_path.take();
+ this.selected_stack_frame_id.take();
+ this.build_entries(cx);
+ }
SessionEvent::Stopped(_) => {
this.selection.take();
this.edited_path.take();
@@ -225,7 +231,6 @@ impl VariableList {
SessionEvent::Variables | SessionEvent::Watchers => {
this.build_entries(cx);
}
-
_ => {}
}),
cx.on_focus_out(&focus_handle, window, |this, _, _, cx| {
@@ -4,7 +4,7 @@ use dap::{Scope, StackFrame, Variable, requests::Variables};
use editor::{Editor, EditorMode, MultiBuffer};
use gpui::{BackgroundExecutor, TestAppContext, VisualTestContext};
use language::{
- Language, LanguageConfig, LanguageMatcher, tree_sitter_python, tree_sitter_rust,
+ Language, LanguageConfig, LanguageMatcher, rust_lang, tree_sitter_python,
tree_sitter_typescript,
};
use project::{FakeFs, Project};
@@ -224,7 +224,7 @@ fn main() {
.unwrap();
buffer.update(cx, |buffer, cx| {
- buffer.set_language(Some(Arc::new(rust_lang())), cx);
+ buffer.set_language(Some(rust_lang()), cx);
});
let (editor, cx) = cx.add_window_view(|window, cx| {
@@ -1521,23 +1521,6 @@ fn main() {
});
}
-fn rust_lang() -> Language {
- let debug_variables_query = include_str!("../../../languages/src/rust/debugger.scm");
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- )
- .with_debug_variables_query(debug_variables_query)
- .unwrap()
-}
-
#[gpui::test]
async fn test_python_inline_values(executor: BackgroundExecutor, cx: &mut TestAppContext) {
init_test(cx);
@@ -1859,21 +1842,23 @@ fn python_lang() -> Language {
.unwrap()
}
-fn go_lang() -> Language {
+fn go_lang() -> Arc<Language> {
let debug_variables_query = include_str!("../../../languages/src/go/debugger.scm");
- Language::new(
- LanguageConfig {
- name: "Go".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["go".to_string()],
+ Arc::new(
+ Language::new(
+ LanguageConfig {
+ name: "Go".into(),
+ matcher: LanguageMatcher {
+ path_suffixes: vec!["go".to_string()],
+ ..Default::default()
+ },
..Default::default()
},
- ..Default::default()
- },
- Some(tree_sitter_go::LANGUAGE.into()),
+ Some(tree_sitter_go::LANGUAGE.into()),
+ )
+ .with_debug_variables_query(debug_variables_query)
+ .unwrap(),
)
- .with_debug_variables_query(debug_variables_query)
- .unwrap()
}
/// Test utility function for inline values testing
@@ -1891,7 +1876,7 @@ async fn test_inline_values_util(
before: &str,
after: &str,
active_debug_line: Option<usize>,
- language: Language,
+ language: Arc<Language>,
executor: BackgroundExecutor,
cx: &mut TestAppContext,
) {
@@ -2091,7 +2076,7 @@ async fn test_inline_values_util(
.unwrap();
buffer.update(cx, |buffer, cx| {
- buffer.set_language(Some(Arc::new(language)), cx);
+ buffer.set_language(Some(language), cx);
});
let (editor, cx) = cx.add_window_view(|window, cx| {
@@ -2276,55 +2261,61 @@ fn main() {
.await;
}
-fn javascript_lang() -> Language {
+fn javascript_lang() -> Arc<Language> {
let debug_variables_query = include_str!("../../../languages/src/javascript/debugger.scm");
- Language::new(
- LanguageConfig {
- name: "JavaScript".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["js".to_string()],
+ Arc::new(
+ Language::new(
+ LanguageConfig {
+ name: "JavaScript".into(),
+ matcher: LanguageMatcher {
+ path_suffixes: vec!["js".to_string()],
+ ..Default::default()
+ },
..Default::default()
},
- ..Default::default()
- },
- Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()),
+ Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()),
+ )
+ .with_debug_variables_query(debug_variables_query)
+ .unwrap(),
)
- .with_debug_variables_query(debug_variables_query)
- .unwrap()
}
-fn typescript_lang() -> Language {
+fn typescript_lang() -> Arc<Language> {
let debug_variables_query = include_str!("../../../languages/src/typescript/debugger.scm");
- Language::new(
- LanguageConfig {
- name: "TypeScript".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["ts".to_string()],
+ Arc::new(
+ Language::new(
+ LanguageConfig {
+ name: "TypeScript".into(),
+ matcher: LanguageMatcher {
+ path_suffixes: vec!["ts".to_string()],
+ ..Default::default()
+ },
..Default::default()
},
- ..Default::default()
- },
- Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()),
+ Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()),
+ )
+ .with_debug_variables_query(debug_variables_query)
+ .unwrap(),
)
- .with_debug_variables_query(debug_variables_query)
- .unwrap()
}
-fn tsx_lang() -> Language {
+fn tsx_lang() -> Arc<Language> {
let debug_variables_query = include_str!("../../../languages/src/tsx/debugger.scm");
- Language::new(
- LanguageConfig {
- name: "TSX".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["tsx".to_string()],
+ Arc::new(
+ Language::new(
+ LanguageConfig {
+ name: "TSX".into(),
+ matcher: LanguageMatcher {
+ path_suffixes: vec!["tsx".to_string()],
+ ..Default::default()
+ },
..Default::default()
},
- ..Default::default()
- },
- Some(tree_sitter_typescript::LANGUAGE_TSX.into()),
+ Some(tree_sitter_typescript::LANGUAGE_TSX.into()),
+ )
+ .with_debug_variables_query(debug_variables_query)
+ .unwrap(),
)
- .with_debug_variables_query(debug_variables_query)
- .unwrap()
}
#[gpui::test]
@@ -1,12 +1,15 @@
use crate::{
debugger_panel::DebugPanel,
- session::running::stack_frame_list::{StackFrameEntry, StackFrameFilter},
+ session::running::stack_frame_list::{
+ StackFrameEntry, StackFrameFilter, stack_frame_filter_key,
+ },
tests::{active_debug_session_panel, init_test, init_test_workspace, start_debug_session},
};
use dap::{
StackFrame,
requests::{Scopes, StackTrace, Threads},
};
+use db::kvp::KEY_VALUE_STORE;
use editor::{Editor, ToPoint as _};
use gpui::{BackgroundExecutor, TestAppContext, VisualTestContext};
use project::{FakeFs, Project};
@@ -1085,3 +1088,180 @@ async fn test_stack_frame_filter(executor: BackgroundExecutor, cx: &mut TestAppC
);
});
}
+
+#[gpui::test]
+async fn test_stack_frame_filter_persistence(
+ executor: BackgroundExecutor,
+ cx: &mut TestAppContext,
+) {
+ init_test(cx);
+
+ let fs = FakeFs::new(executor.clone());
+
+ fs.insert_tree(
+ path!("/project"),
+ json!({
+ "src": {
+ "test.js": "function main() { console.log('hello'); }",
+ }
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let workspace = init_test_workspace(&project, cx).await;
+ let cx = &mut VisualTestContext::from_window(*workspace, cx);
+ workspace
+ .update(cx, |workspace, _, _| {
+ workspace.set_random_database_id();
+ })
+ .unwrap();
+
+ let threads_response = dap::ThreadsResponse {
+ threads: vec![dap::Thread {
+ id: 1,
+ name: "Thread 1".into(),
+ }],
+ };
+
+ let stack_trace_response = dap::StackTraceResponse {
+ stack_frames: vec![StackFrame {
+ id: 1,
+ name: "main".into(),
+ source: Some(dap::Source {
+ name: Some("test.js".into()),
+ path: Some(path!("/project/src/test.js").into()),
+ source_reference: None,
+ presentation_hint: None,
+ origin: None,
+ sources: None,
+ adapter_data: None,
+ checksums: None,
+ }),
+ line: 1,
+ column: 1,
+ end_line: None,
+ end_column: None,
+ can_restart: None,
+ instruction_pointer_reference: None,
+ module_id: None,
+ presentation_hint: None,
+ }],
+ total_frames: None,
+ };
+
+ let stopped_event = dap::StoppedEvent {
+ reason: dap::StoppedEventReason::Pause,
+ description: None,
+ thread_id: Some(1),
+ preserve_focus_hint: None,
+ text: None,
+ all_threads_stopped: None,
+ hit_breakpoint_ids: None,
+ };
+
+ let session = start_debug_session(&workspace, cx, |_| {}).unwrap();
+ let client = session.update(cx, |session, _| session.adapter_client().unwrap());
+ let adapter_name = session.update(cx, |session, _| session.adapter());
+
+ client.on_request::<Threads, _>({
+ let threads_response = threads_response.clone();
+ move |_, _| Ok(threads_response.clone())
+ });
+
+ client.on_request::<Scopes, _>(move |_, _| Ok(dap::ScopesResponse { scopes: vec![] }));
+
+ client.on_request::<StackTrace, _>({
+ let stack_trace_response = stack_trace_response.clone();
+ move |_, _| Ok(stack_trace_response.clone())
+ });
+
+ client
+ .fake_event(dap::messages::Events::Stopped(stopped_event.clone()))
+ .await;
+
+ cx.run_until_parked();
+
+ let stack_frame_list =
+ active_debug_session_panel(workspace, cx).update(cx, |debug_panel_item, cx| {
+ debug_panel_item
+ .running_state()
+ .update(cx, |state, _| state.stack_frame_list().clone())
+ });
+
+ stack_frame_list.update(cx, |stack_frame_list, _cx| {
+ assert_eq!(
+ stack_frame_list.list_filter(),
+ StackFrameFilter::All,
+ "Initial filter should be All"
+ );
+ });
+
+ stack_frame_list.update(cx, |stack_frame_list, cx| {
+ stack_frame_list
+ .toggle_frame_filter(Some(project::debugger::session::ThreadStatus::Stopped), cx);
+ assert_eq!(
+ stack_frame_list.list_filter(),
+ StackFrameFilter::OnlyUserFrames,
+ "Filter should be OnlyUserFrames after toggle"
+ );
+ });
+
+ cx.run_until_parked();
+
+ let workspace_id = workspace
+ .update(cx, |workspace, _window, _cx| workspace.database_id())
+ .ok()
+ .flatten()
+ .expect("workspace id has to be some for this test to work properly");
+
+ let key = stack_frame_filter_key(&adapter_name, workspace_id);
+ let stored_value = KEY_VALUE_STORE.read_kvp(&key).unwrap();
+ assert_eq!(
+ stored_value,
+ Some(StackFrameFilter::OnlyUserFrames.into()),
+ "Filter should be persisted in KVP store with key: {}",
+ key
+ );
+
+ client
+ .fake_event(dap::messages::Events::Terminated(None))
+ .await;
+ cx.run_until_parked();
+
+ let session2 = start_debug_session(&workspace, cx, |_| {}).unwrap();
+ let client2 = session2.update(cx, |session, _| session.adapter_client().unwrap());
+
+ client2.on_request::<Threads, _>({
+ let threads_response = threads_response.clone();
+ move |_, _| Ok(threads_response.clone())
+ });
+
+ client2.on_request::<Scopes, _>(move |_, _| Ok(dap::ScopesResponse { scopes: vec![] }));
+
+ client2.on_request::<StackTrace, _>({
+ let stack_trace_response = stack_trace_response.clone();
+ move |_, _| Ok(stack_trace_response.clone())
+ });
+
+ client2
+ .fake_event(dap::messages::Events::Stopped(stopped_event.clone()))
+ .await;
+
+ cx.run_until_parked();
+
+ let stack_frame_list2 =
+ active_debug_session_panel(workspace, cx).update(cx, |debug_panel_item, cx| {
+ debug_panel_item
+ .running_state()
+ .update(cx, |state, _| state.stack_frame_list().clone())
+ });
+
+ stack_frame_list2.update(cx, |stack_frame_list, _cx| {
+ assert_eq!(
+ stack_frame_list.list_filter(),
+ StackFrameFilter::OnlyUserFrames,
+ "Filter should be restored from KVP store in new session"
+ );
+ });
+}
@@ -11,7 +11,69 @@ workspace = true
[lib]
path = "src/edit_prediction.rs"
+[features]
+eval-support = []
+
[dependencies]
+ai_onboarding.workspace = true
+anyhow.workspace = true
+arrayvec.workspace = true
+brotli.workspace = true
client.workspace = true
+cloud_llm_client.workspace = true
+cloud_zeta2_prompt.workspace = true
+collections.workspace = true
+copilot.workspace = true
+credentials_provider.workspace = true
+db.workspace = true
+edit_prediction_types.workspace = true
+edit_prediction_context.workspace = true
+feature_flags.workspace = true
+fs.workspace = true
+futures.workspace = true
gpui.workspace = true
+indoc.workspace = true
+itertools.workspace = true
language.workspace = true
+language_model.workspace = true
+log.workspace = true
+lsp.workspace = true
+menu.workspace = true
+open_ai.workspace = true
+postage.workspace = true
+pretty_assertions.workspace = true
+project.workspace = true
+rand.workspace = true
+regex.workspace = true
+release_channel.workspace = true
+semver.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+settings.workspace = true
+smol.workspace = true
+strsim.workspace = true
+strum.workspace = true
+telemetry.workspace = true
+telemetry_events.workspace = true
+thiserror.workspace = true
+ui.workspace = true
+util.workspace = true
+uuid.workspace = true
+workspace.workspace = true
+worktree.workspace = true
+zed_actions.workspace = true
+
+[dev-dependencies]
+clock = { workspace = true, features = ["test-support"] }
+cloud_api_types.workspace = true
+cloud_llm_client = { workspace = true, features = ["test-support"] }
+ctor.workspace = true
+gpui = { workspace = true, features = ["test-support"] }
+indoc.workspace = true
+language = { workspace = true, features = ["test-support"] }
+language_model = { workspace = true, features = ["test-support"] }
+lsp.workspace = true
+parking_lot.workspace = true
+project = { workspace = true, features = ["test-support"] }
+settings = { workspace = true, features = ["test-support"] }
+zlog.workspace = true
@@ -0,0 +1,78 @@
+use language::{BufferSnapshot, Point};
+use std::ops::Range;
+
+pub fn editable_and_context_ranges_for_cursor_position(
+ position: Point,
+ snapshot: &BufferSnapshot,
+ editable_region_token_limit: usize,
+ context_token_limit: usize,
+) -> (Range<Point>, Range<Point>) {
+ let mut scope_range = position..position;
+ let mut remaining_edit_tokens = editable_region_token_limit;
+
+ while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) {
+ let parent_tokens = guess_token_count(parent.byte_range().len());
+ let parent_point_range = Point::new(
+ parent.start_position().row as u32,
+ parent.start_position().column as u32,
+ )
+ ..Point::new(
+ parent.end_position().row as u32,
+ parent.end_position().column as u32,
+ );
+ if parent_point_range == scope_range {
+ break;
+ } else if parent_tokens <= editable_region_token_limit {
+ scope_range = parent_point_range;
+ remaining_edit_tokens = editable_region_token_limit - parent_tokens;
+ } else {
+ break;
+ }
+ }
+
+ let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens);
+ let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit);
+ (editable_range, context_range)
+}
+
+fn expand_range(
+ snapshot: &BufferSnapshot,
+ range: Range<Point>,
+ mut remaining_tokens: usize,
+) -> Range<Point> {
+ let mut expanded_range = range;
+ expanded_range.start.column = 0;
+ expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
+ loop {
+ let mut expanded = false;
+
+ if remaining_tokens > 0 && expanded_range.start.row > 0 {
+ expanded_range.start.row -= 1;
+ let line_tokens =
+ guess_token_count(snapshot.line_len(expanded_range.start.row) as usize);
+ remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
+ expanded = true;
+ }
+
+ if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row {
+ expanded_range.end.row += 1;
+ expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
+ let line_tokens = guess_token_count(expanded_range.end.column as usize);
+ remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
+ expanded = true;
+ }
+
+ if !expanded {
+ break;
+ }
+ }
+ expanded_range
+}
+
+/// Typical number of string bytes per token for the purposes of limiting model input. This is
+/// intentionally low to err on the side of underestimating limits.
+pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
+
+pub fn guess_token_count(bytes: usize) -> usize {
+ bytes / BYTES_PER_TOKEN_GUESS
+}
@@ -1,298 +1,1944 @@
-use std::{ops::Range, sync::Arc};
+use anyhow::Result;
+use arrayvec::ArrayVec;
+use client::{Client, EditPredictionUsage, UserStore};
+use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
+use cloud_llm_client::{
+ AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
+ EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
+ MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
+ ZED_VERSION_HEADER_NAME,
+};
+use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
+use collections::{HashMap, HashSet};
+use db::kvp::{Dismissable, KEY_VALUE_STORE};
+use edit_prediction_context::EditPredictionExcerptOptions;
+use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile};
+use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
+use futures::{
+ AsyncReadExt as _, FutureExt as _, StreamExt as _,
+ channel::{
+ mpsc::{self, UnboundedReceiver},
+ oneshot,
+ },
+ select_biased,
+};
+use gpui::BackgroundExecutor;
+use gpui::{
+ App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
+ http_client::{self, AsyncBody, Method},
+ prelude::*,
+};
+use language::language_settings::all_language_settings;
+use language::{Anchor, Buffer, File, Point, ToPoint};
+use language::{BufferSnapshot, OffsetRangeExt};
+use language_model::{LlmApiToken, RefreshLlmTokenListener};
+use project::{Project, ProjectPath, WorktreeId};
+use release_channel::AppVersion;
+use semver::Version;
+use serde::de::DeserializeOwned;
+use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
+use std::collections::{VecDeque, hash_map};
+use workspace::Workspace;
+
+use std::ops::Range;
+use std::path::Path;
+use std::rc::Rc;
+use std::str::FromStr as _;
+use std::sync::{Arc, LazyLock};
+use std::time::{Duration, Instant};
+use std::{env, mem};
+use thiserror::Error;
+use util::{RangeExt as _, ResultExt as _};
+use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
+
+mod cursor_excerpt;
+mod license_detection;
+pub mod mercury;
+mod onboarding_modal;
+pub mod open_ai_response;
+mod prediction;
+pub mod sweep_ai;
+pub mod udiff;
+mod xml_edits;
+mod zed_edit_prediction_delegate;
+pub mod zeta1;
+pub mod zeta2;
+
+#[cfg(test)]
+mod edit_prediction_tests;
+
+use crate::license_detection::LicenseDetectionWatcher;
+use crate::mercury::Mercury;
+use crate::onboarding_modal::ZedPredictModal;
+pub use crate::prediction::EditPrediction;
+pub use crate::prediction::EditPredictionId;
+pub use crate::prediction::EditPredictionInputs;
+use crate::prediction::EditPredictionResult;
+pub use crate::sweep_ai::SweepAi;
+pub use telemetry_events::EditPredictionRating;
+pub use zed_edit_prediction_delegate::ZedEditPredictionDelegate;
+
+actions!(
+ edit_prediction,
+ [
+ /// Resets the edit prediction onboarding state.
+ ResetOnboarding,
+ /// Clears the edit prediction history.
+ ClearHistory,
+ ]
+);
+
+/// Maximum number of events to track.
+const EVENT_COUNT_MAX: usize = 6;
+const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
+const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
+const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
-use client::EditPredictionUsage;
-use gpui::{App, Context, Entity, SharedString};
-use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt};
+pub struct SweepFeatureFlag;
-// TODO: Find a better home for `Direction`.
-//
-// This should live in an ancestor crate of `editor` and `edit_prediction`,
-// but at time of writing there isn't an obvious spot.
-#[derive(Copy, Clone, PartialEq, Eq)]
-pub enum Direction {
- Prev,
- Next,
+impl FeatureFlag for SweepFeatureFlag {
+ const NAME: &str = "sweep-ai";
}
-#[derive(Clone)]
-pub enum EditPrediction {
- /// Edits within the buffer that requested the prediction
- Local {
- id: Option<SharedString>,
- edits: Vec<(Range<language::Anchor>, Arc<str>)>,
- edit_preview: Option<language::EditPreview>,
- },
- /// Jump to a different file from the one that requested the prediction
- Jump {
- id: Option<SharedString>,
- snapshot: language::BufferSnapshot,
- target: language::Anchor,
+pub struct MercuryFeatureFlag;
+
+impl FeatureFlag for MercuryFeatureFlag {
+ const NAME: &str = "mercury";
+}
+
+pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
+ context: EditPredictionExcerptOptions {
+ max_bytes: 512,
+ min_bytes: 128,
+ target_before_cursor_over_total_bytes: 0.5,
},
+ max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
+ prompt_format: PromptFormat::DEFAULT,
+};
+
+static USE_OLLAMA: LazyLock<bool> =
+ LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
+
+static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
+ match env::var("ZED_ZETA2_MODEL").as_deref() {
+ Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
+ Ok(model) => model,
+ Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
+ Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
+ }
+ .to_string()
+});
+static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
+ env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
+ if *USE_OLLAMA {
+ Some("http://localhost:11434/v1/chat/completions".into())
+ } else {
+ None
+ }
+ })
+});
+
+pub struct Zeta2FeatureFlag;
+
+impl FeatureFlag for Zeta2FeatureFlag {
+ const NAME: &'static str = "zeta2";
+
+ fn enabled_for_staff() -> bool {
+ true
+ }
+}
+
+#[derive(Clone)]
+struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
+
+impl Global for EditPredictionStoreGlobal {}
+
+pub struct EditPredictionStore {
+ client: Arc<Client>,
+ user_store: Entity<UserStore>,
+ llm_token: LlmApiToken,
+ _llm_token_subscription: Subscription,
+ projects: HashMap<EntityId, ProjectState>,
+ use_context: bool,
+ options: ZetaOptions,
+ update_required: bool,
+ debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
+ #[cfg(feature = "eval-support")]
+ eval_cache: Option<Arc<dyn EvalCache>>,
+ edit_prediction_model: EditPredictionModel,
+ pub sweep_ai: SweepAi,
+ pub mercury: Mercury,
+ data_collection_choice: DataCollectionChoice,
+ reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
+ shown_predictions: VecDeque<EditPrediction>,
+ rated_predictions: HashSet<EditPredictionId>,
+}
+
+#[derive(Copy, Clone, Default, PartialEq, Eq)]
+pub enum EditPredictionModel {
+ #[default]
+ Zeta1,
+ Zeta2,
+ Sweep,
+ Mercury,
+}
+
+#[derive(Debug, Clone, PartialEq)]
+pub struct ZetaOptions {
+ pub context: EditPredictionExcerptOptions,
+ pub max_prompt_bytes: usize,
+ pub prompt_format: predict_edits_v3::PromptFormat,
+}
+
+#[derive(Debug)]
+pub enum DebugEvent {
+ ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
+ ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
+ EditPredictionRequested(EditPredictionRequestedDebugEvent),
+}
+
+#[derive(Debug)]
+pub struct ContextRetrievalStartedDebugEvent {
+ pub project_entity_id: EntityId,
+ pub timestamp: Instant,
+ pub search_prompt: String,
+}
+
+#[derive(Debug)]
+pub struct ContextRetrievalFinishedDebugEvent {
+ pub project_entity_id: EntityId,
+ pub timestamp: Instant,
+ pub metadata: Vec<(&'static str, SharedString)>,
+}
+
+#[derive(Debug)]
+pub struct EditPredictionRequestedDebugEvent {
+ pub inputs: EditPredictionInputs,
+ pub retrieval_time: Duration,
+ pub buffer: WeakEntity<Buffer>,
+ pub position: Anchor,
+ pub local_prompt: Result<String, String>,
+ pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
+}
+
+pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
+
+struct ProjectState {
+ events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
+ last_event: Option<LastEvent>,
+ recent_paths: VecDeque<ProjectPath>,
+ registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
+ current_prediction: Option<CurrentEditPrediction>,
+ next_pending_prediction_id: usize,
+ pending_predictions: ArrayVec<PendingPrediction, 2>,
+ context_updates_tx: smol::channel::Sender<()>,
+ context_updates_rx: smol::channel::Receiver<()>,
+ last_prediction_refresh: Option<(EntityId, Instant)>,
+ cancelled_predictions: HashSet<usize>,
+ context: Entity<RelatedExcerptStore>,
+ license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
+ _subscription: gpui::Subscription,
+}
+
+impl ProjectState {
+ pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
+ self.events
+ .iter()
+ .cloned()
+ .chain(
+ self.last_event
+ .as_ref()
+ .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
+ )
+ .collect()
+ }
+
+ fn cancel_pending_prediction(
+ &mut self,
+ pending_prediction: PendingPrediction,
+ cx: &mut Context<EditPredictionStore>,
+ ) {
+ self.cancelled_predictions.insert(pending_prediction.id);
+
+ cx.spawn(async move |this, cx| {
+ let Some(prediction_id) = pending_prediction.task.await else {
+ return;
+ };
+
+ this.update(cx, |this, _cx| {
+ this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
+ })
+ .ok();
+ })
+ .detach()
+ }
}
-pub enum DataCollectionState {
- /// The provider doesn't support data collection.
- Unsupported,
- /// Data collection is enabled.
- Enabled { is_project_open_source: bool },
- /// Data collection is disabled or unanswered.
- Disabled { is_project_open_source: bool },
+#[derive(Debug, Clone)]
+struct CurrentEditPrediction {
+ pub requested_by: PredictionRequestedBy,
+ pub prediction: EditPrediction,
+ pub was_shown: bool,
}
-impl DataCollectionState {
- pub fn is_supported(&self) -> bool {
- !matches!(self, DataCollectionState::Unsupported)
+impl CurrentEditPrediction {
+ fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
+ let Some(new_edits) = self
+ .prediction
+ .interpolate(&self.prediction.buffer.read(cx))
+ else {
+ return false;
+ };
+
+ if self.prediction.buffer != old_prediction.prediction.buffer {
+ return true;
+ }
+
+ let Some(old_edits) = old_prediction
+ .prediction
+ .interpolate(&old_prediction.prediction.buffer.read(cx))
+ else {
+ return true;
+ };
+
+ let requested_by_buffer_id = self.requested_by.buffer_id();
+
+ // This reduces the occurrence of UI thrash from replacing edits
+ //
+ // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
+ if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
+ && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
+ && old_edits.len() == 1
+ && new_edits.len() == 1
+ {
+ let (old_range, old_text) = &old_edits[0];
+ let (new_range, new_text) = &new_edits[0];
+ new_range == old_range && new_text.starts_with(old_text.as_ref())
+ } else {
+ true
+ }
}
+}
- pub fn is_enabled(&self) -> bool {
- matches!(self, DataCollectionState::Enabled { .. })
+#[derive(Debug, Clone)]
+enum PredictionRequestedBy {
+ DiagnosticsUpdate,
+ Buffer(EntityId),
+}
+
+impl PredictionRequestedBy {
+ pub fn buffer_id(&self) -> Option<EntityId> {
+ match self {
+ PredictionRequestedBy::DiagnosticsUpdate => None,
+ PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
+ }
}
+}
+
+#[derive(Debug)]
+struct PendingPrediction {
+ id: usize,
+ task: Task<Option<EditPredictionId>>,
+}
+
+/// A prediction from the perspective of a buffer.
+#[derive(Debug)]
+enum BufferEditPrediction<'a> {
+ Local { prediction: &'a EditPrediction },
+ Jump { prediction: &'a EditPrediction },
+}
+
+#[cfg(test)]
+impl std::ops::Deref for BufferEditPrediction<'_> {
+ type Target = EditPrediction;
- pub fn is_project_open_source(&self) -> bool {
+ fn deref(&self) -> &Self::Target {
match self {
- Self::Enabled {
- is_project_open_source,
- }
- | Self::Disabled {
- is_project_open_source,
- } => *is_project_open_source,
- _ => false,
+ BufferEditPrediction::Local { prediction } => prediction,
+ BufferEditPrediction::Jump { prediction } => prediction,
}
}
}
-pub trait EditPredictionProvider: 'static + Sized {
- fn name() -> &'static str;
- fn display_name() -> &'static str;
- fn show_completions_in_menu() -> bool;
- fn show_tab_accept_marker() -> bool {
- false
+struct RegisteredBuffer {
+ snapshot: BufferSnapshot,
+ _subscriptions: [gpui::Subscription; 2],
+}
+
+struct LastEvent {
+ old_snapshot: BufferSnapshot,
+ new_snapshot: BufferSnapshot,
+ end_edit_anchor: Option<Anchor>,
+}
+
+impl LastEvent {
+ pub fn finalize(
+ &self,
+ license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
+ cx: &App,
+ ) -> Option<Arc<predict_edits_v3::Event>> {
+ let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
+ let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
+
+ let file = self.new_snapshot.file();
+ let old_file = self.old_snapshot.file();
+
+ let in_open_source_repo = [file, old_file].iter().all(|file| {
+ file.is_some_and(|file| {
+ license_detection_watchers
+ .get(&file.worktree_id(cx))
+ .is_some_and(|watcher| watcher.is_project_open_source())
+ })
+ });
+
+ let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
+
+ if path == old_path && diff.is_empty() {
+ None
+ } else {
+ Some(Arc::new(predict_edits_v3::Event::BufferChange {
+ old_path,
+ path,
+ diff,
+ in_open_source_repo,
+ // TODO: Actually detect if this edit was predicted or not
+ predicted: false,
+ }))
+ }
}
- fn supports_jump_to_edit() -> bool {
- true
+}
+
+fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
+ if let Some(file) = snapshot.file() {
+ file.full_path(cx).into()
+ } else {
+ Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
}
+}
- fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
- DataCollectionState::Unsupported
+impl EditPredictionStore {
+ pub fn try_global(cx: &App) -> Option<Entity<Self>> {
+ cx.try_global::<EditPredictionStoreGlobal>()
+ .map(|global| global.0.clone())
}
- fn usage(&self, _cx: &App) -> Option<EditPredictionUsage> {
- None
+ pub fn global(
+ client: &Arc<Client>,
+ user_store: &Entity<UserStore>,
+ cx: &mut App,
+ ) -> Entity<Self> {
+ cx.try_global::<EditPredictionStoreGlobal>()
+ .map(|global| global.0.clone())
+ .unwrap_or_else(|| {
+ let ep_store = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
+ cx.set_global(EditPredictionStoreGlobal(ep_store.clone()));
+ ep_store
+ })
}
- fn toggle_data_collection(&mut self, _cx: &mut App) {}
- fn is_enabled(
- &self,
+ pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
+ let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
+ let data_collection_choice = Self::load_data_collection_choice();
+
+ let llm_token = LlmApiToken::default();
+
+ let (reject_tx, reject_rx) = mpsc::unbounded();
+ cx.background_spawn({
+ let client = client.clone();
+ let llm_token = llm_token.clone();
+ let app_version = AppVersion::global(cx);
+ let background_executor = cx.background_executor().clone();
+ async move {
+ Self::handle_rejected_predictions(
+ reject_rx,
+ client,
+ llm_token,
+ app_version,
+ background_executor,
+ )
+ .await
+ }
+ })
+ .detach();
+
+ let mut this = Self {
+ projects: HashMap::default(),
+ client,
+ user_store,
+ options: DEFAULT_OPTIONS,
+ use_context: false,
+ llm_token,
+ _llm_token_subscription: cx.subscribe(
+ &refresh_llm_token_listener,
+ |this, _listener, _event, cx| {
+ let client = this.client.clone();
+ let llm_token = this.llm_token.clone();
+ cx.spawn(async move |_this, _cx| {
+ llm_token.refresh(&client).await?;
+ anyhow::Ok(())
+ })
+ .detach_and_log_err(cx);
+ },
+ ),
+ update_required: false,
+ debug_tx: None,
+ #[cfg(feature = "eval-support")]
+ eval_cache: None,
+ edit_prediction_model: EditPredictionModel::Zeta2,
+ sweep_ai: SweepAi::new(cx),
+ mercury: Mercury::new(cx),
+ data_collection_choice,
+ reject_predictions_tx: reject_tx,
+ rated_predictions: Default::default(),
+ shown_predictions: Default::default(),
+ };
+
+ this.configure_context_retrieval(cx);
+ let weak_this = cx.weak_entity();
+ cx.on_flags_ready(move |_, cx| {
+ weak_this
+ .update(cx, |this, cx| this.configure_context_retrieval(cx))
+ .ok();
+ })
+ .detach();
+ cx.observe_global::<SettingsStore>(|this, cx| {
+ this.configure_context_retrieval(cx);
+ })
+ .detach();
+
+ this
+ }
+
+ pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
+ self.edit_prediction_model = model;
+ }
+
+ pub fn has_sweep_api_token(&self) -> bool {
+ self.sweep_ai
+ .api_token
+ .clone()
+ .now_or_never()
+ .flatten()
+ .is_some()
+ }
+
+ pub fn has_mercury_api_token(&self) -> bool {
+ self.mercury
+ .api_token
+ .clone()
+ .now_or_never()
+ .flatten()
+ .is_some()
+ }
+
+ #[cfg(feature = "eval-support")]
+ pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
+ self.eval_cache = Some(cache);
+ }
+
+ pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<DebugEvent> {
+ let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
+ self.debug_tx = Some(debug_watch_tx);
+ debug_watch_rx
+ }
+
+ pub fn options(&self) -> &ZetaOptions {
+ &self.options
+ }
+
+ pub fn set_options(&mut self, options: ZetaOptions) {
+ self.options = options;
+ }
+
+ pub fn set_use_context(&mut self, use_context: bool) {
+ self.use_context = use_context;
+ }
+
+ pub fn clear_history(&mut self) {
+ for project_state in self.projects.values_mut() {
+ project_state.events.clear();
+ }
+ }
+
+ pub fn context_for_project<'a>(
+ &'a self,
+ project: &Entity<Project>,
+ cx: &'a App,
+ ) -> &'a [RelatedFile] {
+ self.projects
+ .get(&project.entity_id())
+ .map(|project| project.context.read(cx).related_files())
+ .unwrap_or(&[])
+ }
+
+ pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
+ if self.edit_prediction_model == EditPredictionModel::Zeta2 {
+ self.user_store.read(cx).edit_prediction_usage()
+ } else {
+ None
+ }
+ }
+
+ pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
+ self.get_or_init_project(project, cx);
+ }
+
+ pub fn register_buffer(
+ &mut self,
buffer: &Entity<Buffer>,
- cursor_position: language::Anchor,
- cx: &App,
- ) -> bool;
- fn is_refreshing(&self, cx: &App) -> bool;
- fn refresh(
+ project: &Entity<Project>,
+ cx: &mut Context<Self>,
+ ) {
+ let project_state = self.get_or_init_project(project, cx);
+ Self::register_buffer_impl(project_state, buffer, project, cx);
+ }
+
+ fn get_or_init_project(
&mut self,
- buffer: Entity<Buffer>,
- cursor_position: language::Anchor,
- debounce: bool,
+ project: &Entity<Project>,
cx: &mut Context<Self>,
- );
- fn cycle(
+ ) -> &mut ProjectState {
+ let entity_id = project.entity_id();
+ let (context_updates_tx, context_updates_rx) = smol::channel::unbounded();
+ self.projects
+ .entry(entity_id)
+ .or_insert_with(|| ProjectState {
+ context: {
+ let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
+ cx.subscribe(
+ &related_excerpt_store,
+ move |this, _, event, _| match event {
+ RelatedExcerptStoreEvent::StartedRefresh => {
+ if let Some(debug_tx) = this.debug_tx.clone() {
+ debug_tx
+ .unbounded_send(DebugEvent::ContextRetrievalStarted(
+ ContextRetrievalStartedDebugEvent {
+ project_entity_id: entity_id,
+ timestamp: Instant::now(),
+ search_prompt: String::new(),
+ },
+ ))
+ .ok();
+ }
+ }
+ RelatedExcerptStoreEvent::FinishedRefresh {
+ cache_hit_count,
+ cache_miss_count,
+ mean_definition_latency,
+ max_definition_latency,
+ } => {
+ if let Some(debug_tx) = this.debug_tx.clone() {
+ debug_tx
+ .unbounded_send(DebugEvent::ContextRetrievalFinished(
+ ContextRetrievalFinishedDebugEvent {
+ project_entity_id: entity_id,
+ timestamp: Instant::now(),
+ metadata: vec![
+ (
+ "Cache Hits",
+ format!(
+ "{}/{}",
+ cache_hit_count,
+ cache_hit_count + cache_miss_count
+ )
+ .into(),
+ ),
+ (
+ "Max LSP Time",
+ format!(
+ "{} ms",
+ max_definition_latency.as_millis()
+ )
+ .into(),
+ ),
+ (
+ "Mean LSP Time",
+ format!(
+ "{} ms",
+ mean_definition_latency.as_millis()
+ )
+ .into(),
+ ),
+ ],
+ },
+ ))
+ .ok();
+ }
+ if let Some(project_state) = this.projects.get(&entity_id) {
+ project_state.context_updates_tx.send_blocking(()).ok();
+ }
+ }
+ },
+ )
+ .detach();
+ related_excerpt_store
+ },
+ events: VecDeque::new(),
+ last_event: None,
+ recent_paths: VecDeque::new(),
+ context_updates_rx,
+ context_updates_tx,
+ registered_buffers: HashMap::default(),
+ current_prediction: None,
+ cancelled_predictions: HashSet::default(),
+ pending_predictions: ArrayVec::new(),
+ next_pending_prediction_id: 0,
+ last_prediction_refresh: None,
+ license_detection_watchers: HashMap::default(),
+ _subscription: cx.subscribe(&project, Self::handle_project_event),
+ })
+ }
+
+ pub fn project_context_updates(
+ &self,
+ project: &Entity<Project>,
+ ) -> Option<smol::channel::Receiver<()>> {
+ let project_state = self.projects.get(&project.entity_id())?;
+ Some(project_state.context_updates_rx.clone())
+ }
+
+ fn handle_project_event(
&mut self,
- buffer: Entity<Buffer>,
- cursor_position: language::Anchor,
- direction: Direction,
+ project: Entity<Project>,
+ event: &project::Event,
cx: &mut Context<Self>,
- );
- fn accept(&mut self, cx: &mut Context<Self>);
- fn discard(&mut self, cx: &mut Context<Self>);
- fn did_show(&mut self, _cx: &mut Context<Self>) {}
- fn suggest(
+ ) {
+ // TODO [zeta2] init with recent paths
+ match event {
+ project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
+ let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
+ return;
+ };
+ let path = project.read(cx).path_for_entry(*active_entry_id, cx);
+ if let Some(path) = path {
+ if let Some(ix) = project_state
+ .recent_paths
+ .iter()
+ .position(|probe| probe == &path)
+ {
+ project_state.recent_paths.remove(ix);
+ }
+ project_state.recent_paths.push_front(path);
+ }
+ }
+ project::Event::DiagnosticsUpdated { .. } => {
+ if cx.has_flag::<Zeta2FeatureFlag>() {
+ self.refresh_prediction_from_diagnostics(project, cx);
+ }
+ }
+ _ => (),
+ }
+ }
+
+ fn register_buffer_impl<'a>(
+ project_state: &'a mut ProjectState,
+ buffer: &Entity<Buffer>,
+ project: &Entity<Project>,
+ cx: &mut Context<Self>,
+ ) -> &'a mut RegisteredBuffer {
+ let buffer_id = buffer.entity_id();
+
+ if let Some(file) = buffer.read(cx).file() {
+ let worktree_id = file.worktree_id(cx);
+ if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
+ project_state
+ .license_detection_watchers
+ .entry(worktree_id)
+ .or_insert_with(|| {
+ let project_entity_id = project.entity_id();
+ cx.observe_release(&worktree, move |this, _worktree, _cx| {
+ let Some(project_state) = this.projects.get_mut(&project_entity_id)
+ else {
+ return;
+ };
+ project_state
+ .license_detection_watchers
+ .remove(&worktree_id);
+ })
+ .detach();
+ Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
+ });
+ }
+ }
+
+ match project_state.registered_buffers.entry(buffer_id) {
+ hash_map::Entry::Occupied(entry) => entry.into_mut(),
+ hash_map::Entry::Vacant(entry) => {
+ let snapshot = buffer.read(cx).snapshot();
+ let project_entity_id = project.entity_id();
+ entry.insert(RegisteredBuffer {
+ snapshot,
+ _subscriptions: [
+ cx.subscribe(buffer, {
+ let project = project.downgrade();
+ move |this, buffer, event, cx| {
+ if let language::BufferEvent::Edited = event
+ && let Some(project) = project.upgrade()
+ {
+ this.report_changes_for_buffer(&buffer, &project, cx);
+ }
+ }
+ }),
+ cx.observe_release(buffer, move |this, _buffer, _cx| {
+ let Some(project_state) = this.projects.get_mut(&project_entity_id)
+ else {
+ return;
+ };
+ project_state.registered_buffers.remove(&buffer_id);
+ }),
+ ],
+ })
+ }
+ }
+ }
+
+ fn report_changes_for_buffer(
&mut self,
buffer: &Entity<Buffer>,
- cursor_position: language::Anchor,
+ project: &Entity<Project>,
cx: &mut Context<Self>,
- ) -> Option<EditPrediction>;
-}
+ ) {
+ let project_state = self.get_or_init_project(project, cx);
+ let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
+
+ let new_snapshot = buffer.read(cx).snapshot();
+ if new_snapshot.version == registered_buffer.snapshot.version {
+ return;
+ }
+
+ let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
+ let end_edit_anchor = new_snapshot
+ .anchored_edits_since::<Point>(&old_snapshot.version)
+ .last()
+ .map(|(_, range)| range.end);
+ let events = &mut project_state.events;
+
+ if let Some(LastEvent {
+ new_snapshot: last_new_snapshot,
+ end_edit_anchor: last_end_edit_anchor,
+ ..
+ }) = project_state.last_event.as_mut()
+ {
+ let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
+ == last_new_snapshot.remote_id()
+ && old_snapshot.version == last_new_snapshot.version;
+
+ let should_coalesce = is_next_snapshot_of_same_buffer
+ && end_edit_anchor
+ .as_ref()
+ .zip(last_end_edit_anchor.as_ref())
+ .is_some_and(|(a, b)| {
+ let a = a.to_point(&new_snapshot);
+ let b = b.to_point(&new_snapshot);
+ a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
+ });
-pub trait EditPredictionProviderHandle {
- fn name(&self) -> &'static str;
- fn display_name(&self) -> &'static str;
- fn is_enabled(
+ if should_coalesce {
+ *last_end_edit_anchor = end_edit_anchor;
+ *last_new_snapshot = new_snapshot;
+ return;
+ }
+ }
+
+ if events.len() + 1 >= EVENT_COUNT_MAX {
+ events.pop_front();
+ }
+
+ if let Some(event) = project_state.last_event.take() {
+ events.extend(event.finalize(&project_state.license_detection_watchers, cx));
+ }
+
+ project_state.last_event = Some(LastEvent {
+ old_snapshot,
+ new_snapshot,
+ end_edit_anchor,
+ });
+ }
+
+ fn current_prediction_for_buffer(
&self,
buffer: &Entity<Buffer>,
- cursor_position: language::Anchor,
+ project: &Entity<Project>,
cx: &App,
- ) -> bool;
- fn show_completions_in_menu(&self) -> bool;
- fn show_tab_accept_marker(&self) -> bool;
- fn supports_jump_to_edit(&self) -> bool;
- fn data_collection_state(&self, cx: &App) -> DataCollectionState;
- fn usage(&self, cx: &App) -> Option<EditPredictionUsage>;
- fn toggle_data_collection(&self, cx: &mut App);
- fn is_refreshing(&self, cx: &App) -> bool;
- fn refresh(
- &self,
- buffer: Entity<Buffer>,
- cursor_position: language::Anchor,
- debounce: bool,
- cx: &mut App,
- );
- fn cycle(
- &self,
- buffer: Entity<Buffer>,
- cursor_position: language::Anchor,
- direction: Direction,
- cx: &mut App,
- );
- fn did_show(&self, cx: &mut App);
- fn accept(&self, cx: &mut App);
- fn discard(&self, cx: &mut App);
- fn suggest(
- &self,
- buffer: &Entity<Buffer>,
- cursor_position: language::Anchor,
- cx: &mut App,
- ) -> Option<EditPrediction>;
-}
+ ) -> Option<BufferEditPrediction<'_>> {
+ let project_state = self.projects.get(&project.entity_id())?;
-impl<T> EditPredictionProviderHandle for Entity<T>
-where
- T: EditPredictionProvider,
-{
- fn name(&self) -> &'static str {
- T::name()
- }
+ let CurrentEditPrediction {
+ requested_by,
+ prediction,
+ ..
+ } = project_state.current_prediction.as_ref()?;
- fn display_name(&self) -> &'static str {
- T::display_name()
- }
+ if prediction.targets_buffer(buffer.read(cx)) {
+ Some(BufferEditPrediction::Local { prediction })
+ } else {
+ let show_jump = match requested_by {
+ PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
+ requested_by_buffer_id == &buffer.entity_id()
+ }
+ PredictionRequestedBy::DiagnosticsUpdate => true,
+ };
- fn show_completions_in_menu(&self) -> bool {
- T::show_completions_in_menu()
+ if show_jump {
+ Some(BufferEditPrediction::Jump { prediction })
+ } else {
+ None
+ }
+ }
}
- fn show_tab_accept_marker(&self) -> bool {
- T::show_tab_accept_marker()
- }
+ fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
+ match self.edit_prediction_model {
+ EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
+ EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
+ }
+
+ let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
+ return;
+ };
- fn supports_jump_to_edit(&self) -> bool {
- T::supports_jump_to_edit()
+ let Some(prediction) = project_state.current_prediction.take() else {
+ return;
+ };
+ let request_id = prediction.prediction.id.to_string();
+ for pending_prediction in mem::take(&mut project_state.pending_predictions) {
+ project_state.cancel_pending_prediction(pending_prediction, cx);
+ }
+
+ let client = self.client.clone();
+ let llm_token = self.llm_token.clone();
+ let app_version = AppVersion::global(cx);
+ cx.spawn(async move |this, cx| {
+ let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
+ http_client::Url::parse(&predict_edits_url)?
+ } else {
+ client
+ .http_client()
+ .build_zed_llm_url("/predict_edits/accept", &[])?
+ };
+
+ let response = cx
+ .background_spawn(Self::send_api_request::<()>(
+ move |builder| {
+ let req = builder.uri(url.as_ref()).body(
+ serde_json::to_string(&AcceptEditPredictionBody {
+ request_id: request_id.clone(),
+ })?
+ .into(),
+ );
+ Ok(req?)
+ },
+ client,
+ llm_token,
+ app_version,
+ ))
+ .await;
+
+ Self::handle_api_response(&this, response, cx)?;
+ anyhow::Ok(())
+ })
+ .detach_and_log_err(cx);
}
- fn data_collection_state(&self, cx: &App) -> DataCollectionState {
- self.read(cx).data_collection_state(cx)
+ async fn handle_rejected_predictions(
+ rx: UnboundedReceiver<EditPredictionRejection>,
+ client: Arc<Client>,
+ llm_token: LlmApiToken,
+ app_version: Version,
+ background_executor: BackgroundExecutor,
+ ) {
+ let mut rx = std::pin::pin!(rx.peekable());
+ let mut batched = Vec::new();
+
+ while let Some(rejection) = rx.next().await {
+ batched.push(rejection);
+
+ if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
+ select_biased! {
+ next = rx.as_mut().peek().fuse() => {
+ if next.is_some() {
+ continue;
+ }
+ }
+ () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
+ }
+ }
+
+ let url = client
+ .http_client()
+ .build_zed_llm_url("/predict_edits/reject", &[])
+ .unwrap();
+
+ let flush_count = batched
+ .len()
+ // in case items have accumulated after failure
+ .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
+ let start = batched.len() - flush_count;
+
+ let body = RejectEditPredictionsBodyRef {
+ rejections: &batched[start..],
+ };
+
+ let result = Self::send_api_request::<()>(
+ |builder| {
+ let req = builder
+ .uri(url.as_ref())
+ .body(serde_json::to_string(&body)?.into());
+ anyhow::Ok(req?)
+ },
+ client.clone(),
+ llm_token.clone(),
+ app_version.clone(),
+ )
+ .await;
+
+ if result.log_err().is_some() {
+ batched.drain(start..);
+ }
+ }
}
- fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
- self.read(cx).usage(cx)
+ fn reject_current_prediction(
+ &mut self,
+ reason: EditPredictionRejectReason,
+ project: &Entity<Project>,
+ ) {
+ if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
+ project_state.pending_predictions.clear();
+ if let Some(prediction) = project_state.current_prediction.take() {
+ self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
+ }
+ };
}
- fn toggle_data_collection(&self, cx: &mut App) {
- self.update(cx, |this, cx| this.toggle_data_collection(cx))
+ fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
+ if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
+ if let Some(current_prediction) = project_state.current_prediction.as_mut() {
+ if !current_prediction.was_shown {
+ current_prediction.was_shown = true;
+ self.shown_predictions
+ .push_front(current_prediction.prediction.clone());
+ if self.shown_predictions.len() > 50 {
+ let completion = self.shown_predictions.pop_back().unwrap();
+ self.rated_predictions.remove(&completion.id);
+ }
+ }
+ }
+ }
}
- fn is_enabled(
- &self,
- buffer: &Entity<Buffer>,
- cursor_position: language::Anchor,
- cx: &App,
- ) -> bool {
- self.read(cx).is_enabled(buffer, cursor_position, cx)
+ fn reject_prediction(
+ &mut self,
+ prediction_id: EditPredictionId,
+ reason: EditPredictionRejectReason,
+ was_shown: bool,
+ ) {
+ match self.edit_prediction_model {
+ EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
+ EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
+ }
+
+ self.reject_predictions_tx
+ .unbounded_send(EditPredictionRejection {
+ request_id: prediction_id.to_string(),
+ reason,
+ was_shown,
+ })
+ .log_err();
}
- fn is_refreshing(&self, cx: &App) -> bool {
- self.read(cx).is_refreshing(cx)
+ fn is_refreshing(&self, project: &Entity<Project>) -> bool {
+ self.projects
+ .get(&project.entity_id())
+ .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
}
- fn refresh(
- &self,
+ pub fn refresh_prediction_from_buffer(
+ &mut self,
+ project: Entity<Project>,
buffer: Entity<Buffer>,
- cursor_position: language::Anchor,
- debounce: bool,
- cx: &mut App,
+ position: language::Anchor,
+ cx: &mut Context<Self>,
) {
- self.update(cx, |this, cx| {
- this.refresh(buffer, cursor_position, debounce, cx)
+ self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
+ let Some(request_task) = this
+ .update(cx, |this, cx| {
+ this.request_prediction(
+ &project,
+ &buffer,
+ position,
+ PredictEditsRequestTrigger::Other,
+ cx,
+ )
+ })
+ .log_err()
+ else {
+ return Task::ready(anyhow::Ok(None));
+ };
+
+ cx.spawn(async move |_cx| {
+ request_task.await.map(|prediction_result| {
+ prediction_result.map(|prediction_result| {
+ (
+ prediction_result,
+ PredictionRequestedBy::Buffer(buffer.entity_id()),
+ )
+ })
+ })
+ })
})
}
- fn cycle(
- &self,
- buffer: Entity<Buffer>,
- cursor_position: language::Anchor,
- direction: Direction,
- cx: &mut App,
+ pub fn refresh_prediction_from_diagnostics(
+ &mut self,
+ project: Entity<Project>,
+ cx: &mut Context<Self>,
+ ) {
+ let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
+ return;
+ };
+
+ // Prefer predictions from buffer
+ if project_state.current_prediction.is_some() {
+ return;
+ };
+
+ self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
+ let Some(open_buffer_task) = project
+ .update(cx, |project, cx| {
+ project
+ .active_entry()
+ .and_then(|entry| project.path_for_entry(entry, cx))
+ .map(|path| project.open_buffer(path, cx))
+ })
+ .log_err()
+ .flatten()
+ else {
+ return Task::ready(anyhow::Ok(None));
+ };
+
+ cx.spawn(async move |cx| {
+ let active_buffer = open_buffer_task.await?;
+ let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
+
+ let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
+ active_buffer,
+ &snapshot,
+ Default::default(),
+ Default::default(),
+ &project,
+ cx,
+ )
+ .await?
+ else {
+ return anyhow::Ok(None);
+ };
+
+ let Some(prediction_result) = this
+ .update(cx, |this, cx| {
+ this.request_prediction(
+ &project,
+ &jump_buffer,
+ jump_position,
+ PredictEditsRequestTrigger::Diagnostics,
+ cx,
+ )
+ })?
+ .await?
+ else {
+ return anyhow::Ok(None);
+ };
+
+ this.update(cx, |this, cx| {
+ Some((
+ if this
+ .get_or_init_project(&project, cx)
+ .current_prediction
+ .is_none()
+ {
+ prediction_result
+ } else {
+ EditPredictionResult {
+ id: prediction_result.id,
+ prediction: Err(EditPredictionRejectReason::CurrentPreferred),
+ }
+ },
+ PredictionRequestedBy::DiagnosticsUpdate,
+ ))
+ })
+ })
+ });
+ }
+
+ #[cfg(not(test))]
+ pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
+ #[cfg(test)]
+ pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
+
+ fn queue_prediction_refresh(
+ &mut self,
+ project: Entity<Project>,
+ throttle_entity: EntityId,
+ cx: &mut Context<Self>,
+ do_refresh: impl FnOnce(
+ WeakEntity<Self>,
+ &mut AsyncApp,
+ )
+ -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
+ + 'static,
) {
- self.update(cx, |this, cx| {
- this.cycle(buffer, cursor_position, direction, cx)
+ let project_state = self.get_or_init_project(&project, cx);
+ let pending_prediction_id = project_state.next_pending_prediction_id;
+ project_state.next_pending_prediction_id += 1;
+ let last_request = project_state.last_prediction_refresh;
+
+ let task = cx.spawn(async move |this, cx| {
+ if let Some((last_entity, last_timestamp)) = last_request
+ && throttle_entity == last_entity
+ && let Some(timeout) =
+ (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
+ {
+ cx.background_executor().timer(timeout).await;
+ }
+
+ // If this task was cancelled before the throttle timeout expired,
+ // do not perform a request.
+ let mut is_cancelled = true;
+ this.update(cx, |this, cx| {
+ let project_state = this.get_or_init_project(&project, cx);
+ if !project_state
+ .cancelled_predictions
+ .remove(&pending_prediction_id)
+ {
+ project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
+ is_cancelled = false;
+ }
+ })
+ .ok();
+ if is_cancelled {
+ return None;
+ }
+
+ let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
+ let new_prediction_id = new_prediction_result
+ .as_ref()
+ .map(|(prediction, _)| prediction.id.clone());
+
+ // When a prediction completes, remove it from the pending list, and cancel
+ // any pending predictions that were enqueued before it.
+ this.update(cx, |this, cx| {
+ let project_state = this.get_or_init_project(&project, cx);
+
+ let is_cancelled = project_state
+ .cancelled_predictions
+ .remove(&pending_prediction_id);
+
+ let new_current_prediction = if !is_cancelled
+ && let Some((prediction_result, requested_by)) = new_prediction_result
+ {
+ match prediction_result.prediction {
+ Ok(prediction) => {
+ let new_prediction = CurrentEditPrediction {
+ requested_by,
+ prediction,
+ was_shown: false,
+ };
+
+ if let Some(current_prediction) =
+ project_state.current_prediction.as_ref()
+ {
+ if new_prediction.should_replace_prediction(¤t_prediction, cx)
+ {
+ this.reject_current_prediction(
+ EditPredictionRejectReason::Replaced,
+ &project,
+ );
+
+ Some(new_prediction)
+ } else {
+ this.reject_prediction(
+ new_prediction.prediction.id,
+ EditPredictionRejectReason::CurrentPreferred,
+ false,
+ );
+ None
+ }
+ } else {
+ Some(new_prediction)
+ }
+ }
+ Err(reject_reason) => {
+ this.reject_prediction(prediction_result.id, reject_reason, false);
+ None
+ }
+ }
+ } else {
+ None
+ };
+
+ let project_state = this.get_or_init_project(&project, cx);
+
+ if let Some(new_prediction) = new_current_prediction {
+ project_state.current_prediction = Some(new_prediction);
+ }
+
+ let mut pending_predictions = mem::take(&mut project_state.pending_predictions);
+ for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
+ if pending_prediction.id == pending_prediction_id {
+ pending_predictions.remove(ix);
+ for pending_prediction in pending_predictions.drain(0..ix) {
+ project_state.cancel_pending_prediction(pending_prediction, cx)
+ }
+ break;
+ }
+ }
+ this.get_or_init_project(&project, cx).pending_predictions = pending_predictions;
+ cx.notify();
+ })
+ .ok();
+
+ new_prediction_id
+ });
+
+ if project_state.pending_predictions.len() <= 1 {
+ project_state.pending_predictions.push(PendingPrediction {
+ id: pending_prediction_id,
+ task,
+ });
+ } else if project_state.pending_predictions.len() == 2 {
+ let pending_prediction = project_state.pending_predictions.pop().unwrap();
+ project_state.pending_predictions.push(PendingPrediction {
+ id: pending_prediction_id,
+ task,
+ });
+ project_state.cancel_pending_prediction(pending_prediction, cx);
+ }
+ }
+
+ pub fn request_prediction(
+ &mut self,
+ project: &Entity<Project>,
+ active_buffer: &Entity<Buffer>,
+ position: language::Anchor,
+ trigger: PredictEditsRequestTrigger,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<Option<EditPredictionResult>>> {
+ self.request_prediction_internal(
+ project.clone(),
+ active_buffer.clone(),
+ position,
+ trigger,
+ cx.has_flag::<Zeta2FeatureFlag>(),
+ cx,
+ )
+ }
+
+ fn request_prediction_internal(
+ &mut self,
+ project: Entity<Project>,
+ active_buffer: Entity<Buffer>,
+ position: language::Anchor,
+ trigger: PredictEditsRequestTrigger,
+ allow_jump: bool,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<Option<EditPredictionResult>>> {
+ const DIAGNOSTIC_LINES_RANGE: u32 = 20;
+
+ self.get_or_init_project(&project, cx);
+ let project_state = self.projects.get(&project.entity_id()).unwrap();
+ let events = project_state.events(cx);
+ let has_events = !events.is_empty();
+
+ let snapshot = active_buffer.read(cx).snapshot();
+ let cursor_point = position.to_point(&snapshot);
+ let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
+ let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
+ let diagnostic_search_range =
+ Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
+
+ let related_files = if self.use_context {
+ self.context_for_project(&project, cx).to_vec()
+ } else {
+ Vec::new()
+ };
+
+ let task = match self.edit_prediction_model {
+ EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(
+ self,
+ &project,
+ &active_buffer,
+ snapshot.clone(),
+ position,
+ events,
+ trigger,
+ cx,
+ ),
+ EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
+ self,
+ &project,
+ &active_buffer,
+ snapshot.clone(),
+ position,
+ events,
+ related_files,
+ trigger,
+ cx,
+ ),
+ EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
+ &project,
+ &active_buffer,
+ snapshot.clone(),
+ position,
+ events,
+ &project_state.recent_paths,
+ related_files,
+ diagnostic_search_range.clone(),
+ cx,
+ ),
+ EditPredictionModel::Mercury => self.mercury.request_prediction(
+ &project,
+ &active_buffer,
+ snapshot.clone(),
+ position,
+ events,
+ &project_state.recent_paths,
+ related_files,
+ diagnostic_search_range.clone(),
+ cx,
+ ),
+ };
+
+ cx.spawn(async move |this, cx| {
+ let prediction = task.await?;
+
+ if prediction.is_none() && allow_jump {
+ let cursor_point = position.to_point(&snapshot);
+ if has_events
+ && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
+ active_buffer.clone(),
+ &snapshot,
+ diagnostic_search_range,
+ cursor_point,
+ &project,
+ cx,
+ )
+ .await?
+ {
+ return this
+ .update(cx, |this, cx| {
+ this.request_prediction_internal(
+ project,
+ jump_buffer,
+ jump_position,
+ trigger,
+ false,
+ cx,
+ )
+ })?
+ .await;
+ }
+
+ return anyhow::Ok(None);
+ }
+
+ Ok(prediction)
})
}
- fn accept(&self, cx: &mut App) {
- self.update(cx, |this, cx| this.accept(cx))
+ async fn next_diagnostic_location(
+ active_buffer: Entity<Buffer>,
+ active_buffer_snapshot: &BufferSnapshot,
+ active_buffer_diagnostic_search_range: Range<Point>,
+ active_buffer_cursor_point: Point,
+ project: &Entity<Project>,
+ cx: &mut AsyncApp,
+ ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
+ // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
+ let mut jump_location = active_buffer_snapshot
+ .diagnostic_groups(None)
+ .into_iter()
+ .filter_map(|(_, group)| {
+ let range = &group.entries[group.primary_ix]
+ .range
+ .to_point(&active_buffer_snapshot);
+ if range.overlaps(&active_buffer_diagnostic_search_range) {
+ None
+ } else {
+ Some(range.start)
+ }
+ })
+ .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
+ .map(|position| {
+ (
+ active_buffer.clone(),
+ active_buffer_snapshot.anchor_before(position),
+ )
+ });
+
+ if jump_location.is_none() {
+ let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
+ let file = buffer.file()?;
+
+ Some(ProjectPath {
+ worktree_id: file.worktree_id(cx),
+ path: file.path().clone(),
+ })
+ })?;
+
+ let buffer_task = project.update(cx, |project, cx| {
+ let (path, _, _) = project
+ .diagnostic_summaries(false, cx)
+ .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
+ .max_by_key(|(path, _, _)| {
+ // find the buffer with errors that shares most parent directories
+ path.path
+ .components()
+ .zip(
+ active_buffer_path
+ .as_ref()
+ .map(|p| p.path.components())
+ .unwrap_or_default(),
+ )
+ .take_while(|(a, b)| a == b)
+ .count()
+ })?;
+
+ Some(project.open_buffer(path, cx))
+ })?;
+
+ if let Some(buffer_task) = buffer_task {
+ let closest_buffer = buffer_task.await?;
+
+ jump_location = closest_buffer
+ .read_with(cx, |buffer, _cx| {
+ buffer
+ .buffer_diagnostics(None)
+ .into_iter()
+ .min_by_key(|entry| entry.diagnostic.severity)
+ .map(|entry| entry.range.start)
+ })?
+ .map(|position| (closest_buffer, position));
+ }
+ }
+
+ anyhow::Ok(jump_location)
}
- fn discard(&self, cx: &mut App) {
- self.update(cx, |this, cx| this.discard(cx))
+ async fn send_raw_llm_request(
+ request: open_ai::Request,
+ client: Arc<Client>,
+ llm_token: LlmApiToken,
+ app_version: Version,
+ #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
+ #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
+ ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
+ let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
+ http_client::Url::parse(&predict_edits_url)?
+ } else {
+ client
+ .http_client()
+ .build_zed_llm_url("/predict_edits/raw", &[])?
+ };
+
+ #[cfg(feature = "eval-support")]
+ let cache_key = if let Some(cache) = eval_cache {
+ use collections::FxHasher;
+ use std::hash::{Hash, Hasher};
+
+ let mut hasher = FxHasher::default();
+ url.hash(&mut hasher);
+ let request_str = serde_json::to_string_pretty(&request)?;
+ request_str.hash(&mut hasher);
+ let hash = hasher.finish();
+
+ let key = (eval_cache_kind, hash);
+ if let Some(response_str) = cache.read(key) {
+ return Ok((serde_json::from_str(&response_str)?, None));
+ }
+
+ Some((cache, request_str, key))
+ } else {
+ None
+ };
+
+ let (response, usage) = Self::send_api_request(
+ |builder| {
+ let req = builder
+ .uri(url.as_ref())
+ .body(serde_json::to_string(&request)?.into());
+ Ok(req?)
+ },
+ client,
+ llm_token,
+ app_version,
+ )
+ .await?;
+
+ #[cfg(feature = "eval-support")]
+ if let Some((cache, request, key)) = cache_key {
+ cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
+ }
+
+ Ok((response, usage))
}
- fn did_show(&self, cx: &mut App) {
- self.update(cx, |this, cx| this.did_show(cx))
+ fn handle_api_response<T>(
+ this: &WeakEntity<Self>,
+ response: Result<(T, Option<EditPredictionUsage>)>,
+ cx: &mut gpui::AsyncApp,
+ ) -> Result<T> {
+ match response {
+ Ok((data, usage)) => {
+ if let Some(usage) = usage {
+ this.update(cx, |this, cx| {
+ this.user_store.update(cx, |user_store, cx| {
+ user_store.update_edit_prediction_usage(usage, cx);
+ });
+ })
+ .ok();
+ }
+ Ok(data)
+ }
+ Err(err) => {
+ if err.is::<ZedUpdateRequiredError>() {
+ cx.update(|cx| {
+ this.update(cx, |this, _cx| {
+ this.update_required = true;
+ })
+ .ok();
+
+ let error_message: SharedString = err.to_string().into();
+ show_app_notification(
+ NotificationId::unique::<ZedUpdateRequiredError>(),
+ cx,
+ move |cx| {
+ cx.new(|cx| {
+ ErrorMessagePrompt::new(error_message.clone(), cx)
+ .with_link_button("Update Zed", "https://zed.dev/releases")
+ })
+ },
+ );
+ })
+ .ok();
+ }
+ Err(err)
+ }
+ }
}
- fn suggest(
- &self,
- buffer: &Entity<Buffer>,
- cursor_position: language::Anchor,
- cx: &mut App,
- ) -> Option<EditPrediction> {
- self.update(cx, |this, cx| this.suggest(buffer, cursor_position, cx))
- }
-}
-
-/// Returns edits updated based on user edits since the old snapshot. None is returned if any user
-/// edit is not a prefix of a predicted insertion.
-pub fn interpolate_edits(
- old_snapshot: &BufferSnapshot,
- new_snapshot: &BufferSnapshot,
- current_edits: &[(Range<Anchor>, Arc<str>)],
-) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
- let mut edits = Vec::new();
-
- let mut model_edits = current_edits.iter().peekable();
- for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
- while let Some((model_old_range, _)) = model_edits.peek() {
- let model_old_range = model_old_range.to_offset(old_snapshot);
- if model_old_range.end < user_edit.old.start {
- let (model_old_range, model_new_text) = model_edits.next().unwrap();
- edits.push((model_old_range.clone(), model_new_text.clone()));
+ async fn send_api_request<Res>(
+ build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
+ client: Arc<Client>,
+ llm_token: LlmApiToken,
+ app_version: Version,
+ ) -> Result<(Res, Option<EditPredictionUsage>)>
+ where
+ Res: DeserializeOwned,
+ {
+ let http_client = client.http_client();
+ let mut token = llm_token.acquire(&client).await?;
+ let mut did_retry = false;
+
+ loop {
+ let request_builder = http_client::Request::builder().method(Method::POST);
+
+ let request = build(
+ request_builder
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", token))
+ .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
+ )?;
+
+ let mut response = http_client.send(request).await?;
+
+ if let Some(minimum_required_version) = response
+ .headers()
+ .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
+ .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
+ {
+ anyhow::ensure!(
+ app_version >= minimum_required_version,
+ ZedUpdateRequiredError {
+ minimum_version: minimum_required_version
+ }
+ );
+ }
+
+ if response.status().is_success() {
+ let usage = EditPredictionUsage::from_headers(response.headers()).ok();
+
+ let mut body = Vec::new();
+ response.body_mut().read_to_end(&mut body).await?;
+ return Ok((serde_json::from_slice(&body)?, usage));
+ } else if !did_retry
+ && response
+ .headers()
+ .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
+ .is_some()
+ {
+ did_retry = true;
+ token = llm_token.refresh(&client).await?;
} else {
- break;
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ anyhow::bail!(
+ "Request failed with status: {:?}\nBody: {}",
+ response.status(),
+ body
+ );
}
}
+ }
- if let Some((model_old_range, model_new_text)) = model_edits.peek() {
- let model_old_offset_range = model_old_range.to_offset(old_snapshot);
- if user_edit.old == model_old_offset_range {
- let user_new_text = new_snapshot
- .text_for_range(user_edit.new.clone())
- .collect::<String>();
+ pub fn refresh_context(
+ &mut self,
+ project: &Entity<Project>,
+ buffer: &Entity<language::Buffer>,
+ cursor_position: language::Anchor,
+ cx: &mut Context<Self>,
+ ) {
+ if self.use_context {
+ self.get_or_init_project(project, cx)
+ .context
+ .update(cx, |store, cx| {
+ store.refresh(buffer.clone(), cursor_position, cx);
+ });
+ }
+ }
- if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
- if !model_suffix.is_empty() {
- let anchor = old_snapshot.anchor_after(user_edit.old.end);
- edits.push((anchor..anchor, model_suffix.into()));
- }
+ fn is_file_open_source(
+ &self,
+ project: &Entity<Project>,
+ file: &Arc<dyn File>,
+ cx: &App,
+ ) -> bool {
+ if !file.is_local() || file.is_private() {
+ return false;
+ }
+ let Some(project_state) = self.projects.get(&project.entity_id()) else {
+ return false;
+ };
+ project_state
+ .license_detection_watchers
+ .get(&file.worktree_id(cx))
+ .as_ref()
+ .is_some_and(|watcher| watcher.is_project_open_source())
+ }
+
+ fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
+ self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
+ }
- model_edits.next();
- continue;
+ fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
+ if !self.data_collection_choice.is_enabled() {
+ return false;
+ }
+ events.iter().all(|event| {
+ matches!(
+ event.as_ref(),
+ Event::BufferChange {
+ in_open_source_repo: true,
+ ..
}
+ )
+ })
+ }
+
+ fn load_data_collection_choice() -> DataCollectionChoice {
+ let choice = KEY_VALUE_STORE
+ .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
+ .log_err()
+ .flatten();
+
+ match choice.as_deref() {
+ Some("true") => DataCollectionChoice::Enabled,
+ Some("false") => DataCollectionChoice::Disabled,
+ Some(_) => {
+ log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
+ DataCollectionChoice::NotAnswered
}
+ None => DataCollectionChoice::NotAnswered,
+ }
+ }
+
+ fn toggle_data_collection_choice(&mut self, cx: &mut Context<Self>) {
+ self.data_collection_choice = self.data_collection_choice.toggle();
+ let new_choice = self.data_collection_choice;
+ db::write_and_log(cx, move || {
+ KEY_VALUE_STORE.write_kvp(
+ ZED_PREDICT_DATA_COLLECTION_CHOICE.into(),
+ new_choice.is_enabled().to_string(),
+ )
+ });
+ }
+
+ pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
+ self.shown_predictions.iter()
+ }
+
+ pub fn shown_completions_len(&self) -> usize {
+ self.shown_predictions.len()
+ }
+
+ pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
+ self.rated_predictions.contains(id)
+ }
+
+ pub fn rate_prediction(
+ &mut self,
+ prediction: &EditPrediction,
+ rating: EditPredictionRating,
+ feedback: String,
+ cx: &mut Context<Self>,
+ ) {
+ self.rated_predictions.insert(prediction.id.clone());
+ telemetry::event!(
+ "Edit Prediction Rated",
+ rating,
+ inputs = prediction.inputs,
+ output = prediction.edit_preview.as_unified_diff(&prediction.edits),
+ feedback
+ );
+ self.client.telemetry().flush_events().detach();
+ cx.notify();
+ }
+
+ fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
+ self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
+ && all_language_settings(None, cx).edit_predictions.use_context;
+ }
+}
+
+#[derive(Error, Debug)]
+#[error(
+ "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
+)]
+pub struct ZedUpdateRequiredError {
+ minimum_version: Version,
+}
+
+#[cfg(feature = "eval-support")]
+pub type EvalCacheKey = (EvalCacheEntryKind, u64);
+
+#[cfg(feature = "eval-support")]
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub enum EvalCacheEntryKind {
+ Context,
+ Search,
+ Prediction,
+}
+
+#[cfg(feature = "eval-support")]
+impl std::fmt::Display for EvalCacheEntryKind {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ EvalCacheEntryKind::Search => write!(f, "search"),
+ EvalCacheEntryKind::Context => write!(f, "context"),
+ EvalCacheEntryKind::Prediction => write!(f, "prediction"),
}
+ }
+}
- return None;
+#[cfg(feature = "eval-support")]
+pub trait EvalCache: Send + Sync {
+ fn read(&self, key: EvalCacheKey) -> Option<String>;
+ fn write(&self, key: EvalCacheKey, input: &str, value: &str);
+}
+
+#[derive(Debug, Clone, Copy)]
+pub enum DataCollectionChoice {
+ NotAnswered,
+ Enabled,
+ Disabled,
+}
+
+impl DataCollectionChoice {
+ pub fn is_enabled(self) -> bool {
+ match self {
+ Self::Enabled => true,
+ Self::NotAnswered | Self::Disabled => false,
+ }
+ }
+
+ pub fn is_answered(self) -> bool {
+ match self {
+ Self::Enabled | Self::Disabled => true,
+ Self::NotAnswered => false,
+ }
}
- edits.extend(model_edits.cloned());
+ #[must_use]
+ pub fn toggle(&self) -> DataCollectionChoice {
+ match self {
+ Self::Enabled => Self::Disabled,
+ Self::Disabled => Self::Enabled,
+ Self::NotAnswered => Self::Enabled,
+ }
+ }
+}
+
+impl From<bool> for DataCollectionChoice {
+ fn from(value: bool) -> Self {
+ match value {
+ true => DataCollectionChoice::Enabled,
+ false => DataCollectionChoice::Disabled,
+ }
+ }
+}
+
+struct ZedPredictUpsell;
+
+impl Dismissable for ZedPredictUpsell {
+ const KEY: &'static str = "dismissed-edit-predict-upsell";
+
+ fn dismissed() -> bool {
+ // To make this backwards compatible with older versions of Zed, we
+ // check if the user has seen the previous Edit Prediction Onboarding
+ // before, by checking the data collection choice which was written to
+ // the database once the user clicked on "Accept and Enable"
+ if KEY_VALUE_STORE
+ .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
+ .log_err()
+ .is_some_and(|s| s.is_some())
+ {
+ return true;
+ }
+
+ KEY_VALUE_STORE
+ .read_kvp(Self::KEY)
+ .log_err()
+ .is_some_and(|s| s.is_some())
+ }
+}
+
+pub fn should_show_upsell_modal() -> bool {
+ !ZedPredictUpsell::dismissed()
+}
+
+pub fn init(cx: &mut App) {
+ cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
+ workspace.register_action(
+ move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
+ ZedPredictModal::toggle(
+ workspace,
+ workspace.user_store().clone(),
+ workspace.client().clone(),
+ window,
+ cx,
+ )
+ },
+ );
- if edits.is_empty() { None } else { Some(edits) }
+ workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
+ update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
+ settings
+ .project
+ .all_languages
+ .features
+ .get_or_insert_default()
+ .edit_prediction_provider = Some(EditPredictionProvider::None)
+ });
+ });
+ })
+ .detach();
}
@@ -0,0 +1,1806 @@
+use super::*;
+use crate::zeta1::MAX_EVENT_TOKENS;
+use client::{UserStore, test::FakeServer};
+use clock::{FakeSystemClock, ReplicaId};
+use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
+use cloud_llm_client::{
+ EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
+ RejectEditPredictionsBody,
+};
+use edit_prediction_context::Line;
+use futures::{
+ AsyncReadExt, StreamExt,
+ channel::{mpsc, oneshot},
+};
+use gpui::{
+ Entity, TestAppContext,
+ http_client::{FakeHttpClient, Response},
+};
+use indoc::indoc;
+use language::{Point, ToOffset as _};
+use lsp::LanguageServerId;
+use open_ai::Usage;
+use parking_lot::Mutex;
+use pretty_assertions::{assert_eq, assert_matches};
+use project::{FakeFs, Project};
+use serde_json::json;
+use settings::SettingsStore;
+use std::{path::Path, sync::Arc, time::Duration};
+use util::{path, rel_path::rel_path};
+use uuid::Uuid;
+
+use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
+
+#[gpui::test]
+async fn test_current_state(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "1.txt": "Hello!\nHow\nBye\n",
+ "2.txt": "Hola!\nComo\nAdios\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.register_project(&project, cx);
+ });
+
+ let buffer1 = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
+ project.set_active_path(Some(path.clone()), cx);
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot1.anchor_before(language::Point::new(1, 3));
+
+ // Prediction for current file
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
+ });
+ let (_request, respond_tx) = requests.predict.next().await.unwrap();
+
+ respond_tx
+ .send(model_response(indoc! {r"
+ --- a/root/1.txt
+ +++ b/root/1.txt
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "}))
+ .unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ let prediction = ep_store
+ .current_prediction_for_buffer(&buffer1, &project, cx)
+ .unwrap();
+ assert_matches!(prediction, BufferEditPrediction::Local { .. });
+ });
+
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
+ });
+
+ // Prediction for diagnostic in another file
+
+ let diagnostic = lsp::Diagnostic {
+ range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
+ severity: Some(lsp::DiagnosticSeverity::ERROR),
+ message: "Sentence is incomplete".to_string(),
+ ..Default::default()
+ };
+
+ project.update(cx, |project, cx| {
+ project.lsp_store().update(cx, |lsp_store, cx| {
+ lsp_store
+ .update_diagnostics(
+ LanguageServerId(0),
+ lsp::PublishDiagnosticsParams {
+ uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(),
+ diagnostics: vec![diagnostic],
+ version: None,
+ },
+ None,
+ language::DiagnosticSourceKind::Pushed,
+ &[],
+ cx,
+ )
+ .unwrap();
+ });
+ });
+
+ let (_request, respond_tx) = requests.predict.next().await.unwrap();
+ respond_tx
+ .send(model_response(indoc! {r#"
+ --- a/root/2.txt
+ +++ b/root/2.txt
+ Hola!
+ -Como
+ +Como estas?
+ Adios
+ "#}))
+ .unwrap();
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ let prediction = ep_store
+ .current_prediction_for_buffer(&buffer1, &project, cx)
+ .unwrap();
+ assert_matches!(
+ prediction,
+ BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
+ );
+ });
+
+ let buffer2 = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ let prediction = ep_store
+ .current_prediction_for_buffer(&buffer2, &project, cx)
+ .unwrap();
+ assert_matches!(prediction, BufferEditPrediction::Local { .. });
+ });
+}
+
+#[gpui::test]
+async fn test_simple_request(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ let prediction_task = ep_store.update(cx, |ep_store, cx| {
+ ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
+ });
+
+ let (_, respond_tx) = requests.predict.next().await.unwrap();
+
+ // TODO Put back when we have a structured request again
+ // assert_eq!(
+ // request.excerpt_path.as_ref(),
+ // Path::new(path!("root/foo.md"))
+ // );
+ // assert_eq!(
+ // request.cursor_point,
+ // Point {
+ // line: Line(1),
+ // column: 3
+ // }
+ // );
+
+ respond_tx
+ .send(model_response(indoc! { r"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "}))
+ .unwrap();
+
+ let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
+
+ assert_eq!(prediction.edits.len(), 1);
+ assert_eq!(
+ prediction.edits[0].0.to_point(&snapshot).start,
+ language::Point::new(1, 3)
+ );
+ assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
+}
+
+#[gpui::test]
+async fn test_request_events(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\n\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.register_buffer(&buffer, &project, cx);
+ });
+
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit(vec![(7..7, "How")], None, cx);
+ });
+
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ let prediction_task = ep_store.update(cx, |ep_store, cx| {
+ ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
+ });
+
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
+
+ let prompt = prompt_from_request(&request);
+ assert!(
+ prompt.contains(indoc! {"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ -1,3 +1,3 @@
+ Hello!
+ -
+ +How
+ Bye
+ "}),
+ "{prompt}"
+ );
+
+ respond_tx
+ .send(model_response(indoc! {r#"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "#}))
+ .unwrap();
+
+ let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
+
+ assert_eq!(prediction.edits.len(), 1);
+ assert_eq!(
+ prediction.edits[0].0.to_point(&snapshot).start,
+ language::Point::new(1, 3)
+ );
+ assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
+}
+
+#[gpui::test]
+async fn test_empty_prediction(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ const NO_OP_DIFF: &str = indoc! { r"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How
+ Bye
+ "};
+
+ let (_, respond_tx) = requests.predict.next().await.unwrap();
+ let response = model_response(NO_OP_DIFF);
+ let id = response.id.clone();
+ respond_tx.send(response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ assert!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .is_none()
+ );
+ });
+
+ // prediction is reported as rejected
+ let (reject_request, _) = requests.reject.next().await.unwrap();
+
+ assert_eq!(
+ &reject_request.rejections,
+ &[EditPredictionRejection {
+ request_id: id,
+ reason: EditPredictionRejectReason::Empty,
+ was_shown: false
+ }]
+ );
+}
+
+#[gpui::test]
+async fn test_interpolated_empty(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_tx) = requests.predict.next().await.unwrap();
+
+ buffer.update(cx, |buffer, cx| {
+ buffer.set_text("Hello!\nHow are you?\nBye", cx);
+ });
+
+ let response = model_response(SIMPLE_DIFF);
+ let id = response.id.clone();
+ respond_tx.send(response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ assert!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .is_none()
+ );
+ });
+
+ // prediction is reported as rejected
+ let (reject_request, _) = requests.reject.next().await.unwrap();
+
+ assert_eq!(
+ &reject_request.rejections,
+ &[EditPredictionRejection {
+ request_id: id,
+ reason: EditPredictionRejectReason::InterpolatedEmpty,
+ was_shown: false
+ }]
+ );
+}
+
+const SIMPLE_DIFF: &str = indoc! { r"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+"};
+
+#[gpui::test]
+async fn test_replace_current(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_tx) = requests.predict.next().await.unwrap();
+ let first_response = model_response(SIMPLE_DIFF);
+ let first_id = first_response.id.clone();
+ respond_tx.send(first_response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ assert_eq!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ first_id
+ );
+ });
+
+ // a second request is triggered
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_tx) = requests.predict.next().await.unwrap();
+ let second_response = model_response(SIMPLE_DIFF);
+ let second_id = second_response.id.clone();
+ respond_tx.send(second_response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ // second replaces first
+ assert_eq!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ second_id
+ );
+ });
+
+ // first is reported as replaced
+ let (reject_request, _) = requests.reject.next().await.unwrap();
+
+ assert_eq!(
+ &reject_request.rejections,
+ &[EditPredictionRejection {
+ request_id: first_id,
+ reason: EditPredictionRejectReason::Replaced,
+ was_shown: false
+ }]
+ );
+}
+
+#[gpui::test]
+async fn test_current_preferred(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_tx) = requests.predict.next().await.unwrap();
+ let first_response = model_response(SIMPLE_DIFF);
+ let first_id = first_response.id.clone();
+ respond_tx.send(first_response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ assert_eq!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ first_id
+ );
+ });
+
+ // a second request is triggered
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_tx) = requests.predict.next().await.unwrap();
+ // worse than current prediction
+ let second_response = model_response(indoc! { r"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are
+ Bye
+ "});
+ let second_id = second_response.id.clone();
+ respond_tx.send(second_response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ // first is preferred over second
+ assert_eq!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ first_id
+ );
+ });
+
+ // second is reported as rejected
+ let (reject_request, _) = requests.reject.next().await.unwrap();
+
+ assert_eq!(
+ &reject_request.rejections,
+ &[EditPredictionRejection {
+ request_id: second_id,
+ reason: EditPredictionRejectReason::CurrentPreferred,
+ was_shown: false
+ }]
+ );
+}
+
+#[gpui::test]
+async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ // start two refresh tasks
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_first) = requests.predict.next().await.unwrap();
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_second) = requests.predict.next().await.unwrap();
+
+ // wait for throttle
+ cx.run_until_parked();
+
+ // second responds first
+ let second_response = model_response(SIMPLE_DIFF);
+ let second_id = second_response.id.clone();
+ respond_second.send(second_response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ // current prediction is second
+ assert_eq!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ second_id
+ );
+ });
+
+ let first_response = model_response(SIMPLE_DIFF);
+ let first_id = first_response.id.clone();
+ respond_first.send(first_response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ // current prediction is still second, since first was cancelled
+ assert_eq!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ second_id
+ );
+ });
+
+ // first is reported as rejected
+ let (reject_request, _) = requests.reject.next().await.unwrap();
+
+ cx.run_until_parked();
+
+ assert_eq!(
+ &reject_request.rejections,
+ &[EditPredictionRejection {
+ request_id: first_id,
+ reason: EditPredictionRejectReason::Canceled,
+ was_shown: false
+ }]
+ );
+}
+
+#[gpui::test]
+async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.md": "Hello!\nHow\nBye\n"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(1, 3));
+
+ // start two refresh tasks
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_first) = requests.predict.next().await.unwrap();
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_, respond_second) = requests.predict.next().await.unwrap();
+
+ // wait for throttle, so requests are sent
+ cx.run_until_parked();
+
+ ep_store.update(cx, |ep_store, cx| {
+ // start a third request
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+
+ // 2 are pending, so 2nd is cancelled
+ assert_eq!(
+ ep_store
+ .get_or_init_project(&project, cx)
+ .cancelled_predictions
+ .iter()
+ .copied()
+ .collect::<Vec<_>>(),
+ [1]
+ );
+ });
+
+ // wait for throttle
+ cx.run_until_parked();
+
+ let (_, respond_third) = requests.predict.next().await.unwrap();
+
+ let first_response = model_response(SIMPLE_DIFF);
+ let first_id = first_response.id.clone();
+ respond_first.send(first_response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ // current prediction is first
+ assert_eq!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ first_id
+ );
+ });
+
+ let cancelled_response = model_response(SIMPLE_DIFF);
+ let cancelled_id = cancelled_response.id.clone();
+ respond_second.send(cancelled_response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ // current prediction is still first, since second was cancelled
+ assert_eq!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ first_id
+ );
+ });
+
+ let third_response = model_response(SIMPLE_DIFF);
+ let third_response_id = third_response.id.clone();
+ respond_third.send(third_response).unwrap();
+
+ cx.run_until_parked();
+
+ ep_store.read_with(cx, |ep_store, cx| {
+ // third completes and replaces first
+ assert_eq!(
+ ep_store
+ .current_prediction_for_buffer(&buffer, &project, cx)
+ .unwrap()
+ .id
+ .0,
+ third_response_id
+ );
+ });
+
+ // second is reported as rejected
+ let (reject_request, _) = requests.reject.next().await.unwrap();
+
+ cx.run_until_parked();
+
+ assert_eq!(
+ &reject_request.rejections,
+ &[
+ EditPredictionRejection {
+ request_id: cancelled_id,
+ reason: EditPredictionRejectReason::Canceled,
+ was_shown: false
+ },
+ EditPredictionRejection {
+ request_id: first_id,
+ reason: EditPredictionRejectReason::Replaced,
+ was_shown: false
+ }
+ ]
+ );
+}
+
+#[gpui::test]
+async fn test_rejections_flushing(cx: &mut TestAppContext) {
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.reject_prediction(
+ EditPredictionId("test-1".into()),
+ EditPredictionRejectReason::Discarded,
+ false,
+ );
+ ep_store.reject_prediction(
+ EditPredictionId("test-2".into()),
+ EditPredictionRejectReason::Canceled,
+ true,
+ );
+ });
+
+ cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
+ cx.run_until_parked();
+
+ let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
+ respond_tx.send(()).unwrap();
+
+ // batched
+ assert_eq!(reject_request.rejections.len(), 2);
+ assert_eq!(
+ reject_request.rejections[0],
+ EditPredictionRejection {
+ request_id: "test-1".to_string(),
+ reason: EditPredictionRejectReason::Discarded,
+ was_shown: false
+ }
+ );
+ assert_eq!(
+ reject_request.rejections[1],
+ EditPredictionRejection {
+ request_id: "test-2".to_string(),
+ reason: EditPredictionRejectReason::Canceled,
+ was_shown: true
+ }
+ );
+
+ // Reaching batch size limit sends without debounce
+ ep_store.update(cx, |ep_store, _cx| {
+ for i in 0..70 {
+ ep_store.reject_prediction(
+ EditPredictionId(format!("batch-{}", i).into()),
+ EditPredictionRejectReason::Discarded,
+ false,
+ );
+ }
+ });
+
+ // First MAX/2 items are sent immediately
+ cx.run_until_parked();
+ let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
+ respond_tx.send(()).unwrap();
+
+ assert_eq!(reject_request.rejections.len(), 50);
+ assert_eq!(reject_request.rejections[0].request_id, "batch-0");
+ assert_eq!(reject_request.rejections[49].request_id, "batch-49");
+
+ // Remaining items are debounced with the next batch
+ cx.executor().advance_clock(Duration::from_secs(15));
+ cx.run_until_parked();
+
+ let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
+ respond_tx.send(()).unwrap();
+
+ assert_eq!(reject_request.rejections.len(), 20);
+ assert_eq!(reject_request.rejections[0].request_id, "batch-50");
+ assert_eq!(reject_request.rejections[19].request_id, "batch-69");
+
+ // Request failure
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.reject_prediction(
+ EditPredictionId("retry-1".into()),
+ EditPredictionRejectReason::Discarded,
+ false,
+ );
+ });
+
+ cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
+ cx.run_until_parked();
+
+ let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
+ assert_eq!(reject_request.rejections.len(), 1);
+ assert_eq!(reject_request.rejections[0].request_id, "retry-1");
+ // Simulate failure
+ drop(_respond_tx);
+
+ // Add another rejection
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.reject_prediction(
+ EditPredictionId("retry-2".into()),
+ EditPredictionRejectReason::Discarded,
+ false,
+ );
+ });
+
+ cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
+ cx.run_until_parked();
+
+ // Retry should include both the failed item and the new one
+ let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
+ respond_tx.send(()).unwrap();
+
+ assert_eq!(reject_request.rejections.len(), 2);
+ assert_eq!(reject_request.rejections[0].request_id, "retry-1");
+ assert_eq!(reject_request.rejections[1].request_id, "retry-2");
+}
+
+// Skipped until we start including diagnostics in prompt
+// #[gpui::test]
+// async fn test_request_diagnostics(cx: &mut TestAppContext) {
+// let (ep_store, mut req_rx) = init_test_with_fake_client(cx);
+// let fs = FakeFs::new(cx.executor());
+// fs.insert_tree(
+// "/root",
+// json!({
+// "foo.md": "Hello!\nBye"
+// }),
+// )
+// .await;
+// let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+// let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
+// let diagnostic = lsp::Diagnostic {
+// range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
+// severity: Some(lsp::DiagnosticSeverity::ERROR),
+// message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
+// ..Default::default()
+// };
+
+// project.update(cx, |project, cx| {
+// project.lsp_store().update(cx, |lsp_store, cx| {
+// // Create some diagnostics
+// lsp_store
+// .update_diagnostics(
+// LanguageServerId(0),
+// lsp::PublishDiagnosticsParams {
+// uri: path_to_buffer_uri.clone(),
+// diagnostics: vec![diagnostic],
+// version: None,
+// },
+// None,
+// language::DiagnosticSourceKind::Pushed,
+// &[],
+// cx,
+// )
+// .unwrap();
+// });
+// });
+
+// let buffer = project
+// .update(cx, |project, cx| {
+// let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+// project.open_buffer(path, cx)
+// })
+// .await
+// .unwrap();
+
+// let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+// let position = snapshot.anchor_before(language::Point::new(0, 0));
+
+// let _prediction_task = ep_store.update(cx, |ep_store, cx| {
+// ep_store.request_prediction(&project, &buffer, position, cx)
+// });
+
+// let (request, _respond_tx) = req_rx.next().await.unwrap();
+
+// assert_eq!(request.diagnostic_groups.len(), 1);
+// let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
+// .unwrap();
+// // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
+// assert_eq!(
+// value,
+// json!({
+// "entries": [{
+// "range": {
+// "start": 8,
+// "end": 10
+// },
+// "diagnostic": {
+// "source": null,
+// "code": null,
+// "code_description": null,
+// "severity": 1,
+// "message": "\"Hello\" deprecated. Use \"Hi\" instead",
+// "markdown": null,
+// "group_id": 0,
+// "is_primary": true,
+// "is_disk_based": false,
+// "is_unnecessary": false,
+// "source_kind": "Pushed",
+// "data": null,
+// "underline": true
+// }
+// }],
+// "primary_ix": 0
+// })
+// );
+// }
+
+fn model_response(text: &str) -> open_ai::Response {
+ open_ai::Response {
+ id: Uuid::new_v4().to_string(),
+ object: "response".into(),
+ created: 0,
+ model: "model".into(),
+ choices: vec![open_ai::Choice {
+ index: 0,
+ message: open_ai::RequestMessage::Assistant {
+ content: Some(open_ai::MessageContent::Plain(text.to_string())),
+ tool_calls: vec![],
+ },
+ finish_reason: None,
+ }],
+ usage: Usage {
+ prompt_tokens: 0,
+ completion_tokens: 0,
+ total_tokens: 0,
+ },
+ }
+}
+
+fn prompt_from_request(request: &open_ai::Request) -> &str {
+ assert_eq!(request.messages.len(), 1);
+ let open_ai::RequestMessage::User {
+ content: open_ai::MessageContent::Plain(content),
+ ..
+ } = &request.messages[0]
+ else {
+ panic!(
+ "Request does not have single user message of type Plain. {:#?}",
+ request
+ );
+ };
+ content
+}
+
+struct RequestChannels {
+ predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
+ reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
+}
+
+fn init_test_with_fake_client(
+ cx: &mut TestAppContext,
+) -> (Entity<EditPredictionStore>, RequestChannels) {
+ cx.update(move |cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ zlog::init_test();
+
+ let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
+ let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
+
+ let http_client = FakeHttpClient::create({
+ move |req| {
+ let uri = req.uri().path().to_string();
+ let mut body = req.into_body();
+ let predict_req_tx = predict_req_tx.clone();
+ let reject_req_tx = reject_req_tx.clone();
+ async move {
+ let resp = match uri.as_str() {
+ "/client/llm_tokens" => serde_json::to_string(&json!({
+ "token": "test"
+ }))
+ .unwrap(),
+ "/predict_edits/raw" => {
+ let mut buf = Vec::new();
+ body.read_to_end(&mut buf).await.ok();
+ let req = serde_json::from_slice(&buf).unwrap();
+
+ let (res_tx, res_rx) = oneshot::channel();
+ predict_req_tx.unbounded_send((req, res_tx)).unwrap();
+ serde_json::to_string(&res_rx.await?).unwrap()
+ }
+ "/predict_edits/reject" => {
+ let mut buf = Vec::new();
+ body.read_to_end(&mut buf).await.ok();
+ let req = serde_json::from_slice(&buf).unwrap();
+
+ let (res_tx, res_rx) = oneshot::channel();
+ reject_req_tx.unbounded_send((req, res_tx)).unwrap();
+ serde_json::to_string(&res_rx.await?).unwrap()
+ }
+ _ => {
+ panic!("Unexpected path: {}", uri)
+ }
+ };
+
+ Ok(Response::builder().body(resp.into()).unwrap())
+ }
+ }
+ });
+
+ let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
+ client.cloud_client().set_credentials(1, "test".into());
+
+ language_model::init(client.clone(), cx);
+
+ let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+ let ep_store = EditPredictionStore::global(&client, &user_store, cx);
+
+ (
+ ep_store,
+ RequestChannels {
+ predict: predict_req_rx,
+ reject: reject_req_rx,
+ },
+ )
+ })
+}
+
+const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
+
+#[gpui::test]
+async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
+ let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
+ let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
+ to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
+ });
+
+ let edit_preview = cx
+ .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
+ .await;
+
+ let completion = EditPrediction {
+ edits,
+ edit_preview,
+ buffer: buffer.clone(),
+ snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
+ id: EditPredictionId("the-id".into()),
+ inputs: EditPredictionInputs {
+ events: Default::default(),
+ included_files: Default::default(),
+ cursor_point: cloud_llm_client::predict_edits_v3::Point {
+ line: Line(0),
+ column: 0,
+ },
+ cursor_path: Path::new("").into(),
+ },
+ buffer_snapshotted_at: Instant::now(),
+ response_received_at: Instant::now(),
+ };
+
+ cx.update(|cx| {
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(2..5, "REM".into()), (9..11, "".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(2..2, "REM".into()), (6..8, "".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.undo(cx));
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(2..5, "REM".into()), (9..11, "".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(3..3, "EM".into()), (7..9, "".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(4..4, "M".into()), (8..10, "".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(9..11, "".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(4..4, "M".into()), (8..10, "".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
+ assert_eq!(
+ from_completion_edits(
+ &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &buffer,
+ cx
+ ),
+ vec![(4..4, "M".into())]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
+ assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
+ })
+}
+
+#[gpui::test]
+async fn test_clean_up_diff(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ assert_eq!(
+ apply_edit_prediction(
+ indoc! {"
+ fn main() {
+ let word_1 = \"lorem\";
+ let range = word.len()..word.len();
+ }
+ "},
+ indoc! {"
+ <|editable_region_start|>
+ fn main() {
+ let word_1 = \"lorem\";
+ let range = word_1.len()..word_1.len();
+ }
+
+ <|editable_region_end|>
+ "},
+ cx,
+ )
+ .await,
+ indoc! {"
+ fn main() {
+ let word_1 = \"lorem\";
+ let range = word_1.len()..word_1.len();
+ }
+ "},
+ );
+
+ assert_eq!(
+ apply_edit_prediction(
+ indoc! {"
+ fn main() {
+ let story = \"the quick\"
+ }
+ "},
+ indoc! {"
+ <|editable_region_start|>
+ fn main() {
+ let story = \"the quick brown fox jumps over the lazy dog\";
+ }
+
+ <|editable_region_end|>
+ "},
+ cx,
+ )
+ .await,
+ indoc! {"
+ fn main() {
+ let story = \"the quick brown fox jumps over the lazy dog\";
+ }
+ "},
+ );
+}
+
+#[gpui::test]
+async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let buffer_content = "lorem\n";
+ let completion_response = indoc! {"
+ ```animals.js
+ <|start_of_file|>
+ <|editable_region_start|>
+ lorem
+ ipsum
+ <|editable_region_end|>
+ ```"};
+
+ assert_eq!(
+ apply_edit_prediction(buffer_content, completion_response, cx).await,
+ "lorem\nipsum"
+ );
+}
+
+#[gpui::test]
+async fn test_can_collect_data(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
+ .await;
+
+ let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/project/src/main.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.data_collection_choice = DataCollectionChoice::Enabled
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ true
+ );
+
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.data_collection_choice = DataCollectionChoice::Disabled
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+}
+
+#[gpui::test]
+async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ let project = Project::test(fs.clone(), [], cx).await;
+
+ let buffer = cx.new(|_cx| {
+ Buffer::remote(
+ language::BufferId::new(1).unwrap(),
+ ReplicaId::new(1),
+ language::Capability::ReadWrite,
+ "fn main() {\n println!(\"Hello\");\n}",
+ )
+ });
+
+ let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.data_collection_choice = DataCollectionChoice::Enabled
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+}
+
+#[gpui::test]
+async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/project"),
+ json!({
+ "LICENSE": BSD_0_TXT,
+ ".env": "SECRET_KEY=secret"
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer("/project/.env", cx)
+ })
+ .await
+ .unwrap();
+
+ let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.data_collection_choice = DataCollectionChoice::Enabled
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+}
+
+#[gpui::test]
+async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ let project = Project::test(fs.clone(), [], cx).await;
+ let buffer = cx.new(|cx| Buffer::local("", cx));
+
+ let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.data_collection_choice = DataCollectionChoice::Enabled
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+}
+
+#[gpui::test]
+async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
+ .await;
+
+ let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer("/project/main.rs", cx)
+ })
+ .await
+ .unwrap();
+
+ let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.data_collection_choice = DataCollectionChoice::Enabled
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+}
+
+#[gpui::test]
+async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/open_source_worktree"),
+ json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
+ )
+ .await;
+ fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
+ .await;
+
+ let project = Project::test(
+ fs.clone(),
+ [
+ path!("/open_source_worktree").as_ref(),
+ path!("/closed_source_worktree").as_ref(),
+ ],
+ cx,
+ )
+ .await;
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.data_collection_choice = DataCollectionChoice::Enabled
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ true
+ );
+
+ let closed_source_file = project
+ .update(cx, |project, cx| {
+ let worktree2 = project
+ .worktree_for_root_name("closed_source_worktree", cx)
+ .unwrap();
+ worktree2.update(cx, |worktree2, cx| {
+ worktree2.load_file(rel_path("main.rs"), cx)
+ })
+ })
+ .await
+ .unwrap()
+ .file;
+
+ buffer.update(cx, |buffer, cx| {
+ buffer.file_updated(closed_source_file, cx);
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+}
+
+#[gpui::test]
+async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/worktree1"),
+ json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
+ )
+ .await;
+ fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
+ .await;
+
+ let project = Project::test(
+ fs.clone(),
+ [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
+ cx,
+ )
+ .await;
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/worktree1/main.rs"), cx)
+ })
+ .await
+ .unwrap();
+ let private_buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/worktree2/file.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ let (ep_store, captured_request, _) = make_test_ep_store(&project, cx).await;
+ ep_store.update(cx, |ep_store, _cx| {
+ ep_store.data_collection_choice = DataCollectionChoice::Enabled
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ true
+ );
+
+ // this has a side effect of registering the buffer to watch for edits
+ run_edit_prediction(&private_buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+
+ private_buffer.update(cx, |private_buffer, cx| {
+ private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ false
+ );
+
+ // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
+ // included
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit(
+ [(
+ 0..0,
+ " ".repeat(MAX_EVENT_TOKENS * cursor_excerpt::BYTES_PER_TOKEN_GUESS),
+ )],
+ None,
+ cx,
+ );
+ });
+
+ run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ assert_eq!(
+ captured_request.lock().clone().unwrap().can_collect_data,
+ true
+ );
+}
+
+fn init_test(cx: &mut TestAppContext) {
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ });
+}
+
+async fn apply_edit_prediction(
+ buffer_content: &str,
+ completion_response: &str,
+ cx: &mut TestAppContext,
+) -> String {
+ let fs = project::FakeFs::new(cx.executor());
+ let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+ let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
+ let (ep_store, _, response) = make_test_ep_store(&project, cx).await;
+ *response.lock() = completion_response.to_string();
+ let edit_prediction = run_edit_prediction(&buffer, &project, &ep_store, cx).await;
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
+ });
+ buffer.read_with(cx, |buffer, _| buffer.text())
+}
+
+async fn run_edit_prediction(
+ buffer: &Entity<Buffer>,
+ project: &Entity<Project>,
+ ep_store: &Entity<EditPredictionStore>,
+ cx: &mut TestAppContext,
+) -> EditPrediction {
+ let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.register_buffer(buffer, &project, cx)
+ });
+ cx.background_executor.run_until_parked();
+ let prediction_task = ep_store.update(cx, |ep_store, cx| {
+ ep_store.request_prediction(&project, buffer, cursor, Default::default(), cx)
+ });
+ prediction_task.await.unwrap().unwrap().prediction.unwrap()
+}
+
+async fn make_test_ep_store(
+ project: &Entity<Project>,
+ cx: &mut TestAppContext,
+) -> (
+ Entity<EditPredictionStore>,
+ Arc<Mutex<Option<PredictEditsBody>>>,
+ Arc<Mutex<String>>,
+) {
+ let default_response = indoc! {"
+ ```main.rs
+ <|start_of_file|>
+ <|editable_region_start|>
+ hello world
+ <|editable_region_end|>
+ ```"
+ };
+ let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
+ let completion_response: Arc<Mutex<String>> =
+ Arc::new(Mutex::new(default_response.to_string()));
+ let http_client = FakeHttpClient::create({
+ let captured_request = captured_request.clone();
+ let completion_response = completion_response.clone();
+ let mut next_request_id = 0;
+ move |req| {
+ let captured_request = captured_request.clone();
+ let completion_response = completion_response.clone();
+ async move {
+ match (req.method(), req.uri().path()) {
+ (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
+ .status(200)
+ .body(
+ serde_json::to_string(&CreateLlmTokenResponse {
+ token: LlmToken("the-llm-token".to_string()),
+ })
+ .unwrap()
+ .into(),
+ )
+ .unwrap()),
+ (&Method::POST, "/predict_edits/v2") => {
+ let mut request_body = String::new();
+ req.into_body().read_to_string(&mut request_body).await?;
+ *captured_request.lock() =
+ Some(serde_json::from_str(&request_body).unwrap());
+ next_request_id += 1;
+ Ok(http_client::Response::builder()
+ .status(200)
+ .body(
+ serde_json::to_string(&PredictEditsResponse {
+ request_id: format!("request-{next_request_id}"),
+ output_excerpt: completion_response.lock().clone(),
+ })
+ .unwrap()
+ .into(),
+ )
+ .unwrap())
+ }
+ _ => Ok(http_client::Response::builder()
+ .status(404)
+ .body("Not Found".into())
+ .unwrap()),
+ }
+ }
+ }
+ });
+
+ let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
+ cx.update(|cx| {
+ RefreshLlmTokenListener::register(client.clone(), cx);
+ });
+ let _server = FakeServer::for_client(42, &client, cx).await;
+
+ let ep_store = cx.new(|cx| {
+ let mut ep_store = EditPredictionStore::new(client, project.read(cx).user_store(), cx);
+ ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
+
+ let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
+ for worktree in worktrees {
+ let worktree_id = worktree.read(cx).id();
+ ep_store
+ .get_or_init_project(project, cx)
+ .license_detection_watchers
+ .entry(worktree_id)
+ .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
+ }
+
+ ep_store
+ });
+
+ (ep_store, captured_request, completion_response)
+}
+
+fn to_completion_edits(
+ iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
+ buffer: &Entity<Buffer>,
+ cx: &App,
+) -> Vec<(Range<Anchor>, Arc<str>)> {
+ let buffer = buffer.read(cx);
+ iterator
+ .into_iter()
+ .map(|(range, text)| {
+ (
+ buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
+ text,
+ )
+ })
+ .collect()
+}
+
+fn from_completion_edits(
+ editor_edits: &[(Range<Anchor>, Arc<str>)],
+ buffer: &Entity<Buffer>,
+ cx: &App,
+) -> Vec<(Range<usize>, Arc<str>)> {
+ let buffer = buffer.read(cx);
+ editor_edits
+ .iter()
+ .map(|(range, text)| {
+ (
+ range.start.to_offset(buffer)..range.end.to_offset(buffer),
+ text.clone(),
+ )
+ })
+ .collect()
+}
+
+#[ctor::ctor]
+fn init_logger() {
+ zlog::init_test();
+}
@@ -0,0 +1,340 @@
+use anyhow::{Context as _, Result};
+use cloud_llm_client::predict_edits_v3::Event;
+use credentials_provider::CredentialsProvider;
+use edit_prediction_context::RelatedFile;
+use futures::{AsyncReadExt as _, FutureExt, future::Shared};
+use gpui::{
+ App, AppContext as _, Entity, Task,
+ http_client::{self, AsyncBody, Method},
+};
+use language::{Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _};
+use project::{Project, ProjectPath};
+use std::{
+ collections::VecDeque, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc, time::Instant,
+};
+
+use crate::{
+ EditPredictionId, EditPredictionInputs, open_ai_response::text_from_response,
+ prediction::EditPredictionResult,
+};
+
+const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
+const MAX_CONTEXT_TOKENS: usize = 150;
+const MAX_REWRITE_TOKENS: usize = 350;
+
+pub struct Mercury {
+ pub api_token: Shared<Task<Option<String>>>,
+}
+
+impl Mercury {
+ pub fn new(cx: &App) -> Self {
+ Mercury {
+ api_token: load_api_token(cx).shared(),
+ }
+ }
+
+ pub fn set_api_token(&mut self, api_token: Option<String>, cx: &mut App) -> Task<Result<()>> {
+ self.api_token = Task::ready(api_token.clone()).shared();
+ store_api_token_in_keychain(api_token, cx)
+ }
+
+ pub fn request_prediction(
+ &self,
+ _project: &Entity<Project>,
+ active_buffer: &Entity<Buffer>,
+ snapshot: BufferSnapshot,
+ position: language::Anchor,
+ events: Vec<Arc<Event>>,
+ _recent_paths: &VecDeque<ProjectPath>,
+ related_files: Vec<RelatedFile>,
+ _diagnostic_search_range: Range<Point>,
+ cx: &mut App,
+ ) -> Task<Result<Option<EditPredictionResult>>> {
+ let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
+ return Task::ready(Ok(None));
+ };
+ let full_path: Arc<Path> = snapshot
+ .file()
+ .map(|file| file.full_path(cx))
+ .unwrap_or_else(|| "untitled".into())
+ .into();
+
+ let http_client = cx.http_client();
+ let cursor_point = position.to_point(&snapshot);
+ let buffer_snapshotted_at = Instant::now();
+
+ let result = cx.background_spawn(async move {
+ let (editable_range, context_range) =
+ crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
+ cursor_point,
+ &snapshot,
+ MAX_CONTEXT_TOKENS,
+ MAX_REWRITE_TOKENS,
+ );
+
+ let offset_range = editable_range.to_offset(&snapshot);
+ let prompt = build_prompt(
+ &events,
+ &related_files,
+ &snapshot,
+ full_path.as_ref(),
+ cursor_point,
+ editable_range,
+ context_range.clone(),
+ );
+
+ let inputs = EditPredictionInputs {
+ events: events,
+ included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
+ path: full_path.clone(),
+ max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
+ excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
+ start_line: cloud_llm_client::predict_edits_v3::Line(
+ context_range.start.row,
+ ),
+ text: snapshot
+ .text_for_range(context_range.clone())
+ .collect::<String>()
+ .into(),
+ }],
+ }],
+ cursor_point: cloud_llm_client::predict_edits_v3::Point {
+ column: cursor_point.column,
+ line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
+ },
+ cursor_path: full_path.clone(),
+ };
+
+ let request_body = open_ai::Request {
+ model: "mercury-coder".into(),
+ messages: vec![open_ai::RequestMessage::User {
+ content: open_ai::MessageContent::Plain(prompt),
+ }],
+ stream: false,
+ max_completion_tokens: None,
+ stop: vec![],
+ temperature: None,
+ tool_choice: None,
+ parallel_tool_calls: None,
+ tools: vec![],
+ prompt_cache_key: None,
+ reasoning_effort: None,
+ };
+
+ let buf = serde_json::to_vec(&request_body)?;
+ let body: AsyncBody = buf.into();
+
+ let request = http_client::Request::builder()
+ .uri(MERCURY_API_URL)
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_token))
+ .header("Connection", "keep-alive")
+ .method(Method::POST)
+ .body(body)
+ .context("Failed to create request")?;
+
+ let mut response = http_client
+ .send(request)
+ .await
+ .context("Failed to send request")?;
+
+ let mut body: Vec<u8> = Vec::new();
+ response
+ .body_mut()
+ .read_to_end(&mut body)
+ .await
+ .context("Failed to read response body")?;
+
+ let response_received_at = Instant::now();
+ if !response.status().is_success() {
+ anyhow::bail!(
+ "Request failed with status: {:?}\nBody: {}",
+ response.status(),
+ String::from_utf8_lossy(&body),
+ );
+ };
+
+ let mut response: open_ai::Response =
+ serde_json::from_slice(&body).context("Failed to parse response")?;
+
+ let id = mem::take(&mut response.id);
+ let response_str = text_from_response(response).unwrap_or_default();
+
+ let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
+ let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
+
+ let mut edits = Vec::new();
+ const NO_PREDICTION_OUTPUT: &str = "None";
+
+ if response_str != NO_PREDICTION_OUTPUT {
+ let old_text = snapshot
+ .text_for_range(offset_range.clone())
+ .collect::<String>();
+ edits.extend(
+ language::text_diff(&old_text, &response_str)
+ .into_iter()
+ .map(|(range, text)| {
+ (
+ snapshot.anchor_after(offset_range.start + range.start)
+ ..snapshot.anchor_before(offset_range.start + range.end),
+ text,
+ )
+ }),
+ );
+ }
+
+ anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
+ });
+
+ let buffer = active_buffer.clone();
+
+ cx.spawn(async move |cx| {
+ let (id, edits, old_snapshot, response_received_at, inputs) =
+ result.await.context("Mercury edit prediction failed")?;
+ anyhow::Ok(Some(
+ EditPredictionResult::new(
+ EditPredictionId(id.into()),
+ &buffer,
+ &old_snapshot,
+ edits.into(),
+ buffer_snapshotted_at,
+ response_received_at,
+ inputs,
+ cx,
+ )
+ .await,
+ ))
+ })
+ }
+}
+
+fn build_prompt(
+ events: &[Arc<Event>],
+ related_files: &[RelatedFile],
+ cursor_buffer: &BufferSnapshot,
+ cursor_buffer_path: &Path,
+ cursor_point: Point,
+ editable_range: Range<Point>,
+ context_range: Range<Point>,
+) -> String {
+ const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
+ const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
+ const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
+ const RECENTLY_VIEWED_SNIPPET_END: &str = "<|/recently_viewed_code_snippet|>\n";
+ const CURRENT_FILE_CONTENT_START: &str = "<|current_file_content|>\n";
+ const CURRENT_FILE_CONTENT_END: &str = "<|/current_file_content|>\n";
+ const CODE_TO_EDIT_START: &str = "<|code_to_edit|>\n";
+ const CODE_TO_EDIT_END: &str = "<|/code_to_edit|>\n";
+ const EDIT_DIFF_HISTORY_START: &str = "<|edit_diff_history|>\n";
+ const EDIT_DIFF_HISTORY_END: &str = "<|/edit_diff_history|>\n";
+ const CURSOR_TAG: &str = "<|cursor|>";
+ const CODE_SNIPPET_FILE_PATH_PREFIX: &str = "code_snippet_file_path: ";
+ const CURRENT_FILE_PATH_PREFIX: &str = "current_file_path: ";
+
+ let mut prompt = String::new();
+
+ push_delimited(
+ &mut prompt,
+ RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
+ |prompt| {
+ for related_file in related_files {
+ for related_excerpt in &related_file.excerpts {
+ push_delimited(
+ prompt,
+ RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
+ |prompt| {
+ prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
+ prompt.push_str(related_file.path.path.as_unix_str());
+ prompt.push('\n');
+ prompt.push_str(&related_excerpt.text.to_string());
+ },
+ );
+ }
+ }
+ },
+ );
+
+ push_delimited(
+ &mut prompt,
+ CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
+ |prompt| {
+ prompt.push_str(CURRENT_FILE_PATH_PREFIX);
+ prompt.push_str(cursor_buffer_path.as_os_str().to_string_lossy().as_ref());
+ prompt.push('\n');
+
+ let prefix_range = context_range.start..editable_range.start;
+ let suffix_range = editable_range.end..context_range.end;
+
+ prompt.extend(cursor_buffer.text_for_range(prefix_range));
+ push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
+ let range_before_cursor = editable_range.start..cursor_point;
+ let range_after_cursor = cursor_point..editable_range.end;
+ prompt.extend(cursor_buffer.text_for_range(range_before_cursor));
+ prompt.push_str(CURSOR_TAG);
+ prompt.extend(cursor_buffer.text_for_range(range_after_cursor));
+ });
+ prompt.extend(cursor_buffer.text_for_range(suffix_range));
+ },
+ );
+
+ push_delimited(
+ &mut prompt,
+ EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
+ |prompt| {
+ for event in events {
+ writeln!(prompt, "{event}").unwrap();
+ }
+ },
+ );
+
+ prompt
+}
+
+fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(&mut String)) {
+ prompt.push_str(delimiters.start);
+ cb(prompt);
+ prompt.push_str(delimiters.end);
+}
+
+pub const MERCURY_CREDENTIALS_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
+pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token";
+
+pub fn load_api_token(cx: &App) -> Task<Option<String>> {
+ if let Some(api_token) = std::env::var("MERCURY_AI_TOKEN")
+ .ok()
+ .filter(|value| !value.is_empty())
+ {
+ return Task::ready(Some(api_token));
+ }
+ let credentials_provider = <dyn CredentialsProvider>::global(cx);
+ cx.spawn(async move |cx| {
+ let (_, credentials) = credentials_provider
+ .read_credentials(MERCURY_CREDENTIALS_URL, &cx)
+ .await
+ .ok()??;
+ String::from_utf8(credentials).ok()
+ })
+}
+
+fn store_api_token_in_keychain(api_token: Option<String>, cx: &App) -> Task<Result<()>> {
+ let credentials_provider = <dyn CredentialsProvider>::global(cx);
+
+ cx.spawn(async move |cx| {
+ if let Some(api_token) = api_token {
+ credentials_provider
+ .write_credentials(
+ MERCURY_CREDENTIALS_URL,
+ MERCURY_CREDENTIALS_USERNAME,
+ api_token.as_bytes(),
+ cx,
+ )
+ .await
+ .context("Failed to save Mercury API token to system keychain")
+ } else {
+ credentials_provider
+ .delete_credentials(MERCURY_CREDENTIALS_URL, cx)
+ .await
+ .context("Failed to delete Mercury API token from system keychain")
+ }
+ })
+}
@@ -0,0 +1,31 @@
+pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
+ let choice = res.choices.pop()?;
+ let output_text = match choice.message {
+ open_ai::RequestMessage::Assistant {
+ content: Some(open_ai::MessageContent::Plain(content)),
+ ..
+ } => content,
+ open_ai::RequestMessage::Assistant {
+ content: Some(open_ai::MessageContent::Multipart(mut content)),
+ ..
+ } => {
+ if content.is_empty() {
+ log::error!("No output from Baseten completion response");
+ return None;
+ }
+
+ match content.remove(0) {
+ open_ai::MessagePart::Text { text } => text,
+ open_ai::MessagePart::Image { .. } => {
+ log::error!("Expected text, got an image");
+ return None;
+ }
+ }
+ }
+ _ => {
+ log::error!("Invalid response message: {:?}", choice.message);
+ return None;
+ }
+ };
+ Some(output_text)
+}
@@ -99,7 +99,7 @@ pub struct EditPrediction {
#[derive(Debug, Clone, Serialize)]
pub struct EditPredictionInputs {
pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
- pub included_files: Vec<cloud_llm_client::predict_edits_v3::IncludedFile>,
+ pub included_files: Vec<cloud_llm_client::predict_edits_v3::RelatedFile>,
pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
pub cursor_path: Arc<Path>,
}
@@ -1,6 +1,7 @@
use anyhow::{Context as _, Result};
use cloud_llm_client::predict_edits_v3::Event;
use credentials_provider::CredentialsProvider;
+use edit_prediction_context::RelatedFile;
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
use gpui::{
App, AppContext as _, Entity, Task,
@@ -49,6 +50,7 @@ impl SweepAi {
position: language::Anchor,
events: Vec<Arc<Event>>,
recent_paths: &VecDeque<ProjectPath>,
+ related_files: Vec<RelatedFile>,
diagnostic_search_range: Range<Point>,
cx: &mut App,
) -> Task<Result<Option<EditPredictionResult>>> {
@@ -120,6 +122,19 @@ impl SweepAi {
})
.collect::<Vec<_>>();
+ let retrieval_chunks = related_files
+ .iter()
+ .flat_map(|related_file| {
+ related_file.excerpts.iter().map(|excerpt| FileChunk {
+ file_path: related_file.path.path.as_unix_str().to_string(),
+ start_line: excerpt.point_range.start.row as usize,
+ end_line: excerpt.point_range.end.row as usize,
+ content: excerpt.text.to_string(),
+ timestamp: None,
+ })
+ })
+ .collect();
+
let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
let mut diagnostic_content = String::new();
let mut diagnostic_count = 0;
@@ -168,7 +183,7 @@ impl SweepAi {
multiple_suggestions: false,
branch: None,
file_chunks,
- retrieval_chunks: vec![],
+ retrieval_chunks,
recent_user_actions: vec![],
use_bytes: true,
// TODO
@@ -182,7 +197,7 @@ impl SweepAi {
let inputs = EditPredictionInputs {
events,
- included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
+ included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
path: full_path.clone(),
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
@@ -320,7 +335,7 @@ struct AutocompleteRequest {
pub cursor_position: usize,
pub original_file_contents: String,
pub file_chunks: Vec<FileChunk>,
- pub retrieval_chunks: Vec<RetrievalChunk>,
+ pub retrieval_chunks: Vec<FileChunk>,
pub recent_user_actions: Vec<UserAction>,
pub multiple_suggestions: bool,
pub privacy_mode_enabled: bool,
@@ -337,15 +352,6 @@ struct FileChunk {
pub timestamp: Option<u64>,
}
-#[derive(Debug, Clone, Serialize)]
-struct RetrievalChunk {
- pub file_path: String,
- pub start_line: usize,
- pub end_line: usize,
- pub content: String,
- pub timestamp: u64,
-}
-
#[derive(Debug, Clone, Serialize)]
struct UserAction {
pub action_type: ActionType,
@@ -1,55 +1,56 @@
-use std::{cmp, sync::Arc, time::Duration};
+use std::{cmp, sync::Arc};
use client::{Client, UserStore};
use cloud_llm_client::EditPredictionRejectReason;
-use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
+use edit_prediction_types::{DataCollectionState, Direction, EditPredictionDelegate};
use gpui::{App, Entity, prelude::*};
-use language::ToPoint as _;
+use language::{Buffer, ToPoint as _};
use project::Project;
-use crate::{BufferEditPrediction, Zeta, ZetaEditPredictionModel};
+use crate::{BufferEditPrediction, EditPredictionModel, EditPredictionStore};
-pub struct ZetaEditPredictionProvider {
- zeta: Entity<Zeta>,
+pub struct ZedEditPredictionDelegate {
+ store: Entity<EditPredictionStore>,
project: Entity<Project>,
+ singleton_buffer: Option<Entity<Buffer>>,
}
-impl ZetaEditPredictionProvider {
- pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
-
+impl ZedEditPredictionDelegate {
pub fn new(
project: Entity<Project>,
+ singleton_buffer: Option<Entity<Buffer>>,
client: &Arc<Client>,
user_store: &Entity<UserStore>,
cx: &mut Context<Self>,
) -> Self {
- let zeta = Zeta::global(client, user_store, cx);
- zeta.update(cx, |zeta, cx| {
- zeta.register_project(&project, cx);
+ let store = EditPredictionStore::global(client, user_store, cx);
+ store.update(cx, |store, cx| {
+ store.register_project(&project, cx);
});
- cx.observe(&zeta, |_this, _zeta, cx| {
+ cx.observe(&store, |_this, _ep_store, cx| {
cx.notify();
})
.detach();
Self {
project: project,
- zeta,
+ store: store,
+ singleton_buffer,
}
}
}
-impl EditPredictionProvider for ZetaEditPredictionProvider {
+impl EditPredictionDelegate for ZedEditPredictionDelegate {
fn name() -> &'static str {
- "zed-predict2"
+ "zed-predict"
}
fn display_name() -> &'static str {
- "Zed's Edit Predictions 2"
+ "Zed's Edit Predictions"
}
- fn show_completions_in_menu() -> bool {
+ fn show_predictions_in_menu() -> bool {
true
}
@@ -57,17 +58,38 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
true
}
- fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
- // TODO [zeta2]
- DataCollectionState::Unsupported
+ fn data_collection_state(&self, cx: &App) -> DataCollectionState {
+ if let Some(buffer) = &self.singleton_buffer
+ && let Some(file) = buffer.read(cx).file()
+ {
+ let is_project_open_source =
+ self.store
+ .read(cx)
+ .is_file_open_source(&self.project, file, cx);
+ if self.store.read(cx).data_collection_choice.is_enabled() {
+ DataCollectionState::Enabled {
+ is_project_open_source,
+ }
+ } else {
+ DataCollectionState::Disabled {
+ is_project_open_source,
+ }
+ }
+ } else {
+ return DataCollectionState::Disabled {
+ is_project_open_source: false,
+ };
+ }
}
- fn toggle_data_collection(&mut self, _cx: &mut App) {
- // TODO [zeta2]
+ fn toggle_data_collection(&mut self, cx: &mut App) {
+ self.store.update(cx, |store, cx| {
+ store.toggle_data_collection_choice(cx);
+ });
}
fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
- self.zeta.read(cx).usage(cx)
+ self.store.read(cx).usage(cx)
}
fn is_enabled(
@@ -76,16 +98,16 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
_cursor_position: language::Anchor,
cx: &App,
) -> bool {
- let zeta = self.zeta.read(cx);
- if zeta.edit_prediction_model == ZetaEditPredictionModel::Sweep {
- zeta.has_sweep_api_token()
+ let store = self.store.read(cx);
+ if store.edit_prediction_model == EditPredictionModel::Sweep {
+ store.has_sweep_api_token()
} else {
true
}
}
fn is_refreshing(&self, cx: &App) -> bool {
- self.zeta.read(cx).is_refreshing(&self.project)
+ self.store.read(cx).is_refreshing(&self.project)
}
fn refresh(
@@ -95,24 +117,24 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
_debounce: bool,
cx: &mut Context<Self>,
) {
- let zeta = self.zeta.read(cx);
+ let store = self.store.read(cx);
- if zeta.user_store.read_with(cx, |user_store, _cx| {
+ if store.user_store.read_with(cx, |user_store, _cx| {
user_store.account_too_young() || user_store.has_overdue_invoices()
}) {
return;
}
- if let Some(current) = zeta.current_prediction_for_buffer(&buffer, &self.project, cx)
+ if let Some(current) = store.current_prediction_for_buffer(&buffer, &self.project, cx)
&& let BufferEditPrediction::Local { prediction } = current
&& prediction.interpolate(buffer.read(cx)).is_some()
{
return;
}
- self.zeta.update(cx, |zeta, cx| {
- zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx);
- zeta.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
+ self.store.update(cx, |store, cx| {
+ store.refresh_context(&self.project, &buffer, cursor_position, cx);
+ store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
});
}
@@ -126,20 +148,20 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
}
fn accept(&mut self, cx: &mut Context<Self>) {
- self.zeta.update(cx, |zeta, cx| {
- zeta.accept_current_prediction(&self.project, cx);
+ self.store.update(cx, |store, cx| {
+ store.accept_current_prediction(&self.project, cx);
});
}
fn discard(&mut self, cx: &mut Context<Self>) {
- self.zeta.update(cx, |zeta, _cx| {
- zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &self.project);
+ self.store.update(cx, |store, _cx| {
+ store.reject_current_prediction(EditPredictionRejectReason::Discarded, &self.project);
});
}
fn did_show(&mut self, cx: &mut Context<Self>) {
- self.zeta.update(cx, |zeta, cx| {
- zeta.did_show_current_prediction(&self.project, cx);
+ self.store.update(cx, |store, cx| {
+ store.did_show_current_prediction(&self.project, cx);
});
}
@@ -148,16 +170,16 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
buffer: &Entity<language::Buffer>,
cursor_position: language::Anchor,
cx: &mut Context<Self>,
- ) -> Option<edit_prediction::EditPrediction> {
+ ) -> Option<edit_prediction_types::EditPrediction> {
let prediction =
- self.zeta
+ self.store
.read(cx)
.current_prediction_for_buffer(buffer, &self.project, cx)?;
let prediction = match prediction {
BufferEditPrediction::Local { prediction } => prediction,
BufferEditPrediction::Jump { prediction } => {
- return Some(edit_prediction::EditPrediction::Jump {
+ return Some(edit_prediction_types::EditPrediction::Jump {
id: Some(prediction.id.to_string().into()),
snapshot: prediction.snapshot.clone(),
target: prediction.edits.first().unwrap().0.start,
@@ -169,8 +191,8 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
let snapshot = buffer.snapshot();
let Some(edits) = prediction.interpolate(&snapshot) else {
- self.zeta.update(cx, |zeta, _cx| {
- zeta.reject_current_prediction(
+ self.store.update(cx, |store, _cx| {
+ store.reject_current_prediction(
EditPredictionRejectReason::InterpolatedEmpty,
&self.project,
);
@@ -208,7 +230,7 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
}
}
- Some(edit_prediction::EditPrediction::Local {
+ Some(edit_prediction_types::EditPrediction::Local {
id: Some(prediction.id.to_string().into()),
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
edit_preview: Some(prediction.edit_preview.clone()),
@@ -1,9 +1,8 @@
-mod input_excerpt;
-
use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
use crate::{
- EditPredictionId, ZedUpdateRequiredError, Zeta,
+ EditPredictionId, EditPredictionStore, ZedUpdateRequiredError,
+ cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
prediction::{EditPredictionInputs, EditPredictionResult},
};
use anyhow::{Context as _, Result};
@@ -12,7 +11,6 @@ use cloud_llm_client::{
predict_edits_v3::Event,
};
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
-use input_excerpt::excerpt_for_cursor_position;
use language::{
Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
};
@@ -30,23 +28,23 @@ pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
pub(crate) const MAX_EVENT_TOKENS: usize = 500;
pub(crate) fn request_prediction_with_zeta1(
- zeta: &mut Zeta,
+ store: &mut EditPredictionStore,
project: &Entity<Project>,
buffer: &Entity<Buffer>,
snapshot: BufferSnapshot,
position: language::Anchor,
events: Vec<Arc<Event>>,
trigger: PredictEditsRequestTrigger,
- cx: &mut Context<Zeta>,
+ cx: &mut Context<EditPredictionStore>,
) -> Task<Result<Option<EditPredictionResult>>> {
let buffer = buffer.clone();
let buffer_snapshotted_at = Instant::now();
- let client = zeta.client.clone();
- let llm_token = zeta.llm_token.clone();
+ let client = store.client.clone();
+ let llm_token = store.llm_token.clone();
let app_version = AppVersion::global(cx);
let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
- let can_collect_file = zeta.can_collect_file(project, file, cx);
+ let can_collect_file = store.can_collect_file(project, file, cx);
let git_info = if can_collect_file {
git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
} else {
@@ -102,7 +100,7 @@ pub(crate) fn request_prediction_with_zeta1(
let http_client = client.http_client();
- let response = Zeta::send_api_request::<PredictEditsResponse>(
+ let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
|request| {
let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
predict_edits_url
@@ -124,7 +122,7 @@ pub(crate) fn request_prediction_with_zeta1(
let inputs = EditPredictionInputs {
events: included_events.into(),
- included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
+ included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
path: full_path.clone(),
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
@@ -155,8 +153,8 @@ pub(crate) fn request_prediction_with_zeta1(
Err(err) => {
if err.is::<ZedUpdateRequiredError>() {
cx.update(|cx| {
- this.update(cx, |zeta, _cx| {
- zeta.update_required = true;
+ this.update(cx, |ep_store, _cx| {
+ ep_store.update_required = true;
})
.ok();
@@ -495,10 +493,159 @@ pub fn format_event(event: &Event) -> String {
}
}
-/// Typical number of string bytes per token for the purposes of limiting model input. This is
-/// intentionally low to err on the side of underestimating limits.
-pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
+#[derive(Debug)]
+pub struct InputExcerpt {
+ pub context_range: Range<Point>,
+ pub editable_range: Range<Point>,
+ pub prompt: String,
+}
+
+pub fn excerpt_for_cursor_position(
+ position: Point,
+ path: &str,
+ snapshot: &BufferSnapshot,
+ editable_region_token_limit: usize,
+ context_token_limit: usize,
+) -> InputExcerpt {
+ let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
+ position,
+ snapshot,
+ editable_region_token_limit,
+ context_token_limit,
+ );
+
+ let mut prompt = String::new();
+
+ writeln!(&mut prompt, "```{path}").unwrap();
+ if context_range.start == Point::zero() {
+ writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
+ }
+
+ for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
+ prompt.push_str(chunk.text);
+ }
+
+ push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
+
+ for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
+ prompt.push_str(chunk.text);
+ }
+ write!(prompt, "\n```").unwrap();
+
+ InputExcerpt {
+ context_range,
+ editable_range,
+ prompt,
+ }
+}
+
+fn push_editable_range(
+ cursor_position: Point,
+ snapshot: &BufferSnapshot,
+ editable_range: Range<Point>,
+ prompt: &mut String,
+) {
+ writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
+ for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
+ prompt.push_str(chunk.text);
+ }
+ prompt.push_str(CURSOR_MARKER);
+ for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
+ prompt.push_str(chunk.text);
+ }
+ write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use gpui::{App, AppContext};
+ use indoc::indoc;
+ use language::Buffer;
+
+ #[gpui::test]
+ fn test_excerpt_for_cursor_position(cx: &mut App) {
+ let text = indoc! {r#"
+ fn foo() {
+ let x = 42;
+ println!("Hello, world!");
+ }
+
+ fn bar() {
+ let x = 42;
+ let mut sum = 0;
+ for i in 0..x {
+ sum += i;
+ }
+ println!("Sum: {}", sum);
+ return sum;
+ }
+
+ fn generate_random_numbers() -> Vec<i32> {
+ let mut rng = rand::thread_rng();
+ let mut numbers = Vec::new();
+ for _ in 0..5 {
+ numbers.push(rng.random_range(1..101));
+ }
+ numbers
+ }
+ "#};
+ let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
+ let snapshot = buffer.read(cx).snapshot();
+
+ // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
+ // when a larger scope doesn't fit the editable region.
+ let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
+ assert_eq!(
+ excerpt.prompt,
+ indoc! {r#"
+ ```main.rs
+ let x = 42;
+ println!("Hello, world!");
+ <|editable_region_start|>
+ }
-fn guess_token_count(bytes: usize) -> usize {
- bytes / BYTES_PER_TOKEN_GUESS
+ fn bar() {
+ let x = 42;
+ let mut sum = 0;
+ for i in 0..x {
+ sum += i;
+ }
+ println!("Sum: {}", sum);
+ r<|user_cursor_is_here|>eturn sum;
+ }
+
+ fn generate_random_numbers() -> Vec<i32> {
+ <|editable_region_end|>
+ let mut rng = rand::thread_rng();
+ let mut numbers = Vec::new();
+ ```"#}
+ );
+
+ // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
+ let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
+ assert_eq!(
+ excerpt.prompt,
+ indoc! {r#"
+ ```main.rs
+ fn bar() {
+ let x = 42;
+ let mut sum = 0;
+ <|editable_region_start|>
+ for i in 0..x {
+ sum += i;
+ }
+ println!("Sum: {}", sum);
+ r<|user_cursor_is_here|>eturn sum;
+ }
+
+ fn generate_random_numbers() -> Vec<i32> {
+ let mut rng = rand::thread_rng();
+ <|editable_region_end|>
+ let mut numbers = Vec::new();
+ for _ in 0..5 {
+ numbers.push(rng.random_range(1..101));
+ ```"#}
+ );
+ }
}
@@ -0,0 +1,327 @@
+#[cfg(feature = "eval-support")]
+use crate::EvalCacheEntryKind;
+use crate::open_ai_response::text_from_response;
+use crate::prediction::EditPredictionResult;
+use crate::{
+ DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs,
+ EditPredictionRequestedDebugEvent, EditPredictionStore,
+};
+use anyhow::{Result, anyhow, bail};
+use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
+use cloud_llm_client::{EditPredictionRejectReason, PredictEditsRequestTrigger};
+use cloud_zeta2_prompt::CURSOR_MARKER;
+use edit_prediction_context::{EditPredictionExcerpt, Line};
+use edit_prediction_context::{RelatedExcerpt, RelatedFile};
+use futures::channel::oneshot;
+use gpui::{Entity, Task, prelude::*};
+use language::{Anchor, BufferSnapshot};
+use language::{Buffer, Point, ToOffset as _, ToPoint};
+use project::{Project, ProjectItem as _};
+use release_channel::AppVersion;
+use std::{
+ env,
+ path::Path,
+ sync::Arc,
+ time::{Duration, Instant},
+};
+
+pub fn request_prediction_with_zeta2(
+ store: &mut EditPredictionStore,
+ project: &Entity<Project>,
+ active_buffer: &Entity<Buffer>,
+ active_snapshot: BufferSnapshot,
+ position: Anchor,
+ events: Vec<Arc<Event>>,
+ mut included_files: Vec<RelatedFile>,
+ trigger: PredictEditsRequestTrigger,
+ cx: &mut Context<EditPredictionStore>,
+) -> Task<Result<Option<EditPredictionResult>>> {
+ let options = store.options.clone();
+ let buffer_snapshotted_at = Instant::now();
+
+ let Some((excerpt_path, active_project_path)) = active_snapshot
+ .file()
+ .map(|file| -> Arc<Path> { file.full_path(cx).into() })
+ .zip(active_buffer.read(cx).project_path(cx))
+ else {
+ return Task::ready(Err(anyhow!("No file path for excerpt")));
+ };
+
+ let client = store.client.clone();
+ let llm_token = store.llm_token.clone();
+ let app_version = AppVersion::global(cx);
+ let debug_tx = store.debug_tx.clone();
+
+ let file = active_buffer.read(cx).file();
+
+ let active_file_full_path = file.as_ref().map(|f| f.full_path(cx));
+
+ // TODO data collection
+ let can_collect_data = file
+ .as_ref()
+ .map_or(false, |file| store.can_collect_file(project, file, cx));
+
+ #[cfg(feature = "eval-support")]
+ let eval_cache = store.eval_cache.clone();
+
+ let request_task = cx.background_spawn({
+ let active_buffer = active_buffer.clone();
+ async move {
+ let cursor_offset = position.to_offset(&active_snapshot);
+ let cursor_point = cursor_offset.to_point(&active_snapshot);
+
+ let before_retrieval = Instant::now();
+
+ let excerpt_options = options.context;
+
+ let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
+ cursor_point,
+ &active_snapshot,
+ &excerpt_options,
+ ) else {
+ return Ok((None, None));
+ };
+
+ let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
+ ..active_snapshot.anchor_before(excerpt.range.end);
+ let related_excerpt = RelatedExcerpt {
+ anchor_range: excerpt_anchor_range.clone(),
+ point_range: Point::new(excerpt.line_range.start.0, 0)
+ ..Point::new(excerpt.line_range.end.0, 0),
+ text: active_snapshot.as_rope().slice(excerpt.range),
+ };
+
+ if let Some(buffer_ix) = included_files
+ .iter()
+ .position(|file| file.buffer.entity_id() == active_buffer.entity_id())
+ {
+ let file = &mut included_files[buffer_ix];
+ file.excerpts.push(related_excerpt);
+ file.merge_excerpts();
+ let last_ix = included_files.len() - 1;
+ included_files.swap(buffer_ix, last_ix);
+ } else {
+ let active_file = RelatedFile {
+ path: active_project_path,
+ buffer: active_buffer.downgrade(),
+ excerpts: vec![related_excerpt],
+ max_row: active_snapshot.max_point().row,
+ };
+ included_files.push(active_file);
+ }
+
+ let included_files = included_files
+ .iter()
+ .map(|related_file| predict_edits_v3::RelatedFile {
+ path: Arc::from(related_file.path.path.as_std_path()),
+ max_row: Line(related_file.max_row),
+ excerpts: related_file
+ .excerpts
+ .iter()
+ .map(|excerpt| predict_edits_v3::Excerpt {
+ start_line: Line(excerpt.point_range.start.row),
+ text: excerpt.text.to_string().into(),
+ })
+ .collect(),
+ })
+ .collect::<Vec<_>>();
+
+ let cloud_request = predict_edits_v3::PredictEditsRequest {
+ excerpt_path,
+ excerpt: String::new(),
+ excerpt_line_range: Line(0)..Line(0),
+ excerpt_range: 0..0,
+ cursor_point: predict_edits_v3::Point {
+ line: predict_edits_v3::Line(cursor_point.row),
+ column: cursor_point.column,
+ },
+ related_files: included_files,
+ events,
+ can_collect_data,
+ debug_info: debug_tx.is_some(),
+ prompt_max_bytes: Some(options.max_prompt_bytes),
+ prompt_format: options.prompt_format,
+ excerpt_parent: None,
+ git_info: None,
+ trigger,
+ };
+
+ let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
+
+ let inputs = EditPredictionInputs {
+ included_files: cloud_request.related_files,
+ events: cloud_request.events,
+ cursor_point: cloud_request.cursor_point,
+ cursor_path: cloud_request.excerpt_path,
+ };
+
+ let retrieval_time = Instant::now() - before_retrieval;
+
+ let debug_response_tx = if let Some(debug_tx) = &debug_tx {
+ let (response_tx, response_rx) = oneshot::channel();
+
+ debug_tx
+ .unbounded_send(DebugEvent::EditPredictionRequested(
+ EditPredictionRequestedDebugEvent {
+ inputs: inputs.clone(),
+ retrieval_time,
+ buffer: active_buffer.downgrade(),
+ local_prompt: match prompt_result.as_ref() {
+ Ok(prompt) => Ok(prompt.clone()),
+ Err(err) => Err(err.to_string()),
+ },
+ position,
+ response_rx,
+ },
+ ))
+ .ok();
+ Some(response_tx)
+ } else {
+ None
+ };
+
+ if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
+ if let Some(debug_response_tx) = debug_response_tx {
+ debug_response_tx
+ .send((Err("Request skipped".to_string()), Duration::ZERO))
+ .ok();
+ }
+ anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
+ }
+
+ let prompt = prompt_result?;
+ let generation_params =
+ cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
+ let request = open_ai::Request {
+ model: EDIT_PREDICTIONS_MODEL_ID.clone(),
+ messages: vec![open_ai::RequestMessage::User {
+ content: open_ai::MessageContent::Plain(prompt),
+ }],
+ stream: false,
+ max_completion_tokens: None,
+ stop: generation_params.stop.unwrap_or_default(),
+ temperature: generation_params.temperature.or(Some(0.7)),
+ tool_choice: None,
+ parallel_tool_calls: None,
+ tools: vec![],
+ prompt_cache_key: None,
+ reasoning_effort: None,
+ };
+
+ log::trace!("Sending edit prediction request");
+
+ let before_request = Instant::now();
+ let response = EditPredictionStore::send_raw_llm_request(
+ request,
+ client,
+ llm_token,
+ app_version,
+ #[cfg(feature = "eval-support")]
+ eval_cache,
+ #[cfg(feature = "eval-support")]
+ EvalCacheEntryKind::Prediction,
+ )
+ .await;
+ let received_response_at = Instant::now();
+ let request_time = received_response_at - before_request;
+
+ log::trace!("Got edit prediction response");
+
+ if let Some(debug_response_tx) = debug_response_tx {
+ debug_response_tx
+ .send((
+ response
+ .as_ref()
+ .map_err(|err| err.to_string())
+ .map(|response| response.0.clone()),
+ request_time,
+ ))
+ .ok();
+ }
+
+ let (res, usage) = response?;
+ let request_id = EditPredictionId(res.id.clone().into());
+ let Some(mut output_text) = text_from_response(res) else {
+ return Ok((Some((request_id, None)), usage));
+ };
+
+ if output_text.contains(CURSOR_MARKER) {
+ log::trace!("Stripping out {CURSOR_MARKER} from response");
+ output_text = output_text.replace(CURSOR_MARKER, "");
+ }
+
+ let get_buffer_from_context = |path: &Path| {
+ if Some(path) == active_file_full_path.as_deref() {
+ Some((
+ &active_snapshot,
+ std::slice::from_ref(&excerpt_anchor_range),
+ ))
+ } else {
+ None
+ }
+ };
+
+ let (_, edits) = match options.prompt_format {
+ PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
+ if output_text.contains("--- a/\n+++ b/\nNo edits") {
+ let edits = vec![];
+ (&active_snapshot, edits)
+ } else {
+ crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
+ }
+ }
+ PromptFormat::OldTextNewText => {
+ crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await?
+ }
+ _ => {
+ bail!("unsupported prompt format {}", options.prompt_format)
+ }
+ };
+
+ anyhow::Ok((
+ Some((
+ request_id,
+ Some((
+ inputs,
+ active_buffer,
+ active_snapshot.clone(),
+ edits,
+ received_response_at,
+ )),
+ )),
+ usage,
+ ))
+ }
+ });
+
+ cx.spawn(async move |this, cx| {
+ let Some((id, prediction)) =
+ EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
+ else {
+ return Ok(None);
+ };
+
+ let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) =
+ prediction
+ else {
+ return Ok(Some(EditPredictionResult {
+ id,
+ prediction: Err(EditPredictionRejectReason::Empty),
+ }));
+ };
+
+ Ok(Some(
+ EditPredictionResult::new(
+ id,
+ &edited_buffer,
+ &edited_buffer_snapshot,
+ edits.into(),
+ buffer_snapshotted_at,
+ received_response_at,
+ inputs,
+ cx,
+ )
+ .await,
+ ))
+ })
+}
@@ -1,5 +1,5 @@
[package]
-name = "zeta_cli"
+name = "edit_prediction_cli"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
@@ -9,17 +9,18 @@ license = "GPL-3.0-or-later"
workspace = true
[[bin]]
-name = "zeta"
+name = "ep_cli"
path = "src/main.rs"
[dependencies]
-
anyhow.workspace = true
+anthropic.workspace = true
+http_client.workspace = true
chrono.workspace = true
clap.workspace = true
client.workspace = true
cloud_llm_client.workspace= true
-cloud_zeta2_prompt.workspace= true
+cloud_zeta2_prompt.workspace = true
collections.workspace = true
debug_adapter_extension.workspace = true
edit_prediction_context.workspace = true
@@ -28,6 +29,7 @@ fs.workspace = true
futures.workspace = true
gpui.workspace = true
gpui_tokio.workspace = true
+indoc.workspace = true
language.workspace = true
language_extension.workspace = true
language_model.workspace = true
@@ -35,9 +37,7 @@ language_models.workspace = true
languages = { workspace = true, features = ["load-grammars"] }
log.workspace = true
node_runtime.workspace = true
-ordered-float.workspace = true
paths.workspace = true
-polars = { version = "0.51", features = ["lazy", "dtype-struct", "parquet"] }
project.workspace = true
prompt_store.workspace = true
pulldown-cmark.workspace = true
@@ -48,12 +48,13 @@ serde_json.workspace = true
settings.workspace = true
shellexpand.workspace = true
smol.workspace = true
-soa-rs = "0.8.1"
+sqlez.workspace = true
+sqlez_macros.workspace = true
terminal_view.workspace = true
toml.workspace = true
util.workspace = true
watch.workspace = true
-zeta = { workspace = true, features = ["eval-support"] }
+edit_prediction = { workspace = true, features = ["eval-support"] }
zlog.workspace = true
[dev-dependencies]
@@ -6,17 +6,17 @@ use std::{
};
use anyhow::Result;
+use edit_prediction::{EditPredictionStore, udiff::DiffLine};
use gpui::{AsyncApp, Entity};
use project::Project;
use util::ResultExt as _;
-use zeta::{Zeta, udiff::DiffLine};
use crate::{
EvaluateArguments, PredictionOptions,
example::{Example, NamedExample},
headless::ZetaCliAppState,
paths::print_run_data_dir,
- predict::{PredictionDetails, perform_predict, setup_zeta},
+ predict::{PredictionDetails, perform_predict, setup_store},
};
#[derive(Debug)]
@@ -45,7 +45,7 @@ pub async fn run_evaluate(
let project = example.setup_project(&app_state, cx).await.unwrap();
let providers = (0..args.repetitions)
- .map(|_| setup_zeta(args.options.provider, &project, &app_state, cx).unwrap())
+ .map(|_| setup_store(args.options.provider, &project, &app_state, cx).unwrap())
.collect::<Vec<_>>();
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
@@ -53,7 +53,7 @@ pub async fn run_evaluate(
let tasks = providers
.into_iter()
.enumerate()
- .map(move |(repetition_ix, zeta)| {
+ .map(move |(repetition_ix, store)| {
let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
let example = example.clone();
let project = project.clone();
@@ -65,7 +65,7 @@ pub async fn run_evaluate(
example,
repetition_ix,
project,
- zeta,
+ store,
options,
!args.skip_prediction,
cx,
@@ -154,7 +154,7 @@ pub async fn run_evaluate_one(
example: NamedExample,
repetition_ix: Option<u16>,
project: Entity<Project>,
- zeta: Entity<Zeta>,
+ store: Entity<EditPredictionStore>,
prediction_options: PredictionOptions,
predict: bool,
cx: &mut AsyncApp,
@@ -162,7 +162,7 @@ pub async fn run_evaluate_one(
let predict_result = perform_predict(
example.clone(),
project,
- zeta,
+ store,
repetition_ix,
prediction_options,
cx,
@@ -3,6 +3,8 @@ use std::{
cell::RefCell,
fmt::{self, Display},
fs,
+ hash::Hash,
+ hash::Hasher,
io::Write,
mem,
path::{Path, PathBuf},
@@ -14,6 +16,7 @@ use anyhow::{Context as _, Result, anyhow};
use clap::ValueEnum;
use cloud_zeta2_prompt::CURSOR_MARKER;
use collections::HashMap;
+use edit_prediction::udiff::OpenedBuffers;
use futures::{
AsyncWriteExt as _,
lock::{Mutex, OwnedMutexGuard},
@@ -25,7 +28,6 @@ use project::{Project, ProjectPath};
use pulldown_cmark::CowStr;
use serde::{Deserialize, Serialize};
use util::{paths::PathStyle, rel_path::RelPath};
-use zeta::udiff::OpenedBuffers;
use crate::paths::{REPOS_DIR, WORKTREES_DIR};
@@ -43,7 +45,7 @@ pub struct NamedExample {
pub example: Example,
}
-#[derive(Clone, Debug, Serialize, Deserialize)]
+#[derive(Clone, Debug, Hash, Serialize, Deserialize)]
pub struct Example {
pub repository_url: String,
pub revision: String,
@@ -54,6 +56,134 @@ pub struct Example {
pub expected_patch: String,
}
+impl Example {
+ fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
+ // git@github.com:owner/repo.git
+ if self.repository_url.contains('@') {
+ let (owner, repo) = self
+ .repository_url
+ .split_once(':')
+ .context("expected : in git url")?
+ .1
+ .split_once('/')
+ .context("expected / in git url")?;
+ Ok((
+ Cow::Borrowed(owner),
+ Cow::Borrowed(repo.trim_end_matches(".git")),
+ ))
+ // http://github.com/owner/repo.git
+ } else {
+ let url = Url::parse(&self.repository_url)?;
+ let mut segments = url.path_segments().context("empty http url")?;
+ let owner = segments
+ .next()
+ .context("expected owner path segment")?
+ .to_string();
+ let repo = segments
+ .next()
+ .context("expected repo path segment")?
+ .trim_end_matches(".git")
+ .to_string();
+ assert!(segments.next().is_none());
+
+ Ok((owner.into(), repo.into()))
+ }
+ }
+
+ pub async fn setup_worktree(&self, file_name: String) -> Result<PathBuf> {
+ let (repo_owner, repo_name) = self.repo_name()?;
+
+ let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
+ let repo_lock = lock_repo(&repo_dir).await;
+
+ if !repo_dir.is_dir() {
+ fs::create_dir_all(&repo_dir)?;
+ run_git(&repo_dir, &["init"]).await?;
+ run_git(
+ &repo_dir,
+ &["remote", "add", "origin", &self.repository_url],
+ )
+ .await?;
+ }
+
+ // Resolve the example to a revision, fetching it if needed.
+ let revision = run_git(
+ &repo_dir,
+ &["rev-parse", &format!("{}^{{commit}}", self.revision)],
+ )
+ .await;
+ let revision = if let Ok(revision) = revision {
+ revision
+ } else {
+ if run_git(
+ &repo_dir,
+ &["fetch", "--depth", "1", "origin", &self.revision],
+ )
+ .await
+ .is_err()
+ {
+ run_git(&repo_dir, &["fetch", "origin"]).await?;
+ }
+ let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
+ if revision != self.revision {
+ run_git(&repo_dir, &["tag", &self.revision, &revision]).await?;
+ }
+ revision
+ };
+
+ // Create the worktree for this example if needed.
+ let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
+ if worktree_path.is_dir() {
+ run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
+ run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
+ run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
+ } else {
+ let worktree_path_string = worktree_path.to_string_lossy();
+ run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
+ run_git(
+ &repo_dir,
+ &["worktree", "add", "-f", &worktree_path_string, &file_name],
+ )
+ .await?;
+ }
+ drop(repo_lock);
+
+ // Apply the uncommitted diff for this example.
+ if !self.uncommitted_diff.is_empty() {
+ let mut apply_process = smol::process::Command::new("git")
+ .current_dir(&worktree_path)
+ .args(&["apply", "-"])
+ .stdin(std::process::Stdio::piped())
+ .spawn()?;
+
+ let mut stdin = apply_process.stdin.take().unwrap();
+ stdin.write_all(self.uncommitted_diff.as_bytes()).await?;
+ stdin.close().await?;
+ drop(stdin);
+
+ let apply_result = apply_process.output().await?;
+ if !apply_result.status.success() {
+ anyhow::bail!(
+ "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
+ apply_result.status,
+ String::from_utf8_lossy(&apply_result.stderr),
+ String::from_utf8_lossy(&apply_result.stdout),
+ );
+ }
+ }
+
+ Ok(worktree_path)
+ }
+
+ pub fn unique_name(&self) -> String {
+ let mut hasher = std::hash::DefaultHasher::new();
+ self.hash(&mut hasher);
+ let disambiguator = hasher.finish();
+ let hash = format!("{:04x}", disambiguator);
+ format!("{}_{}", &self.revision[..8], &hash[..4])
+ }
+}
+
pub type ActualExcerpt = Excerpt;
#[derive(Clone, Debug, Serialize, Deserialize)]
@@ -292,90 +422,7 @@ impl NamedExample {
}
pub async fn setup_worktree(&self) -> Result<PathBuf> {
- let (repo_owner, repo_name) = self.repo_name()?;
- let file_name = self.file_name();
-
- let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
- let repo_lock = lock_repo(&repo_dir).await;
-
- if !repo_dir.is_dir() {
- fs::create_dir_all(&repo_dir)?;
- run_git(&repo_dir, &["init"]).await?;
- run_git(
- &repo_dir,
- &["remote", "add", "origin", &self.example.repository_url],
- )
- .await?;
- }
-
- // Resolve the example to a revision, fetching it if needed.
- let revision = run_git(
- &repo_dir,
- &[
- "rev-parse",
- &format!("{}^{{commit}}", self.example.revision),
- ],
- )
- .await;
- let revision = if let Ok(revision) = revision {
- revision
- } else {
- run_git(
- &repo_dir,
- &["fetch", "--depth", "1", "origin", &self.example.revision],
- )
- .await?;
- let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
- if revision != self.example.revision {
- run_git(&repo_dir, &["tag", &self.example.revision, &revision]).await?;
- }
- revision
- };
-
- // Create the worktree for this example if needed.
- let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
- if worktree_path.is_dir() {
- run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
- run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
- run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
- } else {
- let worktree_path_string = worktree_path.to_string_lossy();
- run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
- run_git(
- &repo_dir,
- &["worktree", "add", "-f", &worktree_path_string, &file_name],
- )
- .await?;
- }
- drop(repo_lock);
-
- // Apply the uncommitted diff for this example.
- if !self.example.uncommitted_diff.is_empty() {
- let mut apply_process = smol::process::Command::new("git")
- .current_dir(&worktree_path)
- .args(&["apply", "-"])
- .stdin(std::process::Stdio::piped())
- .spawn()?;
-
- let mut stdin = apply_process.stdin.take().unwrap();
- stdin
- .write_all(self.example.uncommitted_diff.as_bytes())
- .await?;
- stdin.close().await?;
- drop(stdin);
-
- let apply_result = apply_process.output().await?;
- if !apply_result.status.success() {
- anyhow::bail!(
- "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
- apply_result.status,
- String::from_utf8_lossy(&apply_result.stderr),
- String::from_utf8_lossy(&apply_result.stdout),
- );
- }
- }
-
- Ok(worktree_path)
+ self.example.setup_worktree(self.file_name()).await
}
pub fn file_name(&self) -> String {
@@ -391,40 +438,6 @@ impl NamedExample {
.collect()
}
- fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
- // git@github.com:owner/repo.git
- if self.example.repository_url.contains('@') {
- let (owner, repo) = self
- .example
- .repository_url
- .split_once(':')
- .context("expected : in git url")?
- .1
- .split_once('/')
- .context("expected / in git url")?;
- Ok((
- Cow::Borrowed(owner),
- Cow::Borrowed(repo.trim_end_matches(".git")),
- ))
- // http://github.com/owner/repo.git
- } else {
- let url = Url::parse(&self.example.repository_url)?;
- let mut segments = url.path_segments().context("empty http url")?;
- let owner = segments
- .next()
- .context("expected owner path segment")?
- .to_string();
- let repo = segments
- .next()
- .context("expected repo path segment")?
- .trim_end_matches(".git")
- .to_string();
- assert!(segments.next().is_none());
-
- Ok((owner.into(), repo.into()))
- }
- }
-
pub async fn cursor_position(
&self,
project: &Entity<Project>,
@@ -481,7 +494,7 @@ impl NamedExample {
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<OpenedBuffers<'_>> {
- zeta::udiff::apply_diff(&self.example.edit_history, project, cx).await
+ edit_prediction::udiff::apply_diff(&self.example.edit_history, project, cx).await
}
}
@@ -5,7 +5,7 @@ mod metrics;
mod paths;
mod predict;
mod source_location;
-mod syntax_retrieval_stats;
+mod training;
mod util;
use crate::{
@@ -14,27 +14,22 @@ use crate::{
headless::ZetaCliAppState,
predict::run_predict,
source_location::SourceLocation,
- syntax_retrieval_stats::retrieval_stats,
+ training::{context::ContextType, distill::run_distill},
util::{open_buffer, open_buffer_with_language_server},
};
-use ::util::paths::PathStyle;
+use ::util::{ResultExt, paths::PathStyle};
use anyhow::{Result, anyhow};
use clap::{Args, Parser, Subcommand, ValueEnum};
use cloud_llm_client::predict_edits_v3;
-use edit_prediction_context::{
- EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
-};
+use edit_prediction::udiff::DiffLine;
+use edit_prediction_context::EditPredictionExcerptOptions;
use gpui::{Application, AsyncApp, Entity, prelude::*};
use language::{Bias, Buffer, BufferSnapshot, Point};
use metrics::delta_chr_f;
-use project::{Project, Worktree};
+use project::{Project, Worktree, lsp_store::OpenLspBufferHandle};
use reqwest_client::ReqwestClient;
-use serde_json::json;
use std::io::{self};
-use std::time::Duration;
use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
-use zeta::ContextMode;
-use zeta::udiff::DiffLine;
#[derive(Parser, Debug)]
#[command(name = "zeta")]
@@ -48,9 +43,9 @@ struct ZetaCliArgs {
#[derive(Subcommand, Debug)]
enum Command {
Context(ContextArgs),
- ContextStats(ContextStatsArgs),
Predict(PredictArguments),
Eval(EvaluateArguments),
+ Distill(DistillArguments),
ConvertExample {
path: PathBuf,
#[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
@@ -63,20 +58,6 @@ enum Command {
Clean,
}
-#[derive(Debug, Args)]
-struct ContextStatsArgs {
- #[arg(long)]
- worktree: PathBuf,
- #[arg(long)]
- extension: Option<String>,
- #[arg(long)]
- limit: Option<usize>,
- #[arg(long)]
- skip: Option<usize>,
- #[clap(flatten)]
- zeta2_args: Zeta2Args,
-}
-
#[derive(Debug, Args)]
struct ContextArgs {
#[arg(long)]
@@ -97,7 +78,7 @@ struct ContextArgs {
enum ContextProvider {
Zeta1,
#[default]
- Syntax,
+ Zeta2,
}
#[derive(Clone, Debug, Args)]
@@ -133,6 +114,15 @@ pub struct PredictArguments {
options: PredictionOptions,
}
+#[derive(Debug, Args)]
+pub struct DistillArguments {
+ split_commit_dataset: PathBuf,
+ #[clap(long, value_enum, default_value_t = ContextType::CurrentFile)]
+ context_type: ContextType,
+ #[clap(long)]
+ batch: Option<String>,
+}
+
#[derive(Clone, Debug, Args)]
pub struct PredictionOptions {
#[clap(flatten)]
@@ -204,35 +194,22 @@ enum PredictionProvider {
Sweep,
}
-fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta::ZetaOptions {
- zeta::ZetaOptions {
- context: ContextMode::Syntax(EditPredictionContextOptions {
- max_retrieved_declarations: args.max_retrieved_definitions,
- use_imports: !args.disable_imports_gathering,
- excerpt: EditPredictionExcerptOptions {
- max_bytes: args.max_excerpt_bytes,
- min_bytes: args.min_excerpt_bytes,
- target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
- },
- score: EditPredictionScoreOptions {
- omit_excerpt_overlaps,
- },
- }),
- max_diagnostic_bytes: args.max_diagnostic_bytes,
+fn zeta2_args_to_options(args: &Zeta2Args) -> edit_prediction::ZetaOptions {
+ edit_prediction::ZetaOptions {
+ context: EditPredictionExcerptOptions {
+ max_bytes: args.max_excerpt_bytes,
+ min_bytes: args.min_excerpt_bytes,
+ target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
+ },
max_prompt_bytes: args.max_prompt_bytes,
prompt_format: args.prompt_format.into(),
- file_indexing_parallelism: args.file_indexing_parallelism,
- buffer_change_grouping_interval: Duration::ZERO,
}
}
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
enum PromptFormat {
- MarkedExcerpt,
- LabeledSections,
OnlySnippets,
#[default]
- NumberedLines,
OldTextNewText,
Minimal,
MinimalQwen,
@@ -242,10 +219,7 @@ enum PromptFormat {
impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
fn into(self) -> predict_edits_v3::PromptFormat {
match self {
- Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
- Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
- Self::NumberedLines => predict_edits_v3::PromptFormat::NumLinesUniDiff,
Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
@@ -295,6 +269,7 @@ struct LoadedContext {
worktree: Entity<Worktree>,
project: Entity<Project>,
buffer: Entity<Buffer>,
+ lsp_open_handle: Option<OpenLspBufferHandle>,
}
async fn load_context(
@@ -330,7 +305,7 @@ async fn load_context(
.await?;
let mut ready_languages = HashSet::default();
- let (_lsp_open_handle, buffer) = if *use_language_server {
+ let (lsp_open_handle, buffer) = if *use_language_server {
let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
project.clone(),
worktree.clone(),
@@ -377,10 +352,11 @@ async fn load_context(
worktree,
project,
buffer,
+ lsp_open_handle,
})
}
-async fn zeta2_syntax_context(
+async fn zeta2_context(
args: ContextArgs,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
@@ -390,6 +366,7 @@ async fn zeta2_syntax_context(
project,
buffer,
clipped_cursor,
+ lsp_open_handle: _handle,
..
} = load_context(&args, app_state, cx).await?;
@@ -402,34 +379,32 @@ async fn zeta2_syntax_context(
.await;
let output = cx
.update(|cx| {
- let zeta = cx.new(|cx| {
- zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
+ let store = cx.new(|cx| {
+ edit_prediction::EditPredictionStore::new(
+ app_state.client.clone(),
+ app_state.user_store.clone(),
+ cx,
+ )
});
- let indexing_done_task = zeta.update(cx, |zeta, cx| {
- zeta.set_options(zeta2_args_to_options(&args.zeta2_args, true));
- zeta.register_buffer(&buffer, &project, cx);
- zeta.wait_for_initial_indexing(&project, cx)
+ store.update(cx, |store, cx| {
+ store.set_options(zeta2_args_to_options(&args.zeta2_args));
+ store.register_buffer(&buffer, &project, cx);
});
cx.spawn(async move |cx| {
- indexing_done_task.await?;
- let request = zeta
- .update(cx, |zeta, cx| {
- let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
- zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
- })?
- .await?;
-
- let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?;
-
- match args.zeta2_args.output_format {
- OutputFormat::Prompt => anyhow::Ok(prompt_string),
- OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?),
- OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({
- "request": request,
- "prompt": prompt_string,
- "section_labels": section_labels,
- }))?),
- }
+ let updates_rx = store.update(cx, |store, cx| {
+ let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
+ store.set_use_context(true);
+ store.refresh_context(&project, &buffer, cursor, cx);
+ store.project_context_updates(&project).unwrap()
+ })?;
+
+ updates_rx.recv().await.ok();
+
+ let context = store.update(cx, |store, cx| {
+ store.context_for_project(&project, cx).to_vec()
+ })?;
+
+ anyhow::Ok(serde_json::to_string_pretty(&context).unwrap())
})
})?
.await?;
@@ -441,7 +416,7 @@ async fn zeta1_context(
args: ContextArgs,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
-) -> Result<zeta::zeta1::GatherContextOutput> {
+) -> Result<edit_prediction::zeta1::GatherContextOutput> {
let LoadedContext {
full_path_str,
snapshot,
@@ -456,7 +431,7 @@ async fn zeta1_context(
let prompt_for_events = move || (events, 0);
cx.update(|cx| {
- zeta::zeta1::gather_context(
+ edit_prediction::zeta1::gather_context(
full_path_str,
&snapshot,
clipped_cursor,
@@ -482,24 +457,10 @@ fn main() {
None => {
if args.printenv {
::util::shell_env::print_env();
- return;
} else {
panic!("Expected a command");
}
}
- Some(Command::ContextStats(arguments)) => {
- let result = retrieval_stats(
- arguments.worktree,
- app_state,
- arguments.extension,
- arguments.limit,
- arguments.skip,
- zeta2_args_to_options(&arguments.zeta2_args, false),
- cx,
- )
- .await;
- println!("{}", result.unwrap());
- }
Some(Command::Context(context_args)) => {
let result = match context_args.provider {
ContextProvider::Zeta1 => {
@@ -507,10 +468,8 @@ fn main() {
zeta1_context(context_args, &app_state, cx).await.unwrap();
serde_json::to_string_pretty(&context.body).unwrap()
}
- ContextProvider::Syntax => {
- zeta2_syntax_context(context_args, &app_state, cx)
- .await
- .unwrap()
+ ContextProvider::Zeta2 => {
+ zeta2_context(context_args, &app_state, cx).await.unwrap()
}
};
println!("{}", result);
@@ -521,6 +480,13 @@ fn main() {
Some(Command::Eval(arguments)) => {
run_evaluate(arguments, &app_state, cx).await;
}
+ Some(Command::Distill(arguments)) => {
+ let _guard = cx
+ .update(|cx| gpui_tokio::Tokio::handle(cx))
+ .unwrap()
+ .enter();
+ run_distill(arguments).await.log_err();
+ }
Some(Command::ConvertExample {
path,
output_format,
@@ -1,5 +1,5 @@
use collections::{HashMap, HashSet};
-use zeta::udiff::DiffLine;
+use edit_prediction::udiff::DiffLine;
type Counts = HashMap<String, usize>;
type CountsDelta = HashMap<String, isize>;
@@ -287,7 +287,7 @@ fn count_ngrams(text: &str, n: usize) -> Counts {
#[cfg(test)]
mod test {
use super::*;
- use zeta::udiff::DiffLine;
+ use edit_prediction::udiff::DiffLine;
#[test]
fn test_delta_chr_f_perfect_match() {
@@ -7,6 +7,7 @@ use crate::{
use ::serde::Serialize;
use anyhow::{Context, Result, anyhow};
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
+use edit_prediction::{EditPredictionStore, EvalCache, EvalCacheEntryKind, EvalCacheKey};
use futures::StreamExt as _;
use gpui::{AppContext, AsyncApp, Entity};
use project::Project;
@@ -18,7 +19,6 @@ use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::{Duration, Instant};
-use zeta::{EvalCache, EvalCacheEntryKind, EvalCacheKey, Zeta};
pub async fn run_predict(
args: PredictArguments,
@@ -27,9 +27,9 @@ pub async fn run_predict(
) {
let example = NamedExample::load(args.example_path).unwrap();
let project = example.setup_project(app_state, cx).await.unwrap();
- let zeta = setup_zeta(args.options.provider, &project, app_state, cx).unwrap();
+ let store = setup_store(args.options.provider, &project, app_state, cx).unwrap();
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
- let result = perform_predict(example, project, zeta, None, args.options, cx)
+ let result = perform_predict(example, project, store, None, args.options, cx)
.await
.unwrap();
result.write(args.format, std::io::stdout()).unwrap();
@@ -37,45 +37,50 @@ pub async fn run_predict(
print_run_data_dir(true, std::io::stdout().is_terminal());
}
-pub fn setup_zeta(
+pub fn setup_store(
provider: PredictionProvider,
project: &Entity<Project>,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
-) -> Result<Entity<Zeta>> {
- let zeta =
- cx.new(|cx| zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx))?;
+) -> Result<Entity<EditPredictionStore>> {
+ let store = cx.new(|cx| {
+ edit_prediction::EditPredictionStore::new(
+ app_state.client.clone(),
+ app_state.user_store.clone(),
+ cx,
+ )
+ })?;
- zeta.update(cx, |zeta, _cx| {
+ store.update(cx, |store, _cx| {
let model = match provider {
- PredictionProvider::Zeta1 => zeta::ZetaEditPredictionModel::Zeta1,
- PredictionProvider::Zeta2 => zeta::ZetaEditPredictionModel::Zeta2,
- PredictionProvider::Sweep => zeta::ZetaEditPredictionModel::Sweep,
+ PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
+ PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
+ PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
};
- zeta.set_edit_prediction_model(model);
+ store.set_edit_prediction_model(model);
})?;
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
cx.subscribe(&buffer_store, {
let project = project.clone();
- let zeta = zeta.clone();
+ let store = store.clone();
move |_, event, cx| match event {
BufferStoreEvent::BufferAdded(buffer) => {
- zeta.update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
+ store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
}
_ => {}
}
})?
.detach();
- anyhow::Ok(zeta)
+ anyhow::Ok(store)
}
pub async fn perform_predict(
example: NamedExample,
project: Entity<Project>,
- zeta: Entity<Zeta>,
+ store: Entity<EditPredictionStore>,
repetition_ix: Option<u16>,
options: PredictionOptions,
cx: &mut AsyncApp,
@@ -108,8 +113,8 @@ pub async fn perform_predict(
std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
.context("creating latest link")?;
- zeta.update(cx, |zeta, _cx| {
- zeta.with_eval_cache(Arc::new(RunCache {
+ store.update(cx, |store, _cx| {
+ store.with_eval_cache(Arc::new(RunCache {
example_run_dir: example_run_dir.clone(),
cache_mode,
}));
@@ -121,44 +126,43 @@ pub async fn perform_predict(
let prompt_format = options.zeta2.prompt_format;
- zeta.update(cx, |zeta, _cx| {
- let mut options = zeta.options().clone();
+ store.update(cx, |store, _cx| {
+ let mut options = store.options().clone();
options.prompt_format = prompt_format.into();
- zeta.set_options(options);
+ store.set_options(options);
})?;
let mut debug_task = gpui::Task::ready(Ok(()));
if options.provider == crate::PredictionProvider::Zeta2 {
- let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
+ let mut debug_rx = store.update(cx, |store, _| store.debug_info())?;
debug_task = cx.background_spawn({
let result = result.clone();
async move {
let mut start_time = None;
- let mut search_queries_generated_at = None;
- let mut search_queries_executed_at = None;
+ let mut retrieval_finished_at = None;
while let Some(event) = debug_rx.next().await {
match event {
- zeta::ZetaDebugInfo::ContextRetrievalStarted(info) => {
+ edit_prediction::DebugEvent::ContextRetrievalStarted(info) => {
start_time = Some(info.timestamp);
fs::write(
example_run_dir.join("search_prompt.md"),
&info.search_prompt,
)?;
}
- zeta::ZetaDebugInfo::SearchQueriesGenerated(info) => {
- search_queries_generated_at = Some(info.timestamp);
- fs::write(
- example_run_dir.join("search_queries.json"),
- serde_json::to_string_pretty(&info.search_queries).unwrap(),
- )?;
- }
- zeta::ZetaDebugInfo::SearchQueriesExecuted(info) => {
- search_queries_executed_at = Some(info.timestamp);
+ edit_prediction::DebugEvent::ContextRetrievalFinished(info) => {
+ retrieval_finished_at = Some(info.timestamp);
+ for (key, value) in &info.metadata {
+ if *key == "search_queries" {
+ fs::write(
+ example_run_dir.join("search_queries.json"),
+ value.as_bytes(),
+ )?;
+ }
+ }
}
- zeta::ZetaDebugInfo::ContextRetrievalFinished(_info) => {}
- zeta::ZetaDebugInfo::EditPredictionRequested(request) => {
+ edit_prediction::DebugEvent::EditPredictionRequested(request) => {
let prediction_started_at = Instant::now();
start_time.get_or_insert(prediction_started_at);
let prompt = request.local_prompt.unwrap_or_default();
@@ -194,19 +198,16 @@ pub async fn perform_predict(
let response =
request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
- let response = zeta::text_from_response(response).unwrap_or_default();
+ let response =
+ edit_prediction::open_ai_response::text_from_response(response)
+ .unwrap_or_default();
let prediction_finished_at = Instant::now();
fs::write(example_run_dir.join("prediction_response.md"), &response)?;
let mut result = result.lock().unwrap();
result.generated_len = response.chars().count();
-
- result.planning_search_time =
- Some(search_queries_generated_at.unwrap() - start_time.unwrap());
- result.running_search_time = Some(
- search_queries_executed_at.unwrap()
- - search_queries_generated_at.unwrap(),
- );
+ result.retrieval_time =
+ retrieval_finished_at.unwrap() - start_time.unwrap();
result.prediction_time = prediction_finished_at - prediction_started_at;
result.total_time = prediction_finished_at - start_time.unwrap();
@@ -218,15 +219,14 @@ pub async fn perform_predict(
}
});
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
- })?
- .await?;
+ store.update(cx, |store, cx| {
+ store.refresh_context(&project, &cursor_buffer, cursor_anchor, cx)
+ })?;
}
- let prediction = zeta
- .update(cx, |zeta, cx| {
- zeta.request_prediction(
+ let prediction = store
+ .update(cx, |store, cx| {
+ store.request_prediction(
&project,
&cursor_buffer,
cursor_anchor,
@@ -321,8 +321,7 @@ pub struct PredictionDetails {
pub diff: String,
pub excerpts: Vec<ActualExcerpt>,
pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
- pub planning_search_time: Option<Duration>,
- pub running_search_time: Option<Duration>,
+ pub retrieval_time: Duration,
pub prediction_time: Duration,
pub total_time: Duration,
pub run_example_dir: PathBuf,
@@ -336,8 +335,7 @@ impl PredictionDetails {
diff: Default::default(),
excerpts: Default::default(),
excerpts_text: Default::default(),
- planning_search_time: Default::default(),
- running_search_time: Default::default(),
+ retrieval_time: Default::default(),
prediction_time: Default::default(),
total_time: Default::default(),
run_example_dir,
@@ -357,28 +355,20 @@ impl PredictionDetails {
}
pub fn to_markdown(&self) -> String {
- let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time;
-
format!(
"## Excerpts\n\n\
{}\n\n\
## Prediction\n\n\
{}\n\n\
## Time\n\n\
- Planning searches: {}ms\n\
- Running searches: {}ms\n\
- Making Prediction: {}ms\n\n\
- -------------------\n\n\
- Total: {}ms\n\
- Inference: {}ms ({:.2}%)\n",
+ Retrieval: {}ms\n\
+ Prediction: {}ms\n\n\
+ Total: {}ms\n",
self.excerpts_text,
self.diff,
- self.planning_search_time.unwrap_or_default().as_millis(),
- self.running_search_time.unwrap_or_default().as_millis(),
+ self.retrieval_time.as_millis(),
self.prediction_time.as_millis(),
self.total_time.as_millis(),
- inference_time.as_millis(),
- (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100.
)
}
}
@@ -0,0 +1,89 @@
+use std::path::Path;
+
+use crate::{source_location::SourceLocation, training::teacher::TeacherModel};
+
+#[derive(Debug, Clone, Default, clap::ValueEnum)]
+pub enum ContextType {
+ #[default]
+ CurrentFile,
+}
+
+const MAX_CONTEXT_SIZE: usize = 32768;
+
+pub fn collect_context(
+ context_type: &ContextType,
+ worktree_dir: &Path,
+ cursor: SourceLocation,
+) -> String {
+ let context = match context_type {
+ ContextType::CurrentFile => {
+ let file_path = worktree_dir.join(cursor.path.as_std_path());
+ let context = std::fs::read_to_string(&file_path).unwrap_or_default();
+
+ let context = add_special_tags(&context, worktree_dir, cursor);
+ context
+ }
+ };
+
+ let region_end_offset = context.find(TeacherModel::REGION_END);
+
+ if context.len() <= MAX_CONTEXT_SIZE {
+ return context;
+ }
+
+ if let Some(region_end_offset) = region_end_offset
+ && region_end_offset + TeacherModel::REGION_END.len() > MAX_CONTEXT_SIZE
+ {
+ let to_truncate = context.len() - MAX_CONTEXT_SIZE;
+ format!(
+ "[...{} bytes truncated]\n{}\n",
+ to_truncate,
+ &context[to_truncate..]
+ )
+ } else {
+ format!(
+ "{}\n[...{} bytes truncated]\n",
+ &context[..MAX_CONTEXT_SIZE],
+ context.len() - MAX_CONTEXT_SIZE
+ )
+ }
+}
+
+/// Add <|editable_region_start/end|> tags
+fn add_special_tags(context: &str, worktree_dir: &Path, cursor: SourceLocation) -> String {
+ let path = worktree_dir.join(cursor.path.as_std_path());
+ let file = std::fs::read_to_string(&path).unwrap_or_default();
+ let lines = file.lines().collect::<Vec<_>>();
+ let cursor_row = cursor.point.row as usize;
+ let start_line = cursor_row.saturating_sub(TeacherModel::LEFT_CONTEXT_SIZE);
+ let end_line = (cursor_row + TeacherModel::RIGHT_CONTEXT_SIZE).min(lines.len());
+
+ let snippet = lines[start_line..end_line].join("\n");
+
+ if context.contains(&snippet) {
+ let mut cursor_line = lines[cursor_row].to_string();
+ cursor_line.insert_str(cursor.point.column as usize, TeacherModel::USER_CURSOR);
+
+ let mut snippet_with_tags_lines = vec![];
+ snippet_with_tags_lines.push(TeacherModel::REGION_START);
+ snippet_with_tags_lines.extend(&lines[start_line..cursor_row]);
+ snippet_with_tags_lines.push(&cursor_line);
+ snippet_with_tags_lines.extend(&lines[cursor_row + 1..end_line]);
+ snippet_with_tags_lines.push(TeacherModel::REGION_END);
+ let snippet_with_tags = snippet_with_tags_lines.join("\n");
+
+ context.replace(&snippet, &snippet_with_tags)
+ } else {
+ log::warn!(
+ "Can't find area around the cursor in the context; proceeding without special tags"
+ );
+ context.to_string()
+ }
+}
+
+pub fn strip_special_tags(context: &str) -> String {
+ context
+ .replace(TeacherModel::REGION_START, "")
+ .replace(TeacherModel::REGION_END, "")
+ .replace(TeacherModel::USER_CURSOR, "")
+}
@@ -0,0 +1,94 @@
+use serde::Deserialize;
+use std::sync::Arc;
+
+use crate::{
+ DistillArguments,
+ example::Example,
+ source_location::SourceLocation,
+ training::{
+ context::ContextType,
+ llm_client::LlmClient,
+ teacher::{TeacherModel, TeacherOutput},
+ },
+};
+use anyhow::Result;
+use reqwest_client::ReqwestClient;
+
+#[derive(Debug, Deserialize)]
+pub struct SplitCommit {
+ repo_url: String,
+ commit_sha: String,
+ edit_history: String,
+ expected_patch: String,
+ cursor_position: String,
+}
+
+pub async fn run_distill(arguments: DistillArguments) -> Result<()> {
+ let split_commits: Vec<SplitCommit> = std::fs::read_to_string(&arguments.split_commit_dataset)
+ .expect("Failed to read split commit dataset")
+ .lines()
+ .map(|line| serde_json::from_str(line).expect("Failed to parse JSON line"))
+ .collect();
+
+ let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
+
+ let llm_client = if let Some(cache_path) = arguments.batch {
+ LlmClient::batch(&cache_path, http_client)?
+ } else {
+ LlmClient::plain(http_client)?
+ };
+
+ let mut teacher = TeacherModel::new(
+ "claude-sonnet-4-5".to_string(),
+ ContextType::CurrentFile,
+ llm_client,
+ );
+
+ let mut num_marked_for_batching = 0;
+
+ for commit in split_commits {
+ if let Some(distilled) = distill_one(&mut teacher, commit).await? {
+ println!("{}", serde_json::to_string(&distilled)?);
+ } else {
+ if num_marked_for_batching == 0 {
+ log::warn!("Marked for batching");
+ }
+ num_marked_for_batching += 1;
+ }
+ }
+
+ eprintln!(
+ "{} requests are marked for batching",
+ num_marked_for_batching
+ );
+ let llm_client = teacher.client;
+ llm_client.sync_batches().await?;
+
+ Ok(())
+}
+
+pub async fn distill_one(
+ teacher: &mut TeacherModel,
+ commit: SplitCommit,
+) -> Result<Option<TeacherOutput>> {
+ let cursor: SourceLocation = commit
+ .cursor_position
+ .parse()
+ .expect("Failed to parse cursor position");
+
+ let path = cursor.path.to_rel_path_buf();
+
+ let example = Example {
+ repository_url: commit.repo_url,
+ revision: commit.commit_sha,
+ uncommitted_diff: commit.edit_history.clone(),
+ cursor_path: path.as_std_path().to_path_buf(),
+ cursor_position: commit.cursor_position,
+ edit_history: commit.edit_history, // todo: trim
+ expected_patch: commit.expected_patch,
+ };
+
+ let prediction = teacher.predict(example).await;
+
+ prediction
+}
@@ -0,0 +1,417 @@
+use anthropic::{
+ ANTHROPIC_API_URL, Message, Request as AnthropicRequest, RequestContent,
+ Response as AnthropicResponse, Role, non_streaming_completion,
+};
+use anyhow::Result;
+use http_client::HttpClient;
+use indoc::indoc;
+use sqlez::bindable::Bind;
+use sqlez::bindable::StaticColumnCount;
+use sqlez_macros::sql;
+use std::hash::Hash;
+use std::hash::Hasher;
+use std::sync::Arc;
+
+pub struct PlainLlmClient {
+ http_client: Arc<dyn HttpClient>,
+ api_key: String,
+}
+
+impl PlainLlmClient {
+ fn new(http_client: Arc<dyn HttpClient>) -> Result<Self> {
+ let api_key = std::env::var("ANTHROPIC_API_KEY")
+ .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
+ Ok(Self {
+ http_client,
+ api_key,
+ })
+ }
+
+ async fn generate(
+ &self,
+ model: String,
+ max_tokens: u64,
+ messages: Vec<Message>,
+ ) -> Result<AnthropicResponse> {
+ let request = AnthropicRequest {
+ model,
+ max_tokens,
+ messages,
+ tools: Vec::new(),
+ thinking: None,
+ tool_choice: None,
+ system: None,
+ metadata: None,
+ stop_sequences: Vec::new(),
+ temperature: None,
+ top_k: None,
+ top_p: None,
+ };
+
+ let response = non_streaming_completion(
+ self.http_client.as_ref(),
+ ANTHROPIC_API_URL,
+ &self.api_key,
+ request,
+ None,
+ )
+ .await
+ .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+
+ Ok(response)
+ }
+}
+
+pub struct BatchingLlmClient {
+ connection: sqlez::connection::Connection,
+ http_client: Arc<dyn HttpClient>,
+ api_key: String,
+}
+
+struct CacheRow {
+ request_hash: String,
+ request: Option<String>,
+ response: Option<String>,
+ batch_id: Option<String>,
+}
+
+impl StaticColumnCount for CacheRow {
+ fn column_count() -> usize {
+ 4
+ }
+}
+
+impl Bind for CacheRow {
+ fn bind(&self, statement: &sqlez::statement::Statement, start_index: i32) -> Result<i32> {
+ let next_index = statement.bind(&self.request_hash, start_index)?;
+ let next_index = statement.bind(&self.request, next_index)?;
+ let next_index = statement.bind(&self.response, next_index)?;
+ let next_index = statement.bind(&self.batch_id, next_index)?;
+ Ok(next_index)
+ }
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+struct SerializableRequest {
+ model: String,
+ max_tokens: u64,
+ messages: Vec<SerializableMessage>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+struct SerializableMessage {
+ role: String,
+ content: String,
+}
+
+impl BatchingLlmClient {
+ fn new(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
+ let api_key = std::env::var("ANTHROPIC_API_KEY")
+ .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
+
+ let connection = sqlez::connection::Connection::open_file(&cache_path);
+ let mut statement = sqlez::statement::Statement::prepare(
+ &connection,
+ indoc! {"
+ CREATE TABLE IF NOT EXISTS cache (
+ request_hash TEXT PRIMARY KEY,
+ request TEXT,
+ response TEXT,
+ batch_id TEXT
+ );
+ "},
+ )?;
+ statement.exec()?;
+ drop(statement);
+
+ Ok(Self {
+ connection,
+ http_client,
+ api_key,
+ })
+ }
+
+ pub fn lookup(
+ &self,
+ model: &str,
+ max_tokens: u64,
+ messages: &[Message],
+ ) -> Result<Option<AnthropicResponse>> {
+ let request_hash_str = Self::request_hash(model, max_tokens, messages);
+ let response: Vec<String> = self.connection.select_bound(
+ &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
+ )?(request_hash_str.as_str())?;
+ Ok(response
+ .into_iter()
+ .next()
+ .and_then(|text| serde_json::from_str(&text).ok()))
+ }
+
+ pub fn mark_for_batch(&self, model: &str, max_tokens: u64, messages: &[Message]) -> Result<()> {
+ let request_hash = Self::request_hash(model, max_tokens, messages);
+
+ let serializable_messages: Vec<SerializableMessage> = messages
+ .iter()
+ .map(|msg| SerializableMessage {
+ role: match msg.role {
+ Role::User => "user".to_string(),
+ Role::Assistant => "assistant".to_string(),
+ },
+ content: message_content_to_string(&msg.content),
+ })
+ .collect();
+
+ let serializable_request = SerializableRequest {
+ model: model.to_string(),
+ max_tokens,
+ messages: serializable_messages,
+ };
+
+ let request = Some(serde_json::to_string(&serializable_request)?);
+ let cache_row = CacheRow {
+ request_hash,
+ request,
+ response: None,
+ batch_id: None,
+ };
+ self.connection.exec_bound(sql!(
+ INSERT OR IGNORE INTO cache(request_hash, request, response, batch_id) VALUES (?, ?, ?, ?)))?(
+ cache_row,
+ )
+ }
+
+ async fn generate(
+ &self,
+ model: String,
+ max_tokens: u64,
+ messages: Vec<Message>,
+ ) -> Result<Option<AnthropicResponse>> {
+ let response = self.lookup(&model, max_tokens, &messages)?;
+ if let Some(response) = response {
+ return Ok(Some(response));
+ }
+
+ self.mark_for_batch(&model, max_tokens, &messages)?;
+
+ Ok(None)
+ }
+
+ /// Uploads pending requests as a new batch; downloads finished batches if any.
+ async fn sync_batches(&self) -> Result<()> {
+ self.upload_pending_requests().await?;
+ self.download_finished_batches().await
+ }
+
+ async fn download_finished_batches(&self) -> Result<()> {
+ let q = sql!(SELECT DISTINCT batch_id FROM cache WHERE batch_id IS NOT NULL AND response IS NULL);
+ let batch_ids: Vec<String> = self.connection.select(q)?()?;
+
+ for batch_id in batch_ids {
+ let batch_status = anthropic::batches::retrieve_batch(
+ self.http_client.as_ref(),
+ ANTHROPIC_API_URL,
+ &self.api_key,
+ &batch_id,
+ )
+ .await
+ .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+
+ log::info!(
+ "Batch {} status: {}",
+ batch_id,
+ batch_status.processing_status
+ );
+
+ if batch_status.processing_status == "ended" {
+ let results = anthropic::batches::retrieve_batch_results(
+ self.http_client.as_ref(),
+ ANTHROPIC_API_URL,
+ &self.api_key,
+ &batch_id,
+ )
+ .await
+ .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+
+ let mut success_count = 0;
+ for result in results {
+ let request_hash = result
+ .custom_id
+ .strip_prefix("req_hash_")
+ .unwrap_or(&result.custom_id)
+ .to_string();
+
+ match result.result {
+ anthropic::batches::BatchResult::Succeeded { message } => {
+ let response_json = serde_json::to_string(&message)?;
+ let q = sql!(UPDATE cache SET response = ? WHERE request_hash = ?);
+ self.connection.exec_bound(q)?((response_json, request_hash))?;
+ success_count += 1;
+ }
+ anthropic::batches::BatchResult::Errored { error } => {
+ log::error!("Batch request {} failed: {:?}", request_hash, error);
+ }
+ anthropic::batches::BatchResult::Canceled => {
+ log::warn!("Batch request {} was canceled", request_hash);
+ }
+ anthropic::batches::BatchResult::Expired => {
+ log::warn!("Batch request {} expired", request_hash);
+ }
+ }
+ }
+ log::info!("Uploaded {} successful requests", success_count);
+ }
+ }
+
+ Ok(())
+ }
+
+ async fn upload_pending_requests(&self) -> Result<String> {
+ let q = sql!(
+ SELECT request_hash, request FROM cache WHERE batch_id IS NULL AND response IS NULL
+ );
+
+ let rows: Vec<(String, String)> = self.connection.select(q)?()?;
+
+ if rows.is_empty() {
+ return Ok(String::new());
+ }
+
+ let batch_requests = rows
+ .iter()
+ .map(|(hash, request_str)| {
+ let serializable_request: SerializableRequest =
+ serde_json::from_str(&request_str).unwrap();
+
+ let messages: Vec<Message> = serializable_request
+ .messages
+ .into_iter()
+ .map(|msg| Message {
+ role: match msg.role.as_str() {
+ "user" => Role::User,
+ "assistant" => Role::Assistant,
+ _ => Role::User,
+ },
+ content: vec![RequestContent::Text {
+ text: msg.content,
+ cache_control: None,
+ }],
+ })
+ .collect();
+
+ let params = AnthropicRequest {
+ model: serializable_request.model,
+ max_tokens: serializable_request.max_tokens,
+ messages,
+ tools: Vec::new(),
+ thinking: None,
+ tool_choice: None,
+ system: None,
+ metadata: None,
+ stop_sequences: Vec::new(),
+ temperature: None,
+ top_k: None,
+ top_p: None,
+ };
+
+ let custom_id = format!("req_hash_{}", hash);
+ anthropic::batches::BatchRequest { custom_id, params }
+ })
+ .collect::<Vec<_>>();
+
+ let batch_len = batch_requests.len();
+ let batch = anthropic::batches::create_batch(
+ self.http_client.as_ref(),
+ ANTHROPIC_API_URL,
+ &self.api_key,
+ anthropic::batches::CreateBatchRequest {
+ requests: batch_requests,
+ },
+ )
+ .await
+ .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+
+ let q = sql!(
+ UPDATE cache SET batch_id = ? WHERE batch_id is NULL
+ );
+ self.connection.exec_bound(q)?(batch.id.as_str())?;
+
+ log::info!("Uploaded batch with {} requests", batch_len);
+
+ Ok(batch.id)
+ }
+
+ fn request_hash(model: &str, max_tokens: u64, messages: &[Message]) -> String {
+ let mut hasher = std::hash::DefaultHasher::new();
+ model.hash(&mut hasher);
+ max_tokens.hash(&mut hasher);
+ for msg in messages {
+ message_content_to_string(&msg.content).hash(&mut hasher);
+ }
+ let request_hash = hasher.finish();
+ format!("{request_hash:016x}")
+ }
+}
+
+fn message_content_to_string(content: &[RequestContent]) -> String {
+ content
+ .iter()
+ .filter_map(|c| match c {
+ RequestContent::Text { text, .. } => Some(text.clone()),
+ _ => None,
+ })
+ .collect::<Vec<String>>()
+ .join("\n")
+}
+
+pub enum LlmClient {
+ // No batching
+ Plain(PlainLlmClient),
+ Batch(BatchingLlmClient),
+ Dummy,
+}
+
+impl LlmClient {
+ pub fn plain(http_client: Arc<dyn HttpClient>) -> Result<Self> {
+ Ok(Self::Plain(PlainLlmClient::new(http_client)?))
+ }
+
+ pub fn batch(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
+ Ok(Self::Batch(BatchingLlmClient::new(
+ cache_path,
+ http_client,
+ )?))
+ }
+
+ #[allow(dead_code)]
+ pub fn dummy() -> Self {
+ Self::Dummy
+ }
+
+ pub async fn generate(
+ &self,
+ model: String,
+ max_tokens: u64,
+ messages: Vec<Message>,
+ ) -> Result<Option<AnthropicResponse>> {
+ match self {
+ LlmClient::Plain(plain_llm_client) => plain_llm_client
+ .generate(model, max_tokens, messages)
+ .await
+ .map(Some),
+ LlmClient::Batch(batching_llm_client) => {
+ batching_llm_client
+ .generate(model, max_tokens, messages)
+ .await
+ }
+ LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
+ }
+ }
+
+ pub async fn sync_batches(&self) -> Result<()> {
+ match self {
+ LlmClient::Plain(_) => Ok(()),
+ LlmClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
+ LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
+ }
+ }
+}
@@ -0,0 +1,4 @@
+pub mod context;
+pub mod distill;
+pub mod llm_client;
+pub mod teacher;
@@ -0,0 +1,48 @@
+# Instructions
+
+You are a code completion assistant helping a programmer finish their work. Your task is to:
+
+1. Analyze the edit history to understand what the programmer is trying to achieve
+2. Identify any incomplete refactoring or changes that need to be finished
+3. Make the remaining edits that a human programmer would logically make next (by rewriting the corresponding code sections)
+4. Apply systematic changes consistently across the entire codebase - if you see a pattern starting, complete it everywhere.
+
+Focus on:
+- Understanding the intent behind the changes (e.g., improving error handling, refactoring APIs, fixing bugs)
+- Completing any partially-applied changes across the codebase
+- Ensuring consistency with the programming style and patterns already established
+- Making edits that maintain or improve code quality
+- If the programmer started refactoring one instance of a pattern, find and update ALL similar instances
+- Don't write a lot of code if you're not sure what to do
+
+Rules:
+- Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
+- Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
+
+Input format:
+- You receive small code fragments called context (structs, field definitions, function signatures, etc.). They may or may not be relevant.
+- Never modify the context code.
+- You also receive a code snippet between <|editable_region_start|> and <|editable_region_end|>. This is the editable region.
+- The cursor position is marked with <|user_cursor|>.
+
+Output format:
+- Return the entire editable region, applying any edits you make.
+- Remove the <|user_cursor|> marker.
+- Wrap the edited code in a block of exactly five backticks.
+
+Output example:
+`````
+ // `zed --askpass` Makes zed operate in nc/netcat mode for use with askpass
+ if let Some(socket) = &args.askpass {{
+ askpass::main(socket);
+ return Ok(());
+ }}
+`````
+
+## User Edits History
+
+{{edit_history}}
+
+## Code Context
+
+{{context}}
@@ -0,0 +1,266 @@
+use crate::{
+ example::Example,
+ source_location::SourceLocation,
+ training::{
+ context::{ContextType, collect_context, strip_special_tags},
+ llm_client::LlmClient,
+ },
+};
+use anthropic::{Message, RequestContent, ResponseContent, Role};
+use anyhow::Result;
+
+pub struct TeacherModel {
+ pub llm_name: String,
+ pub context: ContextType,
+ pub client: LlmClient,
+}
+
+#[derive(Debug, serde::Serialize)]
+pub struct TeacherOutput {
+ parsed_output: String,
+ prompt: String,
+ raw_llm_response: String,
+ context: String,
+ diff: String,
+}
+
+impl TeacherModel {
+ const PROMPT: &str = include_str!("teacher.prompt.md");
+ pub(crate) const REGION_START: &str = "<|editable_region_start|>\n";
+ pub(crate) const REGION_END: &str = "<|editable_region_end|>";
+ pub(crate) const USER_CURSOR: &str = "<|user_cursor|>";
+
+ /// Number of lines to include before the cursor position
+ pub(crate) const LEFT_CONTEXT_SIZE: usize = 5;
+
+ /// Number of lines to include after the cursor position
+ pub(crate) const RIGHT_CONTEXT_SIZE: usize = 5;
+
+ /// Truncate edit history to this number of last lines
+ const MAX_HISTORY_LINES: usize = 128;
+
+ pub fn new(llm_name: String, context: ContextType, client: LlmClient) -> Self {
+ TeacherModel {
+ llm_name,
+ context,
+ client,
+ }
+ }
+
+ pub async fn predict(&self, input: Example) -> Result<Option<TeacherOutput>> {
+ let name = input.unique_name();
+ let worktree_dir = input.setup_worktree(name).await?;
+ let cursor: SourceLocation = input
+ .cursor_position
+ .parse()
+ .expect("Failed to parse cursor position");
+
+ let context = collect_context(&self.context, &worktree_dir, cursor.clone());
+ let edit_history = Self::format_edit_history(&input.edit_history);
+
+ let prompt = Self::PROMPT
+ .replace("{{context}}", &context)
+ .replace("{{edit_history}}", &edit_history);
+
+ let messages = vec![Message {
+ role: Role::User,
+ content: vec![RequestContent::Text {
+ text: prompt.clone(),
+ cache_control: None,
+ }],
+ }];
+
+ let Some(response) = self
+ .client
+ .generate(self.llm_name.clone(), 16384, messages)
+ .await?
+ else {
+ return Ok(None);
+ };
+
+ let response_text = response
+ .content
+ .into_iter()
+ .filter_map(|content| match content {
+ ResponseContent::Text { text } => Some(text),
+ _ => None,
+ })
+ .collect::<Vec<String>>()
+ .join("\n");
+
+ let parsed_output = self.parse_response(&response_text);
+
+ let original_editable_region = Self::extract_editable_region(&context);
+ let context_after_edit = context.replace(&original_editable_region, &parsed_output);
+ let context_after_edit = strip_special_tags(&context_after_edit);
+ let context_before_edit = strip_special_tags(&context);
+ let diff = language::unified_diff(&context_before_edit, &context_after_edit);
+
+ // zeta distill --batch batch_results.txt
+ // zeta distill
+ // 1. Run `zeta distill <2000 examples <- all examples>` for the first time
+ // - store LLM requests in a batch, don't actual send the request
+ // - send the batch (2000 requests) after all inputs are processed
+ // 2. `zeta send-batches`
+ // - upload the batch to Anthropic
+
+ // https://platform.claude.com/docs/en/build-with-claude/batch-processing
+ // https://crates.io/crates/anthropic-sdk-rust
+
+ // - poll for results
+ // - when ready, store results in cache (a database)
+ // 3. `zeta distill` again
+ // - use the cached results this time
+
+ Ok(Some(TeacherOutput {
+ parsed_output,
+ prompt,
+ raw_llm_response: response_text,
+ context,
+ diff,
+ }))
+ }
+
+ fn parse_response(&self, content: &str) -> String {
+ let codeblock = Self::extract_last_codeblock(content);
+ let editable_region = Self::extract_editable_region(&codeblock);
+
+ editable_region
+ }
+
+ /// Extract content from the last code-fenced block if any, or else return content as is
+ fn extract_last_codeblock(text: &str) -> String {
+ let mut last_block = None;
+ let mut search_start = 0;
+
+ while let Some(start) = text[search_start..].find("```") {
+ let start = start + search_start;
+ let bytes = text.as_bytes();
+ let mut backtick_end = start;
+
+ while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
+ backtick_end += 1;
+ }
+
+ let backtick_count = backtick_end - start;
+ let closing_backticks = "`".repeat(backtick_count);
+
+ if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
+ let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
+ last_block = Some(code_block.to_string());
+ search_start = backtick_end + end_pos + backtick_count;
+ } else {
+ break;
+ }
+ }
+
+ last_block.unwrap_or_else(|| text.to_string())
+ }
+
+ fn extract_editable_region(text: &str) -> String {
+ let start = text
+ .find(Self::REGION_START)
+ .map_or(0, |pos| pos + Self::REGION_START.len());
+ let end = text.find(Self::REGION_END).unwrap_or(text.len());
+
+ text[start..end].to_string()
+ }
+
+ /// Truncates edit history to a maximum length and removes comments (unified diff garbage lines)
+ fn format_edit_history(edit_history: &str) -> String {
+ let lines = edit_history
+ .lines()
+ .filter(|&s| Self::is_content_line(s))
+ .collect::<Vec<_>>();
+
+ let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
+ &lines[lines.len() - Self::MAX_HISTORY_LINES..]
+ } else {
+ &lines
+ };
+ history_lines.join("\n")
+ }
+
+ fn is_content_line(s: &str) -> bool {
+ s.starts_with("-")
+ || s.starts_with("+")
+ || s.starts_with(" ")
+ || s.starts_with("---")
+ || s.starts_with("+++")
+ || s.starts_with("@@")
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_parse_response() {
+ let teacher = TeacherModel::new(
+ "test".to_string(),
+ ContextType::CurrentFile,
+ LlmClient::dummy(),
+ );
+ let response = "This is a test response.";
+ let parsed = teacher.parse_response(response);
+ assert_eq!(parsed, response.to_string());
+
+ let response = indoc::indoc! {"
+ Some thinking
+
+ `````
+ actual response
+ `````
+ "};
+ let parsed = teacher.parse_response(response);
+ assert_eq!(parsed, "actual response");
+ }
+
+ #[test]
+ fn test_extract_last_code_block() {
+ let text = indoc::indoc! {"
+ Some thinking
+
+ ```
+ first block
+ ```
+
+ `````
+ last block
+ `````
+ "};
+ let last_block = TeacherModel::extract_last_codeblock(text);
+ assert_eq!(last_block, "last block");
+ }
+
+ #[test]
+ fn test_extract_editable_region() {
+ let teacher = TeacherModel::new(
+ "test".to_string(),
+ ContextType::CurrentFile,
+ LlmClient::dummy(),
+ );
+ let response = indoc::indoc! {"
+ some lines
+ are
+ here
+ <|editable_region_start|>
+ one
+ two three
+
+ <|editable_region_end|>
+ more
+ lines here
+ "};
+ let parsed = teacher.parse_response(response);
+ assert_eq!(
+ parsed,
+ indoc::indoc! {"
+ one
+ two three
+
+ "}
+ );
+ }
+}
@@ -2,7 +2,8 @@ use anyhow::{Result, anyhow};
use futures::channel::mpsc;
use futures::{FutureExt as _, StreamExt as _};
use gpui::{AsyncApp, Entity, Task};
-use language::{Buffer, LanguageId, LanguageServerId, ParseStatus};
+use language::{Buffer, LanguageId, LanguageNotFound, LanguageServerId, ParseStatus};
+use project::lsp_store::OpenLspBufferHandle;
use project::{Project, ProjectPath, Worktree};
use std::collections::HashSet;
use std::sync::Arc;
@@ -40,7 +41,7 @@ pub async fn open_buffer_with_language_server(
path: Arc<RelPath>,
ready_languages: &mut HashSet<LanguageId>,
cx: &mut AsyncApp,
-) -> Result<(Entity<Entity<Buffer>>, LanguageServerId, Entity<Buffer>)> {
+) -> Result<(OpenLspBufferHandle, LanguageServerId, Entity<Buffer>)> {
let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?;
let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
@@ -50,6 +51,17 @@ pub async fn open_buffer_with_language_server(
)
})?;
+ let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
+ let result = language_registry
+ .load_language_for_file_path(path.as_std_path())
+ .await;
+
+ if let Err(error) = result
+ && !error.is::<LanguageNotFound>()
+ {
+ anyhow::bail!(error);
+ }
+
let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
buffer.language().map(|language| language.id())
})?
@@ -57,9 +69,9 @@ pub async fn open_buffer_with_language_server(
return Err(anyhow!("No language for {}", path.display(path_style)));
};
- let log_prefix = path.display(path_style);
+ let log_prefix = format!("{} | ", path.display(path_style));
if !ready_languages.contains(&language_id) {
- wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?;
+ wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
ready_languages.insert(language_id);
}
@@ -95,7 +107,7 @@ pub fn wait_for_lang_server(
log_prefix: String,
cx: &mut AsyncApp,
) -> Task<Result<()>> {
- println!("{}โต Waiting for language server", log_prefix);
+ eprintln!("{}โต Waiting for language server", log_prefix);
let (mut tx, mut rx) = mpsc::channel(1);
@@ -137,7 +149,7 @@ pub fn wait_for_lang_server(
..
} = event
{
- println!("{}โฒ {message}", log_prefix)
+ eprintln!("{}โฒ {message}", log_prefix)
}
}
}),
@@ -162,7 +174,7 @@ pub fn wait_for_lang_server(
cx.spawn(async move |cx| {
if !has_lang_server {
// some buffers never have a language server, so this aborts quickly in that case.
- let timeout = cx.background_executor().timer(Duration::from_secs(5));
+ let timeout = cx.background_executor().timer(Duration::from_secs(500));
futures::select! {
_ = added_rx.next() => {},
_ = timeout.fuse() => {
@@ -173,7 +185,7 @@ pub fn wait_for_lang_server(
let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5));
let result = futures::select! {
_ = rx.next() => {
- println!("{}โ Language server idle", log_prefix);
+ eprintln!("{}โ Language server idle", log_prefix);
anyhow::Ok(())
},
_ = timeout.fuse() => {
@@ -12,41 +12,32 @@ workspace = true
path = "src/edit_prediction_context.rs"
[dependencies]
+parking_lot.workspace = true
anyhow.workspace = true
-arrayvec.workspace = true
cloud_llm_client.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
-hashbrown.workspace = true
-indoc.workspace = true
-itertools.workspace = true
language.workspace = true
-log.workspace = true
-ordered-float.workspace = true
-postage.workspace = true
+lsp.workspace = true
project.workspace = true
-regex.workspace = true
+log.workspace = true
serde.workspace = true
-slotmap.workspace = true
-strum.workspace = true
-text.workspace = true
+smallvec.workspace = true
tree-sitter.workspace = true
util.workspace = true
[dev-dependencies]
-clap.workspace = true
+env_logger.workspace = true
+indoc.workspace = true
futures.workspace = true
gpui = { workspace = true, features = ["test-support"] }
-indoc.workspace = true
language = { workspace = true, features = ["test-support"] }
+lsp = { workspace = true, features = ["test-support"] }
pretty_assertions.workspace = true
project = {workspace= true, features = ["test-support"]}
serde_json.workspace = true
settings = {workspace= true, features = ["test-support"]}
text = { workspace = true, features = ["test-support"] }
-tree-sitter-c.workspace = true
-tree-sitter-cpp.workspace = true
-tree-sitter-go.workspace = true
util = { workspace = true, features = ["test-support"] }
zlog.workspace = true
@@ -0,0 +1,161 @@
+use crate::RelatedExcerpt;
+use language::{BufferSnapshot, OffsetRangeExt as _, Point};
+use std::ops::Range;
+
+#[cfg(not(test))]
+const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 512;
+#[cfg(test)]
+const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 24;
+
+pub fn assemble_excerpts(
+ buffer: &BufferSnapshot,
+ mut input_ranges: Vec<Range<Point>>,
+) -> Vec<RelatedExcerpt> {
+ merge_ranges(&mut input_ranges);
+
+ let mut outline_ranges = Vec::new();
+ let outline_items = buffer.outline_items_as_points_containing(0..buffer.len(), false, None);
+ let mut outline_ix = 0;
+ for input_range in &mut input_ranges {
+ *input_range = clip_range_to_lines(input_range, false, buffer);
+
+ while let Some(outline_item) = outline_items.get(outline_ix) {
+ let item_range = clip_range_to_lines(&outline_item.range, false, buffer);
+
+ if item_range.start > input_range.start {
+ break;
+ }
+
+ if item_range.end > input_range.start {
+ let body_range = outline_item
+ .body_range(buffer)
+ .map(|body| clip_range_to_lines(&body, true, buffer))
+ .filter(|body_range| {
+ body_range.to_offset(buffer).len() > MAX_OUTLINE_ITEM_BODY_SIZE
+ });
+
+ add_outline_item(
+ item_range.clone(),
+ body_range.clone(),
+ buffer,
+ &mut outline_ranges,
+ );
+
+ if let Some(body_range) = body_range
+ && input_range.start < body_range.start
+ {
+ let mut child_outline_ix = outline_ix + 1;
+ while let Some(next_outline_item) = outline_items.get(child_outline_ix) {
+ if next_outline_item.range.end > body_range.end {
+ break;
+ }
+ if next_outline_item.depth == outline_item.depth + 1 {
+ let next_item_range =
+ clip_range_to_lines(&next_outline_item.range, false, buffer);
+
+ add_outline_item(
+ next_item_range,
+ next_outline_item
+ .body_range(buffer)
+ .map(|body| clip_range_to_lines(&body, true, buffer)),
+ buffer,
+ &mut outline_ranges,
+ );
+ }
+ child_outline_ix += 1;
+ }
+ }
+ }
+
+ outline_ix += 1;
+ }
+ }
+
+ input_ranges.extend_from_slice(&outline_ranges);
+ merge_ranges(&mut input_ranges);
+
+ input_ranges
+ .into_iter()
+ .map(|range| {
+ let offset_range = range.to_offset(buffer);
+ RelatedExcerpt {
+ point_range: range,
+ anchor_range: buffer.anchor_before(offset_range.start)
+ ..buffer.anchor_after(offset_range.end),
+ text: buffer.as_rope().slice(offset_range),
+ }
+ })
+ .collect()
+}
+
+fn clip_range_to_lines(
+ range: &Range<Point>,
+ inward: bool,
+ buffer: &BufferSnapshot,
+) -> Range<Point> {
+ let mut range = range.clone();
+ if inward {
+ if range.start.column > 0 {
+ range.start.column = buffer.line_len(range.start.row);
+ }
+ range.end.column = 0;
+ } else {
+ range.start.column = 0;
+ if range.end.column > 0 {
+ range.end.column = buffer.line_len(range.end.row);
+ }
+ }
+ range
+}
+
+fn add_outline_item(
+ mut item_range: Range<Point>,
+ body_range: Option<Range<Point>>,
+ buffer: &BufferSnapshot,
+ outline_ranges: &mut Vec<Range<Point>>,
+) {
+ if let Some(mut body_range) = body_range {
+ if body_range.start.column > 0 {
+ body_range.start.column = buffer.line_len(body_range.start.row);
+ }
+ body_range.end.column = 0;
+
+ let head_range = item_range.start..body_range.start;
+ if head_range.start < head_range.end {
+ outline_ranges.push(head_range);
+ }
+
+ let tail_range = body_range.end..item_range.end;
+ if tail_range.start < tail_range.end {
+ outline_ranges.push(tail_range);
+ }
+ } else {
+ item_range.start.column = 0;
+ item_range.end.column = buffer.line_len(item_range.end.row);
+ outline_ranges.push(item_range);
+ }
+}
+
+pub fn merge_ranges(ranges: &mut Vec<Range<Point>>) {
+ ranges.sort_unstable_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
+
+ let mut index = 1;
+ while index < ranges.len() {
+ let mut prev_range_end = ranges[index - 1].end;
+ if prev_range_end.column > 0 {
+ prev_range_end += Point::new(1, 0);
+ }
+
+ if (prev_range_end + Point::new(1, 0))
+ .cmp(&ranges[index].start)
+ .is_ge()
+ {
+ let removed = ranges.remove(index);
+ if removed.end.cmp(&ranges[index - 1].end).is_gt() {
+ ranges[index - 1].end = removed.end;
+ }
+ } else {
+ index += 1;
+ }
+ }
+}
@@ -1,350 +0,0 @@
-use cloud_llm_client::predict_edits_v3::{self, Line};
-use language::{Language, LanguageId};
-use project::ProjectEntryId;
-use std::ops::Range;
-use std::sync::Arc;
-use std::{borrow::Cow, path::Path};
-use text::{Bias, BufferId, Rope};
-use util::paths::{path_ends_with, strip_path_suffix};
-use util::rel_path::RelPath;
-
-use crate::outline::OutlineDeclaration;
-
-#[derive(Debug, Clone, Eq, PartialEq, Hash)]
-pub struct Identifier {
- pub name: Arc<str>,
- pub language_id: LanguageId,
-}
-
-slotmap::new_key_type! {
- pub struct DeclarationId;
-}
-
-#[derive(Debug, Clone)]
-pub enum Declaration {
- File {
- project_entry_id: ProjectEntryId,
- declaration: FileDeclaration,
- cached_path: CachedDeclarationPath,
- },
- Buffer {
- project_entry_id: ProjectEntryId,
- buffer_id: BufferId,
- rope: Rope,
- declaration: BufferDeclaration,
- cached_path: CachedDeclarationPath,
- },
-}
-
-const ITEM_TEXT_TRUNCATION_LENGTH: usize = 1024;
-
-impl Declaration {
- pub fn identifier(&self) -> &Identifier {
- match self {
- Declaration::File { declaration, .. } => &declaration.identifier,
- Declaration::Buffer { declaration, .. } => &declaration.identifier,
- }
- }
-
- pub fn parent(&self) -> Option<DeclarationId> {
- match self {
- Declaration::File { declaration, .. } => declaration.parent,
- Declaration::Buffer { declaration, .. } => declaration.parent,
- }
- }
-
- pub fn as_buffer(&self) -> Option<&BufferDeclaration> {
- match self {
- Declaration::File { .. } => None,
- Declaration::Buffer { declaration, .. } => Some(declaration),
- }
- }
-
- pub fn as_file(&self) -> Option<&FileDeclaration> {
- match self {
- Declaration::Buffer { .. } => None,
- Declaration::File { declaration, .. } => Some(declaration),
- }
- }
-
- pub fn project_entry_id(&self) -> ProjectEntryId {
- match self {
- Declaration::File {
- project_entry_id, ..
- } => *project_entry_id,
- Declaration::Buffer {
- project_entry_id, ..
- } => *project_entry_id,
- }
- }
-
- pub fn cached_path(&self) -> &CachedDeclarationPath {
- match self {
- Declaration::File { cached_path, .. } => cached_path,
- Declaration::Buffer { cached_path, .. } => cached_path,
- }
- }
-
- pub fn item_range(&self) -> Range<usize> {
- match self {
- Declaration::File { declaration, .. } => declaration.item_range.clone(),
- Declaration::Buffer { declaration, .. } => declaration.item_range.clone(),
- }
- }
-
- pub fn item_line_range(&self) -> Range<Line> {
- match self {
- Declaration::File { declaration, .. } => declaration.item_line_range.clone(),
- Declaration::Buffer {
- declaration, rope, ..
- } => {
- Line(rope.offset_to_point(declaration.item_range.start).row)
- ..Line(rope.offset_to_point(declaration.item_range.end).row)
- }
- }
- }
-
- pub fn item_text(&self) -> (Cow<'_, str>, bool) {
- match self {
- Declaration::File { declaration, .. } => (
- declaration.text.as_ref().into(),
- declaration.text_is_truncated,
- ),
- Declaration::Buffer {
- rope, declaration, ..
- } => (
- rope.chunks_in_range(declaration.item_range.clone())
- .collect::<Cow<str>>(),
- declaration.item_range_is_truncated,
- ),
- }
- }
-
- pub fn signature_text(&self) -> (Cow<'_, str>, bool) {
- match self {
- Declaration::File { declaration, .. } => (
- declaration.text[self.signature_range_in_item_text()].into(),
- declaration.signature_is_truncated,
- ),
- Declaration::Buffer {
- rope, declaration, ..
- } => (
- rope.chunks_in_range(declaration.signature_range.clone())
- .collect::<Cow<str>>(),
- declaration.signature_range_is_truncated,
- ),
- }
- }
-
- pub fn signature_range(&self) -> Range<usize> {
- match self {
- Declaration::File { declaration, .. } => declaration.signature_range.clone(),
- Declaration::Buffer { declaration, .. } => declaration.signature_range.clone(),
- }
- }
-
- pub fn signature_line_range(&self) -> Range<Line> {
- match self {
- Declaration::File { declaration, .. } => declaration.signature_line_range.clone(),
- Declaration::Buffer {
- declaration, rope, ..
- } => {
- Line(rope.offset_to_point(declaration.signature_range.start).row)
- ..Line(rope.offset_to_point(declaration.signature_range.end).row)
- }
- }
- }
-
- pub fn signature_range_in_item_text(&self) -> Range<usize> {
- let signature_range = self.signature_range();
- let item_range = self.item_range();
- signature_range.start.saturating_sub(item_range.start)
- ..(signature_range.end.saturating_sub(item_range.start)).min(item_range.len())
- }
-}
-
-fn expand_range_to_line_boundaries_and_truncate(
- range: &Range<usize>,
- limit: usize,
- rope: &Rope,
-) -> (Range<usize>, Range<predict_edits_v3::Line>, bool) {
- let mut point_range = rope.offset_to_point(range.start)..rope.offset_to_point(range.end);
- point_range.start.column = 0;
- point_range.end.row += 1;
- point_range.end.column = 0;
-
- let mut item_range =
- rope.point_to_offset(point_range.start)..rope.point_to_offset(point_range.end);
- let is_truncated = item_range.len() > limit;
- if is_truncated {
- item_range.end = item_range.start + limit;
- }
- item_range.end = rope.clip_offset(item_range.end, Bias::Left);
-
- let line_range =
- predict_edits_v3::Line(point_range.start.row)..predict_edits_v3::Line(point_range.end.row);
- (item_range, line_range, is_truncated)
-}
-
-#[derive(Debug, Clone)]
-pub struct FileDeclaration {
- pub parent: Option<DeclarationId>,
- pub identifier: Identifier,
- /// offset range of the declaration in the file, expanded to line boundaries and truncated
- pub item_range: Range<usize>,
- /// line range of the declaration in the file, potentially truncated
- pub item_line_range: Range<predict_edits_v3::Line>,
- /// text of `item_range`
- pub text: Arc<str>,
- /// whether `text` was truncated
- pub text_is_truncated: bool,
- /// offset range of the signature in the file, expanded to line boundaries and truncated
- pub signature_range: Range<usize>,
- /// line range of the signature in the file, truncated
- pub signature_line_range: Range<Line>,
- /// whether `signature` was truncated
- pub signature_is_truncated: bool,
-}
-
-impl FileDeclaration {
- pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> FileDeclaration {
- let (item_range_in_file, item_line_range_in_file, text_is_truncated) =
- expand_range_to_line_boundaries_and_truncate(
- &declaration.item_range,
- ITEM_TEXT_TRUNCATION_LENGTH,
- rope,
- );
-
- let (mut signature_range_in_file, signature_line_range, mut signature_is_truncated) =
- expand_range_to_line_boundaries_and_truncate(
- &declaration.signature_range,
- ITEM_TEXT_TRUNCATION_LENGTH,
- rope,
- );
-
- if signature_range_in_file.start < item_range_in_file.start {
- signature_range_in_file.start = item_range_in_file.start;
- signature_is_truncated = true;
- }
- if signature_range_in_file.end > item_range_in_file.end {
- signature_range_in_file.end = item_range_in_file.end;
- signature_is_truncated = true;
- }
-
- FileDeclaration {
- parent: None,
- identifier: declaration.identifier,
- signature_range: signature_range_in_file,
- signature_line_range,
- signature_is_truncated,
- text: rope
- .chunks_in_range(item_range_in_file.clone())
- .collect::<String>()
- .into(),
- text_is_truncated,
- item_range: item_range_in_file,
- item_line_range: item_line_range_in_file,
- }
- }
-}
-
-#[derive(Debug, Clone)]
-pub struct BufferDeclaration {
- pub parent: Option<DeclarationId>,
- pub identifier: Identifier,
- pub item_range: Range<usize>,
- pub item_range_is_truncated: bool,
- pub signature_range: Range<usize>,
- pub signature_range_is_truncated: bool,
-}
-
-impl BufferDeclaration {
- pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> Self {
- let (item_range, _item_line_range, item_range_is_truncated) =
- expand_range_to_line_boundaries_and_truncate(
- &declaration.item_range,
- ITEM_TEXT_TRUNCATION_LENGTH,
- rope,
- );
- let (signature_range, _signature_line_range, signature_range_is_truncated) =
- expand_range_to_line_boundaries_and_truncate(
- &declaration.signature_range,
- ITEM_TEXT_TRUNCATION_LENGTH,
- rope,
- );
- Self {
- parent: None,
- identifier: declaration.identifier,
- item_range,
- item_range_is_truncated,
- signature_range,
- signature_range_is_truncated,
- }
- }
-}
-
-#[derive(Debug, Clone)]
-pub struct CachedDeclarationPath {
- pub worktree_abs_path: Arc<Path>,
- pub rel_path: Arc<RelPath>,
- /// The relative path of the file, possibly stripped according to `import_path_strip_regex`.
- pub rel_path_after_regex_stripping: Arc<RelPath>,
-}
-
-impl CachedDeclarationPath {
- pub fn new(
- worktree_abs_path: Arc<Path>,
- path: &Arc<RelPath>,
- language: Option<&Arc<Language>>,
- ) -> Self {
- let rel_path = path.clone();
- let rel_path_after_regex_stripping = if let Some(language) = language
- && let Some(strip_regex) = language.config().import_path_strip_regex.as_ref()
- && let Ok(stripped) = RelPath::unix(&Path::new(
- strip_regex.replace_all(rel_path.as_unix_str(), "").as_ref(),
- )) {
- Arc::from(stripped)
- } else {
- rel_path.clone()
- };
- CachedDeclarationPath {
- worktree_abs_path,
- rel_path,
- rel_path_after_regex_stripping,
- }
- }
-
- #[cfg(test)]
- pub fn new_for_test(worktree_abs_path: &str, rel_path: &str) -> Self {
- let rel_path: Arc<RelPath> = util::rel_path::rel_path(rel_path).into();
- CachedDeclarationPath {
- worktree_abs_path: std::path::PathBuf::from(worktree_abs_path).into(),
- rel_path_after_regex_stripping: rel_path.clone(),
- rel_path,
- }
- }
-
- pub fn ends_with_posix_path(&self, path: &Path) -> bool {
- if path.as_os_str().len() <= self.rel_path_after_regex_stripping.as_unix_str().len() {
- path_ends_with(self.rel_path_after_regex_stripping.as_std_path(), path)
- } else {
- if let Some(remaining) =
- strip_path_suffix(path, self.rel_path_after_regex_stripping.as_std_path())
- {
- path_ends_with(&self.worktree_abs_path, remaining)
- } else {
- false
- }
- }
- }
-
- pub fn equals_absolute_path(&self, path: &Path) -> bool {
- if let Some(remaining) =
- strip_path_suffix(path, &self.rel_path_after_regex_stripping.as_std_path())
- {
- self.worktree_abs_path.as_ref() == remaining
- } else {
- false
- }
- }
-}
@@ -1,539 +0,0 @@
-use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
-use collections::HashMap;
-use language::BufferSnapshot;
-use ordered_float::OrderedFloat;
-use project::ProjectEntryId;
-use serde::Serialize;
-use std::{cmp::Reverse, ops::Range, path::Path, sync::Arc};
-use strum::EnumIter;
-use text::{Point, ToPoint};
-use util::RangeExt as _;
-
-use crate::{
- CachedDeclarationPath, Declaration, EditPredictionExcerpt, Identifier,
- imports::{Import, Imports, Module},
- reference::{Reference, ReferenceRegion},
- syntax_index::SyntaxIndexState,
- text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient},
-};
-
-const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
-
-#[derive(Clone, Debug, PartialEq, Eq)]
-pub struct EditPredictionScoreOptions {
- pub omit_excerpt_overlaps: bool,
-}
-
-#[derive(Clone, Debug)]
-pub struct ScoredDeclaration {
- /// identifier used by the local reference
- pub identifier: Identifier,
- pub declaration: Declaration,
- pub components: DeclarationScoreComponents,
-}
-
-#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
-pub enum DeclarationStyle {
- Signature,
- Declaration,
-}
-
-#[derive(Clone, Debug, Serialize, Default)]
-pub struct DeclarationScores {
- pub signature: f32,
- pub declaration: f32,
- pub retrieval: f32,
-}
-
-impl ScoredDeclaration {
- /// Returns the score for this declaration with the specified style.
- pub fn score(&self, style: DeclarationStyle) -> f32 {
- // TODO: handle truncation
-
- // Score related to how likely this is the correct declaration, range 0 to 1
- let retrieval = self.retrieval_score();
-
- // Score related to the distance between the reference and cursor, range 0 to 1
- let distance_score = if self.components.is_referenced_nearby {
- 1.0 / (1.0 + self.components.reference_line_distance as f32 / 10.0).powf(2.0)
- } else {
- // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
- 0.5
- };
-
- // For now instead of linear combination, the scores are just multiplied together.
- let combined_score = 10.0 * retrieval * distance_score;
-
- match style {
- DeclarationStyle::Signature => {
- combined_score * self.components.excerpt_vs_signature_weighted_overlap
- }
- DeclarationStyle::Declaration => {
- 2.0 * combined_score * self.components.excerpt_vs_item_weighted_overlap
- }
- }
- }
-
- pub fn retrieval_score(&self) -> f32 {
- let mut score = if self.components.is_same_file {
- 10.0 / self.components.same_file_declaration_count as f32
- } else if self.components.path_import_match_count > 0 {
- 3.0
- } else if self.components.wildcard_path_import_match_count > 0 {
- 1.0
- } else if self.components.normalized_import_similarity > 0.0 {
- self.components.normalized_import_similarity
- } else if self.components.normalized_wildcard_import_similarity > 0.0 {
- 0.5 * self.components.normalized_wildcard_import_similarity
- } else {
- 1.0 / self.components.declaration_count as f32
- };
- score *= 1. + self.components.included_by_others as f32 / 2.;
- score *= 1. + self.components.includes_others as f32 / 4.;
- score
- }
-
- pub fn size(&self, style: DeclarationStyle) -> usize {
- match &self.declaration {
- Declaration::File { declaration, .. } => match style {
- DeclarationStyle::Signature => declaration.signature_range.len(),
- DeclarationStyle::Declaration => declaration.text.len(),
- },
- Declaration::Buffer { declaration, .. } => match style {
- DeclarationStyle::Signature => declaration.signature_range.len(),
- DeclarationStyle::Declaration => declaration.item_range.len(),
- },
- }
- }
-
- pub fn score_density(&self, style: DeclarationStyle) -> f32 {
- self.score(style) / self.size(style) as f32
- }
-}
-
-pub fn scored_declarations(
- options: &EditPredictionScoreOptions,
- index: &SyntaxIndexState,
- excerpt: &EditPredictionExcerpt,
- excerpt_occurrences: &Occurrences,
- adjacent_occurrences: &Occurrences,
- imports: &Imports,
- identifier_to_references: HashMap<Identifier, Vec<Reference>>,
- cursor_offset: usize,
- current_buffer: &BufferSnapshot,
-) -> Vec<ScoredDeclaration> {
- let cursor_point = cursor_offset.to_point(¤t_buffer);
-
- let mut wildcard_import_occurrences = Vec::new();
- let mut wildcard_import_paths = Vec::new();
- for wildcard_import in imports.wildcard_modules.iter() {
- match wildcard_import {
- Module::Namespace(namespace) => {
- wildcard_import_occurrences.push(namespace.occurrences())
- }
- Module::SourceExact(path) => wildcard_import_paths.push(path),
- Module::SourceFuzzy(path) => {
- wildcard_import_occurrences.push(Occurrences::from_path(&path))
- }
- }
- }
-
- let mut scored_declarations = Vec::new();
- let mut project_entry_id_to_outline_ranges: HashMap<ProjectEntryId, Vec<Range<usize>>> =
- HashMap::default();
- for (identifier, references) in identifier_to_references {
- let mut import_occurrences = Vec::new();
- let mut import_paths = Vec::new();
- let mut found_external_identifier: Option<&Identifier> = None;
-
- if let Some(imports) = imports.identifier_to_imports.get(&identifier) {
- // only use alias when it's the only import, could be generalized if some language
- // has overlapping aliases
- //
- // TODO: when an aliased declaration is included in the prompt, should include the
- // aliasing in the prompt.
- //
- // TODO: For SourceFuzzy consider having componentwise comparison that pays
- // attention to ordering.
- if let [
- Import::Alias {
- module,
- external_identifier,
- },
- ] = imports.as_slice()
- {
- match module {
- Module::Namespace(namespace) => {
- import_occurrences.push(namespace.occurrences())
- }
- Module::SourceExact(path) => import_paths.push(path),
- Module::SourceFuzzy(path) => {
- import_occurrences.push(Occurrences::from_path(&path))
- }
- }
- found_external_identifier = Some(&external_identifier);
- } else {
- for import in imports {
- match import {
- Import::Direct { module } => match module {
- Module::Namespace(namespace) => {
- import_occurrences.push(namespace.occurrences())
- }
- Module::SourceExact(path) => import_paths.push(path),
- Module::SourceFuzzy(path) => {
- import_occurrences.push(Occurrences::from_path(&path))
- }
- },
- Import::Alias { .. } => {}
- }
- }
- }
- }
-
- let identifier_to_lookup = found_external_identifier.unwrap_or(&identifier);
- // TODO: update this to be able to return more declarations? Especially if there is the
- // ability to quickly filter a large list (based on imports)
- let identifier_declarations = index
- .declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier_to_lookup);
- let declaration_count = identifier_declarations.len();
-
- if declaration_count == 0 {
- continue;
- }
-
- // TODO: option to filter out other candidates when same file / import match
- let mut checked_declarations = Vec::with_capacity(declaration_count);
- for (declaration_id, declaration) in identifier_declarations {
- match declaration {
- Declaration::Buffer {
- buffer_id,
- declaration: buffer_declaration,
- ..
- } => {
- if buffer_id == ¤t_buffer.remote_id() {
- let already_included_in_prompt =
- range_intersection(&buffer_declaration.item_range, &excerpt.range)
- .is_some()
- || excerpt
- .parent_declarations
- .iter()
- .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id);
- if !options.omit_excerpt_overlaps || !already_included_in_prompt {
- let declaration_line = buffer_declaration
- .item_range
- .start
- .to_point(current_buffer)
- .row;
- let declaration_line_distance =
- (cursor_point.row as i32 - declaration_line as i32).unsigned_abs();
- checked_declarations.push(CheckedDeclaration {
- declaration,
- same_file_line_distance: Some(declaration_line_distance),
- path_import_match_count: 0,
- wildcard_path_import_match_count: 0,
- });
- }
- continue;
- } else {
- }
- }
- Declaration::File { .. } => {}
- }
- let declaration_path = declaration.cached_path();
- let path_import_match_count = import_paths
- .iter()
- .filter(|import_path| {
- declaration_path_matches_import(&declaration_path, import_path)
- })
- .count();
- let wildcard_path_import_match_count = wildcard_import_paths
- .iter()
- .filter(|import_path| {
- declaration_path_matches_import(&declaration_path, import_path)
- })
- .count();
- checked_declarations.push(CheckedDeclaration {
- declaration,
- same_file_line_distance: None,
- path_import_match_count,
- wildcard_path_import_match_count,
- });
- }
-
- let mut max_import_similarity = 0.0;
- let mut max_wildcard_import_similarity = 0.0;
-
- let mut scored_declarations_for_identifier = Vec::with_capacity(checked_declarations.len());
- for checked_declaration in checked_declarations {
- let same_file_declaration_count =
- index.file_declaration_count(checked_declaration.declaration);
-
- let declaration = score_declaration(
- &identifier,
- &references,
- checked_declaration,
- same_file_declaration_count,
- declaration_count,
- &excerpt_occurrences,
- &adjacent_occurrences,
- &import_occurrences,
- &wildcard_import_occurrences,
- cursor_point,
- current_buffer,
- );
-
- if declaration.components.import_similarity > max_import_similarity {
- max_import_similarity = declaration.components.import_similarity;
- }
-
- if declaration.components.wildcard_import_similarity > max_wildcard_import_similarity {
- max_wildcard_import_similarity = declaration.components.wildcard_import_similarity;
- }
-
- project_entry_id_to_outline_ranges
- .entry(declaration.declaration.project_entry_id())
- .or_default()
- .push(declaration.declaration.item_range());
- scored_declarations_for_identifier.push(declaration);
- }
-
- if max_import_similarity > 0.0 || max_wildcard_import_similarity > 0.0 {
- for declaration in scored_declarations_for_identifier.iter_mut() {
- if max_import_similarity > 0.0 {
- declaration.components.max_import_similarity = max_import_similarity;
- declaration.components.normalized_import_similarity =
- declaration.components.import_similarity / max_import_similarity;
- }
- if max_wildcard_import_similarity > 0.0 {
- declaration.components.normalized_wildcard_import_similarity =
- declaration.components.wildcard_import_similarity
- / max_wildcard_import_similarity;
- }
- }
- }
-
- scored_declarations.extend(scored_declarations_for_identifier);
- }
-
- // TODO: Inform this via import / retrieval scores of outline items
- // TODO: Consider using a sweepline
- for scored_declaration in scored_declarations.iter_mut() {
- let project_entry_id = scored_declaration.declaration.project_entry_id();
- let Some(ranges) = project_entry_id_to_outline_ranges.get(&project_entry_id) else {
- continue;
- };
- for range in ranges {
- if range.contains_inclusive(&scored_declaration.declaration.item_range()) {
- scored_declaration.components.included_by_others += 1
- } else if scored_declaration
- .declaration
- .item_range()
- .contains_inclusive(range)
- {
- scored_declaration.components.includes_others += 1
- }
- }
- }
-
- scored_declarations.sort_unstable_by_key(|declaration| {
- Reverse(OrderedFloat(
- declaration.score(DeclarationStyle::Declaration),
- ))
- });
-
- scored_declarations
-}
-
-struct CheckedDeclaration<'a> {
- declaration: &'a Declaration,
- same_file_line_distance: Option<u32>,
- path_import_match_count: usize,
- wildcard_path_import_match_count: usize,
-}
-
-fn declaration_path_matches_import(
- declaration_path: &CachedDeclarationPath,
- import_path: &Arc<Path>,
-) -> bool {
- if import_path.is_absolute() {
- declaration_path.equals_absolute_path(import_path)
- } else {
- declaration_path.ends_with_posix_path(import_path)
- }
-}
-
-fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
- let start = a.start.clone().max(b.start.clone());
- let end = a.end.clone().min(b.end.clone());
- if start < end {
- Some(Range { start, end })
- } else {
- None
- }
-}
-
-fn score_declaration(
- identifier: &Identifier,
- references: &[Reference],
- checked_declaration: CheckedDeclaration,
- same_file_declaration_count: usize,
- declaration_count: usize,
- excerpt_occurrences: &Occurrences,
- adjacent_occurrences: &Occurrences,
- import_occurrences: &[Occurrences],
- wildcard_import_occurrences: &[Occurrences],
- cursor: Point,
- current_buffer: &BufferSnapshot,
-) -> ScoredDeclaration {
- let CheckedDeclaration {
- declaration,
- same_file_line_distance,
- path_import_match_count,
- wildcard_path_import_match_count,
- } = checked_declaration;
-
- let is_referenced_nearby = references
- .iter()
- .any(|r| r.region == ReferenceRegion::Nearby);
- let is_referenced_in_breadcrumb = references
- .iter()
- .any(|r| r.region == ReferenceRegion::Breadcrumb);
- let reference_count = references.len();
- let reference_line_distance = references
- .iter()
- .map(|r| {
- let reference_line = r.range.start.to_point(current_buffer).row as i32;
- (cursor.row as i32 - reference_line).unsigned_abs()
- })
- .min()
- .unwrap();
-
- let is_same_file = same_file_line_distance.is_some();
- let declaration_line_distance = same_file_line_distance.unwrap_or(u32::MAX);
-
- let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0);
- let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0);
- let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences);
- let excerpt_vs_signature_jaccard =
- jaccard_similarity(excerpt_occurrences, &item_signature_occurrences);
- let adjacent_vs_item_jaccard =
- jaccard_similarity(adjacent_occurrences, &item_source_occurrences);
- let adjacent_vs_signature_jaccard =
- jaccard_similarity(adjacent_occurrences, &item_signature_occurrences);
-
- let excerpt_vs_item_weighted_overlap =
- weighted_overlap_coefficient(excerpt_occurrences, &item_source_occurrences);
- let excerpt_vs_signature_weighted_overlap =
- weighted_overlap_coefficient(excerpt_occurrences, &item_signature_occurrences);
- let adjacent_vs_item_weighted_overlap =
- weighted_overlap_coefficient(adjacent_occurrences, &item_source_occurrences);
- let adjacent_vs_signature_weighted_overlap =
- weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences);
-
- let mut import_similarity = 0f32;
- let mut wildcard_import_similarity = 0f32;
- if !import_occurrences.is_empty() || !wildcard_import_occurrences.is_empty() {
- let cached_path = declaration.cached_path();
- let path_occurrences = Occurrences::from_worktree_path(
- cached_path
- .worktree_abs_path
- .file_name()
- .map(|f| f.to_string_lossy()),
- &cached_path.rel_path,
- );
- import_similarity = import_occurrences
- .iter()
- .map(|namespace_occurrences| {
- OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences))
- })
- .max()
- .map(|similarity| similarity.into_inner())
- .unwrap_or_default();
-
- // TODO: Consider something other than max
- wildcard_import_similarity = wildcard_import_occurrences
- .iter()
- .map(|namespace_occurrences| {
- OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences))
- })
- .max()
- .map(|similarity| similarity.into_inner())
- .unwrap_or_default();
- }
-
- // TODO: Consider adding declaration_file_count
- let score_components = DeclarationScoreComponents {
- is_same_file,
- is_referenced_nearby,
- is_referenced_in_breadcrumb,
- reference_line_distance,
- declaration_line_distance,
- reference_count,
- same_file_declaration_count,
- declaration_count,
- excerpt_vs_item_jaccard,
- excerpt_vs_signature_jaccard,
- adjacent_vs_item_jaccard,
- adjacent_vs_signature_jaccard,
- excerpt_vs_item_weighted_overlap,
- excerpt_vs_signature_weighted_overlap,
- adjacent_vs_item_weighted_overlap,
- adjacent_vs_signature_weighted_overlap,
- path_import_match_count,
- wildcard_path_import_match_count,
- import_similarity,
- max_import_similarity: 0.0,
- normalized_import_similarity: 0.0,
- wildcard_import_similarity,
- normalized_wildcard_import_similarity: 0.0,
- included_by_others: 0,
- includes_others: 0,
- };
-
- ScoredDeclaration {
- identifier: identifier.clone(),
- declaration: declaration.clone(),
- components: score_components,
- }
-}
-
-#[cfg(test)]
-mod test {
- use super::*;
-
- #[test]
- fn test_declaration_path_matches() {
- let declaration_path =
- CachedDeclarationPath::new_for_test("/home/user/project", "src/maths.ts");
-
- assert!(declaration_path_matches_import(
- &declaration_path,
- &Path::new("maths.ts").into()
- ));
-
- assert!(declaration_path_matches_import(
- &declaration_path,
- &Path::new("project/src/maths.ts").into()
- ));
-
- assert!(declaration_path_matches_import(
- &declaration_path,
- &Path::new("user/project/src/maths.ts").into()
- ));
-
- assert!(declaration_path_matches_import(
- &declaration_path,
- &Path::new("/home/user/project/src/maths.ts").into()
- ));
-
- assert!(!declaration_path_matches_import(
- &declaration_path,
- &Path::new("other.ts").into()
- ));
-
- assert!(!declaration_path_matches_import(
- &declaration_path,
- &Path::new("/home/user/project/src/other.ts").into()
- ));
- }
-}
@@ -1,335 +1,490 @@
-mod declaration;
-mod declaration_scoring;
+use crate::assemble_excerpts::assemble_excerpts;
+use anyhow::Result;
+use collections::HashMap;
+use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
+use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
+use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _};
+use project::{LocationLink, Project, ProjectPath};
+use serde::{Serialize, Serializer};
+use smallvec::SmallVec;
+use std::{
+ collections::hash_map,
+ ops::Range,
+ sync::Arc,
+ time::{Duration, Instant},
+};
+use util::{RangeExt as _, ResultExt};
+
+mod assemble_excerpts;
+#[cfg(test)]
+mod edit_prediction_context_tests;
mod excerpt;
-mod imports;
-mod outline;
-mod reference;
-mod syntax_index;
-pub mod text_similarity;
+#[cfg(test)]
+mod fake_definition_lsp;
-use std::{path::Path, sync::Arc};
+pub use cloud_llm_client::predict_edits_v3::Line;
+pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
-use cloud_llm_client::predict_edits_v3;
-use collections::HashMap;
-use gpui::{App, AppContext as _, Entity, Task};
-use language::BufferSnapshot;
-use text::{Point, ToOffset as _};
-
-pub use declaration::*;
-pub use declaration_scoring::*;
-pub use excerpt::*;
-pub use imports::*;
-pub use reference::*;
-pub use syntax_index::*;
-
-pub use predict_edits_v3::Line;
-
-#[derive(Clone, Debug, PartialEq)]
-pub struct EditPredictionContextOptions {
- pub use_imports: bool,
- pub excerpt: EditPredictionExcerptOptions,
- pub score: EditPredictionScoreOptions,
- pub max_retrieved_declarations: u8,
+const IDENTIFIER_LINE_COUNT: u32 = 3;
+
+pub struct RelatedExcerptStore {
+ project: WeakEntity<Project>,
+ related_files: Vec<RelatedFile>,
+ cache: HashMap<Identifier, Arc<CacheEntry>>,
+ update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
+ identifier_line_count: u32,
+}
+
+pub enum RelatedExcerptStoreEvent {
+ StartedRefresh,
+ FinishedRefresh {
+ cache_hit_count: usize,
+ cache_miss_count: usize,
+ mean_definition_latency: Duration,
+ max_definition_latency: Duration,
+ },
+}
+
+#[derive(Clone, Debug, PartialEq, Eq, Hash)]
+struct Identifier {
+ pub name: String,
+ pub range: Range<Anchor>,
+}
+
+enum DefinitionTask {
+ CacheHit(Arc<CacheEntry>),
+ CacheMiss(Task<Result<Option<Vec<LocationLink>>>>),
+}
+
+#[derive(Debug)]
+struct CacheEntry {
+ definitions: SmallVec<[CachedDefinition; 1]>,
}
#[derive(Clone, Debug)]
-pub struct EditPredictionContext {
- pub excerpt: EditPredictionExcerpt,
- pub excerpt_text: EditPredictionExcerptText,
- pub cursor_point: Point,
- pub declarations: Vec<ScoredDeclaration>,
+struct CachedDefinition {
+ path: ProjectPath,
+ buffer: Entity<Buffer>,
+ anchor_range: Range<Anchor>,
+}
+
+#[derive(Clone, Debug, Serialize)]
+pub struct RelatedFile {
+ #[serde(serialize_with = "serialize_project_path")]
+ pub path: ProjectPath,
+ #[serde(skip)]
+ pub buffer: WeakEntity<Buffer>,
+ pub excerpts: Vec<RelatedExcerpt>,
+ pub max_row: u32,
}
-impl EditPredictionContext {
- pub fn gather_context_in_background(
- cursor_point: Point,
- buffer: BufferSnapshot,
- options: EditPredictionContextOptions,
- syntax_index: Option<Entity<SyntaxIndex>>,
- cx: &mut App,
- ) -> Task<Option<Self>> {
- let parent_abs_path = project::File::from_dyn(buffer.file()).and_then(|f| {
- let mut path = f.worktree.read(cx).absolutize(&f.path);
- if path.pop() { Some(path) } else { None }
+impl RelatedFile {
+ pub fn merge_excerpts(&mut self) {
+ self.excerpts.sort_unstable_by(|a, b| {
+ a.point_range
+ .start
+ .cmp(&b.point_range.start)
+ .then(b.point_range.end.cmp(&a.point_range.end))
});
- if let Some(syntax_index) = syntax_index {
- let index_state =
- syntax_index.read_with(cx, |index, _cx| Arc::downgrade(index.state()));
- cx.background_spawn(async move {
- let parent_abs_path = parent_abs_path.as_deref();
- let index_state = index_state.upgrade()?;
- let index_state = index_state.lock().await;
- Self::gather_context(
- cursor_point,
- &buffer,
- parent_abs_path,
- &options,
- Some(&index_state),
- )
- })
- } else {
- cx.background_spawn(async move {
- let parent_abs_path = parent_abs_path.as_deref();
- Self::gather_context(cursor_point, &buffer, parent_abs_path, &options, None)
- })
+ let mut index = 1;
+ while index < self.excerpts.len() {
+ if self.excerpts[index - 1]
+ .point_range
+ .end
+ .cmp(&self.excerpts[index].point_range.start)
+ .is_ge()
+ {
+ let removed = self.excerpts.remove(index);
+ if removed
+ .point_range
+ .end
+ .cmp(&self.excerpts[index - 1].point_range.end)
+ .is_gt()
+ {
+ self.excerpts[index - 1].point_range.end = removed.point_range.end;
+ self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end;
+ }
+ } else {
+ index += 1;
+ }
}
}
+}
- pub fn gather_context(
- cursor_point: Point,
- buffer: &BufferSnapshot,
- parent_abs_path: Option<&Path>,
- options: &EditPredictionContextOptions,
- index_state: Option<&SyntaxIndexState>,
- ) -> Option<Self> {
- let imports = if options.use_imports {
- Imports::gather(&buffer, parent_abs_path)
- } else {
- Imports::default()
- };
- Self::gather_context_with_references_fn(
- cursor_point,
- buffer,
- &imports,
- options,
- index_state,
- references_in_excerpt,
- )
- }
+#[derive(Clone, Debug, Serialize)]
+pub struct RelatedExcerpt {
+ #[serde(skip)]
+ pub anchor_range: Range<Anchor>,
+ #[serde(serialize_with = "serialize_point_range")]
+ pub point_range: Range<Point>,
+ #[serde(serialize_with = "serialize_rope")]
+ pub text: Rope,
+}
- pub fn gather_context_with_references_fn(
- cursor_point: Point,
- buffer: &BufferSnapshot,
- imports: &Imports,
- options: &EditPredictionContextOptions,
- index_state: Option<&SyntaxIndexState>,
- get_references: impl FnOnce(
- &EditPredictionExcerpt,
- &EditPredictionExcerptText,
- &BufferSnapshot,
- ) -> HashMap<Identifier, Vec<Reference>>,
- ) -> Option<Self> {
- let excerpt = EditPredictionExcerpt::select_from_buffer(
- cursor_point,
- buffer,
- &options.excerpt,
- index_state,
- )?;
- let excerpt_text = excerpt.text(buffer);
-
- let declarations = if options.max_retrieved_declarations > 0
- && let Some(index_state) = index_state
- {
- let excerpt_occurrences =
- text_similarity::Occurrences::within_string(&excerpt_text.body);
-
- let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
- let adjacent_end = Point::new(cursor_point.row + 1, 0);
- let adjacent_occurrences = text_similarity::Occurrences::within_string(
- &buffer
- .text_for_range(adjacent_start..adjacent_end)
- .collect::<String>(),
- );
+fn serialize_project_path<S: Serializer>(
+ project_path: &ProjectPath,
+ serializer: S,
+) -> Result<S::Ok, S::Error> {
+ project_path.path.serialize(serializer)
+}
- let cursor_offset_in_file = cursor_point.to_offset(buffer);
+fn serialize_rope<S: Serializer>(rope: &Rope, serializer: S) -> Result<S::Ok, S::Error> {
+ rope.to_string().serialize(serializer)
+}
- let references = get_references(&excerpt, &excerpt_text, buffer);
+fn serialize_point_range<S: Serializer>(
+ range: &Range<Point>,
+ serializer: S,
+) -> Result<S::Ok, S::Error> {
+ [
+ [range.start.row, range.start.column],
+ [range.end.row, range.end.column],
+ ]
+ .serialize(serializer)
+}
- let mut declarations = scored_declarations(
- &options.score,
- &index_state,
- &excerpt,
- &excerpt_occurrences,
- &adjacent_occurrences,
- &imports,
- references,
- cursor_offset_in_file,
- buffer,
- );
- // TODO [zeta2] if we need this when we ship, we should probably do it in a smarter way
- declarations.truncate(options.max_retrieved_declarations as usize);
- declarations
- } else {
- vec![]
- };
+const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
+
+impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
+
+impl RelatedExcerptStore {
+ pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
+ let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity<Buffer>, Anchor)>();
+ cx.spawn(async move |this, cx| {
+ let executor = cx.background_executor().clone();
+ while let Some((mut buffer, mut position)) = update_rx.next().await {
+ let mut timer = executor.timer(DEBOUNCE_DURATION).fuse();
+ loop {
+ futures::select_biased! {
+ next = update_rx.next() => {
+ if let Some((new_buffer, new_position)) = next {
+ buffer = new_buffer;
+ position = new_position;
+ timer = executor.timer(DEBOUNCE_DURATION).fuse();
+ } else {
+ return anyhow::Ok(());
+ }
+ }
+ _ = timer => break,
+ }
+ }
- Some(Self {
- excerpt,
- excerpt_text,
- cursor_point,
- declarations,
+ Self::fetch_excerpts(this.clone(), buffer, position, cx).await?;
+ }
+ anyhow::Ok(())
})
+ .detach_and_log_err(cx);
+
+ RelatedExcerptStore {
+ project: project.downgrade(),
+ update_tx,
+ related_files: Vec::new(),
+ cache: Default::default(),
+ identifier_line_count: IDENTIFIER_LINE_COUNT,
+ }
}
-}
-#[cfg(test)]
-mod tests {
- use super::*;
- use std::sync::Arc;
-
- use gpui::{Entity, TestAppContext};
- use indoc::indoc;
- use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
- use project::{FakeFs, Project};
- use serde_json::json;
- use settings::SettingsStore;
- use util::path;
-
- use crate::{EditPredictionExcerptOptions, SyntaxIndex};
-
- #[gpui::test]
- async fn test_call_site(cx: &mut TestAppContext) {
- let (project, index, _rust_lang_id) = init_test(cx).await;
-
- let buffer = project
- .update(cx, |project, cx| {
- let project_path = project.find_project_path("c.rs", cx).unwrap();
- project.open_buffer(project_path, cx)
- })
- .await
- .unwrap();
-
- cx.run_until_parked();
-
- // first process_data call site
- let cursor_point = language::Point::new(8, 21);
- let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
-
- let context = cx
- .update(|cx| {
- EditPredictionContext::gather_context_in_background(
- cursor_point,
- buffer_snapshot,
- EditPredictionContextOptions {
- use_imports: true,
- excerpt: EditPredictionExcerptOptions {
- max_bytes: 60,
- min_bytes: 10,
- target_before_cursor_over_total_bytes: 0.5,
- },
- score: EditPredictionScoreOptions {
- omit_excerpt_overlaps: true,
- },
- max_retrieved_declarations: u8::MAX,
- },
- Some(index.clone()),
- cx,
- )
- })
- .await
- .unwrap();
-
- let mut snippet_identifiers = context
- .declarations
- .iter()
- .map(|snippet| snippet.identifier.name.as_ref())
- .collect::<Vec<_>>();
- snippet_identifiers.sort();
- assert_eq!(snippet_identifiers, vec!["main", "process_data"]);
- drop(buffer);
+ pub fn set_identifier_line_count(&mut self, count: u32) {
+ self.identifier_line_count = count;
}
- async fn init_test(
- cx: &mut TestAppContext,
- ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
- cx.update(|cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- });
+ pub fn refresh(&mut self, buffer: Entity<Buffer>, position: Anchor, _: &mut Context<Self>) {
+ self.update_tx.unbounded_send((buffer, position)).ok();
+ }
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- path!("/root"),
- json!({
- "a.rs": indoc! {r#"
- fn main() {
- let x = 1;
- let y = 2;
- let z = add(x, y);
- println!("Result: {}", z);
- }
+ pub fn related_files(&self) -> &[RelatedFile] {
+ &self.related_files
+ }
- fn add(a: i32, b: i32) -> i32 {
- a + b
- }
- "#},
- "b.rs": indoc! {"
- pub struct Config {
- pub name: String,
- pub value: i32,
- }
+ async fn fetch_excerpts(
+ this: WeakEntity<Self>,
+ buffer: Entity<Buffer>,
+ position: Anchor,
+ cx: &mut AsyncApp,
+ ) -> Result<()> {
+ let (project, snapshot, identifier_line_count) = this.read_with(cx, |this, cx| {
+ (
+ this.project.upgrade(),
+ buffer.read(cx).snapshot(),
+ this.identifier_line_count,
+ )
+ })?;
+ let Some(project) = project else {
+ return Ok(());
+ };
- impl Config {
- pub fn new(name: String, value: i32) -> Self {
- Config { name, value }
- }
- }
- "},
- "c.rs": indoc! {r#"
- use std::collections::HashMap;
-
- fn main() {
- let args: Vec<String> = std::env::args().collect();
- let data: Vec<i32> = args[1..]
- .iter()
- .filter_map(|s| s.parse().ok())
- .collect();
- let result = process_data(data);
- println!("{:?}", result);
- }
+ let file = snapshot.file().cloned();
+ if let Some(file) = &file {
+ log::debug!("retrieving_context buffer:{}", file.path().as_unix_str());
+ }
+
+ this.update(cx, |_, cx| {
+ cx.emit(RelatedExcerptStoreEvent::StartedRefresh);
+ })?;
- fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
- let mut counts = HashMap::new();
- for value in data {
- *counts.entry(value).or_insert(0) += 1;
+ let identifiers = cx
+ .background_spawn(async move {
+ identifiers_for_position(&snapshot, position, identifier_line_count)
+ })
+ .await;
+
+ let async_cx = cx.clone();
+ let start_time = Instant::now();
+ let futures = this.update(cx, |this, cx| {
+ identifiers
+ .into_iter()
+ .filter_map(|identifier| {
+ let task = if let Some(entry) = this.cache.get(&identifier) {
+ DefinitionTask::CacheHit(entry.clone())
+ } else {
+ DefinitionTask::CacheMiss(
+ this.project
+ .update(cx, |project, cx| {
+ project.definitions(&buffer, identifier.range.start, cx)
+ })
+ .ok()?,
+ )
+ };
+
+ let cx = async_cx.clone();
+ let project = project.clone();
+ Some(async move {
+ match task {
+ DefinitionTask::CacheHit(cache_entry) => {
+ Some((identifier, cache_entry, None))
+ }
+ DefinitionTask::CacheMiss(task) => {
+ let locations = task.await.log_err()??;
+ let duration = start_time.elapsed();
+ cx.update(|cx| {
+ (
+ identifier,
+ Arc::new(CacheEntry {
+ definitions: locations
+ .into_iter()
+ .filter_map(|location| {
+ process_definition(location, &project, cx)
+ })
+ .collect(),
+ }),
+ Some(duration),
+ )
+ })
+ .ok()
+ }
}
- counts
- }
+ })
+ })
+ .collect::<Vec<_>>()
+ })?;
+
+ let mut cache_hit_count = 0;
+ let mut cache_miss_count = 0;
+ let mut mean_definition_latency = Duration::ZERO;
+ let mut max_definition_latency = Duration::ZERO;
+ let mut new_cache = HashMap::default();
+ new_cache.reserve(futures.len());
+ for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() {
+ new_cache.insert(identifier, entry);
+ if let Some(duration) = duration {
+ cache_miss_count += 1;
+ mean_definition_latency += duration;
+ max_definition_latency = max_definition_latency.max(duration);
+ } else {
+ cache_hit_count += 1;
+ }
+ }
+ mean_definition_latency /= cache_miss_count.max(1) as u32;
- #[cfg(test)]
- mod tests {
- use super::*;
+ let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?;
- #[test]
- fn test_process_data() {
- let data = vec![1, 2, 2, 3];
- let result = process_data(data);
- assert_eq!(result.get(&2), Some(&2));
- }
- }
- "#}
- }),
- )
- .await;
- let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
- let language_registry = project.read_with(cx, |project, _| project.languages().clone());
- let lang = rust_lang();
- let lang_id = lang.id();
- language_registry.add(Arc::new(lang));
-
- let file_indexing_parallelism = 2;
- let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx));
- cx.run_until_parked();
-
- (project, index, lang_id)
+ if let Some(file) = &file {
+ log::debug!(
+ "finished retrieving context buffer:{}, latency:{:?}",
+ file.path().as_unix_str(),
+ start_time.elapsed()
+ );
+ }
+
+ this.update(cx, |this, cx| {
+ this.cache = new_cache;
+ this.related_files = related_files;
+ cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
+ cache_hit_count,
+ cache_miss_count,
+ mean_definition_latency,
+ max_definition_latency,
+ });
+ })?;
+
+ anyhow::Ok(())
+ }
+}
+
+async fn rebuild_related_files(
+ new_entries: HashMap<Identifier, Arc<CacheEntry>>,
+ cx: &mut AsyncApp,
+) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedFile>)> {
+ let mut snapshots = HashMap::default();
+ for entry in new_entries.values() {
+ for definition in &entry.definitions {
+ if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
+ definition
+ .buffer
+ .read_with(cx, |buffer, _| buffer.parsing_idle())?
+ .await;
+ e.insert(
+ definition
+ .buffer
+ .read_with(cx, |buffer, _| buffer.snapshot())?,
+ );
+ }
+ }
}
- fn rust_lang() -> Language {
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- )
- .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
- .unwrap()
- .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
- .unwrap()
+ Ok(cx
+ .background_spawn(async move {
+ let mut files = Vec::<RelatedFile>::new();
+ let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
+ let mut paths_by_buffer = HashMap::default();
+ for entry in new_entries.values() {
+ for definition in &entry.definitions {
+ let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else {
+ continue;
+ };
+ paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone());
+ ranges_by_buffer
+ .entry(definition.buffer.clone())
+ .or_default()
+ .push(definition.anchor_range.to_point(snapshot));
+ }
+ }
+
+ for (buffer, ranges) in ranges_by_buffer {
+ let Some(snapshot) = snapshots.get(&buffer.entity_id()) else {
+ continue;
+ };
+ let Some(project_path) = paths_by_buffer.get(&buffer.entity_id()) else {
+ continue;
+ };
+ let excerpts = assemble_excerpts(snapshot, ranges);
+ files.push(RelatedFile {
+ path: project_path.clone(),
+ buffer: buffer.downgrade(),
+ excerpts,
+ max_row: snapshot.max_point().row,
+ });
+ }
+
+ files.sort_by_key(|file| file.path.clone());
+ (new_entries, files)
+ })
+ .await)
+}
+
+fn process_definition(
+ location: LocationLink,
+ project: &Entity<Project>,
+ cx: &mut App,
+) -> Option<CachedDefinition> {
+ let buffer = location.target.buffer.read(cx);
+ let anchor_range = location.target.range;
+ let file = buffer.file()?;
+ let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?;
+ if worktree.read(cx).is_single_file() {
+ return None;
}
+ Some(CachedDefinition {
+ path: ProjectPath {
+ worktree_id: file.worktree_id(cx),
+ path: file.path().clone(),
+ },
+ buffer: location.target.buffer,
+ anchor_range,
+ })
+}
+
+/// Gets all of the identifiers that are present in the given line, and its containing
+/// outline items.
+fn identifiers_for_position(
+ buffer: &BufferSnapshot,
+ position: Anchor,
+ identifier_line_count: u32,
+) -> Vec<Identifier> {
+ let offset = position.to_offset(buffer);
+ let point = buffer.offset_to_point(offset);
+
+ // Search for identifiers on lines adjacent to the cursor.
+ let start = Point::new(point.row.saturating_sub(identifier_line_count), 0);
+ let end = Point::new(point.row + identifier_line_count + 1, 0).min(buffer.max_point());
+ let line_range = start..end;
+ let mut ranges = vec![line_range.to_offset(&buffer)];
+
+ // Search for identifiers mentioned in headers/signatures of containing outline items.
+ let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None);
+ for item in outline_items {
+ if let Some(body_range) = item.body_range(&buffer) {
+ ranges.push(item.range.start..body_range.start.to_offset(&buffer));
+ } else {
+ ranges.push(item.range.clone());
+ }
+ }
+
+ ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
+ ranges.dedup_by(|a, b| {
+ if a.start <= b.end {
+ b.start = b.start.min(a.start);
+ b.end = b.end.max(a.end);
+ true
+ } else {
+ false
+ }
+ });
+
+ let mut identifiers = Vec::new();
+ let outer_range =
+ ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end);
+
+ let mut captures = buffer
+ .syntax
+ .captures(outer_range.clone(), &buffer.text, |grammar| {
+ grammar
+ .highlights_config
+ .as_ref()
+ .map(|config| &config.query)
+ });
+
+ for range in ranges {
+ captures.set_byte_range(range.start..outer_range.end);
+
+ let mut last_range = None;
+ while let Some(capture) = captures.peek() {
+ let node_range = capture.node.byte_range();
+ if node_range.start > range.end {
+ break;
+ }
+ let config = captures.grammars()[capture.grammar_index]
+ .highlights_config
+ .as_ref();
+
+ if let Some(config) = config
+ && config.identifier_capture_indices.contains(&capture.index)
+ && range.contains_inclusive(&node_range)
+ && Some(&node_range) != last_range.as_ref()
+ {
+ let name = buffer.text_for_range(node_range.clone()).collect();
+ identifiers.push(Identifier {
+ range: buffer.anchor_after(node_range.start)
+ ..buffer.anchor_before(node_range.end),
+ name,
+ });
+ last_range = Some(node_range);
+ }
+
+ captures.advance();
+ }
+ }
+
+ identifiers
}
@@ -0,0 +1,510 @@
+use super::*;
+use futures::channel::mpsc::UnboundedReceiver;
+use gpui::TestAppContext;
+use indoc::indoc;
+use language::{Point, ToPoint as _, rust_lang};
+use lsp::FakeLanguageServer;
+use project::{FakeFs, LocationLink, Project};
+use serde_json::json;
+use settings::SettingsStore;
+use std::fmt::Write as _;
+use util::{path, test::marked_text_ranges};
+
+#[gpui::test]
+async fn test_edit_prediction_context(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/root"), test_project_1()).await;
+
+ let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+ let mut servers = setup_fake_lsp(&project, cx);
+
+ let (buffer, _handle) = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer_with_lsp(path!("/root/src/main.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ let _server = servers.next().await.unwrap();
+ cx.run_until_parked();
+
+ let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(&project, cx));
+ related_excerpt_store.update(cx, |store, cx| {
+ let position = {
+ let buffer = buffer.read(cx);
+ let offset = buffer.text().find("todo").unwrap();
+ buffer.anchor_before(offset)
+ };
+
+ store.set_identifier_line_count(0);
+ store.refresh(buffer.clone(), position, cx);
+ });
+
+ cx.executor().advance_clock(DEBOUNCE_DURATION);
+ related_excerpt_store.update(cx, |store, _| {
+ let excerpts = store.related_files();
+ assert_related_files(
+ &excerpts,
+ &[
+ (
+ "src/company.rs",
+ &[indoc! {"
+ pub struct Company {
+ owner: Arc<Person>,
+ address: Address,
+ }"}],
+ ),
+ (
+ "src/main.rs",
+ &[
+ indoc! {"
+ pub struct Session {
+ company: Arc<Company>,
+ }
+
+ impl Session {
+ pub fn set_company(&mut self, company: Arc<Company>) {"},
+ indoc! {"
+ }
+ }"},
+ ],
+ ),
+ (
+ "src/person.rs",
+ &[
+ indoc! {"
+ impl Person {
+ pub fn get_first_name(&self) -> &str {
+ &self.first_name
+ }"},
+ "}",
+ ],
+ ),
+ ],
+ );
+ });
+}
+
+#[gpui::test]
+fn test_assemble_excerpts(cx: &mut TestAppContext) {
+ let table = [
+ (
+ indoc! {r#"
+ struct User {
+ first_name: String,
+ ยซlast_nameยป: String,
+ age: u32,
+ email: String,
+ create_at: Instant,
+ }
+
+ impl User {
+ pub fn first_name(&self) -> String {
+ self.first_name.clone()
+ }
+
+ pub fn full_name(&self) -> String {
+ ยซ format!("{} {}", self.first_name, self.last_name)
+ ยป }
+ }
+ "#},
+ indoc! {r#"
+ struct User {
+ first_name: String,
+ last_name: String,
+ โฆ
+ }
+
+ impl User {
+ โฆ
+ pub fn full_name(&self) -> String {
+ format!("{} {}", self.first_name, self.last_name)
+ }
+ }
+ "#},
+ ),
+ (
+ indoc! {r#"
+ struct ยซUserยป {
+ first_name: String,
+ last_name: String,
+ age: u32,
+ }
+
+ impl User {
+ // methods
+ }
+ "#},
+ indoc! {r#"
+ struct User {
+ first_name: String,
+ last_name: String,
+ age: u32,
+ }
+ โฆ
+ "#},
+ ),
+ (
+ indoc! {r#"
+ trait ยซFooProviderยป {
+ const NAME: &'static str;
+
+ fn provide_foo(&self, id: usize) -> Foo;
+
+ fn provide_foo_batched(&self, ids: &[usize]) -> Vec<Foo> {
+ ids.iter()
+ .map(|id| self.provide_foo(*id))
+ .collect()
+ }
+
+ fn sync(&self);
+ }
+ "#
+ },
+ indoc! {r#"
+ trait FooProvider {
+ const NAME: &'static str;
+
+ fn provide_foo(&self, id: usize) -> Foo;
+
+ fn provide_foo_batched(&self, ids: &[usize]) -> Vec<Foo> {
+ โฆ
+ }
+
+ fn sync(&self);
+ }
+ "#},
+ ),
+ (
+ indoc! {r#"
+ trait ยซSomethingยป {
+ fn method1(&self, id: usize) -> Foo;
+
+ fn method2(&self, ids: &[usize]) -> Vec<Foo> {
+ struct Helper1 {
+ field1: usize,
+ }
+
+ struct Helper2 {
+ field2: usize,
+ }
+
+ struct Helper3 {
+ filed2: usize,
+ }
+ }
+
+ fn sync(&self);
+ }
+ "#
+ },
+ indoc! {r#"
+ trait Something {
+ fn method1(&self, id: usize) -> Foo;
+
+ fn method2(&self, ids: &[usize]) -> Vec<Foo> {
+ โฆ
+ }
+
+ fn sync(&self);
+ }
+ "#},
+ ),
+ ];
+
+ for (input, expected_output) in table {
+ let (input, ranges) = marked_text_ranges(&input, false);
+ let buffer = cx.new(|cx| Buffer::local(input, cx).with_language(rust_lang(), cx));
+ buffer.read_with(cx, |buffer, _cx| {
+ let ranges: Vec<Range<Point>> = ranges
+ .into_iter()
+ .map(|range| range.to_point(&buffer))
+ .collect();
+
+ let excerpts = assemble_excerpts(&buffer.snapshot(), ranges);
+
+ let output = format_excerpts(buffer, &excerpts);
+ assert_eq!(output, expected_output);
+ });
+ }
+}
+
+#[gpui::test]
+async fn test_fake_definition_lsp(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(path!("/root"), test_project_1()).await;
+
+ let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+ let mut servers = setup_fake_lsp(&project, cx);
+
+ let (buffer, _handle) = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer_with_lsp(path!("/root/src/main.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ let _server = servers.next().await.unwrap();
+ cx.run_until_parked();
+
+ let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text());
+
+ let definitions = project
+ .update(cx, |project, cx| {
+ let offset = buffer_text.find("Address {").unwrap();
+ project.definitions(&buffer, offset, cx)
+ })
+ .await
+ .unwrap()
+ .unwrap();
+ assert_definitions(&definitions, &["pub struct Address {"], cx);
+
+ let definitions = project
+ .update(cx, |project, cx| {
+ let offset = buffer_text.find("State::CA").unwrap();
+ project.definitions(&buffer, offset, cx)
+ })
+ .await
+ .unwrap()
+ .unwrap();
+ assert_definitions(&definitions, &["pub enum State {"], cx);
+
+ let definitions = project
+ .update(cx, |project, cx| {
+ let offset = buffer_text.find("to_string()").unwrap();
+ project.definitions(&buffer, offset, cx)
+ })
+ .await
+ .unwrap()
+ .unwrap();
+ assert_definitions(&definitions, &["pub fn to_string(&self) -> String {"], cx);
+}
+
+fn init_test(cx: &mut TestAppContext) {
+ let settings_store = cx.update(|cx| SettingsStore::test(cx));
+ cx.set_global(settings_store);
+ env_logger::try_init().ok();
+}
+
+fn setup_fake_lsp(
+ project: &Entity<Project>,
+ cx: &mut TestAppContext,
+) -> UnboundedReceiver<FakeLanguageServer> {
+ let (language_registry, fs) = project.read_with(cx, |project, _| {
+ (project.languages().clone(), project.fs().clone())
+ });
+ let language = rust_lang();
+ language_registry.add(language.clone());
+ fake_definition_lsp::register_fake_definition_server(&language_registry, language, fs)
+}
+
+fn test_project_1() -> serde_json::Value {
+ let person_rs = indoc! {r#"
+ pub struct Person {
+ first_name: String,
+ last_name: String,
+ email: String,
+ age: u32,
+ }
+
+ impl Person {
+ pub fn get_first_name(&self) -> &str {
+ &self.first_name
+ }
+
+ pub fn get_last_name(&self) -> &str {
+ &self.last_name
+ }
+
+ pub fn get_email(&self) -> &str {
+ &self.email
+ }
+
+ pub fn get_age(&self) -> u32 {
+ self.age
+ }
+ }
+ "#};
+
+ let address_rs = indoc! {r#"
+ pub struct Address {
+ street: String,
+ city: String,
+ state: State,
+ zip: u32,
+ }
+
+ pub enum State {
+ CA,
+ OR,
+ WA,
+ TX,
+ // ...
+ }
+
+ impl Address {
+ pub fn get_street(&self) -> &str {
+ &self.street
+ }
+
+ pub fn get_city(&self) -> &str {
+ &self.city
+ }
+
+ pub fn get_state(&self) -> State {
+ self.state
+ }
+
+ pub fn get_zip(&self) -> u32 {
+ self.zip
+ }
+ }
+ "#};
+
+ let company_rs = indoc! {r#"
+ use super::person::Person;
+ use super::address::Address;
+
+ pub struct Company {
+ owner: Arc<Person>,
+ address: Address,
+ }
+
+ impl Company {
+ pub fn get_owner(&self) -> &Person {
+ &self.owner
+ }
+
+ pub fn get_address(&self) -> &Address {
+ &self.address
+ }
+
+ pub fn to_string(&self) -> String {
+ format!("{} ({})", self.owner.first_name, self.address.city)
+ }
+ }
+ "#};
+
+ let main_rs = indoc! {r#"
+ use std::sync::Arc;
+ use super::person::Person;
+ use super::address::Address;
+ use super::company::Company;
+
+ pub struct Session {
+ company: Arc<Company>,
+ }
+
+ impl Session {
+ pub fn set_company(&mut self, company: Arc<Company>) {
+ self.company = company;
+ if company.owner != self.company.owner {
+ log("new owner", company.owner.get_first_name()); todo();
+ }
+ }
+ }
+
+ fn main() {
+ let company = Company {
+ owner: Arc::new(Person {
+ first_name: "John".to_string(),
+ last_name: "Doe".to_string(),
+ email: "john@example.com".to_string(),
+ age: 30,
+ }),
+ address: Address {
+ street: "123 Main St".to_string(),
+ city: "Anytown".to_string(),
+ state: State::CA,
+ zip: 12345,
+ },
+ };
+
+ println!("Company: {}", company.to_string());
+ }
+ "#};
+
+ json!({
+ "src": {
+ "person.rs": person_rs,
+ "address.rs": address_rs,
+ "company.rs": company_rs,
+ "main.rs": main_rs,
+ },
+ })
+}
+
+fn assert_related_files(actual_files: &[RelatedFile], expected_files: &[(&str, &[&str])]) {
+ let actual_files = actual_files
+ .iter()
+ .map(|file| {
+ let excerpts = file
+ .excerpts
+ .iter()
+ .map(|excerpt| excerpt.text.to_string())
+ .collect::<Vec<_>>();
+ (file.path.path.as_unix_str(), excerpts)
+ })
+ .collect::<Vec<_>>();
+ let expected_excerpts = expected_files
+ .iter()
+ .map(|(path, texts)| {
+ (
+ *path,
+ texts
+ .iter()
+ .map(|line| line.to_string())
+ .collect::<Vec<_>>(),
+ )
+ })
+ .collect::<Vec<_>>();
+ pretty_assertions::assert_eq!(actual_files, expected_excerpts)
+}
+
+fn assert_definitions(definitions: &[LocationLink], first_lines: &[&str], cx: &mut TestAppContext) {
+ let actual_first_lines = definitions
+ .iter()
+ .map(|definition| {
+ definition.target.buffer.read_with(cx, |buffer, _| {
+ let mut start = definition.target.range.start.to_point(&buffer);
+ start.column = 0;
+ let end = Point::new(start.row, buffer.line_len(start.row));
+ buffer
+ .text_for_range(start..end)
+ .collect::<String>()
+ .trim()
+ .to_string()
+ })
+ })
+ .collect::<Vec<String>>();
+
+ assert_eq!(actual_first_lines, first_lines);
+}
+
+fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String {
+ let mut output = String::new();
+ let file_line_count = buffer.max_point().row;
+ let mut current_row = 0;
+ for excerpt in excerpts {
+ if excerpt.text.is_empty() {
+ continue;
+ }
+ if current_row < excerpt.point_range.start.row {
+ writeln!(&mut output, "โฆ").unwrap();
+ }
+ current_row = excerpt.point_range.start.row;
+
+ for line in excerpt.text.to_string().lines() {
+ output.push_str(line);
+ output.push('\n');
+ current_row += 1;
+ }
+ }
+ if current_row < file_line_count {
+ writeln!(&mut output, "โฆ").unwrap();
+ }
+ output
+}
@@ -1,11 +1,9 @@
-use language::{BufferSnapshot, LanguageId};
+use cloud_llm_client::predict_edits_v3::Line;
+use language::{BufferSnapshot, LanguageId, Point, ToOffset as _, ToPoint as _};
use std::ops::Range;
-use text::{Point, ToOffset as _, ToPoint as _};
use tree_sitter::{Node, TreeCursor};
use util::RangeExt;
-use crate::{BufferDeclaration, Line, declaration::DeclarationId, syntax_index::SyntaxIndexState};
-
// TODO:
//
// - Test parent signatures
@@ -31,19 +29,16 @@ pub struct EditPredictionExcerptOptions {
pub target_before_cursor_over_total_bytes: f32,
}
-// TODO: consider merging these
#[derive(Debug, Clone)]
pub struct EditPredictionExcerpt {
pub range: Range<usize>,
pub line_range: Range<Line>,
- pub parent_declarations: Vec<(DeclarationId, Range<usize>)>,
pub size: usize,
}
#[derive(Debug, Clone)]
pub struct EditPredictionExcerptText {
pub body: String,
- pub parent_signatures: Vec<String>,
pub language_id: Option<LanguageId>,
}
@@ -52,17 +47,8 @@ impl EditPredictionExcerpt {
let body = buffer
.text_for_range(self.range.clone())
.collect::<String>();
- let parent_signatures = self
- .parent_declarations
- .iter()
- .map(|(_, range)| buffer.text_for_range(range.clone()).collect::<String>())
- .collect();
let language_id = buffer.language().map(|l| l.id());
- EditPredictionExcerptText {
- body,
- parent_signatures,
- language_id,
- }
+ EditPredictionExcerptText { body, language_id }
}
/// Selects an excerpt around a buffer position, attempting to choose logical boundaries based
@@ -79,7 +65,6 @@ impl EditPredictionExcerpt {
query_point: Point,
buffer: &BufferSnapshot,
options: &EditPredictionExcerptOptions,
- syntax_index: Option<&SyntaxIndexState>,
) -> Option<Self> {
if buffer.len() <= options.max_bytes {
log::debug!(
@@ -89,11 +74,7 @@ impl EditPredictionExcerpt {
);
let offset_range = 0..buffer.len();
let line_range = Line(0)..Line(buffer.max_point().row);
- return Some(EditPredictionExcerpt::new(
- offset_range,
- line_range,
- Vec::new(),
- ));
+ return Some(EditPredictionExcerpt::new(offset_range, line_range));
}
let query_offset = query_point.to_offset(buffer);
@@ -104,19 +85,10 @@ impl EditPredictionExcerpt {
return None;
}
- let parent_declarations = if let Some(syntax_index) = syntax_index {
- syntax_index
- .buffer_declarations_containing_range(buffer.remote_id(), query_range.clone())
- .collect()
- } else {
- Vec::new()
- };
-
let excerpt_selector = ExcerptSelector {
query_offset,
query_range,
query_line_range: Line(query_line_range.start)..Line(query_line_range.end),
- parent_declarations: &parent_declarations,
buffer,
options,
};
@@ -139,20 +111,10 @@ impl EditPredictionExcerpt {
excerpt_selector.select_lines()
}
- fn new(
- range: Range<usize>,
- line_range: Range<Line>,
- parent_declarations: Vec<(DeclarationId, Range<usize>)>,
- ) -> Self {
- let size = range.len()
- + parent_declarations
- .iter()
- .map(|(_, range)| range.len())
- .sum::<usize>();
+ fn new(range: Range<usize>, line_range: Range<Line>) -> Self {
Self {
+ size: range.len(),
range,
- parent_declarations,
- size,
line_range,
}
}
@@ -162,14 +124,7 @@ impl EditPredictionExcerpt {
// this is an issue because parent_signature_ranges may be incorrect
log::error!("bug: with_expanded_range called with disjoint range");
}
- let mut parent_declarations = Vec::with_capacity(self.parent_declarations.len());
- for (declaration_id, range) in &self.parent_declarations {
- if !range.contains_inclusive(&new_range) {
- break;
- }
- parent_declarations.push((*declaration_id, range.clone()));
- }
- Self::new(new_range, new_line_range, parent_declarations)
+ Self::new(new_range, new_line_range)
}
fn parent_signatures_size(&self) -> usize {
@@ -181,7 +136,6 @@ struct ExcerptSelector<'a> {
query_offset: usize,
query_range: Range<usize>,
query_line_range: Range<Line>,
- parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)],
buffer: &'a BufferSnapshot,
options: &'a EditPredictionExcerptOptions,
}
@@ -409,13 +363,7 @@ impl<'a> ExcerptSelector<'a> {
}
fn make_excerpt(&self, range: Range<usize>, line_range: Range<Line>) -> EditPredictionExcerpt {
- let parent_declarations = self
- .parent_declarations
- .iter()
- .filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range))
- .map(|(id, declaration)| (*id, declaration.signature_range.clone()))
- .collect();
- EditPredictionExcerpt::new(range, line_range, parent_declarations)
+ EditPredictionExcerpt::new(range, line_range)
}
/// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.
@@ -471,30 +419,14 @@ fn node_line_end(node: Node) -> Point {
mod tests {
use super::*;
use gpui::{AppContext, TestAppContext};
- use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
+ use language::Buffer;
use util::test::{generate_marked_text, marked_text_offsets_by};
fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
- let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang().into(), cx));
+ let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
buffer.read_with(cx, |buffer, _| buffer.snapshot())
}
- fn rust_lang() -> Language {
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- )
- .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
- .unwrap()
- }
-
fn cursor_and_excerpt_range(text: &str) -> (String, usize, Range<usize>) {
let (text, offsets) = marked_text_offsets_by(text, vec!['ห', 'ยซ', 'ยป']);
(text, offsets[&'ห'][0], offsets[&'ยซ'][0]..offsets[&'ยป'][0])
@@ -506,9 +438,8 @@ mod tests {
let buffer = create_buffer(&text, cx);
let cursor_point = cursor.to_point(&buffer);
- let excerpt =
- EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options, None)
- .expect("Should select an excerpt");
+ let excerpt = EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options)
+ .expect("Should select an excerpt");
pretty_assertions::assert_eq!(
generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false),
generate_marked_text(&text, &[expected_excerpt], false)
@@ -0,0 +1,329 @@
+use collections::HashMap;
+use futures::channel::mpsc::UnboundedReceiver;
+use language::{Language, LanguageRegistry};
+use lsp::{
+ FakeLanguageServer, LanguageServerBinary, TextDocumentSyncCapability, TextDocumentSyncKind, Uri,
+};
+use parking_lot::Mutex;
+use project::Fs;
+use std::{ops::Range, path::PathBuf, sync::Arc};
+use tree_sitter::{Parser, QueryCursor, StreamingIterator, Tree};
+
+/// Registers a fake language server that implements go-to-definition using tree-sitter,
+/// making the assumption that all names are unique, and all variables' types are
+/// explicitly declared.
+pub fn register_fake_definition_server(
+ language_registry: &Arc<LanguageRegistry>,
+ language: Arc<Language>,
+ fs: Arc<dyn Fs>,
+) -> UnboundedReceiver<FakeLanguageServer> {
+ let index = Arc::new(Mutex::new(DefinitionIndex::new(language.clone())));
+
+ language_registry.register_fake_lsp(
+ language.name(),
+ language::FakeLspAdapter {
+ name: "fake-definition-lsp",
+ initialization_options: None,
+ prettier_plugins: Vec::new(),
+ disk_based_diagnostics_progress_token: None,
+ disk_based_diagnostics_sources: Vec::new(),
+ language_server_binary: LanguageServerBinary {
+ path: PathBuf::from("fake-definition-lsp"),
+ arguments: Vec::new(),
+ env: None,
+ },
+ capabilities: lsp::ServerCapabilities {
+ definition_provider: Some(lsp::OneOf::Left(true)),
+ text_document_sync: Some(TextDocumentSyncCapability::Kind(
+ TextDocumentSyncKind::FULL,
+ )),
+ ..Default::default()
+ },
+ label_for_completion: None,
+ initializer: Some(Box::new({
+ move |server| {
+ server.handle_notification::<lsp::notification::DidOpenTextDocument, _>({
+ let index = index.clone();
+ move |params, _cx| {
+ index
+ .lock()
+ .open_buffer(params.text_document.uri, ¶ms.text_document.text);
+ }
+ });
+
+ server.handle_notification::<lsp::notification::DidCloseTextDocument, _>({
+ let index = index.clone();
+ let fs = fs.clone();
+ move |params, cx| {
+ let uri = params.text_document.uri;
+ let path = uri.to_file_path().ok();
+ index.lock().mark_buffer_closed(&uri);
+
+ if let Some(path) = path {
+ let index = index.clone();
+ let fs = fs.clone();
+ cx.spawn(async move |_cx| {
+ if let Ok(content) = fs.load(&path).await {
+ index.lock().index_file(uri, &content);
+ }
+ })
+ .detach();
+ }
+ }
+ });
+
+ server.handle_notification::<lsp::notification::DidChangeWatchedFiles, _>({
+ let index = index.clone();
+ let fs = fs.clone();
+ move |params, cx| {
+ let index = index.clone();
+ let fs = fs.clone();
+ cx.spawn(async move |_cx| {
+ for event in params.changes {
+ if index.lock().is_buffer_open(&event.uri) {
+ continue;
+ }
+
+ match event.typ {
+ lsp::FileChangeType::DELETED => {
+ index.lock().remove_definitions_for_file(&event.uri);
+ }
+ lsp::FileChangeType::CREATED
+ | lsp::FileChangeType::CHANGED => {
+ if let Some(path) = event.uri.to_file_path().ok() {
+ if let Ok(content) = fs.load(&path).await {
+ index.lock().index_file(event.uri, &content);
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ })
+ .detach();
+ }
+ });
+
+ server.handle_notification::<lsp::notification::DidChangeTextDocument, _>({
+ let index = index.clone();
+ move |params, _cx| {
+ if let Some(change) = params.content_changes.into_iter().last() {
+ index
+ .lock()
+ .index_file(params.text_document.uri, &change.text);
+ }
+ }
+ });
+
+ server.handle_notification::<lsp::notification::DidChangeWorkspaceFolders, _>(
+ {
+ let index = index.clone();
+ let fs = fs.clone();
+ move |params, cx| {
+ let index = index.clone();
+ let fs = fs.clone();
+ let files = fs.as_fake().files();
+ cx.spawn(async move |_cx| {
+ for folder in params.event.added {
+ let Ok(path) = folder.uri.to_file_path() else {
+ continue;
+ };
+ for file in &files {
+ if let Some(uri) = Uri::from_file_path(&file).ok()
+ && file.starts_with(&path)
+ && let Ok(content) = fs.load(&file).await
+ {
+ index.lock().index_file(uri, &content);
+ }
+ }
+ }
+ })
+ .detach();
+ }
+ },
+ );
+
+ server.set_request_handler::<lsp::request::GotoDefinition, _, _>({
+ let index = index.clone();
+ move |params, _cx| {
+ let result = index.lock().get_definitions(
+ params.text_document_position_params.text_document.uri,
+ params.text_document_position_params.position,
+ );
+ async move { Ok(result) }
+ }
+ });
+ }
+ })),
+ },
+ )
+}
+
+struct DefinitionIndex {
+ language: Arc<Language>,
+ definitions: HashMap<String, Vec<lsp::Location>>,
+ files: HashMap<Uri, FileEntry>,
+}
+
+#[derive(Debug)]
+struct FileEntry {
+ contents: String,
+ is_open_in_buffer: bool,
+}
+
+impl DefinitionIndex {
+ fn new(language: Arc<Language>) -> Self {
+ Self {
+ language,
+ definitions: HashMap::default(),
+ files: HashMap::default(),
+ }
+ }
+
+ fn remove_definitions_for_file(&mut self, uri: &Uri) {
+ self.definitions.retain(|_, locations| {
+ locations.retain(|loc| &loc.uri != uri);
+ !locations.is_empty()
+ });
+ self.files.remove(uri);
+ }
+
+ fn open_buffer(&mut self, uri: Uri, content: &str) {
+ self.index_file_inner(uri, content, true);
+ }
+
+ fn mark_buffer_closed(&mut self, uri: &Uri) {
+ if let Some(entry) = self.files.get_mut(uri) {
+ entry.is_open_in_buffer = false;
+ }
+ }
+
+ fn is_buffer_open(&self, uri: &Uri) -> bool {
+ self.files
+ .get(uri)
+ .map(|entry| entry.is_open_in_buffer)
+ .unwrap_or(false)
+ }
+
+ fn index_file(&mut self, uri: Uri, content: &str) {
+ self.index_file_inner(uri, content, false);
+ }
+
+ fn index_file_inner(&mut self, uri: Uri, content: &str, is_open_in_buffer: bool) -> Option<()> {
+ self.remove_definitions_for_file(&uri);
+ let grammar = self.language.grammar()?;
+ let outline_config = grammar.outline_config.as_ref()?;
+ let mut parser = Parser::new();
+ parser.set_language(&grammar.ts_language).ok()?;
+ let tree = parser.parse(content, None)?;
+ let declarations = extract_declarations_from_tree(&tree, content, outline_config);
+ for (name, byte_range) in declarations {
+ let range = byte_range_to_lsp_range(content, byte_range);
+ let location = lsp::Location {
+ uri: uri.clone(),
+ range,
+ };
+ self.definitions
+ .entry(name)
+ .or_insert_with(Vec::new)
+ .push(location);
+ }
+ self.files.insert(
+ uri,
+ FileEntry {
+ contents: content.to_string(),
+ is_open_in_buffer,
+ },
+ );
+
+ Some(())
+ }
+
+ fn get_definitions(
+ &mut self,
+ uri: Uri,
+ position: lsp::Position,
+ ) -> Option<lsp::GotoDefinitionResponse> {
+ let entry = self.files.get(&uri)?;
+ let name = word_at_position(&entry.contents, position)?;
+ let locations = self.definitions.get(name).cloned()?;
+ Some(lsp::GotoDefinitionResponse::Array(locations))
+ }
+}
+
+fn extract_declarations_from_tree(
+ tree: &Tree,
+ content: &str,
+ outline_config: &language::OutlineConfig,
+) -> Vec<(String, Range<usize>)> {
+ let mut cursor = QueryCursor::new();
+ let mut declarations = Vec::new();
+ let mut matches = cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes());
+ while let Some(query_match) = matches.next() {
+ let mut name_range: Option<Range<usize>> = None;
+ let mut has_item_range = false;
+
+ for capture in query_match.captures {
+ let range = capture.node.byte_range();
+ if capture.index == outline_config.name_capture_ix {
+ name_range = Some(range);
+ } else if capture.index == outline_config.item_capture_ix {
+ has_item_range = true;
+ }
+ }
+
+ if let Some(name_range) = name_range
+ && has_item_range
+ {
+ let name = content[name_range.clone()].to_string();
+ if declarations.iter().any(|(n, _)| n == &name) {
+ continue;
+ }
+ declarations.push((name, name_range));
+ }
+ }
+ declarations
+}
+
+fn byte_range_to_lsp_range(content: &str, byte_range: Range<usize>) -> lsp::Range {
+ let start = byte_offset_to_position(content, byte_range.start);
+ let end = byte_offset_to_position(content, byte_range.end);
+ lsp::Range { start, end }
+}
+
+fn byte_offset_to_position(content: &str, offset: usize) -> lsp::Position {
+ let mut line = 0;
+ let mut character = 0;
+ let mut current_offset = 0;
+ for ch in content.chars() {
+ if current_offset >= offset {
+ break;
+ }
+ if ch == '\n' {
+ line += 1;
+ character = 0;
+ } else {
+ character += 1;
+ }
+ current_offset += ch.len_utf8();
+ }
+ lsp::Position { line, character }
+}
+
+fn word_at_position(content: &str, position: lsp::Position) -> Option<&str> {
+ let mut lines = content.lines();
+ let line = lines.nth(position.line as usize)?;
+ let column = position.character as usize;
+ if column > line.len() {
+ return None;
+ }
+ let start = line[..column]
+ .rfind(|c: char| !c.is_alphanumeric() && c != '_')
+ .map(|i| i + 1)
+ .unwrap_or(0);
+ let end = line[column..]
+ .find(|c: char| !c.is_alphanumeric() && c != '_')
+ .map(|i| i + column)
+ .unwrap_or(line.len());
+ Some(&line[start..end]).filter(|word| !word.is_empty())
+}
@@ -1,1319 +0,0 @@
-use collections::HashMap;
-use language::BufferSnapshot;
-use language::ImportsConfig;
-use language::Language;
-use std::ops::Deref;
-use std::path::Path;
-use std::sync::Arc;
-use std::{borrow::Cow, ops::Range};
-use text::OffsetRangeExt as _;
-use util::RangeExt;
-use util::paths::PathStyle;
-
-use crate::Identifier;
-use crate::text_similarity::Occurrences;
-
-// TODO: Write documentation for extension authors. The @import capture must match before or in the
-// same pattern as all all captures it contains
-
-// Future improvements to consider:
-//
-// * Distinguish absolute vs relative paths in captures. `#include "maths.h"` is relative whereas
-// `#include <maths.h>` is not.
-//
-// * Provide the name used when importing whole modules (see tests with "named_module" in the name).
-// To be useful, will require parsing of identifier qualification.
-//
-// * Scoping for imports that aren't at the top level
-//
-// * Only scan a prefix of the file, when possible. This could look like having query matches that
-// indicate it reached a declaration that is not allowed in the import section.
-//
-// * Support directly parsing to occurrences instead of storing namespaces / paths. Types should be
-// generic on this, so that tests etc can still use strings. Could do similar in syntax index.
-//
-// * Distinguish different types of namespaces when known. E.g. "name.type" capture. Once capture
-// names are more open-ended like this may make sense to build and cache a jump table (direct
-// dispatch from capture index).
-//
-// * There are a few "Language specific:" comments on behavior that gets applied to all languages.
-// Would be cleaner to be conditional on the language or otherwise configured.
-
-#[derive(Debug, Clone, Default)]
-pub struct Imports {
- pub identifier_to_imports: HashMap<Identifier, Vec<Import>>,
- pub wildcard_modules: Vec<Module>,
-}
-
-#[derive(Debug, Clone)]
-pub enum Import {
- Direct {
- module: Module,
- },
- Alias {
- module: Module,
- external_identifier: Identifier,
- },
-}
-
-#[derive(Debug, Clone)]
-pub enum Module {
- SourceExact(Arc<Path>),
- SourceFuzzy(Arc<Path>),
- Namespace(Namespace),
-}
-
-impl Module {
- fn empty() -> Self {
- Module::Namespace(Namespace::default())
- }
-
- fn push_range(
- &mut self,
- range: &ModuleRange,
- snapshot: &BufferSnapshot,
- language: &Language,
- parent_abs_path: Option<&Path>,
- ) -> usize {
- if range.is_empty() {
- return 0;
- }
-
- match range {
- ModuleRange::Source(range) => {
- if let Self::Namespace(namespace) = self
- && namespace.0.is_empty()
- {
- let path = snapshot.text_for_range(range.clone()).collect::<Cow<str>>();
-
- let path = if let Some(strip_regex) =
- language.config().import_path_strip_regex.as_ref()
- {
- strip_regex.replace_all(&path, "")
- } else {
- path
- };
-
- let path = Path::new(path.as_ref());
- if (path.starts_with(".") || path.starts_with(".."))
- && let Some(parent_abs_path) = parent_abs_path
- && let Ok(abs_path) =
- util::paths::normalize_lexically(&parent_abs_path.join(path))
- {
- *self = Self::SourceExact(abs_path.into());
- } else {
- *self = Self::SourceFuzzy(path.into());
- };
- } else if matches!(self, Self::SourceExact(_))
- || matches!(self, Self::SourceFuzzy(_))
- {
- log::warn!("bug in imports query: encountered multiple @source matches");
- } else {
- log::warn!(
- "bug in imports query: encountered both @namespace and @source match"
- );
- }
- }
- ModuleRange::Namespace(range) => {
- if let Self::Namespace(namespace) = self {
- let segment = range_text(snapshot, range);
- if language.config().ignored_import_segments.contains(&segment) {
- return 0;
- } else {
- namespace.0.push(segment);
- return 1;
- }
- } else {
- log::warn!(
- "bug in imports query: encountered both @namespace and @source match"
- );
- }
- }
- }
- 0
- }
-}
-
-#[derive(Debug, Clone)]
-enum ModuleRange {
- Source(Range<usize>),
- Namespace(Range<usize>),
-}
-
-impl Deref for ModuleRange {
- type Target = Range<usize>;
-
- fn deref(&self) -> &Self::Target {
- match self {
- ModuleRange::Source(range) => range,
- ModuleRange::Namespace(range) => range,
- }
- }
-}
-
-#[derive(Debug, Clone, PartialEq, Eq, Default)]
-pub struct Namespace(pub Vec<Arc<str>>);
-
-impl Namespace {
- pub fn occurrences(&self) -> Occurrences {
- Occurrences::from_identifiers(&self.0)
- }
-}
-
-impl Imports {
- pub fn gather(snapshot: &BufferSnapshot, parent_abs_path: Option<&Path>) -> Self {
- // Query to match different import patterns
- let mut matches = snapshot
- .syntax
- .matches(0..snapshot.len(), &snapshot.text, |grammar| {
- grammar.imports_config().map(|imports| &imports.query)
- });
-
- let mut detached_nodes: Vec<DetachedNode> = Vec::new();
- let mut identifier_to_imports = HashMap::default();
- let mut wildcard_modules = Vec::new();
- let mut import_range = None;
-
- while let Some(query_match) = matches.peek() {
- let ImportsConfig {
- query: _,
- import_ix,
- name_ix,
- namespace_ix,
- source_ix,
- list_ix,
- wildcard_ix,
- alias_ix,
- } = matches.grammars()[query_match.grammar_index]
- .imports_config()
- .unwrap();
-
- let mut new_import_range = None;
- let mut alias_range = None;
- let mut modules = Vec::new();
- let mut content: Option<(Range<usize>, ContentKind)> = None;
- for capture in query_match.captures {
- let capture_range = capture.node.byte_range();
-
- if capture.index == *import_ix {
- new_import_range = Some(capture_range);
- } else if Some(capture.index) == *namespace_ix {
- modules.push(ModuleRange::Namespace(capture_range));
- } else if Some(capture.index) == *source_ix {
- modules.push(ModuleRange::Source(capture_range));
- } else if Some(capture.index) == *alias_ix {
- alias_range = Some(capture_range);
- } else {
- let mut found_content = None;
- if Some(capture.index) == *name_ix {
- found_content = Some((capture_range, ContentKind::Name));
- } else if Some(capture.index) == *list_ix {
- found_content = Some((capture_range, ContentKind::List));
- } else if Some(capture.index) == *wildcard_ix {
- found_content = Some((capture_range, ContentKind::Wildcard));
- }
- if let Some((found_content_range, found_kind)) = found_content {
- if let Some((_, old_kind)) = content {
- let point = found_content_range.to_point(snapshot);
- log::warn!(
- "bug in {} imports query: unexpected multiple captures of {} and {} ({}:{}:{})",
- query_match.language.name(),
- old_kind.capture_name(),
- found_kind.capture_name(),
- snapshot
- .file()
- .map(|p| p.path().display(PathStyle::Posix))
- .unwrap_or_default(),
- point.start.row + 1,
- point.start.column + 1
- );
- }
- content = Some((found_content_range, found_kind));
- }
- }
- }
-
- if let Some(new_import_range) = new_import_range {
- log::trace!("starting new import {:?}", new_import_range);
- Self::gather_from_import_statement(
- &detached_nodes,
- &snapshot,
- parent_abs_path,
- &mut identifier_to_imports,
- &mut wildcard_modules,
- );
- detached_nodes.clear();
- import_range = Some(new_import_range.clone());
- }
-
- if let Some((content, content_kind)) = content {
- if import_range
- .as_ref()
- .is_some_and(|import_range| import_range.contains_inclusive(&content))
- {
- detached_nodes.push(DetachedNode {
- modules,
- content: content.clone(),
- content_kind,
- alias: alias_range.unwrap_or(0..0),
- language: query_match.language.clone(),
- });
- } else {
- log::trace!(
- "filtered out match not inside import range: {content_kind:?} at {content:?}"
- );
- }
- }
-
- matches.advance();
- }
-
- Self::gather_from_import_statement(
- &detached_nodes,
- &snapshot,
- parent_abs_path,
- &mut identifier_to_imports,
- &mut wildcard_modules,
- );
-
- Imports {
- identifier_to_imports,
- wildcard_modules,
- }
- }
-
- fn gather_from_import_statement(
- detached_nodes: &[DetachedNode],
- snapshot: &BufferSnapshot,
- parent_abs_path: Option<&Path>,
- identifier_to_imports: &mut HashMap<Identifier, Vec<Import>>,
- wildcard_modules: &mut Vec<Module>,
- ) {
- let mut trees = Vec::new();
-
- for detached_node in detached_nodes {
- if let Some(node) = Self::attach_node(detached_node.into(), &mut trees) {
- trees.push(node);
- }
- log::trace!(
- "Attached node to tree\n{:#?}\nAttach result:\n{:#?}",
- detached_node,
- trees
- .iter()
- .map(|tree| tree.debug(snapshot))
- .collect::<Vec<_>>()
- );
- }
-
- for tree in &trees {
- let mut module = Module::empty();
- Self::gather_from_tree(
- tree,
- snapshot,
- parent_abs_path,
- &mut module,
- identifier_to_imports,
- wildcard_modules,
- );
- }
- }
-
- fn attach_node(mut node: ImportTree, trees: &mut Vec<ImportTree>) -> Option<ImportTree> {
- let mut tree_index = 0;
- while tree_index < trees.len() {
- let tree = &mut trees[tree_index];
- if !node.content.is_empty() && node.content == tree.content {
- // multiple matches can apply to the same name/list/wildcard. This keeps the queries
- // simpler by combining info from these matches.
- if tree.module.is_empty() {
- tree.module = node.module;
- tree.module_children = node.module_children;
- }
- if tree.alias.is_empty() {
- tree.alias = node.alias;
- }
- return None;
- } else if !node.module.is_empty() && node.module.contains_inclusive(&tree.range()) {
- node.module_children.push(trees.remove(tree_index));
- continue;
- } else if !node.content.is_empty() && node.content.contains_inclusive(&tree.content) {
- node.content_children.push(trees.remove(tree_index));
- continue;
- } else if !tree.content.is_empty() && tree.content.contains_inclusive(&node.content) {
- if let Some(node) = Self::attach_node(node, &mut tree.content_children) {
- tree.content_children.push(node);
- }
- return None;
- }
- tree_index += 1;
- }
- Some(node)
- }
-
- fn gather_from_tree(
- tree: &ImportTree,
- snapshot: &BufferSnapshot,
- parent_abs_path: Option<&Path>,
- current_module: &mut Module,
- identifier_to_imports: &mut HashMap<Identifier, Vec<Import>>,
- wildcard_modules: &mut Vec<Module>,
- ) {
- let mut pop_count = 0;
-
- if tree.module_children.is_empty() {
- pop_count +=
- current_module.push_range(&tree.module, snapshot, &tree.language, parent_abs_path);
- } else {
- for child in &tree.module_children {
- pop_count += Self::extend_namespace_from_tree(
- child,
- snapshot,
- parent_abs_path,
- current_module,
- );
- }
- };
-
- if tree.content_children.is_empty() && !tree.content.is_empty() {
- match tree.content_kind {
- ContentKind::Name | ContentKind::List => {
- if tree.alias.is_empty() {
- identifier_to_imports
- .entry(Identifier {
- language_id: tree.language.id(),
- name: range_text(snapshot, &tree.content),
- })
- .or_default()
- .push(Import::Direct {
- module: current_module.clone(),
- });
- } else {
- let alias_name: Arc<str> = range_text(snapshot, &tree.alias);
- let external_name = range_text(snapshot, &tree.content);
- // Language specific: skip "_" aliases for Rust
- if alias_name.as_ref() != "_" {
- identifier_to_imports
- .entry(Identifier {
- language_id: tree.language.id(),
- name: alias_name,
- })
- .or_default()
- .push(Import::Alias {
- module: current_module.clone(),
- external_identifier: Identifier {
- language_id: tree.language.id(),
- name: external_name,
- },
- });
- }
- }
- }
- ContentKind::Wildcard => wildcard_modules.push(current_module.clone()),
- }
- } else {
- for child in &tree.content_children {
- Self::gather_from_tree(
- child,
- snapshot,
- parent_abs_path,
- current_module,
- identifier_to_imports,
- wildcard_modules,
- );
- }
- }
-
- if pop_count > 0 {
- match current_module {
- Module::SourceExact(_) | Module::SourceFuzzy(_) => {
- log::warn!(
- "bug in imports query: encountered both @namespace and @source match"
- );
- }
- Module::Namespace(namespace) => {
- namespace.0.drain(namespace.0.len() - pop_count..);
- }
- }
- }
- }
-
- fn extend_namespace_from_tree(
- tree: &ImportTree,
- snapshot: &BufferSnapshot,
- parent_abs_path: Option<&Path>,
- module: &mut Module,
- ) -> usize {
- let mut pop_count = 0;
- if tree.module_children.is_empty() {
- pop_count += module.push_range(&tree.module, snapshot, &tree.language, parent_abs_path);
- } else {
- for child in &tree.module_children {
- pop_count +=
- Self::extend_namespace_from_tree(child, snapshot, parent_abs_path, module);
- }
- }
- if tree.content_children.is_empty() {
- pop_count += module.push_range(
- &ModuleRange::Namespace(tree.content.clone()),
- snapshot,
- &tree.language,
- parent_abs_path,
- );
- } else {
- for child in &tree.content_children {
- pop_count +=
- Self::extend_namespace_from_tree(child, snapshot, parent_abs_path, module);
- }
- }
- pop_count
- }
-}
-
-fn range_text(snapshot: &BufferSnapshot, range: &Range<usize>) -> Arc<str> {
- snapshot
- .text_for_range(range.clone())
- .collect::<Cow<str>>()
- .into()
-}
-
-#[derive(Debug)]
-struct DetachedNode {
- modules: Vec<ModuleRange>,
- content: Range<usize>,
- content_kind: ContentKind,
- alias: Range<usize>,
- language: Arc<Language>,
-}
-
-#[derive(Debug, Clone, Copy)]
-enum ContentKind {
- Name,
- Wildcard,
- List,
-}
-
-impl ContentKind {
- fn capture_name(&self) -> &'static str {
- match self {
- ContentKind::Name => "name",
- ContentKind::Wildcard => "wildcard",
- ContentKind::List => "list",
- }
- }
-}
-
-#[derive(Debug)]
-struct ImportTree {
- module: ModuleRange,
- /// When non-empty, provides namespace / source info which should be used instead of `module`.
- module_children: Vec<ImportTree>,
- content: Range<usize>,
- /// When non-empty, provides content which should be used instead of `content`.
- content_children: Vec<ImportTree>,
- content_kind: ContentKind,
- alias: Range<usize>,
- language: Arc<Language>,
-}
-
-impl ImportTree {
- fn range(&self) -> Range<usize> {
- self.module.start.min(self.content.start)..self.module.end.max(self.content.end)
- }
-
- #[allow(dead_code)]
- fn debug<'a>(&'a self, snapshot: &'a BufferSnapshot) -> ImportTreeDebug<'a> {
- ImportTreeDebug {
- tree: self,
- snapshot,
- }
- }
-
- fn from_module_range(module: &ModuleRange, language: Arc<Language>) -> Self {
- ImportTree {
- module: module.clone(),
- module_children: Vec::new(),
- content: 0..0,
- content_children: Vec::new(),
- content_kind: ContentKind::Name,
- alias: 0..0,
- language,
- }
- }
-}
-
-impl From<&DetachedNode> for ImportTree {
- fn from(value: &DetachedNode) -> Self {
- let module;
- let module_children;
- match value.modules.len() {
- 0 => {
- module = ModuleRange::Namespace(0..0);
- module_children = Vec::new();
- }
- 1 => {
- module = value.modules[0].clone();
- module_children = Vec::new();
- }
- _ => {
- module = ModuleRange::Namespace(
- value.modules.first().unwrap().start..value.modules.last().unwrap().end,
- );
- module_children = value
- .modules
- .iter()
- .map(|module| ImportTree::from_module_range(module, value.language.clone()))
- .collect();
- }
- }
-
- ImportTree {
- module,
- module_children,
- content: value.content.clone(),
- content_children: Vec::new(),
- content_kind: value.content_kind,
- alias: value.alias.clone(),
- language: value.language.clone(),
- }
- }
-}
-
-struct ImportTreeDebug<'a> {
- tree: &'a ImportTree,
- snapshot: &'a BufferSnapshot,
-}
-
-impl std::fmt::Debug for ImportTreeDebug<'_> {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- f.debug_struct("ImportTree")
- .field("module_range", &self.tree.module)
- .field("module_text", &range_text(self.snapshot, &self.tree.module))
- .field(
- "module_children",
- &self
- .tree
- .module_children
- .iter()
- .map(|child| child.debug(&self.snapshot))
- .collect::<Vec<Self>>(),
- )
- .field("content_range", &self.tree.content)
- .field(
- "content_text",
- &range_text(self.snapshot, &self.tree.content),
- )
- .field(
- "content_children",
- &self
- .tree
- .content_children
- .iter()
- .map(|child| child.debug(&self.snapshot))
- .collect::<Vec<Self>>(),
- )
- .field("content_kind", &self.tree.content_kind)
- .field("alias_range", &self.tree.alias)
- .field("alias_text", &range_text(self.snapshot, &self.tree.alias))
- .finish()
- }
-}
-
-#[cfg(test)]
-mod test {
- use std::path::PathBuf;
- use std::sync::{Arc, LazyLock};
-
- use super::*;
- use collections::HashSet;
- use gpui::{TestAppContext, prelude::*};
- use indoc::indoc;
- use language::{
- Buffer, Language, LanguageConfig, tree_sitter_python, tree_sitter_rust,
- tree_sitter_typescript,
- };
- use regex::Regex;
-
- #[gpui::test]
- fn test_rust_simple(cx: &mut TestAppContext) {
- check_imports(
- &RUST,
- "use std::collections::HashMap;",
- &[&["std", "collections", "HashMap"]],
- cx,
- );
-
- check_imports(
- &RUST,
- "pub use std::collections::HashMap;",
- &[&["std", "collections", "HashMap"]],
- cx,
- );
-
- check_imports(
- &RUST,
- "use std::collections::{HashMap, HashSet};",
- &[
- &["std", "collections", "HashMap"],
- &["std", "collections", "HashSet"],
- ],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_rust_nested(cx: &mut TestAppContext) {
- check_imports(
- &RUST,
- "use std::{any::TypeId, collections::{HashMap, HashSet}};",
- &[
- &["std", "any", "TypeId"],
- &["std", "collections", "HashMap"],
- &["std", "collections", "HashSet"],
- ],
- cx,
- );
-
- check_imports(
- &RUST,
- "use a::b::c::{d::e::F, g::h::I};",
- &[
- &["a", "b", "c", "d", "e", "F"],
- &["a", "b", "c", "g", "h", "I"],
- ],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_rust_multiple_imports(cx: &mut TestAppContext) {
- check_imports(
- &RUST,
- indoc! {"
- use std::collections::HashMap;
- use std::any::{TypeId, Any};
- "},
- &[
- &["std", "collections", "HashMap"],
- &["std", "any", "TypeId"],
- &["std", "any", "Any"],
- ],
- cx,
- );
-
- check_imports(
- &RUST,
- indoc! {"
- use std::collections::HashSet;
-
- fn main() {
- let unqualified = HashSet::new();
- let qualified = std::collections::HashMap::new();
- }
-
- use std::any::TypeId;
- "},
- &[
- &["std", "collections", "HashSet"],
- &["std", "any", "TypeId"],
- ],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_rust_wildcard(cx: &mut TestAppContext) {
- check_imports(&RUST, "use prelude::*;", &[&["prelude", "WILDCARD"]], cx);
-
- check_imports(
- &RUST,
- "use zed::prelude::*;",
- &[&["zed", "prelude", "WILDCARD"]],
- cx,
- );
-
- check_imports(&RUST, "use prelude::{*};", &[&["prelude", "WILDCARD"]], cx);
-
- check_imports(
- &RUST,
- "use prelude::{File, *};",
- &[&["prelude", "File"], &["prelude", "WILDCARD"]],
- cx,
- );
-
- check_imports(
- &RUST,
- "use zed::{App, prelude::*};",
- &[&["zed", "App"], &["zed", "prelude", "WILDCARD"]],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_rust_alias(cx: &mut TestAppContext) {
- check_imports(
- &RUST,
- "use std::io::Result as IoResult;",
- &[&["std", "io", "Result AS IoResult"]],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_rust_crate_and_super(cx: &mut TestAppContext) {
- check_imports(&RUST, "use crate::a::b::c;", &[&["a", "b", "c"]], cx);
- check_imports(&RUST, "use super::a::b::c;", &[&["a", "b", "c"]], cx);
- // TODO: Consider stripping leading "::". Not done for now because for the text similarity matching usecase this
- // is fine.
- check_imports(&RUST, "use ::a::b::c;", &[&["::a", "b", "c"]], cx);
- }
-
- #[gpui::test]
- fn test_typescript_imports(cx: &mut TestAppContext) {
- let parent_abs_path = PathBuf::from("/home/user/project");
-
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import "./maths.js";"#,
- &[&["SOURCE /home/user/project/maths", "WILDCARD"]],
- cx,
- );
-
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import "../maths.js";"#,
- &[&["SOURCE /home/user/maths", "WILDCARD"]],
- cx,
- );
-
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import RandomNumberGenerator, { pi as ฯ } from "./maths.js";"#,
- &[
- &["SOURCE /home/user/project/maths", "RandomNumberGenerator"],
- &["SOURCE /home/user/project/maths", "pi AS ฯ"],
- ],
- cx,
- );
-
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import { pi, phi, absolute } from "./maths.js";"#,
- &[
- &["SOURCE /home/user/project/maths", "pi"],
- &["SOURCE /home/user/project/maths", "phi"],
- &["SOURCE /home/user/project/maths", "absolute"],
- ],
- cx,
- );
-
- // index.js is removed by import_path_strip_regex
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import { pi, phi, absolute } from "./maths/index.js";"#,
- &[
- &["SOURCE /home/user/project/maths", "pi"],
- &["SOURCE /home/user/project/maths", "phi"],
- &["SOURCE /home/user/project/maths", "absolute"],
- ],
- cx,
- );
-
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import type { SomeThing } from "./some-module.js";"#,
- &[&["SOURCE /home/user/project/some-module", "SomeThing"]],
- cx,
- );
-
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import { type SomeThing, OtherThing } from "./some-module.js";"#,
- &[
- &["SOURCE /home/user/project/some-module", "SomeThing"],
- &["SOURCE /home/user/project/some-module", "OtherThing"],
- ],
- cx,
- );
-
- // index.js is removed by import_path_strip_regex
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import { type SomeThing, OtherThing } from "./some-module/index.js";"#,
- &[
- &["SOURCE /home/user/project/some-module", "SomeThing"],
- &["SOURCE /home/user/project/some-module", "OtherThing"],
- ],
- cx,
- );
-
- // fuzzy paths
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import { type SomeThing, OtherThing } from "@my-app/some-module.js";"#,
- &[
- &["SOURCE FUZZY @my-app/some-module", "SomeThing"],
- &["SOURCE FUZZY @my-app/some-module", "OtherThing"],
- ],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_typescript_named_module_imports(cx: &mut TestAppContext) {
- let parent_abs_path = PathBuf::from("/home/user/project");
-
- // TODO: These should provide the name that the module is bound to.
- // For now instead these are treated as unqualified wildcard imports.
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import * as math from "./maths.js";"#,
- // &[&["/home/user/project/maths.js", "WILDCARD AS math"]],
- &[&["SOURCE /home/user/project/maths", "WILDCARD"]],
- cx,
- );
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &TYPESCRIPT,
- r#"import math = require("./maths");"#,
- // &[&["/home/user/project/maths", "WILDCARD AS math"]],
- &[&["SOURCE /home/user/project/maths", "WILDCARD"]],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_python_imports(cx: &mut TestAppContext) {
- check_imports(&PYTHON, "from math import pi", &[&["math", "pi"]], cx);
-
- check_imports(
- &PYTHON,
- "from math import pi, sin, cos",
- &[&["math", "pi"], &["math", "sin"], &["math", "cos"]],
- cx,
- );
-
- check_imports(&PYTHON, "from math import *", &[&["math", "WILDCARD"]], cx);
-
- check_imports(
- &PYTHON,
- "from math import foo.bar.baz",
- &[&["math", "foo", "bar", "baz"]],
- cx,
- );
-
- check_imports(
- &PYTHON,
- "from math import pi as PI",
- &[&["math", "pi AS PI"]],
- cx,
- );
-
- check_imports(
- &PYTHON,
- "from serializers.json import JsonSerializer",
- &[&["serializers", "json", "JsonSerializer"]],
- cx,
- );
-
- check_imports(
- &PYTHON,
- "from custom.serializers import json, xml, yaml",
- &[
- &["custom", "serializers", "json"],
- &["custom", "serializers", "xml"],
- &["custom", "serializers", "yaml"],
- ],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_python_named_module_imports(cx: &mut TestAppContext) {
- // TODO: These should provide the name that the module is bound to.
- // For now instead these are treated as unqualified wildcard imports.
- //
- // check_imports(&PYTHON, "import math", &[&["math", "WILDCARD as math"]], cx);
- // check_imports(&PYTHON, "import math as maths", &[&["math", "WILDCARD AS maths"]], cx);
- //
- // Something like:
- //
- // (import_statement
- // name: [
- // (dotted_name
- // (identifier)* @namespace
- // (identifier) @name.module .)
- // (aliased_import
- // name: (dotted_name
- // ((identifier) ".")* @namespace
- // (identifier) @name.module .)
- // alias: (identifier) @alias)
- // ]) @import
-
- check_imports(&PYTHON, "import math", &[&["math", "WILDCARD"]], cx);
-
- check_imports(
- &PYTHON,
- "import math as maths",
- &[&["math", "WILDCARD"]],
- cx,
- );
-
- check_imports(&PYTHON, "import a.b.c", &[&["a", "b", "c", "WILDCARD"]], cx);
-
- check_imports(
- &PYTHON,
- "import a.b.c as d",
- &[&["a", "b", "c", "WILDCARD"]],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_python_package_relative_imports(cx: &mut TestAppContext) {
- // TODO: These should provide info about the dir they are relative to, to provide more
- // precise resolution. Instead, fuzzy matching is used as usual.
-
- check_imports(&PYTHON, "from . import math", &[&["math"]], cx);
-
- check_imports(&PYTHON, "from .a import math", &[&["a", "math"]], cx);
-
- check_imports(
- &PYTHON,
- "from ..a.b import math",
- &[&["a", "b", "math"]],
- cx,
- );
-
- check_imports(
- &PYTHON,
- "from ..a.b import *",
- &[&["a", "b", "WILDCARD"]],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_c_imports(cx: &mut TestAppContext) {
- let parent_abs_path = PathBuf::from("/home/user/project");
-
- // TODO: Distinguish that these are not relative to current path
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &C,
- r#"#include <math.h>"#,
- &[&["SOURCE FUZZY math.h", "WILDCARD"]],
- cx,
- );
-
- // TODO: These should be treated as relative, but don't start with ./ or ../
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &C,
- r#"#include "math.h""#,
- &[&["SOURCE FUZZY math.h", "WILDCARD"]],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_cpp_imports(cx: &mut TestAppContext) {
- let parent_abs_path = PathBuf::from("/home/user/project");
-
- // TODO: Distinguish that these are not relative to current path
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &CPP,
- r#"#include <math.h>"#,
- &[&["SOURCE FUZZY math.h", "WILDCARD"]],
- cx,
- );
-
- // TODO: These should be treated as relative, but don't start with ./ or ../
- check_imports_with_file_abs_path(
- Some(&parent_abs_path),
- &CPP,
- r#"#include "math.h""#,
- &[&["SOURCE FUZZY math.h", "WILDCARD"]],
- cx,
- );
- }
-
- #[gpui::test]
- fn test_go_imports(cx: &mut TestAppContext) {
- check_imports(
- &GO,
- r#"import . "lib/math""#,
- &[&["lib/math", "WILDCARD"]],
- cx,
- );
-
- // not included, these are only for side-effects
- check_imports(&GO, r#"import _ "lib/math""#, &[], cx);
- }
-
- #[gpui::test]
- fn test_go_named_module_imports(cx: &mut TestAppContext) {
- // TODO: These should provide the name that the module is bound to.
- // For now instead these are treated as unqualified wildcard imports.
-
- check_imports(
- &GO,
- r#"import "lib/math""#,
- &[&["lib/math", "WILDCARD"]],
- cx,
- );
- check_imports(
- &GO,
- r#"import m "lib/math""#,
- &[&["lib/math", "WILDCARD"]],
- cx,
- );
- }
-
- #[track_caller]
- fn check_imports(
- language: &Arc<Language>,
- source: &str,
- expected: &[&[&str]],
- cx: &mut TestAppContext,
- ) {
- check_imports_with_file_abs_path(None, language, source, expected, cx);
- }
-
- #[track_caller]
- fn check_imports_with_file_abs_path(
- parent_abs_path: Option<&Path>,
- language: &Arc<Language>,
- source: &str,
- expected: &[&[&str]],
- cx: &mut TestAppContext,
- ) {
- let buffer = cx.new(|cx| {
- let mut buffer = Buffer::local(source, cx);
- buffer.set_language(Some(language.clone()), cx);
- buffer
- });
- cx.run_until_parked();
-
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
-
- let imports = Imports::gather(&snapshot, parent_abs_path);
- let mut actual_symbols = imports
- .identifier_to_imports
- .iter()
- .flat_map(|(identifier, imports)| {
- imports
- .iter()
- .map(|import| import.to_identifier_parts(identifier.name.as_ref()))
- })
- .chain(
- imports
- .wildcard_modules
- .iter()
- .map(|module| module.to_identifier_parts("WILDCARD")),
- )
- .collect::<Vec<_>>();
- let mut expected_symbols = expected
- .iter()
- .map(|expected| expected.iter().map(|s| s.to_string()).collect::<Vec<_>>())
- .collect::<Vec<_>>();
- actual_symbols.sort();
- expected_symbols.sort();
- if actual_symbols != expected_symbols {
- let top_layer = snapshot.syntax_layers().next().unwrap();
- panic!(
- "Expected imports: {:?}\n\
- Actual imports: {:?}\n\
- Tree:\n{}",
- expected_symbols,
- actual_symbols,
- tree_to_string(&top_layer.node()),
- );
- }
- }
-
- fn tree_to_string(node: &tree_sitter::Node) -> String {
- let mut cursor = node.walk();
- let mut result = String::new();
- let mut depth = 0;
- 'outer: loop {
- result.push_str(&" ".repeat(depth));
- if let Some(field_name) = cursor.field_name() {
- result.push_str(field_name);
- result.push_str(": ");
- }
- if cursor.node().is_named() {
- result.push_str(cursor.node().kind());
- } else {
- result.push('"');
- result.push_str(cursor.node().kind());
- result.push('"');
- }
- result.push('\n');
-
- if cursor.goto_first_child() {
- depth += 1;
- continue;
- }
- if cursor.goto_next_sibling() {
- continue;
- }
- while cursor.goto_parent() {
- depth -= 1;
- if cursor.goto_next_sibling() {
- continue 'outer;
- }
- }
- break;
- }
- result
- }
-
- static RUST: LazyLock<Arc<Language>> = LazyLock::new(|| {
- Arc::new(
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- ignored_import_segments: HashSet::from_iter(["crate".into(), "super".into()]),
- import_path_strip_regex: Some(Regex::new("/(lib|mod)\\.rs$").unwrap()),
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- )
- .with_imports_query(include_str!("../../languages/src/rust/imports.scm"))
- .unwrap(),
- )
- });
-
- static TYPESCRIPT: LazyLock<Arc<Language>> = LazyLock::new(|| {
- Arc::new(
- Language::new(
- LanguageConfig {
- name: "TypeScript".into(),
- import_path_strip_regex: Some(Regex::new("(?:/index)?\\.[jt]s$").unwrap()),
- ..Default::default()
- },
- Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()),
- )
- .with_imports_query(include_str!("../../languages/src/typescript/imports.scm"))
- .unwrap(),
- )
- });
-
- static PYTHON: LazyLock<Arc<Language>> = LazyLock::new(|| {
- Arc::new(
- Language::new(
- LanguageConfig {
- name: "Python".into(),
- import_path_strip_regex: Some(Regex::new("/__init__\\.py$").unwrap()),
- ..Default::default()
- },
- Some(tree_sitter_python::LANGUAGE.into()),
- )
- .with_imports_query(include_str!("../../languages/src/python/imports.scm"))
- .unwrap(),
- )
- });
-
- // TODO: Ideally should use actual language configurations
- static C: LazyLock<Arc<Language>> = LazyLock::new(|| {
- Arc::new(
- Language::new(
- LanguageConfig {
- name: "C".into(),
- import_path_strip_regex: Some(Regex::new("^<|>$").unwrap()),
- ..Default::default()
- },
- Some(tree_sitter_c::LANGUAGE.into()),
- )
- .with_imports_query(include_str!("../../languages/src/c/imports.scm"))
- .unwrap(),
- )
- });
-
- static CPP: LazyLock<Arc<Language>> = LazyLock::new(|| {
- Arc::new(
- Language::new(
- LanguageConfig {
- name: "C++".into(),
- import_path_strip_regex: Some(Regex::new("^<|>$").unwrap()),
- ..Default::default()
- },
- Some(tree_sitter_cpp::LANGUAGE.into()),
- )
- .with_imports_query(include_str!("../../languages/src/cpp/imports.scm"))
- .unwrap(),
- )
- });
-
- static GO: LazyLock<Arc<Language>> = LazyLock::new(|| {
- Arc::new(
- Language::new(
- LanguageConfig {
- name: "Go".into(),
- ..Default::default()
- },
- Some(tree_sitter_go::LANGUAGE.into()),
- )
- .with_imports_query(include_str!("../../languages/src/go/imports.scm"))
- .unwrap(),
- )
- });
-
- impl Import {
- fn to_identifier_parts(&self, identifier: &str) -> Vec<String> {
- match self {
- Import::Direct { module } => module.to_identifier_parts(identifier),
- Import::Alias {
- module,
- external_identifier: external_name,
- } => {
- module.to_identifier_parts(&format!("{} AS {}", external_name.name, identifier))
- }
- }
- }
- }
-
- impl Module {
- fn to_identifier_parts(&self, identifier: &str) -> Vec<String> {
- match self {
- Self::Namespace(namespace) => namespace.to_identifier_parts(identifier),
- Self::SourceExact(path) => {
- vec![
- format!("SOURCE {}", path.display().to_string().replace("\\", "/")),
- identifier.to_string(),
- ]
- }
- Self::SourceFuzzy(path) => {
- vec![
- format!(
- "SOURCE FUZZY {}",
- path.display().to_string().replace("\\", "/")
- ),
- identifier.to_string(),
- ]
- }
- }
- }
- }
-
- impl Namespace {
- fn to_identifier_parts(&self, identifier: &str) -> Vec<String> {
- self.0
- .iter()
- .map(|chunk| chunk.to_string())
- .chain(std::iter::once(identifier.to_string()))
- .collect::<Vec<_>>()
- }
- }
-}
@@ -1,126 +0,0 @@
-use language::{BufferSnapshot, SyntaxMapMatches};
-use std::{cmp::Reverse, ops::Range};
-
-use crate::declaration::Identifier;
-
-// TODO:
-//
-// * how to handle multiple name captures? for now last one wins
-//
-// * annotation ranges
-//
-// * new "signature" capture for outline queries
-//
-// * Check parent behavior of "int x, y = 0" declarations in a test
-
-pub struct OutlineDeclaration {
- pub parent_index: Option<usize>,
- pub identifier: Identifier,
- pub item_range: Range<usize>,
- pub signature_range: Range<usize>,
-}
-
-pub fn declarations_in_buffer(buffer: &BufferSnapshot) -> Vec<OutlineDeclaration> {
- declarations_overlapping_range(0..buffer.len(), buffer)
-}
-
-pub fn declarations_overlapping_range(
- range: Range<usize>,
- buffer: &BufferSnapshot,
-) -> Vec<OutlineDeclaration> {
- let mut declarations = OutlineIterator::new(range, buffer).collect::<Vec<_>>();
- declarations.sort_unstable_by_key(|item| (item.item_range.start, Reverse(item.item_range.end)));
-
- let mut parent_stack: Vec<(usize, Range<usize>)> = Vec::new();
- for (index, declaration) in declarations.iter_mut().enumerate() {
- while let Some((top_parent_index, top_parent_range)) = parent_stack.last() {
- if declaration.item_range.start >= top_parent_range.end {
- parent_stack.pop();
- } else {
- declaration.parent_index = Some(*top_parent_index);
- break;
- }
- }
- parent_stack.push((index, declaration.item_range.clone()));
- }
- declarations
-}
-
-/// Iterates outline items without being ordered w.r.t. nested items and without populating
-/// `parent`.
-pub struct OutlineIterator<'a> {
- buffer: &'a BufferSnapshot,
- matches: SyntaxMapMatches<'a>,
-}
-
-impl<'a> OutlineIterator<'a> {
- pub fn new(range: Range<usize>, buffer: &'a BufferSnapshot) -> Self {
- let matches = buffer.syntax.matches(range, &buffer.text, |grammar| {
- grammar.outline_config.as_ref().map(|c| &c.query)
- });
-
- Self { buffer, matches }
- }
-}
-
-impl<'a> Iterator for OutlineIterator<'a> {
- type Item = OutlineDeclaration;
-
- fn next(&mut self) -> Option<Self::Item> {
- while let Some(mat) = self.matches.peek() {
- let config = self.matches.grammars()[mat.grammar_index]
- .outline_config
- .as_ref()
- .unwrap();
-
- let mut name_range = None;
- let mut item_range = None;
- let mut signature_start = None;
- let mut signature_end = None;
-
- let mut add_to_signature = |range: Range<usize>| {
- if signature_start.is_none() {
- signature_start = Some(range.start);
- }
- signature_end = Some(range.end);
- };
-
- for capture in mat.captures {
- let range = capture.node.byte_range();
- if capture.index == config.name_capture_ix {
- name_range = Some(range.clone());
- add_to_signature(range);
- } else if Some(capture.index) == config.context_capture_ix
- || Some(capture.index) == config.extra_context_capture_ix
- {
- add_to_signature(range);
- } else if capture.index == config.item_capture_ix {
- item_range = Some(range.clone());
- }
- }
-
- let language_id = mat.language.id();
- self.matches.advance();
-
- if let Some(name_range) = name_range
- && let Some(item_range) = item_range
- && let Some(signature_start) = signature_start
- && let Some(signature_end) = signature_end
- {
- let name = self
- .buffer
- .text_for_range(name_range)
- .collect::<String>()
- .into();
-
- return Some(OutlineDeclaration {
- identifier: Identifier { name, language_id },
- item_range: item_range,
- signature_range: signature_start..signature_end,
- parent_index: None,
- });
- }
- }
- None
- }
-}
@@ -1,173 +0,0 @@
-use collections::HashMap;
-use language::BufferSnapshot;
-use std::ops::Range;
-use util::RangeExt;
-
-use crate::{
- declaration::Identifier,
- excerpt::{EditPredictionExcerpt, EditPredictionExcerptText},
-};
-
-#[derive(Debug, Clone)]
-pub struct Reference {
- pub identifier: Identifier,
- pub range: Range<usize>,
- pub region: ReferenceRegion,
-}
-
-#[derive(Copy, Clone, Debug, Eq, PartialEq)]
-pub enum ReferenceRegion {
- Breadcrumb,
- Nearby,
-}
-
-pub fn references_in_excerpt(
- excerpt: &EditPredictionExcerpt,
- excerpt_text: &EditPredictionExcerptText,
- snapshot: &BufferSnapshot,
-) -> HashMap<Identifier, Vec<Reference>> {
- let mut references = references_in_range(
- excerpt.range.clone(),
- excerpt_text.body.as_str(),
- ReferenceRegion::Nearby,
- snapshot,
- );
-
- for ((_, range), text) in excerpt
- .parent_declarations
- .iter()
- .zip(excerpt_text.parent_signatures.iter())
- {
- references.extend(references_in_range(
- range.clone(),
- text.as_str(),
- ReferenceRegion::Breadcrumb,
- snapshot,
- ));
- }
-
- let mut identifier_to_references: HashMap<Identifier, Vec<Reference>> = HashMap::default();
- for reference in references {
- identifier_to_references
- .entry(reference.identifier.clone())
- .or_insert_with(Vec::new)
- .push(reference);
- }
- identifier_to_references
-}
-
-/// Finds all nodes which have a "variable" match from the highlights query within the offset range.
-pub fn references_in_range(
- range: Range<usize>,
- range_text: &str,
- reference_region: ReferenceRegion,
- buffer: &BufferSnapshot,
-) -> Vec<Reference> {
- let mut matches = buffer
- .syntax
- .matches(range.clone(), &buffer.text, |grammar| {
- grammar
- .highlights_config
- .as_ref()
- .map(|config| &config.query)
- });
-
- let mut references = Vec::new();
- let mut last_added_range = None;
- while let Some(mat) = matches.peek() {
- let config = matches.grammars()[mat.grammar_index]
- .highlights_config
- .as_ref();
-
- if let Some(config) = config {
- for capture in mat.captures {
- if config.identifier_capture_indices.contains(&capture.index) {
- let node_range = capture.node.byte_range();
-
- // sometimes multiple highlight queries match - this deduplicates them
- if Some(node_range.clone()) == last_added_range {
- continue;
- }
-
- if !range.contains_inclusive(&node_range) {
- continue;
- }
-
- let identifier_text =
- &range_text[node_range.start - range.start..node_range.end - range.start];
-
- references.push(Reference {
- identifier: Identifier {
- name: identifier_text.into(),
- language_id: mat.language.id(),
- },
- range: node_range.clone(),
- region: reference_region,
- });
- last_added_range = Some(node_range);
- }
- }
- }
-
- matches.advance();
- }
- references
-}
-
-#[cfg(test)]
-mod test {
- use gpui::{TestAppContext, prelude::*};
- use indoc::indoc;
- use language::{BufferSnapshot, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
-
- use crate::reference::{ReferenceRegion, references_in_range};
-
- #[gpui::test]
- fn test_identifier_node_truncated(cx: &mut TestAppContext) {
- let code = indoc! { r#"
- fn main() {
- add(1, 2);
- }
-
- fn add(a: i32, b: i32) -> i32 {
- a + b
- }
- "# };
- let buffer = create_buffer(code, cx);
-
- let range = 0..35;
- let references = references_in_range(
- range.clone(),
- &code[range],
- ReferenceRegion::Breadcrumb,
- &buffer,
- );
- assert_eq!(references.len(), 2);
- assert_eq!(references[0].identifier.name.as_ref(), "main");
- assert_eq!(references[1].identifier.name.as_ref(), "add");
- }
-
- fn create_buffer(text: &str, cx: &mut TestAppContext) -> BufferSnapshot {
- let buffer =
- cx.new(|cx| language::Buffer::local(text, cx).with_language(rust_lang().into(), cx));
- buffer.read_with(cx, |buffer, _| buffer.snapshot())
- }
-
- fn rust_lang() -> Language {
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- )
- .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
- .unwrap()
- .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
- .unwrap()
- }
-}
@@ -1,1069 +0,0 @@
-use anyhow::{Result, anyhow};
-use collections::{HashMap, HashSet};
-use futures::channel::mpsc;
-use futures::lock::Mutex;
-use futures::{FutureExt as _, StreamExt, future};
-use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity};
-use itertools::Itertools;
-
-use language::{Buffer, BufferEvent};
-use postage::stream::Stream as _;
-use project::buffer_store::{BufferStore, BufferStoreEvent};
-use project::worktree_store::{WorktreeStore, WorktreeStoreEvent};
-use project::{PathChange, Project, ProjectEntryId, ProjectPath};
-use slotmap::SlotMap;
-use std::iter;
-use std::ops::{DerefMut, Range};
-use std::sync::Arc;
-use text::BufferId;
-use util::{RangeExt as _, debug_panic, some_or_debug_panic};
-
-use crate::CachedDeclarationPath;
-use crate::declaration::{
- BufferDeclaration, Declaration, DeclarationId, FileDeclaration, Identifier,
-};
-use crate::outline::declarations_in_buffer;
-
-// TODO
-//
-// * Also queue / debounce buffer changes. A challenge for this is that use of
-// `buffer_declarations_containing_range` assumes that the index is always immediately up to date.
-//
-// * Add a per language configuration for skipping indexing.
-//
-// * Handle tsx / ts / js referencing each-other
-
-// Potential future improvements:
-//
-// * Prevent indexing of a large file from blocking the queue.
-//
-// * Send multiple selected excerpt ranges. Challenge is that excerpt ranges influence which
-// references are present and their scores.
-//
-// * Include single-file worktrees / non visible worktrees? E.g. go to definition that resolves to a
-// file in a build dependency. Should not be editable in that case - but how to distinguish the case
-// where it should be editable?
-
-// Potential future optimizations:
-//
-// * Index files on multiple threads in Zed (currently only parallel for the CLI). Adding some kind
-// of priority system to the background executor could help - it's single threaded for now to avoid
-// interfering with other work.
-//
-// * Parse files directly instead of loading into a Rope.
-//
-// - This would allow the task handling dirty_files to be done entirely on the background executor.
-//
-// - Make SyntaxMap generic to handle embedded languages? Will also need to find line boundaries,
-// but that can be done by scanning characters in the flat representation.
-//
-// * Use something similar to slotmap without key versions.
-//
-// * Concurrent slotmap
-
-pub struct SyntaxIndex {
- state: Arc<Mutex<SyntaxIndexState>>,
- project: WeakEntity<Project>,
- initial_file_indexing_done_rx: postage::watch::Receiver<bool>,
- _file_indexing_task: Option<Task<()>>,
-}
-
-pub struct SyntaxIndexState {
- declarations: SlotMap<DeclarationId, Declaration>,
- identifiers: HashMap<Identifier, HashSet<DeclarationId>>,
- files: HashMap<ProjectEntryId, FileState>,
- buffers: HashMap<BufferId, BufferState>,
- dirty_files: HashMap<ProjectEntryId, ProjectPath>,
- dirty_files_tx: mpsc::Sender<()>,
-}
-
-#[derive(Debug, Default)]
-struct FileState {
- declarations: Vec<DeclarationId>,
-}
-
-#[derive(Default)]
-struct BufferState {
- declarations: Vec<DeclarationId>,
- task: Option<Task<()>>,
-}
-
-impl SyntaxIndex {
- pub fn new(
- project: &Entity<Project>,
- file_indexing_parallelism: usize,
- cx: &mut Context<Self>,
- ) -> Self {
- assert!(file_indexing_parallelism > 0);
- let (dirty_files_tx, mut dirty_files_rx) = mpsc::channel::<()>(1);
- let (mut initial_file_indexing_done_tx, initial_file_indexing_done_rx) =
- postage::watch::channel();
-
- let initial_state = SyntaxIndexState {
- declarations: SlotMap::default(),
- identifiers: HashMap::default(),
- files: HashMap::default(),
- buffers: HashMap::default(),
- dirty_files: HashMap::default(),
- dirty_files_tx,
- };
- let mut this = Self {
- project: project.downgrade(),
- state: Arc::new(Mutex::new(initial_state)),
- initial_file_indexing_done_rx,
- _file_indexing_task: None,
- };
-
- let worktree_store = project.read(cx).worktree_store();
- let initial_worktree_snapshots = worktree_store
- .read(cx)
- .worktrees()
- .map(|w| w.read(cx).snapshot())
- .collect::<Vec<_>>();
- this._file_indexing_task = Some(cx.spawn(async move |this, cx| {
- let snapshots_file_count = initial_worktree_snapshots
- .iter()
- .map(|worktree| worktree.file_count())
- .sum::<usize>();
- if snapshots_file_count > 0 {
- let chunk_size = snapshots_file_count.div_ceil(file_indexing_parallelism);
- let chunk_count = snapshots_file_count.div_ceil(chunk_size);
- let file_chunks = initial_worktree_snapshots
- .iter()
- .flat_map(|worktree| {
- let worktree_id = worktree.id();
- worktree.files(false, 0).map(move |entry| {
- (
- entry.id,
- ProjectPath {
- worktree_id,
- path: entry.path.clone(),
- },
- )
- })
- })
- .chunks(chunk_size);
-
- let mut tasks = Vec::with_capacity(chunk_count);
- for chunk in file_chunks.into_iter() {
- tasks.push(Self::update_dirty_files(
- &this,
- chunk.into_iter().collect(),
- cx.clone(),
- ));
- }
- futures::future::join_all(tasks).await;
- log::info!("Finished initial file indexing");
- }
-
- *initial_file_indexing_done_tx.borrow_mut() = true;
-
- let Ok(state) = this.read_with(cx, |this, _cx| Arc::downgrade(&this.state)) else {
- return;
- };
- while dirty_files_rx.next().await.is_some() {
- let Some(state) = state.upgrade() else {
- return;
- };
- let mut state = state.lock().await;
- let was_underused = state.dirty_files.capacity() > 255
- && state.dirty_files.len() * 8 < state.dirty_files.capacity();
- let dirty_files = state.dirty_files.drain().collect::<Vec<_>>();
- if was_underused {
- state.dirty_files.shrink_to_fit();
- }
- drop(state);
- if dirty_files.is_empty() {
- continue;
- }
-
- let chunk_size = dirty_files.len().div_ceil(file_indexing_parallelism);
- let chunk_count = dirty_files.len().div_ceil(chunk_size);
- let mut tasks = Vec::with_capacity(chunk_count);
- let chunks = dirty_files.into_iter().chunks(chunk_size);
- for chunk in chunks.into_iter() {
- tasks.push(Self::update_dirty_files(
- &this,
- chunk.into_iter().collect(),
- cx.clone(),
- ));
- }
- futures::future::join_all(tasks).await;
- }
- }));
-
- cx.subscribe(&worktree_store, Self::handle_worktree_store_event)
- .detach();
-
- let buffer_store = project.read(cx).buffer_store().clone();
- for buffer in buffer_store.read(cx).buffers().collect::<Vec<_>>() {
- this.register_buffer(&buffer, cx);
- }
- cx.subscribe(&buffer_store, Self::handle_buffer_store_event)
- .detach();
-
- this
- }
-
- async fn update_dirty_files(
- this: &WeakEntity<Self>,
- dirty_files: Vec<(ProjectEntryId, ProjectPath)>,
- mut cx: AsyncApp,
- ) {
- for (entry_id, project_path) in dirty_files {
- let Ok(task) = this.update(&mut cx, |this, cx| {
- this.update_file(entry_id, project_path, cx)
- }) else {
- return;
- };
- task.await;
- }
- }
-
- pub fn wait_for_initial_file_indexing(&self, cx: &App) -> Task<Result<()>> {
- if *self.initial_file_indexing_done_rx.borrow() {
- Task::ready(Ok(()))
- } else {
- let mut rx = self.initial_file_indexing_done_rx.clone();
- cx.background_spawn(async move {
- loop {
- match rx.recv().await {
- Some(true) => return Ok(()),
- Some(false) => {}
- None => {
- return Err(anyhow!(
- "SyntaxIndex dropped while waiting for initial file indexing"
- ));
- }
- }
- }
- })
- }
- }
-
- pub fn indexed_file_paths(&self, cx: &App) -> Task<Vec<ProjectPath>> {
- let state = self.state.clone();
- let project = self.project.clone();
-
- cx.spawn(async move |cx| {
- let state = state.lock().await;
- let Some(project) = project.upgrade() else {
- return vec![];
- };
- project
- .read_with(cx, |project, cx| {
- state
- .files
- .keys()
- .filter_map(|entry_id| project.path_for_entry(*entry_id, cx))
- .collect()
- })
- .unwrap_or_default()
- })
- }
-
- fn handle_worktree_store_event(
- &mut self,
- _worktree_store: Entity<WorktreeStore>,
- event: &WorktreeStoreEvent,
- cx: &mut Context<Self>,
- ) {
- use WorktreeStoreEvent::*;
- match event {
- WorktreeUpdatedEntries(worktree_id, updated_entries_set) => {
- let state = Arc::downgrade(&self.state);
- let worktree_id = *worktree_id;
- let updated_entries_set = updated_entries_set.clone();
- cx.background_spawn(async move {
- let Some(state) = state.upgrade() else { return };
- let mut state = state.lock().await;
- for (path, entry_id, path_change) in updated_entries_set.iter() {
- if let PathChange::Removed = path_change {
- state.files.remove(entry_id);
- state.dirty_files.remove(entry_id);
- } else {
- let project_path = ProjectPath {
- worktree_id,
- path: path.clone(),
- };
- state.dirty_files.insert(*entry_id, project_path);
- }
- }
- match state.dirty_files_tx.try_send(()) {
- Err(err) if err.is_disconnected() => {
- log::error!("bug: syntax indexing queue is disconnected");
- }
- _ => {}
- }
- })
- .detach();
- }
- WorktreeDeletedEntry(_worktree_id, project_entry_id) => {
- let project_entry_id = *project_entry_id;
- self.with_state(cx, move |state| {
- state.files.remove(&project_entry_id);
- })
- }
- _ => {}
- }
- }
-
- fn handle_buffer_store_event(
- &mut self,
- _buffer_store: Entity<BufferStore>,
- event: &BufferStoreEvent,
- cx: &mut Context<Self>,
- ) {
- use BufferStoreEvent::*;
- match event {
- BufferAdded(buffer) => self.register_buffer(buffer, cx),
- BufferOpened { .. }
- | BufferChangedFilePath { .. }
- | BufferDropped { .. }
- | SharedBufferClosed { .. } => {}
- }
- }
-
- pub fn state(&self) -> &Arc<Mutex<SyntaxIndexState>> {
- &self.state
- }
-
- fn with_state(&self, cx: &mut App, f: impl FnOnce(&mut SyntaxIndexState) + Send + 'static) {
- if let Some(mut state) = self.state.try_lock() {
- f(&mut state);
- return;
- }
- let state = Arc::downgrade(&self.state);
- cx.background_spawn(async move {
- let Some(state) = state.upgrade() else {
- return;
- };
- let mut state = state.lock().await;
- f(&mut state)
- })
- .detach();
- }
-
- fn register_buffer(&self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
- let buffer_id = buffer.read(cx).remote_id();
- cx.observe_release(buffer, move |this, _buffer, cx| {
- this.with_state(cx, move |state| {
- if let Some(buffer_state) = state.buffers.remove(&buffer_id) {
- SyntaxIndexState::remove_buffer_declarations(
- &buffer_state.declarations,
- &mut state.declarations,
- &mut state.identifiers,
- );
- }
- })
- })
- .detach();
- cx.subscribe(buffer, Self::handle_buffer_event).detach();
-
- self.update_buffer(buffer.clone(), cx);
- }
-
- fn handle_buffer_event(
- &mut self,
- buffer: Entity<Buffer>,
- event: &BufferEvent,
- cx: &mut Context<Self>,
- ) {
- match event {
- BufferEvent::Edited |
- // paths are cached and so should be updated
- BufferEvent::FileHandleChanged => self.update_buffer(buffer, cx),
- _ => {}
- }
- }
-
- fn update_buffer(&self, buffer_entity: Entity<Buffer>, cx: &mut Context<Self>) {
- let buffer = buffer_entity.read(cx);
- if buffer.language().is_none() {
- return;
- }
-
- let Some((project_entry_id, cached_path)) = project::File::from_dyn(buffer.file())
- .and_then(|f| {
- let project_entry_id = f.project_entry_id()?;
- let cached_path = CachedDeclarationPath::new(
- f.worktree.read(cx).abs_path(),
- &f.path,
- buffer.language(),
- );
- Some((project_entry_id, cached_path))
- })
- else {
- return;
- };
- let buffer_id = buffer.remote_id();
-
- let mut parse_status = buffer.parse_status();
- let snapshot_task = cx.spawn({
- let weak_buffer = buffer_entity.downgrade();
- async move |_, cx| {
- while *parse_status.borrow() != language::ParseStatus::Idle {
- parse_status.changed().await?;
- }
- weak_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())
- }
- });
-
- let state = Arc::downgrade(&self.state);
- let task = cx.background_spawn(async move {
- // TODO: How to handle errors?
- let Ok(snapshot) = snapshot_task.await else {
- return;
- };
- let rope = snapshot.text.as_rope();
-
- let declarations = declarations_in_buffer(&snapshot)
- .into_iter()
- .map(|item| {
- (
- item.parent_index,
- BufferDeclaration::from_outline(item, &rope),
- )
- })
- .collect::<Vec<_>>();
-
- let Some(state) = state.upgrade() else {
- return;
- };
- let mut state = state.lock().await;
- let state = state.deref_mut();
-
- let buffer_state = state
- .buffers
- .entry(buffer_id)
- .or_insert_with(Default::default);
-
- SyntaxIndexState::remove_buffer_declarations(
- &buffer_state.declarations,
- &mut state.declarations,
- &mut state.identifiers,
- );
-
- let mut new_ids = Vec::with_capacity(declarations.len());
- state.declarations.reserve(declarations.len());
- for (parent_index, mut declaration) in declarations {
- declaration.parent =
- parent_index.and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
-
- let identifier = declaration.identifier.clone();
- let declaration_id = state.declarations.insert(Declaration::Buffer {
- rope: rope.clone(),
- buffer_id,
- declaration,
- project_entry_id,
- cached_path: cached_path.clone(),
- });
- new_ids.push(declaration_id);
-
- state
- .identifiers
- .entry(identifier)
- .or_default()
- .insert(declaration_id);
- }
-
- buffer_state.declarations = new_ids;
- });
-
- self.with_state(cx, move |state| {
- state
- .buffers
- .entry(buffer_id)
- .or_insert_with(Default::default)
- .task = Some(task)
- });
- }
-
- fn update_file(
- &mut self,
- entry_id: ProjectEntryId,
- project_path: ProjectPath,
- cx: &mut Context<Self>,
- ) -> Task<()> {
- let Some(project) = self.project.upgrade() else {
- return Task::ready(());
- };
- let project = project.read(cx);
-
- let language_registry = project.languages();
- let Some(available_language) =
- language_registry.language_for_file_path(project_path.path.as_std_path())
- else {
- return Task::ready(());
- };
- let language = if let Some(Ok(Ok(language))) = language_registry
- .load_language(&available_language)
- .now_or_never()
- {
- if language
- .grammar()
- .is_none_or(|grammar| grammar.outline_config.is_none())
- {
- return Task::ready(());
- }
- future::Either::Left(async { Ok(language) })
- } else {
- let language_registry = language_registry.clone();
- future::Either::Right(async move {
- anyhow::Ok(
- language_registry
- .load_language(&available_language)
- .await??,
- )
- })
- };
-
- let Some(worktree) = project.worktree_for_id(project_path.worktree_id, cx) else {
- return Task::ready(());
- };
-
- let snapshot_task = worktree.update(cx, |worktree, cx| {
- let load_task = worktree.load_file(&project_path.path, cx);
- let worktree_abs_path = worktree.abs_path();
- cx.spawn(async move |_this, cx| {
- let loaded_file = load_task.await?;
- let language = language.await?;
-
- let buffer = cx.new(|cx| {
- let mut buffer = Buffer::local(loaded_file.text, cx);
- buffer.set_language(Some(language.clone()), cx);
- buffer
- })?;
-
- let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
- while *parse_status.borrow() != language::ParseStatus::Idle {
- parse_status.changed().await?;
- }
-
- let cached_path = CachedDeclarationPath::new(
- worktree_abs_path,
- &project_path.path,
- Some(&language),
- );
-
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
-
- anyhow::Ok((snapshot, cached_path))
- })
- });
-
- let state = Arc::downgrade(&self.state);
- cx.background_spawn(async move {
- // TODO: How to handle errors?
- let Ok((snapshot, cached_path)) = snapshot_task.await else {
- return;
- };
- let rope = snapshot.as_rope();
- let declarations = declarations_in_buffer(&snapshot)
- .into_iter()
- .map(|item| (item.parent_index, FileDeclaration::from_outline(item, rope)))
- .collect::<Vec<_>>();
-
- let Some(state) = state.upgrade() else {
- return;
- };
- let mut state = state.lock().await;
- let state = state.deref_mut();
-
- let file_state = state.files.entry(entry_id).or_insert_with(Default::default);
- for old_declaration_id in &file_state.declarations {
- let Some(declaration) = state.declarations.remove(*old_declaration_id) else {
- debug_panic!("declaration not found");
- continue;
- };
- if let Some(identifier_declarations) =
- state.identifiers.get_mut(declaration.identifier())
- {
- identifier_declarations.remove(old_declaration_id);
- }
- }
-
- let mut new_ids = Vec::with_capacity(declarations.len());
- state.declarations.reserve(declarations.len());
- for (parent_index, mut declaration) in declarations {
- declaration.parent =
- parent_index.and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
-
- let identifier = declaration.identifier.clone();
- let declaration_id = state.declarations.insert(Declaration::File {
- project_entry_id: entry_id,
- declaration,
- cached_path: cached_path.clone(),
- });
- new_ids.push(declaration_id);
-
- state
- .identifiers
- .entry(identifier)
- .or_default()
- .insert(declaration_id);
- }
- file_state.declarations = new_ids;
- })
- }
-}
-
-impl SyntaxIndexState {
- pub fn declaration(&self, id: DeclarationId) -> Option<&Declaration> {
- self.declarations.get(id)
- }
-
- /// Returns declarations for the identifier. If the limit is exceeded, returns an empty vector.
- ///
- /// TODO: Consider doing some pre-ranking and instead truncating when N is exceeded.
- pub fn declarations_for_identifier<const N: usize>(
- &self,
- identifier: &Identifier,
- ) -> Vec<(DeclarationId, &Declaration)> {
- // make sure to not have a large stack allocation
- assert!(N < 32);
-
- let Some(declaration_ids) = self.identifiers.get(&identifier) else {
- return vec![];
- };
-
- let mut result = Vec::with_capacity(N);
- let mut included_buffer_entry_ids = arrayvec::ArrayVec::<_, N>::new();
- let mut file_declarations = Vec::new();
-
- for declaration_id in declaration_ids {
- let declaration = self.declarations.get(*declaration_id);
- let Some(declaration) = some_or_debug_panic(declaration) else {
- continue;
- };
- match declaration {
- Declaration::Buffer {
- project_entry_id, ..
- } => {
- included_buffer_entry_ids.push(*project_entry_id);
- result.push((*declaration_id, declaration));
- if result.len() == N {
- return Vec::new();
- }
- }
- Declaration::File {
- project_entry_id, ..
- } => {
- if !included_buffer_entry_ids.contains(&project_entry_id) {
- file_declarations.push((*declaration_id, declaration));
- }
- }
- }
- }
-
- for (declaration_id, declaration) in file_declarations {
- match declaration {
- Declaration::File {
- project_entry_id, ..
- } => {
- if !included_buffer_entry_ids.contains(&project_entry_id) {
- result.push((declaration_id, declaration));
-
- if result.len() == N {
- return Vec::new();
- }
- }
- }
- Declaration::Buffer { .. } => {}
- }
- }
-
- result
- }
-
- pub fn buffer_declarations_containing_range(
- &self,
- buffer_id: BufferId,
- range: Range<usize>,
- ) -> impl Iterator<Item = (DeclarationId, &BufferDeclaration)> {
- let Some(buffer_state) = self.buffers.get(&buffer_id) else {
- return itertools::Either::Left(iter::empty());
- };
-
- let iter = buffer_state
- .declarations
- .iter()
- .filter_map(move |declaration_id| {
- let Some(declaration) = self
- .declarations
- .get(*declaration_id)
- .and_then(|d| d.as_buffer())
- else {
- log::error!("bug: missing buffer outline declaration");
- return None;
- };
- if declaration.item_range.contains_inclusive(&range) {
- return Some((*declaration_id, declaration));
- }
- return None;
- });
- itertools::Either::Right(iter)
- }
-
- pub fn file_declaration_count(&self, declaration: &Declaration) -> usize {
- match declaration {
- Declaration::File {
- project_entry_id, ..
- } => self
- .files
- .get(project_entry_id)
- .map(|file_state| file_state.declarations.len())
- .unwrap_or_default(),
- Declaration::Buffer { buffer_id, .. } => self
- .buffers
- .get(buffer_id)
- .map(|buffer_state| buffer_state.declarations.len())
- .unwrap_or_default(),
- }
- }
-
- fn remove_buffer_declarations(
- old_declaration_ids: &[DeclarationId],
- declarations: &mut SlotMap<DeclarationId, Declaration>,
- identifiers: &mut HashMap<Identifier, HashSet<DeclarationId>>,
- ) {
- for old_declaration_id in old_declaration_ids {
- let Some(declaration) = declarations.remove(*old_declaration_id) else {
- debug_panic!("declaration not found");
- continue;
- };
- if let Some(identifier_declarations) = identifiers.get_mut(declaration.identifier()) {
- identifier_declarations.remove(old_declaration_id);
- }
- }
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use std::sync::Arc;
-
- use gpui::TestAppContext;
- use indoc::indoc;
- use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
- use project::{FakeFs, Project};
- use serde_json::json;
- use settings::SettingsStore;
- use text::OffsetRangeExt as _;
- use util::{path, rel_path::rel_path};
-
- use crate::syntax_index::SyntaxIndex;
-
- #[gpui::test]
- async fn test_unopen_indexed_files(cx: &mut TestAppContext) {
- let (project, index, rust_lang_id) = init_test(cx).await;
- let main = Identifier {
- name: "main".into(),
- language_id: rust_lang_id,
- };
-
- let index_state = index.read_with(cx, |index, _cx| index.state().clone());
- let index_state = index_state.lock().await;
- cx.update(|cx| {
- let decls = index_state.declarations_for_identifier::<8>(&main);
- assert_eq!(decls.len(), 2);
-
- let decl = expect_file_decl("a.rs", &decls[0].1, &project, cx);
- assert_eq!(decl.identifier, main);
- assert_eq!(decl.item_range, 0..98);
-
- let decl = expect_file_decl("c.rs", &decls[1].1, &project, cx);
- assert_eq!(decl.identifier, main.clone());
- assert_eq!(decl.item_range, 32..280);
- });
- }
-
- #[gpui::test]
- async fn test_parents_in_file(cx: &mut TestAppContext) {
- let (project, index, rust_lang_id) = init_test(cx).await;
- let test_process_data = Identifier {
- name: "test_process_data".into(),
- language_id: rust_lang_id,
- };
-
- let index_state = index.read_with(cx, |index, _cx| index.state().clone());
- let index_state = index_state.lock().await;
- cx.update(|cx| {
- let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
- assert_eq!(decls.len(), 1);
-
- let decl = expect_file_decl("c.rs", &decls[0].1, &project, cx);
- assert_eq!(decl.identifier, test_process_data);
-
- let parent_id = decl.parent.unwrap();
- let parent = index_state.declaration(parent_id).unwrap();
- let parent_decl = expect_file_decl("c.rs", &parent, &project, cx);
- assert_eq!(
- parent_decl.identifier,
- Identifier {
- name: "tests".into(),
- language_id: rust_lang_id
- }
- );
- assert_eq!(parent_decl.parent, None);
- });
- }
-
- #[gpui::test]
- async fn test_parents_in_buffer(cx: &mut TestAppContext) {
- let (project, index, rust_lang_id) = init_test(cx).await;
- let test_process_data = Identifier {
- name: "test_process_data".into(),
- language_id: rust_lang_id,
- };
-
- let buffer = project
- .update(cx, |project, cx| {
- let project_path = project.find_project_path("c.rs", cx).unwrap();
- project.open_buffer(project_path, cx)
- })
- .await
- .unwrap();
-
- cx.run_until_parked();
-
- let index_state = index.read_with(cx, |index, _cx| index.state().clone());
- let index_state = index_state.lock().await;
- cx.update(|cx| {
- let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
- assert_eq!(decls.len(), 1);
-
- let decl = expect_buffer_decl("c.rs", &decls[0].1, &project, cx);
- assert_eq!(decl.identifier, test_process_data);
-
- let parent_id = decl.parent.unwrap();
- let parent = index_state.declaration(parent_id).unwrap();
- let parent_decl = expect_buffer_decl("c.rs", &parent, &project, cx);
- assert_eq!(
- parent_decl.identifier,
- Identifier {
- name: "tests".into(),
- language_id: rust_lang_id
- }
- );
- assert_eq!(parent_decl.parent, None);
- });
-
- drop(buffer);
- }
-
- #[gpui::test]
- async fn test_declarations_limit(cx: &mut TestAppContext) {
- let (_, index, rust_lang_id) = init_test(cx).await;
-
- let index_state = index.read_with(cx, |index, _cx| index.state().clone());
- let index_state = index_state.lock().await;
- let decls = index_state.declarations_for_identifier::<1>(&Identifier {
- name: "main".into(),
- language_id: rust_lang_id,
- });
- assert_eq!(decls.len(), 0);
- }
-
- #[gpui::test]
- async fn test_buffer_shadow(cx: &mut TestAppContext) {
- let (project, index, rust_lang_id) = init_test(cx).await;
-
- let main = Identifier {
- name: "main".into(),
- language_id: rust_lang_id,
- };
-
- let buffer = project
- .update(cx, |project, cx| {
- let project_path = project.find_project_path("c.rs", cx).unwrap();
- project.open_buffer(project_path, cx)
- })
- .await
- .unwrap();
-
- cx.run_until_parked();
-
- let index_state_arc = index.read_with(cx, |index, _cx| index.state().clone());
- {
- let index_state = index_state_arc.lock().await;
-
- cx.update(|cx| {
- let decls = index_state.declarations_for_identifier::<8>(&main);
- assert_eq!(decls.len(), 2);
- let decl = expect_buffer_decl("c.rs", &decls[0].1, &project, cx);
- assert_eq!(decl.identifier, main);
- assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..280);
-
- expect_file_decl("a.rs", &decls[1].1, &project, cx);
- });
- }
-
- // Drop the buffer and wait for release
- cx.update(|_| {
- drop(buffer);
- });
- cx.run_until_parked();
-
- let index_state = index_state_arc.lock().await;
-
- cx.update(|cx| {
- let decls = index_state.declarations_for_identifier::<8>(&main);
- assert_eq!(decls.len(), 2);
- expect_file_decl("a.rs", &decls[0].1, &project, cx);
- expect_file_decl("c.rs", &decls[1].1, &project, cx);
- });
- }
-
- fn expect_buffer_decl<'a>(
- path: &str,
- declaration: &'a Declaration,
- project: &Entity<Project>,
- cx: &App,
- ) -> &'a BufferDeclaration {
- if let Declaration::Buffer {
- declaration,
- project_entry_id,
- ..
- } = declaration
- {
- let project_path = project
- .read(cx)
- .path_for_entry(*project_entry_id, cx)
- .unwrap();
- assert_eq!(project_path.path.as_ref(), rel_path(path),);
- declaration
- } else {
- panic!("Expected a buffer declaration, found {:?}", declaration);
- }
- }
-
- fn expect_file_decl<'a>(
- path: &str,
- declaration: &'a Declaration,
- project: &Entity<Project>,
- cx: &App,
- ) -> &'a FileDeclaration {
- if let Declaration::File {
- declaration,
- project_entry_id: file,
- ..
- } = declaration
- {
- assert_eq!(
- project
- .read(cx)
- .path_for_entry(*file, cx)
- .unwrap()
- .path
- .as_ref(),
- rel_path(path),
- );
- declaration
- } else {
- panic!("Expected a file declaration, found {:?}", declaration);
- }
- }
-
- async fn init_test(
- cx: &mut TestAppContext,
- ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
- cx.update(|cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- });
-
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- path!("/root"),
- json!({
- "a.rs": indoc! {r#"
- fn main() {
- let x = 1;
- let y = 2;
- let z = add(x, y);
- println!("Result: {}", z);
- }
-
- fn add(a: i32, b: i32) -> i32 {
- a + b
- }
- "#},
- "b.rs": indoc! {"
- pub struct Config {
- pub name: String,
- pub value: i32,
- }
-
- impl Config {
- pub fn new(name: String, value: i32) -> Self {
- Config { name, value }
- }
- }
- "},
- "c.rs": indoc! {r#"
- use std::collections::HashMap;
-
- fn main() {
- let args: Vec<String> = std::env::args().collect();
- let data: Vec<i32> = args[1..]
- .iter()
- .filter_map(|s| s.parse().ok())
- .collect();
- let result = process_data(data);
- println!("{:?}", result);
- }
-
- fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
- let mut counts = HashMap::new();
- for value in data {
- *counts.entry(value).or_insert(0) += 1;
- }
- counts
- }
-
- #[cfg(test)]
- mod tests {
- use super::*;
-
- #[test]
- fn test_process_data() {
- let data = vec![1, 2, 2, 3];
- let result = process_data(data);
- assert_eq!(result.get(&2), Some(&2));
- }
- }
- "#}
- }),
- )
- .await;
- let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
- let language_registry = project.read_with(cx, |project, _| project.languages().clone());
- let lang = rust_lang();
- let lang_id = lang.id();
- language_registry.add(Arc::new(lang));
-
- let file_indexing_parallelism = 2;
- let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx));
- cx.run_until_parked();
-
- (project, index, lang_id)
- }
-
- fn rust_lang() -> Language {
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- )
- .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
- .unwrap()
- }
-}
@@ -1,314 +0,0 @@
-use hashbrown::HashTable;
-use regex::Regex;
-use std::{
- borrow::Cow,
- hash::{Hash, Hasher as _},
- path::Path,
- sync::LazyLock,
-};
-use util::rel_path::RelPath;
-
-use crate::reference::Reference;
-
-// TODO: Consider implementing sliding window similarity matching like
-// https://github.com/sourcegraph/cody-public-snapshot/blob/8e20ac6c1460c08b0db581c0204658112a246eda/vscode/src/completions/context/retrievers/jaccard-similarity/bestJaccardMatch.ts
-//
-// That implementation could actually be more efficient - no need to track words in the window that
-// are not in the query.
-
-// TODO: Consider a flat sorted Vec<(String, usize)> representation. Intersection can just walk the
-// two in parallel.
-
-static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
-
-/// Multiset of text occurrences for text similarity that only stores hashes and counts.
-#[derive(Debug, Default)]
-pub struct Occurrences {
- table: HashTable<OccurrenceEntry>,
- total_count: usize,
-}
-
-#[derive(Debug)]
-struct OccurrenceEntry {
- hash: u64,
- count: usize,
-}
-
-impl Occurrences {
- pub fn within_string(text: &str) -> Self {
- Self::from_identifiers(IDENTIFIER_REGEX.find_iter(text).map(|mat| mat.as_str()))
- }
-
- #[allow(dead_code)]
- pub fn within_references(references: &[Reference]) -> Self {
- Self::from_identifiers(
- references
- .iter()
- .map(|reference| reference.identifier.name.as_ref()),
- )
- }
-
- pub fn from_identifiers(identifiers: impl IntoIterator<Item = impl AsRef<str>>) -> Self {
- let mut this = Self::default();
- // TODO: Score matches that match case higher?
- //
- // TODO: Also include unsplit identifier?
- for identifier in identifiers {
- for identifier_part in split_identifier(identifier.as_ref()) {
- this.add_hash(fx_hash(&identifier_part.to_lowercase()));
- }
- }
- this
- }
-
- pub fn from_worktree_path(worktree_name: Option<Cow<'_, str>>, rel_path: &RelPath) -> Self {
- if let Some(worktree_name) = worktree_name {
- Self::from_identifiers(
- std::iter::once(worktree_name)
- .chain(iter_path_without_extension(rel_path.as_std_path())),
- )
- } else {
- Self::from_path(rel_path.as_std_path())
- }
- }
-
- pub fn from_path(path: &Path) -> Self {
- Self::from_identifiers(iter_path_without_extension(path))
- }
-
- fn add_hash(&mut self, hash: u64) {
- self.table
- .entry(
- hash,
- |entry: &OccurrenceEntry| entry.hash == hash,
- |entry| entry.hash,
- )
- .and_modify(|entry| entry.count += 1)
- .or_insert(OccurrenceEntry { hash, count: 1 });
- self.total_count += 1;
- }
-
- fn contains_hash(&self, hash: u64) -> bool {
- self.get_count(hash) != 0
- }
-
- fn get_count(&self, hash: u64) -> usize {
- self.table
- .find(hash, |entry| entry.hash == hash)
- .map(|entry| entry.count)
- .unwrap_or(0)
- }
-}
-
-fn iter_path_without_extension(path: &Path) -> impl Iterator<Item = Cow<'_, str>> {
- let last_component: Option<Cow<'_, str>> = path.file_stem().map(|stem| stem.to_string_lossy());
- let mut path_components = path.components();
- path_components.next_back();
- path_components
- .map(|component| component.as_os_str().to_string_lossy())
- .chain(last_component)
-}
-
-pub fn fx_hash<T: Hash + ?Sized>(data: &T) -> u64 {
- let mut hasher = collections::FxHasher::default();
- data.hash(&mut hasher);
- hasher.finish()
-}
-
-// Splits camelcase / snakecase / kebabcase / pascalcase
-//
-// TODO: Make this more efficient / elegant.
-fn split_identifier(identifier: &str) -> Vec<&str> {
- let mut parts = Vec::new();
- let mut start = 0;
- let chars: Vec<char> = identifier.chars().collect();
-
- if chars.is_empty() {
- return parts;
- }
-
- let mut i = 0;
- while i < chars.len() {
- let ch = chars[i];
-
- // Handle explicit delimiters (underscore and hyphen)
- if ch == '_' || ch == '-' {
- if i > start {
- parts.push(&identifier[start..i]);
- }
- start = i + 1;
- i += 1;
- continue;
- }
-
- // Handle camelCase and PascalCase transitions
- if i > 0 && i < chars.len() {
- let prev_char = chars[i - 1];
-
- // Transition from lowercase/digit to uppercase
- if (prev_char.is_lowercase() || prev_char.is_ascii_digit()) && ch.is_uppercase() {
- parts.push(&identifier[start..i]);
- start = i;
- }
- // Handle sequences like "XMLParser" -> ["XML", "Parser"]
- else if i + 1 < chars.len()
- && ch.is_uppercase()
- && chars[i + 1].is_lowercase()
- && prev_char.is_uppercase()
- {
- parts.push(&identifier[start..i]);
- start = i;
- }
- }
-
- i += 1;
- }
-
- // Add the last part if there's any remaining
- if start < identifier.len() {
- parts.push(&identifier[start..]);
- }
-
- // Filter out empty strings
- parts.into_iter().filter(|s| !s.is_empty()).collect()
-}
-
-pub fn jaccard_similarity<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
- if set_a.table.len() > set_b.table.len() {
- std::mem::swap(&mut set_a, &mut set_b);
- }
- let intersection = set_a
- .table
- .iter()
- .filter(|entry| set_b.contains_hash(entry.hash))
- .count();
- let union = set_a.table.len() + set_b.table.len() - intersection;
- intersection as f32 / union as f32
-}
-
-// TODO
-#[allow(dead_code)]
-pub fn overlap_coefficient<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
- if set_a.table.len() > set_b.table.len() {
- std::mem::swap(&mut set_a, &mut set_b);
- }
- let intersection = set_a
- .table
- .iter()
- .filter(|entry| set_b.contains_hash(entry.hash))
- .count();
- intersection as f32 / set_a.table.len() as f32
-}
-
-// TODO
-#[allow(dead_code)]
-pub fn weighted_jaccard_similarity<'a>(
- mut set_a: &'a Occurrences,
- mut set_b: &'a Occurrences,
-) -> f32 {
- if set_a.table.len() > set_b.table.len() {
- std::mem::swap(&mut set_a, &mut set_b);
- }
-
- let mut numerator = 0;
- let mut denominator_a = 0;
- let mut used_count_b = 0;
- for entry_a in set_a.table.iter() {
- let count_a = entry_a.count;
- let count_b = set_b.get_count(entry_a.hash);
- numerator += count_a.min(count_b);
- denominator_a += count_a.max(count_b);
- used_count_b += count_b;
- }
-
- let denominator = denominator_a + (set_b.total_count - used_count_b);
- if denominator == 0 {
- 0.0
- } else {
- numerator as f32 / denominator as f32
- }
-}
-
-pub fn weighted_overlap_coefficient<'a>(
- mut set_a: &'a Occurrences,
- mut set_b: &'a Occurrences,
-) -> f32 {
- if set_a.table.len() > set_b.table.len() {
- std::mem::swap(&mut set_a, &mut set_b);
- }
-
- let mut numerator = 0;
- for entry_a in set_a.table.iter() {
- let count_a = entry_a.count;
- let count_b = set_b.get_count(entry_a.hash);
- numerator += count_a.min(count_b);
- }
-
- let denominator = set_a.total_count.min(set_b.total_count);
- if denominator == 0 {
- 0.0
- } else {
- numerator as f32 / denominator as f32
- }
-}
-
-#[cfg(test)]
-mod test {
- use super::*;
-
- #[test]
- fn test_split_identifier() {
- assert_eq!(split_identifier("snake_case"), vec!["snake", "case"]);
- assert_eq!(split_identifier("kebab-case"), vec!["kebab", "case"]);
- assert_eq!(split_identifier("PascalCase"), vec!["Pascal", "Case"]);
- assert_eq!(split_identifier("camelCase"), vec!["camel", "Case"]);
- assert_eq!(split_identifier("XMLParser"), vec!["XML", "Parser"]);
- }
-
- #[test]
- fn test_similarity_functions() {
- // 10 identifier parts, 8 unique
- // Repeats: 2 "outline", 2 "items"
- let set_a = Occurrences::within_string(
- "let mut outline_items = query_outline_items(&language, &tree, &source);",
- );
- // 14 identifier parts, 11 unique
- // Repeats: 2 "outline", 2 "language", 2 "tree"
- let set_b = Occurrences::within_string(
- "pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
- );
-
- // 6 overlaps: "outline", "items", "query", "language", "tree", "source"
- // 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str"
- assert_eq!(jaccard_similarity(&set_a, &set_b), 6.0 / (6.0 + 7.0));
-
- // Numerator is one more than before due to both having 2 "outline".
- // Denominator is the same except for 3 more due to the non-overlapping duplicates
- assert_eq!(
- weighted_jaccard_similarity(&set_a, &set_b),
- 7.0 / (7.0 + 7.0 + 3.0)
- );
-
- // Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8.
- assert_eq!(overlap_coefficient(&set_a, &set_b), 6.0 / 8.0);
-
- // Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of
- // the smaller set, 10.
- assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0);
- }
-
- #[test]
- fn test_iter_path_without_extension() {
- let mut iter = iter_path_without_extension(Path::new(""));
- assert_eq!(iter.next(), None);
-
- let iter = iter_path_without_extension(Path::new("foo"));
- assert_eq!(iter.collect::<Vec<_>>(), ["foo"]);
-
- let iter = iter_path_without_extension(Path::new("foo/bar.txt"));
- assert_eq!(iter.collect::<Vec<_>>(), ["foo", "bar"]);
-
- let iter = iter_path_without_extension(Path::new("foo/bar/baz.txt"));
- assert_eq!(iter.collect::<Vec<_>>(), ["foo", "bar", "baz"]);
- }
-}
@@ -0,0 +1,17 @@
+[package]
+name = "edit_prediction_types"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/edit_prediction_types.rs"
+
+[dependencies]
+client.workspace = true
+gpui.workspace = true
+language.workspace = true
@@ -0,0 +1,298 @@
+use std::{ops::Range, sync::Arc};
+
+use client::EditPredictionUsage;
+use gpui::{App, Context, Entity, SharedString};
+use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt};
+
+// TODO: Find a better home for `Direction`.
+//
+// This should live in an ancestor crate of `editor` and `edit_prediction`,
+// but at time of writing there isn't an obvious spot.
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum Direction {
+ Prev,
+ Next,
+}
+
+#[derive(Clone)]
+pub enum EditPrediction {
+ /// Edits within the buffer that requested the prediction
+ Local {
+ id: Option<SharedString>,
+ edits: Vec<(Range<language::Anchor>, Arc<str>)>,
+ edit_preview: Option<language::EditPreview>,
+ },
+ /// Jump to a different file from the one that requested the prediction
+ Jump {
+ id: Option<SharedString>,
+ snapshot: language::BufferSnapshot,
+ target: language::Anchor,
+ },
+}
+
+pub enum DataCollectionState {
+ /// The provider doesn't support data collection.
+ Unsupported,
+ /// Data collection is enabled.
+ Enabled { is_project_open_source: bool },
+ /// Data collection is disabled or unanswered.
+ Disabled { is_project_open_source: bool },
+}
+
+impl DataCollectionState {
+ pub fn is_supported(&self) -> bool {
+ !matches!(self, DataCollectionState::Unsupported)
+ }
+
+ pub fn is_enabled(&self) -> bool {
+ matches!(self, DataCollectionState::Enabled { .. })
+ }
+
+ pub fn is_project_open_source(&self) -> bool {
+ match self {
+ Self::Enabled {
+ is_project_open_source,
+ }
+ | Self::Disabled {
+ is_project_open_source,
+ } => *is_project_open_source,
+ _ => false,
+ }
+ }
+}
+
+pub trait EditPredictionDelegate: 'static + Sized {
+ fn name() -> &'static str;
+ fn display_name() -> &'static str;
+ fn show_predictions_in_menu() -> bool;
+ fn show_tab_accept_marker() -> bool {
+ false
+ }
+ fn supports_jump_to_edit() -> bool {
+ true
+ }
+
+ fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
+ DataCollectionState::Unsupported
+ }
+
+ fn usage(&self, _cx: &App) -> Option<EditPredictionUsage> {
+ None
+ }
+
+ fn toggle_data_collection(&mut self, _cx: &mut App) {}
+ fn is_enabled(
+ &self,
+ buffer: &Entity<Buffer>,
+ cursor_position: language::Anchor,
+ cx: &App,
+ ) -> bool;
+ fn is_refreshing(&self, cx: &App) -> bool;
+ fn refresh(
+ &mut self,
+ buffer: Entity<Buffer>,
+ cursor_position: language::Anchor,
+ debounce: bool,
+ cx: &mut Context<Self>,
+ );
+ fn cycle(
+ &mut self,
+ buffer: Entity<Buffer>,
+ cursor_position: language::Anchor,
+ direction: Direction,
+ cx: &mut Context<Self>,
+ );
+ fn accept(&mut self, cx: &mut Context<Self>);
+ fn discard(&mut self, cx: &mut Context<Self>);
+ fn did_show(&mut self, _cx: &mut Context<Self>) {}
+ fn suggest(
+ &mut self,
+ buffer: &Entity<Buffer>,
+ cursor_position: language::Anchor,
+ cx: &mut Context<Self>,
+ ) -> Option<EditPrediction>;
+}
+
+pub trait EditPredictionDelegateHandle {
+ fn name(&self) -> &'static str;
+ fn display_name(&self) -> &'static str;
+ fn is_enabled(
+ &self,
+ buffer: &Entity<Buffer>,
+ cursor_position: language::Anchor,
+ cx: &App,
+ ) -> bool;
+ fn show_predictions_in_menu(&self) -> bool;
+ fn show_tab_accept_marker(&self) -> bool;
+ fn supports_jump_to_edit(&self) -> bool;
+ fn data_collection_state(&self, cx: &App) -> DataCollectionState;
+ fn usage(&self, cx: &App) -> Option<EditPredictionUsage>;
+ fn toggle_data_collection(&self, cx: &mut App);
+ fn is_refreshing(&self, cx: &App) -> bool;
+ fn refresh(
+ &self,
+ buffer: Entity<Buffer>,
+ cursor_position: language::Anchor,
+ debounce: bool,
+ cx: &mut App,
+ );
+ fn cycle(
+ &self,
+ buffer: Entity<Buffer>,
+ cursor_position: language::Anchor,
+ direction: Direction,
+ cx: &mut App,
+ );
+ fn did_show(&self, cx: &mut App);
+ fn accept(&self, cx: &mut App);
+ fn discard(&self, cx: &mut App);
+ fn suggest(
+ &self,
+ buffer: &Entity<Buffer>,
+ cursor_position: language::Anchor,
+ cx: &mut App,
+ ) -> Option<EditPrediction>;
+}
+
+impl<T> EditPredictionDelegateHandle for Entity<T>
+where
+ T: EditPredictionDelegate,
+{
+ fn name(&self) -> &'static str {
+ T::name()
+ }
+
+ fn display_name(&self) -> &'static str {
+ T::display_name()
+ }
+
+ fn show_predictions_in_menu(&self) -> bool {
+ T::show_predictions_in_menu()
+ }
+
+ fn show_tab_accept_marker(&self) -> bool {
+ T::show_tab_accept_marker()
+ }
+
+ fn supports_jump_to_edit(&self) -> bool {
+ T::supports_jump_to_edit()
+ }
+
+ fn data_collection_state(&self, cx: &App) -> DataCollectionState {
+ self.read(cx).data_collection_state(cx)
+ }
+
+ fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
+ self.read(cx).usage(cx)
+ }
+
+ fn toggle_data_collection(&self, cx: &mut App) {
+ self.update(cx, |this, cx| this.toggle_data_collection(cx))
+ }
+
+ fn is_enabled(
+ &self,
+ buffer: &Entity<Buffer>,
+ cursor_position: language::Anchor,
+ cx: &App,
+ ) -> bool {
+ self.read(cx).is_enabled(buffer, cursor_position, cx)
+ }
+
+ fn is_refreshing(&self, cx: &App) -> bool {
+ self.read(cx).is_refreshing(cx)
+ }
+
+ fn refresh(
+ &self,
+ buffer: Entity<Buffer>,
+ cursor_position: language::Anchor,
+ debounce: bool,
+ cx: &mut App,
+ ) {
+ self.update(cx, |this, cx| {
+ this.refresh(buffer, cursor_position, debounce, cx)
+ })
+ }
+
+ fn cycle(
+ &self,
+ buffer: Entity<Buffer>,
+ cursor_position: language::Anchor,
+ direction: Direction,
+ cx: &mut App,
+ ) {
+ self.update(cx, |this, cx| {
+ this.cycle(buffer, cursor_position, direction, cx)
+ })
+ }
+
+ fn accept(&self, cx: &mut App) {
+ self.update(cx, |this, cx| this.accept(cx))
+ }
+
+ fn discard(&self, cx: &mut App) {
+ self.update(cx, |this, cx| this.discard(cx))
+ }
+
+ fn did_show(&self, cx: &mut App) {
+ self.update(cx, |this, cx| this.did_show(cx))
+ }
+
+ fn suggest(
+ &self,
+ buffer: &Entity<Buffer>,
+ cursor_position: language::Anchor,
+ cx: &mut App,
+ ) -> Option<EditPrediction> {
+ self.update(cx, |this, cx| this.suggest(buffer, cursor_position, cx))
+ }
+}
+
+/// Returns edits updated based on user edits since the old snapshot. None is returned if any user
+/// edit is not a prefix of a predicted insertion.
+pub fn interpolate_edits(
+ old_snapshot: &BufferSnapshot,
+ new_snapshot: &BufferSnapshot,
+ current_edits: &[(Range<Anchor>, Arc<str>)],
+) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
+ let mut edits = Vec::new();
+
+ let mut model_edits = current_edits.iter().peekable();
+ for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
+ while let Some((model_old_range, _)) = model_edits.peek() {
+ let model_old_range = model_old_range.to_offset(old_snapshot);
+ if model_old_range.end < user_edit.old.start {
+ let (model_old_range, model_new_text) = model_edits.next().unwrap();
+ edits.push((model_old_range.clone(), model_new_text.clone()));
+ } else {
+ break;
+ }
+ }
+
+ if let Some((model_old_range, model_new_text)) = model_edits.peek() {
+ let model_old_offset_range = model_old_range.to_offset(old_snapshot);
+ if user_edit.old == model_old_offset_range {
+ let user_new_text = new_snapshot
+ .text_for_range(user_edit.new.clone())
+ .collect::<String>();
+
+ if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
+ if !model_suffix.is_empty() {
+ let anchor = old_snapshot.anchor_after(user_edit.old.end);
+ edits.push((anchor..anchor, model_suffix.into()));
+ }
+
+ model_edits.next();
+ continue;
+ }
+ }
+ }
+
+ return None;
+ }
+
+ edits.extend(model_edits.cloned());
+
+ if edits.is_empty() { None } else { Some(edits) }
+}
@@ -1,5 +1,5 @@
[package]
-name = "edit_prediction_button"
+name = "edit_prediction_ui"
version = "0.1.0"
edition.workspace = true
publish.workspace = true
@@ -9,35 +9,43 @@ license = "GPL-3.0-or-later"
workspace = true
[lib]
-path = "src/edit_prediction_button.rs"
+path = "src/edit_prediction_ui.rs"
doctest = false
[dependencies]
anyhow.workspace = true
+buffer_diff.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
+cloud_zeta2_prompt.workspace = true
codestral.workspace = true
+command_palette_hooks.workspace = true
copilot.workspace = true
edit_prediction.workspace = true
+edit_prediction_types.workspace = true
editor.workspace = true
feature_flags.workspace = true
fs.workspace = true
+futures.workspace = true
gpui.workspace = true
indoc.workspace = true
language.workspace = true
+markdown.workspace = true
+menu.workspace = true
+multi_buffer.workspace = true
paths.workspace = true
project.workspace = true
regex.workspace = true
settings.workspace = true
supermaven.workspace = true
telemetry.workspace = true
+text.workspace = true
+theme.workspace = true
ui.workspace = true
ui_input.workspace = true
-menu.workspace = true
util.workspace = true
workspace.workspace = true
zed_actions.workspace = true
-zeta.workspace = true
[dev-dependencies]
copilot = { workspace = true, features = ["test-support"] }
@@ -1,16 +1,14 @@
-mod sweep_api_token_modal;
-
-pub use sweep_api_token_modal::SweepApiKeyModal;
-
use anyhow::Result;
use client::{Client, UserStore, zed_urls};
use cloud_llm_client::UsageLimit;
-use codestral::CodestralCompletionProvider;
+use codestral::CodestralEditPredictionDelegate;
use copilot::{Copilot, Status};
+use edit_prediction::{MercuryFeatureFlag, SweepFeatureFlag, Zeta2FeatureFlag};
+use edit_prediction_types::EditPredictionDelegateHandle;
use editor::{
Editor, MultiBufferOffset, SelectionEffects, actions::ShowEditPrediction, scroll::Autoscroll,
};
-use feature_flags::{FeatureFlagAppExt, PredictEditsRateCompletionsFeatureFlag};
+use feature_flags::FeatureFlagAppExt;
use fs::Fs;
use gpui::{
Action, Animation, AnimationExt, App, AsyncWindowContext, Corner, Entity, FocusHandle,
@@ -25,6 +23,7 @@ use language::{
use project::DisableAiSettings;
use regex::Regex;
use settings::{
+ EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, Settings, SettingsStore,
update_settings_file,
@@ -44,7 +43,11 @@ use workspace::{
notifications::NotificationId,
};
use zed_actions::OpenBrowser;
-use zeta::{RateCompletions, SweepFeatureFlag, Zeta2FeatureFlag};
+
+use crate::{
+ ExternalProviderApiKeyModal, RatePredictions,
+ rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag,
+};
actions!(
edit_prediction,
@@ -67,7 +70,7 @@ pub struct EditPredictionButton {
editor_focus_handle: Option<FocusHandle>,
language: Option<Arc<Language>>,
file: Option<Arc<dyn File>>,
- edit_prediction_provider: Option<Arc<dyn edit_prediction::EditPredictionProviderHandle>>,
+ edit_prediction_provider: Option<Arc<dyn EditPredictionDelegateHandle>>,
fs: Arc<dyn Fs>,
user_store: Entity<UserStore>,
popover_menu_handle: PopoverMenuHandle<ContextMenu>,
@@ -244,7 +247,7 @@ impl Render for EditPredictionButton {
EditPredictionProvider::Codestral => {
let enabled = self.editor_enabled.unwrap_or(true);
- let has_api_key = CodestralCompletionProvider::has_api_key(cx);
+ let has_api_key = CodestralEditPredictionDelegate::has_api_key(cx);
let fs = self.fs.clone();
let this = cx.weak_entity();
@@ -309,24 +312,34 @@ impl Render for EditPredictionButton {
provider @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => {
let enabled = self.editor_enabled.unwrap_or(true);
- let is_sweep = matches!(
- provider,
- EditPredictionProvider::Experimental(
- EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
- )
- );
+ let ep_icon;
+ let mut missing_token = false;
- let sweep_missing_token = is_sweep
- && !zeta::Zeta::try_global(cx)
- .map_or(false, |zeta| zeta.read(cx).has_sweep_api_token());
-
- let zeta_icon = match (is_sweep, enabled) {
- (true, _) => IconName::SweepAi,
- (false, true) => IconName::ZedPredict,
- (false, false) => IconName::ZedPredictDisabled,
+ match provider {
+ EditPredictionProvider::Experimental(
+ EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
+ ) => {
+ ep_icon = IconName::SweepAi;
+ missing_token = edit_prediction::EditPredictionStore::try_global(cx)
+ .is_some_and(|ep_store| !ep_store.read(cx).has_sweep_api_token());
+ }
+ EditPredictionProvider::Experimental(
+ EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
+ ) => {
+ ep_icon = IconName::Inception;
+ missing_token = edit_prediction::EditPredictionStore::try_global(cx)
+ .is_some_and(|ep_store| !ep_store.read(cx).has_mercury_api_token());
+ }
+ _ => {
+ ep_icon = if enabled {
+ IconName::ZedPredict
+ } else {
+ IconName::ZedPredictDisabled
+ };
+ }
};
- if zeta::should_show_upsell_modal() {
+ if edit_prediction::should_show_upsell_modal() {
let tooltip_meta = if self.user_store.read(cx).current_user().is_some() {
"Choose a Plan"
} else {
@@ -334,7 +347,7 @@ impl Render for EditPredictionButton {
};
return div().child(
- IconButton::new("zed-predict-pending-button", zeta_icon)
+ IconButton::new("zed-predict-pending-button", ep_icon)
.shape(IconButtonShape::Square)
.indicator(Indicator::dot().color(Color::Muted))
.indicator_border_color(Some(cx.theme().colors().status_bar_background))
@@ -367,7 +380,7 @@ impl Render for EditPredictionButton {
let show_editor_predictions = self.editor_show_predictions;
let user = self.user_store.read(cx).current_user();
- let indicator_color = if sweep_missing_token {
+ let indicator_color = if missing_token {
Some(Color::Error)
} else if enabled && (!show_editor_predictions || over_limit) {
Some(if over_limit {
@@ -379,7 +392,7 @@ impl Render for EditPredictionButton {
None
};
- let icon_button = IconButton::new("zed-predict-pending-button", zeta_icon)
+ let icon_button = IconButton::new("zed-predict-pending-button", ep_icon)
.shape(IconButtonShape::Square)
.when_some(indicator_color, |this, color| {
this.indicator(Indicator::dot().color(color))
@@ -419,13 +432,13 @@ impl Render for EditPredictionButton {
let this = cx.weak_entity();
- let mut popover_menu = PopoverMenu::new("zeta")
+ let mut popover_menu = PopoverMenu::new("edit-prediction")
.when(user.is_some(), |popover_menu| {
let this = this.clone();
popover_menu.menu(move |window, cx| {
this.update(cx, |this, cx| {
- this.build_zeta_context_menu(provider, window, cx)
+ this.build_edit_prediction_context_menu(provider, window, cx)
})
.ok()
})
@@ -485,7 +498,7 @@ impl EditPredictionButton {
cx.observe_global::<SettingsStore>(move |_, cx| cx.notify())
.detach();
- CodestralCompletionProvider::ensure_api_key_loaded(client.http_client(), cx);
+ CodestralEditPredictionDelegate::ensure_api_key_loaded(client.http_client(), cx);
Self {
editor_subscription: None,
@@ -520,7 +533,7 @@ impl EditPredictionButton {
}
}
- if CodestralCompletionProvider::has_api_key(cx) {
+ if CodestralEditPredictionDelegate::has_api_key(cx) {
providers.push(EditPredictionProvider::Codestral);
}
@@ -530,6 +543,12 @@ impl EditPredictionButton {
));
}
+ if cx.has_flag::<MercuryFeatureFlag>() {
+ providers.push(EditPredictionProvider::Experimental(
+ EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
+ ));
+ }
+
if cx.has_flag::<Zeta2FeatureFlag>() {
providers.push(EditPredictionProvider::Experimental(
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME,
@@ -599,8 +618,8 @@ impl EditPredictionButton {
EditPredictionProvider::Experimental(
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
) => {
- let has_api_token = zeta::Zeta::try_global(cx)
- .map_or(false, |zeta| zeta.read(cx).has_sweep_api_token());
+ let has_api_token = edit_prediction::EditPredictionStore::try_global(cx)
+ .map_or(false, |ep_store| ep_store.read(cx).has_sweep_api_token());
let should_open_modal = !has_api_token || is_current;
@@ -626,7 +645,66 @@ impl EditPredictionButton {
if let Some(workspace) = window.root::<Workspace>().flatten() {
workspace.update(cx, |workspace, cx| {
workspace.toggle_modal(window, cx, |window, cx| {
- SweepApiKeyModal::new(window, cx)
+ ExternalProviderApiKeyModal::new(
+ window,
+ cx,
+ |api_key, store, cx| {
+ store
+ .sweep_ai
+ .set_api_token(api_key, cx)
+ .detach_and_log_err(cx);
+ },
+ )
+ });
+ });
+ };
+ } else {
+ set_completion_provider(fs.clone(), cx, provider);
+ }
+ });
+
+ menu.item(entry)
+ }
+ EditPredictionProvider::Experimental(
+ EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
+ ) => {
+ let has_api_token = edit_prediction::EditPredictionStore::try_global(cx)
+ .map_or(false, |ep_store| ep_store.read(cx).has_mercury_api_token());
+
+ let should_open_modal = !has_api_token || is_current;
+
+ let entry = if has_api_token {
+ ContextMenuEntry::new("Mercury")
+ .toggleable(IconPosition::Start, is_current)
+ } else {
+ ContextMenuEntry::new("Mercury")
+ .icon(IconName::XCircle)
+ .icon_color(Color::Error)
+ .documentation_aside(
+ DocumentationSide::Left,
+ DocumentationEdge::Bottom,
+ |_| {
+ Label::new("Click to configure your Mercury API token")
+ .into_any_element()
+ },
+ )
+ };
+
+ let entry = entry.handler(move |window, cx| {
+ if should_open_modal {
+ if let Some(workspace) = window.root::<Workspace>().flatten() {
+ workspace.update(cx, |workspace, cx| {
+ workspace.toggle_modal(window, cx, |window, cx| {
+ ExternalProviderApiKeyModal::new(
+ window,
+ cx,
+ |api_key, store, cx| {
+ store
+ .mercury
+ .set_api_token(api_key, cx)
+ .detach_and_log_err(cx);
+ },
+ )
});
});
};
@@ -947,8 +1025,8 @@ impl EditPredictionButton {
)
.context(editor_focus_handle)
.when(
- cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>(),
- |this| this.action("Rate Completions", RateCompletions.boxed_clone()),
+ cx.has_flag::<PredictEditsRatePredictionsFeatureFlag>(),
+ |this| this.action("Rate Predictions", RatePredictions.boxed_clone()),
);
}
@@ -1016,7 +1094,7 @@ impl EditPredictionButton {
})
}
- fn build_zeta_context_menu(
+ fn build_edit_prediction_context_menu(
&self,
provider: EditPredictionProvider,
window: &mut Window,
@@ -1105,9 +1183,33 @@ impl EditPredictionButton {
.separator();
}
- let menu = self.build_language_settings_menu(menu, window, cx);
- let menu = self.add_provider_switching_section(menu, provider, cx);
+ menu = self.build_language_settings_menu(menu, window, cx);
+
+ if cx.has_flag::<Zeta2FeatureFlag>() {
+ let settings = all_language_settings(None, cx);
+ let context_retrieval = settings.edit_predictions.use_context;
+ menu = menu.separator().header("Context Retrieval").item(
+ ContextMenuEntry::new("Enable Context Retrieval")
+ .toggleable(IconPosition::Start, context_retrieval)
+ .action(workspace::ToggleEditPrediction.boxed_clone())
+ .handler({
+ let fs = self.fs.clone();
+ move |_, cx| {
+ update_settings_file(fs.clone(), cx, move |settings, _| {
+ settings
+ .project
+ .all_languages
+ .features
+ .get_or_insert_default()
+ .experimental_edit_prediction_context_retrieval =
+ Some(!context_retrieval)
+ });
+ }
+ }),
+ );
+ }
+ menu = self.add_provider_switching_section(menu, provider, cx);
menu
})
}
@@ -0,0 +1,389 @@
+use std::{
+ any::TypeId,
+ collections::VecDeque,
+ ops::Add,
+ sync::Arc,
+ time::{Duration, Instant},
+};
+
+use anyhow::Result;
+use client::{Client, UserStore};
+use editor::{Editor, PathKey};
+use futures::StreamExt as _;
+use gpui::{
+ Animation, AnimationExt, App, AppContext as _, Context, Entity, EventEmitter, FocusHandle,
+ Focusable, InteractiveElement as _, IntoElement as _, ParentElement as _, SharedString,
+ Styled as _, Task, TextAlign, Window, actions, div, pulsating_between,
+};
+use multi_buffer::MultiBuffer;
+use project::Project;
+use text::OffsetRangeExt;
+use ui::{
+ ButtonCommon, Clickable, Disableable, FluentBuilder as _, IconButton, IconName,
+ StyledTypography as _, h_flex, v_flex,
+};
+
+use edit_prediction::{
+ ContextRetrievalFinishedDebugEvent, ContextRetrievalStartedDebugEvent, DebugEvent,
+ EditPredictionStore,
+};
+use workspace::Item;
+
+pub struct EditPredictionContextView {
+ empty_focus_handle: FocusHandle,
+ project: Entity<Project>,
+ store: Entity<EditPredictionStore>,
+ runs: VecDeque<RetrievalRun>,
+ current_ix: usize,
+ _update_task: Task<Result<()>>,
+}
+
+#[derive(Debug)]
+struct RetrievalRun {
+ editor: Entity<Editor>,
+ started_at: Instant,
+ metadata: Vec<(&'static str, SharedString)>,
+ finished_at: Option<Instant>,
+}
+
+actions!(
+ dev,
+ [
+ /// Go to the previous context retrieval run
+ EditPredictionContextGoBack,
+ /// Go to the next context retrieval run
+ EditPredictionContextGoForward
+ ]
+);
+
+impl EditPredictionContextView {
+ pub fn new(
+ project: Entity<Project>,
+ client: &Arc<Client>,
+ user_store: &Entity<UserStore>,
+ window: &mut gpui::Window,
+ cx: &mut Context<Self>,
+ ) -> Self {
+ let store = EditPredictionStore::global(client, user_store, cx);
+
+ let mut debug_rx = store.update(cx, |store, _| store.debug_info());
+ let _update_task = cx.spawn_in(window, async move |this, cx| {
+ while let Some(event) = debug_rx.next().await {
+ this.update_in(cx, |this, window, cx| {
+ this.handle_store_event(event, window, cx)
+ })?;
+ }
+ Ok(())
+ });
+
+ Self {
+ empty_focus_handle: cx.focus_handle(),
+ project,
+ runs: VecDeque::new(),
+ current_ix: 0,
+ store,
+ _update_task,
+ }
+ }
+
+ fn handle_store_event(
+ &mut self,
+ event: DebugEvent,
+ window: &mut gpui::Window,
+ cx: &mut Context<Self>,
+ ) {
+ match event {
+ DebugEvent::ContextRetrievalStarted(info) => {
+ if info.project_entity_id == self.project.entity_id() {
+ self.handle_context_retrieval_started(info, window, cx);
+ }
+ }
+ DebugEvent::ContextRetrievalFinished(info) => {
+ if info.project_entity_id == self.project.entity_id() {
+ self.handle_context_retrieval_finished(info, window, cx);
+ }
+ }
+ DebugEvent::EditPredictionRequested(_) => {}
+ }
+ }
+
+ fn handle_context_retrieval_started(
+ &mut self,
+ info: ContextRetrievalStartedDebugEvent,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ if self
+ .runs
+ .back()
+ .is_some_and(|run| run.finished_at.is_none())
+ {
+ self.runs.pop_back();
+ }
+
+ let multibuffer = cx.new(|_| MultiBuffer::new(language::Capability::ReadOnly));
+ let editor = cx
+ .new(|cx| Editor::for_multibuffer(multibuffer, Some(self.project.clone()), window, cx));
+
+ if self.runs.len() == 32 {
+ self.runs.pop_front();
+ }
+
+ self.runs.push_back(RetrievalRun {
+ editor,
+ started_at: info.timestamp,
+ finished_at: None,
+ metadata: Vec::new(),
+ });
+
+ cx.notify();
+ }
+
+ fn handle_context_retrieval_finished(
+ &mut self,
+ info: ContextRetrievalFinishedDebugEvent,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ let Some(run) = self.runs.back_mut() else {
+ return;
+ };
+
+ run.finished_at = Some(info.timestamp);
+ run.metadata = info.metadata;
+
+ let project = self.project.clone();
+ let related_files = self
+ .store
+ .read(cx)
+ .context_for_project(&self.project, cx)
+ .to_vec();
+
+ let editor = run.editor.clone();
+ let multibuffer = run.editor.read(cx).buffer().clone();
+
+ if self.current_ix + 2 == self.runs.len() {
+ self.current_ix += 1;
+ }
+
+ cx.spawn_in(window, async move |this, cx| {
+ let mut paths = Vec::new();
+ for related_file in related_files {
+ let (buffer, point_ranges): (_, Vec<_>) =
+ if let Some(buffer) = related_file.buffer.upgrade() {
+ let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
+
+ (
+ buffer,
+ related_file
+ .excerpts
+ .iter()
+ .map(|excerpt| excerpt.anchor_range.to_point(&snapshot))
+ .collect(),
+ )
+ } else {
+ (
+ project
+ .update(cx, |project, cx| {
+ project.open_buffer(related_file.path.clone(), cx)
+ })?
+ .await?,
+ related_file
+ .excerpts
+ .iter()
+ .map(|excerpt| excerpt.point_range.clone())
+ .collect(),
+ )
+ };
+ cx.update(|_, cx| {
+ let path = PathKey::for_buffer(&buffer, cx);
+ paths.push((path, buffer, point_ranges));
+ })?;
+ }
+
+ multibuffer.update(cx, |multibuffer, cx| {
+ multibuffer.clear(cx);
+
+ for (path, buffer, ranges) in paths {
+ multibuffer.set_excerpts_for_path(path, buffer, ranges, 0, cx);
+ }
+ })?;
+
+ editor.update_in(cx, |editor, window, cx| {
+ editor.move_to_beginning(&Default::default(), window, cx);
+ })?;
+
+ this.update(cx, |_, cx| cx.notify())
+ })
+ .detach();
+ }
+
+ fn handle_go_back(
+ &mut self,
+ _: &EditPredictionContextGoBack,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ self.current_ix = self.current_ix.saturating_sub(1);
+ cx.focus_self(window);
+ cx.notify();
+ }
+
+ fn handle_go_forward(
+ &mut self,
+ _: &EditPredictionContextGoForward,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ self.current_ix = self
+ .current_ix
+ .add(1)
+ .min(self.runs.len().saturating_sub(1));
+ cx.focus_self(window);
+ cx.notify();
+ }
+
+ fn render_informational_footer(
+ &self,
+ cx: &mut Context<'_, EditPredictionContextView>,
+ ) -> ui::Div {
+ let run = &self.runs[self.current_ix];
+ let new_run_started = self
+ .runs
+ .back()
+ .map_or(false, |latest_run| latest_run.finished_at.is_none());
+
+ h_flex()
+ .p_2()
+ .w_full()
+ .font_buffer(cx)
+ .text_xs()
+ .border_t_1()
+ .gap_2()
+ .child(v_flex().h_full().flex_1().child({
+ let t0 = run.started_at;
+ let mut table = ui::Table::<2>::new().width(ui::px(300.)).no_ui_font();
+ for (key, value) in &run.metadata {
+ table = table.row([key.into_any_element(), value.clone().into_any_element()])
+ }
+ table = table.row([
+ "Total Time".into_any_element(),
+ format!("{} ms", (run.finished_at.unwrap_or(t0) - t0).as_millis())
+ .into_any_element(),
+ ]);
+ table
+ }))
+ .child(
+ v_flex().h_full().text_align(TextAlign::Right).child(
+ h_flex()
+ .justify_end()
+ .child(
+ IconButton::new("go-back", IconName::ChevronLeft)
+ .disabled(self.current_ix == 0 || self.runs.len() < 2)
+ .tooltip(ui::Tooltip::for_action_title(
+ "Go to previous run",
+ &EditPredictionContextGoBack,
+ ))
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.handle_go_back(&EditPredictionContextGoBack, window, cx);
+ })),
+ )
+ .child(
+ div()
+ .child(format!("{}/{}", self.current_ix + 1, self.runs.len()))
+ .map(|this| {
+ if new_run_started {
+ this.with_animation(
+ "pulsating-count",
+ Animation::new(Duration::from_secs(2))
+ .repeat()
+ .with_easing(pulsating_between(0.4, 0.8)),
+ |label, delta| label.opacity(delta),
+ )
+ .into_any_element()
+ } else {
+ this.into_any_element()
+ }
+ }),
+ )
+ .child(
+ IconButton::new("go-forward", IconName::ChevronRight)
+ .disabled(self.current_ix + 1 == self.runs.len())
+ .tooltip(ui::Tooltip::for_action_title(
+ "Go to next run",
+ &EditPredictionContextGoBack,
+ ))
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.handle_go_forward(
+ &EditPredictionContextGoForward,
+ window,
+ cx,
+ );
+ })),
+ ),
+ ),
+ )
+ }
+}
+
+impl Focusable for EditPredictionContextView {
+ fn focus_handle(&self, cx: &App) -> FocusHandle {
+ self.runs
+ .get(self.current_ix)
+ .map(|run| run.editor.read(cx).focus_handle(cx))
+ .unwrap_or_else(|| self.empty_focus_handle.clone())
+ }
+}
+
+impl EventEmitter<()> for EditPredictionContextView {}
+
+impl Item for EditPredictionContextView {
+ type Event = ();
+
+ fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
+ "Edit Prediction Context".into()
+ }
+
+ fn buffer_kind(&self, _cx: &App) -> workspace::item::ItemBufferKind {
+ workspace::item::ItemBufferKind::Multibuffer
+ }
+
+ fn act_as_type<'a>(
+ &'a self,
+ type_id: TypeId,
+ self_handle: &'a Entity<Self>,
+ _: &'a App,
+ ) -> Option<gpui::AnyEntity> {
+ if type_id == TypeId::of::<Self>() {
+ Some(self_handle.clone().into())
+ } else if type_id == TypeId::of::<Editor>() {
+ Some(self.runs.get(self.current_ix)?.editor.clone().into())
+ } else {
+ None
+ }
+ }
+}
+
+impl gpui::Render for EditPredictionContextView {
+ fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl ui::IntoElement {
+ v_flex()
+ .key_context("EditPredictionContext")
+ .on_action(cx.listener(Self::handle_go_back))
+ .on_action(cx.listener(Self::handle_go_forward))
+ .size_full()
+ .map(|this| {
+ if self.runs.is_empty() {
+ this.child(
+ v_flex()
+ .size_full()
+ .justify_center()
+ .items_center()
+ .child("No retrieval runs yet"),
+ )
+ } else {
+ this.child(self.runs[self.current_ix].editor.clone())
+ .child(self.render_informational_footer(cx))
+ }
+ })
+ }
+}
@@ -0,0 +1,128 @@
+mod edit_prediction_button;
+mod edit_prediction_context_view;
+mod external_provider_api_token_modal;
+mod rate_prediction_modal;
+
+use std::any::{Any as _, TypeId};
+
+use command_palette_hooks::CommandPaletteFilter;
+use edit_prediction::{ResetOnboarding, Zeta2FeatureFlag};
+use edit_prediction_context_view::EditPredictionContextView;
+use feature_flags::FeatureFlagAppExt as _;
+use gpui::actions;
+use project::DisableAiSettings;
+use rate_prediction_modal::RatePredictionsModal;
+use settings::{Settings as _, SettingsStore};
+use ui::{App, prelude::*};
+use workspace::{SplitDirection, Workspace};
+
+pub use edit_prediction_button::{EditPredictionButton, ToggleMenu};
+pub use external_provider_api_token_modal::ExternalProviderApiKeyModal;
+
+use crate::rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag;
+
+actions!(
+ dev,
+ [
+ /// Opens the edit prediction context view.
+ OpenEditPredictionContextView,
+ ]
+);
+
+actions!(
+ edit_prediction,
+ [
+ /// Opens the rate completions modal.
+ RatePredictions,
+ ]
+);
+
+pub fn init(cx: &mut App) {
+ feature_gate_predict_edits_actions(cx);
+
+ cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
+ workspace.register_action(|workspace, _: &RatePredictions, window, cx| {
+ if cx.has_flag::<PredictEditsRatePredictionsFeatureFlag>() {
+ RatePredictionsModal::toggle(workspace, window, cx);
+ }
+ });
+
+ workspace.register_action_renderer(|div, _, _, cx| {
+ let has_flag = cx.has_flag::<Zeta2FeatureFlag>();
+ div.when(has_flag, |div| {
+ div.on_action(cx.listener(
+ move |workspace, _: &OpenEditPredictionContextView, window, cx| {
+ let project = workspace.project();
+ workspace.split_item(
+ SplitDirection::Right,
+ Box::new(cx.new(|cx| {
+ EditPredictionContextView::new(
+ project.clone(),
+ workspace.client(),
+ workspace.user_store(),
+ window,
+ cx,
+ )
+ })),
+ window,
+ cx,
+ );
+ },
+ ))
+ })
+ });
+ })
+ .detach();
+}
+
+fn feature_gate_predict_edits_actions(cx: &mut App) {
+ let rate_completion_action_types = [TypeId::of::<RatePredictions>()];
+ let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
+ let all_action_types = [
+ TypeId::of::<RatePredictions>(),
+ TypeId::of::<edit_prediction::ResetOnboarding>(),
+ zed_actions::OpenZedPredictOnboarding.type_id(),
+ TypeId::of::<edit_prediction::ClearHistory>(),
+ TypeId::of::<rate_prediction_modal::ThumbsUpActivePrediction>(),
+ TypeId::of::<rate_prediction_modal::ThumbsDownActivePrediction>(),
+ TypeId::of::<rate_prediction_modal::NextEdit>(),
+ TypeId::of::<rate_prediction_modal::PreviousEdit>(),
+ ];
+
+ CommandPaletteFilter::update_global(cx, |filter, _cx| {
+ filter.hide_action_types(&rate_completion_action_types);
+ filter.hide_action_types(&reset_onboarding_action_types);
+ filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
+ });
+
+ cx.observe_global::<SettingsStore>(move |cx| {
+ let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
+ let has_feature_flag = cx.has_flag::<PredictEditsRatePredictionsFeatureFlag>();
+
+ CommandPaletteFilter::update_global(cx, |filter, _cx| {
+ if is_ai_disabled {
+ filter.hide_action_types(&all_action_types);
+ } else if has_feature_flag {
+ filter.show_action_types(&rate_completion_action_types);
+ } else {
+ filter.hide_action_types(&rate_completion_action_types);
+ }
+ });
+ })
+ .detach();
+
+ cx.observe_flag::<PredictEditsRatePredictionsFeatureFlag, _>(move |is_enabled, cx| {
+ if !DisableAiSettings::get_global(cx).disable_ai {
+ if is_enabled {
+ CommandPaletteFilter::update_global(cx, |filter, _cx| {
+ filter.show_action_types(&rate_completion_action_types);
+ });
+ } else {
+ CommandPaletteFilter::update_global(cx, |filter, _cx| {
+ filter.hide_action_types(&rate_completion_action_types);
+ });
+ }
+ }
+ })
+ .detach();
+}
@@ -1,23 +1,29 @@
+use edit_prediction::EditPredictionStore;
use gpui::{
DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, IntoElement, ParentElement, Render,
};
use ui::{Button, ButtonStyle, Clickable, Headline, HeadlineSize, prelude::*};
use ui_input::InputField;
use workspace::ModalView;
-use zeta::Zeta;
-pub struct SweepApiKeyModal {
+pub struct ExternalProviderApiKeyModal {
api_key_input: Entity<InputField>,
focus_handle: FocusHandle,
+ on_confirm: Box<dyn Fn(Option<String>, &mut EditPredictionStore, &mut App)>,
}
-impl SweepApiKeyModal {
- pub fn new(window: &mut Window, cx: &mut Context<Self>) -> Self {
- let api_key_input = cx.new(|cx| InputField::new(window, cx, "Enter your Sweep API token"));
+impl ExternalProviderApiKeyModal {
+ pub fn new(
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ on_confirm: impl Fn(Option<String>, &mut EditPredictionStore, &mut App) + 'static,
+ ) -> Self {
+ let api_key_input = cx.new(|cx| InputField::new(window, cx, "Enter your API key"));
Self {
api_key_input,
focus_handle: cx.focus_handle(),
+ on_confirm: Box::new(on_confirm),
}
}
@@ -29,39 +35,35 @@ impl SweepApiKeyModal {
let api_key = self.api_key_input.read(cx).text(cx);
let api_key = (!api_key.trim().is_empty()).then_some(api_key);
- if let Some(zeta) = Zeta::try_global(cx) {
- zeta.update(cx, |zeta, cx| {
- zeta.sweep_ai
- .set_api_token(api_key, cx)
- .detach_and_log_err(cx);
- });
+ if let Some(ep_store) = EditPredictionStore::try_global(cx) {
+ ep_store.update(cx, |ep_store, cx| (self.on_confirm)(api_key, ep_store, cx))
}
cx.emit(DismissEvent);
}
}
-impl EventEmitter<DismissEvent> for SweepApiKeyModal {}
+impl EventEmitter<DismissEvent> for ExternalProviderApiKeyModal {}
-impl ModalView for SweepApiKeyModal {}
+impl ModalView for ExternalProviderApiKeyModal {}
-impl Focusable for SweepApiKeyModal {
+impl Focusable for ExternalProviderApiKeyModal {
fn focus_handle(&self, _cx: &App) -> FocusHandle {
self.focus_handle.clone()
}
}
-impl Render for SweepApiKeyModal {
+impl Render for ExternalProviderApiKeyModal {
fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()
- .key_context("SweepApiKeyModal")
+ .key_context("ExternalApiKeyModal")
.on_action(cx.listener(Self::cancel))
.on_action(cx.listener(Self::confirm))
.elevation_2(cx)
.w(px(400.))
.p_4()
.gap_3()
- .child(Headline::new("Sweep API Token").size(HeadlineSize::Small))
+ .child(Headline::new("API Token").size(HeadlineSize::Small))
.child(self.api_key_input.clone())
.child(
h_flex()
@@ -1,7 +1,8 @@
-use crate::{EditPrediction, EditPredictionRating, Zeta};
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use cloud_zeta2_prompt::write_codeblock;
+use edit_prediction::{EditPrediction, EditPredictionRating, EditPredictionStore};
use editor::{Editor, ExcerptRange, MultiBuffer};
+use feature_flags::FeatureFlag;
use gpui::{
App, BorderStyle, DismissEvent, EdgesRefinement, Entity, EventEmitter, FocusHandle, Focusable,
Length, StyleRefinement, TextStyleRefinement, Window, actions, prelude::*,
@@ -9,9 +10,7 @@ use gpui::{
use language::{LanguageRegistry, Point, language_settings};
use markdown::{Markdown, MarkdownStyle};
use settings::Settings as _;
-use std::fmt::Write;
-use std::sync::Arc;
-use std::time::Duration;
+use std::{fmt::Write, sync::Arc, time::Duration};
use theme::ThemeSettings;
use ui::{KeyBinding, List, ListItem, ListItemSpacing, Tooltip, prelude::*};
use workspace::{ModalView, Workspace};
@@ -34,8 +33,14 @@ actions!(
]
);
+pub struct PredictEditsRatePredictionsFeatureFlag;
+
+impl FeatureFlag for PredictEditsRatePredictionsFeatureFlag {
+ const NAME: &'static str = "predict-edits-rate-completions";
+}
+
pub struct RatePredictionsModal {
- zeta: Entity<Zeta>,
+ ep_store: Entity<EditPredictionStore>,
language_registry: Arc<LanguageRegistry>,
active_prediction: Option<ActivePrediction>,
selected_index: usize,
@@ -68,10 +73,10 @@ impl RatePredictionView {
impl RatePredictionsModal {
pub fn toggle(workspace: &mut Workspace, window: &mut Window, cx: &mut Context<Workspace>) {
- if let Some(zeta) = Zeta::try_global(cx) {
+ if let Some(ep_store) = EditPredictionStore::try_global(cx) {
let language_registry = workspace.app_state().languages.clone();
workspace.toggle_modal(window, cx, |window, cx| {
- RatePredictionsModal::new(zeta, language_registry, window, cx)
+ RatePredictionsModal::new(ep_store, language_registry, window, cx)
});
telemetry::event!("Rate Prediction Modal Open", source = "Edit Prediction");
@@ -79,15 +84,15 @@ impl RatePredictionsModal {
}
pub fn new(
- zeta: Entity<Zeta>,
+ ep_store: Entity<EditPredictionStore>,
language_registry: Arc<LanguageRegistry>,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
- let subscription = cx.observe(&zeta, |_, _, cx| cx.notify());
+ let subscription = cx.observe(&ep_store, |_, _, cx| cx.notify());
Self {
- zeta,
+ ep_store,
language_registry,
selected_index: 0,
focus_handle: cx.focus_handle(),
@@ -113,7 +118,7 @@ impl RatePredictionsModal {
self.selected_index += 1;
self.selected_index = usize::min(
self.selected_index,
- self.zeta.read(cx).shown_predictions().count(),
+ self.ep_store.read(cx).shown_predictions().count(),
);
cx.notify();
}
@@ -130,7 +135,7 @@ impl RatePredictionsModal {
fn select_next_edit(&mut self, _: &NextEdit, _: &mut Window, cx: &mut Context<Self>) {
let next_index = self
- .zeta
+ .ep_store
.read(cx)
.shown_predictions()
.skip(self.selected_index)
@@ -146,11 +151,11 @@ impl RatePredictionsModal {
}
fn select_prev_edit(&mut self, _: &PreviousEdit, _: &mut Window, cx: &mut Context<Self>) {
- let zeta = self.zeta.read(cx);
- let completions_len = zeta.shown_completions_len();
+ let ep_store = self.ep_store.read(cx);
+ let completions_len = ep_store.shown_completions_len();
let prev_index = self
- .zeta
+ .ep_store
.read(cx)
.shown_predictions()
.rev()
@@ -173,7 +178,7 @@ impl RatePredictionsModal {
}
fn select_last(&mut self, _: &menu::SelectLast, _window: &mut Window, cx: &mut Context<Self>) {
- self.selected_index = self.zeta.read(cx).shown_completions_len() - 1;
+ self.selected_index = self.ep_store.read(cx).shown_completions_len() - 1;
cx.notify();
}
@@ -183,9 +188,9 @@ impl RatePredictionsModal {
window: &mut Window,
cx: &mut Context<Self>,
) {
- self.zeta.update(cx, |zeta, cx| {
+ self.ep_store.update(cx, |ep_store, cx| {
if let Some(active) = &self.active_prediction {
- zeta.rate_prediction(
+ ep_store.rate_prediction(
&active.prediction,
EditPredictionRating::Positive,
active.feedback_editor.read(cx).text(cx),
@@ -216,8 +221,8 @@ impl RatePredictionsModal {
return;
}
- self.zeta.update(cx, |zeta, cx| {
- zeta.rate_prediction(
+ self.ep_store.update(cx, |ep_store, cx| {
+ ep_store.rate_prediction(
&active.prediction,
EditPredictionRating::Negative,
active.feedback_editor.read(cx).text(cx),
@@ -254,7 +259,7 @@ impl RatePredictionsModal {
cx: &mut Context<Self>,
) {
let completion = self
- .zeta
+ .ep_store
.read(cx)
.shown_predictions()
.skip(self.selected_index)
@@ -267,7 +272,7 @@ impl RatePredictionsModal {
fn confirm(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
let completion = self
- .zeta
+ .ep_store
.read(cx)
.shown_predictions()
.skip(self.selected_index)
@@ -288,7 +293,7 @@ impl RatePredictionsModal {
// Avoid resetting completion rating if it's already selected.
if let Some(prediction) = prediction {
self.selected_index = self
- .zeta
+ .ep_store
.read(cx)
.shown_predictions()
.enumerate()
@@ -376,7 +381,7 @@ impl RatePredictionsModal {
&included_file.path,
&included_file.excerpts,
if included_file.path == prediction.inputs.cursor_path {
- cursor_insertions
+ cursor_insertions.as_slice()
} else {
&[]
},
@@ -564,7 +569,7 @@ impl RatePredictionsModal {
let border_color = cx.theme().colors().border;
let bg_color = cx.theme().colors().editor_background;
- let rated = self.zeta.read(cx).is_prediction_rated(&completion_id);
+ let rated = self.ep_store.read(cx).is_prediction_rated(&completion_id);
let feedback_empty = active_prediction
.feedback_editor
.read(cx)
@@ -715,7 +720,7 @@ impl RatePredictionsModal {
}
fn render_shown_completions(&self, cx: &Context<Self>) -> impl Iterator<Item = ListItem> {
- self.zeta
+ self.ep_store
.read(cx)
.shown_predictions()
.cloned()
@@ -725,7 +730,7 @@ impl RatePredictionsModal {
.active_prediction
.as_ref()
.is_some_and(|selected| selected.prediction.id == completion.id);
- let rated = self.zeta.read(cx).is_prediction_rated(&completion.id);
+ let rated = self.ep_store.read(cx).is_prediction_rated(&completion.id);
let (icon_name, icon_color, tooltip_text) =
match (rated, completion.edits.is_empty()) {
@@ -49,7 +49,7 @@ fs.workspace = true
git.workspace = true
gpui.workspace = true
indoc.workspace = true
-edit_prediction.workspace = true
+edit_prediction_types.workspace = true
itertools.workspace = true
language.workspace = true
linkify.workspace = true
@@ -84,6 +84,8 @@ tree-sitter-html = { workspace = true, optional = true }
tree-sitter-rust = { workspace = true, optional = true }
tree-sitter-typescript = { workspace = true, optional = true }
tree-sitter-python = { workspace = true, optional = true }
+ztracing.workspace = true
+tracing.workspace = true
unicode-segmentation.workspace = true
unicode-script.workspace = true
unindent = { workspace = true, optional = true }
@@ -94,6 +96,7 @@ uuid.workspace = true
vim_mode_setting.workspace = true
workspace.workspace = true
zed_actions.workspace = true
+zlog.workspace = true
[dev-dependencies]
criterion.workspace = true
@@ -118,6 +121,7 @@ tree-sitter-rust.workspace = true
tree-sitter-typescript.workspace = true
tree-sitter-yaml.workspace = true
tree-sitter-bash.workspace = true
+tree-sitter-md.workspace = true
unindent.workspace = true
util = { workspace = true, features = ["test-support"] }
workspace = { workspace = true, features = ["test-support"] }
@@ -453,8 +453,6 @@ actions!(
CollapseAllDiffHunks,
/// Expands macros recursively at cursor position.
ExpandMacroRecursively,
- /// Finds all references to the symbol at cursor.
- FindAllReferences,
/// Finds the next match in the search.
FindNextMatch,
/// Finds the previous match in the search.
@@ -827,3 +825,20 @@ actions!(
WrapSelectionsInTag
]
);
+
+/// Finds all references to the symbol at cursor.
+#[derive(PartialEq, Clone, Deserialize, JsonSchema, Action)]
+#[action(namespace = editor)]
+#[serde(deny_unknown_fields)]
+pub struct FindAllReferences {
+ #[serde(default = "default_true")]
+ pub always_open_multibuffer: bool,
+}
+
+impl Default for FindAllReferences {
+ fn default() -> Self {
+ Self {
+ always_open_multibuffer: true,
+ }
+ }
+}
@@ -164,6 +164,7 @@ impl<T> BlockPlacement<T> {
}
impl BlockPlacement<Anchor> {
+ #[ztracing::instrument(skip_all)]
fn cmp(&self, other: &Self, buffer: &MultiBufferSnapshot) -> Ordering {
self.start()
.cmp(other.start(), buffer)
@@ -171,6 +172,7 @@ impl BlockPlacement<Anchor> {
.then_with(|| self.tie_break().cmp(&other.tie_break()))
}
+ #[ztracing::instrument(skip_all)]
fn to_wrap_row(&self, wrap_snapshot: &WrapSnapshot) -> Option<BlockPlacement<WrapRow>> {
let buffer_snapshot = wrap_snapshot.buffer_snapshot();
match self {
@@ -474,6 +476,7 @@ pub struct BlockRows<'a> {
}
impl BlockMap {
+ #[ztracing::instrument(skip_all)]
pub fn new(
wrap_snapshot: WrapSnapshot,
buffer_header_height: u32,
@@ -503,6 +506,7 @@ impl BlockMap {
map
}
+ #[ztracing::instrument(skip_all)]
pub fn read(&self, wrap_snapshot: WrapSnapshot, edits: WrapPatch) -> BlockMapReader<'_> {
self.sync(&wrap_snapshot, edits);
*self.wrap_snapshot.borrow_mut() = wrap_snapshot.clone();
@@ -518,13 +522,17 @@ impl BlockMap {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn write(&mut self, wrap_snapshot: WrapSnapshot, edits: WrapPatch) -> BlockMapWriter<'_> {
self.sync(&wrap_snapshot, edits);
*self.wrap_snapshot.borrow_mut() = wrap_snapshot;
BlockMapWriter(self)
}
+ #[ztracing::instrument(skip_all, fields(edits))]
fn sync(&self, wrap_snapshot: &WrapSnapshot, mut edits: WrapPatch) {
+ let _timer = zlog::time!("BlockMap::sync").warn_if_gt(std::time::Duration::from_millis(50));
+
let buffer = wrap_snapshot.buffer_snapshot();
// Handle changing the last excerpt if it is empty.
@@ -784,6 +792,7 @@ impl BlockMap {
*transforms = new_transforms;
}
+ #[ztracing::instrument(skip_all)]
pub fn replace_blocks(&mut self, mut renderers: HashMap<CustomBlockId, RenderBlock>) {
for block in &mut self.custom_blocks {
if let Some(render) = renderers.remove(&block.id) {
@@ -793,6 +802,7 @@ impl BlockMap {
}
/// Guarantees that `wrap_row_for` is called with points in increasing order.
+ #[ztracing::instrument(skip_all)]
fn header_and_footer_blocks<'a, R, T>(
&'a self,
buffer: &'a multi_buffer::MultiBufferSnapshot,
@@ -880,6 +890,7 @@ impl BlockMap {
})
}
+ #[ztracing::instrument(skip_all)]
fn sort_blocks(blocks: &mut Vec<(BlockPlacement<WrapRow>, Block)>) {
blocks.sort_unstable_by(|(placement_a, block_a), (placement_b, block_b)| {
placement_a
@@ -1016,6 +1027,7 @@ impl DerefMut for BlockMapReader<'_> {
}
impl BlockMapReader<'_> {
+ #[ztracing::instrument(skip_all)]
pub fn row_for_block(&self, block_id: CustomBlockId) -> Option<BlockRow> {
let block = self.blocks.iter().find(|block| block.id == block_id)?;
let buffer_row = block
@@ -1054,6 +1066,7 @@ impl BlockMapReader<'_> {
}
impl BlockMapWriter<'_> {
+ #[ztracing::instrument(skip_all)]
pub fn insert(
&mut self,
blocks: impl IntoIterator<Item = BlockProperties<Anchor>>,
@@ -1120,6 +1133,7 @@ impl BlockMapWriter<'_> {
ids
}
+ #[ztracing::instrument(skip_all)]
pub fn resize(&mut self, mut heights: HashMap<CustomBlockId, u32>) {
let wrap_snapshot = &*self.0.wrap_snapshot.borrow();
let buffer = wrap_snapshot.buffer_snapshot();
@@ -1172,6 +1186,7 @@ impl BlockMapWriter<'_> {
self.0.sync(wrap_snapshot, edits);
}
+ #[ztracing::instrument(skip_all)]
pub fn remove(&mut self, block_ids: HashSet<CustomBlockId>) {
let wrap_snapshot = &*self.0.wrap_snapshot.borrow();
let buffer = wrap_snapshot.buffer_snapshot();
@@ -1217,6 +1232,7 @@ impl BlockMapWriter<'_> {
self.0.sync(wrap_snapshot, edits);
}
+ #[ztracing::instrument(skip_all)]
pub fn remove_intersecting_replace_blocks(
&mut self,
ranges: impl IntoIterator<Item = Range<MultiBufferOffset>>,
@@ -1239,6 +1255,7 @@ impl BlockMapWriter<'_> {
self.0.buffers_with_disabled_headers.insert(buffer_id);
}
+ #[ztracing::instrument(skip_all)]
pub fn fold_buffers(
&mut self,
buffer_ids: impl IntoIterator<Item = BufferId>,
@@ -1248,6 +1265,7 @@ impl BlockMapWriter<'_> {
self.fold_or_unfold_buffers(true, buffer_ids, multi_buffer, cx);
}
+ #[ztracing::instrument(skip_all)]
pub fn unfold_buffers(
&mut self,
buffer_ids: impl IntoIterator<Item = BufferId>,
@@ -1257,6 +1275,7 @@ impl BlockMapWriter<'_> {
self.fold_or_unfold_buffers(false, buffer_ids, multi_buffer, cx);
}
+ #[ztracing::instrument(skip_all)]
fn fold_or_unfold_buffers(
&mut self,
fold: bool,
@@ -1292,6 +1311,7 @@ impl BlockMapWriter<'_> {
self.0.sync(&wrap_snapshot, edits);
}
+ #[ztracing::instrument(skip_all)]
fn blocks_intersecting_buffer_range(
&self,
range: Range<MultiBufferOffset>,
@@ -1326,6 +1346,7 @@ impl BlockMapWriter<'_> {
impl BlockSnapshot {
#[cfg(test)]
+ #[ztracing::instrument(skip_all)]
pub fn text(&self) -> String {
self.chunks(
BlockRow(0)..self.transforms.summary().output_rows,
@@ -1337,6 +1358,7 @@ impl BlockSnapshot {
.collect()
}
+ #[ztracing::instrument(skip_all)]
pub(crate) fn chunks<'a>(
&'a self,
rows: Range<BlockRow>,
@@ -1378,6 +1400,7 @@ impl BlockSnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub(super) fn row_infos(&self, start_row: BlockRow) -> BlockRows<'_> {
let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(());
cursor.seek(&start_row, Bias::Right);
@@ -1399,6 +1422,7 @@ impl BlockSnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn blocks_in_range(
&self,
rows: Range<BlockRow>,
@@ -1432,6 +1456,7 @@ impl BlockSnapshot {
})
}
+ #[ztracing::instrument(skip_all)]
pub(crate) fn sticky_header_excerpt(&self, position: f64) -> Option<StickyHeaderExcerpt<'_>> {
let top_row = position as u32;
let mut cursor = self.transforms.cursor::<BlockRow>(());
@@ -1455,6 +1480,7 @@ impl BlockSnapshot {
None
}
+ #[ztracing::instrument(skip_all)]
pub fn block_for_id(&self, block_id: BlockId) -> Option<Block> {
let buffer = self.wrap_snapshot.buffer_snapshot();
let wrap_point = match block_id {
@@ -1491,6 +1517,7 @@ impl BlockSnapshot {
None
}
+ #[ztracing::instrument(skip_all)]
pub fn max_point(&self) -> BlockPoint {
let row = self
.transforms
@@ -1500,10 +1527,12 @@ impl BlockSnapshot {
BlockPoint::new(row, self.line_len(row))
}
+ #[ztracing::instrument(skip_all)]
pub fn longest_row(&self) -> BlockRow {
self.transforms.summary().longest_row
}
+ #[ztracing::instrument(skip_all)]
pub fn longest_row_in_range(&self, range: Range<BlockRow>) -> BlockRow {
let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(());
cursor.seek(&range.start, Bias::Right);
@@ -1555,6 +1584,7 @@ impl BlockSnapshot {
longest_row
}
+ #[ztracing::instrument(skip_all)]
pub(super) fn line_len(&self, row: BlockRow) -> u32 {
let (start, _, item) =
self.transforms
@@ -1574,11 +1604,13 @@ impl BlockSnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub(super) fn is_block_line(&self, row: BlockRow) -> bool {
let (_, _, item) = self.transforms.find::<BlockRow, _>((), &row, Bias::Right);
item.is_some_and(|t| t.block.is_some())
}
+ #[ztracing::instrument(skip_all)]
pub(super) fn is_folded_buffer_header(&self, row: BlockRow) -> bool {
let (_, _, item) = self.transforms.find::<BlockRow, _>((), &row, Bias::Right);
let Some(transform) = item else {
@@ -1587,6 +1619,7 @@ impl BlockSnapshot {
matches!(transform.block, Some(Block::FoldedBuffer { .. }))
}
+ #[ztracing::instrument(skip_all)]
pub(super) fn is_line_replaced(&self, row: MultiBufferRow) -> bool {
let wrap_point = self
.wrap_snapshot
@@ -1602,6 +1635,7 @@ impl BlockSnapshot {
})
}
+ #[ztracing::instrument(skip_all)]
pub fn clip_point(&self, point: BlockPoint, bias: Bias) -> BlockPoint {
let mut cursor = self.transforms.cursor::<Dimensions<BlockRow, WrapRow>>(());
cursor.seek(&BlockRow(point.row), Bias::Right);
@@ -1663,6 +1697,7 @@ impl BlockSnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn to_block_point(&self, wrap_point: WrapPoint) -> BlockPoint {
let (start, _, item) = self.transforms.find::<Dimensions<WrapRow, BlockRow>, _>(
(),
@@ -1684,6 +1719,7 @@ impl BlockSnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn to_wrap_point(&self, block_point: BlockPoint, bias: Bias) -> WrapPoint {
let (start, end, item) = self.transforms.find::<Dimensions<BlockRow, WrapRow>, _>(
(),
@@ -1719,6 +1755,7 @@ impl BlockSnapshot {
impl BlockChunks<'_> {
/// Go to the next transform
+ #[ztracing::instrument(skip_all)]
fn advance(&mut self) {
self.input_chunk = Chunk::default();
self.transforms.next();
@@ -1759,6 +1796,7 @@ pub struct StickyHeaderExcerpt<'a> {
impl<'a> Iterator for BlockChunks<'a> {
type Item = Chunk<'a>;
+ #[ztracing::instrument(skip_all)]
fn next(&mut self) -> Option<Self::Item> {
if self.output_row >= self.max_output_row {
return None;
@@ -1858,6 +1896,7 @@ impl<'a> Iterator for BlockChunks<'a> {
impl Iterator for BlockRows<'_> {
type Item = RowInfo;
+ #[ztracing::instrument(skip_all)]
fn next(&mut self) -> Option<Self::Item> {
if self.started {
self.output_row.0 += 1;
@@ -1960,14 +1999,17 @@ impl DerefMut for BlockContext<'_, '_> {
}
impl CustomBlock {
+ #[ztracing::instrument(skip_all)]
pub fn render(&self, cx: &mut BlockContext) -> AnyElement {
self.render.lock()(cx)
}
+ #[ztracing::instrument(skip_all)]
pub fn start(&self) -> Anchor {
*self.placement.start()
}
+ #[ztracing::instrument(skip_all)]
pub fn end(&self) -> Anchor {
*self.placement.end()
}
@@ -19,6 +19,7 @@ pub struct CreaseMap {
}
impl CreaseMap {
+ #[ztracing::instrument(skip_all)]
pub fn new(snapshot: &MultiBufferSnapshot) -> Self {
CreaseMap {
snapshot: CreaseSnapshot::new(snapshot),
@@ -40,11 +41,13 @@ impl CreaseSnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn creases(&self) -> impl Iterator<Item = (CreaseId, &Crease<Anchor>)> {
self.creases.iter().map(|item| (item.id, &item.crease))
}
/// Returns the first Crease starting on the specified buffer row.
+ #[ztracing::instrument(skip_all)]
pub fn query_row<'a>(
&'a self,
row: MultiBufferRow,
@@ -69,6 +72,7 @@ impl CreaseSnapshot {
None
}
+ #[ztracing::instrument(skip_all)]
pub fn creases_in_range<'a>(
&'a self,
range: Range<MultiBufferRow>,
@@ -95,6 +99,7 @@ impl CreaseSnapshot {
})
}
+ #[ztracing::instrument(skip_all)]
pub fn crease_items_with_offsets(
&self,
snapshot: &MultiBufferSnapshot,
@@ -156,6 +161,7 @@ pub struct CreaseMetadata {
}
impl<T> Crease<T> {
+ #[ztracing::instrument(skip_all)]
pub fn simple(range: Range<T>, placeholder: FoldPlaceholder) -> Self {
Crease::Inline {
range,
@@ -166,6 +172,7 @@ impl<T> Crease<T> {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn block(range: Range<T>, height: u32, style: BlockStyle, render: RenderBlock) -> Self {
Self::Block {
range,
@@ -177,6 +184,7 @@ impl<T> Crease<T> {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn inline<RenderToggle, ToggleElement, RenderTrailer, TrailerElement>(
range: Range<T>,
placeholder: FoldPlaceholder,
@@ -216,6 +224,7 @@ impl<T> Crease<T> {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn with_metadata(self, metadata: CreaseMetadata) -> Self {
match self {
Crease::Inline {
@@ -235,6 +244,7 @@ impl<T> Crease<T> {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn range(&self) -> &Range<T> {
match self {
Crease::Inline { range, .. } => range,
@@ -242,6 +252,7 @@ impl<T> Crease<T> {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn metadata(&self) -> Option<&CreaseMetadata> {
match self {
Self::Inline { metadata, .. } => metadata.as_ref(),
@@ -287,6 +298,7 @@ impl CreaseMap {
self.snapshot.clone()
}
+ #[ztracing::instrument(skip_all)]
pub fn insert(
&mut self,
creases: impl IntoIterator<Item = Crease<Anchor>>,
@@ -312,6 +324,7 @@ impl CreaseMap {
new_ids
}
+ #[ztracing::instrument(skip_all)]
pub fn remove(
&mut self,
ids: impl IntoIterator<Item = CreaseId>,
@@ -379,6 +392,7 @@ impl sum_tree::Summary for ItemSummary {
impl sum_tree::Item for CreaseItem {
type Summary = ItemSummary;
+ #[ztracing::instrument(skip_all)]
fn summary(&self, _cx: &MultiBufferSnapshot) -> Self::Summary {
ItemSummary {
range: self.crease.range().clone(),
@@ -388,12 +402,14 @@ impl sum_tree::Item for CreaseItem {
/// Implements `SeekTarget` for `Range<Anchor>` to enable seeking within a `SumTree` of `CreaseItem`s.
impl SeekTarget<'_, ItemSummary, ItemSummary> for Range<Anchor> {
+ #[ztracing::instrument(skip_all)]
fn cmp(&self, cursor_location: &ItemSummary, snapshot: &MultiBufferSnapshot) -> Ordering {
AnchorRangeExt::cmp(self, &cursor_location.range, snapshot)
}
}
impl SeekTarget<'_, ItemSummary, ItemSummary> for Anchor {
+ #[ztracing::instrument(skip_all)]
fn cmp(&self, other: &ItemSummary, snapshot: &MultiBufferSnapshot) -> Ordering {
self.cmp(&other.range.start, snapshot)
}
@@ -461,6 +477,7 @@ mod test {
}
#[gpui::test]
+ #[ztracing::instrument(skip_all)]
fn test_creases_in_range(cx: &mut App) {
let text = "line1\nline2\nline3\nline4\nline5\nline6\nline7";
let buffer = MultiBuffer::build_simple(text, cx);
@@ -30,6 +30,7 @@ struct HighlightEndpoint {
}
impl<'a> CustomHighlightsChunks<'a> {
+ #[ztracing::instrument(skip_all)]
pub fn new(
range: Range<MultiBufferOffset>,
language_aware: bool,
@@ -51,6 +52,7 @@ impl<'a> CustomHighlightsChunks<'a> {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn seek(&mut self, new_range: Range<MultiBufferOffset>) {
self.highlight_endpoints =
create_highlight_endpoints(&new_range, self.text_highlights, self.multibuffer_snapshot);
@@ -108,6 +110,7 @@ fn create_highlight_endpoints(
impl<'a> Iterator for CustomHighlightsChunks<'a> {
type Item = Chunk<'a>;
+ #[ztracing::instrument(skip_all)]
fn next(&mut self) -> Option<Self::Item> {
let mut next_highlight_endpoint = MultiBufferOffset(usize::MAX);
while let Some(endpoint) = self.highlight_endpoints.peek().copied() {
@@ -99,6 +99,7 @@ impl FoldPoint {
&mut self.0.column
}
+ #[ztracing::instrument(skip_all)]
pub fn to_inlay_point(self, snapshot: &FoldSnapshot) -> InlayPoint {
let (start, _, _) = snapshot
.transforms
@@ -107,6 +108,7 @@ impl FoldPoint {
InlayPoint(start.1.0 + overshoot)
}
+ #[ztracing::instrument(skip_all)]
pub fn to_offset(self, snapshot: &FoldSnapshot) -> FoldOffset {
let (start, _, item) = snapshot
.transforms
@@ -138,6 +140,7 @@ impl<'a> sum_tree::Dimension<'a, TransformSummary> for FoldPoint {
pub(crate) struct FoldMapWriter<'a>(&'a mut FoldMap);
impl FoldMapWriter<'_> {
+ #[ztracing::instrument(skip_all)]
pub(crate) fn fold<T: ToOffset>(
&mut self,
ranges: impl IntoIterator<Item = (Range<T>, FoldPlaceholder)>,
@@ -202,6 +205,7 @@ impl FoldMapWriter<'_> {
}
/// Removes any folds with the given ranges.
+ #[ztracing::instrument(skip_all)]
pub(crate) fn remove_folds<T: ToOffset>(
&mut self,
ranges: impl IntoIterator<Item = Range<T>>,
@@ -215,6 +219,7 @@ impl FoldMapWriter<'_> {
}
/// Removes any folds whose ranges intersect the given ranges.
+ #[ztracing::instrument(skip_all)]
pub(crate) fn unfold_intersecting<T: ToOffset>(
&mut self,
ranges: impl IntoIterator<Item = Range<T>>,
@@ -225,6 +230,7 @@ impl FoldMapWriter<'_> {
/// Removes any folds that intersect the given ranges and for which the given predicate
/// returns true.
+ #[ztracing::instrument(skip_all)]
fn remove_folds_with<T: ToOffset>(
&mut self,
ranges: impl IntoIterator<Item = Range<T>>,
@@ -277,6 +283,7 @@ impl FoldMapWriter<'_> {
(self.0.snapshot.clone(), edits)
}
+ #[ztracing::instrument(skip_all)]
pub(crate) fn update_fold_widths(
&mut self,
new_widths: impl IntoIterator<Item = (ChunkRendererId, Pixels)>,
@@ -326,6 +333,7 @@ pub struct FoldMap {
}
impl FoldMap {
+ #[ztracing::instrument(skip_all)]
pub fn new(inlay_snapshot: InlaySnapshot) -> (Self, FoldSnapshot) {
let this = Self {
snapshot: FoldSnapshot {
@@ -350,6 +358,7 @@ impl FoldMap {
(this, snapshot)
}
+ #[ztracing::instrument(skip_all)]
pub fn read(
&mut self,
inlay_snapshot: InlaySnapshot,
@@ -360,6 +369,7 @@ impl FoldMap {
(self.snapshot.clone(), edits)
}
+ #[ztracing::instrument(skip_all)]
pub(crate) fn write(
&mut self,
inlay_snapshot: InlaySnapshot,
@@ -369,6 +379,7 @@ impl FoldMap {
(FoldMapWriter(self), snapshot, edits)
}
+ #[ztracing::instrument(skip_all)]
fn check_invariants(&self) {
if cfg!(test) {
assert_eq!(
@@ -398,6 +409,7 @@ impl FoldMap {
}
}
+ #[ztracing::instrument(skip_all)]
fn sync(
&mut self,
inlay_snapshot: InlaySnapshot,
@@ -645,6 +657,7 @@ impl FoldSnapshot {
&self.inlay_snapshot.buffer
}
+ #[ztracing::instrument(skip_all)]
fn fold_width(&self, fold_id: &FoldId) -> Option<Pixels> {
self.fold_metadata_by_id.get(fold_id)?.width
}
@@ -665,6 +678,7 @@ impl FoldSnapshot {
self.folds.items(&self.inlay_snapshot.buffer).len()
}
+ #[ztracing::instrument(skip_all)]
pub fn text_summary_for_range(&self, range: Range<FoldPoint>) -> MBTextSummary {
let mut summary = MBTextSummary::default();
@@ -718,6 +732,7 @@ impl FoldSnapshot {
summary
}
+ #[ztracing::instrument(skip_all)]
pub fn to_fold_point(&self, point: InlayPoint, bias: Bias) -> FoldPoint {
let (start, end, item) = self
.transforms
@@ -734,6 +749,7 @@ impl FoldSnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn fold_point_cursor(&self) -> FoldPointCursor<'_> {
let cursor = self
.transforms
@@ -741,10 +757,12 @@ impl FoldSnapshot {
FoldPointCursor { cursor }
}
+ #[ztracing::instrument(skip_all)]
pub fn len(&self) -> FoldOffset {
FoldOffset(self.transforms.summary().output.len)
}
+ #[ztracing::instrument(skip_all)]
pub fn line_len(&self, row: u32) -> u32 {
let line_start = FoldPoint::new(row, 0).to_offset(self).0;
let line_end = if row >= self.max_point().row() {
@@ -755,6 +773,7 @@ impl FoldSnapshot {
(line_end - line_start) as u32
}
+ #[ztracing::instrument(skip_all)]
pub fn row_infos(&self, start_row: u32) -> FoldRows<'_> {
if start_row > self.transforms.summary().output.lines.row {
panic!("invalid display row {}", start_row);
@@ -777,6 +796,7 @@ impl FoldSnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn max_point(&self) -> FoldPoint {
FoldPoint(self.transforms.summary().output.lines)
}
@@ -786,6 +806,7 @@ impl FoldSnapshot {
self.transforms.summary().output.longest_row
}
+ #[ztracing::instrument(skip_all)]
pub fn folds_in_range<T>(&self, range: Range<T>) -> impl Iterator<Item = &Fold>
where
T: ToOffset,
@@ -800,6 +821,7 @@ impl FoldSnapshot {
})
}
+ #[ztracing::instrument(skip_all)]
pub fn intersects_fold<T>(&self, offset: T) -> bool
where
T: ToOffset,
@@ -812,6 +834,7 @@ impl FoldSnapshot {
item.is_some_and(|t| t.placeholder.is_some())
}
+ #[ztracing::instrument(skip_all)]
pub fn is_line_folded(&self, buffer_row: MultiBufferRow) -> bool {
let mut inlay_point = self
.inlay_snapshot
@@ -840,6 +863,7 @@ impl FoldSnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub(crate) fn chunks<'a>(
&'a self,
range: Range<FoldOffset>,
@@ -884,6 +908,7 @@ impl FoldSnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn chars_at(&self, start: FoldPoint) -> impl '_ + Iterator<Item = char> {
self.chunks(
start.to_offset(self)..self.len(),
@@ -893,6 +918,7 @@ impl FoldSnapshot {
.flat_map(|chunk| chunk.text.chars())
}
+ #[ztracing::instrument(skip_all)]
pub fn chunks_at(&self, start: FoldPoint) -> FoldChunks<'_> {
self.chunks(
start.to_offset(self)..self.len(),
@@ -902,6 +928,7 @@ impl FoldSnapshot {
}
#[cfg(test)]
+ #[ztracing::instrument(skip_all)]
pub fn clip_offset(&self, offset: FoldOffset, bias: Bias) -> FoldOffset {
if offset > self.len() {
self.len()
@@ -910,6 +937,7 @@ impl FoldSnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn clip_point(&self, point: FoldPoint, bias: Bias) -> FoldPoint {
let (start, end, item) = self
.transforms
@@ -939,6 +967,7 @@ pub struct FoldPointCursor<'transforms> {
}
impl FoldPointCursor<'_> {
+ #[ztracing::instrument(skip_all)]
pub fn map(&mut self, point: InlayPoint, bias: Bias) -> FoldPoint {
let cursor = &mut self.cursor;
if cursor.did_seek() {
@@ -1267,6 +1296,7 @@ pub struct FoldRows<'a> {
}
impl FoldRows<'_> {
+ #[ztracing::instrument(skip_all)]
pub(crate) fn seek(&mut self, row: u32) {
let fold_point = FoldPoint::new(row, 0);
self.cursor.seek(&fold_point, Bias::Left);
@@ -1280,6 +1310,7 @@ impl FoldRows<'_> {
impl Iterator for FoldRows<'_> {
type Item = RowInfo;
+ #[ztracing::instrument(skip_all)]
fn next(&mut self) -> Option<Self::Item> {
let mut traversed_fold = false;
while self.fold_point > self.cursor.end().0 {
@@ -1391,6 +1422,7 @@ pub struct FoldChunks<'a> {
}
impl FoldChunks<'_> {
+ #[ztracing::instrument(skip_all)]
pub(crate) fn seek(&mut self, range: Range<FoldOffset>) {
self.transform_cursor.seek(&range.start, Bias::Right);
@@ -1425,6 +1457,7 @@ impl FoldChunks<'_> {
impl<'a> Iterator for FoldChunks<'a> {
type Item = Chunk<'a>;
+ #[ztracing::instrument(skip_all)]
fn next(&mut self) -> Option<Self::Item> {
if self.output_offset >= self.max_output_offset {
return None;
@@ -1524,6 +1557,7 @@ impl<'a> Iterator for FoldChunks<'a> {
pub struct FoldOffset(pub MultiBufferOffset);
impl FoldOffset {
+ #[ztracing::instrument(skip_all)]
pub fn to_point(self, snapshot: &FoldSnapshot) -> FoldPoint {
let (start, _, item) = snapshot
.transforms
@@ -1539,6 +1573,7 @@ impl FoldOffset {
}
#[cfg(test)]
+ #[ztracing::instrument(skip_all)]
pub fn to_inlay_offset(self, snapshot: &FoldSnapshot) -> InlayOffset {
let (start, _, _) = snapshot
.transforms
@@ -52,6 +52,7 @@ enum Transform {
impl sum_tree::Item for Transform {
type Summary = TransformSummary;
+ #[ztracing::instrument(skip_all)]
fn summary(&self, _: ()) -> Self::Summary {
match self {
Transform::Isomorphic(summary) => TransformSummary {
@@ -228,6 +229,7 @@ pub struct InlayChunk<'a> {
}
impl InlayChunks<'_> {
+ #[ztracing::instrument(skip_all)]
pub fn seek(&mut self, new_range: Range<InlayOffset>) {
self.transforms.seek(&new_range.start, Bias::Right);
@@ -248,6 +250,7 @@ impl InlayChunks<'_> {
impl<'a> Iterator for InlayChunks<'a> {
type Item = InlayChunk<'a>;
+ #[ztracing::instrument(skip_all)]
fn next(&mut self) -> Option<Self::Item> {
if self.output_offset == self.max_output_offset {
return None;
@@ -441,6 +444,7 @@ impl<'a> Iterator for InlayChunks<'a> {
}
impl InlayBufferRows<'_> {
+ #[ztracing::instrument(skip_all)]
pub fn seek(&mut self, row: u32) {
let inlay_point = InlayPoint::new(row, 0);
self.transforms.seek(&inlay_point, Bias::Left);
@@ -465,6 +469,7 @@ impl InlayBufferRows<'_> {
impl Iterator for InlayBufferRows<'_> {
type Item = RowInfo;
+ #[ztracing::instrument(skip_all)]
fn next(&mut self) -> Option<Self::Item> {
let buffer_row = if self.inlay_row == 0 {
self.buffer_rows.next().unwrap()
@@ -494,6 +499,7 @@ impl InlayPoint {
}
impl InlayMap {
+ #[ztracing::instrument(skip_all)]
pub fn new(buffer: MultiBufferSnapshot) -> (Self, InlaySnapshot) {
let version = 0;
let snapshot = InlaySnapshot {
@@ -511,6 +517,7 @@ impl InlayMap {
)
}
+ #[ztracing::instrument(skip_all)]
pub fn sync(
&mut self,
buffer_snapshot: MultiBufferSnapshot,
@@ -643,6 +650,7 @@ impl InlayMap {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn splice(
&mut self,
to_remove: &[InlayId],
@@ -693,11 +701,13 @@ impl InlayMap {
(snapshot, edits)
}
+ #[ztracing::instrument(skip_all)]
pub fn current_inlays(&self) -> impl Iterator<Item = &Inlay> {
self.inlays.iter()
}
#[cfg(test)]
+ #[ztracing::instrument(skip_all)]
pub(crate) fn randomly_mutate(
&mut self,
next_inlay_id: &mut usize,
@@ -766,6 +776,7 @@ impl InlayMap {
}
impl InlaySnapshot {
+ #[ztracing::instrument(skip_all)]
pub fn to_point(&self, offset: InlayOffset) -> InlayPoint {
let (start, _, item) = self.transforms.find::<Dimensions<
InlayOffset,
@@ -789,14 +800,17 @@ impl InlaySnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn len(&self) -> InlayOffset {
InlayOffset(self.transforms.summary().output.len)
}
+ #[ztracing::instrument(skip_all)]
pub fn max_point(&self) -> InlayPoint {
InlayPoint(self.transforms.summary().output.lines)
}
+ #[ztracing::instrument(skip_all, fields(point))]
pub fn to_offset(&self, point: InlayPoint) -> InlayOffset {
let (start, _, item) = self
.transforms
@@ -817,6 +831,7 @@ impl InlaySnapshot {
None => self.len(),
}
}
+ #[ztracing::instrument(skip_all)]
pub fn to_buffer_point(&self, point: InlayPoint) -> Point {
let (start, _, item) =
self.transforms
@@ -830,6 +845,7 @@ impl InlaySnapshot {
None => self.buffer.max_point(),
}
}
+ #[ztracing::instrument(skip_all)]
pub fn to_buffer_offset(&self, offset: InlayOffset) -> MultiBufferOffset {
let (start, _, item) = self
.transforms
@@ -844,6 +860,7 @@ impl InlaySnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn to_inlay_offset(&self, offset: MultiBufferOffset) -> InlayOffset {
let mut cursor = self
.transforms
@@ -880,10 +897,12 @@ impl InlaySnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn to_inlay_point(&self, point: Point) -> InlayPoint {
self.inlay_point_cursor().map(point)
}
+ #[ztracing::instrument(skip_all)]
pub fn inlay_point_cursor(&self) -> InlayPointCursor<'_> {
let cursor = self.transforms.cursor::<Dimensions<Point, InlayPoint>>(());
InlayPointCursor {
@@ -892,6 +911,7 @@ impl InlaySnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn clip_point(&self, mut point: InlayPoint, mut bias: Bias) -> InlayPoint {
let mut cursor = self.transforms.cursor::<Dimensions<InlayPoint, Point>>(());
cursor.seek(&point, Bias::Left);
@@ -983,10 +1003,12 @@ impl InlaySnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn text_summary(&self) -> MBTextSummary {
self.transforms.summary().output
}
+ #[ztracing::instrument(skip_all)]
pub fn text_summary_for_range(&self, range: Range<InlayOffset>) -> MBTextSummary {
let mut summary = MBTextSummary::default();
@@ -1044,6 +1066,7 @@ impl InlaySnapshot {
summary
}
+ #[ztracing::instrument(skip_all)]
pub fn row_infos(&self, row: u32) -> InlayBufferRows<'_> {
let mut cursor = self.transforms.cursor::<Dimensions<InlayPoint, Point>>(());
let inlay_point = InlayPoint::new(row, 0);
@@ -1071,6 +1094,7 @@ impl InlaySnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn line_len(&self, row: u32) -> u32 {
let line_start = self.to_offset(InlayPoint::new(row, 0)).0;
let line_end = if row >= self.max_point().row() {
@@ -1081,6 +1105,7 @@ impl InlaySnapshot {
(line_end - line_start) as u32
}
+ #[ztracing::instrument(skip_all)]
pub(crate) fn chunks<'a>(
&'a self,
range: Range<InlayOffset>,
@@ -1115,12 +1140,14 @@ impl InlaySnapshot {
}
#[cfg(test)]
+ #[ztracing::instrument(skip_all)]
pub fn text(&self) -> String {
self.chunks(Default::default()..self.len(), false, Highlights::default())
.map(|chunk| chunk.chunk.text)
.collect()
}
+ #[ztracing::instrument(skip_all)]
fn check_invariants(&self) {
#[cfg(any(debug_assertions, feature = "test-support"))]
{
@@ -1147,6 +1174,7 @@ pub struct InlayPointCursor<'transforms> {
}
impl InlayPointCursor<'_> {
+ #[ztracing::instrument(skip_all)]
pub fn map(&mut self, point: Point) -> InlayPoint {
let cursor = &mut self.cursor;
if cursor.did_seek() {
@@ -30,6 +30,7 @@
// ref: https://gist.github.com/ConradIrwin/f759e1fc29267143c4c7895aa495dca5?h=1
// ref: https://unicode.org/Public/emoji/13.0/emoji-test.txt
// https://github.com/bits/UTF-8-Unicode-Test-Documents/blob/master/UTF-8_sequence_separated/utf8_sequence_0-0x10ffff_assigned_including-unprintable-asis.txt
+#[ztracing::instrument(skip_all)]
pub fn is_invisible(c: char) -> bool {
if c <= '\u{1f}' {
c != '\t' && c != '\n' && c != '\r'
@@ -20,6 +20,7 @@ const MAX_TABS: NonZeroU32 = NonZeroU32::new(SPACES.len() as u32).unwrap();
pub struct TabMap(TabSnapshot);
impl TabMap {
+ #[ztracing::instrument(skip_all)]
pub fn new(fold_snapshot: FoldSnapshot, tab_size: NonZeroU32) -> (Self, TabSnapshot) {
let snapshot = TabSnapshot {
fold_snapshot,
@@ -36,6 +37,7 @@ impl TabMap {
self.0.clone()
}
+ #[ztracing::instrument(skip_all)]
pub fn sync(
&mut self,
fold_snapshot: FoldSnapshot,
@@ -176,10 +178,12 @@ impl std::ops::Deref for TabSnapshot {
}
impl TabSnapshot {
+ #[ztracing::instrument(skip_all)]
pub fn buffer_snapshot(&self) -> &MultiBufferSnapshot {
&self.fold_snapshot.inlay_snapshot.buffer
}
+ #[ztracing::instrument(skip_all)]
pub fn line_len(&self, row: u32) -> u32 {
let max_point = self.max_point();
if row < max_point.row() {
@@ -191,10 +195,12 @@ impl TabSnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn text_summary(&self) -> TextSummary {
self.text_summary_for_range(TabPoint::zero()..self.max_point())
}
+ #[ztracing::instrument(skip_all, fields(rows))]
pub fn text_summary_for_range(&self, range: Range<TabPoint>) -> TextSummary {
let input_start = self.tab_point_to_fold_point(range.start, Bias::Left).0;
let input_end = self.tab_point_to_fold_point(range.end, Bias::Right).0;
@@ -234,6 +240,7 @@ impl TabSnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub(crate) fn chunks<'a>(
&'a self,
range: Range<TabPoint>,
@@ -276,11 +283,13 @@ impl TabSnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn rows(&self, row: u32) -> fold_map::FoldRows<'_> {
self.fold_snapshot.row_infos(row)
}
#[cfg(test)]
+ #[ztracing::instrument(skip_all)]
pub fn text(&self) -> String {
self.chunks(
TabPoint::zero()..self.max_point(),
@@ -291,10 +300,12 @@ impl TabSnapshot {
.collect()
}
+ #[ztracing::instrument(skip_all)]
pub fn max_point(&self) -> TabPoint {
self.fold_point_to_tab_point(self.fold_snapshot.max_point())
}
+ #[ztracing::instrument(skip_all)]
pub fn clip_point(&self, point: TabPoint, bias: Bias) -> TabPoint {
self.fold_point_to_tab_point(
self.fold_snapshot
@@ -302,6 +313,7 @@ impl TabSnapshot {
)
}
+ #[ztracing::instrument(skip_all)]
pub fn fold_point_to_tab_point(&self, input: FoldPoint) -> TabPoint {
let chunks = self.fold_snapshot.chunks_at(FoldPoint::new(input.row(), 0));
let tab_cursor = TabStopCursor::new(chunks);
@@ -309,10 +321,12 @@ impl TabSnapshot {
TabPoint::new(input.row(), expanded)
}
+ #[ztracing::instrument(skip_all)]
pub fn tab_point_cursor(&self) -> TabPointCursor<'_> {
TabPointCursor { this: self }
}
+ #[ztracing::instrument(skip_all)]
pub fn tab_point_to_fold_point(&self, output: TabPoint, bias: Bias) -> (FoldPoint, u32, u32) {
let chunks = self
.fold_snapshot
@@ -330,12 +344,14 @@ impl TabSnapshot {
)
}
+ #[ztracing::instrument(skip_all)]
pub fn point_to_tab_point(&self, point: Point, bias: Bias) -> TabPoint {
let inlay_point = self.fold_snapshot.inlay_snapshot.to_inlay_point(point);
let fold_point = self.fold_snapshot.to_fold_point(inlay_point, bias);
self.fold_point_to_tab_point(fold_point)
}
+ #[ztracing::instrument(skip_all)]
pub fn tab_point_to_point(&self, point: TabPoint, bias: Bias) -> Point {
let fold_point = self.tab_point_to_fold_point(point, bias).0;
let inlay_point = fold_point.to_inlay_point(&self.fold_snapshot);
@@ -344,6 +360,7 @@ impl TabSnapshot {
.to_buffer_point(inlay_point)
}
+ #[ztracing::instrument(skip_all)]
fn expand_tabs<'a, I>(&self, mut cursor: TabStopCursor<'a, I>, column: u32) -> u32
where
I: Iterator<Item = Chunk<'a>>,
@@ -377,6 +394,7 @@ impl TabSnapshot {
expanded_bytes + column.saturating_sub(collapsed_bytes)
}
+ #[ztracing::instrument(skip_all)]
fn collapse_tabs<'a, I>(
&self,
mut cursor: TabStopCursor<'a, I>,
@@ -442,6 +460,7 @@ pub struct TabPointCursor<'this> {
}
impl TabPointCursor<'_> {
+ #[ztracing::instrument(skip_all)]
pub fn map(&mut self, point: FoldPoint) -> TabPoint {
self.this.fold_point_to_tab_point(point)
}
@@ -486,6 +505,7 @@ pub struct TextSummary {
}
impl<'a> From<&'a str> for TextSummary {
+ #[ztracing::instrument(skip_all)]
fn from(text: &'a str) -> Self {
let sum = text::TextSummary::from(text);
@@ -500,6 +520,7 @@ impl<'a> From<&'a str> for TextSummary {
}
impl<'a> std::ops::AddAssign<&'a Self> for TextSummary {
+ #[ztracing::instrument(skip_all)]
fn add_assign(&mut self, other: &'a Self) {
let joined_chars = self.last_line_chars + other.first_line_chars;
if joined_chars > self.longest_row_chars {
@@ -541,6 +562,7 @@ pub struct TabChunks<'a> {
}
impl TabChunks<'_> {
+ #[ztracing::instrument(skip_all)]
pub(crate) fn seek(&mut self, range: Range<TabPoint>) {
let (input_start, expanded_char_column, to_next_stop) = self
.snapshot
@@ -576,6 +598,7 @@ impl TabChunks<'_> {
impl<'a> Iterator for TabChunks<'a> {
type Item = Chunk<'a>;
+ #[ztracing::instrument(skip_all)]
fn next(&mut self) -> Option<Self::Item> {
if self.chunk.text.is_empty() {
if let Some(chunk) = self.fold_chunks.next() {
@@ -1452,6 +1475,7 @@ impl<'a, I> TabStopCursor<'a, I>
where
I: Iterator<Item = Chunk<'a>>,
{
+ #[ztracing::instrument(skip_all)]
fn new(chunks: impl IntoIterator<Item = Chunk<'a>, IntoIter = I>) -> Self {
Self {
chunks: chunks.into_iter(),
@@ -1461,6 +1485,7 @@ where
}
}
+ #[ztracing::instrument(skip_all)]
fn bytes_until_next_char(&self) -> Option<usize> {
self.current_chunk.as_ref().and_then(|(chunk, idx)| {
let mut idx = *idx;
@@ -1482,6 +1507,7 @@ where
})
}
+ #[ztracing::instrument(skip_all)]
fn is_char_boundary(&self) -> bool {
self.current_chunk
.as_ref()
@@ -1489,6 +1515,7 @@ where
}
/// distance: length to move forward while searching for the next tab stop
+ #[ztracing::instrument(skip_all)]
fn seek(&mut self, distance: u32) -> Option<TabStop> {
if distance == 0 {
return None;
@@ -86,6 +86,7 @@ pub struct WrapRows<'a> {
}
impl WrapRows<'_> {
+ #[ztracing::instrument(skip_all)]
pub(crate) fn seek(&mut self, start_row: WrapRow) {
self.transforms
.seek(&WrapPoint::new(start_row, 0), Bias::Left);
@@ -101,6 +102,7 @@ impl WrapRows<'_> {
}
impl WrapMap {
+ #[ztracing::instrument(skip_all)]
pub fn new(
tab_snapshot: TabSnapshot,
font: Font,
@@ -131,6 +133,7 @@ impl WrapMap {
self.background_task.is_some()
}
+ #[ztracing::instrument(skip_all)]
pub fn sync(
&mut self,
tab_snapshot: TabSnapshot,
@@ -150,6 +153,7 @@ impl WrapMap {
(self.snapshot.clone(), mem::take(&mut self.edits_since_sync))
}
+ #[ztracing::instrument(skip_all)]
pub fn set_font_with_size(
&mut self,
font: Font,
@@ -167,6 +171,7 @@ impl WrapMap {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn set_wrap_width(&mut self, wrap_width: Option<Pixels>, cx: &mut Context<Self>) -> bool {
if wrap_width == self.wrap_width {
return false;
@@ -177,6 +182,7 @@ impl WrapMap {
true
}
+ #[ztracing::instrument(skip_all)]
fn rewrap(&mut self, cx: &mut Context<Self>) {
self.background_task.take();
self.interpolated_edits.clear();
@@ -248,6 +254,7 @@ impl WrapMap {
}
}
+ #[ztracing::instrument(skip_all)]
fn flush_edits(&mut self, cx: &mut Context<Self>) {
if !self.snapshot.interpolated {
let mut to_remove_len = 0;
@@ -330,6 +337,7 @@ impl WrapMap {
}
impl WrapSnapshot {
+ #[ztracing::instrument(skip_all)]
fn new(tab_snapshot: TabSnapshot) -> Self {
let mut transforms = SumTree::default();
let extent = tab_snapshot.text_summary();
@@ -343,10 +351,12 @@ impl WrapSnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn buffer_snapshot(&self) -> &MultiBufferSnapshot {
self.tab_snapshot.buffer_snapshot()
}
+ #[ztracing::instrument(skip_all)]
fn interpolate(&mut self, new_tab_snapshot: TabSnapshot, tab_edits: &[TabEdit]) -> WrapPatch {
let mut new_transforms;
if tab_edits.is_empty() {
@@ -411,6 +421,7 @@ impl WrapSnapshot {
old_snapshot.compute_edits(tab_edits, self)
}
+ #[ztracing::instrument(skip_all)]
async fn update(
&mut self,
new_tab_snapshot: TabSnapshot,
@@ -570,6 +581,7 @@ impl WrapSnapshot {
old_snapshot.compute_edits(tab_edits, self)
}
+ #[ztracing::instrument(skip_all)]
fn compute_edits(&self, tab_edits: &[TabEdit], new_snapshot: &WrapSnapshot) -> WrapPatch {
let mut wrap_edits = Vec::with_capacity(tab_edits.len());
let mut old_cursor = self.transforms.cursor::<TransformSummary>(());
@@ -606,6 +618,7 @@ impl WrapSnapshot {
Patch::new(wrap_edits)
}
+ #[ztracing::instrument(skip_all)]
pub(crate) fn chunks<'a>(
&'a self,
rows: Range<WrapRow>,
@@ -640,10 +653,12 @@ impl WrapSnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn max_point(&self) -> WrapPoint {
WrapPoint(self.transforms.summary().output.lines)
}
+ #[ztracing::instrument(skip_all)]
pub fn line_len(&self, row: WrapRow) -> u32 {
let (start, _, item) = self.transforms.find::<Dimensions<WrapPoint, TabPoint>, _>(
(),
@@ -664,6 +679,7 @@ impl WrapSnapshot {
}
}
+ #[ztracing::instrument(skip_all, fields(rows))]
pub fn text_summary_for_range(&self, rows: Range<WrapRow>) -> TextSummary {
let mut summary = TextSummary::default();
@@ -725,6 +741,7 @@ impl WrapSnapshot {
summary
}
+ #[ztracing::instrument(skip_all)]
pub fn soft_wrap_indent(&self, row: WrapRow) -> Option<u32> {
let (.., item) = self.transforms.find::<WrapPoint, _>(
(),
@@ -740,10 +757,12 @@ impl WrapSnapshot {
})
}
+ #[ztracing::instrument(skip_all)]
pub fn longest_row(&self) -> u32 {
self.transforms.summary().output.longest_row
}
+ #[ztracing::instrument(skip_all)]
pub fn row_infos(&self, start_row: WrapRow) -> WrapRows<'_> {
let mut transforms = self
.transforms
@@ -766,6 +785,7 @@ impl WrapSnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn to_tab_point(&self, point: WrapPoint) -> TabPoint {
let (start, _, item) =
self.transforms
@@ -777,15 +797,18 @@ impl WrapSnapshot {
TabPoint(tab_point)
}
+ #[ztracing::instrument(skip_all)]
pub fn to_point(&self, point: WrapPoint, bias: Bias) -> Point {
self.tab_snapshot
.tab_point_to_point(self.to_tab_point(point), bias)
}
+ #[ztracing::instrument(skip_all)]
pub fn make_wrap_point(&self, point: Point, bias: Bias) -> WrapPoint {
self.tab_point_to_wrap_point(self.tab_snapshot.point_to_tab_point(point, bias))
}
+ #[ztracing::instrument(skip_all)]
pub fn tab_point_to_wrap_point(&self, point: TabPoint) -> WrapPoint {
let (start, ..) =
self.transforms
@@ -793,6 +816,7 @@ impl WrapSnapshot {
WrapPoint(start.1.0 + (point.0 - start.0.0))
}
+ #[ztracing::instrument(skip_all)]
pub fn wrap_point_cursor(&self) -> WrapPointCursor<'_> {
WrapPointCursor {
cursor: self
@@ -801,6 +825,7 @@ impl WrapSnapshot {
}
}
+ #[ztracing::instrument(skip_all)]
pub fn clip_point(&self, mut point: WrapPoint, bias: Bias) -> WrapPoint {
if bias == Bias::Left {
let (start, _, item) = self
@@ -815,6 +840,7 @@ impl WrapSnapshot {
self.tab_point_to_wrap_point(self.tab_snapshot.clip_point(self.to_tab_point(point), bias))
}
+ #[ztracing::instrument(skip_all, fields(point, ret))]
pub fn prev_row_boundary(&self, mut point: WrapPoint) -> WrapRow {
if self.transforms.is_empty() {
return WrapRow(0);
@@ -841,6 +867,7 @@ impl WrapSnapshot {
unreachable!()
}
+ #[ztracing::instrument(skip_all)]
pub fn next_row_boundary(&self, mut point: WrapPoint) -> Option<WrapRow> {
point.0 += Point::new(1, 0);
@@ -860,11 +887,13 @@ impl WrapSnapshot {
}
#[cfg(test)]
+ #[ztracing::instrument(skip_all)]
pub fn text(&self) -> String {
self.text_chunks(WrapRow(0)).collect()
}
#[cfg(test)]
+ #[ztracing::instrument(skip_all)]
pub fn text_chunks(&self, wrap_row: WrapRow) -> impl Iterator<Item = &str> {
self.chunks(
wrap_row..self.max_point().row() + WrapRow(1),
@@ -874,6 +903,7 @@ impl WrapSnapshot {
.map(|h| h.text)
}
+ #[ztracing::instrument(skip_all)]
fn check_invariants(&self) {
#[cfg(test)]
{
@@ -927,6 +957,7 @@ pub struct WrapPointCursor<'transforms> {
}
impl WrapPointCursor<'_> {
+ #[ztracing::instrument(skip_all)]
pub fn map(&mut self, point: TabPoint) -> WrapPoint {
let cursor = &mut self.cursor;
if cursor.did_seek() {
@@ -939,6 +970,7 @@ impl WrapPointCursor<'_> {
}
impl WrapChunks<'_> {
+ #[ztracing::instrument(skip_all)]
pub(crate) fn seek(&mut self, rows: Range<WrapRow>) {
let output_start = WrapPoint::new(rows.start, 0);
let output_end = WrapPoint::new(rows.end, 0);
@@ -961,6 +993,7 @@ impl WrapChunks<'_> {
impl<'a> Iterator for WrapChunks<'a> {
type Item = Chunk<'a>;
+ #[ztracing::instrument(skip_all)]
fn next(&mut self) -> Option<Self::Item> {
if self.output_position.row() >= self.max_output_row {
return None;
@@ -1033,6 +1066,7 @@ impl<'a> Iterator for WrapChunks<'a> {
impl Iterator for WrapRows<'_> {
type Item = RowInfo;
+ #[ztracing::instrument(skip_all)]
fn next(&mut self) -> Option<Self::Item> {
if self.output_row > self.max_output_row {
return None;
@@ -1069,6 +1103,7 @@ impl Iterator for WrapRows<'_> {
}
impl Transform {
+ #[ztracing::instrument(skip_all)]
fn isomorphic(summary: TextSummary) -> Self {
#[cfg(test)]
assert!(!summary.lines.is_zero());
@@ -1082,6 +1117,7 @@ impl Transform {
}
}
+ #[ztracing::instrument(skip_all)]
fn wrap(indent: u32) -> Self {
static WRAP_TEXT: LazyLock<String> = LazyLock::new(|| {
let mut wrap_text = String::new();
@@ -1134,6 +1170,7 @@ trait SumTreeExt {
}
impl SumTreeExt for SumTree<Transform> {
+ #[ztracing::instrument(skip_all)]
fn push_or_extend(&mut self, transform: Transform) {
let mut transform = Some(transform);
self.update_last(
@@ -1197,6 +1234,7 @@ impl<'a> sum_tree::Dimension<'a, TransformSummary> for TabPoint {
}
impl sum_tree::SeekTarget<'_, TransformSummary, TransformSummary> for TabPoint {
+ #[ztracing::instrument(skip_all)]
fn cmp(&self, cursor_location: &TransformSummary, _: ()) -> std::cmp::Ordering {
Ord::cmp(&self.0, &cursor_location.input.lines)
}
@@ -1,4 +1,4 @@
-use edit_prediction::EditPredictionProvider;
+use edit_prediction_types::EditPredictionDelegate;
use gpui::{Entity, KeyBinding, Modifiers, prelude::*};
use indoc::indoc;
use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint};
@@ -15,7 +15,7 @@ async fn test_edit_prediction_insert(cx: &mut gpui::TestAppContext) {
init_test(cx, |_| {});
let mut cx = EditorTestContext::new(cx).await;
- let provider = cx.new(|_| FakeEditPredictionProvider::default());
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
assign_editor_completion_provider(provider.clone(), &mut cx);
cx.set_state("let absolute_zero_celsius = ห;");
@@ -37,7 +37,7 @@ async fn test_edit_prediction_modification(cx: &mut gpui::TestAppContext) {
init_test(cx, |_| {});
let mut cx = EditorTestContext::new(cx).await;
- let provider = cx.new(|_| FakeEditPredictionProvider::default());
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
assign_editor_completion_provider(provider.clone(), &mut cx);
cx.set_state("let pi = ห\"foo\";");
@@ -59,7 +59,7 @@ async fn test_edit_prediction_jump_button(cx: &mut gpui::TestAppContext) {
init_test(cx, |_| {});
let mut cx = EditorTestContext::new(cx).await;
- let provider = cx.new(|_| FakeEditPredictionProvider::default());
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
assign_editor_completion_provider(provider.clone(), &mut cx);
// Cursor is 2+ lines above the proposed edit
@@ -128,7 +128,7 @@ async fn test_edit_prediction_invalidation_range(cx: &mut gpui::TestAppContext)
init_test(cx, |_| {});
let mut cx = EditorTestContext::new(cx).await;
- let provider = cx.new(|_| FakeEditPredictionProvider::default());
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
assign_editor_completion_provider(provider.clone(), &mut cx);
// Cursor is 3+ lines above the proposed edit
@@ -233,7 +233,7 @@ async fn test_edit_prediction_jump_disabled_for_non_zed_providers(cx: &mut gpui:
init_test(cx, |_| {});
let mut cx = EditorTestContext::new(cx).await;
- let provider = cx.new(|_| FakeNonZedEditPredictionProvider::default());
+ let provider = cx.new(|_| FakeNonZedEditPredictionDelegate::default());
assign_editor_completion_provider_non_zed(provider.clone(), &mut cx);
// Cursor is 2+ lines above the proposed edit
@@ -281,7 +281,7 @@ async fn test_edit_prediction_preview_cleanup_on_toggle_off(cx: &mut gpui::TestA
cx.update(|cx| cx.bind_keys([KeyBinding::new("ctrl-shift-a", AcceptEditPrediction, None)]));
let mut cx = EditorTestContext::new(cx).await;
- let provider = cx.new(|_| FakeEditPredictionProvider::default());
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
assign_editor_completion_provider(provider.clone(), &mut cx);
cx.set_state("let x = ห;");
@@ -371,7 +371,7 @@ fn accept_completion(cx: &mut EditorTestContext) {
}
fn propose_edits<T: ToOffset>(
- provider: &Entity<FakeEditPredictionProvider>,
+ provider: &Entity<FakeEditPredictionDelegate>,
edits: Vec<(Range<T>, &str)>,
cx: &mut EditorTestContext,
) {
@@ -383,7 +383,7 @@ fn propose_edits<T: ToOffset>(
cx.update(|_, cx| {
provider.update(cx, |provider, _| {
- provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local {
+ provider.set_edit_prediction(Some(edit_prediction_types::EditPrediction::Local {
id: None,
edits: edits.collect(),
edit_preview: None,
@@ -393,7 +393,7 @@ fn propose_edits<T: ToOffset>(
}
fn assign_editor_completion_provider(
- provider: Entity<FakeEditPredictionProvider>,
+ provider: Entity<FakeEditPredictionDelegate>,
cx: &mut EditorTestContext,
) {
cx.update_editor(|editor, window, cx| {
@@ -402,7 +402,7 @@ fn assign_editor_completion_provider(
}
fn propose_edits_non_zed<T: ToOffset>(
- provider: &Entity<FakeNonZedEditPredictionProvider>,
+ provider: &Entity<FakeNonZedEditPredictionDelegate>,
edits: Vec<(Range<T>, &str)>,
cx: &mut EditorTestContext,
) {
@@ -414,7 +414,7 @@ fn propose_edits_non_zed<T: ToOffset>(
cx.update(|_, cx| {
provider.update(cx, |provider, _| {
- provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local {
+ provider.set_edit_prediction(Some(edit_prediction_types::EditPrediction::Local {
id: None,
edits: edits.collect(),
edit_preview: None,
@@ -424,7 +424,7 @@ fn propose_edits_non_zed<T: ToOffset>(
}
fn assign_editor_completion_provider_non_zed(
- provider: Entity<FakeNonZedEditPredictionProvider>,
+ provider: Entity<FakeNonZedEditPredictionDelegate>,
cx: &mut EditorTestContext,
) {
cx.update_editor(|editor, window, cx| {
@@ -433,17 +433,20 @@ fn assign_editor_completion_provider_non_zed(
}
#[derive(Default, Clone)]
-pub struct FakeEditPredictionProvider {
- pub completion: Option<edit_prediction::EditPrediction>,
+pub struct FakeEditPredictionDelegate {
+ pub completion: Option<edit_prediction_types::EditPrediction>,
}
-impl FakeEditPredictionProvider {
- pub fn set_edit_prediction(&mut self, completion: Option<edit_prediction::EditPrediction>) {
+impl FakeEditPredictionDelegate {
+ pub fn set_edit_prediction(
+ &mut self,
+ completion: Option<edit_prediction_types::EditPrediction>,
+ ) {
self.completion = completion;
}
}
-impl EditPredictionProvider for FakeEditPredictionProvider {
+impl EditPredictionDelegate for FakeEditPredictionDelegate {
fn name() -> &'static str {
"fake-completion-provider"
}
@@ -452,7 +455,7 @@ impl EditPredictionProvider for FakeEditPredictionProvider {
"Fake Completion Provider"
}
- fn show_completions_in_menu() -> bool {
+ fn show_predictions_in_menu() -> bool {
true
}
@@ -486,7 +489,7 @@ impl EditPredictionProvider for FakeEditPredictionProvider {
&mut self,
_buffer: gpui::Entity<language::Buffer>,
_cursor_position: language::Anchor,
- _direction: edit_prediction::Direction,
+ _direction: edit_prediction_types::Direction,
_cx: &mut gpui::Context<Self>,
) {
}
@@ -500,23 +503,26 @@ impl EditPredictionProvider for FakeEditPredictionProvider {
_buffer: &gpui::Entity<language::Buffer>,
_cursor_position: language::Anchor,
_cx: &mut gpui::Context<Self>,
- ) -> Option<edit_prediction::EditPrediction> {
+ ) -> Option<edit_prediction_types::EditPrediction> {
self.completion.clone()
}
}
#[derive(Default, Clone)]
-pub struct FakeNonZedEditPredictionProvider {
- pub completion: Option<edit_prediction::EditPrediction>,
+pub struct FakeNonZedEditPredictionDelegate {
+ pub completion: Option<edit_prediction_types::EditPrediction>,
}
-impl FakeNonZedEditPredictionProvider {
- pub fn set_edit_prediction(&mut self, completion: Option<edit_prediction::EditPrediction>) {
+impl FakeNonZedEditPredictionDelegate {
+ pub fn set_edit_prediction(
+ &mut self,
+ completion: Option<edit_prediction_types::EditPrediction>,
+ ) {
self.completion = completion;
}
}
-impl EditPredictionProvider for FakeNonZedEditPredictionProvider {
+impl EditPredictionDelegate for FakeNonZedEditPredictionDelegate {
fn name() -> &'static str {
"fake-non-zed-provider"
}
@@ -525,7 +531,7 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider {
"Fake Non-Zed Provider"
}
- fn show_completions_in_menu() -> bool {
+ fn show_predictions_in_menu() -> bool {
false
}
@@ -559,7 +565,7 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider {
&mut self,
_buffer: gpui::Entity<language::Buffer>,
_cursor_position: language::Anchor,
- _direction: edit_prediction::Direction,
+ _direction: edit_prediction_types::Direction,
_cx: &mut gpui::Context<Self>,
) {
}
@@ -573,7 +579,7 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider {
_buffer: &gpui::Entity<language::Buffer>,
_cursor_position: language::Anchor,
_cx: &mut gpui::Context<Self>,
- ) -> Option<edit_prediction::EditPrediction> {
+ ) -> Option<edit_prediction_types::EditPrediction> {
self.completion.clone()
}
}
@@ -51,7 +51,7 @@ pub mod test;
pub(crate) use actions::*;
pub use display_map::{ChunkRenderer, ChunkRendererContext, DisplayPoint, FoldPlaceholder};
-pub use edit_prediction::Direction;
+pub use edit_prediction_types::Direction;
pub use editor_settings::{
CurrentLineHighlight, DocumentColorsRenderMode, EditorSettings, HideMouseMode,
ScrollBeyondLastLine, ScrollbarAxes, SearchSettings, ShowMinimap,
@@ -92,7 +92,7 @@ use collections::{BTreeMap, HashMap, HashSet, VecDeque};
use convert_case::{Case, Casing};
use dap::TelemetrySpawnLocation;
use display_map::*;
-use edit_prediction::{EditPredictionProvider, EditPredictionProviderHandle};
+use edit_prediction_types::{EditPredictionDelegate, EditPredictionDelegateHandle};
use editor_settings::{GoToDefinitionFallback, Minimap as MinimapSettings};
use element::{AcceptEditPredictionBinding, LineWithInvisibles, PositionMap, layout_line};
use futures::{
@@ -1079,6 +1079,7 @@ pub struct Editor {
show_breakpoints: Option<bool>,
show_wrap_guides: Option<bool>,
show_indent_guides: Option<bool>,
+ buffers_with_disabled_indent_guides: HashSet<BufferId>,
highlight_order: usize,
highlighted_rows: HashMap<TypeId, Vec<RowHighlight>>,
background_highlights: HashMap<HighlightKey, BackgroundHighlight>,
@@ -1119,7 +1120,7 @@ pub struct Editor {
pending_mouse_down: Option<Rc<RefCell<Option<MouseDownEvent>>>>,
gutter_hovered: bool,
hovered_link_state: Option<HoveredLinkState>,
- edit_prediction_provider: Option<RegisteredEditPredictionProvider>,
+ edit_prediction_provider: Option<RegisteredEditPredictionDelegate>,
code_action_providers: Vec<Rc<dyn CodeActionProvider>>,
active_edit_prediction: Option<EditPredictionState>,
/// Used to prevent flickering as the user types while the menu is open
@@ -1127,6 +1128,7 @@ pub struct Editor {
edit_prediction_settings: EditPredictionSettings,
edit_predictions_hidden_for_vim_mode: bool,
show_edit_predictions_override: Option<bool>,
+ show_completions_on_input_override: Option<bool>,
menu_edit_predictions_policy: MenuEditPredictionsPolicy,
edit_prediction_preview: EditPredictionPreview,
edit_prediction_indent_conflict: bool,
@@ -1561,8 +1563,8 @@ pub struct RenameState {
struct InvalidationStack<T>(Vec<T>);
-struct RegisteredEditPredictionProvider {
- provider: Arc<dyn EditPredictionProviderHandle>,
+struct RegisteredEditPredictionDelegate {
+ provider: Arc<dyn EditPredictionDelegateHandle>,
_subscription: Subscription,
}
@@ -1590,6 +1592,45 @@ pub struct ClipboardSelection {
pub is_entire_line: bool,
/// The indentation of the first line when this content was originally copied.
pub first_line_indent: u32,
+ #[serde(default)]
+ pub file_path: Option<PathBuf>,
+ #[serde(default)]
+ pub line_range: Option<RangeInclusive<u32>>,
+}
+
+impl ClipboardSelection {
+ pub fn for_buffer(
+ len: usize,
+ is_entire_line: bool,
+ range: Range<Point>,
+ buffer: &MultiBufferSnapshot,
+ project: Option<&Entity<Project>>,
+ cx: &App,
+ ) -> Self {
+ let first_line_indent = buffer
+ .indent_size_for_line(MultiBufferRow(range.start.row))
+ .len;
+
+ let file_path = util::maybe!({
+ let project = project?.read(cx);
+ let file = buffer.file_at(range.start)?;
+ let project_path = ProjectPath {
+ worktree_id: file.worktree_id(cx),
+ path: file.path().clone(),
+ };
+ project.absolute_path(&project_path, cx)
+ });
+
+ let line_range = file_path.as_ref().map(|_| range.start.row..=range.end.row);
+
+ Self {
+ len,
+ is_entire_line,
+ first_line_indent,
+ file_path,
+ line_range,
+ }
+ }
}
// selections, scroll behavior, was newest selection reversed
@@ -2204,6 +2245,7 @@ impl Editor {
show_breakpoints: None,
show_wrap_guides: None,
show_indent_guides,
+ buffers_with_disabled_indent_guides: HashSet::default(),
highlight_order: 0,
highlighted_rows: HashMap::default(),
background_highlights: HashMap::default(),
@@ -2273,6 +2315,7 @@ impl Editor {
editor_actions: Rc::default(),
edit_predictions_hidden_for_vim_mode: false,
show_edit_predictions_override: None,
+ show_completions_on_input_override: None,
menu_edit_predictions_policy: MenuEditPredictionsPolicy::ByProvider,
edit_prediction_settings: EditPredictionSettings::Disabled,
edit_prediction_indent_conflict: false,
@@ -2986,9 +3029,9 @@ impl Editor {
window: &mut Window,
cx: &mut Context<Self>,
) where
- T: EditPredictionProvider,
+ T: EditPredictionDelegate,
{
- self.edit_prediction_provider = provider.map(|provider| RegisteredEditPredictionProvider {
+ self.edit_prediction_provider = provider.map(|provider| RegisteredEditPredictionDelegate {
_subscription: cx.observe_in(&provider, window, |this, _, window, cx| {
if this.focus_handle.is_focused(window) {
this.update_visible_edit_prediction(window, cx);
@@ -3155,6 +3198,10 @@ impl Editor {
}
}
+ pub fn set_show_completions_on_input(&mut self, show_completions_on_input: Option<bool>) {
+ self.show_completions_on_input_override = show_completions_on_input;
+ }
+
pub fn set_show_edit_predictions(
&mut self,
show_edit_predictions: Option<bool>,
@@ -5531,7 +5578,10 @@ impl Editor {
let language_settings = language_settings(language.clone(), buffer_snapshot.file(), cx);
let completion_settings = language_settings.completions.clone();
- if !menu_is_open && trigger.is_some() && !language_settings.show_completions_on_input {
+ let show_completions_on_input = self
+ .show_completions_on_input_override
+ .unwrap_or(language_settings.show_completions_on_input);
+ if !menu_is_open && trigger.is_some() && !show_completions_on_input {
return;
}
@@ -6909,7 +6959,11 @@ impl Editor {
}
}
- fn hide_blame_popover(&mut self, ignore_timeout: bool, cx: &mut Context<Self>) -> bool {
+ pub fn has_mouse_context_menu(&self) -> bool {
+ self.mouse_context_menu.is_some()
+ }
+
+ pub fn hide_blame_popover(&mut self, ignore_timeout: bool, cx: &mut Context<Self>) -> bool {
self.inline_blame_popover_show_task.take();
if let Some(state) = &mut self.inline_blame_popover {
let hide_task = cx.spawn(async move |editor, cx| {
@@ -7392,7 +7446,7 @@ impl Editor {
&& self
.edit_prediction_provider
.as_ref()
- .is_some_and(|provider| provider.provider.show_completions_in_menu());
+ .is_some_and(|provider| provider.provider.show_predictions_in_menu());
let preview_requires_modifier =
all_language_settings(file, cx).edit_predictions_mode() == EditPredictionsMode::Subtle;
@@ -8093,12 +8147,12 @@ impl Editor {
let edit_prediction = provider.suggest(&buffer, cursor_buffer_position, cx)?;
let (completion_id, edits, edit_preview) = match edit_prediction {
- edit_prediction::EditPrediction::Local {
+ edit_prediction_types::EditPrediction::Local {
id,
edits,
edit_preview,
} => (id, edits, edit_preview),
- edit_prediction::EditPrediction::Jump {
+ edit_prediction_types::EditPrediction::Jump {
id,
snapshot,
target,
@@ -8239,7 +8293,7 @@ impl Editor {
Some(())
}
- pub fn edit_prediction_provider(&self) -> Option<Arc<dyn EditPredictionProviderHandle>> {
+ pub fn edit_prediction_provider(&self) -> Option<Arc<dyn EditPredictionDelegateHandle>> {
Some(self.edit_prediction_provider.as_ref()?.provider.clone())
}
@@ -9561,7 +9615,7 @@ impl Editor {
editor_bg_color.blend(accent_color.opacity(0.6))
}
fn get_prediction_provider_icon_name(
- provider: &Option<RegisteredEditPredictionProvider>,
+ provider: &Option<RegisteredEditPredictionDelegate>,
) -> IconName {
match provider {
Some(provider) => match provider.provider.name() {
@@ -12797,13 +12851,15 @@ impl Editor {
text.push_str(chunk);
len += chunk.len();
}
- clipboard_selections.push(ClipboardSelection {
+
+ clipboard_selections.push(ClipboardSelection::for_buffer(
len,
is_entire_line,
- first_line_indent: buffer
- .indent_size_for_line(MultiBufferRow(selection.start.row))
- .len,
- });
+ selection.range(),
+ &buffer,
+ self.project.as_ref(),
+ cx,
+ ));
}
}
@@ -12946,13 +13002,14 @@ impl Editor {
text.push('\n');
len += 1;
}
- clipboard_selections.push(ClipboardSelection {
+ clipboard_selections.push(ClipboardSelection::for_buffer(
len,
is_entire_line,
- first_line_indent: buffer
- .indent_size_for_line(MultiBufferRow(trimmed_range.start.row))
- .len,
- });
+ trimmed_range,
+ &buffer,
+ self.project.as_ref(),
+ cx,
+ ));
}
}
}
@@ -16800,7 +16857,7 @@ impl Editor {
GoToDefinitionFallback::None => Ok(Navigated::No),
GoToDefinitionFallback::FindAllReferences => {
match editor.update_in(cx, |editor, window, cx| {
- editor.find_all_references(&FindAllReferences, window, cx)
+ editor.find_all_references(&FindAllReferences::default(), window, cx)
})? {
Some(references) => references.await,
None => Ok(Navigated::No),
@@ -17028,9 +17085,7 @@ impl Editor {
})
.collect();
- let Some(workspace) = self.workspace() else {
- return Task::ready(Ok(Navigated::No));
- };
+ let workspace = self.workspace();
cx.spawn_in(window, async move |editor, cx| {
let locations: Vec<Location> = future::join_all(definitions)
@@ -17085,6 +17140,10 @@ impl Editor {
})
.context("buffer title")?;
+ let Some(workspace) = workspace else {
+ return Ok(Navigated::No);
+ };
+
let opened = workspace
.update_in(cx, |workspace, window, cx| {
let allow_preview = PreviewTabsSettings::get_global(cx)
@@ -17114,6 +17173,9 @@ impl Editor {
// TODO(andrew): respect preview tab settings
// `enable_keep_preview_on_code_navigation` and
// `enable_preview_file_from_code_navigation`
+ let Some(workspace) = workspace else {
+ return Ok(Navigated::No);
+ };
workspace
.update_in(cx, |workspace, window, cx| {
workspace.open_resolved_path(path, window, cx)
@@ -17137,6 +17199,9 @@ impl Editor {
{
editor.go_to_singleton_buffer_range(range, window, cx);
} else {
+ let Some(workspace) = workspace else {
+ return Navigated::No;
+ };
let pane = workspace.read(cx).active_pane().clone();
window.defer(cx, move |window, cx| {
let target_editor: Entity<Self> =
@@ -17348,20 +17413,21 @@ impl Editor {
pub fn find_all_references(
&mut self,
- _: &FindAllReferences,
+ action: &FindAllReferences,
window: &mut Window,
cx: &mut Context<Self>,
) -> Option<Task<Result<Navigated>>> {
- let selection = self
- .selections
- .newest::<MultiBufferOffset>(&self.display_snapshot(cx));
+ let always_open_multibuffer = action.always_open_multibuffer;
+ let selection = self.selections.newest_anchor();
let multi_buffer = self.buffer.read(cx);
- let head = selection.head();
-
let multi_buffer_snapshot = multi_buffer.snapshot(cx);
+ let selection_offset = selection.map(|anchor| anchor.to_offset(&multi_buffer_snapshot));
+ let selection_point = selection.map(|anchor| anchor.to_point(&multi_buffer_snapshot));
+ let head = selection_offset.head();
+
let head_anchor = multi_buffer_snapshot.anchor_at(
head,
- if head < selection.tail() {
+ if head < selection_offset.tail() {
Bias::Right
} else {
Bias::Left
@@ -17407,6 +17473,15 @@ impl Editor {
let buffer = location.buffer.read(cx);
(location.buffer, location.range.to_point(buffer))
})
+ // if special-casing the single-match case, remove ranges
+ // that intersect current selection
+ .filter(|(location_buffer, location)| {
+ if always_open_multibuffer || &buffer != location_buffer {
+ return true;
+ }
+
+ !location.contains_inclusive(&selection_point.range())
+ })
.into_group_map()
})?;
if locations.is_empty() {
@@ -17416,6 +17491,60 @@ impl Editor {
ranges.sort_by_key(|range| (range.start, Reverse(range.end)));
ranges.dedup();
}
+ let mut num_locations = 0;
+ for ranges in locations.values_mut() {
+ ranges.sort_by_key(|range| (range.start, Reverse(range.end)));
+ ranges.dedup();
+ num_locations += ranges.len();
+ }
+
+ if num_locations == 1 && !always_open_multibuffer {
+ let (target_buffer, target_ranges) = locations.into_iter().next().unwrap();
+ let target_range = target_ranges.first().unwrap().clone();
+
+ return editor.update_in(cx, |editor, window, cx| {
+ let range = target_range.to_point(target_buffer.read(cx));
+ let range = editor.range_for_match(&range);
+ let range = range.start..range.start;
+
+ if Some(&target_buffer) == editor.buffer.read(cx).as_singleton().as_ref() {
+ editor.go_to_singleton_buffer_range(range, window, cx);
+ } else {
+ let pane = workspace.read(cx).active_pane().clone();
+ window.defer(cx, move |window, cx| {
+ let target_editor: Entity<Self> =
+ workspace.update(cx, |workspace, cx| {
+ let pane = workspace.active_pane().clone();
+
+ let preview_tabs_settings = PreviewTabsSettings::get_global(cx);
+ let keep_old_preview = preview_tabs_settings
+ .enable_keep_preview_on_code_navigation;
+ let allow_new_preview = preview_tabs_settings
+ .enable_preview_file_from_code_navigation;
+
+ workspace.open_project_item(
+ pane,
+ target_buffer.clone(),
+ true,
+ true,
+ keep_old_preview,
+ allow_new_preview,
+ window,
+ cx,
+ )
+ });
+ target_editor.update(cx, |target_editor, cx| {
+ // When selecting a definition in a different buffer, disable the nav history
+ // to avoid creating a history entry at the previous cursor location.
+ pane.update(cx, |pane, _| pane.disable_history());
+ target_editor.go_to_singleton_buffer_range(range, window, cx);
+ pane.update(cx, |pane, _| pane.enable_history());
+ });
+ });
+ }
+ Navigated::No
+ });
+ }
workspace.update_in(cx, |workspace, window, cx| {
let target = locations
@@ -17453,7 +17582,7 @@ impl Editor {
}))
}
- /// Opens a multibuffer with the given project locations in it
+ /// Opens a multibuffer with the given project locations in it.
pub fn open_locations_in_multibuffer(
workspace: &mut Workspace,
locations: std::collections::HashMap<Entity<Buffer>, Vec<Range<Point>>>,
@@ -20090,6 +20219,20 @@ impl Editor {
self.show_indent_guides
}
+ pub fn disable_indent_guides_for_buffer(
+ &mut self,
+ buffer_id: BufferId,
+ cx: &mut Context<Self>,
+ ) {
+ self.buffers_with_disabled_indent_guides.insert(buffer_id);
+ cx.notify();
+ }
+
+ pub fn has_indent_guides_disabled_for_buffer(&self, buffer_id: BufferId) -> bool {
+ self.buffers_with_disabled_indent_guides
+ .contains(&buffer_id)
+ }
+
pub fn toggle_line_numbers(
&mut self,
_: &ToggleLineNumbers,
@@ -22077,14 +22220,23 @@ impl Editor {
None => Autoscroll::newest(),
};
let nav_history = editor.nav_history.take();
+ let multibuffer_snapshot = editor.buffer().read(cx).snapshot(cx);
+ let Some((&excerpt_id, _, buffer_snapshot)) =
+ multibuffer_snapshot.as_singleton()
+ else {
+ return;
+ };
editor.change_selections(
SelectionEffects::scroll(autoscroll),
window,
cx,
|s| {
s.select_ranges(ranges.into_iter().map(|range| {
- // we checked that the editor is a singleton editor so the offsets are valid
- MultiBufferOffset(range.start.0)..MultiBufferOffset(range.end.0)
+ let range = buffer_snapshot.anchor_before(range.start)
+ ..buffer_snapshot.anchor_after(range.end);
+ multibuffer_snapshot
+ .anchor_range_in_excerpt(excerpt_id, range)
+ .unwrap()
}));
},
);
@@ -2,7 +2,7 @@ use super::*;
use crate::{
JoinLines,
code_context_menus::CodeContextMenu,
- edit_prediction_tests::FakeEditPredictionProvider,
+ edit_prediction_tests::FakeEditPredictionDelegate,
element::StickyHeader,
linked_editing_ranges::LinkedEditingRanges,
scroll::scroll_amount::ScrollAmount,
@@ -8636,7 +8636,7 @@ async fn test_undo_edit_prediction_scrolls_to_edit_pos(cx: &mut TestAppContext)
let mut cx = EditorTestContext::new(cx).await;
- let provider = cx.new(|_| FakeEditPredictionProvider::default());
+ let provider = cx.new(|_| FakeEditPredictionDelegate::default());
cx.update_editor(|editor, window, cx| {
editor.set_edit_prediction_provider(Some(provider.clone()), window, cx);
});
@@ -8659,7 +8659,7 @@ async fn test_undo_edit_prediction_scrolls_to_edit_pos(cx: &mut TestAppContext)
cx.update(|_, cx| {
provider.update(cx, |provider, _| {
- provider.set_edit_prediction(Some(edit_prediction::EditPrediction::Local {
+ provider.set_edit_prediction(Some(edit_prediction_types::EditPrediction::Local {
id: None,
edits: vec![(edit_position..edit_position, "X".into())],
edit_preview: None,
@@ -9970,7 +9970,7 @@ async fn test_autoindent_disabled_with_nested_language(cx: &mut TestAppContext)
],
..Default::default()
},
- name: LanguageName::new("rust"),
+ name: LanguageName::new_static("rust"),
..Default::default()
},
Some(tree_sitter_rust::LANGUAGE.into()),
@@ -22579,7 +22579,7 @@ async fn test_find_all_references_editor_reuse(cx: &mut TestAppContext) {
});
let navigated = cx
.update_editor(|editor, window, cx| {
- editor.find_all_references(&FindAllReferences, window, cx)
+ editor.find_all_references(&FindAllReferences::default(), window, cx)
})
.unwrap()
.await
@@ -22615,7 +22615,7 @@ async fn test_find_all_references_editor_reuse(cx: &mut TestAppContext) {
);
let navigated = cx
.update_editor(|editor, window, cx| {
- editor.find_all_references(&FindAllReferences, window, cx)
+ editor.find_all_references(&FindAllReferences::default(), window, cx)
})
.unwrap()
.await
@@ -22667,7 +22667,7 @@ async fn test_find_all_references_editor_reuse(cx: &mut TestAppContext) {
});
let navigated = cx
.update_editor(|editor, window, cx| {
- editor.find_all_references(&FindAllReferences, window, cx)
+ editor.find_all_references(&FindAllReferences::default(), window, cx)
})
.unwrap()
.await
@@ -27498,6 +27498,159 @@ async fn test_paste_url_from_other_app_creates_markdown_link_over_selected_text(
));
}
+#[gpui::test]
+async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
+ init_test(cx, |_| {});
+
+ let markdown_language = languages::language("markdown", tree_sitter_md::LANGUAGE.into());
+ let mut cx = EditorTestContext::new(cx).await;
+
+ cx.update_buffer(|buffer, cx| buffer.set_language(Some(markdown_language), cx));
+
+ // Case 1: Test if adding a character with multi cursors preserves nested list indents
+ cx.set_state(&indoc! {"
+ - [ ] Item 1
+ - [ ] Item 1.a
+ - [ห] Item 2
+ - [ห] Item 2.a
+ - [ห] Item 2.b
+ "
+ });
+ cx.update_editor(|editor, window, cx| {
+ editor.handle_input("x", window, cx);
+ });
+ cx.assert_editor_state(indoc! {"
+ - [ ] Item 1
+ - [ ] Item 1.a
+ - [xห] Item 2
+ - [xห] Item 2.a
+ - [xห] Item 2.b
+ "
+ });
+
+ // Case 2: Test adding new line after nested list preserves indent of previous line
+ cx.set_state(&indoc! {"
+ - [ ] Item 1
+ - [ ] Item 1.a
+ - [x] Item 2
+ - [x] Item 2.a
+ - [x] Item 2.bห
+ "
+ });
+ cx.update_editor(|editor, window, cx| {
+ editor.newline(&Newline, window, cx);
+ });
+ cx.assert_editor_state(indoc! {"
+ - [ ] Item 1
+ - [ ] Item 1.a
+ - [x] Item 2
+ - [x] Item 2.a
+ - [x] Item 2.b
+ ห
+ "
+ });
+
+ // Case 3: Test adding a new nested list item preserves indent
+ cx.update_editor(|editor, window, cx| {
+ editor.handle_input("-", window, cx);
+ });
+ cx.assert_editor_state(indoc! {"
+ - [ ] Item 1
+ - [ ] Item 1.a
+ - [x] Item 2
+ - [x] Item 2.a
+ - [x] Item 2.b
+ -ห
+ "
+ });
+ cx.update_editor(|editor, window, cx| {
+ editor.handle_input(" [x] Item 2.c", window, cx);
+ });
+ cx.assert_editor_state(indoc! {"
+ - [ ] Item 1
+ - [ ] Item 1.a
+ - [x] Item 2
+ - [x] Item 2.a
+ - [x] Item 2.b
+ - [x] Item 2.cห
+ "
+ });
+
+ // Case 4: Test adding new line after nested ordered list preserves indent of previous line
+ cx.set_state(indoc! {"
+ 1. Item 1
+ 1. Item 1.a
+ 2. Item 2
+ 1. Item 2.a
+ 2. Item 2.bห
+ "
+ });
+ cx.update_editor(|editor, window, cx| {
+ editor.newline(&Newline, window, cx);
+ });
+ cx.assert_editor_state(indoc! {"
+ 1. Item 1
+ 1. Item 1.a
+ 2. Item 2
+ 1. Item 2.a
+ 2. Item 2.b
+ ห
+ "
+ });
+
+ // Case 5: Adding new ordered list item preserves indent
+ cx.update_editor(|editor, window, cx| {
+ editor.handle_input("3", window, cx);
+ });
+ cx.assert_editor_state(indoc! {"
+ 1. Item 1
+ 1. Item 1.a
+ 2. Item 2
+ 1. Item 2.a
+ 2. Item 2.b
+ 3ห
+ "
+ });
+ cx.update_editor(|editor, window, cx| {
+ editor.handle_input(".", window, cx);
+ });
+ cx.assert_editor_state(indoc! {"
+ 1. Item 1
+ 1. Item 1.a
+ 2. Item 2
+ 1. Item 2.a
+ 2. Item 2.b
+ 3.ห
+ "
+ });
+ cx.update_editor(|editor, window, cx| {
+ editor.handle_input(" Item 2.c", window, cx);
+ });
+ cx.assert_editor_state(indoc! {"
+ 1. Item 1
+ 1. Item 1.a
+ 2. Item 2
+ 1. Item 2.a
+ 2. Item 2.b
+ 3. Item 2.cห
+ "
+ });
+
+ // Case 7: Test blockquote newline preserves something
+ cx.set_state(indoc! {"
+ > Item 1ห
+ "
+ });
+ cx.update_editor(|editor, window, cx| {
+ editor.newline(&Newline, window, cx);
+ });
+ cx.assert_editor_state(indoc! {"
+ > Item 1
+ ห
+ "
+ });
+}
+
#[gpui::test]
async fn test_paste_url_from_zed_copy_creates_markdown_link_over_selected_text(
cx: &mut gpui::TestAppContext,
@@ -28446,33 +28599,32 @@ async fn test_multibuffer_selections_with_folding(cx: &mut TestAppContext) {
3
"});
- // Edge case scenario: fold all buffers, then try to insert
+ // Test correct folded header is selected upon fold
cx.update_editor(|editor, _, cx| {
editor.fold_buffer(buffer_ids[0], cx);
editor.fold_buffer(buffer_ids[1], cx);
});
cx.assert_excerpts_with_selections(indoc! {"
- [EXCERPT]
- ห[FOLDED]
[EXCERPT]
[FOLDED]
+ [EXCERPT]
+ ห[FOLDED]
"});
- // Insert should work via default selection
+ // Test selection inside folded buffer unfolds it on type
cx.update_editor(|editor, window, cx| {
editor.handle_input("W", window, cx);
});
cx.update_editor(|editor, _, cx| {
editor.unfold_buffer(buffer_ids[0], cx);
- editor.unfold_buffer(buffer_ids[1], cx);
});
cx.assert_excerpts_with_selections(indoc! {"
[EXCERPT]
- Wห1
+ 1
2
3
[EXCERPT]
- 1
+ Wห1
Z
3
"});
@@ -28763,3 +28915,65 @@ async fn test_multibuffer_scroll_cursor_top_margin(cx: &mut TestAppContext) {
);
});
}
+
+#[gpui::test]
+async fn test_find_references_single_case(cx: &mut TestAppContext) {
+ init_test(cx, |_| {});
+ let mut cx = EditorLspTestContext::new_rust(
+ lsp::ServerCapabilities {
+ references_provider: Some(lsp::OneOf::Left(true)),
+ ..lsp::ServerCapabilities::default()
+ },
+ cx,
+ )
+ .await;
+
+ let before = indoc!(
+ r#"
+ fn main() {
+ let aหbc = 123;
+ let xyz = abc;
+ }
+ "#
+ );
+ let after = indoc!(
+ r#"
+ fn main() {
+ let abc = 123;
+ let xyz = หabc;
+ }
+ "#
+ );
+
+ cx.lsp
+ .set_request_handler::<lsp::request::References, _, _>(async move |params, _| {
+ Ok(Some(vec![
+ lsp::Location {
+ uri: params.text_document_position.text_document.uri.clone(),
+ range: lsp::Range::new(lsp::Position::new(1, 8), lsp::Position::new(1, 11)),
+ },
+ lsp::Location {
+ uri: params.text_document_position.text_document.uri,
+ range: lsp::Range::new(lsp::Position::new(2, 14), lsp::Position::new(2, 17)),
+ },
+ ]))
+ });
+
+ cx.set_state(before);
+
+ let action = FindAllReferences {
+ always_open_multibuffer: false,
+ };
+
+ let navigated = cx
+ .update_editor(|editor, window, cx| editor.find_all_references(&action, window, cx))
+ .expect("should have spawned a task")
+ .await
+ .unwrap();
+
+ assert_eq!(navigated, Navigated::No);
+
+ cx.run_until_parked();
+
+ cx.assert_editor_state(after);
+}
@@ -2340,7 +2340,7 @@ impl EditorElement {
.opacity(0.05))
.text_color(severity_to_color(&diagnostic_to_render.severity).color(cx))
.text_sm()
- .font_family(style.text.font().family)
+ .font(style.text.font())
.child(diagnostic_to_render.message.clone())
.into_any();
@@ -3915,6 +3915,8 @@ impl EditorElement {
) -> impl IntoElement {
let editor = self.editor.read(cx);
let multi_buffer = editor.buffer.read(cx);
+ let is_read_only = self.editor.read(cx).read_only(cx);
+
let file_status = multi_buffer
.all_diff_hunks_expanded()
.then(|| editor.status_for_buffer_id(for_excerpt.buffer_id, cx))
@@ -3967,7 +3969,7 @@ impl EditorElement {
.gap_1p5()
.when(is_sticky, |el| el.shadow_md())
.border_1()
- .map(|div| {
+ .map(|border| {
let border_color = if is_selected
&& is_folded
&& focus_handle.contains_focused(window, cx)
@@ -3976,7 +3978,7 @@ impl EditorElement {
} else {
colors.border
};
- div.border_color(border_color)
+ border.border_color(border_color)
})
.bg(colors.editor_subheader_background)
.hover(|style| style.bg(colors.element_hover))
@@ -4056,13 +4058,15 @@ impl EditorElement {
})
.take(1),
)
- .child(
- h_flex()
- .size_3()
- .justify_center()
- .flex_shrink_0()
- .children(indicator),
- )
+ .when(!is_read_only, |this| {
+ this.child(
+ h_flex()
+ .size_3()
+ .justify_center()
+ .flex_shrink_0()
+ .children(indicator),
+ )
+ })
.child(
h_flex()
.cursor_pointer()
@@ -508,7 +508,19 @@ impl GitBlame {
let buffer_edits = buffer.update(cx, |buffer, _| buffer.subscribe());
let blame_buffer = project.blame_buffer(&buffer, None, cx);
- Some(async move { (id, snapshot, buffer_edits, blame_buffer.await) })
+ let remote_url = project
+ .git_store()
+ .read(cx)
+ .repository_and_path_for_buffer_id(buffer.read(cx).remote_id(), cx)
+ .and_then(|(repo, _)| {
+ repo.read(cx)
+ .remote_upstream_url
+ .clone()
+ .or(repo.read(cx).remote_origin_url.clone())
+ });
+ Some(
+ async move { (id, snapshot, buffer_edits, blame_buffer.await, remote_url) },
+ )
})
.collect::<Vec<_>>()
});
@@ -524,13 +536,9 @@ impl GitBlame {
.await;
let mut res = vec![];
let mut errors = vec![];
- for (id, snapshot, buffer_edits, blame) in blame {
+ for (id, snapshot, buffer_edits, blame, remote_url) in blame {
match blame {
- Ok(Some(Blame {
- entries,
- messages,
- remote_url,
- })) => {
+ Ok(Some(Blame { entries, messages })) => {
let entries = build_blame_entry_sum_tree(
entries,
snapshot.max_point().row,
@@ -168,7 +168,7 @@ impl Editor {
match EditorSettings::get_global(cx).go_to_definition_fallback {
GoToDefinitionFallback::None => None,
GoToDefinitionFallback::FindAllReferences => {
- editor.find_all_references(&FindAllReferences, window, cx)
+ editor.find_all_references(&FindAllReferences::default(), window, cx)
}
}
})
@@ -607,13 +607,16 @@ async fn parse_blocks(
pub fn hover_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
let settings = ThemeSettings::get_global(cx);
let ui_font_family = settings.ui_font.family.clone();
+ let ui_font_features = settings.ui_font.features.clone();
let ui_font_fallbacks = settings.ui_font.fallbacks.clone();
let buffer_font_family = settings.buffer_font.family.clone();
+ let buffer_font_features = settings.buffer_font.features.clone();
let buffer_font_fallbacks = settings.buffer_font.fallbacks.clone();
let mut base_text_style = window.text_style();
base_text_style.refine(&TextStyleRefinement {
font_family: Some(ui_font_family),
+ font_features: Some(ui_font_features),
font_fallbacks: ui_font_fallbacks,
color: Some(cx.theme().colors().editor_foreground),
..Default::default()
@@ -624,6 +627,7 @@ pub fn hover_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
inline_code: TextStyleRefinement {
background_color: Some(cx.theme().colors().background),
font_family: Some(buffer_font_family),
+ font_features: Some(buffer_font_features),
font_fallbacks: buffer_font_fallbacks,
..Default::default()
},
@@ -657,12 +661,15 @@ pub fn diagnostics_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
let settings = ThemeSettings::get_global(cx);
let ui_font_family = settings.ui_font.family.clone();
let ui_font_fallbacks = settings.ui_font.fallbacks.clone();
+ let ui_font_features = settings.ui_font.features.clone();
let buffer_font_family = settings.buffer_font.family.clone();
+ let buffer_font_features = settings.buffer_font.features.clone();
let buffer_font_fallbacks = settings.buffer_font.fallbacks.clone();
let mut base_text_style = window.text_style();
base_text_style.refine(&TextStyleRefinement {
font_family: Some(ui_font_family),
+ font_features: Some(ui_font_features),
font_fallbacks: ui_font_fallbacks,
color: Some(cx.theme().colors().editor_foreground),
..Default::default()
@@ -673,6 +680,7 @@ pub fn diagnostics_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
inline_code: TextStyleRefinement {
background_color: Some(cx.theme().colors().editor_background.opacity(0.5)),
font_family: Some(buffer_font_family),
+ font_features: Some(buffer_font_features),
font_fallbacks: buffer_font_fallbacks,
..Default::default()
},
@@ -181,6 +181,10 @@ pub fn indent_guides_in_range(
.buffer_snapshot()
.indent_guides_in_range(start_anchor..end_anchor, ignore_disabled_for_language, cx)
.filter(|indent_guide| {
+ if editor.has_indent_guides_disabled_for_buffer(indent_guide.buffer_id) {
+ return false;
+ }
+
if editor.is_buffer_folded(indent_guide.buffer_id, cx) {
return false;
}
@@ -1951,7 +1951,7 @@ mod tests {
use super::*;
use fs::MTime;
use gpui::{App, VisualTestContext};
- use language::{LanguageMatcher, TestFile};
+ use language::TestFile;
use project::FakeFs;
use std::path::{Path, PathBuf};
use util::{path, rel_path::RelPath};
@@ -1991,20 +1991,6 @@ mod tests {
.unwrap()
}
- fn rust_language() -> Arc<language::Language> {
- Arc::new(language::Language::new(
- language::LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- ))
- }
-
#[gpui::test]
async fn test_deserialize(cx: &mut gpui::TestAppContext) {
init_test(cx, |_| {});
@@ -2086,7 +2072,9 @@ mod tests {
{
let project = Project::test(fs.clone(), [path!("/file.rs").as_ref()], cx).await;
// Add Rust to the language, so that we can restore the language of the buffer
- project.read_with(cx, |project, _| project.languages().add(rust_language()));
+ project.read_with(cx, |project, _| {
+ project.languages().add(languages::rust_lang())
+ });
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
@@ -235,7 +235,10 @@ pub fn deploy_context_menu(
.action("Go to Declaration", Box::new(GoToDeclaration))
.action("Go to Type Definition", Box::new(GoToTypeDefinition))
.action("Go to Implementation", Box::new(GoToImplementation))
- .action("Find All References", Box::new(FindAllReferences))
+ .action(
+ "Find All References",
+ Box::new(FindAllReferences::default()),
+ )
.separator()
.action("Rename Symbol", Box::new(Rename))
.action("Format Buffer", Box::new(Format))
@@ -419,22 +419,30 @@ impl SelectionsCollection {
mutable_collection.disjoint.iter().for_each(|selection| {
assert!(
snapshot.can_resolve(&selection.start),
- "disjoint selection start is not resolvable for the given snapshot:\n{selection:?}",
+ "disjoint selection start is not resolvable for the given snapshot:\n{selection:?}, {excerpt:?}",
+ excerpt = snapshot.buffer_for_excerpt(selection.start.excerpt_id).map(|snapshot| snapshot.remote_id()),
);
assert!(
snapshot.can_resolve(&selection.end),
- "disjoint selection end is not resolvable for the given snapshot: {selection:?}",
+ "disjoint selection end is not resolvable for the given snapshot: {selection:?}, {excerpt:?}",
+ excerpt = snapshot.buffer_for_excerpt(selection.end.excerpt_id).map(|snapshot| snapshot.remote_id()),
);
});
if let Some(pending) = &mutable_collection.pending {
let selection = &pending.selection;
assert!(
snapshot.can_resolve(&selection.start),
- "pending selection start is not resolvable for the given snapshot: {pending:?}",
+ "pending selection start is not resolvable for the given snapshot: {pending:?}, {excerpt:?}",
+ excerpt = snapshot
+ .buffer_for_excerpt(selection.start.excerpt_id)
+ .map(|snapshot| snapshot.remote_id()),
);
assert!(
snapshot.can_resolve(&selection.end),
- "pending selection end is not resolvable for the given snapshot: {pending:?}",
+ "pending selection end is not resolvable for the given snapshot: {pending:?}, {excerpt:?}",
+ excerpt = snapshot
+ .buffer_for_excerpt(selection.end.excerpt_id)
+ .map(|snapshot| snapshot.remote_id()),
);
}
}
@@ -532,11 +540,18 @@ impl<'snap, 'a> MutableSelectionsCollection<'snap, 'a> {
};
if filtered_selections.is_empty() {
- let default_anchor = self.snapshot.anchor_before(MultiBufferOffset(0));
+ let buffer_snapshot = self.snapshot.buffer_snapshot();
+ let anchor = buffer_snapshot
+ .excerpts()
+ .find(|(_, buffer, _)| buffer.remote_id() == buffer_id)
+ .and_then(|(excerpt_id, _, range)| {
+ buffer_snapshot.anchor_in_excerpt(excerpt_id, range.context.start)
+ })
+ .unwrap_or_else(|| self.snapshot.anchor_before(MultiBufferOffset(0)));
self.collection.disjoint = Arc::from([Selection {
id: post_inc(&mut self.collection.next_selection_id),
- start: default_anchor,
- end: default_anchor,
+ start: anchor,
+ end: anchor,
reversed: false,
goal: SelectionGoal::None,
}]);
@@ -38,3 +38,4 @@ wasmparser.workspace = true
[dev-dependencies]
pretty_assertions.workspace = true
+tempfile.workspace = true
@@ -247,26 +247,34 @@ impl ExtensionBuilder {
let parser_path = src_path.join("parser.c");
let scanner_path = src_path.join("scanner.c");
- log::info!("compiling {grammar_name} parser");
- let clang_output = util::command::new_smol_command(&clang_path)
- .args(["-fPIC", "-shared", "-Os"])
- .arg(format!("-Wl,--export=tree_sitter_{grammar_name}"))
- .arg("-o")
- .arg(&grammar_wasm_path)
- .arg("-I")
- .arg(&src_path)
- .arg(&parser_path)
- .args(scanner_path.exists().then_some(scanner_path))
- .output()
- .await
- .context("failed to run clang")?;
-
- if !clang_output.status.success() {
- bail!(
- "failed to compile {} parser with clang: {}",
- grammar_name,
- String::from_utf8_lossy(&clang_output.stderr),
+ // Skip recompiling if the WASM object is already newer than the source files
+ if file_newer_than_deps(&grammar_wasm_path, &[&parser_path, &scanner_path]).unwrap_or(false)
+ {
+ log::info!(
+ "skipping compilation of {grammar_name} parser because the existing compiled grammar is up to date"
);
+ } else {
+ log::info!("compiling {grammar_name} parser");
+ let clang_output = util::command::new_smol_command(&clang_path)
+ .args(["-fPIC", "-shared", "-Os"])
+ .arg(format!("-Wl,--export=tree_sitter_{grammar_name}"))
+ .arg("-o")
+ .arg(&grammar_wasm_path)
+ .arg("-I")
+ .arg(&src_path)
+ .arg(&parser_path)
+ .args(scanner_path.exists().then_some(scanner_path))
+ .output()
+ .await
+ .context("failed to run clang")?;
+
+ if !clang_output.status.success() {
+ bail!(
+ "failed to compile {} parser with clang: {}",
+ grammar_name,
+ String::from_utf8_lossy(&clang_output.stderr),
+ );
+ }
}
Ok(())
@@ -643,3 +651,71 @@ fn populate_defaults(manifest: &mut ExtensionManifest, extension_path: &Path) ->
Ok(())
}
+
+/// Returns `true` if the target exists and its last modified time is greater than that
+/// of each dependency which exists (i.e., dependency paths which do not exist are ignored).
+///
+/// # Errors
+///
+/// Returns `Err` if any of the underlying file I/O operations fail.
+fn file_newer_than_deps(target: &Path, dependencies: &[&Path]) -> Result<bool, std::io::Error> {
+ if !target.try_exists()? {
+ return Ok(false);
+ }
+ let target_modified = target.metadata()?.modified()?;
+ for dependency in dependencies {
+ if !dependency.try_exists()? {
+ continue;
+ }
+ let dep_modified = dependency.metadata()?.modified()?;
+ if target_modified < dep_modified {
+ return Ok(false);
+ }
+ }
+ Ok(true)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use std::{fs, thread::sleep, time::Duration};
+
+ #[test]
+ fn test_file_newer_than_deps() {
+ // Don't use TempTree because we need to guarantee the order
+ let tmpdir = tempfile::tempdir().unwrap();
+ let target = tmpdir.path().join("target.wasm");
+ let dep1 = tmpdir.path().join("parser.c");
+ let dep2 = tmpdir.path().join("scanner.c");
+
+ assert!(
+ !file_newer_than_deps(&target, &[&dep1, &dep2]).unwrap(),
+ "target doesn't exist"
+ );
+ fs::write(&target, "foo").unwrap(); // Create target
+ assert!(
+ file_newer_than_deps(&target, &[&dep1, &dep2]).unwrap(),
+ "dependencies don't exist; target is newer"
+ );
+ sleep(Duration::from_secs(1));
+ fs::write(&dep1, "foo").unwrap(); // Create dep1 (newer than target)
+ // Dependency is newer
+ assert!(
+ !file_newer_than_deps(&target, &[&dep1, &dep2]).unwrap(),
+ "a dependency is newer (target {:?}, dep1 {:?})",
+ target.metadata().unwrap().modified().unwrap(),
+ dep1.metadata().unwrap().modified().unwrap(),
+ );
+ sleep(Duration::from_secs(1));
+ fs::write(&dep2, "foo").unwrap(); // Create dep2
+ sleep(Duration::from_secs(1));
+ fs::write(&target, "foobar").unwrap(); // Update target
+ assert!(
+ file_newer_than_deps(&target, &[&dep1, &dep2]).unwrap(),
+ "target is newer than dependencies (target {:?}, dep2 {:?})",
+ target.metadata().unwrap().modified().unwrap(),
+ dep2.metadata().unwrap().modified().unwrap(),
+ );
+ }
+}
@@ -309,9 +309,9 @@ async fn test_extension_store(cx: &mut TestAppContext) {
assert_eq!(
language_registry.language_names(),
[
- LanguageName::new("ERB"),
- LanguageName::new("Plain Text"),
- LanguageName::new("Ruby"),
+ LanguageName::new_static("ERB"),
+ LanguageName::new_static("Plain Text"),
+ LanguageName::new_static("Ruby"),
]
);
assert_eq!(
@@ -466,9 +466,9 @@ async fn test_extension_store(cx: &mut TestAppContext) {
assert_eq!(
language_registry.language_names(),
[
- LanguageName::new("ERB"),
- LanguageName::new("Plain Text"),
- LanguageName::new("Ruby"),
+ LanguageName::new_static("ERB"),
+ LanguageName::new_static("Plain Text"),
+ LanguageName::new_static("Ruby"),
]
);
assert_eq!(
@@ -526,7 +526,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
assert_eq!(
language_registry.language_names(),
- [LanguageName::new("Plain Text")]
+ [LanguageName::new_static("Plain Text")]
);
assert_eq!(language_registry.grammar_names(), []);
});
@@ -708,7 +708,7 @@ async fn test_extension_store_with_test_extension(cx: &mut TestAppContext) {
.await
.unwrap();
- let mut fake_servers = language_registry.register_fake_language_server(
+ let mut fake_servers = language_registry.register_fake_lsp_server(
LanguageServerName("gleam".into()),
lsp::ServerCapabilities {
completion_provider: Some(Default::default()),
@@ -1472,8 +1472,7 @@ impl ExtensionsPage {
},
);
},
- ))
- .color(ui::SwitchColor::Accent),
+ )),
),
),
)
@@ -1,11 +1,5 @@
use crate::FeatureFlag;
-pub struct PredictEditsRateCompletionsFeatureFlag;
-
-impl FeatureFlag for PredictEditsRateCompletionsFeatureFlag {
- const NAME: &'static str = "predict-edits-rate-completions";
-}
-
pub struct NotebookFeatureFlag;
impl FeatureFlag for NotebookFeatureFlag {
@@ -17,3 +11,9 @@ pub struct PanicFeatureFlag;
impl FeatureFlag for PanicFeatureFlag {
const NAME: &'static str = "panic";
}
+
+pub struct InlineAssistantV2FeatureFlag;
+
+impl FeatureFlag for InlineAssistantV2FeatureFlag {
+ const NAME: &'static str = "inline-assistant-v2";
+}
@@ -381,11 +381,18 @@ impl GitRepository for FakeGitRepository {
Ok(state
.branches
.iter()
- .map(|branch_name| Branch {
- is_head: Some(branch_name) == current_branch.as_ref(),
- ref_name: branch_name.into(),
- most_recent_commit: None,
- upstream: None,
+ .map(|branch_name| {
+ let ref_name = if branch_name.starts_with("refs/") {
+ branch_name.into()
+ } else {
+ format!("refs/heads/{branch_name}").into()
+ };
+ Branch {
+ is_head: Some(branch_name) == current_branch.as_ref(),
+ ref_name,
+ most_recent_commit: None,
+ upstream: None,
+ }
})
.collect())
})
@@ -96,7 +96,8 @@ impl<'a> Matcher<'a> {
continue;
}
- let matrix_len = self.query.len() * (prefix.len() + candidate_chars.len());
+ let matrix_len =
+ self.query.len() * (lowercase_prefix.len() + lowercase_candidate_chars.len());
self.score_matrix.clear();
self.score_matrix.resize(matrix_len, None);
self.best_position_matrix.clear();
@@ -596,4 +597,15 @@ mod tests {
})
.collect()
}
+
+ /// Test for https://github.com/zed-industries/zed/issues/44324
+ #[test]
+ fn test_recursive_score_match_index_out_of_bounds() {
+ let paths = vec!["ฤฐ/ฤฐ/ฤฐ/ฤฐ"];
+ let query = "ฤฐ/ฤฐ";
+
+ // This panicked with "index out of bounds: the len is 21 but the index is 22"
+ let result = match_single_path_query(query, false, &paths);
+ let _ = result;
+ }
}
@@ -19,7 +19,6 @@ pub use git2 as libgit;
pub struct Blame {
pub entries: Vec<BlameEntry>,
pub messages: HashMap<Oid, String>,
- pub remote_url: Option<String>,
}
#[derive(Clone, Debug, Default)]
@@ -36,7 +35,6 @@ impl Blame {
working_directory: &Path,
path: &RepoPath,
content: &Rope,
- remote_url: Option<String>,
) -> Result<Self> {
let output = run_git_blame(git_binary, working_directory, path, content).await?;
let mut entries = parse_git_blame(&output)?;
@@ -53,11 +51,7 @@ impl Blame {
.await
.context("failed to get commit messages")?;
- Ok(Self {
- entries,
- messages,
- remote_url,
- })
+ Ok(Self { entries, messages })
}
}
@@ -1494,28 +1494,17 @@ impl GitRepository for RealGitRepository {
let git_binary_path = self.any_git_binary_path.clone();
let executor = self.executor.clone();
- async move {
- let remote_url = if let Some(remote_url) = self.remote_url("upstream").await {
- Some(remote_url)
- } else if let Some(remote_url) = self.remote_url("origin").await {
- Some(remote_url)
- } else {
- None
- };
- executor
- .spawn(async move {
- crate::blame::Blame::for_path(
- &git_binary_path,
- &working_directory?,
- &path,
- &content,
- remote_url,
- )
- .await
- })
+ executor
+ .spawn(async move {
+ crate::blame::Blame::for_path(
+ &git_binary_path,
+ &working_directory?,
+ &path,
+ &content,
+ )
.await
- }
- .boxed()
+ })
+ .boxed()
}
fn file_history(&self, path: RepoPath) -> BoxFuture<'_, Result<FileHistory>> {
@@ -18,6 +18,7 @@ futures.workspace = true
git.workspace = true
gpui.workspace = true
http_client.workspace = true
+itertools.workspace = true
regex.workspace = true
serde.workspace = true
serde_json.workspace = true
@@ -26,7 +26,7 @@ pub fn init(cx: &mut App) {
provider_registry.register_hosting_provider(Arc::new(Gitee));
provider_registry.register_hosting_provider(Arc::new(Github::public_instance()));
provider_registry.register_hosting_provider(Arc::new(Gitlab::public_instance()));
- provider_registry.register_hosting_provider(Arc::new(Sourcehut));
+ provider_registry.register_hosting_provider(Arc::new(SourceHut::public_instance()));
}
/// Registers additional Git hosting providers.
@@ -51,6 +51,8 @@ pub async fn register_additional_providers(
provider_registry.register_hosting_provider(Arc::new(gitea_self_hosted));
} else if let Ok(bitbucket_self_hosted) = Bitbucket::from_remote_url(&origin_url) {
provider_registry.register_hosting_provider(Arc::new(bitbucket_self_hosted));
+ } else if let Ok(sourcehut_self_hosted) = SourceHut::from_remote_url(&origin_url) {
+ provider_registry.register_hosting_provider(Arc::new(sourcehut_self_hosted));
}
}
@@ -1,8 +1,14 @@
-use std::str::FromStr;
use std::sync::LazyLock;
-
-use anyhow::{Result, bail};
+use std::{str::FromStr, sync::Arc};
+
+use anyhow::{Context as _, Result, bail};
+use async_trait::async_trait;
+use futures::AsyncReadExt;
+use gpui::SharedString;
+use http_client::{AsyncBody, HttpClient, HttpRequestExt, Request};
+use itertools::Itertools as _;
use regex::Regex;
+use serde::Deserialize;
use url::Url;
use git::{
@@ -20,6 +26,42 @@ fn pull_request_regex() -> &'static Regex {
&PULL_REQUEST_REGEX
}
+#[derive(Debug, Deserialize)]
+struct CommitDetails {
+ author: Author,
+}
+
+#[derive(Debug, Deserialize)]
+struct Author {
+ user: Account,
+}
+
+#[derive(Debug, Deserialize)]
+struct Account {
+ links: AccountLinks,
+}
+
+#[derive(Debug, Deserialize)]
+struct AccountLinks {
+ avatar: Option<Link>,
+}
+
+#[derive(Debug, Deserialize)]
+struct Link {
+ href: String,
+}
+
+#[derive(Debug, Deserialize)]
+struct CommitDetailsSelfHosted {
+ author: AuthorSelfHosted,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+struct AuthorSelfHosted {
+ avatar_url: Option<String>,
+}
+
pub struct Bitbucket {
name: String,
base_url: Url,
@@ -61,8 +103,60 @@ impl Bitbucket {
.host_str()
.is_some_and(|host| host != "bitbucket.org")
}
+
+ async fn fetch_bitbucket_commit_author(
+ &self,
+ repo_owner: &str,
+ repo: &str,
+ commit: &str,
+ client: &Arc<dyn HttpClient>,
+ ) -> Result<Option<String>> {
+ let Some(host) = self.base_url.host_str() else {
+ bail!("failed to get host from bitbucket base url");
+ };
+ let is_self_hosted = self.is_self_hosted();
+ let url = if is_self_hosted {
+ format!(
+ "https://{host}/rest/api/latest/projects/{repo_owner}/repos/{repo}/commits/{commit}?avatarSize=128"
+ )
+ } else {
+ format!("https://api.{host}/2.0/repositories/{repo_owner}/{repo}/commit/{commit}")
+ };
+
+ let request = Request::get(&url)
+ .header("Content-Type", "application/json")
+ .follow_redirects(http_client::RedirectPolicy::FollowAll);
+
+ let mut response = client
+ .send(request.body(AsyncBody::default())?)
+ .await
+ .with_context(|| format!("error fetching BitBucket commit details at {:?}", url))?;
+
+ let mut body = Vec::new();
+ response.body_mut().read_to_end(&mut body).await?;
+
+ if response.status().is_client_error() {
+ let text = String::from_utf8_lossy(body.as_slice());
+ bail!(
+ "status error {}, response: {text:?}",
+ response.status().as_u16()
+ );
+ }
+
+ let body_str = std::str::from_utf8(&body)?;
+
+ if is_self_hosted {
+ serde_json::from_str::<CommitDetailsSelfHosted>(body_str)
+ .map(|commit| commit.author.avatar_url)
+ } else {
+ serde_json::from_str::<CommitDetails>(body_str)
+ .map(|commit| commit.author.user.links.avatar.map(|link| link.href))
+ }
+ .context("failed to deserialize BitBucket commit details")
+ }
}
+#[async_trait]
impl GitHostingProvider for Bitbucket {
fn name(&self) -> String {
self.name.clone()
@@ -73,7 +167,7 @@ impl GitHostingProvider for Bitbucket {
}
fn supports_avatars(&self) -> bool {
- false
+ true
}
fn format_line_number(&self, line: u32) -> String {
@@ -98,9 +192,16 @@ impl GitHostingProvider for Bitbucket {
return None;
}
- let mut path_segments = url.path_segments()?;
- let owner = path_segments.next()?;
- let repo = path_segments.next()?.trim_end_matches(".git");
+ let mut path_segments = url.path_segments()?.collect::<Vec<_>>();
+ let repo = path_segments.pop()?.trim_end_matches(".git");
+ let owner = if path_segments.get(0).is_some_and(|v| *v == "scm") && path_segments.len() > 1
+ {
+ // Skip the "scm" segment if it's not the only segment
+ // https://github.com/gitkraken/vscode-gitlens/blob/a6e3c6fbb255116507eaabaa9940c192ed7bb0e1/src/git/remotes/bitbucket-server.ts#L72-L74
+ path_segments.into_iter().skip(1).join("/")
+ } else {
+ path_segments.into_iter().join("/")
+ };
Some(ParsedGitRemote {
owner: owner.into(),
@@ -176,6 +277,22 @@ impl GitHostingProvider for Bitbucket {
Some(PullRequest { number, url })
}
+
+ async fn commit_author_avatar_url(
+ &self,
+ repo_owner: &str,
+ repo: &str,
+ commit: SharedString,
+ http_client: Arc<dyn HttpClient>,
+ ) -> Result<Option<Url>> {
+ let commit = commit.to_string();
+ let avatar_url = self
+ .fetch_bitbucket_commit_author(repo_owner, repo, &commit, &http_client)
+ .await?
+ .map(|avatar_url| Url::parse(&avatar_url))
+ .transpose()?;
+ Ok(avatar_url)
+ }
}
#[cfg(test)]
@@ -264,6 +381,38 @@ mod tests {
repo: "zed".into(),
}
);
+
+ // Test with "scm" in the path
+ let remote_url = "https://bitbucket.company.com/scm/zed-industries/zed.git";
+
+ let parsed_remote = Bitbucket::from_remote_url(remote_url)
+ .unwrap()
+ .parse_remote_url(remote_url)
+ .unwrap();
+
+ assert_eq!(
+ parsed_remote,
+ ParsedGitRemote {
+ owner: "zed-industries".into(),
+ repo: "zed".into(),
+ }
+ );
+
+ // Test with only "scm" as owner
+ let remote_url = "https://bitbucket.company.com/scm/zed.git";
+
+ let parsed_remote = Bitbucket::from_remote_url(remote_url)
+ .unwrap()
+ .parse_remote_url(remote_url)
+ .unwrap();
+
+ assert_eq!(
+ parsed_remote,
+ ParsedGitRemote {
+ owner: "scm".into(),
+ repo: "zed".into(),
+ }
+ );
}
#[test]
@@ -1,5 +1,6 @@
use std::str::FromStr;
+use anyhow::{Result, bail};
use url::Url;
use git::{
@@ -7,15 +8,52 @@ use git::{
RemoteUrl,
};
-pub struct Sourcehut;
+use crate::get_host_from_git_remote_url;
-impl GitHostingProvider for Sourcehut {
+pub struct SourceHut {
+ name: String,
+ base_url: Url,
+}
+
+impl SourceHut {
+ pub fn new(name: &str, base_url: Url) -> Self {
+ Self {
+ name: name.to_string(),
+ base_url,
+ }
+ }
+
+ pub fn public_instance() -> Self {
+ Self::new("SourceHut", Url::parse("https://git.sr.ht").unwrap())
+ }
+
+ pub fn from_remote_url(remote_url: &str) -> Result<Self> {
+ let host = get_host_from_git_remote_url(remote_url)?;
+ if host == "git.sr.ht" {
+ bail!("the SourceHut instance is not self-hosted");
+ }
+
+ // TODO: detecting self hosted instances by checking whether "sourcehut" is in the url or not
+ // is not very reliable. See https://github.com/zed-industries/zed/issues/26393 for more
+ // information.
+ if !host.contains("sourcehut") {
+ bail!("not a SourceHut URL");
+ }
+
+ Ok(Self::new(
+ "SourceHut Self-Hosted",
+ Url::parse(&format!("https://{}", host))?,
+ ))
+ }
+}
+
+impl GitHostingProvider for SourceHut {
fn name(&self) -> String {
- "SourceHut".to_string()
+ self.name.clone()
}
fn base_url(&self) -> Url {
- Url::parse("https://git.sr.ht").unwrap()
+ self.base_url.clone()
}
fn supports_avatars(&self) -> bool {
@@ -34,7 +72,7 @@ impl GitHostingProvider for Sourcehut {
let url = RemoteUrl::from_str(url).ok()?;
let host = url.host_str()?;
- if host != "git.sr.ht" {
+ if host != self.base_url.host_str()? {
return None;
}
@@ -96,7 +134,7 @@ mod tests {
#[test]
fn test_parse_remote_url_given_ssh_url() {
- let parsed_remote = Sourcehut
+ let parsed_remote = SourceHut::public_instance()
.parse_remote_url("git@git.sr.ht:~zed-industries/zed")
.unwrap();
@@ -111,7 +149,7 @@ mod tests {
#[test]
fn test_parse_remote_url_given_ssh_url_with_git_suffix() {
- let parsed_remote = Sourcehut
+ let parsed_remote = SourceHut::public_instance()
.parse_remote_url("git@git.sr.ht:~zed-industries/zed.git")
.unwrap();
@@ -126,7 +164,7 @@ mod tests {
#[test]
fn test_parse_remote_url_given_https_url() {
- let parsed_remote = Sourcehut
+ let parsed_remote = SourceHut::public_instance()
.parse_remote_url("https://git.sr.ht/~zed-industries/zed")
.unwrap();
@@ -139,9 +177,63 @@ mod tests {
);
}
+ #[test]
+ fn test_parse_remote_url_given_self_hosted_ssh_url() {
+ let remote_url = "git@sourcehut.org:~zed-industries/zed";
+
+ let parsed_remote = SourceHut::from_remote_url(remote_url)
+ .unwrap()
+ .parse_remote_url(remote_url)
+ .unwrap();
+
+ assert_eq!(
+ parsed_remote,
+ ParsedGitRemote {
+ owner: "zed-industries".into(),
+ repo: "zed".into(),
+ }
+ );
+ }
+
+ #[test]
+ fn test_parse_remote_url_given_self_hosted_ssh_url_with_git_suffix() {
+ let remote_url = "git@sourcehut.org:~zed-industries/zed.git";
+
+ let parsed_remote = SourceHut::from_remote_url(remote_url)
+ .unwrap()
+ .parse_remote_url(remote_url)
+ .unwrap();
+
+ assert_eq!(
+ parsed_remote,
+ ParsedGitRemote {
+ owner: "zed-industries".into(),
+ repo: "zed.git".into(),
+ }
+ );
+ }
+
+ #[test]
+ fn test_parse_remote_url_given_self_hosted_https_url() {
+ let remote_url = "https://sourcehut.org/~zed-industries/zed";
+
+ let parsed_remote = SourceHut::from_remote_url(remote_url)
+ .unwrap()
+ .parse_remote_url(remote_url)
+ .unwrap();
+
+ assert_eq!(
+ parsed_remote,
+ ParsedGitRemote {
+ owner: "zed-industries".into(),
+ repo: "zed".into(),
+ }
+ );
+ }
+
#[test]
fn test_build_sourcehut_permalink() {
- let permalink = Sourcehut.build_permalink(
+ let permalink = SourceHut::public_instance().build_permalink(
ParsedGitRemote {
owner: "zed-industries".into(),
repo: "zed".into(),
@@ -159,7 +251,7 @@ mod tests {
#[test]
fn test_build_sourcehut_permalink_with_git_suffix() {
- let permalink = Sourcehut.build_permalink(
+ let permalink = SourceHut::public_instance().build_permalink(
ParsedGitRemote {
owner: "zed-industries".into(),
repo: "zed.git".into(),
@@ -175,9 +267,49 @@ mod tests {
assert_eq!(permalink.to_string(), expected_url.to_string())
}
+ #[test]
+ fn test_build_sourcehut_self_hosted_permalink() {
+ let permalink = SourceHut::from_remote_url("https://sourcehut.org/~zed-industries/zed")
+ .unwrap()
+ .build_permalink(
+ ParsedGitRemote {
+ owner: "zed-industries".into(),
+ repo: "zed".into(),
+ },
+ BuildPermalinkParams::new(
+ "faa6f979be417239b2e070dbbf6392b909224e0b",
+ &repo_path("crates/editor/src/git/permalink.rs"),
+ None,
+ ),
+ );
+
+ let expected_url = "https://sourcehut.org/~zed-industries/zed/tree/faa6f979be417239b2e070dbbf6392b909224e0b/item/crates/editor/src/git/permalink.rs";
+ assert_eq!(permalink.to_string(), expected_url.to_string())
+ }
+
+ #[test]
+ fn test_build_sourcehut_self_hosted_permalink_with_git_suffix() {
+ let permalink = SourceHut::from_remote_url("https://sourcehut.org/~zed-industries/zed.git")
+ .unwrap()
+ .build_permalink(
+ ParsedGitRemote {
+ owner: "zed-industries".into(),
+ repo: "zed.git".into(),
+ },
+ BuildPermalinkParams::new(
+ "faa6f979be417239b2e070dbbf6392b909224e0b",
+ &repo_path("crates/editor/src/git/permalink.rs"),
+ None,
+ ),
+ );
+
+ let expected_url = "https://sourcehut.org/~zed-industries/zed.git/tree/faa6f979be417239b2e070dbbf6392b909224e0b/item/crates/editor/src/git/permalink.rs";
+ assert_eq!(permalink.to_string(), expected_url.to_string())
+ }
+
#[test]
fn test_build_sourcehut_permalink_with_single_line_selection() {
- let permalink = Sourcehut.build_permalink(
+ let permalink = SourceHut::public_instance().build_permalink(
ParsedGitRemote {
owner: "zed-industries".into(),
repo: "zed".into(),
@@ -195,7 +327,7 @@ mod tests {
#[test]
fn test_build_sourcehut_permalink_with_multi_line_selection() {
- let permalink = Sourcehut.build_permalink(
+ let permalink = SourceHut::public_instance().build_permalink(
ParsedGitRemote {
owner: "zed-industries".into(),
repo: "zed".into(),
@@ -210,4 +342,44 @@ mod tests {
let expected_url = "https://git.sr.ht/~zed-industries/zed/tree/faa6f979be417239b2e070dbbf6392b909224e0b/item/crates/editor/src/git/permalink.rs#L24-48";
assert_eq!(permalink.to_string(), expected_url.to_string())
}
+
+ #[test]
+ fn test_build_sourcehut_self_hosted_permalink_with_single_line_selection() {
+ let permalink = SourceHut::from_remote_url("https://sourcehut.org/~zed-industries/zed")
+ .unwrap()
+ .build_permalink(
+ ParsedGitRemote {
+ owner: "zed-industries".into(),
+ repo: "zed".into(),
+ },
+ BuildPermalinkParams::new(
+ "faa6f979be417239b2e070dbbf6392b909224e0b",
+ &repo_path("crates/editor/src/git/permalink.rs"),
+ Some(6..6),
+ ),
+ );
+
+ let expected_url = "https://sourcehut.org/~zed-industries/zed/tree/faa6f979be417239b2e070dbbf6392b909224e0b/item/crates/editor/src/git/permalink.rs#L7";
+ assert_eq!(permalink.to_string(), expected_url.to_string())
+ }
+
+ #[test]
+ fn test_build_sourcehut_self_hosted_permalink_with_multi_line_selection() {
+ let permalink = SourceHut::from_remote_url("https://sourcehut.org/~zed-industries/zed")
+ .unwrap()
+ .build_permalink(
+ ParsedGitRemote {
+ owner: "zed-industries".into(),
+ repo: "zed".into(),
+ },
+ BuildPermalinkParams::new(
+ "faa6f979be417239b2e070dbbf6392b909224e0b",
+ &repo_path("crates/editor/src/git/permalink.rs"),
+ Some(23..47),
+ ),
+ );
+
+ let expected_url = "https://sourcehut.org/~zed-industries/zed/tree/faa6f979be417239b2e070dbbf6392b909224e0b/item/crates/editor/src/git/permalink.rs#L24-48";
+ assert_eq!(permalink.to_string(), expected_url.to_string())
+ }
}
@@ -8,7 +8,7 @@ use settings::{
use url::Url;
use util::ResultExt as _;
-use crate::{Bitbucket, Github, Gitlab};
+use crate::{Bitbucket, Forgejo, Gitea, Github, Gitlab, SourceHut};
pub(crate) fn init(cx: &mut App) {
init_git_hosting_provider_settings(cx);
@@ -46,6 +46,11 @@ fn update_git_hosting_providers_from_settings(cx: &mut App) {
}
GitHostingProviderKind::Github => Arc::new(Github::new(&provider.name, url)) as _,
GitHostingProviderKind::Gitlab => Arc::new(Gitlab::new(&provider.name, url)) as _,
+ GitHostingProviderKind::Gitea => Arc::new(Gitea::new(&provider.name, url)) as _,
+ GitHostingProviderKind::Forgejo => Arc::new(Forgejo::new(&provider.name, url)) as _,
+ GitHostingProviderKind::SourceHut => {
+ Arc::new(SourceHut::new(&provider.name, url)) as _
+ }
})
});
@@ -13,7 +13,6 @@ name = "git_ui"
path = "src/git_ui.rs"
[features]
-default = []
test-support = ["multi_buffer/test-support"]
[dependencies]
@@ -62,7 +61,8 @@ watch.workspace = true
workspace.workspace = true
zed_actions.workspace = true
zeroize.workspace = true
-
+ztracing.workspace = true
+tracing.workspace = true
[target.'cfg(windows)'.dependencies]
windows.workspace = true
@@ -78,3 +78,6 @@ settings = { workspace = true, features = ["test-support"] }
unindent.workspace = true
workspace = { workspace = true, features = ["test-support"] }
zlog.workspace = true
+
+[package.metadata.cargo-machete]
+ignored = ["tracing"]
@@ -55,6 +55,7 @@ impl BlameRenderer for GitBlameRenderer {
} else {
None
};
+
Some(
div()
.mr_2()
@@ -80,7 +81,10 @@ impl BlameRenderer for GitBlameRenderer {
.on_mouse_down(MouseButton::Right, {
let blame_entry = blame_entry.clone();
let details = details.clone();
+ let editor = editor.clone();
move |event, window, cx| {
+ cx.stop_propagation();
+
deploy_blame_entry_context_menu(
&blame_entry,
details.as_ref(),
@@ -107,17 +111,19 @@ impl BlameRenderer for GitBlameRenderer {
)
}
})
- .hoverable_tooltip(move |_window, cx| {
- cx.new(|cx| {
- CommitTooltip::blame_entry(
- &blame_entry,
- details.clone(),
- repository.clone(),
- workspace.clone(),
- cx,
- )
+ .when(!editor.read(cx).has_mouse_context_menu(), |el| {
+ el.hoverable_tooltip(move |_window, cx| {
+ cx.new(|cx| {
+ CommitTooltip::blame_entry(
+ &blame_entry,
+ details.clone(),
+ repository.clone(),
+ workspace.clone(),
+ cx,
+ )
+ })
+ .into()
})
- .into()
}),
)
.into_any(),
@@ -148,7 +154,7 @@ impl BlameRenderer for GitBlameRenderer {
h_flex()
.id("inline-blame")
.w_full()
- .font_family(style.font().family)
+ .font(style.font())
.text_color(cx.theme().status().hint)
.line_height(style.line_height)
.child(Icon::new(IconName::FileGit).color(Color::Hint))
@@ -396,6 +402,7 @@ fn deploy_blame_entry_context_menu(
});
editor.update(cx, move |editor, cx| {
+ editor.hide_blame_popover(false, cx);
editor.deploy_mouse_context_menu(position, context_menu, window, cx);
cx.notify();
});
@@ -17,8 +17,8 @@ use settings::Settings;
use std::sync::Arc;
use time::OffsetDateTime;
use ui::{
- CommonAnimationExt, Divider, HighlightedLabel, KeyBinding, ListItem, ListItemSpacing, Tooltip,
- prelude::*,
+ CommonAnimationExt, Divider, HighlightedLabel, KeyBinding, ListHeader, ListItem,
+ ListItemSpacing, Tooltip, prelude::*,
};
use util::ResultExt;
use workspace::notifications::DetachAndPromptErr;
@@ -440,13 +440,6 @@ impl BranchListDelegate {
cx.emit(DismissEvent);
}
- fn loader(&self) -> AnyElement {
- Icon::new(IconName::LoadCircle)
- .size(IconSize::Small)
- .with_rotate_animation(3)
- .into_any_element()
- }
-
fn delete_at(&self, idx: usize, window: &mut Window, cx: &mut Context<Picker<Self>>) {
let Some(entry) = self.matches.get(idx).cloned() else {
return;
@@ -558,6 +551,8 @@ impl PickerDelegate for BranchListDelegate {
editor.set_placeholder_text(placeholder, window, cx);
});
+ let focus_handle = self.focus_handle.clone();
+
v_flex()
.when(
self.editor_position() == PickerEditorPosition::End,
@@ -569,7 +564,37 @@ impl PickerDelegate for BranchListDelegate {
.flex_none()
.h_9()
.px_2p5()
- .child(editor.clone()),
+ .child(editor.clone())
+ .when(
+ self.editor_position() == PickerEditorPosition::End,
+ |this| {
+ let tooltip_label = if self.display_remotes {
+ "Turn Off Remote Filter"
+ } else {
+ "Filter Remote Branches"
+ };
+
+ this.gap_1().justify_between().child({
+ IconButton::new("filter-remotes", IconName::Filter)
+ .disabled(self.loading)
+ .toggle_state(self.display_remotes)
+ .tooltip(move |_, cx| {
+ Tooltip::for_action_in(
+ tooltip_label,
+ &branch_picker::FilterRemotes,
+ &focus_handle,
+ cx,
+ )
+ })
+ .on_click(|_click, window, cx| {
+ window.dispatch_action(
+ branch_picker::FilterRemotes.boxed_clone(),
+ cx,
+ );
+ })
+ })
+ },
+ ),
)
.when(
self.editor_position() == PickerEditorPosition::Start,
@@ -683,10 +708,16 @@ impl PickerDelegate for BranchListDelegate {
} else {
Entry::NewBranch { name: query }
};
- picker.delegate.state = if is_url {
- PickerState::NewRemote
+ // Only transition to NewBranch/NewRemote states when we only show their list item
+ // Otherwise, stay in List state so footer buttons remain visible
+ picker.delegate.state = if matches.is_empty() {
+ if is_url {
+ PickerState::NewRemote
+ } else {
+ PickerState::NewBranch
+ }
} else {
- PickerState::NewBranch
+ PickerState::List
};
matches.push(entry);
} else {
@@ -770,7 +801,7 @@ impl PickerDelegate for BranchListDelegate {
} else {
None
};
- self.create_branch(from_branch, format!("refs/heads/{name}").into(), window, cx);
+ self.create_branch(from_branch, name.into(), window, cx);
}
}
@@ -812,226 +843,256 @@ impl PickerDelegate for BranchListDelegate {
})
.unwrap_or_else(|| (None, None, None));
- let icon = if let Some(default_branch) = self.default_branch.clone() {
- let icon = match &entry {
- Entry::Branch { .. } => Some((
- IconName::GitBranchAlt,
- format!("Create branch based off default: {default_branch}"),
- )),
- Entry::NewUrl { url } => {
- Some((IconName::Screen, format!("Create remote based off {url}")))
- }
- Entry::NewBranch { .. } => None,
- };
-
- icon.map(|(icon, tooltip_text)| {
- IconButton::new("branch-from-default", icon)
- .on_click(cx.listener(move |this, _, window, cx| {
- this.delegate.set_selected_index(ix, window, cx);
- this.delegate.confirm(true, window, cx);
- }))
- .tooltip(move |_window, cx| {
- Tooltip::for_action(tooltip_text.clone(), &menu::SecondaryConfirm, cx)
- })
- })
- } else {
- None
- };
+ let entry_icon = match entry {
+ Entry::NewUrl { .. } | Entry::NewBranch { .. } => {
+ Icon::new(IconName::Plus).color(Color::Muted)
+ }
- let icon_element = if self.display_remotes {
- Icon::new(IconName::Screen)
- } else {
- Icon::new(IconName::GitBranchAlt)
+ Entry::Branch { .. } => {
+ if self.display_remotes {
+ Icon::new(IconName::Screen).color(Color::Muted)
+ } else {
+ Icon::new(IconName::GitBranchAlt).color(Color::Muted)
+ }
+ }
};
- let entry_name = match entry {
- Entry::NewUrl { .. } => h_flex()
- .gap_1()
- .child(
- Icon::new(IconName::Plus)
- .size(IconSize::Small)
- .color(Color::Muted),
- )
- .child(
- Label::new("Create remote repository".to_string())
- .single_line()
- .truncate(),
- )
- .into_any_element(),
- Entry::NewBranch { name } => h_flex()
- .gap_1()
- .child(
- Icon::new(IconName::Plus)
- .size(IconSize::Small)
- .color(Color::Muted),
- )
- .child(
- Label::new(format!("Create branch \"{name}\"โฆ"))
- .single_line()
- .truncate(),
- )
+ let entry_title = match entry {
+ Entry::NewUrl { .. } => Label::new("Create Remote Repository")
+ .single_line()
+ .truncate()
.into_any_element(),
- Entry::Branch { branch, positions } => h_flex()
- .max_w_48()
- .child(h_flex().mr_1().child(icon_element))
- .child(
- HighlightedLabel::new(branch.name().to_string(), positions.clone()).truncate(),
- )
+ Entry::NewBranch { name } => Label::new(format!("Create Branch: \"{name}\"โฆ"))
+ .single_line()
+ .truncate()
.into_any_element(),
+ Entry::Branch { branch, positions } => {
+ HighlightedLabel::new(branch.name().to_string(), positions.clone())
+ .single_line()
+ .truncate()
+ .into_any_element()
+ }
};
+ let focus_handle = self.focus_handle.clone();
+ let is_new_items = matches!(entry, Entry::NewUrl { .. } | Entry::NewBranch { .. });
+
+ let delete_branch_button = IconButton::new("delete", IconName::Trash)
+ .tooltip(move |_, cx| {
+ Tooltip::for_action_in(
+ "Delete Branch",
+ &branch_picker::DeleteBranch,
+ &focus_handle,
+ cx,
+ )
+ })
+ .on_click(cx.listener(|this, _, window, cx| {
+ let selected_idx = this.delegate.selected_index();
+ this.delegate.delete_at(selected_idx, window, cx);
+ }));
+
+ let create_from_default_button = self.default_branch.as_ref().map(|default_branch| {
+ let tooltip_label: SharedString = format!("Create New From: {default_branch}").into();
+ let focus_handle = self.focus_handle.clone();
+
+ IconButton::new("create_from_default", IconName::GitBranchPlus)
+ .tooltip(move |_, cx| {
+ Tooltip::for_action_in(
+ tooltip_label.clone(),
+ &menu::SecondaryConfirm,
+ &focus_handle,
+ cx,
+ )
+ })
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.delegate.confirm(true, window, cx);
+ }))
+ .into_any_element()
+ });
+
Some(
ListItem::new(SharedString::from(format!("vcs-menu-{ix}")))
.inset(true)
.spacing(ListItemSpacing::Sparse)
.toggle_state(selected)
- .tooltip({
- match entry {
- Entry::Branch { branch, .. } => Tooltip::text(branch.name().to_string()),
- Entry::NewUrl { .. } => {
- Tooltip::text("Create remote repository".to_string())
- }
- Entry::NewBranch { name } => {
- Tooltip::text(format!("Create branch \"{name}\""))
- }
- }
- })
.child(
- v_flex()
+ h_flex()
.w_full()
- .overflow_hidden()
+ .gap_3()
+ .flex_grow()
+ .child(entry_icon)
.child(
- h_flex()
- .gap_6()
- .justify_between()
- .overflow_x_hidden()
- .child(entry_name)
- .when_some(commit_time, |label, commit_time| {
- label.child(
- Label::new(commit_time)
- .size(LabelSize::Small)
- .color(Color::Muted)
- .into_element(),
- )
- }),
- )
- .when(self.style == BranchListStyle::Modal, |el| {
- el.child(div().max_w_96().child({
- let message = match entry {
- Entry::NewUrl { url } => format!("based off {url}"),
- Entry::NewBranch { .. } => {
- if let Some(current_branch) =
- self.repo.as_ref().and_then(|repo| {
- repo.read(cx).branch.as_ref().map(|b| b.name())
- })
- {
- format!("based off {}", current_branch)
- } else {
- "based off the current branch".to_string()
- }
- }
- Entry::Branch { .. } => {
- let show_author_name = ProjectSettings::get_global(cx)
- .git
- .branch_picker
- .show_author_name;
-
- subject.map_or("no commits found".into(), |subject| {
- if show_author_name && author_name.is_some() {
- format!("{} โข {}", author_name.unwrap(), subject)
- } else {
- subject.to_string()
- }
+ v_flex()
+ .id("info_container")
+ .w_full()
+ .child(entry_title)
+ .child(
+ h_flex()
+ .w_full()
+ .justify_between()
+ .gap_1p5()
+ .when(self.style == BranchListStyle::Modal, |el| {
+ el.child(div().max_w_96().child({
+ let message = match entry {
+ Entry::NewUrl { url } => {
+ format!("Based off {url}")
+ }
+ Entry::NewBranch { .. } => {
+ if let Some(current_branch) =
+ self.repo.as_ref().and_then(|repo| {
+ repo.read(cx)
+ .branch
+ .as_ref()
+ .map(|b| b.name())
+ })
+ {
+ format!("Based off {}", current_branch)
+ } else {
+ "Based off the current branch"
+ .to_string()
+ }
+ }
+ Entry::Branch { .. } => {
+ let show_author_name =
+ ProjectSettings::get_global(cx)
+ .git
+ .branch_picker
+ .show_author_name;
+
+ subject.map_or(
+ "No commits found".into(),
+ |subject| {
+ if show_author_name
+ && author_name.is_some()
+ {
+ format!(
+ "{} โข {}",
+ author_name.unwrap(),
+ subject
+ )
+ } else {
+ subject.to_string()
+ }
+ },
+ )
+ }
+ };
+
+ Label::new(message)
+ .size(LabelSize::Small)
+ .color(Color::Muted)
+ .truncate()
+ }))
})
- }
- };
-
- Label::new(message)
- .size(LabelSize::Small)
- .truncate()
- .color(Color::Muted)
- }))
- }),
+ .when_some(commit_time, |label, commit_time| {
+ label.child(
+ Label::new(commit_time)
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ }),
+ )
+ .when_some(
+ entry.as_branch().map(|b| b.name().to_string()),
+ |this, branch_name| this.tooltip(Tooltip::text(branch_name)),
+ ),
+ ),
+ )
+ .when(
+ self.editor_position() == PickerEditorPosition::End && !is_new_items,
+ |this| {
+ this.map(|this| {
+ if self.selected_index() == ix {
+ this.end_slot(delete_branch_button)
+ } else {
+ this.end_hover_slot(delete_branch_button)
+ }
+ })
+ },
)
- .end_slot::<IconButton>(icon),
+ .when_some(
+ if self.editor_position() == PickerEditorPosition::End && is_new_items {
+ create_from_default_button
+ } else {
+ None
+ },
+ |this, create_from_default_button| {
+ this.map(|this| {
+ if self.selected_index() == ix {
+ this.end_slot(create_from_default_button)
+ } else {
+ this.end_hover_slot(create_from_default_button)
+ }
+ })
+ },
+ ),
)
}
fn render_header(
&self,
_window: &mut Window,
- cx: &mut Context<Picker<Self>>,
+ _cx: &mut Context<Picker<Self>>,
) -> Option<AnyElement> {
- if matches!(
- self.state,
- PickerState::CreateRemote(_) | PickerState::NewRemote | PickerState::NewBranch
- ) {
+ matches!(self.state, PickerState::List).then(|| {
+ let label = if self.display_remotes {
+ "Remote"
+ } else {
+ "Local"
+ };
+
+ ListHeader::new(label).inset(true).into_any_element()
+ })
+ }
+
+ fn render_footer(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
+ if self.editor_position() == PickerEditorPosition::End {
return None;
}
- let label = if self.display_remotes {
- "Remote"
- } else {
- "Local"
- };
- Some(
+
+ let focus_handle = self.focus_handle.clone();
+ let loading_icon = Icon::new(IconName::LoadCircle)
+ .size(IconSize::Small)
+ .with_rotate_animation(3);
+
+ let footer_container = || {
h_flex()
.w_full()
.p_1p5()
- .gap_1()
.border_t_1()
.border_color(cx.theme().colors().border_variant)
- .child(Label::new(label).size(LabelSize::Small).color(Color::Muted))
- .into_any(),
- )
- }
-
- fn render_footer(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
- let focus_handle = self.focus_handle.clone();
+ };
- if self.loading {
- return Some(
- h_flex()
- .w_full()
- .p_1p5()
- .gap_1()
- .justify_end()
- .border_t_1()
- .border_color(cx.theme().colors().border_variant)
- .child(self.loader())
- .into_any(),
- );
- }
match self.state {
- PickerState::List => Some(
- h_flex()
- .w_full()
- .p_1p5()
- .gap_0p5()
- .border_t_1()
- .border_color(cx.theme().colors().border_variant)
- .justify_between()
- .child(
- Button::new("filter-remotes", "Filter remotes")
+ PickerState::List => {
+ let selected_entry = self.matches.get(self.selected_index);
+
+ let branch_from_default_button = self
+ .default_branch
+ .as_ref()
+ .filter(|_| matches!(selected_entry, Some(Entry::NewBranch { .. })))
+ .map(|default_branch| {
+ let button_label = format!("Create New From: {default_branch}");
+
+ Button::new("branch-from-default", button_label)
.key_binding(
KeyBinding::for_action_in(
- &branch_picker::FilterRemotes,
+ &menu::SecondaryConfirm,
&focus_handle,
cx,
)
.map(|kb| kb.size(rems_from_px(12.))),
)
- .on_click(|_click, window, cx| {
- window.dispatch_action(
- branch_picker::FilterRemotes.boxed_clone(),
- cx,
- );
- })
- .disabled(self.loading)
- .style(ButtonStyle::Subtle)
- .toggle_state(self.display_remotes),
- )
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.delegate.confirm(true, window, cx);
+ }))
+ });
+
+ let delete_and_select_btns = h_flex()
+ .gap_1()
.child(
Button::new("delete-branch", "Delete")
+ .disabled(self.loading)
.key_binding(
KeyBinding::for_action_in(
&branch_picker::DeleteBranch,
@@ -1040,43 +1101,134 @@ impl PickerDelegate for BranchListDelegate {
)
.map(|kb| kb.size(rems_from_px(12.))),
)
- .disabled(self.loading)
.on_click(|_, window, cx| {
window
.dispatch_action(branch_picker::DeleteBranch.boxed_clone(), cx);
}),
)
- .when(self.loading, |this| this.child(self.loader()))
- .into_any(),
- ),
+ .child(
+ Button::new("select_branch", "Select")
+ .key_binding(
+ KeyBinding::for_action_in(&menu::Confirm, &focus_handle, cx)
+ .map(|kb| kb.size(rems_from_px(12.))),
+ )
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.delegate.confirm(false, window, cx);
+ })),
+ );
+
+ Some(
+ footer_container()
+ .map(|this| {
+ if branch_from_default_button.is_some() {
+ this.justify_end().when_some(
+ branch_from_default_button,
+ |this, button| {
+ this.child(button).child(
+ Button::new("create", "Create")
+ .key_binding(
+ KeyBinding::for_action_in(
+ &menu::Confirm,
+ &focus_handle,
+ cx,
+ )
+ .map(|kb| kb.size(rems_from_px(12.))),
+ )
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.delegate.confirm(false, window, cx);
+ })),
+ )
+ },
+ )
+ } else if self.loading {
+ this.justify_between()
+ .child(loading_icon)
+ .child(delete_and_select_btns)
+ } else {
+ this.justify_between()
+ .child({
+ let focus_handle = focus_handle.clone();
+ Button::new("filter-remotes", "Filter Remotes")
+ .disabled(self.loading)
+ .toggle_state(self.display_remotes)
+ .key_binding(
+ KeyBinding::for_action_in(
+ &branch_picker::FilterRemotes,
+ &focus_handle,
+ cx,
+ )
+ .map(|kb| kb.size(rems_from_px(12.))),
+ )
+ .on_click(|_click, window, cx| {
+ window.dispatch_action(
+ branch_picker::FilterRemotes.boxed_clone(),
+ cx,
+ );
+ })
+ })
+ .child(delete_and_select_btns)
+ }
+ })
+ .into_any_element(),
+ )
+ }
+ PickerState::NewBranch => {
+ let branch_from_default_button =
+ self.default_branch.as_ref().map(|default_branch| {
+ let button_label = format!("Create New From: {default_branch}");
+
+ Button::new("branch-from-default", button_label)
+ .key_binding(
+ KeyBinding::for_action_in(
+ &menu::SecondaryConfirm,
+ &focus_handle,
+ cx,
+ )
+ .map(|kb| kb.size(rems_from_px(12.))),
+ )
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.delegate.confirm(true, window, cx);
+ }))
+ });
+
+ Some(
+ footer_container()
+ .gap_1()
+ .justify_end()
+ .when_some(branch_from_default_button, |this, button| {
+ this.child(button)
+ })
+ .child(
+ Button::new("branch-from-default", "Create")
+ .key_binding(
+ KeyBinding::for_action_in(&menu::Confirm, &focus_handle, cx)
+ .map(|kb| kb.size(rems_from_px(12.))),
+ )
+ .on_click(cx.listener(|this, _, window, cx| {
+ this.delegate.confirm(false, window, cx);
+ })),
+ )
+ .into_any_element(),
+ )
+ }
PickerState::CreateRemote(_) => Some(
- h_flex()
- .w_full()
- .p_1p5()
- .gap_1()
- .border_t_1()
- .border_color(cx.theme().colors().border_variant)
+ footer_container()
+ .justify_end()
.child(
Label::new("Choose a name for this remote repository")
.size(LabelSize::Small)
.color(Color::Muted),
)
.child(
- h_flex().w_full().justify_end().child(
- Label::new("Save")
- .size(LabelSize::Small)
- .color(Color::Muted),
- ),
+ Label::new("Save")
+ .size(LabelSize::Small)
+ .color(Color::Muted),
)
- .into_any(),
+ .into_any_element(),
),
- PickerState::NewRemote | PickerState::NewBranch => None,
+ PickerState::NewRemote => None,
}
}
-
- fn no_matches_text(&self, _window: &mut Window, _cx: &mut App) -> Option<SharedString> {
- None
- }
}
#[cfg(test)]
@@ -1506,6 +1658,7 @@ mod tests {
let last_match = picker.delegate.matches.last().unwrap();
assert!(last_match.is_new_branch());
assert_eq!(last_match.name(), "new-feature-branch");
+ // State is NewBranch because no existing branches fuzzy-match the query
assert!(matches!(picker.delegate.state, PickerState::NewBranch));
picker.delegate.confirm(false, window, cx);
})
@@ -1527,10 +1680,14 @@ mod tests {
.unwrap()
.unwrap();
- assert!(
- branches
- .into_iter()
- .any(|branch| branch.name() == "new-feature-branch")
+ let new_branch = branches
+ .into_iter()
+ .find(|branch| branch.name() == "new-feature-branch")
+ .expect("new-feature-branch should exist");
+ assert_eq!(
+ new_branch.ref_name.as_ref(),
+ "refs/heads/new-feature-branch",
+ "branch ref_name should not have duplicate refs/heads/ prefix"
);
}
@@ -1,7 +1,9 @@
use anyhow::{Context as _, Result};
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::display_map::{BlockPlacement, BlockProperties, BlockStyle};
-use editor::{Addon, Editor, EditorEvent, ExcerptId, ExcerptRange, MultiBuffer};
+use editor::{
+ Editor, EditorEvent, ExcerptId, ExcerptRange, MultiBuffer, multibuffer_context_lines,
+};
use git::repository::{CommitDetails, CommitDiff, RepoPath};
use git::{GitHostingProviderRegistry, GitRemote, parse_git_remote_url};
use gpui::{
@@ -10,10 +12,9 @@ use gpui::{
PromptLevel, Render, Styled, Task, WeakEntity, Window, actions,
};
use language::{
- Anchor, Buffer, Capability, DiskState, File, LanguageRegistry, LineEnding, ReplicaId, Rope,
- TextBuffer, ToPoint,
+ Anchor, Buffer, Capability, DiskState, File, LanguageRegistry, LineEnding, OffsetRangeExt as _,
+ ReplicaId, Rope, TextBuffer,
};
-use multi_buffer::ExcerptInfo;
use multi_buffer::PathKey;
use project::{Project, WorktreeId, git_store::Repository};
use std::{
@@ -22,11 +23,9 @@ use std::{
sync::Arc,
};
use theme::ActiveTheme;
-use ui::{
- Avatar, Button, ButtonCommon, Clickable, Color, Icon, IconName, IconSize, Label,
- LabelCommon as _, LabelSize, SharedString, div, h_flex, v_flex,
-};
+use ui::{Avatar, DiffStat, Tooltip, prelude::*};
use util::{ResultExt, paths::PathStyle, rel_path::RelPath, truncate_and_trailoff};
+use workspace::item::TabTooltipContent;
use workspace::{
Item, ItemHandle, ItemNavHistory, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView,
Workspace,
@@ -151,11 +150,10 @@ impl CommitView {
let editor = cx.new(|cx| {
let mut editor =
Editor::for_multibuffer(multibuffer.clone(), Some(project.clone()), window, cx);
+
editor.disable_inline_diagnostics();
editor.set_expand_all_diff_hunks(cx);
- editor.register_addon(CommitViewAddon {
- multibuffer: multibuffer.downgrade(),
- });
+
editor
});
let commit_sha = Arc::<str>::from(commit.sha.as_ref());
@@ -206,33 +204,22 @@ impl CommitView {
this.multibuffer.update(cx, |multibuffer, cx| {
let snapshot = buffer.read(cx).snapshot();
let path = snapshot.file().unwrap().path().clone();
-
- let hunks: Vec<_> = buffer_diff.read(cx).hunks(&snapshot, cx).collect();
-
- let excerpt_ranges = if hunks.is_empty() {
- vec![language::Point::zero()..snapshot.max_point()]
- } else {
- hunks
- .into_iter()
- .map(|hunk| {
- let start = hunk.range.start.max(language::Point::new(
- hunk.range.start.row.saturating_sub(3),
- 0,
- ));
- let end_row =
- (hunk.range.end.row + 3).min(snapshot.max_point().row);
- let end =
- language::Point::new(end_row, snapshot.line_len(end_row));
- start..end
- })
- .collect()
+ let excerpt_ranges = {
+ let mut hunks = buffer_diff.read(cx).hunks(&snapshot, cx).peekable();
+ if hunks.peek().is_none() {
+ vec![language::Point::zero()..snapshot.max_point()]
+ } else {
+ hunks
+ .map(|hunk| hunk.buffer_range.to_point(&snapshot))
+ .collect::<Vec<_>>()
+ }
};
let _is_newly_added = multibuffer.set_excerpts_for_path(
PathKey::with_sort_prefix(FILE_NAMESPACE_SORT_PREFIX, path),
buffer,
excerpt_ranges,
- 0,
+ multibuffer_context_lines(cx),
cx,
);
multibuffer.add_diff(buffer_diff, cx);
@@ -262,6 +249,8 @@ impl CommitView {
this.editor.update(cx, |editor, cx| {
editor.disable_header_for_buffer(message_buffer.read(cx).remote_id(), cx);
+ editor
+ .disable_indent_guides_for_buffer(message_buffer.read(cx).remote_id(), cx);
editor.insert_blocks(
[BlockProperties {
@@ -357,6 +346,41 @@ impl CommitView {
.into_any()
}
+ fn calculate_changed_lines(&self, cx: &App) -> (u32, u32) {
+ let snapshot = self.multibuffer.read(cx).snapshot(cx);
+ let mut total_additions = 0u32;
+ let mut total_deletions = 0u32;
+
+ let mut seen_buffers = std::collections::HashSet::new();
+ for (_, buffer, _) in snapshot.excerpts() {
+ let buffer_id = buffer.remote_id();
+ if !seen_buffers.insert(buffer_id) {
+ continue;
+ }
+
+ let Some(diff) = snapshot.diff_for_buffer_id(buffer_id) else {
+ continue;
+ };
+
+ let base_text = diff.base_text();
+
+ for hunk in diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, buffer) {
+ let added_rows = hunk.range.end.row.saturating_sub(hunk.range.start.row);
+ total_additions += added_rows;
+
+ let base_start = base_text
+ .offset_to_point(hunk.diff_base_byte_range.start)
+ .row;
+ let base_end = base_text.offset_to_point(hunk.diff_base_byte_range.end).row;
+ let deleted_rows = base_end.saturating_sub(base_start);
+
+ total_deletions += deleted_rows;
+ }
+ }
+
+ (total_additions, total_deletions)
+ }
+
fn render_header(&self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let commit = &self.commit;
let author_name = commit.author_name.clone();
@@ -380,46 +404,72 @@ impl CommitView {
)
});
- v_flex()
- .p_4()
- .pl_0()
- .gap_4()
+ let (additions, deletions) = self.calculate_changed_lines(cx);
+
+ let commit_diff_stat = if additions > 0 || deletions > 0 {
+ Some(DiffStat::new(
+ "commit-diff-stat",
+ additions as usize,
+ deletions as usize,
+ ))
+ } else {
+ None
+ };
+
+ h_flex()
.border_b_1()
- .border_color(cx.theme().colors().border)
+ .border_color(cx.theme().colors().border_variant)
+ .child(
+ h_flex()
+ .w(self.editor.read(cx).last_gutter_dimensions().full_width())
+ .justify_center()
+ .child(self.render_commit_avatar(&commit.sha, rems_from_px(48.), window, cx)),
+ )
.child(
h_flex()
+ .py_4()
+ .pl_1()
+ .pr_4()
+ .w_full()
.items_start()
- .child(
- h_flex()
- .w(self.editor.read(cx).last_gutter_dimensions().full_width())
- .justify_center()
- .child(self.render_commit_avatar(
- &commit.sha,
- gpui::rems(3.0),
- window,
- cx,
- )),
- )
+ .justify_between()
+ .flex_wrap()
.child(
v_flex()
- .gap_1()
.child(
h_flex()
- .gap_3()
- .items_baseline()
+ .gap_1()
.child(Label::new(author_name).color(Color::Default))
.child(
- Label::new(format!("commit {}", commit.sha))
- .color(Color::Muted),
+ Label::new(format!("Commit:{}", commit.sha))
+ .color(Color::Muted)
+ .size(LabelSize::Small)
+ .truncate()
+ .buffer_font(cx),
),
)
- .child(Label::new(date_string).color(Color::Muted)),
+ .child(
+ h_flex()
+ .gap_1p5()
+ .child(
+ Label::new(date_string)
+ .color(Color::Muted)
+ .size(LabelSize::Small),
+ )
+ .child(
+ Label::new("โข")
+ .color(Color::Ignored)
+ .size(LabelSize::Small),
+ )
+ .children(commit_diff_stat),
+ ),
)
- .child(div().flex_grow())
.children(github_url.map(|url| {
Button::new("view_on_github", "View on GitHub")
.icon(IconName::Github)
- .style(ui::ButtonStyle::Subtle)
+ .icon_color(Color::Muted)
+ .icon_size(IconSize::Small)
+ .icon_position(IconPosition::Start)
.on_click(move |_, _, cx| cx.open_url(&url))
})),
)
@@ -714,55 +764,6 @@ impl language::File for GitBlob {
// }
// }
-struct CommitViewAddon {
- multibuffer: WeakEntity<MultiBuffer>,
-}
-
-impl Addon for CommitViewAddon {
- fn render_buffer_header_controls(
- &self,
- excerpt: &ExcerptInfo,
- _window: &Window,
- cx: &App,
- ) -> Option<AnyElement> {
- let multibuffer = self.multibuffer.upgrade()?;
- let snapshot = multibuffer.read(cx).snapshot(cx);
- let excerpts = snapshot.excerpts().collect::<Vec<_>>();
- let current_idx = excerpts.iter().position(|(id, _, _)| *id == excerpt.id)?;
- let (_, _, current_range) = &excerpts[current_idx];
-
- let start_row = current_range.context.start.to_point(&excerpt.buffer).row;
-
- let prev_end_row = if current_idx > 0 {
- let (_, prev_buffer, prev_range) = &excerpts[current_idx - 1];
- if prev_buffer.remote_id() == excerpt.buffer_id {
- prev_range.context.end.to_point(&excerpt.buffer).row
- } else {
- 0
- }
- } else {
- 0
- };
-
- let skipped_lines = start_row.saturating_sub(prev_end_row);
-
- if skipped_lines > 0 {
- Some(
- Label::new(format!("{} unchanged lines", skipped_lines))
- .color(Color::Muted)
- .size(LabelSize::Small)
- .into_any_element(),
- )
- } else {
- None
- }
- }
-
- fn to_any(&self) -> &dyn Any {
- self
- }
-}
-
async fn build_buffer(
mut text: String,
blob: Arc<dyn File>,
@@ -865,13 +866,28 @@ impl Item for CommitView {
fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
let short_sha = self.commit.sha.get(0..7).unwrap_or(&*self.commit.sha);
let subject = truncate_and_trailoff(self.commit.message.split('\n').next().unwrap(), 20);
- format!("{short_sha} - {subject}").into()
+ format!("{short_sha} โ {subject}").into()
}
- fn tab_tooltip_text(&self, _: &App) -> Option<ui::SharedString> {
+ fn tab_tooltip_content(&self, _: &App) -> Option<TabTooltipContent> {
let short_sha = self.commit.sha.get(0..16).unwrap_or(&*self.commit.sha);
let subject = self.commit.message.split('\n').next().unwrap();
- Some(format!("{short_sha} - {subject}").into())
+
+ Some(TabTooltipContent::Custom(Box::new(Tooltip::element({
+ let subject = subject.to_string();
+ let short_sha = short_sha.to_string();
+
+ move |_, _| {
+ v_flex()
+ .child(Label::new(subject.clone()))
+ .child(
+ Label::new(short_sha.clone())
+ .color(Color::Muted)
+ .size(LabelSize::Small),
+ )
+ .into_any_element()
+ }
+ }))))
}
fn to_item_events(event: &EditorEvent, f: impl FnMut(ItemEvent)) {
@@ -988,12 +1004,11 @@ impl Item for CommitView {
impl Render for CommitView {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let is_stash = self.stash.is_some();
- div()
+
+ v_flex()
.key_context(if is_stash { "StashDiff" } else { "CommitDiff" })
- .bg(cx.theme().colors().editor_background)
- .flex()
- .flex_col()
.size_full()
+ .bg(cx.theme().colors().editor_background)
.child(self.render_header(window, cx))
.child(div().flex_grow().child(self.editor.clone()))
}
@@ -1013,7 +1028,7 @@ impl EventEmitter<ToolbarItemEvent> for CommitViewToolbar {}
impl Render for CommitViewToolbar {
fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
- div()
+ div().hidden()
}
}
@@ -372,7 +372,7 @@ fn render_conflict_buttons(
.gap_1()
.bg(cx.theme().colors().editor_background)
.child(
- Button::new("head", "Use HEAD")
+ Button::new("head", format!("Use {}", conflict.ours_branch_name))
.label_size(LabelSize::Small)
.on_click({
let editor = editor.clone();
@@ -392,7 +392,7 @@ fn render_conflict_buttons(
}),
)
.child(
- Button::new("origin", "Use Origin")
+ Button::new("origin", format!("Use {}", conflict.theirs_branch_name))
.label_size(LabelSize::Small)
.on_click({
let editor = editor.clone();
@@ -267,15 +267,19 @@ impl FileHistoryView {
.child(self.render_commit_avatar(&entry.sha, window, cx))
.child(
h_flex()
+ .min_w_0()
.w_full()
.justify_between()
.child(
h_flex()
+ .min_w_0()
+ .w_full()
.gap_1()
.child(
Label::new(entry.author_name.clone())
.size(LabelSize::Small)
- .color(Color::Default),
+ .color(Color::Default)
+ .truncate(),
)
.child(
Label::new(&entry.subject)
@@ -285,9 +289,11 @@ impl FileHistoryView {
),
)
.child(
- Label::new(relative_timestamp)
- .size(LabelSize::Small)
- .color(Color::Muted),
+ h_flex().flex_none().child(
+ Label::new(relative_timestamp)
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ ),
),
),
)
@@ -34,7 +34,6 @@ use project::{
use settings::{Settings, SettingsStore};
use smol::future::yield_now;
use std::any::{Any, TypeId};
-use std::ops::Range;
use std::sync::Arc;
use theme::ActiveTheme;
use ui::{KeyBinding, Tooltip, prelude::*, vertical_divider};
@@ -46,6 +45,7 @@ use workspace::{
notifications::NotifyTaskExt,
searchable::SearchableItemHandle,
};
+use ztracing::instrument;
actions!(
git,
@@ -469,6 +469,7 @@ impl ProjectDiff {
}
}
+ #[instrument(skip_all)]
fn register_buffer(
&mut self,
path_key: PathKey,
@@ -498,23 +499,30 @@ impl ProjectDiff {
let snapshot = buffer.read(cx).snapshot();
let diff_read = diff.read(cx);
- let diff_hunk_ranges = diff_read
- .hunks_intersecting_range(
- Anchor::min_max_range_for_buffer(diff_read.buffer_id),
- &snapshot,
- cx,
- )
- .map(|diff_hunk| diff_hunk.buffer_range);
- let conflicts = conflict_addon
- .conflict_set(snapshot.remote_id())
- .map(|conflict_set| conflict_set.read(cx).snapshot().conflicts)
- .unwrap_or_default();
- let conflicts = conflicts.iter().map(|conflict| conflict.range.clone());
-
- let excerpt_ranges =
- merge_anchor_ranges(diff_hunk_ranges.into_iter(), conflicts, &snapshot)
- .map(|range| range.to_point(&snapshot))
- .collect::<Vec<_>>();
+
+ let excerpt_ranges = {
+ let diff_hunk_ranges = diff_read
+ .hunks_intersecting_range(
+ Anchor::min_max_range_for_buffer(diff_read.buffer_id),
+ &snapshot,
+ cx,
+ )
+ .map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot));
+ let conflicts = conflict_addon
+ .conflict_set(snapshot.remote_id())
+ .map(|conflict_set| conflict_set.read(cx).snapshot().conflicts)
+ .unwrap_or_default();
+ let mut conflicts = conflicts
+ .iter()
+ .map(|conflict| conflict.range.to_point(&snapshot))
+ .peekable();
+
+ if conflicts.peek().is_some() {
+ conflicts.collect::<Vec<_>>()
+ } else {
+ diff_hunk_ranges.collect()
+ }
+ };
let (was_empty, is_excerpt_newly_added) = self.multibuffer.update(cx, |multibuffer, cx| {
let was_empty = multibuffer.is_empty();
@@ -1542,53 +1550,6 @@ mod preview {
}
}
-fn merge_anchor_ranges<'a>(
- left: impl 'a + Iterator<Item = Range<Anchor>>,
- right: impl 'a + Iterator<Item = Range<Anchor>>,
- snapshot: &'a language::BufferSnapshot,
-) -> impl 'a + Iterator<Item = Range<Anchor>> {
- let mut left = left.fuse().peekable();
- let mut right = right.fuse().peekable();
-
- std::iter::from_fn(move || {
- let Some(left_range) = left.peek() else {
- return right.next();
- };
- let Some(right_range) = right.peek() else {
- return left.next();
- };
-
- let mut next_range = if left_range.start.cmp(&right_range.start, snapshot).is_lt() {
- left.next().unwrap()
- } else {
- right.next().unwrap()
- };
-
- // Extend the basic range while there's overlap with a range from either stream.
- loop {
- if let Some(left_range) = left
- .peek()
- .filter(|range| range.start.cmp(&next_range.end, snapshot).is_le())
- .cloned()
- {
- left.next();
- next_range.end = left_range.end;
- } else if let Some(right_range) = right
- .peek()
- .filter(|range| range.start.cmp(&next_range.end, snapshot).is_le())
- .cloned()
- {
- right.next();
- next_range.end = right_range.end;
- } else {
- break;
- }
- }
-
- Some(next_range)
- })
-}
-
struct BranchDiffAddon {
branch_diff: Entity<branch_diff::BranchDiff>,
}
@@ -584,7 +584,33 @@ impl AnyWeakEntity {
})
}
- /// Assert that entity referenced by this weak handle has been released.
+ /// Asserts that the entity referenced by this weak handle has been fully released.
+ ///
+ /// # Example
+ ///
+ /// ```ignore
+ /// let entity = cx.new(|_| MyEntity::new());
+ /// let weak = entity.downgrade();
+ /// drop(entity);
+ ///
+ /// // Verify the entity was released
+ /// weak.assert_released();
+ /// ```
+ ///
+ /// # Debugging Leaks
+ ///
+ /// If this method panics due to leaked handles, set the `LEAK_BACKTRACE` environment
+ /// variable to see where the leaked handles were allocated:
+ ///
+ /// ```bash
+ /// LEAK_BACKTRACE=1 cargo test my_test
+ /// ```
+ ///
+ /// # Panics
+ ///
+ /// - Panics if any strong handles to the entity are still alive.
+ /// - Panics if the entity was recently dropped but cleanup hasn't completed yet
+ /// (resources are retained until the end of the effect cycle).
#[cfg(any(test, feature = "leak-detection"))]
pub fn assert_released(&self) {
self.entity_ref_counts
@@ -814,16 +840,70 @@ impl<T: 'static> PartialOrd for WeakEntity<T> {
}
}
+/// Controls whether backtraces are captured when entity handles are created.
+///
+/// Set the `LEAK_BACKTRACE` environment variable to any non-empty value to enable
+/// backtrace capture. This helps identify where leaked handles were allocated.
#[cfg(any(test, feature = "leak-detection"))]
static LEAK_BACKTRACE: std::sync::LazyLock<bool> =
std::sync::LazyLock::new(|| std::env::var("LEAK_BACKTRACE").is_ok_and(|b| !b.is_empty()));
+/// Unique identifier for a specific entity handle instance.
+///
+/// This is distinct from `EntityId` - while multiple handles can point to the same
+/// entity (same `EntityId`), each handle has its own unique `HandleId`.
#[cfg(any(test, feature = "leak-detection"))]
#[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq)]
pub(crate) struct HandleId {
- id: u64, // id of the handle itself, not the pointed at object
+ id: u64,
}
+/// Tracks entity handle allocations to detect leaks.
+///
+/// The leak detector is enabled in tests and when the `leak-detection` feature is active.
+/// It tracks every `Entity<T>` and `AnyEntity` handle that is created and released,
+/// allowing you to verify that all handles to an entity have been properly dropped.
+///
+/// # How do leaks happen?
+///
+/// Entities are reference-counted structures that can own other entities
+/// allowing to form cycles. If such a strong-reference counted cycle is
+/// created, all participating strong entities in this cycle will effectively
+/// leak as they cannot be released anymore.
+///
+/// # Usage
+///
+/// You can use `WeakEntity::assert_released` or `AnyWeakEntity::assert_released`
+/// to verify that an entity has been fully released:
+///
+/// ```ignore
+/// let entity = cx.new(|_| MyEntity::new());
+/// let weak = entity.downgrade();
+/// drop(entity);
+///
+/// // This will panic if any handles to the entity are still alive
+/// weak.assert_released();
+/// ```
+///
+/// # Debugging Leaks
+///
+/// When a leak is detected, the detector will panic with information about the leaked
+/// handles. To see where the leaked handles were allocated, set the `LEAK_BACKTRACE`
+/// environment variable:
+///
+/// ```bash
+/// LEAK_BACKTRACE=1 cargo test my_test
+/// ```
+///
+/// This will capture and display backtraces for each leaked handle, helping you
+/// identify where handles were created but not released.
+///
+/// # How It Works
+///
+/// - When an entity handle is created (via `Entity::new`, `Entity::clone`, or
+/// `WeakEntity::upgrade`), `handle_created` is called to register the handle.
+/// - When a handle is dropped, `handle_released` removes it from tracking.
+/// - `assert_released` verifies that no handles remain for a given entity.
#[cfg(any(test, feature = "leak-detection"))]
pub(crate) struct LeakDetector {
next_handle_id: u64,
@@ -832,6 +912,11 @@ pub(crate) struct LeakDetector {
#[cfg(any(test, feature = "leak-detection"))]
impl LeakDetector {
+ /// Records that a new handle has been created for the given entity.
+ ///
+ /// Returns a unique `HandleId` that must be passed to `handle_released` when
+ /// the handle is dropped. If `LEAK_BACKTRACE` is set, captures a backtrace
+ /// at the allocation site.
#[track_caller]
pub fn handle_created(&mut self, entity_id: EntityId) -> HandleId {
let id = util::post_inc(&mut self.next_handle_id);
@@ -844,23 +929,40 @@ impl LeakDetector {
handle_id
}
+ /// Records that a handle has been released (dropped).
+ ///
+ /// This removes the handle from tracking. The `handle_id` should be the same
+ /// one returned by `handle_created` when the handle was allocated.
pub fn handle_released(&mut self, entity_id: EntityId, handle_id: HandleId) {
let handles = self.entity_handles.entry(entity_id).or_default();
handles.remove(&handle_id);
}
+ /// Asserts that all handles to the given entity have been released.
+ ///
+ /// # Panics
+ ///
+ /// Panics if any handles to the entity are still alive. The panic message
+ /// includes backtraces for each leaked handle if `LEAK_BACKTRACE` is set,
+ /// otherwise it suggests setting the environment variable to get more info.
pub fn assert_released(&mut self, entity_id: EntityId) {
+ use std::fmt::Write as _;
let handles = self.entity_handles.entry(entity_id).or_default();
if !handles.is_empty() {
+ let mut out = String::new();
for backtrace in handles.values_mut() {
if let Some(mut backtrace) = backtrace.take() {
backtrace.resolve();
- eprintln!("Leaked handle: {:#?}", backtrace);
+ writeln!(out, "Leaked handle:\n{:?}", backtrace).unwrap();
} else {
- eprintln!("Leaked handle: export LEAK_BACKTRACE to find allocation site");
+ writeln!(
+ out,
+ "Leaked handle: (export LEAK_BACKTRACE to find allocation site)"
+ )
+ .unwrap();
}
}
- panic!();
+ panic!("{out}");
}
}
}
@@ -3567,7 +3567,7 @@ pub const fn relative(fraction: f32) -> DefiniteLength {
}
/// Returns the Golden Ratio, i.e. `~(1.0 + sqrt(5.0)) / 2.0`.
-pub fn phi() -> DefiniteLength {
+pub const fn phi() -> DefiniteLength {
relative(1.618_034)
}
@@ -3580,7 +3580,7 @@ pub fn phi() -> DefiniteLength {
/// # Returns
///
/// A `Rems` representing the specified number of rems.
-pub fn rems(rems: f32) -> Rems {
+pub const fn rems(rems: f32) -> Rems {
Rems(rems)
}
@@ -3608,7 +3608,7 @@ pub const fn px(pixels: f32) -> Pixels {
/// # Returns
///
/// A `Length` variant set to `Auto`.
-pub fn auto() -> Length {
+pub const fn auto() -> Length {
Length::Auto
}
@@ -8,7 +8,6 @@ use std::{fmt::Debug, ops::Range};
use taffy::{
TaffyTree, TraversePartialTree as _,
geometry::{Point as TaffyPoint, Rect as TaffyRect, Size as TaffySize},
- prelude::min_content,
style::AvailableSpace as TaffyAvailableSpace,
tree::NodeId,
};
@@ -296,7 +295,7 @@ trait ToTaffy<Output> {
impl ToTaffy<taffy::style::Style> for Style {
fn to_taffy(&self, rem_size: Pixels, scale_factor: f32) -> taffy::style::Style {
- use taffy::style_helpers::{length, minmax, repeat};
+ use taffy::style_helpers::{fr, length, minmax, repeat};
fn to_grid_line(
placement: &Range<crate::GridPlacement>,
@@ -310,8 +309,8 @@ impl ToTaffy<taffy::style::Style> for Style {
fn to_grid_repeat<T: taffy::style::CheapCloneStr>(
unit: &Option<u16>,
) -> Vec<taffy::GridTemplateComponent<T>> {
- // grid-template-columns: repeat(<number>, minmax(0, min-content));
- unit.map(|count| vec![repeat(count, vec![minmax(length(0.0), min_content())])])
+ // grid-template-columns: repeat(<number>, minmax(0, 1fr));
+ unit.map(|count| vec![repeat(count, vec![minmax(length(0.0), fr(1.0))])])
.unwrap_or_default()
}
@@ -33,8 +33,8 @@ pub enum IconName {
ArrowRightLeft,
ArrowUp,
ArrowUpRight,
- Attach,
AtSign,
+ Attach,
AudioOff,
AudioOn,
Backspace,
@@ -44,8 +44,8 @@ pub enum IconName {
BellRing,
Binary,
Blocks,
- BoltOutlined,
BoltFilled,
+ BoltOutlined,
Book,
BookCopy,
CaseSensitive,
@@ -79,9 +79,9 @@ pub enum IconName {
Debug,
DebugBreakpoint,
DebugContinue,
+ DebugDetach,
DebugDisabledBreakpoint,
DebugDisabledLogBreakpoint,
- DebugDetach,
DebugIgnoreBreakpoints,
DebugLogBreakpoint,
DebugPause,
@@ -135,10 +135,12 @@ pub enum IconName {
GenericRestore,
GitBranch,
GitBranchAlt,
+ GitBranchPlus,
Github,
Hash,
HistoryRerun,
Image,
+ Inception,
Indicator,
Info,
Json,
@@ -146,6 +148,7 @@ pub enum IconName {
Library,
LineHeight,
Link,
+ Linux,
ListCollapse,
ListFilter,
ListTodo,
@@ -171,8 +174,8 @@ pub enum IconName {
PencilUnavailable,
Person,
Pin,
- PlayOutlined,
PlayFilled,
+ PlayOutlined,
Plus,
Power,
Public,
@@ -258,15 +261,14 @@ pub enum IconName {
ZedAssistant,
ZedBurnMode,
ZedBurnModeOn,
- ZedSrcCustom,
- ZedSrcExtension,
ZedPredict,
ZedPredictDisabled,
ZedPredictDown,
ZedPredictError,
ZedPredictUp,
+ ZedSrcCustom,
+ ZedSrcExtension,
ZedXCopilot,
- Linux,
}
impl IconName {
@@ -3,8 +3,9 @@ use std::{str::FromStr, sync::Arc};
use anyhow::{Context as _, Result};
use gpui::{App, AsyncApp, BorrowAppContext as _, Entity, WeakEntity};
-use language::LanguageRegistry;
+use language::{LanguageRegistry, language_settings::all_language_settings};
use project::LspStore;
+use util::schemars::{AllowTrailingCommas, DefaultDenyUnknownFields};
// Origin: https://github.com/SchemaStore/schemastore
const TSCONFIG_SCHEMA: &str = include_str!("schemas/tsconfig.json");
@@ -159,14 +160,35 @@ pub fn resolve_schema_request_inner(
}
}
"snippets" => snippet_provider::format::VsSnippetsFile::generate_json_schema(),
+ "jsonc" => jsonc_schema(),
_ => {
- anyhow::bail!("Unrecognized builtin JSON schema: {}", schema_name);
+ anyhow::bail!("Unrecognized builtin JSON schema: {schema_name}");
}
};
Ok(schema)
}
-pub fn all_schema_file_associations(cx: &mut App) -> serde_json::Value {
+const JSONC_LANGUAGE_NAME: &str = "JSONC";
+
+pub fn all_schema_file_associations(
+ languages: &Arc<LanguageRegistry>,
+ cx: &mut App,
+) -> serde_json::Value {
+ let extension_globs = languages
+ .available_language_for_name(JSONC_LANGUAGE_NAME)
+ .map(|language| language.matcher().path_suffixes.clone())
+ .into_iter()
+ .flatten()
+ // Path suffixes can be entire file names or just their extensions.
+ .flat_map(|path_suffix| [format!("*.{path_suffix}"), path_suffix]);
+ let override_globs = all_language_settings(None, cx)
+ .file_types
+ .get(JSONC_LANGUAGE_NAME)
+ .into_iter()
+ .flat_map(|(_, glob_strings)| glob_strings)
+ .cloned();
+ let jsonc_globs = extension_globs.chain(override_globs).collect::<Vec<_>>();
+
let mut file_associations = serde_json::json!([
{
"fileMatch": [
@@ -211,6 +233,10 @@ pub fn all_schema_file_associations(cx: &mut App) -> serde_json::Value {
"fileMatch": ["package.json"],
"url": "zed://schemas/package_json"
},
+ {
+ "fileMatch": &jsonc_globs,
+ "url": "zed://schemas/jsonc"
+ },
]);
#[cfg(debug_assertions)]
@@ -233,7 +259,7 @@ pub fn all_schema_file_associations(cx: &mut App) -> serde_json::Value {
let file_name = normalized_action_name_to_file_name(normalized_name.clone());
serde_json::json!({
"fileMatch": [file_name],
- "url": format!("zed://schemas/action/{}", normalized_name)
+ "url": format!("zed://schemas/action/{normalized_name}")
})
}),
);
@@ -249,6 +275,26 @@ fn package_json_schema() -> serde_json::Value {
serde_json::Value::from_str(PACKAGE_JSON_SCHEMA).unwrap()
}
+fn jsonc_schema() -> serde_json::Value {
+ let generator = schemars::generate::SchemaSettings::draft2019_09()
+ .with_transform(DefaultDenyUnknownFields)
+ .with_transform(AllowTrailingCommas)
+ .into_generator();
+ let meta_schema = generator
+ .settings()
+ .meta_schema
+ .as_ref()
+ .expect("meta_schema should be present in schemars settings")
+ .to_string();
+ let defs = generator.definitions();
+ let schema = schemars::json_schema!({
+ "$schema": meta_schema,
+ "allowTrailingCommas": true,
+ "$defs": defs,
+ });
+ serde_json::to_value(schema).unwrap()
+}
+
fn generate_inspector_style_schema() -> serde_json::Value {
let schema = schemars::generate::SchemaSettings::draft2019_09()
.with_transform(util::schemars::DefaultDenyUnknownFields)
@@ -4022,6 +4022,20 @@ impl BufferSnapshot {
})
}
+ pub fn outline_items_as_offsets_containing<T: ToOffset>(
+ &self,
+ range: Range<T>,
+ include_extra_context: bool,
+ theme: Option<&SyntaxTheme>,
+ ) -> Vec<OutlineItem<usize>> {
+ self.outline_items_containing_internal(
+ range,
+ include_extra_context,
+ theme,
+ |buffer, range| range.to_offset(buffer),
+ )
+ }
+
fn outline_items_containing_internal<T: ToOffset, U>(
&self,
range: Range<T>,
@@ -6,6 +6,7 @@ use futures::FutureExt as _;
use gpui::{App, AppContext as _, BorrowAppContext, Entity};
use gpui::{HighlightStyle, TestAppContext};
use indoc::indoc;
+use pretty_assertions::assert_eq;
use proto::deserialize_operation;
use rand::prelude::*;
use regex::RegexBuilder;
@@ -46,8 +47,7 @@ fn test_line_endings(cx: &mut gpui::App) {
init_settings(cx, |_| {});
cx.new(|cx| {
- let mut buffer =
- Buffer::local("one\r\ntwo\rthree", cx).with_language(Arc::new(rust_lang()), cx);
+ let mut buffer = Buffer::local("one\r\ntwo\rthree", cx).with_language(rust_lang(), cx);
assert_eq!(buffer.text(), "one\ntwo\nthree");
assert_eq!(buffer.line_ending(), LineEnding::Windows);
@@ -151,7 +151,7 @@ fn test_select_language(cx: &mut App) {
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
registry.add(Arc::new(Language::new(
LanguageConfig {
- name: LanguageName::new("Rust"),
+ name: LanguageName::new_static("Rust"),
matcher: LanguageMatcher {
path_suffixes: vec!["rs".to_string()],
..Default::default()
@@ -173,7 +173,7 @@ fn test_select_language(cx: &mut App) {
)));
registry.add(Arc::new(Language::new(
LanguageConfig {
- name: LanguageName::new("Make"),
+ name: LanguageName::new_static("Make"),
matcher: LanguageMatcher {
path_suffixes: vec!["Makefile".to_string(), "mk".to_string()],
..Default::default()
@@ -608,7 +608,7 @@ async fn test_normalize_whitespace(cx: &mut gpui::TestAppContext) {
#[gpui::test]
async fn test_reparse(cx: &mut gpui::TestAppContext) {
let text = "fn a() {}";
- let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
// Wait for the initial text to parse
cx.executor().run_until_parked();
@@ -735,7 +735,7 @@ async fn test_reparse(cx: &mut gpui::TestAppContext) {
#[gpui::test]
async fn test_resetting_language(cx: &mut gpui::TestAppContext) {
let buffer = cx.new(|cx| {
- let mut buffer = Buffer::local("{}", cx).with_language(Arc::new(rust_lang()), cx);
+ let mut buffer = Buffer::local("{}", cx).with_language(rust_lang(), cx);
buffer.set_sync_parse_timeout(Duration::ZERO);
buffer
});
@@ -783,29 +783,49 @@ async fn test_outline(cx: &mut gpui::TestAppContext) {
"#
.unindent();
- let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
- let outline = buffer.update(cx, |buffer, _| buffer.snapshot().outline(None));
+ let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
+ let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
+ let outline = snapshot.outline(None);
assert_eq!(
outline
.items
.iter()
- .map(|item| (item.text.as_str(), item.depth))
+ .map(|item| (
+ item.text.as_str(),
+ item.depth,
+ item.to_point(&snapshot).body_range(&snapshot)
+ .map(|range| minimize_space(&snapshot.text_for_range(range).collect::<String>()))
+ ))
.collect::<Vec<_>>(),
&[
- ("struct Person", 0),
- ("name", 1),
- ("age", 1),
- ("mod module", 0),
- ("enum LoginState", 1),
- ("LoggedOut", 2),
- ("LoggingOn", 2),
- ("LoggedIn", 2),
- ("person", 3),
- ("time", 3),
- ("impl Eq for Person", 0),
- ("impl Drop for Person", 0),
- ("fn drop", 1),
+ ("struct Person", 0, Some("name: String, age: usize,".to_string())),
+ ("name", 1, None),
+ ("age", 1, None),
+ (
+ "mod module",
+ 0,
+ Some(
+ "enum LoginState { LoggedOut, LoggingOn, LoggedIn { person: Person, time: Instant, } }".to_string()
+ )
+ ),
+ (
+ "enum LoginState",
+ 1,
+ Some("LoggedOut, LoggingOn, LoggedIn { person: Person, time: Instant, }".to_string())
+ ),
+ ("LoggedOut", 2, None),
+ ("LoggingOn", 2, None),
+ ("LoggedIn", 2, Some("person: Person, time: Instant,".to_string())),
+ ("person", 3, None),
+ ("time", 3, None),
+ ("impl Eq for Person", 0, Some("".to_string())),
+ (
+ "impl Drop for Person",
+ 0,
+ Some("fn drop(&mut self) { println!(\"bye\"); }".to_string())
+ ),
+ ("fn drop", 1, Some("println!(\"bye\");".to_string())),
]
);
@@ -840,6 +860,11 @@ async fn test_outline(cx: &mut gpui::TestAppContext) {
]
);
+ fn minimize_space(text: &str) -> String {
+ static WHITESPACE: LazyLock<Regex> = LazyLock::new(|| Regex::new("[\\n\\s]+").unwrap());
+ WHITESPACE.replace_all(text, " ").trim().to_string()
+ }
+
async fn search<'a>(
outline: &'a Outline<Anchor>,
query: &'a str,
@@ -865,7 +890,7 @@ async fn test_outline_nodes_with_newlines(cx: &mut gpui::TestAppContext) {
"#
.unindent();
- let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
let outline = buffer.update(cx, |buffer, _| buffer.snapshot().outline(None));
assert_eq!(
@@ -945,7 +970,7 @@ fn test_outline_annotations(cx: &mut App) {
"#
.unindent();
- let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
let outline = buffer.update(cx, |buffer, _| buffer.snapshot().outline(None));
assert_eq!(
@@ -993,7 +1018,7 @@ async fn test_symbols_containing(cx: &mut gpui::TestAppContext) {
"#
.unindent();
- let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
// point is at the start of an item
@@ -1068,7 +1093,7 @@ async fn test_symbols_containing(cx: &mut gpui::TestAppContext) {
"
.unindent(),
);
- let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
// note, it would be nice to actually return the method test in this
@@ -1087,8 +1112,7 @@ fn test_text_objects(cx: &mut App) {
false,
);
- let buffer =
- cx.new(|cx| Buffer::local(text.clone(), cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new(|cx| Buffer::local(text.clone(), cx).with_language(rust_lang(), cx));
let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot());
let matches = snapshot
@@ -1105,6 +1129,14 @@ fn test_text_objects(cx: &mut App) {
"fn say() -> u8 { return /* hi */ 1 }",
TextObject::AroundFunction
),
+ (
+ "fn say() -> u8 { return /* hi */ 1 }",
+ TextObject::InsideClass
+ ),
+ (
+ "impl Hello {\n fn say() -> u8 { return /* hi */ 1 }\n}",
+ TextObject::AroundClass
+ ),
],
)
}
@@ -1235,7 +1267,12 @@ fn test_enclosing_bracket_ranges(cx: &mut App) {
#[gpui::test]
fn test_enclosing_bracket_ranges_where_brackets_are_not_outermost_children(cx: &mut App) {
let mut assert = |selection_text, bracket_pair_texts| {
- assert_bracket_pairs(selection_text, bracket_pair_texts, javascript_lang(), cx)
+ assert_bracket_pairs(
+ selection_text,
+ bracket_pair_texts,
+ Arc::new(javascript_lang()),
+ cx,
+ )
};
assert(
@@ -1268,7 +1305,7 @@ fn test_enclosing_bracket_ranges_where_brackets_are_not_outermost_children(cx: &
fn test_range_for_syntax_ancestor(cx: &mut App) {
cx.new(|cx| {
let text = "fn a() { b(|c| {}) }";
- let buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx);
+ let buffer = Buffer::local(text, cx).with_language(rust_lang(), cx);
let snapshot = buffer.snapshot();
assert_eq!(
@@ -1320,7 +1357,7 @@ fn test_autoindent_with_soft_tabs(cx: &mut App) {
cx.new(|cx| {
let text = "fn a() {}";
- let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx);
+ let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx);
buffer.edit([(8..8, "\n\n")], Some(AutoindentMode::EachLine), cx);
assert_eq!(buffer.text(), "fn a() {\n \n}");
@@ -1362,7 +1399,7 @@ fn test_autoindent_with_hard_tabs(cx: &mut App) {
cx.new(|cx| {
let text = "fn a() {}";
- let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx);
+ let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx);
buffer.edit([(8..8, "\n\n")], Some(AutoindentMode::EachLine), cx);
assert_eq!(buffer.text(), "fn a() {\n\t\n}");
@@ -1411,7 +1448,7 @@ fn test_autoindent_does_not_adjust_lines_with_unchanged_suggestion(cx: &mut App)
.unindent(),
cx,
)
- .with_language(Arc::new(rust_lang()), cx);
+ .with_language(rust_lang(), cx);
// Lines 2 and 3 don't match the indentation suggestion. When editing these lines,
// their indentation is not adjusted.
@@ -1552,7 +1589,7 @@ fn test_autoindent_does_not_adjust_lines_with_unchanged_suggestion(cx: &mut App)
.unindent(),
cx,
)
- .with_language(Arc::new(rust_lang()), cx);
+ .with_language(rust_lang(), cx);
// Insert a closing brace. It is outdented.
buffer.edit_via_marked_text(
@@ -1615,7 +1652,7 @@ fn test_autoindent_does_not_adjust_lines_within_newly_created_errors(cx: &mut Ap
.unindent(),
cx,
)
- .with_language(Arc::new(rust_lang()), cx);
+ .with_language(rust_lang(), cx);
// Regression test: line does not get outdented due to syntax error
buffer.edit_via_marked_text(
@@ -1674,7 +1711,7 @@ fn test_autoindent_adjusts_lines_when_only_text_changes(cx: &mut App) {
.unindent(),
cx,
)
- .with_language(Arc::new(rust_lang()), cx);
+ .with_language(rust_lang(), cx);
buffer.edit_via_marked_text(
&"
@@ -1724,7 +1761,7 @@ fn test_autoindent_with_edit_at_end_of_buffer(cx: &mut App) {
cx.new(|cx| {
let text = "a\nb";
- let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx);
+ let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx);
buffer.edit(
[(0..1, "\n"), (2..3, "\n")],
Some(AutoindentMode::EachLine),
@@ -1750,7 +1787,7 @@ fn test_autoindent_multi_line_insertion(cx: &mut App) {
"
.unindent();
- let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx);
+ let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx);
buffer.edit(
[(Point::new(3, 0)..Point::new(3, 0), "e(\n f()\n);\n")],
Some(AutoindentMode::EachLine),
@@ -1787,7 +1824,7 @@ fn test_autoindent_block_mode(cx: &mut App) {
}
"#
.unindent();
- let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx);
+ let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx);
// When this text was copied, both of the quotation marks were at the same
// indent level, but the indentation of the first line was not included in
@@ -1870,7 +1907,7 @@ fn test_autoindent_block_mode_with_newline(cx: &mut App) {
}
"#
.unindent();
- let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx);
+ let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx);
// First line contains just '\n', it's indentation is stored in "original_indent_columns"
let original_indent_columns = vec![Some(4)];
@@ -1922,7 +1959,7 @@ fn test_autoindent_block_mode_without_original_indent_columns(cx: &mut App) {
}
"#
.unindent();
- let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx);
+ let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx);
// The original indent columns are not known, so this text is
// auto-indented in a block as if the first line was copied in
@@ -2013,7 +2050,7 @@ fn test_autoindent_block_mode_multiple_adjacent_ranges(cx: &mut App) {
false,
);
- let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx);
+ let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx);
buffer.edit(
[
@@ -2027,7 +2064,7 @@ fn test_autoindent_block_mode_multiple_adjacent_ranges(cx: &mut App) {
cx,
);
- pretty_assertions::assert_eq!(
+ assert_eq!(
buffer.text(),
"
mod numbers {
@@ -2221,7 +2258,7 @@ async fn test_async_autoindents_preserve_preview(cx: &mut TestAppContext) {
// Then we request that a preview tab be preserved for the new version, even though it's edited.
let buffer = cx.new(|cx| {
let text = "fn a() {}";
- let mut buffer = Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx);
+ let mut buffer = Buffer::local(text, cx).with_language(rust_lang(), cx);
// This causes autoindent to be async.
buffer.set_sync_parse_timeout(Duration::ZERO);
@@ -2679,7 +2716,7 @@ fn test_language_at_with_hidden_languages(cx: &mut App) {
.unindent();
let language_registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
- language_registry.add(Arc::new(markdown_lang()));
+ language_registry.add(markdown_lang());
language_registry.add(Arc::new(markdown_inline_lang()));
let mut buffer = Buffer::local(text, cx);
@@ -2721,9 +2758,9 @@ fn test_language_at_for_markdown_code_block(cx: &mut App) {
.unindent();
let language_registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
- language_registry.add(Arc::new(markdown_lang()));
+ language_registry.add(markdown_lang());
language_registry.add(Arc::new(markdown_inline_lang()));
- language_registry.add(Arc::new(rust_lang()));
+ language_registry.add(rust_lang());
let mut buffer = Buffer::local(text, cx);
buffer.set_language_registry(language_registry.clone());
@@ -3120,7 +3157,7 @@ async fn test_preview_edits(cx: &mut TestAppContext) {
cx: &mut TestAppContext,
assert_fn: impl Fn(HighlightedText),
) {
- let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
let edits = buffer.read_with(cx, |buffer, _| {
edits
.into_iter()
@@ -3531,7 +3568,7 @@ let word=รถรคpple.barไฝ รรคpple word2-รถรpPlE-Pizza-word รรPPLE word
"#;
let buffer = cx.new(|cx| {
- let buffer = Buffer::local(contents, cx).with_language(Arc::new(rust_lang()), cx);
+ let buffer = Buffer::local(contents, cx).with_language(rust_lang(), cx);
assert_eq!(buffer.text(), contents);
buffer.check_invariants();
buffer
@@ -3691,7 +3728,7 @@ fn ruby_lang() -> Language {
fn html_lang() -> Language {
Language::new(
LanguageConfig {
- name: LanguageName::new("HTML"),
+ name: LanguageName::new_static("HTML"),
block_comment: Some(BlockCommentConfig {
start: "<!--".into(),
prefix: "".into(),
@@ -3756,78 +3793,6 @@ fn erb_lang() -> Language {
.unwrap()
}
-fn rust_lang() -> Language {
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- )
- .with_indents_query(
- r#"
- (call_expression) @indent
- (field_expression) @indent
- (_ "(" ")" @end) @indent
- (_ "{" "}" @end) @indent
- "#,
- )
- .unwrap()
- .with_brackets_query(
- r#"
- ("{" @open "}" @close)
- "#,
- )
- .unwrap()
- .with_text_object_query(
- r#"
- (function_item
- body: (_
- "{"
- (_)* @function.inside
- "}" )) @function.around
-
- (line_comment)+ @comment.around
-
- (block_comment) @comment.around
- "#,
- )
- .unwrap()
- .with_outline_query(
- r#"
- (line_comment) @annotation
-
- (struct_item
- "struct" @context
- name: (_) @name) @item
- (enum_item
- "enum" @context
- name: (_) @name) @item
- (enum_variant
- name: (_) @name) @item
- (field_declaration
- name: (_) @name) @item
- (impl_item
- "impl" @context
- trait: (_)? @name
- "for"? @context
- type: (_) @name
- body: (_ "{" (_)* "}")) @item
- (function_item
- "fn" @context
- name: (_) @name) @item
- (mod_item
- "mod" @context
- name: (_) @name) @item
- "#,
- )
- .unwrap()
-}
-
fn json_lang() -> Language {
Language::new(
LanguageConfig {
@@ -3865,32 +3830,6 @@ fn javascript_lang() -> Language {
.unwrap()
}
-pub fn markdown_lang() -> Language {
- Language::new(
- LanguageConfig {
- name: "Markdown".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["md".into()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_md::LANGUAGE.into()),
- )
- .with_injection_query(
- r#"
- (fenced_code_block
- (info_string
- (language) @injection.language)
- (code_fence_content) @injection.content)
-
- ((inline) @injection.content
- (#set! injection.language "markdown-inline"))
- "#,
- )
- .unwrap()
-}
-
pub fn markdown_inline_lang() -> Language {
Language::new(
LanguageConfig {
@@ -3917,12 +3856,11 @@ fn get_tree_sexp(buffer: &Entity<Buffer>, cx: &mut gpui::TestAppContext) -> Stri
fn assert_bracket_pairs(
selection_text: &'static str,
bracket_pair_texts: Vec<&'static str>,
- language: Language,
+ language: Arc<Language>,
cx: &mut App,
) {
let (expected_text, selection_ranges) = marked_text_ranges(selection_text, false);
- let buffer =
- cx.new(|cx| Buffer::local(expected_text.clone(), cx).with_language(Arc::new(language), cx));
+ let buffer = cx.new(|cx| Buffer::local(expected_text.clone(), cx).with_language(language, cx));
let buffer = buffer.update(cx, |buffer, _cx| buffer.snapshot());
let selection_range = selection_ranges[0].clone();
@@ -982,7 +982,7 @@ impl<T> Override<T> {
impl Default for LanguageConfig {
fn default() -> Self {
Self {
- name: LanguageName::new(""),
+ name: LanguageName::new_static(""),
code_fence_block_name: None,
grammar: None,
matcher: LanguageMatcher::default(),
@@ -2656,7 +2656,28 @@ pub fn rust_lang() -> Arc<Language> {
text_objects: Some(Cow::from(include_str!(
"../../languages/src/rust/textobjects.scm"
))),
- ..LanguageQueries::default()
+ highlights: Some(Cow::from(include_str!(
+ "../../languages/src/rust/highlights.scm"
+ ))),
+ embedding: Some(Cow::from(include_str!(
+ "../../languages/src/rust/embedding.scm"
+ ))),
+ injections: Some(Cow::from(include_str!(
+ "../../languages/src/rust/injections.scm"
+ ))),
+ overrides: Some(Cow::from(include_str!(
+ "../../languages/src/rust/overrides.scm"
+ ))),
+ redactions: None,
+ runnables: Some(Cow::from(include_str!(
+ "../../languages/src/rust/runnables.scm"
+ ))),
+ debugger: Some(Cow::from(include_str!(
+ "../../languages/src/rust/debugger.scm"
+ ))),
+ imports: Some(Cow::from(include_str!(
+ "../../languages/src/rust/imports.scm"
+ ))),
})
.expect("Could not parse queries");
Arc::new(language)
@@ -2685,6 +2706,15 @@ pub fn markdown_lang() -> Arc<Language> {
injections: Some(Cow::from(include_str!(
"../../languages/src/markdown/injections.scm"
))),
+ highlights: Some(Cow::from(include_str!(
+ "../../languages/src/markdown/highlights.scm"
+ ))),
+ indents: Some(Cow::from(include_str!(
+ "../../languages/src/markdown/indents.scm"
+ ))),
+ outline: Some(Cow::from(include_str!(
+ "../../languages/src/markdown/outline.scm"
+ ))),
..LanguageQueries::default()
})
.expect("Could not parse markdown queries");
@@ -2726,9 +2756,9 @@ mod tests {
assert_eq!(
languages.language_names(),
&[
- LanguageName::new("JSON"),
- LanguageName::new("Plain Text"),
- LanguageName::new("Rust"),
+ LanguageName::new_static("JSON"),
+ LanguageName::new_static("Plain Text"),
+ LanguageName::new_static("Rust"),
]
);
@@ -2739,9 +2769,9 @@ mod tests {
assert_eq!(
languages.language_names(),
&[
- LanguageName::new("JSON"),
- LanguageName::new("Plain Text"),
- LanguageName::new("Rust"),
+ LanguageName::new_static("JSON"),
+ LanguageName::new_static("Plain Text"),
+ LanguageName::new_static("Rust"),
]
);
@@ -2752,9 +2782,9 @@ mod tests {
assert_eq!(
languages.language_names(),
&[
- LanguageName::new("JSON"),
- LanguageName::new("Plain Text"),
- LanguageName::new("Rust"),
+ LanguageName::new_static("JSON"),
+ LanguageName::new_static("Plain Text"),
+ LanguageName::new_static("Rust"),
]
);
@@ -43,12 +43,18 @@ impl LanguageName {
Self(SharedString::new(s))
}
+ pub fn new_static(s: &'static str) -> Self {
+ Self(SharedString::new_static(s))
+ }
+
pub fn from_proto(s: String) -> Self {
Self(SharedString::from(s))
}
+
pub fn to_proto(&self) -> String {
self.0.to_string()
}
+
pub fn lsp_id(&self) -> String {
match self.0.as_ref() {
"Plain Text" => "plaintext".to_string(),
@@ -87,9 +93,9 @@ impl std::fmt::Display for LanguageName {
}
}
-impl<'a> From<&'a str> for LanguageName {
- fn from(str: &'a str) -> LanguageName {
- LanguageName(SharedString::new(str))
+impl From<&'static str> for LanguageName {
+ fn from(str: &'static str) -> Self {
+ Self(SharedString::new_static(str))
}
}
@@ -437,26 +443,14 @@ impl LanguageRegistry {
language_name: impl Into<LanguageName>,
mut adapter: crate::FakeLspAdapter,
) -> futures::channel::mpsc::UnboundedReceiver<lsp::FakeLanguageServer> {
- let language_name = language_name.into();
let adapter_name = LanguageServerName(adapter.name.into());
let capabilities = adapter.capabilities.clone();
let initializer = adapter.initializer.take();
- let adapter = CachedLspAdapter::new(Arc::new(adapter));
- {
- let mut state = self.state.write();
- state
- .lsp_adapters
- .entry(language_name)
- .or_default()
- .push(adapter.clone());
- state.all_lsp_adapters.insert(adapter.name(), adapter);
- }
-
- self.register_fake_language_server(adapter_name, capabilities, initializer)
+ self.register_fake_lsp_adapter(language_name, adapter);
+ self.register_fake_lsp_server(adapter_name, capabilities, initializer)
}
/// Register a fake lsp adapter (without the language server)
- /// The returned channel receives a new instance of the language server every time it is started
#[cfg(any(feature = "test-support", test))]
pub fn register_fake_lsp_adapter(
&self,
@@ -479,7 +473,7 @@ impl LanguageRegistry {
/// Register a fake language server (without the adapter)
/// The returned channel receives a new instance of the language server every time it is started
#[cfg(any(feature = "test-support", test))]
- pub fn register_fake_language_server(
+ pub fn register_fake_lsp_server(
&self,
lsp_name: LanguageServerName,
capabilities: lsp::ServerCapabilities,
@@ -757,7 +751,7 @@ impl LanguageRegistry {
self: &Arc<Self>,
path: &Path,
content: Option<&Rope>,
- user_file_types: Option<&FxHashMap<Arc<str>, GlobSet>>,
+ user_file_types: Option<&FxHashMap<Arc<str>, (GlobSet, Vec<String>)>>,
) -> Option<AvailableLanguage> {
let filename = path.file_name().and_then(|filename| filename.to_str());
// `Path.extension()` returns None for files with a leading '.'
@@ -800,7 +794,7 @@ impl LanguageRegistry {
let path_matches_custom_suffix = || {
user_file_types
.and_then(|types| types.get(language_name.as_ref()))
- .map_or(None, |custom_suffixes| {
+ .map_or(None, |(custom_suffixes, _)| {
path_suffixes
.iter()
.find(|(_, candidate)| custom_suffixes.is_match_candidate(candidate))
@@ -51,7 +51,7 @@ pub struct AllLanguageSettings {
pub edit_predictions: EditPredictionSettings,
pub defaults: LanguageSettings,
languages: HashMap<LanguageName, LanguageSettings>,
- pub(crate) file_types: FxHashMap<Arc<str>, GlobSet>,
+ pub file_types: FxHashMap<Arc<str>, (GlobSet, Vec<String>)>,
}
#[derive(Debug, Clone, PartialEq)]
@@ -373,6 +373,8 @@ impl InlayHintSettings {
pub struct EditPredictionSettings {
/// The provider that supplies edit predictions.
pub provider: settings::EditPredictionProvider,
+ /// Whether to use the experimental edit prediction context retrieval system.
+ pub use_context: bool,
/// A list of globs representing files that edit predictions should be disabled for.
/// This list adds to a pre-existing, sensible default set of globs.
/// Any additional ones you add are combined with them.
@@ -622,6 +624,11 @@ impl settings::Settings for AllLanguageSettings {
.features
.as_ref()
.and_then(|f| f.edit_prediction_provider);
+ let use_edit_prediction_context = all_languages
+ .features
+ .as_ref()
+ .and_then(|f| f.experimental_edit_prediction_context_retrieval)
+ .unwrap_or_default();
let edit_predictions = all_languages.edit_predictions.clone().unwrap();
let edit_predictions_mode = edit_predictions.mode.unwrap();
@@ -649,7 +656,7 @@ impl settings::Settings for AllLanguageSettings {
let enabled_in_text_threads = edit_predictions.enabled_in_text_threads.unwrap();
- let mut file_types: FxHashMap<Arc<str>, GlobSet> = FxHashMap::default();
+ let mut file_types: FxHashMap<Arc<str>, (GlobSet, Vec<String>)> = FxHashMap::default();
for (language, patterns) in all_languages.file_types.iter().flatten() {
let mut builder = GlobSetBuilder::new();
@@ -658,7 +665,10 @@ impl settings::Settings for AllLanguageSettings {
builder.add(Glob::new(pattern).unwrap());
}
- file_types.insert(language.clone(), builder.build().unwrap());
+ file_types.insert(
+ language.clone(),
+ (builder.build().unwrap(), patterns.0.clone()),
+ );
}
Self {
@@ -668,6 +678,7 @@ impl settings::Settings for AllLanguageSettings {
} else {
EditPredictionProvider::None
},
+ use_context: use_edit_prediction_context,
disabled_globs: disabled_globs
.iter()
.filter_map(|g| {
@@ -1,4 +1,4 @@
-use crate::{BufferSnapshot, Point, ToPoint};
+use crate::{BufferSnapshot, Point, ToPoint, ToTreeSitterPoint};
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{BackgroundExecutor, HighlightStyle};
use std::ops::Range;
@@ -48,6 +48,54 @@ impl<T: ToPoint> OutlineItem<T> {
.map(|r| r.start.to_point(buffer)..r.end.to_point(buffer)),
}
}
+
+ pub fn body_range(&self, buffer: &BufferSnapshot) -> Option<Range<Point>> {
+ if let Some(range) = self.body_range.as_ref() {
+ return Some(range.start.to_point(buffer)..range.end.to_point(buffer));
+ }
+
+ let range = self.range.start.to_point(buffer)..self.range.end.to_point(buffer);
+ let start_indent = buffer.indent_size_for_line(range.start.row);
+ let node = buffer.syntax_ancestor(range.clone())?;
+
+ let mut cursor = node.walk();
+ loop {
+ let node = cursor.node();
+ if node.start_position() >= range.start.to_ts_point()
+ && node.end_position() <= range.end.to_ts_point()
+ {
+ break;
+ }
+ cursor.goto_first_child_for_point(range.start.to_ts_point());
+ }
+
+ if !cursor.goto_last_child() {
+ return None;
+ }
+ let body_node = loop {
+ let node = cursor.node();
+ if node.child_count() > 0 {
+ break node;
+ }
+ if !cursor.goto_previous_sibling() {
+ return None;
+ }
+ };
+
+ let mut start_row = body_node.start_position().row as u32;
+ let mut end_row = body_node.end_position().row as u32;
+
+ while start_row < end_row && buffer.indent_size_for_line(start_row) == start_indent {
+ start_row += 1;
+ }
+ while start_row < end_row && buffer.indent_size_for_line(end_row - 1) == start_indent {
+ end_row -= 1;
+ }
+ if start_row < end_row {
+ return Some(Point::new(start_row, 0)..Point::new(end_row, 0));
+ }
+ None
+ }
}
impl<T> Outline<T> {
@@ -1215,6 +1215,19 @@ impl<'a> SyntaxMapMatches<'a> {
true
}
+
+ // pub fn set_byte_range(&mut self, range: Range<usize>) {
+ // for layer in &mut self.layers {
+ // layer.matches.set_byte_range(range.clone());
+ // layer.advance();
+ // }
+ // self.layers.sort_unstable_by_key(|layer| layer.sort_key());
+ // self.active_layer_count = self
+ // .layers
+ // .iter()
+ // .position(|layer| !layer.has_next)
+ // .unwrap_or(self.layers.len());
+ // }
}
impl SyntaxMapCapturesLayer<'_> {
@@ -1,9 +1,9 @@
use super::*;
use crate::{
- LanguageConfig, LanguageMatcher,
- buffer_tests::{markdown_inline_lang, markdown_lang},
+ LanguageConfig, LanguageMatcher, buffer_tests::markdown_inline_lang, markdown_lang, rust_lang,
};
use gpui::App;
+use pretty_assertions::assert_eq;
use rand::rngs::StdRng;
use std::{env, ops::Range, sync::Arc};
use text::{Buffer, BufferId, ReplicaId};
@@ -84,7 +84,7 @@ fn test_splice_included_ranges() {
#[gpui::test]
fn test_syntax_map_layers_for_range(cx: &mut App) {
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
- let language = Arc::new(rust_lang());
+ let language = rust_lang();
registry.add(language.clone());
let mut buffer = Buffer::new(
@@ -181,11 +181,11 @@ fn test_syntax_map_layers_for_range(cx: &mut App) {
#[gpui::test]
fn test_dynamic_language_injection(cx: &mut App) {
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
- let markdown = Arc::new(markdown_lang());
+ let markdown = markdown_lang();
let markdown_inline = Arc::new(markdown_inline_lang());
registry.add(markdown.clone());
registry.add(markdown_inline.clone());
- registry.add(Arc::new(rust_lang()));
+ registry.add(rust_lang());
registry.add(Arc::new(ruby_lang()));
let mut buffer = Buffer::new(
@@ -291,7 +291,7 @@ fn test_typing_multiple_new_injections(cx: &mut App) {
assert_capture_ranges(
&syntax_map,
&buffer,
- &["field"],
+ &["property"],
"fn a() { test_macro!(b.ยซcยป(vec![d.ยซeยป])) }",
);
}
@@ -329,16 +329,16 @@ fn test_pasting_new_injection_line_between_others(cx: &mut App) {
assert_capture_ranges(
&syntax_map,
&buffer,
- &["struct"],
+ &["type"],
"
fn a() {
- b!(ยซB {}ยป);
- c!(ยซC {}ยป);
- d!(ยซD {}ยป);
- h!(ยซH {}ยป);
- e!(ยซE {}ยป);
- f!(ยซF {}ยป);
- g!(ยซG {}ยป);
+ b!(ยซBยป {});
+ c!(ยซCยป {});
+ d!(ยซDยป {});
+ h!(ยซHยป {});
+ e!(ยซEยป {});
+ f!(ยซFยป {});
+ g!(ยซGยป {});
}
",
);
@@ -376,7 +376,7 @@ fn test_joining_injections_with_child_injections(cx: &mut App) {
assert_capture_ranges(
&syntax_map,
&buffer,
- &["field"],
+ &["property"],
"
fn a() {
b!(
@@ -900,7 +900,7 @@ fn test_random_syntax_map_edits_rust_macros(rng: StdRng, cx: &mut App) {
.repeat(2);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
- let language = Arc::new(rust_lang());
+ let language = rust_lang();
registry.add(language.clone());
test_random_edits(text, registry, language, rng);
@@ -1147,11 +1147,11 @@ fn test_edit_sequence(language_name: &str, steps: &[&str], cx: &mut App) -> (Buf
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
registry.add(Arc::new(elixir_lang()));
registry.add(Arc::new(heex_lang()));
- registry.add(Arc::new(rust_lang()));
+ registry.add(rust_lang());
registry.add(Arc::new(ruby_lang()));
registry.add(Arc::new(html_lang()));
registry.add(Arc::new(erb_lang()));
- registry.add(Arc::new(markdown_lang()));
+ registry.add(markdown_lang());
registry.add(Arc::new(markdown_inline_lang()));
let language = registry
@@ -1287,35 +1287,6 @@ fn erb_lang() -> Language {
.unwrap()
}
-fn rust_lang() -> Language {
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- )
- .with_highlights_query(
- r#"
- (field_identifier) @field
- (struct_expression) @struct
- "#,
- )
- .unwrap()
- .with_injection_query(
- r#"
- (macro_invocation
- (token_tree) @injection.content
- (#set! injection.language "rust"))
- "#,
- )
- .unwrap()
-}
-
fn elixir_lang() -> Language {
Language::new(
LanguageConfig {
@@ -1425,6 +1396,7 @@ fn assert_capture_ranges(
actual_ranges.push(capture.node.byte_range());
}
}
+ actual_ranges.dedup();
let (text, expected_ranges) = marked_text_ranges(&marked_string.unindent(), false);
assert_eq!(text, buffer.text());
@@ -245,7 +245,7 @@ impl LspAdapter for ExtensionLspAdapter {
// We can remove once the following extension versions no longer see any use:
// - php@0.0.1
if self.extension.manifest().id.as_ref() == "php" {
- return HashMap::from_iter([(LanguageName::new("PHP"), "php".into())]);
+ return HashMap::from_iter([(LanguageName::new_static("PHP"), "php".into())]);
}
self.extension
@@ -707,6 +707,40 @@ pub trait LanguageModel: Send + Sync {
.boxed()
}
+ fn stream_completion_tool(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AsyncApp,
+ ) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
+ let future = self.stream_completion(request, cx);
+
+ async move {
+ let events = future.await?;
+ let mut events = events.fuse();
+
+ // Iterate through events until we find a complete ToolUse
+ while let Some(event) = events.next().await {
+ match event {
+ Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
+ if tool_use.is_input_complete =>
+ {
+ return Ok(tool_use);
+ }
+ Err(err) => {
+ return Err(err);
+ }
+ _ => {}
+ }
+ }
+
+ // Stream ended without a complete tool use
+ Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
+ "Stream ended without receiving a complete tool use"
+ )))
+ }
+ .boxed()
+ }
+
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
None
}
@@ -109,7 +109,7 @@ pub fn into_open_ai(
messages,
stream,
stop: request.stop,
- temperature: request.temperature.unwrap_or(1.0),
+ temperature: request.temperature.or(Some(1.0)),
max_completion_tokens: max_output_tokens,
parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
Some(false)
@@ -47,10 +47,45 @@
left: (identifier) @function
right: [(function_expression) (arrow_function)])
+; Parameters
+
+(required_parameter
+ (identifier) @variable.parameter)
+
+(required_parameter
+ (_
+ ([
+ (identifier)
+ (shorthand_property_identifier_pattern)
+ ]) @variable.parameter))
+
+(optional_parameter
+ (identifier) @variable.parameter)
+
+(optional_parameter
+ (_
+ ([
+ (identifier)
+ (shorthand_property_identifier_pattern)
+ ]) @variable.parameter))
+
+(catch_clause
+ parameter: (identifier) @variable.parameter)
+
+(index_signature
+ name: (identifier) @variable.parameter)
+
+(arrow_function
+ parameter: (identifier) @variable.parameter)
+
; Special identifiers
+;
+(class_declaration
+ (type_identifier) @type.class)
+
+(extends_clause
+ value: (identifier) @type.class)
-((identifier) @type
- (#match? @type "^[A-Z]"))
(type_identifier) @type
(predefined_type) @type.builtin
@@ -1,2 +1,3 @@
(tag_name) @keyword.jsdoc
(type) @type.jsdoc
+(identifier) @variable.jsdoc
@@ -7,8 +7,8 @@ use futures::StreamExt;
use gpui::{App, AsyncApp, Task};
use http_client::github::{GitHubLspBinaryVersion, latest_github_release};
use language::{
- ContextProvider, LanguageName, LocalFile as _, LspAdapter, LspAdapterDelegate, LspInstaller,
- Toolchain,
+ ContextProvider, LanguageName, LanguageRegistry, LocalFile as _, LspAdapter,
+ LspAdapterDelegate, LspInstaller, Toolchain,
};
use lsp::{LanguageServerBinary, LanguageServerName, Uri};
use node_runtime::{NodeRuntime, VersionStrategy};
@@ -129,14 +129,15 @@ fn server_binary_arguments(server_path: &Path) -> Vec<OsString> {
}
pub struct JsonLspAdapter {
+ languages: Arc<LanguageRegistry>,
node: NodeRuntime,
}
impl JsonLspAdapter {
const PACKAGE_NAME: &str = "vscode-langservers-extracted";
- pub fn new(node: NodeRuntime) -> Self {
- Self { node }
+ pub fn new(languages: Arc<LanguageRegistry>, node: NodeRuntime) -> Self {
+ Self { languages, node }
}
}
@@ -255,7 +256,7 @@ impl LspAdapter for JsonLspAdapter {
cx: &mut AsyncApp,
) -> Result<Value> {
let mut config = cx.update(|cx| {
- let schemas = json_schema_store::all_schema_file_associations(cx);
+ let schemas = json_schema_store::all_schema_file_associations(&self.languages, cx);
// This can be viewed via `dev: open language server logs` -> `json-language-server` ->
// `Server Info`
@@ -285,8 +286,8 @@ impl LspAdapter for JsonLspAdapter {
fn language_ids(&self) -> HashMap<LanguageName, String> {
[
- (LanguageName::new("JSON"), "json".into()),
- (LanguageName::new("JSONC"), "jsonc".into()),
+ (LanguageName::new_static("JSON"), "json".into()),
+ (LanguageName::new_static("JSONC"), "jsonc".into()),
]
.into_iter()
.collect()
@@ -89,7 +89,7 @@ pub fn init(languages: Arc<LanguageRegistry>, fs: Arc<dyn Fs>, node: NodeRuntime
let go_context_provider = Arc::new(go::GoContextProvider);
let go_lsp_adapter = Arc::new(go::GoLspAdapter);
let json_context_provider = Arc::new(JsonTaskProvider);
- let json_lsp_adapter = Arc::new(json::JsonLspAdapter::new(node.clone()));
+ let json_lsp_adapter = Arc::new(json::JsonLspAdapter::new(languages.clone(), node.clone()));
let node_version_lsp_adapter = Arc::new(json::NodeVersionAdapter);
let py_lsp_adapter = Arc::new(python::PyLspAdapter::new());
let ty_lsp_adapter = Arc::new(python::TyLspAdapter::new(fs.clone()));
@@ -24,4 +24,5 @@ rewrap_prefixes = [
auto_indent_on_paste = false
auto_indent_using_last_non_empty_line = false
tab_size = 2
+decrease_indent_pattern = "^.*$"
prettier_parser_name = "markdown"
@@ -0,0 +1 @@
+(list (list_item) @indent)
@@ -903,7 +903,7 @@ impl ContextProvider for PythonContextProvider {
fn selected_test_runner(location: Option<&Arc<dyn language::File>>, cx: &App) -> TestRunner {
const TEST_RUNNER_VARIABLE: &str = "TEST_RUNNER";
- language_settings(Some(LanguageName::new("Python")), location, cx)
+ language_settings(Some(LanguageName::new_static("Python")), location, cx)
.tasks
.variables
.get(TEST_RUNNER_VARIABLE)
@@ -1397,7 +1397,7 @@ async fn venv_to_toolchain(venv: PythonEnvironment, fs: &dyn Fs) -> Option<Toolc
.to_str()?
.to_owned()
.into(),
- language_name: LanguageName::new("Python"),
+ language_name: LanguageName::new_static("Python"),
as_json: serde_json::to_value(data).ok()?,
})
}
@@ -174,20 +174,32 @@ impl LspAdapter for TailwindLspAdapter {
fn language_ids(&self) -> HashMap<LanguageName, String> {
HashMap::from_iter([
- (LanguageName::new("Astro"), "astro".to_string()),
- (LanguageName::new("HTML"), "html".to_string()),
- (LanguageName::new("Gleam"), "html".to_string()),
- (LanguageName::new("CSS"), "css".to_string()),
- (LanguageName::new("JavaScript"), "javascript".to_string()),
- (LanguageName::new("TypeScript"), "typescript".to_string()),
- (LanguageName::new("TSX"), "typescriptreact".to_string()),
- (LanguageName::new("Svelte"), "svelte".to_string()),
- (LanguageName::new("Elixir"), "phoenix-heex".to_string()),
- (LanguageName::new("HEEX"), "phoenix-heex".to_string()),
- (LanguageName::new("ERB"), "erb".to_string()),
- (LanguageName::new("HTML+ERB"), "erb".to_string()),
- (LanguageName::new("PHP"), "php".to_string()),
- (LanguageName::new("Vue.js"), "vue".to_string()),
+ (LanguageName::new_static("Astro"), "astro".to_string()),
+ (LanguageName::new_static("HTML"), "html".to_string()),
+ (LanguageName::new_static("Gleam"), "html".to_string()),
+ (LanguageName::new_static("CSS"), "css".to_string()),
+ (
+ LanguageName::new_static("JavaScript"),
+ "javascript".to_string(),
+ ),
+ (
+ LanguageName::new_static("TypeScript"),
+ "typescript".to_string(),
+ ),
+ (
+ LanguageName::new_static("TSX"),
+ "typescriptreact".to_string(),
+ ),
+ (LanguageName::new_static("Svelte"), "svelte".to_string()),
+ (
+ LanguageName::new_static("Elixir"),
+ "phoenix-heex".to_string(),
+ ),
+ (LanguageName::new_static("HEEX"), "phoenix-heex".to_string()),
+ (LanguageName::new_static("ERB"), "erb".to_string()),
+ (LanguageName::new_static("HTML+ERB"), "erb".to_string()),
+ (LanguageName::new_static("PHP"), "php".to_string()),
+ (LanguageName::new_static("Vue.js"), "vue".to_string()),
])
}
}
@@ -47,13 +47,68 @@
left: (identifier) @function
right: [(function_expression) (arrow_function)])
+; Parameters
+
+(required_parameter
+ (identifier) @variable.parameter)
+
+(required_parameter
+ (_
+ ([
+ (identifier)
+ (shorthand_property_identifier_pattern)
+ ]) @variable.parameter))
+
+(optional_parameter
+ (identifier) @variable.parameter)
+
+(optional_parameter
+ (_
+ ([
+ (identifier)
+ (shorthand_property_identifier_pattern)
+ ]) @variable.parameter))
+
+(catch_clause
+ parameter: (identifier) @variable.parameter)
+
+(index_signature
+ name: (identifier) @variable.parameter)
+
+(arrow_function
+ parameter: (identifier) @variable.parameter)
+
+(type_predicate
+ name: (identifier) @variable.parameter)
+
; Special identifiers
-((identifier) @type
- (#match? @type "^[A-Z]"))
+(type_annotation) @type
(type_identifier) @type
(predefined_type) @type.builtin
+(type_alias_declaration
+ (type_identifier) @type)
+
+(type_alias_declaration
+ value: (_
+ (type_identifier) @type))
+
+(interface_declaration
+ (type_identifier) @type)
+
+(class_declaration
+ (type_identifier) @type.class)
+
+(extends_clause
+ value: (identifier) @type.class)
+
+(extends_type_clause
+ type: (type_identifier) @type)
+
+(implements_clause
+ (type_identifier) @type)
+
([
(identifier)
(shorthand_property_identifier)
@@ -231,8 +286,42 @@
"<" @punctuation.bracket
">" @punctuation.bracket)
+(type_parameters
+ "<" @punctuation.bracket
+ ">" @punctuation.bracket)
+
(decorator "@" @punctuation.special)
+(union_type
+ ("|") @punctuation.special)
+
+(intersection_type
+ ("&") @punctuation.special)
+
+(type_annotation
+ (":") @punctuation.special)
+
+(index_signature
+ (":") @punctuation.special)
+
+(type_predicate_annotation
+ (":") @punctuation.special)
+
+(public_field_definition
+ ("?") @punctuation.special)
+
+(property_signature
+ ("?") @punctuation.special)
+
+(method_signature
+ ("?") @punctuation.special)
+
+(optional_parameter
+ ([
+ "?"
+ ":"
+ ]) @punctuation.special)
+
; Keywords
[ "abstract"
@@ -822,9 +822,9 @@ impl LspAdapter for TypeScriptLspAdapter {
fn language_ids(&self) -> HashMap<LanguageName, String> {
HashMap::from_iter([
- (LanguageName::new("TypeScript"), "typescript".into()),
- (LanguageName::new("JavaScript"), "javascript".into()),
- (LanguageName::new("TSX"), "typescriptreact".into()),
+ (LanguageName::new_static("TypeScript"), "typescript".into()),
+ (LanguageName::new_static("JavaScript"), "javascript".into()),
+ (LanguageName::new_static("TSX"), "typescriptreact".into()),
])
}
}
@@ -4,11 +4,33 @@
; Special identifiers
-((identifier) @type
- (#match? @type "^[A-Z]"))
+(type_annotation) @type
+
(type_identifier) @type
(predefined_type) @type.builtin
+(type_alias_declaration
+ (type_identifier) @type)
+
+(type_alias_declaration
+ value: (_
+ (type_identifier) @type))
+
+(interface_declaration
+ (type_identifier) @type)
+
+(class_declaration
+ (type_identifier) @type.class)
+
+(extends_clause
+ value: (identifier) @type.class)
+
+(extends_type_clause
+ type: (type_identifier) @type)
+
+(implements_clause
+ (type_identifier) @type)
+
;; Enables ts-pretty-errors
;; The Lsp returns "snippets" of typescript, which are not valid typescript in totality,
;; but should still be highlighted
@@ -114,6 +136,40 @@
(arrow_function) @function
+; Parameters
+
+(required_parameter
+ (identifier) @variable.parameter)
+
+(required_parameter
+ (_
+ ([
+ (identifier)
+ (shorthand_property_identifier_pattern)
+ ]) @variable.parameter))
+
+(optional_parameter
+ (identifier) @variable.parameter)
+
+(optional_parameter
+ (_
+ ([
+ (identifier)
+ (shorthand_property_identifier_pattern)
+ ]) @variable.parameter))
+
+(catch_clause
+ parameter: (identifier) @variable.parameter)
+
+(index_signature
+ name: (identifier) @variable.parameter)
+
+(arrow_function
+ parameter: (identifier) @variable.parameter)
+
+(type_predicate
+ name: (identifier) @variable.parameter)
+
; Literals
(this) @variable.special
@@ -244,8 +300,42 @@
"<" @punctuation.bracket
">" @punctuation.bracket)
+(type_parameters
+ "<" @punctuation.bracket
+ ">" @punctuation.bracket)
+
(decorator "@" @punctuation.special)
+(union_type
+ ("|") @punctuation.special)
+
+(intersection_type
+ ("&") @punctuation.special)
+
+(type_annotation
+ (":") @punctuation.special)
+
+(index_signature
+ (":") @punctuation.special)
+
+(type_predicate_annotation
+ (":") @punctuation.special)
+
+(public_field_definition
+ ("?") @punctuation.special)
+
+(property_signature
+ ("?") @punctuation.special)
+
+(method_signature
+ ("?") @punctuation.special)
+
+(optional_parameter
+ ([
+ "?"
+ ":"
+ ]) @punctuation.special)
+
; Keywords
[
@@ -296,9 +296,9 @@ impl LspAdapter for VtslsLspAdapter {
fn language_ids(&self) -> HashMap<LanguageName, String> {
HashMap::from_iter([
- (LanguageName::new("TypeScript"), "typescript".into()),
- (LanguageName::new("JavaScript"), "javascript".into()),
- (LanguageName::new("TSX"), "typescriptreact".into()),
+ (LanguageName::new_static("TypeScript"), "typescript".into()),
+ (LanguageName::new_static("JavaScript"), "javascript".into()),
+ (LanguageName::new_static("TSX"), "typescriptreact".into()),
])
}
}
@@ -7,6 +7,7 @@ use gpui::HitboxBehavior;
use language::LanguageName;
use log::Level;
pub use path_range::{LineCol, PathWithRange};
+use ui::Checkbox;
use std::borrow::Cow;
use std::iter;
@@ -795,7 +796,7 @@ impl Element for MarkdownElement {
let mut code_block_ids = HashSet::default();
let mut current_img_block_range: Option<Range<usize>> = None;
- for (range, event) in parsed_markdown.events.iter() {
+ for (index, (range, event)) in parsed_markdown.events.iter().enumerate() {
// Skip alt text for images that rendered
if let Some(current_img_block_range) = ¤t_img_block_range
&& current_img_block_range.end > range.end
@@ -945,13 +946,29 @@ impl Element for MarkdownElement {
MarkdownTag::HtmlBlock => builder.push_div(div(), range, markdown_end),
MarkdownTag::List(bullet_index) => {
builder.push_list(*bullet_index);
- builder.push_div(div().pl_4(), range, markdown_end);
+ builder.push_div(div().pl_2p5(), range, markdown_end);
}
MarkdownTag::Item => {
- let bullet = if let Some(bullet_index) = builder.next_bullet_index() {
- format!("{}.", bullet_index)
+ let bullet = if let Some((_, MarkdownEvent::TaskListMarker(checked))) =
+ parsed_markdown.events.get(index.saturating_add(1))
+ {
+ let source = &parsed_markdown.source()[range.clone()];
+
+ Checkbox::new(
+ ElementId::Name(source.to_string().into()),
+ if *checked {
+ ToggleState::Selected
+ } else {
+ ToggleState::Unselected
+ },
+ )
+ .fill()
+ .visualization_only(true)
+ .into_any_element()
+ } else if let Some(bullet_index) = builder.next_bullet_index() {
+ div().child(format!("{}.", bullet_index)).into_any_element()
} else {
- "โข".to_string()
+ div().child("โข").into_any_element()
};
builder.push_div(
div()
@@ -1226,6 +1243,9 @@ impl Element for MarkdownElement {
}
MarkdownEvent::SoftBreak => builder.push_text(" ", range.clone()),
MarkdownEvent::HardBreak => builder.push_text("\n", range.clone()),
+ MarkdownEvent::TaskListMarker(_) => {
+ // handled inside the `MarkdownTag::Item` case
+ }
_ => log::debug!("unsupported markdown event {:?}", event),
}
}
@@ -37,3 +37,4 @@ workspace.workspace = true
[dev-dependencies]
editor = { workspace = true, features = ["test-support"] }
+language = { workspace = true, features = ["test-support"] }
@@ -1467,9 +1467,7 @@ mod tests {
use ParsedMarkdownListItemType::*;
use core::panic;
use gpui::{AbsoluteLength, BackgroundExecutor, DefiniteLength};
- use language::{
- HighlightId, Language, LanguageConfig, LanguageMatcher, LanguageRegistry, tree_sitter_rust,
- };
+ use language::{HighlightId, LanguageRegistry};
use pretty_assertions::assert_eq;
async fn parse(input: &str) -> ParsedMarkdown {
@@ -3053,7 +3051,7 @@ fn main() {
#[gpui::test]
async fn test_code_block_with_language(executor: BackgroundExecutor) {
let language_registry = Arc::new(LanguageRegistry::test(executor.clone()));
- language_registry.add(rust_lang());
+ language_registry.add(language::rust_lang());
let parsed = parse_markdown(
"\
@@ -3079,21 +3077,6 @@ fn main() {
);
}
- fn rust_lang() -> Arc<Language> {
- Arc::new(Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".into()],
- ..Default::default()
- },
- collapsed_placeholder: " /* ... */ ".to_string(),
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- ))
- }
-
fn h1(contents: MarkdownParagraph, source_range: Range<usize>) -> ParsedMarkdownElement {
ParsedMarkdownElement::Heading(ParsedMarkdownHeading {
source_range,
@@ -524,7 +524,7 @@ impl Render for MarkdownPreviewView {
if e.checked() { "[x]" } else { "[ ]" };
editor.edit(
- vec![(
+ [(
MultiBufferOffset(
e.source_range().start,
)
@@ -42,6 +42,8 @@ sum_tree.workspace = true
text.workspace = true
theme.workspace = true
tree-sitter.workspace = true
+ztracing.workspace = true
+tracing.workspace = true
util.workspace = true
[dev-dependencies]
@@ -56,3 +58,6 @@ settings = { workspace = true, features = ["test-support"] }
text = { workspace = true, features = ["test-support"] }
util = { workspace = true, features = ["test-support"] }
zlog.workspace = true
+
+[package.metadata.cargo-machete]
+ignored = ["tracing"]
@@ -57,6 +57,7 @@ use text::{
};
use theme::SyntaxTheme;
use util::post_inc;
+use ztracing::instrument;
pub use self::path_key::PathKey;
@@ -1671,6 +1672,7 @@ impl MultiBuffer {
self.insert_excerpts_after(ExcerptId::max(), buffer, ranges, cx)
}
+ #[instrument(skip_all)]
fn merge_excerpt_ranges<'a>(
expanded_ranges: impl IntoIterator<Item = &'a ExcerptRange<Point>> + 'a,
) -> (Vec<ExcerptRange<Point>>, Vec<usize>) {
@@ -4483,6 +4485,7 @@ impl MultiBufferSnapshot {
self.convert_dimension(point, text::BufferSnapshot::point_utf16_to_point)
}
+ #[instrument(skip_all)]
pub fn point_to_offset(&self, point: Point) -> MultiBufferOffset {
self.convert_dimension(point, text::BufferSnapshot::point_to_offset)
}
@@ -4536,6 +4539,7 @@ impl MultiBufferSnapshot {
}
}
+ #[instrument(skip_all)]
fn convert_dimension<MBR1, MBR2, BR1, BR2>(
&self,
key: MBR1,
@@ -6453,12 +6457,13 @@ impl MultiBufferSnapshot {
}
/// Returns the excerpt for the given id. The returned excerpt is guaranteed
- /// to have the same excerpt id as the one passed in, with the exception of
- /// `ExcerptId::max()`.
+ /// to have the latest excerpt id for the one passed in and will also remap
+ /// `ExcerptId::max()` to the corresponding excertp ID.
///
/// Callers of this function should generally use the resulting excerpt's `id` field
/// afterwards.
fn excerpt(&self, excerpt_id: ExcerptId) -> Option<&Excerpt> {
+ let excerpt_id = self.latest_excerpt_id(excerpt_id);
let mut cursor = self.excerpts.cursor::<Option<&Locator>>(());
let locator = self.excerpt_locator_for_id(excerpt_id);
cursor.seek(&Some(locator), Bias::Left);
@@ -6684,6 +6689,7 @@ where
MBD: MultiBufferDimension + Ord + Sub + ops::AddAssign<<MBD as Sub>::Output>,
BD: TextDimension + AddAssign<<MBD as Sub>::Output>,
{
+ #[instrument(skip_all)]
fn seek(&mut self, position: &MBD) {
let position = OutputDimension(*position);
self.cached_region.take();
@@ -1,435 +1,437 @@
-use std::{mem, ops::Range, sync::Arc};
-
-use collections::HashSet;
-use gpui::{App, AppContext, Context, Entity};
-use itertools::Itertools;
-use language::{Buffer, BufferSnapshot};
-use rope::Point;
-use text::{Bias, BufferId, OffsetRangeExt, locator::Locator};
-use util::{post_inc, rel_path::RelPath};
-
-use crate::{
- Anchor, ExcerptId, ExcerptRange, ExpandExcerptDirection, MultiBuffer, build_excerpt_ranges,
-};
-
-#[derive(PartialEq, Eq, Ord, PartialOrd, Clone, Hash, Debug)]
-pub struct PathKey {
- // Used by the derived PartialOrd & Ord
- pub sort_prefix: Option<u64>,
- pub path: Arc<RelPath>,
-}
-
-impl PathKey {
- pub fn with_sort_prefix(sort_prefix: u64, path: Arc<RelPath>) -> Self {
- Self {
- sort_prefix: Some(sort_prefix),
- path,
- }
- }
-
- pub fn for_buffer(buffer: &Entity<Buffer>, cx: &App) -> Self {
- if let Some(file) = buffer.read(cx).file() {
- Self::with_sort_prefix(file.worktree_id(cx).to_proto(), file.path().clone())
- } else {
- Self {
- sort_prefix: None,
- path: RelPath::unix(&buffer.entity_id().to_string())
- .unwrap()
- .into_arc(),
- }
- }
- }
-}
-
-impl MultiBuffer {
- pub fn paths(&self) -> impl Iterator<Item = PathKey> + '_ {
- self.excerpts_by_path.keys().cloned()
- }
-
- pub fn remove_excerpts_for_path(&mut self, path: PathKey, cx: &mut Context<Self>) {
- if let Some(to_remove) = self.excerpts_by_path.remove(&path) {
- self.remove_excerpts(to_remove, cx)
- }
- if let Some(follower) = &self.follower {
- follower.update(cx, |follower, cx| {
- follower.remove_excerpts_for_path(path, cx);
- });
- }
- }
-
- pub fn location_for_path(&self, path: &PathKey, cx: &App) -> Option<Anchor> {
- let excerpt_id = self.excerpts_by_path.get(path)?.first()?;
- let snapshot = self.read(cx);
- let excerpt = snapshot.excerpt(*excerpt_id)?;
- Some(Anchor::in_buffer(excerpt.id, excerpt.range.context.start))
- }
-
- pub fn excerpt_paths(&self) -> impl Iterator<Item = &PathKey> {
- self.excerpts_by_path.keys()
- }
-
- /// Sets excerpts, returns `true` if at least one new excerpt was added.
- pub fn set_excerpts_for_path(
- &mut self,
- path: PathKey,
- buffer: Entity<Buffer>,
- ranges: impl IntoIterator<Item = Range<Point>>,
- context_line_count: u32,
- cx: &mut Context<Self>,
- ) -> (Vec<Range<Anchor>>, bool) {
- let buffer_snapshot = buffer.read(cx).snapshot();
- let excerpt_ranges = build_excerpt_ranges(ranges, context_line_count, &buffer_snapshot);
-
- let (new, counts) = Self::merge_excerpt_ranges(&excerpt_ranges);
- self.set_merged_excerpt_ranges_for_path(
- path,
- buffer,
- excerpt_ranges,
- &buffer_snapshot,
- new,
- counts,
- cx,
- )
- }
-
- pub fn set_excerpt_ranges_for_path(
- &mut self,
- path: PathKey,
- buffer: Entity<Buffer>,
- buffer_snapshot: &BufferSnapshot,
- excerpt_ranges: Vec<ExcerptRange<Point>>,
- cx: &mut Context<Self>,
- ) -> (Vec<Range<Anchor>>, bool) {
- let (new, counts) = Self::merge_excerpt_ranges(&excerpt_ranges);
- self.set_merged_excerpt_ranges_for_path(
- path,
- buffer,
- excerpt_ranges,
- buffer_snapshot,
- new,
- counts,
- cx,
- )
- }
-
- pub fn set_anchored_excerpts_for_path(
- &self,
- path_key: PathKey,
- buffer: Entity<Buffer>,
- ranges: Vec<Range<text::Anchor>>,
- context_line_count: u32,
- cx: &Context<Self>,
- ) -> impl Future<Output = Vec<Range<Anchor>>> + use<> {
- let buffer_snapshot = buffer.read(cx).snapshot();
- let multi_buffer = cx.weak_entity();
- let mut app = cx.to_async();
- async move {
- let snapshot = buffer_snapshot.clone();
- let (excerpt_ranges, new, counts) = app
- .background_spawn(async move {
- let ranges = ranges.into_iter().map(|range| range.to_point(&snapshot));
- let excerpt_ranges =
- build_excerpt_ranges(ranges, context_line_count, &snapshot);
- let (new, counts) = Self::merge_excerpt_ranges(&excerpt_ranges);
- (excerpt_ranges, new, counts)
- })
- .await;
-
- multi_buffer
- .update(&mut app, move |multi_buffer, cx| {
- let (ranges, _) = multi_buffer.set_merged_excerpt_ranges_for_path(
- path_key,
- buffer,
- excerpt_ranges,
- &buffer_snapshot,
- new,
- counts,
- cx,
- );
- ranges
- })
- .ok()
- .unwrap_or_default()
- }
- }
-
- pub fn remove_excerpts_for_buffer(&mut self, buffer: BufferId, cx: &mut Context<Self>) {
- self.remove_excerpts(
- self.excerpts_for_buffer(buffer, cx)
- .into_iter()
- .map(|(excerpt, _)| excerpt),
- cx,
- );
- }
-
- pub(super) fn expand_excerpts_with_paths(
- &mut self,
- ids: impl IntoIterator<Item = ExcerptId>,
- line_count: u32,
- direction: ExpandExcerptDirection,
- cx: &mut Context<Self>,
- ) {
- let grouped = ids
- .into_iter()
- .chunk_by(|id| self.paths_by_excerpt.get(id).cloned())
- .into_iter()
- .filter_map(|(k, v)| Some((k?, v.into_iter().collect::<Vec<_>>())))
- .collect::<Vec<_>>();
- let snapshot = self.snapshot(cx);
-
- for (path, ids) in grouped.into_iter() {
- let Some(excerpt_ids) = self.excerpts_by_path.get(&path) else {
- continue;
- };
-
- let ids_to_expand = HashSet::from_iter(ids);
- let mut excerpt_id_ = None;
- let expanded_ranges = excerpt_ids.iter().filter_map(|excerpt_id| {
- let excerpt = snapshot.excerpt(*excerpt_id)?;
- let excerpt_id = excerpt.id;
- if excerpt_id_.is_none() {
- excerpt_id_ = Some(excerpt_id);
- }
-
- let mut context = excerpt.range.context.to_point(&excerpt.buffer);
- if ids_to_expand.contains(&excerpt_id) {
- match direction {
- ExpandExcerptDirection::Up => {
- context.start.row = context.start.row.saturating_sub(line_count);
- context.start.column = 0;
- }
- ExpandExcerptDirection::Down => {
- context.end.row =
- (context.end.row + line_count).min(excerpt.buffer.max_point().row);
- context.end.column = excerpt.buffer.line_len(context.end.row);
- }
- ExpandExcerptDirection::UpAndDown => {
- context.start.row = context.start.row.saturating_sub(line_count);
- context.start.column = 0;
- context.end.row =
- (context.end.row + line_count).min(excerpt.buffer.max_point().row);
- context.end.column = excerpt.buffer.line_len(context.end.row);
- }
- }
- }
-
- Some(ExcerptRange {
- context,
- primary: excerpt.range.primary.to_point(&excerpt.buffer),
- })
- });
- let mut merged_ranges: Vec<ExcerptRange<Point>> = Vec::new();
- for range in expanded_ranges {
- if let Some(last_range) = merged_ranges.last_mut()
- && last_range.context.end >= range.context.start
- {
- last_range.context.end = range.context.end;
- continue;
- }
- merged_ranges.push(range)
- }
- let Some(excerpt_id) = excerpt_id_ else {
- continue;
- };
- let Some(buffer_id) = &snapshot.buffer_id_for_excerpt(excerpt_id) else {
- continue;
- };
-
- let Some(buffer) = self.buffers.get(buffer_id).map(|b| b.buffer.clone()) else {
- continue;
- };
-
- let buffer_snapshot = buffer.read(cx).snapshot();
- self.update_path_excerpts(path.clone(), buffer, &buffer_snapshot, merged_ranges, cx);
- }
- }
-
- /// Sets excerpts, returns `true` if at least one new excerpt was added.
- fn set_merged_excerpt_ranges_for_path(
- &mut self,
- path: PathKey,
- buffer: Entity<Buffer>,
- ranges: Vec<ExcerptRange<Point>>,
- buffer_snapshot: &BufferSnapshot,
- new: Vec<ExcerptRange<Point>>,
- counts: Vec<usize>,
- cx: &mut Context<Self>,
- ) -> (Vec<Range<Anchor>>, bool) {
- let (excerpt_ids, added_a_new_excerpt) =
- self.update_path_excerpts(path, buffer, buffer_snapshot, new, cx);
-
- let mut result = Vec::new();
- let mut ranges = ranges.into_iter();
- for (excerpt_id, range_count) in excerpt_ids.into_iter().zip(counts.into_iter()) {
- for range in ranges.by_ref().take(range_count) {
- let range = Anchor::range_in_buffer(
- excerpt_id,
- buffer_snapshot.anchor_before(&range.primary.start)
- ..buffer_snapshot.anchor_after(&range.primary.end),
- );
- result.push(range)
- }
- }
- (result, added_a_new_excerpt)
- }
-
- fn update_path_excerpts(
- &mut self,
- path: PathKey,
- buffer: Entity<Buffer>,
- buffer_snapshot: &BufferSnapshot,
- new: Vec<ExcerptRange<Point>>,
- cx: &mut Context<Self>,
- ) -> (Vec<ExcerptId>, bool) {
- let mut insert_after = self
- .excerpts_by_path
- .range(..path.clone())
- .next_back()
- .and_then(|(_, value)| value.last().copied())
- .unwrap_or(ExcerptId::min());
-
- let existing = self
- .excerpts_by_path
- .get(&path)
- .cloned()
- .unwrap_or_default();
- let mut new_iter = new.into_iter().peekable();
- let mut existing_iter = existing.into_iter().peekable();
-
- let mut excerpt_ids = Vec::new();
- let mut to_remove = Vec::new();
- let mut to_insert: Vec<(ExcerptId, ExcerptRange<Point>)> = Vec::new();
- let mut added_a_new_excerpt = false;
- let snapshot = self.snapshot(cx);
-
- let mut next_excerpt_id =
- // is this right? What if we remove the last excerpt, then we might reallocate with a wrong mapping?
- if let Some(last_entry) = self.snapshot.borrow().excerpt_ids.last() {
- last_entry.id.0 + 1
- } else {
- 1
- };
-
- let mut next_excerpt_id = move || ExcerptId(post_inc(&mut next_excerpt_id));
-
- let mut excerpts_cursor = snapshot.excerpts.cursor::<Option<&Locator>>(());
- excerpts_cursor.next();
-
- loop {
- let existing = if let Some(&existing_id) = existing_iter.peek() {
- let locator = snapshot.excerpt_locator_for_id(existing_id);
- excerpts_cursor.seek_forward(&Some(locator), Bias::Left);
- if let Some(excerpt) = excerpts_cursor.item() {
- if excerpt.buffer_id != buffer_snapshot.remote_id() {
- to_remove.push(existing_id);
- existing_iter.next();
- continue;
- }
- Some((existing_id, excerpt.range.context.to_point(buffer_snapshot)))
- } else {
- None
- }
- } else {
- None
- };
-
- let new = new_iter.peek();
- if let Some((last_id, last)) = to_insert.last_mut() {
- if let Some(new) = new
- && last.context.end >= new.context.start
- {
- last.context.end = last.context.end.max(new.context.end);
- excerpt_ids.push(*last_id);
- new_iter.next();
- continue;
- }
- if let Some((existing_id, existing_range)) = &existing
- && last.context.end >= existing_range.start
- {
- last.context.end = last.context.end.max(existing_range.end);
- to_remove.push(*existing_id);
- self.snapshot
- .get_mut()
- .replaced_excerpts
- .insert(*existing_id, *last_id);
- existing_iter.next();
- continue;
- }
- }
-
- match (new, existing) {
- (None, None) => break,
- (None, Some((existing_id, _))) => {
- existing_iter.next();
- to_remove.push(existing_id);
- continue;
- }
- (Some(_), None) => {
- added_a_new_excerpt = true;
- let new_id = next_excerpt_id();
- excerpt_ids.push(new_id);
- to_insert.push((new_id, new_iter.next().unwrap()));
- continue;
- }
- (Some(new), Some((_, existing_range))) => {
- if existing_range.end < new.context.start {
- let existing_id = existing_iter.next().unwrap();
- to_remove.push(existing_id);
- continue;
- } else if existing_range.start > new.context.end {
- let new_id = next_excerpt_id();
- excerpt_ids.push(new_id);
- to_insert.push((new_id, new_iter.next().unwrap()));
- continue;
- }
-
- if existing_range.start == new.context.start
- && existing_range.end == new.context.end
- {
- self.insert_excerpts_with_ids_after(
- insert_after,
- buffer.clone(),
- mem::take(&mut to_insert),
- cx,
- );
- insert_after = existing_iter.next().unwrap();
- excerpt_ids.push(insert_after);
- new_iter.next();
- } else {
- let existing_id = existing_iter.next().unwrap();
- let new_id = next_excerpt_id();
- self.snapshot
- .get_mut()
- .replaced_excerpts
- .insert(existing_id, new_id);
- to_remove.push(existing_id);
- let mut range = new_iter.next().unwrap();
- range.context.start = range.context.start.min(existing_range.start);
- range.context.end = range.context.end.max(existing_range.end);
- excerpt_ids.push(new_id);
- to_insert.push((new_id, range));
- }
- }
- };
- }
-
- self.insert_excerpts_with_ids_after(insert_after, buffer, to_insert, cx);
- // todo(lw): There is a logic bug somewhere that causes the to_remove vector to be not ordered correctly
- to_remove.sort_by_cached_key(|&id| snapshot.excerpt_locator_for_id(id));
- self.remove_excerpts(to_remove, cx);
-
- if excerpt_ids.is_empty() {
- self.excerpts_by_path.remove(&path);
- } else {
- for excerpt_id in &excerpt_ids {
- self.paths_by_excerpt.insert(*excerpt_id, path.clone());
- }
- let snapshot = &*self.snapshot.get_mut();
- let mut excerpt_ids: Vec<_> = excerpt_ids.iter().dedup().cloned().collect();
- excerpt_ids.sort_by_cached_key(|&id| snapshot.excerpt_locator_for_id(id));
- self.excerpts_by_path.insert(path, excerpt_ids);
- }
-
- (excerpt_ids, added_a_new_excerpt)
- }
-}
+use std::{mem, ops::Range, sync::Arc};
+
+use collections::HashSet;
+use gpui::{App, AppContext, Context, Entity};
+use itertools::Itertools;
+use language::{Buffer, BufferSnapshot};
+use rope::Point;
+use text::{Bias, BufferId, OffsetRangeExt, locator::Locator};
+use util::{post_inc, rel_path::RelPath};
+use ztracing::instrument;
+
+use crate::{
+ Anchor, ExcerptId, ExcerptRange, ExpandExcerptDirection, MultiBuffer, build_excerpt_ranges,
+};
+
+#[derive(PartialEq, Eq, Ord, PartialOrd, Clone, Hash, Debug)]
+pub struct PathKey {
+ // Used by the derived PartialOrd & Ord
+ pub sort_prefix: Option<u64>,
+ pub path: Arc<RelPath>,
+}
+
+impl PathKey {
+ pub fn with_sort_prefix(sort_prefix: u64, path: Arc<RelPath>) -> Self {
+ Self {
+ sort_prefix: Some(sort_prefix),
+ path,
+ }
+ }
+
+ pub fn for_buffer(buffer: &Entity<Buffer>, cx: &App) -> Self {
+ if let Some(file) = buffer.read(cx).file() {
+ Self::with_sort_prefix(file.worktree_id(cx).to_proto(), file.path().clone())
+ } else {
+ Self {
+ sort_prefix: None,
+ path: RelPath::unix(&buffer.entity_id().to_string())
+ .unwrap()
+ .into_arc(),
+ }
+ }
+ }
+}
+
+impl MultiBuffer {
+ pub fn paths(&self) -> impl Iterator<Item = PathKey> + '_ {
+ self.excerpts_by_path.keys().cloned()
+ }
+
+ pub fn remove_excerpts_for_path(&mut self, path: PathKey, cx: &mut Context<Self>) {
+ if let Some(to_remove) = self.excerpts_by_path.remove(&path) {
+ self.remove_excerpts(to_remove, cx)
+ }
+ if let Some(follower) = &self.follower {
+ follower.update(cx, |follower, cx| {
+ follower.remove_excerpts_for_path(path, cx);
+ });
+ }
+ }
+
+ pub fn location_for_path(&self, path: &PathKey, cx: &App) -> Option<Anchor> {
+ let excerpt_id = self.excerpts_by_path.get(path)?.first()?;
+ let snapshot = self.read(cx);
+ let excerpt = snapshot.excerpt(*excerpt_id)?;
+ Some(Anchor::in_buffer(excerpt.id, excerpt.range.context.start))
+ }
+
+ pub fn excerpt_paths(&self) -> impl Iterator<Item = &PathKey> {
+ self.excerpts_by_path.keys()
+ }
+
+ /// Sets excerpts, returns `true` if at least one new excerpt was added.
+ #[instrument(skip_all)]
+ pub fn set_excerpts_for_path(
+ &mut self,
+ path: PathKey,
+ buffer: Entity<Buffer>,
+ ranges: impl IntoIterator<Item = Range<Point>>,
+ context_line_count: u32,
+ cx: &mut Context<Self>,
+ ) -> (Vec<Range<Anchor>>, bool) {
+ let buffer_snapshot = buffer.read(cx).snapshot();
+ let excerpt_ranges = build_excerpt_ranges(ranges, context_line_count, &buffer_snapshot);
+
+ let (new, counts) = Self::merge_excerpt_ranges(&excerpt_ranges);
+ self.set_merged_excerpt_ranges_for_path(
+ path,
+ buffer,
+ excerpt_ranges,
+ &buffer_snapshot,
+ new,
+ counts,
+ cx,
+ )
+ }
+
+ pub fn set_excerpt_ranges_for_path(
+ &mut self,
+ path: PathKey,
+ buffer: Entity<Buffer>,
+ buffer_snapshot: &BufferSnapshot,
+ excerpt_ranges: Vec<ExcerptRange<Point>>,
+ cx: &mut Context<Self>,
+ ) -> (Vec<Range<Anchor>>, bool) {
+ let (new, counts) = Self::merge_excerpt_ranges(&excerpt_ranges);
+ self.set_merged_excerpt_ranges_for_path(
+ path,
+ buffer,
+ excerpt_ranges,
+ buffer_snapshot,
+ new,
+ counts,
+ cx,
+ )
+ }
+
+ pub fn set_anchored_excerpts_for_path(
+ &self,
+ path_key: PathKey,
+ buffer: Entity<Buffer>,
+ ranges: Vec<Range<text::Anchor>>,
+ context_line_count: u32,
+ cx: &Context<Self>,
+ ) -> impl Future<Output = Vec<Range<Anchor>>> + use<> {
+ let buffer_snapshot = buffer.read(cx).snapshot();
+ let multi_buffer = cx.weak_entity();
+ let mut app = cx.to_async();
+ async move {
+ let snapshot = buffer_snapshot.clone();
+ let (excerpt_ranges, new, counts) = app
+ .background_spawn(async move {
+ let ranges = ranges.into_iter().map(|range| range.to_point(&snapshot));
+ let excerpt_ranges =
+ build_excerpt_ranges(ranges, context_line_count, &snapshot);
+ let (new, counts) = Self::merge_excerpt_ranges(&excerpt_ranges);
+ (excerpt_ranges, new, counts)
+ })
+ .await;
+
+ multi_buffer
+ .update(&mut app, move |multi_buffer, cx| {
+ let (ranges, _) = multi_buffer.set_merged_excerpt_ranges_for_path(
+ path_key,
+ buffer,
+ excerpt_ranges,
+ &buffer_snapshot,
+ new,
+ counts,
+ cx,
+ );
+ ranges
+ })
+ .ok()
+ .unwrap_or_default()
+ }
+ }
+
+ pub fn remove_excerpts_for_buffer(&mut self, buffer: BufferId, cx: &mut Context<Self>) {
+ self.remove_excerpts(
+ self.excerpts_for_buffer(buffer, cx)
+ .into_iter()
+ .map(|(excerpt, _)| excerpt),
+ cx,
+ );
+ }
+
+ pub(super) fn expand_excerpts_with_paths(
+ &mut self,
+ ids: impl IntoIterator<Item = ExcerptId>,
+ line_count: u32,
+ direction: ExpandExcerptDirection,
+ cx: &mut Context<Self>,
+ ) {
+ let grouped = ids
+ .into_iter()
+ .chunk_by(|id| self.paths_by_excerpt.get(id).cloned())
+ .into_iter()
+ .filter_map(|(k, v)| Some((k?, v.into_iter().collect::<Vec<_>>())))
+ .collect::<Vec<_>>();
+ let snapshot = self.snapshot(cx);
+
+ for (path, ids) in grouped.into_iter() {
+ let Some(excerpt_ids) = self.excerpts_by_path.get(&path) else {
+ continue;
+ };
+
+ let ids_to_expand = HashSet::from_iter(ids);
+ let mut excerpt_id_ = None;
+ let expanded_ranges = excerpt_ids.iter().filter_map(|excerpt_id| {
+ let excerpt = snapshot.excerpt(*excerpt_id)?;
+ let excerpt_id = excerpt.id;
+ if excerpt_id_.is_none() {
+ excerpt_id_ = Some(excerpt_id);
+ }
+
+ let mut context = excerpt.range.context.to_point(&excerpt.buffer);
+ if ids_to_expand.contains(&excerpt_id) {
+ match direction {
+ ExpandExcerptDirection::Up => {
+ context.start.row = context.start.row.saturating_sub(line_count);
+ context.start.column = 0;
+ }
+ ExpandExcerptDirection::Down => {
+ context.end.row =
+ (context.end.row + line_count).min(excerpt.buffer.max_point().row);
+ context.end.column = excerpt.buffer.line_len(context.end.row);
+ }
+ ExpandExcerptDirection::UpAndDown => {
+ context.start.row = context.start.row.saturating_sub(line_count);
+ context.start.column = 0;
+ context.end.row =
+ (context.end.row + line_count).min(excerpt.buffer.max_point().row);
+ context.end.column = excerpt.buffer.line_len(context.end.row);
+ }
+ }
+ }
+
+ Some(ExcerptRange {
+ context,
+ primary: excerpt.range.primary.to_point(&excerpt.buffer),
+ })
+ });
+ let mut merged_ranges: Vec<ExcerptRange<Point>> = Vec::new();
+ for range in expanded_ranges {
+ if let Some(last_range) = merged_ranges.last_mut()
+ && last_range.context.end >= range.context.start
+ {
+ last_range.context.end = range.context.end;
+ continue;
+ }
+ merged_ranges.push(range)
+ }
+ let Some(excerpt_id) = excerpt_id_ else {
+ continue;
+ };
+ let Some(buffer_id) = &snapshot.buffer_id_for_excerpt(excerpt_id) else {
+ continue;
+ };
+
+ let Some(buffer) = self.buffers.get(buffer_id).map(|b| b.buffer.clone()) else {
+ continue;
+ };
+
+ let buffer_snapshot = buffer.read(cx).snapshot();
+ self.update_path_excerpts(path.clone(), buffer, &buffer_snapshot, merged_ranges, cx);
+ }
+ }
+
+ /// Sets excerpts, returns `true` if at least one new excerpt was added.
+ fn set_merged_excerpt_ranges_for_path(
+ &mut self,
+ path: PathKey,
+ buffer: Entity<Buffer>,
+ ranges: Vec<ExcerptRange<Point>>,
+ buffer_snapshot: &BufferSnapshot,
+ new: Vec<ExcerptRange<Point>>,
+ counts: Vec<usize>,
+ cx: &mut Context<Self>,
+ ) -> (Vec<Range<Anchor>>, bool) {
+ let (excerpt_ids, added_a_new_excerpt) =
+ self.update_path_excerpts(path, buffer, buffer_snapshot, new, cx);
+
+ let mut result = Vec::new();
+ let mut ranges = ranges.into_iter();
+ for (excerpt_id, range_count) in excerpt_ids.into_iter().zip(counts.into_iter()) {
+ for range in ranges.by_ref().take(range_count) {
+ let range = Anchor::range_in_buffer(
+ excerpt_id,
+ buffer_snapshot.anchor_before(&range.primary.start)
+ ..buffer_snapshot.anchor_after(&range.primary.end),
+ );
+ result.push(range)
+ }
+ }
+ (result, added_a_new_excerpt)
+ }
+
+ fn update_path_excerpts(
+ &mut self,
+ path: PathKey,
+ buffer: Entity<Buffer>,
+ buffer_snapshot: &BufferSnapshot,
+ new: Vec<ExcerptRange<Point>>,
+ cx: &mut Context<Self>,
+ ) -> (Vec<ExcerptId>, bool) {
+ let mut insert_after = self
+ .excerpts_by_path
+ .range(..path.clone())
+ .next_back()
+ .and_then(|(_, value)| value.last().copied())
+ .unwrap_or(ExcerptId::min());
+
+ let existing = self
+ .excerpts_by_path
+ .get(&path)
+ .cloned()
+ .unwrap_or_default();
+ let mut new_iter = new.into_iter().peekable();
+ let mut existing_iter = existing.into_iter().peekable();
+
+ let mut excerpt_ids = Vec::new();
+ let mut to_remove = Vec::new();
+ let mut to_insert: Vec<(ExcerptId, ExcerptRange<Point>)> = Vec::new();
+ let mut added_a_new_excerpt = false;
+ let snapshot = self.snapshot(cx);
+
+ let mut next_excerpt_id =
+ // todo(lw): is this right? What if we remove the last excerpt, then we might reallocate with a wrong mapping?
+ if let Some(last_entry) = self.snapshot.borrow().excerpt_ids.last() {
+ last_entry.id.0 + 1
+ } else {
+ 1
+ };
+
+ let mut next_excerpt_id = move || ExcerptId(post_inc(&mut next_excerpt_id));
+
+ let mut excerpts_cursor = snapshot.excerpts.cursor::<Option<&Locator>>(());
+ excerpts_cursor.next();
+
+ loop {
+ let existing = if let Some(&existing_id) = existing_iter.peek() {
+ let locator = snapshot.excerpt_locator_for_id(existing_id);
+ excerpts_cursor.seek_forward(&Some(locator), Bias::Left);
+ if let Some(excerpt) = excerpts_cursor.item() {
+ if excerpt.buffer_id != buffer_snapshot.remote_id() {
+ to_remove.push(existing_id);
+ existing_iter.next();
+ continue;
+ }
+ Some((existing_id, excerpt.range.context.to_point(buffer_snapshot)))
+ } else {
+ None
+ }
+ } else {
+ None
+ };
+
+ let new = new_iter.peek();
+ if let Some((last_id, last)) = to_insert.last_mut() {
+ if let Some(new) = new
+ && last.context.end >= new.context.start
+ {
+ last.context.end = last.context.end.max(new.context.end);
+ excerpt_ids.push(*last_id);
+ new_iter.next();
+ continue;
+ }
+ if let Some((existing_id, existing_range)) = &existing
+ && last.context.end >= existing_range.start
+ {
+ last.context.end = last.context.end.max(existing_range.end);
+ to_remove.push(*existing_id);
+ self.snapshot
+ .get_mut()
+ .replaced_excerpts
+ .insert(*existing_id, *last_id);
+ existing_iter.next();
+ continue;
+ }
+ }
+
+ match (new, existing) {
+ (None, None) => break,
+ (None, Some((existing_id, _))) => {
+ existing_iter.next();
+ to_remove.push(existing_id);
+ continue;
+ }
+ (Some(_), None) => {
+ added_a_new_excerpt = true;
+ let new_id = next_excerpt_id();
+ excerpt_ids.push(new_id);
+ to_insert.push((new_id, new_iter.next().unwrap()));
+ continue;
+ }
+ (Some(new), Some((_, existing_range))) => {
+ if existing_range.end < new.context.start {
+ let existing_id = existing_iter.next().unwrap();
+ to_remove.push(existing_id);
+ continue;
+ } else if existing_range.start > new.context.end {
+ let new_id = next_excerpt_id();
+ excerpt_ids.push(new_id);
+ to_insert.push((new_id, new_iter.next().unwrap()));
+ continue;
+ }
+
+ if existing_range.start == new.context.start
+ && existing_range.end == new.context.end
+ {
+ self.insert_excerpts_with_ids_after(
+ insert_after,
+ buffer.clone(),
+ mem::take(&mut to_insert),
+ cx,
+ );
+ insert_after = existing_iter.next().unwrap();
+ excerpt_ids.push(insert_after);
+ new_iter.next();
+ } else {
+ let existing_id = existing_iter.next().unwrap();
+ let new_id = next_excerpt_id();
+ self.snapshot
+ .get_mut()
+ .replaced_excerpts
+ .insert(existing_id, new_id);
+ to_remove.push(existing_id);
+ let mut range = new_iter.next().unwrap();
+ range.context.start = range.context.start.min(existing_range.start);
+ range.context.end = range.context.end.max(existing_range.end);
+ excerpt_ids.push(new_id);
+ to_insert.push((new_id, range));
+ }
+ }
+ };
+ }
+
+ self.insert_excerpts_with_ids_after(insert_after, buffer, to_insert, cx);
+ // todo(lw): There is a logic bug somewhere that causes the to_remove vector to be not ordered correctly
+ to_remove.sort_by_cached_key(|&id| snapshot.excerpt_locator_for_id(id));
+ self.remove_excerpts(to_remove, cx);
+
+ if excerpt_ids.is_empty() {
+ self.excerpts_by_path.remove(&path);
+ } else {
+ for excerpt_id in &excerpt_ids {
+ self.paths_by_excerpt.insert(*excerpt_id, path.clone());
+ }
+ let snapshot = &*self.snapshot.get_mut();
+ let mut excerpt_ids: Vec<_> = excerpt_ids.iter().dedup().cloned().collect();
+ excerpt_ids.sort_by_cached_key(|&id| snapshot.excerpt_locator_for_id(id));
+ self.excerpts_by_path.insert(path, excerpt_ids);
+ }
+
+ (excerpt_ids, added_a_new_excerpt)
+ }
+}
@@ -266,7 +266,8 @@ pub struct Request {
pub max_completion_tokens: Option<u64>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub stop: Vec<String>,
- pub temperature: f32,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub temperature: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
/// Whether to enable parallel function calling during tool use.
@@ -391,7 +391,6 @@ mod tests {
use super::*;
use gpui::{TestAppContext, VisualTestContext};
use indoc::indoc;
- use language::{Language, LanguageConfig, LanguageMatcher};
use project::{FakeFs, Project};
use serde_json::json;
use util::{path, rel_path::rel_path};
@@ -418,7 +417,9 @@ mod tests {
.await;
let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
- project.read_with(cx, |project, _| project.languages().add(rust_lang()));
+ project.read_with(cx, |project, _| {
+ project.languages().add(language::rust_lang())
+ });
let (workspace, cx) =
cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
@@ -581,89 +582,6 @@ mod tests {
})
}
- fn rust_lang() -> Arc<Language> {
- Arc::new(
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- )
- .with_outline_query(
- r#"(struct_item
- (visibility_modifier)? @context
- "struct" @context
- name: (_) @name) @item
-
- (enum_item
- (visibility_modifier)? @context
- "enum" @context
- name: (_) @name) @item
-
- (enum_variant
- (visibility_modifier)? @context
- name: (_) @name) @item
-
- (impl_item
- "impl" @context
- trait: (_)? @name
- "for"? @context
- type: (_) @name) @item
-
- (trait_item
- (visibility_modifier)? @context
- "trait" @context
- name: (_) @name) @item
-
- (function_item
- (visibility_modifier)? @context
- (function_modifiers)? @context
- "fn" @context
- name: (_) @name) @item
-
- (function_signature_item
- (visibility_modifier)? @context
- (function_modifiers)? @context
- "fn" @context
- name: (_) @name) @item
-
- (macro_definition
- . "macro_rules!" @context
- name: (_) @name) @item
-
- (mod_item
- (visibility_modifier)? @context
- "mod" @context
- name: (_) @name) @item
-
- (type_item
- (visibility_modifier)? @context
- "type" @context
- name: (_) @name) @item
-
- (associated_type
- "type" @context
- name: (_) @name) @item
-
- (const_item
- (visibility_modifier)? @context
- "const" @context
- name: (_) @name) @item
-
- (field_declaration
- (visibility_modifier)? @context
- name: (_) @name) @item
-"#,
- )
- .unwrap(),
- )
- }
-
#[track_caller]
fn assert_single_caret_at_row(
editor: &Entity<Editor>,
@@ -5220,7 +5220,7 @@ impl GenerationState {
mod tests {
use db::indoc;
use gpui::{TestAppContext, VisualTestContext, WindowHandle};
- use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
+ use language::rust_lang;
use pretty_assertions::assert_eq;
use project::FakeFs;
use search::{
@@ -5243,9 +5243,7 @@ mod tests {
let root = path!("/rust-analyzer");
populate_with_test_ra_project(&fs, root).await;
let project = Project::test(fs.clone(), [Path::new(root)], cx).await;
- project.read_with(cx, |project, _| {
- project.languages().add(Arc::new(rust_lang()))
- });
+ project.read_with(cx, |project, _| project.languages().add(rust_lang()));
let workspace = add_outline_panel(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let outline_panel = outline_panel(&workspace, cx);
@@ -5478,9 +5476,7 @@ mod tests {
let root = path!("/rust-analyzer");
populate_with_test_ra_project(&fs, root).await;
let project = Project::test(fs.clone(), [Path::new(root)], cx).await;
- project.read_with(cx, |project, _| {
- project.languages().add(Arc::new(rust_lang()))
- });
+ project.read_with(cx, |project, _| project.languages().add(rust_lang()));
let workspace = add_outline_panel(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let outline_panel = outline_panel(&workspace, cx);
@@ -5617,9 +5613,7 @@ mod tests {
let root = path!("/rust-analyzer");
populate_with_test_ra_project(&fs, root).await;
let project = Project::test(fs.clone(), [Path::new(root)], cx).await;
- project.read_with(cx, |project, _| {
- project.languages().add(Arc::new(rust_lang()))
- });
+ project.read_with(cx, |project, _| project.languages().add(rust_lang()));
let workspace = add_outline_panel(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let outline_panel = outline_panel(&workspace, cx);
@@ -5816,7 +5810,8 @@ mod tests {
outline_panel.selected_entry(),
cx,
),
- "fn_lifetime_fn.rs <==== selected"
+ "outline: pub(super) fn hints
+outline: fn hints_lifetimes_named <==== selected"
);
assert_eq!(
selected_row_text(&new_active_editor, cx),
@@ -6029,24 +6024,7 @@ struct OutlineEntryExcerpt {
)
.await;
let project = Project::test(fs.clone(), [Path::new(root)], cx).await;
- project.read_with(cx, |project, _| {
- project.languages().add(Arc::new(
- rust_lang()
- .with_outline_query(
- r#"
- (struct_item
- (visibility_modifier)? @context
- "struct" @context
- name: (_) @name) @item
-
- (field_declaration
- (visibility_modifier)? @context
- name: (_) @name) @item
-"#,
- )
- .unwrap(),
- ))
- });
+ project.read_with(cx, |project, _| project.languages().add(rust_lang()));
let workspace = add_outline_panel(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let outline_panel = outline_panel(&workspace, cx);
@@ -6601,11 +6579,13 @@ outline: struct OutlineEntryExcerpt
format!(
r#"frontend-project/
public/lottie/
- syntax-tree.json <==== selected
+ syntax-tree.json
+ search: {{ "something": "ยซstaticยป" }}
src/
app/(site)/
components/
- ErrorBoundary.tsx"#
+ ErrorBoundary.tsx <==== selected
+ search: ยซstaticยป"#
)
);
});
@@ -6647,7 +6627,7 @@ outline: struct OutlineEntryExcerpt
format!(
r#"frontend-project/
public/lottie/
- syntax-tree.json <==== selected
+ syntax-tree.json
search: {{ "something": "ยซstaticยป" }}
src/
app/(site)/
@@ -6658,7 +6638,7 @@ outline: struct OutlineEntryExcerpt
page.tsx
search: ยซstaticยป
components/
- ErrorBoundary.tsx
+ ErrorBoundary.tsx <==== selected
search: ยซstaticยป"#
)
);
@@ -6992,35 +6972,6 @@ outline: struct OutlineEntryExcerpt
.await;
}
- fn rust_lang() -> Language {
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- )
- .with_highlights_query(
- r#"
- (field_identifier) @field
- (struct_expression) @struct
- "#,
- )
- .unwrap()
- .with_injection_query(
- r#"
- (macro_invocation
- (token_tree) @injection.content
- (#set! injection.language "rust"))
- "#,
- )
- .unwrap()
- }
-
fn snapshot(outline_panel: &OutlinePanel, cx: &App) -> MultiBufferSnapshot {
outline_panel
.active_editor()
@@ -7086,44 +7037,7 @@ outline: struct OutlineEntryExcerpt
.await;
let project = Project::test(fs.clone(), ["/test".as_ref()], cx).await;
- project.read_with(cx, |project, _| {
- project.languages().add(Arc::new(
- rust_lang()
- .with_outline_query(
- r#"
- (struct_item
- (visibility_modifier)? @context
- "struct" @context
- name: (_) @name) @item
- (impl_item
- "impl" @context
- trait: (_)? @context
- "for"? @context
- type: (_) @context
- body: (_)) @item
- (function_item
- (visibility_modifier)? @context
- "fn" @context
- name: (_) @name
- parameters: (_) @context) @item
- (mod_item
- (visibility_modifier)? @context
- "mod" @context
- name: (_) @name) @item
- (enum_item
- (visibility_modifier)? @context
- "enum" @context
- name: (_) @name) @item
- (field_declaration
- (visibility_modifier)? @context
- name: (_) @name
- ":" @context
- type: (_) @context) @item
- "#,
- )
- .unwrap(),
- ))
- });
+ project.read_with(cx, |project, _| project.languages().add(rust_lang()));
let workspace = add_outline_panel(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let outline_panel = outline_panel(&workspace, cx);
@@ -7174,15 +7088,15 @@ outline: struct OutlineEntryExcerpt
"
outline: mod outer <==== selected
outline: pub struct OuterStruct
- outline: field: String
+ outline: field
outline: impl OuterStruct
- outline: pub fn new()
- outline: pub fn method(&self)
+ outline: pub fn new
+ outline: pub fn method
outline: mod inner
- outline: pub fn inner_function()
+ outline: pub fn inner_function
outline: pub struct InnerStruct
- outline: value: i32
-outline: fn main()"
+ outline: value
+outline: fn main"
)
);
});
@@ -7232,7 +7146,7 @@ outline: fn main()"
indoc!(
"
outline: mod outer <==== selected
-outline: fn main()"
+outline: fn main"
)
);
});
@@ -7257,15 +7171,15 @@ outline: fn main()"
"
outline: mod outer <==== selected
outline: pub struct OuterStruct
- outline: field: String
+ outline: field
outline: impl OuterStruct
- outline: pub fn new()
- outline: pub fn method(&self)
+ outline: pub fn new
+ outline: pub fn method
outline: mod inner
- outline: pub fn inner_function()
+ outline: pub fn inner_function
outline: pub struct InnerStruct
- outline: value: i32
-outline: fn main()"
+ outline: value
+outline: fn main"
)
);
});
@@ -7321,7 +7235,7 @@ outline: fn main()"
indoc!(
"
outline: mod outer
-outline: fn main()"
+outline: fn main"
)
);
});
@@ -7378,44 +7292,7 @@ outline: fn main()"
.await;
let project = Project::test(fs.clone(), ["/test".as_ref()], cx).await;
- project.read_with(cx, |project, _| {
- project.languages().add(Arc::new(
- rust_lang()
- .with_outline_query(
- r#"
- (struct_item
- (visibility_modifier)? @context
- "struct" @context
- name: (_) @name) @item
- (impl_item
- "impl" @context
- trait: (_)? @context
- "for"? @context
- type: (_) @context
- body: (_)) @item
- (function_item
- (visibility_modifier)? @context
- "fn" @context
- name: (_) @name
- parameters: (_) @context) @item
- (mod_item
- (visibility_modifier)? @context
- "mod" @context
- name: (_) @name) @item
- (enum_item
- (visibility_modifier)? @context
- "enum" @context
- name: (_) @name) @item
- (field_declaration
- (visibility_modifier)? @context
- name: (_) @name
- ":" @context
- type: (_) @context) @item
- "#,
- )
- .unwrap(),
- ))
- });
+ project.read_with(cx, |project, _| project.languages().add(rust_lang()));
let workspace = add_outline_panel(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
@@ -7462,14 +7339,16 @@ outline: fn main()"
indoc!(
"
outline: struct Config
- outline: name: String
- outline: value: i32
+ outline: name
+ outline: value
outline: impl Config
- outline: fn new(name: String)
- outline: fn get_value(&self)
+ outline: fn new
+ outline: fn get_value
outline: enum Status
-outline: fn process_config(config: Config)
-outline: fn main()"
+ outline: Active
+ outline: Inactive
+outline: fn process_config
+outline: fn main"
)
);
});
@@ -7500,14 +7379,16 @@ outline: fn main()"
indoc!(
"
outline: struct Config <==== selected
- outline: name: String
- outline: value: i32
+ outline: name
+ outline: value
outline: impl Config
- outline: fn new(name: String)
- outline: fn get_value(&self)
+ outline: fn new
+ outline: fn get_value
outline: enum Status
-outline: fn process_config(config: Config)
-outline: fn main()"
+ outline: Active
+ outline: Inactive
+outline: fn process_config
+outline: fn main"
)
);
});
@@ -7535,11 +7416,13 @@ outline: fn main()"
"
outline: struct Config <==== selected
outline: impl Config
- outline: fn new(name: String)
- outline: fn get_value(&self)
+ outline: fn new
+ outline: fn get_value
outline: enum Status
-outline: fn process_config(config: Config)
-outline: fn main()"
+ outline: Active
+ outline: Inactive
+outline: fn process_config
+outline: fn main"
)
);
});
@@ -7566,14 +7449,16 @@ outline: fn main()"
indoc!(
"
outline: struct Config <==== selected
- outline: name: String
- outline: value: i32
+ outline: name
+ outline: value
outline: impl Config
- outline: fn new(name: String)
- outline: fn get_value(&self)
+ outline: fn new
+ outline: fn get_value
outline: enum Status
-outline: fn process_config(config: Config)
-outline: fn main()"
+ outline: Active
+ outline: Inactive
+outline: fn process_config
+outline: fn main"
)
);
});
@@ -7622,44 +7507,7 @@ outline: fn main()"
.await;
let project = Project::test(fs.clone(), ["/test".as_ref()], cx).await;
- project.read_with(cx, |project, _| {
- project.languages().add(Arc::new(
- rust_lang()
- .with_outline_query(
- r#"
- (struct_item
- (visibility_modifier)? @context
- "struct" @context
- name: (_) @name) @item
- (impl_item
- "impl" @context
- trait: (_)? @context
- "for"? @context
- type: (_) @context
- body: (_)) @item
- (function_item
- (visibility_modifier)? @context
- "fn" @context
- name: (_) @name
- parameters: (_) @context) @item
- (mod_item
- (visibility_modifier)? @context
- "mod" @context
- name: (_) @name) @item
- (enum_item
- (visibility_modifier)? @context
- "enum" @context
- name: (_) @name) @item
- (field_declaration
- (visibility_modifier)? @context
- name: (_) @name
- ":" @context
- type: (_) @context) @item
- "#,
- )
- .unwrap(),
- ))
- });
+ project.read_with(cx, |project, _| project.languages().add(rust_lang()));
let workspace = add_outline_panel(&project, cx).await;
let cx = &mut VisualTestContext::from_window(*workspace, cx);
let outline_panel = outline_panel(&workspace, cx);
@@ -7710,15 +7558,15 @@ outline: fn main()"
"
outline: mod outer <==== selected
outline: pub struct OuterStruct
- outline: field: String
+ outline: field
outline: impl OuterStruct
- outline: pub fn new()
- outline: pub fn method(&self)
+ outline: pub fn new
+ outline: pub fn method
outline: mod inner
- outline: pub fn inner_function()
+ outline: pub fn inner_function
outline: pub struct InnerStruct
- outline: value: i32
-outline: fn main()"
+ outline: value
+outline: fn main"
)
);
});
@@ -7759,7 +7607,7 @@ outline: fn main()"
let expected_collapsed_output = indoc!(
"
outline: mod outer <==== selected
- outline: fn main()"
+ outline: fn main"
);
outline_panel.update(cx, |panel, cx| {
@@ -7787,15 +7635,15 @@ outline: fn main()"
"
outline: mod outer <==== selected
outline: pub struct OuterStruct
- outline: field: String
+ outline: field
outline: impl OuterStruct
- outline: pub fn new()
- outline: pub fn method(&self)
+ outline: pub fn new
+ outline: pub fn method
outline: mod inner
- outline: pub fn inner_function()
+ outline: pub fn inner_function
outline: pub struct InnerStruct
- outline: value: i32
- outline: fn main()"
+ outline: value
+ outline: fn main"
);
outline_panel.update(cx, |panel, cx| {
@@ -91,6 +91,8 @@ which.workspace = true
worktree.workspace = true
zeroize.workspace = true
zlog.workspace = true
+ztracing.workspace = true
+tracing.workspace = true
[dev-dependencies]
client = { workspace = true, features = ["test-support"] }
@@ -113,3 +115,6 @@ snippet_provider = { workspace = true, features = ["test-support"] }
unindent.workspace = true
util = { workspace = true, features = ["test-support"] }
worktree = { workspace = true, features = ["test-support"] }
+
+[package.metadata.cargo-machete]
+ignored = ["tracing"]
@@ -1365,7 +1365,7 @@ impl ExternalAgentServer for LocalCodex {
&mut self,
root_dir: Option<&str>,
extra_env: HashMap<String, String>,
- status_tx: Option<watch::Sender<SharedString>>,
+ mut status_tx: Option<watch::Sender<SharedString>>,
_new_version_available_tx: Option<watch::Sender<Option<String>>>,
cx: &mut AsyncApp,
) -> Task<Result<(AgentServerCommand, String, Option<task::SpawnInTerminal>)>> {
@@ -1402,58 +1402,115 @@ impl ExternalAgentServer for LocalCodex {
let dir = paths::external_agents_dir().join(CODEX_NAME);
fs.create_dir(&dir).await?;
- // Find or install the latest Codex release (no update checks for now).
- let release = ::http_client::github::latest_github_release(
+ let bin_name = if cfg!(windows) {
+ "codex-acp.exe"
+ } else {
+ "codex-acp"
+ };
+
+ let find_latest_local_version = async || -> Option<PathBuf> {
+ let mut local_versions: Vec<(semver::Version, String)> = Vec::new();
+ let mut stream = fs.read_dir(&dir).await.ok()?;
+ while let Some(entry) = stream.next().await {
+ let Ok(entry) = entry else { continue };
+ let Some(file_name) = entry.file_name() else {
+ continue;
+ };
+ let version_path = dir.join(&file_name);
+ if fs.is_file(&version_path.join(bin_name)).await {
+ let version_str = file_name.to_string_lossy();
+ if let Ok(version) =
+ semver::Version::from_str(version_str.trim_start_matches('v'))
+ {
+ local_versions.push((version, version_str.into_owned()));
+ }
+ }
+ }
+ local_versions.sort_by(|(a, _), (b, _)| a.cmp(b));
+ local_versions.last().map(|(_, v)| dir.join(v))
+ };
+
+ let fallback_to_latest_local_version =
+ async |err: anyhow::Error| -> Result<PathBuf, anyhow::Error> {
+ if let Some(local) = find_latest_local_version().await {
+ log::info!(
+ "Falling back to locally installed Codex version: {}",
+ local.display()
+ );
+ Ok(local)
+ } else {
+ Err(err)
+ }
+ };
+
+ let version_dir = match ::http_client::github::latest_github_release(
CODEX_ACP_REPO,
true,
false,
http.clone(),
)
.await
- .context("fetching Codex latest release")?;
-
- let version_dir = dir.join(&release.tag_name);
- if !fs.is_dir(&version_dir).await {
- if let Some(mut status_tx) = status_tx {
- status_tx.send("Installingโฆ".into()).ok();
- }
+ {
+ Ok(release) => {
+ let version_dir = dir.join(&release.tag_name);
+ if !fs.is_dir(&version_dir).await {
+ if let Some(ref mut status_tx) = status_tx {
+ status_tx.send("Installingโฆ".into()).ok();
+ }
- let tag = release.tag_name.clone();
- let version_number = tag.trim_start_matches('v');
- let asset_name = asset_name(version_number)
- .context("codex acp is not supported for this architecture")?;
- let asset = release
- .assets
- .into_iter()
- .find(|asset| asset.name == asset_name)
- .with_context(|| format!("no asset found matching `{asset_name:?}`"))?;
- // Strip "sha256:" prefix from digest if present (GitHub API format)
- let digest = asset
- .digest
- .as_deref()
- .and_then(|d| d.strip_prefix("sha256:").or(Some(d)));
- ::http_client::github_download::download_server_binary(
- &*http,
- &asset.browser_download_url,
- digest,
- &version_dir,
- if cfg!(target_os = "windows") && cfg!(target_arch = "x86_64") {
- AssetKind::Zip
+ let tag = release.tag_name.clone();
+ let version_number = tag.trim_start_matches('v');
+ let asset_name = asset_name(version_number)
+ .context("codex acp is not supported for this architecture")?;
+ let asset = release
+ .assets
+ .into_iter()
+ .find(|asset| asset.name == asset_name)
+ .with_context(|| {
+ format!("no asset found matching `{asset_name:?}`")
+ })?;
+ // Strip "sha256:" prefix from digest if present (GitHub API format)
+ let digest = asset
+ .digest
+ .as_deref()
+ .and_then(|d| d.strip_prefix("sha256:").or(Some(d)));
+ match ::http_client::github_download::download_server_binary(
+ &*http,
+ &asset.browser_download_url,
+ digest,
+ &version_dir,
+ if cfg!(target_os = "windows") && cfg!(target_arch = "x86_64") {
+ AssetKind::Zip
+ } else {
+ AssetKind::TarGz
+ },
+ )
+ .await
+ {
+ Ok(()) => {
+ // remove older versions
+ util::fs::remove_matching(&dir, |entry| entry != version_dir)
+ .await;
+ version_dir
+ }
+ Err(err) => {
+ log::error!(
+ "Failed to download Codex release {}: {err:#}",
+ release.tag_name
+ );
+ fallback_to_latest_local_version(err).await?
+ }
+ }
} else {
- AssetKind::TarGz
- },
- )
- .await?;
-
- // remove older versions
- util::fs::remove_matching(&dir, |entry| entry != version_dir).await;
- }
-
- let bin_name = if cfg!(windows) {
- "codex-acp.exe"
- } else {
- "codex-acp"
+ version_dir
+ }
+ }
+ Err(err) => {
+ log::error!("Failed to fetch Codex latest release: {err:#}");
+ fallback_to_latest_local_version(err).await?
+ }
};
+
let bin_path = version_dir.join(bin_name);
anyhow::ensure!(
fs.is_file(&bin_path).await,
@@ -1501,8 +1558,8 @@ fn get_platform_info() -> Option<(&'static str, &'static str, &'static str)> {
return None;
};
- // Only Windows x86_64 uses .zip in release assets
- let ext = if cfg!(target_os = "windows") && cfg!(target_arch = "x86_64") {
+ // Windows uses .zip in release assets
+ let ext = if cfg!(target_os = "windows") {
"zip"
} else {
"tar.gz"
@@ -692,7 +692,7 @@ impl DapStore {
}
VariableLookupKind::Expression => {
let Ok(eval_task) = session.read_with(cx, |session, _| {
- session.mode.request_dap(EvaluateCommand {
+ session.state.request_dap(EvaluateCommand {
expression: inline_value_location.variable_name.clone(),
frame_id: Some(stack_frame_id),
source: None,
@@ -1,7 +1,3 @@
-use crate::debugger::breakpoint_store::BreakpointSessionState;
-use crate::debugger::dap_command::{DataBreakpointContext, ReadMemory};
-use crate::debugger::memory::{self, Memory, MemoryIterator, MemoryPageBuilder, PageAddress};
-
use super::breakpoint_store::{
BreakpointStore, BreakpointStoreEvent, BreakpointUpdatedReason, SourceBreakpoint,
};
@@ -14,6 +10,9 @@ use super::dap_command::{
TerminateCommand, TerminateThreadsCommand, ThreadsCommand, VariablesCommand,
};
use super::dap_store::DapStore;
+use crate::debugger::breakpoint_store::BreakpointSessionState;
+use crate::debugger::dap_command::{DataBreakpointContext, ReadMemory};
+use crate::debugger::memory::{self, Memory, MemoryIterator, MemoryPageBuilder, PageAddress};
use anyhow::{Context as _, Result, anyhow, bail};
use base64::Engine;
use collections::{HashMap, HashSet, IndexMap};
@@ -42,15 +41,13 @@ use gpui::{
Task, WeakEntity,
};
use http_client::HttpClient;
-
use node_runtime::NodeRuntime;
use remote::RemoteClient;
-use rpc::ErrorExt;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use smol::net::{TcpListener, TcpStream};
use std::any::TypeId;
-use std::collections::BTreeMap;
+use std::collections::{BTreeMap, VecDeque};
use std::net::Ipv4Addr;
use std::ops::RangeInclusive;
use std::path::PathBuf;
@@ -71,6 +68,9 @@ use util::command::new_smol_command;
use util::{ResultExt, debug_panic, maybe};
use worktree::Worktree;
+const MAX_TRACKED_OUTPUT_EVENTS: usize = 5000;
+const DEBUG_HISTORY_LIMIT: usize = 10;
+
#[derive(Debug, Copy, Clone, Hash, PartialEq, PartialOrd, Ord, Eq)]
#[repr(transparent)]
pub struct ThreadId(pub i64);
@@ -118,11 +118,11 @@ impl ThreadStatus {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct Thread {
dap: dap::Thread,
stack_frames: Vec<StackFrame>,
- stack_frames_error: Option<anyhow::Error>,
+ stack_frames_error: Option<SharedString>,
_has_stopped: bool,
}
@@ -672,7 +672,18 @@ impl ThreadStates {
.any(|status| *status == ThreadStatus::Stopped)
}
}
-const MAX_TRACKED_OUTPUT_EVENTS: usize = 5000;
+
+// TODO(debugger): Wrap dap types with reference counting so the UI doesn't have to clone them on refresh
+#[derive(Default)]
+pub struct SessionSnapshot {
+ threads: IndexMap<ThreadId, Thread>,
+ thread_states: ThreadStates,
+ variables: HashMap<VariableReference, Vec<dap::Variable>>,
+ stack_frames: IndexMap<StackFrameId, StackFrame>,
+ locations: HashMap<u64, dap::LocationsResponse>,
+ modules: Vec<dap::Module>,
+ loaded_sources: Vec<dap::Source>,
+}
type IsEnabled = bool;
@@ -680,23 +691,19 @@ type IsEnabled = bool;
pub struct OutputToken(pub usize);
/// Represents a current state of a single debug adapter and provides ways to mutate it.
pub struct Session {
- pub mode: SessionState,
+ pub state: SessionState,
+ active_snapshot: SessionSnapshot,
+ snapshots: VecDeque<SessionSnapshot>,
+ selected_snapshot_index: Option<usize>,
id: SessionId,
label: Option<SharedString>,
adapter: DebugAdapterName,
pub(super) capabilities: Capabilities,
child_session_ids: HashSet<SessionId>,
parent_session: Option<Entity<Session>>,
- modules: Vec<dap::Module>,
- loaded_sources: Vec<dap::Source>,
output_token: OutputToken,
output: Box<circular_buffer::CircularBuffer<MAX_TRACKED_OUTPUT_EVENTS, dap::OutputEvent>>,
- threads: IndexMap<ThreadId, Thread>,
- thread_states: ThreadStates,
watchers: HashMap<SharedString, Watcher>,
- variables: HashMap<VariableReference, Vec<dap::Variable>>,
- stack_frames: IndexMap<StackFrameId, StackFrame>,
- locations: HashMap<u64, dap::LocationsResponse>,
is_session_terminated: bool,
requests: HashMap<TypeId, HashMap<RequestSlot, Shared<Task<Option<()>>>>>,
pub(crate) breakpoint_store: Entity<BreakpointStore>,
@@ -801,6 +808,7 @@ pub enum SessionEvent {
},
DataBreakpointInfo,
ConsoleOutput,
+ HistoricSnapshotSelected,
}
#[derive(Clone, Debug, PartialEq, Eq)]
@@ -858,24 +866,20 @@ impl Session {
.detach();
Self {
- mode: SessionState::Booting(None),
+ state: SessionState::Booting(None),
+ snapshots: VecDeque::with_capacity(DEBUG_HISTORY_LIMIT),
+ selected_snapshot_index: None,
+ active_snapshot: Default::default(),
id: session_id,
child_session_ids: HashSet::default(),
parent_session,
capabilities: Capabilities::default(),
watchers: HashMap::default(),
- variables: Default::default(),
- stack_frames: Default::default(),
- thread_states: ThreadStates::default(),
output_token: OutputToken(0),
output: circular_buffer::CircularBuffer::boxed(),
requests: HashMap::default(),
- modules: Vec::default(),
- loaded_sources: Vec::default(),
- threads: IndexMap::default(),
background_tasks: Vec::default(),
restart_task: None,
- locations: Default::default(),
is_session_terminated: false,
ignore_breakpoints: false,
breakpoint_store,
@@ -899,7 +903,7 @@ impl Session {
}
pub fn worktree(&self) -> Option<Entity<Worktree>> {
- match &self.mode {
+ match &self.state {
SessionState::Booting(_) => None,
SessionState::Running(local_mode) => local_mode.worktree.upgrade(),
}
@@ -960,7 +964,7 @@ impl Session {
)
.await?;
this.update(cx, |this, cx| {
- match &mut this.mode {
+ match &mut this.state {
SessionState::Booting(task) if task.is_some() => {
task.take().unwrap().detach_and_log_err(cx);
}
@@ -969,7 +973,7 @@ impl Session {
debug_panic!("Attempting to boot a session that is already running");
}
};
- this.mode = SessionState::Running(mode);
+ this.state = SessionState::Running(mode);
cx.emit(SessionStateEvent::Running);
})?;
@@ -1061,7 +1065,7 @@ impl Session {
}
pub fn binary(&self) -> Option<&DebugAdapterBinary> {
- match &self.mode {
+ match &self.state {
SessionState::Booting(_) => None,
SessionState::Running(running_mode) => Some(&running_mode.binary),
}
@@ -1107,25 +1111,25 @@ impl Session {
}
pub fn is_started(&self) -> bool {
- match &self.mode {
+ match &self.state {
SessionState::Booting(_) => false,
SessionState::Running(running) => running.is_started,
}
}
pub fn is_building(&self) -> bool {
- matches!(self.mode, SessionState::Booting(_))
+ matches!(self.state, SessionState::Booting(_))
}
pub fn as_running_mut(&mut self) -> Option<&mut RunningMode> {
- match &mut self.mode {
+ match &mut self.state {
SessionState::Running(local_mode) => Some(local_mode),
SessionState::Booting(_) => None,
}
}
pub fn as_running(&self) -> Option<&RunningMode> {
- match &self.mode {
+ match &self.state {
SessionState::Running(local_mode) => Some(local_mode),
SessionState::Booting(_) => None,
}
@@ -1269,7 +1273,7 @@ impl Session {
let adapter_id = self.adapter().to_string();
let request = Initialize { adapter_id };
- let SessionState::Running(running) = &self.mode else {
+ let SessionState::Running(running) = &self.state else {
return Task::ready(Err(anyhow!(
"Cannot send initialize request, task still building"
)));
@@ -1317,7 +1321,7 @@ impl Session {
dap_store: WeakEntity<DapStore>,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
- match &self.mode {
+ match &self.state {
SessionState::Running(local_mode) => {
local_mode.initialize_sequence(&self.capabilities, initialize_rx, dap_store, cx)
}
@@ -1333,10 +1337,12 @@ impl Session {
active_thread_id: ThreadId,
cx: &mut Context<Self>,
) {
- match &mut self.mode {
+ match &mut self.state {
SessionState::Running(local_mode) => {
if !matches!(
- self.thread_states.thread_state(active_thread_id),
+ self.active_snapshot
+ .thread_states
+ .thread_state(active_thread_id),
Some(ThreadStatus::Stopped)
) {
return;
@@ -1411,8 +1417,55 @@ impl Session {
})
}
+ fn session_state(&self) -> &SessionSnapshot {
+ self.selected_snapshot_index
+ .and_then(|ix| self.snapshots.get(ix))
+ .unwrap_or_else(|| &self.active_snapshot)
+ }
+
+ fn push_to_history(&mut self) {
+ if !self.has_ever_stopped() {
+ return;
+ }
+
+ while self.snapshots.len() >= DEBUG_HISTORY_LIMIT {
+ self.snapshots.pop_front();
+ }
+
+ self.snapshots
+ .push_back(std::mem::take(&mut self.active_snapshot));
+ }
+
+ pub fn historic_snapshots(&self) -> &VecDeque<SessionSnapshot> {
+ &self.snapshots
+ }
+
+ pub fn select_historic_snapshot(&mut self, ix: Option<usize>, cx: &mut Context<Session>) {
+ if self.selected_snapshot_index == ix {
+ return;
+ }
+
+ if self
+ .selected_snapshot_index
+ .is_some_and(|ix| self.snapshots.len() <= ix)
+ {
+ debug_panic!("Attempted to select a debug session with an out of bounds index");
+ return;
+ }
+
+ self.selected_snapshot_index = ix;
+ cx.emit(SessionEvent::HistoricSnapshotSelected);
+ cx.notify();
+ }
+
+ pub fn active_snapshot_index(&self) -> Option<usize> {
+ self.selected_snapshot_index
+ }
+
fn handle_stopped_event(&mut self, event: StoppedEvent, cx: &mut Context<Self>) {
- self.mode.stopped();
+ self.push_to_history();
+
+ self.state.stopped();
// todo(debugger): Find a clean way to get around the clone
let breakpoint_store = self.breakpoint_store.clone();
if let Some((local, path)) = self.as_running_mut().and_then(|local| {
@@ -1431,14 +1484,16 @@ impl Session {
};
if event.all_threads_stopped.unwrap_or_default() || event.thread_id.is_none() {
- self.thread_states.stop_all_threads();
+ self.active_snapshot.thread_states.stop_all_threads();
self.invalidate_command_type::<StackTraceCommand>();
}
// Event if we stopped all threads we still need to insert the thread_id
// to our own data
if let Some(thread_id) = event.thread_id {
- self.thread_states.stop_thread(ThreadId(thread_id));
+ self.active_snapshot
+ .thread_states
+ .stop_thread(ThreadId(thread_id));
self.invalidate_state(
&StackTraceCommand {
@@ -1451,8 +1506,8 @@ impl Session {
}
self.invalidate_generic();
- self.threads.clear();
- self.variables.clear();
+ self.active_snapshot.threads.clear();
+ self.active_snapshot.variables.clear();
cx.emit(SessionEvent::Stopped(
event
.thread_id
@@ -1474,12 +1529,13 @@ impl Session {
Events::Stopped(event) => self.handle_stopped_event(event, cx),
Events::Continued(event) => {
if event.all_threads_continued.unwrap_or_default() {
- self.thread_states.continue_all_threads();
+ self.active_snapshot.thread_states.continue_all_threads();
self.breakpoint_store.update(cx, |store, cx| {
store.remove_active_position(Some(self.session_id()), cx)
});
} else {
- self.thread_states
+ self.active_snapshot
+ .thread_states
.continue_thread(ThreadId(event.thread_id));
}
// todo(debugger): We should be able to get away with only invalidating generic if all threads were continued
@@ -1496,10 +1552,12 @@ impl Session {
match event.reason {
dap::ThreadEventReason::Started => {
- self.thread_states.continue_thread(thread_id);
+ self.active_snapshot
+ .thread_states
+ .continue_thread(thread_id);
}
dap::ThreadEventReason::Exited => {
- self.thread_states.exit_thread(thread_id);
+ self.active_snapshot.thread_states.exit_thread(thread_id);
}
reason => {
log::error!("Unhandled thread event reason {:?}", reason);
@@ -1526,10 +1584,11 @@ impl Session {
Events::Module(event) => {
match event.reason {
dap::ModuleEventReason::New => {
- self.modules.push(event.module);
+ self.active_snapshot.modules.push(event.module);
}
dap::ModuleEventReason::Changed => {
if let Some(module) = self
+ .active_snapshot
.modules
.iter_mut()
.find(|other| event.module.id == other.id)
@@ -1538,7 +1597,9 @@ impl Session {
}
}
dap::ModuleEventReason::Removed => {
- self.modules.retain(|other| event.module.id != other.id);
+ self.active_snapshot
+ .modules
+ .retain(|other| event.module.id != other.id);
}
}
@@ -1612,8 +1673,9 @@ impl Session {
);
}
- if !self.thread_states.any_stopped_thread()
- && request.type_id() != TypeId::of::<ThreadsCommand>()
+ if (!self.active_snapshot.thread_states.any_stopped_thread()
+ && request.type_id() != TypeId::of::<ThreadsCommand>())
+ || self.selected_snapshot_index.is_some()
|| self.is_session_terminated
{
return;
@@ -1629,7 +1691,7 @@ impl Session {
let task = Self::request_inner::<Arc<T>>(
&self.capabilities,
- &self.mode,
+ &self.state,
command,
|this, result, cx| {
process_result(this, result, cx);
@@ -1697,7 +1759,7 @@ impl Session {
+ 'static,
cx: &mut Context<Self>,
) -> Task<Option<T::Response>> {
- Self::request_inner(&self.capabilities, &self.mode, request, process_result, cx)
+ Self::request_inner(&self.capabilities, &self.state, request, process_result, cx)
}
fn invalidate_command_type<Command: LocalDapCommand>(&mut self) {
@@ -1730,11 +1792,11 @@ impl Session {
}
pub fn any_stopped_thread(&self) -> bool {
- self.thread_states.any_stopped_thread()
+ self.active_snapshot.thread_states.any_stopped_thread()
}
pub fn thread_status(&self, thread_id: ThreadId) -> ThreadStatus {
- self.thread_states.thread_status(thread_id)
+ self.active_snapshot.thread_states.thread_status(thread_id)
}
pub fn threads(&mut self, cx: &mut Context<Self>) -> Vec<(dap::Thread, ThreadStatus)> {
@@ -1745,7 +1807,7 @@ impl Session {
return;
};
- this.threads = result
+ this.active_snapshot.threads = result
.into_iter()
.map(|thread| (ThreadId(thread.id), Thread::from(thread)))
.collect();
@@ -1757,12 +1819,14 @@ impl Session {
cx,
);
- self.threads
+ let state = self.session_state();
+ state
+ .threads
.values()
.map(|thread| {
(
thread.dap.clone(),
- self.thread_states.thread_status(ThreadId(thread.dap.id)),
+ state.thread_states.thread_status(ThreadId(thread.dap.id)),
)
})
.collect()
@@ -1776,14 +1840,14 @@ impl Session {
return;
};
- this.modules = result;
+ this.active_snapshot.modules = result;
cx.emit(SessionEvent::Modules);
cx.notify();
},
cx,
);
- &self.modules
+ &self.session_state().modules
}
// CodeLLDB returns the size of a pointed-to-memory, which we can use to make the experience of go-to-memory better.
@@ -2034,14 +2098,13 @@ impl Session {
let Some(result) = result.log_err() else {
return;
};
- this.loaded_sources = result;
+ this.active_snapshot.loaded_sources = result;
cx.emit(SessionEvent::LoadedSources);
cx.notify();
},
cx,
);
-
- &self.loaded_sources
+ &self.session_state().loaded_sources
}
fn fallback_to_manual_restart(
@@ -2073,7 +2136,7 @@ impl Session {
Some(response)
}
None => {
- this.thread_states.stop_thread(thread_id);
+ this.active_snapshot.thread_states.stop_thread(thread_id);
cx.notify();
None
}
@@ -2149,10 +2212,10 @@ impl Session {
}
self.is_session_terminated = true;
- self.thread_states.exit_all_threads();
+ self.active_snapshot.thread_states.exit_all_threads();
cx.notify();
- let task = match &mut self.mode {
+ let task = match &mut self.state {
SessionState::Running(_) => {
if self
.capabilities
@@ -2213,9 +2276,13 @@ impl Session {
}
pub fn continue_thread(&mut self, thread_id: ThreadId, cx: &mut Context<Self>) {
+ self.select_historic_snapshot(None, cx);
+
let supports_single_thread_execution_requests =
self.capabilities.supports_single_thread_execution_requests;
- self.thread_states.continue_thread(thread_id);
+ self.active_snapshot
+ .thread_states
+ .continue_thread(thread_id);
self.request(
ContinueCommand {
args: ContinueArguments {
@@ -2230,21 +2297,24 @@ impl Session {
}
pub fn adapter_client(&self) -> Option<Arc<DebugAdapterClient>> {
- match self.mode {
+ match self.state {
SessionState::Running(ref local) => Some(local.client.clone()),
SessionState::Booting(_) => None,
}
}
pub fn has_ever_stopped(&self) -> bool {
- self.mode.has_ever_stopped()
+ self.state.has_ever_stopped()
}
+
pub fn step_over(
&mut self,
thread_id: ThreadId,
granularity: SteppingGranularity,
cx: &mut Context<Self>,
) {
+ self.select_historic_snapshot(None, cx);
+
let supports_single_thread_execution_requests =
self.capabilities.supports_single_thread_execution_requests;
let supports_stepping_granularity = self
@@ -2260,7 +2330,7 @@ impl Session {
},
};
- self.thread_states.process_step(thread_id);
+ self.active_snapshot.thread_states.process_step(thread_id);
self.request(
command,
Self::on_step_response::<NextCommand>(thread_id),
@@ -2275,6 +2345,8 @@ impl Session {
granularity: SteppingGranularity,
cx: &mut Context<Self>,
) {
+ self.select_historic_snapshot(None, cx);
+
let supports_single_thread_execution_requests =
self.capabilities.supports_single_thread_execution_requests;
let supports_stepping_granularity = self
@@ -2290,7 +2362,7 @@ impl Session {
},
};
- self.thread_states.process_step(thread_id);
+ self.active_snapshot.thread_states.process_step(thread_id);
self.request(
command,
Self::on_step_response::<StepInCommand>(thread_id),
@@ -2305,6 +2377,8 @@ impl Session {
granularity: SteppingGranularity,
cx: &mut Context<Self>,
) {
+ self.select_historic_snapshot(None, cx);
+
let supports_single_thread_execution_requests =
self.capabilities.supports_single_thread_execution_requests;
let supports_stepping_granularity = self
@@ -2320,7 +2394,7 @@ impl Session {
},
};
- self.thread_states.process_step(thread_id);
+ self.active_snapshot.thread_states.process_step(thread_id);
self.request(
command,
Self::on_step_response::<StepOutCommand>(thread_id),
@@ -2335,6 +2409,8 @@ impl Session {
granularity: SteppingGranularity,
cx: &mut Context<Self>,
) {
+ self.select_historic_snapshot(None, cx);
+
let supports_single_thread_execution_requests =
self.capabilities.supports_single_thread_execution_requests;
let supports_stepping_granularity = self
@@ -2350,7 +2426,7 @@ impl Session {
},
};
- self.thread_states.process_step(thread_id);
+ self.active_snapshot.thread_states.process_step(thread_id);
self.request(
command,
@@ -2365,9 +2441,9 @@ impl Session {
thread_id: ThreadId,
cx: &mut Context<Self>,
) -> Result<Vec<StackFrame>> {
- if self.thread_states.thread_status(thread_id) == ThreadStatus::Stopped
+ if self.active_snapshot.thread_states.thread_status(thread_id) == ThreadStatus::Stopped
&& self.requests.contains_key(&ThreadsCommand.type_id())
- && self.threads.contains_key(&thread_id)
+ && self.active_snapshot.threads.contains_key(&thread_id)
// ^ todo(debugger): We need a better way to check that we're not querying stale data
// We could still be using an old thread id and have sent a new thread's request
// This isn't the biggest concern right now because it hasn't caused any issues outside of tests
@@ -2381,7 +2457,8 @@ impl Session {
},
move |this, stack_frames, cx| {
let entry =
- this.threads
+ this.active_snapshot
+ .threads
.entry(thread_id)
.and_modify(|thread| match &stack_frames {
Ok(stack_frames) => {
@@ -2394,7 +2471,7 @@ impl Session {
}
Err(error) => {
thread.stack_frames.clear();
- thread.stack_frames_error = Some(error.cloned());
+ thread.stack_frames_error = Some(error.to_string().into());
}
});
debug_assert!(
@@ -2402,7 +2479,7 @@ impl Session {
"Sent request for thread_id that doesn't exist"
);
if let Ok(stack_frames) = stack_frames {
- this.stack_frames.extend(
+ this.active_snapshot.stack_frames.extend(
stack_frames
.into_iter()
.filter(|frame| {
@@ -2427,10 +2504,10 @@ impl Session {
);
}
- match self.threads.get(&thread_id) {
+ match self.session_state().threads.get(&thread_id) {
Some(thread) => {
if let Some(error) = &thread.stack_frames_error {
- Err(error.cloned())
+ Err(anyhow!(error.to_string()))
} else {
Ok(thread.stack_frames.clone())
}
@@ -2457,6 +2534,7 @@ impl Session {
}
let entry = this
+ .active_snapshot
.stack_frames
.entry(stack_frame_id)
.and_modify(|stack_frame| {
@@ -2474,7 +2552,8 @@ impl Session {
);
}
- self.stack_frames
+ self.session_state()
+ .stack_frames
.get(&stack_frame_id)
.map(|frame| frame.scopes.as_slice())
.unwrap_or_default()
@@ -2486,7 +2565,8 @@ impl Session {
globals: bool,
locals: bool,
) -> Vec<dap::Variable> {
- let Some(stack_frame) = self.stack_frames.get(&stack_frame_id) else {
+ let state = self.session_state();
+ let Some(stack_frame) = state.stack_frames.get(&stack_frame_id) else {
return Vec::new();
};
@@ -2497,7 +2577,7 @@ impl Session {
(scope.name.to_lowercase().contains("local") && locals)
|| (scope.name.to_lowercase().contains("global") && globals)
})
- .filter_map(|scope| self.variables.get(&scope.variables_reference))
+ .filter_map(|scope| state.variables.get(&scope.variables_reference))
.flatten()
.cloned()
.collect()
@@ -2513,7 +2593,7 @@ impl Session {
frame_id: u64,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
- let request = self.mode.request_dap(EvaluateCommand {
+ let request = self.state.request_dap(EvaluateCommand {
expression: expression.to_string(),
context: Some(EvaluateArgumentsContext::Watch),
frame_id: Some(frame_id),
@@ -2570,7 +2650,9 @@ impl Session {
return;
};
- this.variables.insert(variables_reference, variables);
+ this.active_snapshot
+ .variables
+ .insert(variables_reference, variables);
cx.emit(SessionEvent::Variables);
cx.emit(SessionEvent::InvalidateInlineValue);
@@ -2578,7 +2660,8 @@ impl Session {
cx,
);
- self.variables
+ self.session_state()
+ .variables
.get(&variables_reference)
.cloned()
.unwrap_or_default()
@@ -2645,7 +2728,7 @@ impl Session {
location_reference: None,
};
self.push_output(event);
- let request = self.mode.request_dap(EvaluateCommand {
+ let request = self.state.request_dap(EvaluateCommand {
expression,
context,
frame_id,
@@ -2656,6 +2739,8 @@ impl Session {
this.update(cx, |this, cx| {
this.memory.clear(cx.background_executor());
this.invalidate_command_type::<ReadMemory>();
+ this.invalidate_command_type::<VariablesCommand>();
+ cx.emit(SessionEvent::Variables);
match response {
Ok(response) => {
let event = dap::OutputEvent {
@@ -2703,15 +2788,15 @@ impl Session {
let Some(response) = response.log_err() else {
return;
};
- this.locations.insert(reference, response);
+ this.active_snapshot.locations.insert(reference, response);
},
cx,
);
- self.locations.get(&reference).cloned()
+ self.session_state().locations.get(&reference).cloned()
}
pub fn is_attached(&self) -> bool {
- let SessionState::Running(local_mode) = &self.mode else {
+ let SessionState::Running(local_mode) = &self.state else {
return false;
};
local_mode.binary.request_args.request == StartDebuggingRequestArgumentsRequest::Attach
@@ -2747,7 +2832,7 @@ impl Session {
}
pub fn thread_state(&self, thread_id: ThreadId) -> Option<ThreadStatus> {
- self.thread_states.thread_state(thread_id)
+ self.session_state().thread_states.thread_state(thread_id)
}
pub fn quirks(&self) -> SessionQuirks {
@@ -3296,6 +3296,8 @@ impl RepositorySnapshot {
.iter()
.map(stash_to_proto)
.collect(),
+ remote_upstream_url: self.remote_upstream_url.clone(),
+ remote_origin_url: self.remote_origin_url.clone(),
}
}
@@ -3365,6 +3367,8 @@ impl RepositorySnapshot {
.iter()
.map(stash_to_proto)
.collect(),
+ remote_upstream_url: self.remote_upstream_url.clone(),
+ remote_origin_url: self.remote_origin_url.clone(),
}
}
@@ -5395,6 +5399,8 @@ impl Repository {
cx.emit(RepositoryEvent::StashEntriesChanged)
}
self.snapshot.stash_entries = new_stash_entries;
+ self.snapshot.remote_upstream_url = update.remote_upstream_url;
+ self.snapshot.remote_origin_url = update.remote_origin_url;
let edits = update
.removed_statuses
@@ -5954,11 +5960,7 @@ fn serialize_blame_buffer_response(blame: Option<git::blame::Blame>) -> proto::B
.collect::<Vec<_>>();
proto::BlameBufferResponse {
- blame_response: Some(proto::blame_buffer_response::BlameResponse {
- entries,
- messages,
- remote_url: blame.remote_url,
- }),
+ blame_response: Some(proto::blame_buffer_response::BlameResponse { entries, messages }),
}
}
@@ -5995,11 +5997,7 @@ fn deserialize_blame_buffer_response(
.filter_map(|message| Some((git::Oid::from_bytes(&message.oid).ok()?, message.message)))
.collect::<HashMap<_, _>>();
- Some(Blame {
- entries,
- messages,
- remote_url: response.remote_url,
- })
+ Some(Blame { entries, messages })
}
fn branch_to_proto(branch: &git::repository::Branch) -> proto::Branch {
@@ -6147,7 +6145,6 @@ async fn compute_snapshot(
events.push(RepositoryEvent::BranchChanged);
}
- // Used by edit prediction data collection
let remote_origin_url = backend.remote_url("origin").await;
let remote_upstream_url = backend.remote_url("upstream").await;
@@ -14,6 +14,7 @@ use gpui::{
use language::Buffer;
use text::BufferId;
use util::ResultExt;
+use ztracing::instrument;
use crate::{
Project,
@@ -254,6 +255,7 @@ impl BranchDiff {
self.repo.as_ref()
}
+ #[instrument(skip_all)]
pub fn load_buffers(&mut self, cx: &mut Context<Self>) -> Vec<DiffBuffer> {
let mut output = Vec::default();
let Some(repo) = self.repo.clone() else {
@@ -318,6 +320,7 @@ impl BranchDiff {
output
}
+ #[instrument(skip_all)]
fn load_buffer(
branch_diff: Option<git::status::TreeDiffStatus>,
project_path: crate::ProjectPath,
@@ -1,4 +1,4 @@
-use gpui::{App, Context, Entity, EventEmitter};
+use gpui::{App, Context, Entity, EventEmitter, SharedString};
use std::{cmp::Ordering, ops::Range, sync::Arc};
use text::{Anchor, BufferId, OffsetRangeExt as _};
@@ -92,6 +92,8 @@ impl ConflictSetSnapshot {
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ConflictRegion {
+ pub ours_branch_name: SharedString,
+ pub theirs_branch_name: SharedString,
pub range: Range<Anchor>,
pub ours: Range<Anchor>,
pub theirs: Range<Anchor>,
@@ -179,18 +181,25 @@ impl ConflictSet {
let mut conflict_start: Option<usize> = None;
let mut ours_start: Option<usize> = None;
let mut ours_end: Option<usize> = None;
+ let mut ours_branch_name: Option<SharedString> = None;
let mut base_start: Option<usize> = None;
let mut base_end: Option<usize> = None;
let mut theirs_start: Option<usize> = None;
+ let mut theirs_branch_name: Option<SharedString> = None;
while let Some(line) = lines.next() {
let line_end = line_pos + line.len();
- if line.starts_with("<<<<<<< ") {
+ if let Some(branch_name) = line.strip_prefix("<<<<<<< ") {
// If we see a new conflict marker while already parsing one,
// abandon the previous one and start a new one
conflict_start = Some(line_pos);
ours_start = Some(line_end + 1);
+
+ let branch_name = branch_name.trim();
+ if !branch_name.is_empty() {
+ ours_branch_name = Some(SharedString::new(branch_name));
+ }
} else if line.starts_with("||||||| ")
&& conflict_start.is_some()
&& ours_start.is_some()
@@ -208,12 +217,17 @@ impl ConflictSet {
base_end = Some(line_pos);
}
theirs_start = Some(line_end + 1);
- } else if line.starts_with(">>>>>>> ")
+ } else if let Some(branch_name) = line.strip_prefix(">>>>>>> ")
&& conflict_start.is_some()
&& ours_start.is_some()
&& ours_end.is_some()
&& theirs_start.is_some()
{
+ let branch_name = branch_name.trim();
+ if !branch_name.is_empty() {
+ theirs_branch_name = Some(SharedString::new(branch_name));
+ }
+
let theirs_end = line_pos;
let conflict_end = (line_end + 1).min(buffer_len);
@@ -229,6 +243,12 @@ impl ConflictSet {
.map(|(start, end)| buffer.anchor_after(start)..buffer.anchor_before(end));
conflicts.push(ConflictRegion {
+ ours_branch_name: ours_branch_name
+ .take()
+ .unwrap_or_else(|| SharedString::new_static("HEAD")),
+ theirs_branch_name: theirs_branch_name
+ .take()
+ .unwrap_or_else(|| SharedString::new_static("Origin")),
range,
ours,
theirs,
@@ -304,6 +324,8 @@ mod tests {
let first = &conflict_snapshot.conflicts[0];
assert!(first.base.is_none());
+ assert_eq!(first.ours_branch_name.as_ref(), "HEAD");
+ assert_eq!(first.theirs_branch_name.as_ref(), "branch-name");
let our_text = snapshot
.text_for_range(first.ours.clone())
.collect::<String>();
@@ -315,6 +337,8 @@ mod tests {
let second = &conflict_snapshot.conflicts[1];
assert!(second.base.is_some());
+ assert_eq!(second.ours_branch_name.as_ref(), "HEAD");
+ assert_eq!(second.theirs_branch_name.as_ref(), "branch-name");
let our_text = snapshot
.text_for_range(second.ours.clone())
.collect::<String>();
@@ -381,6 +405,8 @@ mod tests {
// The conflict should have our version, their version, but no base
let conflict = &conflict_snapshot.conflicts[0];
assert!(conflict.base.is_none());
+ assert_eq!(conflict.ours_branch_name.as_ref(), "HEAD");
+ assert_eq!(conflict.theirs_branch_name.as_ref(), "branch-nested");
// Check that the nested conflict was detected correctly
let our_text = snapshot
@@ -407,6 +433,14 @@ mod tests {
let conflict_snapshot = ConflictSet::parse(&snapshot);
assert_eq!(conflict_snapshot.conflicts.len(), 1);
+ assert_eq!(
+ conflict_snapshot.conflicts[0].ours_branch_name.as_ref(),
+ "ours"
+ );
+ assert_eq!(
+ conflict_snapshot.conflicts[0].theirs_branch_name.as_ref(),
+ "Origin" // default branch name if there is none
+ );
}
#[test]
@@ -449,6 +483,38 @@ mod tests {
let conflict_snapshot = ConflictSet::parse(&snapshot);
assert_eq!(conflict_snapshot.conflicts.len(), 4);
+ assert_eq!(
+ conflict_snapshot.conflicts[0].ours_branch_name.as_ref(),
+ "HEAD1"
+ );
+ assert_eq!(
+ conflict_snapshot.conflicts[0].theirs_branch_name.as_ref(),
+ "branch1"
+ );
+ assert_eq!(
+ conflict_snapshot.conflicts[1].ours_branch_name.as_ref(),
+ "HEAD2"
+ );
+ assert_eq!(
+ conflict_snapshot.conflicts[1].theirs_branch_name.as_ref(),
+ "branch2"
+ );
+ assert_eq!(
+ conflict_snapshot.conflicts[2].ours_branch_name.as_ref(),
+ "HEAD3"
+ );
+ assert_eq!(
+ conflict_snapshot.conflicts[2].theirs_branch_name.as_ref(),
+ "branch3"
+ );
+ assert_eq!(
+ conflict_snapshot.conflicts[3].ours_branch_name.as_ref(),
+ "HEAD4"
+ );
+ assert_eq!(
+ conflict_snapshot.conflicts[3].theirs_branch_name.as_ref(),
+ "branch4"
+ );
let range = test_content.find("seven").unwrap()..test_content.find("eleven").unwrap();
let range = buffer.anchor_before(range.start)..buffer.anchor_after(range.end);
@@ -12647,30 +12647,29 @@ impl LspStore {
.language_servers
.get_mut(&server_id)
.context("Could not obtain Language Servers state")?;
- local
+ let registrations = local
.language_server_dynamic_registrations
.get_mut(&server_id)
.with_context(|| {
format!("Expected dynamic registration to exist for server {server_id}")
- })?.diagnostics
+ })?;
+ registrations.diagnostics
.remove(&Some(unreg.id.clone()))
.with_context(|| format!(
"Attempted to unregister non-existent diagnostic registration with ID {}",
unreg.id)
)?;
+ let removed_last_diagnostic_provider = registrations.diagnostics.is_empty();
- let mut has_any_diagnostic_providers_still = true;
if let LanguageServerState::Running {
workspace_diagnostics_refresh_tasks,
..
} = state
{
workspace_diagnostics_refresh_tasks.remove(&Some(unreg.id.clone()));
- has_any_diagnostic_providers_still =
- !workspace_diagnostics_refresh_tasks.is_empty();
}
- if !has_any_diagnostic_providers_still {
+ if removed_last_diagnostic_provider {
server.update_capabilities(|capabilities| {
debug_assert!(capabilities.diagnostic_provider.is_some());
capabilities.diagnostic_provider = None;
@@ -28,7 +28,7 @@ use language::{
ManifestName, ManifestProvider, ManifestQuery, OffsetRangeExt, Point, ToPoint, ToolchainList,
ToolchainLister,
language_settings::{LanguageSettingsContent, language_settings},
- tree_sitter_rust, tree_sitter_typescript,
+ rust_lang, tree_sitter_typescript,
};
use lsp::{
DiagnosticSeverity, DocumentChanges, FileOperationFilter, NumberOrString, TextDocumentEdit,
@@ -746,7 +746,7 @@ async fn test_running_multiple_instances_of_a_single_server_in_one_worktree(
worktree_id,
path: rel_path("project-b/source_file.py").into(),
},
- LanguageName::new("Python"),
+ LanguageName::new_static("Python"),
cx,
)
})
@@ -762,7 +762,7 @@ async fn test_running_multiple_instances_of_a_single_server_in_one_worktree(
worktree_id,
path: rel_path("project-b/source_file.py").into(),
},
- LanguageName::new("Python"),
+ LanguageName::new_static("Python"),
cx,
)
})
@@ -10468,20 +10468,6 @@ fn js_lang() -> Arc<Language> {
))
}
-fn rust_lang() -> Arc<Language> {
- Arc::new(Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- ))
-}
-
fn python_lang(fs: Arc<FakeFs>) -> Arc<Language> {
struct PythonMootToolchainLister(Arc<FakeFs>);
#[async_trait]
@@ -111,7 +111,7 @@ impl Project {
);
let toolchains = project_path_contexts
.filter(|_| detect_venv)
- .map(|p| self.active_toolchain(p, LanguageName::new("Python"), cx))
+ .map(|p| self.active_toolchain(p, LanguageName::new_static("Python"), cx))
.collect::<Vec<_>>();
let lang_registry = self.languages.clone();
cx.spawn(async move |project, cx| {
@@ -311,7 +311,7 @@ impl Project {
);
let toolchains = project_path_contexts
.filter(|_| detect_venv)
- .map(|p| self.active_toolchain(p, LanguageName::new("Python"), cx))
+ .map(|p| self.active_toolchain(p, LanguageName::new_static("Python"), cx))
.collect::<Vec<_>>();
let remote_client = self.remote_client.clone();
let shell = match &remote_client {
@@ -94,6 +94,16 @@ pub struct ContentPromptContext {
pub diagnostic_errors: Vec<ContentPromptDiagnosticContext>,
}
+#[derive(Serialize)]
+pub struct ContentPromptContextV2 {
+ pub content_type: String,
+ pub language_name: Option<String>,
+ pub is_truncated: bool,
+ pub document_content: String,
+ pub rewrite_section: Option<String>,
+ pub diagnostic_errors: Vec<ContentPromptDiagnosticContext>,
+}
+
#[derive(Serialize)]
pub struct TerminalAssistantPromptContext {
pub os: String,
@@ -276,6 +286,88 @@ impl PromptBuilder {
Ok(())
}
+ pub fn generate_inline_transformation_prompt_v2(
+ &self,
+ language_name: Option<&LanguageName>,
+ buffer: BufferSnapshot,
+ range: Range<usize>,
+ ) -> Result<String, RenderError> {
+ let content_type = match language_name.as_ref().map(|l| l.as_ref()) {
+ None | Some("Markdown" | "Plain Text") => "text",
+ Some(_) => "code",
+ };
+
+ const MAX_CTX: usize = 50000;
+ let is_insert = range.is_empty();
+ let mut is_truncated = false;
+
+ let before_range = 0..range.start;
+ let truncated_before = if before_range.len() > MAX_CTX {
+ is_truncated = true;
+ let start = buffer.clip_offset(range.start - MAX_CTX, text::Bias::Right);
+ start..range.start
+ } else {
+ before_range
+ };
+
+ let after_range = range.end..buffer.len();
+ let truncated_after = if after_range.len() > MAX_CTX {
+ is_truncated = true;
+ let end = buffer.clip_offset(range.end + MAX_CTX, text::Bias::Left);
+ range.end..end
+ } else {
+ after_range
+ };
+
+ let mut document_content = String::new();
+ for chunk in buffer.text_for_range(truncated_before) {
+ document_content.push_str(chunk);
+ }
+ if is_insert {
+ document_content.push_str("<insert_here></insert_here>");
+ } else {
+ document_content.push_str("<rewrite_this>\n");
+ for chunk in buffer.text_for_range(range.clone()) {
+ document_content.push_str(chunk);
+ }
+ document_content.push_str("\n</rewrite_this>");
+ }
+ for chunk in buffer.text_for_range(truncated_after) {
+ document_content.push_str(chunk);
+ }
+
+ let rewrite_section = if !is_insert {
+ let mut section = String::new();
+ for chunk in buffer.text_for_range(range.clone()) {
+ section.push_str(chunk);
+ }
+ Some(section)
+ } else {
+ None
+ };
+ let diagnostics = buffer.diagnostics_in_range::<_, Point>(range, false);
+ let diagnostic_errors: Vec<ContentPromptDiagnosticContext> = diagnostics
+ .map(|entry| {
+ let start = entry.range.start;
+ ContentPromptDiagnosticContext {
+ line_number: (start.row + 1) as usize,
+ error_message: entry.diagnostic.message.clone(),
+ code_content: buffer.text_for_range(entry.range).collect(),
+ }
+ })
+ .collect();
+
+ let context = ContentPromptContextV2 {
+ content_type: content_type.to_string(),
+ language_name: language_name.map(|s| s.to_string()),
+ is_truncated,
+ document_content,
+ rewrite_section,
+ diagnostic_errors,
+ };
+ self.handlebars.lock().render("content_prompt_v2", &context)
+ }
+
pub fn generate_inline_transformation_prompt(
&self,
user_prompt: String,
@@ -124,6 +124,8 @@ message UpdateRepository {
optional GitCommitDetails head_commit_details = 11;
optional string merge_message = 12;
repeated StashEntry stash_entries = 13;
+ optional string remote_upstream_url = 14;
+ optional string remote_origin_url = 15;
}
message RemoveRepository {
@@ -500,8 +502,8 @@ message BlameBufferResponse {
message BlameResponse {
repeated BlameEntry entries = 1;
repeated CommitMessage messages = 2;
- optional string remote_url = 4;
reserved 3;
+ reserved 4;
}
optional BlameResponse blame_response = 5;
@@ -132,7 +132,8 @@ pub fn init(cx: &mut App) {
let create_new_window = open_recent.create_new_window;
with_active_or_new_workspace(cx, move |workspace, window, cx| {
let Some(recent_projects) = workspace.active_modal::<RecentProjects>(cx) else {
- RecentProjects::open(workspace, create_new_window, window, cx);
+ let focus_handle = workspace.focus_handle(cx);
+ RecentProjects::open(workspace, create_new_window, window, focus_handle, cx);
return;
};
@@ -246,11 +247,12 @@ impl RecentProjects {
workspace: &mut Workspace,
create_new_window: bool,
window: &mut Window,
+ focus_handle: FocusHandle,
cx: &mut Context<Workspace>,
) {
let weak = cx.entity().downgrade();
workspace.toggle_modal(window, cx, |window, cx| {
- let delegate = RecentProjectsDelegate::new(weak, create_new_window, true);
+ let delegate = RecentProjectsDelegate::new(weak, create_new_window, true, focus_handle);
Self::new(delegate, 34., window, cx)
})
@@ -289,10 +291,16 @@ pub struct RecentProjectsDelegate {
// Flag to reset index when there is a new query vs not reset index when user delete an item
reset_selected_match_index: bool,
has_any_non_local_projects: bool,
+ focus_handle: FocusHandle,
}
impl RecentProjectsDelegate {
- fn new(workspace: WeakEntity<Workspace>, create_new_window: bool, render_paths: bool) -> Self {
+ fn new(
+ workspace: WeakEntity<Workspace>,
+ create_new_window: bool,
+ render_paths: bool,
+ focus_handle: FocusHandle,
+ ) -> Self {
Self {
workspace,
workspaces: Vec::new(),
@@ -302,6 +310,7 @@ impl RecentProjectsDelegate {
render_paths,
reset_selected_match_index: true,
has_any_non_local_projects: false,
+ focus_handle,
}
}
@@ -532,8 +541,8 @@ impl PickerDelegate for RecentProjectsDelegate {
.unzip();
let prefix = match &location {
- SerializedWorkspaceLocation::Remote(RemoteConnectionOptions::Wsl(wsl)) => {
- Some(SharedString::from(&wsl.distro_name))
+ SerializedWorkspaceLocation::Remote(options) => {
+ Some(SharedString::from(options.display_name()))
}
_ => None,
};
@@ -544,12 +553,23 @@ impl PickerDelegate for RecentProjectsDelegate {
paths,
};
+ let focus_handle = self.focus_handle.clone();
+
let secondary_actions = h_flex()
.gap_px()
.child(
IconButton::new("open_new_window", IconName::ArrowUpRight)
.icon_size(IconSize::XSmall)
- .tooltip(Tooltip::text("Open Project in New Window"))
+ .tooltip({
+ move |_, cx| {
+ Tooltip::for_action_in(
+ "Open Project in New Window",
+ &menu::SecondaryConfirm,
+ &focus_handle,
+ cx,
+ )
+ }
+ })
.on_click(cx.listener(move |this, _event, window, cx| {
cx.stop_propagation();
window.prevent_default();
@@ -577,8 +597,9 @@ impl PickerDelegate for RecentProjectsDelegate {
.spacing(ListItemSpacing::Sparse)
.child(
h_flex()
- .flex_grow()
+ .id("projecy_info_container")
.gap_3()
+ .flex_grow()
.when(self.has_any_non_local_projects, |this| {
this.child(match location {
SerializedWorkspaceLocation::Local => Icon::new(IconName::Screen)
@@ -600,6 +621,13 @@ impl PickerDelegate for RecentProjectsDelegate {
highlighted.paths.clear();
}
highlighted.render(window, cx)
+ })
+ .tooltip(move |_, cx| {
+ let tooltip_highlighted_location = highlighted_match.clone();
+ cx.new(|_| MatchTooltip {
+ highlighted_location: tooltip_highlighted_location,
+ })
+ .into()
}),
)
.map(|el| {
@@ -608,13 +636,6 @@ impl PickerDelegate for RecentProjectsDelegate {
} else {
el.end_hover_slot(secondary_actions)
}
- })
- .tooltip(move |_, cx| {
- let tooltip_highlighted_location = highlighted_match.clone();
- cx.new(|_| MatchTooltip {
- highlighted_location: tooltip_highlighted_location,
- })
- .into()
}),
)
}
@@ -1,4 +1,5 @@
use crate::{
+ RemotePlatform,
json_log::LogRecord,
protocol::{MESSAGE_LEN_SIZE, message_len_from_buffer, read_message_with_len, write_message},
};
@@ -14,6 +15,54 @@ use smol::process::Child;
pub mod ssh;
pub mod wsl;
+/// Parses the output of `uname -sm` to determine the remote platform.
+/// Takes the last line to skip possible shell initialization output.
+fn parse_platform(output: &str) -> Result<RemotePlatform> {
+ let output = output.trim();
+ let uname = output.rsplit_once('\n').map_or(output, |(_, last)| last);
+ let Some((os, arch)) = uname.split_once(" ") else {
+ anyhow::bail!("unknown uname: {uname:?}")
+ };
+
+ let os = match os {
+ "Darwin" => "macos",
+ "Linux" => "linux",
+ _ => anyhow::bail!(
+ "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
+ ),
+ };
+
+ // exclude armv5,6,7 as they are 32-bit.
+ let arch = if arch.starts_with("armv8")
+ || arch.starts_with("armv9")
+ || arch.starts_with("arm64")
+ || arch.starts_with("aarch64")
+ {
+ "aarch64"
+ } else if arch.starts_with("x86") {
+ "x86_64"
+ } else {
+ anyhow::bail!(
+ "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
+ )
+ };
+
+ Ok(RemotePlatform { os, arch })
+}
+
+/// Parses the output of `echo $SHELL` to determine the remote shell.
+/// Takes the last line to skip possible shell initialization output.
+fn parse_shell(output: &str, fallback_shell: &str) -> String {
+ let output = output.trim();
+ let shell = output.rsplit_once('\n').map_or(output, |(_, last)| last);
+ if shell.is_empty() {
+ log::error!("$SHELL is not set, falling back to {fallback_shell}");
+ fallback_shell.to_owned()
+ } else {
+ shell.to_owned()
+ }
+}
+
fn handle_rpc_messages_over_child_process_stdio(
mut ssh_proxy_process: Child,
incoming_tx: UnboundedSender<Envelope>,
@@ -316,3 +365,63 @@ async fn which(
)),
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_parse_platform() {
+ let result = parse_platform("Linux x86_64\n").unwrap();
+ assert_eq!(result.os, "linux");
+ assert_eq!(result.arch, "x86_64");
+
+ let result = parse_platform("Darwin arm64\n").unwrap();
+ assert_eq!(result.os, "macos");
+ assert_eq!(result.arch, "aarch64");
+
+ let result = parse_platform("Linux x86_64").unwrap();
+ assert_eq!(result.os, "linux");
+ assert_eq!(result.arch, "x86_64");
+
+ let result = parse_platform("some shell init output\nLinux aarch64\n").unwrap();
+ assert_eq!(result.os, "linux");
+ assert_eq!(result.arch, "aarch64");
+
+ let result = parse_platform("some shell init output\nLinux aarch64").unwrap();
+ assert_eq!(result.os, "linux");
+ assert_eq!(result.arch, "aarch64");
+
+ assert_eq!(parse_platform("Linux armv8l\n").unwrap().arch, "aarch64");
+ assert_eq!(parse_platform("Linux aarch64\n").unwrap().arch, "aarch64");
+ assert_eq!(parse_platform("Linux x86_64\n").unwrap().arch, "x86_64");
+
+ let result = parse_platform(
+ r#"Linux x86_64 - What you're referring to as Linux, is in fact, GNU/Linux...\n"#,
+ )
+ .unwrap();
+ assert_eq!(result.os, "linux");
+ assert_eq!(result.arch, "x86_64");
+
+ assert!(parse_platform("Windows x86_64\n").is_err());
+ assert!(parse_platform("Linux armv7l\n").is_err());
+ }
+
+ #[test]
+ fn test_parse_shell() {
+ assert_eq!(parse_shell("/bin/bash\n", "sh"), "/bin/bash");
+ assert_eq!(parse_shell("/bin/zsh\n", "sh"), "/bin/zsh");
+
+ assert_eq!(parse_shell("/bin/bash", "sh"), "/bin/bash");
+ assert_eq!(
+ parse_shell("some shell init output\n/bin/bash\n", "sh"),
+ "/bin/bash"
+ );
+ assert_eq!(
+ parse_shell("some shell init output\n/bin/bash", "sh"),
+ "/bin/bash"
+ );
+ assert_eq!(parse_shell("", "sh"), "sh");
+ assert_eq!(parse_shell("\n", "sh"), "sh");
+ }
+}
@@ -1,6 +1,7 @@
use crate::{
RemoteClientDelegate, RemotePlatform,
remote_client::{CommandTemplate, RemoteConnection, RemoteConnectionOptions},
+ transport::{parse_platform, parse_shell},
};
use anyhow::{Context as _, Result, anyhow};
use async_trait::async_trait;
@@ -668,6 +669,8 @@ impl SshRemoteConnection {
delegate.set_status(Some("Downloading remote development server on host"), cx);
+ const CONNECT_TIMEOUT_SECS: &str = "10";
+
match self
.socket
.run_command(
@@ -676,6 +679,8 @@ impl SshRemoteConnection {
&[
"-f",
"-L",
+ "--connect-timeout",
+ CONNECT_TIMEOUT_SECS,
url,
"-o",
&tmp_path_gz.display(self.path_style()),
@@ -701,7 +706,15 @@ impl SshRemoteConnection {
.run_command(
self.ssh_shell_kind,
"wget",
- &[url, "-O", &tmp_path_gz.display(self.path_style())],
+ &[
+ "--connect-timeout",
+ CONNECT_TIMEOUT_SECS,
+ "--tries",
+ "1",
+ url,
+ "-O",
+ &tmp_path_gz.display(self.path_style()),
+ ],
true,
)
.await
@@ -1055,52 +1068,20 @@ impl SshSocket {
}
async fn platform(&self, shell: ShellKind) -> Result<RemotePlatform> {
- let uname = self.run_command(shell, "uname", &["-sm"], false).await?;
- let Some((os, arch)) = uname.split_once(" ") else {
- anyhow::bail!("unknown uname: {uname:?}")
- };
-
- let os = match os.trim() {
- "Darwin" => "macos",
- "Linux" => "linux",
- _ => anyhow::bail!(
- "Prebuilt remote servers are not yet available for {os:?}. See https://zed.dev/docs/remote-development"
- ),
- };
- // exclude armv5,6,7 as they are 32-bit.
- let arch = if arch.starts_with("armv8")
- || arch.starts_with("armv9")
- || arch.starts_with("arm64")
- || arch.starts_with("aarch64")
- {
- "aarch64"
- } else if arch.starts_with("x86") {
- "x86_64"
- } else {
- anyhow::bail!(
- "Prebuilt remote servers are not yet available for {arch:?}. See https://zed.dev/docs/remote-development"
- )
- };
-
- Ok(RemotePlatform { os, arch })
+ let output = self.run_command(shell, "uname", &["-sm"], false).await?;
+ parse_platform(&output)
}
async fn shell(&self) -> String {
- let default_shell = "sh";
+ const DEFAULT_SHELL: &str = "sh";
match self
.run_command(ShellKind::Posix, "sh", &["-c", "echo $SHELL"], false)
.await
{
- Ok(shell) => match shell.trim() {
- "" => {
- log::error!("$SHELL is not set, falling back to {default_shell}");
- default_shell.to_owned()
- }
- shell => shell.to_owned(),
- },
+ Ok(output) => parse_shell(&output, DEFAULT_SHELL),
Err(e) => {
- log::error!("Failed to get shell: {e}");
- default_shell.to_owned()
+ log::error!("Failed to detect remote shell: {e}");
+ DEFAULT_SHELL.to_owned()
}
}
}
@@ -1502,12 +1483,8 @@ mod tests {
"-p".to_string(),
"2222".to_string(),
"-o".to_string(),
- "StrictHostKeyChecking=no".to_string()
+ "StrictHostKeyChecking=no".to_string(),
]
);
- assert!(
- scp_args.iter().all(|arg| !arg.starts_with("-L")),
- "scp args should not contain port forward flags: {scp_args:?}"
- );
}
}
@@ -1,6 +1,7 @@
use crate::{
RemoteClientDelegate, RemotePlatform,
remote_client::{CommandTemplate, RemoteConnection, RemoteConnectionOptions},
+ transport::{parse_platform, parse_shell},
};
use anyhow::{Context, Result, anyhow, bail};
use async_trait::async_trait;
@@ -107,23 +108,22 @@ impl WslRemoteConnection {
async fn detect_platform(&self) -> Result<RemotePlatform> {
let program = self.shell_kind.prepend_command_prefix("uname");
- let arch_str = self.run_wsl_command_with_output(&program, &["-m"]).await?;
- let arch_str = arch_str.trim().to_string();
- let arch = match arch_str.as_str() {
- "x86_64" => "x86_64",
- "aarch64" | "arm64" => "aarch64",
- _ => "x86_64",
- };
- Ok(RemotePlatform { os: "linux", arch })
+ let output = self.run_wsl_command_with_output(&program, &["-sm"]).await?;
+ parse_platform(&output)
}
async fn detect_shell(&self) -> Result<String> {
- Ok(self
+ const DEFAULT_SHELL: &str = "sh";
+ match self
.run_wsl_command_with_output("sh", &["-c", "echo $SHELL"])
.await
- .inspect_err(|err| log::error!("Failed to detect remote shell: {err}"))
- .ok()
- .unwrap_or_else(|| "/bin/sh".to_string()))
+ {
+ Ok(output) => Ok(parse_shell(&output, DEFAULT_SHELL)),
+ Err(e) => {
+ log::error!("Failed to detect remote shell: {e}");
+ Ok(DEFAULT_SHELL.to_owned())
+ }
+ }
}
async fn detect_has_wsl_interop(&self) -> Result<bool> {
@@ -452,7 +452,7 @@ async fn test_remote_lsp(cx: &mut TestAppContext, server_cx: &mut TestAppContext
});
let mut fake_lsp = server_cx.update(|cx| {
- headless.read(cx).languages.register_fake_language_server(
+ headless.read(cx).languages.register_fake_lsp_server(
LanguageServerName("rust-analyzer".into()),
lsp::ServerCapabilities {
completion_provider: Some(lsp::CompletionOptions::default()),
@@ -476,7 +476,7 @@ async fn test_remote_lsp(cx: &mut TestAppContext, server_cx: &mut TestAppContext
..FakeLspAdapter::default()
},
);
- headless.read(cx).languages.register_fake_language_server(
+ headless.read(cx).languages.register_fake_lsp_server(
LanguageServerName("fake-analyzer".into()),
lsp::ServerCapabilities {
completion_provider: Some(lsp::CompletionOptions::default()),
@@ -669,7 +669,7 @@ async fn test_remote_cancel_language_server_work(
});
let mut fake_lsp = server_cx.update(|cx| {
- headless.read(cx).languages.register_fake_language_server(
+ headless.read(cx).languages.register_fake_lsp_server(
LanguageServerName("rust-analyzer".into()),
Default::default(),
None,
@@ -81,7 +81,7 @@ pub fn python_env_kernel_specifications(
worktree_id: WorktreeId,
cx: &mut App,
) -> impl Future<Output = Result<Vec<KernelSpecification>>> + use<> {
- let python_language = LanguageName::new("Python");
+ let python_language = LanguageName::new_static("Python");
let toolchains = project.read(cx).available_toolchains(
ProjectPath {
worktree_id,
@@ -18,6 +18,8 @@ rayon.workspace = true
sum_tree.workspace = true
unicode-segmentation.workspace = true
util.workspace = true
+ztracing.workspace = true
+tracing.workspace = true
[dev-dependencies]
ctor.workspace = true
@@ -30,3 +32,6 @@ zlog.workspace = true
[[bench]]
name = "rope_benchmark"
harness = false
+
+[package.metadata.cargo-machete]
+ignored = ["tracing"]
@@ -12,6 +12,7 @@ use std::{
str,
};
use sum_tree::{Bias, Dimension, Dimensions, SumTree};
+use ztracing::instrument;
pub use chunk::{Chunk, ChunkSlice};
pub use offset_utf16::OffsetUtf16;
@@ -428,6 +429,7 @@ impl Rope {
})
}
+ #[instrument(skip_all)]
pub fn point_to_offset(&self, point: Point) -> usize {
if point >= self.summary().lines {
return self.summary().len;
@@ -15,6 +15,7 @@ use util::ResultExt as _;
use util::{
asset_str,
markdown::{MarkdownEscaped, MarkdownInlineCode, MarkdownString},
+ schemars::AllowTrailingCommas,
};
use crate::SettingsAssets;
@@ -451,7 +452,9 @@ impl KeymapFile {
/// Creates a JSON schema generator, suitable for generating json schemas
/// for actions
pub fn action_schema_generator() -> schemars::SchemaGenerator {
- schemars::generate::SchemaSettings::draft2019_09().into_generator()
+ schemars::generate::SchemaSettings::draft2019_09()
+ .with_transform(AllowTrailingCommas)
+ .into_generator()
}
pub fn generate_json_schema_for_registered_actions(cx: &mut App) -> Value {
@@ -62,6 +62,8 @@ impl merge_from::MergeFrom for AllLanguageSettingsContent {
pub struct FeaturesContent {
/// Determines which edit prediction provider to use.
pub edit_prediction_provider: Option<EditPredictionProvider>,
+ /// Enables the experimental edit prediction context retrieval system.
+ pub experimental_edit_prediction_context_retrieval: Option<bool>,
}
/// The provider that supplies edit predictions.
@@ -79,6 +81,7 @@ pub enum EditPredictionProvider {
pub const EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME: &str = "sweep";
pub const EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME: &str = "zeta2";
+pub const EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME: &str = "mercury";
impl<'de> Deserialize<'de> for EditPredictionProvider {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
@@ -109,6 +112,13 @@ impl<'de> Deserialize<'de> for EditPredictionProvider {
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
)
}
+ Content::Experimental(name)
+ if name == EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME =>
+ {
+ EditPredictionProvider::Experimental(
+ EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
+ )
+ }
Content::Experimental(name)
if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME =>
{
@@ -543,7 +543,7 @@ pub enum DiagnosticSeverityContent {
pub struct GitHostingProviderConfig {
/// The type of the provider.
///
- /// Must be one of `github`, `gitlab`, or `bitbucket`.
+ /// Must be one of `github`, `gitlab`, `bitbucket`, `gitea`, `forgejo`, or `source_hut`.
pub provider: GitHostingProviderKind,
/// The base URL for the provider (e.g., "https://code.corp.big.com").
@@ -559,4 +559,7 @@ pub enum GitHostingProviderKind {
Github,
Gitlab,
Bitbucket,
+ Gitea,
+ Forgejo,
+ SourceHut,
}
@@ -1,7 +1,7 @@
use std::path::PathBuf;
use collections::HashMap;
-use gpui::{AbsoluteLength, FontFeatures, SharedString, px};
+use gpui::{AbsoluteLength, FontFeatures, FontWeight, SharedString, px};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings_macros::{MergeFrom, with_fallible_options};
@@ -96,8 +96,7 @@ pub struct TerminalSettingsContent {
pub line_height: Option<TerminalLineHeight>,
pub font_features: Option<FontFeatures>,
/// Sets the terminal's font weight in CSS weight units 0-900.
- #[serde(serialize_with = "crate::serialize_optional_f32_with_two_decimal_places")]
- pub font_weight: Option<f32>,
+ pub font_weight: Option<FontWeight>,
/// Default cursor shape for the terminal.
/// Can be "bar", "block", "underline", or "hollow".
///
@@ -25,7 +25,7 @@ use std::{
use util::{
ResultExt as _,
rel_path::RelPath,
- schemars::{DefaultDenyUnknownFields, replace_subschema},
+ schemars::{AllowTrailingCommas, DefaultDenyUnknownFields, replace_subschema},
};
pub type EditorconfigProperties = ec4rs::Properties;
@@ -1010,6 +1010,7 @@ impl SettingsStore {
pub fn json_schema(&self, params: &SettingsJsonSchemaParams) -> Value {
let mut generator = schemars::generate::SchemaSettings::draft2019_09()
.with_transform(DefaultDenyUnknownFields)
+ .with_transform(AllowTrailingCommas)
.into_generator();
UserSettingsContent::json_schema(&mut generator);
@@ -29,8 +29,8 @@ use std::{
use title_bar::platform_title_bar::PlatformTitleBar;
use ui::{
Banner, ContextMenu, Divider, DividerColor, DropdownMenu, DropdownStyle, IconButtonShape,
- KeyBinding, KeybindingHint, PopoverMenu, Switch, SwitchColor, Tooltip, TreeViewItem,
- WithScrollbar, prelude::*,
+ KeyBinding, KeybindingHint, PopoverMenu, Switch, Tooltip, TreeViewItem, WithScrollbar,
+ prelude::*,
};
use ui_input::{NumberField, NumberFieldType};
use util::{ResultExt as _, paths::PathStyle, rel_path::RelPath};
@@ -3501,7 +3501,6 @@ fn render_toggle_button<B: Into<bool> + From<bool> + Copy>(
Switch::new("toggle_button", toggle_state)
.tab_index(0_isize)
- .color(SwitchColor::Accent)
.on_click({
move |state, _window, cx| {
telemetry::event!("Settings Change", setting = field.json_path, type = file.setting_type());
@@ -2,7 +2,7 @@ use collections::HashMap;
use schemars::{JsonSchema, json_schema};
use serde::Deserialize;
use std::borrow::Cow;
-use util::schemars::DefaultDenyUnknownFields;
+use util::schemars::{AllowTrailingCommas, DefaultDenyUnknownFields};
#[derive(Deserialize)]
pub struct VsSnippetsFile {
@@ -14,6 +14,7 @@ impl VsSnippetsFile {
pub fn generate_json_schema() -> serde_json::Value {
let schema = schemars::generate::SchemaSettings::draft2019_09()
.with_transform(DefaultDenyUnknownFields)
+ .with_transform(AllowTrailingCommas)
.into_generator()
.root_schema_for::<Self>();
@@ -17,8 +17,13 @@ doctest = false
arrayvec = "0.7.1"
rayon.workspace = true
log.workspace = true
+ztracing.workspace = true
+tracing.workspace = true
[dev-dependencies]
ctor.workspace = true
rand.workspace = true
zlog.workspace = true
+
+[package.metadata.cargo-machete]
+ignored = ["tracing"]
@@ -1,6 +1,7 @@
use super::*;
use arrayvec::ArrayVec;
use std::{cmp::Ordering, mem, sync::Arc};
+use ztracing::instrument;
#[derive(Clone)]
struct StackEntry<'a, T: Item, D> {
@@ -211,6 +212,7 @@ where
}
#[track_caller]
+ #[instrument(skip_all)]
pub fn prev(&mut self) {
self.search_backward(|_| true)
}
@@ -394,6 +396,7 @@ where
{
/// Returns whether we found the item you were seeking for.
#[track_caller]
+ #[instrument(skip_all)]
pub fn seek<Target>(&mut self, pos: &Target, bias: Bias) -> bool
where
Target: SeekTarget<'a, T::Summary, D>,
@@ -408,6 +411,7 @@ where
///
/// If we did not seek before, use seek instead in that case.
#[track_caller]
+ #[instrument(skip_all)]
pub fn seek_forward<Target>(&mut self, pos: &Target, bias: Bias) -> bool
where
Target: SeekTarget<'a, T::Summary, D>,
@@ -449,6 +453,7 @@ where
/// Returns whether we found the item you were seeking for.
#[track_caller]
+ #[instrument(skip_all)]
fn seek_internal(
&mut self,
target: &dyn SeekTarget<'a, T::Summary, D>,
@@ -8,6 +8,7 @@ use std::marker::PhantomData;
use std::mem;
use std::{cmp::Ordering, fmt, iter::FromIterator, sync::Arc};
pub use tree_map::{MapSeekTarget, TreeMap, TreeSet};
+use ztracing::instrument;
#[cfg(test)]
pub const TREE_BASE: usize = 2;
@@ -379,6 +380,7 @@ impl<T: Item> SumTree<T> {
/// A more efficient version of `Cursor::new()` + `Cursor::seek()` + `Cursor::item()`.
///
/// Only returns the item that exactly has the target match.
+ #[instrument(skip_all)]
pub fn find_exact<'a, 'slf, D, Target>(
&'slf self,
cx: <T::Summary as Summary>::Context<'a>,
@@ -404,6 +406,7 @@ impl<T: Item> SumTree<T> {
}
/// A more efficient version of `Cursor::new()` + `Cursor::seek()` + `Cursor::item()`
+ #[instrument(skip_all)]
pub fn find<'a, 'slf, D, Target>(
&'slf self,
cx: <T::Summary as Summary>::Context<'a>,
@@ -16,7 +16,7 @@ doctest = false
anyhow.workspace = true
client.workspace = true
collections.workspace = true
-edit_prediction.workspace = true
+edit_prediction_types.workspace = true
futures.workspace = true
gpui.workspace = true
language.workspace = true
@@ -1,7 +1,7 @@
mod messages;
-mod supermaven_completion_provider;
+mod supermaven_edit_prediction_delegate;
-pub use supermaven_completion_provider::*;
+pub use supermaven_edit_prediction_delegate::*;
use anyhow::{Context as _, Result};
#[allow(unused_imports)]
@@ -1,6 +1,6 @@
use crate::{Supermaven, SupermavenCompletionStateId};
use anyhow::Result;
-use edit_prediction::{Direction, EditPrediction, EditPredictionProvider};
+use edit_prediction_types::{Direction, EditPrediction, EditPredictionDelegate};
use futures::StreamExt as _;
use gpui::{App, Context, Entity, EntityId, Task};
use language::{Anchor, Buffer, BufferSnapshot};
@@ -15,7 +15,7 @@ use unicode_segmentation::UnicodeSegmentation;
pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
-pub struct SupermavenCompletionProvider {
+pub struct SupermavenEditPredictionDelegate {
supermaven: Entity<Supermaven>,
buffer_id: Option<EntityId>,
completion_id: Option<SupermavenCompletionStateId>,
@@ -25,7 +25,7 @@ pub struct SupermavenCompletionProvider {
completion_position: Option<language::Anchor>,
}
-impl SupermavenCompletionProvider {
+impl SupermavenEditPredictionDelegate {
pub fn new(supermaven: Entity<Supermaven>) -> Self {
Self {
supermaven,
@@ -104,7 +104,7 @@ fn completion_from_diff(
}
}
-impl EditPredictionProvider for SupermavenCompletionProvider {
+impl EditPredictionDelegate for SupermavenEditPredictionDelegate {
fn name() -> &'static str {
"supermaven"
}
@@ -113,7 +113,7 @@ impl EditPredictionProvider for SupermavenCompletionProvider {
"Supermaven"
}
- fn show_completions_in_menu() -> bool {
+ fn show_predictions_in_menu() -> bool {
true
}
@@ -269,8 +269,8 @@ impl EditPredictionProvider for SupermavenCompletionProvider {
}
fn reset_completion_cache(
- provider: &mut SupermavenCompletionProvider,
- _cx: &mut Context<SupermavenCompletionProvider>,
+ provider: &mut SupermavenEditPredictionDelegate,
+ _cx: &mut Context<SupermavenEditPredictionDelegate>,
) {
provider.pending_refresh = None;
provider.completion_id = None;
@@ -357,6 +357,7 @@ impl DebugTaskFile {
"$schema": meta_schema,
"title": "Debug Configurations",
"description": "Configuration for debug scenarios",
+ "allowTrailingCommas": true,
"type": "array",
"items": {
"type": "object",
@@ -4,7 +4,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::path::PathBuf;
-use util::schemars::DefaultDenyUnknownFields;
+use util::schemars::{AllowTrailingCommas, DefaultDenyUnknownFields};
use util::serde::default_true;
use util::{ResultExt, truncate_and_remove_front};
@@ -118,6 +118,7 @@ impl TaskTemplates {
pub fn generate_json_schema() -> serde_json::Value {
let schema = schemars::generate::SchemaSettings::draft2019_09()
.with_transform(DefaultDenyUnknownFields)
+ .with_transform(AllowTrailingCommas)
.into_generator()
.root_schema_for::<Self>();
@@ -208,7 +208,8 @@ fn path_match<T>(
if path_hyperlink_regexes.is_empty() || path_hyperlink_timeout.as_millis() == 0 {
return None;
}
-
+ debug_assert!(line_start <= hovered);
+ debug_assert!(line_end >= hovered);
let search_start_time = Instant::now();
let timed_out = || {
@@ -224,13 +225,35 @@ fn path_match<T>(
let mut line = String::with_capacity(
(line_end.line.0 - line_start.line.0 + 1) as usize * term.grid().columns(),
);
- line.push(term.grid()[line_start].c);
+ let first_cell = &term.grid()[line_start];
+ line.push(first_cell.c);
+ let mut start_offset = 0;
+ let mut hovered_point_byte_offset = None;
+
+ if !first_cell.flags.intersects(WIDE_CHAR_SPACERS) {
+ start_offset += first_cell.c.len_utf8();
+ if line_start == hovered {
+ hovered_point_byte_offset = Some(0);
+ }
+ }
+
for cell in term.grid().iter_from(line_start) {
if cell.point > line_end {
break;
}
+ let is_spacer = cell.flags.intersects(WIDE_CHAR_SPACERS);
+ if cell.point == hovered {
+ debug_assert!(hovered_point_byte_offset.is_none());
+ if start_offset > 0 && cell.flags.contains(Flags::WIDE_CHAR_SPACER) {
+ // If we hovered on a trailing spacer, back up to the end of the previous char's bytes.
+ start_offset -= 1;
+ }
+ hovered_point_byte_offset = Some(start_offset);
+ } else if cell.point < hovered && !is_spacer {
+ start_offset += cell.c.len_utf8();
+ }
- if !cell.flags.intersects(WIDE_CHAR_SPACERS) {
+ if !is_spacer {
line.push(match cell.c {
'\t' => ' ',
c @ _ => c,
@@ -238,7 +261,7 @@ fn path_match<T>(
}
}
let line = line.trim_ascii_end();
-
+ let hovered_point_byte_offset = hovered_point_byte_offset?;
let found_from_range = |path_range: Range<usize>,
link_range: Range<usize>,
position: Option<(u32, Option<u32>)>| {
@@ -268,7 +291,7 @@ fn path_match<T>(
.expand_wide(link_end, AlacDirection::Left)
.sub(term, Boundary::Grid, 1);
- Some((
+ (
{
let mut path = line[path_range].to_string();
position.inspect(|(line, column)| {
@@ -278,7 +301,7 @@ fn path_match<T>(
path
},
link_match,
- ))
+ )
};
for regex in path_hyperlink_regexes {
@@ -296,7 +319,7 @@ fn path_match<T>(
continue;
}
};
-
+ path_found = true;
let match_range = captures.get(0).unwrap().range();
let (path_range, line_column) = if let Some(path) = captures.name("path") {
let parse = |name: &str| {
@@ -314,14 +337,16 @@ fn path_match<T>(
};
let link_range = captures
.name("link")
- .map_or(match_range, |link| link.range());
+ .map_or_else(|| match_range.clone(), |link| link.range());
+
+ if !link_range.contains(&hovered_point_byte_offset) {
+ // No match, just skip.
+ continue;
+ }
let found = found_from_range(path_range, link_range, line_column);
- if let Some(found) = found {
- path_found = true;
- if found.1.contains(&hovered) {
- return Some(found);
- }
+ if found.1.contains(&hovered) {
+ return Some(found);
}
}
@@ -95,7 +95,7 @@ impl settings::Settings for TerminalSettings {
)
}),
font_features: user_content.font_features,
- font_weight: user_content.font_weight.map(FontWeight),
+ font_weight: user_content.font_weight,
line_height: user_content.line_height.unwrap(),
env: project_content.env.unwrap(),
cursor_shape: user_content.cursor_shape.unwrap().into(),
@@ -167,7 +167,7 @@ impl TerminalPanel {
// hence we focus that first. Otherwise, we'd end up without a focused element, as
// context menu will be gone the moment we spawn the modal.
.action(
- "Spawn task",
+ "Spawn Task",
zed_actions::Spawn::modal().boxed_clone(),
)
});
@@ -1,36 +0,0 @@
-use gpui::{IntoElement, Render};
-use ui::{Divider, prelude::*, tooltip_container};
-
-pub struct TerminalTooltip {
- title: SharedString,
- pid: u32,
-}
-
-impl TerminalTooltip {
- pub fn new(title: impl Into<SharedString>, pid: u32) -> Self {
- Self {
- title: title.into(),
- pid,
- }
- }
-}
-
-impl Render for TerminalTooltip {
- fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- tooltip_container(cx, move |this, _cx| {
- this.occlude()
- .on_mouse_move(|_, _window, cx| cx.stop_propagation())
- .child(
- v_flex()
- .gap_1()
- .child(Label::new(self.title.clone()))
- .child(Divider::horizontal())
- .child(
- Label::new(format!("Process ID (PID): {}", self.pid))
- .color(Color::Muted)
- .size(LabelSize::Small),
- ),
- )
- })
- }
-}
@@ -4,7 +4,6 @@ pub mod terminal_panel;
mod terminal_path_like_target;
pub mod terminal_scrollbar;
mod terminal_slash_command;
-pub mod terminal_tab_tooltip;
use assistant_slash_command::SlashCommandRegistry;
use editor::{EditorSettings, actions::SelectAll, blink_manager::BlinkManager};
@@ -32,9 +31,8 @@ use terminal_panel::TerminalPanel;
use terminal_path_like_target::{hover_path_like_target, open_path_like_target};
use terminal_scrollbar::TerminalScrollHandle;
use terminal_slash_command::TerminalSlashCommand;
-use terminal_tab_tooltip::TerminalTooltip;
use ui::{
- ContextMenu, Icon, IconName, Label, ScrollAxes, Scrollbars, Tooltip, WithScrollbar, h_flex,
+ ContextMenu, Divider, ScrollAxes, Scrollbars, Tooltip, WithScrollbar,
prelude::*,
scrollbars::{self, GlobalSetting, ScrollbarVisibility},
};
@@ -1140,14 +1138,24 @@ impl Item for TerminalView {
type Event = ItemEvent;
fn tab_tooltip_content(&self, cx: &App) -> Option<TabTooltipContent> {
- let terminal = self.terminal().read(cx);
- let title = terminal.title(false);
- let pid = terminal.pid_getter()?.fallback_pid();
-
- Some(TabTooltipContent::Custom(Box::new(move |_window, cx| {
- cx.new(|_| TerminalTooltip::new(title.clone(), pid.as_u32()))
- .into()
- })))
+ Some(TabTooltipContent::Custom(Box::new(Tooltip::element({
+ let terminal = self.terminal().read(cx);
+ let title = terminal.title(false);
+ let pid = terminal.pid_getter()?.fallback_pid();
+
+ move |_, _| {
+ v_flex()
+ .gap_1()
+ .child(Label::new(title.clone()))
+ .child(h_flex().flex_grow().child(Divider::horizontal()))
+ .child(
+ Label::new(format!("Process ID (PID): {}", pid))
+ .color(Color::Muted)
+ .size(LabelSize::Small),
+ )
+ .into_any_element()
+ }
+ }))))
}
fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement {
@@ -8,10 +8,14 @@ use sum_tree::{Bias, Dimensions};
/// A timestamped position in a buffer
#[derive(Copy, Clone, Eq, PartialEq, Hash)]
pub struct Anchor {
+ /// The timestamp of the operation that inserted the text
+ /// in which this anchor is located.
pub timestamp: clock::Lamport,
- /// The byte offset in the buffer
+ /// The byte offset into the text inserted in the operation
+ /// at `timestamp`.
pub offset: usize,
- /// Describes which character the anchor is biased towards
+ /// Whether this anchor stays attached to the character *before* or *after*
+ /// the offset.
pub bias: Bias,
pub buffer_id: Option<BufferId>,
}
@@ -182,7 +182,9 @@ impl TitleBar {
this.children(current_user_face_pile.map(|face_pile| {
v_flex()
- .on_mouse_down(MouseButton::Left, |_, _, cx| cx.stop_propagation())
+ .on_mouse_down(MouseButton::Left, |_, window, _| {
+ window.prevent_default()
+ })
.child(face_pile)
.child(render_color_ribbon(player_colors.local().cursor))
}))
@@ -217,6 +219,9 @@ impl TitleBar {
.child(facepile)
.child(render_color_ribbon(player_color.cursor))
.cursor_pointer()
+ .on_mouse_down(MouseButton::Left, |_, window, _| {
+ window.prevent_default()
+ })
.on_click({
let peer_id = collaborator.peer_id;
cx.listener(move |this, _, window, cx| {
@@ -4,7 +4,7 @@ use gpui::{
};
use theme::ActiveTheme;
-use crate::{ElevationIndex, h_flex};
+use crate::{ElevationIndex, IconButton, h_flex};
use super::ButtonLike;
@@ -15,6 +15,23 @@ pub enum SplitButtonStyle {
Transparent,
}
+pub enum SplitButtonKind {
+ ButtonLike(ButtonLike),
+ IconButton(IconButton),
+}
+
+impl From<IconButton> for SplitButtonKind {
+ fn from(icon_button: IconButton) -> Self {
+ Self::IconButton(icon_button)
+ }
+}
+
+impl From<ButtonLike> for SplitButtonKind {
+ fn from(button_like: ButtonLike) -> Self {
+ Self::ButtonLike(button_like)
+ }
+}
+
/// /// A button with two parts: a primary action on the left and a secondary action on the right.
///
/// The left side is a [`ButtonLike`] with the main action, while the right side can contain
@@ -23,15 +40,15 @@ pub enum SplitButtonStyle {
/// The two sections are visually separated by a divider, but presented as a unified control.
#[derive(IntoElement)]
pub struct SplitButton {
- pub left: ButtonLike,
- pub right: AnyElement,
+ left: SplitButtonKind,
+ right: AnyElement,
style: SplitButtonStyle,
}
impl SplitButton {
- pub fn new(left: ButtonLike, right: AnyElement) -> Self {
+ pub fn new(left: impl Into<SplitButtonKind>, right: AnyElement) -> Self {
Self {
- left,
+ left: left.into(),
right,
style: SplitButtonStyle::Filled,
}
@@ -56,7 +73,10 @@ impl RenderOnce for SplitButton {
this.border_1()
.border_color(cx.theme().colors().border.opacity(0.8))
})
- .child(div().flex_grow().child(self.left))
+ .child(div().flex_grow().child(match self.left {
+ SplitButtonKind::ButtonLike(button) => button.into_any_element(),
+ SplitButtonKind::IconButton(icon) => icon.into_any_element(),
+ }))
.child(
div()
.h_full()
@@ -485,6 +485,7 @@ pub struct Table<const COLS: usize = 3> {
interaction_state: Option<WeakEntity<TableInteractionState>>,
col_widths: Option<TableWidths<COLS>>,
map_row: Option<Rc<dyn Fn((usize, Stateful<Div>), &mut Window, &mut App) -> AnyElement>>,
+ use_ui_font: bool,
empty_table_callback: Option<Rc<dyn Fn(&mut Window, &mut App) -> AnyElement>>,
}
@@ -498,6 +499,7 @@ impl<const COLS: usize> Table<COLS> {
rows: TableContents::Vec(Vec::new()),
interaction_state: None,
map_row: None,
+ use_ui_font: true,
empty_table_callback: None,
col_widths: None,
}
@@ -590,6 +592,11 @@ impl<const COLS: usize> Table<COLS> {
self
}
+ pub fn no_ui_font(mut self) -> Self {
+ self.use_ui_font = false;
+ self
+ }
+
pub fn map_row(
mut self,
callback: impl Fn((usize, Stateful<Div>), &mut Window, &mut App) -> AnyElement + 'static,
@@ -618,8 +625,8 @@ fn base_cell_style(width: Option<Length>) -> Div {
.overflow_hidden()
}
-fn base_cell_style_text(width: Option<Length>, cx: &App) -> Div {
- base_cell_style(width).text_ui(cx)
+fn base_cell_style_text(width: Option<Length>, use_ui_font: bool, cx: &App) -> Div {
+ base_cell_style(width).when(use_ui_font, |el| el.text_ui(cx))
}
pub fn render_table_row<const COLS: usize>(
@@ -656,7 +663,12 @@ pub fn render_table_row<const COLS: usize>(
.map(IntoElement::into_any_element)
.into_iter()
.zip(column_widths)
- .map(|(cell, width)| base_cell_style_text(width, cx).px_1().py_0p5().child(cell)),
+ .map(|(cell, width)| {
+ base_cell_style_text(width, table_context.use_ui_font, cx)
+ .px_1()
+ .py_0p5()
+ .child(cell)
+ }),
);
let row = if let Some(map_row) = table_context.map_row {
@@ -700,7 +712,7 @@ pub fn render_table_header<const COLS: usize>(
.border_color(cx.theme().colors().border)
.children(headers.into_iter().enumerate().zip(column_widths).map(
|((header_idx, h), width)| {
- base_cell_style_text(width, cx)
+ base_cell_style_text(width, table_context.use_ui_font, cx)
.child(h)
.id(ElementId::NamedInteger(
shared_element_id.clone(),
@@ -739,6 +751,7 @@ pub struct TableRenderContext<const COLS: usize> {
pub total_row_count: usize,
pub column_widths: Option<[Length; COLS]>,
pub map_row: Option<Rc<dyn Fn((usize, Stateful<Div>), &mut Window, &mut App) -> AnyElement>>,
+ pub use_ui_font: bool,
}
impl<const COLS: usize> TableRenderContext<COLS> {
@@ -748,6 +761,7 @@ impl<const COLS: usize> TableRenderContext<COLS> {
total_row_count: table.rows.len(),
column_widths: table.col_widths.as_ref().map(|widths| widths.lengths(cx)),
map_row: table.map_row.clone(),
+ use_ui_font: table.use_ui_font,
}
}
}
@@ -44,15 +44,16 @@ pub enum ToggleStyle {
pub struct Checkbox {
id: ElementId,
toggle_state: ToggleState,
+ style: ToggleStyle,
disabled: bool,
placeholder: bool,
- on_click: Option<Box<dyn Fn(&ToggleState, &ClickEvent, &mut Window, &mut App) + 'static>>,
filled: bool,
- style: ToggleStyle,
- tooltip: Option<Box<dyn Fn(&mut Window, &mut App) -> AnyView>>,
+ visualization: bool,
label: Option<SharedString>,
label_size: LabelSize,
label_color: Color,
+ tooltip: Option<Box<dyn Fn(&mut Window, &mut App) -> AnyView>>,
+ on_click: Option<Box<dyn Fn(&ToggleState, &ClickEvent, &mut Window, &mut App) + 'static>>,
}
impl Checkbox {
@@ -61,15 +62,16 @@ impl Checkbox {
Self {
id: id.into(),
toggle_state: checked,
+ style: ToggleStyle::default(),
disabled: false,
- on_click: None,
+ placeholder: false,
filled: false,
- style: ToggleStyle::default(),
- tooltip: None,
+ visualization: false,
label: None,
label_size: LabelSize::Default,
label_color: Color::Muted,
- placeholder: false,
+ tooltip: None,
+ on_click: None,
}
}
@@ -110,6 +112,13 @@ impl Checkbox {
self
}
+ /// Makes the checkbox look enabled but without pointer cursor and hover styles.
+ /// Primarily used for uninteractive markdown previews.
+ pub fn visualization_only(mut self, visualization: bool) -> Self {
+ self.visualization = visualization;
+ self
+ }
+
/// Sets the style of the checkbox using the specified [`ToggleStyle`].
pub fn style(mut self, style: ToggleStyle) -> Self {
self.style = style;
@@ -209,11 +218,10 @@ impl RenderOnce for Checkbox {
let size = Self::container_size();
let checkbox = h_flex()
+ .group(group_id.clone())
.id(self.id.clone())
- .justify_center()
- .items_center()
.size(size)
- .group(group_id.clone())
+ .justify_center()
.child(
div()
.flex()
@@ -230,7 +238,7 @@ impl RenderOnce for Checkbox {
.when(self.disabled, |this| {
this.bg(cx.theme().colors().element_disabled.opacity(0.6))
})
- .when(!self.disabled, |this| {
+ .when(!self.disabled && !self.visualization, |this| {
this.group_hover(group_id.clone(), |el| el.border_color(hover_border_color))
})
.when(self.placeholder, |this| {
@@ -250,20 +258,14 @@ impl RenderOnce for Checkbox {
.map(|this| {
if self.disabled {
this.cursor_not_allowed()
+ } else if self.visualization {
+ this.cursor_default()
} else {
this.cursor_pointer()
}
})
.gap(DynamicSpacing::Base06.rems(cx))
.child(checkbox)
- .when_some(
- self.on_click.filter(|_| !self.disabled),
- |this, on_click| {
- this.on_click(move |click, window, cx| {
- on_click(&self.toggle_state.inverse(), click, window, cx)
- })
- },
- )
.when_some(self.label, |this, label| {
this.child(
Label::new(label)
@@ -274,6 +276,14 @@ impl RenderOnce for Checkbox {
.when_some(self.tooltip, |this, tooltip| {
this.tooltip(move |window, cx| tooltip(window, cx))
})
+ .when_some(
+ self.on_click.filter(|_| !self.disabled),
+ |this, on_click| {
+ this.on_click(move |click, window, cx| {
+ on_click(&self.toggle_state.inverse(), click, window, cx)
+ })
+ },
+ )
}
}
@@ -281,11 +291,7 @@ impl RenderOnce for Checkbox {
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Default)]
pub enum SwitchColor {
#[default]
- Default,
Accent,
- Error,
- Warning,
- Success,
Custom(Hsla),
}
@@ -299,27 +305,10 @@ impl SwitchColor {
}
match self {
- SwitchColor::Default => {
- let colors = cx.theme().colors();
- let base_color = colors.text;
- let bg_color = colors.element_background.blend(base_color.opacity(0.08));
- (bg_color, colors.border_variant)
- }
SwitchColor::Accent => {
let status = cx.theme().status();
- (status.info.opacity(0.4), status.info.opacity(0.2))
- }
- SwitchColor::Error => {
- let status = cx.theme().status();
- (status.error.opacity(0.4), status.error.opacity(0.2))
- }
- SwitchColor::Warning => {
- let status = cx.theme().status();
- (status.warning.opacity(0.4), status.warning.opacity(0.2))
- }
- SwitchColor::Success => {
- let status = cx.theme().status();
- (status.success.opacity(0.4), status.success.opacity(0.2))
+ let colors = cx.theme().colors();
+ (status.info.opacity(0.4), colors.text_accent.opacity(0.2))
}
SwitchColor::Custom(color) => (*color, color.opacity(0.6)),
}
@@ -329,11 +318,7 @@ impl SwitchColor {
impl From<SwitchColor> for Color {
fn from(color: SwitchColor) -> Self {
match color {
- SwitchColor::Default => Color::Default,
SwitchColor::Accent => Color::Accent,
- SwitchColor::Error => Color::Error,
- SwitchColor::Warning => Color::Warning,
- SwitchColor::Success => Color::Success,
SwitchColor::Custom(_) => Color::Default,
}
}
@@ -939,6 +924,15 @@ impl Component for Checkbox {
.into_any_element(),
)],
),
+ example_group_with_title(
+ "Extra",
+ vec![single_example(
+ "Visualization-Only",
+ Checkbox::new("viz_only", ToggleState::Selected)
+ .visualization_only(true)
+ .into_any_element(),
+ )],
+ ),
])
.into_any_element(),
)
@@ -980,37 +974,8 @@ impl Component for Switch {
"Colors",
vec![
single_example(
- "Default",
- Switch::new("switch_default_style", ToggleState::Selected)
- .color(SwitchColor::Default)
- .on_click(|_, _, _cx| {})
- .into_any_element(),
- ),
- single_example(
- "Accent",
+ "Accent (Default)",
Switch::new("switch_accent_style", ToggleState::Selected)
- .color(SwitchColor::Accent)
- .on_click(|_, _, _cx| {})
- .into_any_element(),
- ),
- single_example(
- "Error",
- Switch::new("switch_error_style", ToggleState::Selected)
- .color(SwitchColor::Error)
- .on_click(|_, _, _cx| {})
- .into_any_element(),
- ),
- single_example(
- "Warning",
- Switch::new("switch_warning_style", ToggleState::Selected)
- .color(SwitchColor::Warning)
- .on_click(|_, _, _cx| {})
- .into_any_element(),
- ),
- single_example(
- "Success",
- Switch::new("switch_success_style", ToggleState::Selected)
- .color(SwitchColor::Success)
.on_click(|_, _, _cx| {})
.into_any_element(),
),
@@ -58,7 +58,7 @@ pub fn new_smol_command(program: impl AsRef<OsStr>) -> smol::process::Command {
}
#[cfg(target_os = "macos")]
-fn reset_exception_ports() {
+pub fn reset_exception_ports() {
use mach2::exception_types::{
EXC_MASK_ALL, EXCEPTION_DEFAULT, exception_behavior_t, exception_mask_t,
};
@@ -53,3 +53,20 @@ impl schemars::transform::Transform for DefaultDenyUnknownFields {
transform_subschemas(self, schema);
}
}
+
+/// Defaults `allowTrailingCommas` to `true`, for use with `json-language-server`.
+/// This can be applied to any schema that will be treated as `jsonc`.
+///
+/// Note that this is non-recursive and only applied to the root schema.
+#[derive(Clone)]
+pub struct AllowTrailingCommas;
+
+impl schemars::transform::Transform for AllowTrailingCommas {
+ fn transform(&mut self, schema: &mut schemars::Schema) {
+ if let Some(object) = schema.as_object_mut()
+ && !object.contains_key("allowTrailingCommas")
+ {
+ object.insert("allowTrailingCommas".to_string(), true.into());
+ }
+ }
+}
@@ -390,6 +390,8 @@ pub fn set_pre_exec_to_start_new_session(
use std::os::unix::process::CommandExt;
command.pre_exec(|| {
libc::setsid();
+ #[cfg(target_os = "macos")]
+ crate::command::reset_exception_ports();
Ok(())
});
};
@@ -717,7 +717,7 @@ mod test {
cx.update_global(|store: &mut SettingsStore, cx| {
store.update_user_settings(cx, |settings| {
settings.project.all_languages.languages.0.insert(
- LanguageName::new("Rust").0,
+ LanguageName::new_static("Rust").0,
LanguageSettingsContent {
auto_indent_on_paste: Some(false),
..Default::default()
@@ -11,7 +11,6 @@ use editor::{ClipboardSelection, Editor, SelectionEffects};
use gpui::Context;
use gpui::Window;
use language::Point;
-use multi_buffer::MultiBufferRow;
use settings::Settings;
struct HighlightOnYank;
@@ -198,11 +197,14 @@ impl Vim {
if kind.linewise() {
text.push('\n');
}
- clipboard_selections.push(ClipboardSelection {
- len: text.len() - initial_len,
- is_entire_line: false,
- first_line_indent: buffer.indent_size_for_line(MultiBufferRow(start.row)).len,
- });
+ clipboard_selections.push(ClipboardSelection::for_buffer(
+ text.len() - initial_len,
+ false,
+ start..end,
+ &buffer,
+ editor.project(),
+ cx,
+ ));
}
}
@@ -2382,9 +2382,10 @@ mod test {
Mode::Insert,
);
- cx.set_state("let a = (test::call(), 'p', my_macro!{ห});", Mode::Normal);
- cx.simulate_keystrokes("c a a");
- cx.assert_state("let a = (test::call(), 'p'ห);", Mode::Insert);
+ // TODO regressed with the up-to-date Rust grammar.
+ // cx.set_state("let a = (test::call(), 'p', my_macro!{ห});", Mode::Normal);
+ // cx.simulate_keystrokes("c a a");
+ // cx.assert_state("let a = (test::call(), 'p'ห);", Mode::Insert);
cx.set_state("let a = [test::call(ห), 300];", Mode::Normal);
cx.simulate_keystrokes("c i a");
@@ -1359,11 +1359,11 @@ impl WorkspaceDb {
// If a local workspace points to WSL, this check will cause us to wait for the
// WSL VM and file server to boot up. This can block for many seconds.
// Supported scenarios use remote workspaces.
- if !has_wsl_path
- && paths.paths().iter().all(|path| path.exists())
- && paths.paths().iter().any(|path| path.is_dir())
- {
- result.push((id, SerializedWorkspaceLocation::Local, paths));
+ if !has_wsl_path && paths.paths().iter().all(|path| path.exists()) {
+ // Only show directories in recent projects
+ if paths.paths().iter().any(|path| path.is_dir()) {
+ result.push((id, SerializedWorkspaceLocation::Local, paths));
+ }
} else {
delete_tasks.push(self.delete_workspace_by_id(id));
}
@@ -1641,20 +1641,18 @@ impl Workspace {
let (window_bounds, display) = if let Some(bounds) = window_bounds_override {
(Some(WindowBounds::Windowed(bounds)), None)
- } else {
- let restorable_bounds = serialized_workspace
- .as_ref()
- .and_then(|workspace| Some((workspace.display?, workspace.window_bounds?)))
- .or_else(|| {
- let (display, window_bounds) = DB.last_window().log_err()?;
- Some((display?, window_bounds?))
- });
-
- if let Some((serialized_display, serialized_status)) = restorable_bounds {
- (Some(serialized_status.0), Some(serialized_display))
+ } else if let Some(workspace) = serialized_workspace.as_ref() {
+ // Reopening an existing workspace - restore its saved bounds
+ if let (Some(display), Some(bounds)) =
+ (workspace.display, workspace.window_bounds.as_ref())
+ {
+ (Some(bounds.0), Some(display))
} else {
(None, None)
}
+ } else {
+ // New window - let GPUI's default_bounds() handle cascading
+ (None, None)
};
// Use the serialized workspace to construct the new window
@@ -6069,6 +6067,11 @@ impl Workspace {
.on_action(cx.listener(Workspace::cancel))
}
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn set_random_database_id(&mut self) {
+ self.database_id = Some(WorkspaceId(Uuid::new_v4().as_u64_pair().0 as i64));
+ }
+
#[cfg(any(test, feature = "test-support"))]
pub fn test_new(project: Entity<Project>, window: &mut Window, cx: &mut Context<Self>) -> Self {
use node_runtime::NodeRuntime;
@@ -428,7 +428,7 @@ impl Worktree {
let mut entry = Entry::new(
RelPath::empty().into(),
&metadata,
- &next_entry_id,
+ ProjectEntryId::new(&next_entry_id),
snapshot.root_char_bag,
None,
);
@@ -2736,13 +2736,30 @@ impl BackgroundScannerState {
}
}
- async fn insert_entry(
+ fn entry_id_for(
&mut self,
- mut entry: Entry,
- fs: &dyn Fs,
- watcher: &dyn Watcher,
- ) -> Entry {
- self.reuse_entry_id(&mut entry);
+ next_entry_id: &AtomicUsize,
+ path: &RelPath,
+ metadata: &fs::Metadata,
+ ) -> ProjectEntryId {
+ // If an entry with the same inode was removed from the worktree during this scan,
+ // then it *might* represent the same file or directory. But the OS might also have
+ // re-used the inode for a completely different file or directory.
+ //
+ // Conditionally reuse the old entry's id:
+ // * if the mtime is the same, the file was probably been renamed.
+ // * if the path is the same, the file may just have been updated
+ if let Some(removed_entry) = self.removed_entries.remove(&metadata.inode) {
+ if removed_entry.mtime == Some(metadata.mtime) || *removed_entry.path == *path {
+ return removed_entry.id;
+ }
+ } else if let Some(existing_entry) = self.snapshot.entry_for_path(path) {
+ return existing_entry.id;
+ }
+ ProjectEntryId::new(next_entry_id)
+ }
+
+ async fn insert_entry(&mut self, entry: Entry, fs: &dyn Fs, watcher: &dyn Watcher) -> Entry {
let entry = self.snapshot.insert_entry(entry, fs);
if entry.path.file_name() == Some(&DOT_GIT) {
self.insert_git_repository(entry.path.clone(), fs, watcher)
@@ -3389,13 +3406,13 @@ impl Entry {
fn new(
path: Arc<RelPath>,
metadata: &fs::Metadata,
- next_entry_id: &AtomicUsize,
+ id: ProjectEntryId,
root_char_bag: CharBag,
canonical_path: Option<Arc<Path>>,
) -> Self {
let char_bag = char_bag_for_path(root_char_bag, &path);
Self {
- id: ProjectEntryId::new(next_entry_id),
+ id,
kind: if metadata.is_dir {
EntryKind::PendingDir
} else {
@@ -3682,8 +3699,10 @@ impl BackgroundScanner {
.await;
if ignore_stack.is_abs_path_ignored(root_abs_path.as_path(), true) {
root_entry.is_ignored = true;
+ let mut root_entry = root_entry.clone();
+ state.reuse_entry_id(&mut root_entry);
state
- .insert_entry(root_entry.clone(), self.fs.as_ref(), self.watcher.as_ref())
+ .insert_entry(root_entry, self.fs.as_ref(), self.watcher.as_ref())
.await;
}
if root_entry.is_dir() {
@@ -4289,7 +4308,7 @@ impl BackgroundScanner {
let mut child_entry = Entry::new(
child_path.clone(),
&child_metadata,
- &next_entry_id,
+ ProjectEntryId::new(&next_entry_id),
root_char_bag,
None,
);
@@ -4476,10 +4495,11 @@ impl BackgroundScanner {
.ignore_stack_for_abs_path(&abs_path, metadata.is_dir, self.fs.as_ref())
.await;
let is_external = !canonical_path.starts_with(&root_canonical_path);
+ let entry_id = state.entry_id_for(self.next_entry_id.as_ref(), path, &metadata);
let mut fs_entry = Entry::new(
path.clone(),
&metadata,
- self.next_entry_id.as_ref(),
+ entry_id,
state.snapshot.root_char_bag,
if metadata.is_symlink {
Some(canonical_path.as_path().to_path_buf().into())
@@ -1533,6 +1533,175 @@ async fn test_create_dir_all_on_create_entry(cx: &mut TestAppContext) {
});
}
+#[gpui::test]
+async fn test_create_file_in_expanded_gitignored_dir(cx: &mut TestAppContext) {
+ // Tests the behavior of our worktree refresh when a file in a gitignored directory
+ // is created.
+ init_test(cx);
+ let fs = FakeFs::new(cx.background_executor.clone());
+ fs.insert_tree(
+ "/root",
+ json!({
+ ".gitignore": "ignored_dir\n",
+ "ignored_dir": {
+ "existing_file.txt": "existing content",
+ "another_file.txt": "another content",
+ },
+ }),
+ )
+ .await;
+
+ let tree = Worktree::local(
+ Path::new("/root"),
+ true,
+ fs.clone(),
+ Default::default(),
+ &mut cx.to_async(),
+ )
+ .await
+ .unwrap();
+
+ cx.read(|cx| tree.read(cx).as_local().unwrap().scan_complete())
+ .await;
+
+ tree.read_with(cx, |tree, _| {
+ let ignored_dir = tree.entry_for_path(rel_path("ignored_dir")).unwrap();
+ assert!(ignored_dir.is_ignored);
+ assert_eq!(ignored_dir.kind, EntryKind::UnloadedDir);
+ });
+
+ tree.update(cx, |tree, cx| {
+ tree.load_file(rel_path("ignored_dir/existing_file.txt"), cx)
+ })
+ .await
+ .unwrap();
+
+ tree.read_with(cx, |tree, _| {
+ let ignored_dir = tree.entry_for_path(rel_path("ignored_dir")).unwrap();
+ assert!(ignored_dir.is_ignored);
+ assert_eq!(ignored_dir.kind, EntryKind::Dir);
+
+ assert!(
+ tree.entry_for_path(rel_path("ignored_dir/existing_file.txt"))
+ .is_some()
+ );
+ assert!(
+ tree.entry_for_path(rel_path("ignored_dir/another_file.txt"))
+ .is_some()
+ );
+ });
+
+ let entry = tree
+ .update(cx, |tree, cx| {
+ tree.create_entry(rel_path("ignored_dir/new_file.txt").into(), false, None, cx)
+ })
+ .await
+ .unwrap();
+ assert!(entry.into_included().is_some());
+
+ cx.executor().run_until_parked();
+
+ tree.read_with(cx, |tree, _| {
+ let ignored_dir = tree.entry_for_path(rel_path("ignored_dir")).unwrap();
+ assert!(ignored_dir.is_ignored);
+ assert_eq!(
+ ignored_dir.kind,
+ EntryKind::Dir,
+ "ignored_dir should still be loaded, not UnloadedDir"
+ );
+
+ assert!(
+ tree.entry_for_path(rel_path("ignored_dir/existing_file.txt"))
+ .is_some(),
+ "existing_file.txt should still be visible"
+ );
+ assert!(
+ tree.entry_for_path(rel_path("ignored_dir/another_file.txt"))
+ .is_some(),
+ "another_file.txt should still be visible"
+ );
+ assert!(
+ tree.entry_for_path(rel_path("ignored_dir/new_file.txt"))
+ .is_some(),
+ "new_file.txt should be visible"
+ );
+ });
+}
+
+#[gpui::test]
+async fn test_fs_event_for_gitignored_dir_does_not_lose_contents(cx: &mut TestAppContext) {
+ // Tests the behavior of our worktree refresh when a directory modification for a gitignored directory
+ // is triggered.
+ init_test(cx);
+ let fs = FakeFs::new(cx.background_executor.clone());
+ fs.insert_tree(
+ "/root",
+ json!({
+ ".gitignore": "ignored_dir\n",
+ "ignored_dir": {
+ "file1.txt": "content1",
+ "file2.txt": "content2",
+ },
+ }),
+ )
+ .await;
+
+ let tree = Worktree::local(
+ Path::new("/root"),
+ true,
+ fs.clone(),
+ Default::default(),
+ &mut cx.to_async(),
+ )
+ .await
+ .unwrap();
+
+ cx.read(|cx| tree.read(cx).as_local().unwrap().scan_complete())
+ .await;
+
+ // Load a file to expand the ignored directory
+ tree.update(cx, |tree, cx| {
+ tree.load_file(rel_path("ignored_dir/file1.txt"), cx)
+ })
+ .await
+ .unwrap();
+
+ tree.read_with(cx, |tree, _| {
+ let ignored_dir = tree.entry_for_path(rel_path("ignored_dir")).unwrap();
+ assert_eq!(ignored_dir.kind, EntryKind::Dir);
+ assert!(
+ tree.entry_for_path(rel_path("ignored_dir/file1.txt"))
+ .is_some()
+ );
+ assert!(
+ tree.entry_for_path(rel_path("ignored_dir/file2.txt"))
+ .is_some()
+ );
+ });
+
+ fs.emit_fs_event("/root/ignored_dir", Some(fs::PathEventKind::Changed));
+ tree.flush_fs_events(cx).await;
+
+ tree.read_with(cx, |tree, _| {
+ let ignored_dir = tree.entry_for_path(rel_path("ignored_dir")).unwrap();
+ assert_eq!(
+ ignored_dir.kind,
+ EntryKind::Dir,
+ "ignored_dir should still be loaded (Dir), not UnloadedDir"
+ );
+ assert!(
+ tree.entry_for_path(rel_path("ignored_dir/file1.txt"))
+ .is_some(),
+ "file1.txt should still be visible after directory fs event"
+ );
+ assert!(
+ tree.entry_for_path(rel_path("ignored_dir/file2.txt"))
+ .is_some(),
+ "file2.txt should still be visible after directory fs event"
+ );
+ });
+}
+
#[gpui::test(iterations = 100)]
async fn test_random_worktree_operations_during_initial_scan(
cx: &mut TestAppContext,
@@ -30,6 +30,17 @@ pub enum Model {
alias = "grok-4-fast-non-reasoning-latest"
)]
Grok4FastNonReasoning,
+ #[serde(
+ rename = "grok-4-1-fast-non-reasoning",
+ alias = "grok-4-1-fast-non-reasoning-latest"
+ )]
+ Grok41FastNonReasoning,
+ #[serde(
+ rename = "grok-4-1-fast-reasoning",
+ alias = "grok-4-1-fast-reasoning-latest",
+ alias = "grok-4-1-fast"
+ )]
+ Grok41FastReasoning,
#[serde(rename = "grok-code-fast-1", alias = "grok-code-fast-1-0825")]
GrokCodeFast1,
#[serde(rename = "custom")]
@@ -56,6 +67,9 @@ impl Model {
"grok-4" => Ok(Self::Grok4),
"grok-4-fast-reasoning" => Ok(Self::Grok4FastReasoning),
"grok-4-fast-non-reasoning" => Ok(Self::Grok4FastNonReasoning),
+ "grok-4-1-fast-non-reasoning" => Ok(Self::Grok41FastNonReasoning),
+ "grok-4-1-fast-reasoning" => Ok(Self::Grok41FastReasoning),
+ "grok-4-1-fast" => Ok(Self::Grok41FastReasoning),
"grok-2-vision" => Ok(Self::Grok2Vision),
"grok-3" => Ok(Self::Grok3),
"grok-3-mini" => Ok(Self::Grok3Mini),
@@ -76,6 +90,8 @@ impl Model {
Self::Grok4 => "grok-4",
Self::Grok4FastReasoning => "grok-4-fast-reasoning",
Self::Grok4FastNonReasoning => "grok-4-fast-non-reasoning",
+ Self::Grok41FastNonReasoning => "grok-4-1-fast-non-reasoning",
+ Self::Grok41FastReasoning => "grok-4-1-fast-reasoning",
Self::GrokCodeFast1 => "grok-code-fast-1",
Self::Custom { name, .. } => name,
}
@@ -91,6 +107,8 @@ impl Model {
Self::Grok4 => "Grok 4",
Self::Grok4FastReasoning => "Grok 4 Fast",
Self::Grok4FastNonReasoning => "Grok 4 Fast (Non-Reasoning)",
+ Self::Grok41FastNonReasoning => "Grok 4.1 Fast (Non-Reasoning)",
+ Self::Grok41FastReasoning => "Grok 4.1 Fast",
Self::GrokCodeFast1 => "Grok Code Fast 1",
Self::Custom {
name, display_name, ..
@@ -102,7 +120,10 @@ impl Model {
match self {
Self::Grok3 | Self::Grok3Mini | Self::Grok3Fast | Self::Grok3MiniFast => 131_072,
Self::Grok4 | Self::GrokCodeFast1 => 256_000,
- Self::Grok4FastReasoning | Self::Grok4FastNonReasoning => 128_000,
+ Self::Grok4FastReasoning
+ | Self::Grok4FastNonReasoning
+ | Self::Grok41FastNonReasoning
+ | Self::Grok41FastReasoning => 2_000_000,
Self::Grok2Vision => 8_192,
Self::Custom { max_tokens, .. } => *max_tokens,
}
@@ -114,6 +135,8 @@ impl Model {
Self::Grok4
| Self::Grok4FastReasoning
| Self::Grok4FastNonReasoning
+ | Self::Grok41FastNonReasoning
+ | Self::Grok41FastReasoning
| Self::GrokCodeFast1 => Some(64_000),
Self::Grok2Vision => Some(4_096),
Self::Custom {
@@ -131,7 +154,9 @@ impl Model {
| Self::Grok3MiniFast
| Self::Grok4
| Self::Grok4FastReasoning
- | Self::Grok4FastNonReasoning => true,
+ | Self::Grok4FastNonReasoning
+ | Self::Grok41FastNonReasoning
+ | Self::Grok41FastReasoning => true,
Self::Custom {
parallel_tool_calls: Some(support),
..
@@ -154,6 +179,8 @@ impl Model {
| Self::Grok4
| Self::Grok4FastReasoning
| Self::Grok4FastNonReasoning
+ | Self::Grok41FastNonReasoning
+ | Self::Grok41FastReasoning
| Self::GrokCodeFast1 => true,
Self::Custom {
supports_tools: Some(support),
@@ -165,7 +192,12 @@ impl Model {
pub fn supports_images(&self) -> bool {
match self {
- Self::Grok2Vision => true,
+ Self::Grok2Vision
+ | Self::Grok4
+ | Self::Grok4FastReasoning
+ | Self::Grok4FastNonReasoning
+ | Self::Grok41FastNonReasoning
+ | Self::Grok41FastReasoning => true,
Self::Custom {
supports_images: Some(support),
..
@@ -10,6 +10,9 @@ authors = ["Zed Team <hi@zed.dev>"]
[lints]
workspace = true
+[features]
+tracy = ["ztracing/tracy"]
+
[[bin]]
name = "zed"
path = "src/zed-main.rs"
@@ -50,7 +53,6 @@ debugger_tools.workspace = true
debugger_ui.workspace = true
diagnostics.workspace = true
editor.workspace = true
-zeta2_tools.workspace = true
env_logger.workspace = true
extension.workspace = true
extension_host.workspace = true
@@ -74,7 +76,8 @@ gpui = { workspace = true, features = [
gpui_tokio.workspace = true
rayon.workspace = true
-edit_prediction_button.workspace = true
+edit_prediction.workspace = true
+edit_prediction_ui.workspace = true
http_client.workspace = true
image_viewer.workspace = true
inspector_ui.workspace = true
@@ -144,6 +147,8 @@ theme_extension.workspace = true
theme_selector.workspace = true
time.workspace = true
title_bar.workspace = true
+ztracing.workspace = true
+tracing.workspace = true
toolchain_selector.workspace = true
ui.workspace = true
ui_input.workspace = true
@@ -160,7 +165,6 @@ web_search_providers.workspace = true
workspace.workspace = true
zed_actions.workspace = true
zed_env_vars.workspace = true
-zeta.workspace = true
zlog.workspace = true
zlog_settings.workspace = true
chrono.workspace = true
@@ -224,4 +228,4 @@ osx_info_plist_exts = ["resources/info/*"]
osx_url_schemes = ["zed"]
[package.metadata.cargo-machete]
-ignored = ["profiling", "zstd"]
+ignored = ["profiling", "zstd", "tracing"]
@@ -162,10 +162,11 @@ fn fail_to_open_window(e: anyhow::Error, _cx: &mut App) {
.detach();
}
}
-
pub static STARTUP_TIME: OnceLock<Instant> = OnceLock::new();
pub fn main() {
+ ztracing::init();
+
STARTUP_TIME.get_or_init(|| Instant::now());
#[cfg(unix)]
@@ -581,7 +582,7 @@ pub fn main() {
language_model::init(app_state.client.clone(), cx);
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
acp_tools::init(cx);
- zeta2_tools::init(cx);
+ edit_prediction_ui::init(cx);
web_search::init(cx);
web_search_providers::init(app_state.client.clone(), cx);
snippet_provider::init(cx);
@@ -640,7 +641,7 @@ pub fn main() {
settings_ui::init(cx);
keymap_editor::init(cx);
extensions_ui::init(cx);
- zeta::init(cx);
+ edit_prediction::init(cx);
inspector_ui::init(app_state.clone(), cx);
json_schema_store::init(cx);
miniprofiler_ui::init(*STARTUP_TIME.get().unwrap(), cx);
@@ -401,8 +401,8 @@ pub fn initialize_workspace(
unstable_version_notification(cx);
let edit_prediction_menu_handle = PopoverMenuHandle::default();
- let edit_prediction_button = cx.new(|cx| {
- edit_prediction_button::EditPredictionButton::new(
+ let edit_prediction_ui = cx.new(|cx| {
+ edit_prediction_ui::EditPredictionButton::new(
app_state.fs.clone(),
app_state.user_store.clone(),
edit_prediction_menu_handle.clone(),
@@ -411,7 +411,7 @@ pub fn initialize_workspace(
)
});
workspace.register_action({
- move |_, _: &edit_prediction_button::ToggleMenu, window, cx| {
+ move |_, _: &edit_prediction_ui::ToggleMenu, window, cx| {
edit_prediction_menu_handle.toggle(window, cx);
}
});
@@ -450,7 +450,7 @@ pub fn initialize_workspace(
status_bar.add_left_item(lsp_button, window, cx);
status_bar.add_left_item(diagnostic_summary, window, cx);
status_bar.add_left_item(activity_indicator, window, cx);
- status_bar.add_right_item(edit_prediction_button, window, cx);
+ status_bar.add_right_item(edit_prediction_ui, window, cx);
status_bar.add_right_item(active_buffer_language, window, cx);
status_bar.add_right_item(active_toolchain_language, window, cx);
status_bar.add_right_item(line_ending_indicator, window, cx);
@@ -2255,7 +2255,8 @@ mod tests {
Action, AnyWindowHandle, App, AssetSource, BorrowAppContext, TestAppContext, UpdateGlobal,
VisualTestContext, WindowHandle, actions,
};
- use language::{LanguageMatcher, LanguageRegistry};
+ use language::LanguageRegistry;
+ use languages::{markdown_lang, rust_lang};
use pretty_assertions::{assert_eq, assert_ne};
use project::{Project, ProjectPath};
use semver::Version;
@@ -2895,9 +2896,7 @@ mod tests {
.await;
let project = Project::test(app_state.fs.clone(), [path!("/root").as_ref()], cx).await;
- project.update(cx, |project, _cx| {
- project.languages().add(markdown_language())
- });
+ project.update(cx, |project, _cx| project.languages().add(markdown_lang()));
let window = cx.add_window(|window, cx| Workspace::test_new(project, window, cx));
let workspace = window.root(cx).unwrap();
@@ -3327,9 +3326,7 @@ mod tests {
.await;
let project = Project::test(app_state.fs.clone(), [path!("/root").as_ref()], cx).await;
- project.update(cx, |project, _cx| {
- project.languages().add(markdown_language())
- });
+ project.update(cx, |project, _cx| project.languages().add(markdown_lang()));
let window = cx.add_window(|window, cx| Workspace::test_new(project, window, cx));
let workspace = window.root(cx).unwrap();
@@ -3421,9 +3418,7 @@ mod tests {
.await;
let project = Project::test(app_state.fs.clone(), [path!("/root").as_ref()], cx).await;
- project.update(cx, |project, _cx| {
- project.languages().add(markdown_language())
- });
+ project.update(cx, |project, _cx| project.languages().add(markdown_lang()));
let window = cx.add_window(|window, cx| Workspace::test_new(project, window, cx));
let workspace = window.root(cx).unwrap();
@@ -3494,7 +3489,7 @@ mod tests {
let project = Project::test(app_state.fs.clone(), [path!("/root").as_ref()], cx).await;
project.update(cx, |project, _| {
- project.languages().add(markdown_language());
+ project.languages().add(markdown_lang());
project.languages().add(rust_lang());
});
let window = cx.add_window(|window, cx| Workspace::test_new(project, window, cx));
@@ -3647,8 +3642,8 @@ mod tests {
let project = Project::test(app_state.fs.clone(), [], cx).await;
project.update(cx, |project, _| {
- project.languages().add(rust_lang());
- project.languages().add(markdown_language());
+ project.languages().add(language::rust_lang());
+ project.languages().add(language::markdown_lang());
});
let window = cx.add_window(|window, cx| Workspace::test_new(project, window, cx));
@@ -3727,9 +3722,7 @@ mod tests {
.await;
let project = Project::test(app_state.fs.clone(), [path!("/root").as_ref()], cx).await;
- project.update(cx, |project, _cx| {
- project.languages().add(markdown_language())
- });
+ project.update(cx, |project, _cx| project.languages().add(markdown_lang()));
let window = cx.add_window(|window, cx| Workspace::test_new(project, window, cx));
let workspace = window.root(cx).unwrap();
@@ -3831,9 +3824,7 @@ mod tests {
.await;
let project = Project::test(app_state.fs.clone(), [path!("/root").as_ref()], cx).await;
- project.update(cx, |project, _cx| {
- project.languages().add(markdown_language())
- });
+ project.update(cx, |project, _cx| project.languages().add(markdown_lang()));
let workspace =
cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx));
let pane = workspace
@@ -4225,9 +4216,7 @@ mod tests {
.await;
let project = Project::test(app_state.fs.clone(), [path!("/root").as_ref()], cx).await;
- project.update(cx, |project, _cx| {
- project.languages().add(markdown_language())
- });
+ project.update(cx, |project, _cx| project.languages().add(markdown_lang()));
let workspace = cx.add_window(|window, cx| Workspace::test_new(project, window, cx));
let pane = workspace
.read_with(cx, |workspace, _| workspace.active_pane().clone())
@@ -4914,7 +4903,7 @@ mod tests {
let state = Arc::get_mut(&mut app_state).unwrap();
state.build_window_options = build_window_options;
- app_state.languages.add(markdown_language());
+ app_state.languages.add(markdown_lang());
gpui_tokio::init(cx);
theme::init(theme::LoadThemes::JustBase, cx);
@@ -4965,34 +4954,6 @@ mod tests {
})
}
- fn rust_lang() -> Arc<language::Language> {
- Arc::new(language::Language::new(
- language::LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- ))
- }
-
- fn markdown_language() -> Arc<language::Language> {
- Arc::new(language::Language::new(
- language::LanguageConfig {
- name: "Markdown".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["md".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_md::LANGUAGE.into()),
- ))
- }
-
#[track_caller]
fn assert_key_bindings_for(
window: AnyWindowHandle,
@@ -247,7 +247,10 @@ pub fn app_menus(cx: &mut App) -> Vec<Menu> {
MenuItem::action("Go to Definition", editor::actions::GoToDefinition),
MenuItem::action("Go to Declaration", editor::actions::GoToDeclaration),
MenuItem::action("Go to Type Definition", editor::actions::GoToTypeDefinition),
- MenuItem::action("Find All References", editor::actions::FindAllReferences),
+ MenuItem::action(
+ "Find All References",
+ editor::actions::FindAllReferences::default(),
+ ),
MenuItem::separator(),
MenuItem::action("Next Problem", editor::actions::GoToDiagnostic::default()),
MenuItem::action(
@@ -1,20 +1,21 @@
use client::{Client, UserStore};
-use codestral::CodestralCompletionProvider;
+use codestral::CodestralEditPredictionDelegate;
use collections::HashMap;
-use copilot::{Copilot, CopilotCompletionProvider};
+use copilot::{Copilot, CopilotEditPredictionDelegate};
+use edit_prediction::{SweepFeatureFlag, ZedEditPredictionDelegate, Zeta2FeatureFlag};
use editor::Editor;
use feature_flags::FeatureFlagAppExt;
use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
use language::language_settings::{EditPredictionProvider, all_language_settings};
use language_models::MistralLanguageModelProvider;
use settings::{
+ EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME,
EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME,
EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME, SettingsStore,
};
use std::{cell::RefCell, rc::Rc, sync::Arc};
-use supermaven::{Supermaven, SupermavenCompletionProvider};
+use supermaven::{Supermaven, SupermavenEditPredictionDelegate};
use ui::Window;
-use zeta::{SweepFeatureFlag, Zeta2FeatureFlag, ZetaEditPredictionProvider};
pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
let editors: Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>> = Rc::default();
@@ -59,7 +60,7 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
})
.detach();
- cx.on_action(clear_zeta_edit_history);
+ cx.on_action(clear_edit_prediction_store_edit_history);
let mut provider = all_language_settings(None, cx).edit_predictions.provider;
cx.subscribe(&user_store, {
@@ -100,9 +101,9 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
.detach();
}
-fn clear_zeta_edit_history(_: &zeta::ClearHistory, cx: &mut App) {
- if let Some(zeta) = zeta::Zeta::try_global(cx) {
- zeta.update(cx, |zeta, _| zeta.clear_history());
+fn clear_edit_prediction_store_edit_history(_: &edit_prediction::ClearHistory, cx: &mut App) {
+ if let Some(ep_store) = edit_prediction::EditPredictionStore::try_global(cx) {
+ ep_store.update(cx, |ep_store, _| ep_store.clear_history());
}
}
@@ -176,7 +177,7 @@ fn assign_edit_prediction_provider(
match provider {
EditPredictionProvider::None => {
- editor.set_edit_prediction_provider::<ZetaEditPredictionProvider>(None, window, cx);
+ editor.set_edit_prediction_provider::<ZedEditPredictionDelegate>(None, window, cx);
}
EditPredictionProvider::Copilot => {
if let Some(copilot) = Copilot::global(cx) {
@@ -187,55 +188,65 @@ fn assign_edit_prediction_provider(
copilot.register_buffer(&buffer, cx);
});
}
- let provider = cx.new(|_| CopilotCompletionProvider::new(copilot));
+ let provider = cx.new(|_| CopilotEditPredictionDelegate::new(copilot));
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
}
EditPredictionProvider::Supermaven => {
if let Some(supermaven) = Supermaven::global(cx) {
- let provider = cx.new(|_| SupermavenCompletionProvider::new(supermaven));
+ let provider = cx.new(|_| SupermavenEditPredictionDelegate::new(supermaven));
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
}
EditPredictionProvider::Codestral => {
let http_client = client.http_client();
- let provider = cx.new(|_| CodestralCompletionProvider::new(http_client));
+ let provider = cx.new(|_| CodestralEditPredictionDelegate::new(http_client));
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
value @ (EditPredictionProvider::Experimental(_) | EditPredictionProvider::Zed) => {
- let zeta = zeta::Zeta::global(client, &user_store, cx);
+ let ep_store = edit_prediction::EditPredictionStore::global(client, &user_store, cx);
if let Some(project) = editor.project()
&& let Some(buffer) = &singleton_buffer
&& buffer.read(cx).file().is_some()
{
- let has_model = zeta.update(cx, |zeta, cx| {
+ let has_model = ep_store.update(cx, |ep_store, cx| {
let model = if let EditPredictionProvider::Experimental(name) = value {
if name == EXPERIMENTAL_SWEEP_EDIT_PREDICTION_PROVIDER_NAME
&& cx.has_flag::<SweepFeatureFlag>()
{
- zeta::ZetaEditPredictionModel::Sweep
+ edit_prediction::EditPredictionModel::Sweep
} else if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME
&& cx.has_flag::<Zeta2FeatureFlag>()
{
- zeta::ZetaEditPredictionModel::Zeta2
+ edit_prediction::EditPredictionModel::Zeta2
+ } else if name == EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME
+ && cx.has_flag::<Zeta2FeatureFlag>()
+ {
+ edit_prediction::EditPredictionModel::Mercury
} else {
return false;
}
} else if user_store.read(cx).current_user().is_some() {
- zeta::ZetaEditPredictionModel::Zeta1
+ edit_prediction::EditPredictionModel::Zeta1
} else {
return false;
};
- zeta.set_edit_prediction_model(model);
- zeta.register_buffer(buffer, project, cx);
+ ep_store.set_edit_prediction_model(model);
+ ep_store.register_buffer(buffer, project, cx);
true
});
if has_model {
let provider = cx.new(|cx| {
- ZetaEditPredictionProvider::new(project.clone(), &client, &user_store, cx)
+ ZedEditPredictionDelegate::new(
+ project.clone(),
+ singleton_buffer,
+ &client,
+ &user_store,
+ cx,
+ )
});
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
@@ -1,84 +0,0 @@
-[package]
-name = "zeta"
-version = "0.1.0"
-edition.workspace = true
-publish.workspace = true
-license = "GPL-3.0-or-later"
-
-[lints]
-workspace = true
-
-[lib]
-path = "src/zeta.rs"
-
-[features]
-eval-support = []
-
-[dependencies]
-ai_onboarding.workspace = true
-anyhow.workspace = true
-arrayvec.workspace = true
-brotli.workspace = true
-buffer_diff.workspace = true
-client.workspace = true
-cloud_llm_client.workspace = true
-cloud_zeta2_prompt.workspace = true
-collections.workspace = true
-command_palette_hooks.workspace = true
-copilot.workspace = true
-credentials_provider.workspace = true
-db.workspace = true
-edit_prediction.workspace = true
-edit_prediction_context.workspace = true
-editor.workspace = true
-feature_flags.workspace = true
-fs.workspace = true
-futures.workspace = true
-gpui.workspace = true
-indoc.workspace = true
-itertools.workspace = true
-language.workspace = true
-language_model.workspace = true
-log.workspace = true
-lsp.workspace = true
-markdown.workspace = true
-menu.workspace = true
-open_ai.workspace = true
-postage.workspace = true
-pretty_assertions.workspace = true
-project.workspace = true
-rand.workspace = true
-regex.workspace = true
-release_channel.workspace = true
-semver.workspace = true
-serde.workspace = true
-serde_json.workspace = true
-settings.workspace = true
-smol.workspace = true
-strsim.workspace = true
-strum.workspace = true
-telemetry.workspace = true
-telemetry_events.workspace = true
-theme.workspace = true
-thiserror.workspace = true
-ui.workspace = true
-util.workspace = true
-uuid.workspace = true
-workspace.workspace = true
-worktree.workspace = true
-zed_actions.workspace = true
-
-[dev-dependencies]
-clock = { workspace = true, features = ["test-support"] }
-cloud_api_types.workspace = true
-cloud_llm_client = { workspace = true, features = ["test-support"] }
-ctor.workspace = true
-gpui = { workspace = true, features = ["test-support"] }
-indoc.workspace = true
-language = { workspace = true, features = ["test-support"] }
-language_model = { workspace = true, features = ["test-support"] }
-lsp.workspace = true
-parking_lot.workspace = true
-project = { workspace = true, features = ["test-support"] }
-settings = { workspace = true, features = ["test-support"] }
-zlog.workspace = true
@@ -1,173 +0,0 @@
-use cloud_llm_client::predict_edits_v3::Excerpt;
-use edit_prediction_context::Line;
-use language::{BufferSnapshot, Point};
-use std::ops::Range;
-
-pub fn assemble_excerpts(
- buffer: &BufferSnapshot,
- merged_line_ranges: impl IntoIterator<Item = Range<Line>>,
-) -> Vec<Excerpt> {
- let mut output = Vec::new();
-
- let outline_items = buffer.outline_items_as_points_containing(0..buffer.len(), false, None);
- let mut outline_items = outline_items.into_iter().peekable();
-
- for range in merged_line_ranges {
- let point_range = Point::new(range.start.0, 0)..Point::new(range.end.0, 0);
-
- while let Some(outline_item) = outline_items.peek() {
- if outline_item.range.start >= point_range.start {
- break;
- }
- if outline_item.range.end > point_range.start {
- let mut point_range = outline_item.source_range_for_text.clone();
- point_range.start.column = 0;
- point_range.end.column = buffer.line_len(point_range.end.row);
-
- output.push(Excerpt {
- start_line: Line(point_range.start.row),
- text: buffer
- .text_for_range(point_range.clone())
- .collect::<String>()
- .into(),
- })
- }
- outline_items.next();
- }
-
- output.push(Excerpt {
- start_line: Line(point_range.start.row),
- text: buffer
- .text_for_range(point_range.clone())
- .collect::<String>()
- .into(),
- })
- }
-
- output
-}
-
-#[cfg(test)]
-mod tests {
- use std::sync::Arc;
-
- use super::*;
- use cloud_llm_client::predict_edits_v3;
- use gpui::{TestAppContext, prelude::*};
- use indoc::indoc;
- use language::{Buffer, Language, LanguageConfig, LanguageMatcher, OffsetRangeExt};
- use pretty_assertions::assert_eq;
- use util::test::marked_text_ranges;
-
- #[gpui::test]
- fn test_rust(cx: &mut TestAppContext) {
- let table = [
- (
- indoc! {r#"
- struct User {
- first_name: String,
- ยซ last_name: String,
- ageห: u32,
- ยป email: String,
- create_at: Instant,
- }
-
- impl User {
- pub fn first_name(&self) -> String {
- self.first_name.clone()
- }
-
- pub fn full_name(&self) -> String {
- ยซ format!("{} {}", self.first_name, self.last_name)
- ยป }
- }
- "#},
- indoc! {r#"
- 1|struct User {
- โฆ
- 3| last_name: String,
- 4| age<|cursor|>: u32,
- โฆ
- 9|impl User {
- โฆ
- 14| pub fn full_name(&self) -> String {
- 15| format!("{} {}", self.first_name, self.last_name)
- โฆ
- "#},
- ),
- (
- indoc! {r#"
- struct User {
- first_name: String,
- ยซ last_name: String,
- age: u32,
- }
- ยป"#
- },
- indoc! {r#"
- 1|struct User {
- โฆ
- 3| last_name: String,
- 4| age: u32,
- 5|}
- "#},
- ),
- ];
-
- for (input, expected_output) in table {
- let input_without_ranges = input.replace(['ยซ', 'ยป'], "");
- let input_without_caret = input.replace('ห', "");
- let cursor_offset = input_without_ranges.find('ห');
- let (input, ranges) = marked_text_ranges(&input_without_caret, false);
- let buffer =
- cx.new(|cx| Buffer::local(input, cx).with_language(Arc::new(rust_lang()), cx));
- buffer.read_with(cx, |buffer, _cx| {
- let insertions = cursor_offset
- .map(|offset| {
- let point = buffer.offset_to_point(offset);
- vec![(
- predict_edits_v3::Point {
- line: Line(point.row),
- column: point.column,
- },
- "<|cursor|>",
- )]
- })
- .unwrap_or_default();
- let ranges: Vec<Range<Line>> = ranges
- .into_iter()
- .map(|range| {
- let point_range = range.to_point(&buffer);
- Line(point_range.start.row)..Line(point_range.end.row)
- })
- .collect();
-
- let mut output = String::new();
- cloud_zeta2_prompt::write_excerpts(
- assemble_excerpts(&buffer.snapshot(), ranges).iter(),
- &insertions,
- Line(buffer.max_point().row),
- true,
- &mut output,
- );
- assert_eq!(output, expected_output);
- });
- }
- }
-
- fn rust_lang() -> Language {
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(language::tree_sitter_rust::LANGUAGE.into()),
- )
- .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
- .unwrap()
- }
-}
@@ -1,642 +0,0 @@
-use anyhow::Result;
-use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery;
-use collections::HashMap;
-use futures::{
- StreamExt,
- channel::mpsc::{self, UnboundedSender},
-};
-use gpui::{AppContext, AsyncApp, Entity};
-use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, Point, ToOffset, ToPoint};
-use project::{
- Project, WorktreeSettings,
- search::{SearchQuery, SearchResult},
-};
-use smol::channel;
-use std::ops::Range;
-use util::{
- ResultExt as _,
- paths::{PathMatcher, PathStyle},
-};
-use workspace::item::Settings as _;
-
-#[cfg(feature = "eval-support")]
-type CachedSearchResults = std::collections::BTreeMap<std::path::PathBuf, Vec<Range<usize>>>;
-
-pub async fn run_retrieval_searches(
- queries: Vec<SearchToolQuery>,
- project: Entity<Project>,
- #[cfg(feature = "eval-support")] eval_cache: Option<std::sync::Arc<dyn crate::EvalCache>>,
- cx: &mut AsyncApp,
-) -> Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>> {
- #[cfg(feature = "eval-support")]
- let cache = if let Some(eval_cache) = eval_cache {
- use crate::EvalCacheEntryKind;
- use anyhow::Context;
- use collections::FxHasher;
- use std::hash::{Hash, Hasher};
-
- let mut hasher = FxHasher::default();
- project.read_with(cx, |project, cx| {
- let mut worktrees = project.worktrees(cx);
- let Some(worktree) = worktrees.next() else {
- panic!("Expected a single worktree in eval project. Found none.");
- };
- assert!(
- worktrees.next().is_none(),
- "Expected a single worktree in eval project. Found more than one."
- );
- worktree.read(cx).abs_path().hash(&mut hasher);
- })?;
-
- queries.hash(&mut hasher);
- let key = (EvalCacheEntryKind::Search, hasher.finish());
-
- if let Some(cached_results) = eval_cache.read(key) {
- let file_results = serde_json::from_str::<CachedSearchResults>(&cached_results)
- .context("Failed to deserialize cached search results")?;
- let mut results = HashMap::default();
-
- for (path, ranges) in file_results {
- let buffer = project
- .update(cx, |project, cx| {
- let project_path = project.find_project_path(path, cx).unwrap();
- project.open_buffer(project_path, cx)
- })?
- .await?;
- let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
- let mut ranges: Vec<_> = ranges
- .into_iter()
- .map(|range| {
- snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end)
- })
- .collect();
- merge_anchor_ranges(&mut ranges, &snapshot);
- results.insert(buffer, ranges);
- }
-
- return Ok(results);
- }
-
- Some((eval_cache, serde_json::to_string_pretty(&queries)?, key))
- } else {
- None
- };
-
- let (exclude_matcher, path_style) = project.update(cx, |project, cx| {
- let global_settings = WorktreeSettings::get_global(cx);
- let exclude_patterns = global_settings
- .file_scan_exclusions
- .sources()
- .chain(global_settings.private_files.sources());
- let path_style = project.path_style(cx);
- anyhow::Ok((PathMatcher::new(exclude_patterns, path_style)?, path_style))
- })??;
-
- let (results_tx, mut results_rx) = mpsc::unbounded();
-
- for query in queries {
- let exclude_matcher = exclude_matcher.clone();
- let results_tx = results_tx.clone();
- let project = project.clone();
- cx.spawn(async move |cx| {
- run_query(
- query,
- results_tx.clone(),
- path_style,
- exclude_matcher,
- &project,
- cx,
- )
- .await
- .log_err();
- })
- .detach()
- }
- drop(results_tx);
-
- #[cfg(feature = "eval-support")]
- let cache = cache.clone();
- cx.background_spawn(async move {
- let mut results: HashMap<Entity<Buffer>, Vec<Range<Anchor>>> = HashMap::default();
- let mut snapshots = HashMap::default();
-
- let mut total_bytes = 0;
- 'outer: while let Some((buffer, snapshot, excerpts)) = results_rx.next().await {
- snapshots.insert(buffer.entity_id(), snapshot);
- let existing = results.entry(buffer).or_default();
- existing.reserve(excerpts.len());
-
- for (range, size) in excerpts {
- // Blunt trimming of the results until we have a proper algorithmic filtering step
- if (total_bytes + size) > MAX_RESULTS_LEN {
- log::trace!("Combined results reached limit of {MAX_RESULTS_LEN}B");
- break 'outer;
- }
- total_bytes += size;
- existing.push(range);
- }
- }
-
- #[cfg(feature = "eval-support")]
- if let Some((cache, queries, key)) = cache {
- let cached_results: CachedSearchResults = results
- .iter()
- .filter_map(|(buffer, ranges)| {
- let snapshot = snapshots.get(&buffer.entity_id())?;
- let path = snapshot.file().map(|f| f.path());
- let mut ranges = ranges
- .iter()
- .map(|range| range.to_offset(&snapshot))
- .collect::<Vec<_>>();
- ranges.sort_unstable_by_key(|range| (range.start, range.end));
-
- Some((path?.as_std_path().to_path_buf(), ranges))
- })
- .collect();
- cache.write(
- key,
- &queries,
- &serde_json::to_string_pretty(&cached_results)?,
- );
- }
-
- for (buffer, ranges) in results.iter_mut() {
- if let Some(snapshot) = snapshots.get(&buffer.entity_id()) {
- merge_anchor_ranges(ranges, snapshot);
- }
- }
-
- Ok(results)
- })
- .await
-}
-
-pub(crate) fn merge_anchor_ranges(ranges: &mut Vec<Range<Anchor>>, snapshot: &BufferSnapshot) {
- ranges.sort_unstable_by(|a, b| {
- a.start
- .cmp(&b.start, snapshot)
- .then(b.end.cmp(&a.end, snapshot))
- });
-
- let mut index = 1;
- while index < ranges.len() {
- if ranges[index - 1]
- .end
- .cmp(&ranges[index].start, snapshot)
- .is_ge()
- {
- let removed = ranges.remove(index);
- if removed.end.cmp(&ranges[index - 1].end, snapshot).is_gt() {
- ranges[index - 1].end = removed.end;
- }
- } else {
- index += 1;
- }
- }
-}
-
-const MAX_EXCERPT_LEN: usize = 768;
-const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5;
-
-struct SearchJob {
- buffer: Entity<Buffer>,
- snapshot: BufferSnapshot,
- ranges: Vec<Range<usize>>,
- query_ix: usize,
- jobs_tx: channel::Sender<SearchJob>,
-}
-
-async fn run_query(
- input_query: SearchToolQuery,
- results_tx: UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
- path_style: PathStyle,
- exclude_matcher: PathMatcher,
- project: &Entity<Project>,
- cx: &mut AsyncApp,
-) -> Result<()> {
- let include_matcher = PathMatcher::new(vec![input_query.glob], path_style)?;
-
- let make_search = |regex: &str| -> Result<SearchQuery> {
- SearchQuery::regex(
- regex,
- false,
- true,
- false,
- true,
- include_matcher.clone(),
- exclude_matcher.clone(),
- true,
- None,
- )
- };
-
- if let Some(outer_syntax_regex) = input_query.syntax_node.first() {
- let outer_syntax_query = make_search(outer_syntax_regex)?;
- let nested_syntax_queries = input_query
- .syntax_node
- .into_iter()
- .skip(1)
- .map(|query| make_search(&query))
- .collect::<Result<Vec<_>>>()?;
- let content_query = input_query
- .content
- .map(|regex| make_search(®ex))
- .transpose()?;
-
- let (jobs_tx, jobs_rx) = channel::unbounded();
-
- let outer_search_results_rx =
- project.update(cx, |project, cx| project.search(outer_syntax_query, cx))?;
-
- let outer_search_task = cx.spawn(async move |cx| {
- futures::pin_mut!(outer_search_results_rx);
- while let Some(SearchResult::Buffer { buffer, ranges }) =
- outer_search_results_rx.next().await
- {
- buffer
- .read_with(cx, |buffer, _| buffer.parsing_idle())?
- .await;
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
- let expanded_ranges: Vec<_> = ranges
- .into_iter()
- .filter_map(|range| expand_to_parent_range(&range, &snapshot))
- .collect();
- jobs_tx
- .send(SearchJob {
- buffer,
- snapshot,
- ranges: expanded_ranges,
- query_ix: 0,
- jobs_tx: jobs_tx.clone(),
- })
- .await?;
- }
- anyhow::Ok(())
- });
-
- let n_workers = cx.background_executor().num_cpus();
- let search_job_task = cx.background_executor().scoped(|scope| {
- for _ in 0..n_workers {
- scope.spawn(async {
- while let Ok(job) = jobs_rx.recv().await {
- process_nested_search_job(
- &results_tx,
- &nested_syntax_queries,
- &content_query,
- job,
- )
- .await;
- }
- });
- }
- });
-
- search_job_task.await;
- outer_search_task.await?;
- } else if let Some(content_regex) = &input_query.content {
- let search_query = make_search(&content_regex)?;
-
- let results_rx = project.update(cx, |project, cx| project.search(search_query, cx))?;
- futures::pin_mut!(results_rx);
-
- while let Some(SearchResult::Buffer { buffer, ranges }) = results_rx.next().await {
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
-
- let ranges = ranges
- .into_iter()
- .map(|range| {
- let range = range.to_offset(&snapshot);
- let range = expand_to_entire_lines(range, &snapshot);
- let size = range.len();
- let range =
- snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
- (range, size)
- })
- .collect();
-
- let send_result = results_tx.unbounded_send((buffer.clone(), snapshot.clone(), ranges));
-
- if let Err(err) = send_result
- && !err.is_disconnected()
- {
- log::error!("{err}");
- }
- }
- } else {
- log::warn!("Context gathering model produced a glob-only search");
- }
-
- anyhow::Ok(())
-}
-
-async fn process_nested_search_job(
- results_tx: &UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
- queries: &Vec<SearchQuery>,
- content_query: &Option<SearchQuery>,
- job: SearchJob,
-) {
- if let Some(search_query) = queries.get(job.query_ix) {
- let mut subranges = Vec::new();
- for range in job.ranges {
- let start = range.start;
- let search_results = search_query.search(&job.snapshot, Some(range)).await;
- for subrange in search_results {
- let subrange = start + subrange.start..start + subrange.end;
- subranges.extend(expand_to_parent_range(&subrange, &job.snapshot));
- }
- }
- job.jobs_tx
- .send(SearchJob {
- buffer: job.buffer,
- snapshot: job.snapshot,
- ranges: subranges,
- query_ix: job.query_ix + 1,
- jobs_tx: job.jobs_tx.clone(),
- })
- .await
- .ok();
- } else {
- let ranges = if let Some(content_query) = content_query {
- let mut subranges = Vec::new();
- for range in job.ranges {
- let start = range.start;
- let search_results = content_query.search(&job.snapshot, Some(range)).await;
- for subrange in search_results {
- let subrange = start + subrange.start..start + subrange.end;
- subranges.push(subrange);
- }
- }
- subranges
- } else {
- job.ranges
- };
-
- let matches = ranges
- .into_iter()
- .map(|range| {
- let snapshot = &job.snapshot;
- let range = expand_to_entire_lines(range, snapshot);
- let size = range.len();
- let range = snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
- (range, size)
- })
- .collect();
-
- let send_result = results_tx.unbounded_send((job.buffer, job.snapshot, matches));
-
- if let Err(err) = send_result
- && !err.is_disconnected()
- {
- log::error!("{err}");
- }
- }
-}
-
-fn expand_to_entire_lines(range: Range<usize>, snapshot: &BufferSnapshot) -> Range<usize> {
- let mut point_range = range.to_point(snapshot);
- point_range.start.column = 0;
- if point_range.end.column > 0 {
- point_range.end = snapshot.max_point().min(point_range.end + Point::new(1, 0));
- }
- point_range.to_offset(snapshot)
-}
-
-fn expand_to_parent_range<T: ToPoint + ToOffset>(
- range: &Range<T>,
- snapshot: &BufferSnapshot,
-) -> Option<Range<usize>> {
- let mut line_range = range.to_point(&snapshot);
- line_range.start.column = snapshot.indent_size_for_line(line_range.start.row).len;
- line_range.end.column = snapshot.line_len(line_range.end.row);
- // TODO skip result if matched line isn't the first node line?
-
- let node = snapshot.syntax_ancestor(line_range)?;
- Some(node.byte_range())
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use crate::assemble_excerpts::assemble_excerpts;
- use cloud_zeta2_prompt::write_codeblock;
- use edit_prediction_context::Line;
- use gpui::TestAppContext;
- use indoc::indoc;
- use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
- use pretty_assertions::assert_eq;
- use project::FakeFs;
- use serde_json::json;
- use settings::SettingsStore;
- use std::path::Path;
- use util::path;
-
- #[gpui::test]
- async fn test_retrieval(cx: &mut TestAppContext) {
- init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- path!("/root"),
- json!({
- "user.rs": indoc!{"
- pub struct Organization {
- owner: Arc<User>,
- }
-
- pub struct User {
- first_name: String,
- last_name: String,
- }
-
- impl Organization {
- pub fn owner(&self) -> Arc<User> {
- self.owner.clone()
- }
- }
-
- impl User {
- pub fn new(first_name: String, last_name: String) -> Self {
- Self {
- first_name,
- last_name
- }
- }
-
- pub fn first_name(&self) -> String {
- self.first_name.clone()
- }
-
- pub fn last_name(&self) -> String {
- self.last_name.clone()
- }
- }
- "},
- "main.rs": indoc!{r#"
- fn main() {
- let user = User::new(FIRST_NAME.clone(), "doe".into());
- println!("user {:?}", user);
- }
- "#},
- }),
- )
- .await;
-
- let project = Project::test(fs, vec![Path::new(path!("/root"))], cx).await;
- project.update(cx, |project, _cx| {
- project.languages().add(rust_lang().into())
- });
-
- assert_results(
- &project,
- SearchToolQuery {
- glob: "user.rs".into(),
- syntax_node: vec!["impl\\s+User".into(), "pub\\s+fn\\s+first_name".into()],
- content: None,
- },
- indoc! {r#"
- `````root/user.rs
- โฆ
- impl User {
- โฆ
- pub fn first_name(&self) -> String {
- self.first_name.clone()
- }
- โฆ
- `````
- "#},
- cx,
- )
- .await;
-
- assert_results(
- &project,
- SearchToolQuery {
- glob: "user.rs".into(),
- syntax_node: vec!["impl\\s+User".into()],
- content: Some("\\.clone".into()),
- },
- indoc! {r#"
- `````root/user.rs
- โฆ
- impl User {
- โฆ
- pub fn first_name(&self) -> String {
- self.first_name.clone()
- โฆ
- pub fn last_name(&self) -> String {
- self.last_name.clone()
- โฆ
- `````
- "#},
- cx,
- )
- .await;
-
- assert_results(
- &project,
- SearchToolQuery {
- glob: "*.rs".into(),
- syntax_node: vec![],
- content: Some("\\.clone".into()),
- },
- indoc! {r#"
- `````root/main.rs
- fn main() {
- let user = User::new(FIRST_NAME.clone(), "doe".into());
- โฆ
- `````
-
- `````root/user.rs
- โฆ
- impl Organization {
- pub fn owner(&self) -> Arc<User> {
- self.owner.clone()
- โฆ
- impl User {
- โฆ
- pub fn first_name(&self) -> String {
- self.first_name.clone()
- โฆ
- pub fn last_name(&self) -> String {
- self.last_name.clone()
- โฆ
- `````
- "#},
- cx,
- )
- .await;
- }
-
- async fn assert_results(
- project: &Entity<Project>,
- query: SearchToolQuery,
- expected_output: &str,
- cx: &mut TestAppContext,
- ) {
- let results = run_retrieval_searches(
- vec![query],
- project.clone(),
- #[cfg(feature = "eval-support")]
- None,
- &mut cx.to_async(),
- )
- .await
- .unwrap();
-
- let mut results = results.into_iter().collect::<Vec<_>>();
- results.sort_by_key(|results| {
- results
- .0
- .read_with(cx, |buffer, _| buffer.file().unwrap().path().clone())
- });
-
- let mut output = String::new();
- for (buffer, ranges) in results {
- buffer.read_with(cx, |buffer, cx| {
- let excerpts = ranges.into_iter().map(|range| {
- let point_range = range.to_point(buffer);
- if point_range.end.column > 0 {
- Line(point_range.start.row)..Line(point_range.end.row + 1)
- } else {
- Line(point_range.start.row)..Line(point_range.end.row)
- }
- });
-
- write_codeblock(
- &buffer.file().unwrap().full_path(cx),
- assemble_excerpts(&buffer.snapshot(), excerpts).iter(),
- &[],
- Line(buffer.max_point().row),
- false,
- &mut output,
- );
- });
- }
- output.pop();
-
- assert_eq!(output, expected_output);
- }
-
- fn rust_lang() -> Language {
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- )
- .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
- .unwrap()
- }
-
- fn init_test(cx: &mut TestAppContext) {
- cx.update(move |cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- zlog::init_test();
- });
- }
-}
@@ -1,4062 +0,0 @@
-use anyhow::{Context as _, Result, anyhow, bail};
-use arrayvec::ArrayVec;
-use client::{Client, EditPredictionUsage, UserStore};
-use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature};
-use cloud_llm_client::{
- AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
- EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
- MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
- ZED_VERSION_HEADER_NAME,
-};
-use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
-use cloud_zeta2_prompt::{CURSOR_MARKER, DEFAULT_MAX_PROMPT_BYTES};
-use collections::{HashMap, HashSet};
-use command_palette_hooks::CommandPaletteFilter;
-use db::kvp::{Dismissable, KEY_VALUE_STORE};
-use edit_prediction_context::{
- DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
- EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
- SyntaxIndex, SyntaxIndexState,
-};
-use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag};
-use futures::channel::mpsc::UnboundedReceiver;
-use futures::channel::{mpsc, oneshot};
-use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, select_biased};
-use gpui::BackgroundExecutor;
-use gpui::{
- App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
- http_client::{self, AsyncBody, Method},
- prelude::*,
-};
-use language::{
- Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint,
-};
-use language::{BufferSnapshot, OffsetRangeExt};
-use language_model::{LlmApiToken, RefreshLlmTokenListener};
-use open_ai::FunctionDefinition;
-use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
-use release_channel::AppVersion;
-use semver::Version;
-use serde::de::DeserializeOwned;
-use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file};
-use std::any::{Any as _, TypeId};
-use std::collections::{VecDeque, hash_map};
-use telemetry_events::EditPredictionRating;
-use workspace::Workspace;
-
-use std::ops::Range;
-use std::path::Path;
-use std::rc::Rc;
-use std::str::FromStr as _;
-use std::sync::{Arc, LazyLock};
-use std::time::{Duration, Instant};
-use std::{env, mem};
-use thiserror::Error;
-use util::rel_path::RelPathBuf;
-use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt};
-use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
-
-pub mod assemble_excerpts;
-mod license_detection;
-mod onboarding_modal;
-mod prediction;
-mod provider;
-mod rate_prediction_modal;
-pub mod retrieval_search;
-pub mod sweep_ai;
-pub mod udiff;
-mod xml_edits;
-pub mod zeta1;
-
-#[cfg(test)]
-mod zeta_tests;
-
-use crate::assemble_excerpts::assemble_excerpts;
-use crate::license_detection::LicenseDetectionWatcher;
-use crate::onboarding_modal::ZedPredictModal;
-pub use crate::prediction::EditPrediction;
-pub use crate::prediction::EditPredictionId;
-pub use crate::prediction::EditPredictionInputs;
-use crate::prediction::EditPredictionResult;
-use crate::rate_prediction_modal::{
- NextEdit, PreviousEdit, RatePredictionsModal, ThumbsDownActivePrediction,
- ThumbsUpActivePrediction,
-};
-pub use crate::sweep_ai::SweepAi;
-use crate::zeta1::request_prediction_with_zeta1;
-pub use provider::ZetaEditPredictionProvider;
-
-actions!(
- edit_prediction,
- [
- /// Resets the edit prediction onboarding state.
- ResetOnboarding,
- /// Opens the rate completions modal.
- RateCompletions,
- /// Clears the edit prediction history.
- ClearHistory,
- ]
-);
-
-/// Maximum number of events to track.
-const EVENT_COUNT_MAX: usize = 6;
-const CHANGE_GROUPING_LINE_SPAN: u32 = 8;
-const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
-const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
-
-pub struct SweepFeatureFlag;
-
-impl FeatureFlag for SweepFeatureFlag {
- const NAME: &str = "sweep-ai";
-}
-pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
- max_bytes: 512,
- min_bytes: 128,
- target_before_cursor_over_total_bytes: 0.5,
-};
-
-pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
- ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
-
-pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
- excerpt: DEFAULT_EXCERPT_OPTIONS,
-};
-
-pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
- EditPredictionContextOptions {
- use_imports: true,
- max_retrieved_declarations: 0,
- excerpt: DEFAULT_EXCERPT_OPTIONS,
- score: EditPredictionScoreOptions {
- omit_excerpt_overlaps: true,
- },
- };
-
-pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
- context: DEFAULT_CONTEXT_OPTIONS,
- max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
- max_diagnostic_bytes: 2048,
- prompt_format: PromptFormat::DEFAULT,
- file_indexing_parallelism: 1,
- buffer_change_grouping_interval: Duration::from_secs(1),
-};
-
-static USE_OLLAMA: LazyLock<bool> =
- LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty()));
-static CONTEXT_RETRIEVAL_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
- env::var("ZED_ZETA2_CONTEXT_MODEL").unwrap_or(if *USE_OLLAMA {
- "qwen3-coder:30b".to_string()
- } else {
- "yqvev8r3".to_string()
- })
-});
-static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
- match env::var("ZED_ZETA2_MODEL").as_deref() {
- Ok("zeta2-exp") => "4w5n28vw", // Fine-tuned model @ Baseten
- Ok(model) => model,
- Err(_) if *USE_OLLAMA => "qwen3-coder:30b",
- Err(_) => "yqvev8r3", // Vanilla qwen3-coder @ Baseten
- }
- .to_string()
-});
-static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
- env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
- if *USE_OLLAMA {
- Some("http://localhost:11434/v1/chat/completions".into())
- } else {
- None
- }
- })
-});
-
-pub struct Zeta2FeatureFlag;
-
-impl FeatureFlag for Zeta2FeatureFlag {
- const NAME: &'static str = "zeta2";
-
- fn enabled_for_staff() -> bool {
- true
- }
-}
-
-#[derive(Clone)]
-struct ZetaGlobal(Entity<Zeta>);
-
-impl Global for ZetaGlobal {}
-
-pub struct Zeta {
- client: Arc<Client>,
- user_store: Entity<UserStore>,
- llm_token: LlmApiToken,
- _llm_token_subscription: Subscription,
- projects: HashMap<EntityId, ZetaProject>,
- options: ZetaOptions,
- update_required: bool,
- debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
- #[cfg(feature = "eval-support")]
- eval_cache: Option<Arc<dyn EvalCache>>,
- edit_prediction_model: ZetaEditPredictionModel,
- pub sweep_ai: SweepAi,
- data_collection_choice: DataCollectionChoice,
- reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
- shown_predictions: VecDeque<EditPrediction>,
- rated_predictions: HashSet<EditPredictionId>,
-}
-
-#[derive(Copy, Clone, Default, PartialEq, Eq)]
-pub enum ZetaEditPredictionModel {
- #[default]
- Zeta1,
- Zeta2,
- Sweep,
-}
-
-#[derive(Debug, Clone, PartialEq)]
-pub struct ZetaOptions {
- pub context: ContextMode,
- pub max_prompt_bytes: usize,
- pub max_diagnostic_bytes: usize,
- pub prompt_format: predict_edits_v3::PromptFormat,
- pub file_indexing_parallelism: usize,
- pub buffer_change_grouping_interval: Duration,
-}
-
-#[derive(Debug, Clone, PartialEq)]
-pub enum ContextMode {
- Agentic(AgenticContextOptions),
- Syntax(EditPredictionContextOptions),
-}
-
-#[derive(Debug, Clone, PartialEq)]
-pub struct AgenticContextOptions {
- pub excerpt: EditPredictionExcerptOptions,
-}
-
-impl ContextMode {
- pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
- match self {
- ContextMode::Agentic(options) => &options.excerpt,
- ContextMode::Syntax(options) => &options.excerpt,
- }
- }
-}
-
-#[derive(Debug)]
-pub enum ZetaDebugInfo {
- ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
- SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
- SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
- ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
- EditPredictionRequested(ZetaEditPredictionDebugInfo),
-}
-
-#[derive(Debug)]
-pub struct ZetaContextRetrievalStartedDebugInfo {
- pub project: Entity<Project>,
- pub timestamp: Instant,
- pub search_prompt: String,
-}
-
-#[derive(Debug)]
-pub struct ZetaContextRetrievalDebugInfo {
- pub project: Entity<Project>,
- pub timestamp: Instant,
-}
-
-#[derive(Debug)]
-pub struct ZetaEditPredictionDebugInfo {
- pub inputs: EditPredictionInputs,
- pub retrieval_time: Duration,
- pub buffer: WeakEntity<Buffer>,
- pub position: language::Anchor,
- pub local_prompt: Result<String, String>,
- pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
-}
-
-#[derive(Debug)]
-pub struct ZetaSearchQueryDebugInfo {
- pub project: Entity<Project>,
- pub timestamp: Instant,
- pub search_queries: Vec<SearchToolQuery>,
-}
-
-pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
-
-struct ZetaProject {
- syntax_index: Option<Entity<SyntaxIndex>>,
- events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
- last_event: Option<LastEvent>,
- recent_paths: VecDeque<ProjectPath>,
- registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
- current_prediction: Option<CurrentEditPrediction>,
- next_pending_prediction_id: usize,
- pending_predictions: ArrayVec<PendingPrediction, 2>,
- last_prediction_refresh: Option<(EntityId, Instant)>,
- cancelled_predictions: HashSet<usize>,
- context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
- refresh_context_task: Option<LogErrorFuture<Task<Result<()>>>>,
- refresh_context_debounce_task: Option<Task<Option<()>>>,
- refresh_context_timestamp: Option<Instant>,
- license_detection_watchers: HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
- _subscription: gpui::Subscription,
-}
-
-impl ZetaProject {
- pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
- self.events
- .iter()
- .cloned()
- .chain(
- self.last_event
- .as_ref()
- .and_then(|event| event.finalize(&self.license_detection_watchers, cx)),
- )
- .collect()
- }
-
- fn cancel_pending_prediction(
- &mut self,
- pending_prediction: PendingPrediction,
- cx: &mut Context<Zeta>,
- ) {
- self.cancelled_predictions.insert(pending_prediction.id);
-
- cx.spawn(async move |this, cx| {
- let Some(prediction_id) = pending_prediction.task.await else {
- return;
- };
-
- this.update(cx, |this, _cx| {
- this.reject_prediction(prediction_id, EditPredictionRejectReason::Canceled, false);
- })
- .ok();
- })
- .detach()
- }
-}
-
-#[derive(Debug, Clone)]
-struct CurrentEditPrediction {
- pub requested_by: PredictionRequestedBy,
- pub prediction: EditPrediction,
- pub was_shown: bool,
-}
-
-impl CurrentEditPrediction {
- fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
- let Some(new_edits) = self
- .prediction
- .interpolate(&self.prediction.buffer.read(cx))
- else {
- return false;
- };
-
- if self.prediction.buffer != old_prediction.prediction.buffer {
- return true;
- }
-
- let Some(old_edits) = old_prediction
- .prediction
- .interpolate(&old_prediction.prediction.buffer.read(cx))
- else {
- return true;
- };
-
- let requested_by_buffer_id = self.requested_by.buffer_id();
-
- // This reduces the occurrence of UI thrash from replacing edits
- //
- // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
- if requested_by_buffer_id == Some(self.prediction.buffer.entity_id())
- && requested_by_buffer_id == Some(old_prediction.prediction.buffer.entity_id())
- && old_edits.len() == 1
- && new_edits.len() == 1
- {
- let (old_range, old_text) = &old_edits[0];
- let (new_range, new_text) = &new_edits[0];
- new_range == old_range && new_text.starts_with(old_text.as_ref())
- } else {
- true
- }
- }
-}
-
-#[derive(Debug, Clone)]
-enum PredictionRequestedBy {
- DiagnosticsUpdate,
- Buffer(EntityId),
-}
-
-impl PredictionRequestedBy {
- pub fn buffer_id(&self) -> Option<EntityId> {
- match self {
- PredictionRequestedBy::DiagnosticsUpdate => None,
- PredictionRequestedBy::Buffer(buffer_id) => Some(*buffer_id),
- }
- }
-}
-
-#[derive(Debug)]
-struct PendingPrediction {
- id: usize,
- task: Task<Option<EditPredictionId>>,
-}
-
-/// A prediction from the perspective of a buffer.
-#[derive(Debug)]
-enum BufferEditPrediction<'a> {
- Local { prediction: &'a EditPrediction },
- Jump { prediction: &'a EditPrediction },
-}
-
-#[cfg(test)]
-impl std::ops::Deref for BufferEditPrediction<'_> {
- type Target = EditPrediction;
-
- fn deref(&self) -> &Self::Target {
- match self {
- BufferEditPrediction::Local { prediction } => prediction,
- BufferEditPrediction::Jump { prediction } => prediction,
- }
- }
-}
-
-struct RegisteredBuffer {
- snapshot: BufferSnapshot,
- _subscriptions: [gpui::Subscription; 2],
-}
-
-struct LastEvent {
- old_snapshot: BufferSnapshot,
- new_snapshot: BufferSnapshot,
- end_edit_anchor: Option<Anchor>,
-}
-
-impl LastEvent {
- pub fn finalize(
- &self,
- license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
- cx: &App,
- ) -> Option<Arc<predict_edits_v3::Event>> {
- let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
- let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
-
- let file = self.new_snapshot.file();
- let old_file = self.old_snapshot.file();
-
- let in_open_source_repo = [file, old_file].iter().all(|file| {
- file.is_some_and(|file| {
- license_detection_watchers
- .get(&file.worktree_id(cx))
- .is_some_and(|watcher| watcher.is_project_open_source())
- })
- });
-
- let diff = language::unified_diff(&self.old_snapshot.text(), &self.new_snapshot.text());
-
- if path == old_path && diff.is_empty() {
- None
- } else {
- Some(Arc::new(predict_edits_v3::Event::BufferChange {
- old_path,
- path,
- diff,
- in_open_source_repo,
- // TODO: Actually detect if this edit was predicted or not
- predicted: false,
- }))
- }
- }
-}
-
-fn buffer_path_with_id_fallback(snapshot: &BufferSnapshot, cx: &App) -> Arc<Path> {
- if let Some(file) = snapshot.file() {
- file.full_path(cx).into()
- } else {
- Path::new(&format!("untitled-{}", snapshot.remote_id())).into()
- }
-}
-
-impl Zeta {
- pub fn try_global(cx: &App) -> Option<Entity<Self>> {
- cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
- }
-
- pub fn global(
- client: &Arc<Client>,
- user_store: &Entity<UserStore>,
- cx: &mut App,
- ) -> Entity<Self> {
- cx.try_global::<ZetaGlobal>()
- .map(|global| global.0.clone())
- .unwrap_or_else(|| {
- let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
- cx.set_global(ZetaGlobal(zeta.clone()));
- zeta
- })
- }
-
- pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
- let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
- let data_collection_choice = Self::load_data_collection_choice();
-
- let llm_token = LlmApiToken::default();
-
- let (reject_tx, reject_rx) = mpsc::unbounded();
- cx.background_spawn({
- let client = client.clone();
- let llm_token = llm_token.clone();
- let app_version = AppVersion::global(cx);
- let background_executor = cx.background_executor().clone();
- async move {
- Self::handle_rejected_predictions(
- reject_rx,
- client,
- llm_token,
- app_version,
- background_executor,
- )
- .await
- }
- })
- .detach();
-
- Self {
- projects: HashMap::default(),
- client,
- user_store,
- options: DEFAULT_OPTIONS,
- llm_token,
- _llm_token_subscription: cx.subscribe(
- &refresh_llm_token_listener,
- |this, _listener, _event, cx| {
- let client = this.client.clone();
- let llm_token = this.llm_token.clone();
- cx.spawn(async move |_this, _cx| {
- llm_token.refresh(&client).await?;
- anyhow::Ok(())
- })
- .detach_and_log_err(cx);
- },
- ),
- update_required: false,
- debug_tx: None,
- #[cfg(feature = "eval-support")]
- eval_cache: None,
- edit_prediction_model: ZetaEditPredictionModel::Zeta2,
- sweep_ai: SweepAi::new(cx),
- data_collection_choice,
- reject_predictions_tx: reject_tx,
- rated_predictions: Default::default(),
- shown_predictions: Default::default(),
- }
- }
-
- pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) {
- self.edit_prediction_model = model;
- }
-
- pub fn has_sweep_api_token(&self) -> bool {
- self.sweep_ai
- .api_token
- .clone()
- .now_or_never()
- .flatten()
- .is_some()
- }
-
- #[cfg(feature = "eval-support")]
- pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
- self.eval_cache = Some(cache);
- }
-
- pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<ZetaDebugInfo> {
- let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
- self.debug_tx = Some(debug_watch_tx);
- debug_watch_rx
- }
-
- pub fn options(&self) -> &ZetaOptions {
- &self.options
- }
-
- pub fn set_options(&mut self, options: ZetaOptions) {
- self.options = options;
- }
-
- pub fn clear_history(&mut self) {
- for zeta_project in self.projects.values_mut() {
- zeta_project.events.clear();
- }
- }
-
- pub fn context_for_project(
- &self,
- project: &Entity<Project>,
- ) -> impl Iterator<Item = (Entity<Buffer>, &[Range<Anchor>])> {
- self.projects
- .get(&project.entity_id())
- .and_then(|project| {
- Some(
- project
- .context
- .as_ref()?
- .iter()
- .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())),
- )
- })
- .into_iter()
- .flatten()
- }
-
- pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
- if self.edit_prediction_model == ZetaEditPredictionModel::Zeta2 {
- self.user_store.read(cx).edit_prediction_usage()
- } else {
- None
- }
- }
-
- pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
- self.get_or_init_zeta_project(project, cx);
- }
-
- pub fn register_buffer(
- &mut self,
- buffer: &Entity<Buffer>,
- project: &Entity<Project>,
- cx: &mut Context<Self>,
- ) {
- let zeta_project = self.get_or_init_zeta_project(project, cx);
- Self::register_buffer_impl(zeta_project, buffer, project, cx);
- }
-
- fn get_or_init_zeta_project(
- &mut self,
- project: &Entity<Project>,
- cx: &mut Context<Self>,
- ) -> &mut ZetaProject {
- self.projects
- .entry(project.entity_id())
- .or_insert_with(|| ZetaProject {
- syntax_index: if let ContextMode::Syntax(_) = &self.options.context {
- Some(cx.new(|cx| {
- SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
- }))
- } else {
- None
- },
- events: VecDeque::new(),
- last_event: None,
- recent_paths: VecDeque::new(),
- registered_buffers: HashMap::default(),
- current_prediction: None,
- cancelled_predictions: HashSet::default(),
- pending_predictions: ArrayVec::new(),
- next_pending_prediction_id: 0,
- last_prediction_refresh: None,
- context: None,
- refresh_context_task: None,
- refresh_context_debounce_task: None,
- refresh_context_timestamp: None,
- license_detection_watchers: HashMap::default(),
- _subscription: cx.subscribe(&project, Self::handle_project_event),
- })
- }
-
- fn handle_project_event(
- &mut self,
- project: Entity<Project>,
- event: &project::Event,
- cx: &mut Context<Self>,
- ) {
- // TODO [zeta2] init with recent paths
- match event {
- project::Event::ActiveEntryChanged(Some(active_entry_id)) => {
- let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
- return;
- };
- let path = project.read(cx).path_for_entry(*active_entry_id, cx);
- if let Some(path) = path {
- if let Some(ix) = zeta_project
- .recent_paths
- .iter()
- .position(|probe| probe == &path)
- {
- zeta_project.recent_paths.remove(ix);
- }
- zeta_project.recent_paths.push_front(path);
- }
- }
- project::Event::DiagnosticsUpdated { .. } => {
- if cx.has_flag::<Zeta2FeatureFlag>() {
- self.refresh_prediction_from_diagnostics(project, cx);
- }
- }
- _ => (),
- }
- }
-
- fn register_buffer_impl<'a>(
- zeta_project: &'a mut ZetaProject,
- buffer: &Entity<Buffer>,
- project: &Entity<Project>,
- cx: &mut Context<Self>,
- ) -> &'a mut RegisteredBuffer {
- let buffer_id = buffer.entity_id();
-
- if let Some(file) = buffer.read(cx).file() {
- let worktree_id = file.worktree_id(cx);
- if let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) {
- zeta_project
- .license_detection_watchers
- .entry(worktree_id)
- .or_insert_with(|| {
- let project_entity_id = project.entity_id();
- cx.observe_release(&worktree, move |this, _worktree, _cx| {
- let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
- else {
- return;
- };
- zeta_project.license_detection_watchers.remove(&worktree_id);
- })
- .detach();
- Rc::new(LicenseDetectionWatcher::new(&worktree, cx))
- });
- }
- }
-
- match zeta_project.registered_buffers.entry(buffer_id) {
- hash_map::Entry::Occupied(entry) => entry.into_mut(),
- hash_map::Entry::Vacant(entry) => {
- let snapshot = buffer.read(cx).snapshot();
- let project_entity_id = project.entity_id();
- entry.insert(RegisteredBuffer {
- snapshot,
- _subscriptions: [
- cx.subscribe(buffer, {
- let project = project.downgrade();
- move |this, buffer, event, cx| {
- if let language::BufferEvent::Edited = event
- && let Some(project) = project.upgrade()
- {
- this.report_changes_for_buffer(&buffer, &project, cx);
- }
- }
- }),
- cx.observe_release(buffer, move |this, _buffer, _cx| {
- let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
- else {
- return;
- };
- zeta_project.registered_buffers.remove(&buffer_id);
- }),
- ],
- })
- }
- }
- }
-
- fn report_changes_for_buffer(
- &mut self,
- buffer: &Entity<Buffer>,
- project: &Entity<Project>,
- cx: &mut Context<Self>,
- ) {
- let project_state = self.get_or_init_zeta_project(project, cx);
- let registered_buffer = Self::register_buffer_impl(project_state, buffer, project, cx);
-
- let new_snapshot = buffer.read(cx).snapshot();
- if new_snapshot.version == registered_buffer.snapshot.version {
- return;
- }
-
- let old_snapshot = mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
- let end_edit_anchor = new_snapshot
- .anchored_edits_since::<Point>(&old_snapshot.version)
- .last()
- .map(|(_, range)| range.end);
- let events = &mut project_state.events;
-
- if let Some(LastEvent {
- new_snapshot: last_new_snapshot,
- end_edit_anchor: last_end_edit_anchor,
- ..
- }) = project_state.last_event.as_mut()
- {
- let is_next_snapshot_of_same_buffer = old_snapshot.remote_id()
- == last_new_snapshot.remote_id()
- && old_snapshot.version == last_new_snapshot.version;
-
- let should_coalesce = is_next_snapshot_of_same_buffer
- && end_edit_anchor
- .as_ref()
- .zip(last_end_edit_anchor.as_ref())
- .is_some_and(|(a, b)| {
- let a = a.to_point(&new_snapshot);
- let b = b.to_point(&new_snapshot);
- a.row.abs_diff(b.row) <= CHANGE_GROUPING_LINE_SPAN
- });
-
- if should_coalesce {
- *last_end_edit_anchor = end_edit_anchor;
- *last_new_snapshot = new_snapshot;
- return;
- }
- }
-
- if events.len() + 1 >= EVENT_COUNT_MAX {
- events.pop_front();
- }
-
- if let Some(event) = project_state.last_event.take() {
- events.extend(event.finalize(&project_state.license_detection_watchers, cx));
- }
-
- project_state.last_event = Some(LastEvent {
- old_snapshot,
- new_snapshot,
- end_edit_anchor,
- });
- }
-
- fn current_prediction_for_buffer(
- &self,
- buffer: &Entity<Buffer>,
- project: &Entity<Project>,
- cx: &App,
- ) -> Option<BufferEditPrediction<'_>> {
- let project_state = self.projects.get(&project.entity_id())?;
-
- let CurrentEditPrediction {
- requested_by,
- prediction,
- ..
- } = project_state.current_prediction.as_ref()?;
-
- if prediction.targets_buffer(buffer.read(cx)) {
- Some(BufferEditPrediction::Local { prediction })
- } else {
- let show_jump = match requested_by {
- PredictionRequestedBy::Buffer(requested_by_buffer_id) => {
- requested_by_buffer_id == &buffer.entity_id()
- }
- PredictionRequestedBy::DiagnosticsUpdate => true,
- };
-
- if show_jump {
- Some(BufferEditPrediction::Jump { prediction })
- } else {
- None
- }
- }
- }
-
- fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
- match self.edit_prediction_model {
- ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
- ZetaEditPredictionModel::Sweep => return,
- }
-
- let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
- return;
- };
-
- let Some(prediction) = project_state.current_prediction.take() else {
- return;
- };
- let request_id = prediction.prediction.id.to_string();
- for pending_prediction in mem::take(&mut project_state.pending_predictions) {
- project_state.cancel_pending_prediction(pending_prediction, cx);
- }
-
- let client = self.client.clone();
- let llm_token = self.llm_token.clone();
- let app_version = AppVersion::global(cx);
- cx.spawn(async move |this, cx| {
- let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
- http_client::Url::parse(&predict_edits_url)?
- } else {
- client
- .http_client()
- .build_zed_llm_url("/predict_edits/accept", &[])?
- };
-
- let response = cx
- .background_spawn(Self::send_api_request::<()>(
- move |builder| {
- let req = builder.uri(url.as_ref()).body(
- serde_json::to_string(&AcceptEditPredictionBody {
- request_id: request_id.clone(),
- })?
- .into(),
- );
- Ok(req?)
- },
- client,
- llm_token,
- app_version,
- ))
- .await;
-
- Self::handle_api_response(&this, response, cx)?;
- anyhow::Ok(())
- })
- .detach_and_log_err(cx);
- }
-
- async fn handle_rejected_predictions(
- rx: UnboundedReceiver<EditPredictionRejection>,
- client: Arc<Client>,
- llm_token: LlmApiToken,
- app_version: Version,
- background_executor: BackgroundExecutor,
- ) {
- let mut rx = std::pin::pin!(rx.peekable());
- let mut batched = Vec::new();
-
- while let Some(rejection) = rx.next().await {
- batched.push(rejection);
-
- if batched.len() < MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST / 2 {
- select_biased! {
- next = rx.as_mut().peek().fuse() => {
- if next.is_some() {
- continue;
- }
- }
- () = background_executor.timer(REJECT_REQUEST_DEBOUNCE).fuse() => {},
- }
- }
-
- let url = client
- .http_client()
- .build_zed_llm_url("/predict_edits/reject", &[])
- .unwrap();
-
- let flush_count = batched
- .len()
- // in case items have accumulated after failure
- .min(MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST);
- let start = batched.len() - flush_count;
-
- let body = RejectEditPredictionsBodyRef {
- rejections: &batched[start..],
- };
-
- let result = Self::send_api_request::<()>(
- |builder| {
- let req = builder
- .uri(url.as_ref())
- .body(serde_json::to_string(&body)?.into());
- anyhow::Ok(req?)
- },
- client.clone(),
- llm_token.clone(),
- app_version.clone(),
- )
- .await;
-
- if result.log_err().is_some() {
- batched.drain(start..);
- }
- }
- }
-
- fn reject_current_prediction(
- &mut self,
- reason: EditPredictionRejectReason,
- project: &Entity<Project>,
- ) {
- if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
- project_state.pending_predictions.clear();
- if let Some(prediction) = project_state.current_prediction.take() {
- self.reject_prediction(prediction.prediction.id, reason, prediction.was_shown);
- }
- };
- }
-
- fn did_show_current_prediction(&mut self, project: &Entity<Project>, _cx: &mut Context<Self>) {
- if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
- if let Some(current_prediction) = project_state.current_prediction.as_mut() {
- if !current_prediction.was_shown {
- current_prediction.was_shown = true;
- self.shown_predictions
- .push_front(current_prediction.prediction.clone());
- if self.shown_predictions.len() > 50 {
- let completion = self.shown_predictions.pop_back().unwrap();
- self.rated_predictions.remove(&completion.id);
- }
- }
- }
- }
- }
-
- fn reject_prediction(
- &mut self,
- prediction_id: EditPredictionId,
- reason: EditPredictionRejectReason,
- was_shown: bool,
- ) {
- match self.edit_prediction_model {
- ZetaEditPredictionModel::Zeta1 | ZetaEditPredictionModel::Zeta2 => {}
- ZetaEditPredictionModel::Sweep => return,
- }
-
- self.reject_predictions_tx
- .unbounded_send(EditPredictionRejection {
- request_id: prediction_id.to_string(),
- reason,
- was_shown,
- })
- .log_err();
- }
-
- fn is_refreshing(&self, project: &Entity<Project>) -> bool {
- self.projects
- .get(&project.entity_id())
- .is_some_and(|project_state| !project_state.pending_predictions.is_empty())
- }
-
- pub fn refresh_prediction_from_buffer(
- &mut self,
- project: Entity<Project>,
- buffer: Entity<Buffer>,
- position: language::Anchor,
- cx: &mut Context<Self>,
- ) {
- self.queue_prediction_refresh(project.clone(), buffer.entity_id(), cx, move |this, cx| {
- let Some(request_task) = this
- .update(cx, |this, cx| {
- this.request_prediction(
- &project,
- &buffer,
- position,
- PredictEditsRequestTrigger::Other,
- cx,
- )
- })
- .log_err()
- else {
- return Task::ready(anyhow::Ok(None));
- };
-
- cx.spawn(async move |_cx| {
- request_task.await.map(|prediction_result| {
- prediction_result.map(|prediction_result| {
- (
- prediction_result,
- PredictionRequestedBy::Buffer(buffer.entity_id()),
- )
- })
- })
- })
- })
- }
-
- pub fn refresh_prediction_from_diagnostics(
- &mut self,
- project: Entity<Project>,
- cx: &mut Context<Self>,
- ) {
- let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
- return;
- };
-
- // Prefer predictions from buffer
- if zeta_project.current_prediction.is_some() {
- return;
- };
-
- self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
- let Some(open_buffer_task) = project
- .update(cx, |project, cx| {
- project
- .active_entry()
- .and_then(|entry| project.path_for_entry(entry, cx))
- .map(|path| project.open_buffer(path, cx))
- })
- .log_err()
- .flatten()
- else {
- return Task::ready(anyhow::Ok(None));
- };
-
- cx.spawn(async move |cx| {
- let active_buffer = open_buffer_task.await?;
- let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
-
- let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
- active_buffer,
- &snapshot,
- Default::default(),
- Default::default(),
- &project,
- cx,
- )
- .await?
- else {
- return anyhow::Ok(None);
- };
-
- let Some(prediction_result) = this
- .update(cx, |this, cx| {
- this.request_prediction(
- &project,
- &jump_buffer,
- jump_position,
- PredictEditsRequestTrigger::Diagnostics,
- cx,
- )
- })?
- .await?
- else {
- return anyhow::Ok(None);
- };
-
- this.update(cx, |this, cx| {
- Some((
- if this
- .get_or_init_zeta_project(&project, cx)
- .current_prediction
- .is_none()
- {
- prediction_result
- } else {
- EditPredictionResult {
- id: prediction_result.id,
- prediction: Err(EditPredictionRejectReason::CurrentPreferred),
- }
- },
- PredictionRequestedBy::DiagnosticsUpdate,
- ))
- })
- })
- });
- }
-
- #[cfg(not(test))]
- pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
- #[cfg(test)]
- pub const THROTTLE_TIMEOUT: Duration = Duration::ZERO;
-
- fn queue_prediction_refresh(
- &mut self,
- project: Entity<Project>,
- throttle_entity: EntityId,
- cx: &mut Context<Self>,
- do_refresh: impl FnOnce(
- WeakEntity<Self>,
- &mut AsyncApp,
- )
- -> Task<Result<Option<(EditPredictionResult, PredictionRequestedBy)>>>
- + 'static,
- ) {
- let zeta_project = self.get_or_init_zeta_project(&project, cx);
- let pending_prediction_id = zeta_project.next_pending_prediction_id;
- zeta_project.next_pending_prediction_id += 1;
- let last_request = zeta_project.last_prediction_refresh;
-
- let task = cx.spawn(async move |this, cx| {
- if let Some((last_entity, last_timestamp)) = last_request
- && throttle_entity == last_entity
- && let Some(timeout) =
- (last_timestamp + Self::THROTTLE_TIMEOUT).checked_duration_since(Instant::now())
- {
- cx.background_executor().timer(timeout).await;
- }
-
- // If this task was cancelled before the throttle timeout expired,
- // do not perform a request.
- let mut is_cancelled = true;
- this.update(cx, |this, cx| {
- let project_state = this.get_or_init_zeta_project(&project, cx);
- if !project_state
- .cancelled_predictions
- .remove(&pending_prediction_id)
- {
- project_state.last_prediction_refresh = Some((throttle_entity, Instant::now()));
- is_cancelled = false;
- }
- })
- .ok();
- if is_cancelled {
- return None;
- }
-
- let new_prediction_result = do_refresh(this.clone(), cx).await.log_err().flatten();
- let new_prediction_id = new_prediction_result
- .as_ref()
- .map(|(prediction, _)| prediction.id.clone());
-
- // When a prediction completes, remove it from the pending list, and cancel
- // any pending predictions that were enqueued before it.
- this.update(cx, |this, cx| {
- let zeta_project = this.get_or_init_zeta_project(&project, cx);
-
- let is_cancelled = zeta_project
- .cancelled_predictions
- .remove(&pending_prediction_id);
-
- let new_current_prediction = if !is_cancelled
- && let Some((prediction_result, requested_by)) = new_prediction_result
- {
- match prediction_result.prediction {
- Ok(prediction) => {
- let new_prediction = CurrentEditPrediction {
- requested_by,
- prediction,
- was_shown: false,
- };
-
- if let Some(current_prediction) =
- zeta_project.current_prediction.as_ref()
- {
- if new_prediction.should_replace_prediction(¤t_prediction, cx)
- {
- this.reject_current_prediction(
- EditPredictionRejectReason::Replaced,
- &project,
- );
-
- Some(new_prediction)
- } else {
- this.reject_prediction(
- new_prediction.prediction.id,
- EditPredictionRejectReason::CurrentPreferred,
- false,
- );
- None
- }
- } else {
- Some(new_prediction)
- }
- }
- Err(reject_reason) => {
- this.reject_prediction(prediction_result.id, reject_reason, false);
- None
- }
- }
- } else {
- None
- };
-
- let zeta_project = this.get_or_init_zeta_project(&project, cx);
-
- if let Some(new_prediction) = new_current_prediction {
- zeta_project.current_prediction = Some(new_prediction);
- }
-
- let mut pending_predictions = mem::take(&mut zeta_project.pending_predictions);
- for (ix, pending_prediction) in pending_predictions.iter().enumerate() {
- if pending_prediction.id == pending_prediction_id {
- pending_predictions.remove(ix);
- for pending_prediction in pending_predictions.drain(0..ix) {
- zeta_project.cancel_pending_prediction(pending_prediction, cx)
- }
- break;
- }
- }
- this.get_or_init_zeta_project(&project, cx)
- .pending_predictions = pending_predictions;
- cx.notify();
- })
- .ok();
-
- new_prediction_id
- });
-
- if zeta_project.pending_predictions.len() <= 1 {
- zeta_project.pending_predictions.push(PendingPrediction {
- id: pending_prediction_id,
- task,
- });
- } else if zeta_project.pending_predictions.len() == 2 {
- let pending_prediction = zeta_project.pending_predictions.pop().unwrap();
- zeta_project.pending_predictions.push(PendingPrediction {
- id: pending_prediction_id,
- task,
- });
- zeta_project.cancel_pending_prediction(pending_prediction, cx);
- }
- }
-
- pub fn request_prediction(
- &mut self,
- project: &Entity<Project>,
- active_buffer: &Entity<Buffer>,
- position: language::Anchor,
- trigger: PredictEditsRequestTrigger,
- cx: &mut Context<Self>,
- ) -> Task<Result<Option<EditPredictionResult>>> {
- self.request_prediction_internal(
- project.clone(),
- active_buffer.clone(),
- position,
- trigger,
- cx.has_flag::<Zeta2FeatureFlag>(),
- cx,
- )
- }
-
- fn request_prediction_internal(
- &mut self,
- project: Entity<Project>,
- active_buffer: Entity<Buffer>,
- position: language::Anchor,
- trigger: PredictEditsRequestTrigger,
- allow_jump: bool,
- cx: &mut Context<Self>,
- ) -> Task<Result<Option<EditPredictionResult>>> {
- const DIAGNOSTIC_LINES_RANGE: u32 = 20;
-
- self.get_or_init_zeta_project(&project, cx);
- let zeta_project = self.projects.get(&project.entity_id()).unwrap();
- let events = zeta_project.events(cx);
- let has_events = !events.is_empty();
-
- let snapshot = active_buffer.read(cx).snapshot();
- let cursor_point = position.to_point(&snapshot);
- let diagnostic_search_start = cursor_point.row.saturating_sub(DIAGNOSTIC_LINES_RANGE);
- let diagnostic_search_end = cursor_point.row + DIAGNOSTIC_LINES_RANGE;
- let diagnostic_search_range =
- Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
-
- let task = match self.edit_prediction_model {
- ZetaEditPredictionModel::Zeta1 => request_prediction_with_zeta1(
- self,
- &project,
- &active_buffer,
- snapshot.clone(),
- position,
- events,
- trigger,
- cx,
- ),
- ZetaEditPredictionModel::Zeta2 => self.request_prediction_with_zeta2(
- &project,
- &active_buffer,
- snapshot.clone(),
- position,
- events,
- trigger,
- cx,
- ),
- ZetaEditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
- &project,
- &active_buffer,
- snapshot.clone(),
- position,
- events,
- &zeta_project.recent_paths,
- diagnostic_search_range.clone(),
- cx,
- ),
- };
-
- cx.spawn(async move |this, cx| {
- let prediction = task.await?;
-
- if prediction.is_none() && allow_jump {
- let cursor_point = position.to_point(&snapshot);
- if has_events
- && let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
- active_buffer.clone(),
- &snapshot,
- diagnostic_search_range,
- cursor_point,
- &project,
- cx,
- )
- .await?
- {
- return this
- .update(cx, |this, cx| {
- this.request_prediction_internal(
- project,
- jump_buffer,
- jump_position,
- trigger,
- false,
- cx,
- )
- })?
- .await;
- }
-
- return anyhow::Ok(None);
- }
-
- Ok(prediction)
- })
- }
-
- async fn next_diagnostic_location(
- active_buffer: Entity<Buffer>,
- active_buffer_snapshot: &BufferSnapshot,
- active_buffer_diagnostic_search_range: Range<Point>,
- active_buffer_cursor_point: Point,
- project: &Entity<Project>,
- cx: &mut AsyncApp,
- ) -> Result<Option<(Entity<Buffer>, language::Anchor)>> {
- // find the closest diagnostic to the cursor that wasn't close enough to be included in the last request
- let mut jump_location = active_buffer_snapshot
- .diagnostic_groups(None)
- .into_iter()
- .filter_map(|(_, group)| {
- let range = &group.entries[group.primary_ix]
- .range
- .to_point(&active_buffer_snapshot);
- if range.overlaps(&active_buffer_diagnostic_search_range) {
- None
- } else {
- Some(range.start)
- }
- })
- .min_by_key(|probe| probe.row.abs_diff(active_buffer_cursor_point.row))
- .map(|position| {
- (
- active_buffer.clone(),
- active_buffer_snapshot.anchor_before(position),
- )
- });
-
- if jump_location.is_none() {
- let active_buffer_path = active_buffer.read_with(cx, |buffer, cx| {
- let file = buffer.file()?;
-
- Some(ProjectPath {
- worktree_id: file.worktree_id(cx),
- path: file.path().clone(),
- })
- })?;
-
- let buffer_task = project.update(cx, |project, cx| {
- let (path, _, _) = project
- .diagnostic_summaries(false, cx)
- .filter(|(path, _, _)| Some(path) != active_buffer_path.as_ref())
- .max_by_key(|(path, _, _)| {
- // find the buffer with errors that shares most parent directories
- path.path
- .components()
- .zip(
- active_buffer_path
- .as_ref()
- .map(|p| p.path.components())
- .unwrap_or_default(),
- )
- .take_while(|(a, b)| a == b)
- .count()
- })?;
-
- Some(project.open_buffer(path, cx))
- })?;
-
- if let Some(buffer_task) = buffer_task {
- let closest_buffer = buffer_task.await?;
-
- jump_location = closest_buffer
- .read_with(cx, |buffer, _cx| {
- buffer
- .buffer_diagnostics(None)
- .into_iter()
- .min_by_key(|entry| entry.diagnostic.severity)
- .map(|entry| entry.range.start)
- })?
- .map(|position| (closest_buffer, position));
- }
- }
-
- anyhow::Ok(jump_location)
- }
-
- fn request_prediction_with_zeta2(
- &mut self,
- project: &Entity<Project>,
- active_buffer: &Entity<Buffer>,
- active_snapshot: BufferSnapshot,
- position: language::Anchor,
- events: Vec<Arc<Event>>,
- trigger: PredictEditsRequestTrigger,
- cx: &mut Context<Self>,
- ) -> Task<Result<Option<EditPredictionResult>>> {
- let project_state = self.projects.get(&project.entity_id());
-
- let index_state = project_state.and_then(|state| {
- state
- .syntax_index
- .as_ref()
- .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone()))
- });
- let options = self.options.clone();
- let buffer_snapshotted_at = Instant::now();
- let Some(excerpt_path) = active_snapshot
- .file()
- .map(|path| -> Arc<Path> { path.full_path(cx).into() })
- else {
- return Task::ready(Err(anyhow!("No file path for excerpt")));
- };
- let client = self.client.clone();
- let llm_token = self.llm_token.clone();
- let app_version = AppVersion::global(cx);
- let worktree_snapshots = project
- .read(cx)
- .worktrees(cx)
- .map(|worktree| worktree.read(cx).snapshot())
- .collect::<Vec<_>>();
- let debug_tx = self.debug_tx.clone();
-
- let diagnostics = active_snapshot.diagnostic_sets().clone();
-
- let file = active_buffer.read(cx).file();
- let parent_abs_path = project::File::from_dyn(file).and_then(|f| {
- let mut path = f.worktree.read(cx).absolutize(&f.path);
- if path.pop() { Some(path) } else { None }
- });
-
- // TODO data collection
- let can_collect_data = file
- .as_ref()
- .map_or(false, |file| self.can_collect_file(project, file, cx));
-
- let empty_context_files = HashMap::default();
- let context_files = project_state
- .and_then(|project_state| project_state.context.as_ref())
- .unwrap_or(&empty_context_files);
-
- #[cfg(feature = "eval-support")]
- let parsed_fut = futures::future::join_all(
- context_files
- .keys()
- .map(|buffer| buffer.read(cx).parsing_idle()),
- );
-
- let mut included_files = context_files
- .iter()
- .filter_map(|(buffer_entity, ranges)| {
- let buffer = buffer_entity.read(cx);
- Some((
- buffer_entity.clone(),
- buffer.snapshot(),
- buffer.file()?.full_path(cx).into(),
- ranges.clone(),
- ))
- })
- .collect::<Vec<_>>();
-
- included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| {
- (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len()))
- });
-
- #[cfg(feature = "eval-support")]
- let eval_cache = self.eval_cache.clone();
-
- let request_task = cx.background_spawn({
- let active_buffer = active_buffer.clone();
- async move {
- #[cfg(feature = "eval-support")]
- parsed_fut.await;
-
- let index_state = if let Some(index_state) = index_state {
- Some(index_state.lock_owned().await)
- } else {
- None
- };
-
- let cursor_offset = position.to_offset(&active_snapshot);
- let cursor_point = cursor_offset.to_point(&active_snapshot);
-
- let before_retrieval = Instant::now();
-
- let (diagnostic_groups, diagnostic_groups_truncated) =
- Self::gather_nearby_diagnostics(
- cursor_offset,
- &diagnostics,
- &active_snapshot,
- options.max_diagnostic_bytes,
- );
-
- let cloud_request = match options.context {
- ContextMode::Agentic(context_options) => {
- let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
- cursor_point,
- &active_snapshot,
- &context_options.excerpt,
- index_state.as_deref(),
- ) else {
- return Ok((None, None));
- };
-
- let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
- ..active_snapshot.anchor_before(excerpt.range.end);
-
- if let Some(buffer_ix) =
- included_files.iter().position(|(_, snapshot, _, _)| {
- snapshot.remote_id() == active_snapshot.remote_id()
- })
- {
- let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
- ranges.push(excerpt_anchor_range);
- retrieval_search::merge_anchor_ranges(ranges, buffer);
- let last_ix = included_files.len() - 1;
- included_files.swap(buffer_ix, last_ix);
- } else {
- included_files.push((
- active_buffer.clone(),
- active_snapshot.clone(),
- excerpt_path.clone(),
- vec![excerpt_anchor_range],
- ));
- }
-
- let included_files = included_files
- .iter()
- .map(|(_, snapshot, path, ranges)| {
- let ranges = ranges
- .iter()
- .map(|range| {
- let point_range = range.to_point(&snapshot);
- Line(point_range.start.row)..Line(point_range.end.row)
- })
- .collect::<Vec<_>>();
- let excerpts = assemble_excerpts(&snapshot, ranges);
- predict_edits_v3::IncludedFile {
- path: path.clone(),
- max_row: Line(snapshot.max_point().row),
- excerpts,
- }
- })
- .collect::<Vec<_>>();
-
- predict_edits_v3::PredictEditsRequest {
- excerpt_path,
- excerpt: String::new(),
- excerpt_line_range: Line(0)..Line(0),
- excerpt_range: 0..0,
- cursor_point: predict_edits_v3::Point {
- line: predict_edits_v3::Line(cursor_point.row),
- column: cursor_point.column,
- },
- included_files,
- referenced_declarations: vec![],
- events,
- can_collect_data,
- diagnostic_groups,
- diagnostic_groups_truncated,
- debug_info: debug_tx.is_some(),
- prompt_max_bytes: Some(options.max_prompt_bytes),
- prompt_format: options.prompt_format,
- // TODO [zeta2]
- signatures: vec![],
- excerpt_parent: None,
- git_info: None,
- trigger,
- }
- }
- ContextMode::Syntax(context_options) => {
- let Some(context) = EditPredictionContext::gather_context(
- cursor_point,
- &active_snapshot,
- parent_abs_path.as_deref(),
- &context_options,
- index_state.as_deref(),
- ) else {
- return Ok((None, None));
- };
-
- make_syntax_context_cloud_request(
- excerpt_path,
- context,
- events,
- can_collect_data,
- diagnostic_groups,
- diagnostic_groups_truncated,
- None,
- debug_tx.is_some(),
- &worktree_snapshots,
- index_state.as_deref(),
- Some(options.max_prompt_bytes),
- options.prompt_format,
- trigger,
- )
- }
- };
-
- let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
-
- let inputs = EditPredictionInputs {
- included_files: cloud_request.included_files,
- events: cloud_request.events,
- cursor_point: cloud_request.cursor_point,
- cursor_path: cloud_request.excerpt_path,
- };
-
- let retrieval_time = Instant::now() - before_retrieval;
-
- let debug_response_tx = if let Some(debug_tx) = &debug_tx {
- let (response_tx, response_rx) = oneshot::channel();
-
- debug_tx
- .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
- ZetaEditPredictionDebugInfo {
- inputs: inputs.clone(),
- retrieval_time,
- buffer: active_buffer.downgrade(),
- local_prompt: match prompt_result.as_ref() {
- Ok((prompt, _)) => Ok(prompt.clone()),
- Err(err) => Err(err.to_string()),
- },
- position,
- response_rx,
- },
- ))
- .ok();
- Some(response_tx)
- } else {
- None
- };
-
- if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
- if let Some(debug_response_tx) = debug_response_tx {
- debug_response_tx
- .send((Err("Request skipped".to_string()), Duration::ZERO))
- .ok();
- }
- anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
- }
-
- let (prompt, _) = prompt_result?;
- let generation_params =
- cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
- let request = open_ai::Request {
- model: EDIT_PREDICTIONS_MODEL_ID.clone(),
- messages: vec![open_ai::RequestMessage::User {
- content: open_ai::MessageContent::Plain(prompt),
- }],
- stream: false,
- max_completion_tokens: None,
- stop: generation_params.stop.unwrap_or_default(),
- temperature: generation_params.temperature.unwrap_or(0.7),
- tool_choice: None,
- parallel_tool_calls: None,
- tools: vec![],
- prompt_cache_key: None,
- reasoning_effort: None,
- };
-
- log::trace!("Sending edit prediction request");
-
- let before_request = Instant::now();
- let response = Self::send_raw_llm_request(
- request,
- client,
- llm_token,
- app_version,
- #[cfg(feature = "eval-support")]
- eval_cache,
- #[cfg(feature = "eval-support")]
- EvalCacheEntryKind::Prediction,
- )
- .await;
- let received_response_at = Instant::now();
- let request_time = received_response_at - before_request;
-
- log::trace!("Got edit prediction response");
-
- if let Some(debug_response_tx) = debug_response_tx {
- debug_response_tx
- .send((
- response
- .as_ref()
- .map_err(|err| err.to_string())
- .map(|response| response.0.clone()),
- request_time,
- ))
- .ok();
- }
-
- let (res, usage) = response?;
- let request_id = EditPredictionId(res.id.clone().into());
- let Some(mut output_text) = text_from_response(res) else {
- return Ok((Some((request_id, None)), usage));
- };
-
- if output_text.contains(CURSOR_MARKER) {
- log::trace!("Stripping out {CURSOR_MARKER} from response");
- output_text = output_text.replace(CURSOR_MARKER, "");
- }
-
- let get_buffer_from_context = |path: &Path| {
- included_files
- .iter()
- .find_map(|(_, buffer, probe_path, ranges)| {
- if probe_path.as_ref() == path {
- Some((buffer, ranges.as_slice()))
- } else {
- None
- }
- })
- };
-
- let (edited_buffer_snapshot, edits) = match options.prompt_format {
- PromptFormat::NumLinesUniDiff => {
- // TODO: Implement parsing of multi-file diffs
- crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
- }
- PromptFormat::Minimal
- | PromptFormat::MinimalQwen
- | PromptFormat::SeedCoder1120 => {
- if output_text.contains("--- a/\n+++ b/\nNo edits") {
- let edits = vec![];
- (&active_snapshot, edits)
- } else {
- crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
- }
- }
- PromptFormat::OldTextNewText => {
- crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context)
- .await?
- }
- _ => {
- bail!("unsupported prompt format {}", options.prompt_format)
- }
- };
-
- let edited_buffer = included_files
- .iter()
- .find_map(|(buffer, snapshot, _, _)| {
- if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
- Some(buffer.clone())
- } else {
- None
- }
- })
- .context("Failed to find buffer in included_buffers")?;
-
- anyhow::Ok((
- Some((
- request_id,
- Some((
- inputs,
- edited_buffer,
- edited_buffer_snapshot.clone(),
- edits,
- received_response_at,
- )),
- )),
- usage,
- ))
- }
- });
-
- cx.spawn({
- async move |this, cx| {
- let Some((id, prediction)) =
- Self::handle_api_response(&this, request_task.await, cx)?
- else {
- return Ok(None);
- };
-
- let Some((
- inputs,
- edited_buffer,
- edited_buffer_snapshot,
- edits,
- received_response_at,
- )) = prediction
- else {
- return Ok(Some(EditPredictionResult {
- id,
- prediction: Err(EditPredictionRejectReason::Empty),
- }));
- };
-
- // TODO telemetry: duration, etc
- Ok(Some(
- EditPredictionResult::new(
- id,
- &edited_buffer,
- &edited_buffer_snapshot,
- edits.into(),
- buffer_snapshotted_at,
- received_response_at,
- inputs,
- cx,
- )
- .await,
- ))
- }
- })
- }
-
- async fn send_raw_llm_request(
- request: open_ai::Request,
- client: Arc<Client>,
- llm_token: LlmApiToken,
- app_version: Version,
- #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
- #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
- ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
- let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
- http_client::Url::parse(&predict_edits_url)?
- } else {
- client
- .http_client()
- .build_zed_llm_url("/predict_edits/raw", &[])?
- };
-
- #[cfg(feature = "eval-support")]
- let cache_key = if let Some(cache) = eval_cache {
- use collections::FxHasher;
- use std::hash::{Hash, Hasher};
-
- let mut hasher = FxHasher::default();
- url.hash(&mut hasher);
- let request_str = serde_json::to_string_pretty(&request)?;
- request_str.hash(&mut hasher);
- let hash = hasher.finish();
-
- let key = (eval_cache_kind, hash);
- if let Some(response_str) = cache.read(key) {
- return Ok((serde_json::from_str(&response_str)?, None));
- }
-
- Some((cache, request_str, key))
- } else {
- None
- };
-
- let (response, usage) = Self::send_api_request(
- |builder| {
- let req = builder
- .uri(url.as_ref())
- .body(serde_json::to_string(&request)?.into());
- Ok(req?)
- },
- client,
- llm_token,
- app_version,
- )
- .await?;
-
- #[cfg(feature = "eval-support")]
- if let Some((cache, request, key)) = cache_key {
- cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
- }
-
- Ok((response, usage))
- }
-
- fn handle_api_response<T>(
- this: &WeakEntity<Self>,
- response: Result<(T, Option<EditPredictionUsage>)>,
- cx: &mut gpui::AsyncApp,
- ) -> Result<T> {
- match response {
- Ok((data, usage)) => {
- if let Some(usage) = usage {
- this.update(cx, |this, cx| {
- this.user_store.update(cx, |user_store, cx| {
- user_store.update_edit_prediction_usage(usage, cx);
- });
- })
- .ok();
- }
- Ok(data)
- }
- Err(err) => {
- if err.is::<ZedUpdateRequiredError>() {
- cx.update(|cx| {
- this.update(cx, |this, _cx| {
- this.update_required = true;
- })
- .ok();
-
- let error_message: SharedString = err.to_string().into();
- show_app_notification(
- NotificationId::unique::<ZedUpdateRequiredError>(),
- cx,
- move |cx| {
- cx.new(|cx| {
- ErrorMessagePrompt::new(error_message.clone(), cx)
- .with_link_button("Update Zed", "https://zed.dev/releases")
- })
- },
- );
- })
- .ok();
- }
- Err(err)
- }
- }
- }
-
- async fn send_api_request<Res>(
- build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
- client: Arc<Client>,
- llm_token: LlmApiToken,
- app_version: Version,
- ) -> Result<(Res, Option<EditPredictionUsage>)>
- where
- Res: DeserializeOwned,
- {
- let http_client = client.http_client();
- let mut token = llm_token.acquire(&client).await?;
- let mut did_retry = false;
-
- loop {
- let request_builder = http_client::Request::builder().method(Method::POST);
-
- let request = build(
- request_builder
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", token))
- .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
- )?;
-
- let mut response = http_client.send(request).await?;
-
- if let Some(minimum_required_version) = response
- .headers()
- .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
- .and_then(|version| Version::from_str(version.to_str().ok()?).ok())
- {
- anyhow::ensure!(
- app_version >= minimum_required_version,
- ZedUpdateRequiredError {
- minimum_version: minimum_required_version
- }
- );
- }
-
- if response.status().is_success() {
- let usage = EditPredictionUsage::from_headers(response.headers()).ok();
-
- let mut body = Vec::new();
- response.body_mut().read_to_end(&mut body).await?;
- return Ok((serde_json::from_slice(&body)?, usage));
- } else if !did_retry
- && response
- .headers()
- .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
- .is_some()
- {
- did_retry = true;
- token = llm_token.refresh(&client).await?;
- } else {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- anyhow::bail!(
- "Request failed with status: {:?}\nBody: {}",
- response.status(),
- body
- );
- }
- }
- }
-
- pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
- pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
-
- // Refresh the related excerpts when the user just beguns editing after
- // an idle period, and after they pause editing.
- fn refresh_context_if_needed(
- &mut self,
- project: &Entity<Project>,
- buffer: &Entity<language::Buffer>,
- cursor_position: language::Anchor,
- cx: &mut Context<Self>,
- ) {
- if !matches!(self.edit_prediction_model, ZetaEditPredictionModel::Zeta2) {
- return;
- }
-
- if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
- return;
- }
-
- let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
- return;
- };
-
- let now = Instant::now();
- let was_idle = zeta_project
- .refresh_context_timestamp
- .map_or(true, |timestamp| {
- now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
- });
- zeta_project.refresh_context_timestamp = Some(now);
- zeta_project.refresh_context_debounce_task = Some(cx.spawn({
- let buffer = buffer.clone();
- let project = project.clone();
- async move |this, cx| {
- if was_idle {
- log::debug!("refetching edit prediction context after idle");
- } else {
- cx.background_executor()
- .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
- .await;
- log::debug!("refetching edit prediction context after pause");
- }
- this.update(cx, |this, cx| {
- let task = this.refresh_context(project.clone(), buffer, cursor_position, cx);
-
- if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) {
- zeta_project.refresh_context_task = Some(task.log_err());
- };
- })
- .ok()
- }
- }));
- }
-
- // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
- // and avoid spawning more than one concurrent task.
- pub fn refresh_context(
- &mut self,
- project: Entity<Project>,
- buffer: Entity<language::Buffer>,
- cursor_position: language::Anchor,
- cx: &mut Context<Self>,
- ) -> Task<Result<()>> {
- let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
- return Task::ready(anyhow::Ok(()));
- };
-
- let ContextMode::Agentic(options) = &self.options().context else {
- return Task::ready(anyhow::Ok(()));
- };
-
- let snapshot = buffer.read(cx).snapshot();
- let cursor_point = cursor_position.to_point(&snapshot);
- let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
- cursor_point,
- &snapshot,
- &options.excerpt,
- None,
- ) else {
- return Task::ready(Ok(()));
- };
-
- let app_version = AppVersion::global(cx);
- let client = self.client.clone();
- let llm_token = self.llm_token.clone();
- let debug_tx = self.debug_tx.clone();
- let current_file_path: Arc<Path> = snapshot
- .file()
- .map(|f| f.full_path(cx).into())
- .unwrap_or_else(|| Path::new("untitled").into());
-
- let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
- predict_edits_v3::PlanContextRetrievalRequest {
- excerpt: cursor_excerpt.text(&snapshot).body,
- excerpt_path: current_file_path,
- excerpt_line_range: cursor_excerpt.line_range,
- cursor_file_max_row: Line(snapshot.max_point().row),
- events: zeta_project.events(cx),
- },
- ) {
- Ok(prompt) => prompt,
- Err(err) => {
- return Task::ready(Err(err));
- }
- };
-
- if let Some(debug_tx) = &debug_tx {
- debug_tx
- .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
- ZetaContextRetrievalStartedDebugInfo {
- project: project.clone(),
- timestamp: Instant::now(),
- search_prompt: prompt.clone(),
- },
- ))
- .ok();
- }
-
- pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
- let schema = language_model::tool_schema::root_schema_for::<SearchToolInput>(
- language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
- );
-
- let description = schema
- .get("description")
- .and_then(|description| description.as_str())
- .unwrap()
- .to_string();
-
- (schema.into(), description)
- });
-
- let (tool_schema, tool_description) = TOOL_SCHEMA.clone();
-
- let request = open_ai::Request {
- model: CONTEXT_RETRIEVAL_MODEL_ID.clone(),
- messages: vec![open_ai::RequestMessage::User {
- content: open_ai::MessageContent::Plain(prompt),
- }],
- stream: false,
- max_completion_tokens: None,
- stop: Default::default(),
- temperature: 0.7,
- tool_choice: None,
- parallel_tool_calls: None,
- tools: vec![open_ai::ToolDefinition::Function {
- function: FunctionDefinition {
- name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
- description: Some(tool_description),
- parameters: Some(tool_schema),
- },
- }],
- prompt_cache_key: None,
- reasoning_effort: None,
- };
-
- #[cfg(feature = "eval-support")]
- let eval_cache = self.eval_cache.clone();
-
- cx.spawn(async move |this, cx| {
- log::trace!("Sending search planning request");
- let response = Self::send_raw_llm_request(
- request,
- client,
- llm_token,
- app_version,
- #[cfg(feature = "eval-support")]
- eval_cache.clone(),
- #[cfg(feature = "eval-support")]
- EvalCacheEntryKind::Context,
- )
- .await;
- let mut response = Self::handle_api_response(&this, response, cx)?;
- log::trace!("Got search planning response");
-
- let choice = response
- .choices
- .pop()
- .context("No choices in retrieval response")?;
- let open_ai::RequestMessage::Assistant {
- content: _,
- tool_calls,
- } = choice.message
- else {
- anyhow::bail!("Retrieval response didn't include an assistant message");
- };
-
- let mut queries: Vec<SearchToolQuery> = Vec::new();
- for tool_call in tool_calls {
- let open_ai::ToolCallContent::Function { function } = tool_call.content;
- if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
- log::warn!(
- "Context retrieval response tried to call an unknown tool: {}",
- function.name
- );
-
- continue;
- }
-
- let input: SearchToolInput = serde_json::from_str(&function.arguments)
- .with_context(|| format!("invalid search json {}", &function.arguments))?;
- queries.extend(input.queries);
- }
-
- if let Some(debug_tx) = &debug_tx {
- debug_tx
- .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
- ZetaSearchQueryDebugInfo {
- project: project.clone(),
- timestamp: Instant::now(),
- search_queries: queries.clone(),
- },
- ))
- .ok();
- }
-
- log::trace!("Running retrieval search: {queries:#?}");
-
- let related_excerpts_result = retrieval_search::run_retrieval_searches(
- queries,
- project.clone(),
- #[cfg(feature = "eval-support")]
- eval_cache,
- cx,
- )
- .await;
-
- log::trace!("Search queries executed");
-
- if let Some(debug_tx) = &debug_tx {
- debug_tx
- .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
- ZetaContextRetrievalDebugInfo {
- project: project.clone(),
- timestamp: Instant::now(),
- },
- ))
- .ok();
- }
-
- this.update(cx, |this, _cx| {
- let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
- return Ok(());
- };
- zeta_project.refresh_context_task.take();
- if let Some(debug_tx) = &this.debug_tx {
- debug_tx
- .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished(
- ZetaContextRetrievalDebugInfo {
- project,
- timestamp: Instant::now(),
- },
- ))
- .ok();
- }
- match related_excerpts_result {
- Ok(excerpts) => {
- zeta_project.context = Some(excerpts);
- Ok(())
- }
- Err(error) => Err(error),
- }
- })?
- })
- }
-
- pub fn set_context(
- &mut self,
- project: Entity<Project>,
- context: HashMap<Entity<Buffer>, Vec<Range<Anchor>>>,
- ) {
- if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) {
- zeta_project.context = Some(context);
- }
- }
-
- fn gather_nearby_diagnostics(
- cursor_offset: usize,
- diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
- snapshot: &BufferSnapshot,
- max_diagnostics_bytes: usize,
- ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
- // TODO: Could make this more efficient
- let mut diagnostic_groups = Vec::new();
- for (language_server_id, diagnostics) in diagnostic_sets {
- let mut groups = Vec::new();
- diagnostics.groups(*language_server_id, &mut groups, &snapshot);
- diagnostic_groups.extend(
- groups
- .into_iter()
- .map(|(_, group)| group.resolve::<usize>(&snapshot)),
- );
- }
-
- // sort by proximity to cursor
- diagnostic_groups.sort_by_key(|group| {
- let range = &group.entries[group.primary_ix].range;
- if range.start >= cursor_offset {
- range.start - cursor_offset
- } else if cursor_offset >= range.end {
- cursor_offset - range.end
- } else {
- (cursor_offset - range.start).min(range.end - cursor_offset)
- }
- });
-
- let mut results = Vec::new();
- let mut diagnostic_groups_truncated = false;
- let mut diagnostics_byte_count = 0;
- for group in diagnostic_groups {
- let raw_value = serde_json::value::to_raw_value(&group).unwrap();
- diagnostics_byte_count += raw_value.get().len();
- if diagnostics_byte_count > max_diagnostics_bytes {
- diagnostic_groups_truncated = true;
- break;
- }
- results.push(predict_edits_v3::DiagnosticGroup(raw_value));
- }
-
- (results, diagnostic_groups_truncated)
- }
-
- // TODO: Dedupe with similar code in request_prediction?
- pub fn cloud_request_for_zeta_cli(
- &mut self,
- project: &Entity<Project>,
- buffer: &Entity<Buffer>,
- position: language::Anchor,
- cx: &mut Context<Self>,
- ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
- let project_state = self.projects.get(&project.entity_id());
-
- let index_state = project_state.and_then(|state| {
- state
- .syntax_index
- .as_ref()
- .map(|index| index.read_with(cx, |index, _cx| index.state().clone()))
- });
- let options = self.options.clone();
- let snapshot = buffer.read(cx).snapshot();
- let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
- return Task::ready(Err(anyhow!("No file path for excerpt")));
- };
- let worktree_snapshots = project
- .read(cx)
- .worktrees(cx)
- .map(|worktree| worktree.read(cx).snapshot())
- .collect::<Vec<_>>();
-
- let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
- let mut path = f.worktree.read(cx).absolutize(&f.path);
- if path.pop() { Some(path) } else { None }
- });
-
- cx.background_spawn(async move {
- let index_state = if let Some(index_state) = index_state {
- Some(index_state.lock_owned().await)
- } else {
- None
- };
-
- let cursor_point = position.to_point(&snapshot);
-
- let debug_info = true;
- EditPredictionContext::gather_context(
- cursor_point,
- &snapshot,
- parent_abs_path.as_deref(),
- match &options.context {
- ContextMode::Agentic(_) => {
- // TODO
- panic!("Llm mode not supported in zeta cli yet");
- }
- ContextMode::Syntax(edit_prediction_context_options) => {
- edit_prediction_context_options
- }
- },
- index_state.as_deref(),
- )
- .context("Failed to select excerpt")
- .map(|context| {
- make_syntax_context_cloud_request(
- excerpt_path.into(),
- context,
- // TODO pass everything
- Vec::new(),
- false,
- Vec::new(),
- false,
- None,
- debug_info,
- &worktree_snapshots,
- index_state.as_deref(),
- Some(options.max_prompt_bytes),
- options.prompt_format,
- PredictEditsRequestTrigger::Other,
- )
- })
- })
- }
-
- pub fn wait_for_initial_indexing(
- &mut self,
- project: &Entity<Project>,
- cx: &mut Context<Self>,
- ) -> Task<Result<()>> {
- let zeta_project = self.get_or_init_zeta_project(project, cx);
- if let Some(syntax_index) = &zeta_project.syntax_index {
- syntax_index.read(cx).wait_for_initial_file_indexing(cx)
- } else {
- Task::ready(Ok(()))
- }
- }
-
- fn is_file_open_source(
- &self,
- project: &Entity<Project>,
- file: &Arc<dyn File>,
- cx: &App,
- ) -> bool {
- if !file.is_local() || file.is_private() {
- return false;
- }
- let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
- return false;
- };
- zeta_project
- .license_detection_watchers
- .get(&file.worktree_id(cx))
- .as_ref()
- .is_some_and(|watcher| watcher.is_project_open_source())
- }
-
- fn can_collect_file(&self, project: &Entity<Project>, file: &Arc<dyn File>, cx: &App) -> bool {
- self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
- }
-
- fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
- if !self.data_collection_choice.is_enabled() {
- return false;
- }
- events.iter().all(|event| {
- matches!(
- event.as_ref(),
- Event::BufferChange {
- in_open_source_repo: true,
- ..
- }
- )
- })
- }
-
- fn load_data_collection_choice() -> DataCollectionChoice {
- let choice = KEY_VALUE_STORE
- .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
- .log_err()
- .flatten();
-
- match choice.as_deref() {
- Some("true") => DataCollectionChoice::Enabled,
- Some("false") => DataCollectionChoice::Disabled,
- Some(_) => {
- log::error!("unknown value in '{ZED_PREDICT_DATA_COLLECTION_CHOICE}'");
- DataCollectionChoice::NotAnswered
- }
- None => DataCollectionChoice::NotAnswered,
- }
- }
-
- pub fn shown_predictions(&self) -> impl DoubleEndedIterator<Item = &EditPrediction> {
- self.shown_predictions.iter()
- }
-
- pub fn shown_completions_len(&self) -> usize {
- self.shown_predictions.len()
- }
-
- pub fn is_prediction_rated(&self, id: &EditPredictionId) -> bool {
- self.rated_predictions.contains(id)
- }
-
- pub fn rate_prediction(
- &mut self,
- prediction: &EditPrediction,
- rating: EditPredictionRating,
- feedback: String,
- cx: &mut Context<Self>,
- ) {
- self.rated_predictions.insert(prediction.id.clone());
- telemetry::event!(
- "Edit Prediction Rated",
- rating,
- inputs = prediction.inputs,
- output = prediction.edit_preview.as_unified_diff(&prediction.edits),
- feedback
- );
- self.client.telemetry().flush_events().detach();
- cx.notify();
- }
-}
-
-pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
- let choice = res.choices.pop()?;
- let output_text = match choice.message {
- open_ai::RequestMessage::Assistant {
- content: Some(open_ai::MessageContent::Plain(content)),
- ..
- } => content,
- open_ai::RequestMessage::Assistant {
- content: Some(open_ai::MessageContent::Multipart(mut content)),
- ..
- } => {
- if content.is_empty() {
- log::error!("No output from Baseten completion response");
- return None;
- }
-
- match content.remove(0) {
- open_ai::MessagePart::Text { text } => text,
- open_ai::MessagePart::Image { .. } => {
- log::error!("Expected text, got an image");
- return None;
- }
- }
- }
- _ => {
- log::error!("Invalid response message: {:?}", choice.message);
- return None;
- }
- };
- Some(output_text)
-}
-
-#[derive(Error, Debug)]
-#[error(
- "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
-)]
-pub struct ZedUpdateRequiredError {
- minimum_version: Version,
-}
-
-fn make_syntax_context_cloud_request(
- excerpt_path: Arc<Path>,
- context: EditPredictionContext,
- events: Vec<Arc<predict_edits_v3::Event>>,
- can_collect_data: bool,
- diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
- diagnostic_groups_truncated: bool,
- git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
- debug_info: bool,
- worktrees: &Vec<worktree::Snapshot>,
- index_state: Option<&SyntaxIndexState>,
- prompt_max_bytes: Option<usize>,
- prompt_format: PromptFormat,
- trigger: PredictEditsRequestTrigger,
-) -> predict_edits_v3::PredictEditsRequest {
- let mut signatures = Vec::new();
- let mut declaration_to_signature_index = HashMap::default();
- let mut referenced_declarations = Vec::new();
-
- for snippet in context.declarations {
- let project_entry_id = snippet.declaration.project_entry_id();
- let Some(path) = worktrees.iter().find_map(|worktree| {
- worktree.entry_for_id(project_entry_id).map(|entry| {
- let mut full_path = RelPathBuf::new();
- full_path.push(worktree.root_name());
- full_path.push(&entry.path);
- full_path
- })
- }) else {
- continue;
- };
-
- let parent_index = index_state.and_then(|index_state| {
- snippet.declaration.parent().and_then(|parent| {
- add_signature(
- parent,
- &mut declaration_to_signature_index,
- &mut signatures,
- index_state,
- )
- })
- });
-
- let (text, text_is_truncated) = snippet.declaration.item_text();
- referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
- path: path.as_std_path().into(),
- text: text.into(),
- range: snippet.declaration.item_line_range(),
- text_is_truncated,
- signature_range: snippet.declaration.signature_range_in_item_text(),
- parent_index,
- signature_score: snippet.score(DeclarationStyle::Signature),
- declaration_score: snippet.score(DeclarationStyle::Declaration),
- score_components: snippet.components,
- });
- }
-
- let excerpt_parent = index_state.and_then(|index_state| {
- context
- .excerpt
- .parent_declarations
- .last()
- .and_then(|(parent, _)| {
- add_signature(
- *parent,
- &mut declaration_to_signature_index,
- &mut signatures,
- index_state,
- )
- })
- });
-
- predict_edits_v3::PredictEditsRequest {
- excerpt_path,
- excerpt: context.excerpt_text.body,
- excerpt_line_range: context.excerpt.line_range,
- excerpt_range: context.excerpt.range,
- cursor_point: predict_edits_v3::Point {
- line: predict_edits_v3::Line(context.cursor_point.row),
- column: context.cursor_point.column,
- },
- referenced_declarations,
- included_files: vec![],
- signatures,
- excerpt_parent,
- events,
- can_collect_data,
- diagnostic_groups,
- diagnostic_groups_truncated,
- git_info,
- debug_info,
- prompt_max_bytes,
- prompt_format,
- trigger,
- }
-}
-
-fn add_signature(
- declaration_id: DeclarationId,
- declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
- signatures: &mut Vec<Signature>,
- index: &SyntaxIndexState,
-) -> Option<usize> {
- if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
- return Some(*signature_index);
- }
- let Some(parent_declaration) = index.declaration(declaration_id) else {
- log::error!("bug: missing parent declaration");
- return None;
- };
- let parent_index = parent_declaration.parent().and_then(|parent| {
- add_signature(parent, declaration_to_signature_index, signatures, index)
- });
- let (text, text_is_truncated) = parent_declaration.signature_text();
- let signature_index = signatures.len();
- signatures.push(Signature {
- text: text.into(),
- text_is_truncated,
- parent_index,
- range: parent_declaration.signature_line_range(),
- });
- declaration_to_signature_index.insert(declaration_id, signature_index);
- Some(signature_index)
-}
-
-#[cfg(feature = "eval-support")]
-pub type EvalCacheKey = (EvalCacheEntryKind, u64);
-
-#[cfg(feature = "eval-support")]
-#[derive(Debug, Clone, Copy, PartialEq)]
-pub enum EvalCacheEntryKind {
- Context,
- Search,
- Prediction,
-}
-
-#[cfg(feature = "eval-support")]
-impl std::fmt::Display for EvalCacheEntryKind {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- EvalCacheEntryKind::Search => write!(f, "search"),
- EvalCacheEntryKind::Context => write!(f, "context"),
- EvalCacheEntryKind::Prediction => write!(f, "prediction"),
- }
- }
-}
-
-#[cfg(feature = "eval-support")]
-pub trait EvalCache: Send + Sync {
- fn read(&self, key: EvalCacheKey) -> Option<String>;
- fn write(&self, key: EvalCacheKey, input: &str, value: &str);
-}
-
-#[derive(Debug, Clone, Copy)]
-pub enum DataCollectionChoice {
- NotAnswered,
- Enabled,
- Disabled,
-}
-
-impl DataCollectionChoice {
- pub fn is_enabled(self) -> bool {
- match self {
- Self::Enabled => true,
- Self::NotAnswered | Self::Disabled => false,
- }
- }
-
- pub fn is_answered(self) -> bool {
- match self {
- Self::Enabled | Self::Disabled => true,
- Self::NotAnswered => false,
- }
- }
-
- #[must_use]
- pub fn toggle(&self) -> DataCollectionChoice {
- match self {
- Self::Enabled => Self::Disabled,
- Self::Disabled => Self::Enabled,
- Self::NotAnswered => Self::Enabled,
- }
- }
-}
-
-impl From<bool> for DataCollectionChoice {
- fn from(value: bool) -> Self {
- match value {
- true => DataCollectionChoice::Enabled,
- false => DataCollectionChoice::Disabled,
- }
- }
-}
-
-struct ZedPredictUpsell;
-
-impl Dismissable for ZedPredictUpsell {
- const KEY: &'static str = "dismissed-edit-predict-upsell";
-
- fn dismissed() -> bool {
- // To make this backwards compatible with older versions of Zed, we
- // check if the user has seen the previous Edit Prediction Onboarding
- // before, by checking the data collection choice which was written to
- // the database once the user clicked on "Accept and Enable"
- if KEY_VALUE_STORE
- .read_kvp(ZED_PREDICT_DATA_COLLECTION_CHOICE)
- .log_err()
- .is_some_and(|s| s.is_some())
- {
- return true;
- }
-
- KEY_VALUE_STORE
- .read_kvp(Self::KEY)
- .log_err()
- .is_some_and(|s| s.is_some())
- }
-}
-
-pub fn should_show_upsell_modal() -> bool {
- !ZedPredictUpsell::dismissed()
-}
-
-pub fn init(cx: &mut App) {
- feature_gate_predict_edits_actions(cx);
-
- cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
- workspace.register_action(|workspace, _: &RateCompletions, window, cx| {
- if cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>() {
- RatePredictionsModal::toggle(workspace, window, cx);
- }
- });
-
- workspace.register_action(
- move |workspace, _: &zed_actions::OpenZedPredictOnboarding, window, cx| {
- ZedPredictModal::toggle(
- workspace,
- workspace.user_store().clone(),
- workspace.client().clone(),
- window,
- cx,
- )
- },
- );
-
- workspace.register_action(|workspace, _: &ResetOnboarding, _window, cx| {
- update_settings_file(workspace.app_state().fs.clone(), cx, move |settings, _| {
- settings
- .project
- .all_languages
- .features
- .get_or_insert_default()
- .edit_prediction_provider = Some(EditPredictionProvider::None)
- });
- });
- })
- .detach();
-}
-
-fn feature_gate_predict_edits_actions(cx: &mut App) {
- let rate_completion_action_types = [TypeId::of::<RateCompletions>()];
- let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
- let zeta_all_action_types = [
- TypeId::of::<RateCompletions>(),
- TypeId::of::<ResetOnboarding>(),
- zed_actions::OpenZedPredictOnboarding.type_id(),
- TypeId::of::<ClearHistory>(),
- TypeId::of::<ThumbsUpActivePrediction>(),
- TypeId::of::<ThumbsDownActivePrediction>(),
- TypeId::of::<NextEdit>(),
- TypeId::of::<PreviousEdit>(),
- ];
-
- CommandPaletteFilter::update_global(cx, |filter, _cx| {
- filter.hide_action_types(&rate_completion_action_types);
- filter.hide_action_types(&reset_onboarding_action_types);
- filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
- });
-
- cx.observe_global::<SettingsStore>(move |cx| {
- let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
- let has_feature_flag = cx.has_flag::<PredictEditsRateCompletionsFeatureFlag>();
-
- CommandPaletteFilter::update_global(cx, |filter, _cx| {
- if is_ai_disabled {
- filter.hide_action_types(&zeta_all_action_types);
- } else if has_feature_flag {
- filter.show_action_types(&rate_completion_action_types);
- } else {
- filter.hide_action_types(&rate_completion_action_types);
- }
- });
- })
- .detach();
-
- cx.observe_flag::<PredictEditsRateCompletionsFeatureFlag, _>(move |is_enabled, cx| {
- if !DisableAiSettings::get_global(cx).disable_ai {
- if is_enabled {
- CommandPaletteFilter::update_global(cx, |filter, _cx| {
- filter.show_action_types(&rate_completion_action_types);
- });
- } else {
- CommandPaletteFilter::update_global(cx, |filter, _cx| {
- filter.hide_action_types(&rate_completion_action_types);
- });
- }
- }
- })
- .detach();
-}
-
-#[cfg(test)]
-mod tests {
- use std::{path::Path, sync::Arc, time::Duration};
-
- use client::UserStore;
- use clock::FakeSystemClock;
- use cloud_llm_client::{
- EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody,
- };
- use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
- use futures::{
- AsyncReadExt, StreamExt,
- channel::{mpsc, oneshot},
- };
- use gpui::{
- Entity, TestAppContext,
- http_client::{FakeHttpClient, Response},
- prelude::*,
- };
- use indoc::indoc;
- use language::OffsetRangeExt as _;
- use open_ai::Usage;
- use pretty_assertions::{assert_eq, assert_matches};
- use project::{FakeFs, Project};
- use serde_json::json;
- use settings::SettingsStore;
- use util::path;
- use uuid::Uuid;
-
- use crate::{BufferEditPrediction, EditPredictionId, REJECT_REQUEST_DEBOUNCE, Zeta};
-
- #[gpui::test]
- async fn test_current_state(cx: &mut TestAppContext) {
- let (zeta, mut requests) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "1.txt": "Hello!\nHow\nBye\n",
- "2.txt": "Hola!\nComo\nAdios\n"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- zeta.update(cx, |zeta, cx| {
- zeta.register_project(&project, cx);
- });
-
- let buffer1 = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
- let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot1.anchor_before(language::Point::new(1, 3));
-
- // Prediction for current file
-
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
- });
- let (_request, respond_tx) = requests.predict.next().await.unwrap();
-
- respond_tx
- .send(model_response(indoc! {r"
- --- a/root/1.txt
- +++ b/root/1.txt
- @@ ... @@
- Hello!
- -How
- +How are you?
- Bye
- "}))
- .unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- let prediction = zeta
- .current_prediction_for_buffer(&buffer1, &project, cx)
- .unwrap();
- assert_matches!(prediction, BufferEditPrediction::Local { .. });
- });
-
- // Context refresh
- let refresh_task = zeta.update(cx, |zeta, cx| {
- zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
- });
- let (_request, respond_tx) = requests.predict.next().await.unwrap();
- respond_tx
- .send(open_ai::Response {
- id: Uuid::new_v4().to_string(),
- object: "response".into(),
- created: 0,
- model: "model".into(),
- choices: vec![open_ai::Choice {
- index: 0,
- message: open_ai::RequestMessage::Assistant {
- content: None,
- tool_calls: vec![open_ai::ToolCall {
- id: "search".into(),
- content: open_ai::ToolCallContent::Function {
- function: open_ai::FunctionContent {
- name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
- .to_string(),
- arguments: serde_json::to_string(&SearchToolInput {
- queries: Box::new([SearchToolQuery {
- glob: "root/2.txt".to_string(),
- syntax_node: vec![],
- content: Some(".".into()),
- }]),
- })
- .unwrap(),
- },
- },
- }],
- },
- finish_reason: None,
- }],
- usage: Usage {
- prompt_tokens: 0,
- completion_tokens: 0,
- total_tokens: 0,
- },
- })
- .unwrap();
- refresh_task.await.unwrap();
-
- zeta.update(cx, |zeta, _cx| {
- zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project);
- });
-
- // Prediction for another file
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
- });
- let (_request, respond_tx) = requests.predict.next().await.unwrap();
- respond_tx
- .send(model_response(indoc! {r#"
- --- a/root/2.txt
- +++ b/root/2.txt
- Hola!
- -Como
- +Como estas?
- Adios
- "#}))
- .unwrap();
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- let prediction = zeta
- .current_prediction_for_buffer(&buffer1, &project, cx)
- .unwrap();
- assert_matches!(
- prediction,
- BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
- );
- });
-
- let buffer2 = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
-
- zeta.read_with(cx, |zeta, cx| {
- let prediction = zeta
- .current_prediction_for_buffer(&buffer2, &project, cx)
- .unwrap();
- assert_matches!(prediction, BufferEditPrediction::Local { .. });
- });
- }
-
- #[gpui::test]
- async fn test_simple_request(cx: &mut TestAppContext) {
- let (zeta, mut requests) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "foo.md": "Hello!\nHow\nBye\n"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- let buffer = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot.anchor_before(language::Point::new(1, 3));
-
- let prediction_task = zeta.update(cx, |zeta, cx| {
- zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
- });
-
- let (_, respond_tx) = requests.predict.next().await.unwrap();
-
- // TODO Put back when we have a structured request again
- // assert_eq!(
- // request.excerpt_path.as_ref(),
- // Path::new(path!("root/foo.md"))
- // );
- // assert_eq!(
- // request.cursor_point,
- // Point {
- // line: Line(1),
- // column: 3
- // }
- // );
-
- respond_tx
- .send(model_response(indoc! { r"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How are you?
- Bye
- "}))
- .unwrap();
-
- let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
-
- assert_eq!(prediction.edits.len(), 1);
- assert_eq!(
- prediction.edits[0].0.to_point(&snapshot).start,
- language::Point::new(1, 3)
- );
- assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
- }
-
- #[gpui::test]
- async fn test_request_events(cx: &mut TestAppContext) {
- let (zeta, mut requests) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "foo.md": "Hello!\n\nBye\n"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- let buffer = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
-
- zeta.update(cx, |zeta, cx| {
- zeta.register_buffer(&buffer, &project, cx);
- });
-
- buffer.update(cx, |buffer, cx| {
- buffer.edit(vec![(7..7, "How")], None, cx);
- });
-
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot.anchor_before(language::Point::new(1, 3));
-
- let prediction_task = zeta.update(cx, |zeta, cx| {
- zeta.request_prediction(&project, &buffer, position, Default::default(), cx)
- });
-
- let (request, respond_tx) = requests.predict.next().await.unwrap();
-
- let prompt = prompt_from_request(&request);
- assert!(
- prompt.contains(indoc! {"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ -1,3 +1,3 @@
- Hello!
- -
- +How
- Bye
- "}),
- "{prompt}"
- );
-
- respond_tx
- .send(model_response(indoc! {r#"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How are you?
- Bye
- "#}))
- .unwrap();
-
- let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
-
- assert_eq!(prediction.edits.len(), 1);
- assert_eq!(
- prediction.edits[0].0.to_point(&snapshot).start,
- language::Point::new(1, 3)
- );
- assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
- }
-
- #[gpui::test]
- async fn test_empty_prediction(cx: &mut TestAppContext) {
- let (zeta, mut requests) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "foo.md": "Hello!\nHow\nBye\n"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- let buffer = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot.anchor_before(language::Point::new(1, 3));
-
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
- });
-
- const NO_OP_DIFF: &str = indoc! { r"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How
- Bye
- "};
-
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- let response = model_response(NO_OP_DIFF);
- let id = response.id.clone();
- respond_tx.send(response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- assert!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .is_none()
- );
- });
-
- // prediction is reported as rejected
- let (reject_request, _) = requests.reject.next().await.unwrap();
-
- assert_eq!(
- &reject_request.rejections,
- &[EditPredictionRejection {
- request_id: id,
- reason: EditPredictionRejectReason::Empty,
- was_shown: false
- }]
- );
- }
-
- #[gpui::test]
- async fn test_interpolated_empty(cx: &mut TestAppContext) {
- let (zeta, mut requests) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "foo.md": "Hello!\nHow\nBye\n"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- let buffer = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot.anchor_before(language::Point::new(1, 3));
-
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
- });
-
- let (_, respond_tx) = requests.predict.next().await.unwrap();
-
- buffer.update(cx, |buffer, cx| {
- buffer.set_text("Hello!\nHow are you?\nBye", cx);
- });
-
- let response = model_response(SIMPLE_DIFF);
- let id = response.id.clone();
- respond_tx.send(response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- assert!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .is_none()
- );
- });
-
- // prediction is reported as rejected
- let (reject_request, _) = requests.reject.next().await.unwrap();
-
- assert_eq!(
- &reject_request.rejections,
- &[EditPredictionRejection {
- request_id: id,
- reason: EditPredictionRejectReason::InterpolatedEmpty,
- was_shown: false
- }]
- );
- }
-
- const SIMPLE_DIFF: &str = indoc! { r"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How are you?
- Bye
- "};
-
- #[gpui::test]
- async fn test_replace_current(cx: &mut TestAppContext) {
- let (zeta, mut requests) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "foo.md": "Hello!\nHow\nBye\n"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- let buffer = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot.anchor_before(language::Point::new(1, 3));
-
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
- });
-
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- let first_response = model_response(SIMPLE_DIFF);
- let first_id = first_response.id.clone();
- respond_tx.send(first_response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- assert_eq!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .unwrap()
- .id
- .0,
- first_id
- );
- });
-
- // a second request is triggered
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
- });
-
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- let second_response = model_response(SIMPLE_DIFF);
- let second_id = second_response.id.clone();
- respond_tx.send(second_response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- // second replaces first
- assert_eq!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .unwrap()
- .id
- .0,
- second_id
- );
- });
-
- // first is reported as replaced
- let (reject_request, _) = requests.reject.next().await.unwrap();
-
- assert_eq!(
- &reject_request.rejections,
- &[EditPredictionRejection {
- request_id: first_id,
- reason: EditPredictionRejectReason::Replaced,
- was_shown: false
- }]
- );
- }
-
- #[gpui::test]
- async fn test_current_preferred(cx: &mut TestAppContext) {
- let (zeta, mut requests) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "foo.md": "Hello!\nHow\nBye\n"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- let buffer = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot.anchor_before(language::Point::new(1, 3));
-
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
- });
-
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- let first_response = model_response(SIMPLE_DIFF);
- let first_id = first_response.id.clone();
- respond_tx.send(first_response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- assert_eq!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .unwrap()
- .id
- .0,
- first_id
- );
- });
-
- // a second request is triggered
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
- });
-
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- // worse than current prediction
- let second_response = model_response(indoc! { r"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How are
- Bye
- "});
- let second_id = second_response.id.clone();
- respond_tx.send(second_response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- // first is preferred over second
- assert_eq!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .unwrap()
- .id
- .0,
- first_id
- );
- });
-
- // second is reported as rejected
- let (reject_request, _) = requests.reject.next().await.unwrap();
-
- assert_eq!(
- &reject_request.rejections,
- &[EditPredictionRejection {
- request_id: second_id,
- reason: EditPredictionRejectReason::CurrentPreferred,
- was_shown: false
- }]
- );
- }
-
- #[gpui::test]
- async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
- let (zeta, mut requests) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "foo.md": "Hello!\nHow\nBye\n"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- let buffer = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot.anchor_before(language::Point::new(1, 3));
-
- // start two refresh tasks
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
- });
-
- let (_, respond_first) = requests.predict.next().await.unwrap();
-
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
- });
-
- let (_, respond_second) = requests.predict.next().await.unwrap();
-
- // wait for throttle
- cx.run_until_parked();
-
- // second responds first
- let second_response = model_response(SIMPLE_DIFF);
- let second_id = second_response.id.clone();
- respond_second.send(second_response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- // current prediction is second
- assert_eq!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .unwrap()
- .id
- .0,
- second_id
- );
- });
-
- let first_response = model_response(SIMPLE_DIFF);
- let first_id = first_response.id.clone();
- respond_first.send(first_response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- // current prediction is still second, since first was cancelled
- assert_eq!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .unwrap()
- .id
- .0,
- second_id
- );
- });
-
- // first is reported as rejected
- let (reject_request, _) = requests.reject.next().await.unwrap();
-
- cx.run_until_parked();
-
- assert_eq!(
- &reject_request.rejections,
- &[EditPredictionRejection {
- request_id: first_id,
- reason: EditPredictionRejectReason::Canceled,
- was_shown: false
- }]
- );
- }
-
- #[gpui::test]
- async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
- let (zeta, mut requests) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "foo.md": "Hello!\nHow\nBye\n"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- let buffer = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot.anchor_before(language::Point::new(1, 3));
-
- // start two refresh tasks
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
- });
-
- let (_, respond_first) = requests.predict.next().await.unwrap();
-
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
- });
-
- let (_, respond_second) = requests.predict.next().await.unwrap();
-
- // wait for throttle, so requests are sent
- cx.run_until_parked();
-
- zeta.update(cx, |zeta, cx| {
- // start a third request
- zeta.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
-
- // 2 are pending, so 2nd is cancelled
- assert_eq!(
- zeta.get_or_init_zeta_project(&project, cx)
- .cancelled_predictions
- .iter()
- .copied()
- .collect::<Vec<_>>(),
- [1]
- );
- });
-
- // wait for throttle
- cx.run_until_parked();
-
- let (_, respond_third) = requests.predict.next().await.unwrap();
-
- let first_response = model_response(SIMPLE_DIFF);
- let first_id = first_response.id.clone();
- respond_first.send(first_response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- // current prediction is first
- assert_eq!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .unwrap()
- .id
- .0,
- first_id
- );
- });
-
- let cancelled_response = model_response(SIMPLE_DIFF);
- let cancelled_id = cancelled_response.id.clone();
- respond_second.send(cancelled_response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- // current prediction is still first, since second was cancelled
- assert_eq!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .unwrap()
- .id
- .0,
- first_id
- );
- });
-
- let third_response = model_response(SIMPLE_DIFF);
- let third_response_id = third_response.id.clone();
- respond_third.send(third_response).unwrap();
-
- cx.run_until_parked();
-
- zeta.read_with(cx, |zeta, cx| {
- // third completes and replaces first
- assert_eq!(
- zeta.current_prediction_for_buffer(&buffer, &project, cx)
- .unwrap()
- .id
- .0,
- third_response_id
- );
- });
-
- // second is reported as rejected
- let (reject_request, _) = requests.reject.next().await.unwrap();
-
- cx.run_until_parked();
-
- assert_eq!(
- &reject_request.rejections,
- &[
- EditPredictionRejection {
- request_id: cancelled_id,
- reason: EditPredictionRejectReason::Canceled,
- was_shown: false
- },
- EditPredictionRejection {
- request_id: first_id,
- reason: EditPredictionRejectReason::Replaced,
- was_shown: false
- }
- ]
- );
- }
-
- #[gpui::test]
- async fn test_rejections_flushing(cx: &mut TestAppContext) {
- let (zeta, mut requests) = init_test(cx);
-
- zeta.update(cx, |zeta, _cx| {
- zeta.reject_prediction(
- EditPredictionId("test-1".into()),
- EditPredictionRejectReason::Discarded,
- false,
- );
- zeta.reject_prediction(
- EditPredictionId("test-2".into()),
- EditPredictionRejectReason::Canceled,
- true,
- );
- });
-
- cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
- cx.run_until_parked();
-
- let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
- respond_tx.send(()).unwrap();
-
- // batched
- assert_eq!(reject_request.rejections.len(), 2);
- assert_eq!(
- reject_request.rejections[0],
- EditPredictionRejection {
- request_id: "test-1".to_string(),
- reason: EditPredictionRejectReason::Discarded,
- was_shown: false
- }
- );
- assert_eq!(
- reject_request.rejections[1],
- EditPredictionRejection {
- request_id: "test-2".to_string(),
- reason: EditPredictionRejectReason::Canceled,
- was_shown: true
- }
- );
-
- // Reaching batch size limit sends without debounce
- zeta.update(cx, |zeta, _cx| {
- for i in 0..70 {
- zeta.reject_prediction(
- EditPredictionId(format!("batch-{}", i).into()),
- EditPredictionRejectReason::Discarded,
- false,
- );
- }
- });
-
- // First MAX/2 items are sent immediately
- cx.run_until_parked();
- let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
- respond_tx.send(()).unwrap();
-
- assert_eq!(reject_request.rejections.len(), 50);
- assert_eq!(reject_request.rejections[0].request_id, "batch-0");
- assert_eq!(reject_request.rejections[49].request_id, "batch-49");
-
- // Remaining items are debounced with the next batch
- cx.executor().advance_clock(Duration::from_secs(15));
- cx.run_until_parked();
-
- let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
- respond_tx.send(()).unwrap();
-
- assert_eq!(reject_request.rejections.len(), 20);
- assert_eq!(reject_request.rejections[0].request_id, "batch-50");
- assert_eq!(reject_request.rejections[19].request_id, "batch-69");
-
- // Request failure
- zeta.update(cx, |zeta, _cx| {
- zeta.reject_prediction(
- EditPredictionId("retry-1".into()),
- EditPredictionRejectReason::Discarded,
- false,
- );
- });
-
- cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
- cx.run_until_parked();
-
- let (reject_request, _respond_tx) = requests.reject.next().await.unwrap();
- assert_eq!(reject_request.rejections.len(), 1);
- assert_eq!(reject_request.rejections[0].request_id, "retry-1");
- // Simulate failure
- drop(_respond_tx);
-
- // Add another rejection
- zeta.update(cx, |zeta, _cx| {
- zeta.reject_prediction(
- EditPredictionId("retry-2".into()),
- EditPredictionRejectReason::Discarded,
- false,
- );
- });
-
- cx.executor().advance_clock(REJECT_REQUEST_DEBOUNCE);
- cx.run_until_parked();
-
- // Retry should include both the failed item and the new one
- let (reject_request, respond_tx) = requests.reject.next().await.unwrap();
- respond_tx.send(()).unwrap();
-
- assert_eq!(reject_request.rejections.len(), 2);
- assert_eq!(reject_request.rejections[0].request_id, "retry-1");
- assert_eq!(reject_request.rejections[1].request_id, "retry-2");
- }
-
- // Skipped until we start including diagnostics in prompt
- // #[gpui::test]
- // async fn test_request_diagnostics(cx: &mut TestAppContext) {
- // let (zeta, mut req_rx) = init_test(cx);
- // let fs = FakeFs::new(cx.executor());
- // fs.insert_tree(
- // "/root",
- // json!({
- // "foo.md": "Hello!\nBye"
- // }),
- // )
- // .await;
- // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
-
- // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
- // let diagnostic = lsp::Diagnostic {
- // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
- // severity: Some(lsp::DiagnosticSeverity::ERROR),
- // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
- // ..Default::default()
- // };
-
- // project.update(cx, |project, cx| {
- // project.lsp_store().update(cx, |lsp_store, cx| {
- // // Create some diagnostics
- // lsp_store
- // .update_diagnostics(
- // LanguageServerId(0),
- // lsp::PublishDiagnosticsParams {
- // uri: path_to_buffer_uri.clone(),
- // diagnostics: vec![diagnostic],
- // version: None,
- // },
- // None,
- // language::DiagnosticSourceKind::Pushed,
- // &[],
- // cx,
- // )
- // .unwrap();
- // });
- // });
-
- // let buffer = project
- // .update(cx, |project, cx| {
- // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- // project.open_buffer(path, cx)
- // })
- // .await
- // .unwrap();
-
- // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- // let position = snapshot.anchor_before(language::Point::new(0, 0));
-
- // let _prediction_task = zeta.update(cx, |zeta, cx| {
- // zeta.request_prediction(&project, &buffer, position, cx)
- // });
-
- // let (request, _respond_tx) = req_rx.next().await.unwrap();
-
- // assert_eq!(request.diagnostic_groups.len(), 1);
- // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
- // .unwrap();
- // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
- // assert_eq!(
- // value,
- // json!({
- // "entries": [{
- // "range": {
- // "start": 8,
- // "end": 10
- // },
- // "diagnostic": {
- // "source": null,
- // "code": null,
- // "code_description": null,
- // "severity": 1,
- // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
- // "markdown": null,
- // "group_id": 0,
- // "is_primary": true,
- // "is_disk_based": false,
- // "is_unnecessary": false,
- // "source_kind": "Pushed",
- // "data": null,
- // "underline": true
- // }
- // }],
- // "primary_ix": 0
- // })
- // );
- // }
-
- fn model_response(text: &str) -> open_ai::Response {
- open_ai::Response {
- id: Uuid::new_v4().to_string(),
- object: "response".into(),
- created: 0,
- model: "model".into(),
- choices: vec![open_ai::Choice {
- index: 0,
- message: open_ai::RequestMessage::Assistant {
- content: Some(open_ai::MessageContent::Plain(text.to_string())),
- tool_calls: vec![],
- },
- finish_reason: None,
- }],
- usage: Usage {
- prompt_tokens: 0,
- completion_tokens: 0,
- total_tokens: 0,
- },
- }
- }
-
- fn prompt_from_request(request: &open_ai::Request) -> &str {
- assert_eq!(request.messages.len(), 1);
- let open_ai::RequestMessage::User {
- content: open_ai::MessageContent::Plain(content),
- ..
- } = &request.messages[0]
- else {
- panic!(
- "Request does not have single user message of type Plain. {:#?}",
- request
- );
- };
- content
- }
-
- struct RequestChannels {
- predict: mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
- reject: mpsc::UnboundedReceiver<(RejectEditPredictionsBody, oneshot::Sender<()>)>,
- }
-
- fn init_test(cx: &mut TestAppContext) -> (Entity<Zeta>, RequestChannels) {
- cx.update(move |cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- zlog::init_test();
-
- let (predict_req_tx, predict_req_rx) = mpsc::unbounded();
- let (reject_req_tx, reject_req_rx) = mpsc::unbounded();
-
- let http_client = FakeHttpClient::create({
- move |req| {
- let uri = req.uri().path().to_string();
- let mut body = req.into_body();
- let predict_req_tx = predict_req_tx.clone();
- let reject_req_tx = reject_req_tx.clone();
- async move {
- let resp = match uri.as_str() {
- "/client/llm_tokens" => serde_json::to_string(&json!({
- "token": "test"
- }))
- .unwrap(),
- "/predict_edits/raw" => {
- let mut buf = Vec::new();
- body.read_to_end(&mut buf).await.ok();
- let req = serde_json::from_slice(&buf).unwrap();
-
- let (res_tx, res_rx) = oneshot::channel();
- predict_req_tx.unbounded_send((req, res_tx)).unwrap();
- serde_json::to_string(&res_rx.await?).unwrap()
- }
- "/predict_edits/reject" => {
- let mut buf = Vec::new();
- body.read_to_end(&mut buf).await.ok();
- let req = serde_json::from_slice(&buf).unwrap();
-
- let (res_tx, res_rx) = oneshot::channel();
- reject_req_tx.unbounded_send((req, res_tx)).unwrap();
- serde_json::to_string(&res_rx.await?).unwrap()
- }
- _ => {
- panic!("Unexpected path: {}", uri)
- }
- };
-
- Ok(Response::builder().body(resp.into()).unwrap())
- }
- }
- });
-
- let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
- client.cloud_client().set_credentials(1, "test".into());
-
- language_model::init(client.clone(), cx);
-
- let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
- let zeta = Zeta::global(&client, &user_store, cx);
-
- (
- zeta,
- RequestChannels {
- predict: predict_req_rx,
- reject: reject_req_rx,
- },
- )
- })
- }
-}
@@ -1,231 +0,0 @@
-use super::{
- CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER, START_OF_FILE_MARKER,
- guess_token_count,
-};
-use language::{BufferSnapshot, Point};
-use std::{fmt::Write, ops::Range};
-
-#[derive(Debug)]
-pub struct InputExcerpt {
- pub context_range: Range<Point>,
- pub editable_range: Range<Point>,
- pub prompt: String,
-}
-
-pub fn excerpt_for_cursor_position(
- position: Point,
- path: &str,
- snapshot: &BufferSnapshot,
- editable_region_token_limit: usize,
- context_token_limit: usize,
-) -> InputExcerpt {
- let mut scope_range = position..position;
- let mut remaining_edit_tokens = editable_region_token_limit;
-
- while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) {
- let parent_tokens = guess_token_count(parent.byte_range().len());
- let parent_point_range = Point::new(
- parent.start_position().row as u32,
- parent.start_position().column as u32,
- )
- ..Point::new(
- parent.end_position().row as u32,
- parent.end_position().column as u32,
- );
- if parent_point_range == scope_range {
- break;
- } else if parent_tokens <= editable_region_token_limit {
- scope_range = parent_point_range;
- remaining_edit_tokens = editable_region_token_limit - parent_tokens;
- } else {
- break;
- }
- }
-
- let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens);
- let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit);
-
- let mut prompt = String::new();
-
- writeln!(&mut prompt, "```{path}").unwrap();
- if context_range.start == Point::zero() {
- writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
- }
-
- for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
- prompt.push_str(chunk.text);
- }
-
- push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
-
- for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
- prompt.push_str(chunk.text);
- }
- write!(prompt, "\n```").unwrap();
-
- InputExcerpt {
- context_range,
- editable_range,
- prompt,
- }
-}
-
-fn push_editable_range(
- cursor_position: Point,
- snapshot: &BufferSnapshot,
- editable_range: Range<Point>,
- prompt: &mut String,
-) {
- writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
- for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
- prompt.push_str(chunk.text);
- }
- prompt.push_str(CURSOR_MARKER);
- for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
- prompt.push_str(chunk.text);
- }
- write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
-}
-
-fn expand_range(
- snapshot: &BufferSnapshot,
- range: Range<Point>,
- mut remaining_tokens: usize,
-) -> Range<Point> {
- let mut expanded_range = range;
- expanded_range.start.column = 0;
- expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
- loop {
- let mut expanded = false;
-
- if remaining_tokens > 0 && expanded_range.start.row > 0 {
- expanded_range.start.row -= 1;
- let line_tokens =
- guess_token_count(snapshot.line_len(expanded_range.start.row) as usize);
- remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
- expanded = true;
- }
-
- if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row {
- expanded_range.end.row += 1;
- expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
- let line_tokens = guess_token_count(expanded_range.end.column as usize);
- remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
- expanded = true;
- }
-
- if !expanded {
- break;
- }
- }
- expanded_range
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use gpui::{App, AppContext};
- use indoc::indoc;
- use language::{Buffer, Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
- use std::sync::Arc;
-
- #[gpui::test]
- fn test_excerpt_for_cursor_position(cx: &mut App) {
- let text = indoc! {r#"
- fn foo() {
- let x = 42;
- println!("Hello, world!");
- }
-
- fn bar() {
- let x = 42;
- let mut sum = 0;
- for i in 0..x {
- sum += i;
- }
- println!("Sum: {}", sum);
- return sum;
- }
-
- fn generate_random_numbers() -> Vec<i32> {
- let mut rng = rand::thread_rng();
- let mut numbers = Vec::new();
- for _ in 0..5 {
- numbers.push(rng.random_range(1..101));
- }
- numbers
- }
- "#};
- let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
- let snapshot = buffer.read(cx).snapshot();
-
- // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
- // when a larger scope doesn't fit the editable region.
- let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
- assert_eq!(
- excerpt.prompt,
- indoc! {r#"
- ```main.rs
- let x = 42;
- println!("Hello, world!");
- <|editable_region_start|>
- }
-
- fn bar() {
- let x = 42;
- let mut sum = 0;
- for i in 0..x {
- sum += i;
- }
- println!("Sum: {}", sum);
- r<|user_cursor_is_here|>eturn sum;
- }
-
- fn generate_random_numbers() -> Vec<i32> {
- <|editable_region_end|>
- let mut rng = rand::thread_rng();
- let mut numbers = Vec::new();
- ```"#}
- );
-
- // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
- let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
- assert_eq!(
- excerpt.prompt,
- indoc! {r#"
- ```main.rs
- fn bar() {
- let x = 42;
- let mut sum = 0;
- <|editable_region_start|>
- for i in 0..x {
- sum += i;
- }
- println!("Sum: {}", sum);
- r<|user_cursor_is_here|>eturn sum;
- }
-
- fn generate_random_numbers() -> Vec<i32> {
- let mut rng = rand::thread_rng();
- <|editable_region_end|>
- let mut numbers = Vec::new();
- for _ in 0..5 {
- numbers.push(rng.random_range(1..101));
- ```"#}
- );
- }
-
- fn rust_lang() -> Language {
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::LANGUAGE.into()),
- )
- }
-}
@@ -1,671 +0,0 @@
-use client::test::FakeServer;
-use clock::{FakeSystemClock, ReplicaId};
-use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
-use cloud_llm_client::{PredictEditsBody, PredictEditsResponse};
-use gpui::TestAppContext;
-use http_client::FakeHttpClient;
-use indoc::indoc;
-use language::Point;
-use parking_lot::Mutex;
-use serde_json::json;
-use settings::SettingsStore;
-use util::{path, rel_path::rel_path};
-
-use crate::zeta1::MAX_EVENT_TOKENS;
-
-use super::*;
-
-const BSD_0_TXT: &str = include_str!("../license_examples/0bsd.txt");
-
-#[gpui::test]
-async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
- let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
- let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
- to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
- });
-
- let edit_preview = cx
- .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
- .await;
-
- let completion = EditPrediction {
- edits,
- edit_preview,
- buffer: buffer.clone(),
- snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
- id: EditPredictionId("the-id".into()),
- inputs: EditPredictionInputs {
- events: Default::default(),
- included_files: Default::default(),
- cursor_point: cloud_llm_client::predict_edits_v3::Point {
- line: Line(0),
- column: 0,
- },
- cursor_path: Path::new("").into(),
- },
- buffer_snapshotted_at: Instant::now(),
- response_received_at: Instant::now(),
- };
-
- cx.update(|cx| {
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(2..5, "REM".into()), (9..11, "".into())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(2..2, "REM".into()), (6..8, "".into())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.undo(cx));
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(2..5, "REM".into()), (9..11, "".into())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(3..3, "EM".into()), (7..9, "".into())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(4..4, "M".into()), (8..10, "".into())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(9..11, "".into())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(4..4, "M".into()), (8..10, "".into())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
- assert_eq!(
- from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
- &buffer,
- cx
- ),
- vec![(4..4, "M".into())]
- );
-
- buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
- assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
- })
-}
-
-#[gpui::test]
-async fn test_clean_up_diff(cx: &mut TestAppContext) {
- init_test(cx);
-
- assert_eq!(
- apply_edit_prediction(
- indoc! {"
- fn main() {
- let word_1 = \"lorem\";
- let range = word.len()..word.len();
- }
- "},
- indoc! {"
- <|editable_region_start|>
- fn main() {
- let word_1 = \"lorem\";
- let range = word_1.len()..word_1.len();
- }
-
- <|editable_region_end|>
- "},
- cx,
- )
- .await,
- indoc! {"
- fn main() {
- let word_1 = \"lorem\";
- let range = word_1.len()..word_1.len();
- }
- "},
- );
-
- assert_eq!(
- apply_edit_prediction(
- indoc! {"
- fn main() {
- let story = \"the quick\"
- }
- "},
- indoc! {"
- <|editable_region_start|>
- fn main() {
- let story = \"the quick brown fox jumps over the lazy dog\";
- }
-
- <|editable_region_end|>
- "},
- cx,
- )
- .await,
- indoc! {"
- fn main() {
- let story = \"the quick brown fox jumps over the lazy dog\";
- }
- "},
- );
-}
-
-#[gpui::test]
-async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
- init_test(cx);
-
- let buffer_content = "lorem\n";
- let completion_response = indoc! {"
- ```animals.js
- <|start_of_file|>
- <|editable_region_start|>
- lorem
- ipsum
- <|editable_region_end|>
- ```"};
-
- assert_eq!(
- apply_edit_prediction(buffer_content, completion_response, cx).await,
- "lorem\nipsum"
- );
-}
-
-#[gpui::test]
-async fn test_can_collect_data(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- fs.insert_tree(path!("/project"), json!({ "LICENSE": BSD_0_TXT }))
- .await;
-
- let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/project/src/main.rs"), cx)
- })
- .await
- .unwrap();
-
- let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- true
- );
-
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Disabled
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-}
-
-#[gpui::test]
-async fn test_no_data_collection_for_remote_file(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- let project = Project::test(fs.clone(), [], cx).await;
-
- let buffer = cx.new(|_cx| {
- Buffer::remote(
- language::BufferId::new(1).unwrap(),
- ReplicaId::new(1),
- language::Capability::ReadWrite,
- "fn main() {\n println!(\"Hello\");\n}",
- )
- });
-
- let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-}
-
-#[gpui::test]
-async fn test_no_data_collection_for_private_file(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- fs.insert_tree(
- path!("/project"),
- json!({
- "LICENSE": BSD_0_TXT,
- ".env": "SECRET_KEY=secret"
- }),
- )
- .await;
-
- let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer("/project/.env", cx)
- })
- .await
- .unwrap();
-
- let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-}
-
-#[gpui::test]
-async fn test_no_data_collection_for_untitled_buffer(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- let project = Project::test(fs.clone(), [], cx).await;
- let buffer = cx.new(|cx| Buffer::local("", cx));
-
- let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-}
-
-#[gpui::test]
-async fn test_no_data_collection_when_closed_source(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- fs.insert_tree(path!("/project"), json!({ "main.rs": "fn main() {}" }))
- .await;
-
- let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer("/project/main.rs", cx)
- })
- .await
- .unwrap();
-
- let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-}
-
-#[gpui::test]
-async fn test_data_collection_status_changes_on_move(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- fs.insert_tree(
- path!("/open_source_worktree"),
- json!({ "LICENSE": BSD_0_TXT, "main.rs": "" }),
- )
- .await;
- fs.insert_tree(path!("/closed_source_worktree"), json!({ "main.rs": "" }))
- .await;
-
- let project = Project::test(
- fs.clone(),
- [
- path!("/open_source_worktree").as_ref(),
- path!("/closed_source_worktree").as_ref(),
- ],
- cx,
- )
- .await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/open_source_worktree/main.rs"), cx)
- })
- .await
- .unwrap();
-
- let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- true
- );
-
- let closed_source_file = project
- .update(cx, |project, cx| {
- let worktree2 = project
- .worktree_for_root_name("closed_source_worktree", cx)
- .unwrap();
- worktree2.update(cx, |worktree2, cx| {
- worktree2.load_file(rel_path("main.rs"), cx)
- })
- })
- .await
- .unwrap()
- .file;
-
- buffer.update(cx, |buffer, cx| {
- buffer.file_updated(closed_source_file, cx);
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-}
-
-#[gpui::test]
-async fn test_no_data_collection_for_events_in_uncollectable_buffers(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = project::FakeFs::new(cx.executor());
- fs.insert_tree(
- path!("/worktree1"),
- json!({ "LICENSE": BSD_0_TXT, "main.rs": "", "other.rs": "" }),
- )
- .await;
- fs.insert_tree(path!("/worktree2"), json!({ "private.rs": "" }))
- .await;
-
- let project = Project::test(
- fs.clone(),
- [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()],
- cx,
- )
- .await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/worktree1/main.rs"), cx)
- })
- .await
- .unwrap();
- let private_buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/worktree2/file.rs"), cx)
- })
- .await
- .unwrap();
-
- let (zeta, captured_request, _) = make_test_zeta(&project, cx).await;
- zeta.update(cx, |zeta, _cx| {
- zeta.data_collection_choice = DataCollectionChoice::Enabled
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- true
- );
-
- // this has a side effect of registering the buffer to watch for edits
- run_edit_prediction(&private_buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-
- private_buffer.update(cx, |private_buffer, cx| {
- private_buffer.edit([(0..0, "An edit for the history!")], None, cx);
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- false
- );
-
- // make an edit that uses too many bytes, causing private_buffer edit to not be able to be
- // included
- buffer.update(cx, |buffer, cx| {
- buffer.edit(
- [(
- 0..0,
- " ".repeat(MAX_EVENT_TOKENS * zeta1::BYTES_PER_TOKEN_GUESS),
- )],
- None,
- cx,
- );
- });
-
- run_edit_prediction(&buffer, &project, &zeta, cx).await;
- assert_eq!(
- captured_request.lock().clone().unwrap().can_collect_data,
- true
- );
-}
-
-fn init_test(cx: &mut TestAppContext) {
- cx.update(|cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- });
-}
-
-async fn apply_edit_prediction(
- buffer_content: &str,
- completion_response: &str,
- cx: &mut TestAppContext,
-) -> String {
- let fs = project::FakeFs::new(cx.executor());
- let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
- let buffer = cx.new(|cx| Buffer::local(buffer_content, cx));
- let (zeta, _, response) = make_test_zeta(&project, cx).await;
- *response.lock() = completion_response.to_string();
- let edit_prediction = run_edit_prediction(&buffer, &project, &zeta, cx).await;
- buffer.update(cx, |buffer, cx| {
- buffer.edit(edit_prediction.edits.iter().cloned(), None, cx)
- });
- buffer.read_with(cx, |buffer, _| buffer.text())
-}
-
-async fn run_edit_prediction(
- buffer: &Entity<Buffer>,
- project: &Entity<Project>,
- zeta: &Entity<Zeta>,
- cx: &mut TestAppContext,
-) -> EditPrediction {
- let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0)));
- zeta.update(cx, |zeta, cx| zeta.register_buffer(buffer, &project, cx));
- cx.background_executor.run_until_parked();
- let prediction_task = zeta.update(cx, |zeta, cx| {
- zeta.request_prediction(&project, buffer, cursor, Default::default(), cx)
- });
- prediction_task.await.unwrap().unwrap().prediction.unwrap()
-}
-
-async fn make_test_zeta(
- project: &Entity<Project>,
- cx: &mut TestAppContext,
-) -> (
- Entity<Zeta>,
- Arc<Mutex<Option<PredictEditsBody>>>,
- Arc<Mutex<String>>,
-) {
- let default_response = indoc! {"
- ```main.rs
- <|start_of_file|>
- <|editable_region_start|>
- hello world
- <|editable_region_end|>
- ```"
- };
- let captured_request: Arc<Mutex<Option<PredictEditsBody>>> = Arc::new(Mutex::new(None));
- let completion_response: Arc<Mutex<String>> =
- Arc::new(Mutex::new(default_response.to_string()));
- let http_client = FakeHttpClient::create({
- let captured_request = captured_request.clone();
- let completion_response = completion_response.clone();
- let mut next_request_id = 0;
- move |req| {
- let captured_request = captured_request.clone();
- let completion_response = completion_response.clone();
- async move {
- match (req.method(), req.uri().path()) {
- (&Method::POST, "/client/llm_tokens") => Ok(http_client::Response::builder()
- .status(200)
- .body(
- serde_json::to_string(&CreateLlmTokenResponse {
- token: LlmToken("the-llm-token".to_string()),
- })
- .unwrap()
- .into(),
- )
- .unwrap()),
- (&Method::POST, "/predict_edits/v2") => {
- let mut request_body = String::new();
- req.into_body().read_to_string(&mut request_body).await?;
- *captured_request.lock() =
- Some(serde_json::from_str(&request_body).unwrap());
- next_request_id += 1;
- Ok(http_client::Response::builder()
- .status(200)
- .body(
- serde_json::to_string(&PredictEditsResponse {
- request_id: format!("request-{next_request_id}"),
- output_excerpt: completion_response.lock().clone(),
- })
- .unwrap()
- .into(),
- )
- .unwrap())
- }
- _ => Ok(http_client::Response::builder()
- .status(404)
- .body("Not Found".into())
- .unwrap()),
- }
- }
- }
- });
-
- let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
- cx.update(|cx| {
- RefreshLlmTokenListener::register(client.clone(), cx);
- });
- let _server = FakeServer::for_client(42, &client, cx).await;
-
- let zeta = cx.new(|cx| {
- let mut zeta = Zeta::new(client, project.read(cx).user_store(), cx);
- zeta.set_edit_prediction_model(ZetaEditPredictionModel::Zeta1);
-
- let worktrees = project.read(cx).worktrees(cx).collect::<Vec<_>>();
- for worktree in worktrees {
- let worktree_id = worktree.read(cx).id();
- zeta.get_or_init_zeta_project(project, cx)
- .license_detection_watchers
- .entry(worktree_id)
- .or_insert_with(|| Rc::new(LicenseDetectionWatcher::new(&worktree, cx)));
- }
-
- zeta
- });
-
- (zeta, captured_request, completion_response)
-}
-
-fn to_completion_edits(
- iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
- buffer: &Entity<Buffer>,
- cx: &App,
-) -> Vec<(Range<Anchor>, Arc<str>)> {
- let buffer = buffer.read(cx);
- iterator
- .into_iter()
- .map(|(range, text)| {
- (
- buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
- text,
- )
- })
- .collect()
-}
-
-fn from_completion_edits(
- editor_edits: &[(Range<Anchor>, Arc<str>)],
- buffer: &Entity<Buffer>,
- cx: &App,
-) -> Vec<(Range<usize>, Arc<str>)> {
- let buffer = buffer.read(cx);
- editor_edits
- .iter()
- .map(|(range, text)| {
- (
- range.start.to_offset(buffer)..range.end.to_offset(buffer),
- text.clone(),
- )
- })
- .collect()
-}
-
-#[ctor::ctor]
-fn init_logger() {
- zlog::init_test();
-}
@@ -1,49 +0,0 @@
-[package]
-name = "zeta2_tools"
-version = "0.1.0"
-edition.workspace = true
-publish.workspace = true
-license = "GPL-3.0-or-later"
-
-[lints]
-workspace = true
-
-[lib]
-path = "src/zeta2_tools.rs"
-
-[dependencies]
-anyhow.workspace = true
-client.workspace = true
-cloud_llm_client.workspace = true
-cloud_zeta2_prompt.workspace = true
-collections.workspace = true
-edit_prediction_context.workspace = true
-editor.workspace = true
-feature_flags.workspace = true
-futures.workspace = true
-gpui.workspace = true
-language.workspace = true
-multi_buffer.workspace = true
-project.workspace = true
-serde.workspace = true
-serde_json.workspace = true
-telemetry.workspace = true
-text.workspace = true
-ui.workspace = true
-ui_input.workspace = true
-util.workspace = true
-workspace.workspace = true
-zeta.workspace = true
-
-[dev-dependencies]
-clap.workspace = true
-gpui = { workspace = true, features = ["test-support"] }
-indoc.workspace = true
-language = { workspace = true, features = ["test-support"] }
-pretty_assertions.workspace = true
-project = { workspace = true, features = ["test-support"] }
-serde_json.workspace = true
-settings = { workspace = true, features = ["test-support"] }
-text = { workspace = true, features = ["test-support"] }
-util = { workspace = true, features = ["test-support"] }
-zlog.workspace = true
@@ -1,438 +0,0 @@
-use std::{
- any::TypeId,
- collections::VecDeque,
- ops::Add,
- sync::Arc,
- time::{Duration, Instant},
-};
-
-use anyhow::Result;
-use client::{Client, UserStore};
-use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery;
-use editor::{Editor, PathKey};
-use futures::StreamExt as _;
-use gpui::{
- Animation, AnimationExt, App, AppContext as _, Context, Entity, EventEmitter, FocusHandle,
- Focusable, ParentElement as _, SharedString, Styled as _, Task, TextAlign, Window, actions,
- pulsating_between,
-};
-use multi_buffer::MultiBuffer;
-use project::Project;
-use text::OffsetRangeExt;
-use ui::{
- ButtonCommon, Clickable, Color, Disableable, FluentBuilder as _, Icon, IconButton, IconName,
- IconSize, InteractiveElement, IntoElement, ListHeader, ListItem, StyledTypography, div, h_flex,
- v_flex,
-};
-use workspace::Item;
-use zeta::{
- Zeta, ZetaContextRetrievalDebugInfo, ZetaContextRetrievalStartedDebugInfo, ZetaDebugInfo,
- ZetaSearchQueryDebugInfo,
-};
-
-pub struct Zeta2ContextView {
- empty_focus_handle: FocusHandle,
- project: Entity<Project>,
- zeta: Entity<Zeta>,
- runs: VecDeque<RetrievalRun>,
- current_ix: usize,
- _update_task: Task<Result<()>>,
-}
-
-#[derive(Debug)]
-struct RetrievalRun {
- editor: Entity<Editor>,
- search_queries: Vec<SearchToolQuery>,
- started_at: Instant,
- search_results_generated_at: Option<Instant>,
- search_results_executed_at: Option<Instant>,
- finished_at: Option<Instant>,
-}
-
-actions!(
- dev,
- [
- /// Go to the previous context retrieval run
- Zeta2ContextGoBack,
- /// Go to the next context retrieval run
- Zeta2ContextGoForward
- ]
-);
-
-impl Zeta2ContextView {
- pub fn new(
- project: Entity<Project>,
- client: &Arc<Client>,
- user_store: &Entity<UserStore>,
- window: &mut gpui::Window,
- cx: &mut Context<Self>,
- ) -> Self {
- let zeta = Zeta::global(client, user_store, cx);
-
- let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info());
- let _update_task = cx.spawn_in(window, async move |this, cx| {
- while let Some(event) = debug_rx.next().await {
- this.update_in(cx, |this, window, cx| {
- this.handle_zeta_event(event, window, cx)
- })?;
- }
- Ok(())
- });
-
- Self {
- empty_focus_handle: cx.focus_handle(),
- project,
- runs: VecDeque::new(),
- current_ix: 0,
- zeta,
- _update_task,
- }
- }
-
- fn handle_zeta_event(
- &mut self,
- event: ZetaDebugInfo,
- window: &mut gpui::Window,
- cx: &mut Context<Self>,
- ) {
- match event {
- ZetaDebugInfo::ContextRetrievalStarted(info) => {
- if info.project == self.project {
- self.handle_context_retrieval_started(info, window, cx);
- }
- }
- ZetaDebugInfo::SearchQueriesGenerated(info) => {
- if info.project == self.project {
- self.handle_search_queries_generated(info, window, cx);
- }
- }
- ZetaDebugInfo::SearchQueriesExecuted(info) => {
- if info.project == self.project {
- self.handle_search_queries_executed(info, window, cx);
- }
- }
- ZetaDebugInfo::ContextRetrievalFinished(info) => {
- if info.project == self.project {
- self.handle_context_retrieval_finished(info, window, cx);
- }
- }
- ZetaDebugInfo::EditPredictionRequested(_) => {}
- }
- }
-
- fn handle_context_retrieval_started(
- &mut self,
- info: ZetaContextRetrievalStartedDebugInfo,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- if self
- .runs
- .back()
- .is_some_and(|run| run.search_results_executed_at.is_none())
- {
- self.runs.pop_back();
- }
-
- let multibuffer = cx.new(|_| MultiBuffer::new(language::Capability::ReadOnly));
- let editor = cx
- .new(|cx| Editor::for_multibuffer(multibuffer, Some(self.project.clone()), window, cx));
-
- if self.runs.len() == 32 {
- self.runs.pop_front();
- }
-
- self.runs.push_back(RetrievalRun {
- editor,
- search_queries: Vec::new(),
- started_at: info.timestamp,
- search_results_generated_at: None,
- search_results_executed_at: None,
- finished_at: None,
- });
-
- cx.notify();
- }
-
- fn handle_context_retrieval_finished(
- &mut self,
- info: ZetaContextRetrievalDebugInfo,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- let Some(run) = self.runs.back_mut() else {
- return;
- };
-
- run.finished_at = Some(info.timestamp);
-
- let multibuffer = run.editor.read(cx).buffer().clone();
- multibuffer.update(cx, |multibuffer, cx| {
- multibuffer.clear(cx);
-
- let context = self.zeta.read(cx).context_for_project(&self.project);
- let mut paths = Vec::new();
- for (buffer, ranges) in context {
- let path = PathKey::for_buffer(&buffer, cx);
- let snapshot = buffer.read(cx).snapshot();
- let ranges = ranges
- .iter()
- .map(|range| range.to_point(&snapshot))
- .collect::<Vec<_>>();
- paths.push((path, buffer, ranges));
- }
-
- for (path, buffer, ranges) in paths {
- multibuffer.set_excerpts_for_path(path, buffer, ranges, 0, cx);
- }
- });
-
- run.editor.update(cx, |editor, cx| {
- editor.move_to_beginning(&Default::default(), window, cx);
- });
-
- cx.notify();
- }
-
- fn handle_search_queries_generated(
- &mut self,
- info: ZetaSearchQueryDebugInfo,
- _window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- let Some(run) = self.runs.back_mut() else {
- return;
- };
-
- run.search_results_generated_at = Some(info.timestamp);
- run.search_queries = info.search_queries;
- cx.notify();
- }
-
- fn handle_search_queries_executed(
- &mut self,
- info: ZetaContextRetrievalDebugInfo,
- _window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- if self.current_ix + 2 == self.runs.len() {
- // Switch to latest when the queries are executed
- self.current_ix += 1;
- }
-
- let Some(run) = self.runs.back_mut() else {
- return;
- };
-
- run.search_results_executed_at = Some(info.timestamp);
- cx.notify();
- }
-
- fn handle_go_back(
- &mut self,
- _: &Zeta2ContextGoBack,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- self.current_ix = self.current_ix.saturating_sub(1);
- cx.focus_self(window);
- cx.notify();
- }
-
- fn handle_go_forward(
- &mut self,
- _: &Zeta2ContextGoForward,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- self.current_ix = self
- .current_ix
- .add(1)
- .min(self.runs.len().saturating_sub(1));
- cx.focus_self(window);
- cx.notify();
- }
-
- fn render_informational_footer(&self, cx: &mut Context<'_, Zeta2ContextView>) -> ui::Div {
- let is_latest = self.runs.len() == self.current_ix + 1;
- let run = &self.runs[self.current_ix];
-
- h_flex()
- .p_2()
- .w_full()
- .font_buffer(cx)
- .text_xs()
- .border_t_1()
- .gap_2()
- .child(
- v_flex().h_full().flex_1().children(
- run.search_queries
- .iter()
- .enumerate()
- .flat_map(|(ix, query)| {
- std::iter::once(ListHeader::new(query.glob.clone()).into_any_element())
- .chain(query.syntax_node.iter().enumerate().map(
- move |(regex_ix, regex)| {
- ListItem::new(ix * 100 + regex_ix)
- .start_slot(
- Icon::new(IconName::MagnifyingGlass)
- .color(Color::Muted)
- .size(IconSize::Small),
- )
- .child(regex.clone())
- .into_any_element()
- },
- ))
- .chain(query.content.as_ref().map(move |regex| {
- ListItem::new(ix * 100 + query.syntax_node.len())
- .start_slot(
- Icon::new(IconName::MagnifyingGlass)
- .color(Color::Muted)
- .size(IconSize::Small),
- )
- .child(regex.clone())
- .into_any_element()
- }))
- }),
- ),
- )
- .child(
- v_flex()
- .h_full()
- .text_align(TextAlign::Right)
- .child(
- h_flex()
- .justify_end()
- .child(
- IconButton::new("go-back", IconName::ChevronLeft)
- .disabled(self.current_ix == 0 || self.runs.len() < 2)
- .tooltip(ui::Tooltip::for_action_title(
- "Go to previous run",
- &Zeta2ContextGoBack,
- ))
- .on_click(cx.listener(|this, _, window, cx| {
- this.handle_go_back(&Zeta2ContextGoBack, window, cx);
- })),
- )
- .child(
- div()
- .child(format!("{}/{}", self.current_ix + 1, self.runs.len()))
- .map(|this| {
- if self.runs.back().is_some_and(|back| {
- back.search_results_executed_at.is_none()
- }) {
- this.with_animation(
- "pulsating-count",
- Animation::new(Duration::from_secs(2))
- .repeat()
- .with_easing(pulsating_between(0.4, 0.8)),
- |label, delta| label.opacity(delta),
- )
- .into_any_element()
- } else {
- this.into_any_element()
- }
- }),
- )
- .child(
- IconButton::new("go-forward", IconName::ChevronRight)
- .disabled(self.current_ix + 1 == self.runs.len())
- .tooltip(ui::Tooltip::for_action_title(
- "Go to next run",
- &Zeta2ContextGoBack,
- ))
- .on_click(cx.listener(|this, _, window, cx| {
- this.handle_go_forward(&Zeta2ContextGoForward, window, cx);
- })),
- ),
- )
- .map(|mut div| {
- let pending_message = |div: ui::Div, msg: &'static str| {
- if is_latest {
- return div.child(msg);
- } else {
- return div.child("Canceled");
- }
- };
-
- let t0 = run.started_at;
- let Some(t1) = run.search_results_generated_at else {
- return pending_message(div, "Planning search...");
- };
- div = div.child(format!("Planned search: {:>5} ms", (t1 - t0).as_millis()));
-
- let Some(t2) = run.search_results_executed_at else {
- return pending_message(div, "Running search...");
- };
- div = div.child(format!("Ran search: {:>5} ms", (t2 - t1).as_millis()));
-
- div.child(format!(
- "Total: {:>5} ms",
- (run.finished_at.unwrap_or(t0) - t0).as_millis()
- ))
- }),
- )
- }
-}
-
-impl Focusable for Zeta2ContextView {
- fn focus_handle(&self, cx: &App) -> FocusHandle {
- self.runs
- .get(self.current_ix)
- .map(|run| run.editor.read(cx).focus_handle(cx))
- .unwrap_or_else(|| self.empty_focus_handle.clone())
- }
-}
-
-impl EventEmitter<()> for Zeta2ContextView {}
-
-impl Item for Zeta2ContextView {
- type Event = ();
-
- fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
- "Edit Prediction Context".into()
- }
-
- fn buffer_kind(&self, _cx: &App) -> workspace::item::ItemBufferKind {
- workspace::item::ItemBufferKind::Multibuffer
- }
-
- fn act_as_type<'a>(
- &'a self,
- type_id: TypeId,
- self_handle: &'a Entity<Self>,
- _: &'a App,
- ) -> Option<gpui::AnyEntity> {
- if type_id == TypeId::of::<Self>() {
- Some(self_handle.clone().into())
- } else if type_id == TypeId::of::<Editor>() {
- Some(self.runs.get(self.current_ix)?.editor.clone().into())
- } else {
- None
- }
- }
-}
-
-impl gpui::Render for Zeta2ContextView {
- fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl ui::IntoElement {
- v_flex()
- .key_context("Zeta2Context")
- .on_action(cx.listener(Self::handle_go_back))
- .on_action(cx.listener(Self::handle_go_forward))
- .size_full()
- .map(|this| {
- if self.runs.is_empty() {
- this.child(
- v_flex()
- .size_full()
- .justify_center()
- .items_center()
- .child("No retrieval runs yet"),
- )
- } else {
- this.child(self.runs[self.current_ix].editor.clone())
- .child(self.render_informational_footer(cx))
- }
- })
- }
-}
@@ -1,1023 +0,0 @@
-mod zeta2_context_view;
-
-use std::{str::FromStr, sync::Arc, time::Duration};
-
-use client::{Client, UserStore};
-use cloud_llm_client::predict_edits_v3::PromptFormat;
-use collections::HashMap;
-use editor::{Editor, EditorEvent, EditorMode, MultiBuffer};
-use feature_flags::FeatureFlagAppExt as _;
-use futures::{FutureExt, StreamExt as _, channel::oneshot, future::Shared};
-use gpui::{
- Empty, Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, actions,
- prelude::*,
-};
-use language::Buffer;
-use project::{Project, telemetry_snapshot::TelemetrySnapshot};
-use ui::{ButtonLike, ContextMenu, ContextMenuEntry, DropdownMenu, KeyBinding, prelude::*};
-use ui_input::InputField;
-use util::ResultExt;
-use workspace::{Item, SplitDirection, Workspace};
-use zeta::{
- AgenticContextOptions, ContextMode, DEFAULT_SYNTAX_CONTEXT_OPTIONS, EditPredictionInputs, Zeta,
- Zeta2FeatureFlag, ZetaDebugInfo, ZetaEditPredictionDebugInfo, ZetaOptions,
-};
-
-use edit_prediction_context::{EditPredictionContextOptions, EditPredictionExcerptOptions};
-use zeta2_context_view::Zeta2ContextView;
-
-actions!(
- dev,
- [
- /// Opens the edit prediction context view.
- OpenZeta2ContextView,
- /// Opens the edit prediction inspector.
- OpenZeta2Inspector,
- /// Rate prediction as positive.
- Zeta2RatePredictionPositive,
- /// Rate prediction as negative.
- Zeta2RatePredictionNegative,
- ]
-);
-
-pub fn init(cx: &mut App) {
- cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
- workspace.register_action_renderer(|div, _, _, cx| {
- let has_flag = cx.has_flag::<Zeta2FeatureFlag>();
- div.when(has_flag, |div| {
- div.on_action(
- cx.listener(move |workspace, _: &OpenZeta2Inspector, window, cx| {
- let project = workspace.project();
- workspace.split_item(
- SplitDirection::Right,
- Box::new(cx.new(|cx| {
- Zeta2Inspector::new(
- &project,
- workspace.client(),
- workspace.user_store(),
- window,
- cx,
- )
- })),
- window,
- cx,
- )
- }),
- )
- .on_action(cx.listener(
- move |workspace, _: &OpenZeta2ContextView, window, cx| {
- let project = workspace.project();
- workspace.split_item(
- SplitDirection::Right,
- Box::new(cx.new(|cx| {
- Zeta2ContextView::new(
- project.clone(),
- workspace.client(),
- workspace.user_store(),
- window,
- cx,
- )
- })),
- window,
- cx,
- );
- },
- ))
- })
- });
- })
- .detach();
-}
-
-// TODO show included diagnostics, and events
-
-pub struct Zeta2Inspector {
- focus_handle: FocusHandle,
- project: Entity<Project>,
- last_prediction: Option<LastPrediction>,
- max_excerpt_bytes_input: Entity<InputField>,
- min_excerpt_bytes_input: Entity<InputField>,
- cursor_context_ratio_input: Entity<InputField>,
- max_prompt_bytes_input: Entity<InputField>,
- context_mode: ContextModeState,
- zeta: Entity<Zeta>,
- _active_editor_subscription: Option<Subscription>,
- _update_state_task: Task<()>,
- _receive_task: Task<()>,
-}
-
-pub enum ContextModeState {
- Llm,
- Syntax {
- max_retrieved_declarations: Entity<InputField>,
- },
-}
-
-struct LastPrediction {
- prompt_editor: Entity<Editor>,
- retrieval_time: Duration,
- request_time: Option<Duration>,
- buffer: WeakEntity<Buffer>,
- position: language::Anchor,
- state: LastPredictionState,
- inputs: EditPredictionInputs,
- project_snapshot: Shared<Task<Arc<TelemetrySnapshot>>>,
- _task: Option<Task<()>>,
-}
-
-#[derive(Clone, Copy, PartialEq)]
-enum Feedback {
- Positive,
- Negative,
-}
-
-enum LastPredictionState {
- Requested,
- Success {
- model_response_editor: Entity<Editor>,
- feedback_editor: Entity<Editor>,
- feedback: Option<Feedback>,
- request_id: String,
- },
- Failed {
- message: String,
- },
-}
-
-impl Zeta2Inspector {
- pub fn new(
- project: &Entity<Project>,
- client: &Arc<Client>,
- user_store: &Entity<UserStore>,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) -> Self {
- let zeta = Zeta::global(client, user_store, cx);
- let mut request_rx = zeta.update(cx, |zeta, _cx| zeta.debug_info());
-
- let receive_task = cx.spawn_in(window, async move |this, cx| {
- while let Some(prediction) = request_rx.next().await {
- this.update_in(cx, |this, window, cx| {
- this.update_last_prediction(prediction, window, cx)
- })
- .ok();
- }
- });
-
- let mut this = Self {
- focus_handle: cx.focus_handle(),
- project: project.clone(),
- last_prediction: None,
- max_excerpt_bytes_input: Self::number_input("Max Excerpt Bytes", window, cx),
- min_excerpt_bytes_input: Self::number_input("Min Excerpt Bytes", window, cx),
- cursor_context_ratio_input: Self::number_input("Cursor Context Ratio", window, cx),
- max_prompt_bytes_input: Self::number_input("Max Prompt Bytes", window, cx),
- context_mode: ContextModeState::Llm,
- zeta: zeta.clone(),
- _active_editor_subscription: None,
- _update_state_task: Task::ready(()),
- _receive_task: receive_task,
- };
- this.set_options_state(&zeta.read(cx).options().clone(), window, cx);
- this
- }
-
- fn set_options_state(
- &mut self,
- options: &ZetaOptions,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- let excerpt_options = options.context.excerpt();
- self.max_excerpt_bytes_input.update(cx, |input, cx| {
- input.set_text(excerpt_options.max_bytes.to_string(), window, cx);
- });
- self.min_excerpt_bytes_input.update(cx, |input, cx| {
- input.set_text(excerpt_options.min_bytes.to_string(), window, cx);
- });
- self.cursor_context_ratio_input.update(cx, |input, cx| {
- input.set_text(
- format!(
- "{:.2}",
- excerpt_options.target_before_cursor_over_total_bytes
- ),
- window,
- cx,
- );
- });
- self.max_prompt_bytes_input.update(cx, |input, cx| {
- input.set_text(options.max_prompt_bytes.to_string(), window, cx);
- });
-
- match &options.context {
- ContextMode::Agentic(_) => {
- self.context_mode = ContextModeState::Llm;
- }
- ContextMode::Syntax(_) => {
- self.context_mode = ContextModeState::Syntax {
- max_retrieved_declarations: Self::number_input(
- "Max Retrieved Definitions",
- window,
- cx,
- ),
- };
- }
- }
- cx.notify();
- }
-
- fn set_zeta_options(&mut self, options: ZetaOptions, cx: &mut Context<Self>) {
- self.zeta.update(cx, |this, _cx| this.set_options(options));
-
- if let Some(prediction) = self.last_prediction.as_mut() {
- if let Some(buffer) = prediction.buffer.upgrade() {
- let position = prediction.position;
- let project = self.project.clone();
- self.zeta.update(cx, |zeta, cx| {
- zeta.refresh_prediction_from_buffer(project, buffer, position, cx)
- });
- prediction.state = LastPredictionState::Requested;
- } else {
- self.last_prediction.take();
- }
- }
-
- cx.notify();
- }
-
- fn number_input(
- label: &'static str,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) -> Entity<InputField> {
- let input = cx.new(|cx| {
- InputField::new(window, cx, "")
- .label(label)
- .label_min_width(px(64.))
- });
-
- cx.subscribe_in(
- &input.read(cx).editor().clone(),
- window,
- |this, _, event, _window, cx| {
- let EditorEvent::BufferEdited = event else {
- return;
- };
-
- fn number_input_value<T: FromStr + Default>(
- input: &Entity<InputField>,
- cx: &App,
- ) -> T {
- input
- .read(cx)
- .editor()
- .read(cx)
- .text(cx)
- .parse::<T>()
- .unwrap_or_default()
- }
-
- let zeta_options = this.zeta.read(cx).options().clone();
-
- let excerpt_options = EditPredictionExcerptOptions {
- max_bytes: number_input_value(&this.max_excerpt_bytes_input, cx),
- min_bytes: number_input_value(&this.min_excerpt_bytes_input, cx),
- target_before_cursor_over_total_bytes: number_input_value(
- &this.cursor_context_ratio_input,
- cx,
- ),
- };
-
- let context = match zeta_options.context {
- ContextMode::Agentic(_context_options) => {
- ContextMode::Agentic(AgenticContextOptions {
- excerpt: excerpt_options,
- })
- }
- ContextMode::Syntax(context_options) => {
- let max_retrieved_declarations = match &this.context_mode {
- ContextModeState::Llm => {
- zeta::DEFAULT_SYNTAX_CONTEXT_OPTIONS.max_retrieved_declarations
- }
- ContextModeState::Syntax {
- max_retrieved_declarations,
- } => number_input_value(max_retrieved_declarations, cx),
- };
-
- ContextMode::Syntax(EditPredictionContextOptions {
- excerpt: excerpt_options,
- max_retrieved_declarations,
- ..context_options
- })
- }
- };
-
- this.set_zeta_options(
- ZetaOptions {
- context,
- max_prompt_bytes: number_input_value(&this.max_prompt_bytes_input, cx),
- max_diagnostic_bytes: zeta_options.max_diagnostic_bytes,
- prompt_format: zeta_options.prompt_format,
- file_indexing_parallelism: zeta_options.file_indexing_parallelism,
- buffer_change_grouping_interval: zeta_options
- .buffer_change_grouping_interval,
- },
- cx,
- );
- },
- )
- .detach();
- input
- }
-
- fn update_last_prediction(
- &mut self,
- prediction: zeta::ZetaDebugInfo,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- self._update_state_task = cx.spawn_in(window, {
- let language_registry = self.project.read(cx).languages().clone();
- async move |this, cx| {
- let mut languages = HashMap::default();
- let ZetaDebugInfo::EditPredictionRequested(prediction) = prediction else {
- return;
- };
- for ext in prediction
- .inputs
- .included_files
- .iter()
- .filter_map(|file| file.path.extension())
- {
- if !languages.contains_key(ext) {
- // Most snippets are gonna be the same language,
- // so we think it's fine to do this sequentially for now
- languages.insert(
- ext.to_owned(),
- language_registry
- .language_for_name_or_extension(&ext.to_string_lossy())
- .await
- .ok(),
- );
- }
- }
-
- let markdown_language = language_registry
- .language_for_name("Markdown")
- .await
- .log_err();
-
- let json_language = language_registry.language_for_name("Json").await.log_err();
-
- this.update_in(cx, |this, window, cx| {
- let ZetaEditPredictionDebugInfo {
- response_rx,
- position,
- buffer,
- retrieval_time,
- local_prompt,
- ..
- } = prediction;
-
- let task = cx.spawn_in(window, {
- let markdown_language = markdown_language.clone();
- let json_language = json_language.clone();
- async move |this, cx| {
- let response = response_rx.await;
-
- this.update_in(cx, |this, window, cx| {
- if let Some(prediction) = this.last_prediction.as_mut() {
- prediction.state = match response {
- Ok((Ok(response), request_time)) => {
- prediction.request_time = Some(request_time);
-
- let feedback_editor = cx.new(|cx| {
- let buffer = cx.new(|cx| {
- let mut buffer = Buffer::local("", cx);
- buffer.set_language(
- markdown_language.clone(),
- cx,
- );
- buffer
- });
- let buffer =
- cx.new(|cx| MultiBuffer::singleton(buffer, cx));
- let mut editor = Editor::new(
- EditorMode::AutoHeight {
- min_lines: 3,
- max_lines: None,
- },
- buffer,
- None,
- window,
- cx,
- );
- editor.set_placeholder_text(
- "Write feedback here",
- window,
- cx,
- );
- editor.set_show_line_numbers(false, cx);
- editor.set_show_gutter(false, cx);
- editor.set_show_scrollbars(false, cx);
- editor
- });
-
- cx.subscribe_in(
- &feedback_editor,
- window,
- |this, editor, ev, window, cx| match ev {
- EditorEvent::BufferEdited => {
- if let Some(last_prediction) =
- this.last_prediction.as_mut()
- && let LastPredictionState::Success {
- feedback: feedback_state,
- ..
- } = &mut last_prediction.state
- {
- if feedback_state.take().is_some() {
- editor.update(cx, |editor, cx| {
- editor.set_placeholder_text(
- "Write feedback here",
- window,
- cx,
- );
- });
- cx.notify();
- }
- }
- }
- _ => {}
- },
- )
- .detach();
-
- LastPredictionState::Success {
- model_response_editor: cx.new(|cx| {
- let buffer = cx.new(|cx| {
- let mut buffer = Buffer::local(
- serde_json::to_string_pretty(&response)
- .unwrap_or_default(),
- cx,
- );
- buffer.set_language(json_language, cx);
- buffer
- });
- let buffer = cx.new(|cx| {
- MultiBuffer::singleton(buffer, cx)
- });
- let mut editor = Editor::new(
- EditorMode::full(),
- buffer,
- None,
- window,
- cx,
- );
- editor.set_read_only(true);
- editor.set_show_line_numbers(false, cx);
- editor.set_show_gutter(false, cx);
- editor.set_show_scrollbars(false, cx);
- editor
- }),
- feedback_editor,
- feedback: None,
- request_id: response.id.clone(),
- }
- }
- Ok((Err(err), request_time)) => {
- prediction.request_time = Some(request_time);
- LastPredictionState::Failed { message: err }
- }
- Err(oneshot::Canceled) => LastPredictionState::Failed {
- message: "Canceled".to_string(),
- },
- };
- }
- })
- .ok();
- }
- });
-
- let project_snapshot_task = TelemetrySnapshot::new(&this.project, cx);
-
- this.last_prediction = Some(LastPrediction {
- prompt_editor: cx.new(|cx| {
- let buffer = cx.new(|cx| {
- let mut buffer =
- Buffer::local(local_prompt.unwrap_or_else(|err| err), cx);
- buffer.set_language(markdown_language.clone(), cx);
- buffer
- });
- let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
- let mut editor =
- Editor::new(EditorMode::full(), buffer, None, window, cx);
- editor.set_read_only(true);
- editor.set_show_line_numbers(false, cx);
- editor.set_show_gutter(false, cx);
- editor.set_show_scrollbars(false, cx);
- editor
- }),
- retrieval_time,
- request_time: None,
- buffer,
- position,
- state: LastPredictionState::Requested,
- project_snapshot: cx
- .foreground_executor()
- .spawn(async move { Arc::new(project_snapshot_task.await) })
- .shared(),
- inputs: prediction.inputs,
- _task: Some(task),
- });
- cx.notify();
- })
- .ok();
- }
- });
- }
-
- fn handle_rate_positive(
- &mut self,
- _action: &Zeta2RatePredictionPositive,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- self.handle_rate(Feedback::Positive, window, cx);
- }
-
- fn handle_rate_negative(
- &mut self,
- _action: &Zeta2RatePredictionNegative,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- self.handle_rate(Feedback::Negative, window, cx);
- }
-
- fn handle_rate(&mut self, kind: Feedback, window: &mut Window, cx: &mut Context<Self>) {
- let Some(last_prediction) = self.last_prediction.as_mut() else {
- return;
- };
-
- let project_snapshot_task = last_prediction.project_snapshot.clone();
-
- cx.spawn_in(window, async move |this, cx| {
- let project_snapshot = project_snapshot_task.await;
- this.update_in(cx, |this, window, cx| {
- let Some(last_prediction) = this.last_prediction.as_mut() else {
- return;
- };
-
- let LastPredictionState::Success {
- feedback: feedback_state,
- feedback_editor,
- model_response_editor,
- request_id,
- ..
- } = &mut last_prediction.state
- else {
- return;
- };
-
- *feedback_state = Some(kind);
- let text = feedback_editor.update(cx, |feedback_editor, cx| {
- feedback_editor.set_placeholder_text(
- "Submitted. Edit or submit again to change.",
- window,
- cx,
- );
- feedback_editor.text(cx)
- });
- cx.notify();
-
- cx.defer_in(window, {
- let model_response_editor = model_response_editor.downgrade();
- move |_, window, cx| {
- if let Some(model_response_editor) = model_response_editor.upgrade() {
- model_response_editor.focus_handle(cx).focus(window);
- }
- }
- });
-
- let kind = match kind {
- Feedback::Positive => "positive",
- Feedback::Negative => "negative",
- };
-
- telemetry::event!(
- "Zeta2 Prediction Rated",
- id = request_id,
- kind = kind,
- text = text,
- request = last_prediction.inputs,
- project_snapshot = project_snapshot,
- );
- })
- .log_err();
- })
- .detach();
- }
-
- fn render_options(&self, window: &mut Window, cx: &mut Context<Self>) -> Div {
- v_flex()
- .gap_2()
- .child(
- h_flex()
- .child(Headline::new("Options").size(HeadlineSize::Small))
- .justify_between()
- .child(
- ui::Button::new("reset-options", "Reset")
- .disabled(self.zeta.read(cx).options() == &zeta::DEFAULT_OPTIONS)
- .style(ButtonStyle::Outlined)
- .size(ButtonSize::Large)
- .on_click(cx.listener(|this, _, window, cx| {
- this.set_options_state(&zeta::DEFAULT_OPTIONS, window, cx);
- })),
- ),
- )
- .child(
- v_flex()
- .gap_2()
- .child(
- h_flex()
- .gap_2()
- .items_end()
- .child(self.max_excerpt_bytes_input.clone())
- .child(self.min_excerpt_bytes_input.clone())
- .child(self.cursor_context_ratio_input.clone())
- .child(self.render_context_mode_dropdown(window, cx)),
- )
- .child(
- h_flex()
- .gap_2()
- .items_end()
- .children(match &self.context_mode {
- ContextModeState::Llm => None,
- ContextModeState::Syntax {
- max_retrieved_declarations,
- } => Some(max_retrieved_declarations.clone()),
- })
- .child(self.max_prompt_bytes_input.clone())
- .child(self.render_prompt_format_dropdown(window, cx)),
- ),
- )
- }
-
- fn render_context_mode_dropdown(&self, window: &mut Window, cx: &mut Context<Self>) -> Div {
- let this = cx.weak_entity();
-
- v_flex()
- .gap_1p5()
- .child(
- Label::new("Context Mode")
- .size(LabelSize::Small)
- .color(Color::Muted),
- )
- .child(
- DropdownMenu::new(
- "ep-ctx-mode",
- match &self.context_mode {
- ContextModeState::Llm => "LLM-based",
- ContextModeState::Syntax { .. } => "Syntax",
- },
- ContextMenu::build(window, cx, move |menu, _window, _cx| {
- menu.item(
- ContextMenuEntry::new("LLM-based")
- .toggleable(
- IconPosition::End,
- matches!(self.context_mode, ContextModeState::Llm),
- )
- .handler({
- let this = this.clone();
- move |window, cx| {
- this.update(cx, |this, cx| {
- let current_options =
- this.zeta.read(cx).options().clone();
- match current_options.context.clone() {
- ContextMode::Agentic(_) => {}
- ContextMode::Syntax(context_options) => {
- let options = ZetaOptions {
- context: ContextMode::Agentic(
- AgenticContextOptions {
- excerpt: context_options.excerpt,
- },
- ),
- ..current_options
- };
- this.set_options_state(&options, window, cx);
- this.set_zeta_options(options, cx);
- }
- }
- })
- .ok();
- }
- }),
- )
- .item(
- ContextMenuEntry::new("Syntax")
- .toggleable(
- IconPosition::End,
- matches!(self.context_mode, ContextModeState::Syntax { .. }),
- )
- .handler({
- move |window, cx| {
- this.update(cx, |this, cx| {
- let current_options =
- this.zeta.read(cx).options().clone();
- match current_options.context.clone() {
- ContextMode::Agentic(context_options) => {
- let options = ZetaOptions {
- context: ContextMode::Syntax(
- EditPredictionContextOptions {
- excerpt: context_options.excerpt,
- ..DEFAULT_SYNTAX_CONTEXT_OPTIONS
- },
- ),
- ..current_options
- };
- this.set_options_state(&options, window, cx);
- this.set_zeta_options(options, cx);
- }
- ContextMode::Syntax(_) => {}
- }
- })
- .ok();
- }
- }),
- )
- }),
- )
- .style(ui::DropdownStyle::Outlined),
- )
- }
-
- fn render_prompt_format_dropdown(&self, window: &mut Window, cx: &mut Context<Self>) -> Div {
- let active_format = self.zeta.read(cx).options().prompt_format;
- let this = cx.weak_entity();
-
- v_flex()
- .gap_1p5()
- .child(
- Label::new("Prompt Format")
- .size(LabelSize::Small)
- .color(Color::Muted),
- )
- .child(
- DropdownMenu::new(
- "ep-prompt-format",
- active_format.to_string(),
- ContextMenu::build(window, cx, move |mut menu, _window, _cx| {
- for prompt_format in PromptFormat::iter() {
- menu = menu.item(
- ContextMenuEntry::new(prompt_format.to_string())
- .toggleable(IconPosition::End, active_format == prompt_format)
- .handler({
- let this = this.clone();
- move |_window, cx| {
- this.update(cx, |this, cx| {
- let current_options =
- this.zeta.read(cx).options().clone();
- let options = ZetaOptions {
- prompt_format,
- ..current_options
- };
- this.set_zeta_options(options, cx);
- })
- .ok();
- }
- }),
- )
- }
- menu
- }),
- )
- .style(ui::DropdownStyle::Outlined),
- )
- }
-
- fn render_stats(&self) -> Option<Div> {
- let Some(prediction) = self.last_prediction.as_ref() else {
- return None;
- };
-
- Some(
- v_flex()
- .p_4()
- .gap_2()
- .min_w(px(160.))
- .child(Headline::new("Stats").size(HeadlineSize::Small))
- .child(Self::render_duration(
- "Context retrieval",
- Some(prediction.retrieval_time),
- ))
- .child(Self::render_duration("Request", prediction.request_time)),
- )
- }
-
- fn render_duration(name: &'static str, time: Option<Duration>) -> Div {
- h_flex()
- .gap_1()
- .child(Label::new(name).color(Color::Muted).size(LabelSize::Small))
- .child(match time {
- Some(time) => Label::new(if time.as_micros() >= 1000 {
- format!("{} ms", time.as_millis())
- } else {
- format!("{} ยตs", time.as_micros())
- })
- .size(LabelSize::Small),
- None => Label::new("...").size(LabelSize::Small),
- })
- }
-
- fn render_content(&self, _: &mut Window, cx: &mut Context<Self>) -> AnyElement {
- if !cx.has_flag::<Zeta2FeatureFlag>() {
- return Self::render_message("`zeta2` feature flag is not enabled");
- }
-
- match self.last_prediction.as_ref() {
- None => Self::render_message("No prediction"),
- Some(prediction) => self.render_last_prediction(prediction, cx).into_any(),
- }
- }
-
- fn render_message(message: impl Into<SharedString>) -> AnyElement {
- v_flex()
- .size_full()
- .justify_center()
- .items_center()
- .child(Label::new(message).size(LabelSize::Large))
- .into_any()
- }
-
- fn render_last_prediction(&self, prediction: &LastPrediction, cx: &mut Context<Self>) -> Div {
- h_flex()
- .items_start()
- .w_full()
- .flex_1()
- .border_t_1()
- .border_color(cx.theme().colors().border)
- .bg(cx.theme().colors().editor_background)
- .child(
- v_flex()
- .flex_1()
- .gap_2()
- .p_4()
- .h_full()
- .child(
- h_flex()
- .justify_between()
- .child(ui::Headline::new("Prompt").size(ui::HeadlineSize::XSmall))
- .child(match prediction.state {
- LastPredictionState::Requested
- | LastPredictionState::Failed { .. } => ui::Chip::new("Local")
- .bg_color(cx.theme().status().warning_background)
- .label_color(Color::Success),
- LastPredictionState::Success { .. } => ui::Chip::new("Cloud")
- .bg_color(cx.theme().status().success_background)
- .label_color(Color::Success),
- }),
- )
- .child(prediction.prompt_editor.clone()),
- )
- .child(ui::vertical_divider())
- .child(
- v_flex()
- .flex_1()
- .gap_2()
- .h_full()
- .child(
- v_flex()
- .flex_1()
- .gap_2()
- .p_4()
- .child(
- ui::Headline::new("Model Response").size(ui::HeadlineSize::XSmall),
- )
- .child(match &prediction.state {
- LastPredictionState::Success {
- model_response_editor,
- ..
- } => model_response_editor.clone().into_any_element(),
- LastPredictionState::Requested => v_flex()
- .gap_2()
- .child(Label::new("Loading...").buffer_font(cx))
- .into_any_element(),
- LastPredictionState::Failed { message } => v_flex()
- .gap_2()
- .max_w_96()
- .child(Label::new(message.clone()).buffer_font(cx))
- .into_any_element(),
- }),
- )
- .child(ui::divider())
- .child(
- if let LastPredictionState::Success {
- feedback_editor,
- feedback: feedback_state,
- ..
- } = &prediction.state
- {
- v_flex()
- .key_context("Zeta2Feedback")
- .on_action(cx.listener(Self::handle_rate_positive))
- .on_action(cx.listener(Self::handle_rate_negative))
- .gap_2()
- .p_2()
- .child(feedback_editor.clone())
- .child(
- h_flex()
- .justify_end()
- .w_full()
- .child(
- ButtonLike::new("rate-positive")
- .when(
- *feedback_state == Some(Feedback::Positive),
- |this| this.style(ButtonStyle::Filled),
- )
- .child(
- KeyBinding::for_action(
- &Zeta2RatePredictionPositive,
- cx,
- )
- .size(TextSize::Small.rems(cx)),
- )
- .child(ui::Icon::new(ui::IconName::ThumbsUp))
- .on_click(cx.listener(|this, _, window, cx| {
- this.handle_rate_positive(
- &Zeta2RatePredictionPositive,
- window,
- cx,
- );
- })),
- )
- .child(
- ButtonLike::new("rate-negative")
- .when(
- *feedback_state == Some(Feedback::Negative),
- |this| this.style(ButtonStyle::Filled),
- )
- .child(
- KeyBinding::for_action(
- &Zeta2RatePredictionNegative,
- cx,
- )
- .size(TextSize::Small.rems(cx)),
- )
- .child(ui::Icon::new(ui::IconName::ThumbsDown))
- .on_click(cx.listener(|this, _, window, cx| {
- this.handle_rate_negative(
- &Zeta2RatePredictionNegative,
- window,
- cx,
- );
- })),
- ),
- )
- .into_any()
- } else {
- Empty.into_any_element()
- },
- ),
- )
- }
-}
-
-impl Focusable for Zeta2Inspector {
- fn focus_handle(&self, _cx: &App) -> FocusHandle {
- self.focus_handle.clone()
- }
-}
-
-impl Item for Zeta2Inspector {
- type Event = ();
-
- fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
- "Zeta2 Inspector".into()
- }
-}
-
-impl EventEmitter<()> for Zeta2Inspector {}
-
-impl Render for Zeta2Inspector {
- fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- v_flex()
- .size_full()
- .bg(cx.theme().colors().editor_background)
- .child(
- h_flex()
- .w_full()
- .child(
- v_flex()
- .flex_1()
- .p_4()
- .h_full()
- .justify_between()
- .child(self.render_options(window, cx))
- .gap_4(),
- )
- .child(ui::vertical_divider())
- .children(self.render_stats()),
- )
- .child(self.render_content(window, cx))
- }
-}
@@ -1,1260 +0,0 @@
-use ::util::rel_path::RelPath;
-use ::util::{RangeExt, ResultExt as _};
-use anyhow::{Context as _, Result};
-use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
-use edit_prediction_context::{
- Declaration, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions, Identifier,
- Imports, Reference, ReferenceRegion, SyntaxIndex, SyntaxIndexState, references_in_range,
-};
-use futures::StreamExt as _;
-use futures::channel::mpsc;
-use gpui::Entity;
-use gpui::{AppContext, AsyncApp};
-use language::OffsetRangeExt;
-use language::{BufferSnapshot, Point};
-use ordered_float::OrderedFloat;
-use polars::prelude::*;
-use project::{Project, ProjectEntryId, ProjectPath, Worktree};
-use serde::{Deserialize, Serialize};
-use std::fs;
-use std::{
- cmp::Reverse,
- collections::{HashMap, HashSet},
- fs::File,
- hash::{Hash, Hasher},
- io::{BufRead, BufReader, BufWriter, Write as _},
- ops::Range,
- path::{Path, PathBuf},
- sync::{
- Arc,
- atomic::{self, AtomicUsize},
- },
- time::Duration,
-};
-use util::paths::PathStyle;
-use zeta::ContextMode;
-
-use crate::headless::ZetaCliAppState;
-use crate::source_location::SourceLocation;
-use crate::util::{open_buffer, open_buffer_with_language_server};
-
-pub async fn retrieval_stats(
- worktree: PathBuf,
- app_state: Arc<ZetaCliAppState>,
- only_extension: Option<String>,
- file_limit: Option<usize>,
- skip_files: Option<usize>,
- options: zeta::ZetaOptions,
- cx: &mut AsyncApp,
-) -> Result<String> {
- let ContextMode::Syntax(context_options) = options.context.clone() else {
- anyhow::bail!("retrieval stats only works in ContextMode::Syntax");
- };
-
- let options = Arc::new(options);
- let worktree_path = worktree.canonicalize()?;
-
- let project = cx.update(|cx| {
- Project::local(
- app_state.client.clone(),
- app_state.node_runtime.clone(),
- app_state.user_store.clone(),
- app_state.languages.clone(),
- app_state.fs.clone(),
- None,
- cx,
- )
- })?;
-
- let worktree = project
- .update(cx, |project, cx| {
- project.create_worktree(&worktree_path, true, cx)
- })?
- .await?;
-
- // wait for worktree scan so that wait_for_initial_file_indexing waits for the whole worktree.
- worktree
- .read_with(cx, |worktree, _cx| {
- worktree.as_local().unwrap().scan_complete()
- })?
- .await;
-
- let index = cx.new(|cx| SyntaxIndex::new(&project, options.file_indexing_parallelism, cx))?;
- index
- .read_with(cx, |index, cx| index.wait_for_initial_file_indexing(cx))?
- .await?;
- let indexed_files = index
- .read_with(cx, |index, cx| index.indexed_file_paths(cx))?
- .await;
- let mut filtered_files = indexed_files
- .into_iter()
- .filter(|project_path| {
- let file_extension = project_path.path.extension();
- if let Some(only_extension) = only_extension.as_ref() {
- file_extension.is_some_and(|extension| extension == only_extension)
- } else {
- file_extension
- .is_some_and(|extension| !["md", "json", "sh", "diff"].contains(&extension))
- }
- })
- .collect::<Vec<_>>();
- filtered_files.sort_by(|a, b| a.path.cmp(&b.path));
-
- let index_state = index.read_with(cx, |index, _cx| index.state().clone())?;
- cx.update(|_| {
- drop(index);
- })?;
- let index_state = Arc::new(
- Arc::into_inner(index_state)
- .context("Index state had more than 1 reference")?
- .into_inner(),
- );
-
- struct FileSnapshot {
- project_entry_id: ProjectEntryId,
- snapshot: BufferSnapshot,
- hash: u64,
- parent_abs_path: Arc<Path>,
- }
-
- let files: Vec<FileSnapshot> = futures::future::try_join_all({
- filtered_files
- .iter()
- .map(|file| {
- let buffer_task =
- open_buffer(project.clone(), worktree.clone(), file.path.clone(), cx);
- cx.spawn(async move |cx| {
- let buffer = buffer_task.await?;
- let (project_entry_id, parent_abs_path, snapshot) =
- buffer.read_with(cx, |buffer, cx| {
- let file = project::File::from_dyn(buffer.file()).unwrap();
- let project_entry_id = file.project_entry_id().unwrap();
- let mut parent_abs_path = file.worktree.read(cx).absolutize(&file.path);
- if !parent_abs_path.pop() {
- panic!("Invalid worktree path");
- }
-
- (project_entry_id, parent_abs_path, buffer.snapshot())
- })?;
-
- anyhow::Ok(
- cx.background_spawn(async move {
- let mut hasher = collections::FxHasher::default();
- snapshot.text().hash(&mut hasher);
- FileSnapshot {
- project_entry_id,
- snapshot,
- hash: hasher.finish(),
- parent_abs_path: parent_abs_path.into(),
- }
- })
- .await,
- )
- })
- })
- .collect::<Vec<_>>()
- })
- .await?;
-
- let mut file_snapshots = HashMap::default();
- let mut hasher = collections::FxHasher::default();
- for FileSnapshot {
- project_entry_id,
- snapshot,
- hash,
- ..
- } in &files
- {
- file_snapshots.insert(*project_entry_id, snapshot.clone());
- hash.hash(&mut hasher);
- }
- let files_hash = hasher.finish();
- let file_snapshots = Arc::new(file_snapshots);
- let target_cli_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("../../target/zeta_cli");
- fs::create_dir_all(&target_cli_dir).unwrap();
- let target_cli_dir = target_cli_dir.canonicalize().unwrap();
-
- let lsp_cache_dir = target_cli_dir.join("cache");
- fs::create_dir_all(&lsp_cache_dir).unwrap();
-
- let lsp_definitions_path = lsp_cache_dir.join(format!(
- "{}-{:x}.jsonl",
- worktree_path.file_stem().unwrap_or_default().display(),
- files_hash
- ));
-
- let mut lsp_definitions = HashMap::default();
- let mut lsp_files = 0;
-
- if fs::exists(&lsp_definitions_path)? {
- log::info!(
- "Using cached LSP definitions from {}",
- lsp_definitions_path.display()
- );
-
- let file = File::options()
- .read(true)
- .write(true)
- .open(&lsp_definitions_path)?;
- let lines = BufReader::new(&file).lines();
- let mut valid_len: usize = 0;
-
- for (line, expected_file) in lines.zip(files.iter()) {
- let line = line?;
- let FileLspDefinitions { path, references } = match serde_json::from_str(&line) {
- Ok(ok) => ok,
- Err(_) => {
- log::error!("Found invalid cache line. Truncating to #{lsp_files}.",);
- file.set_len(valid_len as u64)?;
- break;
- }
- };
- let expected_path = expected_file.snapshot.file().unwrap().path().as_unix_str();
- if expected_path != path.as_ref() {
- log::error!(
- "Expected file #{} to be {expected_path}, but found {path}. Truncating to #{lsp_files}.",
- lsp_files + 1
- );
- file.set_len(valid_len as u64)?;
- break;
- }
- for (point, ranges) in references {
- let Ok(path) = RelPath::new(Path::new(path.as_ref()), PathStyle::Posix) else {
- log::warn!("Invalid path: {}", path);
- continue;
- };
- lsp_definitions.insert(
- SourceLocation {
- path: path.into_arc(),
- point: point.into(),
- },
- ranges,
- );
- }
- lsp_files += 1;
- valid_len += line.len() + 1
- }
- }
-
- if lsp_files < files.len() {
- if lsp_files == 0 {
- log::warn!(
- "No LSP definitions found, populating {}",
- lsp_definitions_path.display()
- );
- } else {
- log::warn!("{} files missing from LSP cache", files.len() - lsp_files);
- }
-
- gather_lsp_definitions(
- &lsp_definitions_path,
- lsp_files,
- &filtered_files,
- &worktree,
- &project,
- &mut lsp_definitions,
- cx,
- )
- .await?;
- }
- let files_len = files.len().min(file_limit.unwrap_or(usize::MAX));
- let done_count = Arc::new(AtomicUsize::new(0));
-
- let (output_tx, output_rx) = mpsc::unbounded::<ReferenceRetrievalResult>();
-
- let tasks = files
- .into_iter()
- .skip(skip_files.unwrap_or(0))
- .take(file_limit.unwrap_or(usize::MAX))
- .map(|project_file| {
- let index_state = index_state.clone();
- let lsp_definitions = lsp_definitions.clone();
- let output_tx = output_tx.clone();
- let done_count = done_count.clone();
- let file_snapshots = file_snapshots.clone();
- let context_options = context_options.clone();
- cx.background_spawn(async move {
- let snapshot = project_file.snapshot;
-
- let full_range = 0..snapshot.len();
- let references = references_in_range(
- full_range,
- &snapshot.text(),
- ReferenceRegion::Nearby,
- &snapshot,
- );
-
- let imports = if context_options.use_imports {
- Imports::gather(&snapshot, Some(&project_file.parent_abs_path))
- } else {
- Imports::default()
- };
-
- let path = snapshot.file().unwrap().path();
-
- for reference in references {
- let query_point = snapshot.offset_to_point(reference.range.start);
- let source_location = SourceLocation {
- path: path.clone(),
- point: query_point,
- };
- let lsp_definitions = lsp_definitions
- .get(&source_location)
- .cloned()
- .unwrap_or_else(|| {
- log::warn!(
- "No definitions found for source location: {:?}",
- source_location
- );
- Vec::new()
- });
-
- let retrieve_result = retrieve_definitions(
- &reference,
- &imports,
- query_point,
- &snapshot,
- &index_state,
- &file_snapshots,
- &context_options,
- )
- .await?;
-
- let result = ReferenceRetrievalResult {
- cursor_path: path.clone(),
- identifier: reference.identifier,
- cursor_point: query_point,
- lsp_definitions,
- retrieved_definitions: retrieve_result.definitions,
- excerpt_range: retrieve_result.excerpt_range,
- };
-
- output_tx.unbounded_send(result).ok();
- }
-
- println!(
- "{:02}/{:02} done",
- done_count.fetch_add(1, atomic::Ordering::Relaxed) + 1,
- files_len,
- );
-
- anyhow::Ok(())
- })
- })
- .collect::<Vec<_>>();
-
- drop(output_tx);
-
- let df_task = cx.background_spawn(build_dataframe(output_rx));
-
- futures::future::try_join_all(tasks).await?;
- let mut df = df_task.await?;
-
- let run_id = format!(
- "{}-{}",
- worktree_path.file_stem().unwrap_or_default().display(),
- chrono::Local::now().format("%Y%m%d_%H%M%S")
- );
- let run_dir = target_cli_dir.join(run_id);
- fs::create_dir(&run_dir).unwrap();
-
- let parquet_path = run_dir.join("stats.parquet");
- let mut parquet_file = fs::File::create(&parquet_path)?;
-
- ParquetWriter::new(&mut parquet_file)
- .finish(&mut df)
- .unwrap();
-
- let stats = SummaryStats::from_dataframe(df)?;
-
- let stats_path = run_dir.join("stats.txt");
- fs::write(&stats_path, format!("{}", stats))?;
-
- println!("{}", stats);
- println!("\nWrote:");
- println!("- {}", relativize_path(&parquet_path).display());
- println!("- {}", relativize_path(&stats_path).display());
- println!("- {}", relativize_path(&lsp_definitions_path).display());
-
- Ok("".to_string())
-}
-
-async fn build_dataframe(
- mut output_rx: mpsc::UnboundedReceiver<ReferenceRetrievalResult>,
-) -> Result<DataFrame> {
- use soa_rs::{Soa, Soars};
-
- #[derive(Default, Soars)]
- struct Row {
- ref_id: u32,
- cursor_path: String,
- cursor_row: u32,
- cursor_column: u32,
- cursor_identifier: String,
- gold_in_excerpt: bool,
- gold_path: String,
- gold_row: u32,
- gold_column: u32,
- gold_is_external: bool,
- candidate_count: u32,
- candidate_path: Option<String>,
- candidate_row: Option<u32>,
- candidate_column: Option<u32>,
- candidate_is_gold: Option<bool>,
- candidate_rank: Option<u32>,
- candidate_is_same_file: Option<bool>,
- candidate_is_referenced_nearby: Option<bool>,
- candidate_is_referenced_in_breadcrumb: Option<bool>,
- candidate_reference_count: Option<u32>,
- candidate_same_file_declaration_count: Option<u32>,
- candidate_declaration_count: Option<u32>,
- candidate_reference_line_distance: Option<u32>,
- candidate_declaration_line_distance: Option<u32>,
- candidate_excerpt_vs_item_jaccard: Option<f32>,
- candidate_excerpt_vs_signature_jaccard: Option<f32>,
- candidate_adjacent_vs_item_jaccard: Option<f32>,
- candidate_adjacent_vs_signature_jaccard: Option<f32>,
- candidate_excerpt_vs_item_weighted_overlap: Option<f32>,
- candidate_excerpt_vs_signature_weighted_overlap: Option<f32>,
- candidate_adjacent_vs_item_weighted_overlap: Option<f32>,
- candidate_adjacent_vs_signature_weighted_overlap: Option<f32>,
- candidate_path_import_match_count: Option<u32>,
- candidate_wildcard_path_import_match_count: Option<u32>,
- candidate_import_similarity: Option<f32>,
- candidate_max_import_similarity: Option<f32>,
- candidate_normalized_import_similarity: Option<f32>,
- candidate_wildcard_import_similarity: Option<f32>,
- candidate_normalized_wildcard_import_similarity: Option<f32>,
- candidate_included_by_others: Option<u32>,
- candidate_includes_others: Option<u32>,
- }
- let mut rows = Soa::<Row>::new();
- let mut next_ref_id = 0;
-
- while let Some(result) = output_rx.next().await {
- let mut gold_is_external = false;
- let mut gold_in_excerpt = false;
- let cursor_path = result.cursor_path.as_unix_str();
- let cursor_row = result.cursor_point.row + 1;
- let cursor_column = result.cursor_point.column + 1;
- let cursor_identifier = result.identifier.name.to_string();
- let ref_id = next_ref_id;
- next_ref_id += 1;
-
- for lsp_definition in result.lsp_definitions {
- let SourceRange {
- path: gold_path,
- point_range: gold_point_range,
- offset_range: gold_offset_range,
- } = lsp_definition;
- let lsp_point_range =
- SerializablePoint::into_language_point_range(gold_point_range.clone());
-
- gold_is_external = gold_is_external
- || gold_path.is_absolute()
- || gold_path
- .components()
- .any(|component| component.as_os_str() == "node_modules");
-
- gold_in_excerpt = gold_in_excerpt
- || result.excerpt_range.as_ref().is_some_and(|excerpt_range| {
- excerpt_range.contains_inclusive(&gold_offset_range)
- });
-
- let gold_row = gold_point_range.start.row;
- let gold_column = gold_point_range.start.column;
- let candidate_count = result.retrieved_definitions.len() as u32;
-
- for (candidate_rank, retrieved_definition) in
- result.retrieved_definitions.iter().enumerate()
- {
- let candidate_is_gold = gold_path.as_path()
- == retrieved_definition.path.as_std_path()
- && retrieved_definition
- .range
- .contains_inclusive(&lsp_point_range);
-
- let candidate_row = retrieved_definition.range.start.row + 1;
- let candidate_column = retrieved_definition.range.start.column + 1;
-
- let DeclarationScoreComponents {
- is_same_file,
- is_referenced_nearby,
- is_referenced_in_breadcrumb,
- reference_count,
- same_file_declaration_count,
- declaration_count,
- reference_line_distance,
- declaration_line_distance,
- excerpt_vs_item_jaccard,
- excerpt_vs_signature_jaccard,
- adjacent_vs_item_jaccard,
- adjacent_vs_signature_jaccard,
- excerpt_vs_item_weighted_overlap,
- excerpt_vs_signature_weighted_overlap,
- adjacent_vs_item_weighted_overlap,
- adjacent_vs_signature_weighted_overlap,
- path_import_match_count,
- wildcard_path_import_match_count,
- import_similarity,
- max_import_similarity,
- normalized_import_similarity,
- wildcard_import_similarity,
- normalized_wildcard_import_similarity,
- included_by_others,
- includes_others,
- } = retrieved_definition.components;
-
- rows.push(Row {
- ref_id,
- cursor_path: cursor_path.to_string(),
- cursor_row,
- cursor_column,
- cursor_identifier: cursor_identifier.clone(),
- gold_in_excerpt,
- gold_path: gold_path.to_string_lossy().to_string(),
- gold_row,
- gold_column,
- gold_is_external,
- candidate_count,
- candidate_path: Some(retrieved_definition.path.as_unix_str().to_string()),
- candidate_row: Some(candidate_row),
- candidate_column: Some(candidate_column),
- candidate_is_gold: Some(candidate_is_gold),
- candidate_rank: Some(candidate_rank as u32),
- candidate_is_same_file: Some(is_same_file),
- candidate_is_referenced_nearby: Some(is_referenced_nearby),
- candidate_is_referenced_in_breadcrumb: Some(is_referenced_in_breadcrumb),
- candidate_reference_count: Some(reference_count as u32),
- candidate_same_file_declaration_count: Some(same_file_declaration_count as u32),
- candidate_declaration_count: Some(declaration_count as u32),
- candidate_reference_line_distance: Some(reference_line_distance),
- candidate_declaration_line_distance: Some(declaration_line_distance),
- candidate_excerpt_vs_item_jaccard: Some(excerpt_vs_item_jaccard),
- candidate_excerpt_vs_signature_jaccard: Some(excerpt_vs_signature_jaccard),
- candidate_adjacent_vs_item_jaccard: Some(adjacent_vs_item_jaccard),
- candidate_adjacent_vs_signature_jaccard: Some(adjacent_vs_signature_jaccard),
- candidate_excerpt_vs_item_weighted_overlap: Some(
- excerpt_vs_item_weighted_overlap,
- ),
- candidate_excerpt_vs_signature_weighted_overlap: Some(
- excerpt_vs_signature_weighted_overlap,
- ),
- candidate_adjacent_vs_item_weighted_overlap: Some(
- adjacent_vs_item_weighted_overlap,
- ),
- candidate_adjacent_vs_signature_weighted_overlap: Some(
- adjacent_vs_signature_weighted_overlap,
- ),
- candidate_path_import_match_count: Some(path_import_match_count as u32),
- candidate_wildcard_path_import_match_count: Some(
- wildcard_path_import_match_count as u32,
- ),
- candidate_import_similarity: Some(import_similarity),
- candidate_max_import_similarity: Some(max_import_similarity),
- candidate_normalized_import_similarity: Some(normalized_import_similarity),
- candidate_wildcard_import_similarity: Some(wildcard_import_similarity),
- candidate_normalized_wildcard_import_similarity: Some(
- normalized_wildcard_import_similarity,
- ),
- candidate_included_by_others: Some(included_by_others as u32),
- candidate_includes_others: Some(includes_others as u32),
- });
- }
-
- if result.retrieved_definitions.is_empty() {
- rows.push(Row {
- ref_id,
- cursor_path: cursor_path.to_string(),
- cursor_row,
- cursor_column,
- cursor_identifier: cursor_identifier.clone(),
- gold_in_excerpt,
- gold_path: gold_path.to_string_lossy().to_string(),
- gold_row,
- gold_column,
- gold_is_external,
- candidate_count,
- ..Default::default()
- });
- }
- }
- }
- let slices = rows.slices();
-
- let RowSlices {
- ref_id,
- cursor_path,
- cursor_row,
- cursor_column,
- cursor_identifier,
- gold_in_excerpt,
- gold_path,
- gold_row,
- gold_column,
- gold_is_external,
- candidate_path,
- candidate_row,
- candidate_column,
- candidate_is_gold,
- candidate_rank,
- candidate_count,
- candidate_is_same_file,
- candidate_is_referenced_nearby,
- candidate_is_referenced_in_breadcrumb,
- candidate_reference_count,
- candidate_same_file_declaration_count,
- candidate_declaration_count,
- candidate_reference_line_distance,
- candidate_declaration_line_distance,
- candidate_excerpt_vs_item_jaccard,
- candidate_excerpt_vs_signature_jaccard,
- candidate_adjacent_vs_item_jaccard,
- candidate_adjacent_vs_signature_jaccard,
- candidate_excerpt_vs_item_weighted_overlap,
- candidate_excerpt_vs_signature_weighted_overlap,
- candidate_adjacent_vs_item_weighted_overlap,
- candidate_adjacent_vs_signature_weighted_overlap,
- candidate_path_import_match_count,
- candidate_wildcard_path_import_match_count,
- candidate_import_similarity,
- candidate_max_import_similarity,
- candidate_normalized_import_similarity,
- candidate_wildcard_import_similarity,
- candidate_normalized_wildcard_import_similarity,
- candidate_included_by_others,
- candidate_includes_others,
- } = slices;
-
- let df = DataFrame::new(vec![
- Series::new(PlSmallStr::from_str("ref_id"), ref_id).into(),
- Series::new(PlSmallStr::from_str("cursor_path"), cursor_path).into(),
- Series::new(PlSmallStr::from_str("cursor_row"), cursor_row).into(),
- Series::new(PlSmallStr::from_str("cursor_column"), cursor_column).into(),
- Series::new(PlSmallStr::from_str("cursor_identifier"), cursor_identifier).into(),
- Series::new(PlSmallStr::from_str("gold_in_excerpt"), gold_in_excerpt).into(),
- Series::new(PlSmallStr::from_str("gold_path"), gold_path).into(),
- Series::new(PlSmallStr::from_str("gold_row"), gold_row).into(),
- Series::new(PlSmallStr::from_str("gold_column"), gold_column).into(),
- Series::new(PlSmallStr::from_str("gold_is_external"), gold_is_external).into(),
- Series::new(PlSmallStr::from_str("candidate_count"), candidate_count).into(),
- Series::new(PlSmallStr::from_str("candidate_path"), candidate_path).into(),
- Series::new(PlSmallStr::from_str("candidate_row"), candidate_row).into(),
- Series::new(PlSmallStr::from_str("candidate_column"), candidate_column).into(),
- Series::new(PlSmallStr::from_str("candidate_is_gold"), candidate_is_gold).into(),
- Series::new(PlSmallStr::from_str("candidate_rank"), candidate_rank).into(),
- Series::new(
- PlSmallStr::from_str("candidate_is_same_file"),
- candidate_is_same_file,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_is_referenced_nearby"),
- candidate_is_referenced_nearby,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_is_referenced_in_breadcrumb"),
- candidate_is_referenced_in_breadcrumb,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_reference_count"),
- candidate_reference_count,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_same_file_declaration_count"),
- candidate_same_file_declaration_count,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_declaration_count"),
- candidate_declaration_count,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_reference_line_distance"),
- candidate_reference_line_distance,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_declaration_line_distance"),
- candidate_declaration_line_distance,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_excerpt_vs_item_jaccard"),
- candidate_excerpt_vs_item_jaccard,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_excerpt_vs_signature_jaccard"),
- candidate_excerpt_vs_signature_jaccard,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_adjacent_vs_item_jaccard"),
- candidate_adjacent_vs_item_jaccard,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_adjacent_vs_signature_jaccard"),
- candidate_adjacent_vs_signature_jaccard,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_excerpt_vs_item_weighted_overlap"),
- candidate_excerpt_vs_item_weighted_overlap,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_excerpt_vs_signature_weighted_overlap"),
- candidate_excerpt_vs_signature_weighted_overlap,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_adjacent_vs_item_weighted_overlap"),
- candidate_adjacent_vs_item_weighted_overlap,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_adjacent_vs_signature_weighted_overlap"),
- candidate_adjacent_vs_signature_weighted_overlap,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_path_import_match_count"),
- candidate_path_import_match_count,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_wildcard_path_import_match_count"),
- candidate_wildcard_path_import_match_count,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_import_similarity"),
- candidate_import_similarity,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_max_import_similarity"),
- candidate_max_import_similarity,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_normalized_import_similarity"),
- candidate_normalized_import_similarity,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_wildcard_import_similarity"),
- candidate_wildcard_import_similarity,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_normalized_wildcard_import_similarity"),
- candidate_normalized_wildcard_import_similarity,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_included_by_others"),
- candidate_included_by_others,
- )
- .into(),
- Series::new(
- PlSmallStr::from_str("candidate_includes_others"),
- candidate_includes_others,
- )
- .into(),
- ])?;
-
- Ok(df)
-}
-
-fn relativize_path(path: &Path) -> &Path {
- path.strip_prefix(std::env::current_dir().unwrap())
- .unwrap_or(path)
-}
-
-struct SummaryStats {
- references_count: u32,
- retrieved_count: u32,
- top_match_count: u32,
- non_top_match_count: u32,
- ranking_involved_top_match_count: u32,
- missing_none_retrieved: u32,
- missing_wrong_retrieval: u32,
- missing_external: u32,
- in_excerpt_count: u32,
-}
-
-impl SummaryStats {
- fn from_dataframe(df: DataFrame) -> Result<Self> {
- // TODO: use lazy more
- let unique_refs =
- df.unique::<(), ()>(Some(&["ref_id".into()]), UniqueKeepStrategy::Any, None)?;
- let references_count = unique_refs.height() as u32;
-
- let gold_mask = df.column("candidate_is_gold")?.bool()?;
- let gold_df = df.filter(&gold_mask)?;
- let retrieved_count = gold_df.height() as u32;
-
- let top_match_mask = gold_df.column("candidate_rank")?.u32()?.equal(0);
- let top_match_df = gold_df.filter(&top_match_mask)?;
- let top_match_count = top_match_df.height() as u32;
-
- let ranking_involved_top_match_count = top_match_df
- .column("candidate_count")?
- .u32()?
- .gt(1)
- .sum()
- .unwrap_or_default();
-
- let non_top_match_count = (!top_match_mask).sum().unwrap_or(0);
-
- let not_retrieved_df = df
- .lazy()
- .group_by(&[col("ref_id"), col("candidate_count")])
- .agg(&[
- col("candidate_is_gold")
- .fill_null(false)
- .sum()
- .alias("gold_count"),
- col("gold_in_excerpt").sum().alias("gold_in_excerpt_count"),
- col("gold_is_external")
- .sum()
- .alias("gold_is_external_count"),
- ])
- .filter(col("gold_count").eq(lit(0)))
- .collect()?;
-
- let in_excerpt_mask = not_retrieved_df
- .column("gold_in_excerpt_count")?
- .u32()?
- .gt(0);
- let in_excerpt_count = in_excerpt_mask.sum().unwrap_or(0);
-
- let missing_df = not_retrieved_df.filter(&!in_excerpt_mask)?;
-
- let missing_none_retrieved_mask = missing_df.column("candidate_count")?.u32()?.equal(0);
- let missing_none_retrieved = missing_none_retrieved_mask.sum().unwrap_or(0);
- let external_mask = missing_df.column("gold_is_external_count")?.u32()?.gt(0);
- let missing_external = (missing_none_retrieved_mask & external_mask)
- .sum()
- .unwrap_or(0);
-
- let missing_wrong_retrieval = missing_df
- .column("candidate_count")?
- .u32()?
- .gt(0)
- .sum()
- .unwrap_or(0);
-
- Ok(SummaryStats {
- references_count,
- retrieved_count,
- top_match_count,
- non_top_match_count,
- ranking_involved_top_match_count,
- missing_none_retrieved,
- missing_wrong_retrieval,
- missing_external,
- in_excerpt_count,
- })
- }
-
- fn count_and_percentage(part: u32, total: u32) -> String {
- format!("{} ({:.2}%)", part, (part as f64 / total as f64) * 100.0)
- }
-}
-
-impl std::fmt::Display for SummaryStats {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- let included = self.in_excerpt_count + self.retrieved_count;
- let missing = self.references_count - included;
- writeln!(f)?;
- writeln!(f, "โฎ references: {}", self.references_count)?;
- writeln!(
- f,
- "โโโฎ included: {}",
- Self::count_and_percentage(included, self.references_count),
- )?;
- writeln!(
- f,
- "โ โโโฎ retrieved: {}",
- Self::count_and_percentage(self.retrieved_count, self.references_count)
- )?;
- writeln!(
- f,
- "โ โ โโโฎ top match : {}",
- Self::count_and_percentage(self.top_match_count, self.retrieved_count)
- )?;
- writeln!(
- f,
- "โ โ โ โฐโโด involving ranking: {}",
- Self::count_and_percentage(self.ranking_involved_top_match_count, self.top_match_count)
- )?;
- writeln!(
- f,
- "โ โ โฐโโด non-top match: {}",
- Self::count_and_percentage(self.non_top_match_count, self.retrieved_count)
- )?;
- writeln!(
- f,
- "โ โฐโโด in excerpt: {}",
- Self::count_and_percentage(self.in_excerpt_count, included)
- )?;
- writeln!(
- f,
- "โฐโโฎ missing: {}",
- Self::count_and_percentage(missing, self.references_count)
- )?;
- writeln!(
- f,
- " โโโฎ none retrieved: {}",
- Self::count_and_percentage(self.missing_none_retrieved, missing)
- )?;
- writeln!(
- f,
- " โ โฐโโด external (expected): {}",
- Self::count_and_percentage(self.missing_external, missing)
- )?;
- writeln!(
- f,
- " โฐโโด wrong retrieval: {}",
- Self::count_and_percentage(self.missing_wrong_retrieval, missing)
- )?;
- Ok(())
- }
-}
-
-#[derive(Debug)]
-struct ReferenceRetrievalResult {
- cursor_path: Arc<RelPath>,
- cursor_point: Point,
- identifier: Identifier,
- excerpt_range: Option<Range<usize>>,
- lsp_definitions: Vec<SourceRange>,
- retrieved_definitions: Vec<RetrievedDefinition>,
-}
-
-#[derive(Debug)]
-struct RetrievedDefinition {
- path: Arc<RelPath>,
- range: Range<Point>,
- score: f32,
- #[allow(dead_code)]
- retrieval_score: f32,
- #[allow(dead_code)]
- components: DeclarationScoreComponents,
-}
-
-struct RetrieveResult {
- definitions: Vec<RetrievedDefinition>,
- excerpt_range: Option<Range<usize>>,
-}
-
-async fn retrieve_definitions(
- reference: &Reference,
- imports: &Imports,
- query_point: Point,
- snapshot: &BufferSnapshot,
- index: &Arc<SyntaxIndexState>,
- file_snapshots: &Arc<HashMap<ProjectEntryId, BufferSnapshot>>,
- context_options: &EditPredictionContextOptions,
-) -> Result<RetrieveResult> {
- let mut single_reference_map = HashMap::default();
- single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]);
- let edit_prediction_context = EditPredictionContext::gather_context_with_references_fn(
- query_point,
- snapshot,
- imports,
- &context_options,
- Some(&index),
- |_, _, _| single_reference_map,
- );
-
- let Some(edit_prediction_context) = edit_prediction_context else {
- return Ok(RetrieveResult {
- definitions: Vec::new(),
- excerpt_range: None,
- });
- };
-
- let mut retrieved_definitions = Vec::new();
- for scored_declaration in edit_prediction_context.declarations {
- match &scored_declaration.declaration {
- Declaration::File {
- project_entry_id,
- declaration,
- ..
- } => {
- let Some(snapshot) = file_snapshots.get(&project_entry_id) else {
- log::error!("bug: file project entry not found");
- continue;
- };
- let path = snapshot.file().unwrap().path().clone();
- retrieved_definitions.push(RetrievedDefinition {
- path,
- range: snapshot.offset_to_point(declaration.item_range.start)
- ..snapshot.offset_to_point(declaration.item_range.end),
- score: scored_declaration.score(DeclarationStyle::Declaration),
- retrieval_score: scored_declaration.retrieval_score(),
- components: scored_declaration.components,
- });
- }
- Declaration::Buffer {
- project_entry_id,
- rope,
- declaration,
- ..
- } => {
- let Some(snapshot) = file_snapshots.get(&project_entry_id) else {
- // This case happens when dependency buffers have been opened by
- // go-to-definition, resulting in single-file worktrees.
- continue;
- };
- let path = snapshot.file().unwrap().path().clone();
- retrieved_definitions.push(RetrievedDefinition {
- path,
- range: rope.offset_to_point(declaration.item_range.start)
- ..rope.offset_to_point(declaration.item_range.end),
- score: scored_declaration.score(DeclarationStyle::Declaration),
- retrieval_score: scored_declaration.retrieval_score(),
- components: scored_declaration.components,
- });
- }
- }
- }
- retrieved_definitions.sort_by_key(|definition| Reverse(OrderedFloat(definition.score)));
-
- Ok(RetrieveResult {
- definitions: retrieved_definitions,
- excerpt_range: Some(edit_prediction_context.excerpt.range),
- })
-}
-
-async fn gather_lsp_definitions(
- lsp_definitions_path: &Path,
- start_index: usize,
- files: &[ProjectPath],
- worktree: &Entity<Worktree>,
- project: &Entity<Project>,
- definitions: &mut HashMap<SourceLocation, Vec<SourceRange>>,
- cx: &mut AsyncApp,
-) -> Result<()> {
- let worktree_id = worktree.read_with(cx, |worktree, _cx| worktree.id())?;
-
- let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
- cx.subscribe(&lsp_store, {
- move |_, event, _| {
- if let project::LspStoreEvent::LanguageServerUpdate {
- message:
- client::proto::update_language_server::Variant::WorkProgress(
- client::proto::LspWorkProgress {
- message: Some(message),
- ..
- },
- ),
- ..
- } = event
- {
- println!("โฒ {message}")
- }
- }
- })?
- .detach();
-
- let (cache_line_tx, mut cache_line_rx) = mpsc::unbounded::<FileLspDefinitions>();
-
- let cache_file = File::options()
- .append(true)
- .create(true)
- .open(lsp_definitions_path)
- .unwrap();
-
- let cache_task = cx.background_spawn(async move {
- let mut writer = BufWriter::new(cache_file);
- while let Some(line) = cache_line_rx.next().await {
- serde_json::to_writer(&mut writer, &line).unwrap();
- writer.write_all(&[b'\n']).unwrap();
- }
- writer.flush().unwrap();
- });
-
- let mut error_count = 0;
- let mut lsp_open_handles = Vec::new();
- let mut ready_languages = HashSet::default();
- for (file_index, project_path) in files[start_index..].iter().enumerate() {
- println!(
- "Processing file {} of {}: {}",
- start_index + file_index + 1,
- files.len(),
- project_path.path.display(PathStyle::Posix)
- );
-
- let Some((lsp_open_handle, language_server_id, buffer)) = open_buffer_with_language_server(
- project.clone(),
- worktree.clone(),
- project_path.path.clone(),
- &mut ready_languages,
- cx,
- )
- .await
- .log_err() else {
- continue;
- };
- lsp_open_handles.push(lsp_open_handle);
-
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
- let full_range = 0..snapshot.len();
- let references = references_in_range(
- full_range,
- &snapshot.text(),
- ReferenceRegion::Nearby,
- &snapshot,
- );
-
- loop {
- let is_ready = lsp_store
- .read_with(cx, |lsp_store, _cx| {
- lsp_store
- .language_server_statuses
- .get(&language_server_id)
- .is_some_and(|status| status.pending_work.is_empty())
- })
- .unwrap();
- if is_ready {
- break;
- }
- cx.background_executor()
- .timer(Duration::from_millis(10))
- .await;
- }
-
- let mut cache_line_references = Vec::with_capacity(references.len());
-
- for reference in references {
- // TODO: Rename declaration to definition in edit_prediction_context?
- let lsp_result = project
- .update(cx, |project, cx| {
- project.definitions(&buffer, reference.range.start, cx)
- })?
- .await;
-
- match lsp_result {
- Ok(lsp_definitions) => {
- let mut targets = Vec::new();
- for target in lsp_definitions.unwrap_or_default() {
- let buffer = target.target.buffer;
- let anchor_range = target.target.range;
- buffer.read_with(cx, |buffer, cx| {
- let Some(file) = project::File::from_dyn(buffer.file()) else {
- return;
- };
- let file_worktree = file.worktree.read(cx);
- let file_worktree_id = file_worktree.id();
- // Relative paths for worktree files, absolute for all others
- let path = if worktree_id != file_worktree_id {
- file.worktree.read(cx).absolutize(&file.path)
- } else {
- file.path.as_std_path().to_path_buf()
- };
- let offset_range = anchor_range.to_offset(&buffer);
- let point_range = SerializablePoint::from_language_point_range(
- offset_range.to_point(&buffer),
- );
- targets.push(SourceRange {
- path,
- offset_range,
- point_range,
- });
- })?;
- }
-
- let point = snapshot.offset_to_point(reference.range.start);
-
- cache_line_references.push((point.into(), targets.clone()));
- definitions.insert(
- SourceLocation {
- path: project_path.path.clone(),
- point,
- },
- targets,
- );
- }
- Err(err) => {
- log::error!("Language server error: {err}");
- error_count += 1;
- }
- }
- }
-
- cache_line_tx
- .unbounded_send(FileLspDefinitions {
- path: project_path.path.as_unix_str().into(),
- references: cache_line_references,
- })
- .log_err();
- }
-
- drop(cache_line_tx);
-
- if error_count > 0 {
- log::error!("Encountered {} language server errors", error_count);
- }
-
- cache_task.await;
-
- Ok(())
-}
-
-#[derive(Serialize, Deserialize)]
-struct FileLspDefinitions {
- path: Arc<str>,
- references: Vec<(SerializablePoint, Vec<SourceRange>)>,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize)]
-struct SourceRange {
- path: PathBuf,
- point_range: Range<SerializablePoint>,
- offset_range: Range<usize>,
-}
-
-/// Serializes to 1-based row and column indices.
-#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct SerializablePoint {
- pub row: u32,
- pub column: u32,
-}
-
-impl SerializablePoint {
- pub fn into_language_point_range(range: Range<Self>) -> Range<Point> {
- range.start.into()..range.end.into()
- }
-
- pub fn from_language_point_range(range: Range<Point>) -> Range<Self> {
- range.start.into()..range.end.into()
- }
-}
-
-impl From<Point> for SerializablePoint {
- fn from(point: Point) -> Self {
- SerializablePoint {
- row: point.row + 1,
- column: point.column + 1,
- }
- }
-}
-
-impl From<SerializablePoint> for Point {
- fn from(serializable: SerializablePoint) -> Self {
- Point {
- row: serializable.row.saturating_sub(1),
- column: serializable.column.saturating_sub(1),
- }
- }
-}
@@ -5,12 +5,12 @@ use std::sync::{
atomic::{AtomicU8, Ordering},
};
-use crate::{SCOPE_DEPTH_MAX, SCOPE_STRING_SEP_STR, Scope, ScopeAlloc, env_config, private};
+use crate::{SCOPE_DEPTH_MAX, SCOPE_STRING_SEP_STR, ScopeAlloc, ScopeRef, env_config, private};
use log;
static ENV_FILTER: OnceLock<env_config::EnvFilter> = OnceLock::new();
-static SCOPE_MAP: RwLock<Option<ScopeMap>> = RwLock::new(None);
+static SCOPE_MAP: RwLock<ScopeMap> = RwLock::new(ScopeMap::empty());
pub const LEVEL_ENABLED_MAX_DEFAULT: log::LevelFilter = log::LevelFilter::Info;
/// The maximum log level of verbosity that is enabled by default.
@@ -59,7 +59,11 @@ pub fn is_possibly_enabled_level(level: log::Level) -> bool {
level as u8 <= LEVEL_ENABLED_MAX_CONFIG.load(Ordering::Acquire)
}
-pub fn is_scope_enabled(scope: &Scope, module_path: Option<&str>, level: log::Level) -> bool {
+pub fn is_scope_enabled(
+ scope: &ScopeRef<'_>,
+ module_path: Option<&str>,
+ level: log::Level,
+) -> bool {
// TODO: is_always_allowed_level that checks against LEVEL_ENABLED_MIN_CONFIG
if !is_possibly_enabled_level(level) {
// [FAST PATH]
@@ -74,16 +78,11 @@ pub fn is_scope_enabled(scope: &Scope, module_path: Option<&str>, level: log::Le
err.into_inner()
});
- let Some(map) = global_scope_map.as_ref() else {
- // on failure, return false because it's not <= LEVEL_ENABLED_MAX_STATIC
- return is_enabled_by_default;
- };
-
- if map.is_empty() {
+ if global_scope_map.is_empty() {
// if no scopes are enabled, return false because it's not <= LEVEL_ENABLED_MAX_STATIC
return is_enabled_by_default;
}
- let enabled_status = map.is_enabled(scope, module_path, level);
+ let enabled_status = global_scope_map.is_enabled(scope, module_path, level);
match enabled_status {
EnabledStatus::NotConfigured => is_enabled_by_default,
EnabledStatus::Enabled => true,
@@ -107,7 +106,7 @@ pub fn refresh_from_settings(settings: &HashMap<String, String>) {
SCOPE_MAP.clear_poison();
err.into_inner()
});
- global_map.replace(map_new);
+ *global_map = map_new;
}
log::trace!("Log configuration updated");
}
@@ -395,12 +394,21 @@ impl ScopeMap {
}
EnabledStatus::NotConfigured
}
+
+ const fn empty() -> ScopeMap {
+ ScopeMap {
+ entries: vec![],
+ modules: vec![],
+ root_count: 0,
+ }
+ }
}
#[cfg(test)]
mod tests {
use log::LevelFilter;
+ use crate::Scope;
use crate::private::scope_new;
use super::*;
@@ -8,7 +8,7 @@ use std::{
},
};
-use crate::{SCOPE_STRING_SEP_CHAR, Scope};
+use crate::{SCOPE_STRING_SEP_CHAR, ScopeRef};
// ANSI color escape codes for log levels
const ANSI_RESET: &str = "\x1b[0m";
@@ -35,7 +35,7 @@ static SINK_FILE_SIZE_BYTES: AtomicU64 = AtomicU64::new(0);
const SINK_FILE_SIZE_BYTES_MAX: u64 = 1024 * 1024; // 1 MB
pub struct Record<'a> {
- pub scope: Scope,
+ pub scope: ScopeRef<'a>,
pub level: log::Level,
pub message: &'a std::fmt::Arguments<'a>,
pub module_path: Option<&'a str>,
@@ -208,7 +208,7 @@ pub fn flush() {
}
struct SourceFmt<'a> {
- scope: Scope,
+ scope: ScopeRef<'a>,
module_path: Option<&'a str>,
line: Option<u32>,
ansi: bool,
@@ -70,15 +70,18 @@ impl log::Log for Zlog {
if !self.enabled(record.metadata()) {
return;
}
- let (crate_name_scope, module_scope) = match record.module_path_static() {
+ let module_path = record.module_path().or(record.file());
+ let (crate_name_scope, module_scope) = match module_path {
Some(module_path) => {
let crate_name = private::extract_crate_name_from_module_path(module_path);
- let crate_name_scope = private::scope_new(&[crate_name]);
- let module_scope = private::scope_new(&[module_path]);
+ let crate_name_scope = private::scope_ref_new(&[crate_name]);
+ let module_scope = private::scope_ref_new(&[module_path]);
(crate_name_scope, module_scope)
}
- // TODO: when do we hit this
- None => (private::scope_new(&[]), private::scope_new(&["*unknown*"])),
+ None => {
+ // TODO: when do we hit this
+ (private::scope_new(&[]), private::scope_new(&["*unknown*"]))
+ }
};
let level = record.metadata().level();
if !filter::is_scope_enabled(&crate_name_scope, Some(record.target()), level) {
@@ -89,7 +92,7 @@ impl log::Log for Zlog {
level,
message: record.args(),
// PERF(batching): store non-static paths in a cache + leak them and pass static str here
- module_path: record.module_path().or(record.file()),
+ module_path,
line: record.line(),
});
}
@@ -252,6 +255,10 @@ pub mod private {
}
pub const fn scope_new(scopes: &[&'static str]) -> Scope {
+ scope_ref_new(scopes)
+ }
+
+ pub const fn scope_ref_new<'a>(scopes: &[&'a str]) -> ScopeRef<'a> {
assert!(scopes.len() <= SCOPE_DEPTH_MAX);
let mut scope = [""; SCOPE_DEPTH_MAX];
let mut i = 0;
@@ -275,6 +282,7 @@ pub mod private {
}
pub type Scope = [&'static str; SCOPE_DEPTH_MAX];
+pub type ScopeRef<'a> = [&'a str; SCOPE_DEPTH_MAX];
pub type ScopeAlloc = [String; SCOPE_DEPTH_MAX];
const SCOPE_STRING_SEP_STR: &str = ".";
const SCOPE_STRING_SEP_CHAR: char = '.';
@@ -0,0 +1,20 @@
+[package]
+name = "ztracing"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[features]
+tracy = ["tracing-tracy"]
+
+[dependencies]
+tracing.workspace = true
+
+tracing-subscriber = "0.3.22"
+tracing-tracy = { version = "0.11.4", optional = true, features = ["enable", "ondemand"] }
+
+ztracing_macro.workspace = true
@@ -0,0 +1 @@
+../../LICENSE-AGPL
@@ -0,0 +1 @@
+../../LICENSE-APACHE
@@ -0,0 +1,9 @@
+use std::env;
+
+fn main() {
+ if env::var_os("ZTRACING").is_some() {
+ println!(r"cargo::rustc-cfg=ztracing");
+ }
+ println!("cargo::rerun-if-changed=build.rs");
+ println!("cargo::rerun-if-env-changed=ZTRACING");
+}
@@ -0,0 +1,16 @@
+#[cfg(ztracing)]
+pub use tracing::instrument;
+#[cfg(not(ztracing))]
+pub use ztracing_macro::instrument;
+
+#[cfg(ztracing)]
+pub fn init() {
+ use tracing_subscriber::prelude::*;
+ tracing::subscriber::set_global_default(
+ tracing_subscriber::registry().with(tracing_tracy::TracyLayer::default()),
+ )
+ .expect("setup tracy layer");
+}
+
+#[cfg(not(ztracing))]
+pub fn init() {}
@@ -0,0 +1,11 @@
+[package]
+name = "ztracing_macro"
+version = "0.1.0"
+edition.workspace = true
+publish.workspace = true
+license = "GPL-3.0-or-later"
+
+[lib]
+proc-macro = true
+
+[dependencies]
@@ -0,0 +1 @@
+../../LICENSE-AGPL
@@ -0,0 +1 @@
+../../LICENSE-APACHE
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,7 @@
+#[proc_macro_attribute]
+pub fn instrument(
+ _attr: proc_macro::TokenStream,
+ item: proc_macro::TokenStream,
+) -> proc_macro::TokenStream {
+ item
+}
@@ -41,6 +41,7 @@
- [Debugger](./debugger.md)
- [Diagnostics](./diagnostics.md)
- [Tasks](./tasks.md)
+- [Tab Switcher](./tab-switcher.md)
- [Remote Development](./remote-development.md)
- [Environment Variables](./environment.md)
- [REPL](./repl.md)
@@ -5,7 +5,7 @@
## Using Zed's built-in debugger
-While the Zed project is open you can open the `New Process Modal` and select the `Debug` tab. There you can see to debug configurations to debug Zed with, one for GDB and one for LLDB. Select the configuration you want and Zed will build and launch the binary.
+While the Zed project is open you can open the `New Process Modal` and select the `Debug` tab. There you can see two debug configurations to debug Zed with, one for GDB and one for LLDB. Select the configuration you want and Zed will build and launch the binary.
Please note, GDB isn't supported on arm Macbooks
@@ -3,7 +3,7 @@
Astro support is available through the [Astro extension](https://github.com/zed-extensions/astro).
- Tree-sitter: [virchau13/tree-sitter-astro](https://github.com/virchau13/tree-sitter-astro)
-- Language Server: [withastro/language-tools](https://github.com/withastro/language-tools)
+- Language Server: [withastro/language-tools](https://github.com/withastro/astro/tree/main/packages/language-tools/language-server)
<!--
TBD: Documentation Astro usage / configuration
@@ -2,34 +2,44 @@
PHP support is available through the [PHP extension](https://github.com/zed-extensions/php).
-- Tree-sitter: https://github.com/tree-sitter/tree-sitter-php
-- Language Servers:
- - [phpactor](https://github.com/phpactor/phpactor)
- - [intelephense](https://github.com/bmewburn/vscode-intelephense/)
+- Tree-sitter: [tree-sitter/tree-sitter-php](https://github.com/tree-sitter/tree-sitter-php)
+- Language Server: [phpactor/phpactor](https://github.com/phpactor/phpactor)
+- Alternate Language Server: [bmewburn/vscode-intelephense](https://github.com/bmewburn/vscode-intelephense/)
-## Choosing a language server
+## Install PHP
-The PHP extension offers both `phpactor` and `intelephense` language server support.
+The PHP extension requires PHP to be installed and available in your `PATH`:
-`phpactor` is enabled by default.
+```sh
+# macOS via Homebrew
+brew install php
-### Phpactor
+# Debian/Ubuntu
+sudo apt-get install php-cli
-The Zed PHP Extension can install `phpactor` automatically but requires `php` to be installed and available in your path:
+# CentOS 8+/RHEL
+sudo dnf install php-cli
-```sh
-# brew install php # macOS
-# sudo apt-get install php # Debian/Ubuntu
-# yum install php # CentOS/RHEL
-# pacman -S php # Arch Linux
+# Arch Linux
+sudo pacman -S php
+
+# check PHP path
+## macOS and Linux
which php
+
+## Windows
+where php
```
+## Choosing a language server
+
+The PHP extension uses [LSP language servers](https://microsoft.github.io/language-server-protocol) with Phpactor as the default. If you want to use other language servers that support Zed (e.g. Intelephense or PHP Tools), make sure to follow the documentation on how to implement it.
+
### Intelephense
-[Intelephense](https://intelephense.com/) is a [proprietary](https://github.com/bmewburn/vscode-intelephense/blob/master/LICENSE.txt#L29) language server for PHP operating under a freemium model. Certain features require purchase of a [premium license](https://intelephense.com/).
+[Intelephense](https://intelephense.com/) is a [proprietary](https://github.com/bmewburn/vscode-intelephense/blob/master/LICENSE.txt#L29) language server for PHP operating under a freemium model. Certain features require purchase of a [premium license](https://intelephense.com/buy).
-To switch to `intelephense`, add the following to your `settings.json`:
+To use Intelephense, add the following to your `settings.json`:
```json [settings]
{
@@ -41,7 +51,9 @@ To switch to `intelephense`, add the following to your `settings.json`:
}
```
-To use the premium features, you can place your [licence.txt file](https://intelephense.com/faq.html) at `~/intelephense/licence.txt` inside your home directory. Alternatively, you can pass the licence key or a path to a file containing the licence key as an initialization option for the `intelephense` language server. To do this, add the following to your `settings.json`:
+To use the premium features, you can place your license file inside your home directory at `~/intelephense/licence.txt` for macOS and Linux, or `%USERPROFILE%/intelephense/licence.txt` on Windows.
+
+Alternatively, you can pass the licence key or a path to a file containing the licence key as an initialization option. To do this, add the following to your `settings.json`:
```json [settings]
{
@@ -55,15 +67,67 @@ To use the premium features, you can place your [licence.txt file](https://intel
}
```
+### PHP Tools
+
+[PHP Tools](https://www.devsense.com/) is a proprietary language server that offers free and premium features. You need to [purchase a license](https://www.devsense.com/en/purchase) to activate the premium features.
+
+To use PHP Tools, add the following to your `settings.json`:
+
+```json [settings]
+{
+ "languages": {
+ "PHP": {
+ "language_servers": ["phptools", "!intelephense", "!phpactor", "..."]
+ }
+ }
+}
+```
+
+To use the premium features, you can add your license in `initialization_options` in your `settings.json`:
+
+```json [settings]
+{
+ "lsp": {
+ "phptools": {
+ "initialization_options": {
+ "0": "your_license_key"
+ }
+ }
+ }
+}
+```
+
+or, set environment variable `DEVSENSE_PHP_LS_LICENSE` on `.env` file in your project.
+
+```env
+DEVSENSE_PHP_LS_LICENSE="your_license_key"
+```
+
+Check out the documentation of [PHP Tools for Zed](https://docs.devsense.com/other/zed/) for more details.
+
+### Phpactor
+
+To use Phpactor instead of Intelephense or any other tools, add the following to your `settings.json`:
+
+```json [settings]
+{
+ "languages": {
+ "PHP": {
+ "language_servers": ["phpactor", "!intelephense", "!phptools", "..."]
+ }
+ }
+}
+```
+
## PHPDoc
Zed supports syntax highlighting for PHPDoc comments.
- Tree-sitter: [claytonrcarter/tree-sitter-phpdoc](https://github.com/claytonrcarter/tree-sitter-phpdoc)
-## Setting up Xdebug
+## Debugging
-Zedโs PHP extension provides a debug adapter for PHP and Xdebug. The adapter name is `Xdebug`. Here a couple ways you can use it:
+The PHP extension provides a debug adapter for PHP via Xdebug. There are several ways to use it:
```json
[
@@ -83,10 +147,10 @@ Zedโs PHP extension provides a debug adapter for PHP and Xdebug. The adapter n
]
```
-In case you run into issues:
+These are common troubleshooting tips, in case you run into issues:
-- ensure that you have Xdebug installed for the version of PHP youโre running
-- ensure that Xdebug is configured to run in `debug` mode
-- ensure that Xdebug is actually starting a debugging session
-- check that the host and port matches between Xdebug and Zed
-- look at the diagnostics log by using the `xdebug_info()` function in the page youโre trying to debug
+- Ensure that you have Xdebug installed for the version of PHP youโre running.
+- Ensure that Xdebug is configured to run in `debug` mode.
+- Ensure that Xdebug is actually starting a debugging session.
+- Ensure that the host and port matches between Xdebug and Zed.
+- Look at the diagnostics log by using the `xdebug_info()` function in the page youโre trying to debug.
@@ -3,7 +3,7 @@
Rego language support in Zed is provided by the community-maintained [Rego extension](https://github.com/StyraInc/zed-rego).
- Tree-sitter: [FallenAngel97/tree-sitter-rego](https://github.com/FallenAngel97/tree-sitter-rego)
-- Language Server: [StyraInc/regal](https://github.com/StyraInc/regal)
+- Language Server: [open-policy-agent/regal](https://github.com/open-policy-agent/regal)
## Installation
@@ -1,6 +1,6 @@
How to use our internal tools to profile and keep Zed fast.
-# Flamechart/CPU profiling
+# Rough quick CPU profiling (Flamechart)
See what the CPU spends the most time on. Strongly recommend you use
[samply](https://github.com/mstange/samply). It opens an interactive profile in
@@ -12,6 +12,46 @@ The profile.json does not contain any symbols. Firefox profiler can add the loca
<img width="851" height="613" alt="image" src="https://github.com/user-attachments/assets/cbef2b51-0442-4ee9-bc5c-95f6ccf9be2c" />
+# In depth CPU profiling (Tracing)
+
+See how long each annotated function call took and its arguments (if
+configured).
+
+Annotate any function you need appear in the profile with instrument. For more
+details see
+[tracing-instrument](https://docs.rs/tracing/latest/tracing/attr.instrument.html):
+
+```rust
+#[instrument(skip_all)]
+fn should_appear_in_profile(kitty: Cat) {
+ sleep(QUITE_LONG)
+}
+```
+
+Then either compile Zed with `ZTRACING=1 cargo r --features tracy --release`. The release build is optional but highly recommended as like every program Zeds performance characteristics change dramatically with optimizations. You do not want to chase slowdowns that do not exist in release.
+
+## One time Setup/Building the profiler:
+
+Download the profiler:
+[linux x86_64](https://zed-tracy-import-miniprofiler.nyc3.digitaloceanspaces.com/tracy-profiler-linux-x86_64)
+[macos aarch64](https://zed-tracy-import-miniprofiler.nyc3.digitaloceanspaces.com/tracy-profiler-0.13.0-macos-aarch64)
+
+### Alternative: Building it yourself
+
+- Clone the repo at git@github.com:wolfpld/tracy.git
+- `cd profiler && mkdir build && cd build`
+- Run cmake to generate build files: `cmake -G Ninja -DCMAKE_BUILD_TYPE=Release ..`
+- Build the profiler: `ninja`
+- [Optional] move the profiler somewhere nice like ~/.local/bin on linux
+
+## Usage
+
+Open the profiler (tracy-profiler), you should see zed in the list of `Discovered clients` click it.
+<img width="392" height="287" alt="image" src="https://github.com/user-attachments/assets/b6f06fc3-6b25-41c7-ade9-558cc93d6033" />
+
+To find functions that take a long time follow this image:
+<img width="888" height="1159" alt="image" src="https://github.com/user-attachments/assets/77087617-f53a-4331-863d-e59f8a5b6f0b" />
+
# Task/Async profiling
Get a profile of the zed foreground executor and background executors. Check if
@@ -23,11 +63,17 @@ look at the results live.
## Setup/Building the importer:
+Download the importer
+[linux x86_64](https://zed-tracy-import-miniprofiler.nyc3.digitaloceanspaces.com/tracy-import-miniprofiler-linux-x86_64)
+[mac aarch64](https://zed-tracy-import-miniprofiler.nyc3.digitaloceanspaces.com/tracy-import-miniprofiler-macos-aarch64)
+
+### Alternative: Building it yourself
+
- Clone the repo at git@github.com:zed-industries/tracy.git on v0.12.2 branch
-- `cd profiler && mkdir build && cd build`
+- `cd import && mkdir build && cd build`
- Run cmake to generate build files: `cmake -G Ninja -DCMAKE_BUILD_TYPE=Release ..`
- Build the importer: `ninja`
-- Run the impoter on the trace file: `./tracy-import-miniprofiler /path/to/trace.miniprof /path/to/output.tracy`
+- Run the importer on the trace file: `./tracy-import-miniprofiler /path/to/trace.miniprof /path/to/output.tracy`
- Open the trace in tracy:
- If you're on windows download the v0.12.2 version from the releases on the upstream repo
- If you're on other platforms open it on the website: https://tracy.nereid.pl/ (the version might mismatch so your luck might vary, we need to host our own ideally..)
@@ -174,14 +174,38 @@ When opening a remote project there are three relevant settings locations:
Both the local Zed and the server Zed read the project settings, but they are not aware of the other's main `settings.json`.
-Depending on the kind of setting you want to make, which settings file you should use:
+Which settings file you should use depends on the kind of setting you want to make:
- Project settings should be used for things that affect the project: indentation settings, which formatter / language server to use, etc.
-- Server settings should be used for things that affect the server: paths to language servers, etc.
+- Server settings should be used for things that affect the server: paths to language servers, proxy settings, etc.
- Local settings should be used for things that affect the UI: font size, etc.
In addition any extensions you have installed locally will be propagated to the remote server. This means that language servers, etc. will run correctly.
+## Proxy Configuration
+
+The remote server will not use your local machine's proxy configuration because they may be under different network policies. If your remote server requires a proxy to access the internet, you must configure it on the remote server itself.
+
+In most cases, your remote server will already have proxy environment variables configured. Zed will automatically use them when downloading language servers, communicating with LLM models, etc.
+
+If needed, you can set these environment variables in the server's shell configuration (e.g., `~/.bashrc`):
+
+```bash
+export http_proxy="http://proxy.example.com:8080"
+export https_proxy="http://proxy.example.com:8080"
+export no_proxy="localhost,127.0.0.1"
+```
+
+Alternatively, you can configure the proxy in the remote machine's `~/.config/zed/settings.json` (Linux) or `~/.zed/settings.json` (macOS):
+
+```json
+{
+ "proxy": "http://proxy.example.com:8080"
+}
+```
+
+See the [proxy documentation](./configuring-zed.md#network-proxy) for supported proxy types and additional configuration options.
+
## Initializing the remote server
Once you provide the SSH options, Zed shells out to `ssh` on your local machine to create a ControlMaster connection with the options you provide.
@@ -0,0 +1,46 @@
+# Tab Switcher
+
+The Tab Switcher provides a quick way to navigate between open tabs in Zed. It
+displays a list of your open tabs sorted by recent usage, making it easy to jump
+back to whatever you were just working on.
+
+
+
+## Quick Switching
+
+When the Tab Switcher is opened using {#kb tab_switcher::Toggle}, instead of
+running the {#action tab_switcher::Toggle} from the command palette, it'll stay
+active as long as the <kbd class="keybinding">ctrl</kbd> key is held down.
+
+While holding down <kbd class="keybinding">ctrl</kbd>, each subsequent <kbd
+class="keybinding">tab</kbd> press cycles to the next item (<kbd
+class="keybinding">shift</kbd> to cycle backwards) and, when <kbd
+class="keybinding">ctrl</kbd> is released, the selected item is confirmed and
+the switcher is closed.
+
+## Opening the Tab Switcher
+
+The Tab Switcher can also be opened with either {#action tab_switcher::Toggle}
+or {#action tab_switcher::ToggleAll}. Using {#kb tab_switcher::Toggle} will show
+only the tabs for the current pane, while {#kb tab_switcher::ToggleAll} shows
+all tabs for all panes.
+
+While the Tab Switcher is open, you can:
+
+- Press {#kb menu::SelectNext} to move to the next tab in the list
+- Press {#kb menu::SelectPrevious} to move to the previous tab
+- Press <kbd class="keybinding">enter</kbd> to confirm the selected tab and close the switcher
+- Press <kbd class="keybinding">escape</kbd> to close the switcher and return to the original tab from which
+ the switcher was opened
+- Press {#kb tab_switcher::CloseSelectedItem} to close the currently selected tab
+
+As you navigate through the list, Zed will update the pane's active item to
+match the selected tab.
+
+## Action Reference
+
+| Action | Description |
+| ----------------------------------------- | ------------------------------------------------- |
+| {#action tab_switcher::Toggle} | Open the Tab Switcher for the current pane |
+| {#action tab_switcher::ToggleAll} | Open the Tab Switcher showing tabs from all panes |
+| {#action tab_switcher::CloseSelectedItem} | Close the selected tab in the Tab Switcher |
@@ -177,6 +177,7 @@ let
ZED_UPDATE_EXPLANATION = "Zed has been installed using Nix. Auto-updates have thus been disabled.";
RELEASE_VERSION = version;
LK_CUSTOM_WEBRTC = livekit-libwebrtc;
+ PROTOC="${protobuf}/bin/protoc";
CARGO_PROFILE = profile;
# need to handle some profiles specially https://github.com/rust-lang/cargo/issues/11053