symbol_info_tool.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use assistant_tool::{ActionLog, Tool, ToolResult};
  3use gpui::{AnyWindowHandle, App, AsyncApp, Entity, Task};
  4use language::{self, Anchor, Buffer, BufferSnapshot, Location, Point, ToPoint, ToPointUtf16};
  5use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  6use project::Project;
  7use schemars::JsonSchema;
  8use serde::{Deserialize, Serialize};
  9use std::{fmt::Write, ops::Range, sync::Arc};
 10use ui::IconName;
 11use util::markdown::MarkdownInlineCode;
 12
 13use crate::schema::json_schema_for;
 14
 15#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 16pub struct SymbolInfoToolInput {
 17    /// The relative path to the file containing the symbol.
 18    ///
 19    /// WARNING: you MUST start this path with one of the project's root directories.
 20    pub path: String,
 21
 22    /// The information to get about the symbol.
 23    pub command: Info,
 24
 25    /// The text that comes immediately before the symbol in the file.
 26    pub context_before_symbol: String,
 27
 28    /// The symbol name. This text must appear in the file right between `context_before_symbol`
 29    /// and `context_after_symbol`.
 30    ///
 31    /// The file must contain exactly one occurrence of `context_before_symbol` followed by
 32    /// `symbol` followed by `context_after_symbol`. If the file contains zero occurrences,
 33    /// or if it contains more than one occurrence, the tool will fail, so it is absolutely
 34    /// critical that you verify ahead of time that the string is unique. You can search
 35    /// the file's contents to verify this ahead of time.
 36    ///
 37    /// To make the string more likely to be unique, include a minimum of 1 line of context
 38    /// before the symbol, as well as a minimum of 1 line of context after the symbol.
 39    /// If these lines of context are not enough to obtain a string that appears only once
 40    /// in the file, then double the number of context lines until the string becomes unique.
 41    /// (Start with 1 line before and 1 line after though, because too much context is
 42    /// needlessly costly.)
 43    ///
 44    /// Do not alter the context lines of code in any way, and make sure to preserve all
 45    /// whitespace and indentation for all lines of code. The combined string must be exactly
 46    /// as it appears in the file, or else this tool call will fail.
 47    pub symbol: String,
 48
 49    /// The text that comes immediately after the symbol in the file.
 50    pub context_after_symbol: String,
 51}
 52
 53#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 54#[serde(rename_all = "snake_case")]
 55pub enum Info {
 56    /// Get the symbol's definition (where it's first assigned, even if it's declared elsewhere)
 57    Definition,
 58    /// Get the symbol's declaration (where it's first declared)
 59    Declaration,
 60    /// Get the symbol's implementation
 61    Implementation,
 62    /// Get the symbol's type definition
 63    TypeDefinition,
 64    /// Find all references to the symbol in the project
 65    References,
 66}
 67
 68pub struct SymbolInfoTool;
 69
 70impl Tool for SymbolInfoTool {
 71    fn name(&self) -> String {
 72        "symbol_info".into()
 73    }
 74
 75    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 76        false
 77    }
 78
 79    fn description(&self) -> String {
 80        include_str!("./symbol_info_tool/description.md").into()
 81    }
 82
 83    fn icon(&self) -> IconName {
 84        IconName::Code
 85    }
 86
 87    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 88        json_schema_for::<SymbolInfoToolInput>(format)
 89    }
 90
 91    fn ui_text(&self, input: &serde_json::Value) -> String {
 92        match serde_json::from_value::<SymbolInfoToolInput>(input.clone()) {
 93            Ok(input) => {
 94                let symbol = MarkdownInlineCode(&input.symbol);
 95
 96                match input.command {
 97                    Info::Definition => {
 98                        format!("Find definition for {symbol}")
 99                    }
100                    Info::Declaration => {
101                        format!("Find declaration for {symbol}")
102                    }
103                    Info::Implementation => {
104                        format!("Find implementation for {symbol}")
105                    }
106                    Info::TypeDefinition => {
107                        format!("Find type definition for {symbol}")
108                    }
109                    Info::References => {
110                        format!("Find references for {symbol}")
111                    }
112                }
113            }
114            Err(_) => "Get symbol info".to_string(),
115        }
116    }
117
118    fn run(
119        self: Arc<Self>,
120        input: serde_json::Value,
121        _messages: &[LanguageModelRequestMessage],
122        project: Entity<Project>,
123        action_log: Entity<ActionLog>,
124        _window: Option<AnyWindowHandle>,
125        cx: &mut App,
126    ) -> ToolResult {
127        let input = match serde_json::from_value::<SymbolInfoToolInput>(input) {
128            Ok(input) => input,
129            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
130        };
131
132        cx.spawn(async move |cx| {
133            let buffer = {
134                let project_path = project.read_with(cx, |project, cx| {
135                    project
136                        .find_project_path(&input.path, cx)
137                        .context("Path not found in project")
138                })??;
139
140                project.update(cx, |project, cx| project.open_buffer(project_path, cx))?.await?
141            };
142
143            action_log.update(cx, |action_log, cx| {
144                action_log.track_buffer(buffer.clone(), cx);
145            })?;
146
147            let position = {
148                let Some(range) = buffer.read_with(cx, |buffer, _cx| {
149                    find_symbol_range(&buffer, &input.context_before_symbol, &input.symbol, &input.context_after_symbol)
150                })? else {
151                    return Err(anyhow!(
152                        "Failed to locate the text specified by context_before_symbol, symbol, and context_after_symbol. Make sure context_before_symbol and context_after_symbol each match exactly once in the file."
153                    ));
154                };
155
156                buffer.read_with(cx, |buffer, _| {
157                    range.start.to_point_utf16(&buffer.snapshot())
158                })?
159            };
160
161            let output: String = match input.command {
162                Info::Definition => {
163                    render_locations(project
164                        .update(cx, |project, cx| {
165                            project.definition(&buffer, position, cx)
166                        })?
167                        .await?.into_iter().map(|link| link.target),
168                        cx)
169                }
170                Info::Declaration => {
171                    render_locations(project
172                        .update(cx, |project, cx| {
173                            project.declaration(&buffer, position, cx)
174                        })?
175                        .await?.into_iter().map(|link| link.target),
176                        cx)
177                }
178                Info::Implementation => {
179                    render_locations(project
180                        .update(cx, |project, cx| {
181                            project.implementation(&buffer, position, cx)
182                        })?
183                        .await?.into_iter().map(|link| link.target),
184                        cx)
185                }
186                Info::TypeDefinition => {
187                    render_locations(project
188                        .update(cx, |project, cx| {
189                            project.type_definition(&buffer, position, cx)
190                        })?
191                        .await?.into_iter().map(|link| link.target),
192                        cx)
193                }
194                Info::References => {
195                    render_locations(project
196                        .update(cx, |project, cx| {
197                            project.references(&buffer, position, cx)
198                        })?
199                        .await?,
200                        cx)
201                }
202            };
203
204            if output.is_empty() {
205                Err(anyhow!("None found."))
206            } else {
207                Ok(output)
208            }
209        }).into()
210    }
211}
212
213/// Finds the range of the symbol in the buffer, if it appears between context_before_symbol
214/// and context_after_symbol, and if that combined string has one unique result in the buffer.
215fn find_symbol_range(
216    buffer: &Buffer,
217    context_before_symbol: &str,
218    symbol: &str,
219    context_after_symbol: &str,
220) -> Option<Range<Anchor>> {
221    let snapshot = buffer.snapshot();
222    let text = snapshot.text();
223    let search_string = format!("{context_before_symbol}{symbol}{context_after_symbol}");
224    let mut positions = text.match_indices(&search_string);
225    let position = positions.next()?.0;
226
227    // The combined string must appear exactly once.
228    if positions.next().is_some() {
229        return None;
230    }
231
232    let symbol_start = position + context_before_symbol.len();
233    let symbol_end = symbol_start + symbol.len();
234    let symbol_start_anchor = snapshot.anchor_before(snapshot.offset_to_point(symbol_start));
235    let symbol_end_anchor = snapshot.anchor_before(snapshot.offset_to_point(symbol_end));
236
237    Some(symbol_start_anchor..symbol_end_anchor)
238}
239
240fn render_locations(locations: impl IntoIterator<Item = Location>, cx: &mut AsyncApp) -> String {
241    let mut answer = String::new();
242
243    for location in locations {
244        location
245            .buffer
246            .read_with(cx, |buffer, _cx| {
247                if let Some(target_path) = buffer
248                    .file()
249                    .and_then(|file| file.path().as_os_str().to_str())
250                {
251                    let snapshot = buffer.snapshot();
252                    let start = location.range.start.to_point(&snapshot);
253                    let end = location.range.end.to_point(&snapshot);
254                    let start_line = start.row + 1;
255                    let start_col = start.column + 1;
256                    let end_line = end.row + 1;
257                    let end_col = end.column + 1;
258
259                    if start_line == end_line {
260                        writeln!(answer, "{target_path}:{start_line},{start_col}")
261                    } else {
262                        writeln!(
263                            answer,
264                            "{target_path}:{start_line},{start_col}-{end_line},{end_col}",
265                        )
266                    }
267                    .ok();
268
269                    write_code_excerpt(&mut answer, &snapshot, &location.range);
270                }
271            })
272            .ok();
273    }
274
275    // Trim trailing newlines without reallocating.
276    answer.truncate(answer.trim_end().len());
277
278    answer
279}
280
281fn write_code_excerpt(buf: &mut String, snapshot: &BufferSnapshot, range: &Range<Anchor>) {
282    const MAX_LINE_LEN: u32 = 200;
283
284    let start = range.start.to_point(snapshot);
285    let end = range.end.to_point(snapshot);
286
287    for row in start.row..=end.row {
288        let row_start = Point::new(row, 0);
289        let row_end = if row < snapshot.max_point().row {
290            Point::new(row + 1, 0)
291        } else {
292            Point::new(row, u32::MAX)
293        };
294
295        buf.extend(
296            snapshot
297                .text_for_range(row_start..row_end)
298                .take(MAX_LINE_LEN as usize),
299        );
300
301        if row_end.column > MAX_LINE_LEN {
302            buf.push_str("\n");
303        }
304
305        buf.push('\n');
306    }
307}