google.rs

  1use anyhow::{Context as _, Result};
  2use collections::BTreeMap;
  3use credentials_provider::CredentialsProvider;
  4use futures::{FutureExt, StreamExt, future::BoxFuture};
  5pub use google_ai::completion::{GoogleEventMapper, count_google_tokens, into_google};
  6use google_ai::{GenerateContentResponse, GoogleModelMode};
  7use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
  8use http_client::HttpClient;
  9use language_model::{
 10    AuthenticateError, ConfigurationViewTargetAgent, EnvVar, LanguageModelCompletionError,
 11    LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
 12};
 13use language_model::{
 14    GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME, IconOrSvg, LanguageModel, LanguageModelId,
 15    LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
 16    LanguageModelProviderState, LanguageModelRequest, RateLimiter,
 17};
 18use schemars::JsonSchema;
 19use serde::{Deserialize, Serialize};
 20pub use settings::GoogleAvailableModel as AvailableModel;
 21use settings::{Settings, SettingsStore};
 22use std::sync::{Arc, LazyLock};
 23use strum::IntoEnumIterator;
 24use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
 25use ui_input::InputField;
 26use util::ResultExt;
 27
 28use language_model::ApiKeyState;
 29
 30const PROVIDER_ID: LanguageModelProviderId = GOOGLE_PROVIDER_ID;
 31const PROVIDER_NAME: LanguageModelProviderName = GOOGLE_PROVIDER_NAME;
 32
 33#[derive(Default, Clone, Debug, PartialEq)]
 34pub struct GoogleSettings {
 35    pub api_url: String,
 36    pub available_models: Vec<AvailableModel>,
 37}
 38
 39#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
 40#[serde(tag = "type", rename_all = "lowercase")]
 41pub enum ModelMode {
 42    #[default]
 43    Default,
 44    Thinking {
 45        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
 46        budget_tokens: Option<u32>,
 47    },
 48}
 49
 50pub struct GoogleLanguageModelProvider {
 51    http_client: Arc<dyn HttpClient>,
 52    state: Entity<State>,
 53}
 54
 55pub struct State {
 56    api_key_state: ApiKeyState,
 57    credentials_provider: Arc<dyn CredentialsProvider>,
 58}
 59
 60const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
 61const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY";
 62
 63static API_KEY_ENV_VAR: LazyLock<EnvVar> = LazyLock::new(|| {
 64    // Try GEMINI_API_KEY first as primary, fallback to GOOGLE_AI_API_KEY
 65    EnvVar::new(GEMINI_API_KEY_VAR_NAME.into()).or(EnvVar::new(GOOGLE_AI_API_KEY_VAR_NAME.into()))
 66});
 67
 68impl State {
 69    fn is_authenticated(&self) -> bool {
 70        self.api_key_state.has_key()
 71    }
 72
 73    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
 74        let credentials_provider = self.credentials_provider.clone();
 75        let api_url = GoogleLanguageModelProvider::api_url(cx);
 76        self.api_key_state.store(
 77            api_url,
 78            api_key,
 79            |this| &mut this.api_key_state,
 80            credentials_provider,
 81            cx,
 82        )
 83    }
 84
 85    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 86        let credentials_provider = self.credentials_provider.clone();
 87        let api_url = GoogleLanguageModelProvider::api_url(cx);
 88        self.api_key_state.load_if_needed(
 89            api_url,
 90            |this| &mut this.api_key_state,
 91            credentials_provider,
 92            cx,
 93        )
 94    }
 95}
 96
 97impl GoogleLanguageModelProvider {
 98    pub fn new(
 99        http_client: Arc<dyn HttpClient>,
100        credentials_provider: Arc<dyn CredentialsProvider>,
101        cx: &mut App,
102    ) -> Self {
103        let state = cx.new(|cx| {
104            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
105                let credentials_provider = this.credentials_provider.clone();
106                let api_url = Self::api_url(cx);
107                this.api_key_state.handle_url_change(
108                    api_url,
109                    |this| &mut this.api_key_state,
110                    credentials_provider,
111                    cx,
112                );
113                cx.notify();
114            })
115            .detach();
116            State {
117                api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
118                credentials_provider,
119            }
120        });
121
122        Self { http_client, state }
123    }
124
125    fn create_language_model(&self, model: google_ai::Model) -> Arc<dyn LanguageModel> {
126        Arc::new(GoogleLanguageModel {
127            id: LanguageModelId::from(model.id().to_string()),
128            model,
129            state: self.state.clone(),
130            http_client: self.http_client.clone(),
131            request_limiter: RateLimiter::new(4),
132        })
133    }
134
135    fn settings(cx: &App) -> &GoogleSettings {
136        &crate::AllLanguageModelSettings::get_global(cx).google
137    }
138
139    fn api_url(cx: &App) -> SharedString {
140        let api_url = &Self::settings(cx).api_url;
141        if api_url.is_empty() {
142            google_ai::API_URL.into()
143        } else {
144            SharedString::new(api_url.as_str())
145        }
146    }
147}
148
149impl LanguageModelProviderState for GoogleLanguageModelProvider {
150    type ObservableEntity = State;
151
152    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
153        Some(self.state.clone())
154    }
155}
156
157impl LanguageModelProvider for GoogleLanguageModelProvider {
158    fn id(&self) -> LanguageModelProviderId {
159        PROVIDER_ID
160    }
161
162    fn name(&self) -> LanguageModelProviderName {
163        PROVIDER_NAME
164    }
165
166    fn icon(&self) -> IconOrSvg {
167        IconOrSvg::Icon(IconName::AiGoogle)
168    }
169
170    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
171        Some(self.create_language_model(google_ai::Model::default()))
172    }
173
174    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
175        Some(self.create_language_model(google_ai::Model::default_fast()))
176    }
177
178    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
179        let mut models = BTreeMap::default();
180
181        // Add base models from google_ai::Model::iter()
182        for model in google_ai::Model::iter() {
183            if !matches!(model, google_ai::Model::Custom { .. }) {
184                models.insert(model.id().to_string(), model);
185            }
186        }
187
188        // Override with available models from settings
189        for model in &GoogleLanguageModelProvider::settings(cx).available_models {
190            models.insert(
191                model.name.clone(),
192                google_ai::Model::Custom {
193                    name: model.name.clone(),
194                    display_name: model.display_name.clone(),
195                    max_tokens: model.max_tokens,
196                    mode: model.mode.unwrap_or_default(),
197                },
198            );
199        }
200
201        models
202            .into_values()
203            .map(|model| {
204                Arc::new(GoogleLanguageModel {
205                    id: LanguageModelId::from(model.id().to_string()),
206                    model,
207                    state: self.state.clone(),
208                    http_client: self.http_client.clone(),
209                    request_limiter: RateLimiter::new(4),
210                }) as Arc<dyn LanguageModel>
211            })
212            .collect()
213    }
214
215    fn is_authenticated(&self, cx: &App) -> bool {
216        self.state.read(cx).is_authenticated()
217    }
218
219    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
220        self.state.update(cx, |state, cx| state.authenticate(cx))
221    }
222
223    fn configuration_view(
224        &self,
225        target_agent: language_model::ConfigurationViewTargetAgent,
226        window: &mut Window,
227        cx: &mut App,
228    ) -> AnyView {
229        cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
230            .into()
231    }
232
233    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
234        self.state
235            .update(cx, |state, cx| state.set_api_key(None, cx))
236    }
237}
238
239pub struct GoogleLanguageModel {
240    id: LanguageModelId,
241    model: google_ai::Model,
242    state: Entity<State>,
243    http_client: Arc<dyn HttpClient>,
244    request_limiter: RateLimiter,
245}
246
247impl GoogleLanguageModel {
248    fn stream_completion(
249        &self,
250        request: google_ai::GenerateContentRequest,
251        cx: &AsyncApp,
252    ) -> BoxFuture<
253        'static,
254        Result<futures::stream::BoxStream<'static, Result<GenerateContentResponse>>>,
255    > {
256        let http_client = self.http_client.clone();
257
258        let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
259            let api_url = GoogleLanguageModelProvider::api_url(cx);
260            (state.api_key_state.key(&api_url), api_url)
261        });
262
263        async move {
264            let api_key = api_key.context("Missing Google API key")?;
265            let request = google_ai::stream_generate_content(
266                http_client.as_ref(),
267                &api_url,
268                &api_key,
269                request,
270            );
271            request.await.context("failed to stream completion")
272        }
273        .boxed()
274    }
275}
276
277impl LanguageModel for GoogleLanguageModel {
278    fn id(&self) -> LanguageModelId {
279        self.id.clone()
280    }
281
282    fn name(&self) -> LanguageModelName {
283        LanguageModelName::from(self.model.display_name().to_string())
284    }
285
286    fn provider_id(&self) -> LanguageModelProviderId {
287        PROVIDER_ID
288    }
289
290    fn provider_name(&self) -> LanguageModelProviderName {
291        PROVIDER_NAME
292    }
293
294    fn supports_tools(&self) -> bool {
295        self.model.supports_tools()
296    }
297
298    fn supports_images(&self) -> bool {
299        self.model.supports_images()
300    }
301
302    fn supports_thinking(&self) -> bool {
303        matches!(self.model.mode(), GoogleModelMode::Thinking { .. })
304    }
305
306    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
307        match choice {
308            LanguageModelToolChoice::Auto
309            | LanguageModelToolChoice::Any
310            | LanguageModelToolChoice::None => true,
311        }
312    }
313
314    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
315        LanguageModelToolSchemaFormat::JsonSchemaSubset
316    }
317
318    fn telemetry_id(&self) -> String {
319        format!("google/{}", self.model.request_id())
320    }
321
322    fn max_token_count(&self) -> u64 {
323        self.model.max_token_count()
324    }
325
326    fn max_output_tokens(&self) -> Option<u64> {
327        self.model.max_output_tokens()
328    }
329
330    fn count_tokens(
331        &self,
332        request: LanguageModelRequest,
333        cx: &App,
334    ) -> BoxFuture<'static, Result<u64>> {
335        let model_id = self.model.request_id().to_string();
336        let request = into_google(request, model_id, self.model.mode());
337        let http_client = self.http_client.clone();
338        let api_url = GoogleLanguageModelProvider::api_url(cx);
339        let api_key = self.state.read(cx).api_key_state.key(&api_url);
340
341        async move {
342            let Some(api_key) = api_key else {
343                return Err(LanguageModelCompletionError::NoApiKey {
344                    provider: PROVIDER_NAME,
345                }
346                .into());
347            };
348            let response = google_ai::count_tokens(
349                http_client.as_ref(),
350                &api_url,
351                &api_key,
352                google_ai::CountTokensRequest {
353                    generate_content_request: request,
354                },
355            )
356            .await?;
357            Ok(response.total_tokens)
358        }
359        .boxed()
360    }
361
362    fn stream_completion(
363        &self,
364        request: LanguageModelRequest,
365        cx: &AsyncApp,
366    ) -> BoxFuture<
367        'static,
368        Result<
369            futures::stream::BoxStream<
370                'static,
371                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
372            >,
373            LanguageModelCompletionError,
374        >,
375    > {
376        let request = into_google(
377            request,
378            self.model.request_id().to_string(),
379            self.model.mode(),
380        );
381        let request = self.stream_completion(request, cx);
382        let future = self.request_limiter.stream(async move {
383            let response = request.await.map_err(LanguageModelCompletionError::from)?;
384            Ok(GoogleEventMapper::new().map_stream(response))
385        });
386        async move { Ok(future.await?.boxed()) }.boxed()
387    }
388}
389
390struct ConfigurationView {
391    api_key_editor: Entity<InputField>,
392    state: Entity<State>,
393    target_agent: language_model::ConfigurationViewTargetAgent,
394    load_credentials_task: Option<Task<()>>,
395}
396
397impl ConfigurationView {
398    fn new(
399        state: Entity<State>,
400        target_agent: language_model::ConfigurationViewTargetAgent,
401        window: &mut Window,
402        cx: &mut Context<Self>,
403    ) -> Self {
404        cx.observe(&state, |_, _, cx| {
405            cx.notify();
406        })
407        .detach();
408
409        let load_credentials_task = Some(cx.spawn_in(window, {
410            let state = state.clone();
411            async move |this, cx| {
412                if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
413                    // We don't log an error, because "not signed in" is also an error.
414                    let _ = task.await;
415                }
416                this.update(cx, |this, cx| {
417                    this.load_credentials_task = None;
418                    cx.notify();
419                })
420                .log_err();
421            }
422        }));
423
424        Self {
425            api_key_editor: cx.new(|cx| InputField::new(window, cx, "AIzaSy...")),
426            target_agent,
427            state,
428            load_credentials_task,
429        }
430    }
431
432    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
433        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
434        if api_key.is_empty() {
435            return;
436        }
437
438        // url changes can cause the editor to be displayed again
439        self.api_key_editor
440            .update(cx, |editor, cx| editor.set_text("", window, cx));
441
442        let state = self.state.clone();
443        cx.spawn_in(window, async move |_, cx| {
444            state
445                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
446                .await
447        })
448        .detach_and_log_err(cx);
449    }
450
451    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
452        self.api_key_editor
453            .update(cx, |editor, cx| editor.set_text("", window, cx));
454
455        let state = self.state.clone();
456        cx.spawn_in(window, async move |_, cx| {
457            state
458                .update(cx, |state, cx| state.set_api_key(None, cx))
459                .await
460        })
461        .detach_and_log_err(cx);
462    }
463
464    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
465        !self.state.read(cx).is_authenticated()
466    }
467}
468
469impl Render for ConfigurationView {
470    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
471        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
472        let configured_card_label = if env_var_set {
473            format!(
474                "API key set in {} environment variable",
475                API_KEY_ENV_VAR.name
476            )
477        } else {
478            let api_url = GoogleLanguageModelProvider::api_url(cx);
479            if api_url == google_ai::API_URL {
480                "API key configured".to_string()
481            } else {
482                format!("API key configured for {}", api_url)
483            }
484        };
485
486        if self.load_credentials_task.is_some() {
487            div()
488                .child(Label::new("Loading credentials..."))
489                .into_any_element()
490        } else if self.should_render_editor(cx) {
491            v_flex()
492                .size_full()
493                .on_action(cx.listener(Self::save_api_key))
494                .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent {
495                    ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Google AI".into(),
496                    ConfigurationViewTargetAgent::Other(agent) => agent.clone(),
497                })))
498                .child(
499                    List::new()
500                        .child(
501                            ListBulletItem::new("")
502                                .child(Label::new("Create one by visiting"))
503                                .child(ButtonLink::new("Google AI's console", "https://aistudio.google.com/app/apikey"))
504                        )
505                        .child(
506                            ListBulletItem::new("Paste your API key below and hit enter to start using the agent")
507                        )
508                )
509                .child(self.api_key_editor.clone())
510                .child(
511                    Label::new(
512                        format!("You can also set the {GEMINI_API_KEY_VAR_NAME} environment variable and restart Zed."),
513                    )
514                    .size(LabelSize::Small).color(Color::Muted),
515                )
516                .into_any_element()
517        } else {
518            ConfiguredApiCard::new(configured_card_label)
519                .disabled(env_var_set)
520                .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
521                .when(env_var_set, |this| {
522                    this.tooltip_label(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR_NAME} and {GOOGLE_AI_API_KEY_VAR_NAME} environment variables are unset."))
523                })
524                .into_any_element()
525        }
526    }
527}