@@ -14,10 +14,7 @@ use language_model::{
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, RateLimiter, Role,
};
-use lmstudio::{
- ChatCompletionRequest, ChatMessage, ModelType, ResponseStreamEvent, get_models,
- stream_chat_completion,
-};
+use lmstudio::{ModelType, get_models};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
@@ -49,6 +46,7 @@ pub struct AvailableModel {
pub display_name: Option<String>,
pub max_tokens: usize,
pub supports_tool_calls: bool,
+ pub supports_images: bool,
}
pub struct LmStudioLanguageModelProvider {
@@ -88,6 +86,7 @@ impl State {
.loaded_context_length
.or_else(|| model.max_context_length),
model.capabilities.supports_tool_calls(),
+ model.capabilities.supports_images() || model.r#type == ModelType::Vlm,
)
})
.collect();
@@ -201,6 +200,7 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
supports_tool_calls: model.supports_tool_calls,
+ supports_images: model.supports_images,
},
);
}
@@ -244,23 +244,34 @@ pub struct LmStudioLanguageModel {
}
impl LmStudioLanguageModel {
- fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
+ fn to_lmstudio_request(
+ &self,
+ request: LanguageModelRequest,
+ ) -> lmstudio::ChatCompletionRequest {
let mut messages = Vec::new();
for message in request.messages {
for content in message.content {
match content {
- MessageContent::Text(text) => messages.push(match message.role {
- Role::User => ChatMessage::User { content: text },
- Role::Assistant => ChatMessage::Assistant {
- content: Some(text),
- tool_calls: Vec::new(),
- },
- Role::System => ChatMessage::System { content: text },
- }),
+ MessageContent::Text(text) => add_message_content_part(
+ lmstudio::MessagePart::Text { text },
+ message.role,
+ &mut messages,
+ ),
MessageContent::Thinking { .. } => {}
MessageContent::RedactedThinking(_) => {}
- MessageContent::Image(_) => {}
+ MessageContent::Image(image) => {
+ add_message_content_part(
+ lmstudio::MessagePart::Image {
+ image_url: lmstudio::ImageUrl {
+ url: image.to_base64_url(),
+ detail: None,
+ },
+ },
+ message.role,
+ &mut messages,
+ );
+ }
MessageContent::ToolUse(tool_use) => {
let tool_call = lmstudio::ToolCall {
id: tool_use.id.to_string(),
@@ -285,23 +296,32 @@ impl LmStudioLanguageModel {
}
}
MessageContent::ToolResult(tool_result) => {
- match &tool_result.content {
+ let content = match &tool_result.content {
LanguageModelToolResultContent::Text(text) => {
- messages.push(lmstudio::ChatMessage::Tool {
- content: text.to_string(),
- tool_call_id: tool_result.tool_use_id.to_string(),
- });
+ vec![lmstudio::MessagePart::Text {
+ text: text.to_string(),
+ }]
}
- LanguageModelToolResultContent::Image(_) => {
- // no support for images for now
+ LanguageModelToolResultContent::Image(image) => {
+ vec![lmstudio::MessagePart::Image {
+ image_url: lmstudio::ImageUrl {
+ url: image.to_base64_url(),
+ detail: None,
+ },
+ }]
}
};
+
+ messages.push(lmstudio::ChatMessage::Tool {
+ content: content.into(),
+ tool_call_id: tool_result.tool_use_id.to_string(),
+ });
}
}
}
}
- ChatCompletionRequest {
+ lmstudio::ChatCompletionRequest {
model: self.model.name.clone(),
messages,
stream: true,
@@ -332,10 +352,12 @@ impl LmStudioLanguageModel {
fn stream_completion(
&self,
- request: ChatCompletionRequest,
+ request: lmstudio::ChatCompletionRequest,
cx: &AsyncApp,
- ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
- {
+ ) -> BoxFuture<
+ 'static,
+ Result<futures::stream::BoxStream<'static, Result<lmstudio::ResponseStreamEvent>>>,
+ > {
let http_client = self.http_client.clone();
let Ok(api_url) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
@@ -345,7 +367,7 @@ impl LmStudioLanguageModel {
};
let future = self.request_limiter.stream(async move {
- let request = stream_chat_completion(http_client.as_ref(), &api_url, request);
+ let request = lmstudio::stream_chat_completion(http_client.as_ref(), &api_url, request);
let response = request.await?;
Ok(response)
});
@@ -385,7 +407,7 @@ impl LanguageModel for LmStudioLanguageModel {
}
fn supports_images(&self) -> bool {
- false
+ self.model.supports_images
}
fn telemetry_id(&self) -> String {
@@ -446,7 +468,7 @@ impl LmStudioEventMapper {
pub fn map_stream(
mut self,
- events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
+ events: Pin<Box<dyn Send + Stream<Item = Result<lmstudio::ResponseStreamEvent>>>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
events.flat_map(move |event| {
@@ -459,7 +481,7 @@ impl LmStudioEventMapper {
pub fn map_event(
&mut self,
- event: ResponseStreamEvent,
+ event: lmstudio::ResponseStreamEvent,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.into_iter().next() else {
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
@@ -551,6 +573,40 @@ struct RawToolCall {
arguments: String,
}
+fn add_message_content_part(
+ new_part: lmstudio::MessagePart,
+ role: Role,
+ messages: &mut Vec<lmstudio::ChatMessage>,
+) {
+ match (role, messages.last_mut()) {
+ (Role::User, Some(lmstudio::ChatMessage::User { content }))
+ | (
+ Role::Assistant,
+ Some(lmstudio::ChatMessage::Assistant {
+ content: Some(content),
+ ..
+ }),
+ )
+ | (Role::System, Some(lmstudio::ChatMessage::System { content })) => {
+ content.push_part(new_part);
+ }
+ _ => {
+ messages.push(match role {
+ Role::User => lmstudio::ChatMessage::User {
+ content: lmstudio::MessageContent::from(vec![new_part]),
+ },
+ Role::Assistant => lmstudio::ChatMessage::Assistant {
+ content: Some(lmstudio::MessageContent::from(vec![new_part])),
+ tool_calls: Vec::new(),
+ },
+ Role::System => lmstudio::ChatMessage::System {
+ content: lmstudio::MessageContent::from(vec![new_part]),
+ },
+ });
+ }
+ }
+}
+
struct ConfigurationView {
state: gpui::Entity<State>,
loading_models_task: Option<Task<()>>,
@@ -48,6 +48,7 @@ pub struct Model {
pub display_name: Option<String>,
pub max_tokens: usize,
pub supports_tool_calls: bool,
+ pub supports_images: bool,
}
impl Model {
@@ -56,12 +57,14 @@ impl Model {
display_name: Option<&str>,
max_tokens: Option<usize>,
supports_tool_calls: bool,
+ supports_images: bool,
) -> Self {
Self {
name: name.to_owned(),
display_name: display_name.map(|s| s.to_owned()),
max_tokens: max_tokens.unwrap_or(2048),
supports_tool_calls,
+ supports_images,
}
}
@@ -110,22 +113,78 @@ pub struct FunctionDefinition {
pub enum ChatMessage {
Assistant {
#[serde(default)]
- content: Option<String>,
+ content: Option<MessageContent>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
tool_calls: Vec<ToolCall>,
},
User {
- content: String,
+ content: MessageContent,
},
System {
- content: String,
+ content: MessageContent,
},
Tool {
- content: String,
+ content: MessageContent,
tool_call_id: String,
},
}
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(untagged)]
+pub enum MessageContent {
+ Plain(String),
+ Multipart(Vec<MessagePart>),
+}
+
+impl MessageContent {
+ pub fn empty() -> Self {
+ MessageContent::Multipart(vec![])
+ }
+
+ pub fn push_part(&mut self, part: MessagePart) {
+ match self {
+ MessageContent::Plain(text) => {
+ *self =
+ MessageContent::Multipart(vec![MessagePart::Text { text: text.clone() }, part]);
+ }
+ MessageContent::Multipart(parts) if parts.is_empty() => match part {
+ MessagePart::Text { text } => *self = MessageContent::Plain(text),
+ MessagePart::Image { .. } => *self = MessageContent::Multipart(vec![part]),
+ },
+ MessageContent::Multipart(parts) => parts.push(part),
+ }
+ }
+}
+
+impl From<Vec<MessagePart>> for MessageContent {
+ fn from(mut parts: Vec<MessagePart>) -> Self {
+ if let [MessagePart::Text { text }] = parts.as_mut_slice() {
+ MessageContent::Plain(std::mem::take(text))
+ } else {
+ MessageContent::Multipart(parts)
+ }
+ }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum MessagePart {
+ Text {
+ text: String,
+ },
+ #[serde(rename = "image_url")]
+ Image {
+ image_url: ImageUrl,
+ },
+}
+
+#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
+pub struct ImageUrl {
+ pub url: String,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub detail: Option<String>,
+}
+
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ToolCall {
pub id: String,
@@ -210,6 +269,10 @@ impl Capabilities {
pub fn supports_tool_calls(&self) -> bool {
self.0.iter().any(|cap| cap == "tool_use")
}
+
+ pub fn supports_images(&self) -> bool {
+ self.0.iter().any(|cap| cap == "vision")
+ }
}
#[derive(Serialize, Deserialize, Debug)]
@@ -393,3 +456,38 @@ pub async fn get_models(
serde_json::from_str(&body).context("Unable to parse LM Studio models response")?;
Ok(response.data)
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_image_message_part_serialization() {
+ let image_part = MessagePart::Image {
+ image_url: ImageUrl {
+ url: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==".to_string(),
+ detail: None,
+ },
+ };
+
+ let json = serde_json::to_string(&image_part).unwrap();
+ println!("Serialized image part: {}", json);
+
+ // Verify the structure matches what LM Studio expects
+ let expected_structure = r#"{"type":"image_url","image_url":{"url":"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=="}}"#;
+ assert_eq!(json, expected_structure);
+ }
+
+ #[test]
+ fn test_text_message_part_serialization() {
+ let text_part = MessagePart::Text {
+ text: "Hello, world!".to_string(),
+ };
+
+ let json = serde_json::to_string(&text_part).unwrap();
+ println!("Serialized text part: {}", json);
+
+ let expected_structure = r#"{"type":"text","text":"Hello, world!"}"#;
+ assert_eq!(json, expected_structure);
+ }
+}