language_model.rs

  1mod model;
  2mod rate_limiter;
  3mod registry;
  4mod request;
  5mod role;
  6mod telemetry;
  7
  8#[cfg(any(test, feature = "test-support"))]
  9pub mod fake_provider;
 10
 11use anyhow::{Result, anyhow};
 12use client::Client;
 13use futures::FutureExt;
 14use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 15use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
 16use http_client::http::{HeaderMap, HeaderValue};
 17use icons::IconName;
 18use parking_lot::Mutex;
 19use proto::Plan;
 20use schemars::JsonSchema;
 21use serde::{Deserialize, Serialize, de::DeserializeOwned};
 22use std::fmt;
 23use std::ops::{Add, Sub};
 24use std::str::FromStr as _;
 25use std::sync::Arc;
 26use thiserror::Error;
 27use util::serde::is_default;
 28use zed_llm_client::{
 29    MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
 30};
 31
 32pub use crate::model::*;
 33pub use crate::rate_limiter::*;
 34pub use crate::registry::*;
 35pub use crate::request::*;
 36pub use crate::role::*;
 37pub use crate::telemetry::*;
 38
 39pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
 40
 41pub fn init(client: Arc<Client>, cx: &mut App) {
 42    init_settings(cx);
 43    RefreshLlmTokenListener::register(client.clone(), cx);
 44}
 45
 46pub fn init_settings(cx: &mut App) {
 47    registry::init(cx);
 48}
 49
 50/// The availability of a [`LanguageModel`].
 51#[derive(Debug, PartialEq, Eq, Clone, Copy)]
 52pub enum LanguageModelAvailability {
 53    /// The language model is available to the general public.
 54    Public,
 55    /// The language model is available to users on the indicated plan.
 56    RequiresPlan(Plan),
 57}
 58
 59/// Configuration for caching language model messages.
 60#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 61pub struct LanguageModelCacheConfiguration {
 62    pub max_cache_anchors: usize,
 63    pub should_speculate: bool,
 64    pub min_total_token: usize,
 65}
 66
 67/// A completion event from a language model.
 68#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 69pub enum LanguageModelCompletionEvent {
 70    Stop(StopReason),
 71    Text(String),
 72    Thinking {
 73        text: String,
 74        signature: Option<String>,
 75    },
 76    ToolUse(LanguageModelToolUse),
 77    StartMessage {
 78        message_id: String,
 79    },
 80    UsageUpdate(TokenUsage),
 81}
 82
 83#[derive(Error, Debug)]
 84pub enum LanguageModelCompletionError {
 85    #[error("received bad input JSON")]
 86    BadInputJson {
 87        id: LanguageModelToolUseId,
 88        tool_name: Arc<str>,
 89        raw_input: Arc<str>,
 90        json_parse_error: String,
 91    },
 92    #[error(transparent)]
 93    Other(#[from] anyhow::Error),
 94}
 95
 96/// Indicates the format used to define the input schema for a language model tool.
 97#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
 98pub enum LanguageModelToolSchemaFormat {
 99    /// A JSON schema, see https://json-schema.org
100    JsonSchema,
101    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
102    JsonSchemaSubset,
103}
104
105#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
106#[serde(rename_all = "snake_case")]
107pub enum StopReason {
108    EndTurn,
109    MaxTokens,
110    ToolUse,
111}
112
113#[derive(Debug, Clone, Copy)]
114pub struct RequestUsage {
115    pub limit: UsageLimit,
116    pub amount: i32,
117}
118
119impl RequestUsage {
120    pub fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
121        let limit = headers
122            .get(MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME)
123            .ok_or_else(|| anyhow!("missing {MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME:?} header"))?;
124        let limit = UsageLimit::from_str(limit.to_str()?)?;
125
126        let amount = headers
127            .get(MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME)
128            .ok_or_else(|| anyhow!("missing {MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME:?} header"))?;
129        let amount = amount.to_str()?.parse::<i32>()?;
130
131        Ok(Self { limit, amount })
132    }
133}
134
135#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
136pub struct TokenUsage {
137    #[serde(default, skip_serializing_if = "is_default")]
138    pub input_tokens: u32,
139    #[serde(default, skip_serializing_if = "is_default")]
140    pub output_tokens: u32,
141    #[serde(default, skip_serializing_if = "is_default")]
142    pub cache_creation_input_tokens: u32,
143    #[serde(default, skip_serializing_if = "is_default")]
144    pub cache_read_input_tokens: u32,
145}
146
147impl TokenUsage {
148    pub fn total_tokens(&self) -> u32 {
149        self.input_tokens
150            + self.output_tokens
151            + self.cache_read_input_tokens
152            + self.cache_creation_input_tokens
153    }
154}
155
156impl Add<TokenUsage> for TokenUsage {
157    type Output = Self;
158
159    fn add(self, other: Self) -> Self {
160        Self {
161            input_tokens: self.input_tokens + other.input_tokens,
162            output_tokens: self.output_tokens + other.output_tokens,
163            cache_creation_input_tokens: self.cache_creation_input_tokens
164                + other.cache_creation_input_tokens,
165            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
166        }
167    }
168}
169
170impl Sub<TokenUsage> for TokenUsage {
171    type Output = Self;
172
173    fn sub(self, other: Self) -> Self {
174        Self {
175            input_tokens: self.input_tokens - other.input_tokens,
176            output_tokens: self.output_tokens - other.output_tokens,
177            cache_creation_input_tokens: self.cache_creation_input_tokens
178                - other.cache_creation_input_tokens,
179            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
180        }
181    }
182}
183
184#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
185pub struct LanguageModelToolUseId(Arc<str>);
186
187impl fmt::Display for LanguageModelToolUseId {
188    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
189        write!(f, "{}", self.0)
190    }
191}
192
193impl<T> From<T> for LanguageModelToolUseId
194where
195    T: Into<Arc<str>>,
196{
197    fn from(value: T) -> Self {
198        Self(value.into())
199    }
200}
201
202#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
203pub struct LanguageModelToolUse {
204    pub id: LanguageModelToolUseId,
205    pub name: Arc<str>,
206    pub raw_input: String,
207    pub input: serde_json::Value,
208    pub is_input_complete: bool,
209}
210
211pub struct LanguageModelTextStream {
212    pub message_id: Option<String>,
213    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
214    // Has complete token usage after the stream has finished
215    pub last_token_usage: Arc<Mutex<TokenUsage>>,
216}
217
218impl Default for LanguageModelTextStream {
219    fn default() -> Self {
220        Self {
221            message_id: None,
222            stream: Box::pin(futures::stream::empty()),
223            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
224        }
225    }
226}
227
228pub trait LanguageModel: Send + Sync {
229    fn id(&self) -> LanguageModelId;
230    fn name(&self) -> LanguageModelName;
231    fn provider_id(&self) -> LanguageModelProviderId;
232    fn provider_name(&self) -> LanguageModelProviderName;
233    fn telemetry_id(&self) -> String;
234
235    fn api_key(&self, _cx: &App) -> Option<String> {
236        None
237    }
238
239    /// Returns the availability of this language model.
240    fn availability(&self) -> LanguageModelAvailability {
241        LanguageModelAvailability::Public
242    }
243
244    /// Whether this model supports tools.
245    fn supports_tools(&self) -> bool;
246
247    /// Returns whether this model supports "max mode";
248    fn supports_max_mode(&self) -> bool {
249        if self.provider_id().0 != ZED_CLOUD_PROVIDER_ID {
250            return false;
251        }
252
253        const MAX_MODE_CAPABLE_MODELS: &[CloudModel] = &[
254            CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
255            CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
256        ];
257
258        for model in MAX_MODE_CAPABLE_MODELS {
259            if self.id().0 == model.id() {
260                return true;
261            }
262        }
263
264        false
265    }
266
267    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
268        LanguageModelToolSchemaFormat::JsonSchema
269    }
270
271    fn max_token_count(&self) -> usize;
272    fn max_output_tokens(&self) -> Option<u32> {
273        None
274    }
275
276    fn count_tokens(
277        &self,
278        request: LanguageModelRequest,
279        cx: &App,
280    ) -> BoxFuture<'static, Result<usize>>;
281
282    fn stream_completion(
283        &self,
284        request: LanguageModelRequest,
285        cx: &AsyncApp,
286    ) -> BoxFuture<
287        'static,
288        Result<
289            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
290        >,
291    >;
292
293    fn stream_completion_with_usage(
294        &self,
295        request: LanguageModelRequest,
296        cx: &AsyncApp,
297    ) -> BoxFuture<
298        'static,
299        Result<(
300            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
301            Option<RequestUsage>,
302        )>,
303    > {
304        self.stream_completion(request, cx)
305            .map(|result| result.map(|stream| (stream, None)))
306            .boxed()
307    }
308
309    fn stream_completion_text(
310        &self,
311        request: LanguageModelRequest,
312        cx: &AsyncApp,
313    ) -> BoxFuture<'static, Result<LanguageModelTextStream>> {
314        self.stream_completion_text_with_usage(request, cx)
315            .map(|result| result.map(|(stream, _usage)| stream))
316            .boxed()
317    }
318
319    fn stream_completion_text_with_usage(
320        &self,
321        request: LanguageModelRequest,
322        cx: &AsyncApp,
323    ) -> BoxFuture<'static, Result<(LanguageModelTextStream, Option<RequestUsage>)>> {
324        let future = self.stream_completion_with_usage(request, cx);
325
326        async move {
327            let (events, usage) = future.await?;
328            let mut events = events.fuse();
329            let mut message_id = None;
330            let mut first_item_text = None;
331            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
332
333            if let Some(first_event) = events.next().await {
334                match first_event {
335                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
336                        message_id = Some(id.clone());
337                    }
338                    Ok(LanguageModelCompletionEvent::Text(text)) => {
339                        first_item_text = Some(text);
340                    }
341                    _ => (),
342                }
343            }
344
345            let stream = futures::stream::iter(first_item_text.map(Ok))
346                .chain(events.filter_map({
347                    let last_token_usage = last_token_usage.clone();
348                    move |result| {
349                        let last_token_usage = last_token_usage.clone();
350                        async move {
351                            match result {
352                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
353                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
354                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
355                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
356                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
357                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
358                                    *last_token_usage.lock() = token_usage;
359                                    None
360                                }
361                                Err(err) => Some(Err(err)),
362                            }
363                        }
364                    }
365                }))
366                .boxed();
367
368            Ok((
369                LanguageModelTextStream {
370                    message_id,
371                    stream,
372                    last_token_usage,
373                },
374                usage,
375            ))
376        }
377        .boxed()
378    }
379
380    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
381        None
382    }
383
384    #[cfg(any(test, feature = "test-support"))]
385    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
386        unimplemented!()
387    }
388}
389
390#[derive(Debug, Error)]
391pub enum LanguageModelKnownError {
392    #[error("Context window limit exceeded ({tokens})")]
393    ContextWindowLimitExceeded { tokens: usize },
394}
395
396pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
397    fn name() -> String;
398    fn description() -> String;
399}
400
401/// An error that occurred when trying to authenticate the language model provider.
402#[derive(Debug, Error)]
403pub enum AuthenticateError {
404    #[error("credentials not found")]
405    CredentialsNotFound,
406    #[error(transparent)]
407    Other(#[from] anyhow::Error),
408}
409
410pub trait LanguageModelProvider: 'static {
411    fn id(&self) -> LanguageModelProviderId;
412    fn name(&self) -> LanguageModelProviderName;
413    fn icon(&self) -> IconName {
414        IconName::ZedAssistant
415    }
416    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
417    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
418    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
419    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
420        Vec::new()
421    }
422    fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &App) {}
423    fn is_authenticated(&self, cx: &App) -> bool;
424    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
425    fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
426    fn must_accept_terms(&self, _cx: &App) -> bool {
427        false
428    }
429    fn render_accept_terms(
430        &self,
431        _view: LanguageModelProviderTosView,
432        _cx: &mut App,
433    ) -> Option<AnyElement> {
434        None
435    }
436    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
437}
438
439#[derive(PartialEq, Eq)]
440pub enum LanguageModelProviderTosView {
441    /// When there are some past interactions in the Agent Panel.
442    ThreadtEmptyState,
443    /// When there are no past interactions in the Agent Panel.
444    ThreadFreshStart,
445    PromptEditorPopup,
446    Configuration,
447}
448
449pub trait LanguageModelProviderState: 'static {
450    type ObservableEntity;
451
452    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
453
454    fn subscribe<T: 'static>(
455        &self,
456        cx: &mut gpui::Context<T>,
457        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
458    ) -> Option<gpui::Subscription> {
459        let entity = self.observable_entity()?;
460        Some(cx.observe(&entity, move |this, _, cx| {
461            callback(this, cx);
462        }))
463    }
464}
465
466#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
467pub struct LanguageModelId(pub SharedString);
468
469#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
470pub struct LanguageModelName(pub SharedString);
471
472#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
473pub struct LanguageModelProviderId(pub SharedString);
474
475#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
476pub struct LanguageModelProviderName(pub SharedString);
477
478impl fmt::Display for LanguageModelProviderId {
479    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
480        write!(f, "{}", self.0)
481    }
482}
483
484impl From<String> for LanguageModelId {
485    fn from(value: String) -> Self {
486        Self(SharedString::from(value))
487    }
488}
489
490impl From<String> for LanguageModelName {
491    fn from(value: String) -> Self {
492        Self(SharedString::from(value))
493    }
494}
495
496impl From<String> for LanguageModelProviderId {
497    fn from(value: String) -> Self {
498        Self(SharedString::from(value))
499    }
500}
501
502impl From<String> for LanguageModelProviderName {
503    fn from(value: String) -> Self {
504        Self(SharedString::from(value))
505    }
506}