diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 281d5fdcd6e8f550bde1497455ddebe984545dc9..0acf61c7e118a1c8e08269eb50dc6be54a9dde10 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -14,6 +14,10 @@ use language_model::{ TokenUsage, env_var, }; use menu; +use open_ai::responses::{ + ResponseFunctionCallItem, ResponseFunctionCallOutputItem, ResponseInputContent, + ResponseInputItem, ResponseMessageItem, +}; use open_ai::{ ImageUrl, Model, OPEN_AI_API_URL, ReasoningEffort, ResponseStreamEvent, responses::{ @@ -22,7 +26,6 @@ use open_ai::{ }, stream_completion, }; -use serde_json::{Value, json}; use settings::{OpenAiAvailableModel as AvailableModel, Settings, SettingsStore}; use std::pin::Pin; use std::str::FromStr as _; @@ -585,9 +588,9 @@ pub fn into_open_ai_response( fn append_message_to_response_items( message: LanguageModelRequestMessage, index: usize, - input_items: &mut Vec, + input_items: &mut Vec, ) { - let mut content_parts: Vec = Vec::new(); + let mut content_parts: Vec = Vec::new(); for content in message.content { match content { @@ -604,20 +607,20 @@ fn append_message_to_response_items( MessageContent::ToolUse(tool_use) => { flush_response_parts(&message.role, index, &mut content_parts, input_items); let call_id = tool_use.id.to_string(); - input_items.push(json!({ - "type": "function_call", - "call_id": call_id, - "name": tool_use.name, - "arguments": tool_use.raw_input, + input_items.push(ResponseInputItem::FunctionCall(ResponseFunctionCallItem { + call_id, + name: tool_use.name.to_string(), + arguments: tool_use.raw_input, })); } MessageContent::ToolResult(tool_result) => { flush_response_parts(&message.role, index, &mut content_parts, input_items); - input_items.push(json!({ - "type": "function_call_output", - "call_id": tool_result.tool_use_id.to_string(), - "output": tool_result_output(&tool_result), - })); + input_items.push(ResponseInputItem::FunctionCallOutput( + ResponseFunctionCallOutputItem { + call_id: tool_result.tool_use_id.to_string(), + output: tool_result_output(&tool_result), + }, + )); } } } @@ -625,67 +628,59 @@ fn append_message_to_response_items( flush_response_parts(&message.role, index, &mut content_parts, input_items); } -fn push_response_text_part(role: &Role, text: impl Into, parts: &mut Vec) { +fn push_response_text_part( + role: &Role, + text: impl Into, + parts: &mut Vec, +) { let text = text.into(); if text.trim().is_empty() { return; } match role { - Role::Assistant => parts.push(json!({ - "type": "output_text", - "text": text, - "annotations": [], - })), - _ => parts.push(json!({ - "type": "input_text", - "text": text, - })), + Role::Assistant => parts.push(ResponseInputContent::OutputText { + text, + annotations: Vec::new(), + }), + _ => parts.push(ResponseInputContent::Text { text }), } } -fn push_response_image_part(role: &Role, image: LanguageModelImage, parts: &mut Vec) { +fn push_response_image_part( + role: &Role, + image: LanguageModelImage, + parts: &mut Vec, +) { match role { - Role::Assistant => parts.push(json!({ - "type": "output_text", - "text": "[image omitted]", - "annotations": [], - })), - _ => parts.push(json!({ - "type": "input_image", - "image_url": image.to_base64_url(), - })), + Role::Assistant => parts.push(ResponseInputContent::OutputText { + text: "[image omitted]".to_string(), + annotations: Vec::new(), + }), + _ => parts.push(ResponseInputContent::Image { + image_url: image.to_base64_url(), + }), } } fn flush_response_parts( role: &Role, _index: usize, - parts: &mut Vec, - input_items: &mut Vec, + parts: &mut Vec, + input_items: &mut Vec, ) { if parts.is_empty() { return; } - let item = match role { - Role::Assistant => json!({ - "type": "message", - "role": "assistant", - "status": "completed", - "content": parts.clone(), - }), - Role::User => json!({ - "type": "message", - "role": "user", - "content": parts.clone(), - }), - Role::System => json!({ - "type": "message", - "role": "system", - "content": parts.clone(), - }), - }; + let item = ResponseInputItem::Message(ResponseMessageItem { + role: match role { + Role::User => open_ai::Role::User, + Role::Assistant => open_ai::Role::Assistant, + Role::System => open_ai::Role::System, + }, + content: parts.clone(), + }); input_items.push(item); parts.clear(); @@ -1358,7 +1353,6 @@ impl Render for ConfigurationView { #[cfg(test)] mod tests { - use super::*; use futures::{StreamExt, executor::block_on}; use gpui::TestAppContext; use language_model::{LanguageModelRequestMessage, LanguageModelRequestTool}; @@ -1367,6 +1361,9 @@ mod tests { ResponseSummary, ResponseUsage, StreamEvent as ResponsesStreamEvent, }; use pretty_assertions::assert_eq; + use serde_json::json; + + use super::*; fn map_response_events(events: Vec) -> Vec { block_on(async { @@ -1587,7 +1584,6 @@ mod tests { { "type": "message", "role": "assistant", - "status": "completed", "content": [ { "type": "output_text", "text": "Looking that up.", "annotations": [] } ] diff --git a/crates/open_ai/src/responses.rs b/crates/open_ai/src/responses.rs index e135f19fcb69254a81ad047221751f828f7d1f33..9196b4a11fbaeeabb9ebe7e59cf106c4d260c267 100644 --- a/crates/open_ai/src/responses.rs +++ b/crates/open_ai/src/responses.rs @@ -4,13 +4,13 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use serde::{Deserialize, Serialize}; use serde_json::Value; -use crate::{ReasoningEffort, RequestError, ToolChoice}; +use crate::{ReasoningEffort, RequestError, Role, ToolChoice}; #[derive(Serialize, Debug)] pub struct Request { pub model: String, #[serde(skip_serializing_if = "Vec::is_empty")] - pub input: Vec, + pub input: Vec, #[serde(default)] pub stream: bool, #[serde(skip_serializing_if = "Option::is_none")] @@ -31,6 +31,50 @@ pub struct Request { pub reasoning: Option, } +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ResponseInputItem { + Message(ResponseMessageItem), + FunctionCall(ResponseFunctionCallItem), + FunctionCallOutput(ResponseFunctionCallOutputItem), +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ResponseMessageItem { + pub role: Role, + pub content: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ResponseFunctionCallItem { + pub call_id: String, + pub name: String, + pub arguments: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ResponseFunctionCallOutputItem { + pub call_id: String, + pub output: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ResponseInputContent { + #[serde(rename = "input_text")] + Text { text: String }, + #[serde(rename = "input_image")] + Image { image_url: String }, + #[serde(rename = "output_text")] + OutputText { + text: String, + #[serde(default)] + annotations: Vec, + }, + #[serde(rename = "refusal")] + Refusal { refusal: String }, +} + #[derive(Serialize, Debug)] pub struct ReasoningConfig { pub effort: ReasoningEffort,