1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7
8#[cfg(any(test, feature = "test-support"))]
9pub mod fake_provider;
10
11use anyhow::{Context as _, Result};
12use client::Client;
13use futures::FutureExt;
14use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
15use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
16use http_client::http::{HeaderMap, HeaderValue};
17use icons::IconName;
18use parking_lot::Mutex;
19use proto::Plan;
20use schemars::JsonSchema;
21use serde::{Deserialize, Serialize, de::DeserializeOwned};
22use std::fmt;
23use std::ops::{Add, Sub};
24use std::str::FromStr as _;
25use std::sync::Arc;
26use thiserror::Error;
27use util::serde::is_default;
28use zed_llm_client::{
29 CompletionRequestStatus, MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME,
30 MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME, UsageLimit,
31};
32
33pub use crate::model::*;
34pub use crate::rate_limiter::*;
35pub use crate::registry::*;
36pub use crate::request::*;
37pub use crate::role::*;
38pub use crate::telemetry::*;
39
40pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
41
42pub fn init(client: Arc<Client>, cx: &mut App) {
43 init_settings(cx);
44 RefreshLlmTokenListener::register(client.clone(), cx);
45}
46
47pub fn init_settings(cx: &mut App) {
48 registry::init(cx);
49}
50
51/// The availability of a [`LanguageModel`].
52#[derive(Debug, PartialEq, Eq, Clone, Copy)]
53pub enum LanguageModelAvailability {
54 /// The language model is available to the general public.
55 Public,
56 /// The language model is available to users on the indicated plan.
57 RequiresPlan(Plan),
58}
59
60/// Configuration for caching language model messages.
61#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
62pub struct LanguageModelCacheConfiguration {
63 pub max_cache_anchors: usize,
64 pub should_speculate: bool,
65 pub min_total_token: usize,
66}
67
68/// A completion event from a language model.
69#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
70pub enum LanguageModelCompletionEvent {
71 StatusUpdate(CompletionRequestStatus),
72 Stop(StopReason),
73 Text(String),
74 Thinking {
75 text: String,
76 signature: Option<String>,
77 },
78 ToolUse(LanguageModelToolUse),
79 StartMessage {
80 message_id: String,
81 },
82 UsageUpdate(TokenUsage),
83}
84
85#[derive(Error, Debug)]
86pub enum LanguageModelCompletionError {
87 #[error("received bad input JSON")]
88 BadInputJson {
89 id: LanguageModelToolUseId,
90 tool_name: Arc<str>,
91 raw_input: Arc<str>,
92 json_parse_error: String,
93 },
94 #[error(transparent)]
95 Other(#[from] anyhow::Error),
96}
97
98/// Indicates the format used to define the input schema for a language model tool.
99#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
100pub enum LanguageModelToolSchemaFormat {
101 /// A JSON schema, see https://json-schema.org
102 JsonSchema,
103 /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
104 JsonSchemaSubset,
105}
106
107#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
108#[serde(rename_all = "snake_case")]
109pub enum StopReason {
110 EndTurn,
111 MaxTokens,
112 ToolUse,
113}
114
115#[derive(Debug, Clone, Copy)]
116pub struct RequestUsage {
117 pub limit: UsageLimit,
118 pub amount: i32,
119}
120
121impl RequestUsage {
122 pub fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
123 let limit = headers
124 .get(MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME)
125 .with_context(|| {
126 format!("missing {MODEL_REQUESTS_USAGE_LIMIT_HEADER_NAME:?} header")
127 })?;
128 let limit = UsageLimit::from_str(limit.to_str()?)?;
129
130 let amount = headers
131 .get(MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME)
132 .with_context(|| {
133 format!("missing {MODEL_REQUESTS_USAGE_AMOUNT_HEADER_NAME:?} header")
134 })?;
135 let amount = amount.to_str()?.parse::<i32>()?;
136
137 Ok(Self { limit, amount })
138 }
139}
140
141#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
142pub struct TokenUsage {
143 #[serde(default, skip_serializing_if = "is_default")]
144 pub input_tokens: u32,
145 #[serde(default, skip_serializing_if = "is_default")]
146 pub output_tokens: u32,
147 #[serde(default, skip_serializing_if = "is_default")]
148 pub cache_creation_input_tokens: u32,
149 #[serde(default, skip_serializing_if = "is_default")]
150 pub cache_read_input_tokens: u32,
151}
152
153impl TokenUsage {
154 pub fn total_tokens(&self) -> u32 {
155 self.input_tokens
156 + self.output_tokens
157 + self.cache_read_input_tokens
158 + self.cache_creation_input_tokens
159 }
160}
161
162impl Add<TokenUsage> for TokenUsage {
163 type Output = Self;
164
165 fn add(self, other: Self) -> Self {
166 Self {
167 input_tokens: self.input_tokens + other.input_tokens,
168 output_tokens: self.output_tokens + other.output_tokens,
169 cache_creation_input_tokens: self.cache_creation_input_tokens
170 + other.cache_creation_input_tokens,
171 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
172 }
173 }
174}
175
176impl Sub<TokenUsage> for TokenUsage {
177 type Output = Self;
178
179 fn sub(self, other: Self) -> Self {
180 Self {
181 input_tokens: self.input_tokens - other.input_tokens,
182 output_tokens: self.output_tokens - other.output_tokens,
183 cache_creation_input_tokens: self.cache_creation_input_tokens
184 - other.cache_creation_input_tokens,
185 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
186 }
187 }
188}
189
190#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
191pub struct LanguageModelToolUseId(Arc<str>);
192
193impl fmt::Display for LanguageModelToolUseId {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 write!(f, "{}", self.0)
196 }
197}
198
199impl<T> From<T> for LanguageModelToolUseId
200where
201 T: Into<Arc<str>>,
202{
203 fn from(value: T) -> Self {
204 Self(value.into())
205 }
206}
207
208#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
209pub struct LanguageModelToolUse {
210 pub id: LanguageModelToolUseId,
211 pub name: Arc<str>,
212 pub raw_input: String,
213 pub input: serde_json::Value,
214 pub is_input_complete: bool,
215}
216
217pub struct LanguageModelTextStream {
218 pub message_id: Option<String>,
219 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
220 // Has complete token usage after the stream has finished
221 pub last_token_usage: Arc<Mutex<TokenUsage>>,
222}
223
224impl Default for LanguageModelTextStream {
225 fn default() -> Self {
226 Self {
227 message_id: None,
228 stream: Box::pin(futures::stream::empty()),
229 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
230 }
231 }
232}
233
234pub trait LanguageModel: Send + Sync {
235 fn id(&self) -> LanguageModelId;
236 fn name(&self) -> LanguageModelName;
237 fn provider_id(&self) -> LanguageModelProviderId;
238 fn provider_name(&self) -> LanguageModelProviderName;
239 fn telemetry_id(&self) -> String;
240
241 fn api_key(&self, _cx: &App) -> Option<String> {
242 None
243 }
244
245 /// Returns the availability of this language model.
246 fn availability(&self) -> LanguageModelAvailability {
247 LanguageModelAvailability::Public
248 }
249
250 /// Whether this model supports images
251 fn supports_images(&self) -> bool;
252
253 /// Whether this model supports tools.
254 fn supports_tools(&self) -> bool;
255
256 /// Whether this model supports choosing which tool to use.
257 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
258
259 /// Returns whether this model supports "max mode";
260 fn supports_max_mode(&self) -> bool {
261 if self.provider_id().0 != ZED_CLOUD_PROVIDER_ID {
262 return false;
263 }
264
265 const MAX_MODE_CAPABLE_MODELS: &[CloudModel] = &[
266 CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
267 CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
268 ];
269
270 for model in MAX_MODE_CAPABLE_MODELS {
271 if self.id().0 == model.id() {
272 return true;
273 }
274 }
275
276 false
277 }
278
279 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
280 LanguageModelToolSchemaFormat::JsonSchema
281 }
282
283 fn max_token_count(&self) -> usize;
284 fn max_output_tokens(&self) -> Option<u32> {
285 None
286 }
287
288 fn count_tokens(
289 &self,
290 request: LanguageModelRequest,
291 cx: &App,
292 ) -> BoxFuture<'static, Result<usize>>;
293
294 fn stream_completion(
295 &self,
296 request: LanguageModelRequest,
297 cx: &AsyncApp,
298 ) -> BoxFuture<
299 'static,
300 Result<
301 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
302 >,
303 >;
304
305 fn stream_completion_text(
306 &self,
307 request: LanguageModelRequest,
308 cx: &AsyncApp,
309 ) -> BoxFuture<'static, Result<LanguageModelTextStream>> {
310 let future = self.stream_completion(request, cx);
311
312 async move {
313 let events = future.await?;
314 let mut events = events.fuse();
315 let mut message_id = None;
316 let mut first_item_text = None;
317 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
318
319 if let Some(first_event) = events.next().await {
320 match first_event {
321 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
322 message_id = Some(id.clone());
323 }
324 Ok(LanguageModelCompletionEvent::Text(text)) => {
325 first_item_text = Some(text);
326 }
327 _ => (),
328 }
329 }
330
331 let stream = futures::stream::iter(first_item_text.map(Ok))
332 .chain(events.filter_map({
333 let last_token_usage = last_token_usage.clone();
334 move |result| {
335 let last_token_usage = last_token_usage.clone();
336 async move {
337 match result {
338 Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
339 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
340 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
341 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
342 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
343 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
344 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
345 *last_token_usage.lock() = token_usage;
346 None
347 }
348 Err(err) => Some(Err(err)),
349 }
350 }
351 }
352 }))
353 .boxed();
354
355 Ok(LanguageModelTextStream {
356 message_id,
357 stream,
358 last_token_usage,
359 })
360 }
361 .boxed()
362 }
363
364 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
365 None
366 }
367
368 #[cfg(any(test, feature = "test-support"))]
369 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
370 unimplemented!()
371 }
372}
373
374#[derive(Debug, Error)]
375pub enum LanguageModelKnownError {
376 #[error("Context window limit exceeded ({tokens})")]
377 ContextWindowLimitExceeded { tokens: usize },
378}
379
380pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
381 fn name() -> String;
382 fn description() -> String;
383}
384
385/// An error that occurred when trying to authenticate the language model provider.
386#[derive(Debug, Error)]
387pub enum AuthenticateError {
388 #[error("credentials not found")]
389 CredentialsNotFound,
390 #[error(transparent)]
391 Other(#[from] anyhow::Error),
392}
393
394pub trait LanguageModelProvider: 'static {
395 fn id(&self) -> LanguageModelProviderId;
396 fn name(&self) -> LanguageModelProviderName;
397 fn icon(&self) -> IconName {
398 IconName::ZedAssistant
399 }
400 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
401 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
402 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
403 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
404 Vec::new()
405 }
406 fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &App) {}
407 fn is_authenticated(&self, cx: &App) -> bool;
408 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
409 fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
410 fn must_accept_terms(&self, _cx: &App) -> bool {
411 false
412 }
413 fn render_accept_terms(
414 &self,
415 _view: LanguageModelProviderTosView,
416 _cx: &mut App,
417 ) -> Option<AnyElement> {
418 None
419 }
420 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
421}
422
423#[derive(PartialEq, Eq)]
424pub enum LanguageModelProviderTosView {
425 /// When there are some past interactions in the Agent Panel.
426 ThreadtEmptyState,
427 /// When there are no past interactions in the Agent Panel.
428 ThreadFreshStart,
429 PromptEditorPopup,
430 Configuration,
431}
432
433pub trait LanguageModelProviderState: 'static {
434 type ObservableEntity;
435
436 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
437
438 fn subscribe<T: 'static>(
439 &self,
440 cx: &mut gpui::Context<T>,
441 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
442 ) -> Option<gpui::Subscription> {
443 let entity = self.observable_entity()?;
444 Some(cx.observe(&entity, move |this, _, cx| {
445 callback(this, cx);
446 }))
447 }
448}
449
450#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
451pub struct LanguageModelId(pub SharedString);
452
453#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
454pub struct LanguageModelName(pub SharedString);
455
456#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
457pub struct LanguageModelProviderId(pub SharedString);
458
459#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
460pub struct LanguageModelProviderName(pub SharedString);
461
462impl fmt::Display for LanguageModelProviderId {
463 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
464 write!(f, "{}", self.0)
465 }
466}
467
468impl From<String> for LanguageModelId {
469 fn from(value: String) -> Self {
470 Self(SharedString::from(value))
471 }
472}
473
474impl From<String> for LanguageModelName {
475 fn from(value: String) -> Self {
476 Self(SharedString::from(value))
477 }
478}
479
480impl From<String> for LanguageModelProviderId {
481 fn from(value: String) -> Self {
482 Self(SharedString::from(value))
483 }
484}
485
486impl From<String> for LanguageModelProviderName {
487 fn from(value: String) -> Self {
488 Self(SharedString::from(value))
489 }
490}