Detailed changes
@@ -33,13 +33,23 @@ pub struct UserMessage {
pub id: Option<UserMessageId>,
pub content: ContentBlock,
pub chunks: Vec<acp::ContentBlock>,
- pub checkpoint: Option<GitStoreCheckpoint>,
+ pub checkpoint: Option<Checkpoint>,
+}
+
+#[derive(Debug)]
+pub struct Checkpoint {
+ git_checkpoint: GitStoreCheckpoint,
+ pub show: bool,
}
impl UserMessage {
fn to_markdown(&self, cx: &App) -> String {
let mut markdown = String::new();
- if let Some(_) = self.checkpoint {
+ if self
+ .checkpoint
+ .as_ref()
+ .map_or(false, |checkpoint| checkpoint.show)
+ {
writeln!(markdown, "## User (checkpoint)").unwrap();
} else {
writeln!(markdown, "## User").unwrap();
@@ -1145,9 +1155,12 @@ impl AcpThread {
self.project.read(cx).languages().clone(),
cx,
);
+ let request = acp::PromptRequest {
+ prompt: message.clone(),
+ session_id: self.session_id.clone(),
+ };
let git_store = self.project.read(cx).git_store().clone();
- let old_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
let message_id = if self
.connection
.session_editor(&self.session_id, cx)
@@ -1161,68 +1174,63 @@ impl AcpThread {
AgentThreadEntry::UserMessage(UserMessage {
id: message_id.clone(),
content: block,
- chunks: message.clone(),
+ chunks: message,
checkpoint: None,
}),
cx,
);
+
+ self.run_turn(cx, async move |this, cx| {
+ let old_checkpoint = git_store
+ .update(cx, |git, cx| git.checkpoint(cx))?
+ .await
+ .context("failed to get old checkpoint")
+ .log_err();
+ this.update(cx, |this, cx| {
+ if let Some((_ix, message)) = this.last_user_message() {
+ message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
+ git_checkpoint,
+ show: false,
+ });
+ }
+ this.connection.prompt(message_id, request, cx)
+ })?
+ .await
+ })
+ }
+
+ pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
+ self.run_turn(cx, async move |this, cx| {
+ this.update(cx, |this, cx| {
+ this.connection
+ .resume(&this.session_id, cx)
+ .map(|resume| resume.run(cx))
+ })?
+ .context("resuming a session is not supported")?
+ .await
+ })
+ }
+
+ fn run_turn(
+ &mut self,
+ cx: &mut Context<Self>,
+ f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
+ ) -> BoxFuture<'static, Result<()>> {
self.clear_completed_plan_entries(cx);
- let (old_checkpoint_tx, old_checkpoint_rx) = oneshot::channel();
let (tx, rx) = oneshot::channel();
let cancel_task = self.cancel(cx);
- let request = acp::PromptRequest {
- prompt: message,
- session_id: self.session_id.clone(),
- };
-
- self.send_task = Some(cx.spawn({
- let message_id = message_id.clone();
- async move |this, cx| {
- cancel_task.await;
- old_checkpoint_tx.send(old_checkpoint.await).ok();
- if let Ok(result) = this.update(cx, |this, cx| {
- this.connection.prompt(message_id, request, cx)
- }) {
- tx.send(result.await).log_err();
- }
- }
+ self.send_task = Some(cx.spawn(async move |this, cx| {
+ cancel_task.await;
+ tx.send(f(this, cx).await).ok();
}));
cx.spawn(async move |this, cx| {
- let old_checkpoint = old_checkpoint_rx
- .await
- .map_err(|_| anyhow!("send canceled"))
- .flatten()
- .context("failed to get old checkpoint")
- .log_err();
-
let response = rx.await;
- if let Some((old_checkpoint, message_id)) = old_checkpoint.zip(message_id) {
- let new_checkpoint = git_store
- .update(cx, |git, cx| git.checkpoint(cx))?
- .await
- .context("failed to get new checkpoint")
- .log_err();
- if let Some(new_checkpoint) = new_checkpoint {
- let equal = git_store
- .update(cx, |git, cx| {
- git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
- })?
- .await
- .unwrap_or(true);
- if !equal {
- this.update(cx, |this, cx| {
- if let Some((ix, message)) = this.user_message_mut(&message_id) {
- message.checkpoint = Some(old_checkpoint);
- cx.emit(AcpThreadEvent::EntryUpdated(ix));
- }
- })?;
- }
- }
- }
+ this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
+ .await?;
this.update(cx, |this, cx| {
match response {
@@ -1294,7 +1302,10 @@ impl AcpThread {
return Task::ready(Err(anyhow!("message not found")));
};
- let checkpoint = message.checkpoint.clone();
+ let checkpoint = message
+ .checkpoint
+ .as_ref()
+ .map(|c| c.git_checkpoint.clone());
let git_store = self.project.read(cx).git_store().clone();
cx.spawn(async move |this, cx| {
@@ -1316,6 +1327,59 @@ impl AcpThread {
})
}
+ fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let git_store = self.project.read(cx).git_store().clone();
+
+ let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
+ if let Some(checkpoint) = message.checkpoint.as_ref() {
+ checkpoint.git_checkpoint.clone()
+ } else {
+ return Task::ready(Ok(()));
+ }
+ } else {
+ return Task::ready(Ok(()));
+ };
+
+ let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
+ cx.spawn(async move |this, cx| {
+ let new_checkpoint = new_checkpoint
+ .await
+ .context("failed to get new checkpoint")
+ .log_err();
+ if let Some(new_checkpoint) = new_checkpoint {
+ let equal = git_store
+ .update(cx, |git, cx| {
+ git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
+ })?
+ .await
+ .unwrap_or(true);
+ this.update(cx, |this, cx| {
+ let (ix, message) = this.last_user_message().context("no user message")?;
+ let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
+ checkpoint.show = !equal;
+ cx.emit(AcpThreadEvent::EntryUpdated(ix));
+ anyhow::Ok(())
+ })??;
+ }
+
+ Ok(())
+ })
+ }
+
+ fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
+ self.entries
+ .iter_mut()
+ .enumerate()
+ .rev()
+ .find_map(|(ix, entry)| {
+ if let AgentThreadEntry::UserMessage(message) = entry {
+ Some((ix, message))
+ } else {
+ None
+ }
+ })
+ }
+
fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
self.entries.iter().find_map(|entry| {
if let AgentThreadEntry::UserMessage(message) = entry {
@@ -1552,6 +1616,7 @@ mod tests {
use settings::SettingsStore;
use smol::stream::StreamExt as _;
use std::{
+ any::Any,
cell::RefCell,
path::Path,
rc::Rc,
@@ -2284,6 +2349,10 @@ mod tests {
_session_id: session_id.clone(),
}))
}
+
+ fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+ self
+ }
}
struct FakeAgentSessionEditor {
@@ -4,7 +4,7 @@ use anyhow::Result;
use collections::IndexMap;
use gpui::{Entity, SharedString, Task};
use project::Project;
-use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
+use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
use ui::{App, IconName};
use uuid::Uuid;
@@ -36,6 +36,14 @@ pub trait AgentConnection {
cx: &mut App,
) -> Task<Result<acp::PromptResponse>>;
+ fn resume(
+ &self,
+ _session_id: &acp::SessionId,
+ _cx: &mut App,
+ ) -> Option<Rc<dyn AgentSessionResume>> {
+ None
+ }
+
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
fn session_editor(
@@ -53,12 +61,24 @@ pub trait AgentConnection {
fn model_selector(&self) -> Option<Rc<dyn AgentModelSelector>> {
None
}
+
+ fn into_any(self: Rc<Self>) -> Rc<dyn Any>;
+}
+
+impl dyn AgentConnection {
+ pub fn downcast<T: 'static + AgentConnection + Sized>(self: Rc<Self>) -> Option<Rc<T>> {
+ self.into_any().downcast().ok()
+ }
}
pub trait AgentSessionEditor {
fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
}
+pub trait AgentSessionResume {
+ fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>>;
+}
+
#[derive(Debug)]
pub struct AuthRequired;
@@ -299,6 +319,10 @@ mod test_support {
) -> Option<Rc<dyn AgentSessionEditor>> {
Some(Rc::new(StubAgentSessionEditor))
}
+
+ fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+ self
+ }
}
struct StubAgentSessionEditor;
@@ -1,9 +1,8 @@
-use crate::{AgentResponseEvent, Thread, templates::Templates};
use crate::{
- ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool, DiagnosticsTool,
- EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool,
- OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent,
- WebSearchTool,
+ AgentResponseEvent, ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DeletePathTool,
+ DiagnosticsTool, EditFileTool, FetchTool, FindPathTool, GrepTool, ListDirectoryTool,
+ MovePathTool, NowTool, OpenTool, ReadFileTool, TerminalTool, ThinkingTool, Thread,
+ ToolCallAuthorization, UserMessageContent, WebSearchTool, templates::Templates,
};
use acp_thread::AgentModelSelector;
use agent_client_protocol as acp;
@@ -11,6 +10,7 @@ use agent_settings::AgentSettings;
use anyhow::{Context as _, Result, anyhow};
use collections::{HashSet, IndexMap};
use fs::Fs;
+use futures::channel::mpsc;
use futures::{StreamExt, future};
use gpui::{
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
@@ -21,6 +21,7 @@ use prompt_store::{
ProjectContext, PromptId, PromptStore, RulesFileContext, UserRulesContext, WorktreeContext,
};
use settings::update_settings_file;
+use std::any::Any;
use std::cell::RefCell;
use std::collections::HashMap;
use std::path::Path;
@@ -426,9 +427,9 @@ impl NativeAgent {
self.models.refresh_list(cx);
for session in self.sessions.values_mut() {
session.thread.update(cx, |thread, _| {
- let model_id = LanguageModels::model_id(&thread.selected_model);
+ let model_id = LanguageModels::model_id(&thread.model());
if let Some(model) = self.models.model_from_id(&model_id) {
- thread.selected_model = model.clone();
+ thread.set_model(model.clone());
}
});
}
@@ -439,6 +440,124 @@ impl NativeAgent {
#[derive(Clone)]
pub struct NativeAgentConnection(pub Entity<NativeAgent>);
+impl NativeAgentConnection {
+ pub fn thread(&self, session_id: &acp::SessionId, cx: &App) -> Option<Entity<Thread>> {
+ self.0
+ .read(cx)
+ .sessions
+ .get(session_id)
+ .map(|session| session.thread.clone())
+ }
+
+ fn run_turn(
+ &self,
+ session_id: acp::SessionId,
+ cx: &mut App,
+ f: impl 'static
+ + FnOnce(
+ Entity<Thread>,
+ &mut App,
+ ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>>,
+ ) -> Task<Result<acp::PromptResponse>> {
+ let Some((thread, acp_thread)) = self.0.update(cx, |agent, _cx| {
+ agent
+ .sessions
+ .get_mut(&session_id)
+ .map(|s| (s.thread.clone(), s.acp_thread.clone()))
+ }) else {
+ return Task::ready(Err(anyhow!("Session not found")));
+ };
+ log::debug!("Found session for: {}", session_id);
+
+ let mut response_stream = match f(thread, cx) {
+ Ok(stream) => stream,
+ Err(err) => return Task::ready(Err(err)),
+ };
+ cx.spawn(async move |cx| {
+ // Handle response stream and forward to session.acp_thread
+ while let Some(result) = response_stream.next().await {
+ match result {
+ Ok(event) => {
+ log::trace!("Received completion event: {:?}", event);
+
+ match event {
+ AgentResponseEvent::Text(text) => {
+ acp_thread.update(cx, |thread, cx| {
+ thread.push_assistant_content_block(
+ acp::ContentBlock::Text(acp::TextContent {
+ text,
+ annotations: None,
+ }),
+ false,
+ cx,
+ )
+ })?;
+ }
+ AgentResponseEvent::Thinking(text) => {
+ acp_thread.update(cx, |thread, cx| {
+ thread.push_assistant_content_block(
+ acp::ContentBlock::Text(acp::TextContent {
+ text,
+ annotations: None,
+ }),
+ true,
+ cx,
+ )
+ })?;
+ }
+ AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
+ tool_call,
+ options,
+ response,
+ }) => {
+ let recv = acp_thread.update(cx, |thread, cx| {
+ thread.request_tool_call_authorization(tool_call, options, cx)
+ })?;
+ cx.background_spawn(async move {
+ if let Some(option) = recv
+ .await
+ .context("authorization sender was dropped")
+ .log_err()
+ {
+ response
+ .send(option)
+ .map(|_| anyhow!("authorization receiver was dropped"))
+ .log_err();
+ }
+ })
+ .detach();
+ }
+ AgentResponseEvent::ToolCall(tool_call) => {
+ acp_thread.update(cx, |thread, cx| {
+ thread.upsert_tool_call(tool_call, cx)
+ })?;
+ }
+ AgentResponseEvent::ToolCallUpdate(update) => {
+ acp_thread.update(cx, |thread, cx| {
+ thread.update_tool_call(update, cx)
+ })??;
+ }
+ AgentResponseEvent::Stop(stop_reason) => {
+ log::debug!("Assistant message complete: {:?}", stop_reason);
+ return Ok(acp::PromptResponse { stop_reason });
+ }
+ }
+ }
+ Err(e) => {
+ log::error!("Error in model response stream: {:?}", e);
+ return Err(e);
+ }
+ }
+ }
+
+ log::info!("Response stream completed");
+ anyhow::Ok(acp::PromptResponse {
+ stop_reason: acp::StopReason::EndTurn,
+ })
+ })
+ }
+}
+
impl AgentModelSelector for NativeAgentConnection {
fn list_models(&self, cx: &mut App) -> Task<Result<acp_thread::AgentModelList>> {
log::debug!("NativeAgentConnection::list_models called");
@@ -472,7 +591,7 @@ impl AgentModelSelector for NativeAgentConnection {
};
thread.update(cx, |thread, _cx| {
- thread.selected_model = model.clone();
+ thread.set_model(model.clone());
});
update_settings_file::<AgentSettings>(
@@ -502,7 +621,7 @@ impl AgentModelSelector for NativeAgentConnection {
else {
return Task::ready(Err(anyhow!("Session not found")));
};
- let model = thread.read(cx).selected_model.clone();
+ let model = thread.read(cx).model().clone();
let Some(provider) = LanguageModelRegistry::read_global(cx).provider(&model.provider_id())
else {
return Task::ready(Err(anyhow!("Provider not found")));
@@ -644,25 +763,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
) -> Task<Result<acp::PromptResponse>> {
let id = id.expect("UserMessageId is required");
let session_id = params.session_id.clone();
- let agent = self.0.clone();
log::info!("Received prompt request for session: {}", session_id);
log::debug!("Prompt blocks count: {}", params.prompt.len());
- cx.spawn(async move |cx| {
- // Get session
- let (thread, acp_thread) = agent
- .update(cx, |agent, _| {
- agent
- .sessions
- .get_mut(&session_id)
- .map(|s| (s.thread.clone(), s.acp_thread.clone()))
- })?
- .ok_or_else(|| {
- log::error!("Session not found: {}", session_id);
- anyhow::anyhow!("Session not found")
- })?;
- log::debug!("Found session for: {}", session_id);
-
+ self.run_turn(session_id, cx, |thread, cx| {
let content: Vec<UserMessageContent> = params
.prompt
.into_iter()
@@ -672,99 +776,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
log::debug!("Message id: {:?}", id);
log::debug!("Message content: {:?}", content);
- // Get model using the ModelSelector capability (always available for agent2)
- // Get the selected model from the thread directly
- let model = thread.read_with(cx, |thread, _| thread.selected_model.clone())?;
-
- // Send to thread
- log::info!("Sending message to thread with model: {:?}", model.name());
- let mut response_stream =
- thread.update(cx, |thread, cx| thread.send(id, content, cx))?;
-
- // Handle response stream and forward to session.acp_thread
- while let Some(result) = response_stream.next().await {
- match result {
- Ok(event) => {
- log::trace!("Received completion event: {:?}", event);
-
- match event {
- AgentResponseEvent::Text(text) => {
- acp_thread.update(cx, |thread, cx| {
- thread.push_assistant_content_block(
- acp::ContentBlock::Text(acp::TextContent {
- text,
- annotations: None,
- }),
- false,
- cx,
- )
- })?;
- }
- AgentResponseEvent::Thinking(text) => {
- acp_thread.update(cx, |thread, cx| {
- thread.push_assistant_content_block(
- acp::ContentBlock::Text(acp::TextContent {
- text,
- annotations: None,
- }),
- true,
- cx,
- )
- })?;
- }
- AgentResponseEvent::ToolCallAuthorization(ToolCallAuthorization {
- tool_call,
- options,
- response,
- }) => {
- let recv = acp_thread.update(cx, |thread, cx| {
- thread.request_tool_call_authorization(tool_call, options, cx)
- })?;
- cx.background_spawn(async move {
- if let Some(option) = recv
- .await
- .context("authorization sender was dropped")
- .log_err()
- {
- response
- .send(option)
- .map(|_| anyhow!("authorization receiver was dropped"))
- .log_err();
- }
- })
- .detach();
- }
- AgentResponseEvent::ToolCall(tool_call) => {
- acp_thread.update(cx, |thread, cx| {
- thread.upsert_tool_call(tool_call, cx)
- })?;
- }
- AgentResponseEvent::ToolCallUpdate(update) => {
- acp_thread.update(cx, |thread, cx| {
- thread.update_tool_call(update, cx)
- })??;
- }
- AgentResponseEvent::Stop(stop_reason) => {
- log::debug!("Assistant message complete: {:?}", stop_reason);
- return Ok(acp::PromptResponse { stop_reason });
- }
- }
- }
- Err(e) => {
- log::error!("Error in model response stream: {:?}", e);
- // TODO: Consider sending an error message to the UI
- break;
- }
- }
- }
-
- log::info!("Response stream completed");
- anyhow::Ok(acp::PromptResponse {
- stop_reason: acp::StopReason::EndTurn,
- })
+ Ok(thread.update(cx, |thread, cx| {
+ log::info!(
+ "Sending message to thread with model: {:?}",
+ thread.model().name()
+ );
+ thread.send(id, content, cx)
+ }))
})
}
+ fn resume(
+ &self,
+ session_id: &acp::SessionId,
+ _cx: &mut App,
+ ) -> Option<Rc<dyn acp_thread::AgentSessionResume>> {
+ Some(Rc::new(NativeAgentSessionResume {
+ connection: self.clone(),
+ session_id: session_id.clone(),
+ }) as _)
+ }
+
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
log::info!("Cancelling on session: {}", session_id);
self.0.update(cx, |agent, cx| {
@@ -786,6 +818,10 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
.map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
})
}
+
+ fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+ self
+ }
}
struct NativeAgentSessionEditor(Entity<Thread>);
@@ -796,6 +832,20 @@ impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
}
}
+struct NativeAgentSessionResume {
+ connection: NativeAgentConnection,
+ session_id: acp::SessionId,
+}
+
+impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
+ fn run(&self, cx: &mut App) -> Task<Result<acp::PromptResponse>> {
+ self.connection
+ .run_turn(self.session_id.clone(), cx, |thread, cx| {
+ thread.update(cx, |thread, cx| thread.resume(cx))
+ })
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -957,7 +1007,7 @@ mod tests {
agent.read_with(cx, |agent, _| {
let session = agent.sessions.get(&session_id).unwrap();
session.thread.read_with(cx, |thread, _| {
- assert_eq!(thread.selected_model.id().0, "fake");
+ assert_eq!(thread.model().id().0, "fake");
});
});
@@ -12,9 +12,9 @@ use gpui::{
};
use indoc::indoc;
use language_model::{
- LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
- LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, Role, StopReason,
- fake_provider::FakeLanguageModel,
+ LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry,
+ LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
+ Role, StopReason, fake_provider::FakeLanguageModel,
};
use project::Project;
use prompt_store::ProjectContext;
@@ -394,8 +394,194 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
}
+#[gpui::test]
+async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ let events = thread.update(cx, |thread, cx| {
+ thread.add_tool(EchoTool);
+ thread.send(UserMessageId::new(), ["abc"], cx)
+ });
+ cx.run_until_parked();
+ let tool_use = LanguageModelToolUse {
+ id: "tool_id_1".into(),
+ name: EchoTool.name().into(),
+ raw_input: "{}".into(),
+ input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
+ is_input_complete: true,
+ };
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
+ fake_model.end_last_completion_stream();
+
+ cx.run_until_parked();
+ let completion = fake_model.pending_completions().pop().unwrap();
+ let tool_result = LanguageModelToolResult {
+ tool_use_id: "tool_id_1".into(),
+ tool_name: EchoTool.name().into(),
+ is_error: false,
+ content: "def".into(),
+ output: Some("def".into()),
+ };
+ assert_eq!(
+ completion.messages[1..],
+ vec![
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec!["abc".into()],
+ cache: false
+ },
+ LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content: vec![MessageContent::ToolUse(tool_use.clone())],
+ cache: false
+ },
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![MessageContent::ToolResult(tool_result.clone())],
+ cache: false
+ },
+ ]
+ );
+
+ // Simulate reaching tool use limit.
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
+ cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
+ ));
+ fake_model.end_last_completion_stream();
+ let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
+ assert!(
+ last_event
+ .unwrap_err()
+ .is::<language_model::ToolUseLimitReachedError>()
+ );
+
+ let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
+ cx.run_until_parked();
+ let completion = fake_model.pending_completions().pop().unwrap();
+ assert_eq!(
+ completion.messages[1..],
+ vec![
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec!["abc".into()],
+ cache: false
+ },
+ LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content: vec![MessageContent::ToolUse(tool_use)],
+ cache: false
+ },
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![MessageContent::ToolResult(tool_result)],
+ cache: false
+ },
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec!["Continue where you left off".into()],
+ cache: false
+ }
+ ]
+ );
+
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
+ fake_model.end_last_completion_stream();
+ events.collect::<Vec<_>>().await;
+ thread.read_with(cx, |thread, _cx| {
+ assert_eq!(
+ thread.last_message().unwrap().to_markdown(),
+ indoc! {"
+ ## Assistant
+
+ Done
+ "}
+ )
+ });
+
+ // Ensure we error if calling resume when tool use limit was *not* reached.
+ let error = thread
+ .update(cx, |thread, cx| thread.resume(cx))
+ .unwrap_err();
+ assert_eq!(
+ error.to_string(),
+ "can only resume after tool use limit is reached"
+ )
+}
+
+#[gpui::test]
+async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
+ let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+ let fake_model = model.as_fake();
+
+ let events = thread.update(cx, |thread, cx| {
+ thread.add_tool(EchoTool);
+ thread.send(UserMessageId::new(), ["abc"], cx)
+ });
+ cx.run_until_parked();
+
+ let tool_use = LanguageModelToolUse {
+ id: "tool_id_1".into(),
+ name: EchoTool.name().into(),
+ raw_input: "{}".into(),
+ input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
+ is_input_complete: true,
+ };
+ let tool_result = LanguageModelToolResult {
+ tool_use_id: "tool_id_1".into(),
+ tool_name: EchoTool.name().into(),
+ is_error: false,
+ content: "def".into(),
+ output: Some("def".into()),
+ };
+ fake_model
+ .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
+ cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
+ ));
+ fake_model.end_last_completion_stream();
+ let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
+ assert!(
+ last_event
+ .unwrap_err()
+ .is::<language_model::ToolUseLimitReachedError>()
+ );
+
+ thread.update(cx, |thread, cx| {
+ thread.send(UserMessageId::new(), vec!["ghi"], cx)
+ });
+ cx.run_until_parked();
+ let completion = fake_model.pending_completions().pop().unwrap();
+ assert_eq!(
+ completion.messages[1..],
+ vec![
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec!["abc".into()],
+ cache: false
+ },
+ LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content: vec![MessageContent::ToolUse(tool_use)],
+ cache: false
+ },
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![MessageContent::ToolResult(tool_result)],
+ cache: false
+ },
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec!["ghi".into()],
+ cache: false
+ }
+ ]
+ );
+}
+
async fn expect_tool_call(
- events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
+ events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
) -> acp::ToolCall {
let event = events
.next()
@@ -411,7 +597,7 @@ async fn expect_tool_call(
}
async fn expect_tool_call_update_fields(
- events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
+ events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
) -> acp::ToolCallUpdate {
let event = events
.next()
@@ -429,7 +615,7 @@ async fn expect_tool_call_update_fields(
}
async fn next_tool_call_authorization(
- events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
+ events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
) -> ToolCallAuthorization {
loop {
let event = events
@@ -1007,9 +1193,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
}
/// Filters out the stop events for asserting against in tests
-fn stop_events(
- result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
-) -> Vec<acp::StopReason> {
+fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
result_events
.into_iter()
.filter_map(|event| match event.unwrap() {
@@ -7,7 +7,7 @@ use std::future;
#[derive(JsonSchema, Serialize, Deserialize)]
pub struct EchoToolInput {
/// The text to echo.
- text: String,
+ pub text: String,
}
pub struct EchoTool;
@@ -2,10 +2,10 @@ use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
use acp_thread::{MentionUri, UserMessageId};
use action_log::ActionLog;
use agent_client_protocol as acp;
-use agent_settings::{AgentProfileId, AgentSettings};
+use agent_settings::{AgentProfileId, AgentSettings, CompletionMode};
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::adapt_schema_to_format;
-use cloud_llm_client::{CompletionIntent, CompletionMode};
+use cloud_llm_client::{CompletionIntent, CompletionRequestStatus};
use collections::IndexMap;
use fs::Fs;
use futures::{
@@ -14,10 +14,10 @@ use futures::{
};
use gpui::{App, Context, Entity, SharedString, Task};
use language_model::{
- LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelImage,
- LanguageModelProviderId, LanguageModelRequest, LanguageModelRequestMessage,
- LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
- LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
+ LanguageModel, LanguageModelCompletionEvent, LanguageModelImage, LanguageModelProviderId,
+ LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
+ LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
+ LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
};
use project::Project;
use prompt_store::ProjectContext;
@@ -33,6 +33,7 @@ use util::{ResultExt, markdown::MarkdownCodeBlock};
pub enum Message {
User(UserMessage),
Agent(AgentMessage),
+ Resume,
}
impl Message {
@@ -47,6 +48,7 @@ impl Message {
match self {
Message::User(message) => message.to_markdown(),
Message::Agent(message) => message.to_markdown(),
+ Message::Resume => "[resumed after tool use limit was reached]".into(),
}
}
}
@@ -320,7 +322,11 @@ impl AgentMessage {
}
pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
- let mut content = Vec::with_capacity(self.content.len());
+ let mut assistant_message = LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content: Vec::with_capacity(self.content.len()),
+ cache: false,
+ };
for chunk in &self.content {
let chunk = match chunk {
AgentMessageContent::Text(text) => {
@@ -342,29 +348,30 @@ impl AgentMessage {
language_model::MessageContent::Image(value.clone())
}
};
- content.push(chunk);
+ assistant_message.content.push(chunk);
}
- let mut messages = vec![LanguageModelRequestMessage {
- role: Role::Assistant,
- content,
+ let mut user_message = LanguageModelRequestMessage {
+ role: Role::User,
+ content: Vec::new(),
cache: false,
- }];
+ };
- if !self.tool_results.is_empty() {
- let mut tool_results = Vec::with_capacity(self.tool_results.len());
- for tool_result in self.tool_results.values() {
- tool_results.push(language_model::MessageContent::ToolResult(
+ for tool_result in self.tool_results.values() {
+ user_message
+ .content
+ .push(language_model::MessageContent::ToolResult(
tool_result.clone(),
));
- }
- messages.push(LanguageModelRequestMessage {
- role: Role::User,
- content: tool_results,
- cache: false,
- });
}
+ let mut messages = Vec::new();
+ if !assistant_message.content.is_empty() {
+ messages.push(assistant_message);
+ }
+ if !user_message.content.is_empty() {
+ messages.push(user_message);
+ }
messages
}
}
@@ -413,11 +420,12 @@ pub struct Thread {
running_turn: Option<Task<()>>,
pending_message: Option<AgentMessage>,
tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
+ tool_use_limit_reached: bool,
context_server_registry: Entity<ContextServerRegistry>,
profile_id: AgentProfileId,
project_context: Rc<RefCell<ProjectContext>>,
templates: Arc<Templates>,
- pub selected_model: Arc<dyn LanguageModel>,
+ model: Arc<dyn LanguageModel>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
}
@@ -429,7 +437,7 @@ impl Thread {
context_server_registry: Entity<ContextServerRegistry>,
action_log: Entity<ActionLog>,
templates: Arc<Templates>,
- default_model: Arc<dyn LanguageModel>,
+ model: Arc<dyn LanguageModel>,
cx: &mut Context<Self>,
) -> Self {
let profile_id = AgentSettings::get_global(cx).default_profile.clone();
@@ -439,11 +447,12 @@ impl Thread {
running_turn: None,
pending_message: None,
tools: BTreeMap::default(),
+ tool_use_limit_reached: false,
context_server_registry,
profile_id,
project_context,
templates,
- selected_model: default_model,
+ model,
project,
action_log,
}
@@ -457,7 +466,19 @@ impl Thread {
&self.action_log
}
- pub fn set_mode(&mut self, mode: CompletionMode) {
+ pub fn model(&self) -> &Arc<dyn LanguageModel> {
+ &self.model
+ }
+
+ pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
+ self.model = model;
+ }
+
+ pub fn completion_mode(&self) -> CompletionMode {
+ self.completion_mode
+ }
+
+ pub fn set_completion_mode(&mut self, mode: CompletionMode) {
self.completion_mode = mode;
}
@@ -499,36 +520,59 @@ impl Thread {
Ok(())
}
+ pub fn resume(
+ &mut self,
+ cx: &mut Context<Self>,
+ ) -> Result<mpsc::UnboundedReceiver<Result<AgentResponseEvent>>> {
+ anyhow::ensure!(
+ self.tool_use_limit_reached,
+ "can only resume after tool use limit is reached"
+ );
+
+ self.messages.push(Message::Resume);
+ cx.notify();
+
+ log::info!("Total messages in thread: {}", self.messages.len());
+ Ok(self.run_turn(cx))
+ }
+
/// Sending a message results in the model streaming a response, which could include tool calls.
/// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
/// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
pub fn send<T>(
&mut self,
- message_id: UserMessageId,
+ id: UserMessageId,
content: impl IntoIterator<Item = T>,
cx: &mut Context<Self>,
- ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>
+ ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>>
where
T: Into<UserMessageContent>,
{
- let model = self.selected_model.clone();
+ log::info!("Thread::send called with model: {:?}", self.model.name());
+
let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
- log::info!("Thread::send called with model: {:?}", model.name());
log::debug!("Thread::send content: {:?}", content);
+ self.messages
+ .push(Message::User(UserMessage { id, content }));
cx.notify();
- let (events_tx, events_rx) =
- mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
- let event_stream = AgentResponseEventStream(events_tx);
- self.messages.push(Message::User(UserMessage {
- id: message_id.clone(),
- content,
- }));
log::info!("Total messages in thread: {}", self.messages.len());
+ self.run_turn(cx)
+ }
+
+ fn run_turn(
+ &mut self,
+ cx: &mut Context<Self>,
+ ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent>> {
+ let model = self.model.clone();
+ let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
+ let event_stream = AgentResponseEventStream(events_tx);
+ let message_ix = self.messages.len().saturating_sub(1);
+ self.tool_use_limit_reached = false;
self.running_turn = Some(cx.spawn(async move |this, cx| {
log::info!("Starting agent turn execution");
- let turn_result = async {
+ let turn_result: Result<()> = async {
let mut completion_intent = CompletionIntent::UserPrompt;
loop {
log::debug!(
@@ -543,13 +587,22 @@ impl Thread {
let mut events = model.stream_completion(request, cx).await?;
log::debug!("Stream completion started successfully");
+ let mut tool_use_limit_reached = false;
let mut tool_uses = FuturesUnordered::new();
while let Some(event) = events.next().await {
match event? {
+ LanguageModelCompletionEvent::StatusUpdate(
+ CompletionRequestStatus::ToolUseLimitReached,
+ ) => {
+ tool_use_limit_reached = true;
+ }
LanguageModelCompletionEvent::Stop(reason) => {
event_stream.send_stop(reason);
if reason == StopReason::Refusal {
- this.update(cx, |this, _cx| this.truncate(message_id))??;
+ this.update(cx, |this, _cx| {
+ this.flush_pending_message();
+ this.messages.truncate(message_ix);
+ })?;
return Ok(());
}
}
@@ -567,12 +620,7 @@ impl Thread {
}
}
- if tool_uses.is_empty() {
- log::info!("No tool uses found, completing turn");
- return Ok(());
- }
- log::info!("Found {} tool uses to execute", tool_uses.len());
-
+ let used_tools = tool_uses.is_empty();
while let Some(tool_result) = tool_uses.next().await {
log::info!("Tool finished {:?}", tool_result);
@@ -596,8 +644,17 @@ impl Thread {
.ok();
}
- this.update(cx, |this, _| this.flush_pending_message())?;
- completion_intent = CompletionIntent::ToolResults;
+ if tool_use_limit_reached {
+ log::info!("Tool use limit reached, completing turn");
+ this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
+ return Err(language_model::ToolUseLimitReachedError.into());
+ } else if used_tools {
+ log::info!("No tool uses found, completing turn");
+ return Ok(());
+ } else {
+ this.update(cx, |this, _| this.flush_pending_message())?;
+ completion_intent = CompletionIntent::ToolResults;
+ }
}
}
.await;
@@ -678,10 +735,10 @@ impl Thread {
fn handle_text_event(
&mut self,
new_text: String,
- events_stream: &AgentResponseEventStream,
+ event_stream: &AgentResponseEventStream,
cx: &mut Context<Self>,
) {
- events_stream.send_text(&new_text);
+ event_stream.send_text(&new_text);
let last_message = self.pending_message();
if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
@@ -798,8 +855,9 @@ impl Thread {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
});
- let supports_images = self.selected_model.supports_images();
+ let supports_images = self.model.supports_images();
let tool_result = tool.run(tool_use.input, tool_event_stream, cx);
+ log::info!("Running tool {}", tool_use.name);
Some(cx.foreground_executor().spawn(async move {
let tool_result = tool_result.await.and_then(|output| {
if let LanguageModelToolResultContent::Image(_) = &output.llm_output {
@@ -902,7 +960,7 @@ impl Thread {
name: tool_name,
description: tool.description().to_string(),
input_schema: tool
- .input_schema(self.selected_model.tool_input_format())
+ .input_schema(self.model.tool_input_format())
.log_err()?,
})
})
@@ -917,7 +975,7 @@ impl Thread {
thread_id: None,
prompt_id: None,
intent: Some(completion_intent),
- mode: Some(self.completion_mode),
+ mode: Some(self.completion_mode.into()),
messages,
tools,
tool_choice: None,
@@ -935,7 +993,7 @@ impl Thread {
.profiles
.get(&self.profile_id)
.context("profile not found")?;
- let provider_id = self.selected_model.provider_id();
+ let provider_id = self.model.provider_id();
Ok(self
.tools
@@ -971,6 +1029,11 @@ impl Thread {
match message {
Message::User(message) => messages.push(message.to_request()),
Message::Agent(message) => messages.extend(message.to_request()),
+ Message::Resume => messages.push(LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec!["Continue where you left off".into()],
+ cache: false,
+ }),
}
}
@@ -1123,9 +1186,7 @@ where
}
#[derive(Clone)]
-struct AgentResponseEventStream(
- mpsc::UnboundedSender<Result<AgentResponseEvent, LanguageModelCompletionError>>,
-);
+struct AgentResponseEventStream(mpsc::UnboundedSender<Result<AgentResponseEvent>>);
impl AgentResponseEventStream {
fn send_text(&self, text: &str) {
@@ -1212,8 +1273,8 @@ impl AgentResponseEventStream {
}
}
- fn send_error(&self, error: LanguageModelCompletionError) {
- self.0.unbounded_send(Err(error)).ok();
+ fn send_error(&self, error: impl Into<anyhow::Error>) {
+ self.0.unbounded_send(Err(error.into())).ok();
}
}
@@ -1229,8 +1290,7 @@ pub struct ToolCallEventStream {
impl ToolCallEventStream {
#[cfg(test)]
pub fn test() -> (Self, ToolCallEventStreamReceiver) {
- let (events_tx, events_rx) =
- mpsc::unbounded::<Result<AgentResponseEvent, LanguageModelCompletionError>>();
+ let (events_tx, events_rx) = mpsc::unbounded::<Result<AgentResponseEvent>>();
let stream = ToolCallEventStream::new(
&LanguageModelToolUse {
@@ -1351,9 +1411,7 @@ impl ToolCallEventStream {
}
#[cfg(test)]
-pub struct ToolCallEventStreamReceiver(
- mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
-);
+pub struct ToolCallEventStreamReceiver(mpsc::UnboundedReceiver<Result<AgentResponseEvent>>);
#[cfg(test)]
impl ToolCallEventStreamReceiver {
@@ -1381,7 +1439,7 @@ impl ToolCallEventStreamReceiver {
#[cfg(test)]
impl std::ops::Deref for ToolCallEventStreamReceiver {
- type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>;
+ type Target = mpsc::UnboundedReceiver<Result<AgentResponseEvent>>;
fn deref(&self) -> &Self::Target {
&self.0
@@ -241,7 +241,7 @@ impl AgentTool for EditFileTool {
thread.build_completion_request(CompletionIntent::ToolResults, cx)
});
let thread = self.thread.read(cx);
- let model = thread.selected_model.clone();
+ let model = thread.model().clone();
let action_log = thread.action_log().clone();
let authorize = self.authorize(&input, &event_stream, cx);
@@ -5,7 +5,7 @@ use anyhow::{Context as _, Result, anyhow};
use futures::channel::oneshot;
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use project::Project;
-use std::{cell::RefCell, path::Path, rc::Rc};
+use std::{any::Any, cell::RefCell, path::Path, rc::Rc};
use ui::App;
use util::ResultExt as _;
@@ -507,4 +507,8 @@ impl AgentConnection for AcpConnection {
})
.detach_and_log_err(cx)
}
+
+ fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+ self
+ }
}
@@ -3,9 +3,9 @@ use anyhow::anyhow;
use collections::HashMap;
use futures::channel::oneshot;
use project::Project;
-use std::cell::RefCell;
use std::path::Path;
use std::rc::Rc;
+use std::{any::Any, cell::RefCell};
use anyhow::{Context as _, Result};
use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
@@ -191,6 +191,10 @@ impl AgentConnection for AcpConnection {
.spawn(async move { conn.cancel(params).await })
.detach();
}
+
+ fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+ self
+ }
}
struct ClientDelegate {
@@ -6,6 +6,7 @@ use context_server::listener::McpServerTool;
use project::Project;
use settings::SettingsStore;
use smol::process::Child;
+use std::any::Any;
use std::cell::RefCell;
use std::fmt::Display;
use std::path::Path;
@@ -289,6 +290,10 @@ impl AgentConnection for ClaudeAgentConnection {
})
.log_err();
}
+
+ fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+ self
+ }
}
#[derive(Clone, Copy)]
@@ -7,20 +7,21 @@ use action_log::ActionLog;
use agent::{TextThreadStore, ThreadStore};
use agent_client_protocol::{self as acp};
use agent_servers::AgentServer;
-use agent_settings::{AgentSettings, NotifyWhenAgentWaiting};
+use agent_settings::{AgentSettings, CompletionMode, NotifyWhenAgentWaiting};
use anyhow::bail;
use audio::{Audio, Sound};
use buffer_diff::BufferDiff;
+use client::zed_urls;
use collections::{HashMap, HashSet};
use editor::scroll::Autoscroll;
use editor::{Editor, EditorMode, MultiBuffer, PathKey, SelectionEffects};
use file_icons::FileIcons;
use gpui::{
- Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, EdgesRefinement, Empty, Entity,
- FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, PlatformDisplay,
- SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement,
- Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop,
- linear_gradient, list, percentage, point, prelude::*, pulsating_between,
+ Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, ClipboardItem, EdgesRefinement,
+ Empty, Entity, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton,
+ PlatformDisplay, SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle,
+ TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div,
+ linear_color_stop, linear_gradient, list, percentage, point, prelude::*, pulsating_between,
};
use language::Buffer;
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
@@ -32,8 +33,8 @@ use std::{collections::BTreeMap, process::ExitStatus, rc::Rc, time::Duration};
use text::Anchor;
use theme::ThemeSettings;
use ui::{
- Disclosure, Divider, DividerColor, KeyBinding, PopoverMenuHandle, Scrollbar, ScrollbarState,
- Tooltip, prelude::*,
+ Callout, Disclosure, Divider, DividerColor, ElevationIndex, KeyBinding, PopoverMenuHandle,
+ Scrollbar, ScrollbarState, Tooltip, prelude::*,
};
use util::{ResultExt, size::format_file_size, time::duration_alt_display};
use workspace::{CollaboratorId, Workspace};
@@ -44,16 +45,39 @@ use super::entry_view_state::EntryViewState;
use crate::acp::AcpModelSelectorPopover;
use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
use crate::agent_diff::AgentDiff;
-use crate::ui::{AgentNotification, AgentNotificationEvent};
+use crate::ui::{AgentNotification, AgentNotificationEvent, BurnModeTooltip};
use crate::{
- AgentDiffPane, AgentPanel, ExpandMessageEditor, Follow, KeepAll, OpenAgentDiff, RejectAll,
+ AgentDiffPane, AgentPanel, ContinueThread, ContinueWithBurnMode, ExpandMessageEditor, Follow,
+ KeepAll, OpenAgentDiff, RejectAll, ToggleBurnMode,
};
const RESPONSE_PADDING_X: Pixels = px(19.);
-
pub const MIN_EDITOR_LINES: usize = 4;
pub const MAX_EDITOR_LINES: usize = 8;
+enum ThreadError {
+ PaymentRequired,
+ ModelRequestLimitReached(cloud_llm_client::Plan),
+ ToolUseLimitReached,
+ Other(SharedString),
+}
+
+impl ThreadError {
+ fn from_err(error: anyhow::Error) -> Self {
+ if error.is::<language_model::PaymentRequiredError>() {
+ Self::PaymentRequired
+ } else if error.is::<language_model::ToolUseLimitReachedError>() {
+ Self::ToolUseLimitReached
+ } else if let Some(error) =
+ error.downcast_ref::<language_model::ModelRequestLimitReachedError>()
+ {
+ Self::ModelRequestLimitReached(error.plan)
+ } else {
+ Self::Other(error.to_string().into())
+ }
+ }
+}
+
pub struct AcpThreadView {
agent: Rc<dyn AgentServer>,
workspace: WeakEntity<Workspace>,
@@ -66,7 +90,7 @@ pub struct AcpThreadView {
model_selector: Option<Entity<AcpModelSelectorPopover>>,
notifications: Vec<WindowHandle<AgentNotification>>,
notification_subscriptions: HashMap<WindowHandle<AgentNotification>, Vec<Subscription>>,
- last_error: Option<Entity<Markdown>>,
+ thread_error: Option<ThreadError>,
list_state: ListState,
scrollbar_state: ScrollbarState,
auth_task: Option<Task<()>>,
@@ -151,7 +175,7 @@ impl AcpThreadView {
entry_view_state: EntryViewState::default(),
list_state: list_state.clone(),
scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
- last_error: None,
+ thread_error: None,
auth_task: None,
expanded_tool_calls: HashSet::default(),
expanded_thinking_blocks: HashSet::default(),
@@ -316,7 +340,7 @@ impl AcpThreadView {
}
pub fn cancel_generation(&mut self, cx: &mut Context<Self>) {
- self.last_error.take();
+ self.thread_error.take();
if let Some(thread) = self.thread() {
self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx)));
@@ -371,6 +395,25 @@ impl AcpThreadView {
}
}
+ fn resume_chat(&mut self, cx: &mut Context<Self>) {
+ self.thread_error.take();
+ let Some(thread) = self.thread() else {
+ return;
+ };
+
+ let task = thread.update(cx, |thread, cx| thread.resume(cx));
+ cx.spawn(async move |this, cx| {
+ let result = task.await;
+
+ this.update(cx, |this, cx| {
+ if let Err(err) = result {
+ this.handle_thread_error(err, cx);
+ }
+ })
+ })
+ .detach();
+ }
+
fn send(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let contents = self
.message_editor
@@ -384,7 +427,7 @@ impl AcpThreadView {
window: &mut Window,
cx: &mut Context<Self>,
) {
- self.last_error.take();
+ self.thread_error.take();
self.editing_message.take();
let Some(thread) = self.thread().cloned() else {
@@ -409,11 +452,9 @@ impl AcpThreadView {
});
cx.spawn(async move |this, cx| {
- if let Err(e) = task.await {
+ if let Err(err) = task.await {
this.update(cx, |this, cx| {
- this.last_error =
- Some(cx.new(|cx| Markdown::new(e.to_string().into(), None, None, cx)));
- cx.notify()
+ this.handle_thread_error(err, cx);
})
.ok();
}
@@ -476,6 +517,16 @@ impl AcpThreadView {
})
}
+ fn handle_thread_error(&mut self, error: anyhow::Error, cx: &mut Context<Self>) {
+ self.thread_error = Some(ThreadError::from_err(error));
+ cx.notify();
+ }
+
+ fn clear_thread_error(&mut self, cx: &mut Context<Self>) {
+ self.thread_error = None;
+ cx.notify();
+ }
+
fn handle_thread_event(
&mut self,
thread: &Entity<AcpThread>,
@@ -551,7 +602,7 @@ impl AcpThreadView {
return;
};
- self.last_error.take();
+ self.thread_error.take();
let authenticate = connection.authenticate(method, cx);
self.auth_task = Some(cx.spawn_in(window, {
let project = self.project.clone();
@@ -561,9 +612,7 @@ impl AcpThreadView {
this.update_in(cx, |this, window, cx| {
if let Err(err) = result {
- this.last_error = Some(cx.new(|cx| {
- Markdown::new(format!("Error: {err}").into(), None, None, cx)
- }))
+ this.handle_thread_error(err, cx);
} else {
this.thread_state = Self::initial_state(
agent,
@@ -620,9 +669,7 @@ impl AcpThreadView {
.py_4()
.px_2()
.children(message.id.clone().and_then(|message_id| {
- message.checkpoint.as_ref()?;
-
- Some(
+ message.checkpoint.as_ref()?.show.then(|| {
Button::new("restore-checkpoint", "Restore Checkpoint")
.icon(IconName::Undo)
.icon_size(IconSize::XSmall)
@@ -630,8 +677,8 @@ impl AcpThreadView {
.label_size(LabelSize::XSmall)
.on_click(cx.listener(move |this, _, _window, cx| {
this.rewind(&message_id, cx);
- })),
- )
+ }))
+ })
}))
.child(
v_flex()
@@ -2322,7 +2369,12 @@ impl AcpThreadView {
h_flex()
.flex_none()
.justify_between()
- .child(self.render_follow_toggle(cx))
+ .child(
+ h_flex()
+ .gap_1()
+ .child(self.render_follow_toggle(cx))
+ .children(self.render_burn_mode_toggle(cx)),
+ )
.child(
h_flex()
.gap_1()
@@ -2333,6 +2385,68 @@ impl AcpThreadView {
.into_any()
}
+ fn as_native_connection(&self, cx: &App) -> Option<Rc<agent2::NativeAgentConnection>> {
+ let acp_thread = self.thread()?.read(cx);
+ acp_thread.connection().clone().downcast()
+ }
+
+ fn as_native_thread(&self, cx: &App) -> Option<Entity<agent2::Thread>> {
+ let acp_thread = self.thread()?.read(cx);
+ self.as_native_connection(cx)?
+ .thread(acp_thread.session_id(), cx)
+ }
+
+ fn toggle_burn_mode(
+ &mut self,
+ _: &ToggleBurnMode,
+ _window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ let Some(thread) = self.as_native_thread(cx) else {
+ return;
+ };
+
+ thread.update(cx, |thread, _cx| {
+ let current_mode = thread.completion_mode();
+ thread.set_completion_mode(match current_mode {
+ CompletionMode::Burn => CompletionMode::Normal,
+ CompletionMode::Normal => CompletionMode::Burn,
+ });
+ });
+ }
+
+ fn render_burn_mode_toggle(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
+ let thread = self.as_native_thread(cx)?.read(cx);
+
+ if !thread.model().supports_burn_mode() {
+ return None;
+ }
+
+ let active_completion_mode = thread.completion_mode();
+ let burn_mode_enabled = active_completion_mode == CompletionMode::Burn;
+ let icon = if burn_mode_enabled {
+ IconName::ZedBurnModeOn
+ } else {
+ IconName::ZedBurnMode
+ };
+
+ Some(
+ IconButton::new("burn-mode", icon)
+ .icon_size(IconSize::Small)
+ .icon_color(Color::Muted)
+ .toggle_state(burn_mode_enabled)
+ .selected_icon_color(Color::Error)
+ .on_click(cx.listener(|this, _event, window, cx| {
+ this.toggle_burn_mode(&ToggleBurnMode, window, cx);
+ }))
+ .tooltip(move |_window, cx| {
+ cx.new(|_| BurnModeTooltip::new().selected(burn_mode_enabled))
+ .into()
+ })
+ .into_any_element(),
+ )
+ }
+
fn start_editing_message(&mut self, index: usize, window: &mut Window, cx: &mut Context<Self>) {
let Some(thread) = self.thread() else {
return;
@@ -3002,6 +3116,187 @@ impl AcpThreadView {
}
}
+impl AcpThreadView {
+ fn render_thread_error(&self, window: &mut Window, cx: &mut Context<'_, Self>) -> Option<Div> {
+ let content = match self.thread_error.as_ref()? {
+ ThreadError::Other(error) => self.render_any_thread_error(error.clone(), cx),
+ ThreadError::PaymentRequired => self.render_payment_required_error(cx),
+ ThreadError::ModelRequestLimitReached(plan) => {
+ self.render_model_request_limit_reached_error(*plan, cx)
+ }
+ ThreadError::ToolUseLimitReached => {
+ self.render_tool_use_limit_reached_error(window, cx)?
+ }
+ };
+
+ Some(
+ div()
+ .border_t_1()
+ .border_color(cx.theme().colors().border)
+ .child(content),
+ )
+ }
+
+ fn render_any_thread_error(&self, error: SharedString, cx: &mut Context<'_, Self>) -> Callout {
+ let icon = Icon::new(IconName::XCircle)
+ .size(IconSize::Small)
+ .color(Color::Error);
+
+ Callout::new()
+ .icon(icon)
+ .title("Error")
+ .description(error.clone())
+ .secondary_action(self.create_copy_button(error.to_string()))
+ .primary_action(self.dismiss_error_button(cx))
+ .bg_color(self.error_callout_bg(cx))
+ }
+
+ fn render_payment_required_error(&self, cx: &mut Context<Self>) -> Callout {
+ const ERROR_MESSAGE: &str =
+ "You reached your free usage limit. Upgrade to Zed Pro for more prompts.";
+
+ let icon = Icon::new(IconName::XCircle)
+ .size(IconSize::Small)
+ .color(Color::Error);
+
+ Callout::new()
+ .icon(icon)
+ .title("Free Usage Exceeded")
+ .description(ERROR_MESSAGE)
+ .tertiary_action(self.upgrade_button(cx))
+ .secondary_action(self.create_copy_button(ERROR_MESSAGE))
+ .primary_action(self.dismiss_error_button(cx))
+ .bg_color(self.error_callout_bg(cx))
+ }
+
+ fn render_model_request_limit_reached_error(
+ &self,
+ plan: cloud_llm_client::Plan,
+ cx: &mut Context<Self>,
+ ) -> Callout {
+ let error_message = match plan {
+ cloud_llm_client::Plan::ZedPro => "Upgrade to usage-based billing for more prompts.",
+ cloud_llm_client::Plan::ZedProTrial | cloud_llm_client::Plan::ZedFree => {
+ "Upgrade to Zed Pro for more prompts."
+ }
+ };
+
+ let icon = Icon::new(IconName::XCircle)
+ .size(IconSize::Small)
+ .color(Color::Error);
+
+ Callout::new()
+ .icon(icon)
+ .title("Model Prompt Limit Reached")
+ .description(error_message)
+ .tertiary_action(self.upgrade_button(cx))
+ .secondary_action(self.create_copy_button(error_message))
+ .primary_action(self.dismiss_error_button(cx))
+ .bg_color(self.error_callout_bg(cx))
+ }
+
+ fn render_tool_use_limit_reached_error(
+ &self,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Option<Callout> {
+ let thread = self.as_native_thread(cx)?;
+ let supports_burn_mode = thread.read(cx).model().supports_burn_mode();
+
+ let focus_handle = self.focus_handle(cx);
+
+ let icon = Icon::new(IconName::Info)
+ .size(IconSize::Small)
+ .color(Color::Info);
+
+ Some(
+ Callout::new()
+ .icon(icon)
+ .title("Consecutive tool use limit reached.")
+ .when(supports_burn_mode, |this| {
+ this.secondary_action(
+ Button::new("continue-burn-mode", "Continue with Burn Mode")
+ .style(ButtonStyle::Filled)
+ .style(ButtonStyle::Tinted(ui::TintColor::Accent))
+ .layer(ElevationIndex::ModalSurface)
+ .label_size(LabelSize::Small)
+ .key_binding(
+ KeyBinding::for_action_in(
+ &ContinueWithBurnMode,
+ &focus_handle,
+ window,
+ cx,
+ )
+ .map(|kb| kb.size(rems_from_px(10.))),
+ )
+ .tooltip(Tooltip::text("Enable Burn Mode for unlimited tool use."))
+ .on_click({
+ cx.listener(move |this, _, _window, cx| {
+ thread.update(cx, |thread, _cx| {
+ thread.set_completion_mode(CompletionMode::Burn);
+ });
+ this.resume_chat(cx);
+ })
+ }),
+ )
+ })
+ .primary_action(
+ Button::new("continue-conversation", "Continue")
+ .layer(ElevationIndex::ModalSurface)
+ .label_size(LabelSize::Small)
+ .key_binding(
+ KeyBinding::for_action_in(&ContinueThread, &focus_handle, window, cx)
+ .map(|kb| kb.size(rems_from_px(10.))),
+ )
+ .on_click(cx.listener(|this, _, _window, cx| {
+ this.resume_chat(cx);
+ })),
+ ),
+ )
+ }
+
+ fn create_copy_button(&self, message: impl Into<String>) -> impl IntoElement {
+ let message = message.into();
+
+ IconButton::new("copy", IconName::Copy)
+ .icon_size(IconSize::Small)
+ .icon_color(Color::Muted)
+ .tooltip(Tooltip::text("Copy Error Message"))
+ .on_click(move |_, _, cx| {
+ cx.write_to_clipboard(ClipboardItem::new_string(message.clone()))
+ })
+ }
+
+ fn dismiss_error_button(&self, cx: &mut Context<Self>) -> impl IntoElement {
+ IconButton::new("dismiss", IconName::Close)
+ .icon_size(IconSize::Small)
+ .icon_color(Color::Muted)
+ .tooltip(Tooltip::text("Dismiss Error"))
+ .on_click(cx.listener({
+ move |this, _, _, cx| {
+ this.clear_thread_error(cx);
+ cx.notify();
+ }
+ }))
+ }
+
+ fn upgrade_button(&self, cx: &mut Context<Self>) -> impl IntoElement {
+ Button::new("upgrade", "Upgrade")
+ .label_size(LabelSize::Small)
+ .style(ButtonStyle::Tinted(ui::TintColor::Accent))
+ .on_click(cx.listener({
+ move |this, _, _, cx| {
+ this.clear_thread_error(cx);
+ cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx));
+ }
+ }))
+ }
+
+ fn error_callout_bg(&self, cx: &Context<Self>) -> Hsla {
+ cx.theme().status().error.opacity(0.08)
+ }
+}
+
impl Focusable for AcpThreadView {
fn focus_handle(&self, cx: &App) -> FocusHandle {
self.message_editor.focus_handle(cx)
@@ -3016,6 +3311,7 @@ impl Render for AcpThreadView {
.size_full()
.key_context("AcpThread")
.on_action(cx.listener(Self::open_agent_diff))
+ .on_action(cx.listener(Self::toggle_burn_mode))
.bg(cx.theme().colors().panel_background)
.child(match &self.thread_state {
ThreadState::Unauthenticated { connection } => v_flex()
@@ -3100,19 +3396,7 @@ impl Render for AcpThreadView {
}
_ => this,
})
- .when_some(self.last_error.clone(), |el, error| {
- el.child(
- div()
- .p_2()
- .text_xs()
- .border_t_1()
- .border_color(cx.theme().colors().border)
- .bg(cx.theme().status().error_background)
- .child(
- self.render_markdown(error, default_markdown_style(false, window, cx)),
- ),
- )
- })
+ .children(self.render_thread_error(window, cx))
.child(self.render_message_editor(window, cx))
}
}
@@ -3299,8 +3583,6 @@ fn terminal_command_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
#[cfg(test)]
pub(crate) mod tests {
- use std::path::Path;
-
use acp_thread::StubAgentConnection;
use agent::{TextThreadStore, ThreadStore};
use agent_client_protocol::SessionId;
@@ -3310,6 +3592,8 @@ pub(crate) mod tests {
use project::Project;
use serde_json::json;
use settings::SettingsStore;
+ use std::any::Any;
+ use std::path::Path;
use super::*;
@@ -3547,6 +3831,10 @@ pub(crate) mod tests {
fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
unimplemented!()
}
+
+ fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+ self
+ }
}
pub(crate) fn init_test(cx: &mut TestAppContext) {
@@ -5,7 +5,6 @@ mod agent_diff;
mod agent_model_selector;
mod agent_panel;
mod buffer_codegen;
-mod burn_mode_tooltip;
mod context_picker;
mod context_server_configuration;
mod context_strip;
@@ -1,61 +0,0 @@
-use gpui::{Context, FontWeight, IntoElement, Render, Window};
-use ui::{prelude::*, tooltip_container};
-
-pub struct BurnModeTooltip {
- selected: bool,
-}
-
-impl BurnModeTooltip {
- pub fn new() -> Self {
- Self { selected: false }
- }
-
- pub fn selected(mut self, selected: bool) -> Self {
- self.selected = selected;
- self
- }
-}
-
-impl Render for BurnModeTooltip {
- fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let (icon, color) = if self.selected {
- (IconName::ZedBurnModeOn, Color::Error)
- } else {
- (IconName::ZedBurnMode, Color::Default)
- };
-
- let turned_on = h_flex()
- .h_4()
- .px_1()
- .border_1()
- .border_color(cx.theme().colors().border)
- .bg(cx.theme().colors().text_accent.opacity(0.1))
- .rounded_sm()
- .child(
- Label::new("ON")
- .size(LabelSize::XSmall)
- .weight(FontWeight::SEMIBOLD)
- .color(Color::Accent),
- );
-
- let title = h_flex()
- .gap_1p5()
- .child(Icon::new(icon).size(IconSize::Small).color(color))
- .child(Label::new("Burn Mode"))
- .when(self.selected, |title| title.child(turned_on));
-
- tooltip_container(window, cx, |this, _, _| {
- this
- .child(title)
- .child(
- div()
- .max_w_64()
- .child(
- Label::new("Enables models to use large context windows, unlimited tool calls, and other capabilities for expanded reasoning.")
- .size(LabelSize::Small)
- .color(Color::Muted)
- )
- )
- })
- }
-}
@@ -6,7 +6,7 @@ use crate::agent_diff::AgentDiffThread;
use crate::agent_model_selector::AgentModelSelector;
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
use crate::ui::{
- MaxModeTooltip,
+ BurnModeTooltip,
preview::{AgentPreview, UsageCallout},
};
use agent::history_store::HistoryStore;
@@ -605,7 +605,7 @@ impl MessageEditor {
this.toggle_burn_mode(&ToggleBurnMode, window, cx);
}))
.tooltip(move |_window, cx| {
- cx.new(|_| MaxModeTooltip::new().selected(burn_mode_enabled))
+ cx.new(|_| BurnModeTooltip::new().selected(burn_mode_enabled))
.into()
})
.into_any_element(),
@@ -1,6 +1,6 @@
use crate::{
- burn_mode_tooltip::BurnModeTooltip,
language_model_selector::{LanguageModelSelector, language_model_selector},
+ ui::BurnModeTooltip,
};
use agent_settings::{AgentSettings, CompletionMode};
use anyhow::Result;
@@ -2,11 +2,11 @@ use crate::ToggleBurnMode;
use gpui::{Context, FontWeight, IntoElement, Render, Window};
use ui::{KeyBinding, prelude::*, tooltip_container};
-pub struct MaxModeTooltip {
+pub struct BurnModeTooltip {
selected: bool,
}
-impl MaxModeTooltip {
+impl BurnModeTooltip {
pub fn new() -> Self {
Self { selected: false }
}
@@ -17,7 +17,7 @@ impl MaxModeTooltip {
}
}
-impl Render for MaxModeTooltip {
+impl Render for BurnModeTooltip {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let (icon, color) = if self.selected {
(IconName::ZedBurnModeOn, Color::Error)
@@ -42,6 +42,18 @@ impl fmt::Display for ModelRequestLimitReachedError {
}
}
+#[derive(Error, Debug)]
+pub struct ToolUseLimitReachedError;
+
+impl fmt::Display for ToolUseLimitReachedError {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(
+ f,
+ "Consecutive tool use limit reached. Enable Burn Mode for unlimited tool use."
+ )
+ }
+}
+
#[derive(Clone, Default)]
pub struct LlmApiToken(Arc<RwLock<Option<String>>>);