diff --git a/Cargo.lock b/Cargo.lock index 731551ee52957c5fa25453f2653159580e615712..601907b9ec8db783f741c2b4174f3e18aebf5e10 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -728,9 +728,12 @@ dependencies = [ "itertools 0.14.0", "language", "language_model", + "log", + "lsp", "open", "project", "rand 0.8.5", + "regex", "release_channel", "schemars", "serde", diff --git a/assets/settings/default.json b/assets/settings/default.json index ee1b79a32e050008a61862cae074cfaff7f9198c..46dae6ccecf5ec05f340109842107a456ed237aa 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -653,6 +653,7 @@ "tools": { "bash": true, "batch-tool": true, + "code-symbols": true, "copy-path": true, "create-file": true, "delete-path": true, diff --git a/crates/assistant_tools/Cargo.toml b/crates/assistant_tools/Cargo.toml index 839da4b9a812fc1552f26c295638fd8d684d22cf..a2f4df4cb0708c3d530faf4c06418c50d85cb01b 100644 --- a/crates/assistant_tools/Cargo.toml +++ b/crates/assistant_tools/Cargo.toml @@ -25,7 +25,10 @@ http_client.workspace = true itertools.workspace = true language.workspace = true language_model.workspace = true +log.workspace = true +lsp.workspace = true project.workspace = true +regex.workspace = true release_channel.workspace = true schemars.workspace = true serde.workspace = true diff --git a/crates/assistant_tools/src/assistant_tools.rs b/crates/assistant_tools/src/assistant_tools.rs index 9c3636e8338fa19c25a24834e899459c425ed83f..2904fee60f3ba3cf07aef3a642b1f0f743cd4502 100644 --- a/crates/assistant_tools/src/assistant_tools.rs +++ b/crates/assistant_tools/src/assistant_tools.rs @@ -1,5 +1,7 @@ mod bash_tool; mod batch_tool; +mod code_symbol_iter; +mod code_symbols_tool; mod copy_path_tool; mod create_directory_tool; mod create_file_tool; @@ -29,6 +31,7 @@ use move_path_tool::MovePathTool; use crate::bash_tool::BashTool; use crate::batch_tool::BatchTool; +use crate::code_symbols_tool::CodeSymbolsTool; use crate::create_directory_tool::CreateDirectoryTool; use crate::create_file_tool::CreateFileTool; use crate::delete_path_tool::DeletePathTool; @@ -64,6 +67,7 @@ pub fn init(http_client: Arc, cx: &mut App) { registry.register_tool(ListDirectoryTool); registry.register_tool(NowTool); registry.register_tool(OpenTool); + registry.register_tool(CodeSymbolsTool); registry.register_tool(PathSearchTool); registry.register_tool(ReadFileTool); registry.register_tool(RegexSearchTool); diff --git a/crates/assistant_tools/src/code_symbol_iter.rs b/crates/assistant_tools/src/code_symbol_iter.rs new file mode 100644 index 0000000000000000000000000000000000000000..e982ab4d2a6d6b4ba4c6c87c3fea3bdebea03770 --- /dev/null +++ b/crates/assistant_tools/src/code_symbol_iter.rs @@ -0,0 +1,88 @@ +use project::DocumentSymbol; +use regex::Regex; + +#[derive(Debug, Clone)] +pub struct Entry { + pub name: String, + pub kind: lsp::SymbolKind, + pub depth: u32, + pub start_line: usize, + pub end_line: usize, +} + +/// An iterator that filters document symbols based on a regex pattern. +/// This iterator recursively traverses the document symbol tree, incrementing depth for child symbols. +#[derive(Debug, Clone)] +pub struct CodeSymbolIterator<'a> { + symbols: &'a [DocumentSymbol], + regex: Option, + // Stack of (symbol, depth) pairs to process + pending_symbols: Vec<(&'a DocumentSymbol, u32)>, + current_index: usize, + current_depth: u32, +} + +impl<'a> CodeSymbolIterator<'a> { + pub fn new(symbols: &'a [DocumentSymbol], regex: Option) -> Self { + Self { + symbols, + regex, + pending_symbols: Vec::new(), + current_index: 0, + current_depth: 0, + } + } +} + +impl Iterator for CodeSymbolIterator<'_> { + type Item = Entry; + + fn next(&mut self) -> Option { + if let Some((symbol, depth)) = self.pending_symbols.pop() { + for child in symbol.children.iter().rev() { + self.pending_symbols.push((child, depth + 1)); + } + + return Some(Entry { + name: symbol.name.clone(), + kind: symbol.kind, + depth, + start_line: symbol.range.start.0.row as usize, + end_line: symbol.range.end.0.row as usize, + }); + } + + while self.current_index < self.symbols.len() { + let regex = self.regex.as_ref(); + let symbol = &self.symbols[self.current_index]; + self.current_index += 1; + + if regex.is_none_or(|regex| regex.is_match(&symbol.name)) { + // Push in reverse order to maintain traversal order + for child in symbol.children.iter().rev() { + self.pending_symbols.push((child, self.current_depth + 1)); + } + + return Some(Entry { + name: symbol.name.clone(), + kind: symbol.kind, + depth: self.current_depth, + start_line: symbol.range.start.0.row as usize, + end_line: symbol.range.end.0.row as usize, + }); + } else { + // Even if parent doesn't match, push children to check them later + for child in symbol.children.iter().rev() { + self.pending_symbols.push((child, self.current_depth + 1)); + } + + // Check if any pending children match our criteria + if let Some(result) = self.next() { + return Some(result); + } + } + } + + None + } +} diff --git a/crates/assistant_tools/src/code_symbols_tool.rs b/crates/assistant_tools/src/code_symbols_tool.rs new file mode 100644 index 0000000000000000000000000000000000000000..0e7b7a7c796f71b26927dea4e0feff33becdb343 --- /dev/null +++ b/crates/assistant_tools/src/code_symbols_tool.rs @@ -0,0 +1,445 @@ +use std::fmt::{self, Write}; +use std::path::PathBuf; +use std::sync::Arc; + +use anyhow::{anyhow, Result}; +use assistant_tool::{ActionLog, Tool}; +use collections::IndexMap; +use gpui::{App, AsyncApp, Entity, Task}; +use language::{CodeLabel, Language, LanguageRegistry}; +use language_model::LanguageModelRequestMessage; +use lsp::SymbolKind; +use project::{DocumentSymbol, Project, Symbol}; +use regex::{Regex, RegexBuilder}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use ui::IconName; +use util::markdown::MarkdownString; + +use crate::code_symbol_iter::{CodeSymbolIterator, Entry}; + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct CodeSymbolsInput { + /// The relative path of the source code file to read and get the symbols for. + /// This tool should only be used on source code files, never on any other type of file. + /// + /// This path should never be absolute, and the first component + /// of the path should always be a root directory in a project. + /// + /// If no path is specified, this tool returns a flat list of all symbols in the project + /// instead of a hierarchical outline of a specific file. + /// + /// + /// If the project has the following root directories: + /// + /// - directory1 + /// - directory2 + /// + /// If you want to access `file.md` in `directory1`, you should use the path `directory1/file.md`. + /// If you want to access `file.md` in `directory2`, you should use the path `directory2/file.md`. + /// + #[serde(default)] + pub path: Option, + + /// Optional regex pattern to filter symbols by name. + /// When provided, only symbols whose names match this pattern will be included in the results. + /// + /// + /// To find only symbols that contain the word "test", use the regex pattern "test". + /// To find methods that start with "get_", use the regex pattern "^get_". + /// + #[serde(default)] + pub regex: Option, + + /// Whether the regex is case-sensitive. Defaults to false (case-insensitive). + /// + /// + /// Set to `true` to make regex matching case-sensitive. + /// + #[serde(default)] + pub case_sensitive: bool, + + /// Optional starting position for paginated results (0-based). + /// When not provided, starts from the beginning. + #[serde(default)] + pub offset: u32, +} + +impl CodeSymbolsInput { + /// Which page of search results this is. + pub fn page(&self) -> u32 { + 1 + (self.offset / RESULTS_PER_PAGE) + } +} + +const RESULTS_PER_PAGE: u32 = 2000; + +pub struct CodeSymbolsTool; + +impl Tool for CodeSymbolsTool { + fn name(&self) -> String { + "code-symbols".into() + } + + fn needs_confirmation(&self) -> bool { + false + } + + fn description(&self) -> String { + include_str!("./code_symbols_tool/description.md").into() + } + + fn icon(&self) -> IconName { + IconName::Eye + } + + fn input_schema(&self) -> serde_json::Value { + let schema = schemars::schema_for!(CodeSymbolsInput); + serde_json::to_value(&schema).unwrap() + } + + fn ui_text(&self, input: &serde_json::Value) -> String { + match serde_json::from_value::(input.clone()) { + Ok(input) => { + let page = input.page(); + + match &input.path { + Some(path) => { + let path = MarkdownString::inline_code(path); + if page > 1 { + format!("List page {page} of code symbols for {path}") + } else { + format!("List code symbols for {path}") + } + } + None => { + if page > 1 { + format!("List page {page} of project symbols") + } else { + "List all project symbols".to_string() + } + } + } + } + Err(_) => "List code symbols".to_string(), + } + } + + fn run( + self: Arc, + input: serde_json::Value, + _messages: &[LanguageModelRequestMessage], + project: Entity, + action_log: Entity, + cx: &mut App, + ) -> Task> { + let input = match serde_json::from_value::(input) { + Ok(input) => input, + Err(err) => return Task::ready(Err(anyhow!(err))), + }; + + let regex = match input.regex { + Some(regex_str) => match RegexBuilder::new(®ex_str) + .case_insensitive(!input.case_sensitive) + .build() + { + Ok(regex) => Some(regex), + Err(err) => return Task::ready(Err(anyhow!("Invalid regex: {err}"))), + }, + None => None, + }; + + cx.spawn(async move |cx| match input.path { + Some(path) => file_outline(project, path, action_log, regex, input.offset, cx).await, + None => project_symbols(project, regex, input.offset, cx).await, + }) + } +} + +async fn file_outline( + project: Entity, + path: String, + action_log: Entity, + regex: Option, + offset: u32, + cx: &mut AsyncApp, +) -> anyhow::Result { + let buffer = { + let project_path = project.read_with(cx, |project, cx| { + project + .find_project_path(&path, cx) + .ok_or_else(|| anyhow!("Path {path} not found in project")) + })??; + + project + .update(cx, |project, cx| project.open_buffer(project_path, cx))? + .await? + }; + + action_log.update(cx, |action_log, cx| { + action_log.buffer_read(buffer.clone(), cx); + })?; + + let symbols = project + .update(cx, |project, cx| project.document_symbols(&buffer, cx))? + .await?; + + if symbols.is_empty() { + return Err( + if buffer.read_with(cx, |buffer, _| buffer.snapshot().is_empty())? { + anyhow!("This file is empty.") + } else { + anyhow!("No outline information available for this file.") + }, + ); + } + + let language = buffer.read_with(cx, |buffer, _| buffer.language().cloned())?; + let language_registry = project.read_with(cx, |project, _| project.languages().clone())?; + + render_outline(&symbols, language, language_registry, regex, offset).await +} + +async fn project_symbols( + project: Entity, + regex: Option, + offset: u32, + cx: &mut AsyncApp, +) -> anyhow::Result { + let symbols = project + .update(cx, |project, cx| project.symbols("", cx))? + .await?; + + if symbols.is_empty() { + return Err(anyhow!("No symbols found in project.")); + } + + let mut symbols_by_path: IndexMap> = IndexMap::default(); + + for symbol in symbols + .iter() + .filter(|symbol| { + if let Some(regex) = ®ex { + regex.is_match(&symbol.name) + } else { + true + } + }) + .skip(offset as usize) + // Take 1 more than RESULTS_PER_PAGE so we can tell if there are more results. + .take((RESULTS_PER_PAGE as usize).saturating_add(1)) + { + if let Some(worktree_path) = project.read_with(cx, |project, cx| { + project + .worktree_for_id(symbol.path.worktree_id, cx) + .map(|worktree| PathBuf::from(worktree.read(cx).root_name())) + })? { + let path = worktree_path.join(&symbol.path.path); + symbols_by_path.entry(path).or_default().push(symbol); + } + } + + // If no symbols matched the filter, return early + if symbols_by_path.is_empty() { + return Err(anyhow!("No symbols found matching the criteria.")); + } + + let mut symbols_rendered = 0; + let mut has_more_symbols = false; + let mut output = String::new(); + + 'outer: for (file_path, file_symbols) in symbols_by_path { + if symbols_rendered > 0 { + output.push('\n'); + } + + writeln!(&mut output, "{}", file_path.display()).ok(); + + for symbol in file_symbols { + if symbols_rendered >= RESULTS_PER_PAGE { + has_more_symbols = true; + break 'outer; + } + + write!(&mut output, " {} ", symbol.label.text()).ok(); + + // Convert to 1-based line numbers for display + let start_line = symbol.range.start.0.row as usize + 1; + let end_line = symbol.range.end.0.row as usize + 1; + + if start_line == end_line { + writeln!(&mut output, "[L{}]", start_line).ok(); + } else { + writeln!(&mut output, "[L{}-{}]", start_line, end_line).ok(); + } + + symbols_rendered += 1; + } + } + + Ok(if symbols_rendered == 0 { + "No symbols found in the requested page.".to_string() + } else if has_more_symbols { + format!( + "{output}\nShowing symbols {}-{} (more symbols were found; use offset: {} to see next page)", + offset + 1, + offset + symbols_rendered, + offset + RESULTS_PER_PAGE, + ) + } else { + output + }) +} + +async fn render_outline( + symbols: &[DocumentSymbol], + language: Option>, + registry: Arc, + regex: Option, + offset: u32, +) -> Result { + const RESULTS_PER_PAGE_USIZE: usize = RESULTS_PER_PAGE as usize; + let entries = CodeSymbolIterator::new(symbols, regex.clone()) + .skip(offset as usize) + // Take 1 more than RESULTS_PER_PAGE so we can tell if there are more results. + .take(RESULTS_PER_PAGE_USIZE.saturating_add(1)) + .collect::>(); + let has_more = entries.len() > RESULTS_PER_PAGE_USIZE; + + // Get language-specific labels, if available + let labels = match &language { + Some(lang) => { + let entries_for_labels: Vec<(String, SymbolKind)> = entries + .iter() + .take(RESULTS_PER_PAGE_USIZE) + .map(|entry| (entry.name.clone(), entry.kind)) + .collect(); + + let lang_name = lang.name(); + if let Some(lsp_adapter) = registry.lsp_adapters(&lang_name).first().cloned() { + lsp_adapter + .labels_for_symbols(&entries_for_labels, lang) + .await + .ok() + } else { + None + } + } + None => None, + }; + + let mut output = String::new(); + + let entries_rendered = match &labels { + Some(label_list) => render_entries( + &mut output, + entries + .into_iter() + .take(RESULTS_PER_PAGE_USIZE) + .zip(label_list.iter()) + .map(|(entry, label)| (entry, label.as_ref())), + ), + None => render_entries( + &mut output, + entries + .into_iter() + .take(RESULTS_PER_PAGE_USIZE) + .map(|entry| (entry, None)), + ), + }; + + // Calculate pagination information + let page_start = offset + 1; + let page_end = offset + entries_rendered; + let total_symbols = if has_more { + format!("more than {}", page_end) + } else { + page_end.to_string() + }; + + // Add pagination information + if has_more { + writeln!(&mut output, "\nShowing symbols {page_start}-{page_end} (there were more symbols found; use offset: {page_end} to see next page)", + ) + } else { + writeln!( + &mut output, + "\nShowing symbols {page_start}-{page_end} (total symbols: {total_symbols})", + ) + } + .ok(); + + Ok(output) +} + +fn render_entries<'a>( + output: &mut String, + entries: impl IntoIterator)>, +) -> u32 { + let mut entries_rendered = 0; + + for (entry, label) in entries { + // Indent based on depth ("" for level 0, " " for level 1, etc.) + for _ in 0..entry.depth { + output.push_str(" "); + } + + match label { + Some(label) => { + output.push_str(label.text()); + } + None => { + write_symbol_kind(output, entry.kind).ok(); + output.push_str(&entry.name); + } + } + + // Add position information - convert to 1-based line numbers for display + let start_line = entry.start_line + 1; + let end_line = entry.end_line + 1; + + if start_line == end_line { + writeln!(output, " [L{}]", start_line).ok(); + } else { + writeln!(output, " [L{}-{}]", start_line, end_line).ok(); + } + entries_rendered += 1; + } + + entries_rendered +} + +// We may not have a language server adapter to have language-specific +// ways to translate SymbolKnd into a string. In that situation, +// fall back on some reasonable default strings to render. +fn write_symbol_kind(buf: &mut String, kind: SymbolKind) -> Result<(), fmt::Error> { + match kind { + SymbolKind::FILE => write!(buf, "file "), + SymbolKind::MODULE => write!(buf, "module "), + SymbolKind::NAMESPACE => write!(buf, "namespace "), + SymbolKind::PACKAGE => write!(buf, "package "), + SymbolKind::CLASS => write!(buf, "class "), + SymbolKind::METHOD => write!(buf, "method "), + SymbolKind::PROPERTY => write!(buf, "property "), + SymbolKind::FIELD => write!(buf, "field "), + SymbolKind::CONSTRUCTOR => write!(buf, "constructor "), + SymbolKind::ENUM => write!(buf, "enum "), + SymbolKind::INTERFACE => write!(buf, "interface "), + SymbolKind::FUNCTION => write!(buf, "function "), + SymbolKind::VARIABLE => write!(buf, "variable "), + SymbolKind::CONSTANT => write!(buf, "constant "), + SymbolKind::STRING => write!(buf, "string "), + SymbolKind::NUMBER => write!(buf, "number "), + SymbolKind::BOOLEAN => write!(buf, "boolean "), + SymbolKind::ARRAY => write!(buf, "array "), + SymbolKind::OBJECT => write!(buf, "object "), + SymbolKind::KEY => write!(buf, "key "), + SymbolKind::NULL => write!(buf, "null "), + SymbolKind::ENUM_MEMBER => write!(buf, "enum member "), + SymbolKind::STRUCT => write!(buf, "struct "), + SymbolKind::EVENT => write!(buf, "event "), + SymbolKind::OPERATOR => write!(buf, "operator "), + SymbolKind::TYPE_PARAMETER => write!(buf, "type parameter "), + _ => Ok(()), + } +} diff --git a/crates/assistant_tools/src/code_symbols_tool/description.md b/crates/assistant_tools/src/code_symbols_tool/description.md new file mode 100644 index 0000000000000000000000000000000000000000..8916b38797940cd001d5ba3bf3a6c745bbd4f8bb --- /dev/null +++ b/crates/assistant_tools/src/code_symbols_tool/description.md @@ -0,0 +1,39 @@ +Returns either an outline of the public code symbols in the entire project (grouped by file) or else an outline of both the public and private code symbols within a particular file. + +When a path is provided, this tool returns a hierarchical outline of code symbols for that specific file. +When no path is provided, it returns a list of all public code symbols in the project, organized by file. + +You can also provide an optional regular expression which filters the output by only showing code symbols which match that regex. + +Results are paginated with 2000 entries per page. Use the optional 'offset' parameter to request subsequent pages. + +Markdown headings indicate the structure of the output; just like +with markdown headings, the more # symbols there are at the beginning of a line, +the deeper it is in the hierarchy. + +Each code symbol entry ends with a line number or range, which tells you what portion of the +underlying source code file corresponds to that part of the outline. You can use +that line information with other tools, to strategically read portions of the source code. + +For example, you can use this tool to find a relevant symbol in the project, then get the outline of the file which contains that symbol, then use the line number information from that file's outline to read different sections of that file, without having to read the entire file all at once (which can be slow, or use a lot of tokens). + + +# class Foo [L123-136] +## method do_something(arg1, arg2) [L124-126] +## method process_data(data) [L128-135] +# class Bar [L145-161] +## method initialize() [L146-149] +## method update_state(new_state) [L160] +## private method _validate_state(state) [L161-162] + + +This example shows how tree-sitter outlines the structure of source code: + +1. `class Foo` is defined on lines 123-136 + - It contains a method `do_something` spanning lines 124-126 + - It also has a method `process_data` spanning lines 128-135 + +2. `class Bar` is defined on lines 145-161 + - It has an `initialize` method spanning lines 146-149 + - It has an `update_state` method on line 160 + - It has a private method `_validate_state` spanning lines 161-162 diff --git a/crates/assistant_tools/src/path_search_tool.rs b/crates/assistant_tools/src/path_search_tool.rs index c3bdfbe4e02eba872f6018aeaf177209a648166b..cae5749b0ed41999ca64437afd49608eb83603fe 100644 --- a/crates/assistant_tools/src/path_search_tool.rs +++ b/crates/assistant_tools/src/path_search_tool.rs @@ -28,7 +28,7 @@ pub struct PathSearchToolInput { /// Optional starting position for paginated results (0-based). /// When not provided, starts from the beginning. #[serde(default)] - pub offset: Option, + pub offset: u32, } const RESULTS_PER_PAGE: usize = 50; @@ -73,7 +73,7 @@ impl Tool for PathSearchTool { cx: &mut App, ) -> Task> { let (offset, glob) = match serde_json::from_value::(input) { - Ok(input) => (input.offset.unwrap_or(0), input.glob), + Ok(input) => (input.offset, input.glob), Err(err) => return Task::ready(Err(anyhow!(err))), }; @@ -116,10 +116,10 @@ impl Tool for PathSearchTool { matches.sort(); let total_matches = matches.len(); - let response = if total_matches > offset + RESULTS_PER_PAGE { - let paginated_matches: Vec<_> = matches + let response = if total_matches > RESULTS_PER_PAGE + offset as usize { + let paginated_matches: Vec<_> = matches .into_iter() - .skip(offset) + .skip(offset as usize) .take(RESULTS_PER_PAGE) .collect(); @@ -127,7 +127,7 @@ impl Tool for PathSearchTool { "Found {} total matches. Showing results {}-{} (provide 'offset' parameter for more results):\n\n{}", total_matches, offset + 1, - offset + paginated_matches.len(), + offset as usize + paginated_matches.len(), paginated_matches.join("\n") ) } else { diff --git a/crates/assistant_tools/src/regex_search_tool.rs b/crates/assistant_tools/src/regex_search_tool.rs index 635f5439c6c3c53e75ab4c8c3b1a560c806806c2..29452c4cb7375567547fdf4a9d9133561ba5f5ed 100644 --- a/crates/assistant_tools/src/regex_search_tool.rs +++ b/crates/assistant_tools/src/regex_search_tool.rs @@ -24,13 +24,13 @@ pub struct RegexSearchToolInput { /// Optional starting position for paginated results (0-based). /// When not provided, starts from the beginning. #[serde(default)] - pub offset: Option, + pub offset: u32, } impl RegexSearchToolInput { /// Which page of search results this is. pub fn page(&self) -> u32 { - 1 + (self.offset.unwrap_or(0) / RESULTS_PER_PAGE) + 1 + (self.offset / RESULTS_PER_PAGE) } } @@ -87,7 +87,7 @@ impl Tool for RegexSearchTool { const CONTEXT_LINES: u32 = 2; let (offset, regex) = match serde_json::from_value::(input) { - Ok(input) => (input.offset.unwrap_or(0), input.regex), + Ok(input) => (input.offset, input.regex), Err(err) => return Task::ready(Err(anyhow!(err))), };