1use anyhow::{anyhow, Result};
2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
4use http_client::HttpClient;
5use language_model::LanguageModelCompletionEvent;
6use language_model::{
7 LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
8 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
9 LanguageModelRequest, RateLimiter, Role,
10};
11use lmstudio::{
12 get_models, preload_model, stream_chat_completion, ChatCompletionRequest, ChatMessage,
13 ModelType,
14};
15use schemars::JsonSchema;
16use serde::{Deserialize, Serialize};
17use settings::{Settings, SettingsStore};
18use std::{collections::BTreeMap, sync::Arc};
19use ui::{prelude::*, ButtonLike, Indicator};
20use util::ResultExt;
21
22use crate::AllLanguageModelSettings;
23
24const LMSTUDIO_DOWNLOAD_URL: &str = "https://lmstudio.ai/download";
25const LMSTUDIO_CATALOG_URL: &str = "https://lmstudio.ai/models";
26const LMSTUDIO_SITE: &str = "https://lmstudio.ai/";
27
28const PROVIDER_ID: &str = "lmstudio";
29const PROVIDER_NAME: &str = "LM Studio";
30
31#[derive(Default, Debug, Clone, PartialEq)]
32pub struct LmStudioSettings {
33 pub api_url: String,
34 pub available_models: Vec<AvailableModel>,
35}
36
37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
38pub struct AvailableModel {
39 /// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
40 pub name: String,
41 /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
42 pub display_name: Option<String>,
43 /// The model's context window size.
44 pub max_tokens: usize,
45}
46
47pub struct LmStudioLanguageModelProvider {
48 http_client: Arc<dyn HttpClient>,
49 state: gpui::Entity<State>,
50}
51
52pub struct State {
53 http_client: Arc<dyn HttpClient>,
54 available_models: Vec<lmstudio::Model>,
55 fetch_model_task: Option<Task<Result<()>>>,
56 _subscription: Subscription,
57}
58
59impl State {
60 fn is_authenticated(&self) -> bool {
61 !self.available_models.is_empty()
62 }
63
64 fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
65 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
66 let http_client = self.http_client.clone();
67 let api_url = settings.api_url.clone();
68
69 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
70 cx.spawn(|this, mut cx| async move {
71 let models = get_models(http_client.as_ref(), &api_url, None).await?;
72
73 let mut models: Vec<lmstudio::Model> = models
74 .into_iter()
75 .filter(|model| model.r#type != ModelType::Embeddings)
76 .map(|model| lmstudio::Model::new(&model.id, None, None))
77 .collect();
78
79 models.sort_by(|a, b| a.name.cmp(&b.name));
80
81 this.update(&mut cx, |this, cx| {
82 this.available_models = models;
83 cx.notify();
84 })
85 })
86 }
87
88 fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
89 let task = self.fetch_models(cx);
90 self.fetch_model_task.replace(task);
91 }
92
93 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
94 if self.is_authenticated() {
95 Task::ready(Ok(()))
96 } else {
97 self.fetch_models(cx)
98 }
99 }
100}
101
102impl LmStudioLanguageModelProvider {
103 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
104 let this = Self {
105 http_client: http_client.clone(),
106 state: cx.new(|cx| {
107 let subscription = cx.observe_global::<SettingsStore>({
108 let mut settings = AllLanguageModelSettings::get_global(cx).lmstudio.clone();
109 move |this: &mut State, cx| {
110 let new_settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
111 if &settings != new_settings {
112 settings = new_settings.clone();
113 this.restart_fetch_models_task(cx);
114 cx.notify();
115 }
116 }
117 });
118
119 State {
120 http_client,
121 available_models: Default::default(),
122 fetch_model_task: None,
123 _subscription: subscription,
124 }
125 }),
126 };
127 this.state
128 .update(cx, |state, cx| state.restart_fetch_models_task(cx));
129 this
130 }
131}
132
133impl LanguageModelProviderState for LmStudioLanguageModelProvider {
134 type ObservableEntity = State;
135
136 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
137 Some(self.state.clone())
138 }
139}
140
141impl LanguageModelProvider for LmStudioLanguageModelProvider {
142 fn id(&self) -> LanguageModelProviderId {
143 LanguageModelProviderId(PROVIDER_ID.into())
144 }
145
146 fn name(&self) -> LanguageModelProviderName {
147 LanguageModelProviderName(PROVIDER_NAME.into())
148 }
149
150 fn icon(&self) -> IconName {
151 IconName::AiLmStudio
152 }
153
154 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
155 let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();
156
157 // Add models from the LM Studio API
158 for model in self.state.read(cx).available_models.iter() {
159 models.insert(model.name.clone(), model.clone());
160 }
161
162 // Override with available models from settings
163 for model in AllLanguageModelSettings::get_global(cx)
164 .lmstudio
165 .available_models
166 .iter()
167 {
168 models.insert(
169 model.name.clone(),
170 lmstudio::Model {
171 name: model.name.clone(),
172 display_name: model.display_name.clone(),
173 max_tokens: model.max_tokens,
174 },
175 );
176 }
177
178 models
179 .into_values()
180 .map(|model| {
181 Arc::new(LmStudioLanguageModel {
182 id: LanguageModelId::from(model.name.clone()),
183 model: model.clone(),
184 http_client: self.http_client.clone(),
185 request_limiter: RateLimiter::new(4),
186 }) as Arc<dyn LanguageModel>
187 })
188 .collect()
189 }
190
191 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
192 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
193 let http_client = self.http_client.clone();
194 let api_url = settings.api_url.clone();
195 let id = model.id().0.to_string();
196 cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
197 .detach_and_log_err(cx);
198 }
199
200 fn is_authenticated(&self, cx: &App) -> bool {
201 self.state.read(cx).is_authenticated()
202 }
203
204 fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
205 self.state.update(cx, |state, cx| state.authenticate(cx))
206 }
207
208 fn configuration_view(&self, _window: &mut Window, cx: &mut App) -> AnyView {
209 let state = self.state.clone();
210 cx.new(|cx| ConfigurationView::new(state, cx)).into()
211 }
212
213 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
214 self.state.update(cx, |state, cx| state.fetch_models(cx))
215 }
216}
217
218pub struct LmStudioLanguageModel {
219 id: LanguageModelId,
220 model: lmstudio::Model,
221 http_client: Arc<dyn HttpClient>,
222 request_limiter: RateLimiter,
223}
224
225impl LmStudioLanguageModel {
226 fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
227 ChatCompletionRequest {
228 model: self.model.name.clone(),
229 messages: request
230 .messages
231 .into_iter()
232 .map(|msg| match msg.role {
233 Role::User => ChatMessage::User {
234 content: msg.string_contents(),
235 },
236 Role::Assistant => ChatMessage::Assistant {
237 content: Some(msg.string_contents()),
238 tool_calls: None,
239 },
240 Role::System => ChatMessage::System {
241 content: msg.string_contents(),
242 },
243 })
244 .collect(),
245 stream: true,
246 max_tokens: Some(-1),
247 stop: Some(request.stop),
248 temperature: request.temperature.or(Some(0.0)),
249 tools: vec![],
250 }
251 }
252}
253
254impl LanguageModel for LmStudioLanguageModel {
255 fn id(&self) -> LanguageModelId {
256 self.id.clone()
257 }
258
259 fn name(&self) -> LanguageModelName {
260 LanguageModelName::from(self.model.display_name().to_string())
261 }
262
263 fn provider_id(&self) -> LanguageModelProviderId {
264 LanguageModelProviderId(PROVIDER_ID.into())
265 }
266
267 fn provider_name(&self) -> LanguageModelProviderName {
268 LanguageModelProviderName(PROVIDER_NAME.into())
269 }
270
271 fn telemetry_id(&self) -> String {
272 format!("lmstudio/{}", self.model.id())
273 }
274
275 fn max_token_count(&self) -> usize {
276 self.model.max_token_count()
277 }
278
279 fn count_tokens(
280 &self,
281 request: LanguageModelRequest,
282 _cx: &App,
283 ) -> BoxFuture<'static, Result<usize>> {
284 // Endpoint for this is coming soon. In the meantime, hacky estimation
285 let token_count = request
286 .messages
287 .iter()
288 .map(|msg| msg.string_contents().split_whitespace().count())
289 .sum::<usize>();
290
291 let estimated_tokens = (token_count as f64 * 0.75) as usize;
292 async move { Ok(estimated_tokens) }.boxed()
293 }
294
295 fn stream_completion(
296 &self,
297 request: LanguageModelRequest,
298 cx: &AsyncApp,
299 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
300 let request = self.to_lmstudio_request(request);
301
302 let http_client = self.http_client.clone();
303 let Ok(api_url) = cx.update(|cx| {
304 let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
305 settings.api_url.clone()
306 }) else {
307 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
308 };
309
310 let future = self.request_limiter.stream(async move {
311 let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
312 let stream = response
313 .filter_map(|response| async move {
314 match response {
315 Ok(fragment) => {
316 // Skip empty deltas
317 if fragment.choices[0].delta.is_object()
318 && fragment.choices[0].delta.as_object().unwrap().is_empty()
319 {
320 return None;
321 }
322
323 // Try to parse the delta as ChatMessage
324 if let Ok(chat_message) = serde_json::from_value::<ChatMessage>(
325 fragment.choices[0].delta.clone(),
326 ) {
327 let content = match chat_message {
328 ChatMessage::User { content } => content,
329 ChatMessage::Assistant { content, .. } => {
330 content.unwrap_or_default()
331 }
332 ChatMessage::System { content } => content,
333 };
334 if !content.is_empty() {
335 Some(Ok(content))
336 } else {
337 None
338 }
339 } else {
340 None
341 }
342 }
343 Err(error) => Some(Err(error)),
344 }
345 })
346 .boxed();
347 Ok(stream)
348 });
349
350 async move {
351 Ok(future
352 .await?
353 .map(|result| result.map(LanguageModelCompletionEvent::Text))
354 .boxed())
355 }
356 .boxed()
357 }
358
359 fn use_any_tool(
360 &self,
361 _request: LanguageModelRequest,
362 _tool_name: String,
363 _tool_description: String,
364 _schema: serde_json::Value,
365 _cx: &AsyncApp,
366 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
367 async move { Ok(futures::stream::empty().boxed()) }.boxed()
368 }
369}
370
371struct ConfigurationView {
372 state: gpui::Entity<State>,
373 loading_models_task: Option<Task<()>>,
374}
375
376impl ConfigurationView {
377 pub fn new(state: gpui::Entity<State>, cx: &mut Context<Self>) -> Self {
378 let loading_models_task = Some(cx.spawn({
379 let state = state.clone();
380 |this, mut cx| async move {
381 if let Some(task) = state
382 .update(&mut cx, |state, cx| state.authenticate(cx))
383 .log_err()
384 {
385 task.await.log_err();
386 }
387 this.update(&mut cx, |this, cx| {
388 this.loading_models_task = None;
389 cx.notify();
390 })
391 .log_err();
392 }
393 }));
394
395 Self {
396 state,
397 loading_models_task,
398 }
399 }
400
401 fn retry_connection(&self, cx: &mut App) {
402 self.state
403 .update(cx, |state, cx| state.fetch_models(cx))
404 .detach_and_log_err(cx);
405 }
406}
407
408impl Render for ConfigurationView {
409 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
410 let is_authenticated = self.state.read(cx).is_authenticated();
411
412 let lmstudio_intro = "Run local LLMs like Llama, Phi, and Qwen.";
413 let lmstudio_reqs =
414 "To use LM Studio as a provider for Zed assistant, it needs to be running with at least one model downloaded.";
415
416 let inline_code_bg = cx.theme().colors().editor_foreground.opacity(0.05);
417
418 if self.loading_models_task.is_some() {
419 div().child(Label::new("Loading models...")).into_any()
420 } else {
421 v_flex()
422 .size_full()
423 .gap_3()
424 .child(
425 v_flex()
426 .size_full()
427 .gap_2()
428 .p_1()
429 .child(Label::new(lmstudio_intro))
430 .child(Label::new(lmstudio_reqs))
431 .child(
432 h_flex()
433 .gap_0p5()
434 .child(Label::new("To get your first model, try running"))
435 .child(
436 div()
437 .bg(inline_code_bg)
438 .px_1p5()
439 .rounded_md()
440 .child(Label::new("lms get qwen2.5-coder-7b")),
441 ),
442 ),
443 )
444 .child(
445 h_flex()
446 .w_full()
447 .pt_2()
448 .justify_between()
449 .gap_2()
450 .child(
451 h_flex()
452 .w_full()
453 .gap_2()
454 .map(|this| {
455 if is_authenticated {
456 this.child(
457 Button::new("lmstudio-site", "LM Studio")
458 .style(ButtonStyle::Subtle)
459 .icon(IconName::ArrowUpRight)
460 .icon_size(IconSize::XSmall)
461 .icon_color(Color::Muted)
462 .on_click(move |_, _window, cx| {
463 cx.open_url(LMSTUDIO_SITE)
464 })
465 .into_any_element(),
466 )
467 } else {
468 this.child(
469 Button::new(
470 "download_lmstudio_button",
471 "Download LM Studio",
472 )
473 .style(ButtonStyle::Subtle)
474 .icon(IconName::ArrowUpRight)
475 .icon_size(IconSize::XSmall)
476 .icon_color(Color::Muted)
477 .on_click(move |_, _window, cx| {
478 cx.open_url(LMSTUDIO_DOWNLOAD_URL)
479 })
480 .into_any_element(),
481 )
482 }
483 })
484 .child(
485 Button::new("view-models", "Model Catalog")
486 .style(ButtonStyle::Subtle)
487 .icon(IconName::ArrowUpRight)
488 .icon_size(IconSize::XSmall)
489 .icon_color(Color::Muted)
490 .on_click(move |_, _window, cx| {
491 cx.open_url(LMSTUDIO_CATALOG_URL)
492 }),
493 ),
494 )
495 .child(if is_authenticated {
496 // This is only a button to ensure the spacing is correct
497 // it should stay disabled
498 ButtonLike::new("connected")
499 .disabled(true)
500 // Since this won't ever be clickable, we can use the arrow cursor
501 .cursor_style(gpui::CursorStyle::Arrow)
502 .child(
503 h_flex()
504 .gap_2()
505 .child(Indicator::dot().color(Color::Success))
506 .child(Label::new("Connected"))
507 .into_any_element(),
508 )
509 .into_any_element()
510 } else {
511 Button::new("retry_lmstudio_models", "Connect")
512 .icon_position(IconPosition::Start)
513 .icon(IconName::ArrowCircle)
514 .on_click(cx.listener(move |this, _, _window, cx| {
515 this.retry_connection(cx)
516 }))
517 .into_any_element()
518 }),
519 )
520 .into_any()
521 }
522 }
523}