agent: Improve error handling and retry for zed-provided models (#33565)

Michael Sloan , Richard Feldman , and Richard created

* Updates to `zed_llm_client-0.8.5` which adds support for `retry_after`
when anthropic provides it.

* Distinguishes upstream provider errors and rate limits from errors
that originate from zed's servers

* Moves `LanguageModelCompletionError::BadInputJson` to
`LanguageModelCompletionEvent::ToolUseJsonParseError`. While arguably
this is an error case, the logic in thread is cleaner with this move.
There is also precedent for inclusion of errors in the event type -
`CompletionRequestStatus::Failed` is how cloud errors arrive.

* Updates `PROVIDER_ID` / `PROVIDER_NAME` constants to use proper types
instead of `&str`, since they can be constructed in a const fashion.

* Removes use of `CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME`
as the server no longer reads this header and just defaults to that
behavior.

Release notes for this is covered by #33275

Release Notes:

- N/A

---------

Co-authored-by: Richard Feldman <oss@rtfeldman.com>
Co-authored-by: Richard <richard@zed.dev>

Change summary

Cargo.lock                                          |   4 
Cargo.toml                                          |   2 
crates/agent/src/thread.rs                          | 282 +++++------
crates/agent_ui/src/agent_panel.rs                  |   4 
crates/agent_ui/src/message_editor.rs               |   4 
crates/anthropic/src/anthropic.rs                   |  49 +
crates/assistant_context/src/assistant_context.rs   |   3 
crates/assistant_tools/src/edit_agent/evals.rs      |   9 
crates/eval/src/instance.rs                         |  20 
crates/language_model/src/language_model.rs         | 337 ++++++++++----
crates/language_model/src/registry.rs               |   2 
crates/language_model/src/telemetry.rs              |   5 
crates/language_models/src/provider/anthropic.rs    |  32 
crates/language_models/src/provider/bedrock.rs      |  12 
crates/language_models/src/provider/cloud.rs        | 137 +++--
crates/language_models/src/provider/copilot_chat.rs |  49 +-
crates/language_models/src/provider/deepseek.rs     |  20 
crates/language_models/src/provider/google.rs       |  18 
crates/language_models/src/provider/lmstudio.rs     |  18 
crates/language_models/src/provider/mistral.rs      |  40 
crates/language_models/src/provider/ollama.rs       |  14 
crates/language_models/src/provider/open_ai.rs      |  26 
crates/language_models/src/provider/open_router.rs  |  22 
crates/language_models/src/provider/vercel.rs       |  18 
crates/web_search_providers/src/cloud.rs            |   6 
25 files changed, 655 insertions(+), 478 deletions(-)

Detailed changes

Cargo.lock πŸ”—

@@ -20139,9 +20139,9 @@ dependencies = [
 
 [[package]]
 name = "zed_llm_client"
-version = "0.8.4"
+version = "0.8.5"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "de7d9523255f4e00ee3d0918e5407bd252d798a4a8e71f6d37f23317a1588203"
+checksum = "c740e29260b8797ad252c202ea09a255b3cbc13f30faaf92fb6b2490336106e0"
 dependencies = [
  "anyhow",
  "serde",

Cargo.toml πŸ”—

@@ -625,7 +625,7 @@ wasmtime = { version = "29", default-features = false, features = [
 wasmtime-wasi = "29"
 which = "6.0.0"
 workspace-hack = "0.1.0"
-zed_llm_client = "0.8.4"
+zed_llm_client = "0.8.5"
 zstd = "0.11"
 
 [workspace.dependencies.async-stripe]

crates/agent/src/thread.rs πŸ”—

@@ -23,11 +23,10 @@ use gpui::{
 };
 use language_model::{
     ConfiguredModel, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
-    LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
-    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
-    LanguageModelToolResultContent, LanguageModelToolUseId, MessageContent,
-    ModelRequestLimitReachedError, PaymentRequiredError, Role, SelectedModel, StopReason,
-    TokenUsage,
+    LanguageModelId, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
+    LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
+    LanguageModelToolUseId, MessageContent, ModelRequestLimitReachedError, PaymentRequiredError,
+    Role, SelectedModel, StopReason, TokenUsage,
 };
 use postage::stream::Stream as _;
 use project::{
@@ -1531,82 +1530,7 @@ impl Thread {
                     }
 
                     thread.update(cx, |thread, cx| {
-                        let event = match event {
-                            Ok(event) => event,
-                            Err(error) => {
-                                match error {
-                                    LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
-                                        anyhow::bail!(LanguageModelKnownError::RateLimitExceeded { retry_after });
-                                    }
-                                    LanguageModelCompletionError::Overloaded => {
-                                        anyhow::bail!(LanguageModelKnownError::Overloaded);
-                                    }
-                                    LanguageModelCompletionError::ApiInternalServerError =>{
-                                        anyhow::bail!(LanguageModelKnownError::ApiInternalServerError);
-                                    }
-                                    LanguageModelCompletionError::PromptTooLarge { tokens } => {
-                                        let tokens = tokens.unwrap_or_else(|| {
-                                            // We didn't get an exact token count from the API, so fall back on our estimate.
-                                            thread.total_token_usage()
-                                                .map(|usage| usage.total)
-                                                .unwrap_or(0)
-                                                // We know the context window was exceeded in practice, so if our estimate was
-                                                // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
-                                                .max(model.max_token_count().saturating_add(1))
-                                        });
-
-                                        anyhow::bail!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens })
-                                    }
-                                    LanguageModelCompletionError::ApiReadResponseError(io_error) => {
-                                        anyhow::bail!(LanguageModelKnownError::ReadResponseError(io_error));
-                                    }
-                                    LanguageModelCompletionError::UnknownResponseFormat(error) => {
-                                        anyhow::bail!(LanguageModelKnownError::UnknownResponseFormat(error));
-                                    }
-                                    LanguageModelCompletionError::HttpResponseError { status, ref body } => {
-                                        if let Some(known_error) = LanguageModelKnownError::from_http_response(status, body) {
-                                            anyhow::bail!(known_error);
-                                        } else {
-                                            return Err(error.into());
-                                        }
-                                    }
-                                    LanguageModelCompletionError::DeserializeResponse(error) => {
-                                        anyhow::bail!(LanguageModelKnownError::DeserializeResponse(error));
-                                    }
-                                    LanguageModelCompletionError::BadInputJson {
-                                        id,
-                                        tool_name,
-                                        raw_input: invalid_input_json,
-                                        json_parse_error,
-                                    } => {
-                                        thread.receive_invalid_tool_json(
-                                            id,
-                                            tool_name,
-                                            invalid_input_json,
-                                            json_parse_error,
-                                            window,
-                                            cx,
-                                        );
-                                        return Ok(());
-                                    }
-                                    // These are all errors we can't automatically attempt to recover from (e.g. by retrying)
-                                    err @ LanguageModelCompletionError::BadRequestFormat |
-                                    err @ LanguageModelCompletionError::AuthenticationError |
-                                    err @ LanguageModelCompletionError::PermissionError |
-                                    err @ LanguageModelCompletionError::ApiEndpointNotFound |
-                                    err @ LanguageModelCompletionError::SerializeRequest(_) |
-                                    err @ LanguageModelCompletionError::BuildRequestBody(_) |
-                                    err @ LanguageModelCompletionError::HttpSend(_) => {
-                                        anyhow::bail!(err);
-                                    }
-                                    LanguageModelCompletionError::Other(error) => {
-                                        return Err(error);
-                                    }
-                                }
-                            }
-                        };
-
-                        match event {
+                        match event? {
                             LanguageModelCompletionEvent::StartMessage { .. } => {
                                 request_assistant_message_id =
                                     Some(thread.insert_assistant_message(
@@ -1683,9 +1607,7 @@ impl Thread {
                                     };
                                 }
                             }
-                            LanguageModelCompletionEvent::RedactedThinking {
-                                data
-                            } => {
+                            LanguageModelCompletionEvent::RedactedThinking { data } => {
                                 thread.received_chunk();
 
                                 if let Some(last_message) = thread.messages.last_mut() {
@@ -1734,6 +1656,21 @@ impl Thread {
                                     });
                                 }
                             }
+                            LanguageModelCompletionEvent::ToolUseJsonParseError {
+                                id,
+                                tool_name,
+                                raw_input: invalid_input_json,
+                                json_parse_error,
+                            } => {
+                                thread.receive_invalid_tool_json(
+                                    id,
+                                    tool_name,
+                                    invalid_input_json,
+                                    json_parse_error,
+                                    window,
+                                    cx,
+                                );
+                            }
                             LanguageModelCompletionEvent::StatusUpdate(status_update) => {
                                 if let Some(completion) = thread
                                     .pending_completions
@@ -1741,23 +1678,34 @@ impl Thread {
                                     .find(|completion| completion.id == pending_completion_id)
                                 {
                                     match status_update {
-                                        CompletionRequestStatus::Queued {
-                                            position,
-                                        } => {
-                                            completion.queue_state = QueueState::Queued { position };
+                                        CompletionRequestStatus::Queued { position } => {
+                                            completion.queue_state =
+                                                QueueState::Queued { position };
                                         }
                                         CompletionRequestStatus::Started => {
-                                            completion.queue_state =  QueueState::Started;
+                                            completion.queue_state = QueueState::Started;
                                         }
                                         CompletionRequestStatus::Failed {
-                                            code, message, request_id
+                                            code,
+                                            message,
+                                            request_id: _,
+                                            retry_after,
                                         } => {
-                                            anyhow::bail!("completion request failed. request_id: {request_id}, code: {code}, message: {message}");
+                                            return Err(
+                                                LanguageModelCompletionError::from_cloud_failure(
+                                                    model.upstream_provider_name(),
+                                                    code,
+                                                    message,
+                                                    retry_after.map(Duration::from_secs_f64),
+                                                ),
+                                            );
                                         }
-                                        CompletionRequestStatus::UsageUpdated {
-                                            amount, limit
-                                        } => {
-                                            thread.update_model_request_usage(amount as u32, limit, cx);
+                                        CompletionRequestStatus::UsageUpdated { amount, limit } => {
+                                            thread.update_model_request_usage(
+                                                amount as u32,
+                                                limit,
+                                                cx,
+                                            );
                                         }
                                         CompletionRequestStatus::ToolUseLimitReached => {
                                             thread.tool_use_limit_reached = true;
@@ -1808,10 +1756,11 @@ impl Thread {
                         Ok(stop_reason) => {
                             match stop_reason {
                                 StopReason::ToolUse => {
-                                    let tool_uses = thread.use_pending_tools(window, model.clone(), cx);
+                                    let tool_uses =
+                                        thread.use_pending_tools(window, model.clone(), cx);
                                     cx.emit(ThreadEvent::UsePendingTools { tool_uses });
                                 }
-                                StopReason::EndTurn | StopReason::MaxTokens  => {
+                                StopReason::EndTurn | StopReason::MaxTokens => {
                                     thread.project.update(cx, |project, cx| {
                                         project.set_agent_location(None, cx);
                                     });
@@ -1827,7 +1776,9 @@ impl Thread {
                                     {
                                         let mut messages_to_remove = Vec::new();
 
-                                        for (ix, message) in thread.messages.iter().enumerate().rev() {
+                                        for (ix, message) in
+                                            thread.messages.iter().enumerate().rev()
+                                        {
                                             messages_to_remove.push(message.id);
 
                                             if message.role == Role::User {
@@ -1835,7 +1786,9 @@ impl Thread {
                                                     break;
                                                 }
 
-                                                if let Some(prev_message) = thread.messages.get(ix - 1) {
+                                                if let Some(prev_message) =
+                                                    thread.messages.get(ix - 1)
+                                                {
                                                     if prev_message.role == Role::Assistant {
                                                         break;
                                                     }
@@ -1850,14 +1803,16 @@ impl Thread {
 
                                     cx.emit(ThreadEvent::ShowError(ThreadError::Message {
                                         header: "Language model refusal".into(),
-                                        message: "Model refused to generate content for safety reasons.".into(),
+                                        message:
+                                            "Model refused to generate content for safety reasons."
+                                                .into(),
                                     }));
                                 }
                             }
 
                             // We successfully completed, so cancel any remaining retries.
                             thread.retry_state = None;
-                        },
+                        }
                         Err(error) => {
                             thread.project.update(cx, |project, cx| {
                                 project.set_agent_location(None, cx);
@@ -1883,26 +1838,38 @@ impl Thread {
                                 cx.emit(ThreadEvent::ShowError(
                                     ThreadError::ModelRequestLimitReached { plan: error.plan },
                                 ));
-                            } else if let Some(known_error) =
-                                error.downcast_ref::<LanguageModelKnownError>()
+                            } else if let Some(completion_error) =
+                                error.downcast_ref::<LanguageModelCompletionError>()
                             {
-                                match known_error {
-                                    LanguageModelKnownError::ContextWindowLimitExceeded { tokens } => {
+                                use LanguageModelCompletionError::*;
+                                match &completion_error {
+                                    PromptTooLarge { tokens, .. } => {
+                                        let tokens = tokens.unwrap_or_else(|| {
+                                            // We didn't get an exact token count from the API, so fall back on our estimate.
+                                            thread
+                                                .total_token_usage()
+                                                .map(|usage| usage.total)
+                                                .unwrap_or(0)
+                                                // We know the context window was exceeded in practice, so if our estimate was
+                                                // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
+                                                .max(model.max_token_count().saturating_add(1))
+                                        });
                                         thread.exceeded_window_error = Some(ExceededWindowError {
                                             model_id: model.id(),
-                                            token_count: *tokens,
+                                            token_count: tokens,
                                         });
                                         cx.notify();
                                     }
-                                    LanguageModelKnownError::RateLimitExceeded { retry_after } => {
-                                        let provider_name = model.provider_name();
-                                        let error_message = format!(
-                                            "{}'s API rate limit exceeded",
-                                            provider_name.0.as_ref()
-                                        );
-
+                                    RateLimitExceeded {
+                                        retry_after: Some(retry_after),
+                                        ..
+                                    }
+                                    | ServerOverloaded {
+                                        retry_after: Some(retry_after),
+                                        ..
+                                    } => {
                                         thread.handle_rate_limit_error(
-                                            &error_message,
+                                            &completion_error,
                                             *retry_after,
                                             model.clone(),
                                             intent,
@@ -1911,15 +1878,9 @@ impl Thread {
                                         );
                                         retry_scheduled = true;
                                     }
-                                    LanguageModelKnownError::Overloaded => {
-                                        let provider_name = model.provider_name();
-                                        let error_message = format!(
-                                            "{}'s API servers are overloaded right now",
-                                            provider_name.0.as_ref()
-                                        );
-
+                                    RateLimitExceeded { .. } | ServerOverloaded { .. } => {
                                         retry_scheduled = thread.handle_retryable_error(
-                                            &error_message,
+                                            &completion_error,
                                             model.clone(),
                                             intent,
                                             window,
@@ -1929,15 +1890,11 @@ impl Thread {
                                             emit_generic_error(error, cx);
                                         }
                                     }
-                                    LanguageModelKnownError::ApiInternalServerError => {
-                                        let provider_name = model.provider_name();
-                                        let error_message = format!(
-                                            "{}'s API server reported an internal server error",
-                                            provider_name.0.as_ref()
-                                        );
-
+                                    ApiInternalServerError { .. }
+                                    | ApiReadResponseError { .. }
+                                    | HttpSend { .. } => {
                                         retry_scheduled = thread.handle_retryable_error(
-                                            &error_message,
+                                            &completion_error,
                                             model.clone(),
                                             intent,
                                             window,
@@ -1947,12 +1904,16 @@ impl Thread {
                                             emit_generic_error(error, cx);
                                         }
                                     }
-                                    LanguageModelKnownError::ReadResponseError(_) |
-                                    LanguageModelKnownError::DeserializeResponse(_) |
-                                    LanguageModelKnownError::UnknownResponseFormat(_) => {
-                                        // In the future we will attempt to re-roll response, but only once
-                                        emit_generic_error(error, cx);
-                                    }
+                                    NoApiKey { .. }
+                                    | HttpResponseError { .. }
+                                    | BadRequestFormat { .. }
+                                    | AuthenticationError { .. }
+                                    | PermissionError { .. }
+                                    | ApiEndpointNotFound { .. }
+                                    | SerializeRequest { .. }
+                                    | BuildRequestBody { .. }
+                                    | DeserializeResponse { .. }
+                                    | Other { .. } => emit_generic_error(error, cx),
                                 }
                             } else {
                                 emit_generic_error(error, cx);
@@ -2084,7 +2045,7 @@ impl Thread {
 
     fn handle_rate_limit_error(
         &mut self,
-        error_message: &str,
+        error: &LanguageModelCompletionError,
         retry_after: Duration,
         model: Arc<dyn LanguageModel>,
         intent: CompletionIntent,
@@ -2092,9 +2053,10 @@ impl Thread {
         cx: &mut Context<Self>,
     ) {
         // For rate limit errors, we only retry once with the specified duration
-        let retry_message = format!(
-            "{error_message}. Retrying in {} seconds…",
-            retry_after.as_secs()
+        let retry_message = format!("{error}. Retrying in {} seconds…", retry_after.as_secs());
+        log::warn!(
+            "Retrying completion request in {} seconds: {error:?}",
+            retry_after.as_secs(),
         );
 
         // Add a UI-only message instead of a regular message
@@ -2127,18 +2089,18 @@ impl Thread {
 
     fn handle_retryable_error(
         &mut self,
-        error_message: &str,
+        error: &LanguageModelCompletionError,
         model: Arc<dyn LanguageModel>,
         intent: CompletionIntent,
         window: Option<AnyWindowHandle>,
         cx: &mut Context<Self>,
     ) -> bool {
-        self.handle_retryable_error_with_delay(error_message, None, model, intent, window, cx)
+        self.handle_retryable_error_with_delay(error, None, model, intent, window, cx)
     }
 
     fn handle_retryable_error_with_delay(
         &mut self,
-        error_message: &str,
+        error: &LanguageModelCompletionError,
         custom_delay: Option<Duration>,
         model: Arc<dyn LanguageModel>,
         intent: CompletionIntent,
@@ -2168,8 +2130,12 @@ impl Thread {
             // Add a transient message to inform the user
             let delay_secs = delay.as_secs();
             let retry_message = format!(
-                "{}. Retrying (attempt {} of {}) in {} seconds...",
-                error_message, attempt, max_attempts, delay_secs
+                "{error}. Retrying (attempt {attempt} of {max_attempts}) \
+                in {delay_secs} seconds..."
+            );
+            log::warn!(
+                "Retrying completion request (attempt {attempt} of {max_attempts}) \
+                in {delay_secs} seconds: {error:?}",
             );
 
             // Add a UI-only message instead of a regular message
@@ -4139,9 +4105,15 @@ fn main() {{
             >,
         > {
             let error = match self.error_type {
-                TestError::Overloaded => LanguageModelCompletionError::Overloaded,
+                TestError::Overloaded => LanguageModelCompletionError::ServerOverloaded {
+                    provider: self.provider_name(),
+                    retry_after: None,
+                },
                 TestError::InternalServerError => {
-                    LanguageModelCompletionError::ApiInternalServerError
+                    LanguageModelCompletionError::ApiInternalServerError {
+                        provider: self.provider_name(),
+                        message: "I'm a teapot orbiting the sun".to_string(),
+                    }
                 }
             };
             async move {
@@ -4649,9 +4621,13 @@ fn main() {{
             > {
                 if !*self.failed_once.lock() {
                     *self.failed_once.lock() = true;
+                    let provider = self.provider_name();
                     // Return error on first attempt
                     let stream = futures::stream::once(async move {
-                        Err(LanguageModelCompletionError::Overloaded)
+                        Err(LanguageModelCompletionError::ServerOverloaded {
+                            provider,
+                            retry_after: None,
+                        })
                     });
                     async move { Ok(stream.boxed()) }.boxed()
                 } else {
@@ -4814,9 +4790,13 @@ fn main() {{
             > {
                 if !*self.failed_once.lock() {
                     *self.failed_once.lock() = true;
+                    let provider = self.provider_name();
                     // Return error on first attempt
                     let stream = futures::stream::once(async move {
-                        Err(LanguageModelCompletionError::Overloaded)
+                        Err(LanguageModelCompletionError::ServerOverloaded {
+                            provider,
+                            retry_after: None,
+                        })
                     });
                     async move { Ok(stream.boxed()) }.boxed()
                 } else {
@@ -4969,10 +4949,12 @@ fn main() {{
                     LanguageModelCompletionError,
                 >,
             > {
+                let provider = self.provider_name();
                 async move {
                     let stream = futures::stream::once(async move {
                         Err(LanguageModelCompletionError::RateLimitExceeded {
-                            retry_after: Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS),
+                            provider,
+                            retry_after: Some(Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS)),
                         })
                     });
                     Ok(stream.boxed())

crates/agent_ui/src/agent_panel.rs πŸ”—

@@ -2025,9 +2025,7 @@ impl AgentPanel {
                     .thread()
                     .read(cx)
                     .configured_model()
-                    .map_or(false, |model| {
-                        model.provider.id().0 == ZED_CLOUD_PROVIDER_ID
-                    });
+                    .map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID);
 
                 if !is_using_zed_provider {
                     return false;

crates/agent_ui/src/message_editor.rs πŸ”—

@@ -1250,9 +1250,7 @@ impl MessageEditor {
         self.thread
             .read(cx)
             .configured_model()
-            .map_or(false, |model| {
-                model.provider.id().0 == ZED_CLOUD_PROVIDER_ID
-            })
+            .map_or(false, |model| model.provider.id() == ZED_CLOUD_PROVIDER_ID)
     }
 
     fn render_usage_callout(&self, line_height: Pixels, cx: &mut Context<Self>) -> Option<Div> {

crates/anthropic/src/anthropic.rs πŸ”—

@@ -6,7 +6,7 @@ use anyhow::{Context as _, Result, anyhow};
 use chrono::{DateTime, Utc};
 use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
 use http_client::http::{self, HeaderMap, HeaderValue};
-use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode};
 use serde::{Deserialize, Serialize};
 use strum::{EnumIter, EnumString};
 use thiserror::Error;
@@ -356,7 +356,7 @@ pub async fn complete(
         .send(request)
         .await
         .map_err(AnthropicError::HttpSend)?;
-    let status = response.status();
+    let status_code = response.status();
     let mut body = String::new();
     response
         .body_mut()
@@ -364,12 +364,12 @@ pub async fn complete(
         .await
         .map_err(AnthropicError::ReadResponse)?;
 
-    if status.is_success() {
+    if status_code.is_success() {
         Ok(serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)?)
     } else {
         Err(AnthropicError::HttpResponseError {
-            status: status.as_u16(),
-            body,
+            status_code,
+            message: body,
         })
     }
 }
@@ -444,11 +444,7 @@ impl RateLimitInfo {
         }
 
         Self {
-            retry_after: headers
-                .get("retry-after")
-                .and_then(|v| v.to_str().ok())
-                .and_then(|v| v.parse::<u64>().ok())
-                .map(Duration::from_secs),
+            retry_after: parse_retry_after(headers),
             requests: RateLimit::from_headers("requests", headers).ok(),
             tokens: RateLimit::from_headers("tokens", headers).ok(),
             input_tokens: RateLimit::from_headers("input-tokens", headers).ok(),
@@ -457,6 +453,17 @@ impl RateLimitInfo {
     }
 }
 
+/// Parses the Retry-After header value as an integer number of seconds (anthropic always uses
+/// seconds). Note that other services might specify an HTTP date or some other format for this
+/// header. Returns `None` if the header is not present or cannot be parsed.
+pub fn parse_retry_after(headers: &HeaderMap<HeaderValue>) -> Option<Duration> {
+    headers
+        .get("retry-after")
+        .and_then(|v| v.to_str().ok())
+        .and_then(|v| v.parse::<u64>().ok())
+        .map(Duration::from_secs)
+}
+
 fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> anyhow::Result<&'a str> {
     Ok(headers
         .get(key)
@@ -520,6 +527,10 @@ pub async fn stream_completion_with_rate_limit_info(
             })
             .boxed();
         Ok((stream, Some(rate_limits)))
+    } else if response.status().as_u16() == 529 {
+        Err(AnthropicError::ServerOverloaded {
+            retry_after: rate_limits.retry_after,
+        })
     } else if let Some(retry_after) = rate_limits.retry_after {
         Err(AnthropicError::RateLimit { retry_after })
     } else {
@@ -532,10 +543,9 @@ pub async fn stream_completion_with_rate_limit_info(
 
         match serde_json::from_str::<Event>(&body) {
             Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
-            Ok(_) => Err(AnthropicError::UnexpectedResponseFormat(body)),
-            Err(_) => Err(AnthropicError::HttpResponseError {
-                status: response.status().as_u16(),
-                body: body,
+            Ok(_) | Err(_) => Err(AnthropicError::HttpResponseError {
+                status_code: response.status(),
+                message: body,
             }),
         }
     }
@@ -801,16 +811,19 @@ pub enum AnthropicError {
     ReadResponse(io::Error),
 
     /// HTTP error response from the API
-    HttpResponseError { status: u16, body: String },
+    HttpResponseError {
+        status_code: StatusCode,
+        message: String,
+    },
 
     /// Rate limit exceeded
     RateLimit { retry_after: Duration },
 
+    /// Server overloaded
+    ServerOverloaded { retry_after: Option<Duration> },
+
     /// API returned an error response
     ApiError(ApiError),
-
-    /// Unexpected response format
-    UnexpectedResponseFormat(String),
 }
 
 #[derive(Debug, Serialize, Deserialize, Error)]

crates/assistant_context/src/assistant_context.rs πŸ”—

@@ -2140,7 +2140,8 @@ impl AssistantContext {
                                         );
                                     }
                                     LanguageModelCompletionEvent::ToolUse(_) |
-                                    LanguageModelCompletionEvent::UsageUpdate(_)  => {}
+                                    LanguageModelCompletionEvent::ToolUseJsonParseError { .. } |
+                                    LanguageModelCompletionEvent::UsageUpdate(_) => {}
                                 }
                             });
 

crates/assistant_tools/src/edit_agent/evals.rs πŸ”—

@@ -29,6 +29,7 @@ use std::{
     path::Path,
     str::FromStr,
     sync::mpsc,
+    time::Duration,
 };
 use util::path;
 
@@ -1658,12 +1659,14 @@ async fn retry_on_rate_limit<R>(mut request: impl AsyncFnMut() -> Result<R>) ->
         match request().await {
             Ok(result) => return Ok(result),
             Err(err) => match err.downcast::<LanguageModelCompletionError>() {
-                Ok(err) => match err {
-                    LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
+                Ok(err) => match &err {
+                    LanguageModelCompletionError::RateLimitExceeded { retry_after, .. }
+                    | LanguageModelCompletionError::ServerOverloaded { retry_after, .. } => {
+                        let retry_after = retry_after.unwrap_or(Duration::from_secs(5));
                         // Wait for the duration supplied, with some jitter to avoid all requests being made at the same time.
                         let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0));
                         eprintln!(
-                            "Attempt #{attempt}: Rate limit exceeded. Retry after {retry_after:?} + jitter of {jitter:?}"
+                            "Attempt #{attempt}: {err}. Retry after {retry_after:?} + jitter of {jitter:?}"
                         );
                         Timer::after(retry_after + jitter).await;
                         continue;

crates/eval/src/instance.rs πŸ”—

@@ -1054,6 +1054,15 @@ pub fn response_events_to_markdown(
                 | LanguageModelCompletionEvent::StartMessage { .. }
                 | LanguageModelCompletionEvent::StatusUpdate { .. },
             ) => {}
+            Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+                json_parse_error, ..
+            }) => {
+                flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
+                response.push_str(&format!(
+                    "**Error**: parse error in tool use JSON: {}\n\n",
+                    json_parse_error
+                ));
+            }
             Err(error) => {
                 flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
                 response.push_str(&format!("**Error**: {}\n\n", error));
@@ -1132,6 +1141,17 @@ impl ThreadDialog {
                 | Ok(LanguageModelCompletionEvent::StartMessage { .. })
                 | Ok(LanguageModelCompletionEvent::Stop(_)) => {}
 
+                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+                    json_parse_error,
+                    ..
+                }) => {
+                    flush_text(&mut current_text, &mut content);
+                    content.push(MessageContent::Text(format!(
+                        "ERROR: parse error in tool use JSON: {}",
+                        json_parse_error
+                    )));
+                }
+
                 Err(error) => {
                     flush_text(&mut current_text, &mut content);
                     content.push(MessageContent::Text(format!("ERROR: {}", error)));

crates/language_model/src/language_model.rs πŸ”—

@@ -9,17 +9,18 @@ mod telemetry;
 pub mod fake_provider;
 
 use anthropic::{AnthropicError, parse_prompt_too_long};
-use anyhow::Result;
+use anyhow::{Result, anyhow};
 use client::Client;
 use futures::FutureExt;
 use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
-use http_client::http;
+use http_client::{StatusCode, http};
 use icons::IconName;
 use parking_lot::Mutex;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize, de::DeserializeOwned};
 use std::ops::{Add, Sub};
+use std::str::FromStr;
 use std::sync::Arc;
 use std::time::Duration;
 use std::{fmt, io};
@@ -34,11 +35,22 @@ pub use crate::request::*;
 pub use crate::role::*;
 pub use crate::telemetry::*;
 
-pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
+pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
+    LanguageModelProviderId::new("anthropic");
+pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
+    LanguageModelProviderName::new("Anthropic");
 
-/// If we get a rate limit error that doesn't tell us when we can retry,
-/// default to waiting this long before retrying.
-const DEFAULT_RATE_LIMIT_RETRY_AFTER: Duration = Duration::from_secs(4);
+pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
+pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
+    LanguageModelProviderName::new("Google AI");
+
+pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
+pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
+    LanguageModelProviderName::new("OpenAI");
+
+pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
+pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
+    LanguageModelProviderName::new("Zed");
 
 pub fn init(client: Arc<Client>, cx: &mut App) {
     init_settings(cx);
@@ -71,6 +83,12 @@ pub enum LanguageModelCompletionEvent {
         data: String,
     },
     ToolUse(LanguageModelToolUse),
+    ToolUseJsonParseError {
+        id: LanguageModelToolUseId,
+        tool_name: Arc<str>,
+        raw_input: Arc<str>,
+        json_parse_error: String,
+    },
     StartMessage {
         message_id: String,
     },
@@ -79,61 +97,179 @@ pub enum LanguageModelCompletionEvent {
 
 #[derive(Error, Debug)]
 pub enum LanguageModelCompletionError {
-    #[error("rate limit exceeded, retry after {retry_after:?}")]
-    RateLimitExceeded { retry_after: Duration },
-    #[error("received bad input JSON")]
-    BadInputJson {
-        id: LanguageModelToolUseId,
-        tool_name: Arc<str>,
-        raw_input: Arc<str>,
-        json_parse_error: String,
+    #[error("prompt too large for context window")]
+    PromptTooLarge { tokens: Option<u64> },
+    #[error("missing {provider} API key")]
+    NoApiKey { provider: LanguageModelProviderName },
+    #[error("{provider}'s API rate limit exceeded")]
+    RateLimitExceeded {
+        provider: LanguageModelProviderName,
+        retry_after: Option<Duration>,
+    },
+    #[error("{provider}'s API servers are overloaded right now")]
+    ServerOverloaded {
+        provider: LanguageModelProviderName,
+        retry_after: Option<Duration>,
+    },
+    #[error("{provider}'s API server reported an internal server error: {message}")]
+    ApiInternalServerError {
+        provider: LanguageModelProviderName,
+        message: String,
+    },
+    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
+    HttpResponseError {
+        provider: LanguageModelProviderName,
+        status_code: StatusCode,
+        message: String,
+    },
+
+    // Client errors
+    #[error("invalid request format to {provider}'s API: {message}")]
+    BadRequestFormat {
+        provider: LanguageModelProviderName,
+        message: String,
     },
-    #[error("language model provider's API is overloaded")]
-    Overloaded,
+    #[error("authentication error with {provider}'s API: {message}")]
+    AuthenticationError {
+        provider: LanguageModelProviderName,
+        message: String,
+    },
+    #[error("permission error with {provider}'s API: {message}")]
+    PermissionError {
+        provider: LanguageModelProviderName,
+        message: String,
+    },
+    #[error("language model provider API endpoint not found")]
+    ApiEndpointNotFound { provider: LanguageModelProviderName },
+    #[error("I/O error reading response from {provider}'s API")]
+    ApiReadResponseError {
+        provider: LanguageModelProviderName,
+        #[source]
+        error: io::Error,
+    },
+    #[error("error serializing request to {provider} API")]
+    SerializeRequest {
+        provider: LanguageModelProviderName,
+        #[source]
+        error: serde_json::Error,
+    },
+    #[error("error building request body to {provider} API")]
+    BuildRequestBody {
+        provider: LanguageModelProviderName,
+        #[source]
+        error: http::Error,
+    },
+    #[error("error sending HTTP request to {provider} API")]
+    HttpSend {
+        provider: LanguageModelProviderName,
+        #[source]
+        error: anyhow::Error,
+    },
+    #[error("error deserializing {provider} API response")]
+    DeserializeResponse {
+        provider: LanguageModelProviderName,
+        #[source]
+        error: serde_json::Error,
+    },
+
+    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
     #[error(transparent)]
     Other(#[from] anyhow::Error),
-    #[error("invalid request format to language model provider's API")]
-    BadRequestFormat,
-    #[error("authentication error with language model provider's API")]
-    AuthenticationError,
-    #[error("permission error with language model provider's API")]
-    PermissionError,
-    #[error("language model provider API endpoint not found")]
-    ApiEndpointNotFound,
-    #[error("prompt too large for context window")]
-    PromptTooLarge { tokens: Option<u64> },
-    #[error("internal server error in language model provider's API")]
-    ApiInternalServerError,
-    #[error("I/O error reading response from language model provider's API: {0:?}")]
-    ApiReadResponseError(io::Error),
-    #[error("HTTP response error from language model provider's API: status {status} - {body:?}")]
-    HttpResponseError { status: u16, body: String },
-    #[error("error serializing request to language model provider API: {0}")]
-    SerializeRequest(serde_json::Error),
-    #[error("error building request body to language model provider API: {0}")]
-    BuildRequestBody(http::Error),
-    #[error("error sending HTTP request to language model provider API: {0}")]
-    HttpSend(anyhow::Error),
-    #[error("error deserializing language model provider API response: {0}")]
-    DeserializeResponse(serde_json::Error),
-    #[error("unexpected language model provider API response format: {0}")]
-    UnknownResponseFormat(String),
+}
+
+impl LanguageModelCompletionError {
+    pub fn from_cloud_failure(
+        upstream_provider: LanguageModelProviderName,
+        code: String,
+        message: String,
+        retry_after: Option<Duration>,
+    ) -> Self {
+        if let Some(tokens) = parse_prompt_too_long(&message) {
+            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
+            // to be reported. This is a temporary workaround to handle this in the case where the
+            // token limit has been exceeded.
+            Self::PromptTooLarge {
+                tokens: Some(tokens),
+            }
+        } else if let Some(status_code) = code
+            .strip_prefix("upstream_http_")
+            .and_then(|code| StatusCode::from_str(code).ok())
+        {
+            Self::from_http_status(upstream_provider, status_code, message, retry_after)
+        } else if let Some(status_code) = code
+            .strip_prefix("http_")
+            .and_then(|code| StatusCode::from_str(code).ok())
+        {
+            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
+        } else {
+            anyhow!("completion request failed, code: {code}, message: {message}").into()
+        }
+    }
+
+    pub fn from_http_status(
+        provider: LanguageModelProviderName,
+        status_code: StatusCode,
+        message: String,
+        retry_after: Option<Duration>,
+    ) -> Self {
+        match status_code {
+            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
+            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
+            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
+            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
+            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
+                tokens: parse_prompt_too_long(&message),
+            },
+            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
+                provider,
+                retry_after,
+            },
+            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
+            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
+                provider,
+                retry_after,
+            },
+            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
+                provider,
+                retry_after,
+            },
+            _ => Self::HttpResponseError {
+                provider,
+                status_code,
+                message,
+            },
+        }
+    }
 }
 
 impl From<AnthropicError> for LanguageModelCompletionError {
     fn from(error: AnthropicError) -> Self {
+        let provider = ANTHROPIC_PROVIDER_NAME;
         match error {
-            AnthropicError::SerializeRequest(error) => Self::SerializeRequest(error),
-            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody(error),
-            AnthropicError::HttpSend(error) => Self::HttpSend(error),
-            AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse(error),
-            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError(error),
-            AnthropicError::HttpResponseError { status, body } => {
-                Self::HttpResponseError { status, body }
+            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
+            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
+            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
+            AnthropicError::DeserializeResponse(error) => {
+                Self::DeserializeResponse { provider, error }
             }
-            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { retry_after },
+            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
+            AnthropicError::HttpResponseError {
+                status_code,
+                message,
+            } => Self::HttpResponseError {
+                provider,
+                status_code,
+                message,
+            },
+            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
+                provider,
+                retry_after: Some(retry_after),
+            },
+            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
+                provider,
+                retry_after: retry_after,
+            },
             AnthropicError::ApiError(api_error) => api_error.into(),
-            AnthropicError::UnexpectedResponseFormat(error) => Self::UnknownResponseFormat(error),
         }
     }
 }
@@ -141,23 +277,39 @@ impl From<AnthropicError> for LanguageModelCompletionError {
 impl From<anthropic::ApiError> for LanguageModelCompletionError {
     fn from(error: anthropic::ApiError) -> Self {
         use anthropic::ApiErrorCode::*;
-
+        let provider = ANTHROPIC_PROVIDER_NAME;
         match error.code() {
             Some(code) => match code {
-                InvalidRequestError => LanguageModelCompletionError::BadRequestFormat,
-                AuthenticationError => LanguageModelCompletionError::AuthenticationError,
-                PermissionError => LanguageModelCompletionError::PermissionError,
-                NotFoundError => LanguageModelCompletionError::ApiEndpointNotFound,
-                RequestTooLarge => LanguageModelCompletionError::PromptTooLarge {
+                InvalidRequestError => Self::BadRequestFormat {
+                    provider,
+                    message: error.message,
+                },
+                AuthenticationError => Self::AuthenticationError {
+                    provider,
+                    message: error.message,
+                },
+                PermissionError => Self::PermissionError {
+                    provider,
+                    message: error.message,
+                },
+                NotFoundError => Self::ApiEndpointNotFound { provider },
+                RequestTooLarge => Self::PromptTooLarge {
                     tokens: parse_prompt_too_long(&error.message),
                 },
-                RateLimitError => LanguageModelCompletionError::RateLimitExceeded {
-                    retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
+                RateLimitError => Self::RateLimitExceeded {
+                    provider,
+                    retry_after: None,
+                },
+                ApiError => Self::ApiInternalServerError {
+                    provider,
+                    message: error.message,
+                },
+                OverloadedError => Self::ServerOverloaded {
+                    provider,
+                    retry_after: None,
                 },
-                ApiError => LanguageModelCompletionError::ApiInternalServerError,
-                OverloadedError => LanguageModelCompletionError::Overloaded,
             },
-            None => LanguageModelCompletionError::Other(error.into()),
+            None => Self::Other(error.into()),
         }
     }
 }
@@ -278,6 +430,13 @@ pub trait LanguageModel: Send + Sync {
     fn name(&self) -> LanguageModelName;
     fn provider_id(&self) -> LanguageModelProviderId;
     fn provider_name(&self) -> LanguageModelProviderName;
+    fn upstream_provider_id(&self) -> LanguageModelProviderId {
+        self.provider_id()
+    }
+    fn upstream_provider_name(&self) -> LanguageModelProviderName {
+        self.provider_name()
+    }
+
     fn telemetry_id(&self) -> String;
 
     fn api_key(&self, _cx: &App) -> Option<String> {
@@ -365,6 +524,9 @@ pub trait LanguageModel: Send + Sync {
                                 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
                                 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
                                 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
+                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+                                    ..
+                                }) => None,
                                 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
                                     *last_token_usage.lock() = token_usage;
                                     None
@@ -395,39 +557,6 @@ pub trait LanguageModel: Send + Sync {
     }
 }
 
-#[derive(Debug, Error)]
-pub enum LanguageModelKnownError {
-    #[error("Context window limit exceeded ({tokens})")]
-    ContextWindowLimitExceeded { tokens: u64 },
-    #[error("Language model provider's API is currently overloaded")]
-    Overloaded,
-    #[error("Language model provider's API encountered an internal server error")]
-    ApiInternalServerError,
-    #[error("I/O error while reading response from language model provider's API: {0:?}")]
-    ReadResponseError(io::Error),
-    #[error("Error deserializing response from language model provider's API: {0:?}")]
-    DeserializeResponse(serde_json::Error),
-    #[error("Language model provider's API returned a response in an unknown format")]
-    UnknownResponseFormat(String),
-    #[error("Rate limit exceeded for language model provider's API; retry in {retry_after:?}")]
-    RateLimitExceeded { retry_after: Duration },
-}
-
-impl LanguageModelKnownError {
-    /// Attempts to map an HTTP response status code to a known error type.
-    /// Returns None if the status code doesn't map to a specific known error.
-    pub fn from_http_response(status: u16, _body: &str) -> Option<Self> {
-        match status {
-            429 => Some(Self::RateLimitExceeded {
-                retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
-            }),
-            503 => Some(Self::Overloaded),
-            500..=599 => Some(Self::ApiInternalServerError),
-            _ => None,
-        }
-    }
-}
-
 pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
     fn name() -> String;
     fn description() -> String;
@@ -509,12 +638,30 @@ pub struct LanguageModelProviderId(pub SharedString);
 #[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 pub struct LanguageModelProviderName(pub SharedString);
 
+impl LanguageModelProviderId {
+    pub const fn new(id: &'static str) -> Self {
+        Self(SharedString::new_static(id))
+    }
+}
+
+impl LanguageModelProviderName {
+    pub const fn new(id: &'static str) -> Self {
+        Self(SharedString::new_static(id))
+    }
+}
+
 impl fmt::Display for LanguageModelProviderId {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         write!(f, "{}", self.0)
     }
 }
 
+impl fmt::Display for LanguageModelProviderName {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "{}", self.0)
+    }
+}
+
 impl From<String> for LanguageModelId {
     fn from(value: String) -> Self {
         Self(SharedString::from(value))

crates/language_model/src/registry.rs πŸ”—

@@ -98,7 +98,7 @@ impl ConfiguredModel {
     }
 
     pub fn is_provided_by_zed(&self) -> bool {
-        self.provider.id().0 == crate::ZED_CLOUD_PROVIDER_ID
+        self.provider.id() == crate::ZED_CLOUD_PROVIDER_ID
     }
 }
 

crates/language_model/src/telemetry.rs πŸ”—

@@ -1,3 +1,4 @@
+use crate::ANTHROPIC_PROVIDER_ID;
 use anthropic::ANTHROPIC_API_URL;
 use anyhow::{Context as _, anyhow};
 use client::telemetry::Telemetry;
@@ -8,8 +9,6 @@ use std::sync::Arc;
 use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
 use util::ResultExt;
 
-pub const ANTHROPIC_PROVIDER_ID: &str = "anthropic";
-
 pub fn report_assistant_event(
     event: AssistantEventData,
     telemetry: Option<Arc<Telemetry>>,
@@ -19,7 +18,7 @@ pub fn report_assistant_event(
 ) {
     if let Some(telemetry) = telemetry.as_ref() {
         telemetry.report_assistant_event(event.clone());
-        if telemetry.metrics_enabled() && event.model_provider == ANTHROPIC_PROVIDER_ID {
+        if telemetry.metrics_enabled() && event.model_provider == ANTHROPIC_PROVIDER_ID.0 {
             if let Some(api_key) = model_api_key {
                 executor
                     .spawn(async move {

crates/language_models/src/provider/anthropic.rs πŸ”—

@@ -33,8 +33,8 @@ use theme::ThemeSettings;
 use ui::{Icon, IconName, List, Tooltip, prelude::*};
 use util::ResultExt;
 
-const PROVIDER_ID: &str = language_model::ANTHROPIC_PROVIDER_ID;
-const PROVIDER_NAME: &str = "Anthropic";
+const PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
+const PROVIDER_NAME: LanguageModelProviderName = language_model::ANTHROPIC_PROVIDER_NAME;
 
 #[derive(Default, Clone, Debug, PartialEq)]
 pub struct AnthropicSettings {
@@ -218,11 +218,11 @@ impl LanguageModelProviderState for AnthropicLanguageModelProvider {
 
 impl LanguageModelProvider for AnthropicLanguageModelProvider {
     fn id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn icon(&self) -> IconName {
@@ -403,7 +403,11 @@ impl AnthropicModel {
         };
 
         async move {
-            let api_key = api_key.context("Missing Anthropic API Key")?;
+            let Some(api_key) = api_key else {
+                return Err(LanguageModelCompletionError::NoApiKey {
+                    provider: PROVIDER_NAME,
+                });
+            };
             let request =
                 anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
             request.await.map_err(Into::into)
@@ -422,11 +426,11 @@ impl LanguageModel for AnthropicModel {
     }
 
     fn provider_id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn provider_name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn supports_tools(&self) -> bool {
@@ -806,12 +810,14 @@ impl AnthropicEventMapper {
                                 raw_input: tool_use.input_json.clone(),
                             },
                         )),
-                        Err(json_parse_err) => Err(LanguageModelCompletionError::BadInputJson {
-                            id: tool_use.id.into(),
-                            tool_name: tool_use.name.into(),
-                            raw_input: input_json.into(),
-                            json_parse_error: json_parse_err.to_string(),
-                        }),
+                        Err(json_parse_err) => {
+                            Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+                                id: tool_use.id.into(),
+                                tool_name: tool_use.name.into(),
+                                raw_input: input_json.into(),
+                                json_parse_error: json_parse_err.to_string(),
+                            })
+                        }
                     };
 
                     vec![event_result]

crates/language_models/src/provider/bedrock.rs πŸ”—

@@ -52,8 +52,8 @@ use util::ResultExt;
 
 use crate::AllLanguageModelSettings;
 
-const PROVIDER_ID: &str = "amazon-bedrock";
-const PROVIDER_NAME: &str = "Amazon Bedrock";
+const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("amazon-bedrock");
+const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Amazon Bedrock");
 
 #[derive(Default, Clone, Deserialize, Serialize, PartialEq, Debug)]
 pub struct BedrockCredentials {
@@ -285,11 +285,11 @@ impl BedrockLanguageModelProvider {
 
 impl LanguageModelProvider for BedrockLanguageModelProvider {
     fn id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn icon(&self) -> IconName {
@@ -489,11 +489,11 @@ impl LanguageModel for BedrockModel {
     }
 
     fn provider_id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn provider_name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn supports_tools(&self) -> bool {

crates/language_models/src/provider/cloud.rs πŸ”—

@@ -1,4 +1,4 @@
-use anthropic::{AnthropicModelMode, parse_prompt_too_long};
+use anthropic::AnthropicModelMode;
 use anyhow::{Context as _, Result, anyhow};
 use client::{Client, ModelRequestUsage, UserStore, zed_urls};
 use futures::{
@@ -8,25 +8,21 @@ use google_ai::GoogleModelMode;
 use gpui::{
     AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
 };
+use http_client::http::{HeaderMap, HeaderValue};
 use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
 use language_model::{
     AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
-    LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
-    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
-    LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolChoice,
-    LanguageModelToolSchemaFormat, ModelRequestLimitReachedError, RateLimiter,
-    ZED_CLOUD_PROVIDER_ID,
-};
-use language_model::{
-    LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken, PaymentRequiredError,
-    RefreshLlmTokenListener,
+    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
+    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
+    LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
+    LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken,
+    ModelRequestLimitReachedError, PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
 };
 use proto::Plan;
 use release_channel::AppVersion;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize, de::DeserializeOwned};
 use settings::SettingsStore;
-use smol::Timer;
 use smol::io::{AsyncReadExt, BufReader};
 use std::pin::Pin;
 use std::str::FromStr as _;
@@ -47,7 +43,8 @@ use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, i
 use crate::provider::google::{GoogleEventMapper, into_google};
 use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
 
-pub const PROVIDER_NAME: &str = "Zed";
+const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
+const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
 
 #[derive(Default, Clone, Debug, PartialEq)]
 pub struct ZedDotDevSettings {
@@ -351,11 +348,11 @@ impl LanguageModelProviderState for CloudLanguageModelProvider {
 
 impl LanguageModelProvider for CloudLanguageModelProvider {
     fn id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn icon(&self) -> IconName {
@@ -536,8 +533,6 @@ struct PerformLlmCompletionResponse {
 }
 
 impl CloudLanguageModel {
-    const MAX_RETRIES: usize = 3;
-
     async fn perform_llm_completion(
         client: Arc<Client>,
         llm_api_token: LlmApiToken,
@@ -547,8 +542,7 @@ impl CloudLanguageModel {
         let http_client = &client.http_client();
 
         let mut token = llm_api_token.acquire(&client).await?;
-        let mut retries_remaining = Self::MAX_RETRIES;
-        let mut retry_delay = Duration::from_secs(1);
+        let mut refreshed_token = false;
 
         loop {
             let request_builder = http_client::Request::builder()
@@ -590,14 +584,20 @@ impl CloudLanguageModel {
                     includes_status_messages,
                     tool_use_limit_reached,
                 });
-            } else if response
-                .headers()
-                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
-                .is_some()
+            }
+
+            if !refreshed_token
+                && response
+                    .headers()
+                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
+                    .is_some()
             {
-                retries_remaining -= 1;
                 token = llm_api_token.refresh(&client).await?;
-            } else if status == StatusCode::FORBIDDEN
+                refreshed_token = true;
+                continue;
+            }
+
+            if status == StatusCode::FORBIDDEN
                 && response
                     .headers()
                     .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
@@ -622,35 +622,18 @@ impl CloudLanguageModel {
                         return Err(anyhow!(ModelRequestLimitReachedError { plan }));
                     }
                 }
-
-                anyhow::bail!("Forbidden");
-            } else if status.as_u16() >= 500 && status.as_u16() < 600 {
-                // If we encounter an error in the 500 range, retry after a delay.
-                // We've seen at least these in the wild from API providers:
-                // * 500 Internal Server Error
-                // * 502 Bad Gateway
-                // * 529 Service Overloaded
-
-                if retries_remaining == 0 {
-                    let mut body = String::new();
-                    response.body_mut().read_to_string(&mut body).await?;
-                    anyhow::bail!(
-                        "cloud language model completion failed after {} retries with status {status}: {body}",
-                        Self::MAX_RETRIES
-                    );
-                }
-
-                Timer::after(retry_delay).await;
-
-                retries_remaining -= 1;
-                retry_delay *= 2; // If it fails again, wait longer.
             } else if status == StatusCode::PAYMENT_REQUIRED {
                 return Err(anyhow!(PaymentRequiredError));
-            } else {
-                let mut body = String::new();
-                response.body_mut().read_to_string(&mut body).await?;
-                return Err(anyhow!(ApiError { status, body }));
             }
+
+            let mut body = String::new();
+            let headers = response.headers().clone();
+            response.body_mut().read_to_string(&mut body).await?;
+            return Err(anyhow!(ApiError {
+                status,
+                body,
+                headers
+            }));
         }
     }
 }
@@ -660,6 +643,19 @@ impl CloudLanguageModel {
 struct ApiError {
     status: StatusCode,
     body: String,
+    headers: HeaderMap<HeaderValue>,
+}
+
+impl From<ApiError> for LanguageModelCompletionError {
+    fn from(error: ApiError) -> Self {
+        let retry_after = None;
+        LanguageModelCompletionError::from_http_status(
+            PROVIDER_NAME,
+            error.status,
+            error.body,
+            retry_after,
+        )
+    }
 }
 
 impl LanguageModel for CloudLanguageModel {
@@ -672,11 +668,29 @@ impl LanguageModel for CloudLanguageModel {
     }
 
     fn provider_id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn provider_name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
+    }
+
+    fn upstream_provider_id(&self) -> LanguageModelProviderId {
+        use zed_llm_client::LanguageModelProvider::*;
+        match self.model.provider {
+            Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
+            OpenAi => language_model::OPEN_AI_PROVIDER_ID,
+            Google => language_model::GOOGLE_PROVIDER_ID,
+        }
+    }
+
+    fn upstream_provider_name(&self) -> LanguageModelProviderName {
+        use zed_llm_client::LanguageModelProvider::*;
+        match self.model.provider {
+            Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
+            OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
+            Google => language_model::GOOGLE_PROVIDER_NAME,
+        }
     }
 
     fn supports_tools(&self) -> bool {
@@ -776,6 +790,7 @@ impl LanguageModel for CloudLanguageModel {
                         .body(serde_json::to_string(&request_body)?.into())?;
                     let mut response = http_client.send(request).await?;
                     let status = response.status();
+                    let headers = response.headers().clone();
                     let mut response_body = String::new();
                     response
                         .body_mut()
@@ -790,7 +805,8 @@ impl LanguageModel for CloudLanguageModel {
                     } else {
                         Err(anyhow!(ApiError {
                             status,
-                            body: response_body
+                            body: response_body,
+                            headers
                         }))
                     }
                 }
@@ -855,18 +871,7 @@ impl LanguageModel for CloudLanguageModel {
                     )
                     .await
                     .map_err(|err| match err.downcast::<ApiError>() {
-                        Ok(api_err) => {
-                            if api_err.status == StatusCode::BAD_REQUEST {
-                                if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
-                                    return anyhow!(
-                                        LanguageModelKnownError::ContextWindowLimitExceeded {
-                                            tokens
-                                        }
-                                    );
-                                }
-                            }
-                            anyhow!(api_err)
-                        }
+                        Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
                         Err(err) => anyhow!(err),
                     })?;
 
@@ -995,7 +1000,7 @@ where
         .flat_map(move |event| {
             futures::stream::iter(match event {
                 Err(error) => {
-                    vec![Err(LanguageModelCompletionError::Other(error))]
+                    vec![Err(LanguageModelCompletionError::from(error))]
                 }
                 Ok(CloudCompletionEvent::Status(event)) => {
                     vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]

crates/language_models/src/provider/copilot_chat.rs πŸ”—

@@ -35,8 +35,9 @@ use super::anthropic::count_anthropic_tokens;
 use super::google::count_google_tokens;
 use super::open_ai::count_open_ai_tokens;
 
-const PROVIDER_ID: &str = "copilot_chat";
-const PROVIDER_NAME: &str = "GitHub Copilot Chat";
+const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
+const PROVIDER_NAME: LanguageModelProviderName =
+    LanguageModelProviderName::new("GitHub Copilot Chat");
 
 pub struct CopilotChatLanguageModelProvider {
     state: Entity<State>,
@@ -102,11 +103,11 @@ impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
 
 impl LanguageModelProvider for CopilotChatLanguageModelProvider {
     fn id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn icon(&self) -> IconName {
@@ -201,11 +202,11 @@ impl LanguageModel for CopilotChatLanguageModel {
     }
 
     fn provider_id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn provider_name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn supports_tools(&self) -> bool {
@@ -391,24 +392,24 @@ pub fn map_to_language_model_completion_events(
                                             serde_json::Value::from_str(&tool_call.arguments)
                                         };
                                         match arguments {
-                                            Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
-                                                LanguageModelToolUse {
-                                                    id: tool_call.id.clone().into(),
-                                                    name: tool_call.name.as_str().into(),
-                                                    is_input_complete: true,
-                                                    input,
-                                                    raw_input: tool_call.arguments.clone(),
-                                                },
-                                            )),
-                                            Err(error) => {
-                                                Err(LanguageModelCompletionError::BadInputJson {
-                                                    id: tool_call.id.into(),
-                                                    tool_name: tool_call.name.as_str().into(),
-                                                    raw_input: tool_call.arguments.into(),
-                                                    json_parse_error: error.to_string(),
-                                                })
-                                            }
-                                        }
+                                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
+                                            LanguageModelToolUse {
+                                                id: tool_call.id.clone().into(),
+                                                name: tool_call.name.as_str().into(),
+                                                is_input_complete: true,
+                                                input,
+                                                raw_input: tool_call.arguments.clone(),
+                                            },
+                                        )),
+                                        Err(error) => Ok(
+                                            LanguageModelCompletionEvent::ToolUseJsonParseError {
+                                                id: tool_call.id.into(),
+                                                tool_name: tool_call.name.as_str().into(),
+                                                raw_input: tool_call.arguments.into(),
+                                                json_parse_error: error.to_string(),
+                                            },
+                                        ),
+                                    }
                                     },
                                 ));
 

crates/language_models/src/provider/deepseek.rs πŸ”—

@@ -28,8 +28,8 @@ use util::ResultExt;
 
 use crate::{AllLanguageModelSettings, ui::InstructionListItem};
 
-const PROVIDER_ID: &str = "deepseek";
-const PROVIDER_NAME: &str = "DeepSeek";
+const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek");
+const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek");
 const DEEPSEEK_API_KEY_VAR: &str = "DEEPSEEK_API_KEY";
 
 #[derive(Default)]
@@ -174,11 +174,11 @@ impl LanguageModelProviderState for DeepSeekLanguageModelProvider {
 
 impl LanguageModelProvider for DeepSeekLanguageModelProvider {
     fn id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn icon(&self) -> IconName {
@@ -283,11 +283,11 @@ impl LanguageModel for DeepSeekLanguageModel {
     }
 
     fn provider_id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn provider_name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn supports_tools(&self) -> bool {
@@ -466,7 +466,7 @@ impl DeepSeekEventMapper {
         events.flat_map(move |event| {
             futures::stream::iter(match event {
                 Ok(event) => self.map_event(event),
-                Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
+                Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
             })
         })
     }
@@ -476,7 +476,7 @@ impl DeepSeekEventMapper {
         event: deepseek::StreamResponse,
     ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
         let Some(choice) = event.choices.first() else {
-            return vec![Err(LanguageModelCompletionError::Other(anyhow!(
+            return vec![Err(LanguageModelCompletionError::from(anyhow!(
                 "Response contained no choices"
             )))];
         };
@@ -538,8 +538,8 @@ impl DeepSeekEventMapper {
                                 raw_input: tool_call.arguments.clone(),
                             },
                         )),
-                        Err(error) => Err(LanguageModelCompletionError::BadInputJson {
-                            id: tool_call.id.into(),
+                        Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+                            id: tool_call.id.clone().into(),
                             tool_name: tool_call.name.as_str().into(),
                             raw_input: tool_call.arguments.into(),
                             json_parse_error: error.to_string(),

crates/language_models/src/provider/google.rs πŸ”—

@@ -37,8 +37,8 @@ use util::ResultExt;
 use crate::AllLanguageModelSettings;
 use crate::ui::InstructionListItem;
 
-const PROVIDER_ID: &str = "google";
-const PROVIDER_NAME: &str = "Google AI";
+const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
+const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
 
 #[derive(Default, Clone, Debug, PartialEq)]
 pub struct GoogleSettings {
@@ -207,11 +207,11 @@ impl LanguageModelProviderState for GoogleLanguageModelProvider {
 
 impl LanguageModelProvider for GoogleLanguageModelProvider {
     fn id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn icon(&self) -> IconName {
@@ -334,11 +334,11 @@ impl LanguageModel for GoogleLanguageModel {
     }
 
     fn provider_id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn provider_name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn supports_tools(&self) -> bool {
@@ -423,9 +423,7 @@ impl LanguageModel for GoogleLanguageModel {
         );
         let request = self.stream_completion(request, cx);
         let future = self.request_limiter.stream(async move {
-            let response = request
-                .await
-                .map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?;
+            let response = request.await.map_err(LanguageModelCompletionError::from)?;
             Ok(GoogleEventMapper::new().map_stream(response))
         });
         async move { Ok(future.await?.boxed()) }.boxed()
@@ -622,7 +620,7 @@ impl GoogleEventMapper {
                 futures::stream::iter(match event {
                     Some(Ok(event)) => self.map_event(event),
                     Some(Err(error)) => {
-                        vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))]
+                        vec![Err(LanguageModelCompletionError::from(error))]
                     }
                     None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
                 })

crates/language_models/src/provider/lmstudio.rs πŸ”—

@@ -31,8 +31,8 @@ const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
 const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
 const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
 
-const PROVIDER_ID: &str = "lmstudio";
-const PROVIDER_NAME: &str = "LM Studio";
+const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("lmstudio");
+const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("LM Studio");
 
 #[derive(Default, Debug, Clone, PartialEq)]
 pub struct LmStudioSettings {
@@ -156,11 +156,11 @@ impl LanguageModelProviderState for LmStudioLanguageModelProvider {
 
 impl LanguageModelProvider for LmStudioLanguageModelProvider {
     fn id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn icon(&self) -> IconName {
@@ -386,11 +386,11 @@ impl LanguageModel for LmStudioLanguageModel {
     }
 
     fn provider_id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn provider_name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn supports_tools(&self) -> bool {
@@ -474,7 +474,7 @@ impl LmStudioEventMapper {
         events.flat_map(move |event| {
             futures::stream::iter(match event {
                 Ok(event) => self.map_event(event),
-                Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
+                Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
             })
         })
     }
@@ -484,7 +484,7 @@ impl LmStudioEventMapper {
         event: lmstudio::ResponseStreamEvent,
     ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
         let Some(choice) = event.choices.into_iter().next() else {
-            return vec![Err(LanguageModelCompletionError::Other(anyhow!(
+            return vec![Err(LanguageModelCompletionError::from(anyhow!(
                 "Response contained no choices"
             )))];
         };
@@ -553,7 +553,7 @@ impl LmStudioEventMapper {
                                 raw_input: tool_call.arguments,
                             },
                         )),
-                        Err(error) => Err(LanguageModelCompletionError::BadInputJson {
+                        Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
                             id: tool_call.id.into(),
                             tool_name: tool_call.name.into(),
                             raw_input: tool_call.arguments.into(),

crates/language_models/src/provider/mistral.rs πŸ”—

@@ -2,8 +2,7 @@ use anyhow::{Context as _, Result, anyhow};
 use collections::BTreeMap;
 use credentials_provider::CredentialsProvider;
 use editor::{Editor, EditorElement, EditorStyle};
-use futures::stream::BoxStream;
-use futures::{FutureExt, StreamExt, future::BoxFuture};
+use futures::{FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream};
 use gpui::{
     AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
 };
@@ -15,6 +14,7 @@ use language_model::{
     LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
     RateLimiter, Role, StopReason, TokenUsage,
 };
+use mistral::StreamResponse;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use settings::{Settings, SettingsStore};
@@ -29,8 +29,8 @@ use util::ResultExt;
 
 use crate::{AllLanguageModelSettings, ui::InstructionListItem};
 
-const PROVIDER_ID: &str = "mistral";
-const PROVIDER_NAME: &str = "Mistral";
+const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral");
+const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral");
 
 #[derive(Default, Clone, Debug, PartialEq)]
 pub struct MistralSettings {
@@ -171,11 +171,11 @@ impl LanguageModelProviderState for MistralLanguageModelProvider {
 
 impl LanguageModelProvider for MistralLanguageModelProvider {
     fn id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn icon(&self) -> IconName {
@@ -298,11 +298,11 @@ impl LanguageModel for MistralLanguageModel {
     }
 
     fn provider_id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn provider_name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn supports_tools(&self) -> bool {
@@ -579,13 +579,13 @@ impl MistralEventMapper {
 
     pub fn map_stream(
         mut self,
-        events: Pin<Box<dyn Send + futures::Stream<Item = Result<mistral::StreamResponse>>>>,
-    ) -> impl futures::Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
+        events: Pin<Box<dyn Send + Stream<Item = Result<StreamResponse>>>>,
+    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
     {
         events.flat_map(move |event| {
             futures::stream::iter(match event {
                 Ok(event) => self.map_event(event),
-                Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
+                Err(error) => vec![Err(LanguageModelCompletionError::from(error))],
             })
         })
     }
@@ -595,7 +595,7 @@ impl MistralEventMapper {
         event: mistral::StreamResponse,
     ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
         let Some(choice) = event.choices.first() else {
-            return vec![Err(LanguageModelCompletionError::Other(anyhow!(
+            return vec![Err(LanguageModelCompletionError::from(anyhow!(
                 "Response contained no choices"
             )))];
         };
@@ -660,7 +660,7 @@ impl MistralEventMapper {
 
         for (_, tool_call) in self.tool_calls_by_index.drain() {
             if tool_call.id.is_empty() || tool_call.name.is_empty() {
-                results.push(Err(LanguageModelCompletionError::Other(anyhow!(
+                results.push(Err(LanguageModelCompletionError::from(anyhow!(
                     "Received incomplete tool call: missing id or name"
                 ))));
                 continue;
@@ -676,12 +676,14 @@ impl MistralEventMapper {
                         raw_input: tool_call.arguments,
                     },
                 ))),
-                Err(error) => results.push(Err(LanguageModelCompletionError::BadInputJson {
-                    id: tool_call.id.into(),
-                    tool_name: tool_call.name.into(),
-                    raw_input: tool_call.arguments.into(),
-                    json_parse_error: error.to_string(),
-                })),
+                Err(error) => {
+                    results.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+                        id: tool_call.id.into(),
+                        tool_name: tool_call.name.into(),
+                        raw_input: tool_call.arguments.into(),
+                        json_parse_error: error.to_string(),
+                    }))
+                }
             }
         }
 

crates/language_models/src/provider/ollama.rs πŸ”—

@@ -30,8 +30,8 @@ const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
 const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
 const OLLAMA_SITE: &str = "https://ollama.com/";
 
-const PROVIDER_ID: &str = "ollama";
-const PROVIDER_NAME: &str = "Ollama";
+const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama");
+const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Ollama");
 
 #[derive(Default, Debug, Clone, PartialEq)]
 pub struct OllamaSettings {
@@ -181,11 +181,11 @@ impl LanguageModelProviderState for OllamaLanguageModelProvider {
 
 impl LanguageModelProvider for OllamaLanguageModelProvider {
     fn id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn icon(&self) -> IconName {
@@ -350,11 +350,11 @@ impl LanguageModel for OllamaLanguageModel {
     }
 
     fn provider_id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn provider_name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn supports_tools(&self) -> bool {
@@ -453,7 +453,7 @@ fn map_to_language_model_completion_events(
             let delta = match response {
                 Ok(delta) => delta,
                 Err(e) => {
-                    let event = Err(LanguageModelCompletionError::Other(anyhow!(e)));
+                    let event = Err(LanguageModelCompletionError::from(anyhow!(e)));
                     return Some((vec![event], state));
                 }
             };

crates/language_models/src/provider/open_ai.rs πŸ”—

@@ -31,8 +31,8 @@ use util::ResultExt;
 use crate::OpenAiSettingsContent;
 use crate::{AllLanguageModelSettings, ui::InstructionListItem};
 
-const PROVIDER_ID: &str = "openai";
-const PROVIDER_NAME: &str = "OpenAI";
+const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
+const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
 
 #[derive(Default, Clone, Debug, PartialEq)]
 pub struct OpenAiSettings {
@@ -173,11 +173,11 @@ impl LanguageModelProviderState for OpenAiLanguageModelProvider {
 
 impl LanguageModelProvider for OpenAiLanguageModelProvider {
     fn id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn icon(&self) -> IconName {
@@ -267,7 +267,11 @@ impl OpenAiLanguageModel {
         };
 
         let future = self.request_limiter.stream(async move {
-            let api_key = api_key.context("Missing OpenAI API Key")?;
+            let Some(api_key) = api_key else {
+                return Err(LanguageModelCompletionError::NoApiKey {
+                    provider: PROVIDER_NAME,
+                });
+            };
             let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
             let response = request.await?;
             Ok(response)
@@ -287,11 +291,11 @@ impl LanguageModel for OpenAiLanguageModel {
     }
 
     fn provider_id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn provider_name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn supports_tools(&self) -> bool {
@@ -525,7 +529,7 @@ impl OpenAiEventMapper {
         events.flat_map(move |event| {
             futures::stream::iter(match event {
                 Ok(event) => self.map_event(event),
-                Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
+                Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
             })
         })
     }
@@ -588,10 +592,10 @@ impl OpenAiEventMapper {
                                 raw_input: tool_call.arguments.clone(),
                             },
                         )),
-                        Err(error) => Err(LanguageModelCompletionError::BadInputJson {
+                        Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
                             id: tool_call.id.into(),
-                            tool_name: tool_call.name.as_str().into(),
-                            raw_input: tool_call.arguments.into(),
+                            tool_name: tool_call.name.into(),
+                            raw_input: tool_call.arguments.clone().into(),
                             json_parse_error: error.to_string(),
                         }),
                     }

crates/language_models/src/provider/open_router.rs πŸ”—

@@ -29,8 +29,8 @@ use util::ResultExt;
 
 use crate::{AllLanguageModelSettings, ui::InstructionListItem};
 
-const PROVIDER_ID: &str = "openrouter";
-const PROVIDER_NAME: &str = "OpenRouter";
+const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter");
+const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter");
 
 #[derive(Default, Clone, Debug, PartialEq)]
 pub struct OpenRouterSettings {
@@ -244,11 +244,11 @@ impl LanguageModelProviderState for OpenRouterLanguageModelProvider {
 
 impl LanguageModelProvider for OpenRouterLanguageModelProvider {
     fn id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn icon(&self) -> IconName {
@@ -363,11 +363,11 @@ impl LanguageModel for OpenRouterLanguageModel {
     }
 
     fn provider_id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn provider_name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn supports_tools(&self) -> bool {
@@ -607,7 +607,7 @@ impl OpenRouterEventMapper {
         events.flat_map(move |event| {
             futures::stream::iter(match event {
                 Ok(event) => self.map_event(event),
-                Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
+                Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
             })
         })
     }
@@ -617,7 +617,7 @@ impl OpenRouterEventMapper {
         event: ResponseStreamEvent,
     ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
         let Some(choice) = event.choices.first() else {
-            return vec![Err(LanguageModelCompletionError::Other(anyhow!(
+            return vec![Err(LanguageModelCompletionError::from(anyhow!(
                 "Response contained no choices"
             )))];
         };
@@ -683,10 +683,10 @@ impl OpenRouterEventMapper {
                                 raw_input: tool_call.arguments.clone(),
                             },
                         )),
-                        Err(error) => Err(LanguageModelCompletionError::BadInputJson {
-                            id: tool_call.id.into(),
+                        Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+                            id: tool_call.id.clone().into(),
                             tool_name: tool_call.name.as_str().into(),
-                            raw_input: tool_call.arguments.into(),
+                            raw_input: tool_call.arguments.clone().into(),
                             json_parse_error: error.to_string(),
                         }),
                     }

crates/language_models/src/provider/vercel.rs πŸ”—

@@ -25,8 +25,8 @@ use util::ResultExt;
 
 use crate::{AllLanguageModelSettings, ui::InstructionListItem};
 
-const PROVIDER_ID: &str = "vercel";
-const PROVIDER_NAME: &str = "Vercel";
+const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("vercel");
+const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Vercel");
 
 #[derive(Default, Clone, Debug, PartialEq)]
 pub struct VercelSettings {
@@ -172,11 +172,11 @@ impl LanguageModelProviderState for VercelLanguageModelProvider {
 
 impl LanguageModelProvider for VercelLanguageModelProvider {
     fn id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn icon(&self) -> IconName {
@@ -269,7 +269,11 @@ impl VercelLanguageModel {
         };
 
         let future = self.request_limiter.stream(async move {
-            let api_key = api_key.context("Missing Vercel API Key")?;
+            let Some(api_key) = api_key else {
+                return Err(LanguageModelCompletionError::NoApiKey {
+                    provider: PROVIDER_NAME,
+                });
+            };
             let request =
                 open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
             let response = request.await?;
@@ -290,11 +294,11 @@ impl LanguageModel for VercelLanguageModel {
     }
 
     fn provider_id(&self) -> LanguageModelProviderId {
-        LanguageModelProviderId(PROVIDER_ID.into())
+        PROVIDER_ID
     }
 
     fn provider_name(&self) -> LanguageModelProviderName {
-        LanguageModelProviderName(PROVIDER_NAME.into())
+        PROVIDER_NAME
     }
 
     fn supports_tools(&self) -> bool {

crates/web_search_providers/src/cloud.rs πŸ”—

@@ -7,10 +7,7 @@ use gpui::{App, AppContext, Context, Entity, Subscription, Task};
 use http_client::{HttpClient, Method};
 use language_model::{LlmApiToken, RefreshLlmTokenListener};
 use web_search::{WebSearchProvider, WebSearchProviderId};
-use zed_llm_client::{
-    CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, EXPIRED_LLM_TOKEN_HEADER_NAME,
-    WebSearchBody, WebSearchResponse,
-};
+use zed_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, WebSearchBody, WebSearchResponse};
 
 pub struct CloudWebSearchProvider {
     state: Entity<State>,
@@ -92,7 +89,6 @@ async fn perform_web_search(
             .uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
             .header("Content-Type", "application/json")
             .header("Authorization", format!("Bearer {token}"))
-            .header(CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, "true")
             .body(serde_json::to_string(&body)?.into())?;
         let mut response = http_client
             .send(request)