code_action_tool.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use assistant_tool::{ActionLog, Tool, ToolResult};
  3use gpui::{AnyWindowHandle, App, Entity, Task};
  4use language::{self, Anchor, Buffer, ToPointUtf16};
  5use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  6use project::{self, LspAction, Project};
  7use regex::Regex;
  8use schemars::JsonSchema;
  9use serde::{Deserialize, Serialize};
 10use std::{ops::Range, sync::Arc};
 11use ui::IconName;
 12
 13use crate::schema::json_schema_for;
 14
 15#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 16pub struct CodeActionToolInput {
 17    /// The relative path to the file containing the text range.
 18    ///
 19    /// WARNING: you MUST start this path with one of the project's root directories.
 20    pub path: String,
 21
 22    /// The specific code action to execute.
 23    ///
 24    /// If this field is provided, the tool will execute the specified action.
 25    /// If omitted, the tool will list all available code actions for the text range.
 26    ///
 27    /// Here are some actions that are commonly supported (but may not be for this particular
 28    /// text range; you can omit this field to list all the actions, if you want to know
 29    /// what your options are, or you can just try an action and if it fails I'll tell you
 30    /// what the available actions were instead):
 31    /// - "quickfix.all" - applies all available quick fixes in the range
 32    /// - "source.organizeImports" - sorts and cleans up import statements
 33    /// - "source.fixAll" - applies all available auto fixes
 34    /// - "refactor.extract" - extracts selected code into a new function or variable
 35    /// - "refactor.inline" - inlines a variable by replacing references with its value
 36    /// - "refactor.rewrite" - general code rewriting operations
 37    /// - "source.addMissingImports" - adds imports for references that lack them
 38    /// - "source.removeUnusedImports" - removes imports that aren't being used
 39    /// - "source.implementInterface" - generates methods required by an interface/trait
 40    /// - "source.generateAccessors" - creates getter/setter methods
 41    /// - "source.convertToAsyncFunction" - converts callback-style code to async/await
 42    ///
 43    /// Also, there is a special case: if you specify exactly "textDocument/rename" as the action,
 44    /// then this will rename the symbol to whatever string you specified for the `arguments` field.
 45    pub action: Option<String>,
 46
 47    /// Optional arguments to pass to the code action.
 48    ///
 49    /// For rename operations (when action="textDocument/rename"), this should contain the new name.
 50    /// For other code actions, these arguments may be passed to the language server.
 51    pub arguments: Option<serde_json::Value>,
 52
 53    /// The text that comes immediately before the text range in the file.
 54    pub context_before_range: String,
 55
 56    /// The text range. This text must appear in the file right between `context_before_range`
 57    /// and `context_after_range`.
 58    ///
 59    /// The file must contain exactly one occurrence of `context_before_range` followed by
 60    /// `text_range` followed by `context_after_range`. If the file contains zero occurrences,
 61    /// or if it contains more than one occurrence, the tool will fail, so it is absolutely
 62    /// critical that you verify ahead of time that the string is unique. You can search
 63    /// the file's contents to verify this ahead of time.
 64    ///
 65    /// To make the string more likely to be unique, include a minimum of 1 line of context
 66    /// before the text range, as well as a minimum of 1 line of context after the text range.
 67    /// If these lines of context are not enough to obtain a string that appears only once
 68    /// in the file, then double the number of context lines until the string becomes unique.
 69    /// (Start with 1 line before and 1 line after though, because too much context is
 70    /// needlessly costly.)
 71    ///
 72    /// Do not alter the context lines of code in any way, and make sure to preserve all
 73    /// whitespace and indentation for all lines of code. The combined string must be exactly
 74    /// as it appears in the file, or else this tool call will fail.
 75    pub text_range: String,
 76
 77    /// The text that comes immediately after the text range in the file.
 78    pub context_after_range: String,
 79}
 80
 81pub struct CodeActionTool;
 82
 83impl Tool for CodeActionTool {
 84    fn name(&self) -> String {
 85        "code_actions".into()
 86    }
 87
 88    fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool {
 89        false
 90    }
 91
 92    fn description(&self) -> String {
 93        include_str!("./code_action_tool/description.md").into()
 94    }
 95
 96    fn icon(&self) -> IconName {
 97        IconName::Wand
 98    }
 99
100    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
101        json_schema_for::<CodeActionToolInput>(format)
102    }
103
104    fn ui_text(&self, input: &serde_json::Value) -> String {
105        match serde_json::from_value::<CodeActionToolInput>(input.clone()) {
106            Ok(input) => {
107                if let Some(action) = &input.action {
108                    if action == "textDocument/rename" {
109                        let new_name = match &input.arguments {
110                            Some(serde_json::Value::String(new_name)) => new_name.clone(),
111                            Some(value) => {
112                                if let Ok(new_name) =
113                                    serde_json::from_value::<String>(value.clone())
114                                {
115                                    new_name
116                                } else {
117                                    "invalid name".to_string()
118                                }
119                            }
120                            None => "missing name".to_string(),
121                        };
122                        format!("Rename '{}' to '{}'", input.text_range, new_name)
123                    } else {
124                        format!(
125                            "Execute code action '{}' for '{}'",
126                            action, input.text_range
127                        )
128                    }
129                } else {
130                    format!("List available code actions for '{}'", input.text_range)
131                }
132            }
133            Err(_) => "Perform code action".to_string(),
134        }
135    }
136
137    fn run(
138        self: Arc<Self>,
139        input: serde_json::Value,
140        _messages: &[LanguageModelRequestMessage],
141        project: Entity<Project>,
142        action_log: Entity<ActionLog>,
143        _window: Option<AnyWindowHandle>,
144        cx: &mut App,
145    ) -> ToolResult {
146        let input = match serde_json::from_value::<CodeActionToolInput>(input) {
147            Ok(input) => input,
148            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
149        };
150
151        cx.spawn(async move |cx| {
152            let buffer = {
153                let project_path = project.read_with(cx, |project, cx| {
154                    project
155                        .find_project_path(&input.path, cx)
156                        .context("Path not found in project")
157                })??;
158
159                project.update(cx, |project, cx| project.open_buffer(project_path, cx))?.await?
160            };
161
162            action_log.update(cx, |action_log, cx| {
163                action_log.track_buffer(buffer.clone(), cx);
164            })?;
165
166            let range = {
167                let Some(range) = buffer.read_with(cx, |buffer, _cx| {
168                    find_text_range(&buffer, &input.context_before_range, &input.text_range, &input.context_after_range)
169                })? else {
170                    return Err(anyhow!(
171                        "Failed to locate the text specified by context_before_range, text_range, and context_after_range. Make sure context_before_range and context_after_range each match exactly once in the file."
172                    ));
173                };
174
175                range
176            };
177
178            if let Some(action_type) = &input.action {
179                // Special-case the `rename` operation
180                let response = if action_type == "textDocument/rename" {
181                    let Some(new_name) = input.arguments.and_then(|args| serde_json::from_value::<String>(args).ok()) else {
182                        return Err(anyhow!("For rename operations, 'arguments' must be a string containing the new name"));
183                    };
184
185                    let position = buffer.read_with(cx, |buffer, _| {
186                        range.start.to_point_utf16(&buffer.snapshot())
187                    })?;
188
189                    project
190                        .update(cx, |project, cx| {
191                            project.perform_rename(buffer.clone(), position, new_name.clone(), cx)
192                        })?
193                        .await?;
194
195                    format!("Renamed '{}' to '{}'", input.text_range, new_name)
196                } else {
197                    // Get code actions for the range
198                    let actions = project
199                        .update(cx, |project, cx| {
200                            project.code_actions(&buffer, range.clone(), None, cx)
201                        })?
202                        .await?;
203
204                    if actions.is_empty() {
205                        return Err(anyhow!("No code actions available for this range"));
206                    }
207
208                    // Find all matching actions
209                    let regex = match Regex::new(action_type) {
210                        Ok(regex) => regex,
211                        Err(err) => return Err(anyhow!("Invalid regex pattern: {}", err)),
212                    };
213                    let mut matching_actions = actions
214                        .into_iter()
215                        .filter(|action| { regex.is_match(action.lsp_action.title()) });
216
217                    let Some(action) = matching_actions.next() else {
218                        return Err(anyhow!("No code actions match the pattern: {}", action_type));
219                    };
220
221                    // There should have been exactly one matching action.
222                    if let Some(second) = matching_actions.next() {
223                        let mut all_matches = vec![action, second];
224
225                        all_matches.extend(matching_actions);
226
227                        return Err(anyhow!(
228                            "Pattern '{}' matches multiple code actions: {}",
229                            action_type,
230                            all_matches.into_iter().map(|action| action.lsp_action.title().to_string()).collect::<Vec<_>>().join(", ")
231                        ));
232                    }
233
234                    let title = action.lsp_action.title().to_string();
235
236                    project
237                        .update(cx, |project, cx| {
238                            project.apply_code_action(buffer.clone(), action, true, cx)
239                        })?
240                        .await?;
241
242                    format!("Completed code action: {}", title)
243                };
244
245                project
246                    .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
247                    .await?;
248
249                action_log.update(cx, |log, cx| {
250                    log.buffer_edited(buffer.clone(), cx)
251                })?;
252
253                Ok(response)
254            } else {
255                // No action specified, so list the available ones.
256                let (position_start, position_end) = buffer.read_with(cx, |buffer, _| {
257                    let snapshot = buffer.snapshot();
258                    (
259                        range.start.to_point_utf16(&snapshot),
260                        range.end.to_point_utf16(&snapshot)
261                    )
262                })?;
263
264                // Convert position to display coordinates (1-based)
265                let position_start_display = language::Point {
266                    row: position_start.row + 1,
267                    column: position_start.column + 1,
268                };
269
270                let position_end_display = language::Point {
271                    row: position_end.row + 1,
272                    column: position_end.column + 1,
273                };
274
275                // Get code actions for the range
276                let actions = project
277                    .update(cx, |project, cx| {
278                        project.code_actions(&buffer, range.clone(), None, cx)
279                    })?
280                    .await?;
281
282                let mut response = format!(
283                    "Available code actions for text range '{}' at position {}:{} to {}:{} (UTF-16 coordinates):\n\n",
284                    input.text_range,
285                    position_start_display.row, position_start_display.column,
286                    position_end_display.row, position_end_display.column
287                );
288
289                if actions.is_empty() {
290                    response.push_str("No code actions available for this range.");
291                } else {
292                    for (i, action) in actions.iter().enumerate() {
293                        let title = match &action.lsp_action {
294                            LspAction::Action(code_action) => code_action.title.as_str(),
295                            LspAction::Command(command) => command.title.as_str(),
296                            LspAction::CodeLens(code_lens) => {
297                                if let Some(cmd) = &code_lens.command {
298                                    cmd.title.as_str()
299                                } else {
300                                    "Unknown code lens"
301                                }
302                            },
303                        };
304
305                        let kind = match &action.lsp_action {
306                            LspAction::Action(code_action) => {
307                                if let Some(kind) = &code_action.kind {
308                                    kind.as_str()
309                                } else {
310                                    "unknown"
311                                }
312                            },
313                            LspAction::Command(_) => "command",
314                            LspAction::CodeLens(_) => "code_lens",
315                        };
316
317                        response.push_str(&format!("{}. {title} ({kind})\n", i + 1));
318                    }
319                }
320
321                Ok(response)
322            }
323        }).into()
324    }
325}
326
327/// Finds the range of the text in the buffer, if it appears between context_before_range
328/// and context_after_range, and if that combined string has one unique result in the buffer.
329///
330/// If an exact match fails, it tries adding a newline to the end of context_before_range and
331/// to the beginning of context_after_range to accommodate line-based context matching.
332fn find_text_range(
333    buffer: &Buffer,
334    context_before_range: &str,
335    text_range: &str,
336    context_after_range: &str,
337) -> Option<Range<Anchor>> {
338    let snapshot = buffer.snapshot();
339    let text = snapshot.text();
340
341    // First try with exact match
342    let search_string = format!("{context_before_range}{text_range}{context_after_range}");
343    let mut positions = text.match_indices(&search_string);
344    let position_result = positions.next();
345
346    if let Some(position) = position_result {
347        // Check if the matched string is unique
348        if positions.next().is_none() {
349            let range_start = position.0 + context_before_range.len();
350            let range_end = range_start + text_range.len();
351            let range_start_anchor = snapshot.anchor_before(snapshot.offset_to_point(range_start));
352            let range_end_anchor = snapshot.anchor_before(snapshot.offset_to_point(range_end));
353
354            return Some(range_start_anchor..range_end_anchor);
355        }
356    }
357
358    // If exact match fails or is not unique, try with line-based context
359    // Add a newline to the end of before context and beginning of after context
360    let line_based_before = if context_before_range.ends_with('\n') {
361        context_before_range.to_string()
362    } else {
363        format!("{context_before_range}\n")
364    };
365
366    let line_based_after = if context_after_range.starts_with('\n') {
367        context_after_range.to_string()
368    } else {
369        format!("\n{context_after_range}")
370    };
371
372    let line_search_string = format!("{line_based_before}{text_range}{line_based_after}");
373    let mut line_positions = text.match_indices(&line_search_string);
374    let line_position = line_positions.next()?;
375
376    // The line-based search string must also appear exactly once
377    if line_positions.next().is_some() {
378        return None;
379    }
380
381    let line_range_start = line_position.0 + line_based_before.len();
382    let line_range_end = line_range_start + text_range.len();
383    let line_range_start_anchor =
384        snapshot.anchor_before(snapshot.offset_to_point(line_range_start));
385    let line_range_end_anchor = snapshot.anchor_before(snapshot.offset_to_point(line_range_end));
386
387    Some(line_range_start_anchor..line_range_end_anchor)
388}