@@ -1,10 +1,13 @@
use anyhow::{Result, anyhow};
+use collections::HashMap;
+use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
- LanguageModelToolChoice,
+ LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
+ StopReason, WrappedTextContent,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
@@ -12,12 +15,14 @@ use language_model::{
LanguageModelRequest, RateLimiter, Role,
};
use lmstudio::{
- ChatCompletionRequest, ChatMessage, ModelType, get_models, preload_model,
+ ChatCompletionRequest, ChatMessage, ModelType, ResponseStreamEvent, get_models, preload_model,
stream_chat_completion,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
+use std::pin::Pin;
+use std::str::FromStr;
use std::{collections::BTreeMap, sync::Arc};
use ui::{ButtonLike, Indicator, List, prelude::*};
use util::ResultExt;
@@ -40,12 +45,10 @@ pub struct LmStudioSettings {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
- /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
pub name: String,
- /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
pub display_name: Option<String>,
- /// The model's context window size.
pub max_tokens: usize,
+ pub supports_tool_calls: bool,
}
pub struct LmStudioLanguageModelProvider {
@@ -77,7 +80,14 @@ impl State {
let mut models: Vec<lmstudio::Model> = models
.into_iter()
.filter(|model| model.r#type != ModelType::Embeddings)
- .map(|model| lmstudio::Model::new(&model.id, None, None))
+ .map(|model| {
+ lmstudio::Model::new(
+ &model.id,
+ None,
+ None,
+ model.capabilities.supports_tool_calls(),
+ )
+ })
.collect();
models.sort_by(|a, b| a.name.cmp(&b.name));
@@ -156,12 +166,16 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
IconName::AiLmStudio
}
- fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
- self.provided_models(cx).into_iter().next()
+ fn default_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
+ // We shouldn't try to select default model, because it might lead to a load call for an unloaded model.
+ // In a constrained environment where user might not have enough resources it'll be a bad UX to select something
+ // to load by default.
+ None
}
- fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
- self.default_model(cx)
+ fn default_fast_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
+ // See explanation for default_model.
+ None
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
@@ -184,6 +198,7 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
+ supports_tool_calls: model.supports_tool_calls,
},
);
}
@@ -237,31 +252,117 @@ pub struct LmStudioLanguageModel {
impl LmStudioLanguageModel {
fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
+ let mut messages = Vec::new();
+
+ for message in request.messages {
+ for content in message.content {
+ match content {
+ MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
+ .push(match message.role {
+ Role::User => ChatMessage::User { content: text },
+ Role::Assistant => ChatMessage::Assistant {
+ content: Some(text),
+ tool_calls: Vec::new(),
+ },
+ Role::System => ChatMessage::System { content: text },
+ }),
+ MessageContent::RedactedThinking(_) => {}
+ MessageContent::Image(_) => {}
+ MessageContent::ToolUse(tool_use) => {
+ let tool_call = lmstudio::ToolCall {
+ id: tool_use.id.to_string(),
+ content: lmstudio::ToolCallContent::Function {
+ function: lmstudio::FunctionContent {
+ name: tool_use.name.to_string(),
+ arguments: serde_json::to_string(&tool_use.input)
+ .unwrap_or_default(),
+ },
+ },
+ };
+
+ if let Some(lmstudio::ChatMessage::Assistant { tool_calls, .. }) =
+ messages.last_mut()
+ {
+ tool_calls.push(tool_call);
+ } else {
+ messages.push(lmstudio::ChatMessage::Assistant {
+ content: None,
+ tool_calls: vec![tool_call],
+ });
+ }
+ }
+ MessageContent::ToolResult(tool_result) => {
+ match &tool_result.content {
+ LanguageModelToolResultContent::Text(text)
+ | LanguageModelToolResultContent::WrappedText(WrappedTextContent {
+ text,
+ ..
+ }) => {
+ messages.push(lmstudio::ChatMessage::Tool {
+ content: text.to_string(),
+ tool_call_id: tool_result.tool_use_id.to_string(),
+ });
+ }
+ LanguageModelToolResultContent::Image(_) => {
+ // no support for images for now
+ }
+ };
+ }
+ }
+ }
+ }
+
ChatCompletionRequest {
model: self.model.name.clone(),
- messages: request
- .messages
+ messages,
+ stream: true,
+ max_tokens: Some(-1),
+ stop: Some(request.stop),
+ // In LM Studio you can configure specific settings you'd like to use for your model.
+ // For example Qwen3 is recommended to be used with 0.7 temperature.
+ // It would be a bad UX to silently override these settings from Zed, so we pass no temperature as a default.
+ temperature: request.temperature.or(None),
+ tools: request
+ .tools
.into_iter()
- .map(|msg| match msg.role {
- Role::User => ChatMessage::User {
- content: msg.string_contents(),
- },
- Role::Assistant => ChatMessage::Assistant {
- content: Some(msg.string_contents()),
- tool_calls: None,
- },
- Role::System => ChatMessage::System {
- content: msg.string_contents(),
+ .map(|tool| lmstudio::ToolDefinition::Function {
+ function: lmstudio::FunctionDefinition {
+ name: tool.name,
+ description: Some(tool.description),
+ parameters: Some(tool.input_schema),
},
})
.collect(),
- stream: true,
- max_tokens: Some(-1),
- stop: Some(request.stop),
- temperature: request.temperature.or(Some(0.0)),
- tools: vec![],
+ tool_choice: request.tool_choice.map(|choice| match choice {
+ LanguageModelToolChoice::Auto => lmstudio::ToolChoice::Auto,
+ LanguageModelToolChoice::Any => lmstudio::ToolChoice::Required,
+ LanguageModelToolChoice::None => lmstudio::ToolChoice::None,
+ }),
}
}
+
+ fn stream_completion(
+ &self,
+ request: ChatCompletionRequest,
+ cx: &AsyncApp,
+ ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
+ {
+ let http_client = self.http_client.clone();
+ let Ok(api_url) = cx.update(|cx| {
+ let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
+ settings.api_url.clone()
+ }) else {
+ return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+ };
+
+ let future = self.request_limiter.stream(async move {
+ let request = stream_chat_completion(http_client.as_ref(), &api_url, request);
+ let response = request.await?;
+ Ok(response)
+ });
+
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
}
impl LanguageModel for LmStudioLanguageModel {
@@ -282,14 +383,19 @@ impl LanguageModel for LmStudioLanguageModel {
}
fn supports_tools(&self) -> bool {
- false
+ self.model.supports_tool_calls()
}
- fn supports_images(&self) -> bool {
- false
+ fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
+ self.supports_tools()
+ && match choice {
+ LanguageModelToolChoice::Auto => true,
+ LanguageModelToolChoice::Any => true,
+ LanguageModelToolChoice::None => true,
+ }
}
- fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
+ fn supports_images(&self) -> bool {
false
}
@@ -328,85 +434,126 @@ impl LanguageModel for LmStudioLanguageModel {
>,
> {
let request = self.to_lmstudio_request(request);
-
- let http_client = self.http_client.clone();
- let Ok(api_url) = cx.update(|cx| {
- let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
- settings.api_url.clone()
- }) else {
- return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
- };
-
- let future = self.request_limiter.stream(async move {
- let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
-
- // Create a stream mapper to handle content across multiple deltas
- let stream_mapper = LmStudioStreamMapper::new();
-
- let stream = response
- .map(move |response| {
- response.and_then(|fragment| stream_mapper.process_fragment(fragment))
- })
- .filter_map(|result| async move {
- match result {
- Ok(Some(content)) => Some(Ok(content)),
- Ok(None) => None,
- Err(error) => Some(Err(error)),
- }
- })
- .boxed();
-
- Ok(stream)
- });
-
+ let completions = self.stream_completion(request, cx);
async move {
- Ok(future
- .await?
- .map(|result| {
- result
- .map(LanguageModelCompletionEvent::Text)
- .map_err(LanguageModelCompletionError::Other)
- })
- .boxed())
+ let mapper = LmStudioEventMapper::new();
+ Ok(mapper.map_stream(completions.await?).boxed())
}
.boxed()
}
}
-// This will be more useful when we implement tool calling. Currently keeping it empty.
-struct LmStudioStreamMapper {}
+struct LmStudioEventMapper {
+ tool_calls_by_index: HashMap<usize, RawToolCall>,
+}
-impl LmStudioStreamMapper {
+impl LmStudioEventMapper {
fn new() -> Self {
- Self {}
+ Self {
+ tool_calls_by_index: HashMap::default(),
+ }
+ }
+
+ pub fn map_stream(
+ mut self,
+ events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
+ ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
+ {
+ events.flat_map(move |event| {
+ futures::stream::iter(match event {
+ Ok(event) => self.map_event(event),
+ Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
+ })
+ })
}
- fn process_fragment(&self, fragment: lmstudio::ChatResponse) -> Result<Option<String>> {
- // Most of the time, there will be only one choice
- let Some(choice) = fragment.choices.first() else {
- return Ok(None);
+ pub fn map_event(
+ &mut self,
+ event: ResponseStreamEvent,
+ ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
+ let Some(choice) = event.choices.into_iter().next() else {
+ return vec![Err(LanguageModelCompletionError::Other(anyhow!(
+ "Response contained no choices"
+ )))];
};
- // Extract the delta content
- if let Ok(delta) =
- serde_json::from_value::<lmstudio::ResponseMessageDelta>(choice.delta.clone())
- {
- if let Some(content) = delta.content {
- if !content.is_empty() {
- return Ok(Some(content));
+ let mut events = Vec::new();
+ if let Some(content) = choice.delta.content {
+ events.push(Ok(LanguageModelCompletionEvent::Text(content)));
+ }
+
+ if let Some(tool_calls) = choice.delta.tool_calls {
+ for tool_call in tool_calls {
+ let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
+
+ if let Some(tool_id) = tool_call.id {
+ entry.id = tool_id;
+ }
+
+ if let Some(function) = tool_call.function {
+ if let Some(name) = function.name {
+ // At the time of writing this code LM Studio (0.3.15) is incompatible with the OpenAI API:
+ // 1. It sends function name in the first chunk
+ // 2. It sends empty string in the function name field in all subsequent chunks for arguments
+ // According to https://platform.openai.com/docs/guides/function-calling?api-mode=responses#streaming
+ // function name field should be sent only inside the first chunk.
+ if !name.is_empty() {
+ entry.name = name;
+ }
+ }
+
+ if let Some(arguments) = function.arguments {
+ entry.arguments.push_str(&arguments);
+ }
}
}
}
- // If there's a finish_reason, we're done
- if choice.finish_reason.is_some() {
- return Ok(None);
+ match choice.finish_reason.as_deref() {
+ Some("stop") => {
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
+ }
+ Some("tool_calls") => {
+ events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
+ match serde_json::Value::from_str(&tool_call.arguments) {
+ Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_call.id.into(),
+ name: tool_call.name.into(),
+ is_input_complete: true,
+ input,
+ raw_input: tool_call.arguments,
+ },
+ )),
+ Err(error) => Err(LanguageModelCompletionError::BadInputJson {
+ id: tool_call.id.into(),
+ tool_name: tool_call.name.into(),
+ raw_input: tool_call.arguments.into(),
+ json_parse_error: error.to_string(),
+ }),
+ }
+ }));
+
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
+ }
+ Some(stop_reason) => {
+ log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
+ events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
+ }
+ None => {}
}
- Ok(None)
+ events
}
}
+#[derive(Default)]
+struct RawToolCall {
+ id: String,
+ name: String,
+ arguments: String,
+}
+
struct ConfigurationView {
state: gpui::Entity<State>,
loading_models_task: Option<Task<()>>,
@@ -2,7 +2,7 @@ use anyhow::{Context as _, Result};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
use serde::{Deserialize, Serialize};
-use serde_json::{Value, value::RawValue};
+use serde_json::Value;
use std::{convert::TryFrom, sync::Arc, time::Duration};
pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0";
@@ -47,14 +47,21 @@ pub struct Model {
pub name: String,
pub display_name: Option<String>,
pub max_tokens: usize,
+ pub supports_tool_calls: bool,
}
impl Model {
- pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option<usize>) -> Self {
+ pub fn new(
+ name: &str,
+ display_name: Option<&str>,
+ max_tokens: Option<usize>,
+ supports_tool_calls: bool,
+ ) -> Self {
Self {
name: name.to_owned(),
display_name: display_name.map(|s| s.to_owned()),
max_tokens: max_tokens.unwrap_or(2048),
+ supports_tool_calls,
}
}
@@ -69,15 +76,43 @@ impl Model {
pub fn max_token_count(&self) -> usize {
self.max_tokens
}
+
+ pub fn supports_tool_calls(&self) -> bool {
+ self.supports_tool_calls
+ }
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(untagged)]
+pub enum ToolChoice {
+ Auto,
+ Required,
+ None,
+ Other(ToolDefinition),
}
+
+#[derive(Clone, Deserialize, Serialize, Debug)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum ToolDefinition {
+ #[allow(dead_code)]
+ Function { function: FunctionDefinition },
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct FunctionDefinition {
+ pub name: String,
+ pub description: Option<String>,
+ pub parameters: Option<Value>,
+}
+
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum ChatMessage {
Assistant {
#[serde(default)]
content: Option<String>,
- #[serde(default)]
- tool_calls: Option<Vec<LmStudioToolCall>>,
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
+ tool_calls: Vec<ToolCall>,
},
User {
content: String,
@@ -85,31 +120,29 @@ pub enum ChatMessage {
System {
content: String,
},
+ Tool {
+ content: String,
+ tool_call_id: String,
+ },
}
-#[derive(Serialize, Deserialize, Debug)]
-#[serde(rename_all = "lowercase")]
-pub enum LmStudioToolCall {
- Function(LmStudioFunctionCall),
-}
-
-#[derive(Serialize, Deserialize, Debug)]
-pub struct LmStudioFunctionCall {
- pub name: String,
- pub arguments: Box<RawValue>,
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct ToolCall {
+ pub id: String,
+ #[serde(flatten)]
+ pub content: ToolCallContent,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct LmStudioFunctionTool {
- pub name: String,
- pub description: Option<String>,
- pub parameters: Option<Value>,
+#[serde(tag = "type", rename_all = "lowercase")]
+pub enum ToolCallContent {
+ Function { function: FunctionContent },
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-#[serde(tag = "type", rename_all = "lowercase")]
-pub enum LmStudioTool {
- Function { function: LmStudioFunctionTool },
+pub struct FunctionContent {
+ pub name: String,
+ pub arguments: String,
}
#[derive(Serialize, Debug)]
@@ -117,10 +150,16 @@ pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
pub stream: bool,
+ #[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<i32>,
+ #[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
+ #[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
- pub tools: Vec<LmStudioTool>,
+ #[serde(skip_serializing_if = "Vec::is_empty")]
+ pub tools: Vec<ToolDefinition>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub tool_choice: Option<ToolChoice>,
}
#[derive(Serialize, Deserialize, Debug)]
@@ -135,8 +174,7 @@ pub struct ChatResponse {
#[derive(Serialize, Deserialize, Debug)]
pub struct ChoiceDelta {
pub index: u32,
- #[serde(default)]
- pub delta: serde_json::Value,
+ pub delta: ResponseMessageDelta,
pub finish_reason: Option<String>,
}
@@ -164,6 +202,16 @@ pub struct Usage {
pub total_tokens: u32,
}
+#[derive(Debug, Default, Clone, Deserialize, PartialEq)]
+#[serde(transparent)]
+pub struct Capabilities(Vec<String>);
+
+impl Capabilities {
+ pub fn supports_tool_calls(&self) -> bool {
+ self.0.iter().any(|cap| cap == "tool_use")
+ }
+}
+
#[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)]
pub enum ResponseStreamResult {
@@ -175,16 +223,17 @@ pub enum ResponseStreamResult {
pub struct ResponseStreamEvent {
pub created: u32,
pub model: String,
+ pub object: String,
pub choices: Vec<ChoiceDelta>,
pub usage: Option<Usage>,
}
-#[derive(Serialize, Deserialize)]
+#[derive(Deserialize)]
pub struct ListModelsResponse {
pub data: Vec<ModelEntry>,
}
-#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
+#[derive(Clone, Debug, Deserialize, PartialEq)]
pub struct ModelEntry {
pub id: String,
pub object: String,
@@ -196,6 +245,8 @@ pub struct ModelEntry {
pub state: ModelState,
pub max_context_length: Option<u32>,
pub loaded_context_length: Option<u32>,
+ #[serde(default)]
+ pub capabilities: Capabilities,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
@@ -265,7 +316,7 @@ pub async fn stream_chat_completion(
client: &dyn HttpClient,
api_url: &str,
request: ChatCompletionRequest,
-) -> Result<BoxStream<'static, Result<ChatResponse>>> {
+) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
let uri = format!("{api_url}/chat/completions");
let request_builder = http::Request::builder()
.method(Method::POST)