1use anyhow::{Result, anyhow};
2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
4use http_client::HttpClient;
5use language_model::{AuthenticateError, LanguageModelCompletionEvent};
6use language_model::{
7 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
8 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
9 LanguageModelRequest, RateLimiter, Role,
10};
11use lmstudio::{
12 ChatCompletionRequest, ChatMessage, ModelType, get_models, preload_model,
13 stream_chat_completion,
14};
15use schemars::JsonSchema;
16use serde::{Deserialize, Serialize};
17use settings::{Settings, SettingsStore};
18use std::{collections::BTreeMap, sync::Arc};
19use ui::{ButtonLike, Indicator, List, prelude::*};
20use util::ResultExt;
21
22use crate::AllLanguageModelSettings;
23use crate::ui::InstructionListItem;
24
25const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
26const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
27const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
28
29const PROVIDER_ID: &str = "lmstudio";
30const PROVIDER_NAME: &str = "LM Studio";
31
32#[derive(Default, Debug, Clone, PartialEq)]
33pub struct LmStudioSettings {
34 pub api_url: String,
35 pub available_models: Vec<AvailableModel>,
36}
37
38#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
39pub struct AvailableModel {
40 /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
41 pub name: String,
42 /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
43 pub display_name: Option<String>,
44 /// The model's context window size.
45 pub max_tokens: usize,
46}
47
48pub struct LmStudioLanguageModelProvider {
49 http_client: Arc<dyn HttpClient>,
50 state: gpui::Entity<State>,
51}
52
53pub struct State {
54 http_client: Arc<dyn HttpClient>,
55 available_models: Vec<lmstudio::Model>,
56 fetch_model_task: Option<Task<Result<()>>>,
57 _subscription: Subscription,
58}
59
60impl State {
61 fn is_authenticated(&self) -> bool {
62 !self.available_models.is_empty()
63 }
64
65 fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
66 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
67 let http_client = self.http_client.clone();
68 let api_url = settings.api_url.clone();
69
70 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
71 cx.spawn(async move |this, cx| {
72 let models = get_models(http_client.as_ref(), &api_url, None).await?;
73
74 let mut models: Vec<lmstudio::Model> = models
75 .into_iter()
76 .filter(|model| model.r#type != ModelType::Embeddings)
77 .map(|model| lmstudio::Model::new(&model.id, None, None))
78 .collect();
79
80 models.sort_by(|a, b| a.name.cmp(&b.name));
81
82 this.update(cx, |this, cx| {
83 this.available_models = models;
84 cx.notify();
85 })
86 })
87 }
88
89 fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
90 let task = self.fetch_models(cx);
91 self.fetch_model_task.replace(task);
92 }
93
94 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
95 if self.is_authenticated() {
96 return Task::ready(Ok(()));
97 }
98
99 let fetch_models_task = self.fetch_models(cx);
100 cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
101 }
102}
103
104impl LmStudioLanguageModelProvider {
105 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
106 let this = Self {
107 http_client: http_client.clone(),
108 state: cx.new(|cx| {
109 let subscription = cx.observe_global::<SettingsStore>({
110 let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
111 move |this: &mut State, cx| {
112 let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
113 if &settings != new_settings {
114 settings = new_settings.clone();
115 this.restart_fetch_models_task(cx);
116 cx.notify();
117 }
118 }
119 });
120
121 State {
122 http_client,
123 available_models: Default::default(),
124 fetch_model_task: None,
125 _subscription: subscription,
126 }
127 }),
128 };
129 this.state
130 .update(cx, |state, cx| state.restart_fetch_models_task(cx));
131 this
132 }
133}
134
135impl LanguageModelProviderState for LmStudioLanguageModelProvider {
136 type ObservableEntity = State;
137
138 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
139 Some(self.state.clone())
140 }
141}
142
143impl LanguageModelProvider for LmStudioLanguageModelProvider {
144 fn id(&self) -> LanguageModelProviderId {
145 LanguageModelProviderId(PROVIDER_ID.into())
146 }
147
148 fn name(&self) -> LanguageModelProviderName {
149 LanguageModelProviderName(PROVIDER_NAME.into())
150 }
151
152 fn icon(&self) -> IconName {
153 IconName::AiLmStudio
154 }
155
156 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
157 self.provided_models(cx).into_iter().next()
158 }
159
160 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
161 self.default_model(cx)
162 }
163
164 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
165 let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
166
167 // Add models from the LM Studio API
168 for model in self.state.read(cx).available_models.iter() {
169 models.insert(model.name.clone(), model.clone());
170 }
171
172 // Override with available models from settings
173 for model in AllLanguageModelSettings::get_global(cx)
174 .lmstudio
175 .available_models
176 .iter()
177 {
178 models.insert(
179 model.name.clone(),
180 lmstudio::Model {
181 name: model.name.clone(),
182 display_name: model.display_name.clone(),
183 max_tokens: model.max_tokens,
184 },
185 );
186 }
187
188 models
189 .into_values()
190 .map(|model| {
191 Arc::new(LmStudioLanguageModel {
192 id: LanguageModelId::from(model.name.clone()),
193 model: model.clone(),
194 http_client: self.http_client.clone(),
195 request_limiter: RateLimiter::new(4),
196 }) as Arc<dyn LanguageModel>
197 })
198 .collect()
199 }
200
201 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
202 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
203 let http_client = self.http_client.clone();
204 let api_url = settings.api_url.clone();
205 let id = model.id().0.to_string();
206 cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
207 .detach_and_log_err(cx);
208 }
209
210 fn is_authenticated(&self, cx: &App) -> bool {
211 self.state.read(cx).is_authenticated()
212 }
213
214 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
215 self.state.update(cx, |state, cx| state.authenticate(cx))
216 }
217
218 fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
219 let state = self.state.clone();
220 cx.new(|cx| ConfigurationView::new(state, cx)).into()
221 }
222
223 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
224 self.state.update(cx, |state, cx| state.fetch_models(cx))
225 }
226}
227
228pub struct LmStudioLanguageModel {
229 id: LanguageModelId,
230 model: lmstudio::Model,
231 http_client: Arc<dyn HttpClient>,
232 request_limiter: RateLimiter,
233}
234
235impl LmStudioLanguageModel {
236 fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
237 ChatCompletionRequest {
238 model: self.model.name.clone(),
239 messages: request
240 .messages
241 .into_iter()
242 .map(|msg| match msg.role {
243 Role::User => ChatMessage::User {
244 content: msg.string_contents(),
245 },
246 Role::Assistant => ChatMessage::Assistant {
247 content: Some(msg.string_contents()),
248 tool_calls: None,
249 },
250 Role::System => ChatMessage::System {
251 content: msg.string_contents(),
252 },
253 })
254 .collect(),
255 stream: true,
256 max_tokens: Some(-1),
257 stop: Some(request.stop),
258 temperature: request.temperature.or(Some(0.0)),
259 tools: vec![],
260 }
261 }
262}
263
264impl LanguageModel for LmStudioLanguageModel {
265 fn id(&self) -> LanguageModelId {
266 self.id.clone()
267 }
268
269 fn name(&self) -> LanguageModelName {
270 LanguageModelName::from(self.model.display_name().to_string())
271 }
272
273 fn provider_id(&self) -> LanguageModelProviderId {
274 LanguageModelProviderId(PROVIDER_ID.into())
275 }
276
277 fn provider_name(&self) -> LanguageModelProviderName {
278 LanguageModelProviderName(PROVIDER_NAME.into())
279 }
280
281 fn supports_tools(&self) -> bool {
282 false
283 }
284
285 fn telemetry_id(&self) -> String {
286 format!("lmstudio/{}", self.model.id())
287 }
288
289 fn max_token_count(&self) -> usize {
290 self.model.max_token_count()
291 }
292
293 fn count_tokens(
294 &self,
295 request: LanguageModelRequest,
296 _cx: &App,
297 ) -> BoxFuture<'static, Result<usize>> {
298 // Endpoint for this is coming soon. In the meantime, hacky estimation
299 let token_count = request
300 .messages
301 .iter()
302 .map(|msg| msg.string_contents().split_whitespace().count())
303 .sum::<usize>();
304
305 let estimated_tokens = (token_count as f64 * 0.75) as usize;
306 async move { Ok(estimated_tokens) }.boxed()
307 }
308
309 fn stream_completion(
310 &self,
311 request: LanguageModelRequest,
312 cx: &AsyncApp,
313 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
314 let request = self.to_lmstudio_request(request);
315
316 let http_client = self.http_client.clone();
317 let Ok(api_url) = cx.update(|cx| {
318 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
319 settings.api_url.clone()
320 }) else {
321 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
322 };
323
324 let future = self.request_limiter.stream(async move {
325 let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
326 let stream = response
327 .filter_map(|response| async move {
328 match response {
329 Ok(fragment) => {
330 // Skip empty deltas
331 if fragment.choices[0].delta.is_object()
332 && fragment.choices[0].delta.as_object().unwrap().is_empty()
333 {
334 return None;
335 }
336
337 // Try to parse the delta as ChatMessage
338 if let Ok(chat_message) = serde_json::from_value::<ChatMessage>(
339 fragment.choices[0].delta.clone(),
340 ) {
341 let content = match chat_message {
342 ChatMessage::User { content } => content,
343 ChatMessage::Assistant { content, .. } => {
344 content.unwrap_or_default()
345 }
346 ChatMessage::System { content } => content,
347 };
348 if !content.is_empty() {
349 Some(Ok(content))
350 } else {
351 None
352 }
353 } else {
354 None
355 }
356 }
357 Err(error) => Some(Err(error)),
358 }
359 })
360 .boxed();
361 Ok(stream)
362 });
363
364 async move {
365 Ok(future
366 .await?
367 .map(|result| result.map(LanguageModelCompletionEvent::Text))
368 .boxed())
369 }
370 .boxed()
371 }
372}
373
374struct ConfigurationView {
375 state: gpui::Entity<State>,
376 loading_models_task: Option<Task<()>>,
377}
378
379impl ConfigurationView {
380 pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
381 let loading_models_task = Some(cx.spawn({
382 let state = state.clone();
383 async move |this, cx| {
384 if let Some(task) = state
385 .update(cx, |state, cx| state.authenticate(cx))
386 .log_err()
387 {
388 task.await.log_err();
389 }
390 this.update(cx, |this, cx| {
391 this.loading_models_task = None;
392 cx.notify();
393 })
394 .log_err();
395 }
396 }));
397
398 Self {
399 state,
400 loading_models_task,
401 }
402 }
403
404 fn retry_connection(&self, cx: &mut App) {
405 self.state
406 .update(cx, |state, cx| state.fetch_models(cx))
407 .detach_and_log_err(cx);
408 }
409}
410
411impl Render for ConfigurationView {
412 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
413 let is_authenticated = self.state.read(cx).is_authenticated();
414
415 let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
416
417 if self.loading_models_task.is_some() {
418 div().child(Label::new("Loading models...")).into_any()
419 } else {
420 v_flex()
421 .gap_2()
422 .child(
423 v_flex().gap_1().child(Label::new(lmstudio_intro)).child(
424 List::new()
425 .child(InstructionListItem::text_only(
426 "LM Studio needs to be running with at least one model downloaded.",
427 ))
428 .child(InstructionListItem::text_only(
429 "To get your first model, try running `lms get qwen2.5-coder-7b`",
430 )),
431 ),
432 )
433 .child(
434 h_flex()
435 .w_full()
436 .justify_between()
437 .gap_2()
438 .child(
439 h_flex()
440 .w_full()
441 .gap_2()
442 .map(|this| {
443 if is_authenticated {
444 this.child(
445 Button::new("lmstudio-site", "LM Studio")
446 .style(ButtonStyle::Subtle)
447 .icon(IconName::ArrowUpRight)
448 .icon_size(IconSize::XSmall)
449 .icon_color(Color::Muted)
450 .on_click(move |_, _window, cx| {
451 cx.open_url(LMSTUDIO_SITE)
452 })
453 .into_any_element(),
454 )
455 } else {
456 this.child(
457 Button::new(
458 "download_lmstudio_button",
459 "Download LM Studio",
460 )
461 .style(ButtonStyle::Subtle)
462 .icon(IconName::ArrowUpRight)
463 .icon_size(IconSize::XSmall)
464 .icon_color(Color::Muted)
465 .on_click(move |_, _window, cx| {
466 cx.open_url(LMSTUDIO_DOWNLOAD_URL)
467 })
468 .into_any_element(),
469 )
470 }
471 })
472 .child(
473 Button::new("view-models", "Model Catalog")
474 .style(ButtonStyle::Subtle)
475 .icon(IconName::ArrowUpRight)
476 .icon_size(IconSize::XSmall)
477 .icon_color(Color::Muted)
478 .on_click(move |_, _window, cx| {
479 cx.open_url(LMSTUDIO_CATALOG_URL)
480 }),
481 ),
482 )
483 .map(|this| {
484 if is_authenticated {
485 this.child(
486 ButtonLike::new("connected")
487 .disabled(true)
488 .cursor_style(gpui::CursorStyle::Arrow)
489 .child(
490 h_flex()
491 .gap_2()
492 .child(Indicator::dot().color(Color::Success))
493 .child(Label::new("Connected"))
494 .into_any_element(),
495 ),
496 )
497 } else {
498 this.child(
499 Button::new("retry_lmstudio_models", "Connect")
500 .icon_position(IconPosition::Start)
501 .icon_size(IconSize::XSmall)
502 .icon(IconName::Play)
503 .on_click(cx.listener(move |this, _, _window, cx| {
504 this.retry_connection(cx)
505 })),
506 )
507 }
508 }),
509 )
510 .into_any()
511 }
512 }
513}