language_model.rs

  1mod model;
  2mod rate_limiter;
  3mod registry;
  4mod request;
  5mod role;
  6mod telemetry;
  7
  8#[cfg(any(test, feature = "test-support"))]
  9pub mod fake_provider;
 10
 11use anthropic::{AnthropicError, parse_prompt_too_long};
 12use anyhow::Result;
 13use client::Client;
 14use futures::FutureExt;
 15use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 16use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
 17use http_client::http;
 18use icons::IconName;
 19use parking_lot::Mutex;
 20use schemars::JsonSchema;
 21use serde::{Deserialize, Serialize, de::DeserializeOwned};
 22use std::ops::{Add, Sub};
 23use std::sync::Arc;
 24use std::time::Duration;
 25use std::{fmt, io};
 26use thiserror::Error;
 27use util::serde::is_default;
 28use zed_llm_client::CompletionRequestStatus;
 29
 30pub use crate::model::*;
 31pub use crate::rate_limiter::*;
 32pub use crate::registry::*;
 33pub use crate::request::*;
 34pub use crate::role::*;
 35pub use crate::telemetry::*;
 36
 37pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
 38
 39/// If we get a rate limit error that doesn't tell us when we can retry,
 40/// default to waiting this long before retrying.
 41const DEFAULT_RATE_LIMIT_RETRY_AFTER: Duration = Duration::from_secs(4);
 42
 43pub fn init(client: Arc<Client>, cx: &mut App) {
 44    init_settings(cx);
 45    RefreshLlmTokenListener::register(client.clone(), cx);
 46}
 47
 48pub fn init_settings(cx: &mut App) {
 49    registry::init(cx);
 50}
 51
 52/// Configuration for caching language model messages.
 53#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 54pub struct LanguageModelCacheConfiguration {
 55    pub max_cache_anchors: usize,
 56    pub should_speculate: bool,
 57    pub min_total_token: u64,
 58}
 59
 60/// A completion event from a language model.
 61#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 62pub enum LanguageModelCompletionEvent {
 63    StatusUpdate(CompletionRequestStatus),
 64    Stop(StopReason),
 65    Text(String),
 66    Thinking {
 67        text: String,
 68        signature: Option<String>,
 69    },
 70    ToolUse(LanguageModelToolUse),
 71    StartMessage {
 72        message_id: String,
 73    },
 74    UsageUpdate(TokenUsage),
 75}
 76
 77#[derive(Error, Debug)]
 78pub enum LanguageModelCompletionError {
 79    #[error("rate limit exceeded, retry after {retry_after:?}")]
 80    RateLimitExceeded { retry_after: Duration },
 81    #[error("received bad input JSON")]
 82    BadInputJson {
 83        id: LanguageModelToolUseId,
 84        tool_name: Arc<str>,
 85        raw_input: Arc<str>,
 86        json_parse_error: String,
 87    },
 88    #[error("language model provider's API is overloaded")]
 89    Overloaded,
 90    #[error(transparent)]
 91    Other(#[from] anyhow::Error),
 92    #[error("invalid request format to language model provider's API")]
 93    BadRequestFormat,
 94    #[error("authentication error with language model provider's API")]
 95    AuthenticationError,
 96    #[error("permission error with language model provider's API")]
 97    PermissionError,
 98    #[error("language model provider API endpoint not found")]
 99    ApiEndpointNotFound,
100    #[error("prompt too large for context window")]
101    PromptTooLarge { tokens: Option<u64> },
102    #[error("internal server error in language model provider's API")]
103    ApiInternalServerError,
104    #[error("I/O error reading response from language model provider's API: {0:?}")]
105    ApiReadResponseError(io::Error),
106    #[error("HTTP response error from language model provider's API: status {status} - {body:?}")]
107    HttpResponseError { status: u16, body: String },
108    #[error("error serializing request to language model provider API: {0}")]
109    SerializeRequest(serde_json::Error),
110    #[error("error building request body to language model provider API: {0}")]
111    BuildRequestBody(http::Error),
112    #[error("error sending HTTP request to language model provider API: {0}")]
113    HttpSend(anyhow::Error),
114    #[error("error deserializing language model provider API response: {0}")]
115    DeserializeResponse(serde_json::Error),
116    #[error("unexpected language model provider API response format: {0}")]
117    UnknownResponseFormat(String),
118}
119
120impl From<AnthropicError> for LanguageModelCompletionError {
121    fn from(error: AnthropicError) -> Self {
122        match error {
123            AnthropicError::SerializeRequest(error) => Self::SerializeRequest(error),
124            AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody(error),
125            AnthropicError::HttpSend(error) => Self::HttpSend(error),
126            AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse(error),
127            AnthropicError::ReadResponse(error) => Self::ApiReadResponseError(error),
128            AnthropicError::HttpResponseError { status, body } => {
129                Self::HttpResponseError { status, body }
130            }
131            AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { retry_after },
132            AnthropicError::ApiError(api_error) => api_error.into(),
133            AnthropicError::UnexpectedResponseFormat(error) => Self::UnknownResponseFormat(error),
134        }
135    }
136}
137
138impl From<anthropic::ApiError> for LanguageModelCompletionError {
139    fn from(error: anthropic::ApiError) -> Self {
140        use anthropic::ApiErrorCode::*;
141
142        match error.code() {
143            Some(code) => match code {
144                InvalidRequestError => LanguageModelCompletionError::BadRequestFormat,
145                AuthenticationError => LanguageModelCompletionError::AuthenticationError,
146                PermissionError => LanguageModelCompletionError::PermissionError,
147                NotFoundError => LanguageModelCompletionError::ApiEndpointNotFound,
148                RequestTooLarge => LanguageModelCompletionError::PromptTooLarge {
149                    tokens: parse_prompt_too_long(&error.message),
150                },
151                RateLimitError => LanguageModelCompletionError::RateLimitExceeded {
152                    retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
153                },
154                ApiError => LanguageModelCompletionError::ApiInternalServerError,
155                OverloadedError => LanguageModelCompletionError::Overloaded,
156            },
157            None => LanguageModelCompletionError::Other(error.into()),
158        }
159    }
160}
161
162/// Indicates the format used to define the input schema for a language model tool.
163#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
164pub enum LanguageModelToolSchemaFormat {
165    /// A JSON schema, see https://json-schema.org
166    JsonSchema,
167    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
168    JsonSchemaSubset,
169}
170
171#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
172#[serde(rename_all = "snake_case")]
173pub enum StopReason {
174    EndTurn,
175    MaxTokens,
176    ToolUse,
177    Refusal,
178}
179
180#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
181pub struct TokenUsage {
182    #[serde(default, skip_serializing_if = "is_default")]
183    pub input_tokens: u64,
184    #[serde(default, skip_serializing_if = "is_default")]
185    pub output_tokens: u64,
186    #[serde(default, skip_serializing_if = "is_default")]
187    pub cache_creation_input_tokens: u64,
188    #[serde(default, skip_serializing_if = "is_default")]
189    pub cache_read_input_tokens: u64,
190}
191
192impl TokenUsage {
193    pub fn total_tokens(&self) -> u64 {
194        self.input_tokens
195            + self.output_tokens
196            + self.cache_read_input_tokens
197            + self.cache_creation_input_tokens
198    }
199}
200
201impl Add<TokenUsage> for TokenUsage {
202    type Output = Self;
203
204    fn add(self, other: Self) -> Self {
205        Self {
206            input_tokens: self.input_tokens + other.input_tokens,
207            output_tokens: self.output_tokens + other.output_tokens,
208            cache_creation_input_tokens: self.cache_creation_input_tokens
209                + other.cache_creation_input_tokens,
210            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
211        }
212    }
213}
214
215impl Sub<TokenUsage> for TokenUsage {
216    type Output = Self;
217
218    fn sub(self, other: Self) -> Self {
219        Self {
220            input_tokens: self.input_tokens - other.input_tokens,
221            output_tokens: self.output_tokens - other.output_tokens,
222            cache_creation_input_tokens: self.cache_creation_input_tokens
223                - other.cache_creation_input_tokens,
224            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
225        }
226    }
227}
228
229#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
230pub struct LanguageModelToolUseId(Arc<str>);
231
232impl fmt::Display for LanguageModelToolUseId {
233    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234        write!(f, "{}", self.0)
235    }
236}
237
238impl<T> From<T> for LanguageModelToolUseId
239where
240    T: Into<Arc<str>>,
241{
242    fn from(value: T) -> Self {
243        Self(value.into())
244    }
245}
246
247#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
248pub struct LanguageModelToolUse {
249    pub id: LanguageModelToolUseId,
250    pub name: Arc<str>,
251    pub raw_input: String,
252    pub input: serde_json::Value,
253    pub is_input_complete: bool,
254}
255
256pub struct LanguageModelTextStream {
257    pub message_id: Option<String>,
258    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
259    // Has complete token usage after the stream has finished
260    pub last_token_usage: Arc<Mutex<TokenUsage>>,
261}
262
263impl Default for LanguageModelTextStream {
264    fn default() -> Self {
265        Self {
266            message_id: None,
267            stream: Box::pin(futures::stream::empty()),
268            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
269        }
270    }
271}
272
273pub trait LanguageModel: Send + Sync {
274    fn id(&self) -> LanguageModelId;
275    fn name(&self) -> LanguageModelName;
276    fn provider_id(&self) -> LanguageModelProviderId;
277    fn provider_name(&self) -> LanguageModelProviderName;
278    fn telemetry_id(&self) -> String;
279
280    fn api_key(&self, _cx: &App) -> Option<String> {
281        None
282    }
283
284    /// Whether this model supports images
285    fn supports_images(&self) -> bool;
286
287    /// Whether this model supports tools.
288    fn supports_tools(&self) -> bool;
289
290    /// Whether this model supports choosing which tool to use.
291    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
292
293    /// Returns whether this model supports "burn mode";
294    fn supports_max_mode(&self) -> bool {
295        false
296    }
297
298    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
299        LanguageModelToolSchemaFormat::JsonSchema
300    }
301
302    fn max_token_count(&self) -> u64;
303    fn max_output_tokens(&self) -> Option<u64> {
304        None
305    }
306
307    fn count_tokens(
308        &self,
309        request: LanguageModelRequest,
310        cx: &App,
311    ) -> BoxFuture<'static, Result<u64>>;
312
313    fn stream_completion(
314        &self,
315        request: LanguageModelRequest,
316        cx: &AsyncApp,
317    ) -> BoxFuture<
318        'static,
319        Result<
320            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
321            LanguageModelCompletionError,
322        >,
323    >;
324
325    fn stream_completion_text(
326        &self,
327        request: LanguageModelRequest,
328        cx: &AsyncApp,
329    ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
330        let future = self.stream_completion(request, cx);
331
332        async move {
333            let events = future.await?;
334            let mut events = events.fuse();
335            let mut message_id = None;
336            let mut first_item_text = None;
337            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
338
339            if let Some(first_event) = events.next().await {
340                match first_event {
341                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
342                        message_id = Some(id.clone());
343                    }
344                    Ok(LanguageModelCompletionEvent::Text(text)) => {
345                        first_item_text = Some(text);
346                    }
347                    _ => (),
348                }
349            }
350
351            let stream = futures::stream::iter(first_item_text.map(Ok))
352                .chain(events.filter_map({
353                    let last_token_usage = last_token_usage.clone();
354                    move |result| {
355                        let last_token_usage = last_token_usage.clone();
356                        async move {
357                            match result {
358                                Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
359                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
360                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
361                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
362                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
363                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
364                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
365                                    *last_token_usage.lock() = token_usage;
366                                    None
367                                }
368                                Err(err) => Some(Err(err)),
369                            }
370                        }
371                    }
372                }))
373                .boxed();
374
375            Ok(LanguageModelTextStream {
376                message_id,
377                stream,
378                last_token_usage,
379            })
380        }
381        .boxed()
382    }
383
384    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
385        None
386    }
387
388    #[cfg(any(test, feature = "test-support"))]
389    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
390        unimplemented!()
391    }
392}
393
394#[derive(Debug, Error)]
395pub enum LanguageModelKnownError {
396    #[error("Context window limit exceeded ({tokens})")]
397    ContextWindowLimitExceeded { tokens: u64 },
398    #[error("Language model provider's API is currently overloaded")]
399    Overloaded,
400    #[error("Language model provider's API encountered an internal server error")]
401    ApiInternalServerError,
402    #[error("I/O error while reading response from language model provider's API: {0:?}")]
403    ReadResponseError(io::Error),
404    #[error("Error deserializing response from language model provider's API: {0:?}")]
405    DeserializeResponse(serde_json::Error),
406    #[error("Language model provider's API returned a response in an unknown format")]
407    UnknownResponseFormat(String),
408    #[error("Rate limit exceeded for language model provider's API; retry in {retry_after:?}")]
409    RateLimitExceeded { retry_after: Duration },
410}
411
412impl LanguageModelKnownError {
413    /// Attempts to map an HTTP response status code to a known error type.
414    /// Returns None if the status code doesn't map to a specific known error.
415    pub fn from_http_response(status: u16, _body: &str) -> Option<Self> {
416        match status {
417            429 => Some(Self::RateLimitExceeded {
418                retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
419            }),
420            503 => Some(Self::Overloaded),
421            500..=599 => Some(Self::ApiInternalServerError),
422            _ => None,
423        }
424    }
425}
426
427pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
428    fn name() -> String;
429    fn description() -> String;
430}
431
432/// An error that occurred when trying to authenticate the language model provider.
433#[derive(Debug, Error)]
434pub enum AuthenticateError {
435    #[error("credentials not found")]
436    CredentialsNotFound,
437    #[error(transparent)]
438    Other(#[from] anyhow::Error),
439}
440
441pub trait LanguageModelProvider: 'static {
442    fn id(&self) -> LanguageModelProviderId;
443    fn name(&self) -> LanguageModelProviderName;
444    fn icon(&self) -> IconName {
445        IconName::ZedAssistant
446    }
447    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
448    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
449    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
450    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
451        Vec::new()
452    }
453    fn is_authenticated(&self, cx: &App) -> bool;
454    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
455    fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
456    fn must_accept_terms(&self, _cx: &App) -> bool {
457        false
458    }
459    fn render_accept_terms(
460        &self,
461        _view: LanguageModelProviderTosView,
462        _cx: &mut App,
463    ) -> Option<AnyElement> {
464        None
465    }
466    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
467}
468
469#[derive(PartialEq, Eq)]
470pub enum LanguageModelProviderTosView {
471    /// When there are some past interactions in the Agent Panel.
472    ThreadtEmptyState,
473    /// When there are no past interactions in the Agent Panel.
474    ThreadFreshStart,
475    PromptEditorPopup,
476    Configuration,
477}
478
479pub trait LanguageModelProviderState: 'static {
480    type ObservableEntity;
481
482    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
483
484    fn subscribe<T: 'static>(
485        &self,
486        cx: &mut gpui::Context<T>,
487        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
488    ) -> Option<gpui::Subscription> {
489        let entity = self.observable_entity()?;
490        Some(cx.observe(&entity, move |this, _, cx| {
491            callback(this, cx);
492        }))
493    }
494}
495
496#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
497pub struct LanguageModelId(pub SharedString);
498
499#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
500pub struct LanguageModelName(pub SharedString);
501
502#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
503pub struct LanguageModelProviderId(pub SharedString);
504
505#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
506pub struct LanguageModelProviderName(pub SharedString);
507
508impl fmt::Display for LanguageModelProviderId {
509    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
510        write!(f, "{}", self.0)
511    }
512}
513
514impl From<String> for LanguageModelId {
515    fn from(value: String) -> Self {
516        Self(SharedString::from(value))
517    }
518}
519
520impl From<String> for LanguageModelName {
521    fn from(value: String) -> Self {
522        Self(SharedString::from(value))
523    }
524}
525
526impl From<String> for LanguageModelProviderId {
527    fn from(value: String) -> Self {
528        Self(SharedString::from(value))
529    }
530}
531
532impl From<String> for LanguageModelProviderName {
533    fn from(value: String) -> Self {
534        Self(SharedString::from(value))
535    }
536}