google.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use collections::BTreeMap;
   3use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture};
   4use google_ai::{
   5    FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
   6    ThinkingConfig, UsageMetadata,
   7};
   8use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
   9use http_client::HttpClient;
  10use language_model::{
  11    AuthenticateError, ConfigurationViewTargetAgent, LanguageModelCompletionError,
  12    LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
  13    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
  14};
  15use language_model::{
  16    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
  17    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  18    LanguageModelRequest, RateLimiter, Role,
  19};
  20use schemars::JsonSchema;
  21use serde::{Deserialize, Serialize};
  22pub use settings::GoogleAvailableModel as AvailableModel;
  23use settings::{Settings, SettingsStore};
  24use std::pin::Pin;
  25use std::sync::{
  26    Arc, LazyLock,
  27    atomic::{self, AtomicU64},
  28};
  29use strum::IntoEnumIterator;
  30use ui::{List, prelude::*};
  31use ui_input::InputField;
  32use util::ResultExt;
  33use zed_env_vars::EnvVar;
  34
  35use crate::api_key::ApiKeyState;
  36use crate::ui::{ConfiguredApiCard, InstructionListItem};
  37
  38const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
  39const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
  40
  41#[derive(Default, Clone, Debug, PartialEq)]
  42pub struct GoogleSettings {
  43    pub api_url: String,
  44    pub available_models: Vec<AvailableModel>,
  45}
  46
  47#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
  48#[serde(tag = "type", rename_all = "lowercase")]
  49pub enum ModelMode {
  50    #[default]
  51    Default,
  52    Thinking {
  53        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  54        budget_tokens: Option<u32>,
  55    },
  56}
  57
  58pub struct GoogleLanguageModelProvider {
  59    http_client: Arc<dyn HttpClient>,
  60    state: Entity<State>,
  61}
  62
  63pub struct State {
  64    api_key_state: ApiKeyState,
  65}
  66
  67const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
  68const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY";
  69
  70static API_KEY_ENV_VAR: LazyLock<EnvVar> = LazyLock::new(|| {
  71    // Try GEMINI_API_KEY first as primary, fallback to GOOGLE_AI_API_KEY
  72    EnvVar::new(GEMINI_API_KEY_VAR_NAME.into()).or(EnvVar::new(GOOGLE_AI_API_KEY_VAR_NAME.into()))
  73});
  74
  75impl State {
  76    fn is_authenticated(&self) -> bool {
  77        self.api_key_state.has_key()
  78    }
  79
  80    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
  81        let api_url = GoogleLanguageModelProvider::api_url(cx);
  82        self.api_key_state
  83            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
  84    }
  85
  86    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
  87        let api_url = GoogleLanguageModelProvider::api_url(cx);
  88        self.api_key_state.load_if_needed(
  89            api_url,
  90            &API_KEY_ENV_VAR,
  91            |this| &mut this.api_key_state,
  92            cx,
  93        )
  94    }
  95}
  96
  97impl GoogleLanguageModelProvider {
  98    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
  99        let state = cx.new(|cx| {
 100            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
 101                let api_url = Self::api_url(cx);
 102                this.api_key_state.handle_url_change(
 103                    api_url,
 104                    &API_KEY_ENV_VAR,
 105                    |this| &mut this.api_key_state,
 106                    cx,
 107                );
 108                cx.notify();
 109            })
 110            .detach();
 111            State {
 112                api_key_state: ApiKeyState::new(Self::api_url(cx)),
 113            }
 114        });
 115
 116        Self { http_client, state }
 117    }
 118
 119    fn create_language_model(&self, model: google_ai::Model) -> Arc<dyn LanguageModel> {
 120        Arc::new(GoogleLanguageModel {
 121            id: LanguageModelId::from(model.id().to_string()),
 122            model,
 123            state: self.state.clone(),
 124            http_client: self.http_client.clone(),
 125            request_limiter: RateLimiter::new(4),
 126        })
 127    }
 128
 129    fn settings(cx: &App) -> &GoogleSettings {
 130        &crate::AllLanguageModelSettings::get_global(cx).google
 131    }
 132
 133    fn api_url(cx: &App) -> SharedString {
 134        let api_url = &Self::settings(cx).api_url;
 135        if api_url.is_empty() {
 136            google_ai::API_URL.into()
 137        } else {
 138            SharedString::new(api_url.as_str())
 139        }
 140    }
 141}
 142
 143impl LanguageModelProviderState for GoogleLanguageModelProvider {
 144    type ObservableEntity = State;
 145
 146    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 147        Some(self.state.clone())
 148    }
 149}
 150
 151impl LanguageModelProvider for GoogleLanguageModelProvider {
 152    fn id(&self) -> LanguageModelProviderId {
 153        PROVIDER_ID
 154    }
 155
 156    fn name(&self) -> LanguageModelProviderName {
 157        PROVIDER_NAME
 158    }
 159
 160    fn icon(&self) -> IconName {
 161        IconName::AiGoogle
 162    }
 163
 164    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 165        Some(self.create_language_model(google_ai::Model::default()))
 166    }
 167
 168    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 169        Some(self.create_language_model(google_ai::Model::default_fast()))
 170    }
 171
 172    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 173        let mut models = BTreeMap::default();
 174
 175        // Add base models from google_ai::Model::iter()
 176        for model in google_ai::Model::iter() {
 177            if !matches!(model, google_ai::Model::Custom { .. }) {
 178                models.insert(model.id().to_string(), model);
 179            }
 180        }
 181
 182        // Override with available models from settings
 183        for model in &GoogleLanguageModelProvider::settings(cx).available_models {
 184            models.insert(
 185                model.name.clone(),
 186                google_ai::Model::Custom {
 187                    name: model.name.clone(),
 188                    display_name: model.display_name.clone(),
 189                    max_tokens: model.max_tokens,
 190                    mode: model.mode.unwrap_or_default(),
 191                },
 192            );
 193        }
 194
 195        models
 196            .into_values()
 197            .map(|model| {
 198                Arc::new(GoogleLanguageModel {
 199                    id: LanguageModelId::from(model.id().to_string()),
 200                    model,
 201                    state: self.state.clone(),
 202                    http_client: self.http_client.clone(),
 203                    request_limiter: RateLimiter::new(4),
 204                }) as Arc<dyn LanguageModel>
 205            })
 206            .collect()
 207    }
 208
 209    fn is_authenticated(&self, cx: &App) -> bool {
 210        self.state.read(cx).is_authenticated()
 211    }
 212
 213    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 214        self.state.update(cx, |state, cx| state.authenticate(cx))
 215    }
 216
 217    fn configuration_view(
 218        &self,
 219        target_agent: language_model::ConfigurationViewTargetAgent,
 220        window: &mut Window,
 221        cx: &mut App,
 222    ) -> AnyView {
 223        cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
 224            .into()
 225    }
 226
 227    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 228        self.state
 229            .update(cx, |state, cx| state.set_api_key(None, cx))
 230    }
 231}
 232
 233pub struct GoogleLanguageModel {
 234    id: LanguageModelId,
 235    model: google_ai::Model,
 236    state: Entity<State>,
 237    http_client: Arc<dyn HttpClient>,
 238    request_limiter: RateLimiter,
 239}
 240
 241impl GoogleLanguageModel {
 242    fn stream_completion(
 243        &self,
 244        request: google_ai::GenerateContentRequest,
 245        cx: &AsyncApp,
 246    ) -> BoxFuture<
 247        'static,
 248        Result<futures::stream::BoxStream<'static, Result<GenerateContentResponse>>>,
 249    > {
 250        let http_client = self.http_client.clone();
 251
 252        let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
 253            let api_url = GoogleLanguageModelProvider::api_url(cx);
 254            (state.api_key_state.key(&api_url), api_url)
 255        }) else {
 256            return future::ready(Err(anyhow!("App state dropped"))).boxed();
 257        };
 258
 259        async move {
 260            let api_key = api_key.context("Missing Google API key")?;
 261            let request = google_ai::stream_generate_content(
 262                http_client.as_ref(),
 263                &api_url,
 264                &api_key,
 265                request,
 266            );
 267            request.await.context("failed to stream completion")
 268        }
 269        .boxed()
 270    }
 271}
 272
 273impl LanguageModel for GoogleLanguageModel {
 274    fn id(&self) -> LanguageModelId {
 275        self.id.clone()
 276    }
 277
 278    fn name(&self) -> LanguageModelName {
 279        LanguageModelName::from(self.model.display_name().to_string())
 280    }
 281
 282    fn provider_id(&self) -> LanguageModelProviderId {
 283        PROVIDER_ID
 284    }
 285
 286    fn provider_name(&self) -> LanguageModelProviderName {
 287        PROVIDER_NAME
 288    }
 289
 290    fn supports_tools(&self) -> bool {
 291        self.model.supports_tools()
 292    }
 293
 294    fn supports_images(&self) -> bool {
 295        self.model.supports_images()
 296    }
 297
 298    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 299        match choice {
 300            LanguageModelToolChoice::Auto
 301            | LanguageModelToolChoice::Any
 302            | LanguageModelToolChoice::None => true,
 303        }
 304    }
 305
 306    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 307        LanguageModelToolSchemaFormat::JsonSchemaSubset
 308    }
 309
 310    fn telemetry_id(&self) -> String {
 311        format!("google/{}", self.model.request_id())
 312    }
 313
 314    fn max_token_count(&self) -> u64 {
 315        self.model.max_token_count()
 316    }
 317
 318    fn max_output_tokens(&self) -> Option<u64> {
 319        self.model.max_output_tokens()
 320    }
 321
 322    fn count_tokens(
 323        &self,
 324        request: LanguageModelRequest,
 325        cx: &App,
 326    ) -> BoxFuture<'static, Result<u64>> {
 327        let model_id = self.model.request_id().to_string();
 328        let request = into_google(request, model_id, self.model.mode());
 329        let http_client = self.http_client.clone();
 330        let api_url = GoogleLanguageModelProvider::api_url(cx);
 331        let api_key = self.state.read(cx).api_key_state.key(&api_url);
 332
 333        async move {
 334            let Some(api_key) = api_key else {
 335                return Err(LanguageModelCompletionError::NoApiKey {
 336                    provider: PROVIDER_NAME,
 337                }
 338                .into());
 339            };
 340            let response = google_ai::count_tokens(
 341                http_client.as_ref(),
 342                &api_url,
 343                &api_key,
 344                google_ai::CountTokensRequest {
 345                    generate_content_request: request,
 346                },
 347            )
 348            .await?;
 349            Ok(response.total_tokens)
 350        }
 351        .boxed()
 352    }
 353
 354    fn stream_completion(
 355        &self,
 356        request: LanguageModelRequest,
 357        cx: &AsyncApp,
 358    ) -> BoxFuture<
 359        'static,
 360        Result<
 361            futures::stream::BoxStream<
 362                'static,
 363                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
 364            >,
 365            LanguageModelCompletionError,
 366        >,
 367    > {
 368        let request = into_google(
 369            request,
 370            self.model.request_id().to_string(),
 371            self.model.mode(),
 372        );
 373        let request = self.stream_completion(request, cx);
 374        let future = self.request_limiter.stream(async move {
 375            let response = request.await.map_err(LanguageModelCompletionError::from)?;
 376            Ok(GoogleEventMapper::new().map_stream(response))
 377        });
 378        async move { Ok(future.await?.boxed()) }.boxed()
 379    }
 380}
 381
 382pub fn into_google(
 383    mut request: LanguageModelRequest,
 384    model_id: String,
 385    mode: GoogleModelMode,
 386) -> google_ai::GenerateContentRequest {
 387    fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
 388        content
 389            .into_iter()
 390            .flat_map(|content| match content {
 391                language_model::MessageContent::Text(text) => {
 392                    if !text.is_empty() {
 393                        vec![Part::TextPart(google_ai::TextPart { text })]
 394                    } else {
 395                        vec![]
 396                    }
 397                }
 398                language_model::MessageContent::Thinking {
 399                    text: _,
 400                    signature: Some(signature),
 401                } => {
 402                    if !signature.is_empty() {
 403                        vec![Part::ThoughtPart(google_ai::ThoughtPart {
 404                            thought: true,
 405                            thought_signature: signature,
 406                        })]
 407                    } else {
 408                        vec![]
 409                    }
 410                }
 411                language_model::MessageContent::Thinking { .. } => {
 412                    vec![]
 413                }
 414                language_model::MessageContent::RedactedThinking(_) => vec![],
 415                language_model::MessageContent::Image(image) => {
 416                    vec![Part::InlineDataPart(google_ai::InlineDataPart {
 417                        inline_data: google_ai::GenerativeContentBlob {
 418                            mime_type: "image/png".to_string(),
 419                            data: image.source.to_string(),
 420                        },
 421                    })]
 422                }
 423                language_model::MessageContent::ToolUse(tool_use) => {
 424                    // Normalize empty string signatures to None
 425                    let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
 426
 427                    vec![Part::FunctionCallPart(google_ai::FunctionCallPart {
 428                        function_call: google_ai::FunctionCall {
 429                            name: tool_use.name.to_string(),
 430                            args: tool_use.input,
 431                        },
 432                        thought_signature,
 433                    })]
 434                }
 435                language_model::MessageContent::ToolResult(tool_result) => {
 436                    match tool_result.content {
 437                        language_model::LanguageModelToolResultContent::Text(text) => {
 438                            vec![Part::FunctionResponsePart(
 439                                google_ai::FunctionResponsePart {
 440                                    function_response: google_ai::FunctionResponse {
 441                                        name: tool_result.tool_name.to_string(),
 442                                        // The API expects a valid JSON object
 443                                        response: serde_json::json!({
 444                                            "output": text
 445                                        }),
 446                                    },
 447                                },
 448                            )]
 449                        }
 450                        language_model::LanguageModelToolResultContent::Image(image) => {
 451                            vec![
 452                                Part::FunctionResponsePart(google_ai::FunctionResponsePart {
 453                                    function_response: google_ai::FunctionResponse {
 454                                        name: tool_result.tool_name.to_string(),
 455                                        // The API expects a valid JSON object
 456                                        response: serde_json::json!({
 457                                            "output": "Tool responded with an image"
 458                                        }),
 459                                    },
 460                                }),
 461                                Part::InlineDataPart(google_ai::InlineDataPart {
 462                                    inline_data: google_ai::GenerativeContentBlob {
 463                                        mime_type: "image/png".to_string(),
 464                                        data: image.source.to_string(),
 465                                    },
 466                                }),
 467                            ]
 468                        }
 469                    }
 470                }
 471            })
 472            .collect()
 473    }
 474
 475    let system_instructions = if request
 476        .messages
 477        .first()
 478        .is_some_and(|msg| matches!(msg.role, Role::System))
 479    {
 480        let message = request.messages.remove(0);
 481        Some(SystemInstruction {
 482            parts: map_content(message.content),
 483        })
 484    } else {
 485        None
 486    };
 487
 488    google_ai::GenerateContentRequest {
 489        model: google_ai::ModelName { model_id },
 490        system_instruction: system_instructions,
 491        contents: request
 492            .messages
 493            .into_iter()
 494            .filter_map(|message| {
 495                let parts = map_content(message.content);
 496                if parts.is_empty() {
 497                    None
 498                } else {
 499                    Some(google_ai::Content {
 500                        parts,
 501                        role: match message.role {
 502                            Role::User => google_ai::Role::User,
 503                            Role::Assistant => google_ai::Role::Model,
 504                            Role::System => google_ai::Role::User, // Google AI doesn't have a system role
 505                        },
 506                    })
 507                }
 508            })
 509            .collect(),
 510        generation_config: Some(google_ai::GenerationConfig {
 511            candidate_count: Some(1),
 512            stop_sequences: Some(request.stop),
 513            max_output_tokens: None,
 514            temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
 515            thinking_config: match (request.thinking_allowed, mode) {
 516                (true, GoogleModelMode::Thinking { budget_tokens }) => {
 517                    budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
 518                }
 519                _ => None,
 520            },
 521            top_p: None,
 522            top_k: None,
 523        }),
 524        safety_settings: None,
 525        tools: (!request.tools.is_empty()).then(|| {
 526            vec![google_ai::Tool {
 527                function_declarations: request
 528                    .tools
 529                    .into_iter()
 530                    .map(|tool| FunctionDeclaration {
 531                        name: tool.name,
 532                        description: tool.description,
 533                        parameters: tool.input_schema,
 534                    })
 535                    .collect(),
 536            }]
 537        }),
 538        tool_config: request.tool_choice.map(|choice| google_ai::ToolConfig {
 539            function_calling_config: google_ai::FunctionCallingConfig {
 540                mode: match choice {
 541                    LanguageModelToolChoice::Auto => google_ai::FunctionCallingMode::Auto,
 542                    LanguageModelToolChoice::Any => google_ai::FunctionCallingMode::Any,
 543                    LanguageModelToolChoice::None => google_ai::FunctionCallingMode::None,
 544                },
 545                allowed_function_names: None,
 546            },
 547        }),
 548    }
 549}
 550
 551pub struct GoogleEventMapper {
 552    usage: UsageMetadata,
 553    stop_reason: StopReason,
 554}
 555
 556impl GoogleEventMapper {
 557    pub fn new() -> Self {
 558        Self {
 559            usage: UsageMetadata::default(),
 560            stop_reason: StopReason::EndTurn,
 561        }
 562    }
 563
 564    pub fn map_stream(
 565        mut self,
 566        events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
 567    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 568    {
 569        events
 570            .map(Some)
 571            .chain(futures::stream::once(async { None }))
 572            .flat_map(move |event| {
 573                futures::stream::iter(match event {
 574                    Some(Ok(event)) => self.map_event(event),
 575                    Some(Err(error)) => {
 576                        vec![Err(LanguageModelCompletionError::from(error))]
 577                    }
 578                    None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
 579                })
 580            })
 581    }
 582
 583    pub fn map_event(
 584        &mut self,
 585        event: GenerateContentResponse,
 586    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 587        static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
 588
 589        let mut events: Vec<_> = Vec::new();
 590        let mut wants_to_use_tool = false;
 591        if let Some(usage_metadata) = event.usage_metadata {
 592            update_usage(&mut self.usage, &usage_metadata);
 593            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
 594                convert_usage(&self.usage),
 595            )))
 596        }
 597
 598        if let Some(prompt_feedback) = event.prompt_feedback
 599            && let Some(block_reason) = prompt_feedback.block_reason.as_deref()
 600        {
 601            self.stop_reason = match block_reason {
 602                "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => {
 603                    StopReason::Refusal
 604                }
 605                _ => {
 606                    log::error!("Unexpected Google block_reason: {block_reason}");
 607                    StopReason::Refusal
 608                }
 609            };
 610            events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
 611
 612            return events;
 613        }
 614
 615        if let Some(candidates) = event.candidates {
 616            for candidate in candidates {
 617                if let Some(finish_reason) = candidate.finish_reason.as_deref() {
 618                    self.stop_reason = match finish_reason {
 619                        "STOP" => StopReason::EndTurn,
 620                        "MAX_TOKENS" => StopReason::MaxTokens,
 621                        _ => {
 622                            log::error!("Unexpected google finish_reason: {finish_reason}");
 623                            StopReason::EndTurn
 624                        }
 625                    };
 626                }
 627                candidate
 628                    .content
 629                    .parts
 630                    .into_iter()
 631                    .for_each(|part| match part {
 632                        Part::TextPart(text_part) => {
 633                            events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
 634                        }
 635                        Part::InlineDataPart(_) => {}
 636                        Part::FunctionCallPart(function_call_part) => {
 637                            wants_to_use_tool = true;
 638                            let name: Arc<str> = function_call_part.function_call.name.into();
 639                            let next_tool_id =
 640                                TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
 641                            let id: LanguageModelToolUseId =
 642                                format!("{}-{}", name, next_tool_id).into();
 643
 644                            // Normalize empty string signatures to None
 645                            let thought_signature = function_call_part
 646                                .thought_signature
 647                                .filter(|s| !s.is_empty());
 648
 649                            events.push(Ok(LanguageModelCompletionEvent::ToolUse(
 650                                LanguageModelToolUse {
 651                                    id,
 652                                    name,
 653                                    is_input_complete: true,
 654                                    raw_input: function_call_part.function_call.args.to_string(),
 655                                    input: function_call_part.function_call.args,
 656                                    thought_signature,
 657                                },
 658                            )));
 659                        }
 660                        Part::FunctionResponsePart(_) => {}
 661                        Part::ThoughtPart(part) => {
 662                            events.push(Ok(LanguageModelCompletionEvent::Thinking {
 663                                text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
 664                                signature: Some(part.thought_signature),
 665                            }));
 666                        }
 667                    });
 668            }
 669        }
 670
 671        // Even when Gemini wants to use a Tool, the API
 672        // responds with `finish_reason: STOP`
 673        if wants_to_use_tool {
 674            self.stop_reason = StopReason::ToolUse;
 675            events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
 676        }
 677        events
 678    }
 679}
 680
 681pub fn count_google_tokens(
 682    request: LanguageModelRequest,
 683    cx: &App,
 684) -> BoxFuture<'static, Result<u64>> {
 685    // We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly.
 686    // So we have to use tokenizer from tiktoken_rs to count tokens.
 687    cx.background_spawn(async move {
 688        let messages = request
 689            .messages
 690            .into_iter()
 691            .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
 692                role: match message.role {
 693                    Role::User => "user".into(),
 694                    Role::Assistant => "assistant".into(),
 695                    Role::System => "system".into(),
 696                },
 697                content: Some(message.string_contents()),
 698                name: None,
 699                function_call: None,
 700            })
 701            .collect::<Vec<_>>();
 702
 703        // Tiktoken doesn't yet support these models, so we manually use the
 704        // same tokenizer as GPT-4.
 705        tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
 706    })
 707    .boxed()
 708}
 709
 710fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
 711    if let Some(prompt_token_count) = new.prompt_token_count {
 712        usage.prompt_token_count = Some(prompt_token_count);
 713    }
 714    if let Some(cached_content_token_count) = new.cached_content_token_count {
 715        usage.cached_content_token_count = Some(cached_content_token_count);
 716    }
 717    if let Some(candidates_token_count) = new.candidates_token_count {
 718        usage.candidates_token_count = Some(candidates_token_count);
 719    }
 720    if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
 721        usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
 722    }
 723    if let Some(thoughts_token_count) = new.thoughts_token_count {
 724        usage.thoughts_token_count = Some(thoughts_token_count);
 725    }
 726    if let Some(total_token_count) = new.total_token_count {
 727        usage.total_token_count = Some(total_token_count);
 728    }
 729}
 730
 731fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
 732    let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
 733    let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
 734    let input_tokens = prompt_tokens - cached_tokens;
 735    let output_tokens = usage.candidates_token_count.unwrap_or(0);
 736
 737    language_model::TokenUsage {
 738        input_tokens,
 739        output_tokens,
 740        cache_read_input_tokens: cached_tokens,
 741        cache_creation_input_tokens: 0,
 742    }
 743}
 744
 745struct ConfigurationView {
 746    api_key_editor: Entity<InputField>,
 747    state: Entity<State>,
 748    target_agent: language_model::ConfigurationViewTargetAgent,
 749    load_credentials_task: Option<Task<()>>,
 750}
 751
 752impl ConfigurationView {
 753    fn new(
 754        state: Entity<State>,
 755        target_agent: language_model::ConfigurationViewTargetAgent,
 756        window: &mut Window,
 757        cx: &mut Context<Self>,
 758    ) -> Self {
 759        cx.observe(&state, |_, _, cx| {
 760            cx.notify();
 761        })
 762        .detach();
 763
 764        let load_credentials_task = Some(cx.spawn_in(window, {
 765            let state = state.clone();
 766            async move |this, cx| {
 767                if let Some(task) = state
 768                    .update(cx, |state, cx| state.authenticate(cx))
 769                    .log_err()
 770                {
 771                    // We don't log an error, because "not signed in" is also an error.
 772                    let _ = task.await;
 773                }
 774                this.update(cx, |this, cx| {
 775                    this.load_credentials_task = None;
 776                    cx.notify();
 777                })
 778                .log_err();
 779            }
 780        }));
 781
 782        Self {
 783            api_key_editor: cx.new(|cx| InputField::new(window, cx, "AIzaSy...")),
 784            target_agent,
 785            state,
 786            load_credentials_task,
 787        }
 788    }
 789
 790    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 791        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
 792        if api_key.is_empty() {
 793            return;
 794        }
 795
 796        // url changes can cause the editor to be displayed again
 797        self.api_key_editor
 798            .update(cx, |editor, cx| editor.set_text("", window, cx));
 799
 800        let state = self.state.clone();
 801        cx.spawn_in(window, async move |_, cx| {
 802            state
 803                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
 804                .await
 805        })
 806        .detach_and_log_err(cx);
 807    }
 808
 809    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 810        self.api_key_editor
 811            .update(cx, |editor, cx| editor.set_text("", window, cx));
 812
 813        let state = self.state.clone();
 814        cx.spawn_in(window, async move |_, cx| {
 815            state
 816                .update(cx, |state, cx| state.set_api_key(None, cx))?
 817                .await
 818        })
 819        .detach_and_log_err(cx);
 820    }
 821
 822    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
 823        !self.state.read(cx).is_authenticated()
 824    }
 825}
 826
 827impl Render for ConfigurationView {
 828    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 829        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
 830        let configured_card_label = if env_var_set {
 831            format!(
 832                "API key set in {} environment variable",
 833                API_KEY_ENV_VAR.name
 834            )
 835        } else {
 836            let api_url = GoogleLanguageModelProvider::api_url(cx);
 837            if api_url == google_ai::API_URL {
 838                "API key configured".to_string()
 839            } else {
 840                format!("API key configured for {}", api_url)
 841            }
 842        };
 843
 844        if self.load_credentials_task.is_some() {
 845            div()
 846                .child(Label::new("Loading credentials..."))
 847                .into_any_element()
 848        } else if self.should_render_editor(cx) {
 849            v_flex()
 850                .size_full()
 851                .on_action(cx.listener(Self::save_api_key))
 852                .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent {
 853                    ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Google AI".into(),
 854                    ConfigurationViewTargetAgent::Other(agent) => agent.clone(),
 855                })))
 856                .child(
 857                    List::new()
 858                        .child(InstructionListItem::new(
 859                            "Create one by visiting",
 860                            Some("Google AI's console"),
 861                            Some("https://aistudio.google.com/app/apikey"),
 862                        ))
 863                        .child(InstructionListItem::text_only(
 864                            "Paste your API key below and hit enter to start using the assistant",
 865                        )),
 866                )
 867                .child(self.api_key_editor.clone())
 868                .child(
 869                    Label::new(
 870                        format!("You can also assign the {GEMINI_API_KEY_VAR_NAME} environment variable and restart Zed."),
 871                    )
 872                    .size(LabelSize::Small).color(Color::Muted),
 873                )
 874                .into_any_element()
 875        } else {
 876            ConfiguredApiCard::new(configured_card_label)
 877                .disabled(env_var_set)
 878                .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
 879                .when(env_var_set, |this| {
 880                    this.tooltip_label(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR_NAME} and {GOOGLE_AI_API_KEY_VAR_NAME} environment variables are unset."))
 881                })
 882                .into_any_element()
 883        }
 884    }
 885}
 886
 887#[cfg(test)]
 888mod tests {
 889    use super::*;
 890    use google_ai::{
 891        Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse,
 892        Part, Role as GoogleRole, TextPart,
 893    };
 894    use language_model::{LanguageModelToolUseId, MessageContent, Role};
 895    use serde_json::json;
 896
 897    #[test]
 898    fn test_function_call_with_signature_creates_tool_use_with_signature() {
 899        let mut mapper = GoogleEventMapper::new();
 900
 901        let response = GenerateContentResponse {
 902            candidates: Some(vec![GenerateContentCandidate {
 903                index: Some(0),
 904                content: Content {
 905                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
 906                        function_call: FunctionCall {
 907                            name: "test_function".to_string(),
 908                            args: json!({"arg": "value"}),
 909                        },
 910                        thought_signature: Some("test_signature_123".to_string()),
 911                    })],
 912                    role: GoogleRole::Model,
 913                },
 914                finish_reason: None,
 915                finish_message: None,
 916                safety_ratings: None,
 917                citation_metadata: None,
 918            }]),
 919            prompt_feedback: None,
 920            usage_metadata: None,
 921        };
 922
 923        let events = mapper.map_event(response);
 924
 925        assert_eq!(events.len(), 2); // ToolUse event + Stop event
 926
 927        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
 928            assert_eq!(tool_use.name.as_ref(), "test_function");
 929            assert_eq!(
 930                tool_use.thought_signature.as_deref(),
 931                Some("test_signature_123")
 932            );
 933        } else {
 934            panic!("Expected ToolUse event");
 935        }
 936    }
 937
 938    #[test]
 939    fn test_function_call_without_signature_has_none() {
 940        let mut mapper = GoogleEventMapper::new();
 941
 942        let response = GenerateContentResponse {
 943            candidates: Some(vec![GenerateContentCandidate {
 944                index: Some(0),
 945                content: Content {
 946                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
 947                        function_call: FunctionCall {
 948                            name: "test_function".to_string(),
 949                            args: json!({"arg": "value"}),
 950                        },
 951                        thought_signature: None,
 952                    })],
 953                    role: GoogleRole::Model,
 954                },
 955                finish_reason: None,
 956                finish_message: None,
 957                safety_ratings: None,
 958                citation_metadata: None,
 959            }]),
 960            prompt_feedback: None,
 961            usage_metadata: None,
 962        };
 963
 964        let events = mapper.map_event(response);
 965
 966        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
 967            assert_eq!(tool_use.thought_signature, None);
 968        } else {
 969            panic!("Expected ToolUse event");
 970        }
 971    }
 972
 973    #[test]
 974    fn test_empty_string_signature_normalized_to_none() {
 975        let mut mapper = GoogleEventMapper::new();
 976
 977        let response = GenerateContentResponse {
 978            candidates: Some(vec![GenerateContentCandidate {
 979                index: Some(0),
 980                content: Content {
 981                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
 982                        function_call: FunctionCall {
 983                            name: "test_function".to_string(),
 984                            args: json!({"arg": "value"}),
 985                        },
 986                        thought_signature: Some("".to_string()),
 987                    })],
 988                    role: GoogleRole::Model,
 989                },
 990                finish_reason: None,
 991                finish_message: None,
 992                safety_ratings: None,
 993                citation_metadata: None,
 994            }]),
 995            prompt_feedback: None,
 996            usage_metadata: None,
 997        };
 998
 999        let events = mapper.map_event(response);
