language_model.rs

  1mod model;
  2mod rate_limiter;
  3mod registry;
  4mod request;
  5mod role;
  6mod telemetry;
  7
  8#[cfg(any(test, feature = "test-support"))]
  9pub mod fake_provider;
 10
 11use anthropic::{AnthropicError, parse_prompt_too_long};
 12use anyhow::{Result, anyhow};
 13use client::Client;
 14use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
 15use futures::FutureExt;
 16use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 17use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
 18use http_client::{StatusCode, http};
 19use icons::IconName;
 20use open_router::OpenRouterError;
 21use parking_lot::Mutex;
 22use serde::{Deserialize, Serialize};
 23pub use settings::LanguageModelCacheConfiguration;
 24use std::ops::{Add, Sub};
 25use std::str::FromStr;
 26use std::sync::Arc;
 27use std::time::Duration;
 28use std::{fmt, io};
 29use thiserror::Error;
 30use util::serde::is_default;
 31
 32pub use crate::model::*;
 33pub use crate::rate_limiter::*;
 34pub use crate::registry::*;
 35pub use crate::request::*;
 36pub use crate::role::*;
 37pub use crate::telemetry::*;
 38
 39pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
 40    LanguageModelProviderId::new("anthropic");
 41pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
 42    LanguageModelProviderName::new("Anthropic");
 43
 44pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
 45pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
 46    LanguageModelProviderName::new("Google AI");
 47
 48pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
 49pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
 50    LanguageModelProviderName::new("OpenAI");
 51
 52pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
 53pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
 54
 55pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
 56pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
 57    LanguageModelProviderName::new("Zed");
 58
 59pub fn init(client: Arc<Client>, cx: &mut App) {
 60    init_settings(cx);
 61    RefreshLlmTokenListener::register(client, cx);
 62}
 63
 64pub fn init_settings(cx: &mut App) {
 65    registry::init(cx);
 66}
 67
 68/// A completion event from a language model.
 69#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 70pub enum LanguageModelCompletionEvent {
 71    StatusUpdate(CompletionRequestStatus),
 72    Stop(StopReason),
 73    Text(String),
 74    Thinking {
 75        text: String,
 76        signature: Option<String>,
 77    },
 78    RedactedThinking {
 79        data: String,
 80    },
 81    ToolUse(LanguageModelToolUse),
 82    ToolUseJsonParseError {
 83        id: LanguageModelToolUseId,
 84        tool_name: Arc<str>,
 85        raw_input: Arc<str>,
 86        json_parse_error: String,
 87    },
 88    StartMessage {
 89        message_id: String,
 90    },
 91    UsageUpdate(TokenUsage),
 92}
 93
 94#[derive(Error, Debug)]
 95pub enum LanguageModelCompletionError {
 96    #[error("prompt too large for context window")]
 97    PromptTooLarge { tokens: Option<u64> },
 98    #[error("missing {provider} API key")]
 99    NoApiKey { provider: LanguageModelProviderName },
100    #[error("{provider}'s API rate limit exceeded")]
101    RateLimitExceeded {
102        provider: LanguageModelProviderName,
103        retry_after: Option<Duration>,
104    },
105    #[error("{provider}'s API servers are overloaded right now")]
106    ServerOverloaded {
107        provider: LanguageModelProviderName,
108        retry_after: Option<Duration>,
109    },
110    #[error("{provider}'s API server reported an internal server error: {message}")]
111    ApiInternalServerError {
112        provider: LanguageModelProviderName,
113        message: String,
114    },
115    #[error("{message}")]
116    UpstreamProviderError {
117        message: String,
118        status: StatusCode,
119        retry_after: Option<Duration>,
120    },
121    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
122    HttpResponseError {
123        provider: LanguageModelProviderName,
124        status_code: StatusCode,
125        message: String,
126    },
127
128    // Client errors
129    #[error("invalid request format to {provider}'s API: {message}")]
130    BadRequestFormat {
131        provider: LanguageModelProviderName,
132        message: String,
133    },
134    #[error("authentication error with {provider}'s API: {message}")]
135    AuthenticationError {
136        provider: LanguageModelProviderName,
137        message: String,
138    },
139    #[error("permission error with {provider}'s API: {message}")]
140    PermissionError {
141        provider: LanguageModelProviderName,
142        message: String,
143    },
144    #[error("language model provider API endpoint not found")]
145    ApiEndpointNotFound { provider: LanguageModelProviderName },
146    #[error("I/O error reading response from {provider}'s API")]
147    ApiReadResponseError {
148        provider: LanguageModelProviderName,
149        #[source]
150        error: io::Error,
151    },
152    #[error("error serializing request to {provider} API")]
153    SerializeRequest {
154        provider: LanguageModelProviderName,
155        #[source]
156        error: serde_json::Error,
157    },
158    #[error("error building request body to {provider} API")]
159    BuildRequestBody {
160        provider: LanguageModelProviderName,
161        #[source]
162        error: http::Error,
163    },
164    #[error("error sending HTTP request to {provider} API")]
165    HttpSend {
166        provider: LanguageModelProviderName,
167        #[source]
168        error: anyhow::Error,
169    },
170    #[error("error deserializing {provider} API response")]
171    DeserializeResponse {
172        provider: LanguageModelProviderName,
173        #[source]
174        error: serde_json::Error,
175    },
176
177    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
178    #[error(transparent)]
179    Other(#[from] anyhow::Error),
180}
181
182impl LanguageModelCompletionError {
183    fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
184        let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
185        let upstream_status = error_json
186            .get("upstream_status")
187            .and_then(|v| v.as_u64())
188            .and_then(|status| u16::try_from(status).ok())
189            .and_then(|status| StatusCode::from_u16(status).ok())?;
190        let inner_message = error_json
191            .get("message")
192            .and_then(|v| v.as_str())
193            .unwrap_or(message)
194            .to_string();
195        Some((upstream_status, inner_message))
196    }
197
198    pub fn from_cloud_failure(
199        upstream_provider: LanguageModelProviderName,
200        code: String,
201        message: String,
202        retry_after: Option<Duration>,
203    ) -> Self {
204        if let Some(tokens) = parse_prompt_too_long(&message) {
205            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
206            // to be reported. This is a temporary workaround to handle this in the case where the
207            // token limit has been exceeded.
208            Self::PromptTooLarge {
209                tokens: Some(tokens),
210            }
211        } else if code == "upstream_http_error" {
212            if let Some((upstream_status, inner_message)) =
213                Self::parse_upstream_error_json(&message)
214            {
215                return Self::from_http_status(
216                    upstream_provider,
217                    upstream_status,
218                    inner_message,
219                    retry_after,
220                );
221            }
222            anyhow!("completion request failed, code: {code}, message: {message}").into()
223        } else if let Some(status_code) = code
224            .strip_prefix("upstream_http_")
225            .and_then(|code| StatusCode::from_str(code).ok())
226        {
227            Self::from_http_status(upstream_provider, status_code, message, retry_after)
228        } else if let Some(status_code) = code
229            .strip_prefix("http_")
230            .and_then(|code| StatusCode::from_str(code).ok())
231        {
232            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
233        } else {
234            anyhow!("completion request failed, code: {code}, message: {message}").into()
235        }
236    }
237
238    pub fn from_http_status(
239        provider: LanguageModelProviderName,
240        status_code: StatusCode,
241        message: String,
242        retry_after: Option<Duration>,
243    ) -> Self {
244        match status_code {
245            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
246            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
247            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
248            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
249            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
250                tokens: parse_prompt_too_long(&message),
251            },
252            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
253                provider,
254                retry_after,
255            },
256            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
257            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
258                provider,
259                retry_after,
260            },
261            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
262                provider,
263                retry_after,
264            },
265            _ => Self::HttpResponseError {
266                provider,
267                status_code,
268                message,
269            },
270        }
271    }
272}
273
274impl From<AnthropicError> for LanguageModelCompletionError {
275    fn from(error: AnthropicError) -> Self {
276        let provider = ANTHROPIC_PROVIDER_NAME;
277        match error {
278            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
279            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
280            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
281            AnthropicError::DeserializeResponse(error) => {
282                Self::DeserializeResponse { provider, error }
283            }
284            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
285            AnthropicError::HttpResponseError {
286                status_code,
287                message,
288            } => Self::HttpResponseError {
289                provider,
290                status_code,
291                message,
292            },
293            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
294                provider,
295                retry_after: Some(retry_after),
296            },
297            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
298                provider,
299                retry_after,
300            },
301            AnthropicError::ApiError(api_error) => api_error.into(),
302        }
303    }
304}
305
306impl From<anthropic::ApiError> for LanguageModelCompletionError {
307    fn from(error: anthropic::ApiError) -> Self {
308        use anthropic::ApiErrorCode::*;
309        let provider = ANTHROPIC_PROVIDER_NAME;
310        match error.code() {
311            Some(code) => match code {
312                InvalidRequestError => Self::BadRequestFormat {
313                    provider,
314                    message: error.message,
315                },
316                AuthenticationError => Self::AuthenticationError {
317                    provider,
318                    message: error.message,
319                },
320                PermissionError => Self::PermissionError {
321                    provider,
322                    message: error.message,
323                },
324                NotFoundError => Self::ApiEndpointNotFound { provider },
325                RequestTooLarge => Self::PromptTooLarge {
326                    tokens: parse_prompt_too_long(&error.message),
327                },
328                RateLimitError => Self::RateLimitExceeded {
329                    provider,
330                    retry_after: None,
331                },
332                ApiError => Self::ApiInternalServerError {
333                    provider,
334                    message: error.message,
335                },
336                OverloadedError => Self::ServerOverloaded {
337                    provider,
338                    retry_after: None,
339                },
340            },
341            None => Self::Other(error.into()),
342        }
343    }
344}
345
346impl From<OpenRouterError> for LanguageModelCompletionError {
347    fn from(error: OpenRouterError) -> Self {
348        let provider = LanguageModelProviderName::new("OpenRouter");
349        match error {
350            OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
351            OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
352            OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
353            OpenRouterError::DeserializeResponse(error) => {
354                Self::DeserializeResponse { provider, error }
355            }
356            OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
357            OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
358                provider,
359                retry_after: Some(retry_after),
360            },
361            OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
362                provider,
363                retry_after,
364            },
365            OpenRouterError::ApiError(api_error) => api_error.into(),
366        }
367    }
368}
369
370impl From<open_router::ApiError> for LanguageModelCompletionError {
371    fn from(error: open_router::ApiError) -> Self {
372        use open_router::ApiErrorCode::*;
373        let provider = LanguageModelProviderName::new("OpenRouter");
374        match error.code {
375            InvalidRequestError => Self::BadRequestFormat {
376                provider,
377                message: error.message,
378            },
379            AuthenticationError => Self::AuthenticationError {
380                provider,
381                message: error.message,
382            },
383            PaymentRequiredError => Self::AuthenticationError {
384                provider,
385                message: format!("Payment required: {}", error.message),
386            },
387            PermissionError => Self::PermissionError {
388                provider,
389                message: error.message,
390            },
391            RequestTimedOut => Self::HttpResponseError {
392                provider,
393                status_code: StatusCode::REQUEST_TIMEOUT,
394                message: error.message,
395            },
396            RateLimitError => Self::RateLimitExceeded {
397                provider,
398                retry_after: None,
399            },
400            ApiError => Self::ApiInternalServerError {
401                provider,
402                message: error.message,
403            },
404            OverloadedError => Self::ServerOverloaded {
405                provider,
406                retry_after: None,
407            },
408        }
409    }
410}
411
412/// Indicates the format used to define the input schema for a language model tool.
413#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
414pub enum LanguageModelToolSchemaFormat {
415    /// A JSON schema, see https://json-schema.org
416    JsonSchema,
417    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
418    JsonSchemaSubset,
419}
420
421#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
422#[serde(rename_all = "snake_case")]
423pub enum StopReason {
424    EndTurn,
425    MaxTokens,
426    ToolUse,
427    Refusal,
428}
429
430#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
431pub struct TokenUsage {
432    #[serde(default, skip_serializing_if = "is_default")]
433    pub input_tokens: u64,
434    #[serde(default, skip_serializing_if = "is_default")]
435    pub output_tokens: u64,
436    #[serde(default, skip_serializing_if = "is_default")]
437    pub cache_creation_input_tokens: u64,
438    #[serde(default, skip_serializing_if = "is_default")]
439    pub cache_read_input_tokens: u64,
440}
441
442impl TokenUsage {
443    pub fn total_tokens(&self) -> u64 {
444        self.input_tokens
445            + self.output_tokens
446            + self.cache_read_input_tokens
447            + self.cache_creation_input_tokens
448    }
449}
450
451impl Add<TokenUsage> for TokenUsage {
452    type Output = Self;
453
454    fn add(self, other: Self) -> Self {
455        Self {
456            input_tokens: self.input_tokens + other.input_tokens,
457            output_tokens: self.output_tokens + other.output_tokens,
458            cache_creation_input_tokens: self.cache_creation_input_tokens
459                + other.cache_creation_input_tokens,
460            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
461        }
462    }
463}
464
465impl Sub<TokenUsage> for TokenUsage {
466    type Output = Self;
467
468    fn sub(self, other: Self) -> Self {
469        Self {
470            input_tokens: self.input_tokens - other.input_tokens,
471            output_tokens: self.output_tokens - other.output_tokens,
472            cache_creation_input_tokens: self.cache_creation_input_tokens
473                - other.cache_creation_input_tokens,
474            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
475        }
476    }
477}
478
479#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
480pub struct LanguageModelToolUseId(Arc<str>);
481
482impl fmt::Display for LanguageModelToolUseId {
483    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
484        write!(f, "{}", self.0)
485    }
486}
487
488impl<T> From<T> for LanguageModelToolUseId
489where
490    T: Into<Arc<str>>,
491{
492    fn from(value: T) -> Self {
493        Self(value.into())
494    }
495}
496
497#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
498pub struct LanguageModelToolUse {
499    pub id: LanguageModelToolUseId,
500    pub name: Arc<str>,
501    pub raw_input: String,
502    pub input: serde_json::Value,
503    pub is_input_complete: bool,
504    /// Thought signature the model sent us. Some models require that this
505    /// signature be preserved and sent back in conversation history for validation.
506    pub thought_signature: Option<String>,
507}
508
509pub struct LanguageModelTextStream {
510    pub message_id: Option<String>,
511    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
512    // Has complete token usage after the stream has finished
513    pub last_token_usage: Arc<Mutex<TokenUsage>>,
514}
515
516impl Default for LanguageModelTextStream {
517    fn default() -> Self {
518        Self {
519            message_id: None,
520            stream: Box::pin(futures::stream::empty()),
521            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
522        }
523    }
524}
525
526pub trait LanguageModel: Send + Sync {
527    fn id(&self) -> LanguageModelId;
528    fn name(&self) -> LanguageModelName;
529    fn provider_id(&self) -> LanguageModelProviderId;
530    fn provider_name(&self) -> LanguageModelProviderName;
531    fn upstream_provider_id(&self) -> LanguageModelProviderId {
532        self.provider_id()
533    }
534    fn upstream_provider_name(&self) -> LanguageModelProviderName {
535        self.provider_name()
536    }
537
538    fn telemetry_id(&self) -> String;
539
540    fn api_key(&self, _cx: &App) -> Option<String> {
541        None
542    }
543
544    /// Whether this model supports images
545    fn supports_images(&self) -> bool;
546
547    /// Whether this model supports tools.
548    fn supports_tools(&self) -> bool;
549
550    /// Whether this model supports choosing which tool to use.
551    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
552
553    /// Returns whether this model supports "burn mode";
554    fn supports_burn_mode(&self) -> bool {
555        false
556    }
557
558    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
559        LanguageModelToolSchemaFormat::JsonSchema
560    }
561
562    fn max_token_count(&self) -> u64;
563    /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
564    fn max_token_count_in_burn_mode(&self) -> Option<u64> {
565        None
566    }
567    fn max_output_tokens(&self) -> Option<u64> {
568        None
569    }
570
571    fn count_tokens(
572        &self,
573        request: LanguageModelRequest,
574        cx: &App,
575    ) -> BoxFuture<'static, Result<u64>>;
576
577    fn stream_completion(
578        &self,
579        request: LanguageModelRequest,
580        cx: &AsyncApp,
581    ) -> BoxFuture<
582        'static,
583        Result<
584            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
585            LanguageModelCompletionError,
586        >,
587    >;
588
589    fn stream_completion_text(
590        &self,
591        request: LanguageModelRequest,
592        cx: &AsyncApp,
593    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
594        let future = self.stream_completion(request, cx);
595
596        async move {
597            let events = future.await?;
598            let mut events = events.fuse();
599            let mut message_id = None;
600            let mut first_item_text = None;
601            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
602
603            if let Some(first_event) = events.next().await {
604                match first_event {
605                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
606                        message_id = Some(id);
607                    }
608                    Ok(LanguageModelCompletionEvent::Text(text)) => {
609                        first_item_text = Some(text);
610                    }
611                    _ => (),
612                }
613            }
614
615            let stream = futures::stream::iter(first_item_text.map(Ok))
616                .chain(events.filter_map({
617                    let last_token_usage = last_token_usage.clone();
618                    move |result| {
619                        let last_token_usage = last_token_usage.clone();
620                        async move {
621                            match result {
622                                Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
623                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
624                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
625                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
626                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
627                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
628                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
629                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
630                                    ..
631                                }) => None,
632                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
633                                    *last_token_usage.lock() = token_usage;
634                                    None
635                                }
636                                Err(err) => Some(Err(err)),
637                            }
638                        }
639                    }
640                }))
641                .boxed();
642
643            Ok(LanguageModelTextStream {
644                message_id,
645                stream,
646                last_token_usage,
647            })
648        }
649        .boxed()
650    }
651
652    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
653        None
654    }
655
656    #[cfg(any(test, feature = "test-support"))]
657    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
658        unimplemented!()
659    }
660}
661
662pub trait LanguageModelExt: LanguageModel {
663    fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
664        match mode {
665            CompletionMode::Normal => self.max_token_count(),
666            CompletionMode::Max => self
667                .max_token_count_in_burn_mode()
668                .unwrap_or_else(|| self.max_token_count()),
669        }
670    }
671}
672impl LanguageModelExt for dyn LanguageModel {}
673
674/// An error that occurred when trying to authenticate the language model provider.
675#[derive(Debug, Error)]
676pub enum AuthenticateError {
677    #[error("connection refused")]
678    ConnectionRefused,
679    #[error("credentials not found")]
680    CredentialsNotFound,
681    #[error(transparent)]
682    Other(#[from] anyhow::Error),
683}
684
685pub trait LanguageModelProvider: 'static {
686    fn id(&self) -> LanguageModelProviderId;
687    fn name(&self) -> LanguageModelProviderName;
688    fn icon(&self) -> IconName {
689        IconName::ZedAssistant
690    }
691    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
692    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
693    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
694    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
695        Vec::new()
696    }
697    fn is_authenticated(&self, cx: &App) -> bool;
698    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
699    fn configuration_view(
700        &self,
701        target_agent: ConfigurationViewTargetAgent,
702        window: &mut Window,
703        cx: &mut App,
704    ) -> AnyView;
705    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
706}
707
708#[derive(Default, Clone)]
709pub enum ConfigurationViewTargetAgent {
710    #[default]
711    ZedAgent,
712    Other(SharedString),
713}
714
715#[derive(PartialEq, Eq)]
716pub enum LanguageModelProviderTosView {
717    /// When there are some past interactions in the Agent Panel.
718    ThreadEmptyState,
719    /// When there are no past interactions in the Agent Panel.
720    ThreadFreshStart,
721    TextThreadPopup,
722    Configuration,
723}
724
725pub trait LanguageModelProviderState: 'static {
726    type ObservableEntity;
727
728    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
729
730    fn subscribe<T: 'static>(
731        &self,
732        cx: &mut gpui::Context<T>,
733        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
734    ) -> Option<gpui::Subscription> {
735        let entity = self.observable_entity()?;
736        Some(cx.observe(&entity, move |this, _, cx| {
737            callback(this, cx);
738        }))
739    }
740}
741
742#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
743pub struct LanguageModelId(pub SharedString);
744
745#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
746pub struct LanguageModelName(pub SharedString);
747
748#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
749pub struct LanguageModelProviderId(pub SharedString);
750
751#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
752pub struct LanguageModelProviderName(pub SharedString);
753
754impl LanguageModelProviderId {
755    pub const fn new(id: &'static str) -> Self {
756        Self(SharedString::new_static(id))
757    }
758}
759
760impl LanguageModelProviderName {
761    pub const fn new(id: &'static str) -> Self {
762        Self(SharedString::new_static(id))
763    }
764}
765
766impl fmt::Display for LanguageModelProviderId {
767    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
768        write!(f, "{}", self.0)
769    }
770}
771
772impl fmt::Display for LanguageModelProviderName {
773    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
774        write!(f, "{}", self.0)
775    }
776}
777
778impl From<String> for LanguageModelId {
779    fn from(value: String) -> Self {
780        Self(SharedString::from(value))
781    }
782}
783
784impl From<String> for LanguageModelName {
785    fn from(value: String) -> Self {
786        Self(SharedString::from(value))
787    }
788}
789
790impl From<String> for LanguageModelProviderId {
791    fn from(value: String) -> Self {
792        Self(SharedString::from(value))
793    }
794}
795
796impl From<String> for LanguageModelProviderName {
797    fn from(value: String) -> Self {
798        Self(SharedString::from(value))
799    }
800}
801
802impl From<Arc<str>> for LanguageModelProviderId {
803    fn from(value: Arc<str>) -> Self {
804        Self(SharedString::from(value))
805    }
806}
807
808impl From<Arc<str>> for LanguageModelProviderName {
809    fn from(value: Arc<str>) -> Self {
810        Self(SharedString::from(value))
811    }
812}
813
814#[cfg(test)]
815mod tests {
816    use super::*;
817
818    #[test]
819    fn test_from_cloud_failure_with_upstream_http_error() {
820        let error = LanguageModelCompletionError::from_cloud_failure(
821            String::from("anthropic").into(),
822            "upstream_http_error".to_string(),
823            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
824            None,
825        );
826
827        match error {
828            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
829                assert_eq!(provider.0, "anthropic");
830            }
831            _ => panic!(
832                "Expected ServerOverloaded error for 503 status, got: {:?}",
833                error
834            ),
835        }
836
837        let error = LanguageModelCompletionError::from_cloud_failure(
838            String::from("anthropic").into(),
839            "upstream_http_error".to_string(),
840            r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
841            None,
842        );
843
844        match error {
845            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
846                assert_eq!(provider.0, "anthropic");
847                assert_eq!(message, "Internal server error");
848            }
849            _ => panic!(
850                "Expected ApiInternalServerError for 500 status, got: {:?}",
851                error
852            ),
853        }
854    }
855
856    #[test]
857    fn test_from_cloud_failure_with_standard_format() {
858        let error = LanguageModelCompletionError::from_cloud_failure(
859            String::from("anthropic").into(),
860            "upstream_http_503".to_string(),
861            "Service unavailable".to_string(),
862            None,
863        );
864
865        match error {
866            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
867                assert_eq!(provider.0, "anthropic");
868            }
869            _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
870        }
871    }
872
873    #[test]
874    fn test_upstream_http_error_connection_timeout() {
875        let error = LanguageModelCompletionError::from_cloud_failure(
876            String::from("anthropic").into(),
877            "upstream_http_error".to_string(),
878            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
879            None,
880        );
881
882        match error {
883            LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
884                assert_eq!(provider.0, "anthropic");
885            }
886            _ => panic!(
887                "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
888                error
889            ),
890        }
891
892        let error = LanguageModelCompletionError::from_cloud_failure(
893            String::from("anthropic").into(),
894            "upstream_http_error".to_string(),
895            r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
896            None,
897        );
898
899        match error {
900            LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
901                assert_eq!(provider.0, "anthropic");
902                assert_eq!(
903                    message,
904                    "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
905                );
906            }
907            _ => panic!(
908                "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
909                error
910            ),
911        }
912    }
913
914    #[test]
915    fn test_language_model_tool_use_serializes_with_signature() {
916        use serde_json::json;
917
918        let tool_use = LanguageModelToolUse {
919            id: LanguageModelToolUseId::from("test_id"),
920            name: "test_tool".into(),
921            raw_input: json!({"arg": "value"}).to_string(),
922            input: json!({"arg": "value"}),
923            is_input_complete: true,
924            thought_signature: Some("test_signature".to_string()),
925        };
926
927        let serialized = serde_json::to_value(&tool_use).unwrap();
928
929        assert_eq!(serialized["id"], "test_id");
930        assert_eq!(serialized["name"], "test_tool");
931        assert_eq!(serialized["thought_signature"], "test_signature");
932    }
933
934    #[test]
935    fn test_language_model_tool_use_deserializes_with_missing_signature() {
936        use serde_json::json;
937
938        let json = json!({
939            "id": "test_id",
940            "name": "test_tool",
941            "raw_input": "{\"arg\":\"value\"}",
942            "input": {"arg": "value"},
943            "is_input_complete": true
944        });
945
946        let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
947
948        assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
949        assert_eq!(tool_use.name.as_ref(), "test_tool");
950        assert_eq!(tool_use.thought_signature, None);
951    }
952
953    #[test]
954    fn test_language_model_tool_use_round_trip_with_signature() {
955        use serde_json::json;
956
957        let original = LanguageModelToolUse {
958            id: LanguageModelToolUseId::from("round_trip_id"),
959            name: "round_trip_tool".into(),
960            raw_input: json!({"key": "value"}).to_string(),
961            input: json!({"key": "value"}),
962            is_input_complete: true,
963            thought_signature: Some("round_trip_sig".to_string()),
964        };
965
966        let serialized = serde_json::to_value(&original).unwrap();
967        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
968
969        assert_eq!(deserialized.id, original.id);
970        assert_eq!(deserialized.name, original.name);
971        assert_eq!(deserialized.thought_signature, original.thought_signature);
972    }
973
974    #[test]
975    fn test_language_model_tool_use_round_trip_without_signature() {
976        use serde_json::json;
977
978        let original = LanguageModelToolUse {
979            id: LanguageModelToolUseId::from("no_sig_id"),
980            name: "no_sig_tool".into(),
981            raw_input: json!({"key": "value"}).to_string(),
982            input: json!({"key": "value"}),
983            is_input_complete: true,
984            thought_signature: None,
985        };
986
987        let serialized = serde_json::to_value(&original).unwrap();
988        let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
989
990        assert_eq!(deserialized.id, original.id);
991        assert_eq!(deserialized.name, original.name);
992        assert_eq!(deserialized.thought_signature, None);
993    }
994}