language_models: Add images support to LMStudio provider (#32741)

Umesh Yadav created

Tested with gemma3:4b
LMStudio: beta version 0.3.17

Release Notes:

- Add images support to LMStudio provider

Change summary

crates/agent_settings/src/agent_settings.rs     |   4 
crates/language_models/src/provider/lmstudio.rs | 114 ++++++++++++++----
crates/lmstudio/src/lmstudio.rs                 | 106 +++++++++++++++++
3 files changed, 190 insertions(+), 34 deletions(-)

Detailed changes

crates/agent_settings/src/agent_settings.rs 🔗

@@ -386,7 +386,9 @@ impl AgentSettingsContent {
                             _ => None,
                         };
                         settings.provider = Some(AgentProviderContentV1::LmStudio {
-                            default_model: Some(lmstudio::Model::new(&model, None, None, false)),
+                            default_model: Some(lmstudio::Model::new(
+                                &model, None, None, false, false,
+                            )),
                             api_url,
                         });
                     }

crates/language_models/src/provider/lmstudio.rs 🔗

@@ -14,10 +14,7 @@ use language_model::{
     LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
     LanguageModelRequest, RateLimiter, Role,
 };
-use lmstudio::{
-    ChatCompletionRequest, ChatMessage, ModelType, ResponseStreamEvent, get_models,
-    stream_chat_completion,
-};
+use lmstudio::{ModelType, get_models};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use settings::{Settings, SettingsStore};
@@ -49,6 +46,7 @@ pub struct AvailableModel {
     pub display_name: Option<String>,
     pub max_tokens: usize,
     pub supports_tool_calls: bool,
+    pub supports_images: bool,
 }
 
 pub struct LmStudioLanguageModelProvider {
@@ -88,6 +86,7 @@ impl State {
                             .loaded_context_length
                             .or_else(|| model.max_context_length),
                         model.capabilities.supports_tool_calls(),
+                        model.capabilities.supports_images() || model.r#type == ModelType::Vlm,
                     )
                 })
                 .collect();
@@ -201,6 +200,7 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
                     display_name: model.display_name.clone(),
                     max_tokens: model.max_tokens,
                     supports_tool_calls: model.supports_tool_calls,
+                    supports_images: model.supports_images,
                 },
             );
         }
