1use anyhow::{anyhow, Context as _, Result};
2use assistant_tool::{ActionLog, Tool};
3use gpui::{App, AsyncApp, Entity, Task};
4use language::{self, Anchor, Buffer, BufferSnapshot, Location, Point, ToPoint, ToPointUtf16};
5use language_model::LanguageModelRequestMessage;
6use project::Project;
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use std::{fmt::Write, ops::Range, sync::Arc};
10use ui::IconName;
11use util::markdown::MarkdownString;
12
13#[derive(Debug, Serialize, Deserialize, JsonSchema)]
14pub struct SymbolInfoToolInput {
15 /// The relative path to the file containing the symbol.
16 ///
17 /// WARNING: you MUST start this path with one of the project's root directories.
18 pub path: String,
19
20 /// The information to get about the symbol.
21 pub command: Info,
22
23 /// The text that comes immediately before the symbol in the file.
24 pub context_before_symbol: String,
25
26 /// The symbol name. This text must appear in the file right between `context_before_symbol`
27 /// and `context_after_symbol`.
28 ///
29 /// The file must contain exactly one occurrence of `context_before_symbol` followed by
30 /// `symbol` followed by `context_after_symbol`. If the file contains zero occurrences,
31 /// or if it contains more than one occurrence, the tool will fail, so it is absolutely
32 /// critical that you verify ahead of time that the string is unique. You can search
33 /// the file's contents to verify this ahead of time.
34 ///
35 /// To make the string more likely to be unique, include a minimum of 1 line of context
36 /// before the symbol, as well as a minimum of 1 line of context after the symbol.
37 /// If these lines of context are not enough to obtain a string that appears only once
38 /// in the file, then double the number of context lines until the string becomes unique.
39 /// (Start with 1 line before and 1 line after though, because too much context is
40 /// needlessly costly.)
41 ///
42 /// Do not alter the context lines of code in any way, and make sure to preserve all
43 /// whitespace and indentation for all lines of code. The combined string must be exactly
44 /// as it appears in the file, or else this tool call will fail.
45 pub symbol: String,
46
47 /// The text that comes immediately after the symbol in the file.
48 pub context_after_symbol: String,
49}
50
51#[derive(Debug, Serialize, Deserialize, JsonSchema)]
52#[serde(rename_all = "snake_case")]
53pub enum Info {
54 /// Get the symbol's definition (where it's first assigned, even if it's declared elsewhere)
55 Definition,
56 /// Get the symbol's declaration (where it's first declared)
57 Declaration,
58 /// Get the symbol's implementation
59 Implementation,
60 /// Get the symbol's type definition
61 TypeDefinition,
62 /// Find all references to the symbol in the project
63 References,
64}
65
66pub struct SymbolInfoTool;
67
68impl Tool for SymbolInfoTool {
69 fn name(&self) -> String {
70 "symbol-info".into()
71 }
72
73 fn needs_confirmation(&self) -> bool {
74 false
75 }
76
77 fn description(&self) -> String {
78 include_str!("./symbol_info_tool/description.md").into()
79 }
80
81 fn icon(&self) -> IconName {
82 IconName::Eye
83 }
84
85 fn input_schema(&self) -> serde_json::Value {
86 let schema = schemars::schema_for!(SymbolInfoToolInput);
87 serde_json::to_value(&schema).unwrap()
88 }
89
90 fn ui_text(&self, input: &serde_json::Value) -> String {
91 match serde_json::from_value::<SymbolInfoToolInput>(input.clone()) {
92 Ok(input) => {
93 let symbol = MarkdownString::inline_code(&input.symbol);
94
95 match input.command {
96 Info::Definition => {
97 format!("Find definition for {symbol}")
98 }
99 Info::Declaration => {
100 format!("Find declaration for {symbol}")
101 }
102 Info::Implementation => {
103 format!("Find implementation for {symbol}")
104 }
105 Info::TypeDefinition => {
106 format!("Find type definition for {symbol}")
107 }
108 Info::References => {
109 format!("Find references for {symbol}")
110 }
111 }
112 }
113 Err(_) => "Get symbol info".to_string(),
114 }
115 }
116
117 fn run(
118 self: Arc<Self>,
119 input: serde_json::Value,
120 _messages: &[LanguageModelRequestMessage],
121 project: Entity<Project>,
122 action_log: Entity<ActionLog>,
123 cx: &mut App,
124 ) -> Task<Result<String>> {
125 let input = match serde_json::from_value::<SymbolInfoToolInput>(input) {
126 Ok(input) => input,
127 Err(err) => return Task::ready(Err(anyhow!(err))),
128 };
129
130 cx.spawn(async move |cx| {
131 let buffer = {
132 let project_path = project.read_with(cx, |project, cx| {
133 project
134 .find_project_path(&input.path, cx)
135 .context("Path not found in project")
136 })??;
137
138 project.update(cx, |project, cx| project.open_buffer(project_path, cx))?.await?
139 };
140
141 action_log.update(cx, |action_log, cx| {
142 action_log.buffer_read(buffer.clone(), cx);
143 })?;
144
145 let position = {
146 let Some(range) = buffer.read_with(cx, |buffer, _cx| {
147 find_symbol_range(&buffer, &input.context_before_symbol, &input.symbol, &input.context_after_symbol)
148 })? else {
149 return Err(anyhow!(
150 "Failed to locate the text specified by context_before_symbol, symbol, and context_after_symbol. Make sure context_before_symbol and context_after_symbol each match exactly once in the file."
151 ));
152 };
153
154 buffer.read_with(cx, |buffer, _| {
155 range.start.to_point_utf16(&buffer.snapshot())
156 })?
157 };
158
159 let output: String = match input.command {
160 Info::Definition => {
161 render_locations(project
162 .update(cx, |project, cx| {
163 project.definition(&buffer, position, cx)
164 })?
165 .await?.into_iter().map(|link| link.target),
166 cx)
167 }
168 Info::Declaration => {
169 render_locations(project
170 .update(cx, |project, cx| {
171 project.declaration(&buffer, position, cx)
172 })?
173 .await?.into_iter().map(|link| link.target),
174 cx)
175 }
176 Info::Implementation => {
177 render_locations(project
178 .update(cx, |project, cx| {
179 project.implementation(&buffer, position, cx)
180 })?
181 .await?.into_iter().map(|link| link.target),
182 cx)
183 }
184 Info::TypeDefinition => {
185 render_locations(project
186 .update(cx, |project, cx| {
187 project.type_definition(&buffer, position, cx)
188 })?
189 .await?.into_iter().map(|link| link.target),
190 cx)
191 }
192 Info::References => {
193 render_locations(project
194 .update(cx, |project, cx| {
195 project.references(&buffer, position, cx)
196 })?
197 .await?,
198 cx)
199 }
200 };
201
202 if output.is_empty() {
203 Err(anyhow!("None found."))
204 } else {
205 Ok(output)
206 }
207 })
208 }
209}
210
211/// Finds the range of the symbol in the buffer, if it appears between context_before_symbol
212/// and context_after_symbol, and if that combined string has one unique result in the buffer.
213fn find_symbol_range(
214 buffer: &Buffer,
215 context_before_symbol: &str,
216 symbol: &str,
217 context_after_symbol: &str,
218) -> Option<Range<Anchor>> {
219 let snapshot = buffer.snapshot();
220 let text = snapshot.text();
221 let search_string = format!("{context_before_symbol}{symbol}{context_after_symbol}");
222 let mut positions = text.match_indices(&search_string);
223 let position = positions.next()?.0;
224
225 // The combined string must appear exactly once.
226 if positions.next().is_some() {
227 return None;
228 }
229
230 let symbol_start = position + context_before_symbol.len();
231 let symbol_end = symbol_start + symbol.len();
232 let symbol_start_anchor = snapshot.anchor_before(snapshot.offset_to_point(symbol_start));
233 let symbol_end_anchor = snapshot.anchor_before(snapshot.offset_to_point(symbol_end));
234
235 Some(symbol_start_anchor..symbol_end_anchor)
236}
237
238fn render_locations(locations: impl IntoIterator<Item = Location>, cx: &mut AsyncApp) -> String {
239 let mut answer = String::new();
240
241 for location in locations {
242 location
243 .buffer
244 .read_with(cx, |buffer, _cx| {
245 if let Some(target_path) = buffer
246 .file()
247 .and_then(|file| file.path().as_os_str().to_str())
248 {
249 let snapshot = buffer.snapshot();
250 let start = location.range.start.to_point(&snapshot);
251 let end = location.range.end.to_point(&snapshot);
252 let start_line = start.row + 1;
253 let start_col = start.column + 1;
254 let end_line = end.row + 1;
255 let end_col = end.column + 1;
256
257 if start_line == end_line {
258 writeln!(answer, "{target_path}:{start_line},{start_col}")
259 } else {
260 writeln!(
261 answer,
262 "{target_path}:{start_line},{start_col}-{end_line},{end_col}",
263 )
264 }
265 .ok();
266
267 write_code_excerpt(&mut answer, &snapshot, &location.range);
268 }
269 })
270 .ok();
271 }
272
273 // Trim trailing newlines without reallocating.
274 answer.truncate(answer.trim_end().len());
275
276 answer
277}
278
279fn write_code_excerpt(buf: &mut String, snapshot: &BufferSnapshot, range: &Range<Anchor>) {
280 const MAX_LINE_LEN: u32 = 200;
281
282 let start = range.start.to_point(snapshot);
283 let end = range.end.to_point(snapshot);
284
285 for row in start.row..=end.row {
286 let row_start = Point::new(row, 0);
287 let row_end = if row < snapshot.max_point().row {
288 Point::new(row + 1, 0)
289 } else {
290 Point::new(row, u32::MAX)
291 };
292
293 buf.extend(
294 snapshot
295 .text_for_range(row_start..row_end)
296 .take(MAX_LINE_LEN as usize),
297 );
298
299 if row_end.column > MAX_LINE_LEN {
300 buf.push_str("…\n");
301 }
302
303 buf.push('\n');
304 }
305}