Cargo.lock 🔗
@@ -244,6 +244,7 @@ dependencies = [
"terminal",
"text",
"theme",
+ "thiserror 2.0.12",
"tree-sitter-rust",
"ui",
"unindent",
Antonio Scandurra created
Release Notes:
- N/A
Cargo.lock | 1
crates/agent2/Cargo.toml | 1
crates/agent2/src/thread.rs | 295 +++++++++++++++++---------------------
3 files changed, 134 insertions(+), 163 deletions(-)
@@ -244,6 +244,7 @@ dependencies = [
"terminal",
"text",
"theme",
+ "thiserror 2.0.12",
"tree-sitter-rust",
"ui",
"unindent",
@@ -61,6 +61,7 @@ sqlez.workspace = true
task.workspace = true
telemetry.workspace = true
terminal.workspace = true
+thiserror.workspace = true
text.workspace = true
ui.workspace = true
util.workspace = true
@@ -499,6 +499,16 @@ pub struct ToolCallAuthorization {
pub response: oneshot::Sender<acp::PermissionOptionId>,
}
+#[derive(Debug, thiserror::Error)]
+enum CompletionError {
+ #[error("max tokens")]
+ MaxTokens,
+ #[error("refusal")]
+ Refusal,
+ #[error(transparent)]
+ Other(#[from] anyhow::Error),
+}
+
pub struct Thread {
id: acp::SessionId,
prompt_id: PromptId,
@@ -1077,101 +1087,62 @@ impl Thread {
_task: cx.spawn(async move |this, cx| {
log::info!("Starting agent turn execution");
let mut update_title = None;
- let turn_result: Result<StopReason> = async {
- let mut completion_intent = CompletionIntent::UserPrompt;
+ let turn_result: Result<()> = async {
+ let mut intent = CompletionIntent::UserPrompt;
loop {
- log::debug!(
- "Building completion request with intent: {:?}",
- completion_intent
- );
- let request = this.update(cx, |this, cx| {
- this.build_completion_request(completion_intent, cx)
- })??;
-
- log::info!("Calling model.stream_completion");
-
- let mut tool_use_limit_reached = false;
- let mut refused = false;
- let mut reached_max_tokens = false;
- let mut tool_uses = Self::stream_completion_with_retries(
- this.clone(),
- model.clone(),
- request,
- &event_stream,
- &mut tool_use_limit_reached,
- &mut refused,
- &mut reached_max_tokens,
- cx,
- )
- .await?;
-
- if refused {
- return Ok(StopReason::Refusal);
- } else if reached_max_tokens {
- return Ok(StopReason::MaxTokens);
- }
-
- let end_turn = tool_uses.is_empty();
- while let Some(tool_result) = tool_uses.next().await {
- log::info!("Tool finished {:?}", tool_result);
-
- event_stream.update_tool_call_fields(
- &tool_result.tool_use_id,
- acp::ToolCallUpdateFields {
- status: Some(if tool_result.is_error {
- acp::ToolCallStatus::Failed
- } else {
- acp::ToolCallStatus::Completed
- }),
- raw_output: tool_result.output.clone(),
- ..Default::default()
- },
- );
- this.update(cx, |this, _cx| {
- this.pending_message()
- .tool_results
- .insert(tool_result.tool_use_id.clone(), tool_result);
- })?;
- }
+ Self::stream_completion(&this, &model, intent, &event_stream, cx).await?;
+ let mut end_turn = true;
this.update(cx, |this, cx| {
+ // Generate title if needed.
if this.title.is_none() && update_title.is_none() {
update_title = Some(this.update_title(&event_stream, cx));
}
+
+ // End the turn if the model didn't use tools.
+ let message = this.pending_message.as_ref();
+ end_turn =
+ message.map_or(true, |message| message.tool_results.is_empty());
+ this.flush_pending_message(cx);
})?;
- if tool_use_limit_reached {
+ if this.read_with(cx, |this, _| this.tool_use_limit_reached)? {
log::info!("Tool use limit reached, completing turn");
- this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
return Err(language_model::ToolUseLimitReachedError.into());
} else if end_turn {
log::info!("No tool uses found, completing turn");
- return Ok(StopReason::EndTurn);
+ return Ok(());
} else {
- this.update(cx, |this, cx| this.flush_pending_message(cx))?;
- completion_intent = CompletionIntent::ToolResults;
+ intent = CompletionIntent::ToolResults;
}
}
}
.await;
_ = this.update(cx, |this, cx| this.flush_pending_message(cx));
- match turn_result {
- Ok(reason) => {
- log::info!("Turn execution completed: {:?}", reason);
-
- if let Some(update_title) = update_title {
- update_title.await.context("update title failed").log_err();
- }
+ if let Some(update_title) = update_title {
+ update_title.await.context("update title failed").log_err();
+ }
- event_stream.send_stop(reason);
- if reason == StopReason::Refusal {
- _ = this.update(cx, |this, _| this.messages.truncate(message_ix));
- }
+ match turn_result {
+ Ok(()) => {
+ log::info!("Turn execution completed");
+ event_stream.send_stop(acp::StopReason::EndTurn);
}
Err(error) => {
log::error!("Turn execution failed: {:?}", error);
- event_stream.send_error(error);
+ match error.downcast::<CompletionError>() {
+ Ok(CompletionError::Refusal) => {
+ event_stream.send_stop(acp::StopReason::Refusal);
+ _ = this.update(cx, |this, _| this.messages.truncate(message_ix));
+ }
+ Ok(CompletionError::MaxTokens) => {
+ event_stream.send_stop(acp::StopReason::MaxTokens);
+ }
+ Ok(CompletionError::Other(error)) | Err(error) => {
+ event_stream.send_error(error);
+ }
+ }
}
}
@@ -1181,17 +1152,17 @@ impl Thread {
Ok(events_rx)
}
- async fn stream_completion_with_retries(
- this: WeakEntity<Self>,
- model: Arc<dyn LanguageModel>,
- request: LanguageModelRequest,
+ async fn stream_completion(
+ this: &WeakEntity<Self>,
+ model: &Arc<dyn LanguageModel>,
+ completion_intent: CompletionIntent,
event_stream: &ThreadEventStream,
- tool_use_limit_reached: &mut bool,
- refusal: &mut bool,
- max_tokens_reached: &mut bool,
cx: &mut AsyncApp,
- ) -> Result<FuturesUnordered<Task<LanguageModelToolResult>>> {
+ ) -> Result<()> {
log::debug!("Stream completion started successfully");
+ let request = this.update(cx, |this, cx| {
+ this.build_completion_request(completion_intent, cx)
+ })??;
let mut attempt = None;
'retry: loop {
@@ -1204,68 +1175,33 @@ impl Thread {
attempt
);
- let mut events = model.stream_completion(request.clone(), cx).await?;
- let mut tool_uses = FuturesUnordered::new();
+ log::info!(
+ "Calling model.stream_completion, attempt {}",
+ attempt.unwrap_or(0)
+ );
+ let mut events = model
+ .stream_completion(request.clone(), cx)
+ .await
+ .map_err(|error| anyhow!(error))?;
+ let mut tool_results = FuturesUnordered::new();
+
while let Some(event) = events.next().await {
match event {
- Ok(LanguageModelCompletionEvent::StatusUpdate(
- CompletionRequestStatus::ToolUseLimitReached,
- )) => {
- *tool_use_limit_reached = true;
- }
- Ok(LanguageModelCompletionEvent::StatusUpdate(
- CompletionRequestStatus::UsageUpdated { amount, limit },
- )) => {
- this.update(cx, |this, cx| {
- this.update_model_request_usage(amount, limit, cx)
- })?;
- }
- Ok(LanguageModelCompletionEvent::UsageUpdate(usage)) => {
- telemetry::event!(
- "Agent Thread Completion Usage Updated",
- thread_id = this.read_with(cx, |this, _| this.id.to_string())?,
- prompt_id = this.read_with(cx, |this, _| this.prompt_id.to_string())?,
- model = model.telemetry_id(),
- model_provider = model.provider_id().to_string(),
- attempt,
- input_tokens = usage.input_tokens,
- output_tokens = usage.output_tokens,
- cache_creation_input_tokens = usage.cache_creation_input_tokens,
- cache_read_input_tokens = usage.cache_read_input_tokens,
- );
-
- this.update(cx, |this, cx| this.update_token_usage(usage, cx))?;
- }
- Ok(LanguageModelCompletionEvent::Stop(StopReason::Refusal)) => {
- *refusal = true;
- return Ok(FuturesUnordered::default());
- }
- Ok(LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)) => {
- *max_tokens_reached = true;
- return Ok(FuturesUnordered::default());
- }
- Ok(LanguageModelCompletionEvent::Stop(
- StopReason::ToolUse | StopReason::EndTurn,
- )) => break,
Ok(event) => {
log::trace!("Received completion event: {:?}", event);
- this.update(cx, |this, cx| {
- tool_uses.extend(this.handle_streamed_completion_event(
- event,
- event_stream,
- cx,
- ));
- })?;
+ tool_results.extend(this.update(cx, |this, cx| {
+ this.handle_streamed_completion_event(event, event_stream, cx)
+ })??);
}
Err(error) => {
let completion_mode =
this.read_with(cx, |thread, _cx| thread.completion_mode())?;
if completion_mode == CompletionMode::Normal {
- return Err(error.into());
+ return Err(anyhow!(error))?;
}
let Some(strategy) = Self::retry_strategy_for(&error) else {
- return Err(error.into());
+ return Err(anyhow!(error))?;
};
let max_attempts = match &strategy {
@@ -1279,7 +1215,7 @@ impl Thread {
let attempt = *attempt;
if attempt > max_attempts {
- return Err(error.into());
+ return Err(anyhow!(error))?;
}
let delay = match &strategy {
@@ -1306,7 +1242,29 @@ impl Thread {
}
}
- return Ok(tool_uses);
+ while let Some(tool_result) = tool_results.next().await {
+ log::info!("Tool finished {:?}", tool_result);
+
+ event_stream.update_tool_call_fields(
+ &tool_result.tool_use_id,
+ acp::ToolCallUpdateFields {
+ status: Some(if tool_result.is_error {
+ acp::ToolCallStatus::Failed
+ } else {
+ acp::ToolCallStatus::Completed
+ }),
+ raw_output: tool_result.output.clone(),
+ ..Default::default()
+ },
+ );
+ this.update(cx, |this, _cx| {
+ this.pending_message()
+ .tool_results
+ .insert(tool_result.tool_use_id.clone(), tool_result);
+ })?;
+ }
+
+ return Ok(());
}
}
@@ -1328,14 +1286,14 @@ impl Thread {
}
/// A helper method that's called on every streamed completion event.
- /// Returns an optional tool result task, which the main agentic loop in
- /// send will send back to the model when it resolves.
+ /// Returns an optional tool result task, which the main agentic loop will
+ /// send back to the model when it resolves.
fn handle_streamed_completion_event(
&mut self,
event: LanguageModelCompletionEvent,
event_stream: &ThreadEventStream,
cx: &mut Context<Self>,
- ) -> Option<Task<LanguageModelToolResult>> {
+ ) -> Result<Option<Task<LanguageModelToolResult>>> {
log::trace!("Handling streamed completion event: {:?}", event);
use LanguageModelCompletionEvent::*;
@@ -1350,7 +1308,7 @@ impl Thread {
}
RedactedThinking { data } => self.handle_redacted_thinking_event(data, cx),
ToolUse(tool_use) => {
- return self.handle_tool_use_event(tool_use, event_stream, cx);
+ return Ok(self.handle_tool_use_event(tool_use, event_stream, cx));
}
ToolUseJsonParseError {
id,
@@ -1358,18 +1316,46 @@ impl Thread {
raw_input,
json_parse_error,
} => {
- return Some(Task::ready(self.handle_tool_use_json_parse_error_event(
- id,
- tool_name,
- raw_input,
- json_parse_error,
+ return Ok(Some(Task::ready(
+ self.handle_tool_use_json_parse_error_event(
+ id,
+ tool_name,
+ raw_input,
+ json_parse_error,
+ ),
)));
}
- StatusUpdate(_) => {}
- UsageUpdate(_) | Stop(_) => unreachable!(),
+ UsageUpdate(usage) => {
+ telemetry::event!(
+ "Agent Thread Completion Usage Updated",
+ thread_id = self.id.to_string(),
+ prompt_id = self.prompt_id.to_string(),
+ model = self.model.as_ref().map(|m| m.telemetry_id()),
+ model_provider = self.model.as_ref().map(|m| m.provider_id().to_string()),
+ input_tokens = usage.input_tokens,
+ output_tokens = usage.output_tokens,
+ cache_creation_input_tokens = usage.cache_creation_input_tokens,
+ cache_read_input_tokens = usage.cache_read_input_tokens,
+ );
+ self.update_token_usage(usage, cx);
+ }
+ StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => {
+ self.update_model_request_usage(amount, limit, cx);
+ }
+ StatusUpdate(
+ CompletionRequestStatus::Started
+ | CompletionRequestStatus::Queued { .. }
+ | CompletionRequestStatus::Failed { .. },
+ ) => {}
+ StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => {
+ self.tool_use_limit_reached = true;
+ }
+ Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()),
+ Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()),
+ Stop(StopReason::ToolUse | StopReason::EndTurn) => {}
}
- None
+ Ok(None)
}
fn handle_text_event(
@@ -2225,25 +2211,8 @@ impl ThreadEventStream {
self.0.unbounded_send(Ok(ThreadEvent::Retry(status))).ok();
}
- fn send_stop(&self, reason: StopReason) {
- match reason {
- StopReason::EndTurn => {
- self.0
- .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::EndTurn)))
- .ok();
- }
- StopReason::MaxTokens => {
- self.0
- .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::MaxTokens)))
- .ok();
- }
- StopReason::Refusal => {
- self.0
- .unbounded_send(Ok(ThreadEvent::Stop(acp::StopReason::Refusal)))
- .ok();
- }
- StopReason::ToolUse => {}
- }
+ fn send_stop(&self, reason: acp::StopReason) {
+ self.0.unbounded_send(Ok(ThreadEvent::Stop(reason))).ok();
}
fn send_canceled(&self) {