1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7pub mod tool_schema;
8
9#[cfg(any(test, feature = "test-support"))]
10pub mod fake_provider;
11
12use anthropic::{AnthropicError, parse_prompt_too_long};
13use anyhow::{Result, anyhow};
14use client::Client;
15use cloud_llm_client::{CompletionMode, CompletionRequestStatus, UsageLimit};
16use futures::FutureExt;
17use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
18use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
19use http_client::{StatusCode, http};
20use icons::IconName;
21use open_router::OpenRouterError;
22use parking_lot::Mutex;
23use serde::{Deserialize, Serialize};
24pub use settings::LanguageModelCacheConfiguration;
25use std::ops::{Add, Sub};
26use std::str::FromStr;
27use std::sync::Arc;
28use std::time::Duration;
29use std::{fmt, io};
30use thiserror::Error;
31use util::serde::is_default;
32
33pub use crate::model::*;
34pub use crate::rate_limiter::*;
35pub use crate::registry::*;
36pub use crate::request::*;
37pub use crate::role::*;
38pub use crate::telemetry::*;
39pub use crate::tool_schema::LanguageModelToolSchemaFormat;
40
41pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
42 LanguageModelProviderId::new("anthropic");
43pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
44 LanguageModelProviderName::new("Anthropic");
45
46pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
47pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
48 LanguageModelProviderName::new("Google AI");
49
50pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
51pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
52 LanguageModelProviderName::new("OpenAI");
53
54pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
55pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
56
57pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
58pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
59 LanguageModelProviderName::new("Zed");
60
61pub fn init(client: Arc<Client>, cx: &mut App) {
62 init_settings(cx);
63 RefreshLlmTokenListener::register(client, cx);
64}
65
66pub fn init_settings(cx: &mut App) {
67 registry::init(cx);
68}
69
70/// A completion event from a language model.
71#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
72pub enum LanguageModelCompletionEvent {
73 Queued {
74 position: usize,
75 },
76 Started,
77 UsageUpdated {
78 amount: usize,
79 limit: UsageLimit,
80 },
81 ToolUseLimitReached,
82 Stop(StopReason),
83 Text(String),
84 Thinking {
85 text: String,
86 signature: Option<String>,
87 },
88 RedactedThinking {
89 data: String,
90 },
91 ToolUse(LanguageModelToolUse),
92 ToolUseJsonParseError {
93 id: LanguageModelToolUseId,
94 tool_name: Arc<str>,
95 raw_input: Arc<str>,
96 json_parse_error: String,
97 },
98 StartMessage {
99 message_id: String,
100 },
101 UsageUpdate(TokenUsage),
102}
103
104impl LanguageModelCompletionEvent {
105 pub fn from_completion_request_status(
106 status: CompletionRequestStatus,
107 upstream_provider: LanguageModelProviderName,
108 ) -> Result<Self, LanguageModelCompletionError> {
109 match status {
110 CompletionRequestStatus::Queued { position } => {
111 Ok(LanguageModelCompletionEvent::Queued { position })
112 }
113 CompletionRequestStatus::Started => Ok(LanguageModelCompletionEvent::Started),
114 CompletionRequestStatus::UsageUpdated { amount, limit } => {
115 Ok(LanguageModelCompletionEvent::UsageUpdated { amount, limit })
116 }
117 CompletionRequestStatus::ToolUseLimitReached => {
118 Ok(LanguageModelCompletionEvent::ToolUseLimitReached)
119 }
120 CompletionRequestStatus::Failed {
121 code,
122 message,
123 request_id: _,
124 retry_after,
125 } => Err(LanguageModelCompletionError::from_cloud_failure(
126 upstream_provider,
127 code,
128 message,
129 retry_after.map(Duration::from_secs_f64),
130 )),
131 }
132 }
133}
134
135#[derive(Error, Debug)]
136pub enum LanguageModelCompletionError {
137 #[error("prompt too large for context window")]
138 PromptTooLarge { tokens: Option<u64> },
139 #[error("missing {provider} API key")]
140 NoApiKey { provider: LanguageModelProviderName },
141 #[error("{provider}'s API rate limit exceeded")]
142 RateLimitExceeded {
143 provider: LanguageModelProviderName,
144 retry_after: Option<Duration>,
145 },
146 #[error("{provider}'s API servers are overloaded right now")]
147 ServerOverloaded {
148 provider: LanguageModelProviderName,
149 retry_after: Option<Duration>,
150 },
151 #[error("{provider}'s API server reported an internal server error: {message}")]
152 ApiInternalServerError {
153 provider: LanguageModelProviderName,
154 message: String,
155 },
156 #[error("{message}")]
157 UpstreamProviderError {
158 message: String,
159 status: StatusCode,
160 retry_after: Option<Duration>,
161 },
162 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
163 HttpResponseError {
164 provider: LanguageModelProviderName,
165 status_code: StatusCode,
166 message: String,
167 },
168
169 // Client errors
170 #[error("invalid request format to {provider}'s API: {message}")]
171 BadRequestFormat {
172 provider: LanguageModelProviderName,
173 message: String,
174 },
175 #[error("authentication error with {provider}'s API: {message}")]
176 AuthenticationError {
177 provider: LanguageModelProviderName,
178 message: String,
179 },
180 #[error("Permission error with {provider}'s API: {message}")]
181 PermissionError {
182 provider: LanguageModelProviderName,
183 message: String,
184 },
185 #[error("language model provider API endpoint not found")]
186 ApiEndpointNotFound { provider: LanguageModelProviderName },
187 #[error("I/O error reading response from {provider}'s API")]
188 ApiReadResponseError {
189 provider: LanguageModelProviderName,
190 #[source]
191 error: io::Error,
192 },
193 #[error("error serializing request to {provider} API")]
194 SerializeRequest {
195 provider: LanguageModelProviderName,
196 #[source]
197 error: serde_json::Error,
198 },
199 #[error("error building request body to {provider} API")]
200 BuildRequestBody {
201 provider: LanguageModelProviderName,
202 #[source]
203 error: http::Error,
204 },
205 #[error("error sending HTTP request to {provider} API")]
206 HttpSend {
207 provider: LanguageModelProviderName,
208 #[source]
209 error: anyhow::Error,
210 },
211 #[error("error deserializing {provider} API response")]
212 DeserializeResponse {
213 provider: LanguageModelProviderName,
214 #[source]
215 error: serde_json::Error,
216 },
217
218 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
219 #[error(transparent)]
220 Other(#[from] anyhow::Error),
221}
222
223impl LanguageModelCompletionError {
224 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
225 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
226 let upstream_status = error_json
227 .get("upstream_status")
228 .and_then(|v| v.as_u64())
229 .and_then(|status| u16::try_from(status).ok())
230 .and_then(|status| StatusCode::from_u16(status).ok())?;
231 let inner_message = error_json
232 .get("message")
233 .and_then(|v| v.as_str())
234 .unwrap_or(message)
235 .to_string();
236 Some((upstream_status, inner_message))
237 }
238
239 pub fn from_cloud_failure(
240 upstream_provider: LanguageModelProviderName,
241 code: String,
242 message: String,
243 retry_after: Option<Duration>,
244 ) -> Self {
245 if let Some(tokens) = parse_prompt_too_long(&message) {
246 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
247 // to be reported. This is a temporary workaround to handle this in the case where the
248 // token limit has been exceeded.
249 Self::PromptTooLarge {
250 tokens: Some(tokens),
251 }
252 } else if code == "upstream_http_error" {
253 if let Some((upstream_status, inner_message)) =
254 Self::parse_upstream_error_json(&message)
255 {
256 return Self::from_http_status(
257 upstream_provider,
258 upstream_status,
259 inner_message,
260 retry_after,
261 );
262 }
263 anyhow!("completion request failed, code: {code}, message: {message}").into()
264 } else if let Some(status_code) = code
265 .strip_prefix("upstream_http_")
266 .and_then(|code| StatusCode::from_str(code).ok())
267 {
268 Self::from_http_status(upstream_provider, status_code, message, retry_after)
269 } else if let Some(status_code) = code
270 .strip_prefix("http_")
271 .and_then(|code| StatusCode::from_str(code).ok())
272 {
273 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
274 } else {
275 anyhow!("completion request failed, code: {code}, message: {message}").into()
276 }
277 }
278
279 pub fn from_http_status(
280 provider: LanguageModelProviderName,
281 status_code: StatusCode,
282 message: String,
283 retry_after: Option<Duration>,
284 ) -> Self {
285 match status_code {
286 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
287 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
288 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
289 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
290 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
291 tokens: parse_prompt_too_long(&message),
292 },
293 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
294 provider,
295 retry_after,
296 },
297 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
298 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
299 provider,
300 retry_after,
301 },
302 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
303 provider,
304 retry_after,
305 },
306 _ => Self::HttpResponseError {
307 provider,
308 status_code,
309 message,
310 },
311 }
312 }
313}
314
315impl From<AnthropicError> for LanguageModelCompletionError {
316 fn from(error: AnthropicError) -> Self {
317 let provider = ANTHROPIC_PROVIDER_NAME;
318 match error {
319 AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
320 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
321 AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
322 AnthropicError::DeserializeResponse(error) => {
323 Self::DeserializeResponse { provider, error }
324 }
325 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
326 AnthropicError::HttpResponseError {
327 status_code,
328 message,
329 } => Self::HttpResponseError {
330 provider,
331 status_code,
332 message,
333 },
334 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
335 provider,
336 retry_after: Some(retry_after),
337 },
338 AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
339 provider,
340 retry_after,
341 },
342 AnthropicError::ApiError(api_error) => api_error.into(),
343 }
344 }
345}
346
347impl From<anthropic::ApiError> for LanguageModelCompletionError {
348 fn from(error: anthropic::ApiError) -> Self {
349 use anthropic::ApiErrorCode::*;
350 let provider = ANTHROPIC_PROVIDER_NAME;
351 match error.code() {
352 Some(code) => match code {
353 InvalidRequestError => Self::BadRequestFormat {
354 provider,
355 message: error.message,
356 },
357 AuthenticationError => Self::AuthenticationError {
358 provider,
359 message: error.message,
360 },
361 PermissionError => Self::PermissionError {
362 provider,
363 message: error.message,
364 },
365 NotFoundError => Self::ApiEndpointNotFound { provider },
366 RequestTooLarge => Self::PromptTooLarge {
367 tokens: parse_prompt_too_long(&error.message),
368 },
369 RateLimitError => Self::RateLimitExceeded {
370 provider,
371 retry_after: None,
372 },
373 ApiError => Self::ApiInternalServerError {
374 provider,
375 message: error.message,
376 },
377 OverloadedError => Self::ServerOverloaded {
378 provider,
379 retry_after: None,
380 },
381 },
382 None => Self::Other(error.into()),
383 }
384 }
385}
386
387impl From<open_ai::RequestError> for LanguageModelCompletionError {
388 fn from(error: open_ai::RequestError) -> Self {
389 match error {
390 open_ai::RequestError::HttpResponseError {
391 provider,
392 status_code,
393 body,
394 headers,
395 } => {
396 let retry_after = headers
397 .get(http::header::RETRY_AFTER)
398 .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
399 .map(Duration::from_secs);
400
401 Self::from_http_status(provider.into(), status_code, body, retry_after)
402 }
403 open_ai::RequestError::Other(e) => Self::Other(e),
404 }
405 }
406}
407
408impl From<OpenRouterError> for LanguageModelCompletionError {
409 fn from(error: OpenRouterError) -> Self {
410 let provider = LanguageModelProviderName::new("OpenRouter");
411 match error {
412 OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
413 OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
414 OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
415 OpenRouterError::DeserializeResponse(error) => {
416 Self::DeserializeResponse { provider, error }
417 }
418 OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
419 OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
420 provider,
421 retry_after: Some(retry_after),
422 },
423 OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
424 provider,
425 retry_after,
426 },
427 OpenRouterError::ApiError(api_error) => api_error.into(),
428 }
429 }
430}
431
432impl From<open_router::ApiError> for LanguageModelCompletionError {
433 fn from(error: open_router::ApiError) -> Self {
434 use open_router::ApiErrorCode::*;
435 let provider = LanguageModelProviderName::new("OpenRouter");
436 match error.code {
437 InvalidRequestError => Self::BadRequestFormat {
438 provider,
439 message: error.message,
440 },
441 AuthenticationError => Self::AuthenticationError {
442 provider,
443 message: error.message,
444 },
445 PaymentRequiredError => Self::AuthenticationError {
446 provider,
447 message: format!("Payment required: {}", error.message),
448 },
449 PermissionError => Self::PermissionError {
450 provider,
451 message: error.message,
452 },
453 RequestTimedOut => Self::HttpResponseError {
454 provider,
455 status_code: StatusCode::REQUEST_TIMEOUT,
456 message: error.message,
457 },
458 RateLimitError => Self::RateLimitExceeded {
459 provider,
460 retry_after: None,
461 },
462 ApiError => Self::ApiInternalServerError {
463 provider,
464 message: error.message,
465 },
466 OverloadedError => Self::ServerOverloaded {
467 provider,
468 retry_after: None,
469 },
470 }
471 }
472}
473
474#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
475#[serde(rename_all = "snake_case")]
476pub enum StopReason {
477 EndTurn,
478 MaxTokens,
479 ToolUse,
480 Refusal,
481}
482
483#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
484pub struct TokenUsage {
485 #[serde(default, skip_serializing_if = "is_default")]
486 pub input_tokens: u64,
487 #[serde(default, skip_serializing_if = "is_default")]
488 pub output_tokens: u64,
489 #[serde(default, skip_serializing_if = "is_default")]
490 pub cache_creation_input_tokens: u64,
491 #[serde(default, skip_serializing_if = "is_default")]
492 pub cache_read_input_tokens: u64,
493}
494
495impl TokenUsage {
496 pub fn total_tokens(&self) -> u64 {
497 self.input_tokens
498 + self.output_tokens
499 + self.cache_read_input_tokens
500 + self.cache_creation_input_tokens
501 }
502}
503
504impl Add<TokenUsage> for TokenUsage {
505 type Output = Self;
506
507 fn add(self, other: Self) -> Self {
508 Self {
509 input_tokens: self.input_tokens + other.input_tokens,
510 output_tokens: self.output_tokens + other.output_tokens,
511 cache_creation_input_tokens: self.cache_creation_input_tokens
512 + other.cache_creation_input_tokens,
513 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
514 }
515 }
516}
517
518impl Sub<TokenUsage> for TokenUsage {
519 type Output = Self;
520
521 fn sub(self, other: Self) -> Self {
522 Self {
523 input_tokens: self.input_tokens - other.input_tokens,
524 output_tokens: self.output_tokens - other.output_tokens,
525 cache_creation_input_tokens: self.cache_creation_input_tokens
526 - other.cache_creation_input_tokens,
527 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
528 }
529 }
530}
531
532#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
533pub struct LanguageModelToolUseId(Arc<str>);
534
535impl fmt::Display for LanguageModelToolUseId {
536 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
537 write!(f, "{}", self.0)
538 }
539}
540
541impl<T> From<T> for LanguageModelToolUseId
542where
543 T: Into<Arc<str>>,
544{
545 fn from(value: T) -> Self {
546 Self(value.into())
547 }
548}
549
550#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
551pub struct LanguageModelToolUse {
552 pub id: LanguageModelToolUseId,
553 pub name: Arc<str>,
554 pub raw_input: String,
555 pub input: serde_json::Value,
556 pub is_input_complete: bool,
557 /// Thought signature the model sent us. Some models require that this
558 /// signature be preserved and sent back in conversation history for validation.
559 pub thought_signature: Option<String>,
560}
561
562pub struct LanguageModelTextStream {
563 pub message_id: Option<String>,
564 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
565 // Has complete token usage after the stream has finished
566 pub last_token_usage: Arc<Mutex<TokenUsage>>,
567}
568
569impl Default for LanguageModelTextStream {
570 fn default() -> Self {
571 Self {
572 message_id: None,
573 stream: Box::pin(futures::stream::empty()),
574 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
575 }
576 }
577}
578
579pub trait LanguageModel: Send + Sync {
580 fn id(&self) -> LanguageModelId;
581 fn name(&self) -> LanguageModelName;
582 fn provider_id(&self) -> LanguageModelProviderId;
583 fn provider_name(&self) -> LanguageModelProviderName;
584 fn upstream_provider_id(&self) -> LanguageModelProviderId {
585 self.provider_id()
586 }
587 fn upstream_provider_name(&self) -> LanguageModelProviderName {
588 self.provider_name()
589 }
590
591 fn telemetry_id(&self) -> String;
592
593 fn api_key(&self, _cx: &App) -> Option<String> {
594 None
595 }
596
597 /// Whether this model supports images
598 fn supports_images(&self) -> bool;
599
600 /// Whether this model supports tools.
601 fn supports_tools(&self) -> bool;
602
603 /// Whether this model supports choosing which tool to use.
604 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
605
606 /// Returns whether this model supports "burn mode";
607 fn supports_burn_mode(&self) -> bool {
608 false
609 }
610
611 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
612 LanguageModelToolSchemaFormat::JsonSchema
613 }
614
615 fn max_token_count(&self) -> u64;
616 /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
617 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
618 None
619 }
620 fn max_output_tokens(&self) -> Option<u64> {
621 None
622 }
623
624 fn count_tokens(
625 &self,
626 request: LanguageModelRequest,
627 cx: &App,
628 ) -> BoxFuture<'static, Result<u64>>;
629
630 fn stream_completion(
631 &self,
632 request: LanguageModelRequest,
633 cx: &AsyncApp,
634 ) -> BoxFuture<
635 'static,
636 Result<
637 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
638 LanguageModelCompletionError,
639 >,
640 >;
641
642 fn stream_completion_text(
643 &self,
644 request: LanguageModelRequest,
645 cx: &AsyncApp,
646 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
647 let future = self.stream_completion(request, cx);
648
649 async move {
650 let events = future.await?;
651 let mut events = events.fuse();
652 let mut message_id = None;
653 let mut first_item_text = None;
654 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
655
656 if let Some(first_event) = events.next().await {
657 match first_event {
658 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
659 message_id = Some(id);
660 }
661 Ok(LanguageModelCompletionEvent::Text(text)) => {
662 first_item_text = Some(text);
663 }
664 _ => (),
665 }
666 }
667
668 let stream = futures::stream::iter(first_item_text.map(Ok))
669 .chain(events.filter_map({
670 let last_token_usage = last_token_usage.clone();
671 move |result| {
672 let last_token_usage = last_token_usage.clone();
673 async move {
674 match result {
675 Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
676 Ok(LanguageModelCompletionEvent::Started) => None,
677 Ok(LanguageModelCompletionEvent::UsageUpdated { .. }) => None,
678 Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => None,
679 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
680 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
681 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
682 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
683 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
684 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
685 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
686 ..
687 }) => None,
688 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
689 *last_token_usage.lock() = token_usage;
690 None
691 }
692 Err(err) => Some(Err(err)),
693 }
694 }
695 }
696 }))
697 .boxed();
698
699 Ok(LanguageModelTextStream {
700 message_id,
701 stream,
702 last_token_usage,
703 })
704 }
705 .boxed()
706 }
707
708 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
709 None
710 }
711
712 #[cfg(any(test, feature = "test-support"))]
713 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
714 unimplemented!()
715 }
716}
717
718pub trait LanguageModelExt: LanguageModel {
719 fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
720 match mode {
721 CompletionMode::Normal => self.max_token_count(),
722 CompletionMode::Max => self
723 .max_token_count_in_burn_mode()
724 .unwrap_or_else(|| self.max_token_count()),
725 }
726 }
727}
728impl LanguageModelExt for dyn LanguageModel {}
729
730/// An error that occurred when trying to authenticate the language model provider.
731#[derive(Debug, Error)]
732pub enum AuthenticateError {
733 #[error("connection refused")]
734 ConnectionRefused,
735 #[error("credentials not found")]
736 CredentialsNotFound,
737 #[error(transparent)]
738 Other(#[from] anyhow::Error),
739}
740
741pub trait LanguageModelProvider: 'static {
742 fn id(&self) -> LanguageModelProviderId;
743 fn name(&self) -> LanguageModelProviderName;
744 fn icon(&self) -> IconName {
745 IconName::ZedAssistant
746 }
747 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
748 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
749 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
750 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
751 Vec::new()
752 }
753 fn is_authenticated(&self, cx: &App) -> bool;
754 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
755 fn configuration_view(
756 &self,
757 target_agent: ConfigurationViewTargetAgent,
758 window: &mut Window,
759 cx: &mut App,
760 ) -> AnyView;
761 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
762}
763
764#[derive(Default, Clone)]
765pub enum ConfigurationViewTargetAgent {
766 #[default]
767 ZedAgent,
768 Other(SharedString),
769}
770
771#[derive(PartialEq, Eq)]
772pub enum LanguageModelProviderTosView {
773 /// When there are some past interactions in the Agent Panel.
774 ThreadEmptyState,
775 /// When there are no past interactions in the Agent Panel.
776 ThreadFreshStart,
777 TextThreadPopup,
778 Configuration,
779}
780
781pub trait LanguageModelProviderState: 'static {
782 type ObservableEntity;
783
784 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
785
786 fn subscribe<T: 'static>(
787 &self,
788 cx: &mut gpui::Context<T>,
789 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
790 ) -> Option<gpui::Subscription> {
791 let entity = self.observable_entity()?;
792 Some(cx.observe(&entity, move |this, _, cx| {
793 callback(this, cx);
794 }))
795 }
796}
797
798#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
799pub struct LanguageModelId(pub SharedString);
800
801#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
802pub struct LanguageModelName(pub SharedString);
803
804#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
805pub struct LanguageModelProviderId(pub SharedString);
806
807#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
808pub struct LanguageModelProviderName(pub SharedString);
809
810impl LanguageModelProviderId {
811 pub const fn new(id: &'static str) -> Self {
812 Self(SharedString::new_static(id))
813 }
814}
815
816impl LanguageModelProviderName {
817 pub const fn new(id: &'static str) -> Self {
818 Self(SharedString::new_static(id))
819 }
820}
821
822impl fmt::Display for LanguageModelProviderId {
823 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
824 write!(f, "{}", self.0)
825 }
826}
827
828impl fmt::Display for LanguageModelProviderName {
829 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
830 write!(f, "{}", self.0)
831 }
832}
833
834impl From<String> for LanguageModelId {
835 fn from(value: String) -> Self {
836 Self(SharedString::from(value))
837 }
838}
839
840impl From<String> for LanguageModelName {
841 fn from(value: String) -> Self {
842 Self(SharedString::from(value))
843 }
844}
845
846impl From<String> for LanguageModelProviderId {
847 fn from(value: String) -> Self {
848 Self(SharedString::from(value))
849 }
850}
851
852impl From<String> for LanguageModelProviderName {
853 fn from(value: String) -> Self {
854 Self(SharedString::from(value))
855 }
856}
857
858impl From<Arc<str>> for LanguageModelProviderId {
859 fn from(value: Arc<str>) -> Self {
860 Self(SharedString::from(value))
861 }
862}
863
864impl From<Arc<str>> for LanguageModelProviderName {
865 fn from(value: Arc<str>) -> Self {
866 Self(SharedString::from(value))
867 }
868}
869
870#[cfg(test)]
871mod tests {
872 use super::*;
873
874 #[test]
875 fn test_from_cloud_failure_with_upstream_http_error() {
876 let error = LanguageModelCompletionError::from_cloud_failure(
877 String::from("anthropic").into(),
878 "upstream_http_error".to_string(),
879 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
880 None,
881 );
882
883 match error {
884 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
885 assert_eq!(provider.0, "anthropic");
886 }
887 _ => panic!(
888 "Expected ServerOverloaded error for 503 status, got: {:?}",
889 error
890 ),
891 }
892
893 let error = LanguageModelCompletionError::from_cloud_failure(
894 String::from("anthropic").into(),
895 "upstream_http_error".to_string(),
896 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
897 None,
898 );
899
900 match error {
901 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
902 assert_eq!(provider.0, "anthropic");
903 assert_eq!(message, "Internal server error");
904 }
905 _ => panic!(
906 "Expected ApiInternalServerError for 500 status, got: {:?}",
907 error
908 ),
909 }
910 }
911
912 #[test]
913 fn test_from_cloud_failure_with_standard_format() {
914 let error = LanguageModelCompletionError::from_cloud_failure(
915 String::from("anthropic").into(),
916 "upstream_http_503".to_string(),
917 "Service unavailable".to_string(),
918 None,
919 );
920
921 match error {
922 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
923 assert_eq!(provider.0, "anthropic");
924 }
925 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
926 }
927 }
928
929 #[test]
930 fn test_upstream_http_error_connection_timeout() {
931 let error = LanguageModelCompletionError::from_cloud_failure(
932 String::from("anthropic").into(),
933 "upstream_http_error".to_string(),
934 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
935 None,
936 );
937
938 match error {
939 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
940 assert_eq!(provider.0, "anthropic");
941 }
942 _ => panic!(
943 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
944 error
945 ),
946 }
947
948 let error = LanguageModelCompletionError::from_cloud_failure(
949 String::from("anthropic").into(),
950 "upstream_http_error".to_string(),
951 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
952 None,
953 );
954
955 match error {
956 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
957 assert_eq!(provider.0, "anthropic");
958 assert_eq!(
959 message,
960 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
961 );
962 }
963 _ => panic!(
964 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
965 error
966 ),
967 }
968 }
969
970 #[test]
971 fn test_language_model_tool_use_serializes_with_signature() {
972 use serde_json::json;
973
974 let tool_use = LanguageModelToolUse {
975 id: LanguageModelToolUseId::from("test_id"),
976 name: "test_tool".into(),
977 raw_input: json!({"arg": "value"}).to_string(),
978 input: json!({"arg": "value"}),
979 is_input_complete: true,
980 thought_signature: Some("test_signature".to_string()),
981 };
982
983 let serialized = serde_json::to_value(&tool_use).unwrap();
984
985 assert_eq!(serialized["id"], "test_id");
986 assert_eq!(serialized["name"], "test_tool");
987 assert_eq!(serialized["thought_signature"], "test_signature");
988 }
989
990 #[test]
991 fn test_language_model_tool_use_deserializes_with_missing_signature() {
992 use serde_json::json;
993
994 let json = json!({
995 "id": "test_id",
996 "name": "test_tool",
997 "raw_input": "{\"arg\":\"value\"}",
998 "input": {"arg": "value"},
999 "is_input_complete": true
1000 });
1001
1002 let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
1003
1004 assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
1005 assert_eq!(tool_use.name.as_ref(), "test_tool");
1006 assert_eq!(tool_use.thought_signature, None);
1007 }
1008
1009 #[test]
1010 fn test_language_model_tool_use_round_trip_with_signature() {
1011 use serde_json::json;
1012
1013 let original = LanguageModelToolUse {
1014 id: LanguageModelToolUseId::from("round_trip_id"),
1015 name: "round_trip_tool".into(),
1016 raw_input: json!({"key": "value"}).to_string(),
1017 input: json!({"key": "value"}),
1018 is_input_complete: true,
1019 thought_signature: Some("round_trip_sig".to_string()),
1020 };
1021
1022 let serialized = serde_json::to_value(&original).unwrap();
1023 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1024
1025 assert_eq!(deserialized.id, original.id);
1026 assert_eq!(deserialized.name, original.name);
1027 assert_eq!(deserialized.thought_signature, original.thought_signature);
1028 }
1029
1030 #[test]
1031 fn test_language_model_tool_use_round_trip_without_signature() {
1032 use serde_json::json;
1033
1034 let original = LanguageModelToolUse {
1035 id: LanguageModelToolUseId::from("no_sig_id"),
1036 name: "no_sig_tool".into(),
1037 raw_input: json!({"key": "value"}).to_string(),
1038 input: json!({"key": "value"}),
1039 is_input_complete: true,
1040 thought_signature: None,
1041 };
1042
1043 let serialized = serde_json::to_value(&original).unwrap();
1044 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1045
1046 assert_eq!(deserialized.id, original.id);
1047 assert_eq!(deserialized.name, original.name);
1048 assert_eq!(deserialized.thought_signature, None);
1049 }
1050}