cloud.rs

  1use super::open_ai::count_open_ai_tokens;
  2use crate::{
  3    settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
  4    LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
  5    LanguageModelProviderState, LanguageModelRequest, RateLimiter,
  6};
  7use anyhow::{anyhow, Context as _, Result};
  8use client::{Client, UserStore};
  9use collections::BTreeMap;
 10use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
 11use gpui::{
 12    AnyView, AppContext, AsyncAppContext, FocusHandle, Model, ModelContext, Subscription, Task,
 13};
 14use schemars::JsonSchema;
 15use serde::{Deserialize, Serialize};
 16use settings::{Settings, SettingsStore};
 17use std::{future, sync::Arc};
 18use strum::IntoEnumIterator;
 19use ui::prelude::*;
 20
 21use crate::LanguageModelProvider;
 22
 23use super::anthropic::count_anthropic_tokens;
 24
 25pub const PROVIDER_ID: &str = "zed.dev";
 26pub const PROVIDER_NAME: &str = "Zed AI";
 27
 28#[derive(Default, Clone, Debug, PartialEq)]
 29pub struct ZedDotDevSettings {
 30    pub available_models: Vec<AvailableModel>,
 31}
 32
 33#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 34#[serde(rename_all = "lowercase")]
 35pub enum AvailableProvider {
 36    Anthropic,
 37    OpenAi,
 38    Google,
 39}
 40
 41#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 42pub struct AvailableModel {
 43    provider: AvailableProvider,
 44    name: String,
 45    max_tokens: usize,
 46    tool_override: Option<String>,
 47}
 48
 49pub struct CloudLanguageModelProvider {
 50    client: Arc<Client>,
 51    state: gpui::Model<State>,
 52    _maintain_client_status: Task<()>,
 53}
 54
 55pub struct State {
 56    client: Arc<Client>,
 57    user_store: Model<UserStore>,
 58    status: client::Status,
 59    _subscription: Subscription,
 60}
 61
 62impl State {
 63    fn is_connected(&self) -> bool {
 64        self.status.is_connected()
 65    }
 66
 67    fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 68        let client = self.client.clone();
 69        cx.spawn(move |this, mut cx| async move {
 70            client.authenticate_and_connect(true, &cx).await?;
 71            this.update(&mut cx, |_, cx| cx.notify())
 72        })
 73    }
 74}
 75
 76impl CloudLanguageModelProvider {
 77    pub fn new(user_store: Model<UserStore>, client: Arc<Client>, cx: &mut AppContext) -> Self {
 78        let mut status_rx = client.status();
 79        let status = *status_rx.borrow();
 80
 81        let state = cx.new_model(|cx| State {
 82            client: client.clone(),
 83            user_store,
 84            status,
 85            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 86                cx.notify();
 87            }),
 88        });
 89
 90        let state_ref = state.downgrade();
 91        let maintain_client_status = cx.spawn(|mut cx| async move {
 92            while let Some(status) = status_rx.next().await {
 93                if let Some(this) = state_ref.upgrade() {
 94                    _ = this.update(&mut cx, |this, cx| {
 95                        if this.status != status {
 96                            this.status = status;
 97                            cx.notify();
 98                        }
 99                    });
100                } else {
101                    break;
102                }
103            }
104        });
105
106        Self {
107            client,
108            state,
109            _maintain_client_status: maintain_client_status,
110        }
111    }
112}
113
114impl LanguageModelProviderState for CloudLanguageModelProvider {
115    type ObservableEntity = State;
116
117    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
118        Some(self.state.clone())
119    }
120}
121
122impl LanguageModelProvider for CloudLanguageModelProvider {
123    fn id(&self) -> LanguageModelProviderId {
124        LanguageModelProviderId(PROVIDER_ID.into())
125    }
126
127    fn name(&self) -> LanguageModelProviderName {
128        LanguageModelProviderName(PROVIDER_NAME.into())
129    }
130
131    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
132        let mut models = BTreeMap::default();
133
134        for model in anthropic::Model::iter() {
135            if !matches!(model, anthropic::Model::Custom { .. }) {
136                models.insert(model.id().to_string(), CloudModel::Anthropic(model));
137            }
138        }
139        for model in open_ai::Model::iter() {
140            if !matches!(model, open_ai::Model::Custom { .. }) {
141                models.insert(model.id().to_string(), CloudModel::OpenAi(model));
142            }
143        }
144        for model in google_ai::Model::iter() {
145            if !matches!(model, google_ai::Model::Custom { .. }) {
146                models.insert(model.id().to_string(), CloudModel::Google(model));
147            }
148        }
149
150        // Override with available models from settings
151        for model in &AllLanguageModelSettings::get_global(cx)
152            .zed_dot_dev
153            .available_models
154        {
155            let model = match model.provider {
156                AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
157                    name: model.name.clone(),
158                    max_tokens: model.max_tokens,
159                    tool_override: model.tool_override.clone(),
160                }),
161                AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
162                    name: model.name.clone(),
163                    max_tokens: model.max_tokens,
164                }),
165                AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
166                    name: model.name.clone(),
167                    max_tokens: model.max_tokens,
168                }),
169            };
170            models.insert(model.id().to_string(), model.clone());
171        }
172
173        models
174            .into_values()
175            .map(|model| {
176                Arc::new(CloudLanguageModel {
177                    id: LanguageModelId::from(model.id().to_string()),
178                    model,
179                    client: self.client.clone(),
180                    request_limiter: RateLimiter::new(4),
181                }) as Arc<dyn LanguageModel>
182            })
183            .collect()
184    }
185
186    fn is_authenticated(&self, cx: &AppContext) -> bool {
187        self.state.read(cx).status.is_connected()
188    }
189
190    fn authenticate(&self, _cx: &mut AppContext) -> Task<Result<()>> {
191        Task::ready(Ok(()))
192    }
193
194    fn configuration_view(&self, cx: &mut WindowContext) -> (AnyView, Option<FocusHandle>) {
195        let view = cx
196            .new_view(|_cx| ConfigurationView {
197                state: self.state.clone(),
198            })
199            .into();
200        (view, None)
201    }
202
203    fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
204        Task::ready(Ok(()))
205    }
206}
207
208pub struct CloudLanguageModel {
209    id: LanguageModelId,
210    model: CloudModel,
211    client: Arc<Client>,
212    request_limiter: RateLimiter,
213}
214
215impl LanguageModel for CloudLanguageModel {
216    fn id(&self) -> LanguageModelId {
217        self.id.clone()
218    }
219
220    fn name(&self) -> LanguageModelName {
221        LanguageModelName::from(self.model.display_name().to_string())
222    }
223
224    fn provider_id(&self) -> LanguageModelProviderId {
225        LanguageModelProviderId(PROVIDER_ID.into())
226    }
227
228    fn provider_name(&self) -> LanguageModelProviderName {
229        LanguageModelProviderName(PROVIDER_NAME.into())
230    }
231
232    fn telemetry_id(&self) -> String {
233        format!("zed.dev/{}", self.model.id())
234    }
235
236    fn max_token_count(&self) -> usize {
237        self.model.max_token_count()
238    }
239
240    fn count_tokens(
241        &self,
242        request: LanguageModelRequest,
243        cx: &AppContext,
244    ) -> BoxFuture<'static, Result<usize>> {
245        match self.model.clone() {
246            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
247            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
248            CloudModel::Google(model) => {
249                let client = self.client.clone();
250                let request = request.into_google(model.id().into());
251                let request = google_ai::CountTokensRequest {
252                    contents: request.contents,
253                };
254                async move {
255                    let request = serde_json::to_string(&request)?;
256                    let response = client
257                        .request(proto::CountLanguageModelTokens {
258                            provider: proto::LanguageModelProvider::Google as i32,
259                            request,
260                        })
261                        .await?;
262                    Ok(response.token_count as usize)
263                }
264                .boxed()
265            }
266        }
267    }
268
269    fn stream_completion(
270        &self,
271        request: LanguageModelRequest,
272        _: &AsyncAppContext,
273    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
274        match &self.model {
275            CloudModel::Anthropic(model) => {
276                let client = self.client.clone();
277                let request = request.into_anthropic(model.id().into());
278                let future = self.request_limiter.stream(async move {
279                    let request = serde_json::to_string(&request)?;
280                    let stream = client
281                        .request_stream(proto::StreamCompleteWithLanguageModel {
282                            provider: proto::LanguageModelProvider::Anthropic as i32,
283                            request,
284                        })
285                        .await?;
286                    Ok(anthropic::extract_text_from_events(
287                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
288                    ))
289                });
290                async move { Ok(future.await?.boxed()) }.boxed()
291            }
292            CloudModel::OpenAi(model) => {
293                let client = self.client.clone();
294                let request = request.into_open_ai(model.id().into());
295                let future = self.request_limiter.stream(async move {
296                    let request = serde_json::to_string(&request)?;
297                    let stream = client
298                        .request_stream(proto::StreamCompleteWithLanguageModel {
299                            provider: proto::LanguageModelProvider::OpenAi as i32,
300                            request,
301                        })
302                        .await?;
303                    Ok(open_ai::extract_text_from_events(
304                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
305                    ))
306                });
307                async move { Ok(future.await?.boxed()) }.boxed()
308            }
309            CloudModel::Google(model) => {
310                let client = self.client.clone();
311                let request = request.into_google(model.id().into());
312                let future = self.request_limiter.stream(async move {
313                    let request = serde_json::to_string(&request)?;
314                    let stream = client
315                        .request_stream(proto::StreamCompleteWithLanguageModel {
316                            provider: proto::LanguageModelProvider::Google as i32,
317                            request,
318                        })
319                        .await?;
320                    Ok(google_ai::extract_text_from_events(
321                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
322                    ))
323                });
324                async move { Ok(future.await?.boxed()) }.boxed()
325            }
326        }
327    }
328
329    fn use_any_tool(
330        &self,
331        request: LanguageModelRequest,
332        tool_name: String,
333        tool_description: String,
334        input_schema: serde_json::Value,
335        _cx: &AsyncAppContext,
336    ) -> BoxFuture<'static, Result<serde_json::Value>> {
337        match &self.model {
338            CloudModel::Anthropic(model) => {
339                let client = self.client.clone();
340                let mut request = request.into_anthropic(model.tool_model_id().into());
341                request.tool_choice = Some(anthropic::ToolChoice::Tool {
342                    name: tool_name.clone(),
343                });
344                request.tools = vec![anthropic::Tool {
345                    name: tool_name.clone(),
346                    description: tool_description,
347                    input_schema,
348                }];
349
350                self.request_limiter
351                    .run(async move {
352                        let request = serde_json::to_string(&request)?;
353                        let response = client
354                            .request(proto::CompleteWithLanguageModel {
355                                provider: proto::LanguageModelProvider::Anthropic as i32,
356                                request,
357                            })
358                            .await?;
359                        let response: anthropic::Response =
360                            serde_json::from_str(&response.completion)?;
361                        response
362                            .content
363                            .into_iter()
364                            .find_map(|content| {
365                                if let anthropic::Content::ToolUse { name, input, .. } = content {
366                                    if name == tool_name {
367                                        Some(input)
368                                    } else {
369                                        None
370                                    }
371                                } else {
372                                    None
373                                }
374                            })
375                            .context("tool not used")
376                    })
377                    .boxed()
378            }
379            CloudModel::OpenAi(_) => {
380                future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
381            }
382            CloudModel::Google(_) => {
383                future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
384            }
385        }
386    }
387}
388
389struct ConfigurationView {
390    state: gpui::Model<State>,
391}
392
393impl ConfigurationView {
394    fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
395        self.state.update(cx, |state, cx| {
396            state.authenticate(cx).detach_and_log_err(cx);
397        });
398        cx.notify();
399    }
400}
401
402impl Render for ConfigurationView {
403    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
404        const ZED_AI_URL: &str = "https://zed.dev/ai";
405        const ACCOUNT_SETTINGS_URL: &str = "https://zed.dev/settings";
406
407        let is_connected = self.state.read(cx).is_connected();
408        let plan = self.state.read(cx).user_store.read(cx).current_plan();
409
410        let is_pro = plan == Some(proto::Plan::ZedPro);
411
412        if is_connected {
413            v_flex()
414                .gap_3()
415                .max_w_4_5()
416                .child(Label::new(
417                    if is_pro {
418                        "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
419                    } else {
420                        "You have basic access to models from Anthropic, OpenAI, Google and more through the Zed AI Free plan."
421                    }))
422                .child(
423                    if is_pro {
424                        h_flex().child(
425                        Button::new("manage_settings", "Manage Subscription")
426                            .style(ButtonStyle::Filled)
427                            .on_click(cx.listener(|_, _, cx| {
428                                cx.open_url(ACCOUNT_SETTINGS_URL)
429                            })))
430                    } else {
431                        h_flex()
432                            .gap_2()
433                            .child(
434                        Button::new("learn_more", "Learn more")
435                            .style(ButtonStyle::Subtle)
436                            .on_click(cx.listener(|_, _, cx| {
437                                cx.open_url(ZED_AI_URL)
438                            })))
439                            .child(
440                        Button::new("upgrade", "Upgrade")
441                            .style(ButtonStyle::Subtle)
442                            .color(Color::Accent)
443                            .on_click(cx.listener(|_, _, cx| {
444                                cx.open_url(ACCOUNT_SETTINGS_URL)
445                            })))
446                    },
447                )
448        } else {
449            v_flex()
450                .gap_6()
451                .child(Label::new("Use the zed.dev to access language models."))
452                .child(
453                    v_flex()
454                        .gap_2()
455                        .child(
456                            Button::new("sign_in", "Sign in")
457                                .icon_color(Color::Muted)
458                                .icon(IconName::Github)
459                                .icon_position(IconPosition::Start)
460                                .style(ButtonStyle::Filled)
461                                .full_width()
462                                .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
463                        )
464                        .child(
465                            div().flex().w_full().items_center().child(
466                                Label::new("Sign in to enable collaboration.")
467                                    .color(Color::Muted)
468                                    .size(LabelSize::Small),
469                            ),
470                        ),
471                )
472        }
473    }
474}