google.rs

   1use anyhow::{Context as _, Result};
   2use collections::BTreeMap;
   3use credentials_provider::CredentialsProvider;
   4use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
   5use google_ai::{
   6    FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
   7    ThinkingConfig, UsageMetadata,
   8};
   9use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
  10use http_client::HttpClient;
  11use language_model::{
  12    AuthenticateError, ConfigurationViewTargetAgent, EnvVar, LanguageModelCompletionError,
  13    LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
  14    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
  15};
  16use language_model::{
  17    GOOGLE_PROVIDER_ID, GOOGLE_PROVIDER_NAME, IconOrSvg, LanguageModel, LanguageModelId,
  18    LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
  19    LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
  20};
  21use schemars::JsonSchema;
  22use serde::{Deserialize, Serialize};
  23pub use settings::GoogleAvailableModel as AvailableModel;
  24use settings::{Settings, SettingsStore};
  25use std::pin::Pin;
  26use std::sync::{
  27    Arc, LazyLock,
  28    atomic::{self, AtomicU64},
  29};
  30use strum::IntoEnumIterator;
  31use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
  32use ui_input::InputField;
  33use util::ResultExt;
  34
  35use language_model::ApiKeyState;
  36
  37const PROVIDER_ID: LanguageModelProviderId = GOOGLE_PROVIDER_ID;
  38const PROVIDER_NAME: LanguageModelProviderName = GOOGLE_PROVIDER_NAME;
  39
  40#[derive(Default, Clone, Debug, PartialEq)]
  41pub struct GoogleSettings {
  42    pub api_url: String,
  43    pub available_models: Vec<AvailableModel>,
  44}
  45
  46#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
  47#[serde(tag = "type", rename_all = "lowercase")]
  48pub enum ModelMode {
  49    #[default]
  50    Default,
  51    Thinking {
  52        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  53        budget_tokens: Option<u32>,
  54    },
  55}
  56
  57pub struct GoogleLanguageModelProvider {
  58    http_client: Arc<dyn HttpClient>,
  59    state: Entity<State>,
  60}
  61
  62pub struct State {
  63    api_key_state: ApiKeyState,
  64    credentials_provider: Arc<dyn CredentialsProvider>,
  65}
  66
  67const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
  68const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY";
  69
  70static API_KEY_ENV_VAR: LazyLock<EnvVar> = LazyLock::new(|| {
  71    // Try GEMINI_API_KEY first as primary, fallback to GOOGLE_AI_API_KEY
  72    EnvVar::new(GEMINI_API_KEY_VAR_NAME.into()).or(EnvVar::new(GOOGLE_AI_API_KEY_VAR_NAME.into()))
  73});
  74
  75impl State {
  76    fn is_authenticated(&self) -> bool {
  77        self.api_key_state.has_key()
  78    }
  79
  80    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
  81        let credentials_provider = self.credentials_provider.clone();
  82        let api_url = GoogleLanguageModelProvider::api_url(cx);
  83        self.api_key_state.store(
  84            api_url,
  85            api_key,
  86            |this| &mut this.api_key_state,
  87            credentials_provider,
  88            cx,
  89        )
  90    }
  91
  92    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
  93        let credentials_provider = self.credentials_provider.clone();
  94        let api_url = GoogleLanguageModelProvider::api_url(cx);
  95        self.api_key_state.load_if_needed(
  96            api_url,
  97            |this| &mut this.api_key_state,
  98            credentials_provider,
  99            cx,
 100        )
 101    }
 102}
 103
 104impl GoogleLanguageModelProvider {
 105    pub fn new(
 106        http_client: Arc<dyn HttpClient>,
 107        credentials_provider: Arc<dyn CredentialsProvider>,
 108        cx: &mut App,
 109    ) -> Self {
 110        let state = cx.new(|cx| {
 111            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
 112                let credentials_provider = this.credentials_provider.clone();
 113                let api_url = Self::api_url(cx);
 114                this.api_key_state.handle_url_change(
 115                    api_url,
 116                    |this| &mut this.api_key_state,
 117                    credentials_provider,
 118                    cx,
 119                );
 120                cx.notify();
 121            })
 122            .detach();
 123            State {
 124                api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
 125                credentials_provider,
 126            }
 127        });
 128
 129        Self { http_client, state }
 130    }
 131
 132    fn create_language_model(&self, model: google_ai::Model) -> Arc<dyn LanguageModel> {
 133        Arc::new(GoogleLanguageModel {
 134            id: LanguageModelId::from(model.id().to_string()),
 135            model,
 136            state: self.state.clone(),
 137            http_client: self.http_client.clone(),
 138            request_limiter: RateLimiter::new(4),
 139        })
 140    }
 141
 142    fn settings(cx: &App) -> &GoogleSettings {
 143        &crate::AllLanguageModelSettings::get_global(cx).google
 144    }
 145
 146    fn api_url(cx: &App) -> SharedString {
 147        let api_url = &Self::settings(cx).api_url;
 148        if api_url.is_empty() {
 149            google_ai::API_URL.into()
 150        } else {
 151            SharedString::new(api_url.as_str())
 152        }
 153    }
 154}
 155
 156impl LanguageModelProviderState for GoogleLanguageModelProvider {
 157    type ObservableEntity = State;
 158
 159    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 160        Some(self.state.clone())
 161    }
 162}
 163
 164impl LanguageModelProvider for GoogleLanguageModelProvider {
 165    fn id(&self) -> LanguageModelProviderId {
 166        PROVIDER_ID
 167    }
 168
 169    fn name(&self) -> LanguageModelProviderName {
 170        PROVIDER_NAME
 171    }
 172
 173    fn icon(&self) -> IconOrSvg {
 174        IconOrSvg::Icon(IconName::AiGoogle)
 175    }
 176
 177    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 178        Some(self.create_language_model(google_ai::Model::default()))
 179    }
 180
 181    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 182        Some(self.create_language_model(google_ai::Model::default_fast()))
 183    }
 184
 185    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 186        let mut models = BTreeMap::default();
 187
 188        // Add base models from google_ai::Model::iter()
 189        for model in google_ai::Model::iter() {
 190            if !matches!(model, google_ai::Model::Custom { .. }) {
 191                models.insert(model.id().to_string(), model);
 192            }
 193        }
 194
 195        // Override with available models from settings
 196        for model in &GoogleLanguageModelProvider::settings(cx).available_models {
 197            models.insert(
 198                model.name.clone(),
 199                google_ai::Model::Custom {
 200                    name: model.name.clone(),
 201                    display_name: model.display_name.clone(),
 202                    max_tokens: model.max_tokens,
 203                    mode: model.mode.unwrap_or_default(),
 204                },
 205            );
 206        }
 207
 208        models
 209            .into_values()
 210            .map(|model| {
 211                Arc::new(GoogleLanguageModel {
 212                    id: LanguageModelId::from(model.id().to_string()),
 213                    model,
 214                    state: self.state.clone(),
 215                    http_client: self.http_client.clone(),
 216                    request_limiter: RateLimiter::new(4),
 217                }) as Arc<dyn LanguageModel>
 218            })
 219            .collect()
 220    }
 221
 222    fn is_authenticated(&self, cx: &App) -> bool {
 223        self.state.read(cx).is_authenticated()
 224    }
 225
 226    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 227        self.state.update(cx, |state, cx| state.authenticate(cx))
 228    }
 229
 230    fn configuration_view(
 231        &self,
 232        target_agent: language_model::ConfigurationViewTargetAgent,
 233        window: &mut Window,
 234        cx: &mut App,
 235    ) -> AnyView {
 236        cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
 237            .into()
 238    }
 239
 240    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 241        self.state
 242            .update(cx, |state, cx| state.set_api_key(None, cx))
 243    }
 244}
 245
 246pub struct GoogleLanguageModel {
 247    id: LanguageModelId,
 248    model: google_ai::Model,
 249    state: Entity<State>,
 250    http_client: Arc<dyn HttpClient>,
 251    request_limiter: RateLimiter,
 252}
 253
 254impl GoogleLanguageModel {
 255    fn stream_completion(
 256        &self,
 257        request: google_ai::GenerateContentRequest,
 258        cx: &AsyncApp,
 259    ) -> BoxFuture<
 260        'static,
 261        Result<futures::stream::BoxStream<'static, Result<GenerateContentResponse>>>,
 262    > {
 263        let http_client = self.http_client.clone();
 264
 265        let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
 266            let api_url = GoogleLanguageModelProvider::api_url(cx);
 267            (state.api_key_state.key(&api_url), api_url)
 268        });
 269
 270        async move {
 271            let api_key = api_key.context("Missing Google API key")?;
 272            let request = google_ai::stream_generate_content(
 273                http_client.as_ref(),
 274                &api_url,
 275                &api_key,
 276                request,
 277            );
 278            request.await.context("failed to stream completion")
 279        }
 280        .boxed()
 281    }
 282}
 283
 284impl LanguageModel for GoogleLanguageModel {
 285    fn id(&self) -> LanguageModelId {
 286        self.id.clone()
 287    }
 288
 289    fn name(&self) -> LanguageModelName {
 290        LanguageModelName::from(self.model.display_name().to_string())
 291    }
 292
 293    fn provider_id(&self) -> LanguageModelProviderId {
 294        PROVIDER_ID
 295    }
 296
 297    fn provider_name(&self) -> LanguageModelProviderName {
 298        PROVIDER_NAME
 299    }
 300
 301    fn supports_tools(&self) -> bool {
 302        self.model.supports_tools()
 303    }
 304
 305    fn supports_images(&self) -> bool {
 306        self.model.supports_images()
 307    }
 308
 309    fn supports_thinking(&self) -> bool {
 310        matches!(self.model.mode(), GoogleModelMode::Thinking { .. })
 311    }
 312
 313    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 314        match choice {
 315            LanguageModelToolChoice::Auto
 316            | LanguageModelToolChoice::Any
 317            | LanguageModelToolChoice::None => true,
 318        }
 319    }
 320
 321    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 322        LanguageModelToolSchemaFormat::JsonSchemaSubset
 323    }
 324
 325    fn telemetry_id(&self) -> String {
 326        format!("google/{}", self.model.request_id())
 327    }
 328
 329    fn max_token_count(&self) -> u64 {
 330        self.model.max_token_count()
 331    }
 332
 333    fn max_output_tokens(&self) -> Option<u64> {
 334        self.model.max_output_tokens()
 335    }
 336
 337    fn count_tokens(
 338        &self,
 339        request: LanguageModelRequest,
 340        cx: &App,
 341    ) -> BoxFuture<'static, Result<u64>> {
 342        let model_id = self.model.request_id().to_string();
 343        let request = into_google(request, model_id, self.model.mode());
 344        let http_client = self.http_client.clone();
 345        let api_url = GoogleLanguageModelProvider::api_url(cx);
 346        let api_key = self.state.read(cx).api_key_state.key(&api_url);
 347
 348        async move {
 349            let Some(api_key) = api_key else {
 350                return Err(LanguageModelCompletionError::NoApiKey {
 351                    provider: PROVIDER_NAME,
 352                }
 353                .into());
 354            };
 355            let response = google_ai::count_tokens(
 356                http_client.as_ref(),
 357                &api_url,
 358                &api_key,
 359                google_ai::CountTokensRequest {
 360                    generate_content_request: request,
 361                },
 362            )
 363            .await?;
 364            Ok(response.total_tokens)
 365        }
 366        .boxed()
 367    }
 368
 369    fn stream_completion(
 370        &self,
 371        request: LanguageModelRequest,
 372        cx: &AsyncApp,
 373    ) -> BoxFuture<
 374        'static,
 375        Result<
 376            futures::stream::BoxStream<
 377                'static,
 378                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
 379            >,
 380            LanguageModelCompletionError,
 381        >,
 382    > {
 383        let request = into_google(
 384            request,
 385            self.model.request_id().to_string(),
 386            self.model.mode(),
 387        );
 388        let request = self.stream_completion(request, cx);
 389        let future = self.request_limiter.stream(async move {
 390            let response = request.await.map_err(LanguageModelCompletionError::from)?;
 391            Ok(GoogleEventMapper::new().map_stream(response))
 392        });
 393        async move { Ok(future.await?.boxed()) }.boxed()
 394    }
 395}
 396
 397pub fn into_google(
 398    mut request: LanguageModelRequest,
 399    model_id: String,
 400    mode: GoogleModelMode,
 401) -> google_ai::GenerateContentRequest {
 402    fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
 403        content
 404            .into_iter()
 405            .flat_map(|content| match content {
 406                language_model::MessageContent::Text(text) => {
 407                    if !text.is_empty() {
 408                        vec![Part::TextPart(google_ai::TextPart { text })]
 409                    } else {
 410                        vec![]
 411                    }
 412                }
 413                language_model::MessageContent::Thinking {
 414                    text: _,
 415                    signature: Some(signature),
 416                } => {
 417                    if !signature.is_empty() {
 418                        vec![Part::ThoughtPart(google_ai::ThoughtPart {
 419                            thought: true,
 420                            thought_signature: signature,
 421                        })]
 422                    } else {
 423                        vec![]
 424                    }
 425                }
 426                language_model::MessageContent::Thinking { .. } => {
 427                    vec![]
 428                }
 429                language_model::MessageContent::RedactedThinking(_) => vec![],
 430                language_model::MessageContent::Image(image) => {
 431                    vec![Part::InlineDataPart(google_ai::InlineDataPart {
 432                        inline_data: google_ai::GenerativeContentBlob {
 433                            mime_type: "image/png".to_string(),
 434                            data: image.source.to_string(),
 435                        },
 436                    })]
 437                }
 438                language_model::MessageContent::ToolUse(tool_use) => {
 439                    // Normalize empty string signatures to None
 440                    let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
 441
 442                    vec![Part::FunctionCallPart(google_ai::FunctionCallPart {
 443                        function_call: google_ai::FunctionCall {
 444                            name: tool_use.name.to_string(),
 445                            args: tool_use.input,
 446                        },
 447                        thought_signature,
 448                    })]
 449                }
 450                language_model::MessageContent::ToolResult(tool_result) => {
 451                    match tool_result.content {
 452                        language_model::LanguageModelToolResultContent::Text(text) => {
 453                            vec![Part::FunctionResponsePart(
 454                                google_ai::FunctionResponsePart {
 455                                    function_response: google_ai::FunctionResponse {
 456                                        name: tool_result.tool_name.to_string(),
 457                                        // The API expects a valid JSON object
 458                                        response: serde_json::json!({
 459                                            "output": text
 460                                        }),
 461                                    },
 462                                },
 463                            )]
 464                        }
 465                        language_model::LanguageModelToolResultContent::Image(image) => {
 466                            vec![
 467                                Part::FunctionResponsePart(google_ai::FunctionResponsePart {
 468                                    function_response: google_ai::FunctionResponse {
 469                                        name: tool_result.tool_name.to_string(),
 470                                        // The API expects a valid JSON object
 471                                        response: serde_json::json!({
 472                                            "output": "Tool responded with an image"
 473                                        }),
 474                                    },
 475                                }),
 476                                Part::InlineDataPart(google_ai::InlineDataPart {
 477                                    inline_data: google_ai::GenerativeContentBlob {
 478                                        mime_type: "image/png".to_string(),
 479                                        data: image.source.to_string(),
 480                                    },
 481                                }),
 482                            ]
 483                        }
 484                    }
 485                }
 486            })
 487            .collect()
 488    }
 489
 490    let system_instructions = if request
 491        .messages
 492        .first()
 493        .is_some_and(|msg| matches!(msg.role, Role::System))
 494    {
 495        let message = request.messages.remove(0);
 496        Some(SystemInstruction {
 497            parts: map_content(message.content),
 498        })
 499    } else {
 500        None
 501    };
 502
 503    google_ai::GenerateContentRequest {
 504        model: google_ai::ModelName { model_id },
 505        system_instruction: system_instructions,
 506        contents: request
 507            .messages
 508            .into_iter()
 509            .filter_map(|message| {
 510                let parts = map_content(message.content);
 511                if parts.is_empty() {
 512                    None
 513                } else {
 514                    Some(google_ai::Content {
 515                        parts,
 516                        role: match message.role {
 517                            Role::User => google_ai::Role::User,
 518                            Role::Assistant => google_ai::Role::Model,
 519                            Role::System => google_ai::Role::User, // Google AI doesn't have a system role
 520                        },
 521                    })
 522                }
 523            })
 524            .collect(),
 525        generation_config: Some(google_ai::GenerationConfig {
 526            candidate_count: Some(1),
 527            stop_sequences: Some(request.stop),
 528            max_output_tokens: None,
 529            temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
 530            thinking_config: match (request.thinking_allowed, mode) {
 531                (true, GoogleModelMode::Thinking { budget_tokens }) => {
 532                    budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
 533                }
 534                _ => None,
 535            },
 536            top_p: None,
 537            top_k: None,
 538        }),
 539        safety_settings: None,
 540        tools: (!request.tools.is_empty()).then(|| {
 541            vec![google_ai::Tool {
 542                function_declarations: request
 543                    .tools
 544                    .into_iter()
 545                    .map(|tool| FunctionDeclaration {
 546                        name: tool.name,
 547                        description: tool.description,
 548                        parameters: tool.input_schema,
 549                    })
 550                    .collect(),
 551            }]
 552        }),
 553        tool_config: request.tool_choice.map(|choice| google_ai::ToolConfig {
 554            function_calling_config: google_ai::FunctionCallingConfig {
 555                mode: match choice {
 556                    LanguageModelToolChoice::Auto => google_ai::FunctionCallingMode::Auto,
 557                    LanguageModelToolChoice::Any => google_ai::FunctionCallingMode::Any,
 558                    LanguageModelToolChoice::None => google_ai::FunctionCallingMode::None,
 559                },
 560                allowed_function_names: None,
 561            },
 562        }),
 563    }
 564}
 565
 566pub struct GoogleEventMapper {
 567    usage: UsageMetadata,
 568    stop_reason: StopReason,
 569}
 570
 571impl GoogleEventMapper {
 572    pub fn new() -> Self {
 573        Self {
 574            usage: UsageMetadata::default(),
 575            stop_reason: StopReason::EndTurn,
 576        }
 577    }
 578
 579    pub fn map_stream(
 580        mut self,
 581        events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
 582    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 583    {
 584        events
 585            .map(Some)
 586            .chain(futures::stream::once(async { None }))
 587            .flat_map(move |event| {
 588                futures::stream::iter(match event {
 589                    Some(Ok(event)) => self.map_event(event),
 590                    Some(Err(error)) => {
 591                        vec![Err(LanguageModelCompletionError::from(error))]
 592                    }
 593                    None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
 594                })
 595            })
 596    }
 597
 598    pub fn map_event(
 599        &mut self,
 600        event: GenerateContentResponse,
 601    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 602        static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
 603
 604        let mut events: Vec<_> = Vec::new();
 605        let mut wants_to_use_tool = false;
 606        if let Some(usage_metadata) = event.usage_metadata {
 607            update_usage(&mut self.usage, &usage_metadata);
 608            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
 609                convert_usage(&self.usage),
 610            )))
 611        }
 612
 613        if let Some(prompt_feedback) = event.prompt_feedback
 614            && let Some(block_reason) = prompt_feedback.block_reason.as_deref()
 615        {
 616            self.stop_reason = match block_reason {
 617                "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => {
 618                    StopReason::Refusal
 619                }
 620                _ => {
 621                    log::error!("Unexpected Google block_reason: {block_reason}");
 622                    StopReason::Refusal
 623                }
 624            };
 625            events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
 626
 627            return events;
 628        }
 629
 630        if let Some(candidates) = event.candidates {
 631            for candidate in candidates {
 632                if let Some(finish_reason) = candidate.finish_reason.as_deref() {
 633                    self.stop_reason = match finish_reason {
 634                        "STOP" => StopReason::EndTurn,
 635                        "MAX_TOKENS" => StopReason::MaxTokens,
 636                        _ => {
 637                            log::error!("Unexpected google finish_reason: {finish_reason}");
 638                            StopReason::EndTurn
 639                        }
 640                    };
 641                }
 642                candidate
 643                    .content
 644                    .parts
 645                    .into_iter()
 646                    .for_each(|part| match part {
 647                        Part::TextPart(text_part) => {
 648                            events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
 649                        }
 650                        Part::InlineDataPart(_) => {}
 651                        Part::FunctionCallPart(function_call_part) => {
 652                            wants_to_use_tool = true;
 653                            let name: Arc<str> = function_call_part.function_call.name.into();
 654                            let next_tool_id =
 655                                TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
 656                            let id: LanguageModelToolUseId =
 657                                format!("{}-{}", name, next_tool_id).into();
 658
 659                            // Normalize empty string signatures to None
 660                            let thought_signature = function_call_part
 661                                .thought_signature
 662                                .filter(|s| !s.is_empty());
 663
 664                            events.push(Ok(LanguageModelCompletionEvent::ToolUse(
 665                                LanguageModelToolUse {
 666                                    id,
 667                                    name,
 668                                    is_input_complete: true,
 669                                    raw_input: function_call_part.function_call.args.to_string(),
 670                                    input: function_call_part.function_call.args,
 671                                    thought_signature,
 672                                },
 673                            )));
 674                        }
 675                        Part::FunctionResponsePart(_) => {}
 676                        Part::ThoughtPart(part) => {
 677                            events.push(Ok(LanguageModelCompletionEvent::Thinking {
 678                                text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
 679                                signature: Some(part.thought_signature),
 680                            }));
 681                        }
 682                    });
 683            }
 684        }
 685
 686        // Even when Gemini wants to use a Tool, the API
 687        // responds with `finish_reason: STOP`
 688        if wants_to_use_tool {
 689            self.stop_reason = StopReason::ToolUse;
 690            events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
 691        }
 692        events
 693    }
 694}
 695
 696pub fn count_google_tokens(
 697    request: LanguageModelRequest,
 698    cx: &App,
 699) -> BoxFuture<'static, Result<u64>> {
 700    // We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly.
 701    // So we have to use tokenizer from tiktoken_rs to count tokens.
 702    cx.background_spawn(async move {
 703        let messages = request
 704            .messages
 705            .into_iter()
 706            .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
 707                role: match message.role {
 708                    Role::User => "user".into(),
 709                    Role::Assistant => "assistant".into(),
 710                    Role::System => "system".into(),
 711                },
 712                content: Some(message.string_contents()),
 713                name: None,
 714                function_call: None,
 715            })
 716            .collect::<Vec<_>>();
 717
 718        // Tiktoken doesn't yet support these models, so we manually use the
 719        // same tokenizer as GPT-4.
 720        tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
 721    })
 722    .boxed()
 723}
 724
 725fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
 726    if let Some(prompt_token_count) = new.prompt_token_count {
 727        usage.prompt_token_count = Some(prompt_token_count);
 728    }
 729    if let Some(cached_content_token_count) = new.cached_content_token_count {
 730        usage.cached_content_token_count = Some(cached_content_token_count);
 731    }
 732    if let Some(candidates_token_count) = new.candidates_token_count {
 733        usage.candidates_token_count = Some(candidates_token_count);
 734    }
 735    if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
 736        usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
 737    }
 738    if let Some(thoughts_token_count) = new.thoughts_token_count {
 739        usage.thoughts_token_count = Some(thoughts_token_count);
 740    }
 741    if let Some(total_token_count) = new.total_token_count {
 742        usage.total_token_count = Some(total_token_count);
 743    }
 744}
 745
 746fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
 747    let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
 748    let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
 749    let input_tokens = prompt_tokens - cached_tokens;
 750    let output_tokens = usage.candidates_token_count.unwrap_or(0);
 751
 752    language_model::TokenUsage {
 753        input_tokens,
 754        output_tokens,
 755        cache_read_input_tokens: cached_tokens,
 756        cache_creation_input_tokens: 0,
 757    }
 758}
 759
 760struct ConfigurationView {
 761    api_key_editor: Entity<InputField>,
 762    state: Entity<State>,
 763    target_agent: language_model::ConfigurationViewTargetAgent,
 764    load_credentials_task: Option<Task<()>>,
 765}
 766
 767impl ConfigurationView {
 768    fn new(
 769        state: Entity<State>,
 770        target_agent: language_model::ConfigurationViewTargetAgent,
 771        window: &mut Window,
 772        cx: &mut Context<Self>,
 773    ) -> Self {
 774        cx.observe(&state, |_, _, cx| {
 775            cx.notify();
 776        })
 777        .detach();
 778
 779        let load_credentials_task = Some(cx.spawn_in(window, {
 780            let state = state.clone();
 781            async move |this, cx| {
 782                if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
 783                    // We don't log an error, because "not signed in" is also an error.
 784                    let _ = task.await;
 785                }
 786                this.update(cx, |this, cx| {
 787                    this.load_credentials_task = None;
 788                    cx.notify();
 789                })
 790                .log_err();
 791            }
 792        }));
 793
 794        Self {
 795            api_key_editor: cx.new(|cx| InputField::new(window, cx, "AIzaSy...")),
 796            target_agent,
 797            state,
 798            load_credentials_task,
 799        }
 800    }
 801
 802    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 803        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
 804        if api_key.is_empty() {
 805            return;
 806        }
 807
 808        // url changes can cause the editor to be displayed again
 809        self.api_key_editor
 810            .update(cx, |editor, cx| editor.set_text("", window, cx));
 811
 812        let state = self.state.clone();
 813        cx.spawn_in(window, async move |_, cx| {
 814            state
 815                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
 816                .await
 817        })
 818        .detach_and_log_err(cx);
 819    }
 820
 821    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 822        self.api_key_editor
 823            .update(cx, |editor, cx| editor.set_text("", window, cx));
 824
 825        let state = self.state.clone();
 826        cx.spawn_in(window, async move |_, cx| {
 827            state
 828                .update(cx, |state, cx| state.set_api_key(None, cx))
 829                .await
 830        })
 831        .detach_and_log_err(cx);
 832    }
 833
 834    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
 835        !self.state.read(cx).is_authenticated()
 836    }
 837}
 838
 839impl Render for ConfigurationView {
 840    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 841        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
 842        let configured_card_label = if env_var_set {
 843            format!(
 844                "API key set in {} environment variable",
 845                API_KEY_ENV_VAR.name
 846            )
 847        } else {
 848            let api_url = GoogleLanguageModelProvider::api_url(cx);
 849            if api_url == google_ai::API_URL {
 850                "API key configured".to_string()
 851            } else {
 852                format!("API key configured for {}", api_url)
 853            }
 854        };
 855
 856        if self.load_credentials_task.is_some() {
 857            div()
 858                .child(Label::new("Loading credentials..."))
 859                .into_any_element()
 860        } else if self.should_render_editor(cx) {
 861            v_flex()
 862                .size_full()
 863                .on_action(cx.listener(Self::save_api_key))
 864                .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent {
 865                    ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Google AI".into(),
 866                    ConfigurationViewTargetAgent::Other(agent) => agent.clone(),
 867                })))
 868                .child(
 869                    List::new()
 870                        .child(
 871                            ListBulletItem::new("")
 872                                .child(Label::new("Create one by visiting"))
 873                                .child(ButtonLink::new("Google AI's console", "https://aistudio.google.com/app/apikey"))
 874                        )
 875                        .child(
 876                            ListBulletItem::new("Paste your API key below and hit enter to start using the agent")
 877                        )
 878                )
 879                .child(self.api_key_editor.clone())
 880                .child(
 881                    Label::new(
 882                        format!("You can also set the {GEMINI_API_KEY_VAR_NAME} environment variable and restart Zed."),
 883                    )
 884                    .size(LabelSize::Small).color(Color::Muted),
 885                )
 886                .into_any_element()
 887        } else {
 888            ConfiguredApiCard::new(configured_card_label)
 889                .disabled(env_var_set)
 890                .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
 891                .when(env_var_set, |this| {
 892                    this.tooltip_label(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR_NAME} and {GOOGLE_AI_API_KEY_VAR_NAME} environment variables are unset."))
 893                })
 894                .into_any_element()
 895        }
 896    }
 897}
 898
 899#[cfg(test)]
 900mod tests {
 901    use super::*;
 902    use google_ai::{
 903        Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse,
 904        Part, Role as GoogleRole, TextPart,
 905    };
 906    use language_model::{LanguageModelToolUseId, MessageContent, Role};
 907    use serde_json::json;
 908
 909    #[test]
 910    fn test_function_call_with_signature_creates_tool_use_with_signature() {
 911        let mut mapper = GoogleEventMapper::new();
 912
 913        let response = GenerateContentResponse {
 914            candidates: Some(vec![GenerateContentCandidate {
 915                index: Some(0),
 916                content: Content {
 917                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
 918                        function_call: FunctionCall {
 919                            name: "test_function".to_string(),
 920                            args: json!({"arg": "value"}),
 921                        },
 922                        thought_signature: Some("test_signature_123".to_string()),
 923                    })],
 924                    role: GoogleRole::Model,
 925                },
 926                finish_reason: None,
 927                finish_message: None,
 928                safety_ratings: None,
 929                citation_metadata: None,
 930            }]),
 931            prompt_feedback: None,
 932            usage_metadata: None,
 933        };
 934
 935        let events = mapper.map_event(response);
 936
 937        assert_eq!(events.len(), 2); // ToolUse event + Stop event
 938
 939        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
 940            assert_eq!(tool_use.name.as_ref(), "test_function");
 941            assert_eq!(
 942                tool_use.thought_signature.as_deref(),
 943                Some("test_signature_123")
 944            );
 945        } else {
 946            panic!("Expected ToolUse event");
 947        }
 948    }
 949
 950    #[test]
 951    fn test_function_call_without_signature_has_none() {
 952        let mut mapper = GoogleEventMapper::new();
 953
 954        let response = GenerateContentResponse {
 955            candidates: Some(vec![GenerateContentCandidate {
 956                index: Some(0),
 957                content: Content {
 958                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
 959                        function_call: FunctionCall {
 960                            name: "test_function".to_string(),
 961                            args: json!({"arg": "value"}),
 962                        },
 963                        thought_signature: None,
 964                    })],
 965                    role: GoogleRole::Model,
 966                },
 967                finish_reason: None,
 968                finish_message: None,
 969                safety_ratings: None,
 970                citation_metadata: None,
 971            }]),
 972            prompt_feedback: None,
 973            usage_metadata: None,
 974        };
 975
 976        let events = mapper.map_event(response);
 977
 978        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
 979            assert_eq!(tool_use.thought_signature, None);
 980        } else {
 981            panic!("Expected ToolUse event");
 982        }
 983    }
 984
 985    #[test]
 986    fn test_empty_string_signature_normalized_to_none() {
 987        let mut mapper = GoogleEventMapper::new();
 988
 989        let response = GenerateContentResponse {
 990            candidates: Some(vec![GenerateContentCandidate {
 991                index: Some(0),
 992                content: Content {
 993                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
 994                        function_call: FunctionCall {
 995                            name: "test_function".to_string(),
 996                            args: json!({"arg": "value"}),
 997                        },
 998                        thought_signature: Some("".to_string()),
 999                    })],
1000                    role: GoogleRole::Model,
1001                },
1002                finish_reason: None,
1003                finish_message: None,
1004                safety_ratings: None,
1005                citation_metadata: None,
1006            }]),
1007            prompt_feedback: None,
1008            usage_metadata: None,
1009        };
1010
1011        let events = mapper.map_event(response);
1012
1013        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1014            assert_eq!(tool_use.thought_signature, None);
1015        } else {
1016            panic!("Expected ToolUse event");
1017        }
1018    }
1019
1020    #[test]
1021    fn test_parallel_function_calls_preserve_signatures() {
1022        let mut mapper = GoogleEventMapper::new();
1023
1024        let response = GenerateContentResponse {
1025            candidates: Some(vec![GenerateContentCandidate {
1026                index: Some(0),
1027                content: Content {
1028                    parts: vec![
1029                        Part::FunctionCallPart(FunctionCallPart {
1030                            function_call: FunctionCall {
1031                                name: "function_1".to_string(),
1032                                args: json!({"arg": "value1"}),
1033                            },
1034                            thought_signature: Some("signature_1".to_string()),
1035                        }),
1036                        Part::FunctionCallPart(FunctionCallPart {
1037                            function_call: FunctionCall {
1038                                name: "function_2".to_string(),
1039                                args: json!({"arg": "value2"}),
1040                            },
1041                            thought_signature: None,
1042                        }),
1043                    ],
1044                    role: GoogleRole::Model,
1045                },
1046                finish_reason: None,
1047                finish_message: None,
1048                safety_ratings: None,
1049                citation_metadata: None,
1050            }]),
1051            prompt_feedback: None,
1052            usage_metadata: None,
1053        };
1054
1055        let events = mapper.map_event(response);
1056
1057        assert_eq!(events.len(), 3); // 2 ToolUse events + Stop event
1058
1059        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1060            assert_eq!(tool_use.name.as_ref(), "function_1");
1061            assert_eq!(tool_use.thought_signature.as_deref(), Some("signature_1"));
1062        } else {
1063            panic!("Expected ToolUse event for function_1");
1064        }
1065
1066        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
1067            assert_eq!(tool_use.name.as_ref(), "function_2");
1068            assert_eq!(tool_use.thought_signature, None);
1069        } else {
1070            panic!("Expected ToolUse event for function_2");
1071        }
1072    }
1073
1074    #[test]
1075    fn test_tool_use_with_signature_converts_to_function_call_part() {
1076        let tool_use = language_model::LanguageModelToolUse {
1077            id: LanguageModelToolUseId::from("test_id"),
1078            name: "test_function".into(),
1079            raw_input: json!({"arg": "value"}).to_string(),
1080            input: json!({"arg": "value"}),
1081            is_input_complete: true,
1082            thought_signature: Some("test_signature_456".to_string()),
1083        };
1084
1085        let request = super::into_google(
1086            LanguageModelRequest {
1087                messages: vec![language_model::LanguageModelRequestMessage {
1088                    role: Role::Assistant,
1089                    content: vec![MessageContent::ToolUse(tool_use)],
1090                    cache: false,
1091                    reasoning_details: None,
1092                }],
1093                ..Default::default()
1094            },
1095            "gemini-2.5-flash".to_string(),
1096            GoogleModelMode::Default,
1097        );
1098
1099        assert_eq!(request.contents[0].parts.len(), 1);
1100        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1101            assert_eq!(fc_part.function_call.name, "test_function");
1102            assert_eq!(
1103                fc_part.thought_signature.as_deref(),
1104                Some("test_signature_456")
1105            );
1106        } else {
1107            panic!("Expected FunctionCallPart");
1108        }
1109    }
1110
1111    #[test]
1112    fn test_tool_use_without_signature_omits_field() {
1113        let tool_use = language_model::LanguageModelToolUse {
1114            id: LanguageModelToolUseId::from("test_id"),
1115            name: "test_function".into(),
1116            raw_input: json!({"arg": "value"}).to_string(),
1117            input: json!({"arg": "value"}),
1118            is_input_complete: true,
1119            thought_signature: None,
1120        };
1121
1122        let request = super::into_google(
1123            LanguageModelRequest {
1124                messages: vec![language_model::LanguageModelRequestMessage {
1125                    role: Role::Assistant,
1126                    content: vec![MessageContent::ToolUse(tool_use)],
1127                    cache: false,
1128                    reasoning_details: None,
1129                }],
1130                ..Default::default()
1131            },
1132            "gemini-2.5-flash".to_string(),
1133            GoogleModelMode::Default,
1134        );
1135
1136        assert_eq!(request.contents[0].parts.len(), 1);
1137        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1138            assert_eq!(fc_part.thought_signature, None);
1139        } else {
1140            panic!("Expected FunctionCallPart");
1141        }
1142    }
1143
1144    #[test]
1145    fn test_empty_signature_in_tool_use_normalized_to_none() {
1146        let tool_use = language_model::LanguageModelToolUse {
1147            id: LanguageModelToolUseId::from("test_id"),
1148            name: "test_function".into(),
1149            raw_input: json!({"arg": "value"}).to_string(),
1150            input: json!({"arg": "value"}),
1151            is_input_complete: true,
1152            thought_signature: Some("".to_string()),
1153        };
1154
1155        let request = super::into_google(
1156            LanguageModelRequest {
1157                messages: vec![language_model::LanguageModelRequestMessage {
1158                    role: Role::Assistant,
1159                    content: vec![MessageContent::ToolUse(tool_use)],
1160                    cache: false,
1161                    reasoning_details: None,
1162                }],
1163                ..Default::default()
1164            },
1165            "gemini-2.5-flash".to_string(),
1166            GoogleModelMode::Default,
1167        );
1168
1169        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1170            assert_eq!(fc_part.thought_signature, None);
1171        } else {
1172            panic!("Expected FunctionCallPart");
1173        }
1174    }
1175
1176    #[test]
1177    fn test_round_trip_preserves_signature() {
1178        let mut mapper = GoogleEventMapper::new();
1179
1180        // Simulate receiving a response from Google with a signature
1181        let response = GenerateContentResponse {
1182            candidates: Some(vec![GenerateContentCandidate {
1183                index: Some(0),
1184                content: Content {
1185                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
1186                        function_call: FunctionCall {
1187                            name: "test_function".to_string(),
1188                            args: json!({"arg": "value"}),
1189                        },
1190                        thought_signature: Some("round_trip_sig".to_string()),
1191                    })],
1192                    role: GoogleRole::Model,
1193                },
1194                finish_reason: None,
1195                finish_message: None,
1196                safety_ratings: None,
1197                citation_metadata: None,
1198            }]),
1199            prompt_feedback: None,
1200            usage_metadata: None,
1201        };
1202
1203        let events = mapper.map_event(response);
1204
1205        let tool_use = if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1206            tool_use.clone()
1207        } else {
1208            panic!("Expected ToolUse event");
1209        };
1210
1211        // Convert back to Google format
1212        let request = super::into_google(
1213            LanguageModelRequest {
1214                messages: vec![language_model::LanguageModelRequestMessage {
1215                    role: Role::Assistant,
1216                    content: vec![MessageContent::ToolUse(tool_use)],
1217                    cache: false,
1218                    reasoning_details: None,
1219                }],
1220                ..Default::default()
1221            },
1222            "gemini-2.5-flash".to_string(),
1223            GoogleModelMode::Default,
1224        );
1225
1226        // Verify signature is preserved
1227        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1228            assert_eq!(fc_part.thought_signature.as_deref(), Some("round_trip_sig"));
1229        } else {
1230            panic!("Expected FunctionCallPart");
1231        }
1232    }
1233
1234    #[test]
1235    fn test_mixed_text_and_function_call_with_signature() {
1236        let mut mapper = GoogleEventMapper::new();
1237
1238        let response = GenerateContentResponse {
1239            candidates: Some(vec![GenerateContentCandidate {
1240                index: Some(0),
1241                content: Content {
1242                    parts: vec![
1243                        Part::TextPart(TextPart {
1244                            text: "I'll help with that.".to_string(),
1245                        }),
1246                        Part::FunctionCallPart(FunctionCallPart {
1247                            function_call: FunctionCall {
1248                                name: "helper_function".to_string(),
1249                                args: json!({"query": "help"}),
1250                            },
1251                            thought_signature: Some("mixed_sig".to_string()),
1252                        }),
1253                    ],
1254                    role: GoogleRole::Model,
1255                },
1256                finish_reason: None,
1257                finish_message: None,
1258                safety_ratings: None,
1259                citation_metadata: None,
1260            }]),
1261            prompt_feedback: None,
1262            usage_metadata: None,
1263        };
1264
1265        let events = mapper.map_event(response);
1266
1267        assert_eq!(events.len(), 3); // Text event + ToolUse event + Stop event
1268
1269        if let Ok(LanguageModelCompletionEvent::Text(text)) = &events[0] {
1270            assert_eq!(text, "I'll help with that.");
1271        } else {
1272            panic!("Expected Text event");
1273        }
1274
1275        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
1276            assert_eq!(tool_use.name.as_ref(), "helper_function");
1277            assert_eq!(tool_use.thought_signature.as_deref(), Some("mixed_sig"));
1278        } else {
1279            panic!("Expected ToolUse event");
1280        }
1281    }
1282
1283    #[test]
1284    fn test_special_characters_in_signature_preserved() {
1285        let mut mapper = GoogleEventMapper::new();
1286
1287        let signature_with_special_chars = "sig<>\"'&%$#@!{}[]".to_string();
1288
1289        let response = GenerateContentResponse {
1290            candidates: Some(vec![GenerateContentCandidate {
1291                index: Some(0),
1292                content: Content {
1293                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
1294                        function_call: FunctionCall {
1295                            name: "test_function".to_string(),
1296                            args: json!({"arg": "value"}),
1297                        },
1298                        thought_signature: Some(signature_with_special_chars.clone()),
1299                    })],
1300                    role: GoogleRole::Model,
1301                },
1302                finish_reason: None,
1303                finish_message: None,
1304                safety_ratings: None,
1305                citation_metadata: None,
1306            }]),
1307            prompt_feedback: None,
1308            usage_metadata: None,
1309        };
1310
1311        let events = mapper.map_event(response);
1312
1313        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1314            assert_eq!(
1315                tool_use.thought_signature.as_deref(),
1316                Some(signature_with_special_chars.as_str())
1317            );
1318        } else {
1319            panic!("Expected ToolUse event");
1320        }
1321    }
1322}