1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7
8#[cfg(any(test, feature = "test-support"))]
9pub mod fake_provider;
10
11use anyhow::{Result, anyhow};
12use client::Client;
13use futures::FutureExt;
14use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
15use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
16use http_client::http::{HeaderMap, HeaderValue};
17use icons::IconName;
18use parking_lot::Mutex;
19use proto::Plan;
20use schemars::JsonSchema;
21use serde::{Deserialize, Serialize, de::DeserializeOwned};
22use std::fmt;
23use std::ops::{Add, Sub};
24use std::str::FromStr as _;
25use std::sync::Arc;
26use thiserror::Error;
27use util::serde::is_default;
28use zed_llm_client::{
29 MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME, MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
30};
31
32pub use crate::model::*;
33pub use crate::rate_limiter::*;
34pub use crate::registry::*;
35pub use crate::request::*;
36pub use crate::role::*;
37pub use crate::telemetry::*;
38
39pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
40
41pub fn init(client: Arc<Client>, cx: &mut App) {
42 init_settings(cx);
43 RefreshLlmTokenListener::register(client.clone(), cx);
44}
45
46pub fn init_settings(cx: &mut App) {
47 registry::init(cx);
48}
49
50/// The availability of a [`LanguageModel`].
51#[derive(Debug, PartialEq, Eq, Clone, Copy)]
52pub enum LanguageModelAvailability {
53 /// The language model is available to the general public.
54 Public,
55 /// The language model is available to users on the indicated plan.
56 RequiresPlan(Plan),
57}
58
59/// Configuration for caching language model messages.
60#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
61pub struct LanguageModelCacheConfiguration {
62 pub max_cache_anchors: usize,
63 pub should_speculate: bool,
64 pub min_total_token: usize,
65}
66
67/// A completion event from a language model.
68#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
69pub enum LanguageModelCompletionEvent {
70 Stop(StopReason),
71 Text(String),
72 Thinking {
73 text: String,
74 signature: Option<String>,
75 },
76 ToolUse(LanguageModelToolUse),
77 StartMessage {
78 message_id: String,
79 },
80 UsageUpdate(TokenUsage),
81}
82
83#[derive(Error, Debug)]
84pub enum LanguageModelCompletionError {
85 #[error("received bad input JSON")]
86 BadInputJson {
87 id: LanguageModelToolUseId,
88 tool_name: Arc<str>,
89 raw_input: Arc<str>,
90 json_parse_error: String,
91 },
92 #[error(transparent)]
93 Other(#[from] anyhow::Error),
94}
95
96/// Indicates the format used to define the input schema for a language model tool.
97#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
98pub enum LanguageModelToolSchemaFormat {
99 /// A JSON schema, see https://json-schema.org
100 JsonSchema,
101 /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
102 JsonSchemaSubset,
103}
104
105#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
106#[serde(rename_all = "snake_case")]
107pub enum StopReason {
108 EndTurn,
109 MaxTokens,
110 ToolUse,
111}
112
113#[derive(Debug, Clone, Copy)]
114pub struct RequestUsage {
115 pub limit: UsageLimit,
116 pub amount: i32,
117}
118
119impl RequestUsage {
120 pub fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
121 let limit = headers
122 .get(MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME)
123 .ok_or_else(|| anyhow!("missing {MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME:?} header"))?;
124 let limit = UsageLimit::from_str(limit.to_str()?)?;
125
126 let amount = headers
127 .get(MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME)
128 .ok_or_else(|| anyhow!("missing {MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME:?} header"))?;
129 let amount = amount.to_str()?.parse::<i32>()?;
130
131 Ok(Self { limit, amount })
132 }
133}
134
135#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
136pub struct TokenUsage {
137 #[serde(default, skip_serializing_if = "is_default")]
138 pub input_tokens: u32,
139 #[serde(default, skip_serializing_if = "is_default")]
140 pub output_tokens: u32,
141 #[serde(default, skip_serializing_if = "is_default")]
142 pub cache_creation_input_tokens: u32,
143 #[serde(default, skip_serializing_if = "is_default")]
144 pub cache_read_input_tokens: u32,
145}
146
147impl TokenUsage {
148 pub fn total_tokens(&self) -> u32 {
149 self.input_tokens
150 + self.output_tokens
151 + self.cache_read_input_tokens
152 + self.cache_creation_input_tokens
153 }
154}
155
156impl Add<TokenUsage> for TokenUsage {
157 type Output = Self;
158
159 fn add(self, other: Self) -> Self {
160 Self {
161 input_tokens: self.input_tokens + other.input_tokens,
162 output_tokens: self.output_tokens + other.output_tokens,
163 cache_creation_input_tokens: self.cache_creation_input_tokens
164 + other.cache_creation_input_tokens,
165 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
166 }
167 }
168}
169
170impl Sub<TokenUsage> for TokenUsage {
171 type Output = Self;
172
173 fn sub(self, other: Self) -> Self {
174 Self {
175 input_tokens: self.input_tokens - other.input_tokens,
176 output_tokens: self.output_tokens - other.output_tokens,
177 cache_creation_input_tokens: self.cache_creation_input_tokens
178 - other.cache_creation_input_tokens,
179 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
180 }
181 }
182}
183
184#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
185pub struct LanguageModelToolUseId(Arc<str>);
186
187impl fmt::Display for LanguageModelToolUseId {
188 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
189 write!(f, "{}", self.0)
190 }
191}
192
193impl<T> From<T> for LanguageModelToolUseId
194where
195 T: Into<Arc<str>>,
196{
197 fn from(value: T) -> Self {
198 Self(value.into())
199 }
200}
201
202#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
203pub struct LanguageModelToolUse {
204 pub id: LanguageModelToolUseId,
205 pub name: Arc<str>,
206 pub raw_input: String,
207 pub input: serde_json::Value,
208 pub is_input_complete: bool,
209}
210
211pub struct LanguageModelTextStream {
212 pub message_id: Option<String>,
213 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
214 // Has complete token usage after the stream has finished
215 pub last_token_usage: Arc<Mutex<TokenUsage>>,
216}
217
218impl Default for LanguageModelTextStream {
219 fn default() -> Self {
220 Self {
221 message_id: None,
222 stream: Box::pin(futures::stream::empty()),
223 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
224 }
225 }
226}
227
228pub trait LanguageModel: Send + Sync {
229 fn id(&self) -> LanguageModelId;
230 fn name(&self) -> LanguageModelName;
231 fn provider_id(&self) -> LanguageModelProviderId;
232 fn provider_name(&self) -> LanguageModelProviderName;
233 fn telemetry_id(&self) -> String;
234
235 fn api_key(&self, _cx: &App) -> Option<String> {
236 None
237 }
238
239 /// Returns the availability of this language model.
240 fn availability(&self) -> LanguageModelAvailability {
241 LanguageModelAvailability::Public
242 }
243
244 /// Whether this model supports tools.
245 fn supports_tools(&self) -> bool;
246
247 /// Returns whether this model supports "max mode";
248 fn supports_max_mode(&self) -> bool {
249 if self.provider_id().0 != ZED_CLOUD_PROVIDER_ID {
250 return false;
251 }
252
253 const MAX_MODE_CAPABLE_MODELS: &[CloudModel] = &[
254 CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
255 CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
256 ];
257
258 for model in MAX_MODE_CAPABLE_MODELS {
259 if self.id().0 == model.id() {
260 return true;
261 }
262 }
263
264 false
265 }
266
267 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
268 LanguageModelToolSchemaFormat::JsonSchema
269 }
270
271 fn max_token_count(&self) -> usize;
272 fn max_output_tokens(&self) -> Option<u32> {
273 None
274 }
275
276 fn count_tokens(
277 &self,
278 request: LanguageModelRequest,
279 cx: &App,
280 ) -> BoxFuture<'static, Result<usize>>;
281
282 fn stream_completion(
283 &self,
284 request: LanguageModelRequest,
285 cx: &AsyncApp,
286 ) -> BoxFuture<
287 'static,
288 Result<
289 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
290 >,
291 >;
292
293 fn stream_completion_with_usage(
294 &self,
295 request: LanguageModelRequest,
296 cx: &AsyncApp,
297 ) -> BoxFuture<
298 'static,
299 Result<(
300 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
301 Option<RequestUsage>,
302 )>,
303 > {
304 self.stream_completion(request, cx)
305 .map(|result| result.map(|stream| (stream, None)))
306 .boxed()
307 }
308
309 fn stream_completion_text(
310 &self,
311 request: LanguageModelRequest,
312 cx: &AsyncApp,
313 ) -> BoxFuture<'static, Result<LanguageModelTextStream>> {
314 self.stream_completion_text_with_usage(request, cx)
315 .map(|result| result.map(|(stream, _usage)| stream))
316 .boxed()
317 }
318
319 fn stream_completion_text_with_usage(
320 &self,
321 request: LanguageModelRequest,
322 cx: &AsyncApp,
323 ) -> BoxFuture<'static, Result<(LanguageModelTextStream, Option<RequestUsage>)>> {
324 let future = self.stream_completion_with_usage(request, cx);
325
326 async move {
327 let (events, usage) = future.await?;
328 let mut events = events.fuse();
329 let mut message_id = None;
330 let mut first_item_text = None;
331 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
332
333 if let Some(first_event) = events.next().await {
334 match first_event {
335 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
336 message_id = Some(id.clone());
337 }
338 Ok(LanguageModelCompletionEvent::Text(text)) => {
339 first_item_text = Some(text);
340 }
341 _ => (),
342 }
343 }
344
345 let stream = futures::stream::iter(first_item_text.map(Ok))
346 .chain(events.filter_map({
347 let last_token_usage = last_token_usage.clone();
348 move |result| {
349 let last_token_usage = last_token_usage.clone();
350 async move {
351 match result {
352 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
353 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
354 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
355 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
356 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
357 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
358 *last_token_usage.lock() = token_usage;
359 None
360 }
361 Err(err) => Some(Err(err)),
362 }
363 }
364 }
365 }))
366 .boxed();
367
368 Ok((
369 LanguageModelTextStream {
370 message_id,
371 stream,
372 last_token_usage,
373 },
374 usage,
375 ))
376 }
377 .boxed()
378 }
379
380 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
381 None
382 }
383
384 #[cfg(any(test, feature = "test-support"))]
385 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
386 unimplemented!()
387 }
388}
389
390#[derive(Debug, Error)]
391pub enum LanguageModelKnownError {
392 #[error("Context window limit exceeded ({tokens})")]
393 ContextWindowLimitExceeded { tokens: usize },
394}
395
396pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
397 fn name() -> String;
398 fn description() -> String;
399}
400
401/// An error that occurred when trying to authenticate the language model provider.
402#[derive(Debug, Error)]
403pub enum AuthenticateError {
404 #[error("credentials not found")]
405 CredentialsNotFound,
406 #[error(transparent)]
407 Other(#[from] anyhow::Error),
408}
409
410pub trait LanguageModelProvider: 'static {
411 fn id(&self) -> LanguageModelProviderId;
412 fn name(&self) -> LanguageModelProviderName;
413 fn icon(&self) -> IconName {
414 IconName::ZedAssistant
415 }
416 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
417 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
418 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
419 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
420 Vec::new()
421 }
422 fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &App) {}
423 fn is_authenticated(&self, cx: &App) -> bool;
424 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
425 fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
426 fn must_accept_terms(&self, _cx: &App) -> bool {
427 false
428 }
429 fn render_accept_terms(
430 &self,
431 _view: LanguageModelProviderTosView,
432 _cx: &mut App,
433 ) -> Option<AnyElement> {
434 None
435 }
436 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
437}
438
439#[derive(PartialEq, Eq)]
440pub enum LanguageModelProviderTosView {
441 /// When there are some past interactions in the Agent Panel.
442 ThreadtEmptyState,
443 /// When there are no past interactions in the Agent Panel.
444 ThreadFreshStart,
445 PromptEditorPopup,
446 Configuration,
447}
448
449pub trait LanguageModelProviderState: 'static {
450 type ObservableEntity;
451
452 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
453
454 fn subscribe<T: 'static>(
455 &self,
456 cx: &mut gpui::Context<T>,
457 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
458 ) -> Option<gpui::Subscription> {
459 let entity = self.observable_entity()?;
460 Some(cx.observe(&entity, move |this, _, cx| {
461 callback(this, cx);
462 }))
463 }
464}
465
466#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
467pub struct LanguageModelId(pub SharedString);
468
469#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
470pub struct LanguageModelName(pub SharedString);
471
472#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
473pub struct LanguageModelProviderId(pub SharedString);
474
475#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
476pub struct LanguageModelProviderName(pub SharedString);
477
478impl fmt::Display for LanguageModelProviderId {
479 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
480 write!(f, "{}", self.0)
481 }
482}
483
484impl From<String> for LanguageModelId {
485 fn from(value: String) -> Self {
486 Self(SharedString::from(value))
487 }
488}
489
490impl From<String> for LanguageModelName {
491 fn from(value: String) -> Self {
492 Self(SharedString::from(value))
493 }
494}
495
496impl From<String> for LanguageModelProviderId {
497 fn from(value: String) -> Self {
498 Self(SharedString::from(value))
499 }
500}
501
502impl From<String> for LanguageModelProviderName {
503 fn from(value: String) -> Self {
504 Self(SharedString::from(value))
505 }
506}