1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7
8#[cfg(any(test, feature = "test-support"))]
9pub mod fake_provider;
10
11use anthropic::{AnthropicError, parse_prompt_too_long};
12use anyhow::Result;
13use client::Client;
14use futures::FutureExt;
15use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
16use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
17use http_client::http;
18use icons::IconName;
19use parking_lot::Mutex;
20use schemars::JsonSchema;
21use serde::{Deserialize, Serialize, de::DeserializeOwned};
22use std::ops::{Add, Sub};
23use std::sync::Arc;
24use std::time::Duration;
25use std::{fmt, io};
26use thiserror::Error;
27use util::serde::is_default;
28use zed_llm_client::CompletionRequestStatus;
29
30pub use crate::model::*;
31pub use crate::rate_limiter::*;
32pub use crate::registry::*;
33pub use crate::request::*;
34pub use crate::role::*;
35pub use crate::telemetry::*;
36
37pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
38
39/// If we get a rate limit error that doesn't tell us when we can retry,
40/// default to waiting this long before retrying.
41const DEFAULT_RATE_LIMIT_RETRY_AFTER: Duration = Duration::from_secs(4);
42
43pub fn init(client: Arc<Client>, cx: &mut App) {
44 init_settings(cx);
45 RefreshLlmTokenListener::register(client.clone(), cx);
46}
47
48pub fn init_settings(cx: &mut App) {
49 registry::init(cx);
50}
51
52/// Configuration for caching language model messages.
53#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
54pub struct LanguageModelCacheConfiguration {
55 pub max_cache_anchors: usize,
56 pub should_speculate: bool,
57 pub min_total_token: u64,
58}
59
60/// A completion event from a language model.
61#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
62pub enum LanguageModelCompletionEvent {
63 StatusUpdate(CompletionRequestStatus),
64 Stop(StopReason),
65 Text(String),
66 Thinking {
67 text: String,
68 signature: Option<String>,
69 },
70 RedactedThinking {
71 data: String,
72 },
73 ToolUse(LanguageModelToolUse),
74 StartMessage {
75 message_id: String,
76 },
77 UsageUpdate(TokenUsage),
78}
79
80#[derive(Error, Debug)]
81pub enum LanguageModelCompletionError {
82 #[error("rate limit exceeded, retry after {retry_after:?}")]
83 RateLimitExceeded { retry_after: Duration },
84 #[error("received bad input JSON")]
85 BadInputJson {
86 id: LanguageModelToolUseId,
87 tool_name: Arc<str>,
88 raw_input: Arc<str>,
89 json_parse_error: String,
90 },
91 #[error("language model provider's API is overloaded")]
92 Overloaded,
93 #[error(transparent)]
94 Other(#[from] anyhow::Error),
95 #[error("invalid request format to language model provider's API")]
96 BadRequestFormat,
97 #[error("authentication error with language model provider's API")]
98 AuthenticationError,
99 #[error("permission error with language model provider's API")]
100 PermissionError,
101 #[error("language model provider API endpoint not found")]
102 ApiEndpointNotFound,
103 #[error("prompt too large for context window")]
104 PromptTooLarge { tokens: Option<u64> },
105 #[error("internal server error in language model provider's API")]
106 ApiInternalServerError,
107 #[error("I/O error reading response from language model provider's API: {0:?}")]
108 ApiReadResponseError(io::Error),
109 #[error("HTTP response error from language model provider's API: status {status} - {body:?}")]
110 HttpResponseError { status: u16, body: String },
111 #[error("error serializing request to language model provider API: {0}")]
112 SerializeRequest(serde_json::Error),
113 #[error("error building request body to language model provider API: {0}")]
114 BuildRequestBody(http::Error),
115 #[error("error sending HTTP request to language model provider API: {0}")]
116 HttpSend(anyhow::Error),
117 #[error("error deserializing language model provider API response: {0}")]
118 DeserializeResponse(serde_json::Error),
119 #[error("unexpected language model provider API response format: {0}")]
120 UnknownResponseFormat(String),
121}
122
123impl From<AnthropicError> for LanguageModelCompletionError {
124 fn from(error: AnthropicError) -> Self {
125 match error {
126 AnthropicError::SerializeRequest(error) => Self::SerializeRequest(error),
127 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody(error),
128 AnthropicError::HttpSend(error) => Self::HttpSend(error),
129 AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse(error),
130 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError(error),
131 AnthropicError::HttpResponseError { status, body } => {
132 Self::HttpResponseError { status, body }
133 }
134 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { retry_after },
135 AnthropicError::ApiError(api_error) => api_error.into(),
136 AnthropicError::UnexpectedResponseFormat(error) => Self::UnknownResponseFormat(error),
137 }
138 }
139}
140
141impl From<anthropic::ApiError> for LanguageModelCompletionError {
142 fn from(error: anthropic::ApiError) -> Self {
143 use anthropic::ApiErrorCode::*;
144
145 match error.code() {
146 Some(code) => match code {
147 InvalidRequestError => LanguageModelCompletionError::BadRequestFormat,
148 AuthenticationError => LanguageModelCompletionError::AuthenticationError,
149 PermissionError => LanguageModelCompletionError::PermissionError,
150 NotFoundError => LanguageModelCompletionError::ApiEndpointNotFound,
151 RequestTooLarge => LanguageModelCompletionError::PromptTooLarge {
152 tokens: parse_prompt_too_long(&error.message),
153 },
154 RateLimitError => LanguageModelCompletionError::RateLimitExceeded {
155 retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
156 },
157 ApiError => LanguageModelCompletionError::ApiInternalServerError,
158 OverloadedError => LanguageModelCompletionError::Overloaded,
159 },
160 None => LanguageModelCompletionError::Other(error.into()),
161 }
162 }
163}
164
165/// Indicates the format used to define the input schema for a language model tool.
166#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
167pub enum LanguageModelToolSchemaFormat {
168 /// A JSON schema, see https://json-schema.org
169 JsonSchema,
170 /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
171 JsonSchemaSubset,
172}
173
174#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
175#[serde(rename_all = "snake_case")]
176pub enum StopReason {
177 EndTurn,
178 MaxTokens,
179 ToolUse,
180 Refusal,
181}
182
183#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
184pub struct TokenUsage {
185 #[serde(default, skip_serializing_if = "is_default")]
186 pub input_tokens: u64,
187 #[serde(default, skip_serializing_if = "is_default")]
188 pub output_tokens: u64,
189 #[serde(default, skip_serializing_if = "is_default")]
190 pub cache_creation_input_tokens: u64,
191 #[serde(default, skip_serializing_if = "is_default")]
192 pub cache_read_input_tokens: u64,
193}
194
195impl TokenUsage {
196 pub fn total_tokens(&self) -> u64 {
197 self.input_tokens
198 + self.output_tokens
199 + self.cache_read_input_tokens
200 + self.cache_creation_input_tokens
201 }
202}
203
204impl Add<TokenUsage> for TokenUsage {
205 type Output = Self;
206
207 fn add(self, other: Self) -> Self {
208 Self {
209 input_tokens: self.input_tokens + other.input_tokens,
210 output_tokens: self.output_tokens + other.output_tokens,
211 cache_creation_input_tokens: self.cache_creation_input_tokens
212 + other.cache_creation_input_tokens,
213 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
214 }
215 }
216}
217
218impl Sub<TokenUsage> for TokenUsage {
219 type Output = Self;
220
221 fn sub(self, other: Self) -> Self {
222 Self {
223 input_tokens: self.input_tokens - other.input_tokens,
224 output_tokens: self.output_tokens - other.output_tokens,
225 cache_creation_input_tokens: self.cache_creation_input_tokens
226 - other.cache_creation_input_tokens,
227 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
228 }
229 }
230}
231
232#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
233pub struct LanguageModelToolUseId(Arc<str>);
234
235impl fmt::Display for LanguageModelToolUseId {
236 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
237 write!(f, "{}", self.0)
238 }
239}
240
241impl<T> From<T> for LanguageModelToolUseId
242where
243 T: Into<Arc<str>>,
244{
245 fn from(value: T) -> Self {
246 Self(value.into())
247 }
248}
249
250#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
251pub struct LanguageModelToolUse {
252 pub id: LanguageModelToolUseId,
253 pub name: Arc<str>,
254 pub raw_input: String,
255 pub input: serde_json::Value,
256 pub is_input_complete: bool,
257}
258
259pub struct LanguageModelTextStream {
260 pub message_id: Option<String>,
261 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
262 // Has complete token usage after the stream has finished
263 pub last_token_usage: Arc<Mutex<TokenUsage>>,
264}
265
266impl Default for LanguageModelTextStream {
267 fn default() -> Self {
268 Self {
269 message_id: None,
270 stream: Box::pin(futures::stream::empty()),
271 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
272 }
273 }
274}
275
276pub trait LanguageModel: Send + Sync {
277 fn id(&self) -> LanguageModelId;
278 fn name(&self) -> LanguageModelName;
279 fn provider_id(&self) -> LanguageModelProviderId;
280 fn provider_name(&self) -> LanguageModelProviderName;
281 fn telemetry_id(&self) -> String;
282
283 fn api_key(&self, _cx: &App) -> Option<String> {
284 None
285 }
286
287 /// Whether this model supports images
288 fn supports_images(&self) -> bool;
289
290 /// Whether this model supports tools.
291 fn supports_tools(&self) -> bool;
292
293 /// Whether this model supports choosing which tool to use.
294 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
295
296 /// Returns whether this model supports "burn mode";
297 fn supports_burn_mode(&self) -> bool {
298 false
299 }
300
301 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
302 LanguageModelToolSchemaFormat::JsonSchema
303 }
304
305 fn max_token_count(&self) -> u64;
306 fn max_output_tokens(&self) -> Option<u64> {
307 None
308 }
309
310 fn count_tokens(
311 &self,
312 request: LanguageModelRequest,
313 cx: &App,
314 ) -> BoxFuture<'static, Result<u64>>;
315
316 fn stream_completion(
317 &self,
318 request: LanguageModelRequest,
319 cx: &AsyncApp,
320 ) -> BoxFuture<
321 'static,
322 Result<
323 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
324 LanguageModelCompletionError,
325 >,
326 >;
327
328 fn stream_completion_text(
329 &self,
330 request: LanguageModelRequest,
331 cx: &AsyncApp,
332 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
333 let future = self.stream_completion(request, cx);
334
335 async move {
336 let events = future.await?;
337 let mut events = events.fuse();
338 let mut message_id = None;
339 let mut first_item_text = None;
340 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
341
342 if let Some(first_event) = events.next().await {
343 match first_event {
344 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
345 message_id = Some(id.clone());
346 }
347 Ok(LanguageModelCompletionEvent::Text(text)) => {
348 first_item_text = Some(text);
349 }
350 _ => (),
351 }
352 }
353
354 let stream = futures::stream::iter(first_item_text.map(Ok))
355 .chain(events.filter_map({
356 let last_token_usage = last_token_usage.clone();
357 move |result| {
358 let last_token_usage = last_token_usage.clone();
359 async move {
360 match result {
361 Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
362 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
363 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
364 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
365 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
366 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
367 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
368 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
369 *last_token_usage.lock() = token_usage;
370 None
371 }
372 Err(err) => Some(Err(err)),
373 }
374 }
375 }
376 }))
377 .boxed();
378
379 Ok(LanguageModelTextStream {
380 message_id,
381 stream,
382 last_token_usage,
383 })
384 }
385 .boxed()
386 }
387
388 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
389 None
390 }
391
392 #[cfg(any(test, feature = "test-support"))]
393 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
394 unimplemented!()
395 }
396}
397
398#[derive(Debug, Error)]
399pub enum LanguageModelKnownError {
400 #[error("Context window limit exceeded ({tokens})")]
401 ContextWindowLimitExceeded { tokens: u64 },
402 #[error("Language model provider's API is currently overloaded")]
403 Overloaded,
404 #[error("Language model provider's API encountered an internal server error")]
405 ApiInternalServerError,
406 #[error("I/O error while reading response from language model provider's API: {0:?}")]
407 ReadResponseError(io::Error),
408 #[error("Error deserializing response from language model provider's API: {0:?}")]
409 DeserializeResponse(serde_json::Error),
410 #[error("Language model provider's API returned a response in an unknown format")]
411 UnknownResponseFormat(String),
412 #[error("Rate limit exceeded for language model provider's API; retry in {retry_after:?}")]
413 RateLimitExceeded { retry_after: Duration },
414}
415
416impl LanguageModelKnownError {
417 /// Attempts to map an HTTP response status code to a known error type.
418 /// Returns None if the status code doesn't map to a specific known error.
419 pub fn from_http_response(status: u16, _body: &str) -> Option<Self> {
420 match status {
421 429 => Some(Self::RateLimitExceeded {
422 retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
423 }),
424 503 => Some(Self::Overloaded),
425 500..=599 => Some(Self::ApiInternalServerError),
426 _ => None,
427 }
428 }
429}
430
431pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
432 fn name() -> String;
433 fn description() -> String;
434}
435
436/// An error that occurred when trying to authenticate the language model provider.
437#[derive(Debug, Error)]
438pub enum AuthenticateError {
439 #[error("credentials not found")]
440 CredentialsNotFound,
441 #[error(transparent)]
442 Other(#[from] anyhow::Error),
443}
444
445pub trait LanguageModelProvider: 'static {
446 fn id(&self) -> LanguageModelProviderId;
447 fn name(&self) -> LanguageModelProviderName;
448 fn icon(&self) -> IconName {
449 IconName::ZedAssistant
450 }
451 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
452 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
453 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
454 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
455 Vec::new()
456 }
457 fn is_authenticated(&self, cx: &App) -> bool;
458 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
459 fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
460 fn must_accept_terms(&self, _cx: &App) -> bool {
461 false
462 }
463 fn render_accept_terms(
464 &self,
465 _view: LanguageModelProviderTosView,
466 _cx: &mut App,
467 ) -> Option<AnyElement> {
468 None
469 }
470 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
471}
472
473#[derive(PartialEq, Eq)]
474pub enum LanguageModelProviderTosView {
475 /// When there are some past interactions in the Agent Panel.
476 ThreadtEmptyState,
477 /// When there are no past interactions in the Agent Panel.
478 ThreadFreshStart,
479 PromptEditorPopup,
480 Configuration,
481}
482
483pub trait LanguageModelProviderState: 'static {
484 type ObservableEntity;
485
486 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
487
488 fn subscribe<T: 'static>(
489 &self,
490 cx: &mut gpui::Context<T>,
491 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
492 ) -> Option<gpui::Subscription> {
493 let entity = self.observable_entity()?;
494 Some(cx.observe(&entity, move |this, _, cx| {
495 callback(this, cx);
496 }))
497 }
498}
499
500#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
501pub struct LanguageModelId(pub SharedString);
502
503#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
504pub struct LanguageModelName(pub SharedString);
505
506#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
507pub struct LanguageModelProviderId(pub SharedString);
508
509#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
510pub struct LanguageModelProviderName(pub SharedString);
511
512impl fmt::Display for LanguageModelProviderId {
513 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
514 write!(f, "{}", self.0)
515 }
516}
517
518impl From<String> for LanguageModelId {
519 fn from(value: String) -> Self {
520 Self(SharedString::from(value))
521 }
522}
523
524impl From<String> for LanguageModelName {
525 fn from(value: String) -> Self {
526 Self(SharedString::from(value))
527 }
528}
529
530impl From<String> for LanguageModelProviderId {
531 fn from(value: String) -> Self {
532 Self(SharedString::from(value))
533 }
534}
535
536impl From<String> for LanguageModelProviderName {
537 fn from(value: String) -> Self {
538 Self(SharedString::from(value))
539 }
540}