From bab28560ef56092bd99e17d595523644f18f6764 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Fri, 18 Apr 2025 20:47:59 -0600 Subject: [PATCH] Systematically optimize agentic editing performance (#28961) Now that we've established a proper eval in tree, this PR is reboots of our agent loop back to a set of minimal tools and simpler prompts. We should aim to get this branch feeling subjectively competitive with what's on main and then merge it, and build from there. Let's invest in our eval and use it to drive better performance of the agent loop. How you can help: Pick an example, and then make the outcome faster or better. It's fine to even use your own subjective judgment, as our evaluation criteria likely need tuning as well at this point. Focus on making the agent work better in your own subjective experience first. Let's focus on simple/practical improvements to make this thing work better, then determine how we can craft our judgment criteria to lock those improvements in. Release Notes: - N/A --------- Co-authored-by: Max Co-authored-by: Antonio Co-authored-by: Agus Co-authored-by: Richard Co-authored-by: Max Brunsfeld Co-authored-by: Antonio Scandurra Co-authored-by: Michael Sloan --- Cargo.lock | 5 + assets/prompts/assistant_system_prompt.hbs | 203 ++---- assets/settings/default.json | 173 +++-- crates/agent/src/thread.rs | 53 +- crates/agent/src/thread_store.rs | 3 - crates/assistant_tools/Cargo.toml | 4 + crates/assistant_tools/src/assistant_tools.rs | 22 +- .../assistant_tools/src/code_symbols_tool.rs | 20 +- crates/assistant_tools/src/contents_tool.rs | 2 +- .../src/diagnostics_tool/description.md | 7 +- crates/assistant_tools/src/edit_file_tool.rs | 183 ++++++ .../src/edit_file_tool/description.md | 45 ++ .../src/list_directory_tool.rs | 2 +- .../src/list_directory_tool/description.md | 2 +- .../assistant_tools/src/path_search_tool.rs | 171 +++-- .../src/path_search_tool/description.md | 8 +- crates/assistant_tools/src/read_file_tool.rs | 243 ++++++- .../src/read_file_tool/description.md | 5 +- .../src/regex_search_tool/description.md | 9 +- crates/eval/Cargo.toml | 3 +- .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../email_verification_refactor/base.toml | 1 + .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../find_and_replace_diff_card/prompt.md | 2 +- .../thread_criteria.md | 3 + .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../eval/examples/metal_i64_support/base.toml | 1 + .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../virtio_block_request_refactor/base.toml | 1 + .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 .../{criteria.md => diff_criteria.md} | 0 crates/eval/src/eval.rs | 193 ++++-- crates/eval/src/example.rs | 618 +++++++++++++++--- ...judge_prompt.hbs => judge_diff_prompt.hbs} | 22 + crates/eval/src/judge_thread_prompt.hbs | 22 + crates/prompt_store/src/prompts.rs | 2 - crates/worktree/src/worktree.rs | 17 + typos.toml | 4 +- 68 files changed, 1573 insertions(+), 476 deletions(-) create mode 100644 crates/assistant_tools/src/edit_file_tool.rs create mode 100644 crates/assistant_tools/src/edit_file_tool/description.md rename crates/eval/examples/add_arp_protocol_support/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/auth_session_management/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/buffer_string_input_support/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/checkpoint_stability/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/dd_iaptic_mcp_server_integration/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/debian_image_builder/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/docs_restructure/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/email_verification_refactor/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/exif_rotation_support/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/expand_laravel_php_support/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/find_and_replace_diff_card/{criteria.md => diff_criteria.md} (100%) create mode 100644 crates/eval/examples/find_and_replace_diff_card/thread_criteria.md rename crates/eval/examples/finnish_translation/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/language_model_file_support/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/lhs_join_update_callbacks/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/libdevice_symbol_reexport/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/license_management/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/metal_i64_support/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/metrics_data_size_updates/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/nan_diff_handling/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/never_type_workaround/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/optimizer_schema_refactor/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/rate_limit_endpoints/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/replace_hold_with_drain_on_exit/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/request_to_axios_migration/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/restore_version_api_support/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/runtime_script_refactor/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/standardized_docker_dependency_checks/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/table_metrics_sorting/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/tax_id_validation/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/test_infrastructure/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/time_detail_merge_update/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/tool_response_handling/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/toolbar_endpoints/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/virtio_block_request_refactor/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/war_and_uri_corrections/{criteria.md => diff_criteria.md} (100%) rename crates/eval/examples/window_title_support/{criteria.md => diff_criteria.md} (100%) rename crates/eval/src/{judge_prompt.hbs => judge_diff_prompt.hbs} (64%) create mode 100644 crates/eval/src/judge_thread_prompt.hbs diff --git a/Cargo.lock b/Cargo.lock index 1b872d02fa29ea9603e76ae1e76f8aa75ebfd0ad..27f918391376ffd9d458a9ad3f2cce60f2c5e504 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -710,17 +710,21 @@ dependencies = [ "gpui", "html_to_markdown", "http_client", + "indoc", "itertools 0.14.0", "language", "language_model", "linkme", "open", + "pretty_assertions", "project", "rand 0.8.5", "regex", "schemars", "serde", "serde_json", + "settings", + "tree-sitter-rust", "ui", "unindent", "util", @@ -4914,6 +4918,7 @@ dependencies = [ "release_channel", "reqwest_client", "serde", + "serde_json", "settings", "shellexpand 2.1.2", "telemetry", diff --git a/assets/prompts/assistant_system_prompt.hbs b/assets/prompts/assistant_system_prompt.hbs index 60b2cee74eb9cd3fbcb0675596e963bbad947609..be186f1738c7882dd9a3c50c7e0a6eecd53d1cb2 100644 --- a/assets/prompts/assistant_system_prompt.hbs +++ b/assets/prompts/assistant_system_prompt.hbs @@ -1,148 +1,77 @@ -You are an AI assistant integrated into a code editor. You have the programming ability of an expert programmer who takes pride in writing high-quality code and is driven to the point of obsession about solving problems effectively. Your goal is to do one of the following two things: +You are a highly skilled software engineer with extensive knowledge in many programming languages, frameworks, design patterns, and best practices. -1. Help users answer questions and perform tasks related to their codebase. -2. Answer general-purpose questions unrelated to their particular codebase. +## Communication -It will be up to you to decide which of these you are doing based on what the user has told you. When unclear, ask clarifying questions to understand the user's intent before proceeding. +1. Be conversational but professional. +2. Refer to the USER in the second person and yourself in the first person. +3. Format your responses in markdown. Use backticks to format file, directory, function, and class names. +4. NEVER lie or make things up. +5. Refrain from apologizing all the time when results are unexpected. Instead, just try your best to proceed or explain the circumstances to the user without apologizing. -You should only perform actions that modify the user's system if explicitly requested by the user: -- If the user asks a question about how to accomplish a task, provide guidance or information, and use read-only tools (e.g., search) to assist. You may suggest potential actions, but do not directly modify the user's system without explicit instruction. -- If the user clearly requests that you perform an action, carry out the action directly without explaining why you are doing so. +## Searching and Reading -When answering questions, it's okay to give incomplete examples containing comments about what would go there in a real version. When being asked to directly perform tasks on the code base, you must ALWAYS make fully working code. You may never "simplify" the code by omitting or deleting functionality you know the user has requested, and you must NEVER write comments like "in a full version, this would..." - instead, you must actually implement the real version. Don't be lazy! +If you are unsure about the answer to the user's request or how to satiate their request, you should gather more information. +This can be done with additional tool calls, asking clarifying questions, etc. -Note that project files are automatically backed up. The user can always get them back later if anything goes wrong, so there's -no need to create backup files (e.g. `.bak` files) because these files will just take up unnecessary space on the user's disk. +For example, if you've performed a semantic search, and the results may not fully answer the user's request, or merit gathering more information, feel free to call more tools. Similarly, if you've performed an edit that may partially +satiate the user's query, but you're not confident, gather more information or use more tools before ending your turn. -When attempting to resolve issues around failing tests, never simply remove the failing tests. Unless the user explicitly asks you to remove tests, ALWAYS attempt to fix the code causing the tests to fail. +Bias towards not asking the user for help if you can find the answer yourself. -Ignore "TODO"-type comments unless they're relevant to the user's explicit request or the user specifically asks you to address them. It is, however, okay to include them in codebase summaries. +## Tool Use - +{{#if has_rules}} +There are project rules that apply to these root directories: +{{#each worktrees}} +{{#if rules_file}} +`{{root_name}}/{{rules_file.path_in_worktree}}`: +`````` +{{{rules_file.text}}} +`````` +{{/if}} +{{/each}} +{{/if}} {{#if has_default_user_rules}} The user has specified the following rules that should be applied: @@ -152,32 +81,8 @@ The user has specified the following rules that should be applied: Rules title: {{title}} {{/if}} `````` -{{contents}} +{{contents}}} `````` {{/each}} - {{/if}} -The user has opened a project that contains the following root directories/files. Whenever you specify a path in the project, it must be a relative path which begins with one of these root directories/files: - -{{#each worktrees}} -- `{{root_name}}` (absolute path: `{{abs_path}}`) -{{/each}} -{{#if has_rules}} - -There are project rules that apply to these root directories: -{{#each worktrees}} -{{#if rules_file}} - -`{{root_name}}/{{rules_file.path_in_worktree}}`: - -`````` -{{{rules_file.text}}} -`````` {{/if}} -{{/each}} -{{/if}} - - -Operating System: {{os}} ({{arch}}) -Shell: {{shell}} - diff --git a/assets/settings/default.json b/assets/settings/default.json index 98ee37f21366db2cc3bacf32cb04f51faf812439..f31feb73568b67da349a7054f51e32eed9987422 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -214,7 +214,14 @@ // The default number of lines to expand excerpts in the multibuffer by. "expand_excerpt_lines": 3, // Globs to match against file paths to determine if a file is private. - "private_files": ["**/.env*", "**/*.pem", "**/*.key", "**/*.cert", "**/*.crt", "**/secrets.yml"], + "private_files": [ + "**/.env*", + "**/*.pem", + "**/*.key", + "**/*.cert", + "**/*.crt", + "**/secrets.yml" + ], // Whether to use additional LSP queries to format (and amend) the code after // every "trigger" symbol input, defined by LSP server capabilities. "use_on_type_format": true, @@ -587,7 +594,6 @@ // // Default: main "fallback_branch_name": "main", - "scrollbar": { // When to show the scrollbar in the git panel. // @@ -660,25 +666,25 @@ "name": "Write", "enable_all_context_servers": true, "tools": { - "terminal": true, - "batch_tool": true, - "code_actions": true, - "code_symbols": true, - "contents": true, + "batch_tool": false, + "code_actions": false, + "code_symbols": false, + "contents": false, "copy_path": false, "create_file": true, "delete_path": false, "diagnostics": true, - "find_replace_file": true, + "edit_file": true, "fetch": true, - "list_directory": false, + "list_directory": true, "move_path": false, - "now": true, + "now": false, "path_search": true, "read_file": true, "regex_search": true, - "rename": true, - "symbol_info": true, + "rename": false, + "symbol_info": false, + "terminal": true, "thinking": true, "web_search": true } @@ -715,7 +721,9 @@ // The list of language servers to use (or disable) for all languages. // // This is typically customized on a per-language basis. - "language_servers": ["..."], + "language_servers": [ + "..." + ], // When to automatically save edited buffers. This setting can // take four values. // @@ -911,7 +919,9 @@ // for files that are not tracked by git, but are still important to your project. Note that globs // that are overly broad can slow down Zed's file scanning. `file_scan_exclusions` takes // precedence over these inclusions. - "file_scan_inclusions": [".env*"], + "file_scan_inclusions": [ + ".env*" + ], // Git gutter behavior configuration. "git": { // Control whether the git gutter is shown. May take 2 values: @@ -963,7 +973,15 @@ // Any addition to this list will be merged with the default list. // Globs are matched relative to the worktree root, // except when starting with a slash (/) or equivalent in Windows. - "disabled_globs": ["**/.env*", "**/*.pem", "**/*.key", "**/*.cert", "**/*.crt", "**/.dev.vars", "**/secrets.yml"], + "disabled_globs": [ + "**/.env*", + "**/*.pem", + "**/*.key", + "**/*.cert", + "**/*.crt", + "**/.dev.vars", + "**/secrets.yml" + ], // When to show edit predictions previews in buffer. // This setting takes two possible values: // 1. Display predictions inline when there are no language server completions available. @@ -1096,7 +1114,12 @@ // Default directories to search for virtual environments, relative // to the current working directory. We recommend overriding this // in your project's settings, rather than globally. - "directories": [".env", "env", ".venv", "venv"], + "directories": [ + ".env", + "env", + ".venv", + "venv" + ], // Can also be `csh`, `fish`, `nushell` and `power_shell` "activate_script": "default" } @@ -1160,8 +1183,15 @@ // } // "file_types": { - "JSONC": ["**/.zed/**/*.json", "**/zed/**/*.json", "**/Zed/**/*.json", "**/.vscode/**/*.json"], - "Shell Script": [".env.*"] + "JSONC": [ + "**/.zed/**/*.json", + "**/zed/**/*.json", + "**/Zed/**/*.json", + "**/.vscode/**/*.json" + ], + "Shell Script": [ + ".env.*" + ] }, // By default use a recent system version of node, or install our own. // You can override this to use a version of node that is not in $PATH with: @@ -1234,10 +1264,15 @@ // Different settings for specific languages. "languages": { "Astro": { - "language_servers": ["astro-language-server", "..."], + "language_servers": [ + "astro-language-server", + "..." + ], "prettier": { "allowed": true, - "plugins": ["prettier-plugin-astro"] + "plugins": [ + "prettier-plugin-astro" + ] } }, "Blade": { @@ -1273,10 +1308,19 @@ "ensure_final_newline_on_save": false }, "Elixir": { - "language_servers": ["elixir-ls", "!next-ls", "!lexical", "..."] + "language_servers": [ + "elixir-ls", + "!next-ls", + "!lexical", + "..." + ] }, "Erlang": { - "language_servers": ["erlang-ls", "!elp", "..."] + "language_servers": [ + "erlang-ls", + "!elp", + "..." + ] }, "Git Commit": { "allow_rewrap": "anywhere" @@ -1292,7 +1336,12 @@ } }, "HEEX": { - "language_servers": ["elixir-ls", "!next-ls", "!lexical", "..."] + "language_servers": [ + "elixir-ls", + "!next-ls", + "!lexical", + "..." + ] }, "HTML": { "prettier": { @@ -1302,11 +1351,17 @@ "Java": { "prettier": { "allowed": true, - "plugins": ["prettier-plugin-java"] + "plugins": [ + "prettier-plugin-java" + ] } }, "JavaScript": { - "language_servers": ["!typescript-language-server", "vtsls", "..."], + "language_servers": [ + "!typescript-language-server", + "vtsls", + "..." + ], "prettier": { "allowed": true } @@ -1324,7 +1379,10 @@ "LaTeX": { "format_on_save": "on", "formatter": "language_server", - "language_servers": ["texlab", "..."], + "language_servers": [ + "texlab", + "..." + ], "prettier": { "allowed": false } @@ -1339,10 +1397,16 @@ } }, "PHP": { - "language_servers": ["phpactor", "!intelephense", "..."], + "language_servers": [ + "phpactor", + "!intelephense", + "..." + ], "prettier": { "allowed": true, - "plugins": ["@prettier/plugin-php"], + "plugins": [ + "@prettier/plugin-php" + ], "parser": "php" } }, @@ -1350,7 +1414,12 @@ "allow_rewrap": "anywhere" }, "Ruby": { - "language_servers": ["solargraph", "!ruby-lsp", "!rubocop", "..."] + "language_servers": [ + "solargraph", + "!ruby-lsp", + "!rubocop", + "..." + ] }, "SCSS": { "prettier": { @@ -1360,21 +1429,36 @@ "SQL": { "prettier": { "allowed": true, - "plugins": ["prettier-plugin-sql"] + "plugins": [ + "prettier-plugin-sql" + ] } }, "Starlark": { - "language_servers": ["starpls", "!buck2-lsp", "..."] + "language_servers": [ + "starpls", + "!buck2-lsp", + "..." + ] }, "Svelte": { - "language_servers": ["svelte-language-server", "..."], + "language_servers": [ + "svelte-language-server", + "..." + ], "prettier": { "allowed": true, - "plugins": ["prettier-plugin-svelte"] + "plugins": [ + "prettier-plugin-svelte" + ] } }, "TSX": { - "language_servers": ["!typescript-language-server", "vtsls", "..."], + "language_servers": [ + "!typescript-language-server", + "vtsls", + "..." + ], "prettier": { "allowed": true } @@ -1385,13 +1469,20 @@ } }, "TypeScript": { - "language_servers": ["!typescript-language-server", "vtsls", "..."], + "language_servers": [ + "!typescript-language-server", + "vtsls", + "..." + ], "prettier": { "allowed": true } }, "Vue.js": { - "language_servers": ["vue-language-server", "..."], + "language_servers": [ + "vue-language-server", + "..." + ], "prettier": { "allowed": true } @@ -1399,7 +1490,9 @@ "XML": { "prettier": { "allowed": true, - "plugins": ["@prettier/plugin-xml"] + "plugins": [ + "@prettier/plugin-xml" + ] } }, "YAML": { @@ -1408,7 +1501,10 @@ } }, "Zig": { - "language_servers": ["zls", "..."] + "language_servers": [ + "zls", + "..." + ] } }, // Different settings for specific language models. @@ -1562,7 +1658,6 @@ // } // ] "ssh_connections": [], - // Configures context servers for use in the Assistant. "context_servers": {}, "debugger": { diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index a7044b01004eb38120dda397ea774762ec35f104..c50c001cd2f8aab54acea5a927dfb00a52031f4c 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -313,6 +313,9 @@ pub struct Thread { feedback: Option, message_feedback: HashMap, last_auto_capture_at: Option, + request_callback: Option< + Box])>, + >, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -365,6 +368,7 @@ impl Thread { feedback: None, message_feedback: HashMap::default(), last_auto_capture_at: None, + request_callback: None, } } @@ -434,9 +438,18 @@ impl Thread { feedback: None, message_feedback: HashMap::default(), last_auto_capture_at: None, + request_callback: None, } } + pub fn set_request_callback( + &mut self, + callback: impl 'static + + FnMut(&LanguageModelRequest, &[Result]), + ) { + self.request_callback = Some(Box::new(callback)); + } + pub fn id(&self) -> &ThreadId { &self.id } @@ -1083,15 +1096,6 @@ impl Thread { content.push(stale_message.into()); } - if action_log.has_edited_files_since_project_diagnostics_check() { - content.push( - "\n\nWhen you're done making changes, make sure to check project diagnostics \ - and fix all errors AND warnings you introduced! \ - DO NOT mention you're going to do this until you're done." - .into(), - ); - } - if !content.is_empty() { let context_message = LanguageModelRequestMessage { role: Role::User, @@ -1110,6 +1114,11 @@ impl Thread { cx: &mut Context, ) { let pending_completion_id = post_inc(&mut self.completion_count); + let mut request_callback_parameters = if self.request_callback.is_some() { + Some((request.clone(), Vec::new())) + } else { + None + }; let prompt_id = self.last_prompt_id.clone(); let task = cx.spawn(async move |thread, cx| { let stream_completion_future = model.stream_completion_with_usage(request, &cx); @@ -1117,6 +1126,7 @@ impl Thread { thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage); let stream_completion = async { let (mut events, usage) = stream_completion_future.await?; + let mut stop_reason = StopReason::EndTurn; let mut current_token_usage = TokenUsage::default(); @@ -1129,6 +1139,11 @@ impl Thread { } while let Some(event) = events.next().await { + if let Some((_, response_events)) = request_callback_parameters.as_mut() { + response_events + .push(event.as_ref().map_err(|error| error.to_string()).cloned()); + } + let event = event?; thread.update(cx, |thread, cx| { @@ -1293,6 +1308,14 @@ impl Thread { } cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new))); + if let Some((request_callback, (request, response_events))) = thread + .request_callback + .as_mut() + .zip(request_callback_parameters.as_ref()) + { + request_callback(request, response_events); + } + thread.auto_capture_telemetry(cx); if let Ok(initial_usage) = initial_token_usage { @@ -1587,17 +1610,11 @@ impl Thread { }); } + /// Insert an empty message to be populated with tool results upon send. pub fn attach_tool_results(&mut self, cx: &mut Context) { + // TODO: Don't insert a dummy user message here. Ensure this works with the thinking model. // Insert a user message to contain the tool results. - self.insert_user_message( - // TODO: Sending up a user message without any content results in the model sending back - // responses that also don't have any content. We currently don't handle this case well, - // so for now we provide some text to keep the model on track. - "Here are the tool results.", - Vec::new(), - None, - cx, - ); + self.insert_user_message("Here are the tool results.", Vec::new(), None, cx); } /// Cancels the last pending completion, if there are any pending. diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 74787016fd6f8330f58aee863afd189c039b3089..e2f9e3d2de7a44b7319798d4da4837050fc94359 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -279,14 +279,12 @@ impl ThreadStore { cx: &App, ) -> Task<(WorktreeContext, Option)> { let root_name = worktree.root_name().into(); - let abs_path = worktree.abs_path(); let rules_task = Self::load_worktree_rules_file(fs, worktree, cx); let Some(rules_task) = rules_task else { return Task::ready(( WorktreeContext { root_name, - abs_path, rules_file: None, }, None, @@ -305,7 +303,6 @@ impl ThreadStore { }; let worktree_info = WorktreeContext { root_name, - abs_path, rules_file, }; (worktree_info, rules_file_error) diff --git a/crates/assistant_tools/Cargo.toml b/crates/assistant_tools/Cargo.toml index be116a65343551bd9259e46aca979d957381954c..eaaeff1e47c854f346c87b85e79a71df9579e4b4 100644 --- a/crates/assistant_tools/Cargo.toml +++ b/crates/assistant_tools/Cargo.toml @@ -22,6 +22,7 @@ futures.workspace = true gpui.workspace = true html_to_markdown.workspace = true http_client.workspace = true +indoc.workspace = true itertools.workspace = true language.workspace = true language_model.workspace = true @@ -45,5 +46,8 @@ gpui = { workspace = true, features = ["test-support"] } language = { workspace = true, features = ["test-support"] } project = { workspace = true, features = ["test-support"] } rand.workspace = true +pretty_assertions.workspace = true +settings = { workspace = true, features = ["test-support"] } +tree-sitter-rust.workspace = true workspace = { workspace = true, features = ["test-support"] } unindent.workspace = true diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index 33e06466e25825059edb6b58d4699db11ccb4a7a..b68a273b4e4090aec0dfa53abcc6d61bf3a8027c 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -7,8 +7,8 @@ mod create_directory_tool; mod create_file_tool; mod delete_path_tool; mod diagnostics_tool; +mod edit_file_tool; mod fetch_tool; -mod find_replace_file_tool; mod list_directory_tool; mod move_path_tool; mod now_tool; @@ -42,8 +42,8 @@ use crate::create_directory_tool::CreateDirectoryTool; use crate::create_file_tool::CreateFileTool; use crate::delete_path_tool::DeletePathTool; use crate::diagnostics_tool::DiagnosticsTool; +use crate::edit_file_tool::EditFileTool; use crate::fetch_tool::FetchTool; -use crate::find_replace_file_tool::FindReplaceFileTool; use crate::list_directory_tool::ListDirectoryTool; use crate::now_tool::NowTool; use crate::open_tool::OpenTool; @@ -59,28 +59,28 @@ pub fn init(http_client: Arc, cx: &mut App) { assistant_tool::init(cx); let registry = ToolRegistry::global(cx); + registry.register_tool(TerminalTool); registry.register_tool(BatchTool); - registry.register_tool(CodeActionTool); - registry.register_tool(CodeSymbolsTool); - registry.register_tool(ContentsTool); - registry.register_tool(CopyPathTool); registry.register_tool(CreateDirectoryTool); registry.register_tool(CreateFileTool); + registry.register_tool(CopyPathTool); registry.register_tool(DeletePathTool); + registry.register_tool(EditFileTool); + registry.register_tool(SymbolInfoTool); + registry.register_tool(CodeActionTool); + registry.register_tool(MovePathTool); registry.register_tool(DiagnosticsTool); - registry.register_tool(FetchTool::new(http_client)); - registry.register_tool(FindReplaceFileTool); registry.register_tool(ListDirectoryTool); - registry.register_tool(MovePathTool); registry.register_tool(NowTool); registry.register_tool(OpenTool); + registry.register_tool(CodeSymbolsTool); + registry.register_tool(ContentsTool); registry.register_tool(PathSearchTool); registry.register_tool(ReadFileTool); registry.register_tool(RegexSearchTool); registry.register_tool(RenameTool); - registry.register_tool(SymbolInfoTool); - registry.register_tool(TerminalTool); registry.register_tool(ThinkingTool); + registry.register_tool(FetchTool::new(http_client)); cx.observe_flag::({ move |is_enabled, cx| { diff --git a/crates/assistant_tools/src/code_symbols_tool.rs b/crates/assistant_tools/src/code_symbols_tool.rs index 4743c88720192b9a147f2a56327bffa2ab4d6e6c..9cd63ab02c61d7fb6c505072eec36c0a9c325483 100644 --- a/crates/assistant_tools/src/code_symbols_tool.rs +++ b/crates/assistant_tools/src/code_symbols_tool.rs @@ -147,7 +147,7 @@ impl Tool for CodeSymbolsTool { }; cx.spawn(async move |cx| match input.path { - Some(path) => file_outline(project, path, action_log, regex, input.offset, cx).await, + Some(path) => file_outline(project, path, action_log, regex, cx).await, None => project_symbols(project, regex, input.offset, cx).await, }) .into() @@ -159,7 +159,6 @@ pub async fn file_outline( path: String, action_log: Entity, regex: Option, - offset: u32, cx: &mut AsyncApp, ) -> anyhow::Result { let buffer = { @@ -195,7 +194,8 @@ pub async fn file_outline( .into_iter() .map(|item| item.to_point(&snapshot)), regex, - offset, + 0, + usize::MAX, ) .await } @@ -294,11 +294,10 @@ async fn project_symbols( async fn render_outline( items: impl IntoIterator>, regex: Option, - offset: u32, + offset: usize, + results_per_page: usize, ) -> Result { - const RESULTS_PER_PAGE_USIZE: usize = RESULTS_PER_PAGE as usize; - - let mut items = items.into_iter().skip(offset as usize); + let mut items = items.into_iter().skip(offset); let entries = items .by_ref() @@ -307,7 +306,7 @@ async fn render_outline( .as_ref() .is_none_or(|regex| regex.is_match(&item.text)) }) - .take(RESULTS_PER_PAGE_USIZE) + .take(results_per_page) .collect::>(); let has_more = items.next().is_some(); @@ -338,7 +337,10 @@ async fn render_outline( Ok(output) } -fn render_entries(output: &mut String, items: impl IntoIterator>) -> u32 { +fn render_entries( + output: &mut String, + items: impl IntoIterator>, +) -> usize { let mut entries_rendered = 0; for item in items { diff --git a/crates/assistant_tools/src/contents_tool.rs b/crates/assistant_tools/src/contents_tool.rs index 5281cfa7c788e7eef9fcb902598d07b50784b8aa..183fa29e8c5aa5efff2f1eed0252bc122e58fe92 100644 --- a/crates/assistant_tools/src/contents_tool.rs +++ b/crates/assistant_tools/src/contents_tool.rs @@ -228,7 +228,7 @@ impl Tool for ContentsTool { } else { // File is too big, so return its outline and a suggestion to // read again with a line number range specified. - let outline = file_outline(project, file_path, action_log, None, 0, cx).await?; + let outline = file_outline(project, file_path, action_log, None, cx).await?; Ok(format!("This file was too big to read all at once. Here is an outline of its symbols:\n\n{outline}\n\nUsing the line numbers in this outline, you can call this tool again while specifying the start and end fields to see the implementations of symbols in the outline.")) } diff --git a/crates/assistant_tools/src/diagnostics_tool/description.md b/crates/assistant_tools/src/diagnostics_tool/description.md index ab250e09aded5d97a38b568d557b91a2e0a50fc4..90dc00f1e408c0bd4d79de68833db9d4bafc0d2c 100644 --- a/crates/assistant_tools/src/diagnostics_tool/description.md +++ b/crates/assistant_tools/src/diagnostics_tool/description.md @@ -15,6 +15,7 @@ To get a project-wide diagnostic summary: {} -IMPORTANT: When you're done making changes, you **MUST** get the **project** diagnostics (input: `{}`) at the end of your edits so you can fix any problems you might have introduced. **DO NOT** tell the user you're done before doing this! - -You may only attempt to fix these up to 3 times. If you have tried 3 times to fix them, and there are still problems remaining, you must not continue trying to fix them, and must instead tell the user that there are problems remaining - and ask if the user would like you to attempt to solve them further. + +- If you think you can fix a diagnostic, make 1-2 attempts and then give up. +- Don't remove code you've generated just because you can't fix an error. The user can help you fix it. + diff --git a/crates/assistant_tools/src/edit_file_tool.rs b/crates/assistant_tools/src/edit_file_tool.rs new file mode 100644 index 0000000000000000000000000000000000000000..136dd60bfed431f86cef07f445ce9e0ca88ef331 --- /dev/null +++ b/crates/assistant_tools/src/edit_file_tool.rs @@ -0,0 +1,183 @@ +use crate::{replace::replace_with_flexible_indent, schema::json_schema_for}; +use anyhow::{Context as _, Result, anyhow}; +use assistant_tool::{ActionLog, Tool, ToolResult}; +use gpui::{App, AppContext, AsyncApp, Entity, Task}; +use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; +use project::Project; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::{path::PathBuf, sync::Arc}; +use ui::IconName; + +use crate::replace::replace_exact; + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct EditFileToolInput { + /// The full path of the file to modify in the project. + /// + /// WARNING: When specifying which file path need changing, you MUST + /// start each path with one of the project's root directories. + /// + /// The following examples assume we have two root directories in the project: + /// - backend + /// - frontend + /// + /// + /// `backend/src/main.rs` + /// + /// Notice how the file path starts with root-1. Without that, the path + /// would be ambiguous and the call would fail! + /// + /// + /// + /// `frontend/db.js` + /// + pub path: PathBuf, + + /// A user-friendly markdown description of what's being replaced. This will be shown in the UI. + /// + /// Fix API endpoint URLs + /// Update copyright year in `page_footer` + pub display_description: String, + + /// The text to replace. + pub old_string: String, + + /// The text to replace it with. + pub new_string: String, +} + +pub struct EditFileTool; + +impl Tool for EditFileTool { + fn name(&self) -> String { + "edit_file".into() + } + + fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + false + } + + fn description(&self) -> String { + include_str!("edit_file_tool/description.md").to_string() + } + + fn icon(&self) -> IconName { + IconName::Pencil + } + + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { + json_schema_for::(format) + } + + fn ui_text(&self, input: &serde_json::Value) -> String { + match serde_json::from_value::(input.clone()) { + Ok(input) => input.display_description, + Err(_) => "Edit file".to_string(), + } + } + + fn run( + self: Arc, + input: serde_json::Value, + _messages: &[LanguageModelRequestMessage], + project: Entity, + action_log: Entity, + cx: &mut App, + ) -> ToolResult { + let input = match serde_json::from_value::(input) { + Ok(input) => input, + Err(err) => return Task::ready(Err(anyhow!(err))).into(), + }; + + cx.spawn(async move |cx: &mut AsyncApp| { + let project_path = project.read_with(cx, |project, cx| { + project + .find_project_path(&input.path, cx) + .context("Path not found in project") + })??; + + let buffer = project + .update(cx, |project, cx| project.open_buffer(project_path, cx))? + .await?; + + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + + if input.old_string.is_empty() { + return Err(anyhow!("`old_string` cannot be empty. Use a different tool if you want to create a file.")); + } + + if input.old_string == input.new_string { + return Err(anyhow!("The `old_string` and `new_string` are identical, so no changes would be made.")); + } + + let result = cx + .background_spawn(async move { + // Try to match exactly + let diff = replace_exact(&input.old_string, &input.new_string, &snapshot) + .await + // If that fails, try being flexible about indentation + .or_else(|| replace_with_flexible_indent(&input.old_string, &input.new_string, &snapshot))?; + + if diff.edits.is_empty() { + return None; + } + + let old_text = snapshot.text(); + + Some((old_text, diff)) + }) + .await; + + let Some((old_text, diff)) = result else { + let err = buffer.read_with(cx, |buffer, _cx| { + let file_exists = buffer + .file() + .map_or(false, |file| file.disk_state().exists()); + + if !file_exists { + anyhow!("{} does not exist", input.path.display()) + } else if buffer.is_empty() { + anyhow!( + "{} is empty, so the provided `old_string` wasn't found.", + input.path.display() + ) + } else { + anyhow!("Failed to match the provided `old_string`") + } + })?; + + return Err(err) + }; + + let snapshot = cx.update(|cx| { + action_log.update(cx, |log, cx| { + log.buffer_read(buffer.clone(), cx) + }); + let snapshot = buffer.update(cx, |buffer, cx| { + buffer.finalize_last_transaction(); + buffer.apply_diff(diff, cx); + buffer.finalize_last_transaction(); + buffer.snapshot() + }); + action_log.update(cx, |log, cx| { + log.buffer_edited(buffer.clone(), cx) + }); + snapshot + })?; + + project.update( cx, |project, cx| { + project.save_buffer(buffer, cx) + })?.await?; + + let diff_str = cx.background_spawn(async move { + let new_text = snapshot.text(); + language::unified_diff(&old_text, &new_text) + }).await; + + + Ok(format!("Edited {}:\n\n```diff\n{}\n```", input.path.display(), diff_str)) + + }).into() + } +} diff --git a/crates/assistant_tools/src/edit_file_tool/description.md b/crates/assistant_tools/src/edit_file_tool/description.md new file mode 100644 index 0000000000000000000000000000000000000000..51f2db7808684c2f19bd6ccbf7fad9882a6d3b96 --- /dev/null +++ b/crates/assistant_tools/src/edit_file_tool/description.md @@ -0,0 +1,45 @@ +This is a tool for editing files. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead. For larger edits, use the `create_file` tool to overwrite files. + +Before using this tool: + +1. Use the `read_file` tool to understand the file's contents and context + +2. Verify the directory path is correct (only applicable when creating new files): + - Use the `list_directory` tool to verify the parent directory exists and is the correct location + +To make a file edit, provide the following: +1. path: The full path to the file you wish to modify in the project. This path must include the root directory in the project. +2. old_string: The text to replace (must be unique within the file, and must match the file contents exactly, including all whitespace and indentation) +3. new_string: The edited text, which will replace the old_string in the file. + +The tool will replace ONE occurrence of old_string with new_string in the specified file. + +CRITICAL REQUIREMENTS FOR USING THIS TOOL: + +1. UNIQUENESS: The old_string MUST uniquely identify the specific instance you want to change. This means: + - Include AT LEAST 3-5 lines of context BEFORE the change point + - Include AT LEAST 3-5 lines of context AFTER the change point + - Include all whitespace, indentation, and surrounding code exactly as it appears in the file + +2. SINGLE INSTANCE: This tool can only change ONE instance at a time. If you need to change multiple instances: + - Make separate calls to this tool for each instance + - Each call must uniquely identify its specific instance using extensive context + +3. VERIFICATION: Before using this tool: + - Check how many instances of the target text exist in the file + - If multiple instances exist, gather enough context to uniquely identify each one + - Plan separate tool calls for each instance + +WARNING: If you do not follow these requirements: + - The tool will fail if old_string matches multiple locations + - The tool will fail if old_string doesn't match exactly (including whitespace) + - You may change the wrong instance if you don't include enough context + +When making edits: + - Ensure the edit results in idiomatic, correct code + - Do not leave the code in a broken state + - Always use fully-qualified project paths (starting with the name of one of the project's root directories) + +If you want to create a new file, use the `create_file` tool instead of this tool. Don't pass an empty `old_string`. + +Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each. diff --git a/crates/assistant_tools/src/list_directory_tool.rs b/crates/assistant_tools/src/list_directory_tool.rs index 9db00d765d58d266bdd448a2f4c4517ef865f2f4..ef0c2838e859d9fef29b95e9b5739d25d739fd62 100644 --- a/crates/assistant_tools/src/list_directory_tool.rs +++ b/crates/assistant_tools/src/list_directory_tool.rs @@ -12,7 +12,7 @@ use util::markdown::MarkdownString; #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct ListDirectoryToolInput { - /// The relative path of the directory to list. + /// The fully-qualified path of the directory to list in the project. /// /// This path should never be absolute, and the first component /// of the path should always be a root directory in a project. diff --git a/crates/assistant_tools/src/list_directory_tool/description.md b/crates/assistant_tools/src/list_directory_tool/description.md index a7d364ae6348479db645cd44098a298a14e12c2a..1daf3e3a9f29a4885f2954235b790a803df7eacc 100644 --- a/crates/assistant_tools/src/list_directory_tool/description.md +++ b/crates/assistant_tools/src/list_directory_tool/description.md @@ -1 +1 @@ -Lists files and directories in a given path. +Lists files and directories in a given path. Prefer the `regex_search` or `path_search` tools when searching the codebase. diff --git a/crates/assistant_tools/src/path_search_tool.rs b/crates/assistant_tools/src/path_search_tool.rs index 17b85f82782e6f18c9ad39f2c0e45c3d8cc9bbd4..ea19cb1dee9bf11c10a1a4089d59dbcacb6cd8cb 100644 --- a/crates/assistant_tools/src/path_search_tool.rs +++ b/crates/assistant_tools/src/path_search_tool.rs @@ -6,14 +6,14 @@ use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat} use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use std::{path::PathBuf, sync::Arc}; +use std::{cmp, fmt::Write as _, path::PathBuf, sync::Arc}; use ui::IconName; use util::paths::PathMatcher; use worktree::Snapshot; #[derive(Debug, Serialize, Deserialize, JsonSchema)] pub struct PathSearchToolInput { - /// The glob to search all project paths for. + /// The glob to match against every path in the project. /// /// /// If the project has the following root directories: @@ -76,66 +76,125 @@ impl Tool for PathSearchTool { Ok(input) => (input.offset, input.glob), Err(err) => return Task::ready(Err(anyhow!(err))).into(), }; - - let path_matcher = match PathMatcher::new([ - // Sometimes models try to search for "". In this case, return all paths in the project. - if glob.is_empty() { "*" } else { &glob }, - ]) { - Ok(matcher) => matcher, - Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))).into(), - }; - let snapshots: Vec = project - .read(cx) - .worktrees(cx) - .map(|worktree| worktree.read(cx).snapshot()) - .collect(); - + let offset = offset as usize; + let task = search_paths(&glob, project, cx); cx.background_spawn(async move { - let mut matches = Vec::new(); - - for worktree in snapshots { - let root_name = worktree.root_name(); - - // Don't consider ignored entries. - for entry in worktree.entries(false, 0) { - if path_matcher.is_match(&entry.path) { - matches.push( - PathBuf::from(root_name) - .join(&entry.path) - .to_string_lossy() - .to_string(), - ); - } - } - } + let matches = task.await?; + let paginated_matches = &matches[cmp::min(offset, matches.len()) + ..cmp::min(offset + RESULTS_PER_PAGE, matches.len())]; if matches.is_empty() { - Ok(format!("No paths in the project matched the glob {glob:?}")) + Ok("No matches found".to_string()) } else { - // Sort to group entries in the same directory together. - matches.sort(); - - let total_matches = matches.len(); - let response = if total_matches > RESULTS_PER_PAGE + offset as usize { - let paginated_matches: Vec<_> = matches - .into_iter() - .skip(offset as usize) - .take(RESULTS_PER_PAGE) - .collect(); - - format!( - "Found {} total matches. Showing results {}-{} (provide 'offset' parameter for more results):\n\n{}", - total_matches, + let mut message = format!("Found {} total matches.", matches.len()); + if matches.len() > RESULTS_PER_PAGE { + write!( + &mut message, + "\nShowing results {}-{} (provide 'offset' parameter for more results):", offset + 1, - offset as usize + paginated_matches.len(), - paginated_matches.join("\n") + offset + paginated_matches.len() ) - } else { - matches.join("\n") - }; - - Ok(response) + .unwrap(); + } + for mat in matches.into_iter().skip(offset).take(RESULTS_PER_PAGE) { + write!(&mut message, "\n{}", mat.display()).unwrap(); + } + Ok(message) } - }).into() + }) + .into() + } +} + +fn search_paths(glob: &str, project: Entity, cx: &mut App) -> Task>> { + let path_matcher = match PathMatcher::new([ + // Sometimes models try to search for "". In this case, return all paths in the project. + if glob.is_empty() { "*" } else { glob }, + ]) { + Ok(matcher) => matcher, + Err(err) => return Task::ready(Err(anyhow!("Invalid glob: {err}"))), + }; + let snapshots: Vec = project + .read(cx) + .worktrees(cx) + .map(|worktree| worktree.read(cx).snapshot()) + .collect(); + + cx.background_spawn(async move { + Ok(snapshots + .iter() + .flat_map(|snapshot| { + let root_name = PathBuf::from(snapshot.root_name()); + snapshot + .entries(false, 0) + .map(move |entry| root_name.join(&entry.path)) + .filter(|path| path_matcher.is_match(&path)) + }) + .collect()) + }) +} + +#[cfg(test)] +mod test { + use super::*; + use gpui::TestAppContext; + use project::{FakeFs, Project}; + use settings::SettingsStore; + use util::path; + + #[gpui::test] + async fn test_path_search_tool(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + serde_json::json!({ + "apple": { + "banana": { + "carrot": "1", + }, + "bandana": { + "carbonara": "2", + }, + "endive": "3" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + + let matches = cx + .update(|cx| search_paths("root/**/car*", project.clone(), cx)) + .await + .unwrap(); + assert_eq!( + matches, + &[ + PathBuf::from("root/apple/banana/carrot"), + PathBuf::from("root/apple/bandana/carbonara") + ] + ); + + let matches = cx + .update(|cx| search_paths("**/car*", project.clone(), cx)) + .await + .unwrap(); + assert_eq!( + matches, + &[ + PathBuf::from("root/apple/banana/carrot"), + PathBuf::from("root/apple/bandana/carbonara") + ] + ); + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + }); } } diff --git a/crates/assistant_tools/src/path_search_tool/description.md b/crates/assistant_tools/src/path_search_tool/description.md index 129aaa7c8e29f961f5d3e7306d764ca5494bc9ca..73345bff8e807124a6b9fce48accc600235995c8 100644 --- a/crates/assistant_tools/src/path_search_tool/description.md +++ b/crates/assistant_tools/src/path_search_tool/description.md @@ -1,3 +1,7 @@ -Returns paths in the project which match the given glob. +Fast file pattern matching tool that works with any codebase size -Results are paginated with 50 matches per page. Use the optional 'offset' parameter to request subsequent pages. +- Supports glob patterns like "**/*.js" or "src/**/*.ts" +- Returns matching file paths sorted alphabetically +- Prefer the `regex_search` tool to this tool when searching for symbols unless you have specific information about paths. +- Use this tool when you need to find files by name patterns +- Results are paginated with 50 matches per page. Use the optional 'offset' parameter to request subsequent pages. diff --git a/crates/assistant_tools/src/read_file_tool.rs b/crates/assistant_tools/src/read_file_tool.rs index 5fe5cf9e9780541480bbb81521f72151df82298c..87e6fd96a70946a0daeec4c9e44ac43137e4267f 100644 --- a/crates/assistant_tools/src/read_file_tool.rs +++ b/crates/assistant_tools/src/read_file_tool.rs @@ -1,14 +1,14 @@ -use std::sync::Arc; - use crate::{code_symbols_tool::file_outline, schema::json_schema_for}; use anyhow::{Result, anyhow}; use assistant_tool::{ActionLog, Tool, ToolResult}; use gpui::{App, Entity, Task}; +use indoc::formatdoc; use itertools::Itertools; use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat}; use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use std::sync::Arc; use ui::IconName; use util::markdown::MarkdownString; @@ -95,11 +95,24 @@ impl Tool for ReadFileTool { }; let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else { - return Task::ready(Err(anyhow!("Path {} not found in project", &input.path,))).into(); + return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into(); + }; + let Some(worktree) = project + .read(cx) + .worktree_for_id(project_path.worktree_id, cx) + else { + return Task::ready(Err(anyhow!("Worktree not found for project path"))).into(); }; + let exists = worktree.update(cx, |worktree, cx| { + worktree.file_exists(&project_path.path, cx) + }); let file_path = input.path.clone(); cx.spawn(async move |cx| { + if !exists.await? { + return Err(anyhow!("{} not found", file_path)) + } + let buffer = cx .update(|cx| { project.update(cx, |project, cx| project.open_buffer(project_path, cx)) @@ -141,11 +154,231 @@ impl Tool for ReadFileTool { } else { // File is too big, so return an error with the outline // and a suggestion to read again with line numbers. - let outline = file_outline(project, file_path, action_log, None, 0, cx).await?; + let outline = file_outline(project, file_path, action_log, None, cx).await?; + Ok(formatdoc! {" + This file was too big to read all at once. Here is an outline of its symbols: + + {outline} - Ok(format!("This file was too big to read all at once. Here is an outline of its symbols:\n\n{outline}\n\nUsing the line numbers in this outline, you can call this tool again while specifying the start_line and end_line fields to see the implementations of symbols in the outline.")) + Using the line numbers in this outline, you can call this tool again while specifying + the start_line and end_line fields to see the implementations of symbols in the outline." + }) } } }).into() } } + +#[cfg(test)] +mod test { + use super::*; + use gpui::{AppContext, TestAppContext}; + use language::{Language, LanguageConfig, LanguageMatcher}; + use project::{FakeFs, Project}; + use serde_json::json; + use settings::SettingsStore; + use util::path; + + #[gpui::test] + async fn test_read_nonexistent_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({})).await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let result = cx + .update(|cx| { + let input = json!({ + "path": "root/nonexistent_file.txt" + }); + Arc::new(ReadFileTool) + .run(input, &[], project.clone(), action_log, cx) + .output + }) + .await; + assert_eq!( + result.unwrap_err().to_string(), + "root/nonexistent_file.txt not found" + ); + } + + #[gpui::test] + async fn test_read_small_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "small_file.txt": "This is a small file content" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let result = cx + .update(|cx| { + let input = json!({ + "path": "root/small_file.txt" + }); + Arc::new(ReadFileTool) + .run(input, &[], project.clone(), action_log, cx) + .output + }) + .await; + assert_eq!(result.unwrap(), "This is a small file content"); + } + + #[gpui::test] + async fn test_read_large_file(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "large_file.rs": (0..1000).map(|i| format!("struct Test{} {{\n a: u32,\n b: usize,\n}}", i)).collect::>().join("\n") + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _| project.languages().clone()); + language_registry.add(Arc::new(rust_lang())); + let action_log = cx.new(|_| ActionLog::new(project.clone())); + + let result = cx + .update(|cx| { + let input = json!({ + "path": "root/large_file.rs" + }); + Arc::new(ReadFileTool) + .run(input, &[], project.clone(), action_log.clone(), cx) + .output + }) + .await; + let content = result.unwrap(); + assert_eq!( + content.lines().skip(2).take(6).collect::>(), + vec![ + "struct Test0 [L1-4]", + " a [L2]", + " b [L3]", + "struct Test1 [L5-8]", + " a [L6]", + " b [L7]", + ] + ); + + let result = cx + .update(|cx| { + let input = json!({ + "path": "root/large_file.rs", + "offset": 1 + }); + Arc::new(ReadFileTool) + .run(input, &[], project.clone(), action_log, cx) + .output + }) + .await; + let content = result.unwrap(); + let expected_content = (0..1000) + .flat_map(|i| { + vec![ + format!("struct Test{} [L{}-{}]", i, i * 4 + 1, i * 4 + 4), + format!(" a [L{}]", i * 4 + 2), + format!(" b [L{}]", i * 4 + 3), + ] + }) + .collect::>(); + pretty_assertions::assert_eq!( + content + .lines() + .skip(2) + .take(expected_content.len()) + .collect::>(), + expected_content + ); + } + + #[gpui::test] + async fn test_read_file_with_line_range(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "multiline.txt": "Line 1\nLine 2\nLine 3\nLine 4\nLine 5" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let result = cx + .update(|cx| { + let input = json!({ + "path": "root/multiline.txt", + "start_line": 2, + "end_line": 4 + }); + Arc::new(ReadFileTool) + .run(input, &[], project.clone(), action_log, cx) + .output + }) + .await; + assert_eq!(result.unwrap(), "Line 2\nLine 3"); + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + }); + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_outline_query( + r#" + (line_comment) @annotation + + (struct_item + "struct" @context + name: (_) @name) @item + (enum_item + "enum" @context + name: (_) @name) @item + (enum_variant + name: (_) @name) @item + (field_declaration + name: (_) @name) @item + (impl_item + "impl" @context + trait: (_)? @name + "for"? @context + type: (_) @name + body: (_ "{" (_)* "}")) @item + (function_item + "fn" @context + name: (_) @name) @item + (mod_item + "mod" @context + name: (_) @name) @item + "#, + ) + .unwrap() + } +} diff --git a/crates/assistant_tools/src/read_file_tool/description.md b/crates/assistant_tools/src/read_file_tool/description.md index b14898a2c30b0989832d9f37f311ecfa944b1e48..7bcebc03341541496ab090090ab7ef8beb3f2ebe 100644 --- a/crates/assistant_tools/src/read_file_tool/description.md +++ b/crates/assistant_tools/src/read_file_tool/description.md @@ -1,6 +1,3 @@ Reads the content of the given file in the project. -If the file is too big to read all at once, and neither a start line -nor an end line was specified, then this returns an outline of the -file's symbols (with line numbers) instead of the file's contents, -so that it can be called again with line ranges. +- Never attempt to read a path that hasn't been previously mentioned. diff --git a/crates/assistant_tools/src/regex_search_tool/description.md b/crates/assistant_tools/src/regex_search_tool/description.md index 160ecedebba60839ef4a20c6e605667bddda1f45..674dd6043b6efc4ae41dcf3cac1fc4a8b6c00e73 100644 --- a/crates/assistant_tools/src/regex_search_tool/description.md +++ b/crates/assistant_tools/src/regex_search_tool/description.md @@ -1,7 +1,6 @@ Searches the entire project for the given regular expression. -Returns a list of paths that matched the query. For each path, it returns some excerpts of the matched text. - -Results are paginated with 20 matches per page. Use the optional 'offset' parameter to request subsequent pages. - -This tool is not aware of semantics and does not use any information from language servers, so it should only be used when no available semantic tool (e.g. one that uses language servers) could fit a particular use case instead. +- Prefer this tool when searching for files containing symbols in the project. +- Supports full regex syntax (eg. "log.*Error", "function\\s+\\w+", etc.) +- Use this tool when you need to find files containing specific patterns +- Results are paginated with 20 matches per page. Use the optional 'offset' parameter to request subsequent pages. diff --git a/crates/eval/Cargo.toml b/crates/eval/Cargo.toml index d8c97f9a1185c5d2af9b1c15f88c7014382e0e51..e494ce296dc5775e228faad92ce0b57ebe24cd7d 100644 --- a/crates/eval/Cargo.toml +++ b/crates/eval/Cargo.toml @@ -28,7 +28,7 @@ language.workspace = true language_extension.workspace = true language_model.workspace = true language_models.workspace = true -languages.workspace = true +languages = { workspace = true, features = ["load-grammars"] } node_runtime.workspace = true paths.workspace = true project.workspace = true @@ -36,6 +36,7 @@ prompt_store.workspace = true release_channel.workspace = true reqwest_client.workspace = true serde.workspace = true +serde_json.workspace = true settings.workspace = true shellexpand.workspace = true telemetry.workspace = true diff --git a/crates/eval/examples/add_arp_protocol_support/criteria.md b/crates/eval/examples/add_arp_protocol_support/diff_criteria.md similarity index 100% rename from crates/eval/examples/add_arp_protocol_support/criteria.md rename to crates/eval/examples/add_arp_protocol_support/diff_criteria.md diff --git a/crates/eval/examples/auth_session_management/criteria.md b/crates/eval/examples/auth_session_management/diff_criteria.md similarity index 100% rename from crates/eval/examples/auth_session_management/criteria.md rename to crates/eval/examples/auth_session_management/diff_criteria.md diff --git a/crates/eval/examples/buffer_string_input_support/criteria.md b/crates/eval/examples/buffer_string_input_support/diff_criteria.md similarity index 100% rename from crates/eval/examples/buffer_string_input_support/criteria.md rename to crates/eval/examples/buffer_string_input_support/diff_criteria.md diff --git a/crates/eval/examples/checkpoint_stability/criteria.md b/crates/eval/examples/checkpoint_stability/diff_criteria.md similarity index 100% rename from crates/eval/examples/checkpoint_stability/criteria.md rename to crates/eval/examples/checkpoint_stability/diff_criteria.md diff --git a/crates/eval/examples/dd_iaptic_mcp_server_integration/criteria.md b/crates/eval/examples/dd_iaptic_mcp_server_integration/diff_criteria.md similarity index 100% rename from crates/eval/examples/dd_iaptic_mcp_server_integration/criteria.md rename to crates/eval/examples/dd_iaptic_mcp_server_integration/diff_criteria.md diff --git a/crates/eval/examples/debian_image_builder/criteria.md b/crates/eval/examples/debian_image_builder/diff_criteria.md similarity index 100% rename from crates/eval/examples/debian_image_builder/criteria.md rename to crates/eval/examples/debian_image_builder/diff_criteria.md diff --git a/crates/eval/examples/docs_restructure/criteria.md b/crates/eval/examples/docs_restructure/diff_criteria.md similarity index 100% rename from crates/eval/examples/docs_restructure/criteria.md rename to crates/eval/examples/docs_restructure/diff_criteria.md diff --git a/crates/eval/examples/email_verification_refactor/base.toml b/crates/eval/examples/email_verification_refactor/base.toml index a8851fdddad79359fb6401b841c98db61ef33cfa..04c26ca6b9efcceb61d82f5df33e6ffcb5314c99 100644 --- a/crates/eval/examples/email_verification_refactor/base.toml +++ b/crates/eval/examples/email_verification_refactor/base.toml @@ -1,3 +1,4 @@ url = "https://github.com/dani-garcia/vaultwarden.git" revision = "3a1f1bae002bebf26ce3a38b879c1ba26529af1e" language_extension = "rs" +allow_preexisting_diagnostics = true diff --git a/crates/eval/examples/email_verification_refactor/criteria.md b/crates/eval/examples/email_verification_refactor/diff_criteria.md similarity index 100% rename from crates/eval/examples/email_verification_refactor/criteria.md rename to crates/eval/examples/email_verification_refactor/diff_criteria.md diff --git a/crates/eval/examples/exif_rotation_support/criteria.md b/crates/eval/examples/exif_rotation_support/diff_criteria.md similarity index 100% rename from crates/eval/examples/exif_rotation_support/criteria.md rename to crates/eval/examples/exif_rotation_support/diff_criteria.md diff --git a/crates/eval/examples/expand_laravel_php_support/criteria.md b/crates/eval/examples/expand_laravel_php_support/diff_criteria.md similarity index 100% rename from crates/eval/examples/expand_laravel_php_support/criteria.md rename to crates/eval/examples/expand_laravel_php_support/diff_criteria.md diff --git a/crates/eval/examples/find_and_replace_diff_card/criteria.md b/crates/eval/examples/find_and_replace_diff_card/diff_criteria.md similarity index 100% rename from crates/eval/examples/find_and_replace_diff_card/criteria.md rename to crates/eval/examples/find_and_replace_diff_card/diff_criteria.md diff --git a/crates/eval/examples/find_and_replace_diff_card/prompt.md b/crates/eval/examples/find_and_replace_diff_card/prompt.md index efd23cbba3c8c9ce58d4840afb0de55934f0acaf..a4c2cfdb0c3a0151a6bf18f65e487cf6c821aebb 100644 --- a/crates/eval/examples/find_and_replace_diff_card/prompt.md +++ b/crates/eval/examples/find_and_replace_diff_card/prompt.md @@ -1,3 +1,3 @@ -Look at the `find_replace_file_tool.rs`. I want to implement a card for it. The card should be a brand new `Entity` with a `Render` implementation. +Look at the `find_replace_file_tool.rs`. I want to implement a card for it. The card should implement the `Render` trait. The card should show a diff. It should be a beautifully presented diff. The card "box" should look like what we show for markdown codeblocks (look at `MarkdownElement`). I want to see a red background for lines that were deleted and a green background for lines that were added. We should have a div per diff line. diff --git a/crates/eval/examples/find_and_replace_diff_card/thread_criteria.md b/crates/eval/examples/find_and_replace_diff_card/thread_criteria.md new file mode 100644 index 0000000000000000000000000000000000000000..ee5c640a44154ade73ee69edf15c8e9216f870fc --- /dev/null +++ b/crates/eval/examples/find_and_replace_diff_card/thread_criteria.md @@ -0,0 +1,3 @@ +1. The first tool call should be to path search including "find_replace_file_tool.rs" in the string. (*Not* regex_search, for example, or reading the file based on a guess at the path.) This is because we gave the model a filename and it needs to turn that into a real path. +2. After obtaining the correct path of "zed/crates/assistant_tools/src/find_replace_file_tool.rs", it should read the contents of that path. +3. When trying to find information about the Render trait, it should *not* begin with a path search, because it doesn't yet have any information on what path the Render trait might be in. diff --git a/crates/eval/examples/finnish_translation/criteria.md b/crates/eval/examples/finnish_translation/diff_criteria.md similarity index 100% rename from crates/eval/examples/finnish_translation/criteria.md rename to crates/eval/examples/finnish_translation/diff_criteria.md diff --git a/crates/eval/examples/language_model_file_support/criteria.md b/crates/eval/examples/language_model_file_support/diff_criteria.md similarity index 100% rename from crates/eval/examples/language_model_file_support/criteria.md rename to crates/eval/examples/language_model_file_support/diff_criteria.md diff --git a/crates/eval/examples/lhs_join_update_callbacks/criteria.md b/crates/eval/examples/lhs_join_update_callbacks/diff_criteria.md similarity index 100% rename from crates/eval/examples/lhs_join_update_callbacks/criteria.md rename to crates/eval/examples/lhs_join_update_callbacks/diff_criteria.md diff --git a/crates/eval/examples/libdevice_symbol_reexport/criteria.md b/crates/eval/examples/libdevice_symbol_reexport/diff_criteria.md similarity index 100% rename from crates/eval/examples/libdevice_symbol_reexport/criteria.md rename to crates/eval/examples/libdevice_symbol_reexport/diff_criteria.md diff --git a/crates/eval/examples/license_management/criteria.md b/crates/eval/examples/license_management/diff_criteria.md similarity index 100% rename from crates/eval/examples/license_management/criteria.md rename to crates/eval/examples/license_management/diff_criteria.md diff --git a/crates/eval/examples/metal_i64_support/base.toml b/crates/eval/examples/metal_i64_support/base.toml index 01b07032313a35c4bdf233b6022d5b5b47773b72..4648f148b84b24237188a3df9a8006cd42420a16 100644 --- a/crates/eval/examples/metal_i64_support/base.toml +++ b/crates/eval/examples/metal_i64_support/base.toml @@ -1,3 +1,4 @@ url = "https://github.com/huggingface/candle.git" revision = "3164a19a5dc18f5e0f7a063ae85a0cfd289e98f1" language_extension = "rs" +allow_preexisting_diagnostics = true diff --git a/crates/eval/examples/metal_i64_support/criteria.md b/crates/eval/examples/metal_i64_support/diff_criteria.md similarity index 100% rename from crates/eval/examples/metal_i64_support/criteria.md rename to crates/eval/examples/metal_i64_support/diff_criteria.md diff --git a/crates/eval/examples/metrics_data_size_updates/criteria.md b/crates/eval/examples/metrics_data_size_updates/diff_criteria.md similarity index 100% rename from crates/eval/examples/metrics_data_size_updates/criteria.md rename to crates/eval/examples/metrics_data_size_updates/diff_criteria.md diff --git a/crates/eval/examples/nan_diff_handling/criteria.md b/crates/eval/examples/nan_diff_handling/diff_criteria.md similarity index 100% rename from crates/eval/examples/nan_diff_handling/criteria.md rename to crates/eval/examples/nan_diff_handling/diff_criteria.md diff --git a/crates/eval/examples/never_type_workaround/criteria.md b/crates/eval/examples/never_type_workaround/diff_criteria.md similarity index 100% rename from crates/eval/examples/never_type_workaround/criteria.md rename to crates/eval/examples/never_type_workaround/diff_criteria.md diff --git a/crates/eval/examples/optimizer_schema_refactor/criteria.md b/crates/eval/examples/optimizer_schema_refactor/diff_criteria.md similarity index 100% rename from crates/eval/examples/optimizer_schema_refactor/criteria.md rename to crates/eval/examples/optimizer_schema_refactor/diff_criteria.md diff --git a/crates/eval/examples/rate_limit_endpoints/criteria.md b/crates/eval/examples/rate_limit_endpoints/diff_criteria.md similarity index 100% rename from crates/eval/examples/rate_limit_endpoints/criteria.md rename to crates/eval/examples/rate_limit_endpoints/diff_criteria.md diff --git a/crates/eval/examples/replace_hold_with_drain_on_exit/criteria.md b/crates/eval/examples/replace_hold_with_drain_on_exit/diff_criteria.md similarity index 100% rename from crates/eval/examples/replace_hold_with_drain_on_exit/criteria.md rename to crates/eval/examples/replace_hold_with_drain_on_exit/diff_criteria.md diff --git a/crates/eval/examples/request_to_axios_migration/criteria.md b/crates/eval/examples/request_to_axios_migration/diff_criteria.md similarity index 100% rename from crates/eval/examples/request_to_axios_migration/criteria.md rename to crates/eval/examples/request_to_axios_migration/diff_criteria.md diff --git a/crates/eval/examples/restore_version_api_support/criteria.md b/crates/eval/examples/restore_version_api_support/diff_criteria.md similarity index 100% rename from crates/eval/examples/restore_version_api_support/criteria.md rename to crates/eval/examples/restore_version_api_support/diff_criteria.md diff --git a/crates/eval/examples/runtime_script_refactor/criteria.md b/crates/eval/examples/runtime_script_refactor/diff_criteria.md similarity index 100% rename from crates/eval/examples/runtime_script_refactor/criteria.md rename to crates/eval/examples/runtime_script_refactor/diff_criteria.md diff --git a/crates/eval/examples/standardized_docker_dependency_checks/criteria.md b/crates/eval/examples/standardized_docker_dependency_checks/diff_criteria.md similarity index 100% rename from crates/eval/examples/standardized_docker_dependency_checks/criteria.md rename to crates/eval/examples/standardized_docker_dependency_checks/diff_criteria.md diff --git a/crates/eval/examples/table_metrics_sorting/criteria.md b/crates/eval/examples/table_metrics_sorting/diff_criteria.md similarity index 100% rename from crates/eval/examples/table_metrics_sorting/criteria.md rename to crates/eval/examples/table_metrics_sorting/diff_criteria.md diff --git a/crates/eval/examples/tax_id_validation/criteria.md b/crates/eval/examples/tax_id_validation/diff_criteria.md similarity index 100% rename from crates/eval/examples/tax_id_validation/criteria.md rename to crates/eval/examples/tax_id_validation/diff_criteria.md diff --git a/crates/eval/examples/test_infrastructure/criteria.md b/crates/eval/examples/test_infrastructure/diff_criteria.md similarity index 100% rename from crates/eval/examples/test_infrastructure/criteria.md rename to crates/eval/examples/test_infrastructure/diff_criteria.md diff --git a/crates/eval/examples/time_detail_merge_update/criteria.md b/crates/eval/examples/time_detail_merge_update/diff_criteria.md similarity index 100% rename from crates/eval/examples/time_detail_merge_update/criteria.md rename to crates/eval/examples/time_detail_merge_update/diff_criteria.md diff --git a/crates/eval/examples/tool_response_handling/criteria.md b/crates/eval/examples/tool_response_handling/diff_criteria.md similarity index 100% rename from crates/eval/examples/tool_response_handling/criteria.md rename to crates/eval/examples/tool_response_handling/diff_criteria.md diff --git a/crates/eval/examples/toolbar_endpoints/criteria.md b/crates/eval/examples/toolbar_endpoints/diff_criteria.md similarity index 100% rename from crates/eval/examples/toolbar_endpoints/criteria.md rename to crates/eval/examples/toolbar_endpoints/diff_criteria.md diff --git a/crates/eval/examples/virtio_block_request_refactor/base.toml b/crates/eval/examples/virtio_block_request_refactor/base.toml index a207fb3a100cbe1660244a45dc22bf7edbd7dcb1..58fdc0a963ed39eabb5e3c628ef53212de0977b4 100644 --- a/crates/eval/examples/virtio_block_request_refactor/base.toml +++ b/crates/eval/examples/virtio_block_request_refactor/base.toml @@ -1,3 +1,4 @@ url = "https://github.com/firecracker-microvm/firecracker.git" revision = "5eaa6e08e350cd38c8102848913a096312e59097" language_extension = "rs" +allow_preexisting_diagnostics = true diff --git a/crates/eval/examples/virtio_block_request_refactor/criteria.md b/crates/eval/examples/virtio_block_request_refactor/diff_criteria.md similarity index 100% rename from crates/eval/examples/virtio_block_request_refactor/criteria.md rename to crates/eval/examples/virtio_block_request_refactor/diff_criteria.md diff --git a/crates/eval/examples/war_and_uri_corrections/criteria.md b/crates/eval/examples/war_and_uri_corrections/diff_criteria.md similarity index 100% rename from crates/eval/examples/war_and_uri_corrections/criteria.md rename to crates/eval/examples/war_and_uri_corrections/diff_criteria.md diff --git a/crates/eval/examples/window_title_support/criteria.md b/crates/eval/examples/window_title_support/diff_criteria.md similarity index 100% rename from crates/eval/examples/window_title_support/criteria.md rename to crates/eval/examples/window_title_support/diff_criteria.md diff --git a/crates/eval/src/eval.rs b/crates/eval/src/eval.rs index 49c08201a815ff15810ff503d58ef202c8ee301e..5b465a5e74e9a092c28e53ea5b09a3cbeb765708 100644 --- a/crates/eval/src/eval.rs +++ b/crates/eval/src/eval.rs @@ -9,8 +9,7 @@ use ::fs::RealFs; use anyhow::{Result, anyhow}; use clap::Parser; use extension::ExtensionHostProxy; -use futures::future; -use futures::stream::StreamExt; +use futures::{StreamExt, future}; use gpui::http_client::{Uri, read_proxy_from_env}; use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal}; use gpui_tokio::Tokio; @@ -183,7 +182,7 @@ fn main() { println!( "{}Logging to: {}", example.log_prefix, - example.output_file_path.display() + example.example_output_directory().display() ); let repo_url = example.base.url.clone(); @@ -192,7 +191,7 @@ fn main() { if !repo_path.join(".git").is_dir() { println!( - "{:>(); + }); let results = futures::stream::iter(tasks) .buffer_unordered(concurrency) - .collect::>>, Example)>>() + .collect::>() .await; println!("\n\n"); @@ -259,26 +256,41 @@ fn main() { println!("========================================"); println!(""); - let mut judge_scores = Vec::new(); + let mut diff_scores = Vec::new(); + let mut thread_scores = Vec::new(); + let mut error_count = 0; for (result, example) in results { match result { Err(err) => { println!("💥 {}{:?}", example.log_prefix, err); + error_count += 1; } Ok(judge_results) => { for judge_result in judge_results { match judge_result { Ok(judge_output) => { const SCORES: [&str; 6] = ["💀", "😭", "😔", "😐", "🙂", "🤩"]; - let score: u32 = judge_output.score; - let score_index = (score.min(5)) as usize; + let diff_score: u32 = judge_output.diff.score; + let score_index = (diff_score.min(5)) as usize; println!( - "{} {}{}", - SCORES[score_index], example.log_prefix, judge_output.score, + "{} {}{} (Diff)", + SCORES[score_index], + example.log_prefix, + judge_output.diff.score, ); - judge_scores.push(judge_output.score); + diff_scores.push(judge_output.diff.score); + + if let Some(thread) = judge_output.thread { + let process_score: u32 = thread.score; + let score_index = (process_score.min(5)) as usize; + println!( + "{} {}{} (Thread)", + SCORES[score_index], example.log_prefix, thread.score, + ); + thread_scores.push(thread.score); + } } Err(err) => { println!("💥 {}{:?}", example.log_prefix, err); @@ -290,17 +302,39 @@ fn main() { println!( "{} > {}", " ".repeat(max_name_width), - example.output_file_path.display() + example.example_output_directory().display() ); } - let score_count = judge_scores.len(); - let average_score = judge_scores + let diff_score_count = diff_scores.len(); + let average_diff_score = diff_scores .into_iter() .map(|score| score as f32) .sum::() - / (score_count as f32); - println!("\nAverage score: {average_score}"); + / (diff_score_count as f32); + + if error_count > 0 { + println!("\n{error_count} examples failed to run!"); + } + + if diff_score_count > 0 { + println!("\nAverage code diff score: {average_diff_score}"); + } + + let thread_score_count = thread_scores.len(); + + // We might have gotten no thread scores if we weren't asked to judge the thread. + if thread_score_count > 0 { + let average_thread_score = thread_scores + .into_iter() + .map(|score| score as f32) + .sum::() + / (thread_score_count as f32); + + if diff_score_count > 0 { + println!("\nAverage thread score: {average_thread_score}"); + } + } std::thread::sleep(std::time::Duration::from_secs(2)); @@ -322,45 +356,11 @@ async fn run_example( let run_output = cx .update(|cx| example.run(model.clone(), app_state.clone(), cx))? .await?; - let diff = example.repository_diff().await?; - - // Run judge for each repetition - let mut results = Vec::new(); - for round in 0..judge_repetitions { - let judge_result = example.judge(model.clone(), diff.clone(), round, cx).await; - if let Ok(judge_output) = &judge_result { - let cohort_id = example - .output_file_path - .parent() - .and_then(|p| p.file_name()) - .map(|name| name.to_string_lossy().to_string()) - .unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string()); + let judge_tasks = (0..judge_repetitions) + .map(|round| run_judge_repetition(example.clone(), model.clone(), &run_output, round, cx)); - let path = std::path::Path::new("."); - let commit_id = get_current_commit_id(path).await.unwrap_or_default(); - - telemetry::event!( - "Agent Eval Completed", - cohort_id = cohort_id, - example_name = example.name.clone(), - round = round, - score = judge_output.score, - analysis = judge_output.analysis, - tool_use_counts = run_output.tool_use_counts, - response_count = run_output.response_count, - token_usage = run_output.token_usage, - model = model.telemetry_id(), - model_provider = model.provider_id().to_string(), - repository_url = example.base.url.clone(), - repository_revision = example.base.revision.clone(), - diagnostics_summary = run_output.diagnostics, - commit_id = commit_id - ); - } - - results.push(judge_result); - } + let results = future::join_all(judge_tasks).await; app_state.client.telemetry().flush_events(); @@ -537,3 +537,68 @@ pub fn get_current_commit_id_sync(repo_path: &Path) -> String { get_current_commit_id(repo_path).await.unwrap_or_default() }) } + +async fn run_judge_repetition( + example: Example, + model: Arc, + run_output: &RunOutput, + round: u32, + cx: &AsyncApp, +) -> Result { + let judge_result = example.judge(model.clone(), &run_output, round, cx).await; + + if let Ok(judge_output) = &judge_result { + let cohort_id = example + .run_directory_path + .file_name() + .map(|name| name.to_string_lossy().to_string()) + .unwrap_or(chrono::Local::now().format("%Y-%m-%d_%H-%M-%S").to_string()); + + let path = std::path::Path::new("."); + let commit_id = get_current_commit_id(path).await.unwrap_or_default(); + + if let Some(thread) = &judge_output.thread { + telemetry::event!( + "Agent Eval Completed", + cohort_id = cohort_id, + example_name = example.name.clone(), + round = round, + diff_score = judge_output.diff.score, + diff_analysis = judge_output.diff.analysis, + thread_score = thread.score, + thread_analysis = thread.analysis, + tool_use_counts = run_output.tool_use_counts, + response_count = run_output.response_count, + token_usage = run_output.token_usage, + model = model.telemetry_id(), + model_provider = model.provider_id().to_string(), + repository_url = example.base.url.clone(), + repository_revision = example.base.revision.clone(), + diagnostics_before = run_output.diagnostics_before, + diagnostics_after = run_output.diagnostics_after, + commit_id = commit_id + ); + } else { + telemetry::event!( + "Agent Eval Completed", + cohort_id = cohort_id, + example_name = example.name.clone(), + round = round, + diff_score = judge_output.diff.score, + diff_analysis = judge_output.diff.analysis, + tool_use_counts = run_output.tool_use_counts, + response_count = run_output.response_count, + token_usage = run_output.token_usage, + model = model.telemetry_id(), + model_provider = model.provider_id().to_string(), + repository_url = example.base.url.clone(), + repository_revision = example.base.revision.clone(), + diagnostics_before = run_output.diagnostics_before, + diagnostics_after = run_output.diagnostics_after, + commit_id = commit_id + ); + } + } + + judge_result +} diff --git a/crates/eval/src/example.rs b/crates/eval/src/example.rs index 36f0fe7fdf054eefc59ae9f8f62902aca65305b5..6b6fcb72904c5d9d0a56076dfd7951a1a84b48e8 100644 --- a/crates/eval/src/example.rs +++ b/crates/eval/src/example.rs @@ -10,14 +10,16 @@ use gpui::{App, AppContext as _, AsyncApp, Entity, Task}; use handlebars::Handlebars; use language::{DiagnosticSeverity, OffsetRangeExt}; use language_model::{ - LanguageModel, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role, - StopReason, TokenUsage, + LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, + MessageContent, Role, StopReason, TokenUsage, }; use project::{LspStore, Project, ProjectPath}; use serde::{Deserialize, Serialize}; +use std::cell::RefCell; use std::fmt::Write as _; use std::fs::File; use std::io::Write as _; +use std::rc::Rc; use std::sync::{Arc, Mutex}; use std::time::Duration; use std::{ @@ -45,6 +47,19 @@ pub struct ExampleBase { pub insert_id: Option, #[serde(default = "default_true")] pub require_lsp: bool, + #[serde(default)] + pub allow_preexisting_diagnostics: bool, +} + +impl ExampleBase { + pub fn repo_name(&self) -> String { + self.url + .split('/') + .next_back() + .unwrap_or(&"") + .trim_end_matches(".git") + .into() + } } #[derive(Clone, Debug)] @@ -54,14 +69,12 @@ pub struct Example { pub base: ExampleBase, /// Content of `prompt.md` pub prompt: String, - /// Content of `criteria.md` - pub criteria: String, - /// Markdown output file to append to - pub output_file: Option>>, - /// Path to the output run directory. - pub run_dir: PathBuf, - /// Path to markdown output file - pub output_file_path: PathBuf, + /// Content of `diff_criteria.md` + pub diff_criteria: String, + /// Content of `thread_criteria.md`, if that file exists (it's optional) + pub thread_criteria: Option, + /// Path to the directory containing the requests and responses for the agentic loop + pub run_directory_path: PathBuf, /// Prefix used for logging that identifies this example pub log_prefix: String, } @@ -69,41 +82,65 @@ pub struct Example { #[derive(Debug, Serialize, Deserialize, Clone)] pub struct RunOutput { pub repository_diff: String, - pub diagnostics: String, + pub ran_diagnostics_check: bool, + pub diagnostics_before: Option, + pub diagnostics_after: Option, pub response_count: usize, pub token_usage: TokenUsage, pub tool_use_counts: HashMap, u32>, + pub last_request: LanguageModelRequest, } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct JudgeInput { +pub struct JudgeDiffInput { pub repository_diff: String, + pub ran_diagnostics_check: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub diagnostics_before: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub diagnostics_after: Option, pub criteria: String, } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct JudgeOutput { +pub struct JudgeThreadInput { + pub messages: String, + pub criteria: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JudgeResponse { pub analysis: String, pub score: u32, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JudgeOutput { + pub thread: Option, + pub diff: JudgeResponse, +} + impl Example { /// Load an example from a directory containing base.toml, prompt.md, and criteria.md pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result { let name = Self::name_from_path(dir_path); let base_path = dir_path.join("base.toml"); let prompt_path = dir_path.join("prompt.md"); - let criteria_path = dir_path.join("criteria.md"); - let output_file_path = run_dir.join(format!("{}.md", name)); + let diff_criteria_path = dir_path.join("diff_criteria.md"); + let thread_criteria_path = dir_path.join("thread_criteria.md"); + let thread_criteria = if thread_criteria_path.exists() { + Some(fs::read_to_string(thread_criteria_path.clone())?) + } else { + None + }; Ok(Example { name: name.clone(), base: toml::from_str(&fs::read_to_string(&base_path)?)?, prompt: fs::read_to_string(prompt_path.clone())?, - criteria: fs::read_to_string(criteria_path.clone())?, - run_dir: run_dir.to_path_buf(), - output_file: None, - output_file_path, + thread_criteria, + diff_criteria: fs::read_to_string(diff_criteria_path.clone())?, + run_directory_path: run_dir.to_path_buf(), log_prefix: name, }) } @@ -111,10 +148,13 @@ impl Example { pub fn set_repetition_number(&mut self, repetition_number: u32) { if repetition_number > 0 { self.name = format!("{}-{}", self.name, repetition_number); - self.output_file_path = self.run_dir.join(format!("{}.md", self.name)); } } + pub fn example_output_directory(&self) -> PathBuf { + self.run_directory_path.join(&self.name) + } + pub fn set_log_prefix_style(&mut self, color: &str, name_width: usize) { self.log_prefix = format!( "{}{: Arc> { - self.output_file - .clone() - .expect("Output file not created. Call setup() first.") - } - pub fn run( &self, model: Arc, @@ -305,6 +337,11 @@ impl Example { None }; + let diagnostics_before = query_lsp_diagnostics(project.clone(), cx).await?; + if diagnostics_before.is_some() && !this.base.allow_preexisting_diagnostics { + return Err(anyhow!("Example has pre-existing diagnostics. If you want to run this example regardless, set `allow_preexisting_diagnostics` to `true` in `base.toml`")); + } + if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() { return Err(anyhow!("Setup only mode")); } @@ -312,15 +349,32 @@ impl Example { let thread_store = thread_store.await?; let thread = thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?; + let last_request = Rc::new(RefCell::new(None)); - { - let output_file_ref = this.output_file(); - let mut output_file = output_file_ref.lock().unwrap(); - writeln!(&mut output_file, "👤 USER:").log_err(); - writeln!(&mut output_file, "{}", this.prompt).log_err(); - writeln!(&mut output_file, "🤖 ASSISTANT:").log_err(); - output_file.flush().log_err(); - } + thread.update(cx, |thread, _cx| { + let mut request_count = 0; + let example_dir_path = this.example_output_directory(); + + let last_request = Rc::clone(&last_request); + thread.set_request_callback(move |request, response_events| { + *last_request.borrow_mut() = Some(request.clone()); + + request_count += 1; + let messages_file_path = example_dir_path.join(format!("{request_count}.messages.md")); + let last_messages_file_path = example_dir_path.join("last.messages.md"); + let request_markdown = RequestMarkdown::new(request); + let response_events_markdown = response_events_to_markdown(response_events); + + let messages = format!("{}\n\n{}", request_markdown.messages, response_events_markdown); + fs::write(messages_file_path, messages.clone()).expect("failed to write messages file"); + fs::write(last_messages_file_path, messages).expect("failed to write last messages file"); + + if request_count == 1 { + let tools_file_path = example_dir_path.join("tools.md"); + fs::write(tools_file_path, request_markdown.tools).expect("failed to write tools file"); + } + }); + })?; let tool_use_counts: Arc, u32>>> = Mutex::new(HashMap::default()).into(); @@ -332,8 +386,6 @@ impl Example { }); let event_handler_task = cx.spawn({ - // Need to clone the Arc here because the reference from output_file() won't live long enough - let output_file = this.output_file.clone().unwrap(); let log_prefix = this.log_prefix.clone(); let tool_use_counts = tool_use_counts.clone(); let thread = thread.downgrade(); @@ -349,8 +401,6 @@ impl Example { return Err(anyhow!("ThreadEvent channel ended early")); }; - let mut output_file = output_file.lock().unwrap(); - match event { ThreadEvent::Stopped(reason) => match reason { Ok(StopReason::EndTurn) => { @@ -371,18 +421,7 @@ impl Example { ThreadEvent::ShowError(thread_error) => { break Err(anyhow!(thread_error.clone())); } - ThreadEvent::StreamedAssistantText(_, chunk) => { - write!(&mut output_file, "{}", chunk).log_err(); - } - ThreadEvent::StreamedAssistantThinking(_, chunk) => { - write!(&mut output_file, "{}", chunk).log_err(); - } - ThreadEvent::UsePendingTools { tool_uses } => { - writeln!(&mut output_file, "\n\nUSING TOOLS:").log_err(); - for tool_use in tool_uses { - writeln!(&mut output_file, "{}: {}", tool_use.name, tool_use.input) - .log_err(); - } + ThreadEvent::StreamedAssistantText(_, _)| ThreadEvent::StreamedAssistantThinking(_, _) | ThreadEvent::UsePendingTools { .. } => { } ThreadEvent::ToolFinished { tool_use_id, @@ -398,8 +437,6 @@ impl Example { format!("TOOL FINISHED: {}", tool_use.name) }; println!("{log_prefix}{message}"); - writeln!(&mut output_file, "\n{}", message).log_err(); - writeln!(&mut output_file, "\n{}\n", tool_result.content).log_err(); let mut tool_use_counts = tool_use_counts.lock().unwrap(); *tool_use_counts .entry(tool_result.tool_name.clone()) @@ -407,7 +444,6 @@ impl Example { } else { let message = format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name); println!("{log_prefix}{message}"); - writeln!(&mut output_file, "\n{}", message).log_err(); } } })?; @@ -428,8 +464,6 @@ impl Example { } } } - - output_file.flush().log_err(); } } }); @@ -451,21 +485,35 @@ impl Example { println!("{}Getting repository diff", this.log_prefix); let repository_diff = this.repository_diff().await?; - let repository_diff_path = this.run_dir.join(format!("{}.diff", this.name)); + let example_output_dir = this.example_output_directory(); + let repository_diff_path = example_output_dir.join("patch.diff"); let mut repository_diff_output_file = File::create(&repository_diff_path)?; writeln!(&mut repository_diff_output_file, "{}", &repository_diff).log_err(); println!("{}Getting diagnostics", this.log_prefix); - let diagnostics = cx + let diagnostics_after = cx .update(move |cx| { cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await) })? .await?; println!("{}Got diagnostics", this.log_prefix); + let Some(last_request) = last_request.borrow_mut().take() else { + return Err(anyhow!("No requests ran.")); + }; + drop(subscription); drop(lsp_open_handle_and_store); + if let Some(diagnostics_before) = &diagnostics_before { + fs::write(example_output_dir.join("diagnostics_before.txt"), diagnostics_before)?; + } + + if let Some(diagnostics_after) = &diagnostics_after { + fs::write(example_output_dir.join("diagnostics_after.txt"), diagnostics_after)?; + } + + thread.update(cx, |thread, _cx| { let response_count = thread .messages() @@ -473,31 +521,38 @@ impl Example { .count(); RunOutput { repository_diff, - diagnostics, + ran_diagnostics_check: this.base.require_lsp, + diagnostics_before, + diagnostics_after, response_count, token_usage: thread.cumulative_token_usage(), tool_use_counts: tool_use_counts.lock().unwrap().clone(), + last_request, } }) }) } - pub async fn judge( + async fn judge_diff( &self, model: Arc, - repository_diff: String, - judge_repetitions: u32, + run_output: &RunOutput, + judge_number: u32, cx: &AsyncApp, - ) -> Result { - let judge_prompt = include_str!("judge_prompt.hbs"); - let judge_prompt_name = "judge_prompt"; - let mut handlebars = Handlebars::new(); - handlebars.register_template_string(judge_prompt_name, judge_prompt)?; - let prompt = handlebars.render( - judge_prompt_name, - &JudgeInput { - repository_diff, - criteria: self.criteria.clone(), + ) -> Result<(String, JudgeResponse)> { + let judge_diff_prompt = include_str!("judge_diff_prompt.hbs"); + let judge_diff_prompt_name = "judge_diff_prompt"; + let mut hbs = Handlebars::new(); + hbs.register_template_string(judge_diff_prompt_name, judge_diff_prompt)?; + + let diff_prompt = hbs.render( + judge_diff_prompt_name, + &JudgeDiffInput { + repository_diff: run_output.repository_diff.clone(), + ran_diagnostics_check: run_output.ran_diagnostics_check, + diagnostics_before: run_output.diagnostics_before.clone(), + diagnostics_after: run_output.diagnostics_after.clone(), + criteria: self.diff_criteria.clone(), }, )?; @@ -506,7 +561,7 @@ impl Example { prompt_id: None, messages: vec![LanguageModelRequestMessage { role: Role::User, - content: vec![MessageContent::Text(prompt)], + content: vec![MessageContent::Text(diff_prompt)], cache: false, }], temperature: None, @@ -514,24 +569,106 @@ impl Example { stop: Vec::new(), }; - let response = send_language_model_request(model, request, cx).await?; + let diff_response = send_language_model_request(model, request, cx).await?; + let diff_output = JudgeResponse::parse(&diff_response)?; - let judge_file_path = self.run_dir.join(format!( - "{}_judge_{}.md", - self.name, // This is the eval_name - judge_repetitions - )); + println!( + "{}Judge #{judge_number} - Diff score: {}", + self.log_prefix, diff_output.score + ); - let mut judge_output_file = File::create(&judge_file_path)?; - writeln!(&mut judge_output_file, "{}", &response).log_err(); + Ok((diff_response, diff_output)) + } + + async fn judge_thread( + &self, + model: Arc, + run_output: &RunOutput, + judge_number: u32, + cx: &AsyncApp, + ) -> Result<(String, Option)> { + if let Some(criteria) = self.thread_criteria.clone() { + let judge_thread_prompt = include_str!("judge_thread_prompt.hbs"); + let judge_thread_prompt_name = "judge_thread_prompt"; + let mut hbs = Handlebars::new(); + hbs.register_template_string(judge_thread_prompt_name, judge_thread_prompt)?; + + let request_markdown = RequestMarkdown::new(&run_output.last_request); + let thread_prompt = hbs.render( + judge_thread_prompt_name, + &JudgeThreadInput { + messages: request_markdown.messages, + criteria, + }, + )?; + + let request = LanguageModelRequest { + thread_id: None, + prompt_id: None, + messages: vec![LanguageModelRequestMessage { + role: Role::User, + content: vec![MessageContent::Text(thread_prompt)], + cache: false, + }], + temperature: None, + tools: Vec::new(), + stop: Vec::new(), + }; + + let thread_response = send_language_model_request(model, request, cx).await?; + let thread_output = JudgeResponse::parse(&thread_response)?; + + println!( + "{}Judge #{judge_number} - Thread score: {}", + self.log_prefix, thread_output.score + ); + + Ok((thread_response, Some(thread_output))) + } else { + let msg = "There were no criteria specified for this thread, so this example was not judged on its thread.".to_string(); + Ok((msg, None)) + } + } + + pub async fn judge( + &self, + model: Arc, + run_output: &RunOutput, + judge_number: u32, + cx: &AsyncApp, + ) -> Result { + let mut output_file = File::create( + self.example_output_directory() + .join(format!("judge_{}.md", judge_number)), + ) + .expect("failed to create judge.md"); + + println!("{}Running judge #{judge_number}", self.log_prefix); + + let diff_task = self.judge_diff(model.clone(), &run_output, judge_number, cx); + let thread_task = self.judge_thread(model.clone(), &run_output, judge_number, cx); - parse_judge_output(&response) + let (diff_result, thread_result) = futures::join!(diff_task, thread_task); + + let (diff_response, diff_output) = diff_result?; + let (thread_response, thread_output) = thread_result?; + + writeln!( + &mut output_file, + "# Judgment\n\n## Thread\n\n{thread_response}\n\n## Diff\n\n{diff_response}", + ) + .log_err(); + + Ok(JudgeOutput { + thread: thread_output, + diff: diff_output, + }) } - pub async fn repository_diff(&self) -> Result { + async fn repository_diff(&self) -> Result { let worktree_path = self.worktree_path(); - run_git(&worktree_path, &["add", "-N"]).await?; - run_git(&worktree_path, &["diff"]).await + run_git(&worktree_path, &["add", "."]).await?; + run_git(&worktree_path, &["diff", "--staged"]).await } } @@ -599,7 +736,10 @@ fn has_pending_lang_server_work(lsp_store: &Entity, cx: &App) -> bool .any(|(_, status)| !status.pending_work.is_empty()) } -async fn query_lsp_diagnostics(project: Entity, cx: &mut AsyncApp) -> Result { +async fn query_lsp_diagnostics( + project: Entity, + cx: &mut AsyncApp, +) -> Result> { let paths_with_diagnostics = project.update(cx, |project, cx| { project .diagnostic_summaries(true, cx) @@ -608,6 +748,10 @@ async fn query_lsp_diagnostics(project: Entity, cx: &mut AsyncApp) -> R .collect::>() })?; + if paths_with_diagnostics.is_empty() { + return Ok(None); + } + let mut output = String::new(); for project_path in paths_with_diagnostics { let buffer = project @@ -633,16 +777,18 @@ async fn query_lsp_diagnostics(project: Entity, cx: &mut AsyncApp) -> R )?; } } - anyhow::Ok(output) + anyhow::Ok(Some(output)) } -fn parse_judge_output(response: &str) -> Result { - let analysis = get_tag("analysis", response)?.to_string(); - let score = get_tag("score", response)? - .parse() - .context("error parsing score")?; +impl JudgeResponse { + fn parse(response: &str) -> Result { + let analysis = get_tag("analysis", response)?.to_string(); + let score = get_tag("score", response)? + .parse() + .context("error parsing score")?; - Ok(JudgeOutput { analysis, score }) + Ok(Self { analysis, score }) + } } fn get_tag(name: &'static str, response: &str) -> Result { @@ -724,9 +870,135 @@ pub async fn send_language_model_request( } } +struct RequestMarkdown { + tools: String, + messages: String, +} + +impl RequestMarkdown { + fn new(request: &LanguageModelRequest) -> Self { + let mut tools = String::new(); + let mut messages = String::new(); + + // Print the tools + if !request.tools.is_empty() { + for tool in &request.tools { + write!(&mut tools, "# {}\n\n", tool.name).unwrap(); + write!(&mut tools, "{}\n\n", tool.description).unwrap(); + write!( + &mut tools, + "```json\n{}\n```\n\n", + serde_json::to_string_pretty(&tool.input_schema).unwrap_or_default() + ) + .unwrap(); + } + } + + // Print the messages + for message in &request.messages { + let role_str = match message.role { + Role::User => "👤 USER", + Role::Assistant => "🤖 ASSISTANT", + Role::System => "⚙️ SYSTEM", + }; + + messages.push_str(&format!("# {}\n\n", role_str)); + + for content in &message.content { + match content { + MessageContent::Text(text) => { + messages.push_str(text); + messages.push_str("\n\n"); + } + MessageContent::Image(_) => { + messages.push_str("[IMAGE DATA]\n\n"); + } + MessageContent::ToolUse(tool_use) => { + messages.push_str(&format!( + "**Tool Use**: {} (ID: {})\n", + tool_use.name, tool_use.id + )); + messages.push_str(&format!("```json\n{}\n```\n\n", tool_use.input)); + } + MessageContent::ToolResult(tool_result) => { + messages.push_str(&format!( + "**Tool Result**: {} (ID: {})\n\n", + tool_result.tool_name, tool_result.tool_use_id + )); + if tool_result.is_error { + messages.push_str("**ERROR:**\n"); + } + messages.push_str(&format!("{}\n", tool_result.content)); + } + } + } + } + + Self { tools, messages } + } +} + +fn response_events_to_markdown( + response_events: &[std::result::Result], +) -> String { + let mut response = String::new(); + // Print the response events if any + response.push_str("# Response\n\n"); + let mut text_buffer = String::new(); + let mut thinking_buffer = String::new(); + + let flush_buffers = + |output: &mut String, text_buffer: &mut String, thinking_buffer: &mut String| { + if !text_buffer.is_empty() { + output.push_str(&format!("**Text**:\n{}\n\n", text_buffer)); + text_buffer.clear(); + } + if !thinking_buffer.is_empty() { + output.push_str(&format!("**Thinking**:\n{}\n\n", thinking_buffer)); + thinking_buffer.clear(); + } + }; + + for event in response_events { + match event { + Ok(LanguageModelCompletionEvent::Text(text)) => { + text_buffer.push_str(text); + } + Ok(LanguageModelCompletionEvent::Thinking(text)) => { + thinking_buffer.push_str(text); + } + Ok(LanguageModelCompletionEvent::Stop(reason)) => { + flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); + response.push_str(&format!("**Stop**: {:?}\n\n", reason)); + } + Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => { + flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); + response.push_str(&format!( + "**Tool Use**: {} (ID: {})\n", + tool_use.name, tool_use.id + )); + response.push_str(&format!("```json\n{}\n```\n\n", tool_use.input)); + } + Ok( + LanguageModelCompletionEvent::UsageUpdate(_) + | LanguageModelCompletionEvent::StartMessage { .. }, + ) => {} + Err(error) => { + flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); + response.push_str(&format!("**Error**: {}\n\n", error)); + } + } + } + + flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); + + response +} + #[cfg(test)] mod test { use super::*; + use handlebars::Handlebars; #[test] fn test_parse_judge_output() { @@ -736,7 +1008,7 @@ mod test { "# .unindent(); - let output = parse_judge_output(&response).unwrap(); + let output = JudgeResponse::parse(&response).unwrap(); assert_eq!( output.analysis, "The model did a good job but there were still compilations errors." @@ -756,8 +1028,158 @@ mod test { "# .unindent(); - let output = parse_judge_output(&response).unwrap(); + let output = JudgeResponse::parse(&response).unwrap(); assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2"); assert_eq!(output.score, 1); } + + #[test] + fn test_judge_prompt_with_diagnostics() { + // Case 1: Both diagnostics before and after are present + let input = JudgeDiffInput { + repository_diff: "diff content goes here".to_string(), + ran_diagnostics_check: true, + diagnostics_before: Some("Error at line 10: variable not found".to_string()), + diagnostics_after: Some("Error at line 15: missing semicolon".to_string()), + criteria: "Fix all bugs".to_string(), + }; + + let rendered = templates().render(JUDGE_PROMPT_NAME, &input).unwrap(); + + let expected_diagnostics_section = r#" + Take into account the diagnostics before and after applying the change: + + + Error at line 10: variable not found + + + + Error at line 15: missing semicolon + + "# + .unindent(); + + assert!(rendered.contains(&expected_diagnostics_section)); + } + + #[test] + fn test_judge_prompt_with_empty_diagnostics() { + // Case 2: Diagnostics check run but no diagnostics found + let input = JudgeDiffInput { + repository_diff: "diff content goes here".to_string(), + ran_diagnostics_check: true, + diagnostics_before: None, + diagnostics_after: None, + criteria: "Fix all bugs".to_string(), + }; + + let rendered = templates().render(JUDGE_PROMPT_NAME, &input).unwrap(); + + let expected_diagnostics_section = r#" + Take into account the diagnostics before and after applying the change: + + + No diagnostics before applying the edits. + + + + No diagnostics after applying the edits. + + "# + .unindent(); + + assert!(rendered.contains(&expected_diagnostics_section)); + } + + #[test] + fn test_judge_prompt_with_mixed_diagnostics() { + let templates = templates(); + + // Case 3: Before diagnostics present, after diagnostics absent + let input = JudgeDiffInput { + repository_diff: "diff content goes here".to_string(), + ran_diagnostics_check: true, + diagnostics_before: Some("Error at line 10: variable not found".to_string()), + diagnostics_after: None, + criteria: "Fix all bugs".to_string(), + }; + + let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap(); + + let expected_diagnostics_section = r#" + Take into account the diagnostics before and after applying the change: + + + Error at line 10: variable not found + + + + No diagnostics after applying the edits. + + "# + .unindent(); + + assert!(rendered.contains(&expected_diagnostics_section)); + + // Case 4: Before diagnostics absent, after diagnostics present + let input = JudgeDiffInput { + repository_diff: "diff content goes here".to_string(), + ran_diagnostics_check: true, + diagnostics_before: None, + diagnostics_after: Some("Error at line 15: missing semicolon".to_string()), + criteria: "Fix all bugs".to_string(), + }; + + let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap(); + + let expected_diagnostics_section = r#" + Take into account the diagnostics before and after applying the change: + + + No diagnostics before applying the edits. + + + + Error at line 15: missing semicolon + + "# + .unindent(); + + assert!(rendered.contains(&expected_diagnostics_section)); + } + + #[test] + fn test_judge_prompt_without_diagnostics() { + let templates = templates(); + + // Case 5: No diagnostics check run + let input = JudgeDiffInput { + repository_diff: "diff content goes here".to_string(), + ran_diagnostics_check: false, + diagnostics_before: None, + diagnostics_after: None, + criteria: "Fix all bugs".to_string(), + }; + + let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap(); + + // Check for the message when no diagnostics were performed + let diagnostics_message = "No diagnostic checks were performed."; + + assert!(rendered.contains(diagnostics_message)); + assert!(!rendered.contains("")); + assert!(!rendered.contains("")); + } + + const JUDGE_PROMPT_NAME: &str = "judge_prompt"; + + fn templates() -> Handlebars<'static> { + let mut judge_prompt = include_str!("judge_diff_prompt.hbs").to_string(); + language::LineEnding::normalize(&mut judge_prompt); + let mut handlebars = Handlebars::new(); + handlebars + .register_template_string(JUDGE_PROMPT_NAME, judge_prompt) + .unwrap(); + handlebars + } } diff --git a/crates/eval/src/judge_prompt.hbs b/crates/eval/src/judge_diff_prompt.hbs similarity index 64% rename from crates/eval/src/judge_prompt.hbs rename to crates/eval/src/judge_diff_prompt.hbs index 862cc0985c441857b89861873ffd40a03f1eef69..4e9aaf68c0e5d9c1423716a6e652ee32be31dde9 100644 --- a/crates/eval/src/judge_prompt.hbs +++ b/crates/eval/src/judge_diff_prompt.hbs @@ -10,6 +10,28 @@ Use the following criteria to score the above changes. {{criteria}} +{{#if ran_diagnostics_check}} +Take into account the diagnostics before and after applying the change: + + +{{#if diagnostics_before}} +{{{diagnostics_before}}} +{{else}} +No diagnostics before applying the edits. +{{/if}} + + + +{{#if diagnostics_after}} +{{{diagnostics_after}}} +{{else}} +No diagnostics after applying the edits. +{{/if}} + +{{else}} +No diagnostic checks were performed. +{{/if}} + Based on these criteria, give the test output a score between 0 and 5. The output score should ONLY INCLUDE whole numbers. DO NOT return decimals or floats. diff --git a/crates/eval/src/judge_thread_prompt.hbs b/crates/eval/src/judge_thread_prompt.hbs new file mode 100644 index 0000000000000000000000000000000000000000..a84ce8e698c4939a6134a286d234d5195fbc1c6e --- /dev/null +++ b/crates/eval/src/judge_thread_prompt.hbs @@ -0,0 +1,22 @@ +You are an expert software developer tasked with evaluating an AI agent's messages and tool calls in this conversation: + + +{{{messages}}} + + +Use the following criteria to score the above messages. + + +{{criteria}} + + +Based on these criteria, give the messages a score between 0 and 5. +The output score should ONLY INCLUDE whole numbers. DO NOT return decimals or floats. + +- 5 means: messages meet all criteria +- 0 means: messages don't meet any criteria + +``` +{YOUR ANALYSIS HERE} +{YOUR SCORE HERE} +``` diff --git a/crates/prompt_store/src/prompts.rs b/crates/prompt_store/src/prompts.rs index 6b143440c2b5e156789ca02840b5e4f7c076c424..9dd3df3523eb74cb47ce9a3421648f26e8ba5e16 100644 --- a/crates/prompt_store/src/prompts.rs +++ b/crates/prompt_store/src/prompts.rs @@ -60,7 +60,6 @@ pub struct DefaultUserRulesContext { #[derive(Debug, Clone, Serialize)] pub struct WorktreeContext { pub root_name: String, - pub abs_path: Arc, pub rules_file: Option, } @@ -403,7 +402,6 @@ mod test { fn test_assistant_system_prompt_renders() { let worktrees = vec![WorktreeContext { root_name: "path".into(), - abs_path: Path::new("/some/path").into(), rules_file: Some(RulesFileContext { path_in_worktree: Path::new(".rules").into(), abs_path: Path::new("/some/path/.rules").into(), diff --git a/crates/worktree/src/worktree.rs b/crates/worktree/src/worktree.rs index 38e39fd77403a055fdf4c1ade1dc7ad841058fce..b00b1d6b63b5e5ef17d20d9ca522223914969a88 100644 --- a/crates/worktree/src/worktree.rs +++ b/crates/worktree/src/worktree.rs @@ -806,6 +806,23 @@ impl Worktree { } } + pub fn file_exists(&self, path: &Path, cx: &Context) -> Task> { + match self { + Worktree::Local(this) => { + let fs = this.fs.clone(); + let path = this.absolutize(path); + cx.background_spawn(async move { + let path = path?; + let metadata = fs.metadata(&path).await?; + Ok(metadata.map_or(false, |metadata| !metadata.is_dir)) + }) + } + Worktree::Remote(_) => Task::ready(Err(anyhow!( + "remote worktrees can't yet check file existence" + ))), + } + } + pub fn load_file(&self, path: &Path, cx: &Context) -> Task> { match self { Worktree::Local(this) => this.load_file(path, cx), diff --git a/typos.toml b/typos.toml index 72bc3e8ccf3ed732fb1d28309afde91d38e766ff..4952c61c6ac20227d9be46f8b7b020cbd89667ee 100644 --- a/typos.toml +++ b/typos.toml @@ -45,9 +45,7 @@ extend-exclude = [ # Spellcheck triggers on `|Fixe[sd]|` regex part. "script/danger/dangerfile.ts", # Eval examples for prompts and criteria - "crates/eval/examples/checkpoint_stability/criteria.md", - "crates/eval/examples/tax_id_validation/prompt.md", - "crates/eval/examples/tax_id_validation/criteria.md" + "crates/eval/examples/", ] [default]