language_model.rs

  1mod model;
  2mod rate_limiter;
  3mod registry;
  4mod request;
  5mod role;
  6mod telemetry;
  7
  8#[cfg(any(test, feature = "test-support"))]
  9pub mod fake_provider;
 10
 11use anthropic::{AnthropicError, parse_prompt_too_long};
 12use anyhow::{Result, anyhow};
 13use client::Client;
 14use futures::FutureExt;
 15use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 16use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
 17use http_client::{StatusCode, http};
 18use icons::IconName;
 19use parking_lot::Mutex;
 20use schemars::JsonSchema;
 21use serde::{Deserialize, Serialize, de::DeserializeOwned};
 22use std::ops::{Add, Sub};
 23use std::str::FromStr;
 24use std::sync::Arc;
 25use std::time::Duration;
 26use std::{fmt, io};
 27use thiserror::Error;
 28use util::serde::is_default;
 29use zed_llm_client::CompletionRequestStatus;
 30
 31pub use crate::model::*;
 32pub use crate::rate_limiter::*;
 33pub use crate::registry::*;
 34pub use crate::request::*;
 35pub use crate::role::*;
 36pub use crate::telemetry::*;
 37
 38pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
 39    LanguageModelProviderId::new("anthropic");
 40pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
 41    LanguageModelProviderName::new("Anthropic");
 42
 43pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
 44pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
 45    LanguageModelProviderName::new("Google AI");
 46
 47pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
 48pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
 49    LanguageModelProviderName::new("OpenAI");
 50
 51pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
 52pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
 53    LanguageModelProviderName::new("Zed");
 54
 55pub fn init(client: Arc<Client>, cx: &mut App) {
 56    init_settings(cx);
 57    RefreshLlmTokenListener::register(client.clone(), cx);
 58}
 59
 60pub fn init_settings(cx: &mut App) {
 61    registry::init(cx);
 62}
 63
 64/// Configuration for caching language model messages.
 65#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 66pub struct LanguageModelCacheConfiguration {
 67    pub max_cache_anchors: usize,
 68    pub should_speculate: bool,
 69    pub min_total_token: u64,
 70}
 71
 72/// A completion event from a language model.
 73#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 74pub enum LanguageModelCompletionEvent {
 75    StatusUpdate(CompletionRequestStatus),
 76    Stop(StopReason),
 77    Text(String),
 78    Thinking {
 79        text: String,
 80        signature: Option<String>,
 81    },
 82    RedactedThinking {
 83        data: String,
 84    },
 85    ToolUse(LanguageModelToolUse),
 86    ToolUseJsonParseError {
 87        id: LanguageModelToolUseId,
 88        tool_name: Arc<str>,
 89        raw_input: Arc<str>,
 90        json_parse_error: String,
 91    },
 92    StartMessage {
 93        message_id: String,
 94    },
 95    UsageUpdate(TokenUsage),
 96}
 97
 98#[derive(Error, Debug)]
 99pub enum LanguageModelCompletionError {
100    #[error("prompt too large for context window")]
101    PromptTooLarge { tokens: Option<u64> },
102    #[error("missing {provider} API key")]
103    NoApiKey { provider: LanguageModelProviderName },
104    #[error("{provider}'s API rate limit exceeded")]
105    RateLimitExceeded {
106        provider: LanguageModelProviderName,
107        retry_after: Option<Duration>,
108    },
109    #[error("{provider}'s API servers are overloaded right now")]
110    ServerOverloaded {
111        provider: LanguageModelProviderName,
112        retry_after: Option<Duration>,
113    },
114    #[error("{provider}'s API server reported an internal server error: {message}")]
115    ApiInternalServerError {
116        provider: LanguageModelProviderName,
117        message: String,
118    },
119    #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
120    HttpResponseError {
121        provider: LanguageModelProviderName,
122        status_code: StatusCode,
123        message: String,
124    },
125
126    // Client errors
127    #[error("invalid request format to {provider}'s API: {message}")]
128    BadRequestFormat {
129        provider: LanguageModelProviderName,
130        message: String,
131    },
132    #[error("authentication error with {provider}'s API: {message}")]
133    AuthenticationError {
134        provider: LanguageModelProviderName,
135        message: String,
136    },
137    #[error("permission error with {provider}'s API: {message}")]
138    PermissionError {
139        provider: LanguageModelProviderName,
140        message: String,
141    },
142    #[error("language model provider API endpoint not found")]
143    ApiEndpointNotFound { provider: LanguageModelProviderName },
144    #[error("I/O error reading response from {provider}'s API")]
145    ApiReadResponseError {
146        provider: LanguageModelProviderName,
147        #[source]
148        error: io::Error,
149    },
150    #[error("error serializing request to {provider} API")]
151    SerializeRequest {
152        provider: LanguageModelProviderName,
153        #[source]
154        error: serde_json::Error,
155    },
156    #[error("error building request body to {provider} API")]
157    BuildRequestBody {
158        provider: LanguageModelProviderName,
159        #[source]
160        error: http::Error,
161    },
162    #[error("error sending HTTP request to {provider} API")]
163    HttpSend {
164        provider: LanguageModelProviderName,
165        #[source]
166        error: anyhow::Error,
167    },
168    #[error("error deserializing {provider} API response")]
169    DeserializeResponse {
170        provider: LanguageModelProviderName,
171        #[source]
172        error: serde_json::Error,
173    },
174
175    // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
176    #[error(transparent)]
177    Other(#[from] anyhow::Error),
178}
179
180impl LanguageModelCompletionError {
181    pub fn from_cloud_failure(
182        upstream_provider: LanguageModelProviderName,
183        code: String,
184        message: String,
185        retry_after: Option<Duration>,
186    ) -> Self {
187        if let Some(tokens) = parse_prompt_too_long(&message) {
188            // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
189            // to be reported. This is a temporary workaround to handle this in the case where the
190            // token limit has been exceeded.
191            Self::PromptTooLarge {
192                tokens: Some(tokens),
193            }
194        } else if let Some(status_code) = code
195            .strip_prefix("upstream_http_")
196            .and_then(|code| StatusCode::from_str(code).ok())
197        {
198            Self::from_http_status(upstream_provider, status_code, message, retry_after)
199        } else if let Some(status_code) = code
200            .strip_prefix("http_")
201            .and_then(|code| StatusCode::from_str(code).ok())
202        {
203            Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
204        } else {
205            anyhow!("completion request failed, code: {code}, message: {message}").into()
206        }
207    }
208
209    pub fn from_http_status(
210        provider: LanguageModelProviderName,
211        status_code: StatusCode,
212        message: String,
213        retry_after: Option<Duration>,
214    ) -> Self {
215        match status_code {
216            StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
217            StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
218            StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
219            StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
220            StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
221                tokens: parse_prompt_too_long(&message),
222            },
223            StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
224                provider,
225                retry_after,
226            },
227            StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
228            StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
229                provider,
230                retry_after,
231            },
232            _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
233                provider,
234                retry_after,
235            },
236            _ => Self::HttpResponseError {
237                provider,
238                status_code,
239                message,
240            },
241        }
242    }
243}
244
245impl From<AnthropicError> for LanguageModelCompletionError {
246    fn from(error: AnthropicError) -> Self {
247        let provider = ANTHROPIC_PROVIDER_NAME;
248        match error {
249            AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
250            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
251            AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
252            AnthropicError::DeserializeResponse(error) => {
253                Self::DeserializeResponse { provider, error }
254            }
255            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
256            AnthropicError::HttpResponseError {
257                status_code,
258                message,
259            } => Self::HttpResponseError {
260                provider,
261                status_code,
262                message,
263            },
264            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
265                provider,
266                retry_after: Some(retry_after),
267            },
268            AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
269                provider,
270                retry_after: retry_after,
271            },
272            AnthropicError::ApiError(api_error) => api_error.into(),
273        }
274    }
275}
276
277impl From<anthropic::ApiError> for LanguageModelCompletionError {
278    fn from(error: anthropic::ApiError) -> Self {
279        use anthropic::ApiErrorCode::*;
280        let provider = ANTHROPIC_PROVIDER_NAME;
281        match error.code() {
282            Some(code) => match code {
283                InvalidRequestError => Self::BadRequestFormat {
284                    provider,
285                    message: error.message,
286                },
287                AuthenticationError => Self::AuthenticationError {
288                    provider,
289                    message: error.message,
290                },
291                PermissionError => Self::PermissionError {
292                    provider,
293                    message: error.message,
294                },
295                NotFoundError => Self::ApiEndpointNotFound { provider },
296                RequestTooLarge => Self::PromptTooLarge {
297                    tokens: parse_prompt_too_long(&error.message),
298                },
299                RateLimitError => Self::RateLimitExceeded {
300                    provider,
301                    retry_after: None,
302                },
303                ApiError => Self::ApiInternalServerError {
304                    provider,
305                    message: error.message,
306                },
307                OverloadedError => Self::ServerOverloaded {
308                    provider,
309                    retry_after: None,
310                },
311            },
312            None => Self::Other(error.into()),
313        }
314    }
315}
316
317/// Indicates the format used to define the input schema for a language model tool.
318#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
319pub enum LanguageModelToolSchemaFormat {
320    /// A JSON schema, see https://json-schema.org
321    JsonSchema,
322    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
323    JsonSchemaSubset,
324}
325
326#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
327#[serde(rename_all = "snake_case")]
328pub enum StopReason {
329    EndTurn,
330    MaxTokens,
331    ToolUse,
332    Refusal,
333}
334
335#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
336pub struct TokenUsage {
337    #[serde(default, skip_serializing_if = "is_default")]
338    pub input_tokens: u64,
339    #[serde(default, skip_serializing_if = "is_default")]
340    pub output_tokens: u64,
341    #[serde(default, skip_serializing_if = "is_default")]
342    pub cache_creation_input_tokens: u64,
343    #[serde(default, skip_serializing_if = "is_default")]
344    pub cache_read_input_tokens: u64,
345}
346
347impl TokenUsage {
348    pub fn total_tokens(&self) -> u64 {
349        self.input_tokens
350            + self.output_tokens
351            + self.cache_read_input_tokens
352            + self.cache_creation_input_tokens
353    }
354}
355
356impl Add<TokenUsage> for TokenUsage {
357    type Output = Self;
358
359    fn add(self, other: Self) -> Self {
360        Self {
361            input_tokens: self.input_tokens + other.input_tokens,
362            output_tokens: self.output_tokens + other.output_tokens,
363            cache_creation_input_tokens: self.cache_creation_input_tokens
364                + other.cache_creation_input_tokens,
365            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
366        }
367    }
368}
369
370impl Sub<TokenUsage> for TokenUsage {
371    type Output = Self;
372
373    fn sub(self, other: Self) -> Self {
374        Self {
375            input_tokens: self.input_tokens - other.input_tokens,
376            output_tokens: self.output_tokens - other.output_tokens,
377            cache_creation_input_tokens: self.cache_creation_input_tokens
378                - other.cache_creation_input_tokens,
379            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
380        }
381    }
382}
383
384#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
385pub struct LanguageModelToolUseId(Arc<str>);
386
387impl fmt::Display for LanguageModelToolUseId {
388    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
389        write!(f, "{}", self.0)
390    }
391}
392
393impl<T> From<T> for LanguageModelToolUseId
394where
395    T: Into<Arc<str>>,
396{
397    fn from(value: T) -> Self {
398        Self(value.into())
399    }
400}
401
402#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
403pub struct LanguageModelToolUse {
404    pub id: LanguageModelToolUseId,
405    pub name: Arc<str>,
406    pub raw_input: String,
407    pub input: serde_json::Value,
408    pub is_input_complete: bool,
409}
410
411pub struct LanguageModelTextStream {
412    pub message_id: Option<String>,
413    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
414    // Has complete token usage after the stream has finished
415    pub last_token_usage: Arc<Mutex<TokenUsage>>,
416}
417
418impl Default for LanguageModelTextStream {
419    fn default() -> Self {
420        Self {
421            message_id: None,
422            stream: Box::pin(futures::stream::empty()),
423            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
424        }
425    }
426}
427
428pub trait LanguageModel: Send + Sync {
429    fn id(&self) -> LanguageModelId;
430    fn name(&self) -> LanguageModelName;
431    fn provider_id(&self) -> LanguageModelProviderId;
432    fn provider_name(&self) -> LanguageModelProviderName;
433    fn upstream_provider_id(&self) -> LanguageModelProviderId {
434        self.provider_id()
435    }
436    fn upstream_provider_name(&self) -> LanguageModelProviderName {
437        self.provider_name()
438    }
439
440    fn telemetry_id(&self) -> String;
441
442    fn api_key(&self, _cx: &App) -> Option<String> {
443        None
444    }
445
446    /// Whether this model supports images
447    fn supports_images(&self) -> bool;
448
449    /// Whether this model supports tools.
450    fn supports_tools(&self) -> bool;
451
452    /// Whether this model supports choosing which tool to use.
453    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
454
455    /// Returns whether this model supports "burn mode";
456    fn supports_burn_mode(&self) -> bool {
457        false
458    }
459
460    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
461        LanguageModelToolSchemaFormat::JsonSchema
462    }
463
464    fn max_token_count(&self) -> u64;
465    fn max_output_tokens(&self) -> Option<u64> {
466        None
467    }
468
469    fn count_tokens(
470        &self,
471        request: LanguageModelRequest,
472        cx: &App,
473    ) -> BoxFuture<'static, Result<u64>>;
474
475    fn stream_completion(
476        &self,
477        request: LanguageModelRequest,
478        cx: &AsyncApp,
479    ) -> BoxFuture<
480        'static,
481        Result<
482            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
483            LanguageModelCompletionError,
484        >,
485    >;
486
487    fn stream_completion_text(
488        &self,
489        request: LanguageModelRequest,
490        cx: &AsyncApp,
491    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
492        let future = self.stream_completion(request, cx);
493
494        async move {
495            let events = future.await?;
496            let mut events = events.fuse();
497            let mut message_id = None;
498            let mut first_item_text = None;
499            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
500
501            if let Some(first_event) = events.next().await {
502                match first_event {
503                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
504                        message_id = Some(id.clone());
505                    }
506                    Ok(LanguageModelCompletionEvent::Text(text)) => {
507                        first_item_text = Some(text);
508                    }
509                    _ => (),
510                }
511            }
512
513            let stream = futures::stream::iter(first_item_text.map(Ok))
514                .chain(events.filter_map({
515                    let last_token_usage = last_token_usage.clone();
516                    move |result| {
517                        let last_token_usage = last_token_usage.clone();
518                        async move {
519                            match result {
520                                Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
521                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
522                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
523                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
524                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
525                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
526                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
527                                Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
528                                    ..
529                                }) => None,
530                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
531                                    *last_token_usage.lock() = token_usage;
532                                    None
533                                }
534                                Err(err) => Some(Err(err)),
535                            }
536                        }
537                    }
538                }))
539                .boxed();
540
541            Ok(LanguageModelTextStream {
542                message_id,
543                stream,
544                last_token_usage,
545            })
546        }
547        .boxed()
548    }
549
550    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
551        None
552    }
553
554    #[cfg(any(test, feature = "test-support"))]
555    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
556        unimplemented!()
557    }
558}
559
560pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
561    fn name() -> String;
562    fn description() -> String;
563}
564
565/// An error that occurred when trying to authenticate the language model provider.
566#[derive(Debug, Error)]
567pub enum AuthenticateError {
568    #[error("credentials not found")]
569    CredentialsNotFound,
570    #[error(transparent)]
571    Other(#[from] anyhow::Error),
572}
573
574pub trait LanguageModelProvider: 'static {
575    fn id(&self) -> LanguageModelProviderId;
576    fn name(&self) -> LanguageModelProviderName;
577    fn icon(&self) -> IconName {
578        IconName::ZedAssistant
579    }
580    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
581    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
582    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
583    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
584        Vec::new()
585    }
586    fn is_authenticated(&self, cx: &App) -> bool;
587    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
588    fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
589    fn must_accept_terms(&self, _cx: &App) -> bool {
590        false
591    }
592    fn render_accept_terms(
593        &self,
594        _view: LanguageModelProviderTosView,
595        _cx: &mut App,
596    ) -> Option<AnyElement> {
597        None
598    }
599    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
600}
601
602#[derive(PartialEq, Eq)]
603pub enum LanguageModelProviderTosView {
604    /// When there are some past interactions in the Agent Panel.
605    ThreadEmptyState,
606    /// When there are no past interactions in the Agent Panel.
607    ThreadFreshStart,
608    PromptEditorPopup,
609    Configuration,
610}
611
612pub trait LanguageModelProviderState: 'static {
613    type ObservableEntity;
614
615    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
616
617    fn subscribe<T: 'static>(
618        &self,
619        cx: &mut gpui::Context<T>,
620        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
621    ) -> Option<gpui::Subscription> {
622        let entity = self.observable_entity()?;
623        Some(cx.observe(&entity, move |this, _, cx| {
624            callback(this, cx);
625        }))
626    }
627}
628
629#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
630pub struct LanguageModelId(pub SharedString);
631
632#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
633pub struct LanguageModelName(pub SharedString);
634
635#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
636pub struct LanguageModelProviderId(pub SharedString);
637
638#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
639pub struct LanguageModelProviderName(pub SharedString);
640
641impl LanguageModelProviderId {
642    pub const fn new(id: &'static str) -> Self {
643        Self(SharedString::new_static(id))
644    }
645}
646
647impl LanguageModelProviderName {
648    pub const fn new(id: &'static str) -> Self {
649        Self(SharedString::new_static(id))
650    }
651}
652
653impl fmt::Display for LanguageModelProviderId {
654    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
655        write!(f, "{}", self.0)
656    }
657}
658
659impl fmt::Display for LanguageModelProviderName {
660    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
661        write!(f, "{}", self.0)
662    }
663}
664
665impl From<String> for LanguageModelId {
666    fn from(value: String) -> Self {
667        Self(SharedString::from(value))
668    }
669}
670
671impl From<String> for LanguageModelName {
672    fn from(value: String) -> Self {
673        Self(SharedString::from(value))
674    }
675}
676
677impl From<String> for LanguageModelProviderId {
678    fn from(value: String) -> Self {
679        Self(SharedString::from(value))
680    }
681}
682
683impl From<String> for LanguageModelProviderName {
684    fn from(value: String) -> Self {
685        Self(SharedString::from(value))
686    }
687}