Detailed changes
@@ -658,6 +658,7 @@ dependencies = [
"collections",
"derive_more",
"gpui",
+ "language_model",
"parking_lot",
"project",
"serde",
@@ -671,11 +672,16 @@ dependencies = [
"anyhow",
"assistant_tool",
"chrono",
+ "collections",
+ "futures 0.3.31",
"gpui",
+ "language_model",
"project",
+ "rand 0.8.5",
"schemars",
"serde",
"serde_json",
+ "util",
]
[[package]]
@@ -3128,6 +3134,7 @@ dependencies = [
"extension",
"futures 0.3.31",
"gpui",
+ "language_model",
"log",
"parking_lot",
"postage",
@@ -600,6 +600,13 @@
"provider": "zed.dev",
// The model to use.
"model": "claude-3-5-sonnet-latest"
+ },
+ // The model to use when applying edits from the assistant.
+ "editor_model": {
+ // The provider to use.
+ "provider": "zed.dev",
+ // The model to use.
+ "model": "claude-3-5-sonnet-latest"
}
},
// The settings for slash commands.
@@ -186,8 +186,12 @@ fn init_language_model_settings(cx: &mut App) {
fn update_active_language_model_from_settings(cx: &mut App) {
let settings = AssistantSettings::get_global(cx);
- let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
- let model_id = LanguageModelId::from(settings.default_model.model.clone());
+ let active_model_provider_name =
+ LanguageModelProviderId::from(settings.default_model.provider.clone());
+ let active_model_id = LanguageModelId::from(settings.default_model.model.clone());
+ let editor_provider_name =
+ LanguageModelProviderId::from(settings.editor_model.provider.clone());
+ let editor_model_id = LanguageModelId::from(settings.editor_model.model.clone());
let inline_alternatives = settings
.inline_alternatives
.iter()
@@ -199,7 +203,8 @@ fn update_active_language_model_from_settings(cx: &mut App) {
})
.collect::<Vec<_>>();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
- registry.select_active_model(&provider_name, &model_id, cx);
+ registry.select_active_model(&active_model_provider_name, &active_model_id, cx);
+ registry.select_editor_model(&editor_provider_name, &editor_model_id, cx);
registry.select_inline_alternative_models(inline_alternatives, cx);
});
}
@@ -297,7 +297,8 @@ impl AssistantPanel {
&LanguageModelRegistry::global(cx),
window,
|this, _, event: &language_model::Event, window, cx| match event {
- language_model::Event::ActiveModelChanged => {
+ language_model::Event::ActiveModelChanged
+ | language_model::Event::EditorModelChanged => {
this.completion_provider_changed(window, cx);
}
language_model::Event::ProviderStateChanged => {
@@ -652,7 +652,7 @@ impl ActiveThread {
)
.child(message_content),
),
- Role::Assistant => div()
+ Role::Assistant => v_flex()
.id(("message-container", ix))
.child(message_content)
.map(|parent| {
@@ -623,6 +623,7 @@ impl Thread {
}
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
+ let request = self.to_completion_request(RequestKind::Chat, cx);
let pending_tool_uses = self
.tool_use
.pending_tool_uses()
@@ -633,7 +634,7 @@ impl Thread {
for tool_use in pending_tool_uses {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
- let task = tool.run(tool_use.input, self.project.clone(), cx);
+ let task = tool.run(tool_use.input, &request.messages, self.project.clone(), cx);
self.insert_tool_output(tool_use.id.clone(), task, cx);
}
@@ -62,6 +62,7 @@ pub struct AssistantSettings {
pub default_width: Pixels,
pub default_height: Pixels,
pub default_model: LanguageModelSelection,
+ pub editor_model: LanguageModelSelection,
pub inline_alternatives: Vec<LanguageModelSelection>,
pub using_outdated_settings_version: bool,
pub enable_experimental_live_diffs: bool,
@@ -162,6 +163,7 @@ impl AssistantSettingsContent {
})
}
}),
+ editor_model: None,
inline_alternatives: None,
enable_experimental_live_diffs: None,
},
@@ -182,6 +184,7 @@ impl AssistantSettingsContent {
.id()
.to_string(),
}),
+ editor_model: None,
inline_alternatives: None,
enable_experimental_live_diffs: None,
},
@@ -310,6 +313,7 @@ impl Default for VersionedAssistantSettingsContent {
default_width: None,
default_height: None,
default_model: None,
+ editor_model: None,
inline_alternatives: None,
enable_experimental_live_diffs: None,
})
@@ -340,6 +344,8 @@ pub struct AssistantSettingsContentV2 {
default_height: Option<f32>,
/// The default model to use when creating new chats.
default_model: Option<LanguageModelSelection>,
+ /// The model to use when applying edits from the assistant.
+ editor_model: Option<LanguageModelSelection>,
/// Additional models with which to generate alternatives when performing inline assists.
inline_alternatives: Option<Vec<LanguageModelSelection>>,
/// Enable experimental live diffs in the assistant panel.
@@ -470,6 +476,7 @@ impl Settings for AssistantSettings {
value.default_height.map(Into::into),
);
merge(&mut settings.default_model, value.default_model);
+ merge(&mut settings.editor_model, value.editor_model);
merge(&mut settings.inline_alternatives, value.inline_alternatives);
merge(
&mut settings.enable_experimental_live_diffs,
@@ -528,6 +535,10 @@ mod tests {
provider: "test-provider".into(),
model: "gpt-99".into(),
}),
+ editor_model: Some(LanguageModelSelection {
+ provider: "test-provider".into(),
+ model: "gpt-99".into(),
+ }),
inline_alternatives: None,
enabled: None,
button: None,
@@ -15,6 +15,7 @@ path = "src/assistant_tool.rs"
anyhow.workspace = true
collections.workspace = true
derive_more.workspace = true
+language_model.workspace = true
gpui.workspace = true
parking_lot.workspace = true
project.workspace = true
@@ -5,6 +5,7 @@ use std::sync::Arc;
use anyhow::Result;
use gpui::{App, Entity, SharedString, Task};
+use language_model::LanguageModelRequestMessage;
use project::Project;
pub use crate::tool_registry::*;
@@ -44,6 +45,7 @@ pub trait Tool: 'static + Send + Sync {
fn run(
self: Arc<Self>,
input: serde_json::Value,
+ messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
cx: &mut App,
) -> Task<Result<String>>;
@@ -15,8 +15,18 @@ path = "src/assistant_tools.rs"
anyhow.workspace = true
assistant_tool.workspace = true
chrono.workspace = true
+collections.workspace = true
+futures.workspace = true
gpui.workspace = true
+language_model.workspace = true
project.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
+util.workspace = true
+
+[dev-dependencies]
+rand.workspace = true
+collections = { workspace = true, features = ["test-support"] }
+gpui = { workspace = true, features = ["test-support"] }
+project = { workspace = true, features = ["test-support"] }
@@ -1,3 +1,4 @@
+mod edit_files_tool;
mod list_worktrees_tool;
mod now_tool;
mod read_file_tool;
@@ -5,6 +6,7 @@ mod read_file_tool;
use assistant_tool::ToolRegistry;
use gpui::App;
+use crate::edit_files_tool::EditFilesTool;
use crate::list_worktrees_tool::ListWorktreesTool;
use crate::now_tool::NowTool;
use crate::read_file_tool::ReadFileTool;
@@ -16,4 +18,5 @@ pub fn init(cx: &mut App) {
registry.register_tool(NowTool);
registry.register_tool(ListWorktreesTool);
registry.register_tool(ReadFileTool);
+ registry.register_tool(EditFilesTool);
}
@@ -0,0 +1,155 @@
+mod edit_action;
+
+use collections::HashSet;
+use std::{path::Path, sync::Arc};
+
+use anyhow::{anyhow, Result};
+use assistant_tool::Tool;
+use edit_action::{EditAction, EditActionParser};
+use futures::StreamExt;
+use gpui::{App, Entity, Task};
+use language_model::{
+ LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
+};
+use project::{Project, ProjectPath, WorktreeId};
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+
+#[derive(Debug, Serialize, Deserialize, JsonSchema)]
+pub struct EditFilesToolInput {
+ /// The ID of the worktree in which the files reside.
+ pub worktree_id: usize,
+ /// Instruct how to modify the files.
+ pub edit_instructions: String,
+}
+
+pub struct EditFilesTool;
+
+impl Tool for EditFilesTool {
+ fn name(&self) -> String {
+ "edit-files".into()
+ }
+
+ fn description(&self) -> String {
+ include_str!("./edit_files_tool/description.md").into()
+ }
+
+ fn input_schema(&self) -> serde_json::Value {
+ let schema = schemars::schema_for!(EditFilesToolInput);
+ serde_json::to_value(&schema).unwrap()
+ }
+
+ fn run(
+ self: Arc<Self>,
+ input: serde_json::Value,
+ messages: &[LanguageModelRequestMessage],
+ project: Entity<Project>,
+ cx: &mut App,
+ ) -> Task<Result<String>> {
+ let input = match serde_json::from_value::<EditFilesToolInput>(input) {
+ Ok(input) => input,
+ Err(err) => return Task::ready(Err(anyhow!(err))),
+ };
+
+ let model_registry = LanguageModelRegistry::read_global(cx);
+ let Some(model) = model_registry.editor_model() else {
+ return Task::ready(Err(anyhow!("No editor model configured")));
+ };
+
+ let mut messages = messages.to_vec();
+ if let Some(last_message) = messages.last_mut() {
+ // Strip out tool use from the last message because we're in the middle of executing a tool call.
+ last_message
+ .content
+ .retain(|content| !matches!(content, language_model::MessageContent::ToolUse(_)))
+ }
+ messages.push(LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![
+ include_str!("./edit_files_tool/edit_prompt.md").into(),
+ input.edit_instructions.into(),
+ ],
+ cache: false,
+ });
+
+ cx.spawn(|mut cx| async move {
+ let request = LanguageModelRequest {
+ messages,
+ tools: vec![],
+ stop: vec![],
+ temperature: None,
+ };
+
+ let mut parser = EditActionParser::new();
+
+ let stream = model.stream_completion_text(request, &cx);
+ let mut chunks = stream.await?;
+
+ let mut changed_buffers = HashSet::default();
+ let mut applied_edits = 0;
+
+ while let Some(chunk) = chunks.stream.next().await {
+ for action in parser.parse_chunk(&chunk?) {
+ let project_path = ProjectPath {
+ worktree_id: WorktreeId::from_usize(input.worktree_id),
+ path: Path::new(action.file_path()).into(),
+ };
+
+ let buffer = project
+ .update(&mut cx, |project, cx| project.open_buffer(project_path, cx))?
+ .await?;
+
+ let diff = buffer
+ .read_with(&cx, |buffer, cx| {
+ let new_text = match action {
+ EditAction::Replace { old, new, .. } => {
+ // TODO: Replace in background?
+ buffer.text().replace(&old, &new)
+ }
+ EditAction::Write { content, .. } => content,
+ };
+
+ buffer.diff(new_text, cx)
+ })?
+ .await;
+
+ let _clock =
+ buffer.update(&mut cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
+
+ changed_buffers.insert(buffer);
+
+ applied_edits += 1;
+ }
+ }
+
+ // Save each buffer once at the end
+ for buffer in changed_buffers {
+ project
+ .update(&mut cx, |project, cx| project.save_buffer(buffer, cx))?
+ .await?;
+ }
+
+ let errors = parser.errors();
+
+ if errors.is_empty() {
+ Ok("Successfully applied all edits".into())
+ } else {
+ let error_message = errors
+ .iter()
+ .map(|e| e.to_string())
+ .collect::<Vec<_>>()
+ .join("\n");
+
+ if applied_edits > 0 {
+ Err(anyhow!(
+ "Applied {} edit(s), but some blocks failed to parse:\n{}",
+ applied_edits,
+ error_message
+ ))
+ } else {
+ Err(anyhow!(error_message))
+ }
+ }
+ })
+ }
+}
@@ -0,0 +1,3 @@
+Edit files in a worktree by providing its id and a description of how to modify the code to complete the request.
+
+Make instructions unambiguous and complete. Explain all needed code changes clearly and completely, but concisely. Just show the changes needed. DO NOT show the entire updated function/file/etc!
@@ -0,0 +1,807 @@
+use util::ResultExt;
+
+/// Represents an edit action to be performed on a file.
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum EditAction {
+ /// Replace specific content in a file with new content
+ Replace {
+ file_path: String,
+ old: String,
+ new: String,
+ },
+ /// Write content to a file (create or overwrite)
+ Write { file_path: String, content: String },
+}
+
+impl EditAction {
+ pub fn file_path(&self) -> &str {
+ match self {
+ EditAction::Replace { file_path, .. } => file_path,
+ EditAction::Write { file_path, .. } => file_path,
+ }
+ }
+}
+
+/// Parses edit actions from an LLM response.
+/// See system.md for more details on the format.
+#[derive(Debug)]
+pub struct EditActionParser {
+ state: State,
+ pre_fence_line: Vec<u8>,
+ marker_ix: usize,
+ line: usize,
+ column: usize,
+ old_bytes: Vec<u8>,
+ new_bytes: Vec<u8>,
+ errors: Vec<ParseError>,
+}
+
+#[derive(Debug, PartialEq, Eq)]
+enum State {
+ /// Anywhere outside an action
+ Default,
+ /// After opening ```, in optional language tag
+ OpenFence,
+ /// In SEARCH marker
+ SearchMarker,
+ /// In search block or divider
+ SearchBlock,
+ /// In replace block or REPLACE marker
+ ReplaceBlock,
+ /// In closing ```
+ CloseFence,
+}
+
+impl EditActionParser {
+ /// Creates a new `EditActionParser`
+ pub fn new() -> Self {
+ Self {
+ state: State::Default,
+ pre_fence_line: Vec::new(),
+ marker_ix: 0,
+ line: 1,
+ column: 0,
+ old_bytes: Vec::new(),
+ new_bytes: Vec::new(),
+ errors: Vec::new(),
+ }
+ }
+
+ /// Processes a chunk of input text and returns any completed edit actions.
+ ///
+ /// This method can be called repeatedly with fragments of input. The parser
+ /// maintains its state between calls, allowing you to process streaming input
+ /// as it becomes available. Actions are only inserted once they are fully parsed.
+ ///
+ /// If a block fails to parse, it will simply be skipped and an error will be recorded.
+ /// All errors can be accessed through the `EditActionsParser::errors` method.
+ pub fn parse_chunk(&mut self, input: &str) -> Vec<EditAction> {
+ use State::*;
+
+ const FENCE: &[u8] = b"\n```";
+ const SEARCH_MARKER: &[u8] = b"<<<<<<< SEARCH\n";
+ const DIVIDER: &[u8] = b"=======\n";
+ const NL_DIVIDER: &[u8] = b"\n=======\n";
+ const REPLACE_MARKER: &[u8] = b">>>>>>> REPLACE";
+ const NL_REPLACE_MARKER: &[u8] = b"\n>>>>>>> REPLACE";
+
+ let mut actions = Vec::new();
+
+ for byte in input.bytes() {
+ // Update line and column tracking
+ if byte == b'\n' {
+ self.line += 1;
+ self.column = 0;
+ } else {
+ self.column += 1;
+ }
+
+ match self.state {
+ Default => match match_marker(byte, FENCE, &mut self.marker_ix) {
+ MarkerMatch::Complete => {
+ self.to_state(OpenFence);
+ }
+ MarkerMatch::Partial => {}
+ MarkerMatch::None => {
+ if self.marker_ix > 0 {
+ self.marker_ix = 0;
+ self.pre_fence_line.clear();
+ }
+
+ if byte != b'\n' {
+ self.pre_fence_line.push(byte);
+ }
+ }
+ },
+ OpenFence => {
+ // skip language tag
+ if byte == b'\n' {
+ self.to_state(SearchMarker);
+ }
+ }
+ SearchMarker => {
+ if self.expect_marker(byte, SEARCH_MARKER) {
+ self.to_state(SearchBlock);
+ }
+ }
+ SearchBlock => {
+ if collect_until_marker(
+ byte,
+ DIVIDER,
+ NL_DIVIDER,
+ &mut self.marker_ix,
+ &mut self.old_bytes,
+ ) {
+ self.to_state(ReplaceBlock);
+ }
+ }
+ ReplaceBlock => {
+ if collect_until_marker(
+ byte,
+ REPLACE_MARKER,
+ NL_REPLACE_MARKER,
+ &mut self.marker_ix,
+ &mut self.new_bytes,
+ ) {
+ self.to_state(CloseFence);
+ }
+ }
+ CloseFence => {
+ if self.expect_marker(byte, FENCE) {
+ if let Some(action) = self.action() {
+ actions.push(action);
+ }
+ self.reset();
+ }
+ }
+ };
+ }
+
+ actions
+ }
+
+ /// Returns a reference to the errors encountered during parsing.
+ pub fn errors(&self) -> &[ParseError] {
+ &self.errors
+ }
+
+ fn action(&mut self) -> Option<EditAction> {
+ if self.old_bytes.is_empty() && self.new_bytes.is_empty() {
+ self.push_error(ParseErrorKind::NoOp);
+ return None;
+ }
+
+ let file_path = String::from_utf8(std::mem::take(&mut self.pre_fence_line)).log_err()?;
+ let content = String::from_utf8(std::mem::take(&mut self.new_bytes)).log_err()?;
+
+ if self.old_bytes.is_empty() {
+ Some(EditAction::Write { file_path, content })
+ } else {
+ let old = String::from_utf8(std::mem::take(&mut self.old_bytes)).log_err()?;
+
+ Some(EditAction::Replace {
+ file_path,
+ old,
+ new: content,
+ })
+ }
+ }
+
+ fn expect_marker(&mut self, byte: u8, marker: &'static [u8]) -> bool {
+ match match_marker(byte, marker, &mut self.marker_ix) {
+ MarkerMatch::Complete => true,
+ MarkerMatch::Partial => false,
+ MarkerMatch::None => {
+ self.push_error(ParseErrorKind::ExpectedMarker {
+ expected: marker,
+ found: byte,
+ });
+ self.reset();
+ false
+ }
+ }
+ }
+
+ fn to_state(&mut self, state: State) {
+ self.state = state;
+ self.marker_ix = 0;
+ }
+
+ fn reset(&mut self) {
+ self.pre_fence_line.clear();
+ self.old_bytes.clear();
+ self.new_bytes.clear();
+ self.to_state(State::Default);
+ }
+
+ fn push_error(&mut self, kind: ParseErrorKind) {
+ self.errors.push(ParseError {
+ line: self.line,
+ column: self.column,
+ kind,
+ });
+ }
+}
+
+#[derive(Debug)]
+enum MarkerMatch {
+ None,
+ Partial,
+ Complete,
+}
+
+fn match_marker(byte: u8, marker: &[u8], marker_ix: &mut usize) -> MarkerMatch {
+ if byte == marker[*marker_ix] {
+ *marker_ix += 1;
+
+ if *marker_ix >= marker.len() {
+ MarkerMatch::Complete
+ } else {
+ MarkerMatch::Partial
+ }
+ } else {
+ MarkerMatch::None
+ }
+}
+
+fn collect_until_marker(
+ byte: u8,
+ marker: &[u8],
+ nl_marker: &[u8],
+ marker_ix: &mut usize,
+ buf: &mut Vec<u8>,
+) -> bool {
+ let marker = if buf.is_empty() {
+ // do not require another newline if block is empty
+ marker
+ } else {
+ nl_marker
+ };
+
+ match match_marker(byte, marker, marker_ix) {
+ MarkerMatch::Complete => true,
+ MarkerMatch::Partial => false,
+ MarkerMatch::None => {
+ if *marker_ix > 0 {
+ buf.extend_from_slice(&marker[..*marker_ix]);
+ *marker_ix = 0;
+
+ // The beginning of marker might match current byte
+ match match_marker(byte, marker, marker_ix) {
+ MarkerMatch::Complete => return true,
+ MarkerMatch::Partial => return false,
+ MarkerMatch::None => { /* no match, keep collecting */ }
+ }
+ }
+
+ buf.push(byte);
+
+ false
+ }
+ }
+}
+
+#[derive(Debug, PartialEq, Eq)]
+pub struct ParseError {
+ line: usize,
+ column: usize,
+ kind: ParseErrorKind,
+}
+
+#[derive(Debug, PartialEq, Eq)]
+pub enum ParseErrorKind {
+ ExpectedMarker { expected: &'static [u8], found: u8 },
+ NoOp,
+}
+
+impl std::fmt::Display for ParseErrorKind {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ ParseErrorKind::ExpectedMarker { expected, found } => {
+ write!(
+ f,
+ "Expected marker {:?}, found {:?}",
+ String::from_utf8_lossy(expected),
+ *found as char
+ )
+ }
+ ParseErrorKind::NoOp => {
+ write!(f, "No search or replace")
+ }
+ }
+ }
+}
+
+impl std::fmt::Display for ParseError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "input:{}:{}: {}", self.line, self.column, self.kind)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use rand::prelude::*;
+
+ #[test]
+ fn test_simple_edit_action() {
+ let input = r#"src/main.rs
+```
+<<<<<<< SEARCH
+fn original() {}
+=======
+fn replacement() {}
+>>>>>>> REPLACE
+```
+"#;
+
+ let mut parser = EditActionParser::new();
+ let actions = parser.parse_chunk(input);
+
+ assert_eq!(actions.len(), 1);
+ assert_eq!(
+ actions[0],
+ EditAction::Replace {
+ file_path: "src/main.rs".to_string(),
+ old: "fn original() {}".to_string(),
+ new: "fn replacement() {}".to_string(),
+ }
+ );
+ }
+
+ #[test]
+ fn test_with_language_tag() {
+ let input = r#"src/main.rs
+```rust
+<<<<<<< SEARCH
+fn original() {}
+=======
+fn replacement() {}
+>>>>>>> REPLACE
+```
+"#;
+
+ let mut parser = EditActionParser::new();
+ let actions = parser.parse_chunk(input);
+
+ assert_eq!(actions.len(), 1);
+ assert_eq!(
+ actions[0],
+ EditAction::Replace {
+ file_path: "src/main.rs".to_string(),
+ old: "fn original() {}".to_string(),
+ new: "fn replacement() {}".to_string(),
+ }
+ );
+ }
+
+ #[test]
+ fn test_with_surrounding_text() {
+ let input = r#"Here's a modification I'd like to make to the file:
+
+src/main.rs
+```rust
+<<<<<<< SEARCH
+fn original() {}
+=======
+fn replacement() {}
+>>>>>>> REPLACE
+```
+
+This change makes the function better.
+"#;
+
+ let mut parser = EditActionParser::new();
+ let actions = parser.parse_chunk(input);
+
+ assert_eq!(actions.len(), 1);
+ assert_eq!(
+ actions[0],
+ EditAction::Replace {
+ file_path: "src/main.rs".to_string(),
+ old: "fn original() {}".to_string(),
+ new: "fn replacement() {}".to_string(),
+ }
+ );
+ }
+
+ #[test]
+ fn test_multiple_edit_actions() {
+ let input = r#"First change:
+src/main.rs
+```
+<<<<<<< SEARCH
+fn original() {}
+=======
+fn replacement() {}
+>>>>>>> REPLACE
+```
+
+Second change:
+src/utils.rs
+```rust
+<<<<<<< SEARCH
+fn old_util() -> bool { false }
+=======
+fn new_util() -> bool { true }
+>>>>>>> REPLACE
+```
+"#;
+
+ let mut parser = EditActionParser::new();
+ let actions = parser.parse_chunk(input);
+
+ assert_eq!(actions.len(), 2);
+ assert_eq!(
+ actions[0],
+ EditAction::Replace {
+ file_path: "src/main.rs".to_string(),
+ old: "fn original() {}".to_string(),
+ new: "fn replacement() {}".to_string(),
+ }
+ );
+ assert_eq!(
+ actions[1],
+ EditAction::Replace {
+ file_path: "src/utils.rs".to_string(),
+ old: "fn old_util() -> bool { false }".to_string(),
+ new: "fn new_util() -> bool { true }".to_string(),
+ }
+ );
+ }
+
+ #[test]
+ fn test_multiline() {
+ let input = r#"src/main.rs
+```rust
+<<<<<<< SEARCH
+fn original() {
+ println!("This is the original function");
+ let x = 42;
+ if x > 0 {
+ println!("Positive number");
+ }
+}
+=======
+fn replacement() {
+ println!("This is the replacement function");
+ let x = 100;
+ if x > 50 {
+ println!("Large number");
+ } else {
+ println!("Small number");
+ }
+}
+>>>>>>> REPLACE
+```
+"#;
+
+ let mut parser = EditActionParser::new();
+ let actions = parser.parse_chunk(input);
+
+ assert_eq!(actions.len(), 1);
+ assert_eq!(
+ actions[0],
+ EditAction::Replace {
+ file_path: "src/main.rs".to_string(),
+ old: "fn original() {\n println!(\"This is the original function\");\n let x = 42;\n if x > 0 {\n println!(\"Positive number\");\n }\n}".to_string(),
+ new: "fn replacement() {\n println!(\"This is the replacement function\");\n let x = 100;\n if x > 50 {\n println!(\"Large number\");\n } else {\n println!(\"Small number\");\n }\n}".to_string(),
+ }
+ );
+ }
+
+ #[test]
+ fn test_write_action() {
+ let input = r#"Create a new main.rs file:
+
+src/main.rs
+```rust
+<<<<<<< SEARCH
+=======
+fn new_function() {
+ println!("This function is being added");
+}
+>>>>>>> REPLACE
+```
+"#;
+
+ let mut parser = EditActionParser::new();
+ let actions = parser.parse_chunk(input);
+
+ assert_eq!(actions.len(), 1);
+ assert_eq!(
+ actions[0],
+ EditAction::Write {
+ file_path: "src/main.rs".to_string(),
+ content: "fn new_function() {\n println!(\"This function is being added\");\n}"
+ .to_string(),
+ }
+ );
+ }
+
+ #[test]
+ fn test_empty_replace() {
+ let input = r#"src/main.rs
+```rust
+<<<<<<< SEARCH
+fn this_will_be_deleted() {
+ println!("Deleting this function");
+}
+=======
+>>>>>>> REPLACE
+```
+"#;
+
+ let mut parser = EditActionParser::new();
+ let actions = parser.parse_chunk(input);
+
+ assert_eq!(actions.len(), 1);
+ assert_eq!(
+ actions[0],
+ EditAction::Replace {
+ file_path: "src/main.rs".to_string(),
+ old: "fn this_will_be_deleted() {\n println!(\"Deleting this function\");\n}"
+ .to_string(),
+ new: "".to_string(),
+ }
+ );
+ }
+
+ #[test]
+ fn test_empty_both() {
+ let input = r#"src/main.rs
+```rust
+<<<<<<< SEARCH
+=======
+>>>>>>> REPLACE
+```
+"#;
+
+ let mut parser = EditActionParser::new();
+ let actions = parser.parse_chunk(input);
+
+ // Should not create an action when both sections are empty
+ assert_eq!(actions.len(), 0);
+
+ // Check that the NoOp error was added
+ assert_eq!(parser.errors().len(), 1);
+ match parser.errors()[0].kind {
+ ParseErrorKind::NoOp => {}
+ _ => panic!("Expected NoOp error"),
+ }
+ }
+
+ #[test]
+ fn test_resumability() {
+ let input_part1 = r#"src/main.rs
+```rust
+<<<<<<< SEARCH
+fn ori"#;
+
+ let input_part2 = r#"ginal() {}
+=======
+fn replacement() {}"#;
+
+ let input_part3 = r#"
+>>>>>>> REPLACE
+```
+"#;
+
+ let mut parser = EditActionParser::new();
+ let actions1 = parser.parse_chunk(input_part1);
+ assert_eq!(actions1.len(), 0);
+
+ let actions2 = parser.parse_chunk(input_part2);
+ // No actions should be complete yet
+ assert_eq!(actions2.len(), 0);
+
+ let actions3 = parser.parse_chunk(input_part3);
+ // The third chunk should complete the action
+ assert_eq!(actions3.len(), 1);
+ assert_eq!(
+ actions3[0],
+ EditAction::Replace {
+ file_path: "src/main.rs".to_string(),
+ old: "fn original() {}".to_string(),
+ new: "fn replacement() {}".to_string(),
+ }
+ );
+ }
+
+ #[test]
+ fn test_parser_state_preservation() {
+ let mut parser = EditActionParser::new();
+ let actions1 = parser.parse_chunk("src/main.rs\n```rust\n<<<<<<< SEARCH\n");
+
+ // Check parser is in the correct state
+ assert_eq!(parser.state, State::SearchBlock);
+ assert_eq!(parser.pre_fence_line, b"src/main.rs");
+
+ // Continue parsing
+ let actions2 = parser.parse_chunk("original code\n=======\n");
+ assert_eq!(parser.state, State::ReplaceBlock);
+ assert_eq!(parser.old_bytes, b"original code");
+
+ let actions3 = parser.parse_chunk("replacement code\n>>>>>>> REPLACE\n```\n");
+
+ // After complete parsing, state should reset
+ assert_eq!(parser.state, State::Default);
+ assert!(parser.pre_fence_line.is_empty());
+ assert!(parser.old_bytes.is_empty());
+ assert!(parser.new_bytes.is_empty());
+
+ assert_eq!(actions1.len(), 0);
+ assert_eq!(actions2.len(), 0);
+ assert_eq!(actions3.len(), 1);
+ }
+
+ #[test]
+ fn test_invalid_search_marker() {
+ let input = r#"src/main.rs
+```rust
+<<<<<<< WRONG_MARKER
+fn original() {}
+=======
+fn replacement() {}
+>>>>>>> REPLACE
+```
+"#;
+
+ let mut parser = EditActionParser::new();
+ let actions = parser.parse_chunk(input);
+ assert_eq!(actions.len(), 0);
+
+ assert_eq!(parser.errors().len(), 1);
+ let error = &parser.errors()[0];
+
+ assert_eq!(error.line, 3);
+ assert_eq!(error.column, 9);
+ assert_eq!(
+ error.kind,
+ ParseErrorKind::ExpectedMarker {
+ expected: b"<<<<<<< SEARCH\n",
+ found: b'W'
+ }
+ );
+ }
+
+ #[test]
+ fn test_missing_closing_fence() {
+ let input = r#"src/main.rs
+```rust
+<<<<<<< SEARCH
+fn original() {}
+=======
+fn replacement() {}
+>>>>>>> REPLACE
+<!-- Missing closing fence -->
+
+src/utils.rs
+```rust
+<<<<<<< SEARCH
+fn utils_func() {}
+=======
+fn new_utils_func() {}
+>>>>>>> REPLACE
+```
+"#;
+
+ let mut parser = EditActionParser::new();
+ let actions = parser.parse_chunk(input);
+
+ // Only the second block should be parsed
+ assert_eq!(actions.len(), 1);
+ assert_eq!(
+ actions[0],
+ EditAction::Replace {
+ file_path: "src/utils.rs".to_string(),
+ old: "fn utils_func() {}".to_string(),
+ new: "fn new_utils_func() {}".to_string(),
+ }
+ );
+
+ // The parser should continue after an error
+ assert_eq!(parser.state, State::Default);
+ }
+
+ const SYSTEM_PROMPT: &str = include_str!("./edit_prompt.md");
+
+ #[test]
+ fn test_parse_examples_in_system_prompt() {
+ let mut parser = EditActionParser::new();
+ let actions = parser.parse_chunk(SYSTEM_PROMPT);
+ assert_examples_in_system_prompt(&actions, parser.errors());
+ }
+
+ #[gpui::test(iterations = 10)]
+ fn test_random_chunking_of_system_prompt(mut rng: StdRng) {
+ let mut parser = EditActionParser::new();
+ let mut remaining = SYSTEM_PROMPT;
+ let mut actions = Vec::with_capacity(5);
+
+ while !remaining.is_empty() {
+ let chunk_size = rng.gen_range(1..=std::cmp::min(remaining.len(), 100));
+
+ let (chunk, rest) = remaining.split_at(chunk_size);
+
+ actions.extend(parser.parse_chunk(chunk));
+ remaining = rest;
+ }
+
+ assert_examples_in_system_prompt(&actions, parser.errors());
+ }
+
+ fn assert_examples_in_system_prompt(actions: &[EditAction], errors: &[ParseError]) {
+ assert_eq!(actions.len(), 5);
+
+ assert_eq!(
+ actions[0],
+ EditAction::Replace {
+ file_path: "mathweb/flask/app.py".to_string(),
+ old: "from flask import Flask".to_string(),
+ new: "import math\nfrom flask import Flask".to_string(),
+ }
+ );
+
+ assert_eq!(
+ actions[1],
+ EditAction::Replace {
+ file_path: "mathweb/flask/app.py".to_string(),
+ old: "def factorial(n):\n \"compute factorial\"\n\n if n == 0:\n return 1\n else:\n return n * factorial(n-1)\n".to_string(),
+ new: "".to_string(),
+ }
+ );
+
+ assert_eq!(
+ actions[2],
+ EditAction::Replace {
+ file_path: "mathweb/flask/app.py".to_string(),
+ old: " return str(factorial(n))".to_string(),
+ new: " return str(math.factorial(n))".to_string(),
+ }
+ );
+
+ assert_eq!(
+ actions[3],
+ EditAction::Write {
+ file_path: "hello.py".to_string(),
+ content: "def hello():\n \"print a greeting\"\n\n print(\"hello\")"
+ .to_string(),
+ }
+ );
+
+ assert_eq!(
+ actions[4],
+ EditAction::Replace {
+ file_path: "main.py".to_string(),
+ old: "def hello():\n \"print a greeting\"\n\n print(\"hello\")".to_string(),
+ new: "from hello import hello".to_string(),
+ }
+ );
+
+ // Ensure we have no parsing errors
+ assert!(errors.is_empty(), "Parsing errors found: {:?}", errors);
+ }
+
+ #[test]
+ fn test_print_error() {
+ let input = r#"src/main.rs
+```rust
+<<<<<<< WRONG_MARKER
+fn original() {}
+=======
+fn replacement() {}
+>>>>>>> REPLACE
+```
+"#;
+
+ let mut parser = EditActionParser::new();
+ parser.parse_chunk(input);
+
+ assert_eq!(parser.errors().len(), 1);
+ let error = &parser.errors()[0];
+ let expected_error = r#"input:3:9: Expected marker "<<<<<<< SEARCH\n", found 'W'"#;
+
+ assert_eq!(format!("{}", error), expected_error);
+ }
+}
@@ -0,0 +1,134 @@
+Act as an expert software developer.
+Always use best practices when coding.
+Respect and use existing conventions, libraries, etc that are already present in the code base.
+
+Take requests for changes to the supplied code.
+Always reply to the user in the same language they are using.
+
+Once you understand the request, you MUST:
+
+1. Decide if you need to propose *SEARCH/REPLACE* edits to any files that haven't been added to the chat. You can create new files without asking!
+
+2. Think step-by-step and explain the needed changes in a few short sentences.
+
+3. Describe each change with a *SEARCH/REPLACE block* per the examples below.
+
+All changes to files must use this *SEARCH/REPLACE block* format.
+ONLY EVER RETURN CODE IN A *SEARCH/REPLACE BLOCK*!
+
+# Example conversations:
+
+## USER: Change get_factorial() to use math.factorial
+
+## ASSISTANT: To make this change we need to modify `mathweb/flask/app.py` to:
+
+1. Import the math package.
+2. Remove the existing factorial() function.
+3. Update get_factorial() to call math.factorial instead.
+
+Here are the *SEARCH/REPLACE* blocks:
+
+mathweb/flask/app.py
+```python
+<<<<<<< SEARCH
+from flask import Flask
+=======
+import math
+from flask import Flask
+>>>>>>> REPLACE
+```
+
+mathweb/flask/app.py
+```python
+<<<<<<< SEARCH
+def factorial(n):
+ "compute factorial"
+
+ if n == 0:
+ return 1
+ else:
+ return n * factorial(n-1)
+
+=======
+>>>>>>> REPLACE
+```
+
+mathweb/flask/app.py
+```python
+<<<<<<< SEARCH
+ return str(factorial(n))
+=======
+ return str(math.factorial(n))
+>>>>>>> REPLACE
+```
+
+
+## USER: Refactor hello() into its own file.
+
+## ASSISTANT: To make this change we need to modify `main.py` and make a new file `hello.py`:
+
+1. Make a new hello.py file with hello() in it.
+2. Remove hello() from main.py and replace it with an import.
+
+Here are the *SEARCH/REPLACE* blocks:
+
+hello.py
+```python
+<<<<<<< SEARCH
+=======
+def hello():
+ "print a greeting"
+
+ print("hello")
+>>>>>>> REPLACE
+```
+
+main.py
+```python
+<<<<<<< SEARCH
+def hello():
+ "print a greeting"
+
+ print("hello")
+=======
+from hello import hello
+>>>>>>> REPLACE
+```
+# *SEARCH/REPLACE block* Rules:
+
+Every *SEARCH/REPLACE block* must use this format:
+1. The *FULL* file path alone on a line, verbatim. No bold asterisks, no quotes around it, no escaping of characters, etc.
+2. The opening fence and code language, eg: ```python
+3. The start of search block: <<<<<<< SEARCH
+4. A contiguous chunk of lines to search for in the existing source code
+5. The dividing line: =======
+6. The lines to replace into the source code
+7. The end of the replace block: >>>>>>> REPLACE
+8. The closing fence: ```
+
+Use the *FULL* file path, as shown to you by the user.
+
+Every *SEARCH* section must *EXACTLY MATCH* the existing file content, character for character, including all comments, docstrings, etc.
+If the file contains code or other data wrapped/escaped in json/xml/quotes or other containers, you need to propose edits to the literal contents of the file, including the container markup.
+
+*SEARCH/REPLACE* blocks will *only* replace the first match occurrence.
+Including multiple unique *SEARCH/REPLACE* blocks if needed.
+Include enough lines in each SEARCH section to uniquely match each set of lines that need to change.
+
+Keep *SEARCH/REPLACE* blocks concise.
+Break large *SEARCH/REPLACE* blocks into a series of smaller blocks that each change a small portion of the file.
+Include just the changing lines, and a few surrounding lines if needed for uniqueness.
+Do not include long runs of unchanging lines in *SEARCH/REPLACE* blocks.
+
+Only create *SEARCH/REPLACE* blocks for files that the user has added to the chat!
+
+To move code within a file, use 2 *SEARCH/REPLACE* blocks: 1 to delete it from its current location, 1 to insert it in the new location.
+
+Pay attention to which filenames the user wants you to edit, especially if they are asking you to create a new file.
+
+If you want to put code in a new file, use a *SEARCH/REPLACE block* with:
+- A new file path, including dir name if needed
+- An empty `SEARCH` section
+- The new file's contents in the `REPLACE` section
+
+ONLY EVER RETURN CODE IN A *SEARCH/REPLACE BLOCK*!
@@ -3,6 +3,7 @@ use std::sync::Arc;
use anyhow::Result;
use assistant_tool::Tool;
use gpui::{App, Entity, Task};
+use language_model::LanguageModelRequestMessage;
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -34,6 +35,7 @@ impl Tool for ListWorktreesTool {
fn run(
self: Arc<Self>,
_input: serde_json::Value,
+ _messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
cx: &mut App,
) -> Task<Result<String>> {
@@ -4,6 +4,7 @@ use anyhow::{anyhow, Result};
use assistant_tool::Tool;
use chrono::{Local, Utc};
use gpui::{App, Entity, Task};
+use language_model::LanguageModelRequestMessage;
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -42,6 +43,7 @@ impl Tool for NowTool {
fn run(
self: Arc<Self>,
input: serde_json::Value,
+ _messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
_cx: &mut App,
) -> Task<Result<String>> {
@@ -4,6 +4,7 @@ use std::sync::Arc;
use anyhow::{anyhow, Result};
use assistant_tool::Tool;
use gpui::{App, Entity, Task};
+use language_model::LanguageModelRequestMessage;
use project::{Project, ProjectPath, WorktreeId};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -37,6 +38,7 @@ impl Tool for ReadFileTool {
fn run(
self: Arc<Self>,
input: serde_json::Value,
+ _messages: &[LanguageModelRequestMessage],
project: Entity<Project>,
cx: &mut App,
) -> Task<Result<String>> {
@@ -56,7 +58,16 @@ impl Tool for ReadFileTool {
})?
.await?;
- cx.update(|cx| buffer.read(cx).text())
+ buffer.read_with(&cx, |buffer, _cx| {
+ if buffer
+ .file()
+ .map_or(false, |file| file.disk_state().exists())
+ {
+ Ok(buffer.text())
+ } else {
+ Err(anyhow!("File does not exist"))
+ }
+ })?
})
}
}
@@ -21,6 +21,7 @@ context_server_settings.workspace = true
extension.workspace = true
futures.workspace = true
gpui.workspace = true
+language_model.workspace = true
log.workspace = true
parking_lot.workspace = true
postage.workspace = true
@@ -3,6 +3,7 @@ use std::sync::Arc;
use anyhow::{anyhow, bail, Result};
use assistant_tool::{Tool, ToolSource};
use gpui::{App, Entity, Task};
+use language_model::LanguageModelRequestMessage;
use project::Project;
use crate::manager::ContextServerManager;
@@ -58,6 +59,7 @@ impl Tool for ContextServerTool {
fn run(
self: Arc<Self>,
input: serde_json::Value,
+ _messages: &[LanguageModelRequestMessage],
_project: Entity<Project>,
cx: &mut App,
) -> Task<Result<String>> {
@@ -18,6 +18,7 @@ impl Global for GlobalLanguageModelRegistry {}
#[derive(Default)]
pub struct LanguageModelRegistry {
active_model: Option<ActiveModel>,
+ editor_model: Option<ActiveModel>,
providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
inline_alternatives: Vec<Arc<dyn LanguageModel>>,
}
@@ -29,6 +30,7 @@ pub struct ActiveModel {
pub enum Event {
ActiveModelChanged,
+ EditorModelChanged,
ProviderStateChanged,
AddedProvider(LanguageModelProviderId),
RemovedProvider(LanguageModelProviderId),
@@ -128,6 +130,22 @@ impl LanguageModelRegistry {
}
}
+ pub fn select_editor_model(
+ &mut self,
+ provider: &LanguageModelProviderId,
+ model_id: &LanguageModelId,
+ cx: &mut Context<Self>,
+ ) {
+ let Some(provider) = self.provider(provider) else {
+ return;
+ };
+
+ let models = provider.provided_models(cx);
+ if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
+ self.set_editor_model(Some(model), cx);
+ }
+ }
+
pub fn set_active_provider(
&mut self,
provider: Option<Arc<dyn LanguageModelProvider>>,
@@ -162,6 +180,28 @@ impl LanguageModelRegistry {
}
}
+ pub fn set_editor_model(
+ &mut self,
+ model: Option<Arc<dyn LanguageModel>>,
+ cx: &mut Context<Self>,
+ ) {
+ if let Some(model) = model {
+ let provider_id = model.provider_id();
+ if let Some(provider) = self.providers.get(&provider_id).cloned() {
+ self.editor_model = Some(ActiveModel {
+ provider,
+ model: Some(model),
+ });
+ cx.emit(Event::EditorModelChanged);
+ } else {
+ log::warn!("Active model's provider not found in registry");
+ }
+ } else {
+ self.editor_model = None;
+ cx.emit(Event::EditorModelChanged);
+ }
+ }
+
pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
Some(self.active_model.as_ref()?.provider.clone())
}
@@ -170,6 +210,10 @@ impl LanguageModelRegistry {
self.active_model.as_ref()?.model.clone()
}
+ pub fn editor_model(&self) -> Option<Arc<dyn LanguageModel>> {
+ self.editor_model.as_ref()?.model.clone()
+ }
+
/// Selects and sets the inline alternatives for language models based on
/// provider name and id.
pub fn select_inline_alternative_models(