copilot: Add support for Gemini 2.0 Flash model to Copilot Chat (#24952)

Richard Hao and Peter Tripp created

Co-authored-by: Peter Tripp <peter@zed.dev>

Change summary

crates/copilot/src/copilot_chat.rs                  |  8 +++
crates/language_models/src/provider/copilot_chat.rs |  6 ++
crates/language_models/src/provider/google.rs       | 32 ++++++++++++++
3 files changed, 43 insertions(+), 3 deletions(-)

Detailed changes

crates/copilot/src/copilot_chat.rs 🔗

@@ -40,13 +40,15 @@ pub enum Model {
     O3Mini,
     #[serde(alias = "claude-3-5-sonnet", rename = "claude-3.5-sonnet")]
     Claude3_5Sonnet,
+    #[serde(alias = "gemini-2.0-flash", rename = "gemini-2.0-flash-001")]
+    Gemini20Flash,
 }
 
 impl Model {
     pub fn uses_streaming(&self) -> bool {
         match self {
             Self::Gpt4o | Self::Gpt4 | Self::Gpt3_5Turbo | Self::Claude3_5Sonnet => true,
-            Self::O3Mini | Self::O1 => false,
+            Self::O3Mini | Self::O1 | Self::Gemini20Flash => false,
         }
     }
 
@@ -58,6 +60,7 @@ impl Model {
             "o1" => Ok(Self::O1),
             "o3-mini" => Ok(Self::O3Mini),
             "claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet),
+            "gemini-2.0-flash-001" => Ok(Self::Gemini20Flash),
             _ => Err(anyhow!("Invalid model id: {}", id)),
         }
     }
@@ -70,6 +73,7 @@ impl Model {
             Self::O3Mini => "o3-mini",
             Self::O1 => "o1",
             Self::Claude3_5Sonnet => "claude-3-5-sonnet",
+            Self::Gemini20Flash => "gemini-2.0-flash-001",
         }
     }
 
@@ -81,6 +85,7 @@ impl Model {
             Self::O3Mini => "o3-mini",
             Self::O1 => "o1",
             Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
+            Self::Gemini20Flash => "Gemini 2.0 Flash",
         }
     }
 
@@ -92,6 +97,7 @@ impl Model {
             Self::O3Mini => 20000,
             Self::O1 => 20000,
             Self::Claude3_5Sonnet => 200_000,
+            Model::Gemini20Flash => 128_000,
         }
     }
 }

crates/language_models/src/provider/copilot_chat.rs 🔗

@@ -25,6 +25,7 @@ use strum::IntoEnumIterator;
 use ui::prelude::*;
 
 use super::anthropic::count_anthropic_tokens;
+use super::google::count_google_tokens;
 use super::open_ai::count_open_ai_tokens;
 
 const PROVIDER_ID: &str = "copilot_chat";
@@ -174,13 +175,16 @@ impl LanguageModel for CopilotChatLanguageModel {
     ) -> BoxFuture<'static, Result<usize>> {
         match self.model {
             CopilotChatModel::Claude3_5Sonnet => count_anthropic_tokens(request, cx),
+            CopilotChatModel::Gemini20Flash => count_google_tokens(request, cx),
             _ => {
                 let model = match self.model {
                     CopilotChatModel::Gpt4o => open_ai::Model::FourOmni,
                     CopilotChatModel::Gpt4 => open_ai::Model::Four,
                     CopilotChatModel::Gpt3_5Turbo => open_ai::Model::ThreePointFiveTurbo,
                     CopilotChatModel::O1 | CopilotChatModel::O3Mini => open_ai::Model::Four,
-                    CopilotChatModel::Claude3_5Sonnet => unreachable!(),
+                    CopilotChatModel::Claude3_5Sonnet | CopilotChatModel::Gemini20Flash => {
+                        unreachable!()
+                    }
                 };
                 count_open_ai_tokens(request, model, cx)
             }

crates/language_models/src/provider/google.rs 🔗

@@ -11,7 +11,7 @@ use language_model::LanguageModelCompletionEvent;
 use language_model::{
     LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
     LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
-    LanguageModelRequest, RateLimiter,
+    LanguageModelRequest, RateLimiter, Role,
 };
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
@@ -324,6 +324,36 @@ impl LanguageModel for GoogleLanguageModel {
     }
 }
 
+pub fn count_google_tokens(
+    request: LanguageModelRequest,
+    cx: &App,
+) -> BoxFuture<'static, Result<usize>> {
+    // We couldn't use the GoogleLanguageModelProvider to count tokens because the github copilot doesn't have the access to google_ai directly.
+    // So we have to use tokenizer from tiktoken_rs to count tokens.
+    cx.background_executor()
+        .spawn(async move {
+            let messages = request
+                .messages
+                .into_iter()
+                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
+                    role: match message.role {
+                        Role::User => "user".into(),
+                        Role::Assistant => "assistant".into(),
+                        Role::System => "system".into(),
+                    },
+                    content: Some(message.string_contents()),
+                    name: None,
+                    function_call: None,
+                })
+                .collect::<Vec<_>>();
+
+            // Tiktoken doesn't yet support these models, so we manually use the
+            // same tokenizer as GPT-4.
+            tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
+        })
+        .boxed()
+}
+
 struct ConfigurationView {
     api_key_editor: Entity<Editor>,
     state: gpui::Entity<State>,