cloud.rs

  1use super::open_ai::count_open_ai_tokens;
  2use crate::{
  3    settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
  4    LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
  5    LanguageModelProviderState, LanguageModelRequest,
  6};
  7use anyhow::Result;
  8use client::Client;
  9use collections::BTreeMap;
 10use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
 11use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
 12use schemars::JsonSchema;
 13use serde::{Deserialize, Serialize};
 14use settings::{Settings, SettingsStore};
 15use std::sync::Arc;
 16use strum::IntoEnumIterator;
 17use ui::prelude::*;
 18
 19use crate::LanguageModelProvider;
 20
 21use super::anthropic::count_anthropic_tokens;
 22
 23pub const PROVIDER_ID: &str = "zed.dev";
 24pub const PROVIDER_NAME: &str = "zed.dev";
 25
 26#[derive(Default, Clone, Debug, PartialEq)]
 27pub struct ZedDotDevSettings {
 28    pub available_models: Vec<AvailableModel>,
 29}
 30
 31#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 32#[serde(rename_all = "lowercase")]
 33pub enum AvailableProvider {
 34    Anthropic,
 35    OpenAi,
 36    Google,
 37}
 38
 39#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 40pub struct AvailableModel {
 41    provider: AvailableProvider,
 42    name: String,
 43    max_tokens: usize,
 44}
 45
 46pub struct CloudLanguageModelProvider {
 47    client: Arc<Client>,
 48    state: gpui::Model<State>,
 49    _maintain_client_status: Task<()>,
 50}
 51
 52struct State {
 53    client: Arc<Client>,
 54    status: client::Status,
 55    _subscription: Subscription,
 56}
 57
 58impl State {
 59    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 60        let client = self.client.clone();
 61        cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
 62    }
 63}
 64
 65impl CloudLanguageModelProvider {
 66    pub fn new(client: Arc<Client>, cx: &mut AppContext) -> Self {
 67        let mut status_rx = client.status();
 68        let status = *status_rx.borrow();
 69
 70        let state = cx.new_model(|cx| State {
 71            client: client.clone(),
 72            status,
 73            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 74                cx.notify();
 75            }),
 76        });
 77
 78        let state_ref = state.downgrade();
 79        let maintain_client_status = cx.spawn(|mut cx| async move {
 80            while let Some(status) = status_rx.next().await {
 81                if let Some(this) = state_ref.upgrade() {
 82                    _ = this.update(&mut cx, |this, cx| {
 83                        this.status = status;
 84                        cx.notify();
 85                    });
 86                } else {
 87                    break;
 88                }
 89            }
 90        });
 91
 92        Self {
 93            client,
 94            state,
 95            _maintain_client_status: maintain_client_status,
 96        }
 97    }
 98}
 99
