@@ -18,6 +18,8 @@ use language_model::{
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
+use std::collections::HashMap;
+use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use strum::IntoEnumIterator;
@@ -27,9 +29,6 @@ use util::ResultExt;
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
-use std::collections::HashMap;
-use std::pin::Pin;
-
const PROVIDER_ID: &str = "mistral";
const PROVIDER_NAME: &str = "Mistral";
@@ -48,6 +47,7 @@ pub struct AvailableModel {
pub max_output_tokens: Option<u32>,
pub max_completion_tokens: Option<u32>,
pub supports_tools: Option<bool>,
+ pub supports_images: Option<bool>,
}
pub struct MistralLanguageModelProvider {
@@ -215,6 +215,7 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
max_output_tokens: model.max_output_tokens,
max_completion_tokens: model.max_completion_tokens,
supports_tools: model.supports_tools,
+ supports_images: model.supports_images,
},
);
}
@@ -314,7 +315,7 @@ impl LanguageModel for MistralLanguageModel {
}
fn supports_images(&self) -> bool {
- false
+ self.model.supports_images()
}
fn telemetry_id(&self) -> String {
@@ -389,58 +390,113 @@ pub fn into_mistral(
let stream = true;
let mut messages = Vec::new();
- for message in request.messages {
- for content in message.content {
- match content {
- MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
- .push(match message.role {
- Role::User => mistral::RequestMessage::User { content: text },
- Role::Assistant => mistral::RequestMessage::Assistant {
- content: Some(text),
- tool_calls: Vec::new(),
- },
- Role::System => mistral::RequestMessage::System { content: text },
- }),
- MessageContent::RedactedThinking(_) => {}
- MessageContent::Image(_) => {}
- MessageContent::ToolUse(tool_use) => {
- let tool_call = mistral::ToolCall {
- id: tool_use.id.to_string(),
- content: mistral::ToolCallContent::Function {
- function: mistral::FunctionContent {
- name: tool_use.name.to_string(),
- arguments: serde_json::to_string(&tool_use.input)
- .unwrap_or_default(),
- },
- },
- };
-
- if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
- messages.last_mut()
- {
- tool_calls.push(tool_call);
- } else {
- messages.push(mistral::RequestMessage::Assistant {
- content: None,
- tool_calls: vec![tool_call],
- });
+ for message in &request.messages {
+ match message.role {
+ Role::User => {
+ let mut message_content = mistral::MessageContent::empty();
+ for content in &message.content {
+ match content {
+ MessageContent::Text(text) => {
+ message_content
+ .push_part(mistral::MessagePart::Text { text: text.clone() });
+ }
+ MessageContent::Image(image_content) => {
+ message_content.push_part(mistral::MessagePart::ImageUrl {
+ image_url: image_content.to_base64_url(),
+ });
+ }
+ MessageContent::Thinking { text, .. } => {
+ message_content
+ .push_part(mistral::MessagePart::Text { text: text.clone() });
+ }
+ MessageContent::RedactedThinking(_) => {}
+ MessageContent::ToolUse(_) | MessageContent::ToolResult(_) => {
+ // Tool content is not supported in User messages for Mistral
+ }
}
}
- MessageContent::ToolResult(tool_result) => {
- let content = match &tool_result.content {
- LanguageModelToolResultContent::Text(text) => text.to_string(),
- LanguageModelToolResultContent::Image(_) => {
- // TODO: Mistral image support
- "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
- }
- };
-
- messages.push(mistral::RequestMessage::Tool {
- content,
- tool_call_id: tool_result.tool_use_id.to_string(),
+ if !matches!(message_content, mistral::MessageContent::Plain { ref content } if content.is_empty())
+ {
+ messages.push(mistral::RequestMessage::User {
+ content: message_content,
});
}
}
+ Role::Assistant => {
+ for content in &message.content {
+ match content {
+ MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
+ messages.push(mistral::RequestMessage::Assistant {
+ content: Some(text.clone()),
+ tool_calls: Vec::new(),
+ });
+ }
+ MessageContent::RedactedThinking(_) => {}
+ MessageContent::Image(_) => {}
+ MessageContent::ToolUse(tool_use) => {
+ let tool_call = mistral::ToolCall {
+ id: tool_use.id.to_string(),
+ content: mistral::ToolCallContent::Function {
+ function: mistral::FunctionContent {
+ name: tool_use.name.to_string(),
+ arguments: serde_json::to_string(&tool_use.input)
+ .unwrap_or_default(),
+ },
+ },
+ };
+
+ if let Some(mistral::RequestMessage::Assistant { tool_calls, .. }) =
+ messages.last_mut()
+ {
+ tool_calls.push(tool_call);
+ } else {
+ messages.push(mistral::RequestMessage::Assistant {
+ content: None,
+ tool_calls: vec![tool_call],
+ });
+ }
+ }
+ MessageContent::ToolResult(_) => {
+ // Tool results are not supported in Assistant messages
+ }
+ }
+ }
+ }
+ Role::System => {
+ for content in &message.content {
+ match content {
+ MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
+ messages.push(mistral::RequestMessage::System {
+ content: text.clone(),
+ });
+ }
+ MessageContent::RedactedThinking(_) => {}
+ MessageContent::Image(_)
+ | MessageContent::ToolUse(_)
+ | MessageContent::ToolResult(_) => {
+ // Images and tools are not supported in System messages
+ }
+ }
+ }
+ }
+ }
+ }
+
+ for message in &request.messages {
+ for content in &message.content {
+ if let MessageContent::ToolResult(tool_result) = content {
+ let content = match &tool_result.content {
+ LanguageModelToolResultContent::Text(text) => text.to_string(),
+ LanguageModelToolResultContent::Image(_) => {
+ "[Tool responded with an image, but Zed doesn't support these in Mistral models yet]".to_string()
+ }
+ };
+
+ messages.push(mistral::RequestMessage::Tool {
+ content,
+ tool_call_id: tool_result.tool_use_id.to_string(),
+ });
+ }
}
}
@@ -819,62 +875,88 @@ impl Render for ConfigurationView {
#[cfg(test)]
mod tests {
use super::*;
- use language_model;
+ use language_model::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
#[test]
- fn test_into_mistral_conversion() {
- let request = language_model::LanguageModelRequest {
+ fn test_into_mistral_basic_conversion() {
+ let request = LanguageModelRequest {
messages: vec![
- language_model::LanguageModelRequestMessage {
- role: language_model::Role::System,
- content: vec![language_model::MessageContent::Text(
- "You are a helpful assistant.".to_string(),
- )],
+ LanguageModelRequestMessage {
+ role: Role::System,
+ content: vec![MessageContent::Text("System prompt".into())],
cache: false,
},
- language_model::LanguageModelRequestMessage {
- role: language_model::Role::User,
- content: vec![language_model::MessageContent::Text(
- "Hello, how are you?".to_string(),
- )],
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![MessageContent::Text("Hello".into())],
cache: false,
},
],
- temperature: Some(0.7),
- tools: Vec::new(),
+ temperature: Some(0.5),
+ tools: vec![],
tool_choice: None,
thread_id: None,
prompt_id: None,
intent: None,
mode: None,
- stop: Vec::new(),
+ stop: vec![],
};
- let model_name = "mistral-medium-latest".to_string();
- let max_output_tokens = Some(1000);
- let mistral_request = into_mistral(request, model_name, max_output_tokens);
+ let mistral_request = into_mistral(request, "mistral-small-latest".into(), None);
- assert_eq!(mistral_request.model, "mistral-medium-latest");
- assert_eq!(mistral_request.temperature, Some(0.7));
- assert_eq!(mistral_request.max_tokens, Some(1000));
+ assert_eq!(mistral_request.model, "mistral-small-latest");
+ assert_eq!(mistral_request.temperature, Some(0.5));
+ assert_eq!(mistral_request.messages.len(), 2);
assert!(mistral_request.stream);
- assert!(mistral_request.tools.is_empty());
- assert!(mistral_request.tool_choice.is_none());
+ }
- assert_eq!(mistral_request.messages.len(), 2);
+ #[test]
+ fn test_into_mistral_with_image() {
+ let request = LanguageModelRequest {
+ messages: vec![LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![
+ MessageContent::Text("What's in this image?".into()),
+ MessageContent::Image(LanguageModelImage {
+ source: "base64data".into(),
+ size: Default::default(),
+ }),
+ ],
+ cache: false,
+ }],
+ tools: vec![],
+ tool_choice: None,
+ temperature: None,
+ thread_id: None,
+ prompt_id: None,
+ intent: None,
+ mode: None,
+ stop: vec![],
+ };
- match &mistral_request.messages[0] {
- mistral::RequestMessage::System { content } => {
- assert_eq!(content, "You are a helpful assistant.");
- }
- _ => panic!("Expected System message"),
- }
+ let mistral_request = into_mistral(request, "pixtral-12b-latest".into(), None);
- match &mistral_request.messages[1] {
- mistral::RequestMessage::User { content } => {
- assert_eq!(content, "Hello, how are you?");
+ assert_eq!(mistral_request.messages.len(), 1);
+ assert!(matches!(
+ &mistral_request.messages[0],
+ mistral::RequestMessage::User {
+ content: mistral::MessageContent::Multipart { .. }
}
- _ => panic!("Expected User message"),
+ ));
+
+ if let mistral::RequestMessage::User {
+ content: mistral::MessageContent::Multipart { content },
+ } = &mistral_request.messages[0]
+ {
+ assert_eq!(content.len(), 2);
+ assert!(matches!(
+ &content[0],
+ mistral::MessagePart::Text { text } if text == "What's in this image?"
+ ));
+ assert!(matches!(
+ &content[1],
+ mistral::MessagePart::ImageUrl { image_url } if image_url.starts_with("data:image/png;base64,")
+ ));
}
}
}
@@ -60,6 +60,10 @@ pub enum Model {
OpenCodestralMamba,
#[serde(rename = "devstral-small-latest", alias = "devstral-small-latest")]
DevstralSmallLatest,
+ #[serde(rename = "pixtral-12b-latest", alias = "pixtral-12b-latest")]
+ Pixtral12BLatest,
+ #[serde(rename = "pixtral-large-latest", alias = "pixtral-large-latest")]
+ PixtralLargeLatest,
#[serde(rename = "custom")]
Custom {
@@ -70,6 +74,7 @@ pub enum Model {
max_output_tokens: Option<u32>,
max_completion_tokens: Option<u32>,
supports_tools: Option<bool>,
+ supports_images: Option<bool>,
},
}
@@ -86,6 +91,9 @@ impl Model {
"mistral-small-latest" => Ok(Self::MistralSmallLatest),
"open-mistral-nemo" => Ok(Self::OpenMistralNemo),
"open-codestral-mamba" => Ok(Self::OpenCodestralMamba),
+ "devstral-small-latest" => Ok(Self::DevstralSmallLatest),
+ "pixtral-12b-latest" => Ok(Self::Pixtral12BLatest),
+ "pixtral-large-latest" => Ok(Self::PixtralLargeLatest),
invalid_id => anyhow::bail!("invalid model id '{invalid_id}'"),
}
}
@@ -99,6 +107,8 @@ impl Model {
Self::OpenMistralNemo => "open-mistral-nemo",
Self::OpenCodestralMamba => "open-codestral-mamba",
Self::DevstralSmallLatest => "devstral-small-latest",
+ Self::Pixtral12BLatest => "pixtral-12b-latest",
+ Self::PixtralLargeLatest => "pixtral-large-latest",
Self::Custom { name, .. } => name,
}
}
@@ -112,6 +122,8 @@ impl Model {
Self::OpenMistralNemo => "open-mistral-nemo",
Self::OpenCodestralMamba => "open-codestral-mamba",
Self::DevstralSmallLatest => "devstral-small-latest",
+ Self::Pixtral12BLatest => "pixtral-12b-latest",
+ Self::PixtralLargeLatest => "pixtral-large-latest",
Self::Custom {
name, display_name, ..
} => display_name.as_ref().unwrap_or(name),
@@ -127,6 +139,8 @@ impl Model {
Self::OpenMistralNemo => 131000,
Self::OpenCodestralMamba => 256000,
Self::DevstralSmallLatest => 262144,
+ Self::Pixtral12BLatest => 128000,
+ Self::PixtralLargeLatest => 128000,
Self::Custom { max_tokens, .. } => *max_tokens,
}
}
@@ -148,10 +162,29 @@ impl Model {
| Self::MistralSmallLatest
| Self::OpenMistralNemo
| Self::OpenCodestralMamba
- | Self::DevstralSmallLatest => true,
+ | Self::DevstralSmallLatest
+ | Self::Pixtral12BLatest
+ | Self::PixtralLargeLatest => true,
Self::Custom { supports_tools, .. } => supports_tools.unwrap_or(false),
}
}
+
+ pub fn supports_images(&self) -> bool {
+ match self {
+ Self::Pixtral12BLatest
+ | Self::PixtralLargeLatest
+ | Self::MistralMediumLatest
+ | Self::MistralSmallLatest => true,
+ Self::CodestralLatest
+ | Self::MistralLargeLatest
+ | Self::OpenMistralNemo
+ | Self::OpenCodestralMamba
+ | Self::DevstralSmallLatest => false,
+ Self::Custom {
+ supports_images, ..
+ } => supports_images.unwrap_or(false),
+ }
+ }
}
#[derive(Debug, Serialize, Deserialize)]
@@ -231,7 +264,8 @@ pub enum RequestMessage {
tool_calls: Vec<ToolCall>,
},
User {
- content: String,
+ #[serde(flatten)]
+ content: MessageContent,
},
System {
content: String,
@@ -242,6 +276,54 @@ pub enum RequestMessage {
},
}
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(untagged)]
+pub enum MessageContent {
+ #[serde(rename = "content")]
+ Plain { content: String },
+ #[serde(rename = "content")]
+ Multipart { content: Vec<MessagePart> },
+}
+
+impl MessageContent {
+ pub fn empty() -> Self {
+ Self::Plain {
+ content: String::new(),
+ }
+ }
+
+ pub fn push_part(&mut self, part: MessagePart) {
+ match self {
+ Self::Plain { content } => match part {
+ MessagePart::Text { text } => {
+ content.push_str(&text);
+ }
+ part => {
+ let mut parts = if content.is_empty() {
+ Vec::new()
+ } else {
+ vec![MessagePart::Text {
+ text: content.clone(),
+ }]
+ };
+ parts.push(part);
+ *self = Self::Multipart { content: parts };
+ }
+ },
+ Self::Multipart { content } => {
+ content.push(part);
+ }
+ }
+ }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum MessagePart {
+ Text { text: String },
+ ImageUrl { image_url: String },
+}
+
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ToolCall {
pub id: String,