code_symbols_tool.rs

  1use std::fmt::{self, Write};
  2use std::path::PathBuf;
  3use std::sync::Arc;
  4
  5use anyhow::{Result, anyhow};
  6use assistant_tool::{ActionLog, Tool};
  7use collections::IndexMap;
  8use gpui::{App, AsyncApp, Entity, Task};
  9use language::{CodeLabel, Language, LanguageRegistry};
 10use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
 11use lsp::SymbolKind;
 12use project::{DocumentSymbol, Project, Symbol};
 13use regex::{Regex, RegexBuilder};
 14use schemars::JsonSchema;
 15use serde::{Deserialize, Serialize};
 16use ui::IconName;
 17use util::markdown::MarkdownString;
 18
 19use crate::code_symbol_iter::{CodeSymbolIterator, Entry};
 20use crate::schema::json_schema_for;
 21
 22#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 23pub struct CodeSymbolsInput {
 24    /// The relative path of the source code file to read and get the symbols for.
 25    /// This tool should only be used on source code files, never on any other type of file.
 26    ///
 27    /// This path should never be absolute, and the first component
 28    /// of the path should always be a root directory in a project.
 29    ///
 30    /// If no path is specified, this tool returns a flat list of all symbols in the project
 31    /// instead of a hierarchical outline of a specific file.
 32    ///
 33    /// <example>
 34    /// If the project has the following root directories:
 35    ///
 36    /// - directory1
 37    /// - directory2
 38    ///
 39    /// If you want to access `file.md` in `directory1`, you should use the path `directory1/file.md`.
 40    /// If you want to access `file.md` in `directory2`, you should use the path `directory2/file.md`.
 41    /// </example>
 42    #[serde(default)]
 43    pub path: Option<String>,
 44
 45    /// Optional regex pattern to filter symbols by name.
 46    /// When provided, only symbols whose names match this pattern will be included in the results.
 47    ///
 48    /// <example>
 49    /// To find only symbols that contain the word "test", use the regex pattern "test".
 50    /// To find methods that start with "get_", use the regex pattern "^get_".
 51    /// </example>
 52    #[serde(default)]
 53    pub regex: Option<String>,
 54
 55    /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
 56    ///
 57    /// <example>
 58    /// Set to `true` to make regex matching case-sensitive.
 59    /// </example>
 60    #[serde(default)]
 61    pub case_sensitive: bool,
 62
 63    /// Optional starting position for paginated results (0-based).
 64    /// When not provided, starts from the beginning.
 65    #[serde(default)]
 66    pub offset: u32,
 67}
 68
 69impl CodeSymbolsInput {
 70    /// Which page of search results this is.
 71    pub fn page(&self) -> u32 {
 72        1 + (self.offset / RESULTS_PER_PAGE)
 73    }
 74}
 75
 76const RESULTS_PER_PAGE: u32 = 2000;
 77
 78pub struct CodeSymbolsTool;
 79
 80impl Tool for CodeSymbolsTool {
 81    fn name(&self) -> String {
 82        "code_symbols".into()
 83    }
 84
 85    fn needs_confirmation(&self) -> bool {
 86        false
 87    }
 88
 89    fn description(&self) -> String {
 90        include_str!("./code_symbols_tool/description.md").into()
 91    }
 92
 93    fn icon(&self) -> IconName {
 94        IconName::Code
 95    }
 96
 97    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
 98        json_schema_for::<CodeSymbolsInput>(format)
 99    }
