language_model.rs

  1mod model;
  2mod rate_limiter;
  3mod registry;
  4mod request;
  5mod role;
  6mod telemetry;
  7
  8#[cfg(any(test, feature = "test-support"))]
  9pub mod fake_provider;
 10
 11use anthropic::{AnthropicError, parse_prompt_too_long};
 12use anyhow::Result;
 13use client::Client;
 14use futures::FutureExt;
 15use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 16use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
 17use http_client::http;
 18use icons::IconName;
 19use parking_lot::Mutex;
 20use schemars::JsonSchema;
 21use serde::{Deserialize, Serialize, de::DeserializeOwned};
 22use std::ops::{Add, Sub};
 23use std::sync::Arc;
 24use std::time::Duration;
 25use std::{fmt, io};
 26use thiserror::Error;
 27use util::serde::is_default;
 28use zed_llm_client::CompletionRequestStatus;
 29
 30pub use crate::model::*;
 31pub use crate::rate_limiter::*;
 32pub use crate::registry::*;
 33pub use crate::request::*;
 34pub use crate::role::*;
 35pub use crate::telemetry::*;
 36
 37pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
 38
 39/// If we get a rate limit error that doesn't tell us when we can retry,
 40/// default to waiting this long before retrying.
 41const DEFAULT_RATE_LIMIT_RETRY_AFTER: Duration = Duration::from_secs(4);
 42
 43pub fn init(client: Arc<Client>, cx: &mut App) {
 44    init_settings(cx);
 45    RefreshLlmTokenListener::register(client.clone(), cx);
 46}
 47
 48pub fn init_settings(cx: &mut App) {
 49    registry::init(cx);
 50}
 51
 52/// Configuration for caching language model messages.
 53#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 54pub struct LanguageModelCacheConfiguration {
 55    pub max_cache_anchors: usize,
 56    pub should_speculate: bool,
 57    pub min_total_token: u64,
 58}
 59
 60/// A completion event from a language model.
 61#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 62pub enum LanguageModelCompletionEvent {
 63    StatusUpdate(CompletionRequestStatus),
 64    Stop(StopReason),
 65    Text(String),
 66    Thinking {
 67        text: String,
 68        signature: Option<String>,
 69    },
 70    RedactedThinking {
 71        data: String,
 72    },
 73    ToolUse(LanguageModelToolUse),
 74    StartMessage {
 75        message_id: String,
 76    },
 77    UsageUpdate(TokenUsage),
 78}
 79
 80#[derive(Error, Debug)]
 81pub enum LanguageModelCompletionError {
 82    #[error("rate limit exceeded, retry after {retry_after:?}")]
 83    RateLimitExceeded { retry_after: Duration },
 84    #[error("received bad input JSON")]
 85    BadInputJson {
 86        id: LanguageModelToolUseId,
 87        tool_name: Arc<str>,
 88        raw_input: Arc<str>,
 89        json_parse_error: String,
 90    },
 91    #[error("language model provider's API is overloaded")]
 92    Overloaded,
 93    #[error(transparent)]
 94    Other(#[from] anyhow::Error),
 95    #[error("invalid request format to language model provider's API")]
 96    BadRequestFormat,
 97    #[error("authentication error with language model provider's API")]
 98    AuthenticationError,
 99    #[error("permission error with language model provider's API")]
100    PermissionError,
101    #[error("language model provider API endpoint not found")]
102    ApiEndpointNotFound,
103    #[error("prompt too large for context window")]
104    PromptTooLarge { tokens: Option<u64> },
105    #[error("internal server error in language model provider's API")]
106    ApiInternalServerError,
107    #[error("I/O error reading response from language model provider's API: {0:?}")]
108    ApiReadResponseError(io::Error),
109    #[error("HTTP response error from language model provider's API: status {status} - {body:?}")]
110    HttpResponseError { status: u16, body: String },
111    #[error("error serializing request to language model provider API: {0}")]
112    SerializeRequest(serde_json::Error),
113    #[error("error building request body to language model provider API: {0}")]
114    BuildRequestBody(http::Error),
115    #[error("error sending HTTP request to language model provider API: {0}")]
116    HttpSend(anyhow::Error),
117    #[error("error deserializing language model provider API response: {0}")]
118    DeserializeResponse(serde_json::Error),
119    #[error("unexpected language model provider API response format: {0}")]
120    UnknownResponseFormat(String),
121}
122
123impl From<AnthropicError> for LanguageModelCompletionError {
124    fn from(error: AnthropicError) -> Self {
125        match error {
126            AnthropicError::SerializeRequest(error) => Self::SerializeRequest(error),
127            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody(error),
128            AnthropicError::HttpSend(error) => Self::HttpSend(error),
129            AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse(error),
130            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError(error),
131            AnthropicError::HttpResponseError { status, body } => {
132                Self::HttpResponseError { status, body }
133            }
134            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { retry_after },
135            AnthropicError::ApiError(api_error) => api_error.into(),
136            AnthropicError::UnexpectedResponseFormat(error) => Self::UnknownResponseFormat(error),
137        }
138    }
139}
140
141impl From<anthropic::ApiError> for LanguageModelCompletionError {
142    fn from(error: anthropic::ApiError) -> Self {
143        use anthropic::ApiErrorCode::*;
144
145        match error.code() {
146            Some(code) => match code {
147                InvalidRequestError => LanguageModelCompletionError::BadRequestFormat,
148                AuthenticationError => LanguageModelCompletionError::AuthenticationError,
149                PermissionError => LanguageModelCompletionError::PermissionError,
150                NotFoundError => LanguageModelCompletionError::ApiEndpointNotFound,
151                RequestTooLarge => LanguageModelCompletionError::PromptTooLarge {
152                    tokens: parse_prompt_too_long(&error.message),
153                },
154                RateLimitError => LanguageModelCompletionError::RateLimitExceeded {
155                    retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
156                },
157                ApiError => LanguageModelCompletionError::ApiInternalServerError,
158                OverloadedError => LanguageModelCompletionError::Overloaded,
159            },
160            None => LanguageModelCompletionError::Other(error.into()),
161        }
162    }
163}
164
165/// Indicates the format used to define the input schema for a language model tool.
166#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
167pub enum LanguageModelToolSchemaFormat {
168    /// A JSON schema, see https://json-schema.org
169    JsonSchema,
170    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
171    JsonSchemaSubset,
172}
173
174#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
175#[serde(rename_all = "snake_case")]
176pub enum StopReason {
177    EndTurn,
178    MaxTokens,
179    ToolUse,
180    Refusal,
181}
182
183#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
184pub struct TokenUsage {
185    #[serde(default, skip_serializing_if = "is_default")]
186    pub input_tokens: u64,
187    #[serde(default, skip_serializing_if = "is_default")]
188    pub output_tokens: u64,
189    #[serde(default, skip_serializing_if = "is_default")]
190    pub cache_creation_input_tokens: u64,
191    #[serde(default, skip_serializing_if = "is_default")]
192    pub cache_read_input_tokens: u64,
193}
194
195impl TokenUsage {
196    pub fn total_tokens(&self) -> u64 {
197        self.input_tokens
198            + self.output_tokens
199            + self.cache_read_input_tokens
200            + self.cache_creation_input_tokens
201    }
202}
203
204impl Add<TokenUsage> for TokenUsage {
205    type Output = Self;
206
207    fn add(self, other: Self) -> Self {
208        Self {
209            input_tokens: self.input_tokens + other.input_tokens,
210            output_tokens: self.output_tokens + other.output_tokens,
211            cache_creation_input_tokens: self.cache_creation_input_tokens
212                + other.cache_creation_input_tokens,
213            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
214        }
215    }
216}
217
218impl Sub<TokenUsage> for TokenUsage {
219    type Output = Self;
220
221    fn sub(self, other: Self) -> Self {
222        Self {
223            input_tokens: self.input_tokens - other.input_tokens,
224            output_tokens: self.output_tokens - other.output_tokens,
225            cache_creation_input_tokens: self.cache_creation_input_tokens
226                - other.cache_creation_input_tokens,
227            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
228        }
229    }
230}
231
232#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
233pub struct LanguageModelToolUseId(Arc<str>);
234
235impl fmt::Display for LanguageModelToolUseId {
236    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
237        write!(f, "{}", self.0)
238    }
239}
240
241impl<T> From<T> for LanguageModelToolUseId
242where
243    T: Into<Arc<str>>,
244{
245    fn from(value: T) -> Self {
246        Self(value.into())
247    }
248}
249
250#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
251pub struct LanguageModelToolUse {
252    pub id: LanguageModelToolUseId,
253    pub name: Arc<str>,
254    pub raw_input: String,
255    pub input: serde_json::Value,
256    pub is_input_complete: bool,
257}
258
259pub struct LanguageModelTextStream {
260    pub message_id: Option<String>,
261    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
262    // Has complete token usage after the stream has finished
263    pub last_token_usage: Arc<Mutex<TokenUsage>>,
264}
265
266impl Default for LanguageModelTextStream {
267    fn default() -> Self {
268        Self {
269            message_id: None,
270            stream: Box::pin(futures::stream::empty()),
271            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
272        }
273    }
274}
275
276pub trait LanguageModel: Send + Sync {
277    fn id(&self) -> LanguageModelId;
278    fn name(&self) -> LanguageModelName;
279    fn provider_id(&self) -> LanguageModelProviderId;
280    fn provider_name(&self) -> LanguageModelProviderName;
281    fn telemetry_id(&self) -> String;
282
283    fn api_key(&self, _cx: &App) -> Option<String> {
284        None
285    }
286
287    /// Whether this model supports images
288    fn supports_images(&self) -> bool;
289
290    /// Whether this model supports tools.
291    fn supports_tools(&self) -> bool;
292
293    /// Whether this model supports choosing which tool to use.
294    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
295
296    /// Returns whether this model supports "burn mode";
297    fn supports_burn_mode(&self) -> bool {
298        false
299    }
300
301    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
302        LanguageModelToolSchemaFormat::JsonSchema
303    }
304
305    fn max_token_count(&self) -> u64;
306    fn max_output_tokens(&self) -> Option<u64> {
307        None
308    }
309
310    fn count_tokens(
311        &self,
312        request: LanguageModelRequest,
313        cx: &App,
314    ) -> BoxFuture<'static, Result<u64>>;
315
316    fn stream_completion(
317        &self,
318        request: LanguageModelRequest,
319        cx: &AsyncApp,
320    ) -> BoxFuture<
321        'static,
322        Result<
323            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
324            LanguageModelCompletionError,
325        >,
326    >;
327
328    fn stream_completion_text(
329        &self,
330        request: LanguageModelRequest,
331        cx: &AsyncApp,
332    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
333        let future = self.stream_completion(request, cx);
334
335        async move {
336            let events = future.await?;
337            let mut events = events.fuse();
338            let mut message_id = None;
339            let mut first_item_text = None;
340            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
341
342            if let Some(first_event) = events.next().await {
343                match first_event {
344                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
345                        message_id = Some(id.clone());
346                    }
347                    Ok(LanguageModelCompletionEvent::Text(text)) => {
348                        first_item_text = Some(text);
349                    }
350                    _ => (),
351                }
352            }
353
354            let stream = futures::stream::iter(first_item_text.map(Ok))
355                .chain(events.filter_map({
356                    let last_token_usage = last_token_usage.clone();
357                    move |result| {
358                        let last_token_usage = last_token_usage.clone();
359                        async move {
360                            match result {
361                                Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
362                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
363                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
364                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
365                                Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
366                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
367                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
368                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
369                                    *last_token_usage.lock() = token_usage;
370                                    None
371                                }
372                                Err(err) => Some(Err(err)),
373                            }
374                        }
375                    }
376                }))
377                .boxed();
378
379            Ok(LanguageModelTextStream {
380                message_id,
381                stream,
382                last_token_usage,
383            })
384        }
385        .boxed()
386    }
387
388    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
389        None
390    }
391
392    #[cfg(any(test, feature = "test-support"))]
393    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
394        unimplemented!()
395    }
396}
397
398#[derive(Debug, Error)]
399pub enum LanguageModelKnownError {
400    #[error("Context window limit exceeded ({tokens})")]
401    ContextWindowLimitExceeded { tokens: u64 },
402    #[error("Language model provider's API is currently overloaded")]
403    Overloaded,
404    #[error("Language model provider's API encountered an internal server error")]
405    ApiInternalServerError,
406    #[error("I/O error while reading response from language model provider's API: {0:?}")]
407    ReadResponseError(io::Error),
408    #[error("Error deserializing response from language model provider's API: {0:?}")]
409    DeserializeResponse(serde_json::Error),
410    #[error("Language model provider's API returned a response in an unknown format")]
411    UnknownResponseFormat(String),
412    #[error("Rate limit exceeded for language model provider's API; retry in {retry_after:?}")]
413    RateLimitExceeded { retry_after: Duration },
414}
415
416impl LanguageModelKnownError {
417    /// Attempts to map an HTTP response status code to a known error type.
418    /// Returns None if the status code doesn't map to a specific known error.
419    pub fn from_http_response(status: u16, _body: &str) -> Option<Self> {
420        match status {
421            429 => Some(Self::RateLimitExceeded {
422                retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
423            }),
424            503 => Some(Self::Overloaded),
425            500..=599 => Some(Self::ApiInternalServerError),
426            _ => None,
427        }
428    }
429}
430
431pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
432    fn name() -> String;
433    fn description() -> String;
434}
435
436/// An error that occurred when trying to authenticate the language model provider.
437#[derive(Debug, Error)]
438pub enum AuthenticateError {
439    #[error("credentials not found")]
440    CredentialsNotFound,
441    #[error(transparent)]
442    Other(#[from] anyhow::Error),
443}
444
445pub trait LanguageModelProvider: 'static {
446    fn id(&self) -> LanguageModelProviderId;
447    fn name(&self) -> LanguageModelProviderName;
448    fn icon(&self) -> IconName {
449        IconName::ZedAssistant
450    }
451    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
452    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
453    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
454    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
455        Vec::new()
456    }
457    fn is_authenticated(&self, cx: &App) -> bool;
458    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
459    fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
460    fn must_accept_terms(&self, _cx: &App) -> bool {
461        false
462    }
463    fn render_accept_terms(
464        &self,
465        _view: LanguageModelProviderTosView,
466        _cx: &mut App,
467    ) -> Option<AnyElement> {
468        None
469    }
470    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
471}
472
473#[derive(PartialEq, Eq)]
474pub enum LanguageModelProviderTosView {
475    /// When there are some past interactions in the Agent Panel.
476    ThreadtEmptyState,
477    /// When there are no past interactions in the Agent Panel.
478    ThreadFreshStart,
479    PromptEditorPopup,
480    Configuration,
481}
482
483pub trait LanguageModelProviderState: 'static {
484    type ObservableEntity;
485
486    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
487
488    fn subscribe<T: 'static>(
489        &self,
490        cx: &mut gpui::Context<T>,
491        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
492    ) -> Option<gpui::Subscription> {
493        let entity = self.observable_entity()?;
494        Some(cx.observe(&entity, move |this, _, cx| {
495            callback(this, cx);
496        }))
497    }
498}
499
500#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
501pub struct LanguageModelId(pub SharedString);
502
503#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
504pub struct LanguageModelName(pub SharedString);
505
506#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
507pub struct LanguageModelProviderId(pub SharedString);
508
509#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
510pub struct LanguageModelProviderName(pub SharedString);
511
512impl fmt::Display for LanguageModelProviderId {
513    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
514        write!(f, "{}", self.0)
515    }
516}
517
518impl From<String> for LanguageModelId {
519    fn from(value: String) -> Self {
520        Self(SharedString::from(value))
521    }
522}
523
524impl From<String> for LanguageModelName {
525    fn from(value: String) -> Self {
526        Self(SharedString::from(value))
527    }
528}
529
530impl From<String> for LanguageModelProviderId {
531    fn from(value: String) -> Self {
532        Self(SharedString::from(value))
533    }
534}
535
536impl From<String> for LanguageModelProviderName {
537    fn from(value: String) -> Self {
538        Self(SharedString::from(value))
539    }
540}