@@ -43,6 +43,8 @@ pub struct AvailableModel {
pub max_tokens: usize,
pub max_output_tokens: Option<u32>,
pub max_completion_tokens: Option<u32>,
+ pub supports_tools: Option<bool>,
+ pub supports_images: Option<bool>,
}
pub struct OpenRouterLanguageModelProvider {
@@ -227,7 +229,8 @@ impl LanguageModelProvider for OpenRouterLanguageModelProvider {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
- supports_tools: Some(false),
+ supports_tools: model.supports_tools,
+ supports_images: model.supports_images,
});
}
@@ -345,7 +348,7 @@ impl LanguageModel for OpenRouterLanguageModel {
}
fn supports_images(&self) -> bool {
- false
+ self.model.supports_images.unwrap_or(false)
}
fn count_tokens(
@@ -386,20 +389,26 @@ pub fn into_open_router(
max_output_tokens: Option<u32>,
) -> open_router::Request {
let mut messages = Vec::new();
- for req_message in request.messages {
- for content in req_message.content {
+ for message in request.messages {
+ for content in message.content {
match content {
- MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
- .push(match req_message.role {
- Role::User => open_router::RequestMessage::User { content: text },
- Role::Assistant => open_router::RequestMessage::Assistant {
- content: Some(text),
- tool_calls: Vec::new(),
- },
- Role::System => open_router::RequestMessage::System { content: text },
- }),
+ MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
+ add_message_content_part(
+ open_router::MessagePart::Text { text: text },
+ message.role,
+ &mut messages,
+ )
+ }
MessageContent::RedactedThinking(_) => {}
- MessageContent::Image(_) => {}
+ MessageContent::Image(image) => {
+ add_message_content_part(
+ open_router::MessagePart::Image {
+ image_url: image.to_base64_url(),
+ },
+ message.role,
+ &mut messages,
+ );
+ }
MessageContent::ToolUse(tool_use) => {
let tool_call = open_router::ToolCall {
id: tool_use.id.to_string(),
@@ -425,16 +434,20 @@ pub fn into_open_router(
}
MessageContent::ToolResult(tool_result) => {
let content = match &tool_result.content {
- LanguageModelToolResultContent::Text(text) => {
- text.to_string()
+ LanguageModelToolResultContent::Text(text) => {
+ vec![open_router::MessagePart::Text {
+ text: text.to_string(),
+ }]
}
- LanguageModelToolResultContent::Image(_) => {
- "[Tool responded with an image, but Zed doesn't support these in Open AI models yet]".to_string()
+ LanguageModelToolResultContent::Image(image) => {
+ vec![open_router::MessagePart::Image {
+ image_url: image.to_base64_url(),
+ }]
}
};
messages.push(open_router::RequestMessage::Tool {
- content: content,
+ content: content.into(),
tool_call_id: tool_result.tool_use_id.to_string(),
});
}
@@ -473,6 +486,42 @@ pub fn into_open_router(
}
}
+fn add_message_content_part(
+ new_part: open_router::MessagePart,
+ role: Role,
+ messages: &mut Vec<open_router::RequestMessage>,
+) {
+ match (role, messages.last_mut()) {
+ (Role::User, Some(open_router::RequestMessage::User { content }))
+ | (Role::System, Some(open_router::RequestMessage::System { content })) => {
+ content.push_part(new_part);
+ }
+ (
+ Role::Assistant,
+ Some(open_router::RequestMessage::Assistant {
+ content: Some(content),
+ ..
+ }),
+ ) => {
+ content.push_part(new_part);
+ }
+ _ => {
+ messages.push(match role {
+ Role::User => open_router::RequestMessage::User {
+ content: open_router::MessageContent::from(vec![new_part]),
+ },
+ Role::Assistant => open_router::RequestMessage::Assistant {
+ content: Some(open_router::MessageContent::from(vec![new_part])),
+ tool_calls: Vec::new(),
+ },
+ Role::System => open_router::RequestMessage::System {
+ content: open_router::MessageContent::from(vec![new_part]),
+ },
+ });
+ }
+ }
+}
+
pub struct OpenRouterEventMapper {
tool_calls_by_index: HashMap<usize, RawToolCall>,
}
@@ -52,6 +52,7 @@ pub struct Model {
pub display_name: Option<String>,
pub max_tokens: usize,
pub supports_tools: Option<bool>,
+ pub supports_images: Option<bool>,
}
impl Model {
@@ -61,6 +62,7 @@ impl Model {
Some("Auto Router"),
Some(2000000),
Some(true),
+ Some(false),
)
}
@@ -73,12 +75,14 @@ impl Model {
display_name: Option<&str>,
max_tokens: Option<usize>,
supports_tools: Option<bool>,
+ supports_images: Option<bool>,
) -> Self {
Self {
name: name.to_owned(),
display_name: display_name.map(|s| s.to_owned()),
max_tokens: max_tokens.unwrap_or(2000000),
supports_tools,
+ supports_images,
}
}
@@ -154,22 +158,118 @@ pub struct FunctionDefinition {
#[serde(tag = "role", rename_all = "lowercase")]
pub enum RequestMessage {
Assistant {
- content: Option<String>,
+ content: Option<MessageContent>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
tool_calls: Vec<ToolCall>,
},
User {
- content: String,
+ content: MessageContent,
},
System {
- content: String,
+ content: MessageContent,
},
Tool {
- content: String,
+ content: MessageContent,
tool_call_id: String,
},
}
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(untagged)]
+pub enum MessageContent {
+ Plain(String),
+ Multipart(Vec<MessagePart>),
+}
+
+impl MessageContent {
+ pub fn empty() -> Self {
+ Self::Plain(String::new())
+ }
+
+ pub fn push_part(&mut self, part: MessagePart) {
+ match self {
+ Self::Plain(text) if text.is_empty() => {
+ *self = Self::Multipart(vec![part]);
+ }
+ Self::Plain(text) => {
+ let text_part = MessagePart::Text {
+ text: std::mem::take(text),
+ };
+ *self = Self::Multipart(vec![text_part, part]);
+ }
+ Self::Multipart(parts) => parts.push(part),
+ }
+ }
+}
+
+impl From<Vec<MessagePart>> for MessageContent {
+ fn from(parts: Vec<MessagePart>) -> Self {
+ if parts.len() == 1 {
+ if let MessagePart::Text { text } = &parts[0] {
+ return Self::Plain(text.clone());
+ }
+ }
+ Self::Multipart(parts)
+ }
+}
+
+impl From<String> for MessageContent {
+ fn from(text: String) -> Self {
+ Self::Plain(text)
+ }
+}
+
+impl From<&str> for MessageContent {
+ fn from(text: &str) -> Self {
+ Self::Plain(text.to_string())
+ }
+}
+
+impl MessageContent {
+ pub fn as_text(&self) -> Option<&str> {
+ match self {
+ Self::Plain(text) => Some(text),
+ Self::Multipart(parts) if parts.len() == 1 => {
+ if let MessagePart::Text { text } = &parts[0] {
+ Some(text)
+ } else {
+ None
+ }
+ }
+ _ => None,
+ }
+ }
+
+ pub fn to_text(&self) -> String {
+ match self {
+ Self::Plain(text) => text.clone(),
+ Self::Multipart(parts) => parts
+ .iter()
+ .filter_map(|part| {
+ if let MessagePart::Text { text } = part {
+ Some(text.as_str())
+ } else {
+ None
+ }
+ })
+ .collect::<Vec<_>>()
+ .join(""),
+ }
+ }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum MessagePart {
+ Text {
+ text: String,
+ },
+ #[serde(rename = "image_url")]
+ Image {
+ image_url: String,
+ },
+}
+
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ToolCall {
pub id: String,
@@ -266,6 +366,14 @@ pub struct ModelEntry {
pub context_length: Option<usize>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub supported_parameters: Vec<String>,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub architecture: Option<ModelArchitecture>,
+}
+
+#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
+pub struct ModelArchitecture {
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
+ pub input_modalities: Vec<String>,
}
pub async fn complete(
@@ -470,6 +578,13 @@ pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<M
),
max_tokens: entry.context_length.unwrap_or(2000000),
supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())),
+ supports_images: Some(
+ entry
+ .architecture
+ .as_ref()
+ .map(|arch| arch.input_modalities.contains(&"image".to_string()))
+ .unwrap_or(false),
+ ),
})
.collect();