google.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use collections::BTreeMap;
  3use credentials_provider::CredentialsProvider;
  4use editor::{Editor, EditorElement, EditorStyle};
  5use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
  6use google_ai::{
  7    FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
  8    ThinkingConfig, UsageMetadata,
  9};
 10use gpui::{
 11    AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
 12};
 13use http_client::HttpClient;
 14use language_model::{
 15    AuthenticateError, ConfigurationViewTargetAgent, LanguageModelCompletionError,
 16    LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolSchemaFormat,
 17    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
 18};
 19use language_model::{
 20    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
 21    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
 22    LanguageModelRequest, RateLimiter, Role,
 23};
 24use schemars::JsonSchema;
 25use serde::{Deserialize, Serialize};
 26use settings::{Settings, SettingsStore};
 27use std::pin::Pin;
 28use std::sync::{
 29    Arc,
 30    atomic::{self, AtomicU64},
 31};
 32use strum::IntoEnumIterator;
 33use theme::ThemeSettings;
 34use ui::{Icon, IconName, List, Tooltip, prelude::*};
 35use util::ResultExt;
 36
 37use crate::AllLanguageModelSettings;
 38use crate::ui::InstructionListItem;
 39
 40use super::anthropic::ApiKey;
 41
 42const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
 43const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
 44
 45#[derive(Default, Clone, Debug, PartialEq)]
 46pub struct GoogleSettings {
 47    pub api_url: String,
 48    pub available_models: Vec<AvailableModel>,
 49}
 50
 51#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
 52#[serde(tag = "type", rename_all = "lowercase")]
 53pub enum ModelMode {
 54    #[default]
 55    Default,
 56    Thinking {
 57        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
 58        budget_tokens: Option<u32>,
 59    },
 60}
 61
 62impl From<ModelMode> for GoogleModelMode {
 63    fn from(value: ModelMode) -> Self {
 64        match value {
 65            ModelMode::Default => GoogleModelMode::Default,
 66            ModelMode::Thinking { budget_tokens } => GoogleModelMode::Thinking { budget_tokens },
 67        }
 68    }
 69}
 70
 71impl From<GoogleModelMode> for ModelMode {
 72    fn from(value: GoogleModelMode) -> Self {
 73        match value {
 74            GoogleModelMode::Default => ModelMode::Default,
 75            GoogleModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
 76        }
 77    }
 78}
 79
 80#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 81pub struct AvailableModel {
 82    name: String,
 83    display_name: Option<String>,
 84    max_tokens: u64,
 85    mode: Option<ModelMode>,
 86}
 87
 88pub struct GoogleLanguageModelProvider {
 89    http_client: Arc<dyn HttpClient>,
 90    state: gpui::Entity<State>,
 91}
 92
 93pub struct State {
 94    api_key: Option<String>,
 95    api_key_from_env: bool,
 96    _subscription: Subscription,
 97}
 98
 99const GEMINI_API_KEY_VAR: &str = "GEMINI_API_KEY";
