1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7pub mod tool_schema;
8
9#[cfg(any(test, feature = "test-support"))]
10pub mod fake_provider;
11
12use anthropic::{AnthropicError, parse_prompt_too_long};
13use anyhow::{Result, anyhow};
14use client::Client;
15use cloud_llm_client::{CompletionMode, CompletionRequestStatus, UsageLimit};
16use futures::FutureExt;
17use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
18use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
19use http_client::{StatusCode, http};
20use icons::IconName;
21use open_router::OpenRouterError;
22use parking_lot::Mutex;
23use serde::{Deserialize, Serialize};
24pub use settings::LanguageModelCacheConfiguration;
25use std::ops::{Add, Sub};
26use std::str::FromStr;
27use std::sync::Arc;
28use std::time::Duration;
29use std::{fmt, io};
30use thiserror::Error;
31use util::serde::is_default;
32
33pub use crate::model::*;
34pub use crate::rate_limiter::*;
35pub use crate::registry::*;
36pub use crate::request::*;
37pub use crate::role::*;
38pub use crate::telemetry::*;
39pub use crate::tool_schema::LanguageModelToolSchemaFormat;
40
41pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
42 LanguageModelProviderId::new("anthropic");
43pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
44 LanguageModelProviderName::new("Anthropic");
45
46pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
47pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
48 LanguageModelProviderName::new("Google AI");
49
50pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
51pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
52 LanguageModelProviderName::new("OpenAI");
53
54pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
55pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
56
57pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
58pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
59 LanguageModelProviderName::new("Zed");
60
61pub fn init(client: Arc<Client>, cx: &mut App) {
62 init_settings(cx);
63 RefreshLlmTokenListener::register(client, cx);
64}
65
66pub fn init_settings(cx: &mut App) {
67 registry::init(cx);
68}
69
70/// A completion event from a language model.
71#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
72pub enum LanguageModelCompletionEvent {
73 Queued {
74 position: usize,
75 },
76 Started,
77 UsageUpdated {
78 amount: usize,
79 limit: UsageLimit,
80 },
81 ToolUseLimitReached,
82 Stop(StopReason),
83 Text(String),
84 Thinking {
85 text: String,
86 signature: Option<String>,
87 },
88 RedactedThinking {
89 data: String,
90 },
91 ToolUse(LanguageModelToolUse),
92 ToolUseJsonParseError {
93 id: LanguageModelToolUseId,
94 tool_name: Arc<str>,
95 raw_input: Arc<str>,
96 json_parse_error: String,
97 },
98 StartMessage {
99 message_id: String,
100 },
101 ReasoningDetails(serde_json::Value),
102 UsageUpdate(TokenUsage),
103}
104
105impl LanguageModelCompletionEvent {
106 pub fn from_completion_request_status(
107 status: CompletionRequestStatus,
108 upstream_provider: LanguageModelProviderName,
109 ) -> Result<Self, LanguageModelCompletionError> {
110 match status {
111 CompletionRequestStatus::Queued { position } => {
112 Ok(LanguageModelCompletionEvent::Queued { position })
113 }
114 CompletionRequestStatus::Started => Ok(LanguageModelCompletionEvent::Started),
115 CompletionRequestStatus::UsageUpdated { amount, limit } => {
116 Ok(LanguageModelCompletionEvent::UsageUpdated { amount, limit })
117 }
118 CompletionRequestStatus::ToolUseLimitReached => {
119 Ok(LanguageModelCompletionEvent::ToolUseLimitReached)
120 }
121 CompletionRequestStatus::Failed {
122 code,
123 message,
124 request_id: _,
125 retry_after,
126 } => Err(LanguageModelCompletionError::from_cloud_failure(
127 upstream_provider,
128 code,
129 message,
130 retry_after.map(Duration::from_secs_f64),
131 )),
132 }
133 }
134}
135
136#[derive(Error, Debug)]
137pub enum LanguageModelCompletionError {
138 #[error("prompt too large for context window")]
139 PromptTooLarge { tokens: Option<u64> },
140 #[error("missing {provider} API key")]
141 NoApiKey { provider: LanguageModelProviderName },
142 #[error("{provider}'s API rate limit exceeded")]
143 RateLimitExceeded {
144 provider: LanguageModelProviderName,
145 retry_after: Option<Duration>,
146 },
147 #[error("{provider}'s API servers are overloaded right now")]
148 ServerOverloaded {
149 provider: LanguageModelProviderName,
150 retry_after: Option<Duration>,
151 },
152 #[error("{provider}'s API server reported an internal server error: {message}")]
153 ApiInternalServerError {
154 provider: LanguageModelProviderName,
155 message: String,
156 },
157 #[error("{message}")]
158 UpstreamProviderError {
159 message: String,
160 status: StatusCode,
161 retry_after: Option<Duration>,
162 },
163 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
164 HttpResponseError {
165 provider: LanguageModelProviderName,
166 status_code: StatusCode,
167 message: String,
168 },
169
170 // Client errors
171 #[error("invalid request format to {provider}'s API: {message}")]
172 BadRequestFormat {
173 provider: LanguageModelProviderName,
174 message: String,
175 },
176 #[error("authentication error with {provider}'s API: {message}")]
177 AuthenticationError {
178 provider: LanguageModelProviderName,
179 message: String,
180 },
181 #[error("Permission error with {provider}'s API: {message}")]
182 PermissionError {
183 provider: LanguageModelProviderName,
184 message: String,
185 },
186 #[error("language model provider API endpoint not found")]
187 ApiEndpointNotFound { provider: LanguageModelProviderName },
188 #[error("I/O error reading response from {provider}'s API")]
189 ApiReadResponseError {
190 provider: LanguageModelProviderName,
191 #[source]
192 error: io::Error,
193 },
194 #[error("error serializing request to {provider} API")]
195 SerializeRequest {
196 provider: LanguageModelProviderName,
197 #[source]
198 error: serde_json::Error,
199 },
200 #[error("error building request body to {provider} API")]
201 BuildRequestBody {
202 provider: LanguageModelProviderName,
203 #[source]
204 error: http::Error,
205 },
206 #[error("error sending HTTP request to {provider} API")]
207 HttpSend {
208 provider: LanguageModelProviderName,
209 #[source]
210 error: anyhow::Error,
211 },
212 #[error("error deserializing {provider} API response")]
213 DeserializeResponse {
214 provider: LanguageModelProviderName,
215 #[source]
216 error: serde_json::Error,
217 },
218
219 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
220 #[error(transparent)]
221 Other(#[from] anyhow::Error),
222}
223
224impl LanguageModelCompletionError {
225 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
226 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
227 let upstream_status = error_json
228 .get("upstream_status")
229 .and_then(|v| v.as_u64())
230 .and_then(|status| u16::try_from(status).ok())
231 .and_then(|status| StatusCode::from_u16(status).ok())?;
232 let inner_message = error_json
233 .get("message")
234 .and_then(|v| v.as_str())
235 .unwrap_or(message)
236 .to_string();
237 Some((upstream_status, inner_message))
238 }
239
240 pub fn from_cloud_failure(
241 upstream_provider: LanguageModelProviderName,
242 code: String,
243 message: String,
244 retry_after: Option<Duration>,
245 ) -> Self {
246 if let Some(tokens) = parse_prompt_too_long(&message) {
247 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
248 // to be reported. This is a temporary workaround to handle this in the case where the
249 // token limit has been exceeded.
250 Self::PromptTooLarge {
251 tokens: Some(tokens),
252 }
253 } else if code == "upstream_http_error" {
254 if let Some((upstream_status, inner_message)) =
255 Self::parse_upstream_error_json(&message)
256 {
257 return Self::from_http_status(
258 upstream_provider,
259 upstream_status,
260 inner_message,
261 retry_after,
262 );
263 }
264 anyhow!("completion request failed, code: {code}, message: {message}").into()
265 } else if let Some(status_code) = code
266 .strip_prefix("upstream_http_")
267 .and_then(|code| StatusCode::from_str(code).ok())
268 {
269 Self::from_http_status(upstream_provider, status_code, message, retry_after)
270 } else if let Some(status_code) = code
271 .strip_prefix("http_")
272 .and_then(|code| StatusCode::from_str(code).ok())
273 {
274 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
275 } else {
276 anyhow!("completion request failed, code: {code}, message: {message}").into()
277 }
278 }
279
280 pub fn from_http_status(
281 provider: LanguageModelProviderName,
282 status_code: StatusCode,
283 message: String,
284 retry_after: Option<Duration>,
285 ) -> Self {
286 match status_code {
287 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
288 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
289 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
290 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
291 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
292 tokens: parse_prompt_too_long(&message),
293 },
294 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
295 provider,
296 retry_after,
297 },
298 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
299 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
300 provider,
301 retry_after,
302 },
303 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
304 provider,
305 retry_after,
306 },
307 _ => Self::HttpResponseError {
308 provider,
309 status_code,
310 message,
311 },
312 }
313 }
314}
315
316impl From<AnthropicError> for LanguageModelCompletionError {
317 fn from(error: AnthropicError) -> Self {
318 let provider = ANTHROPIC_PROVIDER_NAME;
319 match error {
320 AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
321 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
322 AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
323 AnthropicError::DeserializeResponse(error) => {
324 Self::DeserializeResponse { provider, error }
325 }
326 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
327 AnthropicError::HttpResponseError {
328 status_code,
329 message,
330 } => Self::HttpResponseError {
331 provider,
332 status_code,
333 message,
334 },
335 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
336 provider,
337 retry_after: Some(retry_after),
338 },
339 AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
340 provider,
341 retry_after,
342 },
343 AnthropicError::ApiError(api_error) => api_error.into(),
344 }
345 }
346}
347
348impl From<anthropic::ApiError> for LanguageModelCompletionError {
349 fn from(error: anthropic::ApiError) -> Self {
350 use anthropic::ApiErrorCode::*;
351 let provider = ANTHROPIC_PROVIDER_NAME;
352 match error.code() {
353 Some(code) => match code {
354 InvalidRequestError => Self::BadRequestFormat {
355 provider,
356 message: error.message,
357 },
358 AuthenticationError => Self::AuthenticationError {
359 provider,
360 message: error.message,
361 },
362 PermissionError => Self::PermissionError {
363 provider,
364 message: error.message,
365 },
366 NotFoundError => Self::ApiEndpointNotFound { provider },
367 RequestTooLarge => Self::PromptTooLarge {
368 tokens: parse_prompt_too_long(&error.message),
369 },
370 RateLimitError => Self::RateLimitExceeded {
371 provider,
372 retry_after: None,
373 },
374 ApiError => Self::ApiInternalServerError {
375 provider,
376 message: error.message,
377 },
378 OverloadedError => Self::ServerOverloaded {
379 provider,
380 retry_after: None,
381 },
382 },
383 None => Self::Other(error.into()),
384 }
385 }
386}
387
388impl From<open_ai::RequestError> for LanguageModelCompletionError {
389 fn from(error: open_ai::RequestError) -> Self {
390 match error {
391 open_ai::RequestError::HttpResponseError {
392 provider,
393 status_code,
394 body,
395 headers,
396 } => {
397 let retry_after = headers
398 .get(http::header::RETRY_AFTER)
399 .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
400 .map(Duration::from_secs);
401
402 Self::from_http_status(provider.into(), status_code, body, retry_after)
403 }
404 open_ai::RequestError::Other(e) => Self::Other(e),
405 }
406 }
407}
408
409impl From<OpenRouterError> for LanguageModelCompletionError {
410 fn from(error: OpenRouterError) -> Self {
411 let provider = LanguageModelProviderName::new("OpenRouter");
412 match error {
413 OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
414 OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
415 OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
416 OpenRouterError::DeserializeResponse(error) => {
417 Self::DeserializeResponse { provider, error }
418 }
419 OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
420 OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
421 provider,
422 retry_after: Some(retry_after),
423 },
424 OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
425 provider,
426 retry_after,
427 },
428 OpenRouterError::ApiError(api_error) => api_error.into(),
429 }
430 }
431}
432
433impl From<open_router::ApiError> for LanguageModelCompletionError {
434 fn from(error: open_router::ApiError) -> Self {
435 use open_router::ApiErrorCode::*;
436 let provider = LanguageModelProviderName::new("OpenRouter");
437 match error.code {
438 InvalidRequestError => Self::BadRequestFormat {
439 provider,
440 message: error.message,
441 },
442 AuthenticationError => Self::AuthenticationError {
443 provider,
444 message: error.message,
445 },
446 PaymentRequiredError => Self::AuthenticationError {
447 provider,
448 message: format!("Payment required: {}", error.message),
449 },
450 PermissionError => Self::PermissionError {
451 provider,
452 message: error.message,
453 },
454 RequestTimedOut => Self::HttpResponseError {
455 provider,
456 status_code: StatusCode::REQUEST_TIMEOUT,
457 message: error.message,
458 },
459 RateLimitError => Self::RateLimitExceeded {
460 provider,
461 retry_after: None,
462 },
463 ApiError => Self::ApiInternalServerError {
464 provider,
465 message: error.message,
466 },
467 OverloadedError => Self::ServerOverloaded {
468 provider,
469 retry_after: None,
470 },
471 }
472 }
473}
474
475#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
476#[serde(rename_all = "snake_case")]
477pub enum StopReason {
478 EndTurn,
479 MaxTokens,
480 ToolUse,
481 Refusal,
482}
483
484#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
485pub struct TokenUsage {
486 #[serde(default, skip_serializing_if = "is_default")]
487 pub input_tokens: u64,
488 #[serde(default, skip_serializing_if = "is_default")]
489 pub output_tokens: u64,
490 #[serde(default, skip_serializing_if = "is_default")]
491 pub cache_creation_input_tokens: u64,
492 #[serde(default, skip_serializing_if = "is_default")]
493 pub cache_read_input_tokens: u64,
494}
495
496impl TokenUsage {
497 pub fn total_tokens(&self) -> u64 {
498 self.input_tokens
499 + self.output_tokens
500 + self.cache_read_input_tokens
501 + self.cache_creation_input_tokens
502 }
503}
504
505impl Add<TokenUsage> for TokenUsage {
506 type Output = Self;
507
508 fn add(self, other: Self) -> Self {
509 Self {
510 input_tokens: self.input_tokens + other.input_tokens,
511 output_tokens: self.output_tokens + other.output_tokens,
512 cache_creation_input_tokens: self.cache_creation_input_tokens
513 + other.cache_creation_input_tokens,
514 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
515 }
516 }
517}
518
519impl Sub<TokenUsage> for TokenUsage {
520 type Output = Self;
521
522 fn sub(self, other: Self) -> Self {
523 Self {
524 input_tokens: self.input_tokens - other.input_tokens,
525 output_tokens: self.output_tokens - other.output_tokens,
526 cache_creation_input_tokens: self.cache_creation_input_tokens
527 - other.cache_creation_input_tokens,
528 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
529 }
530 }
531}
532
533#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
534pub struct LanguageModelToolUseId(Arc<str>);
535
536impl fmt::Display for LanguageModelToolUseId {
537 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
538 write!(f, "{}", self.0)
539 }
540}
541
542impl<T> From<T> for LanguageModelToolUseId
543where
544 T: Into<Arc<str>>,
545{
546 fn from(value: T) -> Self {
547 Self(value.into())
548 }
549}
550
551#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
552pub struct LanguageModelToolUse {
553 pub id: LanguageModelToolUseId,
554 pub name: Arc<str>,
555 pub raw_input: String,
556 pub input: serde_json::Value,
557 pub is_input_complete: bool,
558 /// Thought signature the model sent us. Some models require that this
559 /// signature be preserved and sent back in conversation history for validation.
560 pub thought_signature: Option<String>,
561}
562
563pub struct LanguageModelTextStream {
564 pub message_id: Option<String>,
565 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
566 // Has complete token usage after the stream has finished
567 pub last_token_usage: Arc<Mutex<TokenUsage>>,
568}
569
570impl Default for LanguageModelTextStream {
571 fn default() -> Self {
572 Self {
573 message_id: None,
574 stream: Box::pin(futures::stream::empty()),
575 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
576 }
577 }
578}
579
580pub trait LanguageModel: Send + Sync {
581 fn id(&self) -> LanguageModelId;
582 fn name(&self) -> LanguageModelName;
583 fn provider_id(&self) -> LanguageModelProviderId;
584 fn provider_name(&self) -> LanguageModelProviderName;
585 fn upstream_provider_id(&self) -> LanguageModelProviderId {
586 self.provider_id()
587 }
588 fn upstream_provider_name(&self) -> LanguageModelProviderName {
589 self.provider_name()
590 }
591
592 fn telemetry_id(&self) -> String;
593
594 fn api_key(&self, _cx: &App) -> Option<String> {
595 None
596 }
597
598 /// Whether this model supports images
599 fn supports_images(&self) -> bool;
600
601 /// Whether this model supports tools.
602 fn supports_tools(&self) -> bool;
603
604 /// Whether this model supports choosing which tool to use.
605 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
606
607 /// Returns whether this model supports "burn mode";
608 fn supports_burn_mode(&self) -> bool {
609 false
610 }
611
612 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
613 LanguageModelToolSchemaFormat::JsonSchema
614 }
615
616 fn max_token_count(&self) -> u64;
617 /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
618 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
619 None
620 }
621 fn max_output_tokens(&self) -> Option<u64> {
622 None
623 }
624
625 fn count_tokens(
626 &self,
627 request: LanguageModelRequest,
628 cx: &App,
629 ) -> BoxFuture<'static, Result<u64>>;
630
631 fn stream_completion(
632 &self,
633 request: LanguageModelRequest,
634 cx: &AsyncApp,
635 ) -> BoxFuture<
636 'static,
637 Result<
638 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
639 LanguageModelCompletionError,
640 >,
641 >;
642
643 fn stream_completion_text(
644 &self,
645 request: LanguageModelRequest,
646 cx: &AsyncApp,
647 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
648 let future = self.stream_completion(request, cx);
649
650 async move {
651 let events = future.await?;
652 let mut events = events.fuse();
653 let mut message_id = None;
654 let mut first_item_text = None;
655 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
656
657 if let Some(first_event) = events.next().await {
658 match first_event {
659 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
660 message_id = Some(id);
661 }
662 Ok(LanguageModelCompletionEvent::Text(text)) => {
663 first_item_text = Some(text);
664 }
665 _ => (),
666 }
667 }
668
669 let stream = futures::stream::iter(first_item_text.map(Ok))
670 .chain(events.filter_map({
671 let last_token_usage = last_token_usage.clone();
672 move |result| {
673 let last_token_usage = last_token_usage.clone();
674 async move {
675 match result {
676 Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
677 Ok(LanguageModelCompletionEvent::Started) => None,
678 Ok(LanguageModelCompletionEvent::UsageUpdated { .. }) => None,
679 Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => None,
680 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
681 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
682 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
683 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
684 Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
685 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
686 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
687 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
688 ..
689 }) => None,
690 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
691 *last_token_usage.lock() = token_usage;
692 None
693 }
694 Err(err) => Some(Err(err)),
695 }
696 }
697 }
698 }))
699 .boxed();
700
701 Ok(LanguageModelTextStream {
702 message_id,
703 stream,
704 last_token_usage,
705 })
706 }
707 .boxed()
708 }
709
710 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
711 None
712 }
713
714 #[cfg(any(test, feature = "test-support"))]
715 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
716 unimplemented!()
717 }
718}
719
720pub trait LanguageModelExt: LanguageModel {
721 fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
722 match mode {
723 CompletionMode::Normal => self.max_token_count(),
724 CompletionMode::Max => self
725 .max_token_count_in_burn_mode()
726 .unwrap_or_else(|| self.max_token_count()),
727 }
728 }
729}
730impl LanguageModelExt for dyn LanguageModel {}
731
732/// An error that occurred when trying to authenticate the language model provider.
733#[derive(Debug, Error)]
734pub enum AuthenticateError {
735 #[error("connection refused")]
736 ConnectionRefused,
737 #[error("credentials not found")]
738 CredentialsNotFound,
739 #[error(transparent)]
740 Other(#[from] anyhow::Error),
741}
742
743pub trait LanguageModelProvider: 'static {
744 fn id(&self) -> LanguageModelProviderId;
745 fn name(&self) -> LanguageModelProviderName;
746 fn icon(&self) -> IconName {
747 IconName::ZedAssistant
748 }
749 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
750 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
751 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
752 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
753 Vec::new()
754 }
755 fn is_authenticated(&self, cx: &App) -> bool;
756 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
757 fn configuration_view(
758 &self,
759 target_agent: ConfigurationViewTargetAgent,
760 window: &mut Window,
761 cx: &mut App,
762 ) -> AnyView;
763 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
764}
765
766#[derive(Default, Clone)]
767pub enum ConfigurationViewTargetAgent {
768 #[default]
769 ZedAgent,
770 Other(SharedString),
771}
772
773#[derive(PartialEq, Eq)]
774pub enum LanguageModelProviderTosView {
775 /// When there are some past interactions in the Agent Panel.
776 ThreadEmptyState,
777 /// When there are no past interactions in the Agent Panel.
778 ThreadFreshStart,
779 TextThreadPopup,
780 Configuration,
781}
782
783pub trait LanguageModelProviderState: 'static {
784 type ObservableEntity;
785
786 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
787
788 fn subscribe<T: 'static>(
789 &self,
790 cx: &mut gpui::Context<T>,
791 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
792 ) -> Option<gpui::Subscription> {
793 let entity = self.observable_entity()?;
794 Some(cx.observe(&entity, move |this, _, cx| {
795 callback(this, cx);
796 }))
797 }
798}
799
800#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
801pub struct LanguageModelId(pub SharedString);
802
803#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
804pub struct LanguageModelName(pub SharedString);
805
806#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
807pub struct LanguageModelProviderId(pub SharedString);
808
809#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
810pub struct LanguageModelProviderName(pub SharedString);
811
812impl LanguageModelProviderId {
813 pub const fn new(id: &'static str) -> Self {
814 Self(SharedString::new_static(id))
815 }
816}
817
818impl LanguageModelProviderName {
819 pub const fn new(id: &'static str) -> Self {
820 Self(SharedString::new_static(id))
821 }
822}
823
824impl fmt::Display for LanguageModelProviderId {
825 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
826 write!(f, "{}", self.0)
827 }
828}
829
830impl fmt::Display for LanguageModelProviderName {
831 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
832 write!(f, "{}", self.0)
833 }
834}
835
836impl From<String> for LanguageModelId {
837 fn from(value: String) -> Self {
838 Self(SharedString::from(value))
839 }
840}
841
842impl From<String> for LanguageModelName {
843 fn from(value: String) -> Self {
844 Self(SharedString::from(value))
845 }
846}
847
848impl From<String> for LanguageModelProviderId {
849 fn from(value: String) -> Self {
850 Self(SharedString::from(value))
851 }
852}
853
854impl From<String> for LanguageModelProviderName {
855 fn from(value: String) -> Self {
856 Self(SharedString::from(value))
857 }
858}
859
860impl From<Arc<str>> for LanguageModelProviderId {
861 fn from(value: Arc<str>) -> Self {
862 Self(SharedString::from(value))
863 }
864}
865
866impl From<Arc<str>> for LanguageModelProviderName {
867 fn from(value: Arc<str>) -> Self {
868 Self(SharedString::from(value))
869 }
870}
871
872#[cfg(test)]
873mod tests {
874 use super::*;
875
876 #[test]
877 fn test_from_cloud_failure_with_upstream_http_error() {
878 let error = LanguageModelCompletionError::from_cloud_failure(
879 String::from("anthropic").into(),
880 "upstream_http_error".to_string(),
881 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
882 None,
883 );
884
885 match error {
886 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
887 assert_eq!(provider.0, "anthropic");
888 }
889 _ => panic!(
890 "Expected ServerOverloaded error for 503 status, got: {:?}",
891 error
892 ),
893 }
894
895 let error = LanguageModelCompletionError::from_cloud_failure(
896 String::from("anthropic").into(),
897 "upstream_http_error".to_string(),
898 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
899 None,
900 );
901
902 match error {
903 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
904 assert_eq!(provider.0, "anthropic");
905 assert_eq!(message, "Internal server error");
906 }
907 _ => panic!(
908 "Expected ApiInternalServerError for 500 status, got: {:?}",
909 error
910 ),
911 }
912 }
913
914 #[test]
915 fn test_from_cloud_failure_with_standard_format() {
916 let error = LanguageModelCompletionError::from_cloud_failure(
917 String::from("anthropic").into(),
918 "upstream_http_503".to_string(),
919 "Service unavailable".to_string(),
920 None,
921 );
922
923 match error {
924 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
925 assert_eq!(provider.0, "anthropic");
926 }
927 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
928 }
929 }
930
931 #[test]
932 fn test_upstream_http_error_connection_timeout() {
933 let error = LanguageModelCompletionError::from_cloud_failure(
934 String::from("anthropic").into(),
935 "upstream_http_error".to_string(),
936 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
937 None,
938 );
939
940 match error {
941 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
942 assert_eq!(provider.0, "anthropic");
943 }
944 _ => panic!(
945 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
946 error
947 ),
948 }
949
950 let error = LanguageModelCompletionError::from_cloud_failure(
951 String::from("anthropic").into(),
952 "upstream_http_error".to_string(),
953 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
954 None,
955 );
956
957 match error {
958 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
959 assert_eq!(provider.0, "anthropic");
960 assert_eq!(
961 message,
962 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
963 );
964 }
965 _ => panic!(
966 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
967 error
968 ),
969 }
970 }
971
972 #[test]
973 fn test_language_model_tool_use_serializes_with_signature() {
974 use serde_json::json;
975
976 let tool_use = LanguageModelToolUse {
977 id: LanguageModelToolUseId::from("test_id"),
978 name: "test_tool".into(),
979 raw_input: json!({"arg": "value"}).to_string(),
980 input: json!({"arg": "value"}),
981 is_input_complete: true,
982 thought_signature: Some("test_signature".to_string()),
983 };
984
985 let serialized = serde_json::to_value(&tool_use).unwrap();
986
987 assert_eq!(serialized["id"], "test_id");
988 assert_eq!(serialized["name"], "test_tool");
989 assert_eq!(serialized["thought_signature"], "test_signature");
990 }
991
992 #[test]
993 fn test_language_model_tool_use_deserializes_with_missing_signature() {
994 use serde_json::json;
995
996 let json = json!({
997 "id": "test_id",
998 "name": "test_tool",
999 "raw_input": "{\"arg\":\"value\"}",
1000 "input": {"arg": "value"},
1001 "is_input_complete": true
1002 });
1003
1004 let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
1005
1006 assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
1007 assert_eq!(tool_use.name.as_ref(), "test_tool");
1008 assert_eq!(tool_use.thought_signature, None);
1009 }
1010
1011 #[test]
1012 fn test_language_model_tool_use_round_trip_with_signature() {
1013 use serde_json::json;
1014
1015 let original = LanguageModelToolUse {
1016 id: LanguageModelToolUseId::from("round_trip_id"),
1017 name: "round_trip_tool".into(),
1018 raw_input: json!({"key": "value"}).to_string(),
1019 input: json!({"key": "value"}),
1020 is_input_complete: true,
1021 thought_signature: Some("round_trip_sig".to_string()),
1022 };
1023
1024 let serialized = serde_json::to_value(&original).unwrap();
1025 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1026
1027 assert_eq!(deserialized.id, original.id);
1028 assert_eq!(deserialized.name, original.name);
1029 assert_eq!(deserialized.thought_signature, original.thought_signature);
1030 }
1031
1032 #[test]
1033 fn test_language_model_tool_use_round_trip_without_signature() {
1034 use serde_json::json;
1035
1036 let original = LanguageModelToolUse {
1037 id: LanguageModelToolUseId::from("no_sig_id"),
1038 name: "no_sig_tool".into(),
1039 raw_input: json!({"arg": "value"}).to_string(),
1040 input: json!({"arg": "value"}),
1041 is_input_complete: true,
1042 thought_signature: None,
1043 };
1044
1045 let serialized = serde_json::to_value(&original).unwrap();
1046 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1047
1048 assert_eq!(deserialized.id, original.id);
1049 assert_eq!(deserialized.name, original.name);
1050 assert_eq!(deserialized.thought_signature, None);
1051 }
1052}