1use std::fmt::Write;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use crate::schema::json_schema_for;
6use anyhow::{Result, anyhow};
7use assistant_tool::{ActionLog, Tool, ToolResult};
8use collections::IndexMap;
9use gpui::{App, AsyncApp, Entity, Task};
10use language::{OutlineItem, ParseStatus, Point};
11use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
12use project::{Project, Symbol};
13use regex::{Regex, RegexBuilder};
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16use ui::IconName;
17use util::markdown::MarkdownString;
18
19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
20pub struct CodeSymbolsInput {
21 /// The relative path of the source code file to read and get the symbols for.
22 /// This tool should only be used on source code files, never on any other type of file.
23 ///
24 /// This path should never be absolute, and the first component
25 /// of the path should always be a root directory in a project.
26 ///
27 /// If no path is specified, this tool returns a flat list of all symbols in the project
28 /// instead of a hierarchical outline of a specific file.
29 ///
30 /// <example>
31 /// If the project has the following root directories:
32 ///
33 /// - directory1
34 /// - directory2
35 ///
36 /// If you want to access `file.md` in `directory1`, you should use the path `directory1/file.md`.
37 /// If you want to access `file.md` in `directory2`, you should use the path `directory2/file.md`.
38 /// </example>
39 #[serde(default)]
40 pub path: Option<String>,
41
42 /// Optional regex pattern to filter symbols by name.
43 /// When provided, only symbols whose names match this pattern will be included in the results.
44 ///
45 /// <example>
46 /// To find only symbols that contain the word "test", use the regex pattern "test".
47 /// To find methods that start with "get_", use the regex pattern "^get_".
48 /// </example>
49 #[serde(default)]
50 pub regex: Option<String>,
51
52 /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
53 ///
54 /// <example>
55 /// Set to `true` to make regex matching case-sensitive.
56 /// </example>
57 #[serde(default)]
58 pub case_sensitive: bool,
59
60 /// Optional starting position for paginated results (0-based).
61 /// When not provided, starts from the beginning.
62 #[serde(default)]
63 pub offset: u32,
64}
65
66impl CodeSymbolsInput {
67 /// Which page of search results this is.
68 pub fn page(&self) -> u32 {
69 1 + (self.offset / RESULTS_PER_PAGE)
70 }
71}
72
73const RESULTS_PER_PAGE: u32 = 2000;
74
75pub struct CodeSymbolsTool;
76
77impl Tool for CodeSymbolsTool {
78 fn name(&self) -> String {
79 "code_symbols".into()
80 }
81
82 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
83 false
84 }
85
86 fn description(&self) -> String {
87 include_str!("./code_symbols_tool/description.md").into()
88 }
89
90 fn icon(&self) -> IconName {
91 IconName::Code
92 }
93
94 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
95 json_schema_for::<CodeSymbolsInput>(format)
96 }
97
98 fn ui_text(&self, input: &serde_json::Value) -> String {
99 match serde_json::from_value::<CodeSymbolsInput>(input.clone()) {
100 Ok(input) => {
101 let page = input.page();
102
103 match &input.path {
104 Some(path) => {
105 let path = MarkdownString::inline_code(path);
106 if page > 1 {
107 format!("List page {page} of code symbols for {path}")
108 } else {
109 format!("List code symbols for {path}")
110 }
111 }
112 None => {
113 if page > 1 {
114 format!("List page {page} of project symbols")
115 } else {
116 "List all project symbols".to_string()
117 }
118 }
119 }
120 }
121 Err(_) => "List code symbols".to_string(),
122 }
123 }
124
125 fn run(
126 self: Arc<Self>,
127 input: serde_json::Value,
128 _messages: &[LanguageModelRequestMessage],
129 project: Entity<Project>,
130 action_log: Entity<ActionLog>,
131 cx: &mut App,
132 ) -> ToolResult {
133 let input = match serde_json::from_value::<CodeSymbolsInput>(input) {
134 Ok(input) => input,
135 Err(err) => return Task::ready(Err(anyhow!(err))).into(),
136 };
137
138 let regex = match input.regex {
139 Some(regex_str) => match RegexBuilder::new(®ex_str)
140 .case_insensitive(!input.case_sensitive)
141 .build()
142 {
143 Ok(regex) => Some(regex),
144 Err(err) => return Task::ready(Err(anyhow!("Invalid regex: {err}"))).into(),
145 },
146 None => None,
147 };
148
149 cx.spawn(async move |cx| match input.path {
150 Some(path) => file_outline(project, path, action_log, regex, cx).await,
151 None => project_symbols(project, regex, input.offset, cx).await,
152 })
153 .into()
154 }
155}
156
157pub async fn file_outline(
158 project: Entity<Project>,
159 path: String,
160 action_log: Entity<ActionLog>,
161 regex: Option<Regex>,
162 cx: &mut AsyncApp,
163) -> anyhow::Result<String> {
164 let buffer = {
165 let project_path = project.read_with(cx, |project, cx| {
166 project
167 .find_project_path(&path, cx)
168 .ok_or_else(|| anyhow!("Path {path} not found in project"))
169 })??;
170
171 project
172 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
173 .await?
174 };
175
176 action_log.update(cx, |action_log, cx| {
177 action_log.track_buffer(buffer.clone(), cx);
178 })?;
179
180 // Wait until the buffer has been fully parsed, so that we can read its outline.
181 let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
182 while *parse_status.borrow() != ParseStatus::Idle {
183 parse_status.changed().await?;
184 }
185
186 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
187 let Some(outline) = snapshot.outline(None) else {
188 return Err(anyhow!("No outline information available for this file."));
189 };
190
191 render_outline(
192 outline
193 .items
194 .into_iter()
195 .map(|item| item.to_point(&snapshot)),
196 regex,
197 0,
198 usize::MAX,
199 )
200 .await
201}
202
203async fn project_symbols(
204 project: Entity<Project>,
205 regex: Option<Regex>,
206 offset: u32,
207 cx: &mut AsyncApp,
208) -> anyhow::Result<String> {
209 let symbols = project
210 .update(cx, |project, cx| project.symbols("", cx))?
211 .await?;
212
213 if symbols.is_empty() {
214 return Err(anyhow!("No symbols found in project."));
215 }
216
217 let mut symbols_by_path: IndexMap<PathBuf, Vec<&Symbol>> = IndexMap::default();
218
219 for symbol in symbols
220 .iter()
221 .filter(|symbol| {
222 if let Some(regex) = ®ex {
223 regex.is_match(&symbol.name)
224 } else {
225 true
226 }
227 })
228 .skip(offset as usize)
229 // Take 1 more than RESULTS_PER_PAGE so we can tell if there are more results.
230 .take((RESULTS_PER_PAGE as usize).saturating_add(1))
231 {
232 if let Some(worktree_path) = project.read_with(cx, |project, cx| {
233 project
234 .worktree_for_id(symbol.path.worktree_id, cx)
235 .map(|worktree| PathBuf::from(worktree.read(cx).root_name()))
236 })? {
237 let path = worktree_path.join(&symbol.path.path);
238 symbols_by_path.entry(path).or_default().push(symbol);
239 }
240 }
241
242 // If no symbols matched the filter, return early
243 if symbols_by_path.is_empty() {
244 return Err(anyhow!("No symbols found matching the criteria."));
245 }
246
247 let mut symbols_rendered = 0;
248 let mut has_more_symbols = false;
249 let mut output = String::new();
250
251 'outer: for (file_path, file_symbols) in symbols_by_path {
252 if symbols_rendered > 0 {
253 output.push('\n');
254 }
255
256 writeln!(&mut output, "{}", file_path.display()).ok();
257
258 for symbol in file_symbols {
259 if symbols_rendered >= RESULTS_PER_PAGE {
260 has_more_symbols = true;
261 break 'outer;
262 }
263
264 write!(&mut output, " {} ", symbol.label.text()).ok();
265
266 // Convert to 1-based line numbers for display
267 let start_line = symbol.range.start.0.row as usize + 1;
268 let end_line = symbol.range.end.0.row as usize + 1;
269
270 if start_line == end_line {
271 writeln!(&mut output, "[L{}]", start_line).ok();
272 } else {
273 writeln!(&mut output, "[L{}-{}]", start_line, end_line).ok();
274 }
275
276 symbols_rendered += 1;
277 }
278 }
279
280 Ok(if symbols_rendered == 0 {
281 "No symbols found in the requested page.".to_string()
282 } else if has_more_symbols {
283 format!(
284 "{output}\nShowing symbols {}-{} (more symbols were found; use offset: {} to see next page)",
285 offset + 1,
286 offset + symbols_rendered,
287 offset + RESULTS_PER_PAGE,
288 )
289 } else {
290 output
291 })
292}
293
294async fn render_outline(
295 items: impl IntoIterator<Item = OutlineItem<Point>>,
296 regex: Option<Regex>,
297 offset: usize,
298 results_per_page: usize,
299) -> Result<String> {
300 let mut items = items.into_iter().skip(offset);
301
302 let entries = items
303 .by_ref()
304 .filter(|item| {
305 regex
306 .as_ref()
307 .is_none_or(|regex| regex.is_match(&item.text))
308 })
309 .take(results_per_page)
310 .collect::<Vec<_>>();
311 let has_more = items.next().is_some();
312
313 let mut output = String::new();
314 let entries_rendered = render_entries(&mut output, entries);
315
316 // Calculate pagination information
317 let page_start = offset + 1;
318 let page_end = offset + entries_rendered;
319 let total_symbols = if has_more {
320 format!("more than {}", page_end)
321 } else {
322 page_end.to_string()
323 };
324
325 // Add pagination information
326 if has_more {
327 writeln!(&mut output, "\nShowing symbols {page_start}-{page_end} (there were more symbols found; use offset: {page_end} to see next page)",
328 )
329 } else {
330 writeln!(
331 &mut output,
332 "\nShowing symbols {page_start}-{page_end} (total symbols: {total_symbols})",
333 )
334 }
335 .ok();
336
337 Ok(output)
338}
339
340fn render_entries(
341 output: &mut String,
342 items: impl IntoIterator<Item = OutlineItem<Point>>,
343) -> usize {
344 let mut entries_rendered = 0;
345
346 for item in items {
347 // Indent based on depth ("" for level 0, " " for level 1, etc.)
348 for _ in 0..item.depth {
349 output.push(' ');
350 }
351 output.push_str(&item.text);
352
353 // Add position information - convert to 1-based line numbers for display
354 let start_line = item.range.start.row + 1;
355 let end_line = item.range.end.row + 1;
356
357 if start_line == end_line {
358 writeln!(output, " [L{}]", start_line).ok();
359 } else {
360 writeln!(output, " [L{}-{}]", start_line, end_line).ok();
361 }
362 entries_rendered += 1;
363 }
364
365 entries_rendered
366}