code_symbols_tool.rs

  1use std::fmt::Write;
  2use std::path::PathBuf;
  3use std::sync::Arc;
  4
  5use crate::schema::json_schema_for;
  6use anyhow::{Result, anyhow};
  7use assistant_tool::{ActionLog, Tool, ToolResult};
  8use collections::IndexMap;
  9use gpui::{App, AsyncApp, Entity, Task};
 10use language::{OutlineItem, ParseStatus, Point};
 11use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
 12use project::{Project, Symbol};
 13use regex::{Regex, RegexBuilder};
 14use schemars::JsonSchema;
 15use serde::{Deserialize, Serialize};
 16use ui::IconName;
 17use util::markdown::MarkdownString;
 18
 19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 20pub struct CodeSymbolsInput {
 21    /// The relative path of the source code file to read and get the symbols for.
 22    /// This tool should only be used on source code files, never on any other type of file.
 23    ///
 24    /// This path should never be absolute, and the first component
 25    /// of the path should always be a root directory in a project.
 26    ///
 27    /// If no path is specified, this tool returns a flat list of all symbols in the project
 28    /// instead of a hierarchical outline of a specific file.
 29    ///
 30    /// <example>
 31    /// If the project has the following root directories:
 32    ///
 33    /// - directory1
 34    /// - directory2
 35    ///
 36    /// If you want to access `file.md` in `directory1`, you should use the path `directory1/file.md`.
 37    /// If you want to access `file.md` in `directory2`, you should use the path `directory2/file.md`.
 38    /// </example>
 39    #[serde(default)]
 40    pub path: Option<String>,
 41
 42    /// Optional regex pattern to filter symbols by name.
 43    /// When provided, only symbols whose names match this pattern will be included in the results.
 44    ///
 45    /// <example>
 46    /// To find only symbols that contain the word "test", use the regex pattern "test".
 47    /// To find methods that start with "get_", use the regex pattern "^get_".
 48    /// </example>
 49    #[serde(default)]
 50    pub regex: Option<String>,
 51
 52    /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
 53    ///
 54    /// <example>
 55    /// Set to `true` to make regex matching case-sensitive.
 56    /// </example>
 57    #[serde(default)]
 58    pub case_sensitive: bool,
 59
 60    /// Optional starting position for paginated results (0-based).
 61    /// When not provided, starts from the beginning.
 62    #[serde(default)]
 63    pub offset: u32,
 64}
 65
 66impl CodeSymbolsInput {
 67    /// Which page of search results this is.
 68    pub fn page(&self) -> u32 {
 69        1 + (self.offset / RESULTS_PER_PAGE)
 70    }
 71}
 72
 73const RESULTS_PER_PAGE: u32 = 2000;
 74
 75pub struct CodeSymbolsTool;
 76
 77impl Tool for CodeSymbolsTool {
 78    fn name(&self) -> String {
 79        "code_symbols".into()
 80    }
 81
 82    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 83        false
 84    }
 85
 86    fn description(&self) -> String {
 87        include_str!("./code_symbols_tool/description.md").into()
 88    }
 89
 90    fn icon(&self) -> IconName {
 91        IconName::Code
 92    }
 93
 94    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 95        json_schema_for::<CodeSymbolsInput>(format)
 96    }
 97
 98    fn ui_text(&self, input: &serde_json::Value) -> String {
 99        match serde_json::from_value::<CodeSymbolsInput>(input.clone()) {
100            Ok(input) => {
101                let page = input.page();
102
103                match &input.path {
104                    Some(path) => {
105                        let path = MarkdownString::inline_code(path);
106                        if page > 1 {
107                            format!("List page {page} of code symbols for {path}")
108                        } else {
109                            format!("List code symbols for {path}")
110                        }
111                    }
112                    None => {
113                        if page > 1 {
114                            format!("List page {page} of project symbols")
115                        } else {
116                            "List all project symbols".to_string()
117                        }
118                    }
119                }
120            }
121            Err(_) => "List code symbols".to_string(),
122        }
123    }
124
125    fn run(
126        self: Arc<Self>,
127        input: serde_json::Value,
128        _messages: &[LanguageModelRequestMessage],
129        project: Entity<Project>,
130        action_log: Entity<ActionLog>,
131        cx: &mut App,
132    ) -> ToolResult {
133        let input = match serde_json::from_value::<CodeSymbolsInput>(input) {
134            Ok(input) => input,
135            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
136        };
137
138        let regex = match input.regex {
139            Some(regex_str) => match RegexBuilder::new(&regex_str)
140                .case_insensitive(!input.case_sensitive)
141                .build()
142            {
143                Ok(regex) => Some(regex),
144                Err(err) => return Task::ready(Err(anyhow!("Invalid regex: {err}"))).into(),
145            },
146            None => None,
147        };
148
149        cx.spawn(async move |cx| match input.path {
150            Some(path) => file_outline(project, path, action_log, regex, input.offset, cx).await,
151            None => project_symbols(project, regex, input.offset, cx).await,
152        })
153        .into()
154    }
155}
156
157pub async fn file_outline(
158    project: Entity<Project>,
159    path: String,
160    action_log: Entity<ActionLog>,
161    regex: Option<Regex>,
162    offset: u32,
163    cx: &mut AsyncApp,
164) -> anyhow::Result<String> {
165    let buffer = {
166        let project_path = project.read_with(cx, |project, cx| {
167            project
168                .find_project_path(&path, cx)
169                .ok_or_else(|| anyhow!("Path {path} not found in project"))
170        })??;
171
172        project
173            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
174            .await?
175    };
176
177    action_log.update(cx, |action_log, cx| {
178        action_log.buffer_read(buffer.clone(), cx);
179    })?;
180
181    // Wait until the buffer has been fully parsed, so that we can read its outline.
182    let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
183    while *parse_status.borrow() != ParseStatus::Idle {
184        parse_status.changed().await?;
185    }
186
187    let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
188    let Some(outline) = snapshot.outline(None) else {
189        return Err(anyhow!("No outline information available for this file."));
190    };
191
192    render_outline(
193        outline
194            .items
195            .into_iter()
196            .map(|item| item.to_point(&snapshot)),
197        regex,
198        offset,
199    )
200    .await
201}
202
203async fn project_symbols(
204    project: Entity<Project>,
205    regex: Option<Regex>,
206    offset: u32,
207    cx: &mut AsyncApp,
208) -> anyhow::Result<String> {
209    let symbols = project
210        .update(cx, |project, cx| project.symbols("", cx))?
211        .await?;
212
213    if symbols.is_empty() {
214        return Err(anyhow!("No symbols found in project."));
215    }
216
217    let mut symbols_by_path: IndexMap<PathBuf, Vec<&Symbol>> = IndexMap::default();
218
219    for symbol in symbols
220        .iter()
221        .filter(|symbol| {
222            if let Some(regex) = &regex {
223                regex.is_match(&symbol.name)
224            } else {
225                true
226            }
227        })
228        .skip(offset as usize)
229        // Take 1 more than RESULTS_PER_PAGE so we can tell if there are more results.
230        .take((RESULTS_PER_PAGE as usize).saturating_add(1))
231    {
232        if let Some(worktree_path) = project.read_with(cx, |project, cx| {
233            project
234                .worktree_for_id(symbol.path.worktree_id, cx)
235                .map(|worktree| PathBuf::from(worktree.read(cx).root_name()))
236        })? {
237            let path = worktree_path.join(&symbol.path.path);
238            symbols_by_path.entry(path).or_default().push(symbol);
239        }
240    }
241
242    // If no symbols matched the filter, return early
243    if symbols_by_path.is_empty() {
244        return Err(anyhow!("No symbols found matching the criteria."));
245    }
246
247    let mut symbols_rendered = 0;
248    let mut has_more_symbols = false;
249    let mut output = String::new();
250
251    'outer: for (file_path, file_symbols) in symbols_by_path {
252        if symbols_rendered > 0 {
253            output.push('\n');
254        }
255
256        writeln!(&mut output, "{}", file_path.display()).ok();
257
258        for symbol in file_symbols {
259            if symbols_rendered >= RESULTS_PER_PAGE {
260                has_more_symbols = true;
261                break 'outer;
262            }
263
264            write!(&mut output, "  {} ", symbol.label.text()).ok();
265
266            // Convert to 1-based line numbers for display
267            let start_line = symbol.range.start.0.row as usize + 1;
268            let end_line = symbol.range.end.0.row as usize + 1;
269
270            if start_line == end_line {
271                writeln!(&mut output, "[L{}]", start_line).ok();
272            } else {
273                writeln!(&mut output, "[L{}-{}]", start_line, end_line).ok();
274            }
275
276            symbols_rendered += 1;
277        }
278    }
279
280    Ok(if symbols_rendered == 0 {
281        "No symbols found in the requested page.".to_string()
282    } else if has_more_symbols {
283        format!(
284            "{output}\nShowing symbols {}-{} (more symbols were found; use offset: {} to see next page)",
285            offset + 1,
286            offset + symbols_rendered,
287            offset + RESULTS_PER_PAGE,
288        )
289    } else {
290        output
291    })
292}
293
294async fn render_outline(
295    items: impl IntoIterator<Item = OutlineItem<Point>>,
296    regex: Option<Regex>,
297    offset: u32,
298) -> Result<String> {
299    const RESULTS_PER_PAGE_USIZE: usize = RESULTS_PER_PAGE as usize;
300
301    let mut items = items.into_iter().skip(offset as usize);
302
303    let entries = items
304        .by_ref()
305        .filter(|item| {
306            regex
307                .as_ref()
308                .is_none_or(|regex| regex.is_match(&item.text))
309        })
310        .take(RESULTS_PER_PAGE_USIZE)
311        .collect::<Vec<_>>();
312    let has_more = items.next().is_some();
313
314    let mut output = String::new();
315    let entries_rendered = render_entries(&mut output, entries);
316
317    // Calculate pagination information
318    let page_start = offset + 1;
319    let page_end = offset + entries_rendered;
320    let total_symbols = if has_more {
321        format!("more than {}", page_end)
322    } else {
323        page_end.to_string()
324    };
325
326    // Add pagination information
327    if has_more {
328        writeln!(&mut output, "\nShowing symbols {page_start}-{page_end} (there were more symbols found; use offset: {page_end} to see next page)",
329        )
330    } else {
331        writeln!(
332            &mut output,
333            "\nShowing symbols {page_start}-{page_end} (total symbols: {total_symbols})",
334        )
335    }
336    .ok();
337
338    Ok(output)
339}
340
341fn render_entries(output: &mut String, items: impl IntoIterator<Item = OutlineItem<Point>>) -> u32 {
342    let mut entries_rendered = 0;
343
344    for item in items {
345        // Indent based on depth ("" for level 0, "  " for level 1, etc.)
346        for _ in 0..item.depth {
347            output.push(' ');
348        }
349        output.push_str(&item.text);
350
351        // Add position information - convert to 1-based line numbers for display
352        let start_line = item.range.start.row + 1;
353        let end_line = item.range.end.row + 1;
354
355        if start_line == end_line {
356            writeln!(output, " [L{}]", start_line).ok();
357        } else {
358            writeln!(output, " [L{}-{}]", start_line, end_line).ok();
359        }
360        entries_rendered += 1;
361    }
362
363    entries_rendered
364}