@@ -1,26 +1,30 @@
use anyhow::{Result, anyhow};
use collections::HashMap;
+use fs::Fs;
use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
-use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task};
+use gpui::{AnyView, App, AsyncApp, Context, CursorStyle, Entity, Subscription, Task};
use http_client::HttpClient;
use language_model::{
- AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
- LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
- StopReason, TokenUsage,
+ ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
+ LanguageModelCompletionEvent, LanguageModelToolChoice, LanguageModelToolResultContent,
+ LanguageModelToolUse, MessageContent, StopReason, TokenUsage, env_var,
};
use language_model::{
- IconOrSvg, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
- LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
- LanguageModelRequest, RateLimiter, Role,
+ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
+ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
-use lmstudio::{ModelType, get_models};
+use lmstudio::{LMSTUDIO_API_URL, ModelType, get_models};
+
pub use settings::LmStudioAvailableModel as AvailableModel;
-use settings::{Settings, SettingsStore};
+use settings::{Settings, SettingsStore, update_settings_file};
use std::pin::Pin;
+use std::sync::LazyLock;
use std::{collections::BTreeMap, sync::Arc};
-use ui::{ButtonLike, Indicator, List, ListBulletItem, prelude::*};
-use util::ResultExt;
+use ui::{
+ ButtonLike, ConfiguredApiCard, ElevationIndex, List, ListBulletItem, Tooltip, prelude::*,
+};
+use ui_input::InputField;
use crate::AllLanguageModelSettings;
use crate::provider::util::parse_tool_arguments;
@@ -32,6 +36,9 @@ const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("lmstudio");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("LM Studio");
+const API_KEY_ENV_VAR_NAME: &str = "LMSTUDIO_API_KEY";
+static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
+
#[derive(Default, Debug, Clone, PartialEq)]
pub struct LmStudioSettings {
pub api_url: String,
@@ -44,6 +51,7 @@ pub struct LmStudioLanguageModelProvider {
}
pub struct State {
+ api_key_state: ApiKeyState,
http_client: Arc<dyn HttpClient>,
available_models: Vec<lmstudio::Model>,
fetch_model_task: Option<Task<Result<()>>>,
@@ -55,14 +63,25 @@ impl State {
!self.available_models.is_empty()
}
+ fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
+ let api_url = LmStudioLanguageModelProvider::api_url(cx).into();
+ let task = self
+ .api_key_state
+ .store(api_url, api_key, |this| &mut this.api_key_state, cx);
+ self.restart_fetch_models_task(cx);
+ task
+ }
+
fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
+ let api_key = self.api_key_state.key(&api_url);
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(async move |this, cx| {
- let models = get_models(http_client.as_ref(), &api_url, None).await?;
+ let models =
+ get_models(http_client.as_ref(), &api_url, api_key.as_deref(), None).await?;
let mut models: Vec<lmstudio::Model> = models
.into_iter()
@@ -95,6 +114,11 @@ impl State {
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
+ let api_url = LmStudioLanguageModelProvider::api_url(cx).into();
+ let _task = self
+ .api_key_state
+ .load_if_needed(api_url, |this| &mut this.api_key_state, cx);
+
if self.is_authenticated() {
return Task::ready(Ok(()));
}
@@ -145,6 +169,10 @@ impl LmStudioLanguageModelProvider {
});
State {
+ api_key_state: ApiKeyState::new(
+ Self::api_url(cx).into(),
+ (*API_KEY_ENV_VAR).clone(),
+ ),
http_client,
available_models: Default::default(),
fetch_model_task: None,
@@ -156,6 +184,17 @@ impl LmStudioLanguageModelProvider {
.update(cx, |state, cx| state.restart_fetch_models_task(cx));
this
}
+
+ fn api_url(cx: &App) -> String {
+ AllLanguageModelSettings::get_global(cx)
+ .lmstudio
+ .api_url
+ .clone()
+ }
+
+ fn has_custom_url(cx: &App) -> bool {
+ Self::api_url(cx) != LMSTUDIO_API_URL
+ }
}
impl LanguageModelProviderState for LmStudioLanguageModelProvider {
@@ -225,6 +264,7 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
model,
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
+ state: self.state.clone(),
}) as Arc<dyn LanguageModel>
})
.collect()
@@ -244,12 +284,13 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
_window: &mut Window,
cx: &mut App,
) -> AnyView {
- let state = self.state.clone();
- cx.new(|cx| ConfigurationView::new(state, cx)).into()
+ cx.new(|cx| ConfigurationView::new(self.state.clone(), _window, cx))
+ .into()
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
- self.state.update(cx, |state, cx| state.fetch_models(cx))
+ self.state
+ .update(cx, |state, cx| state.set_api_key(None, cx))
}
}
@@ -258,6 +299,7 @@ pub struct LmStudioLanguageModel {
model: lmstudio::Model,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
+ state: Entity<State>,
}
impl LmStudioLanguageModel {
@@ -376,15 +418,20 @@ impl LmStudioLanguageModel {
Result<futures::stream::BoxStream<'static, Result<lmstudio::ResponseStreamEvent>>>,
> {
let http_client = self.http_client.clone();
- let api_url = cx.update(|cx| {
- let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
- settings.api_url.clone()
+ let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
+ let api_url = LmStudioLanguageModelProvider::api_url(cx);
+ (state.api_key_state.key(&api_url), api_url)
});
let future = self.request_limiter.stream(async move {
- let request = lmstudio::stream_chat_completion(http_client.as_ref(), &api_url, request);
- let response = request.await?;
- Ok(response)
+ let stream = lmstudio::stream_chat_completion(
+ http_client.as_ref(),
+ &api_url,
+ api_key.as_deref(),
+ request,
+ )
+ .await?;
+ Ok(stream)
});
async move { Ok(future.await?.boxed()) }.boxed()
@@ -634,53 +681,212 @@ fn add_message_content_part(
struct ConfigurationView {
state: Entity<State>,
- loading_models_task: Option<Task<()>>,
+ api_key_editor: Entity<InputField>,
+ api_url_editor: Entity<InputField>,
}
impl ConfigurationView {
- pub fn new(state: Entity<State>, cx: &mut Context<Self>) -> Self {
- let loading_models_task = Some(cx.spawn({
- let state = state.clone();
- async move |this, cx| {
- state
- .update(cx, |state, cx| state.authenticate(cx))
- .await
- .log_err();
-
- this.update(cx, |this, cx| {
- this.loading_models_task = None;
- cx.notify();
- })
- .log_err();
- }
- }));
+ pub fn new(state: Entity<State>, _window: &mut Window, cx: &mut Context<Self>) -> Self {
+ let api_key_editor = cx.new(|cx| InputField::new(_window, cx, "sk-...").label("API key"));
+
+ let api_url_editor = cx.new(|cx| {
+ let input = InputField::new(_window, cx, LMSTUDIO_API_URL).label("API URL");
+ input.set_text(&LmStudioLanguageModelProvider::api_url(cx), _window, cx);
+ input
+ });
+
+ cx.observe(&state, |_, _, cx| {
+ cx.notify();
+ })
+ .detach();
Self {
state,
- loading_models_task,
+ api_key_editor,
+ api_url_editor,
}
}
- fn retry_connection(&self, cx: &mut App) {
+ fn retry_connection(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
+ let has_api_url = LmStudioLanguageModelProvider::has_custom_url(cx);
+ let has_api_key = self
+ .state
+ .read_with(cx, |state, _| state.api_key_state.has_key());
+ if !has_api_url {
+ self.save_api_url(cx);
+ }
+ if !has_api_key {
+ self.save_api_key(&Default::default(), _window, cx);
+ }
+
+ self.state.update(cx, |state, cx| {
+ state.restart_fetch_models_task(cx);
+ });
+ }
+
+ fn save_api_key(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
+ let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
+ if api_key.is_empty() {
+ return;
+ }
+
+ self.api_key_editor
+ .update(cx, |input, cx| input.set_text("", _window, cx));
+
+ let state = self.state.clone();
+ cx.spawn_in(_window, async move |_, cx| {
+ state
+ .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
+ .await
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn reset_api_key(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
+ self.api_key_editor
+ .update(cx, |input, cx| input.set_text("", _window, cx));
+
+ let state = self.state.clone();
+ cx.spawn_in(_window, async move |_, cx| {
+ state
+ .update(cx, |state, cx| state.set_api_key(None, cx))
+ .await
+ })
+ .detach_and_log_err(cx);
+
+ cx.notify();
+ }
+
+ fn save_api_url(&self, cx: &mut Context<Self>) {
+ let api_url = self.api_url_editor.read(cx).text(cx).trim().to_string();
+ let current_url = LmStudioLanguageModelProvider::api_url(cx);
+ if !api_url.is_empty() && &api_url != ¤t_url {
+ self.state
+ .update(cx, |state, cx| state.set_api_key(None, cx))
+ .detach_and_log_err(cx);
+
+ let fs = <dyn Fs>::global(cx);
+ update_settings_file(fs, cx, move |settings, _| {
+ settings
+ .language_models
+ .get_or_insert_default()
+ .lmstudio
+ .get_or_insert_default()
+ .api_url = Some(api_url);
+ });
+ }
+ }
+
+ fn reset_api_url(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
+ self.api_url_editor
+ .update(cx, |input, cx| input.set_text("", _window, cx));
+
+ // Clear API key when URL changes since keys are URL-specific
self.state
- .update(cx, |state, cx| state.fetch_models(cx))
+ .update(cx, |state, cx| state.set_api_key(None, cx))
.detach_and_log_err(cx);
- }
-}
-impl Render for ConfigurationView {
- fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
- let is_authenticated = self.state.read(cx).is_authenticated();
+ let fs = <dyn Fs>::global(cx);
+ update_settings_file(fs, cx, |settings, _cx| {
+ if let Some(settings) = settings
+ .language_models
+ .as_mut()
+ .and_then(|models| models.lmstudio.as_mut())
+ {
+ settings.api_url = Some(LMSTUDIO_API_URL.into());
+ }
+ });
+ cx.notify();
+ }
- let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
+ fn render_api_url_editor(&self, cx: &Context<Self>) -> impl IntoElement {
+ let api_url = LmStudioLanguageModelProvider::api_url(cx);
+ let custom_api_url_set = api_url != LMSTUDIO_API_URL;
- if self.loading_models_task.is_some() {
- div().child(Label::new("Loading models...")).into_any()
+ if custom_api_url_set {
+ h_flex()
+ .p_3()
+ .justify_between()
+ .rounded_md()
+ .border_1()
+ .border_color(cx.theme().colors().border)
+ .bg(cx.theme().colors().elevated_surface_background)
+ .child(
+ h_flex()
+ .gap_2()
+ .child(Icon::new(IconName::Check).color(Color::Success))
+ .child(v_flex().gap_1().child(Label::new(api_url))),
+ )
+ .child(
+ Button::new("reset-api-url", "Reset API URL")
+ .label_size(LabelSize::Small)
+ .icon(IconName::Undo)
+ .icon_size(IconSize::Small)
+ .icon_position(IconPosition::Start)
+ .layer(ElevationIndex::ModalSurface)
+ .on_click(
+ cx.listener(|this, _, _window, cx| this.reset_api_url(_window, cx)),
+ ),
+ )
+ .into_any_element()
} else {
v_flex()
+ .on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| {
+ this.save_api_url(cx);
+ cx.notify();
+ }))
.gap_2()
+ .child(self.api_url_editor.clone())
+ .into_any_element()
+ }
+ }
+
+ fn render_api_key_editor(&self, cx: &Context<Self>) -> impl IntoElement {
+ let state = self.state.read(cx);
+ let env_var_set = state.api_key_state.is_from_env_var();
+ let configured_card_label = if env_var_set {
+ format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.")
+ } else {
+ "API key configured".to_string()
+ };
+
+ if !state.api_key_state.has_key() {
+ v_flex()
+ .on_action(cx.listener(Self::save_api_key))
+ .child(self.api_key_editor.clone())
.child(
- v_flex().gap_1().child(Label::new(lmstudio_intro)).child(
+ Label::new(format!(
+ "You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
+ ))
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .into_any_element()
+ } else {
+ ConfiguredApiCard::new(configured_card_label)
+ .disabled(env_var_set)
+ .on_click(cx.listener(|this, _, _window, cx| this.reset_api_key(_window, cx)))
+ .when(env_var_set, |this| {
+ this.tooltip_label(format!(
+ "To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."
+ ))
+ })
+ .into_any_element()
+ }
+ }
+}
+
+impl Render for ConfigurationView {
+ fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
+ let is_authenticated = self.state.read(cx).is_authenticated();
+
+ v_flex()
+ .gap_2()
+ .child(
+ v_flex()
+ .gap_1()
+ .child(Label::new("Run local LLMs like Llama, Phi, and Qwen."))
+ .child(
List::new()
.child(ListBulletItem::new(
"LM Studio needs to be running with at least one model downloaded.",
@@ -690,86 +896,100 @@ impl Render for ConfigurationView {
.child(Label::new("To get your first model, try running"))
.child(Label::new("lms get qwen2.5-coder-7b").inline_code(cx)),
),
- ),
- )
- .child(
- h_flex()
- .w_full()
- .justify_between()
- .gap_2()
- .child(
- h_flex()
- .w_full()
- .gap_2()
- .map(|this| {
- if is_authenticated {
- this.child(
- Button::new("lmstudio-site", "LM Studio")
- .style(ButtonStyle::Subtle)
- .icon(IconName::ArrowUpRight)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
- .on_click(move |_, _window, cx| {
- cx.open_url(LMSTUDIO_SITE)
- })
- .into_any_element(),
- )
- } else {
- this.child(
- Button::new(
- "download_lmstudio_button",
- "Download LM Studio",
- )
+ )
+ .child(Label::new(
+ "Alternatively, you can connect to an LM Studio server by specifying its \
+ URL and API key (may not be required):",
+ )),
+ )
+ .child(self.render_api_url_editor(cx))
+ .child(self.render_api_key_editor(cx))
+ .child(
+ h_flex()
+ .w_full()
+ .justify_between()
+ .gap_2()
+ .child(
+ h_flex()
+ .w_full()
+ .gap_2()
+ .map(|this| {
+ if is_authenticated {
+ this.child(
+ Button::new("lmstudio-site", "LM Studio")
.style(ButtonStyle::Subtle)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.on_click(move |_, _window, cx| {
- cx.open_url(LMSTUDIO_DOWNLOAD_URL)
+ cx.open_url(LMSTUDIO_SITE)
})
.into_any_element(),
+ )
+ } else {
+ this.child(
+ Button::new(
+ "download_lmstudio_button",
+ "Download LM Studio",
)
- }
- })
- .child(
- Button::new("view-models", "Model Catalog")
.style(ButtonStyle::Subtle)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.on_click(move |_, _window, cx| {
- cx.open_url(LMSTUDIO_CATALOG_URL)
- }),
- ),
- )
- .map(|this| {
- if is_authenticated {
- this.child(
- ButtonLike::new("connected")
- .disabled(true)
- .cursor_style(gpui::CursorStyle::Arrow)
- .child(
- h_flex()
- .gap_2()
- .child(Indicator::dot().color(Color::Success))
- .child(Label::new("Connected"))
- .into_any_element(),
- ),
- )
- } else {
- this.child(
- Button::new("retry_lmstudio_models", "Connect")
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::XSmall)
- .icon(IconName::PlayFilled)
- .on_click(cx.listener(move |this, _, _window, cx| {
- this.retry_connection(cx)
- })),
- )
- }
- }),
- )
- .into_any()
- }
+ cx.open_url(LMSTUDIO_DOWNLOAD_URL)
+ })
+ .into_any_element(),
+ )
+ }
+ })
+ .child(
+ Button::new("view-models", "Model Catalog")
+ .style(ButtonStyle::Subtle)
+ .icon(IconName::ArrowUpRight)
+ .icon_size(IconSize::Small)
+ .icon_color(Color::Muted)
+ .on_click(move |_, _window, cx| {
+ cx.open_url(LMSTUDIO_CATALOG_URL)
+ }),
+ ),
+ )
+ .map(|this| {
+ if is_authenticated {
+ this.child(
+ ButtonLike::new("connected")
+ .disabled(true)
+ .cursor_style(CursorStyle::Arrow)
+ .child(
+ h_flex()
+ .gap_2()
+ .child(Icon::new(IconName::Check).color(Color::Success))
+ .child(Label::new("Connected"))
+ .into_any_element(),
+ )
+ .child(
+ IconButton::new("refresh-models", IconName::RotateCcw)
+ .tooltip(Tooltip::text("Refresh Models"))
+ .on_click(cx.listener(|this, _, _window, cx| {
+ this.state.update(cx, |state, _| {
+ state.available_models.clear();
+ });
+ this.retry_connection(_window, cx);
+ })),
+ ),
+ )
+ } else {
+ this.child(
+ Button::new("retry_lmstudio_models", "Connect")
+ .icon_position(IconPosition::Start)
+ .icon_size(IconSize::XSmall)
+ .icon(IconName::PlayFilled)
+ .on_click(cx.listener(move |this, _, _window, cx| {
+ this.retry_connection(_window, cx)
+ })),
+ )
+ }
+ }),
+ )
}
}