1use std::fmt::{self, Write};
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use assistant_tool::{ActionLog, Tool};
7use collections::IndexMap;
8use gpui::{App, AsyncApp, Entity, Task};
9use language::{CodeLabel, Language, LanguageRegistry};
10use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
11use lsp::SymbolKind;
12use project::{DocumentSymbol, Project, Symbol};
13use regex::{Regex, RegexBuilder};
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16use ui::IconName;
17use util::markdown::MarkdownString;
18
19use crate::code_symbol_iter::{CodeSymbolIterator, Entry};
20use crate::schema::json_schema_for;
21
22#[derive(Debug, Serialize, Deserialize, JsonSchema)]
23pub struct CodeSymbolsInput {
24 /// The relative path of the source code file to read and get the symbols for.
25 /// This tool should only be used on source code files, never on any other type of file.
26 ///
27 /// This path should never be absolute, and the first component
28 /// of the path should always be a root directory in a project.
29 ///
30 /// If no path is specified, this tool returns a flat list of all symbols in the project
31 /// instead of a hierarchical outline of a specific file.
32 ///
33 /// <example>
34 /// If the project has the following root directories:
35 ///
36 /// - directory1
37 /// - directory2
38 ///
39 /// If you want to access `file.md` in `directory1`, you should use the path `directory1/file.md`.
40 /// If you want to access `file.md` in `directory2`, you should use the path `directory2/file.md`.
41 /// </example>
42 #[serde(default)]
43 pub path: Option<String>,
44
45 /// Optional regex pattern to filter symbols by name.
46 /// When provided, only symbols whose names match this pattern will be included in the results.
47 ///
48 /// <example>
49 /// To find only symbols that contain the word "test", use the regex pattern "test".
50 /// To find methods that start with "get_", use the regex pattern "^get_".
51 /// </example>
52 #[serde(default)]
53 pub regex: Option<String>,
54
55 /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
56 ///
57 /// <example>
58 /// Set to `true` to make regex matching case-sensitive.
59 /// </example>
60 #[serde(default)]
61 pub case_sensitive: bool,
62
63 /// Optional starting position for paginated results (0-based).
64 /// When not provided, starts from the beginning.
65 #[serde(default)]
66 pub offset: u32,
67}
68
69impl CodeSymbolsInput {
70 /// Which page of search results this is.
71 pub fn page(&self) -> u32 {
72 1 + (self.offset / RESULTS_PER_PAGE)
73 }
74}
75
76const RESULTS_PER_PAGE: u32 = 2000;
77
78pub struct CodeSymbolsTool;
79
80impl Tool for CodeSymbolsTool {
81 fn name(&self) -> String {
82 "code-symbols".into()
83 }
84
85 fn needs_confirmation(&self) -> bool {
86 false
87 }
88
89 fn description(&self) -> String {
90 include_str!("./code_symbols_tool/description.md").into()
91 }
92
93 fn icon(&self) -> IconName {
94 IconName::Eye
95 }
96
97 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
98 json_schema_for::<CodeSymbolsInput>(format)
99 }
100
101 fn ui_text(&self, input: &serde_json::Value) -> String {
102 match serde_json::from_value::<CodeSymbolsInput>(input.clone()) {
103 Ok(input) => {
104 let page = input.page();
105
106 match &input.path {
107 Some(path) => {
108 let path = MarkdownString::inline_code(path);
109 if page > 1 {
110 format!("List page {page} of code symbols for {path}")
111 } else {
112 format!("List code symbols for {path}")
113 }
114 }
115 None => {
116 if page > 1 {
117 format!("List page {page} of project symbols")
118 } else {
119 "List all project symbols".to_string()
120 }
121 }
122 }
123 }
124 Err(_) => "List code symbols".to_string(),
125 }
126 }
127
128 fn run(
129 self: Arc<Self>,
130 input: serde_json::Value,
131 _messages: &[LanguageModelRequestMessage],
132 project: Entity<Project>,
133 action_log: Entity<ActionLog>,
134 cx: &mut App,
135 ) -> Task<Result<String>> {
136 let input = match serde_json::from_value::<CodeSymbolsInput>(input) {
137 Ok(input) => input,
138 Err(err) => return Task::ready(Err(anyhow!(err))),
139 };
140
141 let regex = match input.regex {
142 Some(regex_str) => match RegexBuilder::new(®ex_str)
143 .case_insensitive(!input.case_sensitive)
144 .build()
145 {
146 Ok(regex) => Some(regex),
147 Err(err) => return Task::ready(Err(anyhow!("Invalid regex: {err}"))),
148 },
149 None => None,
150 };
151
152 cx.spawn(async move |cx| match input.path {
153 Some(path) => file_outline(project, path, action_log, regex, input.offset, cx).await,
154 None => project_symbols(project, regex, input.offset, cx).await,
155 })
156 }
157}
158
159async fn file_outline(
160 project: Entity<Project>,
161 path: String,
162 action_log: Entity<ActionLog>,
163 regex: Option<Regex>,
164 offset: u32,
165 cx: &mut AsyncApp,
166) -> anyhow::Result<String> {
167 let buffer = {
168 let project_path = project.read_with(cx, |project, cx| {
169 project
170 .find_project_path(&path, cx)
171 .ok_or_else(|| anyhow!("Path {path} not found in project"))
172 })??;
173
174 project
175 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
176 .await?
177 };
178
179 action_log.update(cx, |action_log, cx| {
180 action_log.buffer_read(buffer.clone(), cx);
181 })?;
182
183 let symbols = project
184 .update(cx, |project, cx| project.document_symbols(&buffer, cx))?
185 .await?;
186
187 if symbols.is_empty() {
188 return Err(
189 if buffer.read_with(cx, |buffer, _| buffer.snapshot().is_empty())? {
190 anyhow!("This file is empty.")
191 } else {
192 anyhow!("No outline information available for this file.")
193 },
194 );
195 }
196
197 let language = buffer.read_with(cx, |buffer, _| buffer.language().cloned())?;
198 let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
199
200 render_outline(&symbols, language, language_registry, regex, offset).await
201}
202
203async fn project_symbols(
204 project: Entity<Project>,
205 regex: Option<Regex>,
206 offset: u32,
207 cx: &mut AsyncApp,
208) -> anyhow::Result<String> {
209 let symbols = project
210 .update(cx, |project, cx| project.symbols("", cx))?
211 .await?;
212
213 if symbols.is_empty() {
214 return Err(anyhow!("No symbols found in project."));
215 }
216
217 let mut symbols_by_path: IndexMap<PathBuf, Vec<&Symbol>> = IndexMap::default();
218
219 for symbol in symbols
220 .iter()
221 .filter(|symbol| {
222 if let Some(regex) = ®ex {
223 regex.is_match(&symbol.name)
224 } else {
225 true
226 }
227 })
228 .skip(offset as usize)
229 // Take 1 more than RESULTS_PER_PAGE so we can tell if there are more results.
230 .take((RESULTS_PER_PAGE as usize).saturating_add(1))
231 {
232 if let Some(worktree_path) = project.read_with(cx, |project, cx| {
233 project
234 .worktree_for_id(symbol.path.worktree_id, cx)
235 .map(|worktree| PathBuf::from(worktree.read(cx).root_name()))
236 })? {
237 let path = worktree_path.join(&symbol.path.path);
238 symbols_by_path.entry(path).or_default().push(symbol);
239 }
240 }
241
242 // If no symbols matched the filter, return early
243 if symbols_by_path.is_empty() {
244 return Err(anyhow!("No symbols found matching the criteria."));
245 }
246
247 let mut symbols_rendered = 0;
248 let mut has_more_symbols = false;
249 let mut output = String::new();
250
251 'outer: for (file_path, file_symbols) in symbols_by_path {
252 if symbols_rendered > 0 {
253 output.push('\n');
254 }
255
256 writeln!(&mut output, "{}", file_path.display()).ok();
257
258 for symbol in file_symbols {
259 if symbols_rendered >= RESULTS_PER_PAGE {
260 has_more_symbols = true;
261 break 'outer;
262 }
263
264 write!(&mut output, " {} ", symbol.label.text()).ok();
265
266 // Convert to 1-based line numbers for display
267 let start_line = symbol.range.start.0.row as usize + 1;
268 let end_line = symbol.range.end.0.row as usize + 1;
269
270 if start_line == end_line {
271 writeln!(&mut output, "[L{}]", start_line).ok();
272 } else {
273 writeln!(&mut output, "[L{}-{}]", start_line, end_line).ok();
274 }
275
276 symbols_rendered += 1;
277 }
278 }
279
280 Ok(if symbols_rendered == 0 {
281 "No symbols found in the requested page.".to_string()
282 } else if has_more_symbols {
283 format!(
284 "{output}\nShowing symbols {}-{} (more symbols were found; use offset: {} to see next page)",
285 offset + 1,
286 offset + symbols_rendered,
287 offset + RESULTS_PER_PAGE,
288 )
289 } else {
290 output
291 })
292}
293
294async fn render_outline(
295 symbols: &[DocumentSymbol],
296 language: Option<Arc<Language>>,
297 registry: Arc<LanguageRegistry>,
298 regex: Option<Regex>,
299 offset: u32,
300) -> Result<String> {
301 const RESULTS_PER_PAGE_USIZE: usize = RESULTS_PER_PAGE as usize;
302 let entries = CodeSymbolIterator::new(symbols, regex.clone())
303 .skip(offset as usize)
304 // Take 1 more than RESULTS_PER_PAGE so we can tell if there are more results.
305 .take(RESULTS_PER_PAGE_USIZE.saturating_add(1))
306 .collect::<Vec<Entry>>();
307 let has_more = entries.len() > RESULTS_PER_PAGE_USIZE;
308
309 // Get language-specific labels, if available
310 let labels = match &language {
311 Some(lang) => {
312 let entries_for_labels: Vec<(String, SymbolKind)> = entries
313 .iter()
314 .take(RESULTS_PER_PAGE_USIZE)
315 .map(|entry| (entry.name.clone(), entry.kind))
316 .collect();
317
318 let lang_name = lang.name();
319 if let Some(lsp_adapter) = registry.lsp_adapters(&lang_name).first().cloned() {
320 lsp_adapter
321 .labels_for_symbols(&entries_for_labels, lang)
322 .await
323 .ok()
324 } else {
325 None
326 }
327 }
328 None => None,
329 };
330
331 let mut output = String::new();
332
333 let entries_rendered = match &labels {
334 Some(label_list) => render_entries(
335 &mut output,
336 entries
337 .into_iter()
338 .take(RESULTS_PER_PAGE_USIZE)
339 .zip(label_list.iter())
340 .map(|(entry, label)| (entry, label.as_ref())),
341 ),
342 None => render_entries(
343 &mut output,
344 entries
345 .into_iter()
346 .take(RESULTS_PER_PAGE_USIZE)
347 .map(|entry| (entry, None)),
348 ),
349 };
350
351 // Calculate pagination information
352 let page_start = offset + 1;
353 let page_end = offset + entries_rendered;
354 let total_symbols = if has_more {
355 format!("more than {}", page_end)
356 } else {
357 page_end.to_string()
358 };
359
360 // Add pagination information
361 if has_more {
362 writeln!(&mut output, "\nShowing symbols {page_start}-{page_end} (there were more symbols found; use offset: {page_end} to see next page)",
363 )
364 } else {
365 writeln!(
366 &mut output,
367 "\nShowing symbols {page_start}-{page_end} (total symbols: {total_symbols})",
368 )
369 }
370 .ok();
371
372 Ok(output)
373}
374
375fn render_entries<'a>(
376 output: &mut String,
377 entries: impl IntoIterator<Item = (Entry, Option<&'a CodeLabel>)>,
378) -> u32 {
379 let mut entries_rendered = 0;
380
381 for (entry, label) in entries {
382 // Indent based on depth ("" for level 0, " " for level 1, etc.)
383 for _ in 0..entry.depth {
384 output.push_str(" ");
385 }
386
387 match label {
388 Some(label) => {
389 output.push_str(label.text());
390 }
391 None => {
392 write_symbol_kind(output, entry.kind).ok();
393 output.push_str(&entry.name);
394 }
395 }
396
397 // Add position information - convert to 1-based line numbers for display
398 let start_line = entry.start_line + 1;
399 let end_line = entry.end_line + 1;
400
401 if start_line == end_line {
402 writeln!(output, " [L{}]", start_line).ok();
403 } else {
404 writeln!(output, " [L{}-{}]", start_line, end_line).ok();
405 }
406 entries_rendered += 1;
407 }
408
409 entries_rendered
410}
411
412// We may not have a language server adapter to have language-specific
413// ways to translate SymbolKnd into a string. In that situation,
414// fall back on some reasonable default strings to render.
415fn write_symbol_kind(buf: &mut String, kind: SymbolKind) -> Result<(), fmt::Error> {
416 match kind {
417 SymbolKind::FILE => write!(buf, "file "),
418 SymbolKind::MODULE => write!(buf, "module "),
419 SymbolKind::NAMESPACE => write!(buf, "namespace "),
420 SymbolKind::PACKAGE => write!(buf, "package "),
421 SymbolKind::CLASS => write!(buf, "class "),
422 SymbolKind::METHOD => write!(buf, "method "),
423 SymbolKind::PROPERTY => write!(buf, "property "),
424 SymbolKind::FIELD => write!(buf, "field "),
425 SymbolKind::CONSTRUCTOR => write!(buf, "constructor "),
426 SymbolKind::ENUM => write!(buf, "enum "),
427 SymbolKind::INTERFACE => write!(buf, "interface "),
428 SymbolKind::FUNCTION => write!(buf, "function "),
429 SymbolKind::VARIABLE => write!(buf, "variable "),
430 SymbolKind::CONSTANT => write!(buf, "constant "),
431 SymbolKind::STRING => write!(buf, "string "),
432 SymbolKind::NUMBER => write!(buf, "number "),
433 SymbolKind::BOOLEAN => write!(buf, "boolean "),
434 SymbolKind::ARRAY => write!(buf, "array "),
435 SymbolKind::OBJECT => write!(buf, "object "),
436 SymbolKind::KEY => write!(buf, "key "),
437 SymbolKind::NULL => write!(buf, "null "),
438 SymbolKind::ENUM_MEMBER => write!(buf, "enum member "),
439 SymbolKind::STRUCT => write!(buf, "struct "),
440 SymbolKind::EVENT => write!(buf, "event "),
441 SymbolKind::OPERATOR => write!(buf, "operator "),
442 SymbolKind::TYPE_PARAMETER => write!(buf, "type parameter "),
443 _ => Ok(()),
444 }
445}