1use anyhow::{Result, anyhow};
2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
4use http_client::HttpClient;
5use language_model::{
6 AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
7 LanguageModelToolChoice,
8};
9use language_model::{
10 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
11 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
12 LanguageModelRequest, RateLimiter, Role,
13};
14use lmstudio::{
15 ChatCompletionRequest, ChatMessage, ModelType, get_models, preload_model,
16 stream_chat_completion,
17};
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use settings::{Settings, SettingsStore};
21use std::{collections::BTreeMap, sync::Arc};
22use ui::{ButtonLike, Indicator, List, prelude::*};
23use util::ResultExt;
24
25use crate::AllLanguageModelSettings;
26use crate::ui::InstructionListItem;
27
28const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
29const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
30const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
31
32const PROVIDER_ID: &str = "lmstudio";
33const PROVIDER_NAME: &str = "LM Studio";
34
35#[derive(Default, Debug, Clone, PartialEq)]
36pub struct LmStudioSettings {
37 pub api_url: String,
38 pub available_models: Vec<AvailableModel>,
39}
40
41#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
42pub struct AvailableModel {
43 /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
44 pub name: String,
45 /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
46 pub display_name: Option<String>,
47 /// The model's context window size.
48 pub max_tokens: usize,
49}
50
51pub struct LmStudioLanguageModelProvider {
52 http_client: Arc<dyn HttpClient>,
53 state: gpui::Entity<State>,
54}
55
56pub struct State {
57 http_client: Arc<dyn HttpClient>,
58 available_models: Vec<lmstudio::Model>,
59 fetch_model_task: Option<Task<Result<()>>>,
60 _subscription: Subscription,
61}
62
63impl State {
64 fn is_authenticated(&self) -> bool {
65 !self.available_models.is_empty()
66 }
67
68 fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
69 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
70 let http_client = self.http_client.clone();
71 let api_url = settings.api_url.clone();
72
73 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
74 cx.spawn(async move |this, cx| {
75 let models = get_models(http_client.as_ref(), &api_url, None).await?;
76
77 let mut models: Vec<lmstudio::Model> = models
78 .into_iter()
79 .filter(|model| model.r#type != ModelType::Embeddings)
80 .map(|model| lmstudio::Model::new(&model.id, None, None))
81 .collect();
82
83 models.sort_by(|a, b| a.name.cmp(&b.name));
84
85 this.update(cx, |this, cx| {
86 this.available_models = models;
87 cx.notify();
88 })
89 })
90 }
91
92 fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
93 let task = self.fetch_models(cx);
94 self.fetch_model_task.replace(task);
95 }
96
97 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
98 if self.is_authenticated() {
99 return Task::ready(Ok(()));
100 }
101
102 let fetch_models_task = self.fetch_models(cx);
103 cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
104 }
105}
106
107impl LmStudioLanguageModelProvider {
108 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
109 let this = Self {
110 http_client: http_client.clone(),
111 state: cx.new(|cx| {
112 let subscription = cx.observe_global::<SettingsStore>({
113 let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
114 move |this: &mut State, cx| {
115 let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
116 if &settings != new_settings {
117 settings = new_settings.clone();
118 this.restart_fetch_models_task(cx);
119 cx.notify();
120 }
121 }
122 });
123
124 State {
125 http_client,
126 available_models: Default::default(),
127 fetch_model_task: None,
128 _subscription: subscription,
129 }
130 }),
131 };
132 this.state
133 .update(cx, |state, cx| state.restart_fetch_models_task(cx));
134 this
135 }
136}
137
138impl LanguageModelProviderState for LmStudioLanguageModelProvider {
139 type ObservableEntity = State;
140
141 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
142 Some(self.state.clone())
143 }
144}
145
146impl LanguageModelProvider for LmStudioLanguageModelProvider {
147 fn id(&self) -> LanguageModelProviderId {
148 LanguageModelProviderId(PROVIDER_ID.into())
149 }
150
151 fn name(&self) -> LanguageModelProviderName {
152 LanguageModelProviderName(PROVIDER_NAME.into())
153 }
154
155 fn icon(&self) -> IconName {
156 IconName::AiLmStudio
157 }
158
159 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
160 self.provided_models(cx).into_iter().next()
161 }
162
163 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
164 self.default_model(cx)
165 }
166
167 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
168 let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
169
170 // Add models from the LM Studio API
171 for model in self.state.read(cx).available_models.iter() {
172 models.insert(model.name.clone(), model.clone());
173 }
174
175 // Override with available models from settings
176 for model in AllLanguageModelSettings::get_global(cx)
177 .lmstudio
178 .available_models
179 .iter()
180 {
181 models.insert(
182 model.name.clone(),
183 lmstudio::Model {
184 name: model.name.clone(),
185 display_name: model.display_name.clone(),
186 max_tokens: model.max_tokens,
187 },
188 );
189 }
190
191 models
192 .into_values()
193 .map(|model| {
194 Arc::new(LmStudioLanguageModel {
195 id: LanguageModelId::from(model.name.clone()),
196 model: model.clone(),
197 http_client: self.http_client.clone(),
198 request_limiter: RateLimiter::new(4),
199 }) as Arc<dyn LanguageModel>
200 })
201 .collect()
202 }
203
204 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
205 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
206 let http_client = self.http_client.clone();
207 let api_url = settings.api_url.clone();
208 let id = model.id().0.to_string();
209 cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
210 .detach_and_log_err(cx);
211 }
212
213 fn is_authenticated(&self, cx: &App) -> bool {
214 self.state.read(cx).is_authenticated()
215 }
216
217 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
218 self.state.update(cx, |state, cx| state.authenticate(cx))
219 }
220
221 fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
222 let state = self.state.clone();
223 cx.new(|cx| ConfigurationView::new(state, cx)).into()
224 }
225
226 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
227 self.state.update(cx, |state, cx| state.fetch_models(cx))
228 }
229}
230
231pub struct LmStudioLanguageModel {
232 id: LanguageModelId,
233 model: lmstudio::Model,
234 http_client: Arc<dyn HttpClient>,
235 request_limiter: RateLimiter,
236}
237
238impl LmStudioLanguageModel {
239 fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
240 ChatCompletionRequest {
241 model: self.model.name.clone(),
242 messages: request
243 .messages
244 .into_iter()
245 .map(|msg| match msg.role {
246 Role::User => ChatMessage::User {
247 content: msg.string_contents(),
248 },
249 Role::Assistant => ChatMessage::Assistant {
250 content: Some(msg.string_contents()),
251 tool_calls: None,
252 },
253 Role::System => ChatMessage::System {
254 content: msg.string_contents(),
255 },
256 })
257 .collect(),
258 stream: true,
259 max_tokens: Some(-1),
260 stop: Some(request.stop),
261 temperature: request.temperature.or(Some(0.0)),
262 tools: vec![],
263 }
264 }
265}
266
267impl LanguageModel for LmStudioLanguageModel {
268 fn id(&self) -> LanguageModelId {
269 self.id.clone()
270 }
271
272 fn name(&self) -> LanguageModelName {
273 LanguageModelName::from(self.model.display_name().to_string())
274 }
275
276 fn provider_id(&self) -> LanguageModelProviderId {
277 LanguageModelProviderId(PROVIDER_ID.into())
278 }
279
280 fn provider_name(&self) -> LanguageModelProviderName {
281 LanguageModelProviderName(PROVIDER_NAME.into())
282 }
283
284 fn supports_tools(&self) -> bool {
285 false
286 }
287
288 fn supports_images(&self) -> bool {
289 false
290 }
291
292 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
293 false
294 }
295
296 fn telemetry_id(&self) -> String {
297 format!("lmstudio/{}", self.model.id())
298 }
299
300 fn max_token_count(&self) -> usize {
301 self.model.max_token_count()
302 }
303
304 fn count_tokens(
305 &self,
306 request: LanguageModelRequest,
307 _cx: &App,
308 ) -> BoxFuture<'static, Result<usize>> {
309 // Endpoint for this is coming soon. In the meantime, hacky estimation
310 let token_count = request
311 .messages
312 .iter()
313 .map(|msg| msg.string_contents().split_whitespace().count())
314 .sum::<usize>();
315
316 let estimated_tokens = (token_count as f64 * 0.75) as usize;
317 async move { Ok(estimated_tokens) }.boxed()
318 }
319
320 fn stream_completion(
321 &self,
322 request: LanguageModelRequest,
323 cx: &AsyncApp,
324 ) -> BoxFuture<
325 'static,
326 Result<
327 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
328 >,
329 > {
330 let request = self.to_lmstudio_request(request);
331
332 let http_client = self.http_client.clone();
333 let Ok(api_url) = cx.update(|cx| {
334 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
335 settings.api_url.clone()
336 }) else {
337 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
338 };
339
340 let future = self.request_limiter.stream(async move {
341 let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
342
343 // Create a stream mapper to handle content across multiple deltas
344 let stream_mapper = LmStudioStreamMapper::new();
345
346 let stream = response
347 .map(move |response| {
348 response.and_then(|fragment| stream_mapper.process_fragment(fragment))
349 })
350 .filter_map(|result| async move {
351 match result {
352 Ok(Some(content)) => Some(Ok(content)),
353 Ok(None) => None,
354 Err(error) => Some(Err(error)),
355 }
356 })
357 .boxed();
358
359 Ok(stream)
360 });
361
362 async move {
363 Ok(future
364 .await?
365 .map(|result| {
366 result
367 .map(LanguageModelCompletionEvent::Text)
368 .map_err(LanguageModelCompletionError::Other)
369 })
370 .boxed())
371 }
372 .boxed()
373 }
374}
375
376// This will be more useful when we implement tool calling. Currently keeping it empty.
377struct LmStudioStreamMapper {}
378
379impl LmStudioStreamMapper {
380 fn new() -> Self {
381 Self {}
382 }
383
384 fn process_fragment(&self, fragment: lmstudio::ChatResponse) -> Result<Option<String>> {
385 // Most of the time, there will be only one choice
386 let Some(choice) = fragment.choices.first() else {
387 return Ok(None);
388 };
389
390 // Extract the delta content
391 if let Ok(delta) =
392 serde_json::from_value::<lmstudio::ResponseMessageDelta>(choice.delta.clone())
393 {
394 if let Some(content) = delta.content {
395 if !content.is_empty() {
396 return Ok(Some(content));
397 }
398 }
399 }
400
401 // If there's a finish_reason, we're done
402 if choice.finish_reason.is_some() {
403 return Ok(None);
404 }
405
406 Ok(None)
407 }
408}
409
410struct ConfigurationView {
411 state: gpui::Entity<State>,
412 loading_models_task: Option<Task<()>>,
413}
414
415impl ConfigurationView {
416 pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
417 let loading_models_task = Some(cx.spawn({
418 let state = state.clone();
419 async move |this, cx| {
420 if let Some(task) = state
421 .update(cx, |state, cx| state.authenticate(cx))
422 .log_err()
423 {
424 task.await.log_err();
425 }
426 this.update(cx, |this, cx| {
427 this.loading_models_task = None;
428 cx.notify();
429 })
430 .log_err();
431 }
432 }));
433
434 Self {
435 state,
436 loading_models_task,
437 }
438 }
439
440 fn retry_connection(&self, cx: &mut App) {
441 self.state
442 .update(cx, |state, cx| state.fetch_models(cx))
443 .detach_and_log_err(cx);
444 }
445}
446
447impl Render for ConfigurationView {
448 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
449 let is_authenticated = self.state.read(cx).is_authenticated();
450
451 let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
452
453 if self.loading_models_task.is_some() {
454 div().child(Label::new("Loading models...")).into_any()
455 } else {
456 v_flex()
457 .gap_2()
458 .child(
459 v_flex().gap_1().child(Label::new(lmstudio_intro)).child(
460 List::new()
461 .child(InstructionListItem::text_only(
462 "LM Studio needs to be running with at least one model downloaded.",
463 ))
464 .child(InstructionListItem::text_only(
465 "To get your first model, try running `lms get qwen2.5-coder-7b`",
466 )),
467 ),
468 )
469 .child(
470 h_flex()
471 .w_full()
472 .justify_between()
473 .gap_2()
474 .child(
475 h_flex()
476 .w_full()
477 .gap_2()
478 .map(|this| {
479 if is_authenticated {
480 this.child(
481 Button::new("lmstudio-site", "LM Studio")
482 .style(ButtonStyle::Subtle)
483 .icon(IconName::ArrowUpRight)
484 .icon_size(IconSize::XSmall)
485 .icon_color(Color::Muted)
486 .on_click(move |_, _window, cx| {
487 cx.open_url(LMSTUDIO_SITE)
488 })
489 .into_any_element(),
490 )
491 } else {
492 this.child(
493 Button::new(
494 "download_lmstudio_button",
495 "Download LM Studio",
496 )
497 .style(ButtonStyle::Subtle)
498 .icon(IconName::ArrowUpRight)
499 .icon_size(IconSize::XSmall)
500 .icon_color(Color::Muted)
501 .on_click(move |_, _window, cx| {
502 cx.open_url(LMSTUDIO_DOWNLOAD_URL)
503 })
504 .into_any_element(),
505 )
506 }
507 })
508 .child(
509 Button::new("view-models", "Model Catalog")
510 .style(ButtonStyle::Subtle)
511 .icon(IconName::ArrowUpRight)
512 .icon_size(IconSize::XSmall)
513 .icon_color(Color::Muted)
514 .on_click(move |_, _window, cx| {
515 cx.open_url(LMSTUDIO_CATALOG_URL)
516 }),
517 ),
518 )
519 .map(|this| {
520 if is_authenticated {
521 this.child(
522 ButtonLike::new("connected")
523 .disabled(true)
524 .cursor_style(gpui::CursorStyle::Arrow)
525 .child(
526 h_flex()
527 .gap_2()
528 .child(Indicator::dot().color(Color::Success))
529 .child(Label::new("Connected"))
530 .into_any_element(),
531 ),
532 )
533 } else {
534 this.child(
535 Button::new("retry_lmstudio_models", "Connect")
536 .icon_position(IconPosition::Start)
537 .icon_size(IconSize::XSmall)
538 .icon(IconName::Play)
539 .on_click(cx.listener(move |this, _, _window, cx| {
540 this.retry_connection(cx)
541 })),
542 )
543 }
544 }),
545 )
546 .into_any()
547 }
548 }
549}