Add tool calling support for GitHub Copilot Chat (#28035)

Bennet Bo Fenner and Marshall Bowers created

This PR adds tool calling support for GitHub Copilot Chat models.

Currently only supports the Claude family of models.

Release Notes:

- agent: Added tool calling support for Claude models in GitHub Copilot
Chat.

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>

Change summary

crates/copilot/src/copilot_chat.rs                  |  95 +++
crates/language_models/src/provider/copilot_chat.rs | 302 +++++++++++---
2 files changed, 315 insertions(+), 82 deletions(-)

Detailed changes

crates/copilot/src/copilot_chat.rs 🔗

@@ -131,25 +131,70 @@ pub struct Request {
     pub temperature: f32,
     pub model: Model,
     pub messages: Vec<ChatMessage>,
+    #[serde(default, skip_serializing_if = "Vec::is_empty")]
+    pub tools: Vec<Tool>,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub tool_choice: Option<ToolChoice>,
 }
 
-impl Request {
-    pub fn new(model: Model, messages: Vec<ChatMessage>) -> Self {
-        Self {
-            intent: true,
-            n: 1,
-            stream: model.uses_streaming(),
-            temperature: 0.1,
-            model,
-            messages,
-        }
-    }
+#[derive(Serialize, Deserialize)]
+pub struct Function {
+    pub name: String,
+    pub description: String,
+    pub parameters: serde_json::Value,
+}
+
+#[derive(Serialize, Deserialize)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum Tool {
+    Function { function: Function },
+}
+
+#[derive(Serialize, Deserialize)]
+#[serde(tag = "type", rename_all = "lowercase")]
+pub enum ToolChoice {
+    Auto,
+    Any,
+    Tool { name: String },
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(tag = "role", rename_all = "lowercase")]
+pub enum ChatMessage {
+    Assistant {
+        content: Option<String>,
+        #[serde(default, skip_serializing_if = "Vec::is_empty")]
+        tool_calls: Vec<ToolCall>,
+    },
+    User {
+        content: String,
+    },
+    System {
+        content: String,
+    },
+    Tool {
+        content: String,
+        tool_call_id: String,
+    },
 }
 
 #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct ChatMessage {
-    pub role: Role,
-    pub content: String,
+pub struct ToolCall {
+    pub id: String,
+    #[serde(flatten)]
+    pub content: ToolCallContent,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(tag = "type", rename_all = "lowercase")]
+pub enum ToolCallContent {
+    Function { function: FunctionContent },
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct FunctionContent {
+    pub name: String,
+    pub arguments: String,
 }
 
 #[derive(Deserialize, Debug)]
@@ -172,6 +217,21 @@ pub struct ResponseChoice {
 pub struct ResponseDelta {
     pub content: Option<String>,
     pub role: Option<Role>,
+    #[serde(default)]
+    pub tool_calls: Vec<ToolCallChunk>,
+}
+
+#[derive(Deserialize, Debug, Eq, PartialEq)]
+pub struct ToolCallChunk {
+    pub index: usize,
+    pub id: Option<String>,
+    pub function: Option<FunctionChunk>,
+}
+
+#[derive(Deserialize, Debug, Eq, PartialEq)]
+pub struct FunctionChunk {
+    pub name: Option<String>,
+    pub arguments: Option<String>,
 }
 
 #[derive(Deserialize)]
@@ -385,7 +445,8 @@ async fn stream_completion(
 
     let is_streaming = request.stream;
 
-    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
+    let json = serde_json::to_string(&request)?;
+    let request = request_builder.body(AsyncBody::from(json))?;
     let mut response = client.send(request).await?;
 
     if !response.status().is_success() {
@@ -413,9 +474,7 @@ async fn stream_completion(
 
                         match serde_json::from_str::<ResponseEvent>(line) {
                             Ok(response) => {
-                                if response.choices.is_empty()
-                                    || response.choices.first().unwrap().finish_reason.is_some()
-                                {
+                                if response.choices.is_empty() {
                                     None
                                 } else {
                                     Some(Ok(response))

crates/language_models/src/provider/copilot_chat.rs 🔗

@@ -1,14 +1,17 @@
+use std::pin::Pin;
+use std::str::FromStr as _;
 use std::sync::Arc;
 
 use anyhow::{Result, anyhow};
+use collections::HashMap;
 use copilot::copilot_chat::{
     ChatMessage, CopilotChat, Model as CopilotChatModel, Request as CopilotChatRequest,
-    Role as CopilotChatRole,
+    ResponseEvent, Tool, ToolCall,
 };
 use copilot::{Copilot, Status};
 use futures::future::BoxFuture;
 use futures::stream::BoxStream;
-use futures::{FutureExt, StreamExt};
+use futures::{FutureExt, Stream, StreamExt};
 use gpui::{
     Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription, Task,
     Transformation, percentage, svg,
@@ -16,12 +19,14 @@ use gpui::{
 use language_model::{
     AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
     LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
-    LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
+    LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage,
+    LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
 };
 use settings::SettingsStore;
 use std::time::Duration;
 use strum::IntoEnumIterator;
 use ui::prelude::*;
+use util::maybe;
 
 use super::anthropic::count_anthropic_tokens;
 use super::google::count_google_tokens;
@@ -180,7 +185,12 @@ impl LanguageModel for CopilotChatLanguageModel {
     }
 
     fn supports_tools(&self) -> bool {
-        false
+        match self.model {
+            CopilotChatModel::Claude3_5Sonnet
+            | CopilotChatModel::Claude3_7Sonnet
+            | CopilotChatModel::Claude3_7SonnetThinking => true,
+            _ => false,
+        }
     }
 
     fn telemetry_id(&self) -> String {
@@ -240,77 +250,241 @@ impl LanguageModel for CopilotChatLanguageModel {
             }
         }
 
-        let copilot_request = self.to_copilot_chat_request(request);
-        let is_streaming = copilot_request.stream;
+        let copilot_request = match self.to_copilot_chat_request(request) {
+            Ok(request) => request,
+            Err(err) => return futures::future::ready(Err(err)).boxed(),
+        };
 
         let request_limiter = self.request_limiter.clone();
         let future = cx.spawn(async move |cx| {
-            let response = CopilotChat::stream_completion(copilot_request, cx.clone());
-            request_limiter.stream(async move {
-                let response = response.await?;
-                let stream = response
-                    .filter_map(move |response| async move {
-                        match response {
-                            Ok(result) => {
-                                let choice = result.choices.first();
-                                match choice {
-                                    Some(choice) if !is_streaming => {
-                                        match &choice.message {
-                                            Some(msg) => Some(Ok(msg.content.clone().unwrap_or_default())),
-                                            None => Some(Err(anyhow::anyhow!(
-                                                "The Copilot Chat API returned a response with no message content"
-                                            ))),
-                                        }
-                                    },
-                                    Some(choice) => {
-                                        match &choice.delta {
-                                            Some(delta) => Some(Ok(delta.content.clone().unwrap_or_default())),
-                                            None => Some(Err(anyhow::anyhow!(
-                                                "The Copilot Chat API returned a response with no delta content"
-                                            ))),
-                                        }
-                                    },
-                                    None => Some(Err(anyhow::anyhow!(
-                                        "The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
-                                    ))),
+            let request = CopilotChat::stream_completion(copilot_request, cx.clone());
+            request_limiter
+                .stream(async move {
+                    let response = request.await?;
+                    Ok(map_to_language_model_completion_events(response))
+                })
+                .await
+        });
+        async move { Ok(future.await?.boxed()) }.boxed()
+    }
+}
+
+pub fn map_to_language_model_completion_events(
+    events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
+) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
+    #[derive(Default)]
+    struct RawToolCall {
+        id: String,
+        name: String,
+        arguments: String,
+    }
+
+    struct State {
+        events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
+        tool_calls_by_index: HashMap<usize, RawToolCall>,
+    }
+
+    futures::stream::unfold(
+        State {
+            events,
+            tool_calls_by_index: HashMap::default(),
+        },
+        |mut state| async move {
+            if let Some(event) = state.events.next().await {
+                match event {
+                    Ok(event) => {
+                        let Some(choice) = event.choices.first() else {
+                            return Some((
+                                vec![Err(anyhow!("Response contained no choices"))],
+                                state,
+                            ));
+                        };
+
+                        let Some(delta) = choice.delta.as_ref() else {
+                            return Some((
+                                vec![Err(anyhow!("Response contained no delta"))],
+                                state,
+                            ));
+                        };
+
+                        let mut events = Vec::new();
+                        if let Some(content) = delta.content.clone() {
+                            events.push(Ok(LanguageModelCompletionEvent::Text(content)));
+                        }
+
+                            for tool_call in &delta.tool_calls {
+                                let entry = state
+                                    .tool_calls_by_index
+                                    .entry(tool_call.index)
+                                    .or_default();
+
+                                if let Some(tool_id) = tool_call.id.clone() {
+                                    entry.id = tool_id;
+                                }
+
+                                if let Some(function) = tool_call.function.as_ref() {
+                                    if let Some(name) = function.name.clone() {
+                                        entry.name = name;
+                                    }
+
+                                    if let Some(arguments) = function.arguments.clone() {
+                                        entry.arguments.push_str(&arguments);
+                                    }
                                 }
                             }
-                            Err(err) => Some(Err(err)),
+
+                        match choice.finish_reason.as_deref() {
+                            Some("stop") => {
+                                events.push(Ok(LanguageModelCompletionEvent::Stop(
+                                    StopReason::EndTurn,
+                                )));
+                            }
+                            Some("tool_calls") => {
+                                events.extend(state.tool_calls_by_index.drain().map(
+                                    |(_, tool_call)| {
+                                        maybe!({
+                                            Ok(LanguageModelCompletionEvent::ToolUse(
+                                                LanguageModelToolUse {
+                                                    id: tool_call.id.into(),
+                                                    name: tool_call.name.as_str().into(),
+                                                    input: serde_json::Value::from_str(
+                                                        &tool_call.arguments,
+                                                    )?,
+                                                },
+                                            ))
+                                        })
+                                    },
+                                ));
+
+                                events.push(Ok(LanguageModelCompletionEvent::Stop(
+                                    StopReason::ToolUse,
+                                )));
+                            }
+                            Some(stop_reason) => {
+                                log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}",);
+                                events.push(Ok(LanguageModelCompletionEvent::Stop(
+                                    StopReason::EndTurn,
+                                )));
+                            }
+                            None => {}
                         }
-                    })
-                    .boxed();
 
-                Ok(stream)
-            }).await
-        });
+                        return Some((events, state));
+                    }
+                    Err(err) => return Some((vec![Err(err)], state)),
+                }
+            }
 
-        async move {
-            Ok(future
-                .await?
-                .map(|result| result.map(LanguageModelCompletionEvent::Text))
-                .boxed())
-        }
-        .boxed()
-    }
+            None
+        },
+    )
+    .flat_map(futures::stream::iter)
 }
 
 impl CopilotChatLanguageModel {
-    pub fn to_copilot_chat_request(&self, request: LanguageModelRequest) -> CopilotChatRequest {
-        CopilotChatRequest::new(
-            self.model.clone(),
-            request
-                .messages
-                .into_iter()
-                .map(|msg| ChatMessage {
-                    role: match msg.role {
-                        Role::User => CopilotChatRole::User,
-                        Role::Assistant => CopilotChatRole::Assistant,
-                        Role::System => CopilotChatRole::System,
-                    },
-                    content: msg.string_contents(),
-                })
-                .collect(),
-        )
+    pub fn to_copilot_chat_request(
+        &self,
+        request: LanguageModelRequest,
+    ) -> Result<CopilotChatRequest> {
+        let model = self.model.clone();
+
+        let mut request_messages: Vec<LanguageModelRequestMessage> = Vec::new();
+        for message in request.messages {
+            if let Some(last_message) = request_messages.last_mut() {
+                if last_message.role == message.role {
+                    last_message.content.extend(message.content);
+                } else {
+                    request_messages.push(message);
+                }
+            } else {
+                request_messages.push(message);
+            }
+        }
+
+        let mut messages: Vec<ChatMessage> = Vec::new();
+        for message in request_messages {
+            let text_content = {
+                let mut buffer = String::new();
+                for string in message.content.iter().filter_map(|content| match content {
+                    MessageContent::Text(text) => Some(text.as_str()),
+                    MessageContent::ToolUse(_)
+                    | MessageContent::ToolResult(_)
+                    | MessageContent::Image(_) => None,
+                }) {
+                    buffer.push_str(string);
+                }
+
+                buffer
+            };
+
+            match message.role {
+                Role::User => {
+                    for content in &message.content {
+                        if let MessageContent::ToolResult(tool_result) = content {
+                            messages.push(ChatMessage::Tool {
+                                tool_call_id: tool_result.tool_use_id.to_string(),
+                                content: tool_result.content.to_string(),
+                            });
+                        }
+                    }
+
+                    messages.push(ChatMessage::User {
+                        content: text_content,
+                    });
+                }
+                Role::Assistant => {
+                    let mut tool_calls = Vec::new();
+                    for content in &message.content {
+                        if let MessageContent::ToolUse(tool_use) = content {
+                            tool_calls.push(ToolCall {
+                                id: tool_use.id.to_string(),
+                                content: copilot::copilot_chat::ToolCallContent::Function {
+                                    function: copilot::copilot_chat::FunctionContent {
+                                        name: tool_use.name.to_string(),
+                                        arguments: serde_json::to_string(&tool_use.input)?,
+                                    },
+                                },
+                            });
+                        }
+                    }
+
+                    messages.push(ChatMessage::Assistant {
+                        content: if text_content.is_empty() {
+                            None
+                        } else {
+                            Some(text_content)
+                        },
+                        tool_calls,
+                    });
+                }
+                Role::System => messages.push(ChatMessage::System {
+                    content: message.string_contents(),
+                }),
+            }
+        }
+
+        let tools = request
+            .tools
+            .iter()
+            .map(|tool| Tool::Function {
+                function: copilot::copilot_chat::Function {
+                    name: tool.name.clone(),
+                    description: tool.description.clone(),
+                    parameters: tool.input_schema.clone(),
+                },
+            })
+            .collect();
+
+        Ok(CopilotChatRequest {
+            intent: true,
+            n: 1,
+            stream: model.uses_streaming(),
+            temperature: 0.1,
+            model,
+            messages,
+            tools,
+            tool_choice: None,
+        })
     }
 }