Detailed changes
@@ -31,6 +31,7 @@ dependencies = [
"ui",
"url",
"util",
+ "uuid",
"watch",
"workspace-hack",
]
@@ -6446,6 +6447,7 @@ dependencies = [
"log",
"parking_lot",
"pretty_assertions",
+ "rand 0.8.5",
"regex",
"rope",
"schemars",
@@ -36,6 +36,7 @@ terminal.workspace = true
ui.workspace = true
url.workspace = true
util.workspace = true
+uuid.workspace = true
watch.workspace = true
workspace-hack.workspace = true
@@ -9,18 +9,19 @@ pub use mention::*;
pub use terminal::*;
use action_log::ActionLog;
-use agent_client_protocol::{self as acp};
-use anyhow::{Context as _, Result};
+use agent_client_protocol as acp;
+use anyhow::{Context as _, Result, anyhow};
use editor::Bias;
use futures::{FutureExt, channel::oneshot, future::BoxFuture};
use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
use itertools::Itertools;
use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
use markdown::Markdown;
-use project::{AgentLocation, Project};
+use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
use std::collections::HashMap;
use std::error::Error;
-use std::fmt::Formatter;
+use std::fmt::{Formatter, Write};
+use std::ops::Range;
use std::process::ExitStatus;
use std::rc::Rc;
use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
@@ -29,24 +30,23 @@ use util::ResultExt;
#[derive(Debug)]
pub struct UserMessage {
+ pub id: Option<UserMessageId>,
pub content: ContentBlock,
+ pub checkpoint: Option<GitStoreCheckpoint>,
}
impl UserMessage {
- pub fn from_acp(
- message: impl IntoIterator<Item = acp::ContentBlock>,
- language_registry: Arc<LanguageRegistry>,
- cx: &mut App,
- ) -> Self {
- let mut content = ContentBlock::Empty;
- for chunk in message {
- content.append(chunk, &language_registry, cx)
- }
- Self { content: content }
- }
-
fn to_markdown(&self, cx: &App) -> String {
- format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
+ let mut markdown = String::new();
+ if let Some(_) = self.checkpoint {
+ writeln!(markdown, "## User (checkpoint)").unwrap();
+ } else {
+ writeln!(markdown, "## User").unwrap();
+ }
+ writeln!(markdown).unwrap();
+ writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
+ writeln!(markdown).unwrap();
+ markdown
}
}
@@ -633,6 +633,7 @@ pub struct AcpThread {
pub enum AcpThreadEvent {
NewEntry,
EntryUpdated(usize),
+ EntriesRemoved(Range<usize>),
ToolAuthorizationRequired,
Stopped,
Error,
@@ -772,7 +773,7 @@ impl AcpThread {
) -> Result<()> {
match update {
acp::SessionUpdate::UserMessageChunk { content } => {
- self.push_user_content_block(content, cx);
+ self.push_user_content_block(None, content, cx);
}
acp::SessionUpdate::AgentMessageChunk { content } => {
self.push_assistant_content_block(content, false, cx);
@@ -793,18 +794,32 @@ impl AcpThread {
Ok(())
}
- pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
+ pub fn push_user_content_block(
+ &mut self,
+ message_id: Option<UserMessageId>,
+ chunk: acp::ContentBlock,
+ cx: &mut Context<Self>,
+ ) {
let language_registry = self.project.read(cx).languages().clone();
let entries_len = self.entries.len();
if let Some(last_entry) = self.entries.last_mut()
- && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
+ && let AgentThreadEntry::UserMessage(UserMessage { id, content, .. }) = last_entry
{
+ *id = message_id.or(id.take());
content.append(chunk, &language_registry, cx);
- cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
+ let idx = entries_len - 1;
+ cx.emit(AcpThreadEvent::EntryUpdated(idx));
} else {
let content = ContentBlock::new(chunk, &language_registry, cx);
- self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
+ self.push_entry(
+ AgentThreadEntry::UserMessage(UserMessage {
+ id: message_id,
+ content,
+ checkpoint: None,
+ }),
+ cx,
+ );
}
}
@@ -819,7 +834,8 @@ impl AcpThread {
if let Some(last_entry) = self.entries.last_mut()
&& let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
{
- cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
+ let idx = entries_len - 1;
+ cx.emit(AcpThreadEvent::EntryUpdated(idx));
match (chunks.last_mut(), is_thought) {
(Some(AssistantMessageChunk::Message { block }), false)
| (Some(AssistantMessageChunk::Thought { block }), true) => {
@@ -1118,69 +1134,113 @@ impl AcpThread {
self.project.read(cx).languages().clone(),
cx,
);
+ let git_store = self.project.read(cx).git_store().clone();
+
+ let old_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
+ let message_id = if self
+ .connection
+ .session_editor(&self.session_id, cx)
+ .is_some()
+ {
+ Some(UserMessageId::new())
+ } else {
+ None
+ };
self.push_entry(
- AgentThreadEntry::UserMessage(UserMessage { content: block }),
+ AgentThreadEntry::UserMessage(UserMessage {
+ id: message_id.clone(),
+ content: block,
+ checkpoint: None,
+ }),
cx,
);
self.clear_completed_plan_entries(cx);
+ let (old_checkpoint_tx, old_checkpoint_rx) = oneshot::channel();
let (tx, rx) = oneshot::channel();
let cancel_task = self.cancel(cx);
+ let request = acp::PromptRequest {
+ prompt: message,
+ session_id: self.session_id.clone(),
+ };
- self.send_task = Some(cx.spawn(async move |this, cx| {
- async {
+ self.send_task = Some(cx.spawn({
+ let message_id = message_id.clone();
+ async move |this, cx| {
cancel_task.await;
- let result = this
- .update(cx, |this, cx| {
- this.connection.prompt(
- acp::PromptRequest {
- prompt: message,
- session_id: this.session_id.clone(),
- },
- cx,
- )
- })?
- .await;
-
- tx.send(result).log_err();
-
- anyhow::Ok(())
+ old_checkpoint_tx.send(old_checkpoint.await).ok();
+ if let Ok(result) = this.update(cx, |this, cx| {
+ this.connection.prompt(message_id, request, cx)
+ }) {
+ tx.send(result.await).log_err();
+ }
}
- .await
- .log_err();
}));
- cx.spawn(async move |this, cx| match rx.await {
- Ok(Err(e)) => {
- this.update(cx, |this, cx| {
- this.send_task.take();
- cx.emit(AcpThreadEvent::Error)
- })
+ cx.spawn(async move |this, cx| {
+ let old_checkpoint = old_checkpoint_rx
+ .await
+ .map_err(|_| anyhow!("send canceled"))
+ .flatten()
+ .context("failed to get old checkpoint")
.log_err();
- Err(e)?
- }
- result => {
- let cancelled = matches!(
- result,
- Ok(Ok(acp::PromptResponse {
- stop_reason: acp::StopReason::Cancelled
- }))
- );
- // We only take the task if the current prompt wasn't cancelled.
- //
- // This prompt may have been cancelled because another one was sent
- // while it was still generating. In these cases, dropping `send_task`
- // would cause the next generation to be cancelled.
- if !cancelled {
- this.update(cx, |this, _cx| this.send_task.take()).ok();
- }
+ let response = rx.await;
- this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
+ if let Some((old_checkpoint, message_id)) = old_checkpoint.zip(message_id) {
+ let new_checkpoint = git_store
+ .update(cx, |git, cx| git.checkpoint(cx))?
+ .await
+ .context("failed to get new checkpoint")
.log_err();
- Ok(())
+ if let Some(new_checkpoint) = new_checkpoint {
+ let equal = git_store
+ .update(cx, |git, cx| {
+ git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
+ })?
+ .await
+ .unwrap_or(true);
+ if !equal {
+ this.update(cx, |this, cx| {
+ if let Some((ix, message)) = this.user_message_mut(&message_id) {
+ message.checkpoint = Some(old_checkpoint);
+ cx.emit(AcpThreadEvent::EntryUpdated(ix));
+ }
+ })?;
+ }
+ }
}
+
+ this.update(cx, |this, cx| {
+ match response {
+ Ok(Err(e)) => {
+ this.send_task.take();
+ cx.emit(AcpThreadEvent::Error);
+ Err(e)
+ }
+ result => {
+ let cancelled = matches!(
+ result,
+ Ok(Ok(acp::PromptResponse {
+ stop_reason: acp::StopReason::Cancelled
+ }))
+ );
+
+ // We only take the task if the current prompt wasn't cancelled.
+ //
+ // This prompt may have been cancelled because another one was sent
+ // while it was still generating. In these cases, dropping `send_task`
+ // would cause the next generation to be cancelled.
+ if !cancelled {
+ this.send_task.take();
+ }
+
+ cx.emit(AcpThreadEvent::Stopped);
+ Ok(())
+ }
+ }
+ })?
})
.boxed()
}
@@ -1212,6 +1272,66 @@ impl AcpThread {
cx.foreground_executor().spawn(send_task)
}
+ /// Rewinds this thread to before the entry at `index`, removing it and all
+ /// subsequent entries while reverting any changes made from that point.
+ pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else {
+ return Task::ready(Err(anyhow!("not supported")));
+ };
+ let Some(message) = self.user_message(&id) else {
+ return Task::ready(Err(anyhow!("message not found")));
+ };
+
+ let checkpoint = message.checkpoint.clone();
+
+ let git_store = self.project.read(cx).git_store().clone();
+ cx.spawn(async move |this, cx| {
+ if let Some(checkpoint) = checkpoint {
+ git_store
+ .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
+ .await?;
+ }
+
+ cx.update(|cx| session_editor.truncate(id.clone(), cx))?
+ .await?;
+ this.update(cx, |this, cx| {
+ if let Some((ix, _)) = this.user_message_mut(&id) {
+ let range = ix..this.entries.len();
+ this.entries.truncate(ix);
+ cx.emit(AcpThreadEvent::EntriesRemoved(range));
+ }
+ })
+ })
+ }
+
+ fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
+ self.entries.iter().find_map(|entry| {
+ if let AgentThreadEntry::UserMessage(message) = entry {
+ if message.id.as_ref() == Some(&id) {
+ Some(message)
+ } else {
+ None
+ }
+ } else {
+ None
+ }
+ })
+ }
+
+ fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
+ self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
+ if let AgentThreadEntry::UserMessage(message) = entry {
+ if message.id.as_ref() == Some(&id) {
+ Some((ix, message))
+ } else {
+ None
+ }
+ } else {
+ None
+ }
+ })
+ }
+
pub fn read_text_file(
&self,
path: PathBuf,
@@ -1414,13 +1534,18 @@ mod tests {
use futures::{channel::mpsc, future::LocalBoxFuture, select};
use gpui::{AsyncApp, TestAppContext, WeakEntity};
use indoc::indoc;
- use project::FakeFs;
+ use project::{FakeFs, Fs};
use rand::Rng as _;
use serde_json::json;
use settings::SettingsStore;
use smol::stream::StreamExt as _;
- use std::{cell::RefCell, path::Path, rc::Rc, time::Duration};
-
+ use std::{
+ cell::RefCell,
+ path::Path,
+ rc::Rc,
+ sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
+ time::Duration,
+ };
use util::path;
fn init_test(cx: &mut TestAppContext) {
@@ -1452,6 +1577,7 @@ mod tests {
// Test creating a new user message
thread.update(cx, |thread, cx| {
thread.push_user_content_block(
+ None,
acp::ContentBlock::Text(acp::TextContent {
annotations: None,
text: "Hello, ".to_string(),
@@ -1463,6 +1589,7 @@ mod tests {
thread.update(cx, |thread, cx| {
assert_eq!(thread.entries.len(), 1);
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
+ assert_eq!(user_msg.id, None);
assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
} else {
panic!("Expected UserMessage");
@@ -1470,8 +1597,10 @@ mod tests {
});
// Test appending to existing user message
+ let message_1_id = UserMessageId::new();
thread.update(cx, |thread, cx| {
thread.push_user_content_block(
+ Some(message_1_id.clone()),
acp::ContentBlock::Text(acp::TextContent {
annotations: None,
text: "world!".to_string(),
@@ -1483,6 +1612,7 @@ mod tests {
thread.update(cx, |thread, cx| {
assert_eq!(thread.entries.len(), 1);
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
+ assert_eq!(user_msg.id, Some(message_1_id));
assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
} else {
panic!("Expected UserMessage");
@@ -1501,8 +1631,10 @@ mod tests {
);
});
+ let message_2_id = UserMessageId::new();
thread.update(cx, |thread, cx| {
thread.push_user_content_block(
+ Some(message_2_id.clone()),
acp::ContentBlock::Text(acp::TextContent {
annotations: None,
text: "New user message".to_string(),
@@ -1514,6 +1646,7 @@ mod tests {
thread.update(cx, |thread, cx| {
assert_eq!(thread.entries.len(), 3);
if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
+ assert_eq!(user_msg.id, Some(message_2_id));
assert_eq!(user_msg.content.to_markdown(cx), "New user message");
} else {
panic!("Expected UserMessage at index 2");
@@ -1830,6 +1963,180 @@ mod tests {
assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
}
+ #[gpui::test(iterations = 10)]
+ async fn test_checkpoints(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.background_executor.clone());
+ fs.insert_tree(
+ path!("/test"),
+ json!({
+ ".git": {}
+ }),
+ )
+ .await;
+ let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
+
+ let simulate_changes = Arc::new(AtomicBool::new(true));
+ let next_filename = Arc::new(AtomicUsize::new(0));
+ let connection = Rc::new(FakeAgentConnection::new().on_user_message({
+ let simulate_changes = simulate_changes.clone();
+ let next_filename = next_filename.clone();
+ let fs = fs.clone();
+ move |request, thread, mut cx| {
+ let fs = fs.clone();
+ let simulate_changes = simulate_changes.clone();
+ let next_filename = next_filename.clone();
+ async move {
+ if simulate_changes.load(SeqCst) {
+ let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
+ fs.write(Path::new(&filename), b"").await?;
+ }
+
+ let acp::ContentBlock::Text(content) = &request.prompt[0] else {
+ panic!("expected text content block");
+ };
+ thread.update(&mut cx, |thread, cx| {
+ thread
+ .handle_session_update(
+ acp::SessionUpdate::AgentMessageChunk {
+ content: content.text.to_uppercase().into(),
+ },
+ cx,
+ )
+ .unwrap();
+ })?;
+ Ok(acp::PromptResponse {
+ stop_reason: acp::StopReason::EndTurn,
+ })
+ }
+ .boxed_local()
+ }
+ }));
+ let thread = connection
+ .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
+ .await
+ .unwrap();
+
+ cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
+ .await
+ .unwrap();
+ thread.read_with(cx, |thread, cx| {
+ assert_eq!(
+ thread.to_markdown(cx),
+ indoc! {"
+ ## User (checkpoint)
+
+ Lorem
+
+ ## Assistant
+
+ LOREM
+
+ "}
+ );
+ });
+ assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
+
+ cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
+ .await
+ .unwrap();
+ thread.read_with(cx, |thread, cx| {
+ assert_eq!(
+ thread.to_markdown(cx),
+ indoc! {"
+ ## User (checkpoint)
+
+ Lorem
+
+ ## Assistant
+
+ LOREM
+
+ ## User (checkpoint)
+
+ ipsum
+
+ ## Assistant
+
+ IPSUM
+
+ "}
+ );
+ });
+ assert_eq!(
+ fs.files(),
+ vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
+ );
+
+ // Checkpoint isn't stored when there are no changes.
+ simulate_changes.store(false, SeqCst);
+ cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
+ .await
+ .unwrap();
+ thread.read_with(cx, |thread, cx| {
+ assert_eq!(
+ thread.to_markdown(cx),
+ indoc! {"
+ ## User (checkpoint)
+
+ Lorem
+
+ ## Assistant
+
+ LOREM
+
+ ## User (checkpoint)
+
+ ipsum
+
+ ## Assistant
+
+ IPSUM
+
+ ## User
+
+ dolor
+
+ ## Assistant
+
+ DOLOR
+
+ "}
+ );
+ });
+ assert_eq!(
+ fs.files(),
+ vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
+ );
+
+ // Rewinding the conversation truncates the history and restores the checkpoint.
+ thread
+ .update(cx, |thread, cx| {
+ let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
+ panic!("unexpected entries {:?}", thread.entries)
+ };
+ thread.rewind(message.id.clone().unwrap(), cx)
+ })
+ .await
+ .unwrap();
+ thread.read_with(cx, |thread, cx| {
+ assert_eq!(
+ thread.to_markdown(cx),
+ indoc! {"
+ ## User (checkpoint)
+
+ Lorem
+
+ ## Assistant
+
+ LOREM
+
+ "}
+ );
+ });
+ assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
+ }
+
async fn run_until_first_tool_call(
thread: &Entity<AcpThread>,
cx: &mut TestAppContext,
@@ -1938,6 +2245,7 @@ mod tests {
fn prompt(
&self,
+ _id: Option<UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<gpui::Result<acp::PromptResponse>> {
@@ -1966,5 +2274,25 @@ mod tests {
})
.detach();
}
+
+ fn session_editor(
+ &self,
+ session_id: &acp::SessionId,
+ _cx: &mut App,
+ ) -> Option<Rc<dyn AgentSessionEditor>> {
+ Some(Rc::new(FakeAgentSessionEditor {
+ _session_id: session_id.clone(),
+ }))
+ }
+ }
+
+ struct FakeAgentSessionEditor {
+ _session_id: acp::SessionId,
+ }
+
+ impl AgentSessionEditor for FakeAgentSessionEditor {
+ fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
+ Task::ready(Ok(()))
+ }
}
}
@@ -1,13 +1,21 @@
-use std::{error::Error, fmt, path::Path, rc::Rc};
-
+use crate::AcpThread;
use agent_client_protocol::{self as acp};
use anyhow::Result;
use collections::IndexMap;
use gpui::{AsyncApp, Entity, SharedString, Task};
use project::Project;
+use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use ui::{App, IconName};
+use uuid::Uuid;
-use crate::AcpThread;
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct UserMessageId(Arc<str>);
+
+impl UserMessageId {
+ pub fn new() -> Self {
+ Self(Uuid::new_v4().to_string().into())
+ }
+}
pub trait AgentConnection {
fn new_thread(
@@ -21,11 +29,23 @@ pub trait AgentConnection {
fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
- fn prompt(&self, params: acp::PromptRequest, cx: &mut App)
- -> Task<Result<acp::PromptResponse>>;
+ fn prompt(
+ &self,
+ user_message_id: Option<UserMessageId>,
+ params: acp::PromptRequest,
+ cx: &mut App,
+ ) -> Task<Result<acp::PromptResponse>>;
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
+ fn session_editor(
+ &self,
+ _session_id: &acp::SessionId,
+ _cx: &mut App,
+ ) -> Option<Rc<dyn AgentSessionEditor>> {
+ None
+ }
+
/// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
///
/// If the agent does not support model selection, returns [None].
@@ -35,6 +55,10 @@ pub trait AgentConnection {
}
}
+pub trait AgentSessionEditor {
+ fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
+}
+
#[derive(Debug)]
pub struct AuthRequired;
@@ -1,8 +1,9 @@
use crate::{AgentResponseEvent, Thread, templates::Templates};
use crate::{
ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool,
- FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MessageContent, MovePathTool, NowTool,
- OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool,
+ FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool,
+ ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent,
+ WebSearchTool,
};
use acp_thread::AgentModelSelector;
use agent_client_protocol as acp;
@@ -637,9 +638,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
fn prompt(
&self,
+ id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {
+ let id = id.expect("UserMessageId is required");
let session_id = params.session_id.clone();
let agent = self.0.clone();
log::info!("Received prompt request for session: {}", session_id);
@@ -660,13 +663,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
})?;
log::debug!("Found session for: {}", session_id);
- let message: Vec<MessageContent> = params
+ let content: Vec<UserMessageContent> = params
.prompt
.into_iter()
.map(Into::into)
.collect::<Vec<_>>();
- log::info!("Converted prompt to message: {} chars", message.len());
- log::debug!("Message content: {:?}", message);
+ log::info!("Converted prompt to message: {} chars", content.len());
+ log::debug!("Message id: {:?}", id);
+ log::debug!("Message content: {:?}", content);
// Get model using the ModelSelector capability (always available for agent2)
// Get the selected model from the thread directly
@@ -674,7 +678,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
// Send to thread
log::info!("Sending message to thread with model: {:?}", model.name());
- let mut response_stream = thread.update(cx, |thread, cx| thread.send(message, cx))?;
+ let mut response_stream =
+ thread.update(cx, |thread, cx| thread.send(id, content, cx))?;
// Handle response stream and forward to session.acp_thread
while let Some(result) = response_stream.next().await {
@@ -768,6 +773,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
}
});
}
+
+ fn session_editor(
+ &self,
+ session_id: &agent_client_protocol::SessionId,
+ cx: &mut App,
+ ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
+ self.0.update(cx, |agent, _cx| {
+ agent
+ .sessions
+ .get(session_id)
+ .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
+ })
+ }
+}
+
+struct NativeAgentSessionEditor(Entity<Thread>);
+
+impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
+ fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
+ Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
+ }
}
#[cfg(test)]
@@ -1,6 +1,5 @@
use super::*;
-use crate::MessageContent;
-use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList};
+use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
use action_log::ActionLog;
use agent_client_protocol::{self as acp};
use agent_settings::AgentProfileId;
@@ -38,15 +37,19 @@ async fn test_echo(cx: &mut TestAppContext) {
let events = thread
.update(cx, |thread, cx| {
- thread.send("Testing: Reply with 'Hello'", cx)
+ thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
})
.collect()
.await;
thread.update(cx, |thread, _cx| {
assert_eq!(
- thread.messages().last().unwrap().content,
- vec![MessageContent::Text("Hello".to_string())]
- );
+ thread.last_message().unwrap().to_markdown(),
+ indoc! {"
+ ## Assistant
+
+ Hello
+ "}
+ )
});
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
}
@@ -59,12 +62,13 @@ async fn test_thinking(cx: &mut TestAppContext) {
let events = thread
.update(cx, |thread, cx| {
thread.send(
- indoc! {"
+ UserMessageId::new(),
+ [indoc! {"
Testing:
Generate a thinking step where you just think the word 'Think',
and have your final answer be 'Hello'
- "},
+ "}],
cx,
)
})
@@ -72,9 +76,10 @@ async fn test_thinking(cx: &mut TestAppContext) {
.await;
thread.update(cx, |thread, _cx| {
assert_eq!(
- thread.messages().last().unwrap().to_markdown(),
+ thread.last_message().unwrap().to_markdown(),
indoc! {"
- ## assistant
+ ## Assistant
+
<think>Think</think>
Hello
"}
@@ -95,7 +100,9 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
project_context.borrow_mut().shell = "test-shell".into();
thread.update(cx, |thread, _| thread.add_tool(EchoTool));
- thread.update(cx, |thread, cx| thread.send("abc", cx));
+ thread.update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["abc"], cx)
+ });
cx.run_until_parked();
let mut pending_completions = fake_model.pending_completions();
assert_eq!(
@@ -132,7 +139,8 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
.update(cx, |thread, cx| {
thread.add_tool(EchoTool);
thread.send(
- "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
+ UserMessageId::new(),
+ ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
cx,
)
})
@@ -146,7 +154,11 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
thread.remove_tool(&AgentTool::name(&EchoTool));
thread.add_tool(DelayTool);
thread.send(
- "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
+ UserMessageId::new(),
+ [
+ "Now call the delay tool with 200ms.",
+ "When the timer goes off, then you echo the output of the tool.",
+ ],
cx,
)
})
@@ -156,13 +168,14 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
thread.update(cx, |thread, _cx| {
assert!(
thread
- .messages()
- .last()
+ .last_message()
+ .unwrap()
+ .as_agent_message()
.unwrap()
.content
.iter()
.any(|content| {
- if let MessageContent::Text(text) = content {
+ if let AgentMessageContent::Text(text) = content {
text.contains("Ding")
} else {
false
@@ -182,7 +195,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
// Test a tool call that's likely to complete *before* streaming stops.
let mut events = thread.update(cx, |thread, cx| {
thread.add_tool(WordListTool);
- thread.send("Test the word_list tool.", cx)
+ thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
});
let mut saw_partial_tool_use = false;
@@ -190,8 +203,10 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
thread.update(cx, |thread, _cx| {
// Look for a tool use in the thread's last message
- let last_content = thread.messages().last().unwrap().content.last().unwrap();
- if let MessageContent::ToolUse(last_tool_use) = last_content {
+ let message = thread.last_message().unwrap();
+ let agent_message = message.as_agent_message().unwrap();
+ let last_content = agent_message.content.last().unwrap();
+ if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
assert_eq!(last_tool_use.name.as_ref(), "word_list");
if tool_call.status == acp::ToolCallStatus::Pending {
if !last_tool_use.is_input_complete
@@ -229,7 +244,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
let mut events = thread.update(cx, |thread, cx| {
thread.add_tool(ToolRequiringPermission);
- thread.send("abc", cx)
+ thread.send(UserMessageId::new(), ["abc"], cx)
});
cx.run_until_parked();
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
@@ -357,7 +372,9 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
- let mut events = thread.update(cx, |thread, cx| thread.send("abc", cx));
+ let mut events = thread.update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["abc"], cx)
+ });
cx.run_until_parked();
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
@@ -449,7 +466,12 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
.update(cx, |thread, cx| {
thread.add_tool(DelayTool);
thread.send(
- "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
+ UserMessageId::new(),
+ [
+ "Call the delay tool twice in the same message.",
+ "Once with 100ms. Once with 300ms.",
+ "When both timers are complete, describe the outputs.",
+ ],
cx,
)
})
@@ -460,12 +482,13 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
thread.update(cx, |thread, _cx| {
- let last_message = thread.messages().last().unwrap();
- let text = last_message
+ let last_message = thread.last_message().unwrap();
+ let agent_message = last_message.as_agent_message().unwrap();
+ let text = agent_message
.content
.iter()
.filter_map(|content| {
- if let MessageContent::Text(text) = content {
+ if let AgentMessageContent::Text(text) = content {
Some(text.as_str())
} else {
None
@@ -521,7 +544,7 @@ async fn test_profiles(cx: &mut TestAppContext) {
// Test that test-1 profile (default) has echo and delay tools
thread.update(cx, |thread, cx| {
thread.set_profile(AgentProfileId("test-1".into()));
- thread.send("test", cx);
+ thread.send(UserMessageId::new(), ["test"], cx);
});
cx.run_until_parked();
@@ -539,7 +562,7 @@ async fn test_profiles(cx: &mut TestAppContext) {
// Switch to test-2 profile, and verify that it has only the infinite tool.
thread.update(cx, |thread, cx| {
thread.set_profile(AgentProfileId("test-2".into()));
- thread.send("test2", cx)
+ thread.send(UserMessageId::new(), ["test2"], cx)
});
cx.run_until_parked();
let mut pending_completions = fake_model.pending_completions();
@@ -562,7 +585,8 @@ async fn test_cancellation(cx: &mut TestAppContext) {
thread.add_tool(InfiniteTool);
thread.add_tool(EchoTool);
thread.send(
- "Call the echo tool and then call the infinite tool, then explain their output",
+ UserMessageId::new(),
+ ["Call the echo tool, then call the infinite tool, then explain their output"],
cx,
)
});
@@ -607,14 +631,20 @@ async fn test_cancellation(cx: &mut TestAppContext) {
// Ensure we can still send a new message after cancellation.
let events = thread
.update(cx, |thread, cx| {
- thread.send("Testing: reply with 'Hello' then stop.", cx)
+ thread.send(
+ UserMessageId::new(),
+ ["Testing: reply with 'Hello' then stop."],
+ cx,
+ )
})
.collect::<Vec<_>>()
.await;
thread.update(cx, |thread, _cx| {
+ let message = thread.last_message().unwrap();
+ let agent_message = message.as_agent_message().unwrap();
assert_eq!(
- thread.messages().last().unwrap().content,
- vec![MessageContent::Text("Hello".to_string())]
+ agent_message.content,
+ vec![AgentMessageContent::Text("Hello".to_string())]
);
});
assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
@@ -625,13 +655,16 @@ async fn test_refusal(cx: &mut TestAppContext) {
let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
let fake_model = model.as_fake();
- let events = thread.update(cx, |thread, cx| thread.send("Hello", cx));
+ let events = thread.update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Hello"], cx)
+ });
cx.run_until_parked();
thread.read_with(cx, |thread, _| {
assert_eq!(
thread.to_markdown(),
indoc! {"
- ## user
+ ## User
+
Hello
"}
);
@@ -643,9 +676,12 @@ async fn test_refusal(cx: &mut TestAppContext) {
assert_eq!(
thread.to_markdown(),
indoc! {"
- ## user
+ ## User
+
Hello
- ## assistant
+
+ ## Assistant
+
Hey!
"}
);
@@ -661,6 +697,85 @@ async fn test_refusal(cx: &mut TestAppContext) {
});
}
+#[gpui::test]
+async fn test_truncate(cx: &mut TestAppContext) {
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ let message_id = UserMessageId::new();
+ thread.update(cx, |thread, cx| {
+ thread.send(message_id.clone(), ["Hello"], cx)
+ });
+ cx.run_until_parked();
+ thread.read_with(cx, |thread, _| {
+ assert_eq!(
+ thread.to_markdown(),
+ indoc! {"
+ ## User
+
+ Hello
+ "}
+ );
+ });
+
+ fake_model.send_last_completion_stream_text_chunk("Hey!");
+ cx.run_until_parked();
+ thread.read_with(cx, |thread, _| {
+ assert_eq!(
+ thread.to_markdown(),
+ indoc! {"
+ ## User
+
+ Hello
+
+ ## Assistant
+
+ Hey!
+ "}
+ );
+ });
+
+ thread
+ .update(cx, |thread, _cx| thread.truncate(message_id))
+ .unwrap();
+ cx.run_until_parked();
+ thread.read_with(cx, |thread, _| {
+ assert_eq!(thread.to_markdown(), "");
+ });
+
+ // Ensure we can still send a new message after truncation.
+ thread.update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Hi"], cx)
+ });
+ thread.update(cx, |thread, _cx| {
+ assert_eq!(
+ thread.to_markdown(),
+ indoc! {"
+ ## User
+
+ Hi
+ "}
+ );
+ });
+ cx.run_until_parked();
+ fake_model.send_last_completion_stream_text_chunk("Ahoy!");
+ cx.run_until_parked();
+ thread.read_with(cx, |thread, _| {
+ assert_eq!(
+ thread.to_markdown(),
+ indoc! {"
+ ## User
+
+ Hi
+
+ ## Assistant
+
+ Ahoy!
+ "}
+ );
+ });
+}
+
#[gpui::test]
async fn test_agent_connection(cx: &mut TestAppContext) {
cx.update(settings::init);
@@ -774,6 +889,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
let result = cx
.update(|cx| {
connection.prompt(
+ Some(acp_thread::UserMessageId::new()),
acp::PromptRequest {
session_id: session_id.clone(),
prompt: vec!["ghi".into()],
@@ -796,7 +912,9 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
let fake_model = model.as_fake();
- let mut events = thread.update(cx, |thread, cx| thread.send("Think", cx));
+ let mut events = thread.update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), ["Think"], cx)
+ });
cx.run_until_parked();
// Simulate streaming partial input.
@@ -1,12 +1,12 @@
use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
-use acp_thread::MentionUri;
+use acp_thread::{MentionUri, UserMessageId};
use action_log::ActionLog;
use agent_client_protocol as acp;
use agent_settings::{AgentProfileId, AgentSettings};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::adapt_schema_to_format;
use cloud_llm_client::{CompletionIntent, CompletionMode};
-use collections::HashMap;
+use collections::IndexMap;
use fs::Fs;
use futures::{
channel::{mpsc, oneshot},
@@ -19,7 +19,6 @@ use language_model::{
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
};
-use log;
use project::Project;
use prompt_store::ProjectContext;
use schemars::{JsonSchema, Schema};
@@ -30,49 +29,199 @@ use std::fmt::Write;
use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
use util::{ResultExt, markdown::MarkdownCodeBlock};
-#[derive(Debug, Clone)]
-pub struct AgentMessage {
- pub role: Role,
- pub content: Vec<MessageContent>,
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum Message {
+ User(UserMessage),
+ Agent(AgentMessage),
+}
+
+impl Message {
+ pub fn as_agent_message(&self) -> Option<&AgentMessage> {
+ match self {
+ Message::Agent(agent_message) => Some(agent_message),
+ _ => None,
+ }
+ }
+
+ pub fn to_markdown(&self) -> String {
+ match self {
+ Message::User(message) => message.to_markdown(),
+ Message::Agent(message) => message.to_markdown(),
+ }
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct UserMessage {
+ pub id: UserMessageId,
+ pub content: Vec<UserMessageContent>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
-pub enum MessageContent {
+pub enum UserMessageContent {
Text(String),
- Thinking {
- text: String,
- signature: Option<String>,
- },
- Mention {
- uri: MentionUri,
- content: String,
- },
- RedactedThinking(String),
+ Mention { uri: MentionUri, content: String },
Image(LanguageModelImage),
- ToolUse(LanguageModelToolUse),
- ToolResult(LanguageModelToolResult),
+}
+
+impl UserMessage {
+ pub fn to_markdown(&self) -> String {
+ let mut markdown = String::from("## User\n\n");
+
+ for content in &self.content {
+ match content {
+ UserMessageContent::Text(text) => {
+ markdown.push_str(text);
+ markdown.push('\n');
+ }
+ UserMessageContent::Image(_) => {
+ markdown.push_str("<image />\n");
+ }
+ UserMessageContent::Mention { uri, content } => {
+ if !content.is_empty() {
+ markdown.push_str(&format!("{}\n\n{}\n", uri.to_link(), content));
+ } else {
+ markdown.push_str(&format!("{}\n", uri.to_link()));
+ }
+ }
+ }
+ }
+
+ markdown
+ }
+
+ fn to_request(&self) -> LanguageModelRequestMessage {
+ let mut message = LanguageModelRequestMessage {
+ role: Role::User,
+ content: Vec::with_capacity(self.content.len()),
+ cache: false,
+ };
+
+ const OPEN_CONTEXT: &str = "<context>\n\
+ The following items were attached by the user. \
+ They are up-to-date and don't need to be re-read.\n\n";
+
+ const OPEN_FILES_TAG: &str = "<files>";
+ const OPEN_SYMBOLS_TAG: &str = "<symbols>";
+ const OPEN_THREADS_TAG: &str = "<threads>";
+ const OPEN_RULES_TAG: &str =
+ "<rules>\nThe user has specified the following rules that should be applied:\n";
+
+ let mut file_context = OPEN_FILES_TAG.to_string();
+ let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
+ let mut thread_context = OPEN_THREADS_TAG.to_string();
+ let mut rules_context = OPEN_RULES_TAG.to_string();
+
+ for chunk in &self.content {
+ let chunk = match chunk {
+ UserMessageContent::Text(text) => {
+ language_model::MessageContent::Text(text.clone())
+ }
+ UserMessageContent::Image(value) => {
+ language_model::MessageContent::Image(value.clone())
+ }
+ UserMessageContent::Mention { uri, content } => {
+ match uri {
+ MentionUri::File(path) | MentionUri::Symbol(path, _) => {
+ write!(
+ &mut symbol_context,
+ "\n{}",
+ MarkdownCodeBlock {
+ tag: &codeblock_tag(&path),
+ text: &content.to_string(),
+ }
+ )
+ .ok();
+ }
+ MentionUri::Thread(_session_id) => {
+ write!(&mut thread_context, "\n{}\n", content).ok();
+ }
+ MentionUri::Rule(_user_prompt_id) => {
+ write!(
+ &mut rules_context,
+ "\n{}",
+ MarkdownCodeBlock {
+ tag: "",
+ text: &content
+ }
+ )
+ .ok();
+ }
+ }
+
+ language_model::MessageContent::Text(uri.to_link())
+ }
+ };
+
+ message.content.push(chunk);
+ }
+
+ let len_before_context = message.content.len();
+
+ if file_context.len() > OPEN_FILES_TAG.len() {
+ file_context.push_str("</files>\n");
+ message
+ .content
+ .push(language_model::MessageContent::Text(file_context));
+ }
+
+ if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
+ symbol_context.push_str("</symbols>\n");
+ message
+ .content
+ .push(language_model::MessageContent::Text(symbol_context));
+ }
+
+ if thread_context.len() > OPEN_THREADS_TAG.len() {
+ thread_context.push_str("</threads>\n");
+ message
+ .content
+ .push(language_model::MessageContent::Text(thread_context));
+ }
+
+ if rules_context.len() > OPEN_RULES_TAG.len() {
+ rules_context.push_str("</user_rules>\n");
+ message
+ .content
+ .push(language_model::MessageContent::Text(rules_context));
+ }
+
+ if message.content.len() > len_before_context {
+ message.content.insert(
+ len_before_context,
+ language_model::MessageContent::Text(OPEN_CONTEXT.into()),
+ );
+ message
+ .content
+ .push(language_model::MessageContent::Text("</context>".into()));
+ }
+
+ message
+ }
}
impl AgentMessage {
pub fn to_markdown(&self) -> String {
- let mut markdown = format!("## {}\n", self.role);
+ let mut markdown = String::from("## Assistant\n\n");
for content in &self.content {
match content {
- MessageContent::Text(text) => {
+ AgentMessageContent::Text(text) => {
markdown.push_str(text);
markdown.push('\n');
}
- MessageContent::Thinking { text, .. } => {
+ AgentMessageContent::Thinking { text, .. } => {
markdown.push_str("<think>");
markdown.push_str(text);
markdown.push_str("</think>\n");
}
- MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
- MessageContent::Image(_) => {
+ AgentMessageContent::RedactedThinking(_) => {
+ markdown.push_str("<redacted_thinking />\n")
+ }
+ AgentMessageContent::Image(_) => {
markdown.push_str("<image />\n");
}
- MessageContent::ToolUse(tool_use) => {
+ AgentMessageContent::ToolUse(tool_use) => {
markdown.push_str(&format!(
"**Tool Use**: {} (ID: {})\n",
tool_use.name, tool_use.id
@@ -85,41 +234,106 @@ impl AgentMessage {
}
));
}
- MessageContent::ToolResult(tool_result) => {
- markdown.push_str(&format!(
- "**Tool Result**: {} (ID: {})\n\n",
- tool_result.tool_name, tool_result.tool_use_id
- ));
- if tool_result.is_error {
- markdown.push_str("**ERROR:**\n");
- }
+ }
+ }
- match &tool_result.content {
- LanguageModelToolResultContent::Text(text) => {
- writeln!(markdown, "{text}\n").ok();
- }
- LanguageModelToolResultContent::Image(_) => {
- writeln!(markdown, "<image />\n").ok();
- }
- }
+ for tool_result in self.tool_results.values() {
+ markdown.push_str(&format!(
+ "**Tool Result**: {} (ID: {})\n\n",
+ tool_result.tool_name, tool_result.tool_use_id
+ ));
+ if tool_result.is_error {
+ markdown.push_str("**ERROR:**\n");
+ }
- if let Some(output) = tool_result.output.as_ref() {
- writeln!(
- markdown,
- "**Debug Output**:\n\n```json\n{}\n```\n",
- serde_json::to_string_pretty(output).unwrap()
- )
- .unwrap();
- }
+ match &tool_result.content {
+ LanguageModelToolResultContent::Text(text) => {
+ writeln!(markdown, "{text}\n").ok();
}
- MessageContent::Mention { uri, .. } => {
- write!(markdown, "{}", uri.to_link()).ok();
+ LanguageModelToolResultContent::Image(_) => {
+ writeln!(markdown, "<image />\n").ok();
}
}
+
+ if let Some(output) = tool_result.output.as_ref() {
+ writeln!(
+ markdown,
+ "**Debug Output**:\n\n```json\n{}\n```\n",
+ serde_json::to_string_pretty(output).unwrap()
+ )
+ .unwrap();
+ }
}
markdown
}
+
+ pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
+ let mut content = Vec::with_capacity(self.content.len());
+ for chunk in &self.content {
+ let chunk = match chunk {
+ AgentMessageContent::Text(text) => {
+ language_model::MessageContent::Text(text.clone())
+ }
+ AgentMessageContent::Thinking { text, signature } => {
+ language_model::MessageContent::Thinking {
+ text: text.clone(),
+ signature: signature.clone(),
+ }
+ }
+ AgentMessageContent::RedactedThinking(value) => {
+ language_model::MessageContent::RedactedThinking(value.clone())
+ }
+ AgentMessageContent::ToolUse(value) => {
+ language_model::MessageContent::ToolUse(value.clone())
+ }
+ AgentMessageContent::Image(value) => {
+ language_model::MessageContent::Image(value.clone())
+ }
+ };
+ content.push(chunk);
+ }
+
+ let mut messages = vec![LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content,
+ cache: false,
+ }];
+
+ if !self.tool_results.is_empty() {
+ let mut tool_results = Vec::with_capacity(self.tool_results.len());
+ for tool_result in self.tool_results.values() {
+ tool_results.push(language_model::MessageContent::ToolResult(
+ tool_result.clone(),
+ ));
+ }
+ messages.push(LanguageModelRequestMessage {
+ role: Role::User,
+ content: tool_results,
+ cache: false,
+ });
+ }
+
+ messages
+ }
+}
+
+#[derive(Default, Debug, Clone, PartialEq, Eq)]
+pub struct AgentMessage {
+ pub content: Vec<AgentMessageContent>,
+ pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum AgentMessageContent {
+ Text(String),
+ Thinking {
+ text: String,
+ signature: Option<String>,
+ },
+ RedactedThinking(String),
+ Image(LanguageModelImage),
+ ToolUse(LanguageModelToolUse),
}
#[derive(Debug)]
@@ -140,13 +354,13 @@ pub struct ToolCallAuthorization {
}
pub struct Thread {
- messages: Vec<AgentMessage>,
+ messages: Vec<Message>,
completion_mode: CompletionMode,
/// Holds the task that handles agent interaction until the end of the turn.
/// Survives across multiple requests as the model performs tool calls and
/// we run tools, report their results.
running_turn: Option<Task<()>>,
- pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
+ pending_agent_message: Option<AgentMessage>,
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
context_server_registry: Entity<ContextServerRegistry>,
profile_id: AgentProfileId,
@@ -172,7 +386,7 @@ impl Thread {
messages: Vec::new(),
completion_mode: CompletionMode::Normal,
running_turn: None,
- pending_tool_uses: HashMap::default(),
+ pending_agent_message: None,
tools: BTreeMap::default(),
context_server_registry,
profile_id,
@@ -196,8 +410,13 @@ impl Thread {
self.completion_mode = mode;
}
- pub fn messages(&self) -> &[AgentMessage] {
- &self.messages
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn last_message(&self) -> Option<Message> {
+ if let Some(message) = self.pending_agent_message.clone() {
+ Some(Message::Agent(message))
+ } else {
+ self.messages.last().cloned()
+ }
}
pub fn add_tool(&mut self, tool: impl AgentTool) {
@@ -213,35 +432,36 @@ impl Thread {
}
pub fn cancel(&mut self) {
+ // TODO: do we need to emit a stop::cancel for ACP?
self.running_turn.take();
+ self.flush_pending_agent_message();
+ }
- let tool_results = self
- .pending_tool_uses
- .drain()
- .map(|(tool_use_id, tool_use)| {
- MessageContent::ToolResult(LanguageModelToolResult {
- tool_use_id,
- tool_name: tool_use.name.clone(),
- is_error: true,
- content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
- output: None,
- })
- })
- .collect::<Vec<_>>();
- self.last_user_message().content.extend(tool_results);
+ pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
+ self.cancel();
+ let Some(position) = self.messages.iter().position(
+ |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
+ ) else {
+ return Err(anyhow!("Message not found"));
+ };
+ self.messages.truncate(position);
+ Ok(())
}
/// Sending a message results in the model streaming a response, which could include tool calls.
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
- pub fn send(
+ pub fn send<T>(
&mut self,
- content: impl Into<UserMessage>,
+ message_id: UserMessageId,
+ content: impl IntoIterator<Item = T>,
cx: &mut Context<Self>,
- ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
- let content = content.into().0;
-
+ ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>
+ where
+ T: Into<UserMessageContent>,
+ {
let model = self.selected_model.clone();
+ let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
log::info!("Thread::send called with model: {:?}", model.name());
log::debug!("Thread::send content: {:?}", content);
@@ -251,10 +471,10 @@ impl Thread {
let event_stream = AgentResponseEventStream(events_tx);
let user_message_ix = self.messages.len();
- self.messages.push(AgentMessage {
- role: Role::User,
+ self.messages.push(Message::User(UserMessage {
+ id: message_id,
content,
- });
+ }));
log::info!("Total messages in thread: {}", self.messages.len());
self.running_turn = Some(cx.spawn(async move |thread, cx| {
log::info!("Starting agent turn execution");
@@ -270,15 +490,11 @@ impl Thread {
thread.build_completion_request(completion_intent, cx)
})?;
- // println!(
- // "request: {}",
- // serde_json::to_string_pretty(&request).unwrap()
- // );
-
// Stream events, appending to messages and collecting up tool uses.
log::info!("Calling model.stream_completion");
let mut events = model.stream_completion(request, cx).await?;
log::debug!("Stream completion started successfully");
+
let mut tool_uses = FuturesUnordered::new();
while let Some(event) = events.next().await {
match event {
@@ -286,6 +502,7 @@ impl Thread {
event_stream.send_stop(reason);
if reason == StopReason::Refusal {
thread.update(cx, |thread, _cx| {
+ thread.pending_agent_message = None;
thread.messages.truncate(user_message_ix);
})?;
break 'outer;
@@ -338,15 +555,16 @@ impl Thread {
);
thread
.update(cx, |thread, _cx| {
- thread.pending_tool_uses.remove(&tool_result.tool_use_id);
thread
- .last_user_message()
- .content
- .push(MessageContent::ToolResult(tool_result));
+ .pending_agent_message()
+ .tool_results
+ .insert(tool_result.tool_use_id.clone(), tool_result);
})
.ok();
}
+ thread.update(cx, |thread, _cx| thread.flush_pending_agent_message())?;
+
completion_intent = CompletionIntent::ToolResults;
}
@@ -354,6 +572,10 @@ impl Thread {
}
.await;
+ thread
+ .update(cx, |thread, _cx| thread.flush_pending_agent_message())
+ .ok();
+
if let Err(error) = turn_result {
log::error!("Turn execution failed: {:?}", error);
event_stream.send_error(error);
@@ -364,7 +586,7 @@ impl Thread {
events_rx
}
- pub fn build_system_message(&self) -> AgentMessage {
+ pub fn build_system_message(&self) -> LanguageModelRequestMessage {
log::debug!("Building system message");
let prompt = SystemPromptTemplate {
project: &self.project_context.borrow(),
@@ -374,9 +596,10 @@ impl Thread {
.context("failed to build system prompt")
.expect("Invalid template");
log::debug!("System message built");
- AgentMessage {
+ LanguageModelRequestMessage {
role: Role::System,
- content: vec![prompt.as_str().into()],
+ content: vec![prompt.into()],
+ cache: true,
}
}
@@ -394,10 +617,7 @@ impl Thread {
match event {
StartMessage { .. } => {
- self.messages.push(AgentMessage {
- role: Role::Assistant,
- content: Vec::new(),
- });
+ self.messages.push(Message::Agent(AgentMessage::default()));
}
Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
Thinking { text, signature } => {
@@ -435,11 +655,13 @@ impl Thread {
) {
events_stream.send_text(&new_text);
- let last_message = self.last_assistant_message();
- if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
+ let last_message = self.pending_agent_message();
+ if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
text.push_str(&new_text);
} else {
- last_message.content.push(MessageContent::Text(new_text));
+ last_message
+ .content
+ .push(AgentMessageContent::Text(new_text));
}
cx.notify();
@@ -454,13 +676,14 @@ impl Thread {
) {
event_stream.send_thinking(&new_text);
- let last_message = self.last_assistant_message();
- if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
+ let last_message = self.pending_agent_message();
+ if let Some(AgentMessageContent::Thinking { text, signature }) =
+ last_message.content.last_mut()
{
text.push_str(&new_text);
*signature = new_signature.or(signature.take());
} else {
- last_message.content.push(MessageContent::Thinking {
+ last_message.content.push(AgentMessageContent::Thinking {
text: new_text,
signature: new_signature,
});
@@ -470,10 +693,10 @@ impl Thread {
}
fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
- let last_message = self.last_assistant_message();
+ let last_message = self.pending_agent_message();
last_message
.content
- .push(MessageContent::RedactedThinking(data));
+ .push(AgentMessageContent::RedactedThinking(data));
cx.notify();
}
@@ -486,14 +709,17 @@ impl Thread {
cx.notify();
let tool = self.tools.get(tool_use.name.as_ref()).cloned();
-
- self.pending_tool_uses
- .insert(tool_use.id.clone(), tool_use.clone());
- let last_message = self.last_assistant_message();
+ let mut title = SharedString::from(&tool_use.name);
+ let mut kind = acp::ToolKind::Other;
+ if let Some(tool) = tool.as_ref() {
+ title = tool.initial_title(tool_use.input.clone());
+ kind = tool.kind();
+ }
// Ensure the last message ends in the current tool use
+ let last_message = self.pending_agent_message();
let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
- if let MessageContent::ToolUse(last_tool_use) = content {
+ if let AgentMessageContent::ToolUse(last_tool_use) = content {
if last_tool_use.id == tool_use.id {
*last_tool_use = tool_use.clone();
false
@@ -505,18 +731,11 @@ impl Thread {
}
});
- let mut title = SharedString::from(&tool_use.name);
- let mut kind = acp::ToolKind::Other;
- if let Some(tool) = tool.as_ref() {
- title = tool.initial_title(tool_use.input.clone());
- kind = tool.kind();
- }
-
if push_new_tool_use {
event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
last_message
.content
- .push(MessageContent::ToolUse(tool_use.clone()));
+ .push(AgentMessageContent::ToolUse(tool_use.clone()));
} else {
event_stream.update_tool_call_fields(
&tool_use.id,
@@ -601,30 +820,37 @@ impl Thread {
}
}
- /// Guarantees the last message is from the assistant and returns a mutable reference.
- fn last_assistant_message(&mut self) -> &mut AgentMessage {
- if self
- .messages
- .last()
- .map_or(true, |m| m.role != Role::Assistant)
- {
- self.messages.push(AgentMessage {
- role: Role::Assistant,
- content: Vec::new(),
- });
- }
- self.messages.last_mut().unwrap()
+ fn pending_agent_message(&mut self) -> &mut AgentMessage {
+ self.pending_agent_message.get_or_insert_default()
}
- /// Guarantees the last message is from the user and returns a mutable reference.
- fn last_user_message(&mut self) -> &mut AgentMessage {
- if self.messages.last().map_or(true, |m| m.role != Role::User) {
- self.messages.push(AgentMessage {
- role: Role::User,
- content: Vec::new(),
- });
+ fn flush_pending_agent_message(&mut self) {
+ let Some(mut message) = self.pending_agent_message.take() else {
+ return;
+ };
+
+ for content in &message.content {
+ let AgentMessageContent::ToolUse(tool_use) = content else {
+ continue;
+ };
+
+ if !message.tool_results.contains_key(&tool_use.id) {
+ message.tool_results.insert(
+ tool_use.id.clone(),
+ LanguageModelToolResult {
+ tool_use_id: tool_use.id.clone(),
+ tool_name: tool_use.name.clone(),
+ is_error: true,
+ content: LanguageModelToolResultContent::Text(
+ "Tool canceled by user".into(),
+ ),
+ output: None,
+ },
+ );
+ }
}
- self.messages.last_mut().unwrap()
+
+ self.messages.push(Message::Agent(message));
}
pub(crate) fn build_completion_request(
@@ -712,46 +938,36 @@ impl Thread {
"Building request messages from {} thread messages",
self.messages.len()
);
+ let mut messages = vec![self.build_system_message()];
+ for message in &self.messages {
+ match message {
+ Message::User(message) => messages.push(message.to_request()),
+ Message::Agent(message) => messages.extend(message.to_request()),
+ }
+ }
+
+ if let Some(message) = self.pending_agent_message.as_ref() {
+ messages.extend(message.to_request());
+ }
- let messages = Some(self.build_system_message())
- .iter()
- .chain(self.messages.iter())
- .map(|message| {
- log::trace!(
- " - {} message with {} content items",
- match message.role {
- Role::System => "System",
- Role::User => "User",
- Role::Assistant => "Assistant",
- },
- message.content.len()
- );
- message.to_request()
- })
- .collect();
messages
}
pub fn to_markdown(&self) -> String {
let mut markdown = String::new();
- for message in &self.messages {
+ for (ix, message) in self.messages.iter().enumerate() {
+ if ix > 0 {
+ markdown.push('\n');
+ }
markdown.push_str(&message.to_markdown());
}
- markdown
- }
-}
-pub struct UserMessage(Vec<MessageContent>);
-
-impl From<Vec<MessageContent>> for UserMessage {
- fn from(content: Vec<MessageContent>) -> Self {
- UserMessage(content)
- }
-}
+ if let Some(message) = self.pending_agent_message.as_ref() {
+ markdown.push('\n');
+ markdown.push_str(&message.to_markdown());
+ }
-impl<T: Into<MessageContent>> From<T> for UserMessage {
- fn from(content: T) -> Self {
- UserMessage(vec![content.into()])
+ markdown
}
}
@@ -1151,130 +1367,6 @@ impl std::ops::DerefMut for ToolCallEventStreamReceiver {
}
}
-impl AgentMessage {
- fn to_request(&self) -> language_model::LanguageModelRequestMessage {
- let mut message = LanguageModelRequestMessage {
- role: self.role,
- content: Vec::with_capacity(self.content.len()),
- cache: false,
- };
-
- const OPEN_CONTEXT: &str = "<context>\n\
- The following items were attached by the user. \
- They are up-to-date and don't need to be re-read.\n\n";
-
- const OPEN_FILES_TAG: &str = "<files>";
- const OPEN_SYMBOLS_TAG: &str = "<symbols>";
- const OPEN_THREADS_TAG: &str = "<threads>";
- const OPEN_RULES_TAG: &str =
- "<rules>\nThe user has specified the following rules that should be applied:\n";
-
- let mut file_context = OPEN_FILES_TAG.to_string();
- let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
- let mut thread_context = OPEN_THREADS_TAG.to_string();
- let mut rules_context = OPEN_RULES_TAG.to_string();
-
- for chunk in &self.content {
- let chunk = match chunk {
- MessageContent::Text(text) => language_model::MessageContent::Text(text.clone()),
- MessageContent::Thinking { text, signature } => {
- language_model::MessageContent::Thinking {
- text: text.clone(),
- signature: signature.clone(),
- }
- }
- MessageContent::RedactedThinking(value) => {
- language_model::MessageContent::RedactedThinking(value.clone())
- }
- MessageContent::ToolUse(value) => {
- language_model::MessageContent::ToolUse(value.clone())
- }
- MessageContent::ToolResult(value) => {
- language_model::MessageContent::ToolResult(value.clone())
- }
- MessageContent::Image(value) => {
- language_model::MessageContent::Image(value.clone())
- }
- MessageContent::Mention { uri, content } => {
- match uri {
- MentionUri::File(path) | MentionUri::Symbol(path, _) => {
- write!(
- &mut symbol_context,
- "\n{}",
- MarkdownCodeBlock {
- tag: &codeblock_tag(&path),
- text: &content.to_string(),
- }
- )
- .ok();
- }
- MentionUri::Thread(_session_id) => {
- write!(&mut thread_context, "\n{}\n", content).ok();
- }
- MentionUri::Rule(_user_prompt_id) => {
- write!(
- &mut rules_context,
- "\n{}",
- MarkdownCodeBlock {
- tag: "",
- text: &content
- }
- )
- .ok();
- }
- }
-
- language_model::MessageContent::Text(uri.to_link())
- }
- };
-
- message.content.push(chunk);
- }
-
- let len_before_context = message.content.len();
-
- if file_context.len() > OPEN_FILES_TAG.len() {
- file_context.push_str("</files>\n");
- message
- .content
- .push(language_model::MessageContent::Text(file_context));
- }
-
- if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
- symbol_context.push_str("</symbols>\n");
- message
- .content
- .push(language_model::MessageContent::Text(symbol_context));
- }
-
- if thread_context.len() > OPEN_THREADS_TAG.len() {
- thread_context.push_str("</threads>\n");
- message
- .content
- .push(language_model::MessageContent::Text(thread_context));
- }
-
- if rules_context.len() > OPEN_RULES_TAG.len() {
- rules_context.push_str("</user_rules>\n");
- message
- .content
- .push(language_model::MessageContent::Text(rules_context));
- }
-
- if message.content.len() > len_before_context {
- message.content.insert(
- len_before_context,
- language_model::MessageContent::Text(OPEN_CONTEXT.into()),
- );
- message
- .content
- .push(language_model::MessageContent::Text("</context>".into()));
- }
-
- message
- }
-}
-
fn codeblock_tag(full_path: &Path) -> String {
let mut result = String::new();
@@ -1287,16 +1379,20 @@ fn codeblock_tag(full_path: &Path) -> String {
result
}
-impl From<acp::ContentBlock> for MessageContent {
+impl From<&str> for UserMessageContent {
+ fn from(text: &str) -> Self {
+ Self::Text(text.into())
+ }
+}
+
+impl From<acp::ContentBlock> for UserMessageContent {
fn from(value: acp::ContentBlock) -> Self {
match value {
- acp::ContentBlock::Text(text_content) => MessageContent::Text(text_content.text),
- acp::ContentBlock::Image(image_content) => {
- MessageContent::Image(convert_image(image_content))
- }
+ acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
+ acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
acp::ContentBlock::Audio(_) => {
// TODO
- MessageContent::Text("[audio]".to_string())
+ Self::Text("[audio]".to_string())
}
acp::ContentBlock::ResourceLink(resource_link) => {
match MentionUri::parse(&resource_link.uri) {
@@ -1306,10 +1402,7 @@ impl From<acp::ContentBlock> for MessageContent {
},
Err(err) => {
log::error!("Failed to parse mention link: {}", err);
- MessageContent::Text(format!(
- "[{}]({})",
- resource_link.name, resource_link.uri
- ))
+ Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
}
}
}
@@ -1322,7 +1415,7 @@ impl From<acp::ContentBlock> for MessageContent {
},
Err(err) => {
log::error!("Failed to parse mention link: {}", err);
- MessageContent::Text(
+ Self::Text(
MarkdownCodeBlock {
tag: &resource.uri,
text: &resource.text,
@@ -1334,7 +1427,7 @@ impl From<acp::ContentBlock> for MessageContent {
}
acp::EmbeddedResourceResource::BlobResourceContents(_) => {
// TODO
- MessageContent::Text("[blob]".to_string())
+ Self::Text("[blob]".to_string())
}
},
}
@@ -1348,9 +1441,3 @@ fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
size: gpui::Size::new(0.into(), 0.into()),
}
}
-
-impl From<&str> for MessageContent {
- fn from(text: &str) -> Self {
- MessageContent::Text(text.into())
- }
-}
@@ -467,6 +467,7 @@ impl AgentConnection for AcpConnection {
fn prompt(
&self,
+ _id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {
@@ -171,6 +171,7 @@ impl AgentConnection for AcpConnection {
fn prompt(
&self,
+ _id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {
@@ -210,6 +210,7 @@ impl AgentConnection for ClaudeAgentConnection {
fn prompt(
&self,
+ _id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<Result<acp::PromptResponse>> {
@@ -423,7 +424,7 @@ impl ClaudeAgentSession {
if !turn_state.borrow().is_cancelled() {
thread
.update(cx, |thread, cx| {
- thread.push_user_content_block(text.into(), cx)
+ thread.push_user_content_block(None, text.into(), cx)
})
.log_err();
}
@@ -679,17 +679,19 @@ impl AcpThreadView {
window: &mut Window,
cx: &mut Context<Self>,
) {
- let count = self.list_state.item_count();
match event {
AcpThreadEvent::NewEntry => {
let index = thread.read(cx).entries().len() - 1;
self.sync_thread_entry_view(index, window, cx);
- self.list_state.splice(count..count, 1);
+ self.list_state.splice(index..index, 1);
}
AcpThreadEvent::EntryUpdated(index) => {
- let index = *index;
- self.sync_thread_entry_view(index, window, cx);
- self.list_state.splice(index..index + 1, 1);
+ self.sync_thread_entry_view(*index, window, cx);
+ self.list_state.splice(*index..index + 1, 1);
+ }
+ AcpThreadEvent::EntriesRemoved(range) => {
+ // TODO: Clean up unused diff editors and terminal views
+ self.list_state.splice(range.clone(), 0);
}
AcpThreadEvent::ToolAuthorizationRequired => {
self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx);
@@ -3789,6 +3791,7 @@ mod tests {
fn prompt(
&self,
+ _id: Option<acp_thread::UserMessageId>,
params: acp::PromptRequest,
cx: &mut App,
) -> Task<gpui::Result<acp::PromptResponse>> {
@@ -3873,6 +3876,7 @@ mod tests {
fn prompt(
&self,
+ _id: Option<acp_thread::UserMessageId>,
_params: acp::PromptRequest,
_cx: &mut App,
) -> Task<gpui::Result<acp::PromptResponse>> {
@@ -1521,7 +1521,8 @@ impl AgentDiff {
self.update_reviewing_editors(workspace, window, cx);
}
}
- AcpThreadEvent::Stopped
+ AcpThreadEvent::EntriesRemoved(_)
+ | AcpThreadEvent::Stopped
| AcpThreadEvent::ToolAuthorizationRequired
| AcpThreadEvent::Error
| AcpThreadEvent::ServerExited(_) => {}
@@ -51,6 +51,7 @@ ashpd.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }
+git = { workspace = true, features = ["test-support"] }
[features]
test-support = ["gpui/test-support", "git/test-support"]
@@ -1,8 +1,9 @@
-use crate::{FakeFs, Fs};
+use crate::{FakeFs, FakeFsEntry, Fs};
use anyhow::{Context as _, Result};
use collections::{HashMap, HashSet};
use futures::future::{self, BoxFuture, join_all};
use git::{
+ Oid,
blame::Blame,
repository::{
AskPassDelegate, Branch, CommitDetails, CommitOptions, FetchOptions, GitRepository,
@@ -12,6 +13,7 @@ use git::{
};
use gpui::{AsyncApp, BackgroundExecutor, SharedString, Task};
use ignore::gitignore::GitignoreBuilder;
+use parking_lot::Mutex;
use rope::Rope;
use smol::future::FutureExt as _;
use std::{path::PathBuf, sync::Arc};
@@ -19,6 +21,7 @@ use std::{path::PathBuf, sync::Arc};
#[derive(Clone)]
pub struct FakeGitRepository {
pub(crate) fs: Arc<FakeFs>,
+ pub(crate) checkpoints: Arc<Mutex<HashMap<Oid, FakeFsEntry>>>,
pub(crate) executor: BackgroundExecutor,
pub(crate) dot_git_path: PathBuf,
pub(crate) repository_dir_path: PathBuf,
@@ -469,22 +472,57 @@ impl GitRepository for FakeGitRepository {
}
fn checkpoint(&self) -> BoxFuture<'static, Result<GitRepositoryCheckpoint>> {
- unimplemented!()
+ let executor = self.executor.clone();
+ let fs = self.fs.clone();
+ let checkpoints = self.checkpoints.clone();
+ let repository_dir_path = self.repository_dir_path.parent().unwrap().to_path_buf();
+ async move {
+ executor.simulate_random_delay().await;
+ let oid = Oid::random(&mut executor.rng());
+ let entry = fs.entry(&repository_dir_path)?;
+ checkpoints.lock().insert(oid, entry);
+ Ok(GitRepositoryCheckpoint { commit_sha: oid })
+ }
+ .boxed()
}
- fn restore_checkpoint(
- &self,
- _checkpoint: GitRepositoryCheckpoint,
- ) -> BoxFuture<'_, Result<()>> {
- unimplemented!()
+ fn restore_checkpoint(&self, checkpoint: GitRepositoryCheckpoint) -> BoxFuture<'_, Result<()>> {
+ let executor = self.executor.clone();
+ let fs = self.fs.clone();
+ let checkpoints = self.checkpoints.clone();
+ let repository_dir_path = self.repository_dir_path.parent().unwrap().to_path_buf();
+ async move {
+ executor.simulate_random_delay().await;
+ let checkpoints = checkpoints.lock();
+ let entry = checkpoints
+ .get(&checkpoint.commit_sha)
+ .context(format!("invalid checkpoint: {}", checkpoint.commit_sha))?;
+ fs.insert_entry(&repository_dir_path, entry.clone())?;
+ Ok(())
+ }
+ .boxed()
}
fn compare_checkpoints(
&self,
- _left: GitRepositoryCheckpoint,
- _right: GitRepositoryCheckpoint,
+ left: GitRepositoryCheckpoint,
+ right: GitRepositoryCheckpoint,
) -> BoxFuture<'_, Result<bool>> {
- unimplemented!()
+ let executor = self.executor.clone();
+ let checkpoints = self.checkpoints.clone();
+ async move {
+ executor.simulate_random_delay().await;
+ let checkpoints = checkpoints.lock();
+ let left = checkpoints
+ .get(&left.commit_sha)
+ .context(format!("invalid left checkpoint: {}", left.commit_sha))?;
+ let right = checkpoints
+ .get(&right.commit_sha)
+ .context(format!("invalid right checkpoint: {}", right.commit_sha))?;
+
+ Ok(left == right)
+ }
+ .boxed()
}
fn diff_checkpoints(
@@ -499,3 +537,63 @@ impl GitRepository for FakeGitRepository {
unimplemented!()
}
}
+
+#[cfg(test)]
+mod tests {
+ use crate::{FakeFs, Fs};
+ use gpui::BackgroundExecutor;
+ use serde_json::json;
+ use std::path::Path;
+ use util::path;
+
+ #[gpui::test]
+ async fn test_checkpoints(executor: BackgroundExecutor) {
+ let fs = FakeFs::new(executor);
+ fs.insert_tree(
+ path!("/"),
+ json!({
+ "bar": {
+ "baz": "qux"
+ },
+ "foo": {
+ ".git": {},
+ "a": "lorem",
+ "b": "ipsum",
+ },
+ }),
+ )
+ .await;
+ fs.with_git_state(Path::new("/foo/.git"), true, |_git| {})
+ .unwrap();
+ let repository = fs.open_repo(Path::new("/foo/.git")).unwrap();
+
+ let checkpoint_1 = repository.checkpoint().await.unwrap();
+ fs.write(Path::new("/foo/b"), b"IPSUM").await.unwrap();
+ fs.write(Path::new("/foo/c"), b"dolor").await.unwrap();
+ let checkpoint_2 = repository.checkpoint().await.unwrap();
+ let checkpoint_3 = repository.checkpoint().await.unwrap();
+
+ assert!(
+ repository
+ .compare_checkpoints(checkpoint_2.clone(), checkpoint_3.clone())
+ .await
+ .unwrap()
+ );
+ assert!(
+ !repository
+ .compare_checkpoints(checkpoint_1.clone(), checkpoint_2.clone())
+ .await
+ .unwrap()
+ );
+
+ repository.restore_checkpoint(checkpoint_1).await.unwrap();
+ assert_eq!(
+ fs.files_with_contents(Path::new("")),
+ [
+ (Path::new("/bar/baz").into(), b"qux".into()),
+ (Path::new("/foo/a").into(), b"lorem".into()),
+ (Path::new("/foo/b").into(), b"ipsum".into())
+ ]
+ );
+ }
+}
@@ -924,7 +924,7 @@ pub struct FakeFs {
#[cfg(any(test, feature = "test-support"))]
struct FakeFsState {
- root: Arc<Mutex<FakeFsEntry>>,
+ root: FakeFsEntry,
next_inode: u64,
next_mtime: SystemTime,
git_event_tx: smol::channel::Sender<PathBuf>,
@@ -939,7 +939,7 @@ struct FakeFsState {
}
#[cfg(any(test, feature = "test-support"))]
-#[derive(Debug)]
+#[derive(Clone, Debug)]
enum FakeFsEntry {
File {
inode: u64,
@@ -953,7 +953,7 @@ enum FakeFsEntry {
inode: u64,
mtime: MTime,
len: u64,
- entries: BTreeMap<String, Arc<Mutex<FakeFsEntry>>>,
+ entries: BTreeMap<String, FakeFsEntry>,
git_repo_state: Option<Arc<Mutex<FakeGitRepositoryState>>>,
},
Symlink {
@@ -961,6 +961,67 @@ enum FakeFsEntry {
},
}
+#[cfg(any(test, feature = "test-support"))]
+impl PartialEq for FakeFsEntry {
+ fn eq(&self, other: &Self) -> bool {
+ match (self, other) {
+ (
+ Self::File {
+ inode: l_inode,
+ mtime: l_mtime,
+ len: l_len,
+ content: l_content,
+ git_dir_path: l_git_dir_path,
+ },
+ Self::File {
+ inode: r_inode,
+ mtime: r_mtime,
+ len: r_len,
+ content: r_content,
+ git_dir_path: r_git_dir_path,
+ },
+ ) => {
+ l_inode == r_inode
+ && l_mtime == r_mtime
+ && l_len == r_len
+ && l_content == r_content
+ && l_git_dir_path == r_git_dir_path
+ }
+ (
+ Self::Dir {
+ inode: l_inode,
+ mtime: l_mtime,
+ len: l_len,
+ entries: l_entries,
+ git_repo_state: l_git_repo_state,
+ },
+ Self::Dir {
+ inode: r_inode,
+ mtime: r_mtime,
+ len: r_len,
+ entries: r_entries,
+ git_repo_state: r_git_repo_state,
+ },
+ ) => {
+ let same_repo_state = match (l_git_repo_state.as_ref(), r_git_repo_state.as_ref()) {
+ (Some(l), Some(r)) => Arc::ptr_eq(l, r),
+ (None, None) => true,
+ _ => false,
+ };
+ l_inode == r_inode
+ && l_mtime == r_mtime
+ && l_len == r_len
+ && l_entries == r_entries
+ && same_repo_state
+ }
+ (Self::Symlink { target: l_target }, Self::Symlink { target: r_target }) => {
+ l_target == r_target
+ }
+ _ => false,
+ }
+ }
+}
+
#[cfg(any(test, feature = "test-support"))]
impl FakeFsState {
fn get_and_increment_mtime(&mut self) -> MTime {
@@ -975,25 +1036,9 @@ impl FakeFsState {
inode
}
- fn read_path(&self, target: &Path) -> Result<Arc<Mutex<FakeFsEntry>>> {
- Ok(self
- .try_read_path(target, true)
- .ok_or_else(|| {
- anyhow!(io::Error::new(
- io::ErrorKind::NotFound,
- format!("not found: {target:?}")
- ))
- })?
- .0)
- }
-
- fn try_read_path(
- &self,
- target: &Path,
- follow_symlink: bool,
- ) -> Option<(Arc<Mutex<FakeFsEntry>>, PathBuf)> {
- let mut path = target.to_path_buf();
+ fn canonicalize(&self, target: &Path, follow_symlink: bool) -> Option<PathBuf> {
let mut canonical_path = PathBuf::new();
+ let mut path = target.to_path_buf();
let mut entry_stack = Vec::new();
'outer: loop {
let mut path_components = path.components().peekable();
@@ -1003,7 +1048,7 @@ impl FakeFsState {
Component::Prefix(prefix_component) => prefix = Some(prefix_component),
Component::RootDir => {
entry_stack.clear();
- entry_stack.push(self.root.clone());
+ entry_stack.push(&self.root);
canonical_path.clear();
match prefix {
Some(prefix_component) => {
@@ -1020,20 +1065,18 @@ impl FakeFsState {
canonical_path.pop();
}
Component::Normal(name) => {
- let current_entry = entry_stack.last().cloned()?;
- let current_entry = current_entry.lock();
- if let FakeFsEntry::Dir { entries, .. } = &*current_entry {
- let entry = entries.get(name.to_str().unwrap()).cloned()?;
+ let current_entry = *entry_stack.last()?;
+ if let FakeFsEntry::Dir { entries, .. } = current_entry {
+ let entry = entries.get(name.to_str().unwrap())?;
if path_components.peek().is_some() || follow_symlink {
- let entry = entry.lock();
- if let FakeFsEntry::Symlink { target, .. } = &*entry {
+ if let FakeFsEntry::Symlink { target, .. } = entry {
let mut target = target.clone();
target.extend(path_components);
path = target;
continue 'outer;
}
}
- entry_stack.push(entry.clone());
+ entry_stack.push(entry);
canonical_path = canonical_path.join(name);
} else {
return None;
@@ -1043,19 +1086,72 @@ impl FakeFsState {
}
break;
}
- Some((entry_stack.pop()?, canonical_path))
+
+ if entry_stack.is_empty() {
+ None
+ } else {
+ Some(canonical_path)
+ }
+ }
+
+ fn try_entry(
+ &mut self,
+ target: &Path,
+ follow_symlink: bool,
+ ) -> Option<(&mut FakeFsEntry, PathBuf)> {
+ let canonical_path = self.canonicalize(target, follow_symlink)?;
+
+ let mut components = canonical_path.components();
+ let Some(Component::RootDir) = components.next() else {
+ panic!(
+ "the path {:?} was not canonicalized properly {:?}",
+ target, canonical_path
+ )
+ };
+
+ let mut entry = &mut self.root;
+ for component in components {
+ match component {
+ Component::Normal(name) => {
+ if let FakeFsEntry::Dir { entries, .. } = entry {
+ entry = entries.get_mut(name.to_str().unwrap())?;
+ } else {
+ return None;
+ }
+ }
+ _ => {
+ panic!(
+ "the path {:?} was not canonicalized properly {:?}",
+ target, canonical_path
+ )
+ }
+ }
+ }
+
+ Some((entry, canonical_path))
}
- fn write_path<Fn, T>(&self, path: &Path, callback: Fn) -> Result<T>
+ fn entry(&mut self, target: &Path) -> Result<&mut FakeFsEntry> {
+ Ok(self
+ .try_entry(target, true)
+ .ok_or_else(|| {
+ anyhow!(io::Error::new(
+ io::ErrorKind::NotFound,
+ format!("not found: {target:?}")
+ ))
+ })?
+ .0)
+ }
+
+ fn write_path<Fn, T>(&mut self, path: &Path, callback: Fn) -> Result<T>
where
- Fn: FnOnce(btree_map::Entry<String, Arc<Mutex<FakeFsEntry>>>) -> Result<T>,
+ Fn: FnOnce(btree_map::Entry<String, FakeFsEntry>) -> Result<T>,
{
let path = normalize_path(path);
let filename = path.file_name().context("cannot overwrite the root")?;
let parent_path = path.parent().unwrap();
- let parent = self.read_path(parent_path)?;
- let mut parent = parent.lock();
+ let parent = self.entry(parent_path)?;
let new_entry = parent
.dir_entries(parent_path)?
.entry(filename.to_str().unwrap().into());
@@ -1105,13 +1201,13 @@ impl FakeFs {
this: this.clone(),
executor: executor.clone(),
state: Arc::new(Mutex::new(FakeFsState {
- root: Arc::new(Mutex::new(FakeFsEntry::Dir {
+ root: FakeFsEntry::Dir {
inode: 0,
mtime: MTime(UNIX_EPOCH),
len: 0,
entries: Default::default(),
git_repo_state: None,
- })),
+ },
git_event_tx: tx,
next_mtime: UNIX_EPOCH + Self::SYSTEMTIME_INTERVAL,
next_inode: 1,
@@ -1161,15 +1257,15 @@ impl FakeFs {
.write_path(path, move |entry| {
match entry {
btree_map::Entry::Vacant(e) => {
- e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
+ e.insert(FakeFsEntry::File {
inode: new_inode,
mtime: new_mtime,
content: Vec::new(),
len: 0,
git_dir_path: None,
- })));
+ });
}
- btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut().lock() {
+ btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut() {
FakeFsEntry::File { mtime, .. } => *mtime = new_mtime,
FakeFsEntry::Dir { mtime, .. } => *mtime = new_mtime,
FakeFsEntry::Symlink { .. } => {}
@@ -1188,7 +1284,7 @@ impl FakeFs {
pub async fn insert_symlink(&self, path: impl AsRef<Path>, target: PathBuf) {
let mut state = self.state.lock();
let path = path.as_ref();
- let file = Arc::new(Mutex::new(FakeFsEntry::Symlink { target }));
+ let file = FakeFsEntry::Symlink { target };
state
.write_path(path.as_ref(), move |e| match e {
btree_map::Entry::Vacant(e) => {
@@ -1221,13 +1317,13 @@ impl FakeFs {
match entry {
btree_map::Entry::Vacant(e) => {
kind = Some(PathEventKind::Created);
- e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
+ e.insert(FakeFsEntry::File {
inode: new_inode,
mtime: new_mtime,
len: new_len,
content: new_content,
git_dir_path: None,
- })));
+ });
}
btree_map::Entry::Occupied(mut e) => {
kind = Some(PathEventKind::Changed);
@@ -1237,7 +1333,7 @@ impl FakeFs {
len,
content,
..
- } = &mut *e.get_mut().lock()
+ } = e.get_mut()
{
*mtime = new_mtime;
*content = new_content;
@@ -1259,9 +1355,8 @@ impl FakeFs {
pub fn read_file_sync(&self, path: impl AsRef<Path>) -> Result<Vec<u8>> {
let path = path.as_ref();
let path = normalize_path(path);
- let state = self.state.lock();
- let entry = state.read_path(&path)?;
- let entry = entry.lock();
+ let mut state = self.state.lock();
+ let entry = state.entry(&path)?;
entry.file_content(&path).cloned()
}
@@ -1269,9 +1364,8 @@ impl FakeFs {
let path = path.as_ref();
let path = normalize_path(path);
self.simulate_random_delay().await;
- let state = self.state.lock();
- let entry = state.read_path(&path)?;
- let entry = entry.lock();
+ let mut state = self.state.lock();
+ let entry = state.entry(&path)?;
entry.file_content(&path).cloned()
}
@@ -1292,6 +1386,25 @@ impl FakeFs {
self.state.lock().flush_events(count);
}
+ pub(crate) fn entry(&self, target: &Path) -> Result<FakeFsEntry> {
+ self.state.lock().entry(target).cloned()
+ }
+
+ pub(crate) fn insert_entry(&self, target: &Path, new_entry: FakeFsEntry) -> Result<()> {
+ let mut state = self.state.lock();
+ state.write_path(target, |entry| {
+ match entry {
+ btree_map::Entry::Vacant(vacant_entry) => {
+ vacant_entry.insert(new_entry);
+ }
+ btree_map::Entry::Occupied(mut occupied_entry) => {
+ occupied_entry.insert(new_entry);
+ }
+ }
+ Ok(())
+ })
+ }
+
#[must_use]
pub fn insert_tree<'a>(
&'a self,
@@ -1361,20 +1474,19 @@ impl FakeFs {
F: FnOnce(&mut FakeGitRepositoryState, &Path, &Path) -> T,
{
let mut state = self.state.lock();
- let entry = state.read_path(dot_git).context("open .git")?;
- let mut entry = entry.lock();
+ let git_event_tx = state.git_event_tx.clone();
+ let entry = state.entry(dot_git).context("open .git")?;
- if let FakeFsEntry::Dir { git_repo_state, .. } = &mut *entry {
+ if let FakeFsEntry::Dir { git_repo_state, .. } = entry {
let repo_state = git_repo_state.get_or_insert_with(|| {
log::debug!("insert git state for {dot_git:?}");
- Arc::new(Mutex::new(FakeGitRepositoryState::new(
- state.git_event_tx.clone(),
- )))
+ Arc::new(Mutex::new(FakeGitRepositoryState::new(git_event_tx)))
});
let mut repo_state = repo_state.lock();
let result = f(&mut repo_state, dot_git, dot_git);
+ drop(repo_state);
if emit_git_event {
state.emit_event([(dot_git, None)]);
}
@@ -1398,21 +1510,20 @@ impl FakeFs {
}
}
.clone();
- drop(entry);
- let Some((git_dir_entry, canonical_path)) = state.try_read_path(&path, true) else {
+ let Some((git_dir_entry, canonical_path)) = state.try_entry(&path, true) else {
anyhow::bail!("pointed-to git dir {path:?} not found")
};
let FakeFsEntry::Dir {
git_repo_state,
entries,
..
- } = &mut *git_dir_entry.lock()
+ } = git_dir_entry
else {
anyhow::bail!("gitfile points to a non-directory")
};
let common_dir = if let Some(child) = entries.get("commondir") {
Path::new(
- std::str::from_utf8(child.lock().file_content("commondir".as_ref())?)
+ std::str::from_utf8(child.file_content("commondir".as_ref())?)
.context("commondir content")?,
)
.to_owned()
@@ -1420,15 +1531,14 @@ impl FakeFs {
canonical_path.clone()
};
let repo_state = git_repo_state.get_or_insert_with(|| {
- Arc::new(Mutex::new(FakeGitRepositoryState::new(
- state.git_event_tx.clone(),
- )))
+ Arc::new(Mutex::new(FakeGitRepositoryState::new(git_event_tx)))
});
let mut repo_state = repo_state.lock();
let result = f(&mut repo_state, &canonical_path, &common_dir);
if emit_git_event {
+ drop(repo_state);
state.emit_event([(canonical_path, None)]);
}
@@ -1655,14 +1765,12 @@ impl FakeFs {
pub fn paths(&self, include_dot_git: bool) -> Vec<PathBuf> {
let mut result = Vec::new();
let mut queue = collections::VecDeque::new();
- queue.push_back((
- PathBuf::from(util::path!("/")),
- self.state.lock().root.clone(),
- ));
+ let state = &*self.state.lock();
+ queue.push_back((PathBuf::from(util::path!("/")), &state.root));
while let Some((path, entry)) = queue.pop_front() {
- if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() {
+ if let FakeFsEntry::Dir { entries, .. } = entry {
for (name, entry) in entries {
- queue.push_back((path.join(name), entry.clone()));
+ queue.push_back((path.join(name), entry));
}
}
if include_dot_git
@@ -1679,14 +1787,12 @@ impl FakeFs {
pub fn directories(&self, include_dot_git: bool) -> Vec<PathBuf> {
let mut result = Vec::new();
let mut queue = collections::VecDeque::new();
- queue.push_back((
- PathBuf::from(util::path!("/")),
- self.state.lock().root.clone(),
- ));
+ let state = &*self.state.lock();
+ queue.push_back((PathBuf::from(util::path!("/")), &state.root));
while let Some((path, entry)) = queue.pop_front() {
- if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() {
+ if let FakeFsEntry::Dir { entries, .. } = entry {
for (name, entry) in entries {
- queue.push_back((path.join(name), entry.clone()));
+ queue.push_back((path.join(name), entry));
}
if include_dot_git
|| !path
@@ -1703,17 +1809,14 @@ impl FakeFs {
pub fn files(&self) -> Vec<PathBuf> {
let mut result = Vec::new();
let mut queue = collections::VecDeque::new();
- queue.push_back((
- PathBuf::from(util::path!("/")),
- self.state.lock().root.clone(),
- ));
+ let state = &*self.state.lock();
+ queue.push_back((PathBuf::from(util::path!("/")), &state.root));
while let Some((path, entry)) = queue.pop_front() {
- let e = entry.lock();
- match &*e {
+ match entry {
FakeFsEntry::File { .. } => result.push(path),
FakeFsEntry::Dir { entries, .. } => {
for (name, entry) in entries {
- queue.push_back((path.join(name), entry.clone()));
+ queue.push_back((path.join(name), entry));
}
}
FakeFsEntry::Symlink { .. } => {}
@@ -1725,13 +1828,10 @@ impl FakeFs {
pub fn files_with_contents(&self, prefix: &Path) -> Vec<(PathBuf, Vec<u8>)> {
let mut result = Vec::new();
let mut queue = collections::VecDeque::new();
- queue.push_back((
- PathBuf::from(util::path!("/")),
- self.state.lock().root.clone(),
- ));
+ let state = &*self.state.lock();
+ queue.push_back((PathBuf::from(util::path!("/")), &state.root));
while let Some((path, entry)) = queue.pop_front() {
- let e = entry.lock();
- match &*e {
+ match entry {
FakeFsEntry::File { content, .. } => {
if path.starts_with(prefix) {
result.push((path, content.clone()));
@@ -1739,7 +1839,7 @@ impl FakeFs {
}
FakeFsEntry::Dir { entries, .. } => {
for (name, entry) in entries {
- queue.push_back((path.join(name), entry.clone()));
+ queue.push_back((path.join(name), entry));
}
}
FakeFsEntry::Symlink { .. } => {}
@@ -1805,10 +1905,7 @@ impl FakeFsEntry {
}
}
- fn dir_entries(
- &mut self,
- path: &Path,
- ) -> Result<&mut BTreeMap<String, Arc<Mutex<FakeFsEntry>>>> {
+ fn dir_entries(&mut self, path: &Path) -> Result<&mut BTreeMap<String, FakeFsEntry>> {
if let Self::Dir { entries, .. } = self {
Ok(entries)
} else {
@@ -1855,12 +1952,12 @@ struct FakeHandle {
impl FileHandle for FakeHandle {
fn current_path(&self, fs: &Arc<dyn Fs>) -> Result<PathBuf> {
let fs = fs.as_fake();
- let state = fs.state.lock();
- let Some(target) = state.moves.get(&self.inode) else {
+ let mut state = fs.state.lock();
+ let Some(target) = state.moves.get(&self.inode).cloned() else {
anyhow::bail!("fake fd not moved")
};
- if state.try_read_path(&target, false).is_some() {
+ if state.try_entry(&target, false).is_some() {
return Ok(target.clone());
}
anyhow::bail!("fake fd target not found")
@@ -1888,13 +1985,13 @@ impl Fs for FakeFs {
state.write_path(&cur_path, |entry| {
entry.or_insert_with(|| {
created_dirs.push((cur_path.clone(), Some(PathEventKind::Created)));
- Arc::new(Mutex::new(FakeFsEntry::Dir {
+ FakeFsEntry::Dir {
inode,
mtime,
len: 0,
entries: Default::default(),
git_repo_state: None,
- }))
+ }
});
Ok(())
})?
@@ -1909,13 +2006,13 @@ impl Fs for FakeFs {
let mut state = self.state.lock();
let inode = state.get_and_increment_inode();
let mtime = state.get_and_increment_mtime();
- let file = Arc::new(Mutex::new(FakeFsEntry::File {
+ let file = FakeFsEntry::File {
inode,
mtime,
len: 0,
content: Vec::new(),
git_dir_path: None,
- }));
+ };
let mut kind = Some(PathEventKind::Created);
state.write_path(path, |entry| {
match entry {
@@ -1939,7 +2036,7 @@ impl Fs for FakeFs {
async fn create_symlink(&self, path: &Path, target: PathBuf) -> Result<()> {
let mut state = self.state.lock();
- let file = Arc::new(Mutex::new(FakeFsEntry::Symlink { target }));
+ let file = FakeFsEntry::Symlink { target };
state
.write_path(path.as_ref(), move |e| match e {
btree_map::Entry::Vacant(e) => {
@@ -2002,7 +2099,7 @@ impl Fs for FakeFs {
}
})?;
- let inode = match *moved_entry.lock() {
+ let inode = match moved_entry {
FakeFsEntry::File { inode, .. } => inode,
FakeFsEntry::Dir { inode, .. } => inode,
_ => 0,
@@ -2051,8 +2148,8 @@ impl Fs for FakeFs {
let mut state = self.state.lock();
let mtime = state.get_and_increment_mtime();
let inode = state.get_and_increment_inode();
- let source_entry = state.read_path(&source)?;
- let content = source_entry.lock().file_content(&source)?.clone();
+ let source_entry = state.entry(&source)?;
+ let content = source_entry.file_content(&source)?.clone();
let mut kind = Some(PathEventKind::Created);
state.write_path(&target, |e| match e {
btree_map::Entry::Occupied(e) => {
@@ -2066,13 +2163,13 @@ impl Fs for FakeFs {
}
}
btree_map::Entry::Vacant(e) => Ok(Some(
- e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
+ e.insert(FakeFsEntry::File {
inode,
mtime,
len: content.len() as u64,
content,
git_dir_path: None,
- })))
+ })
.clone(),
)),
})?;
@@ -2088,8 +2185,7 @@ impl Fs for FakeFs {
let base_name = path.file_name().context("cannot remove the root")?;
let mut state = self.state.lock();
- let parent_entry = state.read_path(parent_path)?;
- let mut parent_entry = parent_entry.lock();
+ let parent_entry = state.entry(parent_path)?;
let entry = parent_entry
.dir_entries(parent_path)?
.entry(base_name.to_str().unwrap().into());
@@ -2100,15 +2196,14 @@ impl Fs for FakeFs {
anyhow::bail!("{path:?} does not exist");
}
}
- btree_map::Entry::Occupied(e) => {
+ btree_map::Entry::Occupied(mut entry) => {
{
- let mut entry = e.get().lock();
- let children = entry.dir_entries(&path)?;
+ let children = entry.get_mut().dir_entries(&path)?;
if !options.recursive && !children.is_empty() {
anyhow::bail!("{path:?} is not empty");
}
}
- e.remove();
+ entry.remove();
}
}
state.emit_event([(path, Some(PathEventKind::Removed))]);
@@ -2122,8 +2217,7 @@ impl Fs for FakeFs {
let parent_path = path.parent().context("cannot remove the root")?;
let base_name = path.file_name().unwrap();
let mut state = self.state.lock();
- let parent_entry = state.read_path(parent_path)?;
- let mut parent_entry = parent_entry.lock();
+ let parent_entry = state.entry(parent_path)?;
let entry = parent_entry
.dir_entries(parent_path)?
.entry(base_name.to_str().unwrap().into());
@@ -2133,9 +2227,9 @@ impl Fs for FakeFs {
anyhow::bail!("{path:?} does not exist");
}
}
- btree_map::Entry::Occupied(e) => {
- e.get().lock().file_content(&path)?;
- e.remove();
+ btree_map::Entry::Occupied(mut entry) => {
+ entry.get_mut().file_content(&path)?;
+ entry.remove();
}
}
state.emit_event([(path, Some(PathEventKind::Removed))]);
@@ -2149,12 +2243,10 @@ impl Fs for FakeFs {
async fn open_handle(&self, path: &Path) -> Result<Arc<dyn FileHandle>> {
self.simulate_random_delay().await;
- let state = self.state.lock();
- let entry = state.read_path(&path)?;
- let entry = entry.lock();
- let inode = match *entry {
- FakeFsEntry::File { inode, .. } => inode,
- FakeFsEntry::Dir { inode, .. } => inode,
+ let mut state = self.state.lock();
+ let inode = match state.entry(&path)? {
+ FakeFsEntry::File { inode, .. } => *inode,
+ FakeFsEntry::Dir { inode, .. } => *inode,
_ => unreachable!(),
};
Ok(Arc::new(FakeHandle { inode }))
@@ -2204,8 +2296,8 @@ impl Fs for FakeFs {
let path = normalize_path(path);
self.simulate_random_delay().await;
let state = self.state.lock();
- let (_, canonical_path) = state
- .try_read_path(&path, true)
+ let canonical_path = state
+ .canonicalize(&path, true)
.with_context(|| format!("path does not exist: {path:?}"))?;
Ok(canonical_path)
}
@@ -2213,9 +2305,9 @@ impl Fs for FakeFs {
async fn is_file(&self, path: &Path) -> bool {
let path = normalize_path(path);
self.simulate_random_delay().await;
- let state = self.state.lock();
- if let Some((entry, _)) = state.try_read_path(&path, true) {
- entry.lock().is_file()
+ let mut state = self.state.lock();
+ if let Some((entry, _)) = state.try_entry(&path, true) {
+ entry.is_file()
} else {
false
}
@@ -2232,17 +2324,16 @@ impl Fs for FakeFs {
let path = normalize_path(path);
let mut state = self.state.lock();
state.metadata_call_count += 1;
- if let Some((mut entry, _)) = state.try_read_path(&path, false) {
- let is_symlink = entry.lock().is_symlink();
+ if let Some((mut entry, _)) = state.try_entry(&path, false) {
+ let is_symlink = entry.is_symlink();
if is_symlink {
- if let Some(e) = state.try_read_path(&path, true).map(|e| e.0) {
+ if let Some(e) = state.try_entry(&path, true).map(|e| e.0) {
entry = e;
} else {
return Ok(None);
}
}
- let entry = entry.lock();
Ok(Some(match &*entry {
FakeFsEntry::File {
inode, mtime, len, ..
@@ -2274,12 +2365,11 @@ impl Fs for FakeFs {
async fn read_link(&self, path: &Path) -> Result<PathBuf> {
self.simulate_random_delay().await;
let path = normalize_path(path);
- let state = self.state.lock();
+ let mut state = self.state.lock();
let (entry, _) = state
- .try_read_path(&path, false)
+ .try_entry(&path, false)
.with_context(|| format!("path does not exist: {path:?}"))?;
- let entry = entry.lock();
- if let FakeFsEntry::Symlink { target } = &*entry {
+ if let FakeFsEntry::Symlink { target } = entry {
Ok(target.clone())
} else {
anyhow::bail!("not a symlink: {path:?}")
@@ -2294,8 +2384,7 @@ impl Fs for FakeFs {
let path = normalize_path(path);
let mut state = self.state.lock();
state.read_dir_call_count += 1;
- let entry = state.read_path(&path)?;
- let mut entry = entry.lock();
+ let entry = state.entry(&path)?;
let children = entry.dir_entries(&path)?;
let paths = children
.keys()
@@ -2359,6 +2448,7 @@ impl Fs for FakeFs {
dot_git_path: abs_dot_git.to_path_buf(),
repository_dir_path: repository_dir_path.to_owned(),
common_dir_path: common_dir_path.to_owned(),
+ checkpoints: Arc::default(),
}) as _
},
)
@@ -12,7 +12,7 @@ workspace = true
path = "src/git.rs"
[features]
-test-support = []
+test-support = ["rand"]
[dependencies]
anyhow.workspace = true
@@ -26,6 +26,7 @@ http_client.workspace = true
log.workspace = true
parking_lot.workspace = true
regex.workspace = true
+rand = { workspace = true, optional = true }
rope.workspace = true
schemars.workspace = true
serde.workspace = true
@@ -47,3 +48,4 @@ text = { workspace = true, features = ["test-support"] }
unindent.workspace = true
gpui = { workspace = true, features = ["test-support"] }
tempfile.workspace = true
+rand.workspace = true
@@ -119,6 +119,13 @@ impl Oid {
Ok(Self(oid))
}
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn random(rng: &mut impl rand::Rng) -> Self {
+ let mut bytes = [0; 20];
+ rng.fill(&mut bytes);
+ Self::from_bytes(&bytes).unwrap()
+ }
+
pub fn as_bytes(&self) -> &[u8] {
self.0.as_bytes()
}