google.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use collections::BTreeMap;
   3use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture};
   4use google_ai::{
   5    FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
   6    ThinkingConfig, UsageMetadata,
   7};
   8use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
   9use http_client::HttpClient;
  10use language_model::{
  11    AuthenticateError, ConfigurationViewTargetAgent, EnvVar, LanguageModelCompletionError,
  12    LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
  13    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
  14};
  15use language_model::{
  16    IconOrSvg, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
  17    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  18    LanguageModelRequest, RateLimiter, Role,
  19};
  20use schemars::JsonSchema;
  21use serde::{Deserialize, Serialize};
  22pub use settings::GoogleAvailableModel as AvailableModel;
  23use settings::{Settings, SettingsStore};
  24use std::pin::Pin;
  25use std::sync::{
  26    Arc, LazyLock,
  27    atomic::{self, AtomicU64},
  28};
  29use strum::IntoEnumIterator;
  30use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
  31use ui_input::InputField;
  32use util::ResultExt;
  33
  34use language_model::ApiKeyState;
  35
  36const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
  37const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
  38
  39#[derive(Default, Clone, Debug, PartialEq)]
  40pub struct GoogleSettings {
  41    pub api_url: String,
  42    pub available_models: Vec<AvailableModel>,
  43}
  44
  45#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
  46#[serde(tag = "type", rename_all = "lowercase")]
  47pub enum ModelMode {
  48    #[default]
  49    Default,
  50    Thinking {
  51        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  52        budget_tokens: Option<u32>,
  53    },
  54}
  55
  56pub struct GoogleLanguageModelProvider {
  57    http_client: Arc<dyn HttpClient>,
  58    state: Entity<State>,
  59}
  60
  61pub struct State {
  62    api_key_state: ApiKeyState,
  63}
  64
  65const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
  66const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY";
  67
  68static API_KEY_ENV_VAR: LazyLock<EnvVar> = LazyLock::new(|| {
  69    // Try GEMINI_API_KEY first as primary, fallback to GOOGLE_AI_API_KEY
  70    EnvVar::new(GEMINI_API_KEY_VAR_NAME.into()).or(EnvVar::new(GOOGLE_AI_API_KEY_VAR_NAME.into()))
  71});
  72
  73impl State {
  74    fn is_authenticated(&self) -> bool {
  75        self.api_key_state.has_key()
  76    }
  77
  78    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
  79        let api_url = GoogleLanguageModelProvider::api_url(cx);
  80        self.api_key_state
  81            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
  82    }
  83
  84    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
  85        let api_url = GoogleLanguageModelProvider::api_url(cx);
  86        self.api_key_state
  87            .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
  88    }
  89}
  90
  91impl GoogleLanguageModelProvider {
  92    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
  93        let state = cx.new(|cx| {
  94            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
  95                let api_url = Self::api_url(cx);
  96                this.api_key_state
  97                    .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
  98                cx.notify();
  99            })
 100            .detach();
 101            State {
 102                api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
 103            }
 104        });
 105
 106        Self { http_client, state }
 107    }
 108
 109    fn create_language_model(&self, model: google_ai::Model) -> Arc<dyn LanguageModel> {
 110        Arc::new(GoogleLanguageModel {
 111            id: LanguageModelId::from(model.id().to_string()),
 112            model,
 113            state: self.state.clone(),
 114            http_client: self.http_client.clone(),
 115            request_limiter: RateLimiter::new(4),
 116        })
 117    }
 118
 119    fn settings(cx: &App) -> &GoogleSettings {
 120        &crate::AllLanguageModelSettings::get_global(cx).google
 121    }
 122
 123    fn api_url(cx: &App) -> SharedString {
 124        let api_url = &Self::settings(cx).api_url;
 125        if api_url.is_empty() {
 126            google_ai::API_URL.into()
 127        } else {
 128            SharedString::new(api_url.as_str())
 129        }
 130    }
 131}
 132
 133impl LanguageModelProviderState for GoogleLanguageModelProvider {
 134    type ObservableEntity = State;
 135
 136    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 137        Some(self.state.clone())
 138    }
 139}
 140
 141impl LanguageModelProvider for GoogleLanguageModelProvider {
 142    fn id(&self) -> LanguageModelProviderId {
 143        PROVIDER_ID
 144    }
 145
 146    fn name(&self) -> LanguageModelProviderName {
 147        PROVIDER_NAME
 148    }
 149
 150    fn icon(&self) -> IconOrSvg {
 151        IconOrSvg::Icon(IconName::AiGoogle)
 152    }
 153
 154    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 155        Some(self.create_language_model(google_ai::Model::default()))
 156    }
 157
 158    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 159        Some(self.create_language_model(google_ai::Model::default_fast()))
 160    }
 161
 162    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 163        let mut models = BTreeMap::default();
 164
 165        // Add base models from google_ai::Model::iter()
 166        for model in google_ai::Model::iter() {
 167            if !matches!(model, google_ai::Model::Custom { .. }) {
 168                models.insert(model.id().to_string(), model);
 169            }
 170        }
 171
 172        // Override with available models from settings
 173        for model in &GoogleLanguageModelProvider::settings(cx).available_models {
 174            models.insert(
 175                model.name.clone(),
 176                google_ai::Model::Custom {
 177                    name: model.name.clone(),
 178                    display_name: model.display_name.clone(),
 179                    max_tokens: model.max_tokens,
 180                    mode: model.mode.unwrap_or_default(),
 181                },
 182            );
 183        }
 184
 185        models
 186            .into_values()
 187            .map(|model| {
 188                Arc::new(GoogleLanguageModel {
 189                    id: LanguageModelId::from(model.id().to_string()),
 190                    model,
 191                    state: self.state.clone(),
 192                    http_client: self.http_client.clone(),
 193                    request_limiter: RateLimiter::new(4),
 194                }) as Arc<dyn LanguageModel>
 195            })
 196            .collect()
 197    }
 198
 199    fn is_authenticated(&self, cx: &App) -> bool {
 200        self.state.read(cx).is_authenticated()
 201    }
 202
 203    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 204        self.state.update(cx, |state, cx| state.authenticate(cx))
 205    }
 206
 207    fn configuration_view(
 208        &self,
 209        target_agent: language_model::ConfigurationViewTargetAgent,
 210        window: &mut Window,
 211        cx: &mut App,
 212    ) -> AnyView {
 213        cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
 214            .into()
 215    }
 216
 217    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 218        self.state
 219            .update(cx, |state, cx| state.set_api_key(None, cx))
 220    }
 221}
 222
 223pub struct GoogleLanguageModel {
 224    id: LanguageModelId,
 225    model: google_ai::Model,
 226    state: Entity<State>,
 227    http_client: Arc<dyn HttpClient>,
 228    request_limiter: RateLimiter,
 229}
 230
 231impl GoogleLanguageModel {
 232    fn stream_completion(
 233        &self,
 234        request: google_ai::GenerateContentRequest,
 235        cx: &AsyncApp,
 236    ) -> BoxFuture<
 237        'static,
 238        Result<futures::stream::BoxStream<'static, Result<GenerateContentResponse>>>,
 239    > {
 240        let http_client = self.http_client.clone();
 241
 242        let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
 243            let api_url = GoogleLanguageModelProvider::api_url(cx);
 244            (state.api_key_state.key(&api_url), api_url)
 245        }) else {
 246            return future::ready(Err(anyhow!("App state dropped"))).boxed();
 247        };
 248
 249        async move {
 250            let api_key = api_key.context("Missing Google API key")?;
 251            let request = google_ai::stream_generate_content(
 252                http_client.as_ref(),
 253                &api_url,
 254                &api_key,
 255                request,
 256            );
 257            request.await.context("failed to stream completion")
 258        }
 259        .boxed()
 260    }
 261}
 262
 263impl LanguageModel for GoogleLanguageModel {
 264    fn id(&self) -> LanguageModelId {
 265        self.id.clone()
 266    }
 267
 268    fn name(&self) -> LanguageModelName {
 269        LanguageModelName::from(self.model.display_name().to_string())
 270    }
 271
 272    fn provider_id(&self) -> LanguageModelProviderId {
 273        PROVIDER_ID
 274    }
 275
 276    fn provider_name(&self) -> LanguageModelProviderName {
 277        PROVIDER_NAME
 278    }
 279
 280    fn supports_tools(&self) -> bool {
 281        self.model.supports_tools()
 282    }
 283
 284    fn supports_images(&self) -> bool {
 285        self.model.supports_images()
 286    }
 287
 288    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 289        match choice {
 290            LanguageModelToolChoice::Auto
 291            | LanguageModelToolChoice::Any
 292            | LanguageModelToolChoice::None => true,
 293        }
 294    }
 295
 296    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 297        LanguageModelToolSchemaFormat::JsonSchemaSubset
 298    }
 299
 300    fn telemetry_id(&self) -> String {
 301        format!("google/{}", self.model.request_id())
 302    }
 303
 304    fn max_token_count(&self) -> u64 {
 305        self.model.max_token_count()
 306    }
 307
 308    fn max_output_tokens(&self) -> Option<u64> {
 309        self.model.max_output_tokens()
 310    }
 311
 312    fn count_tokens(
 313        &self,
 314        request: LanguageModelRequest,
 315        cx: &App,
 316    ) -> BoxFuture<'static, Result<u64>> {
 317        let model_id = self.model.request_id().to_string();
 318        let request = into_google(request, model_id, self.model.mode());
 319        let http_client = self.http_client.clone();
 320        let api_url = GoogleLanguageModelProvider::api_url(cx);
 321        let api_key = self.state.read(cx).api_key_state.key(&api_url);
 322
 323        async move {
 324            let Some(api_key) = api_key else {
 325                return Err(LanguageModelCompletionError::NoApiKey {
 326                    provider: PROVIDER_NAME,
 327                }
 328                .into());
 329            };
 330            let response = google_ai::count_tokens(
 331                http_client.as_ref(),
 332                &api_url,
 333                &api_key,
 334                google_ai::CountTokensRequest {
 335                    generate_content_request: request,
 336                },
 337            )
 338            .await?;
 339            Ok(response.total_tokens)
 340        }
 341        .boxed()
 342    }
 343
 344    fn stream_completion(
 345        &self,
 346        request: LanguageModelRequest,
 347        cx: &AsyncApp,
 348    ) -> BoxFuture<
 349        'static,
 350        Result<
 351            futures::stream::BoxStream<
 352                'static,
 353                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
 354            >,
 355            LanguageModelCompletionError,
 356        >,
 357    > {
 358        let request = into_google(
 359            request,
 360            self.model.request_id().to_string(),
 361            self.model.mode(),
 362        );
 363        let request = self.stream_completion(request, cx);
 364        let future = self.request_limiter.stream(async move {
 365            let response = request.await.map_err(LanguageModelCompletionError::from)?;
 366            Ok(GoogleEventMapper::new().map_stream(response))
 367        });
 368        async move { Ok(future.await?.boxed()) }.boxed()
 369    }
 370}
 371
 372pub fn into_google(
 373    mut request: LanguageModelRequest,
 374    model_id: String,
 375    mode: GoogleModelMode,
 376) -> google_ai::GenerateContentRequest {
 377    fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
 378        content
 379            .into_iter()
 380            .flat_map(|content| match content {
 381                language_model::MessageContent::Text(text) => {
 382                    if !text.is_empty() {
 383                        vec![Part::TextPart(google_ai::TextPart { text })]
 384                    } else {
 385                        vec![]
 386                    }
 387                }
 388                language_model::MessageContent::Thinking {
 389                    text: _,
 390                    signature: Some(signature),
 391                } => {
 392                    if !signature.is_empty() {
 393                        vec![Part::ThoughtPart(google_ai::ThoughtPart {
 394                            thought: true,
 395                            thought_signature: signature,
 396                        })]
 397                    } else {
 398                        vec![]
 399                    }
 400                }
 401                language_model::MessageContent::Thinking { .. } => {
 402                    vec![]
 403                }
 404                language_model::MessageContent::RedactedThinking(_) => vec![],
 405                language_model::MessageContent::Image(image) => {
 406                    vec![Part::InlineDataPart(google_ai::InlineDataPart {
 407                        inline_data: google_ai::GenerativeContentBlob {
 408                            mime_type: "image/png".to_string(),
 409                            data: image.source.to_string(),
 410                        },
 411                    })]
 412                }
 413                language_model::MessageContent::ToolUse(tool_use) => {
 414                    // Normalize empty string signatures to None
 415                    let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
 416
 417                    vec![Part::FunctionCallPart(google_ai::FunctionCallPart {
 418                        function_call: google_ai::FunctionCall {
 419                            name: tool_use.name.to_string(),
 420                            args: tool_use.input,
 421                        },
 422                        thought_signature,
 423                    })]
 424                }
 425                language_model::MessageContent::ToolResult(tool_result) => {
 426                    match tool_result.content {
 427                        language_model::LanguageModelToolResultContent::Text(text) => {
 428                            vec![Part::FunctionResponsePart(
 429                                google_ai::FunctionResponsePart {
 430                                    function_response: google_ai::FunctionResponse {
 431                                        name: tool_result.tool_name.to_string(),
 432                                        // The API expects a valid JSON object
 433                                        response: serde_json::json!({
 434                                            "output": text
 435                                        }),
 436                                    },
 437                                },
 438                            )]
 439                        }
 440                        language_model::LanguageModelToolResultContent::Image(image) => {
 441                            vec![
 442                                Part::FunctionResponsePart(google_ai::FunctionResponsePart {
 443                                    function_response: google_ai::FunctionResponse {
 444                                        name: tool_result.tool_name.to_string(),
 445                                        // The API expects a valid JSON object
 446                                        response: serde_json::json!({
 447                                            "output": "Tool responded with an image"
 448                                        }),
 449                                    },
 450                                }),
 451                                Part::InlineDataPart(google_ai::InlineDataPart {
 452                                    inline_data: google_ai::GenerativeContentBlob {
 453                                        mime_type: "image/png".to_string(),
 454                                        data: image.source.to_string(),
 455                                    },
 456                                }),
 457                            ]
 458                        }
 459                    }
 460                }
 461            })
 462            .collect()
 463    }
 464
 465    let system_instructions = if request
 466        .messages
 467        .first()
 468        .is_some_and(|msg| matches!(msg.role, Role::System))
 469    {
 470        let message = request.messages.remove(0);
 471        Some(SystemInstruction {
 472            parts: map_content(message.content),
 473        })
 474    } else {
 475        None
 476    };
 477
 478    google_ai::GenerateContentRequest {
 479        model: google_ai::ModelName { model_id },
 480        system_instruction: system_instructions,
 481        contents: request
 482            .messages
 483            .into_iter()
 484            .filter_map(|message| {
 485                let parts = map_content(message.content);
 486                if parts.is_empty() {
 487                    None
 488                } else {
 489                    Some(google_ai::Content {
 490                        parts,
 491                        role: match message.role {
 492                            Role::User => google_ai::Role::User,
 493                            Role::Assistant => google_ai::Role::Model,
 494                            Role::System => google_ai::Role::User, // Google AI doesn't have a system role
 495                        },
 496                    })
 497                }
 498            })
 499            .collect(),
 500        generation_config: Some(google_ai::GenerationConfig {
 501            candidate_count: Some(1),
 502            stop_sequences: Some(request.stop),
 503            max_output_tokens: None,
 504            temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
 505            thinking_config: match (request.thinking_allowed, mode) {
 506                (true, GoogleModelMode::Thinking { budget_tokens }) => {
 507                    budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
 508                }
 509                _ => None,
 510            },
 511            top_p: None,
 512            top_k: None,
 513        }),
 514        safety_settings: None,
 515        tools: (!request.tools.is_empty()).then(|| {
 516            vec![google_ai::Tool {
 517                function_declarations: request
 518                    .tools
 519                    .into_iter()
 520                    .map(|tool| FunctionDeclaration {
 521                        name: tool.name,
 522                        description: tool.description,
 523                        parameters: tool.input_schema,
 524                    })
 525                    .collect(),
 526            }]
 527        }),
 528        tool_config: request.tool_choice.map(|choice| google_ai::ToolConfig {
 529            function_calling_config: google_ai::FunctionCallingConfig {
 530                mode: match choice {
 531                    LanguageModelToolChoice::Auto => google_ai::FunctionCallingMode::Auto,
 532                    LanguageModelToolChoice::Any => google_ai::FunctionCallingMode::Any,
 533                    LanguageModelToolChoice::None => google_ai::FunctionCallingMode::None,
 534                },
 535                allowed_function_names: None,
 536            },
 537        }),
 538    }
 539}
 540
 541pub struct GoogleEventMapper {
 542    usage: UsageMetadata,
 543    stop_reason: StopReason,
 544}
 545
 546impl GoogleEventMapper {
 547    pub fn new() -> Self {
 548        Self {
 549            usage: UsageMetadata::default(),
 550            stop_reason: StopReason::EndTurn,
 551        }
 552    }
 553
 554    pub fn map_stream(
 555        mut self,
 556        events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
 557    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 558    {
 559        events
 560            .map(Some)
 561            .chain(futures::stream::once(async { None }))
 562            .flat_map(move |event| {
 563                futures::stream::iter(match event {
 564                    Some(Ok(event)) => self.map_event(event),
 565                    Some(Err(error)) => {
 566                        vec![Err(LanguageModelCompletionError::from(error))]
 567                    }
 568                    None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
 569                })
 570            })
 571    }
 572
 573    pub fn map_event(
 574        &mut self,
 575        event: GenerateContentResponse,
 576    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 577        static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
 578
 579        let mut events: Vec<_> = Vec::new();
 580        let mut wants_to_use_tool = false;
 581        if let Some(usage_metadata) = event.usage_metadata {
 582            update_usage(&mut self.usage, &usage_metadata);
 583            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
 584                convert_usage(&self.usage),
 585            )))
 586        }
 587
 588        if let Some(prompt_feedback) = event.prompt_feedback
 589            && let Some(block_reason) = prompt_feedback.block_reason.as_deref()
 590        {
 591            self.stop_reason = match block_reason {
 592                "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => {
 593                    StopReason::Refusal
 594                }
 595                _ => {
 596                    log::error!("Unexpected Google block_reason: {block_reason}");
 597                    StopReason::Refusal
 598                }
 599            };
 600            events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
 601
 602            return events;
 603        }
 604
 605        if let Some(candidates) = event.candidates {
 606            for candidate in candidates {
 607                if let Some(finish_reason) = candidate.finish_reason.as_deref() {
 608                    self.stop_reason = match finish_reason {
 609                        "STOP" => StopReason::EndTurn,
 610                        "MAX_TOKENS" => StopReason::MaxTokens,
 611                        _ => {
 612                            log::error!("Unexpected google finish_reason: {finish_reason}");
 613                            StopReason::EndTurn
 614                        }
 615                    };
 616                }
 617                candidate
 618                    .content
 619                    .parts
 620                    .into_iter()
 621                    .for_each(|part| match part {
 622                        Part::TextPart(text_part) => {
 623                            events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
 624                        }
 625                        Part::InlineDataPart(_) => {}
 626                        Part::FunctionCallPart(function_call_part) => {
 627                            wants_to_use_tool = true;
 628                            let name: Arc<str> = function_call_part.function_call.name.into();
 629                            let next_tool_id =
 630                                TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
 631                            let id: LanguageModelToolUseId =
 632                                format!("{}-{}", name, next_tool_id).into();
 633
 634                            // Normalize empty string signatures to None
 635                            let thought_signature = function_call_part
 636                                .thought_signature
 637                                .filter(|s| !s.is_empty());
 638
 639                            events.push(Ok(LanguageModelCompletionEvent::ToolUse(
 640                                LanguageModelToolUse {
 641                                    id,
 642                                    name,
 643                                    is_input_complete: true,
 644                                    raw_input: function_call_part.function_call.args.to_string(),
 645                                    input: function_call_part.function_call.args,
 646                                    thought_signature,
 647                                },
 648                            )));
 649                        }
 650                        Part::FunctionResponsePart(_) => {}
 651                        Part::ThoughtPart(part) => {
 652                            events.push(Ok(LanguageModelCompletionEvent::Thinking {
 653                                text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
 654                                signature: Some(part.thought_signature),
 655                            }));
 656                        }
 657                    });
 658            }
 659        }
 660
 661        // Even when Gemini wants to use a Tool, the API
 662        // responds with `finish_reason: STOP`
 663        if wants_to_use_tool {
 664            self.stop_reason = StopReason::ToolUse;
 665            events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
 666        }
 667        events
 668    }
 669}
 670
 671pub fn count_google_tokens(
 672    request: LanguageModelRequest,
 673    cx: &App,
 674) -> BoxFuture<'static, Result<u64>> {
 675    // We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly.
 676    // So we have to use tokenizer from tiktoken_rs to count tokens.
 677    cx.background_spawn(async move {
 678        let messages = request
 679            .messages
 680            .into_iter()
 681            .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
 682                role: match message.role {
 683                    Role::User => "user".into(),
 684                    Role::Assistant => "assistant".into(),
 685                    Role::System => "system".into(),
 686                },
 687                content: Some(message.string_contents()),
 688                name: None,
 689                function_call: None,
 690            })
 691            .collect::<Vec<_>>();
 692
 693        // Tiktoken doesn't support these models, so we manually use the
 694        // same tokenizer as GPT-4.
 695        tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
 696    })
 697    .boxed()
 698}
 699
 700fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
 701    if let Some(prompt_token_count) = new.prompt_token_count {
 702        usage.prompt_token_count = Some(prompt_token_count);
 703    }
 704    if let Some(cached_content_token_count) = new.cached_content_token_count {
 705        usage.cached_content_token_count = Some(cached_content_token_count);
 706    }
 707    if let Some(candidates_token_count) = new.candidates_token_count {
 708        usage.candidates_token_count = Some(candidates_token_count);
 709    }
 710    if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
 711        usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
 712    }
 713    if let Some(thoughts_token_count) = new.thoughts_token_count {
 714        usage.thoughts_token_count = Some(thoughts_token_count);
 715    }
 716    if let Some(total_token_count) = new.total_token_count {
 717        usage.total_token_count = Some(total_token_count);
 718    }
 719}
 720
 721fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
 722    let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
 723    let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
 724    let input_tokens = prompt_tokens - cached_tokens;
 725    let output_tokens = usage.candidates_token_count.unwrap_or(0);
 726
 727    language_model::TokenUsage {
 728        input_tokens,
 729        output_tokens,
 730        cache_read_input_tokens: cached_tokens,
 731        cache_creation_input_tokens: 0,
 732    }
 733}
 734
 735struct ConfigurationView {
 736    api_key_editor: Entity<InputField>,
 737    state: Entity<State>,
 738    target_agent: language_model::ConfigurationViewTargetAgent,
 739    load_credentials_task: Option<Task<()>>,
 740}
 741
 742impl ConfigurationView {
 743    fn new(
 744        state: Entity<State>,
 745        target_agent: language_model::ConfigurationViewTargetAgent,
 746        window: &mut Window,
 747        cx: &mut Context<Self>,
 748    ) -> Self {
 749        cx.observe(&state, |_, _, cx| {
 750            cx.notify();
 751        })
 752        .detach();
 753
 754        let load_credentials_task = Some(cx.spawn_in(window, {
 755            let state = state.clone();
 756            async move |this, cx| {
 757                if let Some(task) = state
 758                    .update(cx, |state, cx| state.authenticate(cx))
 759                    .log_err()
 760                {
 761                    // We don't log an error, because "not signed in" is also an error.
 762                    let _ = task.await;
 763                }
 764                this.update(cx, |this, cx| {
 765                    this.load_credentials_task = None;
 766                    cx.notify();
 767                })
 768                .log_err();
 769            }
 770        }));
 771
 772        Self {
 773            api_key_editor: cx.new(|cx| InputField::new(window, cx, "AIzaSy...")),
 774            target_agent,
 775            state,
 776            load_credentials_task,
 777        }
 778    }
 779
 780    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 781        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
 782        if api_key.is_empty() {
 783            return;
 784        }
 785
 786        // url changes can cause the editor to be displayed again
 787        self.api_key_editor
 788            .update(cx, |editor, cx| editor.set_text("", window, cx));
 789
 790        let state = self.state.clone();
 791        cx.spawn_in(window, async move |_, cx| {
 792            state
 793                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
 794                .await
 795        })
 796        .detach_and_log_err(cx);
 797    }
 798
 799    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 800        self.api_key_editor
 801            .update(cx, |editor, cx| editor.set_text("", window, cx));
 802
 803        let state = self.state.clone();
 804        cx.spawn_in(window, async move |_, cx| {
 805            state
 806                .update(cx, |state, cx| state.set_api_key(None, cx))?
 807                .await
 808        })
 809        .detach_and_log_err(cx);
 810    }
 811
 812    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
 813        !self.state.read(cx).is_authenticated()
 814    }
 815}
 816
 817impl Render for ConfigurationView {
 818    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 819        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
 820        let configured_card_label = if env_var_set {
 821            format!(
 822                "API key set in {} environment variable",
 823                API_KEY_ENV_VAR.name
 824            )
 825        } else {
 826            let api_url = GoogleLanguageModelProvider::api_url(cx);
 827            if api_url == google_ai::API_URL {
 828                "API key configured".to_string()
 829            } else {
 830                format!("API key configured for {}", api_url)
 831            }
 832        };
 833
 834        if self.load_credentials_task.is_some() {
 835            div()
 836                .child(Label::new("Loading credentials..."))
 837                .into_any_element()
 838        } else if self.should_render_editor(cx) {
 839            v_flex()
 840                .size_full()
 841                .on_action(cx.listener(Self::save_api_key))
 842                .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent {
 843                    ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Google AI".into(),
 844                    ConfigurationViewTargetAgent::EditPrediction => "Google AI for edit predictions".into(),
 845                    ConfigurationViewTargetAgent::Other(agent) => agent.clone(),
 846                })))
 847                .child(
 848                    List::new()
 849                        .child(
 850                            ListBulletItem::new("")
 851                                .child(Label::new("Create one by visiting"))
 852                                .child(ButtonLink::new("Google AI's console", "https://aistudio.google.com/app/apikey"))
 853                        )
 854                        .child(
 855                            ListBulletItem::new("Paste your API key below and hit enter to start using the agent")
 856                        )
 857                )
 858                .child(self.api_key_editor.clone())
 859                .child(
 860                    Label::new(
 861                        format!("You can also assign the {GEMINI_API_KEY_VAR_NAME} environment variable and restart Zed."),
 862                    )
 863                    .size(LabelSize::Small).color(Color::Muted),
 864                )
 865                .into_any_element()
 866        } else {
 867            ConfiguredApiCard::new(configured_card_label)
 868                .disabled(env_var_set)
 869                .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
 870                .when(env_var_set, |this| {
 871                    this.tooltip_label(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR_NAME} and {GOOGLE_AI_API_KEY_VAR_NAME} environment variables are unset."))
 872                })
 873                .into_any_element()
 874        }
 875    }
 876}
 877
 878#[cfg(test)]
 879mod tests {
 880    use super::*;
 881    use google_ai::{
 882        Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse,
 883        Part, Role as GoogleRole, TextPart,
 884    };
 885    use language_model::{LanguageModelToolUseId, MessageContent, Role};
 886    use serde_json::json;
 887
 888    #[test]
 889    fn test_function_call_with_signature_creates_tool_use_with_signature() {
 890        let mut mapper = GoogleEventMapper::new();
 891
 892        let response = GenerateContentResponse {
 893            candidates: Some(vec![GenerateContentCandidate {
 894                index: Some(0),
 895                content: Content {
 896                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
 897                        function_call: FunctionCall {
 898                            name: "test_function".to_string(),
 899                            args: json!({"arg": "value"}),
 900                        },
 901                        thought_signature: Some("test_signature_123".to_string()),
 902                    })],
 903                    role: GoogleRole::Model,
 904                },
 905                finish_reason: None,
 906                finish_message: None,
 907                safety_ratings: None,
 908                citation_metadata: None,
 909            }]),
 910            prompt_feedback: None,
 911            usage_metadata: None,
 912        };
 913
 914        let events = mapper.map_event(response);
 915
 916        assert_eq!(events.len(), 2); // ToolUse event + Stop event
 917
 918        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
 919            assert_eq!(tool_use.name.as_ref(), "test_function");
 920            assert_eq!(
 921                tool_use.thought_signature.as_deref(),
 922                Some("test_signature_123")
 923            );
 924        } else {
 925            panic!("Expected ToolUse event");
 926        }
 927    }
 928
 929    #[test]
 930    fn test_function_call_without_signature_has_none() {
 931        let mut mapper = GoogleEventMapper::new();
 932
 933        let response = GenerateContentResponse {
 934            candidates: Some(vec![GenerateContentCandidate {
 935                index: Some(0),
 936                content: Content {
 937                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
 938                        function_call: FunctionCall {
 939                            name: "test_function".to_string(),
 940                            args: json!({"arg": "value"}),
 941                        },
 942                        thought_signature: None,
 943                    })],
 944                    role: GoogleRole::Model,
 945                },
 946                finish_reason: None,
 947                finish_message: None,
 948                safety_ratings: None,
 949                citation_metadata: None,
 950            }]),
 951            prompt_feedback: None,
 952            usage_metadata: None,
 953        };
 954
 955        let events = mapper.map_event(response);
 956
 957        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
 958            assert_eq!(tool_use.thought_signature, None);
 959        } else {
 960            panic!("Expected ToolUse event");
 961        }
 962    }
 963
 964    #[test]
 965    fn test_empty_string_signature_normalized_to_none() {
 966        let mut mapper = GoogleEventMapper::new();
 967
 968        let response = GenerateContentResponse {
 969            candidates: Some(vec![GenerateContentCandidate {
 970                index: Some(0),
 971                content: Content {
 972                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
 973                        function_call: FunctionCall {
 974                            name: "test_function".to_string(),
 975                            args: json!({"arg": "value"}),
 976                        },
 977                        thought_signature: Some("".to_string()),
 978                    })],
 979                    role: GoogleRole::Model,
 980                },
 981                finish_reason: None,
 982                finish_message: None,
 983                safety_ratings: None,
 984                citation_metadata: None,
 985            }]),
 986            prompt_feedback: None,
 987            usage_metadata: None,
 988        };
 989
 990        let events = mapper.map_event(response);
 991
 992        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
 993            assert_eq!(tool_use.thought_signature, None);
 994        } else {
 995            panic!("Expected ToolUse event");
 996        }
 997    }
 998
 999    #[test]
