WIP

Antonio Scandurra created

Change summary

crates/agent/src/agent.rs                                 |   2 
crates/agent/src/thread.rs                                | 544 ++++++++
crates/agent_ui/src/agent_ui.rs                           |   4 
crates/agent_ui/src/context_picker/completion_provider.rs |   4 
crates/agent_ui/src/profile_selector.rs                   |   6 
crates/agent_ui/src/tool_compatibility.rs                 |   6 
6 files changed, 552 insertions(+), 14 deletions(-)

Detailed changes

crates/agent/src/agent.rs 🔗

@@ -10,7 +10,7 @@ pub use context::{AgentContext, ContextId, ContextLoadResult};
 pub use context_store::ContextStore;
 pub use thread::{
     LastRestoreCheckpoint, Message, MessageCrease, MessageId, MessageSegment, ThreadError,
-    ThreadEvent, ThreadFeedback, ThreadId, ThreadSummary, TokenUsageRatio, ZedAgent,
+    ThreadEvent, ThreadFeedback, ThreadId, ThreadSummary, TokenUsageRatio, ZedAgentThread,
 };
 pub use thread_store::{SerializedThread, TextThreadStore, ThreadStore};
 

crates/agent/src/thread.rs 🔗

@@ -19,8 +19,9 @@ use collections::{HashMap, HashSet};
 use feature_flags::{self, FeatureFlagAppExt};
 use futures::{
     FutureExt, StreamExt as _,
-    channel::oneshot,
-    future::{Either, Shared},
+    channel::{mpsc, oneshot},
+    future::{BoxFuture, Either, LocalBoxFuture, Shared},
+    stream::{BoxStream, LocalBoxStream},
 };
 use git::repository::DiffType;
 use gpui::{
@@ -46,7 +47,7 @@ use proto::Plan;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use settings::Settings;
-use std::fmt::Write;
+use std::{collections::VecDeque, fmt::Write};
 use std::{
     ops::Range,
     sync::Arc,
@@ -980,6 +981,33 @@ impl<T: Into<String>> From<T> for UserMessageParams {
     }
 }
 
+pub struct Turn {
+    user_message_id: MessageId,
+    response_events: LocalBoxStream<'static, Result<ResponseEvent>>,
+}
+
+struct ToolCallResult {
+    task: Task<Result<()>>,
+    card: Option<AnyToolCard>,
+}
+
+pub enum ResponseEvent {
+    Text(String),
+    Thinking(String),
+    ToolCallChunk {
+        id: LanguageModelToolUseId,
+        label: String,
+        input: serde_json::Value,
+    },
+    ToolCall {
+        id: LanguageModelToolUseId,
+        needs_confirmation: bool,
+        label: String,
+        run: Box<dyn FnOnce(Option<AnyWindowHandle>, &mut App) -> ToolCallResult>,
+    },
+    InvalidToolCallChunk(LanguageModelToolUse),
+}
+
 impl ZedAgentThread {
     pub fn new(
         project: Entity<Project>,
@@ -1610,6 +1638,516 @@ impl ZedAgentThread {
         }
     }
 
+    pub fn send_message2(
+        &mut self,
+        user_message: impl Into<UserMessageParams>,
+        model: Arc<dyn LanguageModel>,
+        window: Option<AnyWindowHandle>,
+        cx: &mut Context<Self>,
+    ) -> LocalBoxFuture<'static, Result<Turn>> {
+        self.advance_prompt_id();
+
+        let user_message = user_message.into();
+        let prev_turn = self.cancel();
+        let (cancel_tx, cancel_rx) = oneshot::channel();
+        let (turn_tx, turn_rx) = oneshot::channel();
+        self.pending_turn = Some(PendingTurn {
+            task: cx.spawn(async move |this, cx| {
+                if let Some(prev_turn) = prev_turn {
+                    prev_turn.await?;
+                }
+
+                let user_message_id =
+                    this.update(cx, |this, cx| this.insert_user_message(user_message, cx))?;
+                let (response_events_tx, response_events_rx) = mpsc::unbounded();
+                turn_tx
+                    .send(Turn {
+                        user_message_id,
+                        response_events: response_events_rx.boxed_local(),
+                    })
+                    .ok();
+
+                Self::turn_loop2(
+                    &this,
+                    model,
+                    CompletionIntent::UserPrompt,
+                    cancel_rx,
+                    response_events_tx,
+                    window,
+                    cx,
+                )
+                .await?;
+
+                this.update(cx, |this, _cx| this.pending_turn.take()).ok();
+
+                Ok(())
+            }),
+            cancel_tx,
+        });
+
+        async move { turn_rx.await.map_err(|_| anyhow!("Turn loop failed")) }.boxed_local()
+    }
+
+    async fn turn_loop2(
+        this: &WeakEntity<Self>,
+        model: Arc<dyn LanguageModel>,
+        mut intent: CompletionIntent,
+        mut cancel_rx: oneshot::Receiver<()>,
+        mut response_events_tx: mpsc::UnboundedSender<Result<ResponseEvent>>,
+        window: Option<AnyWindowHandle>,
+        cx: &mut AsyncApp,
+    ) -> Result<()> {
+        struct RetryState {
+            attempts: u8,
+            custom_delay: Option<Duration>,
+        }
+        let mut retry_state: Option<RetryState> = None;
+
+        struct PendingAssistantMessage {
+            chunks: VecDeque<PendingAssistantMessageChunk>,
+        }
+
+        impl PendingAssistantMessage {
+            fn push_text(&mut self, text: String) {
+                if let Some(PendingAssistantMessageChunk::Text(existing_text)) =
+                    self.chunks.back_mut()
+                {
+                    existing_text.push_str(&text);
+                } else {
+                    self.chunks
+                        .push_back(PendingAssistantMessageChunk::Text(text));
+                }
+            }
+
+            fn push_thinking(&mut self, text: String, signature: Option<String>) {
+                if let Some(PendingAssistantMessageChunk::Thinking {
+                    text: existing_text,
+                    signature: existing_signature,
+                }) = self.chunks.back_mut()
+                {
+                    *existing_signature = existing_signature.take().or(signature);
+                    existing_text.push_str(&text);
+                } else {
+                    self.chunks
+                        .push_back(PendingAssistantMessageChunk::Thinking { text, signature });
+                }
+            }
+        }
+
+        enum PendingAssistantMessageChunk {
+            Text(String),
+            Thinking {
+                text: String,
+                signature: Option<String>,
+            },
+            RedactedThinking {
+                data: String,
+            },
+            ToolCall(PendingAssistantToolCall),
+        }
+
+        struct PendingAssistantToolCall {
+            request: LanguageModelToolUse,
+            output: oneshot::Receiver<Result<ToolResultOutput>>,
+        }
+
+        loop {
+            let mut segments = Vec::new();
+            let mut assistant_message = PendingAssistantMessage {
+                chunks: VecDeque::new(),
+            };
+
+            let send = async {
+                if let Some(retry_state) = retry_state.as_ref() {
+                    let delay = retry_state.custom_delay.unwrap_or_else(|| {
+                        BASE_RETRY_DELAY * 2_u32.pow((retry_state.attempts - 1) as u32)
+                    });
+                    cx.background_executor().timer(delay).await;
+                }
+
+                let request = this.update(cx, |this, cx| this.build_request(&model, intent, cx))?;
+                let mut events = model.stream_completion(request.clone(), cx).await?;
+
+                while let Some(event) = events.next().await {
+                    let event = event?;
+                    match event {
+                        LanguageModelCompletionEvent::StartMessage { .. } => {
+                            // no-op, todo!("do we wanna insert a new message here?")
+                        }
+                        LanguageModelCompletionEvent::Text(chunk) => {
+                            response_events_tx
+                                .unbounded_send(Ok(ResponseEvent::Text(chunk.clone())));
+                            assistant_message.push_text(chunk);
+                        }
+                        LanguageModelCompletionEvent::Thinking { text, signature } => {
+                            response_events_tx
+                                .unbounded_send(Ok(ResponseEvent::Thinking(text.clone())));
+                            assistant_message.push_thinking(text, signature);
+                        }
+                        LanguageModelCompletionEvent::RedactedThinking { data } => {
+                            assistant_message
+                                .chunks
+                                .push_back(PendingAssistantMessageChunk::RedactedThinking { data });
+                        }
+                        LanguageModelCompletionEvent::ToolUse(tool_use) => {
+                            match this
+                                .read_with(cx, |this, cx| this.tool_for_name(&tool_use.name, cx))?
+                            {
+                                Ok(tool) => {
+                                    if tool_use.is_input_complete {
+                                        let (output_tx, output_rx) = oneshot::channel();
+                                        let mut request = request.clone();
+                                        // todo!("add the pending assistant message (excluding the tool calls)")
+                                        response_events_tx.unbounded_send(Ok(
+                                            ResponseEvent::ToolCall {
+                                                id: tool_use.id,
+                                                needs_confirmation: cx.update(|cx| {
+                                                    tool.needs_confirmation(&tool_use.input, cx)
+                                                })?,
+                                                label: tool.ui_text(&tool_use.input),
+                                                run: Box::new({
+                                                    let project = this
+                                                        .read_with(cx, |this, _| {
+                                                            this.project.clone()
+                                                        })?;
+                                                    let action_log = this
+                                                        .read_with(cx, |this, _| {
+                                                            this.action_log.clone()
+                                                        })?;
+                                                    move |window, cx| {
+                                                        let assistant_tool::ToolResult {
+                                                            output,
+                                                            card,
+                                                        } = tool.run(
+                                                            tool_use.input,
+                                                            Arc::new(request),
+                                                            project,
+                                                            action_log,
+                                                            model,
+                                                            window,
+                                                            cx,
+                                                        );
+
+                                                        ToolCallResult {
+                                                            task: cx.foreground_executor().spawn(
+                                                                async move {
+                                                                    match output.await {
+                                                                        Ok(output) => {
+                                                                            output_tx
+                                                                                .send(Ok(output))
+                                                                                .ok();
+                                                                            Ok(())
+                                                                        }
+                                                                        Err(error) => {
+                                                                            let error =
+                                                                                Arc::new(error);
+                                                                            output_tx
+                                                                                .send(Err(anyhow!(
+                                                                                    error.clone()
+                                                                                )))
+                                                                                .ok();
+                                                                            Err(anyhow!(error))
+                                                                        }
+                                                                    }
+                                                                },
+                                                            ),
+                                                            card,
+                                                        }
+                                                    }
+                                                }),
+                                            },
+                                        ));
+                                        assistant_message.chunks.push_back(
+                                            PendingAssistantMessageChunk::ToolCall(
+                                                PendingAssistantToolCall {
+                                                    request: tool_use,
+                                                    output: output_rx,
+                                                },
+                                            ),
+                                        );
+                                    } else {
+                                        response_events_tx.unbounded_send(Ok(
+                                            ResponseEvent::ToolCallChunk {
+                                                id: tool_use.id,
+                                                label: tool
+                                                    .still_streaming_ui_text(&tool_use.input),
+                                                input: tool_use.input,
+                                            },
+                                        ));
+                                    }
+                                }
+                                Err(error) => {
+                                    response_events_tx.unbounded_send(Ok(
+                                        ResponseEvent::InvalidToolCallChunk(tool_use.clone()),
+                                    ));
+                                    if tool_use.is_input_complete {
+                                        let (output_tx, output_rx) = oneshot::channel();
+                                        output_tx.send(Err(error)).unwrap();
+                                        assistant_message.chunks.push_back(
+                                            PendingAssistantMessageChunk::ToolCall(
+                                                PendingAssistantToolCall {
+                                                    request: tool_use,
+                                                    output: output_rx,
+                                                },
+                                            ),
+                                        );
+                                    }
+                                }
+                            }
+                        }
+                        LanguageModelCompletionEvent::UsageUpdate(_token_usage) => {
+                            // todo!
+                        }
+                        LanguageModelCompletionEvent::StatusUpdate(_completion_request_status) => {
+                            // todo!
+                        }
+                        LanguageModelCompletionEvent::Stop(StopReason::EndTurn) => {
+                            // todo!
+                        }
+                        LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) => {
+                            // todo!
+                        }
+                        LanguageModelCompletionEvent::Stop(StopReason::Refusal) => {
+                            // todo!
+                        }
+                        LanguageModelCompletionEvent::Stop(StopReason::ToolUse) => {}
+                    }
+                }
+
+                while let Some(chunk) = assistant_message.chunks.pop_front() {
+                    match chunk {
+                        PendingAssistantMessageChunk::Text(_) => todo!(),
+                        PendingAssistantMessageChunk::Thinking { text, signature } => todo!(),
+                        PendingAssistantMessageChunk::RedactedThinking { data } => todo!(),
+                        PendingAssistantMessageChunk::ToolCall(pending_assistant_tool_call) => {
+                            pending_assistant_tool_call.output.await;
+                        }
+                    }
+
+                    let (tool_result, thread_result) = pending_tool_use.result().await;
+                    this.update(cx, |thread, cx| {
+                        thread.set_tool_call_result(
+                            pending_tool_use.index_in_message,
+                            thread_result,
+                            cx,
+                        )
+                    })?;
+                    assistant_message.push(MessageContent::ToolUse(pending_tool_use.request));
+                    tool_results_message
+                        .content
+                        .push(MessageContent::ToolResult(tool_result));
+                }
+
+                anyhow::Ok(())
+            }
+            .boxed_local();
+
+            enum SendStatus {
+                Canceled,
+                Finished(Result<()>),
+            }
+
+            let status = match futures::future::select(&mut cancel_rx, send).await {
+                Either::Left(_) => SendStatus::Canceled,
+                Either::Right((result, _)) => SendStatus::Finished(result),
+            };
+
+            match status {
+                SendStatus::Canceled => {
+                    for pending_tool_use in pending_tool_uses {
+                        tool_results_message
+                            .content
+                            .push(MessageContent::ToolResult(LanguageModelToolResult {
+                                tool_use_id: pending_tool_use.request.id.clone(),
+                                tool_name: pending_tool_use.request.name.clone(),
+                                is_error: true,
+                                content: LanguageModelToolResultContent::Text(
+                                    "<User cancelled tool use>".into(),
+                                ),
+                                output: None,
+                            }));
+                        assistant_message.push(MessageContent::ToolUse(pending_tool_use.request));
+                    }
+
+                    this.update(cx, |this, _cx| {
+                        if !assistant_message.content.is_empty() {
+                            this.messages.push(assistant_message);
+                        }
+
+                        if !tool_results_message.content.is_empty() {
+                            this.messages.push(tool_results_message);
+                        }
+                    })?;
+
+                    break;
+                }
+                SendStatus::Finished(result) => {
+                    for mut pending_tool_use in pending_tool_uses {
+                        let (tool_result, thread_result) = pending_tool_use.result().await;
+                        this.update(cx, |thread, cx| {
+                            thread.set_tool_call_result(
+                                pending_tool_use.index_in_message,
+                                thread_result,
+                                cx,
+                            )
+                        })?;
+                        assistant_message.push(MessageContent::ToolUse(pending_tool_use.request));
+                        tool_results_message
+                            .content
+                            .push(MessageContent::ToolResult(tool_result));
+                    }
+
+                    match result {
+                        Ok(_) => {
+                            retry_state = None;
+                        }
+                        Err(error) => {
+                            let mut retry = |custom_delay: Option<Duration>| -> bool {
+                                let retry_state = retry_state.get_or_insert_with(|| RetryState {
+                                    attempts: 0,
+                                    custom_delay,
+                                });
+                                retry_state.attempts += 1;
+                                retry_state.attempts <= MAX_RETRY_ATTEMPTS
+                            };
+
+                            if error.is::<PaymentRequiredError>() {
+                                // todo!
+                                // cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
+                            } else if let Some(_error) =
+                                error.downcast_ref::<ModelRequestLimitReachedError>()
+                            {
+                                // todo!
+                                // cx.emit(ThreadEvent::ShowError(
+                                //     ThreadError::ModelRequestLimitReached { plan: error.plan },
+                                // ));
+                            } else if let Some(completion_error) =
+                                error.downcast_ref::<LanguageModelCompletionError>()
+                            {
+                                match completion_error {
+                                    LanguageModelCompletionError::RateLimitExceeded {
+                                        retry_after,
+                                    } => {
+                                        if !retry(Some(*retry_after)) {
+                                            break;
+                                        }
+                                    }
+                                    LanguageModelCompletionError::Overloaded => {
+                                        if !retry(None) {
+                                            break;
+                                        }
+                                    }
+                                    LanguageModelCompletionError::ApiInternalServerError => {
+                                        if !retry(None) {
+                                            break;
+                                        }
+                                        // todo!
+                                    }
+                                    _ => {
+                                        // todo!(emit_generic_error(error, cx);)
+                                        break;
+                                    }
+                                }
+                            } else if let Some(known_error) =
+                                error.downcast_ref::<LanguageModelKnownError>()
+                            {
+                                match known_error {
+                                    LanguageModelKnownError::ContextWindowLimitExceeded {
+                                        tokens: _,
+                                    } => {
+                                        // todo!
+                                        // this.exceeded_window_error =
+                                        //     Some(ExceededWindowError {
+                                        //         model_id: model.id(),
+                                        //         token_count: *tokens,
+                                        //     });
+                                        // cx.notify();
+                                        break;
+                                    }
+                                    LanguageModelKnownError::RateLimitExceeded { retry_after } => {
+                                        // let provider_name = model.provider_name();
+                                        // let error_message = format!(
+                                        //     "{}'s API rate limit exceeded",
+                                        //     provider_name.0.as_ref()
+                                        // );
+                                        if !retry(Some(*retry_after)) {
+                                            // todo! show err
+                                            break;
+                                        }
+                                    }
+                                    LanguageModelKnownError::Overloaded => {
+                                        //todo!
+                                        // let provider_name = model.provider_name();
+                                        // let error_message = format!(
+                                        //     "{}'s API servers are overloaded right now",
+                                        //     provider_name.0.as_ref()
+                                        // );
+
+                                        if !retry(None) {
+                                            // todo! show err
+                                            break;
+                                        }
+                                    }
+                                    LanguageModelKnownError::ApiInternalServerError => {
+                                        // let provider_name = model.provider_name();
+                                        // let error_message = format!(
+                                        //     "{}'s API server reported an internal server error",
+                                        //     provider_name.0.as_ref()
+                                        // );
+
+                                        if !retry(None) {
+                                            break;
+                                        }
+                                    }
+                                    LanguageModelKnownError::ReadResponseError(_)
+                                    | LanguageModelKnownError::DeserializeResponse(_)
+                                    | LanguageModelKnownError::UnknownResponseFormat(_) => {
+                                        // In the future we will attempt to re-roll response, but only once
+                                        // todo!(emit_generic_error(error, cx);)
+                                        break;
+                                    }
+                                }
+                            } else {
+                                // todo!(emit_generic_error(error, cx));
+                                break;
+                            }
+                        }
+                    }
+
+                    let done = this.update(cx, |this, cx| {
+                        let done = if assistant_message.content.is_empty() {
+                            true
+                        } else {
+                            this.messages.push(assistant_message);
+                            if tool_results_message.content.is_empty() {
+                                true
+                            } else {
+                                this.messages.push(tool_results_message);
+                                false
+                            }
+                        };
+
+                        let summary_pending = matches!(this.summary(), ThreadSummary::Pending);
+
+                        if summary_pending && (done || this.messages.len() > 6) {
+                            this.summarize(cx);
+                        }
+
+                        done
+                    })?;
+
+                    if done && retry_state.is_none() {
+                        break;
+                    } else {
+                        intent = CompletionIntent::ToolResults;
+                    }
+                }
+            }
+        }
+
+        Ok(())
+    }
+
     pub fn send_message(
         &mut self,
         params: impl Into<UserMessageParams>,

crates/agent_ui/src/agent_ui.rs 🔗

@@ -26,7 +26,7 @@ mod ui;
 
 use std::sync::Arc;
 
-use agent::{ThreadId, ZedAgent};
+use agent::{ThreadId, ZedAgentThread};
 use agent_settings::{AgentProfileId, AgentSettings, LanguageModelSelection};
 use assistant_slash_command::SlashCommandRegistry;
 use client::Client;
@@ -114,7 +114,7 @@ impl ManageProfiles {
 
 #[derive(Clone)]
 pub(crate) enum ModelUsageContext {
-    Thread(Entity<ZedAgent>),
+    Thread(Entity<ZedAgentThread>),
     InlineAssistant,
 }
 

crates/agent_ui/src/context_picker/completion_provider.rs 🔗

@@ -22,7 +22,7 @@ use util::ResultExt as _;
 use workspace::Workspace;
 
 use agent::{
-    ZedAgent,
+    ZedAgentThread,
     context::{AgentContextHandle, AgentContextKey, RULES_ICON},
     thread_store::{TextThreadStore, ThreadStore},
 };
@@ -449,7 +449,7 @@ impl ContextPickerCompletionProvider {
                         let context_store = context_store.clone();
                         let thread_store = thread_store.clone();
                         window.spawn::<_, Option<_>>(cx, async move |cx| {
-                            let thread: Entity<ZedAgent> = thread_store
+                            let thread: Entity<ZedAgentThread> = thread_store
                                 .update_in(cx, |thread_store, window, cx| {
                                     thread_store.open_thread(&thread_id, window, cx)
                                 })

crates/agent_ui/src/profile_selector.rs 🔗

@@ -1,6 +1,6 @@
 use crate::{ManageProfiles, ToggleProfileSelector};
 use agent::{
-    ZedAgent,
+    ZedAgentThread,
     agent_profile::{AgentProfile, AvailableProfiles},
 };
 use agent_settings::{AgentDockPosition, AgentProfileId, AgentSettings, builtin_profiles};
@@ -17,7 +17,7 @@ use ui::{
 pub struct ProfileSelector {
     profiles: AvailableProfiles,
     fs: Arc<dyn Fs>,
-    thread: Entity<ZedAgent>,
+    thread: Entity<ZedAgentThread>,
     menu_handle: PopoverMenuHandle<ContextMenu>,
     focus_handle: FocusHandle,
     _subscriptions: Vec<Subscription>,
@@ -26,7 +26,7 @@ pub struct ProfileSelector {
 impl ProfileSelector {
     pub fn new(
         fs: Arc<dyn Fs>,
-        thread: Entity<ZedAgent>,
+        thread: Entity<ZedAgentThread>,
         focus_handle: FocusHandle,
         cx: &mut Context<Self>,
     ) -> Self {

crates/agent_ui/src/tool_compatibility.rs 🔗

@@ -1,4 +1,4 @@
-use agent::{ThreadEvent, ZedAgent};
+use agent::{ThreadEvent, ZedAgentThread};
 use assistant_tool::{Tool, ToolSource};
 use collections::HashMap;
 use gpui::{App, Context, Entity, IntoElement, Render, Subscription, Window};
@@ -8,12 +8,12 @@ use ui::prelude::*;
 
 pub struct IncompatibleToolsState {
     cache: HashMap<LanguageModelToolSchemaFormat, Vec<Arc<dyn Tool>>>,
-    thread: Entity<ZedAgent>,
+    thread: Entity<ZedAgentThread>,
     _thread_subscription: Subscription,
 }
 
 impl IncompatibleToolsState {
-    pub fn new(thread: Entity<ZedAgent>, cx: &mut Context<Self>) -> Self {
+    pub fn new(thread: Entity<ZedAgentThread>, cx: &mut Context<Self>) -> Self {
         let _tool_working_set_subscription =
             cx.subscribe(&thread, |this, _, event, _| match event {
                 ThreadEvent::ProfileChanged => {