zed.rs

  1use crate::{
  2    assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
  3    LanguageModelRequest,
  4};
  5use anyhow::{anyhow, Result};
  6use client::{proto, Client};
  7use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
  8use gpui::{AnyView, AppContext, Task};
  9use std::{future, sync::Arc};
 10use ui::prelude::*;
 11
 12pub struct ZedDotDevCompletionProvider {
 13    client: Arc<Client>,
 14    default_model: ZedDotDevModel,
 15    settings_version: usize,
 16    status: client::Status,
 17    _maintain_client_status: Task<()>,
 18}
 19
 20impl ZedDotDevCompletionProvider {
 21    pub fn new(
 22        default_model: ZedDotDevModel,
 23        client: Arc<Client>,
 24        settings_version: usize,
 25        cx: &mut AppContext,
 26    ) -> Self {
 27        let mut status_rx = client.status();
 28        let status = *status_rx.borrow();
 29        let maintain_client_status = cx.spawn(|mut cx| async move {
 30            while let Some(status) = status_rx.next().await {
 31                let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 32                    if let CompletionProvider::ZedDotDev(provider) = provider {
 33                        provider.status = status;
 34                    } else {
 35                        unreachable!()
 36                    }
 37                });
 38            }
 39        });
 40        Self {
 41            client,
 42            default_model,
 43            settings_version,
 44            status,
 45            _maintain_client_status: maintain_client_status,
 46        }
 47    }
 48
 49    pub fn update(&mut self, default_model: ZedDotDevModel, settings_version: usize) {
 50        self.default_model = default_model;
 51        self.settings_version = settings_version;
 52    }
 53
 54    pub fn settings_version(&self) -> usize {
 55        self.settings_version
 56    }
 57
 58    pub fn default_model(&self) -> ZedDotDevModel {
 59        self.default_model.clone()
 60    }
 61
 62    pub fn is_authenticated(&self) -> bool {
 63        self.status.is_connected()
 64    }
 65
 66    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 67        let client = self.client.clone();
 68        cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
 69    }
 70
 71    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
 72        cx.new_view(|_cx| AuthenticationPrompt).into()
 73    }
 74
 75    pub fn count_tokens(
 76        &self,
 77        request: LanguageModelRequest,
 78        cx: &AppContext,
 79    ) -> BoxFuture<'static, Result<usize>> {
 80        match request.model {
 81            LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
 82            LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4)
 83            | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo)
 84            | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt3Point5Turbo) => {
 85                count_open_ai_tokens(request, cx.background_executor())
 86            }
 87            LanguageModel::ZedDotDev(
 88                ZedDotDevModel::Claude3Opus
 89                | ZedDotDevModel::Claude3Sonnet
 90                | ZedDotDevModel::Claude3Haiku,
 91            ) => {
 92                // Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
 93                count_open_ai_tokens(request, cx.background_executor())
 94            }
 95            LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
 96                let request = self.client.request(proto::CountTokensWithLanguageModel {
 97                    model,
 98                    messages: request
 99                        .messages
100                        .iter()
101                        .map(|message| message.to_proto())
102                        .collect(),
103                });
104                async move {
105                    let response = request.await?;
106                    Ok(response.token_count as usize)
107                }
108                .boxed()
109            }
110        }
111    }
112
113    pub fn complete(
114        &self,
115        request: LanguageModelRequest,
116    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
117        let request = proto::CompleteWithLanguageModel {
118            model: request.model.id().to_string(),
119            messages: request
120                .messages
121                .iter()
122                .map(|message| message.to_proto())
123                .collect(),
124            stop: request.stop,
125            temperature: request.temperature,
126        };
127
128        self.client
129            .request_stream(request)
130            .map_ok(|stream| {
131                stream
132                    .filter_map(|response| async move {
133                        match response {
134                            Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
135                            Err(error) => Some(Err(error)),
136                        }
137                    })
138                    .boxed()
139            })
140            .boxed()
141    }
142}
143
144struct AuthenticationPrompt;
145
146impl Render for AuthenticationPrompt {
147    fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
148        const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
149
150        v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
151            v_flex()
152                .gap_2()
153                .child(
154                    Button::new("sign_in", "Sign in")
155                        .icon_color(Color::Muted)
156                        .icon(IconName::Github)
157                        .icon_position(IconPosition::Start)
158                        .style(ButtonStyle::Filled)
159                        .full_width()
160                        .on_click(|_, cx| {
161                            CompletionProvider::global(cx)
162                                .authenticate(cx)
163                                .detach_and_log_err(cx);
164                        }),
165                )
166                .child(
167                    div().flex().w_full().items_center().child(
168                        Label::new("Sign in to enable collaboration.")
169                            .color(Color::Muted)
170                            .size(LabelSize::Small),
171                    ),
172                ),
173        )
174    }
175}