1use anyhow::{Result, anyhow};
2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
4use http_client::HttpClient;
5use language_model::{AuthenticateError, LanguageModelCompletionEvent};
6use language_model::{
7 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
8 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
9 LanguageModelRequest, RateLimiter, Role,
10};
11use lmstudio::{
12 ChatCompletionRequest, ChatMessage, ModelType, get_models, preload_model,
13 stream_chat_completion,
14};
15use schemars::JsonSchema;
16use serde::{Deserialize, Serialize};
17use settings::{Settings, SettingsStore};
18use std::{collections::BTreeMap, sync::Arc};
19use ui::{ButtonLike, Indicator, prelude::*};
20use util::ResultExt;
21
22use crate::AllLanguageModelSettings;
23
24const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
25const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
26const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
27
28const PROVIDER_ID: &str = "lmstudio";
29const PROVIDER_NAME: &str = "LM Studio";
30
31#[derive(Default, Debug, Clone, PartialEq)]
32pub struct LmStudioSettings {
33 pub api_url: String,
34 pub available_models: Vec<AvailableModel>,
35}
36
37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
38pub struct AvailableModel {
39 /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
40 pub name: String,
41 /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
42 pub display_name: Option<String>,
43 /// The model's context window size.
44 pub max_tokens: usize,
45}
46
47pub struct LmStudioLanguageModelProvider {
48 http_client: Arc<dyn HttpClient>,
49 state: gpui::Entity<State>,
50}
51
52pub struct State {
53 http_client: Arc<dyn HttpClient>,
54 available_models: Vec<lmstudio::Model>,
55 fetch_model_task: Option<Task<Result<()>>>,
56 _subscription: Subscription,
57}
58
59impl State {
60 fn is_authenticated(&self) -> bool {
61 !self.available_models.is_empty()
62 }
63
64 fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
65 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
66 let http_client = self.http_client.clone();
67 let api_url = settings.api_url.clone();
68
69 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
70 cx.spawn(async move |this, cx| {
71 let models = get_models(http_client.as_ref(), &api_url, None).await?;
72
73 let mut models: Vec<lmstudio::Model> = models
74 .into_iter()
75 .filter(|model| model.r#type != ModelType::Embeddings)
76 .map(|model| lmstudio::Model::new(&model.id, None, None))
77 .collect();
78
79 models.sort_by(|a, b| a.name.cmp(&b.name));
80
81 this.update(cx, |this, cx| {
82 this.available_models = models;
83 cx.notify();
84 })
85 })
86 }
87
88 fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
89 let task = self.fetch_models(cx);
90 self.fetch_model_task.replace(task);
91 }
92
93 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
94 if self.is_authenticated() {
95 return Task::ready(Ok(()));
96 }
97
98 let fetch_models_task = self.fetch_models(cx);
99 cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
100 }
101}
102
103impl LmStudioLanguageModelProvider {
104 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
105 let this = Self {
106 http_client: http_client.clone(),
107 state: cx.new(|cx| {
108 let subscription = cx.observe_global::<SettingsStore>({
109 let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
110 move |this: &mut State, cx| {
111 let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
112 if &settings != new_settings {
113 settings = new_settings.clone();
114 this.restart_fetch_models_task(cx);
115 cx.notify();
116 }
117 }
118 });
119
120 State {
121 http_client,
122 available_models: Default::default(),
123 fetch_model_task: None,
124 _subscription: subscription,
125 }
126 }),
127 };
128 this.state
129 .update(cx, |state, cx| state.restart_fetch_models_task(cx));
130 this
131 }
132}
133
134impl LanguageModelProviderState for LmStudioLanguageModelProvider {
135 type ObservableEntity = State;
136
137 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
138 Some(self.state.clone())
139 }
140}
141
142impl LanguageModelProvider for LmStudioLanguageModelProvider {
143 fn id(&self) -> LanguageModelProviderId {
144 LanguageModelProviderId(PROVIDER_ID.into())
145 }
146
147 fn name(&self) -> LanguageModelProviderName {
148 LanguageModelProviderName(PROVIDER_NAME.into())
149 }
150
151 fn icon(&self) -> IconName {
152 IconName::AiLmStudio
153 }
154
155 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
156 self.provided_models(cx).into_iter().next()
157 }
158
159 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
160 let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
161
162 // Add models from the LM Studio API
163 for model in self.state.read(cx).available_models.iter() {
164 models.insert(model.name.clone(), model.clone());
165 }
166
167 // Override with available models from settings
168 for model in AllLanguageModelSettings::get_global(cx)
169 .lmstudio
170 .available_models
171 .iter()
172 {
173 models.insert(
174 model.name.clone(),
175 lmstudio::Model {
176 name: model.name.clone(),
177 display_name: model.display_name.clone(),
178 max_tokens: model.max_tokens,
179 },
180 );
181 }
182
183 models
184 .into_values()
185 .map(|model| {
186 Arc::new(LmStudioLanguageModel {
187 id: LanguageModelId::from(model.name.clone()),
188 model: model.clone(),
189 http_client: self.http_client.clone(),
190 request_limiter: RateLimiter::new(4),
191 }) as Arc<dyn LanguageModel>
192 })
193 .collect()
194 }
195
196 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
197 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
198 let http_client = self.http_client.clone();
199 let api_url = settings.api_url.clone();
200 let id = model.id().0.to_string();
201 cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
202 .detach_and_log_err(cx);
203 }
204
205 fn is_authenticated(&self, cx: &App) -> bool {
206 self.state.read(cx).is_authenticated()
207 }
208
209 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
210 self.state.update(cx, |state, cx| state.authenticate(cx))
211 }
212
213 fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
214 let state = self.state.clone();
215 cx.new(|cx| ConfigurationView::new(state, cx)).into()
216 }
217
218 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
219 self.state.update(cx, |state, cx| state.fetch_models(cx))
220 }
221}
222
223pub struct LmStudioLanguageModel {
224 id: LanguageModelId,
225 model: lmstudio::Model,
226 http_client: Arc<dyn HttpClient>,
227 request_limiter: RateLimiter,
228}
229
230impl LmStudioLanguageModel {
231 fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
232 ChatCompletionRequest {
233 model: self.model.name.clone(),
234 messages: request
235 .messages
236 .into_iter()
237 .map(|msg| match msg.role {
238 Role::User => ChatMessage::User {
239 content: msg.string_contents(),
240 },
241 Role::Assistant => ChatMessage::Assistant {
242 content: Some(msg.string_contents()),
243 tool_calls: None,
244 },
245 Role::System => ChatMessage::System {
246 content: msg.string_contents(),
247 },
248 })
249 .collect(),
250 stream: true,
251 max_tokens: Some(-1),
252 stop: Some(request.stop),
253 temperature: request.temperature.or(Some(0.0)),
254 tools: vec![],
255 }
256 }
257}
258
259impl LanguageModel for LmStudioLanguageModel {
260 fn id(&self) -> LanguageModelId {
261 self.id.clone()
262 }
263
264 fn name(&self) -> LanguageModelName {
265 LanguageModelName::from(self.model.display_name().to_string())
266 }
267
268 fn provider_id(&self) -> LanguageModelProviderId {
269 LanguageModelProviderId(PROVIDER_ID.into())
270 }
271
272 fn provider_name(&self) -> LanguageModelProviderName {
273 LanguageModelProviderName(PROVIDER_NAME.into())
274 }
275
276 fn supports_tools(&self) -> bool {
277 false
278 }
279
280 fn telemetry_id(&self) -> String {
281 format!("lmstudio/{}", self.model.id())
282 }
283
284 fn max_token_count(&self) -> usize {
285 self.model.max_token_count()
286 }
287
288 fn count_tokens(
289 &self,
290 request: LanguageModelRequest,
291 _cx: &App,
292 ) -> BoxFuture<'static, Result<usize>> {
293 // Endpoint for this is coming soon. In the meantime, hacky estimation
294 let token_count = request
295 .messages
296 .iter()
297 .map(|msg| msg.string_contents().split_whitespace().count())
298 .sum::<usize>();
299
300 let estimated_tokens = (token_count as f64 * 0.75) as usize;
301 async move { Ok(estimated_tokens) }.boxed()
302 }
303
304 fn stream_completion(
305 &self,
306 request: LanguageModelRequest,
307 cx: &AsyncApp,
308 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
309 let request = self.to_lmstudio_request(request);
310
311 let http_client = self.http_client.clone();
312 let Ok(api_url) = cx.update(|cx| {
313 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
314 settings.api_url.clone()
315 }) else {
316 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
317 };
318
319 let future = self.request_limiter.stream(async move {
320 let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
321 let stream = response
322 .filter_map(|response| async move {
323 match response {
324 Ok(fragment) => {
325 // Skip empty deltas
326 if fragment.choices[0].delta.is_object()
327 && fragment.choices[0].delta.as_object().unwrap().is_empty()
328 {
329 return None;
330 }
331
332 // Try to parse the delta as ChatMessage
333 if let Ok(chat_message) = serde_json::from_value::<ChatMessage>(
334 fragment.choices[0].delta.clone(),
335 ) {
336 let content = match chat_message {
337 ChatMessage::User { content } => content,
338 ChatMessage::Assistant { content, .. } => {
339 content.unwrap_or_default()
340 }
341 ChatMessage::System { content } => content,
342 };
343 if !content.is_empty() {
344 Some(Ok(content))
345 } else {
346 None
347 }
348 } else {
349 None
350 }
351 }
352 Err(error) => Some(Err(error)),
353 }
354 })
355 .boxed();
356 Ok(stream)
357 });
358
359 async move {
360 Ok(future
361 .await?
362 .map(|result| result.map(LanguageModelCompletionEvent::Text))
363 .boxed())
364 }
365 .boxed()
366 }
367
368 fn use_any_tool(
369 &self,
370 _request: LanguageModelRequest,
371 _tool_name: String,
372 _tool_description: String,
373 _schema: serde_json::Value,
374 _cx: &AsyncApp,
375 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
376 async move { Ok(futures::stream::empty().boxed()) }.boxed()
377 }
378}
379
380struct ConfigurationView {
381 state: gpui::Entity<State>,
382 loading_models_task: Option<Task<()>>,
383}
384
385impl ConfigurationView {
386 pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
387 let loading_models_task = Some(cx.spawn({
388 let state = state.clone();
389 async move |this, cx| {
390 if let Some(task) = state
391 .update(cx, |state, cx| state.authenticate(cx))
392 .log_err()
393 {
394 task.await.log_err();
395 }
396 this.update(cx, |this, cx| {
397 this.loading_models_task = None;
398 cx.notify();
399 })
400 .log_err();
401 }
402 }));
403
404 Self {
405 state,
406 loading_models_task,
407 }
408 }
409
410 fn retry_connection(&self, cx: &mut App) {
411 self.state
412 .update(cx, |state, cx| state.fetch_models(cx))
413 .detach_and_log_err(cx);
414 }
415}
416
417impl Render for ConfigurationView {
418 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
419 let is_authenticated = self.state.read(cx).is_authenticated();
420
421 let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
422 let lmstudio_reqs = "To use LM Studio as a provider for Zed assistant, it needs to be running with at least one model downloaded.";
423
424 let inline_code_bg = cx.theme().colors().editor_foreground.opacity(0.05);
425
426 if self.loading_models_task.is_some() {
427 div().child(Label::new("Loading models...")).into_any()
428 } else {
429 v_flex()
430 .size_full()
431 .gap_3()
432 .child(
433 v_flex()
434 .size_full()
435 .gap_2()
436 .p_1()
437 .child(Label::new(lmstudio_intro))
438 .child(Label::new(lmstudio_reqs))
439 .child(
440 h_flex()
441 .gap_0p5()
442 .child(Label::new("To get your first model, try running"))
443 .child(
444 div()
445 .bg(inline_code_bg)
446 .px_1p5()
447 .rounded_sm()
448 .child(Label::new("lms get qwen2.5-coder-7b")),
449 ),
450 ),
451 )
452 .child(
453 h_flex()
454 .w_full()
455 .pt_2()
456 .justify_between()
457 .gap_2()
458 .child(
459 h_flex()
460 .w_full()
461 .gap_2()
462 .map(|this| {
463 if is_authenticated {
464 this.child(
465 Button::new("lmstudio-site", "LM Studio")
466 .style(ButtonStyle::Subtle)
467 .icon(IconName::ArrowUpRight)
468 .icon_size(IconSize::XSmall)
469 .icon_color(Color::Muted)
470 .on_click(move |_, _window, cx| {
471 cx.open_url(LMSTUDIO_SITE)
472 })
473 .into_any_element(),
474 )
475 } else {
476 this.child(
477 Button::new(
478 "download_lmstudio_button",
479 "Download LM Studio",
480 )
481 .style(ButtonStyle::Subtle)
482 .icon(IconName::ArrowUpRight)
483 .icon_size(IconSize::XSmall)
484 .icon_color(Color::Muted)
485 .on_click(move |_, _window, cx| {
486 cx.open_url(LMSTUDIO_DOWNLOAD_URL)
487 })
488 .into_any_element(),
489 )
490 }
491 })
492 .child(
493 Button::new("view-models", "Model Catalog")
494 .style(ButtonStyle::Subtle)
495 .icon(IconName::ArrowUpRight)
496 .icon_size(IconSize::XSmall)
497 .icon_color(Color::Muted)
498 .on_click(move |_, _window, cx| {
499 cx.open_url(LMSTUDIO_CATALOG_URL)
500 }),
501 ),
502 )
503 .child(if is_authenticated {
504 // This is only a button to ensure the spacing is correct
505 // it should stay disabled
506 ButtonLike::new("connected")
507 .disabled(true)
508 // Since this won't ever be clickable, we can use the arrow cursor
509 .cursor_style(gpui::CursorStyle::Arrow)
510 .child(
511 h_flex()
512 .gap_2()
513 .child(Indicator::dot().color(Color::Success))
514 .child(Label::new("Connected"))
515 .into_any_element(),
516 )
517 .into_any_element()
518 } else {
519 Button::new("retry_lmstudio_models", "Connect")
520 .icon_position(IconPosition::Start)
521 .icon(IconName::ArrowCircle)
522 .on_click(cx.listener(move |this, _, _window, cx| {
523 this.retry_connection(cx)
524 }))
525 .into_any_element()
526 }),
527 )
528 .into_any()
529 }
530 }
531}