google.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use collections::BTreeMap;
  3use credentials_provider::CredentialsProvider;
  4use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture};
  5use google_ai::{
  6    FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
  7    ThinkingConfig, UsageMetadata,
  8};
  9use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
 10use http_client::HttpClient;
 11use language_model::{
 12    AuthenticateError, ConfigurationViewTargetAgent, LanguageModelCompletionError,
 13    LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
 14    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
 15};
 16use language_model::{
 17    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
 18    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
 19    LanguageModelRequest, RateLimiter, Role,
 20};
 21use schemars::JsonSchema;
 22use serde::{Deserialize, Serialize};
 23pub use settings::GoogleAvailableModel as AvailableModel;
 24use settings::{Settings, SettingsStore};
 25use std::pin::Pin;
 26use std::sync::{
 27    Arc, LazyLock,
 28    atomic::{self, AtomicU64},
 29};
 30use strum::IntoEnumIterator;
 31use ui::{Icon, IconName, List, Tooltip, prelude::*};
 32use ui_input::SingleLineInput;
 33use util::{ResultExt, truncate_and_trailoff};
 34use zed_env_vars::EnvVar;
 35
 36use crate::api_key::ApiKey;
 37use crate::api_key::ApiKeyState;
 38use crate::ui::InstructionListItem;
 39
 40const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
 41const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
 42
 43#[derive(Default, Clone, Debug, PartialEq)]
 44pub struct GoogleSettings {
 45    pub api_url: String,
 46    pub available_models: Vec<AvailableModel>,
 47}
 48
 49#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
 50#[serde(tag = "type", rename_all = "lowercase")]
 51pub enum ModelMode {
 52    #[default]
 53    Default,
 54    Thinking {
 55        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
 56        budget_tokens: Option<u32>,
 57    },
 58}
 59
 60pub struct GoogleLanguageModelProvider {
 61    http_client: Arc<dyn HttpClient>,
 62    state: Entity<State>,
 63}
 64
 65pub struct State {
 66    api_key_state: ApiKeyState,
 67}
 68
 69const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
 70const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY";
 71
 72static API_KEY_ENV_VAR: LazyLock<EnvVar> = LazyLock::new(|| {
 73    // Try GEMINI_API_KEY first as primary, fallback to GOOGLE_AI_API_KEY
 74    EnvVar::new(GEMINI_API_KEY_VAR_NAME.into()).or(EnvVar::new(GOOGLE_AI_API_KEY_VAR_NAME.into()))
 75});
 76
 77impl State {
 78    const fn is_authenticated(&self) -> bool {
 79        self.api_key_state.has_key()
 80    }
 81
 82    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
 83        let api_url = GoogleLanguageModelProvider::api_url(cx);
 84        self.api_key_state
 85            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
 86    }
 87
 88    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 89        let api_url = GoogleLanguageModelProvider::api_url(cx);
 90        self.api_key_state.load_if_needed(
 91            api_url,
 92            &API_KEY_ENV_VAR,
 93            |this| &mut this.api_key_state,
 94            cx,
 95        )
 96    }
 97}
 98
 99impl GoogleLanguageModelProvider {
100    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
101        let state = cx.new(|cx| {
102            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
103                let api_url = Self::api_url(cx);
104                this.api_key_state.handle_url_change(
105                    api_url,
106                    &API_KEY_ENV_VAR,
107                    |this| &mut this.api_key_state,
108                    cx,
109                );
110                cx.notify();
111            })
112            .detach();
113            State {
114                api_key_state: ApiKeyState::new(Self::api_url(cx)),
115            }
116        });
117
118        Self { http_client, state }
119    }
120
121    fn create_language_model(&self, model: google_ai::Model) -> Arc<dyn LanguageModel> {
122        Arc::new(GoogleLanguageModel {
123            id: LanguageModelId::from(model.id().to_string()),
124            model,
125            state: self.state.clone(),
126            http_client: self.http_client.clone(),
127            request_limiter: RateLimiter::new(4),
128        })
129    }
130
131    pub fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
132        if let Some(key) = API_KEY_ENV_VAR.value.clone() {
133            return Task::ready(Ok(key));
134        }
135        let credentials_provider = <dyn CredentialsProvider>::global(cx);
136        let api_url = Self::api_url(cx).to_string();
137        cx.spawn(async move |cx| {
138            Ok(
139                ApiKey::load_from_system_keychain(&api_url, credentials_provider.as_ref(), cx)
140                    .await?
141                    .key()
142                    .to_string(),
143            )
144        })
145    }
146
147    fn settings(cx: &App) -> &GoogleSettings {
148        &crate::AllLanguageModelSettings::get_global(cx).google
149    }
150
151    fn api_url(cx: &App) -> SharedString {
152        let api_url = &Self::settings(cx).api_url;
153        if api_url.is_empty() {
154            google_ai::API_URL.into()
155        } else {
156            SharedString::new(api_url.as_str())
157        }
158    }
159}
160
161impl LanguageModelProviderState for GoogleLanguageModelProvider {
162    type ObservableEntity = State;
163
164    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
165        Some(self.state.clone())
166    }
167}
168
169impl LanguageModelProvider for GoogleLanguageModelProvider {
170    fn id(&self) -> LanguageModelProviderId {
171        PROVIDER_ID
172    }
173
174    fn name(&self) -> LanguageModelProviderName {
175        PROVIDER_NAME
176    }
177
178    fn icon(&self) -> IconName {
179        IconName::AiGoogle
180    }
181
182    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
183        Some(self.create_language_model(google_ai::Model::default()))
184    }
185
186    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
187        Some(self.create_language_model(google_ai::Model::default_fast()))
188    }
189
190    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
191        let mut models = BTreeMap::default();
192
193        // Add base models from google_ai::Model::iter()
194        for model in google_ai::Model::iter() {
195            if !matches!(model, google_ai::Model::Custom { .. }) {
196                models.insert(model.id().to_string(), model);
197            }
198        }
199
200        // Override with available models from settings
201        for model in &GoogleLanguageModelProvider::settings(cx).available_models {
202            models.insert(
203                model.name.clone(),
204                google_ai::Model::Custom {
205                    name: model.name.clone(),
206                    display_name: model.display_name.clone(),
207                    max_tokens: model.max_tokens,
208                    mode: model.mode.unwrap_or_default(),
209                },
210            );
211        }
212
213        models
214            .into_values()
215            .map(|model| {
216                Arc::new(GoogleLanguageModel {
217                    id: LanguageModelId::from(model.id().to_string()),
218                    model,
219                    state: self.state.clone(),
220                    http_client: self.http_client.clone(),
221                    request_limiter: RateLimiter::new(4),
222                }) as Arc<dyn LanguageModel>
223            })
224            .collect()
225    }
226
227    fn is_authenticated(&self, cx: &App) -> bool {
228        self.state.read(cx).is_authenticated()
229    }
230
231    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
232        self.state.update(cx, |state, cx| state.authenticate(cx))
233    }
234
235    fn configuration_view(
236        &self,
237        target_agent: language_model::ConfigurationViewTargetAgent,
238        window: &mut Window,
239        cx: &mut App,
240    ) -> AnyView {
241        cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
242            .into()
243    }
244
245    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
246        self.state
247            .update(cx, |state, cx| state.set_api_key(None, cx))
248    }
249}
250
251pub struct GoogleLanguageModel {
252    id: LanguageModelId,
253    model: google_ai::Model,
254    state: Entity<State>,
255    http_client: Arc<dyn HttpClient>,
256    request_limiter: RateLimiter,
257}
258
259impl GoogleLanguageModel {
260    fn stream_completion(
261        &self,
262        request: google_ai::GenerateContentRequest,
263        cx: &AsyncApp,
264    ) -> BoxFuture<
265        'static,
266        Result<futures::stream::BoxStream<'static, Result<GenerateContentResponse>>>,
267    > {
268        let http_client = self.http_client.clone();
269
270        let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
271            let api_url = GoogleLanguageModelProvider::api_url(cx);
272            (state.api_key_state.key(&api_url), api_url)
273        }) else {
274            return future::ready(Err(anyhow!("App state dropped"))).boxed();
275        };
276
277        async move {
278            let api_key = api_key.context("Missing Google API key")?;
279            let request = google_ai::stream_generate_content(
280                http_client.as_ref(),
281                &api_url,
282                &api_key,
283                request,
284            );
285            request.await.context("failed to stream completion")
286        }
287        .boxed()
288    }
289}
290
291impl LanguageModel for GoogleLanguageModel {
292    fn id(&self) -> LanguageModelId {
293        self.id.clone()
294    }
295
296    fn name(&self) -> LanguageModelName {
297        LanguageModelName::from(self.model.display_name().to_string())
298    }
299
300    fn provider_id(&self) -> LanguageModelProviderId {
301        PROVIDER_ID
302    }
303
304    fn provider_name(&self) -> LanguageModelProviderName {
305        PROVIDER_NAME
306    }
307
308    fn supports_tools(&self) -> bool {
309        self.model.supports_tools()
310    }
311
312    fn supports_images(&self) -> bool {
313        self.model.supports_images()
314    }
315
316    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
317        match choice {
318            LanguageModelToolChoice::Auto
319            | LanguageModelToolChoice::Any
320            | LanguageModelToolChoice::None => true,
321        }
322    }
323
324    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
325        LanguageModelToolSchemaFormat::JsonSchemaSubset
326    }
327
328    fn telemetry_id(&self) -> String {
329        format!("google/{}", self.model.request_id())
330    }
331
332    fn max_token_count(&self) -> u64 {
333        self.model.max_token_count()
334    }
335
336    fn max_output_tokens(&self) -> Option<u64> {
337        self.model.max_output_tokens()
338    }
339
340    fn count_tokens(
341        &self,
342        request: LanguageModelRequest,
343        cx: &App,
344    ) -> BoxFuture<'static, Result<u64>> {
345        let model_id = self.model.request_id().to_string();
346        let request = into_google(request, model_id, self.model.mode());
347        let http_client = self.http_client.clone();
348        let api_url = GoogleLanguageModelProvider::api_url(cx);
349        let api_key = self.state.read(cx).api_key_state.key(&api_url);
350
351        async move {
352            let Some(api_key) = api_key else {
353                return Err(LanguageModelCompletionError::NoApiKey {
354                    provider: PROVIDER_NAME,
355                }
356                .into());
357            };
358            let response = google_ai::count_tokens(
359                http_client.as_ref(),
360                &api_url,
361                &api_key,
362                google_ai::CountTokensRequest {
363                    generate_content_request: request,
364                },
365            )
366            .await?;
367            Ok(response.total_tokens)
368        }
369        .boxed()
370    }
371
372    fn stream_completion(
373        &self,
374        request: LanguageModelRequest,
375        cx: &AsyncApp,
376    ) -> BoxFuture<
377        'static,
378        Result<
379            futures::stream::BoxStream<
380                'static,
381                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
382            >,
383            LanguageModelCompletionError,
384        >,
385    > {
386        let request = into_google(
387            request,
388            self.model.request_id().to_string(),
389            self.model.mode(),
390        );
391        let request = self.stream_completion(request, cx);
392        let future = self.request_limiter.stream(async move {
393            let response = request.await.map_err(LanguageModelCompletionError::from)?;
394            Ok(GoogleEventMapper::new().map_stream(response))
395        });
396        async move { Ok(future.await?.boxed()) }.boxed()
397    }
398}
399
400pub fn into_google(
401    mut request: LanguageModelRequest,
402    model_id: String,
403    mode: GoogleModelMode,
404) -> google_ai::GenerateContentRequest {
405    fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
406        content
407            .into_iter()
408            .flat_map(|content| match content {
409                language_model::MessageContent::Text(text) => {
410                    if !text.is_empty() {
411                        vec![Part::TextPart(google_ai::TextPart { text })]
412                    } else {
413                        vec![]
414                    }
415                }
416                language_model::MessageContent::Thinking {
417                    text: _,
418                    signature: Some(signature),
419                } => {
420                    if !signature.is_empty() {
421                        vec![Part::ThoughtPart(google_ai::ThoughtPart {
422                            thought: true,
423                            thought_signature: signature,
424                        })]
425                    } else {
426                        vec![]
427                    }
428                }
429                language_model::MessageContent::Thinking { .. } => {
430                    vec![]
431                }
432                language_model::MessageContent::RedactedThinking(_) => vec![],
433                language_model::MessageContent::Image(image) => {
434                    vec![Part::InlineDataPart(google_ai::InlineDataPart {
435                        inline_data: google_ai::GenerativeContentBlob {
436                            mime_type: "image/png".to_string(),
437                            data: image.source.to_string(),
438                        },
439                    })]
440                }
441                language_model::MessageContent::ToolUse(tool_use) => {
442                    vec![Part::FunctionCallPart(google_ai::FunctionCallPart {
443                        function_call: google_ai::FunctionCall {
444                            name: tool_use.name.to_string(),
445                            args: tool_use.input,
446                        },
447                    })]
448                }
449                language_model::MessageContent::ToolResult(tool_result) => {
450                    match tool_result.content {
451                        language_model::LanguageModelToolResultContent::Text(text) => {
452                            vec![Part::FunctionResponsePart(
453                                google_ai::FunctionResponsePart {
454                                    function_response: google_ai::FunctionResponse {
455                                        name: tool_result.tool_name.to_string(),
456                                        // The API expects a valid JSON object
457                                        response: serde_json::json!({
458                                            "output": text
459                                        }),
460                                    },
461                                },
462                            )]
463                        }
464                        language_model::LanguageModelToolResultContent::Image(image) => {
465                            vec![
466                                Part::FunctionResponsePart(google_ai::FunctionResponsePart {
467                                    function_response: google_ai::FunctionResponse {
468                                        name: tool_result.tool_name.to_string(),
469                                        // The API expects a valid JSON object
470                                        response: serde_json::json!({
471                                            "output": "Tool responded with an image"
472                                        }),
473                                    },
474                                }),
475                                Part::InlineDataPart(google_ai::InlineDataPart {
476                                    inline_data: google_ai::GenerativeContentBlob {
477                                        mime_type: "image/png".to_string(),
478                                        data: image.source.to_string(),
479                                    },
480                                }),
481                            ]
482                        }
483                    }
484                }
485            })
486            .collect()
487    }
488
489    let system_instructions = if request
490        .messages
491        .first()
492        .is_some_and(|msg| matches!(msg.role, Role::System))
493    {
494        let message = request.messages.remove(0);
495        Some(SystemInstruction {
496            parts: map_content(message.content),
497        })
498    } else {
499        None
500    };
501
502    google_ai::GenerateContentRequest {
503        model: google_ai::ModelName { model_id },
504        system_instruction: system_instructions,
505        contents: request
506            .messages
507            .into_iter()
508            .filter_map(|message| {
509                let parts = map_content(message.content);
510                if parts.is_empty() {
511                    None
512                } else {
513                    Some(google_ai::Content {
514                        parts,
515                        role: match message.role {
516                            Role::User => google_ai::Role::User,
517                            Role::Assistant => google_ai::Role::Model,
518                            Role::System => google_ai::Role::User, // Google AI doesn't have a system role
519                        },
520                    })
521                }
522            })
523            .collect(),
524        generation_config: Some(google_ai::GenerationConfig {
525            candidate_count: Some(1),
526            stop_sequences: Some(request.stop),
527            max_output_tokens: None,
528            temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
529            thinking_config: match (request.thinking_allowed, mode) {
530                (true, GoogleModelMode::Thinking { budget_tokens }) => {
531                    budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
532                }
533                _ => None,
534            },
535            top_p: None,
536            top_k: None,
537        }),
538        safety_settings: None,
539        tools: (!request.tools.is_empty()).then(|| {
540            vec![google_ai::Tool {
541                function_declarations: request
542                    .tools
543                    .into_iter()
544                    .map(|tool| FunctionDeclaration {
545                        name: tool.name,
546                        description: tool.description,
547                        parameters: tool.input_schema,
548                    })
549                    .collect(),
550            }]
551        }),
552        tool_config: request.tool_choice.map(|choice| google_ai::ToolConfig {
553            function_calling_config: google_ai::FunctionCallingConfig {
554                mode: match choice {
555                    LanguageModelToolChoice::Auto => google_ai::FunctionCallingMode::Auto,
556                    LanguageModelToolChoice::Any => google_ai::FunctionCallingMode::Any,
557                    LanguageModelToolChoice::None => google_ai::FunctionCallingMode::None,
558                },
559                allowed_function_names: None,
560            },
561        }),
562    }
563}
564
565pub struct GoogleEventMapper {
566    usage: UsageMetadata,
567    stop_reason: StopReason,
568}
569
570impl GoogleEventMapper {
571    pub fn new() -> Self {
572        Self {
573            usage: UsageMetadata::default(),
574            stop_reason: StopReason::EndTurn,
575        }
576    }
577
578    pub fn map_stream(
579        mut self,
580        events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
581    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
582    {
583        events
584            .map(Some)
585            .chain(futures::stream::once(async { None }))
586            .flat_map(move |event| {
587                futures::stream::iter(match event {
588                    Some(Ok(event)) => self.map_event(event),
589                    Some(Err(error)) => {
590                        vec![Err(LanguageModelCompletionError::from(error))]
591                    }
592                    None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
593                })
594            })
595    }
596
597    pub fn map_event(
598        &mut self,
599        event: GenerateContentResponse,
600    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
601        static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
602
603        let mut events: Vec<_> = Vec::new();
604        let mut wants_to_use_tool = false;
605        if let Some(usage_metadata) = event.usage_metadata {
606            update_usage(&mut self.usage, &usage_metadata);
607            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
608                convert_usage(&self.usage),
609            )))
610        }
611
612        if let Some(prompt_feedback) = event.prompt_feedback
613            && let Some(block_reason) = prompt_feedback.block_reason.as_deref()
614        {
615            self.stop_reason = match block_reason {
616                "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => {
617                    StopReason::Refusal
618                }
619                _ => {
620                    log::error!("Unexpected Google block_reason: {block_reason}");
621                    StopReason::Refusal
622                }
623            };
624            events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
625
626            return events;
627        }
628
629        if let Some(candidates) = event.candidates {
630            for candidate in candidates {
631                if let Some(finish_reason) = candidate.finish_reason.as_deref() {
632                    self.stop_reason = match finish_reason {
633                        "STOP" => StopReason::EndTurn,
634                        "MAX_TOKENS" => StopReason::MaxTokens,
635                        _ => {
636                            log::error!("Unexpected google finish_reason: {finish_reason}");
637                            StopReason::EndTurn
638                        }
639                    };
640                }
641                candidate
642                    .content
643                    .parts
644                    .into_iter()
645                    .for_each(|part| match part {
646                        Part::TextPart(text_part) => {
647                            events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
648                        }
649                        Part::InlineDataPart(_) => {}
650                        Part::FunctionCallPart(function_call_part) => {
651                            wants_to_use_tool = true;
652                            let name: Arc<str> = function_call_part.function_call.name.into();
653                            let next_tool_id =
654                                TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
655                            let id: LanguageModelToolUseId =
656                                format!("{}-{}", name, next_tool_id).into();
657
658                            events.push(Ok(LanguageModelCompletionEvent::ToolUse(
659                                LanguageModelToolUse {
660                                    id,
661                                    name,
662                                    is_input_complete: true,
663                                    raw_input: function_call_part.function_call.args.to_string(),
664                                    input: function_call_part.function_call.args,
665                                },
666                            )));
667                        }
668                        Part::FunctionResponsePart(_) => {}
669                        Part::ThoughtPart(part) => {
670                            events.push(Ok(LanguageModelCompletionEvent::Thinking {
671                                text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
672                                signature: Some(part.thought_signature),
673                            }));
674                        }
675                    });
676            }
677        }
678
679        // Even when Gemini wants to use a Tool, the API
680        // responds with `finish_reason: STOP`
681        if wants_to_use_tool {
682            self.stop_reason = StopReason::ToolUse;
683            events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
684        }
685        events
686    }
687}
688
689pub fn count_google_tokens(
690    request: LanguageModelRequest,
691    cx: &App,
692) -> BoxFuture<'static, Result<u64>> {
693    // We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly.
694    // So we have to use tokenizer from tiktoken_rs to count tokens.
695    cx.background_spawn(async move {
696        let messages = request
697            .messages
698            .into_iter()
699            .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
700                role: match message.role {
701                    Role::User => "user".into(),
702                    Role::Assistant => "assistant".into(),
703                    Role::System => "system".into(),
704                },
705                content: Some(message.string_contents()),
706                name: None,
707                function_call: None,
708            })
709            .collect::<Vec<_>>();
710
711        // Tiktoken doesn't yet support these models, so we manually use the
712        // same tokenizer as GPT-4.
713        tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
714    })
715    .boxed()
716}
717
718const fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
719    if let Some(prompt_token_count) = new.prompt_token_count {
720        usage.prompt_token_count = Some(prompt_token_count);
721    }
722    if let Some(cached_content_token_count) = new.cached_content_token_count {
723        usage.cached_content_token_count = Some(cached_content_token_count);
724    }
725    if let Some(candidates_token_count) = new.candidates_token_count {
726        usage.candidates_token_count = Some(candidates_token_count);
727    }
728    if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
729        usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
730    }
731    if let Some(thoughts_token_count) = new.thoughts_token_count {
732        usage.thoughts_token_count = Some(thoughts_token_count);
733    }
734    if let Some(total_token_count) = new.total_token_count {
735        usage.total_token_count = Some(total_token_count);
736    }
737}
738
739fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
740    let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
741    let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
742    let input_tokens = prompt_tokens - cached_tokens;
743    let output_tokens = usage.candidates_token_count.unwrap_or(0);
744
745    language_model::TokenUsage {
746        input_tokens,
747        output_tokens,
748        cache_read_input_tokens: cached_tokens,
749        cache_creation_input_tokens: 0,
750    }
751}
752
753struct ConfigurationView {
754    api_key_editor: Entity<SingleLineInput>,
755    state: Entity<State>,
756    target_agent: language_model::ConfigurationViewTargetAgent,
757    load_credentials_task: Option<Task<()>>,
758}
759
760impl ConfigurationView {
761    fn new(
762        state: Entity<State>,
763        target_agent: language_model::ConfigurationViewTargetAgent,
764        window: &mut Window,
765        cx: &mut Context<Self>,
766    ) -> Self {
767        cx.observe(&state, |_, _, cx| {
768            cx.notify();
769        })
770        .detach();
771
772        let load_credentials_task = Some(cx.spawn_in(window, {
773            let state = state.clone();
774            async move |this, cx| {
775                if let Some(task) = state
776                    .update(cx, |state, cx| state.authenticate(cx))
777                    .log_err()
778                {
779                    // We don't log an error, because "not signed in" is also an error.
780                    let _ = task.await;
781                }
782                this.update(cx, |this, cx| {
783                    this.load_credentials_task = None;
784                    cx.notify();
785                })
786                .log_err();
787            }
788        }));
789
790        Self {
791            api_key_editor: cx.new(|cx| SingleLineInput::new(window, cx, "AIzaSy...")),
792            target_agent,
793            state,
794            load_credentials_task,
795        }
796    }
797
798    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
799        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
800        if api_key.is_empty() {
801            return;
802        }
803
804        // url changes can cause the editor to be displayed again
805        self.api_key_editor
806            .update(cx, |editor, cx| editor.set_text("", window, cx));
807
808        let state = self.state.clone();
809        cx.spawn_in(window, async move |_, cx| {
810            state
811                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
812                .await
813        })
814        .detach_and_log_err(cx);
815    }
816
817    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
818        self.api_key_editor
819            .update(cx, |editor, cx| editor.set_text("", window, cx));
820
821        let state = self.state.clone();
822        cx.spawn_in(window, async move |_, cx| {
823            state
824                .update(cx, |state, cx| state.set_api_key(None, cx))?
825                .await
826        })
827        .detach_and_log_err(cx);
828    }
829
830    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
831        !self.state.read(cx).is_authenticated()
832    }
833}
834
835impl Render for ConfigurationView {
836    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
837        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
838
839        if self.load_credentials_task.is_some() {
840            div().child(Label::new("Loading credentials...")).into_any()
841        } else if self.should_render_editor(cx) {
842            v_flex()
843                .size_full()
844                .on_action(cx.listener(Self::save_api_key))
845                .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent {
846                    ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Google AI".into(),
847                    ConfigurationViewTargetAgent::Other(agent) => agent.clone(),
848                })))
849                .child(
850                    List::new()
851                        .child(InstructionListItem::new(
852                            "Create one by visiting",
853                            Some("Google AI's console"),
854                            Some("https://aistudio.google.com/app/apikey"),
855                        ))
856                        .child(InstructionListItem::text_only(
857                            "Paste your API key below and hit enter to start using the assistant",
858                        )),
859                )
860                .child(self.api_key_editor.clone())
861                .child(
862                    Label::new(
863                        format!("You can also assign the {GEMINI_API_KEY_VAR_NAME} environment variable and restart Zed."),
864                    )
865                    .size(LabelSize::Small).color(Color::Muted),
866                )
867                .into_any()
868        } else {
869            h_flex()
870                .mt_1()
871                .p_1()
872                .justify_between()
873                .rounded_md()
874                .border_1()
875                .border_color(cx.theme().colors().border)
876                .bg(cx.theme().colors().background)
877                .child(
878                    h_flex()
879                        .gap_1()
880                        .child(Icon::new(IconName::Check).color(Color::Success))
881                        .child(Label::new(if env_var_set {
882                            format!("API key set in {} environment variable", API_KEY_ENV_VAR.name)
883                        } else {
884                            let api_url = GoogleLanguageModelProvider::api_url(cx);
885                            if api_url == google_ai::API_URL {
886                                "API key configured".to_string()
887                            } else {
888                                format!("API key configured for {}", truncate_and_trailoff(&api_url, 32))
889                            }
890                        })),
891                )
892                .child(
893                    Button::new("reset-key", "Reset Key")
894                        .label_size(LabelSize::Small)
895                        .icon(Some(IconName::Trash))
896                        .icon_size(IconSize::Small)
897                        .icon_position(IconPosition::Start)
898                        .disabled(env_var_set)
899                        .when(env_var_set, |this| {
900                            this.tooltip(Tooltip::text(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR_NAME} and {GOOGLE_AI_API_KEY_VAR_NAME} environment variables are unset.")))
901                        })
902                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
903                )
904                .into_any()
905        }
906    }
907}