language_model.rs

  1mod model;
  2mod rate_limiter;
  3mod registry;
  4mod request;
  5mod role;
  6mod telemetry;
  7
  8#[cfg(any(test, feature = "test-support"))]
  9pub mod fake_provider;
 10
 11use anyhow::{Result, anyhow};
 12use client::Client;
 13use futures::FutureExt;
 14use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 15use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
 16use http_client::http::{HeaderMap, HeaderValue};
 17use icons::IconName;
 18use parking_lot::Mutex;
 19use proto::Plan;
 20use schemars::JsonSchema;
 21use serde::{Deserialize, Serialize, de::DeserializeOwned};
 22use std::fmt;
 23use std::ops::{Add, Sub};
 24use std::str::FromStr as _;
 25use std::sync::Arc;
 26use thiserror::Error;
 27use util::serde::is_default;
 28use zed_llm_client::{
 29    MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
 30};
 31
 32pub use crate::model::*;
 33pub use crate::rate_limiter::*;
 34pub use crate::registry::*;
 35pub use crate::request::*;
 36pub use crate::role::*;
 37pub use crate::telemetry::*;
 38
 39pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
 40
 41pub fn init(client: Arc<Client>, cx: &mut App) {
 42    init_settings(cx);
 43    RefreshLlmTokenListener::register(client.clone(), cx);
 44}
 45
 46pub fn init_settings(cx: &mut App) {
 47    registry::init(cx);
 48}
 49
 50/// The availability of a [`LanguageModel`].
 51#[derive(Debug, PartialEq, Eq, Clone, Copy)]
 52pub enum LanguageModelAvailability {
 53    /// The language model is available to the general public.
 54    Public,
 55    /// The language model is available to users on the indicated plan.
 56    RequiresPlan(Plan),
 57}
 58
 59/// Configuration for caching language model messages.
 60#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 61pub struct LanguageModelCacheConfiguration {
 62    pub max_cache_anchors: usize,
 63    pub should_speculate: bool,
 64    pub min_total_token: usize,
 65}
 66
 67#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 68#[serde(tag = "status", rename_all = "snake_case")]
 69pub enum CompletionRequestStatus {
 70    Queued { position: usize },
 71    Started,
 72    ToolUseLimitReached,
 73}
 74
 75/// A completion event from a language model.
 76#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 77pub enum LanguageModelCompletionEvent {
 78    QueueUpdate(CompletionRequestStatus),
 79    Stop(StopReason),
 80    Text(String),
 81    Thinking {
 82        text: String,
 83        signature: Option<String>,
 84    },
 85    ToolUse(LanguageModelToolUse),
 86    StartMessage {
 87        message_id: String,
 88    },
 89    UsageUpdate(TokenUsage),
 90}
 91
 92#[derive(Error, Debug)]
 93pub enum LanguageModelCompletionError {
 94    #[error("received bad input JSON")]
 95    BadInputJson {
 96        id: LanguageModelToolUseId,
 97        tool_name: Arc<str>,
 98        raw_input: Arc<str>,
 99        json_parse_error: String,
100    },
101    #[error(transparent)]
102    Other(#[from] anyhow::Error),
103}
104
105/// Indicates the format used to define the input schema for a language model tool.
106#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
107pub enum LanguageModelToolSchemaFormat {
108    /// A JSON schema, see https://json-schema.org
109    JsonSchema,
110    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
111    JsonSchemaSubset,
112}
113
114#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
115#[serde(rename_all = "snake_case")]
116pub enum StopReason {
117    EndTurn,
118    MaxTokens,
119    ToolUse,
120}
121
122#[derive(Debug, Clone, Copy)]
123pub struct RequestUsage {
124    pub limit: UsageLimit,
125    pub amount: i32,
126}
127
128impl RequestUsage {
129    pub fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
130        let limit = headers
131            .get(MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME)
132            .ok_or_else(|| anyhow!("missing {MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME:?} header"))?;
133        let limit = UsageLimit::from_str(limit.to_str()?)?;
134
135        let amount = headers
136            .get(MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME)
137            .ok_or_else(|| anyhow!("missing {MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME:?} header"))?;
138        let amount = amount.to_str()?.parse::<i32>()?;
139
140        Ok(Self { limit, amount })
141    }
142}
143
144#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
145pub struct TokenUsage {
146    #[serde(default, skip_serializing_if = "is_default")]
147    pub input_tokens: u32,
148    #[serde(default, skip_serializing_if = "is_default")]
149    pub output_tokens: u32,
150    #[serde(default, skip_serializing_if = "is_default")]
151    pub cache_creation_input_tokens: u32,
152    #[serde(default, skip_serializing_if = "is_default")]
153    pub cache_read_input_tokens: u32,
154}
155
156impl TokenUsage {
157    pub fn total_tokens(&self) -> u32 {
158        self.input_tokens
159            + self.output_tokens
160            + self.cache_read_input_tokens
161            + self.cache_creation_input_tokens
162    }
163}
164
165impl Add<TokenUsage> for TokenUsage {
166    type Output = Self;
167
168    fn add(self, other: Self) -> Self {
169        Self {
170            input_tokens: self.input_tokens + other.input_tokens,
171            output_tokens: self.output_tokens + other.output_tokens,
172            cache_creation_input_tokens: self.cache_creation_input_tokens
173                + other.cache_creation_input_tokens,
174            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
175        }
176    }
177}
178
179impl Sub<TokenUsage> for TokenUsage {
180    type Output = Self;
181
182    fn sub(self, other: Self) -> Self {
183        Self {
184            input_tokens: self.input_tokens - other.input_tokens,
185            output_tokens: self.output_tokens - other.output_tokens,
186            cache_creation_input_tokens: self.cache_creation_input_tokens
187                - other.cache_creation_input_tokens,
188            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
189        }
190    }
191}
192
193#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
194pub struct LanguageModelToolUseId(Arc<str>);
195
196impl fmt::Display for LanguageModelToolUseId {
197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198        write!(f, "{}", self.0)
199    }
200}
201
202impl<T> From<T> for LanguageModelToolUseId
203where
204    T: Into<Arc<str>>,
205{
206    fn from(value: T) -> Self {
207        Self(value.into())
208    }
209}
210
211#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
212pub struct LanguageModelToolUse {
213    pub id: LanguageModelToolUseId,
214    pub name: Arc<str>,
215    pub raw_input: String,
216    pub input: serde_json::Value,
217    pub is_input_complete: bool,
218}
219
220pub struct LanguageModelTextStream {
221    pub message_id: Option<String>,
222    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
223    // Has complete token usage after the stream has finished
224    pub last_token_usage: Arc<Mutex<TokenUsage>>,
225}
226
227impl Default for LanguageModelTextStream {
228    fn default() -> Self {
229        Self {
230            message_id: None,
231            stream: Box::pin(futures::stream::empty()),
232            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
233        }
234    }
235}
236
237pub trait LanguageModel: Send + Sync {
238    fn id(&self) -> LanguageModelId;
239    fn name(&self) -> LanguageModelName;
240    fn provider_id(&self) -> LanguageModelProviderId;
241    fn provider_name(&self) -> LanguageModelProviderName;
242    fn telemetry_id(&self) -> String;
243
244    fn api_key(&self, _cx: &App) -> Option<String> {
245        None
246    }
247
248    /// Returns the availability of this language model.
249    fn availability(&self) -> LanguageModelAvailability {
250        LanguageModelAvailability::Public
251    }
252
253    /// Whether this model supports tools.
254    fn supports_tools(&self) -> bool;
255
256    /// Returns whether this model supports "max mode";
257    fn supports_max_mode(&self) -> bool {
258        if self.provider_id().0 != ZED_CLOUD_PROVIDER_ID {
259            return false;
260        }
261
262        const MAX_MODE_CAPABLE_MODELS: &[CloudModel] = &[
263            CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
264            CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
265        ];
266
267        for model in MAX_MODE_CAPABLE_MODELS {
268            if self.id().0 == model.id() {
269                return true;
270            }
271        }
272
273        false
274    }
275
276    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
277        LanguageModelToolSchemaFormat::JsonSchema
278    }
279
280    fn max_token_count(&self) -> usize;
281    fn max_output_tokens(&self) -> Option<u32> {
282        None
283    }
284
285    fn count_tokens(
286        &self,
287        request: LanguageModelRequest,
288        cx: &App,
289    ) -> BoxFuture<'static, Result<usize>>;
290
291    fn stream_completion(
292        &self,
293        request: LanguageModelRequest,
294        cx: &AsyncApp,
295    ) -> BoxFuture<
296        'static,
297        Result<
298            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
299        >,
300    >;
301
302    fn stream_completion_with_usage(
303        &self,
304        request: LanguageModelRequest,
305        cx: &AsyncApp,
306    ) -> BoxFuture<
307        'static,
308        Result<(
309            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
310            Option<RequestUsage>,
311        )>,
312    > {
313        self.stream_completion(request, cx)
314            .map(|result| result.map(|stream| (stream, None)))
315            .boxed()
316    }
317
318    fn stream_completion_text(
319        &self,
320        request: LanguageModelRequest,
321        cx: &AsyncApp,
322    ) -> BoxFuture<'static, Result<LanguageModelTextStream>> {
323        self.stream_completion_text_with_usage(request, cx)
324            .map(|result| result.map(|(stream, _usage)| stream))
325            .boxed()
326    }
327
328    fn stream_completion_text_with_usage(
329        &self,
330        request: LanguageModelRequest,
331        cx: &AsyncApp,
332    ) -> BoxFuture<'static, Result<(LanguageModelTextStream, Option<RequestUsage>)>> {
333        let future = self.stream_completion_with_usage(request, cx);
334
335        async move {
336            let (events, usage) = future.await?;
337            let mut events = events.fuse();
338            let mut message_id = None;
339            let mut first_item_text = None;
340            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
341
342            if let Some(first_event) = events.next().await {
343                match first_event {
344                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
345                        message_id = Some(id.clone());
346                    }
347                    Ok(LanguageModelCompletionEvent::Text(text)) => {
348                        first_item_text = Some(text);
349                    }
350                    _ => (),
351                }
352            }
353
354            let stream = futures::stream::iter(first_item_text.map(Ok))
355                .chain(events.filter_map({
356                    let last_token_usage = last_token_usage.clone();
357                    move |result| {
358                        let last_token_usage = last_token_usage.clone();
359                        async move {
360                            match result {
361                                Ok(LanguageModelCompletionEvent::QueueUpdate { .. }) => None,
362                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
363                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
364                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
365                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
366                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
367                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
368                                    *last_token_usage.lock() = token_usage;
369                                    None
370                                }
371                                Err(err) => Some(Err(err)),
372                            }
373                        }
374                    }
375                }))
376                .boxed();
377
378            Ok((
379                LanguageModelTextStream {
380                    message_id,
381                    stream,
382                    last_token_usage,
383                },
384                usage,
385            ))
386        }
387        .boxed()
388    }
389
390    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
391        None
392    }
393
394    #[cfg(any(test, feature = "test-support"))]
395    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
396        unimplemented!()
397    }
398}
399
400#[derive(Debug, Error)]
401pub enum LanguageModelKnownError {
402    #[error("Context window limit exceeded ({tokens})")]
403    ContextWindowLimitExceeded { tokens: usize },
404}
405
406pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
407    fn name() -> String;
408    fn description() -> String;
409}
410
411/// An error that occurred when trying to authenticate the language model provider.
412#[derive(Debug, Error)]
413pub enum AuthenticateError {
414    #[error("credentials not found")]
415    CredentialsNotFound,
416    #[error(transparent)]
417    Other(#[from] anyhow::Error),
418}
419
420pub trait LanguageModelProvider: 'static {
421    fn id(&self) -> LanguageModelProviderId;
422    fn name(&self) -> LanguageModelProviderName;
423    fn icon(&self) -> IconName {
424        IconName::ZedAssistant
425    }
426    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
427    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
428    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
429    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
430        Vec::new()
431    }
432    fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &App) {}
433    fn is_authenticated(&self, cx: &App) -> bool;
434    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
435    fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
436    fn must_accept_terms(&self, _cx: &App) -> bool {
437        false
438    }
439    fn render_accept_terms(
440        &self,
441        _view: LanguageModelProviderTosView,
442        _cx: &mut App,
443    ) -> Option<AnyElement> {
444        None
445    }
446    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
447}
448
449#[derive(PartialEq, Eq)]
450pub enum LanguageModelProviderTosView {
451    /// When there are some past interactions in the Agent Panel.
452    ThreadtEmptyState,
453    /// When there are no past interactions in the Agent Panel.
454    ThreadFreshStart,
455    PromptEditorPopup,
456    Configuration,
457}
458
459pub trait LanguageModelProviderState: 'static {
460    type ObservableEntity;
461
462    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
463
464    fn subscribe<T: 'static>(
465        &self,
466        cx: &mut gpui::Context<T>,
467        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
468    ) -> Option<gpui::Subscription> {
469        let entity = self.observable_entity()?;
470        Some(cx.observe(&entity, move |this, _, cx| {
471            callback(this, cx);
472        }))
473    }
474}
475
476#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
477pub struct LanguageModelId(pub SharedString);
478
479#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
480pub struct LanguageModelName(pub SharedString);
481
482#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
483pub struct LanguageModelProviderId(pub SharedString);
484
485#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
486pub struct LanguageModelProviderName(pub SharedString);
487
488impl fmt::Display for LanguageModelProviderId {
489    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
490        write!(f, "{}", self.0)
491    }
492}
493
494impl From<String> for LanguageModelId {
495    fn from(value: String) -> Self {
496        Self(SharedString::from(value))
497    }
498}
499
500impl From<String> for LanguageModelName {
501    fn from(value: String) -> Self {
502        Self(SharedString::from(value))
503    }
504}
505
506impl From<String> for LanguageModelProviderId {
507    fn from(value: String) -> Self {
508        Self(SharedString::from(value))
509    }
510}
511
512impl From<String> for LanguageModelProviderName {
513    fn from(value: String) -> Self {
514        Self(SharedString::from(value))
515    }
516}