1use std::fmt::Write;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use crate::schema::json_schema_for;
6use anyhow::{Result, anyhow};
7use assistant_tool::{ActionLog, Tool};
8use collections::IndexMap;
9use gpui::{App, AsyncApp, Entity, Task};
10use language::{OutlineItem, ParseStatus, Point};
11use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
12use project::{Project, Symbol};
13use regex::{Regex, RegexBuilder};
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16use ui::IconName;
17use util::markdown::MarkdownString;
18
19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
20pub struct CodeSymbolsInput {
21 /// The relative path of the source code file to read and get the symbols for.
22 /// This tool should only be used on source code files, never on any other type of file.
23 ///
24 /// This path should never be absolute, and the first component
25 /// of the path should always be a root directory in a project.
26 ///
27 /// If no path is specified, this tool returns a flat list of all symbols in the project
28 /// instead of a hierarchical outline of a specific file.
29 ///
30 /// <example>
31 /// If the project has the following root directories:
32 ///
33 /// - directory1
34 /// - directory2
35 ///
36 /// If you want to access `file.md` in `directory1`, you should use the path `directory1/file.md`.
37 /// If you want to access `file.md` in `directory2`, you should use the path `directory2/file.md`.
38 /// </example>
39 #[serde(default)]
40 pub path: Option<String>,
41
42 /// Optional regex pattern to filter symbols by name.
43 /// When provided, only symbols whose names match this pattern will be included in the results.
44 ///
45 /// <example>
46 /// To find only symbols that contain the word "test", use the regex pattern "test".
47 /// To find methods that start with "get_", use the regex pattern "^get_".
48 /// </example>
49 #[serde(default)]
50 pub regex: Option<String>,
51
52 /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
53 ///
54 /// <example>
55 /// Set to `true` to make regex matching case-sensitive.
56 /// </example>
57 #[serde(default)]
58 pub case_sensitive: bool,
59
60 /// Optional starting position for paginated results (0-based).
61 /// When not provided, starts from the beginning.
62 #[serde(default)]
63 pub offset: u32,
64}
65
66impl CodeSymbolsInput {
67 /// Which page of search results this is.
68 pub fn page(&self) -> u32 {
69 1 + (self.offset / RESULTS_PER_PAGE)
70 }
71}
72
73const RESULTS_PER_PAGE: u32 = 2000;
74
75pub struct CodeSymbolsTool;
76
77impl Tool for CodeSymbolsTool {
78 fn name(&self) -> String {
79 "code_symbols".into()
80 }
81
82 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
83 false
84 }
85
86 fn description(&self) -> String {
87 include_str!("./code_symbols_tool/description.md").into()
88 }
89
90 fn icon(&self) -> IconName {
91 IconName::Code
92 }
93
94 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
95 json_schema_for::<CodeSymbolsInput>(format)
96 }
97
98 fn ui_text(&self, input: &serde_json::Value) -> String {
99 match serde_json::from_value::<CodeSymbolsInput>(input.clone()) {
100 Ok(input) => {
101 let page = input.page();
102
103 match &input.path {
104 Some(path) => {
105 let path = MarkdownString::inline_code(path);
106 if page > 1 {
107 format!("List page {page} of code symbols for {path}")
108 } else {
109 format!("List code symbols for {path}")
110 }
111 }
112 None => {
113 if page > 1 {
114 format!("List page {page} of project symbols")
115 } else {
116 "List all project symbols".to_string()
117 }
118 }
119 }
120 }
121 Err(_) => "List code symbols".to_string(),
122 }
123 }
124
125 fn run(
126 self: Arc<Self>,
127 input: serde_json::Value,
128 _messages: &[LanguageModelRequestMessage],
129 project: Entity<Project>,
130 action_log: Entity<ActionLog>,
131 cx: &mut App,
132 ) -> Task<Result<String>> {
133 let input = match serde_json::from_value::<CodeSymbolsInput>(input) {
134 Ok(input) => input,
135 Err(err) => return Task::ready(Err(anyhow!(err))),
136 };
137
138 let regex = match input.regex {
139 Some(regex_str) => match RegexBuilder::new(®ex_str)
140 .case_insensitive(!input.case_sensitive)
141 .build()
142 {
143 Ok(regex) => Some(regex),
144 Err(err) => return Task::ready(Err(anyhow!("Invalid regex: {err}"))),
145 },
146 None => None,
147 };
148
149 cx.spawn(async move |cx| match input.path {
150 Some(path) => file_outline(project, path, action_log, regex, input.offset, cx).await,
151 None => project_symbols(project, regex, input.offset, cx).await,
152 })
153 }
154}
155
156pub async fn file_outline(
157 project: Entity<Project>,
158 path: String,
159 action_log: Entity<ActionLog>,
160 regex: Option<Regex>,
161 offset: u32,
162 cx: &mut AsyncApp,
163) -> anyhow::Result<String> {
164 let buffer = {
165 let project_path = project.read_with(cx, |project, cx| {
166 project
167 .find_project_path(&path, cx)
168 .ok_or_else(|| anyhow!("Path {path} not found in project"))
169 })??;
170
171 project
172 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
173 .await?
174 };
175
176 action_log.update(cx, |action_log, cx| {
177 action_log.buffer_read(buffer.clone(), cx);
178 })?;
179
180 // Wait until the buffer has been fully parsed, so that we can read its outline.
181 let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?;
182 while *parse_status.borrow() != ParseStatus::Idle {
183 parse_status.changed().await?;
184 }
185
186 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
187 let Some(outline) = snapshot.outline(None) else {
188 return Err(anyhow!("No outline information available for this file."));
189 };
190
191 render_outline(
192 outline
193 .items
194 .into_iter()
195 .map(|item| item.to_point(&snapshot)),
196 regex,
197 offset,
198 )
199 .await
200}
201
202async fn project_symbols(
203 project: Entity<Project>,
204 regex: Option<Regex>,
205 offset: u32,
206 cx: &mut AsyncApp,
207) -> anyhow::Result<String> {
208 let symbols = project
209 .update(cx, |project, cx| project.symbols("", cx))?
210 .await?;
211
212 if symbols.is_empty() {
213 return Err(anyhow!("No symbols found in project."));
214 }
215
216 let mut symbols_by_path: IndexMap<PathBuf, Vec<&Symbol>> = IndexMap::default();
217
218 for symbol in symbols
219 .iter()
220 .filter(|symbol| {
221 if let Some(regex) = ®ex {
222 regex.is_match(&symbol.name)
223 } else {
224 true
225 }
226 })
227 .skip(offset as usize)
228 // Take 1 more than RESULTS_PER_PAGE so we can tell if there are more results.
229 .take((RESULTS_PER_PAGE as usize).saturating_add(1))
230 {
231 if let Some(worktree_path) = project.read_with(cx, |project, cx| {
232 project
233 .worktree_for_id(symbol.path.worktree_id, cx)
234 .map(|worktree| PathBuf::from(worktree.read(cx).root_name()))
235 })? {
236 let path = worktree_path.join(&symbol.path.path);
237 symbols_by_path.entry(path).or_default().push(symbol);
238 }
239 }
240
241 // If no symbols matched the filter, return early
242 if symbols_by_path.is_empty() {
243 return Err(anyhow!("No symbols found matching the criteria."));
244 }
245
246 let mut symbols_rendered = 0;
247 let mut has_more_symbols = false;
248 let mut output = String::new();
249
250 'outer: for (file_path, file_symbols) in symbols_by_path {
251 if symbols_rendered > 0 {
252 output.push('\n');
253 }
254
255 writeln!(&mut output, "{}", file_path.display()).ok();
256
257 for symbol in file_symbols {
258 if symbols_rendered >= RESULTS_PER_PAGE {
259 has_more_symbols = true;
260 break 'outer;
261 }
262
263 write!(&mut output, " {} ", symbol.label.text()).ok();
264
265 // Convert to 1-based line numbers for display
266 let start_line = symbol.range.start.0.row as usize + 1;
267 let end_line = symbol.range.end.0.row as usize + 1;
268
269 if start_line == end_line {
270 writeln!(&mut output, "[L{}]", start_line).ok();
271 } else {
272 writeln!(&mut output, "[L{}-{}]", start_line, end_line).ok();
273 }
274
275 symbols_rendered += 1;
276 }
277 }
278
279 Ok(if symbols_rendered == 0 {
280 "No symbols found in the requested page.".to_string()
281 } else if has_more_symbols {
282 format!(
283 "{output}\nShowing symbols {}-{} (more symbols were found; use offset: {} to see next page)",
284 offset + 1,
285 offset + symbols_rendered,
286 offset + RESULTS_PER_PAGE,
287 )
288 } else {
289 output
290 })
291}
292
293async fn render_outline(
294 items: impl IntoIterator<Item = OutlineItem<Point>>,
295 regex: Option<Regex>,
296 offset: u32,
297) -> Result<String> {
298 const RESULTS_PER_PAGE_USIZE: usize = RESULTS_PER_PAGE as usize;
299
300 let mut items = items.into_iter().skip(offset as usize);
301
302 let entries = items
303 .by_ref()
304 .filter(|item| {
305 regex
306 .as_ref()
307 .is_none_or(|regex| regex.is_match(&item.text))
308 })
309 .take(RESULTS_PER_PAGE_USIZE)
310 .collect::<Vec<_>>();
311 let has_more = items.next().is_some();
312
313 let mut output = String::new();
314 let entries_rendered = render_entries(&mut output, entries);
315
316 // Calculate pagination information
317 let page_start = offset + 1;
318 let page_end = offset + entries_rendered;
319 let total_symbols = if has_more {
320 format!("more than {}", page_end)
321 } else {
322 page_end.to_string()
323 };
324
325 // Add pagination information
326 if has_more {
327 writeln!(&mut output, "\nShowing symbols {page_start}-{page_end} (there were more symbols found; use offset: {page_end} to see next page)",
328 )
329 } else {
330 writeln!(
331 &mut output,
332 "\nShowing symbols {page_start}-{page_end} (total symbols: {total_symbols})",
333 )
334 }
335 .ok();
336
337 Ok(output)
338}
339
340fn render_entries(output: &mut String, items: impl IntoIterator<Item = OutlineItem<Point>>) -> u32 {
341 let mut entries_rendered = 0;
342
343 for item in items {
344 // Indent based on depth ("" for level 0, " " for level 1, etc.)
345 for _ in 0..item.depth {
346 output.push(' ');
347 }
348 output.push_str(&item.text);
349
350 // Add position information - convert to 1-based line numbers for display
351 let start_line = item.range.start.row + 1;
352 let end_line = item.range.end.row + 1;
353
354 if start_line == end_line {
355 writeln!(output, " [L{}]", start_line).ok();
356 } else {
357 writeln!(output, " [L{}-{}]", start_line, end_line).ok();
358 }
359 entries_rendered += 1;
360 }
361
362 entries_rendered
363}