diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 623316916ba0a095047c11d00251e4f9f7be69df..3d1cefa07f07a169f78fbf27c3454f5371ce21b6 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -43,6 +43,8 @@ pub struct AvailableModel { pub max_tokens: usize, pub max_output_tokens: Option, pub max_completion_tokens: Option, + pub supports_tools: Option, + pub supports_images: Option, } pub struct OpenRouterLanguageModelProvider { @@ -227,7 +229,8 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider { name: model.name.clone(), display_name: model.display_name.clone(), max_tokens: model.max_tokens, - supports_tools: Some(false), + supports_tools: model.supports_tools, + supports_images: model.supports_images, }); } @@ -345,7 +348,7 @@ impl LanguageModel for OpenRouterLanguageModel { } fn supports_images(&self) -> bool { - false + self.model.supports_images.unwrap_or(false) } fn count_tokens( @@ -386,20 +389,26 @@ pub fn into_open_router( max_output_tokens: Option, ) -> open_router::Request { let mut messages = Vec::new(); - for req_message in request.messages { - for content in req_message.content { + for message in request.messages { + for content in message.content { match content { - MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages - .push(match req_message.role { - Role::User => open_router::RequestMessage::User { content: text }, - Role::Assistant => open_router::RequestMessage::Assistant { - content: Some(text), - tool_calls: Vec::new(), - }, - Role::System => open_router::RequestMessage::System { content: text }, - }), + MessageContent::Text(text) | MessageContent::Thinking { text, .. } => { + add_message_content_part( + open_router::MessagePart::Text { text: text }, + message.role, + &mut messages, + ) + } MessageContent::RedactedThinking(_) => {} - MessageContent::Image(_) => {} + MessageContent::Image(image) => { + add_message_content_part( + open_router::MessagePart::Image { + image_url: image.to_base64_url(), + }, + message.role, + &mut messages, + ); + } MessageContent::ToolUse(tool_use) => { let tool_call = open_router::ToolCall { id: tool_use.id.to_string(), @@ -425,16 +434,20 @@ pub fn into_open_router( } MessageContent::ToolResult(tool_result) => { let content = match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - text.to_string() + LanguageModelToolResultContent::Text(text) => { + vec![open_router::MessagePart::Text { + text: text.to_string(), + }] } - LanguageModelToolResultContent::Image(_) => { - "[Tool responded with an image, but Zed doesn't support these in Open AI models yet]".to_string() + LanguageModelToolResultContent::Image(image) => { + vec![open_router::MessagePart::Image { + image_url: image.to_base64_url(), + }] } }; messages.push(open_router::RequestMessage::Tool { - content: content, + content: content.into(), tool_call_id: tool_result.tool_use_id.to_string(), }); } @@ -473,6 +486,42 @@ pub fn into_open_router( } } +fn add_message_content_part( + new_part: open_router::MessagePart, + role: Role, + messages: &mut Vec, +) { + match (role, messages.last_mut()) { + (Role::User, Some(open_router::RequestMessage::User { content })) + | (Role::System, Some(open_router::RequestMessage::System { content })) => { + content.push_part(new_part); + } + ( + Role::Assistant, + Some(open_router::RequestMessage::Assistant { + content: Some(content), + .. + }), + ) => { + content.push_part(new_part); + } + _ => { + messages.push(match role { + Role::User => open_router::RequestMessage::User { + content: open_router::MessageContent::from(vec![new_part]), + }, + Role::Assistant => open_router::RequestMessage::Assistant { + content: Some(open_router::MessageContent::from(vec![new_part])), + tool_calls: Vec::new(), + }, + Role::System => open_router::RequestMessage::System { + content: open_router::MessageContent::from(vec![new_part]), + }, + }); + } + } +} + pub struct OpenRouterEventMapper { tool_calls_by_index: HashMap, } diff --git a/crates/open_router/src/open_router.rs b/crates/open_router/src/open_router.rs index f0fe07150358dd85a46351b8fa2bd56a9f5bb0e6..ad3009b48ff184b664fb9237af5af179af2ad837 100644 --- a/crates/open_router/src/open_router.rs +++ b/crates/open_router/src/open_router.rs @@ -52,6 +52,7 @@ pub struct Model { pub display_name: Option, pub max_tokens: usize, pub supports_tools: Option, + pub supports_images: Option, } impl Model { @@ -61,6 +62,7 @@ impl Model { Some("Auto Router"), Some(2000000), Some(true), + Some(false), ) } @@ -73,12 +75,14 @@ impl Model { display_name: Option<&str>, max_tokens: Option, supports_tools: Option, + supports_images: Option, ) -> Self { Self { name: name.to_owned(), display_name: display_name.map(|s| s.to_owned()), max_tokens: max_tokens.unwrap_or(2000000), supports_tools, + supports_images, } } @@ -154,22 +158,118 @@ pub struct FunctionDefinition { #[serde(tag = "role", rename_all = "lowercase")] pub enum RequestMessage { Assistant { - content: Option, + content: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] tool_calls: Vec, }, User { - content: String, + content: MessageContent, }, System { - content: String, + content: MessageContent, }, Tool { - content: String, + content: MessageContent, tool_call_id: String, }, } +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(untagged)] +pub enum MessageContent { + Plain(String), + Multipart(Vec), +} + +impl MessageContent { + pub fn empty() -> Self { + Self::Plain(String::new()) + } + + pub fn push_part(&mut self, part: MessagePart) { + match self { + Self::Plain(text) if text.is_empty() => { + *self = Self::Multipart(vec![part]); + } + Self::Plain(text) => { + let text_part = MessagePart::Text { + text: std::mem::take(text), + }; + *self = Self::Multipart(vec![text_part, part]); + } + Self::Multipart(parts) => parts.push(part), + } + } +} + +impl From> for MessageContent { + fn from(parts: Vec) -> Self { + if parts.len() == 1 { + if let MessagePart::Text { text } = &parts[0] { + return Self::Plain(text.clone()); + } + } + Self::Multipart(parts) + } +} + +impl From for MessageContent { + fn from(text: String) -> Self { + Self::Plain(text) + } +} + +impl From<&str> for MessageContent { + fn from(text: &str) -> Self { + Self::Plain(text.to_string()) + } +} + +impl MessageContent { + pub fn as_text(&self) -> Option<&str> { + match self { + Self::Plain(text) => Some(text), + Self::Multipart(parts) if parts.len() == 1 => { + if let MessagePart::Text { text } = &parts[0] { + Some(text) + } else { + None + } + } + _ => None, + } + } + + pub fn to_text(&self) -> String { + match self { + Self::Plain(text) => text.clone(), + Self::Multipart(parts) => parts + .iter() + .filter_map(|part| { + if let MessagePart::Text { text } = part { + Some(text.as_str()) + } else { + None + } + }) + .collect::>() + .join(""), + } + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum MessagePart { + Text { + text: String, + }, + #[serde(rename = "image_url")] + Image { + image_url: String, + }, +} + #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct ToolCall { pub id: String, @@ -266,6 +366,14 @@ pub struct ModelEntry { pub context_length: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub supported_parameters: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub architecture: Option, +} + +#[derive(Default, Debug, Clone, PartialEq, Deserialize)] +pub struct ModelArchitecture { + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub input_modalities: Vec, } pub async fn complete( @@ -470,6 +578,13 @@ pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result