google.rs

  1use anyhow::{anyhow, Result};
  2use collections::BTreeMap;
  3use editor::{Editor, EditorElement, EditorStyle};
  4use futures::{future::BoxFuture, FutureExt, StreamExt};
  5use google_ai::stream_generate_content;
  6use gpui::{
  7    AnyView, AppContext, AsyncAppContext, FocusHandle, FocusableView, FontStyle, ModelContext,
  8    Subscription, Task, TextStyle, View, WhiteSpace,
  9};
 10use http_client::HttpClient;
 11use schemars::JsonSchema;
 12use serde::{Deserialize, Serialize};
 13use settings::{Settings, SettingsStore};
 14use std::{future, sync::Arc, time::Duration};
 15use strum::IntoEnumIterator;
 16use theme::ThemeSettings;
 17use ui::{prelude::*, Indicator};
 18use util::ResultExt;
 19
 20use crate::{
 21    settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
 22    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
 23    LanguageModelProviderState, LanguageModelRequest, RateLimiter,
 24};
 25
 26const PROVIDER_ID: &str = "google";
 27const PROVIDER_NAME: &str = "Google AI";
 28
 29#[derive(Default, Clone, Debug, PartialEq)]
 30pub struct GoogleSettings {
 31    pub api_url: String,
 32    pub low_speed_timeout: Option<Duration>,
 33    pub available_models: Vec<AvailableModel>,
 34}
 35
 36#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 37pub struct AvailableModel {
 38    name: String,
 39    max_tokens: usize,
 40}
 41
 42pub struct GoogleLanguageModelProvider {
 43    http_client: Arc<dyn HttpClient>,
 44    state: gpui::Model<State>,
 45}
 46
 47pub struct State {
 48    api_key: Option<String>,
 49    _subscription: Subscription,
 50}
 51
 52impl State {
 53    fn is_authenticated(&self) -> bool {
 54        self.api_key.is_some()
 55    }
 56
 57    fn reset_api_key(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 58        let delete_credentials =
 59            cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
 60        cx.spawn(|this, mut cx| async move {
 61            delete_credentials.await.ok();
 62            this.update(&mut cx, |this, cx| {
 63                this.api_key = None;
 64                cx.notify();
 65            })
 66        })
 67    }
 68}
 69
 70impl GoogleLanguageModelProvider {
 71    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
 72        let state = cx.new_model(|cx| State {
 73            api_key: None,
 74            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
 75                cx.notify();
 76            }),
 77        });
 78
 79        Self { http_client, state }
 80    }
 81}
 82
 83impl LanguageModelProviderState for GoogleLanguageModelProvider {
 84    type ObservableEntity = State;
 85
 86    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
 87        Some(self.state.clone())
 88    }
 89}
 90
 91impl LanguageModelProvider for GoogleLanguageModelProvider {
 92    fn id(&self) -> LanguageModelProviderId {
 93        LanguageModelProviderId(PROVIDER_ID.into())
 94    }
 95
 96    fn name(&self) -> LanguageModelProviderName {
 97        LanguageModelProviderName(PROVIDER_NAME.into())
 98    }
 99
100    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
101        let mut models = BTreeMap::default();
102
103        // Add base models from google_ai::Model::iter()
104        for model in google_ai::Model::iter() {
105            if !matches!(model, google_ai::Model::Custom { .. }) {
106                models.insert(model.id().to_string(), model);
107            }
108        }
109
110        // Override with available models from settings
111        for model in &AllLanguageModelSettings::get_global(cx)
112            .google
113            .available_models
114        {
115            models.insert(
116                model.name.clone(),
117                google_ai::Model::Custom {
118                    name: model.name.clone(),
119                    max_tokens: model.max_tokens,
120                },
121            );
122        }
123
124        models
125            .into_values()
126            .map(|model| {
127                Arc::new(GoogleLanguageModel {
128                    id: LanguageModelId::from(model.id().to_string()),
129                    model,
130                    state: self.state.clone(),
131                    http_client: self.http_client.clone(),
132                    rate_limiter: RateLimiter::new(4),
133                }) as Arc<dyn LanguageModel>
134            })
135            .collect()
136    }
137
138    fn is_authenticated(&self, cx: &AppContext) -> bool {
139        self.state.read(cx).is_authenticated()
140    }
141
142    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
143        if self.is_authenticated(cx) {
144            Task::ready(Ok(()))
145        } else {
146            let api_url = AllLanguageModelSettings::get_global(cx)
147                .google
148                .api_url
149                .clone();
150            let state = self.state.clone();
151            cx.spawn(|mut cx| async move {
152                let api_key = if let Ok(api_key) = std::env::var("GOOGLE_AI_API_KEY") {
153                    api_key
154                } else {
155                    let (_, api_key) = cx
156                        .update(|cx| cx.read_credentials(&api_url))?
157                        .await?
158                        .ok_or_else(|| anyhow!("credentials not found"))?;
159                    String::from_utf8(api_key)?
160                };
161
162                state.update(&mut cx, |this, cx| {
163                    this.api_key = Some(api_key);
164                    cx.notify();
165                })
166            })
167        }
168    }
169
170    fn configuration_view(&self, cx: &mut WindowContext) -> (AnyView, Option<FocusHandle>) {
171        let view = cx.new_view(|cx| ConfigurationView::new(self.state.clone(), cx));
172
173        let focus_handle = view.focus_handle(cx);
174        (view.into(), Some(focus_handle))
175    }
176
177    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
178        let state = self.state.clone();
179        let delete_credentials =
180            cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
181        cx.spawn(|mut cx| async move {
182            delete_credentials.await.log_err();
183            state.update(&mut cx, |this, cx| {
184                this.api_key = None;
185                cx.notify();
186            })
187        })
188    }
189}
190
191pub struct GoogleLanguageModel {
192    id: LanguageModelId,
193    model: google_ai::Model,
194    state: gpui::Model<State>,
195    http_client: Arc<dyn HttpClient>,
196    rate_limiter: RateLimiter,
197}
198
199impl LanguageModel for GoogleLanguageModel {
200    fn id(&self) -> LanguageModelId {
201        self.id.clone()
202    }
203
204    fn name(&self) -> LanguageModelName {
205        LanguageModelName::from(self.model.display_name().to_string())
206    }
207
208    fn provider_id(&self) -> LanguageModelProviderId {
209        LanguageModelProviderId(PROVIDER_ID.into())
210    }
211
212    fn provider_name(&self) -> LanguageModelProviderName {
213        LanguageModelProviderName(PROVIDER_NAME.into())
214    }
215
216    fn telemetry_id(&self) -> String {
217        format!("google/{}", self.model.id())
218    }
219
220    fn max_token_count(&self) -> usize {
221        self.model.max_token_count()
222    }
223
224    fn count_tokens(
225        &self,
226        request: LanguageModelRequest,
227        cx: &AppContext,
228    ) -> BoxFuture<'static, Result<usize>> {
229        let request = request.into_google(self.model.id().to_string());
230        let http_client = self.http_client.clone();
231        let api_key = self.state.read(cx).api_key.clone();
232        let api_url = AllLanguageModelSettings::get_global(cx)
233            .google
234            .api_url
235            .clone();
236
237        async move {
238            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
239            let response = google_ai::count_tokens(
240                http_client.as_ref(),
241                &api_url,
242                &api_key,
243                google_ai::CountTokensRequest {
244                    contents: request.contents,
245                },
246            )
247            .await?;
248            Ok(response.total_tokens)
249        }
250        .boxed()
251    }
252
253    fn stream_completion(
254        &self,
255        request: LanguageModelRequest,
256        cx: &AsyncAppContext,
257    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
258        let request = request.into_google(self.model.id().to_string());
259
260        let http_client = self.http_client.clone();
261        let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
262            let settings = &AllLanguageModelSettings::get_global(cx).google;
263            (state.api_key.clone(), settings.api_url.clone())
264        }) else {
265            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
266        };
267
268        let future = self.rate_limiter.stream(async move {
269            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
270            let response =
271                stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
272            let events = response.await?;
273            Ok(google_ai::extract_text_from_events(events).boxed())
274        });
275        async move { Ok(future.await?.boxed()) }.boxed()
276    }
277
278    fn use_any_tool(
279        &self,
280        _request: LanguageModelRequest,
281        _name: String,
282        _description: String,
283        _schema: serde_json::Value,
284        _cx: &AsyncAppContext,
285    ) -> BoxFuture<'static, Result<serde_json::Value>> {
286        future::ready(Err(anyhow!("not implemented"))).boxed()
287    }
288}
289
290struct ConfigurationView {
291    api_key_editor: View<Editor>,
292    state: gpui::Model<State>,
293}
294
295impl ConfigurationView {
296    fn new(state: gpui::Model<State>, cx: &mut WindowContext) -> Self {
297        Self {
298            api_key_editor: cx.new_view(|cx| {
299                let mut editor = Editor::single_line(cx);
300                editor.set_placeholder_text("AIzaSy...", cx);
301                editor
302            }),
303            state,
304        }
305    }
306
307    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
308        let api_key = self.api_key_editor.read(cx).text(cx);
309        if api_key.is_empty() {
310            return;
311        }
312
313        let settings = &AllLanguageModelSettings::get_global(cx).google;
314        let write_credentials =
315            cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
316        let state = self.state.clone();
317        cx.spawn(|_, mut cx| async move {
318            write_credentials.await?;
319            state.update(&mut cx, |this, cx| {
320                this.api_key = Some(api_key);
321                cx.notify();
322            })
323        })
324        .detach_and_log_err(cx);
325    }
326
327    fn reset_api_key(&mut self, cx: &mut ViewContext<Self>) {
328        self.api_key_editor
329            .update(cx, |editor, cx| editor.set_text("", cx));
330        self.state
331            .update(cx, |state, cx| state.reset_api_key(cx))
332            .detach_and_log_err(cx);
333    }
334
335    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
336        let settings = ThemeSettings::get_global(cx);
337        let text_style = TextStyle {
338            color: cx.theme().colors().text,
339            font_family: settings.ui_font.family.clone(),
340            font_features: settings.ui_font.features.clone(),
341            font_fallbacks: settings.ui_font.fallbacks.clone(),
342            font_size: rems(0.875).into(),
343            font_weight: settings.ui_font.weight,
344            font_style: FontStyle::Normal,
345            line_height: relative(1.3),
346            background_color: None,
347            underline: None,
348            strikethrough: None,
349            white_space: WhiteSpace::Normal,
350        };
351        EditorElement::new(
352            &self.api_key_editor,
353            EditorStyle {
354                background: cx.theme().colors().editor_background,
355                local_player: cx.theme().players().local(),
356                text: text_style,
357                ..Default::default()
358            },
359        )
360    }
361}
362
363impl FocusableView for ConfigurationView {
364    fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
365        self.api_key_editor.read(cx).focus_handle(cx)
366    }
367}
368
369impl Render for ConfigurationView {
370    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
371        const INSTRUCTIONS: [&str; 4] = [
372            "To use the Google AI assistant, you need to add your Google AI API key.",
373            "You can create an API key at: https://makersuite.google.com/app/apikey",
374            "",
375            "Paste your Google AI API key below and hit enter to use the assistant:",
376        ];
377
378        if self.state.read(cx).is_authenticated() {
379            h_flex()
380                .size_full()
381                .justify_between()
382                .child(
383                    h_flex()
384                        .gap_2()
385                        .child(Indicator::dot().color(Color::Success))
386                        .child(Label::new("API Key configured").size(LabelSize::Small)),
387                )
388                .child(
389                    Button::new("reset-key", "Reset key")
390                        .icon(Some(IconName::Trash))
391                        .icon_size(IconSize::Small)
392                        .icon_position(IconPosition::Start)
393                        .on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))),
394                )
395                .into_any()
396        } else {
397            v_flex()
398                .size_full()
399                .on_action(cx.listener(Self::save_api_key))
400                .children(
401                    INSTRUCTIONS.map(|instruction| Label::new(instruction)),
402                )
403                .child(
404                    h_flex()
405                        .w_full()
406                        .my_2()
407                        .px_2()
408                        .py_1()
409                        .bg(cx.theme().colors().editor_background)
410                        .rounded_md()
411                        .child(self.render_api_key_editor(cx)),
412                )
413                .child(
414                    Label::new(
415                        "You can also assign the GOOGLE_AI_API_KEY environment variable and restart Zed.",
416                    )
417                    .size(LabelSize::Small),
418                )
419                .into_any()
420        }
421    }
422}