1use anthropic::{AnthropicModelMode, parse_prompt_too_long};
2use anyhow::{Result, anyhow};
3use client::{Client, UserStore, zed_urls};
4use collections::BTreeMap;
5use feature_flags::{FeatureFlagAppExt, LlmClosedBetaFeatureFlag, ZedProFeatureFlag};
6use futures::{
7 AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
8};
9use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
10use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
11use language_model::{
12 AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
13 LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
14 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
15 LanguageModelProviderTosView, LanguageModelRequest, LanguageModelToolSchemaFormat,
16 ModelRequestLimitReachedError, RateLimiter, RequestUsage, ZED_CLOUD_PROVIDER_ID,
17};
18use language_model::{
19 LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider, LlmApiToken,
20 MaxMonthlySpendReachedError, PaymentRequiredError, RefreshLlmTokenListener,
21};
22use proto::Plan;
23use schemars::JsonSchema;
24use serde::{Deserialize, Serialize, de::DeserializeOwned};
25use settings::{Settings, SettingsStore};
26use smol::Timer;
27use smol::io::{AsyncReadExt, BufReader};
28use std::pin::Pin;
29use std::str::FromStr as _;
30use std::{
31 sync::{Arc, LazyLock},
32 time::Duration,
33};
34use strum::IntoEnumIterator;
35use thiserror::Error;
36use ui::{TintColor, prelude::*};
37use zed_llm_client::{
38 CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CURRENT_PLAN_HEADER_NAME, CompletionBody,
39 CompletionRequestStatus, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
40 MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
41 SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
42 TOOL_USE_LIMIT_REACHED_HEADER_NAME,
43};
44
45use crate::AllLanguageModelSettings;
46use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic};
47use crate::provider::google::{GoogleEventMapper, into_google};
48use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai};
49
50pub const PROVIDER_NAME: &str = "Zed";
51
52const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
53 option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");
54
55fn zed_cloud_provider_additional_models() -> &'static [AvailableModel] {
56 static ADDITIONAL_MODELS: LazyLock<Vec<AvailableModel>> = LazyLock::new(|| {
57 ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON
58 .map(|json| serde_json::from_str(json).unwrap())
59 .unwrap_or_default()
60 });
61 ADDITIONAL_MODELS.as_slice()
62}
63
64#[derive(Default, Clone, Debug, PartialEq)]
65pub struct ZedDotDevSettings {
66 pub available_models: Vec<AvailableModel>,
67}
68
69#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
70#[serde(rename_all = "lowercase")]
71pub enum AvailableProvider {
72 Anthropic,
73 OpenAi,
74 Google,
75}
76
77#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
78pub struct AvailableModel {
79 /// The provider of the language model.
80 pub provider: AvailableProvider,
81 /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
82 pub name: String,
83 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
84 pub display_name: Option<String>,
85 /// The size of the context window, indicating the maximum number of tokens the model can process.
86 pub max_tokens: usize,
87 /// The maximum number of output tokens allowed by the model.
88 pub max_output_tokens: Option<u32>,
89 /// The maximum number of completion tokens allowed by the model (o1-* only)
90 pub max_completion_tokens: Option<u32>,
91 /// Override this model with a different Anthropic model for tool calls.
92 pub tool_override: Option<String>,
93 /// Indicates whether this custom model supports caching.
94 pub cache_configuration: Option<LanguageModelCacheConfiguration>,
95 /// The default temperature to use for this model.
96 pub default_temperature: Option<f32>,
97 /// Any extra beta headers to provide when using the model.
98 #[serde(default)]
99 pub extra_beta_headers: Vec<String>,
100 /// The model's mode (e.g. thinking)
101 pub mode: Option<ModelMode>,
102}
103
104#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
105#[serde(tag = "type", rename_all = "lowercase")]
106pub enum ModelMode {
107 #[default]
108 Default,
109 Thinking {
110 /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
111 budget_tokens: Option<u32>,
112 },
113}
114
115impl From<ModelMode> for AnthropicModelMode {
116 fn from(value: ModelMode) -> Self {
117 match value {
118 ModelMode::Default => AnthropicModelMode::Default,
119 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
120 }
121 }
122}
123
124pub struct CloudLanguageModelProvider {
125 client: Arc<Client>,
126 state: gpui::Entity<State>,
127 _maintain_client_status: Task<()>,
128}
129
130pub struct State {
131 client: Arc<Client>,
132 llm_api_token: LlmApiToken,
133 user_store: Entity<UserStore>,
134 status: client::Status,
135 accept_terms: Option<Task<Result<()>>>,
136 _settings_subscription: Subscription,
137 _llm_token_subscription: Subscription,
138}
139
140impl State {
141 fn new(
142 client: Arc<Client>,
143 user_store: Entity<UserStore>,
144 status: client::Status,
145 cx: &mut Context<Self>,
146 ) -> Self {
147 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
148
149 Self {
150 client: client.clone(),
151 llm_api_token: LlmApiToken::default(),
152 user_store,
153 status,
154 accept_terms: None,
155 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
156 cx.notify();
157 }),
158 _llm_token_subscription: cx.subscribe(
159 &refresh_llm_token_listener,
160 |this, _listener, _event, cx| {
161 let client = this.client.clone();
162 let llm_api_token = this.llm_api_token.clone();
163 cx.spawn(async move |_this, _cx| {
164 llm_api_token.refresh(&client).await?;
165 anyhow::Ok(())
166 })
167 .detach_and_log_err(cx);
168 },
169 ),
170 }
171 }
172
173 fn is_signed_out(&self) -> bool {
174 self.status.is_signed_out()
175 }
176
177 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
178 let client = self.client.clone();
179 cx.spawn(async move |this, cx| {
180 client.authenticate_and_connect(true, &cx).await?;
181 this.update(cx, |_, cx| cx.notify())
182 })
183 }
184
185 fn has_accepted_terms_of_service(&self, cx: &App) -> bool {
186 self.user_store
187 .read(cx)
188 .current_user_has_accepted_terms()
189 .unwrap_or(false)
190 }
191
192 fn accept_terms_of_service(&mut self, cx: &mut Context<Self>) {
193 let user_store = self.user_store.clone();
194 self.accept_terms = Some(cx.spawn(async move |this, cx| {
195 let _ = user_store
196 .update(cx, |store, cx| store.accept_terms_of_service(cx))?
197 .await;
198 this.update(cx, |this, cx| {
199 this.accept_terms = None;
200 cx.notify()
201 })
202 }));
203 }
204}
205
206impl CloudLanguageModelProvider {
207 pub fn new(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) -> Self {
208 let mut status_rx = client.status();
209 let status = *status_rx.borrow();
210
211 let state = cx.new(|cx| State::new(client.clone(), user_store.clone(), status, cx));
212
213 let state_ref = state.downgrade();
214 let maintain_client_status = cx.spawn(async move |cx| {
215 while let Some(status) = status_rx.next().await {
216 if let Some(this) = state_ref.upgrade() {
217 _ = this.update(cx, |this, cx| {
218 if this.status != status {
219 this.status = status;
220 cx.notify();
221 }
222 });
223 } else {
224 break;
225 }
226 }
227 });
228
229 Self {
230 client,
231 state: state.clone(),
232 _maintain_client_status: maintain_client_status,
233 }
234 }
235
236 fn create_language_model(
237 &self,
238 model: CloudModel,
239 llm_api_token: LlmApiToken,
240 ) -> Arc<dyn LanguageModel> {
241 Arc::new(CloudLanguageModel {
242 id: LanguageModelId::from(model.id().to_string()),
243 model,
244 llm_api_token: llm_api_token.clone(),
245 client: self.client.clone(),
246 request_limiter: RateLimiter::new(4),
247 })
248 }
249}
250
251impl LanguageModelProviderState for CloudLanguageModelProvider {
252 type ObservableEntity = State;
253
254 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
255 Some(self.state.clone())
256 }
257}
258
259impl LanguageModelProvider for CloudLanguageModelProvider {
260 fn id(&self) -> LanguageModelProviderId {
261 LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
262 }
263
264 fn name(&self) -> LanguageModelProviderName {
265 LanguageModelProviderName(PROVIDER_NAME.into())
266 }
267
268 fn icon(&self) -> IconName {
269 IconName::AiZed
270 }
271
272 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
273 let llm_api_token = self.state.read(cx).llm_api_token.clone();
274 let model = CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet);
275 Some(self.create_language_model(model, llm_api_token))
276 }
277
278 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
279 let llm_api_token = self.state.read(cx).llm_api_token.clone();
280 let model = CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet);
281 Some(self.create_language_model(model, llm_api_token))
282 }
283
284 fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
285 let llm_api_token = self.state.read(cx).llm_api_token.clone();
286 [
287 CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
288 CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
289 ]
290 .into_iter()
291 .map(|model| self.create_language_model(model, llm_api_token.clone()))
292 .collect()
293 }
294
295 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
296 let mut models = BTreeMap::default();
297
298 if cx.is_staff() {
299 for model in anthropic::Model::iter() {
300 if !matches!(model, anthropic::Model::Custom { .. }) {
301 models.insert(model.id().to_string(), CloudModel::Anthropic(model));
302 }
303 }
304 for model in open_ai::Model::iter() {
305 if !matches!(model, open_ai::Model::Custom { .. }) {
306 models.insert(model.id().to_string(), CloudModel::OpenAi(model));
307 }
308 }
309 for model in google_ai::Model::iter() {
310 if !matches!(model, google_ai::Model::Custom { .. }) {
311 models.insert(model.id().to_string(), CloudModel::Google(model));
312 }
313 }
314 } else {
315 models.insert(
316 anthropic::Model::Claude3_5Sonnet.id().to_string(),
317 CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
318 );
319 models.insert(
320 anthropic::Model::Claude3_7Sonnet.id().to_string(),
321 CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
322 );
323 models.insert(
324 anthropic::Model::Claude3_7SonnetThinking.id().to_string(),
325 CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
326 );
327 }
328
329 let llm_closed_beta_models = if cx.has_flag::<LlmClosedBetaFeatureFlag>() {
330 zed_cloud_provider_additional_models()
331 } else {
332 &[]
333 };
334
335 // Override with available models from settings
336 for model in AllLanguageModelSettings::get_global(cx)
337 .zed_dot_dev
338 .available_models
339 .iter()
340 .chain(llm_closed_beta_models)
341 .cloned()
342 {
343 let model = match model.provider {
344 AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
345 name: model.name.clone(),
346 display_name: model.display_name.clone(),
347 max_tokens: model.max_tokens,
348 tool_override: model.tool_override.clone(),
349 cache_configuration: model.cache_configuration.as_ref().map(|config| {
350 anthropic::AnthropicModelCacheConfiguration {
351 max_cache_anchors: config.max_cache_anchors,
352 should_speculate: config.should_speculate,
353 min_total_token: config.min_total_token,
354 }
355 }),
356 default_temperature: model.default_temperature,
357 max_output_tokens: model.max_output_tokens,
358 extra_beta_headers: model.extra_beta_headers.clone(),
359 mode: model.mode.unwrap_or_default().into(),
360 }),
361 AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
362 name: model.name.clone(),
363 display_name: model.display_name.clone(),
364 max_tokens: model.max_tokens,
365 max_output_tokens: model.max_output_tokens,
366 max_completion_tokens: model.max_completion_tokens,
367 }),
368 AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
369 name: model.name.clone(),
370 display_name: model.display_name.clone(),
371 max_tokens: model.max_tokens,
372 }),
373 };
374 models.insert(model.id().to_string(), model.clone());
375 }
376
377 let llm_api_token = self.state.read(cx).llm_api_token.clone();
378 models
379 .into_values()
380 .map(|model| self.create_language_model(model, llm_api_token.clone()))
381 .collect()
382 }
383
384 fn is_authenticated(&self, cx: &App) -> bool {
385 !self.state.read(cx).is_signed_out()
386 }
387
388 fn authenticate(&self, _cx: &mut App) -> Task<Result<(), AuthenticateError>> {
389 Task::ready(Ok(()))
390 }
391
392 fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
393 cx.new(|_| ConfigurationView {
394 state: self.state.clone(),
395 })
396 .into()
397 }
398
399 fn must_accept_terms(&self, cx: &App) -> bool {
400 !self.state.read(cx).has_accepted_terms_of_service(cx)
401 }
402
403 fn render_accept_terms(
404 &self,
405 view: LanguageModelProviderTosView,
406 cx: &mut App,
407 ) -> Option<AnyElement> {
408 render_accept_terms(self.state.clone(), view, cx)
409 }
410
411 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
412 Task::ready(Ok(()))
413 }
414}
415
416fn render_accept_terms(
417 state: Entity<State>,
418 view_kind: LanguageModelProviderTosView,
419 cx: &mut App,
420) -> Option<AnyElement> {
421 if state.read(cx).has_accepted_terms_of_service(cx) {
422 return None;
423 }
424
425 let accept_terms_disabled = state.read(cx).accept_terms.is_some();
426
427 let thread_fresh_start = matches!(view_kind, LanguageModelProviderTosView::ThreadFreshStart);
428 let thread_empty_state = matches!(view_kind, LanguageModelProviderTosView::ThreadtEmptyState);
429
430 let terms_button = Button::new("terms_of_service", "Terms of Service")
431 .style(ButtonStyle::Subtle)
432 .icon(IconName::ArrowUpRight)
433 .icon_color(Color::Muted)
434 .icon_size(IconSize::XSmall)
435 .when(thread_empty_state, |this| this.label_size(LabelSize::Small))
436 .on_click(move |_, _window, cx| cx.open_url("https://zed.dev/terms-of-service"));
437
438 let button_container = h_flex().child(
439 Button::new("accept_terms", "I accept the Terms of Service")
440 .when(!thread_empty_state, |this| {
441 this.full_width()
442 .style(ButtonStyle::Tinted(TintColor::Accent))
443 .icon(IconName::Check)
444 .icon_position(IconPosition::Start)
445 .icon_size(IconSize::Small)
446 })
447 .when(thread_empty_state, |this| {
448 this.style(ButtonStyle::Tinted(TintColor::Warning))
449 .label_size(LabelSize::Small)
450 })
451 .disabled(accept_terms_disabled)
452 .on_click({
453 let state = state.downgrade();
454 move |_, _window, cx| {
455 state
456 .update(cx, |state, cx| state.accept_terms_of_service(cx))
457 .ok();
458 }
459 }),
460 );
461
462 let form = if thread_empty_state {
463 h_flex()
464 .w_full()
465 .flex_wrap()
466 .justify_between()
467 .child(
468 h_flex()
469 .child(
470 Label::new("To start using Zed AI, please read and accept the")
471 .size(LabelSize::Small),
472 )
473 .child(terms_button),
474 )
475 .child(button_container)
476 } else {
477 v_flex()
478 .w_full()
479 .gap_2()
480 .child(
481 h_flex()
482 .flex_wrap()
483 .when(thread_fresh_start, |this| this.justify_center())
484 .child(Label::new(
485 "To start using Zed AI, please read and accept the",
486 ))
487 .child(terms_button),
488 )
489 .child({
490 match view_kind {
491 LanguageModelProviderTosView::PromptEditorPopup => {
492 button_container.w_full().justify_end()
493 }
494 LanguageModelProviderTosView::Configuration => {
495 button_container.w_full().justify_start()
496 }
497 LanguageModelProviderTosView::ThreadFreshStart => {
498 button_container.w_full().justify_center()
499 }
500 LanguageModelProviderTosView::ThreadtEmptyState => div().w_0(),
501 }
502 })
503 };
504
505 Some(form.into_any())
506}
507
508pub struct CloudLanguageModel {
509 id: LanguageModelId,
510 model: CloudModel,
511 llm_api_token: LlmApiToken,
512 client: Arc<Client>,
513 request_limiter: RateLimiter,
514}
515
516struct PerformLlmCompletionResponse {
517 response: Response<AsyncBody>,
518 usage: Option<RequestUsage>,
519 tool_use_limit_reached: bool,
520 includes_status_messages: bool,
521}
522
523impl CloudLanguageModel {
524 const MAX_RETRIES: usize = 3;
525
526 async fn perform_llm_completion(
527 client: Arc<Client>,
528 llm_api_token: LlmApiToken,
529 body: CompletionBody,
530 ) -> Result<PerformLlmCompletionResponse> {
531 let http_client = &client.http_client();
532
533 let mut token = llm_api_token.acquire(&client).await?;
534 let mut retries_remaining = Self::MAX_RETRIES;
535 let mut retry_delay = Duration::from_secs(1);
536
537 loop {
538 let request_builder = http_client::Request::builder().method(Method::POST);
539 let request_builder = if let Ok(completions_url) = std::env::var("ZED_COMPLETIONS_URL")
540 {
541 request_builder.uri(completions_url)
542 } else {
543 request_builder.uri(http_client.build_zed_llm_url("/completions", &[])?.as_ref())
544 };
545 let request = request_builder
546 .header("Content-Type", "application/json")
547 .header("Authorization", format!("Bearer {token}"))
548 .header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, "true")
549 .body(serde_json::to_string(&body)?.into())?;
550 let mut response = http_client.send(request).await?;
551 let status = response.status();
552 if status.is_success() {
553 let includes_status_messages = response
554 .headers()
555 .get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME)
556 .is_some();
557
558 let tool_use_limit_reached = response
559 .headers()
560 .get(TOOL_USE_LIMIT_REACHED_HEADER_NAME)
561 .is_some();
562
563 let usage = if includes_status_messages {
564 None
565 } else {
566 RequestUsage::from_headers(response.headers()).ok()
567 };
568
569 return Ok(PerformLlmCompletionResponse {
570 response,
571 usage,
572 includes_status_messages,
573 tool_use_limit_reached,
574 });
575 } else if response
576 .headers()
577 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
578 .is_some()
579 {
580 retries_remaining -= 1;
581 token = llm_api_token.refresh(&client).await?;
582 } else if status == StatusCode::FORBIDDEN
583 && response
584 .headers()
585 .get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME)
586 .is_some()
587 {
588 return Err(anyhow!(MaxMonthlySpendReachedError));
589 } else if status == StatusCode::FORBIDDEN
590 && response
591 .headers()
592 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
593 .is_some()
594 {
595 if let Some(MODEL_REQUESTS_RESOURCE_HEADER_VALUE) = response
596 .headers()
597 .get(SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME)
598 .and_then(|resource| resource.to_str().ok())
599 {
600 if let Some(plan) = response
601 .headers()
602 .get(CURRENT_PLAN_HEADER_NAME)
603 .and_then(|plan| plan.to_str().ok())
604 .and_then(|plan| zed_llm_client::Plan::from_str(plan).ok())
605 {
606 let plan = match plan {
607 zed_llm_client::Plan::Free => Plan::Free,
608 zed_llm_client::Plan::ZedPro => Plan::ZedPro,
609 zed_llm_client::Plan::ZedProTrial => Plan::ZedProTrial,
610 };
611 return Err(anyhow!(ModelRequestLimitReachedError { plan }));
612 }
613 }
614
615 return Err(anyhow!("Forbidden"));
616 } else if status.as_u16() >= 500 && status.as_u16() < 600 {
617 // If we encounter an error in the 500 range, retry after a delay.
618 // We've seen at least these in the wild from API providers:
619 // * 500 Internal Server Error
620 // * 502 Bad Gateway
621 // * 529 Service Overloaded
622
623 if retries_remaining == 0 {
624 let mut body = String::new();
625 response.body_mut().read_to_string(&mut body).await?;
626 return Err(anyhow!(
627 "cloud language model completion failed after {} retries with status {status}: {body}",
628 Self::MAX_RETRIES
629 ));
630 }
631
632 Timer::after(retry_delay).await;
633
634 retries_remaining -= 1;
635 retry_delay *= 2; // If it fails again, wait longer.
636 } else if status == StatusCode::PAYMENT_REQUIRED {
637 return Err(anyhow!(PaymentRequiredError));
638 } else {
639 let mut body = String::new();
640 response.body_mut().read_to_string(&mut body).await?;
641 return Err(anyhow!(ApiError { status, body }));
642 }
643 }
644 }
645}
646
647#[derive(Debug, Error)]
648#[error("cloud language model request failed with status {status}: {body}")]
649struct ApiError {
650 status: StatusCode,
651 body: String,
652}
653
654impl LanguageModel for CloudLanguageModel {
655 fn id(&self) -> LanguageModelId {
656 self.id.clone()
657 }
658
659 fn name(&self) -> LanguageModelName {
660 LanguageModelName::from(self.model.display_name().to_string())
661 }
662
663 fn provider_id(&self) -> LanguageModelProviderId {
664 LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
665 }
666
667 fn provider_name(&self) -> LanguageModelProviderName {
668 LanguageModelProviderName(PROVIDER_NAME.into())
669 }
670
671 fn supports_tools(&self) -> bool {
672 match self.model {
673 CloudModel::Anthropic(_) => true,
674 CloudModel::Google(_) => true,
675 CloudModel::OpenAi(_) => true,
676 }
677 }
678
679 fn telemetry_id(&self) -> String {
680 format!("zed.dev/{}", self.model.id())
681 }
682
683 fn availability(&self) -> LanguageModelAvailability {
684 self.model.availability()
685 }
686
687 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
688 self.model.tool_input_format()
689 }
690
691 fn max_token_count(&self) -> usize {
692 self.model.max_token_count()
693 }
694
695 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
696 match &self.model {
697 CloudModel::Anthropic(model) => {
698 model
699 .cache_configuration()
700 .map(|cache| LanguageModelCacheConfiguration {
701 max_cache_anchors: cache.max_cache_anchors,
702 should_speculate: cache.should_speculate,
703 min_total_token: cache.min_total_token,
704 })
705 }
706 CloudModel::OpenAi(_) | CloudModel::Google(_) => None,
707 }
708 }
709
710 fn count_tokens(
711 &self,
712 request: LanguageModelRequest,
713 cx: &App,
714 ) -> BoxFuture<'static, Result<usize>> {
715 match self.model.clone() {
716 CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
717 CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
718 CloudModel::Google(model) => {
719 let client = self.client.clone();
720 let llm_api_token = self.llm_api_token.clone();
721 let model_id = model.id().to_string();
722 let generate_content_request = into_google(request, model_id.clone());
723 async move {
724 let http_client = &client.http_client();
725 let token = llm_api_token.acquire(&client).await?;
726
727 let request_builder = http_client::Request::builder().method(Method::POST);
728 let request_builder =
729 if let Ok(completions_url) = std::env::var("ZED_COUNT_TOKENS_URL") {
730 request_builder.uri(completions_url)
731 } else {
732 request_builder.uri(
733 http_client
734 .build_zed_llm_url("/count_tokens", &[])?
735 .as_ref(),
736 )
737 };
738 let request_body = CountTokensBody {
739 provider: zed_llm_client::LanguageModelProvider::Google,
740 model: model_id,
741 provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
742 generate_content_request,
743 })?,
744 };
745 let request = request_builder
746 .header("Content-Type", "application/json")
747 .header("Authorization", format!("Bearer {token}"))
748 .body(serde_json::to_string(&request_body)?.into())?;
749 let mut response = http_client.send(request).await?;
750 let status = response.status();
751 let mut response_body = String::new();
752 response
753 .body_mut()
754 .read_to_string(&mut response_body)
755 .await?;
756
757 if status.is_success() {
758 let response_body: CountTokensResponse =
759 serde_json::from_str(&response_body)?;
760
761 Ok(response_body.tokens)
762 } else {
763 Err(anyhow!(ApiError {
764 status,
765 body: response_body
766 }))
767 }
768 }
769 .boxed()
770 }
771 }
772 }
773
774 fn stream_completion(
775 &self,
776 request: LanguageModelRequest,
777 _cx: &AsyncApp,
778 ) -> BoxFuture<
779 'static,
780 Result<
781 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
782 >,
783 > {
784 let thread_id = request.thread_id.clone();
785 let prompt_id = request.prompt_id.clone();
786 let mode = request.mode;
787 match &self.model {
788 CloudModel::Anthropic(model) => {
789 let request = into_anthropic(
790 request,
791 model.request_id().into(),
792 model.default_temperature(),
793 model.max_output_tokens(),
794 model.mode(),
795 );
796 let client = self.client.clone();
797 let llm_api_token = self.llm_api_token.clone();
798 let future = self.request_limiter.stream(async move {
799 let PerformLlmCompletionResponse {
800 response,
801 usage,
802 includes_status_messages,
803 tool_use_limit_reached,
804 } = Self::perform_llm_completion(
805 client.clone(),
806 llm_api_token,
807 CompletionBody {
808 thread_id,
809 prompt_id,
810 mode,
811 provider: zed_llm_client::LanguageModelProvider::Anthropic,
812 model: request.model.clone(),
813 provider_request: serde_json::to_value(&request)?,
814 },
815 )
816 .await
817 .map_err(|err| match err.downcast::<ApiError>() {
818 Ok(api_err) => {
819 if api_err.status == StatusCode::BAD_REQUEST {
820 if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
821 return anyhow!(
822 LanguageModelKnownError::ContextWindowLimitExceeded {
823 tokens
824 }
825 );
826 }
827 }
828 anyhow!(api_err)
829 }
830 Err(err) => anyhow!(err),
831 })?;
832
833 let mut mapper = AnthropicEventMapper::new();
834 Ok(map_cloud_completion_events(
835 Box::pin(
836 response_lines(response, includes_status_messages)
837 .chain(usage_updated_event(usage))
838 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
839 ),
840 move |event| mapper.map_event(event),
841 ))
842 });
843 async move { Ok(future.await?.boxed()) }.boxed()
844 }
845 CloudModel::OpenAi(model) => {
846 let client = self.client.clone();
847 let request = into_open_ai(request, model, model.max_output_tokens());
848 let llm_api_token = self.llm_api_token.clone();
849 let future = self.request_limiter.stream(async move {
850 let PerformLlmCompletionResponse {
851 response,
852 usage,
853 includes_status_messages,
854 tool_use_limit_reached,
855 } = Self::perform_llm_completion(
856 client.clone(),
857 llm_api_token,
858 CompletionBody {
859 thread_id,
860 prompt_id,
861 mode,
862 provider: zed_llm_client::LanguageModelProvider::OpenAi,
863 model: request.model.clone(),
864 provider_request: serde_json::to_value(&request)?,
865 },
866 )
867 .await?;
868
869 let mut mapper = OpenAiEventMapper::new();
870 Ok(map_cloud_completion_events(
871 Box::pin(
872 response_lines(response, includes_status_messages)
873 .chain(usage_updated_event(usage))
874 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
875 ),
876 move |event| mapper.map_event(event),
877 ))
878 });
879 async move { Ok(future.await?.boxed()) }.boxed()
880 }
881 CloudModel::Google(model) => {
882 let client = self.client.clone();
883 let request = into_google(request, model.id().into());
884 let llm_api_token = self.llm_api_token.clone();
885 let future = self.request_limiter.stream(async move {
886 let PerformLlmCompletionResponse {
887 response,
888 usage,
889 includes_status_messages,
890 tool_use_limit_reached,
891 } = Self::perform_llm_completion(
892 client.clone(),
893 llm_api_token,
894 CompletionBody {
895 thread_id,
896 prompt_id,
897 mode,
898 provider: zed_llm_client::LanguageModelProvider::Google,
899 model: request.model.model_id.clone(),
900 provider_request: serde_json::to_value(&request)?,
901 },
902 )
903 .await?;
904
905 let mut mapper = GoogleEventMapper::new();
906 Ok(map_cloud_completion_events(
907 Box::pin(
908 response_lines(response, includes_status_messages)
909 .chain(usage_updated_event(usage))
910 .chain(tool_use_limit_reached_event(tool_use_limit_reached)),
911 ),
912 move |event| mapper.map_event(event),
913 ))
914 });
915 async move { Ok(future.await?.boxed()) }.boxed()
916 }
917 }
918 }
919}
920
921#[derive(Serialize, Deserialize)]
922#[serde(rename_all = "snake_case")]
923pub enum CloudCompletionEvent<T> {
924 Status(CompletionRequestStatus),
925 Event(T),
926}
927
928fn map_cloud_completion_events<T, F>(
929 stream: Pin<Box<dyn Stream<Item = Result<CloudCompletionEvent<T>>> + Send>>,
930 mut map_callback: F,
931) -> BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
932where
933 T: DeserializeOwned + 'static,
934 F: FnMut(T) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
935 + Send
936 + 'static,
937{
938 stream
939 .flat_map(move |event| {
940 futures::stream::iter(match event {
941 Err(error) => {
942 vec![Err(LanguageModelCompletionError::Other(error))]
943 }
944 Ok(CloudCompletionEvent::Status(event)) => {
945 vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
946 }
947 Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
948 })
949 })
950 .boxed()
951}
952
953fn usage_updated_event<T>(
954 usage: Option<RequestUsage>,
955) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
956 futures::stream::iter(usage.map(|usage| {
957 Ok(CloudCompletionEvent::Status(
958 CompletionRequestStatus::UsageUpdated {
959 amount: usage.amount as usize,
960 limit: usage.limit,
961 },
962 ))
963 }))
964}
965
966fn tool_use_limit_reached_event<T>(
967 tool_use_limit_reached: bool,
968) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
969 futures::stream::iter(tool_use_limit_reached.then(|| {
970 Ok(CloudCompletionEvent::Status(
971 CompletionRequestStatus::ToolUseLimitReached,
972 ))
973 }))
974}
975
976fn response_lines<T: DeserializeOwned>(
977 response: Response<AsyncBody>,
978 includes_status_messages: bool,
979) -> impl Stream<Item = Result<CloudCompletionEvent<T>>> {
980 futures::stream::try_unfold(
981 (String::new(), BufReader::new(response.into_body())),
982 move |(mut line, mut body)| async move {
983 match body.read_line(&mut line).await {
984 Ok(0) => Ok(None),
985 Ok(_) => {
986 let event = if includes_status_messages {
987 serde_json::from_str::<CloudCompletionEvent<T>>(&line)?
988 } else {
989 CloudCompletionEvent::Event(serde_json::from_str::<T>(&line)?)
990 };
991
992 line.clear();
993 Ok(Some((event, (line, body))))
994 }
995 Err(e) => Err(e.into()),
996 }
997 },
998 )
999}
1000
1001struct ConfigurationView {
1002 state: gpui::Entity<State>,
1003}
1004
1005impl ConfigurationView {
1006 fn authenticate(&mut self, cx: &mut Context<Self>) {
1007 self.state.update(cx, |state, cx| {
1008 state.authenticate(cx).detach_and_log_err(cx);
1009 });
1010 cx.notify();
1011 }
1012}
1013
1014impl Render for ConfigurationView {
1015 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1016 const ZED_AI_URL: &str = "https://zed.dev/ai";
1017
1018 let is_connected = !self.state.read(cx).is_signed_out();
1019 let plan = self.state.read(cx).user_store.read(cx).current_plan();
1020 let has_accepted_terms = self.state.read(cx).has_accepted_terms_of_service(cx);
1021
1022 let is_pro = plan == Some(proto::Plan::ZedPro);
1023 let subscription_text = Label::new(if is_pro {
1024 "You have full access to Zed's hosted LLMs, which include models from Anthropic, OpenAI, and Google. They come with faster speeds and higher limits through Zed Pro."
1025 } else {
1026 "You have basic access to models from Anthropic through the Zed AI Free plan."
1027 });
1028 let manage_subscription_button = if is_pro {
1029 Some(
1030 h_flex().child(
1031 Button::new("manage_settings", "Manage Subscription")
1032 .style(ButtonStyle::Tinted(TintColor::Accent))
1033 .on_click(
1034 cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1035 ),
1036 ),
1037 )
1038 } else if cx.has_flag::<ZedProFeatureFlag>() {
1039 Some(
1040 h_flex()
1041 .gap_2()
1042 .child(
1043 Button::new("learn_more", "Learn more")
1044 .style(ButtonStyle::Subtle)
1045 .on_click(cx.listener(|_, _, _, cx| cx.open_url(ZED_AI_URL))),
1046 )
1047 .child(
1048 Button::new("upgrade", "Upgrade")
1049 .style(ButtonStyle::Subtle)
1050 .color(Color::Accent)
1051 .on_click(
1052 cx.listener(|_, _, _, cx| cx.open_url(&zed_urls::account_url(cx))),
1053 ),
1054 ),
1055 )
1056 } else {
1057 None
1058 };
1059
1060 if is_connected {
1061 v_flex()
1062 .gap_3()
1063 .w_full()
1064 .children(render_accept_terms(
1065 self.state.clone(),
1066 LanguageModelProviderTosView::Configuration,
1067 cx,
1068 ))
1069 .when(has_accepted_terms, |this| {
1070 this.child(subscription_text)
1071 .children(manage_subscription_button)
1072 })
1073 } else {
1074 v_flex()
1075 .gap_2()
1076 .child(Label::new("Use Zed AI to access hosted language models."))
1077 .child(
1078 Button::new("sign_in", "Sign In")
1079 .icon_color(Color::Muted)
1080 .icon(IconName::Github)
1081 .icon_position(IconPosition::Start)
1082 .on_click(cx.listener(move |this, _, _, cx| this.authenticate(cx))),
1083 )
1084 }
1085 }
1086}