1000    fn test_parallel_function_calls_preserve_signatures() {
1001        let mut mapper = GoogleEventMapper::new();
1002
1003        let response = GenerateContentResponse {
1004            candidates: Some(vec![GenerateContentCandidate {
1005                index: Some(0),
1006                content: Content {
1007                    parts: vec![
1008                        Part::FunctionCallPart(FunctionCallPart {
1009                            function_call: FunctionCall {
1010                                name: "function_1".to_string(),
1011                                args: json!({"arg": "value1"}),
1012                            },
1013                            thought_signature: Some("signature_1".to_string()),
1014                        }),
1015                        Part::FunctionCallPart(FunctionCallPart {
1016                            function_call: FunctionCall {
1017                                name: "function_2".to_string(),
1018                                args: json!({"arg": "value2"}),
1019                            },
1020                            thought_signature: None,
1021                        }),
1022                    ],
1023                    role: GoogleRole::Model,
1024                },
1025                finish_reason: None,
1026                finish_message: None,
1027                safety_ratings: None,
1028                citation_metadata: None,
1029            }]),
1030            prompt_feedback: None,
1031            usage_metadata: None,
1032        };
1033
1034        let events = mapper.map_event(response);
1035
1036        assert_eq!(events.len(), 3); // 2 ToolUse events + Stop event
1037
1038        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1039            assert_eq!(tool_use.name.as_ref(), "function_1");
1040            assert_eq!(tool_use.thought_signature.as_deref(), Some("signature_1"));
1041        } else {
1042            panic!("Expected ToolUse event for function_1");
1043        }
1044
1045        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
1046            assert_eq!(tool_use.name.as_ref(), "function_2");
1047            assert_eq!(tool_use.thought_signature, None);
1048        } else {
1049            panic!("Expected ToolUse event for function_2");
1050        }
1051    }
1052
1053    #[test]
1054    fn test_tool_use_with_signature_converts_to_function_call_part() {
1055        let tool_use = language_model::LanguageModelToolUse {
1056            id: LanguageModelToolUseId::from("test_id"),
1057            name: "test_function".into(),
1058            raw_input: json!({"arg": "value"}).to_string(),
1059            input: json!({"arg": "value"}),
1060            is_input_complete: true,
1061            thought_signature: Some("test_signature_456".to_string()),
1062        };
1063
1064        let request = super::into_google(
1065            LanguageModelRequest {
1066                messages: vec![language_model::LanguageModelRequestMessage {
1067                    role: Role::Assistant,
1068                    content: vec![MessageContent::ToolUse(tool_use)],
1069                    cache: false,
1070                    reasoning_details: None,
1071                }],
1072                ..Default::default()
1073            },
1074            "gemini-2.5-flash".to_string(),
1075            GoogleModelMode::Default,
1076        );
1077
1078        assert_eq!(request.contents[0].parts.len(), 1);
1079        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1080            assert_eq!(fc_part.function_call.name, "test_function");
1081            assert_eq!(
1082                fc_part.thought_signature.as_deref(),
1083                Some("test_signature_456")
1084            );
1085        } else {
1086            panic!("Expected FunctionCallPart");
1087        }
1088    }
1089
1090    #[test]
1091    fn test_tool_use_without_signature_omits_field() {
1092        let tool_use = language_model::LanguageModelToolUse {
1093            id: LanguageModelToolUseId::from("test_id"),
1094            name: "test_function".into(),
1095            raw_input: json!({"arg": "value"}).to_string(),
1096            input: json!({"arg": "value"}),
1097            is_input_complete: true,
1098            thought_signature: None,
1099        };
1100
1101        let request = super::into_google(
1102            LanguageModelRequest {
1103                messages: vec![language_model::LanguageModelRequestMessage {
1104                    role: Role::Assistant,
1105                    content: vec![MessageContent::ToolUse(tool_use)],
1106                    cache: false,
1107                    reasoning_details: None,
1108                }],
1109                ..Default::default()
1110            },
1111            "gemini-2.5-flash".to_string(),
1112            GoogleModelMode::Default,
1113        );
1114
1115        assert_eq!(request.contents[0].parts.len(), 1);
1116        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1117            assert_eq!(fc_part.thought_signature, None);
1118        } else {
1119            panic!("Expected FunctionCallPart");
1120        }
1121    }
1122
1123    #[test]
1124    fn test_empty_signature_in_tool_use_normalized_to_none() {
1125        let tool_use = language_model::LanguageModelToolUse {
1126            id: LanguageModelToolUseId::from("test_id"),
1127            name: "test_function".into(),
1128            raw_input: json!({"arg": "value"}).to_string(),
1129            input: json!({"arg": "value"}),
1130            is_input_complete: true,
1131            thought_signature: Some("".to_string()),
1132        };
1133
1134        let request = super::into_google(
1135            LanguageModelRequest {
1136                messages: vec![language_model::LanguageModelRequestMessage {
1137                    role: Role::Assistant,
1138                    content: vec![MessageContent::ToolUse(tool_use)],
1139                    cache: false,
1140                    reasoning_details: None,
1141                }],
1142                ..Default::default()
1143            },
1144            "gemini-2.5-flash".to_string(),
1145            GoogleModelMode::Default,
1146        );
1147
1148        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1149            assert_eq!(fc_part.thought_signature, None);
1150        } else {
1151            panic!("Expected FunctionCallPart");
1152        }
1153    }
1154
1155    #[test]
1156    fn test_round_trip_preserves_signature() {
1157        let mut mapper = GoogleEventMapper::new();
1158
1159        // Simulate receiving a response from Google with a signature
1160        let response = GenerateContentResponse {
1161            candidates: Some(vec![GenerateContentCandidate {
1162                index: Some(0),
1163                content: Content {
1164                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
1165                        function_call: FunctionCall {
1166                            name: "test_function".to_string(),
1167                            args: json!({"arg": "value"}),
1168                        },
1169                        thought_signature: Some("round_trip_sig".to_string()),
1170                    })],
1171                    role: GoogleRole::Model,
1172                },
1173                finish_reason: None,
1174                finish_message: None,
1175                safety_ratings: None,
1176                citation_metadata: None,
1177            }]),
1178            prompt_feedback: None,
1179            usage_metadata: None,
1180        };
1181
1182        let events = mapper.map_event(response);
1183
1184        let tool_use = if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1185            tool_use.clone()
1186        } else {
1187            panic!("Expected ToolUse event");
1188        };
1189
1190        // Convert back to Google format
1191        let request = super::into_google(
1192            LanguageModelRequest {
1193                messages: vec![language_model::LanguageModelRequestMessage {
1194                    role: Role::Assistant,
1195                    content: vec![MessageContent::ToolUse(tool_use)],
1196                    cache: false,
1197                    reasoning_details: None,
1198                }],
1199                ..Default::default()
1200            },
1201            "gemini-2.5-flash".to_string(),
1202            GoogleModelMode::Default,
1203        );
1204
1205        // Verify signature is preserved
1206        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1207            assert_eq!(fc_part.thought_signature.as_deref(), Some("round_trip_sig"));
1208        } else {
1209            panic!("Expected FunctionCallPart");
1210        }
1211    }
1212
1213    #[test]
1214    fn test_mixed_text_and_function_call_with_signature() {
1215        let mut mapper = GoogleEventMapper::new();
1216
1217        let response = GenerateContentResponse {
1218            candidates: Some(vec![GenerateContentCandidate {
1219                index: Some(0),
1220                content: Content {
1221                    parts: vec![
1222                        Part::TextPart(TextPart {
1223                            text: "I'll help with that.".to_string(),
1224                        }),
1225                        Part::FunctionCallPart(FunctionCallPart {
1226                            function_call: FunctionCall {
1227                                name: "helper_function".to_string(),
1228                                args: json!({"query": "help"}),
1229                            },
1230                            thought_signature: Some("mixed_sig".to_string()),
1231                        }),
1232                    ],
1233                    role: GoogleRole::Model,
1234                },
1235                finish_reason: None,
1236                finish_message: None,
1237                safety_ratings: None,
1238                citation_metadata: None,
1239            }]),
1240            prompt_feedback: None,
1241            usage_metadata: None,
1242        };
1243
1244        let events = mapper.map_event(response);
1245
1246        assert_eq!(events.len(), 3); // Text event + ToolUse event + Stop event
1247
1248        if let Ok(LanguageModelCompletionEvent::Text(text)) = &events[0] {
1249            assert_eq!(text, "I'll help with that.");
1250        } else {
1251            panic!("Expected Text event");
1252        }
1253
1254        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
1255            assert_eq!(tool_use.name.as_ref(), "helper_function");
1256            assert_eq!(tool_use.thought_signature.as_deref(), Some("mixed_sig"));
1257        } else {
1258            panic!("Expected ToolUse event");
1259        }
1260    }
1261
1262    #[test]
1263    fn test_special_characters_in_signature_preserved() {
1264        let mut mapper = GoogleEventMapper::new();
1265
1266        let signature_with_special_chars = "sig<>\"'&%$#@!{}[]".to_string();
1267
1268        let response = GenerateContentResponse {
1269            candidates: Some(vec![GenerateContentCandidate {
1270                index: Some(0),
1271                content: Content {
1272                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
1273                        function_call: FunctionCall {
1274                            name: "test_function".to_string(),
1275                            args: json!({"arg": "value"}),
1276                        },
1277                        thought_signature: Some(signature_with_special_chars.clone()),
1278                    })],
1279                    role: GoogleRole::Model,
1280                },
1281                finish_reason: None,
1282                finish_message: None,
1283                safety_ratings: None,
1284                citation_metadata: None,
1285            }]),
1286            prompt_feedback: None,
1287            usage_metadata: None,
1288        };
1289
1290        let events = mapper.map_event(response);
1291
1292        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1293            assert_eq!(
1294                tool_use.thought_signature.as_deref(),
1295                Some(signature_with_special_chars.as_str())
1296            );
1297        } else {
1298            panic!("Expected ToolUse event");
1299        }
1300    }
1301}