1mod model;
  2mod rate_limiter;
  3mod registry;
  4mod request;
  5mod role;
  6mod telemetry;
  7
  8#[cfg(any(test, feature = "test-support"))]
  9pub mod fake_provider;
 10
 11use anthropic::{AnthropicError, parse_prompt_too_long};
 12use anyhow::{Result, anyhow};
 13use client::Client;
 14use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
 15use futures::FutureExt;
 16use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 17use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
 18use http_client::{StatusCode, http};
 19use icons::IconName;
 20use open_router::OpenRouterError;
 21use parking_lot::Mutex;
 22use serde::{Deserialize, Serialize};
 23pub use settings::LanguageModelCacheConfiguration;
 24use std::ops::{Add, Sub};
 25use std::str::FromStr;
 26use std::sync::Arc;
 27use std::time::Duration;
 28use std::{fmt, io};
 29use thiserror::Error;
 30use util::serde::is_default;
 31
 32pub use crate::model::*;
 33pub use crate::rate_limiter::*;
 34pub use crate::registry::*;
 35pub use crate::request::*;
 36pub use crate::role::*;
 37pub use crate::telemetry::*;
 38
 39pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
 40    LanguageModelProviderId::new("anthropic");
 41pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
 42    LanguageModelProviderName::new("Anthropic");
 43
 44pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
 45pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
 46    LanguageModelProviderName::new("Google AI");
 47
 48pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
 49pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
 50    LanguageModelProviderName::new("OpenAI");
 51
 52pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
 53pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
 54
 55pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
 56pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
 57    LanguageModelProviderName::new("Zed");
 58
 59pub const OLLAMA_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama");
 60pub const OLLAMA_PROVIDER_NAME: LanguageModelProviderName =
 61    LanguageModelProviderName::new("Ollama");
 62
 63pub fn init(client: Arc<Client>, cx: &mut App) {
 64    init_settings(cx);
 65    RefreshLlmTokenListener::register(client, cx);
 66}
 67
 68pub fn init_settings(cx: &mut App) {
 69    registry::init(cx);
 70}
 71
 72/// A completion event from a language model.
 73#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 74pub enum LanguageModelCompletionEvent {
 75    StatusUpdate(CompletionRequestStatus),
 76    Stop(StopReason),
 77    Text(String),
 78    Thinking {
 79        text: String,
 80        signature: Option<String>,
 81    },
 82    RedactedThinking {
 83        data: String,
 84    },
 85    ToolUse(LanguageModelToolUse),
 86    ToolUseJsonParseError {
 87        id: LanguageModelToolUseId,
 88        tool_name: Arc<str>,
 89        raw_input: Arc<str>,
 90        json_parse_error: String,
 91    },
 92    StartMessage {
 93        message_id: String,
 94    },
 95    UsageUpdate(TokenUsage),
 96}
 97
 98#[derive(Error, Debug)]
 99pub enum LanguageModelCompletionError {
100    #[error("prompt too large for context window")]
101    PromptTooLarge { tokens: Option<u64> },
102    #[error("missing {provider} API key")]
103    NoApiKey { provider: LanguageModelProviderName },
104    #[error("{provider}'s API rate limit exceeded")]
105    RateLimitExceeded {
106        provider: LanguageModelProviderName,
107        retry_after: Option<Duration>,
108    },
109    #[error("{provider}'s API servers are overloaded right now")]
110    ServerOverloaded {
111        provider: LanguageModelProviderName,
112        retry_after: Option<Duration>,
113    },
114    #[error("{provider}'s API server reported an internal server error: {message}")]
115    ApiInternalServerError {
116        provider: LanguageModelProviderName,
117        message: String,
118    },
119    #[error("{message}")]
120    UpstreamProviderError {
121        message: String,
122        status: StatusCode,
123        retry_after: Option<Duration>,
124    },
125    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
126    HttpResponseError {
127        provider: LanguageModelProviderName,
128        status_code: StatusCode,
129        message: String,
130    },
131
132    // Client errors
133    #[error("invalid request format to {provider}'s API: {message}")]
134    BadRequestFormat {
135        provider: LanguageModelProviderName,
136        message: String,
137    },
138    #[error("authentication error with {provider}'s API: {message}")]
139    AuthenticationError {
140        provider: LanguageModelProviderName,
141        message: String,
142    },
143    #[error("permission error with {provider}'s API: {message}")]
144    PermissionError {
145        provider: LanguageModelProviderName,
146        message: String,
147    },
148    #[error("language model provider API endpoint not found")]
149    ApiEndpointNotFound { provider: LanguageModelProviderName },
150    #[error("I/O error reading response from {provider}'s API")]
151    ApiReadResponseError {
152        provider: LanguageModelProviderName,
153        #[source]
154        error: io::Error,
155    },
156    #[error("error serializing request to {provider} API")]
157    SerializeRequest {
158        provider: LanguageModelProviderName,
159        #[source]
160        error: serde_json::Error,
161    },
162    #[error("error building request body to {provider} API")]
163    BuildRequestBody {
164        provider: LanguageModelProviderName,
165        #[source]
166        error: http::Error,
167    },
168    #[error("error sending HTTP request to {provider} API")]
169    HttpSend {
170        provider: LanguageModelProviderName,
171        #[source]
172        error: anyhow::Error,
173    },
174    #[error("error deserializing {provider} API response")]
175    DeserializeResponse {
176        provider: LanguageModelProviderName,
177        #[source]
178        error: serde_json::Error,
179    },
180
181    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
182    #[error(transparent)]
183    Other(#[from] anyhow::Error),
184}
185
186impl LanguageModelCompletionError {
187    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
188        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
189        let upstream_status = error_json
190            .get("upstream_status")
191            .and_then(|v| v.as_u64())
192            .and_then(|status| u16::try_from(status).ok())
193            .and_then(|status| StatusCode::from_u16(status).ok())?;
194        let inner_message = error_json
195            .get("message")
196            .and_then(|v| v.as_str())
197            .unwrap_or(message)
198            .to_string();
199        Some((upstream_status, inner_message))
200    }
201
202    pub fn from_cloud_failure(
203        upstream_provider: LanguageModelProviderName,
204        code: String,
205        message: String,
206        retry_after: Option<Duration>,
207    ) -> Self {
208        if let Some(tokens) = parse_prompt_too_long(&message) {
209            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
210            // to be reported. This is a temporary workaround to handle this in the case where the
211            // token limit has been exceeded.
212            Self::PromptTooLarge {
213                tokens: Some(tokens),
214            }
215        } else if code == "upstream_http_error" {
216            if let Some((upstream_status, inner_message)) =
217                Self::parse_upstream_error_json(&message)
218            {
219                return Self::from_http_status(
220                    upstream_provider,
221                    upstream_status,
222                    inner_message,
223                    retry_after,
224                );
225            }
226            anyhow!("completion request failed, code: {code}, message: {message}").into()
227        } else if let Some(status_code) = code
228            .strip_prefix("upstream_http_")
229            .and_then(|code| StatusCode::from_str(code).ok())
230        {
231            Self::from_http_status(upstream_provider, status_code, message, retry_after)
232        } else if let Some(status_code) = code
233            .strip_prefix("http_")
234            .and_then(|code| StatusCode::from_str(code).ok())
235        {
236            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
237        } else {
238            anyhow!("completion request failed, code: {code}, message: {message}").into()
239        }
240    }
241
242    pub fn from_http_status(
243        provider: LanguageModelProviderName,
244        status_code: StatusCode,
245        message: String,
246        retry_after: Option<Duration>,
247    ) -> Self {
248        match status_code {
249            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
250            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
251            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
252            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
253            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
254                tokens: parse_prompt_too_long(&message),
255            },
256            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
257                provider,
258                retry_after,
259            },
260            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
261            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
262                provider,
263                retry_after,
264            },
265            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
266                provider,
267                retry_after,
268            },
269            _ => Self::HttpResponseError {
270                provider,
271                status_code,
272                message,
273            },
274        }
275    }
276}
277
278impl From<AnthropicError> for LanguageModelCompletionError {
279    fn from(error: AnthropicError) -> Self {
280        let provider = ANTHROPIC_PROVIDER_NAME;
281        match error {
282            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
283            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
284            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
285            AnthropicError::DeserializeResponse(error) => {
286                Self::DeserializeResponse { provider, error }
287            }
288            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
289            AnthropicError::HttpResponseError {
290                status_code,
291                message,
292            } => Self::HttpResponseError {
293                provider,
294                status_code,
295                message,
296            },
297            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
298                provider,
299                retry_after: Some(retry_after),
300            },
301            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
302                provider,
303                retry_after,
304            },
305            AnthropicError::ApiError(api_error) => api_error.into(),
306        }
307    }
308}
309
310impl From<anthropic::ApiError> for LanguageModelCompletionError {
311    fn from(error: anthropic::ApiError) -> Self {
312        use anthropic::ApiErrorCode::*;
313        let provider = ANTHROPIC_PROVIDER_NAME;
314        match error.code() {
315            Some(code) => match code {
316                InvalidRequestError => Self::BadRequestFormat {
317                    provider,
318                    message: error.message,
319                },
320                AuthenticationError => Self::AuthenticationError {
321                    provider,
322                    message: error.message,
323                },
324                PermissionError => Self::PermissionError {
325                    provider,
326                    message: error.message,
327                },
328                NotFoundError => Self::ApiEndpointNotFound { provider },
329                RequestTooLarge => Self::PromptTooLarge {
330                    tokens: parse_prompt_too_long(&error.message),
331                },
332                RateLimitError => Self::RateLimitExceeded {
333                    provider,
334                    retry_after: None,
335                },
336                ApiError => Self::ApiInternalServerError {
337                    provider,
338                    message: error.message,
339                },
340                OverloadedError => Self::ServerOverloaded {
341                    provider,
342                    retry_after: None,
343                },
344            },
345            None => Self::Other(error.into()),
346        }
347    }
348}
349
350impl From<OpenRouterError> for LanguageModelCompletionError {
351    fn from(error: OpenRouterError) -> Self {
352        let provider = LanguageModelProviderName::new("OpenRouter");
353        match error {
354            OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
355            OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
356            OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
357            OpenRouterError::DeserializeResponse(error) => {
358                Self::DeserializeResponse { provider, error }
359            }
360            OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
361            OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
362                provider,
363                retry_after: Some(retry_after),
364            },
365            OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
366                provider,
367                retry_after,
368            },
369            OpenRouterError::ApiError(api_error) => api_error.into(),
370        }
371    }
372}
373
374impl From<open_router::ApiError> for LanguageModelCompletionError {
375    fn from(error: open_router::ApiError) -> Self {
376        use open_router::ApiErrorCode::*;
377        let provider = LanguageModelProviderName::new("OpenRouter");
378        match error.code {
379            InvalidRequestError => Self::BadRequestFormat {
380                provider,
381                message: error.message,
382            },
383            AuthenticationError => Self::AuthenticationError {
384                provider,
385                message: error.message,
386            },
387            PaymentRequiredError => Self::AuthenticationError {
388                provider,
389                message: format!("Payment required: {}", error.message),
390            },
391            PermissionError => Self::PermissionError {
392                provider,
393                message: error.message,
394            },
395            RequestTimedOut => Self::HttpResponseError {
396                provider,
397                status_code: StatusCode::REQUEST_TIMEOUT,
398                message: error.message,
399            },
400            RateLimitError => Self::RateLimitExceeded {
401                provider,
402                retry_after: None,
403            },
404            ApiError => Self::ApiInternalServerError {
405                provider,
406                message: error.message,
407            },
408            OverloadedError => Self::ServerOverloaded {
409                provider,
410                retry_after: None,
411            },
412        }
413    }
414}
415
416/// Indicates the format used to define the input schema for a language model tool.
417#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
418pub enum LanguageModelToolSchemaFormat {
419    /// A JSON schema, see https://json-schema.org
420    JsonSchema,
421    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
422    JsonSchemaSubset,
423}
424
425#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
426#[serde(rename_all = "snake_case")]
427pub enum StopReason {
428    EndTurn,
429    MaxTokens,
430    ToolUse,
431    Refusal,
432}
433
434#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
435pub struct TokenUsage {
436    #[serde(default, skip_serializing_if = "is_default")]
437    pub input_tokens: u64,
438    #[serde(default, skip_serializing_if = "is_default")]
439    pub output_tokens: u64,
440    #[serde(default, skip_serializing_if = "is_default")]
441    pub cache_creation_input_tokens: u64,
442    #[serde(default, skip_serializing_if = "is_default")]
443    pub cache_read_input_tokens: u64,
444}
445
446impl TokenUsage {
447    pub fn total_tokens(&self) -> u64 {
448        self.input_tokens
449            + self.output_tokens
450            + self.cache_read_input_tokens
451            + self.cache_creation_input_tokens
452    }
453}
454
455impl Add<TokenUsage> for TokenUsage {
456    type Output = Self;
457
458    fn add(self, other: Self) -> Self {
459        Self {
460            input_tokens: self.input_tokens + other.input_tokens,
461            output_tokens: self.output_tokens + other.output_tokens,
462            cache_creation_input_tokens: self.cache_creation_input_tokens
463                + other.cache_creation_input_tokens,
464            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
465        }
466    }
467}
468
469impl Sub<TokenUsage> for TokenUsage {
470    type Output = Self;
471
472    fn sub(self, other: Self) -> Self {
473        Self {
474            input_tokens: self.input_tokens - other.input_tokens,
475            output_tokens: self.output_tokens - other.output_tokens,
476            cache_creation_input_tokens: self.cache_creation_input_tokens
477                - other.cache_creation_input_tokens,
478            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
479        }
480    }
481}
482
483#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
484pub struct LanguageModelToolUseId(Arc<str>);
485
486impl fmt::Display for LanguageModelToolUseId {
487    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
488        write!(f, "{}", self.0)
489    }
490}
491
492impl<T> From<T> for LanguageModelToolUseId
493where
494    T: Into<Arc<str>>,
495{
496    fn from(value: T) -> Self {
497        Self(value.into())
498    }
499}
500
501#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
502pub struct LanguageModelToolUse {
503    pub id: LanguageModelToolUseId,
504    pub name: Arc<str>,
505    pub raw_input: String,
506    pub input: serde_json::Value,
507    pub is_input_complete: bool,
508}
509
510pub struct LanguageModelTextStream {
511    pub message_id: Option<String>,
512    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
513    // Has complete token usage after the stream has finished
514    pub last_token_usage: Arc<Mutex<TokenUsage>>,
515}
516
517impl Default for LanguageModelTextStream {
518    fn default() -> Self {
519        Self {
520            message_id: None,
521            stream: Box::pin(futures::stream::empty()),
522            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
523        }
524    }
525}
526
527pub trait LanguageModel: Send + Sync {
528    fn id(&self) -> LanguageModelId;
529    fn name(&self) -> LanguageModelName;
530    fn provider_id(&self) -> LanguageModelProviderId;
531    fn provider_name(&self) -> LanguageModelProviderName;
532    fn upstream_provider_id(&self) -> LanguageModelProviderId {
533        self.provider_id()
534    }
535    fn upstream_provider_name(&self) -> LanguageModelProviderName {
536        self.provider_name()
537    }
538
539    fn telemetry_id(&self) -> String;
540
541    fn api_key(&self, _cx: &App) -> Option<String> {
542        None
543    }
544
545    /// Whether this model supports images
546    fn supports_images(&self) -> bool;
547
548    /// Whether this model supports tools.
549    fn supports_tools(&self) -> bool;
550
551    /// Whether this model supports choosing which tool to use.
552    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
553
554    /// Returns whether this model supports "burn mode";
555    fn supports_burn_mode(&self) -> bool {
556        false
557    }
558
559    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
560        LanguageModelToolSchemaFormat::JsonSchema
561    }
562
563    fn max_token_count(&self) -> u64;
564    /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
565    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
566        None
567    }
568    fn max_output_tokens(&self) -> Option<u64> {
569        None
570    }
571
572    fn count_tokens(
573        &self,
574        request: LanguageModelRequest,
575        cx: &App,
576    ) -> BoxFuture<'static, Result<u64>>;
577
578    fn stream_completion(
579        &self,
580        request: LanguageModelRequest,
581        cx: &AsyncApp,
582    ) -> BoxFuture<
583        'static,
584        Result<
585            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
586            LanguageModelCompletionError,
587        >,
588    >;
589
590    fn stream_completion_text(
591        &self,
592        request: LanguageModelRequest,
593        cx: &AsyncApp,
594    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
595        let future = self.stream_completion(request, cx);
596
597        async move {
598            let events = future.await?;
599            let mut events = events.fuse();
600            let mut message_id = None;
601            let mut first_item_text = None;
602            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
603
604            if let Some(first_event) = events.next().await {
605                match first_event {
606                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
607                        message_id = Some(id);
608                    }
609                    Ok(LanguageModelCompletionEvent::Text(text)) => {
610                        first_item_text = Some(text);
611                    }
612                    _ => (),
613                }
614            }
615
616            let stream = futures::stream::iter(first_item_text.map(Ok))
617                .chain(events.filter_map({
618                    let last_token_usage = last_token_usage.clone();
619                    move |result| {
620                        let last_token_usage = last_token_usage.clone();
621                        async move {
622                            match result {
623                                Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
624                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
625                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
626                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
627                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
628                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
629                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
630                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
631                                    ..
632                                }) => None,
633                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
634                                    *last_token_usage.lock() = token_usage;
635                                    None
636                                }
637                                Err(err) => Some(Err(err)),
638                            }
639                        }
640                    }
641                }))
642                .boxed();
643
644            Ok(LanguageModelTextStream {
645                message_id,
646                stream,
647                last_token_usage,
648            })
649        }
650        .boxed()
651    }
652
653    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
654        None
655    }
656
657    #[cfg(any(test, feature = "test-support"))]
658    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
659        unimplemented!()
660    }
661}
662
663pub trait LanguageModelExt: LanguageModel {
664    fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
665        match mode {
666            CompletionMode::Normal => self.max_token_count(),
667            CompletionMode::Max => self
668                .max_token_count_in_burn_mode()
669                .unwrap_or_else(|| self.max_token_count()),
670        }
671    }
672}
673impl LanguageModelExt for dyn LanguageModel {}
674
675/// An error that occurred when trying to authenticate the language model provider.
676#[derive(Debug, Error)]
677pub enum AuthenticateError {
678    #[error("connection refused")]
679    ConnectionRefused,
680    #[error("credentials not found")]
681    CredentialsNotFound,
682    #[error(transparent)]
683    Other(#[from] anyhow::Error),
684}
685
686pub trait LanguageModelProvider: 'static {
687    fn id(&self) -> LanguageModelProviderId;
688    fn name(&self) -> LanguageModelProviderName;
689    fn icon(&self) -> IconName {
690        IconName::ZedAssistant
691    }
692    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
693    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
694    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
695    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
696        Vec::new()
697    }
698    fn is_authenticated(&self, cx: &App) -> bool;
699    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
700    fn configuration_view(
701        &self,
702        target_agent: ConfigurationViewTargetAgent,
703        window: &mut Window,
704        cx: &mut App,
705    ) -> AnyView;
706    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
707}
708
709#[derive(Default, Clone)]
710pub enum ConfigurationViewTargetAgent {
711    #[default]
712    ZedAgent,
713    Other(SharedString),
714}
715
716#[derive(PartialEq, Eq)]
717pub enum LanguageModelProviderTosView {
718    /// When there are some past interactions in the Agent Panel.
719    ThreadEmptyState,
720    /// When there are no past interactions in the Agent Panel.
721    ThreadFreshStart,
722    TextThreadPopup,
723    Configuration,
724}
725
726pub trait LanguageModelProviderState: 'static {
727    type ObservableEntity;
728
729    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
730
731    fn subscribe<T: 'static>(
732        &self,
733        cx: &mut gpui::Context<T>,
734        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
735    ) -> Option<gpui::Subscription> {
736        let entity = self.observable_entity()?;
737        Some(cx.observe(&entity, move |this, _, cx| {
738            callback(this, cx);
739        }))
740    }
741}
742
743#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
744pub struct LanguageModelId(pub SharedString);
745
746#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
747pub struct LanguageModelName(pub SharedString);
748
749#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
750pub struct LanguageModelProviderId(pub SharedString);
751
752#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
753pub struct LanguageModelProviderName(pub SharedString);
754
755impl LanguageModelProviderId {
756    pub const fn new(id: &'static str) -> Self {
757        Self(SharedString::new_static(id))
758    }
759}
760
761impl LanguageModelProviderName {
762    pub const fn new(id: &'static str) -> Self {
763        Self(SharedString::new_static(id))
764    }
765}
766
767impl fmt::Display for LanguageModelProviderId {
768    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
769        write!(f, "{}", self.0)
770    }
771}
772
773impl fmt::Display for LanguageModelProviderName {
774    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
775        write!(f, "{}", self.0)
776    }
777}
778
779impl From<String> for LanguageModelId {
780    fn from(value: String) -> Self {
781        Self(SharedString::from(value))
782    }
783}
784
785impl From<String> for LanguageModelName {
786    fn from(value: String) -> Self {
787        Self(SharedString::from(value))
788    }
789}
790
791impl From<String> for LanguageModelProviderId {
792    fn from(value: String) -> Self {
793        Self(SharedString::from(value))
794    }
795}
796
797impl From<String> for LanguageModelProviderName {
798    fn from(value: String) -> Self {
799        Self(SharedString::from(value))
800    }
801}
802
803impl From<Arc<str>> for LanguageModelProviderId {
804    fn from(value: Arc<str>) -> Self {
805        Self(SharedString::from(value))
806    }
807}
808
809impl From<Arc<str>> for LanguageModelProviderName {
810    fn from(value: Arc<str>) -> Self {
811        Self(SharedString::from(value))
812    }
813}
814
815#[cfg(test)]
816mod tests {
817    use super::*;
818
819    #[test]
820    fn test_from_cloud_failure_with_upstream_http_error() {
821        let error = LanguageModelCompletionError::from_cloud_failure(
822            String::from("anthropic").into(),
823            "upstream_http_error".to_string(),
824            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
825            None,
826        );
827
828        match error {
829            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
830                assert_eq!(provider.0, "anthropic");
831            }
832            _ => panic!(
833                "Expected ServerOverloaded error for 503 status, got: {:?}",
834                error
835            ),
836        }
837
838        let error = LanguageModelCompletionError::from_cloud_failure(
839            String::from("anthropic").into(),
840            "upstream_http_error".to_string(),
841            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
842            None,
843        );
844
845        match error {
846            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
847                assert_eq!(provider.0, "anthropic");
848                assert_eq!(message, "Internal server error");
849            }
850            _ => panic!(
851                "Expected ApiInternalServerError for 500 status, got: {:?}",
852                error
853            ),
854        }
855    }
856
857    #[test]
858    fn test_from_cloud_failure_with_standard_format() {
859        let error = LanguageModelCompletionError::from_cloud_failure(
860            String::from("anthropic").into(),
861            "upstream_http_503".to_string(),
862            "Service unavailable".to_string(),
863            None,
864        );
865
866        match error {
867            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
868                assert_eq!(provider.0, "anthropic");
869            }
870            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
871        }
872    }
873
874    #[test]
875    fn test_upstream_http_error_connection_timeout() {
876        let error = LanguageModelCompletionError::from_cloud_failure(
877            String::from("anthropic").into(),
878            "upstream_http_error".to_string(),
879            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
880            None,
881        );
882
883        match error {
884            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
885                assert_eq!(provider.0, "anthropic");
886            }
887            _ => panic!(
888                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
889                error
890            ),
891        }
892
893        let error = LanguageModelCompletionError::from_cloud_failure(
894            String::from("anthropic").into(),
895            "upstream_http_error".to_string(),
896            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
897            None,
898        );
899
900        match error {
901            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
902                assert_eq!(provider.0, "anthropic");
903                assert_eq!(
904                    message,
905                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
906                );
907            }
908            _ => panic!(
909                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
910                error
911            ),
912        }
913    }
914}