1use crate::{
2 count_open_ai_tokens, CompletionProvider, LanguageModel, LanguageModelCompletionProvider,
3 LanguageModelRequest,
4};
5use anyhow::{anyhow, Result};
6use client::{proto, Client};
7use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
8use gpui::{AnyView, AppContext, Task};
9use language_model::CloudModel;
10use std::{future, sync::Arc};
11use strum::IntoEnumIterator;
12use ui::prelude::*;
13
14pub struct CloudCompletionProvider {
15 client: Arc<Client>,
16 model: CloudModel,
17 settings_version: usize,
18 status: client::Status,
19 _maintain_client_status: Task<()>,
20}
21
22impl CloudCompletionProvider {
23 pub fn new(
24 model: CloudModel,
25 client: Arc<Client>,
26 settings_version: usize,
27 cx: &mut AppContext,
28 ) -> Self {
29 let mut status_rx = client.status();
30 let status = *status_rx.borrow();
31 let maintain_client_status = cx.spawn(|mut cx| async move {
32 while let Some(status) = status_rx.next().await {
33 let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
34 provider.update_current_as::<_, Self>(|provider| {
35 provider.status = status;
36 });
37 });
38 }
39 });
40 Self {
41 client,
42 model,
43 settings_version,
44 status,
45 _maintain_client_status: maintain_client_status,
46 }
47 }
48
49 pub fn update(&mut self, model: CloudModel, settings_version: usize) {
50 self.model = model;
51 self.settings_version = settings_version;
52 }
53}
54
55impl LanguageModelCompletionProvider for CloudCompletionProvider {
56 fn available_models(&self) -> Vec<LanguageModel> {
57 let mut custom_model = if matches!(self.model, CloudModel::Custom { .. }) {
58 Some(self.model.clone())
59 } else {
60 None
61 };
62 CloudModel::iter()
63 .filter_map(move |model| {
64 if let CloudModel::Custom { .. } = model {
65 custom_model.take()
66 } else {
67 Some(model)
68 }
69 })
70 .map(LanguageModel::Cloud)
71 .collect()
72 }
73
74 fn settings_version(&self) -> usize {
75 self.settings_version
76 }
77
78 fn is_authenticated(&self) -> bool {
79 self.status.is_connected()
80 }
81
82 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
83 let client = self.client.clone();
84 cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
85 }
86
87 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
88 cx.new_view(|_cx| AuthenticationPrompt).into()
89 }
90
91 fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
92 Task::ready(Ok(()))
93 }
94
95 fn model(&self) -> LanguageModel {
96 LanguageModel::Cloud(self.model.clone())
97 }
98
99 fn count_tokens(
100 &self,
101 request: LanguageModelRequest,
102 cx: &AppContext,
103 ) -> BoxFuture<'static, Result<usize>> {
104 match request.model {
105 LanguageModel::Cloud(CloudModel::Gpt4)
106 | LanguageModel::Cloud(CloudModel::Gpt4Turbo)
107 | LanguageModel::Cloud(CloudModel::Gpt4Omni)
108 | LanguageModel::Cloud(CloudModel::Gpt3Point5Turbo) => {
109 count_open_ai_tokens(request, cx.background_executor())
110 }
111 LanguageModel::Cloud(
112 CloudModel::Claude3_5Sonnet
113 | CloudModel::Claude3Opus
114 | CloudModel::Claude3Sonnet
115 | CloudModel::Claude3Haiku,
116 ) => {
117 // Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
118 count_open_ai_tokens(request, cx.background_executor())
119 }
120 LanguageModel::Cloud(CloudModel::Custom { name, .. }) => {
121 let request = self.client.request(proto::CountTokensWithLanguageModel {
122 model: name,
123 messages: request
124 .messages
125 .iter()
126 .map(|message| message.to_proto())
127 .collect(),
128 });
129 async move {
130 let response = request.await?;
131 Ok(response.token_count as usize)
132 }
133 .boxed()
134 }
135 _ => future::ready(Err(anyhow!("invalid model"))).boxed(),
136 }
137 }
138
139 fn stream_completion(
140 &self,
141 mut request: LanguageModelRequest,
142 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
143 request.preprocess();
144
145 let request = proto::CompleteWithLanguageModel {
146 model: request.model.id().to_string(),
147 messages: request
148 .messages
149 .iter()
150 .map(|message| message.to_proto())
151 .collect(),
152 stop: request.stop,
153 temperature: request.temperature,
154 tools: Vec::new(),
155 tool_choice: None,
156 };
157
158 self.client
159 .request_stream(request)
160 .map_ok(|stream| {
161 stream
162 .filter_map(|response| async move {
163 match response {
164 Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
165 Err(error) => Some(Err(error)),
166 }
167 })
168 .boxed()
169 })
170 .boxed()
171 }
172
173 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
174 self
175 }
176}
177
178struct AuthenticationPrompt;
179
180impl Render for AuthenticationPrompt {
181 fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
182 const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
183
184 v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
185 v_flex()
186 .gap_2()
187 .child(
188 Button::new("sign_in", "Sign in")
189 .icon_color(Color::Muted)
190 .icon(IconName::Github)
191 .icon_position(IconPosition::Start)
192 .style(ButtonStyle::Filled)
193 .full_width()
194 .on_click(|_, cx| {
195 CompletionProvider::global(cx)
196 .authenticate(cx)
197 .detach_and_log_err(cx);
198 }),
199 )
200 .child(
201 div().flex().w_full().items_center().child(
202 Label::new("Sign in to enable collaboration.")
203 .color(Color::Muted)
204 .size(LabelSize::Small),
205 ),
206 ),
207 )
208 }
209}