diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index d8fa8967a862053ccf2a820878f450c38ea18fad..41c8a17c2d251e23f7c2d6b27fbd2ff488c1c0e4 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -1,5 +1,6 @@ pub mod copilot_chat; mod copilot_completion_provider; +pub mod copilot_responses; pub mod request; mod sign_in; diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot/src/copilot_chat.rs index a6758ce53c0aa18d04dcd376c2e0afb93add6ab5..5d22760942dbbcfd72f1dacb83c249a08f2fe72a 100644 --- a/crates/copilot/src/copilot_chat.rs +++ b/crates/copilot/src/copilot_chat.rs @@ -15,6 +15,8 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use itertools::Itertools; use paths::home_dir; use serde::{Deserialize, Serialize}; + +use crate::copilot_responses as responses; use settings::watch_config_dir; pub const COPILOT_OAUTH_ENV_VAR: &str = "GH_COPILOT_TOKEN"; @@ -42,10 +44,14 @@ impl CopilotChatConfiguration { } } - pub fn api_url_from_endpoint(&self, endpoint: &str) -> String { + pub fn chat_completions_url_from_endpoint(&self, endpoint: &str) -> String { format!("{}/chat/completions", endpoint) } + pub fn responses_url_from_endpoint(&self, endpoint: &str) -> String { + format!("{}/responses", endpoint) + } + pub fn models_url_from_endpoint(&self, endpoint: &str) -> String { format!("{}/models", endpoint) } @@ -71,6 +77,14 @@ pub enum Role { System, } +#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)] +pub enum ModelSupportedEndpoint { + #[serde(rename = "/chat/completions")] + ChatCompletions, + #[serde(rename = "/responses")] + Responses, +} + #[derive(Deserialize)] struct ModelSchema { #[serde(deserialize_with = "deserialize_models_skip_errors")] @@ -109,6 +123,8 @@ pub struct Model { // reached. Zed does not currently implement this behaviour is_chat_fallback: bool, model_picker_enabled: bool, + #[serde(default)] + supported_endpoints: Vec, } #[derive(Clone, Serialize, Deserialize, Debug, PartialEq)] @@ -224,6 +240,16 @@ impl Model { pub fn tokenizer(&self) -> Option<&str> { self.capabilities.tokenizer.as_deref() } + + pub fn supports_response(&self) -> bool { + self.supported_endpoints.len() > 0 + && !self + .supported_endpoints + .contains(&ModelSupportedEndpoint::ChatCompletions) + && self + .supported_endpoints + .contains(&ModelSupportedEndpoint::Responses) + } } #[derive(Serialize, Deserialize)] @@ -253,7 +279,7 @@ pub enum Tool { Function { function: Function }, } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] #[serde(rename_all = "lowercase")] pub enum ToolChoice { Auto, @@ -346,7 +372,7 @@ pub struct Usage { #[derive(Debug, Deserialize)] pub struct ResponseChoice { - pub index: usize, + pub index: Option, pub finish_reason: Option, pub delta: Option, pub message: Option, @@ -359,10 +385,9 @@ pub struct ResponseDelta { #[serde(default)] pub tool_calls: Vec, } - #[derive(Deserialize, Debug, Eq, PartialEq)] pub struct ToolCallChunk { - pub index: usize, + pub index: Option, pub id: Option, pub function: Option, } @@ -554,13 +579,47 @@ impl CopilotChat { is_user_initiated: bool, mut cx: AsyncApp, ) -> Result>> { + let (client, token, configuration) = Self::get_auth_details(&mut cx).await?; + + let api_url = configuration.chat_completions_url_from_endpoint(&token.api_endpoint); + stream_completion( + client.clone(), + token.api_key, + api_url.into(), + request, + is_user_initiated, + ) + .await + } + + pub async fn stream_response( + request: responses::Request, + is_user_initiated: bool, + mut cx: AsyncApp, + ) -> Result>> { + let (client, token, configuration) = Self::get_auth_details(&mut cx).await?; + + let api_url = configuration.responses_url_from_endpoint(&token.api_endpoint); + responses::stream_response( + client.clone(), + token.api_key, + api_url, + request, + is_user_initiated, + ) + .await + } + + async fn get_auth_details( + cx: &mut AsyncApp, + ) -> Result<(Arc, ApiToken, CopilotChatConfiguration)> { let this = cx .update(|cx| Self::global(cx)) .ok() .flatten() .context("Copilot chat is not enabled")?; - let (oauth_token, api_token, client, configuration) = this.read_with(&cx, |this, _| { + let (oauth_token, api_token, client, configuration) = this.read_with(cx, |this, _| { ( this.oauth_token.clone(), this.api_token.clone(), @@ -572,12 +631,12 @@ impl CopilotChat { let oauth_token = oauth_token.context("No OAuth token available")?; let token = match api_token { - Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token.clone(), + Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token, _ => { let token_url = configuration.token_url(); let token = request_api_token(&oauth_token, token_url.into(), client.clone()).await?; - this.update(&mut cx, |this, cx| { + this.update(cx, |this, cx| { this.api_token = Some(token.clone()); cx.notify(); })?; @@ -585,15 +644,7 @@ impl CopilotChat { } }; - let api_url = configuration.api_url_from_endpoint(&token.api_endpoint); - stream_completion( - client.clone(), - token.api_key, - api_url.into(), - request, - is_user_initiated, - ) - .await + Ok((client, token, configuration)) } pub fn set_configuration( diff --git a/crates/copilot/src/copilot_responses.rs b/crates/copilot/src/copilot_responses.rs new file mode 100644 index 0000000000000000000000000000000000000000..c1e066208823dcab34a32096cfa447dd0ec9592f --- /dev/null +++ b/crates/copilot/src/copilot_responses.rs @@ -0,0 +1,414 @@ +use super::*; +use anyhow::{Result, anyhow}; +use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; +use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +pub use settings::OpenAiReasoningEffort as ReasoningEffort; + +#[derive(Serialize, Debug)] +pub struct Request { + pub model: String, + pub input: Vec, + #[serde(default)] + pub stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + pub tools: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub include: Option>, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "snake_case")] +pub enum ResponseIncludable { + #[serde(rename = "reasoning.encrypted_content")] + ReasoningEncryptedContent, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ToolDefinition { + Function { + name: String, + #[serde(skip_serializing_if = "Option::is_none")] + description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + parameters: Option, + #[serde(skip_serializing_if = "Option::is_none")] + strict: Option, + }, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "lowercase")] +pub enum ToolChoice { + Auto, + Any, + None, + #[serde(untagged)] + Other(ToolDefinition), +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "lowercase")] +pub enum ReasoningSummary { + Auto, + Concise, + Detailed, +} + +#[derive(Serialize, Debug)] +pub struct ReasoningConfig { + pub effort: ReasoningEffort, + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone, Default)] +#[serde(rename_all = "snake_case")] +pub enum ResponseImageDetail { + Low, + High, + #[default] + Auto, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ResponseInputContent { + InputText { + text: String, + }, + OutputText { + text: String, + }, + InputImage { + #[serde(skip_serializing_if = "Option::is_none")] + image_url: Option, + #[serde(default)] + detail: ResponseImageDetail, + }, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "snake_case")] +pub enum ItemStatus { + InProgress, + Completed, + Incomplete, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum ResponseFunctionOutput { + Text(String), + Content(Vec), +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ResponseInputItem { + Message { + role: String, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + status: Option, + }, + FunctionCall { + call_id: String, + name: String, + arguments: String, + #[serde(skip_serializing_if = "Option::is_none")] + status: Option, + }, + FunctionCallOutput { + call_id: String, + output: ResponseFunctionOutput, + #[serde(skip_serializing_if = "Option::is_none")] + status: Option, + }, + Reasoning { + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + summary: Vec, + encrypted_content: String, + }, +} + +#[derive(Deserialize, Debug, Clone)] +#[serde(rename_all = "snake_case")] +pub enum IncompleteReason { + #[serde(rename = "max_output_tokens")] + MaxOutputTokens, + #[serde(rename = "content_filter")] + ContentFilter, +} + +#[derive(Deserialize, Debug, Clone)] +pub struct IncompleteDetails { + #[serde(skip_serializing_if = "Option::is_none")] + pub reason: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ResponseReasoningItem { + #[serde(rename = "type")] + pub kind: String, + pub text: String, +} + +#[derive(Deserialize, Debug)] +#[serde(tag = "type")] +pub enum StreamEvent { + #[serde(rename = "error")] + GenericError { error: ResponseError }, + + #[serde(rename = "response.created")] + Created { response: Response }, + + #[serde(rename = "response.output_item.added")] + OutputItemAdded { + output_index: usize, + #[serde(default)] + sequence_number: Option, + item: ResponseOutputItem, + }, + + #[serde(rename = "response.output_text.delta")] + OutputTextDelta { + item_id: String, + output_index: usize, + delta: String, + }, + + #[serde(rename = "response.output_item.done")] + OutputItemDone { + output_index: usize, + #[serde(default)] + sequence_number: Option, + item: ResponseOutputItem, + }, + + #[serde(rename = "response.incomplete")] + Incomplete { response: Response }, + + #[serde(rename = "response.completed")] + Completed { response: Response }, + + #[serde(rename = "response.failed")] + Failed { response: Response }, + + #[serde(other)] + Unknown, +} + +#[derive(Deserialize, Debug, Clone)] +pub struct ResponseError { + pub code: String, + pub message: String, +} + +#[derive(Deserialize, Debug, Default, Clone)] +pub struct Response { + pub id: Option, + pub status: Option, + pub usage: Option, + pub output: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub incomplete_details: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Deserialize, Debug, Default, Clone)] +pub struct ResponseUsage { + pub input_tokens: Option, + pub output_tokens: Option, + pub total_tokens: Option, +} + +#[derive(Deserialize, Debug, Clone)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ResponseOutputItem { + Message { + id: String, + role: String, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option>, + }, + FunctionCall { + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + call_id: String, + name: String, + arguments: String, + #[serde(skip_serializing_if = "Option::is_none")] + status: Option, + }, + Reasoning { + id: String, + #[serde(skip_serializing_if = "Option::is_none")] + summary: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + encrypted_content: Option, + }, +} + +#[derive(Deserialize, Debug, Clone)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ResponseOutputContent { + OutputText { text: String }, + Refusal { refusal: String }, +} + +pub async fn stream_response( + client: Arc, + api_key: String, + api_url: String, + request: Request, + is_user_initiated: bool, +) -> Result>> { + let is_vision_request = request.input.iter().any(|item| match item { + ResponseInputItem::Message { + content: Some(parts), + .. + } => parts + .iter() + .any(|p| matches!(p, ResponseInputContent::InputImage { .. })), + _ => false, + }); + + let request_initiator = if is_user_initiated { "user" } else { "agent" }; + + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(&api_url) + .header( + "Editor-Version", + format!( + "Zed/{}", + option_env!("CARGO_PKG_VERSION").unwrap_or("unknown") + ), + ) + .header("Authorization", format!("Bearer {}", api_key)) + .header("Content-Type", "application/json") + .header("Copilot-Integration-Id", "vscode-chat") + .header("X-Initiator", request_initiator); + + let request_builder = if is_vision_request { + request_builder.header("Copilot-Vision-Request", "true") + } else { + request_builder + }; + + let is_streaming = request.stream; + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; + let mut response = client.send(request).await?; + + if !response.status().is_success() { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + anyhow::bail!("Failed to connect to API: {} {}", response.status(), body); + } + + if is_streaming { + let reader = BufReader::new(response.into_body()); + Ok(reader + .lines() + .filter_map(|line| async move { + match line { + Ok(line) => { + let line = line.strip_prefix("data: ")?; + if line.starts_with("[DONE]") || line.is_empty() { + return None; + } + + match serde_json::from_str::(line) { + Ok(event) => Some(Ok(event)), + Err(error) => { + log::error!( + "Failed to parse Copilot responses stream event: `{}`\nResponse: `{}`", + error, + line, + ); + Some(Err(anyhow!(error))) + } + } + } + Err(error) => Some(Err(anyhow!(error))), + } + }) + .boxed()) + } else { + // Simulate streaming this makes the mapping of this function return more straight-forward to handle if all callers assume it streams. + // Removes the need of having a method to map StreamEvent and another to map Response to a LanguageCompletionEvent + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + match serde_json::from_str::(&body) { + Ok(response) => { + let events = vec![StreamEvent::Created { + response: response.clone(), + }]; + + let mut all_events = events; + for (output_index, item) in response.output.iter().enumerate() { + all_events.push(StreamEvent::OutputItemAdded { + output_index, + sequence_number: None, + item: item.clone(), + }); + + if let ResponseOutputItem::Message { + id, + content: Some(content), + .. + } = item + { + for part in content { + if let ResponseOutputContent::OutputText { text } = part { + all_events.push(StreamEvent::OutputTextDelta { + item_id: id.clone(), + output_index, + delta: text.clone(), + }); + } + } + } + + all_events.push(StreamEvent::OutputItemDone { + output_index, + sequence_number: None, + item: item.clone(), + }); + } + + let final_event = if response.error.is_some() { + StreamEvent::Failed { response } + } else if response.incomplete_details.is_some() { + StreamEvent::Incomplete { response } + } else { + StreamEvent::Completed { response } + }; + all_events.push(final_event); + + Ok(futures::stream::iter(all_events.into_iter().map(Ok)).boxed()) + } + Err(error) => { + log::error!( + "Failed to parse Copilot non-streaming response: `{}`\nResponse: `{}`", + error, + body, + ); + Err(anyhow!(error)) + } + } + } +} diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 64a2c65f0d2bcc4240e980922930e24240ce3249..1941bd903951420266ba5c4609cb34c15130224e 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -15,6 +15,7 @@ use futures::future::BoxFuture; use futures::stream::BoxStream; use futures::{FutureExt, Stream, StreamExt}; use gpui::{Action, AnyView, App, AsyncApp, Entity, Render, Subscription, Task, svg}; +use http_client::StatusCode; use language::language_settings::all_language_settings; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, @@ -306,6 +307,23 @@ impl LanguageModel for CopilotChatLanguageModel { | CompletionIntent::EditFile => false, }); + if self.model.supports_response() { + let responses_request = into_copilot_responses(&self.model, request); + let request_limiter = self.request_limiter.clone(); + let future = cx.spawn(async move |cx| { + let request = + CopilotChat::stream_response(responses_request, is_user_initiated, cx.clone()); + request_limiter + .stream(async move { + let stream = request.await?; + let mapper = CopilotResponsesEventMapper::new(); + Ok(mapper.map_stream(stream).boxed()) + }) + .await + }); + return async move { Ok(future.await?.boxed()) }.boxed(); + } + let copilot_request = match into_copilot_chat(&self.model, request) { Ok(request) => request, Err(err) => return futures::future::ready(Err(err.into())).boxed(), @@ -380,11 +398,9 @@ pub fn map_to_language_model_completion_events( events.push(Ok(LanguageModelCompletionEvent::Text(content))); } - for tool_call in &delta.tool_calls { - let entry = state - .tool_calls_by_index - .entry(tool_call.index) - .or_default(); + for (index, tool_call) in delta.tool_calls.iter().enumerate() { + let tool_index = tool_call.index.unwrap_or(index); + let entry = state.tool_calls_by_index.entry(tool_index).or_default(); if let Some(tool_id) = tool_call.id.clone() { entry.id = tool_id; @@ -433,11 +449,11 @@ pub fn map_to_language_model_completion_events( match arguments { Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { - id: tool_call.id.clone().into(), + id: tool_call.id.into(), name: tool_call.name.as_str().into(), is_input_complete: true, input, - raw_input: tool_call.arguments.clone(), + raw_input: tool_call.arguments, }, )), Err(error) => Ok( @@ -477,6 +493,191 @@ pub fn map_to_language_model_completion_events( .flat_map(futures::stream::iter) } +pub struct CopilotResponsesEventMapper { + pending_stop_reason: Option, +} + +impl CopilotResponsesEventMapper { + pub fn new() -> Self { + Self { + pending_stop_reason: None, + } + } + + pub fn map_stream( + mut self, + events: Pin>>>, + ) -> impl Stream> + { + events.flat_map(move |event| { + futures::stream::iter(match event { + Ok(event) => self.map_event(event), + Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))], + }) + }) + } + + fn map_event( + &mut self, + event: copilot::copilot_responses::StreamEvent, + ) -> Vec> { + match event { + copilot::copilot_responses::StreamEvent::OutputItemAdded { item, .. } => match item { + copilot::copilot_responses::ResponseOutputItem::Message { id, .. } => { + vec![Ok(LanguageModelCompletionEvent::StartMessage { + message_id: id, + })] + } + _ => Vec::new(), + }, + + copilot::copilot_responses::StreamEvent::OutputTextDelta { delta, .. } => { + if delta.is_empty() { + Vec::new() + } else { + vec![Ok(LanguageModelCompletionEvent::Text(delta))] + } + } + + copilot::copilot_responses::StreamEvent::OutputItemDone { item, .. } => match item { + copilot::copilot_responses::ResponseOutputItem::Message { .. } => Vec::new(), + copilot::copilot_responses::ResponseOutputItem::FunctionCall { + call_id, + name, + arguments, + .. + } => { + let mut events = Vec::new(); + match serde_json::from_str::(&arguments) { + Ok(input) => events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: call_id.into(), + name: name.as_str().into(), + is_input_complete: true, + input, + raw_input: arguments.clone(), + }, + ))), + Err(error) => { + events.push(Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: call_id.into(), + tool_name: name.as_str().into(), + raw_input: arguments.clone().into(), + json_parse_error: error.to_string(), + })) + } + } + // Record that we already emitted a tool-use stop so we can avoid duplicating + // a Stop event on Completed. + self.pending_stop_reason = Some(StopReason::ToolUse); + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); + events + } + copilot::copilot_responses::ResponseOutputItem::Reasoning { + summary, + encrypted_content, + .. + } => { + let mut events = Vec::new(); + + if let Some(blocks) = summary { + let mut text = String::new(); + for block in blocks { + text.push_str(&block.text); + } + if !text.is_empty() { + events.push(Ok(LanguageModelCompletionEvent::Thinking { + text, + signature: None, + })); + } + } + + if let Some(data) = encrypted_content { + events.push(Ok(LanguageModelCompletionEvent::RedactedThinking { data })); + } + + events + } + }, + + copilot::copilot_responses::StreamEvent::Completed { response } => { + let mut events = Vec::new(); + if let Some(usage) = response.usage { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: usage.input_tokens.unwrap_or(0), + output_tokens: usage.output_tokens.unwrap_or(0), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }))); + } + if self.pending_stop_reason.take() != Some(StopReason::ToolUse) { + events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); + } + events + } + + copilot::copilot_responses::StreamEvent::Incomplete { response } => { + let reason = response + .incomplete_details + .as_ref() + .and_then(|details| details.reason.as_ref()); + let stop_reason = match reason { + Some(copilot::copilot_responses::IncompleteReason::MaxOutputTokens) => { + StopReason::MaxTokens + } + Some(copilot::copilot_responses::IncompleteReason::ContentFilter) => { + StopReason::Refusal + } + _ => self + .pending_stop_reason + .take() + .unwrap_or(StopReason::EndTurn), + }; + + let mut events = Vec::new(); + if let Some(usage) = response.usage { + events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: usage.input_tokens.unwrap_or(0), + output_tokens: usage.output_tokens.unwrap_or(0), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }))); + } + events.push(Ok(LanguageModelCompletionEvent::Stop(stop_reason))); + events + } + + copilot::copilot_responses::StreamEvent::Failed { response } => { + let provider = PROVIDER_NAME; + let (status_code, message) = match response.error { + Some(error) => { + let status_code = StatusCode::from_str(&error.code) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + (status_code, error.message) + } + None => ( + StatusCode::INTERNAL_SERVER_ERROR, + "response.failed".to_string(), + ), + }; + vec![Err(LanguageModelCompletionError::HttpResponseError { + provider, + status_code, + message, + })] + } + + copilot::copilot_responses::StreamEvent::GenericError { error } => vec![Err( + LanguageModelCompletionError::Other(anyhow!(format!("{error:?}"))), + )], + + copilot::copilot_responses::StreamEvent::Created { .. } + | copilot::copilot_responses::StreamEvent::Unknown => Vec::new(), + } + } +} + fn into_copilot_chat( model: &copilot::copilot_chat::Model, request: LanguageModelRequest, @@ -635,6 +836,470 @@ fn into_copilot_chat( }) } +fn into_copilot_responses( + model: &copilot::copilot_chat::Model, + request: LanguageModelRequest, +) -> copilot::copilot_responses::Request { + use copilot::copilot_responses as responses; + + let LanguageModelRequest { + thread_id: _, + prompt_id: _, + intent: _, + mode: _, + messages, + tools, + tool_choice, + stop: _, + temperature, + thinking_allowed: _, + } = request; + + let mut input_items: Vec = Vec::new(); + + for message in messages { + match message.role { + Role::User => { + for content in &message.content { + if let MessageContent::ToolResult(tool_result) = content { + let output = if let Some(out) = &tool_result.output { + match out { + serde_json::Value::String(s) => { + responses::ResponseFunctionOutput::Text(s.clone()) + } + serde_json::Value::Null => { + responses::ResponseFunctionOutput::Text(String::new()) + } + other => responses::ResponseFunctionOutput::Text(other.to_string()), + } + } else { + match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + responses::ResponseFunctionOutput::Text(text.to_string()) + } + LanguageModelToolResultContent::Image(image) => { + if model.supports_vision() { + responses::ResponseFunctionOutput::Content(vec![ + responses::ResponseInputContent::InputImage { + image_url: Some(image.to_base64_url()), + detail: Default::default(), + }, + ]) + } else { + debug_panic!( + "This should be caught at {} level", + tool_result.tool_name + ); + responses::ResponseFunctionOutput::Text( + "[Tool responded with an image, but this model does not support vision]".into(), + ) + } + } + } + }; + + input_items.push(responses::ResponseInputItem::FunctionCallOutput { + call_id: tool_result.tool_use_id.to_string(), + output, + status: None, + }); + } + } + + let mut parts: Vec = Vec::new(); + for content in &message.content { + match content { + MessageContent::Text(text) => { + parts.push(responses::ResponseInputContent::InputText { + text: text.clone(), + }); + } + + MessageContent::Image(image) => { + if model.supports_vision() { + parts.push(responses::ResponseInputContent::InputImage { + image_url: Some(image.to_base64_url()), + detail: Default::default(), + }); + } + } + _ => {} + } + } + + if !parts.is_empty() { + input_items.push(responses::ResponseInputItem::Message { + role: "user".into(), + content: Some(parts), + status: None, + }); + } + } + + Role::Assistant => { + for content in &message.content { + if let MessageContent::ToolUse(tool_use) = content { + input_items.push(responses::ResponseInputItem::FunctionCall { + call_id: tool_use.id.to_string(), + name: tool_use.name.to_string(), + arguments: tool_use.raw_input.clone(), + status: None, + }); + } + } + + for content in &message.content { + if let MessageContent::RedactedThinking(data) = content { + input_items.push(responses::ResponseInputItem::Reasoning { + id: None, + summary: Vec::new(), + encrypted_content: data.clone(), + }); + } + } + + let mut parts: Vec = Vec::new(); + for content in &message.content { + match content { + MessageContent::Text(text) => { + parts.push(responses::ResponseInputContent::OutputText { + text: text.clone(), + }); + } + MessageContent::Image(_) => { + parts.push(responses::ResponseInputContent::OutputText { + text: "[image omitted]".to_string(), + }); + } + _ => {} + } + } + + if !parts.is_empty() { + input_items.push(responses::ResponseInputItem::Message { + role: "assistant".into(), + content: Some(parts), + status: Some("completed".into()), + }); + } + } + + Role::System => { + let mut parts: Vec = Vec::new(); + for content in &message.content { + if let MessageContent::Text(text) = content { + parts.push(responses::ResponseInputContent::InputText { + text: text.clone(), + }); + } + } + + if !parts.is_empty() { + input_items.push(responses::ResponseInputItem::Message { + role: "system".into(), + content: Some(parts), + status: None, + }); + } + } + } + } + + let converted_tools: Vec = tools + .into_iter() + .map(|tool| responses::ToolDefinition::Function { + name: tool.name, + description: Some(tool.description), + parameters: Some(tool.input_schema), + strict: None, + }) + .collect(); + + let mapped_tool_choice = tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => responses::ToolChoice::Auto, + LanguageModelToolChoice::Any => responses::ToolChoice::Any, + LanguageModelToolChoice::None => responses::ToolChoice::None, + }); + + responses::Request { + model: model.id().to_string(), + input: input_items, + stream: model.uses_streaming(), + temperature, + tools: converted_tools, + tool_choice: mapped_tool_choice, + reasoning: None, // We would need to add support for setting from user settings. + include: Some(vec![ + copilot::copilot_responses::ResponseIncludable::ReasoningEncryptedContent, + ]), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use copilot::copilot_responses as responses; + use futures::StreamExt; + + fn map_events(events: Vec) -> Vec { + futures::executor::block_on(async { + CopilotResponsesEventMapper::new() + .map_stream(Box::pin(futures::stream::iter(events.into_iter().map(Ok)))) + .collect::>() + .await + .into_iter() + .map(Result::unwrap) + .collect() + }) + } + + #[test] + fn responses_stream_maps_text_and_usage() { + let events = vec![ + responses::StreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: responses::ResponseOutputItem::Message { + id: "msg_1".into(), + role: "assistant".into(), + content: Some(Vec::new()), + }, + }, + responses::StreamEvent::OutputTextDelta { + item_id: "msg_1".into(), + output_index: 0, + delta: "Hello".into(), + }, + responses::StreamEvent::Completed { + response: responses::Response { + usage: Some(responses::ResponseUsage { + input_tokens: Some(5), + output_tokens: Some(3), + total_tokens: Some(8), + }), + ..Default::default() + }, + }, + ]; + + let mapped = map_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::StartMessage { ref message_id } if message_id == "msg_1" + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::Text(ref text) if text == "Hello" + )); + assert!(matches!( + mapped[2], + LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: 5, + output_tokens: 3, + .. + }) + )); + assert!(matches!( + mapped[3], + LanguageModelCompletionEvent::Stop(StopReason::EndTurn) + )); + } + + #[test] + fn responses_stream_maps_tool_calls() { + let events = vec![responses::StreamEvent::OutputItemDone { + output_index: 0, + sequence_number: None, + item: responses::ResponseOutputItem::FunctionCall { + id: Some("fn_1".into()), + call_id: "call_1".into(), + name: "do_it".into(), + arguments: "{\"x\":1}".into(), + status: None, + }, + }]; + + let mapped = map_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::ToolUse(ref use_) if use_.id.to_string() == "call_1" && use_.name.as_ref() == "do_it" + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::Stop(StopReason::ToolUse) + )); + } + + #[test] + fn responses_stream_handles_json_parse_error() { + let events = vec![responses::StreamEvent::OutputItemDone { + output_index: 0, + sequence_number: None, + item: responses::ResponseOutputItem::FunctionCall { + id: Some("fn_1".into()), + call_id: "call_1".into(), + name: "do_it".into(), + arguments: "{not json}".into(), + status: None, + }, + }]; + + let mapped = map_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::ToolUseJsonParseError { ref id, ref tool_name, .. } + if id.to_string() == "call_1" && tool_name.as_ref() == "do_it" + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::Stop(StopReason::ToolUse) + )); + } + + #[test] + fn responses_stream_maps_reasoning_summary_and_encrypted_content() { + let events = vec![responses::StreamEvent::OutputItemDone { + output_index: 0, + sequence_number: None, + item: responses::ResponseOutputItem::Reasoning { + id: "r1".into(), + summary: Some(vec![responses::ResponseReasoningItem { + kind: "summary_text".into(), + text: "Chain".into(), + }]), + encrypted_content: Some("ENC".into()), + }, + }]; + + let mapped = map_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::Thinking { ref text, signature: None } if text == "Chain" + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::RedactedThinking { ref data } if data == "ENC" + )); + } + + #[test] + fn responses_stream_handles_incomplete_max_tokens() { + let events = vec![responses::StreamEvent::Incomplete { + response: responses::Response { + usage: Some(responses::ResponseUsage { + input_tokens: Some(10), + output_tokens: Some(0), + total_tokens: Some(10), + }), + incomplete_details: Some(responses::IncompleteDetails { + reason: Some(responses::IncompleteReason::MaxOutputTokens), + }), + ..Default::default() + }, + }]; + + let mapped = map_events(events); + assert!(matches!( + mapped[0], + LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: 10, + output_tokens: 0, + .. + }) + )); + assert!(matches!( + mapped[1], + LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) + )); + } + + #[test] + fn responses_stream_handles_incomplete_content_filter() { + let events = vec![responses::StreamEvent::Incomplete { + response: responses::Response { + usage: None, + incomplete_details: Some(responses::IncompleteDetails { + reason: Some(responses::IncompleteReason::ContentFilter), + }), + ..Default::default() + }, + }]; + + let mapped = map_events(events); + assert!(matches!( + mapped.last().unwrap(), + LanguageModelCompletionEvent::Stop(StopReason::Refusal) + )); + } + + #[test] + fn responses_stream_completed_no_duplicate_after_tool_use() { + let events = vec![ + responses::StreamEvent::OutputItemDone { + output_index: 0, + sequence_number: None, + item: responses::ResponseOutputItem::FunctionCall { + id: Some("fn_1".into()), + call_id: "call_1".into(), + name: "do_it".into(), + arguments: "{}".into(), + status: None, + }, + }, + responses::StreamEvent::Completed { + response: responses::Response::default(), + }, + ]; + + let mapped = map_events(events); + + let mut stop_count = 0usize; + let mut saw_tool_use_stop = false; + for event in mapped { + if let LanguageModelCompletionEvent::Stop(reason) = event { + stop_count += 1; + if matches!(reason, StopReason::ToolUse) { + saw_tool_use_stop = true; + } + } + } + assert_eq!(stop_count, 1, "should emit exactly one Stop event"); + assert!(saw_tool_use_stop, "Stop reason should be ToolUse"); + } + + #[test] + fn responses_stream_failed_maps_http_response_error() { + let events = vec![responses::StreamEvent::Failed { + response: responses::Response { + error: Some(responses::ResponseError { + code: "429".into(), + message: "too many requests".into(), + }), + ..Default::default() + }, + }]; + + let mapped_results = futures::executor::block_on(async { + CopilotResponsesEventMapper::new() + .map_stream(Box::pin(futures::stream::iter(events.into_iter().map(Ok)))) + .collect::>() + .await + }); + + assert_eq!(mapped_results.len(), 1); + match &mapped_results[0] { + Err(LanguageModelCompletionError::HttpResponseError { + status_code, + message, + .. + }) => { + assert_eq!(*status_code, http_client::StatusCode::TOO_MANY_REQUESTS); + assert_eq!(message, "too many requests"); + } + other => panic!("expected HttpResponseError, got {:?}", other), + } + } +} struct ConfigurationView { copilot_status: Option, state: Entity,