From 587ed1e3140741b446503ecbe824e50042dd7042 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 1 Jul 2025 12:52:08 +0200 Subject: [PATCH] WIP --- crates/agent/src/agent.rs | 2 +- crates/agent/src/thread.rs | 544 +++++++++++++++++- crates/agent_ui/src/agent_ui.rs | 4 +- .../src/context_picker/completion_provider.rs | 4 +- crates/agent_ui/src/profile_selector.rs | 6 +- crates/agent_ui/src/tool_compatibility.rs | 6 +- 6 files changed, 552 insertions(+), 14 deletions(-) diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 561e38a4369037da446c2bc3b7eb2e0f0d15539b..ccb20f0627292ff7db99f88204aec605e435bbdb 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -10,7 +10,7 @@ pub use context::{AgentContext, ContextId, ContextLoadResult}; pub use context_store::ContextStore; pub use thread::{ LastRestoreCheckpoint, Message, MessageCrease, MessageId, MessageSegment, ThreadError, - ThreadEvent, ThreadFeedback, ThreadId, ThreadSummary, TokenUsageRatio, ZedAgent, + ThreadEvent, ThreadFeedback, ThreadId, ThreadSummary, TokenUsageRatio, ZedAgentThread, }; pub use thread_store::{SerializedThread, TextThreadStore, ThreadStore}; diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 33ecb6fd3afe392fa9534a578d2471986fefd969..d33d09505d222cc5a82edb5caf716aebc7e7e211 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -19,8 +19,9 @@ use collections::{HashMap, HashSet}; use feature_flags::{self, FeatureFlagAppExt}; use futures::{ FutureExt, StreamExt as _, - channel::oneshot, - future::{Either, Shared}, + channel::{mpsc, oneshot}, + future::{BoxFuture, Either, LocalBoxFuture, Shared}, + stream::{BoxStream, LocalBoxStream}, }; use git::repository::DiffType; use gpui::{ @@ -46,7 +47,7 @@ use proto::Plan; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; -use std::fmt::Write; +use std::{collections::VecDeque, fmt::Write}; use std::{ ops::Range, sync::Arc, @@ -980,6 +981,33 @@ impl> From for UserMessageParams { } } +pub struct Turn { + user_message_id: MessageId, + response_events: LocalBoxStream<'static, Result>, +} + +struct ToolCallResult { + task: Task>, + card: Option, +} + +pub enum ResponseEvent { + Text(String), + Thinking(String), + ToolCallChunk { + id: LanguageModelToolUseId, + label: String, + input: serde_json::Value, + }, + ToolCall { + id: LanguageModelToolUseId, + needs_confirmation: bool, + label: String, + run: Box, &mut App) -> ToolCallResult>, + }, + InvalidToolCallChunk(LanguageModelToolUse), +} + impl ZedAgentThread { pub fn new( project: Entity, @@ -1610,6 +1638,516 @@ impl ZedAgentThread { } } + pub fn send_message2( + &mut self, + user_message: impl Into, + model: Arc, + window: Option, + cx: &mut Context, + ) -> LocalBoxFuture<'static, Result> { + self.advance_prompt_id(); + + let user_message = user_message.into(); + let prev_turn = self.cancel(); + let (cancel_tx, cancel_rx) = oneshot::channel(); + let (turn_tx, turn_rx) = oneshot::channel(); + self.pending_turn = Some(PendingTurn { + task: cx.spawn(async move |this, cx| { + if let Some(prev_turn) = prev_turn { + prev_turn.await?; + } + + let user_message_id = + this.update(cx, |this, cx| this.insert_user_message(user_message, cx))?; + let (response_events_tx, response_events_rx) = mpsc::unbounded(); + turn_tx + .send(Turn { + user_message_id, + response_events: response_events_rx.boxed_local(), + }) + .ok(); + + Self::turn_loop2( + &this, + model, + CompletionIntent::UserPrompt, + cancel_rx, + response_events_tx, + window, + cx, + ) + .await?; + + this.update(cx, |this, _cx| this.pending_turn.take()).ok(); + + Ok(()) + }), + cancel_tx, + }); + + async move { turn_rx.await.map_err(|_| anyhow!("Turn loop failed")) }.boxed_local() + } + + async fn turn_loop2( + this: &WeakEntity, + model: Arc, + mut intent: CompletionIntent, + mut cancel_rx: oneshot::Receiver<()>, + mut response_events_tx: mpsc::UnboundedSender>, + window: Option, + cx: &mut AsyncApp, + ) -> Result<()> { + struct RetryState { + attempts: u8, + custom_delay: Option, + } + let mut retry_state: Option = None; + + struct PendingAssistantMessage { + chunks: VecDeque, + } + + impl PendingAssistantMessage { + fn push_text(&mut self, text: String) { + if let Some(PendingAssistantMessageChunk::Text(existing_text)) = + self.chunks.back_mut() + { + existing_text.push_str(&text); + } else { + self.chunks + .push_back(PendingAssistantMessageChunk::Text(text)); + } + } + + fn push_thinking(&mut self, text: String, signature: Option) { + if let Some(PendingAssistantMessageChunk::Thinking { + text: existing_text, + signature: existing_signature, + }) = self.chunks.back_mut() + { + *existing_signature = existing_signature.take().or(signature); + existing_text.push_str(&text); + } else { + self.chunks + .push_back(PendingAssistantMessageChunk::Thinking { text, signature }); + } + } + } + + enum PendingAssistantMessageChunk { + Text(String), + Thinking { + text: String, + signature: Option, + }, + RedactedThinking { + data: String, + }, + ToolCall(PendingAssistantToolCall), + } + + struct PendingAssistantToolCall { + request: LanguageModelToolUse, + output: oneshot::Receiver>, + } + + loop { + let mut segments = Vec::new(); + let mut assistant_message = PendingAssistantMessage { + chunks: VecDeque::new(), + }; + + let send = async { + if let Some(retry_state) = retry_state.as_ref() { + let delay = retry_state.custom_delay.unwrap_or_else(|| { + BASE_RETRY_DELAY * 2_u32.pow((retry_state.attempts - 1) as u32) + }); + cx.background_executor().timer(delay).await; + } + + let request = this.update(cx, |this, cx| this.build_request(&model, intent, cx))?; + let mut events = model.stream_completion(request.clone(), cx).await?; + + while let Some(event) = events.next().await { + let event = event?; + match event { + LanguageModelCompletionEvent::StartMessage { .. } => { + // no-op, todo!("do we wanna insert a new message here?") + } + LanguageModelCompletionEvent::Text(chunk) => { + response_events_tx + .unbounded_send(Ok(ResponseEvent::Text(chunk.clone()))); + assistant_message.push_text(chunk); + } + LanguageModelCompletionEvent::Thinking { text, signature } => { + response_events_tx + .unbounded_send(Ok(ResponseEvent::Thinking(text.clone()))); + assistant_message.push_thinking(text, signature); + } + LanguageModelCompletionEvent::RedactedThinking { data } => { + assistant_message + .chunks + .push_back(PendingAssistantMessageChunk::RedactedThinking { data }); + } + LanguageModelCompletionEvent::ToolUse(tool_use) => { + match this + .read_with(cx, |this, cx| this.tool_for_name(&tool_use.name, cx))? + { + Ok(tool) => { + if tool_use.is_input_complete { + let (output_tx, output_rx) = oneshot::channel(); + let mut request = request.clone(); + // todo!("add the pending assistant message (excluding the tool calls)") + response_events_tx.unbounded_send(Ok( + ResponseEvent::ToolCall { + id: tool_use.id, + needs_confirmation: cx.update(|cx| { + tool.needs_confirmation(&tool_use.input, cx) + })?, + label: tool.ui_text(&tool_use.input), + run: Box::new({ + let project = this + .read_with(cx, |this, _| { + this.project.clone() + })?; + let action_log = this + .read_with(cx, |this, _| { + this.action_log.clone() + })?; + move |window, cx| { + let assistant_tool::ToolResult { + output, + card, + } = tool.run( + tool_use.input, + Arc::new(request), + project, + action_log, + model, + window, + cx, + ); + + ToolCallResult { + task: cx.foreground_executor().spawn( + async move { + match output.await { + Ok(output) => { + output_tx + .send(Ok(output)) + .ok(); + Ok(()) + } + Err(error) => { + let error = + Arc::new(error); + output_tx + .send(Err(anyhow!( + error.clone() + ))) + .ok(); + Err(anyhow!(error)) + } + } + }, + ), + card, + } + } + }), + }, + )); + assistant_message.chunks.push_back( + PendingAssistantMessageChunk::ToolCall( + PendingAssistantToolCall { + request: tool_use, + output: output_rx, + }, + ), + ); + } else { + response_events_tx.unbounded_send(Ok( + ResponseEvent::ToolCallChunk { + id: tool_use.id, + label: tool + .still_streaming_ui_text(&tool_use.input), + input: tool_use.input, + }, + )); + } + } + Err(error) => { + response_events_tx.unbounded_send(Ok( + ResponseEvent::InvalidToolCallChunk(tool_use.clone()), + )); + if tool_use.is_input_complete { + let (output_tx, output_rx) = oneshot::channel(); + output_tx.send(Err(error)).unwrap(); + assistant_message.chunks.push_back( + PendingAssistantMessageChunk::ToolCall( + PendingAssistantToolCall { + request: tool_use, + output: output_rx, + }, + ), + ); + } + } + } + } + LanguageModelCompletionEvent::UsageUpdate(_token_usage) => { + // todo! + } + LanguageModelCompletionEvent::StatusUpdate(_completion_request_status) => { + // todo! + } + LanguageModelCompletionEvent::Stop(StopReason::EndTurn) => { + // todo! + } + LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) => { + // todo! + } + LanguageModelCompletionEvent::Stop(StopReason::Refusal) => { + // todo! + } + LanguageModelCompletionEvent::Stop(StopReason::ToolUse) => {} + } + } + + while let Some(chunk) = assistant_message.chunks.pop_front() { + match chunk { + PendingAssistantMessageChunk::Text(_) => todo!(), + PendingAssistantMessageChunk::Thinking { text, signature } => todo!(), + PendingAssistantMessageChunk::RedactedThinking { data } => todo!(), + PendingAssistantMessageChunk::ToolCall(pending_assistant_tool_call) => { + pending_assistant_tool_call.output.await; + } + } + + let (tool_result, thread_result) = pending_tool_use.result().await; + this.update(cx, |thread, cx| { + thread.set_tool_call_result( + pending_tool_use.index_in_message, + thread_result, + cx, + ) + })?; + assistant_message.push(MessageContent::ToolUse(pending_tool_use.request)); + tool_results_message + .content + .push(MessageContent::ToolResult(tool_result)); + } + + anyhow::Ok(()) + } + .boxed_local(); + + enum SendStatus { + Canceled, + Finished(Result<()>), + } + + let status = match futures::future::select(&mut cancel_rx, send).await { + Either::Left(_) => SendStatus::Canceled, + Either::Right((result, _)) => SendStatus::Finished(result), + }; + + match status { + SendStatus::Canceled => { + for pending_tool_use in pending_tool_uses { + tool_results_message + .content + .push(MessageContent::ToolResult(LanguageModelToolResult { + tool_use_id: pending_tool_use.request.id.clone(), + tool_name: pending_tool_use.request.name.clone(), + is_error: true, + content: LanguageModelToolResultContent::Text( + "".into(), + ), + output: None, + })); + assistant_message.push(MessageContent::ToolUse(pending_tool_use.request)); + } + + this.update(cx, |this, _cx| { + if !assistant_message.content.is_empty() { + this.messages.push(assistant_message); + } + + if !tool_results_message.content.is_empty() { + this.messages.push(tool_results_message); + } + })?; + + break; + } + SendStatus::Finished(result) => { + for mut pending_tool_use in pending_tool_uses { + let (tool_result, thread_result) = pending_tool_use.result().await; + this.update(cx, |thread, cx| { + thread.set_tool_call_result( + pending_tool_use.index_in_message, + thread_result, + cx, + ) + })?; + assistant_message.push(MessageContent::ToolUse(pending_tool_use.request)); + tool_results_message + .content + .push(MessageContent::ToolResult(tool_result)); + } + + match result { + Ok(_) => { + retry_state = None; + } + Err(error) => { + let mut retry = |custom_delay: Option| -> bool { + let retry_state = retry_state.get_or_insert_with(|| RetryState { + attempts: 0, + custom_delay, + }); + retry_state.attempts += 1; + retry_state.attempts <= MAX_RETRY_ATTEMPTS + }; + + if error.is::() { + // todo! + // cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired)); + } else if let Some(_error) = + error.downcast_ref::() + { + // todo! + // cx.emit(ThreadEvent::ShowError( + // ThreadError::ModelRequestLimitReached { plan: error.plan }, + // )); + } else if let Some(completion_error) = + error.downcast_ref::() + { + match completion_error { + LanguageModelCompletionError::RateLimitExceeded { + retry_after, + } => { + if !retry(Some(*retry_after)) { + break; + } + } + LanguageModelCompletionError::Overloaded => { + if !retry(None) { + break; + } + } + LanguageModelCompletionError::ApiInternalServerError => { + if !retry(None) { + break; + } + // todo! + } + _ => { + // todo!(emit_generic_error(error, cx);) + break; + } + } + } else if let Some(known_error) = + error.downcast_ref::() + { + match known_error { + LanguageModelKnownError::ContextWindowLimitExceeded { + tokens: _, + } => { + // todo! + // this.exceeded_window_error = + // Some(ExceededWindowError { + // model_id: model.id(), + // token_count: *tokens, + // }); + // cx.notify(); + break; + } + LanguageModelKnownError::RateLimitExceeded { retry_after } => { + // let provider_name = model.provider_name(); + // let error_message = format!( + // "{}'s API rate limit exceeded", + // provider_name.0.as_ref() + // ); + if !retry(Some(*retry_after)) { + // todo! show err + break; + } + } + LanguageModelKnownError::Overloaded => { + //todo! + // let provider_name = model.provider_name(); + // let error_message = format!( + // "{}'s API servers are overloaded right now", + // provider_name.0.as_ref() + // ); + + if !retry(None) { + // todo! show err + break; + } + } + LanguageModelKnownError::ApiInternalServerError => { + // let provider_name = model.provider_name(); + // let error_message = format!( + // "{}'s API server reported an internal server error", + // provider_name.0.as_ref() + // ); + + if !retry(None) { + break; + } + } + LanguageModelKnownError::ReadResponseError(_) + | LanguageModelKnownError::DeserializeResponse(_) + | LanguageModelKnownError::UnknownResponseFormat(_) => { + // In the future we will attempt to re-roll response, but only once + // todo!(emit_generic_error(error, cx);) + break; + } + } + } else { + // todo!(emit_generic_error(error, cx)); + break; + } + } + } + + let done = this.update(cx, |this, cx| { + let done = if assistant_message.content.is_empty() { + true + } else { + this.messages.push(assistant_message); + if tool_results_message.content.is_empty() { + true + } else { + this.messages.push(tool_results_message); + false + } + }; + + let summary_pending = matches!(this.summary(), ThreadSummary::Pending); + + if summary_pending && (done || this.messages.len() > 6) { + this.summarize(cx); + } + + done + })?; + + if done && retry_state.is_none() { + break; + } else { + intent = CompletionIntent::ToolResults; + } + } + } + } + + Ok(()) + } + pub fn send_message( &mut self, params: impl Into, diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index 9076d1b27413d83ead22b8da01c2a92fb0c906c7..0198fa607a1e5eb04e0999481d7c144c2a92787d 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -26,7 +26,7 @@ mod ui; use std::sync::Arc; -use agent::{ThreadId, ZedAgent}; +use agent::{ThreadId, ZedAgentThread}; use agent_settings::{AgentProfileId, AgentSettings, LanguageModelSelection}; use assistant_slash_command::SlashCommandRegistry; use client::Client; @@ -114,7 +114,7 @@ impl ManageProfiles { #[derive(Clone)] pub(crate) enum ModelUsageContext { - Thread(Entity), + Thread(Entity), InlineAssistant, } diff --git a/crates/agent_ui/src/context_picker/completion_provider.rs b/crates/agent_ui/src/context_picker/completion_provider.rs index cfc85578ef724e7d81a8da783c889acfacc43107..f5c62c0abe52f17eb92393e5a89f361354a43f21 100644 --- a/crates/agent_ui/src/context_picker/completion_provider.rs +++ b/crates/agent_ui/src/context_picker/completion_provider.rs @@ -22,7 +22,7 @@ use util::ResultExt as _; use workspace::Workspace; use agent::{ - ZedAgent, + ZedAgentThread, context::{AgentContextHandle, AgentContextKey, RULES_ICON}, thread_store::{TextThreadStore, ThreadStore}, }; @@ -449,7 +449,7 @@ impl ContextPickerCompletionProvider { let context_store = context_store.clone(); let thread_store = thread_store.clone(); window.spawn::<_, Option<_>>(cx, async move |cx| { - let thread: Entity = thread_store + let thread: Entity = thread_store .update_in(cx, |thread_store, window, cx| { thread_store.open_thread(&thread_id, window, cx) }) diff --git a/crates/agent_ui/src/profile_selector.rs b/crates/agent_ui/src/profile_selector.rs index da76390eadd5417fcdf0f65f69bf734af7cb7c6d..fc28e512169b094259548dd9fd194eed45577c03 100644 --- a/crates/agent_ui/src/profile_selector.rs +++ b/crates/agent_ui/src/profile_selector.rs @@ -1,6 +1,6 @@ use crate::{ManageProfiles, ToggleProfileSelector}; use agent::{ - ZedAgent, + ZedAgentThread, agent_profile::{AgentProfile, AvailableProfiles}, }; use agent_settings::{AgentDockPosition, AgentProfileId, AgentSettings, builtin_profiles}; @@ -17,7 +17,7 @@ use ui::{ pub struct ProfileSelector { profiles: AvailableProfiles, fs: Arc, - thread: Entity, + thread: Entity, menu_handle: PopoverMenuHandle, focus_handle: FocusHandle, _subscriptions: Vec, @@ -26,7 +26,7 @@ pub struct ProfileSelector { impl ProfileSelector { pub fn new( fs: Arc, - thread: Entity, + thread: Entity, focus_handle: FocusHandle, cx: &mut Context, ) -> Self { diff --git a/crates/agent_ui/src/tool_compatibility.rs b/crates/agent_ui/src/tool_compatibility.rs index 0e25cb38ac07454551ef0376b75fcc4d8ff551c4..b51d145bb95c412d4ba9957d10e9c3e324113dd4 100644 --- a/crates/agent_ui/src/tool_compatibility.rs +++ b/crates/agent_ui/src/tool_compatibility.rs @@ -1,4 +1,4 @@ -use agent::{ThreadEvent, ZedAgent}; +use agent::{ThreadEvent, ZedAgentThread}; use assistant_tool::{Tool, ToolSource}; use collections::HashMap; use gpui::{App, Context, Entity, IntoElement, Render, Subscription, Window}; @@ -8,12 +8,12 @@ use ui::prelude::*; pub struct IncompatibleToolsState { cache: HashMap>>, - thread: Entity, + thread: Entity, _thread_subscription: Subscription, } impl IncompatibleToolsState { - pub fn new(thread: Entity, cx: &mut Context) -> Self { + pub fn new(thread: Entity, cx: &mut Context) -> Self { let _tool_working_set_subscription = cx.subscribe(&thread, |this, _, event, _| match event { ThreadEvent::ProfileChanged => {