100impl LanguageModelProviderState for CloudLanguageModelProvider {
101    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
102        Some(cx.observe(&self.state, |_, _, cx| {
103            cx.notify();
104        }))
105    }
106}
107
108impl LanguageModelProvider for CloudLanguageModelProvider {
109    fn id(&self) -> LanguageModelProviderId {
110        LanguageModelProviderId(PROVIDER_ID.into())
111    }
112
113    fn name(&self) -> LanguageModelProviderName {
114        LanguageModelProviderName(PROVIDER_NAME.into())
115    }
116
117    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
118        let mut models = BTreeMap::default();
119
120        for model in anthropic::Model::iter() {
121            if !matches!(model, anthropic::Model::Custom { .. }) {
122                models.insert(model.id().to_string(), CloudModel::Anthropic(model));
123            }
124        }
125        for model in open_ai::Model::iter() {
126            if !matches!(model, open_ai::Model::Custom { .. }) {
127                models.insert(model.id().to_string(), CloudModel::OpenAi(model));
128            }
129        }
130        for model in google_ai::Model::iter() {
131            if !matches!(model, google_ai::Model::Custom { .. }) {
132                models.insert(model.id().to_string(), CloudModel::Google(model));
133            }
134        }
135
136        // Override with available models from settings
137        for model in &AllLanguageModelSettings::get_global(cx)
138            .zed_dot_dev
139            .available_models
140        {
141            let model = match model.provider {
142                AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
143                    name: model.name.clone(),
144                    max_tokens: model.max_tokens,
145                }),
146                AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
147                    name: model.name.clone(),
148                    max_tokens: model.max_tokens,
149                }),
150                AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
151                    name: model.name.clone(),
152                    max_tokens: model.max_tokens,
153                }),
154            };
155            models.insert(model.id().to_string(), model.clone());
156        }
157
158        models
159            .into_values()
160            .map(|model| {
161                Arc::new(CloudLanguageModel {
162                    id: LanguageModelId::from(model.id().to_string()),
163                    model,
164                    client: self.client.clone(),
165                }) as Arc<dyn LanguageModel>
166            })
167            .collect()
168    }
169
170    fn is_authenticated(&self, cx: &AppContext) -> bool {
171        self.state.read(cx).status.is_connected()
172    }
173
174    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
175        self.state.read(cx).authenticate(cx)
176    }
177
178    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
179        cx.new_view(|_cx| AuthenticationPrompt {
180            state: self.state.clone(),
181        })
182        .into()
183    }
184
185    fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
186        Task::ready(Ok(()))
187    }
188}
189
190pub struct CloudLanguageModel {
191    id: LanguageModelId,
192    model: CloudModel,
193    client: Arc<Client>,
194}
195
196impl LanguageModel for CloudLanguageModel {
197    fn id(&self) -> LanguageModelId {
198        self.id.clone()
199    }
200
201    fn name(&self) -> LanguageModelName {
202        LanguageModelName::from(self.model.display_name().to_string())
203    }
204
205    fn provider_id(&self) -> LanguageModelProviderId {
206        LanguageModelProviderId(PROVIDER_ID.into())
207    }
208
209    fn provider_name(&self) -> LanguageModelProviderName {
210        LanguageModelProviderName(PROVIDER_NAME.into())
211    }
212
213    fn telemetry_id(&self) -> String {
214        format!("zed.dev/{}", self.model.id())
215    }
216
217    fn max_token_count(&self) -> usize {
218        self.model.max_token_count()
219    }
220
221    fn count_tokens(
222        &self,
223        request: LanguageModelRequest,
224        cx: &AppContext,
225    ) -> BoxFuture<'static, Result<usize>> {
226        match self.model.clone() {
227            CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
228            CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
229            CloudModel::Google(model) => {
230                let client = self.client.clone();
231                let request = request.into_google(model.id().into());
232                let request = google_ai::CountTokensRequest {
233                    contents: request.contents,
234                };
235                async move {
236                    let request = serde_json::to_string(&request)?;
237                    let response = client.request(proto::QueryLanguageModel {
238                        provider: proto::LanguageModelProvider::Google as i32,
239                        kind: proto::LanguageModelRequestKind::CountTokens as i32,
240                        request,
241                    });
242                    let response = response.await?;
243                    let response =
244                        serde_json::from_str::<google_ai::CountTokensResponse>(&response.response)?;
245                    Ok(response.total_tokens)
246                }
247                .boxed()
248            }
249        }
250    }
251
252    fn stream_completion(
253        &self,
254        request: LanguageModelRequest,
255        _: &AsyncAppContext,
256    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
257        match &self.model {
258            CloudModel::Anthropic(model) => {
259                let client = self.client.clone();
260                let request = request.into_anthropic(model.id().into());
261                async move {
262                    let request = serde_json::to_string(&request)?;
263                    let response = client.request_stream(proto::QueryLanguageModel {
264                        provider: proto::LanguageModelProvider::Anthropic as i32,
265                        kind: proto::LanguageModelRequestKind::Complete as i32,
266                        request,
267                    });
268                    let chunks = response.await?;
269                    Ok(anthropic::extract_text_from_events(
270                        chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
271                    )
272                    .boxed())
273                }
274                .boxed()
275            }
276            CloudModel::OpenAi(model) => {
277                let client = self.client.clone();
278                let request = request.into_open_ai(model.id().into());
279                async move {
280                    let request = serde_json::to_string(&request)?;
281                    let response = client.request_stream(proto::QueryLanguageModel {
282                        provider: proto::LanguageModelProvider::OpenAi as i32,
283                        kind: proto::LanguageModelRequestKind::Complete as i32,
284                        request,
285                    });
286                    let chunks = response.await?;
287                    Ok(open_ai::extract_text_from_events(
288                        chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
289                    )
290                    .boxed())
291                }
292                .boxed()
293            }
294            CloudModel::Google(model) => {
295                let client = self.client.clone();
296                let request = request.into_google(model.id().into());
297                async move {
298                    let request = serde_json::to_string(&request)?;
299                    let response = client.request_stream(proto::QueryLanguageModel {
300                        provider: proto::LanguageModelProvider::Google as i32,
301                        kind: proto::LanguageModelRequestKind::Complete as i32,
302                        request,
303                    });
304                    let chunks = response.await?;
305                    Ok(google_ai::extract_text_from_events(
306                        chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
307                    )
308                    .boxed())
309                }
310                .boxed()
311            }
312        }
313    }
314}
315
316struct AuthenticationPrompt {
317    state: gpui::Model<State>,
318}
319
320impl Render for AuthenticationPrompt {
321    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
322        const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
323
324        v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
325            v_flex()
326                .gap_2()
327                .child(
328                    Button::new("sign_in", "Sign in")
329                        .icon_color(Color::Muted)
330                        .icon(IconName::Github)
331                        .icon_position(IconPosition::Start)
332                        .style(ButtonStyle::Filled)
333                        .full_width()
334                        .on_click(cx.listener(move |this, _, cx| {
335                            this.state.update(cx, |provider, cx| {
336                                provider.authenticate(cx).detach_and_log_err(cx);
337                                cx.notify();
338                            });
339                        })),
340                )
341                .child(
342                    div().flex().w_full().items_center().child(
343                        Label::new("Sign in to enable collaboration.")
344                            .color(Color::Muted)
345                            .size(LabelSize::Small),
346                    ),
347                ),
348        )
349    }
350}