1use anyhow::{Result, anyhow};
2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
4use http_client::HttpClient;
5use language_model::{AuthenticateError, LanguageModelCompletionEvent};
6use language_model::{
7 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
8 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
9 LanguageModelRequest, RateLimiter, Role,
10};
11use lmstudio::{
12 ChatCompletionRequest, ChatMessage, ModelType, get_models, preload_model,
13 stream_chat_completion,
14};
15use schemars::JsonSchema;
16use serde::{Deserialize, Serialize};
17use settings::{Settings, SettingsStore};
18use std::{collections::BTreeMap, sync::Arc};
19use ui::{ButtonLike, Indicator, List, prelude::*};
20use util::ResultExt;
21
22use crate::AllLanguageModelSettings;
23use crate::ui::InstructionListItem;
24
25const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
26const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
27const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
28
29const PROVIDER_ID: &str = "lmstudio";
30const PROVIDER_NAME: &str = "LM Studio";
31
32#[derive(Default, Debug, Clone, PartialEq)]
33pub struct LmStudioSettings {
34 pub api_url: String,
35 pub available_models: Vec<AvailableModel>,
36}
37
38#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
39pub struct AvailableModel {
40 /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
41 pub name: String,
42 /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
43 pub display_name: Option<String>,
44 /// The model's context window size.
45 pub max_tokens: usize,
46}
47
48pub struct LmStudioLanguageModelProvider {
49 http_client: Arc<dyn HttpClient>,
50 state: gpui::Entity<State>,
51}
52
53pub struct State {
54 http_client: Arc<dyn HttpClient>,
55 available_models: Vec<lmstudio::Model>,
56 fetch_model_task: Option<Task<Result<()>>>,
57 _subscription: Subscription,
58}
59
60impl State {
61 fn is_authenticated(&self) -> bool {
62 !self.available_models.is_empty()
63 }
64
65 fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
66 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
67 let http_client = self.http_client.clone();
68 let api_url = settings.api_url.clone();
69
70 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
71 cx.spawn(async move |this, cx| {
72 let models = get_models(http_client.as_ref(), &api_url, None).await?;
73
74 let mut models: Vec<lmstudio::Model> = models
75 .into_iter()
76 .filter(|model| model.r#type != ModelType::Embeddings)
77 .map(|model| lmstudio::Model::new(&model.id, None, None))
78 .collect();
79
80 models.sort_by(|a, b| a.name.cmp(&b.name));
81
82 this.update(cx, |this, cx| {
83 this.available_models = models;
84 cx.notify();
85 })
86 })
87 }
88
89 fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
90 let task = self.fetch_models(cx);
91 self.fetch_model_task.replace(task);
92 }
93
94 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
95 if self.is_authenticated() {
96 return Task::ready(Ok(()));
97 }
98
99 let fetch_models_task = self.fetch_models(cx);
100 cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
101 }
102}
103
104impl LmStudioLanguageModelProvider {
105 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
106 let this = Self {
107 http_client: http_client.clone(),
108 state: cx.new(|cx| {
109 let subscription = cx.observe_global::<SettingsStore>({
110 let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
111 move |this: &mut State, cx| {
112 let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
113 if &settings != new_settings {
114 settings = new_settings.clone();
115 this.restart_fetch_models_task(cx);
116 cx.notify();
117 }
118 }
119 });
120
121 State {
122 http_client,
123 available_models: Default::default(),
124 fetch_model_task: None,
125 _subscription: subscription,
126 }
127 }),
128 };
129 this.state
130 .update(cx, |state, cx| state.restart_fetch_models_task(cx));
131 this
132 }
133}
134
135impl LanguageModelProviderState for LmStudioLanguageModelProvider {
136 type ObservableEntity = State;
137
138 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
139 Some(self.state.clone())
140 }
141}
142
143impl LanguageModelProvider for LmStudioLanguageModelProvider {
144 fn id(&self) -> LanguageModelProviderId {
145 LanguageModelProviderId(PROVIDER_ID.into())
146 }
147
148 fn name(&self) -> LanguageModelProviderName {
149 LanguageModelProviderName(PROVIDER_NAME.into())
150 }
151
152 fn icon(&self) -> IconName {
153 IconName::AiLmStudio
154 }
155
156 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
157 self.provided_models(cx).into_iter().next()
158 }
159
160 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
161 let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
162
163 // Add models from the LM Studio API
164 for model in self.state.read(cx).available_models.iter() {
165 models.insert(model.name.clone(), model.clone());
166 }
167
168 // Override with available models from settings
169 for model in AllLanguageModelSettings::get_global(cx)
170 .lmstudio
171 .available_models
172 .iter()
173 {
174 models.insert(
175 model.name.clone(),
176 lmstudio::Model {
177 name: model.name.clone(),
178 display_name: model.display_name.clone(),
179 max_tokens: model.max_tokens,
180 },
181 );
182 }
183
184 models
185 .into_values()
186 .map(|model| {
187 Arc::new(LmStudioLanguageModel {
188 id: LanguageModelId::from(model.name.clone()),
189 model: model.clone(),
190 http_client: self.http_client.clone(),
191 request_limiter: RateLimiter::new(4),
192 }) as Arc<dyn LanguageModel>
193 })
194 .collect()
195 }
196
197 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
198 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
199 let http_client = self.http_client.clone();
200 let api_url = settings.api_url.clone();
201 let id = model.id().0.to_string();
202 cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
203 .detach_and_log_err(cx);
204 }
205
206 fn is_authenticated(&self, cx: &App) -> bool {
207 self.state.read(cx).is_authenticated()
208 }
209
210 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
211 self.state.update(cx, |state, cx| state.authenticate(cx))
212 }
213
214 fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
215 let state = self.state.clone();
216 cx.new(|cx| ConfigurationView::new(state, cx)).into()
217 }
218
219 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
220 self.state.update(cx, |state, cx| state.fetch_models(cx))
221 }
222}
223
224pub struct LmStudioLanguageModel {
225 id: LanguageModelId,
226 model: lmstudio::Model,
227 http_client: Arc<dyn HttpClient>,
228 request_limiter: RateLimiter,
229}
230
231impl LmStudioLanguageModel {
232 fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
233 ChatCompletionRequest {
234 model: self.model.name.clone(),
235 messages: request
236 .messages
237 .into_iter()
238 .map(|msg| match msg.role {
239 Role::User => ChatMessage::User {
240 content: msg.string_contents(),
241 },
242 Role::Assistant => ChatMessage::Assistant {
243 content: Some(msg.string_contents()),
244 tool_calls: None,
245 },
246 Role::System => ChatMessage::System {
247 content: msg.string_contents(),
248 },
249 })
250 .collect(),
251 stream: true,
252 max_tokens: Some(-1),
253 stop: Some(request.stop),
254 temperature: request.temperature.or(Some(0.0)),
255 tools: vec![],
256 }
257 }
258}
259
260impl LanguageModel for LmStudioLanguageModel {
261 fn id(&self) -> LanguageModelId {
262 self.id.clone()
263 }
264
265 fn name(&self) -> LanguageModelName {
266 LanguageModelName::from(self.model.display_name().to_string())
267 }
268
269 fn provider_id(&self) -> LanguageModelProviderId {
270 LanguageModelProviderId(PROVIDER_ID.into())
271 }
272
273 fn provider_name(&self) -> LanguageModelProviderName {
274 LanguageModelProviderName(PROVIDER_NAME.into())
275 }
276
277 fn supports_tools(&self) -> bool {
278 false
279 }
280
281 fn telemetry_id(&self) -> String {
282 format!("lmstudio/{}", self.model.id())
283 }
284
285 fn max_token_count(&self) -> usize {
286 self.model.max_token_count()
287 }
288
289 fn count_tokens(
290 &self,
291 request: LanguageModelRequest,
292 _cx: &App,
293 ) -> BoxFuture<'static, Result<usize>> {
294 // Endpoint for this is coming soon. In the meantime, hacky estimation
295 let token_count = request
296 .messages
297 .iter()
298 .map(|msg| msg.string_contents().split_whitespace().count())
299 .sum::<usize>();
300
301 let estimated_tokens = (token_count as f64 * 0.75) as usize;
302 async move { Ok(estimated_tokens) }.boxed()
303 }
304
305 fn stream_completion(
306 &self,
307 request: LanguageModelRequest,
308 cx: &AsyncApp,
309 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
310 let request = self.to_lmstudio_request(request);
311
312 let http_client = self.http_client.clone();
313 let Ok(api_url) = cx.update(|cx| {
314 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
315 settings.api_url.clone()
316 }) else {
317 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
318 };
319
320 let future = self.request_limiter.stream(async move {
321 let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
322 let stream = response
323 .filter_map(|response| async move {
324 match response {
325 Ok(fragment) => {
326 // Skip empty deltas
327 if fragment.choices[0].delta.is_object()
328 && fragment.choices[0].delta.as_object().unwrap().is_empty()
329 {
330 return None;
331 }
332
333 // Try to parse the delta as ChatMessage
334 if let Ok(chat_message) = serde_json::from_value::<ChatMessage>(
335 fragment.choices[0].delta.clone(),
336 ) {
337 let content = match chat_message {
338 ChatMessage::User { content } => content,
339 ChatMessage::Assistant { content, .. } => {
340 content.unwrap_or_default()
341 }
342 ChatMessage::System { content } => content,
343 };
344 if !content.is_empty() {
345 Some(Ok(content))
346 } else {
347 None
348 }
349 } else {
350 None
351 }
352 }
353 Err(error) => Some(Err(error)),
354 }
355 })
356 .boxed();
357 Ok(stream)
358 });
359
360 async move {
361 Ok(future
362 .await?
363 .map(|result| result.map(LanguageModelCompletionEvent::Text))
364 .boxed())
365 }
366 .boxed()
367 }
368}
369
370struct ConfigurationView {
371 state: gpui::Entity<State>,
372 loading_models_task: Option<Task<()>>,
373}
374
375impl ConfigurationView {
376 pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
377 let loading_models_task = Some(cx.spawn({
378 let state = state.clone();
379 async move |this, cx| {
380 if let Some(task) = state
381 .update(cx, |state, cx| state.authenticate(cx))
382 .log_err()
383 {
384 task.await.log_err();
385 }
386 this.update(cx, |this, cx| {
387 this.loading_models_task = None;
388 cx.notify();
389 })
390 .log_err();
391 }
392 }));
393
394 Self {
395 state,
396 loading_models_task,
397 }
398 }
399
400 fn retry_connection(&self, cx: &mut App) {
401 self.state
402 .update(cx, |state, cx| state.fetch_models(cx))
403 .detach_and_log_err(cx);
404 }
405}
406
407impl Render for ConfigurationView {
408 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
409 let is_authenticated = self.state.read(cx).is_authenticated();
410
411 let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
412
413 if self.loading_models_task.is_some() {
414 div().child(Label::new("Loading models...")).into_any()
415 } else {
416 v_flex()
417 .gap_2()
418 .child(
419 v_flex().gap_1().child(Label::new(lmstudio_intro)).child(
420 List::new()
421 .child(InstructionListItem::text_only(
422 "LM Studio needs to be running with at least one model downloaded.",
423 ))
424 .child(InstructionListItem::text_only(
425 "To get your first model, try running `lms get qwen2.5-coder-7b`",
426 )),
427 ),
428 )
429 .child(
430 h_flex()
431 .w_full()
432 .justify_between()
433 .gap_2()
434 .child(
435 h_flex()
436 .w_full()
437 .gap_2()
438 .map(|this| {
439 if is_authenticated {
440 this.child(
441 Button::new("lmstudio-site", "LM Studio")
442 .style(ButtonStyle::Subtle)
443 .icon(IconName::ArrowUpRight)
444 .icon_size(IconSize::XSmall)
445 .icon_color(Color::Muted)
446 .on_click(move |_, _window, cx| {
447 cx.open_url(LMSTUDIO_SITE)
448 })
449 .into_any_element(),
450 )
451 } else {
452 this.child(
453 Button::new(
454 "download_lmstudio_button",
455 "Download LM Studio",
456 )
457 .style(ButtonStyle::Subtle)
458 .icon(IconName::ArrowUpRight)
459 .icon_size(IconSize::XSmall)
460 .icon_color(Color::Muted)
461 .on_click(move |_, _window, cx| {
462 cx.open_url(LMSTUDIO_DOWNLOAD_URL)
463 })
464 .into_any_element(),
465 )
466 }
467 })
468 .child(
469 Button::new("view-models", "Model Catalog")
470 .style(ButtonStyle::Subtle)
471 .icon(IconName::ArrowUpRight)
472 .icon_size(IconSize::XSmall)
473 .icon_color(Color::Muted)
474 .on_click(move |_, _window, cx| {
475 cx.open_url(LMSTUDIO_CATALOG_URL)
476 }),
477 ),
478 )
479 .map(|this| {
480 if is_authenticated {
481 this.child(
482 ButtonLike::new("connected")
483 .disabled(true)
484 .cursor_style(gpui::CursorStyle::Arrow)
485 .child(
486 h_flex()
487 .gap_2()
488 .child(Indicator::dot().color(Color::Success))
489 .child(Label::new("Connected"))
490 .into_any_element(),
491 ),
492 )
493 } else {
494 this.child(
495 Button::new("retry_lmstudio_models", "Connect")
496 .icon_position(IconPosition::Start)
497 .icon_size(IconSize::XSmall)
498 .icon(IconName::Play)
499 .on_click(cx.listener(move |this, _, _window, cx| {
500 this.retry_connection(cx)
501 })),
502 )
503 }
504 }),
505 )
506 .into_any()
507 }
508 }
509}