google.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use collections::BTreeMap;
  3use credentials_provider::CredentialsProvider;
  4use editor::{Editor, EditorElement, EditorStyle};
  5use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture};
  6use google_ai::{
  7    FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
  8    ThinkingConfig, UsageMetadata,
  9};
 10use gpui::{
 11    AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace,
 12    Window,
 13};
 14use http_client::HttpClient;
 15use language_model::{
 16    AuthenticateError, ConfigurationViewTargetAgent, LanguageModelCompletionError,
 17    LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
 18    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
 19};
 20use language_model::{
 21    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
 22    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
 23    LanguageModelRequest, RateLimiter, Role,
 24};
 25use schemars::JsonSchema;
 26use serde::{Deserialize, Serialize};
 27use settings::{Settings, SettingsStore};
 28use std::pin::Pin;
 29use std::sync::{
 30    Arc, LazyLock,
 31    atomic::{self, AtomicU64},
 32};
 33use strum::IntoEnumIterator;
 34use theme::ThemeSettings;
 35use ui::{Icon, IconName, List, Tooltip, prelude::*};
 36use util::{ResultExt, truncate_and_trailoff};
 37use zed_env_vars::EnvVar;
 38
 39use crate::api_key::ApiKey;
 40use crate::api_key::ApiKeyState;
 41use crate::ui::InstructionListItem;
 42
 43const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
 44const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
 45
 46#[derive(Default, Clone, Debug, PartialEq)]
 47pub struct GoogleSettings {
 48    pub api_url: String,
 49    pub available_models: Vec<AvailableModel>,
 50}
 51
 52#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
 53#[serde(tag = "type", rename_all = "lowercase")]
 54pub enum ModelMode {
 55    #[default]
 56    Default,
 57    Thinking {
 58        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
 59        budget_tokens: Option<u32>,
 60    },
 61}
 62
 63impl From<ModelMode> for GoogleModelMode {
 64    fn from(value: ModelMode) -> Self {
 65        match value {
 66            ModelMode::Default => GoogleModelMode::Default,
 67            ModelMode::Thinking { budget_tokens } => GoogleModelMode::Thinking { budget_tokens },
 68        }
 69    }
 70}
 71
 72impl From<GoogleModelMode> for ModelMode {
 73    fn from(value: GoogleModelMode) -> Self {
 74        match value {
 75            GoogleModelMode::Default => ModelMode::Default,
 76            GoogleModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
 77        }
 78    }
 79}
 80
 81#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 82pub struct AvailableModel {
 83    name: String,
 84    display_name: Option<String>,
 85    max_tokens: u64,
 86    mode: Option<ModelMode>,
 87}
 88
 89pub struct GoogleLanguageModelProvider {
 90    http_client: Arc<dyn HttpClient>,
 91    state: gpui::Entity<State>,
 92}
 93
 94pub struct State {
 95    api_key_state: ApiKeyState,
 96}
 97
 98const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
 99const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY";
