@@ -35,8 +35,8 @@ use ui::{TintColor, prelude::*};
use util::{ResultExt as _, maybe};
use zed_llm_client::{
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
- CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
- ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
+ CompletionEvent, CompletionRequestStatus, CountTokensBody, CountTokensResponse,
+ EXPIRED_LLM_TOKEN_HEADER_NAME, ListModelsResponse, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
TOOL_USE_LIMIT_REACHED_HEADER_NAME, ZED_VERSION_HEADER_NAME,
};
@@ -1040,15 +1040,8 @@ impl LanguageModel for CloudLanguageModel {
}
}
-#[derive(Serialize, Deserialize)]
-#[serde(rename_all = "snake_case")]
-pub enum CloudCompletionEvent<T> {
- Status(CompletionRequestStatus),
- Event(T),
-}
-
fn map_cloud_completion_events<T, F>(
- stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
+ stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
mut map_callback: F,
) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
where
@@ -1063,10 +1056,10 @@ where
Err(error) => {
vec![Err(LanguageModelCompletionError::from(error))]
}
- Ok(CloudCompletionEvent::Status(event)) => {
+ Ok(CompletionEvent::Status(event)) => {
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
}
- Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
+ Ok(CompletionEvent::Event(event)) => map_callback(event),
})
})
.boxed()
@@ -1074,9 +1067,9 @@ where
fn usage_updated_event<T>(
usage: Option<ModelRequestUsage>,
-) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
+) -> impl Stream<Item = Result<CompletionEvent<T>>> {
futures::stream::iter(usage.map(|usage| {
- Ok(CloudCompletionEvent::Status(
+ Ok(CompletionEvent::Status(
CompletionRequestStatus::UsageUpdated {
amount: usage.amount as usize,
limit: usage.limit,
@@ -1087,9 +1080,9 @@ fn usage_updated_event<T>(
fn tool_use_limit_reached_event<T>(
tool_use_limit_reached: bool,
-) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
+) -> impl Stream<Item = Result<CompletionEvent<T>>> {
futures::stream::iter(tool_use_limit_reached.then(|| {
- Ok(CloudCompletionEvent::Status(
+ Ok(CompletionEvent::Status(
CompletionRequestStatus::ToolUseLimitReached,
))
}))
@@ -1098,7 +1091,7 @@ fn tool_use_limit_reached_event<T>(
fn response_lines<T: DeserializeOwned>(
response: Response<AsyncBody>,
includes_status_messages: bool,
-) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
+) -> impl Stream<Item = Result<CompletionEvent<T>>> {
futures::stream::try_unfold(
(String::new(), BufReader::new(response.into_body())),
move |(mut line, mut body)| async move {
@@ -1106,9 +1099,9 @@ fn response_lines<T: DeserializeOwned>(
Ok(0) => Ok(None),
Ok(_) => {
let event = if includes_status_messages {
- serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
+ serde_json::from_str::<CompletionEvent<T>>(&line)?
} else {
- CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
+ CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
};
line.clear();