@@ -244,23 +244,34 @@ pub struct LmStudioLanguageModel {
 }
 
 impl LmStudioLanguageModel {
-    fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
+    fn to_lmstudio_request(
+        &self,
+        request: LanguageModelRequest,
+    ) -> lmstudio::ChatCompletionRequest {
         let mut messages = Vec::new();
 
         for message in request.messages {
             for content in message.content {
                 match content {
-                    MessageContent::Text(text) => messages.push(match message.role {
-                        Role::User => ChatMessage::User { content: text },
-                        Role::Assistant => ChatMessage::Assistant {
-                            content: Some(text),
-                            tool_calls: Vec::new(),
-                        },
-                        Role::System => ChatMessage::System { content: text },
-                    }),
+                    MessageContent::Text(text) => add_message_content_part(
+                        lmstudio::MessagePart::Text { text },
+                        message.role,
+                        &mut messages,
+                    ),
                     MessageContent::Thinking { .. } => {}
                     MessageContent::RedactedThinking(_) => {}
-                    MessageContent::Image(_) => {}
+                    MessageContent::Image(image) => {
+                        add_message_content_part(
+                            lmstudio::MessagePart::Image {
+                                image_url: lmstudio::ImageUrl {
+                                    url: image.to_base64_url(),
+                                    detail: None,
+                                },
+                            },
+                            message.role,
+                            &mut messages,
+                        );
+                    }
                     MessageContent::ToolUse(tool_use) => {
                         let tool_call = lmstudio::ToolCall {
                             id: tool_use.id.to_string(),
@@ -285,23 +296,32 @@ impl LmStudioLanguageModel {
                         }
                     }
                     MessageContent::ToolResult(tool_result) => {
-                        match &tool_result.content {
+                        let content = match &tool_result.content {
                             LanguageModelToolResultContent::Text(text) => {
-                                messages.push(lmstudio::ChatMessage::Tool {
-                                    content: text.to_string(),
-                                    tool_call_id: tool_result.tool_use_id.to_string(),
-                                });
+                                vec![lmstudio::MessagePart::Text {
+                                    text: text.to_string(),
+                                }]
                             }
-                            LanguageModelToolResultContent::Image(_) => {
-                                // no support for images for now
+                            LanguageModelToolResultContent::Image(image) => {
+                                vec![lmstudio::MessagePart::Image {
+                                    image_url: lmstudio::ImageUrl {
+                                        url: image.to_base64_url(),
+                                        detail: None,
+                                    },
+                                }]
                             }
                         };
+
+                        messages.push(lmstudio::ChatMessage::Tool {
+                            content: content.into(),
+                            tool_call_id: tool_result.tool_use_id.to_string(),
+                        });
                     }
                 }
             }
         }
 
-        ChatCompletionRequest {
+        lmstudio::ChatCompletionRequest {
             model: self.model.name.clone(),
             messages,
             stream: true,
@@ -332,10 +352,12 @@ impl LmStudioLanguageModel {
 
     fn stream_completion(
         &self,
-        request: ChatCompletionRequest,
+        request: lmstudio::ChatCompletionRequest,
         cx: &AsyncApp,
-    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
-    {
+    ) -> BoxFuture<
+        'static,
+        Result<futures::stream::BoxStream<'static, Result<lmstudio::ResponseStreamEvent>>>,
+    > {
         let http_client = self.http_client.clone();
         let Ok(api_url) = cx.update(|cx| {
             let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
@@ -345,7 +367,7 @@ impl LmStudioLanguageModel {
         };
 
         let future = self.request_limiter.stream(async move {
-            let request = stream_chat_completion(http_client.as_ref(), &api_url, request);
+            let request = lmstudio::stream_chat_completion(http_client.as_ref(), &api_url, request);
             let response = request.await?;
             Ok(response)
         });
@@ -385,7 +407,7 @@ impl LanguageModel for LmStudioLanguageModel {
     }
 
     fn supports_images(&self) -> bool {
-        false
+        self.model.supports_images
     }
 
     fn telemetry_id(&self) -> String {
@@ -446,7 +468,7 @@ impl LmStudioEventMapper {
 
     pub fn map_stream(
         mut self,
-        events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
+        events: Pin<Box<dyn Send + Stream<Item = Result<lmstudio::ResponseStreamEvent>>>>,
     ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
     {
         events.flat_map(move |event| {
@@ -459,7 +481,7 @@ impl LmStudioEventMapper {
 
     pub fn map_event(
         &mut self,
-        event: ResponseStreamEvent,
+        event: lmstudio::ResponseStreamEvent,
     ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
         let Some(choice) = event.choices.into_iter().next() else {
             return vec![Err(LanguageModelCompletionError::Other(anyhow!(
@@ -551,6 +573,40 @@ struct RawToolCall {
     arguments: String,
 }
 
+fn add_message_content_part(
+    new_part: lmstudio::MessagePart,
+    role: Role,
+    messages: &mut Vec<lmstudio::ChatMessage>,
+) {
+    match (role, messages.last_mut()) {
+        (Role::User, Some(lmstudio::ChatMessage::User { content }))
+        | (
+            Role::Assistant,
+            Some(lmstudio::ChatMessage::Assistant {
+                content: Some(content),
+                ..
+            }),
+        )
+        | (Role::System, Some(lmstudio::ChatMessage::System { content })) => {
+            content.push_part(new_part);
+        }
+        _ => {
+            messages.push(match role {
+                Role::User => lmstudio::ChatMessage::User {
+                    content: lmstudio::MessageContent::from(vec![new_part]),
+                },
+                Role::Assistant => lmstudio::ChatMessage::Assistant {
+                    content: Some(lmstudio::MessageContent::from(vec![new_part])),
+                    tool_calls: Vec::new(),
+                },
+                Role::System => lmstudio::ChatMessage::System {
+                    content: lmstudio::MessageContent::from(vec![new_part]),
+                },
+            });
+        }
+    }
+}
+
 struct ConfigurationView {
     state: gpui::Entity<State>,
     loading_models_task: Option<Task<()>>,

crates/lmstudio/src/lmstudio.rs 🔗

@@ -48,6 +48,7 @@ pub struct Model {
     pub display_name: Option<String>,
     pub max_tokens: usize,
     pub supports_tool_calls: bool,
+    pub supports_images: bool,
 }
 
 impl Model {
@@ -56,12 +57,14 @@ impl Model {
         display_name: Option<&str>,
         max_tokens: Option<usize>,
         supports_tool_calls: bool,
+        supports_images: bool,
     ) -> Self {
         Self {
             name: name.to_owned(),
             display_name: display_name.map(|s| s.to_owned()),
             max_tokens: max_tokens.unwrap_or(2048),
             supports_tool_calls,
+            supports_images,
         }
     }
 
@@ -110,22 +113,78 @@ pub struct FunctionDefinition {
 pub enum ChatMessage {
     Assistant {
         #[serde(default)]
-        content: Option<String>,
+        content: Option<MessageContent>,
         #[serde(default, skip_serializing_if = "Vec::is_empty")]
         tool_calls: Vec<ToolCall>,
     },
     User {
-        content: String,
+        content: MessageContent,
     },
     System {
-        content: String,
+        content: MessageContent,
     },
     Tool {
-        content: String,
+        content: MessageContent,
         tool_call_id: String,
     },
 }
 
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(untagged)]
+pub enum MessageContent {
+    Plain(String),
+    Multipart(Vec<MessagePart>),
+}
+
+impl MessageContent {
+    pub fn empty() -> Self {
+        MessageContent::Multipart(vec![])
+    }
+
+    pub fn push_part(&mut self, part: MessagePart) {
+        match self {
+            MessageContent::Plain(text) => {
+                *self =
+                    MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
+            }
+            MessageContent::Multipart(parts) if parts.is_empty() => match part {
+                MessagePart::Text { text } => *self = MessageContent::Plain(text),
+                MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
+            },
+            MessageContent::Multipart(parts) => parts.push(part),
+        }
+    }
+}
+
+impl From<Vec<MessagePart>> for MessageContent {
+    fn from(mut parts: Vec<MessagePart>) -> Self {
+        if let [MessagePart::Text { text }] = parts.as_mut_slice() {
+            MessageContent::Plain(std::mem::take(text))
+        } else {
+            MessageContent::Multipart(parts)
+        }
+    }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum MessagePart {
+    Text {
+        text: String,
+    },
+    #[serde(rename = "image_url")]
+    Image {
+        image_url: ImageUrl,
+    },
+}
+
+#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
+pub struct ImageUrl {
+    pub url: String,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub detail: Option<String>,
+}
+
 #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 pub struct ToolCall {
     pub id: String,
@@ -210,6 +269,10 @@ impl Capabilities {
     pub fn supports_tool_calls(&self) -> bool {
         self.0.iter().any(|cap| cap == "tool_use")
     }
+
+    pub fn supports_images(&self) -> bool {
+        self.0.iter().any(|cap| cap == "vision")
+    }
 }
 
 #[derive(Serialize, Deserialize, Debug)]
@@ -393,3 +456,38 @@ pub async fn get_models(
         serde_json::from_str(&body).context("Unable to parse LM Studio models response")?;
     Ok(response.data)
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_image_message_part_serialization() {
+        let image_part = MessagePart::Image {
+            image_url: ImageUrl {
+                url: "".to_string(),
+                detail: None,
+            },
+        };
+
+        let json = serde_json::to_string(&image_part).unwrap();
+        println!("Serialized image part: {}", json);
+
+        // Verify the structure matches what LM Studio expects
+        let expected_structure = r#"{"type":"image_url","image_url":{"url":""}}"#;
+        assert_eq!(json, expected_structure);
+    }
+
+    #[test]
+    fn test_text_message_part_serialization() {
+        let text_part = MessagePart::Text {
+            text: "Hello, world!".to_string(),
+        };
+
+        let json = serde_json::to_string(&text_part).unwrap();
+        println!("Serialized text part: {}", json);
+
+        let expected_structure = r#"{"type":"text","text":"Hello, world!"}"#;
+        assert_eq!(json, expected_structure);
+    }
+}