copilot: Add support for new models (#19968)

Jonathan Toledo and Bennet Bo Fenner created

Closes #19963

This PR implements integration with the newly announced GitHub Copilot
LLM models, including:
- Claude 3.5 Sonnet
- o1-mini
- o1-preview

Release Notes:

- N/A

---------

Co-authored-by: Bennet Bo Fenner <bennet@zed.dev>

Change summary

crates/copilot/src/copilot_chat.rs                 | 60 ++++++++++++----
crates/language_model/src/provider/copilot_chat.rs | 45 +++++++++--
2 files changed, 79 insertions(+), 26 deletions(-)

Detailed changes

crates/copilot/src/copilot_chat.rs 🔗

@@ -35,14 +35,30 @@ pub enum Model {
     Gpt4,
     #[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")]
     Gpt3_5Turbo,
+    #[serde(alias = "o1-preview", rename = "o1-preview-2024-09-12")]
+    O1Preview,
+    #[serde(alias = "o1-mini", rename = "o1-mini-2024-09-12")]
+    O1Mini,
+    #[serde(alias = "claude-3-5-sonnet", rename = "claude-3.5-sonnet")]
+    Claude3_5Sonnet,
 }
 
 impl Model {
+    pub fn uses_streaming(&self) -> bool {
+        match self {
+            Self::Gpt4o | Self::Gpt4 | Self::Gpt3_5Turbo | Self::Claude3_5Sonnet => true,
+            Self::O1Mini | Self::O1Preview => false,
+        }
+    }
+
     pub fn from_id(id: &str) -> Result<Self> {
         match id {
             "gpt-4o" => Ok(Self::Gpt4o),
             "gpt-4" => Ok(Self::Gpt4),
             "gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
+            "o1-preview" => Ok(Self::O1Preview),
+            "o1-mini" => Ok(Self::O1Mini),
+            "claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet),
             _ => Err(anyhow!("Invalid model id: {}", id)),
         }
     }
@@ -52,6 +68,9 @@ impl Model {
             Self::Gpt3_5Turbo => "gpt-3.5-turbo",
             Self::Gpt4 => "gpt-4",
             Self::Gpt4o => "gpt-4o",
+            Self::O1Mini => "o1-mini",
+            Self::O1Preview => "o1-preview",
+            Self::Claude3_5Sonnet => "claude-3-5-sonnet",
         }
     }
 
@@ -60,6 +79,9 @@ impl Model {
             Self::Gpt3_5Turbo => "GPT-3.5",
             Self::Gpt4 => "GPT-4",
             Self::Gpt4o => "GPT-4o",
+            Self::O1Mini => "o1-mini",
+            Self::O1Preview => "o1-preview",
+            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
         }
     }
 
@@ -68,6 +90,9 @@ impl Model {
             Self::Gpt4o => 128000,
             Self::Gpt4 => 8192,
             Self::Gpt3_5Turbo => 16385,
+            Self::O1Mini => 128000,
+            Self::O1Preview => 128000,
+            Self::Claude3_5Sonnet => 200_000,
         }
     }
 }
