Detailed changes
@@ -664,9 +664,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
);
// Simulate reaching tool use limit.
- fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
- cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
- ));
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
fake_model.end_last_completion_stream();
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
assert!(
@@ -749,9 +747,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
};
fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
- fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
- cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
- ));
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
fake_model.end_last_completion_stream();
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
assert!(
@@ -15,7 +15,7 @@ use agent_settings::{
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
use client::{ModelRequestUsage, RequestUsage, UserStore};
-use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
+use cloud_llm_client::{CompletionIntent, Plan, UsageLimit};
use collections::{HashMap, HashSet, IndexMap};
use fs::Fs;
use futures::stream;
@@ -1430,20 +1430,16 @@ impl Thread {
);
self.update_token_usage(usage, cx);
}
- StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => {
+ UsageUpdated { amount, limit } => {
self.update_model_request_usage(amount, limit, cx);
}
- StatusUpdate(
- CompletionRequestStatus::Started
- | CompletionRequestStatus::Queued { .. }
- | CompletionRequestStatus::Failed { .. },
- ) => {}
- StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => {
+ ToolUseLimitReached => {
self.tool_use_limit_reached = true;
}
Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()),
Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()),
Stop(StopReason::ToolUse | StopReason::EndTurn) => {}
+ Started | Queued { .. } => {}
}
Ok(None)
@@ -1687,9 +1683,7 @@ impl Thread {
let event = event.log_err()?;
let text = match event {
LanguageModelCompletionEvent::Text(text) => text,
- LanguageModelCompletionEvent::StatusUpdate(
- CompletionRequestStatus::UsageUpdated { amount, limit },
- ) => {
+ LanguageModelCompletionEvent::UsageUpdated { amount, limit } => {
this.update(cx, |thread, cx| {
thread.update_model_request_usage(amount, limit, cx);
})
@@ -1753,9 +1747,7 @@ impl Thread {
let event = event?;
let text = match event {
LanguageModelCompletionEvent::Text(text) => text,
- LanguageModelCompletionEvent::StatusUpdate(
- CompletionRequestStatus::UsageUpdated { amount, limit },
- ) => {
+ LanguageModelCompletionEvent::UsageUpdated { amount, limit } => {
this.update(cx, |thread, cx| {
thread.update_model_request_usage(amount, limit, cx);
})?;
@@ -7,9 +7,10 @@ use assistant_slash_command::{
use assistant_slash_commands::FileCommandMetadata;
use client::{self, ModelRequestUsage, RequestUsage, proto, telemetry::Telemetry};
use clock::ReplicaId;
-use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
+use cloud_llm_client::{CompletionIntent, UsageLimit};
use collections::{HashMap, HashSet};
use fs::{Fs, RenameOptions};
+
use futures::{FutureExt, StreamExt, future::Shared};
use gpui::{
App, AppContext as _, Context, Entity, EventEmitter, RenderImage, SharedString, Subscription,
@@ -2073,14 +2074,15 @@ impl TextThread {
});
match event {
- LanguageModelCompletionEvent::StatusUpdate(status_update) => {
- if let CompletionRequestStatus::UsageUpdated { amount, limit } = status_update {
- this.update_model_request_usage(
- amount as u32,
- limit,
- cx,
- );
- }
+ LanguageModelCompletionEvent::Started |
+ LanguageModelCompletionEvent::Queued {..} |
+ LanguageModelCompletionEvent::ToolUseLimitReached { .. } => {}
+ LanguageModelCompletionEvent::UsageUpdated { amount, limit } => {
+ this.update_model_request_usage(
+ amount as u32,
+ limit,
+ cx,
+ );
}
LanguageModelCompletionEvent::StartMessage { .. } => {}
LanguageModelCompletionEvent::Stop(reason) => {
@@ -1251,8 +1251,11 @@ pub fn response_events_to_markdown(
}
Ok(
LanguageModelCompletionEvent::UsageUpdate(_)
+ | LanguageModelCompletionEvent::ToolUseLimitReached
| LanguageModelCompletionEvent::StartMessage { .. }
- | LanguageModelCompletionEvent::StatusUpdate { .. },
+ | LanguageModelCompletionEvent::UsageUpdated { .. }
+ | LanguageModelCompletionEvent::Queued { .. }
+ | LanguageModelCompletionEvent::Started,
) => {}
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
json_parse_error, ..
@@ -1337,9 +1340,12 @@ impl ThreadDialog {
// Skip these
Ok(LanguageModelCompletionEvent::UsageUpdate(_))
| Ok(LanguageModelCompletionEvent::RedactedThinking { .. })
- | Ok(LanguageModelCompletionEvent::StatusUpdate { .. })
| Ok(LanguageModelCompletionEvent::StartMessage { .. })
- | Ok(LanguageModelCompletionEvent::Stop(_)) => {}
+ | Ok(LanguageModelCompletionEvent::Stop(_))
+ | Ok(LanguageModelCompletionEvent::Queued { .. })
+ | Ok(LanguageModelCompletionEvent::Started)
+ | Ok(LanguageModelCompletionEvent::UsageUpdated { .. })
+ | Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => {}
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
json_parse_error,
@@ -12,7 +12,7 @@ pub mod fake_provider;
use anthropic::{AnthropicError, parse_prompt_too_long};
use anyhow::{Result, anyhow};
use client::Client;
-use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
+use cloud_llm_client::{CompletionMode, CompletionRequestStatus, UsageLimit};
use futures::FutureExt;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
@@ -70,7 +70,15 @@ pub fn init_settings(cx: &mut App) {
/// A completion event from a language model.
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum LanguageModelCompletionEvent {
- StatusUpdate(CompletionRequestStatus),
+ Queued {
+ position: usize,
+ },
+ Started,
+ UsageUpdated {
+ amount: usize,
+ limit: UsageLimit,
+ },
+ ToolUseLimitReached,
Stop(StopReason),
Text(String),
Thinking {
@@ -93,6 +101,37 @@ pub enum LanguageModelCompletionEvent {
UsageUpdate(TokenUsage),
}
+impl LanguageModelCompletionEvent {
+ pub fn from_completion_request_status(
+ status: CompletionRequestStatus,
+ upstream_provider: LanguageModelProviderName,
+ ) -> Result<Self, LanguageModelCompletionError> {
+ match status {
+ CompletionRequestStatus::Queued { position } => {
+ Ok(LanguageModelCompletionEvent::Queued { position })
+ }
+ CompletionRequestStatus::Started => Ok(LanguageModelCompletionEvent::Started),
+ CompletionRequestStatus::UsageUpdated { amount, limit } => {
+ Ok(LanguageModelCompletionEvent::UsageUpdated { amount, limit })
+ }
+ CompletionRequestStatus::ToolUseLimitReached => {
+ Ok(LanguageModelCompletionEvent::ToolUseLimitReached)
+ }
+ CompletionRequestStatus::Failed {
+ code,
+ message,
+ request_id: _,
+ retry_after,
+ } => Err(LanguageModelCompletionError::from_cloud_failure(
+ upstream_provider,
+ code,
+ message,
+ retry_after.map(Duration::from_secs_f64),
+ )),
+ }
+ }
+}
+
#[derive(Error, Debug)]
pub enum LanguageModelCompletionError {
#[error("prompt too large for context window")]
@@ -633,7 +672,10 @@ pub trait LanguageModel: Send + Sync {
let last_token_usage = last_token_usage.clone();
async move {
match result {
- Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
+ Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
+ Ok(LanguageModelCompletionEvent::Started) => None,
+ Ok(LanguageModelCompletionEvent::UsageUpdated { .. }) => None,
+ Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => None,
Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
@@ -752,6 +752,7 @@ impl LanguageModel for CloudLanguageModel {
let mode = request.mode;
let app_version = cx.update(|cx| AppVersion::global(cx)).ok();
let thinking_allowed = request.thinking_allowed;
+ let provider_name = provider_name(&self.model.provider);
match self.model.provider {
cloud_llm_client::LanguageModelProvider::Anthropic => {
let request = into_anthropic(
@@ -801,8 +802,9 @@ impl LanguageModel for CloudLanguageModel {
Box::pin(
response_lines(response, includes_status_messages)
.chain(usage_updated_event(usage))
- .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
+ .chain(tool_use_limit_reached_event(tool_use_limit_reached)), // .map(|_| {}),
),
+ &provider_name,
move |event| mapper.map_event(event),
))
});
@@ -849,6 +851,7 @@ impl LanguageModel for CloudLanguageModel {
.chain(usage_updated_event(usage))
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
),
+ &provider_name,
move |event| mapper.map_event(event),
))
});
@@ -895,6 +898,7 @@ impl LanguageModel for CloudLanguageModel {
.chain(usage_updated_event(usage))
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
),
+ &provider_name,
move |event| mapper.map_event(event),
))
});
@@ -935,6 +939,7 @@ impl LanguageModel for CloudLanguageModel {
.chain(usage_updated_event(usage))
.chain(tool_use_limit_reached_event(tool_use_limit_reached)),
),
+ &provider_name,
move |event| mapper.map_event(event),
))
});
@@ -946,6 +951,7 @@ impl LanguageModel for CloudLanguageModel {
fn map_cloud_completion_events<T, F>(
stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
+ provider: &LanguageModelProviderName,
mut map_callback: F,
) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
where
@@ -954,6 +960,7 @@ where
+ Send
+ 'static,
{
+ let provider = provider.clone();
stream
.flat_map(move |event| {
futures::stream::iter(match event {
@@ -961,7 +968,12 @@ where
vec![Err(LanguageModelCompletionError::from(error))]
}
Ok(CompletionEvent::Status(event)) => {
- vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
+ vec![
+ LanguageModelCompletionEvent::from_completion_request_status(
+ event,
+ provider.clone(),
+ ),
+ ]
}
Ok(CompletionEvent::Event(event)) => map_callback(event),
})
@@ -969,6 +981,17 @@ where
.boxed()
}
+fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {
+ match provider {
+ cloud_llm_client::LanguageModelProvider::Anthropic => {
+ language_model::ANTHROPIC_PROVIDER_NAME
+ }
+ cloud_llm_client::LanguageModelProvider::OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
+ cloud_llm_client::LanguageModelProvider::Google => language_model::GOOGLE_PROVIDER_NAME,
+ cloud_llm_client::LanguageModelProvider::XAi => language_model::X_AI_PROVIDER_NAME,
+ }
+}
+
fn usage_updated_event<T>(
usage: Option<ModelRequestUsage>,
) -> impl Stream<Item = Result<CompletionEvent<T>>> {