1use std::fmt::Write;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use crate::schema::json_schema_for;
6use anyhow::{Result, anyhow};
7use assistant_tool::{ActionLog, Tool, ToolResult};
8use collections::IndexMap;
9use gpui::{AnyWindowHandle, App, AsyncApp, Entity, Task};
10use language::{OutlineItem, ParseStatus, Point};
11use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
12use project::{Project, Symbol};
13use regex::{Regex, RegexBuilder};
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16use ui::IconName;
17use util::markdown::MarkdownInlineCode;
18
19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
20pub struct CodeSymbolsInput {
21 /// The relative path of the source code file to read and get the symbols for.
22 /// This tool should only be used on source code files, never on any other type of file.
23 ///
24 /// This path should never be absolute, and the first component
25 /// of the path should always be a root directory in a project.
26 ///
27 /// If no path is specified, this tool returns a flat list of all symbols in the project
28 /// instead of a hierarchical outline of a specific file.
29 ///
30 /// <example>
31 /// If the project has the following root directories:
32 ///
33 /// - directory1
34 /// - directory2
35 ///
36 /// If you want to access `file.md` in `directory1`, you should use the path `directory1/file.md`.
37 /// If you want to access `file.md` in `directory2`, you should use the path `directory2/file.md`.
38 /// </example>
39 #[serde(default)]
40 pub path: Option<String>,
41
42 /// Optional regex pattern to filter symbols by name.
43 /// When provided, only symbols whose names match this pattern will be included in the results.
44 ///
45 /// <example>
46 /// To find only symbols that contain the word "test", use the regex pattern "test".
47 /// To find methods that start with "get_", use the regex pattern "^get_".
48 /// </example>
49 #[serde(default)]
50 pub regex: Option<String>,
51
52 /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
53 ///
54 /// <example>
55 /// Set to `true` to make regex matching case-sensitive.
56 /// </example>
57 #[serde(default)]
58 pub case_sensitive: bool,
59
60 /// Optional starting position for paginated results (0-based).
61 /// When not provided, starts from the beginning.
62 #[serde(default)]
63 pub offset: u32,
64}
65
66impl CodeSymbolsInput {
67 /// Which page of search results this is.
68 pub fn page(&self) -> u32 {
69 1 + (self.offset / RESULTS_PER_PAGE)
70 }
71}
72
73const RESULTS_PER_PAGE: u32 = 2000;
74
75pub struct CodeSymbolsTool;
76
77impl Tool for CodeSymbolsTool {
78 fn name(&self) -> String {
79 "code_symbols".into()
80 }
81
82 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
83 false
84 }
85
86 fn description(&self) -> String {
87 include_str!("./code_symbols_tool/description.md").into()
88 }
89
90 fn icon(&self) -> IconName {
91 IconName::Code
92 }
93
94 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
95 json_schema_for::<CodeSymbolsInput>(format)
96 }
97
98 fn ui_text(&self, input: &serde_json::Value) -> String {
99 match serde_json::from_value::<CodeSymbolsInput>(input.clone()) {
100 Ok(input) => {
101 let page = input.page();
102
103 match &input.path {
104 Some(path) => {
105 let path = MarkdownInlineCode(path);
106 if page > 1 {
107 format!("List page {page} of code symbols for {path}")
108 } else {
109 format!("List code symbols for {path}")
110 }
111 }
112 None => {
113 if page > 1 {
114 format!("List page {page} of project symbols")
115 } else {
116 "List all project symbols".to_string()
117 }
118 }
119 }
120 }
121 Err(_) => "List code symbols".to_string(),
122 }
123 }
124
125 fn run(
126 self: Arc<Self>,
127 input: serde_json::Value,
128 _messages: &[LanguageModelRequestMessage],
129 project: Entity<Project>,
130 action_log: Entity<ActionLog>,
131 _window: Option<AnyWindowHandle>,
132 cx: &mut App,
133 ) -> ToolResult {
134 let input = match serde_json::from_value::<CodeSymbolsInput>(input) {
135 Ok(input) => input,
136 Err(err) => return Task::ready(Err(anyhow!(err))).into(),
137 };
138
139 let regex = match input.regex {
140 Some(regex_str) => match RegexBuilder::new(®ex_str)
141 .case_insensitive(!input.case_sensitive)
142 .build()
143 {
144 Ok(regex) => Some(regex),
145 Err(err) => return Task::ready(Err(anyhow!("Invalid regex: {err}"))).into(),
146 },
147 None => None,
148 };
149
150 cx.spawn(async move |cx| match input.path {
151 Some(path) => file_outline(project, path, action_log, regex, cx).await,
152 None => project_symbols(project, regex, input.offset, cx).await,
153 })
154 .into()
155 }
156}
157
158pub async fn file_outline(
159 project: Entity<Project>,
160 path: String,
161 action_log: Entity<ActionLog>,
162 regex: Option<Regex>,
163 cx: &mut AsyncApp,
164) -> anyhow::Result<String> {
165 let buffer = {
166 let project_path = project.read_with(cx, |project, cx| {
167 project
168 .find_project_path(&path, cx)
169 .ok_or_else(|| anyhow!("Path {path} not found in project"))
170 })??;
171
172 project
173 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
174 .await?
175 };
176
177 action_log.update(cx, |action_log, cx| {
178 action_log.track_buffer(buffer.clone(), cx);
179 })?;
180
181 // Wait until the buffer has been fully parsed, so that we can read its outline.
182 let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
183 while *parse_status.borrow() != ParseStatus::Idle {
184 parse_status.changed().await?;
185 }
186
187 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
188 let Some(outline) = snapshot.outline(None) else {
189 return Err(anyhow!("No outline information available for this file."));
190 };
191
192 render_outline(
193 outline
194 .items
195 .into_iter()
196 .map(|item| item.to_point(&snapshot)),
197 regex,
198 0,
199 usize::MAX,
200 )
201 .await
202}
203
204async fn project_symbols(
205 project: Entity<Project>,
206 regex: Option<Regex>,
207 offset: u32,
208 cx: &mut AsyncApp,
209) -> anyhow::Result<String> {
210 let symbols = project
211 .update(cx, |project, cx| project.symbols("", cx))?
212 .await?;
213
214 if symbols.is_empty() {
215 return Err(anyhow!("No symbols found in project."));
216 }
217
218 let mut symbols_by_path: IndexMap<PathBuf, Vec<&Symbol>> = IndexMap::default();
219
220 for symbol in symbols
221 .iter()
222 .filter(|symbol| {
223 if let Some(regex) = ®ex {
224 regex.is_match(&symbol.name)
225 } else {
226 true
227 }
228 })
229 .skip(offset as usize)
230 // Take 1 more than RESULTS_PER_PAGE so we can tell if there are more results.
231 .take((RESULTS_PER_PAGE as usize).saturating_add(1))
232 {
233 if let Some(worktree_path) = project.read_with(cx, |project, cx| {
234 project
235 .worktree_for_id(symbol.path.worktree_id, cx)
236 .map(|worktree| PathBuf::from(worktree.read(cx).root_name()))
237 })? {
238 let path = worktree_path.join(&symbol.path.path);
239 symbols_by_path.entry(path).or_default().push(symbol);
240 }
241 }
242
243 // If no symbols matched the filter, return early
244 if symbols_by_path.is_empty() {
245 return Err(anyhow!("No symbols found matching the criteria."));
246 }
247
248 let mut symbols_rendered = 0;
249 let mut has_more_symbols = false;
250 let mut output = String::new();
251
252 'outer: for (file_path, file_symbols) in symbols_by_path {
253 if symbols_rendered > 0 {
254 output.push('\n');
255 }
256
257 writeln!(&mut output, "{}", file_path.display()).ok();
258
259 for symbol in file_symbols {
260 if symbols_rendered >= RESULTS_PER_PAGE {
261 has_more_symbols = true;
262 break 'outer;
263 }
264
265 write!(&mut output, " {} ", symbol.label.text()).ok();
266
267 // Convert to 1-based line numbers for display
268 let start_line = symbol.range.start.0.row as usize + 1;
269 let end_line = symbol.range.end.0.row as usize + 1;
270
271 if start_line == end_line {
272 writeln!(&mut output, "[L{}]", start_line).ok();
273 } else {
274 writeln!(&mut output, "[L{}-{}]", start_line, end_line).ok();
275 }
276
277 symbols_rendered += 1;
278 }
279 }
280
281 Ok(if symbols_rendered == 0 {
282 "No symbols found in the requested page.".to_string()
283 } else if has_more_symbols {
284 format!(
285 "{output}\nShowing symbols {}-{} (more symbols were found; use offset: {} to see next page)",
286 offset + 1,
287 offset + symbols_rendered,
288 offset + RESULTS_PER_PAGE,
289 )
290 } else {
291 output
292 })
293}
294
295async fn render_outline(
296 items: impl IntoIterator<Item = OutlineItem<Point>>,
297 regex: Option<Regex>,
298 offset: usize,
299 results_per_page: usize,
300) -> Result<String> {
301 let mut items = items.into_iter().skip(offset);
302
303 let entries = items
304 .by_ref()
305 .filter(|item| {
306 regex
307 .as_ref()
308 .is_none_or(|regex| regex.is_match(&item.text))
309 })
310 .take(results_per_page)
311 .collect::<Vec<_>>();
312 let has_more = items.next().is_some();
313
314 let mut output = String::new();
315 let entries_rendered = render_entries(&mut output, entries);
316
317 // Calculate pagination information
318 let page_start = offset + 1;
319 let page_end = offset + entries_rendered;
320 let total_symbols = if has_more {
321 format!("more than {}", page_end)
322 } else {
323 page_end.to_string()
324 };
325
326 // Add pagination information
327 if has_more {
328 writeln!(&mut output, "\nShowing symbols {page_start}-{page_end} (there were more symbols found; use offset: {page_end} to see next page)",
329 )
330 } else {
331 writeln!(
332 &mut output,
333 "\nShowing symbols {page_start}-{page_end} (total symbols: {total_symbols})",
334 )
335 }
336 .ok();
337
338 Ok(output)
339}
340
341fn render_entries(
342 output: &mut String,
343 items: impl IntoIterator<Item = OutlineItem<Point>>,
344) -> usize {
345 let mut entries_rendered = 0;
346
347 for item in items {
348 // Indent based on depth ("" for level 0, " " for level 1, etc.)
349 for _ in 0..item.depth {
350 output.push(' ');
351 }
352 output.push_str(&item.text);
353
354 // Add position information - convert to 1-based line numbers for display
355 let start_line = item.range.start.row + 1;
356 let end_line = item.range.end.row + 1;
357
358 if start_line == end_line {
359 writeln!(output, " [L{}]", start_line).ok();
360 } else {
361 writeln!(output, " [L{}-{}]", start_line, end_line).ok();
362 }
363 entries_rendered += 1;
364 }
365
366 entries_rendered
367}