symbol_info_tool.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use assistant_tool::{ActionLog, Tool, ToolResult};
  3use gpui::{App, AsyncApp, Entity, Task};
  4use language::{self, Anchor, Buffer, BufferSnapshot, Location, Point, ToPoint, ToPointUtf16};
  5use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  6use project::Project;
  7use schemars::JsonSchema;
  8use serde::{Deserialize, Serialize};
  9use std::{fmt::Write, ops::Range, sync::Arc};
 10use ui::IconName;
 11use util::markdown::MarkdownString;
 12
 13use crate::schema::json_schema_for;
 14
 15#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 16pub struct SymbolInfoToolInput {
 17    /// The relative path to the file containing the symbol.
 18    ///
 19    /// WARNING: you MUST start this path with one of the project's root directories.
 20    pub path: String,
 21
 22    /// The information to get about the symbol.
 23    pub command: Info,
 24
 25    /// The text that comes immediately before the symbol in the file.
 26    pub context_before_symbol: String,
 27
 28    /// The symbol name. This text must appear in the file right between `context_before_symbol`
 29    /// and `context_after_symbol`.
 30    ///
 31    /// The file must contain exactly one occurrence of `context_before_symbol` followed by
 32    /// `symbol` followed by `context_after_symbol`. If the file contains zero occurrences,
 33    /// or if it contains more than one occurrence, the tool will fail, so it is absolutely
 34    /// critical that you verify ahead of time that the string is unique. You can search
 35    /// the file's contents to verify this ahead of time.
 36    ///
 37    /// To make the string more likely to be unique, include a minimum of 1 line of context
 38    /// before the symbol, as well as a minimum of 1 line of context after the symbol.
 39    /// If these lines of context are not enough to obtain a string that appears only once
 40    /// in the file, then double the number of context lines until the string becomes unique.
 41    /// (Start with 1 line before and 1 line after though, because too much context is
 42    /// needlessly costly.)
 43    ///
 44    /// Do not alter the context lines of code in any way, and make sure to preserve all
 45    /// whitespace and indentation for all lines of code. The combined string must be exactly
 46    /// as it appears in the file, or else this tool call will fail.
 47    pub symbol: String,
 48
 49    /// The text that comes immediately after the symbol in the file.
 50    pub context_after_symbol: String,
 51}
 52
 53#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 54#[serde(rename_all = "snake_case")]
 55pub enum Info {
 56    /// Get the symbol's definition (where it's first assigned, even if it's declared elsewhere)
 57    Definition,
 58    /// Get the symbol's declaration (where it's first declared)
 59    Declaration,
 60    /// Get the symbol's implementation
 61    Implementation,
 62    /// Get the symbol's type definition
 63    TypeDefinition,
 64    /// Find all references to the symbol in the project
 65    References,
 66}
 67
 68pub struct SymbolInfoTool;
 69
 70impl Tool for SymbolInfoTool {
 71    fn name(&self) -> String {
 72        "symbol_info".into()
 73    }
 74
 75    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 76        false
 77    }
 78
 79    fn description(&self) -> String {
 80        include_str!("./symbol_info_tool/description.md").into()
 81    }
 82
 83    fn icon(&self) -> IconName {
 84        IconName::Code
 85    }
 86
 87    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 88        json_schema_for::<SymbolInfoToolInput>(format)
 89    }
 90
 91    fn ui_text(&self, input: &serde_json::Value) -> String {
 92        match serde_json::from_value::<SymbolInfoToolInput>(input.clone()) {
 93            Ok(input) => {
 94                let symbol = MarkdownString::inline_code(&input.symbol);
 95
 96                match input.command {
 97                    Info::Definition => {
 98                        format!("Find definition for {symbol}")
 99                    }
100                    Info::Declaration => {
101                        format!("Find declaration for {symbol}")
102                    }
103                    Info::Implementation => {
104                        format!("Find implementation for {symbol}")
105                    }
106                    Info::TypeDefinition => {
107                        format!("Find type definition for {symbol}")
108                    }
109                    Info::References => {
110                        format!("Find references for {symbol}")
111                    }
112                }
113            }
114            Err(_) => "Get symbol info".to_string(),
115        }
116    }
117
118    fn run(
119        self: Arc<Self>,
120        input: serde_json::Value,
121        _messages: &[LanguageModelRequestMessage],
122        project: Entity<Project>,
123        action_log: Entity<ActionLog>,
124        cx: &mut App,
125    ) -> ToolResult {
126        let input = match serde_json::from_value::<SymbolInfoToolInput>(input) {
127            Ok(input) => input,
128            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
129        };
130
131        cx.spawn(async move |cx| {
132            let buffer = {
133                let project_path = project.read_with(cx, |project, cx| {
134                    project
135                        .find_project_path(&input.path, cx)
136                        .context("Path not found in project")
137                })??;
138
139                project.update(cx, |project, cx| project.open_buffer(project_path, cx))?.await?
140            };
141
142            action_log.update(cx, |action_log, cx| {
143                action_log.track_buffer(buffer.clone(), cx);
144            })?;
145
146            let position = {
147                let Some(range) = buffer.read_with(cx, |buffer, _cx| {
148                    find_symbol_range(&buffer, &input.context_before_symbol, &input.symbol, &input.context_after_symbol)
149                })? else {
150                    return Err(anyhow!(
151                        "Failed to locate the text specified by context_before_symbol, symbol, and context_after_symbol. Make sure context_before_symbol and context_after_symbol each match exactly once in the file."
152                    ));
153                };
154
155                buffer.read_with(cx, |buffer, _| {
156                    range.start.to_point_utf16(&buffer.snapshot())
157                })?
158            };
159
160            let output: String = match input.command {
161                Info::Definition => {
162                    render_locations(project
163                        .update(cx, |project, cx| {
164                            project.definition(&buffer, position, cx)
165                        })?
166                        .await?.into_iter().map(|link| link.target),
167                        cx)
168                }
169                Info::Declaration => {
170                    render_locations(project
171                        .update(cx, |project, cx| {
172                            project.declaration(&buffer, position, cx)
173                        })?
174                        .await?.into_iter().map(|link| link.target),
175                        cx)
176                }
177                Info::Implementation => {
178                    render_locations(project
179                        .update(cx, |project, cx| {
180                            project.implementation(&buffer, position, cx)
181                        })?
182                        .await?.into_iter().map(|link| link.target),
183                        cx)
184                }
185                Info::TypeDefinition => {
186                    render_locations(project
187                        .update(cx, |project, cx| {
188                            project.type_definition(&buffer, position, cx)
189                        })?
190                        .await?.into_iter().map(|link| link.target),
191                        cx)
192                }
193                Info::References => {
194                    render_locations(project
195                        .update(cx, |project, cx| {
196                            project.references(&buffer, position, cx)
197                        })?
198                        .await?,
199                        cx)
200                }
201            };
202
203            if output.is_empty() {
204                Err(anyhow!("None found."))
205            } else {
206                Ok(output)
207            }
208        }).into()
209    }
210}
211
212/// Finds the range of the symbol in the buffer, if it appears between context_before_symbol
213/// and context_after_symbol, and if that combined string has one unique result in the buffer.
214fn find_symbol_range(
215    buffer: &Buffer,
216    context_before_symbol: &str,
217    symbol: &str,
218    context_after_symbol: &str,
219) -> Option<Range<Anchor>> {
220    let snapshot = buffer.snapshot();
221    let text = snapshot.text();
222    let search_string = format!("{context_before_symbol}{symbol}{context_after_symbol}");
223    let mut positions = text.match_indices(&search_string);
224    let position = positions.next()?.0;
225
226    // The combined string must appear exactly once.
227    if positions.next().is_some() {
228        return None;
229    }
230
231    let symbol_start = position + context_before_symbol.len();
232    let symbol_end = symbol_start + symbol.len();
233    let symbol_start_anchor = snapshot.anchor_before(snapshot.offset_to_point(symbol_start));
234    let symbol_end_anchor = snapshot.anchor_before(snapshot.offset_to_point(symbol_end));
235
236    Some(symbol_start_anchor..symbol_end_anchor)
237}
238
239fn render_locations(locations: impl IntoIterator<Item = Location>, cx: &mut AsyncApp) -> String {
240    let mut answer = String::new();
241
242    for location in locations {
243        location
244            .buffer
245            .read_with(cx, |buffer, _cx| {
246                if let Some(target_path) = buffer
247                    .file()
248                    .and_then(|file| file.path().as_os_str().to_str())
249                {
250                    let snapshot = buffer.snapshot();
251                    let start = location.range.start.to_point(&snapshot);
252                    let end = location.range.end.to_point(&snapshot);
253                    let start_line = start.row + 1;
254                    let start_col = start.column + 1;
255                    let end_line = end.row + 1;
256                    let end_col = end.column + 1;
257
258                    if start_line == end_line {
259                        writeln!(answer, "{target_path}:{start_line},{start_col}")
260                    } else {
261                        writeln!(
262                            answer,
263                            "{target_path}:{start_line},{start_col}-{end_line},{end_col}",
264                        )
265                    }
266                    .ok();
267
268                    write_code_excerpt(&mut answer, &snapshot, &location.range);
269                }
270            })
271            .ok();
272    }
273
274    // Trim trailing newlines without reallocating.
275    answer.truncate(answer.trim_end().len());
276
277    answer
278}
279
280fn write_code_excerpt(buf: &mut String, snapshot: &BufferSnapshot, range: &Range<Anchor>) {
281    const MAX_LINE_LEN: u32 = 200;
282
283    let start = range.start.to_point(snapshot);
284    let end = range.end.to_point(snapshot);
285
286    for row in start.row..=end.row {
287        let row_start = Point::new(row, 0);
288        let row_end = if row < snapshot.max_point().row {
289            Point::new(row + 1, 0)
290        } else {
291            Point::new(row, u32::MAX)
292        };
293
294        buf.extend(
295            snapshot
296                .text_for_range(row_start..row_end)
297                .take(MAX_LINE_LEN as usize),
298        );
299
300        if row_end.column > MAX_LINE_LEN {
301            buf.push_str("\n");
302        }
303
304        buf.push('\n');
305    }
306}