@@ -37,7 +37,7 @@ use settings::Settings;
use thiserror::Error;
use util::{ResultExt as _, TryFutureExt as _, post_inc};
use uuid::Uuid;
-use zed_llm_client::CompletionMode;
+use zed_llm_client::{CompletionMode, CompletionRequestStatus};
use crate::ThreadStore;
use crate::context::{AgentContext, AgentContextHandle, ContextLoadResult, LoadedContext};
@@ -1356,20 +1356,17 @@ impl Thread {
self.last_received_chunk_at = Some(Instant::now());
let task = cx.spawn(async move |thread, cx| {
- let stream_completion_future = model.stream_completion_with_usage(request, &cx);
+ let stream_completion_future = model.stream_completion(request, &cx);
let initial_token_usage =
thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
let stream_completion = async {
- let (mut events, usage) = stream_completion_future.await?;
+ let mut events = stream_completion_future.await?;
let mut stop_reason = StopReason::EndTurn;
let mut current_token_usage = TokenUsage::default();
thread
.update(cx, |_thread, cx| {
- if let Some(usage) = usage {
- cx.emit(ThreadEvent::UsageUpdated(usage));
- }
cx.emit(ThreadEvent::NewRequest);
})
.ok();
@@ -1515,27 +1512,34 @@ impl Thread {
});
}
}
- LanguageModelCompletionEvent::QueueUpdate(status) => {
+ LanguageModelCompletionEvent::StatusUpdate(status_update) => {
if let Some(completion) = thread
.pending_completions
.iter_mut()
.find(|completion| completion.id == pending_completion_id)
{
- let queue_state = match status {
- language_model::CompletionRequestStatus::Queued {
+ match status_update {
+ CompletionRequestStatus::Queued {
position,
- } => Some(QueueState::Queued { position }),
- language_model::CompletionRequestStatus::Started => {
- Some(QueueState::Started)
+ } => {
+ completion.queue_state = QueueState::Queued { position };
+ }
+ CompletionRequestStatus::Started => {
+ completion.queue_state = QueueState::Started;
+ }
+ CompletionRequestStatus::Failed {
+ code, message
+ } => {
+ return Err(anyhow!("completion request failed. code: {code}, message: {message}"));
+ }
+ CompletionRequestStatus::UsageUpdated {
+ amount, limit
+ } => {
+ cx.emit(ThreadEvent::UsageUpdated(RequestUsage { limit, amount: amount as i32 }));
}
- language_model::CompletionRequestStatus::ToolUseLimitReached => {
+ CompletionRequestStatus::ToolUseLimitReached => {
thread.tool_use_limit_reached = true;
- None
}
- };
-
- if let Some(queue_state) = queue_state {
- completion.queue_state = queue_state;
}
}
}
@@ -1690,19 +1694,27 @@ impl Thread {
self.pending_summary = cx.spawn(async move |this, cx| {
async move {
- let stream = model.model.stream_completion_text_with_usage(request, &cx);
- let (mut messages, usage) = stream.await?;
-
- if let Some(usage) = usage {
- this.update(cx, |_thread, cx| {
- cx.emit(ThreadEvent::UsageUpdated(usage));
- })
- .ok();
- }
+ let mut messages = model.model.stream_completion(request, &cx).await?;
let mut new_summary = String::new();
- while let Some(message) = messages.stream.next().await {
- let text = message?;
+ while let Some(event) = messages.next().await {
+ let event = event?;
+ let text = match event {
+ LanguageModelCompletionEvent::Text(text) => text,
+ LanguageModelCompletionEvent::StatusUpdate(
+ CompletionRequestStatus::UsageUpdated { amount, limit },
+ ) => {
+ this.update(cx, |_, cx| {
+ cx.emit(ThreadEvent::UsageUpdated(RequestUsage {
+ limit,
+ amount: amount as i32,
+ }));
+ })?;
+ continue;
+ }
+ _ => continue,
+ };
+
let mut lines = text.lines();
new_summary.extend(lines.next());
@@ -26,7 +26,8 @@ use std::sync::Arc;
use thiserror::Error;
use util::serde::is_default;
use zed_llm_client::{
- MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
+ CompletionRequestStatus, MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME,
+ MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
};
pub use crate::model::*;
@@ -64,18 +65,10 @@ pub struct LanguageModelCacheConfiguration {
pub min_total_token: usize,
}
-#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
-#[serde(tag = "status", rename_all = "snake_case")]
-pub enum CompletionRequestStatus {
- Queued { position: usize },
- Started,
- ToolUseLimitReached,
-}
-
/// A completion event from a language model.
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum LanguageModelCompletionEvent {
- QueueUpdate(CompletionRequestStatus),
+ StatusUpdate(CompletionRequestStatus),
Stop(StopReason),
Text(String),
Thinking {
@@ -299,41 +292,15 @@ pub trait LanguageModel: Send + Sync {
>,
>;
- fn stream_completion_with_usage(
- &self,
- request: LanguageModelRequest,
- cx: &AsyncApp,
- ) -> BoxFuture<
- 'static,
- Result<(
- BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
- Option<RequestUsage>,
- )>,
- > {
- self.stream_completion(request, cx)
- .map(|result| result.map(|stream| (stream, None)))
- .boxed()
- }
-
fn stream_completion_text(
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<LanguageModelTextStream>> {
- self.stream_completion_text_with_usage(request, cx)
- .map(|result| result.map(|(stream, _usage)| stream))
- .boxed()
- }
-
- fn stream_completion_text_with_usage(
- &self,
- request: LanguageModelRequest,
- cx: &AsyncApp,
- ) -> BoxFuture<'static, Result<(LanguageModelTextStream, Option<RequestUsage>)>> {
- let future = self.stream_completion_with_usage(request, cx);
+ let future = self.stream_completion(request, cx);
async move {
- let (events, usage) = future.await?;
+ let events = future.await?;
let mut events = events.fuse();
let mut message_id = None;
let mut first_item_text = None;
@@ -358,7 +325,7 @@ pub trait LanguageModel: Send + Sync {
let last_token_usage = last_token_usage.clone();
async move {
match result {
- Ok(LanguageModelCompletionEvent::QueueUpdate { .. }) => None,
+ Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
@@ -375,14 +342,11 @@ pub trait LanguageModel: Send + Sync {
}))
.boxed();
- Ok((
- LanguageModelTextStream {
- message_id,
- stream,
- last_token_usage,
- },
- usage,
- ))
+ Ok(LanguageModelTextStream {
+ message_id,
+ stream,
+ last_token_usage,
+ })
}
.boxed()
}
@@ -9,12 +9,11 @@ use futures::{
use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{
- AuthenticateError, CloudModel, CompletionRequestStatus, LanguageModel,
- LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId,
- LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
- LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
- LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter, RequestUsage,
- ZED_CLOUD_PROVIDER_ID,
+ AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
+ LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
+ LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
+ LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
+ ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
};
use language_model::{
LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
@@ -36,9 +35,10 @@ use strum::IntoEnumIterator;
use thiserror::Error;
use ui::{TintColor, prelude::*};
use zed_llm_client::{
- CURRENT_PLAN_HEADER_NAME, CompletionBody, CountTokensBody, CountTokensResponse,
- EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
- MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
+ CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
+ CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
+ MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
+ SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
TOOL_USE_LIMIT_REACHED_HEADER_NAME,
};
@@ -517,7 +517,7 @@ struct PerformLlmCompletionResponse {
response: Response<AsyncBody>,
usage: Option<RequestUsage>,
tool_use_limit_reached: bool,
- includes_queue_events: bool,
+ includes_status_messages: bool,
}
impl CloudLanguageModel {
@@ -545,25 +545,31 @@ impl CloudLanguageModel {
let request = request_builder
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}"))
- .header("x-zed-client-supports-queueing", "true")
+ .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
.body(serde_json::to_string(&body)?.into())?;
let mut response = http_client.send(request).await?;
let status = response.status();
if status.is_success() {
- let includes_queue_events = response
+ let includes_status_messages = response
.headers()
- .get("x-zed-server-supports-queueing")
+ .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
.is_some();
+
let tool_use_limit_reached = response
.headers()
.get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
.is_some();
- let usage = RequestUsage::from_headers(response.headers()).ok();
+
+ let usage = if includes_status_messages {
+ None
+ } else {
+ RequestUsage::from_headers(response.headers()).ok()
+ };
return Ok(PerformLlmCompletionResponse {
response,
usage,
- includes_queue_events,
+ includes_status_messages,
tool_use_limit_reached,
});
} else if response
@@ -767,28 +773,12 @@ impl LanguageModel for CloudLanguageModel {
fn stream_completion(
&self,
request: LanguageModelRequest,
- cx: &AsyncApp,
+ _cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
>,
- > {
- self.stream_completion_with_usage(request, cx)
- .map(|result| result.map(|(stream, _)| stream))
- .boxed()
- }
-
- fn stream_completion_with_usage(
- &self,
- request: LanguageModelRequest,
- _cx: &AsyncApp,
- ) -> BoxFuture<
- 'static,
- Result<(
- BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
- Option<RequestUsage>,
- )>,
> {
let thread_id = request.thread_id.clone();
let prompt_id = request.prompt_id.clone();
@@ -804,11 +794,11 @@ impl LanguageModel for CloudLanguageModel {
);
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
- let future = self.request_limiter.stream_with_usage(async move {
+ let future = self.request_limiter.stream(async move {
let PerformLlmCompletionResponse {
response,
usage,
- includes_queue_events,
+ includes_status_messages,
tool_use_limit_reached,
} = Self::perform_llm_completion(
client.clone(),
@@ -840,32 +830,26 @@ impl LanguageModel for CloudLanguageModel {
})?;
let mut mapper = AnthropicEventMapper::new();
- Ok((
- map_cloud_completion_events(
- Box::pin(
- response_lines(response, includes_queue_events)
- .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
- ),
- move |event| mapper.map_event(event),
+ Ok(map_cloud_completion_events(
+ Box::pin(
+ response_lines(response, includes_status_messages)
+ .chain(usage_updated_event(usage))
+ .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
),
- usage,
+ move |event| mapper.map_event(event),
))
});
- async move {
- let (stream, usage) = future.await?;
- Ok((stream.boxed(), usage))
- }
- .boxed()
+ async move { Ok(future.await?.boxed()) }.boxed()
}
CloudModel::OpenAi(model) => {
let client = self.client.clone();
let request = into_open_ai(request, model, model.max_output_tokens());
let llm_api_token = self.llm_api_token.clone();
- let future = self.request_limiter.stream_with_usage(async move {
+ let future = self.request_limiter.stream(async move {
let PerformLlmCompletionResponse {
response,
usage,
- includes_queue_events,
+ includes_status_messages,
tool_use_limit_reached,
} = Self::perform_llm_completion(
client.clone(),
@@ -882,32 +866,26 @@ impl LanguageModel for CloudLanguageModel {
.await?;
let mut mapper = OpenAiEventMapper::new();
- Ok((
- map_cloud_completion_events(
- Box::pin(
- response_lines(response, includes_queue_events)
- .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
- ),
- move |event| mapper.map_event(event),
+ Ok(map_cloud_completion_events(
+ Box::pin(
+ response_lines(response, includes_status_messages)
+ .chain(usage_updated_event(usage))
+ .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
),
- usage,
+ move |event| mapper.map_event(event),
))
});
- async move {
- let (stream, usage) = future.await?;
- Ok((stream.boxed(), usage))
- }
- .boxed()
+ async move { Ok(future.await?.boxed()) }.boxed()
}
CloudModel::Google(model) => {
let client = self.client.clone();
let request = into_google(request, model.id().into());
let llm_api_token = self.llm_api_token.clone();
- let future = self.request_limiter.stream_with_usage(async move {
+ let future = self.request_limiter.stream(async move {
let PerformLlmCompletionResponse {
response,
usage,
- includes_queue_events,
+ includes_status_messages,
tool_use_limit_reached,
} = Self::perform_llm_completion(
client.clone(),
@@ -924,22 +902,16 @@ impl LanguageModel for CloudLanguageModel {
.await?;
let mut mapper = GoogleEventMapper::new();
- Ok((
- map_cloud_completion_events(
- Box::pin(
- response_lines(response, includes_queue_events)
- .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
- ),
- move |event| mapper.map_event(event),
+ Ok(map_cloud_completion_events(
+ Box::pin(
+ response_lines(response, includes_status_messages)
+ .chain(usage_updated_event(usage))
+ .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
),
- usage,
+ move |event| mapper.map_event(event),
))
});
- async move {
- let (stream, usage) = future.await?;
- Ok((stream.boxed(), usage))
- }
- .boxed()
+ async move { Ok(future.await?.boxed()) }.boxed()
}
}
}
@@ -948,7 +920,7 @@ impl LanguageModel for CloudLanguageModel {
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CloudCompletionEvent<T> {
- System(CompletionRequestStatus),
+ Status(CompletionRequestStatus),
Event(T),
}
@@ -968,8 +940,8 @@ where
Err(error) => {
vec![Err(LanguageModelCompletionError::Other(error))]
}
- Ok(CloudCompletionEvent::System(event)) => {
- vec![Ok(LanguageModelCompletionEvent::QueueUpdate(event))]
+ Ok(CloudCompletionEvent::Status(event)) => {
+ vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
}
Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
})
@@ -977,11 +949,24 @@ where
.boxed()
}
+fn usage_updated_event<T>(
+ usage: Option<RequestUsage>,
+) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
+ futures::stream::iter(usage.map(|usage| {
+ Ok(CloudCompletionEvent::Status(
+ CompletionRequestStatus::UsageUpdated {
+ amount: usage.amount as usize,
+ limit: usage.limit,
+ },
+ ))
+ }))
+}
+
fn tool_use_limit_reached_event<T>(
tool_use_limit_reached: bool,
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
futures::stream::iter(tool_use_limit_reached.then(|| {
- Ok(CloudCompletionEvent::System(
+ Ok(CloudCompletionEvent::Status(
CompletionRequestStatus::ToolUseLimitReached,
))
}))
@@ -989,7 +974,7 @@ fn tool_use_limit_reached_event<T>(
fn response_lines<T: DeserializeOwned>(
response: Response<AsyncBody>,
- includes_queue_events: bool,
+ includes_status_messages: bool,
) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
futures::stream::try_unfold(
(String::new(), BufReader::new(response.into_body())),
@@ -997,7 +982,7 @@ fn response_lines<T: DeserializeOwned>(
match body.read_line(&mut line).await {
Ok(0) => Ok(None),
Ok(_) => {
- let event = if includes_queue_events {
+ let event = if includes_status_messages {
serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
} else {
CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)