1use anyhow::{anyhow, Result};
2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
4use http_client::HttpClient;
5use language_model::{AuthenticateError, LanguageModelCompletionEvent};
6use language_model::{
7 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
8 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
9 LanguageModelRequest, RateLimiter, Role,
10};
11use lmstudio::{
12 get_models, preload_model, stream_chat_completion, ChatCompletionRequest, ChatMessage,
13 ModelType,
14};
15use schemars::JsonSchema;
16use serde::{Deserialize, Serialize};
17use settings::{Settings, SettingsStore};
18use std::{collections::BTreeMap, sync::Arc};
19use ui::{prelude::*, ButtonLike, Indicator};
20use util::ResultExt;
21
22use crate::AllLanguageModelSettings;
23
24const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
25const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
26const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
27
28const PROVIDER_ID: &str = "lmstudio";
29const PROVIDER_NAME: &str = "LM Studio";
30
31#[derive(Default, Debug, Clone, PartialEq)]
32pub struct LmStudioSettings {
33 pub api_url: String,
34 pub available_models: Vec<AvailableModel>,
35}
36
37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
38pub struct AvailableModel {
39 /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
40 pub name: String,
41 /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
42 pub display_name: Option<String>,
43 /// The model's context window size.
44 pub max_tokens: usize,
45}
46
47pub struct LmStudioLanguageModelProvider {
48 http_client: Arc<dyn HttpClient>,
49 state: gpui::Entity<State>,
50}
51
52pub struct State {
53 http_client: Arc<dyn HttpClient>,
54 available_models: Vec<lmstudio::Model>,
55 fetch_model_task: Option<Task<Result<()>>>,
56 _subscription: Subscription,
57}
58
59impl State {
60 fn is_authenticated(&self) -> bool {
61 !self.available_models.is_empty()
62 }
63
64 fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
65 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
66 let http_client = self.http_client.clone();
67 let api_url = settings.api_url.clone();
68
69 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
70 cx.spawn(async move |this, cx| {
71 let models = get_models(http_client.as_ref(), &api_url, None).await?;
72
73 let mut models: Vec<lmstudio::Model> = models
74 .into_iter()
75 .filter(|model| model.r#type != ModelType::Embeddings)
76 .map(|model| lmstudio::Model::new(&model.id, None, None))
77 .collect();
78
79 models.sort_by(|a, b| a.name.cmp(&b.name));
80
81 this.update(cx, |this, cx| {
82 this.available_models = models;
83 cx.notify();
84 })
85 })
86 }
87
88 fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
89 let task = self.fetch_models(cx);
90 self.fetch_model_task.replace(task);
91 }
92
93 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
94 if self.is_authenticated() {
95 return Task::ready(Ok(()));
96 }
97
98 let fetch_models_task = self.fetch_models(cx);
99 cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
100 }
101}
102
103impl LmStudioLanguageModelProvider {
104 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
105 let this = Self {
106 http_client: http_client.clone(),
107 state: cx.new(|cx| {
108 let subscription = cx.observe_global::<SettingsStore>({
109 let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
110 move |this: &mut State, cx| {
111 let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
112 if &settings != new_settings {
113 settings = new_settings.clone();
114 this.restart_fetch_models_task(cx);
115 cx.notify();
116 }
117 }
118 });
119
120 State {
121 http_client,
122 available_models: Default::default(),
123 fetch_model_task: None,
124 _subscription: subscription,
125 }
126 }),
127 };
128 this.state
129 .update(cx, |state, cx| state.restart_fetch_models_task(cx));
130 this
131 }
132}
133
134impl LanguageModelProviderState for LmStudioLanguageModelProvider {
135 type ObservableEntity = State;
136
137 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
138 Some(self.state.clone())
139 }
140}
141
142impl LanguageModelProvider for LmStudioLanguageModelProvider {
143 fn id(&self) -> LanguageModelProviderId {
144 LanguageModelProviderId(PROVIDER_ID.into())
145 }
146
147 fn name(&self) -> LanguageModelProviderName {
148 LanguageModelProviderName(PROVIDER_NAME.into())
149 }
150
151 fn icon(&self) -> IconName {
152 IconName::AiLmStudio
153 }
154
155 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
156 self.provided_models(cx).into_iter().next()
157 }
158
159 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
160 let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
161
162 // Add models from the LM Studio API
163 for model in self.state.read(cx).available_models.iter() {
164 models.insert(model.name.clone(), model.clone());
165 }
166
167 // Override with available models from settings
168 for model in AllLanguageModelSettings::get_global(cx)
169 .lmstudio
170 .available_models
171 .iter()
172 {
173 models.insert(
174 model.name.clone(),
175 lmstudio::Model {
176 name: model.name.clone(),
177 display_name: model.display_name.clone(),
178 max_tokens: model.max_tokens,
179 },
180 );
181 }
182
183 models
184 .into_values()
185 .map(|model| {
186 Arc::new(LmStudioLanguageModel {
187 id: LanguageModelId::from(model.name.clone()),
188 model: model.clone(),
189 http_client: self.http_client.clone(),
190 request_limiter: RateLimiter::new(4),
191 }) as Arc<dyn LanguageModel>
192 })
193 .collect()
194 }
195
196 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
197 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
198 let http_client = self.http_client.clone();
199 let api_url = settings.api_url.clone();
200 let id = model.id().0.to_string();
201 cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
202 .detach_and_log_err(cx);
203 }
204
205 fn is_authenticated(&self, cx: &App) -> bool {
206 self.state.read(cx).is_authenticated()
207 }
208
209 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
210 self.state.update(cx, |state, cx| state.authenticate(cx))
211 }
212
213 fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
214 let state = self.state.clone();
215 cx.new(|cx| ConfigurationView::new(state, cx)).into()
216 }
217
218 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
219 self.state.update(cx, |state, cx| state.fetch_models(cx))
220 }
221}
222
223pub struct LmStudioLanguageModel {
224 id: LanguageModelId,
225 model: lmstudio::Model,
226 http_client: Arc<dyn HttpClient>,
227 request_limiter: RateLimiter,
228}
229
230impl LmStudioLanguageModel {
231 fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
232 ChatCompletionRequest {
233 model: self.model.name.clone(),
234 messages: request
235 .messages
236 .into_iter()
237 .map(|msg| match msg.role {
238 Role::User => ChatMessage::User {
239 content: msg.string_contents(),
240 },
241 Role::Assistant => ChatMessage::Assistant {
242 content: Some(msg.string_contents()),
243 tool_calls: None,
244 },
245 Role::System => ChatMessage::System {
246 content: msg.string_contents(),
247 },
248 })
249 .collect(),
250 stream: true,
251 max_tokens: Some(-1),
252 stop: Some(request.stop),
253 temperature: request.temperature.or(Some(0.0)),
254 tools: vec![],
255 }
256 }
257}
258
259impl LanguageModel for LmStudioLanguageModel {
260 fn id(&self) -> LanguageModelId {
261 self.id.clone()
262 }
263
264 fn name(&self) -> LanguageModelName {
265 LanguageModelName::from(self.model.display_name().to_string())
266 }
267
268 fn provider_id(&self) -> LanguageModelProviderId {
269 LanguageModelProviderId(PROVIDER_ID.into())
270 }
271
272 fn provider_name(&self) -> LanguageModelProviderName {
273 LanguageModelProviderName(PROVIDER_NAME.into())
274 }
275
276 fn telemetry_id(&self) -> String {
277 format!("lmstudio/{}", self.model.id())
278 }
279
280 fn max_token_count(&self) -> usize {
281 self.model.max_token_count()
282 }
283
284 fn count_tokens(
285 &self,
286 request: LanguageModelRequest,
287 _cx: &App,
288 ) -> BoxFuture<'static, Result<usize>> {
289 // Endpoint for this is coming soon. In the meantime, hacky estimation
290 let token_count = request
291 .messages
292 .iter()
293 .map(|msg| msg.string_contents().split_whitespace().count())
294 .sum::<usize>();
295
296 let estimated_tokens = (token_count as f64 * 0.75) as usize;
297 async move { Ok(estimated_tokens) }.boxed()
298 }
299
300 fn stream_completion(
301 &self,
302 request: LanguageModelRequest,
303 cx: &AsyncApp,
304 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
305 let request = self.to_lmstudio_request(request);
306
307 let http_client = self.http_client.clone();
308 let Ok(api_url) = cx.update(|cx| {
309 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
310 settings.api_url.clone()
311 }) else {
312 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
313 };
314
315 let future = self.request_limiter.stream(async move {
316 let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
317 let stream = response
318 .filter_map(|response| async move {
319 match response {
320 Ok(fragment) => {
321 // Skip empty deltas
322 if fragment.choices[0].delta.is_object()
323 && fragment.choices[0].delta.as_object().unwrap().is_empty()
324 {
325 return None;
326 }
327
328 // Try to parse the delta as ChatMessage
329 if let Ok(chat_message) = serde_json::from_value::<ChatMessage>(
330 fragment.choices[0].delta.clone(),
331 ) {
332 let content = match chat_message {
333 ChatMessage::User { content } => content,
334 ChatMessage::Assistant { content, .. } => {
335 content.unwrap_or_default()
336 }
337 ChatMessage::System { content } => content,
338 };
339 if !content.is_empty() {
340 Some(Ok(content))
341 } else {
342 None
343 }
344 } else {
345 None
346 }
347 }
348 Err(error) => Some(Err(error)),
349 }
350 })
351 .boxed();
352 Ok(stream)
353 });
354
355 async move {
356 Ok(future
357 .await?
358 .map(|result| result.map(LanguageModelCompletionEvent::Text))
359 .boxed())
360 }
361 .boxed()
362 }
363
364 fn use_any_tool(
365 &self,
366 _request: LanguageModelRequest,
367 _tool_name: String,
368 _tool_description: String,
369 _schema: serde_json::Value,
370 _cx: &AsyncApp,
371 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
372 async move { Ok(futures::stream::empty().boxed()) }.boxed()
373 }
374}
375
376struct ConfigurationView {
377 state: gpui::Entity<State>,
378 loading_models_task: Option<Task<()>>,
379}
380
381impl ConfigurationView {
382 pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
383 let loading_models_task = Some(cx.spawn({
384 let state = state.clone();
385 async move |this, cx| {
386 if let Some(task) = state
387 .update(cx, |state, cx| state.authenticate(cx))
388 .log_err()
389 {
390 task.await.log_err();
391 }
392 this.update(cx, |this, cx| {
393 this.loading_models_task = None;
394 cx.notify();
395 })
396 .log_err();
397 }
398 }));
399
400 Self {
401 state,
402 loading_models_task,
403 }
404 }
405
406 fn retry_connection(&self, cx: &mut App) {
407 self.state
408 .update(cx, |state, cx| state.fetch_models(cx))
409 .detach_and_log_err(cx);
410 }
411}
412
413impl Render for ConfigurationView {
414 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
415 let is_authenticated = self.state.read(cx).is_authenticated();
416
417 let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
418 let lmstudio_reqs = "To use LM Studio as a provider for Zed assistant, it needs to be running with at least one model downloaded.";
419
420 let inline_code_bg = cx.theme().colors().editor_foreground.opacity(0.05);
421
422 if self.loading_models_task.is_some() {
423 div().child(Label::new("Loading models...")).into_any()
424 } else {
425 v_flex()
426 .size_full()
427 .gap_3()
428 .child(
429 v_flex()
430 .size_full()
431 .gap_2()
432 .p_1()
433 .child(Label::new(lmstudio_intro))
434 .child(Label::new(lmstudio_reqs))
435 .child(
436 h_flex()
437 .gap_0p5()
438 .child(Label::new("To get your first model, try running"))
439 .child(
440 div()
441 .bg(inline_code_bg)
442 .px_1p5()
443 .rounded_sm()
444 .child(Label::new("lms get qwen2.5-coder-7b")),
445 ),
446 ),
447 )
448 .child(
449 h_flex()
450 .w_full()
451 .pt_2()
452 .justify_between()
453 .gap_2()
454 .child(
455 h_flex()
456 .w_full()
457 .gap_2()
458 .map(|this| {
459 if is_authenticated {
460 this.child(
461 Button::new("lmstudio-site", "LM Studio")
462 .style(ButtonStyle::Subtle)
463 .icon(IconName::ArrowUpRight)
464 .icon_size(IconSize::XSmall)
465 .icon_color(Color::Muted)
466 .on_click(move |_, _window, cx| {
467 cx.open_url(LMSTUDIO_SITE)
468 })
469 .into_any_element(),
470 )
471 } else {
472 this.child(
473 Button::new(
474 "download_lmstudio_button",
475 "Download LM Studio",
476 )
477 .style(ButtonStyle::Subtle)
478 .icon(IconName::ArrowUpRight)
479 .icon_size(IconSize::XSmall)
480 .icon_color(Color::Muted)
481 .on_click(move |_, _window, cx| {
482 cx.open_url(LMSTUDIO_DOWNLOAD_URL)
483 })
484 .into_any_element(),
485 )
486 }
487 })
488 .child(
489 Button::new("view-models", "Model Catalog")
490 .style(ButtonStyle::Subtle)
491 .icon(IconName::ArrowUpRight)
492 .icon_size(IconSize::XSmall)
493 .icon_color(Color::Muted)
494 .on_click(move |_, _window, cx| {
495 cx.open_url(LMSTUDIO_CATALOG_URL)
496 }),
497 ),
498 )
499 .child(if is_authenticated {
500 // This is only a button to ensure the spacing is correct
501 // it should stay disabled
502 ButtonLike::new("connected")
503 .disabled(true)
504 // Since this won't ever be clickable, we can use the arrow cursor
505 .cursor_style(gpui::CursorStyle::Arrow)
506 .child(
507 h_flex()
508 .gap_2()
509 .child(Indicator::dot().color(Color::Success))
510 .child(Label::new("Connected"))
511 .into_any_element(),
512 )
513 .into_any_element()
514 } else {
515 Button::new("retry_lmstudio_models", "Connect")
516 .icon_position(IconPosition::Start)
517 .icon(IconName::ArrowCircle)
518 .on_click(cx.listener(move |this, _, _window, cx| {
519 this.retry_connection(cx)
520 }))
521 .into_any_element()
522 }),
523 )
524 .into_any()
525 }
526 }
527}