1use super::open_ai::count_open_ai_tokens;
2use crate::{
3 settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
4 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
5 LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
6};
7use anyhow::{anyhow, bail, Result};
8use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
9use collections::BTreeMap;
10use feature_flags::{FeatureFlagAppExt, LanguageModels};
11use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
12use gpui::{
13 AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
14 Subscription, Task,
15};
16use http_client::{AsyncBody, HttpClient, Method, Response};
17use schemars::JsonSchema;
18use serde::{Deserialize, Serialize};
19use serde_json::value::RawValue;
20use settings::{Settings, SettingsStore};
21use smol::{
22 io::BufReader,
23 lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
24};
25use std::{future, sync::Arc};
26use strum::IntoEnumIterator;
27use ui::prelude::*;
28
29use crate::{LanguageModelAvailability, LanguageModelProvider};
30
31use super::anthropic::count_anthropic_tokens;
32
33pub const PROVIDER_ID: &str = "zed.dev";
34pub const PROVIDER_NAME: &str = "Zed";
35
36#[derive(Default, Clone, Debug, PartialEq)]
37pub struct ZedDotDevSettings {
38 pub available_models: Vec<AvailableModel>,
39}
40
41#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
42#[serde(rename_all = "lowercase")]
43pub enum AvailableProvider {
44 Anthropic,
45 OpenAi,
46 Google,
47}
48
49#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
50pub struct AvailableModel {
51 provider: AvailableProvider,
52 name: String,
53 max_tokens: usize,
54 tool_override: Option<String>,
55}
56
57pub struct CloudLanguageModelProvider {
58 client: Arc<Client>,
59 llm_api_token: LlmApiToken,
60 state: gpui::Model<State>,
61 _maintain_client_status: Task<()>,
62}
63
64pub struct State {
65 client: Arc<Client>,
66 user_store: Model<UserStore>,
67 status: client::Status,
68 accept_terms: Option<Task<Result<()>>>,
69 _subscription: Subscription,
70}
71
72impl State {
73 fn is_signed_out(&self) -> bool {
74 self.status.is_signed_out()
75 }
76
77 fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
78 let client = self.client.clone();
79 cx.spawn(move |this, mut cx| async move {
80 client.authenticate_and_connect(true, &cx).await?;
81 this.update(&mut cx, |_, cx| cx.notify())
82 })
83 }
84
85 fn has_accepted_terms_of_service(&self, cx: &AppContext) -> bool {
86 self.user_store
87 .read(cx)
88 .current_user_has_accepted_terms()
89 .unwrap_or(false)
90 }
91
92 fn accept_terms_of_service(&mut self, cx: &mut ModelContext<Self>) {
93 let user_store = self.user_store.clone();
94 self.accept_terms = Some(cx.spawn(move |this, mut cx| async move {
95 let _ = user_store
96 .update(&mut cx, |store, cx| store.accept_terms_of_service(cx))?
97 .await;
98 this.update(&mut cx, |this, cx| {
99 this.accept_terms = None;
100 cx.notify()
101 })
102 }));
103 }
104}
105
106impl CloudLanguageModelProvider {
107 pub fn new(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) -> Self {
108 let mut status_rx = client.status();
109 let status = *status_rx.borrow();
110
111 let state = cx.new_model(|cx| State {
112 client: client.clone(),
113 user_store,
114 status,
115 accept_terms: None,
116 _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
117 cx.notify();
118 }),
119 });
120
121 let state_ref = state.downgrade();
122 let maintain_client_status = cx.spawn(|mut cx| async move {
123 while let Some(status) = status_rx.next().await {
124 if let Some(this) = state_ref.upgrade() {
125 _ = this.update(&mut cx, |this, cx| {
126 if this.status != status {
127 this.status = status;
128 cx.notify();
129 }
130 });
131 } else {
132 break;
133 }
134 }
135 });
136
137 Self {
138 client,
139 state,
140 llm_api_token: LlmApiToken::default(),
141 _maintain_client_status: maintain_client_status,
142 }
143 }
144}
145
146impl LanguageModelProviderState for CloudLanguageModelProvider {
147 type ObservableEntity = State;
148
149 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
150 Some(self.state.clone())
151 }
152}
153
154impl LanguageModelProvider for CloudLanguageModelProvider {
155 fn id(&self) -> LanguageModelProviderId {
156 LanguageModelProviderId(PROVIDER_ID.into())
157 }
158
159 fn name(&self) -> LanguageModelProviderName {
160 LanguageModelProviderName(PROVIDER_NAME.into())
161 }
162
163 fn icon(&self) -> IconName {
164 IconName::AiZed
165 }
166
167 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
168 let mut models = BTreeMap::default();
169
170 let is_user = !cx.has_flag::<LanguageModels>();
171 if is_user {
172 models.insert(
173 anthropic::Model::Claude3_5Sonnet.id().to_string(),
174 CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
175 );
176 } else {
177 for model in anthropic::Model::iter() {
178 if !matches!(model, anthropic::Model::Custom { .. }) {
179 models.insert(model.id().to_string(), CloudModel::Anthropic(model));
180 }
181 }
182 for model in open_ai::Model::iter() {
183 if !matches!(model, open_ai::Model::Custom { .. }) {
184 models.insert(model.id().to_string(), CloudModel::OpenAi(model));
185 }
186 }
187 for model in google_ai::Model::iter() {
188 if !matches!(model, google_ai::Model::Custom { .. }) {
189 models.insert(model.id().to_string(), CloudModel::Google(model));
190 }
191 }
192 for model in ZedModel::iter() {
193 models.insert(model.id().to_string(), CloudModel::Zed(model));
194 }
195
196 // Override with available models from settings
197 for model in &AllLanguageModelSettings::get_global(cx)
198 .zed_dot_dev
199 .available_models
200 {
201 let model = match model.provider {
202 AvailableProvider::Anthropic => {
203 CloudModel::Anthropic(anthropic::Model::Custom {
204 name: model.name.clone(),
205 max_tokens: model.max_tokens,
206 tool_override: model.tool_override.clone(),
207 })
208 }
209 AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
210 name: model.name.clone(),
211 max_tokens: model.max_tokens,
212 }),
213 AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
214 name: model.name.clone(),
215 max_tokens: model.max_tokens,
216 }),
217 };
218 models.insert(model.id().to_string(), model.clone());
219 }
220 }
221
222 models
223 .into_values()
224 .map(|model| {
225 Arc::new(CloudLanguageModel {
226 id: LanguageModelId::from(model.id().to_string()),
227 model,
228 llm_api_token: self.llm_api_token.clone(),
229 client: self.client.clone(),
230 request_limiter: RateLimiter::new(4),
231 }) as Arc<dyn LanguageModel>
232 })
233 .collect()
234 }
235
236 fn is_authenticated(&self, cx: &AppContext) -> bool {
237 !self.state.read(cx).is_signed_out()
238 }
239
240 fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
241 Task::ready(Ok(()))
242 }
243
244 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
245 cx.new_view(|_cx| ConfigurationView {
246 state: self.state.clone(),
247 })
248 .into()
249 }
250
251 fn must_accept_terms(&self, cx: &AppContext) -> bool {
252 !self.state.read(cx).has_accepted_terms_of_service(cx)
253 }
254
255 fn render_accept_terms(&self, cx: &mut WindowContext) -> Option<AnyElement> {
256 let state = self.state.read(cx);
257
258 let terms = [(
259 "anthropic_terms_of_service",
260 "Anthropic Terms of Service",
261 "https://www.anthropic.com/legal/consumer-terms",
262 )]
263 .map(|(id, label, url)| {
264 Button::new(id, label)
265 .style(ButtonStyle::Subtle)
266 .icon(IconName::ExternalLink)
267 .icon_size(IconSize::XSmall)
268 .icon_color(Color::Muted)
269 .on_click(move |_, cx| cx.open_url(url))
270 });
271
272 if state.has_accepted_terms_of_service(cx) {
273 None
274 } else {
275 let disabled = state.accept_terms.is_some();
276 Some(
277 v_flex()
278 .child(Label::new("Terms & Conditions").weight(FontWeight::SEMIBOLD))
279 .child("Please read and accept the terms and conditions of Zed AI and our provider partners to continue.")
280 .child(v_flex().m_2().gap_1().children(terms))
281 .child(
282 h_flex().justify_end().mt_1().child(
283 Button::new("accept_terms", "Accept")
284 .disabled(disabled)
285 .on_click({
286 let state = self.state.downgrade();
287 move |_, cx| {
288 state
289 .update(cx, |state, cx| {
290 state.accept_terms_of_service(cx)
291 })
292 .ok();
293 }
294 }),
295 ),
296 )
297 .into_any(),
298 )
299 }
300 }
301
302 fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
303 Task::ready(Ok(()))
304 }
305}
306
307pub struct CloudLanguageModel {
308 id: LanguageModelId,
309 model: CloudModel,
310 llm_api_token: LlmApiToken,
311 client: Arc<Client>,
312 request_limiter: RateLimiter,
313}
314
315#[derive(Clone, Default)]
316struct LlmApiToken(Arc<RwLock<Option<String>>>);
317
318impl CloudLanguageModel {
319 async fn perform_llm_completion(
320 client: Arc<Client>,
321 llm_api_token: LlmApiToken,
322 body: PerformCompletionParams,
323 ) -> Result<Response<AsyncBody>> {
324 let http_client = &client.http_client();
325
326 let mut token = llm_api_token.acquire(&client).await?;
327 let mut did_retry = false;
328
329 let response = loop {
330 let request = http_client::Request::builder()
331 .method(Method::POST)
332 .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
333 .header("Content-Type", "application/json")
334 .header("Authorization", format!("Bearer {token}"))
335 .body(serde_json::to_string(&body)?.into())?;
336 let response = http_client.send(request).await?;
337 if response.status().is_success() {
338 break response;
339 } else if !did_retry
340 && response
341 .headers()
342 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
343 .is_some()
344 {
345 did_retry = true;
346 token = llm_api_token.refresh(&client).await?;
347 } else {
348 break Err(anyhow!(
349 "cloud language model completion failed with status {}",
350 response.status()
351 ))?;
352 }
353 };
354
355 Ok(response)
356 }
357}
358
359impl LanguageModel for CloudLanguageModel {
360 fn id(&self) -> LanguageModelId {
361 self.id.clone()
362 }
363
364 fn name(&self) -> LanguageModelName {
365 LanguageModelName::from(self.model.display_name().to_string())
366 }
367
368 fn provider_id(&self) -> LanguageModelProviderId {
369 LanguageModelProviderId(PROVIDER_ID.into())
370 }
371
372 fn provider_name(&self) -> LanguageModelProviderName {
373 LanguageModelProviderName(PROVIDER_NAME.into())
374 }
375
376 fn telemetry_id(&self) -> String {
377 format!("zed.dev/{}", self.model.id())
378 }
379
380 fn availability(&self) -> LanguageModelAvailability {
381 self.model.availability()
382 }
383
384 fn max_token_count(&self) -> usize {
385 self.model.max_token_count()
386 }
387
388 fn count_tokens(
389 &self,
390 request: LanguageModelRequest,
391 cx: &AppContext,
392 ) -> BoxFuture<'static, Result<usize>> {
393 match self.model.clone() {
394 CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
395 CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
396 CloudModel::Google(model) => {
397 let client = self.client.clone();
398 let request = request.into_google(model.id().into());
399 let request = google_ai::CountTokensRequest {
400 contents: request.contents,
401 };
402 async move {
403 let request = serde_json::to_string(&request)?;
404 let response = client
405 .request(proto::CountLanguageModelTokens {
406 provider: proto::LanguageModelProvider::Google as i32,
407 request,
408 })
409 .await?;
410 Ok(response.token_count as usize)
411 }
412 .boxed()
413 }
414 CloudModel::Zed(_) => {
415 count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
416 }
417 }
418 }
419
420 fn stream_completion(
421 &self,
422 request: LanguageModelRequest,
423 _cx: &AsyncAppContext,
424 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
425 match &self.model {
426 CloudModel::Anthropic(model) => {
427 let request = request.into_anthropic(model.id().into());
428 let client = self.client.clone();
429 let llm_api_token = self.llm_api_token.clone();
430 let future = self.request_limiter.stream(async move {
431 let response = Self::perform_llm_completion(
432 client.clone(),
433 llm_api_token,
434 PerformCompletionParams {
435 provider: client::LanguageModelProvider::Anthropic,
436 model: request.model.clone(),
437 provider_request: RawValue::from_string(serde_json::to_string(
438 &request,
439 )?)?,
440 },
441 )
442 .await?;
443 let body = BufReader::new(response.into_body());
444 let stream = futures::stream::try_unfold(body, move |mut body| async move {
445 let mut buffer = String::new();
446 match body.read_line(&mut buffer).await {
447 Ok(0) => Ok(None),
448 Ok(_) => {
449 let event: anthropic::Event = serde_json::from_str(&buffer)?;
450 Ok(Some((event, body)))
451 }
452 Err(e) => Err(e.into()),
453 }
454 });
455
456 Ok(anthropic::extract_text_from_events(stream))
457 });
458 async move { Ok(future.await?.boxed()) }.boxed()
459 }
460 CloudModel::OpenAi(model) => {
461 let client = self.client.clone();
462 let request = request.into_open_ai(model.id().into());
463 let llm_api_token = self.llm_api_token.clone();
464 let future = self.request_limiter.stream(async move {
465 let response = Self::perform_llm_completion(
466 client.clone(),
467 llm_api_token,
468 PerformCompletionParams {
469 provider: client::LanguageModelProvider::OpenAi,
470 model: request.model.clone(),
471 provider_request: RawValue::from_string(serde_json::to_string(
472 &request,
473 )?)?,
474 },
475 )
476 .await?;
477 let body = BufReader::new(response.into_body());
478 let stream = futures::stream::try_unfold(body, move |mut body| async move {
479 let mut buffer = String::new();
480 match body.read_line(&mut buffer).await {
481 Ok(0) => Ok(None),
482 Ok(_) => {
483 let event: open_ai::ResponseStreamEvent =
484 serde_json::from_str(&buffer)?;
485 Ok(Some((event, body)))
486 }
487 Err(e) => Err(e.into()),
488 }
489 });
490
491 Ok(open_ai::extract_text_from_events(stream))
492 });
493 async move { Ok(future.await?.boxed()) }.boxed()
494 }
495 CloudModel::Google(model) => {
496 let client = self.client.clone();
497 let request = request.into_google(model.id().into());
498 let llm_api_token = self.llm_api_token.clone();
499 let future = self.request_limiter.stream(async move {
500 let response = Self::perform_llm_completion(
501 client.clone(),
502 llm_api_token,
503 PerformCompletionParams {
504 provider: client::LanguageModelProvider::Google,
505 model: request.model.clone(),
506 provider_request: RawValue::from_string(serde_json::to_string(
507 &request,
508 )?)?,
509 },
510 )
511 .await?;
512 let body = BufReader::new(response.into_body());
513 let stream = futures::stream::try_unfold(body, move |mut body| async move {
514 let mut buffer = String::new();
515 match body.read_line(&mut buffer).await {
516 Ok(0) => Ok(None),
517 Ok(_) => {
518 let event: google_ai::GenerateContentResponse =
519 serde_json::from_str(&buffer)?;
520 Ok(Some((event, body)))
521 }
522 Err(e) => Err(e.into()),
523 }
524 });
525
526 Ok(google_ai::extract_text_from_events(stream))
527 });
528 async move { Ok(future.await?.boxed()) }.boxed()
529 }
530 CloudModel::Zed(model) => {
531 let client = self.client.clone();
532 let mut request = request.into_open_ai(model.id().into());
533 request.max_tokens = Some(4000);
534 let llm_api_token = self.llm_api_token.clone();
535 let future = self.request_limiter.stream(async move {
536 let response = Self::perform_llm_completion(
537 client.clone(),
538 llm_api_token,
539 PerformCompletionParams {
540 provider: client::LanguageModelProvider::Zed,
541 model: request.model.clone(),
542 provider_request: RawValue::from_string(serde_json::to_string(
543 &request,
544 )?)?,
545 },
546 )
547 .await?;
548 let body = BufReader::new(response.into_body());
549 let stream = futures::stream::try_unfold(body, move |mut body| async move {
550 let mut buffer = String::new();
551 match body.read_line(&mut buffer).await {
552 Ok(0) => Ok(None),
553 Ok(_) => {
554 let event: open_ai::ResponseStreamEvent =
555 serde_json::from_str(&buffer)?;
556 Ok(Some((event, body)))
557 }
558 Err(e) => Err(e.into()),
559 }
560 });
561
562 Ok(open_ai::extract_text_from_events(stream))
563 });
564 async move { Ok(future.await?.boxed()) }.boxed()
565 }
566 }
567 }
568
569 fn use_any_tool(
570 &self,
571 request: LanguageModelRequest,
572 tool_name: String,
573 tool_description: String,
574 input_schema: serde_json::Value,
575 _cx: &AsyncAppContext,
576 ) -> BoxFuture<'static, Result<serde_json::Value>> {
577 match &self.model {
578 CloudModel::Anthropic(model) => {
579 let client = self.client.clone();
580 let mut request = request.into_anthropic(model.tool_model_id().into());
581 request.tool_choice = Some(anthropic::ToolChoice::Tool {
582 name: tool_name.clone(),
583 });
584 request.tools = vec![anthropic::Tool {
585 name: tool_name.clone(),
586 description: tool_description,
587 input_schema,
588 }];
589
590 let llm_api_token = self.llm_api_token.clone();
591 self.request_limiter
592 .run(async move {
593 let response = Self::perform_llm_completion(
594 client.clone(),
595 llm_api_token,
596 PerformCompletionParams {
597 provider: client::LanguageModelProvider::Anthropic,
598 model: request.model.clone(),
599 provider_request: RawValue::from_string(serde_json::to_string(
600 &request,
601 )?)?,
602 },
603 )
604 .await?;
605
606 let mut tool_use_index = None;
607 let mut tool_input = String::new();
608 let mut body = BufReader::new(response.into_body());
609 let mut line = String::new();
610 while body.read_line(&mut line).await? > 0 {
611 let event: anthropic::Event = serde_json::from_str(&line)?;
612 line.clear();
613
614 match event {
615 anthropic::Event::ContentBlockStart {
616 content_block,
617 index,
618 } => {
619 if let anthropic::Content::ToolUse { name, .. } = content_block
620 {
621 if name == tool_name {
622 tool_use_index = Some(index);
623 }
624 }
625 }
626 anthropic::Event::ContentBlockDelta { index, delta } => match delta
627 {
628 anthropic::ContentDelta::TextDelta { .. } => {}
629 anthropic::ContentDelta::InputJsonDelta { partial_json } => {
630 if Some(index) == tool_use_index {
631 tool_input.push_str(&partial_json);
632 }
633 }
634 },
635 anthropic::Event::ContentBlockStop { index } => {
636 if Some(index) == tool_use_index {
637 return Ok(serde_json::from_str(&tool_input)?);
638 }
639 }
640 _ => {}
641 }
642 }
643
644 if tool_use_index.is_some() {
645 Err(anyhow!("tool content incomplete"))
646 } else {
647 Err(anyhow!("tool not used"))
648 }
649 })
650 .boxed()
651 }
652 CloudModel::OpenAi(model) => {
653 let mut request = request.into_open_ai(model.id().into());
654 let client = self.client.clone();
655 let mut function = open_ai::FunctionDefinition {
656 name: tool_name.clone(),
657 description: None,
658 parameters: None,
659 };
660 let func = open_ai::ToolDefinition::Function {
661 function: function.clone(),
662 };
663 request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
664 // Fill in description and params separately, as they're not needed for tool_choice field.
665 function.description = Some(tool_description);
666 function.parameters = Some(input_schema);
667 request.tools = vec![open_ai::ToolDefinition::Function { function }];
668
669 let llm_api_token = self.llm_api_token.clone();
670 self.request_limiter
671 .run(async move {
672 let response = Self::perform_llm_completion(
673 client.clone(),
674 llm_api_token,
675 PerformCompletionParams {
676 provider: client::LanguageModelProvider::OpenAi,
677 model: request.model.clone(),
678 provider_request: RawValue::from_string(serde_json::to_string(
679 &request,
680 )?)?,
681 },
682 )
683 .await?;
684
685 let mut body = BufReader::new(response.into_body());
686 let mut line = String::new();
687 let mut load_state = None;
688
689 while body.read_line(&mut line).await? > 0 {
690 let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
691 line.clear();
692
693 for choice in part.choices {
694 let Some(tool_calls) = choice.delta.tool_calls else {
695 continue;
696 };
697
698 for call in tool_calls {
699 if let Some(func) = call.function {
700 if func.name.as_deref() == Some(tool_name.as_str()) {
701 load_state = Some((String::default(), call.index));
702 }
703 if let Some((arguments, (output, index))) =
704 func.arguments.zip(load_state.as_mut())
705 {
706 if call.index == *index {
707 output.push_str(&arguments);
708 }
709 }
710 }
711 }
712 }
713 }
714
715 if let Some((arguments, _)) = load_state {
716 return Ok(serde_json::from_str(&arguments)?);
717 } else {
718 bail!("tool not used");
719 }
720 })
721 .boxed()
722 }
723 CloudModel::Google(_) => {
724 future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
725 }
726 CloudModel::Zed(model) => {
727 // All Zed models are OpenAI-based at the time of writing.
728 let mut request = request.into_open_ai(model.id().into());
729 let client = self.client.clone();
730 let mut function = open_ai::FunctionDefinition {
731 name: tool_name.clone(),
732 description: None,
733 parameters: None,
734 };
735 let func = open_ai::ToolDefinition::Function {
736 function: function.clone(),
737 };
738 request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
739 // Fill in description and params separately, as they're not needed for tool_choice field.
740 function.description = Some(tool_description);
741 function.parameters = Some(input_schema);
742 request.tools = vec![open_ai::ToolDefinition::Function { function }];
743
744 let llm_api_token = self.llm_api_token.clone();
745 self.request_limiter
746 .run(async move {
747 let response = Self::perform_llm_completion(
748 client.clone(),
749 llm_api_token,
750 PerformCompletionParams {
751 provider: client::LanguageModelProvider::Zed,
752 model: request.model.clone(),
753 provider_request: RawValue::from_string(serde_json::to_string(
754 &request,
755 )?)?,
756 },
757 )
758 .await?;
759
760 let mut body = BufReader::new(response.into_body());
761 let mut line = String::new();
762 let mut load_state = None;
763
764 while body.read_line(&mut line).await? > 0 {
765 let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
766 line.clear();
767
768 for choice in part.choices {
769 let Some(tool_calls) = choice.delta.tool_calls else {
770 continue;
771 };
772
773 for call in tool_calls {
774 if let Some(func) = call.function {
775 if func.name.as_deref() == Some(tool_name.as_str()) {
776 load_state = Some((String::default(), call.index));
777 }
778 if let Some((arguments, (output, index))) =
779 func.arguments.zip(load_state.as_mut())
780 {
781 if call.index == *index {
782 output.push_str(&arguments);
783 }
784 }
785 }
786 }
787 }
788 }
789 if let Some((arguments, _)) = load_state {
790 return Ok(serde_json::from_str(&arguments)?);
791 } else {
792 bail!("tool not used");
793 }
794 })
795 .boxed()
796 }
797 }
798 }
799}
800
801impl LlmApiToken {
802 async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
803 let lock = self.0.upgradable_read().await;
804 if let Some(token) = lock.as_ref() {
805 Ok(token.to_string())
806 } else {
807 Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, &client).await
808 }
809 }
810
811 async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
812 Self::fetch(self.0.write().await, &client).await
813 }
814
815 async fn fetch<'a>(
816 mut lock: RwLockWriteGuard<'a, Option<String>>,
817 client: &Arc<Client>,
818 ) -> Result<String> {
819 let response = client.request(proto::GetLlmToken {}).await?;
820 *lock = Some(response.token.clone());
821 Ok(response.token.clone())
822 }
823}
824
825struct ConfigurationView {
826 state: gpui::Model<State>,
827}
828
829impl ConfigurationView {
830 fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
831 self.state.update(cx, |state, cx| {
832 state.authenticate(cx).detach_and_log_err(cx);
833 });
834 cx.notify();
835 }
836}
837
838impl Render for ConfigurationView {
839 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
840 const ZED_AI_URL: &str = "https://zed.dev/ai";
841 const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/account";
842
843 let is_connected = !self.state.read(cx).is_signed_out();
844 let plan = self.state.read(cx).user_store.read(cx).current_plan();
845 let must_accept_terms = !self.state.read(cx).has_accepted_terms_of_service(cx);
846
847 let is_pro = plan == Some(proto::Plan::ZedPro);
848
849 if is_connected {
850 v_flex()
851 .gap_3()
852 .max_w_4_5()
853 .when(must_accept_terms, |this| {
854 this.child(Label::new(
855 "You must accept the terms of service to use this provider.",
856 ))
857 })
858 .child(Label::new(
859 if is_pro {
860 "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
861 } else {
862 "You have basic access to models from Anthropic, OpenAI, Google and more through the Zed AI Free plan."
863 }))
864 .child(
865 if is_pro {
866 h_flex().child(
867 Button::new("manage_settings", "Manage Subscription")
868 .style(ButtonStyle::Filled)
869 .on_click(cx.listener(|_, _, cx| {
870 cx.open_url(ACCOUNT_SETTINGS_URL)
871 })))
872 } else {
873 h_flex()
874 .gap_2()
875 .child(
876 Button::new("learn_more", "Learn more")
877 .style(ButtonStyle::Subtle)
878 .on_click(cx.listener(|_, _, cx| {
879 cx.open_url(ZED_AI_URL)
880 })))
881 .child(
882 Button::new("upgrade", "Upgrade")
883 .style(ButtonStyle::Subtle)
884 .color(Color::Accent)
885 .on_click(cx.listener(|_, _, cx| {
886 cx.open_url(ACCOUNT_SETTINGS_URL)
887 })))
888 },
889 )
890 } else {
891 v_flex()
892 .gap_6()
893 .child(Label::new("Use the zed.dev to access language models."))
894 .child(
895 v_flex()
896 .gap_2()
897 .child(
898 Button::new("sign_in", "Sign in")
899 .icon_color(Color::Muted)
900 .icon(IconName::Github)
901 .icon_position(IconPosition::Start)
902 .style(ButtonStyle::Filled)
903 .full_width()
904 .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
905 )
906 .child(
907 div().flex().w_full().items_center().child(
908 Label::new("Sign in to enable collaboration.")
909 .color(Color::Muted)
910 .size(LabelSize::Small),
911 ),
912 ),
913 )
914 }
915 }
916}