1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7
8#[cfg(any(test, feature = "test-support"))]
9pub mod fake_provider;
10
11use anthropic::{AnthropicError, parse_prompt_too_long};
12use anyhow::Result;
13use client::Client;
14use futures::FutureExt;
15use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
16use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
17use http_client::http;
18use icons::IconName;
19use parking_lot::Mutex;
20use schemars::JsonSchema;
21use serde::{Deserialize, Serialize, de::DeserializeOwned};
22use std::ops::{Add, Sub};
23use std::sync::Arc;
24use std::time::Duration;
25use std::{fmt, io};
26use thiserror::Error;
27use util::serde::is_default;
28use zed_llm_client::CompletionRequestStatus;
29
30pub use crate::model::*;
31pub use crate::rate_limiter::*;
32pub use crate::registry::*;
33pub use crate::request::*;
34pub use crate::role::*;
35pub use crate::telemetry::*;
36
37pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
38
39/// If we get a rate limit error that doesn't tell us when we can retry,
40/// default to waiting this long before retrying.
41const DEFAULT_RATE_LIMIT_RETRY_AFTER: Duration = Duration::from_secs(4);
42
43pub fn init(client: Arc<Client>, cx: &mut App) {
44 init_settings(cx);
45 RefreshLlmTokenListener::register(client.clone(), cx);
46}
47
48pub fn init_settings(cx: &mut App) {
49 registry::init(cx);
50}
51
52/// Configuration for caching language model messages.
53#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
54pub struct LanguageModelCacheConfiguration {
55 pub max_cache_anchors: usize,
56 pub should_speculate: bool,
57 pub min_total_token: u64,
58}
59
60/// A completion event from a language model.
61#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
62pub enum LanguageModelCompletionEvent {
63 StatusUpdate(CompletionRequestStatus),
64 Stop(StopReason),
65 Text(String),
66 Thinking {
67 text: String,
68 signature: Option<String>,
69 },
70 ToolUse(LanguageModelToolUse),
71 StartMessage {
72 message_id: String,
73 },
74 UsageUpdate(TokenUsage),
75}
76
77#[derive(Error, Debug)]
78pub enum LanguageModelCompletionError {
79 #[error("rate limit exceeded, retry after {retry_after:?}")]
80 RateLimitExceeded { retry_after: Duration },
81 #[error("received bad input JSON")]
82 BadInputJson {
83 id: LanguageModelToolUseId,
84 tool_name: Arc<str>,
85 raw_input: Arc<str>,
86 json_parse_error: String,
87 },
88 #[error("language model provider's API is overloaded")]
89 Overloaded,
90 #[error(transparent)]
91 Other(#[from] anyhow::Error),
92 #[error("invalid request format to language model provider's API")]
93 BadRequestFormat,
94 #[error("authentication error with language model provider's API")]
95 AuthenticationError,
96 #[error("permission error with language model provider's API")]
97 PermissionError,
98 #[error("language model provider API endpoint not found")]
99 ApiEndpointNotFound,
100 #[error("prompt too large for context window")]
101 PromptTooLarge { tokens: Option<u64> },
102 #[error("internal server error in language model provider's API")]
103 ApiInternalServerError,
104 #[error("I/O error reading response from language model provider's API: {0:?}")]
105 ApiReadResponseError(io::Error),
106 #[error("HTTP response error from language model provider's API: status {status} - {body:?}")]
107 HttpResponseError { status: u16, body: String },
108 #[error("error serializing request to language model provider API: {0}")]
109 SerializeRequest(serde_json::Error),
110 #[error("error building request body to language model provider API: {0}")]
111 BuildRequestBody(http::Error),
112 #[error("error sending HTTP request to language model provider API: {0}")]
113 HttpSend(anyhow::Error),
114 #[error("error deserializing language model provider API response: {0}")]
115 DeserializeResponse(serde_json::Error),
116 #[error("unexpected language model provider API response format: {0}")]
117 UnknownResponseFormat(String),
118}
119
120impl From<AnthropicError> for LanguageModelCompletionError {
121 fn from(error: AnthropicError) -> Self {
122 match error {
123 AnthropicError::SerializeRequest(error) => Self::SerializeRequest(error),
124 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody(error),
125 AnthropicError::HttpSend(error) => Self::HttpSend(error),
126 AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse(error),
127 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError(error),
128 AnthropicError::HttpResponseError { status, body } => {
129 Self::HttpResponseError { status, body }
130 }
131 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { retry_after },
132 AnthropicError::ApiError(api_error) => api_error.into(),
133 AnthropicError::UnexpectedResponseFormat(error) => Self::UnknownResponseFormat(error),
134 }
135 }
136}
137
138impl From<anthropic::ApiError> for LanguageModelCompletionError {
139 fn from(error: anthropic::ApiError) -> Self {
140 use anthropic::ApiErrorCode::*;
141
142 match error.code() {
143 Some(code) => match code {
144 InvalidRequestError => LanguageModelCompletionError::BadRequestFormat,
145 AuthenticationError => LanguageModelCompletionError::AuthenticationError,
146 PermissionError => LanguageModelCompletionError::PermissionError,
147 NotFoundError => LanguageModelCompletionError::ApiEndpointNotFound,
148 RequestTooLarge => LanguageModelCompletionError::PromptTooLarge {
149 tokens: parse_prompt_too_long(&error.message),
150 },
151 RateLimitError => LanguageModelCompletionError::RateLimitExceeded {
152 retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
153 },
154 ApiError => LanguageModelCompletionError::ApiInternalServerError,
155 OverloadedError => LanguageModelCompletionError::Overloaded,
156 },
157 None => LanguageModelCompletionError::Other(error.into()),
158 }
159 }
160}
161
162/// Indicates the format used to define the input schema for a language model tool.
163#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
164pub enum LanguageModelToolSchemaFormat {
165 /// A JSON schema, see https://json-schema.org
166 JsonSchema,
167 /// A subset of an OpenAPI 3.0 schema object supported by Google AI, see https://ai.google.dev/api/caching#Schema
168 JsonSchemaSubset,
169}
170
171#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
172#[serde(rename_all = "snake_case")]
173pub enum StopReason {
174 EndTurn,
175 MaxTokens,
176 ToolUse,
177 Refusal,
178}
179
180#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
181pub struct TokenUsage {
182 #[serde(default, skip_serializing_if = "is_default")]
183 pub input_tokens: u64,
184 #[serde(default, skip_serializing_if = "is_default")]
185 pub output_tokens: u64,
186 #[serde(default, skip_serializing_if = "is_default")]
187 pub cache_creation_input_tokens: u64,
188 #[serde(default, skip_serializing_if = "is_default")]
189 pub cache_read_input_tokens: u64,
190}
191
192impl TokenUsage {
193 pub fn total_tokens(&self) -> u64 {
194 self.input_tokens
195 + self.output_tokens
196 + self.cache_read_input_tokens
197 + self.cache_creation_input_tokens
198 }
199}
200
201impl Add<TokenUsage> for TokenUsage {
202 type Output = Self;
203
204 fn add(self, other: Self) -> Self {
205 Self {
206 input_tokens: self.input_tokens + other.input_tokens,
207 output_tokens: self.output_tokens + other.output_tokens,
208 cache_creation_input_tokens: self.cache_creation_input_tokens
209 + other.cache_creation_input_tokens,
210 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
211 }
212 }
213}
214
215impl Sub<TokenUsage> for TokenUsage {
216 type Output = Self;
217
218 fn sub(self, other: Self) -> Self {
219 Self {
220 input_tokens: self.input_tokens - other.input_tokens,
221 output_tokens: self.output_tokens - other.output_tokens,
222 cache_creation_input_tokens: self.cache_creation_input_tokens
223 - other.cache_creation_input_tokens,
224 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
225 }
226 }
227}
228
229#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
230pub struct LanguageModelToolUseId(Arc<str>);
231
232impl fmt::Display for LanguageModelToolUseId {
233 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234 write!(f, "{}", self.0)
235 }
236}
237
238impl<T> From<T> for LanguageModelToolUseId
239where
240 T: Into<Arc<str>>,
241{
242 fn from(value: T) -> Self {
243 Self(value.into())
244 }
245}
246
247#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
248pub struct LanguageModelToolUse {
249 pub id: LanguageModelToolUseId,
250 pub name: Arc<str>,
251 pub raw_input: String,
252 pub input: serde_json::Value,
253 pub is_input_complete: bool,
254}
255
256pub struct LanguageModelTextStream {
257 pub message_id: Option<String>,
258 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
259 // Has complete token usage after the stream has finished
260 pub last_token_usage: Arc<Mutex<TokenUsage>>,
261}
262
263impl Default for LanguageModelTextStream {
264 fn default() -> Self {
265 Self {
266 message_id: None,
267 stream: Box::pin(futures::stream::empty()),
268 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
269 }
270 }
271}
272
273pub trait LanguageModel: Send + Sync {
274 fn id(&self) -> LanguageModelId;
275 fn name(&self) -> LanguageModelName;
276 fn provider_id(&self) -> LanguageModelProviderId;
277 fn provider_name(&self) -> LanguageModelProviderName;
278 fn telemetry_id(&self) -> String;
279
280 fn api_key(&self, _cx: &App) -> Option<String> {
281 None
282 }
283
284 /// Whether this model supports images
285 fn supports_images(&self) -> bool;
286
287 /// Whether this model supports tools.
288 fn supports_tools(&self) -> bool;
289
290 /// Whether this model supports choosing which tool to use.
291 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
292
293 /// Returns whether this model supports "burn mode";
294 fn supports_max_mode(&self) -> bool {
295 false
296 }
297
298 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
299 LanguageModelToolSchemaFormat::JsonSchema
300 }
301
302 fn max_token_count(&self) -> u64;
303 fn max_output_tokens(&self) -> Option<u64> {
304 None
305 }
306
307 fn count_tokens(
308 &self,
309 request: LanguageModelRequest,
310 cx: &App,
311 ) -> BoxFuture<'static, Result<u64>>;
312
313 fn stream_completion(
314 &self,
315 request: LanguageModelRequest,
316 cx: &AsyncApp,
317 ) -> BoxFuture<
318 'static,
319 Result<
320 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
321 LanguageModelCompletionError,
322 >,
323 >;
324
325 fn stream_completion_text(
326 &self,
327 request: LanguageModelRequest,
328 cx: &AsyncApp,
329 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
330 let future = self.stream_completion(request, cx);
331
332 async move {
333 let events = future.await?;
334 let mut events = events.fuse();
335 let mut message_id = None;
336 let mut first_item_text = None;
337 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
338
339 if let Some(first_event) = events.next().await {
340 match first_event {
341 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
342 message_id = Some(id.clone());
343 }
344 Ok(LanguageModelCompletionEvent::Text(text)) => {
345 first_item_text = Some(text);
346 }
347 _ => (),
348 }
349 }
350
351 let stream = futures::stream::iter(first_item_text.map(Ok))
352 .chain(events.filter_map({
353 let last_token_usage = last_token_usage.clone();
354 move |result| {
355 let last_token_usage = last_token_usage.clone();
356 async move {
357 match result {
358 Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
359 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
360 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
361 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
362 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
363 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
364 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
365 *last_token_usage.lock() = token_usage;
366 None
367 }
368 Err(err) => Some(Err(err)),
369 }
370 }
371 }
372 }))
373 .boxed();
374
375 Ok(LanguageModelTextStream {
376 message_id,
377 stream,
378 last_token_usage,
379 })
380 }
381 .boxed()
382 }
383
384 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
385 None
386 }
387
388 #[cfg(any(test, feature = "test-support"))]
389 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
390 unimplemented!()
391 }
392}
393
394#[derive(Debug, Error)]
395pub enum LanguageModelKnownError {
396 #[error("Context window limit exceeded ({tokens})")]
397 ContextWindowLimitExceeded { tokens: u64 },
398 #[error("Language model provider's API is currently overloaded")]
399 Overloaded,
400 #[error("Language model provider's API encountered an internal server error")]
401 ApiInternalServerError,
402 #[error("I/O error while reading response from language model provider's API: {0:?}")]
403 ReadResponseError(io::Error),
404 #[error("Error deserializing response from language model provider's API: {0:?}")]
405 DeserializeResponse(serde_json::Error),
406 #[error("Language model provider's API returned a response in an unknown format")]
407 UnknownResponseFormat(String),
408 #[error("Rate limit exceeded for language model provider's API; retry in {retry_after:?}")]
409 RateLimitExceeded { retry_after: Duration },
410}
411
412impl LanguageModelKnownError {
413 /// Attempts to map an HTTP response status code to a known error type.
414 /// Returns None if the status code doesn't map to a specific known error.
415 pub fn from_http_response(status: u16, _body: &str) -> Option<Self> {
416 match status {
417 429 => Some(Self::RateLimitExceeded {
418 retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
419 }),
420 503 => Some(Self::Overloaded),
421 500..=599 => Some(Self::ApiInternalServerError),
422 _ => None,
423 }
424 }
425}
426
427pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
428 fn name() -> String;
429 fn description() -> String;
430}
431
432/// An error that occurred when trying to authenticate the language model provider.
433#[derive(Debug, Error)]
434pub enum AuthenticateError {
435 #[error("credentials not found")]
436 CredentialsNotFound,
437 #[error(transparent)]
438 Other(#[from] anyhow::Error),
439}
440
441pub trait LanguageModelProvider: 'static {
442 fn id(&self) -> LanguageModelProviderId;
443 fn name(&self) -> LanguageModelProviderName;
444 fn icon(&self) -> IconName {
445 IconName::ZedAssistant
446 }
447 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
448 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
449 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
450 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
451 Vec::new()
452 }
453 fn is_authenticated(&self, cx: &App) -> bool;
454 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
455 fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
456 fn must_accept_terms(&self, _cx: &App) -> bool {
457 false
458 }
459 fn render_accept_terms(
460 &self,
461 _view: LanguageModelProviderTosView,
462 _cx: &mut App,
463 ) -> Option<AnyElement> {
464 None
465 }
466 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
467}
468
469#[derive(PartialEq, Eq)]
470pub enum LanguageModelProviderTosView {
471 /// When there are some past interactions in the Agent Panel.
472 ThreadtEmptyState,
473 /// When there are no past interactions in the Agent Panel.
474 ThreadFreshStart,
475 PromptEditorPopup,
476 Configuration,
477}
478
479pub trait LanguageModelProviderState: 'static {
480 type ObservableEntity;
481
482 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
483
484 fn subscribe<T: 'static>(
485 &self,
486 cx: &mut gpui::Context<T>,
487 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
488 ) -> Option<gpui::Subscription> {
489 let entity = self.observable_entity()?;
490 Some(cx.observe(&entity, move |this, _, cx| {
491 callback(this, cx);
492 }))
493 }
494}
495
496#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
497pub struct LanguageModelId(pub SharedString);
498
499#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
500pub struct LanguageModelName(pub SharedString);
501
502#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
503pub struct LanguageModelProviderId(pub SharedString);
504
505#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
506pub struct LanguageModelProviderName(pub SharedString);
507
508impl fmt::Display for LanguageModelProviderId {
509 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
510 write!(f, "{}", self.0)
511 }
512}
513
514impl From<String> for LanguageModelId {
515 fn from(value: String) -> Self {
516 Self(SharedString::from(value))
517 }
518}
519
520impl From<String> for LanguageModelName {
521 fn from(value: String) -> Self {
522 Self(SharedString::from(value))
523 }
524}
525
526impl From<String> for LanguageModelProviderId {
527 fn from(value: String) -> Self {
528 Self(SharedString::from(value))
529 }
530}
531
532impl From<String> for LanguageModelProviderName {
533 fn from(value: String) -> Self {
534 Self(SharedString::from(value))
535 }
536}