1mod api_key;
2mod model;
3mod rate_limiter;
4mod registry;
5mod request;
6mod role;
7mod telemetry;
8pub mod tool_schema;
9
10#[cfg(any(test, feature = "test-support"))]
11pub mod fake_provider;
12
13use anthropic::{AnthropicError, parse_prompt_too_long};
14use anyhow::{Result, anyhow};
15use client::Client;
16use cloud_llm_client::{CompletionMode, CompletionRequestStatus, UsageLimit};
17use futures::FutureExt;
18use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
19use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
20use http_client::{StatusCode, http};
21use icons::IconName;
22use open_router::OpenRouterError;
23use parking_lot::Mutex;
24use serde::{Deserialize, Serialize};
25pub use settings::LanguageModelCacheConfiguration;
26use std::ops::{Add, Sub};
27use std::str::FromStr;
28use std::sync::Arc;
29use std::time::Duration;
30use std::{fmt, io};
31use thiserror::Error;
32use util::serde::is_default;
33
34pub use crate::api_key::{ApiKey, ApiKeyState};
35pub use crate::model::*;
36pub use crate::rate_limiter::*;
37pub use crate::registry::*;
38pub use crate::request::*;
39pub use crate::role::*;
40pub use crate::telemetry::*;
41pub use crate::tool_schema::LanguageModelToolSchemaFormat;
42pub use zed_env_vars::{EnvVar, env_var};
43
44pub const ANTHROPIC_PROVIDER_ID: LanguageModelProviderId =
45 LanguageModelProviderId::new("anthropic");
46pub const ANTHROPIC_PROVIDER_NAME: LanguageModelProviderName =
47 LanguageModelProviderName::new("Anthropic");
48
49pub const GOOGLE_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("google");
50pub const GOOGLE_PROVIDER_NAME: LanguageModelProviderName =
51 LanguageModelProviderName::new("Google AI");
52
53pub const OPEN_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai");
54pub const OPEN_AI_PROVIDER_NAME: LanguageModelProviderName =
55 LanguageModelProviderName::new("OpenAI");
56
57pub const X_AI_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("x_ai");
58pub const X_AI_PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("xAI");
59
60pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("zed.dev");
61pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
62 LanguageModelProviderName::new("Zed");
63
64pub fn init(client: Arc<Client>, cx: &mut App) {
65 init_settings(cx);
66 RefreshLlmTokenListener::register(client, cx);
67}
68
69pub fn init_settings(cx: &mut App) {
70 registry::init(cx);
71}
72
73/// A completion event from a language model.
74#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
75pub enum LanguageModelCompletionEvent {
76 Queued {
77 position: usize,
78 },
79 Started,
80 UsageUpdated {
81 amount: usize,
82 limit: UsageLimit,
83 },
84 ToolUseLimitReached,
85 Stop(StopReason),
86 Text(String),
87 Thinking {
88 text: String,
89 signature: Option<String>,
90 },
91 RedactedThinking {
92 data: String,
93 },
94 ToolUse(LanguageModelToolUse),
95 ToolUseJsonParseError {
96 id: LanguageModelToolUseId,
97 tool_name: Arc<str>,
98 raw_input: Arc<str>,
99 json_parse_error: String,
100 },
101 StartMessage {
102 message_id: String,
103 },
104 ReasoningDetails(serde_json::Value),
105 UsageUpdate(TokenUsage),
106}
107
108impl LanguageModelCompletionEvent {
109 pub fn from_completion_request_status(
110 status: CompletionRequestStatus,
111 upstream_provider: LanguageModelProviderName,
112 ) -> Result<Self, LanguageModelCompletionError> {
113 match status {
114 CompletionRequestStatus::Queued { position } => {
115 Ok(LanguageModelCompletionEvent::Queued { position })
116 }
117 CompletionRequestStatus::Started => Ok(LanguageModelCompletionEvent::Started),
118 CompletionRequestStatus::UsageUpdated { amount, limit } => {
119 Ok(LanguageModelCompletionEvent::UsageUpdated { amount, limit })
120 }
121 CompletionRequestStatus::ToolUseLimitReached => {
122 Ok(LanguageModelCompletionEvent::ToolUseLimitReached)
123 }
124 CompletionRequestStatus::Failed {
125 code,
126 message,
127 request_id: _,
128 retry_after,
129 } => Err(LanguageModelCompletionError::from_cloud_failure(
130 upstream_provider,
131 code,
132 message,
133 retry_after.map(Duration::from_secs_f64),
134 )),
135 }
136 }
137}
138
139#[derive(Error, Debug)]
140pub enum LanguageModelCompletionError {
141 #[error("prompt too large for context window")]
142 PromptTooLarge { tokens: Option<u64> },
143 #[error("missing {provider} API key")]
144 NoApiKey { provider: LanguageModelProviderName },
145 #[error("{provider}'s API rate limit exceeded")]
146 RateLimitExceeded {
147 provider: LanguageModelProviderName,
148 retry_after: Option<Duration>,
149 },
150 #[error("{provider}'s API servers are overloaded right now")]
151 ServerOverloaded {
152 provider: LanguageModelProviderName,
153 retry_after: Option<Duration>,
154 },
155 #[error("{provider}'s API server reported an internal server error: {message}")]
156 ApiInternalServerError {
157 provider: LanguageModelProviderName,
158 message: String,
159 },
160 #[error("{message}")]
161 UpstreamProviderError {
162 message: String,
163 status: StatusCode,
164 retry_after: Option<Duration>,
165 },
166 #[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
167 HttpResponseError {
168 provider: LanguageModelProviderName,
169 status_code: StatusCode,
170 message: String,
171 },
172
173 // Client errors
174 #[error("invalid request format to {provider}'s API: {message}")]
175 BadRequestFormat {
176 provider: LanguageModelProviderName,
177 message: String,
178 },
179 #[error("authentication error with {provider}'s API: {message}")]
180 AuthenticationError {
181 provider: LanguageModelProviderName,
182 message: String,
183 },
184 #[error("Permission error with {provider}'s API: {message}")]
185 PermissionError {
186 provider: LanguageModelProviderName,
187 message: String,
188 },
189 #[error("language model provider API endpoint not found")]
190 ApiEndpointNotFound { provider: LanguageModelProviderName },
191 #[error("I/O error reading response from {provider}'s API")]
192 ApiReadResponseError {
193 provider: LanguageModelProviderName,
194 #[source]
195 error: io::Error,
196 },
197 #[error("error serializing request to {provider} API")]
198 SerializeRequest {
199 provider: LanguageModelProviderName,
200 #[source]
201 error: serde_json::Error,
202 },
203 #[error("error building request body to {provider} API")]
204 BuildRequestBody {
205 provider: LanguageModelProviderName,
206 #[source]
207 error: http::Error,
208 },
209 #[error("error sending HTTP request to {provider} API")]
210 HttpSend {
211 provider: LanguageModelProviderName,
212 #[source]
213 error: anyhow::Error,
214 },
215 #[error("error deserializing {provider} API response")]
216 DeserializeResponse {
217 provider: LanguageModelProviderName,
218 #[source]
219 error: serde_json::Error,
220 },
221
222 // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
223 #[error(transparent)]
224 Other(#[from] anyhow::Error),
225}
226
227impl LanguageModelCompletionError {
228 fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
229 let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
230 let upstream_status = error_json
231 .get("upstream_status")
232 .and_then(|v| v.as_u64())
233 .and_then(|status| u16::try_from(status).ok())
234 .and_then(|status| StatusCode::from_u16(status).ok())?;
235 let inner_message = error_json
236 .get("message")
237 .and_then(|v| v.as_str())
238 .unwrap_or(message)
239 .to_string();
240 Some((upstream_status, inner_message))
241 }
242
243 pub fn from_cloud_failure(
244 upstream_provider: LanguageModelProviderName,
245 code: String,
246 message: String,
247 retry_after: Option<Duration>,
248 ) -> Self {
249 if let Some(tokens) = parse_prompt_too_long(&message) {
250 // TODO: currently Anthropic PAYLOAD_TOO_LARGE response may cause INTERNAL_SERVER_ERROR
251 // to be reported. This is a temporary workaround to handle this in the case where the
252 // token limit has been exceeded.
253 Self::PromptTooLarge {
254 tokens: Some(tokens),
255 }
256 } else if code == "upstream_http_error" {
257 if let Some((upstream_status, inner_message)) =
258 Self::parse_upstream_error_json(&message)
259 {
260 return Self::from_http_status(
261 upstream_provider,
262 upstream_status,
263 inner_message,
264 retry_after,
265 );
266 }
267 anyhow!("completion request failed, code: {code}, message: {message}").into()
268 } else if let Some(status_code) = code
269 .strip_prefix("upstream_http_")
270 .and_then(|code| StatusCode::from_str(code).ok())
271 {
272 Self::from_http_status(upstream_provider, status_code, message, retry_after)
273 } else if let Some(status_code) = code
274 .strip_prefix("http_")
275 .and_then(|code| StatusCode::from_str(code).ok())
276 {
277 Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
278 } else {
279 anyhow!("completion request failed, code: {code}, message: {message}").into()
280 }
281 }
282
283 pub fn from_http_status(
284 provider: LanguageModelProviderName,
285 status_code: StatusCode,
286 message: String,
287 retry_after: Option<Duration>,
288 ) -> Self {
289 match status_code {
290 StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
291 StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
292 StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
293 StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
294 StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
295 tokens: parse_prompt_too_long(&message),
296 },
297 StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
298 provider,
299 retry_after,
300 },
301 StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
302 StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
303 provider,
304 retry_after,
305 },
306 _ if status_code.as_u16() == 529 => Self::ServerOverloaded {
307 provider,
308 retry_after,
309 },
310 _ => Self::HttpResponseError {
311 provider,
312 status_code,
313 message,
314 },
315 }
316 }
317}
318
319impl From<AnthropicError> for LanguageModelCompletionError {
320 fn from(error: AnthropicError) -> Self {
321 let provider = ANTHROPIC_PROVIDER_NAME;
322 match error {
323 AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
324 AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
325 AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
326 AnthropicError::DeserializeResponse(error) => {
327 Self::DeserializeResponse { provider, error }
328 }
329 AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
330 AnthropicError::HttpResponseError {
331 status_code,
332 message,
333 } => Self::HttpResponseError {
334 provider,
335 status_code,
336 message,
337 },
338 AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
339 provider,
340 retry_after: Some(retry_after),
341 },
342 AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
343 provider,
344 retry_after,
345 },
346 AnthropicError::ApiError(api_error) => api_error.into(),
347 }
348 }
349}
350
351impl From<anthropic::ApiError> for LanguageModelCompletionError {
352 fn from(error: anthropic::ApiError) -> Self {
353 use anthropic::ApiErrorCode::*;
354 let provider = ANTHROPIC_PROVIDER_NAME;
355 match error.code() {
356 Some(code) => match code {
357 InvalidRequestError => Self::BadRequestFormat {
358 provider,
359 message: error.message,
360 },
361 AuthenticationError => Self::AuthenticationError {
362 provider,
363 message: error.message,
364 },
365 PermissionError => Self::PermissionError {
366 provider,
367 message: error.message,
368 },
369 NotFoundError => Self::ApiEndpointNotFound { provider },
370 RequestTooLarge => Self::PromptTooLarge {
371 tokens: parse_prompt_too_long(&error.message),
372 },
373 RateLimitError => Self::RateLimitExceeded {
374 provider,
375 retry_after: None,
376 },
377 ApiError => Self::ApiInternalServerError {
378 provider,
379 message: error.message,
380 },
381 OverloadedError => Self::ServerOverloaded {
382 provider,
383 retry_after: None,
384 },
385 },
386 None => Self::Other(error.into()),
387 }
388 }
389}
390
391impl From<open_ai::RequestError> for LanguageModelCompletionError {
392 fn from(error: open_ai::RequestError) -> Self {
393 match error {
394 open_ai::RequestError::HttpResponseError {
395 provider,
396 status_code,
397 body,
398 headers,
399 } => {
400 let retry_after = headers
401 .get(http::header::RETRY_AFTER)
402 .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
403 .map(Duration::from_secs);
404
405 Self::from_http_status(provider.into(), status_code, body, retry_after)
406 }
407 open_ai::RequestError::Other(e) => Self::Other(e),
408 }
409 }
410}
411
412impl From<OpenRouterError> for LanguageModelCompletionError {
413 fn from(error: OpenRouterError) -> Self {
414 let provider = LanguageModelProviderName::new("OpenRouter");
415 match error {
416 OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
417 OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
418 OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
419 OpenRouterError::DeserializeResponse(error) => {
420 Self::DeserializeResponse { provider, error }
421 }
422 OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
423 OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
424 provider,
425 retry_after: Some(retry_after),
426 },
427 OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
428 provider,
429 retry_after,
430 },
431 OpenRouterError::ApiError(api_error) => api_error.into(),
432 }
433 }
434}
435
436impl From<open_router::ApiError> for LanguageModelCompletionError {
437 fn from(error: open_router::ApiError) -> Self {
438 use open_router::ApiErrorCode::*;
439 let provider = LanguageModelProviderName::new("OpenRouter");
440 match error.code {
441 InvalidRequestError => Self::BadRequestFormat {
442 provider,
443 message: error.message,
444 },
445 AuthenticationError => Self::AuthenticationError {
446 provider,
447 message: error.message,
448 },
449 PaymentRequiredError => Self::AuthenticationError {
450 provider,
451 message: format!("Payment required: {}", error.message),
452 },
453 PermissionError => Self::PermissionError {
454 provider,
455 message: error.message,
456 },
457 RequestTimedOut => Self::HttpResponseError {
458 provider,
459 status_code: StatusCode::REQUEST_TIMEOUT,
460 message: error.message,
461 },
462 RateLimitError => Self::RateLimitExceeded {
463 provider,
464 retry_after: None,
465 },
466 ApiError => Self::ApiInternalServerError {
467 provider,
468 message: error.message,
469 },
470 OverloadedError => Self::ServerOverloaded {
471 provider,
472 retry_after: None,
473 },
474 }
475 }
476}
477
478#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
479#[serde(rename_all = "snake_case")]
480pub enum StopReason {
481 EndTurn,
482 MaxTokens,
483 ToolUse,
484 Refusal,
485}
486
487#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize, Default)]
488pub struct TokenUsage {
489 #[serde(default, skip_serializing_if = "is_default")]
490 pub input_tokens: u64,
491 #[serde(default, skip_serializing_if = "is_default")]
492 pub output_tokens: u64,
493 #[serde(default, skip_serializing_if = "is_default")]
494 pub cache_creation_input_tokens: u64,
495 #[serde(default, skip_serializing_if = "is_default")]
496 pub cache_read_input_tokens: u64,
497}
498
499impl TokenUsage {
500 pub fn total_tokens(&self) -> u64 {
501 self.input_tokens
502 + self.output_tokens
503 + self.cache_read_input_tokens
504 + self.cache_creation_input_tokens
505 }
506}
507
508impl Add<TokenUsage> for TokenUsage {
509 type Output = Self;
510
511 fn add(self, other: Self) -> Self {
512 Self {
513 input_tokens: self.input_tokens + other.input_tokens,
514 output_tokens: self.output_tokens + other.output_tokens,
515 cache_creation_input_tokens: self.cache_creation_input_tokens
516 + other.cache_creation_input_tokens,
517 cache_read_input_tokens: self.cache_read_input_tokens + other.cache_read_input_tokens,
518 }
519 }
520}
521
522impl Sub<TokenUsage> for TokenUsage {
523 type Output = Self;
524
525 fn sub(self, other: Self) -> Self {
526 Self {
527 input_tokens: self.input_tokens - other.input_tokens,
528 output_tokens: self.output_tokens - other.output_tokens,
529 cache_creation_input_tokens: self.cache_creation_input_tokens
530 - other.cache_creation_input_tokens,
531 cache_read_input_tokens: self.cache_read_input_tokens - other.cache_read_input_tokens,
532 }
533 }
534}
535
536#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
537pub struct LanguageModelToolUseId(Arc<str>);
538
539impl fmt::Display for LanguageModelToolUseId {
540 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
541 write!(f, "{}", self.0)
542 }
543}
544
545impl<T> From<T> for LanguageModelToolUseId
546where
547 T: Into<Arc<str>>,
548{
549 fn from(value: T) -> Self {
550 Self(value.into())
551 }
552}
553
554#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
555pub struct LanguageModelToolUse {
556 pub id: LanguageModelToolUseId,
557 pub name: Arc<str>,
558 pub raw_input: String,
559 pub input: serde_json::Value,
560 pub is_input_complete: bool,
561 /// Thought signature the model sent us. Some models require that this
562 /// signature be preserved and sent back in conversation history for validation.
563 pub thought_signature: Option<String>,
564}
565
566pub struct LanguageModelTextStream {
567 pub message_id: Option<String>,
568 pub stream: BoxStream<'static, Result<String, LanguageModelCompletionError>>,
569 // Has complete token usage after the stream has finished
570 pub last_token_usage: Arc<Mutex<TokenUsage>>,
571}
572
573impl Default for LanguageModelTextStream {
574 fn default() -> Self {
575 Self {
576 message_id: None,
577 stream: Box::pin(futures::stream::empty()),
578 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
579 }
580 }
581}
582
583pub trait LanguageModel: Send + Sync {
584 fn id(&self) -> LanguageModelId;
585 fn name(&self) -> LanguageModelName;
586 fn provider_id(&self) -> LanguageModelProviderId;
587 fn provider_name(&self) -> LanguageModelProviderName;
588 fn upstream_provider_id(&self) -> LanguageModelProviderId {
589 self.provider_id()
590 }
591 fn upstream_provider_name(&self) -> LanguageModelProviderName {
592 self.provider_name()
593 }
594
595 fn telemetry_id(&self) -> String;
596
597 fn api_key(&self, _cx: &App) -> Option<String> {
598 None
599 }
600
601 /// Whether this model supports images
602 fn supports_images(&self) -> bool;
603
604 /// Whether this model supports tools.
605 fn supports_tools(&self) -> bool;
606
607 /// Whether this model supports choosing which tool to use.
608 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool;
609
610 /// Returns whether this model supports "burn mode";
611 fn supports_burn_mode(&self) -> bool {
612 false
613 }
614
615 /// Returns whether this model or provider supports streaming tool calls;
616 fn supports_streaming_tools(&self) -> bool {
617 false
618 }
619
620 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
621 LanguageModelToolSchemaFormat::JsonSchema
622 }
623
624 fn max_token_count(&self) -> u64;
625 /// Returns the maximum token count for this model in burn mode (If `supports_burn_mode` is `false` this returns `None`)
626 fn max_token_count_in_burn_mode(&self) -> Option<u64> {
627 None
628 }
629 fn max_output_tokens(&self) -> Option<u64> {
630 None
631 }
632
633 fn count_tokens(
634 &self,
635 request: LanguageModelRequest,
636 cx: &App,
637 ) -> BoxFuture<'static, Result<u64>>;
638
639 fn stream_completion(
640 &self,
641 request: LanguageModelRequest,
642 cx: &AsyncApp,
643 ) -> BoxFuture<
644 'static,
645 Result<
646 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
647 LanguageModelCompletionError,
648 >,
649 >;
650
651 fn stream_completion_text(
652 &self,
653 request: LanguageModelRequest,
654 cx: &AsyncApp,
655 ) -> BoxFuture<'static, Result<LanguageModelTextStream, LanguageModelCompletionError>> {
656 let future = self.stream_completion(request, cx);
657
658 async move {
659 let events = future.await?;
660 let mut events = events.fuse();
661 let mut message_id = None;
662 let mut first_item_text = None;
663 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
664
665 if let Some(first_event) = events.next().await {
666 match first_event {
667 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
668 message_id = Some(id);
669 }
670 Ok(LanguageModelCompletionEvent::Text(text)) => {
671 first_item_text = Some(text);
672 }
673 _ => (),
674 }
675 }
676
677 let stream = futures::stream::iter(first_item_text.map(Ok))
678 .chain(events.filter_map({
679 let last_token_usage = last_token_usage.clone();
680 move |result| {
681 let last_token_usage = last_token_usage.clone();
682 async move {
683 match result {
684 Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
685 Ok(LanguageModelCompletionEvent::Started) => None,
686 Ok(LanguageModelCompletionEvent::UsageUpdated { .. }) => None,
687 Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => None,
688 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
689 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
690 Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
691 Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) => None,
692 Ok(LanguageModelCompletionEvent::ReasoningDetails(_)) => None,
693 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
694 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
695 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
696 ..
697 }) => None,
698 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
699 *last_token_usage.lock() = token_usage;
700 None
701 }
702 Err(err) => Some(Err(err)),
703 }
704 }
705 }
706 }))
707 .boxed();
708
709 Ok(LanguageModelTextStream {
710 message_id,
711 stream,
712 last_token_usage,
713 })
714 }
715 .boxed()
716 }
717
718 fn stream_completion_tool(
719 &self,
720 request: LanguageModelRequest,
721 cx: &AsyncApp,
722 ) -> BoxFuture<'static, Result<LanguageModelToolUse, LanguageModelCompletionError>> {
723 let future = self.stream_completion(request, cx);
724
725 async move {
726 let events = future.await?;
727 let mut events = events.fuse();
728
729 // Iterate through events until we find a complete ToolUse
730 while let Some(event) = events.next().await {
731 match event {
732 Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
733 if tool_use.is_input_complete =>
734 {
735 return Ok(tool_use);
736 }
737 Err(err) => {
738 return Err(err);
739 }
740 _ => {}
741 }
742 }
743
744 // Stream ended without a complete tool use
745 Err(LanguageModelCompletionError::Other(anyhow::anyhow!(
746 "Stream ended without receiving a complete tool use"
747 )))
748 }
749 .boxed()
750 }
751
752 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
753 None
754 }
755
756 #[cfg(any(test, feature = "test-support"))]
757 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
758 unimplemented!()
759 }
760}
761
762pub trait LanguageModelExt: LanguageModel {
763 fn max_token_count_for_mode(&self, mode: CompletionMode) -> u64 {
764 match mode {
765 CompletionMode::Normal => self.max_token_count(),
766 CompletionMode::Max => self
767 .max_token_count_in_burn_mode()
768 .unwrap_or_else(|| self.max_token_count()),
769 }
770 }
771}
772impl LanguageModelExt for dyn LanguageModel {}
773
774impl std::fmt::Debug for dyn LanguageModel {
775 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
776 f.debug_struct("<dyn LanguageModel>")
777 .field("id", &self.id())
778 .field("name", &self.name())
779 .field("provider_id", &self.provider_id())
780 .field("provider_name", &self.provider_name())
781 .field("upstream_provider_name", &self.upstream_provider_name())
782 .field("upstream_provider_id", &self.upstream_provider_id())
783 .field("upstream_provider_id", &self.upstream_provider_id())
784 .field("supports_streaming_tools", &self.supports_streaming_tools())
785 .finish()
786 }
787}
788
789/// An error that occurred when trying to authenticate the language model provider.
790#[derive(Debug, Error)]
791pub enum AuthenticateError {
792 #[error("connection refused")]
793 ConnectionRefused,
794 #[error("credentials not found")]
795 CredentialsNotFound,
796 #[error(transparent)]
797 Other(#[from] anyhow::Error),
798}
799
800pub trait LanguageModelProvider: 'static {
801 fn id(&self) -> LanguageModelProviderId;
802 fn name(&self) -> LanguageModelProviderName;
803 fn icon(&self) -> IconName {
804 IconName::ZedAssistant
805 }
806 /// Returns the path to an external SVG icon for this provider, if any.
807 /// When present, this takes precedence over `icon()`.
808 fn icon_path(&self) -> Option<SharedString> {
809 None
810 }
811 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
812 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
813 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
814 fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
815 Vec::new()
816 }
817 fn is_authenticated(&self, cx: &App) -> bool;
818 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
819 fn configuration_view(
820 &self,
821 target_agent: ConfigurationViewTargetAgent,
822 window: &mut Window,
823 cx: &mut App,
824 ) -> AnyView;
825 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
826}
827
828#[derive(Default, Clone, PartialEq, Eq)]
829pub enum ConfigurationViewTargetAgent {
830 #[default]
831 ZedAgent,
832 EditPrediction,
833 Other(SharedString),
834}
835
836#[derive(PartialEq, Eq)]
837pub enum LanguageModelProviderTosView {
838 /// When there are some past interactions in the Agent Panel.
839 ThreadEmptyState,
840 /// When there are no past interactions in the Agent Panel.
841 ThreadFreshStart,
842 TextThreadPopup,
843 Configuration,
844}
845
846pub trait LanguageModelProviderState: 'static {
847 type ObservableEntity;
848
849 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
850
851 fn subscribe<T: 'static>(
852 &self,
853 cx: &mut gpui::Context<T>,
854 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
855 ) -> Option<gpui::Subscription> {
856 let entity = self.observable_entity()?;
857 Some(cx.observe(&entity, move |this, _, cx| {
858 callback(this, cx);
859 }))
860 }
861}
862
863#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
864pub struct LanguageModelId(pub SharedString);
865
866#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
867pub struct LanguageModelName(pub SharedString);
868
869#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
870pub struct LanguageModelProviderId(pub SharedString);
871
872#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
873pub struct LanguageModelProviderName(pub SharedString);
874
875impl LanguageModelProviderId {
876 pub const fn new(id: &'static str) -> Self {
877 Self(SharedString::new_static(id))
878 }
879}
880
881impl LanguageModelProviderName {
882 pub const fn new(id: &'static str) -> Self {
883 Self(SharedString::new_static(id))
884 }
885}
886
887impl fmt::Display for LanguageModelProviderId {
888 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
889 write!(f, "{}", self.0)
890 }
891}
892
893impl fmt::Display for LanguageModelProviderName {
894 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
895 write!(f, "{}", self.0)
896 }
897}
898
899impl From<String> for LanguageModelId {
900 fn from(value: String) -> Self {
901 Self(SharedString::from(value))
902 }
903}
904
905impl From<String> for LanguageModelName {
906 fn from(value: String) -> Self {
907 Self(SharedString::from(value))
908 }
909}
910
911impl From<String> for LanguageModelProviderId {
912 fn from(value: String) -> Self {
913 Self(SharedString::from(value))
914 }
915}
916
917impl From<String> for LanguageModelProviderName {
918 fn from(value: String) -> Self {
919 Self(SharedString::from(value))
920 }
921}
922
923impl From<Arc<str>> for LanguageModelProviderId {
924 fn from(value: Arc<str>) -> Self {
925 Self(SharedString::from(value))
926 }
927}
928
929impl From<Arc<str>> for LanguageModelProviderName {
930 fn from(value: Arc<str>) -> Self {
931 Self(SharedString::from(value))
932 }
933}
934
935#[cfg(test)]
936mod tests {
937 use super::*;
938
939 #[test]
940 fn test_from_cloud_failure_with_upstream_http_error() {
941 let error = LanguageModelCompletionError::from_cloud_failure(
942 String::from("anthropic").into(),
943 "upstream_http_error".to_string(),
944 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
945 None,
946 );
947
948 match error {
949 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
950 assert_eq!(provider.0, "anthropic");
951 }
952 _ => panic!(
953 "Expected ServerOverloaded error for 503 status, got: {:?}",
954 error
955 ),
956 }
957
958 let error = LanguageModelCompletionError::from_cloud_failure(
959 String::from("anthropic").into(),
960 "upstream_http_error".to_string(),
961 r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
962 None,
963 );
964
965 match error {
966 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
967 assert_eq!(provider.0, "anthropic");
968 assert_eq!(message, "Internal server error");
969 }
970 _ => panic!(
971 "Expected ApiInternalServerError for 500 status, got: {:?}",
972 error
973 ),
974 }
975 }
976
977 #[test]
978 fn test_from_cloud_failure_with_standard_format() {
979 let error = LanguageModelCompletionError::from_cloud_failure(
980 String::from("anthropic").into(),
981 "upstream_http_503".to_string(),
982 "Service unavailable".to_string(),
983 None,
984 );
985
986 match error {
987 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
988 assert_eq!(provider.0, "anthropic");
989 }
990 _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
991 }
992 }
993
994 #[test]
995 fn test_upstream_http_error_connection_timeout() {
996 let error = LanguageModelCompletionError::from_cloud_failure(
997 String::from("anthropic").into(),
998 "upstream_http_error".to_string(),
999 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
1000 None,
1001 );
1002
1003 match error {
1004 LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
1005 assert_eq!(provider.0, "anthropic");
1006 }
1007 _ => panic!(
1008 "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
1009 error
1010 ),
1011 }
1012
1013 let error = LanguageModelCompletionError::from_cloud_failure(
1014 String::from("anthropic").into(),
1015 "upstream_http_error".to_string(),
1016 r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
1017 None,
1018 );
1019
1020 match error {
1021 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1022 assert_eq!(provider.0, "anthropic");
1023 assert_eq!(
1024 message,
1025 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
1026 );
1027 }
1028 _ => panic!(
1029 "Expected ApiInternalServerError for connection timeout with 500 status, got: {:?}",
1030 error
1031 ),
1032 }
1033 }
1034
1035 #[test]
1036 fn test_language_model_tool_use_serializes_with_signature() {
1037 use serde_json::json;
1038
1039 let tool_use = LanguageModelToolUse {
1040 id: LanguageModelToolUseId::from("test_id"),
1041 name: "test_tool".into(),
1042 raw_input: json!({"arg": "value"}).to_string(),
1043 input: json!({"arg": "value"}),
1044 is_input_complete: true,
1045 thought_signature: Some("test_signature".to_string()),
1046 };
1047
1048 let serialized = serde_json::to_value(&tool_use).unwrap();
1049
1050 assert_eq!(serialized["id"], "test_id");
1051 assert_eq!(serialized["name"], "test_tool");
1052 assert_eq!(serialized["thought_signature"], "test_signature");
1053 }
1054
1055 #[test]
1056 fn test_language_model_tool_use_deserializes_with_missing_signature() {
1057 use serde_json::json;
1058
1059 let json = json!({
1060 "id": "test_id",
1061 "name": "test_tool",
1062 "raw_input": "{\"arg\":\"value\"}",
1063 "input": {"arg": "value"},
1064 "is_input_complete": true
1065 });
1066
1067 let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap();
1068
1069 assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id"));
1070 assert_eq!(tool_use.name.as_ref(), "test_tool");
1071 assert_eq!(tool_use.thought_signature, None);
1072 }
1073
1074 #[test]
1075 fn test_language_model_tool_use_round_trip_with_signature() {
1076 use serde_json::json;
1077
1078 let original = LanguageModelToolUse {
1079 id: LanguageModelToolUseId::from("round_trip_id"),
1080 name: "round_trip_tool".into(),
1081 raw_input: json!({"key": "value"}).to_string(),
1082 input: json!({"key": "value"}),
1083 is_input_complete: true,
1084 thought_signature: Some("round_trip_sig".to_string()),
1085 };
1086
1087 let serialized = serde_json::to_value(&original).unwrap();
1088 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1089
1090 assert_eq!(deserialized.id, original.id);
1091 assert_eq!(deserialized.name, original.name);
1092 assert_eq!(deserialized.thought_signature, original.thought_signature);
1093 }
1094
1095 #[test]
1096 fn test_language_model_tool_use_round_trip_without_signature() {
1097 use serde_json::json;
1098
1099 let original = LanguageModelToolUse {
1100 id: LanguageModelToolUseId::from("no_sig_id"),
1101 name: "no_sig_tool".into(),
1102 raw_input: json!({"arg": "value"}).to_string(),
1103 input: json!({"arg": "value"}),
1104 is_input_complete: true,
1105 thought_signature: None,
1106 };
1107
1108 let serialized = serde_json::to_value(&original).unwrap();
1109 let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap();
1110
1111 assert_eq!(deserialized.id, original.id);
1112 assert_eq!(deserialized.name, original.name);
1113 assert_eq!(deserialized.thought_signature, None);
1114 }
1115}