1use ai_onboarding::YoungAccountBanner;
2use anthropic::AnthropicModelMode;
3use anyhow::{Context as _, Result, anyhow};
4use client::{Client, UserStore, zed_urls};
5use cloud_api_types::{OrganizationId, Plan};
6use cloud_llm_client::{
7 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
8 CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, CompletionEvent, CompletionRequestStatus,
9 CountTokensBody, CountTokensResponse, ListModelsResponse,
10 SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
11};
12use futures::{
13 AsyncBufReadExt, FutureExt, Stream, StreamExt,
14 future::BoxFuture,
15 stream::{self, BoxStream},
16};
17use google_ai::GoogleModelMode;
18use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
19use http_client::http::{HeaderMap, HeaderValue};
20use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Response, StatusCode};
21use language_model::{
22 AuthenticateError, IconOrSvg, LanguageModel, LanguageModelCacheConfiguration,
23 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel,
24 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
25 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
26 LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh,
27 PaymentRequiredError, RateLimiter, RefreshLlmTokenListener,
28};
29use release_channel::AppVersion;
30use schemars::JsonSchema;
31use semver::Version;
32use serde::{Deserialize, Serialize, de::DeserializeOwned};
33use settings::SettingsStore;
34pub use settings::ZedDotDevAvailableModel as AvailableModel;
35pub use settings::ZedDotDevAvailableProvider as AvailableProvider;
36use smol::io::{AsyncReadExt, BufReader};
37use std::collections::VecDeque;
38use std::pin::Pin;
39use std::str::FromStr;
40use std::sync::Arc;
41use std::task::Poll;
42use std::time::Duration;
43use thiserror::Error;
44use ui::{TintColor, prelude::*};
45
46use crate::provider::anthropic::{
47 AnthropicEventMapper, count_anthropic_tokens_with_tiktoken, into_anthropic,
48};
49use crate::provider::google::{GoogleEventMapper, into_google};
50use crate::provider::open_ai::{
51 OpenAiEventMapper, OpenAiResponseEventMapper, count_open_ai_tokens, into_open_ai,
52 into_open_ai_response,
53};
54use crate::provider::x_ai::count_xai_tokens;
55
56const PROVIDER_ID: LanguageModelProviderId = language_model::ZED_CLOUD_PROVIDER_ID;
57const PROVIDER_NAME: LanguageModelProviderName = language_model::ZED_CLOUD_PROVIDER_NAME;
58
59#[derive(Default, Clone, Debug, PartialEq)]
60pub struct ZedDotDevSettings {
61 pub available_models: Vec<AvailableModel>,
62}
63#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
64#[serde(tag = "type", rename_all = "lowercase")]
65pub enum ModelMode {
66 #[default]
67 Default,
68 Thinking {
69 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
70 budget_tokens: Option<u32>,
71 },
72}
73
74impl From<ModelMode> for AnthropicModelMode {
75 fn from(value: ModelMode) -> Self {
76 match value {
77 ModelMode::Default => AnthropicModelMode::Default,
78 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
79 }
80 }
81}
82
83pub struct CloudLanguageModelProvider {
84 client: Arc<Client>,
85 state: Entity<State>,
86 _maintain_client_status: Task<()>,
87}
88
89pub struct State {
90 client: Arc<Client>,
91 llm_api_token: LlmApiToken,
92 user_store: Entity<UserStore>,
93 status: client::Status,
94 models: Vec<Arc<cloud_llm_client::LanguageModel>>,
95 default_model: Option<Arc<cloud_llm_client::LanguageModel>>,
96 default_fast_model: Option<Arc<cloud_llm_client::LanguageModel>>,
97 recommended_models: Vec<Arc<cloud_llm_client::LanguageModel>>,
98 _user_store_subscription: Subscription,
99 _settings_subscription: Subscription,
100 _llm_token_subscription: Subscription,
101}
102
103impl State {
104 fn new(
105 client: Arc<Client>,
106 user_store: Entity<UserStore>,
107 status: client::Status,
108 cx: &mut Context<Self>,
109 ) -> Self {
110 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
111 let llm_api_token = LlmApiToken::global(cx);
112 Self {
113 client: client.clone(),
114 llm_api_token,
115 user_store: user_store.clone(),
116 status,
117 models: Vec::new(),
118 default_model: None,
119 default_fast_model: None,
120 recommended_models: Vec::new(),
121 _user_store_subscription: cx.subscribe(
122 &user_store,
123 move |this, _user_store, event, cx| match event {
124 client::user::Event::PrivateUserInfoUpdated => {
125 let status = *client.status().borrow();
126 if status.is_signed_out() {
127 return;
128 }
129
130 let client = this.client.clone();
131 let llm_api_token = this.llm_api_token.clone();
132 let organization_id = this
133 .user_store
134 .read(cx)
135 .current_organization()
136 .map(|organization| organization.id.clone());
137 cx.spawn(async move |this, cx| {
138 let response =
139 Self::fetch_models(client, llm_api_token, organization_id).await?;
140 this.update(cx, |this, cx| this.update_models(response, cx))
141 })
142 .detach_and_log_err(cx);
143 }
144 _ => {}
145 },
146 ),
147 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
148 cx.notify();
149 }),
150 _llm_token_subscription: cx.subscribe(
151 &refresh_llm_token_listener,
152 move |this, _listener, _event, cx| {
153 let client = this.client.clone();
154 let llm_api_token = this.llm_api_token.clone();
155 let organization_id = this
156 .user_store
157 .read(cx)
158 .current_organization()
159 .map(|organization| organization.id.clone());
160 cx.spawn(async move |this, cx| {
161 let response =
162 Self::fetch_models(client, llm_api_token, organization_id).await?;
163 this.update(cx, |this, cx| {
164 this.update_models(response, cx);
165 })
166 })
167 .detach_and_log_err(cx);
168 },
169 ),
170 }
171 }
172
173 fn is_signed_out(&self, cx: &App) -> bool {
174 self.user_store.read(cx).current_user().is_none()
175 }
176
177 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
178 let client = self.client.clone();
179 cx.spawn(async move |state, cx| {
180 client.sign_in_with_optional_connect(true, cx).await?;
181 state.update(cx, |_, cx| cx.notify())
182 })
183 }
184
185 fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
186 let mut models = Vec::new();
187
188 for model in response.models {
189 models.push(Arc::new(model.clone()));
190 }
191
192 self.default_model = models
193 .iter()
194 .find(|model| {
195 response
196 .default_model
197 .as_ref()
198 .is_some_and(|default_model_id| &model.id == default_model_id)
199 })
200 .cloned();
201 self.default_fast_model = models
202 .iter()
203 .find(|model| {
204 response
205 .default_fast_model
206 .as_ref()
207 .is_some_and(|default_fast_model_id| &model.id == default_fast_model_id)
208 })
209 .cloned();
210 self.recommended_models = response
211 .recommended_models
212 .iter()
213 .filter_map(|id| models.iter().find(|model| &model.id == id))
214 .cloned()
215 .collect();
216 self.models = models;
217 cx.notify();
218 }
219
220 async fn fetch_models(
221 client: Arc<Client>,
222 llm_api_token: LlmApiToken,
223 organization_id: Option<OrganizationId>,
224 ) -> Result<ListModelsResponse> {
225 let http_client = &client.http_client();
226 let token = llm_api_token.acquire(&client, organization_id).await?;
227
228 let request = http_client::Request::builder()
229 .method(Method::GET)
230 .header(CLIENT_SUPPORTS_X_AI_HEADER_NAME, "true")
231 .uri(http_client.build_zed_llm_url("/models", &[])?.as_ref())
232 .header("Authorization", format!("Bearer {token}"))
233 .body(AsyncBody::empty())?;
234 let mut response = http_client
235 .send(request)
236 .await
237 .context("failed to send list models request")?;
238
239 if response.status().is_success() {
240 let mut body = String::new();
241 response.body_mut().read_to_string(&mut body).await?;
242 Ok(serde_json::from_str(&body)?)
243 } else {
244 let mut body = String::new();
245 response.body_mut().read_to_string(&mut body).await?;
246 anyhow::bail!(
247 "error listing models.\nStatus: {:?}\nBody: {body}",
248 response.status(),
249 );
250 }
251 }
252}
253
254impl CloudLanguageModelProvider {
255 pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
256 let mut status_rx = client.status();
257 let status = *status_rx.borrow();
258
259 let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
260
261 let state_ref = state.downgrade();
262 let maintain_client_status = cx.spawn(async move |cx| {
263 while let Some(status) = status_rx.next().await {
264 if let Some(this) = state_ref.upgrade() {
265 _ = this.update(cx, |this, cx| {
266 if this.status != status {
267 this.status = status;
268 cx.notify();
269 }
270 });
271 } else {
272 break;
273 }
274 }
275 });
276
277 Self {
278 client,
279 state,
280 _maintain_client_status: maintain_client_status,
281 }
282 }
283
284 fn create_language_model(
285 &self,
286 model: Arc<cloud_llm_client::LanguageModel>,
287 llm_api_token: LlmApiToken,
288 user_store: Entity<UserStore>,
289 ) -> Arc<dyn LanguageModel> {
290 Arc::new(CloudLanguageModel {
291 id: LanguageModelId(SharedString::from(model.id.0.clone())),
292 model,
293 llm_api_token,
294 user_store,
295 client: self.client.clone(),
296 request_limiter: RateLimiter::new(4),
297 })
298 }
299}
300
301impl LanguageModelProviderState for CloudLanguageModelProvider {
302 type ObservableEntity = State;
303
304 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
305 Some(self.state.clone())
306 }
307}
308
309impl LanguageModelProvider for CloudLanguageModelProvider {
310 fn id(&self) -> LanguageModelProviderId {
311 PROVIDER_ID
312 }
313
314 fn name(&self) -> LanguageModelProviderName {
315 PROVIDER_NAME
316 }
317
318 fn icon(&self) -> IconOrSvg {
319 IconOrSvg::Icon(IconName::AiZed)
320 }
321
322 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
323 let state = self.state.read(cx);
324 let default_model = state.default_model.clone()?;
325 let llm_api_token = state.llm_api_token.clone();
326 let user_store = state.user_store.clone();
327 Some(self.create_language_model(default_model, llm_api_token, user_store))
328 }
329
330 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
331 let state = self.state.read(cx);
332 let default_fast_model = state.default_fast_model.clone()?;
333 let llm_api_token = state.llm_api_token.clone();
334 let user_store = state.user_store.clone();
335 Some(self.create_language_model(default_fast_model, llm_api_token, user_store))
336 }
337
338 fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
339 let state = self.state.read(cx);
340 let llm_api_token = state.llm_api_token.clone();
341 let user_store = state.user_store.clone();
342 state
343 .recommended_models
344 .iter()
345 .cloned()
346 .map(|model| {
347 self.create_language_model(model, llm_api_token.clone(), user_store.clone())
348 })
349 .collect()
350 }
351
352 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
353 let state = self.state.read(cx);
354 let llm_api_token = state.llm_api_token.clone();
355 let user_store = state.user_store.clone();
356 state
357 .models
358 .iter()
359 .cloned()
360 .map(|model| {
361 self.create_language_model(model, llm_api_token.clone(), user_store.clone())
362 })
363 .collect()
364 }
365
366 fn is_authenticated(&self, cx: &App) -> bool {
367 let state = self.state.read(cx);
368 !state.is_signed_out(cx)
369 }
370
371 fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
372 Task::ready(Ok(()))
373 }
374
375 fn configuration_view(
376 &self,
377 _target_agent: language_model::ConfigurationViewTargetAgent,
378 _: &mut Window,
379 cx: &mut App,
380 ) -> AnyView {
381 cx.new(|_| ConfigurationView::new(self.state.clone()))
382 .into()
383 }
384
385 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
386 Task::ready(Ok(()))
387 }
388}
389
390pub struct CloudLanguageModel {
391 id: LanguageModelId,
392 model: Arc<cloud_llm_client::LanguageModel>,
393 llm_api_token: LlmApiToken,
394 user_store: Entity<UserStore>,
395 client: Arc<Client>,
396 request_limiter: RateLimiter,
397}
398
399struct PerformLlmCompletionResponse {
400 response: Response<AsyncBody>,
401 includes_status_messages: bool,
402}
403
404impl CloudLanguageModel {
405 async fn perform_llm_completion(
406 client: Arc<Client>,
407 llm_api_token: LlmApiToken,
408 organization_id: Option<OrganizationId>,
409 app_version: Option<Version>,
410 body: CompletionBody,
411 ) -> Result<PerformLlmCompletionResponse> {
412 let http_client = &client.http_client();
413
414 let mut token = llm_api_token
415 .acquire(&client, organization_id.clone())
416 .await?;
417 let mut refreshed_token = false;
418
419 loop {
420 let request = http_client::Request::builder()
421 .method(Method::POST)
422 .uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
423 .when_some(app_version.as_ref(), |builder, app_version| {
424 builder.header(ZED_VERSION_HEADER_NAME, app_version.to_string())
425 })
426 .header("Content-Type", "application/json")
427 .header("Authorization", format!("Bearer {token}"))
428 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
429 .header(CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, "true")
430 .body(serde_json::to_string(&body)?.into())?;
431
432 let mut response = http_client.send(request).await?;
433 let status = response.status();
434 if status.is_success() {
435 let includes_status_messages = response
436 .headers()
437 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
438 .is_some();
439
440 return Ok(PerformLlmCompletionResponse {
441 response,
442 includes_status_messages,
443 });
444 }
445
446 if !refreshed_token && response.needs_llm_token_refresh() {
447 token = llm_api_token
448 .refresh(&client, organization_id.clone())
449 .await?;
450 refreshed_token = true;
451 continue;
452 }
453
454 if status == StatusCode::PAYMENT_REQUIRED {
455 return Err(anyhow!(PaymentRequiredError));
456 }
457
458 let mut body = String::new();
459 let headers = response.headers().clone();
460 response.body_mut().read_to_string(&mut body).await?;
461 return Err(anyhow!(ApiError {
462 status,
463 body,
464 headers
465 }));
466 }
467 }
468}
469
470#[derive(Debug, Error)]
471#[error("cloud language model request failed with status {status}: {body}")]
472struct ApiError {
473 status: StatusCode,
474 body: String,
475 headers: HeaderMap<HeaderValue>,
476}
477
478/// Represents error responses from Zed's cloud API.
479///
480/// Example JSON for an upstream HTTP error:
481/// ```json
482/// {
483/// "code": "upstream_http_error",
484/// "message": "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout",
485/// "upstream_status": 503
486/// }
487/// ```
488#[derive(Debug, serde::Deserialize)]
489struct CloudApiError {
490 code: String,
491 message: String,
492 #[serde(default)]
493 #[serde(deserialize_with = "deserialize_optional_status_code")]
494 upstream_status: Option<StatusCode>,
495 #[serde(default)]
496 retry_after: Option<f64>,
497}
498
499fn deserialize_optional_status_code<'de, D>(deserializer: D) -> Result<Option<StatusCode>, D::Error>
500where
501 D: serde::Deserializer<'de>,
502{
503 let opt: Option<u16> = Option::deserialize(deserializer)?;
504 Ok(opt.and_then(|code| StatusCode::from_u16(code).ok()))
505}
506
507impl From<ApiError> for LanguageModelCompletionError {
508 fn from(error: ApiError) -> Self {
509 if let Ok(cloud_error) = serde_json::from_str::<CloudApiError>(&error.body) {
510 if cloud_error.code.starts_with("upstream_http_") {
511 let status = if let Some(status) = cloud_error.upstream_status {
512 status
513 } else if cloud_error.code.ends_with("_error") {
514 error.status
515 } else {
516 // If there's a status code in the code string (e.g. "upstream_http_429")
517 // then use that; otherwise, see if the JSON contains a status code.
518 cloud_error
519 .code
520 .strip_prefix("upstream_http_")
521 .and_then(|code_str| code_str.parse::<u16>().ok())
522 .and_then(|code| StatusCode::from_u16(code).ok())
523 .unwrap_or(error.status)
524 };
525
526 return LanguageModelCompletionError::UpstreamProviderError {
527 message: cloud_error.message,
528 status,
529 retry_after: cloud_error.retry_after.map(Duration::from_secs_f64),
530 };
531 }
532
533 return LanguageModelCompletionError::from_http_status(
534 PROVIDER_NAME,
535 error.status,
536 cloud_error.message,
537 None,
538 );
539 }
540
541 let retry_after = None;
542 LanguageModelCompletionError::from_http_status(
543 PROVIDER_NAME,
544 error.status,
545 error.body,
546 retry_after,
547 )
548 }
549}
550
551impl LanguageModel for CloudLanguageModel {
552 fn id(&self) -> LanguageModelId {
553 self.id.clone()
554 }
555
556 fn name(&self) -> LanguageModelName {
557 LanguageModelName::from(self.model.display_name.clone())
558 }
559
560 fn provider_id(&self) -> LanguageModelProviderId {
561 PROVIDER_ID
562 }
563
564 fn provider_name(&self) -> LanguageModelProviderName {
565 PROVIDER_NAME
566 }
567
568 fn upstream_provider_id(&self) -> LanguageModelProviderId {
569 use cloud_llm_client::LanguageModelProvider::*;
570 match self.model.provider {
571 Anthropic => language_model::ANTHROPIC_PROVIDER_ID,
572 OpenAi => language_model::OPEN_AI_PROVIDER_ID,
573 Google => language_model::GOOGLE_PROVIDER_ID,
574 XAi => language_model::X_AI_PROVIDER_ID,
575 }
576 }
577
578 fn upstream_provider_name(&self) -> LanguageModelProviderName {
579 use cloud_llm_client::LanguageModelProvider::*;
580 match self.model.provider {
581 Anthropic => language_model::ANTHROPIC_PROVIDER_NAME,
582 OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
583 Google => language_model::GOOGLE_PROVIDER_NAME,
584 XAi => language_model::X_AI_PROVIDER_NAME,
585 }
586 }
587
588 fn is_latest(&self) -> bool {
589 self.model.is_latest
590 }
591
592 fn supports_tools(&self) -> bool {
593 self.model.supports_tools
594 }
595
596 fn supports_images(&self) -> bool {
597 self.model.supports_images
598 }
599
600 fn supports_thinking(&self) -> bool {
601 self.model.supports_thinking
602 }
603
604 fn supports_fast_mode(&self) -> bool {
605 self.model.supports_fast_mode
606 }
607
608 fn supported_effort_levels(&self) -> Vec<LanguageModelEffortLevel> {
609 self.model
610 .supported_effort_levels
611 .iter()
612 .map(|effort_level| LanguageModelEffortLevel {
613 name: effort_level.name.clone().into(),
614 value: effort_level.value.clone().into(),
615 is_default: effort_level.is_default.unwrap_or(false),
616 })
617 .collect()
618 }
619
620 fn supports_streaming_tools(&self) -> bool {
621 self.model.supports_streaming_tools
622 }
623
624 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
625 match choice {
626 LanguageModelToolChoice::Auto
627 | LanguageModelToolChoice::Any
628 | LanguageModelToolChoice::None => true,
629 }
630 }
631
632 fn supports_split_token_display(&self) -> bool {
633 use cloud_llm_client::LanguageModelProvider::*;
634 matches!(self.model.provider, OpenAi | XAi)
635 }
636
637 fn telemetry_id(&self) -> String {
638 format!("zed.dev/{}", self.model.id)
639 }
640
641 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
642 match self.model.provider {
643 cloud_llm_client::LanguageModelProvider::Anthropic
644 | cloud_llm_client::LanguageModelProvider::OpenAi => {
645 LanguageModelToolSchemaFormat::JsonSchema
646 }
647 cloud_llm_client::LanguageModelProvider::Google
648 | cloud_llm_client::LanguageModelProvider::XAi => {
649 LanguageModelToolSchemaFormat::JsonSchemaSubset
650 }
651 }
652 }
653
654 fn max_token_count(&self) -> u64 {
655 self.model.max_token_count as u64
656 }
657
658 fn max_output_tokens(&self) -> Option<u64> {
659 Some(self.model.max_output_tokens as u64)
660 }
661
662 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
663 match &self.model.provider {
664 cloud_llm_client::LanguageModelProvider::Anthropic => {
665 Some(LanguageModelCacheConfiguration {
666 min_total_token: 2_048,
667 should_speculate: true,
668 max_cache_anchors: 4,
669 })
670 }
671 cloud_llm_client::LanguageModelProvider::OpenAi
672 | cloud_llm_client::LanguageModelProvider::XAi
673 | cloud_llm_client::LanguageModelProvider::Google => None,
674 }
675 }
676
677 fn count_tokens(
678 &self,
679 request: LanguageModelRequest,
680 cx: &App,
681 ) -> BoxFuture<'static, Result<u64>> {
682 match self.model.provider {
683 cloud_llm_client::LanguageModelProvider::Anthropic => cx
684 .background_spawn(async move { count_anthropic_tokens_with_tiktoken(request) })
685 .boxed(),
686 cloud_llm_client::LanguageModelProvider::OpenAi => {
687 let model = match open_ai::Model::from_id(&self.model.id.0) {
688 Ok(model) => model,
689 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
690 };
691 count_open_ai_tokens(request, model, cx)
692 }
693 cloud_llm_client::LanguageModelProvider::XAi => {
694 let model = match x_ai::Model::from_id(&self.model.id.0) {
695 Ok(model) => model,
696 Err(err) => return async move { Err(anyhow!(err)) }.boxed(),
697 };
698 count_xai_tokens(request, model, cx)
699 }
700 cloud_llm_client::LanguageModelProvider::Google => {
701 let client = self.client.clone();
702 let llm_api_token = self.llm_api_token.clone();
703 let organization_id = self
704 .user_store
705 .read(cx)
706 .current_organization()
707 .map(|organization| organization.id.clone());
708 let model_id = self.model.id.to_string();
709 let generate_content_request =
710 into_google(request, model_id.clone(), GoogleModelMode::Default);
711 async move {
712 let http_client = &client.http_client();
713 let token = llm_api_token.acquire(&client, organization_id).await?;
714
715 let request_body = CountTokensBody {
716 provider: cloud_llm_client::LanguageModelProvider::Google,
717 model: model_id,
718 provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
719 generate_content_request,
720 })?,
721 };
722 let request = http_client::Request::builder()
723 .method(Method::POST)
724 .uri(
725 http_client
726 .build_zed_llm_url("/count_tokens", &[])?
727 .as_ref(),
728 )
729 .header("Content-Type", "application/json")
730 .header("Authorization", format!("Bearer {token}"))
731 .body(serde_json::to_string(&request_body)?.into())?;
732 let mut response = http_client.send(request).await?;
733 let status = response.status();
734 let headers = response.headers().clone();
735 let mut response_body = String::new();
736 response
737 .body_mut()
738 .read_to_string(&mut response_body)
739 .await?;
740
741 if status.is_success() {
742 let response_body: CountTokensResponse =
743 serde_json::from_str(&response_body)?;
744
745 Ok(response_body.tokens as u64)
746 } else {
747 Err(anyhow!(ApiError {
748 status,
749 body: response_body,
750 headers
751 }))
752 }
753 }
754 .boxed()
755 }
756 }
757 }
758
759 fn stream_completion(
760 &self,
761 request: LanguageModelRequest,
762 cx: &AsyncApp,
763 ) -> BoxFuture<
764 'static,
765 Result<
766 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
767 LanguageModelCompletionError,
768 >,
769 > {
770 let thread_id = request.thread_id.clone();
771 let prompt_id = request.prompt_id.clone();
772 let intent = request.intent;
773 let app_version = Some(cx.update(|cx| AppVersion::global(cx)));
774 let user_store = self.user_store.clone();
775 let organization_id = cx.update(|cx| {
776 user_store
777 .read(cx)
778 .current_organization()
779 .map(|organization| organization.id.clone())
780 });
781 let thinking_allowed = request.thinking_allowed;
782 let enable_thinking = thinking_allowed && self.model.supports_thinking;
783 let provider_name = provider_name(&self.model.provider);
784 match self.model.provider {
785 cloud_llm_client::LanguageModelProvider::Anthropic => {
786 let effort = request
787 .thinking_effort
788 .as_ref()
789 .and_then(|effort| anthropic::Effort::from_str(effort).ok());
790
791 let mut request = into_anthropic(
792 request,
793 self.model.id.to_string(),
794 1.0,
795 self.model.max_output_tokens as u64,
796 if enable_thinking {
797 AnthropicModelMode::Thinking {
798 budget_tokens: Some(4_096),
799 }
800 } else {
801 AnthropicModelMode::Default
802 },
803 );
804
805 if enable_thinking && effort.is_some() {
806 request.thinking = Some(anthropic::Thinking::Adaptive);
807 request.output_config = Some(anthropic::OutputConfig { effort });
808 }
809
810 let client = self.client.clone();
811 let llm_api_token = self.llm_api_token.clone();
812 let organization_id = organization_id.clone();
813 let future = self.request_limiter.stream(async move {
814 let PerformLlmCompletionResponse {
815 response,
816 includes_status_messages,
817 } = Self::perform_llm_completion(
818 client.clone(),
819 llm_api_token,
820 organization_id,
821 app_version,
822 CompletionBody {
823 thread_id,
824 prompt_id,
825 intent,
826 provider: cloud_llm_client::LanguageModelProvider::Anthropic,
827 model: request.model.clone(),
828 provider_request: serde_json::to_value(&request)
829 .map_err(|e| anyhow!(e))?,
830 },
831 )
832 .await
833 .map_err(|err| match err.downcast::<ApiError>() {
834 Ok(api_err) => anyhow!(LanguageModelCompletionError::from(api_err)),
835 Err(err) => anyhow!(err),
836 })?;
837
838 let mut mapper = AnthropicEventMapper::new();
839 Ok(map_cloud_completion_events(
840 Box::pin(response_lines(response, includes_status_messages)),
841 &provider_name,
842 move |event| mapper.map_event(event),
843 ))
844 });
845 async move { Ok(future.await?.boxed()) }.boxed()
846 }
847 cloud_llm_client::LanguageModelProvider::OpenAi => {
848 let client = self.client.clone();
849 let llm_api_token = self.llm_api_token.clone();
850 let organization_id = organization_id.clone();
851 let effort = request
852 .thinking_effort
853 .as_ref()
854 .and_then(|effort| open_ai::ReasoningEffort::from_str(effort).ok());
855
856 let mut request = into_open_ai_response(
857 request,
858 &self.model.id.0,
859 self.model.supports_parallel_tool_calls,
860 true,
861 None,
862 None,
863 );
864
865 if enable_thinking && let Some(effort) = effort {
866 request.reasoning = Some(open_ai::responses::ReasoningConfig {
867 effort,
868 summary: Some(open_ai::responses::ReasoningSummaryMode::Auto),
869 });
870 }
871
872 let future = self.request_limiter.stream(async move {
873 let PerformLlmCompletionResponse {
874 response,
875 includes_status_messages,
876 } = Self::perform_llm_completion(
877 client.clone(),
878 llm_api_token,
879 organization_id,
880 app_version,
881 CompletionBody {
882 thread_id,
883 prompt_id,
884 intent,
885 provider: cloud_llm_client::LanguageModelProvider::OpenAi,
886 model: request.model.clone(),
887 provider_request: serde_json::to_value(&request)
888 .map_err(|e| anyhow!(e))?,
889 },
890 )
891 .await?;
892
893 let mut mapper = OpenAiResponseEventMapper::new();
894 Ok(map_cloud_completion_events(
895 Box::pin(response_lines(response, includes_status_messages)),
896 &provider_name,
897 move |event| mapper.map_event(event),
898 ))
899 });
900 async move { Ok(future.await?.boxed()) }.boxed()
901 }
902 cloud_llm_client::LanguageModelProvider::XAi => {
903 let client = self.client.clone();
904 let request = into_open_ai(
905 request,
906 &self.model.id.0,
907 self.model.supports_parallel_tool_calls,
908 false,
909 None,
910 None,
911 );
912 let llm_api_token = self.llm_api_token.clone();
913 let organization_id = organization_id.clone();
914 let future = self.request_limiter.stream(async move {
915 let PerformLlmCompletionResponse {
916 response,
917 includes_status_messages,
918 } = Self::perform_llm_completion(
919 client.clone(),
920 llm_api_token,
921 organization_id,
922 app_version,
923 CompletionBody {
924 thread_id,
925 prompt_id,
926 intent,
927 provider: cloud_llm_client::LanguageModelProvider::XAi,
928 model: request.model.clone(),
929 provider_request: serde_json::to_value(&request)
930 .map_err(|e| anyhow!(e))?,
931 },
932 )
933 .await?;
934
935 let mut mapper = OpenAiEventMapper::new();
936 Ok(map_cloud_completion_events(
937 Box::pin(response_lines(response, includes_status_messages)),
938 &provider_name,
939 move |event| mapper.map_event(event),
940 ))
941 });
942 async move { Ok(future.await?.boxed()) }.boxed()
943 }
944 cloud_llm_client::LanguageModelProvider::Google => {
945 let client = self.client.clone();
946 let request =
947 into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
948 let llm_api_token = self.llm_api_token.clone();
949 let future = self.request_limiter.stream(async move {
950 let PerformLlmCompletionResponse {
951 response,
952 includes_status_messages,
953 } = Self::perform_llm_completion(
954 client.clone(),
955 llm_api_token,
956 organization_id,
957 app_version,
958 CompletionBody {
959 thread_id,
960 prompt_id,
961 intent,
962 provider: cloud_llm_client::LanguageModelProvider::Google,
963 model: request.model.model_id.clone(),
964 provider_request: serde_json::to_value(&request)
965 .map_err(|e| anyhow!(e))?,
966 },
967 )
968 .await?;
969
970 let mut mapper = GoogleEventMapper::new();
971 Ok(map_cloud_completion_events(
972 Box::pin(response_lines(response, includes_status_messages)),
973 &provider_name,
974 move |event| mapper.map_event(event),
975 ))
976 });
977 async move { Ok(future.await?.boxed()) }.boxed()
978 }
979 }
980 }
981}
982
983fn map_cloud_completion_events<T, F>(
984 stream: Pin<Box<dyn Stream<Item = Result<CompletionEvent<T>>> + Send>>,
985 provider: &LanguageModelProviderName,
986 mut map_callback: F,
987) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
988where
989 T: DeserializeOwned + 'static,
990 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
991 + Send
992 + 'static,
993{
994 let provider = provider.clone();
995 let mut stream = stream.fuse();
996
997 let mut saw_stream_ended = false;
998
999 let mut done = false;
1000 let mut pending = VecDeque::new();
1001
1002 stream::poll_fn(move |cx| {
1003 loop {
1004 if let Some(item) = pending.pop_front() {
1005 return Poll::Ready(Some(item));
1006 }
1007
1008 if done {
1009 return Poll::Ready(None);
1010 }
1011
1012 match stream.poll_next_unpin(cx) {
1013 Poll::Ready(Some(event)) => {
1014 let items = match event {
1015 Err(error) => {
1016 vec![Err(LanguageModelCompletionError::from(error))]
1017 }
1018 Ok(CompletionEvent::Status(CompletionRequestStatus::StreamEnded)) => {
1019 saw_stream_ended = true;
1020 vec![]
1021 }
1022 Ok(CompletionEvent::Status(status)) => {
1023 LanguageModelCompletionEvent::from_completion_request_status(
1024 status,
1025 provider.clone(),
1026 )
1027 .transpose()
1028 .map(|event| vec![event])
1029 .unwrap_or_default()
1030 }
1031 Ok(CompletionEvent::Event(event)) => map_callback(event),
1032 };
1033 pending.extend(items);
1034 }
1035 Poll::Ready(None) => {
1036 done = true;
1037
1038 if !saw_stream_ended {
1039 return Poll::Ready(Some(Err(
1040 LanguageModelCompletionError::StreamEndedUnexpectedly {
1041 provider: provider.clone(),
1042 },
1043 )));
1044 }
1045 }
1046 Poll::Pending => return Poll::Pending,
1047 }
1048 }
1049 })
1050 .boxed()
1051}
1052
1053fn provider_name(provider: &cloud_llm_client::LanguageModelProvider) -> LanguageModelProviderName {
1054 match provider {
1055 cloud_llm_client::LanguageModelProvider::Anthropic => {
1056 language_model::ANTHROPIC_PROVIDER_NAME
1057 }
1058 cloud_llm_client::LanguageModelProvider::OpenAi => language_model::OPEN_AI_PROVIDER_NAME,
1059 cloud_llm_client::LanguageModelProvider::Google => language_model::GOOGLE_PROVIDER_NAME,
1060 cloud_llm_client::LanguageModelProvider::XAi => language_model::X_AI_PROVIDER_NAME,
1061 }
1062}
1063
1064fn response_lines<T: DeserializeOwned>(
1065 response: Response<AsyncBody>,
1066 includes_status_messages: bool,
1067) -> impl Stream<Item = Result<CompletionEvent<T>>> {
1068 futures::stream::try_unfold(
1069 (String::new(), BufReader::new(response.into_body())),
1070 move |(mut line, mut body)| async move {
1071 match body.read_line(&mut line).await {
1072 Ok(0) => Ok(None),
1073 Ok(_) => {
1074 let event = if includes_status_messages {
1075 serde_json::from_str::<CompletionEvent<T>>(&line)?
1076 } else {
1077 CompletionEvent::Event(serde_json::from_str::<T>(&line)?)
1078 };
1079
1080 line.clear();
1081 Ok(Some((event, (line, body))))
1082 }
1083 Err(e) => Err(e.into()),
1084 }
1085 },
1086 )
1087}
1088
1089#[derive(IntoElement, RegisterComponent)]
1090struct ZedAiConfiguration {
1091 is_connected: bool,
1092 plan: Option<Plan>,
1093 eligible_for_trial: bool,
1094 account_too_young: bool,
1095 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1096}
1097
1098impl RenderOnce for ZedAiConfiguration {
1099 fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
1100 let (subscription_text, has_paid_plan) = match self.plan {
1101 Some(Plan::ZedPro) => (
1102 "You have access to Zed's hosted models through your Pro subscription.",
1103 true,
1104 ),
1105 Some(Plan::ZedProTrial) => (
1106 "You have access to Zed's hosted models through your Pro trial.",
1107 false,
1108 ),
1109 Some(Plan::ZedStudent) => (
1110 "You have access to Zed's hosted models through your Student subscription.",
1111 true,
1112 ),
1113 Some(Plan::ZedBusiness) => (
1114 "You have access to Zed's hosted models through your Organization.",
1115 true,
1116 ),
1117 Some(Plan::ZedFree) | None => (
1118 if self.eligible_for_trial {
1119 "Subscribe for access to Zed's hosted models. Start with a 14 day free trial."
1120 } else {
1121 "Subscribe for access to Zed's hosted models."
1122 },
1123 false,
1124 ),
1125 };
1126
1127 let manage_subscription_buttons = if has_paid_plan {
1128 Button::new("manage_settings", "Manage Subscription")
1129 .full_width()
1130 .label_size(LabelSize::Small)
1131 .style(ButtonStyle::Tinted(TintColor::Accent))
1132 .on_click(|_, _, cx| cx.open_url(&zed_urls::account_url(cx)))
1133 .into_any_element()
1134 } else if self.plan.is_none() || self.eligible_for_trial {
1135 Button::new("start_trial", "Start 14-day Free Pro Trial")
1136 .full_width()
1137 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1138 .on_click(|_, _, cx| cx.open_url(&zed_urls::start_trial_url(cx)))
1139 .into_any_element()
1140 } else {
1141 Button::new("upgrade", "Upgrade to Pro")
1142 .full_width()
1143 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1144 .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx)))
1145 .into_any_element()
1146 };
1147
1148 if !self.is_connected {
1149 return v_flex()
1150 .gap_2()
1151 .child(Label::new("Sign in to have access to Zed's complete agentic experience with hosted models."))
1152 .child(
1153 Button::new("sign_in", "Sign In to use Zed AI")
1154 .start_icon(Icon::new(IconName::Github).size(IconSize::Small).color(Color::Muted))
1155 .full_width()
1156 .on_click({
1157 let callback = self.sign_in_callback.clone();
1158 move |_, window, cx| (callback)(window, cx)
1159 }),
1160 );
1161 }
1162
1163 v_flex().gap_2().w_full().map(|this| {
1164 if self.account_too_young {
1165 this.child(YoungAccountBanner).child(
1166 Button::new("upgrade", "Upgrade to Pro")
1167 .style(ui::ButtonStyle::Tinted(ui::TintColor::Accent))
1168 .full_width()
1169 .on_click(|_, _, cx| cx.open_url(&zed_urls::upgrade_to_zed_pro_url(cx))),
1170 )
1171 } else {
1172 this.text_sm()
1173 .child(subscription_text)
1174 .child(manage_subscription_buttons)
1175 }
1176 })
1177 }
1178}
1179
1180struct ConfigurationView {
1181 state: Entity<State>,
1182 sign_in_callback: Arc<dyn Fn(&mut Window, &mut App) + Send + Sync>,
1183}
1184
1185impl ConfigurationView {
1186 fn new(state: Entity<State>) -> Self {
1187 let sign_in_callback = Arc::new({
1188 let state = state.clone();
1189 move |_window: &mut Window, cx: &mut App| {
1190 state.update(cx, |state, cx| {
1191 state.authenticate(cx).detach_and_log_err(cx);
1192 });
1193 }
1194 });
1195
1196 Self {
1197 state,
1198 sign_in_callback,
1199 }
1200 }
1201}
1202
1203impl Render for ConfigurationView {
1204 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1205 let state = self.state.read(cx);
1206 let user_store = state.user_store.read(cx);
1207
1208 ZedAiConfiguration {
1209 is_connected: !state.is_signed_out(cx),
1210 plan: user_store.plan(),
1211 eligible_for_trial: user_store.trial_started_at().is_none(),
1212 account_too_young: user_store.account_too_young(),
1213 sign_in_callback: self.sign_in_callback.clone(),
1214 }
1215 }
1216}
1217
1218impl Component for ZedAiConfiguration {
1219 fn name() -> &'static str {
1220 "AI Configuration Content"
1221 }
1222
1223 fn sort_name() -> &'static str {
1224 "AI Configuration Content"
1225 }
1226
1227 fn scope() -> ComponentScope {
1228 ComponentScope::Onboarding
1229 }
1230
1231 fn preview(_window: &mut Window, _cx: &mut App) -> Option<AnyElement> {
1232 fn configuration(
1233 is_connected: bool,
1234 plan: Option<Plan>,
1235 eligible_for_trial: bool,
1236 account_too_young: bool,
1237 ) -> AnyElement {
1238 ZedAiConfiguration {
1239 is_connected,
1240 plan,
1241 eligible_for_trial,
1242 account_too_young,
1243 sign_in_callback: Arc::new(|_, _| {}),
1244 }
1245 .into_any_element()
1246 }
1247
1248 Some(
1249 v_flex()
1250 .p_4()
1251 .gap_4()
1252 .children(vec![
1253 single_example("Not connected", configuration(false, None, false, false)),
1254 single_example(
1255 "Accept Terms of Service",
1256 configuration(true, None, true, false),
1257 ),
1258 single_example(
1259 "No Plan - Not eligible for trial",
1260 configuration(true, None, false, false),
1261 ),
1262 single_example(
1263 "No Plan - Eligible for trial",
1264 configuration(true, None, true, false),
1265 ),
1266 single_example(
1267 "Free Plan",
1268 configuration(true, Some(Plan::ZedFree), true, false),
1269 ),
1270 single_example(
1271 "Zed Pro Trial Plan",
1272 configuration(true, Some(Plan::ZedProTrial), true, false),
1273 ),
1274 single_example(
1275 "Zed Pro Plan",
1276 configuration(true, Some(Plan::ZedPro), true, false),
1277 ),
1278 ])
1279 .into_any_element(),
1280 )
1281 }
1282}
1283
1284#[cfg(test)]
1285mod tests {
1286 use super::*;
1287 use http_client::http::{HeaderMap, StatusCode};
1288 use language_model::LanguageModelCompletionError;
1289
1290 #[test]
1291 fn test_api_error_conversion_with_upstream_http_error() {
1292 // upstream_http_error with 503 status should become ServerOverloaded
1293 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout","upstream_status":503}"#;
1294
1295 let api_error = ApiError {
1296 status: StatusCode::INTERNAL_SERVER_ERROR,
1297 body: error_body.to_string(),
1298 headers: HeaderMap::new(),
1299 };
1300
1301 let completion_error: LanguageModelCompletionError = api_error.into();
1302
1303 match completion_error {
1304 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1305 assert_eq!(
1306 message,
1307 "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers, reset reason: connection timeout"
1308 );
1309 }
1310 _ => panic!(
1311 "Expected UpstreamProviderError for upstream 503, got: {:?}",
1312 completion_error
1313 ),
1314 }
1315
1316 // upstream_http_error with 500 status should become ApiInternalServerError
1317 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the OpenAI API: internal server error","upstream_status":500}"#;
1318
1319 let api_error = ApiError {
1320 status: StatusCode::INTERNAL_SERVER_ERROR,
1321 body: error_body.to_string(),
1322 headers: HeaderMap::new(),
1323 };
1324
1325 let completion_error: LanguageModelCompletionError = api_error.into();
1326
1327 match completion_error {
1328 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1329 assert_eq!(
1330 message,
1331 "Received an error from the OpenAI API: internal server error"
1332 );
1333 }
1334 _ => panic!(
1335 "Expected UpstreamProviderError for upstream 500, got: {:?}",
1336 completion_error
1337 ),
1338 }
1339
1340 // upstream_http_error with 429 status should become RateLimitExceeded
1341 let error_body = r#"{"code":"upstream_http_error","message":"Received an error from the Google API: rate limit exceeded","upstream_status":429}"#;
1342
1343 let api_error = ApiError {
1344 status: StatusCode::INTERNAL_SERVER_ERROR,
1345 body: error_body.to_string(),
1346 headers: HeaderMap::new(),
1347 };
1348
1349 let completion_error: LanguageModelCompletionError = api_error.into();
1350
1351 match completion_error {
1352 LanguageModelCompletionError::UpstreamProviderError { message, .. } => {
1353 assert_eq!(
1354 message,
1355 "Received an error from the Google API: rate limit exceeded"
1356 );
1357 }
1358 _ => panic!(
1359 "Expected UpstreamProviderError for upstream 429, got: {:?}",
1360 completion_error
1361 ),
1362 }
1363
1364 // Regular 500 error without upstream_http_error should remain ApiInternalServerError for Zed
1365 let error_body = "Regular internal server error";
1366
1367 let api_error = ApiError {
1368 status: StatusCode::INTERNAL_SERVER_ERROR,
1369 body: error_body.to_string(),
1370 headers: HeaderMap::new(),
1371 };
1372
1373 let completion_error: LanguageModelCompletionError = api_error.into();
1374
1375 match completion_error {
1376 LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
1377 assert_eq!(provider, PROVIDER_NAME);
1378 assert_eq!(message, "Regular internal server error");
1379 }
1380 _ => panic!(
1381 "Expected ApiInternalServerError for regular 500, got: {:?}",
1382 completion_error
1383 ),
1384 }
1385
1386 // upstream_http_429 format should be converted to UpstreamProviderError
1387 let error_body = r#"{"code":"upstream_http_429","message":"Upstream Anthropic rate limit exceeded.","retry_after":30.5}"#;
1388
1389 let api_error = ApiError {
1390 status: StatusCode::INTERNAL_SERVER_ERROR,
1391 body: error_body.to_string(),
1392 headers: HeaderMap::new(),
1393 };
1394
1395 let completion_error: LanguageModelCompletionError = api_error.into();
1396
1397 match completion_error {
1398 LanguageModelCompletionError::UpstreamProviderError {
1399 message,
1400 status,
1401 retry_after,
1402 } => {
1403 assert_eq!(message, "Upstream Anthropic rate limit exceeded.");
1404 assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
1405 assert_eq!(retry_after, Some(Duration::from_secs_f64(30.5)));
1406 }
1407 _ => panic!(
1408 "Expected UpstreamProviderError for upstream_http_429, got: {:?}",
1409 completion_error
1410 ),
1411 }
1412
1413 // Invalid JSON in error body should fall back to regular error handling
1414 let error_body = "Not JSON at all";
1415
1416 let api_error = ApiError {
1417 status: StatusCode::INTERNAL_SERVER_ERROR,
1418 body: error_body.to_string(),
1419 headers: HeaderMap::new(),
1420 };
1421
1422 let completion_error: LanguageModelCompletionError = api_error.into();
1423
1424 match completion_error {
1425 LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
1426 assert_eq!(provider, PROVIDER_NAME);
1427 }
1428 _ => panic!(
1429 "Expected ApiInternalServerError for invalid JSON, got: {:?}",
1430 completion_error
1431 ),
1432 }
1433 }
1434}