cloud.rs

  1use crate::{
  2    assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
  3    LanguageModelRequest,
  4};
  5use anyhow::{anyhow, Result};
  6use client::{proto, Client};
  7use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
  8use gpui::{AnyView, AppContext, Task};
  9use std::{future, sync::Arc};
 10use strum::IntoEnumIterator;
 11use ui::prelude::*;
 12
 13pub struct CloudCompletionProvider {
 14    client: Arc<Client>,
 15    model: CloudModel,
 16    settings_version: usize,
 17    status: client::Status,
 18    _maintain_client_status: Task<()>,
 19}
 20
 21impl CloudCompletionProvider {
 22    pub fn new(
 23        model: CloudModel,
 24        client: Arc<Client>,
 25        settings_version: usize,
 26        cx: &mut AppContext,
 27    ) -> Self {
 28        let mut status_rx = client.status();
 29        let status = *status_rx.borrow();
 30        let maintain_client_status = cx.spawn(|mut cx| async move {
 31            while let Some(status) = status_rx.next().await {
 32                let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 33                    if let CompletionProvider::Cloud(provider) = provider {
 34                        provider.status = status;
 35                    } else {
 36                        unreachable!()
 37                    }
 38                });
 39            }
 40        });
 41        Self {
 42            client,
 43            model,
 44            settings_version,
 45            status,
 46            _maintain_client_status: maintain_client_status,
 47        }
 48    }
 49
 50    pub fn update(&mut self, model: CloudModel, settings_version: usize) {
 51        self.model = model;
 52        self.settings_version = settings_version;
 53    }
 54
 55    pub fn available_models(&self) -> impl Iterator<Item = CloudModel> {
 56        let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
 57            Some(custom_model)
 58        } else {
 59            None
 60        };
 61        CloudModel::iter().filter_map(move |model| {
 62            if let CloudModel::Custom(_) = model {
 63                Some(CloudModel::Custom(custom_model.take()?))
 64            } else {
 65                Some(model)
 66            }
 67        })
 68    }
 69
 70    pub fn settings_version(&self) -> usize {
 71        self.settings_version
 72    }
 73
 74    pub fn model(&self) -> CloudModel {
 75        self.model.clone()
 76    }
 77
 78    pub fn is_authenticated(&self) -> bool {
 79        self.status.is_connected()
 80    }
 81
 82    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 83        let client = self.client.clone();
 84        cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
 85    }
 86
 87    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
 88        cx.new_view(|_cx| AuthenticationPrompt).into()
 89    }
 90
 91    pub fn count_tokens(
 92        &self,
 93        request: LanguageModelRequest,
 94        cx: &AppContext,
 95    ) -> BoxFuture<'static, Result<usize>> {
 96        match request.model {
 97            LanguageModel::Cloud(CloudModel::Gpt4)
 98            | LanguageModel::Cloud(CloudModel::Gpt4Turbo)
 99            | LanguageModel::Cloud(CloudModel::Gpt4Omni)
100            | LanguageModel::Cloud(CloudModel::Gpt3Point5Turbo) => {
101                count_open_ai_tokens(request, cx.background_executor())
102            }
103            LanguageModel::Cloud(
104                CloudModel::Claude3_5Sonnet
105                | CloudModel::Claude3Opus
106                | CloudModel::Claude3Sonnet
107                | CloudModel::Claude3Haiku,
108            ) => {
109                // Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
110                count_open_ai_tokens(request, cx.background_executor())
111            }
112            LanguageModel::Cloud(CloudModel::Custom(model)) => {
113                let request = self.client.request(proto::CountTokensWithLanguageModel {
114                    model,
115                    messages: request
116                        .messages
117                        .iter()
118                        .map(|message| message.to_proto())
119                        .collect(),
120                });
121                async move {
122                    let response = request.await?;
123                    Ok(response.token_count as usize)
124                }
125                .boxed()
126            }
127            _ => future::ready(Err(anyhow!("invalid model"))).boxed(),
128        }
129    }
130
131    pub fn complete(
132        &self,
133        mut request: LanguageModelRequest,
134    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
135        request.preprocess();
136
137        let request = proto::CompleteWithLanguageModel {
138            model: request.model.id().to_string(),
139            messages: request
140                .messages
141                .iter()
142                .map(|message| message.to_proto())
143                .collect(),
144            stop: request.stop,
145            temperature: request.temperature,
146            tools: Vec::new(),
147            tool_choice: None,
148        };
149
150        self.client
151            .request_stream(request)
152            .map_ok(|stream| {
153                stream
154                    .filter_map(|response| async move {
155                        match response {
156                            Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
157                            Err(error) => Some(Err(error)),
158                        }
159                    })
160                    .boxed()
161            })
162            .boxed()
163    }
164}
165
166struct AuthenticationPrompt;
167
168impl Render for AuthenticationPrompt {
169    fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
170        const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
171
172        v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
173            v_flex()
174                .gap_2()
175                .child(
176                    Button::new("sign_in", "Sign in")
177                        .icon_color(Color::Muted)
178                        .icon(IconName::Github)
179                        .icon_position(IconPosition::Start)
180                        .style(ButtonStyle::Filled)
181                        .full_width()
182                        .on_click(|_, cx| {
183                            CompletionProvider::global(cx)
184                                .authenticate(cx)
185                                .detach_and_log_err(cx);
186                        }),
187                )
188                .child(
189                    div().flex().w_full().items_center().child(
190                        Label::new("Sign in to enable collaboration.")
191                            .color(Color::Muted)
192                            .size(LabelSize::Small),
193                    ),
194                ),
195        )
196    }
197}