diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index ffc5dbc6d30e58b5d819c3778b063951b0ed0861..f43cbed952afd434c4262da486ce11dffa40a5c8 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -664,9 +664,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) { ); // Simulate reaching tool use limit. - fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate( - cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached, - )); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached); fake_model.end_last_completion_stream(); let last_event = events.collect::>().await.pop().unwrap(); assert!( @@ -749,9 +747,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) { }; fake_model .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone())); - fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate( - cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached, - )); + fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached); fake_model.end_last_completion_stream(); let last_event = events.collect::>().await.pop().unwrap(); assert!( diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 45c09675b2470bc399e7ad38fbf976fb2b06eea6..928b60eee4bc3ccdf296e8ba7f4f0bdc49cb9fa3 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -15,7 +15,7 @@ use agent_settings::{ use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; use client::{ModelRequestUsage, RequestUsage, UserStore}; -use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit}; +use cloud_llm_client::{CompletionIntent, Plan, UsageLimit}; use collections::{HashMap, HashSet, IndexMap}; use fs::Fs; use futures::stream; @@ -1430,20 +1430,16 @@ impl Thread { ); self.update_token_usage(usage, cx); } - StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => { + UsageUpdated { amount, limit } => { self.update_model_request_usage(amount, limit, cx); } - StatusUpdate( - CompletionRequestStatus::Started - | CompletionRequestStatus::Queued { .. } - | CompletionRequestStatus::Failed { .. }, - ) => {} - StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => { + ToolUseLimitReached => { self.tool_use_limit_reached = true; } Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()), Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()), Stop(StopReason::ToolUse | StopReason::EndTurn) => {} + Started | Queued { .. } => {} } Ok(None) @@ -1687,9 +1683,7 @@ impl Thread { let event = event.log_err()?; let text = match event { LanguageModelCompletionEvent::Text(text) => text, - LanguageModelCompletionEvent::StatusUpdate( - CompletionRequestStatus::UsageUpdated { amount, limit }, - ) => { + LanguageModelCompletionEvent::UsageUpdated { amount, limit } => { this.update(cx, |thread, cx| { thread.update_model_request_usage(amount, limit, cx); }) @@ -1753,9 +1747,7 @@ impl Thread { let event = event?; let text = match event { LanguageModelCompletionEvent::Text(text) => text, - LanguageModelCompletionEvent::StatusUpdate( - CompletionRequestStatus::UsageUpdated { amount, limit }, - ) => { + LanguageModelCompletionEvent::UsageUpdated { amount, limit } => { this.update(cx, |thread, cx| { thread.update_model_request_usage(amount, limit, cx); })?; diff --git a/crates/assistant_text_thread/src/text_thread.rs b/crates/assistant_text_thread/src/text_thread.rs index ae5fe25d430e80b7be68000162b6f0b21807e2a2..9f065e9ca7a1daf933c1313dd1d5f092cbed2771 100644 --- a/crates/assistant_text_thread/src/text_thread.rs +++ b/crates/assistant_text_thread/src/text_thread.rs @@ -7,9 +7,10 @@ use assistant_slash_command::{ use assistant_slash_commands::FileCommandMetadata; use client::{self, ModelRequestUsage, RequestUsage, proto, telemetry::Telemetry}; use clock::ReplicaId; -use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; +use cloud_llm_client::{CompletionIntent, UsageLimit}; use collections::{HashMap, HashSet}; use fs::{Fs, RenameOptions}; + use futures::{FutureExt, StreamExt, future::Shared}; use gpui::{ App, AppContext as _, Context, Entity, EventEmitter, RenderImage, SharedString, Subscription, @@ -2073,14 +2074,15 @@ impl TextThread { }); match event { - LanguageModelCompletionEvent::StatusUpdate(status_update) => { - if let CompletionRequestStatus::UsageUpdated { amount, limit } = status_update { - this.update_model_request_usage( - amount as u32, - limit, - cx, - ); - } + LanguageModelCompletionEvent::Started | + LanguageModelCompletionEvent::Queued {..} | + LanguageModelCompletionEvent::ToolUseLimitReached { .. } => {} + LanguageModelCompletionEvent::UsageUpdated { amount, limit } => { + this.update_model_request_usage( + amount as u32, + limit, + cx, + ); } LanguageModelCompletionEvent::StartMessage { .. } => {} LanguageModelCompletionEvent::Stop(reason) => { diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index 035f1ec0ac8d0c6490dc39637e03e377ee3d194b..075a1a5cea1da782d40778befeb04bf2e6bac316 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -1251,8 +1251,11 @@ pub fn response_events_to_markdown( } Ok( LanguageModelCompletionEvent::UsageUpdate(_) + | LanguageModelCompletionEvent::ToolUseLimitReached | LanguageModelCompletionEvent::StartMessage { .. } - | LanguageModelCompletionEvent::StatusUpdate { .. }, + | LanguageModelCompletionEvent::UsageUpdated { .. } + | LanguageModelCompletionEvent::Queued { .. } + | LanguageModelCompletionEvent::Started, ) => {} Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { json_parse_error, .. @@ -1337,9 +1340,12 @@ impl ThreadDialog { // Skip these Ok(LanguageModelCompletionEvent::UsageUpdate(_)) | Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) - | Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) | Ok(LanguageModelCompletionEvent::StartMessage { .. }) - | Ok(LanguageModelCompletionEvent::Stop(_)) => {} + | Ok(LanguageModelCompletionEvent::Stop(_)) + | Ok(LanguageModelCompletionEvent::Queued { .. }) + | Ok(LanguageModelCompletionEvent::Started) + | Ok(LanguageModelCompletionEvent::UsageUpdated { .. }) + | Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => {} Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { json_parse_error, diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 785bb0dbdc7b6bb82d052cce16eb1c4b2fd66a48..3322409c09399b3ec957d8288b45e1833b77c106 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -12,7 +12,7 @@ pub mod fake_provider; use anthropic::{AnthropicError, parse_prompt_too_long}; use anyhow::{Result, anyhow}; use client::Client; -use cloud_llm_client::{CompletionMode, CompletionRequestStatus}; +use cloud_llm_client::{CompletionMode, CompletionRequestStatus, UsageLimit}; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window}; @@ -70,7 +70,15 @@ pub fn init_settings(cx: &mut App) { /// A completion event from a language model. #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub enum LanguageModelCompletionEvent { - StatusUpdate(CompletionRequestStatus), + Queued { + position: usize, + }, + Started, + UsageUpdated { + amount: usize, + limit: UsageLimit, + }, + ToolUseLimitReached, Stop(StopReason), Text(String), Thinking { @@ -93,6 +101,37 @@ pub enum LanguageModelCompletionEvent { UsageUpdate(TokenUsage), } +impl LanguageModelCompletionEvent { + pub fn from_completion_request_status( + status: CompletionRequestStatus, + upstream_provider: LanguageModelProviderName, + ) -> Result { + match status { + CompletionRequestStatus::Queued { position } => { + Ok(LanguageModelCompletionEvent::Queued { position }) + } + CompletionRequestStatus::Started => Ok(LanguageModelCompletionEvent::Started), + CompletionRequestStatus::UsageUpdated { amount, limit } => { + Ok(LanguageModelCompletionEvent::UsageUpdated { amount, limit }) + } + CompletionRequestStatus::ToolUseLimitReached => { + Ok(LanguageModelCompletionEvent::ToolUseLimitReached) + } + CompletionRequestStatus::Failed { + code, + message, + request_id: _, + retry_after, + } => Err(LanguageModelCompletionError::from_cloud_failure( + upstream_provider, + code, + message, + retry_after.map(Duration::from_secs_f64), + )), + } + } +} + #[derive(Error, Debug)] pub enum LanguageModelCompletionError { #[error("prompt too large for context window")] @@ -633,7 +672,10 @@ pub trait LanguageModel: Send + Sync { let last_token_usage = last_token_usage.clone(); async move { match result { - Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None, + Ok(LanguageModelCompletionEvent::Queued { .. }) => None, + Ok(LanguageModelCompletionEvent::Started) => None, + Ok(LanguageModelCompletionEvent::UsageUpdated { .. }) => None, + Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => None, Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None, Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)), Ok(LanguageModelCompletionEvent::Thinking { .. }) => None, diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index d85533ecce63441fe5aaa7a382bf04af79992f63..a9ff767146287db25fb0b42685525fd56d29d71e 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -752,6 +752,7 @@ impl LanguageModel for CloudLanguageModel { let mode = request.mode; let app_version = cx.update(|cx| AppVersion::global(cx)).ok(); let thinking_allowed = request.thinking_allowed; + let provider_name = provider_name(&self.model.provider); match self.model.provider { cloud_llm_client::LanguageModelProvider::Anthropic => { let request = into_anthropic( @@ -801,8 +802,9 @@ impl LanguageModel for CloudLanguageModel { Box::pin( response_lines(response, includes_status_messages) .chain(usage_updated_event(usage)) - .chain(tool_use_limit_reached_event(tool_use_limit_reached)), + .chain(tool_use_limit_reached_event(tool_use_limit_reached)), // .map(|_| {}), ), + &provider_name, move |event| mapper.map_event(event), )) }); @@ -849,6 +851,7 @@ impl LanguageModel for CloudLanguageModel { .chain(usage_updated_event(usage)) .chain(tool_use_limit_reached_event(tool_use_limit_reached)), ), + &provider_name, move |event| mapper.map_event(event), )) }); @@ -895,6 +898,7 @@ impl LanguageModel for CloudLanguageModel { .chain(usage_updated_event(usage)) .chain(tool_use_limit_reached_event(tool_use_limit_reached)), ), + &provider_name, move |event| mapper.map_event(event), )) }); @@ -935,6 +939,7 @@ impl LanguageModel for CloudLanguageModel { .chain(usage_updated_event(usage)) .chain(tool_use_limit_reached_event(tool_use_limit_reached)), ), + &provider_name, move |event| mapper.map_event(event), )) }); @@ -946,6 +951,7 @@ impl LanguageModel for CloudLanguageModel { fn map_cloud_completion_events( stream: Pin>> + Send>>, + provider: &LanguageModelProviderName, mut map_callback: F, ) -> BoxStream<'static, Result> where @@ -954,6 +960,7 @@ where + Send + 'static, { + let provider = provider.clone(); stream .flat_map(move |event| { futures::stream::iter(match event { @@ -961,7 +968,12 @@ where vec![Err(LanguageModelCompletionError::from(error))] } Ok(CompletionEvent::Status(event)) => { - vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))] + vec![ + LanguageModelCompletionEvent::from_completion_request_status( + event, + provider.clone(), + ), + ] } Ok(CompletionEvent::Event(event)) => map_callback(event), }) @@ -969,6 +981,17 @@ where .boxed() } +fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName { + match provider { + cloud_llm_client::LanguageModelProvider::Anthropic => { + language_model::ANTHROPIC_PROVIDER_NAME + } + cloud_llm_client::LanguageModelProvider::OpenAi => language_model::OPEN_AI_PROVIDER_NAME, + cloud_llm_client::LanguageModelProvider::Google => language_model::GOOGLE_PROVIDER_NAME, + cloud_llm_client::LanguageModelProvider::XAi => language_model::X_AI_PROVIDER_NAME, + } +} + fn usage_updated_event( usage: Option, ) -> impl Stream>> {