Detailed changes
@@ -4,8 +4,8 @@ use crate::context_store::ContextStore;
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::message_editor::insert_message_creases;
use crate::thread::{
- LastRestoreCheckpoint, MessageCrease, MessageId, MessageSegment, Thread, ThreadError,
- ThreadEvent, ThreadFeedback,
+ LastRestoreCheckpoint, MessageCrease, MessageId, MessageSegment, QueueState, Thread,
+ ThreadError, ThreadEvent, ThreadFeedback,
};
use crate::thread_store::{RulesLoadingError, ThreadStore};
use crate::tool_use::{PendingToolUseStatus, ToolUse};
@@ -1733,8 +1733,27 @@ impl ActiveThread {
let show_feedback = thread.is_turn_end(ix);
- let generating_label = (is_generating && is_last_message)
- .then(|| AnimatedLabel::new("Generating").size(LabelSize::Small));
+ let generating_label = is_last_message
+ .then(|| match (thread.queue_state(), is_generating) {
+ (Some(QueueState::Sending), _) => Some(
+ AnimatedLabel::new("Sending")
+ .size(LabelSize::Small)
+ .into_any_element(),
+ ),
+ (Some(QueueState::Queued { position }), _) => Some(
+ Label::new(format!("Queue position: {position}"))
+ .size(LabelSize::Small)
+ .color(Color::Muted)
+ .into_any_element(),
+ ),
+ (_, true) => Some(
+ AnimatedLabel::new("Generating")
+ .size(LabelSize::Small)
+ .into_any_element(),
+ ),
+ _ => None,
+ })
+ .flatten();
let editing_message_state = self
.editing_message
@@ -2105,7 +2124,7 @@ impl ActiveThread {
parent.child(self.render_rules_item(cx))
})
.child(styled_message)
- .when(generating_label.is_some(), |this| {
+ .when_some(generating_label, |this, generating_label| {
this.child(
h_flex()
.h_8()
@@ -2113,7 +2132,7 @@ impl ActiveThread {
.mb_4()
.ml_4()
.py_1p5()
- .child(generating_label.unwrap()),
+ .child(generating_label),
)
})
.when(show_feedback, move |parent| {
@@ -320,6 +320,13 @@ fn default_completion_mode(cx: &App) -> CompletionMode {
}
}
+#[derive(Debug, Clone, Copy)]
+pub enum QueueState {
+ Sending,
+ Queued { position: usize },
+ Started,
+}
+
/// A thread of conversation with the LLM.
pub struct Thread {
id: ThreadId,
@@ -625,6 +632,12 @@ impl Thread {
!self.pending_completions.is_empty() || !self.all_tools_finished()
}
+ pub fn queue_state(&self) -> Option<QueueState> {
+ self.pending_completions
+ .first()
+ .map(|pending_completion| pending_completion.queue_state)
+ }
+
pub fn tools(&self) -> &Entity<ToolWorkingSet> {
&self.tools
}
@@ -1470,6 +1483,20 @@ impl Thread {
});
}
}
+ LanguageModelCompletionEvent::QueueUpdate(queue_event) => {
+ if let Some(completion) = thread
+ .pending_completions
+ .iter_mut()
+ .find(|completion| completion.id == pending_completion_id)
+ {
+ completion.queue_state = match queue_event {
+ language_model::QueueState::Queued { position } => {
+ QueueState::Queued { position }
+ }
+ language_model::QueueState::Started => QueueState::Started,
+ }
+ }
+ }
}
thread.touch_updated_at();
@@ -1590,6 +1617,7 @@ impl Thread {
self.pending_completions.push(PendingCompletion {
id: pending_completion_id,
+ queue_state: QueueState::Sending,
_task: task,
});
}
@@ -2499,6 +2527,7 @@ impl EventEmitter<ThreadEvent> for Thread {}
struct PendingCompletion {
id: usize,
+ queue_state: QueueState,
_task: Task<()>,
}
@@ -2371,6 +2371,7 @@ impl AssistantContext {
});
match event {
+ LanguageModelCompletionEvent::QueueUpdate { .. } => {}
LanguageModelCompletionEvent::StartMessage { .. } => {}
LanguageModelCompletionEvent::Stop(reason) => {
stop_reason = reason;
@@ -1017,7 +1017,8 @@ pub fn response_events_to_markdown(
}
Ok(
LanguageModelCompletionEvent::UsageUpdate(_)
- | LanguageModelCompletionEvent::StartMessage { .. },
+ | LanguageModelCompletionEvent::StartMessage { .. }
+ | LanguageModelCompletionEvent::QueueUpdate { .. },
) => {}
Err(error) => {
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
@@ -1092,6 +1093,7 @@ impl ThreadDialog {
// Skip these
Ok(LanguageModelCompletionEvent::UsageUpdate(_))
+ | Ok(LanguageModelCompletionEvent::QueueUpdate { .. })
| Ok(LanguageModelCompletionEvent::StartMessage { .. })
| Ok(LanguageModelCompletionEvent::Stop(_)) => {}
@@ -64,9 +64,17 @@ pub struct LanguageModelCacheConfiguration {
pub min_total_token: usize,
}
+#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
+#[serde(tag = "status", rename_all = "snake_case")]
+pub enum QueueState {
+ Queued { position: usize },
+ Started,
+}
+
/// A completion event from a language model.
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum LanguageModelCompletionEvent {
+ QueueUpdate(QueueState),
Stop(StopReason),
Text(String),
Thinking {
@@ -349,6 +357,7 @@ pub trait LanguageModel: Send + Sync {
let last_token_usage = last_token_usage.clone();
async move {
match result {
+ Ok(LanguageModelCompletionEvent::QueueUpdate { .. }) => None,
Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
@@ -469,7 +469,7 @@ impl LanguageModel for AnthropicModel {
Ok(anthropic_err) => anthropic_err_to_anyhow(anthropic_err),
Err(err) => anyhow!(err),
})?;
- Ok(map_to_language_model_completion_events(response))
+ Ok(AnthropicEventMapper::new().map_stream(response))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
@@ -629,215 +629,186 @@ pub fn into_anthropic(
}
}
-pub fn map_to_language_model_completion_events(
- events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
-) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
- struct RawToolUse {
- id: String,
- name: String,
- input_json: String,
+pub struct AnthropicEventMapper {
+ tool_uses_by_index: HashMap<usize, RawToolUse>,
+ usage: Usage,
+ stop_reason: StopReason,
+}
+
+impl AnthropicEventMapper {
+ pub fn new() -> Self {
+ Self {
+ tool_uses_by_index: HashMap::default(),
+ usage: Usage::default(),
+ stop_reason: StopReason::EndTurn,
+ }
}
- struct State {
+ pub fn map_stream(
+ mut self,
events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
- tool_uses_by_index: HashMap<usize, RawToolUse>,
- usage: Usage,
- stop_reason: StopReason,
+ ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
+ {
+ events.flat_map(move |event| {
+ futures::stream::iter(match event {
+ Ok(event) => self.map_event(event),
+ Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
+ })
+ })
}
- futures::stream::unfold(
- State {
- events,
- tool_uses_by_index: HashMap::default(),
- usage: Usage::default(),
- stop_reason: StopReason::EndTurn,
- },
- |mut state| async move {
- while let Some(event) = state.events.next().await {
- match event {
- Ok(event) => match event {
- Event::ContentBlockStart {
- index,
- content_block,
- } => match content_block {
- ResponseContent::Text { text } => {
- return Some((
- vec![Ok(LanguageModelCompletionEvent::Text(text))],
- state,
- ));
- }
- ResponseContent::Thinking { thinking } => {
- return Some((
- vec![Ok(LanguageModelCompletionEvent::Thinking {
- text: thinking,
- signature: None,
- })],
- state,
- ));
- }
- ResponseContent::RedactedThinking { .. } => {
- // Redacted thinking is encrypted and not accessible to the user, see:
- // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#suggestions-for-handling-redacted-thinking-in-production
- }
- ResponseContent::ToolUse { id, name, .. } => {
- state.tool_uses_by_index.insert(
- index,
- RawToolUse {
- id,
- name,
- input_json: String::new(),
- },
- );
- }
- },
- Event::ContentBlockDelta { index, delta } => match delta {
- ContentDelta::TextDelta { text } => {
- return Some((
- vec![Ok(LanguageModelCompletionEvent::Text(text))],
- state,
- ));
- }
- ContentDelta::ThinkingDelta { thinking } => {
- return Some((
- vec![Ok(LanguageModelCompletionEvent::Thinking {
- text: thinking,
- signature: None,
- })],
- state,
- ));
- }
- ContentDelta::SignatureDelta { signature } => {
- return Some((
- vec![Ok(LanguageModelCompletionEvent::Thinking {
- text: "".to_string(),
- signature: Some(signature),
- })],
- state,
- ));
- }
- ContentDelta::InputJsonDelta { partial_json } => {
- if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
- tool_use.input_json.push_str(&partial_json);
-
- // Try to convert invalid (incomplete) JSON into
- // valid JSON that serde can accept, e.g. by closing
- // unclosed delimiters. This way, we can update the
- // UI with whatever has been streamed back so far.
- if let Ok(input) = serde_json::Value::from_str(
- &partial_json_fixer::fix_json(&tool_use.input_json),
- ) {
- return Some((
- vec![Ok(LanguageModelCompletionEvent::ToolUse(
- LanguageModelToolUse {
- id: tool_use.id.clone().into(),
- name: tool_use.name.clone().into(),
- is_input_complete: false,
- raw_input: tool_use.input_json.clone(),
- input,
- },
- ))],
- state,
- ));
- }
- }
- }
+ pub fn map_event(
+ &mut self,
+ event: Event,
+ ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
+ match event {
+ Event::ContentBlockStart {
+ index,
+ content_block,
+ } => match content_block {
+ ResponseContent::Text { text } => {
+ vec![Ok(LanguageModelCompletionEvent::Text(text))]
+ }
+ ResponseContent::Thinking { thinking } => {
+ vec![Ok(LanguageModelCompletionEvent::Thinking {
+ text: thinking,
+ signature: None,
+ })]
+ }
+ ResponseContent::RedactedThinking { .. } => {
+ // Redacted thinking is encrypted and not accessible to the user, see:
+ // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#suggestions-for-handling-redacted-thinking-in-production
+ Vec::new()
+ }
+ ResponseContent::ToolUse { id, name, .. } => {
+ self.tool_uses_by_index.insert(
+ index,
+ RawToolUse {
+ id,
+ name,
+ input_json: String::new(),
},
- Event::ContentBlockStop { index } => {
- if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
- let input_json = tool_use.input_json.trim();
- let input_value = if input_json.is_empty() {
- Ok(serde_json::Value::Object(serde_json::Map::default()))
- } else {
- serde_json::Value::from_str(input_json)
- };
- let event_result = match input_value {
- Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
- LanguageModelToolUse {
- id: tool_use.id.into(),
- name: tool_use.name.into(),
- is_input_complete: true,
- input,
- raw_input: tool_use.input_json.clone(),
- },
- )),
- Err(json_parse_err) => {
- Err(LanguageModelCompletionError::BadInputJson {
- id: tool_use.id.into(),
- tool_name: tool_use.name.into(),
- raw_input: input_json.into(),
- json_parse_error: json_parse_err.to_string(),
- })
- }
- };
-
- return Some((vec![event_result], state));
- }
- }
- Event::MessageStart { message } => {
- update_usage(&mut state.usage, &message.usage);
- return Some((
- vec![
- Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
- &state.usage,
- ))),
- Ok(LanguageModelCompletionEvent::StartMessage {
- message_id: message.id,
- }),
- ],
- state,
- ));
- }
- Event::MessageDelta { delta, usage } => {
- update_usage(&mut state.usage, &usage);
- if let Some(stop_reason) = delta.stop_reason.as_deref() {
- state.stop_reason = match stop_reason {
- "end_turn" => StopReason::EndTurn,
- "max_tokens" => StopReason::MaxTokens,
- "tool_use" => StopReason::ToolUse,
- _ => {
- log::error!(
- "Unexpected anthropic stop_reason: {stop_reason}"
- );
- StopReason::EndTurn
- }
- };
- }
- return Some((
- vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
- convert_usage(&state.usage),
- ))],
- state,
- ));
- }
- Event::MessageStop => {
- return Some((
- vec![Ok(LanguageModelCompletionEvent::Stop(state.stop_reason))],
- state,
- ));
- }
- Event::Error { error } => {
- return Some((
- vec![Err(LanguageModelCompletionError::Other(anyhow!(
- AnthropicError::ApiError(error)
- )))],
- state,
- ));
+ );
+ Vec::new()
+ }
+ },
+ Event::ContentBlockDelta { index, delta } => match delta {
+ ContentDelta::TextDelta { text } => {
+ vec![Ok(LanguageModelCompletionEvent::Text(text))]
+ }
+ ContentDelta::ThinkingDelta { thinking } => {
+ vec![Ok(LanguageModelCompletionEvent::Thinking {
+ text: thinking,
+ signature: None,
+ })]
+ }
+ ContentDelta::SignatureDelta { signature } => {
+ vec![Ok(LanguageModelCompletionEvent::Thinking {
+ text: "".to_string(),
+ signature: Some(signature),
+ })]
+ }
+ ContentDelta::InputJsonDelta { partial_json } => {
+ if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) {
+ tool_use.input_json.push_str(&partial_json);
+
+ // Try to convert invalid (incomplete) JSON into
+ // valid JSON that serde can accept, e.g. by closing
+ // unclosed delimiters. This way, we can update the
+ // UI with whatever has been streamed back so far.
+ if let Ok(input) = serde_json::Value::from_str(
+ &partial_json_fixer::fix_json(&tool_use.input_json),
+ ) {
+ return vec![Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_use.id.clone().into(),
+ name: tool_use.name.clone().into(),
+ is_input_complete: false,
+ raw_input: tool_use.input_json.clone(),
+ input,
+ },
+ ))];
}
- _ => {}
- },
- Err(err) => {
- return Some((
- vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))],
- state,
- ));
}
+ return vec![];
+ }
+ },
+ Event::ContentBlockStop { index } => {
+ if let Some(tool_use) = self.tool_uses_by_index.remove(&index) {
+ let input_json = tool_use.input_json.trim();
+ let input_value = if input_json.is_empty() {
+ Ok(serde_json::Value::Object(serde_json::Map::default()))
+ } else {
+ serde_json::Value::from_str(input_json)
+ };
+ let event_result = match input_value {
+ Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_use.id.into(),
+ name: tool_use.name.into(),
+ is_input_complete: true,
+ input,
+ raw_input: tool_use.input_json.clone(),
+ },
+ )),
+ Err(json_parse_err) => Err(LanguageModelCompletionError::BadInputJson {
+ id: tool_use.id.into(),
+ tool_name: tool_use.name.into(),
+ raw_input: input_json.into(),
+ json_parse_error: json_parse_err.to_string(),
+ }),
+ };
+
+ vec![event_result]
+ } else {
+ Vec::new()
+ }
+ }
+ Event::MessageStart { message } => {
+ update_usage(&mut self.usage, &message.usage);
+ vec![
+ Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
+ &self.usage,
+ ))),
+ Ok(LanguageModelCompletionEvent::StartMessage {
+ message_id: message.id,
+ }),
+ ]
+ }
+ Event::MessageDelta { delta, usage } => {
+ update_usage(&mut self.usage, &usage);
+ if let Some(stop_reason) = delta.stop_reason.as_deref() {
+ self.stop_reason = match stop_reason {
+ "end_turn" => StopReason::EndTurn,
+ "max_tokens" => StopReason::MaxTokens,
+ "tool_use" => StopReason::ToolUse,
+ _ => {
+ log::error!("Unexpected anthropic stop_reason: {stop_reason}");
+ StopReason::EndTurn
+ }
+ };
}
+ vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
+ convert_usage(&self.usage),
+ ))]
}
+ Event::MessageStop => {
+ vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
+ }
+ Event::Error { error } => {
+ vec![Err(LanguageModelCompletionError::Other(anyhow!(
+ AnthropicError::ApiError(error)
+ )))]
+ }
+ _ => Vec::new(),
+ }
+ }
+}
- None
- },
- )
- .flat_map(futures::stream::iter)
+struct RawToolUse {
+ id: String,
+ name: String,
+ input_json: String,
}
pub fn anthropic_err_to_anyhow(err: AnthropicError) -> anyhow::Error {
@@ -1,11 +1,10 @@
-use anthropic::{AnthropicError, AnthropicModelMode, parse_prompt_too_long};
+use anthropic::{AnthropicModelMode, parse_prompt_too_long};
use anyhow::{Result, anyhow};
use client::{Client, UserStore, zed_urls};
use collections::BTreeMap;
use feature_flags::{FeatureFlagAppExt, LlmClosedBetaFeatureFlag, ZedProFeatureFlag};
use futures::{
- AsyncBufReadExt, FutureExt, Stream, StreamExt, TryStreamExt as _, future::BoxFuture,
- stream::BoxStream,
+ AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
};
use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
@@ -14,7 +13,7 @@ use language_model::{
LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
- ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
+ ModelRequestLimitReachedError, QueueState, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
};
use language_model::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
@@ -26,6 +25,7 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned};
use settings::{Settings, SettingsStore};
use smol::Timer;
use smol::io::{AsyncReadExt, BufReader};
+use std::pin::Pin;
use std::str::FromStr as _;
use std::{
sync::{Arc, LazyLock},
@@ -41,9 +41,9 @@ use zed_llm_client::{
};
use crate::AllLanguageModelSettings;
-use crate::provider::anthropic::{count_anthropic_tokens, into_anthropic};
-use crate::provider::google::into_google;
-use crate::provider::open_ai::{count_open_ai_tokens, into_open_ai};
+use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
+use crate::provider::google::{GoogleEventMapper, into_google};
+use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
pub const PROVIDER_NAME: &str = "Zed";
@@ -518,7 +518,7 @@ impl CloudLanguageModel {
client: Arc<Client>,
llm_api_token: LlmApiToken,
body: CompletionBody,
- ) -> Result<(Response<AsyncBody>, Option<RequestUsage>)> {
+ ) -> Result<(Response<AsyncBody>, Option<RequestUsage>, bool)> {
let http_client = &client.http_client();
let mut token = llm_api_token.acquire(&client).await?;
@@ -536,13 +536,18 @@ impl CloudLanguageModel {
let request = request_builder
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}"))
+ .header("x-zed-client-supports-queueing", "true")
.body(serde_json::to_string(&body)?.into())?;
let mut response = http_client.send(request).await?;
let status = response.status();
if status.is_success() {
+ let includes_queue_events = response
+ .headers()
+ .get("x-zed-server-supports-queueing")
+ .is_some();
let usage = RequestUsage::from_headers(response.headers()).ok();
- return Ok((response, usage));
+ return Ok((response, usage, includes_queue_events));
} else if response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
@@ -782,7 +787,7 @@ impl LanguageModel for CloudLanguageModel {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream_with_usage(async move {
- let (response, usage) = Self::perform_llm_completion(
+ let (response, usage, includes_queue_events) = Self::perform_llm_completion(
client.clone(),
llm_api_token,
CompletionBody {
@@ -811,9 +816,11 @@ impl LanguageModel for CloudLanguageModel {
Err(err) => anyhow!(err),
})?;
+ let mut mapper = AnthropicEventMapper::new();
Ok((
- crate::provider::anthropic::map_to_language_model_completion_events(
- Box::pin(response_lines(response).map_err(AnthropicError::Other)),
+ map_cloud_completion_events(
+ Box::pin(response_lines(response, includes_queue_events)),
+ move |event| mapper.map_event(event),
),
usage,
))
@@ -829,7 +836,7 @@ impl LanguageModel for CloudLanguageModel {
let request = into_open_ai(request, model, model.max_output_tokens());
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream_with_usage(async move {
- let (response, usage) = Self::perform_llm_completion(
+ let (response, usage, includes_queue_events) = Self::perform_llm_completion(
client.clone(),
llm_api_token,
CompletionBody {
@@ -842,9 +849,12 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
+
+ let mut mapper = OpenAiEventMapper::new();
Ok((
- crate::provider::open_ai::map_to_language_model_completion_events(
- Box::pin(response_lines(response)),
+ map_cloud_completion_events(
+ Box::pin(response_lines(response, includes_queue_events)),
+ move |event| mapper.map_event(event),
),
usage,
))
@@ -860,7 +870,7 @@ impl LanguageModel for CloudLanguageModel {
let request = into_google(request, model.id().into());
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream_with_usage(async move {
- let (response, usage) = Self::perform_llm_completion(
+ let (response, usage, includes_queue_events) = Self::perform_llm_completion(
client.clone(),
llm_api_token,
CompletionBody {
@@ -873,10 +883,12 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
+ let mut mapper = GoogleEventMapper::new();
Ok((
- crate::provider::google::map_to_language_model_completion_events(Box::pin(
- response_lines(response),
- )),
+ map_cloud_completion_events(
+ Box::pin(response_lines(response, includes_queue_events)),
+ move |event| mapper.map_event(event),
+ ),
usage,
))
});
@@ -890,16 +902,54 @@ impl LanguageModel for CloudLanguageModel {
}
}
+#[derive(Serialize, Deserialize)]
+#[serde(rename_all = "snake_case")]
+pub enum CloudCompletionEvent<T> {
+ Queue(QueueState),
+ Event(T),
+}
+
+fn map_cloud_completion_events<T, F>(
+ stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
+ mut map_callback: F,
+) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
+where
+ T: DeserializeOwned + 'static,
+ F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
+ + Send
+ + 'static,
+{
+ stream
+ .flat_map(move |event| {
+ futures::stream::iter(match event {
+ Err(error) => {
+ vec![Err(LanguageModelCompletionError::Other(error))]
+ }
+ Ok(CloudCompletionEvent::Queue(event)) => {
+ vec![Ok(LanguageModelCompletionEvent::QueueUpdate(event))]
+ }
+ Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
+ })
+ })
+ .boxed()
+}
+
fn response_lines<T: DeserializeOwned>(
response: Response<AsyncBody>,
-) -> impl Stream<Item = Result<T>> {
+ includes_queue_events: bool,
+) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
futures::stream::try_unfold(
(String::new(), BufReader::new(response.into_body())),
- move |(mut line, mut body)| async {
+ move |(mut line, mut body)| async move {
match body.read_line(&mut line).await {
Ok(0) => Ok(None),
Ok(_) => {
- let event: T = serde_json::from_str(&line)?;
+ let event = if includes_queue_events {
+ serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
+ } else {
+ CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
+ };
+
line.clear();
Ok(Some((event, (line, body))))
}
@@ -24,7 +24,10 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
-use std::sync::Arc;
+use std::sync::{
+ Arc,
+ atomic::{self, AtomicU64},
+};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
@@ -371,7 +374,7 @@ impl LanguageModel for GoogleLanguageModel {
let response = request
.await
.map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?;
- Ok(map_to_language_model_completion_events(response))
+ Ok(GoogleEventMapper::new().map_stream(response))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
@@ -486,108 +489,98 @@ pub fn into_google(
}
}
-pub fn map_to_language_model_completion_events(
- events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
-) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
- use std::sync::atomic::{AtomicU64, Ordering};
+pub struct GoogleEventMapper {
+ usage: UsageMetadata,
+ stop_reason: StopReason,
+}
- static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
+impl GoogleEventMapper {
+ pub fn new() -> Self {
+ Self {
+ usage: UsageMetadata::default(),
+ stop_reason: StopReason::EndTurn,
+ }
+ }
- struct State {
+ pub fn map_stream(
+ mut self,
events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
- usage: UsageMetadata,
- stop_reason: StopReason,
+ ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
+ {
+ events.flat_map(move |event| {
+ futures::stream::iter(match event {
+ Ok(event) => self.map_event(event),
+ Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
+ })
+ })
}
- futures::stream::unfold(
- State {
- events,
- usage: UsageMetadata::default(),
- stop_reason: StopReason::EndTurn,
- },
- |mut state| async move {
- if let Some(event) = state.events.next().await {
- match event {
- Ok(event) => {
- let mut events: Vec<_> = Vec::new();
- let mut wants_to_use_tool = false;
- if let Some(usage_metadata) = event.usage_metadata {
- update_usage(&mut state.usage, &usage_metadata);
- events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
- convert_usage(&state.usage),
- )))
- }
- if let Some(candidates) = event.candidates {
- for candidate in candidates {
- if let Some(finish_reason) = candidate.finish_reason.as_deref() {
- state.stop_reason = match finish_reason {
- "STOP" => StopReason::EndTurn,
- "MAX_TOKENS" => StopReason::MaxTokens,
- _ => {
- log::error!(
- "Unexpected google finish_reason: {finish_reason}"
- );
- StopReason::EndTurn
- }
- };
- }
- candidate
- .content
- .parts
- .into_iter()
- .for_each(|part| match part {
- Part::TextPart(text_part) => events.push(Ok(
- LanguageModelCompletionEvent::Text(text_part.text),
- )),
- Part::InlineDataPart(_) => {}
- Part::FunctionCallPart(function_call_part) => {
- wants_to_use_tool = true;
- let name: Arc<str> =
- function_call_part.function_call.name.into();
- let next_tool_id =
- TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst);
- let id: LanguageModelToolUseId =
- format!("{}-{}", name, next_tool_id).into();
-
- events.push(Ok(LanguageModelCompletionEvent::ToolUse(
- LanguageModelToolUse {
- id,
- name,
- is_input_complete: true,
- raw_input: function_call_part
- .function_call
- .args
- .to_string(),
- input: function_call_part.function_call.args,
- },
- )));
- }
- Part::FunctionResponsePart(_) => {}
- });
- }
- }
+ pub fn map_event(
+ &mut self,
+ event: GenerateContentResponse,
+ ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
+ static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
- // Even when Gemini wants to use a Tool, the API
- // responds with `finish_reason: STOP`
- if wants_to_use_tool {
- state.stop_reason = StopReason::ToolUse;
+ let mut events: Vec<_> = Vec::new();
+ let mut wants_to_use_tool = false;
+ if let Some(usage_metadata) = event.usage_metadata {
+ update_usage(&mut self.usage, &usage_metadata);
+ events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
+ convert_usage(&self.usage),
+ )))
+ }
+ if let Some(candidates) = event.candidates {
+ for candidate in candidates {
+ if let Some(finish_reason) = candidate.finish_reason.as_deref() {
+ self.stop_reason = match finish_reason {
+ "STOP" => StopReason::EndTurn,
+ "MAX_TOKENS" => StopReason::MaxTokens,
+ _ => {
+ log::error!("Unexpected google finish_reason: {finish_reason}");
+ StopReason::EndTurn
}
- events.push(Ok(LanguageModelCompletionEvent::Stop(state.stop_reason)));
- return Some((events, state));
- }
- Err(err) => {
- return Some((
- vec![Err(LanguageModelCompletionError::Other(anyhow!(err)))],
- state,
- ));
- }
+ };
}
+ candidate
+ .content
+ .parts
+ .into_iter()
+ .for_each(|part| match part {
+ Part::TextPart(text_part) => {
+ events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
+ }
+ Part::InlineDataPart(_) => {}
+ Part::FunctionCallPart(function_call_part) => {
+ wants_to_use_tool = true;
+ let name: Arc<str> = function_call_part.function_call.name.into();
+ let next_tool_id =
+ TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
+ let id: LanguageModelToolUseId =
+ format!("{}-{}", name, next_tool_id).into();
+
+ events.push(Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id,
+ name,
+ is_input_complete: true,
+ raw_input: function_call_part.function_call.args.to_string(),
+ input: function_call_part.function_call.args,
+ },
+ )));
+ }
+ Part::FunctionResponsePart(_) => {}
+ });
}
+ }
- None
- },
- )
- .flat_map(futures::stream::iter)
+ // Even when Gemini wants to use a Tool, the API
+ // responds with `finish_reason: STOP`
+ if wants_to_use_tool {
+ self.stop_reason = StopReason::ToolUse;
+ }
+ events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
+ events
+ }
}
pub fn count_google_tokens(
@@ -330,8 +330,11 @@ impl LanguageModel for OpenAiLanguageModel {
> {
let request = into_open_ai(request, &self.model, self.max_output_tokens());
let completions = self.stream_completion(request, cx);
- async move { Ok(map_to_language_model_completion_events(completions.await?).boxed()) }
- .boxed()
+ async move {
+ let mapper = OpenAiEventMapper::new();
+ Ok(mapper.map_stream(completions.await?).boxed())
+ }
+ .boxed()
}
}
@@ -422,123 +425,108 @@ pub fn into_open_ai(
}
}
-pub fn map_to_language_model_completion_events(
- events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
-) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
- #[derive(Default)]
- struct RawToolCall {
- id: String,
- name: String,
- arguments: String,
+pub struct OpenAiEventMapper {
+ tool_calls_by_index: HashMap<usize, RawToolCall>,
+}
+
+impl OpenAiEventMapper {
+ pub fn new() -> Self {
+ Self {
+ tool_calls_by_index: HashMap::default(),
+ }
}
- struct State {
+ pub fn map_stream(
+ mut self,
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
- tool_calls_by_index: HashMap<usize, RawToolCall>,
+ ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
+ {
+ events.flat_map(move |event| {
+ futures::stream::iter(match event {
+ Ok(event) => self.map_event(event),
+ Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
+ })
+ })
}
- futures::stream::unfold(
- State {
- events,
- tool_calls_by_index: HashMap::default(),
- },
- |mut state| async move {
- if let Some(event) = state.events.next().await {
- match event {
- Ok(event) => {
- let Some(choice) = event.choices.first() else {
- return Some((
- vec![Err(LanguageModelCompletionError::Other(anyhow!(
- "Response contained no choices"
- )))],
- state,
- ));
- };
-
- let mut events = Vec::new();
- if let Some(content) = choice.delta.content.clone() {
- events.push(Ok(LanguageModelCompletionEvent::Text(content)));
- }
+ pub fn map_event(
+ &mut self,
+ event: ResponseStreamEvent,
+ ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
+ let Some(choice) = event.choices.first() else {
+ return vec![Err(LanguageModelCompletionError::Other(anyhow!(
+ "Response contained no choices"
+ )))];
+ };
- if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
- for tool_call in tool_calls {
- let entry = state
- .tool_calls_by_index
- .entry(tool_call.index)
- .or_default();
-
- if let Some(tool_id) = tool_call.id.clone() {
- entry.id = tool_id;
- }
-
- if let Some(function) = tool_call.function.as_ref() {
- if let Some(name) = function.name.clone() {
- entry.name = name;
- }
-
- if let Some(arguments) = function.arguments.clone() {
- entry.arguments.push_str(&arguments);
- }
- }
- }
- }
+ let mut events = Vec::new();
+ if let Some(content) = choice.delta.content.clone() {
+ events.push(Ok(LanguageModelCompletionEvent::Text(content)));
+ }
- match choice.finish_reason.as_deref() {
- Some("stop") => {
- events.push(Ok(LanguageModelCompletionEvent::Stop(
- StopReason::EndTurn,
- )));
- }
- Some("tool_calls") => {
- events.extend(state.tool_calls_by_index.drain().map(
- |(_, tool_call)| match serde_json::Value::from_str(
- &tool_call.arguments,
- ) {
- Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
- LanguageModelToolUse {
- id: tool_call.id.clone().into(),
- name: tool_call.name.as_str().into(),
- is_input_complete: true,
- input,
- raw_input: tool_call.arguments.clone(),
- },
- )),
- Err(error) => {
- Err(LanguageModelCompletionError::BadInputJson {
- id: tool_call.id.into(),
- tool_name: tool_call.name.as_str().into(),
- raw_input: tool_call.arguments.into(),
- json_parse_error: error.to_string(),
- })
- }
- },
- ));
-
- events.push(Ok(LanguageModelCompletionEvent::Stop(
- StopReason::ToolUse,
- )));
- }
- Some(stop_reason) => {
- log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
- events.push(Ok(LanguageModelCompletionEvent::Stop(
- StopReason::EndTurn,
- )));
- }
- None => {}
- }
+ if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
+ for tool_call in tool_calls {
+ let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
+
+ if let Some(tool_id) = tool_call.id.clone() {
+ entry.id = tool_id;
+ }
- return Some((events, state));
+ if let Some(function) = tool_call.function.as_ref() {
+ if let Some(name) = function.name.clone() {
+ entry.name = name;
}
- Err(err) => {
- return Some((vec![Err(LanguageModelCompletionError::Other(err))], state));
+
+ if let Some(arguments) = function.arguments.clone() {
+ entry.arguments.push_str(&arguments);
}
}
}
+ }
- None
- },
- )
- .flat_map(futures::stream::iter)
+ match choice.finish_reason.as_deref() {
+ Some("stop") => {
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
+ }
+ Some("tool_calls") => {
+ events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
+ match serde_json::Value::from_str(&tool_call.arguments) {
+ Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_call.id.clone().into(),
+ name: tool_call.name.as_str().into(),
+ is_input_complete: true,
+ input,
+ raw_input: tool_call.arguments.clone(),
+ },
+ )),
+ Err(error) => Err(LanguageModelCompletionError::BadInputJson {
+ id: tool_call.id.into(),
+ tool_name: tool_call.name.as_str().into(),
+ raw_input: tool_call.arguments.into(),
+ json_parse_error: error.to_string(),
+ }),
+ }
+ }));
+
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
+ }
+ Some(stop_reason) => {
+ log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
+ }
+ None => {}
+ }
+
+ events
+ }
+}
+
+#[derive(Default)]
+struct RawToolCall {
+ id: String,
+ name: String,
+ arguments: String,
}
pub fn count_open_ai_tokens(