symbol_info_tool.rs

  1use anyhow::{anyhow, Context as _, Result};
  2use assistant_tool::{ActionLog, Tool};
  3use gpui::{App, AsyncApp, Entity, Task};
  4use language::{self, Anchor, Buffer, BufferSnapshot, Location, Point, ToPoint, ToPointUtf16};
  5use language_model::LanguageModelRequestMessage;
  6use project::Project;
  7use schemars::JsonSchema;
  8use serde::{Deserialize, Serialize};
  9use std::{fmt::Write, ops::Range, sync::Arc};
 10use ui::IconName;
 11use util::markdown::MarkdownString;
 12
 13#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 14pub struct SymbolInfoToolInput {
 15    /// The relative path to the file containing the symbol.
 16    ///
 17    /// WARNING: you MUST start this path with one of the project's root directories.
 18    pub path: String,
 19
 20    /// The information to get about the symbol.
 21    pub command: Info,
 22
 23    /// The text that comes immediately before the symbol in the file.
 24    pub context_before_symbol: String,
 25
 26    /// The symbol name. This text must appear in the file right between `context_before_symbol`
 27    /// and `context_after_symbol`.
 28    ///
 29    /// The file must contain exactly one occurrence of `context_before_symbol` followed by
 30    /// `symbol` followed by `context_after_symbol`. If the file contains zero occurrences,
 31    /// or if it contains more than one occurrence, the tool will fail, so it is absolutely
 32    /// critical that you verify ahead of time that the string is unique. You can search
 33    /// the file's contents to verify this ahead of time.
 34    ///
 35    /// To make the string more likely to be unique, include a minimum of 1 line of context
 36    /// before the symbol, as well as a minimum of 1 line of context after the symbol.
 37    /// If these lines of context are not enough to obtain a string that appears only once
 38    /// in the file, then double the number of context lines until the string becomes unique.
 39    /// (Start with 1 line before and 1 line after though, because too much context is
 40    /// needlessly costly.)
 41    ///
 42    /// Do not alter the context lines of code in any way, and make sure to preserve all
 43    /// whitespace and indentation for all lines of code. The combined string must be exactly
 44    /// as it appears in the file, or else this tool call will fail.
 45    pub symbol: String,
 46
 47    /// The text that comes immediately after the symbol in the file.
 48    pub context_after_symbol: String,
 49}
 50
 51#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 52#[serde(rename_all = "snake_case")]
 53pub enum Info {
 54    /// Get the symbol's definition (where it's first assigned, even if it's declared elsewhere)
 55    Definition,
 56    /// Get the symbol's declaration (where it's first declared)
 57    Declaration,
 58    /// Get the symbol's implementation
 59    Implementation,
 60    /// Get the symbol's type definition
 61    TypeDefinition,
 62    /// Find all references to the symbol in the project
 63    References,
 64}
 65
 66pub struct SymbolInfoTool;
 67
 68impl Tool for SymbolInfoTool {
 69    fn name(&self) -> String {
 70        "symbol-info".into()
 71    }
 72
 73    fn needs_confirmation(&self) -> bool {
 74        false
 75    }
 76
 77    fn description(&self) -> String {
 78        include_str!("./symbol_info_tool/description.md").into()
 79    }
 80
 81    fn icon(&self) -> IconName {
 82        IconName::Eye
 83    }
 84
 85    fn input_schema(&self) -> serde_json::Value {
 86        let schema = schemars::schema_for!(SymbolInfoToolInput);
 87        serde_json::to_value(&schema).unwrap()
 88    }
 89
 90    fn ui_text(&self, input: &serde_json::Value) -> String {
 91        match serde_json::from_value::<SymbolInfoToolInput>(input.clone()) {
 92            Ok(input) => {
 93                let symbol = MarkdownString::inline_code(&input.symbol);
 94
 95                match input.command {
 96                    Info::Definition => {
 97                        format!("Find definition for {symbol}")
 98                    }
 99                    Info::Declaration => {
100                        format!("Find declaration for {symbol}")
101                    }
102                    Info::Implementation => {
103                        format!("Find implementation for {symbol}")
104                    }
105                    Info::TypeDefinition => {
106                        format!("Find type definition for {symbol}")
107                    }
108                    Info::References => {
109                        format!("Find references for {symbol}")
110                    }
111                }
112            }
113            Err(_) => "Get symbol info".to_string(),
114        }
115    }
116
117    fn run(
118        self: Arc<Self>,
119        input: serde_json::Value,
120        _messages: &[LanguageModelRequestMessage],
121        project: Entity<Project>,
122        action_log: Entity<ActionLog>,
123        cx: &mut App,
124    ) -> Task<Result<String>> {
125        let input = match serde_json::from_value::<SymbolInfoToolInput>(input) {
126            Ok(input) => input,
127            Err(err) => return Task::ready(Err(anyhow!(err))),
128        };
129
130        cx.spawn(async move |cx| {
131            let buffer = {
132                let project_path = project.read_with(cx, |project, cx| {
133                    project
134                        .find_project_path(&input.path, cx)
135                        .context("Path not found in project")
136                })??;
137
138                project.update(cx, |project, cx| project.open_buffer(project_path, cx))?.await?
139            };
140
141            action_log.update(cx, |action_log, cx| {
142                action_log.buffer_read(buffer.clone(), cx);
143            })?;
144
145            let position = {
146                let Some(range) = buffer.read_with(cx, |buffer, _cx| {
147                    find_symbol_range(&buffer, &input.context_before_symbol, &input.symbol, &input.context_after_symbol)
148                })? else {
149                    return Err(anyhow!(
150                        "Failed to locate the text specified by context_before_symbol, symbol, and context_after_symbol. Make sure context_before_symbol and context_after_symbol each match exactly once in the file."
151                    ));
152                };
153
154                buffer.read_with(cx, |buffer, _| {
155                    range.start.to_point_utf16(&buffer.snapshot())
156                })?
157            };
158
159            let output: String = match input.command {
160                Info::Definition => {
161                    render_locations(project
162                        .update(cx, |project, cx| {
163                            project.definition(&buffer, position, cx)
164                        })?
165                        .await?.into_iter().map(|link| link.target),
166                        cx)
167                }
168                Info::Declaration => {
169                    render_locations(project
170                        .update(cx, |project, cx| {
171                            project.declaration(&buffer, position, cx)
172                        })?
173                        .await?.into_iter().map(|link| link.target),
174                        cx)
175                }
176                Info::Implementation => {
177                    render_locations(project
178                        .update(cx, |project, cx| {
179                            project.implementation(&buffer, position, cx)
180                        })?
181                        .await?.into_iter().map(|link| link.target),
182                        cx)
183                }
184                Info::TypeDefinition => {
185                    render_locations(project
186                        .update(cx, |project, cx| {
187                            project.type_definition(&buffer, position, cx)
188                        })?
189                        .await?.into_iter().map(|link| link.target),
190                        cx)
191                }
192                Info::References => {
193                    render_locations(project
194                        .update(cx, |project, cx| {
195                            project.references(&buffer, position, cx)
196                        })?
197                        .await?,
198                        cx)
199                }
200            };
201
202            if output.is_empty() {
203                Err(anyhow!("None found."))
204            } else {
205                Ok(output)
206            }
207        })
208    }
209}
210
211/// Finds the range of the symbol in the buffer, if it appears between context_before_symbol
212/// and context_after_symbol, and if that combined string has one unique result in the buffer.
213fn find_symbol_range(
214    buffer: &Buffer,
215    context_before_symbol: &str,
216    symbol: &str,
217    context_after_symbol: &str,
218) -> Option<Range<Anchor>> {
219    let snapshot = buffer.snapshot();
220    let text = snapshot.text();
221    let search_string = format!("{context_before_symbol}{symbol}{context_after_symbol}");
222    let mut positions = text.match_indices(&search_string);
223    let position = positions.next()?.0;
224
225    // The combined string must appear exactly once.
226    if positions.next().is_some() {
227        return None;
228    }
229
230    let symbol_start = position + context_before_symbol.len();
231    let symbol_end = symbol_start + symbol.len();
232    let symbol_start_anchor = snapshot.anchor_before(snapshot.offset_to_point(symbol_start));
233    let symbol_end_anchor = snapshot.anchor_before(snapshot.offset_to_point(symbol_end));
234
235    Some(symbol_start_anchor..symbol_end_anchor)
236}
237
238fn render_locations(locations: impl IntoIterator<Item = Location>, cx: &mut AsyncApp) -> String {
239    let mut answer = String::new();
240
241    for location in locations {
242        location
243            .buffer
244            .read_with(cx, |buffer, _cx| {
245                if let Some(target_path) = buffer
246                    .file()
247                    .and_then(|file| file.path().as_os_str().to_str())
248                {
249                    let snapshot = buffer.snapshot();
250                    let start = location.range.start.to_point(&snapshot);
251                    let end = location.range.end.to_point(&snapshot);
252                    let start_line = start.row + 1;
253                    let start_col = start.column + 1;
254                    let end_line = end.row + 1;
255                    let end_col = end.column + 1;
256
257                    if start_line == end_line {
258                        writeln!(answer, "{target_path}:{start_line},{start_col}")
259                    } else {
260                        writeln!(
261                            answer,
262                            "{target_path}:{start_line},{start_col}-{end_line},{end_col}",
263                        )
264                    }
265                    .ok();
266
267                    write_code_excerpt(&mut answer, &snapshot, &location.range);
268                }
269            })
270            .ok();
271    }
272
273    // Trim trailing newlines without reallocating.
274    answer.truncate(answer.trim_end().len());
275
276    answer
277}
278
279fn write_code_excerpt(buf: &mut String, snapshot: &BufferSnapshot, range: &Range<Anchor>) {
280    const MAX_LINE_LEN: u32 = 200;
281
282    let start = range.start.to_point(snapshot);
283    let end = range.end.to_point(snapshot);
284
285    for row in start.row..=end.row {
286        let row_start = Point::new(row, 0);
287        let row_end = if row < snapshot.max_point().row {
288            Point::new(row + 1, 0)
289        } else {
290            Point::new(row, u32::MAX)
291        };
292
293        buf.extend(
294            snapshot
295                .text_for_range(row_start..row_end)
296                .take(MAX_LINE_LEN as usize),
297        );
298
299        if row_end.column > MAX_LINE_LEN {
300            buf.push_str("\n");
301        }
302
303        buf.push('\n');
304    }
305}