From 04772bf17d9d80d890ca71bce420c4c746c18d45 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 2 May 2025 13:36:39 -0700 Subject: [PATCH] Add support for queuing status updates in cloud language model provider (#29818) This sets us up to display queue position information to the user, once our language model backend is updated to support request queuing. The JSON returned by the LLM backend will need to look like this: ```json {"queue": {"status": "queued", "position": 1}} {"queue": {"status": "started"}} {"event": {"THE_UPSTREAM_MODEL_PROVIDER_EVENT": "..."}} ``` Release Notes: - N/A --------- Co-authored-by: Marshall Bowers --- crates/agent/src/active_thread.rs | 31 +- crates/agent/src/thread.rs | 29 ++ .../assistant_context_editor/src/context.rs | 1 + crates/eval/src/instance.rs | 4 +- crates/language_model/src/language_model.rs | 9 + .../language_models/src/provider/anthropic.rs | 369 ++++++++---------- crates/language_models/src/provider/cloud.rs | 94 +++-- crates/language_models/src/provider/google.rs | 181 +++++---- .../language_models/src/provider/open_ai.rs | 200 +++++----- 9 files changed, 490 insertions(+), 428 deletions(-) diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index af62fedc5c598484181b90d4331ac5f2bc25d2cc..58822247efa936245c41d765073d27eef0d56e4a 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -4,8 +4,8 @@ use crate::context_store::ContextStore; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::message_editor::insert_message_creases; use crate::thread::{ - LastRestoreCheckpoint, MessageCrease, MessageId, MessageSegment, Thread, ThreadError, - ThreadEvent, ThreadFeedback, + LastRestoreCheckpoint, MessageCrease, MessageId, MessageSegment, QueueState, Thread, + ThreadError, ThreadEvent, ThreadFeedback, }; use crate::thread_store::{RulesLoadingError, ThreadStore}; use crate::tool_use::{PendingToolUseStatus, ToolUse}; @@ -1733,8 +1733,27 @@ impl ActiveThread { let show_feedback = thread.is_turn_end(ix); - let generating_label = (is_generating && is_last_message) - .then(|| AnimatedLabel::new("Generating").size(LabelSize::Small)); + let generating_label = is_last_message + .then(|| match (thread.queue_state(), is_generating) { + (Some(QueueState::Sending), _) => Some( + AnimatedLabel::new("Sending") + .size(LabelSize::Small) + .into_any_element(), + ), + (Some(QueueState::Queued { position }), _) => Some( + Label::new(format!("Queue position: {position}")) + .size(LabelSize::Small) + .color(Color::Muted) + .into_any_element(), + ), + (_, true) => Some( + AnimatedLabel::new("Generating") + .size(LabelSize::Small) + .into_any_element(), + ), + _ => None, + }) + .flatten(); let editing_message_state = self .editing_message @@ -2105,7 +2124,7 @@ impl ActiveThread { parent.child(self.render_rules_item(cx)) }) .child(styled_message) - .when(generating_label.is_some(), |this| { + .when_some(generating_label, |this, generating_label| { this.child( h_flex() .h_8() @@ -2113,7 +2132,7 @@ impl ActiveThread { .mb_4() .ml_4() .py_1p5() - .child(generating_label.unwrap()), + .child(generating_label), ) }) .when(show_feedback, move |parent| { diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 5fb85c4ede09395763e69b88290c3b8be74606ec..5f0561bb6697323ab3bea7bb489bf73b744f0f3e 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -320,6 +320,13 @@ fn default_completion_mode(cx: &App) -> CompletionMode { } } +#[derive(Debug, Clone, Copy)] +pub enum QueueState { + Sending, + Queued { position: usize }, + Started, +} + /// A thread of conversation with the LLM. pub struct Thread { id: ThreadId, @@ -625,6 +632,12 @@ impl Thread { !self.pending_completions.is_empty() || !self.all_tools_finished() } + pub fn queue_state(&self) -> Option { + self.pending_completions + .first() + .map(|pending_completion| pending_completion.queue_state) + } + pub fn tools(&self) -> &Entity { &self.tools } @@ -1470,6 +1483,20 @@ impl Thread { }); } } + LanguageModelCompletionEvent::QueueUpdate(queue_event) => { + if let Some(completion) = thread + .pending_completions + .iter_mut() + .find(|completion| completion.id == pending_completion_id) + { + completion.queue_state = match queue_event { + language_model::QueueState::Queued { position } => { + QueueState::Queued { position } + } + language_model::QueueState::Started => QueueState::Started, + } + } + } } thread.touch_updated_at(); @@ -1590,6 +1617,7 @@ impl Thread { self.pending_completions.push(PendingCompletion { id: pending_completion_id, + queue_state: QueueState::Sending, _task: task, }); } @@ -2499,6 +2527,7 @@ impl EventEmitter for Thread {} struct PendingCompletion { id: usize, + queue_state: QueueState, _task: Task<()>, } diff --git a/crates/assistant_context_editor/src/context.rs b/crates/assistant_context_editor/src/context.rs index f29bcbe7530826b43dee948063444639e1bf4b43..6beeaf34614b47ece3a62dada22f9bdb327f8b4f 100644 --- a/crates/assistant_context_editor/src/context.rs +++ b/crates/assistant_context_editor/src/context.rs @@ -2371,6 +2371,7 @@ impl AssistantContext { }); match event { + LanguageModelCompletionEvent::QueueUpdate { .. } => {} LanguageModelCompletionEvent::StartMessage { .. } => {} LanguageModelCompletionEvent::Stop(reason) => { stop_reason = reason; diff --git a/crates/eval/src/instance.rs b/crates/eval/src/instance.rs index 0c477143e6f5fc87e94053aeadde41b625edb7e2..d3c5fdb29c76bbf447113949831af875aa8c7a82 100644 --- a/crates/eval/src/instance.rs +++ b/crates/eval/src/instance.rs @@ -1017,7 +1017,8 @@ pub fn response_events_to_markdown( } Ok( LanguageModelCompletionEvent::UsageUpdate(_) - | LanguageModelCompletionEvent::StartMessage { .. }, + | LanguageModelCompletionEvent::StartMessage { .. } + | LanguageModelCompletionEvent::QueueUpdate { .. }, ) => {} Err(error) => { flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer); @@ -1092,6 +1093,7 @@ impl ThreadDialog { // Skip these Ok(LanguageModelCompletionEvent::UsageUpdate(_)) + | Ok(LanguageModelCompletionEvent::QueueUpdate { .. }) | Ok(LanguageModelCompletionEvent::StartMessage { .. }) | Ok(LanguageModelCompletionEvent::Stop(_)) => {} diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 4c9e918756c64320022b3c65c885efebe3dd6732..1146bbc1370e499f62ba5fe56c4ca094597c7178 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -64,9 +64,17 @@ pub struct LanguageModelCacheConfiguration { pub min_total_token: usize, } +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] +#[serde(tag = "status", rename_all = "snake_case")] +pub enum QueueState { + Queued { position: usize }, + Started, +} + /// A completion event from a language model. #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub enum LanguageModelCompletionEvent { + QueueUpdate(QueueState), Stop(StopReason), Text(String), Thinking { @@ -349,6 +357,7 @@ pub trait LanguageModel: Send + Sync { let last_token_usage = last_token_usage.clone(); async move { match result { + Ok(LanguageModelCompletionEvent::QueueUpdate { .. }) => None, Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None, Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)), Ok(LanguageModelCompletionEvent::Thinking { .. }) => None, diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 52f98068320808ec33ba7852d8b54abfc3be9917..a22a027534db974d55c840c77730e4e742e88560 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -469,7 +469,7 @@ impl LanguageModel for AnthropicModel { Ok(anthropic_err) => anthropic_err_to_anyhow(anthropic_err), Err(err) => anyhow!(err), })?; - Ok(map_to_language_model_completion_events(response)) + Ok(AnthropicEventMapper::new().map_stream(response)) }); async move { Ok(future.await?.boxed()) }.boxed() } @@ -629,215 +629,186 @@ pub fn into_anthropic( } } -pub fn map_to_language_model_completion_events( - events: Pin>>>, -) -> impl Stream> { - struct RawToolUse { - id: String, - name: String, - input_json: String, +pub struct AnthropicEventMapper { + tool_uses_by_index: HashMap, + usage: Usage, + stop_reason: StopReason, +} + +impl AnthropicEventMapper { + pub fn new() -> Self { + Self { + tool_uses_by_index: HashMap::default(), + usage: Usage::default(), + stop_reason: StopReason::EndTurn, + } } - struct State { + pub fn map_stream( + mut self, events: Pin>>>, - tool_uses_by_index: HashMap, - usage: Usage, - stop_reason: StopReason, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], + }) + }) } - futures::stream::unfold( - State { - events, - tool_uses_by_index: HashMap::default(), - usage: Usage::default(), - stop_reason: StopReason::EndTurn, - }, - |mut state| async move { - while let Some(event) = state.events.next().await { - match event { - Ok(event) => match event { - Event::ContentBlockStart { - index, - content_block, - } => match content_block { - ResponseContent::Text { text } => { - return Some(( - vec![Ok(LanguageModelCompletionEvent::Text(text))], - state, - )); - } - ResponseContent::Thinking { thinking } => { - return Some(( - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: thinking, - signature: None, - })], - state, - )); - } - ResponseContent::RedactedThinking { .. } => { - // Redacted thinking is encrypted and not accessible to the user, see: - // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#suggestions-for-handling-redacted-thinking-in-production - } - ResponseContent::ToolUse { id, name, .. } => { - state.tool_uses_by_index.insert( - index, - RawToolUse { - id, - name, - input_json: String::new(), - }, - ); - } - }, - Event::ContentBlockDelta { index, delta } => match delta { - ContentDelta::TextDelta { text } => { - return Some(( - vec![Ok(LanguageModelCompletionEvent::Text(text))], - state, - )); - } - ContentDelta::ThinkingDelta { thinking } => { - return Some(( - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: thinking, - signature: None, - })], - state, - )); - } - ContentDelta::SignatureDelta { signature } => { - return Some(( - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: "".to_string(), - signature: Some(signature), - })], - state, - )); - } - ContentDelta::InputJsonDelta { partial_json } => { - if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) { - tool_use.input_json.push_str(&partial_json); - - // Try to convert invalid (incomplete) JSON into - // valid JSON that serde can accept, e.g. by closing - // unclosed delimiters. This way, we can update the - // UI with whatever has been streamed back so far. - if let Ok(input) = serde_json::Value::from_str( - &partial_json_fixer::fix_json(&tool_use.input_json), - ) { - return Some(( - vec![Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_use.id.clone().into(), - name: tool_use.name.clone().into(), - is_input_complete: false, - raw_input: tool_use.input_json.clone(), - input, - }, - ))], - state, - )); - } - } - } + pub fn map_event( + &mut self, + event: Event, + ) -> Vec> { + match event { + Event::ContentBlockStart { + index, + content_block, + } => match content_block { + ResponseContent::Text { text } => { + vec![Ok(LanguageModelCompletionEvent::Text(text))] + } + ResponseContent::Thinking { thinking } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: thinking, + signature: None, + })] + } + ResponseContent::RedactedThinking { .. } => { + // Redacted thinking is encrypted and not accessible to the user, see: + // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#suggestions-for-handling-redacted-thinking-in-production + Vec::new() + } + ResponseContent::ToolUse { id, name, .. } => { + self.tool_uses_by_index.insert( + index, + RawToolUse { + id, + name, + input_json: String::new(), }, - Event::ContentBlockStop { index } => { - if let Some(tool_use) = state.tool_uses_by_index.remove(&index) { - let input_json = tool_use.input_json.trim(); - let input_value = if input_json.is_empty() { - Ok(serde_json::Value::Object(serde_json::Map::default())) - } else { - serde_json::Value::from_str(input_json) - }; - let event_result = match input_value { - Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_use.id.into(), - name: tool_use.name.into(), - is_input_complete: true, - input, - raw_input: tool_use.input_json.clone(), - }, - )), - Err(json_parse_err) => { - Err(LanguageModelCompletionError::BadInputJson { - id: tool_use.id.into(), - tool_name: tool_use.name.into(), - raw_input: input_json.into(), - json_parse_error: json_parse_err.to_string(), - }) - } - }; - - return Some((vec![event_result], state)); - } - } - Event::MessageStart { message } => { - update_usage(&mut state.usage, &message.usage); - return Some(( - vec![ - Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage( - &state.usage, - ))), - Ok(LanguageModelCompletionEvent::StartMessage { - message_id: message.id, - }), - ], - state, - )); - } - Event::MessageDelta { delta, usage } => { - update_usage(&mut state.usage, &usage); - if let Some(stop_reason) = delta.stop_reason.as_deref() { - state.stop_reason = match stop_reason { - "end_turn" => StopReason::EndTurn, - "max_tokens" => StopReason::MaxTokens, - "tool_use" => StopReason::ToolUse, - _ => { - log::error!( - "Unexpected anthropic stop_reason: {stop_reason}" - ); - StopReason::EndTurn - } - }; - } - return Some(( - vec![Ok(LanguageModelCompletionEvent::UsageUpdate( - convert_usage(&state.usage), - ))], - state, - )); - } - Event::MessageStop => { - return Some(( - vec![Ok(LanguageModelCompletionEvent::Stop(state.stop_reason))], - state, - )); - } - Event::Error { error } => { - return Some(( - vec![Err(LanguageModelCompletionError::Other(anyhow!( - AnthropicError::ApiError(error) - )))], - state, - )); + ); + Vec::new() + } + }, + Event::ContentBlockDelta { index, delta } => match delta { + ContentDelta::TextDelta { text } => { + vec![Ok(LanguageModelCompletionEvent::Text(text))] + } + ContentDelta::ThinkingDelta { thinking } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: thinking, + signature: None, + })] + } + ContentDelta::SignatureDelta { signature } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: "".to_string(), + signature: Some(signature), + })] + } + ContentDelta::InputJsonDelta { partial_json } => { + if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) { + tool_use.input_json.push_str(&partial_json); + + // Try to convert invalid (incomplete) JSON into + // valid JSON that serde can accept, e.g. by closing + // unclosed delimiters. This way, we can update the + // UI with whatever has been streamed back so far. + if let Ok(input) = serde_json::Value::from_str( + &partial_json_fixer::fix_json(&tool_use.input_json), + ) { + return vec![Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_use.id.clone().into(), + name: tool_use.name.clone().into(), + is_input_complete: false, + raw_input: tool_use.input_json.clone(), + input, + }, + ))]; } - _ => {} - }, - Err(err) => { - return Some(( - vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))], - state, - )); } + return vec![]; + } + }, + Event::ContentBlockStop { index } => { + if let Some(tool_use) = self.tool_uses_by_index.remove(&index) { + let input_json = tool_use.input_json.trim(); + let input_value = if input_json.is_empty() { + Ok(serde_json::Value::Object(serde_json::Map::default())) + } else { + serde_json::Value::from_str(input_json) + }; + let event_result = match input_value { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_use.id.into(), + name: tool_use.name.into(), + is_input_complete: true, + input, + raw_input: tool_use.input_json.clone(), + }, + )), + Err(json_parse_err) => Err(LanguageModelCompletionError::BadInputJson { + id: tool_use.id.into(), + tool_name: tool_use.name.into(), + raw_input: input_json.into(), + json_parse_error: json_parse_err.to_string(), + }), + }; + + vec![event_result] + } else { + Vec::new() + } + } + Event::MessageStart { message } => { + update_usage(&mut self.usage, &message.usage); + vec![ + Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage( + &self.usage, + ))), + Ok(LanguageModelCompletionEvent::StartMessage { + message_id: message.id, + }), + ] + } + Event::MessageDelta { delta, usage } => { + update_usage(&mut self.usage, &usage); + if let Some(stop_reason) = delta.stop_reason.as_deref() { + self.stop_reason = match stop_reason { + "end_turn" => StopReason::EndTurn, + "max_tokens" => StopReason::MaxTokens, + "tool_use" => StopReason::ToolUse, + _ => { + log::error!("Unexpected anthropic stop_reason: {stop_reason}"); + StopReason::EndTurn + } + }; } + vec![Ok(LanguageModelCompletionEvent::UsageUpdate( + convert_usage(&self.usage), + ))] } + Event::MessageStop => { + vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))] + } + Event::Error { error } => { + vec![Err(LanguageModelCompletionError::Other(anyhow!( + AnthropicError::ApiError(error) + )))] + } + _ => Vec::new(), + } + } +} - None - }, - ) - .flat_map(futures::stream::iter) +struct RawToolUse { + id: String, + name: String, + input_json: String, } pub fn anthropic_err_to_anyhow(err: AnthropicError) -> anyhow::Error { diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index d0f8ba275af102b6cdd19fc0b5663afe06b5b411..556be2c75d1b86d4b0e70dfed9114e420bf5aa16 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -1,11 +1,10 @@ -use anthropic::{AnthropicError, AnthropicModelMode, parse_prompt_too_long}; +use anthropic::{AnthropicModelMode, parse_prompt_too_long}; use anyhow::{Result, anyhow}; use client::{Client, UserStore, zed_urls}; use collections::BTreeMap; use feature_flags::{FeatureFlagAppExt, LlmClosedBetaFeatureFlag, ZedProFeatureFlag}; use futures::{ - AsyncBufReadExt, FutureExt, Stream, StreamExt, TryStreamExt as _, future::BoxFuture, - stream::BoxStream, + AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream, }; use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task}; use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode}; @@ -14,7 +13,7 @@ use language_model::{ LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat, - ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID, + ModelRequestLimitReachedError, QueueState, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID, }; use language_model::{ LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, @@ -26,6 +25,7 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned}; use settings::{Settings, SettingsStore}; use smol::Timer; use smol::io::{AsyncReadExt, BufReader}; +use std::pin::Pin; use std::str::FromStr as _; use std::{ sync::{Arc, LazyLock}, @@ -41,9 +41,9 @@ use zed_llm_client::{ }; use crate::AllLanguageModelSettings; -use crate::provider::anthropic::{count_anthropic_tokens, into_anthropic}; -use crate::provider::google::into_google; -use crate::provider::open_ai::{count_open_ai_tokens, into_open_ai}; +use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic}; +use crate::provider::google::{GoogleEventMapper, into_google}; +use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai}; pub const PROVIDER_NAME: &str = "Zed"; @@ -518,7 +518,7 @@ impl CloudLanguageModel { client: Arc, llm_api_token: LlmApiToken, body: CompletionBody, - ) -> Result<(Response, Option)> { + ) -> Result<(Response, Option, bool)> { let http_client = &client.http_client(); let mut token = llm_api_token.acquire(&client).await?; @@ -536,13 +536,18 @@ impl CloudLanguageModel { let request = request_builder .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {token}")) + .header("x-zed-client-supports-queueing", "true") .body(serde_json::to_string(&body)?.into())?; let mut response = http_client.send(request).await?; let status = response.status(); if status.is_success() { + let includes_queue_events = response + .headers() + .get("x-zed-server-supports-queueing") + .is_some(); let usage = RequestUsage::from_headers(response.headers()).ok(); - return Ok((response, usage)); + return Ok((response, usage, includes_queue_events)); } else if response .headers() .get(EXPIRED_LLM_TOKEN_HEADER_NAME) @@ -782,7 +787,7 @@ impl LanguageModel for CloudLanguageModel { let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); let future = self.request_limiter.stream_with_usage(async move { - let (response, usage) = Self::perform_llm_completion( + let (response, usage, includes_queue_events) = Self::perform_llm_completion( client.clone(), llm_api_token, CompletionBody { @@ -811,9 +816,11 @@ impl LanguageModel for CloudLanguageModel { Err(err) => anyhow!(err), })?; + let mut mapper = AnthropicEventMapper::new(); Ok(( - crate::provider::anthropic::map_to_language_model_completion_events( - Box::pin(response_lines(response).map_err(AnthropicError::Other)), + map_cloud_completion_events( + Box::pin(response_lines(response, includes_queue_events)), + move |event| mapper.map_event(event), ), usage, )) @@ -829,7 +836,7 @@ impl LanguageModel for CloudLanguageModel { let request = into_open_ai(request, model, model.max_output_tokens()); let llm_api_token = self.llm_api_token.clone(); let future = self.request_limiter.stream_with_usage(async move { - let (response, usage) = Self::perform_llm_completion( + let (response, usage, includes_queue_events) = Self::perform_llm_completion( client.clone(), llm_api_token, CompletionBody { @@ -842,9 +849,12 @@ impl LanguageModel for CloudLanguageModel { }, ) .await?; + + let mut mapper = OpenAiEventMapper::new(); Ok(( - crate::provider::open_ai::map_to_language_model_completion_events( - Box::pin(response_lines(response)), + map_cloud_completion_events( + Box::pin(response_lines(response, includes_queue_events)), + move |event| mapper.map_event(event), ), usage, )) @@ -860,7 +870,7 @@ impl LanguageModel for CloudLanguageModel { let request = into_google(request, model.id().into()); let llm_api_token = self.llm_api_token.clone(); let future = self.request_limiter.stream_with_usage(async move { - let (response, usage) = Self::perform_llm_completion( + let (response, usage, includes_queue_events) = Self::perform_llm_completion( client.clone(), llm_api_token, CompletionBody { @@ -873,10 +883,12 @@ impl LanguageModel for CloudLanguageModel { }, ) .await?; + let mut mapper = GoogleEventMapper::new(); Ok(( - crate::provider::google::map_to_language_model_completion_events(Box::pin( - response_lines(response), - )), + map_cloud_completion_events( + Box::pin(response_lines(response, includes_queue_events)), + move |event| mapper.map_event(event), + ), usage, )) }); @@ -890,16 +902,54 @@ impl LanguageModel for CloudLanguageModel { } } +#[derive(Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum CloudCompletionEvent { + Queue(QueueState), + Event(T), +} + +fn map_cloud_completion_events( + stream: Pin>> + Send>>, + mut map_callback: F, +) -> BoxStream<'static, Result> +where + T: DeserializeOwned + 'static, + F: FnMut(T) -> Vec> + + Send + + 'static, +{ + stream + .flat_map(move |event| { + futures::stream::iter(match event { + Err(error) => { + vec![Err(LanguageModelCompletionError::Other(error))] + } + Ok(CloudCompletionEvent::Queue(event)) => { + vec![Ok(LanguageModelCompletionEvent::QueueUpdate(event))] + } + Ok(CloudCompletionEvent::Event(event)) => map_callback(event), + }) + }) + .boxed() +} + fn response_lines( response: Response, -) -> impl Stream> { + includes_queue_events: bool, +) -> impl Stream>> { futures::stream::try_unfold( (String::new(), BufReader::new(response.into_body())), - move |(mut line, mut body)| async { + move |(mut line, mut body)| async move { match body.read_line(&mut line).await { Ok(0) => Ok(None), Ok(_) => { - let event: T = serde_json::from_str(&line)?; + let event = if includes_queue_events { + serde_json::from_str::>(&line)? + } else { + CloudCompletionEvent::Event(serde_json::from_str::(&line)?) + }; + line.clear(); Ok(Some((event, (line, body)))) } diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 1bb0df310e3d56ad09254a2f4d6f8c5f3cf3a8a7..b5751bb35977ca4696ae1401c1c56546c30ed512 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -24,7 +24,10 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::pin::Pin; -use std::sync::Arc; +use std::sync::{ + Arc, + atomic::{self, AtomicU64}, +}; use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::{Icon, IconName, List, Tooltip, prelude::*}; @@ -371,7 +374,7 @@ impl LanguageModel for GoogleLanguageModel { let response = request .await .map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?; - Ok(map_to_language_model_completion_events(response)) + Ok(GoogleEventMapper::new().map_stream(response)) }); async move { Ok(future.await?.boxed()) }.boxed() } @@ -486,108 +489,98 @@ pub fn into_google( } } -pub fn map_to_language_model_completion_events( - events: Pin>>>, -) -> impl Stream> { - use std::sync::atomic::{AtomicU64, Ordering}; +pub struct GoogleEventMapper { + usage: UsageMetadata, + stop_reason: StopReason, +} - static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0); +impl GoogleEventMapper { + pub fn new() -> Self { + Self { + usage: UsageMetadata::default(), + stop_reason: StopReason::EndTurn, + } + } - struct State { + pub fn map_stream( + mut self, events: Pin>>>, - usage: UsageMetadata, - stop_reason: StopReason, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], + }) + }) } - futures::stream::unfold( - State { - events, - usage: UsageMetadata::default(), - stop_reason: StopReason::EndTurn, - }, - |mut state| async move { - if let Some(event) = state.events.next().await { - match event { - Ok(event) => { - let mut events: Vec<_> = Vec::new(); - let mut wants_to_use_tool = false; - if let Some(usage_metadata) = event.usage_metadata { - update_usage(&mut state.usage, &usage_metadata); - events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( - convert_usage(&state.usage), - ))) - } - if let Some(candidates) = event.candidates { - for candidate in candidates { - if let Some(finish_reason) = candidate.finish_reason.as_deref() { - state.stop_reason = match finish_reason { - "STOP" => StopReason::EndTurn, - "MAX_TOKENS" => StopReason::MaxTokens, - _ => { - log::error!( - "Unexpected google finish_reason: {finish_reason}" - ); - StopReason::EndTurn - } - }; - } - candidate - .content - .parts - .into_iter() - .for_each(|part| match part { - Part::TextPart(text_part) => events.push(Ok( - LanguageModelCompletionEvent::Text(text_part.text), - )), - Part::InlineDataPart(_) => {} - Part::FunctionCallPart(function_call_part) => { - wants_to_use_tool = true; - let name: Arc = - function_call_part.function_call.name.into(); - let next_tool_id = - TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst); - let id: LanguageModelToolUseId = - format!("{}-{}", name, next_tool_id).into(); - - events.push(Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id, - name, - is_input_complete: true, - raw_input: function_call_part - .function_call - .args - .to_string(), - input: function_call_part.function_call.args, - }, - ))); - } - Part::FunctionResponsePart(_) => {} - }); - } - } + pub fn map_event( + &mut self, + event: GenerateContentResponse, + ) -> Vec> { + static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0); - // Even when Gemini wants to use a Tool, the API - // responds with `finish_reason: STOP` - if wants_to_use_tool { - state.stop_reason = StopReason::ToolUse; + let mut events: Vec<_> = Vec::new(); + let mut wants_to_use_tool = false; + if let Some(usage_metadata) = event.usage_metadata { + update_usage(&mut self.usage, &usage_metadata); + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( + convert_usage(&self.usage), + ))) + } + if let Some(candidates) = event.candidates { + for candidate in candidates { + if let Some(finish_reason) = candidate.finish_reason.as_deref() { + self.stop_reason = match finish_reason { + "STOP" => StopReason::EndTurn, + "MAX_TOKENS" => StopReason::MaxTokens, + _ => { + log::error!("Unexpected google finish_reason: {finish_reason}"); + StopReason::EndTurn } - events.push(Ok(LanguageModelCompletionEvent::Stop(state.stop_reason))); - return Some((events, state)); - } - Err(err) => { - return Some(( - vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))], - state, - )); - } + }; } + candidate + .content + .parts + .into_iter() + .for_each(|part| match part { + Part::TextPart(text_part) => { + events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text))) + } + Part::InlineDataPart(_) => {} + Part::FunctionCallPart(function_call_part) => { + wants_to_use_tool = true; + let name: Arc = function_call_part.function_call.name.into(); + let next_tool_id = + TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst); + let id: LanguageModelToolUseId = + format!("{}-{}", name, next_tool_id).into(); + + events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id, + name, + is_input_complete: true, + raw_input: function_call_part.function_call.args.to_string(), + input: function_call_part.function_call.args, + }, + ))); + } + Part::FunctionResponsePart(_) => {} + }); } + } - None - }, - ) - .flat_map(futures::stream::iter) + // Even when Gemini wants to use a Tool, the API + // responds with `finish_reason: STOP` + if wants_to_use_tool { + self.stop_reason = StopReason::ToolUse; + } + events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))); + events + } } pub fn count_google_tokens( diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 54f27b1727bda7799327278122391cc1634c5233..78d183cb28aa47cb821e3d804f20dceced6b880f 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -330,8 +330,11 @@ impl LanguageModel for OpenAiLanguageModel { > { let request = into_open_ai(request, &self.model, self.max_output_tokens()); let completions = self.stream_completion(request, cx); - async move { Ok(map_to_language_model_completion_events(completions.await?).boxed()) } - .boxed() + async move { + let mapper = OpenAiEventMapper::new(); + Ok(mapper.map_stream(completions.await?).boxed()) + } + .boxed() } } @@ -422,123 +425,108 @@ pub fn into_open_ai( } } -pub fn map_to_language_model_completion_events( - events: Pin>>>, -) -> impl Stream> { - #[derive(Default)] - struct RawToolCall { - id: String, - name: String, - arguments: String, +pub struct OpenAiEventMapper { + tool_calls_by_index: HashMap, +} + +impl OpenAiEventMapper { + pub fn new() -> Self { + Self { + tool_calls_by_index: HashMap::default(), + } } - struct State { + pub fn map_stream( + mut self, events: Pin>>>, - tool_calls_by_index: HashMap, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))], + }) + }) } - futures::stream::unfold( - State { - events, - tool_calls_by_index: HashMap::default(), - }, - |mut state| async move { - if let Some(event) = state.events.next().await { - match event { - Ok(event) => { - let Some(choice) = event.choices.first() else { - return Some(( - vec![Err(LanguageModelCompletionError::Other(anyhow!( - "Response contained no choices" - )))], - state, - )); - }; - - let mut events = Vec::new(); - if let Some(content) = choice.delta.content.clone() { - events.push(Ok(LanguageModelCompletionEvent::Text(content))); - } + pub fn map_event( + &mut self, + event: ResponseStreamEvent, + ) -> Vec> { + let Some(choice) = event.choices.first() else { + return vec![Err(LanguageModelCompletionError::Other(anyhow!( + "Response contained no choices" + )))]; + }; - if let Some(tool_calls) = choice.delta.tool_calls.as_ref() { - for tool_call in tool_calls { - let entry = state - .tool_calls_by_index - .entry(tool_call.index) - .or_default(); - - if let Some(tool_id) = tool_call.id.clone() { - entry.id = tool_id; - } - - if let Some(function) = tool_call.function.as_ref() { - if let Some(name) = function.name.clone() { - entry.name = name; - } - - if let Some(arguments) = function.arguments.clone() { - entry.arguments.push_str(&arguments); - } - } - } - } + let mut events = Vec::new(); + if let Some(content) = choice.delta.content.clone() { + events.push(Ok(LanguageModelCompletionEvent::Text(content))); + } - match choice.finish_reason.as_deref() { - Some("stop") => { - events.push(Ok(LanguageModelCompletionEvent::Stop( - StopReason::EndTurn, - ))); - } - Some("tool_calls") => { - events.extend(state.tool_calls_by_index.drain().map( - |(_, tool_call)| match serde_json::Value::from_str( - &tool_call.arguments, - ) { - Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_call.id.clone().into(), - name: tool_call.name.as_str().into(), - is_input_complete: true, - input, - raw_input: tool_call.arguments.clone(), - }, - )), - Err(error) => { - Err(LanguageModelCompletionError::BadInputJson { - id: tool_call.id.into(), - tool_name: tool_call.name.as_str().into(), - raw_input: tool_call.arguments.into(), - json_parse_error: error.to_string(), - }) - } - }, - )); - - events.push(Ok(LanguageModelCompletionEvent::Stop( - StopReason::ToolUse, - ))); - } - Some(stop_reason) => { - log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",); - events.push(Ok(LanguageModelCompletionEvent::Stop( - StopReason::EndTurn, - ))); - } - None => {} - } + if let Some(tool_calls) = choice.delta.tool_calls.as_ref() { + for tool_call in tool_calls { + let entry = self.tool_calls_by_index.entry(tool_call.index).or_default(); + + if let Some(tool_id) = tool_call.id.clone() { + entry.id = tool_id; + } - return Some((events, state)); + if let Some(function) = tool_call.function.as_ref() { + if let Some(name) = function.name.clone() { + entry.name = name; } - Err(err) => { - return Some((vec![Err(LanguageModelCompletionError::Other(err))], state)); + + if let Some(arguments) = function.arguments.clone() { + entry.arguments.push_str(&arguments); } } } + } - None - }, - ) - .flat_map(futures::stream::iter) + match choice.finish_reason.as_deref() { + Some("stop") => { + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + Some("tool_calls") => { + events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| { + match serde_json::Value::from_str(&tool_call.arguments) { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_call.id.clone().into(), + name: tool_call.name.as_str().into(), + is_input_complete: true, + input, + raw_input: tool_call.arguments.clone(), + }, + )), + Err(error) => Err(LanguageModelCompletionError::BadInputJson { + id: tool_call.id.into(), + tool_name: tool_call.name.as_str().into(), + raw_input: tool_call.arguments.into(), + json_parse_error: error.to_string(), + }), + } + })); + + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); + } + Some(stop_reason) => { + log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",); + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + None => {} + } + + events + } +} + +#[derive(Default)] +struct RawToolCall { + id: String, + name: String, + arguments: String, } pub fn count_open_ai_tokens(