From a3d9269567a08fb46111a6d59c26e6ce10dbb071 Mon Sep 17 00:00:00 2001 From: Antoine Mathie Date: Fri, 6 Mar 2026 17:49:55 +0100 Subject: [PATCH] ai: Add LMStudio API URL & API key support (#48309) Hello, This pull request aims to improve usage of lmstudio ai provider for remote lmstudio nodes and support api key authentication. This has been tested on my local network from a headless lms node. See attached demo vid Release Notes: - lmstudio: Added support for specifying an API key via the UI https://github.com/user-attachments/assets/7594cf49-3198-4171-b3e9-c3264cf35b6e --------- Co-authored-by: Bennet Bo Fenner --- .../language_models/src/provider/lmstudio.rs | 456 +++++++++++++----- crates/lmstudio/src/lmstudio.rs | 21 +- crates/settings_content/src/language_model.rs | 1 + 3 files changed, 357 insertions(+), 121 deletions(-) diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index 9af8559c722d1fe726f7f871c9863cd85a3d2678..ee08f1689aeea9cfa18346108cd2d314b2259583 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -1,26 +1,30 @@ use anyhow::{Result, anyhow}; use collections::HashMap; +use fs::Fs; use futures::Stream; use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; -use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task}; +use gpui::{AnyView, App, AsyncApp, Context, CursorStyle, Entity, Subscription, Task}; use http_client::HttpClient; use language_model::{ - AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent, - LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent, - StopReason, TokenUsage, + ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError, + LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolResultContent, + LanguageModelToolUse, MessageContent, StopReason, TokenUsage, env_var, }; use language_model::{ - IconOrSvg, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, - LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, - LanguageModelRequest, RateLimiter, Role, + LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, }; -use lmstudio::{ModelType, get_models}; +use lmstudio::{LMSTUDIO_API_URL, ModelType, get_models}; + pub use settings::LmStudioAvailableModel as AvailableModel; -use settings::{Settings, SettingsStore}; +use settings::{Settings, SettingsStore, update_settings_file}; use std::pin::Pin; +use std::sync::LazyLock; use std::{collections::BTreeMap, sync::Arc}; -use ui::{ButtonLike, Indicator, List, ListBulletItem, prelude::*}; -use util::ResultExt; +use ui::{ + ButtonLike, ConfiguredApiCard, ElevationIndex, List, ListBulletItem, Tooltip, prelude::*, +}; +use ui_input::InputField; use crate::AllLanguageModelSettings; use crate::provider::util::parse_tool_arguments; @@ -32,6 +36,9 @@ const LMSTUDIO_SITE: &str = "https://lmstudio.ai/"; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("lmstudio"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("LM Studio"); +const API_KEY_ENV_VAR_NAME: &str = "LMSTUDIO_API_KEY"; +static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); + #[derive(Default, Debug, Clone, PartialEq)] pub struct LmStudioSettings { pub api_url: String, @@ -44,6 +51,7 @@ pub struct LmStudioLanguageModelProvider { } pub struct State { + api_key_state: ApiKeyState, http_client: Arc, available_models: Vec, fetch_model_task: Option>>, @@ -55,14 +63,25 @@ impl State { !self.available_models.is_empty() } + fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { + let api_url = LmStudioLanguageModelProvider::api_url(cx).into(); + let task = self + .api_key_state + .store(api_url, api_key, |this| &mut this.api_key_state, cx); + self.restart_fetch_models_task(cx); + task + } + fn fetch_models(&mut self, cx: &mut Context) -> Task> { let settings = &AllLanguageModelSettings::get_global(cx).lmstudio; let http_client = self.http_client.clone(); let api_url = settings.api_url.clone(); + let api_key = self.api_key_state.key(&api_url); // As a proxy for the server being "authenticated", we'll check if its up by fetching the models cx.spawn(async move |this, cx| { - let models = get_models(http_client.as_ref(), &api_url, None).await?; + let models = + get_models(http_client.as_ref(), &api_url, api_key.as_deref(), None).await?; let mut models: Vec = models .into_iter() @@ -95,6 +114,11 @@ impl State { } fn authenticate(&mut self, cx: &mut Context) -> Task> { + let api_url = LmStudioLanguageModelProvider::api_url(cx).into(); + let _task = self + .api_key_state + .load_if_needed(api_url, |this| &mut this.api_key_state, cx); + if self.is_authenticated() { return Task::ready(Ok(())); } @@ -145,6 +169,10 @@ impl LmStudioLanguageModelProvider { }); State { + api_key_state: ApiKeyState::new( + Self::api_url(cx).into(), + (*API_KEY_ENV_VAR).clone(), + ), http_client, available_models: Default::default(), fetch_model_task: None, @@ -156,6 +184,17 @@ impl LmStudioLanguageModelProvider { .update(cx, |state, cx| state.restart_fetch_models_task(cx)); this } + + fn api_url(cx: &App) -> String { + AllLanguageModelSettings::get_global(cx) + .lmstudio + .api_url + .clone() + } + + fn has_custom_url(cx: &App) -> bool { + Self::api_url(cx) != LMSTUDIO_API_URL + } } impl LanguageModelProviderState for LmStudioLanguageModelProvider { @@ -225,6 +264,7 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider { model, http_client: self.http_client.clone(), request_limiter: RateLimiter::new(4), + state: self.state.clone(), }) as Arc }) .collect() @@ -244,12 +284,13 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider { _window: &mut Window, cx: &mut App, ) -> AnyView { - let state = self.state.clone(); - cx.new(|cx| ConfigurationView::new(state, cx)).into() + cx.new(|cx| ConfigurationView::new(self.state.clone(), _window, cx)) + .into() } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.fetch_models(cx)) + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) } } @@ -258,6 +299,7 @@ pub struct LmStudioLanguageModel { model: lmstudio::Model, http_client: Arc, request_limiter: RateLimiter, + state: Entity, } impl LmStudioLanguageModel { @@ -376,15 +418,20 @@ impl LmStudioLanguageModel { Result>>, > { let http_client = self.http_client.clone(); - let api_url = cx.update(|cx| { - let settings = &AllLanguageModelSettings::get_global(cx).lmstudio; - settings.api_url.clone() + let (api_key, api_url) = self.state.read_with(cx, |state, cx| { + let api_url = LmStudioLanguageModelProvider::api_url(cx); + (state.api_key_state.key(&api_url), api_url) }); let future = self.request_limiter.stream(async move { - let request = lmstudio::stream_chat_completion(http_client.as_ref(), &api_url, request); - let response = request.await?; - Ok(response) + let stream = lmstudio::stream_chat_completion( + http_client.as_ref(), + &api_url, + api_key.as_deref(), + request, + ) + .await?; + Ok(stream) }); async move { Ok(future.await?.boxed()) }.boxed() @@ -634,53 +681,212 @@ fn add_message_content_part( struct ConfigurationView { state: Entity, - loading_models_task: Option>, + api_key_editor: Entity, + api_url_editor: Entity, } impl ConfigurationView { - pub fn new(state: Entity, cx: &mut Context) -> Self { - let loading_models_task = Some(cx.spawn({ - let state = state.clone(); - async move |this, cx| { - state - .update(cx, |state, cx| state.authenticate(cx)) - .await - .log_err(); - - this.update(cx, |this, cx| { - this.loading_models_task = None; - cx.notify(); - }) - .log_err(); - } - })); + pub fn new(state: Entity, _window: &mut Window, cx: &mut Context) -> Self { + let api_key_editor = cx.new(|cx| InputField::new(_window, cx, "sk-...").label("API key")); + + let api_url_editor = cx.new(|cx| { + let input = InputField::new(_window, cx, LMSTUDIO_API_URL).label("API URL"); + input.set_text(&LmStudioLanguageModelProvider::api_url(cx), _window, cx); + input + }); + + cx.observe(&state, |_, _, cx| { + cx.notify(); + }) + .detach(); Self { state, - loading_models_task, + api_key_editor, + api_url_editor, } } - fn retry_connection(&self, cx: &mut App) { + fn retry_connection(&mut self, _window: &mut Window, cx: &mut Context) { + let has_api_url = LmStudioLanguageModelProvider::has_custom_url(cx); + let has_api_key = self + .state + .read_with(cx, |state, _| state.api_key_state.has_key()); + if !has_api_url { + self.save_api_url(cx); + } + if !has_api_key { + self.save_api_key(&Default::default(), _window, cx); + } + + self.state.update(cx, |state, cx| { + state.restart_fetch_models_task(cx); + }); + } + + fn save_api_key(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context) { + let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string(); + if api_key.is_empty() { + return; + } + + self.api_key_editor + .update(cx, |input, cx| input.set_text("", _window, cx)); + + let state = self.state.clone(); + cx.spawn_in(_window, async move |_, cx| { + state + .update(cx, |state, cx| state.set_api_key(Some(api_key), cx)) + .await + }) + .detach_and_log_err(cx); + } + + fn reset_api_key(&mut self, _window: &mut Window, cx: &mut Context) { + self.api_key_editor + .update(cx, |input, cx| input.set_text("", _window, cx)); + + let state = self.state.clone(); + cx.spawn_in(_window, async move |_, cx| { + state + .update(cx, |state, cx| state.set_api_key(None, cx)) + .await + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn save_api_url(&self, cx: &mut Context) { + let api_url = self.api_url_editor.read(cx).text(cx).trim().to_string(); + let current_url = LmStudioLanguageModelProvider::api_url(cx); + if !api_url.is_empty() && &api_url != ¤t_url { + self.state + .update(cx, |state, cx| state.set_api_key(None, cx)) + .detach_and_log_err(cx); + + let fs = ::global(cx); + update_settings_file(fs, cx, move |settings, _| { + settings + .language_models + .get_or_insert_default() + .lmstudio + .get_or_insert_default() + .api_url = Some(api_url); + }); + } + } + + fn reset_api_url(&mut self, _window: &mut Window, cx: &mut Context) { + self.api_url_editor + .update(cx, |input, cx| input.set_text("", _window, cx)); + + // Clear API key when URL changes since keys are URL-specific self.state - .update(cx, |state, cx| state.fetch_models(cx)) + .update(cx, |state, cx| state.set_api_key(None, cx)) .detach_and_log_err(cx); - } -} -impl Render for ConfigurationView { - fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { - let is_authenticated = self.state.read(cx).is_authenticated(); + let fs = ::global(cx); + update_settings_file(fs, cx, |settings, _cx| { + if let Some(settings) = settings + .language_models + .as_mut() + .and_then(|models| models.lmstudio.as_mut()) + { + settings.api_url = Some(LMSTUDIO_API_URL.into()); + } + }); + cx.notify(); + } - let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen."; + fn render_api_url_editor(&self, cx: &Context) -> impl IntoElement { + let api_url = LmStudioLanguageModelProvider::api_url(cx); + let custom_api_url_set = api_url != LMSTUDIO_API_URL; - if self.loading_models_task.is_some() { - div().child(Label::new("Loading models...")).into_any() + if custom_api_url_set { + h_flex() + .p_3() + .justify_between() + .rounded_md() + .border_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().elevated_surface_background) + .child( + h_flex() + .gap_2() + .child(Icon::new(IconName::Check).color(Color::Success)) + .child(v_flex().gap_1().child(Label::new(api_url))), + ) + .child( + Button::new("reset-api-url", "Reset API URL") + .label_size(LabelSize::Small) + .icon(IconName::Undo) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) + .layer(ElevationIndex::ModalSurface) + .on_click( + cx.listener(|this, _, _window, cx| this.reset_api_url(_window, cx)), + ), + ) + .into_any_element() } else { v_flex() + .on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| { + this.save_api_url(cx); + cx.notify(); + })) .gap_2() + .child(self.api_url_editor.clone()) + .into_any_element() + } + } + + fn render_api_key_editor(&self, cx: &Context) -> impl IntoElement { + let state = self.state.read(cx); + let env_var_set = state.api_key_state.is_from_env_var(); + let configured_card_label = if env_var_set { + format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.") + } else { + "API key configured".to_string() + }; + + if !state.api_key_state.has_key() { + v_flex() + .on_action(cx.listener(Self::save_api_key)) + .child(self.api_key_editor.clone()) .child( - v_flex().gap_1().child(Label::new(lmstudio_intro)).child( + Label::new(format!( + "You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed." + )) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .into_any_element() + } else { + ConfiguredApiCard::new(configured_card_label) + .disabled(env_var_set) + .on_click(cx.listener(|this, _, _window, cx| this.reset_api_key(_window, cx))) + .when(env_var_set, |this| { + this.tooltip_label(format!( + "To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable." + )) + }) + .into_any_element() + } + } +} + +impl Render for ConfigurationView { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let is_authenticated = self.state.read(cx).is_authenticated(); + + v_flex() + .gap_2() + .child( + v_flex() + .gap_1() + .child(Label::new("Run local LLMs like Llama, Phi, and Qwen.")) + .child( List::new() .child(ListBulletItem::new( "LM Studio needs to be running with at least one model downloaded.", @@ -690,86 +896,100 @@ impl Render for ConfigurationView { .child(Label::new("To get your first model, try running")) .child(Label::new("lms get qwen2.5-coder-7b").inline_code(cx)), ), - ), - ) - .child( - h_flex() - .w_full() - .justify_between() - .gap_2() - .child( - h_flex() - .w_full() - .gap_2() - .map(|this| { - if is_authenticated { - this.child( - Button::new("lmstudio-site", "LM Studio") - .style(ButtonStyle::Subtle) - .icon(IconName::ArrowUpRight) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .on_click(move |_, _window, cx| { - cx.open_url(LMSTUDIO_SITE) - }) - .into_any_element(), - ) - } else { - this.child( - Button::new( - "download_lmstudio_button", - "Download LM Studio", - ) + ) + .child(Label::new( + "Alternatively, you can connect to an LM Studio server by specifying its \ + URL and API key (may not be required):", + )), + ) + .child(self.render_api_url_editor(cx)) + .child(self.render_api_key_editor(cx)) + .child( + h_flex() + .w_full() + .justify_between() + .gap_2() + .child( + h_flex() + .w_full() + .gap_2() + .map(|this| { + if is_authenticated { + this.child( + Button::new("lmstudio-site", "LM Studio") .style(ButtonStyle::Subtle) .icon(IconName::ArrowUpRight) .icon_size(IconSize::Small) .icon_color(Color::Muted) .on_click(move |_, _window, cx| { - cx.open_url(LMSTUDIO_DOWNLOAD_URL) + cx.open_url(LMSTUDIO_SITE) }) .into_any_element(), + ) + } else { + this.child( + Button::new( + "download_lmstudio_button", + "Download LM Studio", ) - } - }) - .child( - Button::new("view-models", "Model Catalog") .style(ButtonStyle::Subtle) .icon(IconName::ArrowUpRight) .icon_size(IconSize::Small) .icon_color(Color::Muted) .on_click(move |_, _window, cx| { - cx.open_url(LMSTUDIO_CATALOG_URL) - }), - ), - ) - .map(|this| { - if is_authenticated { - this.child( - ButtonLike::new("connected") - .disabled(true) - .cursor_style(gpui::CursorStyle::Arrow) - .child( - h_flex() - .gap_2() - .child(Indicator::dot().color(Color::Success)) - .child(Label::new("Connected")) - .into_any_element(), - ), - ) - } else { - this.child( - Button::new("retry_lmstudio_models", "Connect") - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon(IconName::PlayFilled) - .on_click(cx.listener(move |this, _, _window, cx| { - this.retry_connection(cx) - })), - ) - } - }), - ) - .into_any() - } + cx.open_url(LMSTUDIO_DOWNLOAD_URL) + }) + .into_any_element(), + ) + } + }) + .child( + Button::new("view-models", "Model Catalog") + .style(ButtonStyle::Subtle) + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::Small) + .icon_color(Color::Muted) + .on_click(move |_, _window, cx| { + cx.open_url(LMSTUDIO_CATALOG_URL) + }), + ), + ) + .map(|this| { + if is_authenticated { + this.child( + ButtonLike::new("connected") + .disabled(true) + .cursor_style(CursorStyle::Arrow) + .child( + h_flex() + .gap_2() + .child(Icon::new(IconName::Check).color(Color::Success)) + .child(Label::new("Connected")) + .into_any_element(), + ) + .child( + IconButton::new("refresh-models", IconName::RotateCcw) + .tooltip(Tooltip::text("Refresh Models")) + .on_click(cx.listener(|this, _, _window, cx| { + this.state.update(cx, |state, _| { + state.available_models.clear(); + }); + this.retry_connection(_window, cx); + })), + ), + ) + } else { + this.child( + Button::new("retry_lmstudio_models", "Connect") + .icon_position(IconPosition::Start) + .icon_size(IconSize::XSmall) + .icon(IconName::PlayFilled) + .on_click(cx.listener(move |this, _, _window, cx| { + this.retry_connection(_window, cx) + })), + ) + } + }), + ) } } diff --git a/crates/lmstudio/src/lmstudio.rs b/crates/lmstudio/src/lmstudio.rs index ef2f7b6208f62e079609049b8eff83a80034741e..8a44b7fdefe5262d955606b0413b2b2425014296 100644 --- a/crates/lmstudio/src/lmstudio.rs +++ b/crates/lmstudio/src/lmstudio.rs @@ -354,14 +354,19 @@ pub struct ResponseMessageDelta { pub async fn complete( client: &dyn HttpClient, api_url: &str, + api_key: Option<&str>, request: ChatCompletionRequest, ) -> Result { let uri = format!("{api_url}/chat/completions"); - let request_builder = HttpRequest::builder() + let mut request_builder = HttpRequest::builder() .method(Method::POST) .uri(uri) .header("Content-Type", "application/json"); + if let Some(api_key) = api_key { + request_builder = request_builder.header("Authorization", format!("Bearer {}", api_key)); + } + let serialized_request = serde_json::to_string(&request)?; let request = request_builder.body(AsyncBody::from(serialized_request))?; @@ -386,14 +391,19 @@ pub async fn complete( pub async fn stream_chat_completion( client: &dyn HttpClient, api_url: &str, + api_key: Option<&str>, request: ChatCompletionRequest, ) -> Result>> { let uri = format!("{api_url}/chat/completions"); - let request_builder = http::Request::builder() + let mut request_builder = http::Request::builder() .method(Method::POST) .uri(uri) .header("Content-Type", "application/json"); + if let Some(api_key) = api_key { + request_builder = request_builder.header("Authorization", format!("Bearer {}", api_key)); + } + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; let mut response = client.send(request).await?; if response.status().is_success() { @@ -434,14 +444,19 @@ pub async fn stream_chat_completion( pub async fn get_models( client: &dyn HttpClient, api_url: &str, + api_key: Option<&str>, _: Option, ) -> Result> { let uri = format!("{api_url}/models"); - let request_builder = HttpRequest::builder() + let mut request_builder = HttpRequest::builder() .method(Method::GET) .uri(uri) .header("Accept", "application/json"); + if let Some(api_key) = api_key { + request_builder = request_builder.header("Authorization", format!("Bearer {}", api_key)); + } + let request = request_builder.body(AsyncBody::default())?; let mut response = client.send(request).await?; diff --git a/crates/settings_content/src/language_model.rs b/crates/settings_content/src/language_model.rs index 6af419119d819931f3ad826ff416f1b47c89824f..8ced6e0b487a673ff4dba34cae9c1e2c7ee45d13 100644 --- a/crates/settings_content/src/language_model.rs +++ b/crates/settings_content/src/language_model.rs @@ -148,6 +148,7 @@ impl Default for KeepAlive { #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema, MergeFrom)] pub struct LmStudioSettingsContent { pub api_url: Option, + pub api_key: Option, pub available_models: Option>, }