Detailed changes
@@ -20139,9 +20139,9 @@ dependencies = [
[[package]]
name = "zed_llm_client"
-version = "0.8.4"
+version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "de7d9523255f4e00ee3d0918e5407bd252d798a4a8e71f6d37f23317a1588203"
+checksum = "c740e29260b8797ad252c202ea09a255b3cbc13f30faaf92fb6b2490336106e0"
dependencies = [
"anyhow",
"serde",
@@ -625,7 +625,7 @@ wasmtime = { version = "29", default-features = false, features = [
wasmtime-wasi = "29"
which = "6.0.0"
workspace-hack = "0.1.0"
-zed_llm_client = "0.8.4"
+zed_llm_client = "0.8.5"
zstd = "0.11"
[workspace.dependencies.async-stripe]
@@ -23,11 +23,10 @@ use gpui::{
};
use language_model::{
ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
- LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
- LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
- LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
- ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
- TokenUsage,
+ LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
+ LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
+ LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError, PaymentRequiredError,
+ Role, SelectedModel, StopReason, TokenUsage,
};
use postage::stream::Stream as _;
use project::{
@@ -1531,82 +1530,7 @@ impl Thread {
}
thread.update(cx, |thread, cx| {
- let event = match event {
- Ok(event) => event,
- Err(error) => {
- match error {
- LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
- anyhow::bail!(LanguageModelKnownError::RateLimitExceeded { retry_after });
- }
- LanguageModelCompletionError::Overloaded => {
- anyhow::bail!(LanguageModelKnownError::Overloaded);
- }
- LanguageModelCompletionError::ApiInternalServerError =>{
- anyhow::bail!(LanguageModelKnownError::ApiInternalServerError);
- }
- LanguageModelCompletionError::PromptTooLarge { tokens } => {
- let tokens = tokens.unwrap_or_else(|| {
- // We didn't get an exact token count from the API, so fall back on our estimate.
- thread.total_token_usage()
- .map(|usage| usage.total)
- .unwrap_or(0)
- // We know the context window was exceeded in practice, so if our estimate was
- // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
- .max(model.max_token_count().saturating_add(1))
- });
-
- anyhow::bail!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens })
- }
- LanguageModelCompletionError::ApiReadResponseError(io_error) => {
- anyhow::bail!(LanguageModelKnownError::ReadResponseError(io_error));
- }
- LanguageModelCompletionError::UnknownResponseFormat(error) => {
- anyhow::bail!(LanguageModelKnownError::UnknownResponseFormat(error));
- }
- LanguageModelCompletionError::HttpResponseError { status, ref body } => {
- if let Some(known_error) = LanguageModelKnownError::from_http_response(status, body) {
- anyhow::bail!(known_error);
- } else {
- return Err(error.into());
- }
- }
- LanguageModelCompletionError::DeserializeResponse(error) => {
- anyhow::bail!(LanguageModelKnownError::DeserializeResponse(error));
- }
- LanguageModelCompletionError::BadInputJson {
- id,
- tool_name,
- raw_input: invalid_input_json,
- json_parse_error,
- } => {
- thread.receive_invalid_tool_json(
- id,
- tool_name,
- invalid_input_json,
- json_parse_error,
- window,
- cx,
- );
- return Ok(());
- }
- // These are all errors we can't automatically attempt to recover from (e.g. by retrying)
- err @ LanguageModelCompletionError::BadRequestFormat |
- err @ LanguageModelCompletionError::AuthenticationError |
- err @ LanguageModelCompletionError::PermissionError |
- err @ LanguageModelCompletionError::ApiEndpointNotFound |
- err @ LanguageModelCompletionError::SerializeRequest(_) |
- err @ LanguageModelCompletionError::BuildRequestBody(_) |
- err @ LanguageModelCompletionError::HttpSend(_) => {
- anyhow::bail!(err);
- }
- LanguageModelCompletionError::Other(error) => {
- return Err(error);
- }
- }
- }
- };
-
- match event {
+ match event? {
LanguageModelCompletionEvent::StartMessage { .. } => {
request_assistant_message_id =
Some(thread.insert_assistant_message(
@@ -1683,9 +1607,7 @@ impl Thread {
};
}
}
- LanguageModelCompletionEvent::RedactedThinking {
- data
- } => {
+ LanguageModelCompletionEvent::RedactedThinking { data } => {
thread.received_chunk();
if let Some(last_message) = thread.messages.last_mut() {
@@ -1734,6 +1656,21 @@ impl Thread {
});
}
}
+ LanguageModelCompletionEvent::ToolUseJsonParseError {
+ id,
+ tool_name,
+ raw_input: invalid_input_json,
+ json_parse_error,
+ } => {
+ thread.receive_invalid_tool_json(
+ id,
+ tool_name,
+ invalid_input_json,
+ json_parse_error,
+ window,
+ cx,
+ );
+ }
LanguageModelCompletionEvent::StatusUpdate(status_update) => {
if let Some(completion) = thread
.pending_completions
@@ -1741,23 +1678,34 @@ impl Thread {
.find(|completion| completion.id == pending_completion_id)
{
match status_update {
- CompletionRequestStatus::Queued {
- position,
- } => {
- completion.queue_state = QueueState::Queued { position };
+ CompletionRequestStatus::Queued { position } => {
+ completion.queue_state =
+ QueueState::Queued { position };
}
CompletionRequestStatus::Started => {
- completion.queue_state = QueueState::Started;
+ completion.queue_state = QueueState::Started;
}
CompletionRequestStatus::Failed {
- code, message, request_id
+ code,
+ message,
+ request_id: _,
+ retry_after,
} => {
- anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
+ return Err(
+ LanguageModelCompletionError::from_cloud_failure(
+ model.upstream_provider_name(),
+ code,
+ message,
+ retry_after.map(Duration::from_secs_f64),
+ ),
+ );
}
- CompletionRequestStatus::UsageUpdated {
- amount, limit
- } => {
- thread.update_model_request_usage(amount as u32, limit, cx);
+ CompletionRequestStatus::UsageUpdated { amount, limit } => {
+ thread.update_model_request_usage(
+ amount as u32,
+ limit,
+ cx,
+ );
}
CompletionRequestStatus::ToolUseLimitReached => {
thread.tool_use_limit_reached = true;
@@ -1808,10 +1756,11 @@ impl Thread {
Ok(stop_reason) => {
match stop_reason {
StopReason::ToolUse => {
- let tool_uses = thread.use_pending_tools(window, model.clone(), cx);
+ let tool_uses =
+ thread.use_pending_tools(window, model.clone(), cx);
cx.emit(ThreadEvent::UsePendingTools { tool_uses });
}
- StopReason::EndTurn | StopReason::MaxTokens => {
+ StopReason::EndTurn | StopReason::MaxTokens => {
thread.project.update(cx, |project, cx| {
project.set_agent_location(None, cx);
});
@@ -1827,7 +1776,9 @@ impl Thread {
{
let mut messages_to_remove = Vec::new();
- for (ix, message) in thread.messages.iter().enumerate().rev() {
+ for (ix, message) in
+ thread.messages.iter().enumerate().rev()
+ {
messages_to_remove.push(message.id);
if message.role == Role::User {
@@ -1835,7 +1786,9 @@ impl Thread {
break;
}
- if let Some(prev_message) = thread.messages.get(ix - 1) {
+ if let Some(prev_message) =
+ thread.messages.get(ix - 1)
+ {
if prev_message.role == Role::Assistant {
break;
}
@@ -1850,14 +1803,16 @@ impl Thread {
cx.emit(ThreadEvent::ShowError(ThreadError::Message {
header: "Language model refusal".into(),
- message: "Model refused to generate content for safety reasons.".into(),
+ message:
+ "Model refused to generate content for safety reasons."
+ .into(),
}));
}
}
// We successfully completed, so cancel any remaining retries.
thread.retry_state = None;
- },
+ }
Err(error) => {
thread.project.update(cx, |project, cx| {
project.set_agent_location(None, cx);
@@ -1883,26 +1838,38 @@ impl Thread {
cx.emit(ThreadEvent::ShowError(
ThreadError::ModelRequestLimitReached { plan: error.plan },
));
- } else if let Some(known_error) =
- error.downcast_ref::<LanguageModelKnownError>()
+ } else if let Some(completion_error) =
+ error.downcast_ref::<LanguageModelCompletionError>()
{
- match known_error {
- LanguageModelKnownError::ContextWindowLimitExceeded { tokens } => {
+ use LanguageModelCompletionError::*;
+ match &completion_error {
+ PromptTooLarge { tokens, .. } => {
+ let tokens = tokens.unwrap_or_else(|| {
+ // We didn't get an exact token count from the API, so fall back on our estimate.
+ thread
+ .total_token_usage()
+ .map(|usage| usage.total)
+ .unwrap_or(0)
+ // We know the context window was exceeded in practice, so if our estimate was
+ // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
+ .max(model.max_token_count().saturating_add(1))
+ });
thread.exceeded_window_error = Some(ExceededWindowError {
model_id: model.id(),
- token_count: *tokens,
+ token_count: tokens,
});
cx.notify();
}
- LanguageModelKnownError::RateLimitExceeded { retry_after } => {
- let provider_name = model.provider_name();
- let error_message = format!(
- "{}'s API rate limit exceeded",
- provider_name.0.as_ref()
- );
-
+ RateLimitExceeded {
+ retry_after: Some(retry_after),
+ ..
+ }
+ | ServerOverloaded {
+ retry_after: Some(retry_after),
+ ..
+ } => {
thread.handle_rate_limit_error(
- &error_message,
+ &completion_error,
*retry_after,
model.clone(),
intent,
@@ -1911,15 +1878,9 @@ impl Thread {
);
retry_scheduled = true;
}
- LanguageModelKnownError::Overloaded => {
- let provider_name = model.provider_name();
- let error_message = format!(
- "{}'s API servers are overloaded right now",
- provider_name.0.as_ref()
- );
-
+ RateLimitExceeded { .. } | ServerOverloaded { .. } => {
retry_scheduled = thread.handle_retryable_error(
- &error_message,
+ &completion_error,
model.clone(),
intent,
window,
@@ -1929,15 +1890,11 @@ impl Thread {
emit_generic_error(error, cx);
}
}
- LanguageModelKnownError::ApiInternalServerError => {
- let provider_name = model.provider_name();
- let error_message = format!(
- "{}'s API server reported an internal server error",
- provider_name.0.as_ref()
- );
-
+ ApiInternalServerError { .. }
+ | ApiReadResponseError { .. }
+ | HttpSend { .. } => {
retry_scheduled = thread.handle_retryable_error(
- &error_message,
+ &completion_error,
model.clone(),
intent,
window,
@@ -1947,12 +1904,16 @@ impl Thread {
emit_generic_error(error, cx);
}
}
- LanguageModelKnownError::ReadResponseError(_) |
- LanguageModelKnownError::DeserializeResponse(_) |
- LanguageModelKnownError::UnknownResponseFormat(_) => {
- // In the future we will attempt to re-roll response, but only once
- emit_generic_error(error, cx);
- }
+ NoApiKey { .. }
+ | HttpResponseError { .. }
+ | BadRequestFormat { .. }
+ | AuthenticationError { .. }
+ | PermissionError { .. }
+ | ApiEndpointNotFound { .. }
+ | SerializeRequest { .. }
+ | BuildRequestBody { .. }
+ | DeserializeResponse { .. }
+ | Other { .. } => emit_generic_error(error, cx),
}
} else {
emit_generic_error(error, cx);
@@ -2084,7 +2045,7 @@ impl Thread {
fn handle_rate_limit_error(
&mut self,
- error_message: &str,
+ error: &LanguageModelCompletionError,
retry_after: Duration,
model: Arc<dyn LanguageModel>,
intent: CompletionIntent,
@@ -2092,9 +2053,10 @@ impl Thread {
cx: &mut Context<Self>,
) {
// For rate limit errors, we only retry once with the specified duration
- let retry_message = format!(
- "{error_message}. Retrying in {} secondsβ¦",
- retry_after.as_secs()
+ let retry_message = format!("{error}. Retrying in {} secondsβ¦", retry_after.as_secs());
+ log::warn!(
+ "Retrying completion request in {} seconds: {error:?}",
+ retry_after.as_secs(),
);
// Add a UI-only message instead of a regular message
@@ -2127,18 +2089,18 @@ impl Thread {
fn handle_retryable_error(
&mut self,
- error_message: &str,
+ error: &LanguageModelCompletionError,
model: Arc<dyn LanguageModel>,
intent: CompletionIntent,
window: Option<AnyWindowHandle>,
cx: &mut Context<Self>,
) -> bool {
- self.handle_retryable_error_with_delay(error_message, None, model, intent, window, cx)
+ self.handle_retryable_error_with_delay(error, None, model, intent, window, cx)
}
fn handle_retryable_error_with_delay(
&mut self,
- error_message: &str,
+ error: &LanguageModelCompletionError,
custom_delay: Option<Duration>,
model: Arc<dyn LanguageModel>,
intent: CompletionIntent,
@@ -2168,8 +2130,12 @@ impl Thread {
// Add a transient message to inform the user
let delay_secs = delay.as_secs();
let retry_message = format!(
- "{}. Retrying (attempt {} of {}) in {} seconds...",
- error_message, attempt, max_attempts, delay_secs
+ "{error}. Retrying (attempt {attempt} of {max_attempts}) \
+ in {delay_secs} seconds..."
+ );
+ log::warn!(
+ "Retrying completion request (attempt {attempt} of {max_attempts}) \
+ in {delay_secs} seconds: {error:?}",
);
// Add a UI-only message instead of a regular message
@@ -4139,9 +4105,15 @@ fn main() {{
>,
> {
let error = match self.error_type {
- TestError::Overloaded => LanguageModelCompletionError::Overloaded,
+ TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
+ provider: self.provider_name(),
+ retry_after: None,
+ },
TestError::InternalServerError => {
- LanguageModelCompletionError::ApiInternalServerError
+ LanguageModelCompletionError::ApiInternalServerError {
+ provider: self.provider_name(),
+ message: "I'm a teapot orbiting the sun".to_string(),
+ }
}
};
async move {
@@ -4649,9 +4621,13 @@ fn main() {{
> {
if !*self.failed_once.lock() {
*self.failed_once.lock() = true;
+ let provider = self.provider_name();
// Return error on first attempt
let stream = futures::stream::once(async move {
- Err(LanguageModelCompletionError::Overloaded)
+ Err(LanguageModelCompletionError::ServerOverloaded {
+ provider,
+ retry_after: None,
+ })
});
async move { Ok(stream.boxed()) }.boxed()
} else {
@@ -4814,9 +4790,13 @@ fn main() {{
> {
if !*self.failed_once.lock() {
*self.failed_once.lock() = true;
+ let provider = self.provider_name();
// Return error on first attempt
let stream = futures::stream::once(async move {
- Err(LanguageModelCompletionError::Overloaded)
+ Err(LanguageModelCompletionError::ServerOverloaded {
+ provider,
+ retry_after: None,
+ })
});
async move { Ok(stream.boxed()) }.boxed()
} else {
@@ -4969,10 +4949,12 @@ fn main() {{
LanguageModelCompletionError,
>,
> {
+ let provider = self.provider_name();
async move {
let stream = futures::stream::once(async move {
Err(LanguageModelCompletionError::RateLimitExceeded {
- retry_after: Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS),
+ provider,
+ retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
})
});
Ok(stream.boxed())
@@ -2025,9 +2025,7 @@ impl AgentPanel {
.thread()
.read(cx)
.configured_model()
- .map_or(false, |model| {
- model.provider.id().0 == ZED_CLOUD_PROVIDER_ID
- });
+ .map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID);
if !is_using_zed_provider {
return false;
@@ -1250,9 +1250,7 @@ impl MessageEditor {
self.thread
.read(cx)
.configured_model()
- .map_or(false, |model| {
- model.provider.id().0 == ZED_CLOUD_PROVIDER_ID
- })
+ .map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID)
}
fn render_usage_callout(&self, line_height: Pixels, cx: &mut Context<Self>) -> Option<Div> {
@@ -6,7 +6,7 @@ use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::http::{self, HeaderMap, HeaderValue};
-use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode};
use serde::{Deserialize, Serialize};
use strum::{EnumIter, EnumString};
use thiserror::Error;
@@ -356,7 +356,7 @@ pub async fn complete(
.send(request)
.await
.map_err(AnthropicError::HttpSend)?;
- let status = response.status();
+ let status_code = response.status();
let mut body = String::new();
response
.body_mut()
@@ -364,12 +364,12 @@ pub async fn complete(
.await
.map_err(AnthropicError::ReadResponse)?;
- if status.is_success() {
+ if status_code.is_success() {
Ok(serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)?)
} else {
Err(AnthropicError::HttpResponseError {
- status: status.as_u16(),
- body,
+ status_code,
+ message: body,
})
}
}
@@ -444,11 +444,7 @@ impl RateLimitInfo {
}
Self {
- retry_after: headers
- .get("retry-after")
- .and_then(|v| v.to_str().ok())
- .and_then(|v| v.parse::<u64>().ok())
- .map(Duration::from_secs),
+ retry_after: parse_retry_after(headers),
requests: RateLimit::from_headers("requests", headers).ok(),
tokens: RateLimit::from_headers("tokens", headers).ok(),
input_tokens: RateLimit::from_headers("input-tokens", headers).ok(),
@@ -457,6 +453,17 @@ impl RateLimitInfo {
}
}
+/// Parses the Retry-After header value as an integer number of seconds (anthropic always uses
+/// seconds). Note that other services might specify an HTTP date or some other format for this
+/// header. Returns `None` if the header is not present or cannot be parsed.
+pub fn parse_retry_after(headers: &HeaderMap<HeaderValue>) -> Option<Duration> {
+ headers
+ .get("retry-after")
+ .and_then(|v| v.to_str().ok())
+ .and_then(|v| v.parse::<u64>().ok())
+ .map(Duration::from_secs)
+}
+
fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> anyhow::Result<&'a str> {
Ok(headers
.get(key)
@@ -520,6 +527,10 @@ pub async fn stream_completion_with_rate_limit_info(
})
.boxed();
Ok((stream, Some(rate_limits)))
+ } else if response.status().as_u16() == 529 {
+ Err(AnthropicError::ServerOverloaded {
+ retry_after: rate_limits.retry_after,
+ })
} else if let Some(retry_after) = rate_limits.retry_after {
Err(AnthropicError::RateLimit { retry_after })
} else {
@@ -532,10 +543,9 @@ pub async fn stream_completion_with_rate_limit_info(
match serde_json::from_str::<Event>(&body) {
Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
- Ok(_) => Err(AnthropicError::UnexpectedResponseFormat(body)),
- Err(_) => Err(AnthropicError::HttpResponseError {
- status: response.status().as_u16(),
- body: body,
+ Ok(_) | Err(_) => Err(AnthropicError::HttpResponseError {
+ status_code: response.status(),
+ message: body,
}),
}
}
@@ -801,16 +811,19 @@ pub enum AnthropicError {
ReadResponse(io::Error),
/// HTTP error response from the API
- HttpResponseError { status: u16, body: String },
+ HttpResponseError {
+ status_code: StatusCode,
+ message: String,
+ },
/// Rate limit exceeded
RateLimit { retry_after: Duration },
+ /// Server overloaded
+ ServerOverloaded { retry_after: Option<Duration> },
+
/// API returned an error response
ApiError(ApiError),
-
- /// Unexpected response format
- UnexpectedResponseFormat(String),
}
#[derive(Debug, Serialize, Deserialize, Error)]
@@ -2140,7 +2140,8 @@ impl AssistantContext {
);
}
LanguageModelCompletionEvent::ToolUse(_) |
- LanguageModelCompletionEvent::UsageUpdate(_) => {}
+ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } |
+ LanguageModelCompletionEvent::UsageUpdate(_) => {}
}
});
@@ -29,6 +29,7 @@ use std::{
path::Path,
str::FromStr,
sync::mpsc,
+ time::Duration,
};
use util::path;
@@ -1658,12 +1659,14 @@ async fn retry_on_rate_limit<R>(mut request: impl AsyncFnMut() -> Result<R>) ->
match request().await {
Ok(result) => return Ok(result),
Err(err) => match err.downcast::<LanguageModelCompletionError>() {
- Ok(err) => match err {
- LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
+ Ok(err) => match &err {
+ LanguageModelCompletionError::RateLimitExceeded { retry_after, .. }
+ | LanguageModelCompletionError::ServerOverloaded { retry_after, .. } => {
+ let retry_after = retry_after.unwrap_or(Duration::from_secs(5));
// Wait for the duration supplied, with some jitter to avoid all requests being made at the same time.
let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0));
eprintln!(
- "Attempt #{attempt}: Rate limit exceeded. Retry after {retry_after:?} + jitter of {jitter:?}"
+ "Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}"
);
Timer::after(retry_after + jitter).await;
continue;
@@ -1054,6 +1054,15 @@ pub fn response_events_to_markdown(
| LanguageModelCompletionEvent::StartMessage { .. }
| LanguageModelCompletionEvent::StatusUpdate { .. },
) => {}
+ Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+ json_parse_error, ..
+ }) => {
+ flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
+ response.push_str(&format!(
+ "**Error**: parse error in tool use JSON: {}\n\n",
+ json_parse_error
+ ));
+ }
Err(error) => {
flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
response.push_str(&format!("**Error**: {}\n\n", error));
@@ -1132,6 +1141,17 @@ impl ThreadDialog {
| Ok(LanguageModelCompletionEvent::StartMessage { .. })
| Ok(LanguageModelCompletionEvent::Stop(_)) => {}
+ Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+ json_parse_error,
+ ..
+ }) => {
+ flush_text(&mut current_text, &mut content);
+ content.push(MessageContent::Text(format!(
+ "ERROR: parse error in tool use JSON: {}",
+ json_parse_error
+ )));
+ }
+
Err(error) => {
flush_text(&mut current_text, &mut content);
content.push(MessageContent::Text(format!("ERROR: {}", error)));
@@ -9,17 +9,18 @@ mod telemetry;
pub mod fake_provider;
use anthropic::{AnthropicError, parse_prompt_too_long};
-use anyhow::Result;
+use anyhow::{Result, anyhow};
use client::Client;
use futures::FutureExt;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
-use http_client::http;
+use http_client::{StatusCode, http};
use icons::IconName;
use parking_lot::Mutex;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::ops::{Add, Sub};
+use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use std::{fmt, io};
@@ -34,11 +35,22 @@ pub use crate::request::*;
pub use crate::role::*;
pub use crate::telemetry::*;
-pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
+pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
+ LanguageModelProviderId::new("anthropic");
+pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
+ LanguageModelProviderName::new("Anthropic");
-/// If we get a rate limit error that doesn't tell us when we can retry,
-/// default to waiting this long before retrying.
-const DEFAULT_RATE_LIMIT_RETRY_AFTER: Duration = Duration::from_secs(4);
+pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
+pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
+ LanguageModelProviderName::new("Google AI");
+
+pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
+pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
+ LanguageModelProviderName::new("OpenAI");
+
+pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
+pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
+ LanguageModelProviderName::new("Zed");
pub fn init(client: Arc<Client>, cx: &mut App) {
init_settings(cx);
@@ -71,6 +83,12 @@ pub enum LanguageModelCompletionEvent {
data: String,
},
ToolUse(LanguageModelToolUse),
+ ToolUseJsonParseError {
+ id: LanguageModelToolUseId,
+ tool_name: Arc<str>,
+ raw_input: Arc<str>,
+ json_parse_error: String,
+ },
StartMessage {
message_id: String,
},
@@ -79,61 +97,179 @@ pub enum LanguageModelCompletionEvent {
#[derive(Error, Debug)]
pub enum LanguageModelCompletionError {
- #[error("rate limit exceeded, retry after {retry_after:?}")]
- RateLimitExceeded { retry_after: Duration },
- #[error("received bad input JSON")]
- BadInputJson {
- id: LanguageModelToolUseId,
- tool_name: Arc<str>,
- raw_input: Arc<str>,
- json_parse_error: String,
+ #[error("prompt too large for context window")]
+ PromptTooLarge { tokens: Option<u64> },
+ #[error("missing {provider} API key")]
+ NoApiKey { provider: LanguageModelProviderName },
+ #[error("{provider}'s API rate limit exceeded")]
+ RateLimitExceeded {
+ provider: LanguageModelProviderName,
+ retry_after: Option<Duration>,
+ },
+ #[error("{provider}'s API servers are overloaded right now")]
+ ServerOverloaded {
+ provider: LanguageModelProviderName,
+ retry_after: Option<Duration>,
+ },
+ #[error("{provider}'s API server reported an internal server error: {message}")]
+ ApiInternalServerError {
+ provider: LanguageModelProviderName,
+ message: String,
+ },
+ #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
+ HttpResponseError {
+ provider: LanguageModelProviderName,
+ status_code: StatusCode,
+ message: String,
+ },
+
+ // Client errors
+ #[error("invalid request format to {provider}'s API: {message}")]
+ BadRequestFormat {
+ provider: LanguageModelProviderName,
+ message: String,
},
- #[error("language model provider's API is overloaded")]
- Overloaded,
+ #[error("authentication error with {provider}'s API: {message}")]
+ AuthenticationError {
+ provider: LanguageModelProviderName,
+ message: String,
+ },
+ #[error("permission error with {provider}'s API: {message}")]
+ PermissionError {
+ provider: LanguageModelProviderName,
+ message: String,
+ },
+ #[error("language model provider API endpoint not found")]
+ ApiEndpointNotFound { provider: LanguageModelProviderName },
+ #[error("I/O error reading response from {provider}'s API")]
+ ApiReadResponseError {
+ provider: LanguageModelProviderName,
+ #[source]
+ error: io::Error,
+ },
+ #[error("error serializing request to {provider} API")]
+ SerializeRequest {
+ provider: LanguageModelProviderName,
+ #[source]
+ error: serde_json::Error,
+ },
+ #[error("error building request body to {provider} API")]
+ BuildRequestBody {
+ provider: LanguageModelProviderName,
+ #[source]
+ error: http::Error,
+ },
+ #[error("error sending HTTP request to {provider} API")]
+ HttpSend {
+ provider: LanguageModelProviderName,
+ #[source]
+ error: anyhow::Error,
+ },
+ #[error("error deserializing {provider} API response")]
+ DeserializeResponse {
+ provider: LanguageModelProviderName,
+ #[source]
+ error: serde_json::Error,
+ },
+
+ // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
#[error(transparent)]
Other(#[from] anyhow::Error),
- #[error("invalid request format to language model provider's API")]
- BadRequestFormat,
- #[error("authentication error with language model provider's API")]
- AuthenticationError,
- #[error("permission error with language model provider's API")]
- PermissionError,
- #[error("language model provider API endpoint not found")]
- ApiEndpointNotFound,
- #[error("prompt too large for context window")]
- PromptTooLarge { tokens: Option<u64> },
- #[error("internal server error in language model provider's API")]
- ApiInternalServerError,
- #[error("I/O error reading response from language model provider's API: {0:?}")]
- ApiReadResponseError(io::Error),
- #[error("HTTP response error from language model provider's API: status {status} - {body:?}")]
- HttpResponseError { status: u16, body: String },
- #[error("error serializing request to language model provider API: {0}")]
- SerializeRequest(serde_json::Error),
- #[error("error building request body to language model provider API: {0}")]
- BuildRequestBody(http::Error),
- #[error("error sending HTTP request to language model provider API: {0}")]
- HttpSend(anyhow::Error),
- #[error("error deserializing language model provider API response: {0}")]
- DeserializeResponse(serde_json::Error),
- #[error("unexpected language model provider API response format: {0}")]
- UnknownResponseFormat(String),
+}
+
+impl LanguageModelCompletionError {
+ pub fn from_cloud_failure(
+ upstream_provider: LanguageModelProviderName,
+ code: String,
+ message: String,
+ retry_after: Option<Duration>,
+ ) -> Self {
+ if let Some(tokens) = parse_prompt_too_long(&message) {
+ // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
+ // to be reported. This is a temporary workaround to handle this in the case where the
+ // token limit has been exceeded.
+ Self::PromptTooLarge {
+ tokens: Some(tokens),
+ }
+ } else if let Some(status_code) = code
+ .strip_prefix("upstream_http_")
+ .and_then(|code| StatusCode::from_str(code).ok())
+ {
+ Self::from_http_status(upstream_provider, status_code, message, retry_after)
+ } else if let Some(status_code) = code
+ .strip_prefix("http_")
+ .and_then(|code| StatusCode::from_str(code).ok())
+ {
+ Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
+ } else {
+ anyhow!("completion request failed, code: {code}, message: {message}").into()
+ }
+ }
+
+ pub fn from_http_status(
+ provider: LanguageModelProviderName,
+ status_code: StatusCode,
+ message: String,
+ retry_after: Option<Duration>,
+ ) -> Self {
+ match status_code {
+ StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
+ StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
+ StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
+ StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
+ StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
+ tokens: parse_prompt_too_long(&message),
+ },
+ StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
+ provider,
+ retry_after,
+ },
+ StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
+ StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
+ provider,
+ retry_after,
+ },
+ _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
+ provider,
+ retry_after,
+ },
+ _ => Self::HttpResponseError {
+ provider,
+ status_code,
+ message,
+ },
+ }
+ }
}
impl From<AnthropicError> for LanguageModelCompletionError {
fn from(error: AnthropicError) -> Self {
+ let provider = ANTHROPIC_PROVIDER_NAME;
match error {
- AnthropicError::SerializeRequest(error) => Self::SerializeRequest(error),
- AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody(error),
- AnthropicError::HttpSend(error) => Self::HttpSend(error),
- AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse(error),
- AnthropicError::ReadResponse(error) => Self::ApiReadResponseError(error),
- AnthropicError::HttpResponseError { status, body } => {
- Self::HttpResponseError { status, body }
+ AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
+ AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
+ AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
+ AnthropicError::DeserializeResponse(error) => {
+ Self::DeserializeResponse { provider, error }
}
- AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { retry_after },
+ AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
+ AnthropicError::HttpResponseError {
+ status_code,
+ message,
+ } => Self::HttpResponseError {
+ provider,
+ status_code,
+ message,
+ },
+ AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
+ provider,
+ retry_after: Some(retry_after),
+ },
+ AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
+ provider,
+ retry_after: retry_after,
+ },
AnthropicError::ApiError(api_error) => api_error.into(),
- AnthropicError::UnexpectedResponseFormat(error) => Self::UnknownResponseFormat(error),
}
}
}
@@ -141,23 +277,39 @@ impl From<AnthropicError> for LanguageModelCompletionError {
impl From<anthropic::ApiError> for LanguageModelCompletionError {
fn from(error: anthropic::ApiError) -> Self {
use anthropic::ApiErrorCode::*;
-
+ let provider = ANTHROPIC_PROVIDER_NAME;
match error.code() {
Some(code) => match code {
- InvalidRequestError => LanguageModelCompletionError::BadRequestFormat,
- AuthenticationError => LanguageModelCompletionError::AuthenticationError,
- PermissionError => LanguageModelCompletionError::PermissionError,
- NotFoundError => LanguageModelCompletionError::ApiEndpointNotFound,
- RequestTooLarge => LanguageModelCompletionError::PromptTooLarge {
+ InvalidRequestError => Self::BadRequestFormat {
+ provider,
+ message: error.message,
+ },
+ AuthenticationError => Self::AuthenticationError {
+ provider,
+ message: error.message,
+ },
+ PermissionError => Self::PermissionError {
+ provider,
+ message: error.message,
+ },
+ NotFoundError => Self::ApiEndpointNotFound { provider },
+ RequestTooLarge => Self::PromptTooLarge {
tokens: parse_prompt_too_long(&error.message),
},
- RateLimitError => LanguageModelCompletionError::RateLimitExceeded {
- retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
+ RateLimitError => Self::RateLimitExceeded {
+ provider,
+ retry_after: None,
+ },
+ ApiError => Self::ApiInternalServerError {
+ provider,
+ message: error.message,
+ },
+ OverloadedError => Self::ServerOverloaded {
+ provider,
+ retry_after: None,
},
- ApiError => LanguageModelCompletionError::ApiInternalServerError,
- OverloadedError => LanguageModelCompletionError::Overloaded,
},
- None => LanguageModelCompletionError::Other(error.into()),
+ None => Self::Other(error.into()),
}
}
}
@@ -278,6 +430,13 @@ pub trait LanguageModel: Send + Sync {
fn name(&self) -> LanguageModelName;
fn provider_id(&self) -> LanguageModelProviderId;
fn provider_name(&self) -> LanguageModelProviderName;
+ fn upstream_provider_id(&self) -> LanguageModelProviderId {
+ self.provider_id()
+ }
+ fn upstream_provider_name(&self) -> LanguageModelProviderName {
+ self.provider_name()
+ }
+
fn telemetry_id(&self) -> String;
fn api_key(&self, _cx: &App) -> Option<String> {
@@ -365,6 +524,9 @@ pub trait LanguageModel: Send + Sync {
Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
Ok(LanguageModelCompletionEvent::Stop(_)) => None,
Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
+ Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+ ..
+ }) => None,
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
*last_token_usage.lock() = token_usage;
None
@@ -395,39 +557,6 @@ pub trait LanguageModel: Send + Sync {
}
}
-#[derive(Debug, Error)]
-pub enum LanguageModelKnownError {
- #[error("Context window limit exceeded ({tokens})")]
- ContextWindowLimitExceeded { tokens: u64 },
- #[error("Language model provider's API is currently overloaded")]
- Overloaded,
- #[error("Language model provider's API encountered an internal server error")]
- ApiInternalServerError,
- #[error("I/O error while reading response from language model provider's API: {0:?}")]
- ReadResponseError(io::Error),
- #[error("Error deserializing response from language model provider's API: {0:?}")]
- DeserializeResponse(serde_json::Error),
- #[error("Language model provider's API returned a response in an unknown format")]
- UnknownResponseFormat(String),
- #[error("Rate limit exceeded for language model provider's API; retry in {retry_after:?}")]
- RateLimitExceeded { retry_after: Duration },
-}
-
-impl LanguageModelKnownError {
- /// Attempts to map an HTTP response status code to a known error type.
- /// Returns None if the status code doesn't map to a specific known error.
- pub fn from_http_response(status: u16, _body: &str) -> Option<Self> {
- match status {
- 429 => Some(Self::RateLimitExceeded {
- retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
- }),
- 503 => Some(Self::Overloaded),
- 500..=599 => Some(Self::ApiInternalServerError),
- _ => None,
- }
- }
-}
-
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
fn name() -> String;
fn description() -> String;
@@ -509,12 +638,30 @@ pub struct LanguageModelProviderId(pub SharedString);
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
pub struct LanguageModelProviderName(pub SharedString);
+impl LanguageModelProviderId {
+ pub const fn new(id: &'static str) -> Self {
+ Self(SharedString::new_static(id))
+ }
+}
+
+impl LanguageModelProviderName {
+ pub const fn new(id: &'static str) -> Self {
+ Self(SharedString::new_static(id))
+ }
+}
+
impl fmt::Display for LanguageModelProviderId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
+impl fmt::Display for LanguageModelProviderName {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "{}", self.0)
+ }
+}
+
impl From<String> for LanguageModelId {
fn from(value: String) -> Self {
Self(SharedString::from(value))
@@ -98,7 +98,7 @@ impl ConfiguredModel {
}
pub fn is_provided_by_zed(&self) -> bool {
- self.provider.id().0 == crate::ZED_CLOUD_PROVIDER_ID
+ self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
}
}
@@ -1,3 +1,4 @@
+use crate::ANTHROPIC_PROVIDER_ID;
use anthropic::ANTHROPIC_API_URL;
use anyhow::{Context as _, anyhow};
use client::telemetry::Telemetry;
@@ -8,8 +9,6 @@ use std::sync::Arc;
use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
use util::ResultExt;
-pub const ANTHROPIC_PROVIDER_ID: &str = "anthropic";
-
pub fn report_assistant_event(
event: AssistantEventData,
telemetry: Option<Arc<Telemetry>>,
@@ -19,7 +18,7 @@ pub fn report_assistant_event(
) {
if let Some(telemetry) = telemetry.as_ref() {
telemetry.report_assistant_event(event.clone());
- if telemetry.metrics_enabled() && event.model_provider == ANTHROPIC_PROVIDER_ID {
+ if telemetry.metrics_enabled() && event.model_provider == ANTHROPIC_PROVIDER_ID.0 {
if let Some(api_key) = model_api_key {
executor
.spawn(async move {
@@ -33,8 +33,8 @@ use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::ResultExt;
-const PROVIDER_ID: &str = language_model::ANTHROPIC_PROVIDER_ID;
-const PROVIDER_NAME: &str = "Anthropic";
+const PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
+const PROVIDER_NAME: LanguageModelProviderName = language_model::ANTHROPIC_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)]
pub struct AnthropicSettings {
@@ -218,11 +218,11 @@ impl LanguageModelProviderState for AnthropicLanguageModelProvider {
impl LanguageModelProvider for AnthropicLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn icon(&self) -> IconName {
@@ -403,7 +403,11 @@ impl AnthropicModel {
};
async move {
- let api_key = api_key.context("Missing Anthropic API Key")?;
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ });
+ };
let request =
anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
request.await.map_err(Into::into)
@@ -422,11 +426,11 @@ impl LanguageModel for AnthropicModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@@ -806,12 +810,14 @@ impl AnthropicEventMapper {
raw_input: tool_use.input_json.clone(),
},
)),
- Err(json_parse_err) => Err(LanguageModelCompletionError::BadInputJson {
- id: tool_use.id.into(),
- tool_name: tool_use.name.into(),
- raw_input: input_json.into(),
- json_parse_error: json_parse_err.to_string(),
- }),
+ Err(json_parse_err) => {
+ Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+ id: tool_use.id.into(),
+ tool_name: tool_use.name.into(),
+ raw_input: input_json.into(),
+ json_parse_error: json_parse_err.to_string(),
+ })
+ }
};
vec![event_result]
@@ -52,8 +52,8 @@ use util::ResultExt;
use crate::AllLanguageModelSettings;
-const PROVIDER_ID: &str = "amazon-bedrock";
-const PROVIDER_NAME: &str = "Amazon Bedrock";
+const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("amazon-bedrock");
+const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Amazon Bedrock");
#[derive(Default, Clone, Deserialize, Serialize, PartialEq, Debug)]
pub struct BedrockCredentials {
@@ -285,11 +285,11 @@ impl BedrockLanguageModelProvider {
impl LanguageModelProvider for BedrockLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn icon(&self) -> IconName {
@@ -489,11 +489,11 @@ impl LanguageModel for BedrockModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@@ -1,4 +1,4 @@
-use anthropic::{AnthropicModelMode, parse_prompt_too_long};
+use anthropic::AnthropicModelMode;
use anyhow::{Context as _, Result, anyhow};
use client::{Client, ModelRequestUsage, UserStore, zed_urls};
use futures::{
@@ -8,25 +8,21 @@ use google_ai::GoogleModelMode;
use gpui::{
AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
};
+use http_client::http::{HeaderMap, HeaderValue};
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
- LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
- LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
- LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice,
- LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter,
- ZED_CLOUD_PROVIDER_ID,
-};
-use language_model::{
- LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, PaymentRequiredError,
- RefreshLlmTokenListener,
+ LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
+ LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
+ LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
+ LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
+ ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
};
use proto::Plan;
use release_channel::AppVersion;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use settings::SettingsStore;
-use smol::Timer;
use smol::io::{AsyncReadExt, BufReader};
use std::pin::Pin;
use std::str::FromStr as _;
@@ -47,7 +43,8 @@ use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, i
use crate::provider::google::{GoogleEventMapper, into_google};
use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
-pub const PROVIDER_NAME: &str = "Zed";
+const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
+const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)]
pub struct ZedDotDevSettings {
@@ -351,11 +348,11 @@ impl LanguageModelProviderState for CloudLanguageModelProvider {
impl LanguageModelProvider for CloudLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
+ PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn icon(&self) -> IconName {
@@ -536,8 +533,6 @@ struct PerformLlmCompletionResponse {
}
impl CloudLanguageModel {
- const MAX_RETRIES: usize = 3;
-
async fn perform_llm_completion(
client: Arc<Client>,
llm_api_token: LlmApiToken,
@@ -547,8 +542,7 @@ impl CloudLanguageModel {
let http_client = &client.http_client();
let mut token = llm_api_token.acquire(&client).await?;
- let mut retries_remaining = Self::MAX_RETRIES;
- let mut retry_delay = Duration::from_secs(1);
+ let mut refreshed_token = false;
loop {
let request_builder = http_client::Request::builder()
@@ -590,14 +584,20 @@ impl CloudLanguageModel {
includes_status_messages,
tool_use_limit_reached,
});
- } else if response
- .headers()
- .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
- .is_some()
+ }
+
+ if !refreshed_token
+ && response
+ .headers()
+ .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
+ .is_some()
{
- retries_remaining -= 1;
token = llm_api_token.refresh(&client).await?;
- } else if status == StatusCode::FORBIDDEN
+ refreshed_token = true;
+ continue;
+ }
+
+ if status == StatusCode::FORBIDDEN
&& response
.headers()
.get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
@@ -622,35 +622,18 @@ impl CloudLanguageModel {
return Err(anyhow!(ModelRequestLimitReachedError { plan }));
}
}
-
- anyhow::bail!("Forbidden");
- } else if status.as_u16() >= 500 && status.as_u16() < 600 {
- // If we encounter an error in the 500 range, retry after a delay.
- // We've seen at least these in the wild from API providers:
- // * 500 Internal Server Error
- // * 502 Bad Gateway
- // * 529 Service Overloaded
-
- if retries_remaining == 0 {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- anyhow::bail!(
- "cloud language model completion failed after {} retries with status {status}: {body}",
- Self::MAX_RETRIES
- );
- }
-
- Timer::after(retry_delay).await;
-
- retries_remaining -= 1;
- retry_delay *= 2; // If it fails again, wait longer.
} else if status == StatusCode::PAYMENT_REQUIRED {
return Err(anyhow!(PaymentRequiredError));
- } else {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- return Err(anyhow!(ApiError { status, body }));
}
+
+ let mut body = String::new();
+ let headers = response.headers().clone();
+ response.body_mut().read_to_string(&mut body).await?;
+ return Err(anyhow!(ApiError {
+ status,
+ body,
+ headers
+ }));
}
}
}
@@ -660,6 +643,19 @@ impl CloudLanguageModel {
struct ApiError {
status: StatusCode,
body: String,
+ headers: HeaderMap<HeaderValue>,
+}
+
+impl From<ApiError> for LanguageModelCompletionError {
+ fn from(error: ApiError) -> Self {
+ let retry_after = None;
+ LanguageModelCompletionError::from_http_status(
+ PROVIDER_NAME,
+ error.status,
+ error.body,
+ retry_after,
+ )
+ }
}
impl LanguageModel for CloudLanguageModel {
@@ -672,11 +668,29 @@ impl LanguageModel for CloudLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
+ PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
+ }
+
+ fn upstream_provider_id(&self) -> LanguageModelProviderId {
+ use zed_llm_client::LanguageModelProvider::*;
+ match self.model.provider {
+ Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
+ OpenAi => language_model::OPEN_AI_PROVIDER_ID,
+ Google => language_model::GOOGLE_PROVIDER_ID,
+ }
+ }
+
+ fn upstream_provider_name(&self) -> LanguageModelProviderName {
+ use zed_llm_client::LanguageModelProvider::*;
+ match self.model.provider {
+ Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
+ OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
+ Google => language_model::GOOGLE_PROVIDER_NAME,
+ }
}
fn supports_tools(&self) -> bool {
@@ -776,6 +790,7 @@ impl LanguageModel for CloudLanguageModel {
.body(serde_json::to_string(&request_body)?.into())?;
let mut response = http_client.send(request).await?;
let status = response.status();
+ let headers = response.headers().clone();
let mut response_body = String::new();
response
.body_mut()
@@ -790,7 +805,8 @@ impl LanguageModel for CloudLanguageModel {
} else {
Err(anyhow!(ApiError {
status,
- body: response_body
+ body: response_body,
+ headers
}))
}
}
@@ -855,18 +871,7 @@ impl LanguageModel for CloudLanguageModel {
)
.await
.map_err(|err| match err.downcast::<ApiError>() {
- Ok(api_err) => {
- if api_err.status == StatusCode::BAD_REQUEST {
- if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
- return anyhow!(
- LanguageModelKnownError::ContextWindowLimitExceeded {
- tokens
- }
- );
- }
- }
- anyhow!(api_err)
- }
+ Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
Err(err) => anyhow!(err),
})?;
@@ -995,7 +1000,7 @@ where
.flat_map(move |event| {
futures::stream::iter(match event {
Err(error) => {
- vec![Err(LanguageModelCompletionError::Other(error))]
+ vec![Err(LanguageModelCompletionError::from(error))]
}
Ok(CloudCompletionEvent::Status(event)) => {
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
@@ -35,8 +35,9 @@ use super::anthropic::count_anthropic_tokens;
use super::google::count_google_tokens;
use super::open_ai::count_open_ai_tokens;
-const PROVIDER_ID: &str = "copilot_chat";
-const PROVIDER_NAME: &str = "GitHub Copilot Chat";
+const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
+const PROVIDER_NAME: LanguageModelProviderName =
+ LanguageModelProviderName::new("GitHub Copilot Chat");
pub struct CopilotChatLanguageModelProvider {
state: Entity<State>,
@@ -102,11 +103,11 @@ impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
impl LanguageModelProvider for CopilotChatLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn icon(&self) -> IconName {
@@ -201,11 +202,11 @@ impl LanguageModel for CopilotChatLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@@ -391,24 +392,24 @@ pub fn map_to_language_model_completion_events(
serde_json::Value::from_str(&tool_call.arguments)
};
match arguments {
- Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
- LanguageModelToolUse {
- id: tool_call.id.clone().into(),
- name: tool_call.name.as_str().into(),
- is_input_complete: true,
- input,
- raw_input: tool_call.arguments.clone(),
- },
- )),
- Err(error) => {
- Err(LanguageModelCompletionError::BadInputJson {
- id: tool_call.id.into(),
- tool_name: tool_call.name.as_str().into(),
- raw_input: tool_call.arguments.into(),
- json_parse_error: error.to_string(),
- })
- }
- }
+ Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_call.id.clone().into(),
+ name: tool_call.name.as_str().into(),
+ is_input_complete: true,
+ input,
+ raw_input: tool_call.arguments.clone(),
+ },
+ )),
+ Err(error) => Ok(
+ LanguageModelCompletionEvent::ToolUseJsonParseError {
+ id: tool_call.id.into(),
+ tool_name: tool_call.name.as_str().into(),
+ raw_input: tool_call.arguments.into(),
+ json_parse_error: error.to_string(),
+ },
+ ),
+ }
},
));
@@ -28,8 +28,8 @@ use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
-const PROVIDER_ID: &str = "deepseek";
-const PROVIDER_NAME: &str = "DeepSeek";
+const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek");
+const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek");
const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY";
#[derive(Default)]
@@ -174,11 +174,11 @@ impl LanguageModelProviderState for DeepSeekLanguageModelProvider {
impl LanguageModelProvider for DeepSeekLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn icon(&self) -> IconName {
@@ -283,11 +283,11 @@ impl LanguageModel for DeepSeekLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@@ -466,7 +466,7 @@ impl DeepSeekEventMapper {
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
- Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
+ Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
})
})
}
@@ -476,7 +476,7 @@ impl DeepSeekEventMapper {
event: deepseek::StreamResponse,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.first() else {
- return vec![Err(LanguageModelCompletionError::Other(anyhow!(
+ return vec![Err(LanguageModelCompletionError::from(anyhow!(
"Response contained no choices"
)))];
};
@@ -538,8 +538,8 @@ impl DeepSeekEventMapper {
raw_input: tool_call.arguments.clone(),
},
)),
- Err(error) => Err(LanguageModelCompletionError::BadInputJson {
- id: tool_call.id.into(),
+ Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+ id: tool_call.id.clone().into(),
tool_name: tool_call.name.as_str().into(),
raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(),
@@ -37,8 +37,8 @@ use util::ResultExt;
use crate::AllLanguageModelSettings;
use crate::ui::InstructionListItem;
-const PROVIDER_ID: &str = "google";
-const PROVIDER_NAME: &str = "Google AI";
+const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
+const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)]
pub struct GoogleSettings {
@@ -207,11 +207,11 @@ impl LanguageModelProviderState for GoogleLanguageModelProvider {
impl LanguageModelProvider for GoogleLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn icon(&self) -> IconName {
@@ -334,11 +334,11 @@ impl LanguageModel for GoogleLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@@ -423,9 +423,7 @@ impl LanguageModel for GoogleLanguageModel {
);
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
- let response = request
- .await
- .map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?;
+ let response = request.await.map_err(LanguageModelCompletionError::from)?;
Ok(GoogleEventMapper::new().map_stream(response))
});
async move { Ok(future.await?.boxed()) }.boxed()
@@ -622,7 +620,7 @@ impl GoogleEventMapper {
futures::stream::iter(match event {
Some(Ok(event)) => self.map_event(event),
Some(Err(error)) => {
- vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))]
+ vec![Err(LanguageModelCompletionError::from(error))]
}
None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
})
@@ -31,8 +31,8 @@ const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
-const PROVIDER_ID: &str = "lmstudio";
-const PROVIDER_NAME: &str = "LM Studio";
+const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("lmstudio");
+const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("LM Studio");
#[derive(Default, Debug, Clone, PartialEq)]
pub struct LmStudioSettings {
@@ -156,11 +156,11 @@ impl LanguageModelProviderState for LmStudioLanguageModelProvider {
impl LanguageModelProvider for LmStudioLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn icon(&self) -> IconName {
@@ -386,11 +386,11 @@ impl LanguageModel for LmStudioLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@@ -474,7 +474,7 @@ impl LmStudioEventMapper {
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
- Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
+ Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
})
})
}
@@ -484,7 +484,7 @@ impl LmStudioEventMapper {
event: lmstudio::ResponseStreamEvent,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.into_iter().next() else {
- return vec![Err(LanguageModelCompletionError::Other(anyhow!(
+ return vec![Err(LanguageModelCompletionError::from(anyhow!(
"Response contained no choices"
)))];
};
@@ -553,7 +553,7 @@ impl LmStudioEventMapper {
raw_input: tool_call.arguments,
},
)),
- Err(error) => Err(LanguageModelCompletionError::BadInputJson {
+ Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_call.id.into(),
tool_name: tool_call.name.into(),
raw_input: tool_call.arguments.into(),
@@ -2,8 +2,7 @@ use anyhow::{Context as _, Result, anyhow};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
-use futures::stream::BoxStream;
-use futures::{FutureExt, StreamExt, future::BoxFuture};
+use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
};
@@ -15,6 +14,7 @@ use language_model::{
LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
RateLimiter, Role, StopReason, TokenUsage,
};
+use mistral::StreamResponse;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
@@ -29,8 +29,8 @@ use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
-const PROVIDER_ID: &str = "mistral";
-const PROVIDER_NAME: &str = "Mistral";
+const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
+const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
#[derive(Default, Clone, Debug, PartialEq)]
pub struct MistralSettings {
@@ -171,11 +171,11 @@ impl LanguageModelProviderState for MistralLanguageModelProvider {
impl LanguageModelProvider for MistralLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn icon(&self) -> IconName {
@@ -298,11 +298,11 @@ impl LanguageModel for MistralLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@@ -579,13 +579,13 @@ impl MistralEventMapper {
pub fn map_stream(
mut self,
- events: Pin<Box<dyn Send + futures::Stream<Item = Result<mistral::StreamResponse>>>>,
- ) -> impl futures::Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
+ events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
+ ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
- Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
+ Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
})
})
}
@@ -595,7 +595,7 @@ impl MistralEventMapper {
event: mistral::StreamResponse,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.first() else {
- return vec![Err(LanguageModelCompletionError::Other(anyhow!(
+ return vec![Err(LanguageModelCompletionError::from(anyhow!(
"Response contained no choices"
)))];
};
@@ -660,7 +660,7 @@ impl MistralEventMapper {
for (_, tool_call) in self.tool_calls_by_index.drain() {
if tool_call.id.is_empty() || tool_call.name.is_empty() {
- results.push(Err(LanguageModelCompletionError::Other(anyhow!(
+ results.push(Err(LanguageModelCompletionError::from(anyhow!(
"Received incomplete tool call: missing id or name"
))));
continue;
@@ -676,12 +676,14 @@ impl MistralEventMapper {
raw_input: tool_call.arguments,
},
))),
- Err(error) => results.push(Err(LanguageModelCompletionError::BadInputJson {
- id: tool_call.id.into(),
- tool_name: tool_call.name.into(),
- raw_input: tool_call.arguments.into(),
- json_parse_error: error.to_string(),
- })),
+ Err(error) => {
+ results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+ id: tool_call.id.into(),
+ tool_name: tool_call.name.into(),
+ raw_input: tool_call.arguments.into(),
+ json_parse_error: error.to_string(),
+ }))
+ }
}
}
@@ -30,8 +30,8 @@ const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
const OLLAMA_SITE: &str = "https://ollama.com/";
-const PROVIDER_ID: &str = "ollama";
-const PROVIDER_NAME: &str = "Ollama";
+const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama");
+const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Ollama");
#[derive(Default, Debug, Clone, PartialEq)]
pub struct OllamaSettings {
@@ -181,11 +181,11 @@ impl LanguageModelProviderState for OllamaLanguageModelProvider {
impl LanguageModelProvider for OllamaLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn icon(&self) -> IconName {
@@ -350,11 +350,11 @@ impl LanguageModel for OllamaLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@@ -453,7 +453,7 @@ fn map_to_language_model_completion_events(
let delta = match response {
Ok(delta) => delta,
Err(e) => {
- let event = Err(LanguageModelCompletionError::Other(anyhow!(e)));
+ let event = Err(LanguageModelCompletionError::from(anyhow!(e)));
return Some((vec![event], state));
}
};
@@ -31,8 +31,8 @@ use util::ResultExt;
use crate::OpenAiSettingsContent;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
-const PROVIDER_ID: &str = "openai";
-const PROVIDER_NAME: &str = "OpenAI";
+const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
+const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
#[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenAiSettings {
@@ -173,11 +173,11 @@ impl LanguageModelProviderState for OpenAiLanguageModelProvider {
impl LanguageModelProvider for OpenAiLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn icon(&self) -> IconName {
@@ -267,7 +267,11 @@ impl OpenAiLanguageModel {
};
let future = self.request_limiter.stream(async move {
- let api_key = api_key.context("Missing OpenAI API Key")?;
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ });
+ };
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?;
Ok(response)
@@ -287,11 +291,11 @@ impl LanguageModel for OpenAiLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@@ -525,7 +529,7 @@ impl OpenAiEventMapper {
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
- Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
+ Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
})
})
}
@@ -588,10 +592,10 @@ impl OpenAiEventMapper {
raw_input: tool_call.arguments.clone(),
},
)),
- Err(error) => Err(LanguageModelCompletionError::BadInputJson {
+ Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
id: tool_call.id.into(),
- tool_name: tool_call.name.as_str().into(),
- raw_input: tool_call.arguments.into(),
+ tool_name: tool_call.name.into(),
+ raw_input: tool_call.arguments.clone().into(),
json_parse_error: error.to_string(),
}),
}
@@ -29,8 +29,8 @@ use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
-const PROVIDER_ID: &str = "openrouter";
-const PROVIDER_NAME: &str = "OpenRouter";
+const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter");
+const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter");
#[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenRouterSettings {
@@ -244,11 +244,11 @@ impl LanguageModelProviderState for OpenRouterLanguageModelProvider {
impl LanguageModelProvider for OpenRouterLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn icon(&self) -> IconName {
@@ -363,11 +363,11 @@ impl LanguageModel for OpenRouterLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@@ -607,7 +607,7 @@ impl OpenRouterEventMapper {
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
- Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
+ Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
})
})
}
@@ -617,7 +617,7 @@ impl OpenRouterEventMapper {
event: ResponseStreamEvent,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.first() else {
- return vec![Err(LanguageModelCompletionError::Other(anyhow!(
+ return vec![Err(LanguageModelCompletionError::from(anyhow!(
"Response contained no choices"
)))];
};
@@ -683,10 +683,10 @@ impl OpenRouterEventMapper {
raw_input: tool_call.arguments.clone(),
},
)),
- Err(error) => Err(LanguageModelCompletionError::BadInputJson {
- id: tool_call.id.into(),
+ Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+ id: tool_call.id.clone().into(),
tool_name: tool_call.name.as_str().into(),
- raw_input: tool_call.arguments.into(),
+ raw_input: tool_call.arguments.clone().into(),
json_parse_error: error.to_string(),
}),
}
@@ -25,8 +25,8 @@ use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
-const PROVIDER_ID: &str = "vercel";
-const PROVIDER_NAME: &str = "Vercel";
+const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("vercel");
+const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Vercel");
#[derive(Default, Clone, Debug, PartialEq)]
pub struct VercelSettings {
@@ -172,11 +172,11 @@ impl LanguageModelProviderState for VercelLanguageModelProvider {
impl LanguageModelProvider for VercelLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn icon(&self) -> IconName {
@@ -269,7 +269,11 @@ impl VercelLanguageModel {
};
let future = self.request_limiter.stream(async move {
- let api_key = api_key.context("Missing Vercel API Key")?;
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ });
+ };
let request =
open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
let response = request.await?;
@@ -290,11 +294,11 @@ impl LanguageModel for VercelLanguageModel {
}
fn provider_id(&self) -> LanguageModelProviderId {
- LanguageModelProviderId(PROVIDER_ID.into())
+ PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
+ PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
@@ -7,10 +7,7 @@ use gpui::{App, AppContext, Context, Entity, Subscription, Task};
use http_client::{HttpClient, Method};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
use web_search::{WebSearchProvider, WebSearchProviderId};
-use zed_llm_client::{
- CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, EXPIRED_LLM_TOKEN_HEADER_NAME,
- WebSearchBody, WebSearchResponse,
-};
+use zed_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, WebSearchBody, WebSearchResponse};
pub struct CloudWebSearchProvider {
state: Entity<State>,
@@ -92,7 +89,6 @@ async fn perform_web_search(
.uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}"))
- .header(CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, "true")
.body(serde_json::to_string(&body)?.into())?;
let mut response = http_client
.send(request)