cloud.rs

  1use crate::{
  2    count_open_ai_tokens, CompletionProvider, LanguageModel, LanguageModelCompletionProvider,
  3    LanguageModelRequest,
  4};
  5use anyhow::{anyhow, Result};
  6use client::{proto, Client};
  7use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
  8use gpui::{AnyView, AppContext, Task};
  9use language_model::CloudModel;
 10use std::{future, sync::Arc};
 11use strum::IntoEnumIterator;
 12use ui::prelude::*;
 13
 14pub struct CloudCompletionProvider {
 15    client: Arc<Client>,
 16    model: CloudModel,
 17    settings_version: usize,
 18    status: client::Status,
 19    _maintain_client_status: Task<()>,
 20}
 21
 22impl CloudCompletionProvider {
 23    pub fn new(
 24        model: CloudModel,
 25        client: Arc<Client>,
 26        settings_version: usize,
 27        cx: &mut AppContext,
 28    ) -> Self {
 29        let mut status_rx = client.status();
 30        let status = *status_rx.borrow();
 31        let maintain_client_status = cx.spawn(|mut cx| async move {
 32            while let Some(status) = status_rx.next().await {
 33                let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 34                    provider.update_current_as::<_, Self>(|provider| {
 35                        provider.status = status;
 36                    });
 37                });
 38            }
 39        });
 40        Self {
 41            client,
 42            model,
 43            settings_version,
 44            status,
 45            _maintain_client_status: maintain_client_status,
 46        }
 47    }
 48
 49    pub fn update(&mut self, model: CloudModel, settings_version: usize) {
 50        self.model = model;
 51        self.settings_version = settings_version;
 52    }
 53}
 54
 55impl LanguageModelCompletionProvider for CloudCompletionProvider {
 56    fn available_models(&self) -> Vec<LanguageModel> {
 57        let mut custom_model = if matches!(self.model, CloudModel::Custom { .. }) {
 58            Some(self.model.clone())
 59        } else {
 60            None
 61        };
 62        CloudModel::iter()
 63            .filter_map(move |model| {
 64                if let CloudModel::Custom { .. } = model {
 65                    custom_model.take()
 66                } else {
 67                    Some(model)
 68                }
 69            })
 70            .map(LanguageModel::Cloud)
 71            .collect()
 72    }
 73
 74    fn settings_version(&self) -> usize {
 75        self.settings_version
 76    }
 77
 78    fn is_authenticated(&self) -> bool {
 79        self.status.is_connected()
 80    }
 81
 82    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 83        let client = self.client.clone();
 84        cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
 85    }
 86
 87    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
 88        cx.new_view(|_cx| AuthenticationPrompt).into()
 89    }
 90
 91    fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
 92        Task::ready(Ok(()))
 93    }
 94
 95    fn model(&self) -> LanguageModel {
 96        LanguageModel::Cloud(self.model.clone())
 97    }
 98
 99    fn count_tokens(
100        &self,
101        request: LanguageModelRequest,
102        cx: &AppContext,
103    ) -> BoxFuture<'static, Result<usize>> {
104        match request.model {
105            LanguageModel::Cloud(CloudModel::Gpt4)
106            | LanguageModel::Cloud(CloudModel::Gpt4Turbo)
107            | LanguageModel::Cloud(CloudModel::Gpt4Omni)
108            | LanguageModel::Cloud(CloudModel::Gpt3Point5Turbo) => {
109                count_open_ai_tokens(request, cx.background_executor())
110            }
111            LanguageModel::Cloud(
112                CloudModel::Claude3_5Sonnet
113                | CloudModel::Claude3Opus
114                | CloudModel::Claude3Sonnet
115                | CloudModel::Claude3Haiku,
116            ) => {
117                // Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
118                count_open_ai_tokens(request, cx.background_executor())
119            }
120            LanguageModel::Cloud(CloudModel::Custom { name, .. }) => {
121                let request = self.client.request(proto::CountTokensWithLanguageModel {
122                    model: name,
123                    messages: request
124                        .messages
125                        .iter()
126                        .map(|message| message.to_proto())
127                        .collect(),
128                });
129                async move {
130                    let response = request.await?;
131                    Ok(response.token_count as usize)
132                }
133                .boxed()
134            }
135            _ => future::ready(Err(anyhow!("invalid model"))).boxed(),
136        }
137    }
138
139    fn stream_completion(
140        &self,
141        mut request: LanguageModelRequest,
142    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
143        request.preprocess();
144
145        let request = proto::CompleteWithLanguageModel {
146            model: request.model.id().to_string(),
147            messages: request
148                .messages
149                .iter()
150                .map(|message| message.to_proto())
151                .collect(),
152            stop: request.stop,
153            temperature: request.temperature,
154            tools: Vec::new(),
155            tool_choice: None,
156        };
157
158        self.client
159            .request_stream(request)
160            .map_ok(|stream| {
161                stream
162                    .filter_map(|response| async move {
163                        match response {
164                            Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
165                            Err(error) => Some(Err(error)),
166                        }
167                    })
168                    .boxed()
169            })
170            .boxed()
171    }
172
173    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
174        self
175    }
176}
177
178struct AuthenticationPrompt;
179
180impl Render for AuthenticationPrompt {
181    fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
182        const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
183
184        v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
185            v_flex()
186                .gap_2()
187                .child(
188                    Button::new("sign_in", "Sign in")
189                        .icon_color(Color::Muted)
190                        .icon(IconName::Github)
191                        .icon_position(IconPosition::Start)
192                        .style(ButtonStyle::Filled)
193                        .full_width()
194                        .on_click(|_, cx| {
195                            CompletionProvider::global(cx)
196                                .authenticate(cx)
197                                .detach_and_log_err(cx);
198                        }),
199                )
200                .child(
201                    div().flex().w_full().items_center().child(
202                        Label::new("Sign in to enable collaboration.")
203                            .color(Color::Muted)
204                            .size(LabelSize::Small),
205                    ),
206                ),
207        )
208    }
209}