100
101    fn ui_text(&self, input: &serde_json::Value) -> String {
102        match serde_json::from_value::<CodeSymbolsInput>(input.clone()) {
103            Ok(input) => {
104                let page = input.page();
105
106                match &input.path {
107                    Some(path) => {
108                        let path = MarkdownString::inline_code(path);
109                        if page > 1 {
110                            format!("List page {page} of code symbols for {path}")
111                        } else {
112                            format!("List code symbols for {path}")
113                        }
114                    }
115                    None => {
116                        if page > 1 {
117                            format!("List page {page} of project symbols")
118                        } else {
119                            "List all project symbols".to_string()
120                        }
121                    }
122                }
123            }
124            Err(_) => "List code symbols".to_string(),
125        }
126    }
127
128    fn run(
129        self: Arc<Self>,
130        input: serde_json::Value,
131        _messages: &[LanguageModelRequestMessage],
132        project: Entity<Project>,
133        action_log: Entity<ActionLog>,
134        cx: &mut App,
135    ) -> Task<Result<String>> {
136        let input = match serde_json::from_value::<CodeSymbolsInput>(input) {
137            Ok(input) => input,
138            Err(err) => return Task::ready(Err(anyhow!(err))),
139        };
140
141        let regex = match input.regex {
142            Some(regex_str) => match RegexBuilder::new(&regex_str)
143                .case_insensitive(!input.case_sensitive)
144                .build()
145            {
146                Ok(regex) => Some(regex),
147                Err(err) => return Task::ready(Err(anyhow!("Invalid regex: {err}"))),
148            },
149            None => None,
150        };
151
152        cx.spawn(async move |cx| match input.path {
153            Some(path) => file_outline(project, path, action_log, regex, input.offset, cx).await,
154            None => project_symbols(project, regex, input.offset, cx).await,
155        })
156    }
157}
158
159pub async fn file_outline(
160    project: Entity<Project>,
161    path: String,
162    action_log: Entity<ActionLog>,
163    regex: Option<Regex>,
164    offset: u32,
165    cx: &mut AsyncApp,
166) -> anyhow::Result<String> {
167    let buffer = {
168        let project_path = project.read_with(cx, |project, cx| {
169            project
170                .find_project_path(&path, cx)
171                .ok_or_else(|| anyhow!("Path {path} not found in project"))
172        })??;
173
174        project
175            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
176            .await?
177    };
178
179    action_log.update(cx, |action_log, cx| {
180        action_log.buffer_read(buffer.clone(), cx);
181    })?;
182
183    let symbols = project
184        .update(cx, |project, cx| project.document_symbols(&buffer, cx))?
185        .await?;
186
187    if symbols.is_empty() {
188        return Err(
189            if buffer.read_with(cx, |buffer, _| buffer.snapshot().is_empty())? {
190                anyhow!("This file is empty.")
191            } else {
192                anyhow!("No outline information available for this file.")
193            },
194        );
195    }
196
197    let language = buffer.read_with(cx, |buffer, _| buffer.language().cloned())?;
198    let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
199
200    render_outline(&symbols, language, language_registry, regex, offset).await
201}
202
203async fn project_symbols(
204    project: Entity<Project>,
205    regex: Option<Regex>,
206    offset: u32,
207    cx: &mut AsyncApp,
208) -> anyhow::Result<String> {
209    let symbols = project
210        .update(cx, |project, cx| project.symbols("", cx))?
211        .await?;
212
213    if symbols.is_empty() {
214        return Err(anyhow!("No symbols found in project."));
215    }
216
217    let mut symbols_by_path: IndexMap<PathBuf, Vec<&Symbol>> = IndexMap::default();
218
219    for symbol in symbols
220        .iter()
221        .filter(|symbol| {
222            if let Some(regex) = &regex {
223                regex.is_match(&symbol.name)
224            } else {
225                true
226            }
227        })
228        .skip(offset as usize)
229        // Take 1 more than RESULTS_PER_PAGE so we can tell if there are more results.
230        .take((RESULTS_PER_PAGE as usize).saturating_add(1))
231    {
232        if let Some(worktree_path) = project.read_with(cx, |project, cx| {
233            project
234                .worktree_for_id(symbol.path.worktree_id, cx)
235                .map(|worktree| PathBuf::from(worktree.read(cx).root_name()))
236        })? {
237            let path = worktree_path.join(&symbol.path.path);
238            symbols_by_path.entry(path).or_default().push(symbol);
239        }
240    }
241
242    // If no symbols matched the filter, return early
243    if symbols_by_path.is_empty() {
244        return Err(anyhow!("No symbols found matching the criteria."));
245    }
246
247    let mut symbols_rendered = 0;
248    let mut has_more_symbols = false;
249    let mut output = String::new();
250
251    'outer: for (file_path, file_symbols) in symbols_by_path {
252        if symbols_rendered > 0 {
253            output.push('\n');
254        }
255
256        writeln!(&mut output, "{}", file_path.display()).ok();
257
258        for symbol in file_symbols {
259            if symbols_rendered >= RESULTS_PER_PAGE {
260                has_more_symbols = true;
261                break 'outer;
262            }
263
264            write!(&mut output, "  {} ", symbol.label.text()).ok();
265
266            // Convert to 1-based line numbers for display
267            let start_line = symbol.range.start.0.row as usize + 1;
268            let end_line = symbol.range.end.0.row as usize + 1;
269
270            if start_line == end_line {
271                writeln!(&mut output, "[L{}]", start_line).ok();
272            } else {
273                writeln!(&mut output, "[L{}-{}]", start_line, end_line).ok();
274            }
275
276            symbols_rendered += 1;
277        }
278    }
279
280    Ok(if symbols_rendered == 0 {
281        "No symbols found in the requested page.".to_string()
282    } else if has_more_symbols {
283        format!(
284            "{output}\nShowing symbols {}-{} (more symbols were found; use offset: {} to see next page)",
285            offset + 1,
286            offset + symbols_rendered,
287            offset + RESULTS_PER_PAGE,
288        )
289    } else {
290        output
291    })
292}
293
294async fn render_outline(
295    symbols: &[DocumentSymbol],
296    language: Option<Arc<Language>>,
297    registry: Arc<LanguageRegistry>,
298    regex: Option<Regex>,
299    offset: u32,
300) -> Result<String> {
301    const RESULTS_PER_PAGE_USIZE: usize = RESULTS_PER_PAGE as usize;
302    let entries = CodeSymbolIterator::new(symbols, regex.clone())
303        .skip(offset as usize)
304        // Take 1 more than RESULTS_PER_PAGE so we can tell if there are more results.
305        .take(RESULTS_PER_PAGE_USIZE.saturating_add(1))
306        .collect::<Vec<Entry>>();
307    let has_more = entries.len() > RESULTS_PER_PAGE_USIZE;
308
309    // Get language-specific labels, if available
310    let labels = match &language {
311        Some(lang) => {
312            let entries_for_labels: Vec<(String, SymbolKind)> = entries
313                .iter()
314                .take(RESULTS_PER_PAGE_USIZE)
315                .map(|entry| (entry.name.clone(), entry.kind))
316                .collect();
317
318            let lang_name = lang.name();
319            if let Some(lsp_adapter) = registry.lsp_adapters(&lang_name).first().cloned() {
320                lsp_adapter
321                    .labels_for_symbols(&entries_for_labels, lang)
322                    .await
323                    .ok()
324            } else {
325                None
326            }
327        }
328        None => None,
329    };
330
331    let mut output = String::new();
332
333    let entries_rendered = match &labels {
334        Some(label_list) => render_entries(
335            &mut output,
336            entries
337                .into_iter()
338                .take(RESULTS_PER_PAGE_USIZE)
339                .zip(label_list.iter())
340                .map(|(entry, label)| (entry, label.as_ref())),
341        ),
342        None => render_entries(
343            &mut output,
344            entries
345                .into_iter()
346                .take(RESULTS_PER_PAGE_USIZE)
347                .map(|entry| (entry, None)),
348        ),
349    };
350
351    // Calculate pagination information
352    let page_start = offset + 1;
353    let page_end = offset + entries_rendered;
354    let total_symbols = if has_more {
355        format!("more than {}", page_end)
356    } else {
357        page_end.to_string()
358    };
359
360    // Add pagination information
361    if has_more {
362        writeln!(&mut output, "\nShowing symbols {page_start}-{page_end} (there were more symbols found; use offset: {page_end} to see next page)",
363        )
364    } else {
365        writeln!(
366            &mut output,
367            "\nShowing symbols {page_start}-{page_end} (total symbols: {total_symbols})",
368        )
369    }
370    .ok();
371
372    Ok(output)
373}
374
375fn render_entries<'a>(
376    output: &mut String,
377    entries: impl IntoIterator<Item = (Entry, Option<&'a CodeLabel>)>,
378) -> u32 {
379    let mut entries_rendered = 0;
380
381    for (entry, label) in entries {
382        // Indent based on depth ("" for level 0, "  " for level 1, etc.)
383        for _ in 0..entry.depth {
384            output.push_str("  ");
385        }
386
387        match label {
388            Some(label) => {
389                output.push_str(label.text());
390            }
391            None => {
392                write_symbol_kind(output, entry.kind).ok();
393                output.push_str(&entry.name);
394            }
395        }
396
397        // Add position information - convert to 1-based line numbers for display
398        let start_line = entry.start_line + 1;
399        let end_line = entry.end_line + 1;
400
401        if start_line == end_line {
402            writeln!(output, " [L{}]", start_line).ok();
403        } else {
404            writeln!(output, " [L{}-{}]", start_line, end_line).ok();
405        }
406        entries_rendered += 1;
407    }
408
409    entries_rendered
410}
411
412// We may not have a language server adapter to have language-specific
413// ways to translate SymbolKnd into a string. In that situation,
414// fall back on some reasonable default strings to render.
415fn write_symbol_kind(buf: &mut String, kind: SymbolKind) -> Result<(), fmt::Error> {
416    match kind {
417        SymbolKind::FILE => write!(buf, "file "),
418        SymbolKind::MODULE => write!(buf, "module "),
419        SymbolKind::NAMESPACE => write!(buf, "namespace "),
420        SymbolKind::PACKAGE => write!(buf, "package "),
421        SymbolKind::CLASS => write!(buf, "class "),
422        SymbolKind::METHOD => write!(buf, "method "),
423        SymbolKind::PROPERTY => write!(buf, "property "),
424        SymbolKind::FIELD => write!(buf, "field "),
425        SymbolKind::CONSTRUCTOR => write!(buf, "constructor "),
426        SymbolKind::ENUM => write!(buf, "enum "),
427        SymbolKind::INTERFACE => write!(buf, "interface "),
428        SymbolKind::FUNCTION => write!(buf, "function "),
429        SymbolKind::VARIABLE => write!(buf, "variable "),
430        SymbolKind::CONSTANT => write!(buf, "constant "),
431        SymbolKind::STRING => write!(buf, "string "),
432        SymbolKind::NUMBER => write!(buf, "number "),
433        SymbolKind::BOOLEAN => write!(buf, "boolean "),
434        SymbolKind::ARRAY => write!(buf, "array "),
435        SymbolKind::OBJECT => write!(buf, "object "),
436        SymbolKind::KEY => write!(buf, "key "),
437        SymbolKind::NULL => write!(buf, "null "),
438        SymbolKind::ENUM_MEMBER => write!(buf, "enum member "),
439        SymbolKind::STRUCT => write!(buf, "struct "),
440        SymbolKind::EVENT => write!(buf, "event "),
441        SymbolKind::OPERATOR => write!(buf, "operator "),
442        SymbolKind::TYPE_PARAMETER => write!(buf, "type parameter "),
443        _ => Ok(()),
444    }
445}