@@ -131,25 +131,70 @@ pub struct Request {
pub temperature: f32,
pub model: Model,
pub messages: Vec<ChatMessage>,
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
+ pub tools: Vec<Tool>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub tool_choice: Option<ToolChoice>,
}
-impl Request {
- pub fn new(model: Model, messages: Vec<ChatMessage>) -> Self {
- Self {
- intent: true,
- n: 1,
- stream: model.uses_streaming(),
- temperature: 0.1,
- model,
- messages,
- }
- }
+#[derive(Serialize, Deserialize)]
+pub struct Function {
+ pub name: String,
+ pub description: String,
+ pub parameters: serde_json::Value,
+}
+
+#[derive(Serialize, Deserialize)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum Tool {
+ Function { function: Function },
+}
+
+#[derive(Serialize, Deserialize)]
+#[serde(tag = "type", rename_all = "lowercase")]
+pub enum ToolChoice {
+ Auto,
+ Any,
+ Tool { name: String },
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(tag = "role", rename_all = "lowercase")]
+pub enum ChatMessage {
+ Assistant {
+ content: Option<String>,
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
+ tool_calls: Vec<ToolCall>,
+ },
+ User {
+ content: String,
+ },
+ System {
+ content: String,
+ },
+ Tool {
+ content: String,
+ tool_call_id: String,
+ },
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct ChatMessage {
- pub role: Role,
- pub content: String,
+pub struct ToolCall {
+ pub id: String,
+ #[serde(flatten)]
+ pub content: ToolCallContent,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(tag = "type", rename_all = "lowercase")]
+pub enum ToolCallContent {
+ Function { function: FunctionContent },
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct FunctionContent {
+ pub name: String,
+ pub arguments: String,
}
#[derive(Deserialize, Debug)]
@@ -172,6 +217,21 @@ pub struct ResponseChoice {
pub struct ResponseDelta {
pub content: Option<String>,
pub role: Option<Role>,
+ #[serde(default)]
+ pub tool_calls: Vec<ToolCallChunk>,
+}
+
+#[derive(Deserialize, Debug, Eq, PartialEq)]
+pub struct ToolCallChunk {
+ pub index: usize,
+ pub id: Option<String>,
+ pub function: Option<FunctionChunk>,
+}
+
+#[derive(Deserialize, Debug, Eq, PartialEq)]
+pub struct FunctionChunk {
+ pub name: Option<String>,
+ pub arguments: Option<String>,
}
#[derive(Deserialize)]
@@ -385,7 +445,8 @@ async fn stream_completion(
let is_streaming = request.stream;
- let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
+ let json = serde_json::to_string(&request)?;
+ let request = request_builder.body(AsyncBody::from(json))?;
let mut response = client.send(request).await?;
if !response.status().is_success() {
@@ -413,9 +474,7 @@ async fn stream_completion(
match serde_json::from_str::<ResponseEvent>(line) {
Ok(response) => {
- if response.choices.is_empty()
- || response.choices.first().unwrap().finish_reason.is_some()
- {
+ if response.choices.is_empty() {
None
} else {
Some(Ok(response))
@@ -1,14 +1,17 @@
+use std::pin::Pin;
+use std::str::FromStr as _;
use std::sync::Arc;
use anyhow::{Result, anyhow};
+use collections::HashMap;
use copilot::copilot_chat::{
ChatMessage, CopilotChat, Model as CopilotChatModel, Request as CopilotChatRequest,
- Role as CopilotChatRole,
+ ResponseEvent, Tool, ToolCall,
};
use copilot::{Copilot, Status};
use futures::future::BoxFuture;
use futures::stream::BoxStream;
-use futures::{FutureExt, StreamExt};
+use futures::{FutureExt, Stream, StreamExt};
use gpui::{
Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription, Task,
Transformation, percentage, svg,
@@ -16,12 +19,14 @@ use gpui::{
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
- LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
+ LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage,
+ LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason,
};
use settings::SettingsStore;
use std::time::Duration;
use strum::IntoEnumIterator;
use ui::prelude::*;
+use util::maybe;
use super::anthropic::count_anthropic_tokens;
use super::google::count_google_tokens;
@@ -180,7 +185,12 @@ impl LanguageModel for CopilotChatLanguageModel {
}
fn supports_tools(&self) -> bool {
- false
+ match self.model {
+ CopilotChatModel::Claude3_5Sonnet
+ | CopilotChatModel::Claude3_7Sonnet
+ | CopilotChatModel::Claude3_7SonnetThinking => true,
+ _ => false,
+ }
}
fn telemetry_id(&self) -> String {
@@ -240,77 +250,241 @@ impl LanguageModel for CopilotChatLanguageModel {
}
}
- let copilot_request = self.to_copilot_chat_request(request);
- let is_streaming = copilot_request.stream;
+ let copilot_request = match self.to_copilot_chat_request(request) {
+ Ok(request) => request,
+ Err(err) => return futures::future::ready(Err(err)).boxed(),
+ };
let request_limiter = self.request_limiter.clone();
let future = cx.spawn(async move |cx| {
- let response = CopilotChat::stream_completion(copilot_request, cx.clone());
- request_limiter.stream(async move {
- let response = response.await?;
- let stream = response
- .filter_map(move |response| async move {
- match response {
- Ok(result) => {
- let choice = result.choices.first();
- match choice {
- Some(choice) if !is_streaming => {
- match &choice.message {
- Some(msg) => Some(Ok(msg.content.clone().unwrap_or_default())),
- None => Some(Err(anyhow::anyhow!(
- "The Copilot Chat API returned a response with no message content"
- ))),
- }
- },
- Some(choice) => {
- match &choice.delta {
- Some(delta) => Some(Ok(delta.content.clone().unwrap_or_default())),
- None => Some(Err(anyhow::anyhow!(
- "The Copilot Chat API returned a response with no delta content"
- ))),
- }
- },
- None => Some(Err(anyhow::anyhow!(
- "The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
- ))),
+ let request = CopilotChat::stream_completion(copilot_request, cx.clone());
+ request_limiter
+ .stream(async move {
+ let response = request.await?;
+ Ok(map_to_language_model_completion_events(response))
+ })
+ .await
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+}
+
+pub fn map_to_language_model_completion_events(
+ events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
+) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
+ #[derive(Default)]
+ struct RawToolCall {
+ id: String,
+ name: String,
+ arguments: String,
+ }
+
+ struct State {
+ events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
+ tool_calls_by_index: HashMap<usize, RawToolCall>,
+ }
+
+ futures::stream::unfold(
+ State {
+ events,
+ tool_calls_by_index: HashMap::default(),
+ },
+ |mut state| async move {
+ if let Some(event) = state.events.next().await {
+ match event {
+ Ok(event) => {
+ let Some(choice) = event.choices.first() else {
+ return Some((
+ vec![Err(anyhow!("Response contained no choices"))],
+ state,
+ ));
+ };
+
+ let Some(delta) = choice.delta.as_ref() else {
+ return Some((
+ vec![Err(anyhow!("Response contained no delta"))],
+ state,
+ ));
+ };
+
+ let mut events = Vec::new();
+ if let Some(content) = delta.content.clone() {
+ events.push(Ok(LanguageModelCompletionEvent::Text(content)));
+ }
+
+ for tool_call in &delta.tool_calls {
+ let entry = state
+ .tool_calls_by_index
+ .entry(tool_call.index)
+ .or_default();
+
+ if let Some(tool_id) = tool_call.id.clone() {
+ entry.id = tool_id;
+ }
+
+ if let Some(function) = tool_call.function.as_ref() {
+ if let Some(name) = function.name.clone() {
+ entry.name = name;
+ }
+
+ if let Some(arguments) = function.arguments.clone() {
+ entry.arguments.push_str(&arguments);
+ }
}
}
- Err(err) => Some(Err(err)),
+
+ match choice.finish_reason.as_deref() {
+ Some("stop") => {
+ events.push(Ok(LanguageModelCompletionEvent::Stop(
+ StopReason::EndTurn,
+ )));
+ }
+ Some("tool_calls") => {
+ events.extend(state.tool_calls_by_index.drain().map(
+ |(_, tool_call)| {
+ maybe!({
+ Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_call.id.into(),
+ name: tool_call.name.as_str().into(),
+ input: serde_json::Value::from_str(
+ &tool_call.arguments,
+ )?,
+ },
+ ))
+ })
+ },
+ ));
+
+ events.push(Ok(LanguageModelCompletionEvent::Stop(
+ StopReason::ToolUse,
+ )));
+ }
+ Some(stop_reason) => {
+ log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}",);
+ events.push(Ok(LanguageModelCompletionEvent::Stop(
+ StopReason::EndTurn,
+ )));
+ }
+ None => {}
}
- })
- .boxed();
- Ok(stream)
- }).await
- });
+ return Some((events, state));
+ }
+ Err(err) => return Some((vec![Err(err)], state)),
+ }
+ }
- async move {
- Ok(future
- .await?
- .map(|result| result.map(LanguageModelCompletionEvent::Text))
- .boxed())
- }
- .boxed()
- }
+ None
+ },
+ )
+ .flat_map(futures::stream::iter)
}
impl CopilotChatLanguageModel {
- pub fn to_copilot_chat_request(&self, request: LanguageModelRequest) -> CopilotChatRequest {
- CopilotChatRequest::new(
- self.model.clone(),
- request
- .messages
- .into_iter()
- .map(|msg| ChatMessage {
- role: match msg.role {
- Role::User => CopilotChatRole::User,
- Role::Assistant => CopilotChatRole::Assistant,
- Role::System => CopilotChatRole::System,
- },
- content: msg.string_contents(),
- })
- .collect(),
- )
+ pub fn to_copilot_chat_request(
+ &self,
+ request: LanguageModelRequest,
+ ) -> Result<CopilotChatRequest> {
+ let model = self.model.clone();
+
+ let mut request_messages: Vec<LanguageModelRequestMessage> = Vec::new();
+ for message in request.messages {
+ if let Some(last_message) = request_messages.last_mut() {
+ if last_message.role == message.role {
+ last_message.content.extend(message.content);
+ } else {
+ request_messages.push(message);
+ }
+ } else {
+ request_messages.push(message);
+ }
+ }
+
+ let mut messages: Vec<ChatMessage> = Vec::new();
+ for message in request_messages {
+ let text_content = {
+ let mut buffer = String::new();
+ for string in message.content.iter().filter_map(|content| match content {
+ MessageContent::Text(text) => Some(text.as_str()),
+ MessageContent::ToolUse(_)
+ | MessageContent::ToolResult(_)
+ | MessageContent::Image(_) => None,
+ }) {
+ buffer.push_str(string);
+ }
+
+ buffer
+ };
+
+ match message.role {
+ Role::User => {
+ for content in &message.content {
+ if let MessageContent::ToolResult(tool_result) = content {
+ messages.push(ChatMessage::Tool {
+ tool_call_id: tool_result.tool_use_id.to_string(),
+ content: tool_result.content.to_string(),
+ });
+ }
+ }
+
+ messages.push(ChatMessage::User {
+ content: text_content,
+ });
+ }
+ Role::Assistant => {
+ let mut tool_calls = Vec::new();
+ for content in &message.content {
+ if let MessageContent::ToolUse(tool_use) = content {
+ tool_calls.push(ToolCall {
+ id: tool_use.id.to_string(),
+ content: copilot::copilot_chat::ToolCallContent::Function {
+ function: copilot::copilot_chat::FunctionContent {
+ name: tool_use.name.to_string(),
+ arguments: serde_json::to_string(&tool_use.input)?,
+ },
+ },
+ });
+ }
+ }
+
+ messages.push(ChatMessage::Assistant {
+ content: if text_content.is_empty() {
+ None
+ } else {
+ Some(text_content)
+ },
+ tool_calls,
+ });
+ }
+ Role::System => messages.push(ChatMessage::System {
+ content: message.string_contents(),
+ }),
+ }
+ }
+
+ let tools = request
+ .tools
+ .iter()
+ .map(|tool| Tool::Function {
+ function: copilot::copilot_chat::Function {
+ name: tool.name.clone(),
+ description: tool.description.clone(),
+ parameters: tool.input_schema.clone(),
+ },
+ })
+ .collect();
+
+ Ok(CopilotChatRequest {
+ intent: true,
+ n: 1,
+ stream: model.uses_streaming(),
+ temperature: 0.1,
+ model,
+ messages,
+ tools,
+ tool_choice: None,
+ })
}
}