1use anyhow::{Result, anyhow};
2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
4use http_client::HttpClient;
5use language_model::{AuthenticateError, LanguageModelCompletionEvent};
6use language_model::{
7 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
8 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
9 LanguageModelRequest, RateLimiter, Role,
10};
11use lmstudio::{
12 ChatCompletionRequest, ChatMessage, ModelType, get_models, preload_model,
13 stream_chat_completion,
14};
15use schemars::JsonSchema;
16use serde::{Deserialize, Serialize};
17use settings::{Settings, SettingsStore};
18use std::{collections::BTreeMap, sync::Arc};
19use ui::{ButtonLike, Indicator, prelude::*};
20use util::ResultExt;
21
22use crate::AllLanguageModelSettings;
23
24const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
25const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
26const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
27
28const PROVIDER_ID: &str = "lmstudio";
29const PROVIDER_NAME: &str = "LM Studio";
30
31#[derive(Default, Debug, Clone, PartialEq)]
32pub struct LmStudioSettings {
33 pub api_url: String,
34 pub available_models: Vec<AvailableModel>,
35}
36
37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
38pub struct AvailableModel {
39 /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
40 pub name: String,
41 /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
42 pub display_name: Option<String>,
43 /// The model's context window size.
44 pub max_tokens: usize,
45}
46
47pub struct LmStudioLanguageModelProvider {
48 http_client: Arc<dyn HttpClient>,
49 state: gpui::Entity<State>,
50}
51
52pub struct State {
53 http_client: Arc<dyn HttpClient>,
54 available_models: Vec<lmstudio::Model>,
55 fetch_model_task: Option<Task<Result<()>>>,
56 _subscription: Subscription,
57}
58
59impl State {
60 fn is_authenticated(&self) -> bool {
61 !self.available_models.is_empty()
62 }
63
64 fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
65 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
66 let http_client = self.http_client.clone();
67 let api_url = settings.api_url.clone();
68
69 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
70 cx.spawn(async move |this, cx| {
71 let models = get_models(http_client.as_ref(), &api_url, None).await?;
72
73 let mut models: Vec<lmstudio::Model> = models
74 .into_iter()
75 .filter(|model| model.r#type != ModelType::Embeddings)
76 .map(|model| lmstudio::Model::new(&model.id, None, None))
77 .collect();
78
79 models.sort_by(|a, b| a.name.cmp(&b.name));
80
81 this.update(cx, |this, cx| {
82 this.available_models = models;
83 cx.notify();
84 })
85 })
86 }
87
88 fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
89 let task = self.fetch_models(cx);
90 self.fetch_model_task.replace(task);
91 }
92
93 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
94 if self.is_authenticated() {
95 return Task::ready(Ok(()));
96 }
97
98 let fetch_models_task = self.fetch_models(cx);
99 cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
100 }
101}
102
103impl LmStudioLanguageModelProvider {
104 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
105 let this = Self {
106 http_client: http_client.clone(),
107 state: cx.new(|cx| {
108 let subscription = cx.observe_global::<SettingsStore>({
109 let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
110 move |this: &mut State, cx| {
111 let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
112 if &settings != new_settings {
113 settings = new_settings.clone();
114 this.restart_fetch_models_task(cx);
115 cx.notify();
116 }
117 }
118 });
119
120 State {
121 http_client,
122 available_models: Default::default(),
123 fetch_model_task: None,
124 _subscription: subscription,
125 }
126 }),
127 };
128 this.state
129 .update(cx, |state, cx| state.restart_fetch_models_task(cx));
130 this
131 }
132}
133
134impl LanguageModelProviderState for LmStudioLanguageModelProvider {
135 type ObservableEntity = State;
136
137 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
138 Some(self.state.clone())
139 }
140}
141
142impl LanguageModelProvider for LmStudioLanguageModelProvider {
143 fn id(&self) -> LanguageModelProviderId {
144 LanguageModelProviderId(PROVIDER_ID.into())
145 }
146
147 fn name(&self) -> LanguageModelProviderName {
148 LanguageModelProviderName(PROVIDER_NAME.into())
149 }
150
151 fn icon(&self) -> IconName {
152 IconName::AiLmStudio
153 }
154
155 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
156 self.provided_models(cx).into_iter().next()
157 }
158
159 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
160 let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
161
162 // Add models from the LM Studio API
163 for model in self.state.read(cx).available_models.iter() {
164 models.insert(model.name.clone(), model.clone());
165 }
166
167 // Override with available models from settings
168 for model in AllLanguageModelSettings::get_global(cx)
169 .lmstudio
170 .available_models
171 .iter()
172 {
173 models.insert(
174 model.name.clone(),
175 lmstudio::Model {
176 name: model.name.clone(),
177 display_name: model.display_name.clone(),
178 max_tokens: model.max_tokens,
179 },
180 );
181 }
182
183 models
184 .into_values()
185 .map(|model| {
186 Arc::new(LmStudioLanguageModel {
187 id: LanguageModelId::from(model.name.clone()),
188 model: model.clone(),
189 http_client: self.http_client.clone(),
190 request_limiter: RateLimiter::new(4),
191 }) as Arc<dyn LanguageModel>
192 })
193 .collect()
194 }
195
196 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
197 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
198 let http_client = self.http_client.clone();
199 let api_url = settings.api_url.clone();
200 let id = model.id().0.to_string();
201 cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
202 .detach_and_log_err(cx);
203 }
204
205 fn is_authenticated(&self, cx: &App) -> bool {
206 self.state.read(cx).is_authenticated()
207 }
208
209 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
210 self.state.update(cx, |state, cx| state.authenticate(cx))
211 }
212
213 fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
214 let state = self.state.clone();
215 cx.new(|cx| ConfigurationView::new(state, cx)).into()
216 }
217
218 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
219 self.state.update(cx, |state, cx| state.fetch_models(cx))
220 }
221}
222
223pub struct LmStudioLanguageModel {
224 id: LanguageModelId,
225 model: lmstudio::Model,
226 http_client: Arc<dyn HttpClient>,
227 request_limiter: RateLimiter,
228}
229
230impl LmStudioLanguageModel {
231 fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
232 ChatCompletionRequest {
233 model: self.model.name.clone(),
234 messages: request
235 .messages
236 .into_iter()
237 .map(|msg| match msg.role {
238 Role::User => ChatMessage::User {
239 content: msg.string_contents(),
240 },
241 Role::Assistant => ChatMessage::Assistant {
242 content: Some(msg.string_contents()),
243 tool_calls: None,
244 },
245 Role::System => ChatMessage::System {
246 content: msg.string_contents(),
247 },
248 })
249 .collect(),
250 stream: true,
251 max_tokens: Some(-1),
252 stop: Some(request.stop),
253 temperature: request.temperature.or(Some(0.0)),
254 tools: vec![],
255 }
256 }
257}
258
259impl LanguageModel for LmStudioLanguageModel {
260 fn id(&self) -> LanguageModelId {
261 self.id.clone()
262 }
263
264 fn name(&self) -> LanguageModelName {
265 LanguageModelName::from(self.model.display_name().to_string())
266 }
267
268 fn provider_id(&self) -> LanguageModelProviderId {
269 LanguageModelProviderId(PROVIDER_ID.into())
270 }
271
272 fn provider_name(&self) -> LanguageModelProviderName {
273 LanguageModelProviderName(PROVIDER_NAME.into())
274 }
275
276 fn supports_tools(&self) -> bool {
277 false
278 }
279
280 fn telemetry_id(&self) -> String {
281 format!("lmstudio/{}", self.model.id())
282 }
283
284 fn max_token_count(&self) -> usize {
285 self.model.max_token_count()
286 }
287
288 fn count_tokens(
289 &self,
290 request: LanguageModelRequest,
291 _cx: &App,
292 ) -> BoxFuture<'static, Result<usize>> {
293 // Endpoint for this is coming soon. In the meantime, hacky estimation
294 let token_count = request
295 .messages
296 .iter()
297 .map(|msg| msg.string_contents().split_whitespace().count())
298 .sum::<usize>();
299
300 let estimated_tokens = (token_count as f64 * 0.75) as usize;
301 async move { Ok(estimated_tokens) }.boxed()
302 }
303
304 fn stream_completion(
305 &self,
306 request: LanguageModelRequest,
307 cx: &AsyncApp,
308 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
309 let request = self.to_lmstudio_request(request);
310
311 let http_client = self.http_client.clone();
312 let Ok(api_url) = cx.update(|cx| {
313 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
314 settings.api_url.clone()
315 }) else {
316 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
317 };
318
319 let future = self.request_limiter.stream(async move {
320 let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
321 let stream = response
322 .filter_map(|response| async move {
323 match response {
324 Ok(fragment) => {
325 // Skip empty deltas
326 if fragment.choices[0].delta.is_object()
327 && fragment.choices[0].delta.as_object().unwrap().is_empty()
328 {
329 return None;
330 }
331
332 // Try to parse the delta as ChatMessage
333 if let Ok(chat_message) = serde_json::from_value::<ChatMessage>(
334 fragment.choices[0].delta.clone(),
335 ) {
336 let content = match chat_message {
337 ChatMessage::User { content } => content,
338 ChatMessage::Assistant { content, .. } => {
339 content.unwrap_or_default()
340 }
341 ChatMessage::System { content } => content,
342 };
343 if !content.is_empty() {
344 Some(Ok(content))
345 } else {
346 None
347 }
348 } else {
349 None
350 }
351 }
352 Err(error) => Some(Err(error)),
353 }
354 })
355 .boxed();
356 Ok(stream)
357 });
358
359 async move {
360 Ok(future
361 .await?
362 .map(|result| result.map(LanguageModelCompletionEvent::Text))
363 .boxed())
364 }
365 .boxed()
366 }
367}
368
369struct ConfigurationView {
370 state: gpui::Entity<State>,
371 loading_models_task: Option<Task<()>>,
372}
373
374impl ConfigurationView {
375 pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
376 let loading_models_task = Some(cx.spawn({
377 let state = state.clone();
378 async move |this, cx| {
379 if let Some(task) = state
380 .update(cx, |state, cx| state.authenticate(cx))
381 .log_err()
382 {
383 task.await.log_err();
384 }
385 this.update(cx, |this, cx| {
386 this.loading_models_task = None;
387 cx.notify();
388 })
389 .log_err();
390 }
391 }));
392
393 Self {
394 state,
395 loading_models_task,
396 }
397 }
398
399 fn retry_connection(&self, cx: &mut App) {
400 self.state
401 .update(cx, |state, cx| state.fetch_models(cx))
402 .detach_and_log_err(cx);
403 }
404}
405
406impl Render for ConfigurationView {
407 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
408 let is_authenticated = self.state.read(cx).is_authenticated();
409
410 let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
411 let lmstudio_reqs = "To use LM Studio as a provider for Zed assistant, it needs to be running with at least one model downloaded.";
412
413 let inline_code_bg = cx.theme().colors().editor_foreground.opacity(0.05);
414
415 if self.loading_models_task.is_some() {
416 div().child(Label::new("Loading models...")).into_any()
417 } else {
418 v_flex()
419 .size_full()
420 .gap_3()
421 .child(
422 v_flex()
423 .size_full()
424 .gap_2()
425 .p_1()
426 .child(Label::new(lmstudio_intro))
427 .child(Label::new(lmstudio_reqs))
428 .child(
429 h_flex()
430 .gap_0p5()
431 .child(Label::new("To get your first model, try running"))
432 .child(
433 div()
434 .bg(inline_code_bg)
435 .px_1p5()
436 .rounded_sm()
437 .child(Label::new("lms get qwen2.5-coder-7b")),
438 ),
439 ),
440 )
441 .child(
442 h_flex()
443 .w_full()
444 .pt_2()
445 .justify_between()
446 .gap_2()
447 .child(
448 h_flex()
449 .w_full()
450 .gap_2()
451 .map(|this| {
452 if is_authenticated {
453 this.child(
454 Button::new("lmstudio-site", "LM Studio")
455 .style(ButtonStyle::Subtle)
456 .icon(IconName::ArrowUpRight)
457 .icon_size(IconSize::XSmall)
458 .icon_color(Color::Muted)
459 .on_click(move |_, _window, cx| {
460 cx.open_url(LMSTUDIO_SITE)
461 })
462 .into_any_element(),
463 )
464 } else {
465 this.child(
466 Button::new(
467 "download_lmstudio_button",
468 "Download LM Studio",
469 )
470 .style(ButtonStyle::Subtle)
471 .icon(IconName::ArrowUpRight)
472 .icon_size(IconSize::XSmall)
473 .icon_color(Color::Muted)
474 .on_click(move |_, _window, cx| {
475 cx.open_url(LMSTUDIO_DOWNLOAD_URL)
476 })
477 .into_any_element(),
478 )
479 }
480 })
481 .child(
482 Button::new("view-models", "Model Catalog")
483 .style(ButtonStyle::Subtle)
484 .icon(IconName::ArrowUpRight)
485 .icon_size(IconSize::XSmall)
486 .icon_color(Color::Muted)
487 .on_click(move |_, _window, cx| {
488 cx.open_url(LMSTUDIO_CATALOG_URL)
489 }),
490 ),
491 )
492 .child(if is_authenticated {
493 // This is only a button to ensure the spacing is correct
494 // it should stay disabled
495 ButtonLike::new("connected")
496 .disabled(true)
497 // Since this won't ever be clickable, we can use the arrow cursor
498 .cursor_style(gpui::CursorStyle::Arrow)
499 .child(
500 h_flex()
501 .gap_2()
502 .child(Indicator::dot().color(Color::Success))
503 .child(Label::new("Connected"))
504 .into_any_element(),
505 )
506 .into_any_element()
507 } else {
508 Button::new("retry_lmstudio_models", "Connect")
509 .icon_position(IconPosition::Start)
510 .icon(IconName::ArrowCircle)
511 .on_click(cx.listener(move |this, _, _window, cx| {
512 this.retry_connection(cx)
513 }))
514 .into_any_element()
515 }),
516 )
517 .into_any()
518 }
519 }
520}