From 3f1319162af7f5b6b6d6a3db51cc07a258c45e95 Mon Sep 17 00:00:00 2001 From: Bennet Fenner Date: Fri, 17 Oct 2025 18:49:11 +0200 Subject: [PATCH] Remove agent1 code (#40495) Release Notes: - N/A --- .zed/settings.json | 2 +- Cargo.lock | 266 +- Cargo.toml | 8 +- assets/keymaps/default-linux.json | 4 +- assets/keymaps/default-macos.json | 4 +- assets/keymaps/default-windows.json | 4 +- clippy.toml | 2 +- crates/agent/Cargo.toml | 72 +- crates/agent/src/agent.rs | 1658 +++- crates/agent/src/agent_profile.rs | 341 - crates/agent/src/context_server_tool.rs | 140 - crates/{agent2 => agent}/src/db.rs | 124 +- .../src/edit_agent.rs | 0 .../src/edit_agent/create_file_parser.rs | 0 .../src/edit_agent/edit_parser.rs | 0 .../src/edit_agent/evals.rs | 80 +- .../fixtures/add_overwrite_test/before.rs | 0 .../fixtures/delete_run_git_blame/after.rs | 0 .../fixtures/delete_run_git_blame/before.rs | 0 .../disable_cursor_blinking/before.rs | 0 .../disable_cursor_blinking/possible-01.diff | 0 .../disable_cursor_blinking/possible-02.diff | 0 .../disable_cursor_blinking/possible-03.diff | 0 .../disable_cursor_blinking/possible-04.diff | 0 .../extract_handle_command_output/before.rs | 0 .../possible-01.diff | 0 .../possible-02.diff | 0 .../possible-03.diff | 0 .../possible-04.diff | 0 .../possible-05.diff | 0 .../possible-06.diff | 0 .../possible-07.diff | 0 .../possible-08.diff | 0 .../from_pixels_constructor/before.rs | 0 .../fixtures/translate_doc_comments/before.rs | 0 .../before.rs | 0 .../edit_agent/evals/fixtures/zode/prompt.md | 0 .../edit_agent/evals/fixtures/zode/react.py | 0 .../evals/fixtures/zode/react_test.py | 0 .../src/edit_agent/streaming_fuzzy_matcher.rs | 0 crates/{agent2 => agent}/src/history_store.rs | 104 +- crates/agent/src/legacy_thread.rs | 402 + .../src/native_agent_server.rs | 0 .../{assistant_tool => agent}/src/outline.rs | 158 +- .../src/prompts/stale_files_prompt_header.txt | 3 - crates/{agent2 => agent}/src/templates.rs | 0 .../src/templates/create_file_prompt.hbs | 0 .../src/templates/diff_judge.hbs | 0 .../edit_file_prompt_diff_fenced.hbs | 0 .../src/templates/edit_file_prompt_xml.hbs | 0 .../src/templates/system_prompt.hbs | 0 crates/{agent2 => agent}/src/tests/mod.rs | 20 +- .../{agent2 => agent}/src/tests/test_tools.rs | 0 crates/agent/src/thread.rs | 7030 +++++------------ crates/agent/src/thread_store.rs | 1287 --- .../src/tool_schema.rs | 43 +- crates/agent/src/tool_use.rs | 575 -- crates/agent/src/tools.rs | 88 + .../src/tools/context_server_registry.rs | 13 +- .../src/tools/copy_path_tool.rs | 0 .../src/tools/create_directory_tool.rs | 0 .../src/tools/delete_path_tool.rs | 0 .../src/tools/diagnostics_tool.rs | 0 .../src/tools/edit_file_tool.rs | 30 +- .../{agent2 => agent}/src/tools/fetch_tool.rs | 0 .../src/tools/find_path_tool.rs | 0 .../{agent2 => agent}/src/tools/grep_tool.rs | 0 .../src/tools/list_directory_tool.rs | 0 .../src/tools/move_path_tool.rs | 0 .../{agent2 => agent}/src/tools/now_tool.rs | 0 .../{agent2 => agent}/src/tools/open_tool.rs | 0 .../src/tools/read_file_tool.rs | 3 +- .../src/tools/terminal_tool.rs | 0 .../src/tools/thinking_tool.rs | 0 .../src/tools/web_search_tool.rs | 0 crates/agent2/Cargo.toml | 102 - crates/agent2/LICENSE-GPL | 1 - crates/agent2/src/agent.rs | 1588 ---- crates/agent2/src/agent2.rs | 19 - crates/agent2/src/thread.rs | 2663 ------- crates/agent2/src/tool_schema.rs | 43 - crates/agent2/src/tools.rs | 60 - crates/agent_settings/src/agent_settings.rs | 5 +- .../summarize_thread_detailed_prompt.txt | 0 .../src/prompts/summarize_thread_prompt.txt | 0 crates/agent_ui/Cargo.toml | 5 +- .../agent_ui/src/acp/completion_provider.rs | 39 +- crates/agent_ui/src/acp/entry_view_state.rs | 4 +- crates/agent_ui/src/acp/message_editor.rs | 16 +- crates/agent_ui/src/acp/thread_history.rs | 9 +- crates/agent_ui/src/acp/thread_view.rs | 16 +- crates/agent_ui/src/agent_configuration.rs | 58 +- .../configure_context_server_tools_modal.rs | 37 +- .../manage_profiles_modal.rs | 28 +- .../src/agent_configuration/tool_picker.rs | 100 +- crates/agent_ui/src/agent_panel.rs | 179 +- crates/agent_ui/src/agent_ui.rs | 13 +- crates/agent_ui/src/buffer_codegen.rs | 13 +- crates/{agent => agent_ui}/src/context.rs | 154 +- crates/agent_ui/src/context_picker.rs | 149 +- .../src/context_picker/completion_provider.rs | 126 +- .../context_picker/fetch_context_picker.rs | 3 +- .../src/context_picker/file_context_picker.rs | 6 +- .../context_picker/rules_context_picker.rs | 24 +- .../context_picker/symbol_context_picker.rs | 6 +- .../context_picker/thread_context_picker.rs | 219 +- .../{agent => agent_ui}/src/context_store.rs | 107 +- crates/agent_ui/src/context_strip.rs | 28 +- crates/agent_ui/src/inline_assistant.rs | 35 +- crates/agent_ui/src/inline_prompt_editor.rs | 23 +- crates/agent_ui/src/message_editor.rs | 38 +- .../agent_ui/src/terminal_inline_assistant.rs | 20 +- crates/agent_ui/src/ui/context_pill.rs | 12 +- crates/assistant_tool/Cargo.toml | 50 - crates/assistant_tool/LICENSE-GPL | 1 - crates/assistant_tool/src/assistant_tool.rs | 269 - crates/assistant_tool/src/tool_registry.rs | 74 - crates/assistant_tool/src/tool_working_set.rs | 415 - crates/assistant_tools/Cargo.toml | 92 - crates/assistant_tools/LICENSE-GPL | 1 - crates/assistant_tools/src/assistant_tools.rs | 167 - crates/assistant_tools/src/copy_path_tool.rs | 123 - .../src/copy_path_tool/description.md | 6 - .../src/create_directory_tool.rs | 100 - .../src/create_directory_tool/description.md | 3 - .../assistant_tools/src/delete_path_tool.rs | 144 - .../src/delete_path_tool/description.md | 1 - .../assistant_tools/src/diagnostics_tool.rs | 171 - .../src/diagnostics_tool/description.md | 21 - crates/assistant_tools/src/edit_file_tool.rs | 2423 ------ .../src/edit_file_tool/description.md | 8 - crates/assistant_tools/src/fetch_tool.rs | 178 - .../src/fetch_tool/description.md | 1 - crates/assistant_tools/src/find_path_tool.rs | 472 -- .../src/find_path_tool/description.md | 7 - crates/assistant_tools/src/grep_tool.rs | 1308 --- .../src/grep_tool/description.md | 9 - .../src/list_directory_tool.rs | 869 -- .../src/list_directory_tool/description.md | 1 - crates/assistant_tools/src/move_path_tool.rs | 132 - .../src/move_path_tool/description.md | 5 - crates/assistant_tools/src/now_tool.rs | 84 - crates/assistant_tools/src/open_tool.rs | 170 - .../src/open_tool/description.md | 9 - .../src/project_notifications_tool.rs | 360 - .../project_notifications_tool/description.md | 3 - .../prompt_header.txt | 3 - crates/assistant_tools/src/read_file_tool.rs | 1190 --- .../src/read_file_tool/description.md | 3 - crates/assistant_tools/src/schema.rs | 60 - crates/assistant_tools/src/templates.rs | 32 - crates/assistant_tools/src/terminal_tool.rs | 883 --- .../src/terminal_tool/description.md | 11 - crates/assistant_tools/src/thinking_tool.rs | 69 - .../src/thinking_tool/description.md | 1 - crates/assistant_tools/src/ui.rs | 5 - .../src/ui/tool_call_card_header.rs | 131 - .../src/ui/tool_output_preview.rs | 115 - crates/assistant_tools/src/web_search_tool.rs | 327 - crates/eval/Cargo.toml | 4 +- crates/eval/src/eval.rs | 1 - crates/eval/src/example.rs | 4 +- .../eval/src/examples/comment_translation.rs | 2 +- crates/eval/src/examples/file_search.rs | 2 +- .../src/examples/grep_params_escapement.rs | 1 - crates/eval/src/examples/overwrite_file.rs | 1 - crates/eval/src/examples/planets.rs | 7 +- crates/eval/src/instance.rs | 3 +- crates/language_model/Cargo.toml | 1 - crates/language_model/src/language_model.rs | 8 +- crates/remote_server/Cargo.toml | 3 +- .../remote_server/src/remote_editing_tests.rs | 42 +- crates/zed/Cargo.toml | 2 - crates/zed/src/main.rs | 1 - script/danger/dangerfile.ts | 11 +- 175 files changed, 5271 insertions(+), 23738 deletions(-) delete mode 100644 crates/agent/src/agent_profile.rs delete mode 100644 crates/agent/src/context_server_tool.rs rename crates/{agent2 => agent}/src/db.rs (78%) rename crates/{assistant_tools => agent}/src/edit_agent.rs (100%) rename crates/{assistant_tools => agent}/src/edit_agent/create_file_parser.rs (100%) rename crates/{assistant_tools => agent}/src/edit_agent/edit_parser.rs (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals.rs (97%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/add_overwrite_test/before.rs (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/delete_run_git_blame/after.rs (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/delete_run_git_blame/before.rs (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/disable_cursor_blinking/before.rs (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-01.diff (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-02.diff (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-03.diff (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-04.diff (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/extract_handle_command_output/before.rs (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-01.diff (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-02.diff (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-03.diff (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-04.diff (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-05.diff (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-06.diff (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-07.diff (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-08.diff (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/from_pixels_constructor/before.rs (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/translate_doc_comments/before.rs (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/zode/prompt.md (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/zode/react.py (100%) rename crates/{assistant_tools => agent}/src/edit_agent/evals/fixtures/zode/react_test.py (100%) rename crates/{assistant_tools => agent}/src/edit_agent/streaming_fuzzy_matcher.rs (100%) rename crates/{agent2 => agent}/src/history_store.rs (80%) create mode 100644 crates/agent/src/legacy_thread.rs rename crates/{agent2 => agent}/src/native_agent_server.rs (100%) rename crates/{assistant_tool => agent}/src/outline.rs (76%) delete mode 100644 crates/agent/src/prompts/stale_files_prompt_header.txt rename crates/{agent2 => agent}/src/templates.rs (100%) rename crates/{assistant_tools => agent}/src/templates/create_file_prompt.hbs (100%) rename crates/{assistant_tools => agent}/src/templates/diff_judge.hbs (100%) rename crates/{assistant_tools => agent}/src/templates/edit_file_prompt_diff_fenced.hbs (100%) rename crates/{assistant_tools => agent}/src/templates/edit_file_prompt_xml.hbs (100%) rename crates/{agent2 => agent}/src/templates/system_prompt.hbs (100%) rename crates/{agent2 => agent}/src/tests/mod.rs (99%) rename crates/{agent2 => agent}/src/tests/test_tools.rs (100%) delete mode 100644 crates/agent/src/thread_store.rs rename crates/{assistant_tool => agent}/src/tool_schema.rs (85%) delete mode 100644 crates/agent/src/tool_use.rs create mode 100644 crates/agent/src/tools.rs rename crates/{agent2 => agent}/src/tools/context_server_registry.rs (95%) rename crates/{agent2 => agent}/src/tools/copy_path_tool.rs (100%) rename crates/{agent2 => agent}/src/tools/create_directory_tool.rs (100%) rename crates/{agent2 => agent}/src/tools/delete_path_tool.rs (100%) rename crates/{agent2 => agent}/src/tools/diagnostics_tool.rs (100%) rename crates/{agent2 => agent}/src/tools/edit_file_tool.rs (98%) rename crates/{agent2 => agent}/src/tools/fetch_tool.rs (100%) rename crates/{agent2 => agent}/src/tools/find_path_tool.rs (100%) rename crates/{agent2 => agent}/src/tools/grep_tool.rs (100%) rename crates/{agent2 => agent}/src/tools/list_directory_tool.rs (100%) rename crates/{agent2 => agent}/src/tools/move_path_tool.rs (100%) rename crates/{agent2 => agent}/src/tools/now_tool.rs (100%) rename crates/{agent2 => agent}/src/tools/open_tool.rs (100%) rename crates/{agent2 => agent}/src/tools/read_file_tool.rs (99%) rename crates/{agent2 => agent}/src/tools/terminal_tool.rs (100%) rename crates/{agent2 => agent}/src/tools/thinking_tool.rs (100%) rename crates/{agent2 => agent}/src/tools/web_search_tool.rs (100%) delete mode 100644 crates/agent2/Cargo.toml delete mode 120000 crates/agent2/LICENSE-GPL delete mode 100644 crates/agent2/src/agent.rs delete mode 100644 crates/agent2/src/agent2.rs delete mode 100644 crates/agent2/src/thread.rs delete mode 100644 crates/agent2/src/tool_schema.rs delete mode 100644 crates/agent2/src/tools.rs rename crates/{agent => agent_settings}/src/prompts/summarize_thread_detailed_prompt.txt (100%) rename crates/{agent => agent_settings}/src/prompts/summarize_thread_prompt.txt (100%) rename crates/{agent => agent_ui}/src/context.rs (90%) rename crates/{agent => agent_ui}/src/context_store.rs (87%) delete mode 100644 crates/assistant_tool/Cargo.toml delete mode 120000 crates/assistant_tool/LICENSE-GPL delete mode 100644 crates/assistant_tool/src/assistant_tool.rs delete mode 100644 crates/assistant_tool/src/tool_registry.rs delete mode 100644 crates/assistant_tool/src/tool_working_set.rs delete mode 100644 crates/assistant_tools/Cargo.toml delete mode 120000 crates/assistant_tools/LICENSE-GPL delete mode 100644 crates/assistant_tools/src/assistant_tools.rs delete mode 100644 crates/assistant_tools/src/copy_path_tool.rs delete mode 100644 crates/assistant_tools/src/copy_path_tool/description.md delete mode 100644 crates/assistant_tools/src/create_directory_tool.rs delete mode 100644 crates/assistant_tools/src/create_directory_tool/description.md delete mode 100644 crates/assistant_tools/src/delete_path_tool.rs delete mode 100644 crates/assistant_tools/src/delete_path_tool/description.md delete mode 100644 crates/assistant_tools/src/diagnostics_tool.rs delete mode 100644 crates/assistant_tools/src/diagnostics_tool/description.md delete mode 100644 crates/assistant_tools/src/edit_file_tool.rs delete mode 100644 crates/assistant_tools/src/edit_file_tool/description.md delete mode 100644 crates/assistant_tools/src/fetch_tool.rs delete mode 100644 crates/assistant_tools/src/fetch_tool/description.md delete mode 100644 crates/assistant_tools/src/find_path_tool.rs delete mode 100644 crates/assistant_tools/src/find_path_tool/description.md delete mode 100644 crates/assistant_tools/src/grep_tool.rs delete mode 100644 crates/assistant_tools/src/grep_tool/description.md delete mode 100644 crates/assistant_tools/src/list_directory_tool.rs delete mode 100644 crates/assistant_tools/src/list_directory_tool/description.md delete mode 100644 crates/assistant_tools/src/move_path_tool.rs delete mode 100644 crates/assistant_tools/src/move_path_tool/description.md delete mode 100644 crates/assistant_tools/src/now_tool.rs delete mode 100644 crates/assistant_tools/src/open_tool.rs delete mode 100644 crates/assistant_tools/src/open_tool/description.md delete mode 100644 crates/assistant_tools/src/project_notifications_tool.rs delete mode 100644 crates/assistant_tools/src/project_notifications_tool/description.md delete mode 100644 crates/assistant_tools/src/project_notifications_tool/prompt_header.txt delete mode 100644 crates/assistant_tools/src/read_file_tool.rs delete mode 100644 crates/assistant_tools/src/read_file_tool/description.md delete mode 100644 crates/assistant_tools/src/schema.rs delete mode 100644 crates/assistant_tools/src/templates.rs delete mode 100644 crates/assistant_tools/src/terminal_tool.rs delete mode 100644 crates/assistant_tools/src/terminal_tool/description.md delete mode 100644 crates/assistant_tools/src/thinking_tool.rs delete mode 100644 crates/assistant_tools/src/thinking_tool/description.md delete mode 100644 crates/assistant_tools/src/ui.rs delete mode 100644 crates/assistant_tools/src/ui/tool_call_card_header.rs delete mode 100644 crates/assistant_tools/src/ui/tool_output_preview.rs delete mode 100644 crates/assistant_tools/src/web_search_tool.rs diff --git a/.zed/settings.json b/.zed/settings.json index 68e05a426f2474cb663aa5ff843905f375170e0f..2760be95819e9340acf55f60616a9c22105ff52a 100644 --- a/.zed/settings.json +++ b/.zed/settings.json @@ -48,7 +48,7 @@ "remove_trailing_whitespace_on_save": true, "ensure_final_newline_on_save": true, "file_scan_exclusions": [ - "crates/assistant_tools/src/edit_agent/evals/fixtures", + "crates/agent/src/edit_agent/evals/fixtures", "crates/eval/worktrees/", "crates/eval/repos/", "**/.git", diff --git a/Cargo.lock b/Cargo.lock index 43c9c672eb83da9b02f6d885508189b55c5f0080..bb3b71a3ee45e052670ef3e67c877833253c76b0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -139,90 +139,14 @@ dependencies = [ [[package]] name = "agent" version = "0.1.0" -dependencies = [ - "action_log", - "agent_settings", - "anyhow", - "assistant_context", - "assistant_tool", - "assistant_tools", - "chrono", - "client", - "cloud_llm_client", - "collections", - "component", - "context_server", - "convert_case 0.8.0", - "fs", - "futures 0.3.31", - "git", - "gpui", - "heed", - "http_client", - "icons", - "indoc", - "language", - "language_model", - "log", - "parking_lot", - "paths", - "postage", - "pretty_assertions", - "project", - "prompt_store", - "rand 0.9.1", - "ref-cast", - "rope", - "schemars 1.0.1", - "serde", - "serde_json", - "settings", - "smol", - "sqlez", - "telemetry", - "text", - "theme", - "thiserror 2.0.12", - "time", - "util", - "uuid", - "workspace", - "workspace-hack", - "zed_env_vars", - "zstd 0.11.2+zstd.1.5.2", -] - -[[package]] -name = "agent-client-protocol" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3aaa2bd05a2401887945f8bfd70026e90bc3cf96c62ab9eba2779835bf21dc60" -dependencies = [ - "anyhow", - "async-broadcast", - "async-trait", - "futures 0.3.31", - "log", - "parking_lot", - "schemars 1.0.1", - "serde", - "serde_json", -] - -[[package]] -name = "agent2" -version = "0.1.0" dependencies = [ "acp_thread", "action_log", - "agent", "agent-client-protocol", "agent_servers", "agent_settings", "anyhow", "assistant_context", - "assistant_tool", - "assistant_tools", "chrono", "client", "clock", @@ -231,6 +155,7 @@ dependencies = [ "context_server", "ctor", "db", + "derive_more", "editor", "env_logger 0.11.8", "fs", @@ -254,14 +179,19 @@ dependencies = [ "pretty_assertions", "project", "prompt_store", + "rand 0.9.1", + "regex", "reqwest_client", "rust-embed", "schemars 1.0.1", "serde", "serde_json", "settings", + "smallvec", "smol", "sqlez", + "streaming_diff", + "strsim", "task", "telemetry", "tempfile", @@ -283,6 +213,23 @@ dependencies = [ "zstd 0.11.2+zstd.1.5.2", ] +[[package]] +name = "agent-client-protocol" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3aaa2bd05a2401887945f8bfd70026e90bc3cf96c62ab9eba2779835bf21dc60" +dependencies = [ + "anyhow", + "async-broadcast", + "async-trait", + "futures 0.3.31", + "log", + "parking_lot", + "schemars 1.0.1", + "serde", + "serde_json", +] + [[package]] name = "agent_servers" version = "0.1.0" @@ -356,7 +303,6 @@ dependencies = [ "action_log", "agent", "agent-client-protocol", - "agent2", "agent_servers", "agent_settings", "ai_onboarding", @@ -365,8 +311,6 @@ dependencies = [ "assistant_context", "assistant_slash_command", "assistant_slash_commands", - "assistant_tool", - "assistant_tools", "audio", "buffer_diff", "chrono", @@ -411,6 +355,7 @@ dependencies = [ "prompt_store", "proto", "rand 0.9.1", + "ref-cast", "release_channel", "rope", "rules_library", @@ -965,106 +910,6 @@ dependencies = [ "zlog", ] -[[package]] -name = "assistant_tool" -version = "0.1.0" -dependencies = [ - "action_log", - "anyhow", - "buffer_diff", - "clock", - "collections", - "ctor", - "derive_more", - "gpui", - "icons", - "indoc", - "language", - "language_model", - "log", - "parking_lot", - "pretty_assertions", - "project", - "rand 0.9.1", - "regex", - "serde", - "serde_json", - "settings", - "text", - "util", - "workspace", - "workspace-hack", - "zlog", -] - -[[package]] -name = "assistant_tools" -version = "0.1.0" -dependencies = [ - "action_log", - "agent_settings", - "anyhow", - "assistant_tool", - "buffer_diff", - "chrono", - "client", - "clock", - "cloud_llm_client", - "collections", - "component", - "derive_more", - "diffy", - "editor", - "feature_flags", - "fs", - "futures 0.3.31", - "gpui", - "gpui_tokio", - "handlebars 4.5.0", - "html_to_markdown", - "http_client", - "indoc", - "itertools 0.14.0", - "language", - "language_model", - "language_models", - "log", - "lsp", - "markdown", - "open", - "paths", - "portable-pty", - "pretty_assertions", - "project", - "prompt_store", - "rand 0.9.1", - "regex", - "reqwest_client", - "rust-embed", - "schemars 1.0.1", - "serde", - "serde_json", - "settings", - "smallvec", - "smol", - "streaming_diff", - "strsim", - "task", - "tempfile", - "terminal", - "terminal_view", - "theme", - "tree-sitter-rust", - "ui", - "unindent", - "util", - "watch", - "web_search", - "workspace", - "workspace-hack", - "zlog", -] - [[package]] name = "async-attributes" version = "1.1.2" @@ -5819,63 +5664,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "eval" -version = "0.1.0" -dependencies = [ - "agent", - "agent_settings", - "agent_ui", - "anyhow", - "assistant_tool", - "assistant_tools", - "async-trait", - "buffer_diff", - "chrono", - "clap", - "client", - "cloud_llm_client", - "collections", - "debug_adapter_extension", - "dirs 4.0.0", - "dotenvy", - "env_logger 0.11.8", - "extension", - "fs", - "futures 0.3.31", - "gpui", - "gpui_tokio", - "handlebars 4.5.0", - "language", - "language_extension", - "language_model", - "language_models", - "languages", - "markdown", - "node_runtime", - "pathdiff", - "paths", - "pretty_assertions", - "project", - "prompt_store", - "regex", - "release_channel", - "reqwest_client", - "serde", - "serde_json", - "settings", - "shellexpand 2.1.2", - "smol", - "telemetry", - "terminal_view", - "toml 0.8.20", - "unindent", - "util", - "uuid", - "watch", - "workspace-hack", -] - [[package]] name = "event-listener" version = "2.5.3" @@ -8987,7 +8775,6 @@ dependencies = [ "open_router", "parking_lot", "proto", - "schemars 1.0.1", "serde", "serde_json", "settings", @@ -14006,10 +13793,9 @@ name = "remote_server" version = "0.1.0" dependencies = [ "action_log", + "agent", "anyhow", "askpass", - "assistant_tool", - "assistant_tools", "cargo_toml", "clap", "client", @@ -21242,14 +21028,12 @@ version = "0.210.0" dependencies = [ "acp_tools", "activity_indicator", - "agent", "agent_settings", "agent_ui", "anyhow", "ashpd 0.11.0", "askpass", "assets", - "assistant_tools", "audio", "auto_update", "auto_update_ui", diff --git a/Cargo.toml b/Cargo.toml index 3bc4123a0967a1f3bdc5d4d48d37ffedfbf372ce..33f3fa2ed3fa912e33bc24fa9303e3c2b4790dad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,6 @@ members = [ "crates/action_log", "crates/activity_indicator", "crates/agent", - "crates/agent2", "crates/agent_servers", "crates/agent_settings", "crates/agent_ui", @@ -17,8 +16,6 @@ members = [ "crates/assistant_context", "crates/assistant_slash_command", "crates/assistant_slash_commands", - "crates/assistant_tool", - "crates/assistant_tools", "crates/audio", "crates/auto_update", "crates/auto_update_helper", @@ -61,7 +58,7 @@ members = [ "crates/edit_prediction_context", "crates/zeta2_tools", "crates/editor", - "crates/eval", + # "crates/eval", "crates/explorer_command_injector", "crates/extension", "crates/extension_api", @@ -240,7 +237,6 @@ acp_tools = { path = "crates/acp_tools" } acp_thread = { path = "crates/acp_thread" } action_log = { path = "crates/action_log" } agent = { path = "crates/agent" } -agent2 = { path = "crates/agent2" } activity_indicator = { path = "crates/activity_indicator" } agent_ui = { path = "crates/agent_ui" } agent_settings = { path = "crates/agent_settings" } @@ -253,8 +249,6 @@ assets = { path = "crates/assets" } assistant_context = { path = "crates/assistant_context" } assistant_slash_command = { path = "crates/assistant_slash_command" } assistant_slash_commands = { path = "crates/assistant_slash_commands" } -assistant_tool = { path = "crates/assistant_tool" } -assistant_tools = { path = "crates/assistant_tools" } audio = { path = "crates/audio" } auto_update = { path = "crates/auto_update" } auto_update_helper = { path = "crates/auto_update_helper" } diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index 6f3b0ced8feaf5ca9ca3873c47b446117cedb6e8..ff5d7533f412872908d52228590fa3afe45a02d0 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -269,14 +269,14 @@ } }, { - "context": "AgentPanel && prompt_editor", + "context": "AgentPanel && text_thread", "bindings": { "ctrl-n": "agent::NewTextThread", "ctrl-alt-t": "agent::NewThread" } }, { - "context": "AgentPanel && external_agent_thread", + "context": "AgentPanel && acp_thread", "use_key_equivalents": true, "bindings": { "ctrl-n": "agent::NewExternalAgentThread", diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index ffa29b719f74e76c19dc7476c9c6b9643791a22f..9b20c267feeed5068b05cb08e0755d3faab75c96 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -307,7 +307,7 @@ } }, { - "context": "AgentPanel && prompt_editor", + "context": "AgentPanel && text_thread", "use_key_equivalents": true, "bindings": { "cmd-n": "agent::NewTextThread", @@ -315,7 +315,7 @@ } }, { - "context": "AgentPanel && external_agent_thread", + "context": "AgentPanel && acp_thread", "use_key_equivalents": true, "bindings": { "cmd-n": "agent::NewExternalAgentThread", diff --git a/assets/keymaps/default-windows.json b/assets/keymaps/default-windows.json index 4daacca4e2eaa224dacd3880fef2d046139587dd..87e1c350dc10c1f47ceb260b4ce2a03a032b0996 100644 --- a/assets/keymaps/default-windows.json +++ b/assets/keymaps/default-windows.json @@ -270,7 +270,7 @@ } }, { - "context": "AgentPanel && prompt_editor", + "context": "AgentPanel && text_thread", "use_key_equivalents": true, "bindings": { "ctrl-n": "agent::NewTextThread", @@ -278,7 +278,7 @@ } }, { - "context": "AgentPanel && external_agent_thread", + "context": "AgentPanel && acp_thread", "use_key_equivalents": true, "bindings": { "ctrl-n": "agent::NewExternalAgentThread", diff --git a/clippy.toml b/clippy.toml index 0976e2eba301b9bf679baecc232c6ae7084fd5e8..4e9f2de8585e74afe76840c59306ad8ed87fd947 100644 --- a/clippy.toml +++ b/clippy.toml @@ -3,7 +3,7 @@ avoid-breaking-exported-api = false ignore-interior-mutability = [ # Suppresses clippy::mutable_key_type, which is a false positive as the Eq # and Hash impls do not use fields with interior mutability. - "agent::context::AgentContextKey" + "agent_ui::context::AgentContextKey" ] disallowed-methods = [ { path = "std::process::Command::spawn", reason = "Spawning `std::process::Command` can block the current thread for an unknown duration", replacement = "smol::process::Command::spawn" }, diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index ebd043d0c3c61bed287507e303637035a5b8156f..5fb8f915b8f19d5adf6132f0fbefda0f5081bbae 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -5,74 +5,100 @@ edition.workspace = true publish.workspace = true license = "GPL-3.0-or-later" -[lints] -workspace = true - [lib] path = "src/agent.rs" -doctest = false [features] -test-support = [ - "gpui/test-support", - "language/test-support", -] +test-support = ["db/test-support"] +e2e = [] + +[lints] +workspace = true [dependencies] +acp_thread.workspace = true action_log.workspace = true +agent-client-protocol.workspace = true +agent_servers.workspace = true agent_settings.workspace = true anyhow.workspace = true assistant_context.workspace = true -assistant_tool.workspace = true chrono.workspace = true client.workspace = true cloud_llm_client.workspace = true collections.workspace = true -component.workspace = true context_server.workspace = true -convert_case.workspace = true +db.workspace = true +derive_more.workspace = true fs.workspace = true futures.workspace = true git.workspace = true gpui.workspace = true -heed.workspace = true +handlebars = { workspace = true, features = ["rust-embed"] } +html_to_markdown.workspace = true http_client.workspace = true -icons.workspace = true indoc.workspace = true +itertools.workspace = true language.workspace = true language_model.workspace = true +language_models.workspace = true log.workspace = true +open.workspace = true +parking_lot.workspace = true paths.workspace = true -postage.workspace = true project.workspace = true prompt_store.workspace = true -ref-cast.workspace = true -rope.workspace = true +regex.workspace = true +rust-embed.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true +smallvec.workspace = true smol.workspace = true sqlez.workspace = true +streaming_diff.workspace = true +strsim.workspace = true +task.workspace = true telemetry.workspace = true +terminal.workspace = true text.workspace = true -theme.workspace = true thiserror.workspace = true -time.workspace = true +ui.workspace = true util.workspace = true uuid.workspace = true +watch.workspace = true +web_search.workspace = true workspace-hack.workspace = true zed_env_vars.workspace = true zstd.workspace = true [dev-dependencies] -assistant_tools.workspace = true +agent_servers = { workspace = true, "features" = ["test-support"] } +assistant_context = { workspace = true, "features" = ["test-support"] } +client = { workspace = true, "features" = ["test-support"] } +clock = { workspace = true, "features" = ["test-support"] } +context_server = { workspace = true, "features" = ["test-support"] } +ctor.workspace = true +db = { workspace = true, "features" = ["test-support"] } +editor = { workspace = true, "features" = ["test-support"] } +env_logger.workspace = true +fs = { workspace = true, "features" = ["test-support"] } +git = { workspace = true, "features" = ["test-support"] } gpui = { workspace = true, "features" = ["test-support"] } -indoc.workspace = true +gpui_tokio.workspace = true language = { workspace = true, "features" = ["test-support"] } language_model = { workspace = true, "features" = ["test-support"] } -parking_lot.workspace = true +lsp = { workspace = true, "features" = ["test-support"] } pretty_assertions.workspace = true -project = { workspace = true, features = ["test-support"] } -workspace = { workspace = true, features = ["test-support"] } +project = { workspace = true, "features" = ["test-support"] } rand.workspace = true +reqwest_client.workspace = true +settings = { workspace = true, "features" = ["test-support"] } +tempfile.workspace = true +terminal = { workspace = true, "features" = ["test-support"] } +theme = { workspace = true, "features" = ["test-support"] } +tree-sitter-rust.workspace = true +unindent = { workspace = true } +worktree = { workspace = true, "features" = ["test-support"] } +zlog.workspace = true diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 056c380e78de576fd7b0c065e3e5de631fdc37bb..32dec9f723a6776fd14def29be3be4eb21afa72d 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -1,21 +1,1645 @@ -pub mod agent_profile; -pub mod context; -pub mod context_server_tool; -pub mod context_store; -pub mod thread; -pub mod thread_store; -pub mod tool_use; - -pub use context::{AgentContext, ContextId, ContextLoadResult}; -pub use context_store::ContextStore; +mod db; +mod edit_agent; +mod history_store; +mod legacy_thread; +mod native_agent_server; +pub mod outline; +mod templates; +mod thread; +mod tool_schema; +mod tools; + +#[cfg(test)] +mod tests; + +pub use db::*; +pub use history_store::*; +pub use native_agent_server::NativeAgentServer; +pub use templates::*; +pub use thread::*; +pub use tools::*; + +use acp_thread::{AcpThread, AgentModelSelector}; +use agent_client_protocol as acp; +use anyhow::{Context as _, Result, anyhow}; +use chrono::{DateTime, Utc}; +use collections::{HashSet, IndexMap}; use fs::Fs; -use std::sync::Arc; -pub use thread::{ - LastRestoreCheckpoint, Message, MessageCrease, MessageId, MessageSegment, Thread, ThreadError, - ThreadEvent, ThreadFeedback, ThreadId, ThreadSummary, TokenUsageRatio, +use futures::channel::{mpsc, oneshot}; +use futures::future::Shared; +use futures::{StreamExt, future}; +use gpui::{ + App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity, +}; +use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry}; +use project::{Project, ProjectItem, ProjectPath, Worktree}; +use prompt_store::{ + ProjectContext, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext, }; -pub use thread_store::{SerializedThread, TextThreadStore, ThreadStore}; +use serde::{Deserialize, Serialize}; +use settings::{LanguageModelSelection, update_settings_file}; +use std::any::Any; +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::rc::Rc; +use std::sync::Arc; +use util::ResultExt; +use util::rel_path::RelPath; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ProjectSnapshot { + pub worktree_snapshots: Vec, + pub timestamp: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct WorktreeSnapshot { + pub worktree_path: String, + pub git_state: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct GitState { + pub remote_url: Option, + pub head_sha: Option, + pub current_branch: Option, + pub diff: Option, +} + +const RULES_FILE_NAMES: [&str; 9] = [ + ".rules", + ".cursorrules", + ".windsurfrules", + ".clinerules", + ".github/copilot-instructions.md", + "CLAUDE.md", + "AGENT.md", + "AGENTS.md", + "GEMINI.md", +]; + +pub struct RulesLoadingError { + pub message: SharedString, +} + +/// Holds both the internal Thread and the AcpThread for a session +struct Session { + /// The internal thread that processes messages + thread: Entity, + /// The ACP thread that handles protocol communication + acp_thread: WeakEntity, + pending_save: Task<()>, + _subscriptions: Vec, +} + +pub struct LanguageModels { + /// Access language model by ID + models: HashMap>, + /// Cached list for returning language model information + model_list: acp_thread::AgentModelList, + refresh_models_rx: watch::Receiver<()>, + refresh_models_tx: watch::Sender<()>, + _authenticate_all_providers_task: Task<()>, +} + +impl LanguageModels { + fn new(cx: &mut App) -> Self { + let (refresh_models_tx, refresh_models_rx) = watch::channel(()); + + let mut this = Self { + models: HashMap::default(), + model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()), + refresh_models_rx, + refresh_models_tx, + _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx), + }; + this.refresh_list(cx); + this + } + + fn refresh_list(&mut self, cx: &App) { + let providers = LanguageModelRegistry::global(cx) + .read(cx) + .providers() + .into_iter() + .filter(|provider| provider.is_authenticated(cx)) + .collect::>(); + + let mut language_model_list = IndexMap::default(); + let mut recommended_models = HashSet::default(); + + let mut recommended = Vec::new(); + for provider in &providers { + for model in provider.recommended_models(cx) { + recommended_models.insert((model.provider_id(), model.id())); + recommended.push(Self::map_language_model_to_info(&model, provider)); + } + } + if !recommended.is_empty() { + language_model_list.insert( + acp_thread::AgentModelGroupName("Recommended".into()), + recommended, + ); + } + + let mut models = HashMap::default(); + for provider in providers { + let mut provider_models = Vec::new(); + for model in provider.provided_models(cx) { + let model_info = Self::map_language_model_to_info(&model, &provider); + let model_id = model_info.id.clone(); + if !recommended_models.contains(&(model.provider_id(), model.id())) { + provider_models.push(model_info); + } + models.insert(model_id, model); + } + if !provider_models.is_empty() { + language_model_list.insert( + acp_thread::AgentModelGroupName(provider.name().0.clone()), + provider_models, + ); + } + } + + self.models = models; + self.model_list = acp_thread::AgentModelList::Grouped(language_model_list); + self.refresh_models_tx.send(()).ok(); + } + + fn watch(&self) -> watch::Receiver<()> { + self.refresh_models_rx.clone() + } + + pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option> { + self.models.get(model_id).cloned() + } + + fn map_language_model_to_info( + model: &Arc, + provider: &Arc, + ) -> acp_thread::AgentModelInfo { + acp_thread::AgentModelInfo { + id: Self::model_id(model), + name: model.name().0, + description: None, + icon: Some(provider.icon()), + } + } + + fn model_id(model: &Arc) -> acp::ModelId { + acp::ModelId(format!("{}/{}", model.provider_id().0, model.id().0).into()) + } + + fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> { + let authenticate_all_providers = LanguageModelRegistry::global(cx) + .read(cx) + .providers() + .iter() + .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx))) + .collect::>(); + + cx.background_spawn(async move { + for (provider_id, provider_name, authenticate_task) in authenticate_all_providers { + if let Err(err) = authenticate_task.await { + match err { + language_model::AuthenticateError::CredentialsNotFound => { + // Since we're authenticating these providers in the + // background for the purposes of populating the + // language selector, we don't care about providers + // where the credentials are not found. + } + language_model::AuthenticateError::ConnectionRefused => { + // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures. + // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it. + // TODO: Better manage LM Studio auth logic to avoid these noisy failures. + } + _ => { + // Some providers have noisy failure states that we + // don't want to spam the logs with every time the + // language model selector is initialized. + // + // Ideally these should have more clear failure modes + // that we know are safe to ignore here, like what we do + // with `CredentialsNotFound` above. + match provider_id.0.as_ref() { + "lmstudio" | "ollama" => { + // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated". + // + // These fail noisily, so we don't log them. + } + "copilot_chat" => { + // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors. + } + _ => { + log::error!( + "Failed to authenticate provider: {}: {err}", + provider_name.0 + ); + } + } + } + } + } + } + }) + } +} + +pub struct NativeAgent { + /// Session ID -> Session mapping + sessions: HashMap, + history: Entity, + /// Shared project context for all threads + project_context: Entity, + project_context_needs_refresh: watch::Sender<()>, + _maintain_project_context: Task>, + context_server_registry: Entity, + /// Shared templates for all threads + templates: Arc, + /// Cached model information + models: LanguageModels, + project: Entity, + prompt_store: Option>, + fs: Arc, + _subscriptions: Vec, +} + +impl NativeAgent { + pub async fn new( + project: Entity, + history: Entity, + templates: Arc, + prompt_store: Option>, + fs: Arc, + cx: &mut AsyncApp, + ) -> Result> { + log::debug!("Creating new NativeAgent"); + + let project_context = cx + .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))? + .await; + + cx.new(|cx| { + let mut subscriptions = vec![ + cx.subscribe(&project, Self::handle_project_event), + cx.subscribe( + &LanguageModelRegistry::global(cx), + Self::handle_models_updated_event, + ), + ]; + if let Some(prompt_store) = prompt_store.as_ref() { + subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event)) + } + + let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) = + watch::channel(()); + Self { + sessions: HashMap::new(), + history, + project_context: cx.new(|_| project_context), + project_context_needs_refresh: project_context_needs_refresh_tx, + _maintain_project_context: cx.spawn(async move |this, cx| { + Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await + }), + context_server_registry: cx.new(|cx| { + ContextServerRegistry::new(project.read(cx).context_server_store(), cx) + }), + templates, + models: LanguageModels::new(cx), + project, + prompt_store, + fs, + _subscriptions: subscriptions, + } + }) + } + + fn register_session( + &mut self, + thread_handle: Entity, + cx: &mut Context, + ) -> Entity { + let connection = Rc::new(NativeAgentConnection(cx.entity())); + + let thread = thread_handle.read(cx); + let session_id = thread.id().clone(); + let title = thread.title(); + let project = thread.project.clone(); + let action_log = thread.action_log.clone(); + let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone(); + let acp_thread = cx.new(|cx| { + acp_thread::AcpThread::new( + title, + connection, + project.clone(), + action_log.clone(), + session_id.clone(), + prompt_capabilities_rx, + cx, + ) + }); + + let registry = LanguageModelRegistry::read_global(cx); + let summarization_model = registry.thread_summary_model().map(|c| c.model); + + thread_handle.update(cx, |thread, cx| { + thread.set_summarization_model(summarization_model, cx); + thread.add_default_tools( + Rc::new(AcpThreadEnvironment { + acp_thread: acp_thread.downgrade(), + }) as _, + cx, + ) + }); + + let subscriptions = vec![ + cx.observe_release(&acp_thread, |this, acp_thread, _cx| { + this.sessions.remove(acp_thread.session_id()); + }), + cx.subscribe(&thread_handle, Self::handle_thread_title_updated), + cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated), + cx.observe(&thread_handle, move |this, thread, cx| { + this.save_thread(thread, cx) + }), + ]; + + self.sessions.insert( + session_id, + Session { + thread: thread_handle, + acp_thread: acp_thread.downgrade(), + _subscriptions: subscriptions, + pending_save: Task::ready(()), + }, + ); + acp_thread + } + + pub fn models(&self) -> &LanguageModels { + &self.models + } + + async fn maintain_project_context( + this: WeakEntity, + mut needs_refresh: watch::Receiver<()>, + cx: &mut AsyncApp, + ) -> Result<()> { + while needs_refresh.changed().await.is_ok() { + let project_context = this + .update(cx, |this, cx| { + Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx) + })? + .await; + this.update(cx, |this, cx| { + this.project_context = cx.new(|_| project_context); + })?; + } + + Ok(()) + } + + fn build_project_context( + project: &Entity, + prompt_store: Option<&Entity>, + cx: &mut App, + ) -> Task { + let worktrees = project.read(cx).visible_worktrees(cx).collect::>(); + let worktree_tasks = worktrees + .into_iter() + .map(|worktree| { + Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx) + }) + .collect::>(); + let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() { + prompt_store.read_with(cx, |prompt_store, cx| { + let prompts = prompt_store.default_prompt_metadata(); + let load_tasks = prompts.into_iter().map(|prompt_metadata| { + let contents = prompt_store.load(prompt_metadata.id, cx); + async move { (contents.await, prompt_metadata) } + }); + cx.background_spawn(future::join_all(load_tasks)) + }) + } else { + Task::ready(vec![]) + }; + + cx.spawn(async move |_cx| { + let (worktrees, default_user_rules) = + future::join(future::join_all(worktree_tasks), default_user_rules_task).await; + + let worktrees = worktrees + .into_iter() + .map(|(worktree, _rules_error)| { + // TODO: show error message + // if let Some(rules_error) = rules_error { + // this.update(cx, |_, cx| cx.emit(rules_error)).ok(); + // } + worktree + }) + .collect::>(); + + let default_user_rules = default_user_rules + .into_iter() + .flat_map(|(contents, prompt_metadata)| match contents { + Ok(contents) => Some(UserRulesContext { + uuid: match prompt_metadata.id { + prompt_store::PromptId::User { uuid } => uuid, + prompt_store::PromptId::EditWorkflow => return None, + }, + title: prompt_metadata.title.map(|title| title.to_string()), + contents, + }), + Err(_err) => { + // TODO: show error message + // this.update(cx, |_, cx| { + // cx.emit(RulesLoadingError { + // message: format!("{err:?}").into(), + // }); + // }) + // .ok(); + None + } + }) + .collect::>(); + + ProjectContext::new(worktrees, default_user_rules) + }) + } + + fn load_worktree_info_for_system_prompt( + worktree: Entity, + project: Entity, + cx: &mut App, + ) -> Task<(WorktreeContext, Option)> { + let tree = worktree.read(cx); + let root_name = tree.root_name_str().into(); + let abs_path = tree.abs_path(); + + let mut context = WorktreeContext { + root_name, + abs_path, + rules_file: None, + }; + + let rules_task = Self::load_worktree_rules_file(worktree, project, cx); + let Some(rules_task) = rules_task else { + return Task::ready((context, None)); + }; + + cx.spawn(async move |_| { + let (rules_file, rules_file_error) = match rules_task.await { + Ok(rules_file) => (Some(rules_file), None), + Err(err) => ( + None, + Some(RulesLoadingError { + message: format!("{err}").into(), + }), + ), + }; + context.rules_file = rules_file; + (context, rules_file_error) + }) + } + + fn load_worktree_rules_file( + worktree: Entity, + project: Entity, + cx: &mut App, + ) -> Option>> { + let worktree = worktree.read(cx); + let worktree_id = worktree.id(); + let selected_rules_file = RULES_FILE_NAMES + .into_iter() + .filter_map(|name| { + worktree + .entry_for_path(RelPath::unix(name).unwrap()) + .filter(|entry| entry.is_file()) + .map(|entry| entry.path.clone()) + }) + .next(); + + // Note that Cline supports `.clinerules` being a directory, but that is not currently + // supported. This doesn't seem to occur often in GitHub repositories. + selected_rules_file.map(|path_in_worktree| { + let project_path = ProjectPath { + worktree_id, + path: path_in_worktree.clone(), + }; + let buffer_task = + project.update(cx, |project, cx| project.open_buffer(project_path, cx)); + let rope_task = cx.spawn(async move |cx| { + buffer_task.await?.read_with(cx, |buffer, cx| { + let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?; + anyhow::Ok((project_entry_id, buffer.as_rope().clone())) + })? + }); + // Build a string from the rope on a background thread. + cx.background_spawn(async move { + let (project_entry_id, rope) = rope_task.await?; + anyhow::Ok(RulesFileContext { + path_in_worktree, + text: rope.to_string().trim().to_string(), + project_entry_id: project_entry_id.to_usize(), + }) + }) + }) + } + + fn handle_thread_title_updated( + &mut self, + thread: Entity, + _: &TitleUpdated, + cx: &mut Context, + ) { + let session_id = thread.read(cx).id(); + let Some(session) = self.sessions.get(session_id) else { + return; + }; + let thread = thread.downgrade(); + let acp_thread = session.acp_thread.clone(); + cx.spawn(async move |_, cx| { + let title = thread.read_with(cx, |thread, _| thread.title())?; + let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?; + task.await + }) + .detach_and_log_err(cx); + } + + fn handle_thread_token_usage_updated( + &mut self, + thread: Entity, + usage: &TokenUsageUpdated, + cx: &mut Context, + ) { + let Some(session) = self.sessions.get(thread.read(cx).id()) else { + return; + }; + session + .acp_thread + .update(cx, |acp_thread, cx| { + acp_thread.update_token_usage(usage.0.clone(), cx); + }) + .ok(); + } + + fn handle_project_event( + &mut self, + _project: Entity, + event: &project::Event, + _cx: &mut Context, + ) { + match event { + project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => { + self.project_context_needs_refresh.send(()).ok(); + } + project::Event::WorktreeUpdatedEntries(_, items) => { + if items.iter().any(|(path, _, _)| { + RULES_FILE_NAMES + .iter() + .any(|name| path.as_ref() == RelPath::unix(name).unwrap()) + }) { + self.project_context_needs_refresh.send(()).ok(); + } + } + _ => {} + } + } + + fn handle_prompts_updated_event( + &mut self, + _prompt_store: Entity, + _event: &prompt_store::PromptsUpdatedEvent, + _cx: &mut Context, + ) { + self.project_context_needs_refresh.send(()).ok(); + } + + fn handle_models_updated_event( + &mut self, + _registry: Entity, + _event: &language_model::Event, + cx: &mut Context, + ) { + self.models.refresh_list(cx); + + let registry = LanguageModelRegistry::read_global(cx); + let default_model = registry.default_model().map(|m| m.model); + let summarization_model = registry.thread_summary_model().map(|m| m.model); + + for session in self.sessions.values_mut() { + session.thread.update(cx, |thread, cx| { + if thread.model().is_none() + && let Some(model) = default_model.clone() + { + thread.set_model(model, cx); + cx.notify(); + } + thread.set_summarization_model(summarization_model.clone(), cx); + }); + } + } + + pub fn load_thread( + &mut self, + id: acp::SessionId, + cx: &mut Context, + ) -> Task>> { + let database_future = ThreadsDatabase::connect(cx); + cx.spawn(async move |this, cx| { + let database = database_future.await.map_err(|err| anyhow!(err))?; + let db_thread = database + .load_thread(id.clone()) + .await? + .with_context(|| format!("no thread found with ID: {id:?}"))?; + + this.update(cx, |this, cx| { + let summarization_model = LanguageModelRegistry::read_global(cx) + .thread_summary_model() + .map(|c| c.model); + + cx.new(|cx| { + let mut thread = Thread::from_db( + id.clone(), + db_thread, + this.project.clone(), + this.project_context.clone(), + this.context_server_registry.clone(), + this.templates.clone(), + cx, + ); + thread.set_summarization_model(summarization_model, cx); + thread + }) + }) + }) + } + + pub fn open_thread( + &mut self, + id: acp::SessionId, + cx: &mut Context, + ) -> Task>> { + let task = self.load_thread(id, cx); + cx.spawn(async move |this, cx| { + let thread = task.await?; + let acp_thread = + this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?; + let events = thread.update(cx, |thread, cx| thread.replay(cx))?; + cx.update(|cx| { + NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx) + })? + .await?; + Ok(acp_thread) + }) + } + + pub fn thread_summary( + &mut self, + id: acp::SessionId, + cx: &mut Context, + ) -> Task> { + let thread = self.open_thread(id.clone(), cx); + cx.spawn(async move |this, cx| { + let acp_thread = thread.await?; + let result = this + .update(cx, |this, cx| { + this.sessions + .get(&id) + .unwrap() + .thread + .update(cx, |thread, cx| thread.summary(cx)) + })? + .await + .context("Failed to generate summary")?; + drop(acp_thread); + Ok(result) + }) + } + + fn save_thread(&mut self, thread: Entity, cx: &mut Context) { + if thread.read(cx).is_empty() { + return; + } + + let database_future = ThreadsDatabase::connect(cx); + let (id, db_thread) = + thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx))); + let Some(session) = self.sessions.get_mut(&id) else { + return; + }; + let history = self.history.clone(); + session.pending_save = cx.spawn(async move |_, cx| { + let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else { + return; + }; + let db_thread = db_thread.await; + database.save_thread(id, db_thread).await.log_err(); + history.update(cx, |history, cx| history.reload(cx)).ok(); + }); + } +} + +/// Wrapper struct that implements the AgentConnection trait +#[derive(Clone)] +pub struct NativeAgentConnection(pub Entity); + +impl NativeAgentConnection { + pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option> { + self.0 + .read(cx) + .sessions + .get(session_id) + .map(|session| session.thread.clone()) + } + + pub fn load_thread(&self, id: acp::SessionId, cx: &mut App) -> Task>> { + self.0.update(cx, |this, cx| this.load_thread(id, cx)) + } + + fn run_turn( + &self, + session_id: acp::SessionId, + cx: &mut App, + f: impl 'static + + FnOnce(Entity, &mut App) -> Result>>, + ) -> Task> { + let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| { + agent + .sessions + .get_mut(&session_id) + .map(|s| (s.thread.clone(), s.acp_thread.clone())) + }) else { + return Task::ready(Err(anyhow!("Session not found"))); + }; + log::debug!("Found session for: {}", session_id); + + let response_stream = match f(thread, cx) { + Ok(stream) => stream, + Err(err) => return Task::ready(Err(err)), + }; + Self::handle_thread_events(response_stream, acp_thread, cx) + } + + fn handle_thread_events( + mut events: mpsc::UnboundedReceiver>, + acp_thread: WeakEntity, + cx: &App, + ) -> Task> { + cx.spawn(async move |cx| { + // Handle response stream and forward to session.acp_thread + while let Some(result) = events.next().await { + match result { + Ok(event) => { + log::trace!("Received completion event: {:?}", event); + + match event { + ThreadEvent::UserMessage(message) => { + acp_thread.update(cx, |thread, cx| { + for content in message.content { + thread.push_user_content_block( + Some(message.id.clone()), + content.into(), + cx, + ); + } + })?; + } + ThreadEvent::AgentText(text) => { + acp_thread.update(cx, |thread, cx| { + thread.push_assistant_content_block( + acp::ContentBlock::Text(acp::TextContent { + text, + annotations: None, + meta: None, + }), + false, + cx, + ) + })?; + } + ThreadEvent::AgentThinking(text) => { + acp_thread.update(cx, |thread, cx| { + thread.push_assistant_content_block( + acp::ContentBlock::Text(acp::TextContent { + text, + annotations: None, + meta: None, + }), + true, + cx, + ) + })?; + } + ThreadEvent::ToolCallAuthorization(ToolCallAuthorization { + tool_call, + options, + response, + }) => { + let outcome_task = acp_thread.update(cx, |thread, cx| { + thread.request_tool_call_authorization( + tool_call, options, true, cx, + ) + })??; + cx.background_spawn(async move { + if let acp::RequestPermissionOutcome::Selected { option_id } = + outcome_task.await + { + response + .send(option_id) + .map(|_| anyhow!("authorization receiver was dropped")) + .log_err(); + } + }) + .detach(); + } + ThreadEvent::ToolCall(tool_call) => { + acp_thread.update(cx, |thread, cx| { + thread.upsert_tool_call(tool_call, cx) + })??; + } + ThreadEvent::ToolCallUpdate(update) => { + acp_thread.update(cx, |thread, cx| { + thread.update_tool_call(update, cx) + })??; + } + ThreadEvent::Retry(status) => { + acp_thread.update(cx, |thread, cx| { + thread.update_retry_status(status, cx) + })?; + } + ThreadEvent::Stop(stop_reason) => { + log::debug!("Assistant message complete: {:?}", stop_reason); + return Ok(acp::PromptResponse { + stop_reason, + meta: None, + }); + } + } + } + Err(e) => { + log::error!("Error in model response stream: {:?}", e); + return Err(e); + } + } + } + + log::debug!("Response stream completed"); + anyhow::Ok(acp::PromptResponse { + stop_reason: acp::StopReason::EndTurn, + meta: None, + }) + }) + } +} + +struct NativeAgentModelSelector { + session_id: acp::SessionId, + connection: NativeAgentConnection, +} + +impl acp_thread::AgentModelSelector for NativeAgentModelSelector { + fn list_models(&self, cx: &mut App) -> Task> { + log::debug!("NativeAgentConnection::list_models called"); + let list = self.connection.0.read(cx).models.model_list.clone(); + Task::ready(if list.is_empty() { + Err(anyhow::anyhow!("No models available")) + } else { + Ok(list) + }) + } + + fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task> { + log::debug!( + "Setting model for session {}: {}", + self.session_id, + model_id + ); + let Some(thread) = self + .connection + .0 + .read(cx) + .sessions + .get(&self.session_id) + .map(|session| session.thread.clone()) + else { + return Task::ready(Err(anyhow!("Session not found"))); + }; + + let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else { + return Task::ready(Err(anyhow!("Invalid model ID {}", model_id))); + }; + + thread.update(cx, |thread, cx| { + thread.set_model(model.clone(), cx); + }); + + update_settings_file( + self.connection.0.read(cx).fs.clone(), + cx, + move |settings, _cx| { + let provider = model.provider_id().0.to_string(); + let model = model.id().0.to_string(); + settings + .agent + .get_or_insert_default() + .set_model(LanguageModelSelection { + provider: provider.into(), + model, + }); + }, + ); + + Task::ready(Ok(())) + } + + fn selected_model(&self, cx: &mut App) -> Task> { + let Some(thread) = self + .connection + .0 + .read(cx) + .sessions + .get(&self.session_id) + .map(|session| session.thread.clone()) + else { + return Task::ready(Err(anyhow!("Session not found"))); + }; + let Some(model) = thread.read(cx).model() else { + return Task::ready(Err(anyhow!("Model not found"))); + }; + let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id()) + else { + return Task::ready(Err(anyhow!("Provider not found"))); + }; + Task::ready(Ok(LanguageModels::map_language_model_to_info( + model, &provider, + ))) + } + + fn watch(&self, cx: &mut App) -> Option> { + Some(self.connection.0.read(cx).models.watch()) + } +} + +impl acp_thread::AgentConnection for NativeAgentConnection { + fn new_thread( + self: Rc, + project: Entity, + cwd: &Path, + cx: &mut App, + ) -> Task>> { + let agent = self.0.clone(); + log::debug!("Creating new thread for project at: {:?}", cwd); + + cx.spawn(async move |cx| { + log::debug!("Starting thread creation in async context"); + + // Create Thread + let thread = agent.update( + cx, + |agent, cx: &mut gpui::Context| -> Result<_> { + // Fetch default model from registry settings + let registry = LanguageModelRegistry::read_global(cx); + // Log available models for debugging + let available_count = registry.available_models(cx).count(); + log::debug!("Total available models: {}", available_count); + + let default_model = registry.default_model().and_then(|default_model| { + agent + .models + .model_from_id(&LanguageModels::model_id(&default_model.model)) + }); + Ok(cx.new(|cx| { + Thread::new( + project.clone(), + agent.project_context.clone(), + agent.context_server_registry.clone(), + agent.templates.clone(), + default_model, + cx, + ) + })) + }, + )??; + agent.update(cx, |agent, cx| agent.register_session(thread, cx)) + }) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + &[] // No auth for in-process + } + + fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task> { + Task::ready(Ok(())) + } + + fn model_selector(&self, session_id: &acp::SessionId) -> Option> { + Some(Rc::new(NativeAgentModelSelector { + session_id: session_id.clone(), + connection: self.clone(), + }) as Rc) + } + + fn prompt( + &self, + id: Option, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { + let id = id.expect("UserMessageId is required"); + let session_id = params.session_id.clone(); + log::info!("Received prompt request for session: {}", session_id); + log::debug!("Prompt blocks count: {}", params.prompt.len()); + + self.run_turn(session_id, cx, |thread, cx| { + let content: Vec = params + .prompt + .into_iter() + .map(Into::into) + .collect::>(); + log::debug!("Converted prompt to message: {} chars", content.len()); + log::debug!("Message id: {:?}", id); + log::debug!("Message content: {:?}", content); + + thread.update(cx, |thread, cx| thread.send(id, content, cx)) + }) + } + + fn resume( + &self, + session_id: &acp::SessionId, + _cx: &App, + ) -> Option> { + Some(Rc::new(NativeAgentSessionResume { + connection: self.clone(), + session_id: session_id.clone(), + }) as _) + } + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + log::info!("Cancelling on session: {}", session_id); + self.0.update(cx, |agent, cx| { + if let Some(agent) = agent.sessions.get(session_id) { + agent.thread.update(cx, |thread, cx| thread.cancel(cx)); + } + }); + } + + fn truncate( + &self, + session_id: &agent_client_protocol::SessionId, + cx: &App, + ) -> Option> { + self.0.read_with(cx, |agent, _cx| { + agent.sessions.get(session_id).map(|session| { + Rc::new(NativeAgentSessionTruncate { + thread: session.thread.clone(), + acp_thread: session.acp_thread.clone(), + }) as _ + }) + }) + } + + fn set_title( + &self, + session_id: &acp::SessionId, + _cx: &App, + ) -> Option> { + Some(Rc::new(NativeAgentSessionSetTitle { + connection: self.clone(), + session_id: session_id.clone(), + }) as _) + } + + fn telemetry(&self) -> Option> { + Some(Rc::new(self.clone()) as Rc) + } + + fn into_any(self: Rc) -> Rc { + self + } +} + +impl acp_thread::AgentTelemetry for NativeAgentConnection { + fn agent_name(&self) -> String { + "Zed".into() + } + + fn thread_data( + &self, + session_id: &acp::SessionId, + cx: &mut App, + ) -> Task> { + let Some(session) = self.0.read(cx).sessions.get(session_id) else { + return Task::ready(Err(anyhow!("Session not found"))); + }; + + let task = session.thread.read(cx).to_db(cx); + cx.background_spawn(async move { + serde_json::to_value(task.await).context("Failed to serialize thread") + }) + } +} + +struct NativeAgentSessionTruncate { + thread: Entity, + acp_thread: WeakEntity, +} + +impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate { + fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task> { + match self.thread.update(cx, |thread, cx| { + thread.truncate(message_id.clone(), cx)?; + Ok(thread.latest_token_usage()) + }) { + Ok(usage) => { + self.acp_thread + .update(cx, |thread, cx| { + thread.update_token_usage(usage, cx); + }) + .ok(); + Task::ready(Ok(())) + } + Err(error) => Task::ready(Err(error)), + } + } +} + +struct NativeAgentSessionResume { + connection: NativeAgentConnection, + session_id: acp::SessionId, +} + +impl acp_thread::AgentSessionResume for NativeAgentSessionResume { + fn run(&self, cx: &mut App) -> Task> { + self.connection + .run_turn(self.session_id.clone(), cx, |thread, cx| { + thread.update(cx, |thread, cx| thread.resume(cx)) + }) + } +} + +struct NativeAgentSessionSetTitle { + connection: NativeAgentConnection, + session_id: acp::SessionId, +} + +impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle { + fn run(&self, title: SharedString, cx: &mut App) -> Task> { + let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else { + return Task::ready(Err(anyhow!("session not found"))); + }; + let thread = session.thread.clone(); + thread.update(cx, |thread, cx| thread.set_title(title, cx)); + Task::ready(Ok(())) + } +} + +pub struct AcpThreadEnvironment { + acp_thread: WeakEntity, +} + +impl ThreadEnvironment for AcpThreadEnvironment { + fn create_terminal( + &self, + command: String, + cwd: Option, + output_byte_limit: Option, + cx: &mut AsyncApp, + ) -> Task>> { + let task = self.acp_thread.update(cx, |thread, cx| { + thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx) + }); + + let acp_thread = self.acp_thread.clone(); + cx.spawn(async move |cx| { + let terminal = task?.await?; + + let (drop_tx, drop_rx) = oneshot::channel(); + let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?; + + cx.spawn(async move |cx| { + drop_rx.await.ok(); + acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx)) + }) + .detach(); + + let handle = AcpTerminalHandle { + terminal, + _drop_tx: Some(drop_tx), + }; + + Ok(Rc::new(handle) as _) + }) + } +} + +pub struct AcpTerminalHandle { + terminal: Entity, + _drop_tx: Option>, +} + +impl TerminalHandle for AcpTerminalHandle { + fn id(&self, cx: &AsyncApp) -> Result { + self.terminal.read_with(cx, |term, _cx| term.id().clone()) + } + + fn wait_for_exit(&self, cx: &AsyncApp) -> Result>> { + self.terminal + .read_with(cx, |term, _cx| term.wait_for_exit()) + } + + fn current_output(&self, cx: &AsyncApp) -> Result { + self.terminal + .read_with(cx, |term, cx| term.current_output(cx)) + } +} + +#[cfg(test)] +mod internal_tests { + use crate::HistoryEntryId; + + use super::*; + use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri}; + use fs::FakeFs; + use gpui::TestAppContext; + use indoc::formatdoc; + use language_model::fake_provider::FakeLanguageModel; + use serde_json::json; + use settings::SettingsStore; + use util::{path, rel_path::rel_path}; + + #[gpui::test] + async fn test_maintaining_project_context(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/", + json!({ + "a": {} + }), + ) + .await; + let project = Project::test(fs.clone(), [], cx).await; + let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + let agent = NativeAgent::new( + project.clone(), + history_store, + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); + agent.read_with(cx, |agent, cx| { + assert_eq!(agent.project_context.read(cx).worktrees, vec![]) + }); + + let worktree = project + .update(cx, |project, cx| project.create_worktree("/a", true, cx)) + .await + .unwrap(); + cx.run_until_parked(); + agent.read_with(cx, |agent, cx| { + assert_eq!( + agent.project_context.read(cx).worktrees, + vec![WorktreeContext { + root_name: "a".into(), + abs_path: Path::new("/a").into(), + rules_file: None + }] + ) + }); + + // Creating `/a/.rules` updates the project context. + fs.insert_file("/a/.rules", Vec::new()).await; + cx.run_until_parked(); + agent.read_with(cx, |agent, cx| { + let rules_entry = worktree + .read(cx) + .entry_for_path(rel_path(".rules")) + .unwrap(); + assert_eq!( + agent.project_context.read(cx).worktrees, + vec![WorktreeContext { + root_name: "a".into(), + abs_path: Path::new("/a").into(), + rules_file: Some(RulesFileContext { + path_in_worktree: rel_path(".rules").into(), + text: "".into(), + project_entry_id: rules_entry.id.to_usize() + }) + }] + ) + }); + } + + #[gpui::test] + async fn test_listing_models(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree("/", json!({ "a": {} })).await; + let project = Project::test(fs.clone(), [], cx).await; + let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + let connection = NativeAgentConnection( + NativeAgent::new( + project.clone(), + history_store, + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(), + ); + + // Create a thread/session + let acp_thread = cx + .update(|cx| { + Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx) + }) + .await + .unwrap(); + + let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone()); + + let models = cx + .update(|cx| { + connection + .model_selector(&session_id) + .unwrap() + .list_models(cx) + }) + .await + .unwrap(); + + let acp_thread::AgentModelList::Grouped(models) = models else { + panic!("Unexpected model group"); + }; + assert_eq!( + models, + IndexMap::from_iter([( + AgentModelGroupName("Fake".into()), + vec![AgentModelInfo { + id: acp::ModelId("fake/fake".into()), + name: "Fake".into(), + description: None, + icon: Some(ui::IconName::ZedAssistant), + }] + )]) + ); + } + + #[gpui::test] + async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.create_dir(paths::settings_file().parent().unwrap()) + .await + .unwrap(); + fs.insert_file( + paths::settings_file(), + json!({ + "agent": { + "default_model": { + "provider": "foo", + "model": "bar" + } + } + }) + .to_string() + .into_bytes(), + ) + .await; + let project = Project::test(fs.clone(), [], cx).await; + + let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + + // Create the agent and connection + let agent = NativeAgent::new( + project.clone(), + history_store, + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); + let connection = NativeAgentConnection(agent.clone()); + + // Create a thread/session + let acp_thread = cx + .update(|cx| { + Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx) + }) + .await + .unwrap(); + + let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone()); + + // Select a model + let selector = connection.model_selector(&session_id).unwrap(); + let model_id = acp::ModelId("fake/fake".into()); + cx.update(|cx| selector.select_model(model_id.clone(), cx)) + .await + .unwrap(); + + // Verify the thread has the selected model + agent.read_with(cx, |agent, _| { + let session = agent.sessions.get(&session_id).unwrap(); + session.thread.read_with(cx, |thread, _| { + assert_eq!(thread.model().unwrap().id().0, "fake"); + }); + }); + + cx.run_until_parked(); + + // Verify settings file was updated + let settings_content = fs.load(paths::settings_file()).await.unwrap(); + let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap(); + + // Check that the agent settings contain the selected model + assert_eq!( + settings_json["agent"]["default_model"]["model"], + json!("fake") + ); + assert_eq!( + settings_json["agent"]["default_model"]["provider"], + json!("fake") + ); + } + + #[gpui::test] + async fn test_save_load_thread(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/", + json!({ + "a": { + "b.md": "Lorem" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; + let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx)); + let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); + let agent = NativeAgent::new( + project.clone(), + history_store.clone(), + Templates::new(), + None, + fs.clone(), + &mut cx.to_async(), + ) + .await + .unwrap(); + let connection = Rc::new(NativeAgentConnection(agent.clone())); + + let acp_thread = cx + .update(|cx| { + connection + .clone() + .new_thread(project.clone(), Path::new(""), cx) + }) + .await + .unwrap(); + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + let thread = agent.read_with(cx, |agent, _| { + agent.sessions.get(&session_id).unwrap().thread.clone() + }); + + // Ensure empty threads are not saved, even if they get mutated. + let model = Arc::new(FakeLanguageModel::default()); + let summary_model = Arc::new(FakeLanguageModel::default()); + thread.update(cx, |thread, cx| { + thread.set_model(model.clone(), cx); + thread.set_summarization_model(Some(summary_model.clone()), cx); + }); + cx.run_until_parked(); + assert_eq!(history_entries(&history_store, cx), vec![]); + + let send = acp_thread.update(cx, |thread, cx| { + thread.send( + vec![ + "What does ".into(), + acp::ContentBlock::ResourceLink(acp::ResourceLink { + name: "b.md".into(), + uri: MentionUri::File { + abs_path: path!("/a/b.md").into(), + } + .to_uri() + .to_string(), + annotations: None, + description: None, + mime_type: None, + size: None, + title: None, + meta: None, + }), + " mean?".into(), + ], + cx, + ) + }); + let send = cx.foreground_executor().spawn(send); + cx.run_until_parked(); + + model.send_last_completion_stream_text_chunk("Lorem."); + model.end_last_completion_stream(); + cx.run_until_parked(); + summary_model + .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md"))); + summary_model.end_last_completion_stream(); + + send.await.unwrap(); + let uri = MentionUri::File { + abs_path: path!("/a/b.md").into(), + } + .to_uri(); + acp_thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + formatdoc! {" + ## User + + What does [@b.md]({uri}) mean? + + ## Assistant + + Lorem. + + "} + ) + }); + + cx.run_until_parked(); + + // Drop the ACP thread, which should cause the session to be dropped as well. + cx.update(|_| { + drop(thread); + drop(acp_thread); + }); + agent.read_with(cx, |agent, _| { + assert_eq!(agent.sessions.keys().cloned().collect::>(), []); + }); + + // Ensure the thread can be reloaded from disk. + assert_eq!( + history_entries(&history_store, cx), + vec![( + HistoryEntryId::AcpThread(session_id.clone()), + format!("Explaining {}", path!("/a/b.md")) + )] + ); + let acp_thread = agent + .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx)) + .await + .unwrap(); + acp_thread.read_with(cx, |thread, cx| { + assert_eq!( + thread.to_markdown(cx), + formatdoc! {" + ## User + + What does [@b.md]({uri}) mean? + + ## Assistant + + Lorem. + + "} + ) + }); + } + + fn history_entries( + history: &Entity, + cx: &mut TestAppContext, + ) -> Vec<(HistoryEntryId, String)> { + history.read_with(cx, |history, _| { + history + .entries() + .map(|e| (e.id(), e.title().to_string())) + .collect::>() + }) + } -pub fn init(fs: Arc, cx: &mut gpui::App) { - thread_store::init(fs, cx); + fn init_test(cx: &mut TestAppContext) { + env_logger::try_init().ok(); + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + Project::init_settings(cx); + agent_settings::init(cx); + language::init(cx); + LanguageModelRegistry::test(cx); + }); + } } diff --git a/crates/agent/src/agent_profile.rs b/crates/agent/src/agent_profile.rs deleted file mode 100644 index 40ba2f07db7ad425a5d0e9befe91499eb746b74e..0000000000000000000000000000000000000000 --- a/crates/agent/src/agent_profile.rs +++ /dev/null @@ -1,341 +0,0 @@ -use std::sync::Arc; - -use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings}; -use assistant_tool::{Tool, ToolSource, ToolWorkingSet, UniqueToolName}; -use collections::IndexMap; -use convert_case::{Case, Casing}; -use fs::Fs; -use gpui::{App, Entity, SharedString}; -use settings::{Settings, update_settings_file}; -use util::ResultExt; - -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct AgentProfile { - id: AgentProfileId, - tool_set: Entity, -} - -pub type AvailableProfiles = IndexMap; - -impl AgentProfile { - pub fn new(id: AgentProfileId, tool_set: Entity) -> Self { - Self { id, tool_set } - } - - /// Saves a new profile to the settings. - pub fn create( - name: String, - base_profile_id: Option, - fs: Arc, - cx: &App, - ) -> AgentProfileId { - let id = AgentProfileId(name.to_case(Case::Kebab).into()); - - let base_profile = - base_profile_id.and_then(|id| AgentSettings::get_global(cx).profiles.get(&id).cloned()); - - let profile_settings = AgentProfileSettings { - name: name.into(), - tools: base_profile - .as_ref() - .map(|profile| profile.tools.clone()) - .unwrap_or_default(), - enable_all_context_servers: base_profile - .as_ref() - .map(|profile| profile.enable_all_context_servers) - .unwrap_or_default(), - context_servers: base_profile - .map(|profile| profile.context_servers) - .unwrap_or_default(), - }; - - update_settings_file(fs, cx, { - let id = id.clone(); - move |settings, _cx| { - profile_settings.save_to_settings(id, settings).log_err(); - } - }); - - id - } - - /// Returns a map of AgentProfileIds to their names - pub fn available_profiles(cx: &App) -> AvailableProfiles { - let mut profiles = AvailableProfiles::default(); - for (id, profile) in AgentSettings::get_global(cx).profiles.iter() { - profiles.insert(id.clone(), profile.name.clone()); - } - profiles - } - - pub fn id(&self) -> &AgentProfileId { - &self.id - } - - pub fn enabled_tools(&self, cx: &App) -> Vec<(UniqueToolName, Arc)> { - let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else { - return Vec::new(); - }; - - self.tool_set - .read(cx) - .tools(cx) - .into_iter() - .filter(|(_, tool)| Self::is_enabled(settings, tool.source(), tool.name())) - .collect() - } - - pub fn is_tool_enabled(&self, source: ToolSource, tool_name: String, cx: &App) -> bool { - let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else { - return false; - }; - - Self::is_enabled(settings, source, tool_name) - } - - fn is_enabled(settings: &AgentProfileSettings, source: ToolSource, name: String) -> bool { - match source { - ToolSource::Native => *settings.tools.get(name.as_str()).unwrap_or(&false), - ToolSource::ContextServer { id } => settings - .context_servers - .get(id.as_ref()) - .and_then(|preset| preset.tools.get(name.as_str()).copied()) - .unwrap_or(settings.enable_all_context_servers), - } - } -} - -#[cfg(test)] -mod tests { - use agent_settings::ContextServerPreset; - use assistant_tool::ToolRegistry; - use collections::IndexMap; - use gpui::SharedString; - use gpui::{AppContext, TestAppContext}; - use http_client::FakeHttpClient; - use project::Project; - use settings::{Settings, SettingsStore}; - - use super::*; - - #[gpui::test] - async fn test_enabled_built_in_tools_for_profile(cx: &mut TestAppContext) { - init_test_settings(cx); - - let id = AgentProfileId::default(); - let profile_settings = cx.read(|cx| { - AgentSettings::get_global(cx) - .profiles - .get(&id) - .unwrap() - .clone() - }); - let tool_set = default_tool_set(cx); - - let profile = AgentProfile::new(id, tool_set); - - let mut enabled_tools = cx - .read(|cx| profile.enabled_tools(cx)) - .into_iter() - .map(|(_, tool)| tool.name()) - .collect::>(); - enabled_tools.sort(); - - let mut expected_tools = profile_settings - .tools - .into_iter() - .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string())) - // Provider dependent - .filter(|tool| tool != "web_search") - .collect::>(); - // Plus all registered MCP tools - expected_tools.extend(["enabled_mcp_tool".into(), "disabled_mcp_tool".into()]); - expected_tools.sort(); - - assert_eq!(enabled_tools, expected_tools); - } - - #[gpui::test] - async fn test_custom_mcp_settings(cx: &mut TestAppContext) { - init_test_settings(cx); - - let id = AgentProfileId("custom_mcp".into()); - let profile_settings = cx.read(|cx| { - AgentSettings::get_global(cx) - .profiles - .get(&id) - .unwrap() - .clone() - }); - let tool_set = default_tool_set(cx); - - let profile = AgentProfile::new(id, tool_set); - - let mut enabled_tools = cx - .read(|cx| profile.enabled_tools(cx)) - .into_iter() - .map(|(_, tool)| tool.name()) - .collect::>(); - enabled_tools.sort(); - - let mut expected_tools = profile_settings.context_servers["mcp"] - .tools - .iter() - .filter_map(|(key, enabled)| enabled.then(|| key.to_string())) - .collect::>(); - expected_tools.sort(); - - assert_eq!(enabled_tools, expected_tools); - } - - #[gpui::test] - async fn test_only_built_in(cx: &mut TestAppContext) { - init_test_settings(cx); - - let id = AgentProfileId("write_minus_mcp".into()); - let profile_settings = cx.read(|cx| { - AgentSettings::get_global(cx) - .profiles - .get(&id) - .unwrap() - .clone() - }); - let tool_set = default_tool_set(cx); - - let profile = AgentProfile::new(id, tool_set); - - let mut enabled_tools = cx - .read(|cx| profile.enabled_tools(cx)) - .into_iter() - .map(|(_, tool)| tool.name()) - .collect::>(); - enabled_tools.sort(); - - let mut expected_tools = profile_settings - .tools - .into_iter() - .filter_map(|(tool, enabled)| enabled.then_some(tool.to_string())) - // Provider dependent - .filter(|tool| tool != "web_search") - .collect::>(); - expected_tools.sort(); - - assert_eq!(enabled_tools, expected_tools); - } - - fn init_test_settings(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - Project::init_settings(cx); - AgentSettings::register(cx); - language_model::init_settings(cx); - ToolRegistry::default_global(cx); - assistant_tools::init(FakeHttpClient::with_404_response(), cx); - }); - - cx.update(|cx| { - let mut agent_settings = AgentSettings::get_global(cx).clone(); - agent_settings.profiles.insert( - AgentProfileId("write_minus_mcp".into()), - AgentProfileSettings { - name: "write_minus_mcp".into(), - enable_all_context_servers: false, - ..agent_settings.profiles[&AgentProfileId::default()].clone() - }, - ); - agent_settings.profiles.insert( - AgentProfileId("custom_mcp".into()), - AgentProfileSettings { - name: "mcp".into(), - tools: IndexMap::default(), - enable_all_context_servers: false, - context_servers: IndexMap::from_iter([("mcp".into(), context_server_preset())]), - }, - ); - AgentSettings::override_global(agent_settings, cx); - }) - } - - fn context_server_preset() -> ContextServerPreset { - ContextServerPreset { - tools: IndexMap::from_iter([ - ("enabled_mcp_tool".into(), true), - ("disabled_mcp_tool".into(), false), - ]), - } - } - - fn default_tool_set(cx: &mut TestAppContext) -> Entity { - cx.new(|cx| { - let mut tool_set = ToolWorkingSet::default(); - tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")), cx); - tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")), cx); - tool_set - }) - } - - struct FakeTool { - name: String, - source: SharedString, - } - - impl FakeTool { - fn new(name: impl Into, source: impl Into) -> Self { - Self { - name: name.into(), - source: source.into(), - } - } - } - - impl Tool for FakeTool { - fn name(&self) -> String { - self.name.clone() - } - - fn source(&self) -> ToolSource { - ToolSource::ContextServer { - id: self.source.clone(), - } - } - - fn description(&self) -> String { - unimplemented!() - } - - fn icon(&self) -> icons::IconName { - unimplemented!() - } - - fn needs_confirmation( - &self, - _input: &serde_json::Value, - _project: &Entity, - _cx: &App, - ) -> bool { - unimplemented!() - } - - fn ui_text(&self, _input: &serde_json::Value) -> String { - unimplemented!() - } - - fn run( - self: Arc, - _input: serde_json::Value, - _request: Arc, - _project: Entity, - _action_log: Entity, - _model: Arc, - _window: Option, - _cx: &mut App, - ) -> assistant_tool::ToolResult { - unimplemented!() - } - - fn may_perform_edits(&self) -> bool { - unimplemented!() - } - } -} diff --git a/crates/agent/src/context_server_tool.rs b/crates/agent/src/context_server_tool.rs deleted file mode 100644 index 696c569356bca36adf54bc84ec52fa7295048b75..0000000000000000000000000000000000000000 --- a/crates/agent/src/context_server_tool.rs +++ /dev/null @@ -1,140 +0,0 @@ -use std::sync::Arc; - -use action_log::ActionLog; -use anyhow::{Result, anyhow, bail}; -use assistant_tool::{Tool, ToolResult, ToolSource}; -use context_server::{ContextServerId, types}; -use gpui::{AnyWindowHandle, App, Entity, Task}; -use icons::IconName; -use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; -use project::{Project, context_server_store::ContextServerStore}; - -pub struct ContextServerTool { - store: Entity, - server_id: ContextServerId, - tool: types::Tool, -} - -impl ContextServerTool { - pub fn new( - store: Entity, - server_id: ContextServerId, - tool: types::Tool, - ) -> Self { - Self { - store, - server_id, - tool, - } - } -} - -impl Tool for ContextServerTool { - fn name(&self) -> String { - self.tool.name.clone() - } - - fn description(&self) -> String { - self.tool.description.clone().unwrap_or_default() - } - - fn icon(&self) -> IconName { - IconName::ToolHammer - } - - fn source(&self) -> ToolSource { - ToolSource::ContextServer { - id: self.server_id.clone().0.into(), - } - } - - fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { - true - } - - fn may_perform_edits(&self) -> bool { - true - } - - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { - let mut schema = self.tool.input_schema.clone(); - assistant_tool::adapt_schema_to_format(&mut schema, format)?; - Ok(match schema { - serde_json::Value::Null => { - serde_json::json!({ "type": "object", "properties": [] }) - } - serde_json::Value::Object(map) if map.is_empty() => { - serde_json::json!({ "type": "object", "properties": [] }) - } - _ => schema, - }) - } - - fn ui_text(&self, _input: &serde_json::Value) -> String { - format!("Run MCP tool `{}`", self.tool.name) - } - - fn run( - self: Arc, - input: serde_json::Value, - _request: Arc, - _project: Entity, - _action_log: Entity, - _model: Arc, - _window: Option, - cx: &mut App, - ) -> ToolResult { - if let Some(server) = self.store.read(cx).get_running_server(&self.server_id) { - let tool_name = self.tool.name.clone(); - - cx.spawn(async move |_cx| { - let Some(protocol) = server.client() else { - bail!("Context server not initialized"); - }; - - let arguments = if let serde_json::Value::Object(map) = input { - Some(map.into_iter().collect()) - } else { - None - }; - - log::trace!( - "Running tool: {} with arguments: {:?}", - tool_name, - arguments - ); - let response = protocol - .request::( - context_server::types::CallToolParams { - name: tool_name, - arguments, - meta: None, - }, - ) - .await?; - - let mut result = String::new(); - for content in response.content { - match content { - types::ToolResponseContent::Text { text } => { - result.push_str(&text); - } - types::ToolResponseContent::Image { .. } => { - log::warn!("Ignoring image content from tool response"); - } - types::ToolResponseContent::Audio { .. } => { - log::warn!("Ignoring audio content from tool response"); - } - types::ToolResponseContent::Resource { .. } => { - log::warn!("Ignoring resource content from tool response"); - } - } - } - Ok(result.into()) - }) - .into() - } else { - Task::ready(Err(anyhow!("Context server not found"))).into() - } - } -} diff --git a/crates/agent2/src/db.rs b/crates/agent/src/db.rs similarity index 78% rename from crates/agent2/src/db.rs rename to crates/agent/src/db.rs index 563ccdd7ca5b2c2cc63a8c7f30c59b9443f8a0bd..c72e20571e2761788157a5fd10df147c2b414e4a 100644 --- a/crates/agent2/src/db.rs +++ b/crates/agent/src/db.rs @@ -1,6 +1,5 @@ use crate::{AgentMessage, AgentMessageContent, UserMessage, UserMessageContent}; use acp_thread::UserMessageId; -use agent::{thread::DetailedSummaryState, thread_store}; use agent_client_protocol as acp; use agent_settings::{AgentProfileId, CompletionMode}; use anyhow::{Result, anyhow}; @@ -21,8 +20,8 @@ use ui::{App, SharedString}; use zed_env_vars::ZED_STATELESS; pub type DbMessage = crate::Message; -pub type DbSummary = DetailedSummaryState; -pub type DbLanguageModel = thread_store::SerializedLanguageModel; +pub type DbSummary = crate::legacy_thread::DetailedSummaryState; +pub type DbLanguageModel = crate::legacy_thread::SerializedLanguageModel; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DbThreadMetadata { @@ -40,7 +39,7 @@ pub struct DbThread { #[serde(default)] pub detailed_summary: Option, #[serde(default)] - pub initial_project_snapshot: Option>, + pub initial_project_snapshot: Option>, #[serde(default)] pub cumulative_token_usage: language_model::TokenUsage, #[serde(default)] @@ -61,13 +60,17 @@ impl DbThread { match saved_thread_json.get("version") { Some(serde_json::Value::String(version)) => match version.as_str() { Self::VERSION => Ok(serde_json::from_value(saved_thread_json)?), - _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?), + _ => Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json( + json, + )?), }, - _ => Self::upgrade_from_agent_1(agent::SerializedThread::from_json(json)?), + _ => { + Self::upgrade_from_agent_1(crate::legacy_thread::SerializedThread::from_json(json)?) + } } } - fn upgrade_from_agent_1(thread: agent::SerializedThread) -> Result { + fn upgrade_from_agent_1(thread: crate::legacy_thread::SerializedThread) -> Result { let mut messages = Vec::new(); let mut request_token_usage = HashMap::default(); @@ -80,14 +83,19 @@ impl DbThread { // Convert segments to content for segment in msg.segments { match segment { - thread_store::SerializedMessageSegment::Text { text } => { + crate::legacy_thread::SerializedMessageSegment::Text { text } => { content.push(UserMessageContent::Text(text)); } - thread_store::SerializedMessageSegment::Thinking { text, .. } => { + crate::legacy_thread::SerializedMessageSegment::Thinking { + text, + .. + } => { // User messages don't have thinking segments, but handle gracefully content.push(UserMessageContent::Text(text)); } - thread_store::SerializedMessageSegment::RedactedThinking { .. } => { + crate::legacy_thread::SerializedMessageSegment::RedactedThinking { + .. + } => { // User messages don't have redacted thinking, skip. } } @@ -113,16 +121,18 @@ impl DbThread { // Convert segments to content for segment in msg.segments { match segment { - thread_store::SerializedMessageSegment::Text { text } => { + crate::legacy_thread::SerializedMessageSegment::Text { text } => { content.push(AgentMessageContent::Text(text)); } - thread_store::SerializedMessageSegment::Thinking { + crate::legacy_thread::SerializedMessageSegment::Thinking { text, signature, } => { content.push(AgentMessageContent::Thinking { text, signature }); } - thread_store::SerializedMessageSegment::RedactedThinking { data } => { + crate::legacy_thread::SerializedMessageSegment::RedactedThinking { + data, + } => { content.push(AgentMessageContent::RedactedThinking(data)); } } @@ -187,10 +197,9 @@ impl DbThread { messages, updated_at: thread.updated_at, detailed_summary: match thread.detailed_summary_state { - DetailedSummaryState::NotGenerated | DetailedSummaryState::Generating { .. } => { - None - } - DetailedSummaryState::Generated { text, .. } => Some(text), + crate::legacy_thread::DetailedSummaryState::NotGenerated + | crate::legacy_thread::DetailedSummaryState::Generating => None, + crate::legacy_thread::DetailedSummaryState::Generated { text, .. } => Some(text), }, initial_project_snapshot: thread.initial_project_snapshot, cumulative_token_usage: thread.cumulative_token_usage, @@ -414,84 +423,3 @@ impl ThreadsDatabase { }) } } - -#[cfg(test)] -mod tests { - - use super::*; - use agent::MessageSegment; - use agent::context::LoadedContext; - use client::Client; - use fs::{FakeFs, Fs}; - use gpui::AppContext; - use gpui::TestAppContext; - use http_client::FakeHttpClient; - use language_model::Role; - use project::Project; - use settings::SettingsStore; - - fn init_test(fs: Arc, cx: &mut TestAppContext) { - env_logger::try_init().ok(); - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - Project::init_settings(cx); - language::init(cx); - - let http_client = FakeHttpClient::with_404_response(); - let clock = Arc::new(clock::FakeSystemClock::new()); - let client = Client::new(clock, http_client, cx); - agent::init(fs, cx); - agent_settings::init(cx); - language_model::init(client, cx); - }); - } - - #[gpui::test] - async fn test_retrieving_old_thread(cx: &mut TestAppContext) { - let fs = FakeFs::new(cx.executor()); - init_test(fs.clone(), cx); - let project = Project::test(fs, [], cx).await; - - // Save a thread using the old agent. - let thread_store = cx.new(|cx| agent::ThreadStore::fake(project, cx)); - let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx)); - thread.update(cx, |thread, cx| { - thread.insert_message( - Role::User, - vec![MessageSegment::Text("Hey!".into())], - LoadedContext::default(), - vec![], - false, - cx, - ); - thread.insert_message( - Role::Assistant, - vec![MessageSegment::Text("How're you doing?".into())], - LoadedContext::default(), - vec![], - false, - cx, - ) - }); - thread_store - .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx)) - .await - .unwrap(); - - // Open that same thread using the new agent. - let db = cx.update(ThreadsDatabase::connect).await.unwrap(); - let threads = db.list_threads().await.unwrap(); - assert_eq!(threads.len(), 1); - let thread = db - .load_thread(threads[0].id.clone()) - .await - .unwrap() - .unwrap(); - assert_eq!(thread.messages[0].to_markdown(), "## User\n\nHey!\n"); - assert_eq!( - thread.messages[1].to_markdown(), - "## Assistant\n\nHow're you doing?\n" - ); - } -} diff --git a/crates/assistant_tools/src/edit_agent.rs b/crates/agent/src/edit_agent.rs similarity index 100% rename from crates/assistant_tools/src/edit_agent.rs rename to crates/agent/src/edit_agent.rs diff --git a/crates/assistant_tools/src/edit_agent/create_file_parser.rs b/crates/agent/src/edit_agent/create_file_parser.rs similarity index 100% rename from crates/assistant_tools/src/edit_agent/create_file_parser.rs rename to crates/agent/src/edit_agent/create_file_parser.rs diff --git a/crates/assistant_tools/src/edit_agent/edit_parser.rs b/crates/agent/src/edit_agent/edit_parser.rs similarity index 100% rename from crates/assistant_tools/src/edit_agent/edit_parser.rs rename to crates/agent/src/edit_agent/edit_parser.rs diff --git a/crates/assistant_tools/src/edit_agent/evals.rs b/crates/agent/src/edit_agent/evals.rs similarity index 97% rename from crates/assistant_tools/src/edit_agent/evals.rs rename to crates/agent/src/edit_agent/evals.rs index 515e22d5f8b184a875cd91038d7bfa0a7d8127a7..b3043f0a81256568338f5d4be22bfe02de277076 100644 --- a/crates/assistant_tools/src/edit_agent/evals.rs +++ b/crates/agent/src/edit_agent/evals.rs @@ -1,12 +1,8 @@ use super::*; use crate::{ - ReadFileToolInput, - edit_file_tool::{EditFileMode, EditFileToolInput}, - grep_tool::GrepToolInput, - list_directory_tool::ListDirectoryToolInput, + EditFileMode, EditFileToolInput, GrepToolInput, ListDirectoryToolInput, ReadFileToolInput, }; use Role::*; -use assistant_tool::ToolRegistry; use client::{Client, UserStore}; use collections::HashMap; use fs::FakeFs; @@ -15,11 +11,11 @@ use gpui::{AppContext, TestAppContext, Timer}; use http_client::StatusCode; use indoc::{formatdoc, indoc}; use language_model::{ - LanguageModelRegistry, LanguageModelRequestTool, LanguageModelToolResult, - LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, SelectedModel, + LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolResultContent, + LanguageModelToolUse, LanguageModelToolUseId, SelectedModel, }; use project::Project; -use prompt_store::{ModelContext, ProjectContext, PromptBuilder, WorktreeContext}; +use prompt_store::{ProjectContext, WorktreeContext}; use rand::prelude::*; use reqwest_client::ReqwestClient; use serde_json::json; @@ -121,6 +117,7 @@ fn eval_delete_run_git_blame() { // gemini-2.5-pro-06-05 | 1.0 (2025-06-16) // gemini-2.5-flash | // gpt-4.1 | + let input_file_path = "root/blame.rs"; let input_file_content = include_str!("evals/fixtures/delete_run_git_blame/before.rs"); let output_file_content = include_str!("evals/fixtures/delete_run_git_blame/after.rs"); @@ -184,6 +181,7 @@ fn eval_translate_doc_comments() { // gemini-2.5-pro-preview-03-25 | 1.0 (2025-05-22) // gemini-2.5-flash-preview-04-17 | // gpt-4.1 | + let input_file_path = "root/canvas.rs"; let input_file_content = include_str!("evals/fixtures/translate_doc_comments/before.rs"); let edit_description = "Translate all doc comments to Italian"; @@ -246,6 +244,7 @@ fn eval_use_wasi_sdk_in_compile_parser_to_wasm() { // gemini-2.5-pro-preview-latest | 0.99 (2025-06-16) // gemini-2.5-flash-preview-04-17 | // gpt-4.1 | + let input_file_path = "root/lib.rs"; let input_file_content = include_str!("evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs"); @@ -371,6 +370,7 @@ fn eval_disable_cursor_blinking() { // gemini-2.5-pro | 0.95 (2025-07-14) // gemini-2.5-flash-preview-04-17 | 0.78 (2025-07-14) // gpt-4.1 | 0.00 (2025-07-14) (follows edit_description too literally) + let input_file_path = "root/editor.rs"; let input_file_content = include_str!("evals/fixtures/disable_cursor_blinking/before.rs"); let edit_description = "Comment out the call to `BlinkManager::enable`"; @@ -463,6 +463,7 @@ fn eval_from_pixels_constructor() { // claude-3.7-sonnet | 2025-06-14 | 0.88 // gemini-2.5-pro-preview-06-05 | 2025-06-16 | 0.98 // gpt-4.1 | + let input_file_path = "root/canvas.rs"; let input_file_content = include_str!("evals/fixtures/from_pixels_constructor/before.rs"); let edit_description = "Implement from_pixels constructor and add tests."; @@ -665,6 +666,7 @@ fn eval_zode() { // gemini-2.5-pro-preview-03-25 | 1.0 (2025-05-22) // gemini-2.5-flash-preview-04-17 | 1.0 (2025-05-22) // gpt-4.1 | 1.0 (2025-05-22) + let input_file_path = "root/zode.py"; let input_content = None; let edit_description = "Create the main Zode CLI script"; @@ -771,6 +773,7 @@ fn eval_add_overwrite_test() { // gemini-2.5-pro-preview-03-25 | 0.35 (2025-05-22) // gemini-2.5-flash-preview-04-17 | // gpt-4.1 | + let input_file_path = "root/action_log.rs"; let input_file_content = include_str!("evals/fixtures/add_overwrite_test/before.rs"); let edit_description = "Add a new test for overwriting a file in action_log.rs"; @@ -1010,7 +1013,7 @@ fn eval_create_empty_file() { // // TODO: gpt-4.1-mini errored 38 times: // "data did not match any variant of untagged enum ResponseStreamResult" - // + let input_file_content = None; let expected_output_content = String::new(); eval( @@ -1475,19 +1478,16 @@ impl EditAgentTest { language::init(cx); language_model::init(client.clone(), cx); language_models::init(user_store, client.clone(), cx); - crate::init(client.http_client(), cx); }); fs.insert_tree("/root", json!({})).await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; let agent_model = SelectedModel::from_str( - &std::env::var("ZED_AGENT_MODEL") - .unwrap_or("anthropic/claude-3-7-sonnet-latest".into()), + &std::env::var("ZED_AGENT_MODEL").unwrap_or("anthropic/claude-4-sonnet-latest".into()), ) .unwrap(); let judge_model = SelectedModel::from_str( - &std::env::var("ZED_JUDGE_MODEL") - .unwrap_or("anthropic/claude-3-7-sonnet-latest".into()), + &std::env::var("ZED_JUDGE_MODEL").unwrap_or("anthropic/claude-4-sonnet-latest".into()), ) .unwrap(); let (agent_model, judge_model) = cx @@ -1553,39 +1553,27 @@ impl EditAgentTest { .update(cx, |project, cx| project.open_buffer(path, cx)) .await .unwrap(); - let tools = cx.update(|cx| { - ToolRegistry::default_global(cx) - .tools() - .into_iter() - .filter_map(|tool| { - let input_schema = tool - .input_schema(self.agent.model.tool_input_format()) - .ok()?; - Some(LanguageModelRequestTool { - name: tool.name(), - description: tool.description(), - input_schema, - }) - }) - .collect::>() - }); - let tool_names = tools - .iter() - .map(|tool| tool.name.clone()) - .collect::>(); - let worktrees = vec![WorktreeContext { - root_name: "root".to_string(), - abs_path: Path::new("/path/to/root").into(), - rules_file: None, - }]; - let prompt_builder = PromptBuilder::new(None)?; - let project_context = ProjectContext::new(worktrees, Vec::default()); - let system_prompt = prompt_builder.generate_assistant_system_prompt( - &project_context, - &ModelContext { + + let tools = crate::built_in_tools().collect::>(); + + let system_prompt = { + let worktrees = vec![WorktreeContext { + root_name: "root".to_string(), + abs_path: Path::new("/path/to/root").into(), + rules_file: None, + }]; + let project_context = ProjectContext::new(worktrees, Vec::default()); + let tool_names = tools + .iter() + .map(|tool| tool.name.clone().into()) + .collect::>(); + let template = crate::SystemPromptTemplate { + project: &project_context, available_tools: tool_names, - }, - )?; + }; + let templates = Templates::new(); + template.render(&templates).unwrap() + }; let has_system_prompt = eval .conversation diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/add_overwrite_test/before.rs b/crates/agent/src/edit_agent/evals/fixtures/add_overwrite_test/before.rs similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/add_overwrite_test/before.rs rename to crates/agent/src/edit_agent/evals/fixtures/add_overwrite_test/before.rs diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/delete_run_git_blame/after.rs b/crates/agent/src/edit_agent/evals/fixtures/delete_run_git_blame/after.rs similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/delete_run_git_blame/after.rs rename to crates/agent/src/edit_agent/evals/fixtures/delete_run_git_blame/after.rs diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/delete_run_git_blame/before.rs b/crates/agent/src/edit_agent/evals/fixtures/delete_run_git_blame/before.rs similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/delete_run_git_blame/before.rs rename to crates/agent/src/edit_agent/evals/fixtures/delete_run_git_blame/before.rs diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/before.rs b/crates/agent/src/edit_agent/evals/fixtures/disable_cursor_blinking/before.rs similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/before.rs rename to crates/agent/src/edit_agent/evals/fixtures/disable_cursor_blinking/before.rs diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-01.diff b/crates/agent/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-01.diff similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-01.diff rename to crates/agent/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-01.diff diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-02.diff b/crates/agent/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-02.diff similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-02.diff rename to crates/agent/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-02.diff diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-03.diff b/crates/agent/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-03.diff similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-03.diff rename to crates/agent/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-03.diff diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-04.diff b/crates/agent/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-04.diff similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-04.diff rename to crates/agent/src/edit_agent/evals/fixtures/disable_cursor_blinking/possible-04.diff diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/extract_handle_command_output/before.rs b/crates/agent/src/edit_agent/evals/fixtures/extract_handle_command_output/before.rs similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/extract_handle_command_output/before.rs rename to crates/agent/src/edit_agent/evals/fixtures/extract_handle_command_output/before.rs diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-01.diff b/crates/agent/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-01.diff similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-01.diff rename to crates/agent/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-01.diff diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-02.diff b/crates/agent/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-02.diff similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-02.diff rename to crates/agent/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-02.diff diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-03.diff b/crates/agent/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-03.diff similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-03.diff rename to crates/agent/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-03.diff diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-04.diff b/crates/agent/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-04.diff similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-04.diff rename to crates/agent/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-04.diff diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-05.diff b/crates/agent/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-05.diff similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-05.diff rename to crates/agent/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-05.diff diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-06.diff b/crates/agent/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-06.diff similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-06.diff rename to crates/agent/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-06.diff diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-07.diff b/crates/agent/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-07.diff similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-07.diff rename to crates/agent/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-07.diff diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-08.diff b/crates/agent/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-08.diff similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-08.diff rename to crates/agent/src/edit_agent/evals/fixtures/extract_handle_command_output/possible-08.diff diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/from_pixels_constructor/before.rs b/crates/agent/src/edit_agent/evals/fixtures/from_pixels_constructor/before.rs similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/from_pixels_constructor/before.rs rename to crates/agent/src/edit_agent/evals/fixtures/from_pixels_constructor/before.rs diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/translate_doc_comments/before.rs b/crates/agent/src/edit_agent/evals/fixtures/translate_doc_comments/before.rs similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/translate_doc_comments/before.rs rename to crates/agent/src/edit_agent/evals/fixtures/translate_doc_comments/before.rs diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs b/crates/agent/src/edit_agent/evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs rename to crates/agent/src/edit_agent/evals/fixtures/use_wasi_sdk_in_compile_parser_to_wasm/before.rs diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/zode/prompt.md b/crates/agent/src/edit_agent/evals/fixtures/zode/prompt.md similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/zode/prompt.md rename to crates/agent/src/edit_agent/evals/fixtures/zode/prompt.md diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/zode/react.py b/crates/agent/src/edit_agent/evals/fixtures/zode/react.py similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/zode/react.py rename to crates/agent/src/edit_agent/evals/fixtures/zode/react.py diff --git a/crates/assistant_tools/src/edit_agent/evals/fixtures/zode/react_test.py b/crates/agent/src/edit_agent/evals/fixtures/zode/react_test.py similarity index 100% rename from crates/assistant_tools/src/edit_agent/evals/fixtures/zode/react_test.py rename to crates/agent/src/edit_agent/evals/fixtures/zode/react_test.py diff --git a/crates/assistant_tools/src/edit_agent/streaming_fuzzy_matcher.rs b/crates/agent/src/edit_agent/streaming_fuzzy_matcher.rs similarity index 100% rename from crates/assistant_tools/src/edit_agent/streaming_fuzzy_matcher.rs rename to crates/agent/src/edit_agent/streaming_fuzzy_matcher.rs diff --git a/crates/agent2/src/history_store.rs b/crates/agent/src/history_store.rs similarity index 80% rename from crates/agent2/src/history_store.rs rename to crates/agent/src/history_store.rs index ff6caacc78e5dba4ee38f160fa6ded7fcb45a845..c342110f3ee289b6e84241517b69fe9a86efcf16 100644 --- a/crates/agent2/src/history_store.rs +++ b/crates/agent/src/history_store.rs @@ -1,4 +1,4 @@ -use crate::{DbThreadMetadata, ThreadsDatabase}; +use crate::{DbThread, DbThreadMetadata, ThreadsDatabase}; use acp_thread::MentionUri; use agent_client_protocol as acp; use anyhow::{Context as _, Result, anyhow}; @@ -8,8 +8,9 @@ use db::kvp::KEY_VALUE_STORE; use gpui::{App, AsyncApp, Entity, SharedString, Task, prelude::*}; use itertools::Itertools; use paths::contexts_dir; +use project::Project; use serde::{Deserialize, Serialize}; -use std::{collections::VecDeque, path::Path, sync::Arc, time::Duration}; +use std::{collections::VecDeque, path::Path, rc::Rc, sync::Arc, time::Duration}; use ui::ElementId; use util::ResultExt as _; @@ -19,6 +20,33 @@ const SAVE_RECENTLY_OPENED_ENTRIES_DEBOUNCE: Duration = Duration::from_millis(50 const DEFAULT_TITLE: &SharedString = &SharedString::new_static("New Thread"); +//todo: We should remove this function once we support loading all acp thread +pub fn load_agent_thread( + session_id: acp::SessionId, + history_store: Entity, + project: Entity, + cx: &mut App, +) -> Task>> { + use agent_servers::{AgentServer, AgentServerDelegate}; + + let server = Rc::new(crate::NativeAgentServer::new( + project.read(cx).fs().clone(), + history_store, + )); + let delegate = AgentServerDelegate::new( + project.read(cx).agent_server_store().clone(), + project.clone(), + None, + None, + ); + let connection = server.connect(None, delegate, cx); + cx.spawn(async move |cx| { + let (agent, _) = connection.await?; + let agent = agent.downcast::().unwrap(); + cx.update(|cx| agent.load_thread(session_id, cx))?.await + }) +} + #[derive(Clone, Debug)] pub enum HistoryEntry { AcpThread(DbThreadMetadata), @@ -55,8 +83,13 @@ impl HistoryEntry { pub fn title(&self) -> &SharedString { match self { - HistoryEntry::AcpThread(thread) if thread.title.is_empty() => DEFAULT_TITLE, - HistoryEntry::AcpThread(thread) => &thread.title, + HistoryEntry::AcpThread(thread) => { + if thread.title.is_empty() { + DEFAULT_TITLE + } else { + &thread.title + } + } HistoryEntry::TextThread(context) => &context.title, } } @@ -87,7 +120,7 @@ enum SerializedRecentOpen { pub struct HistoryStore { threads: Vec, entries: Vec, - context_store: Entity, + text_thread_store: Entity, recently_opened_entries: VecDeque, _subscriptions: Vec, _save_recently_opened_entries_task: Task<()>, @@ -95,10 +128,11 @@ pub struct HistoryStore { impl HistoryStore { pub fn new( - context_store: Entity, + text_thread_store: Entity, cx: &mut Context, ) -> Self { - let subscriptions = vec![cx.observe(&context_store, |this, _, cx| this.update_entries(cx))]; + let subscriptions = + vec![cx.observe(&text_thread_store, |this, _, cx| this.update_entries(cx))]; cx.spawn(async move |this, cx| { let entries = Self::load_recently_opened_entries(cx).await; @@ -114,7 +148,7 @@ impl HistoryStore { .detach(); Self { - context_store, + text_thread_store, recently_opened_entries: VecDeque::default(), threads: Vec::default(), entries: Vec::default(), @@ -127,6 +161,18 @@ impl HistoryStore { self.threads.iter().find(|thread| &thread.id == session_id) } + pub fn load_thread( + &mut self, + id: acp::SessionId, + cx: &mut Context, + ) -> Task>> { + let database_future = ThreadsDatabase::connect(cx); + cx.background_spawn(async move { + let database = database_future.await.map_err(|err| anyhow!(err))?; + database.load_thread(id).await + }) + } + pub fn delete_thread( &mut self, id: acp::SessionId, @@ -145,9 +191,8 @@ impl HistoryStore { path: Arc, cx: &mut Context, ) -> Task> { - self.context_store.update(cx, |context_store, cx| { - context_store.delete_local_context(path, cx) - }) + self.text_thread_store + .update(cx, |store, cx| store.delete_local_context(path, cx)) } pub fn load_text_thread( @@ -155,9 +200,8 @@ impl HistoryStore { path: Arc, cx: &mut Context, ) -> Task>> { - self.context_store.update(cx, |context_store, cx| { - context_store.open_local_context(path, cx) - }) + self.text_thread_store + .update(cx, |store, cx| store.open_local_context(path, cx)) } pub fn reload(&self, cx: &mut Context) { @@ -197,7 +241,7 @@ impl HistoryStore { let mut history_entries = Vec::new(); history_entries.extend(self.threads.iter().cloned().map(HistoryEntry::AcpThread)); history_entries.extend( - self.context_store + self.text_thread_store .read(cx) .unordered_contexts() .cloned() @@ -231,21 +275,21 @@ impl HistoryStore { }) }); - let context_entries = - self.context_store - .read(cx) - .unordered_contexts() - .flat_map(|context| { - self.recently_opened_entries - .iter() - .enumerate() - .flat_map(|(index, entry)| match entry { - HistoryEntryId::TextThread(path) if &context.path == path => { - Some((index, HistoryEntry::TextThread(context.clone()))) - } - _ => None, - }) - }); + let context_entries = self + .text_thread_store + .read(cx) + .unordered_contexts() + .flat_map(|context| { + self.recently_opened_entries + .iter() + .enumerate() + .flat_map(|(index, entry)| match entry { + HistoryEntryId::TextThread(path) if &context.path == path => { + Some((index, HistoryEntry::TextThread(context.clone()))) + } + _ => None, + }) + }); thread_entries .chain(context_entries) diff --git a/crates/agent/src/legacy_thread.rs b/crates/agent/src/legacy_thread.rs new file mode 100644 index 0000000000000000000000000000000000000000..34babb800616e7a3d5390abdaccc0cafa24ff386 --- /dev/null +++ b/crates/agent/src/legacy_thread.rs @@ -0,0 +1,402 @@ +use crate::ProjectSnapshot; +use agent_settings::{AgentProfileId, CompletionMode}; +use anyhow::Result; +use chrono::{DateTime, Utc}; +use gpui::SharedString; +use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub enum DetailedSummaryState { + #[default] + NotGenerated, + Generating, + Generated { + text: SharedString, + }, +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] +pub struct MessageId(pub usize); + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +pub struct SerializedThread { + pub version: String, + pub summary: SharedString, + pub updated_at: DateTime, + pub messages: Vec, + #[serde(default)] + pub initial_project_snapshot: Option>, + #[serde(default)] + pub cumulative_token_usage: TokenUsage, + #[serde(default)] + pub request_token_usage: Vec, + #[serde(default)] + pub detailed_summary_state: DetailedSummaryState, + #[serde(default)] + pub model: Option, + #[serde(default)] + pub completion_mode: Option, + #[serde(default)] + pub tool_use_limit_reached: bool, + #[serde(default)] + pub profile: Option, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +pub struct SerializedLanguageModel { + pub provider: String, + pub model: String, +} + +impl SerializedThread { + pub const VERSION: &'static str = "0.2.0"; + + pub fn from_json(json: &[u8]) -> Result { + let saved_thread_json = serde_json::from_slice::(json)?; + match saved_thread_json.get("version") { + Some(serde_json::Value::String(version)) => match version.as_str() { + SerializedThreadV0_1_0::VERSION => { + let saved_thread = + serde_json::from_value::(saved_thread_json)?; + Ok(saved_thread.upgrade()) + } + SerializedThread::VERSION => Ok(serde_json::from_value::( + saved_thread_json, + )?), + _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"), + }, + None => { + let saved_thread = + serde_json::from_value::(saved_thread_json)?; + Ok(saved_thread.upgrade()) + } + version => anyhow::bail!("unrecognized serialized thread version: {version:?}"), + } + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct SerializedThreadV0_1_0( + // The structure did not change, so we are reusing the latest SerializedThread. + // When making the next version, make sure this points to SerializedThreadV0_2_0 + SerializedThread, +); + +impl SerializedThreadV0_1_0 { + pub const VERSION: &'static str = "0.1.0"; + + pub fn upgrade(self) -> SerializedThread { + debug_assert_eq!(SerializedThread::VERSION, "0.2.0"); + + let mut messages: Vec = Vec::with_capacity(self.0.messages.len()); + + for message in self.0.messages { + if message.role == Role::User + && !message.tool_results.is_empty() + && let Some(last_message) = messages.last_mut() + { + debug_assert!(last_message.role == Role::Assistant); + + last_message.tool_results = message.tool_results; + continue; + } + + messages.push(message); + } + + SerializedThread { + messages, + version: SerializedThread::VERSION.to_string(), + ..self.0 + } + } +} + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +pub struct SerializedMessage { + pub id: MessageId, + pub role: Role, + #[serde(default)] + pub segments: Vec, + #[serde(default)] + pub tool_uses: Vec, + #[serde(default)] + pub tool_results: Vec, + #[serde(default)] + pub context: String, + #[serde(default)] + pub creases: Vec, + #[serde(default)] + pub is_hidden: bool, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +#[serde(tag = "type")] +pub enum SerializedMessageSegment { + #[serde(rename = "text")] + Text { + text: String, + }, + #[serde(rename = "thinking")] + Thinking { + text: String, + #[serde(skip_serializing_if = "Option::is_none")] + signature: Option, + }, + RedactedThinking { + data: String, + }, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +pub struct SerializedToolUse { + pub id: LanguageModelToolUseId, + pub name: SharedString, + pub input: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +pub struct SerializedToolResult { + pub tool_use_id: LanguageModelToolUseId, + pub is_error: bool, + pub content: LanguageModelToolResultContent, + pub output: Option, +} + +#[derive(Serialize, Deserialize)] +struct LegacySerializedThread { + pub summary: SharedString, + pub updated_at: DateTime, + pub messages: Vec, + #[serde(default)] + pub initial_project_snapshot: Option>, +} + +impl LegacySerializedThread { + pub fn upgrade(self) -> SerializedThread { + SerializedThread { + version: SerializedThread::VERSION.to_string(), + summary: self.summary, + updated_at: self.updated_at, + messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(), + initial_project_snapshot: self.initial_project_snapshot, + cumulative_token_usage: TokenUsage::default(), + request_token_usage: Vec::new(), + detailed_summary_state: DetailedSummaryState::default(), + model: None, + completion_mode: None, + tool_use_limit_reached: false, + profile: None, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct LegacySerializedMessage { + pub id: MessageId, + pub role: Role, + pub text: String, + #[serde(default)] + pub tool_uses: Vec, + #[serde(default)] + pub tool_results: Vec, +} + +impl LegacySerializedMessage { + fn upgrade(self) -> SerializedMessage { + SerializedMessage { + id: self.id, + role: self.role, + segments: vec![SerializedMessageSegment::Text { text: self.text }], + tool_uses: self.tool_uses, + tool_results: self.tool_results, + context: String::new(), + creases: Vec::new(), + is_hidden: false, + } + } +} + +#[derive(Debug, Serialize, Deserialize, PartialEq)] +pub struct SerializedCrease { + pub start: usize, + pub end: usize, + pub icon_path: SharedString, + pub label: SharedString, +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::Utc; + use language_model::{Role, TokenUsage}; + use pretty_assertions::assert_eq; + + #[test] + fn test_legacy_serialized_thread_upgrade() { + let updated_at = Utc::now(); + let legacy_thread = LegacySerializedThread { + summary: "Test conversation".into(), + updated_at, + messages: vec![LegacySerializedMessage { + id: MessageId(1), + role: Role::User, + text: "Hello, world!".to_string(), + tool_uses: vec![], + tool_results: vec![], + }], + initial_project_snapshot: None, + }; + + let upgraded = legacy_thread.upgrade(); + + assert_eq!( + upgraded, + SerializedThread { + summary: "Test conversation".into(), + updated_at, + messages: vec![SerializedMessage { + id: MessageId(1), + role: Role::User, + segments: vec![SerializedMessageSegment::Text { + text: "Hello, world!".to_string() + }], + tool_uses: vec![], + tool_results: vec![], + context: "".to_string(), + creases: vec![], + is_hidden: false + }], + version: SerializedThread::VERSION.to_string(), + initial_project_snapshot: None, + cumulative_token_usage: TokenUsage::default(), + request_token_usage: vec![], + detailed_summary_state: DetailedSummaryState::default(), + model: None, + completion_mode: None, + tool_use_limit_reached: false, + profile: None + } + ) + } + + #[test] + fn test_serialized_threadv0_1_0_upgrade() { + let updated_at = Utc::now(); + let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread { + summary: "Test conversation".into(), + updated_at, + messages: vec![ + SerializedMessage { + id: MessageId(1), + role: Role::User, + segments: vec![SerializedMessageSegment::Text { + text: "Use tool_1".to_string(), + }], + tool_uses: vec![], + tool_results: vec![], + context: "".to_string(), + creases: vec![], + is_hidden: false, + }, + SerializedMessage { + id: MessageId(2), + role: Role::Assistant, + segments: vec![SerializedMessageSegment::Text { + text: "I want to use a tool".to_string(), + }], + tool_uses: vec![SerializedToolUse { + id: "abc".into(), + name: "tool_1".into(), + input: serde_json::Value::Null, + }], + tool_results: vec![], + context: "".to_string(), + creases: vec![], + is_hidden: false, + }, + SerializedMessage { + id: MessageId(1), + role: Role::User, + segments: vec![SerializedMessageSegment::Text { + text: "Here is the tool result".to_string(), + }], + tool_uses: vec![], + tool_results: vec![SerializedToolResult { + tool_use_id: "abc".into(), + is_error: false, + content: LanguageModelToolResultContent::Text("abcdef".into()), + output: Some(serde_json::Value::Null), + }], + context: "".to_string(), + creases: vec![], + is_hidden: false, + }, + ], + version: SerializedThreadV0_1_0::VERSION.to_string(), + initial_project_snapshot: None, + cumulative_token_usage: TokenUsage::default(), + request_token_usage: vec![], + detailed_summary_state: DetailedSummaryState::default(), + model: None, + completion_mode: None, + tool_use_limit_reached: false, + profile: None, + }); + let upgraded = thread_v0_1_0.upgrade(); + + assert_eq!( + upgraded, + SerializedThread { + summary: "Test conversation".into(), + updated_at, + messages: vec![ + SerializedMessage { + id: MessageId(1), + role: Role::User, + segments: vec![SerializedMessageSegment::Text { + text: "Use tool_1".to_string() + }], + tool_uses: vec![], + tool_results: vec![], + context: "".to_string(), + creases: vec![], + is_hidden: false + }, + SerializedMessage { + id: MessageId(2), + role: Role::Assistant, + segments: vec![SerializedMessageSegment::Text { + text: "I want to use a tool".to_string(), + }], + tool_uses: vec![SerializedToolUse { + id: "abc".into(), + name: "tool_1".into(), + input: serde_json::Value::Null, + }], + tool_results: vec![SerializedToolResult { + tool_use_id: "abc".into(), + is_error: false, + content: LanguageModelToolResultContent::Text("abcdef".into()), + output: Some(serde_json::Value::Null), + }], + context: "".to_string(), + creases: vec![], + is_hidden: false, + }, + ], + version: SerializedThread::VERSION.to_string(), + initial_project_snapshot: None, + cumulative_token_usage: TokenUsage::default(), + request_token_usage: vec![], + detailed_summary_state: DetailedSummaryState::default(), + model: None, + completion_mode: None, + tool_use_limit_reached: false, + profile: None + } + ) + } +} diff --git a/crates/agent2/src/native_agent_server.rs b/crates/agent/src/native_agent_server.rs similarity index 100% rename from crates/agent2/src/native_agent_server.rs rename to crates/agent/src/native_agent_server.rs diff --git a/crates/assistant_tool/src/outline.rs b/crates/agent/src/outline.rs similarity index 76% rename from crates/assistant_tool/src/outline.rs rename to crates/agent/src/outline.rs index 4c8e2efefd67e25c630d38e16bda8a8dff34fb16..bc78290fb52ae208742b9dea0e6dbbe560022419 100644 --- a/crates/assistant_tool/src/outline.rs +++ b/crates/agent/src/outline.rs @@ -1,8 +1,6 @@ -use action_log::ActionLog; -use anyhow::{Context as _, Result}; +use anyhow::Result; use gpui::{AsyncApp, Entity}; use language::{Buffer, OutlineItem, ParseStatus}; -use project::Project; use regex::Regex; use std::fmt::Write; use text::Point; @@ -11,51 +9,66 @@ use text::Point; /// we automatically provide the file's symbol outline instead, with line numbers. pub const AUTO_OUTLINE_SIZE: usize = 16384; -pub async fn file_outline( - project: Entity, - path: String, - action_log: Entity, - regex: Option, - cx: &mut AsyncApp, -) -> anyhow::Result { - let buffer = { - let project_path = project.read_with(cx, |project, cx| { - project - .find_project_path(&path, cx) - .with_context(|| format!("Path {path} not found in project")) - })??; - - project - .update(cx, |project, cx| project.open_buffer(project_path, cx))? - .await? - }; +/// Result of getting buffer content, which can be either full content or an outline. +pub struct BufferContent { + /// The actual content (either full text or outline) + pub text: String, + /// Whether this is an outline (true) or full content (false) + pub is_outline: bool, +} - action_log.update(cx, |action_log, cx| { - action_log.buffer_read(buffer.clone(), cx); - })?; +/// Returns either the full content of a buffer or its outline, depending on size. +/// For files larger than AUTO_OUTLINE_SIZE, returns an outline with a header. +/// For smaller files, returns the full content. +pub async fn get_buffer_content_or_outline( + buffer: Entity, + path: Option<&str>, + cx: &AsyncApp, +) -> Result { + let file_size = buffer.read_with(cx, |buffer, _| buffer.text().len())?; - // Wait until the buffer has been fully parsed, so that we can read its outline. - let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?; - while *parse_status.borrow() != ParseStatus::Idle { - parse_status.changed().await?; - } + if file_size > AUTO_OUTLINE_SIZE { + // For large files, use outline instead of full content + // Wait until the buffer has been fully parsed, so we can read its outline + let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?; + while *parse_status.borrow() != ParseStatus::Idle { + parse_status.changed().await?; + } + + let outline_items = buffer.read_with(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + snapshot + .outline(None) + .items + .into_iter() + .map(|item| item.to_point(&snapshot)) + .collect::>() + })?; - let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?; - let outline = snapshot.outline(None); - - render_outline( - outline - .items - .into_iter() - .map(|item| item.to_point(&snapshot)), - regex, - 0, - usize::MAX, - ) - .await + let outline_text = render_outline(outline_items, None, 0, usize::MAX).await?; + + let text = if let Some(path) = path { + format!( + "# File outline for {path} (file too large to show full content)\n\n{outline_text}", + ) + } else { + format!("# File outline (file too large to show full content)\n\n{outline_text}",) + }; + Ok(BufferContent { + text, + is_outline: true, + }) + } else { + // File is small enough, return full content + let text = buffer.read_with(cx, |buffer, _| buffer.text())?; + Ok(BufferContent { + text, + is_outline: false, + }) + } } -pub async fn render_outline( +async fn render_outline( items: impl IntoIterator>, regex: Option, offset: usize, @@ -128,62 +141,3 @@ fn render_entries( entries_rendered } - -/// Result of getting buffer content, which can be either full content or an outline. -pub struct BufferContent { - /// The actual content (either full text or outline) - pub text: String, - /// Whether this is an outline (true) or full content (false) - pub is_outline: bool, -} - -/// Returns either the full content of a buffer or its outline, depending on size. -/// For files larger than AUTO_OUTLINE_SIZE, returns an outline with a header. -/// For smaller files, returns the full content. -pub async fn get_buffer_content_or_outline( - buffer: Entity, - path: Option<&str>, - cx: &AsyncApp, -) -> Result { - let file_size = buffer.read_with(cx, |buffer, _| buffer.text().len())?; - - if file_size > AUTO_OUTLINE_SIZE { - // For large files, use outline instead of full content - // Wait until the buffer has been fully parsed, so we can read its outline - let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?; - while *parse_status.borrow() != ParseStatus::Idle { - parse_status.changed().await?; - } - - let outline_items = buffer.read_with(cx, |buffer, _| { - let snapshot = buffer.snapshot(); - snapshot - .outline(None) - .items - .into_iter() - .map(|item| item.to_point(&snapshot)) - .collect::>() - })?; - - let outline_text = render_outline(outline_items, None, 0, usize::MAX).await?; - - let text = if let Some(path) = path { - format!( - "# File outline for {path} (file too large to show full content)\n\n{outline_text}", - ) - } else { - format!("# File outline (file too large to show full content)\n\n{outline_text}",) - }; - Ok(BufferContent { - text, - is_outline: true, - }) - } else { - // File is small enough, return full content - let text = buffer.read_with(cx, |buffer, _| buffer.text())?; - Ok(BufferContent { - text, - is_outline: false, - }) - } -} diff --git a/crates/agent/src/prompts/stale_files_prompt_header.txt b/crates/agent/src/prompts/stale_files_prompt_header.txt deleted file mode 100644 index f743e239c883c7456f7bdc6e089185c6b994cb44..0000000000000000000000000000000000000000 --- a/crates/agent/src/prompts/stale_files_prompt_header.txt +++ /dev/null @@ -1,3 +0,0 @@ -[The following is an auto-generated notification; do not reply] - -These files have changed since the last read: diff --git a/crates/agent2/src/templates.rs b/crates/agent/src/templates.rs similarity index 100% rename from crates/agent2/src/templates.rs rename to crates/agent/src/templates.rs diff --git a/crates/assistant_tools/src/templates/create_file_prompt.hbs b/crates/agent/src/templates/create_file_prompt.hbs similarity index 100% rename from crates/assistant_tools/src/templates/create_file_prompt.hbs rename to crates/agent/src/templates/create_file_prompt.hbs diff --git a/crates/assistant_tools/src/templates/diff_judge.hbs b/crates/agent/src/templates/diff_judge.hbs similarity index 100% rename from crates/assistant_tools/src/templates/diff_judge.hbs rename to crates/agent/src/templates/diff_judge.hbs diff --git a/crates/assistant_tools/src/templates/edit_file_prompt_diff_fenced.hbs b/crates/agent/src/templates/edit_file_prompt_diff_fenced.hbs similarity index 100% rename from crates/assistant_tools/src/templates/edit_file_prompt_diff_fenced.hbs rename to crates/agent/src/templates/edit_file_prompt_diff_fenced.hbs diff --git a/crates/assistant_tools/src/templates/edit_file_prompt_xml.hbs b/crates/agent/src/templates/edit_file_prompt_xml.hbs similarity index 100% rename from crates/assistant_tools/src/templates/edit_file_prompt_xml.hbs rename to crates/agent/src/templates/edit_file_prompt_xml.hbs diff --git a/crates/agent2/src/templates/system_prompt.hbs b/crates/agent/src/templates/system_prompt.hbs similarity index 100% rename from crates/agent2/src/templates/system_prompt.hbs rename to crates/agent/src/templates/system_prompt.hbs diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent/src/tests/mod.rs similarity index 99% rename from crates/agent2/src/tests/mod.rs rename to crates/agent/src/tests/mod.rs index 2e63aa5856501f880fec94f7659b13be321b03b3..6b7d30b37f825bf664ee270bee9f965ee194291c 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -975,9 +975,9 @@ async fn test_mcp_tools(cx: &mut TestAppContext) { vec![context_server::types::Tool { name: "echo".into(), description: None, - input_schema: serde_json::to_value( - EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema), - ) + input_schema: serde_json::to_value(EchoTool::input_schema( + LanguageModelToolSchemaFormat::JsonSchema, + )) .unwrap(), output_schema: None, annotations: None, @@ -1149,9 +1149,9 @@ async fn test_mcp_tool_truncation(cx: &mut TestAppContext) { context_server::types::Tool { name: "echo".into(), // Conflicts with native EchoTool description: None, - input_schema: serde_json::to_value( - EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema), - ) + input_schema: serde_json::to_value(EchoTool::input_schema( + LanguageModelToolSchemaFormat::JsonSchema, + )) .unwrap(), output_schema: None, annotations: None, @@ -1174,9 +1174,9 @@ async fn test_mcp_tool_truncation(cx: &mut TestAppContext) { context_server::types::Tool { name: "echo".into(), // Also conflicts with native EchoTool description: None, - input_schema: serde_json::to_value( - EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema), - ) + input_schema: serde_json::to_value(EchoTool::input_schema( + LanguageModelToolSchemaFormat::JsonSchema, + )) .unwrap(), output_schema: None, annotations: None, @@ -1864,7 +1864,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) { let selector_opt = connection.model_selector(&session_id); assert!( selector_opt.is_some(), - "agent2 should always support ModelSelector" + "agent should always support ModelSelector" ); let selector = selector_opt.unwrap(); diff --git a/crates/agent2/src/tests/test_tools.rs b/crates/agent/src/tests/test_tools.rs similarity index 100% rename from crates/agent2/src/tests/test_tools.rs rename to crates/agent/src/tests/test_tools.rs diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index d189b7611209d2fbea5c882ea548318f73ddbfb3..ec9d50ff2f62c5602dd91e5da47593764ea01c85 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -1,95 +1,65 @@ use crate::{ - agent_profile::AgentProfile, - context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext}, - thread_store::{ - SerializedCrease, SerializedLanguageModel, SerializedMessage, SerializedMessageSegment, - SerializedThread, SerializedToolResult, SerializedToolUse, SharedProjectContext, - ThreadStore, - }, - tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState}, + ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DbLanguageModel, DbThread, + DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GitState, GrepTool, + ListDirectoryTool, MovePathTool, NowTool, OpenTool, ProjectSnapshot, ReadFileTool, + SystemPromptTemplate, Template, Templates, TerminalTool, ThinkingTool, WebSearchTool, + WorktreeSnapshot, }; +use acp_thread::{MentionUri, UserMessageId}; use action_log::ActionLog; + +use agent_client_protocol as acp; use agent_settings::{ - AgentProfileId, AgentSettings, CompletionMode, SUMMARIZE_THREAD_DETAILED_PROMPT, - SUMMARIZE_THREAD_PROMPT, + AgentProfileId, AgentProfileSettings, AgentSettings, CompletionMode, + SUMMARIZE_THREAD_DETAILED_PROMPT, SUMMARIZE_THREAD_PROMPT, }; -use anyhow::{Result, anyhow}; -use assistant_tool::{AnyToolCard, Tool, ToolWorkingSet}; +use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; -use client::{ModelRequestUsage, RequestUsage}; +use client::{ModelRequestUsage, RequestUsage, UserStore}; use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit}; -use collections::HashMap; -use futures::{FutureExt, StreamExt as _, future::Shared}; +use collections::{HashMap, HashSet, IndexMap}; +use fs::Fs; +use futures::stream; +use futures::{ + FutureExt, + channel::{mpsc, oneshot}, + future::Shared, + stream::FuturesUnordered, +}; use git::repository::DiffType; use gpui::{ - AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, - WeakEntity, Window, + App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity, }; -use http_client::StatusCode; use language_model::{ - ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, - LanguageModelExt as _, LanguageModelId, LanguageModelRegistry, LanguageModelRequest, + LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt, + LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, - LanguageModelToolResultContent, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, - ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason, - TokenUsage, + LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse, + LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage, ZED_CLOUD_PROVIDER_ID, }; -use postage::stream::Stream as _; use project::{ Project, - git_store::{GitStore, GitStoreCheckpoint, RepositoryState}, + git_store::{GitStore, RepositoryState}, }; -use prompt_store::{ModelContext, PromptBuilder}; -use schemars::JsonSchema; +use prompt_store::ProjectContext; +use schemars::{JsonSchema, Schema}; use serde::{Deserialize, Serialize}; -use settings::Settings; +use settings::{Settings, update_settings_file}; +use smol::stream::StreamExt; use std::{ - io::Write, - ops::Range, + collections::BTreeMap, + ops::RangeInclusive, + path::Path, + rc::Rc, sync::Arc, time::{Duration, Instant}, }; -use thiserror::Error; -use util::{ResultExt as _, post_inc}; +use std::{fmt::Write, path::PathBuf}; +use util::{ResultExt, debug_panic, markdown::MarkdownCodeBlock}; use uuid::Uuid; -const MAX_RETRY_ATTEMPTS: u8 = 4; -const BASE_RETRY_DELAY: Duration = Duration::from_secs(5); - -#[derive(Debug, Clone)] -enum RetryStrategy { - ExponentialBackoff { - initial_delay: Duration, - max_attempts: u8, - }, - Fixed { - delay: Duration, - max_attempts: u8, - }, -} - -#[derive( - Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema, -)] -pub struct ThreadId(Arc); - -impl ThreadId { - pub fn new() -> Self { - Self(Uuid::new_v4().to_string().into()) - } -} - -impl std::fmt::Display for ThreadId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -impl From<&str> for ThreadId { - fn from(value: &str) -> Self { - Self(value.into()) - } -} +const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user"; +pub const MAX_TOOL_NAME_LENGTH: usize = 64; /// The ID of the user prompt that initiated a request. /// @@ -109,2014 +79,1958 @@ impl std::fmt::Display for PromptId { } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] -pub struct MessageId(pub usize); - -impl MessageId { - fn post_inc(&mut self) -> Self { - Self(post_inc(&mut self.0)) - } - - pub fn as_usize(&self) -> usize { - self.0 - } -} +pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4; +pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5); -/// Stored information that can be used to resurrect a context crease when creating an editor for a past message. -#[derive(Clone, Debug)] -pub struct MessageCrease { - pub range: Range, - pub icon_path: SharedString, - pub label: SharedString, - /// None for a deserialized message, Some otherwise. - pub context: Option, +#[derive(Debug, Clone)] +enum RetryStrategy { + ExponentialBackoff { + initial_delay: Duration, + max_attempts: u8, + }, + Fixed { + delay: Duration, + max_attempts: u8, + }, } -/// A message in a [`Thread`]. -#[derive(Debug, Clone)] -pub struct Message { - pub id: MessageId, - pub role: Role, - pub segments: Vec, - pub loaded_context: LoadedContext, - pub creases: Vec, - pub is_hidden: bool, - pub ui_only: bool, +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum Message { + User(UserMessage), + Agent(AgentMessage), + Resume, } impl Message { - /// Returns whether the message contains any meaningful text that should be displayed - /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`]) - pub fn should_display_content(&self) -> bool { - self.segments.iter().all(|segment| segment.should_display()) + pub fn as_agent_message(&self) -> Option<&AgentMessage> { + match self { + Message::Agent(agent_message) => Some(agent_message), + _ => None, + } } - pub fn push_thinking(&mut self, text: &str, signature: Option) { - if let Some(MessageSegment::Thinking { - text: segment, - signature: current_signature, - }) = self.segments.last_mut() - { - if let Some(signature) = signature { - *current_signature = Some(signature); - } - segment.push_str(text); - } else { - self.segments.push(MessageSegment::Thinking { - text: text.to_string(), - signature, - }); + pub fn to_request(&self) -> Vec { + match self { + Message::User(message) => vec![message.to_request()], + Message::Agent(message) => message.to_request(), + Message::Resume => vec![LanguageModelRequestMessage { + role: Role::User, + content: vec!["Continue where you left off".into()], + cache: false, + }], } } - pub fn push_redacted_thinking(&mut self, data: String) { - self.segments.push(MessageSegment::RedactedThinking(data)); + pub fn to_markdown(&self) -> String { + match self { + Message::User(message) => message.to_markdown(), + Message::Agent(message) => message.to_markdown(), + Message::Resume => "[resume]\n".into(), + } } - pub fn push_text(&mut self, text: &str) { - if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() { - segment.push_str(text); - } else { - self.segments.push(MessageSegment::Text(text.to_string())); + pub fn role(&self) -> Role { + match self { + Message::User(_) | Message::Resume => Role::User, + Message::Agent(_) => Role::Assistant, } } +} - pub fn to_message_content(&self) -> String { - let mut result = String::new(); +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct UserMessage { + pub id: UserMessageId, + pub content: Vec, +} - if !self.loaded_context.text.is_empty() { - result.push_str(&self.loaded_context.text); - } +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UserMessageContent { + Text(String), + Mention { uri: MentionUri, content: String }, + Image(LanguageModelImage), +} + +impl UserMessage { + pub fn to_markdown(&self) -> String { + let mut markdown = String::from("## User\n\n"); - for segment in &self.segments { - match segment { - MessageSegment::Text(text) => result.push_str(text), - MessageSegment::Thinking { text, .. } => { - result.push_str("\n"); - result.push_str(text); - result.push_str("\n"); + for content in &self.content { + match content { + UserMessageContent::Text(text) => { + markdown.push_str(text); + markdown.push('\n'); + } + UserMessageContent::Image(_) => { + markdown.push_str("\n"); + } + UserMessageContent::Mention { uri, content } => { + if !content.is_empty() { + let _ = writeln!(&mut markdown, "{}\n\n{}", uri.as_link(), content); + } else { + let _ = writeln!(&mut markdown, "{}", uri.as_link()); + } } - MessageSegment::RedactedThinking(_) => {} } } - result + markdown } -} -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum MessageSegment { - Text(String), - Thinking { - text: String, - signature: Option, - }, - RedactedThinking(String), -} + fn to_request(&self) -> LanguageModelRequestMessage { + let mut message = LanguageModelRequestMessage { + role: Role::User, + content: Vec::with_capacity(self.content.len()), + cache: false, + }; -impl MessageSegment { - pub fn should_display(&self) -> bool { - match self { - Self::Text(text) => text.is_empty(), - Self::Thinking { text, .. } => text.is_empty(), - Self::RedactedThinking(_) => false, + const OPEN_CONTEXT: &str = "\n\ + The following items were attached by the user. \ + They are up-to-date and don't need to be re-read.\n\n"; + + const OPEN_FILES_TAG: &str = ""; + const OPEN_DIRECTORIES_TAG: &str = ""; + const OPEN_SYMBOLS_TAG: &str = ""; + const OPEN_SELECTIONS_TAG: &str = ""; + const OPEN_THREADS_TAG: &str = ""; + const OPEN_FETCH_TAG: &str = ""; + const OPEN_RULES_TAG: &str = + "\nThe user has specified the following rules that should be applied:\n"; + + let mut file_context = OPEN_FILES_TAG.to_string(); + let mut directory_context = OPEN_DIRECTORIES_TAG.to_string(); + let mut symbol_context = OPEN_SYMBOLS_TAG.to_string(); + let mut selection_context = OPEN_SELECTIONS_TAG.to_string(); + let mut thread_context = OPEN_THREADS_TAG.to_string(); + let mut fetch_context = OPEN_FETCH_TAG.to_string(); + let mut rules_context = OPEN_RULES_TAG.to_string(); + + for chunk in &self.content { + let chunk = match chunk { + UserMessageContent::Text(text) => { + language_model::MessageContent::Text(text.clone()) + } + UserMessageContent::Image(value) => { + language_model::MessageContent::Image(value.clone()) + } + UserMessageContent::Mention { uri, content } => { + match uri { + MentionUri::File { abs_path } => { + write!( + &mut file_context, + "\n{}", + MarkdownCodeBlock { + tag: &codeblock_tag(abs_path, None), + text: &content.to_string(), + } + ) + .ok(); + } + MentionUri::PastedImage => { + debug_panic!("pasted image URI should not be used in mention content") + } + MentionUri::Directory { .. } => { + write!(&mut directory_context, "\n{}\n", content).ok(); + } + MentionUri::Symbol { + abs_path: path, + line_range, + .. + } => { + write!( + &mut symbol_context, + "\n{}", + MarkdownCodeBlock { + tag: &codeblock_tag(path, Some(line_range)), + text: content + } + ) + .ok(); + } + MentionUri::Selection { + abs_path: path, + line_range, + .. + } => { + write!( + &mut selection_context, + "\n{}", + MarkdownCodeBlock { + tag: &codeblock_tag( + path.as_deref().unwrap_or("Untitled".as_ref()), + Some(line_range) + ), + text: content + } + ) + .ok(); + } + MentionUri::Thread { .. } => { + write!(&mut thread_context, "\n{}\n", content).ok(); + } + MentionUri::TextThread { .. } => { + write!(&mut thread_context, "\n{}\n", content).ok(); + } + MentionUri::Rule { .. } => { + write!( + &mut rules_context, + "\n{}", + MarkdownCodeBlock { + tag: "", + text: content + } + ) + .ok(); + } + MentionUri::Fetch { url } => { + write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok(); + } + } + + language_model::MessageContent::Text(uri.as_link().to_string()) + } + }; + + message.content.push(chunk); } - } - pub fn text(&self) -> Option<&str> { - match self { - MessageSegment::Text(text) => Some(text), - _ => None, + let len_before_context = message.content.len(); + + if file_context.len() > OPEN_FILES_TAG.len() { + file_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(file_context)); } - } -} -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct ProjectSnapshot { - pub worktree_snapshots: Vec, - pub timestamp: DateTime, -} + if directory_context.len() > OPEN_DIRECTORIES_TAG.len() { + directory_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(directory_context)); + } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct WorktreeSnapshot { - pub worktree_path: String, - pub git_state: Option, -} + if symbol_context.len() > OPEN_SYMBOLS_TAG.len() { + symbol_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(symbol_context)); + } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct GitState { - pub remote_url: Option, - pub head_sha: Option, - pub current_branch: Option, - pub diff: Option, -} + if selection_context.len() > OPEN_SELECTIONS_TAG.len() { + selection_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(selection_context)); + } -#[derive(Clone, Debug)] -pub struct ThreadCheckpoint { - message_id: MessageId, - git_checkpoint: GitStoreCheckpoint, -} + if thread_context.len() > OPEN_THREADS_TAG.len() { + thread_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(thread_context)); + } -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum ThreadFeedback { - Positive, - Negative, -} + if fetch_context.len() > OPEN_FETCH_TAG.len() { + fetch_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(fetch_context)); + } -pub enum LastRestoreCheckpoint { - Pending { - message_id: MessageId, - }, - Error { - message_id: MessageId, - error: String, - }, -} + if rules_context.len() > OPEN_RULES_TAG.len() { + rules_context.push_str("\n"); + message + .content + .push(language_model::MessageContent::Text(rules_context)); + } -impl LastRestoreCheckpoint { - pub fn message_id(&self) -> MessageId { - match self { - LastRestoreCheckpoint::Pending { message_id } => *message_id, - LastRestoreCheckpoint::Error { message_id, .. } => *message_id, + if message.content.len() > len_before_context { + message.content.insert( + len_before_context, + language_model::MessageContent::Text(OPEN_CONTEXT.into()), + ); + message + .content + .push(language_model::MessageContent::Text("".into())); } + + message } } -#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] -pub enum DetailedSummaryState { - #[default] - NotGenerated, - Generating { - message_id: MessageId, - }, - Generated { - text: SharedString, - message_id: MessageId, - }, -} +fn codeblock_tag(full_path: &Path, line_range: Option<&RangeInclusive>) -> String { + let mut result = String::new(); -impl DetailedSummaryState { - fn text(&self) -> Option { - if let Self::Generated { text, .. } = self { - Some(text.clone()) + if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) { + let _ = write!(result, "{} ", extension); + } + + let _ = write!(result, "{}", full_path.display()); + + if let Some(range) = line_range { + if range.start() == range.end() { + let _ = write!(result, ":{}", range.start() + 1); } else { - None + let _ = write!(result, ":{}-{}", range.start() + 1, range.end() + 1); } } -} -#[derive(Default, Debug)] -pub struct TotalTokenUsage { - pub total: u64, - pub max: u64, + result } -impl TotalTokenUsage { - pub fn ratio(&self) -> TokenUsageRatio { - #[cfg(debug_assertions)] - let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD") - .unwrap_or("0.8".to_string()) - .parse() - .unwrap(); - #[cfg(not(debug_assertions))] - let warning_threshold: f32 = 0.8; - - // When the maximum is unknown because there is no selected model, - // avoid showing the token limit warning. - if self.max == 0 { - TokenUsageRatio::Normal - } else if self.total >= self.max { - TokenUsageRatio::Exceeded - } else if self.total as f32 / self.max as f32 >= warning_threshold { - TokenUsageRatio::Warning - } else { - TokenUsageRatio::Normal +impl AgentMessage { + pub fn to_markdown(&self) -> String { + let mut markdown = String::from("## Assistant\n\n"); + + for content in &self.content { + match content { + AgentMessageContent::Text(text) => { + markdown.push_str(text); + markdown.push('\n'); + } + AgentMessageContent::Thinking { text, .. } => { + markdown.push_str(""); + markdown.push_str(text); + markdown.push_str("\n"); + } + AgentMessageContent::RedactedThinking(_) => { + markdown.push_str("\n") + } + AgentMessageContent::ToolUse(tool_use) => { + markdown.push_str(&format!( + "**Tool Use**: {} (ID: {})\n", + tool_use.name, tool_use.id + )); + markdown.push_str(&format!( + "{}\n", + MarkdownCodeBlock { + tag: "json", + text: &format!("{:#}", tool_use.input) + } + )); + } + } + } + + for tool_result in self.tool_results.values() { + markdown.push_str(&format!( + "**Tool Result**: {} (ID: {})\n\n", + tool_result.tool_name, tool_result.tool_use_id + )); + if tool_result.is_error { + markdown.push_str("**ERROR:**\n"); + } + + match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + writeln!(markdown, "{text}\n").ok(); + } + LanguageModelToolResultContent::Image(_) => { + writeln!(markdown, "\n").ok(); + } + } + + if let Some(output) = tool_result.output.as_ref() { + writeln!( + markdown, + "**Debug Output**:\n\n```json\n{}\n```\n", + serde_json::to_string_pretty(output).unwrap() + ) + .unwrap(); + } } + + markdown } - pub fn add(&self, tokens: u64) -> TotalTokenUsage { - TotalTokenUsage { - total: self.total + tokens, - max: self.max, + pub fn to_request(&self) -> Vec { + let mut assistant_message = LanguageModelRequestMessage { + role: Role::Assistant, + content: Vec::with_capacity(self.content.len()), + cache: false, + }; + for chunk in &self.content { + match chunk { + AgentMessageContent::Text(text) => { + assistant_message + .content + .push(language_model::MessageContent::Text(text.clone())); + } + AgentMessageContent::Thinking { text, signature } => { + assistant_message + .content + .push(language_model::MessageContent::Thinking { + text: text.clone(), + signature: signature.clone(), + }); + } + AgentMessageContent::RedactedThinking(value) => { + assistant_message.content.push( + language_model::MessageContent::RedactedThinking(value.clone()), + ); + } + AgentMessageContent::ToolUse(tool_use) => { + if self.tool_results.contains_key(&tool_use.id) { + assistant_message + .content + .push(language_model::MessageContent::ToolUse(tool_use.clone())); + } + } + }; + } + + let mut user_message = LanguageModelRequestMessage { + role: Role::User, + content: Vec::new(), + cache: false, + }; + + for tool_result in self.tool_results.values() { + let mut tool_result = tool_result.clone(); + // Surprisingly, the API fails if we return an empty string here. + // It thinks we are sending a tool use without a tool result. + if tool_result.content.is_empty() { + tool_result.content = "".into(); + } + user_message + .content + .push(language_model::MessageContent::ToolResult(tool_result)); + } + + let mut messages = Vec::new(); + if !assistant_message.content.is_empty() { + messages.push(assistant_message); + } + if !user_message.content.is_empty() { + messages.push(user_message); } + messages } } -#[derive(Debug, Default, PartialEq, Eq)] -pub enum TokenUsageRatio { - #[default] - Normal, - Warning, - Exceeded, +#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct AgentMessage { + pub content: Vec, + pub tool_results: IndexMap, } -#[derive(Debug, Clone, Copy)] -pub enum QueueState { - Sending, - Queued { position: usize }, - Started, +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum AgentMessageContent { + Text(String), + Thinking { + text: String, + signature: Option, + }, + RedactedThinking(String), + ToolUse(LanguageModelToolUse), } -/// A thread of conversation with the LLM. -pub struct Thread { - id: ThreadId, - updated_at: DateTime, - summary: ThreadSummary, - pending_summary: Task>, - detailed_summary_task: Task>, - detailed_summary_tx: postage::watch::Sender, - detailed_summary_rx: postage::watch::Receiver, - completion_mode: agent_settings::CompletionMode, - messages: Vec, - next_message_id: MessageId, - last_prompt_id: PromptId, - project_context: SharedProjectContext, - checkpoints_by_message: HashMap, - completion_count: usize, - pending_completions: Vec, - project: Entity, - prompt_builder: Arc, - tools: Entity, - tool_use: ToolUseState, - action_log: Entity, - last_restore_checkpoint: Option, - pending_checkpoint: Option, - initial_project_snapshot: Shared>>>, - request_token_usage: Vec, - cumulative_token_usage: TokenUsage, - exceeded_window_error: Option, - tool_use_limit_reached: bool, - retry_state: Option, - message_feedback: HashMap, - last_received_chunk_at: Option, - request_callback: Option< - Box])>, - >, - remaining_turns: u32, - configured_model: Option, - profile: AgentProfile, - last_error_context: Option<(Arc, CompletionIntent)>, +pub trait TerminalHandle { + fn id(&self, cx: &AsyncApp) -> Result; + fn current_output(&self, cx: &AsyncApp) -> Result; + fn wait_for_exit(&self, cx: &AsyncApp) -> Result>>; } -#[derive(Clone, Debug)] -struct RetryState { - attempt: u8, - max_attempts: u8, - intent: CompletionIntent, +pub trait ThreadEnvironment { + fn create_terminal( + &self, + command: String, + cwd: Option, + output_byte_limit: Option, + cx: &mut AsyncApp, + ) -> Task>>; } -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum ThreadSummary { - Pending, - Generating, - Ready(SharedString), - Error, +#[derive(Debug)] +pub enum ThreadEvent { + UserMessage(UserMessage), + AgentText(String), + AgentThinking(String), + ToolCall(acp::ToolCall), + ToolCallUpdate(acp_thread::ToolCallUpdate), + ToolCallAuthorization(ToolCallAuthorization), + Retry(acp_thread::RetryStatus), + Stop(acp::StopReason), } -impl ThreadSummary { - pub const DEFAULT: SharedString = SharedString::new_static("New Thread"); - - pub fn or_default(&self) -> SharedString { - self.unwrap_or(Self::DEFAULT) - } +#[derive(Debug)] +pub struct NewTerminal { + pub command: String, + pub output_byte_limit: Option, + pub cwd: Option, + pub response: oneshot::Sender>>, +} - pub fn unwrap_or(&self, message: impl Into) -> SharedString { - self.ready().unwrap_or_else(|| message.into()) - } +#[derive(Debug)] +pub struct ToolCallAuthorization { + pub tool_call: acp::ToolCallUpdate, + pub options: Vec, + pub response: oneshot::Sender, +} - pub fn ready(&self) -> Option { - match self { - ThreadSummary::Ready(summary) => Some(summary.clone()), - ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None, - } - } +#[derive(Debug, thiserror::Error)] +enum CompletionError { + #[error("max tokens")] + MaxTokens, + #[error("refusal")] + Refusal, + #[error(transparent)] + Other(#[from] anyhow::Error), } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct ExceededWindowError { - /// Model used when last message exceeded context window - model_id: LanguageModelId, - /// Token count including last message - token_count: u64, +pub struct Thread { + id: acp::SessionId, + prompt_id: PromptId, + updated_at: DateTime, + title: Option, + pending_title_generation: Option>, + pending_summary_generation: Option>>>, + summary: Option, + messages: Vec, + user_store: Entity, + completion_mode: CompletionMode, + /// Holds the task that handles agent interaction until the end of the turn. + /// Survives across multiple requests as the model performs tool calls and + /// we run tools, report their results. + running_turn: Option, + pending_message: Option, + tools: BTreeMap>, + tool_use_limit_reached: bool, + request_token_usage: HashMap, + #[allow(unused)] + cumulative_token_usage: TokenUsage, + #[allow(unused)] + initial_project_snapshot: Shared>>>, + context_server_registry: Entity, + profile_id: AgentProfileId, + project_context: Entity, + templates: Arc, + model: Option>, + summarization_model: Option>, + prompt_capabilities_tx: watch::Sender, + pub(crate) prompt_capabilities_rx: watch::Receiver, + pub(crate) project: Entity, + pub(crate) action_log: Entity, } impl Thread { + fn prompt_capabilities(model: Option<&dyn LanguageModel>) -> acp::PromptCapabilities { + let image = model.map_or(true, |model| model.supports_images()); + acp::PromptCapabilities { + meta: None, + image, + audio: false, + embedded_context: true, + } + } + pub fn new( project: Entity, - tools: Entity, - prompt_builder: Arc, - system_prompt: SharedProjectContext, + project_context: Entity, + context_server_registry: Entity, + templates: Arc, + model: Option>, cx: &mut Context, ) -> Self { - let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel(); - let configured_model = LanguageModelRegistry::read_global(cx).default_model(); let profile_id = AgentSettings::get_global(cx).default_profile.clone(); - + let action_log = cx.new(|_cx| ActionLog::new(project.clone())); + let (prompt_capabilities_tx, prompt_capabilities_rx) = + watch::channel(Self::prompt_capabilities(model.as_deref())); Self { - id: ThreadId::new(), + id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()), + prompt_id: PromptId::new(), updated_at: Utc::now(), - summary: ThreadSummary::Pending, - pending_summary: Task::ready(None), - detailed_summary_task: Task::ready(None), - detailed_summary_tx, - detailed_summary_rx, - completion_mode: AgentSettings::get_global(cx).preferred_completion_mode, + title: None, + pending_title_generation: None, + pending_summary_generation: None, + summary: None, messages: Vec::new(), - next_message_id: MessageId(0), - last_prompt_id: PromptId::new(), - project_context: system_prompt, - checkpoints_by_message: HashMap::default(), - completion_count: 0, - pending_completions: Vec::new(), - project: project.clone(), - prompt_builder, - tools: tools.clone(), - last_restore_checkpoint: None, - pending_checkpoint: None, - tool_use: ToolUseState::new(tools.clone()), - action_log: cx.new(|_| ActionLog::new(project.clone())), + user_store: project.read(cx).user_store(), + completion_mode: AgentSettings::get_global(cx).preferred_completion_mode, + running_turn: None, + pending_message: None, + tools: BTreeMap::default(), + tool_use_limit_reached: false, + request_token_usage: HashMap::default(), + cumulative_token_usage: TokenUsage::default(), initial_project_snapshot: { - let project_snapshot = Self::project_snapshot(project, cx); + let project_snapshot = Self::project_snapshot(project.clone(), cx); cx.foreground_executor() .spawn(async move { Some(project_snapshot.await) }) .shared() }, - request_token_usage: Vec::new(), - cumulative_token_usage: TokenUsage::default(), - exceeded_window_error: None, - tool_use_limit_reached: false, - retry_state: None, - message_feedback: HashMap::default(), - last_error_context: None, - last_received_chunk_at: None, - request_callback: None, - remaining_turns: u32::MAX, - configured_model, - profile: AgentProfile::new(profile_id, tools), + context_server_registry, + profile_id, + project_context, + templates, + model, + summarization_model: None, + prompt_capabilities_tx, + prompt_capabilities_rx, + project, + action_log, } } - pub fn deserialize( - id: ThreadId, - serialized: SerializedThread, - project: Entity, - tools: Entity, - prompt_builder: Arc, - project_context: SharedProjectContext, - window: Option<&mut Window>, // None in headless mode - cx: &mut Context, - ) -> Self { - let next_message_id = MessageId( - serialized - .messages - .last() - .map(|message| message.id.0 + 1) - .unwrap_or(0), - ); - let tool_use = ToolUseState::from_serialized_messages( - tools.clone(), - &serialized.messages, - project.clone(), - window, - cx, - ); - let (detailed_summary_tx, detailed_summary_rx) = - postage::watch::channel_with(serialized.detailed_summary_state); - - let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| { - serialized - .model - .and_then(|model| { - let model = SelectedModel { - provider: model.provider.clone().into(), - model: model.model.into(), - }; - registry.select_model(&model, cx) - }) - .or_else(|| registry.default_model()) - }); - - let completion_mode = serialized - .completion_mode - .unwrap_or_else(|| AgentSettings::get_global(cx).preferred_completion_mode); - let profile_id = serialized - .profile - .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone()); - - Self { - id, - updated_at: serialized.updated_at, - summary: ThreadSummary::Ready(serialized.summary), - pending_summary: Task::ready(None), - detailed_summary_task: Task::ready(None), - detailed_summary_tx, - detailed_summary_rx, - completion_mode, - retry_state: None, - messages: serialized - .messages - .into_iter() - .map(|message| Message { - id: message.id, - role: message.role, - segments: message - .segments - .into_iter() - .map(|segment| match segment { - SerializedMessageSegment::Text { text } => MessageSegment::Text(text), - SerializedMessageSegment::Thinking { text, signature } => { - MessageSegment::Thinking { text, signature } - } - SerializedMessageSegment::RedactedThinking { data } => { - MessageSegment::RedactedThinking(data) - } - }) - .collect(), - loaded_context: LoadedContext { - contexts: Vec::new(), - text: message.context, - images: Vec::new(), - }, - creases: message - .creases - .into_iter() - .map(|crease| MessageCrease { - range: crease.start..crease.end, - icon_path: crease.icon_path, - label: crease.label, - context: None, - }) - .collect(), - is_hidden: message.is_hidden, - ui_only: false, // UI-only messages are not persisted - }) - .collect(), - next_message_id, - last_prompt_id: PromptId::new(), - project_context, - checkpoints_by_message: HashMap::default(), - completion_count: 0, - pending_completions: Vec::new(), - last_restore_checkpoint: None, - pending_checkpoint: None, - project: project.clone(), - prompt_builder, - tools: tools.clone(), - tool_use, - action_log: cx.new(|_| ActionLog::new(project)), - initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(), - request_token_usage: serialized.request_token_usage, - cumulative_token_usage: serialized.cumulative_token_usage, - exceeded_window_error: None, - tool_use_limit_reached: serialized.tool_use_limit_reached, - message_feedback: HashMap::default(), - last_error_context: None, - last_received_chunk_at: None, - request_callback: None, - remaining_turns: u32::MAX, - configured_model, - profile: AgentProfile::new(profile_id, tools), - } + pub fn id(&self) -> &acp::SessionId { + &self.id } - pub fn set_request_callback( + pub fn replay( &mut self, - callback: impl 'static - + FnMut(&LanguageModelRequest, &[Result]), - ) { - self.request_callback = Some(Box::new(callback)); - } - - pub fn id(&self) -> &ThreadId { - &self.id - } - - pub fn profile(&self) -> &AgentProfile { - &self.profile - } - - pub fn set_profile(&mut self, id: AgentProfileId, cx: &mut Context) { - if &id != self.profile.id() { - self.profile = AgentProfile::new(id, self.tools.clone()); - cx.emit(ThreadEvent::ProfileChanged); + cx: &mut Context, + ) -> mpsc::UnboundedReceiver> { + let (tx, rx) = mpsc::unbounded(); + let stream = ThreadEventStream(tx); + for message in &self.messages { + match message { + Message::User(user_message) => stream.send_user_message(user_message), + Message::Agent(assistant_message) => { + for content in &assistant_message.content { + match content { + AgentMessageContent::Text(text) => stream.send_text(text), + AgentMessageContent::Thinking { text, .. } => { + stream.send_thinking(text) + } + AgentMessageContent::RedactedThinking(_) => {} + AgentMessageContent::ToolUse(tool_use) => { + self.replay_tool_call( + tool_use, + assistant_message.tool_results.get(&tool_use.id), + &stream, + cx, + ); + } + } + } + } + Message::Resume => {} + } } + rx } - pub fn is_empty(&self) -> bool { - self.messages.is_empty() - } + fn replay_tool_call( + &self, + tool_use: &LanguageModelToolUse, + tool_result: Option<&LanguageModelToolResult>, + stream: &ThreadEventStream, + cx: &mut Context, + ) { + let tool = self.tools.get(tool_use.name.as_ref()).cloned().or_else(|| { + self.context_server_registry + .read(cx) + .servers() + .find_map(|(_, tools)| { + if let Some(tool) = tools.get(tool_use.name.as_ref()) { + Some(tool.clone()) + } else { + None + } + }) + }); - pub fn updated_at(&self) -> DateTime { - self.updated_at - } + let Some(tool) = tool else { + stream + .0 + .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall { + meta: None, + id: acp::ToolCallId(tool_use.id.to_string().into()), + title: tool_use.name.to_string(), + kind: acp::ToolKind::Other, + status: acp::ToolCallStatus::Failed, + content: Vec::new(), + locations: Vec::new(), + raw_input: Some(tool_use.input.clone()), + raw_output: None, + }))) + .ok(); + return; + }; - pub fn touch_updated_at(&mut self) { - self.updated_at = Utc::now(); + let title = tool.initial_title(tool_use.input.clone(), cx); + let kind = tool.kind(); + stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone()); + + let output = tool_result + .as_ref() + .and_then(|result| result.output.clone()); + if let Some(output) = output.clone() { + let tool_event_stream = ToolCallEventStream::new( + tool_use.id.clone(), + stream.clone(), + Some(self.project.read(cx).fs().clone()), + ); + tool.replay(tool_use.input.clone(), output, tool_event_stream, cx) + .log_err(); + } + + stream.update_tool_call_fields( + &tool_use.id, + acp::ToolCallUpdateFields { + status: Some( + tool_result + .as_ref() + .map_or(acp::ToolCallStatus::Failed, |result| { + if result.is_error { + acp::ToolCallStatus::Failed + } else { + acp::ToolCallStatus::Completed + } + }), + ), + raw_output: output, + ..Default::default() + }, + ); } - pub fn advance_prompt_id(&mut self) { - self.last_prompt_id = PromptId::new(); - } + pub fn from_db( + id: acp::SessionId, + db_thread: DbThread, + project: Entity, + project_context: Entity, + context_server_registry: Entity, + templates: Arc, + cx: &mut Context, + ) -> Self { + let profile_id = db_thread + .profile + .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone()); + let model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + db_thread + .model + .and_then(|model| { + let model = SelectedModel { + provider: model.provider.clone().into(), + model: model.model.into(), + }; + registry.select_model(&model, cx) + }) + .or_else(|| registry.default_model()) + .map(|model| model.model) + }); + let (prompt_capabilities_tx, prompt_capabilities_rx) = + watch::channel(Self::prompt_capabilities(model.as_deref())); - pub fn project_context(&self) -> SharedProjectContext { - self.project_context.clone() - } + let action_log = cx.new(|_| ActionLog::new(project.clone())); - pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option { - if self.configured_model.is_none() { - self.configured_model = LanguageModelRegistry::read_global(cx).default_model(); + Self { + id, + prompt_id: PromptId::new(), + title: if db_thread.title.is_empty() { + None + } else { + Some(db_thread.title.clone()) + }, + pending_title_generation: None, + pending_summary_generation: None, + summary: db_thread.detailed_summary, + messages: db_thread.messages, + user_store: project.read(cx).user_store(), + completion_mode: db_thread.completion_mode.unwrap_or_default(), + running_turn: None, + pending_message: None, + tools: BTreeMap::default(), + tool_use_limit_reached: false, + request_token_usage: db_thread.request_token_usage.clone(), + cumulative_token_usage: db_thread.cumulative_token_usage, + initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(), + context_server_registry, + profile_id, + project_context, + templates, + model, + summarization_model: None, + project, + action_log, + updated_at: db_thread.updated_at, + prompt_capabilities_tx, + prompt_capabilities_rx, } - self.configured_model.clone() - } - - pub fn configured_model(&self) -> Option { - self.configured_model.clone() - } - - pub fn set_configured_model(&mut self, model: Option, cx: &mut Context) { - self.configured_model = model; - cx.notify(); } - pub fn summary(&self) -> &ThreadSummary { - &self.summary - } - - pub fn set_summary(&mut self, new_summary: impl Into, cx: &mut Context) { - let current_summary = match &self.summary { - ThreadSummary::Pending | ThreadSummary::Generating => return, - ThreadSummary::Ready(summary) => summary, - ThreadSummary::Error => &ThreadSummary::DEFAULT, + pub fn to_db(&self, cx: &App) -> Task { + let initial_project_snapshot = self.initial_project_snapshot.clone(); + let mut thread = DbThread { + title: self.title(), + messages: self.messages.clone(), + updated_at: self.updated_at, + detailed_summary: self.summary.clone(), + initial_project_snapshot: None, + cumulative_token_usage: self.cumulative_token_usage, + request_token_usage: self.request_token_usage.clone(), + model: self.model.as_ref().map(|model| DbLanguageModel { + provider: model.provider_id().to_string(), + model: model.name().0.to_string(), + }), + completion_mode: Some(self.completion_mode), + profile: Some(self.profile_id.clone()), }; - let mut new_summary = new_summary.into(); + cx.background_spawn(async move { + let initial_project_snapshot = initial_project_snapshot.await; + thread.initial_project_snapshot = initial_project_snapshot; + thread + }) + } - if new_summary.is_empty() { - new_summary = ThreadSummary::DEFAULT; - } + /// Create a snapshot of the current project state including git information and unsaved buffers. + fn project_snapshot( + project: Entity, + cx: &mut Context, + ) -> Task> { + let git_store = project.read(cx).git_store().clone(); + let worktree_snapshots: Vec<_> = project + .read(cx) + .visible_worktrees(cx) + .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx)) + .collect(); - if current_summary != &new_summary { - self.summary = ThreadSummary::Ready(new_summary); - cx.emit(ThreadEvent::SummaryChanged); - } - } + cx.spawn(async move |_, _| { + let worktree_snapshots = futures::future::join_all(worktree_snapshots).await; - pub fn completion_mode(&self) -> CompletionMode { - self.completion_mode + Arc::new(ProjectSnapshot { + worktree_snapshots, + timestamp: Utc::now(), + }) + }) } - pub fn set_completion_mode(&mut self, mode: CompletionMode) { - self.completion_mode = mode; - } + fn worktree_snapshot( + worktree: Entity, + git_store: Entity, + cx: &App, + ) -> Task { + cx.spawn(async move |cx| { + // Get worktree path and snapshot + let worktree_info = cx.update(|app_cx| { + let worktree = worktree.read(app_cx); + let path = worktree.abs_path().to_string_lossy().into_owned(); + let snapshot = worktree.snapshot(); + (path, snapshot) + }); - pub fn message(&self, id: MessageId) -> Option<&Message> { - let index = self - .messages - .binary_search_by(|message| message.id.cmp(&id)) - .ok()?; + let Ok((worktree_path, _snapshot)) = worktree_info else { + return WorktreeSnapshot { + worktree_path: String::new(), + git_state: None, + }; + }; - self.messages.get(index) - } + let git_state = git_store + .update(cx, |git_store, cx| { + git_store + .repositories() + .values() + .find(|repo| { + repo.read(cx) + .abs_path_to_repo_path(&worktree.read(cx).abs_path()) + .is_some() + }) + .cloned() + }) + .ok() + .flatten() + .map(|repo| { + repo.update(cx, |repo, _| { + let current_branch = + repo.branch.as_ref().map(|branch| branch.name().to_owned()); + repo.send_job(None, |state, _| async move { + let RepositoryState::Local { backend, .. } = state else { + return GitState { + remote_url: None, + head_sha: None, + current_branch, + diff: None, + }; + }; - pub fn messages(&self) -> impl ExactSizeIterator { - self.messages.iter() - } + let remote_url = backend.remote_url("origin"); + let head_sha = backend.head_sha().await; + let diff = backend.diff(DiffType::HeadToWorktree).await.ok(); - pub fn is_generating(&self) -> bool { - !self.pending_completions.is_empty() || !self.all_tools_finished() - } + GitState { + remote_url, + head_sha, + current_branch, + diff, + } + }) + }) + }); - /// Indicates whether streaming of language model events is stale. - /// When `is_generating()` is false, this method returns `None`. - pub fn is_generation_stale(&self) -> Option { - const STALE_THRESHOLD: u128 = 250; + let git_state = match git_state { + Some(git_state) => match git_state.ok() { + Some(git_state) => git_state.await.ok(), + None => None, + }, + None => None, + }; - self.last_received_chunk_at - .map(|instant| instant.elapsed().as_millis() > STALE_THRESHOLD) + WorktreeSnapshot { + worktree_path, + git_state, + } + }) } - fn received_chunk(&mut self) { - self.last_received_chunk_at = Some(Instant::now()); + pub fn project_context(&self) -> &Entity { + &self.project_context } - pub fn queue_state(&self) -> Option { - self.pending_completions - .first() - .map(|pending_completion| pending_completion.queue_state) + pub fn project(&self) -> &Entity { + &self.project } - pub fn tools(&self) -> &Entity { - &self.tools + pub fn action_log(&self) -> &Entity { + &self.action_log } - pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> { - self.tool_use - .pending_tool_uses() - .into_iter() - .find(|tool_use| &tool_use.id == id) + pub fn is_empty(&self) -> bool { + self.messages.is_empty() && self.title.is_none() } - pub fn tools_needing_confirmation(&self) -> impl Iterator { - self.tool_use - .pending_tool_uses() - .into_iter() - .filter(|tool_use| tool_use.status.needs_confirmation()) + pub fn model(&self) -> Option<&Arc> { + self.model.as_ref() } - pub fn has_pending_tool_uses(&self) -> bool { - !self.tool_use.pending_tool_uses().is_empty() + pub fn set_model(&mut self, model: Arc, cx: &mut Context) { + let old_usage = self.latest_token_usage(); + self.model = Some(model); + let new_caps = Self::prompt_capabilities(self.model.as_deref()); + let new_usage = self.latest_token_usage(); + if old_usage != new_usage { + cx.emit(TokenUsageUpdated(new_usage)); + } + self.prompt_capabilities_tx.send(new_caps).log_err(); + cx.notify() } - pub fn checkpoint_for_message(&self, id: MessageId) -> Option { - self.checkpoints_by_message.get(&id).cloned() + pub fn summarization_model(&self) -> Option<&Arc> { + self.summarization_model.as_ref() } - pub fn restore_checkpoint( + pub fn set_summarization_model( &mut self, - checkpoint: ThreadCheckpoint, + model: Option>, cx: &mut Context, - ) -> Task> { - self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending { - message_id: checkpoint.message_id, - }); - cx.emit(ThreadEvent::CheckpointChanged); - cx.notify(); + ) { + self.summarization_model = model; + cx.notify() + } - let git_store = self.project().read(cx).git_store().clone(); - let restore = git_store.update(cx, |git_store, cx| { - git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx) - }); + pub fn completion_mode(&self) -> CompletionMode { + self.completion_mode + } - cx.spawn(async move |this, cx| { - let result = restore.await; - this.update(cx, |this, cx| { - if let Err(err) = result.as_ref() { - this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error { - message_id: checkpoint.message_id, - error: err.to_string(), - }); - } else { - this.truncate(checkpoint.message_id, cx); - this.last_restore_checkpoint = None; - } - this.pending_checkpoint = None; - cx.emit(ThreadEvent::CheckpointChanged); - cx.notify(); - })?; - result - }) + pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context) { + let old_usage = self.latest_token_usage(); + self.completion_mode = mode; + let new_usage = self.latest_token_usage(); + if old_usage != new_usage { + cx.emit(TokenUsageUpdated(new_usage)); + } + cx.notify() } - fn finalize_pending_checkpoint(&mut self, cx: &mut Context) { - let pending_checkpoint = if self.is_generating() { - return; - } else if let Some(checkpoint) = self.pending_checkpoint.take() { - checkpoint + #[cfg(any(test, feature = "test-support"))] + pub fn last_message(&self) -> Option { + if let Some(message) = self.pending_message.clone() { + Some(Message::Agent(message)) } else { - return; - }; - - self.finalize_checkpoint(pending_checkpoint, cx); + self.messages.last().cloned() + } } - fn finalize_checkpoint( + pub fn add_default_tools( &mut self, - pending_checkpoint: ThreadCheckpoint, + environment: Rc, cx: &mut Context, ) { - let git_store = self.project.read(cx).git_store().clone(); - let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx)); - cx.spawn(async move |this, cx| match final_checkpoint.await { - Ok(final_checkpoint) => { - let equal = git_store - .update(cx, |store, cx| { - store.compare_checkpoints( - pending_checkpoint.git_checkpoint.clone(), - final_checkpoint.clone(), - cx, - ) - })? - .await - .unwrap_or(false); - - this.update(cx, |this, cx| { - this.pending_checkpoint = if equal { - Some(pending_checkpoint) - } else { - this.insert_checkpoint(pending_checkpoint, cx); - Some(ThreadCheckpoint { - message_id: this.next_message_id, - git_checkpoint: final_checkpoint, - }) - } - })?; - - Ok(()) - } - Err(_) => this.update(cx, |this, cx| { - this.insert_checkpoint(pending_checkpoint, cx) - }), - }) - .detach(); + let language_registry = self.project.read(cx).languages().clone(); + self.add_tool(CopyPathTool::new(self.project.clone())); + self.add_tool(CreateDirectoryTool::new(self.project.clone())); + self.add_tool(DeletePathTool::new( + self.project.clone(), + self.action_log.clone(), + )); + self.add_tool(DiagnosticsTool::new(self.project.clone())); + self.add_tool(EditFileTool::new( + self.project.clone(), + cx.weak_entity(), + language_registry, + Templates::new(), + )); + self.add_tool(FetchTool::new(self.project.read(cx).client().http_client())); + self.add_tool(FindPathTool::new(self.project.clone())); + self.add_tool(GrepTool::new(self.project.clone())); + self.add_tool(ListDirectoryTool::new(self.project.clone())); + self.add_tool(MovePathTool::new(self.project.clone())); + self.add_tool(NowTool); + self.add_tool(OpenTool::new(self.project.clone())); + self.add_tool(ReadFileTool::new( + self.project.clone(), + self.action_log.clone(), + )); + self.add_tool(TerminalTool::new(self.project.clone(), environment)); + self.add_tool(ThinkingTool); + self.add_tool(WebSearchTool); } - fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context) { - self.checkpoints_by_message - .insert(checkpoint.message_id, checkpoint); - cx.emit(ThreadEvent::CheckpointChanged); - cx.notify(); + pub fn add_tool(&mut self, tool: T) { + self.tools.insert(T::name().into(), tool.erase()); } - pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> { - self.last_restore_checkpoint.as_ref() + pub fn remove_tool(&mut self, name: &str) -> bool { + self.tools.remove(name).is_some() } - pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context) { - let Some(message_ix) = self - .messages - .iter() - .rposition(|message| message.id == message_id) - else { - return; - }; - for deleted_message in self.messages.drain(message_ix..) { - self.checkpoints_by_message.remove(&deleted_message.id); - } - cx.notify(); + pub fn profile(&self) -> &AgentProfileId { + &self.profile_id } - pub fn context_for_message(&self, id: MessageId) -> impl Iterator { - self.messages - .iter() - .find(|message| message.id == id) - .into_iter() - .flat_map(|message| message.loaded_context.contexts.iter()) + pub fn set_profile(&mut self, profile_id: AgentProfileId) { + self.profile_id = profile_id; } - pub fn is_turn_end(&self, ix: usize) -> bool { - if self.messages.is_empty() { - return false; + pub fn cancel(&mut self, cx: &mut Context) { + if let Some(running_turn) = self.running_turn.take() { + running_turn.cancel(); } + self.flush_pending_message(cx); + } - if !self.is_generating() && ix == self.messages.len() - 1 { - return true; - } + fn update_token_usage(&mut self, update: language_model::TokenUsage, cx: &mut Context) { + let Some(last_user_message) = self.last_user_message() else { + return; + }; - let Some(message) = self.messages.get(ix) else { - return false; + self.request_token_usage + .insert(last_user_message.id.clone(), update); + cx.emit(TokenUsageUpdated(self.latest_token_usage())); + cx.notify(); + } + + pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context) -> Result<()> { + self.cancel(cx); + let Some(position) = self.messages.iter().position( + |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id), + ) else { + return Err(anyhow!("Message not found")); }; - if message.role != Role::Assistant { - return false; + for message in self.messages.drain(position..) { + match message { + Message::User(message) => { + self.request_token_usage.remove(&message.id); + } + Message::Agent(_) | Message::Resume => {} + } } - - self.messages - .get(ix + 1) - .and_then(|message| { - self.message(message.id) - .map(|next_message| next_message.role == Role::User && !next_message.is_hidden) - }) - .unwrap_or(false) + self.clear_summary(); + cx.notify(); + Ok(()) } - pub fn tool_use_limit_reached(&self) -> bool { - self.tool_use_limit_reached - } + pub fn latest_token_usage(&self) -> Option { + let last_user_message = self.last_user_message()?; + let tokens = self.request_token_usage.get(&last_user_message.id)?; + let model = self.model.clone()?; - /// Returns whether all of the tool uses have finished running. - pub fn all_tools_finished(&self) -> bool { - // If the only pending tool uses left are the ones with errors, then - // that means that we've finished running all of the pending tools. - self.tool_use - .pending_tool_uses() - .iter() - .all(|pending_tool_use| pending_tool_use.status.is_error()) + Some(acp_thread::TokenUsage { + max_tokens: model.max_token_count_for_mode(self.completion_mode.into()), + used_tokens: tokens.total_tokens(), + }) } - /// Returns whether any pending tool uses may perform edits - pub fn has_pending_edit_tool_uses(&self) -> bool { - self.tool_use - .pending_tool_uses() - .iter() - .filter(|pending_tool_use| !pending_tool_use.status.is_error()) - .any(|pending_tool_use| pending_tool_use.may_perform_edits) - } + pub fn resume( + &mut self, + cx: &mut Context, + ) -> Result>> { + self.messages.push(Message::Resume); + cx.notify(); - pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec { - self.tool_use.tool_uses_for_message(id, &self.project, cx) + log::debug!("Total messages in thread: {}", self.messages.len()); + self.run_turn(cx) } - pub fn tool_results_for_message( - &self, - assistant_message_id: MessageId, - ) -> Vec<&LanguageModelToolResult> { - self.tool_use.tool_results_for_message(assistant_message_id) - } + /// Sending a message results in the model streaming a response, which could include tool calls. + /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent. + /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn. + pub fn send( + &mut self, + id: UserMessageId, + content: impl IntoIterator, + cx: &mut Context, + ) -> Result>> + where + T: Into, + { + let model = self.model().context("No language model configured")?; - pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> { - self.tool_use.tool_result(id) - } + log::info!("Thread::send called with model: {}", model.name().0); + self.advance_prompt_id(); - pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc> { - match &self.tool_use.tool_result(id)?.content { - LanguageModelToolResultContent::Text(text) => Some(text), - LanguageModelToolResultContent::Image(_) => { - // TODO: We should display image - None - } - } - } + let content = content.into_iter().map(Into::into).collect::>(); + log::debug!("Thread::send content: {:?}", content); - pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option { - self.tool_use.tool_result_card(id).cloned() - } + self.messages + .push(Message::User(UserMessage { id, content })); + cx.notify(); - /// Return tools that are both enabled and supported by the model - pub fn available_tools( - &self, - cx: &App, - model: Arc, - ) -> Vec { - if model.supports_tools() { - self.profile - .enabled_tools(cx) - .into_iter() - .filter_map(|(name, tool)| { - // Skip tools that cannot be supported - let input_schema = tool.input_schema(model.tool_input_format()).ok()?; - Some(LanguageModelRequestTool { - name: name.into(), - description: tool.description(), - input_schema, - }) - }) - .collect() - } else { - Vec::default() - } + log::debug!("Total messages in thread: {}", self.messages.len()); + self.run_turn(cx) } - pub fn insert_user_message( + fn run_turn( &mut self, - text: impl Into, - loaded_context: ContextLoadResult, - git_checkpoint: Option, - creases: Vec, cx: &mut Context, - ) -> MessageId { - if !loaded_context.referenced_buffers.is_empty() { - self.action_log.update(cx, |log, cx| { - for buffer in loaded_context.referenced_buffers { - log.buffer_read(buffer, cx); + ) -> Result>> { + self.cancel(cx); + + let model = self.model.clone().context("No language model configured")?; + let profile = AgentSettings::get_global(cx) + .profiles + .get(&self.profile_id) + .context("Profile not found")?; + let (events_tx, events_rx) = mpsc::unbounded::>(); + let event_stream = ThreadEventStream(events_tx); + let message_ix = self.messages.len().saturating_sub(1); + self.tool_use_limit_reached = false; + self.clear_summary(); + self.running_turn = Some(RunningTurn { + event_stream: event_stream.clone(), + tools: self.enabled_tools(profile, &model, cx), + _task: cx.spawn(async move |this, cx| { + log::debug!("Starting agent turn execution"); + + let turn_result = Self::run_turn_internal(&this, model, &event_stream, cx).await; + _ = this.update(cx, |this, cx| this.flush_pending_message(cx)); + + match turn_result { + Ok(()) => { + log::debug!("Turn execution completed"); + event_stream.send_stop(acp::StopReason::EndTurn); + } + Err(error) => { + log::error!("Turn execution failed: {:?}", error); + match error.downcast::() { + Ok(CompletionError::Refusal) => { + event_stream.send_stop(acp::StopReason::Refusal); + _ = this.update(cx, |this, _| this.messages.truncate(message_ix)); + } + Ok(CompletionError::MaxTokens) => { + event_stream.send_stop(acp::StopReason::MaxTokens); + } + Ok(CompletionError::Other(error)) | Err(error) => { + event_stream.send_error(error); + } + } + } } - }); - } - let message_id = self.insert_message( - Role::User, - vec![MessageSegment::Text(text.into())], - loaded_context.loaded_context, - creases, - false, - cx, - ); + _ = this.update(cx, |this, _| this.running_turn.take()); + }), + }); + Ok(events_rx) + } - if let Some(git_checkpoint) = git_checkpoint { - self.pending_checkpoint = Some(ThreadCheckpoint { - message_id, - git_checkpoint, - }); - } + async fn run_turn_internal( + this: &WeakEntity, + model: Arc, + event_stream: &ThreadEventStream, + cx: &mut AsyncApp, + ) -> Result<()> { + let mut attempt = 0; + let mut intent = CompletionIntent::UserPrompt; + loop { + let request = + this.update(cx, |this, cx| this.build_completion_request(intent, cx))??; - message_id - } + telemetry::event!( + "Agent Thread Completion", + thread_id = this.read_with(cx, |this, _| this.id.to_string())?, + prompt_id = this.read_with(cx, |this, _| this.prompt_id.to_string())?, + model = model.telemetry_id(), + model_provider = model.provider_id().to_string(), + attempt + ); - pub fn insert_invisible_continue_message(&mut self, cx: &mut Context) -> MessageId { - let id = self.insert_message( - Role::User, - vec![MessageSegment::Text("Continue where you left off".into())], - LoadedContext::default(), - vec![], - true, - cx, - ); - self.pending_checkpoint = None; + log::debug!("Calling model.stream_completion, attempt {}", attempt); - id - } + let (mut events, mut error) = match model.stream_completion(request, cx).await { + Ok(events) => (events, None), + Err(err) => (stream::empty().boxed(), Some(err)), + }; + let mut tool_results = FuturesUnordered::new(); + while let Some(event) = events.next().await { + log::trace!("Received completion event: {:?}", event); + match event { + Ok(event) => { + tool_results.extend(this.update(cx, |this, cx| { + this.handle_completion_event(event, event_stream, cx) + })??); + } + Err(err) => { + error = Some(err); + break; + } + } + } - pub fn insert_assistant_message( - &mut self, - segments: Vec, - cx: &mut Context, - ) -> MessageId { - self.insert_message( - Role::Assistant, - segments, - LoadedContext::default(), - Vec::new(), - false, - cx, - ) - } + let end_turn = tool_results.is_empty(); + while let Some(tool_result) = tool_results.next().await { + log::debug!("Tool finished {:?}", tool_result); - pub fn insert_message( - &mut self, - role: Role, - segments: Vec, - loaded_context: LoadedContext, - creases: Vec, - is_hidden: bool, - cx: &mut Context, - ) -> MessageId { - let id = self.next_message_id.post_inc(); - self.messages.push(Message { - id, - role, - segments, - loaded_context, - creases, - is_hidden, - ui_only: false, - }); - self.touch_updated_at(); - cx.emit(ThreadEvent::MessageAdded(id)); - id + event_stream.update_tool_call_fields( + &tool_result.tool_use_id, + acp::ToolCallUpdateFields { + status: Some(if tool_result.is_error { + acp::ToolCallStatus::Failed + } else { + acp::ToolCallStatus::Completed + }), + raw_output: tool_result.output.clone(), + ..Default::default() + }, + ); + this.update(cx, |this, _cx| { + this.pending_message() + .tool_results + .insert(tool_result.tool_use_id.clone(), tool_result); + })?; + } + + this.update(cx, |this, cx| { + this.flush_pending_message(cx); + if this.title.is_none() && this.pending_title_generation.is_none() { + this.generate_title(cx); + } + })?; + + if let Some(error) = error { + attempt += 1; + let retry = this.update(cx, |this, cx| { + let user_store = this.user_store.read(cx); + this.handle_completion_error(error, attempt, user_store.plan()) + })??; + let timer = cx.background_executor().timer(retry.duration); + event_stream.send_retry(retry); + timer.await; + this.update(cx, |this, _cx| { + if let Some(Message::Agent(message)) = this.messages.last() { + if message.tool_results.is_empty() { + intent = CompletionIntent::UserPrompt; + this.messages.push(Message::Resume); + } + } + })?; + } else if this.read_with(cx, |this, _| this.tool_use_limit_reached)? { + return Err(language_model::ToolUseLimitReachedError.into()); + } else if end_turn { + return Ok(()); + } else { + intent = CompletionIntent::ToolResults; + attempt = 0; + } + } } - pub fn edit_message( + fn handle_completion_error( &mut self, - id: MessageId, - new_role: Role, - new_segments: Vec, - creases: Vec, - loaded_context: Option, - checkpoint: Option, - cx: &mut Context, - ) -> bool { - let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else { - return false; + error: LanguageModelCompletionError, + attempt: u8, + plan: Option, + ) -> Result { + let Some(model) = self.model.as_ref() else { + return Err(anyhow!(error)); + }; + + let auto_retry = if model.provider_id() == ZED_CLOUD_PROVIDER_ID { + match plan { + Some(Plan::V2(_)) => true, + Some(Plan::V1(_)) => self.completion_mode == CompletionMode::Burn, + None => false, + } + } else { + true }; - message.role = new_role; - message.segments = new_segments; - message.creases = creases; - if let Some(context) = loaded_context { - message.loaded_context = context; + + if !auto_retry { + return Err(anyhow!(error)); } - if let Some(git_checkpoint) = checkpoint { - self.checkpoints_by_message.insert( - id, - ThreadCheckpoint { - message_id: id, - git_checkpoint, - }, - ); + + let Some(strategy) = Self::retry_strategy_for(&error) else { + return Err(anyhow!(error)); + }; + + let max_attempts = match &strategy { + RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts, + RetryStrategy::Fixed { max_attempts, .. } => *max_attempts, + }; + + if attempt > max_attempts { + return Err(anyhow!(error)); } - self.touch_updated_at(); - cx.emit(ThreadEvent::MessageEdited(id)); - true - } - pub fn delete_message(&mut self, id: MessageId, cx: &mut Context) -> bool { - let Some(index) = self.messages.iter().position(|message| message.id == id) else { - return false; + let delay = match &strategy { + RetryStrategy::ExponentialBackoff { initial_delay, .. } => { + let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32); + Duration::from_secs(delay_secs) + } + RetryStrategy::Fixed { delay, .. } => *delay, }; - self.messages.remove(index); - self.touch_updated_at(); - cx.emit(ThreadEvent::MessageDeleted(id)); - true + log::debug!("Retry attempt {attempt} with delay {delay:?}"); + + Ok(acp_thread::RetryStatus { + last_error: error.to_string().into(), + attempt: attempt as usize, + max_attempts: max_attempts as usize, + started_at: Instant::now(), + duration: delay, + }) } - /// Returns the representation of this [`Thread`] in a textual form. - /// - /// This is the representation we use when attaching a thread as context to another thread. - pub fn text(&self) -> String { - let mut text = String::new(); - - for message in &self.messages { - text.push_str(match message.role { - language_model::Role::User => "User:", - language_model::Role::Assistant => "Agent:", - language_model::Role::System => "System:", - }); - text.push('\n'); + /// A helper method that's called on every streamed completion event. + /// Returns an optional tool result task, which the main agentic loop will + /// send back to the model when it resolves. + fn handle_completion_event( + &mut self, + event: LanguageModelCompletionEvent, + event_stream: &ThreadEventStream, + cx: &mut Context, + ) -> Result>> { + log::trace!("Handling streamed completion event: {:?}", event); + use LanguageModelCompletionEvent::*; - for segment in &message.segments { - match segment { - MessageSegment::Text(content) => text.push_str(content), - MessageSegment::Thinking { text: content, .. } => { - text.push_str(&format!("{}", content)) - } - MessageSegment::RedactedThinking(_) => {} - } + match event { + StartMessage { .. } => { + self.flush_pending_message(cx); + self.pending_message = Some(AgentMessage::default()); + } + Text(new_text) => self.handle_text_event(new_text, event_stream, cx), + Thinking { text, signature } => { + self.handle_thinking_event(text, signature, event_stream, cx) + } + RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx), + ToolUse(tool_use) => { + return Ok(self.handle_tool_use_event(tool_use, event_stream, cx)); + } + ToolUseJsonParseError { + id, + tool_name, + raw_input, + json_parse_error, + } => { + return Ok(Some(Task::ready( + self.handle_tool_use_json_parse_error_event( + id, + tool_name, + raw_input, + json_parse_error, + ), + ))); + } + UsageUpdate(usage) => { + telemetry::event!( + "Agent Thread Completion Usage Updated", + thread_id = self.id.to_string(), + prompt_id = self.prompt_id.to_string(), + model = self.model.as_ref().map(|m| m.telemetry_id()), + model_provider = self.model.as_ref().map(|m| m.provider_id().to_string()), + input_tokens = usage.input_tokens, + output_tokens = usage.output_tokens, + cache_creation_input_tokens = usage.cache_creation_input_tokens, + cache_read_input_tokens = usage.cache_read_input_tokens, + ); + self.update_token_usage(usage, cx); + } + StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => { + self.update_model_request_usage(amount, limit, cx); } - text.push('\n'); + StatusUpdate( + CompletionRequestStatus::Started + | CompletionRequestStatus::Queued { .. } + | CompletionRequestStatus::Failed { .. }, + ) => {} + StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => { + self.tool_use_limit_reached = true; + } + Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()), + Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()), + Stop(StopReason::ToolUse | StopReason::EndTurn) => {} } - text + Ok(None) } - /// Serializes this thread into a format for storage or telemetry. - pub fn serialize(&self, cx: &mut Context) -> Task> { - let initial_project_snapshot = self.initial_project_snapshot.clone(); - cx.spawn(async move |this, cx| { - let initial_project_snapshot = initial_project_snapshot.await; - this.read_with(cx, |this, cx| SerializedThread { - version: SerializedThread::VERSION.to_string(), - summary: this.summary().or_default(), - updated_at: this.updated_at(), - messages: this - .messages() - .filter(|message| !message.ui_only) - .map(|message| SerializedMessage { - id: message.id, - role: message.role, - segments: message - .segments - .iter() - .map(|segment| match segment { - MessageSegment::Text(text) => { - SerializedMessageSegment::Text { text: text.clone() } - } - MessageSegment::Thinking { text, signature } => { - SerializedMessageSegment::Thinking { - text: text.clone(), - signature: signature.clone(), - } - } - MessageSegment::RedactedThinking(data) => { - SerializedMessageSegment::RedactedThinking { - data: data.clone(), - } - } - }) - .collect(), - tool_uses: this - .tool_uses_for_message(message.id, cx) - .into_iter() - .map(|tool_use| SerializedToolUse { - id: tool_use.id, - name: tool_use.name, - input: tool_use.input, - }) - .collect(), - tool_results: this - .tool_results_for_message(message.id) - .into_iter() - .map(|tool_result| SerializedToolResult { - tool_use_id: tool_result.tool_use_id.clone(), - is_error: tool_result.is_error, - content: tool_result.content.clone(), - output: tool_result.output.clone(), - }) - .collect(), - context: message.loaded_context.text.clone(), - creases: message - .creases - .iter() - .map(|crease| SerializedCrease { - start: crease.range.start, - end: crease.range.end, - icon_path: crease.icon_path.clone(), - label: crease.label.clone(), - }) - .collect(), - is_hidden: message.is_hidden, - }) - .collect(), - initial_project_snapshot, - cumulative_token_usage: this.cumulative_token_usage, - request_token_usage: this.request_token_usage.clone(), - detailed_summary_state: this.detailed_summary_rx.borrow().clone(), - exceeded_window_error: this.exceeded_window_error.clone(), - model: this - .configured_model - .as_ref() - .map(|model| SerializedLanguageModel { - provider: model.provider.id().0.to_string(), - model: model.model.id().0.to_string(), - }), - completion_mode: Some(this.completion_mode), - tool_use_limit_reached: this.tool_use_limit_reached, - profile: Some(this.profile.id().clone()), - }) - }) - } + fn handle_text_event( + &mut self, + new_text: String, + event_stream: &ThreadEventStream, + cx: &mut Context, + ) { + event_stream.send_text(&new_text); - pub fn remaining_turns(&self) -> u32 { - self.remaining_turns - } + let last_message = self.pending_message(); + if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() { + text.push_str(&new_text); + } else { + last_message + .content + .push(AgentMessageContent::Text(new_text)); + } - pub fn set_remaining_turns(&mut self, remaining_turns: u32) { - self.remaining_turns = remaining_turns; + cx.notify(); } - pub fn send_to_model( + fn handle_thinking_event( &mut self, - model: Arc, - intent: CompletionIntent, - window: Option, + new_text: String, + new_signature: Option, + event_stream: &ThreadEventStream, cx: &mut Context, ) { - if self.remaining_turns == 0 { - return; - } + event_stream.send_thinking(&new_text); - self.remaining_turns -= 1; + let last_message = self.pending_message(); + if let Some(AgentMessageContent::Thinking { text, signature }) = + last_message.content.last_mut() + { + text.push_str(&new_text); + *signature = new_signature.or(signature.take()); + } else { + last_message.content.push(AgentMessageContent::Thinking { + text: new_text, + signature: new_signature, + }); + } - self.flush_notifications(model.clone(), intent, cx); + cx.notify(); + } - let _checkpoint = self.finalize_pending_checkpoint(cx); - self.stream_completion( - self.to_completion_request(model.clone(), intent, cx), - model, - intent, - window, - cx, - ); + fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context) { + let last_message = self.pending_message(); + last_message + .content + .push(AgentMessageContent::RedactedThinking(data)); + cx.notify(); } - pub fn to_completion_request( - &self, - model: Arc, - intent: CompletionIntent, + fn handle_tool_use_event( + &mut self, + tool_use: LanguageModelToolUse, + event_stream: &ThreadEventStream, cx: &mut Context, - ) -> LanguageModelRequest { - let mut request = LanguageModelRequest { - thread_id: Some(self.id.to_string()), - prompt_id: Some(self.last_prompt_id.to_string()), - intent: Some(intent), - mode: None, - messages: vec![], - tools: Vec::new(), - tool_choice: None, - stop: Vec::new(), - temperature: AgentSettings::temperature_for_model(&model, cx), - thinking_allowed: true, - }; - - let available_tools = self.available_tools(cx, model.clone()); - let available_tool_names = available_tools - .iter() - .map(|tool| tool.name.clone()) - .collect(); - - let model_context = &ModelContext { - available_tools: available_tool_names, - }; + ) -> Option> { + cx.notify(); - if let Some(project_context) = self.project_context.borrow().as_ref() { - match self - .prompt_builder - .generate_assistant_system_prompt(project_context, model_context) - { - Err(err) => { - let message = format!("{err:?}").into(); - log::error!("{message}"); - cx.emit(ThreadEvent::ShowError(ThreadError::Message { - header: "Error generating system prompt".into(), - message, - })); - } - Ok(system_prompt) => { - request.messages.push(LanguageModelRequestMessage { - role: Role::System, - content: vec![MessageContent::Text(system_prompt)], - cache: true, - }); + let tool = self.tool(tool_use.name.as_ref()); + let mut title = SharedString::from(&tool_use.name); + let mut kind = acp::ToolKind::Other; + if let Some(tool) = tool.as_ref() { + title = tool.initial_title(tool_use.input.clone(), cx); + kind = tool.kind(); + } + + // Ensure the last message ends in the current tool use + let last_message = self.pending_message(); + let push_new_tool_use = last_message.content.last_mut().is_none_or(|content| { + if let AgentMessageContent::ToolUse(last_tool_use) = content { + if last_tool_use.id == tool_use.id { + *last_tool_use = tool_use.clone(); + false + } else { + true } + } else { + true } + }); + + if push_new_tool_use { + event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone()); + last_message + .content + .push(AgentMessageContent::ToolUse(tool_use.clone())); } else { - let message = "Context for system prompt unexpectedly not ready.".into(); - log::error!("{message}"); - cx.emit(ThreadEvent::ShowError(ThreadError::Message { - header: "Error generating system prompt".into(), - message, - })); + event_stream.update_tool_call_fields( + &tool_use.id, + acp::ToolCallUpdateFields { + title: Some(title.into()), + kind: Some(kind), + raw_input: Some(tool_use.input.clone()), + ..Default::default() + }, + ); } - let mut message_ix_to_cache = None; - for message in &self.messages { - // ui_only messages are for the UI only, not for the model - if message.ui_only { - continue; - } - - let mut request_message = LanguageModelRequestMessage { - role: message.role, - content: Vec::new(), - cache: false, - }; + if !tool_use.is_input_complete { + return None; + } - message - .loaded_context - .add_to_request_message(&mut request_message); - - for segment in &message.segments { - match segment { - MessageSegment::Text(text) => { - let text = text.trim_end(); - if !text.is_empty() { - request_message - .content - .push(MessageContent::Text(text.into())); - } - } - MessageSegment::Thinking { text, signature } => { - if !text.is_empty() { - request_message.content.push(MessageContent::Thinking { - text: text.into(), - signature: signature.clone(), - }); - } - } - MessageSegment::RedactedThinking(data) => { - request_message - .content - .push(MessageContent::RedactedThinking(data.clone())); - } - }; - } + let Some(tool) = tool else { + let content = format!("No tool named {} exists", tool_use.name); + return Some(Task::ready(LanguageModelToolResult { + content: LanguageModelToolResultContent::Text(Arc::from(content)), + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: true, + output: None, + })); + }; - let mut cache_message = true; - let mut tool_results_message = LanguageModelRequestMessage { - role: Role::User, - content: Vec::new(), - cache: false, - }; - for (tool_use, tool_result) in self.tool_use.tool_results(message.id) { - if let Some(tool_result) = tool_result { - request_message - .content - .push(MessageContent::ToolUse(tool_use.clone())); - tool_results_message - .content - .push(MessageContent::ToolResult(LanguageModelToolResult { - tool_use_id: tool_use.id.clone(), - tool_name: tool_result.tool_name.clone(), - is_error: tool_result.is_error, - content: if tool_result.content.is_empty() { - // Surprisingly, the API fails if we return an empty string here. - // It thinks we are sending a tool use without a tool result. - "".into() - } else { - tool_result.content.clone() - }, - output: None, - })); - } else { - cache_message = false; - log::debug!( - "skipped tool use {:?} because it is still pending", - tool_use - ); + let fs = self.project.read(cx).fs().clone(); + let tool_event_stream = + ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs)); + tool_event_stream.update_fields(acp::ToolCallUpdateFields { + status: Some(acp::ToolCallStatus::InProgress), + ..Default::default() + }); + let supports_images = self.model().is_some_and(|model| model.supports_images()); + let tool_result = tool.run(tool_use.input, tool_event_stream, cx); + log::debug!("Running tool {}", tool_use.name); + Some(cx.foreground_executor().spawn(async move { + let tool_result = tool_result.await.and_then(|output| { + if let LanguageModelToolResultContent::Image(_) = &output.llm_output + && !supports_images + { + return Err(anyhow!( + "Attempted to read an image, but this model doesn't support it.", + )); } - } + Ok(output) + }); - if cache_message { - message_ix_to_cache = Some(request.messages.len()); + match tool_result { + Ok(output) => LanguageModelToolResult { + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: false, + content: output.llm_output, + output: Some(output.raw_output), + }, + Err(error) => LanguageModelToolResult { + tool_use_id: tool_use.id, + tool_name: tool_use.name, + is_error: true, + content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())), + output: Some(error.to_string().into()), + }, } - request.messages.push(request_message); + })) + } - if !tool_results_message.content.is_empty() { - if cache_message { - message_ix_to_cache = Some(request.messages.len()); - } - request.messages.push(tool_results_message); - } + fn handle_tool_use_json_parse_error_event( + &mut self, + tool_use_id: LanguageModelToolUseId, + tool_name: Arc, + raw_input: Arc, + json_parse_error: String, + ) -> LanguageModelToolResult { + let tool_output = format!("Error parsing input JSON: {json_parse_error}"); + LanguageModelToolResult { + tool_use_id, + tool_name, + is_error: true, + content: LanguageModelToolResultContent::Text(tool_output.into()), + output: Some(serde_json::Value::String(raw_input.to_string())), } + } - // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching - if let Some(message_ix_to_cache) = message_ix_to_cache { - request.messages[message_ix_to_cache].cache = true; - } + fn update_model_request_usage(&self, amount: usize, limit: UsageLimit, cx: &mut Context) { + self.project + .read(cx) + .user_store() + .update(cx, |user_store, cx| { + user_store.update_model_request_usage( + ModelRequestUsage(RequestUsage { + amount: amount as i32, + limit, + }), + cx, + ) + }); + } - request.tools = available_tools; - request.mode = if model.supports_burn_mode() { - Some(self.completion_mode.into()) - } else { - Some(CompletionMode::Normal.into()) - }; + pub fn title(&self) -> SharedString { + self.title.clone().unwrap_or("New Thread".into()) + } - request + pub fn is_generating_summary(&self) -> bool { + self.pending_summary_generation.is_some() } - fn to_summarize_request( - &self, - model: &Arc, - intent: CompletionIntent, - added_user_message: String, - cx: &App, - ) -> LanguageModelRequest { + pub fn summary(&mut self, cx: &mut Context) -> Shared>> { + if let Some(summary) = self.summary.as_ref() { + return Task::ready(Some(summary.clone())).shared(); + } + if let Some(task) = self.pending_summary_generation.clone() { + return task; + } + let Some(model) = self.summarization_model.clone() else { + log::error!("No summarization model available"); + return Task::ready(None).shared(); + }; let mut request = LanguageModelRequest { - thread_id: None, - prompt_id: None, - intent: Some(intent), - mode: None, - messages: vec![], - tools: Vec::new(), - tool_choice: None, - stop: Vec::new(), - temperature: AgentSettings::temperature_for_model(model, cx), - thinking_allowed: false, + intent: Some(CompletionIntent::ThreadContextSummarization), + temperature: AgentSettings::temperature_for_model(&model, cx), + ..Default::default() }; for message in &self.messages { - let mut request_message = LanguageModelRequestMessage { - role: message.role, - content: Vec::new(), - cache: false, - }; - - for segment in &message.segments { - match segment { - MessageSegment::Text(text) => request_message - .content - .push(MessageContent::Text(text.clone())), - MessageSegment::Thinking { .. } => {} - MessageSegment::RedactedThinking(_) => {} - } - } - - if request_message.content.is_empty() { - continue; - } - - request.messages.push(request_message); + request.messages.extend(message.to_request()); } request.messages.push(LanguageModelRequestMessage { role: Role::User, - content: vec![MessageContent::Text(added_user_message)], + content: vec![SUMMARIZE_THREAD_DETAILED_PROMPT.into()], cache: false, }); - request - } + let task = cx + .spawn(async move |this, cx| { + let mut summary = String::new(); + let mut messages = model.stream_completion(request, cx).await.log_err()?; + while let Some(event) = messages.next().await { + let event = event.log_err()?; + let text = match event { + LanguageModelCompletionEvent::Text(text) => text, + LanguageModelCompletionEvent::StatusUpdate( + CompletionRequestStatus::UsageUpdated { amount, limit }, + ) => { + this.update(cx, |thread, cx| { + thread.update_model_request_usage(amount, limit, cx); + }) + .ok()?; + continue; + } + _ => continue, + }; - /// Insert auto-generated notifications (if any) to the thread - fn flush_notifications( - &mut self, - model: Arc, - intent: CompletionIntent, - cx: &mut Context, - ) { - match intent { - CompletionIntent::UserPrompt | CompletionIntent::ToolResults => { - if let Some(pending_tool_use) = self.attach_tracked_files_state(model, cx) { - cx.emit(ThreadEvent::ToolFinished { - tool_use_id: pending_tool_use.id.clone(), - pending_tool_use: Some(pending_tool_use), - }); + let mut lines = text.lines(); + summary.extend(lines.next()); } - } - CompletionIntent::ThreadSummarization - | CompletionIntent::ThreadContextSummarization - | CompletionIntent::CreateFile - | CompletionIntent::EditFile - | CompletionIntent::InlineAssist - | CompletionIntent::TerminalInlineAssist - | CompletionIntent::GenerateGitCommitMessage => {} - }; - } - - fn attach_tracked_files_state( - &mut self, - model: Arc, - cx: &mut App, - ) -> Option { - // Represent notification as a simulated `project_notifications` tool call - let tool_name = Arc::from("project_notifications"); - let tool = self.tools.read(cx).tool(&tool_name, cx)?; - - if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) { - return None; - } - if self - .action_log - .update(cx, |log, cx| log.unnotified_user_edits(cx).is_none()) - { - return None; - } + log::debug!("Setting summary: {}", summary); + let summary = SharedString::from(summary); - let input = serde_json::json!({}); - let request = Arc::new(LanguageModelRequest::default()); // unused - let window = None; - let tool_result = tool.run( - input, - request, - self.project.clone(), - self.action_log.clone(), - model.clone(), - window, - cx, - ); + this.update(cx, |this, cx| { + this.summary = Some(summary.clone()); + this.pending_summary_generation = None; + cx.notify() + }) + .ok()?; - let tool_use_id = - LanguageModelToolUseId::from(format!("project_notifications_{}", self.messages.len())); + Some(summary) + }) + .shared(); + self.pending_summary_generation = Some(task.clone()); + task + } - let tool_use = LanguageModelToolUse { - id: tool_use_id.clone(), - name: tool_name.clone(), - raw_input: "{}".to_string(), - input: serde_json::json!({}), - is_input_complete: true, + fn generate_title(&mut self, cx: &mut Context) { + let Some(model) = self.summarization_model.clone() else { + return; }; - let tool_output = cx.background_executor().block(tool_result.output); - - // Attach a project_notification tool call to the latest existing - // Assistant message. We cannot create a new Assistant message - // because thinking models require a `thinking` block that we - // cannot mock. We cannot send a notification as a normal - // (non-tool-use) User message because this distracts Agent - // too much. - let tool_message_id = self - .messages - .iter() - .enumerate() - .rfind(|(_, message)| message.role == Role::Assistant) - .map(|(_, message)| message.id)?; - - let tool_use_metadata = ToolUseMetadata { - model: model.clone(), - thread_id: self.id.clone(), - prompt_id: self.last_prompt_id.clone(), + log::debug!( + "Generating title with model: {:?}", + self.summarization_model.as_ref().map(|model| model.name()) + ); + let mut request = LanguageModelRequest { + intent: Some(CompletionIntent::ThreadSummarization), + temperature: AgentSettings::temperature_for_model(&model, cx), + ..Default::default() }; - self.tool_use - .request_tool_use(tool_message_id, tool_use, tool_use_metadata, cx); - - self.tool_use.insert_tool_output( - tool_use_id, - tool_name, - tool_output, - self.configured_model.as_ref(), - self.completion_mode, - ) - } - - pub fn stream_completion( - &mut self, - request: LanguageModelRequest, - model: Arc, - intent: CompletionIntent, - window: Option, - cx: &mut Context, - ) { - self.tool_use_limit_reached = false; - - let pending_completion_id = post_inc(&mut self.completion_count); - let mut request_callback_parameters = if self.request_callback.is_some() { - Some((request.clone(), Vec::new())) - } else { - None - }; - let prompt_id = self.last_prompt_id.clone(); - let tool_use_metadata = ToolUseMetadata { - model: model.clone(), - thread_id: self.id.clone(), - prompt_id: prompt_id.clone(), - }; + for message in &self.messages { + request.messages.extend(message.to_request()); + } - let completion_mode = request - .mode - .unwrap_or(cloud_llm_client::CompletionMode::Normal); + request.messages.push(LanguageModelRequestMessage { + role: Role::User, + content: vec![SUMMARIZE_THREAD_PROMPT.into()], + cache: false, + }); + self.pending_title_generation = Some(cx.spawn(async move |this, cx| { + let mut title = String::new(); - self.last_received_chunk_at = Some(Instant::now()); + let generate = async { + let mut messages = model.stream_completion(request, cx).await?; + while let Some(event) = messages.next().await { + let event = event?; + let text = match event { + LanguageModelCompletionEvent::Text(text) => text, + LanguageModelCompletionEvent::StatusUpdate( + CompletionRequestStatus::UsageUpdated { amount, limit }, + ) => { + this.update(cx, |thread, cx| { + thread.update_model_request_usage(amount, limit, cx); + })?; + continue; + } + _ => continue, + }; - let task = cx.spawn(async move |thread, cx| { - let stream_completion_future = model.stream_completion(request, cx); - let initial_token_usage = - thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage); - let stream_completion = async { - let mut events = stream_completion_future.await?; + let mut lines = text.lines(); + title.extend(lines.next()); - let mut stop_reason = StopReason::EndTurn; - let mut current_token_usage = TokenUsage::default(); + // Stop if the LLM generated multiple lines. + if lines.next().is_some() { + break; + } + } + anyhow::Ok(()) + }; - thread - .update(cx, |_thread, cx| { - cx.emit(ThreadEvent::NewRequest); - }) - .ok(); + if generate.await.context("failed to generate title").is_ok() { + _ = this.update(cx, |this, cx| this.set_title(title.into(), cx)); + } + _ = this.update(cx, |this, _| this.pending_title_generation = None); + })); + } - let mut request_assistant_message_id = None; + pub fn set_title(&mut self, title: SharedString, cx: &mut Context) { + self.pending_title_generation = None; + if Some(&title) != self.title.as_ref() { + self.title = Some(title); + cx.emit(TitleUpdated); + cx.notify(); + } + } - while let Some(event) = events.next().await { - if let Some((_, response_events)) = request_callback_parameters.as_mut() { - response_events - .push(event.as_ref().map_err(|error| error.to_string()).cloned()); - } + fn clear_summary(&mut self) { + self.summary = None; + self.pending_summary_generation = None; + } - thread.update(cx, |thread, cx| { - match event? { - LanguageModelCompletionEvent::StartMessage { .. } => { - request_assistant_message_id = - Some(thread.insert_assistant_message( - vec![MessageSegment::Text(String::new())], - cx, - )); - } - LanguageModelCompletionEvent::Stop(reason) => { - stop_reason = reason; - } - LanguageModelCompletionEvent::UsageUpdate(token_usage) => { - thread.update_token_usage_at_last_message(token_usage); - thread.cumulative_token_usage = thread.cumulative_token_usage - + token_usage - - current_token_usage; - current_token_usage = token_usage; - } - LanguageModelCompletionEvent::Text(chunk) => { - thread.received_chunk(); - - cx.emit(ThreadEvent::ReceivedTextChunk); - if let Some(last_message) = thread.messages.last_mut() { - if last_message.role == Role::Assistant - && !thread.tool_use.has_tool_results(last_message.id) - { - last_message.push_text(&chunk); - cx.emit(ThreadEvent::StreamedAssistantText( - last_message.id, - chunk, - )); - } else { - // If we won't have an Assistant message yet, assume this chunk marks the beginning - // of a new Assistant response. - // - // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it - // will result in duplicating the text of the chunk in the rendered Markdown. - request_assistant_message_id = - Some(thread.insert_assistant_message( - vec![MessageSegment::Text(chunk.to_string())], - cx, - )); - }; - } - } - LanguageModelCompletionEvent::Thinking { - text: chunk, - signature, - } => { - thread.received_chunk(); - - if let Some(last_message) = thread.messages.last_mut() { - if last_message.role == Role::Assistant - && !thread.tool_use.has_tool_results(last_message.id) - { - last_message.push_thinking(&chunk, signature); - cx.emit(ThreadEvent::StreamedAssistantThinking( - last_message.id, - chunk, - )); - } else { - // If we won't have an Assistant message yet, assume this chunk marks the beginning - // of a new Assistant response. - // - // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it - // will result in duplicating the text of the chunk in the rendered Markdown. - request_assistant_message_id = - Some(thread.insert_assistant_message( - vec![MessageSegment::Thinking { - text: chunk.to_string(), - signature, - }], - cx, - )); - }; - } - } - LanguageModelCompletionEvent::RedactedThinking { data } => { - thread.received_chunk(); - - if let Some(last_message) = thread.messages.last_mut() { - if last_message.role == Role::Assistant - && !thread.tool_use.has_tool_results(last_message.id) - { - last_message.push_redacted_thinking(data); - } else { - request_assistant_message_id = - Some(thread.insert_assistant_message( - vec![MessageSegment::RedactedThinking(data)], - cx, - )); - }; - } - } - LanguageModelCompletionEvent::ToolUse(tool_use) => { - let last_assistant_message_id = request_assistant_message_id - .unwrap_or_else(|| { - let new_assistant_message_id = - thread.insert_assistant_message(vec![], cx); - request_assistant_message_id = - Some(new_assistant_message_id); - new_assistant_message_id - }); - - let tool_use_id = tool_use.id.clone(); - let streamed_input = if tool_use.is_input_complete { - None - } else { - Some(tool_use.input.clone()) - }; + fn last_user_message(&self) -> Option<&UserMessage> { + self.messages + .iter() + .rev() + .find_map(|message| match message { + Message::User(user_message) => Some(user_message), + Message::Agent(_) => None, + Message::Resume => None, + }) + } - let ui_text = thread.tool_use.request_tool_use( - last_assistant_message_id, - tool_use, - tool_use_metadata.clone(), - cx, - ); + fn pending_message(&mut self) -> &mut AgentMessage { + self.pending_message.get_or_insert_default() + } - if let Some(input) = streamed_input { - cx.emit(ThreadEvent::StreamedToolUse { - tool_use_id, - ui_text, - input, - }); - } - } - LanguageModelCompletionEvent::ToolUseJsonParseError { - id, - tool_name, - raw_input: invalid_input_json, - json_parse_error, - } => { - thread.receive_invalid_tool_json( - id, - tool_name, - invalid_input_json, - json_parse_error, - window, - cx, - ); - } - LanguageModelCompletionEvent::StatusUpdate(status_update) => { - if let Some(completion) = thread - .pending_completions - .iter_mut() - .find(|completion| completion.id == pending_completion_id) - { - match status_update { - CompletionRequestStatus::Queued { position } => { - completion.queue_state = - QueueState::Queued { position }; - } - CompletionRequestStatus::Started => { - completion.queue_state = QueueState::Started; - } - CompletionRequestStatus::Failed { - code, - message, - request_id: _, - retry_after, - } => { - return Err( - LanguageModelCompletionError::from_cloud_failure( - model.upstream_provider_name(), - code, - message, - retry_after.map(Duration::from_secs_f64), - ), - ); - } - CompletionRequestStatus::UsageUpdated { amount, limit } => { - thread.update_model_request_usage( - amount as u32, - limit, - cx, - ); - } - CompletionRequestStatus::ToolUseLimitReached => { - thread.tool_use_limit_reached = true; - cx.emit(ThreadEvent::ToolUseLimitReached); - } - } - } - } - } + fn flush_pending_message(&mut self, cx: &mut Context) { + let Some(mut message) = self.pending_message.take() else { + return; + }; - thread.touch_updated_at(); - cx.emit(ThreadEvent::StreamedCompletion); - cx.notify(); + if message.content.is_empty() { + return; + } - Ok(()) - })??; + for content in &message.content { + let AgentMessageContent::ToolUse(tool_use) = content else { + continue; + }; - smol::future::yield_now().await; - } + if !message.tool_results.contains_key(&tool_use.id) { + message.tool_results.insert( + tool_use.id.clone(), + LanguageModelToolResult { + tool_use_id: tool_use.id.clone(), + tool_name: tool_use.name.clone(), + is_error: true, + content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()), + output: None, + }, + ); + } + } - thread.update(cx, |thread, cx| { - thread.last_received_chunk_at = None; - thread - .pending_completions - .retain(|completion| completion.id != pending_completion_id); - - // If there is a response without tool use, summarize the message. Otherwise, - // allow two tool uses before summarizing. - if matches!(thread.summary, ThreadSummary::Pending) - && thread.messages.len() >= 2 - && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6) - { - thread.summarize(cx); - } - })?; + self.messages.push(Message::Agent(message)); + self.updated_at = Utc::now(); + self.clear_summary(); + cx.notify() + } - anyhow::Ok(stop_reason) - }; + pub(crate) fn build_completion_request( + &self, + completion_intent: CompletionIntent, + cx: &App, + ) -> Result { + let model = self.model().context("No language model configured")?; + let tools = if let Some(turn) = self.running_turn.as_ref() { + turn.tools + .iter() + .filter_map(|(tool_name, tool)| { + log::trace!("Including tool: {}", tool_name); + Some(LanguageModelRequestTool { + name: tool_name.to_string(), + description: tool.description().to_string(), + input_schema: tool.input_schema(model.tool_input_format()).log_err()?, + }) + }) + .collect::>() + } else { + Vec::new() + }; - let result = stream_completion.await; - let mut retry_scheduled = false; + log::debug!("Building completion request"); + log::debug!("Completion intent: {:?}", completion_intent); + log::debug!("Completion mode: {:?}", self.completion_mode); - thread - .update(cx, |thread, cx| { - thread.finalize_pending_checkpoint(cx); - match result.as_ref() { - Ok(stop_reason) => { - match stop_reason { - StopReason::ToolUse => { - let tool_uses = - thread.use_pending_tools(window, model.clone(), cx); - cx.emit(ThreadEvent::UsePendingTools { tool_uses }); - } - StopReason::EndTurn | StopReason::MaxTokens => { - thread.project.update(cx, |project, cx| { - project.set_agent_location(None, cx); - }); - } - StopReason::Refusal => { - thread.project.update(cx, |project, cx| { - project.set_agent_location(None, cx); - }); - - // Remove the turn that was refused. - // - // https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals#reset-context-after-refusal - { - let mut messages_to_remove = Vec::new(); - - for (ix, message) in - thread.messages.iter().enumerate().rev() - { - messages_to_remove.push(message.id); - - if message.role == Role::User { - if ix == 0 { - break; - } - - if let Some(prev_message) = - thread.messages.get(ix - 1) - && prev_message.role == Role::Assistant { - break; - } - } - } - - for message_id in messages_to_remove { - thread.delete_message(message_id, cx); - } - } - - cx.emit(ThreadEvent::ShowError(ThreadError::Message { - header: "Language model refusal".into(), - message: - "Model refused to generate content for safety reasons." - .into(), - })); - } - } + let messages = self.build_request_messages(cx); + log::debug!("Request will include {} messages", messages.len()); + log::debug!("Request includes {} tools", tools.len()); - // We successfully completed, so cancel any remaining retries. - thread.retry_state = None; - } - Err(error) => { - thread.project.update(cx, |project, cx| { - project.set_agent_location(None, cx); - }); - - if error.is::() { - cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired)); - } else if let Some(error) = - error.downcast_ref::() - { - cx.emit(ThreadEvent::ShowError( - ThreadError::ModelRequestLimitReached { plan: error.plan }, - )); - } else if let Some(completion_error) = - error.downcast_ref::() - { - match &completion_error { - LanguageModelCompletionError::PromptTooLarge { - tokens, .. - } => { - let tokens = tokens.unwrap_or_else(|| { - // We didn't get an exact token count from the API, so fall back on our estimate. - thread - .total_token_usage() - .map(|usage| usage.total) - .unwrap_or(0) - // We know the context window was exceeded in practice, so if our estimate was - // lower than max tokens, the estimate was wrong; return that we exceeded by 1. - .max( - model - .max_token_count_for_mode(completion_mode) - .saturating_add(1), - ) - }); - thread.exceeded_window_error = Some(ExceededWindowError { - model_id: model.id(), - token_count: tokens, - }); - cx.notify(); - } - _ => { - if let Some(retry_strategy) = - Thread::get_retry_strategy(completion_error) - { - log::info!( - "Retrying with {:?} for language model completion error {:?}", - retry_strategy, - completion_error - ); - - retry_scheduled = thread - .handle_retryable_error_with_delay( - completion_error, - Some(retry_strategy), - model.clone(), - intent, - window, - cx, - ); - } - } - } - } + let request = LanguageModelRequest { + thread_id: Some(self.id.to_string()), + prompt_id: Some(self.prompt_id.to_string()), + intent: Some(completion_intent), + mode: Some(self.completion_mode.into()), + messages, + tools, + tool_choice: None, + stop: Vec::new(), + temperature: AgentSettings::temperature_for_model(model, cx), + thinking_allowed: true, + }; - if !retry_scheduled { - thread.cancel_last_completion(window, cx); - } - } - } + log::debug!("Completion request built successfully"); + Ok(request) + } - if !retry_scheduled { - cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new))); - } + fn enabled_tools( + &self, + profile: &AgentProfileSettings, + model: &Arc, + cx: &App, + ) -> BTreeMap> { + fn truncate(tool_name: &SharedString) -> SharedString { + if tool_name.len() > MAX_TOOL_NAME_LENGTH { + let mut truncated = tool_name.to_string(); + truncated.truncate(MAX_TOOL_NAME_LENGTH); + truncated.into() + } else { + tool_name.clone() + } + } - if let Some((request_callback, (request, response_events))) = thread - .request_callback - .as_mut() - .zip(request_callback_parameters.as_ref()) - { - request_callback(request, response_events); + let mut tools = self + .tools + .iter() + .filter_map(|(tool_name, tool)| { + if tool.supported_provider(&model.provider_id()) + && profile.is_tool_enabled(tool_name) + { + Some((truncate(tool_name), tool.clone())) + } else { + None + } + }) + .collect::>(); + + let mut context_server_tools = Vec::new(); + let mut seen_tools = tools.keys().cloned().collect::>(); + let mut duplicate_tool_names = HashSet::default(); + for (server_id, server_tools) in self.context_server_registry.read(cx).servers() { + for (tool_name, tool) in server_tools { + if profile.is_context_server_tool_enabled(&server_id.0, &tool_name) { + let tool_name = truncate(tool_name); + if !seen_tools.insert(tool_name.clone()) { + duplicate_tool_names.insert(tool_name.clone()); } + context_server_tools.push((server_id.clone(), tool_name, tool.clone())); + } + } + } - if let Ok(initial_usage) = initial_token_usage { - let usage = thread.cumulative_token_usage - initial_usage; - - telemetry::event!( - "Assistant Thread Completion", - thread_id = thread.id().to_string(), - prompt_id = prompt_id, - model = model.telemetry_id(), - model_provider = model.provider_id().to_string(), - input_tokens = usage.input_tokens, - output_tokens = usage.output_tokens, - cache_creation_input_tokens = usage.cache_creation_input_tokens, - cache_read_input_tokens = usage.cache_read_input_tokens, - ); - } - }) - .ok(); - }); + // When there are duplicate tool names, disambiguate by prefixing them + // with the server ID. In the rare case there isn't enough space for the + // disambiguated tool name, keep only the last tool with this name. + for (server_id, tool_name, tool) in context_server_tools { + if duplicate_tool_names.contains(&tool_name) { + let available = MAX_TOOL_NAME_LENGTH.saturating_sub(tool_name.len()); + if available >= 2 { + let mut disambiguated = server_id.0.to_string(); + disambiguated.truncate(available - 1); + disambiguated.push('_'); + disambiguated.push_str(&tool_name); + tools.insert(disambiguated.into(), tool.clone()); + } else { + tools.insert(tool_name, tool.clone()); + } + } else { + tools.insert(tool_name, tool.clone()); + } + } - self.pending_completions.push(PendingCompletion { - id: pending_completion_id, - queue_state: QueueState::Sending, - _task: task, - }); + tools } - pub fn summarize(&mut self, cx: &mut Context) { - let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else { - println!("No thread summary model"); - return; - }; - - if !model.provider.is_authenticated(cx) { - return; - } + fn tool(&self, name: &str) -> Option> { + self.running_turn.as_ref()?.tools.get(name).cloned() + } - let request = self.to_summarize_request( - &model.model, - CompletionIntent::ThreadSummarization, - SUMMARIZE_THREAD_PROMPT.into(), - cx, + fn build_request_messages(&self, cx: &App) -> Vec { + log::trace!( + "Building request messages from {} thread messages", + self.messages.len() ); - self.summary = ThreadSummary::Generating; - - self.pending_summary = cx.spawn(async move |this, cx| { - let result = async { - let mut messages = model.model.stream_completion(request, cx).await?; + let system_prompt = SystemPromptTemplate { + project: self.project_context.read(cx), + available_tools: self.tools.keys().cloned().collect(), + } + .render(&self.templates) + .context("failed to build system prompt") + .expect("Invalid template"); + let mut messages = vec![LanguageModelRequestMessage { + role: Role::System, + content: vec![system_prompt.into()], + cache: false, + }]; + for message in &self.messages { + messages.extend(message.to_request()); + } - let mut new_summary = String::new(); - while let Some(event) = messages.next().await { - let Ok(event) = event else { - continue; - }; - let text = match event { - LanguageModelCompletionEvent::Text(text) => text, - LanguageModelCompletionEvent::StatusUpdate( - CompletionRequestStatus::UsageUpdated { amount, limit }, - ) => { - this.update(cx, |thread, cx| { - thread.update_model_request_usage(amount as u32, limit, cx); - })?; - continue; - } - _ => continue, - }; + if let Some(last_message) = messages.last_mut() { + last_message.cache = true; + } - let mut lines = text.lines(); - new_summary.extend(lines.next()); + if let Some(message) = self.pending_message.as_ref() { + messages.extend(message.to_request()); + } - // Stop if the LLM generated multiple lines. - if lines.next().is_some() { - break; - } - } + messages + } - anyhow::Ok(new_summary) + pub fn to_markdown(&self) -> String { + let mut markdown = String::new(); + for (ix, message) in self.messages.iter().enumerate() { + if ix > 0 { + markdown.push('\n'); } - .await; + markdown.push_str(&message.to_markdown()); + } - this.update(cx, |this, cx| { - match result { - Ok(new_summary) => { - if new_summary.is_empty() { - this.summary = ThreadSummary::Error; - } else { - this.summary = ThreadSummary::Ready(new_summary.into()); - } - } - Err(err) => { - this.summary = ThreadSummary::Error; - log::error!("Failed to generate thread summary: {}", err); - } - } - cx.emit(ThreadEvent::SummaryGenerated); - }) - .log_err()?; + if let Some(message) = self.pending_message.as_ref() { + markdown.push('\n'); + markdown.push_str(&message.to_markdown()); + } - Some(()) - }); + markdown } - fn get_retry_strategy(error: &LanguageModelCompletionError) -> Option { + fn advance_prompt_id(&mut self) { + self.prompt_id = PromptId::new(); + } + + fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option { use LanguageModelCompletionError::*; + use http_client::StatusCode; // General strategy here: // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all. @@ -2205,8 +2119,8 @@ impl Thread { }) } Other(err) - if err.is::() - || err.is::() => + if err.is::() + || err.is::() => { // Retrying won't help for Payment Required or Model Request Limit errors (where // the user must upgrade to usage-based billing to get more requests, or else wait @@ -2220,3166 +2134,556 @@ impl Thread { }), } } +} - fn handle_retryable_error_with_delay( - &mut self, - error: &LanguageModelCompletionError, - strategy: Option, - model: Arc, - intent: CompletionIntent, - window: Option, - cx: &mut Context, - ) -> bool { - // Store context for the Retry button - self.last_error_context = Some((model.clone(), intent)); - - // Only auto-retry if Burn Mode is enabled - if self.completion_mode != CompletionMode::Burn { - // Show error with retry options - cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError { - message: format!( - "{}\n\nTo automatically retry when similar errors happen, enable Burn Mode.", - error - ) - .into(), - can_enable_burn_mode: true, - })); - return false; - } +struct RunningTurn { + /// Holds the task that handles agent interaction until the end of the turn. + /// Survives across multiple requests as the model performs tool calls and + /// we run tools, report their results. + _task: Task<()>, + /// The current event stream for the running turn. Used to report a final + /// cancellation event if we cancel the turn. + event_stream: ThreadEventStream, + /// The tools that were enabled for this turn. + tools: BTreeMap>, +} - let Some(strategy) = strategy.or_else(|| Self::get_retry_strategy(error)) else { - return false; - }; +impl RunningTurn { + fn cancel(self) { + log::debug!("Cancelling in progress turn"); + self.event_stream.send_canceled(); + } +} - let max_attempts = match &strategy { - RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts, - RetryStrategy::Fixed { max_attempts, .. } => *max_attempts, - }; +pub struct TokenUsageUpdated(pub Option); - let retry_state = self.retry_state.get_or_insert(RetryState { - attempt: 0, - max_attempts, - intent, - }); +impl EventEmitter for Thread {} - retry_state.attempt += 1; - let attempt = retry_state.attempt; - let max_attempts = retry_state.max_attempts; - let intent = retry_state.intent; +pub struct TitleUpdated; - if attempt <= max_attempts { - let delay = match &strategy { - RetryStrategy::ExponentialBackoff { initial_delay, .. } => { - let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32); - Duration::from_secs(delay_secs) - } - RetryStrategy::Fixed { delay, .. } => *delay, - }; +impl EventEmitter for Thread {} - // Add a transient message to inform the user - let delay_secs = delay.as_secs(); - let retry_message = if max_attempts == 1 { - format!("{error}. Retrying in {delay_secs} seconds...") - } else { - format!( - "{error}. Retrying (attempt {attempt} of {max_attempts}) \ - in {delay_secs} seconds..." - ) - }; - log::warn!( - "Retrying completion request (attempt {attempt} of {max_attempts}) \ - in {delay_secs} seconds: {error:?}", - ); +pub trait AgentTool +where + Self: 'static + Sized, +{ + type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema; + type Output: for<'de> Deserialize<'de> + Serialize + Into; - // Add a UI-only message instead of a regular message - let id = self.next_message_id.post_inc(); - self.messages.push(Message { - id, - role: Role::System, - segments: vec![MessageSegment::Text(retry_message)], - loaded_context: LoadedContext::default(), - creases: Vec::new(), - is_hidden: false, - ui_only: true, - }); - cx.emit(ThreadEvent::MessageAdded(id)); + fn name() -> &'static str; - // Schedule the retry - let thread_handle = cx.entity().downgrade(); + fn description() -> SharedString { + let schema = schemars::schema_for!(Self::Input); + SharedString::new( + schema + .get("description") + .and_then(|description| description.as_str()) + .unwrap_or_default(), + ) + } - cx.spawn(async move |_thread, cx| { - cx.background_executor().timer(delay).await; + fn kind() -> acp::ToolKind; - thread_handle - .update(cx, |thread, cx| { - // Retry the completion - thread.send_to_model(model, intent, window, cx); - }) - .log_err(); - }) - .detach(); + /// The initial tool title to display. Can be updated during the tool run. + fn initial_title( + &self, + input: Result, + cx: &mut App, + ) -> SharedString; - true - } else { - // Max retries exceeded - self.retry_state = None; + /// Returns the JSON schema that describes the tool's input. + fn input_schema(format: LanguageModelToolSchemaFormat) -> Schema { + crate::tool_schema::root_schema_for::(format) + } - // Stop generating since we're giving up on retrying. - self.pending_completions.clear(); + /// Some tools rely on a provider for the underlying billing or other reasons. + /// Allow the tool to check if they are compatible, or should be filtered out. + fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool { + true + } - // Show error alongside a Retry button, but no - // Enable Burn Mode button (since it's already enabled) - cx.emit(ThreadEvent::ShowError(ThreadError::RetryableError { - message: format!("Failed after retrying: {}", error).into(), - can_enable_burn_mode: false, - })); + /// Runs the tool with the provided input. + fn run( + self: Arc, + input: Self::Input, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task>; - false - } + /// Emits events for a previous execution of the tool. + fn replay( + &self, + _input: Self::Input, + _output: Self::Output, + _event_stream: ToolCallEventStream, + _cx: &mut App, + ) -> Result<()> { + Ok(()) } - pub fn start_generating_detailed_summary_if_needed( - &mut self, - thread_store: WeakEntity, - cx: &mut Context, - ) { - let Some(last_message_id) = self.messages.last().map(|message| message.id) else { - return; - }; - - match &*self.detailed_summary_rx.borrow() { - DetailedSummaryState::Generating { message_id, .. } - | DetailedSummaryState::Generated { message_id, .. } - if *message_id == last_message_id => - { - // Already up-to-date - return; - } - _ => {} - } - - let Some(ConfiguredModel { model, provider }) = - LanguageModelRegistry::read_global(cx).thread_summary_model() - else { - return; - }; + fn erase(self) -> Arc { + Arc::new(Erased(Arc::new(self))) + } +} - if !provider.is_authenticated(cx) { - return; - } +pub struct Erased(T); - let request = self.to_summarize_request( - &model, - CompletionIntent::ThreadContextSummarization, - SUMMARIZE_THREAD_DETAILED_PROMPT.into(), - cx, - ); +pub struct AgentToolOutput { + pub llm_output: LanguageModelToolResultContent, + pub raw_output: serde_json::Value, +} - *self.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generating { - message_id: last_message_id, - }; +pub trait AnyAgentTool { + fn name(&self) -> SharedString; + fn description(&self) -> SharedString; + fn kind(&self) -> acp::ToolKind; + fn initial_title(&self, input: serde_json::Value, _cx: &mut App) -> SharedString; + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result; + fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool { + true + } + fn run( + self: Arc, + input: serde_json::Value, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task>; + fn replay( + &self, + input: serde_json::Value, + output: serde_json::Value, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Result<()>; +} - // Replace the detailed summarization task if there is one, cancelling it. It would probably - // be better to allow the old task to complete, but this would require logic for choosing - // which result to prefer (the old task could complete after the new one, resulting in a - // stale summary). - self.detailed_summary_task = cx.spawn(async move |thread, cx| { - let stream = model.stream_completion_text(request, cx); - let Some(mut messages) = stream.await.log_err() else { - thread - .update(cx, |thread, _cx| { - *thread.detailed_summary_tx.borrow_mut() = - DetailedSummaryState::NotGenerated; - }) - .ok()?; - return None; - }; +impl AnyAgentTool for Erased> +where + T: AgentTool, +{ + fn name(&self) -> SharedString { + T::name().into() + } - let mut new_detailed_summary = String::new(); + fn description(&self) -> SharedString { + T::description() + } - while let Some(chunk) = messages.stream.next().await { - if let Some(chunk) = chunk.log_err() { - new_detailed_summary.push_str(&chunk); - } - } + fn kind(&self) -> agent_client_protocol::ToolKind { + T::kind() + } - thread - .update(cx, |thread, _cx| { - *thread.detailed_summary_tx.borrow_mut() = DetailedSummaryState::Generated { - text: new_detailed_summary.into(), - message_id: last_message_id, - }; - }) - .ok()?; + fn initial_title(&self, input: serde_json::Value, _cx: &mut App) -> SharedString { + let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input); + self.0.initial_title(parsed_input, _cx) + } - // Save thread so its summary can be reused later - if let Some(thread) = thread.upgrade() - && let Ok(Ok(save_task)) = cx.update(|cx| { - thread_store - .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx)) - }) - { - save_task.await.log_err(); - } + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { + let mut json = serde_json::to_value(T::input_schema(format))?; + crate::tool_schema::adapt_schema_to_format(&mut json, format)?; + Ok(json) + } - Some(()) - }); + fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool { + self.0.supported_provider(provider) } - pub async fn wait_for_detailed_summary_or_text( - this: &Entity, - cx: &mut AsyncApp, - ) -> Option { - let mut detailed_summary_rx = this - .read_with(cx, |this, _cx| this.detailed_summary_rx.clone()) - .ok()?; - loop { - match detailed_summary_rx.recv().await? { - DetailedSummaryState::Generating { .. } => {} - DetailedSummaryState::NotGenerated => { - return this.read_with(cx, |this, _cx| this.text().into()).ok(); - } - DetailedSummaryState::Generated { text, .. } => return Some(text), - } - } + fn run( + self: Arc, + input: serde_json::Value, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + cx.spawn(async move |cx| { + let input = serde_json::from_value(input)?; + let output = cx + .update(|cx| self.0.clone().run(input, event_stream, cx))? + .await?; + let raw_output = serde_json::to_value(&output)?; + Ok(AgentToolOutput { + llm_output: output.into(), + raw_output, + }) + }) } - pub fn latest_detailed_summary_or_text(&self) -> SharedString { - self.detailed_summary_rx - .borrow() - .text() - .unwrap_or_else(|| self.text().into()) + fn replay( + &self, + input: serde_json::Value, + output: serde_json::Value, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Result<()> { + let input = serde_json::from_value(input)?; + let output = serde_json::from_value(output)?; + self.0.replay(input, output, event_stream, cx) } +} - pub fn is_generating_detailed_summary(&self) -> bool { - matches!( - &*self.detailed_summary_rx.borrow(), - DetailedSummaryState::Generating { .. } - ) +#[derive(Clone)] +struct ThreadEventStream(mpsc::UnboundedSender>); + +impl ThreadEventStream { + fn send_user_message(&self, message: &UserMessage) { + self.0 + .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone()))) + .ok(); } - pub fn use_pending_tools( - &mut self, - window: Option, - model: Arc, - cx: &mut Context, - ) -> Vec { - let request = - Arc::new(self.to_completion_request(model.clone(), CompletionIntent::ToolResults, cx)); - let pending_tool_uses = self - .tool_use - .pending_tool_uses() - .into_iter() - .filter(|tool_use| tool_use.status.is_idle()) - .cloned() - .collect::>(); - - for tool_use in pending_tool_uses.iter() { - self.use_pending_tool(tool_use.clone(), request.clone(), model.clone(), window, cx); - } + fn send_text(&self, text: &str) { + self.0 + .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string()))) + .ok(); + } - pending_tool_uses + fn send_thinking(&self, text: &str) { + self.0 + .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string()))) + .ok(); } - fn use_pending_tool( - &mut self, - tool_use: PendingToolUse, - request: Arc, - model: Arc, - window: Option, - cx: &mut Context, + fn send_tool_call( + &self, + id: &LanguageModelToolUseId, + title: SharedString, + kind: acp::ToolKind, + input: serde_json::Value, ) { - let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) else { - return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx); - }; - - if !self.profile.is_tool_enabled(tool.source(), tool.name(), cx) { - return self.handle_hallucinated_tool_use(tool_use.id, tool_use.name, window, cx); - } + self.0 + .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call( + id, + title.to_string(), + kind, + input, + )))) + .ok(); + } - if tool.needs_confirmation(&tool_use.input, &self.project, cx) - && !AgentSettings::get_global(cx).always_allow_tool_actions - { - self.tool_use.confirm_tool_use( - tool_use.id, - tool_use.ui_text, - tool_use.input, - request, - tool, - ); - cx.emit(ThreadEvent::ToolConfirmationNeeded); - } else { - self.run_tool( - tool_use.id, - tool_use.ui_text, - tool_use.input, - request, - tool, - model, - window, - cx, - ); + fn initial_tool_call( + id: &LanguageModelToolUseId, + title: String, + kind: acp::ToolKind, + input: serde_json::Value, + ) -> acp::ToolCall { + acp::ToolCall { + meta: None, + id: acp::ToolCallId(id.to_string().into()), + title, + kind, + status: acp::ToolCallStatus::Pending, + content: vec![], + locations: vec![], + raw_input: Some(input), + raw_output: None, } } - pub fn handle_hallucinated_tool_use( - &mut self, - tool_use_id: LanguageModelToolUseId, - hallucinated_tool_name: Arc, - window: Option, - cx: &mut Context, + fn update_tool_call_fields( + &self, + tool_use_id: &LanguageModelToolUseId, + fields: acp::ToolCallUpdateFields, ) { - let available_tools = self.profile.enabled_tools(cx); - - let tool_list = available_tools - .iter() - .map(|(name, tool)| format!("- {}: {}", name, tool.description())) - .collect::>() - .join("\n"); - - let error_message = format!( - "The tool '{}' doesn't exist or is not enabled. Available tools:\n{}", - hallucinated_tool_name, tool_list - ); + self.0 + .unbounded_send(Ok(ThreadEvent::ToolCallUpdate( + acp::ToolCallUpdate { + meta: None, + id: acp::ToolCallId(tool_use_id.to_string().into()), + fields, + } + .into(), + ))) + .ok(); + } - let pending_tool_use = self.tool_use.insert_tool_output( - tool_use_id.clone(), - hallucinated_tool_name, - Err(anyhow!("Missing tool call: {error_message}")), - self.configured_model.as_ref(), - self.completion_mode, - ); + fn send_retry(&self, status: acp_thread::RetryStatus) { + self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok(); + } - cx.emit(ThreadEvent::MissingToolUse { - tool_use_id: tool_use_id.clone(), - ui_text: error_message.into(), - }); + fn send_stop(&self, reason: acp::StopReason) { + self.0.unbounded_send(Ok(ThreadEvent::Stop(reason))).ok(); + } - self.tool_finished(tool_use_id, pending_tool_use, false, window, cx); + fn send_canceled(&self) { + self.0 + .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled))) + .ok(); } - pub fn receive_invalid_tool_json( - &mut self, - tool_use_id: LanguageModelToolUseId, - tool_name: Arc, - invalid_json: Arc, - error: String, - window: Option, - cx: &mut Context, - ) { - log::error!("The model returned invalid input JSON: {invalid_json}"); + fn send_error(&self, error: impl Into) { + self.0.unbounded_send(Err(error.into())).ok(); + } +} - let pending_tool_use = self.tool_use.insert_tool_output( - tool_use_id.clone(), - tool_name, - Err(anyhow!("Error parsing input JSON: {error}")), - self.configured_model.as_ref(), - self.completion_mode, - ); - let ui_text = if let Some(pending_tool_use) = &pending_tool_use { - pending_tool_use.ui_text.clone() - } else { - log::error!( - "There was no pending tool use for tool use {tool_use_id}, even though it finished (with invalid input JSON)." - ); - format!("Unknown tool {}", tool_use_id).into() - }; +#[derive(Clone)] +pub struct ToolCallEventStream { + tool_use_id: LanguageModelToolUseId, + stream: ThreadEventStream, + fs: Option>, +} - cx.emit(ThreadEvent::InvalidToolInput { - tool_use_id: tool_use_id.clone(), - ui_text, - invalid_input_json: invalid_json, - }); +impl ToolCallEventStream { + #[cfg(any(test, feature = "test-support"))] + pub fn test() -> (Self, ToolCallEventStreamReceiver) { + let (events_tx, events_rx) = mpsc::unbounded::>(); - self.tool_finished(tool_use_id, pending_tool_use, false, window, cx); - } + let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None); - pub fn run_tool( - &mut self, - tool_use_id: LanguageModelToolUseId, - ui_text: impl Into, - input: serde_json::Value, - request: Arc, - tool: Arc, - model: Arc, - window: Option, - cx: &mut Context, - ) { - let task = - self.spawn_tool_use(tool_use_id.clone(), request, input, tool, model, window, cx); - self.tool_use - .run_pending_tool(tool_use_id, ui_text.into(), task); + (stream, ToolCallEventStreamReceiver(events_rx)) } - fn spawn_tool_use( - &mut self, + fn new( tool_use_id: LanguageModelToolUseId, - request: Arc, - input: serde_json::Value, - tool: Arc, - model: Arc, - window: Option, - cx: &mut Context, - ) -> Task<()> { - let tool_name: Arc = tool.name().into(); - - let tool_result = tool.run( - input, - request, - self.project.clone(), - self.action_log.clone(), - model, - window, - cx, - ); - - // Store the card separately if it exists - if let Some(card) = tool_result.card.clone() { - self.tool_use - .insert_tool_result_card(tool_use_id.clone(), card); + stream: ThreadEventStream, + fs: Option>, + ) -> Self { + Self { + tool_use_id, + stream, + fs, } - - cx.spawn({ - async move |thread: WeakEntity, cx| { - let output = tool_result.output.await; - - thread - .update(cx, |thread, cx| { - let pending_tool_use = thread.tool_use.insert_tool_output( - tool_use_id.clone(), - tool_name, - output, - thread.configured_model.as_ref(), - thread.completion_mode, - ); - thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx); - }) - .ok(); - } - }) } - fn tool_finished( - &mut self, - tool_use_id: LanguageModelToolUseId, - pending_tool_use: Option, - canceled: bool, - window: Option, - cx: &mut Context, - ) { - if self.all_tools_finished() - && let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() - && !canceled - { - self.send_to_model(model.clone(), CompletionIntent::ToolResults, window, cx); - } + pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) { + self.stream + .update_tool_call_fields(&self.tool_use_id, fields); + } - cx.emit(ThreadEvent::ToolFinished { - tool_use_id, - pending_tool_use, - }); + pub fn update_diff(&self, diff: Entity) { + self.stream + .0 + .unbounded_send(Ok(ThreadEvent::ToolCallUpdate( + acp_thread::ToolCallUpdateDiff { + id: acp::ToolCallId(self.tool_use_id.to_string().into()), + diff, + } + .into(), + ))) + .ok(); } - /// Cancels the last pending completion, if there are any pending. - /// - /// Returns whether a completion was canceled. - pub fn cancel_last_completion( - &mut self, - window: Option, - cx: &mut Context, - ) -> bool { - let mut canceled = self.pending_completions.pop().is_some() || self.retry_state.is_some(); - - self.retry_state = None; - - for pending_tool_use in self.tool_use.cancel_pending() { - canceled = true; - self.tool_finished( - pending_tool_use.id.clone(), - Some(pending_tool_use), - true, - window, - cx, - ); + pub fn authorize(&self, title: impl Into, cx: &mut App) -> Task> { + if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions { + return Task::ready(Ok(())); } - if canceled { - cx.emit(ThreadEvent::CompletionCanceled); + let (response_tx, response_rx) = oneshot::channel(); + self.stream + .0 + .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization( + ToolCallAuthorization { + tool_call: acp::ToolCallUpdate { + meta: None, + id: acp::ToolCallId(self.tool_use_id.to_string().into()), + fields: acp::ToolCallUpdateFields { + title: Some(title.into()), + ..Default::default() + }, + }, + options: vec![ + acp::PermissionOption { + id: acp::PermissionOptionId("always_allow".into()), + name: "Always Allow".into(), + kind: acp::PermissionOptionKind::AllowAlways, + meta: None, + }, + acp::PermissionOption { + id: acp::PermissionOptionId("allow".into()), + name: "Allow".into(), + kind: acp::PermissionOptionKind::AllowOnce, + meta: None, + }, + acp::PermissionOption { + id: acp::PermissionOptionId("deny".into()), + name: "Deny".into(), + kind: acp::PermissionOptionKind::RejectOnce, + meta: None, + }, + ], + response: response_tx, + }, + ))) + .ok(); + let fs = self.fs.clone(); + cx.spawn(async move |cx| match response_rx.await?.0.as_ref() { + "always_allow" => { + if let Some(fs) = fs.clone() { + cx.update(|cx| { + update_settings_file(fs, cx, |settings, _| { + settings + .agent + .get_or_insert_default() + .set_always_allow_tool_actions(true); + }); + })?; + } - // When canceled, we always want to insert the checkpoint. - // (We skip over finalize_pending_checkpoint, because it - // would conclude we didn't have anything to insert here.) - if let Some(checkpoint) = self.pending_checkpoint.take() { - self.insert_checkpoint(checkpoint, cx); + Ok(()) } + "allow" => Ok(()), + _ => Err(anyhow!("Permission to run tool denied by user")), + }) + } +} + +#[cfg(any(test, feature = "test-support"))] +pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver>); + +#[cfg(any(test, feature = "test-support"))] +impl ToolCallEventStreamReceiver { + pub async fn expect_authorization(&mut self) -> ToolCallAuthorization { + let event = self.0.next().await; + if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event { + auth } else { - self.finalize_pending_checkpoint(cx); + panic!("Expected ToolCallAuthorization but got: {:?}", event); } - - canceled } - /// Signals that any in-progress editing should be canceled. - /// - /// This method is used to notify listeners (like ActiveThread) that - /// they should cancel any editing operations. - pub fn cancel_editing(&mut self, cx: &mut Context) { - cx.emit(ThreadEvent::CancelEditing); + pub async fn expect_update_fields(&mut self) -> acp::ToolCallUpdateFields { + let event = self.0.next().await; + if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields( + update, + )))) = event + { + update.fields + } else { + panic!("Expected update fields but got: {:?}", event); + } } - pub fn message_feedback(&self, message_id: MessageId) -> Option { - self.message_feedback.get(&message_id).copied() + pub async fn expect_diff(&mut self) -> Entity { + let event = self.0.next().await; + if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateDiff( + update, + )))) = event + { + update.diff + } else { + panic!("Expected diff but got: {:?}", event); + } } - pub fn report_message_feedback( - &mut self, - message_id: MessageId, - feedback: ThreadFeedback, - cx: &mut Context, - ) -> Task> { - if self.message_feedback.get(&message_id) == Some(&feedback) { - return Task::ready(Ok(())); + pub async fn expect_terminal(&mut self) -> Entity { + let event = self.0.next().await; + if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal( + update, + )))) = event + { + update.terminal + } else { + panic!("Expected terminal but got: {:?}", event); } - - let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx); - let serialized_thread = self.serialize(cx); - let thread_id = self.id().clone(); - let client = self.project.read(cx).client(); - - let enabled_tool_names: Vec = self - .profile - .enabled_tools(cx) - .iter() - .map(|(name, _)| name.clone().into()) - .collect(); - - self.message_feedback.insert(message_id, feedback); - - cx.notify(); - - let message_content = self - .message(message_id) - .map(|msg| msg.to_message_content()) - .unwrap_or_default(); - - cx.background_spawn(async move { - let final_project_snapshot = final_project_snapshot.await; - let serialized_thread = serialized_thread.await?; - let thread_data = - serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null); - - let rating = match feedback { - ThreadFeedback::Positive => "positive", - ThreadFeedback::Negative => "negative", - }; - telemetry::event!( - "Assistant Thread Rated", - rating, - thread_id, - enabled_tool_names, - message_id = message_id.0, - message_content, - thread_data, - final_project_snapshot - ); - client.telemetry().flush_events().await; - - Ok(()) - }) } +} - /// Create a snapshot of the current project state including git information and unsaved buffers. - fn project_snapshot( - project: Entity, - cx: &mut Context, - ) -> Task> { - let git_store = project.read(cx).git_store().clone(); - let worktree_snapshots: Vec<_> = project - .read(cx) - .visible_worktrees(cx) - .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx)) - .collect(); +#[cfg(any(test, feature = "test-support"))] +impl std::ops::Deref for ToolCallEventStreamReceiver { + type Target = mpsc::UnboundedReceiver>; - cx.spawn(async move |_, _| { - let worktree_snapshots = futures::future::join_all(worktree_snapshots).await; + fn deref(&self) -> &Self::Target { + &self.0 + } +} - Arc::new(ProjectSnapshot { - worktree_snapshots, - timestamp: Utc::now(), - }) - }) +#[cfg(any(test, feature = "test-support"))] +impl std::ops::DerefMut for ToolCallEventStreamReceiver { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 } +} - fn worktree_snapshot( - worktree: Entity, - git_store: Entity, - cx: &App, - ) -> Task { - cx.spawn(async move |cx| { - // Get worktree path and snapshot - let worktree_info = cx.update(|app_cx| { - let worktree = worktree.read(app_cx); - let path = worktree.abs_path().to_string_lossy().into_owned(); - let snapshot = worktree.snapshot(); - (path, snapshot) - }); +impl From<&str> for UserMessageContent { + fn from(text: &str) -> Self { + Self::Text(text.into()) + } +} - let Ok((worktree_path, _snapshot)) = worktree_info else { - return WorktreeSnapshot { - worktree_path: String::new(), - git_state: None, - }; - }; +impl From for UserMessageContent { + fn from(value: acp::ContentBlock) -> Self { + match value { + acp::ContentBlock::Text(text_content) => Self::Text(text_content.text), + acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)), + acp::ContentBlock::Audio(_) => { + // TODO + Self::Text("[audio]".to_string()) + } + acp::ContentBlock::ResourceLink(resource_link) => { + match MentionUri::parse(&resource_link.uri) { + Ok(uri) => Self::Mention { + uri, + content: String::new(), + }, + Err(err) => { + log::error!("Failed to parse mention link: {}", err); + Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri)) + } + } + } + acp::ContentBlock::Resource(resource) => match resource.resource { + acp::EmbeddedResourceResource::TextResourceContents(resource) => { + match MentionUri::parse(&resource.uri) { + Ok(uri) => Self::Mention { + uri, + content: resource.text, + }, + Err(err) => { + log::error!("Failed to parse mention link: {}", err); + Self::Text( + MarkdownCodeBlock { + tag: &resource.uri, + text: &resource.text, + } + .to_string(), + ) + } + } + } + acp::EmbeddedResourceResource::BlobResourceContents(_) => { + // TODO + Self::Text("[blob]".to_string()) + } + }, + } + } +} - let git_state = git_store - .update(cx, |git_store, cx| { - git_store - .repositories() - .values() - .find(|repo| { - repo.read(cx) - .abs_path_to_repo_path(&worktree.read(cx).abs_path()) - .is_some() - }) - .cloned() +impl From for acp::ContentBlock { + fn from(content: UserMessageContent) -> Self { + match content { + UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent { + text, + annotations: None, + meta: None, + }), + UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent { + data: image.source.to_string(), + mime_type: "image/png".to_string(), + meta: None, + annotations: None, + uri: None, + }), + UserMessageContent::Mention { uri, content } => { + acp::ContentBlock::Resource(acp::EmbeddedResource { + meta: None, + resource: acp::EmbeddedResourceResource::TextResourceContents( + acp::TextResourceContents { + meta: None, + mime_type: None, + text: content, + uri: uri.to_uri().to_string(), + }, + ), + annotations: None, }) - .ok() - .flatten() - .map(|repo| { - repo.update(cx, |repo, _| { - let current_branch = - repo.branch.as_ref().map(|branch| branch.name().to_owned()); - repo.send_job(None, |state, _| async move { - let RepositoryState::Local { backend, .. } = state else { - return GitState { - remote_url: None, - head_sha: None, - current_branch, - diff: None, - }; - }; - - let remote_url = backend.remote_url("origin"); - let head_sha = backend.head_sha().await; - let diff = backend.diff(DiffType::HeadToWorktree).await.ok(); - - GitState { - remote_url, - head_sha, - current_branch, - diff, - } - }) - }) - }); - - let git_state = match git_state { - Some(git_state) => match git_state.ok() { - Some(git_state) => git_state.await.ok(), - None => None, - }, - None => None, - }; - - WorktreeSnapshot { - worktree_path, - git_state, - } - }) - } - - pub fn to_markdown(&self, cx: &App) -> Result { - let mut markdown = Vec::new(); - - let summary = self.summary().or_default(); - writeln!(markdown, "# {summary}\n")?; - - for message in self.messages() { - writeln!( - markdown, - "## {role}\n", - role = match message.role { - Role::User => "User", - Role::Assistant => "Agent", - Role::System => "System", - } - )?; - - if !message.loaded_context.text.is_empty() { - writeln!(markdown, "{}", message.loaded_context.text)?; - } - - if !message.loaded_context.images.is_empty() { - writeln!( - markdown, - "\n{} images attached as context.\n", - message.loaded_context.images.len() - )?; - } - - for segment in &message.segments { - match segment { - MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?, - MessageSegment::Thinking { text, .. } => { - writeln!(markdown, "\n{}\n\n", text)? - } - MessageSegment::RedactedThinking(_) => {} - } - } - - for tool_use in self.tool_uses_for_message(message.id, cx) { - writeln!( - markdown, - "**Use Tool: {} ({})**", - tool_use.name, tool_use.id - )?; - writeln!(markdown, "```json")?; - writeln!( - markdown, - "{}", - serde_json::to_string_pretty(&tool_use.input)? - )?; - writeln!(markdown, "```")?; - } - - for tool_result in self.tool_results_for_message(message.id) { - write!(markdown, "\n**Tool Results: {}", tool_result.tool_use_id)?; - if tool_result.is_error { - write!(markdown, " (Error)")?; - } - - writeln!(markdown, "**\n")?; - match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - writeln!(markdown, "{text}")?; - } - LanguageModelToolResultContent::Image(image) => { - writeln!(markdown, "![Image](data:base64,{})", image.source)?; - } - } - - if let Some(output) = tool_result.output.as_ref() { - writeln!( - markdown, - "\n\nDebug Output:\n\n```json\n{}\n```\n", - serde_json::to_string_pretty(output)? - )?; - } } } - - Ok(String::from_utf8_lossy(&markdown).to_string()) - } - - pub fn keep_edits_in_range( - &mut self, - buffer: Entity, - buffer_range: Range, - cx: &mut Context, - ) { - self.action_log.update(cx, |action_log, cx| { - action_log.keep_edits_in_range(buffer, buffer_range, cx) - }); - } - - pub fn keep_all_edits(&mut self, cx: &mut Context) { - self.action_log - .update(cx, |action_log, cx| action_log.keep_all_edits(cx)); - } - - pub fn reject_edits_in_ranges( - &mut self, - buffer: Entity, - buffer_ranges: Vec>, - cx: &mut Context, - ) -> Task> { - self.action_log.update(cx, |action_log, cx| { - action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx) - }) - } - - pub fn action_log(&self) -> &Entity { - &self.action_log - } - - pub fn project(&self) -> &Entity { - &self.project - } - - pub fn cumulative_token_usage(&self) -> TokenUsage { - self.cumulative_token_usage - } - - pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage { - let Some(model) = self.configured_model.as_ref() else { - return TotalTokenUsage::default(); - }; - - let max = model - .model - .max_token_count_for_mode(self.completion_mode().into()); - - let index = self - .messages - .iter() - .position(|msg| msg.id == message_id) - .unwrap_or(0); - - if index == 0 { - return TotalTokenUsage { total: 0, max }; - } - - let token_usage = &self - .request_token_usage - .get(index - 1) - .cloned() - .unwrap_or_default(); - - TotalTokenUsage { - total: token_usage.total_tokens(), - max, - } - } - - pub fn total_token_usage(&self) -> Option { - let model = self.configured_model.as_ref()?; - - let max = model - .model - .max_token_count_for_mode(self.completion_mode().into()); - - if let Some(exceeded_error) = &self.exceeded_window_error - && model.model.id() == exceeded_error.model_id - { - return Some(TotalTokenUsage { - total: exceeded_error.token_count, - max, - }); - } - - let total = self - .token_usage_at_last_message() - .unwrap_or_default() - .total_tokens(); - - Some(TotalTokenUsage { total, max }) - } - - fn token_usage_at_last_message(&self) -> Option { - self.request_token_usage - .get(self.messages.len().saturating_sub(1)) - .or_else(|| self.request_token_usage.last()) - .cloned() - } - - fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) { - let placeholder = self.token_usage_at_last_message().unwrap_or_default(); - self.request_token_usage - .resize(self.messages.len(), placeholder); - - if let Some(last) = self.request_token_usage.last_mut() { - *last = token_usage; - } - } - - fn update_model_request_usage(&self, amount: u32, limit: UsageLimit, cx: &mut Context) { - self.project - .read(cx) - .user_store() - .update(cx, |user_store, cx| { - user_store.update_model_request_usage( - ModelRequestUsage(RequestUsage { - amount: amount as i32, - limit, - }), - cx, - ) - }); - } - - pub fn deny_tool_use( - &mut self, - tool_use_id: LanguageModelToolUseId, - tool_name: Arc, - window: Option, - cx: &mut Context, - ) { - let err = Err(anyhow::anyhow!( - "Permission to run tool action denied by user" - )); - - self.tool_use.insert_tool_output( - tool_use_id.clone(), - tool_name, - err, - self.configured_model.as_ref(), - self.completion_mode, - ); - self.tool_finished(tool_use_id, None, true, window, cx); } } -#[derive(Debug, Clone, Error)] -pub enum ThreadError { - #[error("Payment required")] - PaymentRequired, - #[error("Model request limit reached")] - ModelRequestLimitReached { plan: Plan }, - #[error("Message {header}: {message}")] - Message { - header: SharedString, - message: SharedString, - }, - #[error("Retryable error: {message}")] - RetryableError { - message: SharedString, - can_enable_burn_mode: bool, - }, -} - -#[derive(Debug, Clone)] -pub enum ThreadEvent { - ShowError(ThreadError), - StreamedCompletion, - ReceivedTextChunk, - NewRequest, - StreamedAssistantText(MessageId, String), - StreamedAssistantThinking(MessageId, String), - StreamedToolUse { - tool_use_id: LanguageModelToolUseId, - ui_text: Arc, - input: serde_json::Value, - }, - MissingToolUse { - tool_use_id: LanguageModelToolUseId, - ui_text: Arc, - }, - InvalidToolInput { - tool_use_id: LanguageModelToolUseId, - ui_text: Arc, - invalid_input_json: Arc, - }, - Stopped(Result>), - MessageAdded(MessageId), - MessageEdited(MessageId), - MessageDeleted(MessageId), - SummaryGenerated, - SummaryChanged, - UsePendingTools { - tool_uses: Vec, - }, - ToolFinished { - #[allow(unused)] - tool_use_id: LanguageModelToolUseId, - /// The pending tool use that corresponds to this tool. - pending_tool_use: Option, - }, - CheckpointChanged, - ToolConfirmationNeeded, - ToolUseLimitReached, - CancelEditing, - CompletionCanceled, - ProfileChanged, -} - -impl EventEmitter for Thread {} - -struct PendingCompletion { - id: usize, - queue_state: QueueState, - _task: Task<()>, -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - context::load_context, context_store::ContextStore, thread_store, thread_store::ThreadStore, - }; - - // Test-specific constants - const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30; - use agent_settings::{AgentProfileId, AgentSettings}; - use assistant_tool::ToolRegistry; - use assistant_tools; - use fs::Fs; - use futures::StreamExt; - use futures::future::BoxFuture; - use futures::stream::BoxStream; - use gpui::TestAppContext; - use http_client; - use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider}; - use language_model::{ - LanguageModelCompletionError, LanguageModelName, LanguageModelProviderId, - LanguageModelProviderName, LanguageModelToolChoice, - }; - use parking_lot::Mutex; - use project::{FakeFs, Project}; - use prompt_store::PromptBuilder; - use serde_json::json; - use settings::{LanguageModelParameters, Settings, SettingsStore}; - use std::sync::Arc; - use std::time::Duration; - use util::path; - use workspace::Workspace; - - #[gpui::test] - async fn test_message_with_context(cx: &mut TestAppContext) { - let fs = init_test_settings(cx); - - let project = create_test_project( - &fs, - cx, - json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), - ) - .await; - - let (_workspace, _thread_store, thread, context_store, model) = - setup_test_environment(cx, project.clone()).await; - - add_file_to_context(&project, &context_store, "test/code.rs", cx) - .await - .unwrap(); - - let context = - context_store.read_with(cx, |store, _| store.context().next().cloned().unwrap()); - let loaded_context = cx - .update(|cx| load_context(vec![context], &project, &None, cx)) - .await; - - // Insert user message with context - let message_id = thread.update(cx, |thread, cx| { - thread.insert_user_message( - "Please explain this code", - loaded_context, - None, - Vec::new(), - cx, - ) - }); - - // Check content and context in message object - let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone()); - - // Use different path format strings based on platform for the test - #[cfg(windows)] - let path_part = r"test\code.rs"; - #[cfg(not(windows))] - let path_part = "test/code.rs"; - - let expected_context = format!( - r#" - -The following items were attached by the user. They are up-to-date and don't need to be re-read. - - -```rs {path_part} -fn main() {{ - println!("Hello, world!"); -}} -``` - - -"# - ); - - assert_eq!(message.role, Role::User); - assert_eq!(message.segments.len(), 1); - assert_eq!( - message.segments[0], - MessageSegment::Text("Please explain this code".to_string()) - ); - assert_eq!(message.loaded_context.text, expected_context); - - // Check message in request - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) - }); - - assert_eq!(request.messages.len(), 2); - let expected_full_message = format!("{}Please explain this code", expected_context); - assert_eq!(request.messages[1].string_contents(), expected_full_message); - } - - #[gpui::test] - async fn test_only_include_new_contexts(cx: &mut TestAppContext) { - let fs = init_test_settings(cx); - - let project = create_test_project( - &fs, - cx, - json!({ - "file1.rs": "fn function1() {}\n", - "file2.rs": "fn function2() {}\n", - "file3.rs": "fn function3() {}\n", - "file4.rs": "fn function4() {}\n", - }), - ) - .await; - - let (_, _thread_store, thread, context_store, model) = - setup_test_environment(cx, project.clone()).await; - - // First message with context 1 - add_file_to_context(&project, &context_store, "test/file1.rs", cx) - .await - .unwrap(); - let new_contexts = context_store.update(cx, |store, cx| { - store.new_context_for_thread(thread.read(cx), None) - }); - assert_eq!(new_contexts.len(), 1); - let loaded_context = cx - .update(|cx| load_context(new_contexts, &project, &None, cx)) - .await; - let message1_id = thread.update(cx, |thread, cx| { - thread.insert_user_message("Message 1", loaded_context, None, Vec::new(), cx) - }); - - // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included) - add_file_to_context(&project, &context_store, "test/file2.rs", cx) - .await - .unwrap(); - let new_contexts = context_store.update(cx, |store, cx| { - store.new_context_for_thread(thread.read(cx), None) - }); - assert_eq!(new_contexts.len(), 1); - let loaded_context = cx - .update(|cx| load_context(new_contexts, &project, &None, cx)) - .await; - let message2_id = thread.update(cx, |thread, cx| { - thread.insert_user_message("Message 2", loaded_context, None, Vec::new(), cx) - }); - - // Third message with all three contexts (contexts 1 and 2 should be skipped) - // - add_file_to_context(&project, &context_store, "test/file3.rs", cx) - .await - .unwrap(); - let new_contexts = context_store.update(cx, |store, cx| { - store.new_context_for_thread(thread.read(cx), None) - }); - assert_eq!(new_contexts.len(), 1); - let loaded_context = cx - .update(|cx| load_context(new_contexts, &project, &None, cx)) - .await; - let message3_id = thread.update(cx, |thread, cx| { - thread.insert_user_message("Message 3", loaded_context, None, Vec::new(), cx) - }); - - // Check what contexts are included in each message - let (message1, message2, message3) = thread.read_with(cx, |thread, _| { - ( - thread.message(message1_id).unwrap().clone(), - thread.message(message2_id).unwrap().clone(), - thread.message(message3_id).unwrap().clone(), - ) - }); - - // First message should include context 1 - assert!(message1.loaded_context.text.contains("file1.rs")); - - // Second message should include only context 2 (not 1) - assert!(!message2.loaded_context.text.contains("file1.rs")); - assert!(message2.loaded_context.text.contains("file2.rs")); - - // Third message should include only context 3 (not 1 or 2) - assert!(!message3.loaded_context.text.contains("file1.rs")); - assert!(!message3.loaded_context.text.contains("file2.rs")); - assert!(message3.loaded_context.text.contains("file3.rs")); - - // Check entire request to make sure all contexts are properly included - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) - }); - - // The request should contain all 3 messages - assert_eq!(request.messages.len(), 4); - - // Check that the contexts are properly formatted in each message - assert!(request.messages[1].string_contents().contains("file1.rs")); - assert!(!request.messages[1].string_contents().contains("file2.rs")); - assert!(!request.messages[1].string_contents().contains("file3.rs")); - - assert!(!request.messages[2].string_contents().contains("file1.rs")); - assert!(request.messages[2].string_contents().contains("file2.rs")); - assert!(!request.messages[2].string_contents().contains("file3.rs")); - - assert!(!request.messages[3].string_contents().contains("file1.rs")); - assert!(!request.messages[3].string_contents().contains("file2.rs")); - assert!(request.messages[3].string_contents().contains("file3.rs")); - - add_file_to_context(&project, &context_store, "test/file4.rs", cx) - .await - .unwrap(); - let new_contexts = context_store.update(cx, |store, cx| { - store.new_context_for_thread(thread.read(cx), Some(message2_id)) - }); - assert_eq!(new_contexts.len(), 3); - let loaded_context = cx - .update(|cx| load_context(new_contexts, &project, &None, cx)) - .await - .loaded_context; - - assert!(!loaded_context.text.contains("file1.rs")); - assert!(loaded_context.text.contains("file2.rs")); - assert!(loaded_context.text.contains("file3.rs")); - assert!(loaded_context.text.contains("file4.rs")); - - let new_contexts = context_store.update(cx, |store, cx| { - // Remove file4.rs - store.remove_context(&loaded_context.contexts[2].handle(), cx); - store.new_context_for_thread(thread.read(cx), Some(message2_id)) - }); - assert_eq!(new_contexts.len(), 2); - let loaded_context = cx - .update(|cx| load_context(new_contexts, &project, &None, cx)) - .await - .loaded_context; - - assert!(!loaded_context.text.contains("file1.rs")); - assert!(loaded_context.text.contains("file2.rs")); - assert!(loaded_context.text.contains("file3.rs")); - assert!(!loaded_context.text.contains("file4.rs")); - - let new_contexts = context_store.update(cx, |store, cx| { - // Remove file3.rs - store.remove_context(&loaded_context.contexts[1].handle(), cx); - store.new_context_for_thread(thread.read(cx), Some(message2_id)) - }); - assert_eq!(new_contexts.len(), 1); - let loaded_context = cx - .update(|cx| load_context(new_contexts, &project, &None, cx)) - .await - .loaded_context; - - assert!(!loaded_context.text.contains("file1.rs")); - assert!(loaded_context.text.contains("file2.rs")); - assert!(!loaded_context.text.contains("file3.rs")); - assert!(!loaded_context.text.contains("file4.rs")); - } - - #[gpui::test] - async fn test_message_without_files(cx: &mut TestAppContext) { - let fs = init_test_settings(cx); - - let project = create_test_project( - &fs, - cx, - json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), - ) - .await; - - let (_, _thread_store, thread, _context_store, model) = - setup_test_environment(cx, project.clone()).await; - - // Insert user message without any context (empty context vector) - let message_id = thread.update(cx, |thread, cx| { - thread.insert_user_message( - "What is the best way to learn Rust?", - ContextLoadResult::default(), - None, - Vec::new(), - cx, - ) - }); - - // Check content and context in message object - let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone()); - - // Context should be empty when no files are included - assert_eq!(message.role, Role::User); - assert_eq!(message.segments.len(), 1); - assert_eq!( - message.segments[0], - MessageSegment::Text("What is the best way to learn Rust?".to_string()) - ); - assert_eq!(message.loaded_context.text, ""); - - // Check message in request - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) - }); - - assert_eq!(request.messages.len(), 2); - assert_eq!( - request.messages[1].string_contents(), - "What is the best way to learn Rust?" - ); - - // Add second message, also without context - let message2_id = thread.update(cx, |thread, cx| { - thread.insert_user_message( - "Are there any good books?", - ContextLoadResult::default(), - None, - Vec::new(), - cx, - ) - }); - - let message2 = - thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone()); - assert_eq!(message2.loaded_context.text, ""); - - // Check that both messages appear in the request - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) - }); - - assert_eq!(request.messages.len(), 3); - assert_eq!( - request.messages[1].string_contents(), - "What is the best way to learn Rust?" - ); - assert_eq!( - request.messages[2].string_contents(), - "Are there any good books?" - ); - } - - #[gpui::test] - #[ignore] // turn this test on when project_notifications tool is re-enabled - async fn test_stale_buffer_notification(cx: &mut TestAppContext) { - let fs = init_test_settings(cx); - - let project = create_test_project( - &fs, - cx, - json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), - ) - .await; - - let (_workspace, _thread_store, thread, context_store, model) = - setup_test_environment(cx, project.clone()).await; - - // Add a buffer to the context. This will be a tracked buffer - let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx) - .await - .unwrap(); - - let context = context_store - .read_with(cx, |store, _| store.context().next().cloned()) - .unwrap(); - let loaded_context = cx - .update(|cx| load_context(vec![context], &project, &None, cx)) - .await; - - // Insert user message and assistant response - thread.update(cx, |thread, cx| { - thread.insert_user_message("Explain this code", loaded_context, None, Vec::new(), cx); - thread.insert_assistant_message( - vec![MessageSegment::Text("This code prints 42.".into())], - cx, - ); - }); - cx.run_until_parked(); - - // We shouldn't have a stale buffer notification yet - let notifications = thread.read_with(cx, |thread, _| { - find_tool_uses(thread, "project_notifications") - }); - assert!( - notifications.is_empty(), - "Should not have stale buffer notification before buffer is modified" - ); - - // Modify the buffer - buffer.update(cx, |buffer, cx| { - buffer.edit( - [(1..1, "\n println!(\"Added a new line\");\n")], - None, - cx, - ); - }); - - // Insert another user message - thread.update(cx, |thread, cx| { - thread.insert_user_message( - "What does the code do now?", - ContextLoadResult::default(), - None, - Vec::new(), - cx, - ) - }); - cx.run_until_parked(); - - // Check for the stale buffer warning - thread.update(cx, |thread, cx| { - thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx) - }); - cx.run_until_parked(); - - let notifications = thread.read_with(cx, |thread, _cx| { - find_tool_uses(thread, "project_notifications") - }); - - let [notification] = notifications.as_slice() else { - panic!("Should have a `project_notifications` tool use"); - }; - - let Some(notification_content) = notification.content.to_str() else { - panic!("`project_notifications` should return text"); - }; - - assert!(notification_content.contains("These files have changed since the last read:")); - assert!(notification_content.contains("code.rs")); - - // Insert another user message and flush notifications again - thread.update(cx, |thread, cx| { - thread.insert_user_message( - "Can you tell me more?", - ContextLoadResult::default(), - None, - Vec::new(), - cx, - ) - }); - - thread.update(cx, |thread, cx| { - thread.flush_notifications(model.clone(), CompletionIntent::UserPrompt, cx) - }); - cx.run_until_parked(); - - // There should be no new notifications (we already flushed one) - let notifications = thread.read_with(cx, |thread, _cx| { - find_tool_uses(thread, "project_notifications") - }); - - assert_eq!( - notifications.len(), - 1, - "Should still have only one notification after second flush - no duplicates" - ); - } - - fn find_tool_uses(thread: &Thread, tool_name: &str) -> Vec { - thread - .messages() - .flat_map(|message| { - thread - .tool_results_for_message(message.id) - .into_iter() - .filter(|result| result.tool_name == tool_name.into()) - .cloned() - .collect::>() - }) - .collect() - } - - #[gpui::test] - async fn test_storing_profile_setting_per_thread(cx: &mut TestAppContext) { - let fs = init_test_settings(cx); - - let project = create_test_project( - &fs, - cx, - json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), - ) - .await; - - let (_workspace, thread_store, thread, _context_store, _model) = - setup_test_environment(cx, project.clone()).await; - - // Check that we are starting with the default profile - let profile = cx.read(|cx| thread.read(cx).profile.clone()); - let tool_set = cx.read(|cx| thread_store.read(cx).tools()); - assert_eq!( - profile, - AgentProfile::new(AgentProfileId::default(), tool_set) - ); - } - - #[gpui::test] - async fn test_serializing_thread_profile(cx: &mut TestAppContext) { - let fs = init_test_settings(cx); - - let project = create_test_project( - &fs, - cx, - json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), - ) - .await; - - let (_workspace, thread_store, thread, _context_store, _model) = - setup_test_environment(cx, project.clone()).await; - - // Profile gets serialized with default values - let serialized = thread - .update(cx, |thread, cx| thread.serialize(cx)) - .await - .unwrap(); - - assert_eq!(serialized.profile, Some(AgentProfileId::default())); - - let deserialized = cx.update(|cx| { - thread.update(cx, |thread, cx| { - Thread::deserialize( - thread.id.clone(), - serialized, - thread.project.clone(), - thread.tools.clone(), - thread.prompt_builder.clone(), - thread.project_context.clone(), - None, - cx, - ) - }) - }); - let tool_set = cx.read(|cx| thread_store.read(cx).tools()); - - assert_eq!( - deserialized.profile, - AgentProfile::new(AgentProfileId::default(), tool_set) - ); - } - - #[gpui::test] - async fn test_temperature_setting(cx: &mut TestAppContext) { - let fs = init_test_settings(cx); - - let project = create_test_project( - &fs, - cx, - json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), - ) - .await; - - let (_workspace, _thread_store, thread, _context_store, model) = - setup_test_environment(cx, project.clone()).await; - - // Both model and provider - cx.update(|cx| { - AgentSettings::override_global( - AgentSettings { - model_parameters: vec![LanguageModelParameters { - provider: Some(model.provider_id().0.to_string().into()), - model: Some(model.id().0), - temperature: Some(0.66), - }], - ..AgentSettings::get_global(cx).clone() - }, - cx, - ); - }); - - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) - }); - assert_eq!(request.temperature, Some(0.66)); - - // Only model - cx.update(|cx| { - AgentSettings::override_global( - AgentSettings { - model_parameters: vec![LanguageModelParameters { - provider: None, - model: Some(model.id().0), - temperature: Some(0.66), - }], - ..AgentSettings::get_global(cx).clone() - }, - cx, - ); - }); - - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) - }); - assert_eq!(request.temperature, Some(0.66)); - - // Only provider - cx.update(|cx| { - AgentSettings::override_global( - AgentSettings { - model_parameters: vec![LanguageModelParameters { - provider: Some(model.provider_id().0.to_string().into()), - model: None, - temperature: Some(0.66), - }], - ..AgentSettings::get_global(cx).clone() - }, - cx, - ); - }); - - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) - }); - assert_eq!(request.temperature, Some(0.66)); - - // Same model name, different provider - cx.update(|cx| { - AgentSettings::override_global( - AgentSettings { - model_parameters: vec![LanguageModelParameters { - provider: Some("anthropic".into()), - model: Some(model.id().0), - temperature: Some(0.66), - }], - ..AgentSettings::get_global(cx).clone() - }, - cx, - ); - }); - - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) - }); - assert_eq!(request.temperature, None); - } - - #[gpui::test] - async fn test_thread_summary(cx: &mut TestAppContext) { - let fs = init_test_settings(cx); - - let project = create_test_project(&fs, cx, json!({})).await; - - let (_, _thread_store, thread, _context_store, model) = - setup_test_environment(cx, project.clone()).await; - - // Initial state should be pending - thread.read_with(cx, |thread, _| { - assert!(matches!(thread.summary(), ThreadSummary::Pending)); - assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); - }); - - // Manually setting the summary should not be allowed in this state - thread.update(cx, |thread, cx| { - thread.set_summary("This should not work", cx); - }); - - thread.read_with(cx, |thread, _| { - assert!(matches!(thread.summary(), ThreadSummary::Pending)); - }); - - // Send a message - thread.update(cx, |thread, cx| { - thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx); - thread.send_to_model( - model.clone(), - CompletionIntent::ThreadSummarization, - None, - cx, - ); - }); - - let fake_model = model.as_fake(); - simulate_successful_response(fake_model, cx); - - // Should start generating summary when there are >= 2 messages - thread.read_with(cx, |thread, _| { - assert_eq!(*thread.summary(), ThreadSummary::Generating); - }); - - // Should not be able to set the summary while generating - thread.update(cx, |thread, cx| { - thread.set_summary("This should not work either", cx); - }); - - thread.read_with(cx, |thread, _| { - assert!(matches!(thread.summary(), ThreadSummary::Generating)); - assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); - }); - - cx.run_until_parked(); - fake_model.send_last_completion_stream_text_chunk("Brief"); - fake_model.send_last_completion_stream_text_chunk(" Introduction"); - fake_model.end_last_completion_stream(); - cx.run_until_parked(); - - // Summary should be set - thread.read_with(cx, |thread, _| { - assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); - assert_eq!(thread.summary().or_default(), "Brief Introduction"); - }); - - // Now we should be able to set a summary - thread.update(cx, |thread, cx| { - thread.set_summary("Brief Intro", cx); - }); - - thread.read_with(cx, |thread, _| { - assert_eq!(thread.summary().or_default(), "Brief Intro"); - }); - - // Test setting an empty summary (should default to DEFAULT) - thread.update(cx, |thread, cx| { - thread.set_summary("", cx); - }); - - thread.read_with(cx, |thread, _| { - assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); - assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); - }); - } - - #[gpui::test] - async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) { - let fs = init_test_settings(cx); - - let project = create_test_project(&fs, cx, json!({})).await; - - let (_, _thread_store, thread, _context_store, model) = - setup_test_environment(cx, project.clone()).await; - - test_summarize_error(&model, &thread, cx); - - // Now we should be able to set a summary - thread.update(cx, |thread, cx| { - thread.set_summary("Brief Intro", cx); - }); - - thread.read_with(cx, |thread, _| { - assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); - assert_eq!(thread.summary().or_default(), "Brief Intro"); - }); - } - - #[gpui::test] - async fn test_thread_summary_error_retry(cx: &mut TestAppContext) { - let fs = init_test_settings(cx); - - let project = create_test_project(&fs, cx, json!({})).await; - - let (_, _thread_store, thread, _context_store, model) = - setup_test_environment(cx, project.clone()).await; - - test_summarize_error(&model, &thread, cx); - - // Sending another message should not trigger another summarize request - thread.update(cx, |thread, cx| { - thread.insert_user_message( - "How are you?", - ContextLoadResult::default(), - None, - vec![], - cx, - ); - thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - - let fake_model = model.as_fake(); - simulate_successful_response(fake_model, cx); - - thread.read_with(cx, |thread, _| { - // State is still Error, not Generating - assert!(matches!(thread.summary(), ThreadSummary::Error)); - }); - - // But the summarize request can be invoked manually - thread.update(cx, |thread, cx| { - thread.summarize(cx); - }); - - thread.read_with(cx, |thread, _| { - assert!(matches!(thread.summary(), ThreadSummary::Generating)); - }); - - cx.run_until_parked(); - fake_model.send_last_completion_stream_text_chunk("A successful summary"); - fake_model.end_last_completion_stream(); - cx.run_until_parked(); - - thread.read_with(cx, |thread, _| { - assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); - assert_eq!(thread.summary().or_default(), "A successful summary"); - }); - } - - // Helper to create a model that returns errors - enum TestError { - Overloaded, - InternalServerError, - } - - struct ErrorInjector { - inner: Arc, - error_type: TestError, - } - - impl ErrorInjector { - fn new(error_type: TestError) -> Self { - Self { - inner: Arc::new(FakeLanguageModel::default()), - error_type, - } - } - } - - impl LanguageModel for ErrorInjector { - fn id(&self) -> LanguageModelId { - self.inner.id() - } - - fn name(&self) -> LanguageModelName { - self.inner.name() - } - - fn provider_id(&self) -> LanguageModelProviderId { - self.inner.provider_id() - } - - fn provider_name(&self) -> LanguageModelProviderName { - self.inner.provider_name() - } - - fn supports_tools(&self) -> bool { - self.inner.supports_tools() - } - - fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { - self.inner.supports_tool_choice(choice) - } - - fn supports_images(&self) -> bool { - self.inner.supports_images() - } - - fn telemetry_id(&self) -> String { - self.inner.telemetry_id() - } - - fn max_token_count(&self) -> u64 { - self.inner.max_token_count() - } - - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - self.inner.count_tokens(request, cx) - } - - fn stream_completion( - &self, - _request: LanguageModelRequest, - _cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result< - BoxStream< - 'static, - Result, - >, - LanguageModelCompletionError, - >, - > { - let error = match self.error_type { - TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded { - provider: self.provider_name(), - retry_after: None, - }, - TestError::InternalServerError => { - LanguageModelCompletionError::ApiInternalServerError { - provider: self.provider_name(), - message: "I'm a teapot orbiting the sun".to_string(), - } - } - }; - async move { - let stream = futures::stream::once(async move { Err(error) }); - Ok(stream.boxed()) - } - .boxed() - } - - fn as_fake(&self) -> &FakeLanguageModel { - &self.inner - } - } - - #[gpui::test] - async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) { - let fs = init_test_settings(cx); - - let project = create_test_project(&fs, cx, json!({})).await; - let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; - - // Enable Burn Mode to allow retries - thread.update(cx, |thread, _| { - thread.set_completion_mode(CompletionMode::Burn); - }); - - // Create model that returns overloaded error - let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); - - // Insert a user message - thread.update(cx, |thread, cx| { - thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Start completion - thread.update(cx, |thread, cx| { - thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - - cx.run_until_parked(); - - thread.read_with(cx, |thread, _| { - assert!(thread.retry_state.is_some(), "Should have retry state"); - let retry_state = thread.retry_state.as_ref().unwrap(); - assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); - assert_eq!( - retry_state.max_attempts, MAX_RETRY_ATTEMPTS, - "Should retry MAX_RETRY_ATTEMPTS times for overloaded errors" - ); - }); - - // Check that a retry message was added - thread.read_with(cx, |thread, _| { - let mut messages = thread.messages(); - assert!( - messages.any(|msg| { - msg.role == Role::System - && msg.ui_only - && msg.segments.iter().any(|seg| { - if let MessageSegment::Text(text) = seg { - text.contains("overloaded") - && text - .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)) - } else { - false - } - }) - }), - "Should have added a system retry message" - ); - }); - - let retry_count = thread.update(cx, |thread, _| { - thread - .messages - .iter() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - - assert_eq!(retry_count, 1, "Should have one retry message"); - } - - #[gpui::test] - async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) { - let fs = init_test_settings(cx); - - let project = create_test_project(&fs, cx, json!({})).await; - let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; - - // Enable Burn Mode to allow retries - thread.update(cx, |thread, _| { - thread.set_completion_mode(CompletionMode::Burn); - }); - - // Create model that returns internal server error - let model = Arc::new(ErrorInjector::new(TestError::InternalServerError)); - - // Insert a user message - thread.update(cx, |thread, cx| { - thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Start completion - thread.update(cx, |thread, cx| { - thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - - cx.run_until_parked(); - - // Check retry state on thread - thread.read_with(cx, |thread, _| { - assert!(thread.retry_state.is_some(), "Should have retry state"); - let retry_state = thread.retry_state.as_ref().unwrap(); - assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); - assert_eq!( - retry_state.max_attempts, 3, - "Should have correct max attempts" - ); - }); - - // Check that a retry message was added with provider name - thread.read_with(cx, |thread, _| { - let mut messages = thread.messages(); - assert!( - messages.any(|msg| { - msg.role == Role::System - && msg.ui_only - && msg.segments.iter().any(|seg| { - if let MessageSegment::Text(text) = seg { - text.contains("internal") - && text.contains("Fake") - && text.contains("Retrying") - && text.contains("attempt 1 of 3") - && text.contains("seconds") - } else { - false - } - }) - }), - "Should have added a system retry message with provider name" - ); - }); - - // Count retry messages - let retry_count = thread.update(cx, |thread, _| { - thread - .messages - .iter() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - - assert_eq!(retry_count, 1, "Should have one retry message"); - } - - #[gpui::test] - async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) { - let fs = init_test_settings(cx); - - let project = create_test_project(&fs, cx, json!({})).await; - let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; - - // Enable Burn Mode to allow retries - thread.update(cx, |thread, _| { - thread.set_completion_mode(CompletionMode::Burn); - }); - - // Create model that returns internal server error - let model = Arc::new(ErrorInjector::new(TestError::InternalServerError)); - - // Insert a user message - thread.update(cx, |thread, cx| { - thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Track retry events and completion count - // Track completion events - let completion_count = Arc::new(Mutex::new(0)); - let completion_count_clone = completion_count.clone(); - - let _subscription = thread.update(cx, |_, cx| { - cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { - if let ThreadEvent::NewRequest = event { - *completion_count_clone.lock() += 1; - } - }) - }); - - // First attempt - thread.update(cx, |thread, cx| { - thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - cx.run_until_parked(); - - // Should have scheduled first retry - count retry messages - let retry_count = thread.update(cx, |thread, _| { - thread - .messages - .iter() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - assert_eq!(retry_count, 1, "Should have scheduled first retry"); - - // Check retry state - thread.read_with(cx, |thread, _| { - assert!(thread.retry_state.is_some(), "Should have retry state"); - let retry_state = thread.retry_state.as_ref().unwrap(); - assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); - assert_eq!( - retry_state.max_attempts, 3, - "Internal server errors should retry up to 3 times" - ); - }); - - // Advance clock for first retry - cx.executor().advance_clock(BASE_RETRY_DELAY); - cx.run_until_parked(); - - // Advance clock for second retry - cx.executor().advance_clock(BASE_RETRY_DELAY); - cx.run_until_parked(); - - // Advance clock for third retry - cx.executor().advance_clock(BASE_RETRY_DELAY); - cx.run_until_parked(); - - // Should have completed all retries - count retry messages - let retry_count = thread.update(cx, |thread, _| { - thread - .messages - .iter() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - assert_eq!( - retry_count, 3, - "Should have 3 retries for internal server errors" - ); - - // For internal server errors, we retry 3 times and then give up - // Check that retry_state is cleared after all retries - thread.read_with(cx, |thread, _| { - assert!( - thread.retry_state.is_none(), - "Retry state should be cleared after all retries" - ); - }); - - // Verify total attempts (1 initial + 3 retries) - assert_eq!( - *completion_count.lock(), - 4, - "Should have attempted once plus 3 retries" - ); - } - - #[gpui::test] - async fn test_max_retries_exceeded(cx: &mut TestAppContext) { - let fs = init_test_settings(cx); - - let project = create_test_project(&fs, cx, json!({})).await; - let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; - - // Enable Burn Mode to allow retries - thread.update(cx, |thread, _| { - thread.set_completion_mode(CompletionMode::Burn); - }); - - // Create model that returns overloaded error - let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); - - // Insert a user message - thread.update(cx, |thread, cx| { - thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Track events - let stopped_with_error = Arc::new(Mutex::new(false)); - let stopped_with_error_clone = stopped_with_error.clone(); - - let _subscription = thread.update(cx, |_, cx| { - cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { - if let ThreadEvent::Stopped(Err(_)) = event { - *stopped_with_error_clone.lock() = true; - } - }) - }); - - // Start initial completion - thread.update(cx, |thread, cx| { - thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - cx.run_until_parked(); - - // Advance through all retries - for _ in 0..MAX_RETRY_ATTEMPTS { - cx.executor().advance_clock(BASE_RETRY_DELAY); - cx.run_until_parked(); - } - - let retry_count = thread.update(cx, |thread, _| { - thread - .messages - .iter() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - - // After max retries, should emit Stopped(Err(...)) event - assert_eq!( - retry_count, MAX_RETRY_ATTEMPTS as usize, - "Should have attempted MAX_RETRY_ATTEMPTS retries for overloaded errors" - ); - assert!( - *stopped_with_error.lock(), - "Should emit Stopped(Err(...)) event after max retries exceeded" - ); - - // Retry state should be cleared - thread.read_with(cx, |thread, _| { - assert!( - thread.retry_state.is_none(), - "Retry state should be cleared after max retries" - ); - - // Verify we have the expected number of retry messages - let retry_messages = thread - .messages - .iter() - .filter(|msg| msg.ui_only && msg.role == Role::System) - .count(); - assert_eq!( - retry_messages, MAX_RETRY_ATTEMPTS as usize, - "Should have MAX_RETRY_ATTEMPTS retry messages for overloaded errors" - ); - }); - } - - #[gpui::test] - async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) { - let fs = init_test_settings(cx); - - let project = create_test_project(&fs, cx, json!({})).await; - let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; - - // Enable Burn Mode to allow retries - thread.update(cx, |thread, _| { - thread.set_completion_mode(CompletionMode::Burn); - }); - - // We'll use a wrapper to switch behavior after first failure - struct RetryTestModel { - inner: Arc, - failed_once: Arc>, - } - - impl LanguageModel for RetryTestModel { - fn id(&self) -> LanguageModelId { - self.inner.id() - } - - fn name(&self) -> LanguageModelName { - self.inner.name() - } - - fn provider_id(&self) -> LanguageModelProviderId { - self.inner.provider_id() - } - - fn provider_name(&self) -> LanguageModelProviderName { - self.inner.provider_name() - } - - fn supports_tools(&self) -> bool { - self.inner.supports_tools() - } - - fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { - self.inner.supports_tool_choice(choice) - } - - fn supports_images(&self) -> bool { - self.inner.supports_images() - } - - fn telemetry_id(&self) -> String { - self.inner.telemetry_id() - } - - fn max_token_count(&self) -> u64 { - self.inner.max_token_count() - } - - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - self.inner.count_tokens(request, cx) - } - - fn stream_completion( - &self, - request: LanguageModelRequest, - cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result< - BoxStream< - 'static, - Result, - >, - LanguageModelCompletionError, - >, - > { - if !*self.failed_once.lock() { - *self.failed_once.lock() = true; - let provider = self.provider_name(); - // Return error on first attempt - let stream = futures::stream::once(async move { - Err(LanguageModelCompletionError::ServerOverloaded { - provider, - retry_after: None, - }) - }); - async move { Ok(stream.boxed()) }.boxed() - } else { - // Succeed on retry - self.inner.stream_completion(request, cx) - } - } - - fn as_fake(&self) -> &FakeLanguageModel { - &self.inner - } - } - - let model = Arc::new(RetryTestModel { - inner: Arc::new(FakeLanguageModel::default()), - failed_once: Arc::new(Mutex::new(false)), - }); - - // Insert a user message - thread.update(cx, |thread, cx| { - thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Track message deletions - // Track when retry completes successfully - let retry_completed = Arc::new(Mutex::new(false)); - let retry_completed_clone = retry_completed.clone(); - - let _subscription = thread.update(cx, |_, cx| { - cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { - if let ThreadEvent::StreamedCompletion = event { - *retry_completed_clone.lock() = true; - } - }) - }); - - // Start completion - thread.update(cx, |thread, cx| { - thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - cx.run_until_parked(); - - // Get the retry message ID - let retry_message_id = thread.read_with(cx, |thread, _| { - thread - .messages() - .find(|msg| msg.role == Role::System && msg.ui_only) - .map(|msg| msg.id) - .expect("Should have a retry message") - }); - - // Wait for retry - cx.executor().advance_clock(BASE_RETRY_DELAY); - cx.run_until_parked(); - - // Stream some successful content - let fake_model = model.as_fake(); - // After the retry, there should be a new pending completion - let pending = fake_model.pending_completions(); - assert!( - !pending.is_empty(), - "Should have a pending completion after retry" - ); - fake_model.send_completion_stream_text_chunk(&pending[0], "Success!"); - fake_model.end_completion_stream(&pending[0]); - cx.run_until_parked(); - - // Check that the retry completed successfully - assert!( - *retry_completed.lock(), - "Retry should have completed successfully" - ); - - // Retry message should still exist but be marked as ui_only - thread.read_with(cx, |thread, _| { - let retry_msg = thread - .message(retry_message_id) - .expect("Retry message should still exist"); - assert!(retry_msg.ui_only, "Retry message should be ui_only"); - assert_eq!( - retry_msg.role, - Role::System, - "Retry message should have System role" - ); - }); - } - - #[gpui::test] - async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) { - let fs = init_test_settings(cx); - - let project = create_test_project(&fs, cx, json!({})).await; - let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; - - // Enable Burn Mode to allow retries - thread.update(cx, |thread, _| { - thread.set_completion_mode(CompletionMode::Burn); - }); - - // Create a model that fails once then succeeds - struct FailOnceModel { - inner: Arc, - failed_once: Arc>, - } - - impl LanguageModel for FailOnceModel { - fn id(&self) -> LanguageModelId { - self.inner.id() - } - - fn name(&self) -> LanguageModelName { - self.inner.name() - } - - fn provider_id(&self) -> LanguageModelProviderId { - self.inner.provider_id() - } - - fn provider_name(&self) -> LanguageModelProviderName { - self.inner.provider_name() - } - - fn supports_tools(&self) -> bool { - self.inner.supports_tools() - } - - fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { - self.inner.supports_tool_choice(choice) - } - - fn supports_images(&self) -> bool { - self.inner.supports_images() - } - - fn telemetry_id(&self) -> String { - self.inner.telemetry_id() - } - - fn max_token_count(&self) -> u64 { - self.inner.max_token_count() - } - - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - self.inner.count_tokens(request, cx) - } - - fn stream_completion( - &self, - request: LanguageModelRequest, - cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result< - BoxStream< - 'static, - Result, - >, - LanguageModelCompletionError, - >, - > { - if !*self.failed_once.lock() { - *self.failed_once.lock() = true; - let provider = self.provider_name(); - // Return error on first attempt - let stream = futures::stream::once(async move { - Err(LanguageModelCompletionError::ServerOverloaded { - provider, - retry_after: None, - }) - }); - async move { Ok(stream.boxed()) }.boxed() - } else { - // Succeed on retry - self.inner.stream_completion(request, cx) - } - } - } - - let fail_once_model = Arc::new(FailOnceModel { - inner: Arc::new(FakeLanguageModel::default()), - failed_once: Arc::new(Mutex::new(false)), - }); - - // Insert a user message - thread.update(cx, |thread, cx| { - thread.insert_user_message( - "Test message", - ContextLoadResult::default(), - None, - vec![], - cx, - ); - }); - - // Start completion with fail-once model - thread.update(cx, |thread, cx| { - thread.send_to_model( - fail_once_model.clone(), - CompletionIntent::UserPrompt, - None, - cx, - ); - }); - - cx.run_until_parked(); - - // Verify retry state exists after first failure - thread.read_with(cx, |thread, _| { - assert!( - thread.retry_state.is_some(), - "Should have retry state after failure" - ); - }); - - // Wait for retry delay - cx.executor().advance_clock(BASE_RETRY_DELAY); - cx.run_until_parked(); - - // The retry should now use our FailOnceModel which should succeed - // We need to help the FakeLanguageModel complete the stream - let inner_fake = fail_once_model.inner.clone(); - - // Wait a bit for the retry to start - cx.run_until_parked(); - - // Check for pending completions and complete them - if let Some(pending) = inner_fake.pending_completions().first() { - inner_fake.send_completion_stream_text_chunk(pending, "Success!"); - inner_fake.end_completion_stream(pending); - } - cx.run_until_parked(); - - thread.read_with(cx, |thread, _| { - assert!( - thread.retry_state.is_none(), - "Retry state should be cleared after successful completion" - ); - - let has_assistant_message = thread - .messages - .iter() - .any(|msg| msg.role == Role::Assistant && !msg.ui_only); - assert!( - has_assistant_message, - "Should have an assistant message after successful retry" - ); - }); - } - - #[gpui::test] - async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) { - let fs = init_test_settings(cx); - - let project = create_test_project(&fs, cx, json!({})).await; - let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; - - // Enable Burn Mode to allow retries - thread.update(cx, |thread, _| { - thread.set_completion_mode(CompletionMode::Burn); - }); - - // Create a model that returns rate limit error with retry_after - struct RateLimitModel { - inner: Arc, - } - - impl LanguageModel for RateLimitModel { - fn id(&self) -> LanguageModelId { - self.inner.id() - } - - fn name(&self) -> LanguageModelName { - self.inner.name() - } - - fn provider_id(&self) -> LanguageModelProviderId { - self.inner.provider_id() - } - - fn provider_name(&self) -> LanguageModelProviderName { - self.inner.provider_name() - } - - fn supports_tools(&self) -> bool { - self.inner.supports_tools() - } - - fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { - self.inner.supports_tool_choice(choice) - } - - fn supports_images(&self) -> bool { - self.inner.supports_images() - } - - fn telemetry_id(&self) -> String { - self.inner.telemetry_id() - } - - fn max_token_count(&self) -> u64 { - self.inner.max_token_count() - } - - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - self.inner.count_tokens(request, cx) - } - - fn stream_completion( - &self, - _request: LanguageModelRequest, - _cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result< - BoxStream< - 'static, - Result, - >, - LanguageModelCompletionError, - >, - > { - let provider = self.provider_name(); - async move { - let stream = futures::stream::once(async move { - Err(LanguageModelCompletionError::RateLimitExceeded { - provider, - retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)), - }) - }); - Ok(stream.boxed()) - } - .boxed() - } - - fn as_fake(&self) -> &FakeLanguageModel { - &self.inner - } - } - - let model = Arc::new(RateLimitModel { - inner: Arc::new(FakeLanguageModel::default()), - }); - - // Insert a user message - thread.update(cx, |thread, cx| { - thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Start completion - thread.update(cx, |thread, cx| { - thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - - cx.run_until_parked(); - - let retry_count = thread.update(cx, |thread, _| { - thread - .messages - .iter() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("rate limit exceeded") - } else { - false - } - }) - }) - .count() - }); - assert_eq!(retry_count, 1, "Should have scheduled one retry"); - - thread.read_with(cx, |thread, _| { - assert!( - thread.retry_state.is_some(), - "Rate limit errors should set retry_state" - ); - if let Some(retry_state) = &thread.retry_state { - assert_eq!( - retry_state.max_attempts, MAX_RETRY_ATTEMPTS, - "Rate limit errors should use MAX_RETRY_ATTEMPTS" - ); - } - }); - - // Verify we have one retry message - thread.read_with(cx, |thread, _| { - let retry_messages = thread - .messages - .iter() - .filter(|msg| { - msg.ui_only - && msg.segments.iter().any(|seg| { - if let MessageSegment::Text(text) = seg { - text.contains("rate limit exceeded") - } else { - false - } - }) - }) - .count(); - assert_eq!( - retry_messages, 1, - "Should have one rate limit retry message" - ); - }); - - // Check that retry message doesn't include attempt count - thread.read_with(cx, |thread, _| { - let retry_message = thread - .messages - .iter() - .find(|msg| msg.role == Role::System && msg.ui_only) - .expect("Should have a retry message"); - - // Check that the message contains attempt count since we use retry_state - if let Some(MessageSegment::Text(text)) = retry_message.segments.first() { - assert!( - text.contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)), - "Rate limit retry message should contain attempt count with MAX_RETRY_ATTEMPTS" - ); - assert!( - text.contains("Retrying"), - "Rate limit retry message should contain retry text" - ); - } - }); - } - - #[gpui::test] - async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) { - let fs = init_test_settings(cx); - - let project = create_test_project(&fs, cx, json!({})).await; - let (_, _, thread, _, model) = setup_test_environment(cx, project.clone()).await; - - // Insert a regular user message - thread.update(cx, |thread, cx| { - thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Insert a UI-only message (like our retry notifications) - thread.update(cx, |thread, cx| { - let id = thread.next_message_id.post_inc(); - thread.messages.push(Message { - id, - role: Role::System, - segments: vec![MessageSegment::Text( - "This is a UI-only message that should not be sent to the model".to_string(), - )], - loaded_context: LoadedContext::default(), - creases: Vec::new(), - is_hidden: true, - ui_only: true, - }); - cx.emit(ThreadEvent::MessageAdded(id)); - }); - - // Insert another regular message - thread.update(cx, |thread, cx| { - thread.insert_user_message( - "How are you?", - ContextLoadResult::default(), - None, - vec![], - cx, - ); - }); - - // Generate the completion request - let request = thread.update(cx, |thread, cx| { - thread.to_completion_request(model.clone(), CompletionIntent::UserPrompt, cx) - }); - - // Verify that the request only contains non-UI-only messages - // Should have system prompt + 2 user messages, but not the UI-only message - let user_messages: Vec<_> = request - .messages - .iter() - .filter(|msg| msg.role == Role::User) - .collect(); - assert_eq!( - user_messages.len(), - 2, - "Should have exactly 2 user messages" - ); - - // Verify the UI-only content is not present anywhere in the request - let request_text = request - .messages - .iter() - .flat_map(|msg| &msg.content) - .filter_map(|content| match content { - MessageContent::Text(text) => Some(text.as_str()), - _ => None, - }) - .collect::(); - - assert!( - !request_text.contains("UI-only message"), - "UI-only message content should not be in the request" - ); - - // Verify the thread still has all 3 messages (including UI-only) - thread.read_with(cx, |thread, _| { - assert_eq!( - thread.messages().count(), - 3, - "Thread should have 3 messages" - ); - assert_eq!( - thread.messages().filter(|m| m.ui_only).count(), - 1, - "Thread should have 1 UI-only message" - ); - }); - - // Verify that UI-only messages are not serialized - let serialized = thread - .update(cx, |thread, cx| thread.serialize(cx)) - .await - .unwrap(); - assert_eq!( - serialized.messages.len(), - 2, - "Serialized thread should only have 2 messages (no UI-only)" - ); - } - - #[gpui::test] - async fn test_no_retry_without_burn_mode(cx: &mut TestAppContext) { - let fs = init_test_settings(cx); - - let project = create_test_project(&fs, cx, json!({})).await; - let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; - - // Ensure we're in Normal mode (not Burn mode) - thread.update(cx, |thread, _| { - thread.set_completion_mode(CompletionMode::Normal); - }); - - // Track error events - let error_events = Arc::new(Mutex::new(Vec::new())); - let error_events_clone = error_events.clone(); - - let _subscription = thread.update(cx, |_, cx| { - cx.subscribe(&thread, move |_, _, event: &ThreadEvent, _| { - if let ThreadEvent::ShowError(error) = event { - error_events_clone.lock().push(error.clone()); - } - }) - }); - - // Create model that returns overloaded error - let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); - - // Insert a user message - thread.update(cx, |thread, cx| { - thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Start completion - thread.update(cx, |thread, cx| { - thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - - cx.run_until_parked(); - - // Verify no retry state was created - thread.read_with(cx, |thread, _| { - assert!( - thread.retry_state.is_none(), - "Should not have retry state in Normal mode" - ); - }); - - // Check that a retryable error was reported - let errors = error_events.lock(); - assert!(!errors.is_empty(), "Should have received an error event"); - - if let ThreadError::RetryableError { - message: _, - can_enable_burn_mode, - } = &errors[0] - { - assert!( - *can_enable_burn_mode, - "Error should indicate burn mode can be enabled" - ); - } else { - panic!("Expected RetryableError, got {:?}", errors[0]); - } - - // Verify the thread is no longer generating - thread.read_with(cx, |thread, _| { - assert!( - !thread.is_generating(), - "Should not be generating after error without retry" - ); - }); - } - - #[gpui::test] - async fn test_retry_canceled_on_stop(cx: &mut TestAppContext) { - let fs = init_test_settings(cx); - - let project = create_test_project(&fs, cx, json!({})).await; - let (_, _, thread, _, _base_model) = setup_test_environment(cx, project.clone()).await; - - // Enable Burn Mode to allow retries - thread.update(cx, |thread, _| { - thread.set_completion_mode(CompletionMode::Burn); - }); - - // Create model that returns overloaded error - let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); - - // Insert a user message - thread.update(cx, |thread, cx| { - thread.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Start completion - thread.update(cx, |thread, cx| { - thread.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - - cx.run_until_parked(); - - // Verify retry was scheduled by checking for retry message - let has_retry_message = thread.read_with(cx, |thread, _| { - thread.messages.iter().any(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - }); - assert!(has_retry_message, "Should have scheduled a retry"); - - // Cancel the completion before the retry happens - thread.update(cx, |thread, cx| { - thread.cancel_last_completion(None, cx); - }); - - cx.run_until_parked(); - - // The retry should not have happened - no pending completions - let fake_model = model.as_fake(); - assert_eq!( - fake_model.pending_completions().len(), - 0, - "Should have no pending completions after cancellation" - ); - - // Verify the retry was canceled by checking retry state - thread.read_with(cx, |thread, _| { - if let Some(retry_state) = &thread.retry_state { - panic!( - "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}", - retry_state.attempt, retry_state.max_attempts, retry_state.intent - ); - } - }); - } - - fn test_summarize_error( - model: &Arc, - thread: &Entity, - cx: &mut TestAppContext, - ) { - thread.update(cx, |thread, cx| { - thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx); - thread.send_to_model( - model.clone(), - CompletionIntent::ThreadSummarization, - None, - cx, - ); - }); - - let fake_model = model.as_fake(); - simulate_successful_response(fake_model, cx); - - thread.read_with(cx, |thread, _| { - assert!(matches!(thread.summary(), ThreadSummary::Generating)); - assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); - }); - - // Simulate summary request ending - cx.run_until_parked(); - fake_model.end_last_completion_stream(); - cx.run_until_parked(); - - // State is set to Error and default message - thread.read_with(cx, |thread, _| { - assert!(matches!(thread.summary(), ThreadSummary::Error)); - assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); - }); - } - - fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) { - cx.run_until_parked(); - fake_model.send_last_completion_stream_text_chunk("Assistant response"); - fake_model.end_last_completion_stream(); - cx.run_until_parked(); - } - - fn init_test_settings(cx: &mut TestAppContext) -> Arc { - let fs = FakeFs::new(cx.executor()); - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - language::init(cx); - Project::init_settings(cx); - AgentSettings::register(cx); - prompt_store::init(cx); - thread_store::init(fs.clone(), cx); - workspace::init_settings(cx); - language_model::init_settings(cx); - theme::init(theme::LoadThemes::JustBase, cx); - ToolRegistry::default_global(cx); - assistant_tool::init(cx); - - let http_client = Arc::new(http_client::HttpClientWithUrl::new( - http_client::FakeHttpClient::with_200_response(), - "http://localhost".to_string(), - None, - )); - assistant_tools::init(http_client, cx); - }); - fs - } - - // Helper to create a test project with test files - async fn create_test_project( - fs: &Arc, - cx: &mut TestAppContext, - files: serde_json::Value, - ) -> Entity { - fs.as_fake().insert_tree(path!("/test"), files).await; - Project::test(fs.clone(), [path!("/test").as_ref()], cx).await - } - - async fn setup_test_environment( - cx: &mut TestAppContext, - project: Entity, - ) -> ( - Entity, - Entity, - Entity, - Entity, - Arc, - ) { - let (workspace, cx) = - cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx)); - - let thread_store = cx - .update(|_, cx| { - ThreadStore::load( - project.clone(), - cx.new(|_| ToolWorkingSet::default()), - None, - Arc::new(PromptBuilder::new(None).unwrap()), - cx, - ) - }) - .await - .unwrap(); - - let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); - let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None)); - - let provider = Arc::new(FakeLanguageModelProvider::default()); - let model = provider.test_model(); - let model: Arc = Arc::new(model); - - cx.update(|_, cx| { - LanguageModelRegistry::global(cx).update(cx, |registry, cx| { - registry.set_default_model( - Some(ConfiguredModel { - provider: provider.clone(), - model: model.clone(), - }), - cx, - ); - registry.set_thread_summary_model( - Some(ConfiguredModel { - provider, - model: model.clone(), - }), - cx, - ); - }) - }); - - (workspace, thread_store, thread, context_store, model) - } - - async fn add_file_to_context( - project: &Entity, - context_store: &Entity, - path: &str, - cx: &mut TestAppContext, - ) -> Result> { - let buffer_path = project - .read_with(cx, |project, cx| project.find_project_path(path, cx)) - .unwrap(); - - let buffer = project - .update(cx, |project, cx| { - project.open_buffer(buffer_path.clone(), cx) - }) - .await - .unwrap(); - - context_store.update(cx, |context_store, cx| { - context_store.add_file_from_buffer(&buffer_path, buffer.clone(), false, cx); - }); - - Ok(buffer) +fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage { + LanguageModelImage { + source: image_content.data.into(), + // TODO: make this optional? + size: gpui::Size::new(0.into(), 0.into()), } } diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs deleted file mode 100644 index 2139f232e3e99b1affb78928dec70e1aaef2a03a..0000000000000000000000000000000000000000 --- a/crates/agent/src/thread_store.rs +++ /dev/null @@ -1,1287 +0,0 @@ -use crate::{ - context_server_tool::ContextServerTool, - thread::{ - DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId, - }, -}; -use agent_settings::{AgentProfileId, CompletionMode}; -use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{Tool, ToolId, ToolWorkingSet}; -use chrono::{DateTime, Utc}; -use collections::HashMap; -use context_server::ContextServerId; -use fs::{Fs, RemoveOptions}; -use futures::{ - FutureExt as _, StreamExt as _, - channel::{mpsc, oneshot}, - future::{self, BoxFuture, Shared}, -}; -use gpui::{ - App, BackgroundExecutor, Context, Entity, EventEmitter, Global, ReadGlobal, SharedString, - Subscription, Task, Window, prelude::*, -}; -use indoc::indoc; -use language_model::{LanguageModelToolResultContent, LanguageModelToolUseId, Role, TokenUsage}; -use project::context_server_store::{ContextServerStatus, ContextServerStore}; -use project::{Project, ProjectItem, ProjectPath, Worktree}; -use prompt_store::{ - ProjectContext, PromptBuilder, PromptId, PromptStore, PromptsUpdatedEvent, RulesFileContext, - UserRulesContext, WorktreeContext, -}; -use serde::{Deserialize, Serialize}; -use sqlez::{ - bindable::{Bind, Column}, - connection::Connection, - statement::Statement, -}; -use std::{ - cell::{Ref, RefCell}, - path::{Path, PathBuf}, - rc::Rc, - sync::{Arc, LazyLock, Mutex}, -}; -use util::{ResultExt as _, rel_path::RelPath}; - -use zed_env_vars::ZED_STATELESS; - -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum DataType { - #[serde(rename = "json")] - Json, - #[serde(rename = "zstd")] - Zstd, -} - -impl Bind for DataType { - fn bind(&self, statement: &Statement, start_index: i32) -> Result { - let value = match self { - DataType::Json => "json", - DataType::Zstd => "zstd", - }; - value.bind(statement, start_index) - } -} - -impl Column for DataType { - fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { - let (value, next_index) = String::column(statement, start_index)?; - let data_type = match value.as_str() { - "json" => DataType::Json, - "zstd" => DataType::Zstd, - _ => anyhow::bail!("Unknown data type: {}", value), - }; - Ok((data_type, next_index)) - } -} - -static RULES_FILE_NAMES: LazyLock<[&RelPath; 9]> = LazyLock::new(|| { - [ - RelPath::unix(".rules").unwrap(), - RelPath::unix(".cursorrules").unwrap(), - RelPath::unix(".windsurfrules").unwrap(), - RelPath::unix(".clinerules").unwrap(), - RelPath::unix(".github/copilot-instructions.md").unwrap(), - RelPath::unix("CLAUDE.md").unwrap(), - RelPath::unix("AGENT.md").unwrap(), - RelPath::unix("AGENTS.md").unwrap(), - RelPath::unix("GEMINI.md").unwrap(), - ] -}); - -pub fn init(fs: Arc, cx: &mut App) { - ThreadsDatabase::init(fs, cx); -} - -/// A system prompt shared by all threads created by this ThreadStore -#[derive(Clone, Default)] -pub struct SharedProjectContext(Rc>>); - -impl SharedProjectContext { - pub fn borrow(&self) -> Ref<'_, Option> { - self.0.borrow() - } -} - -pub type TextThreadStore = assistant_context::ContextStore; - -pub struct ThreadStore { - project: Entity, - tools: Entity, - prompt_builder: Arc, - prompt_store: Option>, - context_server_tool_ids: HashMap>, - threads: Vec, - project_context: SharedProjectContext, - reload_system_prompt_tx: mpsc::Sender<()>, - _reload_system_prompt_task: Task<()>, - _subscriptions: Vec, -} - -pub struct RulesLoadingError { - pub message: SharedString, -} - -impl EventEmitter for ThreadStore {} - -impl ThreadStore { - pub fn load( - project: Entity, - tools: Entity, - prompt_store: Option>, - prompt_builder: Arc, - cx: &mut App, - ) -> Task>> { - cx.spawn(async move |cx| { - let (thread_store, ready_rx) = cx.update(|cx| { - let mut option_ready_rx = None; - let thread_store = cx.new(|cx| { - let (thread_store, ready_rx) = - Self::new(project, tools, prompt_builder, prompt_store, cx); - option_ready_rx = Some(ready_rx); - thread_store - }); - (thread_store, option_ready_rx.take().unwrap()) - })?; - ready_rx.await?; - Ok(thread_store) - }) - } - - fn new( - project: Entity, - tools: Entity, - prompt_builder: Arc, - prompt_store: Option>, - cx: &mut Context, - ) -> (Self, oneshot::Receiver<()>) { - let mut subscriptions = vec![cx.subscribe(&project, Self::handle_project_event)]; - - if let Some(prompt_store) = prompt_store.as_ref() { - subscriptions.push(cx.subscribe( - prompt_store, - |this, _prompt_store, PromptsUpdatedEvent, _cx| { - this.enqueue_system_prompt_reload(); - }, - )) - } - - // This channel and task prevent concurrent and redundant loading of the system prompt. - let (reload_system_prompt_tx, mut reload_system_prompt_rx) = mpsc::channel(1); - let (ready_tx, ready_rx) = oneshot::channel(); - let mut ready_tx = Some(ready_tx); - let reload_system_prompt_task = cx.spawn({ - let prompt_store = prompt_store.clone(); - async move |thread_store, cx| { - loop { - let Some(reload_task) = thread_store - .update(cx, |thread_store, cx| { - thread_store.reload_system_prompt(prompt_store.clone(), cx) - }) - .ok() - else { - return; - }; - reload_task.await; - if let Some(ready_tx) = ready_tx.take() { - ready_tx.send(()).ok(); - } - reload_system_prompt_rx.next().await; - } - } - }); - - let this = Self { - project, - tools, - prompt_builder, - prompt_store, - context_server_tool_ids: HashMap::default(), - threads: Vec::new(), - project_context: SharedProjectContext::default(), - reload_system_prompt_tx, - _reload_system_prompt_task: reload_system_prompt_task, - _subscriptions: subscriptions, - }; - this.register_context_server_handlers(cx); - this.reload(cx).detach_and_log_err(cx); - (this, ready_rx) - } - - #[cfg(any(test, feature = "test-support"))] - pub fn fake(project: Entity, cx: &mut App) -> Self { - Self { - project, - tools: cx.new(|_| ToolWorkingSet::default()), - prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()), - prompt_store: None, - context_server_tool_ids: HashMap::default(), - threads: Vec::new(), - project_context: SharedProjectContext::default(), - reload_system_prompt_tx: mpsc::channel(0).0, - _reload_system_prompt_task: Task::ready(()), - _subscriptions: vec![], - } - } - - fn handle_project_event( - &mut self, - _project: Entity, - event: &project::Event, - _cx: &mut Context, - ) { - match event { - project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => { - self.enqueue_system_prompt_reload(); - } - project::Event::WorktreeUpdatedEntries(_, items) => { - if items - .iter() - .any(|(path, _, _)| RULES_FILE_NAMES.iter().any(|name| path.as_ref() == *name)) - { - self.enqueue_system_prompt_reload(); - } - } - _ => {} - } - } - - fn enqueue_system_prompt_reload(&mut self) { - self.reload_system_prompt_tx.try_send(()).ok(); - } - - // Note that this should only be called from `reload_system_prompt_task`. - fn reload_system_prompt( - &self, - prompt_store: Option>, - cx: &mut Context, - ) -> Task<()> { - let worktrees = self - .project - .read(cx) - .visible_worktrees(cx) - .collect::>(); - let worktree_tasks = worktrees - .into_iter() - .map(|worktree| { - Self::load_worktree_info_for_system_prompt(worktree, self.project.clone(), cx) - }) - .collect::>(); - let default_user_rules_task = match prompt_store { - None => Task::ready(vec![]), - Some(prompt_store) => prompt_store.read_with(cx, |prompt_store, cx| { - let prompts = prompt_store.default_prompt_metadata(); - let load_tasks = prompts.into_iter().map(|prompt_metadata| { - let contents = prompt_store.load(prompt_metadata.id, cx); - async move { (contents.await, prompt_metadata) } - }); - cx.background_spawn(future::join_all(load_tasks)) - }), - }; - - cx.spawn(async move |this, cx| { - let (worktrees, default_user_rules) = - future::join(future::join_all(worktree_tasks), default_user_rules_task).await; - - let worktrees = worktrees - .into_iter() - .map(|(worktree, rules_error)| { - if let Some(rules_error) = rules_error { - this.update(cx, |_, cx| cx.emit(rules_error)).ok(); - } - worktree - }) - .collect::>(); - - let default_user_rules = default_user_rules - .into_iter() - .flat_map(|(contents, prompt_metadata)| match contents { - Ok(contents) => Some(UserRulesContext { - uuid: match prompt_metadata.id { - PromptId::User { uuid } => uuid, - PromptId::EditWorkflow => return None, - }, - title: prompt_metadata.title.map(|title| title.to_string()), - contents, - }), - Err(err) => { - this.update(cx, |_, cx| { - cx.emit(RulesLoadingError { - message: format!("{err:?}").into(), - }); - }) - .ok(); - None - } - }) - .collect::>(); - - this.update(cx, |this, _cx| { - *this.project_context.0.borrow_mut() = - Some(ProjectContext::new(worktrees, default_user_rules)); - }) - .ok(); - }) - } - - fn load_worktree_info_for_system_prompt( - worktree: Entity, - project: Entity, - cx: &mut App, - ) -> Task<(WorktreeContext, Option)> { - let tree = worktree.read(cx); - let root_name = tree.root_name_str().into(); - let abs_path = tree.abs_path(); - - let mut context = WorktreeContext { - root_name, - abs_path, - rules_file: None, - }; - - let rules_task = Self::load_worktree_rules_file(worktree, project, cx); - let Some(rules_task) = rules_task else { - return Task::ready((context, None)); - }; - - cx.spawn(async move |_| { - let (rules_file, rules_file_error) = match rules_task.await { - Ok(rules_file) => (Some(rules_file), None), - Err(err) => ( - None, - Some(RulesLoadingError { - message: format!("{err}").into(), - }), - ), - }; - context.rules_file = rules_file; - (context, rules_file_error) - }) - } - - fn load_worktree_rules_file( - worktree: Entity, - project: Entity, - cx: &mut App, - ) -> Option>> { - let worktree = worktree.read(cx); - let worktree_id = worktree.id(); - let selected_rules_file = RULES_FILE_NAMES - .into_iter() - .filter_map(|name| { - worktree - .entry_for_path(name) - .filter(|entry| entry.is_file()) - .map(|entry| entry.path.clone()) - }) - .next(); - - // Note that Cline supports `.clinerules` being a directory, but that is not currently - // supported. This doesn't seem to occur often in GitHub repositories. - selected_rules_file.map(|path_in_worktree| { - let project_path = ProjectPath { - worktree_id, - path: path_in_worktree.clone(), - }; - let buffer_task = - project.update(cx, |project, cx| project.open_buffer(project_path, cx)); - let rope_task = cx.spawn(async move |cx| { - buffer_task.await?.read_with(cx, |buffer, cx| { - let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?; - anyhow::Ok((project_entry_id, buffer.as_rope().clone())) - })? - }); - // Build a string from the rope on a background thread. - cx.background_spawn(async move { - let (project_entry_id, rope) = rope_task.await?; - anyhow::Ok(RulesFileContext { - path_in_worktree, - text: rope.to_string().trim().to_string(), - project_entry_id: project_entry_id.to_usize(), - }) - }) - }) - } - - pub fn prompt_store(&self) -> &Option> { - &self.prompt_store - } - - pub fn tools(&self) -> Entity { - self.tools.clone() - } - - /// Returns the number of threads. - pub fn thread_count(&self) -> usize { - self.threads.len() - } - - pub fn reverse_chronological_threads(&self) -> impl Iterator { - // ordering is from "ORDER BY" in `list_threads` - self.threads.iter() - } - - pub fn create_thread(&mut self, cx: &mut Context) -> Entity { - cx.new(|cx| { - Thread::new( - self.project.clone(), - self.tools.clone(), - self.prompt_builder.clone(), - self.project_context.clone(), - cx, - ) - }) - } - - pub fn create_thread_from_serialized( - &mut self, - serialized: SerializedThread, - cx: &mut Context, - ) -> Entity { - cx.new(|cx| { - Thread::deserialize( - ThreadId::new(), - serialized, - self.project.clone(), - self.tools.clone(), - self.prompt_builder.clone(), - self.project_context.clone(), - None, - cx, - ) - }) - } - - pub fn open_thread( - &self, - id: &ThreadId, - window: &mut Window, - cx: &mut Context, - ) -> Task>> { - let id = id.clone(); - let database_future = ThreadsDatabase::global_future(cx); - let this = cx.weak_entity(); - window.spawn(cx, async move |cx| { - let database = database_future.await.map_err(|err| anyhow!(err))?; - let thread = database - .try_find_thread(id.clone()) - .await? - .with_context(|| format!("no thread found with ID: {id:?}"))?; - - let thread = this.update_in(cx, |this, window, cx| { - cx.new(|cx| { - Thread::deserialize( - id.clone(), - thread, - this.project.clone(), - this.tools.clone(), - this.prompt_builder.clone(), - this.project_context.clone(), - Some(window), - cx, - ) - }) - })?; - - Ok(thread) - }) - } - - pub fn save_thread(&self, thread: &Entity, cx: &mut Context) -> Task> { - let (metadata, serialized_thread) = - thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx))); - - let database_future = ThreadsDatabase::global_future(cx); - cx.spawn(async move |this, cx| { - let serialized_thread = serialized_thread.await?; - let database = database_future.await.map_err(|err| anyhow!(err))?; - database.save_thread(metadata, serialized_thread).await?; - - this.update(cx, |this, cx| this.reload(cx))?.await - }) - } - - pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context) -> Task> { - let id = id.clone(); - let database_future = ThreadsDatabase::global_future(cx); - cx.spawn(async move |this, cx| { - let database = database_future.await.map_err(|err| anyhow!(err))?; - database.delete_thread(id.clone()).await?; - - this.update(cx, |this, cx| { - this.threads.retain(|thread| thread.id != id); - cx.notify(); - }) - }) - } - - pub fn reload(&self, cx: &mut Context) -> Task> { - let database_future = ThreadsDatabase::global_future(cx); - cx.spawn(async move |this, cx| { - let threads = database_future - .await - .map_err(|err| anyhow!(err))? - .list_threads() - .await?; - - this.update(cx, |this, cx| { - this.threads = threads; - cx.notify(); - }) - }) - } - - fn register_context_server_handlers(&self, cx: &mut Context) { - let context_server_store = self.project.read(cx).context_server_store(); - cx.subscribe(&context_server_store, Self::handle_context_server_event) - .detach(); - - // Check for any servers that were already running before the handler was registered - for server in context_server_store.read(cx).running_servers() { - self.load_context_server_tools(server.id(), context_server_store.clone(), cx); - } - } - - fn handle_context_server_event( - &mut self, - context_server_store: Entity, - event: &project::context_server_store::Event, - cx: &mut Context, - ) { - let tool_working_set = self.tools.clone(); - match event { - project::context_server_store::Event::ServerStatusChanged { server_id, status } => { - match status { - ContextServerStatus::Starting => {} - ContextServerStatus::Running => { - self.load_context_server_tools(server_id.clone(), context_server_store, cx); - } - ContextServerStatus::Stopped | ContextServerStatus::Error(_) => { - if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) { - tool_working_set.update(cx, |tool_working_set, cx| { - tool_working_set.remove(&tool_ids, cx); - }); - } - } - } - } - } - } - - fn load_context_server_tools( - &self, - server_id: ContextServerId, - context_server_store: Entity, - cx: &mut Context, - ) { - let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else { - return; - }; - let tool_working_set = self.tools.clone(); - cx.spawn(async move |this, cx| { - let Some(protocol) = server.client() else { - return; - }; - - if protocol.capable(context_server::protocol::ServerCapability::Tools) - && let Some(response) = protocol - .request::(()) - .await - .log_err() - { - let tool_ids = tool_working_set - .update(cx, |tool_working_set, cx| { - tool_working_set.extend( - response.tools.into_iter().map(|tool| { - Arc::new(ContextServerTool::new( - context_server_store.clone(), - server.id(), - tool, - )) as Arc - }), - cx, - ) - }) - .log_err(); - - if let Some(tool_ids) = tool_ids { - this.update(cx, |this, _| { - this.context_server_tool_ids.insert(server_id, tool_ids); - }) - .log_err(); - } - } - }) - .detach(); - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SerializedThreadMetadata { - pub id: ThreadId, - pub summary: SharedString, - pub updated_at: DateTime, -} - -#[derive(Serialize, Deserialize, Debug, PartialEq)] -pub struct SerializedThread { - pub version: String, - pub summary: SharedString, - pub updated_at: DateTime, - pub messages: Vec, - #[serde(default)] - pub initial_project_snapshot: Option>, - #[serde(default)] - pub cumulative_token_usage: TokenUsage, - #[serde(default)] - pub request_token_usage: Vec, - #[serde(default)] - pub detailed_summary_state: DetailedSummaryState, - #[serde(default)] - pub exceeded_window_error: Option, - #[serde(default)] - pub model: Option, - #[serde(default)] - pub completion_mode: Option, - #[serde(default)] - pub tool_use_limit_reached: bool, - #[serde(default)] - pub profile: Option, -} - -#[derive(Serialize, Deserialize, Debug, PartialEq)] -pub struct SerializedLanguageModel { - pub provider: String, - pub model: String, -} - -impl SerializedThread { - pub const VERSION: &'static str = "0.2.0"; - - pub fn from_json(json: &[u8]) -> Result { - let saved_thread_json = serde_json::from_slice::(json)?; - match saved_thread_json.get("version") { - Some(serde_json::Value::String(version)) => match version.as_str() { - SerializedThreadV0_1_0::VERSION => { - let saved_thread = - serde_json::from_value::(saved_thread_json)?; - Ok(saved_thread.upgrade()) - } - SerializedThread::VERSION => Ok(serde_json::from_value::( - saved_thread_json, - )?), - _ => anyhow::bail!("unrecognized serialized thread version: {version:?}"), - }, - None => { - let saved_thread = - serde_json::from_value::(saved_thread_json)?; - Ok(saved_thread.upgrade()) - } - version => anyhow::bail!("unrecognized serialized thread version: {version:?}"), - } - } -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct SerializedThreadV0_1_0( - // The structure did not change, so we are reusing the latest SerializedThread. - // When making the next version, make sure this points to SerializedThreadV0_2_0 - SerializedThread, -); - -impl SerializedThreadV0_1_0 { - pub const VERSION: &'static str = "0.1.0"; - - pub fn upgrade(self) -> SerializedThread { - debug_assert_eq!(SerializedThread::VERSION, "0.2.0"); - - let mut messages: Vec = Vec::with_capacity(self.0.messages.len()); - - for message in self.0.messages { - if message.role == Role::User - && !message.tool_results.is_empty() - && let Some(last_message) = messages.last_mut() - { - debug_assert!(last_message.role == Role::Assistant); - - last_message.tool_results = message.tool_results; - continue; - } - - messages.push(message); - } - - SerializedThread { - messages, - version: SerializedThread::VERSION.to_string(), - ..self.0 - } - } -} - -#[derive(Debug, Serialize, Deserialize, PartialEq)] -pub struct SerializedMessage { - pub id: MessageId, - pub role: Role, - #[serde(default)] - pub segments: Vec, - #[serde(default)] - pub tool_uses: Vec, - #[serde(default)] - pub tool_results: Vec, - #[serde(default)] - pub context: String, - #[serde(default)] - pub creases: Vec, - #[serde(default)] - pub is_hidden: bool, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq)] -#[serde(tag = "type")] -pub enum SerializedMessageSegment { - #[serde(rename = "text")] - Text { - text: String, - }, - #[serde(rename = "thinking")] - Thinking { - text: String, - #[serde(skip_serializing_if = "Option::is_none")] - signature: Option, - }, - RedactedThinking { - data: String, - }, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq)] -pub struct SerializedToolUse { - pub id: LanguageModelToolUseId, - pub name: SharedString, - pub input: serde_json::Value, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq)] -pub struct SerializedToolResult { - pub tool_use_id: LanguageModelToolUseId, - pub is_error: bool, - pub content: LanguageModelToolResultContent, - pub output: Option, -} - -#[derive(Serialize, Deserialize)] -struct LegacySerializedThread { - pub summary: SharedString, - pub updated_at: DateTime, - pub messages: Vec, - #[serde(default)] - pub initial_project_snapshot: Option>, -} - -impl LegacySerializedThread { - pub fn upgrade(self) -> SerializedThread { - SerializedThread { - version: SerializedThread::VERSION.to_string(), - summary: self.summary, - updated_at: self.updated_at, - messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(), - initial_project_snapshot: self.initial_project_snapshot, - cumulative_token_usage: TokenUsage::default(), - request_token_usage: Vec::new(), - detailed_summary_state: DetailedSummaryState::default(), - exceeded_window_error: None, - model: None, - completion_mode: None, - tool_use_limit_reached: false, - profile: None, - } - } -} - -#[derive(Debug, Serialize, Deserialize)] -struct LegacySerializedMessage { - pub id: MessageId, - pub role: Role, - pub text: String, - #[serde(default)] - pub tool_uses: Vec, - #[serde(default)] - pub tool_results: Vec, -} - -impl LegacySerializedMessage { - fn upgrade(self) -> SerializedMessage { - SerializedMessage { - id: self.id, - role: self.role, - segments: vec![SerializedMessageSegment::Text { text: self.text }], - tool_uses: self.tool_uses, - tool_results: self.tool_results, - context: String::new(), - creases: Vec::new(), - is_hidden: false, - } - } -} - -#[derive(Debug, Serialize, Deserialize, PartialEq)] -pub struct SerializedCrease { - pub start: usize, - pub end: usize, - pub icon_path: SharedString, - pub label: SharedString, -} - -struct GlobalThreadsDatabase( - Shared, Arc>>>, -); - -impl Global for GlobalThreadsDatabase {} - -pub(crate) struct ThreadsDatabase { - executor: BackgroundExecutor, - connection: Arc>, -} - -impl ThreadsDatabase { - fn connection(&self) -> Arc> { - self.connection.clone() - } - - const COMPRESSION_LEVEL: i32 = 3; -} - -impl Bind for ThreadId { - fn bind(&self, statement: &Statement, start_index: i32) -> Result { - self.to_string().bind(statement, start_index) - } -} - -impl Column for ThreadId { - fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { - let (id_str, next_index) = String::column(statement, start_index)?; - Ok((ThreadId::from(id_str.as_str()), next_index)) - } -} - -impl ThreadsDatabase { - fn global_future( - cx: &mut App, - ) -> Shared, Arc>>> { - GlobalThreadsDatabase::global(cx).0.clone() - } - - fn init(fs: Arc, cx: &mut App) { - let executor = cx.background_executor().clone(); - let database_future = executor - .spawn({ - let executor = executor.clone(); - let threads_dir = paths::data_dir().join("threads"); - async move { ThreadsDatabase::new(fs, threads_dir, executor).await } - }) - .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new))) - .boxed() - .shared(); - - cx.set_global(GlobalThreadsDatabase(database_future)); - } - - pub async fn new( - fs: Arc, - threads_dir: PathBuf, - executor: BackgroundExecutor, - ) -> Result { - fs.create_dir(&threads_dir).await?; - - let sqlite_path = threads_dir.join("threads.db"); - let mdb_path = threads_dir.join("threads-db.1.mdb"); - - let needs_migration_from_heed = fs.is_file(&mdb_path).await; - - let connection = if *ZED_STATELESS { - Connection::open_memory(Some("THREAD_FALLBACK_DB")) - } else if cfg!(any(feature = "test-support", test)) { - // rust stores the name of the test on the current thread. - // We use this to automatically create a database that will - // be shared within the test (for the test_retrieve_old_thread) - // but not with concurrent tests. - let thread = std::thread::current(); - let test_name = thread.name(); - Connection::open_memory(Some(&format!( - "THREAD_FALLBACK_{}", - test_name.unwrap_or_default() - ))) - } else { - Connection::open_file(&sqlite_path.to_string_lossy()) - }; - - connection.exec(indoc! {" - CREATE TABLE IF NOT EXISTS threads ( - id TEXT PRIMARY KEY, - summary TEXT NOT NULL, - updated_at TEXT NOT NULL, - data_type TEXT NOT NULL, - data BLOB NOT NULL - ) - "})?() - .map_err(|e| anyhow!("Failed to create threads table: {}", e))?; - - let db = Self { - executor: executor.clone(), - connection: Arc::new(Mutex::new(connection)), - }; - - if needs_migration_from_heed { - let db_connection = db.connection(); - let executor_clone = executor.clone(); - executor - .spawn(async move { - log::info!("Starting threads.db migration"); - Self::migrate_from_heed(&mdb_path, db_connection, executor_clone)?; - fs.remove_dir( - &mdb_path, - RemoveOptions { - recursive: true, - ignore_if_not_exists: true, - }, - ) - .await?; - log::info!("threads.db migrated to sqlite"); - Ok::<(), anyhow::Error>(()) - }) - .detach(); - } - - Ok(db) - } - - // Remove this migration after 2025-09-01 - fn migrate_from_heed( - mdb_path: &Path, - connection: Arc>, - _executor: BackgroundExecutor, - ) -> Result<()> { - use heed::types::SerdeBincode; - struct SerializedThreadHeed(SerializedThread); - - impl heed::BytesEncode<'_> for SerializedThreadHeed { - type EItem = SerializedThreadHeed; - - fn bytes_encode( - item: &Self::EItem, - ) -> Result, heed::BoxedError> { - serde_json::to_vec(&item.0) - .map(std::borrow::Cow::Owned) - .map_err(Into::into) - } - } - - impl<'a> heed::BytesDecode<'a> for SerializedThreadHeed { - type DItem = SerializedThreadHeed; - - fn bytes_decode(bytes: &'a [u8]) -> Result { - SerializedThread::from_json(bytes) - .map(SerializedThreadHeed) - .map_err(Into::into) - } - } - - const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024; - - let env = unsafe { - heed::EnvOpenOptions::new() - .map_size(ONE_GB_IN_BYTES) - .max_dbs(1) - .open(mdb_path)? - }; - - let txn = env.write_txn()?; - let threads: heed::Database, SerializedThreadHeed> = env - .open_database(&txn, Some("threads"))? - .ok_or_else(|| anyhow!("threads database not found"))?; - - for result in threads.iter(&txn)? { - let (thread_id, thread_heed) = result?; - Self::save_thread_sync(&connection, thread_id, thread_heed.0)?; - } - - Ok(()) - } - - fn save_thread_sync( - connection: &Arc>, - id: ThreadId, - thread: SerializedThread, - ) -> Result<()> { - let json_data = serde_json::to_string(&thread)?; - let summary = thread.summary.to_string(); - let updated_at = thread.updated_at.to_rfc3339(); - - let connection = connection.lock().unwrap(); - - let compressed = zstd::encode_all(json_data.as_bytes(), Self::COMPRESSION_LEVEL)?; - let data_type = DataType::Zstd; - let data = compressed; - - let mut insert = connection.exec_bound::<(ThreadId, String, String, DataType, Vec)>(indoc! {" - INSERT OR REPLACE INTO threads (id, summary, updated_at, data_type, data) VALUES (?, ?, ?, ?, ?) - "})?; - - insert((id, summary, updated_at, data_type, data))?; - - Ok(()) - } - - pub fn list_threads(&self) -> Task>> { - let connection = self.connection.clone(); - - self.executor.spawn(async move { - let connection = connection.lock().unwrap(); - let mut select = - connection.select_bound::<(), (ThreadId, String, String)>(indoc! {" - SELECT id, summary, updated_at FROM threads ORDER BY updated_at DESC - "})?; - - let rows = select(())?; - let mut threads = Vec::new(); - - for (id, summary, updated_at) in rows { - threads.push(SerializedThreadMetadata { - id, - summary: summary.into(), - updated_at: DateTime::parse_from_rfc3339(&updated_at)?.with_timezone(&Utc), - }); - } - - Ok(threads) - }) - } - - pub fn try_find_thread(&self, id: ThreadId) -> Task>> { - let connection = self.connection.clone(); - - self.executor.spawn(async move { - let connection = connection.lock().unwrap(); - let mut select = connection.select_bound::)>(indoc! {" - SELECT data_type, data FROM threads WHERE id = ? LIMIT 1 - "})?; - - let rows = select(id)?; - if let Some((data_type, data)) = rows.into_iter().next() { - let json_data = match data_type { - DataType::Zstd => { - let decompressed = zstd::decode_all(&data[..])?; - String::from_utf8(decompressed)? - } - DataType::Json => String::from_utf8(data)?, - }; - - let thread = SerializedThread::from_json(json_data.as_bytes())?; - Ok(Some(thread)) - } else { - Ok(None) - } - }) - } - - pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task> { - let connection = self.connection.clone(); - - self.executor - .spawn(async move { Self::save_thread_sync(&connection, id, thread) }) - } - - pub fn delete_thread(&self, id: ThreadId) -> Task> { - let connection = self.connection.clone(); - - self.executor.spawn(async move { - let connection = connection.lock().unwrap(); - - let mut delete = connection.exec_bound::(indoc! {" - DELETE FROM threads WHERE id = ? - "})?; - - delete(id)?; - - Ok(()) - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::thread::{DetailedSummaryState, MessageId}; - use chrono::Utc; - use language_model::{Role, TokenUsage}; - use pretty_assertions::assert_eq; - - #[test] - fn test_legacy_serialized_thread_upgrade() { - let updated_at = Utc::now(); - let legacy_thread = LegacySerializedThread { - summary: "Test conversation".into(), - updated_at, - messages: vec![LegacySerializedMessage { - id: MessageId(1), - role: Role::User, - text: "Hello, world!".to_string(), - tool_uses: vec![], - tool_results: vec![], - }], - initial_project_snapshot: None, - }; - - let upgraded = legacy_thread.upgrade(); - - assert_eq!( - upgraded, - SerializedThread { - summary: "Test conversation".into(), - updated_at, - messages: vec![SerializedMessage { - id: MessageId(1), - role: Role::User, - segments: vec![SerializedMessageSegment::Text { - text: "Hello, world!".to_string() - }], - tool_uses: vec![], - tool_results: vec![], - context: "".to_string(), - creases: vec![], - is_hidden: false - }], - version: SerializedThread::VERSION.to_string(), - initial_project_snapshot: None, - cumulative_token_usage: TokenUsage::default(), - request_token_usage: vec![], - detailed_summary_state: DetailedSummaryState::default(), - exceeded_window_error: None, - model: None, - completion_mode: None, - tool_use_limit_reached: false, - profile: None - } - ) - } - - #[test] - fn test_serialized_threadv0_1_0_upgrade() { - let updated_at = Utc::now(); - let thread_v0_1_0 = SerializedThreadV0_1_0(SerializedThread { - summary: "Test conversation".into(), - updated_at, - messages: vec![ - SerializedMessage { - id: MessageId(1), - role: Role::User, - segments: vec![SerializedMessageSegment::Text { - text: "Use tool_1".to_string(), - }], - tool_uses: vec![], - tool_results: vec![], - context: "".to_string(), - creases: vec![], - is_hidden: false, - }, - SerializedMessage { - id: MessageId(2), - role: Role::Assistant, - segments: vec![SerializedMessageSegment::Text { - text: "I want to use a tool".to_string(), - }], - tool_uses: vec![SerializedToolUse { - id: "abc".into(), - name: "tool_1".into(), - input: serde_json::Value::Null, - }], - tool_results: vec![], - context: "".to_string(), - creases: vec![], - is_hidden: false, - }, - SerializedMessage { - id: MessageId(1), - role: Role::User, - segments: vec![SerializedMessageSegment::Text { - text: "Here is the tool result".to_string(), - }], - tool_uses: vec![], - tool_results: vec![SerializedToolResult { - tool_use_id: "abc".into(), - is_error: false, - content: LanguageModelToolResultContent::Text("abcdef".into()), - output: Some(serde_json::Value::Null), - }], - context: "".to_string(), - creases: vec![], - is_hidden: false, - }, - ], - version: SerializedThreadV0_1_0::VERSION.to_string(), - initial_project_snapshot: None, - cumulative_token_usage: TokenUsage::default(), - request_token_usage: vec![], - detailed_summary_state: DetailedSummaryState::default(), - exceeded_window_error: None, - model: None, - completion_mode: None, - tool_use_limit_reached: false, - profile: None, - }); - let upgraded = thread_v0_1_0.upgrade(); - - assert_eq!( - upgraded, - SerializedThread { - summary: "Test conversation".into(), - updated_at, - messages: vec![ - SerializedMessage { - id: MessageId(1), - role: Role::User, - segments: vec![SerializedMessageSegment::Text { - text: "Use tool_1".to_string() - }], - tool_uses: vec![], - tool_results: vec![], - context: "".to_string(), - creases: vec![], - is_hidden: false - }, - SerializedMessage { - id: MessageId(2), - role: Role::Assistant, - segments: vec![SerializedMessageSegment::Text { - text: "I want to use a tool".to_string(), - }], - tool_uses: vec![SerializedToolUse { - id: "abc".into(), - name: "tool_1".into(), - input: serde_json::Value::Null, - }], - tool_results: vec![SerializedToolResult { - tool_use_id: "abc".into(), - is_error: false, - content: LanguageModelToolResultContent::Text("abcdef".into()), - output: Some(serde_json::Value::Null), - }], - context: "".to_string(), - creases: vec![], - is_hidden: false, - }, - ], - version: SerializedThread::VERSION.to_string(), - initial_project_snapshot: None, - cumulative_token_usage: TokenUsage::default(), - request_token_usage: vec![], - detailed_summary_state: DetailedSummaryState::default(), - exceeded_window_error: None, - model: None, - completion_mode: None, - tool_use_limit_reached: false, - profile: None - } - ) - } -} diff --git a/crates/assistant_tool/src/tool_schema.rs b/crates/agent/src/tool_schema.rs similarity index 85% rename from crates/assistant_tool/src/tool_schema.rs rename to crates/agent/src/tool_schema.rs index 192f7c8a2bb565ece01a3472a9e46dad316377f4..4b0de3e5c63fb0c5ccafbb89a22dad8a33072b35 100644 --- a/crates/assistant_tool/src/tool_schema.rs +++ b/crates/agent/src/tool_schema.rs @@ -1,7 +1,48 @@ use anyhow::Result; +use language_model::LanguageModelToolSchemaFormat; +use schemars::{ + JsonSchema, Schema, + generate::SchemaSettings, + transform::{Transform, transform_subschemas}, +}; use serde_json::Value; -use crate::LanguageModelToolSchemaFormat; +pub(crate) fn root_schema_for(format: LanguageModelToolSchemaFormat) -> Schema { + let mut generator = match format { + LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07().into_generator(), + LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3() + .with(|settings| { + settings.meta_schema = None; + settings.inline_subschemas = true; + }) + .with_transform(ToJsonSchemaSubsetTransform) + .into_generator(), + }; + generator.root_schema_for::() +} + +#[derive(Debug, Clone)] +struct ToJsonSchemaSubsetTransform; + +impl Transform for ToJsonSchemaSubsetTransform { + fn transform(&mut self, schema: &mut Schema) { + // Ensure that the type field is not an array, this happens when we use + // Option, the type will be [T, "null"]. + if let Some(type_field) = schema.get_mut("type") + && let Some(types) = type_field.as_array() + && let Some(first_type) = types.first() + { + *type_field = first_type.clone(); + } + + // oneOf is not supported, use anyOf instead + if let Some(one_of) = schema.remove("oneOf") { + schema.insert("anyOf".to_string(), one_of); + } + + transform_subschemas(self, schema); + } +} /// Tries to adapt a JSON schema representation to be compatible with the specified format. /// diff --git a/crates/agent/src/tool_use.rs b/crates/agent/src/tool_use.rs deleted file mode 100644 index 962dca591fb66f4679d44b8e8a4733c879bc2e0c..0000000000000000000000000000000000000000 --- a/crates/agent/src/tool_use.rs +++ /dev/null @@ -1,575 +0,0 @@ -use crate::{ - thread::{MessageId, PromptId, ThreadId}, - thread_store::SerializedMessage, -}; -use agent_settings::CompletionMode; -use anyhow::Result; -use assistant_tool::{ - AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet, -}; -use collections::HashMap; -use futures::{FutureExt as _, future::Shared}; -use gpui::{App, Entity, SharedString, Task, Window}; -use icons::IconName; -use language_model::{ - ConfiguredModel, LanguageModel, LanguageModelExt, LanguageModelRequest, - LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse, - LanguageModelToolUseId, Role, -}; -use project::Project; -use std::sync::Arc; -use util::truncate_lines_to_byte_limit; - -#[derive(Debug)] -pub struct ToolUse { - pub id: LanguageModelToolUseId, - pub name: SharedString, - pub ui_text: SharedString, - pub status: ToolUseStatus, - pub input: serde_json::Value, - pub icon: icons::IconName, - pub needs_confirmation: bool, -} - -pub struct ToolUseState { - tools: Entity, - tool_uses_by_assistant_message: HashMap>, - tool_results: HashMap, - pending_tool_uses_by_id: HashMap, - tool_result_cards: HashMap, - tool_use_metadata_by_id: HashMap, -} - -impl ToolUseState { - pub fn new(tools: Entity) -> Self { - Self { - tools, - tool_uses_by_assistant_message: HashMap::default(), - tool_results: HashMap::default(), - pending_tool_uses_by_id: HashMap::default(), - tool_result_cards: HashMap::default(), - tool_use_metadata_by_id: HashMap::default(), - } - } - - /// Constructs a [`ToolUseState`] from the given list of [`SerializedMessage`]s. - /// - /// Accepts a function to filter the tools that should be used to populate the state. - /// - /// If `window` is `None` (e.g., when in headless mode or when running evals), - /// tool cards won't be deserialized - pub fn from_serialized_messages( - tools: Entity, - messages: &[SerializedMessage], - project: Entity, - window: Option<&mut Window>, // None in headless mode - cx: &mut App, - ) -> Self { - let mut this = Self::new(tools); - let mut tool_names_by_id = HashMap::default(); - let mut window = window; - - for message in messages { - match message.role { - Role::Assistant => { - if !message.tool_uses.is_empty() { - let tool_uses = message - .tool_uses - .iter() - .map(|tool_use| LanguageModelToolUse { - id: tool_use.id.clone(), - name: tool_use.name.clone().into(), - raw_input: tool_use.input.to_string(), - input: tool_use.input.clone(), - is_input_complete: true, - }) - .collect::>(); - - tool_names_by_id.extend( - tool_uses - .iter() - .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())), - ); - - this.tool_uses_by_assistant_message - .insert(message.id, tool_uses); - - for tool_result in &message.tool_results { - let tool_use_id = tool_result.tool_use_id.clone(); - let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else { - log::warn!("no tool name found for tool use: {tool_use_id:?}"); - continue; - }; - - this.tool_results.insert( - tool_use_id.clone(), - LanguageModelToolResult { - tool_use_id: tool_use_id.clone(), - tool_name: tool_use.clone(), - is_error: tool_result.is_error, - content: tool_result.content.clone(), - output: tool_result.output.clone(), - }, - ); - - if let Some(window) = &mut window - && let Some(tool) = this.tools.read(cx).tool(tool_use, cx) - && let Some(output) = tool_result.output.clone() - && let Some(card) = - tool.deserialize_card(output, project.clone(), window, cx) - { - this.tool_result_cards.insert(tool_use_id, card); - } - } - } - } - Role::System | Role::User => {} - } - } - - this - } - - pub fn cancel_pending(&mut self) -> Vec { - let mut canceled_tool_uses = Vec::new(); - self.pending_tool_uses_by_id - .retain(|tool_use_id, tool_use| { - if matches!(tool_use.status, PendingToolUseStatus::Error { .. }) { - return true; - } - - let content = "Tool canceled by user".into(); - self.tool_results.insert( - tool_use_id.clone(), - LanguageModelToolResult { - tool_use_id: tool_use_id.clone(), - tool_name: tool_use.name.clone(), - content, - output: None, - is_error: true, - }, - ); - canceled_tool_uses.push(tool_use.clone()); - false - }); - canceled_tool_uses - } - - pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> { - self.pending_tool_uses_by_id.values().collect() - } - - pub fn tool_uses_for_message( - &self, - id: MessageId, - project: &Entity, - cx: &App, - ) -> Vec { - let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else { - return Vec::new(); - }; - - let mut tool_uses = Vec::new(); - - for tool_use in tool_uses_for_message.iter() { - let tool_result = self.tool_results.get(&tool_use.id); - - let status = (|| { - if let Some(tool_result) = tool_result { - let content = tool_result - .content - .to_str() - .map(|str| str.to_owned().into()) - .unwrap_or_default(); - - return if tool_result.is_error { - ToolUseStatus::Error(content) - } else { - ToolUseStatus::Finished(content) - }; - } - - if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) { - match pending_tool_use.status { - PendingToolUseStatus::Idle => ToolUseStatus::Pending, - PendingToolUseStatus::NeedsConfirmation { .. } => { - ToolUseStatus::NeedsConfirmation - } - PendingToolUseStatus::Running { .. } => ToolUseStatus::Running, - PendingToolUseStatus::Error(ref err) => { - ToolUseStatus::Error(err.clone().into()) - } - PendingToolUseStatus::InputStillStreaming => { - ToolUseStatus::InputStillStreaming - } - } - } else { - ToolUseStatus::Pending - } - })(); - - let (icon, needs_confirmation) = - if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) { - ( - tool.icon(), - tool.needs_confirmation(&tool_use.input, project, cx), - ) - } else { - (IconName::Cog, false) - }; - - tool_uses.push(ToolUse { - id: tool_use.id.clone(), - name: tool_use.name.clone().into(), - ui_text: self.tool_ui_label( - &tool_use.name, - &tool_use.input, - tool_use.is_input_complete, - cx, - ), - input: tool_use.input.clone(), - status, - icon, - needs_confirmation, - }) - } - - tool_uses - } - - pub fn tool_ui_label( - &self, - tool_name: &str, - input: &serde_json::Value, - is_input_complete: bool, - cx: &App, - ) -> SharedString { - if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) { - if is_input_complete { - tool.ui_text(input).into() - } else { - tool.still_streaming_ui_text(input).into() - } - } else { - format!("Unknown tool {tool_name:?}").into() - } - } - - pub fn tool_results_for_message( - &self, - assistant_message_id: MessageId, - ) -> Vec<&LanguageModelToolResult> { - let Some(tool_uses) = self - .tool_uses_by_assistant_message - .get(&assistant_message_id) - else { - return Vec::new(); - }; - - tool_uses - .iter() - .filter_map(|tool_use| self.tool_results.get(&tool_use.id)) - .collect() - } - - pub fn message_has_tool_results(&self, assistant_message_id: MessageId) -> bool { - self.tool_uses_by_assistant_message - .get(&assistant_message_id) - .is_some_and(|results| !results.is_empty()) - } - - pub fn tool_result( - &self, - tool_use_id: &LanguageModelToolUseId, - ) -> Option<&LanguageModelToolResult> { - self.tool_results.get(tool_use_id) - } - - pub fn tool_result_card(&self, tool_use_id: &LanguageModelToolUseId) -> Option<&AnyToolCard> { - self.tool_result_cards.get(tool_use_id) - } - - pub fn insert_tool_result_card( - &mut self, - tool_use_id: LanguageModelToolUseId, - card: AnyToolCard, - ) { - self.tool_result_cards.insert(tool_use_id, card); - } - - pub fn request_tool_use( - &mut self, - assistant_message_id: MessageId, - tool_use: LanguageModelToolUse, - metadata: ToolUseMetadata, - cx: &App, - ) -> Arc { - let tool_uses = self - .tool_uses_by_assistant_message - .entry(assistant_message_id) - .or_default(); - - let mut existing_tool_use_found = false; - - for existing_tool_use in tool_uses.iter_mut() { - if existing_tool_use.id == tool_use.id { - *existing_tool_use = tool_use.clone(); - existing_tool_use_found = true; - } - } - - if !existing_tool_use_found { - tool_uses.push(tool_use.clone()); - } - - let status = if tool_use.is_input_complete { - self.tool_use_metadata_by_id - .insert(tool_use.id.clone(), metadata); - - PendingToolUseStatus::Idle - } else { - PendingToolUseStatus::InputStillStreaming - }; - - let ui_text: Arc = self - .tool_ui_label( - &tool_use.name, - &tool_use.input, - tool_use.is_input_complete, - cx, - ) - .into(); - - let may_perform_edits = self - .tools - .read(cx) - .tool(&tool_use.name, cx) - .is_some_and(|tool| tool.may_perform_edits()); - - self.pending_tool_uses_by_id.insert( - tool_use.id.clone(), - PendingToolUse { - assistant_message_id, - id: tool_use.id, - name: tool_use.name.clone(), - ui_text: ui_text.clone(), - input: tool_use.input, - may_perform_edits, - status, - }, - ); - - ui_text - } - - pub fn run_pending_tool( - &mut self, - tool_use_id: LanguageModelToolUseId, - ui_text: SharedString, - task: Task<()>, - ) { - if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) { - tool_use.ui_text = ui_text.into(); - tool_use.status = PendingToolUseStatus::Running { - _task: task.shared(), - }; - } - } - - pub fn confirm_tool_use( - &mut self, - tool_use_id: LanguageModelToolUseId, - ui_text: impl Into>, - input: serde_json::Value, - request: Arc, - tool: Arc, - ) { - if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) { - let ui_text = ui_text.into(); - tool_use.ui_text = ui_text.clone(); - let confirmation = Confirmation { - tool_use_id, - input, - request, - tool, - ui_text, - }; - tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation)); - } - } - - pub fn insert_tool_output( - &mut self, - tool_use_id: LanguageModelToolUseId, - tool_name: Arc, - output: Result, - configured_model: Option<&ConfiguredModel>, - completion_mode: CompletionMode, - ) -> Option { - let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id); - - telemetry::event!( - "Agent Tool Finished", - model = metadata - .as_ref() - .map(|metadata| metadata.model.telemetry_id()), - model_provider = metadata - .as_ref() - .map(|metadata| metadata.model.provider_id().to_string()), - thread_id = metadata.as_ref().map(|metadata| metadata.thread_id.clone()), - prompt_id = metadata.as_ref().map(|metadata| metadata.prompt_id.clone()), - tool_name, - success = output.is_ok() - ); - - match output { - Ok(output) => { - let tool_result = output.content; - const BYTES_PER_TOKEN_ESTIMATE: usize = 3; - - let old_use = self.pending_tool_uses_by_id.remove(&tool_use_id); - - // Protect from overly large output - let tool_output_limit = configured_model - .map(|model| { - model.model.max_token_count_for_mode(completion_mode.into()) as usize - * BYTES_PER_TOKEN_ESTIMATE - }) - .unwrap_or(usize::MAX); - - let content = match tool_result { - ToolResultContent::Text(text) => { - let text = if text.len() < tool_output_limit { - text - } else { - let truncated = truncate_lines_to_byte_limit(&text, tool_output_limit); - format!( - "Tool result too long. The first {} bytes:\n\n{}", - truncated.len(), - truncated - ) - }; - LanguageModelToolResultContent::Text(text.into()) - } - ToolResultContent::Image(language_model_image) => { - if language_model_image.estimate_tokens() < tool_output_limit { - LanguageModelToolResultContent::Image(language_model_image) - } else { - self.tool_results.insert( - tool_use_id.clone(), - LanguageModelToolResult { - tool_use_id: tool_use_id.clone(), - tool_name, - content: "Tool responded with an image that would exceeded the remaining tokens".into(), - is_error: true, - output: None, - }, - ); - - return old_use; - } - } - }; - - self.tool_results.insert( - tool_use_id.clone(), - LanguageModelToolResult { - tool_use_id: tool_use_id.clone(), - tool_name, - content, - is_error: false, - output: output.output, - }, - ); - - old_use - } - Err(err) => { - self.tool_results.insert( - tool_use_id.clone(), - LanguageModelToolResult { - tool_use_id: tool_use_id.clone(), - tool_name, - content: LanguageModelToolResultContent::Text(err.to_string().into()), - is_error: true, - output: None, - }, - ); - - if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) { - tool_use.status = PendingToolUseStatus::Error(err.to_string().into()); - } - - self.pending_tool_uses_by_id.get(&tool_use_id).cloned() - } - } - } - - pub fn has_tool_results(&self, assistant_message_id: MessageId) -> bool { - self.tool_uses_by_assistant_message - .contains_key(&assistant_message_id) - } - - pub fn tool_results( - &self, - assistant_message_id: MessageId, - ) -> impl Iterator)> { - self.tool_uses_by_assistant_message - .get(&assistant_message_id) - .into_iter() - .flatten() - .map(|tool_use| (tool_use, self.tool_results.get(&tool_use.id))) - } -} - -#[derive(Debug, Clone)] -pub struct PendingToolUse { - pub id: LanguageModelToolUseId, - /// The ID of the Assistant message in which the tool use was requested. - #[allow(unused)] - pub assistant_message_id: MessageId, - pub name: Arc, - pub ui_text: Arc, - pub input: serde_json::Value, - pub status: PendingToolUseStatus, - pub may_perform_edits: bool, -} - -#[derive(Debug, Clone)] -pub struct Confirmation { - pub tool_use_id: LanguageModelToolUseId, - pub input: serde_json::Value, - pub ui_text: Arc, - pub request: Arc, - pub tool: Arc, -} - -#[derive(Debug, Clone)] -pub enum PendingToolUseStatus { - InputStillStreaming, - Idle, - NeedsConfirmation(Arc), - Running { _task: Shared> }, - Error(#[allow(unused)] Arc), -} - -impl PendingToolUseStatus { - pub fn is_idle(&self) -> bool { - matches!(self, PendingToolUseStatus::Idle) - } - - pub fn is_error(&self) -> bool { - matches!(self, PendingToolUseStatus::Error(_)) - } - - pub fn needs_confirmation(&self) -> bool { - matches!(self, PendingToolUseStatus::NeedsConfirmation { .. }) - } -} - -#[derive(Clone)] -pub struct ToolUseMetadata { - pub model: Arc, - pub thread_id: ThreadId, - pub prompt_id: PromptId, -} diff --git a/crates/agent/src/tools.rs b/crates/agent/src/tools.rs new file mode 100644 index 0000000000000000000000000000000000000000..831efcad8f154de9aac19d9fd587fafb345d1aad --- /dev/null +++ b/crates/agent/src/tools.rs @@ -0,0 +1,88 @@ +mod context_server_registry; +mod copy_path_tool; +mod create_directory_tool; +mod delete_path_tool; +mod diagnostics_tool; +mod edit_file_tool; +mod fetch_tool; +mod find_path_tool; +mod grep_tool; +mod list_directory_tool; +mod move_path_tool; +mod now_tool; +mod open_tool; +mod read_file_tool; +mod terminal_tool; +mod thinking_tool; +mod web_search_tool; + +use crate::AgentTool; +use language_model::{LanguageModelRequestTool, LanguageModelToolSchemaFormat}; + +pub use context_server_registry::*; +pub use copy_path_tool::*; +pub use create_directory_tool::*; +pub use delete_path_tool::*; +pub use diagnostics_tool::*; +pub use edit_file_tool::*; +pub use fetch_tool::*; +pub use find_path_tool::*; +pub use grep_tool::*; +pub use list_directory_tool::*; +pub use move_path_tool::*; +pub use now_tool::*; +pub use open_tool::*; +pub use read_file_tool::*; +pub use terminal_tool::*; +pub use thinking_tool::*; +pub use web_search_tool::*; + +macro_rules! tools { + ($($tool:ty),* $(,)?) => { + /// A list of all built-in tool names + pub fn built_in_tool_names() -> impl Iterator { + [ + $( + <$tool>::name().to_string(), + )* + ] + .into_iter() + } + + /// A list of all built-in tools + pub fn built_in_tools() -> impl Iterator { + fn language_model_tool() -> LanguageModelRequestTool { + LanguageModelRequestTool { + name: T::name().to_string(), + description: T::description().to_string(), + input_schema: T::input_schema(LanguageModelToolSchemaFormat::JsonSchema).to_value(), + } + } + [ + $( + language_model_tool::<$tool>(), + )* + ] + .into_iter() + } + }; +} + +tools! { + CopyPathTool, + CreateDirectoryTool, + DeletePathTool, + DiagnosticsTool, + EditFileTool, + FetchTool, + FindPathTool, + GrepTool, + ListDirectoryTool, + MovePathTool, + NowTool, + OpenTool, + ReadFileTool, + TerminalTool, + ThinkingTool, + WebSearchTool, +} diff --git a/crates/agent2/src/tools/context_server_registry.rs b/crates/agent/src/tools/context_server_registry.rs similarity index 95% rename from crates/agent2/src/tools/context_server_registry.rs rename to crates/agent/src/tools/context_server_registry.rs index 46fa0298044de017464dc1a2e5bd21bf57c1bfcf..382d2ba9be74b4518de853037c858fd054366d5d 100644 --- a/crates/agent2/src/tools/context_server_registry.rs +++ b/crates/agent/src/tools/context_server_registry.rs @@ -32,6 +32,17 @@ impl ContextServerRegistry { this } + pub fn tools_for_server( + &self, + server_id: &ContextServerId, + ) -> impl Iterator> { + self.registered_servers + .get(server_id) + .map(|server| server.tools.values()) + .into_iter() + .flatten() + } + pub fn servers( &self, ) -> impl Iterator< @@ -154,7 +165,7 @@ impl AnyAgentTool for ContextServerTool { format: language_model::LanguageModelToolSchemaFormat, ) -> Result { let mut schema = self.tool.input_schema.clone(); - assistant_tool::adapt_schema_to_format(&mut schema, format)?; + crate::tool_schema::adapt_schema_to_format(&mut schema, format)?; Ok(match schema { serde_json::Value::Null => { serde_json::json!({ "type": "object", "properties": [] }) diff --git a/crates/agent2/src/tools/copy_path_tool.rs b/crates/agent/src/tools/copy_path_tool.rs similarity index 100% rename from crates/agent2/src/tools/copy_path_tool.rs rename to crates/agent/src/tools/copy_path_tool.rs diff --git a/crates/agent2/src/tools/create_directory_tool.rs b/crates/agent/src/tools/create_directory_tool.rs similarity index 100% rename from crates/agent2/src/tools/create_directory_tool.rs rename to crates/agent/src/tools/create_directory_tool.rs diff --git a/crates/agent2/src/tools/delete_path_tool.rs b/crates/agent/src/tools/delete_path_tool.rs similarity index 100% rename from crates/agent2/src/tools/delete_path_tool.rs rename to crates/agent/src/tools/delete_path_tool.rs diff --git a/crates/agent2/src/tools/diagnostics_tool.rs b/crates/agent/src/tools/diagnostics_tool.rs similarity index 100% rename from crates/agent2/src/tools/diagnostics_tool.rs rename to crates/agent/src/tools/diagnostics_tool.rs diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent/src/tools/edit_file_tool.rs similarity index 98% rename from crates/agent2/src/tools/edit_file_tool.rs rename to crates/agent/src/tools/edit_file_tool.rs index 90bb68979439b92eef685a500a81147fde8099d6..0adff2dee3571f09b40ee69896c05e50c56b51b9 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent/src/tools/edit_file_tool.rs @@ -1,8 +1,10 @@ -use crate::{AgentTool, Thread, ToolCallEventStream}; +use crate::{ + AgentTool, Templates, Thread, ToolCallEventStream, + edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat}, +}; use acp_thread::Diff; use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields}; use anyhow::{Context as _, Result, anyhow}; -use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat}; use cloud_llm_client::CompletionIntent; use collections::HashSet; use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; @@ -34,7 +36,7 @@ const DEFAULT_UI_TEXT: &str = "Editing file"; /// /// 2. Verify the directory path is correct (only applicable when creating new files): /// - Use the `list_directory` tool to verify the parent directory exists and is the correct location -#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] pub struct EditFileToolInput { /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit. /// @@ -75,7 +77,7 @@ pub struct EditFileToolInput { pub mode: EditFileMode, } -#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] struct EditFileToolPartialInput { #[serde(default)] path: String, @@ -123,6 +125,7 @@ pub struct EditFileTool { thread: WeakEntity, language_registry: Arc, project: Entity, + templates: Arc, } impl EditFileTool { @@ -130,11 +133,13 @@ impl EditFileTool { project: Entity, thread: WeakEntity, language_registry: Arc, + templates: Arc, ) -> Self { Self { project, thread, language_registry, + templates, } } @@ -294,8 +299,7 @@ impl AgentTool for EditFileTool { model, project.clone(), action_log.clone(), - // TODO: move edit agent to this crate so we can use our templates - assistant_tools::templates::Templates::new(), + self.templates.clone(), edit_format, ); @@ -599,6 +603,7 @@ mod tests { project, thread.downgrade(), language_registry, + Templates::new(), )) .run(input, ToolCallEventStream::test().0, cx) }) @@ -807,6 +812,7 @@ mod tests { project.clone(), thread.downgrade(), language_registry.clone(), + Templates::new(), )) .run(input, ToolCallEventStream::test().0, cx) }); @@ -865,6 +871,7 @@ mod tests { project.clone(), thread.downgrade(), language_registry, + Templates::new(), )) .run(input, ToolCallEventStream::test().0, cx) }); @@ -951,6 +958,7 @@ mod tests { project.clone(), thread.downgrade(), language_registry.clone(), + Templates::new(), )) .run(input, ToolCallEventStream::test().0, cx) }); @@ -1005,6 +1013,7 @@ mod tests { project.clone(), thread.downgrade(), language_registry, + Templates::new(), )) .run(input, ToolCallEventStream::test().0, cx) }); @@ -1057,6 +1066,7 @@ mod tests { project.clone(), thread.downgrade(), language_registry, + Templates::new(), )); fs.insert_tree("/root", json!({})).await; @@ -1197,6 +1207,7 @@ mod tests { project.clone(), thread.downgrade(), language_registry, + Templates::new(), )); // Test global config paths - these should require confirmation if they exist and are outside the project @@ -1309,6 +1320,7 @@ mod tests { project.clone(), thread.downgrade(), language_registry, + Templates::new(), )); // Test files in different worktrees @@ -1393,6 +1405,7 @@ mod tests { project.clone(), thread.downgrade(), language_registry, + Templates::new(), )); // Test edge cases @@ -1482,6 +1495,7 @@ mod tests { project.clone(), thread.downgrade(), language_registry, + Templates::new(), )); // Test different EditFileMode values @@ -1566,6 +1580,7 @@ mod tests { project, thread.downgrade(), language_registry, + Templates::new(), )); cx.update(|cx| { @@ -1653,6 +1668,7 @@ mod tests { project.clone(), thread.downgrade(), languages.clone(), + Templates::new(), )); let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let edit = cx.update(|cx| { @@ -1682,6 +1698,7 @@ mod tests { project.clone(), thread.downgrade(), languages.clone(), + Templates::new(), )); let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let edit = cx.update(|cx| { @@ -1709,6 +1726,7 @@ mod tests { project.clone(), thread.downgrade(), languages.clone(), + Templates::new(), )); let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let edit = cx.update(|cx| { diff --git a/crates/agent2/src/tools/fetch_tool.rs b/crates/agent/src/tools/fetch_tool.rs similarity index 100% rename from crates/agent2/src/tools/fetch_tool.rs rename to crates/agent/src/tools/fetch_tool.rs diff --git a/crates/agent2/src/tools/find_path_tool.rs b/crates/agent/src/tools/find_path_tool.rs similarity index 100% rename from crates/agent2/src/tools/find_path_tool.rs rename to crates/agent/src/tools/find_path_tool.rs diff --git a/crates/agent2/src/tools/grep_tool.rs b/crates/agent/src/tools/grep_tool.rs similarity index 100% rename from crates/agent2/src/tools/grep_tool.rs rename to crates/agent/src/tools/grep_tool.rs diff --git a/crates/agent2/src/tools/list_directory_tool.rs b/crates/agent/src/tools/list_directory_tool.rs similarity index 100% rename from crates/agent2/src/tools/list_directory_tool.rs rename to crates/agent/src/tools/list_directory_tool.rs diff --git a/crates/agent2/src/tools/move_path_tool.rs b/crates/agent/src/tools/move_path_tool.rs similarity index 100% rename from crates/agent2/src/tools/move_path_tool.rs rename to crates/agent/src/tools/move_path_tool.rs diff --git a/crates/agent2/src/tools/now_tool.rs b/crates/agent/src/tools/now_tool.rs similarity index 100% rename from crates/agent2/src/tools/now_tool.rs rename to crates/agent/src/tools/now_tool.rs diff --git a/crates/agent2/src/tools/open_tool.rs b/crates/agent/src/tools/open_tool.rs similarity index 100% rename from crates/agent2/src/tools/open_tool.rs rename to crates/agent/src/tools/open_tool.rs diff --git a/crates/agent2/src/tools/read_file_tool.rs b/crates/agent/src/tools/read_file_tool.rs similarity index 99% rename from crates/agent2/src/tools/read_file_tool.rs rename to crates/agent/src/tools/read_file_tool.rs index ce8dcba10236aa194e8b30d3fe6855d8c5fa5148..f3ce8e35f2856a3dd53770eef48ec1091fe9b116 100644 --- a/crates/agent2/src/tools/read_file_tool.rs +++ b/crates/agent/src/tools/read_file_tool.rs @@ -1,7 +1,6 @@ use action_log::ActionLog; use agent_client_protocol::{self as acp, ToolCallUpdateFields}; use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::outline; use gpui::{App, Entity, SharedString, Task}; use indoc::formatdoc; use language::Point; @@ -13,7 +12,7 @@ use settings::Settings; use std::sync::Arc; use util::markdown::MarkdownCodeBlock; -use crate::{AgentTool, ToolCallEventStream}; +use crate::{AgentTool, ToolCallEventStream, outline}; /// Reads the content of the given file in the project. /// diff --git a/crates/agent2/src/tools/terminal_tool.rs b/crates/agent/src/tools/terminal_tool.rs similarity index 100% rename from crates/agent2/src/tools/terminal_tool.rs rename to crates/agent/src/tools/terminal_tool.rs diff --git a/crates/agent2/src/tools/thinking_tool.rs b/crates/agent/src/tools/thinking_tool.rs similarity index 100% rename from crates/agent2/src/tools/thinking_tool.rs rename to crates/agent/src/tools/thinking_tool.rs diff --git a/crates/agent2/src/tools/web_search_tool.rs b/crates/agent/src/tools/web_search_tool.rs similarity index 100% rename from crates/agent2/src/tools/web_search_tool.rs rename to crates/agent/src/tools/web_search_tool.rs diff --git a/crates/agent2/Cargo.toml b/crates/agent2/Cargo.toml deleted file mode 100644 index b712bed258dfb69ddf81a1ba431ec7a3566b9baf..0000000000000000000000000000000000000000 --- a/crates/agent2/Cargo.toml +++ /dev/null @@ -1,102 +0,0 @@ -[package] -name = "agent2" -version = "0.1.0" -edition.workspace = true -publish.workspace = true -license = "GPL-3.0-or-later" - -[lib] -path = "src/agent2.rs" - -[features] -test-support = ["db/test-support"] -e2e = [] - -[lints] -workspace = true - -[dependencies] -acp_thread.workspace = true -action_log.workspace = true -agent.workspace = true -agent-client-protocol.workspace = true -agent_servers.workspace = true -agent_settings.workspace = true -anyhow.workspace = true -assistant_context.workspace = true -assistant_tool.workspace = true -assistant_tools.workspace = true -chrono.workspace = true -client.workspace = true -cloud_llm_client.workspace = true -collections.workspace = true -context_server.workspace = true -db.workspace = true -fs.workspace = true -futures.workspace = true -git.workspace = true -gpui.workspace = true -handlebars = { workspace = true, features = ["rust-embed"] } -html_to_markdown.workspace = true -http_client.workspace = true -indoc.workspace = true -itertools.workspace = true -language.workspace = true -language_model.workspace = true -language_models.workspace = true -log.workspace = true -open.workspace = true -parking_lot.workspace = true -paths.workspace = true -project.workspace = true -prompt_store.workspace = true -rust-embed.workspace = true -schemars.workspace = true -serde.workspace = true -serde_json.workspace = true -settings.workspace = true -smol.workspace = true -sqlez.workspace = true -task.workspace = true -telemetry.workspace = true -terminal.workspace = true -thiserror.workspace = true -text.workspace = true -ui.workspace = true -util.workspace = true -uuid.workspace = true -watch.workspace = true -web_search.workspace = true -workspace-hack.workspace = true -zed_env_vars.workspace = true -zstd.workspace = true - -[dev-dependencies] -agent = { workspace = true, "features" = ["test-support"] } -agent_servers = { workspace = true, "features" = ["test-support"] } -assistant_context = { workspace = true, "features" = ["test-support"] } -ctor.workspace = true -client = { workspace = true, "features" = ["test-support"] } -clock = { workspace = true, "features" = ["test-support"] } -context_server = { workspace = true, "features" = ["test-support"] } -db = { workspace = true, "features" = ["test-support"] } -editor = { workspace = true, "features" = ["test-support"] } -env_logger.workspace = true -fs = { workspace = true, "features" = ["test-support"] } -git = { workspace = true, "features" = ["test-support"] } -gpui = { workspace = true, "features" = ["test-support"] } -gpui_tokio.workspace = true -language = { workspace = true, "features" = ["test-support"] } -language_model = { workspace = true, "features" = ["test-support"] } -lsp = { workspace = true, "features" = ["test-support"] } -pretty_assertions.workspace = true -project = { workspace = true, "features" = ["test-support"] } -reqwest_client.workspace = true -settings = { workspace = true, "features" = ["test-support"] } -tempfile.workspace = true -terminal = { workspace = true, "features" = ["test-support"] } -theme = { workspace = true, "features" = ["test-support"] } -tree-sitter-rust.workspace = true -unindent = { workspace = true } -worktree = { workspace = true, "features" = ["test-support"] } -zlog.workspace = true diff --git a/crates/agent2/LICENSE-GPL b/crates/agent2/LICENSE-GPL deleted file mode 120000 index 89e542f750cd3860a0598eff0dc34b56d7336dc4..0000000000000000000000000000000000000000 --- a/crates/agent2/LICENSE-GPL +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-GPL \ No newline at end of file diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs deleted file mode 100644 index bf1fe8b5bb72038e197eafc842ca02e417b9e7c3..0000000000000000000000000000000000000000 --- a/crates/agent2/src/agent.rs +++ /dev/null @@ -1,1588 +0,0 @@ -use crate::{ - ContextServerRegistry, Thread, ThreadEvent, ThreadsDatabase, ToolCallAuthorization, - UserMessageContent, templates::Templates, -}; -use crate::{HistoryStore, TerminalHandle, ThreadEnvironment, TitleUpdated, TokenUsageUpdated}; -use acp_thread::{AcpThread, AgentModelSelector}; -use action_log::ActionLog; -use agent_client_protocol as acp; -use anyhow::{Context as _, Result, anyhow}; -use collections::{HashSet, IndexMap}; -use fs::Fs; -use futures::channel::{mpsc, oneshot}; -use futures::future::Shared; -use futures::{StreamExt, future}; -use gpui::{ - App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity, -}; -use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry}; -use project::{Project, ProjectItem, ProjectPath, Worktree}; -use prompt_store::{ - ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext, -}; -use settings::{LanguageModelSelection, update_settings_file}; -use std::any::Any; -use std::collections::HashMap; -use std::path::{Path, PathBuf}; -use std::rc::Rc; -use std::sync::Arc; -use util::ResultExt; -use util::rel_path::RelPath; - -const RULES_FILE_NAMES: [&str; 9] = [ - ".rules", - ".cursorrules", - ".windsurfrules", - ".clinerules", - ".github/copilot-instructions.md", - "CLAUDE.md", - "AGENT.md", - "AGENTS.md", - "GEMINI.md", -]; - -pub struct RulesLoadingError { - pub message: SharedString, -} - -/// Holds both the internal Thread and the AcpThread for a session -struct Session { - /// The internal thread that processes messages - thread: Entity, - /// The ACP thread that handles protocol communication - acp_thread: WeakEntity, - pending_save: Task<()>, - _subscriptions: Vec, -} - -pub struct LanguageModels { - /// Access language model by ID - models: HashMap>, - /// Cached list for returning language model information - model_list: acp_thread::AgentModelList, - refresh_models_rx: watch::Receiver<()>, - refresh_models_tx: watch::Sender<()>, - _authenticate_all_providers_task: Task<()>, -} - -impl LanguageModels { - fn new(cx: &mut App) -> Self { - let (refresh_models_tx, refresh_models_rx) = watch::channel(()); - - let mut this = Self { - models: HashMap::default(), - model_list: acp_thread::AgentModelList::Grouped(IndexMap::default()), - refresh_models_rx, - refresh_models_tx, - _authenticate_all_providers_task: Self::authenticate_all_language_model_providers(cx), - }; - this.refresh_list(cx); - this - } - - fn refresh_list(&mut self, cx: &App) { - let providers = LanguageModelRegistry::global(cx) - .read(cx) - .providers() - .into_iter() - .filter(|provider| provider.is_authenticated(cx)) - .collect::>(); - - let mut language_model_list = IndexMap::default(); - let mut recommended_models = HashSet::default(); - - let mut recommended = Vec::new(); - for provider in &providers { - for model in provider.recommended_models(cx) { - recommended_models.insert((model.provider_id(), model.id())); - recommended.push(Self::map_language_model_to_info(&model, provider)); - } - } - if !recommended.is_empty() { - language_model_list.insert( - acp_thread::AgentModelGroupName("Recommended".into()), - recommended, - ); - } - - let mut models = HashMap::default(); - for provider in providers { - let mut provider_models = Vec::new(); - for model in provider.provided_models(cx) { - let model_info = Self::map_language_model_to_info(&model, &provider); - let model_id = model_info.id.clone(); - if !recommended_models.contains(&(model.provider_id(), model.id())) { - provider_models.push(model_info); - } - models.insert(model_id, model); - } - if !provider_models.is_empty() { - language_model_list.insert( - acp_thread::AgentModelGroupName(provider.name().0.clone()), - provider_models, - ); - } - } - - self.models = models; - self.model_list = acp_thread::AgentModelList::Grouped(language_model_list); - self.refresh_models_tx.send(()).ok(); - } - - fn watch(&self) -> watch::Receiver<()> { - self.refresh_models_rx.clone() - } - - pub fn model_from_id(&self, model_id: &acp::ModelId) -> Option> { - self.models.get(model_id).cloned() - } - - fn map_language_model_to_info( - model: &Arc, - provider: &Arc, - ) -> acp_thread::AgentModelInfo { - acp_thread::AgentModelInfo { - id: Self::model_id(model), - name: model.name().0, - description: None, - icon: Some(provider.icon()), - } - } - - fn model_id(model: &Arc) -> acp::ModelId { - acp::ModelId(format!("{}/{}", model.provider_id().0, model.id().0).into()) - } - - fn authenticate_all_language_model_providers(cx: &mut App) -> Task<()> { - let authenticate_all_providers = LanguageModelRegistry::global(cx) - .read(cx) - .providers() - .iter() - .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx))) - .collect::>(); - - cx.background_spawn(async move { - for (provider_id, provider_name, authenticate_task) in authenticate_all_providers { - if let Err(err) = authenticate_task.await { - match err { - language_model::AuthenticateError::CredentialsNotFound => { - // Since we're authenticating these providers in the - // background for the purposes of populating the - // language selector, we don't care about providers - // where the credentials are not found. - } - language_model::AuthenticateError::ConnectionRefused => { - // Not logging connection refused errors as they are mostly from LM Studio's noisy auth failures. - // LM Studio only has one auth method (endpoint call) which fails for users who haven't enabled it. - // TODO: Better manage LM Studio auth logic to avoid these noisy failures. - } - _ => { - // Some providers have noisy failure states that we - // don't want to spam the logs with every time the - // language model selector is initialized. - // - // Ideally these should have more clear failure modes - // that we know are safe to ignore here, like what we do - // with `CredentialsNotFound` above. - match provider_id.0.as_ref() { - "lmstudio" | "ollama" => { - // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated". - // - // These fail noisily, so we don't log them. - } - "copilot_chat" => { - // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors. - } - _ => { - log::error!( - "Failed to authenticate provider: {}: {err}", - provider_name.0 - ); - } - } - } - } - } - } - }) - } -} - -pub struct NativeAgent { - /// Session ID -> Session mapping - sessions: HashMap, - history: Entity, - /// Shared project context for all threads - project_context: Entity, - project_context_needs_refresh: watch::Sender<()>, - _maintain_project_context: Task>, - context_server_registry: Entity, - /// Shared templates for all threads - templates: Arc, - /// Cached model information - models: LanguageModels, - project: Entity, - prompt_store: Option>, - fs: Arc, - _subscriptions: Vec, -} - -impl NativeAgent { - pub async fn new( - project: Entity, - history: Entity, - templates: Arc, - prompt_store: Option>, - fs: Arc, - cx: &mut AsyncApp, - ) -> Result> { - log::debug!("Creating new NativeAgent"); - - let project_context = cx - .update(|cx| Self::build_project_context(&project, prompt_store.as_ref(), cx))? - .await; - - cx.new(|cx| { - let mut subscriptions = vec![ - cx.subscribe(&project, Self::handle_project_event), - cx.subscribe( - &LanguageModelRegistry::global(cx), - Self::handle_models_updated_event, - ), - ]; - if let Some(prompt_store) = prompt_store.as_ref() { - subscriptions.push(cx.subscribe(prompt_store, Self::handle_prompts_updated_event)) - } - - let (project_context_needs_refresh_tx, project_context_needs_refresh_rx) = - watch::channel(()); - Self { - sessions: HashMap::new(), - history, - project_context: cx.new(|_| project_context), - project_context_needs_refresh: project_context_needs_refresh_tx, - _maintain_project_context: cx.spawn(async move |this, cx| { - Self::maintain_project_context(this, project_context_needs_refresh_rx, cx).await - }), - context_server_registry: cx.new(|cx| { - ContextServerRegistry::new(project.read(cx).context_server_store(), cx) - }), - templates, - models: LanguageModels::new(cx), - project, - prompt_store, - fs, - _subscriptions: subscriptions, - } - }) - } - - fn register_session( - &mut self, - thread_handle: Entity, - cx: &mut Context, - ) -> Entity { - let connection = Rc::new(NativeAgentConnection(cx.entity())); - - let thread = thread_handle.read(cx); - let session_id = thread.id().clone(); - let title = thread.title(); - let project = thread.project.clone(); - let action_log = thread.action_log.clone(); - let prompt_capabilities_rx = thread.prompt_capabilities_rx.clone(); - let acp_thread = cx.new(|cx| { - acp_thread::AcpThread::new( - title, - connection, - project.clone(), - action_log.clone(), - session_id.clone(), - prompt_capabilities_rx, - cx, - ) - }); - - let registry = LanguageModelRegistry::read_global(cx); - let summarization_model = registry.thread_summary_model().map(|c| c.model); - - thread_handle.update(cx, |thread, cx| { - thread.set_summarization_model(summarization_model, cx); - thread.add_default_tools( - Rc::new(AcpThreadEnvironment { - acp_thread: acp_thread.downgrade(), - }) as _, - cx, - ) - }); - - let subscriptions = vec![ - cx.observe_release(&acp_thread, |this, acp_thread, _cx| { - this.sessions.remove(acp_thread.session_id()); - }), - cx.subscribe(&thread_handle, Self::handle_thread_title_updated), - cx.subscribe(&thread_handle, Self::handle_thread_token_usage_updated), - cx.observe(&thread_handle, move |this, thread, cx| { - this.save_thread(thread, cx) - }), - ]; - - self.sessions.insert( - session_id, - Session { - thread: thread_handle, - acp_thread: acp_thread.downgrade(), - _subscriptions: subscriptions, - pending_save: Task::ready(()), - }, - ); - acp_thread - } - - pub fn models(&self) -> &LanguageModels { - &self.models - } - - async fn maintain_project_context( - this: WeakEntity, - mut needs_refresh: watch::Receiver<()>, - cx: &mut AsyncApp, - ) -> Result<()> { - while needs_refresh.changed().await.is_ok() { - let project_context = this - .update(cx, |this, cx| { - Self::build_project_context(&this.project, this.prompt_store.as_ref(), cx) - })? - .await; - this.update(cx, |this, cx| { - this.project_context = cx.new(|_| project_context); - })?; - } - - Ok(()) - } - - fn build_project_context( - project: &Entity, - prompt_store: Option<&Entity>, - cx: &mut App, - ) -> Task { - let worktrees = project.read(cx).visible_worktrees(cx).collect::>(); - let worktree_tasks = worktrees - .into_iter() - .map(|worktree| { - Self::load_worktree_info_for_system_prompt(worktree, project.clone(), cx) - }) - .collect::>(); - let default_user_rules_task = if let Some(prompt_store) = prompt_store.as_ref() { - prompt_store.read_with(cx, |prompt_store, cx| { - let prompts = prompt_store.default_prompt_metadata(); - let load_tasks = prompts.into_iter().map(|prompt_metadata| { - let contents = prompt_store.load(prompt_metadata.id, cx); - async move { (contents.await, prompt_metadata) } - }); - cx.background_spawn(future::join_all(load_tasks)) - }) - } else { - Task::ready(vec![]) - }; - - cx.spawn(async move |_cx| { - let (worktrees, default_user_rules) = - future::join(future::join_all(worktree_tasks), default_user_rules_task).await; - - let worktrees = worktrees - .into_iter() - .map(|(worktree, _rules_error)| { - // TODO: show error message - // if let Some(rules_error) = rules_error { - // this.update(cx, |_, cx| cx.emit(rules_error)).ok(); - // } - worktree - }) - .collect::>(); - - let default_user_rules = default_user_rules - .into_iter() - .flat_map(|(contents, prompt_metadata)| match contents { - Ok(contents) => Some(UserRulesContext { - uuid: match prompt_metadata.id { - PromptId::User { uuid } => uuid, - PromptId::EditWorkflow => return None, - }, - title: prompt_metadata.title.map(|title| title.to_string()), - contents, - }), - Err(_err) => { - // TODO: show error message - // this.update(cx, |_, cx| { - // cx.emit(RulesLoadingError { - // message: format!("{err:?}").into(), - // }); - // }) - // .ok(); - None - } - }) - .collect::>(); - - ProjectContext::new(worktrees, default_user_rules) - }) - } - - fn load_worktree_info_for_system_prompt( - worktree: Entity, - project: Entity, - cx: &mut App, - ) -> Task<(WorktreeContext, Option)> { - let tree = worktree.read(cx); - let root_name = tree.root_name_str().into(); - let abs_path = tree.abs_path(); - - let mut context = WorktreeContext { - root_name, - abs_path, - rules_file: None, - }; - - let rules_task = Self::load_worktree_rules_file(worktree, project, cx); - let Some(rules_task) = rules_task else { - return Task::ready((context, None)); - }; - - cx.spawn(async move |_| { - let (rules_file, rules_file_error) = match rules_task.await { - Ok(rules_file) => (Some(rules_file), None), - Err(err) => ( - None, - Some(RulesLoadingError { - message: format!("{err}").into(), - }), - ), - }; - context.rules_file = rules_file; - (context, rules_file_error) - }) - } - - fn load_worktree_rules_file( - worktree: Entity, - project: Entity, - cx: &mut App, - ) -> Option>> { - let worktree = worktree.read(cx); - let worktree_id = worktree.id(); - let selected_rules_file = RULES_FILE_NAMES - .into_iter() - .filter_map(|name| { - worktree - .entry_for_path(RelPath::unix(name).unwrap()) - .filter(|entry| entry.is_file()) - .map(|entry| entry.path.clone()) - }) - .next(); - - // Note that Cline supports `.clinerules` being a directory, but that is not currently - // supported. This doesn't seem to occur often in GitHub repositories. - selected_rules_file.map(|path_in_worktree| { - let project_path = ProjectPath { - worktree_id, - path: path_in_worktree.clone(), - }; - let buffer_task = - project.update(cx, |project, cx| project.open_buffer(project_path, cx)); - let rope_task = cx.spawn(async move |cx| { - buffer_task.await?.read_with(cx, |buffer, cx| { - let project_entry_id = buffer.entry_id(cx).context("buffer has no file")?; - anyhow::Ok((project_entry_id, buffer.as_rope().clone())) - })? - }); - // Build a string from the rope on a background thread. - cx.background_spawn(async move { - let (project_entry_id, rope) = rope_task.await?; - anyhow::Ok(RulesFileContext { - path_in_worktree, - text: rope.to_string().trim().to_string(), - project_entry_id: project_entry_id.to_usize(), - }) - }) - }) - } - - fn handle_thread_title_updated( - &mut self, - thread: Entity, - _: &TitleUpdated, - cx: &mut Context, - ) { - let session_id = thread.read(cx).id(); - let Some(session) = self.sessions.get(session_id) else { - return; - }; - let thread = thread.downgrade(); - let acp_thread = session.acp_thread.clone(); - cx.spawn(async move |_, cx| { - let title = thread.read_with(cx, |thread, _| thread.title())?; - let task = acp_thread.update(cx, |acp_thread, cx| acp_thread.set_title(title, cx))?; - task.await - }) - .detach_and_log_err(cx); - } - - fn handle_thread_token_usage_updated( - &mut self, - thread: Entity, - usage: &TokenUsageUpdated, - cx: &mut Context, - ) { - let Some(session) = self.sessions.get(thread.read(cx).id()) else { - return; - }; - session - .acp_thread - .update(cx, |acp_thread, cx| { - acp_thread.update_token_usage(usage.0.clone(), cx); - }) - .ok(); - } - - fn handle_project_event( - &mut self, - _project: Entity, - event: &project::Event, - _cx: &mut Context, - ) { - match event { - project::Event::WorktreeAdded(_) | project::Event::WorktreeRemoved(_) => { - self.project_context_needs_refresh.send(()).ok(); - } - project::Event::WorktreeUpdatedEntries(_, items) => { - if items.iter().any(|(path, _, _)| { - RULES_FILE_NAMES - .iter() - .any(|name| path.as_ref() == RelPath::unix(name).unwrap()) - }) { - self.project_context_needs_refresh.send(()).ok(); - } - } - _ => {} - } - } - - fn handle_prompts_updated_event( - &mut self, - _prompt_store: Entity, - _event: &prompt_store::PromptsUpdatedEvent, - _cx: &mut Context, - ) { - self.project_context_needs_refresh.send(()).ok(); - } - - fn handle_models_updated_event( - &mut self, - _registry: Entity, - _event: &language_model::Event, - cx: &mut Context, - ) { - self.models.refresh_list(cx); - - let registry = LanguageModelRegistry::read_global(cx); - let default_model = registry.default_model().map(|m| m.model); - let summarization_model = registry.thread_summary_model().map(|m| m.model); - - for session in self.sessions.values_mut() { - session.thread.update(cx, |thread, cx| { - if thread.model().is_none() - && let Some(model) = default_model.clone() - { - thread.set_model(model, cx); - cx.notify(); - } - thread.set_summarization_model(summarization_model.clone(), cx); - }); - } - } - - pub fn open_thread( - &mut self, - id: acp::SessionId, - cx: &mut Context, - ) -> Task>> { - let database_future = ThreadsDatabase::connect(cx); - cx.spawn(async move |this, cx| { - let database = database_future.await.map_err(|err| anyhow!(err))?; - let db_thread = database - .load_thread(id.clone()) - .await? - .with_context(|| format!("no thread found with ID: {id:?}"))?; - - let thread = this.update(cx, |this, cx| { - let action_log = cx.new(|_cx| ActionLog::new(this.project.clone())); - cx.new(|cx| { - Thread::from_db( - id.clone(), - db_thread, - this.project.clone(), - this.project_context.clone(), - this.context_server_registry.clone(), - action_log.clone(), - this.templates.clone(), - cx, - ) - }) - })?; - let acp_thread = - this.update(cx, |this, cx| this.register_session(thread.clone(), cx))?; - let events = thread.update(cx, |thread, cx| thread.replay(cx))?; - cx.update(|cx| { - NativeAgentConnection::handle_thread_events(events, acp_thread.downgrade(), cx) - })? - .await?; - Ok(acp_thread) - }) - } - - pub fn thread_summary( - &mut self, - id: acp::SessionId, - cx: &mut Context, - ) -> Task> { - let thread = self.open_thread(id.clone(), cx); - cx.spawn(async move |this, cx| { - let acp_thread = thread.await?; - let result = this - .update(cx, |this, cx| { - this.sessions - .get(&id) - .unwrap() - .thread - .update(cx, |thread, cx| thread.summary(cx)) - })? - .await?; - drop(acp_thread); - Ok(result) - }) - } - - fn save_thread(&mut self, thread: Entity, cx: &mut Context) { - if thread.read(cx).is_empty() { - return; - } - - let database_future = ThreadsDatabase::connect(cx); - let (id, db_thread) = - thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx))); - let Some(session) = self.sessions.get_mut(&id) else { - return; - }; - let history = self.history.clone(); - session.pending_save = cx.spawn(async move |_, cx| { - let Some(database) = database_future.await.map_err(|err| anyhow!(err)).log_err() else { - return; - }; - let db_thread = db_thread.await; - database.save_thread(id, db_thread).await.log_err(); - history.update(cx, |history, cx| history.reload(cx)).ok(); - }); - } -} - -/// Wrapper struct that implements the AgentConnection trait -#[derive(Clone)] -pub struct NativeAgentConnection(pub Entity); - -impl NativeAgentConnection { - pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option> { - self.0 - .read(cx) - .sessions - .get(session_id) - .map(|session| session.thread.clone()) - } - - fn run_turn( - &self, - session_id: acp::SessionId, - cx: &mut App, - f: impl 'static - + FnOnce(Entity, &mut App) -> Result>>, - ) -> Task> { - let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| { - agent - .sessions - .get_mut(&session_id) - .map(|s| (s.thread.clone(), s.acp_thread.clone())) - }) else { - return Task::ready(Err(anyhow!("Session not found"))); - }; - log::debug!("Found session for: {}", session_id); - - let response_stream = match f(thread, cx) { - Ok(stream) => stream, - Err(err) => return Task::ready(Err(err)), - }; - Self::handle_thread_events(response_stream, acp_thread, cx) - } - - fn handle_thread_events( - mut events: mpsc::UnboundedReceiver>, - acp_thread: WeakEntity, - cx: &App, - ) -> Task> { - cx.spawn(async move |cx| { - // Handle response stream and forward to session.acp_thread - while let Some(result) = events.next().await { - match result { - Ok(event) => { - log::trace!("Received completion event: {:?}", event); - - match event { - ThreadEvent::UserMessage(message) => { - acp_thread.update(cx, |thread, cx| { - for content in message.content { - thread.push_user_content_block( - Some(message.id.clone()), - content.into(), - cx, - ); - } - })?; - } - ThreadEvent::AgentText(text) => { - acp_thread.update(cx, |thread, cx| { - thread.push_assistant_content_block( - acp::ContentBlock::Text(acp::TextContent { - text, - annotations: None, - meta: None, - }), - false, - cx, - ) - })?; - } - ThreadEvent::AgentThinking(text) => { - acp_thread.update(cx, |thread, cx| { - thread.push_assistant_content_block( - acp::ContentBlock::Text(acp::TextContent { - text, - annotations: None, - meta: None, - }), - true, - cx, - ) - })?; - } - ThreadEvent::ToolCallAuthorization(ToolCallAuthorization { - tool_call, - options, - response, - }) => { - let outcome_task = acp_thread.update(cx, |thread, cx| { - thread.request_tool_call_authorization( - tool_call, options, true, cx, - ) - })??; - cx.background_spawn(async move { - if let acp::RequestPermissionOutcome::Selected { option_id } = - outcome_task.await - { - response - .send(option_id) - .map(|_| anyhow!("authorization receiver was dropped")) - .log_err(); - } - }) - .detach(); - } - ThreadEvent::ToolCall(tool_call) => { - acp_thread.update(cx, |thread, cx| { - thread.upsert_tool_call(tool_call, cx) - })??; - } - ThreadEvent::ToolCallUpdate(update) => { - acp_thread.update(cx, |thread, cx| { - thread.update_tool_call(update, cx) - })??; - } - ThreadEvent::Retry(status) => { - acp_thread.update(cx, |thread, cx| { - thread.update_retry_status(status, cx) - })?; - } - ThreadEvent::Stop(stop_reason) => { - log::debug!("Assistant message complete: {:?}", stop_reason); - return Ok(acp::PromptResponse { - stop_reason, - meta: None, - }); - } - } - } - Err(e) => { - log::error!("Error in model response stream: {:?}", e); - return Err(e); - } - } - } - - log::debug!("Response stream completed"); - anyhow::Ok(acp::PromptResponse { - stop_reason: acp::StopReason::EndTurn, - meta: None, - }) - }) - } -} - -struct NativeAgentModelSelector { - session_id: acp::SessionId, - connection: NativeAgentConnection, -} - -impl acp_thread::AgentModelSelector for NativeAgentModelSelector { - fn list_models(&self, cx: &mut App) -> Task> { - log::debug!("NativeAgentConnection::list_models called"); - let list = self.connection.0.read(cx).models.model_list.clone(); - Task::ready(if list.is_empty() { - Err(anyhow::anyhow!("No models available")) - } else { - Ok(list) - }) - } - - fn select_model(&self, model_id: acp::ModelId, cx: &mut App) -> Task> { - log::debug!( - "Setting model for session {}: {}", - self.session_id, - model_id - ); - let Some(thread) = self - .connection - .0 - .read(cx) - .sessions - .get(&self.session_id) - .map(|session| session.thread.clone()) - else { - return Task::ready(Err(anyhow!("Session not found"))); - }; - - let Some(model) = self.connection.0.read(cx).models.model_from_id(&model_id) else { - return Task::ready(Err(anyhow!("Invalid model ID {}", model_id))); - }; - - thread.update(cx, |thread, cx| { - thread.set_model(model.clone(), cx); - }); - - update_settings_file( - self.connection.0.read(cx).fs.clone(), - cx, - move |settings, _cx| { - let provider = model.provider_id().0.to_string(); - let model = model.id().0.to_string(); - settings - .agent - .get_or_insert_default() - .set_model(LanguageModelSelection { - provider: provider.into(), - model, - }); - }, - ); - - Task::ready(Ok(())) - } - - fn selected_model(&self, cx: &mut App) -> Task> { - let Some(thread) = self - .connection - .0 - .read(cx) - .sessions - .get(&self.session_id) - .map(|session| session.thread.clone()) - else { - return Task::ready(Err(anyhow!("Session not found"))); - }; - let Some(model) = thread.read(cx).model() else { - return Task::ready(Err(anyhow!("Model not found"))); - }; - let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id()) - else { - return Task::ready(Err(anyhow!("Provider not found"))); - }; - Task::ready(Ok(LanguageModels::map_language_model_to_info( - model, &provider, - ))) - } - - fn watch(&self, cx: &mut App) -> Option> { - Some(self.connection.0.read(cx).models.watch()) - } -} - -impl acp_thread::AgentConnection for NativeAgentConnection { - fn new_thread( - self: Rc, - project: Entity, - cwd: &Path, - cx: &mut App, - ) -> Task>> { - let agent = self.0.clone(); - log::debug!("Creating new thread for project at: {:?}", cwd); - - cx.spawn(async move |cx| { - log::debug!("Starting thread creation in async context"); - - // Create Thread - let thread = agent.update( - cx, - |agent, cx: &mut gpui::Context| -> Result<_> { - // Fetch default model from registry settings - let registry = LanguageModelRegistry::read_global(cx); - // Log available models for debugging - let available_count = registry.available_models(cx).count(); - log::debug!("Total available models: {}", available_count); - - let default_model = registry.default_model().and_then(|default_model| { - agent - .models - .model_from_id(&LanguageModels::model_id(&default_model.model)) - }); - Ok(cx.new(|cx| { - Thread::new( - project.clone(), - agent.project_context.clone(), - agent.context_server_registry.clone(), - agent.templates.clone(), - default_model, - cx, - ) - })) - }, - )??; - agent.update(cx, |agent, cx| agent.register_session(thread, cx)) - }) - } - - fn auth_methods(&self) -> &[acp::AuthMethod] { - &[] // No auth for in-process - } - - fn authenticate(&self, _method: acp::AuthMethodId, _cx: &mut App) -> Task> { - Task::ready(Ok(())) - } - - fn model_selector(&self, session_id: &acp::SessionId) -> Option> { - Some(Rc::new(NativeAgentModelSelector { - session_id: session_id.clone(), - connection: self.clone(), - }) as Rc) - } - - fn prompt( - &self, - id: Option, - params: acp::PromptRequest, - cx: &mut App, - ) -> Task> { - let id = id.expect("UserMessageId is required"); - let session_id = params.session_id.clone(); - log::info!("Received prompt request for session: {}", session_id); - log::debug!("Prompt blocks count: {}", params.prompt.len()); - - self.run_turn(session_id, cx, |thread, cx| { - let content: Vec = params - .prompt - .into_iter() - .map(Into::into) - .collect::>(); - log::debug!("Converted prompt to message: {} chars", content.len()); - log::debug!("Message id: {:?}", id); - log::debug!("Message content: {:?}", content); - - thread.update(cx, |thread, cx| thread.send(id, content, cx)) - }) - } - - fn resume( - &self, - session_id: &acp::SessionId, - _cx: &App, - ) -> Option> { - Some(Rc::new(NativeAgentSessionResume { - connection: self.clone(), - session_id: session_id.clone(), - }) as _) - } - - fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { - log::info!("Cancelling on session: {}", session_id); - self.0.update(cx, |agent, cx| { - if let Some(agent) = agent.sessions.get(session_id) { - agent.thread.update(cx, |thread, cx| thread.cancel(cx)); - } - }); - } - - fn truncate( - &self, - session_id: &agent_client_protocol::SessionId, - cx: &App, - ) -> Option> { - self.0.read_with(cx, |agent, _cx| { - agent.sessions.get(session_id).map(|session| { - Rc::new(NativeAgentSessionTruncate { - thread: session.thread.clone(), - acp_thread: session.acp_thread.clone(), - }) as _ - }) - }) - } - - fn set_title( - &self, - session_id: &acp::SessionId, - _cx: &App, - ) -> Option> { - Some(Rc::new(NativeAgentSessionSetTitle { - connection: self.clone(), - session_id: session_id.clone(), - }) as _) - } - - fn telemetry(&self) -> Option> { - Some(Rc::new(self.clone()) as Rc) - } - - fn into_any(self: Rc) -> Rc { - self - } -} - -impl acp_thread::AgentTelemetry for NativeAgentConnection { - fn agent_name(&self) -> String { - "Zed".into() - } - - fn thread_data( - &self, - session_id: &acp::SessionId, - cx: &mut App, - ) -> Task> { - let Some(session) = self.0.read(cx).sessions.get(session_id) else { - return Task::ready(Err(anyhow!("Session not found"))); - }; - - let task = session.thread.read(cx).to_db(cx); - cx.background_spawn(async move { - serde_json::to_value(task.await).context("Failed to serialize thread") - }) - } -} - -struct NativeAgentSessionTruncate { - thread: Entity, - acp_thread: WeakEntity, -} - -impl acp_thread::AgentSessionTruncate for NativeAgentSessionTruncate { - fn run(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task> { - match self.thread.update(cx, |thread, cx| { - thread.truncate(message_id.clone(), cx)?; - Ok(thread.latest_token_usage()) - }) { - Ok(usage) => { - self.acp_thread - .update(cx, |thread, cx| { - thread.update_token_usage(usage, cx); - }) - .ok(); - Task::ready(Ok(())) - } - Err(error) => Task::ready(Err(error)), - } - } -} - -struct NativeAgentSessionResume { - connection: NativeAgentConnection, - session_id: acp::SessionId, -} - -impl acp_thread::AgentSessionResume for NativeAgentSessionResume { - fn run(&self, cx: &mut App) -> Task> { - self.connection - .run_turn(self.session_id.clone(), cx, |thread, cx| { - thread.update(cx, |thread, cx| thread.resume(cx)) - }) - } -} - -struct NativeAgentSessionSetTitle { - connection: NativeAgentConnection, - session_id: acp::SessionId, -} - -impl acp_thread::AgentSessionSetTitle for NativeAgentSessionSetTitle { - fn run(&self, title: SharedString, cx: &mut App) -> Task> { - let Some(session) = self.connection.0.read(cx).sessions.get(&self.session_id) else { - return Task::ready(Err(anyhow!("session not found"))); - }; - let thread = session.thread.clone(); - thread.update(cx, |thread, cx| thread.set_title(title, cx)); - Task::ready(Ok(())) - } -} - -pub struct AcpThreadEnvironment { - acp_thread: WeakEntity, -} - -impl ThreadEnvironment for AcpThreadEnvironment { - fn create_terminal( - &self, - command: String, - cwd: Option, - output_byte_limit: Option, - cx: &mut AsyncApp, - ) -> Task>> { - let task = self.acp_thread.update(cx, |thread, cx| { - thread.create_terminal(command, vec![], vec![], cwd, output_byte_limit, cx) - }); - - let acp_thread = self.acp_thread.clone(); - cx.spawn(async move |cx| { - let terminal = task?.await?; - - let (drop_tx, drop_rx) = oneshot::channel(); - let terminal_id = terminal.read_with(cx, |terminal, _cx| terminal.id().clone())?; - - cx.spawn(async move |cx| { - drop_rx.await.ok(); - acp_thread.update(cx, |thread, cx| thread.release_terminal(terminal_id, cx)) - }) - .detach(); - - let handle = AcpTerminalHandle { - terminal, - _drop_tx: Some(drop_tx), - }; - - Ok(Rc::new(handle) as _) - }) - } -} - -pub struct AcpTerminalHandle { - terminal: Entity, - _drop_tx: Option>, -} - -impl TerminalHandle for AcpTerminalHandle { - fn id(&self, cx: &AsyncApp) -> Result { - self.terminal.read_with(cx, |term, _cx| term.id().clone()) - } - - fn wait_for_exit(&self, cx: &AsyncApp) -> Result>> { - self.terminal - .read_with(cx, |term, _cx| term.wait_for_exit()) - } - - fn current_output(&self, cx: &AsyncApp) -> Result { - self.terminal - .read_with(cx, |term, cx| term.current_output(cx)) - } -} - -#[cfg(test)] -mod tests { - use crate::HistoryEntryId; - - use super::*; - use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelInfo, MentionUri}; - use fs::FakeFs; - use gpui::TestAppContext; - use indoc::formatdoc; - use language_model::fake_provider::FakeLanguageModel; - use serde_json::json; - use settings::SettingsStore; - use util::{path, rel_path::rel_path}; - - #[gpui::test] - async fn test_maintaining_project_context(cx: &mut TestAppContext) { - init_test(cx); - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - "/", - json!({ - "a": {} - }), - ) - .await; - let project = Project::test(fs.clone(), [], cx).await; - let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx)); - let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); - let agent = NativeAgent::new( - project.clone(), - history_store, - Templates::new(), - None, - fs.clone(), - &mut cx.to_async(), - ) - .await - .unwrap(); - agent.read_with(cx, |agent, cx| { - assert_eq!(agent.project_context.read(cx).worktrees, vec![]) - }); - - let worktree = project - .update(cx, |project, cx| project.create_worktree("/a", true, cx)) - .await - .unwrap(); - cx.run_until_parked(); - agent.read_with(cx, |agent, cx| { - assert_eq!( - agent.project_context.read(cx).worktrees, - vec![WorktreeContext { - root_name: "a".into(), - abs_path: Path::new("/a").into(), - rules_file: None - }] - ) - }); - - // Creating `/a/.rules` updates the project context. - fs.insert_file("/a/.rules", Vec::new()).await; - cx.run_until_parked(); - agent.read_with(cx, |agent, cx| { - let rules_entry = worktree - .read(cx) - .entry_for_path(rel_path(".rules")) - .unwrap(); - assert_eq!( - agent.project_context.read(cx).worktrees, - vec![WorktreeContext { - root_name: "a".into(), - abs_path: Path::new("/a").into(), - rules_file: Some(RulesFileContext { - path_in_worktree: rel_path(".rules").into(), - text: "".into(), - project_entry_id: rules_entry.id.to_usize() - }) - }] - ) - }); - } - - #[gpui::test] - async fn test_listing_models(cx: &mut TestAppContext) { - init_test(cx); - let fs = FakeFs::new(cx.executor()); - fs.insert_tree("/", json!({ "a": {} })).await; - let project = Project::test(fs.clone(), [], cx).await; - let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx)); - let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); - let connection = NativeAgentConnection( - NativeAgent::new( - project.clone(), - history_store, - Templates::new(), - None, - fs.clone(), - &mut cx.to_async(), - ) - .await - .unwrap(), - ); - - // Create a thread/session - let acp_thread = cx - .update(|cx| { - Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx) - }) - .await - .unwrap(); - - let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone()); - - let models = cx - .update(|cx| { - connection - .model_selector(&session_id) - .unwrap() - .list_models(cx) - }) - .await - .unwrap(); - - let acp_thread::AgentModelList::Grouped(models) = models else { - panic!("Unexpected model group"); - }; - assert_eq!( - models, - IndexMap::from_iter([( - AgentModelGroupName("Fake".into()), - vec![AgentModelInfo { - id: acp::ModelId("fake/fake".into()), - name: "Fake".into(), - description: None, - icon: Some(ui::IconName::ZedAssistant), - }] - )]) - ); - } - - #[gpui::test] - async fn test_model_selection_persists_to_settings(cx: &mut TestAppContext) { - init_test(cx); - let fs = FakeFs::new(cx.executor()); - fs.create_dir(paths::settings_file().parent().unwrap()) - .await - .unwrap(); - fs.insert_file( - paths::settings_file(), - json!({ - "agent": { - "default_model": { - "provider": "foo", - "model": "bar" - } - } - }) - .to_string() - .into_bytes(), - ) - .await; - let project = Project::test(fs.clone(), [], cx).await; - - let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx)); - let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); - - // Create the agent and connection - let agent = NativeAgent::new( - project.clone(), - history_store, - Templates::new(), - None, - fs.clone(), - &mut cx.to_async(), - ) - .await - .unwrap(); - let connection = NativeAgentConnection(agent.clone()); - - // Create a thread/session - let acp_thread = cx - .update(|cx| { - Rc::new(connection.clone()).new_thread(project.clone(), Path::new("/a"), cx) - }) - .await - .unwrap(); - - let session_id = cx.update(|cx| acp_thread.read(cx).session_id().clone()); - - // Select a model - let selector = connection.model_selector(&session_id).unwrap(); - let model_id = acp::ModelId("fake/fake".into()); - cx.update(|cx| selector.select_model(model_id.clone(), cx)) - .await - .unwrap(); - - // Verify the thread has the selected model - agent.read_with(cx, |agent, _| { - let session = agent.sessions.get(&session_id).unwrap(); - session.thread.read_with(cx, |thread, _| { - assert_eq!(thread.model().unwrap().id().0, "fake"); - }); - }); - - cx.run_until_parked(); - - // Verify settings file was updated - let settings_content = fs.load(paths::settings_file()).await.unwrap(); - let settings_json: serde_json::Value = serde_json::from_str(&settings_content).unwrap(); - - // Check that the agent settings contain the selected model - assert_eq!( - settings_json["agent"]["default_model"]["model"], - json!("fake") - ); - assert_eq!( - settings_json["agent"]["default_model"]["provider"], - json!("fake") - ); - } - - #[gpui::test] - async fn test_save_load_thread(cx: &mut TestAppContext) { - init_test(cx); - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - "/", - json!({ - "a": { - "b.md": "Lorem" - } - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await; - let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx)); - let history_store = cx.new(|cx| HistoryStore::new(context_store, cx)); - let agent = NativeAgent::new( - project.clone(), - history_store.clone(), - Templates::new(), - None, - fs.clone(), - &mut cx.to_async(), - ) - .await - .unwrap(); - let connection = Rc::new(NativeAgentConnection(agent.clone())); - - let acp_thread = cx - .update(|cx| { - connection - .clone() - .new_thread(project.clone(), Path::new(""), cx) - }) - .await - .unwrap(); - let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); - let thread = agent.read_with(cx, |agent, _| { - agent.sessions.get(&session_id).unwrap().thread.clone() - }); - - // Ensure empty threads are not saved, even if they get mutated. - let model = Arc::new(FakeLanguageModel::default()); - let summary_model = Arc::new(FakeLanguageModel::default()); - thread.update(cx, |thread, cx| { - thread.set_model(model.clone(), cx); - thread.set_summarization_model(Some(summary_model.clone()), cx); - }); - cx.run_until_parked(); - assert_eq!(history_entries(&history_store, cx), vec![]); - - let send = acp_thread.update(cx, |thread, cx| { - thread.send( - vec![ - "What does ".into(), - acp::ContentBlock::ResourceLink(acp::ResourceLink { - name: "b.md".into(), - uri: MentionUri::File { - abs_path: path!("/a/b.md").into(), - } - .to_uri() - .to_string(), - annotations: None, - description: None, - mime_type: None, - size: None, - title: None, - meta: None, - }), - " mean?".into(), - ], - cx, - ) - }); - let send = cx.foreground_executor().spawn(send); - cx.run_until_parked(); - - model.send_last_completion_stream_text_chunk("Lorem."); - model.end_last_completion_stream(); - cx.run_until_parked(); - summary_model - .send_last_completion_stream_text_chunk(&format!("Explaining {}", path!("/a/b.md"))); - summary_model.end_last_completion_stream(); - - send.await.unwrap(); - let uri = MentionUri::File { - abs_path: path!("/a/b.md").into(), - } - .to_uri(); - acp_thread.read_with(cx, |thread, cx| { - assert_eq!( - thread.to_markdown(cx), - formatdoc! {" - ## User - - What does [@b.md]({uri}) mean? - - ## Assistant - - Lorem. - - "} - ) - }); - - cx.run_until_parked(); - - // Drop the ACP thread, which should cause the session to be dropped as well. - cx.update(|_| { - drop(thread); - drop(acp_thread); - }); - agent.read_with(cx, |agent, _| { - assert_eq!(agent.sessions.keys().cloned().collect::>(), []); - }); - - // Ensure the thread can be reloaded from disk. - assert_eq!( - history_entries(&history_store, cx), - vec![( - HistoryEntryId::AcpThread(session_id.clone()), - format!("Explaining {}", path!("/a/b.md")) - )] - ); - let acp_thread = agent - .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx)) - .await - .unwrap(); - acp_thread.read_with(cx, |thread, cx| { - assert_eq!( - thread.to_markdown(cx), - formatdoc! {" - ## User - - What does [@b.md]({uri}) mean? - - ## Assistant - - Lorem. - - "} - ) - }); - } - - fn history_entries( - history: &Entity, - cx: &mut TestAppContext, - ) -> Vec<(HistoryEntryId, String)> { - history.read_with(cx, |history, _| { - history - .entries() - .map(|e| (e.id(), e.title().to_string())) - .collect::>() - }) - } - - fn init_test(cx: &mut TestAppContext) { - env_logger::try_init().ok(); - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - Project::init_settings(cx); - agent_settings::init(cx); - language::init(cx); - LanguageModelRegistry::test(cx); - }); - } -} diff --git a/crates/agent2/src/agent2.rs b/crates/agent2/src/agent2.rs deleted file mode 100644 index 1fc9c1cb956d1676c42713b5d9bb2a0b51e8ac90..0000000000000000000000000000000000000000 --- a/crates/agent2/src/agent2.rs +++ /dev/null @@ -1,19 +0,0 @@ -mod agent; -mod db; -mod history_store; -mod native_agent_server; -mod templates; -mod thread; -mod tool_schema; -mod tools; - -#[cfg(test)] -mod tests; - -pub use agent::*; -pub use db::*; -pub use history_store::*; -pub use native_agent_server::NativeAgentServer; -pub use templates::*; -pub use thread::*; -pub use tools::*; diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs deleted file mode 100644 index 756b868dcfc26239911d6e5c0cd8ad984cd7dc4e..0000000000000000000000000000000000000000 --- a/crates/agent2/src/thread.rs +++ /dev/null @@ -1,2663 +0,0 @@ -use crate::{ - ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DbLanguageModel, DbThread, - DeletePathTool, DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, - ListDirectoryTool, MovePathTool, NowTool, OpenTool, ReadFileTool, SystemPromptTemplate, - Template, Templates, TerminalTool, ThinkingTool, WebSearchTool, -}; -use acp_thread::{MentionUri, UserMessageId}; -use action_log::ActionLog; -use agent::thread::{GitState, ProjectSnapshot, WorktreeSnapshot}; -use agent_client_protocol as acp; -use agent_settings::{ - AgentProfileId, AgentProfileSettings, AgentSettings, CompletionMode, - SUMMARIZE_THREAD_DETAILED_PROMPT, SUMMARIZE_THREAD_PROMPT, -}; -use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::adapt_schema_to_format; -use chrono::{DateTime, Utc}; -use client::{ModelRequestUsage, RequestUsage, UserStore}; -use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit}; -use collections::{HashMap, HashSet, IndexMap}; -use fs::Fs; -use futures::stream; -use futures::{ - FutureExt, - channel::{mpsc, oneshot}, - future::Shared, - stream::FuturesUnordered, -}; -use git::repository::DiffType; -use gpui::{ - App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity, -}; -use language_model::{ - LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt, - LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest, - LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, - LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse, - LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage, ZED_CLOUD_PROVIDER_ID, -}; -use project::{ - Project, - git_store::{GitStore, RepositoryState}, -}; -use prompt_store::ProjectContext; -use schemars::{JsonSchema, Schema}; -use serde::{Deserialize, Serialize}; -use settings::{Settings, update_settings_file}; -use smol::stream::StreamExt; -use std::{ - collections::BTreeMap, - ops::RangeInclusive, - path::Path, - rc::Rc, - sync::Arc, - time::{Duration, Instant}, -}; -use std::{fmt::Write, path::PathBuf}; -use util::{ResultExt, debug_panic, markdown::MarkdownCodeBlock}; -use uuid::Uuid; - -const TOOL_CANCELED_MESSAGE: &str = "Tool canceled by user"; -pub const MAX_TOOL_NAME_LENGTH: usize = 64; - -/// The ID of the user prompt that initiated a request. -/// -/// This equates to the user physically submitting a message to the model (e.g., by pressing the Enter key). -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)] -pub struct PromptId(Arc); - -impl PromptId { - pub fn new() -> Self { - Self(Uuid::new_v4().to_string().into()) - } -} - -impl std::fmt::Display for PromptId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -pub(crate) const MAX_RETRY_ATTEMPTS: u8 = 4; -pub(crate) const BASE_RETRY_DELAY: Duration = Duration::from_secs(5); - -#[derive(Debug, Clone)] -enum RetryStrategy { - ExponentialBackoff { - initial_delay: Duration, - max_attempts: u8, - }, - Fixed { - delay: Duration, - max_attempts: u8, - }, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum Message { - User(UserMessage), - Agent(AgentMessage), - Resume, -} - -impl Message { - pub fn as_agent_message(&self) -> Option<&AgentMessage> { - match self { - Message::Agent(agent_message) => Some(agent_message), - _ => None, - } - } - - pub fn to_request(&self) -> Vec { - match self { - Message::User(message) => vec![message.to_request()], - Message::Agent(message) => message.to_request(), - Message::Resume => vec![LanguageModelRequestMessage { - role: Role::User, - content: vec!["Continue where you left off".into()], - cache: false, - }], - } - } - - pub fn to_markdown(&self) -> String { - match self { - Message::User(message) => message.to_markdown(), - Message::Agent(message) => message.to_markdown(), - Message::Resume => "[resume]\n".into(), - } - } - - pub fn role(&self) -> Role { - match self { - Message::User(_) | Message::Resume => Role::User, - Message::Agent(_) => Role::Assistant, - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct UserMessage { - pub id: UserMessageId, - pub content: Vec, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum UserMessageContent { - Text(String), - Mention { uri: MentionUri, content: String }, - Image(LanguageModelImage), -} - -impl UserMessage { - pub fn to_markdown(&self) -> String { - let mut markdown = String::from("## User\n\n"); - - for content in &self.content { - match content { - UserMessageContent::Text(text) => { - markdown.push_str(text); - markdown.push('\n'); - } - UserMessageContent::Image(_) => { - markdown.push_str("\n"); - } - UserMessageContent::Mention { uri, content } => { - if !content.is_empty() { - let _ = writeln!(&mut markdown, "{}\n\n{}", uri.as_link(), content); - } else { - let _ = writeln!(&mut markdown, "{}", uri.as_link()); - } - } - } - } - - markdown - } - - fn to_request(&self) -> LanguageModelRequestMessage { - let mut message = LanguageModelRequestMessage { - role: Role::User, - content: Vec::with_capacity(self.content.len()), - cache: false, - }; - - const OPEN_CONTEXT: &str = "\n\ - The following items were attached by the user. \ - They are up-to-date and don't need to be re-read.\n\n"; - - const OPEN_FILES_TAG: &str = ""; - const OPEN_DIRECTORIES_TAG: &str = ""; - const OPEN_SYMBOLS_TAG: &str = ""; - const OPEN_SELECTIONS_TAG: &str = ""; - const OPEN_THREADS_TAG: &str = ""; - const OPEN_FETCH_TAG: &str = ""; - const OPEN_RULES_TAG: &str = - "\nThe user has specified the following rules that should be applied:\n"; - - let mut file_context = OPEN_FILES_TAG.to_string(); - let mut directory_context = OPEN_DIRECTORIES_TAG.to_string(); - let mut symbol_context = OPEN_SYMBOLS_TAG.to_string(); - let mut selection_context = OPEN_SELECTIONS_TAG.to_string(); - let mut thread_context = OPEN_THREADS_TAG.to_string(); - let mut fetch_context = OPEN_FETCH_TAG.to_string(); - let mut rules_context = OPEN_RULES_TAG.to_string(); - - for chunk in &self.content { - let chunk = match chunk { - UserMessageContent::Text(text) => { - language_model::MessageContent::Text(text.clone()) - } - UserMessageContent::Image(value) => { - language_model::MessageContent::Image(value.clone()) - } - UserMessageContent::Mention { uri, content } => { - match uri { - MentionUri::File { abs_path } => { - write!( - &mut file_context, - "\n{}", - MarkdownCodeBlock { - tag: &codeblock_tag(abs_path, None), - text: &content.to_string(), - } - ) - .ok(); - } - MentionUri::PastedImage => { - debug_panic!("pasted image URI should not be used in mention content") - } - MentionUri::Directory { .. } => { - write!(&mut directory_context, "\n{}\n", content).ok(); - } - MentionUri::Symbol { - abs_path: path, - line_range, - .. - } => { - write!( - &mut symbol_context, - "\n{}", - MarkdownCodeBlock { - tag: &codeblock_tag(path, Some(line_range)), - text: content - } - ) - .ok(); - } - MentionUri::Selection { - abs_path: path, - line_range, - .. - } => { - write!( - &mut selection_context, - "\n{}", - MarkdownCodeBlock { - tag: &codeblock_tag( - path.as_deref().unwrap_or("Untitled".as_ref()), - Some(line_range) - ), - text: content - } - ) - .ok(); - } - MentionUri::Thread { .. } => { - write!(&mut thread_context, "\n{}\n", content).ok(); - } - MentionUri::TextThread { .. } => { - write!(&mut thread_context, "\n{}\n", content).ok(); - } - MentionUri::Rule { .. } => { - write!( - &mut rules_context, - "\n{}", - MarkdownCodeBlock { - tag: "", - text: content - } - ) - .ok(); - } - MentionUri::Fetch { url } => { - write!(&mut fetch_context, "\nFetch: {}\n\n{}", url, content).ok(); - } - } - - language_model::MessageContent::Text(uri.as_link().to_string()) - } - }; - - message.content.push(chunk); - } - - let len_before_context = message.content.len(); - - if file_context.len() > OPEN_FILES_TAG.len() { - file_context.push_str("\n"); - message - .content - .push(language_model::MessageContent::Text(file_context)); - } - - if directory_context.len() > OPEN_DIRECTORIES_TAG.len() { - directory_context.push_str("\n"); - message - .content - .push(language_model::MessageContent::Text(directory_context)); - } - - if symbol_context.len() > OPEN_SYMBOLS_TAG.len() { - symbol_context.push_str("\n"); - message - .content - .push(language_model::MessageContent::Text(symbol_context)); - } - - if selection_context.len() > OPEN_SELECTIONS_TAG.len() { - selection_context.push_str("\n"); - message - .content - .push(language_model::MessageContent::Text(selection_context)); - } - - if thread_context.len() > OPEN_THREADS_TAG.len() { - thread_context.push_str("\n"); - message - .content - .push(language_model::MessageContent::Text(thread_context)); - } - - if fetch_context.len() > OPEN_FETCH_TAG.len() { - fetch_context.push_str("\n"); - message - .content - .push(language_model::MessageContent::Text(fetch_context)); - } - - if rules_context.len() > OPEN_RULES_TAG.len() { - rules_context.push_str("\n"); - message - .content - .push(language_model::MessageContent::Text(rules_context)); - } - - if message.content.len() > len_before_context { - message.content.insert( - len_before_context, - language_model::MessageContent::Text(OPEN_CONTEXT.into()), - ); - message - .content - .push(language_model::MessageContent::Text("".into())); - } - - message - } -} - -fn codeblock_tag(full_path: &Path, line_range: Option<&RangeInclusive>) -> String { - let mut result = String::new(); - - if let Some(extension) = full_path.extension().and_then(|ext| ext.to_str()) { - let _ = write!(result, "{} ", extension); - } - - let _ = write!(result, "{}", full_path.display()); - - if let Some(range) = line_range { - if range.start() == range.end() { - let _ = write!(result, ":{}", range.start() + 1); - } else { - let _ = write!(result, ":{}-{}", range.start() + 1, range.end() + 1); - } - } - - result -} - -impl AgentMessage { - pub fn to_markdown(&self) -> String { - let mut markdown = String::from("## Assistant\n\n"); - - for content in &self.content { - match content { - AgentMessageContent::Text(text) => { - markdown.push_str(text); - markdown.push('\n'); - } - AgentMessageContent::Thinking { text, .. } => { - markdown.push_str(""); - markdown.push_str(text); - markdown.push_str("\n"); - } - AgentMessageContent::RedactedThinking(_) => { - markdown.push_str("\n") - } - AgentMessageContent::ToolUse(tool_use) => { - markdown.push_str(&format!( - "**Tool Use**: {} (ID: {})\n", - tool_use.name, tool_use.id - )); - markdown.push_str(&format!( - "{}\n", - MarkdownCodeBlock { - tag: "json", - text: &format!("{:#}", tool_use.input) - } - )); - } - } - } - - for tool_result in self.tool_results.values() { - markdown.push_str(&format!( - "**Tool Result**: {} (ID: {})\n\n", - tool_result.tool_name, tool_result.tool_use_id - )); - if tool_result.is_error { - markdown.push_str("**ERROR:**\n"); - } - - match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - writeln!(markdown, "{text}\n").ok(); - } - LanguageModelToolResultContent::Image(_) => { - writeln!(markdown, "\n").ok(); - } - } - - if let Some(output) = tool_result.output.as_ref() { - writeln!( - markdown, - "**Debug Output**:\n\n```json\n{}\n```\n", - serde_json::to_string_pretty(output).unwrap() - ) - .unwrap(); - } - } - - markdown - } - - pub fn to_request(&self) -> Vec { - let mut assistant_message = LanguageModelRequestMessage { - role: Role::Assistant, - content: Vec::with_capacity(self.content.len()), - cache: false, - }; - for chunk in &self.content { - match chunk { - AgentMessageContent::Text(text) => { - assistant_message - .content - .push(language_model::MessageContent::Text(text.clone())); - } - AgentMessageContent::Thinking { text, signature } => { - assistant_message - .content - .push(language_model::MessageContent::Thinking { - text: text.clone(), - signature: signature.clone(), - }); - } - AgentMessageContent::RedactedThinking(value) => { - assistant_message.content.push( - language_model::MessageContent::RedactedThinking(value.clone()), - ); - } - AgentMessageContent::ToolUse(tool_use) => { - if self.tool_results.contains_key(&tool_use.id) { - assistant_message - .content - .push(language_model::MessageContent::ToolUse(tool_use.clone())); - } - } - }; - } - - let mut user_message = LanguageModelRequestMessage { - role: Role::User, - content: Vec::new(), - cache: false, - }; - - for tool_result in self.tool_results.values() { - let mut tool_result = tool_result.clone(); - // Surprisingly, the API fails if we return an empty string here. - // It thinks we are sending a tool use without a tool result. - if tool_result.content.is_empty() { - tool_result.content = "".into(); - } - user_message - .content - .push(language_model::MessageContent::ToolResult(tool_result)); - } - - let mut messages = Vec::new(); - if !assistant_message.content.is_empty() { - messages.push(assistant_message); - } - if !user_message.content.is_empty() { - messages.push(user_message); - } - messages - } -} - -#[derive(Default, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct AgentMessage { - pub content: Vec, - pub tool_results: IndexMap, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum AgentMessageContent { - Text(String), - Thinking { - text: String, - signature: Option, - }, - RedactedThinking(String), - ToolUse(LanguageModelToolUse), -} - -pub trait TerminalHandle { - fn id(&self, cx: &AsyncApp) -> Result; - fn current_output(&self, cx: &AsyncApp) -> Result; - fn wait_for_exit(&self, cx: &AsyncApp) -> Result>>; -} - -pub trait ThreadEnvironment { - fn create_terminal( - &self, - command: String, - cwd: Option, - output_byte_limit: Option, - cx: &mut AsyncApp, - ) -> Task>>; -} - -#[derive(Debug)] -pub enum ThreadEvent { - UserMessage(UserMessage), - AgentText(String), - AgentThinking(String), - ToolCall(acp::ToolCall), - ToolCallUpdate(acp_thread::ToolCallUpdate), - ToolCallAuthorization(ToolCallAuthorization), - Retry(acp_thread::RetryStatus), - Stop(acp::StopReason), -} - -#[derive(Debug)] -pub struct NewTerminal { - pub command: String, - pub output_byte_limit: Option, - pub cwd: Option, - pub response: oneshot::Sender>>, -} - -#[derive(Debug)] -pub struct ToolCallAuthorization { - pub tool_call: acp::ToolCallUpdate, - pub options: Vec, - pub response: oneshot::Sender, -} - -#[derive(Debug, thiserror::Error)] -enum CompletionError { - #[error("max tokens")] - MaxTokens, - #[error("refusal")] - Refusal, - #[error(transparent)] - Other(#[from] anyhow::Error), -} - -pub struct Thread { - id: acp::SessionId, - prompt_id: PromptId, - updated_at: DateTime, - title: Option, - pending_title_generation: Option>, - summary: Option, - messages: Vec, - user_store: Entity, - completion_mode: CompletionMode, - /// Holds the task that handles agent interaction until the end of the turn. - /// Survives across multiple requests as the model performs tool calls and - /// we run tools, report their results. - running_turn: Option, - pending_message: Option, - tools: BTreeMap>, - tool_use_limit_reached: bool, - request_token_usage: HashMap, - #[allow(unused)] - cumulative_token_usage: TokenUsage, - #[allow(unused)] - initial_project_snapshot: Shared>>>, - context_server_registry: Entity, - profile_id: AgentProfileId, - project_context: Entity, - templates: Arc, - model: Option>, - summarization_model: Option>, - prompt_capabilities_tx: watch::Sender, - pub(crate) prompt_capabilities_rx: watch::Receiver, - pub(crate) project: Entity, - pub(crate) action_log: Entity, -} - -impl Thread { - fn prompt_capabilities(model: Option<&dyn LanguageModel>) -> acp::PromptCapabilities { - let image = model.map_or(true, |model| model.supports_images()); - acp::PromptCapabilities { - meta: None, - image, - audio: false, - embedded_context: true, - } - } - - pub fn new( - project: Entity, - project_context: Entity, - context_server_registry: Entity, - templates: Arc, - model: Option>, - cx: &mut Context, - ) -> Self { - let profile_id = AgentSettings::get_global(cx).default_profile.clone(); - let action_log = cx.new(|_cx| ActionLog::new(project.clone())); - let (prompt_capabilities_tx, prompt_capabilities_rx) = - watch::channel(Self::prompt_capabilities(model.as_deref())); - Self { - id: acp::SessionId(uuid::Uuid::new_v4().to_string().into()), - prompt_id: PromptId::new(), - updated_at: Utc::now(), - title: None, - pending_title_generation: None, - summary: None, - messages: Vec::new(), - user_store: project.read(cx).user_store(), - completion_mode: AgentSettings::get_global(cx).preferred_completion_mode, - running_turn: None, - pending_message: None, - tools: BTreeMap::default(), - tool_use_limit_reached: false, - request_token_usage: HashMap::default(), - cumulative_token_usage: TokenUsage::default(), - initial_project_snapshot: { - let project_snapshot = Self::project_snapshot(project.clone(), cx); - cx.foreground_executor() - .spawn(async move { Some(project_snapshot.await) }) - .shared() - }, - context_server_registry, - profile_id, - project_context, - templates, - model, - summarization_model: None, - prompt_capabilities_tx, - prompt_capabilities_rx, - project, - action_log, - } - } - - pub fn id(&self) -> &acp::SessionId { - &self.id - } - - pub fn replay( - &mut self, - cx: &mut Context, - ) -> mpsc::UnboundedReceiver> { - let (tx, rx) = mpsc::unbounded(); - let stream = ThreadEventStream(tx); - for message in &self.messages { - match message { - Message::User(user_message) => stream.send_user_message(user_message), - Message::Agent(assistant_message) => { - for content in &assistant_message.content { - match content { - AgentMessageContent::Text(text) => stream.send_text(text), - AgentMessageContent::Thinking { text, .. } => { - stream.send_thinking(text) - } - AgentMessageContent::RedactedThinking(_) => {} - AgentMessageContent::ToolUse(tool_use) => { - self.replay_tool_call( - tool_use, - assistant_message.tool_results.get(&tool_use.id), - &stream, - cx, - ); - } - } - } - } - Message::Resume => {} - } - } - rx - } - - fn replay_tool_call( - &self, - tool_use: &LanguageModelToolUse, - tool_result: Option<&LanguageModelToolResult>, - stream: &ThreadEventStream, - cx: &mut Context, - ) { - let tool = self.tools.get(tool_use.name.as_ref()).cloned().or_else(|| { - self.context_server_registry - .read(cx) - .servers() - .find_map(|(_, tools)| { - if let Some(tool) = tools.get(tool_use.name.as_ref()) { - Some(tool.clone()) - } else { - None - } - }) - }); - - let Some(tool) = tool else { - stream - .0 - .unbounded_send(Ok(ThreadEvent::ToolCall(acp::ToolCall { - meta: None, - id: acp::ToolCallId(tool_use.id.to_string().into()), - title: tool_use.name.to_string(), - kind: acp::ToolKind::Other, - status: acp::ToolCallStatus::Failed, - content: Vec::new(), - locations: Vec::new(), - raw_input: Some(tool_use.input.clone()), - raw_output: None, - }))) - .ok(); - return; - }; - - let title = tool.initial_title(tool_use.input.clone(), cx); - let kind = tool.kind(); - stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone()); - - let output = tool_result - .as_ref() - .and_then(|result| result.output.clone()); - if let Some(output) = output.clone() { - let tool_event_stream = ToolCallEventStream::new( - tool_use.id.clone(), - stream.clone(), - Some(self.project.read(cx).fs().clone()), - ); - tool.replay(tool_use.input.clone(), output, tool_event_stream, cx) - .log_err(); - } - - stream.update_tool_call_fields( - &tool_use.id, - acp::ToolCallUpdateFields { - status: Some( - tool_result - .as_ref() - .map_or(acp::ToolCallStatus::Failed, |result| { - if result.is_error { - acp::ToolCallStatus::Failed - } else { - acp::ToolCallStatus::Completed - } - }), - ), - raw_output: output, - ..Default::default() - }, - ); - } - - pub fn from_db( - id: acp::SessionId, - db_thread: DbThread, - project: Entity, - project_context: Entity, - context_server_registry: Entity, - action_log: Entity, - templates: Arc, - cx: &mut Context, - ) -> Self { - let profile_id = db_thread - .profile - .unwrap_or_else(|| AgentSettings::get_global(cx).default_profile.clone()); - let model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| { - db_thread - .model - .and_then(|model| { - let model = SelectedModel { - provider: model.provider.clone().into(), - model: model.model.into(), - }; - registry.select_model(&model, cx) - }) - .or_else(|| registry.default_model()) - .map(|model| model.model) - }); - let (prompt_capabilities_tx, prompt_capabilities_rx) = - watch::channel(Self::prompt_capabilities(model.as_deref())); - - Self { - id, - prompt_id: PromptId::new(), - title: if db_thread.title.is_empty() { - None - } else { - Some(db_thread.title.clone()) - }, - pending_title_generation: None, - summary: db_thread.detailed_summary, - messages: db_thread.messages, - user_store: project.read(cx).user_store(), - completion_mode: db_thread.completion_mode.unwrap_or_default(), - running_turn: None, - pending_message: None, - tools: BTreeMap::default(), - tool_use_limit_reached: false, - request_token_usage: db_thread.request_token_usage.clone(), - cumulative_token_usage: db_thread.cumulative_token_usage, - initial_project_snapshot: Task::ready(db_thread.initial_project_snapshot).shared(), - context_server_registry, - profile_id, - project_context, - templates, - model, - summarization_model: None, - project, - action_log, - updated_at: db_thread.updated_at, - prompt_capabilities_tx, - prompt_capabilities_rx, - } - } - - pub fn to_db(&self, cx: &App) -> Task { - let initial_project_snapshot = self.initial_project_snapshot.clone(); - let mut thread = DbThread { - title: self.title(), - messages: self.messages.clone(), - updated_at: self.updated_at, - detailed_summary: self.summary.clone(), - initial_project_snapshot: None, - cumulative_token_usage: self.cumulative_token_usage, - request_token_usage: self.request_token_usage.clone(), - model: self.model.as_ref().map(|model| DbLanguageModel { - provider: model.provider_id().to_string(), - model: model.name().0.to_string(), - }), - completion_mode: Some(self.completion_mode), - profile: Some(self.profile_id.clone()), - }; - - cx.background_spawn(async move { - let initial_project_snapshot = initial_project_snapshot.await; - thread.initial_project_snapshot = initial_project_snapshot; - thread - }) - } - - /// Create a snapshot of the current project state including git information and unsaved buffers. - fn project_snapshot( - project: Entity, - cx: &mut Context, - ) -> Task> { - let git_store = project.read(cx).git_store().clone(); - let worktree_snapshots: Vec<_> = project - .read(cx) - .visible_worktrees(cx) - .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx)) - .collect(); - - cx.spawn(async move |_, _| { - let worktree_snapshots = futures::future::join_all(worktree_snapshots).await; - - Arc::new(ProjectSnapshot { - worktree_snapshots, - timestamp: Utc::now(), - }) - }) - } - - fn worktree_snapshot( - worktree: Entity, - git_store: Entity, - cx: &App, - ) -> Task { - cx.spawn(async move |cx| { - // Get worktree path and snapshot - let worktree_info = cx.update(|app_cx| { - let worktree = worktree.read(app_cx); - let path = worktree.abs_path().to_string_lossy().into_owned(); - let snapshot = worktree.snapshot(); - (path, snapshot) - }); - - let Ok((worktree_path, _snapshot)) = worktree_info else { - return WorktreeSnapshot { - worktree_path: String::new(), - git_state: None, - }; - }; - - let git_state = git_store - .update(cx, |git_store, cx| { - git_store - .repositories() - .values() - .find(|repo| { - repo.read(cx) - .abs_path_to_repo_path(&worktree.read(cx).abs_path()) - .is_some() - }) - .cloned() - }) - .ok() - .flatten() - .map(|repo| { - repo.update(cx, |repo, _| { - let current_branch = - repo.branch.as_ref().map(|branch| branch.name().to_owned()); - repo.send_job(None, |state, _| async move { - let RepositoryState::Local { backend, .. } = state else { - return GitState { - remote_url: None, - head_sha: None, - current_branch, - diff: None, - }; - }; - - let remote_url = backend.remote_url("origin"); - let head_sha = backend.head_sha().await; - let diff = backend.diff(DiffType::HeadToWorktree).await.ok(); - - GitState { - remote_url, - head_sha, - current_branch, - diff, - } - }) - }) - }); - - let git_state = match git_state { - Some(git_state) => match git_state.ok() { - Some(git_state) => git_state.await.ok(), - None => None, - }, - None => None, - }; - - WorktreeSnapshot { - worktree_path, - git_state, - } - }) - } - - pub fn project_context(&self) -> &Entity { - &self.project_context - } - - pub fn project(&self) -> &Entity { - &self.project - } - - pub fn action_log(&self) -> &Entity { - &self.action_log - } - - pub fn is_empty(&self) -> bool { - self.messages.is_empty() && self.title.is_none() - } - - pub fn model(&self) -> Option<&Arc> { - self.model.as_ref() - } - - pub fn set_model(&mut self, model: Arc, cx: &mut Context) { - let old_usage = self.latest_token_usage(); - self.model = Some(model); - let new_caps = Self::prompt_capabilities(self.model.as_deref()); - let new_usage = self.latest_token_usage(); - if old_usage != new_usage { - cx.emit(TokenUsageUpdated(new_usage)); - } - self.prompt_capabilities_tx.send(new_caps).log_err(); - cx.notify() - } - - pub fn summarization_model(&self) -> Option<&Arc> { - self.summarization_model.as_ref() - } - - pub fn set_summarization_model( - &mut self, - model: Option>, - cx: &mut Context, - ) { - self.summarization_model = model; - cx.notify() - } - - pub fn completion_mode(&self) -> CompletionMode { - self.completion_mode - } - - pub fn set_completion_mode(&mut self, mode: CompletionMode, cx: &mut Context) { - let old_usage = self.latest_token_usage(); - self.completion_mode = mode; - let new_usage = self.latest_token_usage(); - if old_usage != new_usage { - cx.emit(TokenUsageUpdated(new_usage)); - } - cx.notify() - } - - #[cfg(any(test, feature = "test-support"))] - pub fn last_message(&self) -> Option { - if let Some(message) = self.pending_message.clone() { - Some(Message::Agent(message)) - } else { - self.messages.last().cloned() - } - } - - pub fn add_default_tools( - &mut self, - environment: Rc, - cx: &mut Context, - ) { - let language_registry = self.project.read(cx).languages().clone(); - self.add_tool(CopyPathTool::new(self.project.clone())); - self.add_tool(CreateDirectoryTool::new(self.project.clone())); - self.add_tool(DeletePathTool::new( - self.project.clone(), - self.action_log.clone(), - )); - self.add_tool(DiagnosticsTool::new(self.project.clone())); - self.add_tool(EditFileTool::new( - self.project.clone(), - cx.weak_entity(), - language_registry, - )); - self.add_tool(FetchTool::new(self.project.read(cx).client().http_client())); - self.add_tool(FindPathTool::new(self.project.clone())); - self.add_tool(GrepTool::new(self.project.clone())); - self.add_tool(ListDirectoryTool::new(self.project.clone())); - self.add_tool(MovePathTool::new(self.project.clone())); - self.add_tool(NowTool); - self.add_tool(OpenTool::new(self.project.clone())); - self.add_tool(ReadFileTool::new( - self.project.clone(), - self.action_log.clone(), - )); - self.add_tool(TerminalTool::new(self.project.clone(), environment)); - self.add_tool(ThinkingTool); - self.add_tool(WebSearchTool); - } - - pub fn add_tool(&mut self, tool: T) { - self.tools.insert(T::name().into(), tool.erase()); - } - - pub fn remove_tool(&mut self, name: &str) -> bool { - self.tools.remove(name).is_some() - } - - pub fn profile(&self) -> &AgentProfileId { - &self.profile_id - } - - pub fn set_profile(&mut self, profile_id: AgentProfileId) { - self.profile_id = profile_id; - } - - pub fn cancel(&mut self, cx: &mut Context) { - if let Some(running_turn) = self.running_turn.take() { - running_turn.cancel(); - } - self.flush_pending_message(cx); - } - - fn update_token_usage(&mut self, update: language_model::TokenUsage, cx: &mut Context) { - let Some(last_user_message) = self.last_user_message() else { - return; - }; - - self.request_token_usage - .insert(last_user_message.id.clone(), update); - cx.emit(TokenUsageUpdated(self.latest_token_usage())); - cx.notify(); - } - - pub fn truncate(&mut self, message_id: UserMessageId, cx: &mut Context) -> Result<()> { - self.cancel(cx); - let Some(position) = self.messages.iter().position( - |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id), - ) else { - return Err(anyhow!("Message not found")); - }; - - for message in self.messages.drain(position..) { - match message { - Message::User(message) => { - self.request_token_usage.remove(&message.id); - } - Message::Agent(_) | Message::Resume => {} - } - } - self.summary = None; - cx.notify(); - Ok(()) - } - - pub fn latest_token_usage(&self) -> Option { - let last_user_message = self.last_user_message()?; - let tokens = self.request_token_usage.get(&last_user_message.id)?; - let model = self.model.clone()?; - - Some(acp_thread::TokenUsage { - max_tokens: model.max_token_count_for_mode(self.completion_mode.into()), - used_tokens: tokens.total_tokens(), - }) - } - - pub fn resume( - &mut self, - cx: &mut Context, - ) -> Result>> { - self.messages.push(Message::Resume); - cx.notify(); - - log::debug!("Total messages in thread: {}", self.messages.len()); - self.run_turn(cx) - } - - /// Sending a message results in the model streaming a response, which could include tool calls. - /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent. - /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn. - pub fn send( - &mut self, - id: UserMessageId, - content: impl IntoIterator, - cx: &mut Context, - ) -> Result>> - where - T: Into, - { - let model = self.model().context("No language model configured")?; - - log::info!("Thread::send called with model: {}", model.name().0); - self.advance_prompt_id(); - - let content = content.into_iter().map(Into::into).collect::>(); - log::debug!("Thread::send content: {:?}", content); - - self.messages - .push(Message::User(UserMessage { id, content })); - cx.notify(); - - log::debug!("Total messages in thread: {}", self.messages.len()); - self.run_turn(cx) - } - - fn run_turn( - &mut self, - cx: &mut Context, - ) -> Result>> { - self.cancel(cx); - - let model = self.model.clone().context("No language model configured")?; - let profile = AgentSettings::get_global(cx) - .profiles - .get(&self.profile_id) - .context("Profile not found")?; - let (events_tx, events_rx) = mpsc::unbounded::>(); - let event_stream = ThreadEventStream(events_tx); - let message_ix = self.messages.len().saturating_sub(1); - self.tool_use_limit_reached = false; - self.summary = None; - self.running_turn = Some(RunningTurn { - event_stream: event_stream.clone(), - tools: self.enabled_tools(profile, &model, cx), - _task: cx.spawn(async move |this, cx| { - log::debug!("Starting agent turn execution"); - - let turn_result = Self::run_turn_internal(&this, model, &event_stream, cx).await; - _ = this.update(cx, |this, cx| this.flush_pending_message(cx)); - - match turn_result { - Ok(()) => { - log::debug!("Turn execution completed"); - event_stream.send_stop(acp::StopReason::EndTurn); - } - Err(error) => { - log::error!("Turn execution failed: {:?}", error); - match error.downcast::() { - Ok(CompletionError::Refusal) => { - event_stream.send_stop(acp::StopReason::Refusal); - _ = this.update(cx, |this, _| this.messages.truncate(message_ix)); - } - Ok(CompletionError::MaxTokens) => { - event_stream.send_stop(acp::StopReason::MaxTokens); - } - Ok(CompletionError::Other(error)) | Err(error) => { - event_stream.send_error(error); - } - } - } - } - - _ = this.update(cx, |this, _| this.running_turn.take()); - }), - }); - Ok(events_rx) - } - - async fn run_turn_internal( - this: &WeakEntity, - model: Arc, - event_stream: &ThreadEventStream, - cx: &mut AsyncApp, - ) -> Result<()> { - let mut attempt = 0; - let mut intent = CompletionIntent::UserPrompt; - loop { - let request = - this.update(cx, |this, cx| this.build_completion_request(intent, cx))??; - - telemetry::event!( - "Agent Thread Completion", - thread_id = this.read_with(cx, |this, _| this.id.to_string())?, - prompt_id = this.read_with(cx, |this, _| this.prompt_id.to_string())?, - model = model.telemetry_id(), - model_provider = model.provider_id().to_string(), - attempt - ); - - log::debug!("Calling model.stream_completion, attempt {}", attempt); - - let (mut events, mut error) = match model.stream_completion(request, cx).await { - Ok(events) => (events, None), - Err(err) => (stream::empty().boxed(), Some(err)), - }; - let mut tool_results = FuturesUnordered::new(); - while let Some(event) = events.next().await { - log::trace!("Received completion event: {:?}", event); - match event { - Ok(event) => { - tool_results.extend(this.update(cx, |this, cx| { - this.handle_completion_event(event, event_stream, cx) - })??); - } - Err(err) => { - error = Some(err); - break; - } - } - } - - let end_turn = tool_results.is_empty(); - while let Some(tool_result) = tool_results.next().await { - log::debug!("Tool finished {:?}", tool_result); - - event_stream.update_tool_call_fields( - &tool_result.tool_use_id, - acp::ToolCallUpdateFields { - status: Some(if tool_result.is_error { - acp::ToolCallStatus::Failed - } else { - acp::ToolCallStatus::Completed - }), - raw_output: tool_result.output.clone(), - ..Default::default() - }, - ); - this.update(cx, |this, _cx| { - this.pending_message() - .tool_results - .insert(tool_result.tool_use_id.clone(), tool_result); - })?; - } - - this.update(cx, |this, cx| { - this.flush_pending_message(cx); - if this.title.is_none() && this.pending_title_generation.is_none() { - this.generate_title(cx); - } - })?; - - if let Some(error) = error { - attempt += 1; - let retry = this.update(cx, |this, cx| { - let user_store = this.user_store.read(cx); - this.handle_completion_error(error, attempt, user_store.plan()) - })??; - let timer = cx.background_executor().timer(retry.duration); - event_stream.send_retry(retry); - timer.await; - this.update(cx, |this, _cx| { - if let Some(Message::Agent(message)) = this.messages.last() { - if message.tool_results.is_empty() { - intent = CompletionIntent::UserPrompt; - this.messages.push(Message::Resume); - } - } - })?; - } else if this.read_with(cx, |this, _| this.tool_use_limit_reached)? { - return Err(language_model::ToolUseLimitReachedError.into()); - } else if end_turn { - return Ok(()); - } else { - intent = CompletionIntent::ToolResults; - attempt = 0; - } - } - } - - fn handle_completion_error( - &mut self, - error: LanguageModelCompletionError, - attempt: u8, - plan: Option, - ) -> Result { - let Some(model) = self.model.as_ref() else { - return Err(anyhow!(error)); - }; - - let auto_retry = if model.provider_id() == ZED_CLOUD_PROVIDER_ID { - match plan { - Some(Plan::V2(_)) => true, - Some(Plan::V1(_)) => self.completion_mode == CompletionMode::Burn, - None => false, - } - } else { - true - }; - - if !auto_retry { - return Err(anyhow!(error)); - } - - let Some(strategy) = Self::retry_strategy_for(&error) else { - return Err(anyhow!(error)); - }; - - let max_attempts = match &strategy { - RetryStrategy::ExponentialBackoff { max_attempts, .. } => *max_attempts, - RetryStrategy::Fixed { max_attempts, .. } => *max_attempts, - }; - - if attempt > max_attempts { - return Err(anyhow!(error)); - } - - let delay = match &strategy { - RetryStrategy::ExponentialBackoff { initial_delay, .. } => { - let delay_secs = initial_delay.as_secs() * 2u64.pow((attempt - 1) as u32); - Duration::from_secs(delay_secs) - } - RetryStrategy::Fixed { delay, .. } => *delay, - }; - log::debug!("Retry attempt {attempt} with delay {delay:?}"); - - Ok(acp_thread::RetryStatus { - last_error: error.to_string().into(), - attempt: attempt as usize, - max_attempts: max_attempts as usize, - started_at: Instant::now(), - duration: delay, - }) - } - - /// A helper method that's called on every streamed completion event. - /// Returns an optional tool result task, which the main agentic loop will - /// send back to the model when it resolves. - fn handle_completion_event( - &mut self, - event: LanguageModelCompletionEvent, - event_stream: &ThreadEventStream, - cx: &mut Context, - ) -> Result>> { - log::trace!("Handling streamed completion event: {:?}", event); - use LanguageModelCompletionEvent::*; - - match event { - StartMessage { .. } => { - self.flush_pending_message(cx); - self.pending_message = Some(AgentMessage::default()); - } - Text(new_text) => self.handle_text_event(new_text, event_stream, cx), - Thinking { text, signature } => { - self.handle_thinking_event(text, signature, event_stream, cx) - } - RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx), - ToolUse(tool_use) => { - return Ok(self.handle_tool_use_event(tool_use, event_stream, cx)); - } - ToolUseJsonParseError { - id, - tool_name, - raw_input, - json_parse_error, - } => { - return Ok(Some(Task::ready( - self.handle_tool_use_json_parse_error_event( - id, - tool_name, - raw_input, - json_parse_error, - ), - ))); - } - UsageUpdate(usage) => { - telemetry::event!( - "Agent Thread Completion Usage Updated", - thread_id = self.id.to_string(), - prompt_id = self.prompt_id.to_string(), - model = self.model.as_ref().map(|m| m.telemetry_id()), - model_provider = self.model.as_ref().map(|m| m.provider_id().to_string()), - input_tokens = usage.input_tokens, - output_tokens = usage.output_tokens, - cache_creation_input_tokens = usage.cache_creation_input_tokens, - cache_read_input_tokens = usage.cache_read_input_tokens, - ); - self.update_token_usage(usage, cx); - } - StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => { - self.update_model_request_usage(amount, limit, cx); - } - StatusUpdate( - CompletionRequestStatus::Started - | CompletionRequestStatus::Queued { .. } - | CompletionRequestStatus::Failed { .. }, - ) => {} - StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => { - self.tool_use_limit_reached = true; - } - Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()), - Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()), - Stop(StopReason::ToolUse | StopReason::EndTurn) => {} - } - - Ok(None) - } - - fn handle_text_event( - &mut self, - new_text: String, - event_stream: &ThreadEventStream, - cx: &mut Context, - ) { - event_stream.send_text(&new_text); - - let last_message = self.pending_message(); - if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() { - text.push_str(&new_text); - } else { - last_message - .content - .push(AgentMessageContent::Text(new_text)); - } - - cx.notify(); - } - - fn handle_thinking_event( - &mut self, - new_text: String, - new_signature: Option, - event_stream: &ThreadEventStream, - cx: &mut Context, - ) { - event_stream.send_thinking(&new_text); - - let last_message = self.pending_message(); - if let Some(AgentMessageContent::Thinking { text, signature }) = - last_message.content.last_mut() - { - text.push_str(&new_text); - *signature = new_signature.or(signature.take()); - } else { - last_message.content.push(AgentMessageContent::Thinking { - text: new_text, - signature: new_signature, - }); - } - - cx.notify(); - } - - fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context) { - let last_message = self.pending_message(); - last_message - .content - .push(AgentMessageContent::RedactedThinking(data)); - cx.notify(); - } - - fn handle_tool_use_event( - &mut self, - tool_use: LanguageModelToolUse, - event_stream: &ThreadEventStream, - cx: &mut Context, - ) -> Option> { - cx.notify(); - - let tool = self.tool(tool_use.name.as_ref()); - let mut title = SharedString::from(&tool_use.name); - let mut kind = acp::ToolKind::Other; - if let Some(tool) = tool.as_ref() { - title = tool.initial_title(tool_use.input.clone(), cx); - kind = tool.kind(); - } - - // Ensure the last message ends in the current tool use - let last_message = self.pending_message(); - let push_new_tool_use = last_message.content.last_mut().is_none_or(|content| { - if let AgentMessageContent::ToolUse(last_tool_use) = content { - if last_tool_use.id == tool_use.id { - *last_tool_use = tool_use.clone(); - false - } else { - true - } - } else { - true - } - }); - - if push_new_tool_use { - event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone()); - last_message - .content - .push(AgentMessageContent::ToolUse(tool_use.clone())); - } else { - event_stream.update_tool_call_fields( - &tool_use.id, - acp::ToolCallUpdateFields { - title: Some(title.into()), - kind: Some(kind), - raw_input: Some(tool_use.input.clone()), - ..Default::default() - }, - ); - } - - if !tool_use.is_input_complete { - return None; - } - - let Some(tool) = tool else { - let content = format!("No tool named {} exists", tool_use.name); - return Some(Task::ready(LanguageModelToolResult { - content: LanguageModelToolResultContent::Text(Arc::from(content)), - tool_use_id: tool_use.id, - tool_name: tool_use.name, - is_error: true, - output: None, - })); - }; - - let fs = self.project.read(cx).fs().clone(); - let tool_event_stream = - ToolCallEventStream::new(tool_use.id.clone(), event_stream.clone(), Some(fs)); - tool_event_stream.update_fields(acp::ToolCallUpdateFields { - status: Some(acp::ToolCallStatus::InProgress), - ..Default::default() - }); - let supports_images = self.model().is_some_and(|model| model.supports_images()); - let tool_result = tool.run(tool_use.input, tool_event_stream, cx); - log::debug!("Running tool {}", tool_use.name); - Some(cx.foreground_executor().spawn(async move { - let tool_result = tool_result.await.and_then(|output| { - if let LanguageModelToolResultContent::Image(_) = &output.llm_output - && !supports_images - { - return Err(anyhow!( - "Attempted to read an image, but this model doesn't support it.", - )); - } - Ok(output) - }); - - match tool_result { - Ok(output) => LanguageModelToolResult { - tool_use_id: tool_use.id, - tool_name: tool_use.name, - is_error: false, - content: output.llm_output, - output: Some(output.raw_output), - }, - Err(error) => LanguageModelToolResult { - tool_use_id: tool_use.id, - tool_name: tool_use.name, - is_error: true, - content: LanguageModelToolResultContent::Text(Arc::from(error.to_string())), - output: Some(error.to_string().into()), - }, - } - })) - } - - fn handle_tool_use_json_parse_error_event( - &mut self, - tool_use_id: LanguageModelToolUseId, - tool_name: Arc, - raw_input: Arc, - json_parse_error: String, - ) -> LanguageModelToolResult { - let tool_output = format!("Error parsing input JSON: {json_parse_error}"); - LanguageModelToolResult { - tool_use_id, - tool_name, - is_error: true, - content: LanguageModelToolResultContent::Text(tool_output.into()), - output: Some(serde_json::Value::String(raw_input.to_string())), - } - } - - fn update_model_request_usage(&self, amount: usize, limit: UsageLimit, cx: &mut Context) { - self.project - .read(cx) - .user_store() - .update(cx, |user_store, cx| { - user_store.update_model_request_usage( - ModelRequestUsage(RequestUsage { - amount: amount as i32, - limit, - }), - cx, - ) - }); - } - - pub fn title(&self) -> SharedString { - self.title.clone().unwrap_or("New Thread".into()) - } - - pub fn summary(&mut self, cx: &mut Context) -> Task> { - if let Some(summary) = self.summary.as_ref() { - return Task::ready(Ok(summary.clone())); - } - let Some(model) = self.summarization_model.clone() else { - return Task::ready(Err(anyhow!("No summarization model available"))); - }; - let mut request = LanguageModelRequest { - intent: Some(CompletionIntent::ThreadContextSummarization), - temperature: AgentSettings::temperature_for_model(&model, cx), - ..Default::default() - }; - - for message in &self.messages { - request.messages.extend(message.to_request()); - } - - request.messages.push(LanguageModelRequestMessage { - role: Role::User, - content: vec![SUMMARIZE_THREAD_DETAILED_PROMPT.into()], - cache: false, - }); - cx.spawn(async move |this, cx| { - let mut summary = String::new(); - let mut messages = model.stream_completion(request, cx).await?; - while let Some(event) = messages.next().await { - let event = event?; - let text = match event { - LanguageModelCompletionEvent::Text(text) => text, - LanguageModelCompletionEvent::StatusUpdate( - CompletionRequestStatus::UsageUpdated { amount, limit }, - ) => { - this.update(cx, |thread, cx| { - thread.update_model_request_usage(amount, limit, cx); - })?; - continue; - } - _ => continue, - }; - - let mut lines = text.lines(); - summary.extend(lines.next()); - } - - log::debug!("Setting summary: {}", summary); - let summary = SharedString::from(summary); - - this.update(cx, |this, cx| { - this.summary = Some(summary.clone()); - cx.notify() - })?; - - Ok(summary) - }) - } - - fn generate_title(&mut self, cx: &mut Context) { - let Some(model) = self.summarization_model.clone() else { - return; - }; - - log::debug!( - "Generating title with model: {:?}", - self.summarization_model.as_ref().map(|model| model.name()) - ); - let mut request = LanguageModelRequest { - intent: Some(CompletionIntent::ThreadSummarization), - temperature: AgentSettings::temperature_for_model(&model, cx), - ..Default::default() - }; - - for message in &self.messages { - request.messages.extend(message.to_request()); - } - - request.messages.push(LanguageModelRequestMessage { - role: Role::User, - content: vec![SUMMARIZE_THREAD_PROMPT.into()], - cache: false, - }); - self.pending_title_generation = Some(cx.spawn(async move |this, cx| { - let mut title = String::new(); - - let generate = async { - let mut messages = model.stream_completion(request, cx).await?; - while let Some(event) = messages.next().await { - let event = event?; - let text = match event { - LanguageModelCompletionEvent::Text(text) => text, - LanguageModelCompletionEvent::StatusUpdate( - CompletionRequestStatus::UsageUpdated { amount, limit }, - ) => { - this.update(cx, |thread, cx| { - thread.update_model_request_usage(amount, limit, cx); - })?; - continue; - } - _ => continue, - }; - - let mut lines = text.lines(); - title.extend(lines.next()); - - // Stop if the LLM generated multiple lines. - if lines.next().is_some() { - break; - } - } - anyhow::Ok(()) - }; - - if generate.await.context("failed to generate title").is_ok() { - _ = this.update(cx, |this, cx| this.set_title(title.into(), cx)); - } - _ = this.update(cx, |this, _| this.pending_title_generation = None); - })); - } - - pub fn set_title(&mut self, title: SharedString, cx: &mut Context) { - self.pending_title_generation = None; - if Some(&title) != self.title.as_ref() { - self.title = Some(title); - cx.emit(TitleUpdated); - cx.notify(); - } - } - - fn last_user_message(&self) -> Option<&UserMessage> { - self.messages - .iter() - .rev() - .find_map(|message| match message { - Message::User(user_message) => Some(user_message), - Message::Agent(_) => None, - Message::Resume => None, - }) - } - - fn pending_message(&mut self) -> &mut AgentMessage { - self.pending_message.get_or_insert_default() - } - - fn flush_pending_message(&mut self, cx: &mut Context) { - let Some(mut message) = self.pending_message.take() else { - return; - }; - - if message.content.is_empty() { - return; - } - - for content in &message.content { - let AgentMessageContent::ToolUse(tool_use) = content else { - continue; - }; - - if !message.tool_results.contains_key(&tool_use.id) { - message.tool_results.insert( - tool_use.id.clone(), - LanguageModelToolResult { - tool_use_id: tool_use.id.clone(), - tool_name: tool_use.name.clone(), - is_error: true, - content: LanguageModelToolResultContent::Text(TOOL_CANCELED_MESSAGE.into()), - output: None, - }, - ); - } - } - - self.messages.push(Message::Agent(message)); - self.updated_at = Utc::now(); - self.summary = None; - cx.notify() - } - - pub(crate) fn build_completion_request( - &self, - completion_intent: CompletionIntent, - cx: &App, - ) -> Result { - let model = self.model().context("No language model configured")?; - let tools = if let Some(turn) = self.running_turn.as_ref() { - turn.tools - .iter() - .filter_map(|(tool_name, tool)| { - log::trace!("Including tool: {}", tool_name); - Some(LanguageModelRequestTool { - name: tool_name.to_string(), - description: tool.description().to_string(), - input_schema: tool.input_schema(model.tool_input_format()).log_err()?, - }) - }) - .collect::>() - } else { - Vec::new() - }; - - log::debug!("Building completion request"); - log::debug!("Completion intent: {:?}", completion_intent); - log::debug!("Completion mode: {:?}", self.completion_mode); - - let messages = self.build_request_messages(cx); - log::debug!("Request will include {} messages", messages.len()); - log::debug!("Request includes {} tools", tools.len()); - - let request = LanguageModelRequest { - thread_id: Some(self.id.to_string()), - prompt_id: Some(self.prompt_id.to_string()), - intent: Some(completion_intent), - mode: Some(self.completion_mode.into()), - messages, - tools, - tool_choice: None, - stop: Vec::new(), - temperature: AgentSettings::temperature_for_model(model, cx), - thinking_allowed: true, - }; - - log::debug!("Completion request built successfully"); - Ok(request) - } - - fn enabled_tools( - &self, - profile: &AgentProfileSettings, - model: &Arc, - cx: &App, - ) -> BTreeMap> { - fn truncate(tool_name: &SharedString) -> SharedString { - if tool_name.len() > MAX_TOOL_NAME_LENGTH { - let mut truncated = tool_name.to_string(); - truncated.truncate(MAX_TOOL_NAME_LENGTH); - truncated.into() - } else { - tool_name.clone() - } - } - - let mut tools = self - .tools - .iter() - .filter_map(|(tool_name, tool)| { - if tool.supported_provider(&model.provider_id()) - && profile.is_tool_enabled(tool_name) - { - Some((truncate(tool_name), tool.clone())) - } else { - None - } - }) - .collect::>(); - - let mut context_server_tools = Vec::new(); - let mut seen_tools = tools.keys().cloned().collect::>(); - let mut duplicate_tool_names = HashSet::default(); - for (server_id, server_tools) in self.context_server_registry.read(cx).servers() { - for (tool_name, tool) in server_tools { - if profile.is_context_server_tool_enabled(&server_id.0, &tool_name) { - let tool_name = truncate(tool_name); - if !seen_tools.insert(tool_name.clone()) { - duplicate_tool_names.insert(tool_name.clone()); - } - context_server_tools.push((server_id.clone(), tool_name, tool.clone())); - } - } - } - - // When there are duplicate tool names, disambiguate by prefixing them - // with the server ID. In the rare case there isn't enough space for the - // disambiguated tool name, keep only the last tool with this name. - for (server_id, tool_name, tool) in context_server_tools { - if duplicate_tool_names.contains(&tool_name) { - let available = MAX_TOOL_NAME_LENGTH.saturating_sub(tool_name.len()); - if available >= 2 { - let mut disambiguated = server_id.0.to_string(); - disambiguated.truncate(available - 1); - disambiguated.push('_'); - disambiguated.push_str(&tool_name); - tools.insert(disambiguated.into(), tool.clone()); - } else { - tools.insert(tool_name, tool.clone()); - } - } else { - tools.insert(tool_name, tool.clone()); - } - } - - tools - } - - fn tool(&self, name: &str) -> Option> { - self.running_turn.as_ref()?.tools.get(name).cloned() - } - - fn build_request_messages(&self, cx: &App) -> Vec { - log::trace!( - "Building request messages from {} thread messages", - self.messages.len() - ); - - let system_prompt = SystemPromptTemplate { - project: self.project_context.read(cx), - available_tools: self.tools.keys().cloned().collect(), - } - .render(&self.templates) - .context("failed to build system prompt") - .expect("Invalid template"); - let mut messages = vec![LanguageModelRequestMessage { - role: Role::System, - content: vec![system_prompt.into()], - cache: false, - }]; - for message in &self.messages { - messages.extend(message.to_request()); - } - - if let Some(last_message) = messages.last_mut() { - last_message.cache = true; - } - - if let Some(message) = self.pending_message.as_ref() { - messages.extend(message.to_request()); - } - - messages - } - - pub fn to_markdown(&self) -> String { - let mut markdown = String::new(); - for (ix, message) in self.messages.iter().enumerate() { - if ix > 0 { - markdown.push('\n'); - } - markdown.push_str(&message.to_markdown()); - } - - if let Some(message) = self.pending_message.as_ref() { - markdown.push('\n'); - markdown.push_str(&message.to_markdown()); - } - - markdown - } - - fn advance_prompt_id(&mut self) { - self.prompt_id = PromptId::new(); - } - - fn retry_strategy_for(error: &LanguageModelCompletionError) -> Option { - use LanguageModelCompletionError::*; - use http_client::StatusCode; - - // General strategy here: - // - If retrying won't help (e.g. invalid API key or payload too large), return None so we don't retry at all. - // - If it's a time-based issue (e.g. server overloaded, rate limit exceeded), retry up to 4 times with exponential backoff. - // - If it's an issue that *might* be fixed by retrying (e.g. internal server error), retry up to 3 times. - match error { - HttpResponseError { - status_code: StatusCode::TOO_MANY_REQUESTS, - .. - } => Some(RetryStrategy::ExponentialBackoff { - initial_delay: BASE_RETRY_DELAY, - max_attempts: MAX_RETRY_ATTEMPTS, - }), - ServerOverloaded { retry_after, .. } | RateLimitExceeded { retry_after, .. } => { - Some(RetryStrategy::Fixed { - delay: retry_after.unwrap_or(BASE_RETRY_DELAY), - max_attempts: MAX_RETRY_ATTEMPTS, - }) - } - UpstreamProviderError { - status, - retry_after, - .. - } => match *status { - StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE => { - Some(RetryStrategy::Fixed { - delay: retry_after.unwrap_or(BASE_RETRY_DELAY), - max_attempts: MAX_RETRY_ATTEMPTS, - }) - } - StatusCode::INTERNAL_SERVER_ERROR => Some(RetryStrategy::Fixed { - delay: retry_after.unwrap_or(BASE_RETRY_DELAY), - // Internal Server Error could be anything, retry up to 3 times. - max_attempts: 3, - }), - status => { - // There is no StatusCode variant for the unofficial HTTP 529 ("The service is overloaded"), - // but we frequently get them in practice. See https://http.dev/529 - if status.as_u16() == 529 { - Some(RetryStrategy::Fixed { - delay: retry_after.unwrap_or(BASE_RETRY_DELAY), - max_attempts: MAX_RETRY_ATTEMPTS, - }) - } else { - Some(RetryStrategy::Fixed { - delay: retry_after.unwrap_or(BASE_RETRY_DELAY), - max_attempts: 2, - }) - } - } - }, - ApiInternalServerError { .. } => Some(RetryStrategy::Fixed { - delay: BASE_RETRY_DELAY, - max_attempts: 3, - }), - ApiReadResponseError { .. } - | HttpSend { .. } - | DeserializeResponse { .. } - | BadRequestFormat { .. } => Some(RetryStrategy::Fixed { - delay: BASE_RETRY_DELAY, - max_attempts: 3, - }), - // Retrying these errors definitely shouldn't help. - HttpResponseError { - status_code: - StatusCode::PAYLOAD_TOO_LARGE | StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED, - .. - } - | AuthenticationError { .. } - | PermissionError { .. } - | NoApiKey { .. } - | ApiEndpointNotFound { .. } - | PromptTooLarge { .. } => None, - // These errors might be transient, so retry them - SerializeRequest { .. } | BuildRequestBody { .. } => Some(RetryStrategy::Fixed { - delay: BASE_RETRY_DELAY, - max_attempts: 1, - }), - // Retry all other 4xx and 5xx errors once. - HttpResponseError { status_code, .. } - if status_code.is_client_error() || status_code.is_server_error() => - { - Some(RetryStrategy::Fixed { - delay: BASE_RETRY_DELAY, - max_attempts: 3, - }) - } - Other(err) - if err.is::() - || err.is::() => - { - // Retrying won't help for Payment Required or Model Request Limit errors (where - // the user must upgrade to usage-based billing to get more requests, or else wait - // for a significant amount of time for the request limit to reset). - None - } - // Conservatively assume that any other errors are non-retryable - HttpResponseError { .. } | Other(..) => Some(RetryStrategy::Fixed { - delay: BASE_RETRY_DELAY, - max_attempts: 2, - }), - } - } -} - -struct RunningTurn { - /// Holds the task that handles agent interaction until the end of the turn. - /// Survives across multiple requests as the model performs tool calls and - /// we run tools, report their results. - _task: Task<()>, - /// The current event stream for the running turn. Used to report a final - /// cancellation event if we cancel the turn. - event_stream: ThreadEventStream, - /// The tools that were enabled for this turn. - tools: BTreeMap>, -} - -impl RunningTurn { - fn cancel(self) { - log::debug!("Cancelling in progress turn"); - self.event_stream.send_canceled(); - } -} - -pub struct TokenUsageUpdated(pub Option); - -impl EventEmitter for Thread {} - -pub struct TitleUpdated; - -impl EventEmitter for Thread {} - -pub trait AgentTool -where - Self: 'static + Sized, -{ - type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema; - type Output: for<'de> Deserialize<'de> + Serialize + Into; - - fn name() -> &'static str; - - fn description(&self) -> SharedString { - let schema = schemars::schema_for!(Self::Input); - SharedString::new( - schema - .get("description") - .and_then(|description| description.as_str()) - .unwrap_or_default(), - ) - } - - fn kind() -> acp::ToolKind; - - /// The initial tool title to display. Can be updated during the tool run. - fn initial_title( - &self, - input: Result, - cx: &mut App, - ) -> SharedString; - - /// Returns the JSON schema that describes the tool's input. - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Schema { - crate::tool_schema::root_schema_for::(format) - } - - /// Some tools rely on a provider for the underlying billing or other reasons. - /// Allow the tool to check if they are compatible, or should be filtered out. - fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool { - true - } - - /// Runs the tool with the provided input. - fn run( - self: Arc, - input: Self::Input, - event_stream: ToolCallEventStream, - cx: &mut App, - ) -> Task>; - - /// Emits events for a previous execution of the tool. - fn replay( - &self, - _input: Self::Input, - _output: Self::Output, - _event_stream: ToolCallEventStream, - _cx: &mut App, - ) -> Result<()> { - Ok(()) - } - - fn erase(self) -> Arc { - Arc::new(Erased(Arc::new(self))) - } -} - -pub struct Erased(T); - -pub struct AgentToolOutput { - pub llm_output: LanguageModelToolResultContent, - pub raw_output: serde_json::Value, -} - -pub trait AnyAgentTool { - fn name(&self) -> SharedString; - fn description(&self) -> SharedString; - fn kind(&self) -> acp::ToolKind; - fn initial_title(&self, input: serde_json::Value, _cx: &mut App) -> SharedString; - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result; - fn supported_provider(&self, _provider: &LanguageModelProviderId) -> bool { - true - } - fn run( - self: Arc, - input: serde_json::Value, - event_stream: ToolCallEventStream, - cx: &mut App, - ) -> Task>; - fn replay( - &self, - input: serde_json::Value, - output: serde_json::Value, - event_stream: ToolCallEventStream, - cx: &mut App, - ) -> Result<()>; -} - -impl AnyAgentTool for Erased> -where - T: AgentTool, -{ - fn name(&self) -> SharedString { - T::name().into() - } - - fn description(&self) -> SharedString { - self.0.description() - } - - fn kind(&self) -> agent_client_protocol::ToolKind { - T::kind() - } - - fn initial_title(&self, input: serde_json::Value, _cx: &mut App) -> SharedString { - let parsed_input = serde_json::from_value(input.clone()).map_err(|_| input); - self.0.initial_title(parsed_input, _cx) - } - - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { - let mut json = serde_json::to_value(self.0.input_schema(format))?; - adapt_schema_to_format(&mut json, format)?; - Ok(json) - } - - fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool { - self.0.supported_provider(provider) - } - - fn run( - self: Arc, - input: serde_json::Value, - event_stream: ToolCallEventStream, - cx: &mut App, - ) -> Task> { - cx.spawn(async move |cx| { - let input = serde_json::from_value(input)?; - let output = cx - .update(|cx| self.0.clone().run(input, event_stream, cx))? - .await?; - let raw_output = serde_json::to_value(&output)?; - Ok(AgentToolOutput { - llm_output: output.into(), - raw_output, - }) - }) - } - - fn replay( - &self, - input: serde_json::Value, - output: serde_json::Value, - event_stream: ToolCallEventStream, - cx: &mut App, - ) -> Result<()> { - let input = serde_json::from_value(input)?; - let output = serde_json::from_value(output)?; - self.0.replay(input, output, event_stream, cx) - } -} - -#[derive(Clone)] -struct ThreadEventStream(mpsc::UnboundedSender>); - -impl ThreadEventStream { - fn send_user_message(&self, message: &UserMessage) { - self.0 - .unbounded_send(Ok(ThreadEvent::UserMessage(message.clone()))) - .ok(); - } - - fn send_text(&self, text: &str) { - self.0 - .unbounded_send(Ok(ThreadEvent::AgentText(text.to_string()))) - .ok(); - } - - fn send_thinking(&self, text: &str) { - self.0 - .unbounded_send(Ok(ThreadEvent::AgentThinking(text.to_string()))) - .ok(); - } - - fn send_tool_call( - &self, - id: &LanguageModelToolUseId, - title: SharedString, - kind: acp::ToolKind, - input: serde_json::Value, - ) { - self.0 - .unbounded_send(Ok(ThreadEvent::ToolCall(Self::initial_tool_call( - id, - title.to_string(), - kind, - input, - )))) - .ok(); - } - - fn initial_tool_call( - id: &LanguageModelToolUseId, - title: String, - kind: acp::ToolKind, - input: serde_json::Value, - ) -> acp::ToolCall { - acp::ToolCall { - meta: None, - id: acp::ToolCallId(id.to_string().into()), - title, - kind, - status: acp::ToolCallStatus::Pending, - content: vec![], - locations: vec![], - raw_input: Some(input), - raw_output: None, - } - } - - fn update_tool_call_fields( - &self, - tool_use_id: &LanguageModelToolUseId, - fields: acp::ToolCallUpdateFields, - ) { - self.0 - .unbounded_send(Ok(ThreadEvent::ToolCallUpdate( - acp::ToolCallUpdate { - meta: None, - id: acp::ToolCallId(tool_use_id.to_string().into()), - fields, - } - .into(), - ))) - .ok(); - } - - fn send_retry(&self, status: acp_thread::RetryStatus) { - self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok(); - } - - fn send_stop(&self, reason: acp::StopReason) { - self.0.unbounded_send(Ok(ThreadEvent::Stop(reason))).ok(); - } - - fn send_canceled(&self) { - self.0 - .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled))) - .ok(); - } - - fn send_error(&self, error: impl Into) { - self.0.unbounded_send(Err(error.into())).ok(); - } -} - -#[derive(Clone)] -pub struct ToolCallEventStream { - tool_use_id: LanguageModelToolUseId, - stream: ThreadEventStream, - fs: Option>, -} - -impl ToolCallEventStream { - #[cfg(test)] - pub fn test() -> (Self, ToolCallEventStreamReceiver) { - let (events_tx, events_rx) = mpsc::unbounded::>(); - - let stream = ToolCallEventStream::new("test_id".into(), ThreadEventStream(events_tx), None); - - (stream, ToolCallEventStreamReceiver(events_rx)) - } - - fn new( - tool_use_id: LanguageModelToolUseId, - stream: ThreadEventStream, - fs: Option>, - ) -> Self { - Self { - tool_use_id, - stream, - fs, - } - } - - pub fn update_fields(&self, fields: acp::ToolCallUpdateFields) { - self.stream - .update_tool_call_fields(&self.tool_use_id, fields); - } - - pub fn update_diff(&self, diff: Entity) { - self.stream - .0 - .unbounded_send(Ok(ThreadEvent::ToolCallUpdate( - acp_thread::ToolCallUpdateDiff { - id: acp::ToolCallId(self.tool_use_id.to_string().into()), - diff, - } - .into(), - ))) - .ok(); - } - - pub fn authorize(&self, title: impl Into, cx: &mut App) -> Task> { - if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions { - return Task::ready(Ok(())); - } - - let (response_tx, response_rx) = oneshot::channel(); - self.stream - .0 - .unbounded_send(Ok(ThreadEvent::ToolCallAuthorization( - ToolCallAuthorization { - tool_call: acp::ToolCallUpdate { - meta: None, - id: acp::ToolCallId(self.tool_use_id.to_string().into()), - fields: acp::ToolCallUpdateFields { - title: Some(title.into()), - ..Default::default() - }, - }, - options: vec![ - acp::PermissionOption { - id: acp::PermissionOptionId("always_allow".into()), - name: "Always Allow".into(), - kind: acp::PermissionOptionKind::AllowAlways, - meta: None, - }, - acp::PermissionOption { - id: acp::PermissionOptionId("allow".into()), - name: "Allow".into(), - kind: acp::PermissionOptionKind::AllowOnce, - meta: None, - }, - acp::PermissionOption { - id: acp::PermissionOptionId("deny".into()), - name: "Deny".into(), - kind: acp::PermissionOptionKind::RejectOnce, - meta: None, - }, - ], - response: response_tx, - }, - ))) - .ok(); - let fs = self.fs.clone(); - cx.spawn(async move |cx| match response_rx.await?.0.as_ref() { - "always_allow" => { - if let Some(fs) = fs.clone() { - cx.update(|cx| { - update_settings_file(fs, cx, |settings, _| { - settings - .agent - .get_or_insert_default() - .set_always_allow_tool_actions(true); - }); - })?; - } - - Ok(()) - } - "allow" => Ok(()), - _ => Err(anyhow!("Permission to run tool denied by user")), - }) - } -} - -#[cfg(test)] -pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver>); - -#[cfg(test)] -impl ToolCallEventStreamReceiver { - pub async fn expect_authorization(&mut self) -> ToolCallAuthorization { - let event = self.0.next().await; - if let Some(Ok(ThreadEvent::ToolCallAuthorization(auth))) = event { - auth - } else { - panic!("Expected ToolCallAuthorization but got: {:?}", event); - } - } - - pub async fn expect_update_fields(&mut self) -> acp::ToolCallUpdateFields { - let event = self.0.next().await; - if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields( - update, - )))) = event - { - update.fields - } else { - panic!("Expected update fields but got: {:?}", event); - } - } - - pub async fn expect_diff(&mut self) -> Entity { - let event = self.0.next().await; - if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateDiff( - update, - )))) = event - { - update.diff - } else { - panic!("Expected diff but got: {:?}", event); - } - } - - pub async fn expect_terminal(&mut self) -> Entity { - let event = self.0.next().await; - if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal( - update, - )))) = event - { - update.terminal - } else { - panic!("Expected terminal but got: {:?}", event); - } - } -} - -#[cfg(test)] -impl std::ops::Deref for ToolCallEventStreamReceiver { - type Target = mpsc::UnboundedReceiver>; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -#[cfg(test)] -impl std::ops::DerefMut for ToolCallEventStreamReceiver { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl From<&str> for UserMessageContent { - fn from(text: &str) -> Self { - Self::Text(text.into()) - } -} - -impl From for UserMessageContent { - fn from(value: acp::ContentBlock) -> Self { - match value { - acp::ContentBlock::Text(text_content) => Self::Text(text_content.text), - acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)), - acp::ContentBlock::Audio(_) => { - // TODO - Self::Text("[audio]".to_string()) - } - acp::ContentBlock::ResourceLink(resource_link) => { - match MentionUri::parse(&resource_link.uri) { - Ok(uri) => Self::Mention { - uri, - content: String::new(), - }, - Err(err) => { - log::error!("Failed to parse mention link: {}", err); - Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri)) - } - } - } - acp::ContentBlock::Resource(resource) => match resource.resource { - acp::EmbeddedResourceResource::TextResourceContents(resource) => { - match MentionUri::parse(&resource.uri) { - Ok(uri) => Self::Mention { - uri, - content: resource.text, - }, - Err(err) => { - log::error!("Failed to parse mention link: {}", err); - Self::Text( - MarkdownCodeBlock { - tag: &resource.uri, - text: &resource.text, - } - .to_string(), - ) - } - } - } - acp::EmbeddedResourceResource::BlobResourceContents(_) => { - // TODO - Self::Text("[blob]".to_string()) - } - }, - } - } -} - -impl From for acp::ContentBlock { - fn from(content: UserMessageContent) -> Self { - match content { - UserMessageContent::Text(text) => acp::ContentBlock::Text(acp::TextContent { - text, - annotations: None, - meta: None, - }), - UserMessageContent::Image(image) => acp::ContentBlock::Image(acp::ImageContent { - data: image.source.to_string(), - mime_type: "image/png".to_string(), - meta: None, - annotations: None, - uri: None, - }), - UserMessageContent::Mention { uri, content } => { - acp::ContentBlock::Resource(acp::EmbeddedResource { - meta: None, - resource: acp::EmbeddedResourceResource::TextResourceContents( - acp::TextResourceContents { - meta: None, - mime_type: None, - text: content, - uri: uri.to_uri().to_string(), - }, - ), - annotations: None, - }) - } - } - } -} - -fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage { - LanguageModelImage { - source: image_content.data.into(), - // TODO: make this optional? - size: gpui::Size::new(0.into(), 0.into()), - } -} diff --git a/crates/agent2/src/tool_schema.rs b/crates/agent2/src/tool_schema.rs deleted file mode 100644 index f608336b416a72885e52abba58ef472029421e4f..0000000000000000000000000000000000000000 --- a/crates/agent2/src/tool_schema.rs +++ /dev/null @@ -1,43 +0,0 @@ -use language_model::LanguageModelToolSchemaFormat; -use schemars::{ - JsonSchema, Schema, - generate::SchemaSettings, - transform::{Transform, transform_subschemas}, -}; - -pub(crate) fn root_schema_for(format: LanguageModelToolSchemaFormat) -> Schema { - let mut generator = match format { - LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07().into_generator(), - LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3() - .with(|settings| { - settings.meta_schema = None; - settings.inline_subschemas = true; - }) - .with_transform(ToJsonSchemaSubsetTransform) - .into_generator(), - }; - generator.root_schema_for::() -} - -#[derive(Debug, Clone)] -struct ToJsonSchemaSubsetTransform; - -impl Transform for ToJsonSchemaSubsetTransform { - fn transform(&mut self, schema: &mut Schema) { - // Ensure that the type field is not an array, this happens when we use - // Option, the type will be [T, "null"]. - if let Some(type_field) = schema.get_mut("type") - && let Some(types) = type_field.as_array() - && let Some(first_type) = types.first() - { - *type_field = first_type.clone(); - } - - // oneOf is not supported, use anyOf instead - if let Some(one_of) = schema.remove("oneOf") { - schema.insert("anyOf".to_string(), one_of); - } - - transform_subschemas(self, schema); - } -} diff --git a/crates/agent2/src/tools.rs b/crates/agent2/src/tools.rs deleted file mode 100644 index bcca7eecd185b9381afded26fb573d14f50bc5be..0000000000000000000000000000000000000000 --- a/crates/agent2/src/tools.rs +++ /dev/null @@ -1,60 +0,0 @@ -mod context_server_registry; -mod copy_path_tool; -mod create_directory_tool; -mod delete_path_tool; -mod diagnostics_tool; -mod edit_file_tool; -mod fetch_tool; -mod find_path_tool; -mod grep_tool; -mod list_directory_tool; -mod move_path_tool; -mod now_tool; -mod open_tool; -mod read_file_tool; -mod terminal_tool; -mod thinking_tool; -mod web_search_tool; - -/// A list of all built in tool names, for use in deduplicating MCP tool names -pub fn default_tool_names() -> impl Iterator { - [ - CopyPathTool::name(), - CreateDirectoryTool::name(), - DeletePathTool::name(), - DiagnosticsTool::name(), - EditFileTool::name(), - FetchTool::name(), - FindPathTool::name(), - GrepTool::name(), - ListDirectoryTool::name(), - MovePathTool::name(), - NowTool::name(), - OpenTool::name(), - ReadFileTool::name(), - TerminalTool::name(), - ThinkingTool::name(), - WebSearchTool::name(), - ] - .into_iter() -} - -pub use context_server_registry::*; -pub use copy_path_tool::*; -pub use create_directory_tool::*; -pub use delete_path_tool::*; -pub use diagnostics_tool::*; -pub use edit_file_tool::*; -pub use fetch_tool::*; -pub use find_path_tool::*; -pub use grep_tool::*; -pub use list_directory_tool::*; -pub use move_path_tool::*; -pub use now_tool::*; -pub use open_tool::*; -pub use read_file_tool::*; -pub use terminal_tool::*; -pub use thinking_tool::*; -pub use web_search_tool::*; - -use crate::AgentTool; diff --git a/crates/agent_settings/src/agent_settings.rs b/crates/agent_settings/src/agent_settings.rs index ec05c95672fa29b6e4813207e3e592fff9d3be15..adab899fadf3b36d199dd13ee19dc8421da9da8f 100644 --- a/crates/agent_settings/src/agent_settings.rs +++ b/crates/agent_settings/src/agent_settings.rs @@ -15,10 +15,9 @@ use settings::{ pub use crate::agent_profile::*; -pub const SUMMARIZE_THREAD_PROMPT: &str = - include_str!("../../agent/src/prompts/summarize_thread_prompt.txt"); +pub const SUMMARIZE_THREAD_PROMPT: &str = include_str!("prompts/summarize_thread_prompt.txt"); pub const SUMMARIZE_THREAD_DETAILED_PROMPT: &str = - include_str!("../../agent/src/prompts/summarize_thread_detailed_prompt.txt"); + include_str!("prompts/summarize_thread_detailed_prompt.txt"); pub fn init(cx: &mut App) { AgentSettings::register(cx); diff --git a/crates/agent/src/prompts/summarize_thread_detailed_prompt.txt b/crates/agent_settings/src/prompts/summarize_thread_detailed_prompt.txt similarity index 100% rename from crates/agent/src/prompts/summarize_thread_detailed_prompt.txt rename to crates/agent_settings/src/prompts/summarize_thread_detailed_prompt.txt diff --git a/crates/agent/src/prompts/summarize_thread_prompt.txt b/crates/agent_settings/src/prompts/summarize_thread_prompt.txt similarity index 100% rename from crates/agent/src/prompts/summarize_thread_prompt.txt rename to crates/agent_settings/src/prompts/summarize_thread_prompt.txt diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml index 47d9f6d6a27a2ad5102e831094912208e66a9b43..d8f495c79614ff1aaf23c017160516c1e54065ab 100644 --- a/crates/agent_ui/Cargo.toml +++ b/crates/agent_ui/Cargo.toml @@ -20,7 +20,6 @@ acp_thread.workspace = true action_log.workspace = true agent-client-protocol.workspace = true agent.workspace = true -agent2.workspace = true agent_servers.workspace = true agent_settings.workspace = true ai_onboarding.workspace = true @@ -29,7 +28,6 @@ arrayvec.workspace = true assistant_context.workspace = true assistant_slash_command.workspace = true assistant_slash_commands.workspace = true -assistant_tool.workspace = true audio.workspace = true buffer_diff.workspace = true chrono.workspace = true @@ -71,6 +69,7 @@ postage.workspace = true project.workspace = true prompt_store.workspace = true proto.workspace = true +ref-cast.workspace = true release_channel.workspace = true rope.workspace = true rules_library.workspace = true @@ -104,9 +103,7 @@ zed_actions.workspace = true [dev-dependencies] acp_thread = { workspace = true, features = ["test-support"] } agent = { workspace = true, features = ["test-support"] } -agent2 = { workspace = true, features = ["test-support"] } assistant_context = { workspace = true, features = ["test-support"] } -assistant_tools.workspace = true buffer_diff = { workspace = true, features = ["test-support"] } db = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] } diff --git a/crates/agent_ui/src/acp/completion_provider.rs b/crates/agent_ui/src/acp/completion_provider.rs index 7588e9f53b32302b3a078f44b3cf85be56ca1b4b..73f0622df878c2abc1d2feef945ef2e771dceaf9 100644 --- a/crates/agent_ui/src/acp/completion_provider.rs +++ b/crates/agent_ui/src/acp/completion_provider.rs @@ -6,8 +6,8 @@ use std::sync::Arc; use std::sync::atomic::AtomicBool; use acp_thread::MentionUri; +use agent::{HistoryEntry, HistoryStore}; use agent_client_protocol as acp; -use agent2::{HistoryEntry, HistoryStore}; use anyhow::Result; use editor::{CompletionProvider, Editor, ExcerptId}; use fuzzy::{StringMatch, StringMatchCandidate}; @@ -32,6 +32,7 @@ use crate::context_picker::file_context_picker::{FileMatch, search_files}; use crate::context_picker::rules_context_picker::{RulesContextEntry, search_rules}; use crate::context_picker::symbol_context_picker::SymbolMatch; use crate::context_picker::symbol_context_picker::search_symbols; +use crate::context_picker::thread_context_picker::search_threads; use crate::context_picker::{ ContextPickerAction, ContextPickerEntry, ContextPickerMode, selection_ranges, }; @@ -938,42 +939,6 @@ impl CompletionProvider for ContextPickerCompletionProvider { } } -pub(crate) fn search_threads( - query: String, - cancellation_flag: Arc, - history_store: &Entity, - cx: &mut App, -) -> Task> { - let threads = history_store.read(cx).entries().collect(); - if query.is_empty() { - return Task::ready(threads); - } - - let executor = cx.background_executor().clone(); - cx.background_spawn(async move { - let candidates = threads - .iter() - .enumerate() - .map(|(id, thread)| StringMatchCandidate::new(id, thread.title())) - .collect::>(); - let matches = fuzzy::match_strings( - &candidates, - &query, - false, - true, - 100, - &cancellation_flag, - executor, - ) - .await; - - matches - .into_iter() - .map(|mat| threads[mat.candidate_id].clone()) - .collect() - }) -} - fn confirm_completion_callback( crease_text: SharedString, start: Anchor, diff --git a/crates/agent_ui/src/acp/entry_view_state.rs b/crates/agent_ui/src/acp/entry_view_state.rs index ee506b98810ba51d0fb933a2ca21e650d0cacc0b..8123c4a422b9d95a2da45e75ceb4079675d845fd 100644 --- a/crates/agent_ui/src/acp/entry_view_state.rs +++ b/crates/agent_ui/src/acp/entry_view_state.rs @@ -1,8 +1,8 @@ use std::{cell::RefCell, ops::Range, rc::Rc}; use acp_thread::{AcpThread, AgentThreadEntry}; +use agent::HistoryStore; use agent_client_protocol::{self as acp, ToolCallId}; -use agent2::HistoryStore; use collections::HashMap; use editor::{Editor, EditorMode, MinimapVisibility}; use gpui::{ @@ -399,9 +399,9 @@ mod tests { use std::{path::Path, rc::Rc}; use acp_thread::{AgentConnection, StubAgentConnection}; + use agent::HistoryStore; use agent_client_protocol as acp; use agent_settings::AgentSettings; - use agent2::HistoryStore; use assistant_context::ContextStore; use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind}; use editor::{EditorSettings, RowInfo}; diff --git a/crates/agent_ui/src/acp/message_editor.rs b/crates/agent_ui/src/acp/message_editor.rs index be1c205cee661d401d577b0bcb2d50dc62b4e38c..57157e59c6b48541ff82bdc417bc119ed01bb997 100644 --- a/crates/agent_ui/src/acp/message_editor.rs +++ b/crates/agent_ui/src/acp/message_editor.rs @@ -3,12 +3,11 @@ use crate::{ context_picker::{ContextPickerAction, fetch_context_picker::fetch_url_content}, }; use acp_thread::{MentionUri, selection_name}; +use agent::{HistoryStore, outline}; use agent_client_protocol as acp; use agent_servers::{AgentServer, AgentServerDelegate}; -use agent2::HistoryStore; use anyhow::{Result, anyhow}; use assistant_slash_commands::codeblock_fence_for_path; -use assistant_tool::outline; use collections::{HashMap, HashSet}; use editor::{ Addon, Anchor, AnchorRangeExt, ContextMenuOptions, ContextMenuPlacement, Editor, EditorElement, @@ -230,7 +229,7 @@ impl MessageEditor { pub fn insert_thread_summary( &mut self, - thread: agent2::DbThreadMetadata, + thread: agent::DbThreadMetadata, window: &mut Window, cx: &mut Context, ) { @@ -599,7 +598,7 @@ impl MessageEditor { id: acp::SessionId, cx: &mut Context, ) -> Task> { - let server = Rc::new(agent2::NativeAgentServer::new( + let server = Rc::new(agent::NativeAgentServer::new( self.project.read(cx).fs().clone(), self.history_store.clone(), )); @@ -612,7 +611,7 @@ impl MessageEditor { let connection = server.connect(None, delegate, cx); cx.spawn(async move |_, cx| { let (agent, _) = connection.await?; - let agent = agent.downcast::().unwrap(); + let agent = agent.downcast::().unwrap(); let summary = agent .0 .update(cx, |agent, cx| agent.thread_summary(id, cx))? @@ -629,8 +628,8 @@ impl MessageEditor { path: PathBuf, cx: &mut Context, ) -> Task> { - let context = self.history_store.update(cx, |text_thread_store, cx| { - text_thread_store.load_text_thread(path.as_path().into(), cx) + let context = self.history_store.update(cx, |store, cx| { + store.load_text_thread(path.as_path().into(), cx) }); cx.spawn(async move |_, cx| { let context = context.await?; @@ -1589,10 +1588,9 @@ mod tests { use std::{cell::RefCell, ops::Range, path::Path, rc::Rc, sync::Arc}; use acp_thread::MentionUri; + use agent::{HistoryStore, outline}; use agent_client_protocol as acp; - use agent2::HistoryStore; use assistant_context::ContextStore; - use assistant_tool::outline; use editor::{AnchorRangeExt as _, Editor, EditorMode}; use fs::FakeFs; use futures::StreamExt as _; diff --git a/crates/agent_ui/src/acp/thread_history.rs b/crates/agent_ui/src/acp/thread_history.rs index cd696f33fa44976e0784c79d1945b548feb20a50..ee280eb9a123e46ba5cf3b75cdeaf67c4b98b71c 100644 --- a/crates/agent_ui/src/acp/thread_history.rs +++ b/crates/agent_ui/src/acp/thread_history.rs @@ -1,6 +1,6 @@ use crate::acp::AcpThreadView; use crate::{AgentPanel, RemoveSelectedThread}; -use agent2::{HistoryEntry, HistoryStore}; +use agent::{HistoryEntry, HistoryStore}; use chrono::{Datelike as _, Local, NaiveDate, TimeDelta}; use editor::{Editor, EditorEvent}; use fuzzy::StringMatchCandidate; @@ -23,11 +23,8 @@ pub struct AcpThreadHistory { hovered_index: Option, search_editor: Entity, search_query: SharedString, - visible_items: Vec, - local_timezone: UtcOffset, - _update_task: Task<()>, _subscriptions: Vec, } @@ -62,7 +59,7 @@ impl EventEmitter for AcpThreadHistory {} impl AcpThreadHistory { pub(crate) fn new( - history_store: Entity, + history_store: Entity, window: &mut Window, cx: &mut Context, ) -> Self { @@ -642,7 +639,7 @@ impl RenderOnce for AcpHistoryEntryElement { if let Some(panel) = workspace.read(cx).panel::(cx) { panel.update(cx, |panel, cx| { panel - .open_saved_prompt_editor( + .open_saved_text_thread( context.path.clone(), window, cx, diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index a5af5b2521894b2051e6edfbe8677aa86177f6f1..cb2e8be2701c2152ef889f7bdc9925f8014f9519 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -5,10 +5,10 @@ use acp_thread::{ }; use acp_thread::{AgentConnection, Plan}; use action_log::ActionLog; +use agent::{DbThreadMetadata, HistoryEntry, HistoryEntryId, HistoryStore, NativeAgentServer}; use agent_client_protocol::{self as acp, PromptCapabilities}; use agent_servers::{AgentServer, AgentServerDelegate}; use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; -use agent2::{DbThreadMetadata, HistoryEntry, HistoryEntryId, HistoryStore, NativeAgentServer}; use anyhow::{Result, anyhow, bail}; use arrayvec::ArrayVec; use audio::{Audio, Sound}; @@ -117,7 +117,7 @@ impl ThreadError { } } -impl ProfileProvider for Entity { +impl ProfileProvider for Entity { fn profile_id(&self, cx: &App) -> AgentProfileId { self.read(cx).profile().clone() } @@ -529,7 +529,7 @@ impl AcpThreadView { let result = if let Some(native_agent) = connection .clone() - .downcast::() + .downcast::() && let Some(resume) = resume_thread.clone() { cx.update(|_, cx| { @@ -3106,7 +3106,7 @@ impl AcpThreadView { let render_history = self .agent .clone() - .downcast::() + .downcast::() .is_some() && self .history_store @@ -4011,12 +4011,12 @@ impl AcpThreadView { pub(crate) fn as_native_connection( &self, cx: &App, - ) -> Option> { + ) -> Option> { let acp_thread = self.thread()?.read(cx); acp_thread.connection().clone().downcast() } - pub(crate) fn as_native_thread(&self, cx: &App) -> Option> { + pub(crate) fn as_native_thread(&self, cx: &App) -> Option> { let acp_thread = self.thread()?.read(cx); self.as_native_connection(cx)? .thread(acp_thread.session_id(), cx) @@ -4404,7 +4404,7 @@ impl AcpThreadView { if let Some(panel) = workspace.panel::(cx) { panel.update(cx, |panel, cx| { panel - .open_saved_prompt_editor(path.as_path().into(), window, cx) + .open_saved_text_thread(path.as_path().into(), window, cx) .detach_and_log_err(cx); }); } @@ -5137,7 +5137,7 @@ impl AcpThreadView { if self .agent .clone() - .downcast::() + .downcast::() .is_some() { // Native agent - use the model name diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index 386eeaca1924ace86e4138fdfc283bfe0c20fae0..ef0d4735d2d7690111ee2549cdee8ab31e32196e 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -6,8 +6,8 @@ mod tool_picker; use std::{ops::Range, sync::Arc}; +use agent::ContextServerRegistry; use anyhow::Result; -use assistant_tool::{ToolSource, ToolWorkingSet}; use cloud_llm_client::{Plan, PlanV1, PlanV2}; use collections::HashMap; use context_server::ContextServerId; @@ -17,7 +17,7 @@ use extension_host::ExtensionStore; use fs::Fs; use gpui::{ Action, AnyView, App, AsyncWindowContext, Corner, Entity, EventEmitter, FocusHandle, Focusable, - Hsla, ScrollHandle, Subscription, Task, WeakEntity, + ScrollHandle, Subscription, Task, WeakEntity, }; use language::LanguageRegistry; use language_model::{ @@ -54,9 +54,8 @@ pub struct AgentConfiguration { focus_handle: FocusHandle, configuration_views_by_provider: HashMap, context_server_store: Entity, - expanded_context_server_tools: HashMap, expanded_provider_configurations: HashMap, - tools: Entity, + context_server_registry: Entity, _registry_subscription: Subscription, scroll_handle: ScrollHandle, _check_for_gemini: Task<()>, @@ -67,7 +66,7 @@ impl AgentConfiguration { fs: Arc, agent_server_store: Entity, context_server_store: Entity, - tools: Entity, + context_server_registry: Entity, language_registry: Arc, workspace: WeakEntity, window: &mut Window, @@ -103,9 +102,8 @@ impl AgentConfiguration { configuration_views_by_provider: HashMap::default(), agent_server_store, context_server_store, - expanded_context_server_tools: HashMap::default(), expanded_provider_configurations: HashMap::default(), - tools, + context_server_registry, _registry_subscription: registry_subscription, scroll_handle: ScrollHandle::new(), _check_for_gemini: Task::ready(()), @@ -438,10 +436,6 @@ impl AgentConfiguration { } } - fn card_item_border_color(&self, cx: &mut Context) -> Hsla { - cx.theme().colors().border.opacity(0.6) - } - fn render_context_servers_section( &mut self, window: &mut Window, @@ -567,7 +561,6 @@ impl AgentConfiguration { window: &mut Window, cx: &mut Context, ) -> impl use<> + IntoElement { - let tools_by_source = self.tools.read(cx).tools_by_source(cx); let server_status = self .context_server_store .read(cx) @@ -596,17 +589,11 @@ impl AgentConfiguration { None }; - let are_tools_expanded = self - .expanded_context_server_tools - .get(&context_server_id) - .copied() - .unwrap_or_default(); - let tools = tools_by_source - .get(&ToolSource::ContextServer { - id: context_server_id.0.clone().into(), - }) - .map_or([].as_slice(), |tools| tools.as_slice()); - let tool_count = tools.len(); + let tool_count = self + .context_server_registry + .read(cx) + .tools_for_server(&context_server_id) + .count(); let (source_icon, source_tooltip) = if is_from_extension { ( @@ -660,7 +647,7 @@ impl AgentConfiguration { let language_registry = self.language_registry.clone(); let context_server_store = self.context_server_store.clone(); let workspace = self.workspace.clone(); - let tools = self.tools.clone(); + let context_server_registry = self.context_server_registry.clone(); move |window, cx| { Some(ContextMenu::build(window, cx, |menu, _window, _cx| { @@ -678,20 +665,16 @@ impl AgentConfiguration { ) .detach_and_log_err(cx); } - }).when(tool_count >= 1, |this| this.entry("View Tools", None, { + }).when(tool_count > 0, |this| this.entry("View Tools", None, { let context_server_id = context_server_id.clone(); - let tools = tools.clone(); + let context_server_registry = context_server_registry.clone(); let workspace = workspace.clone(); - move |window, cx| { let context_server_id = context_server_id.clone(); - let tools = tools.clone(); - let workspace = workspace.clone(); - workspace.update(cx, |workspace, cx| { ConfigureContextServerToolsModal::toggle( context_server_id, - tools, + context_server_registry.clone(), workspace, window, cx, @@ -773,14 +756,6 @@ impl AgentConfiguration { .child( h_flex() .justify_between() - .when( - error.is_none() && are_tools_expanded && tool_count >= 1, - |element| { - element - .border_b_1() - .border_color(self.card_item_border_color(cx)) - }, - ) .child( h_flex() .flex_1() @@ -904,11 +879,6 @@ impl AgentConfiguration { ), ); } - - if !are_tools_expanded || tools.is_empty() { - return parent; - } - parent }) } diff --git a/crates/agent_ui/src/agent_configuration/configure_context_server_tools_modal.rs b/crates/agent_ui/src/agent_configuration/configure_context_server_tools_modal.rs index 5a59806972ecf1b6cbc0702809c98acf1a86b387..3fe0b8d1b1400b4362192261995ed5b6bd1cb662 100644 --- a/crates/agent_ui/src/agent_configuration/configure_context_server_tools_modal.rs +++ b/crates/agent_ui/src/agent_configuration/configure_context_server_tools_modal.rs @@ -1,4 +1,5 @@ -use assistant_tool::{ToolSource, ToolWorkingSet}; +use agent::ContextServerRegistry; +use collections::HashMap; use context_server::ContextServerId; use gpui::{ DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, ScrollHandle, Window, prelude::*, @@ -8,37 +9,37 @@ use workspace::{ModalView, Workspace}; pub struct ConfigureContextServerToolsModal { context_server_id: ContextServerId, - tools: Entity, + context_server_registry: Entity, focus_handle: FocusHandle, - expanded_tools: std::collections::HashMap, + expanded_tools: HashMap, scroll_handle: ScrollHandle, } impl ConfigureContextServerToolsModal { fn new( context_server_id: ContextServerId, - tools: Entity, + context_server_registry: Entity, _window: &mut Window, cx: &mut Context, ) -> Self { Self { context_server_id, - tools, + context_server_registry, focus_handle: cx.focus_handle(), - expanded_tools: std::collections::HashMap::new(), + expanded_tools: HashMap::default(), scroll_handle: ScrollHandle::new(), } } pub fn toggle( context_server_id: ContextServerId, - tools: Entity, + context_server_registry: Entity, workspace: &mut Workspace, window: &mut Window, cx: &mut Context, ) { workspace.toggle_modal(window, cx, |window, cx| { - Self::new(context_server_id, tools, window, cx) + Self::new(context_server_id, context_server_registry, window, cx) }); } @@ -51,13 +52,11 @@ impl ConfigureContextServerToolsModal { window: &mut Window, cx: &mut Context, ) -> impl IntoElement { - let tools_by_source = self.tools.read(cx).tools_by_source(cx); - let server_tools = tools_by_source - .get(&ToolSource::ContextServer { - id: self.context_server_id.0.clone().into(), - }) - .map(|tools| tools.as_slice()) - .unwrap_or(&[]); + let tools = self + .context_server_registry + .read(cx) + .tools_for_server(&self.context_server_id) + .collect::>(); div() .size_full() @@ -70,11 +69,11 @@ impl ConfigureContextServerToolsModal { .max_h_128() .overflow_y_scroll() .track_scroll(&self.scroll_handle) - .children(server_tools.iter().enumerate().flat_map(|(index, tool)| { + .children(tools.iter().enumerate().flat_map(|(index, tool)| { let tool_name = tool.name(); let is_expanded = self .expanded_tools - .get(&tool_name) + .get(tool_name.as_ref()) .copied() .unwrap_or(false); @@ -110,7 +109,7 @@ impl ConfigureContextServerToolsModal { move |this, _event, _window, _cx| { let current = this .expanded_tools - .get(&tool_name) + .get(tool_name.as_ref()) .copied() .unwrap_or(false); this.expanded_tools @@ -127,7 +126,7 @@ impl ConfigureContextServerToolsModal { .into_any_element(), ]; - if index < server_tools.len() - 1 { + if index < tools.len() - 1 { items.push( h_flex() .w_full() diff --git a/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs b/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs index 9a7f0ed602a52d3b27dde565383453f2c5c325fb..fc4bde2c784894b94b7ce35e6e262e52865ffcd1 100644 --- a/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs +++ b/crates/agent_ui/src/agent_configuration/manage_profiles_modal.rs @@ -2,8 +2,8 @@ mod profile_modal_header; use std::sync::Arc; +use agent::ContextServerRegistry; use agent_settings::{AgentProfile, AgentProfileId, AgentSettings, builtin_profiles}; -use assistant_tool::ToolWorkingSet; use editor::Editor; use fs::Fs; use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Subscription, prelude::*}; @@ -17,8 +17,6 @@ use crate::agent_configuration::manage_profiles_modal::profile_modal_header::Pro use crate::agent_configuration::tool_picker::{ToolPicker, ToolPickerDelegate}; use crate::{AgentPanel, ManageProfiles}; -use super::tool_picker::ToolPickerMode; - enum Mode { ChooseProfile(ChooseProfileMode), NewProfile(NewProfileMode), @@ -97,7 +95,7 @@ pub struct NewProfileMode { pub struct ManageProfilesModal { fs: Arc, - tools: Entity, + context_server_registry: Entity, focus_handle: FocusHandle, mode: Mode, } @@ -111,10 +109,9 @@ impl ManageProfilesModal { workspace.register_action(|workspace, action: &ManageProfiles, window, cx| { if let Some(panel) = workspace.panel::(cx) { let fs = workspace.app_state().fs.clone(); - let thread_store = panel.read(cx).thread_store(); - let tools = thread_store.read(cx).tools(); + let context_server_registry = panel.read(cx).context_server_registry().clone(); workspace.toggle_modal(window, cx, |window, cx| { - let mut this = Self::new(fs, tools, window, cx); + let mut this = Self::new(fs, context_server_registry, window, cx); if let Some(profile_id) = action.customize_tools.clone() { this.configure_builtin_tools(profile_id, window, cx); @@ -128,7 +125,7 @@ impl ManageProfilesModal { pub fn new( fs: Arc, - tools: Entity, + context_server_registry: Entity, window: &mut Window, cx: &mut Context, ) -> Self { @@ -136,7 +133,7 @@ impl ManageProfilesModal { Self { fs, - tools, + context_server_registry, focus_handle, mode: Mode::choose_profile(window, cx), } @@ -193,10 +190,9 @@ impl ManageProfilesModal { }; let tool_picker = cx.new(|cx| { - let delegate = ToolPickerDelegate::new( - ToolPickerMode::McpTools, + let delegate = ToolPickerDelegate::mcp_tools( + &self.context_server_registry, self.fs.clone(), - self.tools.clone(), profile_id.clone(), profile, cx, @@ -230,10 +226,12 @@ impl ManageProfilesModal { }; let tool_picker = cx.new(|cx| { - let delegate = ToolPickerDelegate::new( - ToolPickerMode::BuiltinTools, + let delegate = ToolPickerDelegate::builtin_tools( + //todo: This causes the web search tool to show up even it only works when using zed hosted models + agent::built_in_tool_names() + .map(|s| s.into()) + .collect::>(), self.fs.clone(), - self.tools.clone(), profile_id.clone(), profile, cx, diff --git a/crates/agent_ui/src/agent_configuration/tool_picker.rs b/crates/agent_ui/src/agent_configuration/tool_picker.rs index c624948944c0624e75e385d1b4b15aa77fea9bcd..6b84205e1bd6336d70751090d8f0451b1b1925b0 100644 --- a/crates/agent_ui/src/agent_configuration/tool_picker.rs +++ b/crates/agent_ui/src/agent_configuration/tool_picker.rs @@ -1,7 +1,7 @@ use std::{collections::BTreeMap, sync::Arc}; +use agent::ContextServerRegistry; use agent_settings::{AgentProfileId, AgentProfileSettings}; -use assistant_tool::{ToolSource, ToolWorkingSet}; use fs::Fs; use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window}; use picker::{Picker, PickerDelegate}; @@ -14,7 +14,7 @@ pub struct ToolPicker { } #[derive(Clone, Copy, Debug, PartialEq)] -pub enum ToolPickerMode { +enum ToolPickerMode { BuiltinTools, McpTools, } @@ -76,59 +76,79 @@ pub struct ToolPickerDelegate { } impl ToolPickerDelegate { - pub fn new( - mode: ToolPickerMode, + pub fn builtin_tools( + tool_names: Vec>, fs: Arc, - tool_set: Entity, profile_id: AgentProfileId, profile_settings: AgentProfileSettings, cx: &mut Context, ) -> Self { - let items = Arc::new(Self::resolve_items(mode, &tool_set, cx)); + Self::new( + Arc::new( + tool_names + .into_iter() + .map(|name| PickerItem::Tool { + name, + server_id: None, + }) + .collect(), + ), + ToolPickerMode::BuiltinTools, + fs, + profile_id, + profile_settings, + cx, + ) + } + pub fn mcp_tools( + registry: &Entity, + fs: Arc, + profile_id: AgentProfileId, + profile_settings: AgentProfileSettings, + cx: &mut Context, + ) -> Self { + let mut items = Vec::new(); + + for (id, tools) in registry.read(cx).servers() { + let server_id = id.clone().0; + items.push(PickerItem::ContextServer { + server_id: server_id.clone(), + }); + items.extend(tools.keys().map(|tool_name| PickerItem::Tool { + name: tool_name.clone().into(), + server_id: Some(server_id.clone()), + })); + } + + Self::new( + Arc::new(items), + ToolPickerMode::McpTools, + fs, + profile_id, + profile_settings, + cx, + ) + } + + fn new( + items: Arc>, + mode: ToolPickerMode, + fs: Arc, + profile_id: AgentProfileId, + profile_settings: AgentProfileSettings, + cx: &mut Context, + ) -> Self { Self { tool_picker: cx.entity().downgrade(), + mode, fs, items, profile_id, profile_settings, filtered_items: Vec::new(), selected_index: 0, - mode, - } - } - - fn resolve_items( - mode: ToolPickerMode, - tool_set: &Entity, - cx: &mut App, - ) -> Vec { - let mut items = Vec::new(); - for (source, tools) in tool_set.read(cx).tools_by_source(cx) { - match source { - ToolSource::Native => { - if mode == ToolPickerMode::BuiltinTools { - items.extend(tools.into_iter().map(|tool| PickerItem::Tool { - name: tool.name().into(), - server_id: None, - })); - } - } - ToolSource::ContextServer { id } => { - if mode == ToolPickerMode::McpTools && !tools.is_empty() { - let server_id: Arc = id.clone().into(); - items.push(PickerItem::ContextServer { - server_id: server_id.clone(), - }); - items.extend(tools.into_iter().map(|tool| PickerItem::Tool { - name: tool.name().into(), - server_id: Some(server_id.clone()), - })); - } - } - } } - items } } diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index bcba02e3cf2056a27b58b53ab6947b8775e4bfda..2def41c74dd715637f269e572342d04b944505b5 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -4,7 +4,7 @@ use std::rc::Rc; use std::sync::Arc; use acp_thread::AcpThread; -use agent2::{DbThreadMetadata, HistoryEntry}; +use agent::{ContextServerRegistry, DbThreadMetadata, HistoryEntry, HistoryStore}; use db::kvp::{Dismissable, KEY_VALUE_STORE}; use project::agent_server_store::{ AgentServerCommand, AllAgentServersSettings, CLAUDE_CODE_NAME, CODEX_NAME, GEMINI_NAME, @@ -17,6 +17,7 @@ use zed_actions::OpenBrowser; use zed_actions::agent::{OpenClaudeCodeOnboardingModal, ReauthenticateAgent}; use crate::acp::{AcpThreadHistory, ThreadHistoryEvent}; +use crate::context_store::ContextStore; use crate::ui::{AcpOnboardingModal, ClaudeCodeOnboardingModal}; use crate::{ AddContextServer, AgentDiffPane, DeleteRecentlyOpenThread, Follow, InlineAssistant, @@ -32,16 +33,11 @@ use crate::{ use crate::{ ExternalAgent, NewExternalAgentThread, NewNativeAgentThreadFromSummary, placeholder_command, }; -use agent::{ - context_store::ContextStore, - thread_store::{TextThreadStore, ThreadStore}, -}; use agent_settings::AgentSettings; use ai_onboarding::AgentPanelOnboarding; use anyhow::{Result, anyhow}; use assistant_context::{AssistantContext, ContextEvent, ContextSummary}; use assistant_slash_command::SlashCommandWorkingSet; -use assistant_tool::ToolWorkingSet; use client::{UserStore, zed_urls}; use cloud_llm_client::{Plan, PlanV1, PlanV2, UsageLimit}; use editor::{Anchor, AnchorRangeExt as _, Editor, EditorEvent, MultiBuffer}; @@ -118,7 +114,7 @@ pub fn init(cx: &mut App) { .register_action(|workspace, _: &NewTextThread, window, cx| { if let Some(panel) = workspace.panel::(cx) { workspace.focus_panel::(window, cx); - panel.update(cx, |panel, cx| panel.new_prompt_editor(window, cx)); + panel.update(cx, |panel, cx| panel.new_text_thread(window, cx)); } }) .register_action(|workspace, action: &NewExternalAgentThread, window, cx| { @@ -281,7 +277,7 @@ impl ActiveView { pub fn native_agent( fs: Arc, prompt_store: Option>, - acp_history_store: Entity, + history_store: Entity, project: Entity, workspace: WeakEntity, window: &mut Window, @@ -289,12 +285,12 @@ impl ActiveView { ) -> Self { let thread_view = cx.new(|cx| { crate::acp::AcpThreadView::new( - ExternalAgent::NativeAgent.server(fs, acp_history_store.clone()), + ExternalAgent::NativeAgent.server(fs, history_store.clone()), None, None, workspace, project, - acp_history_store, + history_store, prompt_store, window, cx, @@ -304,9 +300,9 @@ impl ActiveView { Self::ExternalAgentThread { thread_view } } - pub fn prompt_editor( + pub fn text_thread( context_editor: Entity, - acp_history_store: Entity, + acp_history_store: Entity, language_registry: Arc, window: &mut Window, cx: &mut App, @@ -379,7 +375,7 @@ impl ActiveView { .replace_recently_opened_text_thread(old_path, new_path, cx); } else { history_store.push_recently_opened_entry( - agent2::HistoryEntryId::TextThread(new_path.clone()), + agent::HistoryEntryId::TextThread(new_path.clone()), cx, ); } @@ -412,11 +408,11 @@ pub struct AgentPanel { project: Entity, fs: Arc, language_registry: Arc, - thread_store: Entity, acp_history: Entity, - history_store: Entity, - context_store: Entity, + history_store: Entity, + text_thread_store: Entity, prompt_store: Option>, + context_server_registry: Entity, inline_assist_context_store: Entity, configuration: Option>, configuration_subscription: Option, @@ -424,8 +420,8 @@ pub struct AgentPanel { previous_view: Option, new_thread_menu_handle: PopoverMenuHandle, agent_panel_menu_handle: PopoverMenuHandle, - assistant_navigation_menu_handle: PopoverMenuHandle, - assistant_navigation_menu: Option>, + agent_navigation_menu_handle: PopoverMenuHandle, + agent_navigation_menu: Option>, width: Option, height: Option, zoomed: bool, @@ -463,33 +459,6 @@ impl AgentPanel { Ok(prompt_store) => prompt_store.await.ok(), Err(_) => None, }; - let tools = cx.new(|_| ToolWorkingSet::default())?; - let thread_store = workspace - .update(cx, |workspace, cx| { - let project = workspace.project().clone(); - ThreadStore::load( - project, - tools.clone(), - prompt_store.clone(), - prompt_builder.clone(), - cx, - ) - })? - .await?; - - let slash_commands = Arc::new(SlashCommandWorkingSet::default()); - let context_store = workspace - .update(cx, |workspace, cx| { - let project = workspace.project().clone(); - assistant_context::ContextStore::new( - project, - prompt_builder.clone(), - slash_commands, - cx, - ) - })? - .await?; - let serialized_panel = if let Some(panel) = cx .background_spawn(async move { KEY_VALUE_STORE.read_kvp(AGENT_PANEL_KEY) }) .await @@ -501,17 +470,22 @@ impl AgentPanel { None }; - let panel = workspace.update_in(cx, |workspace, window, cx| { - let panel = cx.new(|cx| { - Self::new( - workspace, - thread_store, - context_store, - prompt_store, - window, + let slash_commands = Arc::new(SlashCommandWorkingSet::default()); + let text_thread_store = workspace + .update(cx, |workspace, cx| { + let project = workspace.project().clone(); + assistant_context::ContextStore::new( + project, + prompt_builder, + slash_commands, cx, ) - }); + })? + .await?; + + let panel = workspace.update_in(cx, |workspace, window, cx| { + let panel = + cx.new(|cx| Self::new(workspace, text_thread_store, prompt_store, window, cx)); panel.as_mut(cx).loading = true; if let Some(serialized_panel) = serialized_panel { @@ -538,8 +512,7 @@ impl AgentPanel { fn new( workspace: &Workspace, - thread_store: Entity, - context_store: Entity, + text_thread_store: Entity, prompt_store: Option>, window: &mut Window, cx: &mut Context, @@ -551,10 +524,11 @@ impl AgentPanel { let client = workspace.client().clone(); let workspace = workspace.weak_handle(); - let inline_assist_context_store = - cx.new(|_cx| ContextStore::new(project.downgrade(), Some(thread_store.downgrade()))); + let inline_assist_context_store = cx.new(|_cx| ContextStore::new(project.downgrade())); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let history_store = cx.new(|cx| agent2::HistoryStore::new(context_store.clone(), cx)); + let history_store = cx.new(|cx| agent::HistoryStore::new(text_thread_store.clone(), cx)); let acp_history = cx.new(|cx| AcpThreadHistory::new(history_store.clone(), window, cx)); cx.subscribe_in( &acp_history, @@ -570,7 +544,7 @@ impl AgentPanel { ); } ThreadHistoryEvent::Open(HistoryEntry::TextThread(thread)) => { - this.open_saved_prompt_editor(thread.path.clone(), window, cx) + this.open_saved_text_thread(thread.path.clone(), window, cx) .detach_and_log_err(cx); } }, @@ -589,8 +563,7 @@ impl AgentPanel { cx, ), DefaultView::TextThread => { - let context = - context_store.update(cx, |context_store, cx| context_store.create(cx)); + let context = text_thread_store.update(cx, |store, cx| store.create(cx)); let lsp_adapter_delegate = make_lsp_adapter_delegate(&project.clone(), cx).unwrap(); let context_editor = cx.new(|cx| { let mut editor = TextThreadEditor::for_context( @@ -605,7 +578,7 @@ impl AgentPanel { editor.insert_default_prompt(window, cx); editor }); - ActiveView::prompt_editor( + ActiveView::text_thread( context_editor, history_store.clone(), language_registry.clone(), @@ -619,7 +592,7 @@ impl AgentPanel { window.defer(cx, move |window, cx| { let panel = weak_panel.clone(); - let assistant_navigation_menu = + let agent_navigation_menu = ContextMenu::build_persistent(window, cx, move |mut menu, _window, cx| { if let Some(panel) = panel.upgrade() { menu = Self::populate_recently_opened_menu_section(menu, panel, cx); @@ -633,7 +606,7 @@ impl AgentPanel { weak_panel .update(cx, |panel, cx| { cx.subscribe_in( - &assistant_navigation_menu, + &agent_navigation_menu, window, |_, menu, _: &DismissEvent, window, cx| { menu.update(cx, |menu, _| { @@ -643,7 +616,7 @@ impl AgentPanel { }, ) .detach(); - panel.assistant_navigation_menu = Some(assistant_navigation_menu); + panel.agent_navigation_menu = Some(agent_navigation_menu); }) .ok(); }); @@ -666,17 +639,17 @@ impl AgentPanel { project: project.clone(), fs: fs.clone(), language_registry, - thread_store: thread_store.clone(), - context_store, + text_thread_store, prompt_store, configuration: None, configuration_subscription: None, + context_server_registry, inline_assist_context_store, previous_view: None, new_thread_menu_handle: PopoverMenuHandle::default(), agent_panel_menu_handle: PopoverMenuHandle::default(), - assistant_navigation_menu_handle: PopoverMenuHandle::default(), - assistant_navigation_menu: None, + agent_navigation_menu_handle: PopoverMenuHandle::default(), + agent_navigation_menu: None, width: None, height: None, zoomed: false, @@ -711,12 +684,12 @@ impl AgentPanel { &self.inline_assist_context_store } - pub(crate) fn thread_store(&self) -> &Entity { - &self.thread_store + pub(crate) fn thread_store(&self) -> &Entity { + &self.history_store } - pub(crate) fn text_thread_store(&self) -> &Entity { - &self.context_store + pub(crate) fn context_server_registry(&self) -> &Entity { + &self.context_server_registry } fn active_thread_view(&self) -> Option<&Entity> { @@ -753,11 +726,11 @@ impl AgentPanel { ); } - fn new_prompt_editor(&mut self, window: &mut Window, cx: &mut Context) { + fn new_text_thread(&mut self, window: &mut Window, cx: &mut Context) { telemetry::event!("Agent Thread Started", agent = "zed-text"); let context = self - .context_store + .text_thread_store .update(cx, |context_store, cx| context_store.create(cx)); let lsp_adapter_delegate = make_lsp_adapter_delegate(&self.project, cx) .log_err() @@ -783,7 +756,7 @@ impl AgentPanel { } self.set_active_view( - ActiveView::prompt_editor( + ActiveView::text_thread( context_editor.clone(), self.history_store.clone(), self.language_registry.clone(), @@ -921,32 +894,29 @@ impl AgentPanel { self.set_active_view(previous_view, window, cx); } } else { - self.thread_store - .update(cx, |thread_store, cx| thread_store.reload(cx)) - .detach_and_log_err(cx); self.set_active_view(ActiveView::History, window, cx); } cx.notify(); } - pub(crate) fn open_saved_prompt_editor( + pub(crate) fn open_saved_text_thread( &mut self, path: Arc, window: &mut Window, cx: &mut Context, ) -> Task> { let context = self - .context_store - .update(cx, |store, cx| store.open_local_context(path, cx)); + .history_store + .update(cx, |store, cx| store.load_text_thread(path, cx)); cx.spawn_in(window, async move |this, cx| { let context = context.await?; this.update_in(cx, |this, window, cx| { - this.open_prompt_editor(context, window, cx); + this.open_text_thread(context, window, cx); }) }) } - pub(crate) fn open_prompt_editor( + pub(crate) fn open_text_thread( &mut self, context: Entity, window: &mut Window, @@ -973,7 +943,7 @@ impl AgentPanel { } self.set_active_view( - ActiveView::prompt_editor( + ActiveView::text_thread( editor, self.history_store.clone(), self.language_registry.clone(), @@ -1013,7 +983,7 @@ impl AgentPanel { window: &mut Window, cx: &mut Context, ) { - self.assistant_navigation_menu_handle.toggle(window, cx); + self.agent_navigation_menu_handle.toggle(window, cx); } pub fn toggle_options_menu( @@ -1106,7 +1076,6 @@ impl AgentPanel { pub(crate) fn open_configuration(&mut self, window: &mut Window, cx: &mut Context) { let agent_server_store = self.project.read(cx).agent_server_store().clone(); let context_server_store = self.project.read(cx).context_server_store(); - let tools = self.thread_store.read(cx).tools(); let fs = self.fs.clone(); self.set_active_view(ActiveView::Configuration, window, cx); @@ -1115,7 +1084,7 @@ impl AgentPanel { fs, agent_server_store, context_server_store, - tools, + self.context_server_registry.clone(), self.language_registry.clone(), self.workspace.clone(), window, @@ -1183,7 +1152,7 @@ impl AgentPanel { }); } - self.new_thread(&NewThread::default(), window, cx); + self.new_thread(&NewThread, window, cx); if let Some((thread, model)) = self .active_native_agent_thread(cx) .zip(provider.default_model(cx)) @@ -1205,7 +1174,7 @@ impl AgentPanel { } } - pub(crate) fn active_native_agent_thread(&self, cx: &App) -> Option> { + pub(crate) fn active_native_agent_thread(&self, cx: &App) -> Option> { match &self.active_view { ActiveView::ExternalAgentThread { thread_view, .. } => { thread_view.read(cx).as_native_thread(cx) @@ -1241,7 +1210,7 @@ impl AgentPanel { self.history_store.update(cx, |store, cx| { if let Some(path) = context_editor.read(cx).context().read(cx).path() { store.push_recently_opened_entry( - agent2::HistoryEntryId::TextThread(path.clone()), + agent::HistoryEntryId::TextThread(path.clone()), cx, ) } @@ -1295,15 +1264,15 @@ impl AgentPanel { let entry = entry.clone(); panel .update(cx, move |this, cx| match &entry { - agent2::HistoryEntry::AcpThread(entry) => this.external_thread( + agent::HistoryEntry::AcpThread(entry) => this.external_thread( Some(ExternalAgent::NativeAgent), Some(entry.clone()), None, window, cx, ), - agent2::HistoryEntry::TextThread(entry) => this - .open_saved_prompt_editor(entry.path.clone(), window, cx) + agent::HistoryEntry::TextThread(entry) => this + .open_saved_text_thread(entry.path.clone(), window, cx) .detach_and_log_err(cx), }) .ok(); @@ -1730,9 +1699,9 @@ impl AgentPanel { }, ) .anchor(corner) - .with_handle(self.assistant_navigation_menu_handle.clone()) + .with_handle(self.agent_navigation_menu_handle.clone()) .menu({ - let menu = self.assistant_navigation_menu.clone(); + let menu = self.agent_navigation_menu.clone(); move |window, cx| { telemetry::event!("View Thread History Clicked"); @@ -1832,7 +1801,7 @@ impl AgentPanel { }) .item( ContextMenuEntry::new("New Thread") - .action(NewThread::default().boxed_clone()) + .action(NewThread.boxed_clone()) .icon(IconName::Thread) .icon_color(Color::Muted) .handler({ @@ -2278,7 +2247,7 @@ impl AgentPanel { } } - fn render_prompt_editor( + fn render_text_thread( &self, context_editor: &Entity, buffer_search_bar: &Entity, @@ -2409,8 +2378,8 @@ impl AgentPanel { let mut key_context = KeyContext::new_with_defaults(); key_context.add("AgentPanel"); match &self.active_view { - ActiveView::ExternalAgentThread { .. } => key_context.add("external_agent_thread"), - ActiveView::TextThread { .. } => key_context.add("prompt_editor"), + ActiveView::ExternalAgentThread { .. } => key_context.add("acp_thread"), + ActiveView::TextThread { .. } => key_context.add("text_thread"), ActiveView::History | ActiveView::Configuration => {} } key_context @@ -2487,7 +2456,7 @@ impl Render for AgentPanel { this } }) - .child(self.render_prompt_editor( + .child(self.render_text_thread( context_editor, buffer_search_bar, window, @@ -2538,8 +2507,7 @@ impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist { }; let prompt_store = None; let thread_store = None; - let text_thread_store = None; - let context_store = cx.new(|_| ContextStore::new(project.clone(), None)); + let context_store = cx.new(|_| ContextStore::new(project.clone())); assistant.assist( prompt_editor, self.workspace.clone(), @@ -2547,7 +2515,6 @@ impl rules_library::InlineAssistDelegate for PromptLibraryInlineAssist { project, prompt_store, thread_store, - text_thread_store, initial_prompt, window, cx, @@ -2590,7 +2557,7 @@ impl AgentPanelDelegate for ConcreteAssistantPanelDelegate { }; panel.update(cx, |panel, cx| { - panel.open_saved_prompt_editor(path, window, cx) + panel.open_saved_text_thread(path, window, cx) }) } diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 26d37378776b52be5fb88f3dad820986fb812d07..7c31500c937a6513c932c66560cf8754cbafbf1c 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -4,8 +4,10 @@ mod agent_diff; mod agent_model_selector; mod agent_panel; mod buffer_codegen; +mod context; mod context_picker; mod context_server_configuration; +mod context_store; mod context_strip; mod inline_assistant; mod inline_prompt_editor; @@ -22,7 +24,6 @@ mod ui; use std::rc::Rc; use std::sync::Arc; -use agent::ThreadId; use agent_settings::{AgentProfileId, AgentSettings}; use assistant_slash_command::SlashCommandRegistry; use client::Client; @@ -139,10 +140,7 @@ pub struct QuoteSelection; #[derive(Default, Clone, PartialEq, Deserialize, JsonSchema, Action)] #[action(namespace = agent)] #[serde(deny_unknown_fields)] -pub struct NewThread { - #[serde(default)] - from_thread_id: Option, -} +pub struct NewThread; /// Creates a new external agent conversation thread. #[derive(Default, Clone, PartialEq, Deserialize, JsonSchema, Action)] @@ -196,13 +194,13 @@ impl ExternalAgent { pub fn server( &self, fs: Arc, - history: Entity, + history: Entity, ) -> Rc { match self { Self::Gemini => Rc::new(agent_servers::Gemini), Self::ClaudeCode => Rc::new(agent_servers::ClaudeCode), Self::Codex => Rc::new(agent_servers::Codex), - Self::NativeAgent => Rc::new(agent2::NativeAgentServer::new(fs, history)), + Self::NativeAgent => Rc::new(agent::NativeAgentServer::new(fs, history)), Self::Custom { name, command: _ } => { Rc::new(agent_servers::CustomAgentServer::new(name.clone())) } @@ -266,7 +264,6 @@ pub fn init( init_language_model_settings(cx); } assistant_slash_command::init(cx); - agent::init(fs.clone(), cx); agent_panel::init(cx); context_server_configuration::init(language_registry.clone(), fs.clone(), cx); TextThreadEditor::init(cx); diff --git a/crates/agent_ui/src/buffer_codegen.rs b/crates/agent_ui/src/buffer_codegen.rs index 2309aad754aee55af5ad040c39d22304486446a4..215e2a74d7be9cbcb18442dcefa1581d08eec7b2 100644 --- a/crates/agent_ui/src/buffer_codegen.rs +++ b/crates/agent_ui/src/buffer_codegen.rs @@ -1,7 +1,5 @@ -use crate::inline_prompt_editor::CodegenStatus; -use agent::{ - ContextStore, - context::{ContextLoadResult, load_context}, +use crate::{ + context::load_context, context_store::ContextStore, inline_prompt_editor::CodegenStatus, }; use agent_settings::AgentSettings; use anyhow::{Context as _, Result}; @@ -434,16 +432,16 @@ impl CodegenAlternative { .generate_inline_transformation_prompt(user_prompt, language_name, buffer, range) .context("generating content prompt")?; - let context_task = self.context_store.as_ref().map(|context_store| { + let context_task = self.context_store.as_ref().and_then(|context_store| { if let Some(project) = self.project.upgrade() { let context = context_store .read(cx) .context() .cloned() .collect::>(); - load_context(context, &project, &self.prompt_store, cx) + Some(load_context(context, &project, &self.prompt_store, cx)) } else { - Task::ready(ContextLoadResult::default()) + None } }); @@ -459,7 +457,6 @@ impl CodegenAlternative { if let Some(context_task) = context_task { context_task .await - .loaded_context .add_to_request_message(&mut request_message); } diff --git a/crates/agent/src/context.rs b/crates/agent_ui/src/context.rs similarity index 90% rename from crates/agent/src/context.rs rename to crates/agent_ui/src/context.rs index 3b2922087a94c497c07f1df67a8d4d9adf759909..3d0600605153fd8343205f3889953c100bde7a7a 100644 --- a/crates/agent/src/context.rs +++ b/crates/agent_ui/src/context.rs @@ -1,11 +1,8 @@ -use crate::thread::Thread; +use agent::outline; use assistant_context::AssistantContext; -use assistant_tool::outline; -use collections::HashSet; use futures::future; use futures::{FutureExt, future::Shared}; use gpui::{App, AppContext as _, ElementId, Entity, SharedString, Task}; -use icons::IconName; use language::Buffer; use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent}; use project::{Project, ProjectEntryId, ProjectPath, Worktree}; @@ -17,6 +14,7 @@ use std::hash::{Hash, Hasher}; use std::path::PathBuf; use std::{ops::Range, path::Path, sync::Arc}; use text::{Anchor, OffsetRangeExt as _}; +use ui::IconName; use util::markdown::MarkdownCodeBlock; use util::rel_path::RelPath; use util::{ResultExt as _, post_inc}; @@ -181,7 +179,7 @@ impl FileContextHandle { }) } - fn load(self, cx: &App) -> Task>)>> { + fn load(self, cx: &App) -> Task> { let buffer_ref = self.buffer.read(cx); let Some(file) = buffer_ref.file() else { log::error!("file context missing path"); @@ -206,7 +204,7 @@ impl FileContextHandle { text: buffer_content.text.into(), is_outline: buffer_content.is_outline, }); - Some((context, vec![buffer])) + Some(context) }) } } @@ -256,11 +254,7 @@ impl DirectoryContextHandle { self.entry_id.hash(state) } - fn load( - self, - project: Entity, - cx: &mut App, - ) -> Task>)>> { + fn load(self, project: Entity, cx: &mut App) -> Task> { let Some(worktree) = project.read(cx).worktree_for_entry(self.entry_id, cx) else { return Task::ready(None); }; @@ -307,7 +301,7 @@ impl DirectoryContextHandle { }); cx.background_spawn(async move { - let (rope, buffer) = rope_task.await?; + let (rope, _buffer) = rope_task.await?; let fenced_codeblock = MarkdownCodeBlock { tag: &codeblock_tag(&full_path, None), text: &rope.to_string(), @@ -318,18 +312,22 @@ impl DirectoryContextHandle { rel_path, fenced_codeblock, }; - Some((descendant, buffer)) + Some(descendant) }) })); cx.background_spawn(async move { - let (descendants, buffers) = descendants_future.await.into_iter().flatten().unzip(); + let descendants = descendants_future + .await + .into_iter() + .flatten() + .collect::>(); let context = AgentContext::Directory(DirectoryContext { handle: self, full_path: directory_full_path, descendants, }); - Some((context, buffers)) + Some(context) }) } } @@ -397,7 +395,7 @@ impl SymbolContextHandle { .into() } - fn load(self, cx: &App) -> Task>)>> { + fn load(self, cx: &App) -> Task> { let buffer_ref = self.buffer.read(cx); let Some(file) = buffer_ref.file() else { log::error!("symbol context's file has no path"); @@ -406,14 +404,13 @@ impl SymbolContextHandle { let full_path = file.full_path(cx).to_string_lossy().into_owned(); let line_range = self.enclosing_range.to_point(&buffer_ref.snapshot()); let text = self.text(cx); - let buffer = self.buffer.clone(); let context = AgentContext::Symbol(SymbolContext { handle: self, full_path, line_range, text, }); - Task::ready(Some((context, vec![buffer]))) + Task::ready(Some(context)) } } @@ -468,13 +465,12 @@ impl SelectionContextHandle { .into() } - fn load(self, cx: &App) -> Task>)>> { + fn load(self, cx: &App) -> Task> { let Some(full_path) = self.full_path(cx) else { log::error!("selection context's file has no path"); return Task::ready(None); }; let text = self.text(cx); - let buffer = self.buffer.clone(); let context = AgentContext::Selection(SelectionContext { full_path: full_path.to_string_lossy().into_owned(), line_range: self.line_range(cx), @@ -482,7 +478,7 @@ impl SelectionContextHandle { handle: self, }); - Task::ready(Some((context, vec![buffer]))) + Task::ready(Some(context)) } } @@ -523,8 +519,8 @@ impl FetchedUrlContext { })) } - pub fn load(self) -> Task>)>> { - Task::ready(Some((AgentContext::FetchedUrl(self), vec![]))) + pub fn load(self) -> Task> { + Task::ready(Some(AgentContext::FetchedUrl(self))) } } @@ -537,7 +533,7 @@ impl Display for FetchedUrlContext { #[derive(Debug, Clone)] pub struct ThreadContextHandle { - pub thread: Entity, + pub thread: Entity, pub context_id: ContextId, } @@ -558,22 +554,20 @@ impl ThreadContextHandle { } pub fn title(&self, cx: &App) -> SharedString { - self.thread.read(cx).summary().or_default() + self.thread.read(cx).title() } - fn load(self, cx: &App) -> Task>)>> { - cx.spawn(async move |cx| { - let text = Thread::wait_for_detailed_summary_or_text(&self.thread, cx).await?; - let title = self - .thread - .read_with(cx, |thread, _cx| thread.summary().or_default()) - .ok()?; + fn load(self, cx: &mut App) -> Task> { + let task = self.thread.update(cx, |thread, cx| thread.summary(cx)); + let title = self.title(cx); + cx.background_spawn(async move { + let text = task.await?; let context = AgentContext::Thread(ThreadContext { title, text, handle: self, }); - Some((context, vec![])) + Some(context) }) } } @@ -612,7 +606,7 @@ impl TextThreadContextHandle { self.context.read(cx).summary().or_default() } - fn load(self, cx: &App) -> Task>)>> { + fn load(self, cx: &App) -> Task> { let title = self.title(cx); let text = self.context.read(cx).to_xml(cx); let context = AgentContext::TextThread(TextThreadContext { @@ -620,7 +614,7 @@ impl TextThreadContextHandle { text: text.into(), handle: self, }); - Task::ready(Some((context, vec![]))) + Task::ready(Some(context)) } } @@ -666,7 +660,7 @@ impl RulesContextHandle { self, prompt_store: &Option>, cx: &App, - ) -> Task>)>> { + ) -> Task> { let Some(prompt_store) = prompt_store.as_ref() else { return Task::ready(None); }; @@ -685,7 +679,7 @@ impl RulesContextHandle { title, text, }); - Some((context, vec![])) + Some(context) }) } } @@ -748,32 +742,21 @@ impl ImageContext { } } - pub fn load(self, cx: &App) -> Task>)>> { + pub fn load(self, cx: &App) -> Task> { cx.background_spawn(async move { self.image_task.clone().await; - Some((AgentContext::Image(self), vec![])) + Some(AgentContext::Image(self)) }) } } -#[derive(Debug, Clone, Default)] -pub struct ContextLoadResult { - pub loaded_context: LoadedContext, - pub referenced_buffers: HashSet>, -} - #[derive(Debug, Clone, Default)] pub struct LoadedContext { - pub contexts: Vec, pub text: String, pub images: Vec, } impl LoadedContext { - pub fn is_empty(&self) -> bool { - self.text.is_empty() && self.images.is_empty() - } - pub fn add_to_request_message(&self, request_message: &mut LanguageModelRequestMessage) { if !self.text.is_empty() { request_message @@ -804,7 +787,7 @@ pub fn load_context( project: &Entity, prompt_store: &Option>, cx: &mut App, -) -> Task { +) -> Task { let load_tasks: Vec<_> = contexts .into_iter() .map(|context| match context { @@ -823,16 +806,7 @@ pub fn load_context( cx.background_spawn(async move { let load_results = future::join_all(load_tasks).await; - let mut contexts = Vec::new(); let mut text = String::new(); - let mut referenced_buffers = HashSet::default(); - for context in load_results { - let Some((context, buffers)) = context else { - continue; - }; - contexts.push(context); - referenced_buffers.extend(buffers); - } let mut file_context = Vec::new(); let mut directory_context = Vec::new(); @@ -843,7 +817,7 @@ pub fn load_context( let mut text_thread_context = Vec::new(); let mut rules_context = Vec::new(); let mut images = Vec::new(); - for context in &contexts { + for context in load_results.into_iter().flatten() { match context { AgentContext::File(context) => file_context.push(context), AgentContext::Directory(context) => directory_context.push(context), @@ -868,14 +842,7 @@ pub fn load_context( && text_thread_context.is_empty() && rules_context.is_empty() { - return ContextLoadResult { - loaded_context: LoadedContext { - contexts, - text, - images, - }, - referenced_buffers, - }; + return LoadedContext { text, images }; } text.push_str( @@ -961,14 +928,7 @@ pub fn load_context( text.push_str("\n"); - ContextLoadResult { - loaded_context: LoadedContext { - contexts, - text, - images, - }, - referenced_buffers, - } + LoadedContext { text, images } }) } @@ -1131,11 +1091,13 @@ mod tests { assert!(content_len > outline::AUTO_OUTLINE_SIZE); - let file_context = file_context_for(large_content, cx).await; + let file_context = load_context_for("file.txt", large_content, cx).await; assert!( - file_context.is_outline, - "Large file should use outline format" + file_context + .text + .contains(&format!("# File outline for {}", path!("test/file.txt"))), + "Large files should not get an outline" ); assert!( @@ -1153,29 +1115,38 @@ mod tests { assert!(content_len < outline::AUTO_OUTLINE_SIZE); - let file_context = file_context_for(small_content.to_string(), cx).await; + let file_context = load_context_for("file.txt", small_content.to_string(), cx).await; assert!( - !file_context.is_outline, + !file_context + .text + .contains(&format!("# File outline for {}", path!("test/file.txt"))), "Small files should not get an outline" ); - assert_eq!(file_context.text, small_content); + assert!( + file_context.text.contains(small_content), + "Small files should use full content" + ); } - async fn file_context_for(content: String, cx: &mut TestAppContext) -> FileContext { + async fn load_context_for( + filename: &str, + content: String, + cx: &mut TestAppContext, + ) -> LoadedContext { // Create a test project with the file let project = create_test_project( cx, json!({ - "file.txt": content, + filename: content, }), ) .await; // Open the buffer let buffer_path = project - .read_with(cx, |project, cx| project.find_project_path("file.txt", cx)) + .read_with(cx, |project, cx| project.find_project_path(filename, cx)) .unwrap(); let buffer = project @@ -1190,16 +1161,5 @@ mod tests { cx.update(|cx| load_context(vec![context_handle], &project, &None, cx)) .await - .loaded_context - .contexts - .into_iter() - .find_map(|ctx| { - if let AgentContext::File(file_ctx) = ctx { - Some(file_ctx) - } else { - None - } - }) - .expect("Should have found a file context") } } diff --git a/crates/agent_ui/src/context_picker.rs b/crates/agent_ui/src/context_picker.rs index 58edecdf3da6b16bca82a7d4c0e73dcac3969e03..cfb2ce0a60441c18d62965dddf6a626c4b4a4243 100644 --- a/crates/agent_ui/src/context_picker.rs +++ b/crates/agent_ui/src/context_picker.rs @@ -9,6 +9,8 @@ use std::ops::Range; use std::path::PathBuf; use std::sync::Arc; +use agent::{HistoryEntry, HistoryEntryId, HistoryStore}; +use agent_client_protocol as acp; use anyhow::{Result, anyhow}; use collections::HashSet; pub use completion_provider::ContextPickerCompletionProvider; @@ -27,9 +29,7 @@ use project::ProjectPath; use prompt_store::PromptStore; use rules_context_picker::{RulesContextEntry, RulesContextPicker}; use symbol_context_picker::SymbolContextPicker; -use thread_context_picker::{ - ThreadContextEntry, ThreadContextPicker, render_thread_context_entry, unordered_thread_entries, -}; +use thread_context_picker::render_thread_context_entry; use ui::{ ButtonLike, ContextMenu, ContextMenuEntry, ContextMenuItem, Disclosure, TintColor, prelude::*, }; @@ -37,12 +37,8 @@ use util::paths::PathStyle; use util::rel_path::RelPath; use workspace::{Workspace, notifications::NotifyResultExt}; -use agent::{ - ThreadId, - context::RULES_ICON, - context_store::ContextStore, - thread_store::{TextThreadStore, ThreadStore}, -}; +use crate::context_picker::thread_context_picker::ThreadContextPicker; +use crate::{context::RULES_ICON, context_store::ContextStore}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum ContextPickerEntry { @@ -168,17 +164,16 @@ pub(super) struct ContextPicker { mode: ContextPickerState, workspace: WeakEntity, context_store: WeakEntity, - thread_store: Option>, - text_thread_store: Option>, - prompt_store: Option>, + thread_store: Option>, + prompt_store: Option>, _subscriptions: Vec, } impl ContextPicker { pub fn new( workspace: WeakEntity, - thread_store: Option>, - text_thread_store: Option>, + thread_store: Option>, + prompt_store: Option>, context_store: WeakEntity, window: &mut Window, cx: &mut Context, @@ -199,13 +194,6 @@ impl ContextPicker { ) .collect::>(); - let prompt_store = thread_store.as_ref().and_then(|thread_store| { - thread_store - .read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone()) - .ok() - .flatten() - }); - ContextPicker { mode: ContextPickerState::Default(ContextMenu::build( window, @@ -215,7 +203,6 @@ impl ContextPicker { workspace, context_store, thread_store, - text_thread_store, prompt_store, _subscriptions: subscriptions, } @@ -355,17 +342,13 @@ impl ContextPicker { })); } ContextPickerMode::Thread => { - if let Some((thread_store, text_thread_store)) = self - .thread_store - .as_ref() - .zip(self.text_thread_store.as_ref()) - { + if let Some(thread_store) = self.thread_store.clone() { self.mode = ContextPickerState::Thread(cx.new(|cx| { ThreadContextPicker::new( - thread_store.clone(), - text_thread_store.clone(), + thread_store, context_picker.clone(), self.context_store.clone(), + self.workspace.clone(), window, cx, ) @@ -480,16 +463,23 @@ impl ContextPicker { fn add_recent_thread( &self, - entry: ThreadContextEntry, - window: &mut Window, + entry: HistoryEntry, + _window: &mut Window, cx: &mut Context, ) -> Task> { let Some(context_store) = self.context_store.upgrade() else { return Task::ready(Err(anyhow!("context store not available"))); }; + let Some(project) = self + .workspace + .upgrade() + .map(|workspace| workspace.read(cx).project().clone()) + else { + return Task::ready(Err(anyhow!("project not available"))); + }; match entry { - ThreadContextEntry::Thread { id, .. } => { + HistoryEntry::AcpThread(thread) => { let Some(thread_store) = self .thread_store .as_ref() @@ -497,28 +487,28 @@ impl ContextPicker { else { return Task::ready(Err(anyhow!("thread store not available"))); }; - - let open_thread_task = - thread_store.update(cx, |this, cx| this.open_thread(&id, window, cx)); + let load_thread_task = + agent::load_agent_thread(thread.id, thread_store, project, cx); cx.spawn(async move |this, cx| { - let thread = open_thread_task.await?; + let thread = load_thread_task.await?; context_store.update(cx, |context_store, cx| { context_store.add_thread(thread, true, cx); })?; this.update(cx, |_this, cx| cx.notify()) }) } - ThreadContextEntry::Context { path, .. } => { - let Some(text_thread_store) = self - .text_thread_store + HistoryEntry::TextThread(thread) => { + let Some(thread_store) = self + .thread_store .as_ref() .and_then(|thread_store| thread_store.upgrade()) else { return Task::ready(Err(anyhow!("text thread store not available"))); }; - let task = text_thread_store - .update(cx, |this, cx| this.open_local_context(path.clone(), cx)); + let task = thread_store.update(cx, |this, cx| { + this.load_text_thread(thread.path.clone(), cx) + }); cx.spawn(async move |this, cx| { let thread = task.await?; context_store.update(cx, |context_store, cx| { @@ -542,7 +532,6 @@ impl ContextPicker { recent_context_picker_entries_with_store( context_store, self.thread_store.clone(), - self.text_thread_store.clone(), workspace, None, cx, @@ -599,12 +588,12 @@ pub(crate) enum RecentEntry { project_path: ProjectPath, path_prefix: Arc, }, - Thread(ThreadContextEntry), + Thread(HistoryEntry), } pub(crate) fn available_context_picker_entries( - prompt_store: &Option>, - thread_store: &Option>, + prompt_store: &Option>, + thread_store: &Option>, workspace: &Entity, cx: &mut App, ) -> Vec { @@ -639,8 +628,7 @@ pub(crate) fn available_context_picker_entries( fn recent_context_picker_entries_with_store( context_store: Entity, - thread_store: Option>, - text_thread_store: Option>, + thread_store: Option>, workspace: Entity, exclude_path: Option, cx: &App, @@ -657,22 +645,14 @@ fn recent_context_picker_entries_with_store( let exclude_threads = context_store.read(cx).thread_ids(); - recent_context_picker_entries( - thread_store, - text_thread_store, - workspace, - &exclude_paths, - exclude_threads, - cx, - ) + recent_context_picker_entries(thread_store, workspace, &exclude_paths, exclude_threads, cx) } pub(crate) fn recent_context_picker_entries( - thread_store: Option>, - text_thread_store: Option>, + thread_store: Option>, workspace: Entity, exclude_paths: &HashSet, - _exclude_threads: &HashSet, + exclude_threads: &HashSet, cx: &App, ) -> Vec { let mut recent = Vec::with_capacity(6); @@ -698,30 +678,21 @@ pub(crate) fn recent_context_picker_entries( }), ); - if let Some((thread_store, text_thread_store)) = thread_store - .and_then(|store| store.upgrade()) - .zip(text_thread_store.and_then(|store| store.upgrade())) - { - let mut threads = unordered_thread_entries(thread_store, text_thread_store, cx) - .filter(|(_, thread)| match thread { - ThreadContextEntry::Thread { .. } => false, - ThreadContextEntry::Context { .. } => true, - }) - .collect::>(); - - const RECENT_COUNT: usize = 2; - if threads.len() > RECENT_COUNT { - threads.select_nth_unstable_by_key(RECENT_COUNT - 1, |(updated_at, _)| { - std::cmp::Reverse(*updated_at) - }); - threads.truncate(RECENT_COUNT); - } - threads.sort_unstable_by_key(|(updated_at, _)| std::cmp::Reverse(*updated_at)); - + if let Some(thread_store) = thread_store.and_then(|store| store.upgrade()) { + const RECENT_THREADS_COUNT: usize = 2; recent.extend( - threads - .into_iter() - .map(|(_, thread)| RecentEntry::Thread(thread)), + thread_store + .read(cx) + .recently_opened_entries(cx) + .iter() + .filter(|e| match e.id() { + HistoryEntryId::AcpThread(session_id) => !exclude_threads.contains(&session_id), + HistoryEntryId::TextThread(path) => { + !exclude_paths.contains(&path.to_path_buf()) + } + }) + .take(RECENT_THREADS_COUNT) + .map(|thread| RecentEntry::Thread(thread.clone())), ); } @@ -915,17 +886,21 @@ impl MentionLink { ) } - pub fn for_thread(thread: &ThreadContextEntry) -> String { + pub fn for_thread(thread: &HistoryEntry) -> String { match thread { - ThreadContextEntry::Thread { id, title } => { - format!("[@{}]({}:{})", title, Self::THREAD, id) + HistoryEntry::AcpThread(thread) => { + format!("[@{}]({}:{})", thread.title, Self::THREAD, thread.id) } - ThreadContextEntry::Context { path, title } => { - let filename = path.file_name().unwrap_or_default().to_string_lossy(); + HistoryEntry::TextThread(thread) => { + let filename = thread + .path + .file_name() + .unwrap_or_default() + .to_string_lossy(); let escaped_filename = urlencoding::encode(&filename); format!( "[@{}]({}:{}{})", - title, + thread.title, Self::THREAD, Self::TEXT_THREAD_URL_PREFIX, escaped_filename diff --git a/crates/agent_ui/src/context_picker/completion_provider.rs b/crates/agent_ui/src/context_picker/completion_provider.rs index e030779eb8c37347410507a74d27299dbcdfbf7d..56444141f12903db4868f9e154cccdb872b48514 100644 --- a/crates/agent_ui/src/context_picker/completion_provider.rs +++ b/crates/agent_ui/src/context_picker/completion_provider.rs @@ -3,7 +3,7 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use std::sync::atomic::AtomicBool; -use agent::context_store::ContextStore; +use agent::{HistoryEntry, HistoryStore}; use anyhow::Result; use editor::{CompletionProvider, Editor, ExcerptId, ToOffset as _}; use file_icons::FileIcons; @@ -15,8 +15,8 @@ use language::{Buffer, CodeLabel, CodeLabelBuilder, HighlightId}; use lsp::CompletionContext; use project::lsp_store::SymbolLocation; use project::{ - Completion, CompletionDisplayOptions, CompletionIntent, CompletionResponse, ProjectPath, - Symbol, WorktreeId, + Completion, CompletionDisplayOptions, CompletionIntent, CompletionResponse, Project, + ProjectPath, Symbol, WorktreeId, }; use prompt_store::PromptStore; use rope::Point; @@ -27,10 +27,9 @@ use util::paths::PathStyle; use util::rel_path::RelPath; use workspace::Workspace; -use agent::{ - Thread, +use crate::{ context::{AgentContextHandle, AgentContextKey, RULES_ICON}, - thread_store::{TextThreadStore, ThreadStore}, + context_store::ContextStore, }; use super::fetch_context_picker::fetch_url_content; @@ -38,7 +37,7 @@ use super::file_context_picker::{FileMatch, search_files}; use super::rules_context_picker::{RulesContextEntry, search_rules}; use super::symbol_context_picker::SymbolMatch; use super::symbol_context_picker::search_symbols; -use super::thread_context_picker::{ThreadContextEntry, ThreadMatch, search_threads}; +use super::thread_context_picker::search_threads; use super::{ ContextPickerAction, ContextPickerEntry, ContextPickerMode, MentionLink, RecentEntry, available_context_picker_entries, recent_context_picker_entries_with_store, selection_ranges, @@ -48,7 +47,8 @@ use crate::message_editor::ContextCreasesAddon; pub(crate) enum Match { File(FileMatch), Symbol(SymbolMatch), - Thread(ThreadMatch), + Thread(HistoryEntry), + RecentThread(HistoryEntry), Fetch(SharedString), Rules(RulesContextEntry), Entry(EntryMatch), @@ -65,6 +65,7 @@ impl Match { Match::File(file) => file.mat.score, Match::Entry(mode) => mode.mat.as_ref().map(|mat| mat.score).unwrap_or(1.), Match::Thread(_) => 1., + Match::RecentThread(_) => 1., Match::Symbol(_) => 1., Match::Fetch(_) => 1., Match::Rules(_) => 1., @@ -77,9 +78,8 @@ fn search( query: String, cancellation_flag: Arc, recent_entries: Vec, - prompt_store: Option>, - thread_store: Option>, - text_thread_context_store: Option>, + prompt_store: Option>, + thread_store: Option>, workspace: Entity, cx: &mut App, ) -> Task> { @@ -107,13 +107,9 @@ fn search( } Some(ContextPickerMode::Thread) => { - if let Some((thread_store, context_store)) = thread_store - .as_ref() - .and_then(|t| t.upgrade()) - .zip(text_thread_context_store.as_ref().and_then(|t| t.upgrade())) - { + if let Some(thread_store) = thread_store.as_ref().and_then(|t| t.upgrade()) { let search_threads_task = - search_threads(query, cancellation_flag, thread_store, context_store, cx); + search_threads(query, cancellation_flag, &thread_store, cx); cx.background_spawn(async move { search_threads_task .await @@ -135,8 +131,8 @@ fn search( } Some(ContextPickerMode::Rules) => { - if let Some(prompt_store) = prompt_store.as_ref() { - let search_rules_task = search_rules(query, cancellation_flag, prompt_store, cx); + if let Some(prompt_store) = prompt_store.as_ref().and_then(|p| p.upgrade()) { + let search_rules_task = search_rules(query, cancellation_flag, &prompt_store, cx); cx.background_spawn(async move { search_rules_task .await @@ -169,12 +165,7 @@ fn search( }, is_recent: true, }), - super::RecentEntry::Thread(thread_context_entry) => { - Match::Thread(ThreadMatch { - thread: thread_context_entry, - is_recent: true, - }) - } + super::RecentEntry::Thread(entry) => Match::RecentThread(entry), }) .collect::>(); @@ -245,8 +236,8 @@ fn search( pub struct ContextPickerCompletionProvider { workspace: WeakEntity, context_store: WeakEntity, - thread_store: Option>, - text_thread_store: Option>, + thread_store: Option>, + prompt_store: Option>, editor: WeakEntity, excluded_buffer: Option>, } @@ -255,8 +246,8 @@ impl ContextPickerCompletionProvider { pub fn new( workspace: WeakEntity, context_store: WeakEntity, - thread_store: Option>, - text_thread_store: Option>, + thread_store: Option>, + prompt_store: Option>, editor: WeakEntity, exclude_buffer: Option>, ) -> Self { @@ -264,7 +255,7 @@ impl ContextPickerCompletionProvider { workspace, context_store, thread_store, - text_thread_store, + prompt_store, editor, excluded_buffer: exclude_buffer, } @@ -406,14 +397,14 @@ impl ContextPickerCompletionProvider { } fn completion_for_thread( - thread_entry: ThreadContextEntry, + thread_entry: HistoryEntry, excerpt_id: ExcerptId, source_range: Range, recent: bool, editor: Entity, context_store: Entity, - thread_store: Entity, - text_thread_store: Entity, + thread_store: Entity, + project: Entity, ) -> Completion { let icon_for_completion = if recent { IconName::HistoryRerun @@ -439,18 +430,16 @@ impl ContextPickerCompletionProvider { editor, context_store.clone(), move |window, cx| match &thread_entry { - ThreadContextEntry::Thread { id, .. } => { - let thread_id = id.clone(); + HistoryEntry::AcpThread(thread) => { let context_store = context_store.clone(); - let thread_store = thread_store.clone(); + let load_thread_task = agent::load_agent_thread( + thread.id.clone(), + thread_store.clone(), + project.clone(), + cx, + ); window.spawn::<_, Option<_>>(cx, async move |cx| { - let thread: Entity = thread_store - .update_in(cx, |thread_store, window, cx| { - thread_store.open_thread(&thread_id, window, cx) - }) - .ok()? - .await - .log_err()?; + let thread = load_thread_task.await.log_err()?; let context = context_store .update(cx, |context_store, cx| { context_store.add_thread(thread, false, cx) @@ -459,13 +448,13 @@ impl ContextPickerCompletionProvider { Some(context) }) } - ThreadContextEntry::Context { path, .. } => { - let path = path.clone(); + HistoryEntry::TextThread(thread) => { + let path = thread.path.clone(); let context_store = context_store.clone(); - let text_thread_store = text_thread_store.clone(); + let thread_store = thread_store.clone(); cx.spawn::<_, Option<_>>(async move |cx| { - let thread = text_thread_store - .update(cx, |store, cx| store.open_local_context(path, cx)) + let thread = thread_store + .update(cx, |store, cx| store.load_text_thread(path, cx)) .ok()? .await .log_err()?; @@ -774,7 +763,7 @@ impl CompletionProvider for ContextPickerCompletionProvider { ..snapshot.anchor_after(state.source_range.end); let thread_store = self.thread_store.clone(); - let text_thread_store = self.text_thread_store.clone(); + let prompt_store = self.prompt_store.clone(); let editor = self.editor.clone(); let http_client = workspace.read(cx).client().http_client(); let path_style = workspace.read(cx).path_style(cx); @@ -792,19 +781,11 @@ impl CompletionProvider for ContextPickerCompletionProvider { let recent_entries = recent_context_picker_entries_with_store( context_store.clone(), thread_store.clone(), - text_thread_store.clone(), workspace.clone(), excluded_path.clone(), cx, ); - let prompt_store = thread_store.as_ref().and_then(|thread_store| { - thread_store - .read_with(cx, |thread_store, _cx| thread_store.prompt_store().clone()) - .ok() - .flatten() - }); - let search_task = search( mode, query, @@ -812,14 +793,14 @@ impl CompletionProvider for ContextPickerCompletionProvider { recent_entries, prompt_store, thread_store.clone(), - text_thread_store.clone(), workspace.clone(), cx, ); + let project = workspace.read(cx).project().downgrade(); cx.spawn(async move |_, cx| { let matches = search_task.await; - let Some(editor) = editor.upgrade() else { + let Some((editor, project)) = editor.upgrade().zip(project.upgrade()) else { return Ok(Vec::new()); }; @@ -860,25 +841,32 @@ impl CompletionProvider for ContextPickerCompletionProvider { workspace.clone(), cx, ), - - Match::Thread(ThreadMatch { - thread, is_recent, .. - }) => { + Match::Thread(thread) => { let thread_store = thread_store.as_ref().and_then(|t| t.upgrade())?; - let text_thread_store = - text_thread_store.as_ref().and_then(|t| t.upgrade())?; Some(Self::completion_for_thread( thread, excerpt_id, source_range.clone(), - is_recent, + false, editor.clone(), context_store.clone(), thread_store, - text_thread_store, + project.clone(), + )) + } + Match::RecentThread(thread) => { + let thread_store = thread_store.as_ref().and_then(|t| t.upgrade())?; + Some(Self::completion_for_thread( + thread, + excerpt_id, + source_range.clone(), + true, + editor.clone(), + context_store.clone(), + thread_store, + project.clone(), )) } - Match::Rules(user_rules) => Some(Self::completion_for_rules( user_rules, excerpt_id, @@ -1281,7 +1269,7 @@ mod tests { editor }); - let context_store = cx.new(|_| ContextStore::new(project.downgrade(), None)); + let context_store = cx.new(|_| ContextStore::new(project.downgrade())); let editor_entity = editor.downgrade(); editor.update_in(&mut cx, |editor, window, cx| { diff --git a/crates/agent_ui/src/context_picker/fetch_context_picker.rs b/crates/agent_ui/src/context_picker/fetch_context_picker.rs index dd558b2a1c88f60e68313b208b076a0974b30f85..31fc45aca3ccbf561793769939169d214aaa2d99 100644 --- a/crates/agent_ui/src/context_picker/fetch_context_picker.rs +++ b/crates/agent_ui/src/context_picker/fetch_context_picker.rs @@ -2,7 +2,6 @@ use std::cell::RefCell; use std::rc::Rc; use std::sync::Arc; -use agent::context_store::ContextStore; use anyhow::{Context as _, Result, bail}; use futures::AsyncReadExt as _; use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity}; @@ -12,7 +11,7 @@ use picker::{Picker, PickerDelegate}; use ui::{Context, ListItem, Window, prelude::*}; use workspace::Workspace; -use crate::context_picker::ContextPicker; +use crate::{context_picker::ContextPicker, context_store::ContextStore}; pub struct FetchContextPicker { picker: Entity>, diff --git a/crates/agent_ui/src/context_picker/file_context_picker.rs b/crates/agent_ui/src/context_picker/file_context_picker.rs index 4f7a4308406f9d9fbdfa42cc86adc1ffe7593396..8d1e5cb46dfba7bc89770356334fb08a7bf7a0c5 100644 --- a/crates/agent_ui/src/context_picker/file_context_picker.rs +++ b/crates/agent_ui/src/context_picker/file_context_picker.rs @@ -12,8 +12,10 @@ use ui::{ListItem, Tooltip, prelude::*}; use util::{ResultExt as _, paths::PathStyle, rel_path::RelPath}; use workspace::Workspace; -use crate::context_picker::ContextPicker; -use agent::context_store::{ContextStore, FileInclusion}; +use crate::{ + context_picker::ContextPicker, + context_store::{ContextStore, FileInclusion}, +}; pub struct FileContextPicker { picker: Entity>, diff --git a/crates/agent_ui/src/context_picker/rules_context_picker.rs b/crates/agent_ui/src/context_picker/rules_context_picker.rs index 677011577aef23296a34203acdb10e5228ca7cd7..68f4917a4fd5689aab1a418dd78d2c8a322cd717 100644 --- a/crates/agent_ui/src/context_picker/rules_context_picker.rs +++ b/crates/agent_ui/src/context_picker/rules_context_picker.rs @@ -7,9 +7,11 @@ use prompt_store::{PromptId, PromptStore, UserPromptId}; use ui::{ListItem, prelude::*}; use util::ResultExt as _; -use crate::context_picker::ContextPicker; -use agent::context::RULES_ICON; -use agent::context_store::{self, ContextStore}; +use crate::{ + context::RULES_ICON, + context_picker::ContextPicker, + context_store::{self, ContextStore}, +}; pub struct RulesContextPicker { picker: Entity>, @@ -17,7 +19,7 @@ pub struct RulesContextPicker { impl RulesContextPicker { pub fn new( - prompt_store: Entity, + prompt_store: WeakEntity, context_picker: WeakEntity, context_store: WeakEntity, window: &mut Window, @@ -49,7 +51,7 @@ pub struct RulesContextEntry { } pub struct RulesContextPickerDelegate { - prompt_store: Entity, + prompt_store: WeakEntity, context_picker: WeakEntity, context_store: WeakEntity, matches: Vec, @@ -58,7 +60,7 @@ pub struct RulesContextPickerDelegate { impl RulesContextPickerDelegate { pub fn new( - prompt_store: Entity, + prompt_store: WeakEntity, context_picker: WeakEntity, context_store: WeakEntity, ) -> Self { @@ -102,12 +104,10 @@ impl PickerDelegate for RulesContextPickerDelegate { window: &mut Window, cx: &mut Context>, ) -> Task<()> { - let search_task = search_rules( - query, - Arc::new(AtomicBool::default()), - &self.prompt_store, - cx, - ); + let Some(prompt_store) = self.prompt_store.upgrade() else { + return Task::ready(()); + }; + let search_task = search_rules(query, Arc::new(AtomicBool::default()), &prompt_store, cx); cx.spawn_in(window, async move |this, cx| { let matches = search_task.await; this.update(cx, |this, cx| { diff --git a/crates/agent_ui/src/context_picker/symbol_context_picker.rs b/crates/agent_ui/src/context_picker/symbol_context_picker.rs index 5b89f09de884067a94832c7bf474a2949e78c420..fbce71d94efd84b1acc6e0b5d4ea11cb2b9243d5 100644 --- a/crates/agent_ui/src/context_picker/symbol_context_picker.rs +++ b/crates/agent_ui/src/context_picker/symbol_context_picker.rs @@ -15,9 +15,9 @@ use ui::{ListItem, prelude::*}; use util::ResultExt as _; use workspace::Workspace; -use crate::context_picker::ContextPicker; -use agent::context::AgentContextHandle; -use agent::context_store::ContextStore; +use crate::{ + context::AgentContextHandle, context_picker::ContextPicker, context_store::ContextStore, +}; pub struct SymbolContextPicker { picker: Entity>, diff --git a/crates/agent_ui/src/context_picker/thread_context_picker.rs b/crates/agent_ui/src/context_picker/thread_context_picker.rs index 9e843779c2216a89fe23dce514553e50043b8187..d6a3a270742fe28c483d2d7d39894eb9e3c021ea 100644 --- a/crates/agent_ui/src/context_picker/thread_context_picker.rs +++ b/crates/agent_ui/src/context_picker/thread_context_picker.rs @@ -1,19 +1,16 @@ -use std::path::Path; use std::sync::Arc; use std::sync::atomic::AtomicBool; -use chrono::{DateTime, Utc}; +use crate::{ + context_picker::ContextPicker, + context_store::{self, ContextStore}, +}; +use agent::{HistoryEntry, HistoryStore}; use fuzzy::StringMatchCandidate; use gpui::{App, DismissEvent, Entity, FocusHandle, Focusable, Task, WeakEntity}; use picker::{Picker, PickerDelegate}; use ui::{ListItem, prelude::*}; - -use crate::context_picker::ContextPicker; -use agent::{ - ThreadId, - context_store::{self, ContextStore}, - thread_store::{TextThreadStore, ThreadStore}, -}; +use workspace::Workspace; pub struct ThreadContextPicker { picker: Entity>, @@ -21,18 +18,18 @@ pub struct ThreadContextPicker { impl ThreadContextPicker { pub fn new( - thread_store: WeakEntity, - text_thread_context_store: WeakEntity, + thread_store: WeakEntity, context_picker: WeakEntity, context_store: WeakEntity, + workspace: WeakEntity, window: &mut Window, cx: &mut Context, ) -> Self { let delegate = ThreadContextPickerDelegate::new( thread_store, - text_thread_context_store, context_picker, context_store, + workspace, ); let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx)); @@ -52,48 +49,27 @@ impl Render for ThreadContextPicker { } } -#[derive(Debug, Clone)] -pub enum ThreadContextEntry { - Thread { - id: ThreadId, - title: SharedString, - }, - Context { - path: Arc, - title: SharedString, - }, -} - -impl ThreadContextEntry { - pub fn title(&self) -> &SharedString { - match self { - Self::Thread { title, .. } => title, - Self::Context { title, .. } => title, - } - } -} - pub struct ThreadContextPickerDelegate { - thread_store: WeakEntity, - text_thread_store: WeakEntity, + thread_store: WeakEntity, context_picker: WeakEntity, context_store: WeakEntity, - matches: Vec, + workspace: WeakEntity, + matches: Vec, selected_index: usize, } impl ThreadContextPickerDelegate { pub fn new( - thread_store: WeakEntity, - text_thread_store: WeakEntity, + thread_store: WeakEntity, context_picker: WeakEntity, context_store: WeakEntity, + workspace: WeakEntity, ) -> Self { ThreadContextPickerDelegate { thread_store, context_picker, context_store, - text_thread_store, + workspace, matches: Vec::new(), selected_index: 0, } @@ -130,25 +106,15 @@ impl PickerDelegate for ThreadContextPickerDelegate { window: &mut Window, cx: &mut Context>, ) -> Task<()> { - let Some((thread_store, text_thread_context_store)) = self - .thread_store - .upgrade() - .zip(self.text_thread_store.upgrade()) - else { + let Some(thread_store) = self.thread_store.upgrade() else { return Task::ready(()); }; - let search_task = search_threads( - query, - Arc::new(AtomicBool::default()), - thread_store, - text_thread_context_store, - cx, - ); + let search_task = search_threads(query, Arc::new(AtomicBool::default()), &thread_store, cx); cx.spawn_in(window, async move |this, cx| { let matches = search_task.await; this.update(cx, |this, cx| { - this.delegate.matches = matches.into_iter().map(|mat| mat.thread).collect(); + this.delegate.matches = matches; this.delegate.selected_index = 0; cx.notify(); }) @@ -156,21 +122,29 @@ impl PickerDelegate for ThreadContextPickerDelegate { }) } - fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context>) { - let Some(entry) = self.matches.get(self.selected_index) else { + fn confirm(&mut self, _secondary: bool, _window: &mut Window, cx: &mut Context>) { + let Some(project) = self + .workspace + .upgrade() + .map(|w| w.read(cx).project().clone()) + else { + return; + }; + let Some((entry, thread_store)) = self + .matches + .get(self.selected_index) + .zip(self.thread_store.upgrade()) + else { return; }; match entry { - ThreadContextEntry::Thread { id, .. } => { - let Some(thread_store) = self.thread_store.upgrade() else { - return; - }; - let open_thread_task = - thread_store.update(cx, |this, cx| this.open_thread(id, window, cx)); + HistoryEntry::AcpThread(thread) => { + let load_thread_task = + agent::load_agent_thread(thread.id.clone(), thread_store, project, cx); cx.spawn(async move |this, cx| { - let thread = open_thread_task.await?; + let thread = load_thread_task.await?; this.update(cx, |this, cx| { this.delegate .context_store @@ -182,12 +156,10 @@ impl PickerDelegate for ThreadContextPickerDelegate { }) .detach_and_log_err(cx); } - ThreadContextEntry::Context { path, .. } => { - let Some(text_thread_store) = self.text_thread_store.upgrade() else { - return; - }; - let task = text_thread_store - .update(cx, |this, cx| this.open_local_context(path.clone(), cx)); + HistoryEntry::TextThread(thread) => { + let task = thread_store.update(cx, |this, cx| { + this.load_text_thread(thread.path.clone(), cx) + }); cx.spawn(async move |this, cx| { let thread = task.await?; @@ -229,17 +201,17 @@ impl PickerDelegate for ThreadContextPickerDelegate { } pub fn render_thread_context_entry( - entry: &ThreadContextEntry, + entry: &HistoryEntry, context_store: WeakEntity, cx: &mut App, ) -> Div { let is_added = match entry { - ThreadContextEntry::Thread { id, .. } => context_store + HistoryEntry::AcpThread(thread) => context_store .upgrade() - .is_some_and(|ctx_store| ctx_store.read(cx).includes_thread(id)), - ThreadContextEntry::Context { path, .. } => context_store + .is_some_and(|ctx_store| ctx_store.read(cx).includes_thread(&thread.id)), + HistoryEntry::TextThread(thread) => context_store .upgrade() - .is_some_and(|ctx_store| ctx_store.read(cx).includes_text_thread(path)), + .is_some_and(|ctx_store| ctx_store.read(cx).includes_text_thread(&thread.path)), }; h_flex() @@ -271,91 +243,38 @@ pub fn render_thread_context_entry( }) } -#[derive(Clone)] -pub struct ThreadMatch { - pub thread: ThreadContextEntry, - pub is_recent: bool, -} - -pub fn unordered_thread_entries( - thread_store: Entity, - text_thread_store: Entity, - cx: &App, -) -> impl Iterator, ThreadContextEntry)> { - let threads = thread_store - .read(cx) - .reverse_chronological_threads() - .map(|thread| { - ( - thread.updated_at, - ThreadContextEntry::Thread { - id: thread.id.clone(), - title: thread.summary.clone(), - }, - ) - }); - - let text_threads = text_thread_store - .read(cx) - .unordered_contexts() - .map(|context| { - ( - context.mtime.to_utc(), - ThreadContextEntry::Context { - path: context.path.clone(), - title: context.title.clone(), - }, - ) - }); - - threads.chain(text_threads) -} - pub(crate) fn search_threads( query: String, cancellation_flag: Arc, - thread_store: Entity, - text_thread_store: Entity, + thread_store: &Entity, cx: &mut App, -) -> Task> { - let mut threads = - unordered_thread_entries(thread_store, text_thread_store, cx).collect::>(); - threads.sort_unstable_by_key(|(updated_at, _)| std::cmp::Reverse(*updated_at)); +) -> Task> { + let threads = thread_store.read(cx).entries().collect(); + if query.is_empty() { + return Task::ready(threads); + } let executor = cx.background_executor().clone(); cx.background_spawn(async move { - if query.is_empty() { - threads - .into_iter() - .map(|(_, thread)| ThreadMatch { - thread, - is_recent: false, - }) - .collect() - } else { - let candidates = threads - .iter() - .enumerate() - .map(|(id, (_, thread))| StringMatchCandidate::new(id, thread.title())) - .collect::>(); - let matches = fuzzy::match_strings( - &candidates, - &query, - false, - true, - 100, - &cancellation_flag, - executor, - ) - .await; + let candidates = threads + .iter() + .enumerate() + .map(|(id, thread)| StringMatchCandidate::new(id, thread.title())) + .collect::>(); + let matches = fuzzy::match_strings( + &candidates, + &query, + false, + true, + 100, + &cancellation_flag, + executor, + ) + .await; - matches - .into_iter() - .map(|mat| ThreadMatch { - thread: threads[mat.candidate_id].1.clone(), - is_recent: false, - }) - .collect() - } + matches + .into_iter() + .map(|mat| threads[mat.candidate_id].clone()) + .collect() }) } diff --git a/crates/agent/src/context_store.rs b/crates/agent_ui/src/context_store.rs similarity index 87% rename from crates/agent/src/context_store.rs rename to crates/agent_ui/src/context_store.rs index cf35840cc4215695a966931701257c838c00af18..e2ee1cd0c94fd6132719ffcc0bd352865b5f9cf9 100644 --- a/crates/agent/src/context_store.rs +++ b/crates/agent_ui/src/context_store.rs @@ -1,12 +1,9 @@ -use crate::{ - context::{ - AgentContextHandle, AgentContextKey, ContextId, ContextKind, DirectoryContextHandle, - FetchedUrlContext, FileContextHandle, ImageContext, RulesContextHandle, - SelectionContextHandle, SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle, - }, - thread::{MessageId, Thread, ThreadId}, - thread_store::ThreadStore, +use crate::context::{ + AgentContextHandle, AgentContextKey, ContextId, ContextKind, DirectoryContextHandle, + FetchedUrlContext, FileContextHandle, ImageContext, RulesContextHandle, SelectionContextHandle, + SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle, }; +use agent_client_protocol as acp; use anyhow::{Context as _, Result, anyhow}; use assistant_context::AssistantContext; use collections::{HashSet, IndexSet}; @@ -29,10 +26,9 @@ use text::{Anchor, OffsetRangeExt}; pub struct ContextStore { project: WeakEntity, - thread_store: Option>, next_context_id: ContextId, context_set: IndexSet, - context_thread_ids: HashSet, + context_thread_ids: HashSet, context_text_thread_paths: HashSet>, } @@ -43,13 +39,9 @@ pub enum ContextStoreEvent { impl EventEmitter for ContextStore {} impl ContextStore { - pub fn new( - project: WeakEntity, - thread_store: Option>, - ) -> Self { + pub fn new(project: WeakEntity) -> Self { Self { project, - thread_store, next_context_id: ContextId::zero(), context_set: IndexSet::default(), context_thread_ids: HashSet::default(), @@ -67,29 +59,6 @@ impl ContextStore { cx.notify(); } - pub fn new_context_for_thread( - &self, - thread: &Thread, - exclude_messages_from_id: Option, - ) -> Vec { - let existing_context = thread - .messages() - .take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id)) - .flat_map(|message| { - message - .loaded_context - .contexts - .iter() - .map(|context| AgentContextKey(context.handle())) - }) - .collect::>(); - self.context_set - .iter() - .filter(|context| !existing_context.contains(context)) - .map(|entry| entry.0.clone()) - .collect::>() - } - pub fn add_file_from_path( &mut self, project_path: ProjectPath, @@ -209,7 +178,7 @@ impl ContextStore { pub fn add_thread( &mut self, - thread: Entity, + thread: Entity, remove_if_exists: bool, cx: &mut Context, ) -> Option { @@ -384,15 +353,15 @@ impl ContextStore { ); }; } - SuggestedContext::Thread { thread, name: _ } => { - if let Some(thread) = thread.upgrade() { - let context_id = self.next_context_id.post_inc(); - self.insert_context( - AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }), - cx, - ); - } - } + // SuggestedContext::Thread { thread, name: _ } => { + // if let Some(thread) = thread.upgrade() { + // let context_id = self.next_context_id.post_inc(); + // self.insert_context( + // AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }), + // cx, + // ); + // } + // } SuggestedContext::TextThread { context, name: _ } => { if let Some(context) = context.upgrade() { let context_id = self.next_context_id.post_inc(); @@ -410,17 +379,17 @@ impl ContextStore { fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context) -> bool { match &context { - AgentContextHandle::Thread(thread_context) => { - if let Some(thread_store) = self.thread_store.clone() { - thread_context.thread.update(cx, |thread, cx| { - thread.start_generating_detailed_summary_if_needed(thread_store, cx); - }); - self.context_thread_ids - .insert(thread_context.thread.read(cx).id().clone()); - } else { - return false; - } - } + // AgentContextHandle::Thread(thread_context) => { + // if let Some(thread_store) = self.thread_store.clone() { + // thread_context.thread.update(cx, |thread, cx| { + // thread.start_generating_detailed_summary_if_needed(thread_store, cx); + // }); + // self.context_thread_ids + // .insert(thread_context.thread.read(cx).id().clone()); + // } else { + // return false; + // } + // } AgentContextHandle::TextThread(text_thread_context) => { self.context_text_thread_paths .extend(text_thread_context.context.read(cx).path().cloned()); @@ -514,7 +483,7 @@ impl ContextStore { }) } - pub fn includes_thread(&self, thread_id: &ThreadId) -> bool { + pub fn includes_thread(&self, thread_id: &acp::SessionId) -> bool { self.context_thread_ids.contains(thread_id) } @@ -547,9 +516,9 @@ impl ContextStore { } AgentContextHandle::Directory(_) | AgentContextHandle::Symbol(_) + | AgentContextHandle::Thread(_) | AgentContextHandle::Selection(_) | AgentContextHandle::FetchedUrl(_) - | AgentContextHandle::Thread(_) | AgentContextHandle::TextThread(_) | AgentContextHandle::Rules(_) | AgentContextHandle::Image(_) => None, @@ -557,7 +526,7 @@ impl ContextStore { .collect() } - pub fn thread_ids(&self) -> &HashSet { + pub fn thread_ids(&self) -> &HashSet { &self.context_thread_ids } } @@ -569,10 +538,10 @@ pub enum SuggestedContext { icon_path: Option, buffer: WeakEntity, }, - Thread { - name: SharedString, - thread: WeakEntity, - }, + // Thread { + // name: SharedString, + // thread: WeakEntity, + // }, TextThread { name: SharedString, context: WeakEntity, @@ -583,7 +552,7 @@ impl SuggestedContext { pub fn name(&self) -> &SharedString { match self { Self::File { name, .. } => name, - Self::Thread { name, .. } => name, + // Self::Thread { name, .. } => name, Self::TextThread { name, .. } => name, } } @@ -591,7 +560,7 @@ impl SuggestedContext { pub fn icon_path(&self) -> Option { match self { Self::File { icon_path, .. } => icon_path.clone(), - Self::Thread { .. } => None, + // Self::Thread { .. } => None, Self::TextThread { .. } => None, } } @@ -599,7 +568,7 @@ impl SuggestedContext { pub fn kind(&self) -> ContextKind { match self { Self::File { .. } => ContextKind::File, - Self::Thread { .. } => ContextKind::Thread, + // Self::Thread { .. } => ContextKind::Thread, Self::TextThread { .. } => ContextKind::TextThread, } } diff --git a/crates/agent_ui/src/context_strip.rs b/crates/agent_ui/src/context_strip.rs index b75b933de40f19557d9dfa83c874c3427773445b..1f40da3d945df5f066289932b83065dc33d8e169 100644 --- a/crates/agent_ui/src/context_strip.rs +++ b/crates/agent_ui/src/context_strip.rs @@ -4,12 +4,11 @@ use crate::{ context_picker::ContextPicker, ui::{AddedContext, ContextPill}, }; -use agent::context_store::SuggestedContext; -use agent::{ +use crate::{ context::AgentContextHandle, - context_store::ContextStore, - thread_store::{TextThreadStore, ThreadStore}, + context_store::{ContextStore, SuggestedContext}, }; +use agent::HistoryStore; use collections::HashSet; use editor::Editor; use gpui::{ @@ -18,6 +17,7 @@ use gpui::{ }; use itertools::Itertools; use project::ProjectItem; +use prompt_store::PromptStore; use rope::Point; use std::rc::Rc; use text::ToPoint as _; @@ -33,7 +33,7 @@ pub struct ContextStrip { focus_handle: FocusHandle, suggest_context_kind: SuggestContextKind, workspace: WeakEntity, - thread_store: Option>, + prompt_store: Option>, _subscriptions: Vec, focused_index: Option, children_bounds: Option>>, @@ -44,8 +44,8 @@ impl ContextStrip { pub fn new( context_store: Entity, workspace: WeakEntity, - thread_store: Option>, - text_thread_store: Option>, + thread_store: Option>, + prompt_store: Option>, context_picker_menu_handle: PopoverMenuHandle, suggest_context_kind: SuggestContextKind, model_usage_context: ModelUsageContext, @@ -56,7 +56,7 @@ impl ContextStrip { ContextPicker::new( workspace.clone(), thread_store.clone(), - text_thread_store, + prompt_store.clone(), context_store.downgrade(), window, cx, @@ -79,7 +79,7 @@ impl ContextStrip { focus_handle, suggest_context_kind, workspace, - thread_store, + prompt_store, _subscriptions: subscriptions, focused_index: None, children_bounds: None, @@ -96,11 +96,7 @@ impl ContextStrip { fn added_contexts(&self, cx: &App) -> Vec { if let Some(workspace) = self.workspace.upgrade() { let project = workspace.read(cx).project().read(cx); - let prompt_store = self - .thread_store - .as_ref() - .and_then(|thread_store| thread_store.upgrade()) - .and_then(|thread_store| thread_store.read(cx).prompt_store().as_ref()); + let prompt_store = self.prompt_store.as_ref().and_then(|p| p.upgrade()); let current_model = self.model_usage_context.language_model(cx); @@ -110,7 +106,7 @@ impl ContextStrip { .flat_map(|context| { AddedContext::new_pending( context.clone(), - prompt_store, + prompt_store.as_ref(), project, current_model.as_ref(), cx, @@ -339,7 +335,7 @@ impl ContextStrip { let context = text_thread_context.context.clone(); window.defer(cx, move |window, cx| { panel.update(cx, |panel, cx| { - panel.open_prompt_editor(context, window, cx) + panel.open_text_thread(context, window, cx) }); }); } diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index 3d25e614ad69d264700476d52ddc0407590b9e9c..4c09e475b10881ab9bc2327b5b18b1c66e2ba4ad 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -7,13 +7,11 @@ use std::sync::Arc; use crate::{ AgentPanel, buffer_codegen::{BufferCodegen, CodegenAlternative, CodegenEvent}, + context_store::ContextStore, inline_prompt_editor::{CodegenStatus, InlineAssistId, PromptEditor, PromptEditorEvent}, terminal_inline_assistant::TerminalInlineAssistant, }; -use agent::{ - context_store::ContextStore, - thread_store::{TextThreadStore, ThreadStore}, -}; +use agent::HistoryStore; use agent_settings::AgentSettings; use anyhow::{Context as _, Result}; use client::telemetry::Telemetry; @@ -209,24 +207,21 @@ impl InlineAssistant { window: &mut Window, cx: &mut App, ) { - let is_assistant2_enabled = !DisableAiSettings::get_global(cx).disable_ai; + let is_ai_enabled = !DisableAiSettings::get_global(cx).disable_ai; if let Some(editor) = item.act_as::(cx) { editor.update(cx, |editor, cx| { - if is_assistant2_enabled { + if is_ai_enabled { let panel = workspace.read(cx).panel::(cx); let thread_store = panel .as_ref() .map(|agent_panel| agent_panel.read(cx).thread_store().downgrade()); - let text_thread_store = panel - .map(|agent_panel| agent_panel.read(cx).text_thread_store().downgrade()); editor.add_code_action_provider( Rc::new(AssistantCodeActionProvider { editor: cx.entity().downgrade(), workspace: workspace.downgrade(), thread_store, - text_thread_store, }), window, cx, @@ -283,7 +278,6 @@ impl InlineAssistant { let prompt_store = agent_panel.prompt_store().as_ref().cloned(); let thread_store = Some(agent_panel.thread_store().downgrade()); - let text_thread_store = Some(agent_panel.text_thread_store().downgrade()); let context_store = agent_panel.inline_assist_context_store().clone(); let handle_assist = @@ -297,7 +291,6 @@ impl InlineAssistant { workspace.project().downgrade(), prompt_store, thread_store, - text_thread_store, action.prompt.clone(), window, cx, @@ -312,7 +305,6 @@ impl InlineAssistant { workspace.project().downgrade(), prompt_store, thread_store, - text_thread_store, action.prompt.clone(), window, cx, @@ -365,8 +357,7 @@ impl InlineAssistant { context_store: Entity, project: WeakEntity, prompt_store: Option>, - thread_store: Option>, - text_thread_store: Option>, + thread_store: Option>, initial_prompt: Option, window: &mut Window, cx: &mut App, @@ -517,7 +508,7 @@ impl InlineAssistant { context_store.clone(), workspace.clone(), thread_store.clone(), - text_thread_store.clone(), + prompt_store.as_ref().map(|s| s.downgrade()), window, cx, ) @@ -589,8 +580,7 @@ impl InlineAssistant { focus: bool, workspace: Entity, prompt_store: Option>, - thread_store: Option>, - text_thread_store: Option>, + thread_store: Option>, window: &mut Window, cx: &mut App, ) -> InlineAssistId { @@ -608,7 +598,7 @@ impl InlineAssistant { } let project = workspace.read(cx).project().downgrade(); - let context_store = cx.new(|_cx| ContextStore::new(project.clone(), thread_store.clone())); + let context_store = cx.new(|_cx| ContextStore::new(project.clone())); let codegen = cx.new(|cx| { BufferCodegen::new( @@ -617,7 +607,7 @@ impl InlineAssistant { initial_transaction_id, context_store.clone(), project, - prompt_store, + prompt_store.clone(), self.telemetry.clone(), self.prompt_builder.clone(), cx, @@ -636,7 +626,7 @@ impl InlineAssistant { context_store, workspace.downgrade(), thread_store, - text_thread_store, + prompt_store.map(|s| s.downgrade()), window, cx, ) @@ -1773,8 +1763,7 @@ struct InlineAssistDecorations { struct AssistantCodeActionProvider { editor: WeakEntity, workspace: WeakEntity, - thread_store: Option>, - text_thread_store: Option>, + thread_store: Option>, } const ASSISTANT_CODE_ACTION_PROVIDER_ID: &str = "assistant2"; @@ -1846,7 +1835,6 @@ impl CodeActionProvider for AssistantCodeActionProvider { let editor = self.editor.clone(); let workspace = self.workspace.clone(); let thread_store = self.thread_store.clone(); - let text_thread_store = self.text_thread_store.clone(); let prompt_store = PromptStore::global(cx); window.spawn(cx, async move |cx| { let workspace = workspace.upgrade().context("workspace was released")?; @@ -1894,7 +1882,6 @@ impl CodeActionProvider for AssistantCodeActionProvider { workspace, prompt_store, thread_store, - text_thread_store, window, cx, ); diff --git a/crates/agent_ui/src/inline_prompt_editor.rs b/crates/agent_ui/src/inline_prompt_editor.rs index f6347dcb6b80c1b5c939a5c4cd650b9fadf92c62..70d6009e466e3e2f6ba3cd65076f77f7d12b22e0 100644 --- a/crates/agent_ui/src/inline_prompt_editor.rs +++ b/crates/agent_ui/src/inline_prompt_editor.rs @@ -1,7 +1,5 @@ -use agent::{ - context_store::ContextStore, - thread_store::{TextThreadStore, ThreadStore}, -}; +use crate::context_store::ContextStore; +use agent::HistoryStore; use collections::VecDeque; use editor::actions::Paste; use editor::display_map::EditorMargins; @@ -16,6 +14,7 @@ use gpui::{ }; use language_model::{LanguageModel, LanguageModelRegistry}; use parking_lot::Mutex; +use prompt_store::PromptStore; use settings::Settings; use std::cmp; use std::rc::Rc; @@ -777,8 +776,8 @@ impl PromptEditor { fs: Arc, context_store: Entity, workspace: WeakEntity, - thread_store: Option>, - text_thread_store: Option>, + thread_store: Option>, + prompt_store: Option>, window: &mut Window, cx: &mut Context>, ) -> PromptEditor { @@ -823,7 +822,7 @@ impl PromptEditor { workspace.clone(), context_store.downgrade(), thread_store.clone(), - text_thread_store.clone(), + prompt_store.clone(), prompt_editor_entity, codegen_buffer.as_ref().map(Entity::downgrade), )))); @@ -837,7 +836,7 @@ impl PromptEditor { context_store.clone(), workspace.clone(), thread_store.clone(), - text_thread_store.clone(), + prompt_store, context_picker_menu_handle.clone(), SuggestContextKind::Thread, ModelUsageContext::InlineAssistant, @@ -949,8 +948,8 @@ impl PromptEditor { fs: Arc, context_store: Entity, workspace: WeakEntity, - thread_store: Option>, - text_thread_store: Option>, + thread_store: Option>, + prompt_store: Option>, window: &mut Window, cx: &mut Context, ) -> Self { @@ -988,7 +987,7 @@ impl PromptEditor { workspace.clone(), context_store.downgrade(), thread_store.clone(), - text_thread_store.clone(), + prompt_store.clone(), prompt_editor_entity, None, )))); @@ -1002,7 +1001,7 @@ impl PromptEditor { context_store.clone(), workspace.clone(), thread_store.clone(), - text_thread_store.clone(), + prompt_store.clone(), context_picker_menu_handle.clone(), SuggestContextKind::Thread, ModelUsageContext::InlineAssistant, diff --git a/crates/agent_ui/src/message_editor.rs b/crates/agent_ui/src/message_editor.rs index a1311f39233c7eaaf0b416401676fb2e43e51a26..42607833e4b5734424988d1edaa32d10bec06506 100644 --- a/crates/agent_ui/src/message_editor.rs +++ b/crates/agent_ui/src/message_editor.rs @@ -1,31 +1,25 @@ -use agent::{context::AgentContextKey, context_store::ContextStoreEvent}; -use agent_settings::AgentProfileId; +use std::ops::Range; + use collections::HashMap; use editor::display_map::CreaseId; use editor::{Addon, AnchorRangeExt, Editor}; -use gpui::{App, Entity, Subscription}; +use gpui::{Entity, Subscription}; use ui::prelude::*; -use crate::context_picker::crease_for_mention; -use crate::profile_selector::ProfileProvider; -use agent::{MessageCrease, Thread, context_store::ContextStore}; - -impl ProfileProvider for Entity { - fn profiles_supported(&self, cx: &App) -> bool { - self.read(cx) - .configured_model() - .is_some_and(|model| model.model.supports_tools()) - } - - fn profile_id(&self, cx: &App) -> AgentProfileId { - self.read(cx).profile().id().clone() - } +use crate::{ + context::{AgentContextHandle, AgentContextKey}, + context_picker::crease_for_mention, + context_store::{ContextStore, ContextStoreEvent}, +}; - fn set_profile(&self, profile_id: AgentProfileId, cx: &mut App) { - self.update(cx, |this, cx| { - this.set_profile(profile_id, cx); - }); - } +/// Stored information that can be used to resurrect a context crease when creating an editor for a past message. +#[derive(Clone, Debug)] +pub struct MessageCrease { + pub range: Range, + pub icon_path: SharedString, + pub label: SharedString, + /// None for a deserialized message, Some otherwise. + pub context: Option, } #[derive(Default)] diff --git a/crates/agent_ui/src/terminal_inline_assistant.rs b/crates/agent_ui/src/terminal_inline_assistant.rs index 4385d2420511c8a148b2a7a58fa8845bd2c19a07..9e653dcce1dcf1487af9998662b57ea4f998c7de 100644 --- a/crates/agent_ui/src/terminal_inline_assistant.rs +++ b/crates/agent_ui/src/terminal_inline_assistant.rs @@ -1,12 +1,12 @@ -use crate::inline_prompt_editor::{ - CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId, -}; -use crate::terminal_codegen::{CLEAR_INPUT, CodegenEvent, TerminalCodegen}; -use agent::{ +use crate::{ context::load_context, context_store::ContextStore, - thread_store::{TextThreadStore, ThreadStore}, + inline_prompt_editor::{ + CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId, + }, + terminal_codegen::{CLEAR_INPUT, CodegenEvent, TerminalCodegen}, }; +use agent::HistoryStore; use agent_settings::AgentSettings; use anyhow::{Context as _, Result}; use client::telemetry::Telemetry; @@ -74,8 +74,7 @@ impl TerminalInlineAssistant { workspace: WeakEntity, project: WeakEntity, prompt_store: Option>, - thread_store: Option>, - text_thread_store: Option>, + thread_store: Option>, initial_prompt: Option, window: &mut Window, cx: &mut App, @@ -88,7 +87,7 @@ impl TerminalInlineAssistant { cx, ) }); - let context_store = cx.new(|_cx| ContextStore::new(project, thread_store.clone())); + let context_store = cx.new(|_cx| ContextStore::new(project)); let codegen = cx.new(|_| TerminalCodegen::new(terminal, self.telemetry.clone())); let prompt_editor = cx.new(|cx| { @@ -101,7 +100,7 @@ impl TerminalInlineAssistant { context_store.clone(), workspace.clone(), thread_store.clone(), - text_thread_store.clone(), + prompt_store.as_ref().map(|s| s.downgrade()), window, cx, ) @@ -282,7 +281,6 @@ impl TerminalInlineAssistant { context_load_task .await - .loaded_context .add_to_request_message(&mut request_message); request_message.content.push(prompt.into()); diff --git a/crates/agent_ui/src/ui/context_pill.rs b/crates/agent_ui/src/ui/context_pill.rs index f85a06455439d8e52a7b4272bc7f8069f36548ac..ea1f1136794e1ac3a23e2caeaa3006acccf9bce0 100644 --- a/crates/agent_ui/src/ui/context_pill.rs +++ b/crates/agent_ui/src/ui/context_pill.rs @@ -11,13 +11,13 @@ use project::Project; use prompt_store::PromptStore; use rope::Point; use ui::{IconButtonShape, Tooltip, prelude::*, tooltip_container}; +use util::paths::PathStyle; -use agent::context::{ +use crate::context::{ AgentContextHandle, ContextId, ContextKind, DirectoryContextHandle, FetchedUrlContext, FileContextHandle, ImageContext, ImageStatus, RulesContextHandle, SelectionContextHandle, SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle, }; -use util::paths::PathStyle; #[derive(IntoElement)] pub enum ContextPill { @@ -466,7 +466,7 @@ impl AddedContext { parent: None, tooltip: None, icon_path: None, - status: if handle.thread.read(cx).is_generating_detailed_summary() { + status: if handle.thread.read(cx).is_generating_summary() { ContextStatus::Loading { message: "Summarizing…".into(), } @@ -476,7 +476,11 @@ impl AddedContext { render_hover: { let thread = handle.thread.clone(); Some(Rc::new(move |_, cx| { - let text = thread.read(cx).latest_detailed_summary_or_text(); + let text = thread + .update(cx, |thread, cx| thread.summary(cx)) + .now_or_never() + .flatten() + .unwrap_or_else(|| SharedString::from(thread.read(cx).to_markdown())); ContextPillHover::new_text(text, cx).into() })) }, diff --git a/crates/assistant_tool/Cargo.toml b/crates/assistant_tool/Cargo.toml deleted file mode 100644 index c95695052a4778209010b2f9e7a4a57be4cb6cf7..0000000000000000000000000000000000000000 --- a/crates/assistant_tool/Cargo.toml +++ /dev/null @@ -1,50 +0,0 @@ -[package] -name = "assistant_tool" -version = "0.1.0" -edition.workspace = true -publish.workspace = true -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[lib] -path = "src/assistant_tool.rs" - -[dependencies] -action_log.workspace = true -anyhow.workspace = true -collections.workspace = true -derive_more.workspace = true -gpui.workspace = true -icons.workspace = true -language.workspace = true -language_model.workspace = true -log.workspace = true -parking_lot.workspace = true -project.workspace = true -regex.workspace = true -serde.workspace = true -serde_json.workspace = true -text.workspace = true -util.workspace = true -workspace.workspace = true -workspace-hack.workspace = true - -[dev-dependencies] -buffer_diff = { workspace = true, features = ["test-support"] } -collections = { workspace = true, features = ["test-support"] } -clock = { workspace = true, features = ["test-support"] } -ctor.workspace = true -gpui = { workspace = true, features = ["test-support"] } -indoc.workspace = true -language = { workspace = true, features = ["test-support"] } -language_model = { workspace = true, features = ["test-support"] } -log.workspace = true -pretty_assertions.workspace = true -project = { workspace = true, features = ["test-support"] } -rand.workspace = true -settings = { workspace = true, features = ["test-support"] } -text = { workspace = true, features = ["test-support"] } -util = { workspace = true, features = ["test-support"] } -zlog.workspace = true diff --git a/crates/assistant_tool/LICENSE-GPL b/crates/assistant_tool/LICENSE-GPL deleted file mode 120000 index 89e542f750cd3860a0598eff0dc34b56d7336dc4..0000000000000000000000000000000000000000 --- a/crates/assistant_tool/LICENSE-GPL +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-GPL \ No newline at end of file diff --git a/crates/assistant_tool/src/assistant_tool.rs b/crates/assistant_tool/src/assistant_tool.rs deleted file mode 100644 index 9c5825d0f0ecc9c31277bfff5123d3d80501511b..0000000000000000000000000000000000000000 --- a/crates/assistant_tool/src/assistant_tool.rs +++ /dev/null @@ -1,269 +0,0 @@ -pub mod outline; -mod tool_registry; -mod tool_schema; -mod tool_working_set; - -use std::fmt; -use std::fmt::Debug; -use std::fmt::Formatter; -use std::ops::Deref; -use std::sync::Arc; - -use action_log::ActionLog; -use anyhow::Result; -use gpui::AnyElement; -use gpui::AnyWindowHandle; -use gpui::Context; -use gpui::IntoElement; -use gpui::Window; -use gpui::{App, Entity, SharedString, Task, WeakEntity}; -use icons::IconName; -use language_model::LanguageModel; -use language_model::LanguageModelImage; -use language_model::LanguageModelRequest; -use language_model::LanguageModelToolSchemaFormat; -use project::Project; -use workspace::Workspace; - -pub use crate::tool_registry::*; -pub use crate::tool_schema::*; -pub use crate::tool_working_set::*; - -pub fn init(cx: &mut App) { - ToolRegistry::default_global(cx); -} - -#[derive(Debug, Clone)] -pub enum ToolUseStatus { - InputStillStreaming, - NeedsConfirmation, - Pending, - Running, - Finished(SharedString), - Error(SharedString), -} - -impl ToolUseStatus { - pub fn text(&self) -> SharedString { - match self { - ToolUseStatus::NeedsConfirmation => "".into(), - ToolUseStatus::InputStillStreaming => "".into(), - ToolUseStatus::Pending => "".into(), - ToolUseStatus::Running => "".into(), - ToolUseStatus::Finished(out) => out.clone(), - ToolUseStatus::Error(out) => out.clone(), - } - } - - pub fn error(&self) -> Option { - match self { - ToolUseStatus::Error(out) => Some(out.clone()), - _ => None, - } - } -} - -#[derive(Debug)] -pub struct ToolResultOutput { - pub content: ToolResultContent, - pub output: Option, -} - -#[derive(Debug, PartialEq, Eq)] -pub enum ToolResultContent { - Text(String), - Image(LanguageModelImage), -} - -impl ToolResultContent { - pub fn len(&self) -> usize { - match self { - ToolResultContent::Text(str) => str.len(), - ToolResultContent::Image(image) => image.len(), - } - } - - pub fn is_empty(&self) -> bool { - match self { - ToolResultContent::Text(str) => str.is_empty(), - ToolResultContent::Image(image) => image.is_empty(), - } - } - - pub fn as_str(&self) -> Option<&str> { - match self { - ToolResultContent::Text(str) => Some(str), - ToolResultContent::Image(_) => None, - } - } -} - -impl From for ToolResultOutput { - fn from(value: String) -> Self { - ToolResultOutput { - content: ToolResultContent::Text(value), - output: None, - } - } -} - -impl Deref for ToolResultOutput { - type Target = ToolResultContent; - - fn deref(&self) -> &Self::Target { - &self.content - } -} - -/// The result of running a tool, containing both the asynchronous output -/// and an optional card view that can be rendered immediately. -pub struct ToolResult { - /// The asynchronous task that will eventually resolve to the tool's output - pub output: Task>, - /// An optional view to present the output of the tool. - pub card: Option, -} - -pub trait ToolCard: 'static + Sized { - fn render( - &mut self, - status: &ToolUseStatus, - window: &mut Window, - workspace: WeakEntity, - cx: &mut Context, - ) -> impl IntoElement; -} - -#[derive(Clone)] -pub struct AnyToolCard { - entity: gpui::AnyEntity, - render: fn( - entity: gpui::AnyEntity, - status: &ToolUseStatus, - window: &mut Window, - workspace: WeakEntity, - cx: &mut App, - ) -> AnyElement, -} - -impl From> for AnyToolCard { - fn from(entity: Entity) -> Self { - fn downcast_render( - entity: gpui::AnyEntity, - status: &ToolUseStatus, - window: &mut Window, - workspace: WeakEntity, - cx: &mut App, - ) -> AnyElement { - let entity = entity.downcast::().unwrap(); - entity.update(cx, |entity, cx| { - entity - .render(status, window, workspace, cx) - .into_any_element() - }) - } - - Self { - entity: entity.into(), - render: downcast_render::, - } - } -} - -impl AnyToolCard { - pub fn render( - &self, - status: &ToolUseStatus, - window: &mut Window, - workspace: WeakEntity, - cx: &mut App, - ) -> AnyElement { - (self.render)(self.entity.clone(), status, window, workspace, cx) - } -} - -impl From>> for ToolResult { - /// Convert from a task to a ToolResult with no card - fn from(output: Task>) -> Self { - Self { output, card: None } - } -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)] -pub enum ToolSource { - /// A native tool built-in to Zed. - Native, - /// A tool provided by a context server. - ContextServer { id: SharedString }, -} - -/// A tool that can be used by a language model. -pub trait Tool: 'static + Send + Sync { - /// Returns the name of the tool. - fn name(&self) -> String; - - /// Returns the description of the tool. - fn description(&self) -> String; - - /// Returns the icon for the tool. - fn icon(&self) -> IconName; - - /// Returns the source of the tool. - fn source(&self) -> ToolSource { - ToolSource::Native - } - - /// Returns true if the tool needs the users's confirmation - /// before having permission to run. - fn needs_confirmation( - &self, - input: &serde_json::Value, - project: &Entity, - cx: &App, - ) -> bool; - - /// Returns true if the tool may perform edits. - fn may_perform_edits(&self) -> bool; - - /// Returns the JSON schema that describes the tool's input. - fn input_schema(&self, _: LanguageModelToolSchemaFormat) -> Result { - Ok(serde_json::Value::Object(serde_json::Map::default())) - } - - /// Returns markdown to be displayed in the UI for this tool. - fn ui_text(&self, input: &serde_json::Value) -> String; - - /// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming - /// (so information may be missing). - fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String { - self.ui_text(input) - } - - /// Runs the tool with the provided input. - fn run( - self: Arc, - input: serde_json::Value, - request: Arc, - project: Entity, - action_log: Entity, - model: Arc, - window: Option, - cx: &mut App, - ) -> ToolResult; - - fn deserialize_card( - self: Arc, - _output: serde_json::Value, - _project: Entity, - _window: &mut Window, - _cx: &mut App, - ) -> Option { - None - } -} - -impl Debug for dyn Tool { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("Tool").field("name", &self.name()).finish() - } -} diff --git a/crates/assistant_tool/src/tool_registry.rs b/crates/assistant_tool/src/tool_registry.rs deleted file mode 100644 index 26b4821a6d1af05a5e42d639f465486b9311d427..0000000000000000000000000000000000000000 --- a/crates/assistant_tool/src/tool_registry.rs +++ /dev/null @@ -1,74 +0,0 @@ -use std::sync::Arc; - -use collections::HashMap; -use derive_more::{Deref, DerefMut}; -use gpui::Global; -use gpui::{App, ReadGlobal}; -use parking_lot::RwLock; - -use crate::Tool; - -#[derive(Default, Deref, DerefMut)] -struct GlobalToolRegistry(Arc); - -impl Global for GlobalToolRegistry {} - -#[derive(Default)] -struct ToolRegistryState { - tools: HashMap, Arc>, -} - -#[derive(Default)] -pub struct ToolRegistry { - state: RwLock, -} - -impl ToolRegistry { - /// Returns the global [`ToolRegistry`]. - pub fn global(cx: &App) -> Arc { - GlobalToolRegistry::global(cx).0.clone() - } - - /// Returns the global [`ToolRegistry`]. - /// - /// Inserts a default [`ToolRegistry`] if one does not yet exist. - pub fn default_global(cx: &mut App) -> Arc { - cx.default_global::().0.clone() - } - - pub fn new() -> Arc { - Arc::new(Self { - state: RwLock::new(ToolRegistryState { - tools: HashMap::default(), - }), - }) - } - - /// Registers the provided [`Tool`]. - pub fn register_tool(&self, tool: impl Tool) { - let mut state = self.state.write(); - let tool_name: Arc = tool.name().into(); - state.tools.insert(tool_name, Arc::new(tool)); - } - - /// Unregisters the provided [`Tool`]. - pub fn unregister_tool(&self, tool: impl Tool) { - self.unregister_tool_by_name(tool.name().as_str()) - } - - /// Unregisters the tool with the given name. - pub fn unregister_tool_by_name(&self, tool_name: &str) { - let mut state = self.state.write(); - state.tools.remove(tool_name); - } - - /// Returns the list of tools in the registry. - pub fn tools(&self) -> Vec> { - self.state.read().tools.values().cloned().collect() - } - - /// Returns the [`Tool`] with the given name. - pub fn tool(&self, name: &str) -> Option> { - self.state.read().tools.get(name).cloned() - } -} diff --git a/crates/assistant_tool/src/tool_working_set.rs b/crates/assistant_tool/src/tool_working_set.rs deleted file mode 100644 index 61f57affc76aad9e4d2185665b539f9092e3491c..0000000000000000000000000000000000000000 --- a/crates/assistant_tool/src/tool_working_set.rs +++ /dev/null @@ -1,415 +0,0 @@ -use std::{borrow::Borrow, sync::Arc}; - -use crate::{Tool, ToolRegistry, ToolSource}; -use collections::{HashMap, HashSet, IndexMap}; -use gpui::{App, SharedString}; -use util::debug_panic; - -#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)] -pub struct ToolId(usize); - -/// A unique identifier for a tool within a working set. -#[derive(Clone, PartialEq, Eq, Hash, Default)] -pub struct UniqueToolName(SharedString); - -impl Borrow for UniqueToolName { - fn borrow(&self) -> &str { - &self.0 - } -} - -impl From for UniqueToolName { - fn from(value: String) -> Self { - UniqueToolName(SharedString::new(value)) - } -} - -impl Into for UniqueToolName { - fn into(self) -> String { - self.0.into() - } -} - -impl std::fmt::Debug for UniqueToolName { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.0.fmt(f) - } -} - -impl std::fmt::Display for UniqueToolName { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0.as_ref()) - } -} - -/// A working set of tools for use in one instance of the Assistant Panel. -#[derive(Default)] -pub struct ToolWorkingSet { - context_server_tools_by_id: HashMap>, - context_server_tools_by_name: HashMap>, - next_tool_id: ToolId, -} - -impl ToolWorkingSet { - pub fn tool(&self, name: &str, cx: &App) -> Option> { - self.context_server_tools_by_name - .get(name) - .cloned() - .or_else(|| ToolRegistry::global(cx).tool(name)) - } - - pub fn tools(&self, cx: &App) -> Vec<(UniqueToolName, Arc)> { - let mut tools = ToolRegistry::global(cx) - .tools() - .into_iter() - .map(|tool| (UniqueToolName(tool.name().into()), tool)) - .collect::>(); - tools.extend(self.context_server_tools_by_name.clone()); - tools - } - - pub fn tools_by_source(&self, cx: &App) -> IndexMap>> { - let mut tools_by_source = IndexMap::default(); - - for (_, tool) in self.tools(cx) { - tools_by_source - .entry(tool.source()) - .or_insert_with(Vec::new) - .push(tool); - } - - for tools in tools_by_source.values_mut() { - tools.sort_by_key(|tool| tool.name()); - } - - tools_by_source.sort_unstable_keys(); - - tools_by_source - } - - pub fn insert(&mut self, tool: Arc, cx: &App) -> ToolId { - let tool_id = self.register_tool(tool); - self.tools_changed(cx); - tool_id - } - - pub fn extend(&mut self, tools: impl Iterator>, cx: &App) -> Vec { - let ids = tools.map(|tool| self.register_tool(tool)).collect(); - self.tools_changed(cx); - ids - } - - pub fn remove(&mut self, tool_ids_to_remove: &[ToolId], cx: &App) { - self.context_server_tools_by_id - .retain(|id, _| !tool_ids_to_remove.contains(id)); - self.tools_changed(cx); - } - - fn register_tool(&mut self, tool: Arc) -> ToolId { - let tool_id = self.next_tool_id; - self.next_tool_id.0 += 1; - self.context_server_tools_by_id - .insert(tool_id, tool.clone()); - tool_id - } - - fn tools_changed(&mut self, cx: &App) { - self.context_server_tools_by_name = resolve_context_server_tool_name_conflicts( - &self - .context_server_tools_by_id - .values() - .cloned() - .collect::>(), - &ToolRegistry::global(cx).tools(), - ); - } -} - -fn resolve_context_server_tool_name_conflicts( - context_server_tools: &[Arc], - native_tools: &[Arc], -) -> HashMap> { - fn resolve_tool_name(tool: &Arc) -> String { - let mut tool_name = tool.name(); - tool_name.truncate(MAX_TOOL_NAME_LENGTH); - tool_name - } - - const MAX_TOOL_NAME_LENGTH: usize = 64; - - let mut duplicated_tool_names = HashSet::default(); - let mut seen_tool_names = HashSet::default(); - seen_tool_names.extend(native_tools.iter().map(|tool| tool.name())); - for tool in context_server_tools { - let tool_name = resolve_tool_name(tool); - if seen_tool_names.contains(&tool_name) { - debug_assert!( - tool.source() != ToolSource::Native, - "Expected MCP tool but got a native tool: {}", - tool_name - ); - duplicated_tool_names.insert(tool_name); - } else { - seen_tool_names.insert(tool_name); - } - } - - if duplicated_tool_names.is_empty() { - return context_server_tools - .iter() - .map(|tool| (resolve_tool_name(tool).into(), tool.clone())) - .collect(); - } - - context_server_tools - .iter() - .filter_map(|tool| { - let mut tool_name = resolve_tool_name(tool); - if !duplicated_tool_names.contains(&tool_name) { - return Some((tool_name.into(), tool.clone())); - } - match tool.source() { - ToolSource::Native => { - debug_panic!("Expected MCP tool but got a native tool: {}", tool_name); - // Built-in tools always keep their original name - Some((tool_name.into(), tool.clone())) - } - ToolSource::ContextServer { id } => { - // Context server tools are prefixed with the context server ID, and truncated if necessary - tool_name.insert(0, '_'); - if tool_name.len() + id.len() > MAX_TOOL_NAME_LENGTH { - let len = MAX_TOOL_NAME_LENGTH - tool_name.len(); - let mut id = id.to_string(); - id.truncate(len); - tool_name.insert_str(0, &id); - } else { - tool_name.insert_str(0, &id); - } - - tool_name.truncate(MAX_TOOL_NAME_LENGTH); - - if seen_tool_names.contains(&tool_name) { - log::error!("Cannot resolve tool name conflict for tool {}", tool.name()); - None - } else { - Some((tool_name.into(), tool.clone())) - } - } - } - }) - .collect() -} -#[cfg(test)] -mod tests { - use gpui::{AnyWindowHandle, Entity, Task, TestAppContext}; - use language_model::{LanguageModel, LanguageModelRequest}; - use project::Project; - - use crate::{ActionLog, ToolResult}; - - use super::*; - - #[gpui::test] - fn test_unique_tool_names(cx: &mut TestAppContext) { - fn assert_tool( - tool_working_set: &ToolWorkingSet, - unique_name: &str, - expected_name: &str, - expected_source: ToolSource, - cx: &App, - ) { - let tool = tool_working_set.tool(unique_name, cx).unwrap(); - assert_eq!(tool.name(), expected_name); - assert_eq!(tool.source(), expected_source); - } - - let tool_registry = cx.update(ToolRegistry::default_global); - tool_registry.register_tool(TestTool::new("tool1", ToolSource::Native)); - tool_registry.register_tool(TestTool::new("tool2", ToolSource::Native)); - - let mut tool_working_set = ToolWorkingSet::default(); - cx.update(|cx| { - tool_working_set.extend( - vec![ - Arc::new(TestTool::new( - "tool2", - ToolSource::ContextServer { id: "mcp-1".into() }, - )) as Arc, - Arc::new(TestTool::new( - "tool2", - ToolSource::ContextServer { id: "mcp-2".into() }, - )) as Arc, - ] - .into_iter(), - cx, - ); - }); - - cx.update(|cx| { - assert_tool(&tool_working_set, "tool1", "tool1", ToolSource::Native, cx); - assert_tool(&tool_working_set, "tool2", "tool2", ToolSource::Native, cx); - assert_tool( - &tool_working_set, - "mcp-1_tool2", - "tool2", - ToolSource::ContextServer { id: "mcp-1".into() }, - cx, - ); - assert_tool( - &tool_working_set, - "mcp-2_tool2", - "tool2", - ToolSource::ContextServer { id: "mcp-2".into() }, - cx, - ); - }) - } - - #[gpui::test] - fn test_resolve_context_server_tool_name_conflicts() { - assert_resolve_context_server_tool_name_conflicts( - vec![ - TestTool::new("tool1", ToolSource::Native), - TestTool::new("tool2", ToolSource::Native), - ], - vec![TestTool::new( - "tool3", - ToolSource::ContextServer { id: "mcp-1".into() }, - )], - vec!["tool3"], - ); - - assert_resolve_context_server_tool_name_conflicts( - vec![ - TestTool::new("tool1", ToolSource::Native), - TestTool::new("tool2", ToolSource::Native), - ], - vec![ - TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }), - TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }), - ], - vec!["mcp-1_tool3", "mcp-2_tool3"], - ); - - assert_resolve_context_server_tool_name_conflicts( - vec![ - TestTool::new("tool1", ToolSource::Native), - TestTool::new("tool2", ToolSource::Native), - TestTool::new("tool3", ToolSource::Native), - ], - vec![ - TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-1".into() }), - TestTool::new("tool3", ToolSource::ContextServer { id: "mcp-2".into() }), - ], - vec!["mcp-1_tool3", "mcp-2_tool3"], - ); - - // Test deduplication of tools with very long names, in this case the mcp server name should be truncated - assert_resolve_context_server_tool_name_conflicts( - vec![TestTool::new( - "tool-with-very-very-very-long-name", - ToolSource::Native, - )], - vec![TestTool::new( - "tool-with-very-very-very-long-name", - ToolSource::ContextServer { - id: "mcp-with-very-very-very-long-name".into(), - }, - )], - vec!["mcp-with-very-very-very-long-_tool-with-very-very-very-long-name"], - ); - - fn assert_resolve_context_server_tool_name_conflicts( - builtin_tools: Vec, - context_server_tools: Vec, - expected: Vec<&'static str>, - ) { - let context_server_tools: Vec> = context_server_tools - .into_iter() - .map(|t| Arc::new(t) as Arc) - .collect(); - let builtin_tools: Vec> = builtin_tools - .into_iter() - .map(|t| Arc::new(t) as Arc) - .collect(); - let tools = - resolve_context_server_tool_name_conflicts(&context_server_tools, &builtin_tools); - assert_eq!(tools.len(), expected.len()); - for (i, (name, _)) in tools.into_iter().enumerate() { - assert_eq!( - name.0.as_ref(), - expected[i], - "Expected '{}' got '{}' at index {}", - expected[i], - name, - i - ); - } - } - } - - struct TestTool { - name: String, - source: ToolSource, - } - - impl TestTool { - fn new(name: impl Into, source: ToolSource) -> Self { - Self { - name: name.into(), - source, - } - } - } - - impl Tool for TestTool { - fn name(&self) -> String { - self.name.clone() - } - - fn icon(&self) -> icons::IconName { - icons::IconName::Ai - } - - fn may_perform_edits(&self) -> bool { - false - } - - fn needs_confirmation( - &self, - _input: &serde_json::Value, - _project: &Entity, - _cx: &App, - ) -> bool { - true - } - - fn source(&self) -> ToolSource { - self.source.clone() - } - - fn description(&self) -> String { - "Test tool".to_string() - } - - fn ui_text(&self, _input: &serde_json::Value) -> String { - "Test tool".to_string() - } - - fn run( - self: Arc, - _input: serde_json::Value, - _request: Arc, - _project: Entity, - _action_log: Entity, - _model: Arc, - _window: Option, - _cx: &mut App, - ) -> ToolResult { - ToolResult { - output: Task::ready(Err(anyhow::anyhow!("No content"))), - card: None, - } - } - } -} diff --git a/crates/assistant_tools/Cargo.toml b/crates/assistant_tools/Cargo.toml deleted file mode 100644 index 9b9b8196d1c342c536d605306a1a062e73768c56..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/Cargo.toml +++ /dev/null @@ -1,92 +0,0 @@ -[package] -name = "assistant_tools" -version = "0.1.0" -edition.workspace = true -publish.workspace = true -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[lib] -path = "src/assistant_tools.rs" - -[features] -eval = [] - -[dependencies] -action_log.workspace = true -agent_settings.workspace = true -anyhow.workspace = true -assistant_tool.workspace = true -buffer_diff.workspace = true -chrono.workspace = true -client.workspace = true -cloud_llm_client.workspace = true -collections.workspace = true -component.workspace = true -derive_more.workspace = true -diffy = "0.4.2" -editor.workspace = true -feature_flags.workspace = true -futures.workspace = true -gpui.workspace = true -handlebars = { workspace = true, features = ["rust-embed"] } -html_to_markdown.workspace = true -http_client.workspace = true -indoc.workspace = true -itertools.workspace = true -language.workspace = true -language_model.workspace = true -log.workspace = true -lsp.workspace = true -markdown.workspace = true -open.workspace = true -paths.workspace = true -portable-pty.workspace = true -project.workspace = true -prompt_store.workspace = true -regex.workspace = true -rust-embed.workspace = true -schemars.workspace = true -serde.workspace = true -serde_json.workspace = true -settings.workspace = true -smallvec.workspace = true -streaming_diff.workspace = true -strsim.workspace = true -task.workspace = true -terminal.workspace = true -terminal_view.workspace = true -theme.workspace = true -ui.workspace = true -util.workspace = true -watch.workspace = true -web_search.workspace = true -workspace-hack.workspace = true -workspace.workspace = true - -[dev-dependencies] -lsp = { workspace = true, features = ["test-support"] } -client = { workspace = true, features = ["test-support"] } -clock = { workspace = true, features = ["test-support"] } -collections = { workspace = true, features = ["test-support"] } -gpui = { workspace = true, features = ["test-support"] } -gpui_tokio.workspace = true -fs = { workspace = true, features = ["test-support"] } -language = { workspace = true, features = ["test-support"] } -language_model = { workspace = true, features = ["test-support"] } -language_models.workspace = true -project = { workspace = true, features = ["test-support"] } -rand.workspace = true -pretty_assertions.workspace = true -reqwest_client.workspace = true -settings = { workspace = true, features = ["test-support"] } -smol.workspace = true -task = { workspace = true, features = ["test-support"]} -tempfile.workspace = true -theme.workspace = true -tree-sitter-rust.workspace = true -workspace = { workspace = true, features = ["test-support"] } -unindent.workspace = true -zlog.workspace = true diff --git a/crates/assistant_tools/LICENSE-GPL b/crates/assistant_tools/LICENSE-GPL deleted file mode 120000 index 89e542f750cd3860a0598eff0dc34b56d7336dc4..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/LICENSE-GPL +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-GPL \ No newline at end of file diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs deleted file mode 100644 index 17e2ba12f706387859ca3393aa44f5c05570e50a..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/assistant_tools.rs +++ /dev/null @@ -1,167 +0,0 @@ -mod copy_path_tool; -mod create_directory_tool; -mod delete_path_tool; -mod diagnostics_tool; -pub mod edit_agent; -mod edit_file_tool; -mod fetch_tool; -mod find_path_tool; -mod grep_tool; -mod list_directory_tool; -mod move_path_tool; -mod now_tool; -mod open_tool; -mod project_notifications_tool; -mod read_file_tool; -mod schema; -pub mod templates; -mod terminal_tool; -mod thinking_tool; -mod ui; -mod web_search_tool; - -use assistant_tool::ToolRegistry; -use copy_path_tool::CopyPathTool; -use gpui::{App, Entity}; -use http_client::HttpClientWithUrl; -use language_model::LanguageModelRegistry; -use move_path_tool::MovePathTool; -use std::sync::Arc; -use web_search_tool::WebSearchTool; - -pub(crate) use templates::*; - -use crate::create_directory_tool::CreateDirectoryTool; -use crate::delete_path_tool::DeletePathTool; -use crate::diagnostics_tool::DiagnosticsTool; -use crate::edit_file_tool::EditFileTool; -use crate::fetch_tool::FetchTool; -use crate::list_directory_tool::ListDirectoryTool; -use crate::now_tool::NowTool; -use crate::thinking_tool::ThinkingTool; - -pub use edit_file_tool::{EditFileMode, EditFileToolInput}; -pub use find_path_tool::*; -pub use grep_tool::{GrepTool, GrepToolInput}; -pub use open_tool::OpenTool; -pub use project_notifications_tool::ProjectNotificationsTool; -pub use read_file_tool::{ReadFileTool, ReadFileToolInput}; -pub use terminal_tool::TerminalTool; - -pub fn init(http_client: Arc, cx: &mut App) { - assistant_tool::init(cx); - - let registry = ToolRegistry::global(cx); - registry.register_tool(TerminalTool); - registry.register_tool(CreateDirectoryTool); - registry.register_tool(CopyPathTool); - registry.register_tool(DeletePathTool); - registry.register_tool(MovePathTool); - registry.register_tool(DiagnosticsTool); - registry.register_tool(ListDirectoryTool); - registry.register_tool(NowTool); - registry.register_tool(OpenTool); - registry.register_tool(ProjectNotificationsTool); - registry.register_tool(FindPathTool); - registry.register_tool(ReadFileTool); - registry.register_tool(GrepTool); - registry.register_tool(ThinkingTool); - registry.register_tool(FetchTool::new(http_client)); - registry.register_tool(EditFileTool); - - register_web_search_tool(&LanguageModelRegistry::global(cx), cx); - cx.subscribe( - &LanguageModelRegistry::global(cx), - move |registry, event, cx| { - if let language_model::Event::DefaultModelChanged = event { - register_web_search_tool(®istry, cx); - } - }, - ) - .detach(); -} - -fn register_web_search_tool(registry: &Entity, cx: &mut App) { - let using_zed_provider = registry - .read(cx) - .default_model() - .is_some_and(|default| default.is_provided_by_zed()); - if using_zed_provider { - ToolRegistry::global(cx).register_tool(WebSearchTool); - } else { - ToolRegistry::global(cx).unregister_tool(WebSearchTool); - } -} - -#[cfg(test)] -mod tests { - use super::*; - use agent_settings::AgentSettings; - use client::Client; - use clock::FakeSystemClock; - use http_client::FakeHttpClient; - use schemars::JsonSchema; - use serde::Serialize; - use settings::Settings; - - #[test] - fn test_json_schema() { - #[derive(Serialize, JsonSchema)] - struct GetWeatherTool { - location: String, - } - - let schema = schema::json_schema_for::( - language_model::LanguageModelToolSchemaFormat::JsonSchema, - ) - .unwrap(); - - assert_eq!( - schema, - serde_json::json!({ - "type": "object", - "properties": { - "location": { - "type": "string" - } - }, - "required": ["location"], - "additionalProperties": false - }) - ); - } - - #[gpui::test] - fn test_builtin_tool_schema_compatibility(cx: &mut App) { - settings::init(cx); - AgentSettings::register(cx); - - let client = Client::new( - Arc::new(FakeSystemClock::new()), - FakeHttpClient::with_200_response(), - cx, - ); - language_model::init(client.clone(), cx); - crate::init(client.http_client(), cx); - - for tool in ToolRegistry::global(cx).tools() { - let actual_schema = tool - .input_schema(language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset) - .unwrap(); - let mut expected_schema = actual_schema.clone(); - assistant_tool::adapt_schema_to_format( - &mut expected_schema, - language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset, - ) - .unwrap(); - - let error_message = format!( - "Tool schema for `{}` is not compatible with `language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset` (Gemini Models).\n\ - Are you using `schema::json_schema_for(format)` to generate the schema?", - tool.name(), - ); - - assert_eq!(actual_schema, expected_schema, "{}", error_message) - } - } -} diff --git a/crates/assistant_tools/src/copy_path_tool.rs b/crates/assistant_tools/src/copy_path_tool.rs deleted file mode 100644 index 572eddcb1079557b464ba29d125aa44929409cc5..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/copy_path_tool.rs +++ /dev/null @@ -1,123 +0,0 @@ -use crate::schema::json_schema_for; -use action_log::ActionLog; -use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{Tool, ToolResult}; -use gpui::AnyWindowHandle; -use gpui::{App, AppContext, Entity, Task}; -use language_model::LanguageModel; -use language_model::{LanguageModelRequest, LanguageModelToolSchemaFormat}; -use project::Project; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use std::sync::Arc; -use ui::IconName; -use util::markdown::MarkdownInlineCode; - -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -pub struct CopyPathToolInput { - /// The source path of the file or directory to copy. - /// If a directory is specified, its contents will be copied recursively (like `cp -r`). - /// - /// - /// If the project has the following files: - /// - /// - directory1/a/something.txt - /// - directory2/a/things.txt - /// - directory3/a/other.txt - /// - /// You can copy the first file by providing a source_path of "directory1/a/something.txt" - /// - pub source_path: String, - - /// The destination path where the file or directory should be copied to. - /// - /// - /// To copy "directory1/a/something.txt" to "directory2/b/copy.txt", - /// provide a destination_path of "directory2/b/copy.txt" - /// - pub destination_path: String, -} - -pub struct CopyPathTool; - -impl Tool for CopyPathTool { - fn name(&self) -> String { - "copy_path".into() - } - - fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { - false - } - - fn may_perform_edits(&self) -> bool { - true - } - - fn description(&self) -> String { - include_str!("./copy_path_tool/description.md").into() - } - - fn icon(&self) -> IconName { - IconName::ToolCopy - } - - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { - json_schema_for::(format) - } - - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => { - let src = MarkdownInlineCode(&input.source_path); - let dest = MarkdownInlineCode(&input.destination_path); - format!("Copy {src} to {dest}") - } - Err(_) => "Copy path".to_string(), - } - } - - fn run( - self: Arc, - input: serde_json::Value, - _request: Arc, - project: Entity, - _action_log: Entity, - _model: Arc, - _window: Option, - cx: &mut App, - ) -> ToolResult { - let input = match serde_json::from_value::(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; - let copy_task = project.update(cx, |project, cx| { - match project - .find_project_path(&input.source_path, cx) - .and_then(|project_path| project.entry_for_path(&project_path, cx)) - { - Some(entity) => match project.find_project_path(&input.destination_path, cx) { - Some(project_path) => project.copy_entry(entity.id, project_path, cx), - None => Task::ready(Err(anyhow!( - "Destination path {} was outside the project.", - input.destination_path - ))), - }, - None => Task::ready(Err(anyhow!( - "Source path {} was not found in the project.", - input.source_path - ))), - } - }); - - cx.background_spawn(async move { - let _ = copy_task.await.with_context(|| { - format!( - "Copying {} to {}", - input.source_path, input.destination_path - ) - })?; - Ok(format!("Copied {} to {}", input.source_path, input.destination_path).into()) - }) - .into() - } -} diff --git a/crates/assistant_tools/src/copy_path_tool/description.md b/crates/assistant_tools/src/copy_path_tool/description.md deleted file mode 100644 index a5105e6f18c705e93aa9c30b9588f84dd8db542a..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/copy_path_tool/description.md +++ /dev/null @@ -1,6 +0,0 @@ -Copies a file or directory in the project, and returns confirmation that the copy succeeded. -Directory contents will be copied recursively (like `cp -r`). - -This tool should be used when it's desirable to create a copy of a file or directory without modifying the original. -It's much more efficient than doing this by separately reading and then writing the file or directory's contents, -so this tool should be preferred over that approach whenever copying is the goal. diff --git a/crates/assistant_tools/src/create_directory_tool.rs b/crates/assistant_tools/src/create_directory_tool.rs deleted file mode 100644 index 85eea463dc1dfd429dd70ded8c18faf6ee8421c5..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/create_directory_tool.rs +++ /dev/null @@ -1,100 +0,0 @@ -use crate::schema::json_schema_for; -use action_log::ActionLog; -use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{Tool, ToolResult}; -use gpui::AnyWindowHandle; -use gpui::{App, Entity, Task}; -use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; -use project::Project; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use std::sync::Arc; -use ui::IconName; -use util::markdown::MarkdownInlineCode; - -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -pub struct CreateDirectoryToolInput { - /// The path of the new directory. - /// - /// - /// If the project has the following structure: - /// - /// - directory1/ - /// - directory2/ - /// - /// You can create a new directory by providing a path of "directory1/new_directory" - /// - pub path: String, -} - -pub struct CreateDirectoryTool; - -impl Tool for CreateDirectoryTool { - fn name(&self) -> String { - "create_directory".into() - } - - fn description(&self) -> String { - include_str!("./create_directory_tool/description.md").into() - } - - fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { - false - } - - fn may_perform_edits(&self) -> bool { - false - } - - fn icon(&self) -> IconName { - IconName::ToolFolder - } - - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { - json_schema_for::(format) - } - - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => { - format!("Create directory {}", MarkdownInlineCode(&input.path)) - } - Err(_) => "Create directory".to_string(), - } - } - - fn run( - self: Arc, - input: serde_json::Value, - _request: Arc, - project: Entity, - _action_log: Entity, - _model: Arc, - _window: Option, - cx: &mut App, - ) -> ToolResult { - let input = match serde_json::from_value::(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; - let project_path = match project.read(cx).find_project_path(&input.path, cx) { - Some(project_path) => project_path, - None => { - return Task::ready(Err(anyhow!("Path to create was outside the project"))).into(); - } - }; - let destination_path: Arc = input.path.as_str().into(); - - cx.spawn(async move |cx| { - project - .update(cx, |project, cx| { - project.create_entry(project_path.clone(), true, cx) - })? - .await - .with_context(|| format!("Creating directory {destination_path}"))?; - - Ok(format!("Created directory {destination_path}").into()) - }) - .into() - } -} diff --git a/crates/assistant_tools/src/create_directory_tool/description.md b/crates/assistant_tools/src/create_directory_tool/description.md deleted file mode 100644 index 52056518c23517bf9fd36bf7d41d7e46947b15b6..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/create_directory_tool/description.md +++ /dev/null @@ -1,3 +0,0 @@ -Creates a new directory at the specified path within the project. Returns confirmation that the directory was created. - -This tool creates a directory and all necessary parent directories (similar to `mkdir -p`). It should be used whenever you need to create new directories within the project. diff --git a/crates/assistant_tools/src/delete_path_tool.rs b/crates/assistant_tools/src/delete_path_tool.rs deleted file mode 100644 index 7c85f1ed7552931822500f76bb9f3b1b1f47fd0c..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/delete_path_tool.rs +++ /dev/null @@ -1,144 +0,0 @@ -use crate::schema::json_schema_for; -use action_log::ActionLog; -use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{Tool, ToolResult}; -use futures::{SinkExt, StreamExt, channel::mpsc}; -use gpui::{AnyWindowHandle, App, AppContext, Entity, Task}; -use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; -use project::{Project, ProjectPath}; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use std::sync::Arc; -use ui::IconName; - -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -pub struct DeletePathToolInput { - /// The path of the file or directory to delete. - /// - /// - /// If the project has the following files: - /// - /// - directory1/a/something.txt - /// - directory2/a/things.txt - /// - directory3/a/other.txt - /// - /// You can delete the first file by providing a path of "directory1/a/something.txt" - /// - pub path: String, -} - -pub struct DeletePathTool; - -impl Tool for DeletePathTool { - fn name(&self) -> String { - "delete_path".into() - } - - fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { - true - } - - fn may_perform_edits(&self) -> bool { - true - } - - fn description(&self) -> String { - include_str!("./delete_path_tool/description.md").into() - } - - fn icon(&self) -> IconName { - IconName::ToolDeleteFile - } - - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { - json_schema_for::(format) - } - - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => format!("Delete “`{}`”", input.path), - Err(_) => "Delete path".to_string(), - } - } - - fn run( - self: Arc, - input: serde_json::Value, - _request: Arc, - project: Entity, - action_log: Entity, - _model: Arc, - _window: Option, - cx: &mut App, - ) -> ToolResult { - let path_str = match serde_json::from_value::(input) { - Ok(input) => input.path, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; - let Some(project_path) = project.read(cx).find_project_path(&path_str, cx) else { - return Task::ready(Err(anyhow!( - "Couldn't delete {path_str} because that path isn't in this project." - ))) - .into(); - }; - - let Some(worktree) = project - .read(cx) - .worktree_for_id(project_path.worktree_id, cx) - else { - return Task::ready(Err(anyhow!( - "Couldn't delete {path_str} because that path isn't in this project." - ))) - .into(); - }; - - let worktree_snapshot = worktree.read(cx).snapshot(); - let (mut paths_tx, mut paths_rx) = mpsc::channel(256); - cx.background_spawn({ - let project_path = project_path.clone(); - async move { - for entry in - worktree_snapshot.traverse_from_path(true, false, false, &project_path.path) - { - if !entry.path.starts_with(&project_path.path) { - break; - } - paths_tx - .send(ProjectPath { - worktree_id: project_path.worktree_id, - path: entry.path.clone(), - }) - .await?; - } - anyhow::Ok(()) - } - }) - .detach(); - - cx.spawn(async move |cx| { - while let Some(path) = paths_rx.next().await { - if let Ok(buffer) = project - .update(cx, |project, cx| project.open_buffer(path, cx))? - .await - { - action_log.update(cx, |action_log, cx| { - action_log.will_delete_buffer(buffer.clone(), cx) - })?; - } - } - - let deletion_task = project - .update(cx, |project, cx| { - project.delete_file(project_path, false, cx) - })? - .with_context(|| { - format!("Couldn't delete {path_str} because that path isn't in this project.") - })?; - deletion_task - .await - .with_context(|| format!("Deleting {path_str}"))?; - Ok(format!("Deleted {path_str}").into()) - }) - .into() - } -} diff --git a/crates/assistant_tools/src/delete_path_tool/description.md b/crates/assistant_tools/src/delete_path_tool/description.md deleted file mode 100644 index dfd4388bf04cf32038d04cacf169e9ea4bf05c56..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/delete_path_tool/description.md +++ /dev/null @@ -1 +0,0 @@ -Deletes the file or directory (and the directory's contents, recursively) at the specified path in the project, and returns confirmation of the deletion. diff --git a/crates/assistant_tools/src/diagnostics_tool.rs b/crates/assistant_tools/src/diagnostics_tool.rs deleted file mode 100644 index 75bd683512b58d2fdb6c43fc319d266f6609f926..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/diagnostics_tool.rs +++ /dev/null @@ -1,171 +0,0 @@ -use crate::schema::json_schema_for; -use action_log::ActionLog; -use anyhow::{Result, anyhow}; -use assistant_tool::{Tool, ToolResult}; -use gpui::{AnyWindowHandle, App, Entity, Task}; -use language::{DiagnosticSeverity, OffsetRangeExt}; -use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; -use project::Project; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use std::{fmt::Write, sync::Arc}; -use ui::IconName; -use util::markdown::MarkdownInlineCode; - -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -pub struct DiagnosticsToolInput { - /// The path to get diagnostics for. If not provided, returns a project-wide summary. - /// - /// This path should never be absolute, and the first component - /// of the path should always be a root directory in a project. - /// - /// - /// If the project has the following root directories: - /// - /// - lorem - /// - ipsum - /// - /// If you wanna access diagnostics for `dolor.txt` in `ipsum`, you should use the path `ipsum/dolor.txt`. - /// - #[serde(deserialize_with = "deserialize_path")] - pub path: Option, -} - -fn deserialize_path<'de, D>(deserializer: D) -> Result, D::Error> -where - D: serde::Deserializer<'de>, -{ - let opt = Option::::deserialize(deserializer)?; - // The model passes an empty string sometimes - Ok(opt.filter(|s| !s.is_empty())) -} - -pub struct DiagnosticsTool; - -impl Tool for DiagnosticsTool { - fn name(&self) -> String { - "diagnostics".into() - } - - fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { - false - } - - fn may_perform_edits(&self) -> bool { - false - } - - fn description(&self) -> String { - include_str!("./diagnostics_tool/description.md").into() - } - - fn icon(&self) -> IconName { - IconName::ToolDiagnostics - } - - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { - json_schema_for::(format) - } - - fn ui_text(&self, input: &serde_json::Value) -> String { - if let Some(path) = serde_json::from_value::(input.clone()) - .ok() - .and_then(|input| match input.path { - Some(path) if !path.is_empty() => Some(path), - _ => None, - }) - { - format!("Check diagnostics for {}", MarkdownInlineCode(&path)) - } else { - "Check project diagnostics".to_string() - } - } - - fn run( - self: Arc, - input: serde_json::Value, - _request: Arc, - project: Entity, - _action_log: Entity, - _model: Arc, - _window: Option, - cx: &mut App, - ) -> ToolResult { - match serde_json::from_value::(input) - .ok() - .and_then(|input| input.path) - { - Some(path) if !path.is_empty() => { - let Some(project_path) = project.read(cx).find_project_path(&path, cx) else { - return Task::ready(Err(anyhow!("Could not find path {path} in project",))) - .into(); - }; - - let buffer = - project.update(cx, |project, cx| project.open_buffer(project_path, cx)); - - cx.spawn(async move |cx| { - let mut output = String::new(); - let buffer = buffer.await?; - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - - for (_, group) in snapshot.diagnostic_groups(None) { - let entry = &group.entries[group.primary_ix]; - let range = entry.range.to_point(&snapshot); - let severity = match entry.diagnostic.severity { - DiagnosticSeverity::ERROR => "error", - DiagnosticSeverity::WARNING => "warning", - _ => continue, - }; - - writeln!( - output, - "{} at line {}: {}", - severity, - range.start.row + 1, - entry.diagnostic.message - )?; - } - - if output.is_empty() { - Ok("File doesn't have errors or warnings!".to_string().into()) - } else { - Ok(output.into()) - } - }) - .into() - } - _ => { - let project = project.read(cx); - let mut output = String::new(); - let mut has_diagnostics = false; - - for (project_path, _, summary) in project.diagnostic_summaries(true, cx) { - if summary.error_count > 0 || summary.warning_count > 0 { - let Some(worktree) = project.worktree_for_id(project_path.worktree_id, cx) - else { - continue; - }; - - has_diagnostics = true; - output.push_str(&format!( - "{}: {} error(s), {} warning(s)\n", - worktree.read(cx).absolutize(&project_path.path).display(), - summary.error_count, - summary.warning_count - )); - } - } - - if has_diagnostics { - Task::ready(Ok(output.into())).into() - } else { - Task::ready(Ok("No errors or warnings found in the project." - .to_string() - .into())) - .into() - } - } - } - } -} diff --git a/crates/assistant_tools/src/diagnostics_tool/description.md b/crates/assistant_tools/src/diagnostics_tool/description.md deleted file mode 100644 index 90dc00f1e408c0bd4d79de68833db9d4bafc0d2c..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/diagnostics_tool/description.md +++ /dev/null @@ -1,21 +0,0 @@ -Get errors and warnings for the project or a specific file. - -This tool can be invoked after a series of edits to determine if further edits are necessary, or if the user asks to fix errors or warnings in their codebase. - -When a path is provided, shows all diagnostics for that specific file. -When no path is provided, shows a summary of error and warning counts for all files in the project. - - -To get diagnostics for a specific file: -{ - "path": "src/main.rs" -} - -To get a project-wide diagnostic summary: -{} - - - -- If you think you can fix a diagnostic, make 1-2 attempts and then give up. -- Don't remove code you've generated just because you can't fix an error. The user can help you fix it. - diff --git a/crates/assistant_tools/src/edit_file_tool.rs b/crates/assistant_tools/src/edit_file_tool.rs deleted file mode 100644 index 840f34aaae9381882e39f8435242625022dfc26c..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/edit_file_tool.rs +++ /dev/null @@ -1,2423 +0,0 @@ -use crate::{ - Templates, - edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat}, - schema::json_schema_for, - ui::{COLLAPSED_LINES, ToolOutputPreview}, -}; -use action_log::ActionLog; -use agent_settings; -use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{ - AnyToolCard, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus, -}; -use buffer_diff::{BufferDiff, BufferDiffSnapshot}; -use editor::{ - Editor, EditorMode, MinimapVisibility, MultiBuffer, PathKey, multibuffer_context_lines, -}; -use futures::StreamExt; -use gpui::{ - Animation, AnimationExt, AnyWindowHandle, App, AppContext, AsyncApp, Entity, Task, - TextStyleRefinement, WeakEntity, pulsating_between, -}; -use indoc::formatdoc; -use language::{ - Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Point, Rope, - TextBuffer, - language_settings::{self, FormatOnSave, SoftWrap}, -}; -use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; -use markdown::{Markdown, MarkdownElement, MarkdownStyle}; -use paths; -use project::{ - Project, ProjectPath, - lsp_store::{FormatTrigger, LspFormatTarget}, -}; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use settings::Settings; -use std::{ - cmp::Reverse, - collections::HashSet, - ffi::OsStr, - ops::Range, - path::{Path, PathBuf}, - sync::Arc, - time::Duration, -}; -use theme::ThemeSettings; -use ui::{CommonAnimationExt, Disclosure, Tooltip, prelude::*}; -use util::{ResultExt, rel_path::RelPath}; -use workspace::Workspace; - -pub struct EditFileTool; - -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] -pub struct EditFileToolInput { - /// A one-line, user-friendly markdown description of the edit. This will be - /// shown in the UI and also passed to another model to perform the edit. - /// - /// Be terse, but also descriptive in what you want to achieve with this - /// edit. Avoid generic instructions. - /// - /// NEVER mention the file path in this description. - /// - /// Fix API endpoint URLs - /// Update copyright year in `page_footer` - /// - /// Make sure to include this field before all the others in the input object - /// so that we can display it immediately. - pub display_description: String, - - /// The full path of the file to create or modify in the project. - /// - /// WARNING: When specifying which file path need changing, you MUST - /// start each path with one of the project's root directories. - /// - /// The following examples assume we have two root directories in the project: - /// - /a/b/backend - /// - /c/d/frontend - /// - /// - /// `backend/src/main.rs` - /// - /// Notice how the file path starts with `backend`. Without that, the path - /// would be ambiguous and the call would fail! - /// - /// - /// - /// `frontend/db.js` - /// - pub path: PathBuf, - - /// The mode of operation on the file. Possible values: - /// - 'edit': Make granular edits to an existing file. - /// - 'create': Create a new file if it doesn't exist. - /// - 'overwrite': Replace the entire contents of an existing file. - /// - /// When a file already exists or you just created it, prefer editing - /// it as opposed to recreating it from scratch. - pub mode: EditFileMode, -} - -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] -#[serde(rename_all = "lowercase")] -pub enum EditFileMode { - Edit, - Create, - Overwrite, -} - -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -pub struct EditFileToolOutput { - pub original_path: PathBuf, - pub new_text: String, - pub old_text: Arc, - pub raw_output: Option, -} - -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -struct PartialInput { - #[serde(default)] - path: String, - #[serde(default)] - display_description: String, -} - -const DEFAULT_UI_TEXT: &str = "Editing file"; - -impl Tool for EditFileTool { - fn name(&self) -> String { - "edit_file".into() - } - - fn needs_confirmation( - &self, - input: &serde_json::Value, - project: &Entity, - cx: &App, - ) -> bool { - if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions { - return false; - } - - let Ok(input) = serde_json::from_value::(input.clone()) else { - // If it's not valid JSON, it's going to error and confirming won't do anything. - return false; - }; - - // If any path component matches the local settings folder, then this could affect - // the editor in ways beyond the project source, so prompt. - let local_settings_folder = paths::local_settings_folder_name(); - let path = Path::new(&input.path); - if path - .components() - .any(|c| c.as_os_str() == >::as_ref(local_settings_folder)) - { - return true; - } - - // It's also possible that the global config dir is configured to be inside the project, - // so check for that edge case too. - if let Ok(canonical_path) = std::fs::canonicalize(&input.path) - && canonical_path.starts_with(paths::config_dir()) - { - return true; - } - - // Check if path is inside the global config directory - // First check if it's already inside project - if not, try to canonicalize - let project_path = project.read(cx).find_project_path(&input.path, cx); - - // If the path is inside the project, and it's not one of the above edge cases, - // then no confirmation is necessary. Otherwise, confirmation is necessary. - project_path.is_none() - } - - fn may_perform_edits(&self) -> bool { - true - } - - fn description(&self) -> String { - include_str!("edit_file_tool/description.md").to_string() - } - - fn icon(&self) -> IconName { - IconName::ToolPencil - } - - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { - json_schema_for::(format) - } - - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => { - let path = Path::new(&input.path); - let mut description = input.display_description.clone(); - - // Add context about why confirmation may be needed - let local_settings_folder = paths::local_settings_folder_name(); - if path - .components() - .any(|c| c.as_os_str() == >::as_ref(local_settings_folder)) - { - description.push_str(" (local settings)"); - } else if let Ok(canonical_path) = std::fs::canonicalize(&input.path) - && canonical_path.starts_with(paths::config_dir()) - { - description.push_str(" (global settings)"); - } - - description - } - Err(_) => "Editing file".to_string(), - } - } - - fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String { - if let Some(input) = serde_json::from_value::(input.clone()).ok() { - let description = input.display_description.trim(); - if !description.is_empty() { - return description.to_string(); - } - - let path = input.path.trim(); - if !path.is_empty() { - return path.to_string(); - } - } - - DEFAULT_UI_TEXT.to_string() - } - - fn run( - self: Arc, - input: serde_json::Value, - request: Arc, - project: Entity, - action_log: Entity, - model: Arc, - window: Option, - cx: &mut App, - ) -> ToolResult { - let input = match serde_json::from_value::(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; - - let project_path = match resolve_path(&input, project.clone(), cx) { - Ok(path) => path, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; - - let card = window.and_then(|window| { - window - .update(cx, |_, window, cx| { - cx.new(|cx| { - EditFileToolCard::new(input.path.clone(), project.clone(), window, cx) - }) - }) - .ok() - }); - - let card_clone = card.clone(); - let action_log_clone = action_log.clone(); - let task = cx.spawn(async move |cx: &mut AsyncApp| { - let edit_format = EditFormat::from_model(model.clone())?; - let edit_agent = EditAgent::new( - model, - project.clone(), - action_log_clone, - Templates::new(), - edit_format, - ); - - let buffer = project - .update(cx, |project, cx| { - project.open_buffer(project_path.clone(), cx) - })? - .await?; - - let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - let old_text = cx - .background_spawn({ - let old_snapshot = old_snapshot.clone(); - async move { Arc::new(old_snapshot.text()) } - }) - .await; - - if let Some(card) = card_clone.as_ref() { - card.update(cx, |card, cx| card.initialize(buffer.clone(), cx))?; - } - - let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) { - edit_agent.edit( - buffer.clone(), - input.display_description.clone(), - &request, - cx, - ) - } else { - edit_agent.overwrite( - buffer.clone(), - input.display_description.clone(), - &request, - cx, - ) - }; - - let mut hallucinated_old_text = false; - let mut ambiguous_ranges = Vec::new(); - while let Some(event) = events.next().await { - match event { - EditAgentOutputEvent::Edited { .. } => { - if let Some(card) = card_clone.as_ref() { - card.update(cx, |card, cx| card.update_diff(cx))?; - } - } - EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true, - EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges, - EditAgentOutputEvent::ResolvingEditRange(range) => { - if let Some(card) = card_clone.as_ref() { - card.update(cx, |card, cx| card.reveal_range(range, cx))?; - } - } - } - } - let agent_output = output.await?; - - // If format_on_save is enabled, format the buffer - let format_on_save_enabled = buffer - .read_with(cx, |buffer, cx| { - let settings = language_settings::language_settings( - buffer.language().map(|l| l.name()), - buffer.file(), - cx, - ); - !matches!(settings.format_on_save, FormatOnSave::Off) - }) - .unwrap_or(false); - - if format_on_save_enabled { - action_log.update(cx, |log, cx| { - log.buffer_edited(buffer.clone(), cx); - })?; - let format_task = project.update(cx, |project, cx| { - project.format( - HashSet::from_iter([buffer.clone()]), - LspFormatTarget::Buffers, - false, // Don't push to history since the tool did it. - FormatTrigger::Save, - cx, - ) - })?; - format_task.await.log_err(); - } - - project - .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))? - .await?; - - // Notify the action log that we've edited the buffer (*after* formatting has completed). - action_log.update(cx, |log, cx| { - log.buffer_edited(buffer.clone(), cx); - })?; - - let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - let (new_text, diff) = cx - .background_spawn({ - let new_snapshot = new_snapshot.clone(); - let old_text = old_text.clone(); - async move { - let new_text = new_snapshot.text(); - let diff = language::unified_diff(&old_text, &new_text); - - (new_text, diff) - } - }) - .await; - - let output = EditFileToolOutput { - original_path: project_path.path.as_std_path().to_owned(), - new_text, - old_text, - raw_output: Some(agent_output), - }; - - if let Some(card) = card_clone { - card.update(cx, |card, cx| { - card.update_diff(cx); - card.finalize(cx) - }) - .log_err(); - } - - let input_path = input.path.display(); - if diff.is_empty() { - anyhow::ensure!( - !hallucinated_old_text, - formatdoc! {" - Some edits were produced but none of them could be applied. - Read the relevant sections of {input_path} again so that - I can perform the requested edits. - "} - ); - anyhow::ensure!( - ambiguous_ranges.is_empty(), - { - let line_numbers = ambiguous_ranges - .iter() - .map(|range| range.start.to_string()) - .collect::>() - .join(", "); - formatdoc! {" - matches more than one position in the file (lines: {line_numbers}). Read the - relevant sections of {input_path} again and extend so - that I can perform the requested edits. - "} - } - ); - Ok(ToolResultOutput { - content: ToolResultContent::Text("No edits were made.".into()), - output: serde_json::to_value(output).ok(), - }) - } else { - Ok(ToolResultOutput { - content: ToolResultContent::Text(format!( - "Edited {}:\n\n```diff\n{}\n```", - input_path, diff - )), - output: serde_json::to_value(output).ok(), - }) - } - }); - - ToolResult { - output: task, - card: card.map(AnyToolCard::from), - } - } - - fn deserialize_card( - self: Arc, - output: serde_json::Value, - project: Entity, - window: &mut Window, - cx: &mut App, - ) -> Option { - let output = match serde_json::from_value::(output) { - Ok(output) => output, - Err(_) => return None, - }; - - let card = cx.new(|cx| { - EditFileToolCard::new(output.original_path.clone(), project.clone(), window, cx) - }); - - cx.spawn({ - let path: Arc = output.original_path.into(); - let language_registry = project.read(cx).languages().clone(); - let card = card.clone(); - async move |cx| { - let buffer = - build_buffer(output.new_text, path.clone(), &language_registry, cx).await?; - let buffer_diff = - build_buffer_diff(output.old_text.clone(), &buffer, &language_registry, cx) - .await?; - card.update(cx, |card, cx| { - card.multibuffer.update(cx, |multibuffer, cx| { - let snapshot = buffer.read(cx).snapshot(); - let diff = buffer_diff.read(cx); - let diff_hunk_ranges = diff - .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx) - .map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot)) - .collect::>(); - - multibuffer.set_excerpts_for_path( - PathKey::for_buffer(&buffer, cx), - buffer, - diff_hunk_ranges, - multibuffer_context_lines(cx), - cx, - ); - multibuffer.add_diff(buffer_diff, cx); - let end = multibuffer.len(cx); - card.total_lines = - Some(multibuffer.snapshot(cx).offset_to_point(end).row + 1); - }); - - cx.notify(); - })?; - anyhow::Ok(()) - } - }) - .detach_and_log_err(cx); - - Some(card.into()) - } -} - -/// Validate that the file path is valid, meaning: -/// -/// - For `edit` and `overwrite`, the path must point to an existing file. -/// - For `create`, the file must not already exist, but it's parent dir must exist. -fn resolve_path( - input: &EditFileToolInput, - project: Entity, - cx: &mut App, -) -> Result { - let project = project.read(cx); - - match input.mode { - EditFileMode::Edit | EditFileMode::Overwrite => { - let path = project - .find_project_path(&input.path, cx) - .context("Can't edit file: path not found")?; - - let entry = project - .entry_for_path(&path, cx) - .context("Can't edit file: path not found")?; - - anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory"); - Ok(path) - } - - EditFileMode::Create => { - if let Some(path) = project.find_project_path(&input.path, cx) { - anyhow::ensure!( - project.entry_for_path(&path, cx).is_none(), - "Can't create file: file already exists" - ); - } - - let parent_path = input - .path - .parent() - .context("Can't create file: incorrect path")?; - - let parent_project_path = project.find_project_path(&parent_path, cx); - - let parent_entry = parent_project_path - .as_ref() - .and_then(|path| project.entry_for_path(path, cx)) - .context("Can't create file: parent directory doesn't exist")?; - - anyhow::ensure!( - parent_entry.is_dir(), - "Can't create file: parent is not a directory" - ); - - let file_name = input - .path - .file_name() - .and_then(|file_name| file_name.to_str()) - .context("Can't create file: invalid filename")?; - - let new_file_path = parent_project_path.map(|parent| ProjectPath { - path: parent.path.join(RelPath::unix(file_name).unwrap()), - ..parent - }); - - new_file_path.context("Can't create file") - } - } -} - -pub struct EditFileToolCard { - path: PathBuf, - editor: Entity, - multibuffer: Entity, - project: Entity, - buffer: Option>, - base_text: Option>, - buffer_diff: Option>, - revealed_ranges: Vec>, - diff_task: Option>>, - preview_expanded: bool, - error_expanded: Option>, - full_height_expanded: bool, - total_lines: Option, -} - -impl EditFileToolCard { - pub fn new(path: PathBuf, project: Entity, window: &mut Window, cx: &mut App) -> Self { - let expand_edit_card = agent_settings::AgentSettings::get_global(cx).expand_edit_card; - let multibuffer = cx.new(|_| MultiBuffer::without_headers(Capability::ReadOnly)); - - let editor = cx.new(|cx| { - let mut editor = Editor::new( - EditorMode::Full { - scale_ui_elements_with_buffer_font_size: false, - show_active_line_background: false, - sized_by_content: true, - }, - multibuffer.clone(), - Some(project.clone()), - window, - cx, - ); - editor.set_show_gutter(false, cx); - editor.disable_inline_diagnostics(); - editor.disable_expand_excerpt_buttons(cx); - // Keep horizontal scrollbar so user can scroll horizontally if needed - editor.set_show_vertical_scrollbar(false, cx); - editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx); - editor.set_soft_wrap_mode(SoftWrap::None, cx); - editor.scroll_manager.set_forbid_vertical_scroll(true); - editor.set_show_indent_guides(false, cx); - editor.set_read_only(true); - editor.set_show_breakpoints(false, cx); - editor.set_show_code_actions(false, cx); - editor.set_show_git_diff_gutter(false, cx); - editor.set_expand_all_diff_hunks(cx); - editor - }); - Self { - path, - project, - editor, - multibuffer, - buffer: None, - base_text: None, - buffer_diff: None, - revealed_ranges: Vec::new(), - diff_task: None, - preview_expanded: true, - error_expanded: None, - full_height_expanded: expand_edit_card, - total_lines: None, - } - } - - pub fn initialize(&mut self, buffer: Entity, cx: &mut App) { - let buffer_snapshot = buffer.read(cx).snapshot(); - let base_text = buffer_snapshot.text(); - let language_registry = buffer.read(cx).language_registry(); - let text_snapshot = buffer.read(cx).text_snapshot(); - - // Create a buffer diff with the current text as the base - let buffer_diff = cx.new(|cx| { - let mut diff = BufferDiff::new(&text_snapshot, cx); - let _ = diff.set_base_text( - buffer_snapshot.clone(), - language_registry, - text_snapshot, - cx, - ); - diff - }); - - self.buffer = Some(buffer); - self.base_text = Some(base_text.into()); - self.buffer_diff = Some(buffer_diff.clone()); - - // Add the diff to the multibuffer - self.multibuffer - .update(cx, |multibuffer, cx| multibuffer.add_diff(buffer_diff, cx)); - } - - pub fn is_loading(&self) -> bool { - self.total_lines.is_none() - } - - pub fn update_diff(&mut self, cx: &mut Context) { - let Some(buffer) = self.buffer.as_ref() else { - return; - }; - let Some(buffer_diff) = self.buffer_diff.as_ref() else { - return; - }; - - let buffer = buffer.clone(); - let buffer_diff = buffer_diff.clone(); - let base_text = self.base_text.clone(); - self.diff_task = Some(cx.spawn(async move |this, cx| { - let text_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot())?; - let diff_snapshot = BufferDiff::update_diff( - buffer_diff.clone(), - text_snapshot.clone(), - base_text, - false, - false, - None, - None, - cx, - ) - .await?; - buffer_diff.update(cx, |diff, cx| { - diff.set_snapshot(diff_snapshot, &text_snapshot, cx) - })?; - this.update(cx, |this, cx| this.update_visible_ranges(cx)) - })); - } - - pub fn reveal_range(&mut self, range: Range, cx: &mut Context) { - self.revealed_ranges.push(range); - self.update_visible_ranges(cx); - } - - fn update_visible_ranges(&mut self, cx: &mut Context) { - let Some(buffer) = self.buffer.as_ref() else { - return; - }; - - let ranges = self.excerpt_ranges(cx); - self.total_lines = self.multibuffer.update(cx, |multibuffer, cx| { - multibuffer.set_excerpts_for_path( - PathKey::for_buffer(buffer, cx), - buffer.clone(), - ranges, - multibuffer_context_lines(cx), - cx, - ); - let end = multibuffer.len(cx); - Some(multibuffer.snapshot(cx).offset_to_point(end).row + 1) - }); - cx.notify(); - } - - fn excerpt_ranges(&self, cx: &App) -> Vec> { - let Some(buffer) = self.buffer.as_ref() else { - return Vec::new(); - }; - let Some(diff) = self.buffer_diff.as_ref() else { - return Vec::new(); - }; - - let buffer = buffer.read(cx); - let diff = diff.read(cx); - let mut ranges = diff - .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, buffer, cx) - .map(|diff_hunk| diff_hunk.buffer_range.to_point(buffer)) - .collect::>(); - ranges.extend( - self.revealed_ranges - .iter() - .map(|range| range.to_point(buffer)), - ); - ranges.sort_unstable_by_key(|range| (range.start, Reverse(range.end))); - - // Merge adjacent ranges - let mut ranges = ranges.into_iter().peekable(); - let mut merged_ranges = Vec::new(); - while let Some(mut range) = ranges.next() { - while let Some(next_range) = ranges.peek() { - if range.end >= next_range.start { - range.end = range.end.max(next_range.end); - ranges.next(); - } else { - break; - } - } - - merged_ranges.push(range); - } - merged_ranges - } - - pub fn finalize(&mut self, cx: &mut Context) -> Result<()> { - let ranges = self.excerpt_ranges(cx); - let buffer = self.buffer.take().context("card was already finalized")?; - let base_text = self - .base_text - .take() - .context("card was already finalized")?; - let language_registry = self.project.read(cx).languages().clone(); - - // Replace the buffer in the multibuffer with the snapshot - let buffer = cx.new(|cx| { - let language = buffer.read(cx).language().cloned(); - let buffer = TextBuffer::new_normalized( - 0, - cx.entity_id().as_non_zero_u64().into(), - buffer.read(cx).line_ending(), - buffer.read(cx).as_rope().clone(), - ); - let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite); - buffer.set_language(language, cx); - buffer - }); - - let buffer_diff = cx.spawn({ - let buffer = buffer.clone(); - async move |_this, cx| { - build_buffer_diff(base_text, &buffer, &language_registry, cx).await - } - }); - - cx.spawn(async move |this, cx| { - let buffer_diff = buffer_diff.await?; - this.update(cx, |this, cx| { - this.multibuffer.update(cx, |multibuffer, cx| { - let path_key = PathKey::for_buffer(&buffer, cx); - multibuffer.clear(cx); - multibuffer.set_excerpts_for_path( - path_key, - buffer, - ranges, - multibuffer_context_lines(cx), - cx, - ); - multibuffer.add_diff(buffer_diff.clone(), cx); - }); - - cx.notify(); - }) - }) - .detach_and_log_err(cx); - Ok(()) - } -} - -impl ToolCard for EditFileToolCard { - fn render( - &mut self, - status: &ToolUseStatus, - window: &mut Window, - workspace: WeakEntity, - cx: &mut Context, - ) -> impl IntoElement { - let error_message = match status { - ToolUseStatus::Error(err) => Some(err), - _ => None, - }; - - let running_or_pending = match status { - ToolUseStatus::Running | ToolUseStatus::Pending => Some(()), - _ => None, - }; - - let should_show_loading = running_or_pending.is_some() && !self.full_height_expanded; - - let path_label_button = h_flex() - .id(("edit-tool-path-label-button", self.editor.entity_id())) - .w_full() - .max_w_full() - .px_1() - .gap_0p5() - .cursor_pointer() - .rounded_sm() - .opacity(0.8) - .hover(|label| { - label - .opacity(1.) - .bg(cx.theme().colors().element_hover.opacity(0.5)) - }) - .tooltip(Tooltip::text("Jump to File")) - .child( - h_flex() - .child( - Icon::new(IconName::ToolPencil) - .size(IconSize::Small) - .color(Color::Muted), - ) - .child( - div() - .text_size(rems(0.8125)) - .child(self.path.display().to_string()) - .ml_1p5() - .mr_0p5(), - ) - .child( - Icon::new(IconName::ArrowUpRight) - .size(IconSize::Small) - .color(Color::Ignored), - ), - ) - .on_click({ - let path = self.path.clone(); - move |_, window, cx| { - workspace - .update(cx, { - |workspace, cx| { - let Some(project_path) = - workspace.project().read(cx).find_project_path(&path, cx) - else { - return; - }; - let open_task = - workspace.open_path(project_path, None, true, window, cx); - window - .spawn(cx, async move |cx| { - let item = open_task.await?; - if let Some(active_editor) = item.downcast::() { - active_editor - .update_in(cx, |editor, window, cx| { - let snapshot = - editor.buffer().read(cx).snapshot(cx); - let first_hunk = editor - .diff_hunks_in_ranges( - &[editor::Anchor::min() - ..editor::Anchor::max()], - &snapshot, - ) - .next(); - if let Some(first_hunk) = first_hunk { - let first_hunk_start = - first_hunk.multi_buffer_range().start; - editor.change_selections( - Default::default(), - window, - cx, - |selections| { - selections.select_anchor_ranges([ - first_hunk_start - ..first_hunk_start, - ]); - }, - ) - } - }) - .log_err(); - } - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - } - }) - .ok(); - } - }) - .into_any_element(); - - let codeblock_header_bg = cx - .theme() - .colors() - .element_background - .blend(cx.theme().colors().editor_foreground.opacity(0.025)); - - let codeblock_header = h_flex() - .flex_none() - .p_1() - .gap_1() - .justify_between() - .rounded_t_md() - .when(error_message.is_none(), |header| { - header.bg(codeblock_header_bg) - }) - .child(path_label_button) - .when(should_show_loading, |header| { - header.pr_1p5().child( - Icon::new(IconName::ArrowCircle) - .size(IconSize::XSmall) - .color(Color::Info) - .with_rotate_animation(2), - ) - }) - .when_some(error_message, |header, error_message| { - header.child( - h_flex() - .gap_1() - .child( - Icon::new(IconName::Close) - .size(IconSize::Small) - .color(Color::Error), - ) - .child( - Disclosure::new( - ("edit-file-error-disclosure", self.editor.entity_id()), - self.error_expanded.is_some(), - ) - .opened_icon(IconName::ChevronUp) - .closed_icon(IconName::ChevronDown) - .on_click(cx.listener({ - let error_message = error_message.clone(); - - move |this, _event, _window, cx| { - if this.error_expanded.is_some() { - this.error_expanded.take(); - } else { - this.error_expanded = Some(cx.new(|cx| { - Markdown::new(error_message.clone(), None, None, cx) - })) - } - cx.notify(); - } - })), - ), - ) - }) - .when(error_message.is_none() && !self.is_loading(), |header| { - header.child( - Disclosure::new( - ("edit-file-disclosure", self.editor.entity_id()), - self.preview_expanded, - ) - .opened_icon(IconName::ChevronUp) - .closed_icon(IconName::ChevronDown) - .on_click(cx.listener( - move |this, _event, _window, _cx| { - this.preview_expanded = !this.preview_expanded; - }, - )), - ) - }); - - let (editor, editor_line_height) = self.editor.update(cx, |editor, cx| { - let line_height = editor - .style() - .map(|style| style.text.line_height_in_pixels(window.rem_size())) - .unwrap_or_default(); - - editor.set_text_style_refinement(TextStyleRefinement { - font_size: Some( - TextSize::Small - .rems(cx) - .to_pixels(ThemeSettings::get_global(cx).agent_ui_font_size(cx)) - .into(), - ), - ..TextStyleRefinement::default() - }); - let element = editor.render(window, cx); - (element.into_any_element(), line_height) - }); - - let border_color = cx.theme().colors().border.opacity(0.6); - - let waiting_for_diff = { - let styles = [ - ("w_4_5", (0.1, 0.85), 2000), - ("w_1_4", (0.2, 0.75), 2200), - ("w_2_4", (0.15, 0.64), 1900), - ("w_3_5", (0.25, 0.72), 2300), - ("w_2_5", (0.3, 0.56), 1800), - ]; - - let mut container = v_flex() - .p_3() - .gap_1() - .border_t_1() - .rounded_b_md() - .border_color(border_color) - .bg(cx.theme().colors().editor_background); - - for (width_method, pulse_range, duration_ms) in styles.iter() { - let (min_opacity, max_opacity) = *pulse_range; - let placeholder = match *width_method { - "w_4_5" => div().w_3_4(), - "w_1_4" => div().w_1_4(), - "w_2_4" => div().w_2_4(), - "w_3_5" => div().w_3_5(), - "w_2_5" => div().w_2_5(), - _ => div().w_1_2(), - } - .id("loading_div") - .h_1() - .rounded_full() - .bg(cx.theme().colors().element_active) - .with_animation( - "loading_pulsate", - Animation::new(Duration::from_millis(*duration_ms)) - .repeat() - .with_easing(pulsating_between(min_opacity, max_opacity)), - |label, delta| label.opacity(delta), - ); - - container = container.child(placeholder); - } - - container - }; - - v_flex() - .mb_2() - .border_1() - .when(error_message.is_some(), |card| card.border_dashed()) - .border_color(border_color) - .rounded_md() - .overflow_hidden() - .child(codeblock_header) - .when_some(self.error_expanded.as_ref(), |card, error_markdown| { - card.child( - v_flex() - .p_2() - .gap_1() - .border_t_1() - .border_dashed() - .border_color(border_color) - .bg(cx.theme().colors().editor_background) - .rounded_b_md() - .child( - Label::new("Error") - .size(LabelSize::XSmall) - .color(Color::Error), - ) - .child( - div() - .rounded_md() - .text_ui_sm(cx) - .bg(cx.theme().colors().editor_background) - .child(MarkdownElement::new( - error_markdown.clone(), - markdown_style(window, cx), - )), - ), - ) - }) - .when(self.is_loading() && error_message.is_none(), |card| { - card.child(waiting_for_diff) - }) - .when(self.preview_expanded && !self.is_loading(), |card| { - let editor_view = v_flex() - .relative() - .h_full() - .when(!self.full_height_expanded, |editor_container| { - editor_container.max_h(COLLAPSED_LINES as f32 * editor_line_height) - }) - .overflow_hidden() - .border_t_1() - .border_color(border_color) - .bg(cx.theme().colors().editor_background) - .child(editor); - - card.child( - ToolOutputPreview::new(editor_view.into_any_element(), self.editor.entity_id()) - .with_total_lines(self.total_lines.unwrap_or(0) as usize) - .toggle_state(self.full_height_expanded) - .with_collapsed_fade() - .on_toggle({ - let this = cx.entity().downgrade(); - move |is_expanded, _window, cx| { - if let Some(this) = this.upgrade() { - this.update(cx, |this, _cx| { - this.full_height_expanded = is_expanded; - }); - } - } - }), - ) - }) - } -} - -fn markdown_style(window: &Window, cx: &App) -> MarkdownStyle { - let theme_settings = ThemeSettings::get_global(cx); - let ui_font_size = TextSize::Default.rems(cx); - let mut text_style = window.text_style(); - - text_style.refine(&TextStyleRefinement { - font_family: Some(theme_settings.ui_font.family.clone()), - font_fallbacks: theme_settings.ui_font.fallbacks.clone(), - font_features: Some(theme_settings.ui_font.features.clone()), - font_size: Some(ui_font_size.into()), - color: Some(cx.theme().colors().text), - ..Default::default() - }); - - MarkdownStyle { - base_text_style: text_style.clone(), - selection_background_color: cx.theme().colors().element_selection_background, - ..Default::default() - } -} - -async fn build_buffer( - mut text: String, - path: Arc, - language_registry: &Arc, - cx: &mut AsyncApp, -) -> Result> { - let line_ending = LineEnding::detect(&text); - LineEnding::normalize(&mut text); - let text = Rope::from(text); - let language = cx - .update(|_cx| language_registry.load_language_for_file_path(&path))? - .await - .ok(); - let buffer = cx.new(|cx| { - let buffer = TextBuffer::new_normalized( - 0, - cx.entity_id().as_non_zero_u64().into(), - line_ending, - text, - ); - let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite); - buffer.set_language(language, cx); - buffer - })?; - Ok(buffer) -} - -async fn build_buffer_diff( - old_text: Arc, - buffer: &Entity, - language_registry: &Arc, - cx: &mut AsyncApp, -) -> Result> { - let buffer = cx.update(|cx| buffer.read(cx).snapshot())?; - - let old_text_rope = cx - .background_spawn({ - let old_text = old_text.clone(); - async move { Rope::from(old_text.as_str()) } - }) - .await; - let base_buffer = cx - .update(|cx| { - Buffer::build_snapshot( - old_text_rope, - buffer.language().cloned(), - Some(language_registry.clone()), - cx, - ) - })? - .await; - - let diff_snapshot = cx - .update(|cx| { - BufferDiffSnapshot::new_with_base_buffer( - buffer.text.clone(), - Some(old_text), - base_buffer, - cx, - ) - })? - .await; - - let secondary_diff = cx.new(|cx| { - let mut diff = BufferDiff::new(&buffer, cx); - diff.set_snapshot(diff_snapshot.clone(), &buffer, cx); - diff - })?; - - cx.new(|cx| { - let mut diff = BufferDiff::new(&buffer.text, cx); - diff.set_snapshot(diff_snapshot, &buffer, cx); - diff.set_secondary_diff(secondary_diff); - diff - }) -} - -#[cfg(test)] -mod tests { - use super::*; - use ::fs::Fs; - use client::TelemetrySettings; - use gpui::{TestAppContext, UpdateGlobal}; - use language_model::fake_provider::FakeLanguageModel; - use serde_json::json; - use settings::SettingsStore; - use std::fs; - use util::{path, rel_path::rel_path}; - - #[gpui::test] - async fn test_edit_nonexistent_file(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/root", json!({})).await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let model = Arc::new(FakeLanguageModel::default()); - let result = cx - .update(|cx| { - let input = serde_json::to_value(EditFileToolInput { - display_description: "Some edit".into(), - path: "root/nonexistent_file.txt".into(), - mode: EditFileMode::Edit, - }) - .unwrap(); - Arc::new(EditFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log, - model, - None, - cx, - ) - .output - }) - .await; - assert_eq!( - result.unwrap_err().to_string(), - "Can't edit file: path not found" - ); - } - - #[gpui::test] - async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) { - let mode = &EditFileMode::Create; - - let result = test_resolve_path(mode, "root/new.txt", cx); - assert_resolved_path_eq(result.await, "new.txt"); - - let result = test_resolve_path(mode, "new.txt", cx); - assert_resolved_path_eq(result.await, "new.txt"); - - let result = test_resolve_path(mode, "dir/new.txt", cx); - assert_resolved_path_eq(result.await, "dir/new.txt"); - - let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx); - assert_eq!( - result.await.unwrap_err().to_string(), - "Can't create file: file already exists" - ); - - let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx); - assert_eq!( - result.await.unwrap_err().to_string(), - "Can't create file: parent directory doesn't exist" - ); - } - - #[gpui::test] - async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) { - let mode = &EditFileMode::Edit; - - let path_with_root = "root/dir/subdir/existing.txt"; - let path_without_root = "dir/subdir/existing.txt"; - let result = test_resolve_path(mode, path_with_root, cx); - assert_resolved_path_eq(result.await, path_without_root); - - let result = test_resolve_path(mode, path_without_root, cx); - assert_resolved_path_eq(result.await, path_without_root); - - let result = test_resolve_path(mode, "root/nonexistent.txt", cx); - assert_eq!( - result.await.unwrap_err().to_string(), - "Can't edit file: path not found" - ); - - let result = test_resolve_path(mode, "root/dir", cx); - assert_eq!( - result.await.unwrap_err().to_string(), - "Can't edit file: path is a directory" - ); - } - - async fn test_resolve_path( - mode: &EditFileMode, - path: &str, - cx: &mut TestAppContext, - ) -> anyhow::Result { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "dir": { - "subdir": { - "existing.txt": "hello" - } - } - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - - let input = EditFileToolInput { - display_description: "Some edit".into(), - path: path.into(), - mode: mode.clone(), - }; - - cx.update(|cx| resolve_path(&input, project, cx)) - } - - #[track_caller] - fn assert_resolved_path_eq(path: anyhow::Result, expected: &str) { - let actual = path.expect("Should return valid path").path; - assert_eq!(actual.as_ref(), rel_path(expected)); - } - - #[test] - fn still_streaming_ui_text_with_path() { - let input = json!({ - "path": "src/main.rs", - "display_description": "", - "old_string": "old code", - "new_string": "new code" - }); - - assert_eq!(EditFileTool.still_streaming_ui_text(&input), "src/main.rs"); - } - - #[test] - fn still_streaming_ui_text_with_description() { - let input = json!({ - "path": "", - "display_description": "Fix error handling", - "old_string": "old code", - "new_string": "new code" - }); - - assert_eq!( - EditFileTool.still_streaming_ui_text(&input), - "Fix error handling", - ); - } - - #[test] - fn still_streaming_ui_text_with_path_and_description() { - let input = json!({ - "path": "src/main.rs", - "display_description": "Fix error handling", - "old_string": "old code", - "new_string": "new code" - }); - - assert_eq!( - EditFileTool.still_streaming_ui_text(&input), - "Fix error handling", - ); - } - - #[test] - fn still_streaming_ui_text_no_path_or_description() { - let input = json!({ - "path": "", - "display_description": "", - "old_string": "old code", - "new_string": "new code" - }); - - assert_eq!( - EditFileTool.still_streaming_ui_text(&input), - DEFAULT_UI_TEXT, - ); - } - - #[test] - fn still_streaming_ui_text_with_null() { - let input = serde_json::Value::Null; - - assert_eq!( - EditFileTool.still_streaming_ui_text(&input), - DEFAULT_UI_TEXT, - ); - } - - fn init_test(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - language::init(cx); - TelemetrySettings::register(cx); - agent_settings::AgentSettings::register(cx); - Project::init_settings(cx); - }); - } - - fn init_test_with_config(cx: &mut TestAppContext, data_dir: &Path) { - cx.update(|cx| { - paths::set_custom_data_dir(data_dir.to_str().unwrap()); - // Set custom data directory (config will be under data_dir/config) - - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - language::init(cx); - TelemetrySettings::register(cx); - agent_settings::AgentSettings::register(cx); - Project::init_settings(cx); - }); - } - - #[gpui::test] - async fn test_format_on_save(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/root", json!({"src": {}})).await; - - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - - // Set up a Rust language with LSP formatting support - let rust_language = Arc::new(language::Language::new( - language::LanguageConfig { - name: "Rust".into(), - matcher: language::LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - None, - )); - - // Register the language and fake LSP - let language_registry = project.read_with(cx, |project, _| project.languages().clone()); - language_registry.add(rust_language); - - let mut fake_language_servers = language_registry.register_fake_lsp( - "Rust", - language::FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - document_formatting_provider: Some(lsp::OneOf::Left(true)), - ..Default::default() - }, - ..Default::default() - }, - ); - - // Create the file - fs.save( - path!("/root/src/main.rs").as_ref(), - &"initial content".into(), - language::LineEnding::Unix, - ) - .await - .unwrap(); - - // Open the buffer to trigger LSP initialization - let buffer = project - .update(cx, |project, cx| { - project.open_local_buffer(path!("/root/src/main.rs"), cx) - }) - .await - .unwrap(); - - // Register the buffer with language servers - let _handle = project.update(cx, |project, cx| { - project.register_buffer_with_language_servers(&buffer, cx) - }); - - const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n"; - const FORMATTED_CONTENT: &str = - "This file was formatted by the fake formatter in the test.\n"; - - // Get the fake language server and set up formatting handler - let fake_language_server = fake_language_servers.next().await.unwrap(); - fake_language_server.set_request_handler::({ - |_, _| async move { - Ok(Some(vec![lsp::TextEdit { - range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)), - new_text: FORMATTED_CONTENT.to_string(), - }])) - } - }); - - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let model = Arc::new(FakeLanguageModel::default()); - - // First, test with format_on_save enabled - cx.update(|cx| { - SettingsStore::update_global(cx, |store, cx| { - store.update_user_settings(cx, |settings| { - settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On); - settings.project.all_languages.defaults.formatter = - Some(language::language_settings::FormatterList::default()); - }); - }); - }); - - // Have the model stream unformatted content - let edit_result = { - let edit_task = cx.update(|cx| { - let input = serde_json::to_value(EditFileToolInput { - display_description: "Create main function".into(), - path: "root/src/main.rs".into(), - mode: EditFileMode::Overwrite, - }) - .unwrap(); - Arc::new(EditFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }); - - // Stream the unformatted content - cx.executor().run_until_parked(); - model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string()); - model.end_last_completion_stream(); - - edit_task.await - }; - assert!(edit_result.is_ok()); - - // Wait for any async operations (e.g. formatting) to complete - cx.executor().run_until_parked(); - - // Read the file to verify it was formatted automatically - let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); - assert_eq!( - // Ignore carriage returns on Windows - new_content.replace("\r\n", "\n"), - FORMATTED_CONTENT, - "Code should be formatted when format_on_save is enabled" - ); - - let stale_buffer_count = action_log.read_with(cx, |log, cx| log.stale_buffers(cx).count()); - - assert_eq!( - stale_buffer_count, 0, - "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \ - This causes the agent to think the file was modified externally when it was just formatted.", - stale_buffer_count - ); - - // Next, test with format_on_save disabled - cx.update(|cx| { - SettingsStore::update_global(cx, |store, cx| { - store.update_user_settings(cx, |settings| { - settings.project.all_languages.defaults.format_on_save = - Some(FormatOnSave::Off); - }); - }); - }); - - // Stream unformatted edits again - let edit_result = { - let edit_task = cx.update(|cx| { - let input = serde_json::to_value(EditFileToolInput { - display_description: "Update main function".into(), - path: "root/src/main.rs".into(), - mode: EditFileMode::Overwrite, - }) - .unwrap(); - Arc::new(EditFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }); - - // Stream the unformatted content - cx.executor().run_until_parked(); - model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string()); - model.end_last_completion_stream(); - - edit_task.await - }; - assert!(edit_result.is_ok()); - - // Wait for any async operations (e.g. formatting) to complete - cx.executor().run_until_parked(); - - // Verify the file was not formatted - let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); - assert_eq!( - // Ignore carriage returns on Windows - new_content.replace("\r\n", "\n"), - UNFORMATTED_CONTENT, - "Code should not be formatted when format_on_save is disabled" - ); - } - - #[gpui::test] - async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/root", json!({"src": {}})).await; - - // Create a simple file with trailing whitespace - fs.save( - path!("/root/src/main.rs").as_ref(), - &"initial content".into(), - language::LineEnding::Unix, - ) - .await - .unwrap(); - - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let model = Arc::new(FakeLanguageModel::default()); - - // First, test with remove_trailing_whitespace_on_save enabled - cx.update(|cx| { - SettingsStore::update_global(cx, |store, cx| { - store.update_user_settings(cx, |settings| { - settings - .project - .all_languages - .defaults - .remove_trailing_whitespace_on_save = Some(true); - }); - }); - }); - - const CONTENT_WITH_TRAILING_WHITESPACE: &str = - "fn main() { \n println!(\"Hello!\"); \n}\n"; - - // Have the model stream content that contains trailing whitespace - let edit_result = { - let edit_task = cx.update(|cx| { - let input = serde_json::to_value(EditFileToolInput { - display_description: "Create main function".into(), - path: "root/src/main.rs".into(), - mode: EditFileMode::Overwrite, - }) - .unwrap(); - Arc::new(EditFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }); - - // Stream the content with trailing whitespace - cx.executor().run_until_parked(); - model.send_last_completion_stream_text_chunk( - CONTENT_WITH_TRAILING_WHITESPACE.to_string(), - ); - model.end_last_completion_stream(); - - edit_task.await - }; - assert!(edit_result.is_ok()); - - // Wait for any async operations (e.g. formatting) to complete - cx.executor().run_until_parked(); - - // Read the file to verify trailing whitespace was removed automatically - assert_eq!( - // Ignore carriage returns on Windows - fs.load(path!("/root/src/main.rs").as_ref()) - .await - .unwrap() - .replace("\r\n", "\n"), - "fn main() {\n println!(\"Hello!\");\n}\n", - "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled" - ); - - // Next, test with remove_trailing_whitespace_on_save disabled - cx.update(|cx| { - SettingsStore::update_global(cx, |store, cx| { - store.update_user_settings(cx, |settings| { - settings - .project - .all_languages - .defaults - .remove_trailing_whitespace_on_save = Some(false); - }); - }); - }); - - // Stream edits again with trailing whitespace - let edit_result = { - let edit_task = cx.update(|cx| { - let input = serde_json::to_value(EditFileToolInput { - display_description: "Update main function".into(), - path: "root/src/main.rs".into(), - mode: EditFileMode::Overwrite, - }) - .unwrap(); - Arc::new(EditFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }); - - // Stream the content with trailing whitespace - cx.executor().run_until_parked(); - model.send_last_completion_stream_text_chunk( - CONTENT_WITH_TRAILING_WHITESPACE.to_string(), - ); - model.end_last_completion_stream(); - - edit_task.await - }; - assert!(edit_result.is_ok()); - - // Wait for any async operations (e.g. formatting) to complete - cx.executor().run_until_parked(); - - // Verify the file still has trailing whitespace - // Read the file again - it should still have trailing whitespace - let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); - assert_eq!( - // Ignore carriage returns on Windows - final_content.replace("\r\n", "\n"), - CONTENT_WITH_TRAILING_WHITESPACE, - "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled" - ); - } - - #[gpui::test] - async fn test_needs_confirmation(cx: &mut TestAppContext) { - init_test(cx); - let tool = Arc::new(EditFileTool); - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/root", json!({})).await; - - // Test 1: Path with .zed component should require confirmation - let input_with_zed = json!({ - "display_description": "Edit settings", - "path": ".zed/settings.json", - "mode": "edit" - }); - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - cx.update(|cx| { - assert!( - tool.needs_confirmation(&input_with_zed, &project, cx), - "Path with .zed component should require confirmation" - ); - }); - - // Test 2: Absolute path should require confirmation - let input_absolute = json!({ - "display_description": "Edit file", - "path": "/etc/hosts", - "mode": "edit" - }); - cx.update(|cx| { - assert!( - tool.needs_confirmation(&input_absolute, &project, cx), - "Absolute path should require confirmation" - ); - }); - - // Test 3: Relative path without .zed should not require confirmation - let input_relative = json!({ - "display_description": "Edit file", - "path": "root/src/main.rs", - "mode": "edit" - }); - cx.update(|cx| { - assert!( - !tool.needs_confirmation(&input_relative, &project, cx), - "Relative path without .zed should not require confirmation" - ); - }); - - // Test 4: Path with .zed in the middle should require confirmation - let input_zed_middle = json!({ - "display_description": "Edit settings", - "path": "root/.zed/tasks.json", - "mode": "edit" - }); - cx.update(|cx| { - assert!( - tool.needs_confirmation(&input_zed_middle, &project, cx), - "Path with .zed in any component should require confirmation" - ); - }); - - // Test 5: When always_allow_tool_actions is enabled, no confirmation needed - cx.update(|cx| { - let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); - settings.always_allow_tool_actions = true; - agent_settings::AgentSettings::override_global(settings, cx); - - assert!( - !tool.needs_confirmation(&input_with_zed, &project, cx), - "When always_allow_tool_actions is true, no confirmation should be needed" - ); - assert!( - !tool.needs_confirmation(&input_absolute, &project, cx), - "When always_allow_tool_actions is true, no confirmation should be needed for absolute paths" - ); - }); - } - - #[gpui::test] - async fn test_ui_text_shows_correct_context(cx: &mut TestAppContext) { - // Set up a custom config directory for testing - let temp_dir = tempfile::tempdir().unwrap(); - init_test_with_config(cx, temp_dir.path()); - - let tool = Arc::new(EditFileTool); - - // Test ui_text shows context for various paths - let test_cases = vec![ - ( - json!({ - "display_description": "Update config", - "path": ".zed/settings.json", - "mode": "edit" - }), - "Update config (local settings)", - ".zed path should show local settings context", - ), - ( - json!({ - "display_description": "Fix bug", - "path": "src/.zed/local.json", - "mode": "edit" - }), - "Fix bug (local settings)", - "Nested .zed path should show local settings context", - ), - ( - json!({ - "display_description": "Update readme", - "path": "README.md", - "mode": "edit" - }), - "Update readme", - "Normal path should not show additional context", - ), - ( - json!({ - "display_description": "Edit config", - "path": "config.zed", - "mode": "edit" - }), - "Edit config", - ".zed as extension should not show context", - ), - ]; - - for (input, expected_text, description) in test_cases { - cx.update(|_cx| { - let ui_text = tool.ui_text(&input); - assert_eq!(ui_text, expected_text, "Failed for case: {}", description); - }); - } - } - - #[gpui::test] - async fn test_needs_confirmation_outside_project(cx: &mut TestAppContext) { - init_test(cx); - let tool = Arc::new(EditFileTool); - let fs = project::FakeFs::new(cx.executor()); - - // Create a project in /project directory - fs.insert_tree("/project", json!({})).await; - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - - // Test file outside project requires confirmation - let input_outside = json!({ - "display_description": "Edit file", - "path": "/outside/file.txt", - "mode": "edit" - }); - cx.update(|cx| { - assert!( - tool.needs_confirmation(&input_outside, &project, cx), - "File outside project should require confirmation" - ); - }); - - // Test file inside project doesn't require confirmation - let input_inside = json!({ - "display_description": "Edit file", - "path": "project/file.txt", - "mode": "edit" - }); - cx.update(|cx| { - assert!( - !tool.needs_confirmation(&input_inside, &project, cx), - "File inside project should not require confirmation" - ); - }); - } - - #[gpui::test] - async fn test_needs_confirmation_config_paths(cx: &mut TestAppContext) { - // Set up a custom data directory for testing - let temp_dir = tempfile::tempdir().unwrap(); - init_test_with_config(cx, temp_dir.path()); - - let tool = Arc::new(EditFileTool); - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/home/user/myproject", json!({})).await; - let project = Project::test(fs.clone(), [path!("/home/user/myproject").as_ref()], cx).await; - - // Get the actual local settings folder name - let local_settings_folder = paths::local_settings_folder_name(); - - // Test various config path patterns - let test_cases = vec![ - ( - format!("{local_settings_folder}/settings.json"), - true, - "Top-level local settings file".to_string(), - ), - ( - format!("myproject/{local_settings_folder}/settings.json"), - true, - "Local settings in project path".to_string(), - ), - ( - format!("src/{local_settings_folder}/config.toml"), - true, - "Local settings in subdirectory".to_string(), - ), - ( - ".zed.backup/file.txt".to_string(), - true, - ".zed.backup is outside project".to_string(), - ), - ( - "my.zed/file.txt".to_string(), - true, - "my.zed is outside project".to_string(), - ), - ( - "myproject/src/file.zed".to_string(), - false, - ".zed as file extension".to_string(), - ), - ( - "myproject/normal/path/file.rs".to_string(), - false, - "Normal file without config paths".to_string(), - ), - ]; - - for (path, should_confirm, description) in test_cases { - let input = json!({ - "display_description": "Edit file", - "path": path, - "mode": "edit" - }); - cx.update(|cx| { - assert_eq!( - tool.needs_confirmation(&input, &project, cx), - should_confirm, - "Failed for case: {} - path: {}", - description, - path - ); - }); - } - } - - #[gpui::test] - async fn test_needs_confirmation_global_config(cx: &mut TestAppContext) { - // Set up a custom data directory for testing - let temp_dir = tempfile::tempdir().unwrap(); - init_test_with_config(cx, temp_dir.path()); - - let tool = Arc::new(EditFileTool); - let fs = project::FakeFs::new(cx.executor()); - - // Create test files in the global config directory - let global_config_dir = paths::config_dir(); - fs::create_dir_all(&global_config_dir).unwrap(); - let global_settings_path = global_config_dir.join("settings.json"); - fs::write(&global_settings_path, "{}").unwrap(); - - fs.insert_tree("/project", json!({})).await; - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - - // Test global config paths - let test_cases = vec![ - ( - global_settings_path.to_str().unwrap().to_string(), - true, - "Global settings file should require confirmation", - ), - ( - global_config_dir - .join("keymap.json") - .to_str() - .unwrap() - .to_string(), - true, - "Global keymap file should require confirmation", - ), - ( - "project/normal_file.rs".to_string(), - false, - "Normal project file should not require confirmation", - ), - ]; - - for (path, should_confirm, description) in test_cases { - let input = json!({ - "display_description": "Edit file", - "path": path, - "mode": "edit" - }); - cx.update(|cx| { - assert_eq!( - tool.needs_confirmation(&input, &project, cx), - should_confirm, - "Failed for case: {}", - description - ); - }); - } - } - - #[gpui::test] - async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) { - init_test(cx); - let tool = Arc::new(EditFileTool); - let fs = project::FakeFs::new(cx.executor()); - - // Create multiple worktree directories - fs.insert_tree( - "/workspace/frontend", - json!({ - "src": { - "main.js": "console.log('frontend');" - } - }), - ) - .await; - fs.insert_tree( - "/workspace/backend", - json!({ - "src": { - "main.rs": "fn main() {}" - } - }), - ) - .await; - fs.insert_tree( - "/workspace/shared", - json!({ - ".zed": { - "settings.json": "{}" - } - }), - ) - .await; - - // Create project with multiple worktrees - let project = Project::test( - fs.clone(), - [ - path!("/workspace/frontend").as_ref(), - path!("/workspace/backend").as_ref(), - path!("/workspace/shared").as_ref(), - ], - cx, - ) - .await; - - // Test files in different worktrees - let test_cases = vec![ - ("frontend/src/main.js", false, "File in first worktree"), - ("backend/src/main.rs", false, "File in second worktree"), - ( - "shared/.zed/settings.json", - true, - ".zed file in third worktree", - ), - ("/etc/hosts", true, "Absolute path outside all worktrees"), - ( - "../outside/file.txt", - true, - "Relative path outside worktrees", - ), - ]; - - for (path, should_confirm, description) in test_cases { - let input = json!({ - "display_description": "Edit file", - "path": path, - "mode": "edit" - }); - cx.update(|cx| { - assert_eq!( - tool.needs_confirmation(&input, &project, cx), - should_confirm, - "Failed for case: {} - path: {}", - description, - path - ); - }); - } - } - - #[gpui::test] - async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) { - init_test(cx); - let tool = Arc::new(EditFileTool); - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/project", - json!({ - ".zed": { - "settings.json": "{}" - }, - "src": { - ".zed": { - "local.json": "{}" - } - } - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - - // Test edge cases - let test_cases = vec![ - // Empty path - find_project_path returns Some for empty paths - ("", false, "Empty path is treated as project root"), - // Root directory - ("/", true, "Root directory should be outside project"), - ("project/../other", true, "Path with .. is outside project"), - ( - "project/./src/file.rs", - false, - "Path with . should work normally", - ), - // Windows-style paths (if on Windows) - #[cfg(target_os = "windows")] - ("C:\\Windows\\System32\\hosts", true, "Windows system path"), - #[cfg(target_os = "windows")] - ("project\\src\\main.rs", false, "Windows-style project path"), - ]; - - for (path, should_confirm, description) in test_cases { - let input = json!({ - "display_description": "Edit file", - "path": path, - "mode": "edit" - }); - cx.update(|cx| { - assert_eq!( - tool.needs_confirmation(&input, &project, cx), - should_confirm, - "Failed for case: {} - path: {}", - description, - path - ); - }); - } - } - - #[gpui::test] - async fn test_ui_text_with_all_path_types(cx: &mut TestAppContext) { - init_test(cx); - let tool = Arc::new(EditFileTool); - - // Test UI text for various scenarios - let test_cases = vec![ - ( - json!({ - "display_description": "Update config", - "path": ".zed/settings.json", - "mode": "edit" - }), - "Update config (local settings)", - ".zed path should show local settings context", - ), - ( - json!({ - "display_description": "Fix bug", - "path": "src/.zed/local.json", - "mode": "edit" - }), - "Fix bug (local settings)", - "Nested .zed path should show local settings context", - ), - ( - json!({ - "display_description": "Update readme", - "path": "README.md", - "mode": "edit" - }), - "Update readme", - "Normal path should not show additional context", - ), - ( - json!({ - "display_description": "Edit config", - "path": "config.zed", - "mode": "edit" - }), - "Edit config", - ".zed as extension should not show context", - ), - ]; - - for (input, expected_text, description) in test_cases { - cx.update(|_cx| { - let ui_text = tool.ui_text(&input); - assert_eq!(ui_text, expected_text, "Failed for case: {}", description); - }); - } - } - - #[gpui::test] - async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) { - init_test(cx); - let tool = Arc::new(EditFileTool); - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/project", - json!({ - "existing.txt": "content", - ".zed": { - "settings.json": "{}" - } - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - - // Test different EditFileMode values - let modes = vec![ - EditFileMode::Edit, - EditFileMode::Create, - EditFileMode::Overwrite, - ]; - - for mode in modes { - // Test .zed path with different modes - let input_zed = json!({ - "display_description": "Edit settings", - "path": "project/.zed/settings.json", - "mode": mode - }); - cx.update(|cx| { - assert!( - tool.needs_confirmation(&input_zed, &project, cx), - ".zed path should require confirmation regardless of mode: {:?}", - mode - ); - }); - - // Test outside path with different modes - let input_outside = json!({ - "display_description": "Edit file", - "path": "/outside/file.txt", - "mode": mode - }); - cx.update(|cx| { - assert!( - tool.needs_confirmation(&input_outside, &project, cx), - "Outside path should require confirmation regardless of mode: {:?}", - mode - ); - }); - - // Test normal path with different modes - let input_normal = json!({ - "display_description": "Edit file", - "path": "project/normal.txt", - "mode": mode - }); - cx.update(|cx| { - assert!( - !tool.needs_confirmation(&input_normal, &project, cx), - "Normal path should not require confirmation regardless of mode: {:?}", - mode - ); - }); - } - } - - #[gpui::test] - async fn test_always_allow_tool_actions_bypasses_all_checks(cx: &mut TestAppContext) { - // Set up with custom directories for deterministic testing - let temp_dir = tempfile::tempdir().unwrap(); - init_test_with_config(cx, temp_dir.path()); - - let tool = Arc::new(EditFileTool); - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/project", json!({})).await; - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - - // Enable always_allow_tool_actions - cx.update(|cx| { - let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); - settings.always_allow_tool_actions = true; - agent_settings::AgentSettings::override_global(settings, cx); - }); - - // Test that all paths that normally require confirmation are bypassed - let global_settings_path = paths::config_dir().join("settings.json"); - fs::create_dir_all(paths::config_dir()).unwrap(); - fs::write(&global_settings_path, "{}").unwrap(); - - let test_cases = vec![ - ".zed/settings.json", - "project/.zed/config.toml", - global_settings_path.to_str().unwrap(), - "/etc/hosts", - "/absolute/path/file.txt", - "../outside/project.txt", - ]; - - for path in test_cases { - let input = json!({ - "display_description": "Edit file", - "path": path, - "mode": "edit" - }); - cx.update(|cx| { - assert!( - !tool.needs_confirmation(&input, &project, cx), - "Path {} should not require confirmation when always_allow_tool_actions is true", - path - ); - }); - } - - // Disable always_allow_tool_actions and verify confirmation is required again - cx.update(|cx| { - let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); - settings.always_allow_tool_actions = false; - agent_settings::AgentSettings::override_global(settings, cx); - }); - - // Verify .zed path requires confirmation again - let input = json!({ - "display_description": "Edit file", - "path": ".zed/settings.json", - "mode": "edit" - }); - cx.update(|cx| { - assert!( - tool.needs_confirmation(&input, &project, cx), - ".zed path should require confirmation when always_allow_tool_actions is false" - ); - }); - } -} diff --git a/crates/assistant_tools/src/edit_file_tool/description.md b/crates/assistant_tools/src/edit_file_tool/description.md deleted file mode 100644 index 27f8e49dd626a2d1a5266b90413a3a5f8e02e6d8..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/edit_file_tool/description.md +++ /dev/null @@ -1,8 +0,0 @@ -This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead. - -Before using this tool: - -1. Use the `read_file` tool to understand the file's contents and context - -2. Verify the directory path is correct (only applicable when creating new files): - - Use the `list_directory` tool to verify the parent directory exists and is the correct location diff --git a/crates/assistant_tools/src/fetch_tool.rs b/crates/assistant_tools/src/fetch_tool.rs deleted file mode 100644 index cc22c9fc09f73914720c4b639f8d273207d7ca53..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/fetch_tool.rs +++ /dev/null @@ -1,178 +0,0 @@ -use std::rc::Rc; -use std::sync::Arc; -use std::{borrow::Cow, cell::RefCell}; - -use crate::schema::json_schema_for; -use action_log::ActionLog; -use anyhow::{Context as _, Result, anyhow, bail}; -use assistant_tool::{Tool, ToolResult}; -use futures::AsyncReadExt as _; -use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task}; -use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown}; -use http_client::{AsyncBody, HttpClientWithUrl}; -use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; -use project::Project; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use ui::IconName; -use util::markdown::MarkdownEscaped; - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)] -enum ContentType { - Html, - Plaintext, - Json, -} - -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -pub struct FetchToolInput { - /// The URL to fetch. - url: String, -} - -pub struct FetchTool { - http_client: Arc, -} - -impl FetchTool { - pub fn new(http_client: Arc) -> Self { - Self { http_client } - } - - async fn build_message(http_client: Arc, url: &str) -> Result { - let url = if !url.starts_with("https://") && !url.starts_with("http://") { - Cow::Owned(format!("https://{url}")) - } else { - Cow::Borrowed(url) - }; - - let mut response = http_client.get(&url, AsyncBody::default(), true).await?; - - let mut body = Vec::new(); - response - .body_mut() - .read_to_end(&mut body) - .await - .context("error reading response body")?; - - if response.status().is_client_error() { - let text = String::from_utf8_lossy(body.as_slice()); - bail!( - "status error {}, response: {text:?}", - response.status().as_u16() - ); - } - - let Some(content_type) = response.headers().get("content-type") else { - bail!("missing Content-Type header"); - }; - let content_type = content_type - .to_str() - .context("invalid Content-Type header")?; - let content_type = match content_type { - "text/html" | "application/xhtml+xml" => ContentType::Html, - "application/json" => ContentType::Json, - _ => ContentType::Plaintext, - }; - - match content_type { - ContentType::Html => { - let mut handlers: Vec = vec![ - Rc::new(RefCell::new(markdown::WebpageChromeRemover)), - Rc::new(RefCell::new(markdown::ParagraphHandler)), - Rc::new(RefCell::new(markdown::HeadingHandler)), - Rc::new(RefCell::new(markdown::ListHandler)), - Rc::new(RefCell::new(markdown::TableHandler::new())), - Rc::new(RefCell::new(markdown::StyledTextHandler)), - ]; - if url.contains("wikipedia.org") { - use html_to_markdown::structure::wikipedia; - - handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaChromeRemover))); - handlers.push(Rc::new(RefCell::new(wikipedia::WikipediaInfoboxHandler))); - handlers.push(Rc::new( - RefCell::new(wikipedia::WikipediaCodeHandler::new()), - )); - } else { - handlers.push(Rc::new(RefCell::new(markdown::CodeHandler))); - } - - convert_html_to_markdown(&body[..], &mut handlers) - } - ContentType::Plaintext => Ok(std::str::from_utf8(&body)?.to_owned()), - ContentType::Json => { - let json: serde_json::Value = serde_json::from_slice(&body)?; - - Ok(format!( - "```json\n{}\n```", - serde_json::to_string_pretty(&json)? - )) - } - } - } -} - -impl Tool for FetchTool { - fn name(&self) -> String { - "fetch".to_string() - } - - fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { - true - } - - fn may_perform_edits(&self) -> bool { - false - } - - fn description(&self) -> String { - include_str!("./fetch_tool/description.md").to_string() - } - - fn icon(&self) -> IconName { - IconName::ToolWeb - } - - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { - json_schema_for::(format) - } - - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)), - Err(_) => "Fetch URL".to_string(), - } - } - - fn run( - self: Arc, - input: serde_json::Value, - _request: Arc, - _project: Entity, - _action_log: Entity, - _model: Arc, - _window: Option, - cx: &mut App, - ) -> ToolResult { - let input = match serde_json::from_value::(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; - - let text = cx.background_spawn({ - let http_client = self.http_client.clone(); - async move { Self::build_message(http_client, &input.url).await } - }); - - cx.foreground_executor() - .spawn(async move { - let text = text.await?; - if text.trim().is_empty() { - bail!("no textual content found"); - } - - Ok(text.into()) - }) - .into() - } -} diff --git a/crates/assistant_tools/src/fetch_tool/description.md b/crates/assistant_tools/src/fetch_tool/description.md deleted file mode 100644 index 007ba6c60864c2185740b40222a32b05d2819bf0..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/fetch_tool/description.md +++ /dev/null @@ -1 +0,0 @@ -Fetches a URL and returns the content as Markdown. diff --git a/crates/assistant_tools/src/find_path_tool.rs b/crates/assistant_tools/src/find_path_tool.rs deleted file mode 100644 index 0bc478251cb5d3d558dda4fb41df02e85eaafde2..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/find_path_tool.rs +++ /dev/null @@ -1,472 +0,0 @@ -use crate::{schema::json_schema_for, ui::ToolCallCardHeader}; -use action_log::ActionLog; -use anyhow::{Result, anyhow}; -use assistant_tool::{ - Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus, -}; -use editor::Editor; -use futures::channel::oneshot::{self, Receiver}; -use gpui::{ - AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window, -}; -use language; -use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; -use project::Project; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use std::fmt::Write; -use std::{cmp, path::PathBuf, sync::Arc}; -use ui::{Disclosure, Tooltip, prelude::*}; -use util::{ResultExt, paths::PathMatcher}; -use workspace::Workspace; - -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -pub struct FindPathToolInput { - /// The glob to match against every path in the project. - /// - /// - /// If the project has the following root directories: - /// - /// - directory1/a/something.txt - /// - directory2/a/things.txt - /// - directory3/a/other.txt - /// - /// You can get back the first two paths by providing a glob of "*thing*.txt" - /// - pub glob: String, - - /// Optional starting position for paginated results (0-based). - /// When not provided, starts from the beginning. - #[serde(default)] - pub offset: usize, -} - -#[derive(Debug, Serialize, Deserialize)] -struct FindPathToolOutput { - glob: String, - paths: Vec, -} - -const RESULTS_PER_PAGE: usize = 50; - -pub struct FindPathTool; - -impl Tool for FindPathTool { - fn name(&self) -> String { - "find_path".into() - } - - fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { - false - } - - fn may_perform_edits(&self) -> bool { - false - } - - fn description(&self) -> String { - include_str!("./find_path_tool/description.md").into() - } - - fn icon(&self) -> IconName { - IconName::ToolSearch - } - - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { - json_schema_for::(format) - } - - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => format!("Find paths matching “`{}`”", input.glob), - Err(_) => "Search paths".to_string(), - } - } - - fn run( - self: Arc, - input: serde_json::Value, - _request: Arc, - project: Entity, - _action_log: Entity, - _model: Arc, - _window: Option, - cx: &mut App, - ) -> ToolResult { - let (offset, glob) = match serde_json::from_value::(input) { - Ok(input) => (input.offset, input.glob), - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; - - let (sender, receiver) = oneshot::channel(); - - let card = cx.new(|cx| FindPathToolCard::new(glob.clone(), receiver, cx)); - - let search_paths_task = search_paths(&glob, project, cx); - - let task = cx.background_spawn(async move { - let matches = search_paths_task.await?; - let paginated_matches: &[PathBuf] = &matches[cmp::min(offset, matches.len()) - ..cmp::min(offset + RESULTS_PER_PAGE, matches.len())]; - - sender.send(paginated_matches.to_vec()).log_err(); - - if matches.is_empty() { - Ok("No matches found".to_string().into()) - } else { - let mut message = format!("Found {} total matches.", matches.len()); - if matches.len() > RESULTS_PER_PAGE { - write!( - &mut message, - "\nShowing results {}-{} (provide 'offset' parameter for more results):", - offset + 1, - offset + paginated_matches.len() - ) - .unwrap(); - } - - for mat in matches.iter().skip(offset).take(RESULTS_PER_PAGE) { - write!(&mut message, "\n{}", mat.display()).unwrap(); - } - - let output = FindPathToolOutput { - glob, - paths: matches, - }; - - Ok(ToolResultOutput { - content: ToolResultContent::Text(message), - output: Some(serde_json::to_value(output)?), - }) - } - }); - - ToolResult { - output: task, - card: Some(card.into()), - } - } - - fn deserialize_card( - self: Arc, - output: serde_json::Value, - _project: Entity, - _window: &mut Window, - cx: &mut App, - ) -> Option { - let output = serde_json::from_value::(output).ok()?; - let card = cx.new(|_| FindPathToolCard::from_output(output)); - Some(card.into()) - } -} - -fn search_paths(glob: &str, project: Entity, cx: &mut App) -> Task>> { - let path_matcher = match PathMatcher::new( - [ - // Sometimes models try to search for "". In this case, return all paths in the project. - if glob.is_empty() { "*" } else { glob }, - ], - project.read(cx).path_style(cx), - ) { - Ok(matcher) => matcher, - Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))), - }; - let snapshots: Vec<_> = project - .read(cx) - .worktrees(cx) - .map(|worktree| worktree.read(cx).snapshot()) - .collect(); - - cx.background_spawn(async move { - Ok(snapshots - .iter() - .flat_map(|snapshot| { - snapshot - .entries(false, 0) - .map(move |entry| { - snapshot - .root_name() - .join(&entry.path) - .as_std_path() - .to_path_buf() - }) - .filter(|path| path_matcher.is_match(&path)) - }) - .collect()) - }) -} - -struct FindPathToolCard { - paths: Vec, - expanded: bool, - glob: String, - _receiver_task: Option>>, -} - -impl FindPathToolCard { - fn new(glob: String, receiver: Receiver>, cx: &mut Context) -> Self { - let _receiver_task = cx.spawn(async move |this, cx| { - let paths = receiver.await?; - - this.update(cx, |this, _cx| { - this.paths = paths; - }) - .log_err(); - - Ok(()) - }); - - Self { - paths: Vec::new(), - expanded: false, - glob, - _receiver_task: Some(_receiver_task), - } - } - - fn from_output(output: FindPathToolOutput) -> Self { - Self { - glob: output.glob, - paths: output.paths, - expanded: false, - _receiver_task: None, - } - } -} - -impl ToolCard for FindPathToolCard { - fn render( - &mut self, - _status: &ToolUseStatus, - _window: &mut Window, - workspace: WeakEntity, - cx: &mut Context, - ) -> impl IntoElement { - let matches_label: SharedString = if self.paths.is_empty() { - "No matches".into() - } else if self.paths.len() == 1 { - "1 match".into() - } else { - format!("{} matches", self.paths.len()).into() - }; - - let content = if !self.paths.is_empty() && self.expanded { - Some( - v_flex() - .relative() - .ml_1p5() - .px_1p5() - .gap_0p5() - .border_l_1() - .border_color(cx.theme().colors().border_variant) - .children(self.paths.iter().enumerate().map(|(index, path)| { - let path_clone = path.clone(); - let workspace_clone = workspace.clone(); - let button_label = path.to_string_lossy().into_owned(); - - Button::new(("path", index), button_label) - .icon(IconName::ArrowUpRight) - .icon_size(IconSize::Small) - .icon_position(IconPosition::End) - .label_size(LabelSize::Small) - .color(Color::Muted) - .tooltip(Tooltip::text("Jump to File")) - .on_click(move |_, window, cx| { - workspace_clone - .update(cx, |workspace, cx| { - let path = PathBuf::from(&path_clone); - let Some(project_path) = workspace - .project() - .read(cx) - .find_project_path(&path, cx) - else { - return; - }; - let open_task = workspace.open_path( - project_path, - None, - true, - window, - cx, - ); - window - .spawn(cx, async move |cx| { - let item = open_task.await?; - if let Some(active_editor) = - item.downcast::() - { - active_editor - .update_in(cx, |editor, window, cx| { - editor.go_to_singleton_buffer_point( - language::Point::new(0, 0), - window, - cx, - ); - }) - .log_err(); - } - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - }) - .ok(); - }) - })) - .into_any(), - ) - } else { - None - }; - - v_flex() - .mb_2() - .gap_1() - .child( - ToolCallCardHeader::new(IconName::ToolSearch, matches_label) - .with_code_path(&self.glob) - .disclosure_slot( - Disclosure::new("path-search-disclosure", self.expanded) - .opened_icon(IconName::ChevronUp) - .closed_icon(IconName::ChevronDown) - .disabled(self.paths.is_empty()) - .on_click(cx.listener(move |this, _, _, _cx| { - this.expanded = !this.expanded; - })), - ), - ) - .children(content) - } -} - -impl Component for FindPathTool { - fn scope() -> ComponentScope { - ComponentScope::Agent - } - - fn sort_name() -> &'static str { - "FindPathTool" - } - - fn preview(window: &mut Window, cx: &mut App) -> Option { - let successful_card = cx.new(|_| FindPathToolCard { - paths: vec![ - PathBuf::from("src/main.rs"), - PathBuf::from("src/lib.rs"), - PathBuf::from("tests/test.rs"), - ], - expanded: true, - glob: "*.rs".to_string(), - _receiver_task: None, - }); - - let empty_card = cx.new(|_| FindPathToolCard { - paths: Vec::new(), - expanded: false, - glob: "*.nonexistent".to_string(), - _receiver_task: None, - }); - - Some( - v_flex() - .gap_6() - .children(vec![example_group(vec![ - single_example( - "With Paths", - div() - .size_full() - .child(successful_card.update(cx, |tool, cx| { - tool.render( - &ToolUseStatus::Finished("".into()), - window, - WeakEntity::new_invalid(), - cx, - ) - .into_any_element() - })) - .into_any_element(), - ), - single_example( - "No Paths", - div() - .size_full() - .child(empty_card.update(cx, |tool, cx| { - tool.render( - &ToolUseStatus::Finished("".into()), - window, - WeakEntity::new_invalid(), - cx, - ) - .into_any_element() - })) - .into_any_element(), - ), - ])]) - .into_any_element(), - ) - } -} - -#[cfg(test)] -mod test { - use super::*; - use gpui::TestAppContext; - use project::{FakeFs, Project}; - use settings::SettingsStore; - use util::path; - - #[gpui::test] - async fn test_find_path_tool(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - serde_json::json!({ - "apple": { - "banana": { - "carrot": "1", - }, - "bandana": { - "carbonara": "2", - }, - "endive": "3" - } - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - - let matches = cx - .update(|cx| search_paths("root/**/car*", project.clone(), cx)) - .await - .unwrap(); - assert_eq!( - matches, - &[ - PathBuf::from(path!("root/apple/banana/carrot")), - PathBuf::from(path!("root/apple/bandana/carbonara")) - ] - ); - - let matches = cx - .update(|cx| search_paths("**/car*", project.clone(), cx)) - .await - .unwrap(); - assert_eq!( - matches, - &[ - PathBuf::from(path!("root/apple/banana/carrot")), - PathBuf::from(path!("root/apple/bandana/carbonara")) - ] - ); - } - - fn init_test(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - language::init(cx); - Project::init_settings(cx); - }); - } -} diff --git a/crates/assistant_tools/src/find_path_tool/description.md b/crates/assistant_tools/src/find_path_tool/description.md deleted file mode 100644 index f7a697c467b2807c1f4cf1706ef660a77b9ee727..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/find_path_tool/description.md +++ /dev/null @@ -1,7 +0,0 @@ -Fast file path pattern matching tool that works with any codebase size - -- Supports glob patterns like "**/*.js" or "src/**/*.ts" -- Returns matching file paths sorted alphabetically -- Prefer the `grep` tool to this tool when searching for symbols unless you have specific information about paths. -- Use this tool when you need to find files by name patterns -- Results are paginated with 50 matches per page. Use the optional 'offset' parameter to request subsequent pages. diff --git a/crates/assistant_tools/src/grep_tool.rs b/crates/assistant_tools/src/grep_tool.rs deleted file mode 100644 index 609e25f338d11995ea6f587ba476e4f95274e4e9..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/grep_tool.rs +++ /dev/null @@ -1,1308 +0,0 @@ -use crate::schema::json_schema_for; -use action_log::ActionLog; -use anyhow::{Result, anyhow}; -use assistant_tool::{Tool, ToolResult}; -use futures::StreamExt; -use gpui::{AnyWindowHandle, App, Entity, Task}; -use language::{OffsetRangeExt, ParseStatus, Point}; -use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; -use project::{ - Project, WorktreeSettings, - search::{SearchQuery, SearchResult}, -}; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use settings::Settings; -use std::{cmp, fmt::Write, sync::Arc}; -use ui::IconName; -use util::RangeExt; -use util::markdown::MarkdownInlineCode; -use util::paths::PathMatcher; - -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -pub struct GrepToolInput { - /// A regex pattern to search for in the entire project. Note that the regex - /// will be parsed by the Rust `regex` crate. - /// - /// Do NOT specify a path here! This will only be matched against the code **content**. - pub regex: String, - - /// A glob pattern for the paths of files to include in the search. - /// Supports standard glob patterns like "**/*.rs" or "src/**/*.ts". - /// If omitted, all files in the project will be searched. - pub include_pattern: Option, - - /// Optional starting position for paginated results (0-based). - /// When not provided, starts from the beginning. - #[serde(default)] - pub offset: u32, - - /// Whether the regex is case-sensitive. Defaults to false (case-insensitive). - #[serde(default)] - pub case_sensitive: bool, -} - -impl GrepToolInput { - /// Which page of search results this is. - pub fn page(&self) -> u32 { - 1 + (self.offset / RESULTS_PER_PAGE) - } -} - -const RESULTS_PER_PAGE: u32 = 20; - -pub struct GrepTool; - -impl Tool for GrepTool { - fn name(&self) -> String { - "grep".into() - } - - fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { - false - } - - fn may_perform_edits(&self) -> bool { - false - } - - fn description(&self) -> String { - include_str!("./grep_tool/description.md").into() - } - - fn icon(&self) -> IconName { - IconName::ToolRegex - } - - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { - json_schema_for::(format) - } - - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => { - let page = input.page(); - let regex_str = MarkdownInlineCode(&input.regex); - let case_info = if input.case_sensitive { - " (case-sensitive)" - } else { - "" - }; - - if page > 1 { - format!("Get page {page} of search results for regex {regex_str}{case_info}") - } else { - format!("Search files for regex {regex_str}{case_info}") - } - } - Err(_) => "Search with regex".to_string(), - } - } - - fn run( - self: Arc, - input: serde_json::Value, - _request: Arc, - project: Entity, - _action_log: Entity, - _model: Arc, - _window: Option, - cx: &mut App, - ) -> ToolResult { - const CONTEXT_LINES: u32 = 2; - const MAX_ANCESTOR_LINES: u32 = 10; - - let input = match serde_json::from_value::(input) { - Ok(input) => input, - Err(error) => { - return Task::ready(Err(anyhow!("Failed to parse input: {error}"))).into(); - } - }; - - let include_matcher = match PathMatcher::new( - input - .include_pattern - .as_ref() - .into_iter() - .collect::>(), - project.read(cx).path_style(cx), - ) { - Ok(matcher) => matcher, - Err(error) => { - return Task::ready(Err(anyhow!("invalid include glob pattern: {error}"))).into(); - } - }; - - // Exclude global file_scan_exclusions and private_files settings - let exclude_matcher = { - let global_settings = WorktreeSettings::get_global(cx); - let exclude_patterns = global_settings - .file_scan_exclusions - .sources() - .iter() - .chain(global_settings.private_files.sources().iter()); - - match PathMatcher::new(exclude_patterns, project.read(cx).path_style(cx)) { - Ok(matcher) => matcher, - Err(error) => { - return Task::ready(Err(anyhow!("invalid exclude pattern: {error}"))).into(); - } - } - }; - - let query = match SearchQuery::regex( - &input.regex, - false, - input.case_sensitive, - false, - false, - include_matcher, - exclude_matcher, - true, // Always match file include pattern against *full project paths* that start with a project root. - None, - ) { - Ok(query) => query, - Err(error) => return Task::ready(Err(error)).into(), - }; - - let results = project.update(cx, |project, cx| project.search(query, cx)); - - cx.spawn(async move |cx| { - futures::pin_mut!(results); - - let mut output = String::new(); - let mut skips_remaining = input.offset; - let mut matches_found = 0; - let mut has_more_matches = false; - - 'outer: while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await { - if ranges.is_empty() { - continue; - } - - let Ok((Some(path), mut parse_status)) = buffer.read_with(cx, |buffer, cx| { - (buffer.file().map(|file| file.full_path(cx)), buffer.parse_status()) - }) else { - continue; - }; - - // Check if this file should be excluded based on its worktree settings - if let Ok(Some(project_path)) = project.read_with(cx, |project, cx| { - project.find_project_path(&path, cx) - }) - && cx.update(|cx| { - let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx); - worktree_settings.is_path_excluded(&project_path.path) - || worktree_settings.is_path_private(&project_path.path) - }).unwrap_or(false) { - continue; - } - - while *parse_status.borrow() != ParseStatus::Idle { - parse_status.changed().await?; - } - - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - - let mut ranges = ranges - .into_iter() - .map(|range| { - let matched = range.to_point(&snapshot); - let matched_end_line_len = snapshot.line_len(matched.end.row); - let full_lines = Point::new(matched.start.row, 0)..Point::new(matched.end.row, matched_end_line_len); - let symbols = snapshot.symbols_containing(matched.start, None); - - if let Some(ancestor_node) = snapshot.syntax_ancestor(full_lines.clone()) { - let full_ancestor_range = ancestor_node.byte_range().to_point(&snapshot); - let end_row = full_ancestor_range.end.row.min(full_ancestor_range.start.row + MAX_ANCESTOR_LINES); - let end_col = snapshot.line_len(end_row); - let capped_ancestor_range = Point::new(full_ancestor_range.start.row, 0)..Point::new(end_row, end_col); - - if capped_ancestor_range.contains_inclusive(&full_lines) { - return (capped_ancestor_range, Some(full_ancestor_range), symbols) - } - } - - let mut matched = matched; - matched.start.column = 0; - matched.start.row = - matched.start.row.saturating_sub(CONTEXT_LINES); - matched.end.row = cmp::min( - snapshot.max_point().row, - matched.end.row + CONTEXT_LINES, - ); - matched.end.column = snapshot.line_len(matched.end.row); - - (matched, None, symbols) - }) - .peekable(); - - let mut file_header_written = false; - - while let Some((mut range, ancestor_range, parent_symbols)) = ranges.next(){ - if skips_remaining > 0 { - skips_remaining -= 1; - continue; - } - - // We'd already found a full page of matches, and we just found one more. - if matches_found >= RESULTS_PER_PAGE { - has_more_matches = true; - break 'outer; - } - - while let Some((next_range, _, _)) = ranges.peek() { - if range.end.row >= next_range.start.row { - range.end = next_range.end; - ranges.next(); - } else { - break; - } - } - - if !file_header_written { - writeln!(output, "\n## Matches in {}", path.display())?; - file_header_written = true; - } - - let end_row = range.end.row; - output.push_str("\n### "); - - for symbol in parent_symbols { - write!(output, "{} › ", symbol.text)?; - } - - if range.start.row == end_row { - writeln!(output, "L{}", range.start.row + 1)?; - } else { - writeln!(output, "L{}-{}", range.start.row + 1, end_row + 1)?; - } - - output.push_str("```\n"); - output.extend(snapshot.text_for_range(range)); - output.push_str("\n```\n"); - - if let Some(ancestor_range) = ancestor_range - && end_row < ancestor_range.end.row { - let remaining_lines = ancestor_range.end.row - end_row; - writeln!(output, "\n{} lines remaining in ancestor node. Read the file to see all.", remaining_lines)?; - } - - matches_found += 1; - } - } - - if matches_found == 0 { - Ok("No matches found".to_string().into()) - } else if has_more_matches { - Ok(format!( - "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}", - input.offset + 1, - input.offset + matches_found, - input.offset + RESULTS_PER_PAGE, - ).into()) - } else { - Ok(format!("Found {matches_found} matches:\n{output}").into()) - } - }).into() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use assistant_tool::Tool; - use gpui::{AppContext, TestAppContext, UpdateGlobal}; - use language::{Language, LanguageConfig, LanguageMatcher}; - use language_model::fake_provider::FakeLanguageModel; - use project::{FakeFs, Project}; - use serde_json::json; - use settings::SettingsStore; - use unindent::Unindent; - use util::path; - - #[gpui::test] - async fn test_grep_tool_with_include_pattern(cx: &mut TestAppContext) { - init_test(cx); - cx.executor().allow_parking(); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/root"), - serde_json::json!({ - "src": { - "main.rs": "fn main() {\n println!(\"Hello, world!\");\n}", - "utils": { - "helper.rs": "fn helper() {\n println!(\"I'm a helper!\");\n}", - }, - }, - "tests": { - "test_main.rs": "fn test_main() {\n assert!(true);\n}", - } - }), - ) - .await; - - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - - // Test with include pattern for Rust files inside the root of the project - let input = serde_json::to_value(GrepToolInput { - regex: "println".to_string(), - include_pattern: Some("root/**/*.rs".to_string()), - offset: 0, - case_sensitive: false, - }) - .unwrap(); - - let result = run_grep_tool(input, project.clone(), cx).await; - assert!(result.contains("main.rs"), "Should find matches in main.rs"); - assert!( - result.contains("helper.rs"), - "Should find matches in helper.rs" - ); - assert!( - !result.contains("test_main.rs"), - "Should not include test_main.rs even though it's a .rs file (because it doesn't have the pattern)" - ); - - // Test with include pattern for src directory only - let input = serde_json::to_value(GrepToolInput { - regex: "fn".to_string(), - include_pattern: Some("root/**/src/**".to_string()), - offset: 0, - case_sensitive: false, - }) - .unwrap(); - - let result = run_grep_tool(input, project.clone(), cx).await; - assert!( - result.contains("main.rs"), - "Should find matches in src/main.rs" - ); - assert!( - result.contains("helper.rs"), - "Should find matches in src/utils/helper.rs" - ); - assert!( - !result.contains("test_main.rs"), - "Should not include test_main.rs as it's not in src directory" - ); - - // Test with empty include pattern (should default to all files) - let input = serde_json::to_value(GrepToolInput { - regex: "fn".to_string(), - include_pattern: None, - offset: 0, - case_sensitive: false, - }) - .unwrap(); - - let result = run_grep_tool(input, project.clone(), cx).await; - assert!(result.contains("main.rs"), "Should find matches in main.rs"); - assert!( - result.contains("helper.rs"), - "Should find matches in helper.rs" - ); - assert!( - result.contains("test_main.rs"), - "Should include test_main.rs" - ); - } - - #[gpui::test] - async fn test_grep_tool_with_case_sensitivity(cx: &mut TestAppContext) { - init_test(cx); - cx.executor().allow_parking(); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/root"), - serde_json::json!({ - "case_test.txt": "This file has UPPERCASE and lowercase text.\nUPPERCASE patterns should match only with case_sensitive: true", - }), - ) - .await; - - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - - // Test case-insensitive search (default) - let input = serde_json::to_value(GrepToolInput { - regex: "uppercase".to_string(), - include_pattern: Some("**/*.txt".to_string()), - offset: 0, - case_sensitive: false, - }) - .unwrap(); - - let result = run_grep_tool(input, project.clone(), cx).await; - assert!( - result.contains("UPPERCASE"), - "Case-insensitive search should match uppercase" - ); - - // Test case-sensitive search - let input = serde_json::to_value(GrepToolInput { - regex: "uppercase".to_string(), - include_pattern: Some("**/*.txt".to_string()), - offset: 0, - case_sensitive: true, - }) - .unwrap(); - - let result = run_grep_tool(input, project.clone(), cx).await; - assert!( - !result.contains("UPPERCASE"), - "Case-sensitive search should not match uppercase" - ); - - // Test case-sensitive search - let input = serde_json::to_value(GrepToolInput { - regex: "LOWERCASE".to_string(), - include_pattern: Some("**/*.txt".to_string()), - offset: 0, - case_sensitive: true, - }) - .unwrap(); - - let result = run_grep_tool(input, project.clone(), cx).await; - - assert!( - !result.contains("lowercase"), - "Case-sensitive search should match lowercase" - ); - - // Test case-sensitive search for lowercase pattern - let input = serde_json::to_value(GrepToolInput { - regex: "lowercase".to_string(), - include_pattern: Some("**/*.txt".to_string()), - offset: 0, - case_sensitive: true, - }) - .unwrap(); - - let result = run_grep_tool(input, project.clone(), cx).await; - assert!( - result.contains("lowercase"), - "Case-sensitive search should match lowercase text" - ); - } - - /// Helper function to set up a syntax test environment - async fn setup_syntax_test(cx: &mut TestAppContext) -> Entity { - use unindent::Unindent; - init_test(cx); - cx.executor().allow_parking(); - - let fs = FakeFs::new(cx.executor()); - - // Create test file with syntax structures - fs.insert_tree( - path!("/root"), - serde_json::json!({ - "test_syntax.rs": r#" - fn top_level_function() { - println!("This is at the top level"); - } - - mod feature_module { - pub mod nested_module { - pub fn nested_function( - first_arg: String, - second_arg: i32, - ) { - println!("Function in nested module"); - println!("{first_arg}"); - println!("{second_arg}"); - } - } - } - - struct MyStruct { - field1: String, - field2: i32, - } - - impl MyStruct { - fn method_with_block() { - let condition = true; - if condition { - println!("Inside if block"); - } - } - - fn long_function() { - println!("Line 1"); - println!("Line 2"); - println!("Line 3"); - println!("Line 4"); - println!("Line 5"); - println!("Line 6"); - println!("Line 7"); - println!("Line 8"); - println!("Line 9"); - println!("Line 10"); - println!("Line 11"); - println!("Line 12"); - } - } - - trait Processor { - fn process(&self, input: &str) -> String; - } - - impl Processor for MyStruct { - fn process(&self, input: &str) -> String { - format!("Processed: {}", input) - } - } - "#.unindent().trim(), - }), - ) - .await; - - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - - project.update(cx, |project, _cx| { - project.languages().add(rust_lang().into()) - }); - - project - } - - #[gpui::test] - async fn test_grep_top_level_function(cx: &mut TestAppContext) { - let project = setup_syntax_test(cx).await; - - // Test: Line at the top level of the file - let input = serde_json::to_value(GrepToolInput { - regex: "This is at the top level".to_string(), - include_pattern: Some("**/*.rs".to_string()), - offset: 0, - case_sensitive: false, - }) - .unwrap(); - - let result = run_grep_tool(input, project.clone(), cx).await; - let expected = r#" - Found 1 matches: - - ## Matches in root/test_syntax.rs - - ### fn top_level_function › L1-3 - ``` - fn top_level_function() { - println!("This is at the top level"); - } - ``` - "# - .unindent(); - assert_eq!(result, expected); - } - - #[gpui::test] - async fn test_grep_function_body(cx: &mut TestAppContext) { - let project = setup_syntax_test(cx).await; - - // Test: Line inside a function body - let input = serde_json::to_value(GrepToolInput { - regex: "Function in nested module".to_string(), - include_pattern: Some("**/*.rs".to_string()), - offset: 0, - case_sensitive: false, - }) - .unwrap(); - - let result = run_grep_tool(input, project.clone(), cx).await; - let expected = r#" - Found 1 matches: - - ## Matches in root/test_syntax.rs - - ### mod feature_module › pub mod nested_module › pub fn nested_function › L10-14 - ``` - ) { - println!("Function in nested module"); - println!("{first_arg}"); - println!("{second_arg}"); - } - ``` - "# - .unindent(); - assert_eq!(result, expected); - } - - #[gpui::test] - async fn test_grep_function_args_and_body(cx: &mut TestAppContext) { - let project = setup_syntax_test(cx).await; - - // Test: Line with a function argument - let input = serde_json::to_value(GrepToolInput { - regex: "second_arg".to_string(), - include_pattern: Some("**/*.rs".to_string()), - offset: 0, - case_sensitive: false, - }) - .unwrap(); - - let result = run_grep_tool(input, project.clone(), cx).await; - let expected = r#" - Found 1 matches: - - ## Matches in root/test_syntax.rs - - ### mod feature_module › pub mod nested_module › pub fn nested_function › L7-14 - ``` - pub fn nested_function( - first_arg: String, - second_arg: i32, - ) { - println!("Function in nested module"); - println!("{first_arg}"); - println!("{second_arg}"); - } - ``` - "# - .unindent(); - assert_eq!(result, expected); - } - - #[gpui::test] - async fn test_grep_if_block(cx: &mut TestAppContext) { - use unindent::Unindent; - let project = setup_syntax_test(cx).await; - - // Test: Line inside an if block - let input = serde_json::to_value(GrepToolInput { - regex: "Inside if block".to_string(), - include_pattern: Some("**/*.rs".to_string()), - offset: 0, - case_sensitive: false, - }) - .unwrap(); - - let result = run_grep_tool(input, project.clone(), cx).await; - let expected = r#" - Found 1 matches: - - ## Matches in root/test_syntax.rs - - ### impl MyStruct › fn method_with_block › L26-28 - ``` - if condition { - println!("Inside if block"); - } - ``` - "# - .unindent(); - assert_eq!(result, expected); - } - - #[gpui::test] - async fn test_grep_long_function_top(cx: &mut TestAppContext) { - use unindent::Unindent; - let project = setup_syntax_test(cx).await; - - // Test: Line in the middle of a long function - should show message about remaining lines - let input = serde_json::to_value(GrepToolInput { - regex: "Line 5".to_string(), - include_pattern: Some("**/*.rs".to_string()), - offset: 0, - case_sensitive: false, - }) - .unwrap(); - - let result = run_grep_tool(input, project.clone(), cx).await; - let expected = r#" - Found 1 matches: - - ## Matches in root/test_syntax.rs - - ### impl MyStruct › fn long_function › L31-41 - ``` - fn long_function() { - println!("Line 1"); - println!("Line 2"); - println!("Line 3"); - println!("Line 4"); - println!("Line 5"); - println!("Line 6"); - println!("Line 7"); - println!("Line 8"); - println!("Line 9"); - println!("Line 10"); - ``` - - 3 lines remaining in ancestor node. Read the file to see all. - "# - .unindent(); - assert_eq!(result, expected); - } - - #[gpui::test] - async fn test_grep_long_function_bottom(cx: &mut TestAppContext) { - use unindent::Unindent; - let project = setup_syntax_test(cx).await; - - // Test: Line in the long function - let input = serde_json::to_value(GrepToolInput { - regex: "Line 12".to_string(), - include_pattern: Some("**/*.rs".to_string()), - offset: 0, - case_sensitive: false, - }) - .unwrap(); - - let result = run_grep_tool(input, project.clone(), cx).await; - let expected = r#" - Found 1 matches: - - ## Matches in root/test_syntax.rs - - ### impl MyStruct › fn long_function › L41-45 - ``` - println!("Line 10"); - println!("Line 11"); - println!("Line 12"); - } - } - ``` - "# - .unindent(); - assert_eq!(result, expected); - } - - async fn run_grep_tool( - input: serde_json::Value, - project: Entity, - cx: &mut TestAppContext, - ) -> String { - let tool = Arc::new(GrepTool); - let action_log = cx.new(|_cx| ActionLog::new(project.clone())); - let model = Arc::new(FakeLanguageModel::default()); - let task = - cx.update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx)); - - match task.output.await { - Ok(result) => { - if cfg!(windows) { - result.content.as_str().unwrap().replace("root\\", "root/") - } else { - result.content.as_str().unwrap().to_string() - } - } - Err(e) => panic!("Failed to run grep tool: {}", e), - } - } - - fn init_test(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - language::init(cx); - Project::init_settings(cx); - }); - } - - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) - .unwrap() - } - - #[gpui::test] - async fn test_grep_security_boundaries(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - - fs.insert_tree( - path!("/"), - json!({ - "project_root": { - "allowed_file.rs": "fn main() { println!(\"This file is in the project\"); }", - ".mysecrets": "SECRET_KEY=abc123\nfn secret() { /* private */ }", - ".secretdir": { - "config": "fn special_configuration() { /* excluded */ }" - }, - ".mymetadata": "fn custom_metadata() { /* excluded */ }", - "subdir": { - "normal_file.rs": "fn normal_file_content() { /* Normal */ }", - "special.privatekey": "fn private_key_content() { /* private */ }", - "data.mysensitive": "fn sensitive_data() { /* private */ }" - } - }, - "outside_project": { - "sensitive_file.rs": "fn outside_function() { /* This file is outside the project */ }" - } - }), - ) - .await; - - cx.update(|cx| { - use gpui::UpdateGlobal; - use settings::SettingsStore; - SettingsStore::update_global(cx, |store, cx| { - store.update_user_settings(cx, |settings| { - settings.project.worktree.file_scan_exclusions = Some(vec![ - "**/.secretdir".to_string(), - "**/.mymetadata".to_string(), - ]); - settings.project.worktree.private_files = Some( - vec![ - "**/.mysecrets".to_string(), - "**/*.privatekey".to_string(), - "**/*.mysensitive".to_string(), - ] - .into(), - ); - }); - }); - }); - - let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await; - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let model = Arc::new(FakeLanguageModel::default()); - - // Searching for files outside the project worktree should return no results - let result = cx - .update(|cx| { - let input = json!({ - "regex": "outside_function" - }); - Arc::new(GrepTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await; - let results = result.unwrap(); - let paths = extract_paths_from_results(results.content.as_str().unwrap()); - assert!( - paths.is_empty(), - "grep_tool should not find files outside the project worktree" - ); - - // Searching within the project should succeed - let result = cx - .update(|cx| { - let input = json!({ - "regex": "main" - }); - Arc::new(GrepTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await; - let results = result.unwrap(); - let paths = extract_paths_from_results(results.content.as_str().unwrap()); - assert!( - paths.iter().any(|p| p.contains("allowed_file.rs")), - "grep_tool should be able to search files inside worktrees" - ); - - // Searching files that match file_scan_exclusions should return no results - let result = cx - .update(|cx| { - let input = json!({ - "regex": "special_configuration" - }); - Arc::new(GrepTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await; - let results = result.unwrap(); - let paths = extract_paths_from_results(results.content.as_str().unwrap()); - assert!( - paths.is_empty(), - "grep_tool should not search files in .secretdir (file_scan_exclusions)" - ); - - let result = cx - .update(|cx| { - let input = json!({ - "regex": "custom_metadata" - }); - Arc::new(GrepTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await; - let results = result.unwrap(); - let paths = extract_paths_from_results(results.content.as_str().unwrap()); - assert!( - paths.is_empty(), - "grep_tool should not search .mymetadata files (file_scan_exclusions)" - ); - - // Searching private files should return no results - let result = cx - .update(|cx| { - let input = json!({ - "regex": "SECRET_KEY" - }); - Arc::new(GrepTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await; - let results = result.unwrap(); - let paths = extract_paths_from_results(results.content.as_str().unwrap()); - assert!( - paths.is_empty(), - "grep_tool should not search .mysecrets (private_files)" - ); - - let result = cx - .update(|cx| { - let input = json!({ - "regex": "private_key_content" - }); - Arc::new(GrepTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await; - let results = result.unwrap(); - let paths = extract_paths_from_results(results.content.as_str().unwrap()); - assert!( - paths.is_empty(), - "grep_tool should not search .privatekey files (private_files)" - ); - - let result = cx - .update(|cx| { - let input = json!({ - "regex": "sensitive_data" - }); - Arc::new(GrepTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await; - let results = result.unwrap(); - let paths = extract_paths_from_results(results.content.as_str().unwrap()); - assert!( - paths.is_empty(), - "grep_tool should not search .mysensitive files (private_files)" - ); - - // Searching a normal file should still work, even with private_files configured - let result = cx - .update(|cx| { - let input = json!({ - "regex": "normal_file_content" - }); - Arc::new(GrepTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await; - let results = result.unwrap(); - let paths = extract_paths_from_results(results.content.as_str().unwrap()); - assert!( - paths.iter().any(|p| p.contains("normal_file.rs")), - "Should be able to search normal files" - ); - - // Path traversal attempts with .. in include_pattern should not escape project - let result = cx - .update(|cx| { - let input = json!({ - "regex": "outside_function", - "include_pattern": "../outside_project/**/*.rs" - }); - Arc::new(GrepTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await; - let results = result.unwrap(); - let paths = extract_paths_from_results(results.content.as_str().unwrap()); - assert!( - paths.is_empty(), - "grep_tool should not allow escaping project boundaries with relative paths" - ); - } - - #[gpui::test] - async fn test_grep_with_multiple_worktree_settings(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - - // Create first worktree with its own private files - fs.insert_tree( - path!("/worktree1"), - json!({ - ".zed": { - "settings.json": r#"{ - "file_scan_exclusions": ["**/fixture.*"], - "private_files": ["**/secret.rs"] - }"# - }, - "src": { - "main.rs": "fn main() { let secret_key = \"hidden\"; }", - "secret.rs": "const API_KEY: &str = \"secret_value\";", - "utils.rs": "pub fn get_config() -> String { \"config\".to_string() }" - }, - "tests": { - "test.rs": "fn test_secret() { assert!(true); }", - "fixture.sql": "SELECT * FROM secret_table;" - } - }), - ) - .await; - - // Create second worktree with different private files - fs.insert_tree( - path!("/worktree2"), - json!({ - ".zed": { - "settings.json": r#"{ - "file_scan_exclusions": ["**/internal.*"], - "private_files": ["**/private.js", "**/data.json"] - }"# - }, - "lib": { - "public.js": "export function getSecret() { return 'public'; }", - "private.js": "const SECRET_KEY = \"private_value\";", - "data.json": "{\"secret_data\": \"hidden\"}" - }, - "docs": { - "README.md": "# Documentation with secret info", - "internal.md": "Internal secret documentation" - } - }), - ) - .await; - - // Set global settings - cx.update(|cx| { - SettingsStore::update_global(cx, |store, cx| { - store.update_user_settings(cx, |settings| { - settings.project.worktree.file_scan_exclusions = - Some(vec!["**/.git".to_string(), "**/node_modules".to_string()]); - settings.project.worktree.private_files = - Some(vec!["**/.env".to_string()].into()); - }); - }); - }); - - let project = Project::test( - fs.clone(), - [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()], - cx, - ) - .await; - - // Wait for worktrees to be fully scanned - cx.executor().run_until_parked(); - - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let model = Arc::new(FakeLanguageModel::default()); - - // Search for "secret" - should exclude files based on worktree-specific settings - let result = cx - .update(|cx| { - let input = json!({ - "regex": "secret", - "case_sensitive": false - }); - Arc::new(GrepTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await - .unwrap(); - - let content = result.content.as_str().unwrap(); - let paths = extract_paths_from_results(content); - - // Should find matches in non-private files - assert!( - paths.iter().any(|p| p.contains("main.rs")), - "Should find 'secret' in worktree1/src/main.rs" - ); - assert!( - paths.iter().any(|p| p.contains("test.rs")), - "Should find 'secret' in worktree1/tests/test.rs" - ); - assert!( - paths.iter().any(|p| p.contains("public.js")), - "Should find 'secret' in worktree2/lib/public.js" - ); - assert!( - paths.iter().any(|p| p.contains("README.md")), - "Should find 'secret' in worktree2/docs/README.md" - ); - - // Should NOT find matches in private/excluded files based on worktree settings - assert!( - !paths.iter().any(|p| p.contains("secret.rs")), - "Should not search in worktree1/src/secret.rs (local private_files)" - ); - assert!( - !paths.iter().any(|p| p.contains("fixture.sql")), - "Should not search in worktree1/tests/fixture.sql (local file_scan_exclusions)" - ); - assert!( - !paths.iter().any(|p| p.contains("private.js")), - "Should not search in worktree2/lib/private.js (local private_files)" - ); - assert!( - !paths.iter().any(|p| p.contains("data.json")), - "Should not search in worktree2/lib/data.json (local private_files)" - ); - assert!( - !paths.iter().any(|p| p.contains("internal.md")), - "Should not search in worktree2/docs/internal.md (local file_scan_exclusions)" - ); - - // Test with `include_pattern` specific to one worktree - let result = cx - .update(|cx| { - let input = json!({ - "regex": "secret", - "include_pattern": "worktree1/**/*.rs" - }); - Arc::new(GrepTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await - .unwrap(); - - let content = result.content.as_str().unwrap(); - let paths = extract_paths_from_results(content); - - // Should only find matches in worktree1 *.rs files (excluding private ones) - assert!( - paths.iter().any(|p| p.contains("main.rs")), - "Should find match in worktree1/src/main.rs" - ); - assert!( - paths.iter().any(|p| p.contains("test.rs")), - "Should find match in worktree1/tests/test.rs" - ); - assert!( - !paths.iter().any(|p| p.contains("secret.rs")), - "Should not find match in excluded worktree1/src/secret.rs" - ); - assert!( - paths.iter().all(|p| !p.contains("worktree2")), - "Should not find any matches in worktree2" - ); - } - - // Helper function to extract file paths from grep results - fn extract_paths_from_results(results: &str) -> Vec { - results - .lines() - .filter(|line| line.starts_with("## Matches in ")) - .map(|line| { - line.strip_prefix("## Matches in ") - .unwrap() - .trim() - .to_string() - }) - .collect() - } -} diff --git a/crates/assistant_tools/src/grep_tool/description.md b/crates/assistant_tools/src/grep_tool/description.md deleted file mode 100644 index e3c0b43f31da53df49ce905e764dedcc5ea530de..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/grep_tool/description.md +++ /dev/null @@ -1,9 +0,0 @@ -Searches the contents of files in the project with a regular expression - -- Prefer this tool to path search when searching for symbols in the project, because you won't need to guess what path it's in. -- Supports full regex syntax (eg. "log.*Error", "function\\s+\\w+", etc.) -- Pass an `include_pattern` if you know how to narrow your search on the files system -- Never use this tool to search for paths. Only search file contents with this tool. -- Use this tool when you need to find files containing specific patterns -- Results are paginated with 20 matches per page. Use the optional 'offset' parameter to request subsequent pages. -- DO NOT use HTML entities solely to escape characters in the tool parameters. diff --git a/crates/assistant_tools/src/list_directory_tool.rs b/crates/assistant_tools/src/list_directory_tool.rs deleted file mode 100644 index 7d70f41a8c5000b433d47e8caa2a60d3a8024b99..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/list_directory_tool.rs +++ /dev/null @@ -1,869 +0,0 @@ -use crate::schema::json_schema_for; -use action_log::ActionLog; -use anyhow::{Result, anyhow}; -use assistant_tool::{Tool, ToolResult}; -use gpui::{AnyWindowHandle, App, Entity, Task}; -use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; -use project::{Project, ProjectPath, WorktreeSettings}; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use settings::Settings; -use std::{fmt::Write, sync::Arc}; -use ui::IconName; -use util::markdown::MarkdownInlineCode; - -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -pub struct ListDirectoryToolInput { - /// The fully-qualified path of the directory to list in the project. - /// - /// This path should never be absolute, and the first component - /// of the path should always be a root directory in a project. - /// - /// - /// If the project has the following root directories: - /// - /// - directory1 - /// - directory2 - /// - /// You can list the contents of `directory1` by using the path `directory1`. - /// - /// - /// - /// If the project has the following root directories: - /// - /// - foo - /// - bar - /// - /// If you wanna list contents in the directory `foo/baz`, you should use the path `foo/baz`. - /// - pub path: String, -} - -pub struct ListDirectoryTool; - -impl Tool for ListDirectoryTool { - fn name(&self) -> String { - "list_directory".into() - } - - fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { - false - } - - fn may_perform_edits(&self) -> bool { - false - } - - fn description(&self) -> String { - include_str!("./list_directory_tool/description.md").into() - } - - fn icon(&self) -> IconName { - IconName::ToolFolder - } - - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { - json_schema_for::(format) - } - - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => { - let path = MarkdownInlineCode(&input.path); - format!("List the {path} directory's contents") - } - Err(_) => "List directory".to_string(), - } - } - - fn run( - self: Arc, - input: serde_json::Value, - _request: Arc, - project: Entity, - _action_log: Entity, - _model: Arc, - _window: Option, - cx: &mut App, - ) -> ToolResult { - let path_style = project.read(cx).path_style(cx); - let input = match serde_json::from_value::(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; - - // Sometimes models will return these even though we tell it to give a path and not a glob. - // When this happens, just list the root worktree directories. - if matches!(input.path.as_str(), "." | "" | "./" | "*") { - let output = project - .read(cx) - .worktrees(cx) - .filter_map(|worktree| { - worktree.read(cx).root_entry().and_then(|entry| { - if entry.is_dir() { - Some(entry.path.display(path_style)) - } else { - None - } - }) - }) - .collect::>() - .join("\n"); - - return Task::ready(Ok(output.into())).into(); - } - - let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else { - return Task::ready(Err(anyhow!("Path {} not found in project", input.path))).into(); - }; - let Some(worktree) = project - .read(cx) - .worktree_for_id(project_path.worktree_id, cx) - else { - return Task::ready(Err(anyhow!("Worktree not found"))).into(); - }; - - // Check if the directory whose contents we're listing is itself excluded or private - let global_settings = WorktreeSettings::get_global(cx); - if global_settings.is_path_excluded(&project_path.path) { - return Task::ready(Err(anyhow!( - "Cannot list directory because its path matches the user's global `file_scan_exclusions` setting: {}", - &input.path - ))) - .into(); - } - - if global_settings.is_path_private(&project_path.path) { - return Task::ready(Err(anyhow!( - "Cannot list directory because its path matches the user's global `private_files` setting: {}", - &input.path - ))) - .into(); - } - - let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx); - if worktree_settings.is_path_excluded(&project_path.path) { - return Task::ready(Err(anyhow!( - "Cannot list directory because its path matches the user's worktree`file_scan_exclusions` setting: {}", - &input.path - ))) - .into(); - } - - if worktree_settings.is_path_private(&project_path.path) { - return Task::ready(Err(anyhow!( - "Cannot list directory because its path matches the user's worktree `private_paths` setting: {}", - &input.path - ))) - .into(); - } - - let worktree_snapshot = worktree.read(cx).snapshot(); - - let Some(entry) = worktree_snapshot.entry_for_path(&project_path.path) else { - return Task::ready(Err(anyhow!("Path not found: {}", input.path))).into(); - }; - - if !entry.is_dir() { - return Task::ready(Err(anyhow!("{} is not a directory.", input.path))).into(); - } - let worktree_snapshot = worktree.read(cx).snapshot(); - - let mut folders = Vec::new(); - let mut files = Vec::new(); - - for entry in worktree_snapshot.child_entries(&project_path.path) { - // Skip private and excluded files and directories - if global_settings.is_path_private(&entry.path) - || global_settings.is_path_excluded(&entry.path) - { - continue; - } - - let project_path = ProjectPath { - worktree_id: worktree_snapshot.id(), - path: entry.path.clone(), - }; - let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx); - - if worktree_settings.is_path_excluded(&project_path.path) - || worktree_settings.is_path_private(&project_path.path) - { - continue; - } - - let full_path = worktree_snapshot - .root_name() - .join(&entry.path) - .display(worktree_snapshot.path_style()) - .to_string(); - if entry.is_dir() { - folders.push(full_path); - } else { - files.push(full_path); - } - } - - let mut output = String::new(); - - if !folders.is_empty() { - writeln!(output, "# Folders:\n{}", folders.join("\n")).unwrap(); - } - - if !files.is_empty() { - writeln!(output, "\n# Files:\n{}", files.join("\n")).unwrap(); - } - - if output.is_empty() { - writeln!(output, "{} is empty.", input.path).unwrap(); - } - - Task::ready(Ok(output.into())).into() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use assistant_tool::Tool; - use gpui::{AppContext, TestAppContext, UpdateGlobal}; - use indoc::indoc; - use language_model::fake_provider::FakeLanguageModel; - use project::{FakeFs, Project}; - use serde_json::json; - use settings::SettingsStore; - use util::path; - - fn platform_paths(path_str: &str) -> String { - if cfg!(target_os = "windows") { - path_str.replace("/", "\\") - } else { - path_str.to_string() - } - } - - fn init_test(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - language::init(cx); - Project::init_settings(cx); - }); - } - - #[gpui::test] - async fn test_list_directory_separates_files_and_dirs(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/project"), - json!({ - "src": { - "main.rs": "fn main() {}", - "lib.rs": "pub fn hello() {}", - "models": { - "user.rs": "struct User {}", - "post.rs": "struct Post {}" - }, - "utils": { - "helper.rs": "pub fn help() {}" - } - }, - "tests": { - "integration_test.rs": "#[test] fn test() {}" - }, - "README.md": "# Project", - "Cargo.toml": "[package]" - }), - ) - .await; - - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let model = Arc::new(FakeLanguageModel::default()); - let tool = Arc::new(ListDirectoryTool); - - // Test listing root directory - let input = json!({ - "path": "project" - }); - - let result = cx - .update(|cx| { - tool.clone().run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }) - .output - .await - .unwrap(); - - let content = result.content.as_str().unwrap(); - assert_eq!( - content, - platform_paths(indoc! {" - # Folders: - project/src - project/tests - - # Files: - project/Cargo.toml - project/README.md - "}) - ); - - // Test listing src directory - let input = json!({ - "path": "project/src" - }); - - let result = cx - .update(|cx| { - tool.clone().run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }) - .output - .await - .unwrap(); - - let content = result.content.as_str().unwrap(); - assert_eq!( - content, - platform_paths(indoc! {" - # Folders: - project/src/models - project/src/utils - - # Files: - project/src/lib.rs - project/src/main.rs - "}) - ); - - // Test listing directory with only files - let input = json!({ - "path": "project/tests" - }); - - let result = cx - .update(|cx| { - tool.clone().run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }) - .output - .await - .unwrap(); - - let content = result.content.as_str().unwrap(); - assert!(!content.contains("# Folders:")); - assert!(content.contains("# Files:")); - assert!(content.contains(&platform_paths("project/tests/integration_test.rs"))); - } - - #[gpui::test] - async fn test_list_directory_empty_directory(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/project"), - json!({ - "empty_dir": {} - }), - ) - .await; - - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let model = Arc::new(FakeLanguageModel::default()); - let tool = Arc::new(ListDirectoryTool); - - let input = json!({ - "path": "project/empty_dir" - }); - - let result = cx - .update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx)) - .output - .await - .unwrap(); - - let content = result.content.as_str().unwrap(); - assert_eq!(content, "project/empty_dir is empty.\n"); - } - - #[gpui::test] - async fn test_list_directory_error_cases(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/project"), - json!({ - "file.txt": "content" - }), - ) - .await; - - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let model = Arc::new(FakeLanguageModel::default()); - let tool = Arc::new(ListDirectoryTool); - - // Test non-existent path - let input = json!({ - "path": "project/nonexistent" - }); - - let result = cx - .update(|cx| { - tool.clone().run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }) - .output - .await; - - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("Path not found")); - - // Test trying to list a file instead of directory - let input = json!({ - "path": "project/file.txt" - }); - - let result = cx - .update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx)) - .output - .await; - - assert!(result.is_err()); - assert!( - result - .unwrap_err() - .to_string() - .contains("is not a directory") - ); - } - - #[gpui::test] - async fn test_list_directory_security(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/project"), - json!({ - "normal_dir": { - "file1.txt": "content", - "file2.txt": "content" - }, - ".mysecrets": "SECRET_KEY=abc123", - ".secretdir": { - "config": "special configuration", - "secret.txt": "secret content" - }, - ".mymetadata": "custom metadata", - "visible_dir": { - "normal.txt": "normal content", - "special.privatekey": "private key content", - "data.mysensitive": "sensitive data", - ".hidden_subdir": { - "hidden_file.txt": "hidden content" - } - } - }), - ) - .await; - - // Configure settings explicitly - cx.update(|cx| { - SettingsStore::update_global(cx, |store, cx| { - store.update_user_settings(cx, |settings| { - settings.project.worktree.file_scan_exclusions = Some(vec![ - "**/.secretdir".to_string(), - "**/.mymetadata".to_string(), - "**/.hidden_subdir".to_string(), - ]); - settings.project.worktree.private_files = Some( - vec![ - "**/.mysecrets".to_string(), - "**/*.privatekey".to_string(), - "**/*.mysensitive".to_string(), - ] - .into(), - ); - }); - }); - }); - - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let model = Arc::new(FakeLanguageModel::default()); - let tool = Arc::new(ListDirectoryTool); - - // Listing root directory should exclude private and excluded files - let input = json!({ - "path": "project" - }); - - let result = cx - .update(|cx| { - tool.clone().run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }) - .output - .await - .unwrap(); - - let content = result.content.as_str().unwrap(); - - // Should include normal directories - assert!(content.contains("normal_dir"), "Should list normal_dir"); - assert!(content.contains("visible_dir"), "Should list visible_dir"); - - // Should NOT include excluded or private files - assert!( - !content.contains(".secretdir"), - "Should not list .secretdir (file_scan_exclusions)" - ); - assert!( - !content.contains(".mymetadata"), - "Should not list .mymetadata (file_scan_exclusions)" - ); - assert!( - !content.contains(".mysecrets"), - "Should not list .mysecrets (private_files)" - ); - - // Trying to list an excluded directory should fail - let input = json!({ - "path": "project/.secretdir" - }); - - let result = cx - .update(|cx| { - tool.clone().run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }) - .output - .await; - - assert!( - result.is_err(), - "Should not be able to list excluded directory" - ); - assert!( - result - .unwrap_err() - .to_string() - .contains("file_scan_exclusions"), - "Error should mention file_scan_exclusions" - ); - - // Listing a directory should exclude private files within it - let input = json!({ - "path": "project/visible_dir" - }); - - let result = cx - .update(|cx| { - tool.clone().run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }) - .output - .await - .unwrap(); - - let content = result.content.as_str().unwrap(); - - // Should include normal files - assert!(content.contains("normal.txt"), "Should list normal.txt"); - - // Should NOT include private files - assert!( - !content.contains("privatekey"), - "Should not list .privatekey files (private_files)" - ); - assert!( - !content.contains("mysensitive"), - "Should not list .mysensitive files (private_files)" - ); - - // Should NOT include subdirectories that match exclusions - assert!( - !content.contains(".hidden_subdir"), - "Should not list .hidden_subdir (file_scan_exclusions)" - ); - } - - #[gpui::test] - async fn test_list_directory_with_multiple_worktree_settings(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - - // Create first worktree with its own private files - fs.insert_tree( - path!("/worktree1"), - json!({ - ".zed": { - "settings.json": r#"{ - "file_scan_exclusions": ["**/fixture.*"], - "private_files": ["**/secret.rs", "**/config.toml"] - }"# - }, - "src": { - "main.rs": "fn main() { println!(\"Hello from worktree1\"); }", - "secret.rs": "const API_KEY: &str = \"secret_key_1\";", - "config.toml": "[database]\nurl = \"postgres://localhost/db1\"" - }, - "tests": { - "test.rs": "mod tests { fn test_it() {} }", - "fixture.sql": "CREATE TABLE users (id INT, name VARCHAR(255));" - } - }), - ) - .await; - - // Create second worktree with different private files - fs.insert_tree( - path!("/worktree2"), - json!({ - ".zed": { - "settings.json": r#"{ - "file_scan_exclusions": ["**/internal.*"], - "private_files": ["**/private.js", "**/data.json"] - }"# - }, - "lib": { - "public.js": "export function greet() { return 'Hello from worktree2'; }", - "private.js": "const SECRET_TOKEN = \"private_token_2\";", - "data.json": "{\"api_key\": \"json_secret_key\"}" - }, - "docs": { - "README.md": "# Public Documentation", - "internal.md": "# Internal Secrets and Configuration" - } - }), - ) - .await; - - // Set global settings - cx.update(|cx| { - SettingsStore::update_global(cx, |store, cx| { - store.update_user_settings(cx, |settings| { - settings.project.worktree.file_scan_exclusions = - Some(vec!["**/.git".to_string(), "**/node_modules".to_string()]); - settings.project.worktree.private_files = - Some(vec!["**/.env".to_string()].into()); - }); - }); - }); - - let project = Project::test( - fs.clone(), - [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()], - cx, - ) - .await; - - // Wait for worktrees to be fully scanned - cx.executor().run_until_parked(); - - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let model = Arc::new(FakeLanguageModel::default()); - let tool = Arc::new(ListDirectoryTool); - - // Test listing worktree1/src - should exclude secret.rs and config.toml based on local settings - let input = json!({ - "path": "worktree1/src" - }); - - let result = cx - .update(|cx| { - tool.clone().run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }) - .output - .await - .unwrap(); - - let content = result.content.as_str().unwrap(); - assert!(content.contains("main.rs"), "Should list main.rs"); - assert!( - !content.contains("secret.rs"), - "Should not list secret.rs (local private_files)" - ); - assert!( - !content.contains("config.toml"), - "Should not list config.toml (local private_files)" - ); - - // Test listing worktree1/tests - should exclude fixture.sql based on local settings - let input = json!({ - "path": "worktree1/tests" - }); - - let result = cx - .update(|cx| { - tool.clone().run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }) - .output - .await - .unwrap(); - - let content = result.content.as_str().unwrap(); - assert!(content.contains("test.rs"), "Should list test.rs"); - assert!( - !content.contains("fixture.sql"), - "Should not list fixture.sql (local file_scan_exclusions)" - ); - - // Test listing worktree2/lib - should exclude private.js and data.json based on local settings - let input = json!({ - "path": "worktree2/lib" - }); - - let result = cx - .update(|cx| { - tool.clone().run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }) - .output - .await - .unwrap(); - - let content = result.content.as_str().unwrap(); - assert!(content.contains("public.js"), "Should list public.js"); - assert!( - !content.contains("private.js"), - "Should not list private.js (local private_files)" - ); - assert!( - !content.contains("data.json"), - "Should not list data.json (local private_files)" - ); - - // Test listing worktree2/docs - should exclude internal.md based on local settings - let input = json!({ - "path": "worktree2/docs" - }); - - let result = cx - .update(|cx| { - tool.clone().run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }) - .output - .await - .unwrap(); - - let content = result.content.as_str().unwrap(); - assert!(content.contains("README.md"), "Should list README.md"); - assert!( - !content.contains("internal.md"), - "Should not list internal.md (local file_scan_exclusions)" - ); - - // Test trying to list an excluded directory directly - let input = json!({ - "path": "worktree1/src/secret.rs" - }); - - let result = cx - .update(|cx| { - tool.clone().run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }) - .output - .await; - - // This should fail because we're trying to list a file, not a directory - assert!(result.is_err(), "Should fail when trying to list a file"); - } -} diff --git a/crates/assistant_tools/src/list_directory_tool/description.md b/crates/assistant_tools/src/list_directory_tool/description.md deleted file mode 100644 index 30dcc012ff316c944a7495dc14457cfd9df93bb7..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/list_directory_tool/description.md +++ /dev/null @@ -1 +0,0 @@ -Lists files and directories in a given path. Prefer the `grep` or `find_path` tools when searching the codebase. diff --git a/crates/assistant_tools/src/move_path_tool.rs b/crates/assistant_tools/src/move_path_tool.rs deleted file mode 100644 index 22dbe9e625468d8c2688b60bdcd94a7da594730e..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/move_path_tool.rs +++ /dev/null @@ -1,132 +0,0 @@ -use crate::schema::json_schema_for; -use action_log::ActionLog; -use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{Tool, ToolResult}; -use gpui::{AnyWindowHandle, App, AppContext, Entity, Task}; -use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; -use project::Project; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use std::{path::Path, sync::Arc}; -use ui::IconName; -use util::markdown::MarkdownInlineCode; - -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -pub struct MovePathToolInput { - /// The source path of the file or directory to move/rename. - /// - /// - /// If the project has the following files: - /// - /// - directory1/a/something.txt - /// - directory2/a/things.txt - /// - directory3/a/other.txt - /// - /// You can move the first file by providing a source_path of "directory1/a/something.txt" - /// - pub source_path: String, - - /// The destination path where the file or directory should be moved/renamed to. - /// If the paths are the same except for the filename, then this will be a rename. - /// - /// - /// To move "directory1/a/something.txt" to "directory2/b/renamed.txt", - /// provide a destination_path of "directory2/b/renamed.txt" - /// - pub destination_path: String, -} - -pub struct MovePathTool; - -impl Tool for MovePathTool { - fn name(&self) -> String { - "move_path".into() - } - - fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { - false - } - - fn may_perform_edits(&self) -> bool { - true - } - - fn description(&self) -> String { - include_str!("./move_path_tool/description.md").into() - } - - fn icon(&self) -> IconName { - IconName::ArrowRightLeft - } - - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { - json_schema_for::(format) - } - - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => { - let src = MarkdownInlineCode(&input.source_path); - let dest = MarkdownInlineCode(&input.destination_path); - let src_path = Path::new(&input.source_path); - let dest_path = Path::new(&input.destination_path); - - match dest_path - .file_name() - .and_then(|os_str| os_str.to_os_string().into_string().ok()) - { - Some(filename) if src_path.parent() == dest_path.parent() => { - let filename = MarkdownInlineCode(&filename); - format!("Rename {src} to {filename}") - } - _ => { - format!("Move {src} to {dest}") - } - } - } - Err(_) => "Move path".to_string(), - } - } - - fn run( - self: Arc, - input: serde_json::Value, - _request: Arc, - project: Entity, - _action_log: Entity, - _model: Arc, - _window: Option, - cx: &mut App, - ) -> ToolResult { - let input = match serde_json::from_value::(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; - let rename_task = project.update(cx, |project, cx| { - match project - .find_project_path(&input.source_path, cx) - .and_then(|project_path| project.entry_for_path(&project_path, cx)) - { - Some(entity) => match project.find_project_path(&input.destination_path, cx) { - Some(project_path) => project.rename_entry(entity.id, project_path, cx), - None => Task::ready(Err(anyhow!( - "Destination path {} was outside the project.", - input.destination_path - ))), - }, - None => Task::ready(Err(anyhow!( - "Source path {} was not found in the project.", - input.source_path - ))), - } - }); - - cx.background_spawn(async move { - let _ = rename_task.await.with_context(|| { - format!("Moving {} to {}", input.source_path, input.destination_path) - })?; - Ok(format!("Moved {} to {}", input.source_path, input.destination_path).into()) - }) - .into() - } -} diff --git a/crates/assistant_tools/src/move_path_tool/description.md b/crates/assistant_tools/src/move_path_tool/description.md deleted file mode 100644 index 76bc3003d003c44afdd9036cb6691d5fc432291d..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/move_path_tool/description.md +++ /dev/null @@ -1,5 +0,0 @@ -Moves or rename a file or directory in the project, and returns confirmation that the move succeeded. -If the source and destination directories are the same, but the filename is different, this performs -a rename. Otherwise, it performs a move. - -This tool should be used when it's desirable to move or rename a file or directory without changing its contents at all. diff --git a/crates/assistant_tools/src/now_tool.rs b/crates/assistant_tools/src/now_tool.rs deleted file mode 100644 index f50ad065d1cd320aa1a82e4ce17f744d6b04be2c..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/now_tool.rs +++ /dev/null @@ -1,84 +0,0 @@ -use std::sync::Arc; - -use crate::schema::json_schema_for; -use action_log::ActionLog; -use anyhow::{Result, anyhow}; -use assistant_tool::{Tool, ToolResult}; -use chrono::{Local, Utc}; -use gpui::{AnyWindowHandle, App, Entity, Task}; -use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; -use project::Project; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use ui::IconName; - -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -#[serde(rename_all = "snake_case")] -pub enum Timezone { - /// Use UTC for the datetime. - Utc, - /// Use local time for the datetime. - Local, -} - -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -pub struct NowToolInput { - /// The timezone to use for the datetime. - timezone: Timezone, -} - -pub struct NowTool; - -impl Tool for NowTool { - fn name(&self) -> String { - "now".into() - } - - fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { - false - } - - fn may_perform_edits(&self) -> bool { - false - } - - fn description(&self) -> String { - "Returns the current datetime in RFC 3339 format. Only use this tool when the user specifically asks for it or the current task would benefit from knowing the current datetime.".into() - } - - fn icon(&self) -> IconName { - IconName::Info - } - - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { - json_schema_for::(format) - } - - fn ui_text(&self, _input: &serde_json::Value) -> String { - "Get current time".to_string() - } - - fn run( - self: Arc, - input: serde_json::Value, - _request: Arc, - _project: Entity, - _action_log: Entity, - _model: Arc, - _window: Option, - _cx: &mut App, - ) -> ToolResult { - let input: NowToolInput = match serde_json::from_value(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; - - let now = match input.timezone { - Timezone::Utc => Utc::now().to_rfc3339(), - Timezone::Local => Local::now().to_rfc3339(), - }; - let text = format!("The current datetime is {now}."); - - Task::ready(Ok(text.into())).into() - } -} diff --git a/crates/assistant_tools/src/open_tool.rs b/crates/assistant_tools/src/open_tool.rs deleted file mode 100644 index a1aafad041364b0ffca01cc1890c2cc10b3d7b01..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/open_tool.rs +++ /dev/null @@ -1,170 +0,0 @@ -use crate::schema::json_schema_for; -use action_log::ActionLog; -use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{Tool, ToolResult}; -use gpui::{AnyWindowHandle, App, AppContext, Entity, Task}; -use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; -use project::Project; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use std::{path::PathBuf, sync::Arc}; -use ui::IconName; -use util::markdown::MarkdownEscaped; - -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -pub struct OpenToolInput { - /// The path or URL to open with the default application. - path_or_url: String, -} - -pub struct OpenTool; - -impl Tool for OpenTool { - fn name(&self) -> String { - "open".to_string() - } - - fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { - true - } - fn may_perform_edits(&self) -> bool { - false - } - fn description(&self) -> String { - include_str!("./open_tool/description.md").to_string() - } - - fn icon(&self) -> IconName { - IconName::ArrowUpRight - } - - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { - json_schema_for::(format) - } - - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => format!("Open `{}`", MarkdownEscaped(&input.path_or_url)), - Err(_) => "Open file or URL".to_string(), - } - } - - fn run( - self: Arc, - input: serde_json::Value, - _request: Arc, - project: Entity, - _action_log: Entity, - _model: Arc, - _window: Option, - cx: &mut App, - ) -> ToolResult { - let input: OpenToolInput = match serde_json::from_value(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; - - // If path_or_url turns out to be a path in the project, make it absolute. - let abs_path = to_absolute_path(&input.path_or_url, project, cx); - - cx.background_spawn(async move { - match abs_path { - Some(path) => open::that(path), - None => open::that(&input.path_or_url), - } - .context("Failed to open URL or file path")?; - - Ok(format!("Successfully opened {}", input.path_or_url).into()) - }) - .into() - } -} - -fn to_absolute_path( - potential_path: &str, - project: Entity, - cx: &mut App, -) -> Option { - let project = project.read(cx); - project - .find_project_path(PathBuf::from(potential_path), cx) - .and_then(|project_path| project.absolute_path(&project_path, cx)) -} - -#[cfg(test)] -mod tests { - use super::*; - use gpui::TestAppContext; - use project::{FakeFs, Project}; - use settings::SettingsStore; - use std::path::Path; - use tempfile::TempDir; - - #[gpui::test] - async fn test_to_absolute_path(cx: &mut TestAppContext) { - init_test(cx); - let temp_dir = TempDir::new().expect("Failed to create temp directory"); - let temp_path = temp_dir.path().to_string_lossy().into_owned(); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - &temp_path, - serde_json::json!({ - "src": { - "main.rs": "fn main() {}", - "lib.rs": "pub fn lib_fn() {}" - }, - "docs": { - "readme.md": "# Project Documentation" - } - }), - ) - .await; - - // Use the temp_path as the root directory, not just its filename - let project = Project::test(fs.clone(), [temp_dir.path()], cx).await; - - // Test cases where the function should return Some - cx.update(|cx| { - // Project-relative paths should return Some - // Create paths using the last segment of the temp path to simulate a project-relative path - let root_dir_name = Path::new(&temp_path) - .file_name() - .unwrap_or_else(|| std::ffi::OsStr::new("temp")) - .to_string_lossy(); - - assert!( - to_absolute_path(&format!("{root_dir_name}/src/main.rs"), project.clone(), cx) - .is_some(), - "Failed to resolve main.rs path" - ); - - assert!( - to_absolute_path( - &format!("{root_dir_name}/docs/readme.md",), - project.clone(), - cx, - ) - .is_some(), - "Failed to resolve readme.md path" - ); - - // External URL should return None - let result = to_absolute_path("https://example.com", project.clone(), cx); - assert_eq!(result, None, "External URLs should return None"); - - // Path outside project - let result = to_absolute_path("../invalid/path", project.clone(), cx); - assert_eq!(result, None, "Paths outside the project should return None"); - }); - } - - fn init_test(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - language::init(cx); - Project::init_settings(cx); - }); - } -} diff --git a/crates/assistant_tools/src/open_tool/description.md b/crates/assistant_tools/src/open_tool/description.md deleted file mode 100644 index 99ccbb0524473b8c740d6ecd2d9ca9555e1e7028..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/open_tool/description.md +++ /dev/null @@ -1,9 +0,0 @@ -This tool opens a file or URL with the default application associated with it on the user's operating system: -- On macOS, it's equivalent to the `open` command -- On Windows, it's equivalent to `start` -- On Linux, it uses something like `xdg-open`, `gio open`, `gnome-open`, `kde-open`, `wslview` as appropriate - -For example, it can open a web browser with a URL, open a PDF file with the default PDF viewer, etc. - -You MUST ONLY use this tool when the user has explicitly requested opening something. You MUST NEVER assume that -the user would like for you to use this tool. diff --git a/crates/assistant_tools/src/project_notifications_tool.rs b/crates/assistant_tools/src/project_notifications_tool.rs deleted file mode 100644 index e30d80207dae4de1e69efe99724a2a5343b57664..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/project_notifications_tool.rs +++ /dev/null @@ -1,360 +0,0 @@ -use crate::schema::json_schema_for; -use action_log::ActionLog; -use anyhow::Result; -use assistant_tool::{Tool, ToolResult}; -use gpui::{AnyWindowHandle, App, Entity, Task}; -use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; -use project::Project; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use std::{fmt::Write, sync::Arc}; -use ui::IconName; - -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -pub struct ProjectUpdatesToolInput {} - -pub struct ProjectNotificationsTool; - -impl Tool for ProjectNotificationsTool { - fn name(&self) -> String { - "project_notifications".to_string() - } - - fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { - false - } - fn may_perform_edits(&self) -> bool { - false - } - fn description(&self) -> String { - include_str!("./project_notifications_tool/description.md").to_string() - } - - fn icon(&self) -> IconName { - IconName::ToolNotification - } - - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { - json_schema_for::(format) - } - - fn ui_text(&self, _input: &serde_json::Value) -> String { - "Check project notifications".into() - } - - fn run( - self: Arc, - _input: serde_json::Value, - _request: Arc, - _project: Entity, - action_log: Entity, - _model: Arc, - _window: Option, - cx: &mut App, - ) -> ToolResult { - let Some(user_edits_diff) = - action_log.update(cx, |log, cx| log.flush_unnotified_user_edits(cx)) - else { - return result("No new notifications"); - }; - - // NOTE: Changes to this prompt require a symmetric update in the LLM Worker - const HEADER: &str = include_str!("./project_notifications_tool/prompt_header.txt"); - const MAX_BYTES: usize = 8000; - let diff = fit_patch_to_size(&user_edits_diff, MAX_BYTES); - result(&format!("{HEADER}\n\n```diff\n{diff}\n```\n").replace("\r\n", "\n")) - } -} - -fn result(response: &str) -> ToolResult { - Task::ready(Ok(response.to_string().into())).into() -} - -/// Make sure that the patch fits into the size limit (in bytes). -/// Compress the patch by omitting some parts if needed. -/// Unified diff format is assumed. -fn fit_patch_to_size(patch: &str, max_size: usize) -> String { - if patch.len() <= max_size { - return patch.to_string(); - } - - // Compression level 1: remove context lines in diff bodies, but - // leave the counts and positions of inserted/deleted lines - let mut current_size = patch.len(); - let mut file_patches = split_patch(patch); - file_patches.sort_by_key(|patch| patch.len()); - let compressed_patches = file_patches - .iter() - .rev() - .map(|patch| { - if current_size > max_size { - let compressed = compress_patch(patch).unwrap_or_else(|_| patch.to_string()); - current_size -= patch.len() - compressed.len(); - compressed - } else { - patch.to_string() - } - }) - .collect::>(); - - if current_size <= max_size { - return compressed_patches.join("\n\n"); - } - - // Compression level 2: list paths of the changed files only - let filenames = file_patches - .iter() - .map(|patch| { - let patch = diffy::Patch::from_str(patch).unwrap(); - let path = patch - .modified() - .and_then(|path| path.strip_prefix("b/")) - .unwrap_or_default(); - format!("- {path}\n") - }) - .collect::>(); - - filenames.join("") -} - -/// Split a potentially multi-file patch into multiple single-file patches -fn split_patch(patch: &str) -> Vec { - let mut result = Vec::new(); - let mut current_patch = String::new(); - - for line in patch.lines() { - if line.starts_with("---") && !current_patch.is_empty() { - result.push(current_patch.trim_end_matches('\n').into()); - current_patch = String::new(); - } - current_patch.push_str(line); - current_patch.push('\n'); - } - - if !current_patch.is_empty() { - result.push(current_patch.trim_end_matches('\n').into()); - } - - result -} - -fn compress_patch(patch: &str) -> anyhow::Result { - let patch = diffy::Patch::from_str(patch)?; - let mut out = String::new(); - - writeln!(out, "--- {}", patch.original().unwrap_or("a"))?; - writeln!(out, "+++ {}", patch.modified().unwrap_or("b"))?; - - for hunk in patch.hunks() { - writeln!(out, "@@ -{} +{} @@", hunk.old_range(), hunk.new_range())?; - writeln!(out, "[...skipped...]")?; - } - - Ok(out) -} - -#[cfg(test)] -mod tests { - use super::*; - use assistant_tool::ToolResultContent; - use gpui::{AppContext, TestAppContext}; - use indoc::indoc; - use language_model::{LanguageModelRequest, fake_provider::FakeLanguageModelProvider}; - use project::{FakeFs, Project}; - use serde_json::json; - use settings::SettingsStore; - use std::sync::Arc; - use util::path; - - #[gpui::test] - async fn test_stale_buffer_notification(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/test"), - json!({"code.rs": "fn main() {\n println!(\"Hello, world!\");\n}"}), - ) - .await; - - let project = Project::test(fs, [path!("/test").as_ref()], cx).await; - let action_log = cx.new(|_| ActionLog::new(project.clone())); - - let buffer_path = project - .read_with(cx, |project, cx| { - project.find_project_path("test/code.rs", cx) - }) - .unwrap(); - - let buffer = project - .update(cx, |project, cx| { - project.open_buffer(buffer_path.clone(), cx) - }) - .await - .unwrap(); - - // Start tracking the buffer - action_log.update(cx, |log, cx| { - log.buffer_read(buffer.clone(), cx); - }); - cx.run_until_parked(); - - // Run the tool before any changes - let tool = Arc::new(ProjectNotificationsTool); - let provider = Arc::new(FakeLanguageModelProvider::default()); - let model: Arc = Arc::new(provider.test_model()); - let request = Arc::new(LanguageModelRequest::default()); - let tool_input = json!({}); - - let result = cx.update(|cx| { - tool.clone().run( - tool_input.clone(), - request.clone(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }); - cx.run_until_parked(); - - let response = result.output.await.unwrap(); - let response_text = match &response.content { - ToolResultContent::Text(text) => text.clone(), - _ => panic!("Expected text response"), - }; - assert_eq!( - response_text.as_str(), - "No new notifications", - "Tool should return 'No new notifications' when no stale buffers" - ); - - // Modify the buffer (makes it stale) - buffer.update(cx, |buffer, cx| { - buffer.edit([(1..1, "\nChange!\n")], None, cx); - }); - cx.run_until_parked(); - - // Run the tool again - let result = cx.update(|cx| { - tool.clone().run( - tool_input.clone(), - request.clone(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }); - cx.run_until_parked(); - - // This time the buffer is stale, so the tool should return a notification - let response = result.output.await.unwrap(); - let response_text = match &response.content { - ToolResultContent::Text(text) => text.clone(), - _ => panic!("Expected text response"), - }; - - assert!( - response_text.contains("These files have changed"), - "Tool should return the stale buffer notification" - ); - assert!( - response_text.contains("test/code.rs"), - "Tool should return the stale buffer notification" - ); - - // Run the tool once more without any changes - should get no new notifications - let result = cx.update(|cx| { - tool.run( - tool_input.clone(), - request.clone(), - project.clone(), - action_log, - model.clone(), - None, - cx, - ) - }); - cx.run_until_parked(); - - let response = result.output.await.unwrap(); - let response_text = match &response.content { - ToolResultContent::Text(text) => text.clone(), - _ => panic!("Expected text response"), - }; - - assert_eq!( - response_text.as_str(), - "No new notifications", - "Tool should return 'No new notifications' when running again without changes" - ); - } - - #[test] - fn test_patch_compression() { - // Given a patch that doesn't fit into the size budget - let patch = indoc! {" - --- a/dir/test.txt - +++ b/dir/test.txt - @@ -1,3 +1,3 @@ - line 1 - -line 2 - +CHANGED - line 3 - @@ -10,2 +10,2 @@ - line 10 - -line 11 - +line eleven - - - --- a/dir/another.txt - +++ b/dir/another.txt - @@ -100,1 +1,1 @@ - -before - +after - "}; - - // When the size deficit can be compensated by dropping the body, - // then the body should be trimmed for larger files first - let limit = patch.len() - 10; - let compressed = fit_patch_to_size(patch, limit); - let expected = indoc! {" - --- a/dir/test.txt - +++ b/dir/test.txt - @@ -1,3 +1,3 @@ - [...skipped...] - @@ -10,2 +10,2 @@ - [...skipped...] - - - --- a/dir/another.txt - +++ b/dir/another.txt - @@ -100,1 +1,1 @@ - -before - +after"}; - assert_eq!(compressed, expected); - - // When the size deficit is too large, then only file paths - // should be returned - let limit = 10; - let compressed = fit_patch_to_size(patch, limit); - let expected = indoc! {" - - dir/another.txt - - dir/test.txt - "}; - assert_eq!(compressed, expected); - } - - fn init_test(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - language::init(cx); - Project::init_settings(cx); - assistant_tool::init(cx); - }); - } -} diff --git a/crates/assistant_tools/src/project_notifications_tool/description.md b/crates/assistant_tools/src/project_notifications_tool/description.md deleted file mode 100644 index 24ff678f5e7fd728b94ad4ebce06f2a1dcc6a658..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/project_notifications_tool/description.md +++ /dev/null @@ -1,3 +0,0 @@ -This tool reports which files have been modified by the user since the agent last accessed them. - -It serves as a notification mechanism to inform the agent of recent changes. No immediate action is required in response to these updates. diff --git a/crates/assistant_tools/src/project_notifications_tool/prompt_header.txt b/crates/assistant_tools/src/project_notifications_tool/prompt_header.txt deleted file mode 100644 index f743e239c883c7456f7bdc6e089185c6b994cb44..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/project_notifications_tool/prompt_header.txt +++ /dev/null @@ -1,3 +0,0 @@ -[The following is an auto-generated notification; do not reply] - -These files have changed since the last read: diff --git a/crates/assistant_tools/src/read_file_tool.rs b/crates/assistant_tools/src/read_file_tool.rs deleted file mode 100644 index f9f68491e5846fa1ead09d6976d1f9a9bc99b501..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/read_file_tool.rs +++ /dev/null @@ -1,1190 +0,0 @@ -use crate::schema::json_schema_for; -use action_log::ActionLog; -use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{Tool, ToolResult}; -use assistant_tool::{ToolResultContent, outline}; -use gpui::{AnyWindowHandle, App, Entity, Task}; -use project::{ImageItem, image_store}; - -use assistant_tool::ToolResultOutput; -use indoc::formatdoc; -use itertools::Itertools; -use language::{Anchor, Point}; -use language_model::{ - LanguageModel, LanguageModelImage, LanguageModelRequest, LanguageModelToolSchemaFormat, -}; -use project::{AgentLocation, Project, WorktreeSettings}; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use settings::Settings; -use std::sync::Arc; -use ui::IconName; - -/// If the model requests to read a file whose size exceeds this, then -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -pub struct ReadFileToolInput { - /// The relative path of the file to read. - /// - /// This path should never be absolute, and the first component - /// of the path should always be a root directory in a project. - /// - /// - /// If the project has the following root directories: - /// - /// - /a/b/directory1 - /// - /c/d/directory2 - /// - /// If you want to access `file.txt` in `directory1`, you should use the path `directory1/file.txt`. - /// If you want to access `file.txt` in `directory2`, you should use the path `directory2/file.txt`. - /// - pub path: String, - - /// Optional line number to start reading on (1-based index) - #[serde(default)] - pub start_line: Option, - - /// Optional line number to end reading on (1-based index, inclusive) - #[serde(default)] - pub end_line: Option, -} - -pub struct ReadFileTool; - -impl Tool for ReadFileTool { - fn name(&self) -> String { - "read_file".into() - } - - fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { - false - } - - fn may_perform_edits(&self) -> bool { - false - } - - fn description(&self) -> String { - include_str!("./read_file_tool/description.md").into() - } - - fn icon(&self) -> IconName { - IconName::ToolSearch - } - - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { - json_schema_for::(format) - } - - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => { - let path = &input.path; - match (input.start_line, input.end_line) { - (Some(start), Some(end)) => { - format!( - "[Read file `{}` (lines {}-{})](@selection:{}:({}-{}))", - path, start, end, path, start, end - ) - } - (Some(start), None) => { - format!( - "[Read file `{}` (from line {})](@selection:{}:({}-{}))", - path, start, path, start, start - ) - } - _ => format!("[Read file `{}`](@file:{})", path, path), - } - } - Err(_) => "Read file".to_string(), - } - } - - fn run( - self: Arc, - input: serde_json::Value, - _request: Arc, - project: Entity, - action_log: Entity, - model: Arc, - _window: Option, - cx: &mut App, - ) -> ToolResult { - let input = match serde_json::from_value::(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; - - let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else { - return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into(); - }; - - // Error out if this path is either excluded or private in global settings - let global_settings = WorktreeSettings::get_global(cx); - if global_settings.is_path_excluded(&project_path.path) { - return Task::ready(Err(anyhow!( - "Cannot read file because its path matches the global `file_scan_exclusions` setting: {}", - &input.path - ))) - .into(); - } - - if global_settings.is_path_private(&project_path.path) { - return Task::ready(Err(anyhow!( - "Cannot read file because its path matches the global `private_files` setting: {}", - &input.path - ))) - .into(); - } - - // Error out if this path is either excluded or private in worktree settings - let worktree_settings = WorktreeSettings::get(Some((&project_path).into()), cx); - if worktree_settings.is_path_excluded(&project_path.path) { - return Task::ready(Err(anyhow!( - "Cannot read file because its path matches the worktree `file_scan_exclusions` setting: {}", - &input.path - ))) - .into(); - } - - if worktree_settings.is_path_private(&project_path.path) { - return Task::ready(Err(anyhow!( - "Cannot read file because its path matches the worktree `private_files` setting: {}", - &input.path - ))) - .into(); - } - - let file_path = input.path.clone(); - - if image_store::is_image_file(&project, &project_path, cx) { - if !model.supports_images() { - return Task::ready(Err(anyhow!( - "Attempted to read an image, but Zed doesn't currently support sending images to {}.", - model.name().0 - ))) - .into(); - } - - let task = cx.spawn(async move |cx| -> Result { - let image_entity: Entity = cx - .update(|cx| { - project.update(cx, |project, cx| { - project.open_image(project_path.clone(), cx) - }) - })? - .await?; - - let image = - image_entity.read_with(cx, |image_item, _| Arc::clone(&image_item.image))?; - - let language_model_image = cx - .update(|cx| LanguageModelImage::from_image(image, cx))? - .await - .context("processing image")?; - - Ok(ToolResultOutput { - content: ToolResultContent::Image(language_model_image), - output: None, - }) - }); - - return task.into(); - } - - cx.spawn(async move |cx| { - let buffer = cx - .update(|cx| { - project.update(cx, |project, cx| project.open_buffer(project_path, cx)) - })? - .await?; - if buffer.read_with(cx, |buffer, _| { - buffer - .file() - .as_ref() - .is_none_or(|file| !file.disk_state().exists()) - })? { - anyhow::bail!("{file_path} not found"); - } - - project.update(cx, |project, cx| { - project.set_agent_location( - Some(AgentLocation { - buffer: buffer.downgrade(), - position: Anchor::MIN, - }), - cx, - ); - })?; - - // Check if specific line ranges are provided - if input.start_line.is_some() || input.end_line.is_some() { - let mut anchor = None; - let result = buffer.read_with(cx, |buffer, _cx| { - let text = buffer.text(); - // .max(1) because despite instructions to be 1-indexed, sometimes the model passes 0. - let start = input.start_line.unwrap_or(1).max(1); - let start_row = start - 1; - if start_row <= buffer.max_point().row { - let column = buffer.line_indent_for_row(start_row).raw_len(); - anchor = Some(buffer.anchor_before(Point::new(start_row, column))); - } - - let lines = text.split('\n').skip(start_row as usize); - if let Some(end) = input.end_line { - let count = end.saturating_sub(start).saturating_add(1); // Ensure at least 1 line - Itertools::intersperse(lines.take(count as usize), "\n") - .collect::() - .into() - } else { - Itertools::intersperse(lines, "\n") - .collect::() - .into() - } - })?; - - action_log.update(cx, |log, cx| { - log.buffer_read(buffer.clone(), cx); - })?; - - if let Some(anchor) = anchor { - project.update(cx, |project, cx| { - project.set_agent_location( - Some(AgentLocation { - buffer: buffer.downgrade(), - position: anchor, - }), - cx, - ); - })?; - } - - Ok(result) - } else { - // No line ranges specified, so check file size to see if it's too big. - let buffer_content = - outline::get_buffer_content_or_outline(buffer.clone(), Some(&file_path), cx) - .await?; - - action_log.update(cx, |log, cx| { - log.buffer_read(buffer, cx); - })?; - - if buffer_content.is_outline { - Ok(formatdoc! {" - This file was too big to read all at once. - - {} - - Using the line numbers in this outline, you can call this tool again - while specifying the start_line and end_line fields to see the - implementations of symbols in the outline. - - Alternatively, you can fall back to the `grep` tool (if available) - to search the file for specific content.", buffer_content.text - } - .into()) - } else { - Ok(buffer_content.text.into()) - } - } - }) - .into() - } -} - -#[cfg(test)] -mod test { - use super::*; - use gpui::{AppContext, TestAppContext, UpdateGlobal}; - use language::{Language, LanguageConfig, LanguageMatcher}; - use language_model::fake_provider::FakeLanguageModel; - use project::{FakeFs, Project}; - use serde_json::json; - use settings::SettingsStore; - use util::path; - - #[gpui::test] - async fn test_read_nonexistent_file(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree(path!("/root"), json!({})).await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let model = Arc::new(FakeLanguageModel::default()); - let result = cx - .update(|cx| { - let input = json!({ - "path": "root/nonexistent_file.txt" - }); - Arc::new(ReadFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log, - model, - None, - cx, - ) - .output - }) - .await; - assert_eq!( - result.unwrap_err().to_string(), - "root/nonexistent_file.txt not found" - ); - } - - #[gpui::test] - async fn test_read_small_file(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/root"), - json!({ - "small_file.txt": "This is a small file content" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let model = Arc::new(FakeLanguageModel::default()); - let result = cx - .update(|cx| { - let input = json!({ - "path": "root/small_file.txt" - }); - Arc::new(ReadFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log, - model, - None, - cx, - ) - .output - }) - .await; - assert_eq!( - result.unwrap().content.as_str(), - Some("This is a small file content") - ); - } - - #[gpui::test] - async fn test_read_large_file(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/root"), - json!({ - "large_file.rs": (0..1000).map(|i| format!("struct Test{} {{\n a: u32,\n b: usize,\n}}", i)).collect::>().join("\n") - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _| project.languages().clone()); - language_registry.add(Arc::new(rust_lang())); - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let model = Arc::new(FakeLanguageModel::default()); - - let result = cx - .update(|cx| { - let input = json!({ - "path": "root/large_file.rs" - }); - Arc::new(ReadFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await; - let content = result.unwrap(); - let content = content.as_str().unwrap(); - assert_eq!( - content.lines().skip(4).take(6).collect::>(), - vec![ - "struct Test0 [L1-4]", - " a [L2]", - " b [L3]", - "struct Test1 [L5-8]", - " a [L6]", - " b [L7]", - ] - ); - - let result = cx - .update(|cx| { - let input = json!({ - "path": "root/large_file.rs", - "offset": 1 - }); - Arc::new(ReadFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log, - model, - None, - cx, - ) - .output - }) - .await; - let content = result.unwrap(); - let expected_content = (0..1000) - .flat_map(|i| { - vec![ - format!("struct Test{} [L{}-{}]", i, i * 4 + 1, i * 4 + 4), - format!(" a [L{}]", i * 4 + 2), - format!(" b [L{}]", i * 4 + 3), - ] - }) - .collect::>(); - pretty_assertions::assert_eq!( - content - .as_str() - .unwrap() - .lines() - .skip(4) - .take(expected_content.len()) - .collect::>(), - expected_content - ); - } - - #[gpui::test] - async fn test_read_file_with_line_range(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/root"), - json!({ - "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let model = Arc::new(FakeLanguageModel::default()); - let result = cx - .update(|cx| { - let input = json!({ - "path": "root/multiline.txt", - "start_line": 2, - "end_line": 4 - }); - Arc::new(ReadFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log, - model, - None, - cx, - ) - .output - }) - .await; - assert_eq!( - result.unwrap().content.as_str(), - Some("Line 2\nLine 3\nLine 4") - ); - } - - #[gpui::test] - async fn test_read_file_line_range_edge_cases(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/root"), - json!({ - "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let model = Arc::new(FakeLanguageModel::default()); - - // start_line of 0 should be treated as 1 - let result = cx - .update(|cx| { - let input = json!({ - "path": "root/multiline.txt", - "start_line": 0, - "end_line": 2 - }); - Arc::new(ReadFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await; - assert_eq!(result.unwrap().content.as_str(), Some("Line 1\nLine 2")); - - // end_line of 0 should result in at least 1 line - let result = cx - .update(|cx| { - let input = json!({ - "path": "root/multiline.txt", - "start_line": 1, - "end_line": 0 - }); - Arc::new(ReadFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await; - assert_eq!(result.unwrap().content.as_str(), Some("Line 1")); - - // when start_line > end_line, should still return at least 1 line - let result = cx - .update(|cx| { - let input = json!({ - "path": "root/multiline.txt", - "start_line": 3, - "end_line": 2 - }); - Arc::new(ReadFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log, - model, - None, - cx, - ) - .output - }) - .await; - assert_eq!(result.unwrap().content.as_str(), Some("Line 3")); - } - - fn init_test(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - language::init(cx); - Project::init_settings(cx); - }); - } - - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_outline_query( - r#" - (line_comment) @annotation - - (struct_item - "struct" @context - name: (_) @name) @item - (enum_item - "enum" @context - name: (_) @name) @item - (enum_variant - name: (_) @name) @item - (field_declaration - name: (_) @name) @item - (impl_item - "impl" @context - trait: (_)? @name - "for"? @context - type: (_) @name - body: (_ "{" (_)* "}")) @item - (function_item - "fn" @context - name: (_) @name) @item - (mod_item - "mod" @context - name: (_) @name) @item - "#, - ) - .unwrap() - } - - #[gpui::test] - async fn test_read_file_security(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - - fs.insert_tree( - path!("/"), - json!({ - "project_root": { - "allowed_file.txt": "This file is in the project", - ".mysecrets": "SECRET_KEY=abc123", - ".secretdir": { - "config": "special configuration" - }, - ".mymetadata": "custom metadata", - "subdir": { - "normal_file.txt": "Normal file content", - "special.privatekey": "private key content", - "data.mysensitive": "sensitive data" - } - }, - "outside_project": { - "sensitive_file.txt": "This file is outside the project" - } - }), - ) - .await; - - cx.update(|cx| { - use gpui::UpdateGlobal; - use settings::SettingsStore; - SettingsStore::update_global(cx, |store, cx| { - store.update_user_settings(cx, |settings| { - settings.project.worktree.file_scan_exclusions = Some(vec![ - "**/.secretdir".to_string(), - "**/.mymetadata".to_string(), - ]); - settings.project.worktree.private_files = Some( - vec![ - "**/.mysecrets".to_string(), - "**/*.privatekey".to_string(), - "**/*.mysensitive".to_string(), - ] - .into(), - ); - }); - }); - }); - - let project = Project::test(fs.clone(), [path!("/project_root").as_ref()], cx).await; - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let model = Arc::new(FakeLanguageModel::default()); - - // Reading a file outside the project worktree should fail - let result = cx - .update(|cx| { - let input = json!({ - "path": "/outside_project/sensitive_file.txt" - }); - Arc::new(ReadFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await; - assert!( - result.is_err(), - "read_file_tool should error when attempting to read an absolute path outside a worktree" - ); - - // Reading a file within the project should succeed - let result = cx - .update(|cx| { - let input = json!({ - "path": "project_root/allowed_file.txt" - }); - Arc::new(ReadFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await; - assert!( - result.is_ok(), - "read_file_tool should be able to read files inside worktrees" - ); - - // Reading files that match file_scan_exclusions should fail - let result = cx - .update(|cx| { - let input = json!({ - "path": "project_root/.secretdir/config" - }); - Arc::new(ReadFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await; - assert!( - result.is_err(), - "read_file_tool should error when attempting to read files in .secretdir (file_scan_exclusions)" - ); - - let result = cx - .update(|cx| { - let input = json!({ - "path": "project_root/.mymetadata" - }); - Arc::new(ReadFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await; - assert!( - result.is_err(), - "read_file_tool should error when attempting to read .mymetadata files (file_scan_exclusions)" - ); - - // Reading private files should fail - let result = cx - .update(|cx| { - let input = json!({ - "path": "project_root/.mysecrets" - }); - Arc::new(ReadFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await; - assert!( - result.is_err(), - "read_file_tool should error when attempting to read .mysecrets (private_files)" - ); - - let result = cx - .update(|cx| { - let input = json!({ - "path": "project_root/subdir/special.privatekey" - }); - Arc::new(ReadFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await; - assert!( - result.is_err(), - "read_file_tool should error when attempting to read .privatekey files (private_files)" - ); - - let result = cx - .update(|cx| { - let input = json!({ - "path": "project_root/subdir/data.mysensitive" - }); - Arc::new(ReadFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await; - assert!( - result.is_err(), - "read_file_tool should error when attempting to read .mysensitive files (private_files)" - ); - - // Reading a normal file should still work, even with private_files configured - let result = cx - .update(|cx| { - let input = json!({ - "path": "project_root/subdir/normal_file.txt" - }); - Arc::new(ReadFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await; - assert!(result.is_ok(), "Should be able to read normal files"); - assert_eq!( - result.unwrap().content.as_str().unwrap(), - "Normal file content" - ); - - // Path traversal attempts with .. should fail - let result = cx - .update(|cx| { - let input = json!({ - "path": "project_root/../outside_project/sensitive_file.txt" - }); - Arc::new(ReadFileTool) - .run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - .output - }) - .await; - assert!( - result.is_err(), - "read_file_tool should error when attempting to read a relative path that resolves to outside a worktree" - ); - } - - #[gpui::test] - async fn test_read_file_with_multiple_worktree_settings(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.executor()); - - // Create first worktree with its own private_files setting - fs.insert_tree( - path!("/worktree1"), - json!({ - "src": { - "main.rs": "fn main() { println!(\"Hello from worktree1\"); }", - "secret.rs": "const API_KEY: &str = \"secret_key_1\";", - "config.toml": "[database]\nurl = \"postgres://localhost/db1\"" - }, - "tests": { - "test.rs": "mod tests { fn test_it() {} }", - "fixture.sql": "CREATE TABLE users (id INT, name VARCHAR(255));" - }, - ".zed": { - "settings.json": r#"{ - "file_scan_exclusions": ["**/fixture.*"], - "private_files": ["**/secret.rs", "**/config.toml"] - }"# - } - }), - ) - .await; - - // Create second worktree with different private_files setting - fs.insert_tree( - path!("/worktree2"), - json!({ - "lib": { - "public.js": "export function greet() { return 'Hello from worktree2'; }", - "private.js": "const SECRET_TOKEN = \"private_token_2\";", - "data.json": "{\"api_key\": \"json_secret_key\"}" - }, - "docs": { - "README.md": "# Public Documentation", - "internal.md": "# Internal Secrets and Configuration" - }, - ".zed": { - "settings.json": r#"{ - "file_scan_exclusions": ["**/internal.*"], - "private_files": ["**/private.js", "**/data.json"] - }"# - } - }), - ) - .await; - - // Set global settings - cx.update(|cx| { - SettingsStore::update_global(cx, |store, cx| { - store.update_user_settings(cx, |settings| { - settings.project.worktree.file_scan_exclusions = - Some(vec!["**/.git".to_string(), "**/node_modules".to_string()]); - settings.project.worktree.private_files = - Some(vec!["**/.env".to_string()].into()); - }); - }); - }); - - let project = Project::test( - fs.clone(), - [path!("/worktree1").as_ref(), path!("/worktree2").as_ref()], - cx, - ) - .await; - - let action_log = cx.new(|_| ActionLog::new(project.clone())); - let model = Arc::new(FakeLanguageModel::default()); - let tool = Arc::new(ReadFileTool); - - // Test reading allowed files in worktree1 - let input = json!({ - "path": "worktree1/src/main.rs" - }); - - let result = cx - .update(|cx| { - tool.clone().run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }) - .output - .await - .unwrap(); - - assert_eq!( - result.content.as_str().unwrap(), - "fn main() { println!(\"Hello from worktree1\"); }" - ); - - // Test reading private file in worktree1 should fail - let input = json!({ - "path": "worktree1/src/secret.rs" - }); - - let result = cx - .update(|cx| { - tool.clone().run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }) - .output - .await; - - assert!(result.is_err()); - assert!( - result - .unwrap_err() - .to_string() - .contains("worktree `private_files` setting"), - "Error should mention worktree private_files setting" - ); - - // Test reading excluded file in worktree1 should fail - let input = json!({ - "path": "worktree1/tests/fixture.sql" - }); - - let result = cx - .update(|cx| { - tool.clone().run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }) - .output - .await; - - assert!(result.is_err()); - assert!( - result - .unwrap_err() - .to_string() - .contains("worktree `file_scan_exclusions` setting"), - "Error should mention worktree file_scan_exclusions setting" - ); - - // Test reading allowed files in worktree2 - let input = json!({ - "path": "worktree2/lib/public.js" - }); - - let result = cx - .update(|cx| { - tool.clone().run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }) - .output - .await - .unwrap(); - - assert_eq!( - result.content.as_str().unwrap(), - "export function greet() { return 'Hello from worktree2'; }" - ); - - // Test reading private file in worktree2 should fail - let input = json!({ - "path": "worktree2/lib/private.js" - }); - - let result = cx - .update(|cx| { - tool.clone().run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }) - .output - .await; - - assert!(result.is_err()); - assert!( - result - .unwrap_err() - .to_string() - .contains("worktree `private_files` setting"), - "Error should mention worktree private_files setting" - ); - - // Test reading excluded file in worktree2 should fail - let input = json!({ - "path": "worktree2/docs/internal.md" - }); - - let result = cx - .update(|cx| { - tool.clone().run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }) - .output - .await; - - assert!(result.is_err()); - assert!( - result - .unwrap_err() - .to_string() - .contains("worktree `file_scan_exclusions` setting"), - "Error should mention worktree file_scan_exclusions setting" - ); - - // Test that files allowed in one worktree but not in another are handled correctly - // (e.g., config.toml is private in worktree1 but doesn't exist in worktree2) - let input = json!({ - "path": "worktree1/src/config.toml" - }); - - let result = cx - .update(|cx| { - tool.clone().run( - input, - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }) - .output - .await; - - assert!(result.is_err()); - assert!( - result - .unwrap_err() - .to_string() - .contains("worktree `private_files` setting"), - "Config.toml should be blocked by worktree1's private_files setting" - ); - } -} diff --git a/crates/assistant_tools/src/read_file_tool/description.md b/crates/assistant_tools/src/read_file_tool/description.md deleted file mode 100644 index 7bcebc03341541496ab090090ab7ef8beb3f2ebe..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/read_file_tool/description.md +++ /dev/null @@ -1,3 +0,0 @@ -Reads the content of the given file in the project. - -- Never attempt to read a path that hasn't been previously mentioned. diff --git a/crates/assistant_tools/src/schema.rs b/crates/assistant_tools/src/schema.rs deleted file mode 100644 index dab7384efd8ba23669db645c87dcf79e95538d3a..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/schema.rs +++ /dev/null @@ -1,60 +0,0 @@ -use anyhow::Result; -use language_model::LanguageModelToolSchemaFormat; -use schemars::{ - JsonSchema, Schema, - generate::SchemaSettings, - transform::{Transform, transform_subschemas}, -}; - -pub fn json_schema_for( - format: LanguageModelToolSchemaFormat, -) -> Result { - let schema = root_schema_for::(format); - schema_to_json(&schema, format) -} - -fn schema_to_json( - schema: &Schema, - format: LanguageModelToolSchemaFormat, -) -> Result { - let mut value = serde_json::to_value(schema)?; - assistant_tool::adapt_schema_to_format(&mut value, format)?; - Ok(value) -} - -fn root_schema_for(format: LanguageModelToolSchemaFormat) -> Schema { - let mut generator = match format { - LanguageModelToolSchemaFormat::JsonSchema => SchemaSettings::draft07().into_generator(), - LanguageModelToolSchemaFormat::JsonSchemaSubset => SchemaSettings::openapi3() - .with(|settings| { - settings.meta_schema = None; - settings.inline_subschemas = true; - }) - .with_transform(ToJsonSchemaSubsetTransform) - .into_generator(), - }; - generator.root_schema_for::() -} - -#[derive(Debug, Clone)] -struct ToJsonSchemaSubsetTransform; - -impl Transform for ToJsonSchemaSubsetTransform { - fn transform(&mut self, schema: &mut Schema) { - // Ensure that the type field is not an array, this happens when we use - // Option, the type will be [T, "null"]. - if let Some(type_field) = schema.get_mut("type") - && let Some(types) = type_field.as_array() - && let Some(first_type) = types.first() - { - *type_field = first_type.clone(); - } - - // oneOf is not supported, use anyOf instead - if let Some(one_of) = schema.remove("oneOf") { - schema.insert("anyOf".to_string(), one_of); - } - - transform_subschemas(self, schema); - } -} diff --git a/crates/assistant_tools/src/templates.rs b/crates/assistant_tools/src/templates.rs deleted file mode 100644 index c83601199cca11e7a92f07e4159ac6241378d725..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/templates.rs +++ /dev/null @@ -1,32 +0,0 @@ -use anyhow::Result; -use handlebars::Handlebars; -use rust_embed::RustEmbed; -use serde::Serialize; -use std::sync::Arc; - -#[derive(RustEmbed)] -#[folder = "src/templates"] -#[include = "*.hbs"] -struct Assets; - -pub struct Templates(Handlebars<'static>); - -impl Templates { - pub fn new() -> Arc { - let mut handlebars = Handlebars::new(); - handlebars.register_embed_templates::().unwrap(); - handlebars.register_escape_fn(|text| text.into()); - Arc::new(Self(handlebars)) - } -} - -pub trait Template: Sized { - const TEMPLATE_NAME: &'static str; - - fn render(&self, templates: &Templates) -> Result - where - Self: Serialize + Sized, - { - Ok(templates.0.render(Self::TEMPLATE_NAME, self)?) - } -} diff --git a/crates/assistant_tools/src/terminal_tool.rs b/crates/assistant_tools/src/terminal_tool.rs deleted file mode 100644 index cab1498c0bfda186e3d52c7bce02b8f457d4fd85..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/terminal_tool.rs +++ /dev/null @@ -1,883 +0,0 @@ -use crate::{ - schema::json_schema_for, - ui::{COLLAPSED_LINES, ToolOutputPreview}, -}; -use action_log::ActionLog; -use agent_settings; -use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{Tool, ToolCard, ToolResult, ToolUseStatus}; -use futures::FutureExt as _; -use gpui::{ - AnyWindowHandle, App, AppContext, Empty, Entity, EntityId, Task, TextStyleRefinement, - WeakEntity, Window, -}; -use language::LineEnding; -use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; -use markdown::{Markdown, MarkdownElement, MarkdownStyle}; -use portable_pty::{CommandBuilder, PtySize, native_pty_system}; -use project::Project; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsLocation}; -use std::{ - env, - path::{Path, PathBuf}, - process::ExitStatus, - sync::Arc, - time::{Duration, Instant}, -}; -use task::{Shell, ShellBuilder}; -use terminal::terminal_settings::TerminalSettings; -use terminal_view::TerminalView; -use theme::ThemeSettings; -use ui::{CommonAnimationExt, Disclosure, Tooltip, prelude::*}; -use util::{ - ResultExt, get_default_system_shell_preferring_bash, markdown::MarkdownInlineCode, - size::format_file_size, time::duration_alt_display, -}; -use workspace::Workspace; - -const COMMAND_OUTPUT_LIMIT: usize = 16 * 1024; - -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] -pub struct TerminalToolInput { - /// The one-liner command to execute. - command: String, - /// Working directory for the command. This must be one of the root directories of the project. - cd: String, -} - -pub struct TerminalTool; - -impl TerminalTool { - pub const NAME: &str = "terminal"; -} - -impl Tool for TerminalTool { - fn name(&self) -> String { - Self::NAME.to_string() - } - - fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { - true - } - - fn may_perform_edits(&self) -> bool { - false - } - - fn description(&self) -> String { - include_str!("./terminal_tool/description.md").to_string() - } - - fn icon(&self) -> IconName { - IconName::ToolTerminal - } - - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { - json_schema_for::(format) - } - - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => { - let mut lines = input.command.lines(); - let first_line = lines.next().unwrap_or_default(); - let remaining_line_count = lines.count(); - match remaining_line_count { - 0 => MarkdownInlineCode(first_line).to_string(), - 1 => MarkdownInlineCode(&format!( - "{} - {} more line", - first_line, remaining_line_count - )) - .to_string(), - n => MarkdownInlineCode(&format!("{} - {} more lines", first_line, n)) - .to_string(), - } - } - Err(_) => "Run terminal command".to_string(), - } - } - - fn run( - self: Arc, - input: serde_json::Value, - _request: Arc, - project: Entity, - _action_log: Entity, - _model: Arc, - window: Option, - cx: &mut App, - ) -> ToolResult { - let input: TerminalToolInput = match serde_json::from_value(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; - - let working_dir = match working_dir(&input, &project, cx) { - Ok(dir) => dir, - Err(err) => return Task::ready(Err(err)).into(), - }; - - let cwd = working_dir.clone(); - let env = match &cwd { - Some(dir) => project.update(cx, |project, cx| { - let worktree = project.find_worktree(dir.as_path(), cx); - let shell = TerminalSettings::get( - worktree.as_ref().map(|(worktree, path)| SettingsLocation { - worktree_id: worktree.read(cx).id(), - path: &path, - }), - cx, - ) - .shell - .clone(); - project.directory_environment(&shell, dir.as_path().into(), cx) - }), - None => Task::ready(None).shared(), - }; - let is_windows = project.read(cx).path_style(cx).is_windows(); - let shell = project - .update(cx, |project, cx| { - project - .remote_client() - .and_then(|r| r.read(cx).default_system_shell()) - }) - .unwrap_or_else(|| get_default_system_shell_preferring_bash()); - - let env = cx.spawn(async move |_| { - let mut env = env.await.unwrap_or_default(); - if cfg!(unix) { - env.insert("PAGER".into(), "cat".into()); - } - env - }); - - let build_cmd = { - let input_command = input.command.clone(); - move || { - ShellBuilder::new(&Shell::Program(shell), is_windows) - .redirect_stdin_to_dev_null() - .build(Some(input_command), &[]) - } - }; - - let Some(window) = window else { - // Headless setup, a test or eval. Our terminal subsystem requires a workspace, - // so bypass it and provide a convincing imitation using a pty. - let task = cx.background_spawn(async move { - let env = env.await; - let pty_system = native_pty_system(); - let (command, args) = build_cmd(); - let mut cmd = CommandBuilder::new(command); - cmd.args(args); - for (k, v) in env { - cmd.env(k, v); - } - if let Some(cwd) = cwd { - cmd.cwd(cwd); - } - let pair = pty_system.openpty(PtySize { - rows: 24, - cols: 80, - ..Default::default() - })?; - let mut child = pair.slave.spawn_command(cmd)?; - let mut reader = pair.master.try_clone_reader()?; - drop(pair); - let mut content = String::new(); - reader.read_to_string(&mut content)?; - // Massage the pty output a bit to try to match what the terminal codepath gives us - LineEnding::normalize(&mut content); - content = content - .chars() - .filter(|c| c.is_ascii_whitespace() || !c.is_ascii_control()) - .collect(); - let content = content.trim_start().trim_start_matches("^D"); - let exit_status = child.wait()?; - let (processed_content, _) = - process_content(content, &input.command, Some(exit_status)); - Ok(processed_content.into()) - }); - return ToolResult { - output: task, - card: None, - }; - }; - - let terminal = cx.spawn({ - let project = project.downgrade(); - async move |cx| { - let (command, args) = build_cmd(); - let env = env.await; - project - .update(cx, |project, cx| { - project.create_terminal_task( - task::SpawnInTerminal { - command: Some(command), - args, - cwd, - env, - ..Default::default() - }, - cx, - ) - })? - .await - } - }); - - let command_markdown = cx.new(|cx| { - Markdown::new( - format!("```bash\n{}\n```", input.command).into(), - None, - None, - cx, - ) - }); - - let card = - cx.new(|cx| TerminalToolCard::new(command_markdown, working_dir, cx.entity_id(), cx)); - - let output = cx.spawn({ - let card = card.clone(); - async move |cx| { - let terminal = terminal.await?; - let workspace = window - .downcast::() - .and_then(|handle| handle.entity(cx).ok()) - .context("no workspace entity in root of window")?; - - let terminal_view = window.update(cx, |_, window, cx| { - cx.new(|cx| { - let mut view = TerminalView::new( - terminal.clone(), - workspace.downgrade(), - None, - project.downgrade(), - window, - cx, - ); - view.set_embedded_mode(None, cx); - view - }) - })?; - - card.update(cx, |card, _| { - card.terminal = Some(terminal_view.clone()); - card.start_instant = Instant::now(); - }) - .log_err(); - - let exit_status = terminal - .update(cx, |terminal, cx| terminal.wait_for_completed_task(cx))? - .await; - let (content, content_line_count) = terminal.read_with(cx, |terminal, _| { - (terminal.get_content(), terminal.total_lines()) - })?; - - let previous_len = content.len(); - let (processed_content, finished_with_empty_output) = process_content( - &content, - &input.command, - exit_status.map(portable_pty::ExitStatus::from), - ); - - card.update(cx, |card, _| { - card.command_finished = true; - card.exit_status = exit_status; - card.was_content_truncated = processed_content.len() < previous_len; - card.original_content_len = previous_len; - card.content_line_count = content_line_count; - card.finished_with_empty_output = finished_with_empty_output; - card.elapsed_time = Some(card.start_instant.elapsed()); - }) - .log_err(); - - Ok(processed_content.into()) - } - }); - - ToolResult { - output, - card: Some(card.into()), - } - } -} - -fn process_content( - content: &str, - command: &str, - exit_status: Option, -) -> (String, bool) { - let should_truncate = content.len() > COMMAND_OUTPUT_LIMIT; - - let content = if should_truncate { - let mut end_ix = COMMAND_OUTPUT_LIMIT.min(content.len()); - while !content.is_char_boundary(end_ix) { - end_ix -= 1; - } - // Don't truncate mid-line, clear the remainder of the last line - end_ix = content[..end_ix].rfind('\n').unwrap_or(end_ix); - &content[..end_ix] - } else { - content - }; - let content = content.trim(); - let is_empty = content.is_empty(); - let content = format!("```\n{content}\n```"); - let content = if should_truncate { - format!( - "Command output too long. The first {} bytes:\n\n{content}", - content.len(), - ) - } else { - content - }; - - let content = match exit_status { - Some(exit_status) if exit_status.success() => { - if is_empty { - "Command executed successfully.".to_string() - } else { - content - } - } - Some(exit_status) => { - if is_empty { - format!( - "Command \"{command}\" failed with exit code {}.", - exit_status.exit_code() - ) - } else { - format!( - "Command \"{command}\" failed with exit code {}.\n\n{content}", - exit_status.exit_code() - ) - } - } - None => { - format!( - "Command failed or was interrupted.\nPartial output captured:\n\n{}", - content, - ) - } - }; - (content, is_empty) -} - -fn working_dir( - input: &TerminalToolInput, - project: &Entity, - cx: &mut App, -) -> Result> { - let project = project.read(cx); - let cd = &input.cd; - - if cd == "." || cd.is_empty() { - // Accept "." or "" as meaning "the one worktree" if we only have one worktree. - let mut worktrees = project.worktrees(cx); - - match worktrees.next() { - Some(worktree) => { - anyhow::ensure!( - worktrees.next().is_none(), - "'.' is ambiguous in multi-root workspaces. Please specify a root directory explicitly.", - ); - Ok(Some(worktree.read(cx).abs_path().to_path_buf())) - } - None => Ok(None), - } - } else { - let input_path = Path::new(cd); - - if input_path.is_absolute() { - // Absolute paths are allowed, but only if they're in one of the project's worktrees. - if project - .worktrees(cx) - .any(|worktree| input_path.starts_with(&worktree.read(cx).abs_path())) - { - return Ok(Some(input_path.into())); - } - } else if let Some(worktree) = project.worktree_for_root_name(cd, cx) { - return Ok(Some(worktree.read(cx).abs_path().to_path_buf())); - } - - anyhow::bail!("`cd` directory {cd:?} was not in any of the project's worktrees."); - } -} - -struct TerminalToolCard { - input_command: Entity, - working_dir: Option, - entity_id: EntityId, - exit_status: Option, - terminal: Option>, - command_finished: bool, - was_content_truncated: bool, - finished_with_empty_output: bool, - content_line_count: usize, - original_content_len: usize, - preview_expanded: bool, - start_instant: Instant, - elapsed_time: Option, -} - -impl TerminalToolCard { - pub fn new( - input_command: Entity, - working_dir: Option, - entity_id: EntityId, - cx: &mut Context, - ) -> Self { - let expand_terminal_card = - agent_settings::AgentSettings::get_global(cx).expand_terminal_card; - Self { - input_command, - working_dir, - entity_id, - exit_status: None, - terminal: None, - command_finished: false, - was_content_truncated: false, - finished_with_empty_output: false, - original_content_len: 0, - content_line_count: 0, - preview_expanded: expand_terminal_card, - start_instant: Instant::now(), - elapsed_time: None, - } - } -} - -impl ToolCard for TerminalToolCard { - fn render( - &mut self, - status: &ToolUseStatus, - window: &mut Window, - _workspace: WeakEntity, - cx: &mut Context, - ) -> impl IntoElement { - let Some(terminal) = self.terminal.as_ref() else { - return Empty.into_any(); - }; - - let tool_failed = matches!(status, ToolUseStatus::Error(_)); - - let command_failed = - self.command_finished && self.exit_status.is_none_or(|code| !code.success()); - - if (tool_failed || command_failed) && self.elapsed_time.is_none() { - self.elapsed_time = Some(self.start_instant.elapsed()); - } - let time_elapsed = self - .elapsed_time - .unwrap_or_else(|| self.start_instant.elapsed()); - - let header_bg = cx - .theme() - .colors() - .element_background - .blend(cx.theme().colors().editor_foreground.opacity(0.025)); - - let border_color = cx.theme().colors().border.opacity(0.6); - - let path = self - .working_dir - .as_ref() - .cloned() - .or_else(|| env::current_dir().ok()) - .map(|path| path.display().to_string()) - .unwrap_or_else(|| "current directory".to_string()); - - let header = h_flex() - .flex_none() - .gap_1() - .justify_between() - .rounded_t_md() - .child( - div() - .id(("command-target-path", self.entity_id)) - .w_full() - .max_w_full() - .overflow_x_scroll() - .child( - Label::new(path) - .buffer_font(cx) - .size(LabelSize::XSmall) - .color(Color::Muted), - ), - ) - .when(!self.command_finished, |header| { - header.child( - Icon::new(IconName::ArrowCircle) - .size(IconSize::XSmall) - .color(Color::Info) - .with_rotate_animation(2), - ) - }) - .when(tool_failed || command_failed, |header| { - header.child( - div() - .id(("terminal-tool-error-code-indicator", self.entity_id)) - .child( - Icon::new(IconName::Close) - .size(IconSize::Small) - .color(Color::Error), - ) - .when(command_failed && self.exit_status.is_some(), |this| { - this.tooltip(Tooltip::text(format!( - "Exited with code {}", - self.exit_status - .and_then(|status| status.code()) - .unwrap_or(-1), - ))) - }) - .when( - !command_failed && tool_failed && status.error().is_some(), - |this| { - this.tooltip(Tooltip::text(format!( - "Error: {}", - status.error().unwrap(), - ))) - }, - ), - ) - }) - .when(self.was_content_truncated, |header| { - let tooltip = if self.content_line_count + 10 > terminal::MAX_SCROLL_HISTORY_LINES { - "Output exceeded terminal max lines and was \ - truncated, the model received the first 16 KB." - .to_string() - } else { - format!( - "Output is {} long, to avoid unexpected token usage, \ - only 16 KB was sent back to the model.", - format_file_size(self.original_content_len as u64, true), - ) - }; - header.child( - h_flex() - .id(("terminal-tool-truncated-label", self.entity_id)) - .tooltip(Tooltip::text(tooltip)) - .gap_1() - .child( - Icon::new(IconName::Info) - .size(IconSize::XSmall) - .color(Color::Ignored), - ) - .child( - Label::new("Truncated") - .color(Color::Muted) - .size(LabelSize::Small), - ), - ) - }) - .when(time_elapsed > Duration::from_secs(10), |header| { - header.child( - Label::new(format!("({})", duration_alt_display(time_elapsed))) - .buffer_font(cx) - .color(Color::Muted) - .size(LabelSize::Small), - ) - }) - .when(!self.finished_with_empty_output, |header| { - header.child( - Disclosure::new( - ("terminal-tool-disclosure", self.entity_id), - self.preview_expanded, - ) - .opened_icon(IconName::ChevronUp) - .closed_icon(IconName::ChevronDown) - .on_click(cx.listener( - move |this, _event, _window, _cx| { - this.preview_expanded = !this.preview_expanded; - }, - )), - ) - }); - - v_flex() - .mb_2() - .border_1() - .when(tool_failed || command_failed, |card| card.border_dashed()) - .border_color(border_color) - .rounded_lg() - .overflow_hidden() - .child( - v_flex() - .p_2() - .gap_0p5() - .bg(header_bg) - .text_xs() - .child(header) - .child( - MarkdownElement::new( - self.input_command.clone(), - markdown_style(window, cx), - ) - .code_block_renderer( - markdown::CodeBlockRenderer::Default { - copy_button: false, - copy_button_on_hover: true, - border: false, - }, - ), - ), - ) - .when( - self.preview_expanded && !self.finished_with_empty_output, - |this| { - this.child( - div() - .pt_2() - .border_t_1() - .when(tool_failed || command_failed, |card| card.border_dashed()) - .border_color(border_color) - .bg(cx.theme().colors().editor_background) - .rounded_b_md() - .text_ui_sm(cx) - .child({ - let content_mode = terminal.read(cx).content_mode(window, cx); - - if content_mode.is_scrollable() { - div().h_72().child(terminal.clone()).into_any_element() - } else { - ToolOutputPreview::new( - terminal.clone().into_any_element(), - terminal.entity_id(), - ) - .with_total_lines(self.content_line_count) - .toggle_state(!content_mode.is_limited()) - .on_toggle({ - let terminal = terminal.clone(); - move |is_expanded, _, cx| { - terminal.update(cx, |terminal, cx| { - terminal.set_embedded_mode( - if is_expanded { - None - } else { - Some(COLLAPSED_LINES) - }, - cx, - ); - }); - } - }) - .into_any_element() - } - }), - ) - }, - ) - .into_any() - } -} - -fn markdown_style(window: &Window, cx: &App) -> MarkdownStyle { - let theme_settings = ThemeSettings::get_global(cx); - let buffer_font_size = TextSize::Default.rems(cx); - let mut text_style = window.text_style(); - - text_style.refine(&TextStyleRefinement { - font_family: Some(theme_settings.buffer_font.family.clone()), - font_fallbacks: theme_settings.buffer_font.fallbacks.clone(), - font_features: Some(theme_settings.buffer_font.features.clone()), - font_size: Some(buffer_font_size.into()), - color: Some(cx.theme().colors().text), - ..Default::default() - }); - - MarkdownStyle { - base_text_style: text_style.clone(), - selection_background_color: cx.theme().colors().element_selection_background, - ..Default::default() - } -} - -#[cfg(test)] -mod tests { - use editor::EditorSettings; - use fs::RealFs; - use gpui::{BackgroundExecutor, TestAppContext}; - use language_model::fake_provider::FakeLanguageModel; - use pretty_assertions::assert_eq; - use serde_json::json; - use settings::{Settings, SettingsStore}; - use terminal::terminal_settings::TerminalSettings; - use util::{ResultExt as _, test::TempTree}; - - use super::*; - - fn init_test(executor: &BackgroundExecutor, cx: &mut TestAppContext) { - zlog::init_test(); - - executor.allow_parking(); - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - language::init(cx); - Project::init_settings(cx); - workspace::init_settings(cx); - theme::init(theme::LoadThemes::JustBase, cx); - TerminalSettings::register(cx); - EditorSettings::register(cx); - }); - } - - #[gpui::test] - async fn test_interactive_command(executor: BackgroundExecutor, cx: &mut TestAppContext) { - if cfg!(windows) { - return; - } - init_test(&executor, cx); - - let fs = Arc::new(RealFs::new(None, executor)); - let tree = TempTree::new(json!({ - "project": {}, - })); - let project: Entity = - Project::test(fs, [tree.path().join("project").as_path()], cx).await; - let action_log = cx.update(|cx| cx.new(|_| ActionLog::new(project.clone()))); - let model = Arc::new(FakeLanguageModel::default()); - - let input = TerminalToolInput { - command: "cat".to_owned(), - cd: tree - .path() - .join("project") - .as_path() - .to_string_lossy() - .to_string(), - }; - let result = cx.update(|cx| { - TerminalTool::run( - Arc::new(TerminalTool), - serde_json::to_value(input).unwrap(), - Arc::default(), - project.clone(), - action_log.clone(), - model, - None, - cx, - ) - }); - - let output = result.output.await.log_err().unwrap().content; - assert_eq!(output.as_str().unwrap(), "Command executed successfully."); - } - - #[gpui::test] - async fn test_working_directory(executor: BackgroundExecutor, cx: &mut TestAppContext) { - if cfg!(windows) { - return; - } - init_test(&executor, cx); - - let fs = Arc::new(RealFs::new(None, executor)); - let tree = TempTree::new(json!({ - "project": {}, - "other-project": {}, - })); - let project: Entity = - Project::test(fs, [tree.path().join("project").as_path()], cx).await; - let action_log = cx.update(|cx| cx.new(|_| ActionLog::new(project.clone()))); - let model = Arc::new(FakeLanguageModel::default()); - - let check = |input, expected, cx: &mut App| { - let headless_result = TerminalTool::run( - Arc::new(TerminalTool), - serde_json::to_value(input).unwrap(), - Arc::default(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ); - cx.spawn(async move |_| { - let output = headless_result.output.await.map(|output| output.content); - assert_eq!( - output - .ok() - .and_then(|content| content.as_str().map(ToString::to_string)), - expected - ); - }) - }; - - cx.update(|cx| { - check( - TerminalToolInput { - command: "pwd".into(), - cd: ".".into(), - }, - Some(format!( - "```\n{}\n```", - tree.path().join("project").display() - )), - cx, - ) - }) - .await; - - cx.update(|cx| { - check( - TerminalToolInput { - command: "pwd".into(), - cd: "other-project".into(), - }, - None, // other-project is a dir, but *not* a worktree (yet) - cx, - ) - }) - .await; - - // Absolute path above the worktree root - cx.update(|cx| { - check( - TerminalToolInput { - command: "pwd".into(), - cd: tree.path().to_string_lossy().into(), - }, - None, - cx, - ) - }) - .await; - - project - .update(cx, |project, cx| { - project.create_worktree(tree.path().join("other-project"), true, cx) - }) - .await - .unwrap(); - - cx.update(|cx| { - check( - TerminalToolInput { - command: "pwd".into(), - cd: "other-project".into(), - }, - Some(format!( - "```\n{}\n```", - tree.path().join("other-project").display() - )), - cx, - ) - }) - .await; - - cx.update(|cx| { - check( - TerminalToolInput { - command: "pwd".into(), - cd: ".".into(), - }, - None, - cx, - ) - }) - .await; - } -} diff --git a/crates/assistant_tools/src/terminal_tool/description.md b/crates/assistant_tools/src/terminal_tool/description.md deleted file mode 100644 index 3cb5d87d163b3919abafa899ed2fbdba67500773..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/terminal_tool/description.md +++ /dev/null @@ -1,11 +0,0 @@ -Executes a shell one-liner and returns the combined output. - -This tool spawns a process using the user's shell, reads from stdout and stderr (preserving the order of writes), and returns a string with the combined output result. - -The output results will be shown to the user already, only list it again if necessary, avoid being redundant. - -Make sure you use the `cd` parameter to navigate to one of the root directories of the project. NEVER do it as part of the `command` itself, otherwise it will error. - -Do not use this tool for commands that run indefinitely, such as servers (like `npm run start`, `npm run dev`, `python -m http.server`, etc) or file watchers that don't terminate on their own. - -Remember that each invocation of this tool will spawn a new shell process, so you can't rely on any state from previous invocations. diff --git a/crates/assistant_tools/src/thinking_tool.rs b/crates/assistant_tools/src/thinking_tool.rs deleted file mode 100644 index 17ce4afc2eeeff8c6f37834cd9e8c4ff71e7cd70..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/thinking_tool.rs +++ /dev/null @@ -1,69 +0,0 @@ -use std::sync::Arc; - -use crate::schema::json_schema_for; -use action_log::ActionLog; -use anyhow::{Result, anyhow}; -use assistant_tool::{Tool, ToolResult}; -use gpui::{AnyWindowHandle, App, Entity, Task}; -use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; -use project::Project; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use ui::IconName; - -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -pub struct ThinkingToolInput { - /// Content to think about. This should be a description of what to think about or - /// a problem to solve. - content: String, -} - -pub struct ThinkingTool; - -impl Tool for ThinkingTool { - fn name(&self) -> String { - "thinking".to_string() - } - - fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { - false - } - - fn may_perform_edits(&self) -> bool { - false - } - - fn description(&self) -> String { - include_str!("./thinking_tool/description.md").to_string() - } - - fn icon(&self) -> IconName { - IconName::ToolThink - } - - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { - json_schema_for::(format) - } - - fn ui_text(&self, _input: &serde_json::Value) -> String { - "Thinking".to_string() - } - - fn run( - self: Arc, - input: serde_json::Value, - _request: Arc, - _project: Entity, - _action_log: Entity, - _model: Arc, - _window: Option, - _cx: &mut App, - ) -> ToolResult { - // This tool just "thinks out loud" and doesn't perform any actions. - Task::ready(match serde_json::from_value::(input) { - Ok(_input) => Ok("Finished thinking.".to_string().into()), - Err(err) => Err(anyhow!(err)), - }) - .into() - } -} diff --git a/crates/assistant_tools/src/thinking_tool/description.md b/crates/assistant_tools/src/thinking_tool/description.md deleted file mode 100644 index b625d22f321fa427945fdb9c42aaaed9ab86f6be..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/thinking_tool/description.md +++ /dev/null @@ -1 +0,0 @@ -A tool for thinking through problems, brainstorming ideas, or planning without executing any actions. Use this tool when you need to work through complex problems, develop strategies, or outline approaches before taking action. diff --git a/crates/assistant_tools/src/ui.rs b/crates/assistant_tools/src/ui.rs deleted file mode 100644 index 793427385456939eb1a7070fff5bba928a6c2643..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/ui.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod tool_call_card_header; -mod tool_output_preview; - -pub use tool_call_card_header::*; -pub use tool_output_preview::*; diff --git a/crates/assistant_tools/src/ui/tool_call_card_header.rs b/crates/assistant_tools/src/ui/tool_call_card_header.rs deleted file mode 100644 index b41f19432f99685cf745f684228169b53939fffb..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/ui/tool_call_card_header.rs +++ /dev/null @@ -1,131 +0,0 @@ -use gpui::{Animation, AnimationExt, AnyElement, App, IntoElement, pulsating_between}; -use std::time::Duration; -use ui::{Tooltip, prelude::*}; - -/// A reusable header component for tool call cards. -#[derive(IntoElement)] -pub struct ToolCallCardHeader { - icon: IconName, - primary_text: SharedString, - secondary_text: Option, - code_path: Option, - disclosure_slot: Option, - is_loading: bool, - error: Option, -} - -impl ToolCallCardHeader { - pub fn new(icon: IconName, primary_text: impl Into) -> Self { - Self { - icon, - primary_text: primary_text.into(), - secondary_text: None, - code_path: None, - disclosure_slot: None, - is_loading: false, - error: None, - } - } - - pub fn with_secondary_text(mut self, text: impl Into) -> Self { - self.secondary_text = Some(text.into()); - self - } - - pub fn with_code_path(mut self, text: impl Into) -> Self { - self.code_path = Some(text.into()); - self - } - - pub fn disclosure_slot(mut self, element: impl IntoElement) -> Self { - self.disclosure_slot = Some(element.into_any_element()); - self - } - - pub fn loading(mut self) -> Self { - self.is_loading = true; - self - } - - pub fn with_error(mut self, error: impl Into) -> Self { - self.error = Some(error.into()); - self - } -} - -impl RenderOnce for ToolCallCardHeader { - fn render(self, window: &mut Window, cx: &mut App) -> impl IntoElement { - let font_size = rems(0.8125); - let line_height = window.line_height(); - - let secondary_text = self.secondary_text; - let code_path = self.code_path; - - let bullet_divider = || { - div() - .size(px(3.)) - .rounded_full() - .bg(cx.theme().colors().text) - }; - - h_flex() - .id("tool-label-container") - .gap_2() - .max_w_full() - .overflow_x_scroll() - .opacity(0.8) - .child( - h_flex() - .h(line_height) - .gap_1p5() - .text_size(font_size) - .child( - h_flex().h(line_height).justify_center().child( - Icon::new(self.icon) - .size(IconSize::Small) - .color(Color::Muted), - ), - ) - .map(|this| { - if let Some(error) = &self.error { - this.child(format!("{} failed", self.primary_text)).child( - IconButton::new("error_info", IconName::Warning) - .shape(ui::IconButtonShape::Square) - .icon_size(IconSize::XSmall) - .icon_color(Color::Warning) - .tooltip(Tooltip::text(error.clone())), - ) - } else { - this.child(self.primary_text.clone()) - } - }) - .when_some(secondary_text, |this, secondary_text| { - this.child(bullet_divider()) - .child(div().text_size(font_size).child(secondary_text)) - }) - .when_some(code_path, |this, code_path| { - this.child(bullet_divider()) - .child(Label::new(code_path).size(LabelSize::Small).inline_code(cx)) - }) - .with_animation( - "loading-label", - Animation::new(Duration::from_secs(2)) - .repeat() - .with_easing(pulsating_between(0.6, 1.)), - move |this, delta| { - if self.is_loading { - this.opacity(delta) - } else { - this - } - }, - ), - ) - .when_some(self.disclosure_slot, |container, disclosure_slot| { - container - .group("disclosure") - .justify_between() - .child(div().visible_on_hover("disclosure").child(disclosure_slot)) - }) - } -} diff --git a/crates/assistant_tools/src/ui/tool_output_preview.rs b/crates/assistant_tools/src/ui/tool_output_preview.rs deleted file mode 100644 index a672bb8b99daa1fd776f59c4e8be789b8e25240c..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/ui/tool_output_preview.rs +++ /dev/null @@ -1,115 +0,0 @@ -use gpui::{AnyElement, EntityId, prelude::*}; -use ui::{Tooltip, prelude::*}; - -#[derive(IntoElement)] -pub struct ToolOutputPreview -where - F: Fn(bool, &mut Window, &mut App) + 'static, -{ - content: AnyElement, - entity_id: EntityId, - full_height: bool, - total_lines: usize, - collapsed_fade: bool, - on_toggle: Option, -} - -pub const COLLAPSED_LINES: usize = 10; - -impl ToolOutputPreview -where - F: Fn(bool, &mut Window, &mut App) + 'static, -{ - pub fn new(content: AnyElement, entity_id: EntityId) -> Self { - Self { - content, - entity_id, - full_height: true, - total_lines: 0, - collapsed_fade: false, - on_toggle: None, - } - } - - pub fn with_total_lines(mut self, total_lines: usize) -> Self { - self.total_lines = total_lines; - self - } - - pub fn toggle_state(mut self, full_height: bool) -> Self { - self.full_height = full_height; - self - } - - pub fn with_collapsed_fade(mut self) -> Self { - self.collapsed_fade = true; - self - } - - pub fn on_toggle(mut self, listener: F) -> Self { - self.on_toggle = Some(listener); - self - } -} - -impl RenderOnce for ToolOutputPreview -where - F: Fn(bool, &mut Window, &mut App) + 'static, -{ - fn render(self, _window: &mut Window, cx: &mut App) -> impl IntoElement { - if self.total_lines <= COLLAPSED_LINES { - return self.content; - } - let border_color = cx.theme().colors().border.opacity(0.6); - - let (icon, tooltip_label) = if self.full_height { - (IconName::ChevronUp, "Collapse") - } else { - (IconName::ChevronDown, "Expand") - }; - - let gradient_overlay = - if self.collapsed_fade && !self.full_height { - Some(div().absolute().bottom_5().left_0().w_full().h_2_5().bg( - gpui::linear_gradient( - 0., - gpui::linear_color_stop(cx.theme().colors().editor_background, 0.), - gpui::linear_color_stop( - cx.theme().colors().editor_background.opacity(0.), - 1., - ), - ), - )) - } else { - None - }; - - v_flex() - .relative() - .child(self.content) - .children(gradient_overlay) - .child( - h_flex() - .id(("expand-button", self.entity_id)) - .flex_none() - .cursor_pointer() - .h_5() - .justify_center() - .border_t_1() - .rounded_b_md() - .border_color(border_color) - .bg(cx.theme().colors().editor_background) - .hover(|style| style.bg(cx.theme().colors().element_hover.opacity(0.1))) - .child(Icon::new(icon).size(IconSize::Small).color(Color::Muted)) - .tooltip(Tooltip::text(tooltip_label)) - .when_some(self.on_toggle, |this, on_toggle| { - this.on_click({ - move |_, window, cx| { - on_toggle(!self.full_height, window, cx); - } - }) - }), - ) - .into_any() - } -} diff --git a/crates/assistant_tools/src/web_search_tool.rs b/crates/assistant_tools/src/web_search_tool.rs deleted file mode 100644 index dbcca0a1f6f2d5f679fd240a5bfe64c6c9705256..0000000000000000000000000000000000000000 --- a/crates/assistant_tools/src/web_search_tool.rs +++ /dev/null @@ -1,327 +0,0 @@ -use std::{sync::Arc, time::Duration}; - -use crate::schema::json_schema_for; -use crate::ui::ToolCallCardHeader; -use action_log::ActionLog; -use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{ - Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus, -}; -use cloud_llm_client::{WebSearchResponse, WebSearchResult}; -use futures::{Future, FutureExt, TryFutureExt}; -use gpui::{ - AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window, -}; -use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; -use project::Project; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use ui::{IconName, Tooltip, prelude::*}; -use web_search::WebSearchRegistry; -use workspace::Workspace; - -#[derive(Debug, Serialize, Deserialize, JsonSchema)] -pub struct WebSearchToolInput { - /// The search term or question to query on the web. - query: String, -} - -pub struct WebSearchTool; - -impl Tool for WebSearchTool { - fn name(&self) -> String { - "web_search".into() - } - - fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity, _: &App) -> bool { - false - } - - fn may_perform_edits(&self) -> bool { - false - } - - fn description(&self) -> String { - "Search the web for information using your query. Use this when you need real-time information, facts, or data that might not be in your training. Results will include snippets and links from relevant web pages.".into() - } - - fn icon(&self) -> IconName { - IconName::ToolWeb - } - - fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { - json_schema_for::(format) - } - - fn ui_text(&self, _input: &serde_json::Value) -> String { - "Searching the Web".to_string() - } - - fn run( - self: Arc, - input: serde_json::Value, - _request: Arc, - _project: Entity, - _action_log: Entity, - _model: Arc, - _window: Option, - cx: &mut App, - ) -> ToolResult { - let input = match serde_json::from_value::(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; - let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else { - return Task::ready(Err(anyhow!("Web search is not available."))).into(); - }; - - let search_task = provider.search(input.query, cx).map_err(Arc::new).shared(); - let output = cx.background_spawn({ - let search_task = search_task.clone(); - async move { - let response = search_task.await.map_err(|err| anyhow!(err))?; - Ok(ToolResultOutput { - content: ToolResultContent::Text( - serde_json::to_string(&response) - .context("Failed to serialize search results")?, - ), - output: Some(serde_json::to_value(response)?), - }) - } - }); - - ToolResult { - output, - card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()), - } - } - - fn deserialize_card( - self: Arc, - output: serde_json::Value, - _project: Entity, - _window: &mut Window, - cx: &mut App, - ) -> Option { - let output = serde_json::from_value::(output).ok()?; - let card = cx.new(|cx| WebSearchToolCard::new(Task::ready(Ok(output)), cx)); - Some(card.into()) - } -} - -#[derive(RegisterComponent)] -struct WebSearchToolCard { - response: Option>, - _task: Task<()>, -} - -impl WebSearchToolCard { - fn new( - search_task: impl 'static + Future>>, - cx: &mut Context, - ) -> Self { - let _task = cx.spawn(async move |this, cx| { - let response = search_task.await.map_err(|err| anyhow!(err)); - this.update(cx, |this, cx| { - this.response = Some(response); - cx.notify(); - }) - .ok(); - }); - - Self { - response: None, - _task, - } - } -} - -impl ToolCard for WebSearchToolCard { - fn render( - &mut self, - _status: &ToolUseStatus, - _window: &mut Window, - _workspace: WeakEntity, - cx: &mut Context, - ) -> impl IntoElement { - let icon = IconName::ToolWeb; - - let header = match self.response.as_ref() { - Some(Ok(response)) => { - let text: SharedString = if response.results.len() == 1 { - "1 result".into() - } else { - format!("{} results", response.results.len()).into() - }; - ToolCallCardHeader::new(icon, "Searched the Web").with_secondary_text(text) - } - Some(Err(error)) => { - ToolCallCardHeader::new(icon, "Web Search").with_error(error.to_string()) - } - None => ToolCallCardHeader::new(icon, "Searching the Web").loading(), - }; - - let content = self.response.as_ref().and_then(|response| match response { - Ok(response) => Some( - v_flex() - .overflow_hidden() - .ml_1p5() - .pl(px(5.)) - .border_l_1() - .border_color(cx.theme().colors().border_variant) - .gap_1() - .children(response.results.iter().enumerate().map(|(index, result)| { - let title = result.title.clone(); - let url = SharedString::from(result.url.clone()); - - Button::new(("result", index), title) - .label_size(LabelSize::Small) - .color(Color::Muted) - .icon(IconName::ArrowUpRight) - .icon_size(IconSize::Small) - .icon_position(IconPosition::End) - .truncate(true) - .tooltip({ - let url = url.clone(); - move |window, cx| { - Tooltip::with_meta( - "Web Search Result", - None, - url.clone(), - window, - cx, - ) - } - }) - .on_click(move |_, _, cx| cx.open_url(&url)) - })) - .into_any(), - ), - Err(_) => None, - }); - - v_flex().mb_3().gap_1().child(header).children(content) - } -} - -impl Component for WebSearchToolCard { - fn scope() -> ComponentScope { - ComponentScope::Agent - } - - fn preview(window: &mut Window, cx: &mut App) -> Option { - let in_progress_search = cx.new(|cx| WebSearchToolCard { - response: None, - _task: cx.spawn(async move |_this, cx| { - loop { - cx.background_executor() - .timer(Duration::from_secs(60)) - .await - } - }), - }); - - let successful_search = cx.new(|_cx| WebSearchToolCard { - response: Some(Ok(example_search_response())), - _task: Task::ready(()), - }); - - let error_search = cx.new(|_cx| WebSearchToolCard { - response: Some(Err(anyhow!("Failed to resolve https://google.com"))), - _task: Task::ready(()), - }); - - Some( - v_flex() - .gap_6() - .children(vec![example_group(vec![ - single_example( - "In Progress", - div() - .size_full() - .child(in_progress_search.update(cx, |tool, cx| { - tool.render( - &ToolUseStatus::Pending, - window, - WeakEntity::new_invalid(), - cx, - ) - .into_any_element() - })) - .into_any_element(), - ), - single_example( - "Successful", - div() - .size_full() - .child(successful_search.update(cx, |tool, cx| { - tool.render( - &ToolUseStatus::Finished("".into()), - window, - WeakEntity::new_invalid(), - cx, - ) - .into_any_element() - })) - .into_any_element(), - ), - single_example( - "Error", - div() - .size_full() - .child(error_search.update(cx, |tool, cx| { - tool.render( - &ToolUseStatus::Error("".into()), - window, - WeakEntity::new_invalid(), - cx, - ) - .into_any_element() - })) - .into_any_element(), - ), - ])]) - .into_any_element(), - ) - } -} - -fn example_search_response() -> WebSearchResponse { - WebSearchResponse { - results: vec![ - WebSearchResult { - title: "Alo".to_string(), - url: "https://www.google.com/maps/search/Alo%2C+Toronto%2C+Canada".to_string(), - text: "Alo is a popular restaurant in Toronto.".to_string(), - }, - WebSearchResult { - title: "Alo".to_string(), - url: "https://www.google.com/maps/search/Alo%2C+Toronto%2C+Canada".to_string(), - text: "Information about Alo restaurant in Toronto.".to_string(), - }, - WebSearchResult { - title: "Edulis".to_string(), - url: "https://www.google.com/maps/search/Edulis%2C+Toronto%2C+Canada".to_string(), - text: "Details about Edulis restaurant in Toronto.".to_string(), - }, - WebSearchResult { - title: "Sushi Masaki Saito".to_string(), - url: "https://www.google.com/maps/search/Sushi+Masaki+Saito%2C+Toronto%2C+Canada" - .to_string(), - text: "Information about Sushi Masaki Saito in Toronto.".to_string(), - }, - WebSearchResult { - title: "Shoushin".to_string(), - url: "https://www.google.com/maps/search/Shoushin%2C+Toronto%2C+Canada".to_string(), - text: "Details about Shoushin restaurant in Toronto.".to_string(), - }, - WebSearchResult { - title: "Restaurant 20 Victoria".to_string(), - url: - "https://www.google.com/maps/search/Restaurant+20+Victoria%2C+Toronto%2C+Canada" - .to_string(), - text: "Information about Restaurant 20 Victoria in Toronto.".to_string(), - }, - ], - } -} diff --git a/crates/eval/Cargo.toml b/crates/eval/Cargo.toml index a0214c76a1c7230e071cbc65c1eadbc44c7d6ca8..42dd07b8c610746850923cb9eb96fc900e5206db 100644 --- a/crates/eval/Cargo.toml +++ b/crates/eval/Cargo.toml @@ -18,12 +18,12 @@ name = "explorer" path = "src/explorer.rs" [dependencies] +acp_thread.workspace = true agent.workspace = true +agent-client-protocol.workspace = true agent_settings.workspace = true agent_ui.workspace = true anyhow.workspace = true -assistant_tool.workspace = true -assistant_tools.workspace = true async-trait.workspace = true buffer_diff.workspace = true chrono.workspace = true diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index 40d8c14f4f7ddc441f31581951ee4d6c26376a04..3afcc32a930ab32746352e81577d55a25c807cb4 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -429,7 +429,6 @@ pub fn init(cx: &mut App) -> Arc { true, cx, ); - assistant_tools::init(client.http_client(), cx); SettingsStore::update_global(cx, |store, cx| { store.set_user_settings(include_str!("../runner_settings.json"), cx) diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index c0f0900a6cfa5dd942bd27eed852ee4a52896c2c..22a8f9484c9f2c1d4ad01a107841b57e8b96f67b 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -9,7 +9,9 @@ use crate::{ ToolMetrics, assertions::{AssertionsReport, RanAssertion, RanAssertionResult}, }; -use agent::{ContextLoadResult, Thread, ThreadEvent}; +use acp_thread::UserMessageId; +use agent::{Thread, ThreadEvent, UserMessageContent}; +use agent_client_protocol as acp; use agent_settings::AgentProfileId; use anyhow::{Result, anyhow}; use async_trait::async_trait; diff --git a/crates/eval/src/examples/comment_translation.rs b/crates/eval/src/examples/comment_translation.rs index b6c9f7376f05fdc38e9f8128c78eb1761bc59c37..893166f3f13207e3444cb03bb17b2dea650170e7 100644 --- a/crates/eval/src/examples/comment_translation.rs +++ b/crates/eval/src/examples/comment_translation.rs @@ -1,7 +1,7 @@ use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion}; +use agent::{EditFileMode, EditFileToolInput}; use agent_settings::AgentProfileId; use anyhow::Result; -use assistant_tools::{EditFileMode, EditFileToolInput}; use async_trait::async_trait; pub struct CommentTranslation; diff --git a/crates/eval/src/examples/file_search.rs b/crates/eval/src/examples/file_search.rs index f1a482a41a952e889b6053e90e9e243ed546d2db..c893aef14299a6086e8c50072d69b0cbed7e9fde 100644 --- a/crates/eval/src/examples/file_search.rs +++ b/crates/eval/src/examples/file_search.rs @@ -1,6 +1,6 @@ +use agent::FindPathToolInput; use agent_settings::AgentProfileId; use anyhow::Result; -use assistant_tools::FindPathToolInput; use async_trait::async_trait; use regex::Regex; diff --git a/crates/eval/src/examples/grep_params_escapement.rs b/crates/eval/src/examples/grep_params_escapement.rs index 0532698ba28b45bd8111767eb51ea1336e18fa13..face6451572725ed402f23aac7bdc2c70a670b67 100644 --- a/crates/eval/src/examples/grep_params_escapement.rs +++ b/crates/eval/src/examples/grep_params_escapement.rs @@ -1,6 +1,5 @@ use agent_settings::AgentProfileId; use anyhow::Result; -use assistant_tools::GrepToolInput; use async_trait::async_trait; use crate::example::{Example, ExampleContext, ExampleMetadata}; diff --git a/crates/eval/src/examples/overwrite_file.rs b/crates/eval/src/examples/overwrite_file.rs index df0b75294c31bf7ff365e96aea18c371b817e710..d4b73aaec4d7d9a18be411ba7d453db9ffcb18a1 100644 --- a/crates/eval/src/examples/overwrite_file.rs +++ b/crates/eval/src/examples/overwrite_file.rs @@ -1,6 +1,5 @@ use agent_settings::AgentProfileId; use anyhow::Result; -use assistant_tools::{EditFileMode, EditFileToolInput}; use async_trait::async_trait; use crate::example::{Example, ExampleContext, ExampleMetadata}; diff --git a/crates/eval/src/examples/planets.rs b/crates/eval/src/examples/planets.rs index f3a69332d2c544479ca4f367699dc3def4d83370..caa15c728400a82b4223fb9ea8522b0815b36b5a 100644 --- a/crates/eval/src/examples/planets.rs +++ b/crates/eval/src/examples/planets.rs @@ -1,7 +1,6 @@ +use agent::{AgentTool, OpenTool, TerminalTool}; use agent_settings::AgentProfileId; use anyhow::Result; -use assistant_tool::Tool; -use assistant_tools::{OpenTool, TerminalTool}; use async_trait::async_trait; use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion}; @@ -38,9 +37,9 @@ impl Example for Planets { let mut terminal_tool_uses = 0; for tool_use in response.tool_uses() { - if tool_use.name == OpenTool.name() { + if tool_use.name == OpenTool::name() { open_tool_uses += 1; - } else if tool_use.name == TerminalTool::NAME { + } else if tool_use.name == TerminalTool::name() { terminal_tool_uses += 1; } } diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index 208147e2f04b26a7337c071d36f4f687ca0fe184..e95264c3c3b726244abe4edb61dee474d3bff51a 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -1,6 +1,5 @@ -use agent::{Message, MessageSegment, SerializedThread, ThreadStore}; +use agent::Message; use anyhow::{Context as _, Result, anyhow, bail}; -use assistant_tool::ToolWorkingSet; use client::proto::LspWorkProgress; use futures::channel::mpsc; use futures::{FutureExt as _, StreamExt as _, future}; diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index a85283cf121bc10a82e1022071d6a136dd5716f5..2f0fe67034875ebe8c240093b87d66f44e247c2b 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -32,7 +32,6 @@ image.workspace = true log.workspace = true parking_lot.workspace = true proto.workspace = true -schemars.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 38f2b0959072599900cb8a13c16f4e2f8e9c55db..24f9b84afcfa7b9a40b4a1b7684e9a9b036a5a85 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -19,8 +19,7 @@ use http_client::{StatusCode, http}; use icons::IconName; use open_router::OpenRouterError; use parking_lot::Mutex; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use serde::{Deserialize, Serialize}; pub use settings::LanguageModelCacheConfiguration; use std::ops::{Add, Sub}; use std::str::FromStr; @@ -669,11 +668,6 @@ pub trait LanguageModelExt: LanguageModel { } impl LanguageModelExt for dyn LanguageModel {} -pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema { - fn name() -> String; - fn description() -> String; -} - /// An error that occurred when trying to authenticate the language model provider. #[derive(Debug, Error)] pub enum AuthenticateError { diff --git a/crates/remote_server/Cargo.toml b/crates/remote_server/Cargo.toml index 37c77299ef4657ab62324fdef93f71e95ef026d1..3d28f6ba565330a5fc3c0ea0249aaf760c880439 100644 --- a/crates/remote_server/Cargo.toml +++ b/crates/remote_server/Cargo.toml @@ -75,8 +75,7 @@ minidumper.workspace = true [dev-dependencies] action_log.workspace = true -assistant_tool.workspace = true -assistant_tools.workspace = true +agent.workspace = true client = { workspace = true, features = ["test-support"] } clock = { workspace = true, features = ["test-support"] } collections.workspace = true diff --git a/crates/remote_server/src/remote_editing_tests.rs b/crates/remote_server/src/remote_editing_tests.rs index f6cddc65688a35b6ed67bfaa13bccb1ff5bde2c2..4010d033c09473cb475ae40b977af70fca390b82 100644 --- a/crates/remote_server/src/remote_editing_tests.rs +++ b/crates/remote_server/src/remote_editing_tests.rs @@ -2,12 +2,11 @@ /// The tests in this file assume that server_cx is running on Windows too. /// We neead to find a way to test Windows-Non-Windows interactions. use crate::headless_project::HeadlessProject; -use assistant_tool::{Tool as _, ToolResultContent}; -use assistant_tools::{ReadFileTool, ReadFileToolInput}; +use agent::{AgentTool, ReadFileTool, ReadFileToolInput, ToolCallEventStream}; use client::{Client, UserStore}; use clock::FakeSystemClock; use collections::{HashMap, HashSet}; -use language_model::{LanguageModelRequest, fake_provider::FakeLanguageModel}; +use language_model::LanguageModelToolResultContent; use extension::ExtensionHostProxy; use fs::{FakeFs, Fs}; @@ -1721,47 +1720,26 @@ async fn test_remote_agent_fs_tool_calls(cx: &mut TestAppContext, server_cx: &mu .unwrap(); let action_log = cx.new(|_| action_log::ActionLog::new(project.clone())); - let model = Arc::new(FakeLanguageModel::default()); - let request = Arc::new(LanguageModelRequest::default()); let input = ReadFileToolInput { path: "project/b.txt".into(), start_line: None, end_line: None, }; - let exists_result = cx.update(|cx| { - ReadFileTool::run( - Arc::new(ReadFileTool), - serde_json::to_value(input).unwrap(), - request.clone(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }); - let output = exists_result.output.await.unwrap().content; - assert_eq!(output, ToolResultContent::Text("B".to_string())); + let read_tool = Arc::new(ReadFileTool::new(project, action_log)); + let (event_stream, _) = ToolCallEventStream::test(); + + let exists_result = cx.update(|cx| read_tool.clone().run(input, event_stream.clone(), cx)); + let output = exists_result.await.unwrap(); + assert_eq!(output, LanguageModelToolResultContent::Text("B".into())); let input = ReadFileToolInput { path: "project/c.txt".into(), start_line: None, end_line: None, }; - let does_not_exist_result = cx.update(|cx| { - ReadFileTool::run( - Arc::new(ReadFileTool), - serde_json::to_value(input).unwrap(), - request.clone(), - project.clone(), - action_log.clone(), - model.clone(), - None, - cx, - ) - }); - does_not_exist_result.output.await.unwrap_err(); + let does_not_exist_result = cx.update(|cx| read_tool.run(input, event_stream, cx)); + does_not_exist_result.await.unwrap_err(); } #[gpui::test] diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index fd9ffdead941a506c251dbae306988642e1926c7..44ab6c2285cf2a9d75393edbb053d21d35ea1840 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -21,13 +21,11 @@ path = "src/main.rs" [dependencies] acp_tools.workspace = true activity_indicator.workspace = true -agent.workspace = true agent_settings.workspace = true agent_ui.workspace = true anyhow.workspace = true askpass.workspace = true assets.workspace = true -assistant_tools.workspace = true audio.workspace = true auto_update.workspace = true auto_update_ui.workspace = true diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index cc05cdfd822bd41135034dbaa3c174fd0af667cb..92897bc3344c710f4d694667a21040100a23a3cc 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -582,7 +582,6 @@ pub fn main() { false, cx, ); - assistant_tools::init(app_state.client.http_client(), cx); repl::init(app_state.fs.clone(), cx); recent_projects::init(cx); diff --git a/script/danger/dangerfile.ts b/script/danger/dangerfile.ts index 6ed4a27fedb0bea7882ad4bcdd1016929bdd40e3..88dc5c5e71c640a83315ac5f1b14c216763023fd 100644 --- a/script/danger/dangerfile.ts +++ b/script/danger/dangerfile.ts @@ -61,12 +61,11 @@ if (includesIssueUrl) { const PROMPT_PATHS = [ "assets/prompts/content_prompt.hbs", "assets/prompts/terminal_assistant_prompt.hbs", - "crates/agent/src/prompts/stale_files_prompt_header.txt", - "crates/agent/src/prompts/summarize_thread_detailed_prompt.txt", - "crates/agent/src/prompts/summarize_thread_prompt.txt", - "crates/assistant_tools/src/templates/create_file_prompt.hbs", - "crates/assistant_tools/src/templates/edit_file_prompt_xml.hbs", - "crates/assistant_tools/src/templates/edit_file_prompt_diff_fenced.hbs", + "crates/agent_settings/src/prompts/summarize_thread_detailed_prompt.txt", + "crates/agent_settings/src/prompts/summarize_thread_prompt.txt", + "crates/agent/src/templates/create_file_prompt.hbs", + "crates/agent/src/templates/edit_file_prompt_xml.hbs", + "crates/agent/src/templates/edit_file_prompt_diff_fenced.hbs", "crates/git_ui/src/commit_message_prompt.txt", ];