100const GOOGLE_AI_API_KEY_VAR: &str = "GOOGLE_AI_API_KEY";
101
102impl State {
103    fn is_authenticated(&self) -> bool {
104        self.api_key.is_some()
105    }
106
107    fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
108        let credentials_provider = <dyn CredentialsProvider>::global(cx);
109        let api_url = AllLanguageModelSettings::get_global(cx)
110            .google
111            .api_url
112            .clone();
113        cx.spawn(async move |this, cx| {
114            credentials_provider
115                .delete_credentials(&api_url, cx)
116                .await
117                .log_err();
118            this.update(cx, |this, cx| {
119                this.api_key = None;
120                this.api_key_from_env = false;
121                cx.notify();
122            })
123        })
124    }
125
126    fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
127        let credentials_provider = <dyn CredentialsProvider>::global(cx);
128        let api_url = AllLanguageModelSettings::get_global(cx)
129            .google
130            .api_url
131            .clone();
132        cx.spawn(async move |this, cx| {
133            credentials_provider
134                .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
135                .await?;
136            this.update(cx, |this, cx| {
137                this.api_key = Some(api_key);
138                cx.notify();
139            })
140        })
141    }
142
143    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
144        if self.is_authenticated() {
145            return Task::ready(Ok(()));
146        }
147
148        let credentials_provider = <dyn CredentialsProvider>::global(cx);
149        let api_url = AllLanguageModelSettings::get_global(cx)
150            .google
151            .api_url
152            .clone();
153
154        cx.spawn(async move |this, cx| {
155            let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR) {
156                (api_key, true)
157            } else if let Ok(api_key) = std::env::var(GEMINI_API_KEY_VAR) {
158                (api_key, true)
159            } else {
160                let (_, api_key) = credentials_provider
161                    .read_credentials(&api_url, cx)
162                    .await?
163                    .ok_or(AuthenticateError::CredentialsNotFound)?;
164                (
165                    String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
166                    false,
167                )
168            };
169
170            this.update(cx, |this, cx| {
171                this.api_key = Some(api_key);
172                this.api_key_from_env = from_env;
173                cx.notify();
174            })?;
175
176            Ok(())
177        })
178    }
179}
180
181impl GoogleLanguageModelProvider {
182    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
183        let state = cx.new(|cx| State {
184            api_key: None,
185            api_key_from_env: false,
186            _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
187                cx.notify();
188            }),
189        });
190
191        Self { http_client, state }
192    }
193
194    fn create_language_model(&self, model: google_ai::Model) -> Arc<dyn LanguageModel> {
195        Arc::new(GoogleLanguageModel {
196            id: LanguageModelId::from(model.id().to_string()),
197            model,
198            state: self.state.clone(),
199            http_client: self.http_client.clone(),
200            request_limiter: RateLimiter::new(4),
201        })
202    }
203
204    pub fn api_key(cx: &mut App) -> Task<Result<ApiKey>> {
205        let credentials_provider = <dyn CredentialsProvider>::global(cx);
206        let api_url = AllLanguageModelSettings::get_global(cx)
207            .google
208            .api_url
209            .clone();
210
211        if let Ok(key) = std::env::var(GEMINI_API_KEY_VAR) {
212            Task::ready(Ok(ApiKey {
213                key,
214                from_env: true,
215            }))
216        } else {
217            cx.spawn(async move |cx| {
218                let (_, api_key) = credentials_provider
219                    .read_credentials(&api_url, cx)
220                    .await?
221                    .ok_or(AuthenticateError::CredentialsNotFound)?;
222
223                Ok(ApiKey {
224                    key: String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
225                    from_env: false,
226                })
227            })
228        }
229    }
230}
231
232impl LanguageModelProviderState for GoogleLanguageModelProvider {
233    type ObservableEntity = State;
234
235    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
236        Some(self.state.clone())
237    }
238}
239
240impl LanguageModelProvider for GoogleLanguageModelProvider {
241    fn id(&self) -> LanguageModelProviderId {
242        PROVIDER_ID
243    }
244
245    fn name(&self) -> LanguageModelProviderName {
246        PROVIDER_NAME
247    }
248
249    fn icon(&self) -> IconName {
250        IconName::AiGoogle
251    }
252
253    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
254        Some(self.create_language_model(google_ai::Model::default()))
255    }
256
257    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
258        Some(self.create_language_model(google_ai::Model::default_fast()))
259    }
260
261    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
262        let mut models = BTreeMap::default();
263
264        // Add base models from google_ai::Model::iter()
265        for model in google_ai::Model::iter() {
266            if !matches!(model, google_ai::Model::Custom { .. }) {
267                models.insert(model.id().to_string(), model);
268            }
269        }
270
271        // Override with available models from settings
272        for model in &AllLanguageModelSettings::get_global(cx)
273            .google
274            .available_models
275        {
276            models.insert(
277                model.name.clone(),
278                google_ai::Model::Custom {
279                    name: model.name.clone(),
280                    display_name: model.display_name.clone(),
281                    max_tokens: model.max_tokens,
282                    mode: model.mode.unwrap_or_default().into(),
283                },
284            );
285        }
286
287        models
288            .into_values()
289            .map(|model| {
290                Arc::new(GoogleLanguageModel {
291                    id: LanguageModelId::from(model.id().to_string()),
292                    model,
293                    state: self.state.clone(),
294                    http_client: self.http_client.clone(),
295                    request_limiter: RateLimiter::new(4),
296                }) as Arc<dyn LanguageModel>
297            })
298            .collect()
299    }
300
301    fn is_authenticated(&self, cx: &App) -> bool {
302        self.state.read(cx).is_authenticated()
303    }
304
305    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
306        self.state.update(cx, |state, cx| state.authenticate(cx))
307    }
308
309    fn configuration_view(
310        &self,
311        target_agent: language_model::ConfigurationViewTargetAgent,
312        window: &mut Window,
313        cx: &mut App,
314    ) -> AnyView {
315        cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx))
316            .into()
317    }
318
319    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
320        self.state.update(cx, |state, cx| state.reset_api_key(cx))
321    }
322}
323
324pub struct GoogleLanguageModel {
325    id: LanguageModelId,
326    model: google_ai::Model,
327    state: gpui::Entity<State>,
328    http_client: Arc<dyn HttpClient>,
329    request_limiter: RateLimiter,
330}
331
332impl GoogleLanguageModel {
333    fn stream_completion(
334        &self,
335        request: google_ai::GenerateContentRequest,
336        cx: &AsyncApp,
337    ) -> BoxFuture<
338        'static,
339        Result<futures::stream::BoxStream<'static, Result<GenerateContentResponse>>>,
340    > {
341        let http_client = self.http_client.clone();
342
343        let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
344            let settings = &AllLanguageModelSettings::get_global(cx).google;
345            (state.api_key.clone(), settings.api_url.clone())
346        }) else {
347            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
348        };
349
350        async move {
351            let api_key = api_key.context("Missing Google API key")?;
352            let request = google_ai::stream_generate_content(
353                http_client.as_ref(),
354                &api_url,
355                &api_key,
356                request,
357            );
358            request.await.context("failed to stream completion")
359        }
360        .boxed()
361    }
362}
363
364impl LanguageModel for GoogleLanguageModel {
365    fn id(&self) -> LanguageModelId {
366        self.id.clone()
367    }
368
369    fn name(&self) -> LanguageModelName {
370        LanguageModelName::from(self.model.display_name().to_string())
371    }
372
373    fn provider_id(&self) -> LanguageModelProviderId {
374        PROVIDER_ID
375    }
376
377    fn provider_name(&self) -> LanguageModelProviderName {
378        PROVIDER_NAME
379    }
380
381    fn supports_tools(&self) -> bool {
382        self.model.supports_tools()
383    }
384
385    fn supports_images(&self) -> bool {
386        self.model.supports_images()
387    }
388
389    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
390        match choice {
391            LanguageModelToolChoice::Auto
392            | LanguageModelToolChoice::Any
393            | LanguageModelToolChoice::None => true,
394        }
395    }
396
397    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
398        LanguageModelToolSchemaFormat::JsonSchemaSubset
399    }
400
401    fn telemetry_id(&self) -> String {
402        format!("google/{}", self.model.request_id())
403    }
404
405    fn max_token_count(&self) -> u64 {
406        self.model.max_token_count()
407    }
408
409    fn max_output_tokens(&self) -> Option<u64> {
410        self.model.max_output_tokens()
411    }
412
413    fn count_tokens(
414        &self,
415        request: LanguageModelRequest,
416        cx: &App,
417    ) -> BoxFuture<'static, Result<u64>> {
418        let model_id = self.model.request_id().to_string();
419        let request = into_google(request, model_id, self.model.mode());
420        let http_client = self.http_client.clone();
421        let api_key = self.state.read(cx).api_key.clone();
422
423        let settings = &AllLanguageModelSettings::get_global(cx).google;
424        let api_url = settings.api_url.clone();
425
426        async move {
427            let api_key = api_key.context("Missing Google API key")?;
428            let response = google_ai::count_tokens(
429                http_client.as_ref(),
430                &api_url,
431                &api_key,
432                google_ai::CountTokensRequest {
433                    generate_content_request: request,
434                },
435            )
436            .await?;
437            Ok(response.total_tokens)
438        }
439        .boxed()
440    }
441
442    fn stream_completion(
443        &self,
444        request: LanguageModelRequest,
445        cx: &AsyncApp,
446    ) -> BoxFuture<
447        'static,
448        Result<
449            futures::stream::BoxStream<
450                'static,
451                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
452            >,
453            LanguageModelCompletionError,
454        >,
455    > {
456        let request = into_google(
457            request,
458            self.model.request_id().to_string(),
459            self.model.mode(),
460        );
461        let request = self.stream_completion(request, cx);
462        let future = self.request_limiter.stream(async move {
463            let response = request.await.map_err(LanguageModelCompletionError::from)?;
464            Ok(GoogleEventMapper::new().map_stream(response))
465        });
466        async move { Ok(future.await?.boxed()) }.boxed()
467    }
468}
469
470pub fn into_google(
471    mut request: LanguageModelRequest,
472    model_id: String,
473    mode: GoogleModelMode,
474) -> google_ai::GenerateContentRequest {
475    fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
476        content
477            .into_iter()
478            .flat_map(|content| match content {
479                language_model::MessageContent::Text(text) => {
480                    if !text.is_empty() {
481                        vec![Part::TextPart(google_ai::TextPart { text })]
482                    } else {
483                        vec![]
484                    }
485                }
486                language_model::MessageContent::Thinking {
487                    text: _,
488                    signature: Some(signature),
489                } => {
490                    if !signature.is_empty() {
491                        vec![Part::ThoughtPart(google_ai::ThoughtPart {
492                            thought: true,
493                            thought_signature: signature,
494                        })]
495                    } else {
496                        vec![]
497                    }
498                }
499                language_model::MessageContent::Thinking { .. } => {
500                    vec![]
501                }
502                language_model::MessageContent::RedactedThinking(_) => vec![],
503                language_model::MessageContent::Image(image) => {
504                    vec![Part::InlineDataPart(google_ai::InlineDataPart {
505                        inline_data: google_ai::GenerativeContentBlob {
506                            mime_type: "image/png".to_string(),
507                            data: image.source.to_string(),
508                        },
509                    })]
510                }
511                language_model::MessageContent::ToolUse(tool_use) => {
512                    vec![Part::FunctionCallPart(google_ai::FunctionCallPart {
513                        function_call: google_ai::FunctionCall {
514                            name: tool_use.name.to_string(),
515                            args: tool_use.input,
516                        },
517                    })]
518                }
519                language_model::MessageContent::ToolResult(tool_result) => {
520                    match tool_result.content {
521                        language_model::LanguageModelToolResultContent::Text(text) => {
522                            vec![Part::FunctionResponsePart(
523                                google_ai::FunctionResponsePart {
524                                    function_response: google_ai::FunctionResponse {
525                                        name: tool_result.tool_name.to_string(),
526                                        // The API expects a valid JSON object
527                                        response: serde_json::json!({
528                                            "output": text
529                                        }),
530                                    },
531                                },
532                            )]
533                        }
534                        language_model::LanguageModelToolResultContent::Image(image) => {
535                            vec![
536                                Part::FunctionResponsePart(google_ai::FunctionResponsePart {
537                                    function_response: google_ai::FunctionResponse {
538                                        name: tool_result.tool_name.to_string(),
539                                        // The API expects a valid JSON object
540                                        response: serde_json::json!({
541                                            "output": "Tool responded with an image"
542                                        }),
543                                    },
544                                }),
545                                Part::InlineDataPart(google_ai::InlineDataPart {
546                                    inline_data: google_ai::GenerativeContentBlob {
547                                        mime_type: "image/png".to_string(),
548                                        data: image.source.to_string(),
549                                    },
550                                }),
551                            ]
552                        }
553                    }
554                }
555            })
556            .collect()
557    }
558
559    let system_instructions = if request
560        .messages
561        .first()
562        .is_some_and(|msg| matches!(msg.role, Role::System))
563    {
564        let message = request.messages.remove(0);
565        Some(SystemInstruction {
566            parts: map_content(message.content),
567        })
568    } else {
569        None
570    };
571
572    google_ai::GenerateContentRequest {
573        model: google_ai::ModelName { model_id },
574        system_instruction: system_instructions,
575        contents: request
576            .messages
577            .into_iter()
578            .filter_map(|message| {
579                let parts = map_content(message.content);
580                if parts.is_empty() {
581                    None
582                } else {
583                    Some(google_ai::Content {
584                        parts,
585                        role: match message.role {
586                            Role::User => google_ai::Role::User,
587                            Role::Assistant => google_ai::Role::Model,
588                            Role::System => google_ai::Role::User, // Google AI doesn't have a system role
589                        },
590                    })
591                }
592            })
593            .collect(),
594        generation_config: Some(google_ai::GenerationConfig {
595            candidate_count: Some(1),
596            stop_sequences: Some(request.stop),
597            max_output_tokens: None,
598            temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
599            thinking_config: match (request.thinking_allowed, mode) {
600                (true, GoogleModelMode::Thinking { budget_tokens }) => {
601                    budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
602                }
603                _ => None,
604            },
605            top_p: None,
606            top_k: None,
607        }),
608        safety_settings: None,
609        tools: (!request.tools.is_empty()).then(|| {
610            vec![google_ai::Tool {
611                function_declarations: request
612                    .tools
613                    .into_iter()
614                    .map(|tool| FunctionDeclaration {
615                        name: tool.name,
616                        description: tool.description,
617                        parameters: tool.input_schema,
618                    })
619                    .collect(),
620            }]
621        }),
622        tool_config: request.tool_choice.map(|choice| google_ai::ToolConfig {
623            function_calling_config: google_ai::FunctionCallingConfig {
624                mode: match choice {
625                    LanguageModelToolChoice::Auto => google_ai::FunctionCallingMode::Auto,
626                    LanguageModelToolChoice::Any => google_ai::FunctionCallingMode::Any,
627                    LanguageModelToolChoice::None => google_ai::FunctionCallingMode::None,
628                },
629                allowed_function_names: None,
630            },
631        }),
632    }
633}
634
635pub struct GoogleEventMapper {
636    usage: UsageMetadata,
637    stop_reason: StopReason,
638}
639
640impl GoogleEventMapper {
641    pub fn new() -> Self {
642        Self {
643            usage: UsageMetadata::default(),
644            stop_reason: StopReason::EndTurn,
645        }
646    }
647
648    pub fn map_stream(
649        mut self,
650        events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
651    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
652    {
653        events
654            .map(Some)
655            .chain(futures::stream::once(async { None }))
656            .flat_map(move |event| {
657                futures::stream::iter(match event {
658                    Some(Ok(event)) => self.map_event(event),
659                    Some(Err(error)) => {
660                        vec![Err(LanguageModelCompletionError::from(error))]
661                    }
662                    None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
663                })
664            })
665    }
666
667    pub fn map_event(
668        &mut self,
669        event: GenerateContentResponse,
670    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
671        static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
672
673        let mut events: Vec<_> = Vec::new();
674        let mut wants_to_use_tool = false;
675        if let Some(usage_metadata) = event.usage_metadata {
676            update_usage(&mut self.usage, &usage_metadata);
677            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
678                convert_usage(&self.usage),
679            )))
680        }
681        if let Some(candidates) = event.candidates {
682            for candidate in candidates {
683                if let Some(finish_reason) = candidate.finish_reason.as_deref() {
684                    self.stop_reason = match finish_reason {
685                        "STOP" => StopReason::EndTurn,
686                        "MAX_TOKENS" => StopReason::MaxTokens,
687                        _ => {
688                            log::error!("Unexpected google finish_reason: {finish_reason}");
689                            StopReason::EndTurn
690                        }
691                    };
692                }
693                candidate
694                    .content
695                    .parts
696                    .into_iter()
697                    .for_each(|part| match part {
698                        Part::TextPart(text_part) => {
699                            events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
700                        }
701                        Part::InlineDataPart(_) => {}
702                        Part::FunctionCallPart(function_call_part) => {
703                            wants_to_use_tool = true;
704                            let name: Arc<str> = function_call_part.function_call.name.into();
705                            let next_tool_id =
706                                TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
707                            let id: LanguageModelToolUseId =
708                                format!("{}-{}", name, next_tool_id).into();
709
710                            events.push(Ok(LanguageModelCompletionEvent::ToolUse(
711                                LanguageModelToolUse {
712                                    id,
713                                    name,
714                                    is_input_complete: true,
715                                    raw_input: function_call_part.function_call.args.to_string(),
716                                    input: function_call_part.function_call.args,
717                                },
718                            )));
719                        }
720                        Part::FunctionResponsePart(_) => {}
721                        Part::ThoughtPart(part) => {
722                            events.push(Ok(LanguageModelCompletionEvent::Thinking {
723                                text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
724                                signature: Some(part.thought_signature),
725                            }));
726                        }
727                    });
728            }
729        }
730
731        // Even when Gemini wants to use a Tool, the API
732        // responds with `finish_reason: STOP`
733        if wants_to_use_tool {
734            self.stop_reason = StopReason::ToolUse;
735            events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
736        }
737        events
738    }
739}
740
741pub fn count_google_tokens(
742    request: LanguageModelRequest,
743    cx: &App,
744) -> BoxFuture<'static, Result<u64>> {
745    // We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly.
746    // So we have to use tokenizer from tiktoken_rs to count tokens.
747    cx.background_spawn(async move {
748        let messages = request
749            .messages
750            .into_iter()
751            .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
752                role: match message.role {
753                    Role::User => "user".into(),
754                    Role::Assistant => "assistant".into(),
755                    Role::System => "system".into(),
756                },
757                content: Some(message.string_contents()),
758                name: None,
759                function_call: None,
760            })
761            .collect::<Vec<_>>();
762
763        // Tiktoken doesn't yet support these models, so we manually use the
764        // same tokenizer as GPT-4.
765        tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
766    })
767    .boxed()
768}
769
770fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
771    if let Some(prompt_token_count) = new.prompt_token_count {
772        usage.prompt_token_count = Some(prompt_token_count);
773    }
774    if let Some(cached_content_token_count) = new.cached_content_token_count {
775        usage.cached_content_token_count = Some(cached_content_token_count);
776    }
777    if let Some(candidates_token_count) = new.candidates_token_count {
778        usage.candidates_token_count = Some(candidates_token_count);
779    }
780    if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
781        usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
782    }
783    if let Some(thoughts_token_count) = new.thoughts_token_count {
784        usage.thoughts_token_count = Some(thoughts_token_count);
785    }
786    if let Some(total_token_count) = new.total_token_count {
787        usage.total_token_count = Some(total_token_count);
788    }
789}
790
791fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
792    let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
793    let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
794    let input_tokens = prompt_tokens - cached_tokens;
795    let output_tokens = usage.candidates_token_count.unwrap_or(0);
796
797    language_model::TokenUsage {
798        input_tokens,
799        output_tokens,
800        cache_read_input_tokens: cached_tokens,
801        cache_creation_input_tokens: 0,
802    }
803}
804
805struct ConfigurationView {
806    api_key_editor: Entity<Editor>,
807    state: gpui::Entity<State>,
808    target_agent: language_model::ConfigurationViewTargetAgent,
809    load_credentials_task: Option<Task<()>>,
810}
811
812impl ConfigurationView {
813    fn new(
814        state: gpui::Entity<State>,
815        target_agent: language_model::ConfigurationViewTargetAgent,
816        window: &mut Window,
817        cx: &mut Context<Self>,
818    ) -> Self {
819        cx.observe(&state, |_, _, cx| {
820            cx.notify();
821        })
822        .detach();
823
824        let load_credentials_task = Some(cx.spawn_in(window, {
825            let state = state.clone();
826            async move |this, cx| {
827                if let Some(task) = state
828                    .update(cx, |state, cx| state.authenticate(cx))
829                    .log_err()
830                {
831                    // We don't log an error, because "not signed in" is also an error.
832                    let _ = task.await;
833                }
834                this.update(cx, |this, cx| {
835                    this.load_credentials_task = None;
836                    cx.notify();
837                })
838                .log_err();
839            }
840        }));
841
842        Self {
843            api_key_editor: cx.new(|cx| {
844                let mut editor = Editor::single_line(window, cx);
845                editor.set_placeholder_text("AIzaSy...", window, cx);
846                editor
847            }),
848            target_agent,
849            state,
850            load_credentials_task,
851        }
852    }
853
854    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
855        let api_key = self.api_key_editor.read(cx).text(cx);
856        if api_key.is_empty() {
857            return;
858        }
859
860        let state = self.state.clone();
861        cx.spawn_in(window, async move |_, cx| {
862            state
863                .update(cx, |state, cx| state.set_api_key(api_key, cx))?
864                .await
865        })
866        .detach_and_log_err(cx);
867
868        cx.notify();
869    }
870
871    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
872        self.api_key_editor
873            .update(cx, |editor, cx| editor.set_text("", window, cx));
874
875        let state = self.state.clone();
876        cx.spawn_in(window, async move |_, cx| {
877            state.update(cx, |state, cx| state.reset_api_key(cx))?.await
878        })
879        .detach_and_log_err(cx);
880
881        cx.notify();
882    }
883
884    fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
885        let settings = ThemeSettings::get_global(cx);
886        let text_style = TextStyle {
887            color: cx.theme().colors().text,
888            font_family: settings.ui_font.family.clone(),
889            font_features: settings.ui_font.features.clone(),
890            font_fallbacks: settings.ui_font.fallbacks.clone(),
891            font_size: rems(0.875).into(),
892            font_weight: settings.ui_font.weight,
893            font_style: FontStyle::Normal,
894            line_height: relative(1.3),
895            white_space: WhiteSpace::Normal,
896            ..Default::default()
897        };
898        EditorElement::new(
899            &self.api_key_editor,
900            EditorStyle {
901                background: cx.theme().colors().editor_background,
902                local_player: cx.theme().players().local(),
903                text: text_style,
904                ..Default::default()
905            },
906        )
907    }
908
909    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
910        !self.state.read(cx).is_authenticated()
911    }
912}
913
914impl Render for ConfigurationView {
915    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
916        let env_var_set = self.state.read(cx).api_key_from_env;
917
918        if self.load_credentials_task.is_some() {
919            div().child(Label::new("Loading credentials...")).into_any()
920        } else if self.should_render_editor(cx) {
921            v_flex()
922                .size_full()
923                .on_action(cx.listener(Self::save_api_key))
924                .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent {
925                    ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Google AI".into(),
926                    ConfigurationViewTargetAgent::Other(agent) => agent.clone(),
927                })))
928                .child(
929                    List::new()
930                        .child(InstructionListItem::new(
931                            "Create one by visiting",
932                            Some("Google AI's console"),
933                            Some("https://aistudio.google.com/app/apikey"),
934                        ))
935                        .child(InstructionListItem::text_only(
936                            "Paste your API key below and hit enter to start using the assistant",
937                        )),
938                )
939                .child(
940                    h_flex()
941                        .w_full()
942                        .my_2()
943                        .px_2()
944                        .py_1()
945                        .bg(cx.theme().colors().editor_background)
946                        .border_1()
947                        .border_color(cx.theme().colors().border)
948                        .rounded_sm()
949                        .child(self.render_api_key_editor(cx)),
950                )
951                .child(
952                    Label::new(
953                        format!("You can also assign the {GEMINI_API_KEY_VAR} environment variable and restart Zed."),
954                    )
955                    .size(LabelSize::Small).color(Color::Muted),
956                )
957                .into_any()
958        } else {
959            h_flex()
960                .mt_1()
961                .p_1()
962                .justify_between()
963                .rounded_md()
964                .border_1()
965                .border_color(cx.theme().colors().border)
966                .bg(cx.theme().colors().background)
967                .child(
968                    h_flex()
969                        .gap_1()
970                        .child(Icon::new(IconName::Check).color(Color::Success))
971                        .child(Label::new(if env_var_set {
972                            format!("API key set in {GEMINI_API_KEY_VAR} environment variable.")
973                        } else {
974                            "API key configured.".to_string()
975                        })),
976                )
977                .child(
978                    Button::new("reset-key", "Reset Key")
979                        .label_size(LabelSize::Small)
980                        .icon(Some(IconName::Trash))
981                        .icon_size(IconSize::Small)
982                        .icon_position(IconPosition::Start)
983                        .disabled(env_var_set)
984                        .when(env_var_set, |this| {
985                            this.tooltip(Tooltip::text(format!("To reset your API key, make sure {GEMINI_API_KEY_VAR} and {GOOGLE_AI_API_KEY_VAR} environment variables are unset.")))
986                        })
987                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
988                )
989                .into_any()
990        }
991    }
992}