@@ -19,8 +19,9 @@ use collections::{HashMap, HashSet};
use feature_flags::{self, FeatureFlagAppExt};
use futures::{
FutureExt, StreamExt as _,
- channel::oneshot,
- future::{Either, Shared},
+ channel::{mpsc, oneshot},
+ future::{BoxFuture, Either, LocalBoxFuture, Shared},
+ stream::{BoxStream, LocalBoxStream},
};
use git::repository::DiffType;
use gpui::{
@@ -46,7 +47,7 @@ use proto::Plan;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
-use std::fmt::Write;
+use std::{collections::VecDeque, fmt::Write};
use std::{
ops::Range,
sync::Arc,
@@ -980,6 +981,33 @@ impl<T: Into<String>> From<T> for UserMessageParams {
}
}
+pub struct Turn {
+ user_message_id: MessageId,
+ response_events: LocalBoxStream<'static, Result<ResponseEvent>>,
+}
+
+struct ToolCallResult {
+ task: Task<Result<()>>,
+ card: Option<AnyToolCard>,
+}
+
+pub enum ResponseEvent {
+ Text(String),
+ Thinking(String),
+ ToolCallChunk {
+ id: LanguageModelToolUseId,
+ label: String,
+ input: serde_json::Value,
+ },
+ ToolCall {
+ id: LanguageModelToolUseId,
+ needs_confirmation: bool,
+ label: String,
+ run: Box<dyn FnOnce(Option<AnyWindowHandle>, &mut App) -> ToolCallResult>,
+ },
+ InvalidToolCallChunk(LanguageModelToolUse),
+}
+
impl ZedAgentThread {
pub fn new(
project: Entity<Project>,
@@ -1610,6 +1638,516 @@ impl ZedAgentThread {
}
}
+ pub fn send_message2(
+ &mut self,
+ user_message: impl Into<UserMessageParams>,
+ model: Arc<dyn LanguageModel>,
+ window: Option<AnyWindowHandle>,
+ cx: &mut Context<Self>,
+ ) -> LocalBoxFuture<'static, Result<Turn>> {
+ self.advance_prompt_id();
+
+ let user_message = user_message.into();
+ let prev_turn = self.cancel();
+ let (cancel_tx, cancel_rx) = oneshot::channel();
+ let (turn_tx, turn_rx) = oneshot::channel();
+ self.pending_turn = Some(PendingTurn {
+ task: cx.spawn(async move |this, cx| {
+ if let Some(prev_turn) = prev_turn {
+ prev_turn.await?;
+ }
+
+ let user_message_id =
+ this.update(cx, |this, cx| this.insert_user_message(user_message, cx))?;
+ let (response_events_tx, response_events_rx) = mpsc::unbounded();
+ turn_tx
+ .send(Turn {
+ user_message_id,
+ response_events: response_events_rx.boxed_local(),
+ })
+ .ok();
+
+ Self::turn_loop2(
+ &this,
+ model,
+ CompletionIntent::UserPrompt,
+ cancel_rx,
+ response_events_tx,
+ window,
+ cx,
+ )
+ .await?;
+
+ this.update(cx, |this, _cx| this.pending_turn.take()).ok();
+
+ Ok(())
+ }),
+ cancel_tx,
+ });
+
+ async move { turn_rx.await.map_err(|_| anyhow!("Turn loop failed")) }.boxed_local()
+ }
+
+ async fn turn_loop2(
+ this: &WeakEntity<Self>,
+ model: Arc<dyn LanguageModel>,
+ mut intent: CompletionIntent,
+ mut cancel_rx: oneshot::Receiver<()>,
+ mut response_events_tx: mpsc::UnboundedSender<Result<ResponseEvent>>,
+ window: Option<AnyWindowHandle>,
+ cx: &mut AsyncApp,
+ ) -> Result<()> {
+ struct RetryState {
+ attempts: u8,
+ custom_delay: Option<Duration>,
+ }
+ let mut retry_state: Option<RetryState> = None;
+
+ struct PendingAssistantMessage {
+ chunks: VecDeque<PendingAssistantMessageChunk>,
+ }
+
+ impl PendingAssistantMessage {
+ fn push_text(&mut self, text: String) {
+ if let Some(PendingAssistantMessageChunk::Text(existing_text)) =
+ self.chunks.back_mut()
+ {
+ existing_text.push_str(&text);
+ } else {
+ self.chunks
+ .push_back(PendingAssistantMessageChunk::Text(text));
+ }
+ }
+
+ fn push_thinking(&mut self, text: String, signature: Option<String>) {
+ if let Some(PendingAssistantMessageChunk::Thinking {
+ text: existing_text,
+ signature: existing_signature,
+ }) = self.chunks.back_mut()
+ {
+ *existing_signature = existing_signature.take().or(signature);
+ existing_text.push_str(&text);
+ } else {
+ self.chunks
+ .push_back(PendingAssistantMessageChunk::Thinking { text, signature });
+ }
+ }
+ }
+
+ enum PendingAssistantMessageChunk {
+ Text(String),
+ Thinking {
+ text: String,
+ signature: Option<String>,
+ },
+ RedactedThinking {
+ data: String,
+ },
+ ToolCall(PendingAssistantToolCall),
+ }
+
+ struct PendingAssistantToolCall {
+ request: LanguageModelToolUse,
+ output: oneshot::Receiver<Result<ToolResultOutput>>,
+ }
+
+ loop {
+ let mut segments = Vec::new();
+ let mut assistant_message = PendingAssistantMessage {
+ chunks: VecDeque::new(),
+ };
+
+ let send = async {
+ if let Some(retry_state) = retry_state.as_ref() {
+ let delay = retry_state.custom_delay.unwrap_or_else(|| {
+ BASE_RETRY_DELAY * 2_u32.pow((retry_state.attempts - 1) as u32)
+ });
+ cx.background_executor().timer(delay).await;
+ }
+
+ let request = this.update(cx, |this, cx| this.build_request(&model, intent, cx))?;
+ let mut events = model.stream_completion(request.clone(), cx).await?;
+
+ while let Some(event) = events.next().await {
+ let event = event?;
+ match event {
+ LanguageModelCompletionEvent::StartMessage { .. } => {
+ // no-op, todo!("do we wanna insert a new message here?")
+ }
+ LanguageModelCompletionEvent::Text(chunk) => {
+ response_events_tx
+ .unbounded_send(Ok(ResponseEvent::Text(chunk.clone())));
+ assistant_message.push_text(chunk);
+ }
+ LanguageModelCompletionEvent::Thinking { text, signature } => {
+ response_events_tx
+ .unbounded_send(Ok(ResponseEvent::Thinking(text.clone())));
+ assistant_message.push_thinking(text, signature);
+ }
+ LanguageModelCompletionEvent::RedactedThinking { data } => {
+ assistant_message
+ .chunks
+ .push_back(PendingAssistantMessageChunk::RedactedThinking { data });
+ }
+ LanguageModelCompletionEvent::ToolUse(tool_use) => {
+ match this
+ .read_with(cx, |this, cx| this.tool_for_name(&tool_use.name, cx))?
+ {
+ Ok(tool) => {
+ if tool_use.is_input_complete {
+ let (output_tx, output_rx) = oneshot::channel();
+ let mut request = request.clone();
+ // todo!("add the pending assistant message (excluding the tool calls)")
+ response_events_tx.unbounded_send(Ok(
+ ResponseEvent::ToolCall {
+ id: tool_use.id,
+ needs_confirmation: cx.update(|cx| {
+ tool.needs_confirmation(&tool_use.input, cx)
+ })?,
+ label: tool.ui_text(&tool_use.input),
+ run: Box::new({
+ let project = this
+ .read_with(cx, |this, _| {
+ this.project.clone()
+ })?;
+ let action_log = this
+ .read_with(cx, |this, _| {
+ this.action_log.clone()
+ })?;
+ move |window, cx| {
+ let assistant_tool::ToolResult {
+ output,
+ card,
+ } = tool.run(
+ tool_use.input,
+ Arc::new(request),
+ project,
+ action_log,
+ model,
+ window,
+ cx,
+ );
+
+ ToolCallResult {
+ task: cx.foreground_executor().spawn(
+ async move {
+ match output.await {
+ Ok(output) => {
+ output_tx
+ .send(Ok(output))
+ .ok();
+ Ok(())
+ }
+ Err(error) => {
+ let error =
+ Arc::new(error);
+ output_tx
+ .send(Err(anyhow!(
+ error.clone()
+ )))
+ .ok();
+ Err(anyhow!(error))
+ }
+ }
+ },
+ ),
+ card,
+ }
+ }
+ }),
+ },
+ ));
+ assistant_message.chunks.push_back(
+ PendingAssistantMessageChunk::ToolCall(
+ PendingAssistantToolCall {
+ request: tool_use,
+ output: output_rx,
+ },
+ ),
+ );
+ } else {
+ response_events_tx.unbounded_send(Ok(
+ ResponseEvent::ToolCallChunk {
+ id: tool_use.id,
+ label: tool
+ .still_streaming_ui_text(&tool_use.input),
+ input: tool_use.input,
+ },
+ ));
+ }
+ }
+ Err(error) => {
+ response_events_tx.unbounded_send(Ok(
+ ResponseEvent::InvalidToolCallChunk(tool_use.clone()),
+ ));
+ if tool_use.is_input_complete {
+ let (output_tx, output_rx) = oneshot::channel();
+ output_tx.send(Err(error)).unwrap();
+ assistant_message.chunks.push_back(
+ PendingAssistantMessageChunk::ToolCall(
+ PendingAssistantToolCall {
+ request: tool_use,
+ output: output_rx,
+ },
+ ),
+ );
+ }
+ }
+ }
+ }
+ LanguageModelCompletionEvent::UsageUpdate(_token_usage) => {
+ // todo!
+ }
+ LanguageModelCompletionEvent::StatusUpdate(_completion_request_status) => {
+ // todo!
+ }
+ LanguageModelCompletionEvent::Stop(StopReason::EndTurn) => {
+ // todo!
+ }
+ LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) => {
+ // todo!
+ }
+ LanguageModelCompletionEvent::Stop(StopReason::Refusal) => {
+ // todo!
+ }
+ LanguageModelCompletionEvent::Stop(StopReason::ToolUse) => {}
+ }
+ }
+
+ while let Some(chunk) = assistant_message.chunks.pop_front() {
+ match chunk {
+ PendingAssistantMessageChunk::Text(_) => todo!(),
+ PendingAssistantMessageChunk::Thinking { text, signature } => todo!(),
+ PendingAssistantMessageChunk::RedactedThinking { data } => todo!(),
+ PendingAssistantMessageChunk::ToolCall(pending_assistant_tool_call) => {
+ pending_assistant_tool_call.output.await;
+ }
+ }
+
+ let (tool_result, thread_result) = pending_tool_use.result().await;
+ this.update(cx, |thread, cx| {
+ thread.set_tool_call_result(
+ pending_tool_use.index_in_message,
+ thread_result,
+ cx,
+ )
+ })?;
+ assistant_message.push(MessageContent::ToolUse(pending_tool_use.request));
+ tool_results_message
+ .content
+ .push(MessageContent::ToolResult(tool_result));
+ }
+
+ anyhow::Ok(())
+ }
+ .boxed_local();
+
+ enum SendStatus {
+ Canceled,
+ Finished(Result<()>),
+ }
+
+ let status = match futures::future::select(&mut cancel_rx, send).await {
+ Either::Left(_) => SendStatus::Canceled,
+ Either::Right((result, _)) => SendStatus::Finished(result),
+ };
+
+ match status {
+ SendStatus::Canceled => {
+ for pending_tool_use in pending_tool_uses {
+ tool_results_message
+ .content
+ .push(MessageContent::ToolResult(LanguageModelToolResult {
+ tool_use_id: pending_tool_use.request.id.clone(),
+ tool_name: pending_tool_use.request.name.clone(),
+ is_error: true,
+ content: LanguageModelToolResultContent::Text(
+ "<User cancelled tool use>".into(),
+ ),
+ output: None,
+ }));
+ assistant_message.push(MessageContent::ToolUse(pending_tool_use.request));
+ }
+
+ this.update(cx, |this, _cx| {
+ if !assistant_message.content.is_empty() {
+ this.messages.push(assistant_message);
+ }
+
+ if !tool_results_message.content.is_empty() {
+ this.messages.push(tool_results_message);
+ }
+ })?;
+
+ break;
+ }
+ SendStatus::Finished(result) => {
+ for mut pending_tool_use in pending_tool_uses {
+ let (tool_result, thread_result) = pending_tool_use.result().await;
+ this.update(cx, |thread, cx| {
+ thread.set_tool_call_result(
+ pending_tool_use.index_in_message,
+ thread_result,
+ cx,
+ )
+ })?;
+ assistant_message.push(MessageContent::ToolUse(pending_tool_use.request));
+ tool_results_message
+ .content
+ .push(MessageContent::ToolResult(tool_result));
+ }
+
+ match result {
+ Ok(_) => {
+ retry_state = None;
+ }
+ Err(error) => {
+ let mut retry = |custom_delay: Option<Duration>| -> bool {
+ let retry_state = retry_state.get_or_insert_with(|| RetryState {
+ attempts: 0,
+ custom_delay,
+ });
+ retry_state.attempts += 1;
+ retry_state.attempts <= MAX_RETRY_ATTEMPTS
+ };
+
+ if error.is::<PaymentRequiredError>() {
+ // todo!
+ // cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
+ } else if let Some(_error) =
+ error.downcast_ref::<ModelRequestLimitReachedError>()
+ {
+ // todo!
+ // cx.emit(ThreadEvent::ShowError(
+ // ThreadError::ModelRequestLimitReached { plan: error.plan },
+ // ));
+ } else if let Some(completion_error) =
+ error.downcast_ref::<LanguageModelCompletionError>()
+ {
+ match completion_error {
+ LanguageModelCompletionError::RateLimitExceeded {
+ retry_after,
+ } => {
+ if !retry(Some(*retry_after)) {
+ break;
+ }
+ }
+ LanguageModelCompletionError::Overloaded => {
+ if !retry(None) {
+ break;
+ }
+ }
+ LanguageModelCompletionError::ApiInternalServerError => {
+ if !retry(None) {
+ break;
+ }
+ // todo!
+ }
+ _ => {
+ // todo!(emit_generic_error(error, cx);)
+ break;
+ }
+ }
+ } else if let Some(known_error) =
+ error.downcast_ref::<LanguageModelKnownError>()
+ {
+ match known_error {
+ LanguageModelKnownError::ContextWindowLimitExceeded {
+ tokens: _,
+ } => {
+ // todo!
+ // this.exceeded_window_error =
+ // Some(ExceededWindowError {
+ // model_id: model.id(),
+ // token_count: *tokens,
+ // });
+ // cx.notify();
+ break;
+ }
+ LanguageModelKnownError::RateLimitExceeded { retry_after } => {
+ // let provider_name = model.provider_name();
+ // let error_message = format!(
+ // "{}'s API rate limit exceeded",
+ // provider_name.0.as_ref()
+ // );
+ if !retry(Some(*retry_after)) {
+ // todo! show err
+ break;
+ }
+ }
+ LanguageModelKnownError::Overloaded => {
+ //todo!
+ // let provider_name = model.provider_name();
+ // let error_message = format!(
+ // "{}'s API servers are overloaded right now",
+ // provider_name.0.as_ref()
+ // );
+
+ if !retry(None) {
+ // todo! show err
+ break;
+ }
+ }
+ LanguageModelKnownError::ApiInternalServerError => {
+ // let provider_name = model.provider_name();
+ // let error_message = format!(
+ // "{}'s API server reported an internal server error",
+ // provider_name.0.as_ref()
+ // );
+
+ if !retry(None) {
+ break;
+ }
+ }
+ LanguageModelKnownError::ReadResponseError(_)
+ | LanguageModelKnownError::DeserializeResponse(_)
+ | LanguageModelKnownError::UnknownResponseFormat(_) => {
+ // In the future we will attempt to re-roll response, but only once
+ // todo!(emit_generic_error(error, cx);)
+ break;
+ }
+ }
+ } else {
+ // todo!(emit_generic_error(error, cx));
+ break;
+ }
+ }
+ }
+
+ let done = this.update(cx, |this, cx| {
+ let done = if assistant_message.content.is_empty() {
+ true
+ } else {
+ this.messages.push(assistant_message);
+ if tool_results_message.content.is_empty() {
+ true
+ } else {
+ this.messages.push(tool_results_message);
+ false
+ }
+ };
+
+ let summary_pending = matches!(this.summary(), ThreadSummary::Pending);
+
+ if summary_pending && (done || this.messages.len() > 6) {
+ this.summarize(cx);
+ }
+
+ done
+ })?;
+
+ if done && retry_state.is_none() {
+ break;
+ } else {
+ intent = CompletionIntent::ToolResults;
+ }
+ }
+ }
+ }
+
+ Ok(())
+ }
+
pub fn send_message(
&mut self,
params: impl Into<UserMessageParams>,