@@ -87,7 +112,7 @@ impl Request {
         Self {
             intent: true,
             n: 1,
-            stream: true,
+            stream: model.uses_streaming(),
             temperature: 0.1,
             model,
             messages,
@@ -113,7 +138,8 @@ pub struct ResponseEvent {
 pub struct ResponseChoice {
     pub index: usize,
     pub finish_reason: Option<String>,
-    pub delta: ResponseDelta,
+    pub delta: Option<ResponseDelta>,
+    pub message: Option<ResponseDelta>,
 }
 
 #[derive(Debug, Deserialize)]
@@ -333,9 +359,23 @@ async fn stream_completion(
     if let Some(low_speed_timeout) = low_speed_timeout {
         request_builder = request_builder.read_timeout(low_speed_timeout);
     }
+    let is_streaming = request.stream;
+
     let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
     let mut response = client.send(request).await?;
-    if response.status().is_success() {
+
+    if !response.status().is_success() {
+        let mut body = Vec::new();
+        response.body_mut().read_to_end(&mut body).await?;
+        let body_str = std::str::from_utf8(&body)?;
+        return Err(anyhow!(
+            "Failed to connect to API: {} {}",
+            response.status(),
+            body_str
+        ));
+    }
+
+    if is_streaming {
         let reader = BufReader::new(response.into_body());
         Ok(reader
             .lines()
@@ -367,19 +407,9 @@ async fn stream_completion(
     } else {
         let mut body = Vec::new();
         response.body_mut().read_to_end(&mut body).await?;
-
         let body_str = std::str::from_utf8(&body)?;
+        let response: ResponseEvent = serde_json::from_str(body_str)?;
 
-        match serde_json::from_str::<ResponseEvent>(body_str) {
-            Ok(_) => Err(anyhow!(
-                "Unexpected success response while expecting an error: {}",
-                body_str,
-            )),
-            Err(_) => Err(anyhow!(
-                "Failed to connect to API: {} {}",
-                response.status(),
-                body_str,
-            )),
-        }
+        Ok(futures::stream::once(async move { Ok(response) }).boxed())
     }
 }

crates/language_model/src/provider/copilot_chat.rs 🔗

@@ -30,6 +30,7 @@ use crate::{
 };
 use crate::{LanguageModelCompletionEvent, LanguageModelProviderState};
 
+use super::anthropic::count_anthropic_tokens;
 use super::open_ai::count_open_ai_tokens;
 
 const PROVIDER_ID: &str = "copilot_chat";
@@ -179,13 +180,19 @@ impl LanguageModel for CopilotChatLanguageModel {
         request: LanguageModelRequest,
         cx: &AppContext,
     ) -> BoxFuture<'static, Result<usize>> {
-        let model = match self.model {
-            CopilotChatModel::Gpt4o => open_ai::Model::FourOmni,
-            CopilotChatModel::Gpt4 => open_ai::Model::Four,
-            CopilotChatModel::Gpt3_5Turbo => open_ai::Model::ThreePointFiveTurbo,
-        };
-
-        count_open_ai_tokens(request, model, cx)
+        match self.model {
+            CopilotChatModel::Claude3_5Sonnet => count_anthropic_tokens(request, cx),
+            _ => {
+                let model = match self.model {
+                    CopilotChatModel::Gpt4o => open_ai::Model::FourOmni,
+                    CopilotChatModel::Gpt4 => open_ai::Model::Four,
+                    CopilotChatModel::Gpt3_5Turbo => open_ai::Model::ThreePointFiveTurbo,
+                    CopilotChatModel::O1Preview | CopilotChatModel::O1Mini => open_ai::Model::Four,
+                    CopilotChatModel::Claude3_5Sonnet => unreachable!(),
+                };
+                count_open_ai_tokens(request, model, cx)
+            }
+        }
     }
 
     fn stream_completion(
@@ -209,7 +216,8 @@ impl LanguageModel for CopilotChatLanguageModel {
             }
         }
 
-        let request = self.to_copilot_chat_request(request);
+        let copilot_request = self.to_copilot_chat_request(request);
+        let is_streaming = copilot_request.stream;
         let Ok(low_speed_timeout) = cx.update(|cx| {
             AllLanguageModelSettings::get_global(cx)
                 .copilot_chat
@@ -220,16 +228,31 @@ impl LanguageModel for CopilotChatLanguageModel {
 
         let request_limiter = self.request_limiter.clone();
         let future = cx.spawn(|cx| async move {
-            let response = CopilotChat::stream_completion(request, low_speed_timeout, cx);
+            let response = CopilotChat::stream_completion(copilot_request, low_speed_timeout, cx);
             request_limiter.stream(async move {
                 let response = response.await?;
                 let stream = response
-                    .filter_map(|response| async move {
+                    .filter_map(move |response| async move {
                         match response {
                             Ok(result) => {
                                 let choice = result.choices.first();
                                 match choice {
-                                    Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())),
+                                    Some(choice) if !is_streaming => {
+                                        match &choice.message {
+                                            Some(msg) => Some(Ok(msg.content.clone().unwrap_or_default())),
+                                            None => Some(Err(anyhow::anyhow!(
+                                                "The Copilot Chat API returned a response with no message content"
+                                            ))),
+                                        }
+                                    },
+                                    Some(choice) => {
+                                        match &choice.delta {
+                                            Some(delta) => Some(Ok(delta.content.clone().unwrap_or_default())),
+                                            None => Some(Err(anyhow::anyhow!(
+                                                "The Copilot Chat API returned a response with no delta content"
+                                            ))),
+                                        }
+                                    },
                                     None => Some(Err(anyhow::anyhow!(
                                         "The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
                                     ))),