language_model.rs

  1mod model;
  2mod rate_limiter;
  3mod registry;
  4mod request;
  5mod role;
  6mod telemetry;
  7
  8#[cfg(any(test, feature = "test-support"))]
  9pub mod fake_provider;
 10
 11use anyhow::{Result, anyhow};
 12use client::Client;
 13use futures::FutureExt;
 14use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 15use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
 16use http_client::http::{HeaderMap, HeaderValue};
 17use icons::IconName;
 18use parking_lot::Mutex;
 19use proto::Plan;
 20use schemars::JsonSchema;
 21use serde::{Deserialize, Serialize, de::DeserializeOwned};
 22use std::fmt;
 23use std::ops::{Add, Sub};
 24use std::str::FromStr as _;
 25use std::sync::Arc;
 26use thiserror::Error;
 27use util::serde::is_default;
 28use zed_llm_client::{
 29    MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
 30};
 31
 32pub use crate::model::*;
 33pub use crate::rate_limiter::*;
 34pub use crate::registry::*;
 35pub use crate::request::*;
 36pub use crate::role::*;
 37pub use crate::telemetry::*;
 38
 39pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
 40
 41pub fn init(client: Arc<Client>, cx: &mut App) {
 42    init_settings(cx);
 43    RefreshLlmTokenListener::register(client.clone(), cx);
 44}
 45
 46pub fn init_settings(cx: &mut App) {
 47    registry::init(cx);
 48}
 49
 50/// The availability of a [`LanguageModel`].
 51#[derive(Debug, PartialEq, Eq, Clone, Copy)]
 52pub enum LanguageModelAvailability {
 53    /// The language model is available to the general public.
 54    Public,
 55    /// The language model is available to users on the indicated plan.
 56    RequiresPlan(Plan),
 57}
 58
 59/// Configuration for caching language model messages.
 60#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 61pub struct LanguageModelCacheConfiguration {
 62    pub max_cache_anchors: usize,
 63    pub should_speculate: bool,
 64    pub min_total_token: usize,
 65}
 66
 67#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 68#[serde(tag = "status", rename_all = "snake_case")]
 69pub enum QueueState {
 70    Queued { position: usize },
 71    Started,
 72}
 73
 74/// A completion event from a language model.
 75#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 76pub enum LanguageModelCompletionEvent {
 77    QueueUpdate(QueueState),
 78    Stop(StopReason),
 79    Text(String),
 80    Thinking {
 81        text: String,
 82        signature: Option<String>,
 83    },
 84    ToolUse(LanguageModelToolUse),
 85    StartMessage {
 86        message_id: String,
 87    },
 88    UsageUpdate(TokenUsage),
 89}
 90
 91#[derive(Error, Debug)]
 92pub enum LanguageModelCompletionError {
 93    #[error("received bad input JSON")]
 94    BadInputJson {
 95        id: LanguageModelToolUseId,
 96        tool_name: Arc<str>,
 97        raw_input: Arc<str>,
 98        json_parse_error: String,
 99    },
100    #[error(transparent)]
101    Other(#[from] anyhow::Error),
102}
103
104/// Indicates the format used to define the input schema for a language model tool.
105#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
106pub enum LanguageModelToolSchemaFormat {
107    /// A JSON schema, see https://json-schema.org
108    JsonSchema,
109    /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
110    JsonSchemaSubset,
111}
112
113#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
114#[serde(rename_all = "snake_case")]
115pub enum StopReason {
116    EndTurn,
117    MaxTokens,
118    ToolUse,
119}
120
121#[derive(Debug, Clone, Copy)]
122pub struct RequestUsage {
123    pub limit: UsageLimit,
124    pub amount: i32,
125}
126
127impl RequestUsage {
128    pub fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
129        let limit = headers
130            .get(MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME)
131            .ok_or_else(|| anyhow!("missing {MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME:?} header"))?;
132        let limit = UsageLimit::from_str(limit.to_str()?)?;
133
134        let amount = headers
135            .get(MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME)
136            .ok_or_else(|| anyhow!("missing {MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME:?} header"))?;
137        let amount = amount.to_str()?.parse::<i32>()?;
138
139        Ok(Self { limit, amount })
140    }
141}
142
143#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
144pub struct TokenUsage {
145    #[serde(default, skip_serializing_if = "is_default")]
146    pub input_tokens: u32,
147    #[serde(default, skip_serializing_if = "is_default")]
148    pub output_tokens: u32,
149    #[serde(default, skip_serializing_if = "is_default")]
150    pub cache_creation_input_tokens: u32,
151    #[serde(default, skip_serializing_if = "is_default")]
152    pub cache_read_input_tokens: u32,
153}
154
155impl TokenUsage {
156    pub fn total_tokens(&self) -> u32 {
157        self.input_tokens
158            + self.output_tokens
159            + self.cache_read_input_tokens
160            + self.cache_creation_input_tokens
161    }
162}
163
164impl Add<TokenUsage> for TokenUsage {
165    type Output = Self;
166
167    fn add(self, other: Self) -> Self {
168        Self {
169            input_tokens: self.input_tokens + other.input_tokens,
170            output_tokens: self.output_tokens + other.output_tokens,
171            cache_creation_input_tokens: self.cache_creation_input_tokens
172                + other.cache_creation_input_tokens,
173            cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
174        }
175    }
176}
177
178impl Sub<TokenUsage> for TokenUsage {
179    type Output = Self;
180
181    fn sub(self, other: Self) -> Self {
182        Self {
183            input_tokens: self.input_tokens - other.input_tokens,
184            output_tokens: self.output_tokens - other.output_tokens,
185            cache_creation_input_tokens: self.cache_creation_input_tokens
186                - other.cache_creation_input_tokens,
187            cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
188        }
189    }
190}
191
192#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
193pub struct LanguageModelToolUseId(Arc<str>);
194
195impl fmt::Display for LanguageModelToolUseId {
196    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
197        write!(f, "{}", self.0)
198    }
199}
200
201impl<T> From<T> for LanguageModelToolUseId
202where
203    T: Into<Arc<str>>,
204{
205    fn from(value: T) -> Self {
206        Self(value.into())
207    }
208}
209
210#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
211pub struct LanguageModelToolUse {
212    pub id: LanguageModelToolUseId,
213    pub name: Arc<str>,
214    pub raw_input: String,
215    pub input: serde_json::Value,
216    pub is_input_complete: bool,
217}
218
219pub struct LanguageModelTextStream {
220    pub message_id: Option<String>,
221    pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
222    // Has complete token usage after the stream has finished
223    pub last_token_usage: Arc<Mutex<TokenUsage>>,
224}
225
226impl Default for LanguageModelTextStream {
227    fn default() -> Self {
228        Self {
229            message_id: None,
230            stream: Box::pin(futures::stream::empty()),
231            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
232        }
233    }
234}
235
236pub trait LanguageModel: Send + Sync {
237    fn id(&self) -> LanguageModelId;
238    fn name(&self) -> LanguageModelName;
239    fn provider_id(&self) -> LanguageModelProviderId;
240    fn provider_name(&self) -> LanguageModelProviderName;
241    fn telemetry_id(&self) -> String;
242
243    fn api_key(&self, _cx: &App) -> Option<String> {
244        None
245    }
246
247    /// Returns the availability of this language model.
248    fn availability(&self) -> LanguageModelAvailability {
249        LanguageModelAvailability::Public
250    }
251
252    /// Whether this model supports tools.
253    fn supports_tools(&self) -> bool;
254
255    /// Returns whether this model supports "max mode";
256    fn supports_max_mode(&self) -> bool {
257        if self.provider_id().0 != ZED_CLOUD_PROVIDER_ID {
258            return false;
259        }
260
261        const MAX_MODE_CAPABLE_MODELS: &[CloudModel] = &[
262            CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
263            CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
264        ];
265
266        for model in MAX_MODE_CAPABLE_MODELS {
267            if self.id().0 == model.id() {
268                return true;
269            }
270        }
271
272        false
273    }
274
275    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
276        LanguageModelToolSchemaFormat::JsonSchema
277    }
278
279    fn max_token_count(&self) -> usize;
280    fn max_output_tokens(&self) -> Option<u32> {
281        None
282    }
283
284    fn count_tokens(
285        &self,
286        request: LanguageModelRequest,
287        cx: &App,
288    ) -> BoxFuture<'static, Result<usize>>;
289
290    fn stream_completion(
291        &self,
292        request: LanguageModelRequest,
293        cx: &AsyncApp,
294    ) -> BoxFuture<
295        'static,
296        Result<
297            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
298        >,
299    >;
300
301    fn stream_completion_with_usage(
302        &self,
303        request: LanguageModelRequest,
304        cx: &AsyncApp,
305    ) -> BoxFuture<
306        'static,
307        Result<(
308            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
309            Option<RequestUsage>,
310        )>,
311    > {
312        self.stream_completion(request, cx)
313            .map(|result| result.map(|stream| (stream, None)))
314            .boxed()
315    }
316
317    fn stream_completion_text(
318        &self,
319        request: LanguageModelRequest,
320        cx: &AsyncApp,
321    ) -> BoxFuture<'static, Result<LanguageModelTextStream>> {
322        self.stream_completion_text_with_usage(request, cx)
323            .map(|result| result.map(|(stream, _usage)| stream))
324            .boxed()
325    }
326
327    fn stream_completion_text_with_usage(
328        &self,
329        request: LanguageModelRequest,
330        cx: &AsyncApp,
331    ) -> BoxFuture<'static, Result<(LanguageModelTextStream, Option<RequestUsage>)>> {
332        let future = self.stream_completion_with_usage(request, cx);
333
334        async move {
335            let (events, usage) = future.await?;
336            let mut events = events.fuse();
337            let mut message_id = None;
338            let mut first_item_text = None;
339            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
340
341            if let Some(first_event) = events.next().await {
342                match first_event {
343                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
344                        message_id = Some(id.clone());
345                    }
346                    Ok(LanguageModelCompletionEvent::Text(text)) => {
347                        first_item_text = Some(text);
348                    }
349                    _ => (),
350                }
351            }
352
353            let stream = futures::stream::iter(first_item_text.map(Ok))
354                .chain(events.filter_map({
355                    let last_token_usage = last_token_usage.clone();
356                    move |result| {
357                        let last_token_usage = last_token_usage.clone();
358                        async move {
359                            match result {
360                                Ok(LanguageModelCompletionEvent::QueueUpdate { .. }) => None,
361                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
362                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
363                                Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
364                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
365                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
366                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
367                                    *last_token_usage.lock() = token_usage;
368                                    None
369                                }
370                                Err(err) => Some(Err(err)),
371                            }
372                        }
373                    }
374                }))
375                .boxed();
376
377            Ok((
378                LanguageModelTextStream {
379                    message_id,
380                    stream,
381                    last_token_usage,
382                },
383                usage,
384            ))
385        }
386        .boxed()
387    }
388
389    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
390        None
391    }
392
393    #[cfg(any(test, feature = "test-support"))]
394    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
395        unimplemented!()
396    }
397}
398
399#[derive(Debug, Error)]
400pub enum LanguageModelKnownError {
401    #[error("Context window limit exceeded ({tokens})")]
402    ContextWindowLimitExceeded { tokens: usize },
403}
404
405pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
406    fn name() -> String;
407    fn description() -> String;
408}
409
410/// An error that occurred when trying to authenticate the language model provider.
411#[derive(Debug, Error)]
412pub enum AuthenticateError {
413    #[error("credentials not found")]
414    CredentialsNotFound,
415    #[error(transparent)]
416    Other(#[from] anyhow::Error),
417}
418
419pub trait LanguageModelProvider: 'static {
420    fn id(&self) -> LanguageModelProviderId;
421    fn name(&self) -> LanguageModelProviderName;
422    fn icon(&self) -> IconName {
423        IconName::ZedAssistant
424    }
425    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
426    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
427    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
428    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
429        Vec::new()
430    }
431    fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &App) {}
432    fn is_authenticated(&self, cx: &App) -> bool;
433    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
434    fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
435    fn must_accept_terms(&self, _cx: &App) -> bool {
436        false
437    }
438    fn render_accept_terms(
439        &self,
440        _view: LanguageModelProviderTosView,
441        _cx: &mut App,
442    ) -> Option<AnyElement> {
443        None
444    }
445    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
446}
447
448#[derive(PartialEq, Eq)]
449pub enum LanguageModelProviderTosView {
450    /// When there are some past interactions in the Agent Panel.
451    ThreadtEmptyState,
452    /// When there are no past interactions in the Agent Panel.
453    ThreadFreshStart,
454    PromptEditorPopup,
455    Configuration,
456}
457
458pub trait LanguageModelProviderState: 'static {
459    type ObservableEntity;
460
461    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
462
463    fn subscribe<T: 'static>(
464        &self,
465        cx: &mut gpui::Context<T>,
466        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
467    ) -> Option<gpui::Subscription> {
468        let entity = self.observable_entity()?;
469        Some(cx.observe(&entity, move |this, _, cx| {
470            callback(this, cx);
471        }))
472    }
473}
474
475#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
476pub struct LanguageModelId(pub SharedString);
477
478#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
479pub struct LanguageModelName(pub SharedString);
480
481#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
482pub struct LanguageModelProviderId(pub SharedString);
483
484#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
485pub struct LanguageModelProviderName(pub SharedString);
486
487impl fmt::Display for LanguageModelProviderId {
488    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
489        write!(f, "{}", self.0)
490    }
491}
492
493impl From<String> for LanguageModelId {
494    fn from(value: String) -> Self {
495        Self(SharedString::from(value))
496    }
497}
498
499impl From<String> for LanguageModelName {
500    fn from(value: String) -> Self {
501        Self(SharedString::from(value))
502    }
503}
504
505impl From<String> for LanguageModelProviderId {
506    fn from(value: String) -> Self {
507        Self(SharedString::from(value))
508    }
509}
510
511impl From<String> for LanguageModelProviderName {
512    fn from(value: String) -> Self {
513        Self(SharedString::from(value))
514    }
515}