google.rs

  1use anyhow::{anyhow, Result};
  2use collections::BTreeMap;
  3use editor::{Editor, EditorElement, EditorStyle};
  4use futures::{future::BoxFuture, FutureExt, StreamExt};
  5use google_ai::stream_generate_content;
  6use gpui::{
  7    AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View,
  8    WhiteSpace,
  9};
 10use http_client::HttpClient;
 11use settings::{Settings, SettingsStore};
 12use std::{future, sync::Arc, time::Duration};
 13use strum::IntoEnumIterator;
 14use theme::ThemeSettings;
 15use ui::prelude::*;
 16use util::ResultExt;
 17
 18use crate::{
 19    settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
 20    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
 21    LanguageModelProviderState, LanguageModelRequest,
 22};
 23
 24const PROVIDER_ID: &str = "google";
 25const PROVIDER_NAME: &str = "Google AI";
 26
 27#[derive(Default, Clone, Debug, PartialEq)]
 28pub struct GoogleSettings {
 29    pub api_url: String,
 30    pub low_speed_timeout: Option<Duration>,
 31    pub available_models: Vec<google_ai::Model>,
 32}
 33
 34pub struct GoogleLanguageModelProvider {
 35    http_client: Arc<dyn HttpClient>,
 36    state: gpui::Model<State>,
 37}
 38
 39struct State {
 40    api_key: Option<String>,
 41    _subscription: Subscription,
 42}
 43
 44impl GoogleLanguageModelProvider {
 45    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
 46        let state = cx.new_model(|cx| State {
 47            api_key: None,
 48            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 49                cx.notify();
 50            }),
 51        });
 52
 53        Self { http_client, state }
 54    }
 55}
 56
 57impl LanguageModelProviderState for GoogleLanguageModelProvider {
 58    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
 59        Some(cx.observe(&self.state, |_, _, cx| {
 60            cx.notify();
 61        }))
 62    }
 63}
 64
 65impl LanguageModelProvider for GoogleLanguageModelProvider {
 66    fn id(&self) -> LanguageModelProviderId {
 67        LanguageModelProviderId(PROVIDER_ID.into())
 68    }
 69
 70    fn name(&self) -> LanguageModelProviderName {
 71        LanguageModelProviderName(PROVIDER_NAME.into())
 72    }
 73
 74    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
 75        let mut models = BTreeMap::default();
 76
 77        // Add base models from google_ai::Model::iter()
 78        for model in google_ai::Model::iter() {
 79            if !matches!(model, google_ai::Model::Custom { .. }) {
 80                models.insert(model.id().to_string(), model);
 81            }
 82        }
 83
 84        // Override with available models from settings
 85        for model in &AllLanguageModelSettings::get_global(cx)
 86            .google
 87            .available_models
 88        {
 89            models.insert(model.id().to_string(), model.clone());
 90        }
 91
 92        models
 93            .into_values()
 94            .map(|model| {
 95                Arc::new(GoogleLanguageModel {
 96                    id: LanguageModelId::from(model.id().to_string()),
 97                    model,
 98                    state: self.state.clone(),
 99                    http_client: self.http_client.clone(),
100                }) as Arc<dyn LanguageModel>
101            })
102            .collect()
103    }
104
105    fn is_authenticated(&self, cx: &AppContext) -> bool {
106        self.state.read(cx).api_key.is_some()
107    }
108
109    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
110        if self.is_authenticated(cx) {
111            Task::ready(Ok(()))
112        } else {
113            let api_url = AllLanguageModelSettings::get_global(cx)
114                .google
115                .api_url
116                .clone();
117            let state = self.state.clone();
118            cx.spawn(|mut cx| async move {
119                let api_key = if let Ok(api_key) = std::env::var("GOOGLE_AI_API_KEY") {
120                    api_key
121                } else {
122                    let (_, api_key) = cx
123                        .update(|cx| cx.read_credentials(&api_url))?
124                        .await?
125                        .ok_or_else(|| anyhow!("credentials not found"))?;
126                    String::from_utf8(api_key)?
127                };
128
129                state.update(&mut cx, |this, cx| {
130                    this.api_key = Some(api_key);
131                    cx.notify();
132                })
133            })
134        }
135    }
136
137    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
138        cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx))
139            .into()
140    }
141
142    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
143        let state = self.state.clone();
144        let delete_credentials =
145            cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
146        cx.spawn(|mut cx| async move {
147            delete_credentials.await.log_err();
148            state.update(&mut cx, |this, cx| {
149                this.api_key = None;
150                cx.notify();
151            })
152        })
153    }
154}
155
156pub struct GoogleLanguageModel {
157    id: LanguageModelId,
158    model: google_ai::Model,
159    state: gpui::Model<State>,
160    http_client: Arc<dyn HttpClient>,
161}
162
163impl LanguageModel for GoogleLanguageModel {
164    fn id(&self) -> LanguageModelId {
165        self.id.clone()
166    }
167
168    fn name(&self) -> LanguageModelName {
169        LanguageModelName::from(self.model.display_name().to_string())
170    }
171
172    fn provider_id(&self) -> LanguageModelProviderId {
173        LanguageModelProviderId(PROVIDER_ID.into())
174    }
175
176    fn provider_name(&self) -> LanguageModelProviderName {
177        LanguageModelProviderName(PROVIDER_NAME.into())
178    }
179
180    fn telemetry_id(&self) -> String {
181        format!("google/{}", self.model.id())
182    }
183
184    fn max_token_count(&self) -> usize {
185        self.model.max_token_count()
186    }
187
188    fn count_tokens(
189        &self,
190        request: LanguageModelRequest,
191        cx: &AppContext,
192    ) -> BoxFuture<'static, Result<usize>> {
193        let request = request.into_google(self.model.id().to_string());
194        let http_client = self.http_client.clone();
195        let api_key = self.state.read(cx).api_key.clone();
196        let api_url = AllLanguageModelSettings::get_global(cx)
197            .google
198            .api_url
199            .clone();
200
201        async move {
202            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
203            let response = google_ai::count_tokens(
204                http_client.as_ref(),
205                &api_url,
206                &api_key,
207                google_ai::CountTokensRequest {
208                    contents: request.contents,
209                },
210            )
211            .await?;
212            Ok(response.total_tokens)
213        }
214        .boxed()
215    }
216
217    fn stream_completion(
218        &self,
219        request: LanguageModelRequest,
220        cx: &AsyncAppContext,
221    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
222        let request = request.into_google(self.model.id().to_string());
223
224        let http_client = self.http_client.clone();
225        let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
226            let settings = &AllLanguageModelSettings::get_global(cx).google;
227            (state.api_key.clone(), settings.api_url.clone())
228        }) else {
229            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
230        };
231
232        async move {
233            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
234            let response =
235                stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
236            let events = response.await?;
237            Ok(google_ai::extract_text_from_events(events).boxed())
238        }
239        .boxed()
240    }
241
242    fn use_tool(
243        &self,
244        _request: LanguageModelRequest,
245        _name: String,
246        _description: String,
247        _schema: serde_json::Value,
248        _cx: &AsyncAppContext,
249    ) -> BoxFuture<'static, Result<serde_json::Value>> {
250        future::ready(Err(anyhow!("not implemented"))).boxed()
251    }
252}
253
254struct AuthenticationPrompt {
255    api_key: View<Editor>,
256    state: gpui::Model<State>,
257}
258
259impl AuthenticationPrompt {
260    fn new(state: gpui::Model<State>, cx: &mut WindowContext) -> Self {
261        Self {
262            api_key: cx.new_view(|cx| {
263                let mut editor = Editor::single_line(cx);
264                editor.set_placeholder_text("AIzaSy...", cx);
265                editor
266            }),
267            state,
268        }
269    }
270
271    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
272        let api_key = self.api_key.read(cx).text(cx);
273        if api_key.is_empty() {
274            return;
275        }
276
277        let settings = &AllLanguageModelSettings::get_global(cx).google;
278        let write_credentials =
279            cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
280        let state = self.state.clone();
281        cx.spawn(|_, mut cx| async move {
282            write_credentials.await?;
283            state.update(&mut cx, |this, cx| {
284                this.api_key = Some(api_key);
285                cx.notify();
286            })
287        })
288        .detach_and_log_err(cx);
289    }
290
291    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
292        let settings = ThemeSettings::get_global(cx);
293        let text_style = TextStyle {
294            color: cx.theme().colors().text,
295            font_family: settings.ui_font.family.clone(),
296            font_features: settings.ui_font.features.clone(),
297            font_fallbacks: settings.ui_font.fallbacks.clone(),
298            font_size: rems(0.875).into(),
299            font_weight: settings.ui_font.weight,
300            font_style: FontStyle::Normal,
301            line_height: relative(1.3),
302            background_color: None,
303            underline: None,
304            strikethrough: None,
305            white_space: WhiteSpace::Normal,
306        };
307        EditorElement::new(
308            &self.api_key,
309            EditorStyle {
310                background: cx.theme().colors().editor_background,
311                local_player: cx.theme().players().local(),
312                text: text_style,
313                ..Default::default()
314            },
315        )
316    }
317}
318
319impl Render for AuthenticationPrompt {
320    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
321        const INSTRUCTIONS: [&str; 4] = [
322            "To use the Google AI assistant, you need to add your Google AI API key.",
323            "You can create an API key at: https://makersuite.google.com/app/apikey",
324            "",
325            "Paste your Google AI API key below and hit enter to use the assistant:",
326        ];
327
328        v_flex()
329            .p_4()
330            .size_full()
331            .on_action(cx.listener(Self::save_api_key))
332            .children(
333                INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
334            )
335            .child(
336                h_flex()
337                    .w_full()
338                    .my_2()
339                    .px_2()
340                    .py_1()
341                    .bg(cx.theme().colors().editor_background)
342                    .rounded_md()
343                    .child(self.render_api_key_editor(cx)),
344            )
345            .child(
346                Label::new(
347                    "You can also assign the GOOGLE_AI_API_KEY environment variable and restart Zed.",
348                )
349                .size(LabelSize::Small),
350            )
351            .child(
352                h_flex()
353                    .gap_2()
354                    .child(Label::new("Click on").size(LabelSize::Small))
355                    .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
356                    .child(
357                        Label::new("in the status bar to close this panel.").size(LabelSize::Small),
358                    ),
359            )
360            .into_any()
361    }
362}