@@ -718,7 +718,6 @@ dependencies = [
"itertools 0.14.0",
"language",
"language_model",
- "pretty_assertions",
"project",
"rand 0.8.5",
"release_channel",
@@ -728,7 +727,6 @@ dependencies = [
"settings",
"theme",
"ui",
- "unindent",
"util",
"workspace",
"worktree",
@@ -1,6 +1,5 @@
mod edit_action;
pub mod log;
-mod resolve_search_block;
use anyhow::{anyhow, Context, Result};
use assistant_tool::{ActionLog, Tool};
@@ -8,17 +7,16 @@ use collections::HashSet;
use edit_action::{EditAction, EditActionParser};
use futures::StreamExt;
use gpui::{App, AsyncApp, Entity, Task};
-use language::OffsetRangeExt;
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
};
use log::{EditToolLog, EditToolRequestId};
-use project::Project;
-use resolve_search_block::resolve_search_block;
+use project::{search::SearchQuery, Project};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::fmt::Write;
use std::sync::Arc;
+use util::paths::PathMatcher;
use util::ResultExt;
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
@@ -131,11 +129,24 @@ struct EditToolRequest {
parser: EditActionParser,
output: String,
changed_buffers: HashSet<Entity<language::Buffer>>,
+ bad_searches: Vec<BadSearch>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
}
+#[derive(Debug)]
+enum DiffResult {
+ BadSearch(BadSearch),
+ Diff(language::Diff),
+}
+
+#[derive(Debug)]
+struct BadSearch {
+ file_path: String,
+ search: String,
+}
+
impl EditToolRequest {
fn new(
input: EditFilesToolInput,
@@ -193,6 +204,7 @@ impl EditToolRequest {
// we start with the success header so we don't need to shift the output in the common case
output: Self::SUCCESS_OUTPUT_HEADER.to_string(),
changed_buffers: HashSet::default(),
+ bad_searches: Vec::new(),
action_log,
project,
tool_log,
@@ -239,30 +251,36 @@ impl EditToolRequest {
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?;
- let diff = match action {
+ let result = match action {
EditAction::Replace {
old,
new,
- file_path: _,
+ file_path,
} => {
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
- let diff = cx
- .background_executor()
- .spawn(Self::replace_diff(old, new, snapshot))
- .await;
-
- anyhow::Ok(diff)
+ cx.background_executor()
+ .spawn(Self::replace_diff(old, new, file_path, snapshot))
+ .await
}
- EditAction::Write { content, .. } => Ok(buffer
- .read_with(cx, |buffer, cx| buffer.diff(content, cx))?
- .await),
+ EditAction::Write { content, .. } => Ok(DiffResult::Diff(
+ buffer
+ .read_with(cx, |buffer, cx| buffer.diff(content, cx))?
+ .await,
+ )),
}?;
- let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
+ match result {
+ DiffResult::BadSearch(invalid_replace) => {
+ self.bad_searches.push(invalid_replace);
+ }
+ DiffResult::Diff(diff) => {
+ let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
- write!(&mut self.output, "\n\n{}", source)?;
- self.changed_buffers.insert(buffer);
+ write!(&mut self.output, "\n\n{}", source)?;
+ self.changed_buffers.insert(buffer);
+ }
+ }
Ok(())
}
@@ -270,9 +288,29 @@ impl EditToolRequest {
async fn replace_diff(
old: String,
new: String,
+ file_path: std::path::PathBuf,
snapshot: language::BufferSnapshot,
- ) -> language::Diff {
- let edit_range = resolve_search_block(&snapshot, &old).to_offset(&snapshot);
+ ) -> Result<DiffResult> {
+ let query = SearchQuery::text(
+ old.clone(),
+ false,
+ true,
+ true,
+ PathMatcher::new(&[])?,
+ PathMatcher::new(&[])?,
+ None,
+ )?;
+
+ let matches = query.search(&snapshot, None).await;
+
+ if matches.is_empty() {
+ return Ok(DiffResult::BadSearch(BadSearch {
+ search: new.clone(),
+ file_path: file_path.display().to_string(),
+ }));
+ }
+
+ let edit_range = matches[0].clone();
let diff = language::text_diff(&old, &new);
let edits = diff
@@ -290,7 +328,7 @@ impl EditToolRequest {
edits,
};
- diff
+ anyhow::Ok(DiffResult::Diff(diff))
}
const SUCCESS_OUTPUT_HEADER: &str = "Successfully applied. Here's a list of changes:";
@@ -314,7 +352,7 @@ impl EditToolRequest {
let errors = self.parser.errors();
- if errors.is_empty() {
+ if errors.is_empty() && self.bad_searches.is_empty() {
if changed_buffer_count == 0 {
return Err(anyhow!(
"The instructions didn't lead to any changes. You might need to consult the file contents first."
@@ -337,6 +375,24 @@ impl EditToolRequest {
);
}
+ if !self.bad_searches.is_empty() {
+ writeln!(
+ &mut output,
+ "\n\nThese searches failed because they didn't match any strings:"
+ )?;
+
+ for replace in self.bad_searches {
+ writeln!(
+ &mut output,
+ "- '{}' does not appear in `{}`",
+ replace.search.replace("\r", "\\r").replace("\n", "\\n"),
+ replace.file_path
+ )?;
+ }
+
+ write!(&mut output, "Make sure to use exact searches.")?;
+ }
+
if !errors.is_empty() {
writeln!(
&mut output,
@@ -1,226 +0,0 @@
-use language::{Anchor, Bias, BufferSnapshot};
-use std::ops::Range;
-
-#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
-enum SearchDirection {
- Up,
- Left,
- Diagonal,
-}
-
-#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
-struct SearchState {
- cost: u32,
- direction: SearchDirection,
-}
-
-impl SearchState {
- fn new(cost: u32, direction: SearchDirection) -> Self {
- Self { cost, direction }
- }
-}
-
-struct SearchMatrix {
- cols: usize,
- data: Vec<SearchState>,
-}
-
-impl SearchMatrix {
- fn new(rows: usize, cols: usize) -> Self {
- SearchMatrix {
- cols,
- data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
- }
- }
-
- fn get(&self, row: usize, col: usize) -> SearchState {
- self.data[row * self.cols + col]
- }
-
- fn set(&mut self, row: usize, col: usize, cost: SearchState) {
- self.data[row * self.cols + col] = cost;
- }
-}
-
-pub fn resolve_search_block(buffer: &BufferSnapshot, search_query: &str) -> Range<Anchor> {
- const INSERTION_COST: u32 = 3;
- const DELETION_COST: u32 = 10;
- const WHITESPACE_INSERTION_COST: u32 = 1;
- const WHITESPACE_DELETION_COST: u32 = 1;
-
- let buffer_len = buffer.len();
- let query_len = search_query.len();
- let mut matrix = SearchMatrix::new(query_len + 1, buffer_len + 1);
- let mut leading_deletion_cost = 0_u32;
- for (row, query_byte) in search_query.bytes().enumerate() {
- let deletion_cost = if query_byte.is_ascii_whitespace() {
- WHITESPACE_DELETION_COST
- } else {
- DELETION_COST
- };
-
- leading_deletion_cost = leading_deletion_cost.saturating_add(deletion_cost);
- matrix.set(
- row + 1,
- 0,
- SearchState::new(leading_deletion_cost, SearchDirection::Diagonal),
- );
-
- for (col, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() {
- let insertion_cost = if buffer_byte.is_ascii_whitespace() {
- WHITESPACE_INSERTION_COST
- } else {
- INSERTION_COST
- };
-
- let up = SearchState::new(
- matrix.get(row, col + 1).cost.saturating_add(deletion_cost),
- SearchDirection::Up,
- );
- let left = SearchState::new(
- matrix.get(row + 1, col).cost.saturating_add(insertion_cost),
- SearchDirection::Left,
- );
- let diagonal = SearchState::new(
- if query_byte == *buffer_byte {
- matrix.get(row, col).cost
- } else {
- matrix
- .get(row, col)
- .cost
- .saturating_add(deletion_cost + insertion_cost)
- },
- SearchDirection::Diagonal,
- );
- matrix.set(row + 1, col + 1, up.min(left).min(diagonal));
- }
- }
-
- // Traceback to find the best match
- let mut best_buffer_end = buffer_len;
- let mut best_cost = u32::MAX;
- for col in 1..=buffer_len {
- let cost = matrix.get(query_len, col).cost;
- if cost < best_cost {
- best_cost = cost;
- best_buffer_end = col;
- }
- }
-
- let mut query_ix = query_len;
- let mut buffer_ix = best_buffer_end;
- while query_ix > 0 && buffer_ix > 0 {
- let current = matrix.get(query_ix, buffer_ix);
- match current.direction {
- SearchDirection::Diagonal => {
- query_ix -= 1;
- buffer_ix -= 1;
- }
- SearchDirection::Up => {
- query_ix -= 1;
- }
- SearchDirection::Left => {
- buffer_ix -= 1;
- }
- }
- }
-
- let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left));
- start.column = 0;
- let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right));
- if end.column > 0 {
- end.column = buffer.line_len(end.row);
- }
-
- buffer.anchor_after(start)..buffer.anchor_before(end)
-}
-
-#[cfg(test)]
-mod tests {
- use crate::edit_files_tool::resolve_search_block::resolve_search_block;
- use gpui::{prelude::*, App};
- use language::{Buffer, OffsetRangeExt as _};
- use unindent::Unindent as _;
- use util::test::{generate_marked_text, marked_text_ranges};
-
- #[gpui::test]
- fn test_resolve_search_block(cx: &mut App) {
- assert_resolved(
- concat!(
- " Lorem\n",
- "« ipsum\n",
- " dolor sit amet»\n",
- " consecteur",
- ),
- "ipsum\ndolor",
- cx,
- );
-
- assert_resolved(
- &"
- «fn foo1(a: usize) -> usize {
- 40
- }»
-
- fn foo2(b: usize) -> usize {
- 42
- }
- "
- .unindent(),
- "fn foo1(b: usize) {\n40\n}",
- cx,
- );
-
- assert_resolved(
- &"
- fn main() {
- « Foo
- .bar()
- .baz()
- .qux()»
- }
-
- fn foo2(b: usize) -> usize {
- 42
- }
- "
- .unindent(),
- "Foo.bar.baz.qux()",
- cx,
- );
-
- assert_resolved(
- &"
- class Something {
- one() { return 1; }
- « two() { return 2222; }
- three() { return 333; }
- four() { return 4444; }
- five() { return 5555; }
- six() { return 6666; }
- » seven() { return 7; }
- eight() { return 8; }
- }
- "
- .unindent(),
- &"
- two() { return 2222; }
- four() { return 4444; }
- five() { return 5555; }
- six() { return 6666; }
- "
- .unindent(),
- cx,
- );
- }
-
- #[track_caller]
- fn assert_resolved(text_with_expected_range: &str, query: &str, cx: &mut App) {
- let (text, _) = marked_text_ranges(text_with_expected_range, false);
- let buffer = cx.new(|cx| Buffer::local(text.clone(), cx));
- let snapshot = buffer.read(cx).snapshot();
- let range = resolve_search_block(&snapshot, query).to_offset(&snapshot);
- let text_with_actual_range = generate_marked_text(&text, &[range], false);
- pretty_assertions::assert_eq!(text_with_actual_range, text_with_expected_range);
- }
-}