google.rs

  1use anyhow::{anyhow, Result};
  2use collections::BTreeMap;
  3use editor::{Editor, EditorElement, EditorStyle};
  4use futures::{future::BoxFuture, FutureExt, StreamExt};
  5use google_ai::stream_generate_content;
  6use gpui::{
  7    AnyView, AppContext, AsyncAppContext, FocusHandle, FocusableView, FontStyle, ModelContext,
  8    Subscription, Task, TextStyle, View, WhiteSpace,
  9};
 10use http_client::HttpClient;
 11use schemars::JsonSchema;
 12use serde::{Deserialize, Serialize};
 13use settings::{Settings, SettingsStore};
 14use std::{future, sync::Arc, time::Duration};
 15use strum::IntoEnumIterator;
 16use theme::ThemeSettings;
 17use ui::{prelude::*, Indicator};
 18use util::ResultExt;
 19
 20use crate::{
 21    settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
 22    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
 23    LanguageModelProviderState, LanguageModelRequest, RateLimiter,
 24};
 25
 26const PROVIDER_ID: &str = "google";
 27const PROVIDER_NAME: &str = "Google AI";
 28
 29#[derive(Default, Clone, Debug, PartialEq)]
 30pub struct GoogleSettings {
 31    pub api_url: String,
 32    pub low_speed_timeout: Option<Duration>,
 33    pub available_models: Vec<AvailableModel>,
 34}
 35
 36#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 37pub struct AvailableModel {
 38    name: String,
 39    max_tokens: usize,
 40}
 41
 42pub struct GoogleLanguageModelProvider {
 43    http_client: Arc<dyn HttpClient>,
 44    state: gpui::Model<State>,
 45}
 46
 47pub struct State {
 48    api_key: Option<String>,
 49    _subscription: Subscription,
 50}
 51
 52impl State {
 53    fn is_authenticated(&self) -> bool {
 54        self.api_key.is_some()
 55    }
 56
 57    fn reset_api_key(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 58        let delete_credentials =
 59            cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
 60        cx.spawn(|this, mut cx| async move {
 61            delete_credentials.await.ok();
 62            this.update(&mut cx, |this, cx| {
 63                this.api_key = None;
 64                cx.notify();
 65            })
 66        })
 67    }
 68}
 69
 70impl GoogleLanguageModelProvider {
 71    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
 72        let state = cx.new_model(|cx| State {
 73            api_key: None,
 74            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 75                cx.notify();
 76            }),
 77        });
 78
 79        Self { http_client, state }
 80    }
 81}
 82
 83impl LanguageModelProviderState for GoogleLanguageModelProvider {
 84    type ObservableEntity = State;
 85
 86    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
 87        Some(self.state.clone())
 88    }
 89}
 90
 91impl LanguageModelProvider for GoogleLanguageModelProvider {
 92    fn id(&self) -> LanguageModelProviderId {
 93        LanguageModelProviderId(PROVIDER_ID.into())
 94    }
 95
 96    fn name(&self) -> LanguageModelProviderName {
 97        LanguageModelProviderName(PROVIDER_NAME.into())
 98    }
 99
100    fn icon(&self) -> IconName {
101        IconName::AiGoogle
102    }
103
104    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
105        let mut models = BTreeMap::default();
106
107        // Add base models from google_ai::Model::iter()
108        for model in google_ai::Model::iter() {
109            if !matches!(model, google_ai::Model::Custom { .. }) {
110                models.insert(model.id().to_string(), model);
111            }
112        }
113
114        // Override with available models from settings
115        for model in &AllLanguageModelSettings::get_global(cx)
116            .google
117            .available_models
118        {
119            models.insert(
120                model.name.clone(),
121                google_ai::Model::Custom {
122                    name: model.name.clone(),
123                    max_tokens: model.max_tokens,
124                },
125            );
126        }
127
128        models
129            .into_values()
130            .map(|model| {
131                Arc::new(GoogleLanguageModel {
132                    id: LanguageModelId::from(model.id().to_string()),
133                    model,
134                    state: self.state.clone(),
135                    http_client: self.http_client.clone(),
136                    rate_limiter: RateLimiter::new(4),
137                }) as Arc<dyn LanguageModel>
138            })
139            .collect()
140    }
141
142    fn is_authenticated(&self, cx: &AppContext) -> bool {
143        self.state.read(cx).is_authenticated()
144    }
145
146    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
147        if self.is_authenticated(cx) {
148            Task::ready(Ok(()))
149        } else {
150            let api_url = AllLanguageModelSettings::get_global(cx)
151                .google
152                .api_url
153                .clone();
154            let state = self.state.clone();
155            cx.spawn(|mut cx| async move {
156                let api_key = if let Ok(api_key) = std::env::var("GOOGLE_AI_API_KEY") {
157                    api_key
158                } else {
159                    let (_, api_key) = cx
160                        .update(|cx| cx.read_credentials(&api_url))?
161                        .await?
162                        .ok_or_else(|| anyhow!("credentials not found"))?;
163                    String::from_utf8(api_key)?
164                };
165
166                state.update(&mut cx, |this, cx| {
167                    this.api_key = Some(api_key);
168                    cx.notify();
169                })
170            })
171        }
172    }
173
174    fn configuration_view(&self, cx: &mut WindowContext) -> (AnyView, Option<FocusHandle>) {
175        let view = cx.new_view(|cx| ConfigurationView::new(self.state.clone(), cx));
176
177        let focus_handle = view.focus_handle(cx);
178        (view.into(), Some(focus_handle))
179    }
180
181    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
182        let state = self.state.clone();
183        let delete_credentials =
184            cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
185        cx.spawn(|mut cx| async move {
186            delete_credentials.await.log_err();
187            state.update(&mut cx, |this, cx| {
188                this.api_key = None;
189                cx.notify();
190            })
191        })
192    }
193}
194
195pub struct GoogleLanguageModel {
196    id: LanguageModelId,
197    model: google_ai::Model,
198    state: gpui::Model<State>,
199    http_client: Arc<dyn HttpClient>,
200    rate_limiter: RateLimiter,
201}
202
203impl LanguageModel for GoogleLanguageModel {
204    fn id(&self) -> LanguageModelId {
205        self.id.clone()
206    }
207
208    fn name(&self) -> LanguageModelName {
209        LanguageModelName::from(self.model.display_name().to_string())
210    }
211
212    fn provider_id(&self) -> LanguageModelProviderId {
213        LanguageModelProviderId(PROVIDER_ID.into())
214    }
215
216    fn provider_name(&self) -> LanguageModelProviderName {
217        LanguageModelProviderName(PROVIDER_NAME.into())
218    }
219
220    fn telemetry_id(&self) -> String {
221        format!("google/{}", self.model.id())
222    }
223
224    fn max_token_count(&self) -> usize {
225        self.model.max_token_count()
226    }
227
228    fn count_tokens(
229        &self,
230        request: LanguageModelRequest,
231        cx: &AppContext,
232    ) -> BoxFuture<'static, Result<usize>> {
233        let request = request.into_google(self.model.id().to_string());
234        let http_client = self.http_client.clone();
235        let api_key = self.state.read(cx).api_key.clone();
236        let api_url = AllLanguageModelSettings::get_global(cx)
237            .google
238            .api_url
239            .clone();
240
241        async move {
242            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
243            let response = google_ai::count_tokens(
244                http_client.as_ref(),
245                &api_url,
246                &api_key,
247                google_ai::CountTokensRequest {
248                    contents: request.contents,
249                },
250            )
251            .await?;
252            Ok(response.total_tokens)
253        }
254        .boxed()
255    }
256
257    fn stream_completion(
258        &self,
259        request: LanguageModelRequest,
260        cx: &AsyncAppContext,
261    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
262        let request = request.into_google(self.model.id().to_string());
263
264        let http_client = self.http_client.clone();
265        let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
266            let settings = &AllLanguageModelSettings::get_global(cx).google;
267            (state.api_key.clone(), settings.api_url.clone())
268        }) else {
269            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
270        };
271
272        let future = self.rate_limiter.stream(async move {
273            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
274            let response =
275                stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
276            let events = response.await?;
277            Ok(google_ai::extract_text_from_events(events).boxed())
278        });
279        async move { Ok(future.await?.boxed()) }.boxed()
280    }
281
282    fn use_any_tool(
283        &self,
284        _request: LanguageModelRequest,
285        _name: String,
286        _description: String,
287        _schema: serde_json::Value,
288        _cx: &AsyncAppContext,
289    ) -> BoxFuture<'static, Result<serde_json::Value>> {
290        future::ready(Err(anyhow!("not implemented"))).boxed()
291    }
292}
293
294struct ConfigurationView {
295    focus_handle: FocusHandle,
296    api_key_editor: View<Editor>,
297    state: gpui::Model<State>,
298}
299
300impl ConfigurationView {
301    fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
302        let focus_handle = cx.focus_handle();
303
304        cx.on_focus(&focus_handle, |this, cx| {
305            if this.should_render_editor(cx) {
306                this.api_key_editor.read(cx).focus_handle(cx).focus(cx)
307            }
308        })
309        .detach();
310
311        Self {
312            api_key_editor: cx.new_view(|cx| {
313                let mut editor = Editor::single_line(cx);
314                editor.set_placeholder_text("AIzaSy...", cx);
315                editor
316            }),
317            state,
318            focus_handle,
319        }
320    }
321
322    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
323        let api_key = self.api_key_editor.read(cx).text(cx);
324        if api_key.is_empty() {
325            return;
326        }
327
328        let settings = &AllLanguageModelSettings::get_global(cx).google;
329        let write_credentials =
330            cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
331        let state = self.state.clone();
332        cx.spawn(|_, mut cx| async move {
333            write_credentials.await?;
334            state.update(&mut cx, |this, cx| {
335                this.api_key = Some(api_key);
336                cx.notify();
337            })
338        })
339        .detach_and_log_err(cx);
340    }
341
342    fn reset_api_key(&mut self, cx: &mut ViewContext<Self>) {
343        self.api_key_editor
344            .update(cx, |editor, cx| editor.set_text("", cx));
345        self.state
346            .update(cx, |state, cx| state.reset_api_key(cx))
347            .detach_and_log_err(cx);
348    }
349
350    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
351        let settings = ThemeSettings::get_global(cx);
352        let text_style = TextStyle {
353            color: cx.theme().colors().text,
354            font_family: settings.ui_font.family.clone(),
355            font_features: settings.ui_font.features.clone(),
356            font_fallbacks: settings.ui_font.fallbacks.clone(),
357            font_size: rems(0.875).into(),
358            font_weight: settings.ui_font.weight,
359            font_style: FontStyle::Normal,
360            line_height: relative(1.3),
361            background_color: None,
362            underline: None,
363            strikethrough: None,
364            white_space: WhiteSpace::Normal,
365        };
366        EditorElement::new(
367            &self.api_key_editor,
368            EditorStyle {
369                background: cx.theme().colors().editor_background,
370                local_player: cx.theme().players().local(),
371                text: text_style,
372                ..Default::default()
373            },
374        )
375    }
376
377    fn should_render_editor(&self, cx: &mut ViewContext<Self>) -> bool {
378        !self.state.read(cx).is_authenticated()
379    }
380}
381
382impl FocusableView for ConfigurationView {
383    fn focus_handle(&self, _cx: &AppContext) -> FocusHandle {
384        self.focus_handle.clone()
385    }
386}
387
388impl Render for ConfigurationView {
389    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
390        const INSTRUCTIONS: [&str; 4] = [
391            "To use the Google AI assistant, you need to add your Google AI API key.",
392            "You can create an API key at: https://makersuite.google.com/app/apikey",
393            "",
394            "Paste your Google AI API key below and hit enter to use the assistant:",
395        ];
396
397        if self.should_render_editor(cx) {
398            v_flex()
399                .id("google-ai-configuration-view")
400                .track_focus(&self.focus_handle)
401                .size_full()
402                .on_action(cx.listener(Self::save_api_key))
403                .children(
404                    INSTRUCTIONS.map(|instruction| Label::new(instruction)),
405                )
406                .child(
407                    h_flex()
408                        .w_full()
409                        .my_2()
410                        .px_2()
411                        .py_1()
412                        .bg(cx.theme().colors().editor_background)
413                        .rounded_md()
414                        .child(self.render_api_key_editor(cx)),
415                )
416                .child(
417                    Label::new(
418                        "You can also assign the GOOGLE_AI_API_KEY environment variable and restart Zed.",
419                    )
420                    .size(LabelSize::Small),
421                )
422                .into_any()
423        } else {
424            h_flex()
425                .id("google-ai-configuration-view")
426                .track_focus(&self.focus_handle)
427                .size_full()
428                .justify_between()
429                .child(
430                    h_flex()
431                        .gap_2()
432                        .child(Indicator::dot().color(Color::Success))
433                        .child(Label::new("API Key configured").size(LabelSize::Small)),
434                )
435                .child(
436                    Button::new("reset-key", "Reset key")
437                        .icon(Some(IconName::Trash))
438                        .icon_size(IconSize::Small)
439                        .icon_position(IconPosition::Start)
440                        .on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))),
441                )
442                .into_any()
443        }
444    }
445}