1use anyhow::{Result, anyhow};
2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
4use http_client::HttpClient;
5use language_model::{
6 AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
7 LanguageModelToolChoice,
8};
9use language_model::{
10 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
11 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
12 LanguageModelRequest, RateLimiter, Role,
13};
14use lmstudio::{
15 ChatCompletionRequest, ChatMessage, ModelType, get_models, preload_model,
16 stream_chat_completion,
17};
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use settings::{Settings, SettingsStore};
21use std::{collections::BTreeMap, sync::Arc};
22use ui::{ButtonLike, Indicator, List, prelude::*};
23use util::ResultExt;
24
25use crate::AllLanguageModelSettings;
26use crate::ui::InstructionListItem;
27
28const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
29const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
30const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
31
32const PROVIDER_ID: &str = "lmstudio";
33const PROVIDER_NAME: &str = "LM Studio";
34
35#[derive(Default, Debug, Clone, PartialEq)]
36pub struct LmStudioSettings {
37 pub api_url: String,
38 pub available_models: Vec<AvailableModel>,
39}
40
41#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
42pub struct AvailableModel {
43 /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
44 pub name: String,
45 /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
46 pub display_name: Option<String>,
47 /// The model's context window size.
48 pub max_tokens: usize,
49}
50
51pub struct LmStudioLanguageModelProvider {
52 http_client: Arc<dyn HttpClient>,
53 state: gpui::Entity<State>,
54}
55
56pub struct State {
57 http_client: Arc<dyn HttpClient>,
58 available_models: Vec<lmstudio::Model>,
59 fetch_model_task: Option<Task<Result<()>>>,
60 _subscription: Subscription,
61}
62
63impl State {
64 fn is_authenticated(&self) -> bool {
65 !self.available_models.is_empty()
66 }
67
68 fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
69 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
70 let http_client = self.http_client.clone();
71 let api_url = settings.api_url.clone();
72
73 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
74 cx.spawn(async move |this, cx| {
75 let models = get_models(http_client.as_ref(), &api_url, None).await?;
76
77 let mut models: Vec<lmstudio::Model> = models
78 .into_iter()
79 .filter(|model| model.r#type != ModelType::Embeddings)
80 .map(|model| lmstudio::Model::new(&model.id, None, None))
81 .collect();
82
83 models.sort_by(|a, b| a.name.cmp(&b.name));
84
85 this.update(cx, |this, cx| {
86 this.available_models = models;
87 cx.notify();
88 })
89 })
90 }
91
92 fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
93 let task = self.fetch_models(cx);
94 self.fetch_model_task.replace(task);
95 }
96
97 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
98 if self.is_authenticated() {
99 return Task::ready(Ok(()));
100 }
101
102 let fetch_models_task = self.fetch_models(cx);
103 cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
104 }
105}
106
107impl LmStudioLanguageModelProvider {
108 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
109 let this = Self {
110 http_client: http_client.clone(),
111 state: cx.new(|cx| {
112 let subscription = cx.observe_global::<SettingsStore>({
113 let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
114 move |this: &mut State, cx| {
115 let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
116 if &settings != new_settings {
117 settings = new_settings.clone();
118 this.restart_fetch_models_task(cx);
119 cx.notify();
120 }
121 }
122 });
123
124 State {
125 http_client,
126 available_models: Default::default(),
127 fetch_model_task: None,
128 _subscription: subscription,
129 }
130 }),
131 };
132 this.state
133 .update(cx, |state, cx| state.restart_fetch_models_task(cx));
134 this
135 }
136}
137
138impl LanguageModelProviderState for LmStudioLanguageModelProvider {
139 type ObservableEntity = State;
140
141 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
142 Some(self.state.clone())
143 }
144}
145
146impl LanguageModelProvider for LmStudioLanguageModelProvider {
147 fn id(&self) -> LanguageModelProviderId {
148 LanguageModelProviderId(PROVIDER_ID.into())
149 }
150
151 fn name(&self) -> LanguageModelProviderName {
152 LanguageModelProviderName(PROVIDER_NAME.into())
153 }
154
155 fn icon(&self) -> IconName {
156 IconName::AiLmStudio
157 }
158
159 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
160 self.provided_models(cx).into_iter().next()
161 }
162
163 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
164 self.default_model(cx)
165 }
166
167 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
168 let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
169
170 // Add models from the LM Studio API
171 for model in self.state.read(cx).available_models.iter() {
172 models.insert(model.name.clone(), model.clone());
173 }
174
175 // Override with available models from settings
176 for model in AllLanguageModelSettings::get_global(cx)
177 .lmstudio
178 .available_models
179 .iter()
180 {
181 models.insert(
182 model.name.clone(),
183 lmstudio::Model {
184 name: model.name.clone(),
185 display_name: model.display_name.clone(),
186 max_tokens: model.max_tokens,
187 },
188 );
189 }
190
191 models
192 .into_values()
193 .map(|model| {
194 Arc::new(LmStudioLanguageModel {
195 id: LanguageModelId::from(model.name.clone()),
196 model: model.clone(),
197 http_client: self.http_client.clone(),
198 request_limiter: RateLimiter::new(4),
199 }) as Arc<dyn LanguageModel>
200 })
201 .collect()
202 }
203
204 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
205 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
206 let http_client = self.http_client.clone();
207 let api_url = settings.api_url.clone();
208 let id = model.id().0.to_string();
209 cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
210 .detach_and_log_err(cx);
211 }
212
213 fn is_authenticated(&self, cx: &App) -> bool {
214 self.state.read(cx).is_authenticated()
215 }
216
217 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
218 self.state.update(cx, |state, cx| state.authenticate(cx))
219 }
220
221 fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
222 let state = self.state.clone();
223 cx.new(|cx| ConfigurationView::new(state, cx)).into()
224 }
225
226 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
227 self.state.update(cx, |state, cx| state.fetch_models(cx))
228 }
229}
230
231pub struct LmStudioLanguageModel {
232 id: LanguageModelId,
233 model: lmstudio::Model,
234 http_client: Arc<dyn HttpClient>,
235 request_limiter: RateLimiter,
236}
237
238impl LmStudioLanguageModel {
239 fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
240 ChatCompletionRequest {
241 model: self.model.name.clone(),
242 messages: request
243 .messages
244 .into_iter()
245 .map(|msg| match msg.role {
246 Role::User => ChatMessage::User {
247 content: msg.string_contents(),
248 },
249 Role::Assistant => ChatMessage::Assistant {
250 content: Some(msg.string_contents()),
251 tool_calls: None,
252 },
253 Role::System => ChatMessage::System {
254 content: msg.string_contents(),
255 },
256 })
257 .collect(),
258 stream: true,
259 max_tokens: Some(-1),
260 stop: Some(request.stop),
261 temperature: request.temperature.or(Some(0.0)),
262 tools: vec![],
263 }
264 }
265}
266
267impl LanguageModel for LmStudioLanguageModel {
268 fn id(&self) -> LanguageModelId {
269 self.id.clone()
270 }
271
272 fn name(&self) -> LanguageModelName {
273 LanguageModelName::from(self.model.display_name().to_string())
274 }
275
276 fn provider_id(&self) -> LanguageModelProviderId {
277 LanguageModelProviderId(PROVIDER_ID.into())
278 }
279
280 fn provider_name(&self) -> LanguageModelProviderName {
281 LanguageModelProviderName(PROVIDER_NAME.into())
282 }
283
284 fn supports_tools(&self) -> bool {
285 false
286 }
287
288 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
289 false
290 }
291
292 fn telemetry_id(&self) -> String {
293 format!("lmstudio/{}", self.model.id())
294 }
295
296 fn max_token_count(&self) -> usize {
297 self.model.max_token_count()
298 }
299
300 fn count_tokens(
301 &self,
302 request: LanguageModelRequest,
303 _cx: &App,
304 ) -> BoxFuture<'static, Result<usize>> {
305 // Endpoint for this is coming soon. In the meantime, hacky estimation
306 let token_count = request
307 .messages
308 .iter()
309 .map(|msg| msg.string_contents().split_whitespace().count())
310 .sum::<usize>();
311
312 let estimated_tokens = (token_count as f64 * 0.75) as usize;
313 async move { Ok(estimated_tokens) }.boxed()
314 }
315
316 fn stream_completion(
317 &self,
318 request: LanguageModelRequest,
319 cx: &AsyncApp,
320 ) -> BoxFuture<
321 'static,
322 Result<
323 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
324 >,
325 > {
326 let request = self.to_lmstudio_request(request);
327
328 let http_client = self.http_client.clone();
329 let Ok(api_url) = cx.update(|cx| {
330 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
331 settings.api_url.clone()
332 }) else {
333 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
334 };
335
336 let future = self.request_limiter.stream(async move {
337 let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
338
339 // Create a stream mapper to handle content across multiple deltas
340 let stream_mapper = LmStudioStreamMapper::new();
341
342 let stream = response
343 .map(move |response| {
344 response.and_then(|fragment| stream_mapper.process_fragment(fragment))
345 })
346 .filter_map(|result| async move {
347 match result {
348 Ok(Some(content)) => Some(Ok(content)),
349 Ok(None) => None,
350 Err(error) => Some(Err(error)),
351 }
352 })
353 .boxed();
354
355 Ok(stream)
356 });
357
358 async move {
359 Ok(future
360 .await?
361 .map(|result| {
362 result
363 .map(LanguageModelCompletionEvent::Text)
364 .map_err(LanguageModelCompletionError::Other)
365 })
366 .boxed())
367 }
368 .boxed()
369 }
370}
371
372// This will be more useful when we implement tool calling. Currently keeping it empty.
373struct LmStudioStreamMapper {}
374
375impl LmStudioStreamMapper {
376 fn new() -> Self {
377 Self {}
378 }
379
380 fn process_fragment(&self, fragment: lmstudio::ChatResponse) -> Result<Option<String>> {
381 // Most of the time, there will be only one choice
382 let Some(choice) = fragment.choices.first() else {
383 return Ok(None);
384 };
385
386 // Extract the delta content
387 if let Ok(delta) =
388 serde_json::from_value::<lmstudio::ResponseMessageDelta>(choice.delta.clone())
389 {
390 if let Some(content) = delta.content {
391 if !content.is_empty() {
392 return Ok(Some(content));
393 }
394 }
395 }
396
397 // If there's a finish_reason, we're done
398 if choice.finish_reason.is_some() {
399 return Ok(None);
400 }
401
402 Ok(None)
403 }
404}
405
406struct ConfigurationView {
407 state: gpui::Entity<State>,
408 loading_models_task: Option<Task<()>>,
409}
410
411impl ConfigurationView {
412 pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
413 let loading_models_task = Some(cx.spawn({
414 let state = state.clone();
415 async move |this, cx| {
416 if let Some(task) = state
417 .update(cx, |state, cx| state.authenticate(cx))
418 .log_err()
419 {
420 task.await.log_err();
421 }
422 this.update(cx, |this, cx| {
423 this.loading_models_task = None;
424 cx.notify();
425 })
426 .log_err();
427 }
428 }));
429
430 Self {
431 state,
432 loading_models_task,
433 }
434 }
435
436 fn retry_connection(&self, cx: &mut App) {
437 self.state
438 .update(cx, |state, cx| state.fetch_models(cx))
439 .detach_and_log_err(cx);
440 }
441}
442
443impl Render for ConfigurationView {
444 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
445 let is_authenticated = self.state.read(cx).is_authenticated();
446
447 let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
448
449 if self.loading_models_task.is_some() {
450 div().child(Label::new("Loading models...")).into_any()
451 } else {
452 v_flex()
453 .gap_2()
454 .child(
455 v_flex().gap_1().child(Label::new(lmstudio_intro)).child(
456 List::new()
457 .child(InstructionListItem::text_only(
458 "LM Studio needs to be running with at least one model downloaded.",
459 ))
460 .child(InstructionListItem::text_only(
461 "To get your first model, try running `lms get qwen2.5-coder-7b`",
462 )),
463 ),
464 )
465 .child(
466 h_flex()
467 .w_full()
468 .justify_between()
469 .gap_2()
470 .child(
471 h_flex()
472 .w_full()
473 .gap_2()
474 .map(|this| {
475 if is_authenticated {
476 this.child(
477 Button::new("lmstudio-site", "LM Studio")
478 .style(ButtonStyle::Subtle)
479 .icon(IconName::ArrowUpRight)
480 .icon_size(IconSize::XSmall)
481 .icon_color(Color::Muted)
482 .on_click(move |_, _window, cx| {
483 cx.open_url(LMSTUDIO_SITE)
484 })
485 .into_any_element(),
486 )
487 } else {
488 this.child(
489 Button::new(
490 "download_lmstudio_button",
491 "Download LM Studio",
492 )
493 .style(ButtonStyle::Subtle)
494 .icon(IconName::ArrowUpRight)
495 .icon_size(IconSize::XSmall)
496 .icon_color(Color::Muted)
497 .on_click(move |_, _window, cx| {
498 cx.open_url(LMSTUDIO_DOWNLOAD_URL)
499 })
500 .into_any_element(),
501 )
502 }
503 })
504 .child(
505 Button::new("view-models", "Model Catalog")
506 .style(ButtonStyle::Subtle)
507 .icon(IconName::ArrowUpRight)
508 .icon_size(IconSize::XSmall)
509 .icon_color(Color::Muted)
510 .on_click(move |_, _window, cx| {
511 cx.open_url(LMSTUDIO_CATALOG_URL)
512 }),
513 ),
514 )
515 .map(|this| {
516 if is_authenticated {
517 this.child(
518 ButtonLike::new("connected")
519 .disabled(true)
520 .cursor_style(gpui::CursorStyle::Arrow)
521 .child(
522 h_flex()
523 .gap_2()
524 .child(Indicator::dot().color(Color::Success))
525 .child(Label::new("Connected"))
526 .into_any_element(),
527 ),
528 )
529 } else {
530 this.child(
531 Button::new("retry_lmstudio_models", "Connect")
532 .icon_position(IconPosition::Start)
533 .icon_size(IconSize::XSmall)
534 .icon(IconName::Play)
535 .on_click(cx.listener(move |this, _, _window, cx| {
536 this.retry_connection(cx)
537 })),
538 )
539 }
540 }),
541 )
542 .into_any()
543 }
544 }
545}