1use anyhow::{anyhow, Result};
2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
4use http_client::HttpClient;
5use language_model::{AuthenticateError, LanguageModelCompletionEvent};
6use language_model::{
7 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
8 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
9 LanguageModelRequest, RateLimiter, Role,
10};
11use lmstudio::{
12 get_models, preload_model, stream_chat_completion, ChatCompletionRequest, ChatMessage,
13 ModelType,
14};
15use schemars::JsonSchema;
16use serde::{Deserialize, Serialize};
17use settings::{Settings, SettingsStore};
18use std::{collections::BTreeMap, sync::Arc};
19use ui::{prelude::*, ButtonLike, Indicator};
20use util::ResultExt;
21
22use crate::AllLanguageModelSettings;
23
24const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
25const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
26const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
27
28const PROVIDER_ID: &str = "lmstudio";
29const PROVIDER_NAME: &str = "LM Studio";
30
31#[derive(Default, Debug, Clone, PartialEq)]
32pub struct LmStudioSettings {
33 pub api_url: String,
34 pub available_models: Vec<AvailableModel>,
35}
36
37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
38pub struct AvailableModel {
39 /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
40 pub name: String,
41 /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
42 pub display_name: Option<String>,
43 /// The model's context window size.
44 pub max_tokens: usize,
45}
46
47pub struct LmStudioLanguageModelProvider {
48 http_client: Arc<dyn HttpClient>,
49 state: gpui::Entity<State>,
50}
51
52pub struct State {
53 http_client: Arc<dyn HttpClient>,
54 available_models: Vec<lmstudio::Model>,
55 fetch_model_task: Option<Task<Result<()>>>,
56 _subscription: Subscription,
57}
58
59impl State {
60 fn is_authenticated(&self) -> bool {
61 !self.available_models.is_empty()
62 }
63
64 fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
65 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
66 let http_client = self.http_client.clone();
67 let api_url = settings.api_url.clone();
68
69 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
70 cx.spawn(|this, mut cx| async move {
71 let models = get_models(http_client.as_ref(), &api_url, None).await?;
72
73 let mut models: Vec<lmstudio::Model> = models
74 .into_iter()
75 .filter(|model| model.r#type != ModelType::Embeddings)
76 .map(|model| lmstudio::Model::new(&model.id, None, None))
77 .collect();
78
79 models.sort_by(|a, b| a.name.cmp(&b.name));
80
81 this.update(&mut cx, |this, cx| {
82 this.available_models = models;
83 cx.notify();
84 })
85 })
86 }
87
88 fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
89 let task = self.fetch_models(cx);
90 self.fetch_model_task.replace(task);
91 }
92
93 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
94 if self.is_authenticated() {
95 return Task::ready(Ok(()));
96 }
97
98 let fetch_models_task = self.fetch_models(cx);
99 cx.spawn(|_this, _cx| async move { Ok(fetch_models_task.await?) })
100 }
101}
102
103impl LmStudioLanguageModelProvider {
104 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
105 let this = Self {
106 http_client: http_client.clone(),
107 state: cx.new(|cx| {
108 let subscription = cx.observe_global::<SettingsStore>({
109 let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
110 move |this: &mut State, cx| {
111 let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
112 if &settings != new_settings {
113 settings = new_settings.clone();
114 this.restart_fetch_models_task(cx);
115 cx.notify();
116 }
117 }
118 });
119
120 State {
121 http_client,
122 available_models: Default::default(),
123 fetch_model_task: None,
124 _subscription: subscription,
125 }
126 }),
127 };
128 this.state
129 .update(cx, |state, cx| state.restart_fetch_models_task(cx));
130 this
131 }
132}
133
134impl LanguageModelProviderState for LmStudioLanguageModelProvider {
135 type ObservableEntity = State;
136
137 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
138 Some(self.state.clone())
139 }
140}
141
142impl LanguageModelProvider for LmStudioLanguageModelProvider {
143 fn id(&self) -> LanguageModelProviderId {
144 LanguageModelProviderId(PROVIDER_ID.into())
145 }
146
147 fn name(&self) -> LanguageModelProviderName {
148 LanguageModelProviderName(PROVIDER_NAME.into())
149 }
150
151 fn icon(&self) -> IconName {
152 IconName::AiLmStudio
153 }
154
155 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
156 let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
157
158 // Add models from the LM Studio API
159 for model in self.state.read(cx).available_models.iter() {
160 models.insert(model.name.clone(), model.clone());
161 }
162
163 // Override with available models from settings
164 for model in AllLanguageModelSettings::get_global(cx)
165 .lmstudio
166 .available_models
167 .iter()
168 {
169 models.insert(
170 model.name.clone(),
171 lmstudio::Model {
172 name: model.name.clone(),
173 display_name: model.display_name.clone(),
174 max_tokens: model.max_tokens,
175 },
176 );
177 }
178
179 models
180 .into_values()
181 .map(|model| {
182 Arc::new(LmStudioLanguageModel {
183 id: LanguageModelId::from(model.name.clone()),
184 model: model.clone(),
185 http_client: self.http_client.clone(),
186 request_limiter: RateLimiter::new(4),
187 }) as Arc<dyn LanguageModel>
188 })
189 .collect()
190 }
191
192 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
193 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
194 let http_client = self.http_client.clone();
195 let api_url = settings.api_url.clone();
196 let id = model.id().0.to_string();
197 cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
198 .detach_and_log_err(cx);
199 }
200
201 fn is_authenticated(&self, cx: &App) -> bool {
202 self.state.read(cx).is_authenticated()
203 }
204
205 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
206 self.state.update(cx, |state, cx| state.authenticate(cx))
207 }
208
209 fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
210 let state = self.state.clone();
211 cx.new(|cx| ConfigurationView::new(state, cx)).into()
212 }
213
214 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
215 self.state.update(cx, |state, cx| state.fetch_models(cx))
216 }
217}
218
219pub struct LmStudioLanguageModel {
220 id: LanguageModelId,
221 model: lmstudio::Model,
222 http_client: Arc<dyn HttpClient>,
223 request_limiter: RateLimiter,
224}
225
226impl LmStudioLanguageModel {
227 fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
228 ChatCompletionRequest {
229 model: self.model.name.clone(),
230 messages: request
231 .messages
232 .into_iter()
233 .map(|msg| match msg.role {
234 Role::User => ChatMessage::User {
235 content: msg.string_contents(),
236 },
237 Role::Assistant => ChatMessage::Assistant {
238 content: Some(msg.string_contents()),
239 tool_calls: None,
240 },
241 Role::System => ChatMessage::System {
242 content: msg.string_contents(),
243 },
244 })
245 .collect(),
246 stream: true,
247 max_tokens: Some(-1),
248 stop: Some(request.stop),
249 temperature: request.temperature.or(Some(0.0)),
250 tools: vec![],
251 }
252 }
253}
254
255impl LanguageModel for LmStudioLanguageModel {
256 fn id(&self) -> LanguageModelId {
257 self.id.clone()
258 }
259
260 fn name(&self) -> LanguageModelName {
261 LanguageModelName::from(self.model.display_name().to_string())
262 }
263
264 fn provider_id(&self) -> LanguageModelProviderId {
265 LanguageModelProviderId(PROVIDER_ID.into())
266 }
267
268 fn provider_name(&self) -> LanguageModelProviderName {
269 LanguageModelProviderName(PROVIDER_NAME.into())
270 }
271
272 fn telemetry_id(&self) -> String {
273 format!("lmstudio/{}", self.model.id())
274 }
275
276 fn max_token_count(&self) -> usize {
277 self.model.max_token_count()
278 }
279
280 fn count_tokens(
281 &self,
282 request: LanguageModelRequest,
283 _cx: &App,
284 ) -> BoxFuture<'static, Result<usize>> {
285 // Endpoint for this is coming soon. In the meantime, hacky estimation
286 let token_count = request
287 .messages
288 .iter()
289 .map(|msg| msg.string_contents().split_whitespace().count())
290 .sum::<usize>();
291
292 let estimated_tokens = (token_count as f64 * 0.75) as usize;
293 async move { Ok(estimated_tokens) }.boxed()
294 }
295
296 fn stream_completion(
297 &self,
298 request: LanguageModelRequest,
299 cx: &AsyncApp,
300 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
301 let request = self.to_lmstudio_request(request);
302
303 let http_client = self.http_client.clone();
304 let Ok(api_url) = cx.update(|cx| {
305 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
306 settings.api_url.clone()
307 }) else {
308 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
309 };
310
311 let future = self.request_limiter.stream(async move {
312 let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
313 let stream = response
314 .filter_map(|response| async move {
315 match response {
316 Ok(fragment) => {
317 // Skip empty deltas
318 if fragment.choices[0].delta.is_object()
319 && fragment.choices[0].delta.as_object().unwrap().is_empty()
320 {
321 return None;
322 }
323
324 // Try to parse the delta as ChatMessage
325 if let Ok(chat_message) = serde_json::from_value::<ChatMessage>(
326 fragment.choices[0].delta.clone(),
327 ) {
328 let content = match chat_message {
329 ChatMessage::User { content } => content,
330 ChatMessage::Assistant { content, .. } => {
331 content.unwrap_or_default()
332 }
333 ChatMessage::System { content } => content,
334 };
335 if !content.is_empty() {
336 Some(Ok(content))
337 } else {
338 None
339 }
340 } else {
341 None
342 }
343 }
344 Err(error) => Some(Err(error)),
345 }
346 })
347 .boxed();
348 Ok(stream)
349 });
350
351 async move {
352 Ok(future
353 .await?
354 .map(|result| result.map(LanguageModelCompletionEvent::Text))
355 .boxed())
356 }
357 .boxed()
358 }
359
360 fn use_any_tool(
361 &self,
362 _request: LanguageModelRequest,
363 _tool_name: String,
364 _tool_description: String,
365 _schema: serde_json::Value,
366 _cx: &AsyncApp,
367 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
368 async move { Ok(futures::stream::empty().boxed()) }.boxed()
369 }
370}
371
372struct ConfigurationView {
373 state: gpui::Entity<State>,
374 loading_models_task: Option<Task<()>>,
375}
376
377impl ConfigurationView {
378 pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
379 let loading_models_task = Some(cx.spawn({
380 let state = state.clone();
381 |this, mut cx| async move {
382 if let Some(task) = state
383 .update(&mut cx, |state, cx| state.authenticate(cx))
384 .log_err()
385 {
386 task.await.log_err();
387 }
388 this.update(&mut cx, |this, cx| {
389 this.loading_models_task = None;
390 cx.notify();
391 })
392 .log_err();
393 }
394 }));
395
396 Self {
397 state,
398 loading_models_task,
399 }
400 }
401
402 fn retry_connection(&self, cx: &mut App) {
403 self.state
404 .update(cx, |state, cx| state.fetch_models(cx))
405 .detach_and_log_err(cx);
406 }
407}
408
409impl Render for ConfigurationView {
410 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
411 let is_authenticated = self.state.read(cx).is_authenticated();
412
413 let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
414 let lmstudio_reqs =
415 "To use LM Studio as a provider for Zed assistant, it needs to be running with at least one model downloaded.";
416
417 let inline_code_bg = cx.theme().colors().editor_foreground.opacity(0.05);
418
419 if self.loading_models_task.is_some() {
420 div().child(Label::new("Loading models...")).into_any()
421 } else {
422 v_flex()
423 .size_full()
424 .gap_3()
425 .child(
426 v_flex()
427 .size_full()
428 .gap_2()
429 .p_1()
430 .child(Label::new(lmstudio_intro))
431 .child(Label::new(lmstudio_reqs))
432 .child(
433 h_flex()
434 .gap_0p5()
435 .child(Label::new("To get your first model, try running"))
436 .child(
437 div()
438 .bg(inline_code_bg)
439 .px_1p5()
440 .rounded_md()
441 .child(Label::new("lms get qwen2.5-coder-7b")),
442 ),
443 ),
444 )
445 .child(
446 h_flex()
447 .w_full()
448 .pt_2()
449 .justify_between()
450 .gap_2()
451 .child(
452 h_flex()
453 .w_full()
454 .gap_2()
455 .map(|this| {
456 if is_authenticated {
457 this.child(
458 Button::new("lmstudio-site", "LM Studio")
459 .style(ButtonStyle::Subtle)
460 .icon(IconName::ArrowUpRight)
461 .icon_size(IconSize::XSmall)
462 .icon_color(Color::Muted)
463 .on_click(move |_, _window, cx| {
464 cx.open_url(LMSTUDIO_SITE)
465 })
466 .into_any_element(),
467 )
468 } else {
469 this.child(
470 Button::new(
471 "download_lmstudio_button",
472 "Download LM Studio",
473 )
474 .style(ButtonStyle::Subtle)
475 .icon(IconName::ArrowUpRight)
476 .icon_size(IconSize::XSmall)
477 .icon_color(Color::Muted)
478 .on_click(move |_, _window, cx| {
479 cx.open_url(LMSTUDIO_DOWNLOAD_URL)
480 })
481 .into_any_element(),
482 )
483 }
484 })
485 .child(
486 Button::new("view-models", "Model Catalog")
487 .style(ButtonStyle::Subtle)
488 .icon(IconName::ArrowUpRight)
489 .icon_size(IconSize::XSmall)
490 .icon_color(Color::Muted)
491 .on_click(move |_, _window, cx| {
492 cx.open_url(LMSTUDIO_CATALOG_URL)
493 }),
494 ),
495 )
496 .child(if is_authenticated {
497 // This is only a button to ensure the spacing is correct
498 // it should stay disabled
499 ButtonLike::new("connected")
500 .disabled(true)
501 // Since this won't ever be clickable, we can use the arrow cursor
502 .cursor_style(gpui::CursorStyle::Arrow)
503 .child(
504 h_flex()
505 .gap_2()
506 .child(Indicator::dot().color(Color::Success))
507 .child(Label::new("Connected"))
508 .into_any_element(),
509 )
510 .into_any_element()
511 } else {
512 Button::new("retry_lmstudio_models", "Connect")
513 .icon_position(IconPosition::Start)
514 .icon(IconName::ArrowCircle)
515 .on_click(cx.listener(move |this, _, _window, cx| {
516 this.retry_connection(cx)
517 }))
518 .into_any_element()
519 }),
520 )
521 .into_any()
522 }
523 }
524}