Add image input support for OpenAI models (#30639)

Agus Zubiaga created

Release Notes:

- Added input image support for OpenAI models

Change summary

crates/anthropic/src/anthropic.rs                |   2 
crates/copilot/src/copilot_chat.rs               |   6 
crates/language_models/src/provider/anthropic.rs |   2 
crates/language_models/src/provider/open_ai.rs   |  79 +++++++++--
crates/open_ai/src/open_ai.rs                    | 116 ++++++++++++++---
5 files changed, 162 insertions(+), 43 deletions(-)

Detailed changes

crates/anthropic/src/anthropic.rs 🔗

@@ -543,7 +543,7 @@ pub enum RequestContent {
 #[derive(Debug, Serialize, Deserialize)]
 #[serde(untagged)]
 pub enum ToolResultContent {
-    JustText(String),
+    Plain(String),
     Multipart(Vec<ToolResultPart>),
 }
 

crates/copilot/src/copilot_chat.rs 🔗

@@ -217,7 +217,7 @@ pub enum ChatMessage {
 #[derive(Debug, Serialize, Deserialize)]
 #[serde(untagged)]
 pub enum ChatMessageContent {
-    OnlyText(String),
+    Plain(String),
     Multipart(Vec<ChatMessagePart>),
 }
 
@@ -230,7 +230,7 @@ impl ChatMessageContent {
 impl From<Vec<ChatMessagePart>> for ChatMessageContent {
     fn from(mut parts: Vec<ChatMessagePart>) -> Self {
         if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() {
-            ChatMessageContent::OnlyText(std::mem::take(text))
+            ChatMessageContent::Plain(std::mem::take(text))
         } else {
             ChatMessageContent::Multipart(parts)
         }
@@ -239,7 +239,7 @@ impl From<Vec<ChatMessagePart>> for ChatMessageContent {
 
 impl From<String> for ChatMessageContent {
     fn from(text: String) -> Self {
-        ChatMessageContent::OnlyText(text)
+        ChatMessageContent::Plain(text)
     }
 }
 

crates/language_models/src/provider/anthropic.rs 🔗

@@ -589,7 +589,7 @@ pub fn into_anthropic(
                                 is_error: tool_result.is_error,
                                 content: match tool_result.content {
                                     LanguageModelToolResultContent::Text(text) => {
-                                        ToolResultContent::JustText(text.to_string())
+                                        ToolResultContent::Plain(text.to_string())
                                     }
                                     LanguageModelToolResultContent::Image(image) => {
                                         ToolResultContent::Multipart(vec![ToolResultPart::Image {

crates/language_models/src/provider/open_ai.rs 🔗

@@ -15,7 +15,7 @@ use language_model::{
     LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
     RateLimiter, Role, StopReason,
 };
-use open_ai::{Model, ResponseStreamEvent, stream_completion};
+use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use settings::{Settings, SettingsStore};
@@ -362,17 +362,26 @@ pub fn into_open_ai(
     for message in request.messages {
         for content in message.content {
             match content {
-                MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
-                    .push(match message.role {
-                        Role::User => open_ai::RequestMessage::User { content: text },
-                        Role::Assistant => open_ai::RequestMessage::Assistant {
-                            content: Some(text),
-                            tool_calls: Vec::new(),
-                        },
-                        Role::System => open_ai::RequestMessage::System { content: text },
-                    }),
+                MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
+                    add_message_content_part(
+                        open_ai::MessagePart::Text { text: text },
+                        message.role,
+                        &mut messages,
+                    )
+                }
                 MessageContent::RedactedThinking(_) => {}
-                MessageContent::Image(_) => {}
+                MessageContent::Image(image) => {
+                    add_message_content_part(
+                        open_ai::MessagePart::Image {
+                            image_url: ImageUrl {
+                                url: image.to_base64_url(),
+                                detail: None,
+                            },
+                        },
+                        message.role,
+                        &mut messages,
+                    );
+                }
                 MessageContent::ToolUse(tool_use) => {
                     let tool_call = open_ai::ToolCall {
                         id: tool_use.id.to_string(),
@@ -391,22 +400,30 @@ pub fn into_open_ai(
                         tool_calls.push(tool_call);
                     } else {
                         messages.push(open_ai::RequestMessage::Assistant {
-                            content: None,
+                            content: open_ai::MessageContent::empty(),
                             tool_calls: vec![tool_call],
                         });
                     }
                 }
                 MessageContent::ToolResult(tool_result) => {
                     let content = match &tool_result.content {
-                        LanguageModelToolResultContent::Text(text) => text.to_string(),
-                        LanguageModelToolResultContent::Image(_) => {
-                            // TODO: Open AI image support
-                            "[Tool responded with an image, but Zed doesn't support these in Open AI models yet]".to_string()
+                        LanguageModelToolResultContent::Text(text) => {
+                            vec![open_ai::MessagePart::Text {
+                                text: text.to_string(),
+                            }]
+                        }
+                        LanguageModelToolResultContent::Image(image) => {
+                            vec![open_ai::MessagePart::Image {
+                                image_url: ImageUrl {
+                                    url: image.to_base64_url(),
+                                    detail: None,
+                                },
+                            }]
                         }
                     };
 
                     messages.push(open_ai::RequestMessage::Tool {
-                        content,
+                        content: content.into(),
                         tool_call_id: tool_result.tool_use_id.to_string(),
                     });
                 }
@@ -446,6 +463,34 @@ pub fn into_open_ai(
     }
 }
 
+fn add_message_content_part(
+    new_part: open_ai::MessagePart,
+    role: Role,
+    messages: &mut Vec<open_ai::RequestMessage>,
+) {
+    match (role, messages.last_mut()) {
+        (Role::User, Some(open_ai::RequestMessage::User { content }))
+        | (Role::Assistant, Some(open_ai::RequestMessage::Assistant { content, .. }))
+        | (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => {
+            content.push_part(new_part);
+        }
+        _ => {
+            messages.push(match role {
+                Role::User => open_ai::RequestMessage::User {
+                    content: open_ai::MessageContent::empty(),
+                },
+                Role::Assistant => open_ai::RequestMessage::Assistant {
+                    content: open_ai::MessageContent::empty(),
+                    tool_calls: Vec::new(),
+                },
+                Role::System => open_ai::RequestMessage::System {
+                    content: open_ai::MessageContent::empty(),
+                },
+            });
+        }
+    }
+}
+
 pub struct OpenAiEventMapper {
     tool_calls_by_index: HashMap<usize, RawToolCall>,
 }

crates/open_ai/src/open_ai.rs 🔗

@@ -278,22 +278,75 @@ pub struct FunctionDefinition {
 #[serde(tag = "role", rename_all = "lowercase")]
 pub enum RequestMessage {
     Assistant {
-        content: Option<String>,
+        content: MessageContent,
         #[serde(default, skip_serializing_if = "Vec::is_empty")]
         tool_calls: Vec<ToolCall>,
     },
     User {
-        content: String,
+        content: MessageContent,
     },
     System {
-        content: String,
+        content: MessageContent,
     },
     Tool {
-        content: String,
+        content: MessageContent,
         tool_call_id: String,
     },
 }
 
+#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
+#[serde(untagged)]
+pub enum MessageContent {
+    Plain(String),
+    Multipart(Vec<MessagePart>),
+}
+
+impl MessageContent {
+    pub fn empty() -> Self {
+        MessageContent::Multipart(vec![])
+    }
+
+    pub fn push_part(&mut self, part: MessagePart) {
+        match self {
+            MessageContent::Plain(text) => {
+                *self =
+                    MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
+            }
+            MessageContent::Multipart(parts) if parts.is_empty() => match part {
+                MessagePart::Text { text } => *self = MessageContent::Plain(text),
+                MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
+            },
+            MessageContent::Multipart(parts) => parts.push(part),
+        }
+    }
+}
+
+impl From<Vec<MessagePart>> for MessageContent {
+    fn from(mut parts: Vec<MessagePart>) -> Self {
+        if let [MessagePart::Text { text }] = parts.as_mut_slice() {
+            MessageContent::Plain(std::mem::take(text))
+        } else {
+            MessageContent::Multipart(parts)
+        }
+    }
+}
+
+#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
+#[serde(tag = "type")]
+pub enum MessagePart {
+    #[serde(rename = "text")]
+    Text { text: String },
+    #[serde(rename = "image_url")]
+    Image { image_url: ImageUrl },
+}
+
+#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
+pub struct ImageUrl {
+    pub url: String,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub detail: Option<String>,
+}
+
 #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 pub struct ToolCall {
     pub id: String,
@@ -509,24 +562,45 @@ fn adapt_response_to_stream(response: Response) -> ResponseStreamEvent {
         choices: response
             .choices
             .into_iter()
-            .map(|choice| ChoiceDelta {
-                index: choice.index,
-                delta: ResponseMessageDelta {
-                    role: Some(match choice.message {
-                        RequestMessage::Assistant { .. } => Role::Assistant,
-                        RequestMessage::User { .. } => Role::User,
-                        RequestMessage::System { .. } => Role::System,
-                        RequestMessage::Tool { .. } => Role::Tool,
-                    }),
-                    content: match choice.message {
-                        RequestMessage::Assistant { content, .. } => content,
-                        RequestMessage::User { content } => Some(content),
-                        RequestMessage::System { content } => Some(content),
-                        RequestMessage::Tool { content, .. } => Some(content),
+            .map(|choice| {
+                let content = match &choice.message {
+                    RequestMessage::Assistant { content, .. } => content,
+                    RequestMessage::User { content } => content,
+                    RequestMessage::System { content } => content,
+                    RequestMessage::Tool { content, .. } => content,
+                };
+
+                let mut text_content = String::new();
+                match content {
+                    MessageContent::Plain(text) => text_content.push_str(&text),
+                    MessageContent::Multipart(parts) => {
+                        for part in parts {
+                            match part {
+                                MessagePart::Text { text } => text_content.push_str(&text),
+                                MessagePart::Image { .. } => {}
+                            }
+                        }
+                    }
+                };
+
+                ChoiceDelta {
+                    index: choice.index,
+                    delta: ResponseMessageDelta {
+                        role: Some(match choice.message {
+                            RequestMessage::Assistant { .. } => Role::Assistant,
+                            RequestMessage::User { .. } => Role::User,
+                            RequestMessage::System { .. } => Role::System,
+                            RequestMessage::Tool { .. } => Role::Tool,
+                        }),
+                        content: if text_content.is_empty() {
+                            None
+                        } else {
+                            Some(text_content)
+                        },
+                        tool_calls: None,
                     },
-                    tool_calls: None,
-                },
-                finish_reason: choice.finish_reason,
+                    finish_reason: choice.finish_reason,
+                }
             })
             .collect(),
         usage: Some(response.usage),