Use tree-sitter when returning symbols to the model for a given file (#28352)

Antonio Scandurra created

This also increases the threshold for when we return an outline during
`read_file`.

Release Notes:

- Fixed an issue that caused the agent to fail reading large files if
the LSP hadn't started yet.

Change summary

Cargo.lock                                      |   1 
crates/assistant_tools/Cargo.toml               |   1 
crates/assistant_tools/src/assistant_tools.rs   |   1 
crates/assistant_tools/src/code_symbol_iter.rs  |  88 ---------
crates/assistant_tools/src/code_symbols_tool.rs | 174 +++++-------------
crates/assistant_tools/src/read_file_tool.rs    |   5 
6 files changed, 49 insertions(+), 221 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -746,7 +746,6 @@ dependencies = [
  "itertools 0.14.0",
  "language",
  "language_model",
- "lsp",
  "open",
  "project",
  "rand 0.8.5",

crates/assistant_tools/Cargo.toml 🔗

@@ -23,7 +23,6 @@ http_client.workspace = true
 itertools.workspace = true
 language.workspace = true
 language_model.workspace = true
-lsp.workspace = true
 project.workspace = true
 regex.workspace = true
 schemars.workspace = true

crates/assistant_tools/src/code_symbol_iter.rs 🔗

@@ -1,88 +0,0 @@
-use project::DocumentSymbol;
-use regex::Regex;
-
-#[derive(Debug, Clone)]
-pub struct Entry {
-    pub name: String,
-    pub kind: lsp::SymbolKind,
-    pub depth: u32,
-    pub start_line: usize,
-    pub end_line: usize,
-}
-
-/// An iterator that filters document symbols based on a regex pattern.
-/// This iterator recursively traverses the document symbol tree, incrementing depth for child symbols.
-#[derive(Debug, Clone)]
-pub struct CodeSymbolIterator<'a> {
-    symbols: &'a [DocumentSymbol],
-    regex: Option<Regex>,
-    // Stack of (symbol, depth) pairs to process
-    pending_symbols: Vec<(&'a DocumentSymbol, u32)>,
-    current_index: usize,
-    current_depth: u32,
-}
-
-impl<'a> CodeSymbolIterator<'a> {
-    pub fn new(symbols: &'a [DocumentSymbol], regex: Option<Regex>) -> Self {
-        Self {
-            symbols,
-            regex,
-            pending_symbols: Vec::new(),
-            current_index: 0,
-            current_depth: 0,
-        }
-    }
-}
-
-impl Iterator for CodeSymbolIterator<'_> {
-    type Item = Entry;
-
-    fn next(&mut self) -> Option<Self::Item> {
-        if let Some((symbol, depth)) = self.pending_symbols.pop() {
-            for child in symbol.children.iter().rev() {
-                self.pending_symbols.push((child, depth + 1));
-            }
-
-            return Some(Entry {
-                name: symbol.name.clone(),
-                kind: symbol.kind,
-                depth,
-                start_line: symbol.range.start.0.row as usize,
-                end_line: symbol.range.end.0.row as usize,
-            });
-        }
-
-        while self.current_index < self.symbols.len() {
-            let regex = self.regex.as_ref();
-            let symbol = &self.symbols[self.current_index];
-            self.current_index += 1;
-
-            if regex.is_none_or(|regex| regex.is_match(&symbol.name)) {
-                // Push in reverse order to maintain traversal order
-                for child in symbol.children.iter().rev() {
-                    self.pending_symbols.push((child, self.current_depth + 1));
-                }
-
-                return Some(Entry {
-                    name: symbol.name.clone(),
-                    kind: symbol.kind,
-                    depth: self.current_depth,
-                    start_line: symbol.range.start.0.row as usize,
-                    end_line: symbol.range.end.0.row as usize,
-                });
-            } else {
-                // Even if parent doesn't match, push children to check them later
-                for child in symbol.children.iter().rev() {
-                    self.pending_symbols.push((child, self.current_depth + 1));
-                }
-
-                // Check if any pending children match our criteria
-                if let Some(result) = self.next() {
-                    return Some(result);
-                }
-            }
-        }
-
-        None
-    }
-}

crates/assistant_tools/src/code_symbols_tool.rs 🔗

@@ -1,24 +1,21 @@
-use std::fmt::{self, Write};
+use std::fmt::Write;
 use std::path::PathBuf;
 use std::sync::Arc;
 
+use crate::schema::json_schema_for;
 use anyhow::{Result, anyhow};
 use assistant_tool::{ActionLog, Tool};
 use collections::IndexMap;
 use gpui::{App, AsyncApp, Entity, Task};
-use language::{CodeLabel, Language, LanguageRegistry};
+use language::{OutlineItem, ParseStatus, Point};
 use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
-use lsp::SymbolKind;
-use project::{DocumentSymbol, Project, Symbol};
+use project::{Project, Symbol};
 use regex::{Regex, RegexBuilder};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use ui::IconName;
 use util::markdown::MarkdownString;
 
-use crate::code_symbol_iter::{CodeSymbolIterator, Entry};
-use crate::schema::json_schema_for;
-
 #[derive(Debug, Serialize, Deserialize, JsonSchema)]
 pub struct CodeSymbolsInput {
     /// The relative path of the source code file to read and get the symbols for.
@@ -180,24 +177,28 @@ pub async fn file_outline(
         action_log.buffer_read(buffer.clone(), cx);
     })?;
 
-    let symbols = project
-        .update(cx, |project, cx| project.document_symbols(&buffer, cx))?
-        .await?;
-
-    if symbols.is_empty() {
-        return Err(
-            if buffer.read_with(cx, |buffer, _| buffer.snapshot().is_empty())? {
-                anyhow!("This file is empty.")
-            } else {
-                anyhow!("No outline information available for this file.")
-            },
-        );
-    }
-
-    let language = buffer.read_with(cx, |buffer, _| buffer.language().cloned())?;
-    let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
+    // Wait until the buffer has been fully parsed, so that we can read its outline.
+    let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
+    while parse_status
+        .recv()
+        .await
+        .map_or(false, |status| status != ParseStatus::Idle)
+    {}
+
+    let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
+    let Some(outline) = snapshot.outline(None) else {
+        return Err(anyhow!("No outline information available for this file."));
+    };
 
-    render_outline(&symbols, language, language_registry, regex, offset).await
+    render_outline(
+        outline
+            .items
+            .into_iter()
+            .map(|item| item.to_point(&snapshot)),
+        regex,
+        offset,
+    )
+    .await
 }
 
 async fn project_symbols(
@@ -292,61 +293,27 @@ async fn project_symbols(
 }
 
 async fn render_outline(
-    symbols: &[DocumentSymbol],
-    language: Option<Arc<Language>>,
-    registry: Arc<LanguageRegistry>,
+    items: impl IntoIterator<Item = OutlineItem<Point>>,
     regex: Option<Regex>,
     offset: u32,
 ) -> Result<String> {
     const RESULTS_PER_PAGE_USIZE: usize = RESULTS_PER_PAGE as usize;
-    let entries = CodeSymbolIterator::new(symbols, regex.clone())
-        .skip(offset as usize)
-        // Take 1 more than RESULTS_PER_PAGE so we can tell if there are more results.
-        .take(RESULTS_PER_PAGE_USIZE.saturating_add(1))
-        .collect::<Vec<Entry>>();
-    let has_more = entries.len() > RESULTS_PER_PAGE_USIZE;
-
-    // Get language-specific labels, if available
-    let labels = match &language {
-        Some(lang) => {
-            let entries_for_labels: Vec<(String, SymbolKind)> = entries
-                .iter()
-                .take(RESULTS_PER_PAGE_USIZE)
-                .map(|entry| (entry.name.clone(), entry.kind))
-                .collect();
-
-            let lang_name = lang.name();
-            if let Some(lsp_adapter) = registry.lsp_adapters(&lang_name).first().cloned() {
-                lsp_adapter
-                    .labels_for_symbols(&entries_for_labels, lang)
-                    .await
-                    .ok()
-            } else {
-                None
-            }
-        }
-        None => None,
-    };
 
-    let mut output = String::new();
+    let mut items = items.into_iter().skip(offset as usize);
 
-    let entries_rendered = match &labels {
-        Some(label_list) => render_entries(
-            &mut output,
-            entries
-                .into_iter()
-                .take(RESULTS_PER_PAGE_USIZE)
-                .zip(label_list.iter())
-                .map(|(entry, label)| (entry, label.as_ref())),
-        ),
-        None => render_entries(
-            &mut output,
-            entries
-                .into_iter()
-                .take(RESULTS_PER_PAGE_USIZE)
-                .map(|entry| (entry, None)),
-        ),
-    };
+    let entries = items
+        .by_ref()
+        .filter(|item| {
+            regex
+                .as_ref()
+                .is_none_or(|regex| regex.is_match(&item.text))
+        })
+        .take(RESULTS_PER_PAGE_USIZE)
+        .collect::<Vec<_>>();
+    let has_more = items.next().is_some();
+
+    let mut output = String::new();
+    let entries_rendered = render_entries(&mut output, entries);
 
     // Calculate pagination information
     let page_start = offset + 1;
@@ -372,31 +339,19 @@ async fn render_outline(
     Ok(output)
 }
 
-fn render_entries<'a>(
-    output: &mut String,
-    entries: impl IntoIterator<Item = (Entry, Option<&'a CodeLabel>)>,
-) -> u32 {
+fn render_entries(output: &mut String, items: impl IntoIterator<Item = OutlineItem<Point>>) -> u32 {
     let mut entries_rendered = 0;
 
-    for (entry, label) in entries {
+    for item in items {
         // Indent based on depth ("" for level 0, "  " for level 1, etc.)
-        for _ in 0..entry.depth {
-            output.push_str("  ");
-        }
-
-        match label {
-            Some(label) => {
-                output.push_str(label.text());
-            }
-            None => {
-                write_symbol_kind(output, entry.kind).ok();
-                output.push_str(&entry.name);
-            }
+        for _ in 0..item.depth {
+            output.push(' ');
         }
+        output.push_str(&item.text);
 
         // Add position information - convert to 1-based line numbers for display
-        let start_line = entry.start_line + 1;
-        let end_line = entry.end_line + 1;
+        let start_line = item.range.start.row + 1;
+        let end_line = item.range.end.row + 1;
 
         if start_line == end_line {
             writeln!(output, " [L{}]", start_line).ok();
@@ -408,38 +363,3 @@ fn render_entries<'a>(
 
     entries_rendered
 }
-
-// We may not have a language server adapter to have language-specific
-// ways to translate SymbolKnd into a string. In that situation,
-// fall back on some reasonable default strings to render.
-fn write_symbol_kind(buf: &mut String, kind: SymbolKind) -> Result<(), fmt::Error> {
-    match kind {
-        SymbolKind::FILE => write!(buf, "file "),
-        SymbolKind::MODULE => write!(buf, "module "),
-        SymbolKind::NAMESPACE => write!(buf, "namespace "),
-        SymbolKind::PACKAGE => write!(buf, "package "),
-        SymbolKind::CLASS => write!(buf, "class "),
-        SymbolKind::METHOD => write!(buf, "method "),
-        SymbolKind::PROPERTY => write!(buf, "property "),
-        SymbolKind::FIELD => write!(buf, "field "),
-        SymbolKind::CONSTRUCTOR => write!(buf, "constructor "),
-        SymbolKind::ENUM => write!(buf, "enum "),
-        SymbolKind::INTERFACE => write!(buf, "interface "),
-        SymbolKind::FUNCTION => write!(buf, "function "),
-        SymbolKind::VARIABLE => write!(buf, "variable "),
-        SymbolKind::CONSTANT => write!(buf, "constant "),
-        SymbolKind::STRING => write!(buf, "string "),
-        SymbolKind::NUMBER => write!(buf, "number "),
-        SymbolKind::BOOLEAN => write!(buf, "boolean "),
-        SymbolKind::ARRAY => write!(buf, "array "),
-        SymbolKind::OBJECT => write!(buf, "object "),
-        SymbolKind::KEY => write!(buf, "key "),
-        SymbolKind::NULL => write!(buf, "null "),
-        SymbolKind::ENUM_MEMBER => write!(buf, "enum member "),
-        SymbolKind::STRUCT => write!(buf, "struct "),
-        SymbolKind::EVENT => write!(buf, "event "),
-        SymbolKind::OPERATOR => write!(buf, "operator "),
-        SymbolKind::TYPE_PARAMETER => write!(buf, "type parameter "),
-        _ => Ok(()),
-    }
-}

crates/assistant_tools/src/read_file_tool.rs 🔗

@@ -1,7 +1,6 @@
 use std::sync::Arc;
 
-use crate::code_symbols_tool::file_outline;
-use crate::schema::json_schema_for;
+use crate::{code_symbols_tool::file_outline, schema::json_schema_for};
 use anyhow::{Result, anyhow};
 use assistant_tool::{ActionLog, Tool};
 use gpui::{App, Entity, Task};
@@ -16,7 +15,7 @@ use util::markdown::MarkdownString;
 /// If the model requests to read a file whose size exceeds this, then
 /// the tool will return an error along with the model's symbol outline,
 /// and suggest trying again using line ranges from the outline.
-const MAX_FILE_SIZE_TO_READ: usize = 4096;
+const MAX_FILE_SIZE_TO_READ: usize = 16384;
 
 #[derive(Debug, Serialize, Deserialize, JsonSchema)]
 pub struct ReadFileToolInput {