cloud.rs

  1use super::open_ai::count_open_ai_tokens;
  2use crate::{
  3    settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
  4    LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
  5    LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
  6};
  7use anyhow::{anyhow, bail, Result};
  8use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
  9use collections::BTreeMap;
 10use feature_flags::{FeatureFlagAppExt, LanguageModels};
 11use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
 12use gpui::{
 13    AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
 14    Subscription, Task,
 15};
 16use http_client::{AsyncBody, HttpClient, Method, Response};
 17use schemars::JsonSchema;
 18use serde::{Deserialize, Serialize};
 19use serde_json::value::RawValue;
 20use settings::{Settings, SettingsStore};
 21use smol::{
 22    io::BufReader,
 23    lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard},
 24};
 25use std::{future, sync::Arc};
 26use strum::IntoEnumIterator;
 27use ui::prelude::*;
 28
 29use crate::{LanguageModelAvailability, LanguageModelProvider};
 30
 31use super::anthropic::count_anthropic_tokens;
 32
 33pub const PROVIDER_ID: &str = "zed.dev";
 34pub const PROVIDER_NAME: &str = "Zed";
 35
 36#[derive(Default, Clone, Debug, PartialEq)]
 37pub struct ZedDotDevSettings {
 38    pub available_models: Vec<AvailableModel>,
 39}
 40
 41#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 42#[serde(rename_all = "lowercase")]
 43pub enum AvailableProvider {
 44    Anthropic,
 45    OpenAi,
 46    Google,
 47}
 48
 49#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 50pub struct AvailableModel {
 51    provider: AvailableProvider,
 52    name: String,
 53    max_tokens: usize,
 54    tool_override: Option<String>,
 55}
 56
 57pub struct CloudLanguageModelProvider {
 58    client: Arc<Client>,
 59    llm_api_token: LlmApiToken,
 60    state: gpui::Model<State>,
 61    _maintain_client_status: Task<()>,
 62}
 63
 64pub struct State {
 65    client: Arc<Client>,
 66    user_store: Model<UserStore>,
 67    status: client::Status,
 68    accept_terms: Option<Task<Result<()>>>,
 69    _subscription: Subscription,
 70}
 71
 72impl State {
 73    fn is_signed_out(&self) -> bool {
 74        self.status.is_signed_out()
 75    }
 76
 77    fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 78        let client = self.client.clone();
 79        cx.spawn(move |this, mut cx| async move {
 80            client.authenticate_and_connect(true, &cx).await?;
 81            this.update(&mut cx, |_, cx| cx.notify())
 82        })
 83    }
 84
 85    fn has_accepted_terms_of_service(&self, cx: &AppContext) -> bool {
 86        self.user_store
 87            .read(cx)
 88            .current_user_has_accepted_terms()
 89            .unwrap_or(false)
 90    }
 91
 92    fn accept_terms_of_service(&mut self, cx: &mut ModelContext<Self>) {
 93        let user_store = self.user_store.clone();
 94        self.accept_terms = Some(cx.spawn(move |this, mut cx| async move {
 95            let _ = user_store
 96                .update(&mut cx, |store, cx| store.accept_terms_of_service(cx))?
 97                .await;
 98            this.update(&mut cx, |this, cx| {
 99                this.accept_terms = None;
100                cx.notify()
101            })
102        }));
103    }
104}
105
106impl CloudLanguageModelProvider {
107    pub fn new(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) -> Self {
108        let mut status_rx = client.status();
109        let status = *status_rx.borrow();
110
111        let state = cx.new_model(|cx| State {
112            client: client.clone(),
113            user_store,
114            status,
115            accept_terms: None,
116            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
117                cx.notify();
118            }),
119        });
120
121        let state_ref = state.downgrade();
122        let maintain_client_status = cx.spawn(|mut cx| async move {
123            while let Some(status) = status_rx.next().await {
124                if let Some(this) = state_ref.upgrade() {
125                    _ = this.update(&mut cx, |this, cx| {
126                        if this.status != status {
127                            this.status = status;
128                            cx.notify();
129                        }
130                    });
131                } else {
132                    break;
133                }
134            }
135        });
136
137        Self {
138            client,
139            state,
140            llm_api_token: LlmApiToken::default(),
141            _maintain_client_status: maintain_client_status,
142        }
143    }
144}
145
146impl LanguageModelProviderState for CloudLanguageModelProvider {
147    type ObservableEntity = State;
148
149    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
150        Some(self.state.clone())
151    }
152}
153
154impl LanguageModelProvider for CloudLanguageModelProvider {
155    fn id(&self) -> LanguageModelProviderId {
156        LanguageModelProviderId(PROVIDER_ID.into())
157    }
158
159    fn name(&self) -> LanguageModelProviderName {
160        LanguageModelProviderName(PROVIDER_NAME.into())
161    }
162
163    fn icon(&self) -> IconName {
164        IconName::AiZed
165    }
166
167    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
168        let mut models = BTreeMap::default();
169
170        let is_user = !cx.has_flag::<LanguageModels>();
171        if is_user {
172            models.insert(
173                anthropic::Model::Claude3_5Sonnet.id().to_string(),
174                CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
175            );
176        } else {
177            for model in anthropic::Model::iter() {
178                if !matches!(model, anthropic::Model::Custom { .. }) {
179                    models.insert(model.id().to_string(), CloudModel::Anthropic(model));
180                }
181            }
182            for model in open_ai::Model::iter() {
183                if !matches!(model, open_ai::Model::Custom { .. }) {
184                    models.insert(model.id().to_string(), CloudModel::OpenAi(model));
185                }
186            }
187            for model in google_ai::Model::iter() {
188                if !matches!(model, google_ai::Model::Custom { .. }) {
189                    models.insert(model.id().to_string(), CloudModel::Google(model));
190                }
191            }
192            for model in ZedModel::iter() {
193                models.insert(model.id().to_string(), CloudModel::Zed(model));
194            }
195
196            // Override with available models from settings
197            for model in &AllLanguageModelSettings::get_global(cx)
198                .zed_dot_dev
199                .available_models
200            {
201                let model = match model.provider {
202                    AvailableProvider::Anthropic => {
203                        CloudModel::Anthropic(anthropic::Model::Custom {
204                            name: model.name.clone(),
205                            max_tokens: model.max_tokens,
206                            tool_override: model.tool_override.clone(),
207                        })
208                    }
209                    AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
210                        name: model.name.clone(),
211                        max_tokens: model.max_tokens,
212                    }),
213                    AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
214                        name: model.name.clone(),
215                        max_tokens: model.max_tokens,
216                    }),
217                };
218                models.insert(model.id().to_string(), model.clone());
219            }
220        }
221
222        models
223            .into_values()
224            .map(|model| {
225                Arc::new(CloudLanguageModel {
226                    id: LanguageModelId::from(model.id().to_string()),
227                    model,
228                    llm_api_token: self.llm_api_token.clone(),
229                    client: self.client.clone(),
230                    request_limiter: RateLimiter::new(4),
231                }) as Arc<dyn LanguageModel>
232            })
233            .collect()
234    }
235
236    fn is_authenticated(&self, cx: &AppContext) -> bool {
237        !self.state.read(cx).is_signed_out()
238    }
239
240    fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
241        Task::ready(Ok(()))
242    }
243
244    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
245        cx.new_view(|_cx| ConfigurationView {
246            state: self.state.clone(),
247        })
248        .into()
249    }
250
251    fn must_accept_terms(&self, cx: &AppContext) -> bool {
252        !self.state.read(cx).has_accepted_terms_of_service(cx)
253    }
254
255    fn render_accept_terms(&self, cx: &mut WindowContext) -> Option<AnyElement> {
256        let state = self.state.read(cx);
257
258        let terms = [(
259            "anthropic_terms_of_service",
260            "Anthropic Terms of Service",
261            "https://www.anthropic.com/legal/consumer-terms",
262        )]
263        .map(|(id, label, url)| {
264            Button::new(id, label)
265                .style(ButtonStyle::Subtle)
266                .icon(IconName::ExternalLink)
267                .icon_size(IconSize::XSmall)
268                .icon_color(Color::Muted)
269                .on_click(move |_, cx| cx.open_url(url))
270        });
271
272        if state.has_accepted_terms_of_service(cx) {
273            None
274        } else {
275            let disabled = state.accept_terms.is_some();
276            Some(
277                v_flex()
278                    .child(Label::new("Terms & Conditions").weight(FontWeight::SEMIBOLD))
279                    .child("Please read and accept the terms and conditions of Zed AI and our provider partners to continue.")
280                    .child(v_flex().m_2().gap_1().children(terms))
281                    .child(
282                        h_flex().justify_end().mt_1().child(
283                            Button::new("accept_terms", "Accept")
284                                .disabled(disabled)
285                                .on_click({
286                                    let state = self.state.downgrade();
287                                    move |_, cx| {
288                                        state
289                                            .update(cx, |state, cx| {
290                                                state.accept_terms_of_service(cx)
291                                            })
292                                            .ok();
293                                    }
294                                }),
295                        ),
296                    )
297                    .into_any(),
298            )
299        }
300    }
301
302    fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
303        Task::ready(Ok(()))
304    }
305}
306
307pub struct CloudLanguageModel {
308    id: LanguageModelId,
309    model: CloudModel,
310    llm_api_token: LlmApiToken,
311    client: Arc<Client>,
312    request_limiter: RateLimiter,
313}
314
315#[derive(Clone, Default)]
316struct LlmApiToken(Arc<RwLock<Option<String>>>);
317
318impl CloudLanguageModel {
319    async fn perform_llm_completion(
320        client: Arc<Client>,
321        llm_api_token: LlmApiToken,
322        body: PerformCompletionParams,
323    ) -> Result<Response<AsyncBody>> {
324        let http_client = &client.http_client();
325
326        let mut token = llm_api_token.acquire(&client).await?;
327        let mut did_retry = false;
328
329        let response = loop {
330            let request = http_client::Request::builder()
331                .method(Method::POST)
332                .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
333                .header("Content-Type", "application/json")
334                .header("Authorization", format!("Bearer {token}"))
335                .body(serde_json::to_string(&body)?.into())?;
336            let response = http_client.send(request).await?;
337            if response.status().is_success() {
338                break response;
339            } else if !did_retry
340                && response
341                    .headers()
342                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
343                    .is_some()
344            {
345                did_retry = true;
346                token = llm_api_token.refresh(&client).await?;
347            } else {
348                break Err(anyhow!(
349                    "cloud language model completion failed with status {}",
350                    response.status()
351                ))?;
352            }
353        };
354
355        Ok(response)
356    }
357}
358
359impl LanguageModel for CloudLanguageModel {
360    fn id(&self) -> LanguageModelId {
361        self.id.clone()
362    }
363
364    fn name(&self) -> LanguageModelName {
365        LanguageModelName::from(self.model.display_name().to_string())
366    }
367
368    fn provider_id(&self) -> LanguageModelProviderId {
369        LanguageModelProviderId(PROVIDER_ID.into())
370    }
371
372    fn provider_name(&self) -> LanguageModelProviderName {
373        LanguageModelProviderName(PROVIDER_NAME.into())
374    }
375
376    fn telemetry_id(&self) -> String {
377        format!("zed.dev/{}", self.model.id())
378    }
379
380    fn availability(&self) -> LanguageModelAvailability {
381        self.model.availability()
382    }
383
384    fn max_token_count(&self) -> usize {
385        self.model.max_token_count()
386    }
387
388    fn count_tokens(
389        &self,
390        request: LanguageModelRequest,
391        cx: &AppContext,
392    ) -> BoxFuture<'static, Result<usize>> {
393        match self.model.clone() {
394            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
395            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
396            CloudModel::Google(model) => {
397                let client = self.client.clone();
398                let request = request.into_google(model.id().into());
399                let request = google_ai::CountTokensRequest {
400                    contents: request.contents,
401                };
402                async move {
403                    let request = serde_json::to_string(&request)?;
404                    let response = client
405                        .request(proto::CountLanguageModelTokens {
406                            provider: proto::LanguageModelProvider::Google as i32,
407                            request,
408                        })
409                        .await?;
410                    Ok(response.token_count as usize)
411                }
412                .boxed()
413            }
414            CloudModel::Zed(_) => {
415                count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
416            }
417        }
418    }
419
420    fn stream_completion(
421        &self,
422        request: LanguageModelRequest,
423        _cx: &AsyncAppContext,
424    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
425        match &self.model {
426            CloudModel::Anthropic(model) => {
427                let request = request.into_anthropic(model.id().into());
428                let client = self.client.clone();
429                let llm_api_token = self.llm_api_token.clone();
430                let future = self.request_limiter.stream(async move {
431                    let response = Self::perform_llm_completion(
432                        client.clone(),
433                        llm_api_token,
434                        PerformCompletionParams {
435                            provider: client::LanguageModelProvider::Anthropic,
436                            model: request.model.clone(),
437                            provider_request: RawValue::from_string(serde_json::to_string(
438                                &request,
439                            )?)?,
440                        },
441                    )
442                    .await?;
443                    let body = BufReader::new(response.into_body());
444                    let stream = futures::stream::try_unfold(body, move |mut body| async move {
445                        let mut buffer = String::new();
446                        match body.read_line(&mut buffer).await {
447                            Ok(0) => Ok(None),
448                            Ok(_) => {
449                                let event: anthropic::Event = serde_json::from_str(&buffer)?;
450                                Ok(Some((event, body)))
451                            }
452                            Err(e) => Err(e.into()),
453                        }
454                    });
455
456                    Ok(anthropic::extract_text_from_events(stream))
457                });
458                async move { Ok(future.await?.boxed()) }.boxed()
459            }
460            CloudModel::OpenAi(model) => {
461                let client = self.client.clone();
462                let request = request.into_open_ai(model.id().into());
463                let llm_api_token = self.llm_api_token.clone();
464                let future = self.request_limiter.stream(async move {
465                    let response = Self::perform_llm_completion(
466                        client.clone(),
467                        llm_api_token,
468                        PerformCompletionParams {
469                            provider: client::LanguageModelProvider::OpenAi,
470                            model: request.model.clone(),
471                            provider_request: RawValue::from_string(serde_json::to_string(
472                                &request,
473                            )?)?,
474                        },
475                    )
476                    .await?;
477                    let body = BufReader::new(response.into_body());
478                    let stream = futures::stream::try_unfold(body, move |mut body| async move {
479                        let mut buffer = String::new();
480                        match body.read_line(&mut buffer).await {
481                            Ok(0) => Ok(None),
482                            Ok(_) => {
483                                let event: open_ai::ResponseStreamEvent =
484                                    serde_json::from_str(&buffer)?;
485                                Ok(Some((event, body)))
486                            }
487                            Err(e) => Err(e.into()),
488                        }
489                    });
490
491                    Ok(open_ai::extract_text_from_events(stream))
492                });
493                async move { Ok(future.await?.boxed()) }.boxed()
494            }
495            CloudModel::Google(model) => {
496                let client = self.client.clone();
497                let request = request.into_google(model.id().into());
498                let llm_api_token = self.llm_api_token.clone();
499                let future = self.request_limiter.stream(async move {
500                    let response = Self::perform_llm_completion(
501                        client.clone(),
502                        llm_api_token,
503                        PerformCompletionParams {
504                            provider: client::LanguageModelProvider::Google,
505                            model: request.model.clone(),
506                            provider_request: RawValue::from_string(serde_json::to_string(
507                                &request,
508                            )?)?,
509                        },
510                    )
511                    .await?;
512                    let body = BufReader::new(response.into_body());
513                    let stream = futures::stream::try_unfold(body, move |mut body| async move {
514                        let mut buffer = String::new();
515                        match body.read_line(&mut buffer).await {
516                            Ok(0) => Ok(None),
517                            Ok(_) => {
518                                let event: google_ai::GenerateContentResponse =
519                                    serde_json::from_str(&buffer)?;
520                                Ok(Some((event, body)))
521                            }
522                            Err(e) => Err(e.into()),
523                        }
524                    });
525
526                    Ok(google_ai::extract_text_from_events(stream))
527                });
528                async move { Ok(future.await?.boxed()) }.boxed()
529            }
530            CloudModel::Zed(model) => {
531                let client = self.client.clone();
532                let mut request = request.into_open_ai(model.id().into());
533                request.max_tokens = Some(4000);
534                let llm_api_token = self.llm_api_token.clone();
535                let future = self.request_limiter.stream(async move {
536                    let response = Self::perform_llm_completion(
537                        client.clone(),
538                        llm_api_token,
539                        PerformCompletionParams {
540                            provider: client::LanguageModelProvider::Zed,
541                            model: request.model.clone(),
542                            provider_request: RawValue::from_string(serde_json::to_string(
543                                &request,
544                            )?)?,
545                        },
546                    )
547                    .await?;
548                    let body = BufReader::new(response.into_body());
549                    let stream = futures::stream::try_unfold(body, move |mut body| async move {
550                        let mut buffer = String::new();
551                        match body.read_line(&mut buffer).await {
552                            Ok(0) => Ok(None),
553                            Ok(_) => {
554                                let event: open_ai::ResponseStreamEvent =
555                                    serde_json::from_str(&buffer)?;
556                                Ok(Some((event, body)))
557                            }
558                            Err(e) => Err(e.into()),
559                        }
560                    });
561
562                    Ok(open_ai::extract_text_from_events(stream))
563                });
564                async move { Ok(future.await?.boxed()) }.boxed()
565            }
566        }
567    }
568
569    fn use_any_tool(
570        &self,
571        request: LanguageModelRequest,
572        tool_name: String,
573        tool_description: String,
574        input_schema: serde_json::Value,
575        _cx: &AsyncAppContext,
576    ) -> BoxFuture<'static, Result<serde_json::Value>> {
577        match &self.model {
578            CloudModel::Anthropic(model) => {
579                let client = self.client.clone();
580                let mut request = request.into_anthropic(model.tool_model_id().into());
581                request.tool_choice = Some(anthropic::ToolChoice::Tool {
582                    name: tool_name.clone(),
583                });
584                request.tools = vec![anthropic::Tool {
585                    name: tool_name.clone(),
586                    description: tool_description,
587                    input_schema,
588                }];
589
590                let llm_api_token = self.llm_api_token.clone();
591                self.request_limiter
592                    .run(async move {
593                        let response = Self::perform_llm_completion(
594                            client.clone(),
595                            llm_api_token,
596                            PerformCompletionParams {
597                                provider: client::LanguageModelProvider::Anthropic,
598                                model: request.model.clone(),
599                                provider_request: RawValue::from_string(serde_json::to_string(
600                                    &request,
601                                )?)?,
602                            },
603                        )
604                        .await?;
605
606                        let mut tool_use_index = None;
607                        let mut tool_input = String::new();
608                        let mut body = BufReader::new(response.into_body());
609                        let mut line = String::new();
610                        while body.read_line(&mut line).await? > 0 {
611                            let event: anthropic::Event = serde_json::from_str(&line)?;
612                            line.clear();
613
614                            match event {
615                                anthropic::Event::ContentBlockStart {
616                                    content_block,
617                                    index,
618                                } => {
619                                    if let anthropic::Content::ToolUse { name, .. } = content_block
620                                    {
621                                        if name == tool_name {
622                                            tool_use_index = Some(index);
623                                        }
624                                    }
625                                }
626                                anthropic::Event::ContentBlockDelta { index, delta } => match delta
627                                {
628                                    anthropic::ContentDelta::TextDelta { .. } => {}
629                                    anthropic::ContentDelta::InputJsonDelta { partial_json } => {
630                                        if Some(index) == tool_use_index {
631                                            tool_input.push_str(&partial_json);
632                                        }
633                                    }
634                                },
635                                anthropic::Event::ContentBlockStop { index } => {
636                                    if Some(index) == tool_use_index {
637                                        return Ok(serde_json::from_str(&tool_input)?);
638                                    }
639                                }
640                                _ => {}
641                            }
642                        }
643
644                        if tool_use_index.is_some() {
645                            Err(anyhow!("tool content incomplete"))
646                        } else {
647                            Err(anyhow!("tool not used"))
648                        }
649                    })
650                    .boxed()
651            }
652            CloudModel::OpenAi(model) => {
653                let mut request = request.into_open_ai(model.id().into());
654                let client = self.client.clone();
655                let mut function = open_ai::FunctionDefinition {
656                    name: tool_name.clone(),
657                    description: None,
658                    parameters: None,
659                };
660                let func = open_ai::ToolDefinition::Function {
661                    function: function.clone(),
662                };
663                request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
664                // Fill in description and params separately, as they're not needed for tool_choice field.
665                function.description = Some(tool_description);
666                function.parameters = Some(input_schema);
667                request.tools = vec![open_ai::ToolDefinition::Function { function }];
668
669                let llm_api_token = self.llm_api_token.clone();
670                self.request_limiter
671                    .run(async move {
672                        let response = Self::perform_llm_completion(
673                            client.clone(),
674                            llm_api_token,
675                            PerformCompletionParams {
676                                provider: client::LanguageModelProvider::OpenAi,
677                                model: request.model.clone(),
678                                provider_request: RawValue::from_string(serde_json::to_string(
679                                    &request,
680                                )?)?,
681                            },
682                        )
683                        .await?;
684
685                        let mut body = BufReader::new(response.into_body());
686                        let mut line = String::new();
687                        let mut load_state = None;
688
689                        while body.read_line(&mut line).await? > 0 {
690                            let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
691                            line.clear();
692
693                            for choice in part.choices {
694                                let Some(tool_calls) = choice.delta.tool_calls else {
695                                    continue;
696                                };
697
698                                for call in tool_calls {
699                                    if let Some(func) = call.function {
700                                        if func.name.as_deref() == Some(tool_name.as_str()) {
701                                            load_state = Some((String::default(), call.index));
702                                        }
703                                        if let Some((arguments, (output, index))) =
704                                            func.arguments.zip(load_state.as_mut())
705                                        {
706                                            if call.index == *index {
707                                                output.push_str(&arguments);
708                                            }
709                                        }
710                                    }
711                                }
712                            }
713                        }
714
715                        if let Some((arguments, _)) = load_state {
716                            return Ok(serde_json::from_str(&arguments)?);
717                        } else {
718                            bail!("tool not used");
719                        }
720                    })
721                    .boxed()
722            }
723            CloudModel::Google(_) => {
724                future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
725            }
726            CloudModel::Zed(model) => {
727                // All Zed models are OpenAI-based at the time of writing.
728                let mut request = request.into_open_ai(model.id().into());
729                let client = self.client.clone();
730                let mut function = open_ai::FunctionDefinition {
731                    name: tool_name.clone(),
732                    description: None,
733                    parameters: None,
734                };
735                let func = open_ai::ToolDefinition::Function {
736                    function: function.clone(),
737                };
738                request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
739                // Fill in description and params separately, as they're not needed for tool_choice field.
740                function.description = Some(tool_description);
741                function.parameters = Some(input_schema);
742                request.tools = vec![open_ai::ToolDefinition::Function { function }];
743
744                let llm_api_token = self.llm_api_token.clone();
745                self.request_limiter
746                    .run(async move {
747                        let response = Self::perform_llm_completion(
748                            client.clone(),
749                            llm_api_token,
750                            PerformCompletionParams {
751                                provider: client::LanguageModelProvider::Zed,
752                                model: request.model.clone(),
753                                provider_request: RawValue::from_string(serde_json::to_string(
754                                    &request,
755                                )?)?,
756                            },
757                        )
758                        .await?;
759
760                        let mut body = BufReader::new(response.into_body());
761                        let mut line = String::new();
762                        let mut load_state = None;
763
764                        while body.read_line(&mut line).await? > 0 {
765                            let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
766                            line.clear();
767
768                            for choice in part.choices {
769                                let Some(tool_calls) = choice.delta.tool_calls else {
770                                    continue;
771                                };
772
773                                for call in tool_calls {
774                                    if let Some(func) = call.function {
775                                        if func.name.as_deref() == Some(tool_name.as_str()) {
776                                            load_state = Some((String::default(), call.index));
777                                        }
778                                        if let Some((arguments, (output, index))) =
779                                            func.arguments.zip(load_state.as_mut())
780                                        {
781                                            if call.index == *index {
782                                                output.push_str(&arguments);
783                                            }
784                                        }
785                                    }
786                                }
787                            }
788                        }
789                        if let Some((arguments, _)) = load_state {
790                            return Ok(serde_json::from_str(&arguments)?);
791                        } else {
792                            bail!("tool not used");
793                        }
794                    })
795                    .boxed()
796            }
797        }
798    }
799}
800
801impl LlmApiToken {
802    async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
803        let lock = self.0.upgradable_read().await;
804        if let Some(token) = lock.as_ref() {
805            Ok(token.to_string())
806        } else {
807            Self::fetch(RwLockUpgradableReadGuard::upgrade(lock).await, &client).await
808        }
809    }
810
811    async fn refresh(&self, client: &Arc<Client>) -> Result<String> {
812        Self::fetch(self.0.write().await, &client).await
813    }
814
815    async fn fetch<'a>(
816        mut lock: RwLockWriteGuard<'a, Option<String>>,
817        client: &Arc<Client>,
818    ) -> Result<String> {
819        let response = client.request(proto::GetLlmToken {}).await?;
820        *lock = Some(response.token.clone());
821        Ok(response.token.clone())
822    }
823}
824
825struct ConfigurationView {
826    state: gpui::Model<State>,
827}
828
829impl ConfigurationView {
830    fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
831        self.state.update(cx, |state, cx| {
832            state.authenticate(cx).detach_and_log_err(cx);
833        });
834        cx.notify();
835    }
836}
837
838impl Render for ConfigurationView {
839    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
840        const ZED_AI_URL: &str = "https://zed.dev/ai";
841        const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/account";
842
843        let is_connected = !self.state.read(cx).is_signed_out();
844        let plan = self.state.read(cx).user_store.read(cx).current_plan();
845        let must_accept_terms = !self.state.read(cx).has_accepted_terms_of_service(cx);
846
847        let is_pro = plan == Some(proto::Plan::ZedPro);
848
849        if is_connected {
850            v_flex()
851                .gap_3()
852                .max_w_4_5()
853                .when(must_accept_terms, |this| {
854                    this.child(Label::new(
855                        "You must accept the terms of service to use this provider.",
856                    ))
857                })
858                .child(Label::new(
859                    if is_pro {
860                        "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
861                    } else {
862                        "You have basic access to models from Anthropic, OpenAI, Google and more through the Zed AI Free plan."
863                    }))
864                .child(
865                    if is_pro {
866                        h_flex().child(
867                        Button::new("manage_settings", "Manage Subscription")
868                            .style(ButtonStyle::Filled)
869                            .on_click(cx.listener(|_, _, cx| {
870                                cx.open_url(ACCOUNT_SETTINGS_URL)
871                            })))
872                    } else {
873                        h_flex()
874                            .gap_2()
875                            .child(
876                        Button::new("learn_more", "Learn more")
877                            .style(ButtonStyle::Subtle)
878                            .on_click(cx.listener(|_, _, cx| {
879                                cx.open_url(ZED_AI_URL)
880                            })))
881                            .child(
882                        Button::new("upgrade", "Upgrade")
883                            .style(ButtonStyle::Subtle)
884                            .color(Color::Accent)
885                            .on_click(cx.listener(|_, _, cx| {
886                                cx.open_url(ACCOUNT_SETTINGS_URL)
887                            })))
888                    },
889                )
890        } else {
891            v_flex()
892                .gap_6()
893                .child(Label::new("Use the zed.dev to access language models."))
894                .child(
895                    v_flex()
896                        .gap_2()
897                        .child(
898                            Button::new("sign_in", "Sign in")
899                                .icon_color(Color::Muted)
900                                .icon(IconName::Github)
901                                .icon_position(IconPosition::Start)
902                                .style(ButtonStyle::Filled)
903                                .full_width()
904                                .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
905                        )
906                        .child(
907                            div().flex().w_full().items_center().child(
908                                Label::new("Sign in to enable collaboration.")
909                                    .color(Color::Muted)
910                                    .size(LabelSize::Small),
911                            ),
912                        ),
913                )
914        }
915    }
916}