100
101static API_KEY_ENV_VAR: LazyLock<EnvVar> = LazyLock::new(|| {
102    // Try GEMINI_API_KEY first as primary, fallback to GOOGLE_AI_API_KEY
103    EnvVar::new(GEMINI_API_KEY_VAR_NAME.into()).or(EnvVar::new(GOOGLE_AI_API_KEY_VAR_NAME.into()))
104});
105
106impl State {
107    fn is_authenticated(&self) -> bool {
108        self.api_key_state.has_key()
109    }
110
111    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
112        let api_url = GoogleLanguageModelProvider::api_url(cx);
113        self.api_key_state
114            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
115    }
116
117    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
118        let api_url = GoogleLanguageModelProvider::api_url(cx);
119        self.api_key_state.load_if_needed(
120            api_url,
121            &API_KEY_ENV_VAR,
122            |this| &mut this.api_key_state,
123            cx,
124        )
125    }
126}
127
128impl GoogleLanguageModelProvider {
129    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
130        let state = cx.new(|cx| {
131            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
132                let api_url = Self::api_url(cx);
133                this.api_key_state.handle_url_change(
134                    api_url,
135                    &API_KEY_ENV_VAR,
136                    |this| &mut this.api_key_state,
137                    cx,
138                );
139                cx.notify();
140            })
141            .detach();
142            State {
143                api_key_state: ApiKeyState::new(Self::api_url(cx)),
144            }
145        });
146
147        Self { http_client, state }
148    }
149
150    fn create_language_model(&self, model: google_ai::Model) -> Arc<dyn LanguageModel> {
151        Arc::new(GoogleLanguageModel {
152            id: LanguageModelId::from(model.id().to_string()),
153            model,
154            state: self.state.clone(),
155            http_client: self.http_client.clone(),
156            request_limiter: RateLimiter::new(4),
157        })
158    }
159
160    pub fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
161        if let Some(key) = API_KEY_ENV_VAR.value.clone() {
162            return Task::ready(Ok(key));
163        }
164        let credentials_provider = <dyn CredentialsProvider>::global(cx);
165        let api_url = Self::api_url(cx).to_string();
166        cx.spawn(async move |cx| {
167            Ok(
168                ApiKey::load_from_system_keychain(&api_url, credentials_provider.as_ref(), cx)
169                    .await?
170                    .key()
171                    .to_string(),
172            )
173        })
174    }
175
176    fn settings(cx: &App) -> &GoogleSettings {
177        &crate::AllLanguageModelSettings::get_global(cx).google
178    }
179
180    fn api_url(cx: &App) -> SharedString {
181        let api_url = &Self::settings(cx).api_url;
182        if api_url.is_empty() {
183            google_ai::API_URL.into()
184        } else {
185            SharedString::new(api_url.as_str())
186        }
187    }
188}
189
190impl LanguageModelProviderState for GoogleLanguageModelProvider {
191    type ObservableEntity = State;
192
193    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
194        Some(self.state.clone())
195    }
196}
197
198impl LanguageModelProvider for GoogleLanguageModelProvider {
199    fn id(&self) -> LanguageModelProviderId {
200        PROVIDER_ID
201    }
202
203    fn name(&self) -> LanguageModelProviderName {
204        PROVIDER_NAME
205    }
206
207    fn icon(&self) -> IconName {
208        IconName::AiGoogle
209    }
210
211    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
212        Some(self.create_language_model(google_ai::Model::default()))
213    }
214
215    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
216        Some(self.create_language_model(google_ai::Model::default_fast()))
217    }
218
219    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
220        let mut models = BTreeMap::default();
221
222        // Add base models from google_ai::Model::iter()
223        for model in google_ai::Model::iter() {
224            if !matches!(model, google_ai::Model::Custom { .. }) {
225                models.insert(model.id().to_string(), model);
226            }
227        }
228
229        // Override with available models from settings
230        for model in &GoogleLanguageModelProvider::settings(cx).available_models {
231            models.insert(
232                model.name.clone(),
233                google_ai::Model::Custom {
234                    name: model.name.clone(),
235                    display_name: model.display_name.clone(),
236                    max_tokens: model.max_tokens,
237                    mode: model.mode.unwrap_or_default().into(),
238                },
239            );
240        }
241
242        models
243            .into_values()
244            .map(|model| {
245                Arc::new(GoogleLanguageModel {
246                    id: LanguageModelId::from(model.id().to_string()),
247                    model,
248                    state: self.state.clone(),
249                    http_client: self.http_client.clone(),
250                    request_limiter: RateLimiter::new(4),
251                }) as Arc<dyn LanguageModel>
252            })
253            .collect()
254    }
255
256    fn is_authenticated(&self, cx: &App) -> bool {
257        self.state.read(cx).is_authenticated()
258    }
259
260    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
261        self.state.update(cx, |state, cx| state.authenticate(cx))
262    }
263
264    fn configuration_view(
265        &self,
266        target_agent: language_model::ConfigurationViewTargetAgent,
267        window: &mut Window,
268        cx: &mut App,
269    ) -> AnyView {
270        cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
271            .into()
272    }
273
274    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
275        self.state
276            .update(cx, |state, cx| state.set_api_key(None, cx))
277    }
278}
279
280pub struct GoogleLanguageModel {
281    id: LanguageModelId,
282    model: google_ai::Model,
283    state: gpui::Entity<State>,
284    http_client: Arc<dyn HttpClient>,
285    request_limiter: RateLimiter,
286}
287
288impl GoogleLanguageModel {
289    fn stream_completion(
290        &self,
291        request: google_ai::GenerateContentRequest,
292        cx: &AsyncApp,
293    ) -> BoxFuture<
294        'static,
295        Result<futures::stream::BoxStream<'static, Result<GenerateContentResponse>>>,
296    > {
297        let http_client = self.http_client.clone();
298
299        let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
300            let api_url = GoogleLanguageModelProvider::api_url(cx);
301            (state.api_key_state.key(&api_url), api_url)
302        }) else {
303            return future::ready(Err(anyhow!("App state dropped"))).boxed();
304        };
305
306        async move {
307            let api_key = api_key.context("Missing Google API key")?;
308            let request = google_ai::stream_generate_content(
309                http_client.as_ref(),
310                &api_url,
311                &api_key,
312                request,
313            );
314            request.await.context("failed to stream completion")
315        }
316        .boxed()
317    }
318}
319
320impl LanguageModel for GoogleLanguageModel {
321    fn id(&self) -> LanguageModelId {
322        self.id.clone()
323    }
324
325    fn name(&self) -> LanguageModelName {
326        LanguageModelName::from(self.model.display_name().to_string())
327    }
328
329    fn provider_id(&self) -> LanguageModelProviderId {
330        PROVIDER_ID
331    }
332
333    fn provider_name(&self) -> LanguageModelProviderName {
334        PROVIDER_NAME
335    }
336
337    fn supports_tools(&self) -> bool {
338        self.model.supports_tools()
339    }
340
341    fn supports_images(&self) -> bool {
342        self.model.supports_images()
343    }
344
345    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
346        match choice {
347            LanguageModelToolChoice::Auto
348            | LanguageModelToolChoice::Any
349            | LanguageModelToolChoice::None => true,
350        }
351    }
352
353    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
354        LanguageModelToolSchemaFormat::JsonSchemaSubset
355    }
356
357    fn telemetry_id(&self) -> String {
358        format!("google/{}", self.model.request_id())
359    }
360
361    fn max_token_count(&self) -> u64 {
362        self.model.max_token_count()
363    }
364
365    fn max_output_tokens(&self) -> Option<u64> {
366        self.model.max_output_tokens()
367    }
368
369    fn count_tokens(
370        &self,
371        request: LanguageModelRequest,
372        cx: &App,
373    ) -> BoxFuture<'static, Result<u64>> {
374        let model_id = self.model.request_id().to_string();
375        let request = into_google(request, model_id, self.model.mode());
376        let http_client = self.http_client.clone();
377        let api_url = GoogleLanguageModelProvider::api_url(cx);
378        let api_key = self.state.read(cx).api_key_state.key(&api_url);
379
380        async move {
381            let Some(api_key) = api_key else {
382                return Err(LanguageModelCompletionError::NoApiKey {
383                    provider: PROVIDER_NAME,
384                }
385                .into());
386            };
387            let response = google_ai::count_tokens(
388                http_client.as_ref(),
389                &api_url,
390                &api_key,
391                google_ai::CountTokensRequest {
392                    generate_content_request: request,
393                },
394            )
395            .await?;
396            Ok(response.total_tokens)
397        }
398        .boxed()
399    }
400
401    fn stream_completion(
402        &self,
403        request: LanguageModelRequest,
404        cx: &AsyncApp,
405    ) -> BoxFuture<
406        'static,
407        Result<
408            futures::stream::BoxStream<
409                'static,
410                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
411            >,
412            LanguageModelCompletionError,
413        >,
414    > {
415        let request = into_google(
416            request,
417            self.model.request_id().to_string(),
418            self.model.mode(),
419        );
420        let request = self.stream_completion(request, cx);
421        let future = self.request_limiter.stream(async move {
422            let response = request.await.map_err(LanguageModelCompletionError::from)?;
423            Ok(GoogleEventMapper::new().map_stream(response))
424        });
425        async move { Ok(future.await?.boxed()) }.boxed()
426    }
427}
428
429pub fn into_google(
430    mut request: LanguageModelRequest,
431    model_id: String,
432    mode: GoogleModelMode,
433) -> google_ai::GenerateContentRequest {
434    fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
435        content
436            .into_iter()
437            .flat_map(|content| match content {
438                language_model::MessageContent::Text(text) => {
439                    if !text.is_empty() {
440                        vec![Part::TextPart(google_ai::TextPart { text })]
441                    } else {
442                        vec![]
443                    }
444                }
445                language_model::MessageContent::Thinking {
446                    text: _,
447                    signature: Some(signature),
448                } => {
449                    if !signature.is_empty() {
450                        vec![Part::ThoughtPart(google_ai::ThoughtPart {
451                            thought: true,
452                            thought_signature: signature,
453                        })]
454                    } else {
455                        vec![]
456                    }
457                }
458                language_model::MessageContent::Thinking { .. } => {
459                    vec![]
460                }
461                language_model::MessageContent::RedactedThinking(_) => vec![],
462                language_model::MessageContent::Image(image) => {
463                    vec![Part::InlineDataPart(google_ai::InlineDataPart {
464                        inline_data: google_ai::GenerativeContentBlob {
465                            mime_type: "image/png".to_string(),
466                            data: image.source.to_string(),
467                        },
468                    })]
469                }
470                language_model::MessageContent::ToolUse(tool_use) => {
471                    vec![Part::FunctionCallPart(google_ai::FunctionCallPart {
472                        function_call: google_ai::FunctionCall {
473                            name: tool_use.name.to_string(),
474                            args: tool_use.input,
475                        },
476                    })]
477                }
478                language_model::MessageContent::ToolResult(tool_result) => {
479                    match tool_result.content {
480                        language_model::LanguageModelToolResultContent::Text(text) => {
481                            vec![Part::FunctionResponsePart(
482                                google_ai::FunctionResponsePart {
483                                    function_response: google_ai::FunctionResponse {
484                                        name: tool_result.tool_name.to_string(),
485                                        // The API expects a valid JSON object
486                                        response: serde_json::json!({
487                                            "output": text
488                                        }),
489                                    },
490                                },
491                            )]
492                        }
493                        language_model::LanguageModelToolResultContent::Image(image) => {
494                            vec![
495                                Part::FunctionResponsePart(google_ai::FunctionResponsePart {
496                                    function_response: google_ai::FunctionResponse {
497                                        name: tool_result.tool_name.to_string(),
498                                        // The API expects a valid JSON object
499                                        response: serde_json::json!({
500                                            "output": "Tool responded with an image"
501                                        }),
502                                    },
503                                }),
504                                Part::InlineDataPart(google_ai::InlineDataPart {
505                                    inline_data: google_ai::GenerativeContentBlob {
506                                        mime_type: "image/png".to_string(),
507                                        data: image.source.to_string(),
508                                    },
509                                }),
510                            ]
511                        }
512                    }
513                }
514            })
515            .collect()
516    }
517
518    let system_instructions = if request
519        .messages
520        .first()
521        .is_some_and(|msg| matches!(msg.role, Role::System))
522    {
523        let message = request.messages.remove(0);
524        Some(SystemInstruction {
525            parts: map_content(message.content),
526        })
527    } else {
528        None
529    };
530
531    google_ai::GenerateContentRequest {
532        model: google_ai::ModelName { model_id },
533        system_instruction: system_instructions,
534        contents: request
535            .messages
536            .into_iter()
537            .filter_map(|message| {
538                let parts = map_content(message.content);
539                if parts.is_empty() {
540                    None
541                } else {
542                    Some(google_ai::Content {
543                        parts,
544                        role: match message.role {
545                            Role::User => google_ai::Role::User,
546                            Role::Assistant => google_ai::Role::Model,
547                            Role::System => google_ai::Role::User, // Google AI doesn't have a system role
548                        },
549                    })
550                }
551            })
552            .collect(),
553        generation_config: Some(google_ai::GenerationConfig {
554            candidate_count: Some(1),
555            stop_sequences: Some(request.stop),
556            max_output_tokens: None,
557            temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
558            thinking_config: match (request.thinking_allowed, mode) {
559                (true, GoogleModelMode::Thinking { budget_tokens }) => {
560                    budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
561                }
562                _ => None,
563            },
564            top_p: None,
565            top_k: None,
566        }),
567        safety_settings: None,
568        tools: (!request.tools.is_empty()).then(|| {
569            vec![google_ai::Tool {
570                function_declarations: request
571                    .tools
572                    .into_iter()
573                    .map(|tool| FunctionDeclaration {
574                        name: tool.name,
575                        description: tool.description,
576                        parameters: tool.input_schema,
577                    })
578                    .collect(),
579            }]
580        }),
581        tool_config: request.tool_choice.map(|choice| google_ai::ToolConfig {
582            function_calling_config: google_ai::FunctionCallingConfig {
583                mode: match choice {
584                    LanguageModelToolChoice::Auto => google_ai::FunctionCallingMode::Auto,
585                    LanguageModelToolChoice::Any => google_ai::FunctionCallingMode::Any,
586                    LanguageModelToolChoice::None => google_ai::FunctionCallingMode::None,
587                },
588                allowed_function_names: None,
589            },
590        }),
591    }
592}
593
594pub struct GoogleEventMapper {
595    usage: UsageMetadata,
596    stop_reason: StopReason,
597}
598
599impl GoogleEventMapper {
600    pub fn new() -> Self {
601        Self {
602            usage: UsageMetadata::default(),
603            stop_reason: StopReason::EndTurn,
604        }
605    }
606
607    pub fn map_stream(
608        mut self,
609        events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
610    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
611    {
612        events
613            .map(Some)
614            .chain(futures::stream::once(async { None }))
615            .flat_map(move |event| {
616                futures::stream::iter(match event {
617                    Some(Ok(event)) => self.map_event(event),
618                    Some(Err(error)) => {
619                        vec![Err(LanguageModelCompletionError::from(error))]
620                    }
621                    None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
622                })
623            })
624    }
625
626    pub fn map_event(
627        &mut self,
628        event: GenerateContentResponse,
629    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
630        static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
631
632        let mut events: Vec<_> = Vec::new();
633        let mut wants_to_use_tool = false;
634        if let Some(usage_metadata) = event.usage_metadata {
635            update_usage(&mut self.usage, &usage_metadata);
636            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
637                convert_usage(&self.usage),
638            )))
639        }
640        if let Some(candidates) = event.candidates {
641            for candidate in candidates {
642                if let Some(finish_reason) = candidate.finish_reason.as_deref() {
643                    self.stop_reason = match finish_reason {
644                        "STOP" => StopReason::EndTurn,
645                        "MAX_TOKENS" => StopReason::MaxTokens,
646                        _ => {
647                            log::error!("Unexpected google finish_reason: {finish_reason}");
648                            StopReason::EndTurn
649                        }
650                    };
651                }
652                candidate
653                    .content
654                    .parts
655                    .into_iter()
656                    .for_each(|part| match part {
657                        Part::TextPart(text_part) => {
658                            events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
659                        }
660                        Part::InlineDataPart(_) => {}
661                        Part::FunctionCallPart(function_call_part) => {
662                            wants_to_use_tool = true;
663                            let name: Arc<str> = function_call_part.function_call.name.into();
664                            let next_tool_id =
665                                TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
666                            let id: LanguageModelToolUseId =
667                                format!("{}-{}", name, next_tool_id).into();
668
669                            events.push(Ok(LanguageModelCompletionEvent::ToolUse(
670                                LanguageModelToolUse {
671                                    id,
672                                    name,
673                                    is_input_complete: true,
674                                    raw_input: function_call_part.function_call.args.to_string(),
675                                    input: function_call_part.function_call.args,
676                                },
677                            )));
678                        }
679                        Part::FunctionResponsePart(_) => {}
680                        Part::ThoughtPart(part) => {
681                            events.push(Ok(LanguageModelCompletionEvent::Thinking {
682                                text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
683                                signature: Some(part.thought_signature),
684                            }));
685                        }
686                    });
687            }
688        }
689
690        // Even when Gemini wants to use a Tool, the API
691        // responds with `finish_reason: STOP`
692        if wants_to_use_tool {
693            self.stop_reason = StopReason::ToolUse;
694            events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
695        }
696        events
697    }
698}
699
700pub fn count_google_tokens(
701    request: LanguageModelRequest,
702    cx: &App,
703) -> BoxFuture<'static, Result<u64>> {
704    // We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly.
705    // So we have to use tokenizer from tiktoken_rs to count tokens.
706    cx.background_spawn(async move {
707        let messages = request
708            .messages
709            .into_iter()
710            .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
711                role: match message.role {
712                    Role::User => "user".into(),
713                    Role::Assistant => "assistant".into(),
714                    Role::System => "system".into(),
715                },
716                content: Some(message.string_contents()),
717                name: None,
718                function_call: None,
719            })
720            .collect::<Vec<_>>();
721
722        // Tiktoken doesn't yet support these models, so we manually use the
723        // same tokenizer as GPT-4.
724        tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
725    })
726    .boxed()
727}
728
729fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
730    if let Some(prompt_token_count) = new.prompt_token_count {
731        usage.prompt_token_count = Some(prompt_token_count);
732    }
733    if let Some(cached_content_token_count) = new.cached_content_token_count {
734        usage.cached_content_token_count = Some(cached_content_token_count);
735    }
736    if let Some(candidates_token_count) = new.candidates_token_count {
737        usage.candidates_token_count = Some(candidates_token_count);
738    }
739    if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
740        usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
741    }
742    if let Some(thoughts_token_count) = new.thoughts_token_count {
743        usage.thoughts_token_count = Some(thoughts_token_count);
744    }
745    if let Some(total_token_count) = new.total_token_count {
746        usage.total_token_count = Some(total_token_count);
747    }
748}
749
750fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
751    let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
752    let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
753    let input_tokens = prompt_tokens - cached_tokens;
754    let output_tokens = usage.candidates_token_count.unwrap_or(0);
755
756    language_model::TokenUsage {
757        input_tokens,
758        output_tokens,
759        cache_read_input_tokens: cached_tokens,
760        cache_creation_input_tokens: 0,
761    }
762}
763
764struct ConfigurationView {
765    api_key_editor: Entity<Editor>,
766    state: gpui::Entity<State>,
767    target_agent: language_model::ConfigurationViewTargetAgent,
768    load_credentials_task: Option<Task<()>>,
769}
770
771impl ConfigurationView {
772    fn new(
773        state: gpui::Entity<State>,
774        target_agent: language_model::ConfigurationViewTargetAgent,
775        window: &mut Window,
776        cx: &mut Context<Self>,
777    ) -> Self {
778        cx.observe(&state, |_, _, cx| {
779            cx.notify();
780        })
781        .detach();
782
783        let load_credentials_task = Some(cx.spawn_in(window, {
784            let state = state.clone();
785            async move |this, cx| {
786                if let Some(task) = state
787                    .update(cx, |state, cx| state.authenticate(cx))
788                    .log_err()
789                {
790                    // We don't log an error, because "not signed in" is also an error.
791                    let _ = task.await;
792                }
793                this.update(cx, |this, cx| {
794                    this.load_credentials_task = None;
795                    cx.notify();
796                })
797                .log_err();
798            }
799        }));
800
801        Self {
802            api_key_editor: cx.new(|cx| {
803                let mut editor = Editor::single_line(window, cx);
804                editor.set_placeholder_text("AIzaSy...", window, cx);
805                editor
806            }),
807            target_agent,
808            state,
809            load_credentials_task,
810        }
811    }
812
813    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
814        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
815        if api_key.is_empty() {
816            return;
817        }
818
819        // url changes can cause the editor to be displayed again
820        self.api_key_editor
821            .update(cx, |editor, cx| editor.set_text("", window, cx));
822
823        let state = self.state.clone();
824        cx.spawn_in(window, async move |_, cx| {
825            state
826                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
827                .await
828        })
829        .detach_and_log_err(cx);
830    }
831
832    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
833        self.api_key_editor
834            .update(cx, |editor, cx| editor.set_text("", window, cx));
835
836        let state = self.state.clone();
837        cx.spawn_in(window, async move |_, cx| {
838            state
839                .update(cx, |state, cx| state.set_api_key(None, cx))?
840                .await
841        })
842        .detach_and_log_err(cx);
843    }
844
845    fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
846        let settings = ThemeSettings::get_global(cx);
847        let text_style = TextStyle {
848            color: cx.theme().colors().text,
849            font_family: settings.ui_font.family.clone(),
850            font_features: settings.ui_font.features.clone(),
851            font_fallbacks: settings.ui_font.fallbacks.clone(),
852            font_size: rems(0.875).into(),
853            font_weight: settings.ui_font.weight,
854            font_style: FontStyle::Normal,
855            line_height: relative(1.3),
856            white_space: WhiteSpace::Normal,
857            ..Default::default()
858        };
859        EditorElement::new(
860            &self.api_key_editor,
861            EditorStyle {
862                background: cx.theme().colors().editor_background,
863                local_player: cx.theme().players().local(),
864                text: text_style,
865                ..Default::default()
866            },
867        )
868    }
869
870    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
871        !self.state.read(cx).is_authenticated()
872    }
873}
874
875impl Render for ConfigurationView {
876    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
877        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
878
879        if self.load_credentials_task.is_some() {
880            div().child(Label::new("Loading credentials...")).into_any()
881        } else if self.should_render_editor(cx) {
882            v_flex()
883                .size_full()
884                .on_action(cx.listener(Self::save_api_key))
885                .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent {
886                    ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Google AI".into(),
887                    ConfigurationViewTargetAgent::Other(agent) => agent.clone(),
888                })))
889                .child(
890                    List::new()
891                        .child(InstructionListItem::new(
892                            "Create one by visiting",
893                            Some("Google AI's console"),
894                            Some("https://aistudio.google.com/app/apikey"),
895                        ))
896                        .child(InstructionListItem::text_only(
897                            "Paste your API key below and hit enter to start using the assistant",
898                        )),
899                )
900                .child(
901                    h_flex()
902                        .w_full()
903                        .my_2()
904                        .px_2()
905                        .py_1()
906                        .bg(cx.theme().colors().editor_background)
907                        .border_1()
908                        .border_color(cx.theme().colors().border)
909                        .rounded_sm()
910                        .child(self.render_api_key_editor(cx)),
911                )
912                .child(
913                    Label::new(
914                        format!("You can also assign the {GEMINI_API_KEY_VAR_NAME} environment variable and restart Zed."),
915                    )
916                    .size(LabelSize::Small).color(Color::Muted),
917                )
918                .into_any()
919        } else {
920            h_flex()
921                .mt_1()
922                .p_1()
923                .justify_between()
924                .rounded_md()
925                .border_1()
926                .border_color(cx.theme().colors().border)
927                .bg(cx.theme().colors().background)
928                .child(
929                    h_flex()
930                        .gap_1()
931                        .child(Icon::new(IconName::Check).color(Color::Success))
932                        .child(Label::new(if env_var_set {
933                            format!("API key set in {} environment variable", API_KEY_ENV_VAR.name)
934                        } else {
935                            let api_url = GoogleLanguageModelProvider::api_url(cx);
936                            if api_url == google_ai::API_URL {
937                                "API key configured".to_string()
938                            } else {
939                                format!("API key configured for {}", truncate_and_trailoff(&api_url, 32))
940                            }
941                        })),
942                )
943                .child(
944                    Button::new("reset-key", "Reset Key")
945                        .label_size(LabelSize::Small)
946                        .icon(Some(IconName::Trash))
947                        .icon_size(IconSize::Small)
948                        .icon_position(IconPosition::Start)
949                        .disabled(env_var_set)
950                        .when(env_var_set, |this| {
951                            this.tooltip(Tooltip::text(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR_NAME} and {GOOGLE_AI_API_KEY_VAR_NAME} environment variables are unset.")))
952                        })
953                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
954                )
955                .into_any()
956        }
957    }
958}