google.rs

   1use anyhow::{Context as _, Result};
   2use collections::BTreeMap;
   3use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
   4use google_ai::{
   5    FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
   6    ThinkingConfig, UsageMetadata,
   7};
   8use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
   9use http_client::HttpClient;
  10use language_model::{
  11    AuthenticateError, ConfigurationViewTargetAgent, EnvVar, LanguageModelCompletionError,
  12    LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
  13    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
  14};
  15use language_model::{
  16    IconOrSvg, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
  17    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  18    LanguageModelRequest, RateLimiter, Role,
  19};
  20use schemars::JsonSchema;
  21use serde::{Deserialize, Serialize};
  22pub use settings::GoogleAvailableModel as AvailableModel;
  23use settings::{Settings, SettingsStore};
  24use std::pin::Pin;
  25use std::sync::{
  26    Arc, LazyLock,
  27    atomic::{self, AtomicU64},
  28};
  29use strum::IntoEnumIterator;
  30use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
  31use ui_input::InputField;
  32use util::ResultExt;
  33
  34use language_model::ApiKeyState;
  35
  36const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
  37const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
  38
  39#[derive(Default, Clone, Debug, PartialEq)]
  40pub struct GoogleSettings {
  41    pub api_url: String,
  42    pub available_models: Vec<AvailableModel>,
  43}
  44
  45#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
  46#[serde(tag = "type", rename_all = "lowercase")]
  47pub enum ModelMode {
  48    #[default]
  49    Default,
  50    Thinking {
  51        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
  52        budget_tokens: Option<u32>,
  53    },
  54}
  55
  56pub struct GoogleLanguageModelProvider {
  57    http_client: Arc<dyn HttpClient>,
  58    state: Entity<State>,
  59}
  60
  61pub struct State {
  62    api_key_state: ApiKeyState,
  63}
  64
  65const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
  66const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY";
  67
  68static API_KEY_ENV_VAR: LazyLock<EnvVar> = LazyLock::new(|| {
  69    // Try GEMINI_API_KEY first as primary, fallback to GOOGLE_AI_API_KEY
  70    EnvVar::new(GEMINI_API_KEY_VAR_NAME.into()).or(EnvVar::new(GOOGLE_AI_API_KEY_VAR_NAME.into()))
  71});
  72
  73impl State {
  74    fn is_authenticated(&self) -> bool {
  75        self.api_key_state.has_key()
  76    }
  77
  78    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
  79        let api_url = GoogleLanguageModelProvider::api_url(cx);
  80        self.api_key_state
  81            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
  82    }
  83
  84    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
  85        let api_url = GoogleLanguageModelProvider::api_url(cx);
  86        self.api_key_state
  87            .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
  88    }
  89}
  90
  91impl GoogleLanguageModelProvider {
  92    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
  93        let state = cx.new(|cx| {
  94            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
  95                let api_url = Self::api_url(cx);
  96                this.api_key_state
  97                    .handle_url_change(api_url, |this| &mut this.api_key_state, cx);
  98                cx.notify();
  99            })
 100            .detach();
 101            State {
 102                api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
 103            }
 104        });
 105
 106        Self { http_client, state }
 107    }
 108
 109    fn create_language_model(&self, model: google_ai::Model) -> Arc<dyn LanguageModel> {
 110        Arc::new(GoogleLanguageModel {
 111            id: LanguageModelId::from(model.id().to_string()),
 112            model,
 113            state: self.state.clone(),
 114            http_client: self.http_client.clone(),
 115            request_limiter: RateLimiter::new(4),
 116        })
 117    }
 118
 119    fn settings(cx: &App) -> &GoogleSettings {
 120        &crate::AllLanguageModelSettings::get_global(cx).google
 121    }
 122
 123    fn api_url(cx: &App) -> SharedString {
 124        let api_url = &Self::settings(cx).api_url;
 125        if api_url.is_empty() {
 126            google_ai::API_URL.into()
 127        } else {
 128            SharedString::new(api_url.as_str())
 129        }
 130    }
 131}
 132
 133impl LanguageModelProviderState for GoogleLanguageModelProvider {
 134    type ObservableEntity = State;
 135
 136    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 137        Some(self.state.clone())
 138    }
 139}
 140
 141impl LanguageModelProvider for GoogleLanguageModelProvider {
 142    fn id(&self) -> LanguageModelProviderId {
 143        PROVIDER_ID
 144    }
 145
 146    fn name(&self) -> LanguageModelProviderName {
 147        PROVIDER_NAME
 148    }
 149
 150    fn icon(&self) -> IconOrSvg {
 151        IconOrSvg::Icon(IconName::AiGoogle)
 152    }
 153
 154    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 155        Some(self.create_language_model(google_ai::Model::default()))
 156    }
 157
 158    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 159        Some(self.create_language_model(google_ai::Model::default_fast()))
 160    }
 161
 162    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 163        let mut models = BTreeMap::default();
 164
 165        // Add base models from google_ai::Model::iter()
 166        for model in google_ai::Model::iter() {
 167            if !matches!(model, google_ai::Model::Custom { .. }) {
 168                models.insert(model.id().to_string(), model);
 169            }
 170        }
 171
 172        // Override with available models from settings
 173        for model in &GoogleLanguageModelProvider::settings(cx).available_models {
 174            models.insert(
 175                model.name.clone(),
 176                google_ai::Model::Custom {
 177                    name: model.name.clone(),
 178                    display_name: model.display_name.clone(),
 179                    max_tokens: model.max_tokens,
 180                    mode: model.mode.unwrap_or_default(),
 181                },
 182            );
 183        }
 184
 185        models
 186            .into_values()
 187            .map(|model| {
 188                Arc::new(GoogleLanguageModel {
 189                    id: LanguageModelId::from(model.id().to_string()),
 190                    model,
 191                    state: self.state.clone(),
 192                    http_client: self.http_client.clone(),
 193                    request_limiter: RateLimiter::new(4),
 194                }) as Arc<dyn LanguageModel>
 195            })
 196            .collect()
 197    }
 198
 199    fn is_authenticated(&self, cx: &App) -> bool {
 200        self.state.read(cx).is_authenticated()
 201    }
 202
 203    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 204        self.state.update(cx, |state, cx| state.authenticate(cx))
 205    }
 206
 207    fn configuration_view(
 208        &self,
 209        target_agent: language_model::ConfigurationViewTargetAgent,
 210        window: &mut Window,
 211        cx: &mut App,
 212    ) -> AnyView {
 213        cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
 214            .into()
 215    }
 216
 217    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 218        self.state
 219            .update(cx, |state, cx| state.set_api_key(None, cx))
 220    }
 221}
 222
 223pub struct GoogleLanguageModel {
 224    id: LanguageModelId,
 225    model: google_ai::Model,
 226    state: Entity<State>,
 227    http_client: Arc<dyn HttpClient>,
 228    request_limiter: RateLimiter,
 229}
 230
 231impl GoogleLanguageModel {
 232    fn stream_completion(
 233        &self,
 234        request: google_ai::GenerateContentRequest,
 235        cx: &AsyncApp,
 236    ) -> BoxFuture<
 237        'static,
 238        Result<futures::stream::BoxStream<'static, Result<GenerateContentResponse>>>,
 239    > {
 240        let http_client = self.http_client.clone();
 241
 242        let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
 243            let api_url = GoogleLanguageModelProvider::api_url(cx);
 244            (state.api_key_state.key(&api_url), api_url)
 245        });
 246
 247        async move {
 248            let api_key = api_key.context("Missing Google API key")?;
 249            let request = google_ai::stream_generate_content(
 250                http_client.as_ref(),
 251                &api_url,
 252                &api_key,
 253                request,
 254            );
 255            request.await.context("failed to stream completion")
 256        }
 257        .boxed()
 258    }
 259}
 260
 261impl LanguageModel for GoogleLanguageModel {
 262    fn id(&self) -> LanguageModelId {
 263        self.id.clone()
 264    }
 265
 266    fn name(&self) -> LanguageModelName {
 267        LanguageModelName::from(self.model.display_name().to_string())
 268    }
 269
 270    fn provider_id(&self) -> LanguageModelProviderId {
 271        PROVIDER_ID
 272    }
 273
 274    fn provider_name(&self) -> LanguageModelProviderName {
 275        PROVIDER_NAME
 276    }
 277
 278    fn supports_tools(&self) -> bool {
 279        self.model.supports_tools()
 280    }
 281
 282    fn supports_images(&self) -> bool {
 283        self.model.supports_images()
 284    }
 285
 286    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 287        match choice {
 288            LanguageModelToolChoice::Auto
 289            | LanguageModelToolChoice::Any
 290            | LanguageModelToolChoice::None => true,
 291        }
 292    }
 293
 294    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
 295        LanguageModelToolSchemaFormat::JsonSchemaSubset
 296    }
 297
 298    fn telemetry_id(&self) -> String {
 299        format!("google/{}", self.model.request_id())
 300    }
 301
 302    fn max_token_count(&self) -> u64 {
 303        self.model.max_token_count()
 304    }
 305
 306    fn max_output_tokens(&self) -> Option<u64> {
 307        self.model.max_output_tokens()
 308    }
 309
 310    fn count_tokens(
 311        &self,
 312        request: LanguageModelRequest,
 313        cx: &App,
 314    ) -> BoxFuture<'static, Result<u64>> {
 315        let model_id = self.model.request_id().to_string();
 316        let request = into_google(request, model_id, self.model.mode());
 317        let http_client = self.http_client.clone();
 318        let api_url = GoogleLanguageModelProvider::api_url(cx);
 319        let api_key = self.state.read(cx).api_key_state.key(&api_url);
 320
 321        async move {
 322            let Some(api_key) = api_key else {
 323                return Err(LanguageModelCompletionError::NoApiKey {
 324                    provider: PROVIDER_NAME,
 325                }
 326                .into());
 327            };
 328            let response = google_ai::count_tokens(
 329                http_client.as_ref(),
 330                &api_url,
 331                &api_key,
 332                google_ai::CountTokensRequest {
 333                    generate_content_request: request,
 334                },
 335            )
 336            .await?;
 337            Ok(response.total_tokens)
 338        }
 339        .boxed()
 340    }
 341
 342    fn stream_completion(
 343        &self,
 344        request: LanguageModelRequest,
 345        cx: &AsyncApp,
 346    ) -> BoxFuture<
 347        'static,
 348        Result<
 349            futures::stream::BoxStream<
 350                'static,
 351                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
 352            >,
 353            LanguageModelCompletionError,
 354        >,
 355    > {
 356        let request = into_google(
 357            request,
 358            self.model.request_id().to_string(),
 359            self.model.mode(),
 360        );
 361        let request = self.stream_completion(request, cx);
 362        let future = self.request_limiter.stream(async move {
 363            let response = request.await.map_err(LanguageModelCompletionError::from)?;
 364            Ok(GoogleEventMapper::new().map_stream(response))
 365        });
 366        async move { Ok(future.await?.boxed()) }.boxed()
 367    }
 368}
 369
 370pub fn into_google(
 371    mut request: LanguageModelRequest,
 372    model_id: String,
 373    mode: GoogleModelMode,
 374) -> google_ai::GenerateContentRequest {
 375    fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
 376        content
 377            .into_iter()
 378            .flat_map(|content| match content {
 379                language_model::MessageContent::Text(text) => {
 380                    if !text.is_empty() {
 381                        vec![Part::TextPart(google_ai::TextPart { text })]
 382                    } else {
 383                        vec![]
 384                    }
 385                }
 386                language_model::MessageContent::Thinking {
 387                    text: _,
 388                    signature: Some(signature),
 389                } => {
 390                    if !signature.is_empty() {
 391                        vec![Part::ThoughtPart(google_ai::ThoughtPart {
 392                            thought: true,
 393                            thought_signature: signature,
 394                        })]
 395                    } else {
 396                        vec![]
 397                    }
 398                }
 399                language_model::MessageContent::Thinking { .. } => {
 400                    vec![]
 401                }
 402                language_model::MessageContent::RedactedThinking(_) => vec![],
 403                language_model::MessageContent::Image(image) => {
 404                    vec![Part::InlineDataPart(google_ai::InlineDataPart {
 405                        inline_data: google_ai::GenerativeContentBlob {
 406                            mime_type: "image/png".to_string(),
 407                            data: image.source.to_string(),
 408                        },
 409                    })]
 410                }
 411                language_model::MessageContent::ToolUse(tool_use) => {
 412                    // Normalize empty string signatures to None
 413                    let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
 414
 415                    vec![Part::FunctionCallPart(google_ai::FunctionCallPart {
 416                        function_call: google_ai::FunctionCall {
 417                            name: tool_use.name.to_string(),
 418                            args: tool_use.input,
 419                        },
 420                        thought_signature,
 421                    })]
 422                }
 423                language_model::MessageContent::ToolResult(tool_result) => {
 424                    match tool_result.content {
 425                        language_model::LanguageModelToolResultContent::Text(text) => {
 426                            vec![Part::FunctionResponsePart(
 427                                google_ai::FunctionResponsePart {
 428                                    function_response: google_ai::FunctionResponse {
 429                                        name: tool_result.tool_name.to_string(),
 430                                        // The API expects a valid JSON object
 431                                        response: serde_json::json!({
 432                                            "output": text
 433                                        }),
 434                                    },
 435                                },
 436                            )]
 437                        }
 438                        language_model::LanguageModelToolResultContent::Image(image) => {
 439                            vec![
 440                                Part::FunctionResponsePart(google_ai::FunctionResponsePart {
 441                                    function_response: google_ai::FunctionResponse {
 442                                        name: tool_result.tool_name.to_string(),
 443                                        // The API expects a valid JSON object
 444                                        response: serde_json::json!({
 445                                            "output": "Tool responded with an image"
 446                                        }),
 447                                    },
 448                                }),
 449                                Part::InlineDataPart(google_ai::InlineDataPart {
 450                                    inline_data: google_ai::GenerativeContentBlob {
 451                                        mime_type: "image/png".to_string(),
 452                                        data: image.source.to_string(),
 453                                    },
 454                                }),
 455                            ]
 456                        }
 457                    }
 458                }
 459            })
 460            .collect()
 461    }
 462
 463    let system_instructions = if request
 464        .messages
 465        .first()
 466        .is_some_and(|msg| matches!(msg.role, Role::System))
 467    {
 468        let message = request.messages.remove(0);
 469        Some(SystemInstruction {
 470            parts: map_content(message.content),
 471        })
 472    } else {
 473        None
 474    };
 475
 476    google_ai::GenerateContentRequest {
 477        model: google_ai::ModelName { model_id },
 478        system_instruction: system_instructions,
 479        contents: request
 480            .messages
 481            .into_iter()
 482            .filter_map(|message| {
 483                let parts = map_content(message.content);
 484                if parts.is_empty() {
 485                    None
 486                } else {
 487                    Some(google_ai::Content {
 488                        parts,
 489                        role: match message.role {
 490                            Role::User => google_ai::Role::User,
 491                            Role::Assistant => google_ai::Role::Model,
 492                            Role::System => google_ai::Role::User, // Google AI doesn't have a system role
 493                        },
 494                    })
 495                }
 496            })
 497            .collect(),
 498        generation_config: Some(google_ai::GenerationConfig {
 499            candidate_count: Some(1),
 500            stop_sequences: Some(request.stop),
 501            max_output_tokens: None,
 502            temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
 503            thinking_config: match (request.thinking_allowed, mode) {
 504                (true, GoogleModelMode::Thinking { budget_tokens }) => {
 505                    budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
 506                }
 507                _ => None,
 508            },
 509            top_p: None,
 510            top_k: None,
 511        }),
 512        safety_settings: None,
 513        tools: (!request.tools.is_empty()).then(|| {
 514            vec![google_ai::Tool {
 515                function_declarations: request
 516                    .tools
 517                    .into_iter()
 518                    .map(|tool| FunctionDeclaration {
 519                        name: tool.name,
 520                        description: tool.description,
 521                        parameters: tool.input_schema,
 522                    })
 523                    .collect(),
 524            }]
 525        }),
 526        tool_config: request.tool_choice.map(|choice| google_ai::ToolConfig {
 527            function_calling_config: google_ai::FunctionCallingConfig {
 528                mode: match choice {
 529                    LanguageModelToolChoice::Auto => google_ai::FunctionCallingMode::Auto,
 530                    LanguageModelToolChoice::Any => google_ai::FunctionCallingMode::Any,
 531                    LanguageModelToolChoice::None => google_ai::FunctionCallingMode::None,
 532                },
 533                allowed_function_names: None,
 534            },
 535        }),
 536    }
 537}
 538
 539pub struct GoogleEventMapper {
 540    usage: UsageMetadata,
 541    stop_reason: StopReason,
 542}
 543
 544impl GoogleEventMapper {
 545    pub fn new() -> Self {
 546        Self {
 547            usage: UsageMetadata::default(),
 548            stop_reason: StopReason::EndTurn,
 549        }
 550    }
 551
 552    pub fn map_stream(
 553        mut self,
 554        events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
 555    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 556    {
 557        events
 558            .map(Some)
 559            .chain(futures::stream::once(async { None }))
 560            .flat_map(move |event| {
 561                futures::stream::iter(match event {
 562                    Some(Ok(event)) => self.map_event(event),
 563                    Some(Err(error)) => {
 564                        vec![Err(LanguageModelCompletionError::from(error))]
 565                    }
 566                    None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
 567                })
 568            })
 569    }
 570
 571    pub fn map_event(
 572        &mut self,
 573        event: GenerateContentResponse,
 574    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 575        static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
 576
 577        let mut events: Vec<_> = Vec::new();
 578        let mut wants_to_use_tool = false;
 579        if let Some(usage_metadata) = event.usage_metadata {
 580            update_usage(&mut self.usage, &usage_metadata);
 581            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
 582                convert_usage(&self.usage),
 583            )))
 584        }
 585
 586        if let Some(prompt_feedback) = event.prompt_feedback
 587            && let Some(block_reason) = prompt_feedback.block_reason.as_deref()
 588        {
 589            self.stop_reason = match block_reason {
 590                "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => {
 591                    StopReason::Refusal
 592                }
 593                _ => {
 594                    log::error!("Unexpected Google block_reason: {block_reason}");
 595                    StopReason::Refusal
 596                }
 597            };
 598            events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
 599
 600            return events;
 601        }
 602
 603        if let Some(candidates) = event.candidates {
 604            for candidate in candidates {
 605                if let Some(finish_reason) = candidate.finish_reason.as_deref() {
 606                    self.stop_reason = match finish_reason {
 607                        "STOP" => StopReason::EndTurn,
 608                        "MAX_TOKENS" => StopReason::MaxTokens,
 609                        _ => {
 610                            log::error!("Unexpected google finish_reason: {finish_reason}");
 611                            StopReason::EndTurn
 612                        }
 613                    };
 614                }
 615                candidate
 616                    .content
 617                    .parts
 618                    .into_iter()
 619                    .for_each(|part| match part {
 620                        Part::TextPart(text_part) => {
 621                            events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
 622                        }
 623                        Part::InlineDataPart(_) => {}
 624                        Part::FunctionCallPart(function_call_part) => {
 625                            wants_to_use_tool = true;
 626                            let name: Arc<str> = function_call_part.function_call.name.into();
 627                            let next_tool_id =
 628                                TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
 629                            let id: LanguageModelToolUseId =
 630                                format!("{}-{}", name, next_tool_id).into();
 631
 632                            // Normalize empty string signatures to None
 633                            let thought_signature = function_call_part
 634                                .thought_signature
 635                                .filter(|s| !s.is_empty());
 636
 637                            events.push(Ok(LanguageModelCompletionEvent::ToolUse(
 638                                LanguageModelToolUse {
 639                                    id,
 640                                    name,
 641                                    is_input_complete: true,
 642                                    raw_input: function_call_part.function_call.args.to_string(),
 643                                    input: function_call_part.function_call.args,
 644                                    thought_signature,
 645                                },
 646                            )));
 647                        }
 648                        Part::FunctionResponsePart(_) => {}
 649                        Part::ThoughtPart(part) => {
 650                            events.push(Ok(LanguageModelCompletionEvent::Thinking {
 651                                text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
 652                                signature: Some(part.thought_signature),
 653                            }));
 654                        }
 655                    });
 656            }
 657        }
 658
 659        // Even when Gemini wants to use a Tool, the API
 660        // responds with `finish_reason: STOP`
 661        if wants_to_use_tool {
 662            self.stop_reason = StopReason::ToolUse;
 663            events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
 664        }
 665        events
 666    }
 667}
 668
 669pub fn count_google_tokens(
 670    request: LanguageModelRequest,
 671    cx: &App,
 672) -> BoxFuture<'static, Result<u64>> {
 673    // We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly.
 674    // So we have to use tokenizer from tiktoken_rs to count tokens.
 675    cx.background_spawn(async move {
 676        let messages = request
 677            .messages
 678            .into_iter()
 679            .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
 680                role: match message.role {
 681                    Role::User => "user".into(),
 682                    Role::Assistant => "assistant".into(),
 683                    Role::System => "system".into(),
 684                },
 685                content: Some(message.string_contents()),
 686                name: None,
 687                function_call: None,
 688            })
 689            .collect::<Vec<_>>();
 690
 691        // Tiktoken doesn't yet support these models, so we manually use the
 692        // same tokenizer as GPT-4.
 693        tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
 694    })
 695    .boxed()
 696}
 697
 698fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
 699    if let Some(prompt_token_count) = new.prompt_token_count {
 700        usage.prompt_token_count = Some(prompt_token_count);
 701    }
 702    if let Some(cached_content_token_count) = new.cached_content_token_count {
 703        usage.cached_content_token_count = Some(cached_content_token_count);
 704    }
 705    if let Some(candidates_token_count) = new.candidates_token_count {
 706        usage.candidates_token_count = Some(candidates_token_count);
 707    }
 708    if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
 709        usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
 710    }
 711    if let Some(thoughts_token_count) = new.thoughts_token_count {
 712        usage.thoughts_token_count = Some(thoughts_token_count);
 713    }
 714    if let Some(total_token_count) = new.total_token_count {
 715        usage.total_token_count = Some(total_token_count);
 716    }
 717}
 718
 719fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
 720    let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
 721    let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
 722    let input_tokens = prompt_tokens - cached_tokens;
 723    let output_tokens = usage.candidates_token_count.unwrap_or(0);
 724
 725    language_model::TokenUsage {
 726        input_tokens,
 727        output_tokens,
 728        cache_read_input_tokens: cached_tokens,
 729        cache_creation_input_tokens: 0,
 730    }
 731}
 732
 733struct ConfigurationView {
 734    api_key_editor: Entity<InputField>,
 735    state: Entity<State>,
 736    target_agent: language_model::ConfigurationViewTargetAgent,
 737    load_credentials_task: Option<Task<()>>,
 738}
 739
 740impl ConfigurationView {
 741    fn new(
 742        state: Entity<State>,
 743        target_agent: language_model::ConfigurationViewTargetAgent,
 744        window: &mut Window,
 745        cx: &mut Context<Self>,
 746    ) -> Self {
 747        cx.observe(&state, |_, _, cx| {
 748            cx.notify();
 749        })
 750        .detach();
 751
 752        let load_credentials_task = Some(cx.spawn_in(window, {
 753            let state = state.clone();
 754            async move |this, cx| {
 755                if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
 756                    // We don't log an error, because "not signed in" is also an error.
 757                    let _ = task.await;
 758                }
 759                this.update(cx, |this, cx| {
 760                    this.load_credentials_task = None;
 761                    cx.notify();
 762                })
 763                .log_err();
 764            }
 765        }));
 766
 767        Self {
 768            api_key_editor: cx.new(|cx| InputField::new(window, cx, "AIzaSy...")),
 769            target_agent,
 770            state,
 771            load_credentials_task,
 772        }
 773    }
 774
 775    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 776        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
 777        if api_key.is_empty() {
 778            return;
 779        }
 780
 781        // url changes can cause the editor to be displayed again
 782        self.api_key_editor
 783            .update(cx, |editor, cx| editor.set_text("", window, cx));
 784
 785        let state = self.state.clone();
 786        cx.spawn_in(window, async move |_, cx| {
 787            state
 788                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
 789                .await
 790        })
 791        .detach_and_log_err(cx);
 792    }
 793
 794    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 795        self.api_key_editor
 796            .update(cx, |editor, cx| editor.set_text("", window, cx));
 797
 798        let state = self.state.clone();
 799        cx.spawn_in(window, async move |_, cx| {
 800            state
 801                .update(cx, |state, cx| state.set_api_key(None, cx))
 802                .await
 803        })
 804        .detach_and_log_err(cx);
 805    }
 806
 807    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
 808        !self.state.read(cx).is_authenticated()
 809    }
 810}
 811
 812impl Render for ConfigurationView {
 813    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 814        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
 815        let configured_card_label = if env_var_set {
 816            format!(
 817                "API key set in {} environment variable",
 818                API_KEY_ENV_VAR.name
 819            )
 820        } else {
 821            let api_url = GoogleLanguageModelProvider::api_url(cx);
 822            if api_url == google_ai::API_URL {
 823                "API key configured".to_string()
 824            } else {
 825                format!("API key configured for {}", api_url)
 826            }
 827        };
 828
 829        if self.load_credentials_task.is_some() {
 830            div()
 831                .child(Label::new("Loading credentials..."))
 832                .into_any_element()
 833        } else if self.should_render_editor(cx) {
 834            v_flex()
 835                .size_full()
 836                .on_action(cx.listener(Self::save_api_key))
 837                .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent {
 838                    ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Google AI".into(),
 839                    ConfigurationViewTargetAgent::Other(agent) => agent.clone(),
 840                })))
 841                .child(
 842                    List::new()
 843                        .child(
 844                            ListBulletItem::new("")
 845                                .child(Label::new("Create one by visiting"))
 846                                .child(ButtonLink::new("Google AI's console", "https://aistudio.google.com/app/apikey"))
 847                        )
 848                        .child(
 849                            ListBulletItem::new("Paste your API key below and hit enter to start using the agent")
 850                        )
 851                )
 852                .child(self.api_key_editor.clone())
 853                .child(
 854                    Label::new(
 855                        format!("You can also set the {GEMINI_API_KEY_VAR_NAME} environment variable and restart Zed."),
 856                    )
 857                    .size(LabelSize::Small).color(Color::Muted),
 858                )
 859                .into_any_element()
 860        } else {
 861            ConfiguredApiCard::new(configured_card_label)
 862                .disabled(env_var_set)
 863                .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
 864                .when(env_var_set, |this| {
 865                    this.tooltip_label(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR_NAME} and {GOOGLE_AI_API_KEY_VAR_NAME} environment variables are unset."))
 866                })
 867                .into_any_element()
 868        }
 869    }
 870}
 871
 872#[cfg(test)]
 873mod tests {
 874    use super::*;
 875    use google_ai::{
 876        Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse,
 877        Part, Role as GoogleRole, TextPart,
 878    };
 879    use language_model::{LanguageModelToolUseId, MessageContent, Role};
 880    use serde_json::json;
 881
 882    #[test]
 883    fn test_function_call_with_signature_creates_tool_use_with_signature() {
 884        let mut mapper = GoogleEventMapper::new();
 885
 886        let response = GenerateContentResponse {
 887            candidates: Some(vec![GenerateContentCandidate {
 888                index: Some(0),
 889                content: Content {
 890                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
 891                        function_call: FunctionCall {
 892                            name: "test_function".to_string(),
 893                            args: json!({"arg": "value"}),
 894                        },
 895                        thought_signature: Some("test_signature_123".to_string()),
 896                    })],
 897                    role: GoogleRole::Model,
 898                },
 899                finish_reason: None,
 900                finish_message: None,
 901                safety_ratings: None,
 902                citation_metadata: None,
 903            }]),
 904            prompt_feedback: None,
 905            usage_metadata: None,
 906        };
 907
 908        let events = mapper.map_event(response);
 909
 910        assert_eq!(events.len(), 2); // ToolUse event + Stop event
 911
 912        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
 913            assert_eq!(tool_use.name.as_ref(), "test_function");
 914            assert_eq!(
 915                tool_use.thought_signature.as_deref(),
 916                Some("test_signature_123")
 917            );
 918        } else {
 919            panic!("Expected ToolUse event");
 920        }
 921    }
 922
 923    #[test]
 924    fn test_function_call_without_signature_has_none() {
 925        let mut mapper = GoogleEventMapper::new();
 926
 927        let response = GenerateContentResponse {
 928            candidates: Some(vec![GenerateContentCandidate {
 929                index: Some(0),
 930                content: Content {
 931                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
 932                        function_call: FunctionCall {
 933                            name: "test_function".to_string(),
 934                            args: json!({"arg": "value"}),
 935                        },
 936                        thought_signature: None,
 937                    })],
 938                    role: GoogleRole::Model,
 939                },
 940                finish_reason: None,
 941                finish_message: None,
 942                safety_ratings: None,
 943                citation_metadata: None,
 944            }]),
 945            prompt_feedback: None,
 946            usage_metadata: None,
 947        };
 948
 949        let events = mapper.map_event(response);
 950
 951        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
 952            assert_eq!(tool_use.thought_signature, None);
 953        } else {
 954            panic!("Expected ToolUse event");
 955        }
 956    }
 957
 958    #[test]
 959    fn test_empty_string_signature_normalized_to_none() {
 960        let mut mapper = GoogleEventMapper::new();
 961
 962        let response = GenerateContentResponse {
 963            candidates: Some(vec![GenerateContentCandidate {
 964                index: Some(0),
 965                content: Content {
 966                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
 967                        function_call: FunctionCall {
 968                            name: "test_function".to_string(),
 969                            args: json!({"arg": "value"}),
 970                        },
 971                        thought_signature: Some("".to_string()),
 972                    })],
 973                    role: GoogleRole::Model,
 974                },
 975                finish_reason: None,
 976                finish_message: None,
 977                safety_ratings: None,
 978                citation_metadata: None,
 979            }]),
 980            prompt_feedback: None,
 981            usage_metadata: None,
 982        };
 983
 984        let events = mapper.map_event(response);
 985
 986        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
 987            assert_eq!(tool_use.thought_signature, None);
 988        } else {
 989            panic!("Expected ToolUse event");
 990        }
 991    }
 992
 993    #[test]
 994    fn test_parallel_function_calls_preserve_signatures() {
 995        let mut mapper = GoogleEventMapper::new();
 996
 997        let response = GenerateContentResponse {
 998            candidates: Some(vec![GenerateContentCandidate {
 999                index: Some(0),
1000                content: Content {
1001                    parts: vec![
1002                        Part::FunctionCallPart(FunctionCallPart {
1003                            function_call: FunctionCall {
1004                                name: "function_1".to_string(),
1005                                args: json!({"arg": "value1"}),
1006                            },
1007                            thought_signature: Some("signature_1".to_string()),
1008                        }),
1009                        Part::FunctionCallPart(FunctionCallPart {
1010                            function_call: FunctionCall {
1011                                name: "function_2".to_string(),
1012                                args: json!({"arg": "value2"}),
1013                            },
1014                            thought_signature: None,
1015                        }),
1016                    ],
1017                    role: GoogleRole::Model,
1018                },
1019                finish_reason: None,
1020                finish_message: None,
1021                safety_ratings: None,
1022                citation_metadata: None,
1023            }]),
1024            prompt_feedback: None,
1025            usage_metadata: None,
1026        };
1027
1028        let events = mapper.map_event(response);
1029
1030        assert_eq!(events.len(), 3); // 2 ToolUse events + Stop event
1031
1032        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1033            assert_eq!(tool_use.name.as_ref(), "function_1");
1034            assert_eq!(tool_use.thought_signature.as_deref(), Some("signature_1"));
1035        } else {
1036            panic!("Expected ToolUse event for function_1");
1037        }
1038
1039        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
1040            assert_eq!(tool_use.name.as_ref(), "function_2");
1041            assert_eq!(tool_use.thought_signature, None);
1042        } else {
1043            panic!("Expected ToolUse event for function_2");
1044        }
1045    }
1046
1047    #[test]
1048    fn test_tool_use_with_signature_converts_to_function_call_part() {
1049        let tool_use = language_model::LanguageModelToolUse {
1050            id: LanguageModelToolUseId::from("test_id"),
1051            name: "test_function".into(),
1052            raw_input: json!({"arg": "value"}).to_string(),
1053            input: json!({"arg": "value"}),
1054            is_input_complete: true,
1055            thought_signature: Some("test_signature_456".to_string()),
1056        };
1057
1058        let request = super::into_google(
1059            LanguageModelRequest {
1060                messages: vec![language_model::LanguageModelRequestMessage {
1061                    role: Role::Assistant,
1062                    content: vec![MessageContent::ToolUse(tool_use)],
1063                    cache: false,
1064                    reasoning_details: None,
1065                }],
1066                ..Default::default()
1067            },
1068            "gemini-2.5-flash".to_string(),
1069            GoogleModelMode::Default,
1070        );
1071
1072        assert_eq!(request.contents[0].parts.len(), 1);
1073        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1074            assert_eq!(fc_part.function_call.name, "test_function");
1075            assert_eq!(
1076                fc_part.thought_signature.as_deref(),
1077                Some("test_signature_456")
1078            );
1079        } else {
1080            panic!("Expected FunctionCallPart");
1081        }
1082    }
1083
1084    #[test]
1085    fn test_tool_use_without_signature_omits_field() {
1086        let tool_use = language_model::LanguageModelToolUse {
1087            id: LanguageModelToolUseId::from("test_id"),
1088            name: "test_function".into(),
1089            raw_input: json!({"arg": "value"}).to_string(),
1090            input: json!({"arg": "value"}),
1091            is_input_complete: true,
1092            thought_signature: None,
1093        };
1094
1095        let request = super::into_google(
1096            LanguageModelRequest {
1097                messages: vec![language_model::LanguageModelRequestMessage {
1098                    role: Role::Assistant,
1099                    content: vec![MessageContent::ToolUse(tool_use)],
1100                    cache: false,
1101                    reasoning_details: None,
1102                }],
1103                ..Default::default()
1104            },
1105            "gemini-2.5-flash".to_string(),
1106            GoogleModelMode::Default,
1107        );
1108
1109        assert_eq!(request.contents[0].parts.len(), 1);
1110        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1111            assert_eq!(fc_part.thought_signature, None);
1112        } else {
1113            panic!("Expected FunctionCallPart");
1114        }
1115    }
1116
1117    #[test]
1118    fn test_empty_signature_in_tool_use_normalized_to_none() {
1119        let tool_use = language_model::LanguageModelToolUse {
1120            id: LanguageModelToolUseId::from("test_id"),
1121            name: "test_function".into(),
1122            raw_input: json!({"arg": "value"}).to_string(),
1123            input: json!({"arg": "value"}),
1124            is_input_complete: true,
1125            thought_signature: Some("".to_string()),
1126        };
1127
1128        let request = super::into_google(
1129            LanguageModelRequest {
1130                messages: vec![language_model::LanguageModelRequestMessage {
1131                    role: Role::Assistant,
1132                    content: vec![MessageContent::ToolUse(tool_use)],
1133                    cache: false,
1134                    reasoning_details: None,
1135                }],
1136                ..Default::default()
1137            },
1138            "gemini-2.5-flash".to_string(),
1139            GoogleModelMode::Default,
1140        );
1141
1142        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1143            assert_eq!(fc_part.thought_signature, None);
1144        } else {
1145            panic!("Expected FunctionCallPart");
1146        }
1147    }
1148
1149    #[test]
1150    fn test_round_trip_preserves_signature() {
1151        let mut mapper = GoogleEventMapper::new();
1152
1153        // Simulate receiving a response from Google with a signature
1154        let response = GenerateContentResponse {
1155            candidates: Some(vec![GenerateContentCandidate {
1156                index: Some(0),
1157                content: Content {
1158                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
1159                        function_call: FunctionCall {
1160                            name: "test_function".to_string(),
1161                            args: json!({"arg": "value"}),
1162                        },
1163                        thought_signature: Some("round_trip_sig".to_string()),
1164                    })],
1165                    role: GoogleRole::Model,
1166                },
1167                finish_reason: None,
1168                finish_message: None,
1169                safety_ratings: None,
1170                citation_metadata: None,
1171            }]),
1172            prompt_feedback: None,
1173            usage_metadata: None,
1174        };
1175
1176        let events = mapper.map_event(response);
1177
1178        let tool_use = if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1179            tool_use.clone()
1180        } else {
1181            panic!("Expected ToolUse event");
1182        };
1183
1184        // Convert back to Google format
1185        let request = super::into_google(
1186            LanguageModelRequest {
1187                messages: vec![language_model::LanguageModelRequestMessage {
1188                    role: Role::Assistant,
1189                    content: vec![MessageContent::ToolUse(tool_use)],
1190                    cache: false,
1191                    reasoning_details: None,
1192                }],
1193                ..Default::default()
1194            },
1195            "gemini-2.5-flash".to_string(),
1196            GoogleModelMode::Default,
1197        );
1198
1199        // Verify signature is preserved
1200        if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] {
1201            assert_eq!(fc_part.thought_signature.as_deref(), Some("round_trip_sig"));
1202        } else {
1203            panic!("Expected FunctionCallPart");
1204        }
1205    }
1206
1207    #[test]
1208    fn test_mixed_text_and_function_call_with_signature() {
1209        let mut mapper = GoogleEventMapper::new();
1210
1211        let response = GenerateContentResponse {
1212            candidates: Some(vec![GenerateContentCandidate {
1213                index: Some(0),
1214                content: Content {
1215                    parts: vec![
1216                        Part::TextPart(TextPart {
1217                            text: "I'll help with that.".to_string(),
1218                        }),
1219                        Part::FunctionCallPart(FunctionCallPart {
1220                            function_call: FunctionCall {
1221                                name: "helper_function".to_string(),
1222                                args: json!({"query": "help"}),
1223                            },
1224                            thought_signature: Some("mixed_sig".to_string()),
1225                        }),
1226                    ],
1227                    role: GoogleRole::Model,
1228                },
1229                finish_reason: None,
1230                finish_message: None,
1231                safety_ratings: None,
1232                citation_metadata: None,
1233            }]),
1234            prompt_feedback: None,
1235            usage_metadata: None,
1236        };
1237
1238        let events = mapper.map_event(response);
1239
1240        assert_eq!(events.len(), 3); // Text event + ToolUse event + Stop event
1241
1242        if let Ok(LanguageModelCompletionEvent::Text(text)) = &events[0] {
1243            assert_eq!(text, "I'll help with that.");
1244        } else {
1245            panic!("Expected Text event");
1246        }
1247
1248        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] {
1249            assert_eq!(tool_use.name.as_ref(), "helper_function");
1250            assert_eq!(tool_use.thought_signature.as_deref(), Some("mixed_sig"));
1251        } else {
1252            panic!("Expected ToolUse event");
1253        }
1254    }
1255
1256    #[test]
1257    fn test_special_characters_in_signature_preserved() {
1258        let mut mapper = GoogleEventMapper::new();
1259
1260        let signature_with_special_chars = "sig<>\"'&%$#@!{}[]".to_string();
1261
1262        let response = GenerateContentResponse {
1263            candidates: Some(vec![GenerateContentCandidate {
1264                index: Some(0),
1265                content: Content {
1266                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
1267                        function_call: FunctionCall {
1268                            name: "test_function".to_string(),
1269                            args: json!({"arg": "value"}),
1270                        },
1271                        thought_signature: Some(signature_with_special_chars.clone()),
1272                    })],
1273                    role: GoogleRole::Model,
1274                },
1275                finish_reason: None,
1276                finish_message: None,
1277                safety_ratings: None,
1278                citation_metadata: None,
1279            }]),
1280            prompt_feedback: None,
1281            usage_metadata: None,
1282        };
1283
1284        let events = mapper.map_event(response);
1285
1286        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
1287            assert_eq!(
1288                tool_use.thought_signature.as_deref(),
1289                Some(signature_with_special_chars.as_str())
1290            );
1291        } else {
1292            panic!("Expected ToolUse event");
1293        }
1294    }
1295}