1000
1001        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1002            assert_eq!(tool_use.thought_signature, None);
1003        } else {
1004            panic!("Expected ToolUse event");
1005        }
1006    }
1007
1008    #[test]
1009    fn test_parallel_function_calls_preserve_signatures() {
1010        let mut mapper = GoogleEventMapper::new();
1011
1012        let response = GenerateContentResponse {
1013            candidates: Some(vec![GenerateContentCandidate {
1014                index: Some(0),
1015                content: Content {
1016                    parts: vec![
1017                        Part::FunctionCallPart(FunctionCallPart {
1018                            function_call: FunctionCall {
1019                                name: "function_1".to_string(),
1020                                args: json!({"arg": "value1"}),
1021                            },
1022                            thought_signature: Some("signature_1".to_string()),
1023                        }),
1024                        Part::FunctionCallPart(FunctionCallPart {
1025                            function_call: FunctionCall {
1026                                name: "function_2".to_string(),
1027                                args: json!({"arg": "value2"}),
1028                            },
1029                            thought_signature: None,
1030                        }),
1031                    ],
1032                    role: GoogleRole::Model,
1033                },
1034                finish_reason: None,
1035                finish_message: None,
1036                safety_ratings: None,
1037                citation_metadata: None,
1038            }]),
1039            prompt_feedback: None,
1040            usage_metadata: None,
1041        };
1042
1043        let events = mapper.map_event(response);
1044
1045        assert_eq!(events.len(), 3); // 2 ToolUse events + Stop event
1046
1047        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1048            assert_eq!(tool_use.name.as_ref(), "function_1");
1049            assert_eq!(tool_use.thought_signature.as_deref(), Some("signature_1"));
1050        } else {
1051            panic!("Expected ToolUse event for function_1");
1052        }
1053
1054        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
1055            assert_eq!(tool_use.name.as_ref(), "function_2");
1056            assert_eq!(tool_use.thought_signature, None);
1057        } else {
1058            panic!("Expected ToolUse event for function_2");
1059        }
1060    }
1061
1062    #[test]
1063    fn test_tool_use_with_signature_converts_to_function_call_part() {
1064        let tool_use = language_model::LanguageModelToolUse {
1065            id: LanguageModelToolUseId::from("test_id"),
1066            name: "test_function".into(),
1067            raw_input: json!({"arg": "value"}).to_string(),
1068            input: json!({"arg": "value"}),
1069            is_input_complete: true,
1070            thought_signature: Some("test_signature_456".to_string()),
1071        };
1072
1073        let request = super::into_google(
1074            LanguageModelRequest {
1075                messages: vec![language_model::LanguageModelRequestMessage {
1076                    role: Role::Assistant,
1077                    content: vec![MessageContent::ToolUse(tool_use)],
1078                    cache: false,
1079                    reasoning_details: None,
1080                }],
1081                ..Default::default()
1082            },
1083            "gemini-2.5-flash".to_string(),
1084            GoogleModelMode::Default,
1085        );
1086
1087        assert_eq!(request.contents[0].parts.len(), 1);
1088        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1089            assert_eq!(fc_part.function_call.name, "test_function");
1090            assert_eq!(
1091                fc_part.thought_signature.as_deref(),
1092                Some("test_signature_456")
1093            );
1094        } else {
1095            panic!("Expected FunctionCallPart");
1096        }
1097    }
1098
1099    #[test]
1100    fn test_tool_use_without_signature_omits_field() {
1101        let tool_use = language_model::LanguageModelToolUse {
1102            id: LanguageModelToolUseId::from("test_id"),
1103            name: "test_function".into(),
1104            raw_input: json!({"arg": "value"}).to_string(),
1105            input: json!({"arg": "value"}),
1106            is_input_complete: true,
1107            thought_signature: None,
1108        };
1109
1110        let request = super::into_google(
1111            LanguageModelRequest {
1112                messages: vec![language_model::LanguageModelRequestMessage {
1113                    role: Role::Assistant,
1114                    content: vec![MessageContent::ToolUse(tool_use)],
1115                    cache: false,
1116                    reasoning_details: None,
1117                }],
1118                ..Default::default()
1119            },
1120            "gemini-2.5-flash".to_string(),
1121            GoogleModelMode::Default,
1122        );
1123
1124        assert_eq!(request.contents[0].parts.len(), 1);
1125        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1126            assert_eq!(fc_part.thought_signature, None);
1127        } else {
1128            panic!("Expected FunctionCallPart");
1129        }
1130    }
1131
1132    #[test]
1133    fn test_empty_signature_in_tool_use_normalized_to_none() {
1134        let tool_use = language_model::LanguageModelToolUse {
1135            id: LanguageModelToolUseId::from("test_id"),
1136            name: "test_function".into(),
1137            raw_input: json!({"arg": "value"}).to_string(),
1138            input: json!({"arg": "value"}),
1139            is_input_complete: true,
1140            thought_signature: Some("".to_string()),
1141        };
1142
1143        let request = super::into_google(
1144            LanguageModelRequest {
1145                messages: vec![language_model::LanguageModelRequestMessage {
1146                    role: Role::Assistant,
1147                    content: vec![MessageContent::ToolUse(tool_use)],
1148                    cache: false,
1149                    reasoning_details: None,
1150                }],
1151                ..Default::default()
1152            },
1153            "gemini-2.5-flash".to_string(),
1154            GoogleModelMode::Default,
1155        );
1156
1157        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1158            assert_eq!(fc_part.thought_signature, None);
1159        } else {
1160            panic!("Expected FunctionCallPart");
1161        }
1162    }
1163
1164    #[test]
1165    fn test_round_trip_preserves_signature() {
1166        let mut mapper = GoogleEventMapper::new();
1167
1168        // Simulate receiving a response from Google with a signature
1169        let response = GenerateContentResponse {
1170            candidates: Some(vec![GenerateContentCandidate {
1171                index: Some(0),
1172                content: Content {
1173                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
1174                        function_call: FunctionCall {
1175                            name: "test_function".to_string(),
1176                            args: json!({"arg": "value"}),
1177                        },
1178                        thought_signature: Some("round_trip_sig".to_string()),
1179                    })],
1180                    role: GoogleRole::Model,
1181                },
1182                finish_reason: None,
1183                finish_message: None,
1184                safety_ratings: None,
1185                citation_metadata: None,
1186            }]),
1187            prompt_feedback: None,
1188            usage_metadata: None,
1189        };
1190
1191        let events = mapper.map_event(response);
1192
1193        let tool_use = if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1194            tool_use.clone()
1195        } else {
1196            panic!("Expected ToolUse event");
1197        };
1198
1199        // Convert back to Google format
1200        let request = super::into_google(
1201            LanguageModelRequest {
1202                messages: vec![language_model::LanguageModelRequestMessage {
1203                    role: Role::Assistant,
1204                    content: vec![MessageContent::ToolUse(tool_use)],
1205                    cache: false,
1206                    reasoning_details: None,
1207                }],
1208                ..Default::default()
1209            },
1210            "gemini-2.5-flash".to_string(),
1211            GoogleModelMode::Default,
1212        );
1213
1214        // Verify signature is preserved
1215        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1216            assert_eq!(fc_part.thought_signature.as_deref(), Some("round_trip_sig"));
1217        } else {
1218            panic!("Expected FunctionCallPart");
1219        }
1220    }
1221
1222    #[test]
1223    fn test_mixed_text_and_function_call_with_signature() {
1224        let mut mapper = GoogleEventMapper::new();
1225
1226        let response = GenerateContentResponse {
1227            candidates: Some(vec![GenerateContentCandidate {
1228                index: Some(0),
1229                content: Content {
1230                    parts: vec![
1231                        Part::TextPart(TextPart {
1232                            text: "I'll help with that.".to_string(),
1233                        }),
1234                        Part::FunctionCallPart(FunctionCallPart {
1235                            function_call: FunctionCall {
1236                                name: "helper_function".to_string(),
1237                                args: json!({"query": "help"}),
1238                            },
1239                            thought_signature: Some("mixed_sig".to_string()),
1240                        }),
1241                    ],
1242                    role: GoogleRole::Model,
1243                },
1244                finish_reason: None,
1245                finish_message: None,
1246                safety_ratings: None,
1247                citation_metadata: None,
1248            }]),
1249            prompt_feedback: None,
1250            usage_metadata: None,
1251        };
1252
1253        let events = mapper.map_event(response);
1254
1255        assert_eq!(events.len(), 3); // Text event + ToolUse event + Stop event
1256
1257        if let Ok(LanguageModelCompletionEvent::Text(text)) = &events[0] {
1258            assert_eq!(text, "I'll help with that.");
1259        } else {
1260            panic!("Expected Text event");
1261        }
1262
1263        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
1264            assert_eq!(tool_use.name.as_ref(), "helper_function");
1265            assert_eq!(tool_use.thought_signature.as_deref(), Some("mixed_sig"));
1266        } else {
1267            panic!("Expected ToolUse event");
1268        }
1269    }
1270
1271    #[test]
1272    fn test_special_characters_in_signature_preserved() {
1273        let mut mapper = GoogleEventMapper::new();
1274
1275        let signature_with_special_chars = "sig<>\"'&%$#@!{}[]".to_string();
1276
1277        let response = GenerateContentResponse {
1278            candidates: Some(vec![GenerateContentCandidate {
1279                index: Some(0),
1280                content: Content {
1281                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
1282                        function_call: FunctionCall {
1283                            name: "test_function".to_string(),
1284                            args: json!({"arg": "value"}),
1285                        },
1286                        thought_signature: Some(signature_with_special_chars.clone()),
1287                    })],
1288                    role: GoogleRole::Model,
1289                },
1290                finish_reason: None,
1291                finish_message: None,
1292                safety_ratings: None,
1293                citation_metadata: None,
1294            }]),
1295            prompt_feedback: None,
1296            usage_metadata: None,
1297        };
1298
1299        let events = mapper.map_event(response);
1300
1301        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1302            assert_eq!(
1303                tool_use.thought_signature.as_deref(),
1304                Some(signature_with_special_chars.as_str())
1305            );
1306        } else {
1307            panic!("Expected ToolUse event");
1308        }
1309    }
1310}