1use anyhow::{Context as _, Result, anyhow};
2use assistant_tool::{ActionLog, Tool, ToolResult};
3use gpui::{AnyWindowHandle, App, AsyncApp, Entity, Task};
4use language::{self, Anchor, Buffer, BufferSnapshot, Location, Point, ToPoint, ToPointUtf16};
5use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
6use project::Project;
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use std::{fmt::Write, ops::Range, sync::Arc};
10use ui::IconName;
11use util::markdown::MarkdownString;
12
13use crate::schema::json_schema_for;
14
15#[derive(Debug, Serialize, Deserialize, JsonSchema)]
16pub struct SymbolInfoToolInput {
17 /// The relative path to the file containing the symbol.
18 ///
19 /// WARNING: you MUST start this path with one of the project's root directories.
20 pub path: String,
21
22 /// The information to get about the symbol.
23 pub command: Info,
24
25 /// The text that comes immediately before the symbol in the file.
26 pub context_before_symbol: String,
27
28 /// The symbol name. This text must appear in the file right between `context_before_symbol`
29 /// and `context_after_symbol`.
30 ///
31 /// The file must contain exactly one occurrence of `context_before_symbol` followed by
32 /// `symbol` followed by `context_after_symbol`. If the file contains zero occurrences,
33 /// or if it contains more than one occurrence, the tool will fail, so it is absolutely
34 /// critical that you verify ahead of time that the string is unique. You can search
35 /// the file's contents to verify this ahead of time.
36 ///
37 /// To make the string more likely to be unique, include a minimum of 1 line of context
38 /// before the symbol, as well as a minimum of 1 line of context after the symbol.
39 /// If these lines of context are not enough to obtain a string that appears only once
40 /// in the file, then double the number of context lines until the string becomes unique.
41 /// (Start with 1 line before and 1 line after though, because too much context is
42 /// needlessly costly.)
43 ///
44 /// Do not alter the context lines of code in any way, and make sure to preserve all
45 /// whitespace and indentation for all lines of code. The combined string must be exactly
46 /// as it appears in the file, or else this tool call will fail.
47 pub symbol: String,
48
49 /// The text that comes immediately after the symbol in the file.
50 pub context_after_symbol: String,
51}
52
53#[derive(Debug, Serialize, Deserialize, JsonSchema)]
54#[serde(rename_all = "snake_case")]
55pub enum Info {
56 /// Get the symbol's definition (where it's first assigned, even if it's declared elsewhere)
57 Definition,
58 /// Get the symbol's declaration (where it's first declared)
59 Declaration,
60 /// Get the symbol's implementation
61 Implementation,
62 /// Get the symbol's type definition
63 TypeDefinition,
64 /// Find all references to the symbol in the project
65 References,
66}
67
68pub struct SymbolInfoTool;
69
70impl Tool for SymbolInfoTool {
71 fn name(&self) -> String {
72 "symbol_info".into()
73 }
74
75 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
76 false
77 }
78
79 fn description(&self) -> String {
80 include_str!("./symbol_info_tool/description.md").into()
81 }
82
83 fn icon(&self) -> IconName {
84 IconName::Code
85 }
86
87 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
88 json_schema_for::<SymbolInfoToolInput>(format)
89 }
90
91 fn ui_text(&self, input: &serde_json::Value) -> String {
92 match serde_json::from_value::<SymbolInfoToolInput>(input.clone()) {
93 Ok(input) => {
94 let symbol = MarkdownString::inline_code(&input.symbol);
95
96 match input.command {
97 Info::Definition => {
98 format!("Find definition for {symbol}")
99 }
100 Info::Declaration => {
101 format!("Find declaration for {symbol}")
102 }
103 Info::Implementation => {
104 format!("Find implementation for {symbol}")
105 }
106 Info::TypeDefinition => {
107 format!("Find type definition for {symbol}")
108 }
109 Info::References => {
110 format!("Find references for {symbol}")
111 }
112 }
113 }
114 Err(_) => "Get symbol info".to_string(),
115 }
116 }
117
118 fn run(
119 self: Arc<Self>,
120 input: serde_json::Value,
121 _messages: &[LanguageModelRequestMessage],
122 project: Entity<Project>,
123 action_log: Entity<ActionLog>,
124 _window: Option<AnyWindowHandle>,
125 cx: &mut App,
126 ) -> ToolResult {
127 let input = match serde_json::from_value::<SymbolInfoToolInput>(input) {
128 Ok(input) => input,
129 Err(err) => return Task::ready(Err(anyhow!(err))).into(),
130 };
131
132 cx.spawn(async move |cx| {
133 let buffer = {
134 let project_path = project.read_with(cx, |project, cx| {
135 project
136 .find_project_path(&input.path, cx)
137 .context("Path not found in project")
138 })??;
139
140 project.update(cx, |project, cx| project.open_buffer(project_path, cx))?.await?
141 };
142
143 action_log.update(cx, |action_log, cx| {
144 action_log.track_buffer(buffer.clone(), cx);
145 })?;
146
147 let position = {
148 let Some(range) = buffer.read_with(cx, |buffer, _cx| {
149 find_symbol_range(&buffer, &input.context_before_symbol, &input.symbol, &input.context_after_symbol)
150 })? else {
151 return Err(anyhow!(
152 "Failed to locate the text specified by context_before_symbol, symbol, and context_after_symbol. Make sure context_before_symbol and context_after_symbol each match exactly once in the file."
153 ));
154 };
155
156 buffer.read_with(cx, |buffer, _| {
157 range.start.to_point_utf16(&buffer.snapshot())
158 })?
159 };
160
161 let output: String = match input.command {
162 Info::Definition => {
163 render_locations(project
164 .update(cx, |project, cx| {
165 project.definition(&buffer, position, cx)
166 })?
167 .await?.into_iter().map(|link| link.target),
168 cx)
169 }
170 Info::Declaration => {
171 render_locations(project
172 .update(cx, |project, cx| {
173 project.declaration(&buffer, position, cx)
174 })?
175 .await?.into_iter().map(|link| link.target),
176 cx)
177 }
178 Info::Implementation => {
179 render_locations(project
180 .update(cx, |project, cx| {
181 project.implementation(&buffer, position, cx)
182 })?
183 .await?.into_iter().map(|link| link.target),
184 cx)
185 }
186 Info::TypeDefinition => {
187 render_locations(project
188 .update(cx, |project, cx| {
189 project.type_definition(&buffer, position, cx)
190 })?
191 .await?.into_iter().map(|link| link.target),
192 cx)
193 }
194 Info::References => {
195 render_locations(project
196 .update(cx, |project, cx| {
197 project.references(&buffer, position, cx)
198 })?
199 .await?,
200 cx)
201 }
202 };
203
204 if output.is_empty() {
205 Err(anyhow!("None found."))
206 } else {
207 Ok(output)
208 }
209 }).into()
210 }
211}
212
213/// Finds the range of the symbol in the buffer, if it appears between context_before_symbol
214/// and context_after_symbol, and if that combined string has one unique result in the buffer.
215fn find_symbol_range(
216 buffer: &Buffer,
217 context_before_symbol: &str,
218 symbol: &str,
219 context_after_symbol: &str,
220) -> Option<Range<Anchor>> {
221 let snapshot = buffer.snapshot();
222 let text = snapshot.text();
223 let search_string = format!("{context_before_symbol}{symbol}{context_after_symbol}");
224 let mut positions = text.match_indices(&search_string);
225 let position = positions.next()?.0;
226
227 // The combined string must appear exactly once.
228 if positions.next().is_some() {
229 return None;
230 }
231
232 let symbol_start = position + context_before_symbol.len();
233 let symbol_end = symbol_start + symbol.len();
234 let symbol_start_anchor = snapshot.anchor_before(snapshot.offset_to_point(symbol_start));
235 let symbol_end_anchor = snapshot.anchor_before(snapshot.offset_to_point(symbol_end));
236
237 Some(symbol_start_anchor..symbol_end_anchor)
238}
239
240fn render_locations(locations: impl IntoIterator<Item = Location>, cx: &mut AsyncApp) -> String {
241 let mut answer = String::new();
242
243 for location in locations {
244 location
245 .buffer
246 .read_with(cx, |buffer, _cx| {
247 if let Some(target_path) = buffer
248 .file()
249 .and_then(|file| file.path().as_os_str().to_str())
250 {
251 let snapshot = buffer.snapshot();
252 let start = location.range.start.to_point(&snapshot);
253 let end = location.range.end.to_point(&snapshot);
254 let start_line = start.row + 1;
255 let start_col = start.column + 1;
256 let end_line = end.row + 1;
257 let end_col = end.column + 1;
258
259 if start_line == end_line {
260 writeln!(answer, "{target_path}:{start_line},{start_col}")
261 } else {
262 writeln!(
263 answer,
264 "{target_path}:{start_line},{start_col}-{end_line},{end_col}",
265 )
266 }
267 .ok();
268
269 write_code_excerpt(&mut answer, &snapshot, &location.range);
270 }
271 })
272 .ok();
273 }
274
275 // Trim trailing newlines without reallocating.
276 answer.truncate(answer.trim_end().len());
277
278 answer
279}
280
281fn write_code_excerpt(buf: &mut String, snapshot: &BufferSnapshot, range: &Range<Anchor>) {
282 const MAX_LINE_LEN: u32 = 200;
283
284 let start = range.start.to_point(snapshot);
285 let end = range.end.to_point(snapshot);
286
287 for row in start.row..=end.row {
288 let row_start = Point::new(row, 0);
289 let row_end = if row < snapshot.max_point().row {
290 Point::new(row + 1, 0)
291 } else {
292 Point::new(row, u32::MAX)
293 };
294
295 buf.extend(
296 snapshot
297 .text_for_range(row_start..row_end)
298 .take(MAX_LINE_LEN as usize),
299 );
300
301 if row_end.column > MAX_LINE_LEN {
302 buf.push_str("…\n");
303 }
304
305 buf.push('\n');
306 }
307}